update
This commit is contained in:
parent
eb59980e47
commit
36fa3171ba
|
@ -1,4 +1,3 @@
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
import logging
|
import logging
|
||||||
|
@ -26,10 +25,10 @@ def parse_args():
|
||||||
parser.add_argument('--msg', type=str, help='message after checkpoint')
|
parser.add_argument('--msg', type=str, help='message after checkpoint')
|
||||||
parser.add_argument('--batch_size', type=int, default=32, help='batch size in training')
|
parser.add_argument('--batch_size', type=int, default=32, help='batch size in training')
|
||||||
parser.add_argument('--model', default='PointNet', help='model name [default: pointnet_cls]')
|
parser.add_argument('--model', default='PointNet', help='model name [default: pointnet_cls]')
|
||||||
parser.add_argument('--epoch', default=200, type=int, help='number of epoch in training')
|
parser.add_argument('--epoch', default=300, type=int, help='number of epoch in training')
|
||||||
parser.add_argument('--num_points', type=int, default=1024, help='Point Number')
|
parser.add_argument('--num_points', type=int, default=1024, help='Point Number')
|
||||||
parser.add_argument('--learning_rate', default=0.1, type=float, help='learning rate in training')
|
parser.add_argument('--learning_rate', default=0.1, type=float, help='learning rate in training')
|
||||||
parser.add_argument('--weight_decay', type=float, default=1e-4, help='decay rate')
|
parser.add_argument('--weight_decay', type=float, default=2e-4, help='decay rate')
|
||||||
parser.add_argument('--seed', type=int, help='random seed')
|
parser.add_argument('--seed', type=int, help='random seed')
|
||||||
parser.add_argument('--workers', default=8, type=int, help='workers')
|
parser.add_argument('--workers', default=8, type=int, help='workers')
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
@ -56,8 +55,8 @@ def main():
|
||||||
if args.msg is None:
|
if args.msg is None:
|
||||||
message = time_str
|
message = time_str
|
||||||
else:
|
else:
|
||||||
message = "-"+args.msg
|
message = "-" + args.msg
|
||||||
args.checkpoint = 'checkpoints/' + args.model + message + '-'+str(args.seed)
|
args.checkpoint = 'checkpoints/' + args.model + message + '-' + str(args.seed)
|
||||||
if not os.path.isdir(args.checkpoint):
|
if not os.path.isdir(args.checkpoint):
|
||||||
mkdir_p(args.checkpoint)
|
mkdir_p(args.checkpoint)
|
||||||
|
|
||||||
|
@ -68,12 +67,11 @@ def main():
|
||||||
file_handler.setLevel(logging.INFO)
|
file_handler.setLevel(logging.INFO)
|
||||||
file_handler.setFormatter(formatter)
|
file_handler.setFormatter(formatter)
|
||||||
screen_logger.addHandler(file_handler)
|
screen_logger.addHandler(file_handler)
|
||||||
|
|
||||||
def printf(str):
|
def printf(str):
|
||||||
screen_logger.info(str)
|
screen_logger.info(str)
|
||||||
print(str)
|
print(str)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# Model
|
# Model
|
||||||
printf(f"args: {args}")
|
printf(f"args: {args}")
|
||||||
printf('==> Building model..')
|
printf('==> Building model..')
|
||||||
|
@ -94,7 +92,6 @@ def main():
|
||||||
start_epoch = 0 # start from epoch 0 or last checkpoint epoch
|
start_epoch = 0 # start from epoch 0 or last checkpoint epoch
|
||||||
optimizer_dict = None
|
optimizer_dict = None
|
||||||
|
|
||||||
|
|
||||||
if not os.path.isfile(os.path.join(args.checkpoint, "last_checkpoint.pth")):
|
if not os.path.isfile(os.path.join(args.checkpoint, "last_checkpoint.pth")):
|
||||||
save_args(args)
|
save_args(args)
|
||||||
logger = Logger(os.path.join(args.checkpoint, 'log.txt'), title="ModelNet" + args.model)
|
logger = Logger(os.path.join(args.checkpoint, 'log.txt'), title="ModelNet" + args.model)
|
||||||
|
@ -116,19 +113,16 @@ def main():
|
||||||
logger = Logger(os.path.join(args.checkpoint, 'log.txt'), title="ModelNet" + args.model, resume=True)
|
logger = Logger(os.path.join(args.checkpoint, 'log.txt'), title="ModelNet" + args.model, resume=True)
|
||||||
optimizer_dict = checkpoint['optimizer']
|
optimizer_dict = checkpoint['optimizer']
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
printf('==> Preparing data..')
|
printf('==> Preparing data..')
|
||||||
train_loader = DataLoader(ModelNet40(partition='train', num_points=args.num_points), num_workers=args.workers,
|
train_loader = DataLoader(ModelNet40(partition='train', num_points=args.num_points), num_workers=args.workers,
|
||||||
batch_size=args.batch_size, shuffle=True, drop_last=True)
|
batch_size=args.batch_size, shuffle=True, drop_last=True)
|
||||||
test_loader = DataLoader(ModelNet40(partition='test', num_points=args.num_points), num_workers=args.workers,
|
test_loader = DataLoader(ModelNet40(partition='test', num_points=args.num_points), num_workers=args.workers,
|
||||||
batch_size=args.batch_size//2, shuffle=False, drop_last=False)
|
batch_size=args.batch_size // 2, shuffle=False, drop_last=False)
|
||||||
|
|
||||||
optimizer = torch.optim.SGD(net.parameters(), lr=args.learning_rate, momentum=0.9, weight_decay=args.weight_decay)
|
optimizer = torch.optim.SGD(net.parameters(), lr=args.learning_rate, momentum=0.9, weight_decay=args.weight_decay)
|
||||||
if optimizer_dict is not None:
|
if optimizer_dict is not None:
|
||||||
optimizer.load_state_dict(optimizer_dict)
|
optimizer.load_state_dict(optimizer_dict)
|
||||||
scheduler = CosineAnnealingLR(optimizer, args.epoch, eta_min=1e-3, last_epoch=start_epoch-1)
|
scheduler = CosineAnnealingLR(optimizer, args.epoch, eta_min=1e-3, last_epoch=start_epoch - 1)
|
||||||
|
|
||||||
|
|
||||||
for epoch in range(start_epoch, args.epoch):
|
for epoch in range(start_epoch, args.epoch):
|
||||||
printf('Epoch(%d/%s) Learning Rate %s:' % (epoch + 1, args.epoch, optimizer.param_groups[0]['lr']))
|
printf('Epoch(%d/%s) Learning Rate %s:' % (epoch + 1, args.epoch, optimizer.param_groups[0]['lr']))
|
||||||
|
@ -149,16 +143,15 @@ def main():
|
||||||
best_test_loss = test_out["loss"] if (test_out["loss"] < best_test_loss) else best_test_loss
|
best_test_loss = test_out["loss"] if (test_out["loss"] < best_test_loss) else best_test_loss
|
||||||
best_train_loss = train_out["loss"] if (train_out["loss"] < best_train_loss) else best_train_loss
|
best_train_loss = train_out["loss"] if (train_out["loss"] < best_train_loss) else best_train_loss
|
||||||
|
|
||||||
|
|
||||||
save_model(
|
save_model(
|
||||||
net, epoch, path=args.checkpoint, acc=test_out["acc"], is_best=is_best,
|
net, epoch, path=args.checkpoint, acc=test_out["acc"], is_best=is_best,
|
||||||
best_test_acc=best_test_acc, # best test accuracy
|
best_test_acc=best_test_acc, # best test accuracy
|
||||||
best_train_acc = best_train_acc,
|
best_train_acc=best_train_acc,
|
||||||
best_test_acc_avg = best_test_acc_avg,
|
best_test_acc_avg=best_test_acc_avg,
|
||||||
best_train_acc_avg = best_train_acc_avg,
|
best_train_acc_avg=best_train_acc_avg,
|
||||||
best_test_loss = best_test_loss,
|
best_test_loss=best_test_loss,
|
||||||
best_train_loss = best_train_loss,
|
best_train_loss=best_train_loss,
|
||||||
optimizer = optimizer.state_dict()
|
optimizer=optimizer.state_dict()
|
||||||
)
|
)
|
||||||
logger.append([epoch, optimizer.param_groups[0]['lr'],
|
logger.append([epoch, optimizer.param_groups[0]['lr'],
|
||||||
train_out["loss"], train_out["acc_avg"], train_out["acc"],
|
train_out["loss"], train_out["acc_avg"], train_out["acc"],
|
||||||
|
@ -178,8 +171,6 @@ def main():
|
||||||
printf(f"++++++++" * 5)
|
printf(f"++++++++" * 5)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def train(net, trainloader, optimizer, criterion, device):
|
def train(net, trainloader, optimizer, criterion, device):
|
||||||
net.train()
|
net.train()
|
||||||
train_loss = 0
|
train_loss = 0
|
||||||
|
|
|
@ -1,6 +1,3 @@
|
||||||
"""
|
|
||||||
nohup python voting.py --model model31A --msg 20210818204651 > model31A_20210818204651_voting.out &
|
|
||||||
"""
|
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
import datetime
|
import datetime
|
||||||
|
@ -32,10 +29,7 @@ def parse_args():
|
||||||
parser.add_argument('--batch_size', type=int, default=32, help='batch size in training')
|
parser.add_argument('--batch_size', type=int, default=32, help='batch size in training')
|
||||||
parser.add_argument('--model', default='model31A', help='model name [default: pointnet_cls]')
|
parser.add_argument('--model', default='model31A', help='model name [default: pointnet_cls]')
|
||||||
parser.add_argument('--num_classes', default=40, type=int, choices=[10, 40], help='training on ModelNet10/40')
|
parser.add_argument('--num_classes', default=40, type=int, choices=[10, 40], help='training on ModelNet10/40')
|
||||||
parser.add_argument('--epoch', default=350, type=int, help='number of epoch in training')
|
|
||||||
parser.add_argument('--num_points', type=int, default=1024, help='Point Number')
|
parser.add_argument('--num_points', type=int, default=1024, help='Point Number')
|
||||||
parser.add_argument('--learning_rate', default=0.01, type=float, help='learning rate in training')
|
|
||||||
parser.add_argument('--weight_decay', type=float, default=1e-4, help='decay rate')
|
|
||||||
parser.add_argument('--seed', type=int, help='random seed (default: 1)')
|
parser.add_argument('--seed', type=int, help='random seed (default: 1)')
|
||||||
|
|
||||||
# Voting evaluation, referring: https://github.com/CVMI-Lab/PAConv/blob/main/obj_cls/eval_voting.py
|
# Voting evaluation, referring: https://github.com/CVMI-Lab/PAConv/blob/main/obj_cls/eval_voting.py
|
||||||
|
@ -46,7 +40,7 @@ def parse_args():
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
class PointcloudScale(object): # input random scaling
|
class PointcloudScale(object): # input random scaling
|
||||||
def __init__(self, scale_low=2. / 3., scale_high=3. / 2.):
|
def __init__(self, scale_low=2. / 3., scale_high=3. / 2.):
|
||||||
self.scale_low = scale_low
|
self.scale_low = scale_low
|
||||||
self.scale_high = scale_high
|
self.scale_high = scale_high
|
||||||
|
@ -59,6 +53,7 @@ class PointcloudScale(object): # input random scaling
|
||||||
|
|
||||||
return pc
|
return pc
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
print(f"args: {args}")
|
print(f"args: {args}")
|
||||||
|
@ -82,12 +77,12 @@ def main():
|
||||||
if args.msg is None:
|
if args.msg is None:
|
||||||
message = str(datetime.datetime.now().strftime('-%Y%m%d%H%M%S'))
|
message = str(datetime.datetime.now().strftime('-%Y%m%d%H%M%S'))
|
||||||
else:
|
else:
|
||||||
message = "-"+args.msg
|
message = "-" + args.msg
|
||||||
args.checkpoint = 'checkpoints/' + args.model + message
|
args.checkpoint = 'checkpoints/' + args.model + message
|
||||||
|
|
||||||
print('==> Preparing data..')
|
print('==> Preparing data..')
|
||||||
test_loader = DataLoader(ModelNet40(partition='test', num_points=args.num_points), num_workers=4,
|
test_loader = DataLoader(ModelNet40(partition='test', num_points=args.num_points), num_workers=4,
|
||||||
batch_size=args.batch_size//2, shuffle=False, drop_last=False)
|
batch_size=args.batch_size // 2, shuffle=False, drop_last=False)
|
||||||
# Model
|
# Model
|
||||||
print('==> Building model..')
|
print('==> Building model..')
|
||||||
net = models.__dict__[args.model]()
|
net = models.__dict__[args.model]()
|
||||||
|
@ -115,8 +110,6 @@ def main():
|
||||||
voting(net, test_loader, device, args)
|
voting(net, test_loader, device, args)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def validate(net, testloader, criterion, device):
|
def validate(net, testloader, criterion, device):
|
||||||
net.eval()
|
net.eval()
|
||||||
test_loss = 0
|
test_loss = 0
|
||||||
|
@ -152,11 +145,11 @@ def validate(net, testloader, criterion, device):
|
||||||
|
|
||||||
|
|
||||||
def voting(net, testloader, device, args):
|
def voting(net, testloader, device, args):
|
||||||
name ='/evaluate_voting'+str(datetime.datetime.now().strftime('-%Y%m%d%H%M%S'))+'seed_'+str(args.seed)+'.log'
|
name = '/evaluate_voting' + str(datetime.datetime.now().strftime('-%Y%m%d%H%M%S')) + 'seed_' + str(
|
||||||
|
args.seed) + '.log'
|
||||||
io = IOStream(args.checkpoint + name)
|
io = IOStream(args.checkpoint + name)
|
||||||
io.cprint(str(args))
|
io.cprint(str(args))
|
||||||
|
|
||||||
|
|
||||||
net.eval()
|
net.eval()
|
||||||
best_acc = 0
|
best_acc = 0
|
||||||
best_mean_acc = 0
|
best_mean_acc = 0
|
||||||
|
@ -199,11 +192,5 @@ def voting(net, testloader, device, args):
|
||||||
io.cprint(final_outstr)
|
io.cprint(final_outstr)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
main()
|
main()
|
||||||
|
|
Loading…
Reference in a new issue