mirror of
https://github.com/Laurent2916/REVA-QCAV.git
synced 2024-11-12 16:18:25 +00:00
feat: export to onnx
Former-commit-id: fd7e5a5ab785263a16381545ca31fd9e7fe86743 [formerly 10fdf9732fbcf4d922d945adc625e948e5f6e775] Former-commit-id: 871745033b59e626fc38b38bfc8685c6a6366ecf
This commit is contained in:
parent
81938b944e
commit
dc4a399c0f
2
.gitignore
vendored
2
.gitignore
vendored
|
@ -6,5 +6,7 @@ wandb/
|
||||||
images/
|
images/
|
||||||
|
|
||||||
*.pth
|
*.pth
|
||||||
|
*.onnx
|
||||||
|
|
||||||
*.png
|
*.png
|
||||||
*.jpg
|
*.jpg
|
||||||
|
|
39
src/train.py
39
src/train.py
|
@ -5,6 +5,7 @@ from pathlib import Path
|
||||||
import albumentations as A
|
import albumentations as A
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
import torch.onnx
|
||||||
from albumentations.pytorch import ToTensorV2
|
from albumentations.pytorch import ToTensorV2
|
||||||
from torch import optim
|
from torch import optim
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
@ -17,7 +18,7 @@ from unet import UNet
|
||||||
from utils.paste import RandomPaste
|
from utils.paste import RandomPaste
|
||||||
|
|
||||||
CHECKPOINT_DIR = Path("./checkpoints/")
|
CHECKPOINT_DIR = Path("./checkpoints/")
|
||||||
DIR_TRAIN_IMG = Path("/home/lilian/data_disk/lfainsin/val2017")
|
DIR_TRAIN_IMG = Path("/home/lilian/data_disk/lfainsin/smolval2017")
|
||||||
DIR_VALID_IMG = Path("/home/lilian/data_disk/lfainsin/smoltrain2017/")
|
DIR_VALID_IMG = Path("/home/lilian/data_disk/lfainsin/smoltrain2017/")
|
||||||
DIR_SPHERE_IMG = Path("/home/lilian/data_disk/lfainsin/spheres/Images/")
|
DIR_SPHERE_IMG = Path("/home/lilian/data_disk/lfainsin/spheres/Images/")
|
||||||
DIR_SPHERE_MASK = Path("/home/lilian/data_disk/lfainsin/spheres/Masks/")
|
DIR_SPHERE_MASK = Path("/home/lilian/data_disk/lfainsin/spheres/Masks/")
|
||||||
|
@ -89,16 +90,17 @@ def main():
|
||||||
logging.info(f"Using device {device}")
|
logging.info(f"Using device {device}")
|
||||||
|
|
||||||
# enable cudnn benchmarking
|
# enable cudnn benchmarking
|
||||||
torch.backends.cudnn.benchmark = True
|
# torch.backends.cudnn.benchmark = True
|
||||||
|
|
||||||
# 0. Create network
|
# 0. Create network
|
||||||
features = [16, 32, 64, 128]
|
features = [16, 32, 64, 128]
|
||||||
net = UNet(n_channels=args.n_channels, n_classes=args.classes, features=features)
|
net = UNet(n_channels=3, n_classes=args.classes, features=features)
|
||||||
|
nb_params = sum(p.numel() for p in net.parameters() if p.requires_grad)
|
||||||
logging.info(
|
logging.info(
|
||||||
f"""Network:
|
f"""Network:
|
||||||
input channels: {net.n_channels}
|
input channels: {net.n_channels}
|
||||||
output channels: {net.n_classes}
|
output channels: {net.n_classes}
|
||||||
nb parameters: {sum(p.numel() for p in net.parameters() if p.requires_grad)}
|
nb parameters: {nb_params}
|
||||||
features: {features}
|
features: {features}
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
@ -152,19 +154,21 @@ def main():
|
||||||
criterion = nn.BCEWithLogitsLoss()
|
criterion = nn.BCEWithLogitsLoss()
|
||||||
|
|
||||||
# setup wandb
|
# setup wandb
|
||||||
run = wandb.init(
|
wandb.init(
|
||||||
project="U-Net-tmp",
|
project="U-Net-tmp",
|
||||||
config=dict(
|
config=dict(
|
||||||
epochs=args.epochs,
|
epochs=args.epochs,
|
||||||
batch_size=args.batch_size,
|
batch_size=args.batch_size,
|
||||||
learning_rate=args.lr,
|
learning_rate=args.lr,
|
||||||
amp=args.amp,
|
amp=args.amp,
|
||||||
|
features=features,
|
||||||
|
parameters=nb_params,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
wandb.watch(net, log_freq=100)
|
wandb.watch(net, log_freq=len(ds_train) // args.batch_size // 4)
|
||||||
artifact_model = wandb.Artifact("model", type="model")
|
artifact = wandb.Artifact("model", type="model")
|
||||||
artifact_model.add_file("model.pth")
|
artifact.add_file("model.pth")
|
||||||
run.log_artifact(artifact_model)
|
wandb.run.log_artifact(artifact)
|
||||||
|
|
||||||
logging.info(
|
logging.info(
|
||||||
f"""Starting training:
|
f"""Starting training:
|
||||||
|
@ -228,13 +232,25 @@ def main():
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
print(f"Train Loss: {train_loss:.3f}, Valid Score: {val_score:3f}")
|
logging.info(
|
||||||
|
f"""Validation ended:
|
||||||
|
Train Loss: {train_loss}
|
||||||
|
Valid Score: {val_score}
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
# save weights when epoch end
|
# save weights when epoch end
|
||||||
torch.save(net.state_dict(), "model.pth")
|
torch.save(net.state_dict(), "model.pth")
|
||||||
|
artifact = wandb.Artifact("model", type="model")
|
||||||
|
artifact.add_file("model.pth")
|
||||||
|
wandb.run.log_artifact(artifact)
|
||||||
logging.info(f"model saved!")
|
logging.info(f"model saved!")
|
||||||
|
|
||||||
run.finish()
|
# export model to onnx format
|
||||||
|
dummy_input = torch.randn(1, 3, 512, 512, requires_grad=True).to(device)
|
||||||
|
torch.onnx.export(net, dummy_input, "model.onnx")
|
||||||
|
|
||||||
|
wandb.run.finish()
|
||||||
|
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
torch.save(net.state_dict(), "INTERRUPTED.pth")
|
torch.save(net.state_dict(), "INTERRUPTED.pth")
|
||||||
|
@ -244,3 +260,4 @@ def main():
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
# TODO: fix toutes les metrics, loss, accuracy, dice...
|
||||||
|
|
Loading…
Reference in a new issue