Source code for zhusuan.invertible.coupling

import torch
import math
import torch.nn as nn
from .base import RevNet
from .sequential import RevSequential


def get_coupling_mask(n_dim, n_channel, n_mask, split_type="OddEven", dtype=torch.float32):
    """
    a function to make coupling masks
    """
    masks = []
    if n_channel == 1:
        if split_type == "OddEven":
            mask = torch.arange(n_dim, dtype=dtype) % 2
            for i in range(n_mask):
                masks.append(mask)
                mask = 1. - mask

        elif split_type == "Half":
            half_len = n_dim // 2
            half1 = torch.zeros(half_len)
            half2 = torch.ones(n_dim - half_len)
            mask = torch.cat([half1, half2], -1)
            for i in range(n_mask):
                masks.append(mask)
                mask = 1 - mask

        elif split_type == "RandomHalf":
            mask = torch.randint(0, 2, (n_dim,), dtype=dtype)
            for i in range(n_mask):
                masks.append(mask)
                mask = 1. - mask

    else:
        raise NotImplementedError()
    return masks


[docs]class MaskCoupling(RevNet): """ A coupling layer Identify if keep same or do transform by mask :param in_out_dim: input/output dimensions. :param mid_dim: number of units in a hidden layer :param hidden: number of hidden layers :param mask: mask given by the user, often generated by function:`~zhusuan.invertible.coupling.get_coupling_mask` """ def __init__(self, in_out_dim=-1, mid_dim=-1, hidden=-1, mask=None, inner_nn=None): super(MaskCoupling, self).__init__() if inner_nn is None: self.nn = [] self.nn.append(nn.Linear(in_out_dim, mid_dim)) self.nn.append(nn.ReLU()) for _ in range(hidden - 1): self.nn.append(nn.Linear(mid_dim, mid_dim)) self.nn.append(nn.ReLU()) self.nn.append(nn.Linear(mid_dim, in_out_dim)) self.nn = nn.Sequential(*self.nn) else: self.nn = inner_nn self.mask = mask def _forward(self, x, **kwargs): x1, x2 = self.mask * x, (1. - self.mask) * x shift = self.nn(x1) y1, y2 = x1, x2 + shift * (1. - self.mask) return y1 + y2, None def _inverse(self, y, **kwargs): y1, y2 = self.mask * y, (1. - self.mask) * y shift = self.nn(y1) x1, x2 = y1, y2 - shift * (1. - self.mask) return x1 + x2, None
[docs]class Coupling(RevNet): """ coupling layer class :param in_out_dim: input/output dimensions. :param mid_dim: number of units in a hidden layer. :param hidden: number of hidden layers. :param mask_config: 1 if transform odd units, 0 if transform even units. """ def __init__(self, in_out_dim, mid_dim, hidden, mask_config): super(Coupling, self).__init__() self.mask_config = mask_config self.in_block = nn.Sequential( nn.Linear(in_out_dim // 2, mid_dim), nn.ReLU()) self.mid_block = nn.ModuleList([ nn.Sequential( nn.Linear(mid_dim, mid_dim), nn.ReLU()) for _ in range(hidden - 1)]) self.out_block = nn.Linear(mid_dim, in_out_dim // 2) def _forward(self, x, **kwargs): """ Forward pass. :param x: input tensor. :param reverse: True in inference mode, False in sampling mode. Returns: transformed tensor. """ [B, W] = list(x.size()) x = x.reshape((B, W // 2, 2)) if self.mask_config: on, off = x[:, :, 0], x[:, :, 1] else: off, on = x[:, :, 0], x[:, :, 1] off_ = self.in_block(off) for i in range(len(self.mid_block)): off_ = self.mid_block[i](off_) shift = self.out_block(off_) on = on + shift if self.mask_config: x = torch.stack((on, off), dim=2) else: x = torch.stack((off, on), dim=2) return x.reshape((B, W)), None def _inverse(self, x, **kwargs): [B, W] = list(x.size()) x = x.reshape((B, W // 2, 2)) if self.mask_config: on, off = x[:, :, 0], x[:, :, 1] else: off, on = x[:, :, 0], x[:, :, 1] off_ = self.in_block(off) for i in range(len(self.mid_block)): off_ = self.mid_block[i](off_) shift = self.out_block(off_) on = on - shift if self.mask_config: x = torch.stack((on, off), dim=2) else: x = torch.stack((off, on), dim=2) return x.reshape((B, W)), None