mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-21 21:58:47 +00:00
DataLoader validation
This commit is contained in:
parent
38bddc49bd
commit
22f4f4faf1
|
@ -1,22 +1,34 @@
|
||||||
from typing import Callable, TypeVar
|
from typing import Callable, TypeVar
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict, PositiveInt
|
from pydantic import BaseModel, ConfigDict, NonNegativeInt, PositiveInt, model_validator
|
||||||
from torch.utils.data import DataLoader, Dataset
|
from torch.utils.data import DataLoader, Dataset
|
||||||
|
from typing_extensions import Self
|
||||||
|
|
||||||
BatchT = TypeVar("BatchT")
|
BatchT = TypeVar("BatchT")
|
||||||
|
|
||||||
|
|
||||||
class DataLoaderConfig(BaseModel):
|
class DataLoaderConfig(BaseModel):
|
||||||
batch_size: PositiveInt = 1
|
batch_size: PositiveInt = 1
|
||||||
num_workers: int = 0
|
num_workers: NonNegativeInt = 0
|
||||||
pin_memory: bool = False
|
pin_memory: bool = False
|
||||||
prefetch_factor: int | None = None
|
prefetch_factor: PositiveInt | None = None
|
||||||
persistent_workers: bool = False
|
persistent_workers: bool = False
|
||||||
drop_last: bool = False
|
drop_last: bool = False
|
||||||
shuffle: bool = True
|
shuffle: bool = True
|
||||||
|
|
||||||
model_config = ConfigDict(extra="forbid")
|
model_config = ConfigDict(extra="forbid")
|
||||||
# TODO: Add more validation to the config
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def check_prefetch_factor(self) -> Self:
|
||||||
|
if self.prefetch_factor is not None and self.num_workers == 0:
|
||||||
|
raise ValueError(f"prefetch_factor={self.prefetch_factor} requires num_workers > 0")
|
||||||
|
return self
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def check_num_workers(self) -> Self:
|
||||||
|
if self.num_workers == 0 and self.persistent_workers is True:
|
||||||
|
raise ValueError(f"persistent_workers={self.persistent_workers} option needs num_workers > 0")
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
class DatasetFromCallable(Dataset[BatchT]):
|
class DatasetFromCallable(Dataset[BatchT]):
|
||||||
|
|
48
tests/training_utils/test_data_loader.py
Normal file
48
tests/training_utils/test_data_loader.py
Normal file
|
@ -0,0 +1,48 @@
|
||||||
|
import pytest
|
||||||
|
from pydantic import ValidationError
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
from refiners.training_utils.data_loader import DataLoaderConfig, DatasetFromCallable, create_data_loader
|
||||||
|
|
||||||
|
|
||||||
|
def get_item(index: int) -> int:
|
||||||
|
return index * 2
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def config() -> DataLoaderConfig:
|
||||||
|
return DataLoaderConfig(batch_size=2, num_workers=2, persistent_workers=True)
|
||||||
|
|
||||||
|
|
||||||
|
def test_dataloader_config_valid(config: DataLoaderConfig) -> None:
|
||||||
|
assert config.batch_size == 2
|
||||||
|
assert config.num_workers == 2
|
||||||
|
assert config.persistent_workers == True
|
||||||
|
|
||||||
|
|
||||||
|
def test_dataloader_config_invalid() -> None:
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
DataLoaderConfig(num_workers=0, prefetch_factor=2)
|
||||||
|
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
DataLoaderConfig(num_workers=0, persistent_workers=True)
|
||||||
|
|
||||||
|
|
||||||
|
def test_dataset_from_callable():
|
||||||
|
dataset = DatasetFromCallable(get_item, 200)
|
||||||
|
assert len(dataset) == 200
|
||||||
|
assert dataset[0] == 0
|
||||||
|
assert dataset[5] == 10
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_data_loader(config: DataLoaderConfig) -> None:
|
||||||
|
data_loader = create_data_loader(get_item, 100, config)
|
||||||
|
assert isinstance(data_loader, DataLoader)
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_data_loader_with_collate_fn(config: DataLoaderConfig) -> None:
|
||||||
|
def collate_fn(batch: list[int]):
|
||||||
|
return sum(batch)
|
||||||
|
|
||||||
|
data_loader = create_data_loader(get_item, 20, config=config, collate_fn=collate_fn)
|
||||||
|
assert isinstance(data_loader, DataLoader)
|
Loading…
Reference in a new issue