mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-22 06:08:46 +00:00
Add DataloadeConfig to Trainer
This commit is contained in:
parent
b9b999ccfe
commit
be7d065a33
|
@ -146,6 +146,17 @@ class OptimizerConfig(BaseModel):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DataloaderConfig(BaseModel):
|
||||||
|
model_config = ConfigDict(extra="forbid")
|
||||||
|
|
||||||
|
num_workers: int = 0
|
||||||
|
pin_memory: bool = False
|
||||||
|
prefetch_factor: int | None = None
|
||||||
|
persistent_workers: bool = False
|
||||||
|
drop_last: bool = False
|
||||||
|
shuffle: bool = True
|
||||||
|
|
||||||
|
|
||||||
class ModelConfig(BaseModel):
|
class ModelConfig(BaseModel):
|
||||||
# If None, then requires_grad will NOT be changed when loading the model
|
# If None, then requires_grad will NOT be changed when loading the model
|
||||||
# this can be useful if you want to train only a part of the model
|
# this can be useful if you want to train only a part of the model
|
||||||
|
@ -167,6 +178,7 @@ class BaseConfig(BaseModel):
|
||||||
optimizer: OptimizerConfig
|
optimizer: OptimizerConfig
|
||||||
lr_scheduler: LRSchedulerConfig
|
lr_scheduler: LRSchedulerConfig
|
||||||
clock: ClockConfig = ClockConfig()
|
clock: ClockConfig = ClockConfig()
|
||||||
|
dataloader: DataloaderConfig = DataloaderConfig()
|
||||||
|
|
||||||
model_config = ConfigDict(extra="forbid")
|
model_config = ConfigDict(extra="forbid")
|
||||||
|
|
||||||
|
|
|
@ -329,8 +329,17 @@ class Trainer(Generic[ConfigType, Batch], ABC):
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def dataloader(self) -> DataLoader[Any]:
|
def dataloader(self) -> DataLoader[Any]:
|
||||||
|
config = self.config.dataloader
|
||||||
return DataLoader(
|
return DataLoader(
|
||||||
dataset=self.dataset, batch_size=self.config.training.batch_size, shuffle=True, collate_fn=self.collate_fn
|
dataset=self.dataset,
|
||||||
|
batch_size=self.config.training.batch_size,
|
||||||
|
collate_fn=self.collate_fn,
|
||||||
|
num_workers=config.num_workers,
|
||||||
|
prefetch_factor=config.prefetch_factor,
|
||||||
|
persistent_workers=config.persistent_workers,
|
||||||
|
pin_memory=config.pin_memory,
|
||||||
|
shuffle=config.shuffle,
|
||||||
|
drop_last=config.drop_last,
|
||||||
)
|
)
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
|
|
Loading…
Reference in a new issue