fix: retyped the deleted inital logging

Former-commit-id: 219b4a406826f90145d6806d1f1d438d9bd44282 [formerly abe23af73cef108a97791ee21218f7a121f60ac2]
Former-commit-id: c1f0d82783ec3ee4ee1518c10d2b9a9c91e6ec26
This commit is contained in:
Laurent Fainsin 2022-07-04 14:50:29 +02:00
parent 0d6f85518e
commit c69579f9da

View file

@ -147,6 +147,14 @@ if __name__ == "__main__":
# print the config
logging.info(f"wandb config:\n{yaml.dump(wandb.config.as_dict())}")
# wandb init log
wandb.log(
{
"train/learning_rate": optimizer.state_dict()["param_groups"][0]["lr"],
},
commit=False,
)
try:
for epoch in range(1, wandb.config.EPOCHS + 1):
with tqdm(total=len(ds_train), desc=f"{epoch}/{wandb.config.EPOCHS}", unit="img") as pbar:
@ -166,7 +174,6 @@ if __name__ == "__main__":
# forward
with torch.cuda.amp.autocast(enabled=wandb.config.AMP):
pred_masks = net(images)
train_loss = criterion(pred_masks, true_masks)
# backward
optimizer.zero_grad(set_to_none=True)
@ -175,6 +182,7 @@ if __name__ == "__main__":
grad_scaler.update()
# compute metrics
train_loss = criterion(pred_masks, true_masks)
pred_masks_bin = (torch.sigmoid(pred_masks) > 0.5).float()
accuracy = (true_masks == pred_masks_bin).float().mean()
dice = dice_coeff(pred_masks_bin, true_masks)
@ -214,7 +222,7 @@ if __name__ == "__main__":
masks_pred = net(images)
# compute metrics
val_loss += criterion(pred_masks, masks_true)
val_loss += criterion(masks_pred, masks_true)
masks_pred_bin = (torch.sigmoid(masks_pred) > 0.5).float()
mae += torch.nn.functional.l1_loss(masks_pred_bin, masks_true)
accuracy += (masks_true == masks_pred_bin).float().mean()