104 lines
3.3 KiB
Python
104 lines
3.3 KiB
Python
|
import torch
|
||
|
import torch.nn as nn
|
||
|
|
||
|
|
||
|
def weights_init(m):
|
||
|
classname = m.__class__.__name__
|
||
|
if classname.find('Linear') != -1 or classname.find('Conv') != -1:
|
||
|
nn.init.constant_(m.weight, 0)
|
||
|
nn.init.normal_(m.bias, 0, 0.01)
|
||
|
|
||
|
|
||
|
class IgnoreLinear(nn.Module):
|
||
|
def __init__(self, dim_in, dim_out, dim_c):
|
||
|
super(IgnoreLinear, self).__init__()
|
||
|
self._layer = nn.Linear(dim_in, dim_out)
|
||
|
|
||
|
def forward(self, context, x):
|
||
|
return self._layer(x)
|
||
|
|
||
|
|
||
|
class ConcatLinear(nn.Module):
|
||
|
def __init__(self, dim_in, dim_out, dim_c):
|
||
|
super(ConcatLinear, self).__init__()
|
||
|
self._layer = nn.Linear(dim_in + 1 + dim_c, dim_out)
|
||
|
|
||
|
def forward(self, context, x, c):
|
||
|
if x.dim() == 3:
|
||
|
context = context.unsqueeze(1).expand(-1, x.size(1), -1)
|
||
|
x_context = torch.cat((x, context), dim=2)
|
||
|
return self._layer(x_context)
|
||
|
|
||
|
|
||
|
class ConcatLinear_v2(nn.Module):
|
||
|
def __init__(self, dim_in, dim_out, dim_c):
|
||
|
super(ConcatLinear_v2, self).__init__()
|
||
|
self._layer = nn.Linear(dim_in, dim_out)
|
||
|
self._hyper_bias = nn.Linear(1 + dim_c, dim_out, bias=False)
|
||
|
|
||
|
def forward(self, context, x):
|
||
|
bias = self._hyper_bias(context)
|
||
|
if x.dim() == 3:
|
||
|
bias = bias.unsqueeze(1)
|
||
|
return self._layer(x) + bias
|
||
|
|
||
|
|
||
|
class SquashLinear(nn.Module):
|
||
|
def __init__(self, dim_in, dim_out, dim_c):
|
||
|
super(SquashLinear, self).__init__()
|
||
|
self._layer = nn.Linear(dim_in, dim_out)
|
||
|
self._hyper = nn.Linear(1 + dim_c, dim_out)
|
||
|
|
||
|
def forward(self, context, x):
|
||
|
gate = torch.sigmoid(self._hyper(context))
|
||
|
if x.dim() == 3:
|
||
|
gate = gate.unsqueeze(1)
|
||
|
return self._layer(x) * gate
|
||
|
|
||
|
|
||
|
class ScaleLinear(nn.Module):
|
||
|
def __init__(self, dim_in, dim_out, dim_c):
|
||
|
super(ScaleLinear, self).__init__()
|
||
|
self._layer = nn.Linear(dim_in, dim_out)
|
||
|
self._hyper = nn.Linear(1 + dim_c, dim_out)
|
||
|
|
||
|
def forward(self, context, x):
|
||
|
gate = self._hyper(context)
|
||
|
if x.dim() == 3:
|
||
|
gate = gate.unsqueeze(1)
|
||
|
return self._layer(x) * gate
|
||
|
|
||
|
|
||
|
class ConcatSquashLinear(nn.Module):
|
||
|
def __init__(self, dim_in, dim_out, dim_c):
|
||
|
super(ConcatSquashLinear, self).__init__()
|
||
|
self._layer = nn.Linear(dim_in, dim_out)
|
||
|
self._hyper_bias = nn.Linear(1 + dim_c, dim_out, bias=False)
|
||
|
self._hyper_gate = nn.Linear(1 + dim_c, dim_out)
|
||
|
|
||
|
def forward(self, context, x):
|
||
|
gate = torch.sigmoid(self._hyper_gate(context))
|
||
|
bias = self._hyper_bias(context)
|
||
|
if x.dim() == 3:
|
||
|
gate = gate.unsqueeze(1)
|
||
|
bias = bias.unsqueeze(1)
|
||
|
ret = self._layer(x) * gate + bias
|
||
|
return ret
|
||
|
|
||
|
|
||
|
class ConcatScaleLinear(nn.Module):
|
||
|
def __init__(self, dim_in, dim_out, dim_c):
|
||
|
super(ConcatScaleLinear, self).__init__()
|
||
|
self._layer = nn.Linear(dim_in, dim_out)
|
||
|
self._hyper_bias = nn.Linear(1 + dim_c, dim_out, bias=False)
|
||
|
self._hyper_gate = nn.Linear(1 + dim_c, dim_out)
|
||
|
|
||
|
def forward(self, context, x):
|
||
|
gate = self._hyper_gate(context)
|
||
|
bias = self._hyper_bias(context)
|
||
|
if x.dim() == 3:
|
||
|
gate = gate.unsqueeze(1)
|
||
|
bias = bias.unsqueeze(1)
|
||
|
ret = self._layer(x) * gate + bias
|
||
|
return ret
|