Source code for zhusuan.invertible.base

from __future__ import absolute_import
from __future__ import division
import torch.nn as nn

__all__ = [
    "RevNet"
]


[docs]class RevNet(nn.Module): """ An abc of reversible network,every subclass should implement both ``_forward`` and ``_inverse`` abstract method. return value of ``_forward`` and ``_inverse`` is like ``(y, log_det_J)``, in which ``y`` is the transformed tensor and `log_det_J`` is log-determinant of Jacobian. """ def _forward(self, *inputs, **kwargs): raise NotImplementedError() def _inverse(self, *inputs, **kwargs): raise NotImplementedError()
[docs] def forward(self, *inputs, reverse=False, **kwargs): """ when using ``model.forward(x, reverse=False)`` process going with ``_forward(x)``, when using ``model.forward(x, reverse=True)`` process going with ``_inverse(x)``. """ if not reverse: return self._forward(*inputs, **kwargs) else: return self._inverse(*inputs, **kwargs)