# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property # and proprietary rights in and to this software, related documentation # and any modifications thereto. Any use, reproduction, disclosure or # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. import torch import math import torch.nn as nn def mask_inactive_variables(x, is_active): x = x * is_active return x class PositionalEmbedding(nn.Module): def __init__(self, embedding_dim, scale): super(PositionalEmbedding, self).__init__() self.embedding_dim = embedding_dim self.scale = scale def forward(self, timesteps): assert len(timesteps.shape) == 1 timesteps = timesteps * self.scale half_dim = self.embedding_dim // 2 emb = math.log(10000) / (half_dim - 1) emb = torch.exp(torch.arange(half_dim) * -emb) emb = emb.to(device=timesteps.device) emb = timesteps[:, None] * emb[None, :] emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) return emb class RandomFourierEmbedding(nn.Module): def __init__(self, embedding_dim, scale): super(RandomFourierEmbedding, self).__init__() self.w = nn.Parameter(torch.randn(size=(1, embedding_dim // 2)) * scale, requires_grad=False) def forward(self, timesteps): emb = torch.mm(timesteps[:, None], self.w * 2 * 3.14159265359) return torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) def init_temb_fun(embedding_type, embedding_scale, embedding_dim): if embedding_type == 'positional': temb_fun = PositionalEmbedding(embedding_dim, embedding_scale) elif embedding_type == 'fourier': temb_fun = RandomFourierEmbedding(embedding_dim, embedding_scale) else: raise NotImplementedError return temb_fun