38 lines
1.2 KiB
Python
38 lines
1.2 KiB
Python
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
#
|
|
# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
|
|
# and proprietary rights in and to this software, related documentation
|
|
# and any modifications thereto. Any use, reproduction, disclosure or
|
|
# distribution of this software and related documentation without an express
|
|
# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
|
|
import torch
|
|
import numpy as np
|
|
|
|
@torch.jit.script
|
|
def sample_normal_jit(mu, sigma):
|
|
rho = mu.mul(0).normal_()
|
|
z = rho.mul_(sigma).add_(mu)
|
|
return z, rho
|
|
|
|
class Normal:
|
|
def __init__(self, mu, log_sigma, sigma=None):
|
|
self.mu = mu
|
|
self.log_sigma = log_sigma
|
|
self.sigma = torch.exp(log_sigma) if sigma is None else sigma
|
|
|
|
def sample(self, t=1.):
|
|
return sample_normal_jit(self.mu, self.sigma * t)
|
|
|
|
def sample_given_rho(self, rho):
|
|
return rho * self.sigma + self.mu
|
|
|
|
def mean(self):
|
|
return self.mu
|
|
|
|
def log_p(self, samples):
|
|
normalized_samples = (samples - self.mu) / self.sigma
|
|
log_p = - 0.5 * normalized_samples * normalized_samples - 0.5 * np.log(2 * np.pi) - self.log_sigma
|
|
return log_p
|
|
|
|
|