DataLoader validation

This commit is contained in:
limiteinductive 2024-04-24 20:15:44 +00:00 committed by Benjamin Trom
parent 38bddc49bd
commit 22f4f4faf1
2 changed files with 64 additions and 4 deletions

View file

@ -1,22 +1,34 @@
from typing import Callable, TypeVar
from pydantic import BaseModel, ConfigDict, PositiveInt
from pydantic import BaseModel, ConfigDict, NonNegativeInt, PositiveInt, model_validator
from torch.utils.data import DataLoader, Dataset
from typing_extensions import Self
BatchT = TypeVar("BatchT")
class DataLoaderConfig(BaseModel):
batch_size: PositiveInt = 1
num_workers: int = 0
num_workers: NonNegativeInt = 0
pin_memory: bool = False
prefetch_factor: int | None = None
prefetch_factor: PositiveInt | None = None
persistent_workers: bool = False
drop_last: bool = False
shuffle: bool = True
model_config = ConfigDict(extra="forbid")
# TODO: Add more validation to the config
@model_validator(mode="after")
def check_prefetch_factor(self) -> Self:
if self.prefetch_factor is not None and self.num_workers == 0:
raise ValueError(f"prefetch_factor={self.prefetch_factor} requires num_workers > 0")
return self
@model_validator(mode="after")
def check_num_workers(self) -> Self:
if self.num_workers == 0 and self.persistent_workers is True:
raise ValueError(f"persistent_workers={self.persistent_workers} option needs num_workers > 0")
return self
class DatasetFromCallable(Dataset[BatchT]):

View file

@ -0,0 +1,48 @@
import pytest
from pydantic import ValidationError
from torch.utils.data import DataLoader
from refiners.training_utils.data_loader import DataLoaderConfig, DatasetFromCallable, create_data_loader
def get_item(index: int) -> int:
return index * 2
@pytest.fixture
def config() -> DataLoaderConfig:
return DataLoaderConfig(batch_size=2, num_workers=2, persistent_workers=True)
def test_dataloader_config_valid(config: DataLoaderConfig) -> None:
assert config.batch_size == 2
assert config.num_workers == 2
assert config.persistent_workers == True
def test_dataloader_config_invalid() -> None:
with pytest.raises(ValidationError):
DataLoaderConfig(num_workers=0, prefetch_factor=2)
with pytest.raises(ValidationError):
DataLoaderConfig(num_workers=0, persistent_workers=True)
def test_dataset_from_callable():
dataset = DatasetFromCallable(get_item, 200)
assert len(dataset) == 200
assert dataset[0] == 0
assert dataset[5] == 10
def test_create_data_loader(config: DataLoaderConfig) -> None:
data_loader = create_data_loader(get_item, 100, config)
assert isinstance(data_loader, DataLoader)
def test_create_data_loader_with_collate_fn(config: DataLoaderConfig) -> None:
def collate_fn(batch: list[int]):
return sum(batch)
data_loader = create_data_loader(get_item, 20, config=config, collate_fn=collate_fn)
assert isinstance(data_loader, DataLoader)