53 lines
2 KiB
Python
53 lines
2 KiB
Python
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
#
|
|
# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
|
|
# and proprietary rights in and to this software, related documentation
|
|
# and any modifications thereto. Any use, reproduction, disclosure or
|
|
# distribution of this software and related documentation without an express
|
|
# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
|
|
import torch
|
|
import math
|
|
import torch.nn as nn
|
|
|
|
def mask_inactive_variables(x, is_active):
|
|
x = x * is_active
|
|
return x
|
|
|
|
class PositionalEmbedding(nn.Module):
|
|
def __init__(self, embedding_dim, scale):
|
|
super(PositionalEmbedding, self).__init__()
|
|
self.embedding_dim = embedding_dim
|
|
self.scale = scale
|
|
|
|
def forward(self, timesteps):
|
|
assert len(timesteps.shape) == 1
|
|
timesteps = timesteps * self.scale
|
|
half_dim = self.embedding_dim // 2
|
|
emb = math.log(10000) / (half_dim - 1)
|
|
emb = torch.exp(torch.arange(half_dim) * -emb)
|
|
emb = emb.to(device=timesteps.device)
|
|
emb = timesteps[:, None] * emb[None, :]
|
|
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
|
|
return emb
|
|
|
|
|
|
class RandomFourierEmbedding(nn.Module):
|
|
def __init__(self, embedding_dim, scale):
|
|
super(RandomFourierEmbedding, self).__init__()
|
|
self.w = nn.Parameter(torch.randn(size=(1, embedding_dim // 2)) * scale, requires_grad=False)
|
|
|
|
def forward(self, timesteps):
|
|
emb = torch.mm(timesteps[:, None], self.w * 2 * 3.14159265359)
|
|
return torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
|
|
|
|
|
|
def init_temb_fun(embedding_type, embedding_scale, embedding_dim):
|
|
if embedding_type == 'positional':
|
|
temb_fun = PositionalEmbedding(embedding_dim, embedding_scale)
|
|
elif embedding_type == 'fourier':
|
|
temb_fun = RandomFourierEmbedding(embedding_dim, embedding_scale)
|
|
else:
|
|
raise NotImplementedError
|
|
|
|
return temb_fun
|