From 22f4f4faf1170a98d3294ae46b4a828d6cf8dd05 Mon Sep 17 00:00:00 2001 From: limiteinductive Date: Wed, 24 Apr 2024 20:15:44 +0000 Subject: [PATCH] DataLoader validation --- src/refiners/training_utils/data_loader.py | 20 +++++++-- tests/training_utils/test_data_loader.py | 48 ++++++++++++++++++++++ 2 files changed, 64 insertions(+), 4 deletions(-) create mode 100644 tests/training_utils/test_data_loader.py diff --git a/src/refiners/training_utils/data_loader.py b/src/refiners/training_utils/data_loader.py index 3cd79bc..f51dd19 100644 --- a/src/refiners/training_utils/data_loader.py +++ b/src/refiners/training_utils/data_loader.py @@ -1,22 +1,34 @@ from typing import Callable, TypeVar -from pydantic import BaseModel, ConfigDict, PositiveInt +from pydantic import BaseModel, ConfigDict, NonNegativeInt, PositiveInt, model_validator from torch.utils.data import DataLoader, Dataset +from typing_extensions import Self BatchT = TypeVar("BatchT") class DataLoaderConfig(BaseModel): batch_size: PositiveInt = 1 - num_workers: int = 0 + num_workers: NonNegativeInt = 0 pin_memory: bool = False - prefetch_factor: int | None = None + prefetch_factor: PositiveInt | None = None persistent_workers: bool = False drop_last: bool = False shuffle: bool = True model_config = ConfigDict(extra="forbid") - # TODO: Add more validation to the config + + @model_validator(mode="after") + def check_prefetch_factor(self) -> Self: + if self.prefetch_factor is not None and self.num_workers == 0: + raise ValueError(f"prefetch_factor={self.prefetch_factor} requires num_workers > 0") + return self + + @model_validator(mode="after") + def check_num_workers(self) -> Self: + if self.num_workers == 0 and self.persistent_workers is True: + raise ValueError(f"persistent_workers={self.persistent_workers} option needs num_workers > 0") + return self class DatasetFromCallable(Dataset[BatchT]): diff --git a/tests/training_utils/test_data_loader.py b/tests/training_utils/test_data_loader.py new file mode 100644 index 0000000..d1a6e2d --- /dev/null +++ b/tests/training_utils/test_data_loader.py @@ -0,0 +1,48 @@ +import pytest +from pydantic import ValidationError +from torch.utils.data import DataLoader + +from refiners.training_utils.data_loader import DataLoaderConfig, DatasetFromCallable, create_data_loader + + +def get_item(index: int) -> int: + return index * 2 + + +@pytest.fixture +def config() -> DataLoaderConfig: + return DataLoaderConfig(batch_size=2, num_workers=2, persistent_workers=True) + + +def test_dataloader_config_valid(config: DataLoaderConfig) -> None: + assert config.batch_size == 2 + assert config.num_workers == 2 + assert config.persistent_workers == True + + +def test_dataloader_config_invalid() -> None: + with pytest.raises(ValidationError): + DataLoaderConfig(num_workers=0, prefetch_factor=2) + + with pytest.raises(ValidationError): + DataLoaderConfig(num_workers=0, persistent_workers=True) + + +def test_dataset_from_callable(): + dataset = DatasetFromCallable(get_item, 200) + assert len(dataset) == 200 + assert dataset[0] == 0 + assert dataset[5] == 10 + + +def test_create_data_loader(config: DataLoaderConfig) -> None: + data_loader = create_data_loader(get_item, 100, config) + assert isinstance(data_loader, DataLoader) + + +def test_create_data_loader_with_collate_fn(config: DataLoaderConfig) -> None: + def collate_fn(batch: list[int]): + return sum(batch) + + data_loader = create_data_loader(get_item, 20, config=config, collate_fn=collate_fn) + assert isinstance(data_loader, DataLoader)