diff --git a/.gitignore b/.gitignore index da4c1bc..32a40fb 100644 --- a/.gitignore +++ b/.gitignore @@ -6,5 +6,7 @@ wandb/ images/ *.pth +*.onnx + *.png *.jpg diff --git a/src/train.py b/src/train.py index 49f6dc5..65183a8 100644 --- a/src/train.py +++ b/src/train.py @@ -5,6 +5,7 @@ from pathlib import Path import albumentations as A import torch import torch.nn as nn +import torch.onnx from albumentations.pytorch import ToTensorV2 from torch import optim from torch.utils.data import DataLoader @@ -17,7 +18,7 @@ from unet import UNet from utils.paste import RandomPaste CHECKPOINT_DIR = Path("./checkpoints/") -DIR_TRAIN_IMG = Path("/home/lilian/data_disk/lfainsin/val2017") +DIR_TRAIN_IMG = Path("/home/lilian/data_disk/lfainsin/smolval2017") DIR_VALID_IMG = Path("/home/lilian/data_disk/lfainsin/smoltrain2017/") DIR_SPHERE_IMG = Path("/home/lilian/data_disk/lfainsin/spheres/Images/") DIR_SPHERE_MASK = Path("/home/lilian/data_disk/lfainsin/spheres/Masks/") @@ -89,16 +90,17 @@ def main(): logging.info(f"Using device {device}") # enable cudnn benchmarking - torch.backends.cudnn.benchmark = True + # torch.backends.cudnn.benchmark = True # 0. Create network features = [16, 32, 64, 128] - net = UNet(n_channels=args.n_channels, n_classes=args.classes, features=features) + net = UNet(n_channels=3, n_classes=args.classes, features=features) + nb_params = sum(p.numel() for p in net.parameters() if p.requires_grad) logging.info( f"""Network: input channels: {net.n_channels} output channels: {net.n_classes} - nb parameters: {sum(p.numel() for p in net.parameters() if p.requires_grad)} + nb parameters: {nb_params} features: {features} """ ) @@ -152,19 +154,21 @@ def main(): criterion = nn.BCEWithLogitsLoss() # setup wandb - run = wandb.init( + wandb.init( project="U-Net-tmp", config=dict( epochs=args.epochs, batch_size=args.batch_size, learning_rate=args.lr, amp=args.amp, + features=features, + parameters=nb_params, ), ) - wandb.watch(net, log_freq=100) - artifact_model = wandb.Artifact("model", type="model") - artifact_model.add_file("model.pth") - run.log_artifact(artifact_model) + wandb.watch(net, log_freq=len(ds_train) // args.batch_size // 4) + artifact = wandb.Artifact("model", type="model") + artifact.add_file("model.pth") + wandb.run.log_artifact(artifact) logging.info( f"""Starting training: @@ -228,13 +232,25 @@ def main(): } ) - print(f"Train Loss: {train_loss:.3f}, Valid Score: {val_score:3f}") + logging.info( + f"""Validation ended: + Train Loss: {train_loss} + Valid Score: {val_score} + """ + ) # save weights when epoch end torch.save(net.state_dict(), "model.pth") + artifact = wandb.Artifact("model", type="model") + artifact.add_file("model.pth") + wandb.run.log_artifact(artifact) logging.info(f"model saved!") - run.finish() + # export model to onnx format + dummy_input = torch.randn(1, 3, 512, 512, requires_grad=True).to(device) + torch.onnx.export(net, dummy_input, "model.onnx") + + wandb.run.finish() except KeyboardInterrupt: torch.save(net.state_dict(), "INTERRUPTED.pth") @@ -244,3 +260,4 @@ def main(): if __name__ == "__main__": main() + # TODO: fix toutes les metrics, loss, accuracy, dice...