2021-10-19 20:54:46 +00:00
|
|
|
import torch
|
2023-04-11 09:12:58 +00:00
|
|
|
import torch.nn as nn
|
2021-10-19 20:54:46 +00:00
|
|
|
|
2023-04-11 09:12:58 +00:00
|
|
|
__all__ = ["SharedMLP"]
|
2021-10-19 20:54:46 +00:00
|
|
|
|
|
|
|
|
|
|
|
class Swish(nn.Module):
|
2023-04-11 09:12:58 +00:00
|
|
|
def forward(self, x):
|
|
|
|
return x * torch.sigmoid(x)
|
|
|
|
|
2021-10-19 20:54:46 +00:00
|
|
|
|
|
|
|
class SharedMLP(nn.Module):
|
|
|
|
def __init__(self, in_channels, out_channels, dim=1):
|
|
|
|
super().__init__()
|
|
|
|
if dim == 1:
|
|
|
|
conv = nn.Conv1d
|
|
|
|
bn = nn.GroupNorm
|
|
|
|
elif dim == 2:
|
|
|
|
conv = nn.Conv2d
|
|
|
|
bn = nn.GroupNorm
|
|
|
|
else:
|
|
|
|
raise ValueError
|
|
|
|
if not isinstance(out_channels, (list, tuple)):
|
|
|
|
out_channels = [out_channels]
|
|
|
|
layers = []
|
|
|
|
for oc in out_channels:
|
2023-04-11 09:12:58 +00:00
|
|
|
layers.extend(
|
|
|
|
[
|
|
|
|
conv(in_channels, oc, 1),
|
|
|
|
bn(8, oc),
|
|
|
|
Swish(),
|
|
|
|
]
|
|
|
|
)
|
2021-10-19 20:54:46 +00:00
|
|
|
in_channels = oc
|
|
|
|
self.layers = nn.Sequential(*layers)
|
|
|
|
|
|
|
|
def forward(self, inputs):
|
|
|
|
if isinstance(inputs, (list, tuple)):
|
|
|
|
return (self.layers(inputs[0]), *inputs[1:])
|
|
|
|
else:
|
|
|
|
return self.layers(inputs)
|