mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-09 23:12:02 +00:00
fix device in DDPM / DDIM timesteps
This commit is contained in:
parent
b046f0cf3f
commit
72854de669
|
@ -20,7 +20,7 @@ class DDIM(Scheduler):
|
|||
similar to diffusers settings for the DDIM scheduler in Stable Diffusion 1.5
|
||||
"""
|
||||
step_ratio = self.num_train_timesteps // self.num_inference_steps
|
||||
timesteps = arange(start=0, end=self.num_inference_steps, step=1) * step_ratio + 1
|
||||
timesteps = arange(start=0, end=self.num_inference_steps, step=1, device=self.device) * step_ratio + 1
|
||||
return timesteps.flip(0)
|
||||
|
||||
def __call__(self, x: Tensor, noise: Tensor, step: int) -> Tensor:
|
||||
|
|
|
@ -20,7 +20,7 @@ class DDPM(Scheduler):
|
|||
|
||||
def _generate_timesteps(self) -> Tensor:
|
||||
step_ratio = self.num_train_timesteps // self.num_inference_steps
|
||||
timesteps = arange(start=0, end=self.num_inference_steps, step=1) * step_ratio
|
||||
timesteps = arange(start=0, end=self.num_inference_steps, step=1, device=self.device) * step_ratio
|
||||
return timesteps.flip(0)
|
||||
|
||||
def __call__(self, x: Tensor, noise: Tensor, step: int, generator: Generator | None = None) -> Tensor:
|
||||
|
|
Loading…
Reference in a new issue