From 55c4d05d2af65aceae1c2d5ec3d1d302f3e5df62 Mon Sep 17 00:00:00 2001 From: Laurent Fainsin Date: Mon, 4 Jul 2022 15:38:49 +0200 Subject: [PATCH] fix: moved a line wrongly Former-commit-id: 93f2c7ee9a1926248716328dc9dd5a46395373be [formerly 598da90e69422bb29bf0c2770b789773c6360981] Former-commit-id: 354585b269180f4785f8bd05164eb8c8b0d16dba --- src/train.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/train.py b/src/train.py index b82462c..9ddac8e 100644 --- a/src/train.py +++ b/src/train.py @@ -174,6 +174,7 @@ if __name__ == "__main__": # forward with torch.cuda.amp.autocast(enabled=wandb.config.AMP): pred_masks = net(images) + train_loss = criterion(pred_masks, true_masks) # backward optimizer.zero_grad(set_to_none=True) @@ -182,7 +183,6 @@ if __name__ == "__main__": grad_scaler.update() # compute metrics - train_loss = criterion(pred_masks, true_masks) pred_masks_bin = (torch.sigmoid(pred_masks) > 0.5).float() accuracy = (true_masks == pred_masks_bin).float().mean() dice = dice_coeff(pred_masks_bin, true_masks) @@ -203,7 +203,7 @@ if __name__ == "__main__": } ) - if step and (step % 100 == 0 or step == len(train_loader)): + if step and (step % 250 == 0 or step == len(train_loader)): # Evaluation round net.eval() accuracy = 0