what in tarnation is this shit ?

This commit is contained in:
Laurent FAINSIN 2023-08-03 15:39:35 +02:00
parent 4fcb30ca42
commit 62e57ecfc0
2 changed files with 17 additions and 10 deletions

View file

@ -93,7 +93,7 @@ pip install pointnet2_ops_lib/.
```
## Useage
## Usage
### Classification ModelNet40
**Train**: The dataset will be automatically downloaded, run following command to train.

View file

@ -45,22 +45,29 @@ def main():
else:
device = 'cpu'
print(f"==> Using device: {device}")
if args.msg is None:
message = str(datetime.datetime.now().strftime('-%Y%m%d%H%M%S'))
else:
message = "-"+args.msg
args.checkpoint = 'checkpoints/' + args.model + message
# if args.msg is None:
# message = str(datetime.datetime.now().strftime('-%Y%m%d%H%M%S'))
# else:
# message = "-"+args.msg
# args.checkpoint = 'checkpoints/' + args.model + message
if args.checkpoint is not None:
print(f"==> Using checkpoint: {args.checkpoint}")
print('==> Preparing data..')
test_loader = DataLoader(ModelNet40(partition='test', num_points=args.num_points), num_workers=4,
batch_size=args.batch_size, shuffle=False, drop_last=False)
test_loader = DataLoader(
ModelNet40(partition='test', num_points=args.num_points),
num_workers=4,
batch_size=args.batch_size,
shuffle=False,
drop_last=False
)
# Model
print('==> Building model..')
net = models.__dict__[args.model]()
criterion = cal_loss
net = net.to(device)
checkpoint_path = os.path.join(args.checkpoint, 'best_checkpoint.pth')
checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))
# checkpoint_path = os.path.join(args.checkpoint, 'best_checkpoint.pth')
checkpoint = torch.load(args.checkpoint, map_location=torch.device('cpu'))
# criterion = criterion.to(device)
if device == 'cuda':
net = torch.nn.DataParallel(net)