PointMLP/classification_ScanObjectNN/models/modelelite3.py

640 lines
31 KiB
Python
Raw Normal View History

2021-10-04 07:22:15 +00:00
"""
Based on model31, different configures for elite version.
Based on model30, change the grouper operation by normalization.
Based on model28, only change configurations, mainly the reducer.
Based on model27, change to x-a, reorgnized structure
Based on model25, simple LocalGrouper (not x-a), reorgnized structure
Based on model24, using ReLU to replace GELU
Based on model22, remove attention
Bsed on model21, change FPS to random sampling.
Exactly based on Model10, but ReLU to GeLU
Based on Model8, add dropout and max, avg combine.
Based on Local model, add residual connections.
The extraction is doubled for depth.
Learning Point Cloud with Progressively Local representation.
[B,3,N] - {[B,G,K,d]-[B,G,d]} - {[B,G',K,d]-[B,G',d]} -cls
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
# from torch import einsum
# from einops import rearrange, repeat
from einops.layers.torch import Rearrange
from pointnet2_ops import pointnet2_utils
def get_activation(activation):
if activation.lower() == 'gelu':
return nn.GELU()
elif activation.lower() == 'rrelu':
return nn.RReLU(inplace=True)
elif activation.lower() == 'selu':
return nn.SELU(inplace=True)
elif activation.lower() == 'silu':
return nn.SiLU(inplace=True)
elif activation.lower() == 'hardswish':
return nn.Hardswish(inplace=True)
elif activation.lower() == 'leakyrelu':
return nn.LeakyReLU(inplace=True)
else:
return nn.ReLU(inplace=True)
def square_distance(src, dst):
"""
Calculate Euclid distance between each two points.
src^T * dst = xn * xm + yn * ym + zn * zm
sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn;
sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm;
dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2
= sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst
Input:
src: source points, [B, N, C]
dst: target points, [B, M, C]
Output:
dist: per-point square distance, [B, N, M]
"""
B, N, _ = src.shape
_, M, _ = dst.shape
dist = -2 * torch.matmul(src, dst.permute(0, 2, 1))
dist += torch.sum(src ** 2, -1).view(B, N, 1)
dist += torch.sum(dst ** 2, -1).view(B, 1, M)
return dist
def index_points(points, idx):
"""
Input:
points: input points data, [B, N, C]
idx: sample index data, [B, S]
Return:
new_points:, indexed points data, [B, S, C]
"""
device = points.device
B = points.shape[0]
view_shape = list(idx.shape)
view_shape[1:] = [1] * (len(view_shape) - 1)
repeat_shape = list(idx.shape)
repeat_shape[0] = 1
batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape)
new_points = points[batch_indices, idx, :]
return new_points
def farthest_point_sample(xyz, npoint):
"""
Input:
xyz: pointcloud data, [B, N, 3]
npoint: number of samples
Return:
centroids: sampled pointcloud index, [B, npoint]
"""
device = xyz.device
B, N, C = xyz.shape
centroids = torch.zeros(B, npoint, dtype=torch.long).to(device)
distance = torch.ones(B, N).to(device) * 1e10
farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device)
batch_indices = torch.arange(B, dtype=torch.long).to(device)
for i in range(npoint):
centroids[:, i] = farthest
centroid = xyz[batch_indices, farthest, :].view(B, 1, 3)
dist = torch.sum((xyz - centroid) ** 2, -1)
distance = torch.min(distance, dist)
farthest = torch.max(distance, -1)[1]
return centroids
def query_ball_point(radius, nsample, xyz, new_xyz):
"""
Input:
radius: local region radius
nsample: max sample number in local region
xyz: all points, [B, N, 3]
new_xyz: query points, [B, S, 3]
Return:
group_idx: grouped points index, [B, S, nsample]
"""
device = xyz.device
B, N, C = xyz.shape
_, S, _ = new_xyz.shape
group_idx = torch.arange(N, dtype=torch.long).to(device).view(1, 1, N).repeat([B, S, 1])
sqrdists = square_distance(new_xyz, xyz)
group_idx[sqrdists > radius ** 2] = N
group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample]
group_first = group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, nsample])
mask = group_idx == N
group_idx[mask] = group_first[mask]
return group_idx
def knn_point(nsample, xyz, new_xyz):
"""
Input:
nsample: max sample number in local region
xyz: all points, [B, N, C]
new_xyz: query points, [B, S, C]
Return:
group_idx: grouped points index, [B, S, nsample]
"""
sqrdists = square_distance(new_xyz, xyz)
_, group_idx = torch.topk(sqrdists, nsample, dim=-1, largest=False, sorted=False)
return group_idx
class LocalGrouper(nn.Module):
def __init__(self, channel, groups, kneighbors, use_xyz=True, normalize="center", **kwargs):
"""
Give xyz[b,p,3] and fea[b,p,d], return new_xyz[b,g,3] and new_fea[b,g,k,d]
:param groups: groups number
:param kneighbors: k-nerighbors
:param kwargs: others
"""
super(LocalGrouper, self).__init__()
self.groups = groups
self.kneighbors = kneighbors
self.use_xyz = use_xyz
if normalize is not None:
self.normalize = normalize.lower()
else:
self.normalize = None
if self.normalize not in ["center", "anchor"]:
print(f"Unrecognized normalize parameter (self.normalize), set to None. Should be one of [center, anchor].")
self.normalize = None
if self.normalize is not None:
add_channel=3 if self.use_xyz else 0
self.affine_alpha = nn.Parameter(torch.ones([1,1,1,channel + add_channel]))
self.affine_beta = nn.Parameter(torch.zeros([1, 1, 1, channel + add_channel]))
def forward(self, xyz, points):
B, N, C = xyz.shape
S = self.groups
xyz = xyz.contiguous() # xyz [btach, points, xyz]
# fps_idx = torch.multinomial(torch.linspace(0, N - 1, steps=N).repeat(B, 1).to(xyz.device), num_samples=self.groups, replacement=False).long()
# fps_idx = farthest_point_sample(xyz, self.groups).long()
fps_idx = pointnet2_utils.furthest_point_sample(xyz, self.groups).long() # [B, npoint]
new_xyz = index_points(xyz, fps_idx) # [B, npoint, 3]
new_points = index_points(points, fps_idx) # [B, npoint, d]
idx = knn_point(self.kneighbors, xyz, new_xyz)
# idx = query_ball_point(radius, nsample, xyz, new_xyz)
grouped_xyz = index_points(xyz, idx) # [B, npoint, k, 3]
grouped_points = index_points(points, idx) # [B, npoint, k, d]
if self.use_xyz:
grouped_points = torch.cat([grouped_points, grouped_xyz],dim=-1) # [B, npoint, k, d+3]
if self.normalize is not None:
if self.normalize =="center":
std, mean = torch.std_mean(grouped_points, dim=2, keepdim=True)
if self.normalize =="anchor":
mean = torch.cat([new_points, new_xyz],dim=-1) if self.use_xyz else new_points
mean = mean.unsqueeze(dim=-2) # [B, npoint, 1, d+3]
std = torch.std(grouped_points-mean)
grouped_points = (grouped_points-mean)/(std + 1e-5)
grouped_points = self.affine_alpha*grouped_points + self.affine_beta
new_points = torch.cat([grouped_points, new_points.view(B, S, 1, -1).repeat(1, 1, self.kneighbors, 1)], dim=-1)
return new_xyz, new_points
class ConvBNReLU1D(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=1, bias=True, activation='relu'):
super(ConvBNReLU1D, self).__init__()
self.act = get_activation(activation)
self.net = nn.Sequential(
nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, bias=bias),
nn.BatchNorm1d(out_channels),
self.act
)
def forward(self, x):
return self.net(x)
class ConvBNReLURes1D(nn.Module):
def __init__(self, channel, kernel_size=1, groups=1, res_expansion=1.0, bias=True, activation='relu'):
super(ConvBNReLURes1D, self).__init__()
self.act = get_activation(activation)
self.net1 = nn.Sequential(
nn.Conv1d(in_channels=channel, out_channels=int(channel * res_expansion),
kernel_size=kernel_size, groups=groups, bias=bias),
nn.BatchNorm1d(int(channel * res_expansion)),
self.act
)
self.net2 = nn.Sequential(
nn.Conv1d(in_channels=int(channel * res_expansion), out_channels=channel,
kernel_size=kernel_size, groups=groups, bias=bias),
nn.BatchNorm1d(channel)
)
def forward(self, x):
return self.act(self.net2(self.net1(x)) + x)
class PreExtraction(nn.Module):
def __init__(self, channels, out_channels, blocks=1, groups=1, res_expansion=1, bias=True,
activation='relu', use_xyz=True):
"""
input: [b,g,k,d]: output:[b,d,g]
:param channels:
:param blocks:
"""
super(PreExtraction, self).__init__()
in_channels = 3+2*channels if use_xyz else 2*channels
self.transfer = ConvBNReLU1D(in_channels, out_channels, bias=bias, activation=activation)
operation = []
for _ in range(blocks):
operation.append(
ConvBNReLURes1D(out_channels, groups=groups, res_expansion=res_expansion,
bias=bias, activation=activation)
)
self.operation = nn.Sequential(*operation)
def forward(self, x):
b, n, s, d = x.size() # torch.Size([32, 512, 32, 6])
x = x.permute(0, 1, 3, 2)
x = x.reshape(-1, d, s)
x = self.transfer(x)
batch_size, _, _ = x.size()
x = self.operation(x) # [b, d, k]
x = F.adaptive_max_pool1d(x, 1).view(batch_size, -1)
x = x.reshape(b, n, -1).permute(0, 2, 1)
return x
class PosExtraction(nn.Module):
def __init__(self, channels, blocks=1, groups=1, res_expansion=1, bias=True, activation='relu'):
"""
input[b,d,g]; output[b,d,g]
:param channels:
:param blocks:
"""
super(PosExtraction, self).__init__()
operation = []
for _ in range(blocks):
operation.append(
ConvBNReLURes1D(channels, groups=groups, res_expansion=res_expansion, bias=bias, activation=activation)
)
self.operation = nn.Sequential(*operation)
def forward(self, x): # [b, d, g]
return self.operation(x)
class modelelite3(nn.Module):
def __init__(self, points=1024, class_num=40, embed_dim=64, groups=1, res_expansion=1.0,
activation="relu", bias=True, use_xyz=True, normalize="center",
dim_expansion=[2, 2, 2, 2], pre_blocks=[2, 2, 2, 2], pos_blocks=[2, 2, 2, 2],
k_neighbors=[32, 32, 32, 32], reducers=[2, 2, 2, 2], **kwargs):
super(modelelite3, self).__init__()
self.stages = len(pre_blocks)
self.class_num = class_num
self.points = points
self.embedding = ConvBNReLU1D(3, embed_dim, bias=bias, activation=activation)
assert len(pre_blocks) == len(k_neighbors) == len(reducers) == len(pos_blocks) == len(dim_expansion), \
"Please check stage number consistent for pre_blocks, pos_blocks k_neighbors, reducers."
self.local_grouper_list = nn.ModuleList()
self.pre_blocks_list = nn.ModuleList()
self.pos_blocks_list = nn.ModuleList()
last_channel = embed_dim
anchor_points = self.points
for i in range(len(pre_blocks)):
out_channel = last_channel * dim_expansion[i]
pre_block_num = pre_blocks[i]
pos_block_num = pos_blocks[i]
kneighbor = k_neighbors[i]
reduce = reducers[i]
anchor_points = anchor_points // reduce
# append local_grouper_list
local_grouper = LocalGrouper(last_channel, anchor_points, kneighbor, use_xyz, normalize) # [b,g,k,d]
self.local_grouper_list.append(local_grouper)
# append pre_block_list
pre_block_module = PreExtraction(last_channel, out_channel, pre_block_num, groups=groups,
res_expansion=res_expansion,
bias=bias, activation=activation, use_xyz=use_xyz)
self.pre_blocks_list.append(pre_block_module)
# append pos_block_list
pos_block_module = PosExtraction(out_channel, pos_block_num, groups=groups,
res_expansion=res_expansion, bias=bias, activation=activation)
self.pos_blocks_list.append(pos_block_module)
last_channel = out_channel
self.act = get_activation(activation)
self.classifier = nn.Sequential(
nn.Linear(last_channel, 512),
nn.BatchNorm1d(512),
self.act,
nn.Dropout(0.5),
nn.Linear(512, 256),
nn.BatchNorm1d(256),
self.act,
nn.Dropout(0.5),
nn.Linear(256, self.class_num)
)
def forward(self, x):
xyz = x.permute(0, 2, 1)
batch_size, _, _ = x.size()
x = self.embedding(x) # B,D,N
for i in range(self.stages):
# Give xyz[b, p, 3] and fea[b, p, d], return new_xyz[b, g, 3] and new_fea[b, g, k, d]
xyz, x = self.local_grouper_list[i](xyz, x.permute(0, 2, 1)) # [b,g,3] [b,g,k,d]
x = self.pre_blocks_list[i](x) # [b,d,g]
x = self.pos_blocks_list[i](x) # [b,d,g]
x = F.adaptive_max_pool1d(x, 1).squeeze(dim=-1)
x = self.classifier(x)
return x
def modelelite3A1(num_classes=40, **kwargs) -> modelelite3: # 3.48M
return modelelite3(points=1024, class_num=num_classes, embed_dim=64, groups=1, res_expansion=0.125,
activation="relu", bias=False, use_xyz=False, normalize="anchor",
dim_expansion=[2, 2, 2, 2], pre_blocks=[2, 2, 2, 2], pos_blocks=[2, 2, 2, 2],
k_neighbors=[24, 24, 24, 24], reducers=[2, 2, 2, 2], **kwargs)
def modelelite3B1(num_classes=40, **kwargs) -> modelelite3: # 2.78M
return modelelite3(points=1024, class_num=num_classes, embed_dim=64, groups=1, res_expansion=0.0625,
activation="relu", bias=False, use_xyz=False, normalize="anchor",
dim_expansion=[2, 2, 2, 2], pre_blocks=[2, 2, 2, 2], pos_blocks=[2, 2, 2, 2],
k_neighbors=[24, 24, 24, 24], reducers=[2, 2, 2, 2], **kwargs)
def modelelite3C1(num_classes=40, **kwargs) -> modelelite3: # 4.87M
return modelelite3(points=1024, class_num=num_classes, embed_dim=64, groups=1, res_expansion=0.25,
activation="relu", bias=False, use_xyz=False, normalize="anchor",
dim_expansion=[2, 2, 2, 2], pre_blocks=[2, 2, 2, 2], pos_blocks=[2, 2, 2, 2],
k_neighbors=[24, 24, 24, 24], reducers=[2, 2, 2, 2], **kwargs)
def modelelite3D1(num_classes=40, **kwargs) -> modelelite3: # 2.26M
return modelelite3(points=1024, class_num=num_classes, embed_dim=64, groups=8, res_expansion=0.125,
activation="relu", bias=False, use_xyz=False, normalize="anchor",
dim_expansion=[2, 2, 2, 2], pre_blocks=[2, 2, 2, 2], pos_blocks=[2, 2, 2, 2],
k_neighbors=[24, 24, 24, 24], reducers=[2, 2, 2, 2], **kwargs)
def modelelite3E1(num_classes=40, **kwargs) -> modelelite3: # 2.43M
return modelelite3(points=1024, class_num=num_classes, embed_dim=64, groups=4, res_expansion=0.125,
activation="relu", bias=False, use_xyz=False, normalize="anchor",
dim_expansion=[2, 2, 2, 2], pre_blocks=[2, 2, 2, 2], pos_blocks=[2, 2, 2, 2],
k_neighbors=[24, 24, 24, 24], reducers=[2, 2, 2, 2], **kwargs)
def modelelite3F1(num_classes=40, **kwargs) -> modelelite3: # 0.85M
return modelelite3(points=1024, class_num=num_classes, embed_dim=32, groups=4, res_expansion=0.125,
activation="relu", bias=False, use_xyz=False, normalize="anchor",
dim_expansion=[2, 2, 2, 2], pre_blocks=[2, 2, 2, 2], pos_blocks=[2, 2, 2, 2],
k_neighbors=[24, 24, 24, 24], reducers=[2, 2, 2, 2], **kwargs)
def modelelite3G1(num_classes=40, **kwargs) -> modelelite3: # 0.85M
return modelelite3(points=1024, class_num=num_classes, embed_dim=32, groups=4, res_expansion=0.125,
activation="relu", bias=False, use_xyz=False, normalize="anchor",
dim_expansion=[2, 2, 2, 2], pre_blocks=[2, 2, 2, 2], pos_blocks=[2, 2, 2, 2],
k_neighbors=[24, 24, 24, 24], reducers=[2, 2, 2, 2], **kwargs)
def modelelite3H1(num_classes=40, **kwargs) -> modelelite3: # 3.56M
return modelelite3(points=1024, class_num=num_classes, embed_dim=32, groups=1, res_expansion=1,
activation="relu", bias=False, use_xyz=False, normalize="anchor",
dim_expansion=[2, 2, 2, 2], pre_blocks=[2, 2, 2, 2], pos_blocks=[2, 2, 2, 2],
k_neighbors=[24, 24, 24, 24], reducers=[2, 2, 2, 2], **kwargs)
def modelelite3I1(num_classes=40, **kwargs) -> modelelite3: # 0.90M
return modelelite3(points=1024, class_num=num_classes, embed_dim=32, groups=4, res_expansion=0.125,
activation="relu", bias=False, use_xyz=False, normalize="anchor",
dim_expansion=[2, 2, 2, 2], pre_blocks=[3, 3, 3, 3], pos_blocks=[3, 3, 3, 3],
k_neighbors=[24, 24, 24, 24], reducers=[2, 2, 2, 2], **kwargs)
def modelelite3J1(num_classes=40, **kwargs) -> modelelite3: # 0.93M
return modelelite3(points=1024, class_num=num_classes, embed_dim=32, groups=4, res_expansion=0.125,
activation="relu", bias=False, use_xyz=False, normalize="anchor",
dim_expansion=[2, 2, 2, 2], pre_blocks=[3, 4, 6, 3], pos_blocks=[3, 4, 6, 3],
k_neighbors=[24, 24, 24, 24], reducers=[2, 2, 2, 2], **kwargs)
def modelelite3K1(num_classes=40, **kwargs) -> modelelite3: # 0.93M
return modelelite3(points=1024, class_num=num_classes, embed_dim=32, groups=8, res_expansion=0.25,
activation="relu", bias=False, use_xyz=False, normalize="anchor",
dim_expansion=[2, 2, 2, 2], pre_blocks=[3, 4, 6, 3], pos_blocks=[3, 4, 6, 3],
k_neighbors=[24, 24, 24, 24], reducers=[2, 2, 2, 2], **kwargs)
def modelelite3L1(num_classes=40, **kwargs) -> modelelite3: # 0.95M
return modelelite3(points=1024, class_num=num_classes, embed_dim=32, groups=8, res_expansion=0.25,
activation="relu", bias=True, use_xyz=True, normalize="anchor",
dim_expansion=[2, 2, 2, 2], pre_blocks=[3, 4, 6, 3], pos_blocks=[3, 4, 6, 3],
k_neighbors=[24, 24, 24, 24], reducers=[2, 2, 2, 2], **kwargs)
def modelelite3M1(num_classes=40, **kwargs) -> modelelite3: # 0.94M
return modelelite3(points=1024, class_num=num_classes, embed_dim=64, groups=4, res_expansion=0.125,
activation="relu", bias=False, use_xyz=False, normalize="anchor",
dim_expansion=[2, 2, 2], pre_blocks=[4, 4, 4], pos_blocks=[4, 4, 4],
k_neighbors=[24, 24, 24], reducers=[2, 2, 2], **kwargs)
########version 2: 64 neighbors with 0.5 drop ratio ###########
def modelelite3A2(num_classes=40, **kwargs) -> modelelite3: # 3.48M
return modelelite3(points=1024, class_num=num_classes, embed_dim=64, groups=1, res_expansion=0.125,
activation="relu", bias=False, use_xyz=False, normalize="anchor",
dim_expansion=[2, 2, 2, 2], pre_blocks=[2, 2, 2, 2], pos_blocks=[2, 2, 2, 2],
k_neighbors=[32, 32, 32, 32], reducers=[2, 2, 2, 2], **kwargs)
def modelelite3B2(num_classes=40, **kwargs) -> modelelite3: # 2.78M
return modelelite3(points=1024, class_num=num_classes, embed_dim=64, groups=1, res_expansion=0.0625,
activation="relu", bias=False, use_xyz=False, normalize="anchor",
dim_expansion=[2, 2, 2, 2], pre_blocks=[2, 2, 2, 2], pos_blocks=[2, 2, 2, 2],
k_neighbors=[32, 32, 32, 32], reducers=[2, 2, 2, 2], **kwargs)
def modelelite3C2(num_classes=40, **kwargs) -> modelelite3: # 4.87M
return modelelite3(points=1024, class_num=num_classes, embed_dim=64, groups=1, res_expansion=0.25,
activation="relu", bias=False, use_xyz=False, normalize="anchor",
dim_expansion=[2, 2, 2, 2], pre_blocks=[2, 2, 2, 2], pos_blocks=[2, 2, 2, 2],
k_neighbors=[32, 32, 32, 32], reducers=[2, 2, 2, 2], **kwargs)
def modelelite3D2(num_classes=40, **kwargs) -> modelelite3: # 2.26M
return modelelite3(points=1024, class_num=num_classes, embed_dim=64, groups=8, res_expansion=0.125,
activation="relu", bias=False, use_xyz=False, normalize="anchor",
dim_expansion=[2, 2, 2, 2], pre_blocks=[2, 2, 2, 2], pos_blocks=[2, 2, 2, 2],
k_neighbors=[32, 32, 32, 32], reducers=[2, 2, 2, 2], **kwargs)
def modelelite3E2(num_classes=40, **kwargs) -> modelelite3: # 2.43M
return modelelite3(points=1024, class_num=num_classes, embed_dim=64, groups=4, res_expansion=0.125,
activation="relu", bias=False, use_xyz=False, normalize="anchor",
dim_expansion=[2, 2, 2, 2], pre_blocks=[2, 2, 2, 2], pos_blocks=[2, 2, 2, 2],
k_neighbors=[32, 32, 32, 32], reducers=[2, 2, 2, 2], **kwargs)
def modelelite3F2(num_classes=40, **kwargs) -> modelelite3: # 0.85M
return modelelite3(points=1024, class_num=num_classes, embed_dim=32, groups=4, res_expansion=0.125,
activation="relu", bias=False, use_xyz=False, normalize="anchor",
dim_expansion=[2, 2, 2, 2], pre_blocks=[2, 2, 2, 2], pos_blocks=[2, 2, 2, 2],
k_neighbors=[32, 32, 32, 32], reducers=[2, 2, 2, 2], **kwargs)
def modelelite3G2(num_classes=40, **kwargs) -> modelelite3: # 0.85M
return modelelite3(points=1024, class_num=num_classes, embed_dim=32, groups=4, res_expansion=0.125,
activation="relu", bias=False, use_xyz=False, normalize="anchor",
dim_expansion=[2, 2, 2, 2], pre_blocks=[2, 2, 2, 2], pos_blocks=[2, 2, 2, 2],
k_neighbors=[32, 32, 32, 32], reducers=[2, 2, 2, 2], **kwargs)
def modelelite3H2(num_classes=40, **kwargs) -> modelelite3: # 3.56M
return modelelite3(points=1024, class_num=num_classes, embed_dim=32, groups=1, res_expansion=1,
activation="relu", bias=False, use_xyz=False, normalize="anchor",
dim_expansion=[2, 2, 2, 2], pre_blocks=[2, 2, 2, 2], pos_blocks=[2, 2, 2, 2],
k_neighbors=[32, 32, 32, 32], reducers=[2, 2, 2, 2], **kwargs)
def modelelite3I2(num_classes=40, **kwargs) -> modelelite3: # 0.90M
return modelelite3(points=1024, class_num=num_classes, embed_dim=32, groups=4, res_expansion=0.125,
activation="relu", bias=False, use_xyz=False, normalize="anchor",
dim_expansion=[2, 2, 2, 2], pre_blocks=[3, 3, 3, 3], pos_blocks=[3, 3, 3, 3],
k_neighbors=[32, 32, 32, 32], reducers=[2, 2, 2, 2], **kwargs)
def modelelite3J2(num_classes=40, **kwargs) -> modelelite3: # 0.93M
return modelelite3(points=1024, class_num=num_classes, embed_dim=32, groups=4, res_expansion=0.125,
activation="relu", bias=False, use_xyz=False, normalize="anchor",
dim_expansion=[2, 2, 2, 2], pre_blocks=[3, 4, 6, 3], pos_blocks=[3, 4, 6, 3],
k_neighbors=[32, 32, 32, 32], reducers=[2, 2, 2, 2], **kwargs)
def modelelite3K2(num_classes=40, **kwargs) -> modelelite3: # 0.93M
return modelelite3(points=1024, class_num=num_classes, embed_dim=32, groups=8, res_expansion=0.25,
activation="relu", bias=False, use_xyz=False, normalize="anchor",
dim_expansion=[2, 2, 2, 2], pre_blocks=[3, 4, 6, 3], pos_blocks=[3, 4, 6, 3],
k_neighbors=[32, 32, 32, 32], reducers=[2, 2, 2, 2], **kwargs)
def modelelite3L2(num_classes=40, **kwargs) -> modelelite3: # 0.95M
return modelelite3(points=1024, class_num=num_classes, embed_dim=32, groups=8, res_expansion=0.25,
activation="relu", bias=True, use_xyz=True, normalize="anchor",
dim_expansion=[2, 2, 2, 2], pre_blocks=[3, 4, 6, 3], pos_blocks=[3, 4, 6, 3],
k_neighbors=[32, 32, 32, 32], reducers=[2, 2, 2, 2], **kwargs)
def modelelite3M2(num_classes=40, **kwargs) -> modelelite3: # 0.94M
return modelelite3(points=1024, class_num=num_classes, embed_dim=64, groups=4, res_expansion=0.125,
activation="relu", bias=False, use_xyz=False, normalize="anchor",
dim_expansion=[2, 2, 2], pre_blocks=[4, 4, 4], pos_blocks=[4, 4, 4],
k_neighbors=[32, 32, 32], reducers=[2, 2, 2], **kwargs)
def modelelite3X1(num_classes=40, **kwargs) -> modelelite3: # 1.11M
return modelelite3(points=1024, class_num=num_classes, embed_dim=32, groups=1, res_expansion=0.125,
activation="relu", bias=False, use_xyz=False, normalize="anchor",
dim_expansion=[2, 2, 2, 2], pre_blocks=[2, 2, 2, 2], pos_blocks=[2, 2, 2, 2],
k_neighbors=[20, 20, 20, 20], reducers=[2, 2, 2, 2], **kwargs)
def modelelite3X2(num_classes=40, **kwargs) -> modelelite3: # 0.94M
return modelelite3(points=1024, class_num=num_classes, embed_dim=32, groups=2, res_expansion=0.125,
activation="relu", bias=False, use_xyz=False, normalize="anchor",
dim_expansion=[2, 2, 2, 2], pre_blocks=[2, 2, 2, 2], pos_blocks=[2, 2, 2, 2],
k_neighbors=[20, 20, 20, 20], reducers=[2, 2, 2, 2], **kwargs)
def modelelite3X3(num_classes=40, **kwargs) -> modelelite3: # 0.94
return modelelite3(points=1024, class_num=num_classes, embed_dim=32, groups=1, res_expansion=0.125,
activation="relu", bias=False, use_xyz=False, normalize="anchor",
dim_expansion=[2, 2, 2, 2], pre_blocks=[1, 1, 1, 1], pos_blocks=[1, 1, 1, 1],
k_neighbors=[20, 20, 20, 20], reducers=[2, 2, 2, 2], **kwargs)
def modelelite3X4(num_classes=40, **kwargs) -> modelelite3: # 2.77
return modelelite3(points=1024, class_num=num_classes, embed_dim=64, groups=1, res_expansion=0.125,
activation="relu", bias=False, use_xyz=False, normalize="anchor",
dim_expansion=[2, 2, 2, 2], pre_blocks=[1, 1, 1, 1], pos_blocks=[1, 1, 1, 1],
k_neighbors=[20, 20, 20, 20], reducers=[2, 2, 2, 2], **kwargs)
def modelelite3X5(num_classes=40, **kwargs) -> modelelite3: # 1.59
return modelelite3(points=1024, class_num=num_classes, embed_dim=64, groups=1, res_expansion=0.125,
activation="relu", bias=False, use_xyz=False, normalize="anchor",
dim_expansion=[2, 2, 2, 1], pre_blocks=[1, 1, 1, 1], pos_blocks=[1, 1, 1, 1],
k_neighbors=[20, 20, 20, 20], reducers=[2, 2, 2, 2], **kwargs)
def modelelite3X6(num_classes=40, **kwargs) -> modelelite3: # 1.44
return modelelite3(points=1024, class_num=num_classes, embed_dim=64, groups=4, res_expansion=0.25,
activation="relu", bias=False, use_xyz=False, normalize="anchor",
dim_expansion=[2, 2, 2, 1], pre_blocks=[1, 1, 1, 1], pos_blocks=[1, 1, 1, 1],
k_neighbors=[20, 20, 20, 20], reducers=[2, 2, 2, 2], **kwargs)
def modelelite3X7(num_classes=40, **kwargs) -> modelelite3: # 1.11M
return modelelite3(points=1024, class_num=num_classes, embed_dim=32, groups=1, res_expansion=0.125,
activation="relu", bias=False, use_xyz=False, normalize="anchor",
dim_expansion=[2, 2, 2, 2], pre_blocks=[2, 2, 2, 2], pos_blocks=[2, 2, 2, 2],
k_neighbors=[24,24,24,24], reducers=[2, 2, 2, 2], **kwargs)
def modelelite3X8(num_classes=40, **kwargs) -> modelelite3: # 0.95M
return modelelite3(points=1024, class_num=num_classes, embed_dim=32, groups=1, res_expansion=0.125,
activation="relu", bias=False, use_xyz=False, normalize="anchor",
dim_expansion=[2, 2, 2, 2], pre_blocks=[1, 1, 2, 1], pos_blocks=[1, 1, 1, 1],
k_neighbors=[24,24,24,24], reducers=[2, 2, 2, 2], **kwargs)
def modelelite3X9(num_classes=40, **kwargs) -> modelelite3: # 1.59M
return modelelite3(points=1024, class_num=num_classes, embed_dim=64, groups=1, res_expansion=0.125,
activation="relu", bias=False, use_xyz=False, normalize="anchor",
dim_expansion=[2, 2, 2, 1], pre_blocks=[1, 1, 1, 1], pos_blocks=[1, 1, 1, 1],
k_neighbors=[24,24,24,24], reducers=[2, 2, 2, 2], **kwargs)
def modelelite3X10(num_classes=40, **kwargs) -> modelelite3: # 0.72M
return modelelite3(points=1024, class_num=num_classes, embed_dim=32, groups=1, res_expansion=0.25,
activation="relu", bias=False, use_xyz=False, normalize="anchor",
dim_expansion=[2, 2, 2, 1], pre_blocks=[1, 1, 2, 1], pos_blocks=[1, 1, 2, 1],
k_neighbors=[24,24,24,24], reducers=[2, 2, 2, 2], **kwargs)
def modelelite3X11(num_classes=40, **kwargs) -> modelelite3: # 0.98M 79/13s
return modelelite3(points=1024, class_num=num_classes, embed_dim=64, groups=1, res_expansion=0.25,
activation="relu", bias=False, use_xyz=False, normalize="anchor",
dim_expansion=[1, 2, 2, 2], pre_blocks=[1, 1, 1, 1], pos_blocks=[1, 1, 1, 0],
k_neighbors=[20, 20, 20, 20], reducers=[2, 2, 2, 2], **kwargs)
def modelelite3X12(num_classes=40, **kwargs) -> modelelite3: # 0.98M 78/13s
return modelelite3(points=1024, class_num=num_classes, embed_dim=64, groups=1, res_expansion=0.25,
activation="relu", bias=False, use_xyz=False, normalize="anchor",
dim_expansion=[1, 2, 2, 2], pre_blocks=[1, 1, 1, 1], pos_blocks=[1, 1, 1, 0],
k_neighbors=[24,24,24,24], reducers=[2, 2, 2, 2], **kwargs)
def modelelite3X13(num_classes=40, **kwargs) -> modelelite3: # 0.94M 90/15s
return modelelite3(points=1024, class_num=num_classes, embed_dim=64, groups=1, res_expansion=0.125,
activation="relu", bias=False, use_xyz=False, normalize="anchor",
dim_expansion=[1, 2, 2, 2], pre_blocks=[1, 2, 2, 2], pos_blocks=[1, 1, 0, 0],
k_neighbors=[24,24,24,24], reducers=[2, 2, 2, 2], **kwargs)
if __name__ == '__main__':
# data = torch.rand(2, 128, 10)
# model = ConvBNReLURes1D(128, groups=2, activation='relu')
# out = model(data)
# print(out.shape)
#
# batch, groups, neighbors, dim = 2, 512, 32, 16
# x = torch.rand(batch, groups, neighbors, dim)
# pre_extractor = PreExtraction(dim, 3)
# out = pre_extractor(x)
# print(out.shape)
#
# x = torch.rand(batch, dim, groups)
# pos_extractor = PosExtraction(dim, 3)
# out = pos_extractor(x)
# print(out.shape)
data = torch.rand(2, 3, 1024)
print("===> testing modelN ...")
model = modelelite3A1()
out = model(data)
print(out.shape)