LION/models/latent_points_ada_localprior.py
2023-01-23 00:14:49 -05:00

85 lines
4 KiB
Python

# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
import torch
from loguru import logger
import torch.nn as nn
import torch.nn.functional as F
from .latent_points_ada import PVCNN2Unet
from .utils import mask_inactive_variables
# diffusion model for latent points
class PVCNN2Prior(PVCNN2Unet):
sa_blocks = [ # conv_configs, sa_configs
((32, 2, 32), (1024, 0.1, 32, (32, 64))),
((64, 3, 16), (256, 0.2, 32, (64, 128))),
((128, 3, 8), (64, 0.4, 32, (128, 128))),
(None, (16, 0.8, 32, (128, 128, 128))),
]
fp_blocks = [
((128, 128), (128, 3, 8)), # fp_configs, conv_configs
((128, 128), (128, 3, 8)),
((128, 128), (128, 2, 16)),
((128, 128, 64), (64, 2, 32)),
]
def __init__(self, args, num_input_channels, cfg):
# only cfg is used
self.clip_forge_enable = cfg.clipforge.enable
clip_forge_dim = cfg.clipforge.feat_dim
num_input_channels = num_classes = cfg.shapelatent.latent_dim + cfg.ddpm.input_dim
self.num_classes = num_classes
embed_dim = cfg.ddpm.time_dim
use_att = True
extra_feature_channels = cfg.shapelatent.latent_dim
self.num_points = cfg.data.tr_max_sample_points
dropout = cfg.ddpm.dropout
time_emb_scales = cfg.sde.embedding_scale # 1k default
logger.info('[Build Prior Model] nclass={}, embed_dim={}, use_att={},'
'extra_feature_channels={}, dropout={}, time_emb_scales={} num_point={}',
num_classes, embed_dim, use_att, extra_feature_channels, dropout, time_emb_scales,
self.num_points)
# Attention: we are not using time_emb_scales here, but the embedding_scale
super().__init__(
num_classes, embed_dim, use_att, dropout=dropout,
input_dim=cfg.ddpm.input_dim,
extra_feature_channels=extra_feature_channels,
time_emb_scales=time_emb_scales,
verbose=True,
condition_input=False,
cfg=cfg,
sa_blocks=self.sa_blocks,
fp_blocks=self.fp_blocks,
clip_forge_enable=self.clip_forge_enable, clip_forge_dim=clip_forge_dim)
# init mixing logit
self.mixed_prediction = cfg.sde.mixed_prediction # This enables mixed prediction
if self.mixed_prediction:
logger.info('init-mixing_logit = {}, after sigmoid = {}',
cfg.sde.mixing_logit_init, torch.sigmoid(torch.tensor(cfg.sde.mixing_logit_init))
)
init = cfg.sde.mixing_logit_init * torch.ones(size=[1, num_input_channels*self.num_points, 1, 1])
self.mixing_logit = torch.nn.Parameter(init, requires_grad=True)
self.is_active = None
else: # no mixing_logit
self.mixing_logit = None
self.is_active = None
def forward(self, x, t, *args, **kwargs): #x0=None):
# Input: x: B,ND or B,ND,1,1
# require shape for x: B,C,N
## CHECKEQ(x.shape[-1], self.num_classes)
assert('condition_input' in kwargs), 'require condition_input'
if self.mixed_prediction and self.is_active is not None:
x = mask_inactive_variables(x, self.is_active)
input_shape = x.shape
x = x.view(-1,self.num_points,self.num_classes).permute(0,2,1).contiguous()
B = x.shape[0]
out = super().forward(x, t=t, style=kwargs['condition_input'].squeeze(-1).squeeze(-1), clip_feat=kwargs.get('clip_feat', None))
return out.permute(0,2,1).contiguous().view(input_shape)
# -1,self.num_classes) # BDN -> BND -> BN,D