116 lines
5 KiB
Python
116 lines
5 KiB
Python
|
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||
|
#
|
||
|
# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
|
||
|
# and proprietary rights in and to this software, related documentation
|
||
|
# and any modifications thereto. Any use, reproduction, disclosure or
|
||
|
# distribution of this software and related documentation without an express
|
||
|
# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
|
||
|
""" copied and modified from https://github.com/NVlabs/LSGM/blob/5eae2f385c014f2250c3130152b6be711f6a3a5a/util/sr_utils.py """
|
||
|
import torch
|
||
|
import torch.nn as nn
|
||
|
import torch.nn.functional as F
|
||
|
from loguru import logger
|
||
|
|
||
|
|
||
|
@torch.jit.script
|
||
|
def fused_abs_max_add(weight: torch.Tensor, loss: torch.Tensor) -> torch.Tensor:
|
||
|
loss += torch.max(torch.abs(weight))
|
||
|
return loss
|
||
|
|
||
|
|
||
|
class SpectralNormCalculator:
|
||
|
def __init__(self, num_power_iter=4, custom_conv=False):
|
||
|
self.num_power_iter = num_power_iter
|
||
|
# increase the number of iterations for the first time
|
||
|
self.num_power_iter_init = 10 * num_power_iter
|
||
|
self.all_conv_layers = []
|
||
|
# left/right singular vectors used for SR
|
||
|
self.sr_u = {}
|
||
|
self.sr_v = {}
|
||
|
self.all_bn_layers = []
|
||
|
self.custom_conv = custom_conv
|
||
|
|
||
|
def add_conv_layers(self, model):
|
||
|
for n, layer in model.named_modules():
|
||
|
if self.custom_conv:
|
||
|
# add our customized conv layers
|
||
|
if isinstance(layer, Conv2D) or isinstance(layer, ARConv2d):
|
||
|
self.all_conv_layers.append(layer)
|
||
|
else:
|
||
|
if isinstance(layer, nn.Conv2d) or isinstance(layer, nn.Conv3d) or \
|
||
|
isinstance(layer, nn.Conv1d) or isinstance(layer, nn.Linear): # add pytorch conv layers
|
||
|
self.all_conv_layers.append(layer)
|
||
|
|
||
|
def add_bn_layers(self, model):
|
||
|
for n, layer in model.named_modules():
|
||
|
if isinstance(layer, nn.BatchNorm2d) or isinstance(layer, nn.SyncBatchNorm) or \
|
||
|
isinstance(layer, nn.GroupNorm):
|
||
|
self.all_bn_layers.append(layer)
|
||
|
|
||
|
def spectral_norm_parallel(self):
|
||
|
""" This method computes spectral normalization for all conv layers in parallel. This method should be called
|
||
|
after calling the forward method of all the conv layers in each iteration. """
|
||
|
|
||
|
weights = {} # a dictionary indexed by the shape of weights
|
||
|
for l in self.all_conv_layers:
|
||
|
weight = l.weight_normalized if self.custom_conv else l.weight
|
||
|
if not isinstance(l, nn.Linear):
|
||
|
weight_mat = weight.view(weight.size(0), -1)
|
||
|
else:
|
||
|
weight_mat = weight
|
||
|
## logger.info('mat weight: {} | weight: {}', weight_mat.shape, weight.shape)
|
||
|
|
||
|
if weight_mat.shape not in weights:
|
||
|
weights[weight_mat.shape] = []
|
||
|
|
||
|
weights[weight_mat.shape].append(weight_mat)
|
||
|
|
||
|
loss = 0
|
||
|
for i in weights:
|
||
|
weights[i] = torch.stack(weights[i], dim=0)
|
||
|
with torch.no_grad():
|
||
|
num_iter = self.num_power_iter
|
||
|
if i not in self.sr_u:
|
||
|
num_w, row, col = weights[i].shape
|
||
|
self.sr_u[i] = F.normalize(torch.ones(
|
||
|
num_w, row).normal_(0, 1).cuda(), dim=1, eps=1e-3)
|
||
|
self.sr_v[i] = F.normalize(torch.ones(
|
||
|
num_w, col).normal_(0, 1).cuda(), dim=1, eps=1e-3)
|
||
|
num_iter = self.num_power_iter_init
|
||
|
|
||
|
for j in range(num_iter):
|
||
|
# Spectral norm of weight equals to `u^T W v`, where `u` and `v`
|
||
|
# are the first left and right singular vectors.
|
||
|
# This power iteration produces approximations of `u` and `v`.
|
||
|
self.sr_v[i] = F.normalize(torch.matmul(self.sr_u[i].unsqueeze(1), weights[i]).squeeze(1),
|
||
|
dim=1, eps=1e-3) # bx1xr * bxrxc --> bx1xc --> bxc
|
||
|
self.sr_u[i] = F.normalize(torch.matmul(weights[i], self.sr_v[i].unsqueeze(2)).squeeze(2),
|
||
|
dim=1, eps=1e-3) # bxrxc * bxcx1 --> bxrx1 --> bxr
|
||
|
|
||
|
sigma = torch.matmul(self.sr_u[i].unsqueeze(
|
||
|
1), torch.matmul(weights[i], self.sr_v[i].unsqueeze(2)))
|
||
|
loss += torch.sum(sigma)
|
||
|
return loss
|
||
|
|
||
|
def batchnorm_loss(self):
|
||
|
loss = torch.zeros(size=()).cuda()
|
||
|
for l in self.all_bn_layers:
|
||
|
if l.affine:
|
||
|
loss = fused_abs_max_add(l.weight, loss)
|
||
|
|
||
|
return loss
|
||
|
|
||
|
def state_dict(self):
|
||
|
return {
|
||
|
'sr_v': self.sr_v,
|
||
|
'sr_u': self.sr_u
|
||
|
}
|
||
|
|
||
|
def load_state_dict(self, state_dict, device):
|
||
|
# map the tensor to the device id of self.sr_v
|
||
|
for s in state_dict['sr_v']:
|
||
|
self.sr_v[s] = state_dict['sr_v'][s].to(device)
|
||
|
|
||
|
for s in state_dict['sr_u']:
|
||
|
self.sr_u[s] = state_dict['sr_u'][s].to(device)
|