# Copyright 2018 The TensorFlow Probability Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Cumsum bijector."""

from tensorflow_probability.python.internal.backend.jax.compat import v2 as tf
from tensorflow_probability.substrates.jax.bijectors import bijector
from tensorflow_probability.substrates.jax.internal import prefer_static as ps

__all__ = [
    'Cumsum',
]


class Cumsum(bijector.AutoCompositeTensorBijector):
  """Computes the cumulative sum of a tensor along a specified axis.

  If `axis` is not provided, the default uses the rightmost dimension, i.e.,
  axis=-1.

  #### Example

  ```python
  x = tfb.Cumsum()

  x.forward([[1., 1.],
             [2., 2.],
             [3., 3.]])
  # ==> [[1., 2.],
         [2., 4.],
         [3., 6.]]

  x = tfb.Cumsum(axis=-2)

  x.forward([[1., 1.],
             [2., 2.],
             [3., 3.]])
  # ==> [[1., 1.],
         [3., 3.],
         [6., 6.]]
  ```

  """

  def __init__(self, axis=-1, validate_args=False, name='cumsum'):
    """Instantiates the `Cumsum` bijector.

    Args:
      axis: Negative Python `int` indicating the axis along which to compute the
        cumulative sum. Note that positive (and zero) values are not supported.
      validate_args: Python `bool` indicating whether arguments should be
        checked for correctness.
      name: Python `str` name given to ops managed by this object.

    Raises:
      TypeError: if `axis` is not an `int`.
      ValueError: if `axis` is not negative.
    """
    parameters = dict(locals())
    with tf.name_scope(name) as name:
      if not isinstance(axis, int):
        raise TypeError(
            'Argument `axis` is not an `int` type; got {}'.format(axis))
      if axis >= 0:
        raise ValueError(
            'Argument `axis` must be negative; got {}'.format(axis))
      self._axis = axis
      super(Cumsum, self).__init__(
          is_constant_jacobian=True,
          # Positive because we verify `axis < 0`.
          forward_min_event_ndims=-axis,
          validate_args=validate_args,
          parameters=parameters,
          name=name)

  @classmethod
  def _parameter_properties(cls, dtype):
    return dict()

  @property
  def axis(self):
    """Returns the axis over which this `Bijector` computes the cumsum."""
    return self._axis

  def _forward(self, x):
    return tf.cumsum(x, axis=self.axis)

  def _inverse(self, y):
    ndims = ps.rank(y)
    shifted_y = ps.pad(
        ps.slice(
            y, ps.zeros(ndims, dtype=tf.int32),
            ps.shape(y) -
            ps.one_hot(ndims + self.axis, ndims, dtype=tf.int32)
        ),  # Remove the last entry of y in the chosen dimension.
        paddings=ps.one_hot(
            ps.one_hot(ndims + self.axis, ndims, on_value=0, off_value=-1),
            2,
            dtype=tf.int32
        )  # Insert zeros at the beginning of the chosen dimension.
    )

    return y - shifted_y

  def _forward_log_det_jacobian(self, x):
    return tf.constant(0., x.dtype)

  @property
  def _composite_tensor_shape_params(self):
    return ('axis',)


# @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@
# This file is auto-generated by substrates/meta/rewrite.py
# It will be surfaced by the build system as a symlink at:
#   `tensorflow_probability/substrates/jax/bijectors/cumsum.py`
# For more info, see substrate_runfiles_symlinks in build_defs.bzl
# @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@
