diff --git a/classification_ModelNet40/models/pointmlp.py b/classification_ModelNet40/models/pointmlp.py index 597ba46..d350efb 100644 --- a/classification_ModelNet40/models/pointmlp.py +++ b/classification_ModelNet40/models/pointmlp.py @@ -360,9 +360,15 @@ def pointMLPElite(num_classes=40, **kwargs) -> Model: k_neighbors=[24,24,24,24], reducers=[2, 2, 2, 2], **kwargs) if __name__ == '__main__': - data = torch.rand(2, 3, 1024) + data = torch.rand(2, 3, 1024).cuda() + print(data.shape) + print("===> testing pointMLP ...") - model = pointMLP() + model = pointMLP().cuda() out = model(data) print(out.shape) + print("===> testing pointMLPElite ...") + model = pointMLPElite().cuda() + out = model(data) + print(out.shape) diff --git a/part_segmentation/model/pointMLP.py b/part_segmentation/model/pointMLP.py index 7790387..1b62414 100644 --- a/part_segmentation/model/pointMLP.py +++ b/part_segmentation/model/pointMLP.py @@ -455,10 +455,14 @@ def pointMLP(num_classes=50, **kwargs) -> PointMLP: if __name__ == '__main__': - data = torch.rand(2, 3, 2048) - norm = torch.rand(2, 3, 2048) - cls_label = torch.rand([2, 16]) - print("===> testing modelD ...") - model = pointMLP(50) - out = model(data, cls_label) # [2,2048,50] - print(out.shape) + data = torch.rand(2, 3, 2048).cuda() + norm = torch.rand(2, 3, 2048).cuda() + cls_label = torch.rand([2, 16]).cuda() + print(f"data shape: {data.shape}") + print(f"norm shape: {norm.shape}") + print(f"cls_label shape: {cls_label.shape}") + + print("===> testing pointMLP (segmentation) ...") + model = pointMLP(50).cuda() + out = model(data, norm, cls_label) # [2,2048,50] + print(f"out shape: {out.shape}")