123 lines
3.8 KiB
Python
123 lines
3.8 KiB
Python
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
#
|
|
# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
|
|
# and proprietary rights in and to this software, related documentation
|
|
# and any modifications thereto. Any use, reproduction, disclosure or
|
|
# distribution of this software and related documentation without an express
|
|
# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
|
|
import time
|
|
import os
|
|
import numpy as np
|
|
from loguru import logger
|
|
from math import isnan
|
|
from calmsize import size as calmsize
|
|
|
|
|
|
def parse_cfg_str(cfg_str):
|
|
""" parse a string into a dict
|
|
string format: k1=v1,k2=v2,k3=v3
|
|
"""
|
|
cfg_list = cfg_str.split('-')
|
|
cfg_list = [c for c in cfg_list if len(c) > 0]
|
|
cfg_expand_list = []
|
|
for c in cfg_list:
|
|
k, v = c.split('=')
|
|
cfg_expand_list.append(k)
|
|
cfg_expand_list.append(v)
|
|
return cfg_expand_list
|
|
|
|
##cfg_dict = {}
|
|
# if cfg_str == '':
|
|
# return cfg_dict
|
|
##cfg_str_list = cfg_str.split(',')
|
|
# for p in cfg_str_list:
|
|
## kvs = p.split('=')
|
|
## assert(len(kvs) == 2), f'wrong format, expect k1=v1 for {p}'
|
|
## k, v = kvs
|
|
## cfg_dict[k] = v
|
|
# return cfg_dict
|
|
|
|
|
|
def readable_size(num_bytes: int) -> str:
|
|
return '' if isnan(num_bytes) else '{:.1f}'.format(calmsize(num_bytes))
|
|
|
|
|
|
class ExpTimer(object):
|
|
def __init__(self, num_epoch, start_epoch=0):
|
|
self.cur_epoch = start_epoch
|
|
self.num_epoch = num_epoch
|
|
self.time_list = []
|
|
|
|
def tic(self):
|
|
self.last_tic = time.time()
|
|
|
|
def toc(self):
|
|
self.time_list.append(time.time() - self.last_tic)
|
|
self.cur_epoch += 1
|
|
|
|
def hours_left(self):
|
|
if len(self.time_list) == 0:
|
|
return 0
|
|
num_epoch_left = self.num_epoch - self.cur_epoch
|
|
mean_epoch_time = np.array(self.time_list).mean()
|
|
hours_left = (mean_epoch_time * num_epoch_left) / 3600.0 # hours
|
|
return hours_left
|
|
|
|
def print(self):
|
|
logger.info('est: {:.1}h', self.hours_left)
|
|
|
|
|
|
def format_e(n):
|
|
if n == 0:
|
|
return '0'
|
|
a = '%E' % n
|
|
return a.split('E')[0].rstrip('0').rstrip('.') + 'E' + a.split('E-0')[1]
|
|
|
|
|
|
def get_evalname(config):
|
|
# generate tag for the generated samples
|
|
tag = ''
|
|
if config.ddpm.model_var_type != 'fixedlarge':
|
|
tag += config.ddpm.model_var_type
|
|
|
|
if not config.ddpm.ema:
|
|
tag += 'noema'
|
|
tag += f"s{config.trainer.seed}"
|
|
if config.data.te_max_sample_points != 2048:
|
|
tag += 'N%d' % config.data.te_max_sample_points
|
|
if config.eval_ddim_step > 0:
|
|
tag += 'ddim%d_%s%.1f' % (
|
|
config.eval_ddim_step,
|
|
config.sde.ddim_skip_type,
|
|
config.sde.ddim_kappa)
|
|
githash = os.popen('git rev-parse HEAD').read().strip()[:5]
|
|
logger.info('git hash: {}', githash)
|
|
tag += f"H{githash}"
|
|
return tag
|
|
|
|
|
|
def get_expname(config):
|
|
if config.exp_name == '' or config.exp_name == 'none':
|
|
cate = config.data.cates if type(
|
|
config.data.cates) is str else config.data.cates[0]
|
|
cfg_file_name = ''
|
|
if config.data.type == 'datasets.neuralspline_datasets':
|
|
cfg_file_name += 'ns'
|
|
cfg_file_name += '%s/' % cate
|
|
if len(config.hash):
|
|
cfg_file_name += '%s_' % config.hash
|
|
cfg_file_name += f"{config.trainer.type.split('.')[-1].split('_')[0]}_"
|
|
if len(config.cmt):
|
|
cfg_file_name += config.cmt + '_'
|
|
|
|
cfg_file_name += 'B%d' % config.data.batch_size
|
|
|
|
if config.data.tr_max_sample_points != 2048:
|
|
cfg_file_name += 'N%d' % config.data.tr_max_sample_points
|
|
run_time = time.strftime('%m%d')
|
|
cfg_file_name = run_time + '/' + cfg_file_name
|
|
|
|
else:
|
|
cfg_file_name = config.exp_name
|
|
return cfg_file_name
|