Source code for zhusuan.distributions.uniform

import torch
from zhusuan.distributions import Distribution
from zhusuan.distributions.utils import (
    assert_same_log_float_dtype,
    check_broadcast
)


[docs]class Uniform(Distribution): """ The class of univariate Uniform distribution See :class:`~zhusuan.distributions.base.Distribution` for details. :param low: A 'float' Var. Lower range (inclusive). :param high: A 'float' Var. Upper range (exclusive). """ def __init__(self, low, high, dtype=None, is_continuous=True, is_reparameterized=True, group_ndims=0, device=torch.device('cpu'), **kwargs): self._low = torch.as_tensor(low, dtype=dtype).to(device) self._high = torch.as_tensor(high, dtype=dtype).to(device) check_broadcast(self.low, self.high) dtype = assert_same_log_float_dtype([(self._low, "Uniform.low"), (self._high, "Uniform.high")]) super(Uniform, self).__init__(dtype, is_continuous, is_reparameterized, group_ndims=group_ndims, device=device, **kwargs) @property def low(self): """Lower range (inclusive) of the Uniform distribution.""" return self._low @property def high(self): """Upper range (exclusive) of the Uniform distribution.""" return self._high def _batch_shape(self): return torch.broadcast_shapes(self.low.shape, self.high.shape) def _sample(self, n_samples=1): if n_samples > 1: _shape = self._low.shape _shape = torch.Size([n_samples]) + _shape _len = len(self._low.shape) _low = self._low.repeat([n_samples, *_len * [1]]) _high = self._high.repeat([n_samples, *_len * [1]]) else: _shape = self._low.shape _low = torch.as_tensor(self._low, dtype=self._dtype) _high = torch.as_tensor(self._high, dtype=self._dtype) if not self.is_reparameterized: _sample = torch.distributions.uniform.Uniform(_low, _high).sample() else: _sample = torch.distributions.uniform.Uniform(torch.zeros(_shape, dtype=self._dtype), torch.ones(_shape, dtype=self._dtype)).sample() self.sample_cache = _sample return _sample * (_high - _low) + _low def _log_prob(self, sample=None): if sample is None: sample = self.sample_cache if len(sample.shape) > len(self._low.shape): n_samples = sample.shape[0] _len = len(self._low.shape) _low = self._low.repeat([n_samples, *_len * [1]]) _high = self._high.repeat([n_samples, *_len * [1]]) else: _low = self._low _high = self._high return torch.distributions.uniform.Uniform(_low, _high).log_prob(sample) def _prob(self, given): return torch.exp(self._log_prob(given))