mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-22 06:08:46 +00:00
Fix tuple annotation for pyright 1.1.325
This commit is contained in:
parent
44e184d4d5
commit
9d2fbf6dbd
|
@ -24,7 +24,7 @@ class Converter(ContextModule):
|
||||||
self.set_device = set_device
|
self.set_device = set_device
|
||||||
self.set_dtype = set_dtype
|
self.set_dtype = set_dtype
|
||||||
|
|
||||||
def forward(self, *inputs: Tensor) -> tuple[Tensor]:
|
def forward(self, *inputs: Tensor) -> tuple[Tensor, ...]:
|
||||||
parent = self.ensure_parent
|
parent = self.ensure_parent
|
||||||
converted_tensors: list[Tensor] = []
|
converted_tensors: list[Tensor] = []
|
||||||
|
|
||||||
|
|
|
@ -160,7 +160,7 @@ class LatentDiffusionTrainer(Trainer[ConfigType, TextEmbeddingLatentsBatch]):
|
||||||
self.current_step = random_step
|
self.current_step = random_step
|
||||||
return self.ddpm_scheduler.timesteps[random_step].unsqueeze(dim=0)
|
return self.ddpm_scheduler.timesteps[random_step].unsqueeze(dim=0)
|
||||||
|
|
||||||
def sample_noise(self, size: tuple[int, int, int, int], dtype: DType | None = None) -> Tensor:
|
def sample_noise(self, size: tuple[int, ...], dtype: DType | None = None) -> Tensor:
|
||||||
return sample_noise(
|
return sample_noise(
|
||||||
size=size, offset_noise=self.config.latent_diffusion.offset_noise, device=self.device, dtype=dtype
|
size=size, offset_noise=self.config.latent_diffusion.offset_noise, device=self.device, dtype=dtype
|
||||||
)
|
)
|
||||||
|
@ -207,7 +207,7 @@ class LatentDiffusionTrainer(Trainer[ConfigType, TextEmbeddingLatentsBatch]):
|
||||||
|
|
||||||
|
|
||||||
def sample_noise(
|
def sample_noise(
|
||||||
size: tuple[int, int, int, int],
|
size: tuple[int, ...],
|
||||||
offset_noise: float = 0.1,
|
offset_noise: float = 0.1,
|
||||||
device: Device | str = "cpu",
|
device: Device | str = "cpu",
|
||||||
dtype: DType | None = None,
|
dtype: DType | None = None,
|
||||||
|
|
Loading…
Reference in a new issue