diff --git a/.gitignore b/.gitignore index a3f457d..a014da0 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ *.pyc data/ __pycache__/ +*.pth diff --git a/unet_model.py b/unet_model.py index 2ed6b73..6d5dc39 100644 --- a/unet_model.py +++ b/unet_model.py @@ -11,10 +11,10 @@ class UNet(nn.Module): self.down1 = down(64, 128) self.down2 = down(128, 256) self.down3 = down(256, 512) - self.down4 = down(512, 1024) - self.up1 = up(1024, 512) - self.up2 = up(512, 256) - self.up3 = up(256, 128) + self.down4 = down(512, 512) + self.up1 = up(1024, 256) + self.up2 = up(512, 128) + self.up3 = up(256, 64) self.up4 = up(128, 64) self.outc = outconv(64, n_classes) diff --git a/unet_parts.py b/unet_parts.py index 9416dbb..37ec10e 100644 --- a/unet_parts.py +++ b/unet_parts.py @@ -41,7 +41,8 @@ class down(nn.Module): class up(nn.Module): def __init__(self, in_ch, out_ch): super(up, self).__init__() - self.up = nn.ConvTranspose2d(in_ch, out_ch, 2, stride=2) + self.up = nn.UpsamplingBilinear2d(scale_factor=2) + #self.up = nn.ConvTranspose2d(in_ch, out_ch, 2, stride=2) self.conv = double_conv(in_ch, out_ch) def forward(self, x1, x2):