feat: even more wandb logging
Former-commit-id: 1a3c28040a734ca2229e33603405054abc8e3000 [formerly 907e4f7cae3c25a84baf0eaa5ec4d03ddaea0bdb] Former-commit-id: fdfb7dcb7d0573efbff79956e7a4bebfe26e2171
This commit is contained in:
parent
7bdac6583b
commit
2ab95734e4
42
src/train.py
42
src/train.py
|
@ -40,6 +40,9 @@ def main():
|
||||||
IMG_SIZE=512,
|
IMG_SIZE=512,
|
||||||
SPHERES=5,
|
SPHERES=5,
|
||||||
),
|
),
|
||||||
|
settings=wandb.Settings(
|
||||||
|
code_dir="./src/",
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
# create device
|
# create device
|
||||||
|
@ -51,7 +54,6 @@ def main():
|
||||||
# 0. Create network
|
# 0. Create network
|
||||||
net = UNet(n_channels=wandb.config.N_CHANNELS, n_classes=wandb.config.N_CLASSES, features=wandb.config.FEATURES)
|
net = UNet(n_channels=wandb.config.N_CHANNELS, n_classes=wandb.config.N_CLASSES, features=wandb.config.FEATURES)
|
||||||
wandb.config.PARAMETERS = sum(p.numel() for p in net.parameters() if p.requires_grad)
|
wandb.config.PARAMETERS = sum(p.numel() for p in net.parameters() if p.requires_grad)
|
||||||
wandb.watch(net, log_freq=100) # TODO: 1/4 epochs
|
|
||||||
|
|
||||||
# transfer network to device
|
# transfer network to device
|
||||||
net.to(device=device)
|
net.to(device=device)
|
||||||
|
@ -125,15 +127,11 @@ def main():
|
||||||
artifact.add_file("checkpoints/model-0.onnx")
|
artifact.add_file("checkpoints/model-0.onnx")
|
||||||
wandb.run.log_artifact(artifact)
|
wandb.run.log_artifact(artifact)
|
||||||
|
|
||||||
# print the config
|
# log gradients and weights four time per epoch
|
||||||
logging.info(
|
wandb.watch(net, log_freq=(len(train_loader) + len(val_loader)) // 4)
|
||||||
f"""wandb config:
|
|
||||||
{yaml.dump(wandb.config.as_dict())}
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
|
|
||||||
# setup wandb table for saving images
|
# print the config
|
||||||
table = wandb.Table(columns=["ID", "image", "ground truth", "prediction"])
|
logging.info(f"wandb config:\n{yaml.dump(wandb.config.as_dict())}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
for epoch in range(1, wandb.config.EPOCHS + 1):
|
for epoch in range(1, wandb.config.EPOCHS + 1):
|
||||||
|
@ -165,6 +163,7 @@ def main():
|
||||||
# compute metrics
|
# compute metrics
|
||||||
pred_masks_bin = (torch.sigmoid(pred_masks) > 0.5).float()
|
pred_masks_bin = (torch.sigmoid(pred_masks) > 0.5).float()
|
||||||
accuracy = (true_masks == pred_masks_bin).float().mean()
|
accuracy = (true_masks == pred_masks_bin).float().mean()
|
||||||
|
dice = dice_coeff(pred_masks_bin, true_masks)
|
||||||
mae = torch.nn.functional.l1_loss(pred_masks_bin, true_masks)
|
mae = torch.nn.functional.l1_loss(pred_masks_bin, true_masks)
|
||||||
|
|
||||||
# update tqdm progress bar
|
# update tqdm progress bar
|
||||||
|
@ -174,9 +173,10 @@ def main():
|
||||||
# log metrics
|
# log metrics
|
||||||
wandb.log(
|
wandb.log(
|
||||||
{
|
{
|
||||||
"train/epoch": epoch - 1 + step / len(train_loader),
|
"epoch": epoch - 1 + step / len(train_loader),
|
||||||
"train/accuracy": accuracy,
|
"train/accuracy": accuracy,
|
||||||
"train/bce": train_loss,
|
"train/bce": train_loss,
|
||||||
|
"train/dice": dice,
|
||||||
"train/mae": mae,
|
"train/mae": mae,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
@ -184,6 +184,7 @@ def main():
|
||||||
# Evaluation round
|
# Evaluation round
|
||||||
net.eval()
|
net.eval()
|
||||||
accuracy = 0
|
accuracy = 0
|
||||||
|
val_loss = 0
|
||||||
dice = 0
|
dice = 0
|
||||||
mae = 0
|
mae = 0
|
||||||
with tqdm(val_loader, total=len(ds_valid), desc="val", unit="img", leave=False) as pbar:
|
with tqdm(val_loader, total=len(ds_valid), desc="val", unit="img", leave=False) as pbar:
|
||||||
|
@ -198,19 +199,22 @@ def main():
|
||||||
masks_pred = net(images)
|
masks_pred = net(images)
|
||||||
|
|
||||||
# compute metrics
|
# compute metrics
|
||||||
|
val_loss += criterion(pred_masks, true_masks)
|
||||||
|
mae += torch.nn.functional.l1_loss(pred_masks_bin, true_masks)
|
||||||
masks_pred_bin = (torch.sigmoid(masks_pred) > 0.5).float()
|
masks_pred_bin = (torch.sigmoid(masks_pred) > 0.5).float()
|
||||||
accuracy += (true_masks == pred_masks_bin).float().sum()
|
accuracy += (true_masks == pred_masks_bin).float().mean()
|
||||||
dice += dice_coeff(masks_pred_bin, masks_true, reduce_batch_first=False)
|
dice += dice_coeff(masks_pred_bin, masks_true)
|
||||||
mae += torch.nn.functional.l1_loss(pred_masks_bin, true_masks, reduction="sum")
|
|
||||||
|
|
||||||
# update progress bar
|
# update progress bar
|
||||||
pbar.update(images.shape[0])
|
pbar.update(images.shape[0])
|
||||||
|
|
||||||
accuracy /= len(ds_valid)
|
accuracy /= len(val_loader)
|
||||||
dice /= len(val_loader) # TODO: fix dice_coeff to not average
|
val_loss /= len(val_loader)
|
||||||
mae /= len(ds_valid)
|
dice /= len(val_loader)
|
||||||
|
mae /= len(val_loader)
|
||||||
|
|
||||||
# save the last validation batch to table
|
# save the last validation batch to table
|
||||||
|
table = wandb.Table(columns=["ID", "image", "ground truth", "prediction"])
|
||||||
for i, (img, mask, pred) in enumerate(
|
for i, (img, mask, pred) in enumerate(
|
||||||
zip(
|
zip(
|
||||||
images.to("cpu"),
|
images.to("cpu"),
|
||||||
|
@ -223,11 +227,13 @@ def main():
|
||||||
# log validation metrics
|
# log validation metrics
|
||||||
wandb.log(
|
wandb.log(
|
||||||
{
|
{
|
||||||
"val/predictions": table,
|
"predictions": table,
|
||||||
"val/accuracy": accuracy,
|
"val/accuracy": accuracy,
|
||||||
|
"val/bce": val_loss,
|
||||||
"val/dice": dice,
|
"val/dice": dice,
|
||||||
"val/mae": mae,
|
"val/mae": mae,
|
||||||
}
|
},
|
||||||
|
commit=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
# update hyperparameters
|
# update hyperparameters
|
||||||
|
|
Loading…
Reference in a new issue