114 lines
4.7 KiB
Python
114 lines
4.7 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
|
|
import modules.functional as F
|
|
from modules.ball_query import BallQuery
|
|
from modules.shared_mlp import SharedMLP
|
|
|
|
__all__ = ['PointNetAModule', 'PointNetSAModule', 'PointNetFPModule']
|
|
|
|
|
|
class PointNetAModule(nn.Module):
|
|
def __init__(self, in_channels, out_channels, include_coordinates=True):
|
|
super().__init__()
|
|
if not isinstance(out_channels, (list, tuple)):
|
|
out_channels = [[out_channels]]
|
|
elif not isinstance(out_channels[0], (list, tuple)):
|
|
out_channels = [out_channels]
|
|
|
|
mlps = []
|
|
total_out_channels = 0
|
|
for _out_channels in out_channels:
|
|
mlps.append(
|
|
SharedMLP(in_channels=in_channels + (3 if include_coordinates else 0),
|
|
out_channels=_out_channels, dim=1)
|
|
)
|
|
total_out_channels += _out_channels[-1]
|
|
|
|
self.include_coordinates = include_coordinates
|
|
self.out_channels = total_out_channels
|
|
self.mlps = nn.ModuleList(mlps)
|
|
|
|
def forward(self, inputs):
|
|
features, coords = inputs
|
|
if self.include_coordinates:
|
|
features = torch.cat([features, coords], dim=1)
|
|
coords = torch.zeros((coords.size(0), 3, 1), device=coords.device)
|
|
if len(self.mlps) > 1:
|
|
features_list = []
|
|
for mlp in self.mlps:
|
|
features_list.append(mlp(features).max(dim=-1, keepdim=True).values)
|
|
return torch.cat(features_list, dim=1), coords
|
|
else:
|
|
return self.mlps[0](features).max(dim=-1, keepdim=True).values, coords
|
|
|
|
def extra_repr(self):
|
|
return f'out_channels={self.out_channels}, include_coordinates={self.include_coordinates}'
|
|
|
|
|
|
class PointNetSAModule(nn.Module):
|
|
def __init__(self, num_centers, radius, num_neighbors, in_channels, out_channels, include_coordinates=True):
|
|
super().__init__()
|
|
if not isinstance(radius, (list, tuple)):
|
|
radius = [radius]
|
|
if not isinstance(num_neighbors, (list, tuple)):
|
|
num_neighbors = [num_neighbors] * len(radius)
|
|
assert len(radius) == len(num_neighbors)
|
|
if not isinstance(out_channels, (list, tuple)):
|
|
out_channels = [[out_channels]] * len(radius)
|
|
elif not isinstance(out_channels[0], (list, tuple)):
|
|
out_channels = [out_channels] * len(radius)
|
|
assert len(radius) == len(out_channels)
|
|
|
|
groupers, mlps = [], []
|
|
total_out_channels = 0
|
|
for _radius, _out_channels, _num_neighbors in zip(radius, out_channels, num_neighbors):
|
|
groupers.append(
|
|
BallQuery(radius=_radius, num_neighbors=_num_neighbors, include_coordinates=include_coordinates)
|
|
)
|
|
mlps.append(
|
|
SharedMLP(in_channels=in_channels + (3 if include_coordinates else 0),
|
|
out_channels=_out_channels, dim=2)
|
|
)
|
|
total_out_channels += _out_channels[-1]
|
|
|
|
self.num_centers = num_centers
|
|
self.out_channels = total_out_channels
|
|
self.groupers = nn.ModuleList(groupers)
|
|
self.mlps = nn.ModuleList(mlps)
|
|
|
|
def forward(self, inputs):
|
|
features, coords, temb = inputs
|
|
centers_coords = F.furthest_point_sample(coords, self.num_centers)
|
|
features_list = []
|
|
for grouper, mlp in zip(self.groupers, self.mlps):
|
|
features, temb = mlp(grouper(coords, centers_coords, temb, features))
|
|
features_list.append(features.max(dim=-1).values)
|
|
if len(features_list) > 1:
|
|
return features_list[0], centers_coords, temb.max(dim=-1).values if temb.shape[1] > 0 else temb
|
|
else:
|
|
return features_list[0], centers_coords, temb.max(dim=-1).values if temb.shape[1] > 0 else temb
|
|
|
|
def extra_repr(self):
|
|
return f'num_centers={self.num_centers}, out_channels={self.out_channels}'
|
|
|
|
|
|
class PointNetFPModule(nn.Module):
|
|
def __init__(self, in_channels, out_channels):
|
|
super().__init__()
|
|
self.mlp = SharedMLP(in_channels=in_channels, out_channels=out_channels, dim=1)
|
|
|
|
def forward(self, inputs):
|
|
if len(inputs) == 3:
|
|
points_coords, centers_coords, centers_features, temb = inputs
|
|
points_features = None
|
|
else:
|
|
points_coords, centers_coords, centers_features, points_features, temb = inputs
|
|
interpolated_features = F.nearest_neighbor_interpolate(points_coords, centers_coords, centers_features)
|
|
interpolated_temb = F.nearest_neighbor_interpolate(points_coords, centers_coords, temb)
|
|
if points_features is not None:
|
|
interpolated_features = torch.cat(
|
|
[interpolated_features, points_features], dim=1
|
|
)
|
|
return self.mlp(interpolated_features), points_coords, interpolated_temb
|