81 lines
3.2 KiB
Python
81 lines
3.2 KiB
Python
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
#
|
|
# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
|
|
# and proprietary rights in and to this software, related documentation
|
|
# and any modifications thereto. Any use, reproduction, disclosure or
|
|
# distribution of this software and related documentation without an express
|
|
# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
|
|
|
|
""" copied and modified from https://github.com/CW-Huang/sdeflow-light/blob/524650bc5ad69522b3e0905672deef0650374512/lib/models/unet.py """
|
|
|
|
import math
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from torch.nn.init import _calculate_fan_in_and_fan_out
|
|
import numpy as np
|
|
|
|
|
|
def _calculate_correct_fan(tensor, mode):
|
|
"""
|
|
copied and modified from https://github.com/pytorch/pytorch/blob/master/torch/nn/init.py#L337
|
|
"""
|
|
mode = mode.lower()
|
|
valid_modes = ['fan_in', 'fan_out', 'fan_avg']
|
|
if mode not in valid_modes:
|
|
raise ValueError("Mode {} not supported, please use one of {}".format(mode, valid_modes))
|
|
|
|
fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
|
|
return fan_in if mode == 'fan_in' else fan_out
|
|
|
|
|
|
def kaiming_uniform_(tensor, gain=1., mode='fan_in'):
|
|
r"""Fills the input `Tensor` with values according to the method
|
|
described in `Delving deep into rectifiers: Surpassing human-level
|
|
performance on ImageNet classification` - He, K. et al. (2015), using a
|
|
uniform distribution. The resulting tensor will have values sampled from
|
|
:math:`\mathcal{U}(-\text{bound}, \text{bound})` where
|
|
.. math::
|
|
\text{bound} = \text{gain} \times \sqrt{\frac{3}{\text{fan\_mode}}}
|
|
Also known as He initialization.
|
|
Args:
|
|
tensor: an n-dimensional `torch.Tensor`
|
|
gain: multiplier to the dispersion
|
|
mode: either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'``
|
|
preserves the magnitude of the variance of the weights in the
|
|
forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the
|
|
backwards pass.
|
|
Examples:
|
|
>>> w = torch.empty(3, 5)
|
|
>>> nn.init.kaiming_uniform_(w, mode='fan_in')
|
|
"""
|
|
fan = _calculate_correct_fan(tensor, mode)
|
|
# gain = calculate_gain(nonlinearity, a)
|
|
var = gain / max(1., fan)
|
|
bound = math.sqrt(3.0 * var) # Calculate uniform bounds from standard deviation
|
|
with torch.no_grad():
|
|
return tensor.uniform_(-bound, bound)
|
|
|
|
|
|
def variance_scaling_init_(tensor, scale):
|
|
return kaiming_uniform_(tensor, gain=1e-10 if scale == 0 else scale, mode='fan_avg')
|
|
|
|
|
|
def dense(in_channels, out_channels, init_scale=1.):
|
|
lin = nn.Linear(in_channels, out_channels)
|
|
variance_scaling_init_(lin.weight, scale=init_scale)
|
|
nn.init.zeros_(lin.bias)
|
|
return lin
|
|
|
|
def conv2d(in_planes, out_planes, kernel_size=(3, 3), stride=1, dilation=1, padding=1, bias=True, padding_mode='zeros',
|
|
init_scale=1.):
|
|
conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation,
|
|
bias=bias, padding_mode=padding_mode)
|
|
variance_scaling_init_(conv.weight, scale=init_scale)
|
|
if bias:
|
|
nn.init.zeros_(conv.bias)
|
|
return conv
|
|
|
|
|
|
|