JESUS CHRIST DID THEY CHECK THEIR CODE ?
This commit is contained in:
parent
62e57ecfc0
commit
c50897112d
|
@ -360,9 +360,15 @@ def pointMLPElite(num_classes=40, **kwargs) -> Model:
|
|||
k_neighbors=[24,24,24,24], reducers=[2, 2, 2, 2], **kwargs)
|
||||
|
||||
if __name__ == '__main__':
|
||||
data = torch.rand(2, 3, 1024)
|
||||
data = torch.rand(2, 3, 1024).cuda()
|
||||
print(data.shape)
|
||||
|
||||
print("===> testing pointMLP ...")
|
||||
model = pointMLP()
|
||||
model = pointMLP().cuda()
|
||||
out = model(data)
|
||||
print(out.shape)
|
||||
|
||||
print("===> testing pointMLPElite ...")
|
||||
model = pointMLPElite().cuda()
|
||||
out = model(data)
|
||||
print(out.shape)
|
||||
|
|
|
@ -455,10 +455,14 @@ def pointMLP(num_classes=50, **kwargs) -> PointMLP:
|
|||
|
||||
|
||||
if __name__ == '__main__':
|
||||
data = torch.rand(2, 3, 2048)
|
||||
norm = torch.rand(2, 3, 2048)
|
||||
cls_label = torch.rand([2, 16])
|
||||
print("===> testing modelD ...")
|
||||
model = pointMLP(50)
|
||||
out = model(data, cls_label) # [2,2048,50]
|
||||
print(out.shape)
|
||||
data = torch.rand(2, 3, 2048).cuda()
|
||||
norm = torch.rand(2, 3, 2048).cuda()
|
||||
cls_label = torch.rand([2, 16]).cuda()
|
||||
print(f"data shape: {data.shape}")
|
||||
print(f"norm shape: {norm.shape}")
|
||||
print(f"cls_label shape: {cls_label.shape}")
|
||||
|
||||
print("===> testing pointMLP (segmentation) ...")
|
||||
model = pointMLP(50).cuda()
|
||||
out = model(data, norm, cls_label) # [2,2048,50]
|
||||
print(f"out shape: {out.shape}")
|
||||
|
|
Loading…
Reference in a new issue