refactor: removed bilinear stuff + simplified the construction of the Downs and Ups
Former-commit-id: 4c1e0a5a9fc02047b788b13d9bfc3ad7313413e3
This commit is contained in:
parent
cc4e8089ec
commit
bc2aeacfa3
|
@ -91,12 +91,6 @@ def get_args():
|
||||||
default=0.5,
|
default=0.5,
|
||||||
help="Scale factor for the input images",
|
help="Scale factor for the input images",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
|
||||||
"--bilinear",
|
|
||||||
action="store_true",
|
|
||||||
default=False,
|
|
||||||
help="Use bilinear upsampling",
|
|
||||||
)
|
|
||||||
|
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
@ -121,7 +115,7 @@ if __name__ == "__main__":
|
||||||
in_files = args.input
|
in_files = args.input
|
||||||
out_files = get_output_filenames(args)
|
out_files = get_output_filenames(args)
|
||||||
|
|
||||||
net = UNet(n_channels=3, n_classes=2, bilinear=args.bilinear)
|
net = UNet(n_channels=3, n_classes=2)
|
||||||
|
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
logging.info(f"Loading model {args.model}")
|
logging.info(f"Loading model {args.model}")
|
||||||
|
|
17
src/train.py
17
src/train.py
|
@ -214,12 +214,6 @@ def get_args():
|
||||||
default=False,
|
default=False,
|
||||||
help="Use mixed precision",
|
help="Use mixed precision",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
|
||||||
"--bilinear",
|
|
||||||
action="store_true",
|
|
||||||
default=False,
|
|
||||||
help="Use bilinear upsampling",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--classes",
|
"--classes",
|
||||||
"-c",
|
"-c",
|
||||||
|
@ -241,13 +235,14 @@ if __name__ == "__main__":
|
||||||
# Change here to adapt to your data
|
# Change here to adapt to your data
|
||||||
# n_channels=3 for RGB images
|
# n_channels=3 for RGB images
|
||||||
# n_classes is the number of probabilities you want to get per pixel
|
# n_classes is the number of probabilities you want to get per pixel
|
||||||
net = UNet(n_channels=3, n_classes=args.classes, bilinear=args.bilinear)
|
net = UNet(n_channels=3, n_classes=args.classes)
|
||||||
|
|
||||||
logging.info(
|
logging.info(
|
||||||
f"Network:\n"
|
f"""
|
||||||
f"\t{net.n_channels} input channels\n"
|
Network:\n
|
||||||
f"\t{net.n_classes} output channels (classes)\n"
|
\t{net.n_channels} input channels\n
|
||||||
f'\t{"Bilinear" if net.bilinear else "Transposed conv"} upscaling'
|
\t{net.n_classes} output channels (classes)\n
|
||||||
|
"""
|
||||||
)
|
)
|
||||||
|
|
||||||
if args.load:
|
if args.load:
|
||||||
|
|
|
@ -10,12 +10,16 @@ class DoubleConv(nn.Module):
|
||||||
|
|
||||||
def __init__(self, in_channels, out_channels, mid_channels=None):
|
def __init__(self, in_channels, out_channels, mid_channels=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
if not mid_channels:
|
if not mid_channels:
|
||||||
mid_channels = out_channels
|
mid_channels = out_channels
|
||||||
|
|
||||||
self.double_conv = nn.Sequential(
|
self.double_conv = nn.Sequential(
|
||||||
|
# first convolution
|
||||||
nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
|
nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
|
||||||
nn.BatchNorm2d(mid_channels),
|
nn.BatchNorm2d(mid_channels),
|
||||||
nn.ReLU(inplace=True),
|
nn.ReLU(inplace=True),
|
||||||
|
# second convolution
|
||||||
nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
|
nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
|
||||||
nn.BatchNorm2d(out_channels),
|
nn.BatchNorm2d(out_channels),
|
||||||
nn.ReLU(inplace=True),
|
nn.ReLU(inplace=True),
|
||||||
|
@ -30,7 +34,11 @@ class Down(nn.Module):
|
||||||
|
|
||||||
def __init__(self, in_channels, out_channels):
|
def __init__(self, in_channels, out_channels):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.maxpool_conv = nn.Sequential(nn.MaxPool2d(2), DoubleConv(in_channels, out_channels))
|
|
||||||
|
self.maxpool_conv = nn.Sequential(
|
||||||
|
nn.MaxPool2d(2),
|
||||||
|
DoubleConv(in_channels, out_channels),
|
||||||
|
)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return self.maxpool_conv(x)
|
return self.maxpool_conv(x)
|
||||||
|
@ -39,34 +47,29 @@ class Down(nn.Module):
|
||||||
class Up(nn.Module):
|
class Up(nn.Module):
|
||||||
"""Upscaling then double conv"""
|
"""Upscaling then double conv"""
|
||||||
|
|
||||||
def __init__(self, in_channels, out_channels, bilinear=True):
|
def __init__(self, in_channels, out_channels):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
# if bilinear, use the normal convolutions to reduce the number of channels
|
self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
|
||||||
if bilinear:
|
self.conv = DoubleConv(in_channels, out_channels)
|
||||||
self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
|
|
||||||
self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
|
|
||||||
else:
|
|
||||||
self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
|
|
||||||
self.conv = DoubleConv(in_channels, out_channels)
|
|
||||||
|
|
||||||
def forward(self, x1, x2):
|
def forward(self, x1, x2):
|
||||||
x1 = self.up(x1)
|
x1 = self.up(x1)
|
||||||
|
|
||||||
# input is CHW
|
# input is CHW
|
||||||
diffY = x2.size()[2] - x1.size()[2]
|
diffY = x2.size()[2] - x1.size()[2]
|
||||||
diffX = x2.size()[3] - x1.size()[3]
|
diffX = x2.size()[3] - x1.size()[3]
|
||||||
|
|
||||||
x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2])
|
x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2])
|
||||||
# if you have padding issues, see
|
|
||||||
# https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
|
|
||||||
# https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
|
|
||||||
x = torch.cat([x2, x1], dim=1)
|
x = torch.cat([x2, x1], dim=1)
|
||||||
|
|
||||||
return self.conv(x)
|
return self.conv(x)
|
||||||
|
|
||||||
|
|
||||||
class OutConv(nn.Module):
|
class OutConv(nn.Module):
|
||||||
def __init__(self, in_channels, out_channels):
|
def __init__(self, in_channels, out_channels):
|
||||||
super(OutConv, self).__init__()
|
super(OutConv, self).__init__()
|
||||||
|
|
||||||
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
|
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
|
|
@ -4,23 +4,26 @@ from .blocks import *
|
||||||
|
|
||||||
|
|
||||||
class UNet(nn.Module):
|
class UNet(nn.Module):
|
||||||
def __init__(self, n_channels, n_classes, bilinear=False):
|
def __init__(self, n_channels, n_classes, features=[64, 128, 256, 512]):
|
||||||
super(UNet, self).__init__()
|
super(UNet, self).__init__()
|
||||||
self.n_channels = n_channels
|
self.n_channels = n_channels
|
||||||
self.n_classes = n_classes
|
self.n_classes = n_classes
|
||||||
self.bilinear = bilinear
|
|
||||||
|
|
||||||
self.inc = DoubleConv(n_channels, 64)
|
self.inc = DoubleConv(n_channels, features[0])
|
||||||
self.down1 = Down(64, 128)
|
|
||||||
self.down2 = Down(128, 256)
|
self.downs = nn.ModuleList()
|
||||||
self.down3 = Down(256, 512)
|
for i in range(len(features) - 1):
|
||||||
factor = 2 if bilinear else 1
|
self.downs.append(
|
||||||
self.down4 = Down(512, 1024 // factor)
|
Down(*features[i : i + 2]),
|
||||||
self.up1 = Up(1024, 512 // factor, bilinear)
|
)
|
||||||
self.up2 = Up(512, 256 // factor, bilinear)
|
|
||||||
self.up3 = Up(256, 128 // factor, bilinear)
|
self.ups = nn.ModuleList()
|
||||||
self.up4 = Up(128, 64, bilinear)
|
for i in range(len(features) - 1):
|
||||||
self.outc = OutConv(64, n_classes)
|
self.ups.append(
|
||||||
|
Up(*features[::-1][i : i + 2]),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.outc = OutConv(features[0], n_classes)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x1 = self.inc(x)
|
x1 = self.inc(x)
|
||||||
|
|
Loading…
Reference in a new issue