252 lines
11 KiB
Python
252 lines
11 KiB
Python
|
# ---------------------------------------------------------------
|
||
|
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||
|
#
|
||
|
# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
|
||
|
# and proprietary rights in and to this software, related documentation
|
||
|
# and any modifications thereto. Any use, reproduction, disclosure or
|
||
|
# distribution of this software and related documentation without an express
|
||
|
# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
|
||
|
# ---------------------------------------------------------------
|
||
|
|
||
|
import importlib
|
||
|
import argparse
|
||
|
from loguru import logger
|
||
|
from comet_ml import Experiment
|
||
|
import torch
|
||
|
import numpy as np
|
||
|
import os
|
||
|
import sys
|
||
|
import torch.distributed as dist
|
||
|
from torch.multiprocessing import Process
|
||
|
from default_config import cfg as config
|
||
|
from utils import exp_helper, io_helper
|
||
|
from utils import utils
|
||
|
|
||
|
|
||
|
@logger.catch(onerror=lambda _: sys.exit(1), reraise=False)
|
||
|
def main(args, config):
|
||
|
# -- trainer -- #
|
||
|
logger.info('use trainer: {}', config.trainer.type)
|
||
|
trainer_lib = importlib.import_module(config.trainer.type)
|
||
|
Trainer = trainer_lib.Trainer
|
||
|
|
||
|
if config.set_detect_anomaly:
|
||
|
# attention: this makes thing slow
|
||
|
torch.autograd.set_detect_anomaly(True)
|
||
|
logger.info(
|
||
|
'\n\n' + '!'*30 + '\nWARNING: ths set_detect_anomaly is turned on, it can slow down the training! \n' + '!'*30)
|
||
|
|
||
|
# -- command init -- #
|
||
|
comet_key = config.comet_key
|
||
|
_, writer = utils.common_init(args.global_rank,
|
||
|
config.trainer.seed, config.save_dir, comet_key)
|
||
|
trainer = Trainer(config, args)
|
||
|
writer.add_hparams(config.to_dict(), vars(args))
|
||
|
nparam = utils.count_parameters_in_M(trainer.model)
|
||
|
logger.info('param size = %fM ' % nparam)
|
||
|
writer.log_other('nparam', nparam)
|
||
|
|
||
|
if args.global_rank == 0:
|
||
|
trainer.set_writer(writer)
|
||
|
writer.set_model_graph('{}'.format(trainer.model), overwrite=True)
|
||
|
if len(config.bash_name) > 0 and os.path.exists(config.bash_name):
|
||
|
writer.log_asset(config.bash_name)
|
||
|
if len(config.bash_name) > 0 and os.path.exists(os.path.join(config.save_dir, config.bash_name.split('/')[-1])):
|
||
|
writer.log_asset(os.path.join(
|
||
|
config.save_dir, config.bash_name.split('/')[-1]))
|
||
|
ckpt_dir = os.path.join(config.save_dir, 'checkpoints')
|
||
|
snapshot_file = os.path.join(config.save_dir, 'checkpoints', 'snapshot')
|
||
|
|
||
|
# -- check if prev saved ckpt exist -- #
|
||
|
if os.path.exists(ckpt_dir) and os.path.exists(snapshot_file):
|
||
|
logger.info(
|
||
|
'[Detect saved snapshot at the checkpoint dir] resume from preemption!!! ')
|
||
|
args.resume = True
|
||
|
args.pretrained = os.path.join(
|
||
|
config.save_dir, 'checkpoints', 'snapshot')
|
||
|
else:
|
||
|
logger.info('not find any checkpoint: {}, (exist={}), or snapshot {}, (exist={})',
|
||
|
ckpt_dir, os.path.exists(ckpt_dir), snapshot_file, os.path.exists(snapshot_file))
|
||
|
|
||
|
# -- prepare -- #
|
||
|
if args.resume or args.eval_generation:
|
||
|
if args.pretrained is not None:
|
||
|
trainer.start_epoch = trainer.resume(
|
||
|
args.pretrained, eval_generation=args.eval_generation)
|
||
|
else:
|
||
|
raise NotImplementedError
|
||
|
elif args.pretrained is not None:
|
||
|
trainer.load_vae(args.pretrained)
|
||
|
|
||
|
if not args.eval_generation:
|
||
|
trainer.train_epochs()
|
||
|
else:
|
||
|
logger.info('[skip_sample]={}', args.skip_sample)
|
||
|
|
||
|
save_file = None
|
||
|
if not args.skip_nll:
|
||
|
trainer.eval_nll(trainer.step, ntest=args.ntest, save_file=True)
|
||
|
logger.info('save as : {}', save_file)
|
||
|
# vis sampled output
|
||
|
if not args.skip_sample:
|
||
|
trainer.vis_sample(num_vis=8, writer=trainer.writer,
|
||
|
step=trainer.step, include_pred_x0=False,
|
||
|
save_file=save_file)
|
||
|
trainer.eval_sample(trainer.step)
|
||
|
logger.info('done')
|
||
|
|
||
|
# make all nodes wait for rank 0 to finish saving the files
|
||
|
# if args.distributed:
|
||
|
# dist.barrier()
|
||
|
|
||
|
|
||
|
def get_args():
|
||
|
parser = argparse.ArgumentParser('encoder decoder examiner')
|
||
|
# experimental results
|
||
|
parser.add_argument('--exp_root', type=str, default='../exp',
|
||
|
help='location of the results')
|
||
|
# parser.add_argument('--save', type=str, default='exp',
|
||
|
# help='id used for storing intermediate results')
|
||
|
# parser.add_argument('--recont_with_local_prior', type=bool, default=False,
|
||
|
# help='eval nll with local prior sampled from normal distribution')
|
||
|
parser.add_argument('--skip_sample', type=int, default=0,
|
||
|
help='only eval nll, no sampling')
|
||
|
parser.add_argument('--skip_nll', type=int, default=0,
|
||
|
help='skip eval nll ')
|
||
|
# data
|
||
|
parser.add_argument('--ntest', type=str, default=None,
|
||
|
help='number of samples in eval_nll, if None, eval the whole val set')
|
||
|
parser.add_argument('--dataset', type=str, default='cifar10',
|
||
|
choices=['cifar10', 'celeba_64', 'celeba_256',
|
||
|
'imagenet_32', 'ffhq', 'lsun_bedroom_128'],
|
||
|
help='which dataset to use')
|
||
|
parser.add_argument('--data', type=str, default='/tmp/nvae-diff/data',
|
||
|
help='location of the data corpus')
|
||
|
# DDP.
|
||
|
parser.add_argument('--autocast_train', action='store_true', default=True,
|
||
|
help='This flag enables FP16 in training.')
|
||
|
parser.add_argument('--autocast_eval', action='store_true', default=True,
|
||
|
help='This flag enables FP16 in evaluation.')
|
||
|
parser.add_argument('--num_proc_node', type=int, default=1,
|
||
|
help='The number of nodes in multi node env.')
|
||
|
parser.add_argument('--node_rank', type=int, default=0,
|
||
|
help='The index of node.')
|
||
|
parser.add_argument('--local_rank', type=int, default=0,
|
||
|
help='rank of process in the node')
|
||
|
parser.add_argument('--global_rank', type=int, default=0,
|
||
|
help='rank of process among all the processes')
|
||
|
parser.add_argument('--num_process_per_node', type=int, default=1,
|
||
|
help='number of gpus')
|
||
|
parser.add_argument('--master_address', type=str, default='127.0.0.1',
|
||
|
help='address for master')
|
||
|
parser.add_argument('--seed', type=int, default=1,
|
||
|
help='seed used for initialization')
|
||
|
parser.add_argument('--config', type=str,
|
||
|
help='The configuration file.', default='none')
|
||
|
parser.add_argument("opt",
|
||
|
help="Modify config options using the command-line",
|
||
|
default=None,
|
||
|
nargs=argparse.REMAINDER)
|
||
|
# Resume:
|
||
|
parser.add_argument('--resume', default=False, action='store_true')
|
||
|
parser.add_argument('--eval_generation',
|
||
|
default=False, action='store_true')
|
||
|
parser.add_argument('--pretrained',
|
||
|
default=None,
|
||
|
type=str,
|
||
|
help="Pretrained cehckpoint")
|
||
|
|
||
|
args = parser.parse_args()
|
||
|
|
||
|
# update config
|
||
|
if args.eval_generation or args.resume:
|
||
|
logger.info('[pretrained]: {}', args.pretrained)
|
||
|
args.config = os.path.dirname(args.pretrained) + '/../cfg.yml'
|
||
|
config.merge_from_file(args.config)
|
||
|
elif args.config != 'none':
|
||
|
logger.info('load config: {}', args.config)
|
||
|
cur_exp_name = config.exp_name
|
||
|
cur_hash = config.hash
|
||
|
|
||
|
config.merge_from_file(args.config)
|
||
|
config.exp_name = cur_exp_name # not following the exp name here
|
||
|
config.hash = cur_hash # not following the exp name here
|
||
|
config.merge_from_list(args.opt)
|
||
|
|
||
|
# Create log_name
|
||
|
EXP_ROOT = args.exp_root # os.environ.get('EXP_ROOT', '../exp/')
|
||
|
if config.exp_name == '' or config.exp_name == 'none':
|
||
|
config.hash = io_helper.hash_str('%s' % config) + 'h'
|
||
|
cfg_file_name = exp_helper.get_expname(config)
|
||
|
else:
|
||
|
cfg_file_name = config.exp_name
|
||
|
|
||
|
# Currently save dir and log_dir are the same
|
||
|
if args.eval_generation:
|
||
|
config.save_dir = config.log_dir = config.log_name = os.path.dirname(
|
||
|
args.config)
|
||
|
if config.trainer.type == 'ddim':
|
||
|
tag = 'eval_ddim'
|
||
|
else:
|
||
|
tag = 'eval'
|
||
|
cfg_file_name += f'/{tag}/'
|
||
|
config.log_name += f'/{tag}/'
|
||
|
config.save_dir += f'/{tag}/'
|
||
|
config.log_dir += f'/{tag}/'
|
||
|
else:
|
||
|
config.log_name = os.path.join(EXP_ROOT, cfg_file_name)
|
||
|
config.save_dir = os.path.join(EXP_ROOT, cfg_file_name)
|
||
|
config.log_dir = os.path.join(EXP_ROOT, cfg_file_name)
|
||
|
os.makedirs(config.log_dir, exist_ok=True)
|
||
|
|
||
|
# save config and log
|
||
|
if args.global_rank == 0 and not args.eval_generation:
|
||
|
logger.add(config.log_dir + '/train.log')
|
||
|
logger.info('EXP_ROOT: {} + exp name: {}, save dir: {}', EXP_ROOT,
|
||
|
cfg_file_name, config.save_dir)
|
||
|
saved_cfg = os.path.join(config.log_dir, 'cfg.yml')
|
||
|
with open(saved_cfg, 'w') as file:
|
||
|
file.write(config.dump())
|
||
|
logger.info('save config at {}', saved_cfg)
|
||
|
elif args.eval_generation:
|
||
|
logger.add(config.log_dir + '/eval_gen.log')
|
||
|
logger.info('log dir: {}', config.log_dir)
|
||
|
|
||
|
return args, config
|
||
|
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
args, config = get_args()
|
||
|
args.ntest = int(args.ntest) if args.ntest is not None else None
|
||
|
size = args.num_process_per_node
|
||
|
|
||
|
if size > 1:
|
||
|
args.distributed = True
|
||
|
processes = []
|
||
|
for rank in range(size):
|
||
|
logger.info('In Rank={}', rank)
|
||
|
args.local_rank = rank
|
||
|
global_rank = rank + args.node_rank * args.num_process_per_node
|
||
|
global_size = args.num_proc_node * args.num_process_per_node
|
||
|
args.global_size = global_size
|
||
|
args.global_rank = global_rank
|
||
|
logger.info('Node rank %d, local proc %d, global proc %d' %
|
||
|
(args.node_rank, rank, global_rank))
|
||
|
p = Process(target=utils.init_processes,
|
||
|
args=(global_rank, global_size, main, args, config))
|
||
|
p.start()
|
||
|
processes.append(p)
|
||
|
|
||
|
for p in processes:
|
||
|
logger.info('join {}', args.local_rank)
|
||
|
p.join()
|
||
|
else:
|
||
|
# for debugging
|
||
|
args.distributed = False
|
||
|
args.global_size = 1
|
||
|
utils.init_processes(0, size, main, args, config)
|
||
|
logger.info('should end now')
|
||
|
# if args.distributed:
|
||
|
# logger.info('destroy_process_group')
|
||
|
# dist.destroy_process_group()
|