mirror of
https://github.com/Laurent2916/REVA-QCAV.git
synced 2024-11-09 15:02:03 +00:00
Merge pull request #161 from somehower/master
optimize the passed in parameters in the case of bilinear Former-commit-id: 5f37e8a6dc592563593210371cebd40f95788080
This commit is contained in:
commit
c57180e90c
|
@ -18,10 +18,10 @@ class UNet(nn.Module):
|
||||||
self.down3 = Down(256, 512)
|
self.down3 = Down(256, 512)
|
||||||
factor = 2 if bilinear else 1
|
factor = 2 if bilinear else 1
|
||||||
self.down4 = Down(512, 1024 // factor)
|
self.down4 = Down(512, 1024 // factor)
|
||||||
self.up1 = Up(1024, 512, bilinear)
|
self.up1 = Up(1024, 512 // factor, bilinear)
|
||||||
self.up2 = Up(512, 256, bilinear)
|
self.up2 = Up(512, 256 // factor, bilinear)
|
||||||
self.up3 = Up(256, 128, bilinear)
|
self.up3 = Up(256, 128 // factor, bilinear)
|
||||||
self.up4 = Up(128, 64 * factor, bilinear)
|
self.up4 = Up(128, 64, bilinear)
|
||||||
self.outc = OutConv(64, n_classes)
|
self.outc = OutConv(64, n_classes)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
|
|
@ -48,7 +48,7 @@ class Up(nn.Module):
|
||||||
# if bilinear, use the normal convolutions to reduce the number of channels
|
# if bilinear, use the normal convolutions to reduce the number of channels
|
||||||
if bilinear:
|
if bilinear:
|
||||||
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
|
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
|
||||||
self.conv = DoubleConv(in_channels, out_channels // 2, in_channels // 2)
|
self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
|
||||||
else:
|
else:
|
||||||
self.up = nn.ConvTranspose2d(in_channels , in_channels // 2, kernel_size=2, stride=2)
|
self.up = nn.ConvTranspose2d(in_channels , in_channels // 2, kernel_size=2, stride=2)
|
||||||
self.conv = DoubleConv(in_channels, out_channels)
|
self.conv = DoubleConv(in_channels, out_channels)
|
||||||
|
|
Loading…
Reference in a new issue