diff --git a/unet/unet_parts.py b/unet/unet_parts.py index 66149d2..c7128d0 100644 --- a/unet/unet_parts.py +++ b/unet/unet_parts.py @@ -57,7 +57,7 @@ class up(nn.Module): if bilinear: self.up = nn.UpsamplingBilinear2d(scale_factor=2) else: - self.up = nn.ConvTranspose2d(in_ch, out_ch, 2, stride=2) + self.up = nn.ConvTranspose2d(in_ch//2, in_ch//2, 2, stride=2) self.conv = double_conv(in_ch, out_ch)