import torch
__all__ = [
'Distribution',
]
[docs]class Distribution(object):
"""
The :class:`Distribution` class is the base class for various probabilistic
distributions which support batch inputs, generating batches of samples and
evaluate probabilities at batches of given values.
The typical input shape for a :class:`Distribution` is like
``batch_shape + input_shape``. where ``input_shape`` represents the shape
of non-batch input parameter, :attr:`batch_shape` represents how many
independent inputs are fed into the distribution.
Samples generated are of shape
``([n_samples]+ )batch_shape + value_shape``. The first additional axis
is omitted only when passed `n_samples` is None (by default), in which
case one sample is generated. :attr:`value_shape` is the non-batch value
shape of the distribution. For a univariate distribution, its
:attr:`value_shape` is [].
There are cases where a batch of random variables are grouped into a
single event so that their probabilities should be computed together. This
is achieved by setting `group_ndims` argument, which defaults to 0.
The last `group_ndims` number of axes in :attr:`batch_shape` are
grouped into a single event. For example,
``Normal(..., group_ndims=1)`` will set the last axis of its
:attr:`batch_shape` to a single event, i.e., a multivariate Normal with
identity covariance matrix.
When evaluating probabilities at given values, the given Tensor should be
broadcastable to shape ``(... + )batch_shape + value_shape``. The returned
Tensor has shape ``(... + )batch_shape[:-group_ndims]``.
.. seealso::
For more details and examples, please refer to
:doc:`/tutorials/concepts`.
For both, the parameter `dtype` represents type of samples. For discrete,
can be set by user. For continuous, automatically determined from parameter
types.
`dtype` must be among `torch.int16`, `torch.int32`, `torch.int64`,
`torch.float16`, `torch.float32` and `torch.float64`.
When two or more parameters are tensors and they have different type,
`TypeError` will be raised.
:param dtype: The value type of samples from the distribution.
:param is_continuous: Whether the distribution is continuous.
:param is_reparameterized: A bool. Whether the gradients of samples can
and are allowed to propagate back into inputs, using the
reparametrization trick from (Kingma, 2013).
:param use_path_derivative: A bool. Whether when taking the gradients
of the log-probability to propagate them through the parameters
of the distribution (False meaning you do propagate them). This
is based on the paper "Sticking the Landing: Simple,
Lower-Variance Gradient Estimators for Variational Inference"
:param group_ndims: A 0-D `int32` Tensor representing the number of
dimensions in :attr:`batch_shape` (counted from the end) that are
grouped into a single event, so that their probabilities are calculated
together. Default is 0, which means a single value is an event.
See above for more detailed explanation.
"""
def __init__(self,
dtype,
is_continuous,
is_reparameterized,
use_path_derivative=False,
group_ndims=0,
device=torch.device('cpu'),
**kwargs):
self._dtype = dtype
self._is_continuous = is_continuous
self._is_reparameterized = is_reparameterized
self._use_path_derivative = use_path_derivative
self._device = device
if isinstance(group_ndims, int):
if group_ndims < 0:
raise ValueError("group_ndims must be non-negative.")
self._group_ndims = group_ndims
else:
#TODO
pass
@property
def dtype(self):
"""The sample type of the distribution."""
return self._dtype
@property
def device(self):
"""
The device this distribution lies at.
:return: torch.device
"""
return self._device
@property
def is_reparameterized(self):
"""
Whether the gradients of samples can and are allowed to propagate back
into inputs, using the reparametrization trick from (Kingma, 2013).
"""
return self._is_reparameterized
@property
def batch_shape(self):
"""
The shape showing how many independent inputs (which we call batches)
are fed into the distribution. For batch inputs, the shape of a
generated sample is ``batch_shape + value_shape``.
"""
#TODO
return self._batch_shape()
def _batch_shape(self):
"""
Private method for subclasses to rewrite the :attr:`batch_shape`
property.
"""
raise NotImplementedError()
[docs] def sample(self, n_samples=None, **kwargs):
"""
sample(n_samples=None)
Return samples from the distribution. When `n_samples` is None (by
default), one sample of shape ``batch_shape + value_shape`` is
generated. For a scalar `n_samples`, the returned Var has a new
sample dimension with size `n_samples` inserted at ``axis=0``, i.e.,
the shape of samples is ``[n_samples] + batch_shape + value_shape``.
:param n_samples: A 0-D `int32` Tensor or None. How many independent
samples to draw from the distribution.
:return: A Var of samples.
"""
if n_samples is None:
samples = self._sample(n_samples=1, **kwargs)
return samples
elif isinstance(n_samples, int):
return self._sample(n_samples, **kwargs)
else:
#TODO
pass
def _sample(self, n_samples, **kwargs):
"""
Private method for subclasses to rewrite the :meth:`sample` method.
"""
raise NotImplementedError()
[docs] def log_prob(self, given):
"""
log_prob(given)
Compute log probability density (mass) function at `given` value.
:param given: A Var. The value at which to evaluate log probability
density (mass) function. Must be able to broadcast to have a shape
of ``(... + )batch_shape + value_shape``.
:return: A Var of shape ``(... + )batch_shape[:-group_ndims]``.
"""
if given is not None:
given = torch.as_tensor(given, dtype=self.dtype)
log_p = self._log_prob(given)
if self._group_ndims > 0:
return torch.sum(log_p, [i for i in range(-self._group_ndims, 0)])
else:
return log_p
def _log_prob(self, given):
"""
Private method for subclasses to rewrite the :meth:`log_prob` method.
"""
raise NotImplementedError()
[docs] def prob(self, given):
return self._prob(given)
def _prob(self, given):
"""
Private method for subclasses to rewrite the :meth:`prob` method.
"""
raise NotImplementedError()