PVD/modules/pointnet.py

114 lines
4.7 KiB
Python
Raw Normal View History

2021-10-19 20:54:46 +00:00
import torch
import torch.nn as nn
import modules.functional as F
from modules.ball_query import BallQuery
from modules.shared_mlp import SharedMLP
2023-04-11 09:12:58 +00:00
__all__ = ["PointNetAModule", "PointNetSAModule", "PointNetFPModule"]
2021-10-19 20:54:46 +00:00
class PointNetAModule(nn.Module):
def __init__(self, in_channels, out_channels, include_coordinates=True):
super().__init__()
if not isinstance(out_channels, (list, tuple)):
out_channels = [[out_channels]]
elif not isinstance(out_channels[0], (list, tuple)):
out_channels = [out_channels]
mlps = []
total_out_channels = 0
for _out_channels in out_channels:
mlps.append(
2023-04-11 09:12:58 +00:00
SharedMLP(
in_channels=in_channels + (3 if include_coordinates else 0), out_channels=_out_channels, dim=1
)
2021-10-19 20:54:46 +00:00
)
total_out_channels += _out_channels[-1]
self.include_coordinates = include_coordinates
self.out_channels = total_out_channels
self.mlps = nn.ModuleList(mlps)
def forward(self, inputs):
features, coords = inputs
if self.include_coordinates:
features = torch.cat([features, coords], dim=1)
coords = torch.zeros((coords.size(0), 3, 1), device=coords.device)
if len(self.mlps) > 1:
features_list = []
for mlp in self.mlps:
features_list.append(mlp(features).max(dim=-1, keepdim=True).values)
return torch.cat(features_list, dim=1), coords
else:
return self.mlps[0](features).max(dim=-1, keepdim=True).values, coords
def extra_repr(self):
2023-04-11 09:12:58 +00:00
return f"out_channels={self.out_channels}, include_coordinates={self.include_coordinates}"
2021-10-19 20:54:46 +00:00
class PointNetSAModule(nn.Module):
def __init__(self, num_centers, radius, num_neighbors, in_channels, out_channels, include_coordinates=True):
super().__init__()
if not isinstance(radius, (list, tuple)):
radius = [radius]
if not isinstance(num_neighbors, (list, tuple)):
num_neighbors = [num_neighbors] * len(radius)
assert len(radius) == len(num_neighbors)
if not isinstance(out_channels, (list, tuple)):
out_channels = [[out_channels]] * len(radius)
elif not isinstance(out_channels[0], (list, tuple)):
out_channels = [out_channels] * len(radius)
assert len(radius) == len(out_channels)
groupers, mlps = [], []
total_out_channels = 0
for _radius, _out_channels, _num_neighbors in zip(radius, out_channels, num_neighbors):
groupers.append(
BallQuery(radius=_radius, num_neighbors=_num_neighbors, include_coordinates=include_coordinates)
)
mlps.append(
2023-04-11 09:12:58 +00:00
SharedMLP(
in_channels=in_channels + (3 if include_coordinates else 0), out_channels=_out_channels, dim=2
)
2021-10-19 20:54:46 +00:00
)
total_out_channels += _out_channels[-1]
self.num_centers = num_centers
self.out_channels = total_out_channels
self.groupers = nn.ModuleList(groupers)
self.mlps = nn.ModuleList(mlps)
def forward(self, inputs):
features, coords, temb = inputs
centers_coords = F.furthest_point_sample(coords, self.num_centers)
features_list = []
for grouper, mlp in zip(self.groupers, self.mlps):
features, temb = mlp(grouper(coords, centers_coords, temb, features))
features_list.append(features.max(dim=-1).values)
if len(features_list) > 1:
return features_list[0], centers_coords, temb.max(dim=-1).values if temb.shape[1] > 0 else temb
else:
return features_list[0], centers_coords, temb.max(dim=-1).values if temb.shape[1] > 0 else temb
def extra_repr(self):
2023-04-11 09:12:58 +00:00
return f"num_centers={self.num_centers}, out_channels={self.out_channels}"
2021-10-19 20:54:46 +00:00
class PointNetFPModule(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.mlp = SharedMLP(in_channels=in_channels, out_channels=out_channels, dim=1)
def forward(self, inputs):
if len(inputs) == 3:
points_coords, centers_coords, centers_features, temb = inputs
points_features = None
else:
points_coords, centers_coords, centers_features, points_features, temb = inputs
interpolated_features = F.nearest_neighbor_interpolate(points_coords, centers_coords, centers_features)
interpolated_temb = F.nearest_neighbor_interpolate(points_coords, centers_coords, temb)
if points_features is not None:
2023-04-11 09:12:58 +00:00
interpolated_features = torch.cat([interpolated_features, points_features], dim=1)
2021-10-19 20:54:46 +00:00
return self.mlp(interpolated_features), points_coords, interpolated_temb