mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-22 06:08:46 +00:00
Fix tuple annotation for pyright 1.1.325
This commit is contained in:
parent
44e184d4d5
commit
9d2fbf6dbd
|
@ -24,7 +24,7 @@ class Converter(ContextModule):
|
|||
self.set_device = set_device
|
||||
self.set_dtype = set_dtype
|
||||
|
||||
def forward(self, *inputs: Tensor) -> tuple[Tensor]:
|
||||
def forward(self, *inputs: Tensor) -> tuple[Tensor, ...]:
|
||||
parent = self.ensure_parent
|
||||
converted_tensors: list[Tensor] = []
|
||||
|
||||
|
|
|
@ -160,7 +160,7 @@ class LatentDiffusionTrainer(Trainer[ConfigType, TextEmbeddingLatentsBatch]):
|
|||
self.current_step = random_step
|
||||
return self.ddpm_scheduler.timesteps[random_step].unsqueeze(dim=0)
|
||||
|
||||
def sample_noise(self, size: tuple[int, int, int, int], dtype: DType | None = None) -> Tensor:
|
||||
def sample_noise(self, size: tuple[int, ...], dtype: DType | None = None) -> Tensor:
|
||||
return sample_noise(
|
||||
size=size, offset_noise=self.config.latent_diffusion.offset_noise, device=self.device, dtype=dtype
|
||||
)
|
||||
|
@ -207,7 +207,7 @@ class LatentDiffusionTrainer(Trainer[ConfigType, TextEmbeddingLatentsBatch]):
|
|||
|
||||
|
||||
def sample_noise(
|
||||
size: tuple[int, int, int, int],
|
||||
size: tuple[int, ...],
|
||||
offset_noise: float = 0.1,
|
||||
device: Device | str = "cpu",
|
||||
dtype: DType | None = None,
|
||||
|
|
Loading…
Reference in a new issue