mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-24 15:18:46 +00:00
DataLoader validation
This commit is contained in:
parent
38bddc49bd
commit
22f4f4faf1
|
@ -1,22 +1,34 @@
|
|||
from typing import Callable, TypeVar
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, PositiveInt
|
||||
from pydantic import BaseModel, ConfigDict, NonNegativeInt, PositiveInt, model_validator
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
from typing_extensions import Self
|
||||
|
||||
BatchT = TypeVar("BatchT")
|
||||
|
||||
|
||||
class DataLoaderConfig(BaseModel):
|
||||
batch_size: PositiveInt = 1
|
||||
num_workers: int = 0
|
||||
num_workers: NonNegativeInt = 0
|
||||
pin_memory: bool = False
|
||||
prefetch_factor: int | None = None
|
||||
prefetch_factor: PositiveInt | None = None
|
||||
persistent_workers: bool = False
|
||||
drop_last: bool = False
|
||||
shuffle: bool = True
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
# TODO: Add more validation to the config
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_prefetch_factor(self) -> Self:
|
||||
if self.prefetch_factor is not None and self.num_workers == 0:
|
||||
raise ValueError(f"prefetch_factor={self.prefetch_factor} requires num_workers > 0")
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_num_workers(self) -> Self:
|
||||
if self.num_workers == 0 and self.persistent_workers is True:
|
||||
raise ValueError(f"persistent_workers={self.persistent_workers} option needs num_workers > 0")
|
||||
return self
|
||||
|
||||
|
||||
class DatasetFromCallable(Dataset[BatchT]):
|
||||
|
|
48
tests/training_utils/test_data_loader.py
Normal file
48
tests/training_utils/test_data_loader.py
Normal file
|
@ -0,0 +1,48 @@
|
|||
import pytest
|
||||
from pydantic import ValidationError
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from refiners.training_utils.data_loader import DataLoaderConfig, DatasetFromCallable, create_data_loader
|
||||
|
||||
|
||||
def get_item(index: int) -> int:
|
||||
return index * 2
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def config() -> DataLoaderConfig:
|
||||
return DataLoaderConfig(batch_size=2, num_workers=2, persistent_workers=True)
|
||||
|
||||
|
||||
def test_dataloader_config_valid(config: DataLoaderConfig) -> None:
|
||||
assert config.batch_size == 2
|
||||
assert config.num_workers == 2
|
||||
assert config.persistent_workers == True
|
||||
|
||||
|
||||
def test_dataloader_config_invalid() -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
DataLoaderConfig(num_workers=0, prefetch_factor=2)
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
DataLoaderConfig(num_workers=0, persistent_workers=True)
|
||||
|
||||
|
||||
def test_dataset_from_callable():
|
||||
dataset = DatasetFromCallable(get_item, 200)
|
||||
assert len(dataset) == 200
|
||||
assert dataset[0] == 0
|
||||
assert dataset[5] == 10
|
||||
|
||||
|
||||
def test_create_data_loader(config: DataLoaderConfig) -> None:
|
||||
data_loader = create_data_loader(get_item, 100, config)
|
||||
assert isinstance(data_loader, DataLoader)
|
||||
|
||||
|
||||
def test_create_data_loader_with_collate_fn(config: DataLoaderConfig) -> None:
|
||||
def collate_fn(batch: list[int]):
|
||||
return sum(batch)
|
||||
|
||||
data_loader = create_data_loader(get_item, 20, config=config, collate_fn=collate_fn)
|
||||
assert isinstance(data_loader, DataLoader)
|
Loading…
Reference in a new issue