optimize the passed in parameters in the case of bilinear

Former-commit-id: 6234bd1c4608cac8a64d8658e9af82b60a227859
This commit is contained in:
hushenghao 2020-04-17 23:00:11 +08:00
parent 8b1fee8730
commit b5a3b207a1
2 changed files with 5 additions and 5 deletions

View file

@ -18,10 +18,10 @@ class UNet(nn.Module):
self.down3 = Down(256, 512)
factor = 2 if bilinear else 1
self.down4 = Down(512, 1024 // factor)
self.up1 = Up(1024, 512, bilinear)
self.up2 = Up(512, 256, bilinear)
self.up3 = Up(256, 128, bilinear)
self.up4 = Up(128, 64 * factor, bilinear)
self.up1 = Up(1024, 512 // factor, bilinear)
self.up2 = Up(512, 256 // factor, bilinear)
self.up3 = Up(256, 128 // factor, bilinear)
self.up4 = Up(128, 64, bilinear)
self.outc = OutConv(64, n_classes)
def forward(self, x):

View file

@ -48,7 +48,7 @@ class Up(nn.Module):
# if bilinear, use the normal convolutions to reduce the number of channels
if bilinear:
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
self.conv = DoubleConv(in_channels, out_channels // 2, in_channels // 2)
self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
else:
self.up = nn.ConvTranspose2d(in_channels , in_channels // 2, kernel_size=2, stride=2)
self.conv = DoubleConv(in_channels, out_channels)