2023-01-23 05:14:49 +00:00
|
|
|
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
|
|
#
|
|
|
|
# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
|
|
|
|
# and proprietary rights in and to this software, related documentation
|
|
|
|
# and any modifications thereto. Any use, reproduction, disclosure or
|
|
|
|
# distribution of this software and related documentation without an express
|
|
|
|
# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
|
|
|
|
""" implement the gloabl prior for LION
|
|
|
|
"""
|
|
|
|
import torch.nn as nn
|
|
|
|
from loguru import logger
|
|
|
|
import functools
|
|
|
|
import torch
|
|
|
|
from ..utils import init_temb_fun, mask_inactive_variables
|
|
|
|
|
|
|
|
class SE(nn.Module):
|
|
|
|
def __init__(self, channel, reduction=8):
|
|
|
|
super().__init__()
|
|
|
|
self.fc = nn.Sequential(
|
|
|
|
nn.Conv2d(channel, channel // reduction, 1, 1, bias=False),
|
|
|
|
nn.ReLU(inplace=True),
|
|
|
|
nn.Conv2d(channel // reduction, channel, 1, 1, bias=False),
|
|
|
|
nn.Sigmoid()
|
|
|
|
)
|
|
|
|
|
|
|
|
def forward(self, inputs):
|
|
|
|
return inputs * self.fc(inputs)
|
|
|
|
|
|
|
|
class ResBlockSEClip(nn.Module):
|
|
|
|
"""
|
|
|
|
fixed the conv0 not used error in ResBlockSE
|
|
|
|
"""
|
|
|
|
def __init__(self, input_dim, output_dim):
|
|
|
|
super().__init__()
|
|
|
|
self.non_linearity = nn.ReLU(inplace=True)
|
|
|
|
self.input_dim = input_dim
|
|
|
|
self.output_dim = output_dim
|
|
|
|
self.conv1 = nn.Conv2d(input_dim*2, output_dim, 1, 1)
|
|
|
|
self.conv2 = nn.Conv2d(output_dim, output_dim, 1, 1)
|
|
|
|
in_ch = self.output_dim
|
|
|
|
self.SE = SE(in_ch)
|
|
|
|
def forward(self, x, t):
|
|
|
|
## logger.info('x: {}, t: {}, input_dim={}', x.shape, t.shape, self.input_dim)
|
|
|
|
clip_feat = t[:, self.input_dim:].contiguous()
|
|
|
|
t = t[:,:self.input_dim].contiguous()
|
|
|
|
output = x + t
|
|
|
|
output = torch.cat([output, clip_feat], dim=1).contiguous()
|
|
|
|
output = self.conv1(output)
|
|
|
|
output = self.non_linearity(output)
|
|
|
|
output = self.conv2(output)
|
|
|
|
output = self.non_linearity(output)
|
|
|
|
output = self.SE(output)
|
|
|
|
shortcut = x
|
|
|
|
return shortcut + output
|
|
|
|
def __repr__(self):
|
|
|
|
return "ResBlockSEClip(%d, %d)"%(self.input_dim, self.output_dim)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ResBlockSEDrop(nn.Module):
|
|
|
|
"""
|
|
|
|
fixed the conv0 not used error in ResBlockSE
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(self, input_dim, output_dim, dropout):
|
|
|
|
super().__init__()
|
|
|
|
self.non_linearity = nn.ReLU(inplace=True)
|
|
|
|
self.input_dim = input_dim
|
|
|
|
self.output_dim = output_dim
|
|
|
|
self.conv1 = nn.Conv2d(input_dim, output_dim, 1, 1)
|
|
|
|
self.conv2 = nn.Conv2d(output_dim, output_dim, 1, 1)
|
|
|
|
in_ch = self.output_dim
|
|
|
|
self.SE = SE(in_ch)
|
|
|
|
self.dropout = nn.Dropout(dropout)
|
|
|
|
self.dropout_ratio = dropout
|
|
|
|
|
|
|
|
def forward(self, x, t):
|
|
|
|
output = x + t
|
|
|
|
output = self.conv1(output)
|
|
|
|
output = self.non_linearity(output)
|
|
|
|
output = self.dropout(output)
|
|
|
|
output = self.conv2(output)
|
|
|
|
output = self.non_linearity(output)
|
|
|
|
output = self.SE(output)
|
|
|
|
shortcut = x
|
|
|
|
return shortcut + output
|
|
|
|
|
|
|
|
def __repr__(self):
|
|
|
|
return "ResBlockSE_withdropout(%d, %d, drop=%f)" % (
|
|
|
|
self.input_dim, self.output_dim, self.dropout_ratio)
|
|
|
|
|
|
|
|
|
|
|
|
class ResBlock(nn.Module):
|
|
|
|
def __init__(self, input_dim, output_dim):
|
|
|
|
# resample=None, act=nn.ELU(),
|
|
|
|
# normalization=nn.InstanceNorm2d, adjust_padding=False, dilation=1):
|
|
|
|
super().__init__()
|
|
|
|
self.non_linearity = nn.ELU()
|
|
|
|
self.input_dim = input_dim
|
|
|
|
self.output_dim = output_dim
|
|
|
|
self.conv1 = nn.Conv2d(input_dim, output_dim, 1, 1)
|
|
|
|
self.conv2 = nn.Conv2d(output_dim, output_dim, 1, 1)
|
|
|
|
in_ch = self.output_dim
|
|
|
|
self.normalize1 = nn.GroupNorm(num_groups=min(in_ch // 4, 32),
|
|
|
|
num_channels=in_ch, eps=1e-6)
|
|
|
|
self.normalize2 = nn.GroupNorm(num_groups=min(in_ch // 4, 32),
|
|
|
|
num_channels=in_ch, eps=1e-6)
|
|
|
|
|
|
|
|
def forward(self, x, t):
|
|
|
|
x = x + t
|
|
|
|
output = self.conv1(x)
|
|
|
|
output = self.normalize1(output)
|
|
|
|
output = self.non_linearity(output)
|
|
|
|
output = self.conv2(output)
|
|
|
|
output = self.normalize2(output)
|
|
|
|
output = self.non_linearity(output)
|
|
|
|
shortcut = x
|
|
|
|
return shortcut + output
|
|
|
|
|
|
|
|
def __repr__(self):
|
|
|
|
return "ResBlock(%d, %d)" % (self.input_dim, self.output_dim)
|
|
|
|
|
|
|
|
|
|
|
|
class Prior(nn.Module):
|
|
|
|
building_block = ResBlock
|
|
|
|
|
|
|
|
def __init__(self, args, num_input_channels, *oargs, **kwargs):
|
|
|
|
super().__init__()
|
|
|
|
# args: cfg.sde
|
|
|
|
# oargs: other argument: the global argument
|
|
|
|
self.condition_input = kwargs.get('condition_input', False)
|
|
|
|
self.cfg = oargs[0]
|
|
|
|
self.clip_forge_enable = self.cfg.clipforge.enable # kwargs.get('clipforge.enable', 0)
|
|
|
|
|
|
|
|
logger.info('[Build Resnet Prior] Has condition input: {}; clipforge {}; '
|
|
|
|
'learn_mixing_logit={}, ', self.condition_input,
|
|
|
|
self.clip_forge_enable, args.learn_mixing_logit)
|
|
|
|
|
|
|
|
self.act = act = nn.SiLU()
|
|
|
|
self.num_scales = args.num_scales_dae
|
|
|
|
self.num_input_channels = num_input_channels
|
|
|
|
|
|
|
|
self.nf = nf = args.num_channels_dae
|
|
|
|
num_cell_per_scale_dae = args.num_cell_per_scale_dae if 'num_cell_per_scale_dae' not in kwargs else kwargs[
|
|
|
|
'num_cell_per_scale_dae']
|
|
|
|
|
|
|
|
# take clip feature as input
|
|
|
|
if self.clip_forge_enable:
|
|
|
|
self.clip_feat_mapping = nn.Conv1d(self.cfg.clipforge.feat_dim, self.nf, 1)
|
|
|
|
|
|
|
|
# mixed_prediction #
|
|
|
|
self.mixed_prediction = args.mixed_prediction # This enables mixed prediction
|
|
|
|
if self.mixed_prediction:
|
|
|
|
logger.info('init-mixing_logit = {}, after sigmoid = {}',
|
|
|
|
args.mixing_logit_init, torch.sigmoid(torch.tensor(args.mixing_logit_init)))
|
|
|
|
assert(args.mixing_logit_init), f'require learning'
|
|
|
|
# if not args.learn_mixing_logit and args.hypara_mixing_logit:
|
|
|
|
# # not learn, treat it as hyparameters
|
|
|
|
# init = args.mixing_logit_init * torch.ones(size=[1, num_input_channels, 1, 1])
|
|
|
|
# self.mixing_logit = torch.nn.Parameter(init, requires_grad=False) # not update
|
|
|
|
# self.is_active = None
|
|
|
|
# elif not args.learn_mixing_logit: # not learn, loaded from c04cd1h exp
|
2023-04-07 11:33:06 +00:00
|
|
|
# init = torch.load('./exp/1110/chair/c04cd1h_hvae3s_390f8dhInitSepesTrainvae0_hvaeB72l1E4W1/mlogit.pt')
|
2023-01-23 05:14:49 +00:00
|
|
|
# self.mixing_logit = torch.nn.Parameter(init, requires_grad=False)
|
|
|
|
# self.is_active = None
|
|
|
|
# else:
|
|
|
|
if True:
|
|
|
|
init = args.mixing_logit_init * torch.ones(size=[1, num_input_channels, 1, 1])
|
|
|
|
self.mixing_logit = torch.nn.Parameter(init, requires_grad=True)
|
|
|
|
self.is_active = None
|
|
|
|
else: # no mixing_logit
|
|
|
|
self.mixing_logit = None
|
|
|
|
self.is_active = None
|
|
|
|
|
|
|
|
self.embedding_dim = args.embedding_dim
|
|
|
|
self.embedding_dim_mult = 4
|
|
|
|
self.temb_fun = init_temb_fun(args.embedding_type, args.embedding_scale, args.embedding_dim)
|
|
|
|
logger.info('[temb_fun] embedding_type={}, embedding_scale={}, embedding_dim={}',
|
|
|
|
args.embedding_type, args.embedding_scale, args.embedding_dim)
|
|
|
|
# exit()
|
|
|
|
modules = []
|
|
|
|
modules.append(nn.Conv2d(self.embedding_dim, self.embedding_dim * 4, 1, 1))
|
|
|
|
modules.append(nn.Conv2d(self.embedding_dim * 4, nf, 1, 1))
|
|
|
|
self.temb_layer = nn.Sequential(*modules)
|
|
|
|
|
|
|
|
modules = []
|
|
|
|
input_channels = num_input_channels
|
|
|
|
self.input_layer = nn.Conv2d(input_channels, nf, 1, 1)
|
|
|
|
in_ch = nf
|
|
|
|
for i_block in range(args.num_cell_per_scale_dae):
|
|
|
|
modules.append(self.building_block(nf, nf))
|
|
|
|
self.output_layer = nn.Conv2d(nf, input_channels, 1, 1)
|
|
|
|
self.all_modules = nn.ModuleList(modules)
|
|
|
|
|
|
|
|
def forward(self, x, t, **kwargs):
|
|
|
|
# timestep/noise_level embedding; only for continuous training
|
|
|
|
# time embedding
|
|
|
|
if t.dim() == 0:
|
|
|
|
t = t.expand(1)
|
|
|
|
temb = self.temb_fun(t)[:, :, None, None] # make it 4d
|
|
|
|
temb = self.temb_layer(temb)
|
|
|
|
|
|
|
|
if self.clip_forge_enable:
|
|
|
|
clip_feat = kwargs['clip_feat']
|
|
|
|
clip_feat = self.clip_feat_mapping(clip_feat[:, :, None])[:, :, :, None] # B,D -> BD1->B,D,1,1
|
|
|
|
if temb.shape[0] == 1 and temb.shape[0] < clip_feat.shape[0]:
|
|
|
|
temb = temb.expand(clip_feat.shape[0], -1, -1, -1)
|
|
|
|
temb = torch.cat([temb, clip_feat], dim=1) # add to temb feature
|
|
|
|
# mask out inactive variables
|
|
|
|
if self.mixed_prediction and self.is_active is not None:
|
|
|
|
x = mask_inactive_variables(x, self.is_active)
|
|
|
|
x = self.input_layer(x)
|
|
|
|
for layer in self.all_modules:
|
|
|
|
enc_input = x
|
|
|
|
x = layer(enc_input, temb)
|
|
|
|
|
|
|
|
h = self.output_layer(x)
|
|
|
|
return h
|
|
|
|
|
|
|
|
|
|
|
|
class PriorSEDrop(Prior):
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
|
|
self.building_block = functools.partial(ResBlockSEDrop, dropout=args[0].dropout)
|
|
|
|
super().__init__(*args, **kwargs)
|
|
|
|
|
|
|
|
class PriorSEClip(Prior):
|
|
|
|
building_block = ResBlockSEClip
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
|
|
super().__init__(*args, **kwargs)
|
|
|
|
|