Fix tuple annotation for pyright 1.1.325

This commit is contained in:
limiteinductive 2023-08-31 12:40:04 +02:00 committed by Benjamin Trom
parent 44e184d4d5
commit 9d2fbf6dbd
2 changed files with 3 additions and 3 deletions

View file

@ -24,7 +24,7 @@ class Converter(ContextModule):
self.set_device = set_device
self.set_dtype = set_dtype
def forward(self, *inputs: Tensor) -> tuple[Tensor]:
def forward(self, *inputs: Tensor) -> tuple[Tensor, ...]:
parent = self.ensure_parent
converted_tensors: list[Tensor] = []

View file

@ -160,7 +160,7 @@ class LatentDiffusionTrainer(Trainer[ConfigType, TextEmbeddingLatentsBatch]):
self.current_step = random_step
return self.ddpm_scheduler.timesteps[random_step].unsqueeze(dim=0)
def sample_noise(self, size: tuple[int, int, int, int], dtype: DType | None = None) -> Tensor:
def sample_noise(self, size: tuple[int, ...], dtype: DType | None = None) -> Tensor:
return sample_noise(
size=size, offset_noise=self.config.latent_diffusion.offset_noise, device=self.device, dtype=dtype
)
@ -207,7 +207,7 @@ class LatentDiffusionTrainer(Trainer[ConfigType, TextEmbeddingLatentsBatch]):
def sample_noise(
size: tuple[int, int, int, int],
size: tuple[int, ...],
offset_noise: float = 0.1,
device: Device | str = "cpu",
dtype: DType | None = None,