Summer cleanup

Former-commit-id: f6185d67a4bc50aa7ec1b8168aab3f92721c4965
This commit is contained in:
milesial 2021-08-16 02:53:00 +02:00
parent 2f4f7edd5d
commit 063cbbc599
15 changed files with 343 additions and 391 deletions

2
.gitignore vendored
View file

@ -4,5 +4,5 @@ __pycache__/
checkpoints/ checkpoints/
*.pth *.pth
*.jpg *.jpg
SUBMISSION*
venv/ venv/
.idea/

9
Dockerfile Normal file
View file

@ -0,0 +1,9 @@
FROM nvcr.io/nvidia/pytorch:21.06-py3
RUN rm -rf /workspace/*
WORKDIR /workspace/unet
ADD requirements.txt .
RUN pip install --no-cache-dir --upgrade --pre pip
RUN pip install --no-cache-dir -r requirements.txt
ADD . .

80
data_loading.py Normal file
View file

@ -0,0 +1,80 @@
import logging
from os import listdir
from os.path import splitext
from pathlib import Path
import numpy as np
import torch
from PIL import Image
from torch.utils.data import Dataset
class BasicDataset(Dataset):
def __init__(self, images_dir: str, masks_dir: str, scale: float = 1.0, mask_suffix: str = ''):
self.images_dir = Path(images_dir)
self.masks_dir = Path(masks_dir)
assert 0 < scale <= 1, 'Scale must be between 0 and 1'
self.scale = scale
self.mask_suffix = mask_suffix
self.ids = [splitext(file)[0] for file in listdir(images_dir) if not file.startswith('.')]
if not self.ids:
raise RuntimeError(f'No input file found in {images_dir}, make sure you put your images there')
logging.info(f'Creating dataset with {len(self.ids)} examples')
def __len__(self):
return len(self.ids)
@classmethod
def preprocess(cls, pil_img, scale, is_mask):
w, h = pil_img.size
newW, newH = int(scale * w), int(scale * h)
assert newW > 0 and newH > 0, 'Scale is too small, resized images would have no pixel'
pil_img = pil_img.resize((newW, newH))
img_ndarray = np.asarray(pil_img)
if img_ndarray.ndim == 2 and not is_mask:
img_ndarray = img_ndarray[np.newaxis, ...]
elif not is_mask:
img_ndarray = img_ndarray.transpose((2, 0, 1))
if not is_mask:
img_ndarray = img_ndarray / 255
return img_ndarray
@classmethod
def load(cls, filename):
ext = splitext(filename)[1]
if ext in ['.npz', '.npy']:
return Image.fromarray(np.load(filename))
elif ext in ['.pt', '.pth']:
return Image.fromarray(torch.load(filename).numpy())
else:
return Image.open(filename)
def __getitem__(self, idx):
name = self.ids[idx]
mask_file = list(self.masks_dir.glob(name + self.mask_suffix + '.*'))
img_file = list(self.images_dir.glob(name + '.*'))
assert len(mask_file) == 1, f'Either no mask or multiple masks found for the ID {name}: {mask_file}'
assert len(img_file) == 1, f'Either no image or multiple images found for the ID {name}: {img_file}'
mask = self.load(mask_file[0])
img = self.load(img_file[0])
assert img.size == mask.size, \
'Image and mask {name} should be the same size, but are {img.size} and {mask.size}'
img = self.preprocess(img, self.scale, is_mask=False)
mask = self.preprocess(mask, self.scale, is_mask=True)
return {
'image': torch.as_tensor(img.copy()).float().contiguous(),
'mask': torch.as_tensor(mask.copy()).long().contiguous()
}
class CarvanaDataset(BasicDataset):
def __init__(self, images_dir, masks_dir, scale=1):
super().__init__(images_dir, masks_dir, scale, mask_suffix='_mask')

View file

@ -1,42 +0,0 @@
import torch
from torch.autograd import Function
class DiceCoeff(Function):
"""Dice coeff for individual examples"""
def forward(self, input, target):
self.save_for_backward(input, target)
eps = 0.0001
self.inter = torch.dot(input.view(-1), target.view(-1))
self.union = torch.sum(input) + torch.sum(target) + eps
t = (2 * self.inter.float() + eps) / self.union.float()
return t
# This function has only a single output, so it gets only one gradient
def backward(self, grad_output):
input, target = self.saved_variables
grad_input = grad_target = None
if self.needs_input_grad[0]:
grad_input = grad_output * 2 * (target * self.union - self.inter) \
/ (self.union * self.union)
if self.needs_input_grad[1]:
grad_target = None
return grad_input, grad_target
def dice_coeff(input, target):
"""Dice coeff for batches"""
if input.is_cuda:
s = torch.FloatTensor(1).cuda().zero_()
else:
s = torch.FloatTensor(1).zero_()
for i, c in enumerate(zip(input, target)):
s = s + DiceCoeff().forward(c[0], c[1])
return s / (i + 1)

40
dice_score.py Normal file
View file

@ -0,0 +1,40 @@
import torch
from torch import Tensor
def dice_coeff(input: Tensor, target: Tensor, reduce_batch_first: bool = False, epsilon=1e-6):
# Average of Dice coefficient for all batches, or for a single mask
assert input.size() == target.size()
if input.dim() == 2 and reduce_batch_first:
raise ValueError(f'Dice: asked to reduce batch but got tensor without batch dimension (shape {input.shape})')
if input.dim() == 2 or reduce_batch_first:
inter = torch.dot(input.view(-1), target.view(-1))
sets_sum = torch.sum(input) + torch.sum(target)
if sets_sum.item() == 0:
sets_sum = 2 * inter
return (2 * inter + epsilon) / (sets_sum + epsilon)
else:
# compute and average metric for each batch element
dice = 0
for i in range(input.shape[0]):
dice += dice_coeff(input[i, ...], target[i, ...])
return dice / input.shape[0]
def multiclass_dice_coeff(input: Tensor, target: Tensor, reduce_batch_first: bool = False, epsilon=1e-6):
# Average of Dice coefficient for all classes
assert input.size() == target.size()
dice = 0
for channel in range(input.shape[1]):
dice += dice_coeff(input[:, channel, ...], target[:, channel, ...], reduce_batch_first, epsilon)
return dice / input.shape[1]
def dice_loss(input: Tensor, target: Tensor, multiclass: bool = False):
# Dice loss (objective to minimize) between 0 and 1
assert input.size() == target.size()
fn = multiclass_dice_coeff if multiclass else dice_coeff
return 1 - fn(input, target, reduce_batch_first=True)

33
eval.py
View file

@ -1,33 +0,0 @@
import torch
import torch.nn.functional as F
from tqdm import tqdm
from dice_loss import dice_coeff
def eval_net(net, loader, device):
"""Evaluation without the densecrf with the dice coefficient"""
net.eval()
mask_type = torch.float32 if net.n_classes == 1 else torch.long
n_val = len(loader) # the number of batch
tot = 0
with tqdm(total=n_val, desc='Validation round', unit='batch', leave=False) as pbar:
for batch in loader:
imgs, true_masks = batch['image'], batch['mask']
imgs = imgs.to(device=device, dtype=torch.float32)
true_masks = true_masks.to(device=device, dtype=mask_type)
with torch.no_grad():
mask_pred = net(imgs)
if net.n_classes > 1:
tot += F.cross_entropy(mask_pred, true_masks).item()
else:
pred = torch.sigmoid(mask_pred)
pred = (pred > 0.5).float()
tot += dice_coeff(pred, true_masks).item()
pbar.update()
net.train()
return tot / n_val

35
evaluate.py Normal file
View file

@ -0,0 +1,35 @@
import torch
import torch.nn.functional as F
from tqdm import tqdm
from dice_score import multiclass_dice_coeff
def evaluate(net, dataloader, device):
net.eval()
num_val_batches = len(dataloader)
dice_score = 0
# iterate over the validation set
for batch in tqdm(dataloader, total=num_val_batches, desc='Validation round', unit='batch', leave=False):
image, mask_true = batch['image'], batch['mask']
# move images and labels to correct device and type
image = image.to(device=device, dtype=torch.float32)
mask_true = mask_true.to(device=device, dtype=torch.long)
mask_true = F.one_hot(mask_true, net.n_classes).permute(0, 3, 1, 2).float()
with torch.no_grad():
# predict the mask
mask_pred = net(image)
# convert to one-hot format
if net.n_classes == 1:
mask_pred = (F.sigmoid(mask_pred) > 0).float()
else:
mask_pred = F.one_hot(mask_pred.argmax(dim=1), net.n_classes).permute(0, 3, 1, 2).float()
# compute the Dice score, ignoring background
dice_score += multiclass_dice_coeff(mask_pred[:, :1, ...], mask_true[:, :1, ...], reduce_batch_first=False)
net.train()
return dice_score / num_val_batches

View file

@ -8,9 +8,9 @@ import torch.nn.functional as F
from PIL import Image from PIL import Image
from torchvision import transforms from torchvision import transforms
from data_loading import BasicDataset
from unet import UNet from unet import UNet
from utils.data_vis import plot_img_and_mask from utils import plot_img_and_mask
from utils.dataset import BasicDataset
def predict_img(net, def predict_img(net,
@ -19,9 +19,7 @@ def predict_img(net,
scale_factor=1, scale_factor=1,
out_threshold=0.5): out_threshold=0.5):
net.eval() net.eval()
img = torch.from_numpy(BasicDataset.preprocess(full_img, scale_factor, is_mask=False))
img = torch.from_numpy(BasicDataset.preprocess(full_img, scale_factor))
img = img.unsqueeze(0) img = img.unsqueeze(0)
img = img.to(device=device, dtype=torch.float32) img = img.to(device=device, dtype=torch.float32)
@ -29,94 +27,75 @@ def predict_img(net,
output = net(img) output = net(img)
if net.n_classes > 1: if net.n_classes > 1:
probs = F.softmax(output, dim=1) probs = F.softmax(output, dim=1)[0]
else: else:
probs = torch.sigmoid(output) probs = torch.sigmoid(output)[0]
probs = probs.squeeze(0) tf = transforms.Compose([
tf = transforms.Compose(
[
transforms.ToPILImage(), transforms.ToPILImage(),
transforms.Resize(full_img.size[1]), transforms.Resize((full_img.size[1], full_img.size[0])),
transforms.ToTensor() transforms.ToTensor()
] ])
)
probs = tf(probs.cpu()) full_mask = tf(probs.cpu()).squeeze()
full_mask = probs.squeeze().cpu().numpy()
return full_mask > out_threshold if net.n_classes == 1:
return (full_mask > out_threshold).numpy()
else:
return F.one_hot(full_mask.argmax(dim=0), net.n_classes).permute(2, 0, 1).numpy()
def get_args(): def get_args():
parser = argparse.ArgumentParser(description='Predict masks from input images', parser = argparse.ArgumentParser(description='Predict masks from input images')
formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument('--model', '-m', default='MODEL.pth', metavar='FILE',
parser.add_argument('--model', '-m', default='MODEL.pth', help='Specify the file in which the model is stored')
metavar='FILE', parser.add_argument('--input', '-i', metavar='INPUT', nargs='+', help='Filenames of input images', required=True)
help="Specify the file in which the model is stored") parser.add_argument('--output', '-o', metavar='INPUT', nargs='+', help='Filenames of output images')
parser.add_argument('--input', '-i', metavar='INPUT', nargs='+',
help='filenames of input images', required=True)
parser.add_argument('--output', '-o', metavar='INPUT', nargs='+',
help='Filenames of ouput images')
parser.add_argument('--viz', '-v', action='store_true', parser.add_argument('--viz', '-v', action='store_true',
help="Visualize the images as they are processed", help='Visualize the images as they are processed')
default=False) parser.add_argument('--no-save', '-n', action='store_true', help='Do not save the output masks')
parser.add_argument('--no-save', '-n', action='store_true', parser.add_argument('--mask-threshold', '-t', type=float, default=0.5,
help="Do not save the output masks", help='Minimum probability value to consider a mask pixel white')
default=False) parser.add_argument('--scale', '-s', type=float, default=0.5,
parser.add_argument('--mask-threshold', '-t', type=float, help='Scale factor for the input images')
help="Minimum probability value to consider a mask pixel white",
default=0.5)
parser.add_argument('--scale', '-s', type=float,
help="Scale factor for the input images",
default=0.5)
return parser.parse_args() return parser.parse_args()
def get_output_filenames(args): def get_output_filenames(args):
in_files = args.input def _generate_name(fn):
out_files = [] split = os.path.splitext(fn)
return f'{split[0]}_OUT{split[1]}'
if not args.output: return args.output or list(map(_generate_name, args.input))
for f in in_files:
pathsplit = os.path.splitext(f)
out_files.append("{}_OUT{}".format(pathsplit[0], pathsplit[1]))
elif len(in_files) != len(args.output):
logging.error("Input files and output files are not of the same length")
raise SystemExit()
else:
out_files = args.output
return out_files
def mask_to_image(mask): def mask_to_image(mask: np.ndarray):
if mask.ndim == 2:
return Image.fromarray((mask * 255).astype(np.uint8)) return Image.fromarray((mask * 255).astype(np.uint8))
elif mask.ndim == 3:
return Image.fromarray((np.argmax(mask, dim=0) * 255 / mask.shape[0]).astype(np.uint8))
if __name__ == "__main__": if __name__ == '__main__':
args = get_args() args = get_args()
in_files = args.input in_files = args.input
out_files = get_output_filenames(args) out_files = get_output_filenames(args)
net = UNet(n_channels=3, n_classes=1) net = UNet(n_channels=3, n_classes=2)
logging.info("Loading model {}".format(args.model))
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logging.info(f'Loading model {args.model}')
logging.info(f'Using device {device}') logging.info(f'Using device {device}')
net.to(device=device) net.to(device=device)
net.load_state_dict(torch.load(args.model, map_location=device)) net.load_state_dict(torch.load(args.model, map_location=device))
logging.info("Model loaded !") logging.info('Model loaded!')
for i, fn in enumerate(in_files): for i, filename in enumerate(in_files):
logging.info("\nPredicting image {} ...".format(fn)) logging.info(f'\nPredicting image {filename} ...')
img = Image.open(filename)
img = Image.open(fn)
mask = predict_img(net=net, mask = predict_img(net=net,
full_img=img, full_img=img,
@ -125,12 +104,11 @@ if __name__ == "__main__":
device=device) device=device)
if not args.no_save: if not args.no_save:
out_fn = out_files[i] out_filename = out_files[i]
result = mask_to_image(mask) result = mask_to_image(mask)
result.save(out_files[i]) result.save(out_filename)
logging.info(f'Mask saved to {out_filename}')
logging.info("Mask saved to {}".format(out_files[i]))
if args.viz: if args.viz:
logging.info("Visualizing results for image {}, close to continue ...".format(fn)) logging.info(f'Visualizing results for image {filename}, close to continue...')
plot_img_and_mask(img, mask) plot_img_and_mask(img, mask)

View file

@ -3,6 +3,5 @@ numpy
Pillow Pillow
torch torch
torchvision torchvision
tensorboard
future
tqdm tqdm
wandb

View file

@ -1,46 +0,0 @@
""" Submit code specific to the kaggle challenge"""
import os
import torch
from PIL import Image
import numpy as np
from predict import predict_img
from unet import UNet
# credits to https://stackoverflow.com/users/6076729/manuel-lagunas
def rle_encode(mask_image):
pixels = mask_image.flatten()
# We avoid issues with '1' at the start or end (at the corners of
# the original image) by setting those pixels to '0' explicitly.
# We do not expect these to be non-zero for an accurate mask,
# so this should not harm the score.
pixels[0] = 0
pixels[-1] = 0
runs = np.where(pixels[1:] != pixels[:-1])[0] + 2
runs[1::2] = runs[1::2] - runs[:-1:2]
return runs
def submit(net):
"""Used for Kaggle submission: predicts and encode all test images"""
dir = 'data/test/'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
N = len(list(os.listdir(dir)))
with open('SUBMISSION.csv', 'a') as f:
f.write('img,rle_mask\n')
for index, i in enumerate(os.listdir(dir)):
print('{}/{}'.format(index, N))
img = Image.open(dir + i)
mask = predict_img(net, img, device)
enc = rle_encode(mask)
f.write('{},{}\n'.format(i, ' '.join(map(str, enc))))
if __name__ == '__main__':
net = UNet(3, 1).cuda()
net.load_state_dict(torch.load('MODEL.pth'))
submit(net)

210
train.py
View file

@ -1,187 +1,193 @@
import argparse import argparse
import logging import logging
import os
import sys import sys
from pathlib import Path
import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
import wandb
from torch import optim from torch import optim
from torch.utils.data import DataLoader, random_split
from tqdm import tqdm from tqdm import tqdm
from eval import eval_net from data_loading import BasicDataset, CarvanaDataset
from dice_score import dice_loss
from evaluate import evaluate
from unet import UNet from unet import UNet
from torch.utils.tensorboard import SummaryWriter dir_img = Path('./data/imgs/')
from utils.dataset import BasicDataset dir_mask = Path('./data/masks/')
from torch.utils.data import DataLoader, random_split dir_checkpoint = Path('./checkpoints/')
dir_img = 'data/imgs/'
dir_mask = 'data/masks/'
dir_checkpoint = 'checkpoints/'
def train_net(net, def train_net(net,
device, device,
epochs=5, epochs: int = 5,
batch_size=1, batch_size: int = 1,
lr=0.001, learning_rate: float = 0.001,
val_percent=0.1, val_percent: float = 0.1,
save_cp=True, save_checkpoint: bool = True,
img_scale=0.5): img_scale: float = 0.5,
amp: bool = False):
# 1. Create dataset
try:
dataset = CarvanaDataset(dir_img, dir_mask, img_scale)
except (AssertionError, RuntimeError):
dataset = BasicDataset(dir_img, dir_mask, img_scale) dataset = BasicDataset(dir_img, dir_mask, img_scale)
# 2. Split into train / validation partitions
n_val = int(len(dataset) * val_percent) n_val = int(len(dataset) * val_percent)
n_train = len(dataset) - n_val n_train = len(dataset) - n_val
train, val = random_split(dataset, [n_train, n_val]) train_set, val_set = random_split(dataset, [n_train, n_val], generator=torch.Generator().manual_seed(0))
train_loader = DataLoader(train, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True)
val_loader = DataLoader(val, batch_size=batch_size, shuffle=False, num_workers=8, pin_memory=True, drop_last=True)
writer = SummaryWriter(comment=f'LR_{lr}_BS_{batch_size}_SCALE_{img_scale}') # 3. Create data loaders
global_step = 0 loader_args = dict(batch_size=batch_size, num_workers=4, pin_memory=True)
train_loader = DataLoader(train_set, shuffle=True, **loader_args)
val_loader = DataLoader(val_set, shuffle=False, drop_last=True, **loader_args)
# (Initialise logging)
experiment = wandb.init(project='U-Net', resume='allow', anonymous='must')
experiment.config.update(dict(epochs=epochs, batch_size=batch_size, learning_rate=learning_rate,
val_percent=val_percent, save_checkpoint=save_checkpoint, img_scale=img_scale,
amp=amp))
logging.info(f'''Starting training: logging.info(f'''Starting training:
Epochs: {epochs} Epochs: {epochs}
Batch size: {batch_size} Batch size: {batch_size}
Learning rate: {lr} Learning rate: {learning_rate}
Training size: {n_train} Training size: {n_train}
Validation size: {n_val} Validation size: {n_val}
Checkpoints: {save_cp} Checkpoints: {save_checkpoint}
Device: {device.type} Device: {device.type}
Images scaling: {img_scale} Images scaling: {img_scale}
Mixed Precision: {amp}
''') ''')
optimizer = optim.RMSprop(net.parameters(), lr=lr, weight_decay=1e-8, momentum=0.9) # 4. Set up the optimizer, the loss, the learning rate scheduler and the loss scaling for AMP
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min' if net.n_classes > 1 else 'max', patience=2) optimizer = optim.RMSprop(net.parameters(), lr=learning_rate, weight_decay=1e-8, momentum=0.9)
if net.n_classes > 1: scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', patience=2) # goal: maximize Dice score
grad_scaler = torch.cuda.amp.GradScaler(enabled=amp)
criterion = nn.CrossEntropyLoss() criterion = nn.CrossEntropyLoss()
else: global_step = 0
criterion = nn.BCEWithLogitsLoss()
# 5. Begin training
for epoch in range(epochs): for epoch in range(epochs):
net.train() net.train()
epoch_loss = 0 epoch_loss = 0
with tqdm(total=n_train, desc=f'Epoch {epoch + 1}/{epochs}', unit='img') as pbar: with tqdm(total=n_train, desc=f'Epoch {epoch + 1}/{epochs}', unit='img') as pbar:
for batch in train_loader: for batch in train_loader:
imgs = batch['image'] images = batch['image']
true_masks = batch['mask'] true_masks = batch['mask']
assert imgs.shape[1] == net.n_channels, \
assert images.shape[1] == net.n_channels, \
f'Network has been defined with {net.n_channels} input channels, ' \ f'Network has been defined with {net.n_channels} input channels, ' \
f'but loaded images have {imgs.shape[1]} channels. Please check that ' \ f'but loaded images have {images.shape[1]} channels. Please check that ' \
'the images are loaded correctly.' 'the images are loaded correctly.'
imgs = imgs.to(device=device, dtype=torch.float32) images = images.to(device=device, dtype=torch.float32)
mask_type = torch.float32 if net.n_classes == 1 else torch.long true_masks = true_masks.to(device=device, dtype=torch.long)
true_masks = true_masks.to(device=device, dtype=mask_type)
masks_pred = net(imgs) with torch.cuda.amp.autocast(enabled=amp):
loss = criterion(masks_pred, true_masks) masks_pred = net(images)
loss = criterion(masks_pred, true_masks) \
+ dice_loss(F.softmax(masks_pred, dim=1).float(),
F.one_hot(true_masks, net.n_classes).permute(0, 3, 1, 2).float(),
multiclass=True)
optimizer.zero_grad(set_to_none=True)
grad_scaler.scale(loss).backward()
grad_scaler.step(optimizer)
grad_scaler.update()
pbar.update(images.shape[0])
global_step += 1
epoch_loss += loss.item() epoch_loss += loss.item()
writer.add_scalar('Loss/train', loss.item(), global_step) experiment.log({
'train loss': loss.item(),
'step': global_step,
'epoch': epoch
})
pbar.set_postfix(**{'loss (batch)': loss.item()}) pbar.set_postfix(**{'loss (batch)': loss.item()})
optimizer.zero_grad() # Evaluation round
loss.backward()
nn.utils.clip_grad_value_(net.parameters(), 0.1)
optimizer.step()
pbar.update(imgs.shape[0])
global_step += 1
if global_step % (n_train // (10 * batch_size)) == 0: if global_step % (n_train // (10 * batch_size)) == 0:
histograms = {}
for tag, value in net.named_parameters(): for tag, value in net.named_parameters():
tag = tag.replace('.', '/') tag = tag.replace('/', '.')
writer.add_histogram('weights/' + tag, value.data.cpu().numpy(), global_step) histograms['Weights/' + tag] = wandb.Histogram(value.data.cpu())
writer.add_histogram('grads/' + tag, value.grad.data.cpu().numpy(), global_step) histograms['Gradients/' + tag] = wandb.Histogram(value.grad.data.cpu())
val_score = eval_net(net, val_loader, device)
val_score = evaluate(net, val_loader, device)
scheduler.step(val_score) scheduler.step(val_score)
writer.add_scalar('learning_rate', optimizer.param_groups[0]['lr'], global_step)
if net.n_classes > 1: logging.info('Validation Dice score: {}'.format(val_score))
logging.info('Validation cross entropy: {}'.format(val_score)) experiment.log({
writer.add_scalar('Loss/test', val_score, global_step) 'learning rate': optimizer.param_groups[0]['lr'],
else: 'validation Dice': val_score,
logging.info('Validation Dice Coeff: {}'.format(val_score)) 'images': wandb.Image(images[0].cpu()),
writer.add_scalar('Dice/test', val_score, global_step) 'masks': {
'true': wandb.Image(true_masks[0].float().cpu()),
'pred': wandb.Image(torch.softmax(masks_pred, dim=1)[0].float().cpu()),
},
'step': global_step,
'epoch': epoch,
**histograms
})
writer.add_images('images', imgs, global_step) if save_checkpoint:
if net.n_classes == 1: Path(dir_checkpoint).mkdir(parents=True, exist_ok=True)
writer.add_images('masks/true', true_masks, global_step) torch.save(net.state_dict(), str(dir_checkpoint / 'checkpoint_epoch{}.pth'.format(epoch + 1)))
writer.add_images('masks/pred', torch.sigmoid(masks_pred) > 0.5, global_step)
if save_cp:
try:
os.mkdir(dir_checkpoint)
logging.info('Created checkpoint directory')
except OSError:
pass
torch.save(net.state_dict(),
dir_checkpoint + f'CP_epoch{epoch + 1}.pth')
logging.info(f'Checkpoint {epoch + 1} saved!') logging.info(f'Checkpoint {epoch + 1} saved!')
writer.close()
def get_args(): def get_args():
parser = argparse.ArgumentParser(description='Train the UNet on images and target masks', parser = argparse.ArgumentParser(description='Train the UNet on images and target masks')
formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument('--epochs', '-e', metavar='E', type=int, default=50, help='Number of epochs')
parser.add_argument('-e', '--epochs', metavar='E', type=int, default=5, parser.add_argument('--batch-size', '-b', dest='batch_size', metavar='B', type=int, default=1, help='Batch size')
help='Number of epochs', dest='epochs') parser.add_argument('--learning-rate', '-l', metavar='LR', type=float, default=0.00001,
parser.add_argument('-b', '--batch-size', metavar='B', type=int, nargs='?', default=1,
help='Batch size', dest='batchsize')
parser.add_argument('-l', '--learning-rate', metavar='LR', type=float, nargs='?', default=0.0001,
help='Learning rate', dest='lr') help='Learning rate', dest='lr')
parser.add_argument('-f', '--load', dest='load', type=str, default=False, parser.add_argument('--load', '-f', type=str, default=False, help='Load model from a .pth file')
help='Load model from a .pth file') parser.add_argument('--scale', '-s', type=float, default=0.5, help='Downscaling factor of the images')
parser.add_argument('-s', '--scale', dest='scale', type=float, default=0.5, parser.add_argument('--validation', '-v', dest='val', type=float, default=10.0,
help='Downscaling factor of the images')
parser.add_argument('-v', '--validation', dest='val', type=float, default=10.0,
help='Percent of the data that is used as validation (0-100)') help='Percent of the data that is used as validation (0-100)')
parser.add_argument('--amp', action='store_true', default=False, help='Use mixed precision')
return parser.parse_args() return parser.parse_args()
if __name__ == '__main__': if __name__ == '__main__':
logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
args = get_args() args = get_args()
logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logging.info(f'Using device {device}') logging.info(f'Using device {device}')
# Change here to adapt to your data # Change here to adapt to your data
# n_channels=3 for RGB images # n_channels=3 for RGB images
# n_classes is the number of probabilities you want to get per pixel # n_classes is the number of probabilities you want to get per pixel
# - For 1 class and background, use n_classes=1 net = UNet(n_channels=3, n_classes=2, bilinear=True)
# - For 2 classes, use n_classes=1
# - For N > 2 classes, use n_classes=N
net = UNet(n_channels=3, n_classes=1, bilinear=True)
logging.info(f'Network:\n' logging.info(f'Network:\n'
f'\t{net.n_channels} input channels\n' f'\t{net.n_channels} input channels\n'
f'\t{net.n_classes} output channels (classes)\n' f'\t{net.n_classes} output channels (classes)\n'
f'\t{"Bilinear" if net.bilinear else "Transposed conv"} upscaling') f'\t{"Bilinear" if net.bilinear else "Transposed conv"} upscaling')
if args.load: if args.load:
net.load_state_dict( net.load_state_dict(torch.load(args.load, map_location=device))
torch.load(args.load, map_location=device)
)
logging.info(f'Model loaded from {args.load}') logging.info(f'Model loaded from {args.load}')
net.to(device=device) net.to(device=device)
# faster convolutions, but more memory
# cudnn.benchmark = True
try: try:
train_net(net=net, train_net(net=net,
epochs=args.epochs, epochs=args.epochs,
batch_size=args.batchsize, batch_size=args.batch_size,
lr=args.lr, learning_rate=args.lr,
device=device, device=device,
img_scale=args.scale, img_scale=args.scale,
val_percent=args.val / 100) val_percent=args.val / 100,
amp=args.amp)
except KeyboardInterrupt: except KeyboardInterrupt:
torch.save(net.state_dict(), 'INTERRUPTED.pth') torch.save(net.state_dict(), 'INTERRUPTED.pth')
logging.info('Saved interrupt') logging.info('Saved interrupt')
try:
sys.exit(0) sys.exit(0)
except SystemExit:
os._exit(0)

View file

@ -1,7 +1,5 @@
""" Full assembly of the parts to form the complete network """ """ Full assembly of the parts to form the complete network """
import torch.nn.functional as F
from .unet_parts import * from .unet_parts import *

View file

@ -53,7 +53,6 @@ class Up(nn.Module):
self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2) self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
self.conv = DoubleConv(in_channels, out_channels) self.conv = DoubleConv(in_channels, out_channels)
def forward(self, x1, x2): def forward(self, x1, x2):
x1 = self.up(x1) x1 = self.up(x1)
# input is CHW # input is CHW

View file

@ -2,7 +2,7 @@ import matplotlib.pyplot as plt
def plot_img_and_mask(img, mask): def plot_img_and_mask(img, mask):
classes = mask.shape[2] if len(mask.shape) > 2 else 1 classes = mask.shape[0] if len(mask.shape) > 2 else 1
fig, ax = plt.subplots(1, classes + 1) fig, ax = plt.subplots(1, classes + 1)
ax[0].set_title('Input image') ax[0].set_title('Input image')
ax[0].imshow(img) ax[0].imshow(img)

View file

@ -1,71 +0,0 @@
from os.path import splitext
from os import listdir
import numpy as np
from glob import glob
import torch
from torch.utils.data import Dataset
import logging
from PIL import Image
class BasicDataset(Dataset):
def __init__(self, imgs_dir, masks_dir, scale=1, mask_suffix=''):
self.imgs_dir = imgs_dir
self.masks_dir = masks_dir
self.scale = scale
self.mask_suffix = mask_suffix
assert 0 < scale <= 1, 'Scale must be between 0 and 1'
self.ids = [splitext(file)[0] for file in listdir(imgs_dir)
if not file.startswith('.')]
logging.info(f'Creating dataset with {len(self.ids)} examples')
def __len__(self):
return len(self.ids)
@classmethod
def preprocess(cls, pil_img, scale):
w, h = pil_img.size
newW, newH = int(scale * w), int(scale * h)
assert newW > 0 and newH > 0, 'Scale is too small'
pil_img = pil_img.resize((newW, newH))
img_nd = np.array(pil_img)
if len(img_nd.shape) == 2:
img_nd = np.expand_dims(img_nd, axis=2)
# HWC to CHW
img_trans = img_nd.transpose((2, 0, 1))
if img_trans.max() > 1:
img_trans = img_trans / 255
return img_trans
def __getitem__(self, i):
idx = self.ids[i]
mask_file = glob(self.masks_dir + idx + self.mask_suffix + '.*')
img_file = glob(self.imgs_dir + idx + '.*')
assert len(mask_file) == 1, \
f'Either no mask or multiple masks found for the ID {idx}: {mask_file}'
assert len(img_file) == 1, \
f'Either no image or multiple images found for the ID {idx}: {img_file}'
mask = Image.open(mask_file[0])
img = Image.open(img_file[0])
assert img.size == mask.size, \
f'Image and mask {idx} should be the same size, but are {img.size} and {mask.size}'
img = self.preprocess(img, self.scale)
mask = self.preprocess(mask, self.scale)
return {
'image': torch.from_numpy(img).type(torch.FloatTensor),
'mask': torch.from_numpy(mask).type(torch.FloatTensor)
}
class CarvanaDataset(BasicDataset):
def __init__(self, imgs_dir, masks_dir, scale=1):
super().__init__(imgs_dir, masks_dir, scale, mask_suffix='_mask')