feat: use wandb config instead of args
Former-commit-id: ffb1cb9a6e978c41b3b62388c657ccdb13c4ad67 [formerly d557639e5a203e2ba44ebcf4466c42074f215fa0] Former-commit-id: 0d3dd6a81a66348fd4caa840a2727680554854f3
This commit is contained in:
parent
f4ed2f799e
commit
8c9ed80c6a
187
src/train.py
187
src/train.py
|
@ -1,13 +1,9 @@
|
||||||
import argparse
|
|
||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import albumentations as A
|
import albumentations as A
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import yaml
|
||||||
import torch.onnx
|
|
||||||
from albumentations.pytorch import ToTensorV2
|
from albumentations.pytorch import ToTensorV2
|
||||||
from torch import optim
|
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
@ -17,98 +13,42 @@ from src.utils.dataset import SphereDataset
|
||||||
from unet import UNet
|
from unet import UNet
|
||||||
from utils.paste import RandomPaste
|
from utils.paste import RandomPaste
|
||||||
|
|
||||||
CHECKPOINT_DIR = Path("./checkpoints/")
|
|
||||||
DIR_TRAIN_IMG = Path("/home/lilian/data_disk/lfainsin/smolval2017")
|
|
||||||
DIR_VALID_IMG = Path("/home/lilian/data_disk/lfainsin/smoltrain2017/")
|
|
||||||
DIR_SPHERE_IMG = Path("/home/lilian/data_disk/lfainsin/spheres/Images/")
|
|
||||||
DIR_SPHERE_MASK = Path("/home/lilian/data_disk/lfainsin/spheres/Masks/")
|
|
||||||
|
|
||||||
|
|
||||||
def get_args():
|
|
||||||
parser = argparse.ArgumentParser(
|
|
||||||
description="Train the UNet on images and target masks",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--epochs",
|
|
||||||
"-e",
|
|
||||||
metavar="E",
|
|
||||||
type=int,
|
|
||||||
default=5,
|
|
||||||
help="Number of epochs",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--batch-size",
|
|
||||||
"-b",
|
|
||||||
dest="batch_size",
|
|
||||||
metavar="B",
|
|
||||||
type=int,
|
|
||||||
default=70,
|
|
||||||
help="Batch size",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--learning-rate",
|
|
||||||
"-l",
|
|
||||||
metavar="LR",
|
|
||||||
type=float,
|
|
||||||
default=1e-5,
|
|
||||||
help="Learning rate",
|
|
||||||
dest="lr",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--load",
|
|
||||||
"-f",
|
|
||||||
type=str,
|
|
||||||
default=False,
|
|
||||||
help="Load model from a .pth file",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--amp",
|
|
||||||
action="store_true",
|
|
||||||
default=True,
|
|
||||||
help="Use mixed precision",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--classes",
|
|
||||||
"-c",
|
|
||||||
type=int,
|
|
||||||
default=1,
|
|
||||||
help="Number of classes",
|
|
||||||
)
|
|
||||||
|
|
||||||
return parser.parse_args()
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
# get args from cli
|
|
||||||
args = get_args()
|
|
||||||
|
|
||||||
# setup logging
|
# setup logging
|
||||||
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
|
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
|
||||||
|
|
||||||
# enable cuda, if possible
|
# enable cuda, if possible
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
logging.info(f"Using device {device}")
|
|
||||||
|
|
||||||
# enable cudnn benchmarking
|
# setup wandb
|
||||||
# torch.backends.cudnn.benchmark = True
|
wandb.init(
|
||||||
|
project="U-Net",
|
||||||
# 0. Create network
|
config=dict(
|
||||||
features = [16, 32, 64, 128]
|
n_channels=3,
|
||||||
net = UNet(n_channels=3, n_classes=args.classes, features=features)
|
n_classes=1,
|
||||||
nb_params = sum(p.numel() for p in net.parameters() if p.requires_grad)
|
epochs=5,
|
||||||
logging.info(
|
batch_size=70,
|
||||||
f"""Network:
|
learning_rate=1e-5,
|
||||||
input channels: {net.n_channels}
|
amp=True,
|
||||||
output channels: {net.n_classes}
|
num_workers=8,
|
||||||
nb parameters: {nb_params}
|
pin_memory=True,
|
||||||
features: {features}
|
features=[16, 32, 64, 128],
|
||||||
"""
|
benchmark=False,
|
||||||
|
device=device.type,
|
||||||
|
DIR_TRAIN_IMG="/home/lilian/data_disk/lfainsin/val2017",
|
||||||
|
DIR_VALID_IMG="/home/lilian/data_disk/lfainsin/smoltrain2017/",
|
||||||
|
DIR_SPHERE_IMG="/home/lilian/data_disk/lfainsin/spheres/Images/",
|
||||||
|
DIR_SPHERE_MASK="/home/lilian/data_disk/lfainsin/spheres/Masks/",
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Load weights, if needed
|
# enable cudnn benchmarking
|
||||||
if args.load:
|
torch.backends.cudnn.benchmark = wandb.config.benchmark
|
||||||
net.load_state_dict(torch.load(args.load, map_location=device))
|
|
||||||
logging.info(f"Model loaded from {args.load}")
|
# 0. Create network
|
||||||
|
net = UNet(n_channels=3, n_classes=wandb.config.n_classes, features=wandb.config.features)
|
||||||
|
wandb.config.params = sum(p.numel() for p in net.parameters() if p.requires_grad)
|
||||||
|
|
||||||
# save initial model.pth
|
# save initial model.pth
|
||||||
torch.save(net.state_dict(), "model.pth")
|
torch.save(net.state_dict(), "model.pth")
|
||||||
|
@ -122,7 +62,7 @@ def main():
|
||||||
A.Resize(512, 512),
|
A.Resize(512, 512),
|
||||||
A.Flip(),
|
A.Flip(),
|
||||||
A.ColorJitter(),
|
A.ColorJitter(),
|
||||||
RandomPaste(5, DIR_SPHERE_IMG, DIR_SPHERE_MASK),
|
RandomPaste(5, wandb.config.DIR_SPHERE_IMG, wandb.config.DIR_SPHERE_MASK),
|
||||||
A.GaussianBlur(),
|
A.GaussianBlur(),
|
||||||
A.ISONoise(),
|
A.ISONoise(),
|
||||||
A.ToFloat(max_value=255),
|
A.ToFloat(max_value=255),
|
||||||
|
@ -132,59 +72,50 @@ def main():
|
||||||
tf_valid = A.Compose(
|
tf_valid = A.Compose(
|
||||||
[
|
[
|
||||||
A.Resize(512, 512),
|
A.Resize(512, 512),
|
||||||
RandomPaste(5, DIR_SPHERE_IMG, DIR_SPHERE_MASK),
|
RandomPaste(5, wandb.config.DIR_SPHERE_IMG, wandb.config.DIR_SPHERE_MASK),
|
||||||
A.ToFloat(max_value=255),
|
A.ToFloat(max_value=255),
|
||||||
ToTensorV2(),
|
ToTensorV2(),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
# 2. Create datasets
|
# 2. Create datasets
|
||||||
ds_train = SphereDataset(image_dir=DIR_TRAIN_IMG, transform=tf_train)
|
ds_train = SphereDataset(image_dir=wandb.config.DIR_TRAIN_IMG, transform=tf_train)
|
||||||
ds_valid = SphereDataset(image_dir=DIR_VALID_IMG, transform=tf_valid)
|
ds_valid = SphereDataset(image_dir=wandb.config.DIR_VALID_IMG, transform=tf_valid)
|
||||||
|
|
||||||
# 3. Create data loaders
|
# 3. Create data loaders
|
||||||
loader_args = dict(batch_size=args.batch_size, num_workers=8, pin_memory=True)
|
loader_args = dict(
|
||||||
|
batch_size=wandb.config.batch_size, num_workers=wandb.config.num_workers, pin_memory=wandb.config.pin_memory
|
||||||
|
)
|
||||||
train_loader = DataLoader(ds_train, shuffle=True, **loader_args)
|
train_loader = DataLoader(ds_train, shuffle=True, **loader_args)
|
||||||
val_loader = DataLoader(ds_valid, shuffle=False, drop_last=True, **loader_args)
|
val_loader = DataLoader(ds_valid, shuffle=False, drop_last=True, **loader_args)
|
||||||
|
|
||||||
# 4. Set up the optimizer, the loss, the learning rate scheduler and the loss scaling for AMP
|
# 4. Set up the optimizer, the loss, the learning rate scheduler and the loss scaling for AMP
|
||||||
optimizer = optim.RMSprop(net.parameters(), lr=args.lr, weight_decay=1e-8, momentum=0.9)
|
optimizer = torch.optim.RMSprop(net.parameters(), lr=wandb.config.learning_rate, weight_decay=1e-8, momentum=0.9)
|
||||||
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, "max", patience=2)
|
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, "max", patience=2)
|
||||||
grad_scaler = torch.cuda.amp.GradScaler(enabled=args.amp)
|
grad_scaler = torch.cuda.amp.GradScaler(enabled=wandb.config.amp)
|
||||||
criterion = nn.BCEWithLogitsLoss()
|
criterion = torch.nn.BCEWithLogitsLoss()
|
||||||
|
|
||||||
# setup wandb
|
# save model.pth
|
||||||
wandb.init(
|
wandb.watch(net, log_freq=100)
|
||||||
project="U-Net-tmp",
|
artifact = wandb.Artifact("pth", type="model")
|
||||||
config=dict(
|
|
||||||
epochs=args.epochs,
|
|
||||||
batch_size=args.batch_size,
|
|
||||||
learning_rate=args.lr,
|
|
||||||
amp=args.amp,
|
|
||||||
features=features,
|
|
||||||
parameters=nb_params,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
wandb.watch(net, log_freq=len(ds_train) // args.batch_size // 4)
|
|
||||||
artifact = wandb.Artifact("model", type="model")
|
|
||||||
artifact.add_file("model.pth")
|
artifact.add_file("model.pth")
|
||||||
wandb.run.log_artifact(artifact)
|
wandb.run.log_artifact(artifact)
|
||||||
|
logging.info("model.pth saved")
|
||||||
|
|
||||||
logging.info(
|
# save model.onxx
|
||||||
f"""Starting training:
|
dummy_input = torch.randn(1, 3, 512, 512, requires_grad=True).to(device)
|
||||||
Epochs: {args.epochs}
|
torch.onnx.export(net, dummy_input, "model.onnx")
|
||||||
Batch size: {args.batch_size}
|
artifact = wandb.Artifact("onnx", type="model")
|
||||||
Learning rate: {args.lr}
|
artifact.add_file("model.onnx")
|
||||||
Training size: {len(ds_train)}
|
wandb.run.log_artifact(artifact)
|
||||||
Validation size: {len(ds_valid)}
|
logging.info("model.onnx saved")
|
||||||
Device: {device.type}
|
|
||||||
Mixed Precision: {args.amp}
|
# print the config
|
||||||
"""
|
logging.info(f"wandb config:\n{yaml.dump(wandb.config.as_dict())}")
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
for epoch in range(1, args.epochs + 1):
|
for epoch in range(1, wandb.config.epochs + 1):
|
||||||
with tqdm(total=len(ds_train), desc=f"{epoch}/{args.epochs}", unit="img") as pbar:
|
with tqdm(total=len(ds_train), desc=f"{epoch}/{wandb.config.epochs}", unit="img") as pbar:
|
||||||
|
|
||||||
# Training round
|
# Training round
|
||||||
for step, (images, true_masks) in enumerate(train_loader):
|
for step, (images, true_masks) in enumerate(train_loader):
|
||||||
|
@ -199,7 +130,7 @@ def main():
|
||||||
true_masks = true_masks.unsqueeze(1).to(device=device)
|
true_masks = true_masks.unsqueeze(1).to(device=device)
|
||||||
|
|
||||||
# forward
|
# forward
|
||||||
with torch.cuda.amp.autocast(enabled=args.amp):
|
with torch.cuda.amp.autocast(enabled=wandb.config.amp):
|
||||||
pred_masks = net(images)
|
pred_masks = net(images)
|
||||||
train_loss = criterion(pred_masks, true_masks)
|
train_loss = criterion(pred_masks, true_masks)
|
||||||
|
|
||||||
|
@ -241,14 +172,18 @@ def main():
|
||||||
|
|
||||||
# save weights when epoch end
|
# save weights when epoch end
|
||||||
torch.save(net.state_dict(), "model.pth")
|
torch.save(net.state_dict(), "model.pth")
|
||||||
artifact = wandb.Artifact("model", type="model")
|
artifact = wandb.Artifact("pth", type="model")
|
||||||
artifact.add_file("model.pth")
|
artifact.add_file("model.pth")
|
||||||
wandb.run.log_artifact(artifact)
|
wandb.run.log_artifact(artifact)
|
||||||
logging.info(f"model saved!")
|
logging.info("model.pth saved")
|
||||||
|
|
||||||
# export model to onnx format
|
# export model to onnx format
|
||||||
dummy_input = torch.randn(1, 3, 512, 512, requires_grad=True).to(device)
|
dummy_input = torch.randn(1, 3, 512, 512, requires_grad=True).to(device)
|
||||||
torch.onnx.export(net, dummy_input, "model.onnx")
|
torch.onnx.export(net, dummy_input, "model.onnx")
|
||||||
|
artifact = wandb.Artifact("pnnx", type="model")
|
||||||
|
artifact.add_file("model.onnx")
|
||||||
|
wandb.run.log_artifact(artifact)
|
||||||
|
logging.info("model.onnx saved")
|
||||||
|
|
||||||
wandb.run.finish()
|
wandb.run.finish()
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue