mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-22 06:08:46 +00:00
Add DataloadeConfig to Trainer
This commit is contained in:
parent
b9b999ccfe
commit
be7d065a33
|
@ -146,6 +146,17 @@ class OptimizerConfig(BaseModel):
|
|||
)
|
||||
|
||||
|
||||
class DataloaderConfig(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
num_workers: int = 0
|
||||
pin_memory: bool = False
|
||||
prefetch_factor: int | None = None
|
||||
persistent_workers: bool = False
|
||||
drop_last: bool = False
|
||||
shuffle: bool = True
|
||||
|
||||
|
||||
class ModelConfig(BaseModel):
|
||||
# If None, then requires_grad will NOT be changed when loading the model
|
||||
# this can be useful if you want to train only a part of the model
|
||||
|
@ -167,6 +178,7 @@ class BaseConfig(BaseModel):
|
|||
optimizer: OptimizerConfig
|
||||
lr_scheduler: LRSchedulerConfig
|
||||
clock: ClockConfig = ClockConfig()
|
||||
dataloader: DataloaderConfig = DataloaderConfig()
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
|
|
|
@ -329,8 +329,17 @@ class Trainer(Generic[ConfigType, Batch], ABC):
|
|||
|
||||
@cached_property
|
||||
def dataloader(self) -> DataLoader[Any]:
|
||||
config = self.config.dataloader
|
||||
return DataLoader(
|
||||
dataset=self.dataset, batch_size=self.config.training.batch_size, shuffle=True, collate_fn=self.collate_fn
|
||||
dataset=self.dataset,
|
||||
batch_size=self.config.training.batch_size,
|
||||
collate_fn=self.collate_fn,
|
||||
num_workers=config.num_workers,
|
||||
prefetch_factor=config.prefetch_factor,
|
||||
persistent_workers=config.persistent_workers,
|
||||
pin_memory=config.pin_memory,
|
||||
shuffle=config.shuffle,
|
||||
drop_last=config.drop_last,
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
|
|
Loading…
Reference in a new issue