LION/utils/model_helper.py

139 lines
4.9 KiB
Python
Raw Normal View History

2023-01-23 05:14:49 +00:00
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
import torch
import torch.nn.functional as F
from loguru import logger
import torch
from torch.autograd import grad
import importlib
from utils.evaluation_metrics_fast import distChamferCUDA, emd_approx, distChamferCUDA_l1
def loss_fn(predv, targetv, loss_type, point_dim, batch_size, loss_weight_emd=0.02,
loss_weight_cdnorm=1,
return_dict=False):
B = batch_size
output = {}
if loss_type == 'dcd':
from evaluation.dist_aware_cd import calc_dcd
res = calc_dcd(predv, targetv)
loss = res[0]
output['print/rec_dcd'] = loss
elif loss_type == 'cd1_sum_emd': # use l1 loss in chamfer distance, take the sum
dl, dr = distChamferCUDA_l1(predv, targetv, point_dim)
loss = dl + dr # .view(B,-1).sum(-1) + dr.view(B,-1).sum(-1)
output['print/rec_cd1_sum'] = loss
emd = emd_approx(predv, targetv)
emd = emd.view(B, -1)*predv.view(B, -1).shape[1]
output['print/rec_emd'] = emd
loss = loss + emd
elif loss_type == 'cd1_sum': # use l1 loss in chamfer distance, take the sum
dl, dr = distChamferCUDA_l1(predv, targetv, point_dim)
loss = dl + dr # .view(B,-1).sum(-1) + dr.view(B,-1).sum(-1)
output['print/rec_cd1_sum'] = loss
# use l2 loss in chamfer distance, take the sum over N points, but its mean over point dim (3)
elif loss_type == 'cd_sum':
dl, dr = distChamferCUDA(predv, targetv)
loss = dl.view(B, -1).sum(-1) + dr.view(B, -1).sum(-1)
output['print/rec_cd1_sum'] = loss
elif loss_type == 'chamfer':
dl, dr = distChamferCUDA(predv, targetv)
loss = dl.view(B, -1).mean(-1) + dr.view(B, -1).mean(-1)
output['print/rec_cd'] = loss
elif loss_type == 'mse_sum':
loss = F.mse_loss(
predv.contiguous().view(-1, point_dim), targetv.view(-1, point_dim),
reduction='sum')
output['print/rec_mse'] = loss
elif loss_type == 'l1_sum':
loss = F.l1_loss(
predv.contiguous().view(-1, point_dim), targetv.view(-1, point_dim),
reduction='sum')
output['print/rec_l1'] = loss
elif loss_type == 'l1_cd':
loss = F.l1_loss(
predv.contiguous().view(-1, point_dim), targetv.view(-1, point_dim),
reduction='sum')
output['print/rec_l1'] = loss
dl, dr = distChamferCUDA(predv, targetv)
cd_loss = dl.view(B, -1).sum(-1) + dr.view(B, -1).sum(-1)
output['print/rec_cd'] = cd_loss
loss = loss + cd_loss
elif loss_type == 'mse':
loss = F.mse_loss(
predv.contiguous().view(-1, point_dim), targetv.view(-1, point_dim),
reduction='mean')
output['print/rec_mse'] = loss
elif loss_type == 'emd':
emd = emd_approx(predv, targetv)
# dl.view(B,-1).mean(-1) + dr.view(B,-1).mean(-1)
loss = emd.view(B, -1)
output['print/rec_emd'] = loss
elif loss_type == 'chamfer_emd':
dl, dr = distChamferCUDA(predv, targetv)
cd = dl.view(B, -1).mean(-1) + dr.view(B, -1).mean(-1)
cd = cd.view(B, -1)
emd = emd_approx(predv, targetv).view(B, -1)
loss = cd + emd * loss_weight_emd # balance the scale of two loss
output['print/rec_emd'] = emd.mean()
output['print/rec_weight_emd'] = loss_weight_emd
output['print/rec_cd'] = cd.mean()
else:
raise ValueError(loss_type)
if return_dict:
return loss, output
return loss
def import_model(model_str):
logger.info('import: {}', model_str)
p, m = model_str.rsplit('.', 1)
mod = importlib.import_module(p)
Model = getattr(mod, m)
return Model
## self.encoder = Model(zdim=latent_dim, input_dim=args.ddpm.input_dim, args=args)
class DataParallelPassthrough(torch.nn.parallel.DistributedDataParallel):
def __getattr__(self, name):
try:
return super().__getattr__(name)
except AttributeError:
return getattr(self.module, name)
def average_gradients(model, rank=-1):
size = float(dist.get_world_size())
for name, param in model.named_parameters():
if not param.requires_grad or param.grad is None:
continue
dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM, async_op=True)
param.grad.data /= size
torch.cuda.synchronize()
def count_parameters(model):
return sum(p.numel() for p in model.parameters()) # if p.requires_grad)
def get_device(model):
param = next(model.parameters())
return param.device