refactor: removed bilinear stuff + simplified the construction of the Downs and Ups
Former-commit-id: 4c1e0a5a9fc02047b788b13d9bfc3ad7313413e3
This commit is contained in:
parent
cc4e8089ec
commit
bc2aeacfa3
|
@ -91,12 +91,6 @@ def get_args():
|
|||
default=0.5,
|
||||
help="Scale factor for the input images",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--bilinear",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Use bilinear upsampling",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
@ -121,7 +115,7 @@ if __name__ == "__main__":
|
|||
in_files = args.input
|
||||
out_files = get_output_filenames(args)
|
||||
|
||||
net = UNet(n_channels=3, n_classes=2, bilinear=args.bilinear)
|
||||
net = UNet(n_channels=3, n_classes=2)
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
logging.info(f"Loading model {args.model}")
|
||||
|
|
17
src/train.py
17
src/train.py
|
@ -214,12 +214,6 @@ def get_args():
|
|||
default=False,
|
||||
help="Use mixed precision",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--bilinear",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Use bilinear upsampling",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--classes",
|
||||
"-c",
|
||||
|
@ -241,13 +235,14 @@ if __name__ == "__main__":
|
|||
# Change here to adapt to your data
|
||||
# n_channels=3 for RGB images
|
||||
# n_classes is the number of probabilities you want to get per pixel
|
||||
net = UNet(n_channels=3, n_classes=args.classes, bilinear=args.bilinear)
|
||||
net = UNet(n_channels=3, n_classes=args.classes)
|
||||
|
||||
logging.info(
|
||||
f"Network:\n"
|
||||
f"\t{net.n_channels} input channels\n"
|
||||
f"\t{net.n_classes} output channels (classes)\n"
|
||||
f'\t{"Bilinear" if net.bilinear else "Transposed conv"} upscaling'
|
||||
f"""
|
||||
Network:\n
|
||||
\t{net.n_channels} input channels\n
|
||||
\t{net.n_classes} output channels (classes)\n
|
||||
"""
|
||||
)
|
||||
|
||||
if args.load:
|
||||
|
|
|
@ -10,12 +10,16 @@ class DoubleConv(nn.Module):
|
|||
|
||||
def __init__(self, in_channels, out_channels, mid_channels=None):
|
||||
super().__init__()
|
||||
|
||||
if not mid_channels:
|
||||
mid_channels = out_channels
|
||||
|
||||
self.double_conv = nn.Sequential(
|
||||
# first convolution
|
||||
nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
|
||||
nn.BatchNorm2d(mid_channels),
|
||||
nn.ReLU(inplace=True),
|
||||
# second convolution
|
||||
nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
|
||||
nn.BatchNorm2d(out_channels),
|
||||
nn.ReLU(inplace=True),
|
||||
|
@ -30,7 +34,11 @@ class Down(nn.Module):
|
|||
|
||||
def __init__(self, in_channels, out_channels):
|
||||
super().__init__()
|
||||
self.maxpool_conv = nn.Sequential(nn.MaxPool2d(2), DoubleConv(in_channels, out_channels))
|
||||
|
||||
self.maxpool_conv = nn.Sequential(
|
||||
nn.MaxPool2d(2),
|
||||
DoubleConv(in_channels, out_channels),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.maxpool_conv(x)
|
||||
|
@ -39,34 +47,29 @@ class Down(nn.Module):
|
|||
class Up(nn.Module):
|
||||
"""Upscaling then double conv"""
|
||||
|
||||
def __init__(self, in_channels, out_channels, bilinear=True):
|
||||
def __init__(self, in_channels, out_channels):
|
||||
super().__init__()
|
||||
|
||||
# if bilinear, use the normal convolutions to reduce the number of channels
|
||||
if bilinear:
|
||||
self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
|
||||
self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
|
||||
else:
|
||||
self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
|
||||
self.conv = DoubleConv(in_channels, out_channels)
|
||||
self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
|
||||
self.conv = DoubleConv(in_channels, out_channels)
|
||||
|
||||
def forward(self, x1, x2):
|
||||
x1 = self.up(x1)
|
||||
|
||||
# input is CHW
|
||||
diffY = x2.size()[2] - x1.size()[2]
|
||||
diffX = x2.size()[3] - x1.size()[3]
|
||||
|
||||
x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2])
|
||||
# if you have padding issues, see
|
||||
# https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
|
||||
# https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
|
||||
x = torch.cat([x2, x1], dim=1)
|
||||
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
class OutConv(nn.Module):
|
||||
def __init__(self, in_channels, out_channels):
|
||||
super(OutConv, self).__init__()
|
||||
|
||||
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
|
||||
|
||||
def forward(self, x):
|
||||
|
|
|
@ -4,23 +4,26 @@ from .blocks import *
|
|||
|
||||
|
||||
class UNet(nn.Module):
|
||||
def __init__(self, n_channels, n_classes, bilinear=False):
|
||||
def __init__(self, n_channels, n_classes, features=[64, 128, 256, 512]):
|
||||
super(UNet, self).__init__()
|
||||
self.n_channels = n_channels
|
||||
self.n_classes = n_classes
|
||||
self.bilinear = bilinear
|
||||
|
||||
self.inc = DoubleConv(n_channels, 64)
|
||||
self.down1 = Down(64, 128)
|
||||
self.down2 = Down(128, 256)
|
||||
self.down3 = Down(256, 512)
|
||||
factor = 2 if bilinear else 1
|
||||
self.down4 = Down(512, 1024 // factor)
|
||||
self.up1 = Up(1024, 512 // factor, bilinear)
|
||||
self.up2 = Up(512, 256 // factor, bilinear)
|
||||
self.up3 = Up(256, 128 // factor, bilinear)
|
||||
self.up4 = Up(128, 64, bilinear)
|
||||
self.outc = OutConv(64, n_classes)
|
||||
self.inc = DoubleConv(n_channels, features[0])
|
||||
|
||||
self.downs = nn.ModuleList()
|
||||
for i in range(len(features) - 1):
|
||||
self.downs.append(
|
||||
Down(*features[i : i + 2]),
|
||||
)
|
||||
|
||||
self.ups = nn.ModuleList()
|
||||
for i in range(len(features) - 1):
|
||||
self.ups.append(
|
||||
Up(*features[::-1][i : i + 2]),
|
||||
)
|
||||
|
||||
self.outc = OutConv(features[0], n_classes)
|
||||
|
||||
def forward(self, x):
|
||||
x1 = self.inc(x)
|
||||
|
|
Loading…
Reference in a new issue