diff --git a/src/refiners/training_utils/trainer.py b/src/refiners/training_utils/trainer.py index 81d5835..ee88ac5 100644 --- a/src/refiners/training_utils/trainer.py +++ b/src/refiners/training_utils/trainer.py @@ -446,7 +446,7 @@ class Trainer(Generic[ConfigType, Batch], ABC): ... @abstractmethod - def get_item(self, _: int) -> Batch: + def get_item(self, index: int) -> Batch: """ Returns a batch of data. diff --git a/tests/training_utils/test_trainer.py b/tests/training_utils/test_trainer.py index 26edd7c..e8c3643 100644 --- a/tests/training_utils/test_trainer.py +++ b/tests/training_utils/test_trainer.py @@ -46,7 +46,7 @@ class MockTrainer(Trainer[MockConfig, MockBatch]): def dataset_length(self) -> int: return 20 - def get_item(self, _: int) -> MockBatch: + def get_item(self, index: int) -> MockBatch: return MockBatch(inputs=torch.randn(1, 10), targets=torch.randn(1, 10)) def collate_fn(self, batch: list[MockBatch]) -> MockBatch: