From 14ce2f50f95561d5ee54d1bb3389d348e5947a26 Mon Sep 17 00:00:00 2001 From: limiteinductive Date: Thu, 11 Jan 2024 17:45:40 +0100 Subject: [PATCH] make trainer an abstract class --- src/refiners/training_utils/trainer.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/refiners/training_utils/trainer.py b/src/refiners/training_utils/trainer.py index 730bb8a..40cad58 100644 --- a/src/refiners/training_utils/trainer.py +++ b/src/refiners/training_utils/trainer.py @@ -1,5 +1,6 @@ import random import time +from abc import ABC, abstractmethod from functools import cached_property, wraps from pathlib import Path from typing import Any, Callable, Generic, Iterable, TypeVar, cast @@ -264,7 +265,7 @@ Batch = TypeVar("Batch") ConfigType = TypeVar("ConfigType", bound=BaseConfig) -class Trainer(Generic[ConfigType, Batch]): +class Trainer(Generic[ConfigType, Batch], ABC): def __init__(self, config: ConfigType, callbacks: list[Callback[Any]] | None = None) -> None: self.config = config self.clock = TrainingClock( @@ -440,11 +441,13 @@ class Trainer(Generic[ConfigType, Batch]): self.checkpoints_save_folder = None logger.info("Checkpointing disabled: configure `save_folder` to turn it on.") + @abstractmethod def load_models(self) -> dict[str, fl.Module]: - raise NotImplementedError("The `load_models` method must be implemented in the subclass.") + ... + @abstractmethod def load_dataset(self) -> Dataset[Batch]: - raise NotImplementedError("The `load_dataset` method must be implemented in the subclass.") + ... @cached_property def dataset(self) -> Dataset[Batch]: @@ -471,8 +474,9 @@ class Trainer(Generic[ConfigType, Batch]): assert self.checkpoints_save_folder is not None return self.checkpoints_save_folder + @abstractmethod def compute_loss(self, batch: Batch) -> Tensor: - raise NotImplementedError("The `compute_loss` method must be implemented in the subclass.") + ... def compute_evaluation(self) -> None: pass