refactor: rename noise => predicted_noise

and in euler, `alt_noise` can now be simply `noise`
This commit is contained in:
Bryce 2024-01-20 09:37:49 -08:00 committed by Cédric Deltheil
parent 695c24dd3a
commit 12a5439fc4
8 changed files with 29 additions and 25 deletions

View file

@ -96,5 +96,5 @@ black = true
[tool.pyright] [tool.pyright]
include = ["src/refiners", "tests", "scripts"] include = ["src/refiners", "tests", "scripts"]
strict = ["*"] strict = ["*"]
exclude = ["**/__pycache__"] exclude = ["**/__pycache__", "tests/weights"]
reportMissingTypeStubs = "warning" reportMissingTypeStubs = "warning"

View file

@ -91,15 +91,17 @@ class LatentDiffusionModel(fl.Module, ABC):
unconditional_prediction, conditional_prediction = self.unet(latents).chunk(2) unconditional_prediction, conditional_prediction = self.unet(latents).chunk(2)
# classifier-free guidance # classifier-free guidance
noise = unconditional_prediction + condition_scale * (conditional_prediction - unconditional_prediction) predicted_noise = unconditional_prediction + condition_scale * (
conditional_prediction - unconditional_prediction
)
x = x.narrow(dim=1, start=0, length=4) # support > 4 channels for inpainting x = x.narrow(dim=1, start=0, length=4) # support > 4 channels for inpainting
if self.has_self_attention_guidance(): if self.has_self_attention_guidance():
noise += self.compute_self_attention_guidance( predicted_noise += self.compute_self_attention_guidance(
x=x, noise=unconditional_prediction, step=step, clip_text_embedding=clip_text_embedding, **kwargs x=x, noise=unconditional_prediction, step=step, clip_text_embedding=clip_text_embedding, **kwargs
) )
return self.scheduler(x, noise=noise, step=step) return self.scheduler(x, predicted_noise=predicted_noise, step=step)
def structural_copy(self: TLatentDiffusionModel) -> TLatentDiffusionModel: def structural_copy(self: TLatentDiffusionModel) -> TLatentDiffusionModel:
return self.__class__( return self.__class__(

View file

@ -36,7 +36,7 @@ class DDIM(Scheduler):
timesteps = arange(start=0, end=self.num_inference_steps, step=1, device=self.device) * step_ratio + 1 timesteps = arange(start=0, end=self.num_inference_steps, step=1, device=self.device) * step_ratio + 1
return timesteps.flip(0) return timesteps.flip(0)
def __call__(self, x: Tensor, noise: Tensor, step: int, generator: Generator | None = None) -> Tensor: def __call__(self, x: Tensor, predicted_noise: Tensor, step: int, generator: Generator | None = None) -> Tensor:
assert self.first_inference_step <= step < self.num_inference_steps, "invalid step {step}" assert self.first_inference_step <= step < self.num_inference_steps, "invalid step {step}"
timestep, previous_timestep = ( timestep, previous_timestep = (
@ -55,13 +55,13 @@ class DDIM(Scheduler):
else self.cumulative_scale_factors[0] else self.cumulative_scale_factors[0]
), ),
) )
predicted_x = (x - sqrt(1 - current_scale_factor**2) * noise) / current_scale_factor predicted_x = (x - sqrt(1 - current_scale_factor**2) * predicted_noise) / current_scale_factor
noise_factor = sqrt(1 - previous_scale_factor**2) noise_factor = sqrt(1 - previous_scale_factor**2)
# Do not add noise at the last step to avoid visual artifacts. # Do not add noise at the last step to avoid visual artifacts.
if step == self.num_inference_steps - 1: if step == self.num_inference_steps - 1:
noise_factor = 0 noise_factor = 0
denoised_x = previous_scale_factor * predicted_x + noise_factor * noise denoised_x = previous_scale_factor * predicted_x + noise_factor * predicted_noise
return denoised_x return denoised_x

View file

