Source code for zhusuan.distributions.normal

# -*- coding: utf-8 -*-
import torch
import numpy as np

from zhusuan.distributions.base import Distribution
from zhusuan.distributions.utils import (
    assert_same_float_dtype,
    assert_same_log_float_dtype,
    check_broadcast
)


[docs]class Normal(Distribution): """ The class of univariate Normal distribution. See :class:`~zhusuan.distributions.base.Distribution` for details. :param mean: A `float` Var. The mean of the Normal distribution. Should be broadcastable to match `std` or `logstd`. :param std: A `float` Var. The standard deviation of the Normal distribution. Should be positive and broadcastable to match `mean`. :param logstd: A `float` Var. The log standard deviation of the Normal distribution. Should be broadcastable to match `mean`. :param group_ndims: A 0-D `int32` Var representing the number of dimensions in `batch_shape` (counted from the end) that are grouped into a single event, so that their probabilities are calculated together. Default is 0, which means a single value is an event. See :class:`~zhusuan.distributions.base.Distribution` for more detailed explanation. :param is_reparameterized: A Bool. If True, gradients on samples from this distribution are allowed to propagate into inputs, using the reparametrization trick from (Kingma, 2013). :param use_path_derivative: A bool. Whether when taking the gradients of the log-probability to propagate them through the parameters of the distribution (False meaning you do propagate them). This is based on the paper "Sticking the Landing: Simple, Lower-Variance Gradient Estimators for Variational Inference" """ def __init__(self, mean=0., std=None, logstd=None, dtype=None, is_continuous=True, is_reparameterized=True, group_ndims=0, device=torch.device('cpu'), **kwargs): self._mean: torch.Tensor = torch.as_tensor(mean, dtype=dtype).to(device) if (logstd is None) == (std is None): raise ValueError( "Either `std` or `logstd` should be passed. It is not allowed " "that both are specified or both are not.") elif std is None: self._std: torch.Tensor = torch.exp(torch.as_tensor(logstd, dtype=dtype)).to(device) else: # logstd is None: self._std: torch.Tensor = torch.as_tensor(std, dtype=dtype).to(device) # check dtype: check_broadcast(self._std, self._mean) dtype = assert_same_log_float_dtype([(self._mean, "Normal.mean"), (self._std, "Normal.std")]) super(Normal, self).__init__(dtype=dtype, is_continuous=is_continuous, is_reparameterized=is_reparameterized, group_ndims=group_ndims, device=device, **kwargs) @property def mean(self): """The mean of the Normal distribution.""" return self._mean @property def logstd(self): """The log standard deviation of the Normal distribution.""" return torch.log(self._std) @property def std(self): """The standard deviation of the Normal distribution.""" return self._std def _batch_shape(self): return torch.broadcast_shapes(self.mean.shape, self.std.shape) def _sample(self, n_samples=1): if n_samples > 1: _shape = self._mean.shape _shape = torch.Size([n_samples]) + _shape _len = len(self._mean.shape) _mean = self._mean.repeat([n_samples, *_len * [1]]) _std = self._std.repeat([n_samples, *_len * [1]]) else: _shape = self._mean.shape _mean = self._mean _std = self._std if not self.is_reparameterized: _sample = torch.normal(_mean, _std).to(self.device) else: epsilon = torch.normal(0., 1., size=_shape).to(self.device) _sample = _mean + _std * epsilon self.sample_cache = _sample return _sample def _log_prob(self, sample=None): if sample is None: sample = self.sample_cache if len(sample.shape) > len(self._mean.shape): n_samples = sample.shape[0] _len = len(self._mean.shape) _mean = self._mean.repeat([n_samples, *_len * [1]]) _std = self._std.repeat([n_samples, *_len * [1]]) else: _mean = self._mean _std = self._std logstd = torch.log(_std).to(self.device) c = torch.tensor(-0.5 * np.log(2 * np.pi)).to(self.device) precision = torch.exp(-2 * logstd) log_prob = c - logstd - 0.5 * precision * ((sample - _mean) ** 2) return log_prob def _prob(self, given): return torch.exp(self._log_prob(given))