""" Based on model31, different configures for elite version. Based on model30, change the grouper operation by normalization. Based on model28, only change configurations, mainly the reducer. Based on model27, change to x-a, reorgnized structure Based on model25, simple LocalGrouper (not x-a), reorgnized structure Based on model24, using ReLU to replace GELU Based on model22, remove attention Bsed on model21, change FPS to random sampling. Exactly based on Model10, but ReLU to GeLU Based on Model8, add dropout and max, avg combine. Based on Local model, add residual connections. The extraction is doubled for depth. Learning Point Cloud with Progressively Local representation. [B,3,N] - {[B,G,K,d]-[B,G,d]} - {[B,G',K,d]-[B,G',d]} -cls """ import torch import torch.nn as nn import torch.nn.functional as F # from torch import einsum # from einops import rearrange, repeat from einops.layers.torch import Rearrange from pointnet2_ops import pointnet2_utils def get_activation(activation): if activation.lower() == 'gelu': return nn.GELU() elif activation.lower() == 'rrelu': return nn.RReLU(inplace=True) elif activation.lower() == 'selu': return nn.SELU(inplace=True) elif activation.lower() == 'silu': return nn.SiLU(inplace=True) elif activation.lower() == 'hardswish': return nn.Hardswish(inplace=True) elif activation.lower() == 'leakyrelu': return nn.LeakyReLU(inplace=True) else: return nn.ReLU(inplace=True) def square_distance(src, dst): """ Calculate Euclid distance between each two points. src^T * dst = xn * xm + yn * ym + zn * zm; sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn; sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm; dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2 = sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst Input: src: source points, [B, N, C] dst: target points, [B, M, C] Output: dist: per-point square distance, [B, N, M] """ B, N, _ = src.shape _, M, _ = dst.shape dist = -2 * torch.matmul(src, dst.permute(0, 2, 1)) dist += torch.sum(src ** 2, -1).view(B, N, 1) dist += torch.sum(dst ** 2, -1).view(B, 1, M) return dist def index_points(points, idx): """ Input: points: input points data, [B, N, C] idx: sample index data, [B, S] Return: new_points:, indexed points data, [B, S, C] """ device = points.device B = points.shape[0] view_shape = list(idx.shape) view_shape[1:] = [1] * (len(view_shape) - 1) repeat_shape = list(idx.shape) repeat_shape[0] = 1 batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape) new_points = points[batch_indices, idx, :] return new_points def farthest_point_sample(xyz, npoint): """ Input: xyz: pointcloud data, [B, N, 3] npoint: number of samples Return: centroids: sampled pointcloud index, [B, npoint] """ device = xyz.device B, N, C = xyz.shape centroids = torch.zeros(B, npoint, dtype=torch.long).to(device) distance = torch.ones(B, N).to(device) * 1e10 farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device) batch_indices = torch.arange(B, dtype=torch.long).to(device) for i in range(npoint): centroids[:, i] = farthest centroid = xyz[batch_indices, farthest, :].view(B, 1, 3) dist = torch.sum((xyz - centroid) ** 2, -1) distance = torch.min(distance, dist) farthest = torch.max(distance, -1)[1] return centroids def query_ball_point(radius, nsample, xyz, new_xyz): """ Input: radius: local region radius nsample: max sample number in local region xyz: all points, [B, N, 3] new_xyz: query points, [B, S, 3] Return: group_idx: grouped points index, [B, S, nsample] """ device = xyz.device B, N, C = xyz.shape _, S, _ = new_xyz.shape group_idx = torch.arange(N, dtype=torch.long).to(device).view(1, 1, N).repeat([B, S, 1]) sqrdists = square_distance(new_xyz, xyz) group_idx[sqrdists > radius ** 2] = N group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample] group_first = group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, nsample]) mask = group_idx == N group_idx[mask] = group_first[mask] return group_idx def knn_point(nsample, xyz, new_xyz): """ Input: nsample: max sample number in local region xyz: all points, [B, N, C] new_xyz: query points, [B, S, C] Return: group_idx: grouped points index, [B, S, nsample] """ sqrdists = square_distance(new_xyz, xyz) _, group_idx = torch.topk(sqrdists, nsample, dim=-1, largest=False, sorted=False) return group_idx class LocalGrouper(nn.Module): def __init__(self, channel, groups, kneighbors, use_xyz=True, normalize="center", **kwargs): """ Give xyz[b,p,3] and fea[b,p,d], return new_xyz[b,g,3] and new_fea[b,g,k,d] :param groups: groups number :param kneighbors: k-nerighbors :param kwargs: others """ super(LocalGrouper, self).__init__() self.groups = groups self.kneighbors = kneighbors self.use_xyz = use_xyz if normalize is not None: self.normalize = normalize.lower() else: self.normalize = None if self.normalize not in ["center", "anchor"]: print(f"Unrecognized normalize parameter (self.normalize), set to None. Should be one of [center, anchor].") self.normalize = None if self.normalize is not None: add_channel=3 if self.use_xyz else 0 self.affine_alpha = nn.Parameter(torch.ones([1,1,1,channel + add_channel])) self.affine_beta = nn.Parameter(torch.zeros([1, 1, 1, channel + add_channel])) def forward(self, xyz, points): B, N, C = xyz.shape S = self.groups xyz = xyz.contiguous() # xyz [btach, points, xyz] # fps_idx = torch.multinomial(torch.linspace(0, N - 1, steps=N).repeat(B, 1).to(xyz.device), num_samples=self.groups, replacement=False).long() # fps_idx = farthest_point_sample(xyz, self.groups).long() fps_idx = pointnet2_utils.furthest_point_sample(xyz, self.groups).long() # [B, npoint] new_xyz = index_points(xyz, fps_idx) # [B, npoint, 3] new_points = index_points(points, fps_idx) # [B, npoint, d] idx = knn_point(self.kneighbors, xyz, new_xyz) # idx = query_ball_point(radius, nsample, xyz, new_xyz) grouped_xyz = index_points(xyz, idx) # [B, npoint, k, 3] grouped_points = index_points(points, idx) # [B, npoint, k, d] if self.use_xyz: grouped_points = torch.cat([grouped_points, grouped_xyz],dim=-1) # [B, npoint, k, d+3] if self.normalize is not None: if self.normalize =="center": std, mean = torch.std_mean(grouped_points, dim=2, keepdim=True) if self.normalize =="anchor": mean = torch.cat([new_points, new_xyz],dim=-1) if self.use_xyz else new_points mean = mean.unsqueeze(dim=-2) # [B, npoint, 1, d+3] std = torch.std(grouped_points-mean) grouped_points = (grouped_points-mean)/(std + 1e-5) grouped_points = self.affine_alpha*grouped_points + self.affine_beta new_points = torch.cat([grouped_points, new_points.view(B, S, 1, -1).repeat(1, 1, self.kneighbors, 1)], dim=-1) return new_xyz, new_points class ConvBNReLU1D(nn.Module): def __init__(self, in_channels, out_channels, kernel_size=1, bias=True, activation='relu'): super(ConvBNReLU1D, self).__init__() self.act = get_activation(activation) self.net = nn.Sequential( nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, bias=bias), nn.BatchNorm1d(out_channels), self.act ) def forward(self, x): return self.net(x) class ConvBNReLURes1D(nn.Module): def __init__(self, channel, kernel_size=1, groups=1, res_expansion=1.0, bias=True, activation='relu'): super(ConvBNReLURes1D, self).__init__() self.act = get_activation(activation) self.net1 = nn.Sequential( nn.Conv1d(in_channels=channel, out_channels=int(channel * res_expansion), kernel_size=kernel_size, groups=groups, bias=bias), nn.BatchNorm1d(int(channel * res_expansion)), self.act ) self.net2 = nn.Sequential( nn.Conv1d(in_channels=int(channel * res_expansion), out_channels=channel, kernel_size=kernel_size, groups=groups, bias=bias), nn.BatchNorm1d(channel) ) def forward(self, x): return self.act(self.net2(self.net1(x)) + x) class PreExtraction(nn.Module): def __init__(self, channels, out_channels, blocks=1, groups=1, res_expansion=1, bias=True, activation='relu', use_xyz=True): """ input: [b,g,k,d]: output:[b,d,g] :param channels: :param blocks: """ super(PreExtraction, self).__init__() in_channels = 3+2*channels if use_xyz else 2*channels self.transfer = ConvBNReLU1D(in_channels, out_channels, bias=bias, activation=activation) operation = [] for _ in range(blocks): operation.append( ConvBNReLURes1D(out_channels, groups=groups, res_expansion=res_expansion, bias=bias, activation=activation) ) self.operation = nn.Sequential(*operation) def forward(self, x): b, n, s, d = x.size() # torch.Size([32, 512, 32, 6]) x = x.permute(0, 1, 3, 2) x = x.reshape(-1, d, s) x = self.transfer(x) batch_size, _, _ = x.size() x = self.operation(x) # [b, d, k] x = F.adaptive_max_pool1d(x, 1).view(batch_size, -1) x = x.reshape(b, n, -1).permute(0, 2, 1) return x class PosExtraction(nn.Module): def __init__(self, channels, blocks=1, groups=1, res_expansion=1, bias=True, activation='relu'): """ input[b,d,g]; output[b,d,g] :param channels: :param blocks: """ super(PosExtraction, self).__init__() operation = [] for _ in range(blocks): operation.append( ConvBNReLURes1D(channels, groups=groups, res_expansion=res_expansion, bias=bias, activation=activation) ) self.operation = nn.Sequential(*operation) def forward(self, x): # [b, d, g] return self.operation(x) class modelelite3(nn.Module): def __init__(self, points=1024, class_num=40, embed_dim=64, groups=1, res_expansion=1.0, activation="relu", bias=True, use_xyz=True, normalize="center", dim_expansion=[2, 2, 2, 2], pre_blocks=[2, 2, 2, 2], pos_blocks=[2, 2, 2, 2], k_neighbors=[32, 32, 32, 32], reducers=[2, 2, 2, 2], **kwargs): super(modelelite3, self).__init__() self.stages = len(pre_blocks) self.class_num = class_num self.points = points self.embedding = ConvBNReLU1D(3, embed_dim, bias=bias, activation=activation) assert len(pre_blocks) == len(k_neighbors) == len(reducers) == len(pos_blocks) == len(dim_expansion), \ "Please check stage number consistent for pre_blocks, pos_blocks k_neighbors, reducers." self.local_grouper_list = nn.ModuleList() self.pre_blocks_list = nn.ModuleList() self.pos_blocks_list = nn.ModuleList() last_channel = embed_dim anchor_points = self.points for i in range(len(pre_blocks)): out_channel = last_channel * dim_expansion[i] pre_block_num = pre_blocks[i] pos_block_num = pos_blocks[i] kneighbor = k_neighbors[i] reduce = reducers[i] anchor_points = anchor_points // reduce # append local_grouper_list local_grouper = LocalGrouper(last_channel, anchor_points, kneighbor, use_xyz, normalize) # [b,g,k,d] self.local_grouper_list.append(local_grouper) # append pre_block_list pre_block_module = PreExtraction(last_channel, out_channel, pre_block_num, groups=groups, res_expansion=res_expansion, bias=bias, activation=activation, use_xyz=use_xyz) self.pre_blocks_list.append(pre_block_module) # append pos_block_list pos_block_module = PosExtraction(out_channel, pos_block_num, groups=groups, res_expansion=res_expansion, bias=bias, activation=activation) self.pos_blocks_list.append(pos_block_module) last_channel = out_channel self.act = get_activation(activation) self.classifier = nn.Sequential( nn.Linear(last_channel, 512), nn.BatchNorm1d(512), self.act, nn.Dropout(0.5), nn.Linear(512, 256), nn.BatchNorm1d(256), self.act, nn.Dropout(0.5), nn.Linear(256, self.class_num) ) def forward(self, x): xyz = x.permute(0, 2, 1) batch_size, _, _ = x.size() x = self.embedding(x) # B,D,N for i in range(self.stages): # Give xyz[b, p, 3] and fea[b, p, d], return new_xyz[b, g, 3] and new_fea[b, g, k, d] xyz, x = self.local_grouper_list[i](xyz, x.permute(0, 2, 1)) # [b,g,3] [b,g,k,d] x = self.pre_blocks_list[i](x) # [b,d,g] x = self.pos_blocks_list[i](x) # [b,d,g] x = F.adaptive_max_pool1d(x, 1).squeeze(dim=-1) x = self.classifier(x) return x def modelelite3A1(num_classes=40, **kwargs) -> modelelite3: # 3.48M return modelelite3(points=1024, class_num=num_classes, embed_dim=64, groups=1, res_expansion=0.125, activation="relu", bias=False, use_xyz=False, normalize="anchor", dim_expansion=[2, 2, 2, 2], pre_blocks=[2, 2, 2, 2], pos_blocks=[2, 2, 2, 2], k_neighbors=[24, 24, 24, 24], reducers=[2, 2, 2, 2], **kwargs) def modelelite3B1(num_classes=40, **kwargs) -> modelelite3: # 2.78M return modelelite3(points=1024, class_num=num_classes, embed_dim=64, groups=1, res_expansion=0.0625, activation="relu", bias=False, use_xyz=False, normalize="anchor", dim_expansion=[2, 2, 2, 2], pre_blocks=[2, 2, 2, 2], pos_blocks=[2, 2, 2, 2], k_neighbors=[24, 24, 24, 24], reducers=[2, 2, 2, 2], **kwargs) def modelelite3C1(num_classes=40, **kwargs) -> modelelite3: # 4.87M return modelelite3(points=1024, class_num=num_classes, embed_dim=64, groups=1, res_expansion=0.25, activation="relu", bias=False, use_xyz=False, normalize="anchor", dim_expansion=[2, 2, 2, 2], pre_blocks=[2, 2, 2, 2], pos_blocks=[2, 2, 2, 2], k_neighbors=[24, 24, 24, 24], reducers=[2, 2, 2, 2], **kwargs) def modelelite3D1(num_classes=40, **kwargs) -> modelelite3: # 2.26M return modelelite3(points=1024, class_num=num_classes, embed_dim=64, groups=8, res_expansion=0.125, activation="relu", bias=False, use_xyz=False, normalize="anchor", dim_expansion=[2, 2, 2, 2], pre_blocks=[2, 2, 2, 2], pos_blocks=[2, 2, 2, 2], k_neighbors=[24, 24, 24, 24], reducers=[2, 2, 2, 2], **kwargs) def modelelite3E1(num_classes=40, **kwargs) -> modelelite3: # 2.43M return modelelite3(points=1024, class_num=num_classes, embed_dim=64, groups=4, res_expansion=0.125, activation="relu", bias=False, use_xyz=False, normalize="anchor", dim_expansion=[2, 2, 2, 2], pre_blocks=[2, 2, 2, 2], pos_blocks=[2, 2, 2, 2], k_neighbors=[24, 24, 24, 24], reducers=[2, 2, 2, 2], **kwargs) def modelelite3F1(num_classes=40, **kwargs) -> modelelite3: # 0.85M return modelelite3(points=1024, class_num=num_classes, embed_dim=32, groups=4, res_expansion=0.125, activation="relu", bias=False, use_xyz=False, normalize="anchor", dim_expansion=[2, 2, 2, 2], pre_blocks=[2, 2, 2, 2], pos_blocks=[2, 2, 2, 2], k_neighbors=[24, 24, 24, 24], reducers=[2, 2, 2, 2], **kwargs) def modelelite3G1(num_classes=40, **kwargs) -> modelelite3: # 0.85M return modelelite3(points=1024, class_num=num_classes, embed_dim=32, groups=4, res_expansion=0.125, activation="relu", bias=False, use_xyz=False, normalize="anchor", dim_expansion=[2, 2, 2, 2], pre_blocks=[2, 2, 2, 2], pos_blocks=[2, 2, 2, 2], k_neighbors=[24, 24, 24, 24], reducers=[2, 2, 2, 2], **kwargs) def modelelite3H1(num_classes=40, **kwargs) -> modelelite3: # 3.56M return modelelite3(points=1024, class_num=num_classes, embed_dim=32, groups=1, res_expansion=1, activation="relu", bias=False, use_xyz=False, normalize="anchor", dim_expansion=[2, 2, 2, 2], pre_blocks=[2, 2, 2, 2], pos_blocks=[2, 2, 2, 2], k_neighbors=[24, 24, 24, 24], reducers=[2, 2, 2, 2], **kwargs) def modelelite3I1(num_classes=40, **kwargs) -> modelelite3: # 0.90M return modelelite3(points=1024, class_num=num_classes, embed_dim=32, groups=4, res_expansion=0.125, activation="relu", bias=False, use_xyz=False, normalize="anchor", dim_expansion=[2, 2, 2, 2], pre_blocks=[3, 3, 3, 3], pos_blocks=[3, 3, 3, 3], k_neighbors=[24, 24, 24, 24], reducers=[2, 2, 2, 2], **kwargs) def modelelite3J1(num_classes=40, **kwargs) -> modelelite3: # 0.93M return modelelite3(points=1024, class_num=num_classes, embed_dim=32, groups=4, res_expansion=0.125, activation="relu", bias=False, use_xyz=False, normalize="anchor", dim_expansion=[2, 2, 2, 2], pre_blocks=[3, 4, 6, 3], pos_blocks=[3, 4, 6, 3], k_neighbors=[24, 24, 24, 24], reducers=[2, 2, 2, 2], **kwargs) def modelelite3K1(num_classes=40, **kwargs) -> modelelite3: # 0.93M return modelelite3(points=1024, class_num=num_classes, embed_dim=32, groups=8, res_expansion=0.25, activation="relu", bias=False, use_xyz=False, normalize="anchor", dim_expansion=[2, 2, 2, 2], pre_blocks=[3, 4, 6, 3], pos_blocks=[3, 4, 6, 3], k_neighbors=[24, 24, 24, 24], reducers=[2, 2, 2, 2], **kwargs) def modelelite3L1(num_classes=40, **kwargs) -> modelelite3: # 0.95M return modelelite3(points=1024, class_num=num_classes, embed_dim=32, groups=8, res_expansion=0.25, activation="relu", bias=True, use_xyz=True, normalize="anchor", dim_expansion=[2, 2, 2, 2], pre_blocks=[3, 4, 6, 3], pos_blocks=[3, 4, 6, 3], k_neighbors=[24, 24, 24, 24], reducers=[2, 2, 2, 2], **kwargs) def modelelite3M1(num_classes=40, **kwargs) -> modelelite3: # 0.94M return modelelite3(points=1024, class_num=num_classes, embed_dim=64, groups=4, res_expansion=0.125, activation="relu", bias=False, use_xyz=False, normalize="anchor", dim_expansion=[2, 2, 2], pre_blocks=[4, 4, 4], pos_blocks=[4, 4, 4], k_neighbors=[24, 24, 24], reducers=[2, 2, 2], **kwargs) ########version 2: 64 neighbors with 0.5 drop ratio ########### def modelelite3A2(num_classes=40, **kwargs) -> modelelite3: # 3.48M return modelelite3(points=1024, class_num=num_classes, embed_dim=64, groups=1, res_expansion=0.125, activation="relu", bias=False, use_xyz=False, normalize="anchor", dim_expansion=[2, 2, 2, 2], pre_blocks=[2, 2, 2, 2], pos_blocks=[2, 2, 2, 2], k_neighbors=[32, 32, 32, 32], reducers=[2, 2, 2, 2], **kwargs) def modelelite3B2(num_classes=40, **kwargs) -> modelelite3: # 2.78M return modelelite3(points=1024, class_num=num_classes, embed_dim=64, groups=1, res_expansion=0.0625, activation="relu", bias=False, use_xyz=False, normalize="anchor", dim_expansion=[2, 2, 2, 2], pre_blocks=[2, 2, 2, 2], pos_blocks=[2, 2, 2, 2], k_neighbors=[32, 32, 32, 32], reducers=[2, 2, 2, 2], **kwargs) def modelelite3C2(num_classes=40, **kwargs) -> modelelite3: # 4.87M return modelelite3(points=1024, class_num=num_classes, embed_dim=64, groups=1, res_expansion=0.25, activation="relu", bias=False, use_xyz=False, normalize="anchor", dim_expansion=[2, 2, 2, 2], pre_blocks=[2, 2, 2, 2], pos_blocks=[2, 2, 2, 2], k_neighbors=[32, 32, 32, 32], reducers=[2, 2, 2, 2], **kwargs) def modelelite3D2(num_classes=40, **kwargs) -> modelelite3: # 2.26M return modelelite3(points=1024, class_num=num_classes, embed_dim=64, groups=8, res_expansion=0.125, activation="relu", bias=False, use_xyz=False, normalize="anchor", dim_expansion=[2, 2, 2, 2], pre_blocks=[2, 2, 2, 2], pos_blocks=[2, 2, 2, 2], k_neighbors=[32, 32, 32, 32], reducers=[2, 2, 2, 2], **kwargs) def modelelite3E2(num_classes=40, **kwargs) -> modelelite3: # 2.43M return modelelite3(points=1024, class_num=num_classes, embed_dim=64, groups=4, res_expansion=0.125, activation="relu", bias=False, use_xyz=False, normalize="anchor", dim_expansion=[2, 2, 2, 2], pre_blocks=[2, 2, 2, 2], pos_blocks=[2, 2, 2, 2], k_neighbors=[32, 32, 32, 32], reducers=[2, 2, 2, 2], **kwargs) def modelelite3F2(num_classes=40, **kwargs) -> modelelite3: # 0.85M return modelelite3(points=1024, class_num=num_classes, embed_dim=32, groups=4, res_expansion=0.125, activation="relu", bias=False, use_xyz=False, normalize="anchor", dim_expansion=[2, 2, 2, 2], pre_blocks=[2, 2, 2, 2], pos_blocks=[2, 2, 2, 2], k_neighbors=[32, 32, 32, 32], reducers=[2, 2, 2, 2], **kwargs) def modelelite3G2(num_classes=40, **kwargs) -> modelelite3: # 0.85M return modelelite3(points=1024, class_num=num_classes, embed_dim=32, groups=4, res_expansion=0.125, activation="relu", bias=False, use_xyz=False, normalize="anchor", dim_expansion=[2, 2, 2, 2], pre_blocks=[2, 2, 2, 2], pos_blocks=[2, 2, 2, 2], k_neighbors=[32, 32, 32, 32], reducers=[2, 2, 2, 2], **kwargs) def modelelite3H2(num_classes=40, **kwargs) -> modelelite3: # 3.56M return modelelite3(points=1024, class_num=num_classes, embed_dim=32, groups=1, res_expansion=1, activation="relu", bias=False, use_xyz=False, normalize="anchor", dim_expansion=[2, 2, 2, 2], pre_blocks=[2, 2, 2, 2], pos_blocks=[2, 2, 2, 2], k_neighbors=[32, 32, 32, 32], reducers=[2, 2, 2, 2], **kwargs) def modelelite3I2(num_classes=40, **kwargs) -> modelelite3: # 0.90M return modelelite3(points=1024, class_num=num_classes, embed_dim=32, groups=4, res_expansion=0.125, activation="relu", bias=False, use_xyz=False, normalize="anchor", dim_expansion=[2, 2, 2, 2], pre_blocks=[3, 3, 3, 3], pos_blocks=[3, 3, 3, 3], k_neighbors=[32, 32, 32, 32], reducers=[2, 2, 2, 2], **kwargs) def modelelite3J2(num_classes=40, **kwargs) -> modelelite3: # 0.93M return modelelite3(points=1024, class_num=num_classes, embed_dim=32, groups=4, res_expansion=0.125, activation="relu", bias=False, use_xyz=False, normalize="anchor", dim_expansion=[2, 2, 2, 2], pre_blocks=[3, 4, 6, 3], pos_blocks=[3, 4, 6, 3], k_neighbors=[32, 32, 32, 32], reducers=[2, 2, 2, 2], **kwargs) def modelelite3K2(num_classes=40, **kwargs) -> modelelite3: # 0.93M return modelelite3(points=1024, class_num=num_classes, embed_dim=32, groups=8, res_expansion=0.25, activation="relu", bias=False, use_xyz=False, normalize="anchor", dim_expansion=[2, 2, 2, 2], pre_blocks=[3, 4, 6, 3], pos_blocks=[3, 4, 6, 3], k_neighbors=[32, 32, 32, 32], reducers=[2, 2, 2, 2], **kwargs) def modelelite3L2(num_classes=40, **kwargs) -> modelelite3: # 0.95M return modelelite3(points=1024, class_num=num_classes, embed_dim=32, groups=8, res_expansion=0.25, activation="relu", bias=True, use_xyz=True, normalize="anchor", dim_expansion=[2, 2, 2, 2], pre_blocks=[3, 4, 6, 3], pos_blocks=[3, 4, 6, 3], k_neighbors=[32, 32, 32, 32], reducers=[2, 2, 2, 2], **kwargs) def modelelite3M2(num_classes=40, **kwargs) -> modelelite3: # 0.94M return modelelite3(points=1024, class_num=num_classes, embed_dim=64, groups=4, res_expansion=0.125, activation="relu", bias=False, use_xyz=False, normalize="anchor", dim_expansion=[2, 2, 2], pre_blocks=[4, 4, 4], pos_blocks=[4, 4, 4], k_neighbors=[32, 32, 32], reducers=[2, 2, 2], **kwargs) def modelelite3X1(num_classes=40, **kwargs) -> modelelite3: # 1.11M return modelelite3(points=1024, class_num=num_classes, embed_dim=32, groups=1, res_expansion=0.125, activation="relu", bias=False, use_xyz=False, normalize="anchor", dim_expansion=[2, 2, 2, 2], pre_blocks=[2, 2, 2, 2], pos_blocks=[2, 2, 2, 2], k_neighbors=[20, 20, 20, 20], reducers=[2, 2, 2, 2], **kwargs) def modelelite3X2(num_classes=40, **kwargs) -> modelelite3: # 0.94M return modelelite3(points=1024, class_num=num_classes, embed_dim=32, groups=2, res_expansion=0.125, activation="relu", bias=False, use_xyz=False, normalize="anchor", dim_expansion=[2, 2, 2, 2], pre_blocks=[2, 2, 2, 2], pos_blocks=[2, 2, 2, 2], k_neighbors=[20, 20, 20, 20], reducers=[2, 2, 2, 2], **kwargs) def modelelite3X3(num_classes=40, **kwargs) -> modelelite3: # 0.94 return modelelite3(points=1024, class_num=num_classes, embed_dim=32, groups=1, res_expansion=0.125, activation="relu", bias=False, use_xyz=False, normalize="anchor", dim_expansion=[2, 2, 2, 2], pre_blocks=[1, 1, 1, 1], pos_blocks=[1, 1, 1, 1], k_neighbors=[20, 20, 20, 20], reducers=[2, 2, 2, 2], **kwargs) def modelelite3X4(num_classes=40, **kwargs) -> modelelite3: # 2.77 return modelelite3(points=1024, class_num=num_classes, embed_dim=64, groups=1, res_expansion=0.125, activation="relu", bias=False, use_xyz=False, normalize="anchor", dim_expansion=[2, 2, 2, 2], pre_blocks=[1, 1, 1, 1], pos_blocks=[1, 1, 1, 1], k_neighbors=[20, 20, 20, 20], reducers=[2, 2, 2, 2], **kwargs) def modelelite3X5(num_classes=40, **kwargs) -> modelelite3: # 1.59 return modelelite3(points=1024, class_num=num_classes, embed_dim=64, groups=1, res_expansion=0.125, activation="relu", bias=False, use_xyz=False, normalize="anchor", dim_expansion=[2, 2, 2, 1], pre_blocks=[1, 1, 1, 1], pos_blocks=[1, 1, 1, 1], k_neighbors=[20, 20, 20, 20], reducers=[2, 2, 2, 2], **kwargs) def modelelite3X6(num_classes=40, **kwargs) -> modelelite3: # 1.44 return modelelite3(points=1024, class_num=num_classes, embed_dim=64, groups=4, res_expansion=0.25, activation="relu", bias=False, use_xyz=False, normalize="anchor", dim_expansion=[2, 2, 2, 1], pre_blocks=[1, 1, 1, 1], pos_blocks=[1, 1, 1, 1], k_neighbors=[20, 20, 20, 20], reducers=[2, 2, 2, 2], **kwargs) def modelelite3X7(num_classes=40, **kwargs) -> modelelite3: # 1.11M return modelelite3(points=1024, class_num=num_classes, embed_dim=32, groups=1, res_expansion=0.125, activation="relu", bias=False, use_xyz=False, normalize="anchor", dim_expansion=[2, 2, 2, 2], pre_blocks=[2, 2, 2, 2], pos_blocks=[2, 2, 2, 2], k_neighbors=[24,24,24,24], reducers=[2, 2, 2, 2], **kwargs) def modelelite3X8(num_classes=40, **kwargs) -> modelelite3: # 0.95M return modelelite3(points=1024, class_num=num_classes, embed_dim=32, groups=1, res_expansion=0.125, activation="relu", bias=False, use_xyz=False, normalize="anchor", dim_expansion=[2, 2, 2, 2], pre_blocks=[1, 1, 2, 1], pos_blocks=[1, 1, 1, 1], k_neighbors=[24,24,24,24], reducers=[2, 2, 2, 2], **kwargs) def modelelite3X9(num_classes=40, **kwargs) -> modelelite3: # 1.59M return modelelite3(points=1024, class_num=num_classes, embed_dim=64, groups=1, res_expansion=0.125, activation="relu", bias=False, use_xyz=False, normalize="anchor", dim_expansion=[2, 2, 2, 1], pre_blocks=[1, 1, 1, 1], pos_blocks=[1, 1, 1, 1], k_neighbors=[24,24,24,24], reducers=[2, 2, 2, 2], **kwargs) def modelelite3X10(num_classes=40, **kwargs) -> modelelite3: # 0.72M return modelelite3(points=1024, class_num=num_classes, embed_dim=32, groups=1, res_expansion=0.25, activation="relu", bias=False, use_xyz=False, normalize="anchor", dim_expansion=[2, 2, 2, 1], pre_blocks=[1, 1, 2, 1], pos_blocks=[1, 1, 2, 1], k_neighbors=[24,24,24,24], reducers=[2, 2, 2, 2], **kwargs) def modelelite3X11(num_classes=40, **kwargs) -> modelelite3: # 0.98M 79/13s return modelelite3(points=1024, class_num=num_classes, embed_dim=64, groups=1, res_expansion=0.25, activation="relu", bias=False, use_xyz=False, normalize="anchor", dim_expansion=[1, 2, 2, 2], pre_blocks=[1, 1, 1, 1], pos_blocks=[1, 1, 1, 0], k_neighbors=[20, 20, 20, 20], reducers=[2, 2, 2, 2], **kwargs) def modelelite3X12(num_classes=40, **kwargs) -> modelelite3: # 0.98M 78/13s return modelelite3(points=1024, class_num=num_classes, embed_dim=64, groups=1, res_expansion=0.25, activation="relu", bias=False, use_xyz=False, normalize="anchor", dim_expansion=[1, 2, 2, 2], pre_blocks=[1, 1, 1, 1], pos_blocks=[1, 1, 1, 0], k_neighbors=[24,24,24,24], reducers=[2, 2, 2, 2], **kwargs) def modelelite3X13(num_classes=40, **kwargs) -> modelelite3: # 0.94M 90/15s return modelelite3(points=1024, class_num=num_classes, embed_dim=64, groups=1, res_expansion=0.125, activation="relu", bias=False, use_xyz=False, normalize="anchor", dim_expansion=[1, 2, 2, 2], pre_blocks=[1, 2, 2, 2], pos_blocks=[1, 1, 0, 0], k_neighbors=[24,24,24,24], reducers=[2, 2, 2, 2], **kwargs) if __name__ == '__main__': # data = torch.rand(2, 128, 10) # model = ConvBNReLURes1D(128, groups=2, activation='relu') # out = model(data) # print(out.shape) # # batch, groups, neighbors, dim = 2, 512, 32, 16 # x = torch.rand(batch, groups, neighbors, dim) # pre_extractor = PreExtraction(dim, 3) # out = pre_extractor(x) # print(out.shape) # # x = torch.rand(batch, dim, groups) # pos_extractor = PosExtraction(dim, 3) # out = pos_extractor(x) # print(out.shape) data = torch.rand(2, 3, 1024) print("===> testing modelN ...") model = modelelite3A1() out = model(data) print(out.shape)