change param name of abstract get_item method

This commit is contained in:
limiteinductive 2024-02-08 16:16:00 +00:00 committed by Benjamin Trom
parent 6d599d53fd
commit 41508e0865
2 changed files with 2 additions and 2 deletions

View file

@ -446,7 +446,7 @@ class Trainer(Generic[ConfigType, Batch], ABC):
...
@abstractmethod
def get_item(self, _: int) -> Batch:
def get_item(self, index: int) -> Batch:
"""
Returns a batch of data.

View file

@ -46,7 +46,7 @@ class MockTrainer(Trainer[MockConfig, MockBatch]):
def dataset_length(self) -> int:
return 20
def get_item(self, _: int) -> MockBatch:
def get_item(self, index: int) -> MockBatch:
return MockBatch(inputs=torch.randn(1, 10), targets=torch.randn(1, 10))
def collate_fn(self, batch: list[MockBatch]) -> MockBatch: