diff --git a/src/unet/model.py b/src/unet/model.py index be5712e..008d784 100644 --- a/src/unet/model.py +++ b/src/unet/model.py @@ -83,7 +83,10 @@ class UNet(pl.LightningModule): ), ) - wandb.log({log_key: table}) # replace by self.log + wandb.log( + {log_key: table}, + commit=False, + ) # replace by self.log def training_step(self, batch, batch_idx): # unpacking @@ -153,7 +156,7 @@ class UNet(pl.LightningModule): mae = torch.stack([d["mae"] for d in validation_outputs]).mean() # logging - wandb.log( + self.log_dict( { "val/accuracy": accuracy, "val/bce": loss,