mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-21 21:58:47 +00:00
refactor: rename noise => predicted_noise
and in euler, `alt_noise` can now be simply `noise`
This commit is contained in:
parent
695c24dd3a
commit
12a5439fc4
|
@ -96,5 +96,5 @@ black = true
|
||||||
[tool.pyright]
|
[tool.pyright]
|
||||||
include = ["src/refiners", "tests", "scripts"]
|
include = ["src/refiners", "tests", "scripts"]
|
||||||
strict = ["*"]
|
strict = ["*"]
|
||||||
exclude = ["**/__pycache__"]
|
exclude = ["**/__pycache__", "tests/weights"]
|
||||||
reportMissingTypeStubs = "warning"
|
reportMissingTypeStubs = "warning"
|
||||||
|
|
|
@ -91,15 +91,17 @@ class LatentDiffusionModel(fl.Module, ABC):
|
||||||
unconditional_prediction, conditional_prediction = self.unet(latents).chunk(2)
|
unconditional_prediction, conditional_prediction = self.unet(latents).chunk(2)
|
||||||
|
|
||||||
# classifier-free guidance
|
# classifier-free guidance
|
||||||
noise = unconditional_prediction + condition_scale * (conditional_prediction - unconditional_prediction)
|
predicted_noise = unconditional_prediction + condition_scale * (
|
||||||
|
conditional_prediction - unconditional_prediction
|
||||||
|
)
|
||||||
x = x.narrow(dim=1, start=0, length=4) # support > 4 channels for inpainting
|
x = x.narrow(dim=1, start=0, length=4) # support > 4 channels for inpainting
|
||||||
|
|
||||||
if self.has_self_attention_guidance():
|
if self.has_self_attention_guidance():
|
||||||
noise += self.compute_self_attention_guidance(
|
predicted_noise += self.compute_self_attention_guidance(
|
||||||
x=x, noise=unconditional_prediction, step=step, clip_text_embedding=clip_text_embedding, **kwargs
|
x=x, noise=unconditional_prediction, step=step, clip_text_embedding=clip_text_embedding, **kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.scheduler(x, noise=noise, step=step)
|
return self.scheduler(x, predicted_noise=predicted_noise, step=step)
|
||||||
|
|
||||||
def structural_copy(self: TLatentDiffusionModel) -> TLatentDiffusionModel:
|
def structural_copy(self: TLatentDiffusionModel) -> TLatentDiffusionModel:
|
||||||
return self.__class__(
|
return self.__class__(
|
||||||
|
|
|
@ -36,7 +36,7 @@ class DDIM(Scheduler):
|
||||||
timesteps = arange(start=0, end=self.num_inference_steps, step=1, device=self.device) * step_ratio + 1
|
timesteps = arange(start=0, end=self.num_inference_steps, step=1, device=self.device) * step_ratio + 1
|
||||||
return timesteps.flip(0)
|
return timesteps.flip(0)
|
||||||
|
|
||||||
def __call__(self, x: Tensor, noise: Tensor, step: int, generator: Generator | None = None) -> Tensor:
|
def __call__(self, x: Tensor, predicted_noise: Tensor, step: int, generator: Generator | None = None) -> Tensor:
|
||||||
assert self.first_inference_step <= step < self.num_inference_steps, "invalid step {step}"
|
assert self.first_inference_step <= step < self.num_inference_steps, "invalid step {step}"
|
||||||
|
|
||||||
timestep, previous_timestep = (
|
timestep, previous_timestep = (
|
||||||
|
@ -55,13 +55,13 @@ class DDIM(Scheduler):
|
||||||
else self.cumulative_scale_factors[0]
|
else self.cumulative_scale_factors[0]
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
predicted_x = (x - sqrt(1 - current_scale_factor**2) * noise) / current_scale_factor
|
predicted_x = (x - sqrt(1 - current_scale_factor**2) * predicted_noise) / current_scale_factor
|
||||||
noise_factor = sqrt(1 - previous_scale_factor**2)
|
noise_factor = sqrt(1 - previous_scale_factor**2)
|
||||||
|
|
||||||
# Do not add noise at the last step to avoid visual artifacts.
|
# Do not add noise at the last step to avoid visual artifacts.
|
||||||
if step == self.num_inference_steps - 1:
|
if step == self.num_inference_steps - 1:
|
||||||
noise_factor = 0
|
noise_factor = 0
|
||||||
|
|
||||||
denoised_x = previous_scale_factor * predicted_x + noise_factor * noise
|
denoised_x = previous_scale_factor * predicted_x + noise_factor * predicted_noise
|
||||||
|
|
||||||
return denoised_x
|
return denoised_x
|
||||||
|
|
|
@ -35,5 +35,5 @@ class DDPM(Scheduler):
|
||||||
timesteps = arange(start=0, end=self.num_inference_steps, step=1, device=self.device) * step_ratio
|
timesteps = arange(start=0, end=self.num_inference_steps, step=1, device=self.device) * step_ratio
|
||||||
return timesteps.flip(0)
|
return timesteps.flip(0)
|
||||||
|
|
||||||
def __call__(self, x: Tensor, noise: Tensor, step: int, generator: Generator | None = None) -> Tensor:
|
def __call__(self, x: Tensor, predicted_noise: Tensor, step: int, generator: Generator | None = None) -> Tensor:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
|
@ -106,7 +106,7 @@ class DPMSolver(Scheduler):
|
||||||
)
|
)
|
||||||
return denoised_x
|
return denoised_x
|
||||||
|
|
||||||
def __call__(self, x: Tensor, noise: Tensor, step: int, generator: Generator | None = None) -> Tensor:
|
def __call__(self, x: Tensor, predicted_noise: Tensor, step: int, generator: Generator | None = None) -> Tensor:
|
||||||
"""
|
"""
|
||||||
Represents one step of the backward diffusion process that iteratively denoises the input data `x`.
|
Represents one step of the backward diffusion process that iteratively denoises the input data `x`.
|
||||||
|
|
||||||
|
@ -118,7 +118,7 @@ class DPMSolver(Scheduler):
|
||||||
|
|
||||||
current_timestep = self.timesteps[step]
|
current_timestep = self.timesteps[step]
|
||||||
scale_factor, noise_ratio = self.cumulative_scale_factors[current_timestep], self.noise_std[current_timestep]
|
scale_factor, noise_ratio = self.cumulative_scale_factors[current_timestep], self.noise_std[current_timestep]
|
||||||
estimated_denoised_data = (x - noise_ratio * noise) / scale_factor
|
estimated_denoised_data = (x - noise_ratio * predicted_noise) / scale_factor
|
||||||
self.estimated_data.append(estimated_denoised_data)
|
self.estimated_data.append(estimated_denoised_data)
|
||||||
|
|
||||||
if step == self.first_inference_step or (self.last_step_first_order and step == self.num_inference_steps - 1):
|
if step == self.first_inference_step or (self.last_step_first_order and step == self.num_inference_steps - 1):
|
||||||
|
|
|
@ -58,7 +58,7 @@ class EulerScheduler(Scheduler):
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
x: Tensor,
|
x: Tensor,
|
||||||
noise: Tensor,
|
predicted_noise: Tensor,
|
||||||
step: int,
|
step: int,
|
||||||
generator: Generator | None = None,
|
generator: Generator | None = None,
|
||||||
s_churn: float = 0.0,
|
s_churn: float = 0.0,
|
||||||
|
@ -72,13 +72,15 @@ class EulerScheduler(Scheduler):
|
||||||
|
|
||||||
gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0
|
gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0
|
||||||
|
|
||||||
alt_noise = torch.randn(noise.shape, generator=generator, device=noise.device, dtype=noise.dtype)
|
noise = torch.randn(
|
||||||
eps = alt_noise * s_noise
|
predicted_noise.shape, generator=generator, device=predicted_noise.device, dtype=predicted_noise.dtype
|
||||||
|
)
|
||||||
|
eps = noise * s_noise
|
||||||
sigma_hat = sigma * (gamma + 1)
|
sigma_hat = sigma * (gamma + 1)
|
||||||
if gamma > 0:
|
if gamma > 0:
|
||||||
x = x + eps * (sigma_hat**2 - sigma**2) ** 0.5
|
x = x + eps * (sigma_hat**2 - sigma**2) ** 0.5
|
||||||
|
|
||||||
predicted_x = x - sigma_hat * noise
|
predicted_x = x - sigma_hat * predicted_noise
|
||||||
|
|
||||||
# 1st order Euler
|
# 1st order Euler
|
||||||
derivative = (x - predicted_x) / sigma_hat
|
derivative = (x - predicted_x) / sigma_hat
|
||||||
|
|
|
@ -52,9 +52,9 @@ class Scheduler(ABC):
|
||||||
self.timesteps = self._generate_timesteps()
|
self.timesteps = self._generate_timesteps()
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def __call__(self, x: Tensor, noise: Tensor, step: int, generator: Generator | None = None) -> Tensor:
|
def __call__(self, x: Tensor, predicted_noise: Tensor, step: int, generator: Generator | None = None) -> Tensor:
|
||||||
"""
|
"""
|
||||||
Applies a step of the diffusion process to the input tensor `x` using the provided `noise` and `timestep`.
|
Applies a step of the diffusion process to the input tensor `x` using the provided `predicted_noise` and `timestep`.
|
||||||
|
|
||||||
This method should be overridden by subclasses to implement the specific diffusion process.
|
This method should be overridden by subclasses to implement the specific diffusion process.
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -35,11 +35,11 @@ def test_dpm_solver_diffusers(n_steps: int, last_step_first_order: bool):
|
||||||
refiners_scheduler = DPMSolver(num_inference_steps=n_steps, last_step_first_order=last_step_first_order)
|
refiners_scheduler = DPMSolver(num_inference_steps=n_steps, last_step_first_order=last_step_first_order)
|
||||||
|
|
||||||
sample = randn(1, 3, 32, 32)
|
sample = randn(1, 3, 32, 32)
|
||||||
noise = randn(1, 3, 32, 32)
|
predicted_noise = randn(1, 3, 32, 32)
|
||||||
|
|
||||||
for step, timestep in enumerate(diffusers_scheduler.timesteps):
|
for step, timestep in enumerate(diffusers_scheduler.timesteps):
|
||||||
diffusers_output = cast(Tensor, diffusers_scheduler.step(noise, timestep, sample).prev_sample) # type: ignore
|
diffusers_output = cast(Tensor, diffusers_scheduler.step(predicted_noise, timestep, sample).prev_sample) # type: ignore
|
||||||
refiners_output = refiners_scheduler(x=sample, noise=noise, step=step)
|
refiners_output = refiners_scheduler(x=sample, predicted_noise=predicted_noise, step=step)
|
||||||
assert allclose(diffusers_output, refiners_output, rtol=0.01), f"outputs differ at step {step}"
|
assert allclose(diffusers_output, refiners_output, rtol=0.01), f"outputs differ at step {step}"
|
||||||
|
|
||||||
|
|
||||||
|
@ -60,11 +60,11 @@ def test_ddim_diffusers():
|
||||||
refiners_scheduler = DDIM(num_inference_steps=30)
|
refiners_scheduler = DDIM(num_inference_steps=30)
|
||||||
|
|
||||||
sample = randn(1, 4, 32, 32)
|
sample = randn(1, 4, 32, 32)
|
||||||
noise = randn(1, 4, 32, 32)
|
predicted_noise = randn(1, 4, 32, 32)
|
||||||
|
|
||||||
for step, timestep in enumerate(diffusers_scheduler.timesteps):
|
for step, timestep in enumerate(diffusers_scheduler.timesteps):
|
||||||
diffusers_output = cast(Tensor, diffusers_scheduler.step(noise, timestep, sample).prev_sample) # type: ignore
|
diffusers_output = cast(Tensor, diffusers_scheduler.step(predicted_noise, timestep, sample).prev_sample) # type: ignore
|
||||||
refiners_output = refiners_scheduler(x=sample, noise=noise, step=step)
|
refiners_output = refiners_scheduler(x=sample, predicted_noise=predicted_noise, step=step)
|
||||||
|
|
||||||
assert allclose(diffusers_output, refiners_output, rtol=0.01), f"outputs differ at step {step}"
|
assert allclose(diffusers_output, refiners_output, rtol=0.01), f"outputs differ at step {step}"
|
||||||
|
|
||||||
|
@ -86,15 +86,15 @@ def test_euler_diffusers():
|
||||||
refiners_scheduler = EulerScheduler(num_inference_steps=30)
|
refiners_scheduler = EulerScheduler(num_inference_steps=30)
|
||||||
|
|
||||||
sample = randn(1, 4, 32, 32)
|
sample = randn(1, 4, 32, 32)
|
||||||
noise = randn(1, 4, 32, 32)
|
predicted_noise = randn(1, 4, 32, 32)
|
||||||
|
|
||||||
ref_init_noise_sigma = diffusers_scheduler.init_noise_sigma # type: ignore
|
ref_init_noise_sigma = diffusers_scheduler.init_noise_sigma # type: ignore
|
||||||
assert isinstance(ref_init_noise_sigma, Tensor)
|
assert isinstance(ref_init_noise_sigma, Tensor)
|
||||||
assert isclose(ref_init_noise_sigma, refiners_scheduler.init_noise_sigma), "init_noise_sigma differ"
|
assert isclose(ref_init_noise_sigma, refiners_scheduler.init_noise_sigma), "init_noise_sigma differ"
|
||||||
|
|
||||||
for step, timestep in enumerate(diffusers_scheduler.timesteps):
|
for step, timestep in enumerate(diffusers_scheduler.timesteps):
|
||||||
diffusers_output = cast(Tensor, diffusers_scheduler.step(noise, timestep, sample).prev_sample) # type: ignore
|
diffusers_output = cast(Tensor, diffusers_scheduler.step(predicted_noise, timestep, sample).prev_sample) # type: ignore
|
||||||
refiners_output = refiners_scheduler(x=sample, noise=noise, step=step)
|
refiners_output = refiners_scheduler(x=sample, predicted_noise=predicted_noise, step=step)
|
||||||
|
|
||||||
assert allclose(diffusers_output, refiners_output, rtol=0.01), f"outputs differ at step {step}"
|
assert allclose(diffusers_output, refiners_output, rtol=0.01), f"outputs differ at step {step}"
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue