Merge pull request #161 from somehower/master

optimize the passed in parameters in the case of bilinear

Former-commit-id: 5f37e8a6dc592563593210371cebd40f95788080
This commit is contained in:
milesial 2020-04-17 09:29:20 -07:00 committed by GitHub
commit c57180e90c
2 changed files with 5 additions and 5 deletions

View file

@ -18,10 +18,10 @@ class UNet(nn.Module):
self.down3 = Down(256, 512) self.down3 = Down(256, 512)
factor = 2 if bilinear else 1 factor = 2 if bilinear else 1
self.down4 = Down(512, 1024 // factor) self.down4 = Down(512, 1024 // factor)
self.up1 = Up(1024, 512, bilinear) self.up1 = Up(1024, 512 // factor, bilinear)
self.up2 = Up(512, 256, bilinear) self.up2 = Up(512, 256 // factor, bilinear)
self.up3 = Up(256, 128, bilinear) self.up3 = Up(256, 128 // factor, bilinear)
self.up4 = Up(128, 64 * factor, bilinear) self.up4 = Up(128, 64, bilinear)
self.outc = OutConv(64, n_classes) self.outc = OutConv(64, n_classes)
def forward(self, x): def forward(self, x):

View file

@ -48,7 +48,7 @@ class Up(nn.Module):
# if bilinear, use the normal convolutions to reduce the number of channels # if bilinear, use the normal convolutions to reduce the number of channels
if bilinear: if bilinear:
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
self.conv = DoubleConv(in_channels, out_channels // 2, in_channels // 2) self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
else: else:
self.up = nn.ConvTranspose2d(in_channels , in_channels // 2, kernel_size=2, stride=2) self.up = nn.ConvTranspose2d(in_channels , in_channels // 2, kernel_size=2, stride=2)
self.conv = DoubleConv(in_channels, out_channels) self.conv = DoubleConv(in_channels, out_channels)