# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property # and proprietary rights in and to this software, related documentation # and any modifications thereto. Any use, reproduction, disclosure or # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. """copied and modified from https://github.com/NVlabs/LSGM/blob/5eae2f385c014f2250c3130152b6be711f6a3a5a/util/utils.py""" from loguru import logger from comet_ml import Experiment, ExistingExperiment import wandb as WB import os import math import shutil import json import time import sys import types from PIL import Image import torch import torch.nn as nn import numpy as np from torch import optim import torch.distributed as dist from torch.cuda.amp import autocast, GradScaler USE_COMET = int(os.environ.get('USE_COMET', 1)) USE_TFB = int(os.environ.get('USE_TFB', 0)) USE_WB = int(os.environ.get('USE_WB', 0)) print(f'utils/utils.py: USE_COMET={USE_COMET}, USE_WB={USE_WB}') class PixelNormal(object): def __init__(self, param, fixed_log_scales=None): size = param.size() C = size[1] if fixed_log_scales is None: self.num_c = C // 2 # B, 1 or 3, H, W self.means = param[:, :self.num_c, :, :] self.log_scales = torch.clamp( param[:, self.num_c:, :, :], min=-7.0) # B, 1 or 3, H, W raise NotImplementedError else: self.num_c = C # B, 1 or 3, H, W self.means = param # B, 1 or 3, H, W self.log_scales = view4D(fixed_log_scales, size) def get_params(self): return self.means, self.log_scales, self.num_c def log_prob(self, samples): B, C, H, W = samples.size() assert C == self.num_c log_probs = -0.5 * torch.square(self.means - samples) * torch.exp(-2.0 * self.log_scales) - self.log_scales - 0.9189385332 # -0.5*log(2*pi) return log_probs def sample(self, t=1.): z, rho = sample_normal_jit( self.means, torch.exp(self.log_scales)*t) # B, 3, H, W return z def log_prob_discrete(self, samples): """ Calculates discrete pixel probabilities. """ # samples should be in [-1, 1] already B, C, H, W = samples.size() assert C == self.num_c centered = samples - self.means inv_stdv = torch.exp(- self.log_scales) plus_in = inv_stdv * (centered + 1. / 255.) cdf_plus = torch.distributions.Normal(0, 1).cdf(plus_in) min_in = inv_stdv * (centered - 1. / 255.) cdf_min = torch.distributions.Normal(0, 1).cdf(min_in) log_cdf_plus = torch.log(torch.clamp(cdf_plus, min=1e-12)) log_one_minus_cdf_min = torch.log(torch.clamp(1. - cdf_min, min=1e-12)) cdf_delta = cdf_plus - cdf_min log_probs = torch.where(samples < -0.999, log_cdf_plus, torch.where(samples > 0.999, log_one_minus_cdf_min, torch.log(torch.clamp(cdf_delta, min=1e-12)))) assert log_probs.size() == samples.size() return log_probs def mean(self): return self.means class DummyGradScalar(object): def __init__(self, *args, **kwargs): pass def scale(self, input): return input def update(self): pass def state_dict(self): return {} def load_state_dict(self, x): pass def step(self, opt): opt.step() def unscale_(self, x): return x def get_opt(params, cfgopt, use_ema, other_cfg=None): if cfgopt.type == 'adam': optimizer = optim.Adam(params, lr=float(cfgopt.lr), betas=(cfgopt.beta1, cfgopt.beta2), weight_decay=cfgopt.weight_decay) elif cfgopt.type == 'sgd': optimizer = torch.optim.SGD(params, lr=float(cfgopt.lr), momentum=cfgopt.momentum) elif cfgopt.type == 'adamax': from utils.adamax import Adamax logger.info('[Optimizer] Adamax, lr={}, weight_decay={}, eps={}', cfgopt.lr, cfgopt.weight_decay, 1e-4) optimizer = Adamax(params, float(cfgopt.lr), weight_decay=args.weight_decay, eps=1e-4) else: assert 0, "Optimizer type should be either 'adam' or 'sgd'" if use_ema: logger.info('use_ema') ema_decay = 0.9999 from .ema import EMA optimizer = EMA(optimizer, ema_decay=ema_decay) scheduler = optim.lr_scheduler.LambdaLR( optimizer, lr_lambda=lambda x: 1.0) # constant lr scheduler_type = getattr(cfgopt, "scheduler", None) if scheduler_type is not None and len(scheduler_type) > 0: logger.info('get scheduler_type: {}', scheduler_type) if scheduler_type == 'exponential': decay = float(getattr(cfgopt, "step_decay", 0.1)) scheduler = optim.lr_scheduler.ExponentialLR(optimizer, decay) elif scheduler_type == 'step': step_size = int(getattr(cfgopt, "step_epoch", 500)) decay = float(getattr(cfgopt, "step_decay", 0.1)) scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=decay) elif scheduler_type == 'linear': # use default setting from shapeLatent start_epoch = int(getattr(cfgopt, 'sched_start_epoch', 200*1e3)) end_epoch = int(getattr(cfgopt, 'sched_end_epoch', 400*1e3)) end_lr = float(getattr(cfgopt, 'end_lr', 1e-4)) start_lr = cfgopt.lr def lambda_rule(epoch): if epoch <= start_epoch: return 1.0 elif epoch <= end_epoch: total = end_epoch - start_epoch delta = epoch - start_epoch frac = delta / total return (1 - frac) * 1.0 + frac * (end_lr / start_lr) else: return end_lr / start_lr scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule) elif scheduler_type == 'lambda': # linear': step_size = int(getattr(cfgopt, "step_epoch", 2000)) final_ratio = float(getattr(cfgopt, "final_ratio", 0.01)) start_ratio = float(getattr(cfgopt, "start_ratio", 0.5)) duration_ratio = float(getattr(cfgopt, "duration_ratio", 0.45)) def lambda_rule(ep): lr_l = 1.0 - min( 1, max(0, ep - start_ratio * step_size) / float(duration_ratio * step_size)) * (1 - final_ratio) return lr_l scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule) elif scheduler_type == 'cosine_anneal_nocycle': ## logger.info('scheduler_type: {}', scheduler_type) assert(other_cfg is not None) final_lr_ratio = float(getattr(cfgopt, "final_lr_ratio", 0.01)) eta_min = float(cfgopt.lr) * final_lr_ratio eta_max = float(cfgopt.lr) total_epoch = int(other_cfg.trainer.epochs) ##getattr(cfgopt, "step_epoch", 2000) start_ratio = float(getattr(cfgopt, "start_ratio", 0.6)) T_max = total_epoch * (1 - start_ratio) def lambda_rule(ep): curr_ep = max(0., ep - start_ratio * total_epoch) lr = eta_min + 0.5 * (eta_max - eta_min) * ( 1 + np.cos(np.pi * curr_ep / T_max)) lr_l = lr / eta_max return lr_l scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule) else: assert 0, "args.schedulers should be either 'exponential' or 'linear' or 'step'" return optimizer, scheduler class AvgrageMeter(object): def __init__(self): self.reset() def reset(self): self.avg = 0 self.sum = 0 self.cnt = 0 def update(self, val, n=1): self.sum += val * n self.cnt += n self.avg = self.sum / self.cnt class ExpMovingAvgrageMeter(object): def __init__(self, momentum=0.9): self.momentum = momentum self.reset() def reset(self): self.avg = 0 def update(self, val): self.avg = (1. - self.momentum) * self.avg + self.momentum * val class DummyDDP(nn.Module): def __init__(self, model): super(DummyDDP, self).__init__() self.module = model def forward(self, *input, **kwargs): return self.module(*input, **kwargs) def count_parameters_in_M(model): return np.sum(np.prod(v.size()) for name, v in model.named_parameters() if "auxiliary" not in name) / 1e6 def save_checkpoint(state, is_best, save): filename = os.path.join(save, 'checkpoint.pth.tar') torch.save(state, filename) if is_best: best_filename = os.path.join(save, 'model_best.pth.tar') shutil.copyfile(filename, best_filename) def save(model, model_path): torch.save(model.state_dict(), model_path) def load(model, model_path): model.load_state_dict(torch.load(model_path)) # def create_exp_dir(path, scripts_to_save=None): # if not os.path.exists(path): # os.makedirs(path, exist_ok=True) # print('Experiment dir : {}'.format(path)) # # if scripts_to_save is not None: # if not os.path.exists(os.path.join(path, 'scripts')): # os.mkdir(os.path.join(path, 'scripts')) # for script in scripts_to_save: # dst_file = os.path.join(path, 'scripts', os.path.basename(script)) # shutil.copyfile(script, dst_file) # # class Logger(object): # def __init__(self, rank, save): # # other libraries may set logging before arriving at this line. # # by reloading logging, we can get rid of previous configs set by other libraries. # from importlib import reload # reload(logging) # self.rank = rank # if self.rank == 0: # log_format = '%(asctime)s %(message)s' # logging.basicConfig(stream=sys.stdout, level=logging.INFO, # format=log_format, datefmt='%m/%d %I:%M:%S %p') # fh = logging.FileHandler(os.path.join(save, 'log.txt')) # fh.setFormatter(logging.Formatter(log_format)) # logging.getLogger().addHandler(fh) # self.start_time = time.time() # # def info(self, string, *args): # if self.rank == 0: # elapsed_time = time.time() - self.start_time # elapsed_time = time.strftime( # '(Elapsed: %H:%M:%S) ', time.gmtime(elapsed_time)) # if isinstance(string, str): # string = elapsed_time + string # else: # logging.info(elapsed_time) # logging.info(string, *args) def flatten_dict(dd, separator='_', prefix=''): return {prefix + separator + k if prefix else k: v for kk, vv in dd.items() for k, v in flatten_dict(vv, separator, kk).items()} \ if isinstance(dd, dict) else {prefix: dd} class Writer(object): def __init__(self, rank=0, save=None, exp=None, wandb=False): self.rank = rank self.exp = None self.wandb = False self.meter_dict = {} if self.rank == 0: self.exp = exp if USE_TFB and save is not None: logger.info('init TFB: {}', save) from torch.utils.tensorboard import SummaryWriter self.writer = SummaryWriter(log_dir=save, flush_secs=20) else: logger.info('Not init TFB') self.writer = None if self.exp is not None and save is not None: with open(os.path.join(save, 'url.txt'), 'a') as f: f.write(self.exp.url) f.write('\n') self.wandb = wandb else: logger.info('rank={}, init writer as a blackhole', rank) def set_model_graph(self, *args, **kwargs): if self.rank == 0 and self.exp is not None: self.exp.set_model_graph(*args, **kwargs) @property def url(self): if self.exp is not None: return self.exp.url else: return 'none' def add_hparams(self, cfg, args): # **kwargs): if self.exp is not None: self.exp.log_parameters(flatten_dict(cfg)) self.exp.log_parameters(flatten_dict(args)) if self.wandb: WB.config.update(flatten_dict(cfg)) WB.config.update(flatten_dict(args)) def avg_meter(self, name, value, step=None, epoch=None): if self.rank == 0: if name not in self.meter_dict: self.meter_dict[name] = AvgrageMeter() self.meter_dict[name].update(value) def upload_meter(self, step=None, epoch=None): for name, value in self.meter_dict.items(): self.add_scalar(name, value.avg, step=step, epoch=epoch) self.meter_dict = {} def add_scalar(self, *args, **kwargs): if self.rank == 0 and self.writer is not None: if 'step' in kwargs: self.writer.add_scalar(*args, global_step=kwargs['step']) else: self.writer.add_scalar(*args, **kwargs) if self.exp is not None: self.exp.log_metric(*args, **kwargs) if self.wandb: name = args[0] v = args[1] WB.log({name: v}) def log_model(self, name, path): pass def log_other(self, name, value): if self.rank == 0 and self.exp is not None: self.exp.log_other(name, value) # if self.rank == 0 and self.exp is not None: # self.exp.log_model(name, path) def watch(self, model): if self.wandb: WB.watch(model) def log_points_3d(self, scene_name, points, step=0): # *args, **kwargs): if self.rank == 0 and self.exp is not None: self.exp.log_points_3d(*args, **kwargs) if self.wandb: WB.log({"point_cloud": WB.Object3D(points)}) def add_figure(self, *args, **kwargs): if self.rank == 0 and self.writer is not None: self.writer.add_figure(*args, **kwargs) def add_image(self, *args, **kwargs): if self.rank == 0 and self.writer is not None: self.writer.add_image(*args, **kwargs) self.writer.flush() if self.exp is not None: name, img, i = args if isinstance(img, Image.Image): # logger.debug('log PIL Imgae: {}, {}', name, i) self.exp.log_image(img, name, step=i) elif type(img) is str: # logger.debug('log str image: {}, {}: {}', name, i, img) self.exp.log_image(img, name, step=i) elif torch.is_tensor(img): if img.shape[0] in [3, 4] and len(img.shape) == 3: # 3,H,W img = img.permute(1, 2, 0).contiguous() # 3,H,W -> H,W,3 if img.max() < 100: # [0-1] ndarr = img.mul(255).add_(0.5).clamp_( 0, 255).to('cpu') # .squeeze() ndarr = ndarr.numpy().astype(np.uint8) # .reshape(-1, ndarr.shape[-1])) im = Image.fromarray(ndarr) self.exp.log_image(im, name, step=i) else: im = img.to('cpu').numpy() self.exp.log_image(im, name, step=i) elif isinstance(img, (np.ndarray, np.generic)): if img.shape[0] == 3 and len(img.shape) == 3: # 3,H,W img = img.transpose(1, 2, 0) self.exp.log_image(img, name, step=i) if self.wandb and torch.is_tensor(img) and self.rank == 0: ## print(img.shape, img.max(), img.type()) WB.log({name: WB.Image(img.numpy())}) def add_histogram(self, *args, **kwargs): if self.rank == 0 and self.writer is not None: self.writer.add_histogram(*args, **kwargs) if self.exp is not None: name, value, step = args self.exp.log_histogram_3d(value, name, step) # *args, **kwargs) def add_histogram_if(self, write, *args, **kwargs): if write and False: # Used for debugging. self.add_histogram(*args, **kwargs) def close(self, *args, **kwargs): if self.rank == 0 and self.writer is not None: self.writer.close() def log_asset(self, *args, **kwargs): if self.exp is not None: self.exp.log_asset(*args, **kwargs) def common_init(rank, seed, save_dir, comet_key=''): # we use different seeds per gpu. But we sync the weights after model initialization. logger.info('[common-init] at rank={}, seed={}', rank, seed) torch.manual_seed(rank + seed) np.random.seed(rank + seed) torch.cuda.manual_seed(rank + seed) torch.cuda.manual_seed_all(rank + seed) torch.backends.cudnn.benchmark = True # prepare logging and tensorboard summary #logging = Logger(rank, save_dir) logging = None if rank == 0: if os.path.exists('.comet_api'): comet_args = json.load(open('.comet_api', 'r')) exp = Experiment(display_summary_level=0, disabled=USE_COMET == 0, **comet_args) exp.set_name(save_dir.split('exp/')[-1]) exp.set_cmd_args() exp.log_code(folder='./models/') exp.log_code(folder='./trainers/') exp.log_code(folder='./utils/') exp.log_code(folder='./datasets/') else: exp = None if os.path.exists('.wandb_api'): wb_args = json.load(open('.wandb_api', 'r')) wb_dir = '../exp/wandb/' if not os.path.exists( '/workspace/result') else '/workspace/result/wandb/' if not os.path.exists(wb_dir): os.makedirs(wb_dir) WB.init( project=wb_args['project'], entity=wb_args['entity'], name=save_dir.split('exp/')[-1], dir=wb_dir ) wandb = True else: wandb = False else: exp = None wandb = False writer = Writer(rank, save_dir, exp, wandb) logger.info('[common-init] DONE') return logging, writer def reduce_tensor(tensor, world_size): rt = tensor.clone() dist.all_reduce(rt, op=dist.ReduceOp.SUM) rt /= world_size return rt def get_stride_for_cell_type(cell_type): if cell_type.startswith('normal') or cell_type.startswith('combiner'): stride = 1 elif cell_type.startswith('down'): stride = 2 elif cell_type.startswith('up'): stride = -1 else: raise NotImplementedError(cell_type) return stride def get_cout(cin, stride): if stride == 1: cout = cin elif stride == -1: cout = cin // 2 elif stride == 2: cout = 2 * cin return cout def kl_balancer_coeff(num_scales, groups_per_scale, fun='square'): if fun == 'equal': coeff = torch.cat([torch.ones(groups_per_scale[num_scales - i - 1]) for i in range(num_scales)], dim=0).cuda() elif fun == 'linear': coeff = torch.cat([(2 ** i) * torch.ones(groups_per_scale[num_scales - i - 1]) for i in range(num_scales)], dim=0).cuda() elif fun == 'sqrt': coeff = torch.cat( [np.sqrt(2 ** i) * torch.ones(groups_per_scale[num_scales - i - 1]) for i in range(num_scales)], dim=0).cuda() elif fun == 'square': coeff = torch.cat( [np.square(2 ** i) / groups_per_scale[num_scales - i - 1] * torch.ones(groups_per_scale[num_scales - i - 1]) for i in range(num_scales)], dim=0).cuda() else: raise NotImplementedError # convert min to 1. coeff /= torch.min(coeff) return coeff def kl_per_group(kl_all): kl_vals = torch.mean(kl_all, dim=0) kl_coeff_i = torch.abs(kl_all) kl_coeff_i = torch.mean(kl_coeff_i, dim=0, keepdim=True) + 0.01 return kl_coeff_i, kl_vals def rec_balancer(rec_all, rec_coeff=1.0, npoints=None): # layer depth increase, alpha_i increase, 1/alpha_i decrease; kl_coeff decrease # the rec with more points should have higher loss min_points = min(npoints) coeff = [] rec_loss = 0 assert(len(rec_all) == len(npoints)) for ni, n in enumerate(npoints): c = rec_coeff*np.sqrt(n/min_points) rec_loss += rec_all[ni] * c coeff.append(c) # the smallest points' loss weight is 1 return rec_loss, coeff, rec_all def kl_balancer(kl_all, kl_coeff=1.0, kl_balance=False, alpha_i=None): # layer depth increase, alpha_i increase, 1/alpha_i decrease; kl_coeff decrease if kl_balance and kl_coeff < 1.0: alpha_i = alpha_i.unsqueeze(0) kl_all = torch.stack(kl_all, dim=1) kl_coeff_i, kl_vals = kl_per_group(kl_all) total_kl = torch.sum(kl_coeff_i) # kl = ( sum * kl / alpha ) kl_coeff_i = kl_coeff_i / alpha_i * total_kl kl_coeff_i = kl_coeff_i / torch.mean(kl_coeff_i, dim=1, keepdim=True) kl = torch.sum(kl_all * kl_coeff_i.detach(), dim=1) # for reporting kl_coeffs = kl_coeff_i.squeeze(0) else: kl_all = torch.stack(kl_all, dim=1) kl_vals = torch.mean(kl_all, dim=0) kl = torch.sum(kl_all, dim=1) kl_coeffs = torch.ones(size=(len(kl_vals),)) return kl_coeff * kl, kl_coeffs, kl_vals def kl_per_group_vada(all_log_q, all_neg_log_p): assert(len(all_log_q) == len(all_neg_log_p) ), f'get len={len(all_log_q)} and {len(all_neg_log_p)}' kl_all_list = [] kl_diag = [] for log_q, neg_log_p in zip(all_log_q, all_neg_log_p): kl_diag.append(torch.mean( torch.sum(neg_log_p + log_q, dim=[2, 3]), dim=0)) kl_all_list.append(torch.sum(neg_log_p + log_q, dim=[1, 2, 3])) # sum over D,H,W # kl_all = torch.stack(kl_all, dim=1) # batch x num_total_groups kl_vals = torch.mean(torch.stack(kl_all_list, dim=1), dim=0) # mean per group return kl_all_list, kl_vals, kl_diag def kl_coeff(step, total_step, constant_step, min_kl_coeff, max_kl_coeff): # return max(min(max_kl_coeff * (step - constant_step) / total_step, max_kl_coeff), min_kl_coeff) return max(min(min_kl_coeff + (max_kl_coeff - min_kl_coeff) * (step - constant_step) / total_step, max_kl_coeff), min_kl_coeff) def log_iw(decoder, x, log_q, log_p, crop=False): recon = reconstruction_loss(decoder, x, crop) return - recon - log_q + log_p def reconstruction_loss(decoder, x, crop=False): recon = decoder.log_p(x) if crop: recon = recon[:, :, 2:30, 2:30] if isinstance(decoder, DiscMixLogistic): return - torch.sum(recon, dim=[1, 2]) # summation over RGB is done. else: return - torch.sum(recon, dim=[1, 2, 3]) def vae_terms(all_log_q, all_eps): # compute kl kl_all = [] kl_diag = [] log_p, log_q = 0., 0. for log_q_conv, eps in zip(all_log_q, all_eps): log_p_conv = log_p_standard_normal(eps) kl_per_var = log_q_conv - log_p_conv kl_diag.append(torch.mean(torch.sum(kl_per_var, dim=[2, 3]), dim=0)) kl_all.append(torch.sum(kl_per_var, dim=[1, 2, 3])) log_q += torch.sum(log_q_conv, dim=[1, 2, 3]) log_p += torch.sum(log_p_conv, dim=[1, 2, 3]) return log_q, log_p, kl_all, kl_diag def sum_log_q(all_log_q): log_q = 0. for log_q_conv in all_log_q: log_q += torch.sum(log_q_conv, dim=[1, 2, 3]) return log_q def cross_entropy_normal(all_eps): cross_entropy = 0. neg_log_p_per_group = [] for eps in all_eps: neg_log_p_conv = - log_p_standard_normal(eps) neg_log_p = torch.sum(neg_log_p_conv, dim=[1, 2, 3]) cross_entropy += neg_log_p neg_log_p_per_group.append(neg_log_p_conv) return cross_entropy, neg_log_p_per_group def tile_image(batch_image, n, m=None): if m is None: m = n assert n * m == batch_image.size(0) channels, height, width = batch_image.size( 1), batch_image.size(2), batch_image.size(3) batch_image = batch_image.view(n, m, channels, height, width) batch_image = batch_image.permute(2, 0, 3, 1, 4) # n, height, n, width, c batch_image = batch_image.contiguous().view(channels, n * height, m * width) return batch_image def average_gradients_naive(params, is_distributed): """ Gradient averaging. """ if is_distributed: size = float(dist.get_world_size()) for param in params: if param.requires_grad: param.grad.data /= size dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM) def average_gradients(params, is_distributed): """ Gradient averaging. """ if is_distributed: if isinstance(params, types.GeneratorType): params = [p for p in params] size = float(dist.get_world_size()) grad_data = [] grad_size = [] grad_shapes = [] # Gather all grad values for param in params: if param.requires_grad: if param.grad is not None: grad_size.append(param.grad.data.numel()) grad_shapes.append(list(param.grad.data.shape)) grad_data.append(param.grad.data.flatten()) grad_data = torch.cat(grad_data).contiguous() # All-reduce grad values grad_data /= size dist.all_reduce(grad_data, op=dist.ReduceOp.SUM) # Put back the reduce grad values to parameters base = 0 i = 0 for param in params: if param.requires_grad and param.grad is not None: param.grad.data = grad_data[base:base + grad_size[i]].view(grad_shapes[i]) base += grad_size[i] i += 1 def average_params(params, is_distributed): """ parameter averaging. """ if is_distributed: size = float(dist.get_world_size()) for param in params: param.data /= size dist.all_reduce(param.data, op=dist.ReduceOp.SUM) def average_tensor(t, is_distributed): if is_distributed: size = float(dist.get_world_size()) dist.all_reduce(t.data, op=dist.ReduceOp.SUM) t.data /= size def broadcast_params(params, is_distributed): if is_distributed: for param in params: dist.broadcast(param.data, src=0) def num_output(dataset): if dataset in {'mnist', 'omniglot'}: return 28 * 28 elif dataset == 'cifar10': return 3 * 32 * 32 elif dataset.startswith('celeba') or dataset.startswith('imagenet') or dataset.startswith('lsun'): size = int(dataset.split('_')[-1]) return 3 * size * size elif dataset == 'ffhq': return 3 * 256 * 256 else: raise NotImplementedError def get_input_size(dataset): if dataset in {'mnist', 'omniglot'}: return 32 elif dataset == 'cifar10': return 32 elif dataset.startswith('celeba') or dataset.startswith('imagenet') or dataset.startswith('lsun'): size = int(dataset.split('_')[-1]) return size elif dataset == 'ffhq': return 256 elif dataset.startswith('shape'): return 1 # 2048 else: raise NotImplementedError def get_bpd_coeff(dataset): n = num_output(dataset) return 1. / np.log(2.) / n def get_channel_multiplier(dataset, num_scales): if dataset in {'cifar10', 'omniglot'}: mult = (1, 1, 1) elif dataset in {'celeba_256', 'ffhq', 'lsun_church_256'}: if num_scales == 3: mult = (1, 1, 1) # used for prior at 16 elif num_scales == 4: mult = (1, 2, 2, 2) # used for prior at 32 elif num_scales == 5: mult = (1, 1, 2, 2, 2) # used for prior at 64 elif dataset == 'mnist': mult = (1, 1) else: mult = (1, 1) # raise NotImplementedError return mult def get_attention_scales(dataset): if dataset in {'cifar10', 'omniglot'}: attn = (True, False, False) elif dataset in {'celeba_256', 'ffhq', 'lsun_church_256'}: # attn = (False, True, False, False) # used for 32 attn = (False, False, True, False, False) # used for 64 elif dataset == 'mnist': attn = (True, False) else: raise NotImplementedError return attn def change_bit_length(x, num_bits): if num_bits != 8: x = torch.floor(x * 255 / 2 ** (8 - num_bits)) x /= (2 ** num_bits - 1) return x def view4D(t, size, inplace=True): """ Equal to view(-1, 1, 1, 1).expand(size) Designed because of this bug: https://github.com/pytorch/pytorch/pull/48696 """ if inplace: return t.unsqueeze_(-1).unsqueeze_(-1).unsqueeze_(-1).expand(size) else: return t.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).expand(size) def get_arch_cells(arch_type, use_se): if arch_type == 'res_mbconv': arch_cells = dict() arch_cells['normal_enc'] = {'conv_branch': [ 'res_bnswish', 'res_bnswish'], 'se': use_se} arch_cells['down_enc'] = {'conv_branch': [ 'res_bnswish', 'res_bnswish'], 'se': use_se} arch_cells['normal_dec'] = { 'conv_branch': ['mconv_e6k5g0'], 'se': use_se} arch_cells['up_dec'] = {'conv_branch': ['mconv_e6k5g0'], 'se': use_se} arch_cells['normal_pre'] = {'conv_branch': [ 'res_bnswish', 'res_bnswish'], 'se': use_se} arch_cells['down_pre'] = {'conv_branch': [ 'res_bnswish', 'res_bnswish'], 'se': use_se} arch_cells['normal_post'] = { 'conv_branch': ['mconv_e3k5g0'], 'se': use_se} arch_cells['up_post'] = {'conv_branch': ['mconv_e3k5g0'], 'se': use_se} arch_cells['ar_nn'] = [''] elif arch_type == 'res_bnswish': arch_cells = dict() arch_cells['normal_enc'] = {'conv_branch': [ 'res_bnswish', 'res_bnswish'], 'se': use_se} arch_cells['down_enc'] = {'conv_branch': [ 'res_bnswish', 'res_bnswish'], 'se': use_se} arch_cells['normal_dec'] = {'conv_branch': [ 'res_bnswish', 'res_bnswish'], 'se': use_se} arch_cells['up_dec'] = {'conv_branch': [ 'res_bnswish', 'res_bnswish'], 'se': use_se} arch_cells['normal_pre'] = {'conv_branch': [ 'res_bnswish', 'res_bnswish'], 'se': use_se} arch_cells['down_pre'] = {'conv_branch': [ 'res_bnswish', 'res_bnswish'], 'se': use_se} arch_cells['normal_post'] = {'conv_branch': [ 'res_bnswish', 'res_bnswish'], 'se': use_se} arch_cells['up_post'] = {'conv_branch': [ 'res_bnswish', 'res_bnswish'], 'se': use_se} arch_cells['ar_nn'] = [''] elif arch_type == 'res_bnswish2': arch_cells = dict() arch_cells['normal_enc'] = { 'conv_branch': ['res_bnswish_x2'], 'se': use_se} arch_cells['down_enc'] = { 'conv_branch': ['res_bnswish_x2'], 'se': use_se} arch_cells['normal_dec'] = { 'conv_branch': ['res_bnswish_x2'], 'se': use_se} arch_cells['up_dec'] = {'conv_branch': [ 'res_bnswish_x2'], 'se': use_se} arch_cells['normal_pre'] = { 'conv_branch': ['res_bnswish_x2'], 'se': use_se} arch_cells['down_pre'] = { 'conv_branch': ['res_bnswish_x2'], 'se': use_se} arch_cells['normal_post'] = { 'conv_branch': ['res_bnswish_x2'], 'se': use_se} arch_cells['up_post'] = {'conv_branch': [ 'res_bnswish_x2'], 'se': use_se} arch_cells['ar_nn'] = [''] elif arch_type == 'res_mbconv_attn': arch_cells = dict() arch_cells['normal_enc'] = {'conv_branch': [ 'res_bnswish', 'res_bnswish', ], 'se': use_se, 'attn_type': 'attn'} arch_cells['down_enc'] = {'conv_branch': [ 'res_bnswish', 'res_bnswish'], 'se': use_se, 'attn_type': 'attn'} arch_cells['normal_dec'] = {'conv_branch': [ 'mconv_e6k5g0'], 'se': use_se, 'attn_type': 'attn'} arch_cells['up_dec'] = {'conv_branch': [ 'mconv_e6k5g0'], 'se': use_se, 'attn_type': 'attn'} arch_cells['normal_pre'] = {'conv_branch': [ 'res_bnswish', 'res_bnswish'], 'se': use_se} arch_cells['down_pre'] = {'conv_branch': [ 'res_bnswish', 'res_bnswish'], 'se': use_se} arch_cells['normal_post'] = { 'conv_branch': ['mconv_e3k5g0'], 'se': use_se} arch_cells['up_post'] = {'conv_branch': ['mconv_e3k5g0'], 'se': use_se} arch_cells['ar_nn'] = [''] elif arch_type == 'res_mbconv_attn_half': arch_cells = dict() arch_cells['normal_enc'] = {'conv_branch': [ 'res_bnswish', 'res_bnswish'], 'se': use_se} arch_cells['down_enc'] = {'conv_branch': [ 'res_bnswish', 'res_bnswish'], 'se': use_se} arch_cells['normal_dec'] = {'conv_branch': [ 'mconv_e6k5g0'], 'se': use_se, 'attn_type': 'attn'} arch_cells['up_dec'] = {'conv_branch': [ 'mconv_e6k5g0'], 'se': use_se, 'attn_type': 'attn'} arch_cells['normal_pre'] = {'conv_branch': [ 'res_bnswish', 'res_bnswish'], 'se': use_se} arch_cells['down_pre'] = {'conv_branch': [ 'res_bnswish', 'res_bnswish'], 'se': use_se} arch_cells['normal_post'] = { 'conv_branch': ['mconv_e3k5g0'], 'se': use_se} arch_cells['up_post'] = {'conv_branch': ['mconv_e3k5g0'], 'se': use_se} arch_cells['ar_nn'] = [''] else: raise NotImplementedError return arch_cells def get_arch_cells_denoising(arch_type, use_se, apply_sqrt2): if arch_type == 'res_mbconv': arch_cells = dict() arch_cells['normal_enc_diff'] = { 'conv_branch': ['mconv_e6k5g0_gn'], 'se': use_se} arch_cells['down_enc_diff'] = { 'conv_branch': ['mconv_e6k5g0_gn'], 'se': use_se} arch_cells['normal_dec_diff'] = { 'conv_branch': ['mconv_e6k5g0_gn'], 'se': use_se} arch_cells['up_dec_diff'] = { 'conv_branch': ['mconv_e6k5g0_gn'], 'se': use_se} elif arch_type == 'res_ho': arch_cells = dict() arch_cells['normal_enc_diff'] = { 'conv_branch': ['res_gnswish_x2'], 'se': use_se} arch_cells['down_enc_diff'] = { 'conv_branch': ['res_gnswish_x2'], 'se': use_se} arch_cells['normal_dec_diff'] = { 'conv_branch': ['res_gnswish_x2'], 'se': use_se} arch_cells['up_dec_diff'] = { 'conv_branch': ['res_gnswish_x2'], 'se': use_se} elif arch_type == 'res_ho_p1': arch_cells = dict() arch_cells['normal_enc_diff'] = { 'conv_branch': ['res_gnswish_x2_p1'], 'se': use_se} arch_cells['down_enc_diff'] = { 'conv_branch': ['res_gnswish_x2_p1'], 'se': use_se} arch_cells['normal_dec_diff'] = { 'conv_branch': ['res_gnswish_x2_p1'], 'se': use_se} arch_cells['up_dec_diff'] = { 'conv_branch': ['res_gnswish_x2_p1'], 'se': use_se} elif arch_type == 'res_ho_attn': arch_cells = dict() arch_cells['normal_enc_diff'] = { 'conv_branch': ['res_gnswish_x2'], 'se': use_se} arch_cells['down_enc_diff'] = { 'conv_branch': ['res_gnswish_x2'], 'se': use_se} arch_cells['normal_dec_diff'] = { 'conv_branch': ['res_gnswish_x2'], 'se': use_se} arch_cells['up_dec_diff'] = { 'conv_branch': ['res_gnswish_x2'], 'se': use_se} else: raise NotImplementedError for k in arch_cells: arch_cells[k]['apply_sqrt2'] = apply_sqrt2 return arch_cells def groups_per_scale(num_scales, num_groups_per_scale): g = [] n = num_groups_per_scale for s in range(num_scales): assert n >= 1 g.append(n) return g #class PositionalEmbedding(nn.Module): # def __init__(self, embedding_dim, scale): # super(PositionalEmbedding, self).__init__() # self.embedding_dim = embedding_dim # self.scale = scale # # def forward(self, timesteps): # assert len(timesteps.shape) == 1 # timesteps = timesteps * self.scale # half_dim = self.embedding_dim // 2 # emb = math.log(10000) / (half_dim - 1) # emb = torch.exp(torch.arange(half_dim) * -emb) # emb = emb.to(device=timesteps.device) # emb = timesteps[:, None] * emb[None, :] # emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) # return emb # # #class RandomFourierEmbedding(nn.Module): # def __init__(self, embedding_dim, scale): # super(RandomFourierEmbedding, self).__init__() # self.w = nn.Parameter(torch.randn( # size=(1, embedding_dim // 2)) * scale, requires_grad=False) # # def forward(self, timesteps): # emb = torch.mm(timesteps[:, None], self.w * 2 * 3.14159265359) # return torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) # # #def init_temb_fun(embedding_type, embedding_scale, embedding_dim): # if embedding_type == 'positional': # temb_fun = PositionalEmbedding(embedding_dim, embedding_scale) # elif embedding_type == 'fourier': # temb_fun = RandomFourierEmbedding(embedding_dim, embedding_scale) # else: # raise NotImplementedError # # return temb_fun def symmetrize_image_data(images): return 2.0 * images - 1.0 def unsymmetrize_image_data(images): return (images + 1.) / 2. def normalize_symmetric(images): """ Normalize images by dividing the largest intensity. Used for visualizing the intermediate steps. """ b = images.shape[0] m, _ = torch.max(torch.abs(images).view(b, -1), dim=1) images /= (m.view(b, 1, 1, 1) + 1e-3) return images @torch.jit.script def soft_clamp5(x: torch.Tensor): # 5. * torch.tanh(x / 5.) <--> soft differentiable clamp between [-5, 5] return x.div(5.).tanh_().mul(5.) @torch.jit.script def soft_clamp(x: torch.Tensor, a: torch.Tensor): return x.div(a).tanh_().mul(a) class SoftClamp5(nn.Module): def __init__(self): super(SoftClamp5, self).__init__() def forward(self, x): return soft_clamp5(x) def override_architecture_fields(args, stored_args, logging): # list of architecture parameters used in NVAE: architecture_fields = ['arch_instance', 'num_nf', 'num_latent_scales', 'num_groups_per_scale', 'num_latent_per_group', 'num_channels_enc', 'num_preprocess_blocks', 'num_preprocess_cells', 'num_cell_per_cond_enc', 'num_channels_dec', 'num_postprocess_blocks', 'num_postprocess_cells', 'num_cell_per_cond_dec', 'decoder_dist', 'num_x_bits', 'log_sig_q_scale', 'latent_grad_cutoff', 'progressive_output_vae', 'progressive_input_vae', 'channel_mult'] # backward compatibility """ We have broken backward compatibility. No need to se these manually if not hasattr(stored_args, 'log_sig_q_scale'): logging.info('*** Setting %s manually ****', 'log_sig_q_scale') setattr(stored_args, 'log_sig_q_scale', 5.) if not hasattr(stored_args, 'latent_grad_cutoff'): logging.info('*** Setting %s manually ****', 'latent_grad_cutoff') setattr(stored_args, 'latent_grad_cutoff', 0.) if not hasattr(stored_args, 'progressive_input_vae'): logging.info('*** Setting %s manually ****', 'progressive_input_vae') setattr(stored_args, 'progressive_input_vae', 'none') if not hasattr(stored_args, 'progressive_output_vae'): logging.info('*** Setting %s manually ****', 'progressive_output_vae') setattr(stored_args, 'progressive_output_vae', 'none') """ for f in architecture_fields: if not hasattr(args, f) or getattr(args, f) != getattr(stored_args, f): logging.info('Setting %s from loaded checkpoint', f) setattr(args, f, getattr(stored_args, f)) def init_processes(rank, size, fn, args, config): """ Initialize the distributed environment. """ os.environ['MASTER_ADDR'] = args.master_address os.environ['MASTER_PORT'] = '6020' logger.info('set MASTER_PORT: {}, MASTER_PORT: {}', os.environ['MASTER_ADDR'], os.environ['MASTER_PORT']) # if args.num_proc_node == 1: # try to solve the port occupied issue # import socket # import errno # a_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) # for p in range(6010, 6030): # location = (args.master_address, p) # "127.0.0.1", p) # try: # a_socket.bind((args.master_address, p)) # logger.debug('set port as {}', p) # os.environ['MASTER_PORT'] = '%d' % p # a_socket.close() # break # except socket.error as e: # a = 0 # # if e.errno == errno.EADDRINUSE: # # # logger.debug("Port {} is already in use", p) # # else: # # logger.debug(e) logger.info('init_process: rank={}, world_size={}', rank, size) torch.cuda.set_device(args.local_rank) dist.init_process_group( backend='nccl', init_method='env://', rank=rank, world_size=size) fn(args, config) logger.info('barrier: rank={}, world_size={}', rank, size) dist.barrier() logger.info('skip destroy_process_group: rank={}, world_size={}', rank, size) # dist.destroy_process_group() logger.info('skip destroy fini') def sample_rademacher_like(y): return torch.randint(low=0, high=2, size=y.shape, device='cuda') * 2 - 1 def sample_gaussian_like(y): return torch.randn_like(y, device='cuda') def trace_df_dx_hutchinson(f, x, noise, no_autograd): """ Hutchinson's trace estimator for Jacobian df/dx, O(1) call to autograd """ if no_autograd: # the following is compatible with checkpointing torch.sum(f * noise).backward() # torch.autograd.backward(tensors=[f], grad_tensors=[noise]) jvp = x.grad trJ = torch.sum(jvp * noise, dim=[1, 2, 3]) x.grad = None else: jvp = torch.autograd.grad(f, x, noise, create_graph=False)[0] trJ = torch.sum(jvp * noise, dim=[1, 2, 3]) # trJ = torch.einsum('bijk,bijk->b', jvp, noise) # we could test if there's a speed difference in einsum vs sum return trJ def calc_jacobian_regularization(pred_params, eps_t, dae, var_t, m_t, f_t, g2_t, var_N_t, args): """ Calculates Jabobian regularization loss. For reference implementations, see https://github.com/facebookresearch/jacobian_regularizer/blob/master/jacobian/jacobian.py or https://github.com/cfinlay/ffjord-rnode/blob/master/lib/layers/odefunc.py. """ # eps_t_jvp = eps_t.detach() # eps_t_jvp = eps_t.detach().requires_grad_() if args.no_autograd_jvp: raise NotImplementedError( "We have not implemented no_autograd_jvp for jacobian reg.") jvp_ode_func_norms = [] alpha = torch.sigmoid(dae.mixing_logit.detach()) for _ in range(args.jac_reg_samples): noise = sample_gaussian_like(eps_t) jvp = torch.autograd.grad( pred_params, eps_t, noise, create_graph=True)[0] if args.sde_type in ['geometric_sde', 'vpsde', 'power_vpsde']: jvp_ode_func = alpha * (noise * torch.sqrt(var_t) - jvp) if not args.jac_kin_reg_drop_weights: jvp_ode_func = f_t / torch.sqrt(var_t) * jvp_ode_func elif args.sde_type in ['sub_vpsde', 'sub_power_vpsde']: sigma2_N_t = (1.0 - m_t ** 2) ** 2 + m_t ** 2 jvp_ode_func = noise * torch.sqrt(var_t) / (1.0 - m_t ** 4) - ( (1.0 - alpha) * noise * torch.sqrt(var_t) / sigma2_N_t + alpha * jvp) if not args.jac_kin_reg_drop_weights: jvp_ode_func = f_t * (1.0 - m_t ** 4) / \ torch.sqrt(var_t) * jvp_ode_func elif args.sde_type in ['vesde']: jvp_ode_func = (1.0 - alpha) * noise * \ torch.sqrt(var_t) / var_N_t + alpha * jvp if not args.jac_kin_reg_drop_weights: jvp_ode_func = 0.5 * g2_t / torch.sqrt(var_t) * jvp_ode_func else: raise ValueError("Unrecognized SDE type: {}".format(args.sde_type)) jvp_ode_func_norms.append(jvp_ode_func.view( eps_t.size(0), -1).pow(2).sum(dim=1, keepdim=True)) jac_reg_loss = torch.cat(jvp_ode_func_norms, dim=1).mean() # jac_reg_loss = torch.mean(jvp_ode_func.view(eps_t.size(0), -1).pow(2).sum(dim=1)) return jac_reg_loss def calc_kinetic_regularization(pred_params, eps_t, dae, var_t, m_t, f_t, g2_t, var_N_t, args): """ Calculates kinetic regularization loss. For a reference implementation, see https://github.com/cfinlay/ffjord-rnode/blob/master/lib/layers/wrappers/cnf_regularization.py """ # eps_t_kin = eps_t.detach() alpha = torch.sigmoid(dae.mixing_logit.detach()) if args.sde_type in ['geometric_sde', 'vpsde', 'power_vpsde']: ode_func = alpha * (eps_t * torch.sqrt(var_t) - pred_params) if not args.jac_kin_reg_drop_weights: ode_func = f_t / torch.sqrt(var_t) * ode_func elif args.sde_type in ['sub_vpsde', 'sub_power_vpsde']: sigma2_N_t = (1.0 - m_t ** 2) ** 2 + m_t ** 2 ode_func = eps_t * torch.sqrt(var_t) / (1.0 - m_t ** 4) - ( (1.0 - alpha) * eps_t * torch.sqrt(var_t) / sigma2_N_t + alpha * pred_params) if not args.jac_kin_reg_drop_weights: ode_func = f_t * (1.0 - m_t ** 4) / torch.sqrt(var_t) * ode_func elif args.sde_type in ['vesde']: ode_func = (1.0 - alpha) * eps_t * torch.sqrt(var_t) / \ var_N_t + alpha * pred_params if not args.jac_kin_reg_drop_weights: ode_func = 0.5 * g2_t / torch.sqrt(var_t) * ode_func else: raise ValueError("Unrecognized SDE type: {}".format(args.sde_type)) kin_reg_loss = torch.mean(ode_func.view( eps_t.size(0), -1).pow(2).sum(dim=1)) return kin_reg_loss def different_p_q_objectives(iw_sample_p, iw_sample_q): assert iw_sample_p in ['ll_uniform', 'drop_all_uniform', 'll_iw', 'drop_all_iw', 'drop_sigma2t_iw', 'rescale_iw', 'drop_sigma2t_uniform'] assert iw_sample_q in ['reweight_p_samples', 'll_uniform', 'll_iw'] # Removed assert below. It may be stupid, but user can still do it. It may make sense for debugging purposes. # assert iw_sample_p != iw_sample_q, 'It does not make sense to use the same objectives for p and q, but train ' \ # 'with separated q and p updates. To reuse the p objective for q, specify ' \ # '"reweight_p_samples" instead (for the ll-based objectives, the ' \ # 'reweighting factor will simply be 1.0 then)!' # In these cases, we reuse the likelihood-based p-objective (either the uniform sampling version or the importance # sampling version) also for q. if iw_sample_p in ['ll_uniform', 'll_iw'] and iw_sample_q == 'reweight_p_samples': return False # In these cases, we are using a non-likelihood-based objective for p, and hence definitly need to use another q # objective. else: return True def decoder_output(dataset, logits, fixed_log_scales=None): if dataset in {'cifar10', 'celeba_64', 'celeba_256', 'imagenet_32', 'imagenet_64', 'ffhq', 'lsun_bedroom_128', 'lsun_bedroom_256', 'mnist', 'omniglot', 'lsun_church_256'}: return PixelNormal(logits, fixed_log_scales) else: return PixelNormal(logits, fixed_log_scales) # raise NotImplementedError def get_mixed_prediction(mixed_prediction, param, mixing_logit, mixing_component=None): if mixed_prediction: assert mixing_component is not None, 'Provide mixing component when mixed_prediction is enabled.' coeff = torch.sigmoid(mixing_logit) param = (1 - coeff) * mixing_component + coeff * param return param def set_vesde_sigma_max(args, vae, train_queue, logging, is_distributed): logging.info('') logging.info( 'Calculating max. pairwise distance in latent space to set sigma2_max for VESDE...') eps_list = [] vae.eval() for step, x in enumerate(train_queue): x = x[0] if len(x) > 1 else x x = x.cuda() x = symmetrize_image_data(x) # run vae with autocast(enabled=args.autocast_train): with torch.set_grad_enabled(False): logits, all_log_q, all_eps = vae(x) eps = torch.cat(all_eps, dim=1) eps_list.append(eps.detach()) # if step > 5: ### DEBUG # break ### DEBUG # concat eps tensor on each GPU and then gather all on all GPUs eps_this_rank = torch.cat(eps_list, dim=0) if is_distributed: eps_all_gathered = [torch.zeros_like( eps_this_rank)] * dist.get_world_size() dist.all_gather(eps_all_gathered, eps_this_rank) eps_full = torch.cat(eps_all_gathered, dim=0) else: eps_full = eps_this_rank # max pairwise distance squared between all latent encodings, is computed on CPU eps_full = eps_full.cpu().float() eps_full = eps_full.flatten(start_dim=1).unsqueeze(0) max_pairwise_dist_sqr = torch.cdist(eps_full, eps_full).square().max() max_pairwise_dist_sqr = max_pairwise_dist_sqr.cuda() # to be safe, we broadcast to all GPUs if we are in distributed environment. Shouldn't be necessary in principle. if is_distributed: dist.broadcast(max_pairwise_dist_sqr, src=0) args.sigma2_max = max_pairwise_dist_sqr.item() logging.info('Done! Set args.sigma2_max set to {}'.format(args.sigma2_max)) logging.info('') return args def mask_inactive_variables(x, is_active): x = x * is_active return x def common_x_operations(x, num_x_bits): x = x[0] if len(x) > 1 else x x = x.cuda() # change bit length x = change_bit_length(x, num_x_bits) x = symmetrize_image_data(x) return x def vae_regularization(args, vae_sn_calculator, loss_weight=None): """ when using hvae_trainer, we pass args=None, and loss_weight value """ regularization_q, vae_norm_loss, vae_bn_loss, vae_wdn_coeff = 0., 0., 0., args.weight_decay_norm_vae if loss_weight is None else loss_weight if loss_weight is not None or args.train_vae: vae_norm_loss = vae_sn_calculator.spectral_norm_parallel() vae_bn_loss = vae_sn_calculator.batchnorm_loss() regularization_q = (vae_norm_loss + vae_bn_loss) * vae_wdn_coeff return regularization_q, vae_norm_loss, vae_bn_loss, vae_wdn_coeff def dae_regularization(args, dae_sn_calculator, diffusion, dae, step, t, pred_params_p, eps_t_p, var_t_p, m_t_p, g2_t_p): dae_wdn_coeff = args.weight_decay_norm_dae dae_norm_loss = dae_sn_calculator.spectral_norm_parallel() dae_bn_loss = dae_sn_calculator.batchnorm_loss() regularization_p = (dae_norm_loss + dae_bn_loss) * dae_wdn_coeff # Jacobian regularization jac_reg_loss = 0. if args.jac_reg_coeff > 0.0 and step % args.jac_reg_freq == 0: f_t = diffusion.f(t).view(-1, 1, 1, 1) var_N_t = diffusion.var_N( t).view(-1, 1, 1, 1) if args.sde_type == 'vesde' else None """ # Arash: Please remove the following if it looks correct to you, Karsten. # jac_reg_loss = utils.calc_jacobian_regularization(pred_params, eps_t, dae, var_t, m_t, f_t, args) if args.iw_sample_q in ['ll_uniform', 'll_iw']: pred_params_jac_reg = torch.chunk(pred_params, chunks=2, dim=0)[0] var_t_jac_reg, m_t_jac_reg, f_t_jac_reg = torch.chunk(var_t, chunks=2, dim=0)[0], \ torch.chunk(m_t, chunks=2, dim=0)[0], \ torch.chunk(f_t, chunks=2, dim=0)[0] g2_t_jac_reg = torch.chunk(g2_t, chunks=2, dim=0)[0] var_N_t_jac_reg = torch.chunk(var_N_t, chunks=2, dim=0)[0] if args.sde_type == 'vesde' else None else: pred_params_jac_reg = pred_params var_t_jac_reg, m_t_jac_reg, f_t_jac_reg, g2_t_jac_reg, var_N_t_jac_reg = var_t, m_t, f_t, g2_t, var_N_t jac_reg_loss = utils.calc_jacobian_regularization(pred_params_jac_reg, eps_t_p, dae, var_t_jac_reg, m_t_jac_reg, f_t_jac_reg, g2_t_jac_reg, var_N_t_jac_reg, args) """ jac_reg_loss = calc_jacobian_regularization(pred_params_p, eps_t_p, dae, var_t_p, m_t_p, f_t, g2_t_p, var_N_t, args) regularization_p += args.jac_reg_coeff * jac_reg_loss # Kinetic regularization kin_reg_loss = 0. if args.kin_reg_coeff > 0.0: f_t = diffusion.f(t).view(-1, 1, 1, 1) var_N_t = diffusion.var_N( t).view(-1, 1, 1, 1) if args.sde_type == 'vesde' else None """ # Arash: Please remove the following if it looks correct to you, Karsten. # kin_reg_loss = utils.calc_kinetic_regularization(pred_params, eps_t, dae, var_t, m_t, f_t, args) if args.iw_sample_q in ['ll_uniform', 'll_iw']: pred_params_kin_reg = torch.chunk(pred_params, chunks=2, dim=0)[0] var_t_kin_reg, m_t_kin_reg, f_t_kin_reg = torch.chunk(var_t, chunks=2, dim=0)[0], \ torch.chunk(m_t, chunks=2, dim=0)[0], \ torch.chunk(f_t, chunks=2, dim=0)[0] g2_t_kin_reg = torch.chunk(g2_t, chunks=2, dim=0)[0] var_N_t_kin_reg = torch.chunk(var_N_t, chunks=2, dim=0)[0] if args.sde_type == 'vesde' else None else: pred_params_kin_reg = pred_params var_t_kin_reg, m_t_kin_reg, f_t_kin_reg, g2_t_kin_reg, var_N_t_kin_reg = var_t, m_t, f_t, g2_t, var_N_t kin_reg_loss = utils.calc_kinetic_regularization(pred_params_kin_reg, eps_t_p, dae, var_t_kin_reg, m_t_kin_reg, f_t_kin_reg, g2_t_kin_reg, var_N_t_kin_reg, args) """ kin_reg_loss = calc_kinetic_regularization(pred_params_p, eps_t_p, dae, var_t_p, m_t_p, f_t, g2_t_p, var_N_t, args) regularization_p += args.kin_reg_coeff * kin_reg_loss return regularization_p, dae_norm_loss, dae_bn_loss, dae_wdn_coeff, jac_reg_loss, kin_reg_loss def update_vae_lr(args, global_step, warmup_iters, vae_optimizer): if global_step < warmup_iters: lr = args.trainer.opt.lr * float(global_step) / warmup_iters for param_group in vae_optimizer.param_groups: param_group['lr'] = lr # use same lr if lr for local-dae is not specified def update_lr(args, global_step, warmup_iters, dae_optimizer, vae_optimizer, dae_local_optimizer=None): if global_step < warmup_iters: lr = args.learning_rate_dae * float(global_step) / warmup_iters if args.learning_rate_mlogit > 0 and len(dae_optimizer.param_groups) > 1: lr_mlogit = args.learning_rate_mlogit * \ float(global_step) / warmup_iters for i, param_group in enumerate(dae_optimizer.param_groups): if i == 0: param_group['lr'] = lr_mlogit else: param_group['lr'] = lr else: for param_group in dae_optimizer.param_groups: param_group['lr'] = lr # use same lr if lr for local-dae is not specified lr = lr if args.learning_rate_dae_local <= 0 else args.learning_rate_dae_local * \ float(global_step) / warmup_iters if dae_local_optimizer is not None: for param_group in dae_local_optimizer.param_groups: param_group['lr'] = lr if args.train_vae: lr = args.learning_rate_vae * float(global_step) / warmup_iters for param_group in vae_optimizer.param_groups: param_group['lr'] = lr def start_meters(): tr_loss_meter = AvgrageMeter() vae_recon_meter = AvgrageMeter() vae_kl_meter = AvgrageMeter() vae_nelbo_meter = AvgrageMeter() kl_per_group_ema = AvgrageMeter() return tr_loss_meter, vae_recon_meter, vae_kl_meter, vae_nelbo_meter, kl_per_group_ema def epoch_logging(args, writer, step, vae_recon_meter, vae_kl_meter, vae_nelbo_meter, tr_loss_meter, kl_per_group_ema): average_tensor(vae_recon_meter.avg, args.distributed) average_tensor(vae_kl_meter.avg, args.distributed) average_tensor(vae_nelbo_meter.avg, args.distributed) average_tensor(tr_loss_meter.avg, args.distributed) average_tensor(kl_per_group_ema.avg, args.distributed) writer.add_scalar('epoch/vae_recon', vae_recon_meter.avg, step) writer.add_scalar('epoch/vae_kl', vae_kl_meter.avg, step) writer.add_scalar('epoch/vae_nelbo', vae_nelbo_meter.avg, step) writer.add_scalar('epoch/total_loss', tr_loss_meter.avg, step) # add kl value per group to tensorboard for i in range(len(kl_per_group_ema.avg)): writer.add_scalar('kl_value/group_%d' % i, kl_per_group_ema.avg[i], step) def infer_active_variables(train_queue, vae, args, device, distributed, max_iter=None): kl_meter = AvgrageMeter() vae.eval() for step, x in enumerate(train_queue): if max_iter is not None and step > max_iter: break tr_pts = x['tr_points'] with autocast(enabled=args.autocast_train): # apply vae: with torch.set_grad_enabled(False): # output = model.recont(val_x) ## torch.cat([val_x, tr_x])) dist = vae.encode(tr_pts.to(device)) eps = dist.sample()[0] all_log_q = [dist.log_p(eps)] ## _, all_log_q, all_eps = vae(x) ## all_eps = vae.concat_eps_per_scale(all_eps) ## all_log_q = vae.concat_eps_per_scale(all_log_q) all_eps = [eps] def make_4d(xlist): return [ x.unsqueeze(-1).unsqueeze(-1) if len(x.shape) == 2 else x.unsqueeze(-1) for x in xlist] log_q, log_p, kl_all, kl_diag = vae_terms( make_4d(all_log_q), make_4d(all_eps)) kl_meter.update(kl_diag[0], 1) # only the top scale average_tensor(kl_meter.avg, distributed) return kl_meter.avg > 0.1