From b5a3b207a12668b7be1de52f4b00bf5770089b21 Mon Sep 17 00:00:00 2001 From: hushenghao Date: Fri, 17 Apr 2020 23:00:11 +0800 Subject: [PATCH] optimize the passed in parameters in the case of bilinear Former-commit-id: 6234bd1c4608cac8a64d8658e9af82b60a227859 --- unet/unet_model.py | 8 ++++---- unet/unet_parts.py | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/unet/unet_model.py b/unet/unet_model.py index e5d2388..40291c5 100644 --- a/unet/unet_model.py +++ b/unet/unet_model.py @@ -18,10 +18,10 @@ class UNet(nn.Module): self.down3 = Down(256, 512) factor = 2 if bilinear else 1 self.down4 = Down(512, 1024 // factor) - self.up1 = Up(1024, 512, bilinear) - self.up2 = Up(512, 256, bilinear) - self.up3 = Up(256, 128, bilinear) - self.up4 = Up(128, 64 * factor, bilinear) + self.up1 = Up(1024, 512 // factor, bilinear) + self.up2 = Up(512, 256 // factor, bilinear) + self.up3 = Up(256, 128 // factor, bilinear) + self.up4 = Up(128, 64, bilinear) self.outc = OutConv(64, n_classes) def forward(self, x): diff --git a/unet/unet_parts.py b/unet/unet_parts.py index daaa2da..7c64dc1 100644 --- a/unet/unet_parts.py +++ b/unet/unet_parts.py @@ -48,7 +48,7 @@ class Up(nn.Module): # if bilinear, use the normal convolutions to reduce the number of channels if bilinear: self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) - self.conv = DoubleConv(in_channels, out_channels // 2, in_channels // 2) + self.conv = DoubleConv(in_channels, out_channels, in_channels // 2) else: self.up = nn.ConvTranspose2d(in_channels , in_channels // 2, kernel_size=2, stride=2) self.conv = DoubleConv(in_channels, out_channels)