apply @deltheil suggestions

This commit is contained in:
Laurent 2024-02-02 15:05:23 +00:00 committed by Laureηt
parent f62e71da1c
commit 093527a7de
6 changed files with 19 additions and 12 deletions

View file

@ -154,7 +154,7 @@ class CLIPTextEncoderL(CLIPTextEncoder):
Note: Note:
We replace the GeLU activation function with an approximate GeLU to comply with the original CLIP implementation We replace the GeLU activation function with an approximate GeLU to comply with the original CLIP implementation
of OpenAI (https://github.com/openai/CLIP/blob/main/clip/model.py#L166) of OpenAI (https://github.com/openai/CLIP/blob/a1d0717/clip/model.py#L166)
See [[arXiv:2103.00020] Learning Transferable Visual Models From Natural Language Supervision](https://arxiv.org/abs/2103.00020) See [[arXiv:2103.00020] Learning Transferable Visual Models From Natural Language Supervision](https://arxiv.org/abs/2103.00020)
for more details. for more details.

View file

@ -238,8 +238,6 @@ class LatentDiffusionAutoencoder(Chain):
x = decoder(x / self.encoder_scale) x = decoder(x / self.encoder_scale)
return x return x
# backward-compatibility alias
# TODO: deprecate this method
def image_to_latents(self, image: Image.Image) -> Tensor: def image_to_latents(self, image: Image.Image) -> Tensor:
return self.images_to_latents([image]) return self.images_to_latents([image])
@ -261,7 +259,6 @@ class LatentDiffusionAutoencoder(Chain):
def decode_latents(self, x: Tensor) -> Image.Image: def decode_latents(self, x: Tensor) -> Image.Image:
return self.latents_to_image(x) return self.latents_to_image(x)
# TODO: deprecated this method ?
def latents_to_image(self, x: Tensor) -> Image.Image: def latents_to_image(self, x: Tensor) -> Image.Image:
if x.shape[0] != 1: if x.shape[0] != 1:
raise ValueError(f"Expected batch size of 1, got {x.shape[0]}") raise ValueError(f"Expected batch size of 1, got {x.shape[0]}")

View file

@ -121,6 +121,10 @@ class Solver(fl.Module, ABC):
def remove_noise(self, x: Tensor, noise: Tensor, step: int) -> Tensor: def remove_noise(self, x: Tensor, noise: Tensor, step: int) -> Tensor:
"""Remove noise from the input tensor using the current step of the diffusion process. """Remove noise from the input tensor using the current step of the diffusion process.
Note:
See [[arXiv:2006.11239] Denoising Diffusion Probabilistic Models, Equation 15](https://arxiv.org/abs/2006.11239)
and [[arXiv:2210.00939] Improving Sample Quality of Diffusion Models Using Self-Attention Guidance](https://arxiv.org/abs/2210.00939).
Args: Args:
x: The input tensor to remove noise from. x: The input tensor to remove noise from.
noise: The noise tensor to remove from the input tensor. noise: The noise tensor to remove from the input tensor.
@ -132,9 +136,6 @@ class Solver(fl.Module, ABC):
timestep = self.timesteps[step] timestep = self.timesteps[step]
cumulative_scale_factors = self.cumulative_scale_factors[timestep] cumulative_scale_factors = self.cumulative_scale_factors[timestep]
noise_stds = self.noise_std[timestep] noise_stds = self.noise_std[timestep]
# See equation (15) from https://arxiv.org/pdf/2006.11239.pdf.
# Useful to preview progress or for guidance
# See also https://arxiv.org/pdf/2210.00939.pdf (self-attention guidance)
denoised_x = (x - noise_stds * noise) / cumulative_scale_factors denoised_x = (x - noise_stds * noise) / cumulative_scale_factors
return denoised_x return denoised_x
@ -196,6 +197,9 @@ class Solver(fl.Module, ABC):
This method should only be overridden by solvers that This method should only be overridden by solvers that
need to scale the input according to the current timestep. need to scale the input according to the current timestep.
By default, this method does not scale the input.
(scale=1)
Args: Args:
x: The input tensor to scale. x: The input tensor to scale.
step: The current step of the diffusion process. step: The current step of the diffusion process.

View file

@ -96,6 +96,9 @@ class StableDiffusion_1(LatentDiffusionModel):
def set_self_attention_guidance(self, enable: bool, scale: float = 1.0) -> None: def set_self_attention_guidance(self, enable: bool, scale: float = 1.0) -> None:
"""Set whether to enable self-attention guidance. """Set whether to enable self-attention guidance.
See [[arXiv:2210.00939] Improving Sample Quality of Diffusion Models Using Self-Attention Guidance](https://arxiv.org/abs/2210.00939)
for more details.
Args: Args:
enable: Whether to enable self-attention guidance. enable: Whether to enable self-attention guidance.
scale: The scale to use. scale: The scale to use.
@ -114,7 +117,7 @@ class StableDiffusion_1(LatentDiffusionModel):
return self._find_sag_adapter() is not None return self._find_sag_adapter() is not None
def _find_sag_adapter(self) -> SD1SAGAdapter | None: def _find_sag_adapter(self) -> SD1SAGAdapter | None:
"""Finds the self-attention guidance adapter.""" """Finds the self-attention guidance adapter, if any."""
for p in self.unet.get_parents(): for p in self.unet.get_parents():
if isinstance(p, SD1SAGAdapter): if isinstance(p, SD1SAGAdapter):
return p return p

View file

@ -140,6 +140,9 @@ class StableDiffusion_XL(LatentDiffusionModel):
def set_self_attention_guidance(self, enable: bool, scale: float = 1.0) -> None: def set_self_attention_guidance(self, enable: bool, scale: float = 1.0) -> None:
"""Sets the self-attention guidance. """Sets the self-attention guidance.
See [[arXiv:2210.00939] Improving Sample Quality of Diffusion Models Using Self-Attention Guidance](https://arxiv.org/abs/2210.00939)
for more details.
Args: Args:
enable: Whether to enable self-attention guidance or not. enable: Whether to enable self-attention guidance or not.
scale: The scale to use. scale: The scale to use.
@ -158,7 +161,7 @@ class StableDiffusion_XL(LatentDiffusionModel):
return self._find_sag_adapter() is not None return self._find_sag_adapter() is not None
def _find_sag_adapter(self) -> SDXLSAGAdapter | None: def _find_sag_adapter(self) -> SDXLSAGAdapter | None:
"""Finds the self-attention guidance adapter.""" """Finds the self-attention guidance adapter, if any."""
for p in self.unet.get_parents(): for p in self.unet.get_parents():
if isinstance(p, SDXLSAGAdapter): if isinstance(p, SDXLSAGAdapter):
return p return p

View file

@ -154,10 +154,10 @@ class SegmentAnything(fl.Module):
return w return w
def compute_target_size(self, size: tuple[int, int]) -> tuple[int, int]: def compute_target_size(self, size: tuple[int, int]) -> tuple[int, int]:
"""Compute the target size for a given size. """Compute the target size as expected by the image encoder.
Args: Args:
size: The size of the image. size: The size of the input image.
Returns: Returns:
The target height. The target height.
@ -171,7 +171,7 @@ class SegmentAnything(fl.Module):
return (newh, neww) return (newh, neww)
def preprocess_image(self, image: Image.Image, target_size: tuple[int, int]) -> Tensor: def preprocess_image(self, image: Image.Image, target_size: tuple[int, int]) -> Tensor:
"""Preprocess an image. """Preprocess an image without distorting its aspect ratio.
Args: Args:
image: The image to preprocess. image: The image to preprocess.