@ -35,5 +35,5 @@ class DDPM(Scheduler):
timesteps = arange(start=0, end=self.num_inference_steps, step=1, device=self.device) * step_ratio timesteps = arange(start=0, end=self.num_inference_steps, step=1, device=self.device) * step_ratio
return timesteps.flip(0) return timesteps.flip(0)
def __call__(self, x: Tensor, noise: Tensor, step: int, generator: Generator | None = None) -> Tensor: def __call__(self, x: Tensor, predicted_noise: Tensor, step: int, generator: Generator | None = None) -> Tensor:
raise NotImplementedError raise NotImplementedError

View file

@ -106,7 +106,7 @@ class DPMSolver(Scheduler):
) )
return denoised_x return denoised_x
def __call__(self, x: Tensor, noise: Tensor, step: int, generator: Generator | None = None) -> Tensor: def __call__(self, x: Tensor, predicted_noise: Tensor, step: int, generator: Generator | None = None) -> Tensor:
""" """
Represents one step of the backward diffusion process that iteratively denoises the input data `x`. Represents one step of the backward diffusion process that iteratively denoises the input data `x`.
@ -118,7 +118,7 @@ class DPMSolver(Scheduler):
current_timestep = self.timesteps[step] current_timestep = self.timesteps[step]
scale_factor, noise_ratio = self.cumulative_scale_factors[current_timestep], self.noise_std[current_timestep] scale_factor, noise_ratio = self.cumulative_scale_factors[current_timestep], self.noise_std[current_timestep]
estimated_denoised_data = (x - noise_ratio * noise) / scale_factor estimated_denoised_data = (x - noise_ratio * predicted_noise) / scale_factor
self.estimated_data.append(estimated_denoised_data) self.estimated_data.append(estimated_denoised_data)
if step == self.first_inference_step or (self.last_step_first_order and step == self.num_inference_steps - 1): if step == self.first_inference_step or (self.last_step_first_order and step == self.num_inference_steps - 1):

View file

