diff --git a/docs/reference/foundationals/latent_diffusion.md b/docs/reference/foundationals/latent_diffusion.md index 0e290e6..85ae5d6 100644 --- a/docs/reference/foundationals/latent_diffusion.md +++ b/docs/reference/foundationals/latent_diffusion.md @@ -7,3 +7,5 @@ ::: refiners.foundationals.latent_diffusion.solvers ::: refiners.foundationals.latent_diffusion.lora + +::: refiners.foundationals.latent_diffusion.image_prompt diff --git a/src/refiners/foundationals/latent_diffusion/image_prompt.py b/src/refiners/foundationals/latent_diffusion/image_prompt.py index 2771255..1c35eb9 100644 --- a/src/refiners/foundationals/latent_diffusion/image_prompt.py +++ b/src/refiners/foundationals/latent_diffusion/image_prompt.py @@ -329,6 +329,12 @@ class CrossAttentionAdapter(fl.Chain, Adapter[fl.Attention]): class IPAdapter(Generic[T], fl.Chain, Adapter[T]): + """Image Prompt adapter for a Stable Diffusion U-Net model. + + See [[arXiv:2308.06721] IP-Adapter: Text Compatible Image Prompt Adapter for Text-to-Image Diffusion Models](https://arxiv.org/abs/2308.06721) + for more details. + """ + # Prevent PyTorch module registration _clip_image_encoder: list[CLIPImageEncoderH] _grid_image_encoder: list[CLIPImageEncoderH] @@ -343,6 +349,16 @@ class IPAdapter(Generic[T], fl.Chain, Adapter[T]): fine_grained: bool = False, weights: dict[str, Tensor] | None = None, ) -> None: + """Initialize the adapter. + + Args: + target: The target model to adapt. + clip_image_encoder: The CLIP image encoder to use. + image_proj: The image projection to use. + scale: The scale to use for the image prompt. + fine_grained: Whether to use fine-grained image prompt. + weights: The weights of the IPAdapter. + """ with self.setup_adapter(target): super().__init__(target) @@ -376,6 +392,7 @@ class IPAdapter(Generic[T], fl.Chain, Adapter[T]): @property def clip_image_encoder(self) -> CLIPImageEncoderH: + """The CLIP image encoder of the adapter.""" return self._clip_image_encoder[0] @property @@ -399,6 +416,7 @@ class IPAdapter(Generic[T], fl.Chain, Adapter[T]): @property def scale(self) -> float: + """The scale of the adapter.""" return self.sub_adapters[0].scale @scale.setter @@ -411,6 +429,14 @@ class IPAdapter(Generic[T], fl.Chain, Adapter[T]): cross_attn.scale = scale def set_clip_image_embedding(self, image_embedding: Tensor) -> None: + """Set the CLIP image embedding context. + + Note: + This is required by `ImageCrossAttention`. + + Args: + image_embedding: The CLIP image embedding to set. + """ self.set_context("ip_adapter", {"clip_image_embedding": image_embedding}) @overload @@ -433,6 +459,16 @@ class IPAdapter(Generic[T], fl.Chain, Adapter[T]): weights: list[float] | None = None, concat_batches: bool = True, ) -> Tensor: + """Compute the CLIP image embedding. + + Args: + image_prompt: The image prompt to use. + weights: The scale to use for the image prompt. + concat_batches: Whether to concatenate the batches. + + Returns: + The CLIP image embedding. + """ if isinstance(image_prompt, Image.Image): image_prompt = self.preprocess_image(image_prompt) elif isinstance(image_prompt, list): @@ -478,7 +514,18 @@ class IPAdapter(Generic[T], fl.Chain, Adapter[T]): mean: list[float] | None = None, std: list[float] | None = None, ) -> Tensor: - # Default mean and std are parameters from https://github.com/openai/CLIP + """Preprocess the image. + + Note: + The default mean and std are parameters from + https://github.com/openai/CLIP + + Args: + image: The image to preprocess. + size: The size to resize the image to. + mean: The mean to use for normalization. + std: The standard deviation to use for normalization. + """ return normalize( image_to_tensor(image.resize(size), device=self.target.device, dtype=self.target.dtype), mean=[0.48145466, 0.4578275, 0.40821073] if mean is None else mean, diff --git a/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/image_prompt.py b/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/image_prompt.py index 934ad29..c129766 100644 --- a/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/image_prompt.py +++ b/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/image_prompt.py @@ -7,6 +7,8 @@ from refiners.foundationals.latent_diffusion.stable_diffusion_xl.unet import SDX class SDXLIPAdapter(IPAdapter[SDXLUNet]): + """Image Prompt adapter for the Stable Diffusion XL U-Net model.""" + def __init__( self, target: SDXLUNet, @@ -16,6 +18,16 @@ class SDXLIPAdapter(IPAdapter[SDXLUNet]): fine_grained: bool = False, weights: dict[str, Tensor] | None = None, ) -> None: + """Initialize the adapter. + + Args: + target: The SDXLUNet model to adapt. + clip_image_encoder: The CLIP image encoder to use. + image_proj: The image projection to use. + scale: The scale to use for the image prompt. + fine_grained: Whether to use fine-grained image prompt. + weights: The weights of the IPAdapter. + """ clip_image_encoder = clip_image_encoder or CLIPImageEncoderH(device=target.device, dtype=target.dtype) if image_proj is None: