diff --git a/src/predict.py b/src/predict.py index ec03fe6..6d569b4 100755 --- a/src/predict.py +++ b/src/predict.py @@ -91,12 +91,6 @@ def get_args(): default=0.5, help="Scale factor for the input images", ) - parser.add_argument( - "--bilinear", - action="store_true", - default=False, - help="Use bilinear upsampling", - ) return parser.parse_args() @@ -121,7 +115,7 @@ if __name__ == "__main__": in_files = args.input out_files = get_output_filenames(args) - net = UNet(n_channels=3, n_classes=2, bilinear=args.bilinear) + net = UNet(n_channels=3, n_classes=2) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") logging.info(f"Loading model {args.model}") diff --git a/src/train.py b/src/train.py index 7e0b7ab..471b1c3 100644 --- a/src/train.py +++ b/src/train.py @@ -214,12 +214,6 @@ def get_args(): default=False, help="Use mixed precision", ) - parser.add_argument( - "--bilinear", - action="store_true", - default=False, - help="Use bilinear upsampling", - ) parser.add_argument( "--classes", "-c", @@ -241,13 +235,14 @@ if __name__ == "__main__": # Change here to adapt to your data # n_channels=3 for RGB images # n_classes is the number of probabilities you want to get per pixel - net = UNet(n_channels=3, n_classes=args.classes, bilinear=args.bilinear) + net = UNet(n_channels=3, n_classes=args.classes) logging.info( - f"Network:\n" - f"\t{net.n_channels} input channels\n" - f"\t{net.n_classes} output channels (classes)\n" - f'\t{"Bilinear" if net.bilinear else "Transposed conv"} upscaling' + f""" + Network:\n + \t{net.n_channels} input channels\n + \t{net.n_classes} output channels (classes)\n + """ ) if args.load: diff --git a/src/unet/blocks.py b/src/unet/blocks.py index c8a5e47..1f4a854 100644 --- a/src/unet/blocks.py +++ b/src/unet/blocks.py @@ -10,12 +10,16 @@ class DoubleConv(nn.Module): def __init__(self, in_channels, out_channels, mid_channels=None): super().__init__() + if not mid_channels: mid_channels = out_channels + self.double_conv = nn.Sequential( + # first convolution nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(mid_channels), nn.ReLU(inplace=True), + # second convolution nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True), @@ -30,7 +34,11 @@ class Down(nn.Module): def __init__(self, in_channels, out_channels): super().__init__() - self.maxpool_conv = nn.Sequential(nn.MaxPool2d(2), DoubleConv(in_channels, out_channels)) + + self.maxpool_conv = nn.Sequential( + nn.MaxPool2d(2), + DoubleConv(in_channels, out_channels), + ) def forward(self, x): return self.maxpool_conv(x) @@ -39,34 +47,29 @@ class Down(nn.Module): class Up(nn.Module): """Upscaling then double conv""" - def __init__(self, in_channels, out_channels, bilinear=True): + def __init__(self, in_channels, out_channels): super().__init__() - # if bilinear, use the normal convolutions to reduce the number of channels - if bilinear: - self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True) - self.conv = DoubleConv(in_channels, out_channels, in_channels // 2) - else: - self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2) - self.conv = DoubleConv(in_channels, out_channels) + self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2) + self.conv = DoubleConv(in_channels, out_channels) def forward(self, x1, x2): x1 = self.up(x1) + # input is CHW diffY = x2.size()[2] - x1.size()[2] diffX = x2.size()[3] - x1.size()[3] x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2]) - # if you have padding issues, see - # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a - # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd x = torch.cat([x2, x1], dim=1) + return self.conv(x) class OutConv(nn.Module): def __init__(self, in_channels, out_channels): super(OutConv, self).__init__() + self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) def forward(self, x): diff --git a/src/unet/model.py b/src/unet/model.py index 9fa9c0e..b1bb22c 100644 --- a/src/unet/model.py +++ b/src/unet/model.py @@ -4,23 +4,26 @@ from .blocks import * class UNet(nn.Module): - def __init__(self, n_channels, n_classes, bilinear=False): + def __init__(self, n_channels, n_classes, features=[64, 128, 256, 512]): super(UNet, self).__init__() self.n_channels = n_channels self.n_classes = n_classes - self.bilinear = bilinear - self.inc = DoubleConv(n_channels, 64) - self.down1 = Down(64, 128) - self.down2 = Down(128, 256) - self.down3 = Down(256, 512) - factor = 2 if bilinear else 1 - self.down4 = Down(512, 1024 // factor) - self.up1 = Up(1024, 512 // factor, bilinear) - self.up2 = Up(512, 256 // factor, bilinear) - self.up3 = Up(256, 128 // factor, bilinear) - self.up4 = Up(128, 64, bilinear) - self.outc = OutConv(64, n_classes) + self.inc = DoubleConv(n_channels, features[0]) + + self.downs = nn.ModuleList() + for i in range(len(features) - 1): + self.downs.append( + Down(*features[i : i + 2]), + ) + + self.ups = nn.ModuleList() + for i in range(len(features) - 1): + self.ups.append( + Up(*features[::-1][i : i + 2]), + ) + + self.outc = OutConv(features[0], n_classes) def forward(self, x): x1 = self.inc(x)