# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property # and proprietary rights in and to this software, related documentation # and any modifications thereto. Any use, reproduction, disclosure or # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. """ copied and modified from https://github.com/CW-Huang/sdeflow-light/blob/524650bc5ad69522b3e0905672deef0650374512/lib/models/unet.py """ import math import torch import torch.nn as nn import torch.nn.functional as F from torch.nn.init import _calculate_fan_in_and_fan_out import numpy as np def _calculate_correct_fan(tensor, mode): """ copied and modified from https://github.com/pytorch/pytorch/blob/master/torch/nn/init.py#L337 """ mode = mode.lower() valid_modes = ['fan_in', 'fan_out', 'fan_avg'] if mode not in valid_modes: raise ValueError("Mode {} not supported, please use one of {}".format(mode, valid_modes)) fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) return fan_in if mode == 'fan_in' else fan_out def kaiming_uniform_(tensor, gain=1., mode='fan_in'): r"""Fills the input `Tensor` with values according to the method described in `Delving deep into rectifiers: Surpassing human-level performance on ImageNet classification` - He, K. et al. (2015), using a uniform distribution. The resulting tensor will have values sampled from :math:`\mathcal{U}(-\text{bound}, \text{bound})` where .. math:: \text{bound} = \text{gain} \times \sqrt{\frac{3}{\text{fan\_mode}}} Also known as He initialization. Args: tensor: an n-dimensional `torch.Tensor` gain: multiplier to the dispersion mode: either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'`` preserves the magnitude of the variance of the weights in the forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the backwards pass. Examples: >>> w = torch.empty(3, 5) >>> nn.init.kaiming_uniform_(w, mode='fan_in') """ fan = _calculate_correct_fan(tensor, mode) # gain = calculate_gain(nonlinearity, a) var = gain / max(1., fan) bound = math.sqrt(3.0 * var) # Calculate uniform bounds from standard deviation with torch.no_grad(): return tensor.uniform_(-bound, bound) def variance_scaling_init_(tensor, scale): return kaiming_uniform_(tensor, gain=1e-10 if scale == 0 else scale, mode='fan_avg') def dense(in_channels, out_channels, init_scale=1.): lin = nn.Linear(in_channels, out_channels) variance_scaling_init_(lin.weight, scale=init_scale) nn.init.zeros_(lin.bias) return lin def conv2d(in_planes, out_planes, kernel_size=(3, 3), stride=1, dilation=1, padding=1, bias=True, padding_mode='zeros', init_scale=1.): conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, bias=bias, padding_mode=padding_mode) variance_scaling_init_(conv.weight, scale=init_scale) if bias: nn.init.zeros_(conv.bias) return conv