340 lines
14 KiB
Python
340 lines
14 KiB
Python
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
#
|
|
# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
|
|
# and proprietary rights in and to this software, related documentation
|
|
# and any modifications thereto. Any use, reproduction, disclosure or
|
|
# distribution of this software and related documentation without an express
|
|
# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
|
|
import torch
|
|
import numpy as np
|
|
from loguru import logger
|
|
import importlib
|
|
import torch.nn as nn
|
|
from .distributions import Normal
|
|
from utils.model_helper import import_model
|
|
from utils.model_helper import loss_fn
|
|
from utils import utils as helper
|
|
|
|
class Model(nn.Module):
|
|
def __init__(self, args):
|
|
super().__init__()
|
|
self.num_total_iter = 0
|
|
self.args = args
|
|
self.input_dim = args.ddpm.input_dim
|
|
latent_dim = args.shapelatent.latent_dim
|
|
self.latent_dim = latent_dim
|
|
self.kl_weight = args.shapelatent.kl_weight
|
|
|
|
self.num_points = args.data.tr_max_sample_points
|
|
# ---- global ---- #
|
|
# build encoder
|
|
self.style_encoder = import_model(args.latent_pts.style_encoder)(
|
|
zdim=args.latent_pts.style_dim,
|
|
input_dim=self.input_dim,
|
|
args=args)
|
|
if len(args.latent_pts.style_mlp):
|
|
self.style_mlp = import_model(args.latent_pts.style_mlp)(args)
|
|
else:
|
|
self.style_mlp = None
|
|
|
|
self.encoder = import_model(args.shapelatent.encoder_type)(
|
|
zdim=latent_dim,
|
|
input_dim=self.input_dim,
|
|
args=args)
|
|
|
|
# build decoder
|
|
self.decoder = import_model(args.shapelatent.decoder_type)(
|
|
context_dim=latent_dim,
|
|
point_dim=args.ddpm.input_dim,
|
|
args=args)
|
|
logger.info('[Build Model] style_encoder: {}, encoder: {}, decoder: {}',
|
|
args.latent_pts.style_encoder,
|
|
args.shapelatent.encoder_type,
|
|
args.shapelatent.decoder_type)
|
|
|
|
@torch.no_grad()
|
|
def encode(self, x, class_label=None):
|
|
batch_size, _, point_dim = x.size()
|
|
assert(x.shape[2] == self.input_dim), f'expect input in ' \
|
|
f'[B,Npoint,PointDim={self.input_dim}], get: {x.shape}'
|
|
x_0_target = x
|
|
latent_list = []
|
|
all_eps = []
|
|
all_log_q = []
|
|
if self.args.data.cond_on_cat:
|
|
assert(class_label is not None), f'require class label input for cond on cat'
|
|
cls_emb = self.class_embedding(class_label)
|
|
enc_input = x, cls_emb
|
|
else:
|
|
enc_input = x
|
|
|
|
# ---- global style encoder ---- #
|
|
z = self.style_encoder(enc_input)
|
|
z_mu, z_sigma = z['mu_1d'], z['sigma_1d'] # log_sigma
|
|
dist = Normal(mu=z_mu, log_sigma=z_sigma) # (B, F)
|
|
z_global = dist.sample()[0]
|
|
all_eps.append(z_global)
|
|
all_log_q.append(dist.log_p(z_global))
|
|
latent_list.append( [z_global, z_mu, z_sigma] )
|
|
|
|
# ---- original encoder ---- #
|
|
style = z_global # torch.cat([z_global, cls_emb], dim=1) if self.args.data.cond_on_cat else z_global
|
|
style = self.style_mlp(style) if self.style_mlp is not None else style
|
|
z = self.encoder([x, style])
|
|
z_mu, z_sigma = z['mu_1d'], z['sigma_1d']
|
|
z_sigma = z_sigma - self.args.shapelatent.log_sigma_offset
|
|
dist = Normal(mu=z_mu, log_sigma=z_sigma) # (B, F)
|
|
z_local = dist.sample()[0]
|
|
all_eps.append(z_local)
|
|
all_log_q.append(dist.log_p(z_local))
|
|
latent_list.append( [z_local, z_mu, z_sigma] )
|
|
all_eps = self.compose_eps(all_eps)
|
|
if self.args.data.cond_on_cat:
|
|
return all_eps, all_log_q, latent_list, cls_emb
|
|
else:
|
|
return all_eps, all_log_q, latent_list
|
|
|
|
def compose_eps(self, all_eps):
|
|
return torch.cat(all_eps, dim=1) # style: [B,D1], latent pts: [B,ND2]
|
|
|
|
def decompose_eps(self, all_eps):
|
|
eps_style = all_eps[:,:self.args.latent_pts.style_dim]
|
|
eps_local = all_eps[:,self.args.latent_pts.style_dim:]
|
|
return [eps_style, eps_local]
|
|
|
|
def encode_global(self, x, class_label=None):
|
|
|
|
batch_size, N, point_dim = x.size()
|
|
if self.args.data.cond_on_cat:
|
|
assert(class_label is not None), f'require class label input for cond on cat'
|
|
cls_emb = self.class_embedding(class_label)
|
|
enc_input = x, cls_emb
|
|
else:
|
|
enc_input = x
|
|
|
|
z = self.style_encoder(enc_input)
|
|
z_mu, z_sigma = z['mu_1d'], z['sigma_1d'] # log_sigma
|
|
dist = Normal(mu=z_mu, log_sigma=z_sigma) # (B, F)
|
|
return dist
|
|
|
|
def global2style(self, style): ##, cls_emb=None):
|
|
Ndim = len(style.shape)
|
|
if Ndim == 4:
|
|
style = style.squeeze(-1).squeeze(-1)
|
|
style = self.style_mlp(style) if self.style_mlp is not None else style
|
|
if Ndim == 4:
|
|
style = style.unsqueeze(-1).unsqueeze(-1)
|
|
return style
|
|
|
|
def encode_local(self, x, style):
|
|
# ---- original encoder ---- #
|
|
z = self.encoder([x, style])
|
|
z_mu, z_sigma = z['mu_1d'], z['sigma_1d'] # log_sigma
|
|
z_sigma = z_sigma - self.args.shapelatent.log_sigma_offset
|
|
dist = Normal(mu=z_mu, log_sigma=z_sigma) # (B, F)
|
|
return dist
|
|
|
|
def recont(self, x, target=None, class_label=None, cls_emb=None):
|
|
batch_size, N, point_dim = x.size()
|
|
assert(x.shape[2] == self.input_dim), f'expect input in ' \
|
|
f'[B,Npoint,PointDim={self.input_dim}], get: {x.shape}'
|
|
x_0_target = x if target is None else target
|
|
latent_list = []
|
|
all_eps = []
|
|
all_log_q = []
|
|
|
|
# ---- global style encoder ---- #
|
|
if self.args.data.cond_on_cat:
|
|
if class_label is not None:
|
|
assert(class_label is not None)
|
|
cls_emb = self.class_embedding(class_label)
|
|
else:
|
|
assert(cls_emb is not None)
|
|
|
|
enc_input = x, cls_emb
|
|
else:
|
|
enc_input = x
|
|
z = self.style_encoder(enc_input)
|
|
z_mu, z_sigma = z['mu_1d'], z['sigma_1d'] # log_sigma
|
|
dist = Normal(mu=z_mu, log_sigma=z_sigma) # (B, F)
|
|
|
|
z_global = dist.sample()[0]
|
|
all_eps.append(z_global)
|
|
all_log_q.append(dist.log_p(z_global))
|
|
latent_list.append( [z_global, z_mu, z_sigma] )
|
|
|
|
# ---- original encoder ---- #
|
|
style = torch.cat([z_global, cls_emb], dim=1) if self.args.data.cond_on_cat else z_global
|
|
style = self.style_mlp(style) if self.style_mlp is not None else style
|
|
z = self.encoder([x, style])
|
|
z_mu, z_sigma = z['mu_1d'], z['sigma_1d'] # log_sigma
|
|
z_sigma = z_sigma - self.args.shapelatent.log_sigma_offset
|
|
dist = Normal(mu=z_mu, log_sigma=z_sigma) # (B, F)
|
|
z_local = dist.sample()[0]
|
|
all_eps.append(z_local)
|
|
all_log_q.append(dist.log_p(z_local))
|
|
latent_list.append( [z_local, z_mu, z_sigma] )
|
|
|
|
# ---- decoder ---- #
|
|
x_0_pred = self.decoder(None, beta=None, context=z_local, style=style) # (B,ncenter,3)
|
|
|
|
make_4d = lambda x: x.unsqueeze(-1).unsqueeze(-1) if len(x.shape) == 2 else x.unsqueeze(-1)
|
|
all_eps = [make_4d(e) for e in all_eps]
|
|
all_log_q = [make_4d(e) for e in all_log_q]
|
|
|
|
output = {
|
|
'all_eps': all_eps,
|
|
'all_log_q': all_log_q,
|
|
'latent_list': latent_list,
|
|
'x_0_pred':x_0_pred,
|
|
'x_0_target': x_0_target,
|
|
'x_t': torch.zeros_like(x_0_target),
|
|
't': torch.zeros(batch_size),
|
|
'x_0': x_0_target
|
|
}
|
|
output['hist/global_var'] = latent_list[0][2].exp()
|
|
|
|
if 'LatentPoint' in self.args.shapelatent.decoder_type:
|
|
latent_shape = [batch_size, -1, self.latent_dim + self.input_dim]
|
|
if 'Hir' in self.args.shapelatent.decoder_type:
|
|
latent_pts = z_local[:,:-self.args.latent_pts.latent_dim_ext[0]].view(*latent_shape)[:,:,:3].contiguous().clone()
|
|
else:
|
|
latent_pts = z_local.view(*latent_shape)[:,:,:self.input_dim].contiguous().clone()
|
|
|
|
output['vis/latent_pts'] = latent_pts.detach().cpu().view(batch_size,
|
|
-1, self.input_dim) # B,N,3
|
|
output['final_pred'] = output['x_0_pred']
|
|
return output
|
|
|
|
def get_loss(self, x, writer=None, it=None, ## weight_loss_1=1,
|
|
noisy_input=None, class_label=None, **kwargs):
|
|
"""
|
|
shapelatent z ~ q(z|x_0)
|
|
and x_t ~ q(x_t|x_0, t), t ~ Uniform(T)
|
|
forward and get x_{t-1} ~ p(x_{t-1} | x_t, z)
|
|
Args:
|
|
x: Input point clouds, (B, N, d).
|
|
"""
|
|
## kl_weight = self.kl_weight
|
|
if self.args.trainer.anneal_kl and self.num_total_iter > 0:
|
|
global_step = it
|
|
kl_weight = helper.kl_coeff(step=global_step,
|
|
total_step=self.args.sde.kl_anneal_portion_vada * self.num_total_iter,
|
|
constant_step=self.args.sde.kl_const_portion_vada * self.num_total_iter,
|
|
min_kl_coeff=self.args.sde.kl_const_coeff_vada,
|
|
max_kl_coeff=self.args.sde.kl_max_coeff_vada)
|
|
else:
|
|
kl_weight = self.kl_weight
|
|
|
|
batch_size = x.shape[0]
|
|
# CHECKDIM(x, 2, self.input_dim)
|
|
assert(x.shape[2] == self.input_dim)
|
|
|
|
inputs = noisy_input if noisy_input is not None else x
|
|
output = self.recont(inputs, target=x, class_label=class_label)
|
|
|
|
x_0_pred, x_0_target = output['x_0_pred'], output['x_0_target']
|
|
loss_0 = loss_fn(x_0_pred, x_0_target, self.args.ddpm.loss_type,
|
|
self.input_dim, batch_size).mean()
|
|
rec_loss = loss_0
|
|
output['print/loss_0'] = loss_0
|
|
output['rec_loss'] = rec_loss
|
|
|
|
# Loss
|
|
## z_global, z_sigma, z_mu = output['z_global'], output['z_sigma'], output['z_mu']
|
|
kl_term_list = []
|
|
weighted_kl_terms = []
|
|
for pairs_id, pairs in enumerate(output['latent_list']):
|
|
cz, cmu, csigma = pairs
|
|
log_sigma = csigma
|
|
kl_term_close = (0.5*log_sigma.exp()**2 +
|
|
0.5*cmu**2 - log_sigma - 0.5).view(
|
|
batch_size, -1)
|
|
if 'LatentPoint' in self.args.shapelatent.decoder_type and 'Hir' not in self.args.shapelatent.decoder_type:
|
|
if pairs_id == 1:
|
|
latent_shape = [batch_size, -1, self.latent_dim + self.input_dim]
|
|
kl_pt = kl_term_close.view(*latent_shape)[:,:,:self.input_dim]
|
|
kl_feat = kl_term_close.view(*latent_shape)[:,:,self.input_dim:]
|
|
weighted_kl_terms.append(kl_pt.sum(2).sum(1) * self.args.latent_pts.weight_kl_pt)
|
|
weighted_kl_terms.append(kl_feat.sum(2).sum(1) * self.args.latent_pts.weight_kl_feat)
|
|
|
|
output['print/kl_pt%d'%pairs_id] = kl_pt.sum(2).sum(1)
|
|
output['print/kl_feat%d'%pairs_id] = kl_feat.sum(2).sum(1)
|
|
|
|
output['print/z_var_pt%d'%pairs_id] = (log_sigma.view(*latent_shape)[:,:,:self.input_dim]
|
|
).exp()**2
|
|
output['print/z_var_feat%d'%pairs_id] = (log_sigma.view(*latent_shape)[:,:,self.input_dim:]
|
|
).exp()**2
|
|
output['print/z_mean_feat%d'%pairs_id] = cmu.view(*latent_shape)[:,:,self.input_dim:].mean()
|
|
elif pairs_id == 0:
|
|
kl_style = kl_term_close
|
|
weighted_kl_terms.append(kl_style.sum(-1) * self.args.latent_pts.weight_kl_glb)
|
|
|
|
output['print/kl_glb%d'%pairs_id] = kl_style.sum(-1)
|
|
output['print/z_var_glb%d'%pairs_id] = (log_sigma).exp()**2
|
|
|
|
kl_term_close = kl_term_close.sum(-1)
|
|
kl_term_list.append(kl_term_close)
|
|
output['print/kl_%d'%pairs_id] = kl_term_close
|
|
output['print/z_mean_%d'%pairs_id] = cmu.mean()
|
|
output['print/z_mag_%d'%pairs_id] = cmu.abs().max()
|
|
# logger.info('log_sigma: {}, mean: {}', log_sigma.shape, (log_sigma.exp()**2).mean())
|
|
output['print/z_var_%d'%pairs_id] = (log_sigma).exp()**2
|
|
output['print/z_logsigma_%d'%pairs_id] = log_sigma
|
|
output['print/kl_weight'] = kl_weight
|
|
|
|
|
|
loss_recons = rec_loss
|
|
if len(weighted_kl_terms) > 0:
|
|
kl = kl_weight * sum(weighted_kl_terms)
|
|
else:
|
|
kl = kl_weight * sum(kl_term_list)
|
|
loss = kl + loss_recons * self.args.weight_recont
|
|
output['msg/kl'] = kl
|
|
output['msg/rec'] = loss_recons
|
|
output['loss'] = loss
|
|
return output
|
|
|
|
def pz(self, w):
|
|
return w
|
|
|
|
def sample(self, num_samples=10, temp=None, decomposed_eps=[],
|
|
enable_autocast=False, device_str='cuda', cls_emb=None):
|
|
""" currently not support the samples of local level
|
|
Return:
|
|
model_output: [B,N,D]
|
|
"""
|
|
batch_size = num_samples
|
|
center_emd = None
|
|
if 'LatentPoint' in self.args.shapelatent.decoder_type:
|
|
# Latent Point Model: latent shape; B; ND
|
|
latent_shape = (num_samples, self.num_points*(self.latent_dim+self.input_dim))
|
|
style_latent_shape = (num_samples, self.args.latent_pts.style_dim)
|
|
else:
|
|
raise NotImplementedError
|
|
|
|
if len(decomposed_eps) == 0:
|
|
z_local = torch.zeros(*latent_shape).to(
|
|
torch.device(device_str)).normal_()
|
|
z_global = torch.zeros(*style_latent_shape).to(
|
|
torch.device(device_str)).normal_()
|
|
else:
|
|
z_global = decomposed_eps[0]
|
|
z_local = decomposed_eps[1]
|
|
|
|
z_local = z_local.view(*latent_shape)
|
|
z_global = z_global.view(style_latent_shape)
|
|
|
|
style = z_global
|
|
style = self.style_mlp(style) if self.style_mlp is not None else style
|
|
x_0_pred = self.decoder(None, beta=None,
|
|
context=z_local, style=z_global) # (B,ncenter,3)
|
|
## CHECKSIZE(x_0_pred, (batch_size,self.num_points,[3,6]))
|
|
return x_0_pred
|
|
|
|
def latent_shape(self):
|
|
return [
|
|
[self.args.latent_pts.style_dim, 1, 1],
|
|
[self.num_points*(self.latent_dim+self.input_dim),1,1]
|
|
]
|