import torch import torch.nn as nn import modules.functional as F from modules.ball_query import BallQuery from modules.shared_mlp import SharedMLP __all__ = ["PointNetAModule", "PointNetSAModule", "PointNetFPModule"] class PointNetAModule(nn.Module): def __init__(self, in_channels, out_channels, include_coordinates=True): super().__init__() if not isinstance(out_channels, (list, tuple)): out_channels = [[out_channels]] elif not isinstance(out_channels[0], (list, tuple)): out_channels = [out_channels] mlps = [] total_out_channels = 0 for _out_channels in out_channels: mlps.append( SharedMLP( in_channels=in_channels + (3 if include_coordinates else 0), out_channels=_out_channels, dim=1 ) ) total_out_channels += _out_channels[-1] self.include_coordinates = include_coordinates self.out_channels = total_out_channels self.mlps = nn.ModuleList(mlps) def forward(self, inputs): features, coords = inputs if self.include_coordinates: features = torch.cat([features, coords], dim=1) coords = torch.zeros((coords.size(0), 3, 1), device=coords.device) if len(self.mlps) > 1: features_list = [] for mlp in self.mlps: features_list.append(mlp(features).max(dim=-1, keepdim=True).values) return torch.cat(features_list, dim=1), coords else: return self.mlps[0](features).max(dim=-1, keepdim=True).values, coords def extra_repr(self): return f"out_channels={self.out_channels}, include_coordinates={self.include_coordinates}" class PointNetSAModule(nn.Module): def __init__(self, num_centers, radius, num_neighbors, in_channels, out_channels, include_coordinates=True): super().__init__() if not isinstance(radius, (list, tuple)): radius = [radius] if not isinstance(num_neighbors, (list, tuple)): num_neighbors = [num_neighbors] * len(radius) assert len(radius) == len(num_neighbors) if not isinstance(out_channels, (list, tuple)): out_channels = [[out_channels]] * len(radius) elif not isinstance(out_channels[0], (list, tuple)): out_channels = [out_channels] * len(radius) assert len(radius) == len(out_channels) groupers, mlps = [], [] total_out_channels = 0 for _radius, _out_channels, _num_neighbors in zip(radius, out_channels, num_neighbors): groupers.append( BallQuery(radius=_radius, num_neighbors=_num_neighbors, include_coordinates=include_coordinates) ) mlps.append( SharedMLP( in_channels=in_channels + (3 if include_coordinates else 0), out_channels=_out_channels, dim=2 ) ) total_out_channels += _out_channels[-1] self.num_centers = num_centers self.out_channels = total_out_channels self.groupers = nn.ModuleList(groupers) self.mlps = nn.ModuleList(mlps) def forward(self, inputs): features, coords, temb = inputs centers_coords = F.furthest_point_sample(coords, self.num_centers) features_list = [] for grouper, mlp in zip(self.groupers, self.mlps): features, temb = mlp(grouper(coords, centers_coords, temb, features)) features_list.append(features.max(dim=-1).values) if len(features_list) > 1: return features_list[0], centers_coords, temb.max(dim=-1).values if temb.shape[1] > 0 else temb else: return features_list[0], centers_coords, temb.max(dim=-1).values if temb.shape[1] > 0 else temb def extra_repr(self): return f"num_centers={self.num_centers}, out_channels={self.out_channels}" class PointNetFPModule(nn.Module): def __init__(self, in_channels, out_channels): super().__init__() self.mlp = SharedMLP(in_channels=in_channels, out_channels=out_channels, dim=1) def forward(self, inputs): if len(inputs) == 3: points_coords, centers_coords, centers_features, temb = inputs points_features = None else: points_coords, centers_coords, centers_features, points_features, temb = inputs interpolated_features = F.nearest_neighbor_interpolate(points_coords, centers_coords, centers_features) interpolated_temb = F.nearest_neighbor_interpolate(points_coords, centers_coords, temb) if points_features is not None: interpolated_features = torch.cat([interpolated_features, points_features], dim=1) return self.mlp(interpolated_features), points_coords, interpolated_temb