make trainer an abstract class

This commit is contained in:
limiteinductive 2024-01-11 17:45:40 +01:00 committed by Benjamin Trom
parent eafbc8a99a
commit 14ce2f50f9

View file

@ -1,5 +1,6 @@
import random import random
import time import time
from abc import ABC, abstractmethod
from functools import cached_property, wraps from functools import cached_property, wraps
from pathlib import Path from pathlib import Path
from typing import Any, Callable, Generic, Iterable, TypeVar, cast from typing import Any, Callable, Generic, Iterable, TypeVar, cast
@ -264,7 +265,7 @@ Batch = TypeVar("Batch")
ConfigType = TypeVar("ConfigType", bound=BaseConfig) ConfigType = TypeVar("ConfigType", bound=BaseConfig)
class Trainer(Generic[ConfigType, Batch]): class Trainer(Generic[ConfigType, Batch], ABC):
def __init__(self, config: ConfigType, callbacks: list[Callback[Any]] | None = None) -> None: def __init__(self, config: ConfigType, callbacks: list[Callback[Any]] | None = None) -> None:
self.config = config self.config = config
self.clock = TrainingClock( self.clock = TrainingClock(
@ -440,11 +441,13 @@ class Trainer(Generic[ConfigType, Batch]):
self.checkpoints_save_folder = None self.checkpoints_save_folder = None
logger.info("Checkpointing disabled: configure `save_folder` to turn it on.") logger.info("Checkpointing disabled: configure `save_folder` to turn it on.")
@abstractmethod
def load_models(self) -> dict[str, fl.Module]: def load_models(self) -> dict[str, fl.Module]:
raise NotImplementedError("The `load_models` method must be implemented in the subclass.") ...
@abstractmethod
def load_dataset(self) -> Dataset[Batch]: def load_dataset(self) -> Dataset[Batch]:
raise NotImplementedError("The `load_dataset` method must be implemented in the subclass.") ...
@cached_property @cached_property
def dataset(self) -> Dataset[Batch]: def dataset(self) -> Dataset[Batch]:
@ -471,8 +474,9 @@ class Trainer(Generic[ConfigType, Batch]):
assert self.checkpoints_save_folder is not None assert self.checkpoints_save_folder is not None
return self.checkpoints_save_folder return self.checkpoints_save_folder
@abstractmethod
def compute_loss(self, batch: Batch) -> Tensor: def compute_loss(self, batch: Batch) -> Tensor:
raise NotImplementedError("The `compute_loss` method must be implemented in the subclass.") ...
def compute_evaluation(self) -> None: def compute_evaluation(self) -> None:
pass pass