import argparse import logging from pathlib import Path import albumentations as A import torch import torch.nn as nn import torch.onnx from albumentations.pytorch import ToTensorV2 from torch import optim from torch.utils.data import DataLoader from tqdm import tqdm import wandb from evaluate import evaluate from src.utils.dataset import SphereDataset from unet import UNet from utils.paste import RandomPaste CHECKPOINT_DIR = Path("./checkpoints/") DIR_TRAIN_IMG = Path("/home/lilian/data_disk/lfainsin/smolval2017") DIR_VALID_IMG = Path("/home/lilian/data_disk/lfainsin/smoltrain2017/") DIR_SPHERE_IMG = Path("/home/lilian/data_disk/lfainsin/spheres/Images/") DIR_SPHERE_MASK = Path("/home/lilian/data_disk/lfainsin/spheres/Masks/") def get_args(): parser = argparse.ArgumentParser( description="Train the UNet on images and target masks", ) parser.add_argument( "--epochs", "-e", metavar="E", type=int, default=5, help="Number of epochs", ) parser.add_argument( "--batch-size", "-b", dest="batch_size", metavar="B", type=int, default=70, help="Batch size", ) parser.add_argument( "--learning-rate", "-l", metavar="LR", type=float, default=1e-5, help="Learning rate", dest="lr", ) parser.add_argument( "--load", "-f", type=str, default=False, help="Load model from a .pth file", ) parser.add_argument( "--amp", action="store_true", default=True, help="Use mixed precision", ) parser.add_argument( "--classes", "-c", type=int, default=1, help="Number of classes", ) return parser.parse_args() def main(): # get args from cli args = get_args() # setup logging logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") # enable cuda, if possible device = torch.device("cuda" if torch.cuda.is_available() else "cpu") logging.info(f"Using device {device}") # enable cudnn benchmarking # torch.backends.cudnn.benchmark = True # 0. Create network features = [16, 32, 64, 128] net = UNet(n_channels=3, n_classes=args.classes, features=features) nb_params = sum(p.numel() for p in net.parameters() if p.requires_grad) logging.info( f"""Network: input channels: {net.n_channels} output channels: {net.n_classes} nb parameters: {nb_params} features: {features} """ ) # Load weights, if needed if args.load: net.load_state_dict(torch.load(args.load, map_location=device)) logging.info(f"Model loaded from {args.load}") # save initial model.pth torch.save(net.state_dict(), "model.pth") # transfer network to device net.to(device=device) # 1. Create transforms tf_train = A.Compose( [ A.Resize(512, 512), A.Flip(), A.ColorJitter(), RandomPaste(5, DIR_SPHERE_IMG, DIR_SPHERE_MASK), A.GaussianBlur(), A.ISONoise(), A.ToFloat(max_value=255), ToTensorV2(), ], ) tf_valid = A.Compose( [ A.Resize(512, 512), RandomPaste(5, DIR_SPHERE_IMG, DIR_SPHERE_MASK), A.ToFloat(max_value=255), ToTensorV2(), ], ) # 2. Create datasets ds_train = SphereDataset(image_dir=DIR_TRAIN_IMG, transform=tf_train) ds_valid = SphereDataset(image_dir=DIR_VALID_IMG, transform=tf_valid) # 3. Create data loaders loader_args = dict(batch_size=args.batch_size, num_workers=8, pin_memory=True) train_loader = DataLoader(ds_train, shuffle=True, **loader_args) val_loader = DataLoader(ds_valid, shuffle=False, drop_last=True, **loader_args) # 4. Set up the optimizer, the loss, the learning rate scheduler and the loss scaling for AMP optimizer = optim.RMSprop(net.parameters(), lr=args.lr, weight_decay=1e-8, momentum=0.9) scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, "max", patience=2) grad_scaler = torch.cuda.amp.GradScaler(enabled=args.amp) criterion = nn.BCEWithLogitsLoss() # setup wandb wandb.init( project="U-Net-tmp", config=dict( epochs=args.epochs, batch_size=args.batch_size, learning_rate=args.lr, amp=args.amp, features=features, parameters=nb_params, ), ) wandb.watch(net, log_freq=len(ds_train) // args.batch_size // 4) artifact = wandb.Artifact("model", type="model") artifact.add_file("model.pth") wandb.run.log_artifact(artifact) logging.info( f"""Starting training: Epochs: {args.epochs} Batch size: {args.batch_size} Learning rate: {args.lr} Training size: {len(ds_train)} Validation size: {len(ds_valid)} Device: {device.type} Mixed Precision: {args.amp} """ ) try: for epoch in range(1, args.epochs + 1): with tqdm(total=len(ds_train), desc=f"{epoch}/{args.epochs}", unit="img") as pbar: # Training round for step, (images, true_masks) in enumerate(train_loader): assert images.shape[1] == net.n_channels, ( f"Network has been defined with {net.n_channels} input channels, " f"but loaded images have {images.shape[1]} channels. Please check that " "the images are loaded correctly." ) # transfer images to device images = images.to(device=device) true_masks = true_masks.unsqueeze(1).to(device=device) # forward with torch.cuda.amp.autocast(enabled=args.amp): pred_masks = net(images) train_loss = criterion(pred_masks, true_masks) # backward optimizer.zero_grad(set_to_none=True) grad_scaler.scale(train_loss).backward() grad_scaler.step(optimizer) grad_scaler.update() # update tqdm progress bar pbar.update(images.shape[0]) pbar.set_postfix(**{"loss": train_loss.item()}) # log training metrics wandb.log( { "train/epoch": epoch - 1 + step / len(train_loader), "train/train_loss": train_loss, } ) # Evaluation round val_score = evaluate(net, val_loader, device) scheduler.step(val_score) # log validation metrics wandb.log( { "val/val_score": val_score, } ) logging.info( f"""Validation ended: Train Loss: {train_loss} Valid Score: {val_score} """ ) # save weights when epoch end torch.save(net.state_dict(), "model.pth") artifact = wandb.Artifact("model", type="model") artifact.add_file("model.pth") wandb.run.log_artifact(artifact) logging.info(f"model saved!") # export model to onnx format dummy_input = torch.randn(1, 3, 512, 512, requires_grad=True).to(device) torch.onnx.export(net, dummy_input, "model.onnx") wandb.run.finish() except KeyboardInterrupt: torch.save(net.state_dict(), "INTERRUPTED.pth") logging.info("Saved interrupt") raise if __name__ == "__main__": main() # TODO: fix toutes les metrics, loss, accuracy, dice...