JESUS CHRIST DID THEY CHECK THEIR CODE ?

This commit is contained in:
Laurent FAINSIN 2023-08-03 15:51:53 +02:00
parent 62e57ecfc0
commit c50897112d
2 changed files with 19 additions and 9 deletions

View file

@ -360,9 +360,15 @@ def pointMLPElite(num_classes=40, **kwargs) -> Model:
k_neighbors=[24,24,24,24], reducers=[2, 2, 2, 2], **kwargs) k_neighbors=[24,24,24,24], reducers=[2, 2, 2, 2], **kwargs)
if __name__ == '__main__': if __name__ == '__main__':
data = torch.rand(2, 3, 1024) data = torch.rand(2, 3, 1024).cuda()
print(data.shape)
print("===> testing pointMLP ...") print("===> testing pointMLP ...")
model = pointMLP() model = pointMLP().cuda()
out = model(data) out = model(data)
print(out.shape) print(out.shape)
print("===> testing pointMLPElite ...")
model = pointMLPElite().cuda()
out = model(data)
print(out.shape)

View file

@ -455,10 +455,14 @@ def pointMLP(num_classes=50, **kwargs) -> PointMLP:
if __name__ == '__main__': if __name__ == '__main__':
data = torch.rand(2, 3, 2048) data = torch.rand(2, 3, 2048).cuda()
norm = torch.rand(2, 3, 2048) norm = torch.rand(2, 3, 2048).cuda()
cls_label = torch.rand([2, 16]) cls_label = torch.rand([2, 16]).cuda()
print("===> testing modelD ...") print(f"data shape: {data.shape}")
model = pointMLP(50) print(f"norm shape: {norm.shape}")
out = model(data, cls_label) # [2,2048,50] print(f"cls_label shape: {cls_label.shape}")
print(out.shape)
print("===> testing pointMLP (segmentation) ...")
model = pointMLP(50).cuda()
out = model(data, norm, cls_label) # [2,2048,50]
print(f"out shape: {out.shape}")