implement data_iterable

This commit is contained in:
limiteinductive 2024-04-24 17:14:32 +00:00 committed by Benjamin Trom
parent 05a63ef44e
commit 38bddc49bd
6 changed files with 16 additions and 7 deletions

View file

@ -24,7 +24,6 @@ class TrainingConfig(BaseModel):
dtype: str = "float32"
duration: TimeValue = Iteration(1)
seed: int = 0
batch_size: int = 1
gradient_accumulation: Step = Step(1)
gradient_clipping_max_norm: float | None = None

View file

@ -39,7 +39,7 @@ class DatasetFromCallable(Dataset[BatchT]):
def create_data_loader(
get_item: Callable[[int], BatchT],
length: int,
config: DataloaderConfig,
config: DataLoaderConfig,
collate_fn: Callable[[list[BatchT]], BatchT] | None = None,
) -> DataLoader[BatchT]:
return DataLoader(

View file

@ -280,7 +280,7 @@ class Trainer(Generic[ConfigType, Batch], ABC):
@abstractmethod
def create_data_iterable(self) -> Iterable[Batch]: ...
@property
@cached_property
def data_iterable(self) -> Iterable[Batch]:
return self.create_data_iterable()

View file

@ -16,7 +16,6 @@ duration = "100:epoch"
seed = 0
device = "cpu"
dtype = "float32"
batch_size = 4
gradient_accumulation = "4:step"
gradient_clipping_max_norm = 1.0

View file

@ -11,7 +11,6 @@ verbose = false
[training]
duration = "100:epoch"
seed = 0
batch_size = 4
gradient_accumulation = "4:step"
gradient_clipping_max_norm = 1.0

View file

@ -26,7 +26,7 @@ from refiners.training_utils.common import (
scoped_seed,
)
from refiners.training_utils.config import BaseConfig, ModelConfig
from refiners.training_utils.data_loader import DataloaderConfig, create_data_loader
from refiners.training_utils.data_loader import DataLoaderConfig, create_data_loader
from refiners.training_utils.trainer import (
Trainer,
TrainingClock,
@ -65,7 +65,7 @@ class MockConfig(BaseConfig):
mock_model: MockModelConfig
mock_callback: MockCallbackConfig
data_loader: DataloaderConfig
data_loader: DataLoaderConfig
class MockModel(fl.Chain):
@ -334,3 +334,15 @@ def mock_trainer_2_models(mock_config_2_models: MockConfig) -> MockTrainerWith2M
def test_optimizer_parameters(mock_trainer_2_models: MockTrainerWith2Models) -> None:
assert len(mock_trainer_2_models.optimizer.param_groups) == 2
assert mock_trainer_2_models.optimizer.param_groups[0]["lr"] == 1e-5
class MockTrainerNoDataLoader(MockTrainer):
def create_data_iterable(self) -> list[MockBatch]: # type: ignore
return [MockBatch(inputs=torch.randn(4, 10), targets=torch.randn(4, 10)) for _ in range(5)]
def test_trainer_no_data_loader(mock_config: MockConfig) -> None:
trainer = MockTrainerNoDataLoader(config=mock_config)
trainer.train()
assert trainer.step_counter == 500
assert trainer.clock.epoch == 100