""" Based on PointMLP9, change to 128, 256, 512, 512 Based on PointMLP4, use fps replace random sample. Based on PointMLP3, change dropout to 0.1 Based on PointMLP2, add blocks in decode, change channel numbers. Based on PointMLP1, using poseExtraction to replace MLP in decode. PointsformerE2, 1)change relu to GELU, 2) change backbone to model24. """ import torch import torch.nn as nn import torch.nn.functional as F from torch import einsum from einops import rearrange, repeat from pointnet2_ops import pointnet2_utils def get_activation(activation): if activation.lower() == 'gelu': return nn.GELU() elif activation.lower() == 'rrelu': return nn.RReLU(inplace=True) elif activation.lower() == 'selu': return nn.SELU(inplace=True) elif activation.lower() == 'silu': return nn.SiLU(inplace=True) elif activation.lower() == 'hardswish': return nn.Hardswish(inplace=True) elif activation.lower() == 'leakyrelu': return nn.LeakyReLU(inplace=True) elif activation.lower() == 'leakyrelu0.2': return nn.LeakyReLU(negative_slope=0.2, inplace=True) else: return nn.ReLU(inplace=True) def square_distance(src, dst): """ Calculate Euclid distance between each two points. src^T * dst = xn * xm + yn * ym + zn * zm; sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn; sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm; dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2 = sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst Input: src: source points, [B, N, C] dst: target points, [B, M, C] Output: dist: per-point square distance, [B, N, M] """ B, N, _ = src.shape _, M, _ = dst.shape dist = -2 * torch.matmul(src, dst.permute(0, 2, 1)) dist += torch.sum(src ** 2, -1).view(B, N, 1) dist += torch.sum(dst ** 2, -1).view(B, 1, M) return dist def index_points(points, idx): """ Input: points: input points data, [B, N, C] idx: sample index data, [B, S] Return: new_points:, indexed points data, [B, S, C] """ device = points.device B = points.shape[0] view_shape = list(idx.shape) view_shape[1:] = [1] * (len(view_shape) - 1) repeat_shape = list(idx.shape) repeat_shape[0] = 1 batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape) new_points = points[batch_indices, idx, :] return new_points def farthest_point_sample(xyz, npoint): """ Input: xyz: pointcloud data, [B, N, 3] npoint: number of samples Return: centroids: sampled pointcloud index, [B, npoint] """ device = xyz.device B, N, C = xyz.shape centroids = torch.zeros(B, npoint, dtype=torch.long).to(device) distance = torch.ones(B, N).to(device) * 1e10 farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device) batch_indices = torch.arange(B, dtype=torch.long).to(device) for i in range(npoint): centroids[:, i] = farthest centroid = xyz[batch_indices, farthest, :].view(B, 1, 3) dist = torch.sum((xyz - centroid) ** 2, -1) distance = torch.min(distance, dist) farthest = torch.max(distance, -1)[1] return centroids def query_ball_point(radius, nsample, xyz, new_xyz): """ Input: radius: local region radius nsample: max sample number in local region xyz: all points, [B, N, 3] new_xyz: query points, [B, S, 3] Return: group_idx: grouped points index, [B, S, nsample] """ device = xyz.device B, N, C = xyz.shape _, S, _ = new_xyz.shape group_idx = torch.arange(N, dtype=torch.long).to(device).view(1, 1, N).repeat([B, S, 1]) sqrdists = square_distance(new_xyz, xyz) group_idx[sqrdists > radius ** 2] = N group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample] group_first = group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, nsample]) mask = group_idx == N group_idx[mask] = group_first[mask] return group_idx def knn_point(nsample, xyz, new_xyz): """ Input: nsample: max sample number in local region xyz: all points, [B, N, C] new_xyz: query points, [B, S, C] Return: group_idx: grouped points index, [B, S, nsample] """ sqrdists = square_distance(new_xyz, xyz) _, group_idx = torch.topk(sqrdists, nsample, dim=-1, largest=False, sorted=False) return group_idx class LocalGrouper(nn.Module): def __init__(self, channel, groups, kneighbors, use_xyz=True, normalize="anchor", **kwargs): """ Give xyz[b,p,3] and fea[b,p,d], return new_xyz[b,g,3] and new_fea[b,g,k,d] :param groups: groups number :param kneighbors: k-nerighbors :param kwargs: others """ super(LocalGrouper, self).__init__() self.groups = groups self.kneighbors = kneighbors self.use_xyz = use_xyz if normalize is not None: self.normalize = normalize.lower() else: self.normalize = None if self.normalize not in ["center", "anchor"]: print(f"Unrecognized normalize parameter (self.normalize), set to None. Should be one of [center, anchor].") self.normalize = None if self.normalize is not None: add_channel=3 if self.use_xyz else 0 self.affine_alpha = nn.Parameter(torch.ones([1,1,1,channel + add_channel])) self.affine_beta = nn.Parameter(torch.zeros([1, 1, 1, channel + add_channel])) def forward(self, xyz, points): B, N, C = xyz.shape S = self.groups xyz = xyz.contiguous() # xyz [btach, points, xyz] # fps_idx = torch.multinomial(torch.linspace(0, N - 1, steps=N).repeat(B, 1).to(xyz.device), num_samples=self.groups, replacement=False).long() # fps_idx = farthest_point_sample(xyz, self.groups).long() fps_idx = pointnet2_utils.furthest_point_sample(xyz, self.groups).long() # [B, npoint] new_xyz = index_points(xyz, fps_idx) # [B, npoint, 3] new_points = index_points(points, fps_idx) # [B, npoint, d] idx = knn_point(self.kneighbors, xyz, new_xyz) # idx = query_ball_point(radius, nsample, xyz, new_xyz) grouped_xyz = index_points(xyz, idx) # [B, npoint, k, 3] grouped_points = index_points(points, idx) # [B, npoint, k, d] if self.use_xyz: grouped_points = torch.cat([grouped_points, grouped_xyz],dim=-1) # [B, npoint, k, d+3] if self.normalize is not None: if self.normalize =="center": std, mean = torch.std_mean(grouped_points, dim=2, keepdim=True) if self.normalize =="anchor": mean = torch.cat([new_points, new_xyz],dim=-1) if self.use_xyz else new_points mean = mean.unsqueeze(dim=-2) # [B, npoint, 1, d+3] std = torch.std(grouped_points-mean) grouped_points = (grouped_points-mean)/(std + 1e-5) grouped_points = self.affine_alpha*grouped_points + self.affine_beta new_points = torch.cat([grouped_points, new_points.view(B, S, 1, -1).repeat(1, 1, self.kneighbors, 1)], dim=-1) return new_xyz, new_points class ConvBNReLU1D(nn.Module): def __init__(self, in_channels, out_channels, kernel_size=1, bias=True, activation='relu'): super(ConvBNReLU1D, self).__init__() self.act = get_activation(activation) self.net = nn.Sequential( nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, bias=bias), nn.BatchNorm1d(out_channels), self.act ) def forward(self, x): return self.net(x) class ConvBNReLURes1D(nn.Module): def __init__(self, channel, kernel_size=1, groups=1, res_expansion=1.0, bias=True, activation='relu'): super(ConvBNReLURes1D, self).__init__() self.act = get_activation(activation) self.net1 = nn.Sequential( nn.Conv1d(in_channels=channel, out_channels=int(channel * res_expansion), kernel_size=kernel_size, groups=groups, bias=bias), nn.BatchNorm1d(int(channel * res_expansion)), self.act ) if groups > 1: self.net2 = nn.Sequential( nn.Conv1d(in_channels=int(channel * res_expansion), out_channels=channel, kernel_size=kernel_size, groups=groups, bias=bias), nn.BatchNorm1d(channel), self.act, nn.Conv1d(in_channels=channel, out_channels=channel, kernel_size=kernel_size, bias=bias), nn.BatchNorm1d(channel), ) else: self.net2 = nn.Sequential( nn.Conv1d(in_channels=int(channel * res_expansion), out_channels=channel, kernel_size=kernel_size, bias=bias), nn.BatchNorm1d(channel) ) def forward(self, x): return self.act(self.net2(self.net1(x)) + x) class PreExtraction(nn.Module): def __init__(self, channels, out_channels, blocks=1, groups=1, res_expansion=1, bias=True, activation='relu', use_xyz=True): """ input: [b,g,k,d]: output:[b,d,g] :param channels: :param blocks: """ super(PreExtraction, self).__init__() in_channels = 3+2*channels if use_xyz else 2*channels self.transfer = ConvBNReLU1D(in_channels, out_channels, bias=bias, activation=activation) operation = [] for _ in range(blocks): operation.append( ConvBNReLURes1D(out_channels, groups=groups, res_expansion=res_expansion, bias=bias, activation=activation) ) self.operation = nn.Sequential(*operation) def forward(self, x): b, n, s, d = x.size() # torch.Size([32, 512, 32, 6]) x = x.permute(0, 1, 3, 2) x = x.reshape(-1, d, s) x = self.transfer(x) batch_size, _, _ = x.size() x = self.operation(x) # [b, d, k] x = F.adaptive_max_pool1d(x, 1).view(batch_size, -1) x = x.reshape(b, n, -1).permute(0, 2, 1) return x class PosExtraction(nn.Module): def __init__(self, channels, blocks=1, groups=1, res_expansion=1, bias=True, activation='relu'): """ input[b,d,g]; output[b,d,g] :param channels: :param blocks: """ super(PosExtraction, self).__init__() operation = [] for _ in range(blocks): operation.append( ConvBNReLURes1D(channels, groups=groups, res_expansion=res_expansion, bias=bias, activation=activation) ) self.operation = nn.Sequential(*operation) def forward(self, x): # [b, d, g] return self.operation(x) class PointNetFeaturePropagation(nn.Module): def __init__(self, in_channel, out_channel, blocks=1, groups=1, res_expansion=1.0, bias=True, activation='relu'): super(PointNetFeaturePropagation, self).__init__() self.fuse = ConvBNReLU1D(in_channel, out_channel, 1, bias=bias) self.extraction = PosExtraction(out_channel, blocks, groups=groups, res_expansion=res_expansion, bias=bias, activation=activation) def forward(self, xyz1, xyz2, points1, points2): """ Input: xyz1: input points position data, [B, N, 3] xyz2: sampled input points position data, [B, S, 3] points1: input points data, [B, D', N] points2: input points data, [B, D'', S] Return: new_points: upsampled points data, [B, D''', N] """ # xyz1 = xyz1.permute(0, 2, 1) # xyz2 = xyz2.permute(0, 2, 1) points2 = points2.permute(0, 2, 1) B, N, C = xyz1.shape _, S, _ = xyz2.shape if S == 1: interpolated_points = points2.repeat(1, N, 1) else: dists = square_distance(xyz1, xyz2) dists, idx = dists.sort(dim=-1) dists, idx = dists[:, :, :3], idx[:, :, :3] # [B, N, 3] dist_recip = 1.0 / (dists + 1e-8) norm = torch.sum(dist_recip, dim=2, keepdim=True) weight = dist_recip / norm interpolated_points = torch.sum(index_points(points2, idx) * weight.view(B, N, 3, 1), dim=2) if points1 is not None: points1 = points1.permute(0, 2, 1) new_points = torch.cat([points1, interpolated_points], dim=-1) else: new_points = interpolated_points new_points = new_points.permute(0, 2, 1) new_points = self.fuse(new_points) new_points = self.extraction(new_points) return new_points class PointMLP31(nn.Module): def __init__(self, num_classes=50,points=2048, embed_dim=64, groups=1, res_expansion=1.0, activation="relu", bias=True, use_xyz=True, normalize="anchor", dim_expansion=[2, 2, 2, 2], pre_blocks=[2, 2, 2, 2], pos_blocks=[2, 2, 2, 2], k_neighbors=[32, 32, 32, 32], reducers=[4, 4, 4, 4], de_dims=[512, 256, 128, 128], de_blocks=[2,2,2,2], gmp_dim=64,cls_dim=64, **kwargs): super(PointMLP31, self).__init__() self.stages = len(pre_blocks) self.class_num = num_classes self.points = points self.embedding = ConvBNReLU1D(3, embed_dim, bias=bias, activation=activation) assert len(pre_blocks) == len(k_neighbors) == len(reducers) == len(pos_blocks) == len(dim_expansion), \ "Please check stage number consistent for pre_blocks, pos_blocks k_neighbors, reducers." self.local_grouper_list = nn.ModuleList() self.pre_blocks_list = nn.ModuleList() self.pos_blocks_list = nn.ModuleList() last_channel = embed_dim anchor_points = self.points en_dims = [last_channel] ### Building Encoder ##### for i in range(len(pre_blocks)): out_channel = last_channel * dim_expansion[i] pre_block_num = pre_blocks[i] pos_block_num = pos_blocks[i] kneighbor = k_neighbors[i] reduce = reducers[i] anchor_points = anchor_points // reduce # append local_grouper_list local_grouper = LocalGrouper(last_channel, anchor_points, kneighbor, use_xyz, normalize) # [b,g,k,d] self.local_grouper_list.append(local_grouper) # append pre_block_list pre_block_module = PreExtraction(last_channel, out_channel, pre_block_num, groups=groups, res_expansion=res_expansion, bias=bias, activation=activation, use_xyz=use_xyz) self.pre_blocks_list.append(pre_block_module) # append pos_block_list pos_block_module = PosExtraction(out_channel, pos_block_num, groups=groups, res_expansion=res_expansion, bias=bias, activation=activation) self.pos_blocks_list.append(pos_block_module) last_channel = out_channel en_dims.append(last_channel) ### Building Decoder ##### self.decode_list = nn.ModuleList() en_dims.reverse() de_dims.insert(0,en_dims[0]) assert len(en_dims) ==len(de_dims) == len(de_blocks)+1 for i in range(len(en_dims)-1): self.decode_list.append( PointNetFeaturePropagation(de_dims[i]+en_dims[i+1], de_dims[i+1], blocks=de_blocks[i], groups=groups, res_expansion=res_expansion, bias=bias, activation=activation) ) self.act = get_activation(activation) # class label mapping self.cls_map = nn.Sequential( ConvBNReLU1D(16, cls_dim, bias=bias, activation=activation), ConvBNReLU1D(cls_dim, cls_dim, bias=bias, activation=activation) ) # global max pooling mapping self.gmp_map_list = nn.ModuleList() for en_dim in en_dims: self.gmp_map_list.append(ConvBNReLU1D(en_dim, gmp_dim, bias=bias, activation=activation)) self.gmp_map_end = ConvBNReLU1D(gmp_dim*len(en_dims), gmp_dim, bias=bias, activation=activation) # classifier self.classifier = nn.Sequential( nn.Conv1d(gmp_dim+cls_dim+de_dims[-1], 128, 1, bias=bias), nn.BatchNorm1d(128), self.act, nn.Dropout(), nn.Conv1d(128, num_classes, 1, bias=bias) ) self.en_dims = en_dims def forward(self, x, cls_label): xyz = x.permute(0, 2, 1) x = self.embedding(x) # B,D,N xyz_list = [xyz] # [B, N, 3] x_list = [x] # [B, D, N] # here is the encoder for i in range(self.stages): # Give xyz[b, p, 3] and fea[b, p, d], return new_xyz[b, g, 3] and new_fea[b, g, k, d] xyz, x = self.local_grouper_list[i](xyz, x.permute(0, 2, 1)) # [b,g,3] [b,g,k,d] x = self.pre_blocks_list[i](x) # [b,d,g] x = self.pos_blocks_list[i](x) # [b,d,g] xyz_list.append(xyz) x_list.append(x) # here is the decoder xyz_list.reverse() x_list.reverse() x = x_list[0] for i in range(len(self.decode_list)): x = self.decode_list[i](xyz_list[i+1], xyz_list[i], x_list[i+1],x) # here is the global context gmp_list = [] for i in range(len(x_list)): gmp_list.append(F.adaptive_max_pool1d(self.gmp_map_list[i](x_list[i]), 1)) global_context = self.gmp_map_end(torch.cat(gmp_list, dim=1)) # [b, gmp_dim, 1] #here is the cls_token cls_token = self.cls_map(cls_label.unsqueeze(dim=-1)) # [b, cls_dim, 1] x = torch.cat([x, global_context.repeat([1, 1, x.shape[-1]]), cls_token.repeat([1, 1, x.shape[-1]])], dim=1) x = self.classifier(x) # x = F.log_softmax(x, dim=1) # x = x.permute(0, 2, 1) return x def model31A(num_classes=50, **kwargs) -> PointMLP31: return PointMLP31(num_classes=num_classes, points=2048, embed_dim=64, groups=1, res_expansion=1.0, activation="relu", bias=True, use_xyz=False, normalize="anchor", dim_expansion=[2, 2, 2, 2], pre_blocks=[2, 2, 2, 2], pos_blocks=[2, 2, 2, 2], k_neighbors=[32, 32, 32, 32], reducers=[4, 4, 4, 4], de_dims=[512, 256, 128, 128], de_blocks=[4,4,4,4], gmp_dim=64,cls_dim=64, **kwargs) def model31B(num_classes=50, **kwargs) -> PointMLP31: return PointMLP31(num_classes=num_classes,points=2048, embed_dim=64, groups=1, res_expansion=1.0, activation="relu", bias=True, use_xyz=False, normalize="anchor", dim_expansion=[2, 2, 2, 2], pre_blocks=[3, 3, 3, 3], pos_blocks=[3, 3, 3, 3], k_neighbors=[32, 32, 32, 32], reducers=[4, 4, 4, 4], de_dims=[512, 256, 128, 128], de_blocks=[3, 3, 3, 3], gmp_dim=64,cls_dim=64, **kwargs) def model31C(num_classes=50, **kwargs) -> PointMLP31: return PointMLP31(num_classes=num_classes,points=2048, embed_dim=64, groups=1, res_expansion=1.0, activation="relu", bias=True, use_xyz=False, normalize="anchor", dim_expansion=[2, 2, 2, 2], pre_blocks=[2, 2, 2, 2], pos_blocks=[2, 2, 2, 2], k_neighbors=[24, 24, 24, 24], reducers=[4, 4, 4, 4], de_dims=[512, 256, 128, 128], de_blocks=[3, 3, 3, 3], gmp_dim=64,cls_dim=64, **kwargs) def model31D(num_classes=50, **kwargs) -> PointMLP31: return PointMLP31(num_classes=num_classes, points=2048, embed_dim=32, groups=1, res_expansion=1.0, activation="relu", bias=True, use_xyz=False, normalize="anchor", dim_expansion=[2, 2, 2, 2], pre_blocks=[2, 2, 2, 2], pos_blocks=[2, 2, 2, 2], k_neighbors=[32, 32, 32, 32], reducers=[4, 4, 4, 4], de_dims=[512, 256, 128, 128], de_blocks=[4,4,4,4], gmp_dim=64,cls_dim=64, **kwargs) def model31E(num_classes=50, **kwargs) -> PointMLP31: return PointMLP31(num_classes=num_classes, points=2048, embed_dim=64, groups=1, res_expansion=1.0, activation="relu", bias=True, use_xyz=False, normalize="anchor", dim_expansion=[2, 2, 2, 2], pre_blocks=[4, 4, 4, 4], pos_blocks=[4, 4, 4, 4], k_neighbors=[32, 32, 32, 32], reducers=[4, 4, 4, 4], de_dims=[512, 256, 128, 128], de_blocks=[6,6,6,6], gmp_dim=64,cls_dim=64, **kwargs) def model31F(num_classes=50, **kwargs) -> PointMLP31: return PointMLP31(num_classes=num_classes, points=2048, embed_dim=64, groups=1, res_expansion=1.0, activation="relu", bias=True, use_xyz=False, normalize="anchor", dim_expansion=[2, 2, 2, 2], pre_blocks=[2, 2, 2, 2], pos_blocks=[2, 2, 2, 2], k_neighbors=[32, 32, 32, 32], reducers=[4, 2, 4, 2], de_dims=[512, 256, 128, 128], de_blocks=[4,4,4,4], gmp_dim=64,cls_dim=64, **kwargs) def model31G(num_classes=50, **kwargs) -> PointMLP31: return PointMLP31(num_classes=num_classes, points=2048, embed_dim=64, groups=1, res_expansion=1.0, activation="relu", bias=True, use_xyz=True, normalize="anchor", dim_expansion=[2, 2, 2, 2], pre_blocks=[2, 2, 2, 2], pos_blocks=[2, 2, 2, 2], k_neighbors=[32, 32, 32, 32], reducers=[4, 4, 4, 4], de_dims=[512, 256, 128, 128], de_blocks=[4,4,4,4], gmp_dim=64,cls_dim=64, **kwargs) def model31G1(num_classes=50, **kwargs) -> PointMLP31: return PointMLP31(num_classes=num_classes, points=2048, embed_dim=64, groups=1, res_expansion=1.0, activation="leakyrelu0.2", bias=True, use_xyz=True, normalize="anchor", dim_expansion=[2, 2, 2, 2], pre_blocks=[2, 2, 2, 2], pos_blocks=[2, 2, 2, 2], k_neighbors=[32, 32, 32, 32], reducers=[4, 4, 4, 4], de_dims=[512, 256, 128, 128], de_blocks=[4,4,4,4], gmp_dim=64,cls_dim=64, **kwargs) def model31G2(num_classes=50, **kwargs) -> PointMLP31: return PointMLP31(num_classes=num_classes, points=2048, embed_dim=64, groups=1, res_expansion=1.0, activation="leakyrelu", bias=True, use_xyz=True, normalize="anchor", dim_expansion=[2, 2, 2, 2], pre_blocks=[2, 2, 2, 2], pos_blocks=[2, 2, 2, 2], k_neighbors=[32, 32, 32, 32], reducers=[4, 4, 4, 4], de_dims=[512, 256, 128, 128], de_blocks=[4,4,4,4], gmp_dim=64,cls_dim=64, **kwargs) def model31G3(num_classes=50, **kwargs) -> PointMLP31: return PointMLP31(num_classes=num_classes, points=2048, embed_dim=64, groups=1, res_expansion=1.0, activation="gelu", bias=True, use_xyz=True, normalize="anchor", dim_expansion=[2, 2, 2, 2], pre_blocks=[2, 2, 2, 2], pos_blocks=[2, 2, 2, 2], k_neighbors=[32, 32, 32, 32], reducers=[4, 4, 4, 4], de_dims=[512, 256, 128, 128], de_blocks=[4,4,4,4], gmp_dim=64,cls_dim=64, **kwargs) def model31H(num_classes=50, **kwargs) -> PointMLP31: return PointMLP31(num_classes=num_classes, points=2048, embed_dim=64, groups=1, res_expansion=1.0, activation="gelu", bias=True, use_xyz=True, normalize="anchor", dim_expansion=[2, 2, 2, 2], pre_blocks=[2, 2, 2, 2], pos_blocks=[2, 2, 2, 2], k_neighbors=[32, 32, 32, 32], reducers=[4, 4, 4, 4], de_dims=[512, 256, 128, 128], de_blocks=[4,4,4,4], gmp_dim=64,cls_dim=64, **kwargs) def model31I(num_classes=50, **kwargs) -> PointMLP31: return PointMLP31(num_classes=num_classes, points=2048, embed_dim=64, groups=1, res_expansion=1.0, activation="relu", bias=True, use_xyz=True, normalize="anchor", dim_expansion=[2, 2, 2, 1], pre_blocks=[3, 3, 3, 3], pos_blocks=[3, 3, 3, 3], k_neighbors=[32, 32, 32, 32], reducers=[4, 4, 4, 4], de_dims=[512, 256, 128, 128], de_blocks=[4, 4, 4, 4], gmp_dim=64,cls_dim=64, **kwargs) def model31J(num_classes=50, **kwargs) -> PointMLP31: return PointMLP31(num_classes=num_classes, points=2048, embed_dim=64, groups=1, res_expansion=1.0, activation="leakyrelu", bias=True, use_xyz=True, normalize="anchor", dim_expansion=[2, 2, 2, 2], pre_blocks=[2, 2, 2, 2], pos_blocks=[2, 2, 2, 2], k_neighbors=[32, 32, 32, 32], reducers=[4, 4, 4, 4], de_dims=[512, 256, 128, 128], de_blocks=[4,4,4,4], gmp_dim=64,cls_dim=64, **kwargs) def model31K(num_classes=50, **kwargs) -> PointMLP31: return PointMLP31(num_classes=num_classes, points=2048, embed_dim=64, groups=1, res_expansion=1.0, activation="leakyrelu0.2", bias=True, use_xyz=True, normalize="anchor", dim_expansion=[2, 2, 2, 2], pre_blocks=[2, 2, 2, 2], pos_blocks=[2, 2, 2, 2], k_neighbors=[32, 32, 32, 32], reducers=[4, 4, 4, 4], de_dims=[512, 256, 128, 128], de_blocks=[4,4,4,4], gmp_dim=64,cls_dim=64, **kwargs) if __name__ == '__main__': data = torch.rand(2, 3, 2048) norm = torch.rand(2, 3, 2048) cls_label = torch.rand([2, 16]) print("===> testing modelD ...") model = model31D(50) out = model(data, cls_label) # [2,2048,50] print(out.shape) model = model31A(50) out = model(data, cls_label) # [2,2048,50] print(out.shape) model = model31B(50) out = model(data, cls_label) # [2,2048,50] print(out.shape) model = model31C(50) out = model(data, cls_label) # [2,50, 2048] print(out.shape)