what in tarnation is this shit ?
This commit is contained in:
parent
4fcb30ca42
commit
62e57ecfc0
|
@ -93,7 +93,7 @@ pip install pointnet2_ops_lib/.
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
## Useage
|
## Usage
|
||||||
|
|
||||||
### Classification ModelNet40
|
### Classification ModelNet40
|
||||||
**Train**: The dataset will be automatically downloaded, run following command to train.
|
**Train**: The dataset will be automatically downloaded, run following command to train.
|
||||||
|
|
|
@ -45,22 +45,29 @@ def main():
|
||||||
else:
|
else:
|
||||||
device = 'cpu'
|
device = 'cpu'
|
||||||
print(f"==> Using device: {device}")
|
print(f"==> Using device: {device}")
|
||||||
if args.msg is None:
|
# if args.msg is None:
|
||||||
message = str(datetime.datetime.now().strftime('-%Y%m%d%H%M%S'))
|
# message = str(datetime.datetime.now().strftime('-%Y%m%d%H%M%S'))
|
||||||
else:
|
# else:
|
||||||
message = "-"+args.msg
|
# message = "-"+args.msg
|
||||||
args.checkpoint = 'checkpoints/' + args.model + message
|
# args.checkpoint = 'checkpoints/' + args.model + message
|
||||||
|
if args.checkpoint is not None:
|
||||||
|
print(f"==> Using checkpoint: {args.checkpoint}")
|
||||||
|
|
||||||
print('==> Preparing data..')
|
print('==> Preparing data..')
|
||||||
test_loader = DataLoader(ModelNet40(partition='test', num_points=args.num_points), num_workers=4,
|
test_loader = DataLoader(
|
||||||
batch_size=args.batch_size, shuffle=False, drop_last=False)
|
ModelNet40(partition='test', num_points=args.num_points),
|
||||||
|
num_workers=4,
|
||||||
|
batch_size=args.batch_size,
|
||||||
|
shuffle=False,
|
||||||
|
drop_last=False
|
||||||
|
)
|
||||||
# Model
|
# Model
|
||||||
print('==> Building model..')
|
print('==> Building model..')
|
||||||
net = models.__dict__[args.model]()
|
net = models.__dict__[args.model]()
|
||||||
criterion = cal_loss
|
criterion = cal_loss
|
||||||
net = net.to(device)
|
net = net.to(device)
|
||||||
checkpoint_path = os.path.join(args.checkpoint, 'best_checkpoint.pth')
|
# checkpoint_path = os.path.join(args.checkpoint, 'best_checkpoint.pth')
|
||||||
checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))
|
checkpoint = torch.load(args.checkpoint, map_location=torch.device('cpu'))
|
||||||
# criterion = criterion.to(device)
|
# criterion = criterion.to(device)
|
||||||
if device == 'cuda':
|
if device == 'cuda':
|
||||||
net = torch.nn.DataParallel(net)
|
net = torch.nn.DataParallel(net)
|
||||||
|
|
Loading…
Reference in a new issue