mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-22 06:08:46 +00:00
make trainer an abstract class
This commit is contained in:
parent
eafbc8a99a
commit
14ce2f50f9
|
@ -1,5 +1,6 @@
|
|||
import random
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from functools import cached_property, wraps
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Generic, Iterable, TypeVar, cast
|
||||
|
@ -264,7 +265,7 @@ Batch = TypeVar("Batch")
|
|||
ConfigType = TypeVar("ConfigType", bound=BaseConfig)
|
||||
|
||||
|
||||
class Trainer(Generic[ConfigType, Batch]):
|
||||
class Trainer(Generic[ConfigType, Batch], ABC):
|
||||
def __init__(self, config: ConfigType, callbacks: list[Callback[Any]] | None = None) -> None:
|
||||
self.config = config
|
||||
self.clock = TrainingClock(
|
||||
|
@ -440,11 +441,13 @@ class Trainer(Generic[ConfigType, Batch]):
|
|||
self.checkpoints_save_folder = None
|
||||
logger.info("Checkpointing disabled: configure `save_folder` to turn it on.")
|
||||
|
||||
@abstractmethod
|
||||
def load_models(self) -> dict[str, fl.Module]:
|
||||
raise NotImplementedError("The `load_models` method must be implemented in the subclass.")
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def load_dataset(self) -> Dataset[Batch]:
|
||||
raise NotImplementedError("The `load_dataset` method must be implemented in the subclass.")
|
||||
...
|
||||
|
||||
@cached_property
|
||||
def dataset(self) -> Dataset[Batch]:
|
||||
|
@ -471,8 +474,9 @@ class Trainer(Generic[ConfigType, Batch]):
|
|||
assert self.checkpoints_save_folder is not None
|
||||
return self.checkpoints_save_folder
|
||||
|
||||
@abstractmethod
|
||||
def compute_loss(self, batch: Batch) -> Tensor:
|
||||
raise NotImplementedError("The `compute_loss` method must be implemented in the subclass.")
|
||||
...
|
||||
|
||||
def compute_evaluation(self) -> None:
|
||||
pass
|
||||
|
|
Loading…
Reference in a new issue