mirror of
https://github.com/Laurent2916/REVA-QCAV.git
synced 2024-11-09 15:02:03 +00:00
Summer cleanup
Former-commit-id: f6185d67a4bc50aa7ec1b8168aab3f92721c4965
This commit is contained in:
parent
2f4f7edd5d
commit
063cbbc599
2
.gitignore
vendored
2
.gitignore
vendored
|
@ -4,5 +4,5 @@ __pycache__/
|
|||
checkpoints/
|
||||
*.pth
|
||||
*.jpg
|
||||
SUBMISSION*
|
||||
venv/
|
||||
.idea/
|
9
Dockerfile
Normal file
9
Dockerfile
Normal file
|
@ -0,0 +1,9 @@
|
|||
FROM nvcr.io/nvidia/pytorch:21.06-py3
|
||||
|
||||
RUN rm -rf /workspace/*
|
||||
WORKDIR /workspace/unet
|
||||
|
||||
ADD requirements.txt .
|
||||
RUN pip install --no-cache-dir --upgrade --pre pip
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
ADD . .
|
80
data_loading.py
Normal file
80
data_loading.py
Normal file
|
@ -0,0 +1,80 @@
|
|||
import logging
|
||||
from os import listdir
|
||||
from os.path import splitext
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
|
||||
class BasicDataset(Dataset):
|
||||
def __init__(self, images_dir: str, masks_dir: str, scale: float = 1.0, mask_suffix: str = ''):
|
||||
self.images_dir = Path(images_dir)
|
||||
self.masks_dir = Path(masks_dir)
|
||||
assert 0 < scale <= 1, 'Scale must be between 0 and 1'
|
||||
self.scale = scale
|
||||
self.mask_suffix = mask_suffix
|
||||
|
||||
self.ids = [splitext(file)[0] for file in listdir(images_dir) if not file.startswith('.')]
|
||||
if not self.ids:
|
||||
raise RuntimeError(f'No input file found in {images_dir}, make sure you put your images there')
|
||||
logging.info(f'Creating dataset with {len(self.ids)} examples')
|
||||
|
||||
def __len__(self):
|
||||
return len(self.ids)
|
||||
|
||||
@classmethod
|
||||
def preprocess(cls, pil_img, scale, is_mask):
|
||||
w, h = pil_img.size
|
||||
newW, newH = int(scale * w), int(scale * h)
|
||||
assert newW > 0 and newH > 0, 'Scale is too small, resized images would have no pixel'
|
||||
pil_img = pil_img.resize((newW, newH))
|
||||
img_ndarray = np.asarray(pil_img)
|
||||
|
||||
if img_ndarray.ndim == 2 and not is_mask:
|
||||
img_ndarray = img_ndarray[np.newaxis, ...]
|
||||
elif not is_mask:
|
||||
img_ndarray = img_ndarray.transpose((2, 0, 1))
|
||||
|
||||
if not is_mask:
|
||||
img_ndarray = img_ndarray / 255
|
||||
|
||||
return img_ndarray
|
||||
|
||||
@classmethod
|
||||
def load(cls, filename):
|
||||
ext = splitext(filename)[1]
|
||||
if ext in ['.npz', '.npy']:
|
||||
return Image.fromarray(np.load(filename))
|
||||
elif ext in ['.pt', '.pth']:
|
||||
return Image.fromarray(torch.load(filename).numpy())
|
||||
else:
|
||||
return Image.open(filename)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
name = self.ids[idx]
|
||||
mask_file = list(self.masks_dir.glob(name + self.mask_suffix + '.*'))
|
||||
img_file = list(self.images_dir.glob(name + '.*'))
|
||||
|
||||
assert len(mask_file) == 1, f'Either no mask or multiple masks found for the ID {name}: {mask_file}'
|
||||
assert len(img_file) == 1, f'Either no image or multiple images found for the ID {name}: {img_file}'
|
||||
mask = self.load(mask_file[0])
|
||||
img = self.load(img_file[0])
|
||||
|
||||
assert img.size == mask.size, \
|
||||
'Image and mask {name} should be the same size, but are {img.size} and {mask.size}'
|
||||
|
||||
img = self.preprocess(img, self.scale, is_mask=False)
|
||||
mask = self.preprocess(mask, self.scale, is_mask=True)
|
||||
|
||||
return {
|
||||
'image': torch.as_tensor(img.copy()).float().contiguous(),
|
||||
'mask': torch.as_tensor(mask.copy()).long().contiguous()
|
||||
}
|
||||
|
||||
|
||||
class CarvanaDataset(BasicDataset):
|
||||
def __init__(self, images_dir, masks_dir, scale=1):
|
||||
super().__init__(images_dir, masks_dir, scale, mask_suffix='_mask')
|
42
dice_loss.py
42
dice_loss.py
|
@ -1,42 +0,0 @@
|
|||
import torch
|
||||
from torch.autograd import Function
|
||||
|
||||
|
||||
class DiceCoeff(Function):
|
||||
"""Dice coeff for individual examples"""
|
||||
|
||||
def forward(self, input, target):
|
||||
self.save_for_backward(input, target)
|
||||
eps = 0.0001
|
||||
self.inter = torch.dot(input.view(-1), target.view(-1))
|
||||
self.union = torch.sum(input) + torch.sum(target) + eps
|
||||
|
||||
t = (2 * self.inter.float() + eps) / self.union.float()
|
||||
return t
|
||||
|
||||
# This function has only a single output, so it gets only one gradient
|
||||
def backward(self, grad_output):
|
||||
|
||||
input, target = self.saved_variables
|
||||
grad_input = grad_target = None
|
||||
|
||||
if self.needs_input_grad[0]:
|
||||
grad_input = grad_output * 2 * (target * self.union - self.inter) \
|
||||
/ (self.union * self.union)
|
||||
if self.needs_input_grad[1]:
|
||||
grad_target = None
|
||||
|
||||
return grad_input, grad_target
|
||||
|
||||
|
||||
def dice_coeff(input, target):
|
||||
"""Dice coeff for batches"""
|
||||
if input.is_cuda:
|
||||
s = torch.FloatTensor(1).cuda().zero_()
|
||||
else:
|
||||
s = torch.FloatTensor(1).zero_()
|
||||
|
||||
for i, c in enumerate(zip(input, target)):
|
||||
s = s + DiceCoeff().forward(c[0], c[1])
|
||||
|
||||
return s / (i + 1)
|
40
dice_score.py
Normal file
40
dice_score.py
Normal file
|
@ -0,0 +1,40 @@
|
|||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
def dice_coeff(input: Tensor, target: Tensor, reduce_batch_first: bool = False, epsilon=1e-6):
|
||||
# Average of Dice coefficient for all batches, or for a single mask
|
||||
assert input.size() == target.size()
|
||||
if input.dim() == 2 and reduce_batch_first:
|
||||
raise ValueError(f'Dice: asked to reduce batch but got tensor without batch dimension (shape {input.shape})')
|
||||
|
||||
if input.dim() == 2 or reduce_batch_first:
|
||||
inter = torch.dot(input.view(-1), target.view(-1))
|
||||
sets_sum = torch.sum(input) + torch.sum(target)
|
||||
if sets_sum.item() == 0:
|
||||
sets_sum = 2 * inter
|
||||
|
||||
return (2 * inter + epsilon) / (sets_sum + epsilon)
|
||||
else:
|
||||
# compute and average metric for each batch element
|
||||
dice = 0
|
||||
for i in range(input.shape[0]):
|
||||
dice += dice_coeff(input[i, ...], target[i, ...])
|
||||
return dice / input.shape[0]
|
||||
|
||||
|
||||
def multiclass_dice_coeff(input: Tensor, target: Tensor, reduce_batch_first: bool = False, epsilon=1e-6):
|
||||
# Average of Dice coefficient for all classes
|
||||
assert input.size() == target.size()
|
||||
dice = 0
|
||||
for channel in range(input.shape[1]):
|
||||
dice += dice_coeff(input[:, channel, ...], target[:, channel, ...], reduce_batch_first, epsilon)
|
||||
|
||||
return dice / input.shape[1]
|
||||
|
||||
|
||||
def dice_loss(input: Tensor, target: Tensor, multiclass: bool = False):
|
||||
# Dice loss (objective to minimize) between 0 and 1
|
||||
assert input.size() == target.size()
|
||||
fn = multiclass_dice_coeff if multiclass else dice_coeff
|
||||
return 1 - fn(input, target, reduce_batch_first=True)
|
33
eval.py
33
eval.py
|
@ -1,33 +0,0 @@
|
|||
import torch
|
||||
import torch.nn.functional as F
|
||||
from tqdm import tqdm
|
||||
|
||||
from dice_loss import dice_coeff
|
||||
|
||||
|
||||
def eval_net(net, loader, device):
|
||||
"""Evaluation without the densecrf with the dice coefficient"""
|
||||
net.eval()
|
||||
mask_type = torch.float32 if net.n_classes == 1 else torch.long
|
||||
n_val = len(loader) # the number of batch
|
||||
tot = 0
|
||||
|
||||
with tqdm(total=n_val, desc='Validation round', unit='batch', leave=False) as pbar:
|
||||
for batch in loader:
|
||||
imgs, true_masks = batch['image'], batch['mask']
|
||||
imgs = imgs.to(device=device, dtype=torch.float32)
|
||||
true_masks = true_masks.to(device=device, dtype=mask_type)
|
||||
|
||||
with torch.no_grad():
|
||||
mask_pred = net(imgs)
|
||||
|
||||
if net.n_classes > 1:
|
||||
tot += F.cross_entropy(mask_pred, true_masks).item()
|
||||
else:
|
||||
pred = torch.sigmoid(mask_pred)
|
||||
pred = (pred > 0.5).float()
|
||||
tot += dice_coeff(pred, true_masks).item()
|
||||
pbar.update()
|
||||
|
||||
net.train()
|
||||
return tot / n_val
|
35
evaluate.py
Normal file
35
evaluate.py
Normal file
|
@ -0,0 +1,35 @@
|
|||
import torch
|
||||
import torch.nn.functional as F
|
||||
from tqdm import tqdm
|
||||
|
||||
from dice_score import multiclass_dice_coeff
|
||||
|
||||
|
||||
def evaluate(net, dataloader, device):
|
||||
net.eval()
|
||||
num_val_batches = len(dataloader)
|
||||
dice_score = 0
|
||||
|
||||
# iterate over the validation set
|
||||
for batch in tqdm(dataloader, total=num_val_batches, desc='Validation round', unit='batch', leave=False):
|
||||
image, mask_true = batch['image'], batch['mask']
|
||||
# move images and labels to correct device and type
|
||||
image = image.to(device=device, dtype=torch.float32)
|
||||
mask_true = mask_true.to(device=device, dtype=torch.long)
|
||||
mask_true = F.one_hot(mask_true, net.n_classes).permute(0, 3, 1, 2).float()
|
||||
|
||||
with torch.no_grad():
|
||||
# predict the mask
|
||||
mask_pred = net(image)
|
||||
|
||||
# convert to one-hot format
|
||||
if net.n_classes == 1:
|
||||
mask_pred = (F.sigmoid(mask_pred) > 0).float()
|
||||
else:
|
||||
mask_pred = F.one_hot(mask_pred.argmax(dim=1), net.n_classes).permute(0, 3, 1, 2).float()
|
||||
|
||||
# compute the Dice score, ignoring background
|
||||
dice_score += multiclass_dice_coeff(mask_pred[:, :1, ...], mask_true[:, :1, ...], reduce_batch_first=False)
|
||||
|
||||
net.train()
|
||||
return dice_score / num_val_batches
|
110
predict.py
110
predict.py
|
@ -8,9 +8,9 @@ import torch.nn.functional as F
|
|||
from PIL import Image
|
||||
from torchvision import transforms
|
||||
|
||||
from data_loading import BasicDataset
|
||||
from unet import UNet
|
||||
from utils.data_vis import plot_img_and_mask
|
||||
from utils.dataset import BasicDataset
|
||||
from utils import plot_img_and_mask
|
||||
|
||||
|
||||
def predict_img(net,
|
||||
|
@ -19,9 +19,7 @@ def predict_img(net,
|
|||
scale_factor=1,
|
||||
out_threshold=0.5):
|
||||
net.eval()
|
||||
|
||||
img = torch.from_numpy(BasicDataset.preprocess(full_img, scale_factor))
|
||||
|
||||
img = torch.from_numpy(BasicDataset.preprocess(full_img, scale_factor, is_mask=False))
|
||||
img = img.unsqueeze(0)
|
||||
img = img.to(device=device, dtype=torch.float32)
|
||||
|
||||
|
@ -29,94 +27,75 @@ def predict_img(net,
|
|||
output = net(img)
|
||||
|
||||
if net.n_classes > 1:
|
||||
probs = F.softmax(output, dim=1)
|
||||
probs = F.softmax(output, dim=1)[0]
|
||||
else:
|
||||
probs = torch.sigmoid(output)
|
||||
probs = torch.sigmoid(output)[0]
|
||||
|
||||
probs = probs.squeeze(0)
|
||||
|
||||
tf = transforms.Compose(
|
||||
[
|
||||
tf = transforms.Compose([
|
||||
transforms.ToPILImage(),
|
||||
transforms.Resize(full_img.size[1]),
|
||||
transforms.Resize((full_img.size[1], full_img.size[0])),
|
||||
transforms.ToTensor()
|
||||
]
|
||||
)
|
||||
])
|
||||
|
||||
probs = tf(probs.cpu())
|
||||
full_mask = probs.squeeze().cpu().numpy()
|
||||
full_mask = tf(probs.cpu()).squeeze()
|
||||
|
||||
return full_mask > out_threshold
|
||||
if net.n_classes == 1:
|
||||
return (full_mask > out_threshold).numpy()
|
||||
else:
|
||||
return F.one_hot(full_mask.argmax(dim=0), net.n_classes).permute(2, 0, 1).numpy()
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser(description='Predict masks from input images',
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||
parser.add_argument('--model', '-m', default='MODEL.pth',
|
||||
metavar='FILE',
|
||||
help="Specify the file in which the model is stored")
|
||||
parser.add_argument('--input', '-i', metavar='INPUT', nargs='+',
|
||||
help='filenames of input images', required=True)
|
||||
|
||||
parser.add_argument('--output', '-o', metavar='INPUT', nargs='+',
|
||||
help='Filenames of ouput images')
|
||||
parser = argparse.ArgumentParser(description='Predict masks from input images')
|
||||
parser.add_argument('--model', '-m', default='MODEL.pth', metavar='FILE',
|
||||
help='Specify the file in which the model is stored')
|
||||
parser.add_argument('--input', '-i', metavar='INPUT', nargs='+', help='Filenames of input images', required=True)
|
||||
parser.add_argument('--output', '-o', metavar='INPUT', nargs='+', help='Filenames of output images')
|
||||
parser.add_argument('--viz', '-v', action='store_true',
|
||||
help="Visualize the images as they are processed",
|
||||
default=False)
|
||||
parser.add_argument('--no-save', '-n', action='store_true',
|
||||
help="Do not save the output masks",
|
||||
default=False)
|
||||
parser.add_argument('--mask-threshold', '-t', type=float,
|
||||
help="Minimum probability value to consider a mask pixel white",
|
||||
default=0.5)
|
||||
parser.add_argument('--scale', '-s', type=float,
|
||||
help="Scale factor for the input images",
|
||||
default=0.5)
|
||||
help='Visualize the images as they are processed')
|
||||
parser.add_argument('--no-save', '-n', action='store_true', help='Do not save the output masks')
|
||||
parser.add_argument('--mask-threshold', '-t', type=float, default=0.5,
|
||||
help='Minimum probability value to consider a mask pixel white')
|
||||
parser.add_argument('--scale', '-s', type=float, default=0.5,
|
||||
help='Scale factor for the input images')
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def get_output_filenames(args):
|
||||
in_files = args.input
|
||||
out_files = []
|
||||
def _generate_name(fn):
|
||||
split = os.path.splitext(fn)
|
||||
return f'{split[0]}_OUT{split[1]}'
|
||||
|
||||
if not args.output:
|
||||
for f in in_files:
|
||||
pathsplit = os.path.splitext(f)
|
||||
out_files.append("{}_OUT{}".format(pathsplit[0], pathsplit[1]))
|
||||
elif len(in_files) != len(args.output):
|
||||
logging.error("Input files and output files are not of the same length")
|
||||
raise SystemExit()
|
||||
else:
|
||||
out_files = args.output
|
||||
|
||||
return out_files
|
||||
return args.output or list(map(_generate_name, args.input))
|
||||
|
||||
|
||||
def mask_to_image(mask):
|
||||
def mask_to_image(mask: np.ndarray):
|
||||
if mask.ndim == 2:
|
||||
return Image.fromarray((mask * 255).astype(np.uint8))
|
||||
elif mask.ndim == 3:
|
||||
return Image.fromarray((np.argmax(mask, dim=0) * 255 / mask.shape[0]).astype(np.uint8))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if __name__ == '__main__':
|
||||
args = get_args()
|
||||
in_files = args.input
|
||||
out_files = get_output_filenames(args)
|
||||
|
||||
net = UNet(n_channels=3, n_classes=1)
|
||||
|
||||
logging.info("Loading model {}".format(args.model))
|
||||
net = UNet(n_channels=3, n_classes=2)
|
||||
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
logging.info(f'Loading model {args.model}')
|
||||
logging.info(f'Using device {device}')
|
||||
|
||||
net.to(device=device)
|
||||
net.load_state_dict(torch.load(args.model, map_location=device))
|
||||
|
||||
logging.info("Model loaded !")
|
||||
logging.info('Model loaded!')
|
||||
|
||||
for i, fn in enumerate(in_files):
|
||||
logging.info("\nPredicting image {} ...".format(fn))
|
||||
|
||||
img = Image.open(fn)
|
||||
for i, filename in enumerate(in_files):
|
||||
logging.info(f'\nPredicting image {filename} ...')
|
||||
img = Image.open(filename)
|
||||
|
||||
mask = predict_img(net=net,
|
||||
full_img=img,
|
||||
|
@ -125,12 +104,11 @@ if __name__ == "__main__":
|
|||
device=device)
|
||||
|
||||
if not args.no_save:
|
||||
out_fn = out_files[i]
|
||||
out_filename = out_files[i]
|
||||
result = mask_to_image(mask)
|
||||
result.save(out_files[i])
|
||||
|
||||
logging.info("Mask saved to {}".format(out_files[i]))
|
||||
result.save(out_filename)
|
||||
logging.info(f'Mask saved to {out_filename}')
|
||||
|
||||
if args.viz:
|
||||
logging.info("Visualizing results for image {}, close to continue ...".format(fn))
|
||||
logging.info(f'Visualizing results for image {filename}, close to continue...')
|
||||
plot_img_and_mask(img, mask)
|
||||
|
|
|
@ -3,6 +3,5 @@ numpy
|
|||
Pillow
|
||||
torch
|
||||
torchvision
|
||||
tensorboard
|
||||
future
|
||||
tqdm
|
||||
wandb
|
46
submit.py
46
submit.py
|
@ -1,46 +0,0 @@
|
|||
""" Submit code specific to the kaggle challenge"""
|
||||
|
||||
import os
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
from predict import predict_img
|
||||
from unet import UNet
|
||||
|
||||
# credits to https://stackoverflow.com/users/6076729/manuel-lagunas
|
||||
def rle_encode(mask_image):
|
||||
pixels = mask_image.flatten()
|
||||
# We avoid issues with '1' at the start or end (at the corners of
|
||||
# the original image) by setting those pixels to '0' explicitly.
|
||||
# We do not expect these to be non-zero for an accurate mask,
|
||||
# so this should not harm the score.
|
||||
pixels[0] = 0
|
||||
pixels[-1] = 0
|
||||
runs = np.where(pixels[1:] != pixels[:-1])[0] + 2
|
||||
runs[1::2] = runs[1::2] - runs[:-1:2]
|
||||
return runs
|
||||
|
||||
|
||||
def submit(net):
|
||||
"""Used for Kaggle submission: predicts and encode all test images"""
|
||||
dir = 'data/test/'
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
N = len(list(os.listdir(dir)))
|
||||
with open('SUBMISSION.csv', 'a') as f:
|
||||
f.write('img,rle_mask\n')
|
||||
for index, i in enumerate(os.listdir(dir)):
|
||||
print('{}/{}'.format(index, N))
|
||||
|
||||
img = Image.open(dir + i)
|
||||
|
||||
mask = predict_img(net, img, device)
|
||||
enc = rle_encode(mask)
|
||||
f.write('{},{}\n'.format(i, ' '.join(map(str, enc))))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
net = UNet(3, 1).cuda()
|
||||
net.load_state_dict(torch.load('MODEL.pth'))
|
||||
submit(net)
|
210
train.py
210
train.py
|
@ -1,187 +1,193 @@
|
|||
import argparse
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import wandb
|
||||
from torch import optim
|
||||
from torch.utils.data import DataLoader, random_split
|
||||
from tqdm import tqdm
|
||||
|
||||
from eval import eval_net
|
||||
from data_loading import BasicDataset, CarvanaDataset
|
||||
from dice_score import dice_loss
|
||||
from evaluate import evaluate
|
||||
from unet import UNet
|
||||
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from utils.dataset import BasicDataset
|
||||
from torch.utils.data import DataLoader, random_split
|
||||
|
||||
dir_img = 'data/imgs/'
|
||||
dir_mask = 'data/masks/'
|
||||
dir_checkpoint = 'checkpoints/'
|
||||
dir_img = Path('./data/imgs/')
|
||||
dir_mask = Path('./data/masks/')
|
||||
dir_checkpoint = Path('./checkpoints/')
|
||||
|
||||
|
||||
def train_net(net,
|
||||
device,
|
||||
epochs=5,
|
||||
batch_size=1,
|
||||
lr=0.001,
|
||||
val_percent=0.1,
|
||||
save_cp=True,
|
||||
img_scale=0.5):
|
||||
|
||||
epochs: int = 5,
|
||||
batch_size: int = 1,
|
||||
learning_rate: float = 0.001,
|
||||
val_percent: float = 0.1,
|
||||
save_checkpoint: bool = True,
|
||||
img_scale: float = 0.5,
|
||||
amp: bool = False):
|
||||
# 1. Create dataset
|
||||
try:
|
||||
dataset = CarvanaDataset(dir_img, dir_mask, img_scale)
|
||||
except (AssertionError, RuntimeError):
|
||||
dataset = BasicDataset(dir_img, dir_mask, img_scale)
|
||||
|
||||
# 2. Split into train / validation partitions
|
||||
n_val = int(len(dataset) * val_percent)
|
||||
n_train = len(dataset) - n_val
|
||||
train, val = random_split(dataset, [n_train, n_val])
|
||||
train_loader = DataLoader(train, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True)
|
||||
val_loader = DataLoader(val, batch_size=batch_size, shuffle=False, num_workers=8, pin_memory=True, drop_last=True)
|
||||
train_set, val_set = random_split(dataset, [n_train, n_val], generator=torch.Generator().manual_seed(0))
|
||||
|
||||
writer = SummaryWriter(comment=f'LR_{lr}_BS_{batch_size}_SCALE_{img_scale}')
|
||||
global_step = 0
|
||||
# 3. Create data loaders
|
||||
loader_args = dict(batch_size=batch_size, num_workers=4, pin_memory=True)
|
||||
train_loader = DataLoader(train_set, shuffle=True, **loader_args)
|
||||
val_loader = DataLoader(val_set, shuffle=False, drop_last=True, **loader_args)
|
||||
|
||||
# (Initialise logging)
|
||||
experiment = wandb.init(project='U-Net', resume='allow', anonymous='must')
|
||||
experiment.config.update(dict(epochs=epochs, batch_size=batch_size, learning_rate=learning_rate,
|
||||
val_percent=val_percent, save_checkpoint=save_checkpoint, img_scale=img_scale,
|
||||
amp=amp))
|
||||
|
||||
logging.info(f'''Starting training:
|
||||
Epochs: {epochs}
|
||||
Batch size: {batch_size}
|
||||
Learning rate: {lr}
|
||||
Learning rate: {learning_rate}
|
||||
Training size: {n_train}
|
||||
Validation size: {n_val}
|
||||
Checkpoints: {save_cp}
|
||||
Checkpoints: {save_checkpoint}
|
||||
Device: {device.type}
|
||||
Images scaling: {img_scale}
|
||||
Mixed Precision: {amp}
|
||||
''')
|
||||
|
||||
optimizer = optim.RMSprop(net.parameters(), lr=lr, weight_decay=1e-8, momentum=0.9)
|
||||
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min' if net.n_classes > 1 else 'max', patience=2)
|
||||
if net.n_classes > 1:
|
||||
# 4. Set up the optimizer, the loss, the learning rate scheduler and the loss scaling for AMP
|
||||
optimizer = optim.RMSprop(net.parameters(), lr=learning_rate, weight_decay=1e-8, momentum=0.9)
|
||||
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', patience=2) # goal: maximize Dice score
|
||||
grad_scaler = torch.cuda.amp.GradScaler(enabled=amp)
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
else:
|
||||
criterion = nn.BCEWithLogitsLoss()
|
||||
global_step = 0
|
||||
|
||||
# 5. Begin training
|
||||
for epoch in range(epochs):
|
||||
net.train()
|
||||
|
||||
epoch_loss = 0
|
||||
with tqdm(total=n_train, desc=f'Epoch {epoch + 1}/{epochs}', unit='img') as pbar:
|
||||
for batch in train_loader:
|
||||
imgs = batch['image']
|
||||
images = batch['image']
|
||||
true_masks = batch['mask']
|
||||
assert imgs.shape[1] == net.n_channels, \
|
||||
|
||||
assert images.shape[1] == net.n_channels, \
|
||||
f'Network has been defined with {net.n_channels} input channels, ' \
|
||||
f'but loaded images have {imgs.shape[1]} channels. Please check that ' \
|
||||
f'but loaded images have {images.shape[1]} channels. Please check that ' \
|
||||
'the images are loaded correctly.'
|
||||
|
||||
imgs = imgs.to(device=device, dtype=torch.float32)
|
||||
mask_type = torch.float32 if net.n_classes == 1 else torch.long
|
||||
true_masks = true_masks.to(device=device, dtype=mask_type)
|
||||
images = images.to(device=device, dtype=torch.float32)
|
||||
true_masks = true_masks.to(device=device, dtype=torch.long)
|
||||
|
||||
masks_pred = net(imgs)
|
||||
loss = criterion(masks_pred, true_masks)
|
||||
with torch.cuda.amp.autocast(enabled=amp):
|
||||
masks_pred = net(images)
|
||||
loss = criterion(masks_pred, true_masks) \
|
||||
+ dice_loss(F.softmax(masks_pred, dim=1).float(),
|
||||
F.one_hot(true_masks, net.n_classes).permute(0, 3, 1, 2).float(),
|
||||
multiclass=True)
|
||||
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
grad_scaler.scale(loss).backward()
|
||||
grad_scaler.step(optimizer)
|
||||
grad_scaler.update()
|
||||
|
||||
pbar.update(images.shape[0])
|
||||
global_step += 1
|
||||
epoch_loss += loss.item()
|
||||
writer.add_scalar('Loss/train', loss.item(), global_step)
|
||||
|
||||
experiment.log({
|
||||
'train loss': loss.item(),
|
||||
'step': global_step,
|
||||
'epoch': epoch
|
||||
})
|
||||
pbar.set_postfix(**{'loss (batch)': loss.item()})
|
||||
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
nn.utils.clip_grad_value_(net.parameters(), 0.1)
|
||||
optimizer.step()
|
||||
|
||||
pbar.update(imgs.shape[0])
|
||||
global_step += 1
|
||||
# Evaluation round
|
||||
if global_step % (n_train // (10 * batch_size)) == 0:
|
||||
histograms = {}
|
||||
for tag, value in net.named_parameters():
|
||||
tag = tag.replace('.', '/')
|
||||
writer.add_histogram('weights/' + tag, value.data.cpu().numpy(), global_step)
|
||||
writer.add_histogram('grads/' + tag, value.grad.data.cpu().numpy(), global_step)
|
||||
val_score = eval_net(net, val_loader, device)
|
||||
tag = tag.replace('/', '.')
|
||||
histograms['Weights/' + tag] = wandb.Histogram(value.data.cpu())
|
||||
histograms['Gradients/' + tag] = wandb.Histogram(value.grad.data.cpu())
|
||||
|
||||
val_score = evaluate(net, val_loader, device)
|
||||
scheduler.step(val_score)
|
||||
writer.add_scalar('learning_rate', optimizer.param_groups[0]['lr'], global_step)
|
||||
|
||||
if net.n_classes > 1:
|
||||
logging.info('Validation cross entropy: {}'.format(val_score))
|
||||
writer.add_scalar('Loss/test', val_score, global_step)
|
||||
else:
|
||||
logging.info('Validation Dice Coeff: {}'.format(val_score))
|
||||
writer.add_scalar('Dice/test', val_score, global_step)
|
||||
logging.info('Validation Dice score: {}'.format(val_score))
|
||||
experiment.log({
|
||||
'learning rate': optimizer.param_groups[0]['lr'],
|
||||
'validation Dice': val_score,
|
||||
'images': wandb.Image(images[0].cpu()),
|
||||
'masks': {
|
||||
'true': wandb.Image(true_masks[0].float().cpu()),
|
||||
'pred': wandb.Image(torch.softmax(masks_pred, dim=1)[0].float().cpu()),
|
||||
},
|
||||
'step': global_step,
|
||||
'epoch': epoch,
|
||||
**histograms
|
||||
})
|
||||
|
||||
writer.add_images('images', imgs, global_step)
|
||||
if net.n_classes == 1:
|
||||
writer.add_images('masks/true', true_masks, global_step)
|
||||
writer.add_images('masks/pred', torch.sigmoid(masks_pred) > 0.5, global_step)
|
||||
|
||||
if save_cp:
|
||||
try:
|
||||
os.mkdir(dir_checkpoint)
|
||||
logging.info('Created checkpoint directory')
|
||||
except OSError:
|
||||
pass
|
||||
torch.save(net.state_dict(),
|
||||
dir_checkpoint + f'CP_epoch{epoch + 1}.pth')
|
||||
if save_checkpoint:
|
||||
Path(dir_checkpoint).mkdir(parents=True, exist_ok=True)
|
||||
torch.save(net.state_dict(), str(dir_checkpoint / 'checkpoint_epoch{}.pth'.format(epoch + 1)))
|
||||
logging.info(f'Checkpoint {epoch + 1} saved!')
|
||||
|
||||
writer.close()
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser(description='Train the UNet on images and target masks',
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||
parser.add_argument('-e', '--epochs', metavar='E', type=int, default=5,
|
||||
help='Number of epochs', dest='epochs')
|
||||
parser.add_argument('-b', '--batch-size', metavar='B', type=int, nargs='?', default=1,
|
||||
help='Batch size', dest='batchsize')
|
||||
parser.add_argument('-l', '--learning-rate', metavar='LR', type=float, nargs='?', default=0.0001,
|
||||
parser = argparse.ArgumentParser(description='Train the UNet on images and target masks')
|
||||
parser.add_argument('--epochs', '-e', metavar='E', type=int, default=50, help='Number of epochs')
|
||||
parser.add_argument('--batch-size', '-b', dest='batch_size', metavar='B', type=int, default=1, help='Batch size')
|
||||
parser.add_argument('--learning-rate', '-l', metavar='LR', type=float, default=0.00001,
|
||||
help='Learning rate', dest='lr')
|
||||
parser.add_argument('-f', '--load', dest='load', type=str, default=False,
|
||||
help='Load model from a .pth file')
|
||||
parser.add_argument('-s', '--scale', dest='scale', type=float, default=0.5,
|
||||
help='Downscaling factor of the images')
|
||||
parser.add_argument('-v', '--validation', dest='val', type=float, default=10.0,
|
||||
parser.add_argument('--load', '-f', type=str, default=False, help='Load model from a .pth file')
|
||||
parser.add_argument('--scale', '-s', type=float, default=0.5, help='Downscaling factor of the images')
|
||||
parser.add_argument('--validation', '-v', dest='val', type=float, default=10.0,
|
||||
help='Percent of the data that is used as validation (0-100)')
|
||||
parser.add_argument('--amp', action='store_true', default=False, help='Use mixed precision')
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
|
||||
args = get_args()
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
logging.info(f'Using device {device}')
|
||||
|
||||
# Change here to adapt to your data
|
||||
# n_channels=3 for RGB images
|
||||
# n_classes is the number of probabilities you want to get per pixel
|
||||
# - For 1 class and background, use n_classes=1
|
||||
# - For 2 classes, use n_classes=1
|
||||
# - For N > 2 classes, use n_classes=N
|
||||
net = UNet(n_channels=3, n_classes=1, bilinear=True)
|
||||
net = UNet(n_channels=3, n_classes=2, bilinear=True)
|
||||
|
||||
logging.info(f'Network:\n'
|
||||
f'\t{net.n_channels} input channels\n'
|
||||
f'\t{net.n_classes} output channels (classes)\n'
|
||||
f'\t{"Bilinear" if net.bilinear else "Transposed conv"} upscaling')
|
||||
|
||||
if args.load:
|
||||
net.load_state_dict(
|
||||
torch.load(args.load, map_location=device)
|
||||
)
|
||||
net.load_state_dict(torch.load(args.load, map_location=device))
|
||||
logging.info(f'Model loaded from {args.load}')
|
||||
|
||||
net.to(device=device)
|
||||
# faster convolutions, but more memory
|
||||
# cudnn.benchmark = True
|
||||
|
||||
try:
|
||||
train_net(net=net,
|
||||
epochs=args.epochs,
|
||||
batch_size=args.batchsize,
|
||||
lr=args.lr,
|
||||
batch_size=args.batch_size,
|
||||
learning_rate=args.lr,
|
||||
device=device,
|
||||
img_scale=args.scale,
|
||||
val_percent=args.val / 100)
|
||||
val_percent=args.val / 100,
|
||||
amp=args.amp)
|
||||
except KeyboardInterrupt:
|
||||
torch.save(net.state_dict(), 'INTERRUPTED.pth')
|
||||
logging.info('Saved interrupt')
|
||||
try:
|
||||
sys.exit(0)
|
||||
except SystemExit:
|
||||
os._exit(0)
|
||||
|
|
|
@ -1,7 +1,5 @@
|
|||
""" Full assembly of the parts to form the complete network """
|
||||
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .unet_parts import *
|
||||
|
||||
|
||||
|
|
|
@ -53,7 +53,6 @@ class Up(nn.Module):
|
|||
self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
|
||||
self.conv = DoubleConv(in_channels, out_channels)
|
||||
|
||||
|
||||
def forward(self, x1, x2):
|
||||
x1 = self.up(x1)
|
||||
# input is CHW
|
||||
|
|
|
@ -2,7 +2,7 @@ import matplotlib.pyplot as plt
|
|||
|
||||
|
||||
def plot_img_and_mask(img, mask):
|
||||
classes = mask.shape[2] if len(mask.shape) > 2 else 1
|
||||
classes = mask.shape[0] if len(mask.shape) > 2 else 1
|
||||
fig, ax = plt.subplots(1, classes + 1)
|
||||
ax[0].set_title('Input image')
|
||||
ax[0].imshow(img)
|
|
@ -1,71 +0,0 @@
|
|||
from os.path import splitext
|
||||
from os import listdir
|
||||
import numpy as np
|
||||
from glob import glob
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
import logging
|
||||
from PIL import Image
|
||||
|
||||
|
||||
class BasicDataset(Dataset):
|
||||
def __init__(self, imgs_dir, masks_dir, scale=1, mask_suffix=''):
|
||||
self.imgs_dir = imgs_dir
|
||||
self.masks_dir = masks_dir
|
||||
self.scale = scale
|
||||
self.mask_suffix = mask_suffix
|
||||
assert 0 < scale <= 1, 'Scale must be between 0 and 1'
|
||||
|
||||
self.ids = [splitext(file)[0] for file in listdir(imgs_dir)
|
||||
if not file.startswith('.')]
|
||||
logging.info(f'Creating dataset with {len(self.ids)} examples')
|
||||
|
||||
def __len__(self):
|
||||
return len(self.ids)
|
||||
|
||||
@classmethod
|
||||
def preprocess(cls, pil_img, scale):
|
||||
w, h = pil_img.size
|
||||
newW, newH = int(scale * w), int(scale * h)
|
||||
assert newW > 0 and newH > 0, 'Scale is too small'
|
||||
pil_img = pil_img.resize((newW, newH))
|
||||
|
||||
img_nd = np.array(pil_img)
|
||||
|
||||
if len(img_nd.shape) == 2:
|
||||
img_nd = np.expand_dims(img_nd, axis=2)
|
||||
|
||||
# HWC to CHW
|
||||
img_trans = img_nd.transpose((2, 0, 1))
|
||||
if img_trans.max() > 1:
|
||||
img_trans = img_trans / 255
|
||||
|
||||
return img_trans
|
||||
|
||||
def __getitem__(self, i):
|
||||
idx = self.ids[i]
|
||||
mask_file = glob(self.masks_dir + idx + self.mask_suffix + '.*')
|
||||
img_file = glob(self.imgs_dir + idx + '.*')
|
||||
|
||||
assert len(mask_file) == 1, \
|
||||
f'Either no mask or multiple masks found for the ID {idx}: {mask_file}'
|
||||
assert len(img_file) == 1, \
|
||||
f'Either no image or multiple images found for the ID {idx}: {img_file}'
|
||||
mask = Image.open(mask_file[0])
|
||||
img = Image.open(img_file[0])
|
||||
|
||||
assert img.size == mask.size, \
|
||||
f'Image and mask {idx} should be the same size, but are {img.size} and {mask.size}'
|
||||
|
||||
img = self.preprocess(img, self.scale)
|
||||
mask = self.preprocess(mask, self.scale)
|
||||
|
||||
return {
|
||||
'image': torch.from_numpy(img).type(torch.FloatTensor),
|
||||
'mask': torch.from_numpy(mask).type(torch.FloatTensor)
|
||||
}
|
||||
|
||||
|
||||
class CarvanaDataset(BasicDataset):
|
||||
def __init__(self, imgs_dir, masks_dir, scale=1):
|
||||
super().__init__(imgs_dir, masks_dir, scale, mask_suffix='_mask')
|
Loading…
Reference in a new issue