diff --git a/src/refiners/training_utils/data_loader.py b/src/refiners/training_utils/data_loader.py index 9a3ac74..f6b75aa 100644 --- a/src/refiners/training_utils/data_loader.py +++ b/src/refiners/training_utils/data_loader.py @@ -6,7 +6,7 @@ from torch.utils.data import DataLoader, Dataset BatchT = TypeVar("BatchT") -class DataloaderConfig(BaseModel): +class DataLoaderConfig(BaseModel): batch_size: PositiveInt = 1 num_workers: int = 0 pin_memory: bool = False