From 38bddc49bd4d48d72cda710a863600801b5cf3a7 Mon Sep 17 00:00:00 2001 From: limiteinductive Date: Wed, 24 Apr 2024 17:14:32 +0000 Subject: [PATCH] implement data_iterable --- src/refiners/training_utils/config.py | 1 - src/refiners/training_utils/data_loader.py | 2 +- src/refiners/training_utils/trainer.py | 2 +- tests/training_utils/mock_config.toml | 1 - tests/training_utils/mock_config_2_models.toml | 1 - tests/training_utils/test_trainer.py | 16 ++++++++++++++-- 6 files changed, 16 insertions(+), 7 deletions(-) diff --git a/src/refiners/training_utils/config.py b/src/refiners/training_utils/config.py index f482968..a7d3ab6 100644 --- a/src/refiners/training_utils/config.py +++ b/src/refiners/training_utils/config.py @@ -24,7 +24,6 @@ class TrainingConfig(BaseModel): dtype: str = "float32" duration: TimeValue = Iteration(1) seed: int = 0 - batch_size: int = 1 gradient_accumulation: Step = Step(1) gradient_clipping_max_norm: float | None = None diff --git a/src/refiners/training_utils/data_loader.py b/src/refiners/training_utils/data_loader.py index f6b75aa..3cd79bc 100644 --- a/src/refiners/training_utils/data_loader.py +++ b/src/refiners/training_utils/data_loader.py @@ -39,7 +39,7 @@ class DatasetFromCallable(Dataset[BatchT]): def create_data_loader( get_item: Callable[[int], BatchT], length: int, - config: DataloaderConfig, + config: DataLoaderConfig, collate_fn: Callable[[list[BatchT]], BatchT] | None = None, ) -> DataLoader[BatchT]: return DataLoader( diff --git a/src/refiners/training_utils/trainer.py b/src/refiners/training_utils/trainer.py index 3c2afc2..924fe15 100644 --- a/src/refiners/training_utils/trainer.py +++ b/src/refiners/training_utils/trainer.py @@ -280,7 +280,7 @@ class Trainer(Generic[ConfigType, Batch], ABC): @abstractmethod def create_data_iterable(self) -> Iterable[Batch]: ... - @property + @cached_property def data_iterable(self) -> Iterable[Batch]: return self.create_data_iterable() diff --git a/tests/training_utils/mock_config.toml b/tests/training_utils/mock_config.toml index 9a6f167..09472a3 100644 --- a/tests/training_utils/mock_config.toml +++ b/tests/training_utils/mock_config.toml @@ -16,7 +16,6 @@ duration = "100:epoch" seed = 0 device = "cpu" dtype = "float32" -batch_size = 4 gradient_accumulation = "4:step" gradient_clipping_max_norm = 1.0 diff --git a/tests/training_utils/mock_config_2_models.toml b/tests/training_utils/mock_config_2_models.toml index 641bef2..1e0cbcd 100644 --- a/tests/training_utils/mock_config_2_models.toml +++ b/tests/training_utils/mock_config_2_models.toml @@ -11,7 +11,6 @@ verbose = false [training] duration = "100:epoch" seed = 0 -batch_size = 4 gradient_accumulation = "4:step" gradient_clipping_max_norm = 1.0 diff --git a/tests/training_utils/test_trainer.py b/tests/training_utils/test_trainer.py index 56e9881..34dc8e9 100644 --- a/tests/training_utils/test_trainer.py +++ b/tests/training_utils/test_trainer.py @@ -26,7 +26,7 @@ from refiners.training_utils.common import ( scoped_seed, ) from refiners.training_utils.config import BaseConfig, ModelConfig -from refiners.training_utils.data_loader import DataloaderConfig, create_data_loader +from refiners.training_utils.data_loader import DataLoaderConfig, create_data_loader from refiners.training_utils.trainer import ( Trainer, TrainingClock, @@ -65,7 +65,7 @@ class MockConfig(BaseConfig): mock_model: MockModelConfig mock_callback: MockCallbackConfig - data_loader: DataloaderConfig + data_loader: DataLoaderConfig class MockModel(fl.Chain): @@ -334,3 +334,15 @@ def mock_trainer_2_models(mock_config_2_models: MockConfig) -> MockTrainerWith2M def test_optimizer_parameters(mock_trainer_2_models: MockTrainerWith2Models) -> None: assert len(mock_trainer_2_models.optimizer.param_groups) == 2 assert mock_trainer_2_models.optimizer.param_groups[0]["lr"] == 1e-5 + + +class MockTrainerNoDataLoader(MockTrainer): + def create_data_iterable(self) -> list[MockBatch]: # type: ignore + return [MockBatch(inputs=torch.randn(4, 10), targets=torch.randn(4, 10)) for _ in range(5)] + + +def test_trainer_no_data_loader(mock_config: MockConfig) -> None: + trainer = MockTrainerNoDataLoader(config=mock_config) + trainer.train() + assert trainer.step_counter == 500 + assert trainer.clock.epoch == 100