2019-10-24 19:37:21 +00:00
|
|
|
import argparse
|
|
|
|
import logging
|
2021-08-16 00:53:00 +00:00
|
|
|
from pathlib import Path
|
2018-04-09 03:15:24 +00:00
|
|
|
|
2017-08-17 19:16:19 +00:00
|
|
|
import torch
|
2017-08-19 08:59:51 +00:00
|
|
|
import torch.nn as nn
|
2021-08-16 00:53:00 +00:00
|
|
|
import torch.nn.functional as F
|
|
|
|
import wandb
|
2018-04-09 03:15:24 +00:00
|
|
|
from torch import optim
|
2021-08-16 00:53:00 +00:00
|
|
|
from torch.utils.data import DataLoader, random_split
|
2019-10-24 19:37:21 +00:00
|
|
|
from tqdm import tqdm
|
2017-08-17 19:16:19 +00:00
|
|
|
|
2021-08-16 00:53:00 +00:00
|
|
|
from evaluate import evaluate
|
2017-11-30 05:45:19 +00:00
|
|
|
from unet import UNet
|
2022-06-27 13:39:44 +00:00
|
|
|
from utils.data_loading import BasicDataset, CarvanaDataset
|
|
|
|
from utils.dice_score import dice_loss
|
2019-10-24 19:37:21 +00:00
|
|
|
|
2022-06-27 13:39:44 +00:00
|
|
|
dir_img = Path("./data/imgs/")
|
|
|
|
dir_mask = Path("./data/masks/")
|
|
|
|
dir_checkpoint = Path("./checkpoints/")
|
2019-10-24 19:37:21 +00:00
|
|
|
|
2017-08-17 19:16:19 +00:00
|
|
|
|
2022-06-27 13:39:44 +00:00
|
|
|
def train_net(
|
|
|
|
net,
|
|
|
|
device,
|
|
|
|
epochs: int = 5,
|
|
|
|
batch_size: int = 1,
|
|
|
|
learning_rate: float = 1e-5,
|
|
|
|
val_percent: float = 0.1,
|
|
|
|
save_checkpoint: bool = True,
|
|
|
|
img_scale: float = 0.5,
|
|
|
|
amp: bool = False,
|
|
|
|
):
|
2021-08-16 00:53:00 +00:00
|
|
|
# 1. Create dataset
|
|
|
|
try:
|
|
|
|
dataset = CarvanaDataset(dir_img, dir_mask, img_scale)
|
|
|
|
except (AssertionError, RuntimeError):
|
|
|
|
dataset = BasicDataset(dir_img, dir_mask, img_scale)
|
|
|
|
|
|
|
|
# 2. Split into train / validation partitions
|
2019-11-23 13:22:42 +00:00
|
|
|
n_val = int(len(dataset) * val_percent)
|
|
|
|
n_train = len(dataset) - n_val
|
2021-08-16 00:53:00 +00:00
|
|
|
train_set, val_set = random_split(dataset, [n_train, n_val], generator=torch.Generator().manual_seed(0))
|
2019-11-23 16:56:14 +00:00
|
|
|
|
2021-08-16 00:53:00 +00:00
|
|
|
# 3. Create data loaders
|
|
|
|
loader_args = dict(batch_size=batch_size, num_workers=4, pin_memory=True)
|
|
|
|
train_loader = DataLoader(train_set, shuffle=True, **loader_args)
|
|
|
|
val_loader = DataLoader(val_set, shuffle=False, drop_last=True, **loader_args)
|
|
|
|
|
2021-08-16 14:54:06 +00:00
|
|
|
# (Initialize logging)
|
2022-06-27 13:39:44 +00:00
|
|
|
experiment = wandb.init(project="U-Net", resume="allow", anonymous="must")
|
|
|
|
experiment.config.update(
|
|
|
|
dict(
|
|
|
|
epochs=epochs,
|
|
|
|
batch_size=batch_size,
|
|
|
|
learning_rate=learning_rate,
|
|
|
|
val_percent=val_percent,
|
|
|
|
save_checkpoint=save_checkpoint,
|
|
|
|
img_scale=img_scale,
|
|
|
|
amp=amp,
|
|
|
|
)
|
|
|
|
)
|
2017-08-17 19:16:19 +00:00
|
|
|
|
2022-06-27 13:39:44 +00:00
|
|
|
logging.info(
|
|
|
|
f"""Starting training:
|
2019-10-24 19:37:21 +00:00
|
|
|
Epochs: {epochs}
|
|
|
|
Batch size: {batch_size}
|
2021-08-16 00:53:00 +00:00
|
|
|
Learning rate: {learning_rate}
|
2019-11-23 13:22:42 +00:00
|
|
|
Training size: {n_train}
|
|
|
|
Validation size: {n_val}
|
2021-08-16 00:53:00 +00:00
|
|
|
Checkpoints: {save_checkpoint}
|
2019-10-24 19:37:21 +00:00
|
|
|
Device: {device.type}
|
|
|
|
Images scaling: {img_scale}
|
2021-08-16 00:53:00 +00:00
|
|
|
Mixed Precision: {amp}
|
2022-06-27 13:39:44 +00:00
|
|
|
"""
|
|
|
|
)
|
2019-10-24 19:37:21 +00:00
|
|
|
|
2021-08-16 00:53:00 +00:00
|
|
|
# 4. Set up the optimizer, the loss, the learning rate scheduler and the loss scaling for AMP
|
|
|
|
optimizer = optim.RMSprop(net.parameters(), lr=learning_rate, weight_decay=1e-8, momentum=0.9)
|
2022-06-27 13:39:44 +00:00
|
|
|
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, "max", patience=2) # goal: maximize Dice score
|
2021-08-16 00:53:00 +00:00
|
|
|
grad_scaler = torch.cuda.amp.GradScaler(enabled=amp)
|
|
|
|
criterion = nn.CrossEntropyLoss()
|
|
|
|
global_step = 0
|
2017-08-17 19:16:19 +00:00
|
|
|
|
2021-08-16 00:53:00 +00:00
|
|
|
# 5. Begin training
|
2022-06-27 13:39:44 +00:00
|
|
|
for epoch in range(1, epochs + 1):
|
2018-09-26 06:58:49 +00:00
|
|
|
net.train()
|
2017-08-17 19:16:19 +00:00
|
|
|
epoch_loss = 0
|
2022-06-27 13:39:44 +00:00
|
|
|
with tqdm(total=n_train, desc=f"Epoch {epoch}/{epochs}", unit="img") as pbar:
|
2019-11-23 13:22:42 +00:00
|
|
|
for batch in train_loader:
|
2022-06-27 13:39:44 +00:00
|
|
|
images = batch["image"]
|
|
|
|
true_masks = batch["mask"]
|
2021-08-16 00:53:00 +00:00
|
|
|
|
2022-06-27 13:39:44 +00:00
|
|
|
assert images.shape[1] == net.n_channels, (
|
|
|
|
f"Network has been defined with {net.n_channels} input channels, "
|
|
|
|
f"but loaded images have {images.shape[1]} channels. Please check that "
|
|
|
|
"the images are loaded correctly."
|
|
|
|
)
|
2019-11-23 13:22:42 +00:00
|
|
|
|
2021-08-16 00:53:00 +00:00
|
|
|
images = images.to(device=device, dtype=torch.float32)
|
|
|
|
true_masks = true_masks.to(device=device, dtype=torch.long)
|
2017-08-17 19:16:19 +00:00
|
|
|
|
2021-08-16 00:53:00 +00:00
|
|
|
with torch.cuda.amp.autocast(enabled=amp):
|
|
|
|
masks_pred = net(images)
|
2022-06-27 13:39:44 +00:00
|
|
|
loss = criterion(masks_pred, true_masks) + dice_loss(
|
|
|
|
F.softmax(masks_pred, dim=1).float(),
|
|
|
|
F.one_hot(true_masks, net.n_classes).permute(0, 3, 1, 2).float(),
|
|
|
|
multiclass=True,
|
|
|
|
)
|
2017-08-17 19:16:19 +00:00
|
|
|
|
2021-08-16 00:53:00 +00:00
|
|
|
optimizer.zero_grad(set_to_none=True)
|
|
|
|
grad_scaler.scale(loss).backward()
|
|
|
|
grad_scaler.step(optimizer)
|
|
|
|
grad_scaler.update()
|
2017-08-19 08:59:51 +00:00
|
|
|
|
2021-08-16 00:53:00 +00:00
|
|
|
pbar.update(images.shape[0])
|
2019-11-23 16:56:14 +00:00
|
|
|
global_step += 1
|
2021-08-16 00:53:00 +00:00
|
|
|
epoch_loss += loss.item()
|
2022-06-27 13:39:44 +00:00
|
|
|
experiment.log({"train loss": loss.item(), "step": global_step, "epoch": epoch})
|
|
|
|
pbar.set_postfix(**{"loss (batch)": loss.item()})
|
2021-08-16 00:53:00 +00:00
|
|
|
|
|
|
|
# Evaluation round
|
2022-06-27 13:39:44 +00:00
|
|
|
division_step = n_train // (10 * batch_size)
|
2021-10-24 21:14:18 +00:00
|
|
|
if division_step > 0:
|
2021-10-24 21:07:54 +00:00
|
|
|
if global_step % division_step == 0:
|
|
|
|
histograms = {}
|
|
|
|
for tag, value in net.named_parameters():
|
2022-06-27 13:39:44 +00:00
|
|
|
tag = tag.replace("/", ".")
|
|
|
|
histograms["Weights/" + tag] = wandb.Histogram(value.data.cpu())
|
|
|
|
histograms["Gradients/" + tag] = wandb.Histogram(value.grad.data.cpu())
|
2021-10-24 21:07:54 +00:00
|
|
|
|
|
|
|
val_score = evaluate(net, val_loader, device)
|
|
|
|
scheduler.step(val_score)
|
|
|
|
|
2022-06-27 13:39:44 +00:00
|
|
|
logging.info("Validation Dice score: {}".format(val_score))
|
|
|
|
experiment.log(
|
|
|
|
{
|
|
|
|
"learning rate": optimizer.param_groups[0]["lr"],
|
|
|
|
"validation Dice": val_score,
|
|
|
|
"images": wandb.Image(images[0].cpu()),
|
|
|
|
"masks": {
|
|
|
|
"true": wandb.Image(true_masks[0].float().cpu()),
|
|
|
|
"pred": wandb.Image(
|
|
|
|
torch.softmax(masks_pred, dim=1).argmax(dim=1)[0].float().cpu()
|
|
|
|
),
|
|
|
|
},
|
|
|
|
"step": global_step,
|
|
|
|
"epoch": epoch,
|
|
|
|
**histograms,
|
|
|
|
}
|
|
|
|
)
|
2021-08-16 00:53:00 +00:00
|
|
|
|
|
|
|
if save_checkpoint:
|
|
|
|
Path(dir_checkpoint).mkdir(parents=True, exist_ok=True)
|
2022-06-27 13:39:44 +00:00
|
|
|
torch.save(net.state_dict(), str(dir_checkpoint / "checkpoint_epoch{}.pth".format(epoch)))
|
|
|
|
logging.info(f"Checkpoint {epoch} saved!")
|
2017-08-17 19:16:19 +00:00
|
|
|
|
2018-06-08 17:27:32 +00:00
|
|
|
|
|
|
|
def get_args():
|
2022-06-27 13:39:44 +00:00
|
|
|
parser = argparse.ArgumentParser(
|
|
|
|
description="Train the UNet on images and target masks",
|
|
|
|
)
|
|
|
|
parser.add_argument(
|
|
|
|
"--epochs",
|
|
|
|
"-e",
|
|
|
|
metavar="E",
|
|
|
|
type=int,
|
|
|
|
default=5,
|
|
|
|
help="Number of epochs",
|
|
|
|
)
|
|
|
|
parser.add_argument(
|
|
|
|
"--batch-size",
|
|
|
|
"-b",
|
|
|
|
dest="batch_size",
|
|
|
|
metavar="B",
|
|
|
|
type=int,
|
|
|
|
default=1,
|
|
|
|
help="Batch size",
|
|
|
|
)
|
|
|
|
parser.add_argument(
|
|
|
|
"--learning-rate",
|
|
|
|
"-l",
|
|
|
|
metavar="LR",
|
|
|
|
type=float,
|
|
|
|
default=1e-5,
|
|
|
|
help="Learning rate",
|
|
|
|
dest="lr",
|
|
|
|
)
|
|
|
|
parser.add_argument(
|
|
|
|
"--load",
|
|
|
|
"-f",
|
|
|
|
type=str,
|
|
|
|
default=False,
|
|
|
|
help="Load model from a .pth file",
|
|
|
|
)
|
|
|
|
parser.add_argument(
|
|
|
|
"--scale",
|
|
|
|
"-s",
|
|
|
|
type=float,
|
|
|
|
default=0.5,
|
|
|
|
help="Downscaling factor of the images",
|
|
|
|
)
|
|
|
|
parser.add_argument(
|
|
|
|
"--validation",
|
|
|
|
"-v",
|
|
|
|
dest="val",
|
|
|
|
type=float,
|
|
|
|
default=10.0,
|
|
|
|
help="Percent of the data that is used as validation (0-100)",
|
|
|
|
)
|
|
|
|
parser.add_argument(
|
|
|
|
"--amp",
|
|
|
|
action="store_true",
|
|
|
|
default=False,
|
|
|
|
help="Use mixed precision",
|
|
|
|
)
|
|
|
|
parser.add_argument(
|
|
|
|
"--bilinear",
|
|
|
|
action="store_true",
|
|
|
|
default=False,
|
|
|
|
help="Use bilinear upsampling",
|
|
|
|
)
|
|
|
|
parser.add_argument(
|
|
|
|
"--classes",
|
|
|
|
"-c",
|
|
|
|
type=int,
|
|
|
|
default=2,
|
|
|
|
help="Number of classes",
|
|
|
|
)
|
2019-10-24 19:37:21 +00:00
|
|
|
|
|
|
|
return parser.parse_args()
|
|
|
|
|
|
|
|
|
2022-06-27 13:39:44 +00:00
|
|
|
if __name__ == "__main__":
|
2018-06-08 17:27:32 +00:00
|
|
|
args = get_args()
|
2021-08-16 00:53:00 +00:00
|
|
|
|
2022-06-27 13:39:44 +00:00
|
|
|
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
logging.info(f"Using device {device}")
|
2019-10-24 19:37:21 +00:00
|
|
|
|
|
|
|
# Change here to adapt to your data
|
|
|
|
# n_channels=3 for RGB images
|
|
|
|
# n_classes is the number of probabilities you want to get per pixel
|
2022-04-06 11:35:02 +00:00
|
|
|
net = UNet(n_channels=3, n_classes=args.classes, bilinear=args.bilinear)
|
2021-08-16 00:53:00 +00:00
|
|
|
|
2022-06-27 13:39:44 +00:00
|
|
|
logging.info(
|
|
|
|
f"Network:\n"
|
|
|
|
f"\t{net.n_channels} input channels\n"
|
|
|
|
f"\t{net.n_classes} output channels (classes)\n"
|
|
|
|
f'\t{"Bilinear" if net.bilinear else "Transposed conv"} upscaling'
|
|
|
|
)
|
2017-08-19 08:59:51 +00:00
|
|
|
|
2018-06-08 17:27:32 +00:00
|
|
|
if args.load:
|
2021-08-16 00:53:00 +00:00
|
|
|
net.load_state_dict(torch.load(args.load, map_location=device))
|
2022-06-27 13:39:44 +00:00
|
|
|
logging.info(f"Model loaded from {args.load}")
|
2019-10-24 19:37:21 +00:00
|
|
|
|
|
|
|
net.to(device=device)
|
2017-08-19 08:59:51 +00:00
|
|
|
try:
|
2022-06-27 13:39:44 +00:00
|
|
|
train_net(
|
|
|
|
net=net,
|
|
|
|
epochs=args.epochs,
|
|
|
|
batch_size=args.batch_size,
|
|
|
|
learning_rate=args.lr,
|
|
|
|
device=device,
|
|
|
|
img_scale=args.scale,
|
|
|
|
val_percent=args.val / 100,
|
|
|
|
amp=args.amp,
|
|
|
|
)
|
2019-12-21 21:04:23 +00:00
|
|
|
except KeyboardInterrupt:
|
2022-06-27 13:39:44 +00:00
|
|
|
torch.save(net.state_dict(), "INTERRUPTED.pth")
|
|
|
|
logging.info("Saved interrupt")
|
2022-04-06 11:35:02 +00:00
|
|
|
raise
|