Shape-as-Point/train.py

185 lines
6.9 KiB
Python
Raw Normal View History

2021-11-08 10:09:50 +00:00
import os
abspath = os.path.abspath(__file__)
dname = os.path.dirname(abspath)
os.chdir(dname)
import torch
import torch.optim as optim
import open3d as o3d
import numpy as np; np.set_printoptions(precision=4)
import shutil, argparse, time
from torch.utils.tensorboard import SummaryWriter
from src import config
from src.data import collate_remove_none, collate_stack_together, worker_init_fn
from src.training import Trainer
from src.model import Encode2Points
from src.utils import load_config, initialize_logger, \
AverageMeter, load_model_manual
def main():
parser = argparse.ArgumentParser(description='MNIST toy experiment')
parser.add_argument('config', type=str, help='Path to config file.')
parser.add_argument('--no_cuda', action='store_true', default=False,
help='disables CUDA training')
parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed (default: 1)')
args = parser.parse_args()
cfg = load_config(args.config, 'configs/default.yaml')
use_cuda = not args.no_cuda and torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
input_type = cfg['data']['input_type']
batch_size = cfg['train']['batch_size']
model_selection_metric = cfg['train']['model_selection_metric']
# PYTORCH VERSION > 1.0.0
assert(float(torch.__version__.split('.')[-3]) > 0)
# boiler-plate
if cfg['train']['timestamp']:
cfg['train']['out_dir'] += '_' + time.strftime("%Y_%m_%d_%H_%M_%S")
logger = initialize_logger(cfg)
torch.manual_seed(args.seed)
np.random.seed(args.seed)
shutil.copyfile(args.config, os.path.join(cfg['train']['out_dir'], 'config.yaml'))
logger.info("using GPU: " + torch.cuda.get_device_name(0))
# TensorboardX writer
tblogdir = os.path.join(cfg['train']['out_dir'], "tensorboard_log")
if not os.path.exists(tblogdir):
os.makedirs(tblogdir, exist_ok=True)
writer = SummaryWriter(log_dir=tblogdir)
inputs = None
train_dataset = config.get_dataset('train', cfg)
val_dataset = config.get_dataset('val', cfg)
vis_dataset = config.get_dataset('vis', cfg)
collate_fn = collate_remove_none
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=batch_size, num_workers=cfg['train']['n_workers'], shuffle=True,
collate_fn=collate_fn,
worker_init_fn=worker_init_fn)
val_loader = torch.utils.data.DataLoader(
val_dataset, batch_size=1, num_workers=cfg['train']['n_workers_val'], shuffle=False,
collate_fn=collate_remove_none,
worker_init_fn=worker_init_fn)
vis_loader = torch.utils.data.DataLoader(
vis_dataset, batch_size=1, num_workers=cfg['train']['n_workers_val'], shuffle=False,
collate_fn=collate_fn,
worker_init_fn=worker_init_fn)
if torch.cuda.device_count() > 1:
model = torch.nn.DataParallel(Encode2Points(cfg)).to(device)
else:
model = Encode2Points(cfg).to(device)
n_parameter = sum(p.numel() for p in model.parameters() if p.requires_grad)
logger.info('Number of parameters: %d'% n_parameter)
# load model
try:
# load model
state_dict = torch.load(os.path.join(cfg['train']['out_dir'], 'model.pt'))
load_model_manual(state_dict['state_dict'], model)
out = "Load model from iteration %d" % state_dict.get('it', 0)
logger.info(out)
# load point cloud
except:
state_dict = dict()
metric_val_best = state_dict.get(
'loss_val_best', np.inf)
logger.info('Current best validation metric (%s): %.8f'
% (model_selection_metric, metric_val_best))
LR = float(cfg['train']['lr'])
optimizer = optim.Adam(model.parameters(), lr=LR)
start_epoch = state_dict.get('epoch', -1)
it = state_dict.get('it', -1)
trainer = Trainer(cfg, optimizer, device=device)
runtime = {}
runtime['all'] = AverageMeter()
# training loop
for epoch in range(start_epoch+1, cfg['train']['total_epochs']+1):
for batch in train_loader:
it += 1
start = time.time()
loss, loss_each = trainer.train_step(inputs, batch, model)
# measure elapsed time
end = time.time()
runtime['all'].update(end - start)
if it % cfg['train']['print_every'] == 0:
log_text = ('[Epoch %02d] it=%d, loss=%.4f') %(epoch, it, loss)
writer.add_scalar('train/loss', loss, it)
if loss_each is not None:
for k, l in loss_each.items():
if l.item() != 0.:
log_text += (' loss_%s=%.4f') % (k, l.item())
writer.add_scalar('train/%s' % k, l, it)
log_text += (' time=%.3f / %.2f') % (runtime['all'].val, runtime['all'].sum)
logger.info(log_text)
if (it>0)& (it % cfg['train']['visualize_every'] == 0):
for i, batch_vis in enumerate(vis_loader):
trainer.save(model, batch_vis, it, i)
if i >= 4:
break
logger.info('Saved mesh and pointcloud')
# run validation
if it > 0 and (it % cfg['train']['validate_every']) == 0:
eval_dict = trainer.evaluate(val_loader, model)
metric_val = eval_dict[model_selection_metric]
logger.info('Validation metric (%s): %.4f'
% (model_selection_metric, metric_val))
for k, v in eval_dict.items():
writer.add_scalar('val/%s' % k, v, it)
if -(metric_val - metric_val_best) >= 0:
metric_val_best = metric_val
logger.info('New best model (loss %.4f)' % metric_val_best)
state = {'epoch': epoch,
'it': it,
'loss_val_best': metric_val_best}
state['state_dict'] = model.state_dict()
torch.save(state, os.path.join(cfg['train']['out_dir'], 'model_best.pt'))
# save checkpoint
if (epoch > 0) & (it % cfg['train']['checkpoint_every'] == 0):
state = {'epoch': epoch,
'it': it,
'loss_val_best': metric_val_best}
pcl = None
state['state_dict'] = model.state_dict()
torch.save(state, os.path.join(cfg['train']['out_dir'], 'model.pt'))
if (it % cfg['train']['backup_every'] == 0):
torch.save(state, os.path.join(cfg['train']['dir_model'], '%04d' % it + '.pt'))
logger.info("Backup model at iteration %d" % it)
logger.info("Save new model at iteration %d" % it)
done=time.time()
if __name__ == '__main__':
main()