Source code for zhusuan.distributions.utils

# -*- coding: utf-8 -*-
import torch

floating_dtypes = (torch.float32, torch.float16, torch.float64)
log_floating_dtypes = (torch.float32, torch.float64)
integer_dtypes = (torch.int32, torch.int16, torch.int64)
int2float_mapping = {torch.int32: torch.float32,
                     torch.int64: torch.float64,
                     torch.int16: torch.float16}


[docs]def assert_same_dtype_in(tensors_with_name, dtypes=None): """ Whether all types of tensors in `tensors_with_name` are the same and in the allowed `dtypes`. :param tensors_with_name: A list of (tensor, tensor_name). :param dtypes: A list of allowed dtypes. If `None`, then all dtypes are allowed. :return: The dtype of `tensors`. """ dtypes_set = set(dtypes) if dtypes else None expected_dtype = None for tensor, tensor_name in tensors_with_name: if dtypes_set and (tensor.dtype not in dtypes_set): if len(dtypes) == 1: raise TypeError( '{}({}) must have dtype {}.'.format( tensor_name, tensor.dtype, dtypes[0])) else: raise TypeError( '{}({}) must have a dtype in {}.'.format( tensor_name, tensor.dtype, dtypes)) if not expected_dtype: expected_dtype = tensor.dtype elif expected_dtype != tensor.dtype: tensor0, tensor0_name = tensors_with_name[0] raise TypeError( '{}({}) must have the same dtype as {}({}).'.format( tensor_name, tensor.dtype, tensor0_name, tensor0.dtype)) return expected_dtype
[docs]def assert_same_float_dtype(tensors_with_name): """ Whether all tensors in `tensors_with_name` have the same floating type. :param tensors_with_name: A list of (tensor, tensor_name). :return: The type of `tensors`. """ return assert_same_dtype_in(tensors_with_name, floating_dtypes)
[docs]def assert_same_log_float_dtype(tensors_with_name): """ Whether all tensors in `tensors_with_name` have the same floating type, which also support log/exp operations. :param tensors_with_name: A list of (tensor, tensor_name). :return: The type of `tensors`. """ return assert_same_dtype_in(tensors_with_name, log_floating_dtypes)
[docs]def check_broadcast(mean, std): """ check whether mean and std broadcast match """ mean + std