mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-12 16:18:22 +00:00
fix device in DDPM / DDIM timesteps
This commit is contained in:
parent
b046f0cf3f
commit
72854de669
|
@ -20,7 +20,7 @@ class DDIM(Scheduler):
|
||||||
similar to diffusers settings for the DDIM scheduler in Stable Diffusion 1.5
|
similar to diffusers settings for the DDIM scheduler in Stable Diffusion 1.5
|
||||||
"""
|
"""
|
||||||
step_ratio = self.num_train_timesteps // self.num_inference_steps
|
step_ratio = self.num_train_timesteps // self.num_inference_steps
|
||||||
timesteps = arange(start=0, end=self.num_inference_steps, step=1) * step_ratio + 1
|
timesteps = arange(start=0, end=self.num_inference_steps, step=1, device=self.device) * step_ratio + 1
|
||||||
return timesteps.flip(0)
|
return timesteps.flip(0)
|
||||||
|
|
||||||
def __call__(self, x: Tensor, noise: Tensor, step: int) -> Tensor:
|
def __call__(self, x: Tensor, noise: Tensor, step: int) -> Tensor:
|
||||||
|
|
|
@ -20,7 +20,7 @@ class DDPM(Scheduler):
|
||||||
|
|
||||||
def _generate_timesteps(self) -> Tensor:
|
def _generate_timesteps(self) -> Tensor:
|
||||||
step_ratio = self.num_train_timesteps // self.num_inference_steps
|
step_ratio = self.num_train_timesteps // self.num_inference_steps
|
||||||
timesteps = arange(start=0, end=self.num_inference_steps, step=1) * step_ratio
|
timesteps = arange(start=0, end=self.num_inference_steps, step=1, device=self.device) * step_ratio
|
||||||
return timesteps.flip(0)
|
return timesteps.flip(0)
|
||||||
|
|
||||||
def __call__(self, x: Tensor, noise: Tensor, step: int, generator: Generator | None = None) -> Tensor:
|
def __call__(self, x: Tensor, noise: Tensor, step: int, generator: Generator | None = None) -> Tensor:
|
||||||
|
|
Loading…
Reference in a new issue