LION/models/pvcnn2_ada.py

569 lines
22 KiB
Python
Raw Permalink Normal View History

2023-01-23 05:14:49 +00:00
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
"""
copied and modified from source:
https://github.com/alexzhou907/PVD/blob/9747265a5f141e5546fd4f862bfa66aa59f1bd33/model/pvcnn_generation.py
and functions under
https://github.com/alexzhou907/PVD/tree/9747265a5f141e5546fd4f862bfa66aa59f1bd33/modules
"""
import copy
import functools
from loguru import logger
from einops import rearrange
import torch.nn as nn
import torch
import numpy as np
import third_party.pvcnn.functional as F
# from utils.checker import *
from torch.cuda.amp import autocast, GradScaler, custom_fwd, custom_bwd
from .adagn import AdaGN
import os
quiet = int(os.environ.get('quiet', 0))
class SE3d(nn.Module):
def __init__(self, channel, reduction=8):
super().__init__()
self.fc = nn.Sequential(
nn.Linear(channel, channel // reduction, bias=False),
nn.ReLU(inplace=True),
nn.Linear(channel // reduction, channel, bias=False),
nn.Sigmoid()
)
self.channel = channel
def __repr__(self):
return f"SE({self.channel}, {self.channel})"
def forward(self, inputs):
return inputs * self.fc(inputs.mean(-1).mean(-1).mean(-1)).view(inputs.shape[0], inputs.shape[1], 1, 1, 1)
class LinearAttention(nn.Module):
"""
copied and modified from https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py#L159
"""
def __init__(self, dim, heads = 4, dim_head = 32, verbose=True):
super().__init__()
self.heads = heads
hidden_dim = dim_head * heads
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
self.to_out = nn.Conv2d(hidden_dim, dim, 1)
def forward(self, x):
'''
Args:
x: torch.tensor (B,C,N), C=num-channels, N=num-points
Returns:
out: torch.tensor (B,C,N)
'''
x = x.unsqueeze(-1) # add w dimension
b, c, h, w = x.shape
qkv = self.to_qkv(x)
q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3)
k = k.softmax(dim=-1)
context = torch.einsum('bhdn,bhen->bhde', k, v)
out = torch.einsum('bhde,bhdn->bhen', context, q)
out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w)
out = self.to_out(out)
out = out.squeeze(-1) # B,C,N,1 -> B,C,N
return out
def swish(input):
return input * torch.sigmoid(input)
class Swish(nn.Module):
def __init__(self):
super().__init__()
def forward(self, input):
return swish(input)
class BallQuery(nn.Module):
def __init__(self, radius, num_neighbors, include_coordinates=True):
super().__init__()
self.radius = radius
self.num_neighbors = num_neighbors
self.include_coordinates = include_coordinates
@custom_bwd
def backward(self, *args, **kwargs):
return super().backward(*args, **kwargs)
@custom_fwd(cast_inputs=torch.float32)
def forward(self, points_coords, centers_coords, points_features=None):
# input: BCN, BCN
# neighbor_features: B,D(+3),Ncenter
points_coords = points_coords.contiguous()
centers_coords = centers_coords.contiguous()
neighbor_indices = F.ball_query(centers_coords, points_coords, self.radius, self.num_neighbors)
neighbor_coordinates = F.grouping(points_coords, neighbor_indices)
neighbor_coordinates = neighbor_coordinates - centers_coords.unsqueeze(-1)
if points_features is None:
assert self.include_coordinates, 'No Features For Grouping'
neighbor_features = neighbor_coordinates
else:
neighbor_features = F.grouping(points_features, neighbor_indices)
if self.include_coordinates:
neighbor_features = torch.cat([neighbor_coordinates, neighbor_features], dim=1)
return neighbor_features
def extra_repr(self):
return 'radius={}, num_neighbors={}{}'.format(
self.radius, self.num_neighbors, ', include coordinates' if self.include_coordinates else '')
class SharedMLP(nn.Module):
def __init__(self, in_channels, out_channels, dim=1, cfg={}):
assert(len(cfg) > 0), cfg
super().__init__()
if dim==1:
conv = nn.Conv1d
else:
conv = nn.Conv2d
bn = functools.partial(AdaGN, dim, cfg)
if not isinstance(out_channels, (list, tuple)):
out_channels = [out_channels]
layers = []
for oc in out_channels:
layers.append(conv(in_channels, oc, 1))
layers.append(bn(oc))
layers.append(Swish())
in_channels = oc
self.layers = nn.ModuleList(layers)
def forward(self, *inputs):
if len(inputs) == 1 and len(inputs[0]) == 4:
# try to fix thwn SharedMLP is the first layer
inputs = inputs[0]
if len(inputs) == 1:
raise NotImplementedError
elif len(inputs) == 4:
assert(len(inputs) == 4), 'input, style'
x, _, _, style = inputs
for l in self.layers:
if isinstance(l, AdaGN):
x = l(x, style)
else:
x = l(x)
return (x, *inputs[1:])
elif len(inputs) == 2:
x, style = inputs
for l in self.layers:
if isinstance(l, AdaGN):
x = l(x, style)
else:
x = l(x)
return x
else:
raise NotImplementedError
class Voxelization(nn.Module):
def __init__(self, resolution, normalize=True, eps=0):
super().__init__()
self.r = int(resolution)
self.normalize = normalize
self.eps = eps
def forward(self, features, coords):
# features: B,D,N
# coords: B,3,N
coords = coords.detach()
norm_coords = coords - coords.mean(2, keepdim=True)
if self.normalize:
norm_coords = norm_coords / (norm_coords.norm(
dim=1, keepdim=True).max(dim=2, keepdim=True).values * 2.0 +
self.eps) + 0.5
else:
norm_coords = (norm_coords + 1) / 2.0
norm_coords = torch.clamp(norm_coords * self.r, 0, self.r - 1)
vox_coords = torch.round(norm_coords).to(torch.int32)
if features is None:
return features, norm_coords
return F.avg_voxelize(features, vox_coords, self.r), norm_coords
def extra_repr(self):
return 'resolution={}{}'.format(
self.r,
', normalized eps = {}'.format(self.eps) if self.normalize else '')
class PVConv(nn.Module):
def __init__(self, in_channels, out_channels,
kernel_size, resolution,
normalize=1, eps=0, with_se=False,
add_point_feat=True, attention=False,
dropout=0.1, verbose=True,
cfg={}
):
super().__init__()
assert(len(cfg) > 0), cfg
self.resolution = resolution
self.voxelization = Voxelization(resolution,
normalize=normalize,
eps=eps)
# For each PVConv we use (Conv3d, GroupNorm(8), Swish, dropout, Conv3d, GroupNorm(8), Attention)
NormLayer = functools.partial(AdaGN, 3, cfg)
voxel_layers = [
nn.Conv3d(in_channels ,
out_channels,
kernel_size, stride=1,
padding=kernel_size // 2),
NormLayer(out_channels),
Swish(),
nn.Dropout(dropout),
nn.Conv3d(out_channels, out_channels,
kernel_size, stride=1,
padding=kernel_size // 2),
NormLayer(out_channels)
]
if with_se:
voxel_layers.append(SE3d(out_channels))
self.voxel_layers = nn.ModuleList(voxel_layers)
if attention:
self.attn = LinearAttention(out_channels, verbose=verbose)
else:
self.attn = None
if add_point_feat:
self.point_features = SharedMLP(in_channels, out_channels, cfg=cfg)
self.add_point_feat = add_point_feat
def forward(self, inputs):
'''
Args:
inputs: tuple of features and coords
features: B,feat-dim,num-points
coords: B,3, num-points
time_emd: B,D; time embedding
style: B,D; global latent
Returns:
fused_features: in (B,out-feat-dim,num-points)
coords : in (B, 3 or 6, num_points); same as the input coords
'''
features = inputs[0]
coords_input= inputs[1]
time_emb = inputs[2]
style = inputs[3]
if coords_input.shape[1] > 3:
coords = coords_input[:,:3]
else:
coords = coords_input
assert (features.shape[0] == coords.shape[0]
), f'get feat: {features.shape} and {coords.shape}'
assert (features.shape[2] == coords.shape[2]
), f'get feat: {features.shape} and {coords.shape}'
assert (coords.shape[1] == 3
), f'expect coords: B,3,Npoint, get: {coords.shape}'
# features: B,D,N; point_features
# coords: B,3,N
voxel_features_4d, voxel_coords = self.voxelization(features, coords)
r = self.resolution
B = coords.shape[0]
for voxel_layers in self.voxel_layers:
if isinstance(voxel_layers, AdaGN):
voxel_features_4d = voxel_layers(voxel_features_4d, style)
else:
voxel_features_4d = voxel_layers(voxel_features_4d)
voxel_features = F.trilinear_devoxelize(voxel_features_4d, voxel_coords,
r, self.training)
fused_features = voxel_features
if self.add_point_feat:
fused_features = fused_features + self.point_features(features, style)
if self.attn is not None:
fused_features = self.attn(fused_features)
return fused_features, coords_input, time_emb, style
class PointNetAModule(nn.Module):
def __init__(self, in_channels, out_channels, include_coordinates=True, cfg={}):
super().__init__()
if not isinstance(out_channels, (list, tuple)):
out_channels = [[out_channels]]
elif not isinstance(out_channels[0], (list, tuple)):
out_channels = [out_channels]
mlps = []
total_out_channels = 0
for _out_channels in out_channels:
mlps.append(
SharedMLP(in_channels=in_channels + (3 if include_coordinates else 0),
out_channels=_out_channels, dim=1, cfg=cfg)
)
total_out_channels += _out_channels[-1]
self.include_coordinates = include_coordinates
self.out_channels = total_out_channels
self.mlps = nn.ModuleList(mlps)
def forward(self, inputs):
features, coords, time_emb, style = inputs
if self.include_coordinates:
features = torch.cat([features, coords], dim=1)
coords = torch.zeros((coords.size(0), 3, 1), device=coords.device)
if len(self.mlps) > 1:
features_list = []
for mlp in self.mlps:
features_list.append(mlp(features, style).max(dim=-1, keepdim=True).values)
return torch.cat(features_list, dim=1), coords, time_emb
else:
return self.mlps[0](features, style).max(dim=-1, keepdim=True).values, coords, time_emb
def extra_repr(self):
return f'out_channels={self.out_channels}, include_coordinates={self.include_coordinates}'
class PointNetSAModule(nn.Module):
def __init__(self, num_centers, radius, num_neighbors, in_channels, out_channels, include_coordinates=True,
cfg={}):
super().__init__()
if not isinstance(radius, (list, tuple)):
radius = [radius]
if not isinstance(num_neighbors, (list, tuple)):
num_neighbors = [num_neighbors] * len(radius)
assert len(radius) == len(num_neighbors)
if not isinstance(out_channels, (list, tuple)):
out_channels = [[out_channels]] * len(radius)
elif not isinstance(out_channels[0], (list, tuple)):
out_channels = [out_channels] * len(radius)
assert len(radius) == len(out_channels)
groupers, mlps = [], []
total_out_channels = 0
for _radius, _out_channels, _num_neighbors in zip(radius, out_channels, num_neighbors):
groupers.append(
BallQuery(radius=_radius, num_neighbors=_num_neighbors,
include_coordinates=include_coordinates)
)
mlps.append(
SharedMLP(in_channels=in_channels + (3 if include_coordinates else 0),
out_channels=_out_channels, dim=2, cfg=cfg)
)
total_out_channels += _out_channels[-1]
self.num_centers = num_centers
self.out_channels = total_out_channels
self.groupers = nn.ModuleList(groupers)
self.mlps = nn.ModuleList(mlps)
def forward(self, inputs):
features = inputs[0]
coords = inputs[1] # B3N
style = inputs[3]
if coords.shape[1] > 3:
coords = coords[:,:3]
centers_coords = F.furthest_point_sample(coords, self.num_centers)
# centers_coords: B,D,N
S = centers_coords.shape[-1]
time_emb = inputs[2]
time_emb = time_emb[:,:,:S] if \
time_emb is not None and type(time_emb) is not dict \
else time_emb
features_list = []
c = 0
for grouper, mlp in zip(self.groupers, self.mlps):
c += 1
grouper_output = grouper(coords, centers_coords, features )
features_list.append(
mlp(grouper_output, style
).max(dim=-1).values
)
if len(features_list) > 1:
return torch.cat(features_list, dim=1), centers_coords, time_emb, style
else:
return features_list[0], centers_coords, time_emb, style
def extra_repr(self):
return f'num_centers={self.num_centers}, out_channels={self.out_channels}'
class PointNetFPModule(nn.Module):
def __init__(self, in_channels, out_channels, cfg={}):
super().__init__()
self.mlp = SharedMLP(in_channels=in_channels, out_channels=out_channels, dim=1, cfg=cfg)
def forward(self, inputs):
if len(inputs) == 5:
points_coords, centers_coords, centers_features, time_emb, style = inputs
points_features = None
elif len(inputs) == 6:
points_coords, centers_coords, centers_features, points_features, time_emb, style = inputs
else:
raise NotImplementedError
interpolated_features = F.nearest_neighbor_interpolate(points_coords, centers_coords, centers_features)
if points_features is not None:
interpolated_features = torch.cat(
[interpolated_features, points_features], dim=1
)
if time_emb is not None:
B,D,S = time_emb.shape
N = points_coords.shape[-1]
time_emb = time_emb[:,:,0:1].expand(-1,-1,N)
return self.mlp(interpolated_features, style), points_coords, time_emb, style
def _linear_gn_relu(in_channels, out_channels):
return nn.Sequential(nn.Linear(in_channels, out_channels), nn.GroupNorm(8,out_channels), Swish())
def create_mlp_components(in_channels, out_channels, classifier=False, dim=2, width_multiplier=1, cfg={}):
r = width_multiplier
if dim == 1:
block = _linear_gn_relu
else:
block = SharedMLP
if not isinstance(out_channels, (list, tuple)):
out_channels = [out_channels]
if len(out_channels) == 0 or (len(out_channels) == 1 and out_channels[0] is None):
return nn.Sequential(), in_channels, in_channels
layers = []
for oc in out_channels[:-1]:
if oc < 1:
layers.append(nn.Dropout(oc))
else:
oc = int(r * oc)
layers.append(block(in_channels, oc, cfg=cfg))
in_channels = oc
if dim == 1:
if classifier:
layers.append(nn.Linear(in_channels, out_channels[-1]))
else:
layers.append(_linear_gn_relu(in_channels, int(r * out_channels[-1])))
else:
if classifier:
layers.append(nn.Conv1d(in_channels, out_channels[-1], 1))
else:
layers.append(SharedMLP(in_channels, int(r * out_channels[-1])))
return layers, out_channels[-1] if classifier else int(r * out_channels[-1])
def create_pointnet2_sa_components(sa_blocks, extra_feature_channels,
input_dim=3,
embed_dim=64, use_att=False, force_att=0,
dropout=0.1, with_se=False, normalize=True, eps=0, has_temb=1,
width_multiplier=1, voxel_resolution_multiplier=1, verbose=True,
cfg={}):
"""
Returns:
in_channels: the last output channels of the sa blocks
"""
assert(len(cfg) > 0), cfg
r, vr = width_multiplier, voxel_resolution_multiplier
in_channels = extra_feature_channels + input_dim
sa_layers, sa_in_channels = [], []
c = 0
num_centers = None
for conv_configs, sa_configs in sa_blocks:
k = 0
sa_in_channels.append(in_channels)
sa_blocks = []
if conv_configs is not None:
out_channels, num_blocks, voxel_resolution = conv_configs
out_channels = int(r * out_channels)
for p in range(num_blocks):
attention = ( (c+1) % 2 == 0 and use_att and p == 0 ) or (force_att and c > 0)
if voxel_resolution is None:
block = SharedMLP
else:
block = functools.partial(
PVConv, kernel_size=3,
resolution=int(vr * voxel_resolution), attention=attention,
dropout=dropout,
with_se=with_se, # with_se_relu=True,
normalize=normalize, eps=eps, verbose=verbose, cfg=cfg)
if c == 0:
sa_blocks.append(block(in_channels, out_channels, cfg=cfg))
elif k ==0:
sa_blocks.append(block(in_channels+embed_dim*has_temb, out_channels, cfg=cfg))
in_channels = out_channels
k += 1
extra_feature_channels = in_channels
if sa_configs is not None:
num_centers, radius, num_neighbors, out_channels = sa_configs
_out_channels = []
for oc in out_channels:
if isinstance(oc, (list, tuple)):
_out_channels.append([int(r * _oc) for _oc in oc])
else:
_out_channels.append(int(r * oc))
out_channels = _out_channels
if num_centers is None:
block = PointNetAModule
else:
block = functools.partial(PointNetSAModule, num_centers=num_centers, radius=radius,
num_neighbors=num_neighbors)
sa_blocks.append(block(cfg=cfg,
in_channels=extra_feature_channels+(embed_dim*has_temb if k==0 else 0 ),
out_channels=out_channels,
include_coordinates=True))
in_channels = extra_feature_channels = sa_blocks[-1].out_channels
c += 1
if len(sa_blocks) == 1:
sa_layers.append(sa_blocks[0])
else:
sa_layers.append(nn.Sequential(*sa_blocks))
return sa_layers, sa_in_channels, in_channels, 1 if num_centers is None else num_centers
def create_pointnet2_fp_modules(fp_blocks, in_channels, sa_in_channels, embed_dim=64, use_att=False,
dropout=0.1, has_temb=1,
with_se=False, normalize=True, eps=0,
width_multiplier=1, voxel_resolution_multiplier=1,
verbose=True, cfg={}):
assert(len(cfg) > 0), cfg
r, vr = width_multiplier, voxel_resolution_multiplier
fp_layers = []
c = 0
for fp_idx, (fp_configs, conv_configs) in enumerate(fp_blocks):
fp_blocks = []
out_channels = tuple(int(r * oc) for oc in fp_configs)
fp_blocks.append(
PointNetFPModule(
in_channels=in_channels + sa_in_channels[-1 - fp_idx] + embed_dim*has_temb,
out_channels=out_channels,
cfg=cfg)
)
in_channels = out_channels[-1]
if conv_configs is not None:
out_channels, num_blocks, voxel_resolution = conv_configs
out_channels = int(r * out_channels)
for p in range(num_blocks):
attention = (c+1) % 2 == 0 and c < len(fp_blocks) - 1 and use_att and p == 0
if voxel_resolution is None:
block = functools.partial(SharedMLP, cfg=cfg)
else:
block = functools.partial(PVConv, kernel_size=3,
resolution=int(vr * voxel_resolution), attention=attention,
dropout=dropout,
with_se=with_se, # with_se_relu=True,
normalize=normalize, eps=eps,
verbose=verbose,
cfg=cfg)
fp_blocks.append(block(in_channels, out_channels))
in_channels = out_channels
if len(fp_blocks) == 1:
fp_layers.append(fp_blocks[0])
else:
fp_layers.append(nn.Sequential(*fp_blocks))
c += 1
return fp_layers, in_channels