(doc/foundationals) add IPAdapter, related docstrings

This commit is contained in:
Laurent 2024-02-02 16:01:59 +00:00 committed by Laureηt
parent 6b35f1cc84
commit 1dcb36e1e0
3 changed files with 62 additions and 1 deletions

View file

@ -7,3 +7,5 @@
::: refiners.foundationals.latent_diffusion.solvers
::: refiners.foundationals.latent_diffusion.lora
::: refiners.foundationals.latent_diffusion.image_prompt

View file

@ -329,6 +329,12 @@ class CrossAttentionAdapter(fl.Chain, Adapter[fl.Attention]):
class IPAdapter(Generic[T], fl.Chain, Adapter[T]):
"""Image Prompt adapter for a Stable Diffusion U-Net model.
See [[arXiv:2308.06721] IP-Adapter: Text Compatible Image Prompt Adapter for Text-to-Image Diffusion Models](https://arxiv.org/abs/2308.06721)
for more details.
"""
# Prevent PyTorch module registration
_clip_image_encoder: list[CLIPImageEncoderH]
_grid_image_encoder: list[CLIPImageEncoderH]
@ -343,6 +349,16 @@ class IPAdapter(Generic[T], fl.Chain, Adapter[T]):
fine_grained: bool = False,
weights: dict[str, Tensor] | None = None,
) -> None:
"""Initialize the adapter.
Args:
target: The target model to adapt.
clip_image_encoder: The CLIP image encoder to use.
image_proj: The image projection to use.
scale: The scale to use for the image prompt.
fine_grained: Whether to use fine-grained image prompt.
weights: The weights of the IPAdapter.
"""
with self.setup_adapter(target):
super().__init__(target)
@ -376,6 +392,7 @@ class IPAdapter(Generic[T], fl.Chain, Adapter[T]):
@property
def clip_image_encoder(self) -> CLIPImageEncoderH:
"""The CLIP image encoder of the adapter."""
return self._clip_image_encoder[0]
@property
@ -399,6 +416,7 @@ class IPAdapter(Generic[T], fl.Chain, Adapter[T]):
@property
def scale(self) -> float:
"""The scale of the adapter."""
return self.sub_adapters[0].scale
@scale.setter
@ -411,6 +429,14 @@ class IPAdapter(Generic[T], fl.Chain, Adapter[T]):
cross_attn.scale = scale
def set_clip_image_embedding(self, image_embedding: Tensor) -> None:
"""Set the CLIP image embedding context.
Note:
This is required by `ImageCrossAttention`.
Args:
image_embedding: The CLIP image embedding to set.
"""
self.set_context("ip_adapter", {"clip_image_embedding": image_embedding})
@overload
@ -433,6 +459,16 @@ class IPAdapter(Generic[T], fl.Chain, Adapter[T]):
weights: list[float] | None = None,
concat_batches: bool = True,
) -> Tensor:
"""Compute the CLIP image embedding.
Args:
image_prompt: The image prompt to use.
weights: The scale to use for the image prompt.
concat_batches: Whether to concatenate the batches.
Returns:
The CLIP image embedding.
"""
if isinstance(image_prompt, Image.Image):
image_prompt = self.preprocess_image(image_prompt)
elif isinstance(image_prompt, list):
@ -478,7 +514,18 @@ class IPAdapter(Generic[T], fl.Chain, Adapter[T]):
mean: list[float] | None = None,
std: list[float] | None = None,
) -> Tensor:
# Default mean and std are parameters from https://github.com/openai/CLIP
"""Preprocess the image.
Note:
The default mean and std are parameters from
https://github.com/openai/CLIP
Args:
image: The image to preprocess.
size: The size to resize the image to.
mean: The mean to use for normalization.
std: The standard deviation to use for normalization.
"""
return normalize(
image_to_tensor(image.resize(size), device=self.target.device, dtype=self.target.dtype),
mean=[0.48145466, 0.4578275, 0.40821073] if mean is None else mean,

View file

@ -7,6 +7,8 @@ from refiners.foundationals.latent_diffusion.stable_diffusion_xl.unet import SDX
class SDXLIPAdapter(IPAdapter[SDXLUNet]):
"""Image Prompt adapter for the Stable Diffusion XL U-Net model."""
def __init__(
self,
target: SDXLUNet,
@ -16,6 +18,16 @@ class SDXLIPAdapter(IPAdapter[SDXLUNet]):
fine_grained: bool = False,
weights: dict[str, Tensor] | None = None,
) -> None:
"""Initialize the adapter.
Args:
target: The SDXLUNet model to adapt.
clip_image_encoder: The CLIP image encoder to use.
image_proj: The image projection to use.
scale: The scale to use for the image prompt.
fine_grained: Whether to use fine-grained image prompt.
weights: The weights of the IPAdapter.
"""
clip_image_encoder = clip_image_encoder or CLIPImageEncoderH(device=target.device, dtype=target.dtype)
if image_proj is None: