LION/models/dense.py
2023-01-23 00:14:49 -05:00

81 lines
3.2 KiB
Python

# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
""" copied and modified from https://github.com/CW-Huang/sdeflow-light/blob/524650bc5ad69522b3e0905672deef0650374512/lib/models/unet.py """
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.init import _calculate_fan_in_and_fan_out
import numpy as np
def _calculate_correct_fan(tensor, mode):
"""
copied and modified from https://github.com/pytorch/pytorch/blob/master/torch/nn/init.py#L337
"""
mode = mode.lower()
valid_modes = ['fan_in', 'fan_out', 'fan_avg']
if mode not in valid_modes:
raise ValueError("Mode {} not supported, please use one of {}".format(mode, valid_modes))
fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
return fan_in if mode == 'fan_in' else fan_out
def kaiming_uniform_(tensor, gain=1., mode='fan_in'):
r"""Fills the input `Tensor` with values according to the method
described in `Delving deep into rectifiers: Surpassing human-level
performance on ImageNet classification` - He, K. et al. (2015), using a
uniform distribution. The resulting tensor will have values sampled from
:math:`\mathcal{U}(-\text{bound}, \text{bound})` where
.. math::
\text{bound} = \text{gain} \times \sqrt{\frac{3}{\text{fan\_mode}}}
Also known as He initialization.
Args:
tensor: an n-dimensional `torch.Tensor`
gain: multiplier to the dispersion
mode: either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'``
preserves the magnitude of the variance of the weights in the
forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the
backwards pass.
Examples:
>>> w = torch.empty(3, 5)
>>> nn.init.kaiming_uniform_(w, mode='fan_in')
"""
fan = _calculate_correct_fan(tensor, mode)
# gain = calculate_gain(nonlinearity, a)
var = gain / max(1., fan)
bound = math.sqrt(3.0 * var) # Calculate uniform bounds from standard deviation
with torch.no_grad():
return tensor.uniform_(-bound, bound)
def variance_scaling_init_(tensor, scale):
return kaiming_uniform_(tensor, gain=1e-10 if scale == 0 else scale, mode='fan_avg')
def dense(in_channels, out_channels, init_scale=1.):
lin = nn.Linear(in_channels, out_channels)
variance_scaling_init_(lin.weight, scale=init_scale)
nn.init.zeros_(lin.bias)
return lin
def conv2d(in_planes, out_planes, kernel_size=(3, 3), stride=1, dilation=1, padding=1, bias=True, padding_mode='zeros',
init_scale=1.):
conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation,
bias=bias, padding_mode=padding_mode)
variance_scaling_init_(conv.weight, scale=init_scale)
if bias:
nn.init.zeros_(conv.bias)
return conv