JESUS CHRIST DID THEY CHECK THEIR CODE ?
This commit is contained in:
parent
62e57ecfc0
commit
c50897112d
|
@ -360,9 +360,15 @@ def pointMLPElite(num_classes=40, **kwargs) -> Model:
|
||||||
k_neighbors=[24,24,24,24], reducers=[2, 2, 2, 2], **kwargs)
|
k_neighbors=[24,24,24,24], reducers=[2, 2, 2, 2], **kwargs)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
data = torch.rand(2, 3, 1024)
|
data = torch.rand(2, 3, 1024).cuda()
|
||||||
|
print(data.shape)
|
||||||
|
|
||||||
print("===> testing pointMLP ...")
|
print("===> testing pointMLP ...")
|
||||||
model = pointMLP()
|
model = pointMLP().cuda()
|
||||||
out = model(data)
|
out = model(data)
|
||||||
print(out.shape)
|
print(out.shape)
|
||||||
|
|
||||||
|
print("===> testing pointMLPElite ...")
|
||||||
|
model = pointMLPElite().cuda()
|
||||||
|
out = model(data)
|
||||||
|
print(out.shape)
|
||||||
|
|
|
@ -455,10 +455,14 @@ def pointMLP(num_classes=50, **kwargs) -> PointMLP:
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
data = torch.rand(2, 3, 2048)
|
data = torch.rand(2, 3, 2048).cuda()
|
||||||
norm = torch.rand(2, 3, 2048)
|
norm = torch.rand(2, 3, 2048).cuda()
|
||||||
cls_label = torch.rand([2, 16])
|
cls_label = torch.rand([2, 16]).cuda()
|
||||||
print("===> testing modelD ...")
|
print(f"data shape: {data.shape}")
|
||||||
model = pointMLP(50)
|
print(f"norm shape: {norm.shape}")
|
||||||
out = model(data, cls_label) # [2,2048,50]
|
print(f"cls_label shape: {cls_label.shape}")
|
||||||
print(out.shape)
|
|
||||||
|
print("===> testing pointMLP (segmentation) ...")
|
||||||
|
model = pointMLP(50).cuda()
|
||||||
|
out = model(data, norm, cls_label) # [2,2048,50]
|
||||||
|
print(f"out shape: {out.shape}")
|
||||||
|
|
Loading…
Reference in a new issue