LION/models/score_sde/resnet.py
2023-04-07 13:33:06 +02:00

231 lines
9.2 KiB
Python

# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
""" implement the gloabl prior for LION
"""
import torch.nn as nn
from loguru import logger
import functools
import torch
from ..utils import init_temb_fun, mask_inactive_variables
class SE(nn.Module):
def __init__(self, channel, reduction=8):
super().__init__()
self.fc = nn.Sequential(
nn.Conv2d(channel, channel // reduction, 1, 1, bias=False),
nn.ReLU(inplace=True),
nn.Conv2d(channel // reduction, channel, 1, 1, bias=False),
nn.Sigmoid()
)
def forward(self, inputs):
return inputs * self.fc(inputs)
class ResBlockSEClip(nn.Module):
"""
fixed the conv0 not used error in ResBlockSE
"""
def __init__(self, input_dim, output_dim):
super().__init__()
self.non_linearity = nn.ReLU(inplace=True)
self.input_dim = input_dim
self.output_dim = output_dim
self.conv1 = nn.Conv2d(input_dim*2, output_dim, 1, 1)
self.conv2 = nn.Conv2d(output_dim, output_dim, 1, 1)
in_ch = self.output_dim
self.SE = SE(in_ch)
def forward(self, x, t):
## logger.info('x: {}, t: {}, input_dim={}', x.shape, t.shape, self.input_dim)
clip_feat = t[:, self.input_dim:].contiguous()
t = t[:,:self.input_dim].contiguous()
output = x + t
output = torch.cat([output, clip_feat], dim=1).contiguous()
output = self.conv1(output)
output = self.non_linearity(output)
output = self.conv2(output)
output = self.non_linearity(output)
output = self.SE(output)
shortcut = x
return shortcut + output
def __repr__(self):
return "ResBlockSEClip(%d, %d)"%(self.input_dim, self.output_dim)
class ResBlockSEDrop(nn.Module):
"""
fixed the conv0 not used error in ResBlockSE
"""
def __init__(self, input_dim, output_dim, dropout):
super().__init__()
self.non_linearity = nn.ReLU(inplace=True)
self.input_dim = input_dim
self.output_dim = output_dim
self.conv1 = nn.Conv2d(input_dim, output_dim, 1, 1)
self.conv2 = nn.Conv2d(output_dim, output_dim, 1, 1)
in_ch = self.output_dim
self.SE = SE(in_ch)
self.dropout = nn.Dropout(dropout)
self.dropout_ratio = dropout
def forward(self, x, t):
output = x + t
output = self.conv1(output)
output = self.non_linearity(output)
output = self.dropout(output)
output = self.conv2(output)
output = self.non_linearity(output)
output = self.SE(output)
shortcut = x
return shortcut + output
def __repr__(self):
return "ResBlockSE_withdropout(%d, %d, drop=%f)" % (
self.input_dim, self.output_dim, self.dropout_ratio)
class ResBlock(nn.Module):
def __init__(self, input_dim, output_dim):
# resample=None, act=nn.ELU(),
# normalization=nn.InstanceNorm2d, adjust_padding=False, dilation=1):
super().__init__()
self.non_linearity = nn.ELU()
self.input_dim = input_dim
self.output_dim = output_dim
self.conv1 = nn.Conv2d(input_dim, output_dim, 1, 1)
self.conv2 = nn.Conv2d(output_dim, output_dim, 1, 1)
in_ch = self.output_dim
self.normalize1 = nn.GroupNorm(num_groups=min(in_ch // 4, 32),
num_channels=in_ch, eps=1e-6)
self.normalize2 = nn.GroupNorm(num_groups=min(in_ch // 4, 32),
num_channels=in_ch, eps=1e-6)
def forward(self, x, t):
x = x + t
output = self.conv1(x)
output = self.normalize1(output)
output = self.non_linearity(output)
output = self.conv2(output)
output = self.normalize2(output)
output = self.non_linearity(output)
shortcut = x
return shortcut + output
def __repr__(self):
return "ResBlock(%d, %d)" % (self.input_dim, self.output_dim)
class Prior(nn.Module):
building_block = ResBlock
def __init__(self, args, num_input_channels, *oargs, **kwargs):
super().__init__()
# args: cfg.sde
# oargs: other argument: the global argument
self.condition_input = kwargs.get('condition_input', False)
self.cfg = oargs[0]
self.clip_forge_enable = self.cfg.clipforge.enable # kwargs.get('clipforge.enable', 0)
logger.info('[Build Resnet Prior] Has condition input: {}; clipforge {}; '
'learn_mixing_logit={}, ', self.condition_input,
self.clip_forge_enable, args.learn_mixing_logit)
self.act = act = nn.SiLU()
self.num_scales = args.num_scales_dae
self.num_input_channels = num_input_channels
self.nf = nf = args.num_channels_dae
num_cell_per_scale_dae = args.num_cell_per_scale_dae if 'num_cell_per_scale_dae' not in kwargs else kwargs[
'num_cell_per_scale_dae']
# take clip feature as input
if self.clip_forge_enable:
self.clip_feat_mapping = nn.Conv1d(self.cfg.clipforge.feat_dim, self.nf, 1)
# mixed_prediction #
self.mixed_prediction = args.mixed_prediction # This enables mixed prediction
if self.mixed_prediction:
logger.info('init-mixing_logit = {}, after sigmoid = {}',
args.mixing_logit_init, torch.sigmoid(torch.tensor(args.mixing_logit_init)))
assert(args.mixing_logit_init), f'require learning'
# if not args.learn_mixing_logit and args.hypara_mixing_logit:
# # not learn, treat it as hyparameters
# init = args.mixing_logit_init * torch.ones(size=[1, num_input_channels, 1, 1])
# self.mixing_logit = torch.nn.Parameter(init, requires_grad=False) # not update
# self.is_active = None
# elif not args.learn_mixing_logit: # not learn, loaded from c04cd1h exp
# init = torch.load('./exp/1110/chair/c04cd1h_hvae3s_390f8dhInitSepesTrainvae0_hvaeB72l1E4W1/mlogit.pt')
# self.mixing_logit = torch.nn.Parameter(init, requires_grad=False)
# self.is_active = None
# else:
if True:
init = args.mixing_logit_init * torch.ones(size=[1, num_input_channels, 1, 1])
self.mixing_logit = torch.nn.Parameter(init, requires_grad=True)
self.is_active = None
else: # no mixing_logit
self.mixing_logit = None
self.is_active = None
self.embedding_dim = args.embedding_dim
self.embedding_dim_mult = 4
self.temb_fun = init_temb_fun(args.embedding_type, args.embedding_scale, args.embedding_dim)
logger.info('[temb_fun] embedding_type={}, embedding_scale={}, embedding_dim={}',
args.embedding_type, args.embedding_scale, args.embedding_dim)
# exit()
modules = []
modules.append(nn.Conv2d(self.embedding_dim, self.embedding_dim * 4, 1, 1))
modules.append(nn.Conv2d(self.embedding_dim * 4, nf, 1, 1))
self.temb_layer = nn.Sequential(*modules)
modules = []
input_channels = num_input_channels
self.input_layer = nn.Conv2d(input_channels, nf, 1, 1)
in_ch = nf
for i_block in range(args.num_cell_per_scale_dae):
modules.append(self.building_block(nf, nf))
self.output_layer = nn.Conv2d(nf, input_channels, 1, 1)
self.all_modules = nn.ModuleList(modules)
def forward(self, x, t, **kwargs):
# timestep/noise_level embedding; only for continuous training
# time embedding
if t.dim() == 0:
t = t.expand(1)
temb = self.temb_fun(t)[:, :, None, None] # make it 4d
temb = self.temb_layer(temb)
if self.clip_forge_enable:
clip_feat = kwargs['clip_feat']
clip_feat = self.clip_feat_mapping(clip_feat[:, :, None])[:, :, :, None] # B,D -> BD1->B,D,1,1
if temb.shape[0] == 1 and temb.shape[0] < clip_feat.shape[0]:
temb = temb.expand(clip_feat.shape[0], -1, -1, -1)
temb = torch.cat([temb, clip_feat], dim=1) # add to temb feature
# mask out inactive variables
if self.mixed_prediction and self.is_active is not None:
x = mask_inactive_variables(x, self.is_active)
x = self.input_layer(x)
for layer in self.all_modules:
enc_input = x
x = layer(enc_input, temb)
h = self.output_layer(x)
return h
class PriorSEDrop(Prior):
def __init__(self, *args, **kwargs):
self.building_block = functools.partial(ResBlockSEDrop, dropout=args[0].dropout)
super().__init__(*args, **kwargs)
class PriorSEClip(Prior):
building_block = ResBlockSEClip
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)