2019-10-24 19:37:21 +00:00
|
|
|
import argparse
|
|
|
|
import logging
|
2021-08-16 00:53:00 +00:00
|
|
|
from pathlib import Path
|
2018-04-09 03:15:24 +00:00
|
|
|
|
2022-06-28 07:36:21 +00:00
|
|
|
import albumentations as A
|
2017-08-17 19:16:19 +00:00
|
|
|
import torch
|
2017-08-19 08:59:51 +00:00
|
|
|
import torch.nn as nn
|
2022-06-28 07:36:21 +00:00
|
|
|
from albumentations.pytorch import ToTensorV2
|
2018-04-09 03:15:24 +00:00
|
|
|
from torch import optim
|
2022-06-28 07:36:21 +00:00
|
|
|
from torch.utils.data import DataLoader
|
2019-10-24 19:37:21 +00:00
|
|
|
from tqdm import tqdm
|
2017-08-17 19:16:19 +00:00
|
|
|
|
2022-06-28 07:36:21 +00:00
|
|
|
import wandb
|
2021-08-16 00:53:00 +00:00
|
|
|
from evaluate import evaluate
|
2022-06-28 07:36:21 +00:00
|
|
|
from src.utils.dataset import SphereDataset
|
|
|
|
from unet import UNet
|
|
|
|
from utils.paste import RandomPaste
|
2019-10-24 19:37:21 +00:00
|
|
|
|
2022-06-28 07:36:21 +00:00
|
|
|
CHECKPOINT_DIR = Path("./checkpoints/")
|
2022-06-29 08:26:26 +00:00
|
|
|
DIR_TRAIN_IMG = Path("/home/lilian/data_disk/lfainsin/val2017")
|
2022-06-30 08:47:53 +00:00
|
|
|
DIR_VALID_IMG = Path("/home/lilian/data_disk/lfainsin/smoltrain2017/")
|
2022-06-28 07:36:21 +00:00
|
|
|
DIR_SPHERE_IMG = Path("/home/lilian/data_disk/lfainsin/spheres/Images/")
|
|
|
|
DIR_SPHERE_MASK = Path("/home/lilian/data_disk/lfainsin/spheres/Masks/")
|
2019-10-24 19:37:21 +00:00
|
|
|
|
2017-08-17 19:16:19 +00:00
|
|
|
|
2018-06-08 17:27:32 +00:00
|
|
|
def get_args():
|
2022-06-27 13:39:44 +00:00
|
|
|
parser = argparse.ArgumentParser(
|
|
|
|
description="Train the UNet on images and target masks",
|
|
|
|
)
|
|
|
|
parser.add_argument(
|
|
|
|
"--epochs",
|
|
|
|
"-e",
|
|
|
|
metavar="E",
|
|
|
|
type=int,
|
|
|
|
default=5,
|
|
|
|
help="Number of epochs",
|
|
|
|
)
|
|
|
|
parser.add_argument(
|
|
|
|
"--batch-size",
|
|
|
|
"-b",
|
|
|
|
dest="batch_size",
|
|
|
|
metavar="B",
|
|
|
|
type=int,
|
2022-06-30 08:47:53 +00:00
|
|
|
default=70,
|
2022-06-27 13:39:44 +00:00
|
|
|
help="Batch size",
|
|
|
|
)
|
|
|
|
parser.add_argument(
|
|
|
|
"--learning-rate",
|
|
|
|
"-l",
|
|
|
|
metavar="LR",
|
|
|
|
type=float,
|
|
|
|
default=1e-5,
|
|
|
|
help="Learning rate",
|
|
|
|
dest="lr",
|
|
|
|
)
|
|
|
|
parser.add_argument(
|
|
|
|
"--load",
|
|
|
|
"-f",
|
|
|
|
type=str,
|
|
|
|
default=False,
|
|
|
|
help="Load model from a .pth file",
|
|
|
|
)
|
|
|
|
parser.add_argument(
|
|
|
|
"--amp",
|
|
|
|
action="store_true",
|
2022-06-27 14:40:04 +00:00
|
|
|
default=True,
|
2022-06-27 13:39:44 +00:00
|
|
|
help="Use mixed precision",
|
|
|
|
)
|
|
|
|
parser.add_argument(
|
|
|
|
"--classes",
|
|
|
|
"-c",
|
|
|
|
type=int,
|
2022-06-27 14:40:04 +00:00
|
|
|
default=1,
|
2022-06-27 13:39:44 +00:00
|
|
|
help="Number of classes",
|
|
|
|
)
|
2019-10-24 19:37:21 +00:00
|
|
|
|
|
|
|
return parser.parse_args()
|
|
|
|
|
|
|
|
|
2022-06-28 09:36:43 +00:00
|
|
|
def main():
|
|
|
|
# get args from cli
|
2018-06-08 17:27:32 +00:00
|
|
|
args = get_args()
|
2021-08-16 00:53:00 +00:00
|
|
|
|
2022-06-28 09:36:43 +00:00
|
|
|
# setup logging
|
2022-06-27 13:39:44 +00:00
|
|
|
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
|
2022-06-28 07:36:21 +00:00
|
|
|
|
2022-06-28 09:36:43 +00:00
|
|
|
# enable cuda, if possible
|
2022-06-27 13:39:44 +00:00
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
logging.info(f"Using device {device}")
|
2019-10-24 19:37:21 +00:00
|
|
|
|
2022-06-29 09:22:23 +00:00
|
|
|
# enable cudnn benchmarking
|
|
|
|
torch.backends.cudnn.benchmark = True
|
|
|
|
|
2022-06-28 09:36:43 +00:00
|
|
|
# 0. Create network
|
2022-06-30 08:47:53 +00:00
|
|
|
features = [16, 32, 64, 128]
|
|
|
|
net = UNet(n_channels=args.n_channels, n_classes=args.classes, features=features)
|
2022-06-27 13:39:44 +00:00
|
|
|
logging.info(
|
2022-06-28 07:36:21 +00:00
|
|
|
f"""Network:
|
2022-06-28 09:36:43 +00:00
|
|
|
input channels: {net.n_channels}
|
|
|
|
output channels: {net.n_classes}
|
2022-06-30 08:47:53 +00:00
|
|
|
nb parameters: {sum(p.numel() for p in net.parameters() if p.requires_grad)}
|
|
|
|
features: {features}
|
2022-06-27 14:13:38 +00:00
|
|
|
"""
|
2022-06-27 13:39:44 +00:00
|
|
|
)
|
2017-08-19 08:59:51 +00:00
|
|
|
|
2022-06-28 09:36:43 +00:00
|
|
|
# Load weights, if needed
|
2018-06-08 17:27:32 +00:00
|
|
|
if args.load:
|
2021-08-16 00:53:00 +00:00
|
|
|
net.load_state_dict(torch.load(args.load, map_location=device))
|
2022-06-27 13:39:44 +00:00
|
|
|
logging.info(f"Model loaded from {args.load}")
|
2019-10-24 19:37:21 +00:00
|
|
|
|
2022-06-29 14:12:00 +00:00
|
|
|
# save initial model.pth
|
|
|
|
torch.save(net.state_dict(), "model.pth")
|
|
|
|
|
2022-06-28 09:36:43 +00:00
|
|
|
# transfer network to device
|
2019-10-24 19:37:21 +00:00
|
|
|
net.to(device=device)
|
2022-06-28 07:36:21 +00:00
|
|
|
|
2022-06-28 09:36:43 +00:00
|
|
|
# 1. Create transforms
|
|
|
|
tf_train = A.Compose(
|
|
|
|
[
|
2022-06-29 12:15:04 +00:00
|
|
|
A.Resize(512, 512),
|
2022-06-28 09:36:43 +00:00
|
|
|
A.Flip(),
|
|
|
|
A.ColorJitter(),
|
2022-06-29 08:20:35 +00:00
|
|
|
RandomPaste(5, DIR_SPHERE_IMG, DIR_SPHERE_MASK),
|
2022-06-29 12:15:04 +00:00
|
|
|
A.GaussianBlur(),
|
2022-06-28 09:36:43 +00:00
|
|
|
A.ISONoise(),
|
|
|
|
A.ToFloat(max_value=255),
|
2022-06-29 08:20:35 +00:00
|
|
|
ToTensorV2(),
|
2022-06-28 09:36:43 +00:00
|
|
|
],
|
|
|
|
)
|
|
|
|
tf_valid = A.Compose(
|
|
|
|
[
|
2022-06-29 12:15:04 +00:00
|
|
|
A.Resize(512, 512),
|
2022-06-29 08:20:35 +00:00
|
|
|
RandomPaste(5, DIR_SPHERE_IMG, DIR_SPHERE_MASK),
|
2022-06-28 09:36:43 +00:00
|
|
|
A.ToFloat(max_value=255),
|
|
|
|
ToTensorV2(),
|
|
|
|
],
|
|
|
|
)
|
|
|
|
|
|
|
|
# 2. Create datasets
|
|
|
|
ds_train = SphereDataset(image_dir=DIR_TRAIN_IMG, transform=tf_train)
|
|
|
|
ds_valid = SphereDataset(image_dir=DIR_VALID_IMG, transform=tf_valid)
|
|
|
|
|
|
|
|
# 3. Create data loaders
|
2022-06-30 08:47:53 +00:00
|
|
|
loader_args = dict(batch_size=args.batch_size, num_workers=8, pin_memory=True)
|
2022-06-28 09:36:43 +00:00
|
|
|
train_loader = DataLoader(ds_train, shuffle=True, **loader_args)
|
|
|
|
val_loader = DataLoader(ds_valid, shuffle=False, drop_last=True, **loader_args)
|
|
|
|
|
|
|
|
# 4. Set up the optimizer, the loss, the learning rate scheduler and the loss scaling for AMP
|
|
|
|
optimizer = optim.RMSprop(net.parameters(), lr=args.lr, weight_decay=1e-8, momentum=0.9)
|
2022-06-29 08:20:35 +00:00
|
|
|
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, "max", patience=2)
|
2022-06-28 09:36:43 +00:00
|
|
|
grad_scaler = torch.cuda.amp.GradScaler(enabled=args.amp)
|
|
|
|
criterion = nn.BCEWithLogitsLoss()
|
|
|
|
|
2022-06-29 09:22:23 +00:00
|
|
|
# setup wandb
|
2022-06-29 14:12:00 +00:00
|
|
|
run = wandb.init(
|
2022-06-29 08:20:35 +00:00
|
|
|
project="U-Net-tmp",
|
2022-06-28 09:36:43 +00:00
|
|
|
config=dict(
|
2022-06-27 13:39:44 +00:00
|
|
|
epochs=args.epochs,
|
|
|
|
batch_size=args.batch_size,
|
|
|
|
learning_rate=args.lr,
|
|
|
|
amp=args.amp,
|
2022-06-28 09:36:43 +00:00
|
|
|
),
|
|
|
|
)
|
2022-06-29 12:15:04 +00:00
|
|
|
wandb.watch(net, log_freq=100)
|
2022-06-30 08:47:53 +00:00
|
|
|
artifact_model = wandb.Artifact("model", type="model")
|
|
|
|
artifact_model.add_file("model.pth")
|
|
|
|
run.log_artifact(artifact_model)
|
2022-06-28 09:36:43 +00:00
|
|
|
|
|
|
|
logging.info(
|
|
|
|
f"""Starting training:
|
|
|
|
Epochs: {args.epochs}
|
|
|
|
Batch size: {args.batch_size}
|
|
|
|
Learning rate: {args.lr}
|
|
|
|
Training size: {len(ds_train)}
|
|
|
|
Validation size: {len(ds_valid)}
|
|
|
|
Device: {device.type}
|
|
|
|
Mixed Precision: {args.amp}
|
|
|
|
"""
|
|
|
|
)
|
|
|
|
|
|
|
|
try:
|
|
|
|
for epoch in range(1, args.epochs + 1):
|
|
|
|
with tqdm(total=len(ds_train), desc=f"{epoch}/{args.epochs}", unit="img") as pbar:
|
|
|
|
|
|
|
|
# Training round
|
|
|
|
for step, (images, true_masks) in enumerate(train_loader):
|
|
|
|
assert images.shape[1] == net.n_channels, (
|
|
|
|
f"Network has been defined with {net.n_channels} input channels, "
|
|
|
|
f"but loaded images have {images.shape[1]} channels. Please check that "
|
|
|
|
"the images are loaded correctly."
|
|
|
|
)
|
|
|
|
|
2022-06-29 08:20:35 +00:00
|
|
|
# transfer images to device
|
2022-06-28 09:36:43 +00:00
|
|
|
images = images.to(device=device)
|
|
|
|
true_masks = true_masks.unsqueeze(1).to(device=device)
|
|
|
|
|
2022-06-29 08:20:35 +00:00
|
|
|
# forward
|
2022-06-28 09:36:43 +00:00
|
|
|
with torch.cuda.amp.autocast(enabled=args.amp):
|
2022-06-29 08:20:35 +00:00
|
|
|
pred_masks = net(images)
|
2022-06-29 08:26:26 +00:00
|
|
|
train_loss = criterion(pred_masks, true_masks)
|
2022-06-28 09:36:43 +00:00
|
|
|
|
2022-06-29 08:20:35 +00:00
|
|
|
# backward
|
2022-06-28 09:36:43 +00:00
|
|
|
optimizer.zero_grad(set_to_none=True)
|
|
|
|
grad_scaler.scale(train_loss).backward()
|
|
|
|
grad_scaler.step(optimizer)
|
|
|
|
grad_scaler.update()
|
|
|
|
|
2022-06-29 08:20:35 +00:00
|
|
|
# update tqdm progress bar
|
2022-06-28 09:36:43 +00:00
|
|
|
pbar.update(images.shape[0])
|
|
|
|
pbar.set_postfix(**{"loss": train_loss.item()})
|
|
|
|
|
2022-06-29 08:20:35 +00:00
|
|
|
# log training metrics
|
|
|
|
wandb.log(
|
2022-06-28 09:36:43 +00:00
|
|
|
{
|
2022-06-29 08:20:35 +00:00
|
|
|
"train/epoch": epoch - 1 + step / len(train_loader),
|
2022-06-28 09:36:43 +00:00
|
|
|
"train/train_loss": train_loss,
|
|
|
|
}
|
|
|
|
)
|
|
|
|
|
|
|
|
# Evaluation round
|
2022-06-28 14:36:50 +00:00
|
|
|
val_score = evaluate(net, val_loader, device)
|
|
|
|
scheduler.step(val_score)
|
2022-06-29 08:20:35 +00:00
|
|
|
|
|
|
|
# log validation metrics
|
|
|
|
wandb.log(
|
2022-06-28 09:36:43 +00:00
|
|
|
{
|
2022-06-28 14:36:50 +00:00
|
|
|
"val/val_score": val_score,
|
2022-06-28 09:36:43 +00:00
|
|
|
}
|
|
|
|
)
|
|
|
|
|
2022-06-28 14:36:50 +00:00
|
|
|
print(f"Train Loss: {train_loss:.3f}, Valid Score: {val_score:3f}")
|
2022-06-28 09:36:43 +00:00
|
|
|
|
|
|
|
# save weights when epoch end
|
2022-06-30 08:47:53 +00:00
|
|
|
torch.save(net.state_dict(), "model.pth")
|
2022-06-29 14:12:00 +00:00
|
|
|
logging.info(f"model saved!")
|
|
|
|
|
|
|
|
run.finish()
|
2022-06-28 09:36:43 +00:00
|
|
|
|
2019-12-21 21:04:23 +00:00
|
|
|
except KeyboardInterrupt:
|
2022-06-27 13:39:44 +00:00
|
|
|
torch.save(net.state_dict(), "INTERRUPTED.pth")
|
|
|
|
logging.info("Saved interrupt")
|
2022-04-06 11:35:02 +00:00
|
|
|
raise
|
2022-06-28 09:36:43 +00:00
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
main()
|