mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-22 06:08:46 +00:00
make trainer an abstract class
This commit is contained in:
parent
eafbc8a99a
commit
14ce2f50f9
|
@ -1,5 +1,6 @@
|
||||||
import random
|
import random
|
||||||
import time
|
import time
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
from functools import cached_property, wraps
|
from functools import cached_property, wraps
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Callable, Generic, Iterable, TypeVar, cast
|
from typing import Any, Callable, Generic, Iterable, TypeVar, cast
|
||||||
|
@ -264,7 +265,7 @@ Batch = TypeVar("Batch")
|
||||||
ConfigType = TypeVar("ConfigType", bound=BaseConfig)
|
ConfigType = TypeVar("ConfigType", bound=BaseConfig)
|
||||||
|
|
||||||
|
|
||||||
class Trainer(Generic[ConfigType, Batch]):
|
class Trainer(Generic[ConfigType, Batch], ABC):
|
||||||
def __init__(self, config: ConfigType, callbacks: list[Callback[Any]] | None = None) -> None:
|
def __init__(self, config: ConfigType, callbacks: list[Callback[Any]] | None = None) -> None:
|
||||||
self.config = config
|
self.config = config
|
||||||
self.clock = TrainingClock(
|
self.clock = TrainingClock(
|
||||||
|
@ -440,11 +441,13 @@ class Trainer(Generic[ConfigType, Batch]):
|
||||||
self.checkpoints_save_folder = None
|
self.checkpoints_save_folder = None
|
||||||
logger.info("Checkpointing disabled: configure `save_folder` to turn it on.")
|
logger.info("Checkpointing disabled: configure `save_folder` to turn it on.")
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
def load_models(self) -> dict[str, fl.Module]:
|
def load_models(self) -> dict[str, fl.Module]:
|
||||||
raise NotImplementedError("The `load_models` method must be implemented in the subclass.")
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
def load_dataset(self) -> Dataset[Batch]:
|
def load_dataset(self) -> Dataset[Batch]:
|
||||||
raise NotImplementedError("The `load_dataset` method must be implemented in the subclass.")
|
...
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def dataset(self) -> Dataset[Batch]:
|
def dataset(self) -> Dataset[Batch]:
|
||||||
|
@ -471,8 +474,9 @@ class Trainer(Generic[ConfigType, Batch]):
|
||||||
assert self.checkpoints_save_folder is not None
|
assert self.checkpoints_save_folder is not None
|
||||||
return self.checkpoints_save_folder
|
return self.checkpoints_save_folder
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
def compute_loss(self, batch: Batch) -> Tensor:
|
def compute_loss(self, batch: Batch) -> Tensor:
|
||||||
raise NotImplementedError("The `compute_loss` method must be implemented in the subclass.")
|
...
|
||||||
|
|
||||||
def compute_evaluation(self) -> None:
|
def compute_evaluation(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
Loading…
Reference in a new issue