LION/models/distributions.py
2023-01-23 00:14:49 -05:00

38 lines
1.2 KiB
Python

# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
import torch
import numpy as np
@torch.jit.script
def sample_normal_jit(mu, sigma):
rho = mu.mul(0).normal_()
z = rho.mul_(sigma).add_(mu)
return z, rho
class Normal:
def __init__(self, mu, log_sigma, sigma=None):
self.mu = mu
self.log_sigma = log_sigma
self.sigma = torch.exp(log_sigma) if sigma is None else sigma
def sample(self, t=1.):
return sample_normal_jit(self.mu, self.sigma * t)
def sample_given_rho(self, rho):
return rho * self.sigma + self.mu
def mean(self):
return self.mu
def log_p(self, samples):
normalized_samples = (samples - self.mu) / self.sigma
log_p = - 0.5 * normalized_samples * normalized_samples - 0.5 * np.log(2 * np.pi) - self.log_sigma
return log_p