diff --git a/unet/unet_parts.py b/unet/unet_parts.py index 7f68f52..986ba25 100644 --- a/unet/unet_parts.py +++ b/unet/unet_parts.py @@ -13,10 +13,10 @@ class DoubleConv(nn.Module): if not mid_channels: mid_channels = out_channels self.double_conv = nn.Sequential( - nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1), + nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(mid_channels), nn.ReLU(inplace=True), - nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1), + nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True) )