Source code for zhusuan.variational.importance_weighted_objective

import warnings

import torch
import torch.nn as nn
import torch.nn.functional as F
from zhusuan.framework.stochastic_tensor import StochasticTensor

from zhusuan import log_mean_exp

__all__ = [
    'iw_objective',
    'ImportanceWeightedObjective',
]


def compute_iw_term(x, axis):
    """
    compute the importance weighted gradient estimation using Numerically stable method
    param x: log(w) term
    param axis: the samples axis the importance term that will be computed
    """
    x_max = torch.max(x, axis, keepdim=True).values
    w: torch.Tensor = (x - x_max).exp()
    w_tilde = (w / w.sum(dim=axis, keepdim=True)).detach()
    return (w_tilde * x).sum(axis)


[docs]class ImportanceWeightedObjective(nn.Module): """ The class that represents the importance weighted objective for variational inference (Burda, 2015) As a variational objective, :class:`ImportanceWeightedObjective` provides two gradient estimators for the variational (proposal) parameters: * :meth:`sgvb`: The Stochastic Gradient Variational Bayes (SGVB) estimator, also known as "the reparameterization trick", or "path derivative estimator". * :meth:`vimco`: The multi-sample score function estimator with variance reduction, also known as "VIMCO". :param generator: generator part of importance weighted objective :param variational: variational part of importance weighted objective :param axis: The sample dimension(s) to reduce when computing the outer expectation in the objective. :param estimator: the estimator, a str in either 'sgvb' or 'vimco' """ def __init__(self, generator, variational, axis=None, estimator='sgvb'): super().__init__() self.generator = generator self.variational = variational if axis is None: raise ValueError( "ImportanceWeightedObjective is a multi-sample objective, " "the `axis` argument must be specified.") self._axis = axis supported_estimator = ['sgvb', 'vimco'] if estimator not in supported_estimator: raise NotImplementedError() self.estimator = estimator
[docs] def log_joint(self, nodes): log_joint_ = None for n_name in nodes.keys(): try: log_joint_ += nodes[n_name].log_prob() except: log_joint_ = nodes[n_name].log_prob() # warnings.warn(f"exception getted in log_joint method, using current node {n_name}") # log_joint_ = nodes[n_name].log_prob() return log_joint_
[docs] def forward(self, observed, reduce_mean=True): # feed forward observation to the variational(comments are based on iwae case) self.variational(observed) nodes_q: dict = self.variational.nodes _v_inputs = {} for k, v in nodes_q.items(): _v_inputs[k] = v.tensor if self.estimator == "vimco" \ and isinstance(v, StochasticTensor) \ and v.dist.is_reparameterized: raise ValueError("with vimco estimator, the is_reparameterized must be false") _observed = {**_v_inputs, **observed} nodes_p = self.generator(_observed).nodes logpxz = self.log_joint(nodes_p) logqz = self.log_joint(nodes_q) if self.estimator == 'sgvb': return self.sgvb(logpxz, logqz, reduce_mean) else: # self.estimator == 'vimco' return self.vimco(logpxz, logqz, reduce_mean)
[docs] def sgvb(self, logpxz, logqz, reduce_mean=True): """ Implements the stochastic gradient variational bayes (SGVB) gradient estimator for the objective, also known as "reparameterization trick" or "path derivative estimator". It was first used for importance weighted objectives in (Burda, 2015), where it's named "IWAE". It only works for latent `StochasticTensor` s that can be reparameterized (Kingma, 2013). For example, :class:`~zhusuan.distribution.Normal` and :class:`~zhusuan.framework.stochastic.Concrete`. .. note:: To use the :meth:`sgvb` estimator, the ``is_reparameterized`` property of each latent `StochasticTensor` must be True (which is the default setting when they are constructed). :return: A Tensor. The surrogate cost for optimizers to minimize. """ log_w = logpxz - logqz if self._axis is not None: lower_bound = compute_iw_term(log_w, self._axis) else: lower_bound = log_w if reduce_mean: return torch.mean(-lower_bound) else: return -lower_bound
[docs] def vimco(self, logpxz, logqz, reduce_mean=True): """ Implements the multi-sample score function gradient estimator for the objective, also known as "VIMCO", which is named by authors of the original paper (Minh, 2016). It works for all kinds of latent `StochasticTensor` s. .. note:: To use the :meth:`vimco` estimator, the ``is_reparameterized`` property of each reparameterizable latent `StochasticTensor` must be set False. :return: A Tensor. The surrogate cost for optimizers to minimize. """ log_w = logpxz - logqz # check size along the sample axis err_msg = "VIMCO is a multi-sample gradient estimator, size along " \ "`axis` in the objective should be larger than 1." try: _shape = log_w.shape _ = _shape[self._axis:self._axis + 1] if _shape[self._axis] < 2: raise ValueError(err_msg) except: raise ValueError(err_msg) l_signal = log_w # compute the mean of sample dim, the inplace position value is dropped, # in iwae case: [sample dim, batch size] mean_expect_signal = (torch.sum(l_signal, dim=self._axis, keepdim=True) - l_signal) \ / torch.as_tensor(l_signal.shape[self._axis] - 1, dtype=l_signal.dtype) x, sub_x = l_signal, mean_expect_signal int_dim = x.dim() n_dim = torch.as_tensor(int_dim, dtype=torch.int32) # mark the sample axis as bool dim mask axis_dim_mask = torch.as_tensor(F.one_hot(torch.tensor(self._axis), num_classes=n_dim), dtype=torch.bool) original_mask = torch.as_tensor(F.one_hot(torch.tensor(x.dim() - 1), num_classes=n_dim), dtype=torch.bool) axis_dim = torch.ones(int_dim, dtype=torch.int32) * self._axis originals = torch.ones(int_dim, dtype=torch.int32) * (int_dim - 1) perm = torch.where(original_mask, axis_dim, torch.arange(int_dim, dtype=torch.int32)) perm = torch.where(axis_dim_mask, originals, perm) multiples = torch.cat([torch.ones([int_dim], dtype=torch.int32), torch.tensor([x.shape[self._axis]])], 0) if len(perm) > 1: # exchange the sample dim and last dim x = torch.transpose(x, *list(perm)) sub_x = torch.transpose(sub_x, *list(perm)) x_ex = torch.unsqueeze(x, int_dim).repeat(tuple(multiples)) x_ex = x_ex - torch.diag_embed(x) + torch.diag_embed(sub_x) control_variate = log_mean_exp(x_ex, int_dim - 1).permute(list(perm)) l_signal = log_mean_exp(l_signal, self._axis, keepdims=True) - control_variate fake_term = torch.sum(logqz * l_signal.detach(), self._axis) iw_term = compute_iw_term(log_w, self._axis) cost = -fake_term - iw_term return cost.mean()
iw_objective = 'ImportanceWeightedObjective',