274 lines
12 KiB
Python
274 lines
12 KiB
Python
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
#
|
|
# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
|
|
# and proprietary rights in and to this software, related documentation
|
|
# and any modifications thereto. Any use, reproduction, disclosure or
|
|
# distribution of this software and related documentation without an express
|
|
# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
|
|
|
|
import torch
|
|
from loguru import logger
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
import numpy as np
|
|
from .pvcnn2_ada import \
|
|
create_pointnet2_sa_components, create_pointnet2_fp_modules, LinearAttention, create_mlp_components, SharedMLP
|
|
|
|
# the building block of encode and decoder for VAE
|
|
|
|
class PVCNN2Unet(nn.Module):
|
|
"""
|
|
copied and modified from https://github.com/alexzhou907/PVD/blob/9747265a5f141e5546fd4f862bfa66aa59f1bd33/model/pvcnn_generation.py#L172
|
|
"""
|
|
def __init__(self,
|
|
num_classes, embed_dim, use_att, dropout=0.1,
|
|
extra_feature_channels=3,
|
|
input_dim=3,
|
|
width_multiplier=1,
|
|
voxel_resolution_multiplier=1,
|
|
time_emb_scales=1.0,
|
|
verbose=True,
|
|
condition_input=False,
|
|
point_as_feat=1, cfg={},
|
|
sa_blocks={}, fp_blocks={},
|
|
clip_forge_enable=0,
|
|
clip_forge_dim=512
|
|
):
|
|
super().__init__()
|
|
logger.info('[Build Unet] extra_feature_channels={}, input_dim={}',
|
|
extra_feature_channels, input_dim)
|
|
self.input_dim = input_dim
|
|
|
|
self.clip_forge_enable = clip_forge_enable
|
|
self.sa_blocks = sa_blocks
|
|
self.fp_blocks = fp_blocks
|
|
self.point_as_feat = point_as_feat
|
|
self.condition_input = condition_input
|
|
assert extra_feature_channels >= 0
|
|
self.time_emb_scales = time_emb_scales
|
|
self.embed_dim = embed_dim
|
|
## assert(self.embed_dim == 0)
|
|
if self.embed_dim > 0: # has time embedding
|
|
# for prior model, we have time embedding, for VAE model, no time embedding
|
|
self.embedf = nn.Sequential(
|
|
nn.Linear(embed_dim, embed_dim),
|
|
nn.LeakyReLU(0.1, inplace=True),
|
|
nn.Linear(embed_dim, embed_dim),
|
|
)
|
|
|
|
if self.clip_forge_enable:
|
|
self.clip_forge_mapping = nn.Linear(clip_forge_dim, embed_dim)
|
|
style_dim = cfg.latent_pts.style_dim
|
|
self.style_clip = nn.Linear(style_dim + embed_dim, style_dim)
|
|
|
|
self.in_channels = extra_feature_channels + 3
|
|
|
|
sa_layers, sa_in_channels, channels_sa_features, _ = \
|
|
create_pointnet2_sa_components(
|
|
input_dim=input_dim,
|
|
sa_blocks=self.sa_blocks,
|
|
extra_feature_channels=extra_feature_channels,
|
|
with_se=True,
|
|
embed_dim=embed_dim, # time embedding dim
|
|
use_att=use_att, dropout=dropout,
|
|
width_multiplier=width_multiplier,
|
|
voxel_resolution_multiplier=voxel_resolution_multiplier,
|
|
verbose=verbose, cfg=cfg
|
|
)
|
|
self.sa_layers = nn.ModuleList(sa_layers)
|
|
|
|
self.global_att = None if not use_att else LinearAttention(channels_sa_features, 8, verbose=verbose)
|
|
|
|
# only use extra features in the last fp module
|
|
sa_in_channels[0] = extra_feature_channels + input_dim - 3
|
|
fp_layers, channels_fp_features = create_pointnet2_fp_modules(
|
|
fp_blocks=self.fp_blocks, in_channels=channels_sa_features,
|
|
sa_in_channels=sa_in_channels,
|
|
with_se=True, embed_dim=embed_dim,
|
|
use_att=use_att, dropout=dropout,
|
|
width_multiplier=width_multiplier, voxel_resolution_multiplier=voxel_resolution_multiplier,
|
|
verbose=verbose, cfg=cfg
|
|
)
|
|
self.fp_layers = nn.ModuleList(fp_layers)
|
|
|
|
layers, _ = create_mlp_components(
|
|
in_channels=channels_fp_features,
|
|
out_channels=[128, dropout, num_classes], # was 0.5
|
|
classifier=True, dim=2, width_multiplier=width_multiplier,
|
|
cfg=cfg)
|
|
self.classifier = nn.ModuleList(layers)
|
|
|
|
def get_timestep_embedding(self, timesteps, device):
|
|
if len(timesteps.shape) == 2 and timesteps.shape[1] == 1:
|
|
timesteps = timesteps[:,0]
|
|
assert(len(timesteps.shape) == 1), f'get shape: {timesteps.shape}'
|
|
timesteps = timesteps * self.time_emb_scales
|
|
|
|
half_dim = self.embed_dim // 2
|
|
emb = np.log(10000) / (half_dim - 1)
|
|
emb = torch.from_numpy(np.exp(np.arange(0, half_dim) * -emb)).float().to(device)
|
|
emb = timesteps[:, None] * emb[None, :]
|
|
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
|
|
if self.embed_dim % 2 == 1: # zero pad
|
|
emb = nn.functional.pad(emb, (0, 1), "constant", 0)
|
|
assert emb.shape == torch.Size([timesteps.shape[0], self.embed_dim])
|
|
return emb
|
|
|
|
def forward(self, inputs, **kwargs):
|
|
# Input: coords: B3N
|
|
B = inputs.shape[0]
|
|
coords = inputs[:, :self.input_dim, :].contiguous()
|
|
features = inputs
|
|
temb = kwargs.get('t', None)
|
|
if temb is not None:
|
|
t = temb
|
|
if t.ndim == 0 and not len(t.shape) == 1:
|
|
t = t.view(1).expand(B)
|
|
temb = self.embedf(self.get_timestep_embedding(t, inputs.device
|
|
))[:,:,None].expand(-1,-1,inputs.shape[-1])
|
|
temb_ori = temb # B,embed_dim,Npoint
|
|
|
|
style = kwargs['style']
|
|
if self.clip_forge_enable:
|
|
clip_feat = kwargs['clip_feat']
|
|
assert(clip_feat is not None), f'require clip_feat as input'
|
|
clip_feat = self.clip_forge_mapping(clip_feat)
|
|
style = torch.cat([style, clip_feat], dim=1).contiguous()
|
|
style = self.style_clip(style)
|
|
|
|
coords_list, in_features_list = [], []
|
|
for i, sa_blocks in enumerate(self.sa_layers):
|
|
in_features_list.append(features)
|
|
coords_list.append(coords)
|
|
if i > 0 and temb is not None:
|
|
#TODO: implement a sa_blocks forward function; check if is PVConv layer and kwargs get grid_emb, take as additional input
|
|
features = torch.cat([features,temb],dim=1)
|
|
features, coords, temb, _ = \
|
|
sa_blocks ((features,
|
|
coords, temb, style))
|
|
else: # i == 0 or temb is None
|
|
features, coords, temb, _ = \
|
|
sa_blocks ((features, coords, temb, style))
|
|
|
|
in_features_list[0] = inputs[:, 3:, :].contiguous()
|
|
if self.global_att is not None:
|
|
features = self.global_att(features)
|
|
for fp_idx, fp_blocks in enumerate(self.fp_layers):
|
|
if temb is not None:
|
|
features, coords, temb, _ = fp_blocks((
|
|
coords_list[-1-fp_idx], coords,
|
|
torch.cat([features,temb],dim=1),
|
|
in_features_list[-1-fp_idx], temb, style))
|
|
else:
|
|
features, coords, temb, _ = fp_blocks((
|
|
coords_list[-1-fp_idx], coords,
|
|
features,
|
|
in_features_list[-1-fp_idx], temb, style))
|
|
|
|
for l in self.classifier:
|
|
if isinstance(l, SharedMLP):
|
|
features = l(features, style)
|
|
else:
|
|
features = l(features)
|
|
return features
|
|
|
|
class PointTransPVC(nn.Module):
|
|
# encoder : B,N,3 -> B,N,2*D
|
|
sa_blocks = [ # conv_configs, sa_configs
|
|
((32, 2, 32), (1024, 0.1, 32, (32, 64))),
|
|
((64, 3, 16), (256, 0.2, 32, (64, 128))),
|
|
((128, 3, 8), (64, 0.4, 32, (128, 256))),
|
|
(None, (16, 0.8, 32, (128, 128, 128))),
|
|
]
|
|
fp_blocks = [
|
|
((128, 128), (128, 3, 8)), # fp_configs, conv_configs
|
|
((128, 128), (128, 3, 8)),
|
|
((128, 128), (128, 2, 16)),
|
|
((128, 128, 64), (64, 2, 32)),
|
|
]
|
|
|
|
def __init__(self, zdim, input_dim, args={}):
|
|
super().__init__()
|
|
self.zdim = zdim
|
|
self.layers = PVCNN2Unet(2*zdim+input_dim*2,
|
|
embed_dim=0, use_att=1, extra_feature_channels=0,
|
|
input_dim=args.ddpm.input_dim, cfg=args,
|
|
sa_blocks=self.sa_blocks, fp_blocks=self.fp_blocks,
|
|
dropout=args.ddpm.dropout)
|
|
self.skip_weight = args.latent_pts.skip_weight
|
|
self.pts_sigma_offset = args.latent_pts.pts_sigma_offset
|
|
self.input_dim = input_dim
|
|
|
|
def forward(self, inputs):
|
|
x, style = inputs
|
|
B,N,D = x.shape
|
|
output = self.layers(x.permute(0,2,1).contiguous(), style=style).permute(0,2,1).contiguous() # BND
|
|
|
|
pt_mu_1d = output[:,:,:self.input_dim].contiguous()
|
|
pt_sigma_1d = output[:,:,self.input_dim:2*self.input_dim].contiguous() - self.pts_sigma_offset
|
|
|
|
pt_mu_1d = self.skip_weight * pt_mu_1d + x
|
|
if self.zdim > 0:
|
|
ft_mu_1d = output[:,:,2*self.input_dim:-self.zdim].contiguous()
|
|
ft_sigma_1d = output[:,:,-self.zdim:].contiguous()
|
|
|
|
mu_1d = torch.cat([pt_mu_1d, ft_mu_1d], dim=2).view(B,-1).contiguous()
|
|
sigma_1d = torch.cat([pt_sigma_1d, ft_sigma_1d], dim=2).view(B,-1).contiguous()
|
|
else:
|
|
mu_1d = pt_mu_1d.view(B,-1).contiguous()
|
|
sigma_1d = pt_sigma_1d.view(B,-1).contiguous()
|
|
return {'mu_1d': mu_1d, 'sigma_1d': sigma_1d}
|
|
|
|
class LatentPointDecPVC(nn.Module):
|
|
""" input x: [B,Npoint,D] with [B,Npoint,3]
|
|
"""
|
|
sa_blocks = [ # conv_configs, sa_configs
|
|
((32, 2, 32), (1024, 0.1, 32, (32, 64))),
|
|
((64, 3, 16), (256, 0.2, 32, (64, 128))),
|
|
((128, 3, 8), (64, 0.4, 32, (128, 256))),
|
|
(None, (16, 0.8, 32, (128, 128, 128))),
|
|
]
|
|
fp_blocks = [
|
|
((128, 128), (128, 3, 8)), # fp_configs, conv_configs
|
|
((128, 128), (128, 3, 8)),
|
|
((128, 128), (128, 2, 16)),
|
|
((128, 128, 64), (64, 2, 32)),
|
|
]
|
|
|
|
def __init__(self, point_dim, context_dim, num_points=None, args={}, **kwargs):
|
|
super().__init__()
|
|
self.point_dim = point_dim
|
|
logger.info('[Build Dec] point_dim={}, context_dim={}', point_dim, context_dim)
|
|
self.context_dim = context_dim + self.point_dim
|
|
# self.num_points = num_points
|
|
if num_points is None:
|
|
self.num_points = args.data.tr_max_sample_points
|
|
else:
|
|
self.num_points = num_points
|
|
self.layers = PVCNN2Unet(point_dim, embed_dim=0, use_att=1,
|
|
extra_feature_channels=context_dim,
|
|
input_dim=args.ddpm.input_dim, cfg=args,
|
|
sa_blocks=self.sa_blocks, fp_blocks=self.fp_blocks,
|
|
dropout=args.ddpm.dropout)
|
|
self.skip_weight = args.latent_pts.skip_weight
|
|
|
|
def forward(self, x, beta, context, style):
|
|
"""
|
|
Args:
|
|
x: Point clouds at some timestep t, (B, N, d). [not used]
|
|
beta: Time. (B, ). [not used]
|
|
context: Latent points, (B,N_pts*D_latent_pts), D_latent_pts = D_input + D_extra
|
|
style: Shape latents. (B,d).
|
|
Returns:
|
|
points: (B,N,3)
|
|
"""
|
|
|
|
# CHECKDIM(context, 1, self.num_points*self.context_dim)
|
|
assert(context.shape[1] == self.num_points*self.context_dim)
|
|
context = context.view(-1,self.num_points,self.context_dim) # BND
|
|
x = context[:,:,:self.point_dim]
|
|
output = self.layers(context.permute(0,2,1).contiguous(), style=style).permute(0,2,1).contiguous() # BN3
|
|
output = output * self.skip_weight + x
|
|
return output
|
|
|