diff --git a/src/utils/callback.py b/src/utils/callback.py index 15cd5e5..7adf0e0 100644 --- a/src/utils/callback.py +++ b/src/utils/callback.py @@ -70,6 +70,10 @@ class ArtifactLog(Callback): def on_fit_start(self, trainer, pl_module): self.best = 1 + def on_train_epoch_end(self, trainer, pl_module): + # create checkpoint + torch.save(pl_module.state_dict(), "checkpoints/model.pth") + def on_validation_epoch_start(self, trainer, pl_module): self.dices = []