From 41508e0865022627d59231f91fa573dd4f7b88f7 Mon Sep 17 00:00:00 2001 From: limiteinductive Date: Thu, 8 Feb 2024 16:16:00 +0000 Subject: [PATCH] change param name of abstract get_item method --- src/refiners/training_utils/trainer.py | 2 +- tests/training_utils/test_trainer.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/refiners/training_utils/trainer.py b/src/refiners/training_utils/trainer.py index 81d5835..ee88ac5 100644 --- a/src/refiners/training_utils/trainer.py +++ b/src/refiners/training_utils/trainer.py @@ -446,7 +446,7 @@ class Trainer(Generic[ConfigType, Batch], ABC): ... @abstractmethod - def get_item(self, _: int) -> Batch: + def get_item(self, index: int) -> Batch: """ Returns a batch of data. diff --git a/tests/training_utils/test_trainer.py b/tests/training_utils/test_trainer.py index 26edd7c..e8c3643 100644 --- a/tests/training_utils/test_trainer.py +++ b/tests/training_utils/test_trainer.py @@ -46,7 +46,7 @@ class MockTrainer(Trainer[MockConfig, MockBatch]): def dataset_length(self) -> int: return 20 - def get_item(self, _: int) -> MockBatch: + def get_item(self, index: int) -> MockBatch: return MockBatch(inputs=torch.randn(1, 10), targets=torch.randn(1, 10)) def collate_fn(self, batch: list[MockBatch]) -> MockBatch: