feat: use wandb config instead of args

Former-commit-id: ffb1cb9a6e978c41b3b62388c657ccdb13c4ad67 [formerly d557639e5a203e2ba44ebcf4466c42074f215fa0]
Former-commit-id: 0d3dd6a81a66348fd4caa840a2727680554854f3
This commit is contained in:
Laurent Fainsin 2022-06-30 14:04:02 +02:00
parent f4ed2f799e
commit 8c9ed80c6a

View file

@ -1,13 +1,9 @@
import argparse
import logging import logging
from pathlib import Path
import albumentations as A import albumentations as A
import torch import torch
import torch.nn as nn import yaml
import torch.onnx
from albumentations.pytorch import ToTensorV2 from albumentations.pytorch import ToTensorV2
from torch import optim
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from tqdm import tqdm from tqdm import tqdm
@ -17,98 +13,42 @@ from src.utils.dataset import SphereDataset
from unet import UNet from unet import UNet
from utils.paste import RandomPaste from utils.paste import RandomPaste
CHECKPOINT_DIR = Path("./checkpoints/")
DIR_TRAIN_IMG = Path("/home/lilian/data_disk/lfainsin/smolval2017")
DIR_VALID_IMG = Path("/home/lilian/data_disk/lfainsin/smoltrain2017/")
DIR_SPHERE_IMG = Path("/home/lilian/data_disk/lfainsin/spheres/Images/")
DIR_SPHERE_MASK = Path("/home/lilian/data_disk/lfainsin/spheres/Masks/")
def get_args():
parser = argparse.ArgumentParser(
description="Train the UNet on images and target masks",
)
parser.add_argument(
"--epochs",
"-e",
metavar="E",
type=int,
default=5,
help="Number of epochs",
)
parser.add_argument(
"--batch-size",
"-b",
dest="batch_size",
metavar="B",
type=int,
default=70,
help="Batch size",
)
parser.add_argument(
"--learning-rate",
"-l",
metavar="LR",
type=float,
default=1e-5,
help="Learning rate",
dest="lr",
)
parser.add_argument(
"--load",
"-f",
type=str,
default=False,
help="Load model from a .pth file",
)
parser.add_argument(
"--amp",
action="store_true",
default=True,
help="Use mixed precision",
)
parser.add_argument(
"--classes",
"-c",
type=int,
default=1,
help="Number of classes",
)
return parser.parse_args()
def main(): def main():
# get args from cli
args = get_args()
# setup logging # setup logging
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
# enable cuda, if possible # enable cuda, if possible
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logging.info(f"Using device {device}")
# enable cudnn benchmarking # setup wandb
# torch.backends.cudnn.benchmark = True wandb.init(
project="U-Net",
# 0. Create network config=dict(
features = [16, 32, 64, 128] n_channels=3,
net = UNet(n_channels=3, n_classes=args.classes, features=features) n_classes=1,
nb_params = sum(p.numel() for p in net.parameters() if p.requires_grad) epochs=5,
logging.info( batch_size=70,
f"""Network: learning_rate=1e-5,
input channels: {net.n_channels} amp=True,
output channels: {net.n_classes} num_workers=8,
nb parameters: {nb_params} pin_memory=True,
features: {features} features=[16, 32, 64, 128],
""" benchmark=False,
device=device.type,
DIR_TRAIN_IMG="/home/lilian/data_disk/lfainsin/val2017",
DIR_VALID_IMG="/home/lilian/data_disk/lfainsin/smoltrain2017/",
DIR_SPHERE_IMG="/home/lilian/data_disk/lfainsin/spheres/Images/",
DIR_SPHERE_MASK="/home/lilian/data_disk/lfainsin/spheres/Masks/",
),
) )
# Load weights, if needed # enable cudnn benchmarking
if args.load: torch.backends.cudnn.benchmark = wandb.config.benchmark
net.load_state_dict(torch.load(args.load, map_location=device))
logging.info(f"Model loaded from {args.load}") # 0. Create network
net = UNet(n_channels=3, n_classes=wandb.config.n_classes, features=wandb.config.features)
wandb.config.params = sum(p.numel() for p in net.parameters() if p.requires_grad)
# save initial model.pth # save initial model.pth
torch.save(net.state_dict(), "model.pth") torch.save(net.state_dict(), "model.pth")
@ -122,7 +62,7 @@ def main():
A.Resize(512, 512), A.Resize(512, 512),
A.Flip(), A.Flip(),
A.ColorJitter(), A.ColorJitter(),
RandomPaste(5, DIR_SPHERE_IMG, DIR_SPHERE_MASK), RandomPaste(5, wandb.config.DIR_SPHERE_IMG, wandb.config.DIR_SPHERE_MASK),
A.GaussianBlur(), A.GaussianBlur(),
A.ISONoise(), A.ISONoise(),
A.ToFloat(max_value=255), A.ToFloat(max_value=255),
@ -132,59 +72,50 @@ def main():
tf_valid = A.Compose( tf_valid = A.Compose(
[ [
A.Resize(512, 512), A.Resize(512, 512),
RandomPaste(5, DIR_SPHERE_IMG, DIR_SPHERE_MASK), RandomPaste(5, wandb.config.DIR_SPHERE_IMG, wandb.config.DIR_SPHERE_MASK),
A.ToFloat(max_value=255), A.ToFloat(max_value=255),
ToTensorV2(), ToTensorV2(),
], ],
) )
# 2. Create datasets # 2. Create datasets
ds_train = SphereDataset(image_dir=DIR_TRAIN_IMG, transform=tf_train) ds_train = SphereDataset(image_dir=wandb.config.DIR_TRAIN_IMG, transform=tf_train)
ds_valid = SphereDataset(image_dir=DIR_VALID_IMG, transform=tf_valid) ds_valid = SphereDataset(image_dir=wandb.config.DIR_VALID_IMG, transform=tf_valid)
# 3. Create data loaders # 3. Create data loaders
loader_args = dict(batch_size=args.batch_size, num_workers=8, pin_memory=True) loader_args = dict(
batch_size=wandb.config.batch_size, num_workers=wandb.config.num_workers, pin_memory=wandb.config.pin_memory
)
train_loader = DataLoader(ds_train, shuffle=True, **loader_args) train_loader = DataLoader(ds_train, shuffle=True, **loader_args)
val_loader = DataLoader(ds_valid, shuffle=False, drop_last=True, **loader_args) val_loader = DataLoader(ds_valid, shuffle=False, drop_last=True, **loader_args)
# 4. Set up the optimizer, the loss, the learning rate scheduler and the loss scaling for AMP # 4. Set up the optimizer, the loss, the learning rate scheduler and the loss scaling for AMP
optimizer = optim.RMSprop(net.parameters(), lr=args.lr, weight_decay=1e-8, momentum=0.9) optimizer = torch.optim.RMSprop(net.parameters(), lr=wandb.config.learning_rate, weight_decay=1e-8, momentum=0.9)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, "max", patience=2) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, "max", patience=2)
grad_scaler = torch.cuda.amp.GradScaler(enabled=args.amp) grad_scaler = torch.cuda.amp.GradScaler(enabled=wandb.config.amp)
criterion = nn.BCEWithLogitsLoss() criterion = torch.nn.BCEWithLogitsLoss()
# setup wandb # save model.pth
wandb.init( wandb.watch(net, log_freq=100)
project="U-Net-tmp", artifact = wandb.Artifact("pth", type="model")
config=dict(
epochs=args.epochs,
batch_size=args.batch_size,
learning_rate=args.lr,
amp=args.amp,
features=features,
parameters=nb_params,
),
)
wandb.watch(net, log_freq=len(ds_train) // args.batch_size // 4)
artifact = wandb.Artifact("model", type="model")
artifact.add_file("model.pth") artifact.add_file("model.pth")
wandb.run.log_artifact(artifact) wandb.run.log_artifact(artifact)
logging.info("model.pth saved")
logging.info( # save model.onxx
f"""Starting training: dummy_input = torch.randn(1, 3, 512, 512, requires_grad=True).to(device)
Epochs: {args.epochs} torch.onnx.export(net, dummy_input, "model.onnx")
Batch size: {args.batch_size} artifact = wandb.Artifact("onnx", type="model")
Learning rate: {args.lr} artifact.add_file("model.onnx")
Training size: {len(ds_train)} wandb.run.log_artifact(artifact)
Validation size: {len(ds_valid)} logging.info("model.onnx saved")
Device: {device.type}
Mixed Precision: {args.amp} # print the config
""" logging.info(f"wandb config:\n{yaml.dump(wandb.config.as_dict())}")
)
try: try:
for epoch in range(1, args.epochs + 1): for epoch in range(1, wandb.config.epochs + 1):
with tqdm(total=len(ds_train), desc=f"{epoch}/{args.epochs}", unit="img") as pbar: with tqdm(total=len(ds_train), desc=f"{epoch}/{wandb.config.epochs}", unit="img") as pbar:
# Training round # Training round
for step, (images, true_masks) in enumerate(train_loader): for step, (images, true_masks) in enumerate(train_loader):
@ -199,7 +130,7 @@ def main():
true_masks = true_masks.unsqueeze(1).to(device=device) true_masks = true_masks.unsqueeze(1).to(device=device)
# forward # forward
with torch.cuda.amp.autocast(enabled=args.amp): with torch.cuda.amp.autocast(enabled=wandb.config.amp):
pred_masks = net(images) pred_masks = net(images)
train_loss = criterion(pred_masks, true_masks) train_loss = criterion(pred_masks, true_masks)
@ -241,14 +172,18 @@ def main():
# save weights when epoch end # save weights when epoch end
torch.save(net.state_dict(), "model.pth") torch.save(net.state_dict(), "model.pth")
artifact = wandb.Artifact("model", type="model") artifact = wandb.Artifact("pth", type="model")
artifact.add_file("model.pth") artifact.add_file("model.pth")
wandb.run.log_artifact(artifact) wandb.run.log_artifact(artifact)
logging.info(f"model saved!") logging.info("model.pth saved")
# export model to onnx format # export model to onnx format
dummy_input = torch.randn(1, 3, 512, 512, requires_grad=True).to(device) dummy_input = torch.randn(1, 3, 512, 512, requires_grad=True).to(device)
torch.onnx.export(net, dummy_input, "model.onnx") torch.onnx.export(net, dummy_input, "model.onnx")
artifact = wandb.Artifact("pnnx", type="model")
artifact.add_file("model.onnx")
wandb.run.log_artifact(artifact)
logging.info("model.onnx saved")
wandb.run.finish() wandb.run.finish()