Source code for zhusuan.framework.stochastic_tensor

import torch
from zhusuan.distributions.base import Distribution


[docs]class StochasticTensor(object): """ The :class:`StochasticTensor` class represents the stochastic nodes in a :class:`~zhusuan.framework.bn.BayesianNet`. We can use any distribution available in :mod:`zhusuan.distributions` to construct a stochastic node in a :class:`~zhusuan.framework.bn.BayesianNet`. For example:: class Net(BayesianNet): def __init__(self): self.stochastic_node('Normal', name='x', mean=0., std=1.) will build a stochastic node in ``Net`` with the :class:`~zhusuan.distributions.normal.Normal` distribution. The returned ``x`` will be a instance of :class:`StochasticTensor`. :class:`StochasticTensor` instances are Vars, which means that they can be passed into any Jittor operations. This makes it easy to build Bayesian networks by mixing stochastic nodes and Jittor primitives. .. seealso:: For more information, please refer to :doc:`/tutorials/concepts`. :param bn: A :class:`~zhusuan.framework.bn.BayesianNet`. :param name: A string. The name of the :class:`StochasticTensor`. Must be unique in a :class:`~zhusuan.framework.bn.BayesianNet`. :param dist: A :class:`~zhusuan.distributions.base.Distribution` instance that determines the distribution used in this stochastic node. :param observation: A Var, which matches the shape of `dist`. If specified, then the :class:`StochasticTensor` is observed and the :attr:`tensor` property will return the `observation`. :param n_samples: A 0-D integer. Number of samples generated by this :class:`StochasticTensor`. """ def __init__(self, bn, name: str, dist: Distribution, observation=None, n_samples=None, **kwargs): if bn is None: pass self._bn = bn self._name: str = name self._dist: Distribution = dist self._dtype: torch.dtype = dist.dtype self._n_samples = n_samples self._observation = observation self._check_observation(observation) super(StochasticTensor, self).__init__() self._reduce_mean_dims = kwargs.get("reduce_mean_dims", None) self._reduce_sum_dims = kwargs.get("reduce_sum_dims", None) self._multiplier = kwargs.get("multiplier", None) def _check_observation(self, observation): if observation is None: return elif observation.dtype != self.dtype: # convert observation to need dtype observation = torch.as_tensor(observation, self.dtype) return observation @property def bn(self): """ The :class:`~zhusuan.framework.bn.BayesianNet` where the :class:`StochasticTensor` lives. :return: A :class:`~zhusuan.framework.bn.BayesianNet` instance. """ return self._bn @property def name(self): """ The name of the :class:`StochasticTensor`. :return: A string. """ return self._name @property def dtype(self): """ The sample type of the :class:`StochasticTensor`. :return: A ``DType`` instance. """ return self._dtype @property def dist(self): """ The distribution followed by the :class:`StochasticTensor`. :return: A :class:`~zhusuan.distributions.base.Distribution` instance. """ return self._dist
[docs] def is_observed(self): """ Whether the :class:`StochasticTensor` is observed or not. :return: A bool. """ return self._observation is not None
@property def tensor(self): """ The value of this :class:`StochasticTensor`. If it is observed, then the observation is returned, otherwise samples are returned. :return: A Var. """ if self._name in self._bn.observed.keys(): self._dist.sample_cache = self._bn.observed[self._name] return self._bn.observed[self._name] else: _samples = self._dist.sample(n_samples=self._n_samples) return _samples
[docs] def sample(self, force=False): """ The value of this :class:`StochasticTensor`. If it is observed, then the observation is returned, otherwise samples are returned. :param force: force to sample, disregard the observed value, default as False :return: A Var. """ if self._name in self._bn.observed.keys() and not force: self._dist.sample_cache = self._bn.observed[self._name] return self._bn.observed[self._name] else: _samples = self._dist.sample(n_samples=self._n_samples) return _samples
@property def shape(self): """ Return the static shape of this :class:`StochasticTensor`. :return: A ``torch.Size`` instance. """ return self.tensor.shape
[docs] def get_shape(self): """ Alias of :attr:`shape`. :return: A ``TensorShape`` instance. """ return self.shape
[docs] def log_prob(self, sample=None): _log_probs = self._dist.log_prob(sample) if self._reduce_mean_dims: _log_probs = torch.mean(_log_probs, self._reduce_mean_dims, keepdim=True) if self._reduce_sum_dims: _log_probs = torch.sum(_log_probs, self._reduce_sum_dims, keepdim=True) if self._reduce_mean_dims or self._reduce_sum_dims: _m = self._reduce_mean_dims if self._reduce_mean_dims else [] _s = self._reduce_sum_dims if self._reduce_sum_dims else [] _dims = [*_m, *_s] _dims.sort(reverse=True) for d in _dims: if _log_probs.shape == [1]: break _log_probs = torch.squeeze(_log_probs, d) if self._multiplier: _log_probs = _log_probs * self._multiplier return _log_probs