diff --git a/src/train.py b/src/train.py index 65183a8..f3bda57 100644 --- a/src/train.py +++ b/src/train.py @@ -1,13 +1,9 @@ -import argparse import logging -from pathlib import Path import albumentations as A import torch -import torch.nn as nn -import torch.onnx +import yaml from albumentations.pytorch import ToTensorV2 -from torch import optim from torch.utils.data import DataLoader from tqdm import tqdm @@ -17,98 +13,42 @@ from src.utils.dataset import SphereDataset from unet import UNet from utils.paste import RandomPaste -CHECKPOINT_DIR = Path("./checkpoints/") -DIR_TRAIN_IMG = Path("/home/lilian/data_disk/lfainsin/smolval2017") -DIR_VALID_IMG = Path("/home/lilian/data_disk/lfainsin/smoltrain2017/") -DIR_SPHERE_IMG = Path("/home/lilian/data_disk/lfainsin/spheres/Images/") -DIR_SPHERE_MASK = Path("/home/lilian/data_disk/lfainsin/spheres/Masks/") - - -def get_args(): - parser = argparse.ArgumentParser( - description="Train the UNet on images and target masks", - ) - parser.add_argument( - "--epochs", - "-e", - metavar="E", - type=int, - default=5, - help="Number of epochs", - ) - parser.add_argument( - "--batch-size", - "-b", - dest="batch_size", - metavar="B", - type=int, - default=70, - help="Batch size", - ) - parser.add_argument( - "--learning-rate", - "-l", - metavar="LR", - type=float, - default=1e-5, - help="Learning rate", - dest="lr", - ) - parser.add_argument( - "--load", - "-f", - type=str, - default=False, - help="Load model from a .pth file", - ) - parser.add_argument( - "--amp", - action="store_true", - default=True, - help="Use mixed precision", - ) - parser.add_argument( - "--classes", - "-c", - type=int, - default=1, - help="Number of classes", - ) - - return parser.parse_args() - def main(): - # get args from cli - args = get_args() - # setup logging logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") # enable cuda, if possible device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - logging.info(f"Using device {device}") - # enable cudnn benchmarking - # torch.backends.cudnn.benchmark = True - - # 0. Create network - features = [16, 32, 64, 128] - net = UNet(n_channels=3, n_classes=args.classes, features=features) - nb_params = sum(p.numel() for p in net.parameters() if p.requires_grad) - logging.info( - f"""Network: - input channels: {net.n_channels} - output channels: {net.n_classes} - nb parameters: {nb_params} - features: {features} - """ + # setup wandb + wandb.init( + project="U-Net", + config=dict( + n_channels=3, + n_classes=1, + epochs=5, + batch_size=70, + learning_rate=1e-5, + amp=True, + num_workers=8, + pin_memory=True, + features=[16, 32, 64, 128], + benchmark=False, + device=device.type, + DIR_TRAIN_IMG="/home/lilian/data_disk/lfainsin/val2017", + DIR_VALID_IMG="/home/lilian/data_disk/lfainsin/smoltrain2017/", + DIR_SPHERE_IMG="/home/lilian/data_disk/lfainsin/spheres/Images/", + DIR_SPHERE_MASK="/home/lilian/data_disk/lfainsin/spheres/Masks/", + ), ) - # Load weights, if needed - if args.load: - net.load_state_dict(torch.load(args.load, map_location=device)) - logging.info(f"Model loaded from {args.load}") + # enable cudnn benchmarking + torch.backends.cudnn.benchmark = wandb.config.benchmark + + # 0. Create network + net = UNet(n_channels=3, n_classes=wandb.config.n_classes, features=wandb.config.features) + wandb.config.params = sum(p.numel() for p in net.parameters() if p.requires_grad) # save initial model.pth torch.save(net.state_dict(), "model.pth") @@ -122,7 +62,7 @@ def main(): A.Resize(512, 512), A.Flip(), A.ColorJitter(), - RandomPaste(5, DIR_SPHERE_IMG, DIR_SPHERE_MASK), + RandomPaste(5, wandb.config.DIR_SPHERE_IMG, wandb.config.DIR_SPHERE_MASK), A.GaussianBlur(), A.ISONoise(), A.ToFloat(max_value=255), @@ -132,59 +72,50 @@ def main(): tf_valid = A.Compose( [ A.Resize(512, 512), - RandomPaste(5, DIR_SPHERE_IMG, DIR_SPHERE_MASK), + RandomPaste(5, wandb.config.DIR_SPHERE_IMG, wandb.config.DIR_SPHERE_MASK), A.ToFloat(max_value=255), ToTensorV2(), ], ) # 2. Create datasets - ds_train = SphereDataset(image_dir=DIR_TRAIN_IMG, transform=tf_train) - ds_valid = SphereDataset(image_dir=DIR_VALID_IMG, transform=tf_valid) + ds_train = SphereDataset(image_dir=wandb.config.DIR_TRAIN_IMG, transform=tf_train) + ds_valid = SphereDataset(image_dir=wandb.config.DIR_VALID_IMG, transform=tf_valid) # 3. Create data loaders - loader_args = dict(batch_size=args.batch_size, num_workers=8, pin_memory=True) + loader_args = dict( + batch_size=wandb.config.batch_size, num_workers=wandb.config.num_workers, pin_memory=wandb.config.pin_memory + ) train_loader = DataLoader(ds_train, shuffle=True, **loader_args) val_loader = DataLoader(ds_valid, shuffle=False, drop_last=True, **loader_args) # 4. Set up the optimizer, the loss, the learning rate scheduler and the loss scaling for AMP - optimizer = optim.RMSprop(net.parameters(), lr=args.lr, weight_decay=1e-8, momentum=0.9) - scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, "max", patience=2) - grad_scaler = torch.cuda.amp.GradScaler(enabled=args.amp) - criterion = nn.BCEWithLogitsLoss() + optimizer = torch.optim.RMSprop(net.parameters(), lr=wandb.config.learning_rate, weight_decay=1e-8, momentum=0.9) + scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, "max", patience=2) + grad_scaler = torch.cuda.amp.GradScaler(enabled=wandb.config.amp) + criterion = torch.nn.BCEWithLogitsLoss() - # setup wandb - wandb.init( - project="U-Net-tmp", - config=dict( - epochs=args.epochs, - batch_size=args.batch_size, - learning_rate=args.lr, - amp=args.amp, - features=features, - parameters=nb_params, - ), - ) - wandb.watch(net, log_freq=len(ds_train) // args.batch_size // 4) - artifact = wandb.Artifact("model", type="model") + # save model.pth + wandb.watch(net, log_freq=100) + artifact = wandb.Artifact("pth", type="model") artifact.add_file("model.pth") wandb.run.log_artifact(artifact) + logging.info("model.pth saved") - logging.info( - f"""Starting training: - Epochs: {args.epochs} - Batch size: {args.batch_size} - Learning rate: {args.lr} - Training size: {len(ds_train)} - Validation size: {len(ds_valid)} - Device: {device.type} - Mixed Precision: {args.amp} - """ - ) + # save model.onxx + dummy_input = torch.randn(1, 3, 512, 512, requires_grad=True).to(device) + torch.onnx.export(net, dummy_input, "model.onnx") + artifact = wandb.Artifact("onnx", type="model") + artifact.add_file("model.onnx") + wandb.run.log_artifact(artifact) + logging.info("model.onnx saved") + + # print the config + logging.info(f"wandb config:\n{yaml.dump(wandb.config.as_dict())}") try: - for epoch in range(1, args.epochs + 1): - with tqdm(total=len(ds_train), desc=f"{epoch}/{args.epochs}", unit="img") as pbar: + for epoch in range(1, wandb.config.epochs + 1): + with tqdm(total=len(ds_train), desc=f"{epoch}/{wandb.config.epochs}", unit="img") as pbar: # Training round for step, (images, true_masks) in enumerate(train_loader): @@ -199,7 +130,7 @@ def main(): true_masks = true_masks.unsqueeze(1).to(device=device) # forward - with torch.cuda.amp.autocast(enabled=args.amp): + with torch.cuda.amp.autocast(enabled=wandb.config.amp): pred_masks = net(images) train_loss = criterion(pred_masks, true_masks) @@ -241,14 +172,18 @@ def main(): # save weights when epoch end torch.save(net.state_dict(), "model.pth") - artifact = wandb.Artifact("model", type="model") + artifact = wandb.Artifact("pth", type="model") artifact.add_file("model.pth") wandb.run.log_artifact(artifact) - logging.info(f"model saved!") + logging.info("model.pth saved") # export model to onnx format dummy_input = torch.randn(1, 3, 512, 512, requires_grad=True).to(device) torch.onnx.export(net, dummy_input, "model.onnx") + artifact = wandb.Artifact("pnnx", type="model") + artifact.add_file("model.onnx") + wandb.run.log_artifact(artifact) + logging.info("model.onnx saved") wandb.run.finish()