JESUS CHRIST DID THEY CHECK THEIR CODE ?

This commit is contained in:
Laurent FAINSIN 2023-08-03 15:51:53 +02:00
parent 62e57ecfc0
commit c50897112d
2 changed files with 19 additions and 9 deletions

View file

@ -360,9 +360,15 @@ def pointMLPElite(num_classes=40, **kwargs) -> Model:
k_neighbors=[24,24,24,24], reducers=[2, 2, 2, 2], **kwargs)
if __name__ == '__main__':
data = torch.rand(2, 3, 1024)
data = torch.rand(2, 3, 1024).cuda()
print(data.shape)
print("===> testing pointMLP ...")
model = pointMLP()
model = pointMLP().cuda()
out = model(data)
print(out.shape)
print("===> testing pointMLPElite ...")
model = pointMLPElite().cuda()
out = model(data)
print(out.shape)

View file

@ -455,10 +455,14 @@ def pointMLP(num_classes=50, **kwargs) -> PointMLP:
if __name__ == '__main__':
data = torch.rand(2, 3, 2048)
norm = torch.rand(2, 3, 2048)
cls_label = torch.rand([2, 16])
print("===> testing modelD ...")
model = pointMLP(50)
out = model(data, cls_label) # [2,2048,50]
print(out.shape)
data = torch.rand(2, 3, 2048).cuda()
norm = torch.rand(2, 3, 2048).cuda()
cls_label = torch.rand([2, 16]).cuda()
print(f"data shape: {data.shape}")
print(f"norm shape: {norm.shape}")
print(f"cls_label shape: {cls_label.shape}")
print("===> testing pointMLP (segmentation) ...")
model = pointMLP(50).cuda()
out = model(data, norm, cls_label) # [2,2048,50]
print(f"out shape: {out.shape}")