85 lines
4 KiB
Python
85 lines
4 KiB
Python
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
#
|
|
# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
|
|
# and proprietary rights in and to this software, related documentation
|
|
# and any modifications thereto. Any use, reproduction, disclosure or
|
|
# distribution of this software and related documentation without an express
|
|
# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
|
|
import torch
|
|
from loguru import logger
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from .latent_points_ada import PVCNN2Unet
|
|
from .utils import mask_inactive_variables
|
|
|
|
# diffusion model for latent points
|
|
class PVCNN2Prior(PVCNN2Unet):
|
|
sa_blocks = [ # conv_configs, sa_configs
|
|
((32, 2, 32), (1024, 0.1, 32, (32, 64))),
|
|
((64, 3, 16), (256, 0.2, 32, (64, 128))),
|
|
((128, 3, 8), (64, 0.4, 32, (128, 128))),
|
|
(None, (16, 0.8, 32, (128, 128, 128))),
|
|
]
|
|
fp_blocks = [
|
|
((128, 128), (128, 3, 8)), # fp_configs, conv_configs
|
|
((128, 128), (128, 3, 8)),
|
|
((128, 128), (128, 2, 16)),
|
|
((128, 128, 64), (64, 2, 32)),
|
|
]
|
|
|
|
def __init__(self, args, num_input_channels, cfg):
|
|
|
|
# only cfg is used
|
|
self.clip_forge_enable = cfg.clipforge.enable
|
|
clip_forge_dim = cfg.clipforge.feat_dim
|
|
num_input_channels = num_classes = cfg.shapelatent.latent_dim + cfg.ddpm.input_dim
|
|
self.num_classes = num_classes
|
|
embed_dim = cfg.ddpm.time_dim
|
|
use_att = True
|
|
extra_feature_channels = cfg.shapelatent.latent_dim
|
|
self.num_points = cfg.data.tr_max_sample_points
|
|
dropout = cfg.ddpm.dropout
|
|
time_emb_scales = cfg.sde.embedding_scale # 1k default
|
|
logger.info('[Build Prior Model] nclass={}, embed_dim={}, use_att={},'
|
|
'extra_feature_channels={}, dropout={}, time_emb_scales={} num_point={}',
|
|
num_classes, embed_dim, use_att, extra_feature_channels, dropout, time_emb_scales,
|
|
self.num_points)
|
|
# Attention: we are not using time_emb_scales here, but the embedding_scale
|
|
super().__init__(
|
|
num_classes, embed_dim, use_att, dropout=dropout,
|
|
input_dim=cfg.ddpm.input_dim,
|
|
extra_feature_channels=extra_feature_channels,
|
|
time_emb_scales=time_emb_scales,
|
|
verbose=True,
|
|
condition_input=False,
|
|
cfg=cfg,
|
|
sa_blocks=self.sa_blocks,
|
|
fp_blocks=self.fp_blocks,
|
|
clip_forge_enable=self.clip_forge_enable, clip_forge_dim=clip_forge_dim)
|
|
# init mixing logit
|
|
self.mixed_prediction = cfg.sde.mixed_prediction # This enables mixed prediction
|
|
if self.mixed_prediction:
|
|
logger.info('init-mixing_logit = {}, after sigmoid = {}',
|
|
cfg.sde.mixing_logit_init, torch.sigmoid(torch.tensor(cfg.sde.mixing_logit_init))
|
|
)
|
|
init = cfg.sde.mixing_logit_init * torch.ones(size=[1, num_input_channels*self.num_points, 1, 1])
|
|
self.mixing_logit = torch.nn.Parameter(init, requires_grad=True)
|
|
self.is_active = None
|
|
else: # no mixing_logit
|
|
self.mixing_logit = None
|
|
self.is_active = None
|
|
|
|
def forward(self, x, t, *args, **kwargs): #x0=None):
|
|
# Input: x: B,ND or B,ND,1,1
|
|
# require shape for x: B,C,N
|
|
## CHECKEQ(x.shape[-1], self.num_classes)
|
|
assert('condition_input' in kwargs), 'require condition_input'
|
|
if self.mixed_prediction and self.is_active is not None:
|
|
x = mask_inactive_variables(x, self.is_active)
|
|
input_shape = x.shape
|
|
x = x.view(-1,self.num_points,self.num_classes).permute(0,2,1).contiguous()
|
|
B = x.shape[0]
|
|
out = super().forward(x, t=t, style=kwargs['condition_input'].squeeze(-1).squeeze(-1), clip_feat=kwargs.get('clip_feat', None))
|
|
return out.permute(0,2,1).contiguous().view(input_shape)
|
|
# -1,self.num_classes) # BDN -> BND -> BN,D
|