Rework of the transposed conv / bilinear up route
Former-commit-id: 07debeb9f2621b53ed513e5ab8a0307b4da57767
This commit is contained in:
parent
2bef826dcc
commit
490ecf8383
4
train.py
4
train.py
|
@ -154,11 +154,11 @@ if __name__ == '__main__':
|
|||
# - For 1 class and background, use n_classes=1
|
||||
# - For 2 classes, use n_classes=1
|
||||
# - For N > 2 classes, use n_classes=N
|
||||
net = UNet(n_channels=3, n_classes=1)
|
||||
net = UNet(n_channels=3, n_classes=1, bilinear=True)
|
||||
logging.info(f'Network:\n'
|
||||
f'\t{net.n_channels} input channels\n'
|
||||
f'\t{net.n_classes} output channels (classes)\n'
|
||||
f'\t{"Bilinear" if net.bilinear else "Dilated conv"} upscaling')
|
||||
f'\t{"Bilinear" if net.bilinear else "Transposed conv"} upscaling')
|
||||
|
||||
if args.load:
|
||||
net.load_state_dict(
|
||||
|
|
|
@ -16,11 +16,12 @@ class UNet(nn.Module):
|
|||
self.down1 = Down(64, 128)
|
||||
self.down2 = Down(128, 256)
|
||||
self.down3 = Down(256, 512)
|
||||
self.down4 = Down(512, 512)
|
||||
self.up1 = Up(1024, 256, bilinear)
|
||||
self.up2 = Up(512, 128, bilinear)
|
||||
self.up3 = Up(256, 64, bilinear)
|
||||
self.up4 = Up(128, 64, bilinear)
|
||||
factor = 2 if bilinear else 1
|
||||
self.down4 = Down(512, 1024 // factor)
|
||||
self.up1 = Up(1024, 512, bilinear)
|
||||
self.up2 = Up(512, 256, bilinear)
|
||||
self.up3 = Up(256, 128, bilinear)
|
||||
self.up4 = Up(128, 64 * factor, bilinear)
|
||||
self.outc = OutConv(64, n_classes)
|
||||
|
||||
def forward(self, x):
|
||||
|
|
|
@ -8,13 +8,15 @@ import torch.nn.functional as F
|
|||
class DoubleConv(nn.Module):
|
||||
"""(convolution => [BN] => ReLU) * 2"""
|
||||
|
||||
def __init__(self, in_channels, out_channels):
|
||||
def __init__(self, in_channels, out_channels, mid_channels=None):
|
||||
super().__init__()
|
||||
if not mid_channels:
|
||||
mid_channels = out_channels
|
||||
self.double_conv = nn.Sequential(
|
||||
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
|
||||
nn.BatchNorm2d(out_channels),
|
||||
nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1),
|
||||
nn.BatchNorm2d(mid_channels),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
|
||||
nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1),
|
||||
nn.BatchNorm2d(out_channels),
|
||||
nn.ReLU(inplace=True)
|
||||
)
|
||||
|
@ -46,10 +48,11 @@ class Up(nn.Module):
|
|||
# if bilinear, use the normal convolutions to reduce the number of channels
|
||||
if bilinear:
|
||||
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
|
||||
self.conv = DoubleConv(in_channels, out_channels // 2, in_channels // 2)
|
||||
else:
|
||||
self.up = nn.ConvTranspose2d(in_channels // 2, in_channels // 2, kernel_size=2, stride=2)
|
||||
self.up = nn.ConvTranspose2d(in_channels , in_channels // 2, kernel_size=2, stride=2)
|
||||
self.conv = DoubleConv(in_channels, out_channels)
|
||||
|
||||
self.conv = DoubleConv(in_channels, out_channels)
|
||||
|
||||
def forward(self, x1, x2):
|
||||
x1 = self.up(x1)
|
||||
|
|
Loading…
Reference in a new issue