# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property # and proprietary rights in and to this software, related documentation # and any modifications thereto. Any use, reproduction, disclosure or # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. import torch import numpy as np @torch.jit.script def sample_normal_jit(mu, sigma): rho = mu.mul(0).normal_() z = rho.mul_(sigma).add_(mu) return z, rho class Normal: def __init__(self, mu, log_sigma, sigma=None): self.mu = mu self.log_sigma = log_sigma self.sigma = torch.exp(log_sigma) if sigma is None else sigma def sample(self, t=1.): return sample_normal_jit(self.mu, self.sigma * t) def sample_given_rho(self, rho): return rho * self.sigma + self.mu def mean(self): return self.mu def log_p(self, samples): normalized_samples = (samples - self.mu) / self.sigma log_p = - 0.5 * normalized_samples * normalized_samples - 0.5 * np.log(2 * np.pi) - self.log_sigma return log_p