feat: export to onnx

Former-commit-id: fd7e5a5ab785263a16381545ca31fd9e7fe86743 [formerly 10fdf9732fbcf4d922d945adc625e948e5f6e775]
Former-commit-id: 871745033b59e626fc38b38bfc8685c6a6366ecf
This commit is contained in:
Laurent Fainsin 2022-06-30 11:44:20 +02:00
parent 81938b944e
commit dc4a399c0f
2 changed files with 30 additions and 11 deletions

2
.gitignore vendored
View file

@ -6,5 +6,7 @@ wandb/
images/ images/
*.pth *.pth
*.onnx
*.png *.png
*.jpg *.jpg

View file

@ -5,6 +5,7 @@ from pathlib import Path
import albumentations as A import albumentations as A
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.onnx
from albumentations.pytorch import ToTensorV2 from albumentations.pytorch import ToTensorV2
from torch import optim from torch import optim
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
@ -17,7 +18,7 @@ from unet import UNet
from utils.paste import RandomPaste from utils.paste import RandomPaste
CHECKPOINT_DIR = Path("./checkpoints/") CHECKPOINT_DIR = Path("./checkpoints/")
DIR_TRAIN_IMG = Path("/home/lilian/data_disk/lfainsin/val2017") DIR_TRAIN_IMG = Path("/home/lilian/data_disk/lfainsin/smolval2017")
DIR_VALID_IMG = Path("/home/lilian/data_disk/lfainsin/smoltrain2017/") DIR_VALID_IMG = Path("/home/lilian/data_disk/lfainsin/smoltrain2017/")
DIR_SPHERE_IMG = Path("/home/lilian/data_disk/lfainsin/spheres/Images/") DIR_SPHERE_IMG = Path("/home/lilian/data_disk/lfainsin/spheres/Images/")
DIR_SPHERE_MASK = Path("/home/lilian/data_disk/lfainsin/spheres/Masks/") DIR_SPHERE_MASK = Path("/home/lilian/data_disk/lfainsin/spheres/Masks/")
@ -89,16 +90,17 @@ def main():
logging.info(f"Using device {device}") logging.info(f"Using device {device}")
# enable cudnn benchmarking # enable cudnn benchmarking
torch.backends.cudnn.benchmark = True # torch.backends.cudnn.benchmark = True
# 0. Create network # 0. Create network
features = [16, 32, 64, 128] features = [16, 32, 64, 128]
net = UNet(n_channels=args.n_channels, n_classes=args.classes, features=features) net = UNet(n_channels=3, n_classes=args.classes, features=features)
nb_params = sum(p.numel() for p in net.parameters() if p.requires_grad)
logging.info( logging.info(
f"""Network: f"""Network:
input channels: {net.n_channels} input channels: {net.n_channels}
output channels: {net.n_classes} output channels: {net.n_classes}
nb parameters: {sum(p.numel() for p in net.parameters() if p.requires_grad)} nb parameters: {nb_params}
features: {features} features: {features}
""" """
) )
@ -152,19 +154,21 @@ def main():
criterion = nn.BCEWithLogitsLoss() criterion = nn.BCEWithLogitsLoss()
# setup wandb # setup wandb
run = wandb.init( wandb.init(
project="U-Net-tmp", project="U-Net-tmp",
config=dict( config=dict(
epochs=args.epochs, epochs=args.epochs,
batch_size=args.batch_size, batch_size=args.batch_size,
learning_rate=args.lr, learning_rate=args.lr,
amp=args.amp, amp=args.amp,
features=features,
parameters=nb_params,
), ),
) )
wandb.watch(net, log_freq=100) wandb.watch(net, log_freq=len(ds_train) // args.batch_size // 4)
artifact_model = wandb.Artifact("model", type="model") artifact = wandb.Artifact("model", type="model")
artifact_model.add_file("model.pth") artifact.add_file("model.pth")
run.log_artifact(artifact_model) wandb.run.log_artifact(artifact)
logging.info( logging.info(
f"""Starting training: f"""Starting training:
@ -228,13 +232,25 @@ def main():
} }
) )
print(f"Train Loss: {train_loss:.3f}, Valid Score: {val_score:3f}") logging.info(
f"""Validation ended:
Train Loss: {train_loss}
Valid Score: {val_score}
"""
)
# save weights when epoch end # save weights when epoch end
torch.save(net.state_dict(), "model.pth") torch.save(net.state_dict(), "model.pth")
artifact = wandb.Artifact("model", type="model")
artifact.add_file("model.pth")
wandb.run.log_artifact(artifact)
logging.info(f"model saved!") logging.info(f"model saved!")
run.finish() # export model to onnx format
dummy_input = torch.randn(1, 3, 512, 512, requires_grad=True).to(device)
torch.onnx.export(net, dummy_input, "model.onnx")
wandb.run.finish()
except KeyboardInterrupt: except KeyboardInterrupt:
torch.save(net.state_dict(), "INTERRUPTED.pth") torch.save(net.state_dict(), "INTERRUPTED.pth")
@ -244,3 +260,4 @@ def main():
if __name__ == "__main__": if __name__ == "__main__":
main() main()
# TODO: fix toutes les metrics, loss, accuracy, dice...