import torch
from zhusuan.distributions.base import Distribution
[docs]class StochasticTensor(object):
"""
The :class:`StochasticTensor` class represents the stochastic nodes in a
:class:`~zhusuan.framework.bn.BayesianNet`.
We can use any distribution available in :mod:`zhusuan.distributions` to
construct a stochastic node in a :class:`~zhusuan.framework.bn.BayesianNet`. For example::
class Net(BayesianNet):
def __init__(self):
self.stochastic_node('Normal', name='x', mean=0., std=1.)
will build a stochastic node in ``Net`` with the
:class:`~zhusuan.distributions.normal.Normal` distribution. The
returned ``x`` will be a instance of :class:`StochasticTensor`.
:class:`StochasticTensor` instances are Vars, which means that
they can be passed into any Jittor operations. This makes it easy
to build Bayesian networks by mixing stochastic nodes and Jittor
primitives.
.. seealso::
For more information, please refer to :doc:`/tutorials/concepts`.
:param bn: A :class:`~zhusuan.framework.bn.BayesianNet`.
:param name: A string. The name of the :class:`StochasticTensor`. Must be
unique in a :class:`~zhusuan.framework.bn.BayesianNet`.
:param dist: A :class:`~zhusuan.distributions.base.Distribution`
instance that determines the distribution used in this stochastic node.
:param observation: A Var, which matches the shape of `dist`. If
specified, then the :class:`StochasticTensor` is observed and
the :attr:`tensor` property will return the `observation`.
:param n_samples: A 0-D integer. Number of samples generated by
this :class:`StochasticTensor`.
"""
def __init__(self,
bn,
name: str,
dist: Distribution,
observation=None,
n_samples=None,
**kwargs):
if bn is None:
pass
self._bn = bn
self._name: str = name
self._dist: Distribution = dist
self._dtype: torch.dtype = dist.dtype
self._n_samples = n_samples
self._observation = observation
self._check_observation(observation)
super(StochasticTensor, self).__init__()
self._reduce_mean_dims = kwargs.get("reduce_mean_dims", None)
self._reduce_sum_dims = kwargs.get("reduce_sum_dims", None)
self._multiplier = kwargs.get("multiplier", None)
def _check_observation(self, observation):
if observation is None:
return
elif observation.dtype != self.dtype:
# convert observation to need dtype
observation = torch.as_tensor(observation, self.dtype)
return observation
@property
def bn(self):
"""
The :class:`~zhusuan.framework.bn.BayesianNet` where the :class:`StochasticTensor` lives.
:return: A :class:`~zhusuan.framework.bn.BayesianNet` instance.
"""
return self._bn
@property
def name(self):
"""
The name of the :class:`StochasticTensor`.
:return: A string.
"""
return self._name
@property
def dtype(self):
"""
The sample type of the :class:`StochasticTensor`.
:return: A ``DType`` instance.
"""
return self._dtype
@property
def dist(self):
"""
The distribution followed by the :class:`StochasticTensor`.
:return: A :class:`~zhusuan.distributions.base.Distribution` instance.
"""
return self._dist
[docs] def is_observed(self):
"""
Whether the :class:`StochasticTensor` is observed or not.
:return: A bool.
"""
return self._observation is not None
@property
def tensor(self):
"""
The value of this :class:`StochasticTensor`. If it is observed, then
the observation is returned, otherwise samples are returned.
:return: A Var.
"""
if self._name in self._bn.observed.keys():
self._dist.sample_cache = self._bn.observed[self._name]
return self._bn.observed[self._name]
else:
_samples = self._dist.sample(n_samples=self._n_samples)
return _samples
[docs] def sample(self, force=False):
"""
The value of this :class:`StochasticTensor`. If it is observed, then
the observation is returned, otherwise samples are returned.
:param force: force to sample, disregard the observed value, default as False
:return: A Var.
"""
if self._name in self._bn.observed.keys() and not force:
self._dist.sample_cache = self._bn.observed[self._name]
return self._bn.observed[self._name]
else:
_samples = self._dist.sample(n_samples=self._n_samples)
return _samples
@property
def shape(self):
"""
Return the static shape of this :class:`StochasticTensor`.
:return: A ``torch.Size`` instance.
"""
return self.tensor.shape
[docs] def get_shape(self):
"""
Alias of :attr:`shape`.
:return: A ``TensorShape`` instance.
"""
return self.shape
[docs] def log_prob(self, sample=None):
_log_probs = self._dist.log_prob(sample)
if self._reduce_mean_dims:
_log_probs = torch.mean(_log_probs, self._reduce_mean_dims, keepdim=True)
if self._reduce_sum_dims:
_log_probs = torch.sum(_log_probs, self._reduce_sum_dims, keepdim=True)
if self._reduce_mean_dims or self._reduce_sum_dims:
_m = self._reduce_mean_dims if self._reduce_mean_dims else []
_s = self._reduce_sum_dims if self._reduce_sum_dims else []
_dims = [*_m, *_s]
_dims.sort(reverse=True)
for d in _dims:
if _log_probs.shape == [1]:
break
_log_probs = torch.squeeze(_log_probs, d)
if self._multiplier:
_log_probs = _log_probs * self._multiplier
return _log_probs