REVA-QCAV/src/train.py

264 lines
7.9 KiB
Python
Raw Normal View History

import argparse
import logging
from pathlib import Path
import albumentations as A
import torch
2017-08-19 08:59:51 +00:00
import torch.nn as nn
import torch.onnx
from albumentations.pytorch import ToTensorV2
from torch import optim
from torch.utils.data import DataLoader
from tqdm import tqdm
import wandb
from evaluate import evaluate
from src.utils.dataset import SphereDataset
from unet import UNet
from utils.paste import RandomPaste
CHECKPOINT_DIR = Path("./checkpoints/")
DIR_TRAIN_IMG = Path("/home/lilian/data_disk/lfainsin/smolval2017")
DIR_VALID_IMG = Path("/home/lilian/data_disk/lfainsin/smoltrain2017/")
DIR_SPHERE_IMG = Path("/home/lilian/data_disk/lfainsin/spheres/Images/")
DIR_SPHERE_MASK = Path("/home/lilian/data_disk/lfainsin/spheres/Masks/")
def get_args():
parser = argparse.ArgumentParser(
description="Train the UNet on images and target masks",
)
parser.add_argument(
"--epochs",
"-e",
metavar="E",
type=int,
default=5,
help="Number of epochs",
)
parser.add_argument(
"--batch-size",
"-b",
dest="batch_size",
metavar="B",
type=int,
default=70,
help="Batch size",
)
parser.add_argument(
"--learning-rate",
"-l",
metavar="LR",
type=float,
default=1e-5,
help="Learning rate",
dest="lr",
)
parser.add_argument(
"--load",
"-f",
type=str,
default=False,
help="Load model from a .pth file",
)
parser.add_argument(
"--amp",
action="store_true",
default=True,
help="Use mixed precision",
)
parser.add_argument(
"--classes",
"-c",
type=int,
default=1,
help="Number of classes",
)
return parser.parse_args()
def main():
# get args from cli
args = get_args()
# setup logging
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
# enable cuda, if possible
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logging.info(f"Using device {device}")
# enable cudnn benchmarking
# torch.backends.cudnn.benchmark = True
# 0. Create network
features = [16, 32, 64, 128]
net = UNet(n_channels=3, n_classes=args.classes, features=features)
nb_params = sum(p.numel() for p in net.parameters() if p.requires_grad)
logging.info(
f"""Network:
input channels: {net.n_channels}
output channels: {net.n_classes}
nb parameters: {nb_params}
features: {features}
"""
)
2017-08-19 08:59:51 +00:00
# Load weights, if needed
if args.load:
net.load_state_dict(torch.load(args.load, map_location=device))
logging.info(f"Model loaded from {args.load}")
# save initial model.pth
torch.save(net.state_dict(), "model.pth")
# transfer network to device
net.to(device=device)
# 1. Create transforms
tf_train = A.Compose(
[
A.Resize(512, 512),
A.Flip(),
A.ColorJitter(),
RandomPaste(5, DIR_SPHERE_IMG, DIR_SPHERE_MASK),
A.GaussianBlur(),
A.ISONoise(),
A.ToFloat(max_value=255),
ToTensorV2(),
],
)
tf_valid = A.Compose(
[
A.Resize(512, 512),
RandomPaste(5, DIR_SPHERE_IMG, DIR_SPHERE_MASK),
A.ToFloat(max_value=255),
ToTensorV2(),
],
)
# 2. Create datasets
ds_train = SphereDataset(image_dir=DIR_TRAIN_IMG, transform=tf_train)
ds_valid = SphereDataset(image_dir=DIR_VALID_IMG, transform=tf_valid)
# 3. Create data loaders
loader_args = dict(batch_size=args.batch_size, num_workers=8, pin_memory=True)
train_loader = DataLoader(ds_train, shuffle=True, **loader_args)
val_loader = DataLoader(ds_valid, shuffle=False, drop_last=True, **loader_args)
# 4. Set up the optimizer, the loss, the learning rate scheduler and the loss scaling for AMP
optimizer = optim.RMSprop(net.parameters(), lr=args.lr, weight_decay=1e-8, momentum=0.9)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, "max", patience=2)
grad_scaler = torch.cuda.amp.GradScaler(enabled=args.amp)
criterion = nn.BCEWithLogitsLoss()
# setup wandb
wandb.init(
project="U-Net-tmp",
config=dict(
epochs=args.epochs,
batch_size=args.batch_size,
learning_rate=args.lr,
amp=args.amp,
features=features,
parameters=nb_params,
),
)
wandb.watch(net, log_freq=len(ds_train) // args.batch_size // 4)
artifact = wandb.Artifact("model", type="model")
artifact.add_file("model.pth")
wandb.run.log_artifact(artifact)
logging.info(
f"""Starting training:
Epochs: {args.epochs}
Batch size: {args.batch_size}
Learning rate: {args.lr}
Training size: {len(ds_train)}
Validation size: {len(ds_valid)}
Device: {device.type}
Mixed Precision: {args.amp}
"""
)
try:
for epoch in range(1, args.epochs + 1):
with tqdm(total=len(ds_train), desc=f"{epoch}/{args.epochs}", unit="img") as pbar:
# Training round
for step, (images, true_masks) in enumerate(train_loader):
assert images.shape[1] == net.n_channels, (
f"Network has been defined with {net.n_channels} input channels, "
f"but loaded images have {images.shape[1]} channels. Please check that "
"the images are loaded correctly."
)
# transfer images to device
images = images.to(device=device)
true_masks = true_masks.unsqueeze(1).to(device=device)
# forward
with torch.cuda.amp.autocast(enabled=args.amp):
pred_masks = net(images)
train_loss = criterion(pred_masks, true_masks)
# backward
optimizer.zero_grad(set_to_none=True)
grad_scaler.scale(train_loss).backward()
grad_scaler.step(optimizer)
grad_scaler.update()
# update tqdm progress bar
pbar.update(images.shape[0])
pbar.set_postfix(**{"loss": train_loss.item()})
# log training metrics
wandb.log(
{
"train/epoch": epoch - 1 + step / len(train_loader),
"train/train_loss": train_loss,
}
)
# Evaluation round
val_score = evaluate(net, val_loader, device)
scheduler.step(val_score)
# log validation metrics
wandb.log(
{
"val/val_score": val_score,
}
)
logging.info(
f"""Validation ended:
Train Loss: {train_loss}
Valid Score: {val_score}
"""
)
# save weights when epoch end
torch.save(net.state_dict(), "model.pth")
artifact = wandb.Artifact("model", type="model")
artifact.add_file("model.pth")
wandb.run.log_artifact(artifact)
logging.info(f"model saved!")
# export model to onnx format
dummy_input = torch.randn(1, 3, 512, 512, requires_grad=True).to(device)
torch.onnx.export(net, dummy_input, "model.onnx")
wandb.run.finish()
except KeyboardInterrupt:
torch.save(net.state_dict(), "INTERRUPTED.pth")
logging.info("Saved interrupt")
raise
if __name__ == "__main__":
main()
# TODO: fix toutes les metrics, loss, accuracy, dice...