what in tarnation is this shit ?

This commit is contained in:
Laurent FAINSIN 2023-08-03 15:39:35 +02:00
parent 4fcb30ca42
commit 62e57ecfc0
2 changed files with 17 additions and 10 deletions

View file

@ -93,7 +93,7 @@ pip install pointnet2_ops_lib/.
``` ```
## Useage ## Usage
### Classification ModelNet40 ### Classification ModelNet40
**Train**: The dataset will be automatically downloaded, run following command to train. **Train**: The dataset will be automatically downloaded, run following command to train.

View file

@ -45,22 +45,29 @@ def main():
else: else:
device = 'cpu' device = 'cpu'
print(f"==> Using device: {device}") print(f"==> Using device: {device}")
if args.msg is None: # if args.msg is None:
message = str(datetime.datetime.now().strftime('-%Y%m%d%H%M%S')) # message = str(datetime.datetime.now().strftime('-%Y%m%d%H%M%S'))
else: # else:
message = "-"+args.msg # message = "-"+args.msg
args.checkpoint = 'checkpoints/' + args.model + message # args.checkpoint = 'checkpoints/' + args.model + message
if args.checkpoint is not None:
print(f"==> Using checkpoint: {args.checkpoint}")
print('==> Preparing data..') print('==> Preparing data..')
test_loader = DataLoader(ModelNet40(partition='test', num_points=args.num_points), num_workers=4, test_loader = DataLoader(
batch_size=args.batch_size, shuffle=False, drop_last=False) ModelNet40(partition='test', num_points=args.num_points),
num_workers=4,
batch_size=args.batch_size,
shuffle=False,
drop_last=False
)
# Model # Model
print('==> Building model..') print('==> Building model..')
net = models.__dict__[args.model]() net = models.__dict__[args.model]()
criterion = cal_loss criterion = cal_loss
net = net.to(device) net = net.to(device)
checkpoint_path = os.path.join(args.checkpoint, 'best_checkpoint.pth') # checkpoint_path = os.path.join(args.checkpoint, 'best_checkpoint.pth')
checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu')) checkpoint = torch.load(args.checkpoint, map_location=torch.device('cpu'))
# criterion = criterion.to(device) # criterion = criterion.to(device)
if device == 'cuda': if device == 'cuda':
net = torch.nn.DataParallel(net) net = torch.nn.DataParallel(net)