apply @deltheil suggestions

This commit is contained in:
Laurent 2024-02-02 15:05:23 +00:00 committed by Laureηt
parent f62e71da1c
commit 093527a7de
6 changed files with 19 additions and 12 deletions

View file

@ -154,7 +154,7 @@ class CLIPTextEncoderL(CLIPTextEncoder):
Note:
We replace the GeLU activation function with an approximate GeLU to comply with the original CLIP implementation
of OpenAI (https://github.com/openai/CLIP/blob/main/clip/model.py#L166)
of OpenAI (https://github.com/openai/CLIP/blob/a1d0717/clip/model.py#L166)
See [[arXiv:2103.00020] Learning Transferable Visual Models From Natural Language Supervision](https://arxiv.org/abs/2103.00020)
for more details.

View file

@ -238,8 +238,6 @@ class LatentDiffusionAutoencoder(Chain):
x = decoder(x / self.encoder_scale)
return x
# backward-compatibility alias
# TODO: deprecate this method
def image_to_latents(self, image: Image.Image) -> Tensor:
return self.images_to_latents([image])
@ -261,7 +259,6 @@ class LatentDiffusionAutoencoder(Chain):
def decode_latents(self, x: Tensor) -> Image.Image:
return self.latents_to_image(x)
# TODO: deprecated this method ?
def latents_to_image(self, x: Tensor) -> Image.Image:
if x.shape[0] != 1:
raise ValueError(f"Expected batch size of 1, got {x.shape[0]}")

View file

@ -121,6 +121,10 @@ class Solver(fl.Module, ABC):
def remove_noise(self, x: Tensor, noise: Tensor, step: int) -> Tensor:
"""Remove noise from the input tensor using the current step of the diffusion process.
Note:
See [[arXiv:2006.11239] Denoising Diffusion Probabilistic Models, Equation 15](https://arxiv.org/abs/2006.11239)
and [[arXiv:2210.00939] Improving Sample Quality of Diffusion Models Using Self-Attention Guidance](https://arxiv.org/abs/2210.00939).
Args:
x: The input tensor to remove noise from.
noise: The noise tensor to remove from the input tensor.
@ -132,9 +136,6 @@ class Solver(fl.Module, ABC):
timestep = self.timesteps[step]
cumulative_scale_factors = self.cumulative_scale_factors[timestep]
noise_stds = self.noise_std[timestep]
# See equation (15) from https://arxiv.org/pdf/2006.11239.pdf.
# Useful to preview progress or for guidance
# See also https://arxiv.org/pdf/2210.00939.pdf (self-attention guidance)
denoised_x = (x - noise_stds * noise) / cumulative_scale_factors
return denoised_x
@ -196,6 +197,9 @@ class Solver(fl.Module, ABC):
This method should only be overridden by solvers that
need to scale the input according to the current timestep.
By default, this method does not scale the input.
(scale=1)
Args:
x: The input tensor to scale.
step: The current step of the diffusion process.

View file

@ -96,6 +96,9 @@ class StableDiffusion_1(LatentDiffusionModel):
def set_self_attention_guidance(self, enable: bool, scale: float = 1.0) -> None:
"""Set whether to enable self-attention guidance.
See [[arXiv:2210.00939] Improving Sample Quality of Diffusion Models Using Self-Attention Guidance](https://arxiv.org/abs/2210.00939)
for more details.
Args:
enable: Whether to enable self-attention guidance.
scale: The scale to use.
@ -114,7 +117,7 @@ class StableDiffusion_1(LatentDiffusionModel):
return self._find_sag_adapter() is not None
def _find_sag_adapter(self) -> SD1SAGAdapter | None:
"""Finds the self-attention guidance adapter."""
"""Finds the self-attention guidance adapter, if any."""
for p in self.unet.get_parents():
if isinstance(p, SD1SAGAdapter):
return p

View file

@ -140,6 +140,9 @@ class StableDiffusion_XL(LatentDiffusionModel):
def set_self_attention_guidance(self, enable: bool, scale: float = 1.0) -> None:
"""Sets the self-attention guidance.
See [[arXiv:2210.00939] Improving Sample Quality of Diffusion Models Using Self-Attention Guidance](https://arxiv.org/abs/2210.00939)
for more details.
Args:
enable: Whether to enable self-attention guidance or not.
scale: The scale to use.
@ -158,7 +161,7 @@ class StableDiffusion_XL(LatentDiffusionModel):
return self._find_sag_adapter() is not None
def _find_sag_adapter(self) -> SDXLSAGAdapter | None:
"""Finds the self-attention guidance adapter."""
"""Finds the self-attention guidance adapter, if any."""
for p in self.unet.get_parents():
if isinstance(p, SDXLSAGAdapter):
return p

View file

@ -154,10 +154,10 @@ class SegmentAnything(fl.Module):
return w
def compute_target_size(self, size: tuple[int, int]) -> tuple[int, int]:
"""Compute the target size for a given size.
"""Compute the target size as expected by the image encoder.
Args:
size: The size of the image.
size: The size of the input image.
Returns:
The target height.
@ -171,7 +171,7 @@ class SegmentAnything(fl.Module):
return (newh, neww)
def preprocess_image(self, image: Image.Image, target_size: tuple[int, int]) -> Tensor:
"""Preprocess an image.
"""Preprocess an image without distorting its aspect ratio.
Args:
image: The image to preprocess.