LION/models/utils.py
2023-01-23 00:14:49 -05:00

53 lines
2 KiB
Python

# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
import torch
import math
import torch.nn as nn
def mask_inactive_variables(x, is_active):
x = x * is_active
return x
class PositionalEmbedding(nn.Module):
def __init__(self, embedding_dim, scale):
super(PositionalEmbedding, self).__init__()
self.embedding_dim = embedding_dim
self.scale = scale
def forward(self, timesteps):
assert len(timesteps.shape) == 1
timesteps = timesteps * self.scale
half_dim = self.embedding_dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim) * -emb)
emb = emb.to(device=timesteps.device)
emb = timesteps[:, None] * emb[None, :]
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
return emb
class RandomFourierEmbedding(nn.Module):
def __init__(self, embedding_dim, scale):
super(RandomFourierEmbedding, self).__init__()
self.w = nn.Parameter(torch.randn(size=(1, embedding_dim // 2)) * scale, requires_grad=False)
def forward(self, timesteps):
emb = torch.mm(timesteps[:, None], self.w * 2 * 3.14159265359)
return torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
def init_temb_fun(embedding_type, embedding_scale, embedding_dim):
if embedding_type == 'positional':
temb_fun = PositionalEmbedding(embedding_dim, embedding_scale)
elif embedding_type == 'fourier':
temb_fun = RandomFourierEmbedding(embedding_dim, embedding_scale)
else:
raise NotImplementedError
return temb_fun