LION/train_dist.py
2023-01-23 00:14:49 -05:00

252 lines
11 KiB
Python

# ---------------------------------------------------------------
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
# ---------------------------------------------------------------
import importlib
import argparse
from loguru import logger
from comet_ml import Experiment
import torch
import numpy as np
import os
import sys
import torch.distributed as dist
from torch.multiprocessing import Process
from default_config import cfg as config
from utils import exp_helper, io_helper
from utils import utils
@logger.catch(onerror=lambda _: sys.exit(1), reraise=False)
def main(args, config):
# -- trainer -- #
logger.info('use trainer: {}', config.trainer.type)
trainer_lib = importlib.import_module(config.trainer.type)
Trainer = trainer_lib.Trainer
if config.set_detect_anomaly:
# attention: this makes thing slow
torch.autograd.set_detect_anomaly(True)
logger.info(
'\n\n' + '!'*30 + '\nWARNING: ths set_detect_anomaly is turned on, it can slow down the training! \n' + '!'*30)
# -- command init -- #
comet_key = config.comet_key
_, writer = utils.common_init(args.global_rank,
config.trainer.seed, config.save_dir, comet_key)
trainer = Trainer(config, args)
writer.add_hparams(config.to_dict(), vars(args))
nparam = utils.count_parameters_in_M(trainer.model)
logger.info('param size = %fM ' % nparam)
writer.log_other('nparam', nparam)
if args.global_rank == 0:
trainer.set_writer(writer)
writer.set_model_graph('{}'.format(trainer.model), overwrite=True)
if len(config.bash_name) > 0 and os.path.exists(config.bash_name):
writer.log_asset(config.bash_name)
if len(config.bash_name) > 0 and os.path.exists(os.path.join(config.save_dir, config.bash_name.split('/')[-1])):
writer.log_asset(os.path.join(
config.save_dir, config.bash_name.split('/')[-1]))
ckpt_dir = os.path.join(config.save_dir, 'checkpoints')
snapshot_file = os.path.join(config.save_dir, 'checkpoints', 'snapshot')
# -- check if prev saved ckpt exist -- #
if os.path.exists(ckpt_dir) and os.path.exists(snapshot_file):
logger.info(
'[Detect saved snapshot at the checkpoint dir] resume from preemption!!! ')
args.resume = True
args.pretrained = os.path.join(
config.save_dir, 'checkpoints', 'snapshot')
else:
logger.info('not find any checkpoint: {}, (exist={}), or snapshot {}, (exist={})',
ckpt_dir, os.path.exists(ckpt_dir), snapshot_file, os.path.exists(snapshot_file))
# -- prepare -- #
if args.resume or args.eval_generation:
if args.pretrained is not None:
trainer.start_epoch = trainer.resume(
args.pretrained, eval_generation=args.eval_generation)
else:
raise NotImplementedError
elif args.pretrained is not None:
trainer.load_vae(args.pretrained)
if not args.eval_generation:
trainer.train_epochs()
else:
logger.info('[skip_sample]={}', args.skip_sample)
save_file = None
if not args.skip_nll:
trainer.eval_nll(trainer.step, ntest=args.ntest, save_file=True)
logger.info('save as : {}', save_file)
# vis sampled output
if not args.skip_sample:
trainer.vis_sample(num_vis=8, writer=trainer.writer,
step=trainer.step, include_pred_x0=False,
save_file=save_file)
trainer.eval_sample(trainer.step)
logger.info('done')
# make all nodes wait for rank 0 to finish saving the files
# if args.distributed:
# dist.barrier()
def get_args():
parser = argparse.ArgumentParser('encoder decoder examiner')
# experimental results
parser.add_argument('--exp_root', type=str, default='../exp',
help='location of the results')
# parser.add_argument('--save', type=str, default='exp',
# help='id used for storing intermediate results')
# parser.add_argument('--recont_with_local_prior', type=bool, default=False,
# help='eval nll with local prior sampled from normal distribution')
parser.add_argument('--skip_sample', type=int, default=0,
help='only eval nll, no sampling')
parser.add_argument('--skip_nll', type=int, default=0,
help='skip eval nll ')
# data
parser.add_argument('--ntest', type=str, default=None,
help='number of samples in eval_nll, if None, eval the whole val set')
parser.add_argument('--dataset', type=str, default='cifar10',
choices=['cifar10', 'celeba_64', 'celeba_256',
'imagenet_32', 'ffhq', 'lsun_bedroom_128'],
help='which dataset to use')
parser.add_argument('--data', type=str, default='/tmp/nvae-diff/data',
help='location of the data corpus')
# DDP.
parser.add_argument('--autocast_train', action='store_true', default=True,
help='This flag enables FP16 in training.')
parser.add_argument('--autocast_eval', action='store_true', default=True,
help='This flag enables FP16 in evaluation.')
parser.add_argument('--num_proc_node', type=int, default=1,
help='The number of nodes in multi node env.')
parser.add_argument('--node_rank', type=int, default=0,
help='The index of node.')
parser.add_argument('--local_rank', type=int, default=0,
help='rank of process in the node')
parser.add_argument('--global_rank', type=int, default=0,
help='rank of process among all the processes')
parser.add_argument('--num_process_per_node', type=int, default=1,
help='number of gpus')
parser.add_argument('--master_address', type=str, default='127.0.0.1',
help='address for master')
parser.add_argument('--seed', type=int, default=1,
help='seed used for initialization')
parser.add_argument('--config', type=str,
help='The configuration file.', default='none')
parser.add_argument("opt",
help="Modify config options using the command-line",
default=None,
nargs=argparse.REMAINDER)
# Resume:
parser.add_argument('--resume', default=False, action='store_true')
parser.add_argument('--eval_generation',
default=False, action='store_true')
parser.add_argument('--pretrained',
default=None,
type=str,
help="Pretrained cehckpoint")
args = parser.parse_args()
# update config
if args.eval_generation or args.resume:
logger.info('[pretrained]: {}', args.pretrained)
args.config = os.path.dirname(args.pretrained) + '/../cfg.yml'
config.merge_from_file(args.config)
elif args.config != 'none':
logger.info('load config: {}', args.config)
cur_exp_name = config.exp_name
cur_hash = config.hash
config.merge_from_file(args.config)
config.exp_name = cur_exp_name # not following the exp name here
config.hash = cur_hash # not following the exp name here
config.merge_from_list(args.opt)
# Create log_name
EXP_ROOT = args.exp_root # os.environ.get('EXP_ROOT', '../exp/')
if config.exp_name == '' or config.exp_name == 'none':
config.hash = io_helper.hash_str('%s' % config) + 'h'
cfg_file_name = exp_helper.get_expname(config)
else:
cfg_file_name = config.exp_name
# Currently save dir and log_dir are the same
if args.eval_generation:
config.save_dir = config.log_dir = config.log_name = os.path.dirname(
args.config)
if config.trainer.type == 'ddim':
tag = 'eval_ddim'
else:
tag = 'eval'
cfg_file_name += f'/{tag}/'
config.log_name += f'/{tag}/'
config.save_dir += f'/{tag}/'
config.log_dir += f'/{tag}/'
else:
config.log_name = os.path.join(EXP_ROOT, cfg_file_name)
config.save_dir = os.path.join(EXP_ROOT, cfg_file_name)
config.log_dir = os.path.join(EXP_ROOT, cfg_file_name)
os.makedirs(config.log_dir, exist_ok=True)
# save config and log
if args.global_rank == 0 and not args.eval_generation:
logger.add(config.log_dir + '/train.log')
logger.info('EXP_ROOT: {} + exp name: {}, save dir: {}', EXP_ROOT,
cfg_file_name, config.save_dir)
saved_cfg = os.path.join(config.log_dir, 'cfg.yml')
with open(saved_cfg, 'w') as file:
file.write(config.dump())
logger.info('save config at {}', saved_cfg)
elif args.eval_generation:
logger.add(config.log_dir + '/eval_gen.log')
logger.info('log dir: {}', config.log_dir)
return args, config
if __name__ == '__main__':
args, config = get_args()
args.ntest = int(args.ntest) if args.ntest is not None else None
size = args.num_process_per_node
if size > 1:
args.distributed = True
processes = []
for rank in range(size):
logger.info('In Rank={}', rank)
args.local_rank = rank
global_rank = rank + args.node_rank * args.num_process_per_node
global_size = args.num_proc_node * args.num_process_per_node
args.global_size = global_size
args.global_rank = global_rank
logger.info('Node rank %d, local proc %d, global proc %d' %
(args.node_rank, rank, global_rank))
p = Process(target=utils.init_processes,
args=(global_rank, global_size, main, args, config))
p.start()
processes.append(p)
for p in processes:
logger.info('join {}', args.local_rank)
p.join()
else:
# for debugging
args.distributed = False
args.global_size = 1
utils.init_processes(0, size, main, args, config)
logger.info('should end now')
# if args.distributed:
# logger.info('destroy_process_group')
# dist.destroy_process_group()