From 8332f891c35b7d6425b8e990a0e9a3454168a871 Mon Sep 17 00:00:00 2001 From: milesial Date: Thu, 17 Aug 2017 15:33:47 +0200 Subject: [PATCH] Changed from deconv to bilinear for upsampling --- .gitignore | 1 + unet_model.py | 8 ++++---- unet_parts.py | 3 ++- 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/.gitignore b/.gitignore index a3f457d..a014da0 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ *.pyc data/ __pycache__/ +*.pth diff --git a/unet_model.py b/unet_model.py index 2ed6b73..6d5dc39 100644 --- a/unet_model.py +++ b/unet_model.py @@ -11,10 +11,10 @@ class UNet(nn.Module): self.down1 = down(64, 128) self.down2 = down(128, 256) self.down3 = down(256, 512) - self.down4 = down(512, 1024) - self.up1 = up(1024, 512) - self.up2 = up(512, 256) - self.up3 = up(256, 128) + self.down4 = down(512, 512) + self.up1 = up(1024, 256) + self.up2 = up(512, 128) + self.up3 = up(256, 64) self.up4 = up(128, 64) self.outc = outconv(64, n_classes) diff --git a/unet_parts.py b/unet_parts.py index 9416dbb..37ec10e 100644 --- a/unet_parts.py +++ b/unet_parts.py @@ -41,7 +41,8 @@ class down(nn.Module): class up(nn.Module): def __init__(self, in_ch, out_ch): super(up, self).__init__() - self.up = nn.ConvTranspose2d(in_ch, out_ch, 2, stride=2) + self.up = nn.UpsamplingBilinear2d(scale_factor=2) + #self.up = nn.ConvTranspose2d(in_ch, out_ch, 2, stride=2) self.conv = double_conv(in_ch, out_ch) def forward(self, x1, x2):