Add DataloadeConfig to Trainer

This commit is contained in:
limiteinductive 2024-04-15 14:34:54 +00:00 committed by Benjamin Trom
parent b9b999ccfe
commit be7d065a33
2 changed files with 22 additions and 1 deletions

View file

@ -146,6 +146,17 @@ class OptimizerConfig(BaseModel):
) )
class DataloaderConfig(BaseModel):
model_config = ConfigDict(extra="forbid")
num_workers: int = 0
pin_memory: bool = False
prefetch_factor: int | None = None
persistent_workers: bool = False
drop_last: bool = False
shuffle: bool = True
class ModelConfig(BaseModel): class ModelConfig(BaseModel):
# If None, then requires_grad will NOT be changed when loading the model # If None, then requires_grad will NOT be changed when loading the model
# this can be useful if you want to train only a part of the model # this can be useful if you want to train only a part of the model
@ -167,6 +178,7 @@ class BaseConfig(BaseModel):
optimizer: OptimizerConfig optimizer: OptimizerConfig
lr_scheduler: LRSchedulerConfig lr_scheduler: LRSchedulerConfig
clock: ClockConfig = ClockConfig() clock: ClockConfig = ClockConfig()
dataloader: DataloaderConfig = DataloaderConfig()
model_config = ConfigDict(extra="forbid") model_config = ConfigDict(extra="forbid")

View file

@ -329,8 +329,17 @@ class Trainer(Generic[ConfigType, Batch], ABC):
@cached_property @cached_property
def dataloader(self) -> DataLoader[Any]: def dataloader(self) -> DataLoader[Any]:
config = self.config.dataloader
return DataLoader( return DataLoader(
dataset=self.dataset, batch_size=self.config.training.batch_size, shuffle=True, collate_fn=self.collate_fn dataset=self.dataset,
batch_size=self.config.training.batch_size,
collate_fn=self.collate_fn,
num_workers=config.num_workers,
prefetch_factor=config.prefetch_factor,
persistent_workers=config.persistent_workers,
pin_memory=config.pin_memory,
shuffle=config.shuffle,
drop_last=config.drop_last,
) )
@abstractmethod @abstractmethod