Source code for zhusuan.distributions.flow_distribution

import torch

from zhusuan.distributions import Distribution

__all__ = [
    "FlowDistribution",
]


[docs]class FlowDistribution(Distribution): """ A class for sample from Flow networks by provide the latent distribution and the flow network, when calling `sample` method, it returns the sample from flow network, when calling `log_prob` method it return the loss item of flow network. :param latents: An instance of `Distribution` class, as the prior(or the latent variable)of FlowDistrubution :param transformation: A RevNet instance, the Flow net work built by user :param flow_kwargs: additional info to be recode :param dtype: data type :param device: device of Distribution """ def __init__(self, latents, transformation, flow_kwargs=None, dtype=torch.float32, group_ndims=0, device=torch.device('cpu'), **kwargs): self._latents = latents self._transformation = transformation self._flow_kwargs = flow_kwargs super(FlowDistribution, self).__init__( dtype=dtype, is_continuous=True, is_reparameterized=False, group_ndims=group_ndims, device=device, **kwargs ) def _sample(self, n_samples=-1, **kwargs): if n_samples != -1: z = self._latents.sample(n_samples) else: # if n_samples == -1, then return None as no sample return None x, _ = self._transformation.forward(z, reverse=True, **kwargs) return x def _log_prob(self, *given, **kwargs): z, log_det_J = self._transformation.forward(*given, **kwargs, reverse=False) log_ll = torch.sum(self._latents.log_prob(z), dim=1) return log_ll + log_det_J