Source code for zhusuan.invertible.scaling

from __future__ import absolute_import
from __future__ import division
import torch
import torch.nn as nn

from zhusuan.invertible.base import RevNet

__all__ = [
    "Scaling"
]


[docs]class Scaling(RevNet): """ Initialize a (log-)scaling layer. when Forward pass, given class:``x`` as input tensor, it returns (``y``, ``log_det_J``) where ``y`` is transformed tensor by ``y=x*exp(log_scale)`` and ``log_det_J`` is log-determinant of Jacobian. :param dim: input/output dimensions. """ def __init__(self, dim): super(Scaling, self).__init__() self.log_scale = nn.Parameter(torch.zeros([1, dim]), requires_grad=True) def _forward(self, x, **kwargs): log_det_J = torch.sum(self.log_scale) x *= self.log_scale.exp() return x, log_det_J def _inverse(self, y, **kwargs): log_det_J = torch.sum(self.log_scale) y *= torch.exp(-self.log_scale) return y, log_det_J