refactor: removed bilinear stuff + simplified the construction of the Downs and Ups

Former-commit-id: 4c1e0a5a9fc02047b788b13d9bfc3ad7313413e3
This commit is contained in:
Your Name 2022-06-27 16:13:38 +02:00
parent cc4e8089ec
commit bc2aeacfa3
4 changed files with 38 additions and 43 deletions

View file

@ -91,12 +91,6 @@ def get_args():
default=0.5, default=0.5,
help="Scale factor for the input images", help="Scale factor for the input images",
) )
parser.add_argument(
"--bilinear",
action="store_true",
default=False,
help="Use bilinear upsampling",
)
return parser.parse_args() return parser.parse_args()
@ -121,7 +115,7 @@ if __name__ == "__main__":
in_files = args.input in_files = args.input
out_files = get_output_filenames(args) out_files = get_output_filenames(args)
net = UNet(n_channels=3, n_classes=2, bilinear=args.bilinear) net = UNet(n_channels=3, n_classes=2)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logging.info(f"Loading model {args.model}") logging.info(f"Loading model {args.model}")

View file

@ -214,12 +214,6 @@ def get_args():
default=False, default=False,
help="Use mixed precision", help="Use mixed precision",
) )
parser.add_argument(
"--bilinear",
action="store_true",
default=False,
help="Use bilinear upsampling",
)
parser.add_argument( parser.add_argument(
"--classes", "--classes",
"-c", "-c",
@ -241,13 +235,14 @@ if __name__ == "__main__":
# Change here to adapt to your data # Change here to adapt to your data
# n_channels=3 for RGB images # n_channels=3 for RGB images
# n_classes is the number of probabilities you want to get per pixel # n_classes is the number of probabilities you want to get per pixel
net = UNet(n_channels=3, n_classes=args.classes, bilinear=args.bilinear) net = UNet(n_channels=3, n_classes=args.classes)
logging.info( logging.info(
f"Network:\n" f"""
f"\t{net.n_channels} input channels\n" Network:\n
f"\t{net.n_classes} output channels (classes)\n" \t{net.n_channels} input channels\n
f'\t{"Bilinear" if net.bilinear else "Transposed conv"} upscaling' \t{net.n_classes} output channels (classes)\n
"""
) )
if args.load: if args.load:

View file

@ -10,12 +10,16 @@ class DoubleConv(nn.Module):
def __init__(self, in_channels, out_channels, mid_channels=None): def __init__(self, in_channels, out_channels, mid_channels=None):
super().__init__() super().__init__()
if not mid_channels: if not mid_channels:
mid_channels = out_channels mid_channels = out_channels
self.double_conv = nn.Sequential( self.double_conv = nn.Sequential(
# first convolution
nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False), nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(mid_channels), nn.BatchNorm2d(mid_channels),
nn.ReLU(inplace=True), nn.ReLU(inplace=True),
# second convolution
nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False), nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(out_channels), nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True), nn.ReLU(inplace=True),
@ -30,7 +34,11 @@ class Down(nn.Module):
def __init__(self, in_channels, out_channels): def __init__(self, in_channels, out_channels):
super().__init__() super().__init__()
self.maxpool_conv = nn.Sequential(nn.MaxPool2d(2), DoubleConv(in_channels, out_channels))
self.maxpool_conv = nn.Sequential(
nn.MaxPool2d(2),
DoubleConv(in_channels, out_channels),
)
def forward(self, x): def forward(self, x):
return self.maxpool_conv(x) return self.maxpool_conv(x)
@ -39,34 +47,29 @@ class Down(nn.Module):
class Up(nn.Module): class Up(nn.Module):
"""Upscaling then double conv""" """Upscaling then double conv"""
def __init__(self, in_channels, out_channels, bilinear=True): def __init__(self, in_channels, out_channels):
super().__init__() super().__init__()
# if bilinear, use the normal convolutions to reduce the number of channels self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
if bilinear: self.conv = DoubleConv(in_channels, out_channels)
self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
else:
self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
self.conv = DoubleConv(in_channels, out_channels)
def forward(self, x1, x2): def forward(self, x1, x2):
x1 = self.up(x1) x1 = self.up(x1)
# input is CHW # input is CHW
diffY = x2.size()[2] - x1.size()[2] diffY = x2.size()[2] - x1.size()[2]
diffX = x2.size()[3] - x1.size()[3] diffX = x2.size()[3] - x1.size()[3]
x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2]) x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2])
# if you have padding issues, see
# https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
# https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
x = torch.cat([x2, x1], dim=1) x = torch.cat([x2, x1], dim=1)
return self.conv(x) return self.conv(x)
class OutConv(nn.Module): class OutConv(nn.Module):
def __init__(self, in_channels, out_channels): def __init__(self, in_channels, out_channels):
super(OutConv, self).__init__() super(OutConv, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
def forward(self, x): def forward(self, x):

View file

@ -4,23 +4,26 @@ from .blocks import *
class UNet(nn.Module): class UNet(nn.Module):
def __init__(self, n_channels, n_classes, bilinear=False): def __init__(self, n_channels, n_classes, features=[64, 128, 256, 512]):
super(UNet, self).__init__() super(UNet, self).__init__()
self.n_channels = n_channels self.n_channels = n_channels
self.n_classes = n_classes self.n_classes = n_classes
self.bilinear = bilinear
self.inc = DoubleConv(n_channels, 64) self.inc = DoubleConv(n_channels, features[0])
self.down1 = Down(64, 128)
self.down2 = Down(128, 256) self.downs = nn.ModuleList()
self.down3 = Down(256, 512) for i in range(len(features) - 1):
factor = 2 if bilinear else 1 self.downs.append(
self.down4 = Down(512, 1024 // factor) Down(*features[i : i + 2]),
self.up1 = Up(1024, 512 // factor, bilinear) )
self.up2 = Up(512, 256 // factor, bilinear)
self.up3 = Up(256, 128 // factor, bilinear) self.ups = nn.ModuleList()
self.up4 = Up(128, 64, bilinear) for i in range(len(features) - 1):
self.outc = OutConv(64, n_classes) self.ups.append(
Up(*features[::-1][i : i + 2]),
)
self.outc = OutConv(features[0], n_classes)
def forward(self, x): def forward(self, x):
x1 = self.inc(x) x1 = self.inc(x)