Summer cleanup

Former-commit-id: f6185d67a4bc50aa7ec1b8168aab3f92721c4965
This commit is contained in:
milesial 2021-08-16 02:53:00 +02:00
parent 2f4f7edd5d
commit 063cbbc599
15 changed files with 343 additions and 391 deletions

2
.gitignore vendored
View file

@ -4,5 +4,5 @@ __pycache__/
checkpoints/
*.pth
*.jpg
SUBMISSION*
venv/
.idea/

9
Dockerfile Normal file
View file

@ -0,0 +1,9 @@
FROM nvcr.io/nvidia/pytorch:21.06-py3
RUN rm -rf /workspace/*
WORKDIR /workspace/unet
ADD requirements.txt .
RUN pip install --no-cache-dir --upgrade --pre pip
RUN pip install --no-cache-dir -r requirements.txt
ADD . .

80
data_loading.py Normal file
View file

@ -0,0 +1,80 @@
import logging
from os import listdir
from os.path import splitext
from pathlib import Path
import numpy as np
import torch
from PIL import Image
from torch.utils.data import Dataset
class BasicDataset(Dataset):
def __init__(self, images_dir: str, masks_dir: str, scale: float = 1.0, mask_suffix: str = ''):
self.images_dir = Path(images_dir)
self.masks_dir = Path(masks_dir)
assert 0 < scale <= 1, 'Scale must be between 0 and 1'
self.scale = scale
self.mask_suffix = mask_suffix
self.ids = [splitext(file)[0] for file in listdir(images_dir) if not file.startswith('.')]
if not self.ids:
raise RuntimeError(f'No input file found in {images_dir}, make sure you put your images there')
logging.info(f'Creating dataset with {len(self.ids)} examples')
def __len__(self):
return len(self.ids)
@classmethod
def preprocess(cls, pil_img, scale, is_mask):
w, h = pil_img.size
newW, newH = int(scale * w), int(scale * h)
assert newW > 0 and newH > 0, 'Scale is too small, resized images would have no pixel'
pil_img = pil_img.resize((newW, newH))
img_ndarray = np.asarray(pil_img)
if img_ndarray.ndim == 2 and not is_mask:
img_ndarray = img_ndarray[np.newaxis, ...]
elif not is_mask:
img_ndarray = img_ndarray.transpose((2, 0, 1))
if not is_mask:
img_ndarray = img_ndarray / 255
return img_ndarray
@classmethod
def load(cls, filename):
ext = splitext(filename)[1]
if ext in ['.npz', '.npy']:
return Image.fromarray(np.load(filename))
elif ext in ['.pt', '.pth']:
return Image.fromarray(torch.load(filename).numpy())
else:
return Image.open(filename)
def __getitem__(self, idx):
name = self.ids[idx]
mask_file = list(self.masks_dir.glob(name + self.mask_suffix + '.*'))
img_file = list(self.images_dir.glob(name + '.*'))
assert len(mask_file) == 1, f'Either no mask or multiple masks found for the ID {name}: {mask_file}'
assert len(img_file) == 1, f'Either no image or multiple images found for the ID {name}: {img_file}'
mask = self.load(mask_file[0])
img = self.load(img_file[0])
assert img.size == mask.size, \
'Image and mask {name} should be the same size, but are {img.size} and {mask.size}'
img = self.preprocess(img, self.scale, is_mask=False)
mask = self.preprocess(mask, self.scale, is_mask=True)
return {
'image': torch.as_tensor(img.copy()).float().contiguous(),
'mask': torch.as_tensor(mask.copy()).long().contiguous()
}
class CarvanaDataset(BasicDataset):
def __init__(self, images_dir, masks_dir, scale=1):
super().__init__(images_dir, masks_dir, scale, mask_suffix='_mask')

View file

@ -1,42 +0,0 @@
import torch
from torch.autograd import Function
class DiceCoeff(Function):
"""Dice coeff for individual examples"""
def forward(self, input, target):
self.save_for_backward(input, target)
eps = 0.0001
self.inter = torch.dot(input.view(-1), target.view(-1))
self.union = torch.sum(input) + torch.sum(target) + eps
t = (2 * self.inter.float() + eps) / self.union.float()
return t
# This function has only a single output, so it gets only one gradient
def backward(self, grad_output):
input, target = self.saved_variables
grad_input = grad_target = None
if self.needs_input_grad[0]:
grad_input = grad_output * 2 * (target * self.union - self.inter) \
/ (self.union * self.union)
if self.needs_input_grad[1]:
grad_target = None
return grad_input, grad_target
def dice_coeff(input, target):
"""Dice coeff for batches"""
if input.is_cuda:
s = torch.FloatTensor(1).cuda().zero_()
else:
s = torch.FloatTensor(1).zero_()
for i, c in enumerate(zip(input, target)):
s = s + DiceCoeff().forward(c[0], c[1])
return s / (i + 1)

40
dice_score.py Normal file
View file

@ -0,0 +1,40 @@
import torch
from torch import Tensor
def dice_coeff(input: Tensor, target: Tensor, reduce_batch_first: bool = False, epsilon=1e-6):
# Average of Dice coefficient for all batches, or for a single mask
assert input.size() == target.size()
if input.dim() == 2 and reduce_batch_first:
raise ValueError(f'Dice: asked to reduce batch but got tensor without batch dimension (shape {input.shape})')
if input.dim() == 2 or reduce_batch_first:
inter = torch.dot(input.view(-1), target.view(-1))
sets_sum = torch.sum(input) + torch.sum(target)
if sets_sum.item() == 0:
sets_sum = 2 * inter
return (2 * inter + epsilon) / (sets_sum + epsilon)
else:
# compute and average metric for each batch element
dice = 0
for i in range(input.shape[0]):
dice += dice_coeff(input[i, ...], target[i, ...])
return dice / input.shape[0]
def multiclass_dice_coeff(input: Tensor, target: Tensor, reduce_batch_first: bool = False, epsilon=1e-6):
# Average of Dice coefficient for all classes
assert input.size() == target.size()
dice = 0
for channel in range(input.shape[1]):
dice += dice_coeff(input[:, channel, ...], target[:, channel, ...], reduce_batch_first, epsilon)
return dice / input.shape[1]
def dice_loss(input: Tensor, target: Tensor, multiclass: bool = False):
# Dice loss (objective to minimize) between 0 and 1
assert input.size() == target.size()
fn = multiclass_dice_coeff if multiclass else dice_coeff
return 1 - fn(input, target, reduce_batch_first=True)

33
eval.py
View file

@ -1,33 +0,0 @@
import torch
import torch.nn.functional as F
from tqdm import tqdm
from dice_loss import dice_coeff
def eval_net(net, loader, device):
"""Evaluation without the densecrf with the dice coefficient"""
net.eval()
mask_type = torch.float32 if net.n_classes == 1 else torch.long
n_val = len(loader) # the number of batch
tot = 0
with tqdm(total=n_val, desc='Validation round', unit='batch', leave=False) as pbar:
for batch in loader:
imgs, true_masks = batch['image'], batch['mask']
imgs = imgs.to(device=device, dtype=torch.float32)
true_masks = true_masks.to(device=device, dtype=mask_type)
with torch.no_grad():
mask_pred = net(imgs)
if net.n_classes > 1:
tot += F.cross_entropy(mask_pred, true_masks).item()
else:
pred = torch.sigmoid(mask_pred)
pred = (pred > 0.5).float()
tot += dice_coeff(pred, true_masks).item()
pbar.update()
net.train()
return tot / n_val

35
evaluate.py Normal file
View file

@ -0,0 +1,35 @@
import torch
import torch.nn.functional as F
from tqdm import tqdm
from dice_score import multiclass_dice_coeff
def evaluate(net, dataloader, device):
net.eval()
num_val_batches = len(dataloader)
dice_score = 0
# iterate over the validation set
for batch in tqdm(dataloader, total=num_val_batches, desc='Validation round', unit='batch', leave=False):
image, mask_true = batch['image'], batch['mask']
# move images and labels to correct device and type
image = image.to(device=device, dtype=torch.float32)
mask_true = mask_true.to(device=device, dtype=torch.long)
mask_true = F.one_hot(mask_true, net.n_classes).permute(0, 3, 1, 2).float()
with torch.no_grad():
# predict the mask
mask_pred = net(image)
# convert to one-hot format
if net.n_classes == 1:
mask_pred = (F.sigmoid(mask_pred) > 0).float()
else:
mask_pred = F.one_hot(mask_pred.argmax(dim=1), net.n_classes).permute(0, 3, 1, 2).float()
# compute the Dice score, ignoring background
dice_score += multiclass_dice_coeff(mask_pred[:, :1, ...], mask_true[:, :1, ...], reduce_batch_first=False)
net.train()
return dice_score / num_val_batches

View file

@ -8,9 +8,9 @@ import torch.nn.functional as F
from PIL import Image
from torchvision import transforms
from data_loading import BasicDataset
from unet import UNet
from utils.data_vis import plot_img_and_mask
from utils.dataset import BasicDataset
from utils import plot_img_and_mask
def predict_img(net,
@ -19,9 +19,7 @@ def predict_img(net,
scale_factor=1,
out_threshold=0.5):
net.eval()
img = torch.from_numpy(BasicDataset.preprocess(full_img, scale_factor))
img = torch.from_numpy(BasicDataset.preprocess(full_img, scale_factor, is_mask=False))
img = img.unsqueeze(0)
img = img.to(device=device, dtype=torch.float32)
@ -29,94 +27,75 @@ def predict_img(net,
output = net(img)
if net.n_classes > 1:
probs = F.softmax(output, dim=1)
probs = F.softmax(output, dim=1)[0]
else:
probs = torch.sigmoid(output)
probs = torch.sigmoid(output)[0]
probs = probs.squeeze(0)
tf = transforms.Compose(
[
tf = transforms.Compose([
transforms.ToPILImage(),
transforms.Resize(full_img.size[1]),
transforms.Resize((full_img.size[1], full_img.size[0])),
transforms.ToTensor()
]
)
])
probs = tf(probs.cpu())
full_mask = probs.squeeze().cpu().numpy()
full_mask = tf(probs.cpu()).squeeze()
return full_mask > out_threshold
if net.n_classes == 1:
return (full_mask > out_threshold).numpy()
else:
return F.one_hot(full_mask.argmax(dim=0), net.n_classes).permute(2, 0, 1).numpy()
def get_args():
parser = argparse.ArgumentParser(description='Predict masks from input images',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--model', '-m', default='MODEL.pth',
metavar='FILE',
help="Specify the file in which the model is stored")
parser.add_argument('--input', '-i', metavar='INPUT', nargs='+',
help='filenames of input images', required=True)
parser.add_argument('--output', '-o', metavar='INPUT', nargs='+',
help='Filenames of ouput images')
parser = argparse.ArgumentParser(description='Predict masks from input images')
parser.add_argument('--model', '-m', default='MODEL.pth', metavar='FILE',
help='Specify the file in which the model is stored')
parser.add_argument('--input', '-i', metavar='INPUT', nargs='+', help='Filenames of input images', required=True)
parser.add_argument('--output', '-o', metavar='INPUT', nargs='+', help='Filenames of output images')
parser.add_argument('--viz', '-v', action='store_true',
help="Visualize the images as they are processed",
default=False)
parser.add_argument('--no-save', '-n', action='store_true',
help="Do not save the output masks",
default=False)
parser.add_argument('--mask-threshold', '-t', type=float,
help="Minimum probability value to consider a mask pixel white",
default=0.5)
parser.add_argument('--scale', '-s', type=float,
help="Scale factor for the input images",
default=0.5)
help='Visualize the images as they are processed')
parser.add_argument('--no-save', '-n', action='store_true', help='Do not save the output masks')
parser.add_argument('--mask-threshold', '-t', type=float, default=0.5,
help='Minimum probability value to consider a mask pixel white')
parser.add_argument('--scale', '-s', type=float, default=0.5,
help='Scale factor for the input images')
return parser.parse_args()
def get_output_filenames(args):
in_files = args.input
out_files = []
def _generate_name(fn):
split = os.path.splitext(fn)
return f'{split[0]}_OUT{split[1]}'
if not args.output:
for f in in_files:
pathsplit = os.path.splitext(f)
out_files.append("{}_OUT{}".format(pathsplit[0], pathsplit[1]))
elif len(in_files) != len(args.output):
logging.error("Input files and output files are not of the same length")
raise SystemExit()
else:
out_files = args.output
return out_files
return args.output or list(map(_generate_name, args.input))
def mask_to_image(mask):
def mask_to_image(mask: np.ndarray):
if mask.ndim == 2:
return Image.fromarray((mask * 255).astype(np.uint8))
elif mask.ndim == 3:
return Image.fromarray((np.argmax(mask, dim=0) * 255 / mask.shape[0]).astype(np.uint8))
if __name__ == "__main__":
if __name__ == '__main__':
args = get_args()
in_files = args.input
out_files = get_output_filenames(args)
net = UNet(n_channels=3, n_classes=1)
logging.info("Loading model {}".format(args.model))
net = UNet(n_channels=3, n_classes=2)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logging.info(f'Loading model {args.model}')
logging.info(f'Using device {device}')
net.to(device=device)
net.load_state_dict(torch.load(args.model, map_location=device))
logging.info("Model loaded !")
logging.info('Model loaded!')
for i, fn in enumerate(in_files):
logging.info("\nPredicting image {} ...".format(fn))
img = Image.open(fn)
for i, filename in enumerate(in_files):
logging.info(f'\nPredicting image {filename} ...')
img = Image.open(filename)
mask = predict_img(net=net,
full_img=img,
@ -125,12 +104,11 @@ if __name__ == "__main__":
device=device)
if not args.no_save:
out_fn = out_files[i]
out_filename = out_files[i]
result = mask_to_image(mask)
result.save(out_files[i])
logging.info("Mask saved to {}".format(out_files[i]))
result.save(out_filename)
logging.info(f'Mask saved to {out_filename}')
if args.viz:
logging.info("Visualizing results for image {}, close to continue ...".format(fn))
logging.info(f'Visualizing results for image {filename}, close to continue...')
plot_img_and_mask(img, mask)

View file

@ -3,6 +3,5 @@ numpy
Pillow
torch
torchvision
tensorboard
future
tqdm
wandb

View file

@ -1,46 +0,0 @@
""" Submit code specific to the kaggle challenge"""
import os
import torch
from PIL import Image
import numpy as np
from predict import predict_img
from unet import UNet
# credits to https://stackoverflow.com/users/6076729/manuel-lagunas
def rle_encode(mask_image):
pixels = mask_image.flatten()
# We avoid issues with '1' at the start or end (at the corners of
# the original image) by setting those pixels to '0' explicitly.
# We do not expect these to be non-zero for an accurate mask,
# so this should not harm the score.
pixels[0] = 0
pixels[-1] = 0
runs = np.where(pixels[1:] != pixels[:-1])[0] + 2
runs[1::2] = runs[1::2] - runs[:-1:2]
return runs
def submit(net):
"""Used for Kaggle submission: predicts and encode all test images"""
dir = 'data/test/'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
N = len(list(os.listdir(dir)))
with open('SUBMISSION.csv', 'a') as f:
f.write('img,rle_mask\n')
for index, i in enumerate(os.listdir(dir)):
print('{}/{}'.format(index, N))
img = Image.open(dir + i)
mask = predict_img(net, img, device)
enc = rle_encode(mask)
f.write('{},{}\n'.format(i, ' '.join(map(str, enc))))
if __name__ == '__main__':
net = UNet(3, 1).cuda()
net.load_state_dict(torch.load('MODEL.pth'))
submit(net)

212
train.py
View file

@ -1,187 +1,193 @@
import argparse
import logging
import os
import sys
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import wandb
from torch import optim
from torch.utils.data import DataLoader, random_split
from tqdm import tqdm
from eval import eval_net
from data_loading import BasicDataset, CarvanaDataset
from dice_score import dice_loss
from evaluate import evaluate
from unet import UNet
from torch.utils.tensorboard import SummaryWriter
from utils.dataset import BasicDataset
from torch.utils.data import DataLoader, random_split
dir_img = 'data/imgs/'
dir_mask = 'data/masks/'
dir_checkpoint = 'checkpoints/'
dir_img = Path('./data/imgs/')
dir_mask = Path('./data/masks/')
dir_checkpoint = Path('./checkpoints/')
def train_net(net,
device,
epochs=5,
batch_size=1,
lr=0.001,
val_percent=0.1,
save_cp=True,
img_scale=0.5):
epochs: int = 5,
batch_size: int = 1,
learning_rate: float = 0.001,
val_percent: float = 0.1,
save_checkpoint: bool = True,
img_scale: float = 0.5,
amp: bool = False):
# 1. Create dataset
try:
dataset = CarvanaDataset(dir_img, dir_mask, img_scale)
except (AssertionError, RuntimeError):
dataset = BasicDataset(dir_img, dir_mask, img_scale)
# 2. Split into train / validation partitions
n_val = int(len(dataset) * val_percent)
n_train = len(dataset) - n_val
train, val = random_split(dataset, [n_train, n_val])
train_loader = DataLoader(train, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True)
val_loader = DataLoader(val, batch_size=batch_size, shuffle=False, num_workers=8, pin_memory=True, drop_last=True)
train_set, val_set = random_split(dataset, [n_train, n_val], generator=torch.Generator().manual_seed(0))
writer = SummaryWriter(comment=f'LR_{lr}_BS_{batch_size}_SCALE_{img_scale}')
global_step = 0
# 3. Create data loaders
loader_args = dict(batch_size=batch_size, num_workers=4, pin_memory=True)
train_loader = DataLoader(train_set, shuffle=True, **loader_args)
val_loader = DataLoader(val_set, shuffle=False, drop_last=True, **loader_args)
# (Initialise logging)
experiment = wandb.init(project='U-Net', resume='allow', anonymous='must')
experiment.config.update(dict(epochs=epochs, batch_size=batch_size, learning_rate=learning_rate,
val_percent=val_percent, save_checkpoint=save_checkpoint, img_scale=img_scale,
amp=amp))
logging.info(f'''Starting training:
Epochs: {epochs}
Batch size: {batch_size}
Learning rate: {lr}
Learning rate: {learning_rate}
Training size: {n_train}
Validation size: {n_val}
Checkpoints: {save_cp}
Checkpoints: {save_checkpoint}
Device: {device.type}
Images scaling: {img_scale}
Mixed Precision: {amp}
''')
optimizer = optim.RMSprop(net.parameters(), lr=lr, weight_decay=1e-8, momentum=0.9)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min' if net.n_classes > 1 else 'max', patience=2)
if net.n_classes > 1:
# 4. Set up the optimizer, the loss, the learning rate scheduler and the loss scaling for AMP
optimizer = optim.RMSprop(net.parameters(), lr=learning_rate, weight_decay=1e-8, momentum=0.9)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', patience=2) # goal: maximize Dice score
grad_scaler = torch.cuda.amp.GradScaler(enabled=amp)
criterion = nn.CrossEntropyLoss()
else:
criterion = nn.BCEWithLogitsLoss()
global_step = 0
# 5. Begin training
for epoch in range(epochs):
net.train()
epoch_loss = 0
with tqdm(total=n_train, desc=f'Epoch {epoch + 1}/{epochs}', unit='img') as pbar:
for batch in train_loader:
imgs = batch['image']
images = batch['image']
true_masks = batch['mask']
assert imgs.shape[1] == net.n_channels, \
assert images.shape[1] == net.n_channels, \
f'Network has been defined with {net.n_channels} input channels, ' \
f'but loaded images have {imgs.shape[1]} channels. Please check that ' \
f'but loaded images have {images.shape[1]} channels. Please check that ' \
'the images are loaded correctly.'
imgs = imgs.to(device=device, dtype=torch.float32)
mask_type = torch.float32 if net.n_classes == 1 else torch.long
true_masks = true_masks.to(device=device, dtype=mask_type)
images = images.to(device=device, dtype=torch.float32)
true_masks = true_masks.to(device=device, dtype=torch.long)
masks_pred = net(imgs)
loss = criterion(masks_pred, true_masks)
with torch.cuda.amp.autocast(enabled=amp):
masks_pred = net(images)
loss = criterion(masks_pred, true_masks) \
+ dice_loss(F.softmax(masks_pred, dim=1).float(),
F.one_hot(true_masks, net.n_classes).permute(0, 3, 1, 2).float(),
multiclass=True)
optimizer.zero_grad(set_to_none=True)
grad_scaler.scale(loss).backward()
grad_scaler.step(optimizer)
grad_scaler.update()
pbar.update(images.shape[0])
global_step += 1
epoch_loss += loss.item()
writer.add_scalar('Loss/train', loss.item(), global_step)
experiment.log({
'train loss': loss.item(),
'step': global_step,
'epoch': epoch
})
pbar.set_postfix(**{'loss (batch)': loss.item()})
optimizer.zero_grad()
loss.backward()
nn.utils.clip_grad_value_(net.parameters(), 0.1)
optimizer.step()
pbar.update(imgs.shape[0])
global_step += 1
# Evaluation round
if global_step % (n_train // (10 * batch_size)) == 0:
histograms = {}
for tag, value in net.named_parameters():
tag = tag.replace('.', '/')
writer.add_histogram('weights/' + tag, value.data.cpu().numpy(), global_step)
writer.add_histogram('grads/' + tag, value.grad.data.cpu().numpy(), global_step)
val_score = eval_net(net, val_loader, device)
tag = tag.replace('/', '.')
histograms['Weights/' + tag] = wandb.Histogram(value.data.cpu())
histograms['Gradients/' + tag] = wandb.Histogram(value.grad.data.cpu())
val_score = evaluate(net, val_loader, device)
scheduler.step(val_score)
writer.add_scalar('learning_rate', optimizer.param_groups[0]['lr'], global_step)
if net.n_classes > 1:
logging.info('Validation cross entropy: {}'.format(val_score))
writer.add_scalar('Loss/test', val_score, global_step)
else:
logging.info('Validation Dice Coeff: {}'.format(val_score))
writer.add_scalar('Dice/test', val_score, global_step)
logging.info('Validation Dice score: {}'.format(val_score))
experiment.log({
'learning rate': optimizer.param_groups[0]['lr'],
'validation Dice': val_score,
'images': wandb.Image(images[0].cpu()),
'masks': {
'true': wandb.Image(true_masks[0].float().cpu()),
'pred': wandb.Image(torch.softmax(masks_pred, dim=1)[0].float().cpu()),
},
'step': global_step,
'epoch': epoch,
**histograms
})
writer.add_images('images', imgs, global_step)
if net.n_classes == 1:
writer.add_images('masks/true', true_masks, global_step)
writer.add_images('masks/pred', torch.sigmoid(masks_pred) > 0.5, global_step)
if save_cp:
try:
os.mkdir(dir_checkpoint)
logging.info('Created checkpoint directory')
except OSError:
pass
torch.save(net.state_dict(),
dir_checkpoint + f'CP_epoch{epoch + 1}.pth')
logging.info(f'Checkpoint {epoch + 1} saved !')
writer.close()
if save_checkpoint:
Path(dir_checkpoint).mkdir(parents=True, exist_ok=True)
torch.save(net.state_dict(), str(dir_checkpoint / 'checkpoint_epoch{}.pth'.format(epoch + 1)))
logging.info(f'Checkpoint {epoch + 1} saved!')
def get_args():
parser = argparse.ArgumentParser(description='Train the UNet on images and target masks',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('-e', '--epochs', metavar='E', type=int, default=5,
help='Number of epochs', dest='epochs')
parser.add_argument('-b', '--batch-size', metavar='B', type=int, nargs='?', default=1,
help='Batch size', dest='batchsize')
parser.add_argument('-l', '--learning-rate', metavar='LR', type=float, nargs='?', default=0.0001,
parser = argparse.ArgumentParser(description='Train the UNet on images and target masks')
parser.add_argument('--epochs', '-e', metavar='E', type=int, default=50, help='Number of epochs')
parser.add_argument('--batch-size', '-b', dest='batch_size', metavar='B', type=int, default=1, help='Batch size')
parser.add_argument('--learning-rate', '-l', metavar='LR', type=float, default=0.00001,
help='Learning rate', dest='lr')
parser.add_argument('-f', '--load', dest='load', type=str, default=False,
help='Load model from a .pth file')
parser.add_argument('-s', '--scale', dest='scale', type=float, default=0.5,
help='Downscaling factor of the images')
parser.add_argument('-v', '--validation', dest='val', type=float, default=10.0,
parser.add_argument('--load', '-f', type=str, default=False, help='Load model from a .pth file')
parser.add_argument('--scale', '-s', type=float, default=0.5, help='Downscaling factor of the images')
parser.add_argument('--validation', '-v', dest='val', type=float, default=10.0,
help='Percent of the data that is used as validation (0-100)')
parser.add_argument('--amp', action='store_true', default=False, help='Use mixed precision')
return parser.parse_args()
if __name__ == '__main__':
logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
args = get_args()
logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logging.info(f'Using device {device}')
# Change here to adapt to your data
# n_channels=3 for RGB images
# n_classes is the number of probabilities you want to get per pixel
# - For 1 class and background, use n_classes=1
# - For 2 classes, use n_classes=1
# - For N > 2 classes, use n_classes=N
net = UNet(n_channels=3, n_classes=1, bilinear=True)
net = UNet(n_channels=3, n_classes=2, bilinear=True)
logging.info(f'Network:\n'
f'\t{net.n_channels} input channels\n'
f'\t{net.n_classes} output channels (classes)\n'
f'\t{"Bilinear" if net.bilinear else "Transposed conv"} upscaling')
if args.load:
net.load_state_dict(
torch.load(args.load, map_location=device)
)
net.load_state_dict(torch.load(args.load, map_location=device))
logging.info(f'Model loaded from {args.load}')
net.to(device=device)
# faster convolutions, but more memory
# cudnn.benchmark = True
try:
train_net(net=net,
epochs=args.epochs,
batch_size=args.batchsize,
lr=args.lr,
batch_size=args.batch_size,
learning_rate=args.lr,
device=device,
img_scale=args.scale,
val_percent=args.val / 100)
val_percent=args.val / 100,
amp=args.amp)
except KeyboardInterrupt:
torch.save(net.state_dict(), 'INTERRUPTED.pth')
logging.info('Saved interrupt')
try:
sys.exit(0)
except SystemExit:
os._exit(0)

View file

@ -1,7 +1,5 @@
""" Full assembly of the parts to form the complete network """
import torch.nn.functional as F
from .unet_parts import *

View file

@ -50,10 +50,9 @@ class Up(nn.Module):
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
else:
self.up = nn.ConvTranspose2d(in_channels , in_channels // 2, kernel_size=2, stride=2)
self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
self.conv = DoubleConv(in_channels, out_channels)
def forward(self, x1, x2):
x1 = self.up(x1)
# input is CHW

View file

@ -2,14 +2,14 @@ import matplotlib.pyplot as plt
def plot_img_and_mask(img, mask):
classes = mask.shape[2] if len(mask.shape) > 2 else 1
classes = mask.shape[0] if len(mask.shape) > 2 else 1
fig, ax = plt.subplots(1, classes + 1)
ax[0].set_title('Input image')
ax[0].imshow(img)
if classes > 1:
for i in range(classes):
ax[i+1].set_title(f'Output mask (class {i+1})')
ax[i+1].imshow(mask[:, :, i])
ax[i + 1].set_title(f'Output mask (class {i + 1})')
ax[i + 1].imshow(mask[:, :, i])
else:
ax[1].set_title(f'Output mask')
ax[1].imshow(mask)

View file

@ -1,71 +0,0 @@
from os.path import splitext
from os import listdir
import numpy as np
from glob import glob
import torch
from torch.utils.data import Dataset
import logging
from PIL import Image
class BasicDataset(Dataset):
def __init__(self, imgs_dir, masks_dir, scale=1, mask_suffix=''):
self.imgs_dir = imgs_dir
self.masks_dir = masks_dir
self.scale = scale
self.mask_suffix = mask_suffix
assert 0 < scale <= 1, 'Scale must be between 0 and 1'
self.ids = [splitext(file)[0] for file in listdir(imgs_dir)
if not file.startswith('.')]
logging.info(f'Creating dataset with {len(self.ids)} examples')
def __len__(self):
return len(self.ids)
@classmethod
def preprocess(cls, pil_img, scale):
w, h = pil_img.size
newW, newH = int(scale * w), int(scale * h)
assert newW > 0 and newH > 0, 'Scale is too small'
pil_img = pil_img.resize((newW, newH))
img_nd = np.array(pil_img)
if len(img_nd.shape) == 2:
img_nd = np.expand_dims(img_nd, axis=2)
# HWC to CHW
img_trans = img_nd.transpose((2, 0, 1))
if img_trans.max() > 1:
img_trans = img_trans / 255
return img_trans
def __getitem__(self, i):
idx = self.ids[i]
mask_file = glob(self.masks_dir + idx + self.mask_suffix + '.*')
img_file = glob(self.imgs_dir + idx + '.*')
assert len(mask_file) == 1, \
f'Either no mask or multiple masks found for the ID {idx}: {mask_file}'
assert len(img_file) == 1, \
f'Either no image or multiple images found for the ID {idx}: {img_file}'
mask = Image.open(mask_file[0])
img = Image.open(img_file[0])
assert img.size == mask.size, \
f'Image and mask {idx} should be the same size, but are {img.size} and {mask.size}'
img = self.preprocess(img, self.scale)
mask = self.preprocess(mask, self.scale)
return {
'image': torch.from_numpy(img).type(torch.FloatTensor),
'mask': torch.from_numpy(mask).type(torch.FloatTensor)
}
class CarvanaDataset(BasicDataset):
def __init__(self, imgs_dir, masks_dir, scale=1):
super().__init__(imgs_dir, masks_dir, scale, mask_suffix='_mask')