REVA-QCAV/src/train.py
Your Name 842aaa81c2 wtf am i doing
Former-commit-id: dde43cce52408ec8f67372b365796b9014ceee57
2022-06-28 11:36:43 +02:00

227 lines
6.8 KiB
Python

import argparse
import logging
from pathlib import Path
import albumentations as A
import torch
import torch.nn as nn
import torch.nn.functional as F
from albumentations.pytorch import ToTensorV2
from torch import optim
from torch.utils.data import DataLoader
from tqdm import tqdm
import wandb
from evaluate import evaluate
from src.utils.dataset import SphereDataset
from src.utils.dice import dice_loss
from unet import UNet
from utils.paste import RandomPaste
CHECKPOINT_DIR = Path("./checkpoints/")
DIR_TRAIN_IMG = Path("/home/lilian/data_disk/lfainsin/smoltrain2017")
DIR_VALID_IMG = Path("/home/lilian/data_disk/lfainsin/smolval2017/")
DIR_SPHERE_IMG = Path("/home/lilian/data_disk/lfainsin/spheres/Images/")
DIR_SPHERE_MASK = Path("/home/lilian/data_disk/lfainsin/spheres/Masks/")
def get_args():
parser = argparse.ArgumentParser(
description="Train the UNet on images and target masks",
)
parser.add_argument(
"--epochs",
"-e",
metavar="E",
type=int,
default=5,
help="Number of epochs",
)
parser.add_argument(
"--batch-size",
"-b",
dest="batch_size",
metavar="B",
type=int,
default=10,
help="Batch size",
)
parser.add_argument(
"--learning-rate",
"-l",
metavar="LR",
type=float,
default=1e-5,
help="Learning rate",
dest="lr",
)
parser.add_argument(
"--load",
"-f",
type=str,
default=False,
help="Load model from a .pth file",
)
parser.add_argument(
"--amp",
action="store_true",
default=True,
help="Use mixed precision",
)
parser.add_argument(
"--classes",
"-c",
type=int,
default=1,
help="Number of classes",
)
return parser.parse_args()
def main():
# get args from cli
args = get_args()
# setup logging
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
# enable cuda, if possible
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logging.info(f"Using device {device}")
# 0. Create network
net = UNet(n_channels=3, n_classes=args.classes)
logging.info(
f"""Network:
input channels: {net.n_channels}
output channels: {net.n_classes}
"""
)
# Load weights, if needed
if args.load:
net.load_state_dict(torch.load(args.load, map_location=device))
logging.info(f"Model loaded from {args.load}")
# transfer network to device
net.to(device=device)
# 1. Create transforms
tf_train = A.Compose(
[
A.Resize(500, 500),
A.Flip(),
A.ColorJitter(),
RandomPaste(5, 0.2, DIR_SPHERE_IMG, DIR_SPHERE_MASK),
A.ISONoise(),
A.ToFloat(max_value=255),
A.pytorch.ToTensorV2(),
],
)
tf_valid = A.Compose(
[
A.Resize(500, 500),
RandomPaste(5, 0.2, DIR_SPHERE_IMG, DIR_SPHERE_MASK),
A.ToFloat(max_value=255),
ToTensorV2(),
],
)
# 2. Create datasets
ds_train = SphereDataset(image_dir=DIR_TRAIN_IMG, transform=tf_train)
ds_valid = SphereDataset(image_dir=DIR_VALID_IMG, transform=tf_valid)
# 3. Create data loaders
loader_args = dict(batch_size=args.batch_size, num_workers=4, pin_memory=True)
train_loader = DataLoader(ds_train, shuffle=True, **loader_args)
val_loader = DataLoader(ds_valid, shuffle=False, drop_last=True, **loader_args)
# 4. Set up the optimizer, the loss, the learning rate scheduler and the loss scaling for AMP
optimizer = optim.RMSprop(net.parameters(), lr=args.lr, weight_decay=1e-8, momentum=0.9)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, "max", patience=2) # goal: maximize Dice score
grad_scaler = torch.cuda.amp.GradScaler(enabled=args.amp)
criterion = nn.BCEWithLogitsLoss()
# connect to wandb
wandb.init(
project="U-Net",
config=dict(
epochs=args.epochs,
batch_size=args.batch_size,
learning_rate=args.lr,
amp=args.amp,
),
)
logging.info(
f"""Starting training:
Epochs: {args.epochs}
Batch size: {args.batch_size}
Learning rate: {args.lr}
Training size: {len(ds_train)}
Validation size: {len(ds_valid)}
Device: {device.type}
Mixed Precision: {args.amp}
"""
)
try:
for epoch in range(1, args.epochs + 1):
with tqdm(total=len(ds_train), desc=f"{epoch}/{args.epochs}", unit="img") as pbar:
# Training round
for step, (images, true_masks) in enumerate(train_loader):
assert images.shape[1] == net.n_channels, (
f"Network has been defined with {net.n_channels} input channels, "
f"but loaded images have {images.shape[1]} channels. Please check that "
"the images are loaded correctly."
)
images = images.to(device=device)
true_masks = true_masks.unsqueeze(1).to(device=device)
with torch.cuda.amp.autocast(enabled=args.amp):
masks_pred = net(images)
train_loss = criterion(masks_pred, true_masks) # TODO: rajouter le diceloss
optimizer.zero_grad(set_to_none=True)
grad_scaler.scale(train_loss).backward()
grad_scaler.step(optimizer)
grad_scaler.update()
pbar.update(images.shape[0])
pbar.set_postfix(**{"loss": train_loss.item()})
wandb.log( # log training metrics
{
"train/epoch": epoch + step / epoch,
"train/train_loss": train_loss,
}
)
# Evaluation round
val_loss = evaluate(net, val_loader, device)
scheduler.step(val_loss)
wandb.log( # log validation metrics
{
"val/val_loss": val_loss,
}
)
print(f"Train Loss: {train_loss:.3f}, Valid Loss: {val_loss:3f}")
# save weights when epoch end
Path(CHECKPOINT_DIR).mkdir(parents=True, exist_ok=True)
torch.save(net.state_dict(), str(CHECKPOINT_DIR / "checkpoint_epoch{}.pth".format(epoch)))
logging.info(f"Checkpoint {epoch} saved!")
except KeyboardInterrupt:
torch.save(net.state_dict(), "INTERRUPTED.pth")
logging.info("Saved interrupt")
raise
if __name__ == "__main__":
main()