Source code for zhusuan.utils

#!/usr/bin/env python
# -*- coding: utf-8 -*-
import torch


[docs]def log_mean_exp(x, dim=None, keepdims=False): """ Numerically stable log mean of exps across the `dim`. :param x: A Tensor. :param dim: An int or list or tuple. The dimensions to reduce. If `None` (the default), reduces all dimensions. :param keepdims: Bool. If true, retains reduced dimensions with length 1. Default to be False. :return: A Tensor after the computation of log mean exp along given axes of x. """ x_max = torch.max(x, dim, True).values ret = torch.log(torch.mean(torch.exp(x - x_max), dim, True)) + x_max if not keepdims: ret = torch.mean(ret, dim=dim) return ret