Source code for zhusuan.distributions.bernoulli

import torch
import numpy as np

from zhusuan.distributions.base import Distribution
from zhusuan.distributions.utils import (
    assert_same_log_float_dtype
)


[docs]class Bernoulli(Distribution): """ The class of univariate Bernoulli distribution. See :class:`~zhusuan.distributions.base.Distribution` for details. :param logits: A `float` Tensor. The log-odds of probabilities of being 1. .. math:: \\mathrm{logits} = \\log \\frac{p}{1 - p} :param probs: A 'float' Tensor. The p param of bernoulli distribution :param dtype: The value type of samples from the distribution. Can be int (`torch.int16`, `torch.int32`, `torch.int64`) or float (`torch.float16`, `torch.float32`, `torch.float64`). Default is `int32`. :param group_ndims: A 0-D `int32` Tensor representing the number of dimensions in `batch_shape` (counted from the end) that are grouped into a single event, so that their probabilities are calculated together. Default is 0, which means a single value is an event. See :class:`~zhusuan.distributions.base.Distribution` for more detailed explanation. """ def __init__(self, logits=None, probs=None, dtype=None, is_continuous=False, group_ndims=0, device=torch.device('cpu'), **kwargs): if (logits is None) == (probs is None): raise ValueError( "Either `probs` or `logits` should be passed. It is not allowed " "that both are specified or both are not.") elif logits is None: self._probs: torch.Tensor = torch.as_tensor(probs, dtype=dtype).to(device) self._logits = torch.log(self._probs / (torch.ones(self._probs.shape) - self._probs)).to(device) else: # probs is None _logits = torch.as_tensor(logits, dtype=dtype) self._logits = torch.as_tensor(logits).to(device) assert_same_log_float_dtype([(_logits, "Bernoulli.logits")]) self._probs: torch.Tensor = torch.sigmoid(_logits).to(device) # dtype of probs must be float32 or float64 dtype = assert_same_log_float_dtype([(self._probs, "Bernoulli.probs")]) super(Bernoulli, self).__init__(dtype, is_continuous, is_reparameterized=False, # reparameterization trick is not applied for Bernoulli distribution group_ndims=group_ndims, device=device, **kwargs) @property def probs(self): return self._probs @property def logits(self): return self._logits def _batch_shape(self): return self.probs.shape def _sample(self, n_samples: int = 1, **kwargs): if n_samples > 1: sample_shape = np.concatenate([[n_samples], self.batch_shape], axis=0).astype(np.int32).tolist() _probs = self._probs * torch.ones(tuple(sample_shape)).to(self.device) else: _probs = self._probs # * torch.ones(self.batch_shape) # _probs *= torch.tensor(_probs <= 1, dtype=self._dtype) #! Values larger than 1 are set to 0 _sample = torch.bernoulli(_probs) self.sample_cache = _sample return _sample def _log_prob(self, sample=None): if sample is None: sample = self.sample_cache if len(sample.shape) > len(self._probs.shape): sample_shape = np.concatenate([[sample.shape[0]], self.batch_shape], axis=0).astype(np.int32).tolist() _probs = self._probs * torch.ones(tuple(sample_shape)).to(self.device) else: _probs = self._probs # * torch.ones(self.batch_shape) log_prob = sample * torch.log(_probs + 1e-8) + (1 - sample) * torch.log(1 - _probs + 1e-8) return log_prob # ! Check it again def _prob(self, given): return torch.exp(self._log_prob(given))