fix device in DDPM / DDIM timesteps

This commit is contained in:
Pierre Chapuis 2023-09-21 15:31:46 +02:00
parent b046f0cf3f
commit 72854de669
2 changed files with 2 additions and 2 deletions

View file

@ -20,7 +20,7 @@ class DDIM(Scheduler):
similar to diffusers settings for the DDIM scheduler in Stable Diffusion 1.5
"""
step_ratio = self.num_train_timesteps // self.num_inference_steps
timesteps = arange(start=0, end=self.num_inference_steps, step=1) * step_ratio + 1
timesteps = arange(start=0, end=self.num_inference_steps, step=1, device=self.device) * step_ratio + 1
return timesteps.flip(0)
def __call__(self, x: Tensor, noise: Tensor, step: int) -> Tensor:

View file

@ -20,7 +20,7 @@ class DDPM(Scheduler):
def _generate_timesteps(self) -> Tensor:
step_ratio = self.num_train_timesteps // self.num_inference_steps
timesteps = arange(start=0, end=self.num_inference_steps, step=1) * step_ratio
timesteps = arange(start=0, end=self.num_inference_steps, step=1, device=self.device) * step_ratio
return timesteps.flip(0)
def __call__(self, x: Tensor, noise: Tensor, step: int, generator: Generator | None = None) -> Tensor: