From be7d065a332bb9a0f30ded953a3456f0951ec05b Mon Sep 17 00:00:00 2001 From: limiteinductive Date: Mon, 15 Apr 2024 14:34:54 +0000 Subject: [PATCH] Add DataloadeConfig to Trainer --- src/refiners/training_utils/config.py | 12 ++++++++++++ src/refiners/training_utils/trainer.py | 11 ++++++++++- 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/src/refiners/training_utils/config.py b/src/refiners/training_utils/config.py index 0523e92..dfe536b 100644 --- a/src/refiners/training_utils/config.py +++ b/src/refiners/training_utils/config.py @@ -146,6 +146,17 @@ class OptimizerConfig(BaseModel): ) +class DataloaderConfig(BaseModel): + model_config = ConfigDict(extra="forbid") + + num_workers: int = 0 + pin_memory: bool = False + prefetch_factor: int | None = None + persistent_workers: bool = False + drop_last: bool = False + shuffle: bool = True + + class ModelConfig(BaseModel): # If None, then requires_grad will NOT be changed when loading the model # this can be useful if you want to train only a part of the model @@ -167,6 +178,7 @@ class BaseConfig(BaseModel): optimizer: OptimizerConfig lr_scheduler: LRSchedulerConfig clock: ClockConfig = ClockConfig() + dataloader: DataloaderConfig = DataloaderConfig() model_config = ConfigDict(extra="forbid") diff --git a/src/refiners/training_utils/trainer.py b/src/refiners/training_utils/trainer.py index 84c8d57..0eb2335 100644 --- a/src/refiners/training_utils/trainer.py +++ b/src/refiners/training_utils/trainer.py @@ -329,8 +329,17 @@ class Trainer(Generic[ConfigType, Batch], ABC): @cached_property def dataloader(self) -> DataLoader[Any]: + config = self.config.dataloader return DataLoader( - dataset=self.dataset, batch_size=self.config.training.batch_size, shuffle=True, collate_fn=self.collate_fn + dataset=self.dataset, + batch_size=self.config.training.batch_size, + collate_fn=self.collate_fn, + num_workers=config.num_workers, + prefetch_factor=config.prefetch_factor, + persistent_workers=config.persistent_workers, + pin_memory=config.pin_memory, + shuffle=config.shuffle, + drop_last=config.drop_last, ) @abstractmethod