LION/models/shapelatent_modules.py
2023-01-23 00:14:49 -05:00

55 lines
2.1 KiB
Python

# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
import torch.nn as nn
from loguru import logger
from .pvcnn2 import create_pointnet2_sa_components
# implement the global encoder for VAE model
class PointNetPlusEncoder(nn.Module):
sa_blocks = [
[[32, 2, 32], [1024, 0.1, 32, [32, 32]]],
[[32, 1, 16], [256, 0.2, 32, [32, 64]]]
]
force_att = 0 # add attention to all layers
def __init__(self, zdim, input_dim, extra_feature_channels=0, args={}):
super().__init__()
sa_blocks = self.sa_blocks
layers, sa_in_channels, channels_sa_features, _ = \
create_pointnet2_sa_components(sa_blocks,
extra_feature_channels, input_dim=input_dim,
embed_dim=0, force_att=self.force_att,
use_att=True, with_se=True)
self.mlp = nn.Linear(channels_sa_features, zdim*2)
self.zdim = zdim
logger.info('[Encoder] zdim={}, out_sigma={}; force_att: {}', zdim, True, self.force_att)
self.layers = nn.ModuleList(layers)
self.voxel_dim = [n[1][-1][-1] for n in self.sa_blocks]
def forward(self, x):
"""
Args:
x: B,N,3
Returns:
mu, sigma: B,D
"""
output = {}
x = x.transpose(1, 2) # B,3,N
xyz = x ## x[:,:3,:]
features = x
for layer_id, layer in enumerate(self.layers):
features, xyz, _ = layer( (features, xyz, None) )
# features: B,D,N; xyz: B,3,N
features = features.max(-1)[0]
features = self.mlp(features)
mu_1d, sigma_1d = features[:, :self.zdim], features[:, self.zdim:]
output.update({'mu_1d': mu_1d, 'sigma_1d': sigma_1d})
return output