68 lines
2.2 KiB
Python
68 lines
2.2 KiB
Python
|
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||
|
#
|
||
|
# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
|
||
|
# and proprietary rights in and to this software, related documentation
|
||
|
# and any modifications thereto. Any use, reproduction, disclosure or
|
||
|
# distribution of this software and related documentation without an express
|
||
|
# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
|
||
|
"""
|
||
|
adaptive group norm
|
||
|
"""
|
||
|
from loguru import logger
|
||
|
import torch.nn as nn
|
||
|
import torch
|
||
|
import numpy as np
|
||
|
from utils.checker import *
|
||
|
from .dense import dense
|
||
|
import os
|
||
|
|
||
|
class AdaGN(nn.Module):
|
||
|
'''
|
||
|
adaptive group normalization
|
||
|
'''
|
||
|
def __init__(self, ndim, cfg, n_channel):
|
||
|
"""
|
||
|
ndim: dim of the input features
|
||
|
n_channel: number of channels of the inputs
|
||
|
ndim_style: channel of the style features
|
||
|
"""
|
||
|
super().__init__()
|
||
|
style_dim = cfg.latent_pts.style_dim
|
||
|
init_scale = cfg.latent_pts.ada_mlp_init_scale
|
||
|
self.ndim = ndim
|
||
|
self.n_channel = n_channel
|
||
|
self.style_dim = style_dim
|
||
|
self.out_dim = n_channel * 2
|
||
|
self.norm = nn.GroupNorm(8, n_channel)
|
||
|
in_channel = n_channel
|
||
|
self.emd = dense(style_dim, n_channel*2, init_scale=init_scale)
|
||
|
self.emd.bias.data[:in_channel] = 1
|
||
|
self.emd.bias.data[in_channel:] = 0
|
||
|
|
||
|
def __repr__(self):
|
||
|
return f"AdaGN(GN(8, {self.n_channel}), Linear({self.style_dim}, {self.out_dim}))"
|
||
|
|
||
|
def forward(self, image, style):
|
||
|
# style: B,D
|
||
|
# image: B,D,N,1
|
||
|
CHECK2D(style)
|
||
|
style = self.emd(style)
|
||
|
if self.ndim == 3: #B,D,V,V,V
|
||
|
CHECK5D(image)
|
||
|
style = style.view(style.shape[0], -1, 1, 1, 1) # 5D
|
||
|
elif self.ndim == 2: # B,D,N,1
|
||
|
CHECK4D(image)
|
||
|
style = style.view(style.shape[0], -1, 1, 1) # 4D
|
||
|
elif self.ndim == 1: # B,D,N
|
||
|
CHECK3D(image)
|
||
|
style = style.view(style.shape[0], -1, 1) # 4D
|
||
|
else:
|
||
|
raise NotImplementedError
|
||
|
|
||
|
factor, bias = style.chunk(2, 1)
|
||
|
result = self.norm(image)
|
||
|
result = result * factor + bias
|
||
|
return result
|
||
|
|
||
|
|