LION/utils/exp_helper.py
2023-01-23 00:14:49 -05:00

123 lines
3.8 KiB
Python

# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
import time
import os
import numpy as np
from loguru import logger
from math import isnan
from calmsize import size as calmsize
def parse_cfg_str(cfg_str):
""" parse a string into a dict
string format: k1=v1,k2=v2,k3=v3
"""
cfg_list = cfg_str.split('-')
cfg_list = [c for c in cfg_list if len(c) > 0]
cfg_expand_list = []
for c in cfg_list:
k, v = c.split('=')
cfg_expand_list.append(k)
cfg_expand_list.append(v)
return cfg_expand_list
##cfg_dict = {}
# if cfg_str == '':
# return cfg_dict
##cfg_str_list = cfg_str.split(',')
# for p in cfg_str_list:
## kvs = p.split('=')
## assert(len(kvs) == 2), f'wrong format, expect k1=v1 for {p}'
## k, v = kvs
## cfg_dict[k] = v
# return cfg_dict
def readable_size(num_bytes: int) -> str:
return '' if isnan(num_bytes) else '{:.1f}'.format(calmsize(num_bytes))
class ExpTimer(object):
def __init__(self, num_epoch, start_epoch=0):
self.cur_epoch = start_epoch
self.num_epoch = num_epoch
self.time_list = []
def tic(self):
self.last_tic = time.time()
def toc(self):
self.time_list.append(time.time() - self.last_tic)
self.cur_epoch += 1
def hours_left(self):
if len(self.time_list) == 0:
return 0
num_epoch_left = self.num_epoch - self.cur_epoch
mean_epoch_time = np.array(self.time_list).mean()
hours_left = (mean_epoch_time * num_epoch_left) / 3600.0 # hours
return hours_left
def print(self):
logger.info('est: {:.1}h', self.hours_left)
def format_e(n):
if n == 0:
return '0'
a = '%E' % n
return a.split('E')[0].rstrip('0').rstrip('.') + 'E' + a.split('E-0')[1]
def get_evalname(config):
# generate tag for the generated samples
tag = ''
if config.ddpm.model_var_type != 'fixedlarge':
tag += config.ddpm.model_var_type
if not config.ddpm.ema:
tag += 'noema'
tag += f"s{config.trainer.seed}"
if config.data.te_max_sample_points != 2048:
tag += 'N%d' % config.data.te_max_sample_points
if config.eval_ddim_step > 0:
tag += 'ddim%d_%s%.1f' % (
config.eval_ddim_step,
config.sde.ddim_skip_type,
config.sde.ddim_kappa)
githash = os.popen('git rev-parse HEAD').read().strip()[:5]
logger.info('git hash: {}', githash)
tag += f"H{githash}"
return tag
def get_expname(config):
if config.exp_name == '' or config.exp_name == 'none':
cate = config.data.cates if type(
config.data.cates) is str else config.data.cates[0]
cfg_file_name = ''
if config.data.type == 'datasets.neuralspline_datasets':
cfg_file_name += 'ns'
cfg_file_name += '%s/' % cate
if len(config.hash):
cfg_file_name += '%s_' % config.hash
cfg_file_name += f"{config.trainer.type.split('.')[-1].split('_')[0]}_"
if len(config.cmt):
cfg_file_name += config.cmt + '_'
cfg_file_name += 'B%d' % config.data.batch_size
if config.data.tr_max_sample_points != 2048:
cfg_file_name += 'N%d' % config.data.tr_max_sample_points
run_time = time.strftime('%m%d')
cfg_file_name = run_time + '/' + cfg_file_name
else:
cfg_file_name = config.exp_name
return cfg_file_name