@ -58,7 +58,7 @@ class EulerScheduler(Scheduler):
def __call__( def __call__(
self, self,
x: Tensor, x: Tensor,
noise: Tensor, predicted_noise: Tensor,
step: int, step: int,
generator: Generator | None = None, generator: Generator | None = None,
s_churn: float = 0.0, s_churn: float = 0.0,
@ -72,13 +72,15 @@ class EulerScheduler(Scheduler):
gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0 gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0
alt_noise = torch.randn(noise.shape, generator=generator, device=noise.device, dtype=noise.dtype) noise = torch.randn(
eps = alt_noise * s_noise predicted_noise.shape, generator=generator, device=predicted_noise.device, dtype=predicted_noise.dtype
)
eps = noise * s_noise
sigma_hat = sigma * (gamma + 1) sigma_hat = sigma * (gamma + 1)
if gamma > 0: if gamma > 0:
x = x + eps * (sigma_hat**2 - sigma**2) ** 0.5 x = x + eps * (sigma_hat**2 - sigma**2) ** 0.5
predicted_x = x - sigma_hat * noise predicted_x = x - sigma_hat * predicted_noise
# 1st order Euler # 1st order Euler
derivative = (x - predicted_x) / sigma_hat derivative = (x - predicted_x) / sigma_hat

View file

@ -52,9 +52,9 @@ class Scheduler(ABC):
self.timesteps = self._generate_timesteps() self.timesteps = self._generate_timesteps()
@abstractmethod @abstractmethod
def __call__(self, x: Tensor, noise: Tensor, step: int, generator: Generator | None = None) -> Tensor: def __call__(self, x: Tensor, predicted_noise: Tensor, step: int, generator: Generator | None = None) -> Tensor:
""" """
Applies a step of the diffusion process to the input tensor `x` using the provided `noise` and `timestep`. Applies a step of the diffusion process to the input tensor `x` using the provided `predicted_noise` and `timestep`.
This method should be overridden by subclasses to implement the specific diffusion process. This method should be overridden by subclasses to implement the specific diffusion process.
""" """

View file

@ -35,11 +35,11 @@ def test_dpm_solver_diffusers(n_steps: int, last_step_first_order: bool):
refiners_scheduler = DPMSolver(num_inference_steps=n_steps, last_step_first_order=last_step_first_order) refiners_scheduler = DPMSolver(num_inference_steps=n_steps, last_step_first_order=last_step_first_order)
sample = randn(1, 3, 32, 32) sample = randn(1, 3, 32, 32)
noise = randn(1, 3, 32, 32) predicted_noise = randn(1, 3, 32, 32)
for step, timestep in enumerate(diffusers_scheduler.timesteps): for step, timestep in enumerate(diffusers_scheduler.timesteps):
diffusers_output = cast(Tensor, diffusers_scheduler.step(noise, timestep, sample).prev_sample) # type: ignore diffusers_output = cast(Tensor, diffusers_scheduler.step(predicted_noise, timestep, sample).prev_sample) # type: ignore
refiners_output = refiners_scheduler(x=sample, noise=noise, step=step) refiners_output = refiners_scheduler(x=sample, predicted_noise=predicted_noise, step=step)
assert allclose(diffusers_output, refiners_output, rtol=0.01), f"outputs differ at step {step}" assert allclose(diffusers_output, refiners_output, rtol=0.01), f"outputs differ at step {step}"
@ -60,11 +60,11 @@ def test_ddim_diffusers():
refiners_scheduler = DDIM(num_inference_steps=30) refiners_scheduler = DDIM(num_inference_steps=30)
sample = randn(1, 4, 32, 32) sample = randn(1, 4, 32, 32)
noise = randn(1, 4, 32, 32) predicted_noise = randn(1, 4, 32, 32)
for step, timestep in enumerate(diffusers_scheduler.timesteps): for step, timestep in enumerate(diffusers_scheduler.timesteps):
diffusers_output = cast(Tensor, diffusers_scheduler.step(noise, timestep, sample).prev_sample) # type: ignore diffusers_output = cast(Tensor, diffusers_scheduler.step(predicted_noise, timestep, sample).prev_sample) # type: ignore
refiners_output = refiners_scheduler(x=sample, noise=noise, step=step) refiners_output = refiners_scheduler(x=sample, predicted_noise=predicted_noise, step=step)
assert allclose(diffusers_output, refiners_output, rtol=0.01), f"outputs differ at step {step}" assert allclose(diffusers_output, refiners_output, rtol=0.01), f"outputs differ at step {step}"
@ -86,15 +86,15 @@ def test_euler_diffusers():
refiners_scheduler = EulerScheduler(num_inference_steps=30) refiners_scheduler = EulerScheduler(num_inference_steps=30)
sample = randn(1, 4, 32, 32) sample = randn(1, 4, 32, 32)
noise = randn(1, 4, 32, 32) predicted_noise = randn(1, 4, 32, 32)
ref_init_noise_sigma = diffusers_scheduler.init_noise_sigma # type: ignore ref_init_noise_sigma = diffusers_scheduler.init_noise_sigma # type: ignore
assert isinstance(ref_init_noise_sigma, Tensor) assert isinstance(ref_init_noise_sigma, Tensor)
assert isclose(ref_init_noise_sigma, refiners_scheduler.init_noise_sigma), "init_noise_sigma differ" assert isclose(ref_init_noise_sigma, refiners_scheduler.init_noise_sigma), "init_noise_sigma differ"
for step, timestep in enumerate(diffusers_scheduler.timesteps): for step, timestep in enumerate(diffusers_scheduler.timesteps):
diffusers_output = cast(Tensor, diffusers_scheduler.step(noise, timestep, sample).prev_sample) # type: ignore diffusers_output = cast(Tensor, diffusers_scheduler.step(predicted_noise, timestep, sample).prev_sample) # type: ignore
refiners_output = refiners_scheduler(x=sample, noise=noise, step=step) refiners_output = refiners_scheduler(x=sample, predicted_noise=predicted_noise, step=step)
assert allclose(diffusers_output, refiners_output, rtol=0.01), f"outputs differ at step {step}" assert allclose(diffusers_output, refiners_output, rtol=0.01), f"outputs differ at step {step}"