mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-23 06:38:45 +00:00
(doc/foundationals) add IPAdapter
, related docstrings
This commit is contained in:
parent
0643a40cd2
commit
1cbad09ca0
|
@ -7,3 +7,5 @@
|
|||
::: refiners.foundationals.latent_diffusion.solvers
|
||||
|
||||
::: refiners.foundationals.latent_diffusion.lora
|
||||
|
||||
::: refiners.foundationals.latent_diffusion.image_prompt
|
||||
|
|
|
@ -329,6 +329,12 @@ class CrossAttentionAdapter(fl.Chain, Adapter[fl.Attention]):
|
|||
|
||||
|
||||
class IPAdapter(Generic[T], fl.Chain, Adapter[T]):
|
||||
"""Image Prompt adapter for a Stable Diffusion U-Net model.
|
||||
|
||||
See [[arXiv:2308.06721] IP-Adapter: Text Compatible Image Prompt Adapter for Text-to-Image Diffusion Models](https://arxiv.org/abs/2308.06721)
|
||||
for more details.
|
||||
"""
|
||||
|
||||
# Prevent PyTorch module registration
|
||||
_clip_image_encoder: list[CLIPImageEncoderH]
|
||||
_grid_image_encoder: list[CLIPImageEncoderH]
|
||||
|
@ -343,6 +349,16 @@ class IPAdapter(Generic[T], fl.Chain, Adapter[T]):
|
|||
fine_grained: bool = False,
|
||||
weights: dict[str, Tensor] | None = None,
|
||||
) -> None:
|
||||
"""Initialize the adapter.
|
||||
|
||||
Args:
|
||||
target: The target model to adapt.
|
||||
clip_image_encoder: The CLIP image encoder to use.
|
||||
image_proj: The image projection to use.
|
||||
scale: The scale to use for the image prompt.
|
||||
fine_grained: Whether to use fine-grained image prompt.
|
||||
weights: The weights of the IPAdapter.
|
||||
"""
|
||||
with self.setup_adapter(target):
|
||||
super().__init__(target)
|
||||
|
||||
|
@ -376,6 +392,7 @@ class IPAdapter(Generic[T], fl.Chain, Adapter[T]):
|
|||
|
||||
@property
|
||||
def clip_image_encoder(self) -> CLIPImageEncoderH:
|
||||
"""The CLIP image encoder of the adapter."""
|
||||
return self._clip_image_encoder[0]
|
||||
|
||||
@property
|
||||
|
@ -399,6 +416,7 @@ class IPAdapter(Generic[T], fl.Chain, Adapter[T]):
|
|||
|
||||
@property
|
||||
def scale(self) -> float:
|
||||
"""The scale of the adapter."""
|
||||
return self.sub_adapters[0].scale
|
||||
|
||||
@scale.setter
|
||||
|
@ -411,6 +429,14 @@ class IPAdapter(Generic[T], fl.Chain, Adapter[T]):
|
|||
cross_attn.scale = scale
|
||||
|
||||
def set_clip_image_embedding(self, image_embedding: Tensor) -> None:
|
||||
"""Set the CLIP image embedding context.
|
||||
|
||||
Note:
|
||||
This is required by `ImageCrossAttention`.
|
||||
|
||||
Args:
|
||||
image_embedding: The CLIP image embedding to set.
|
||||
"""
|
||||
self.set_context("ip_adapter", {"clip_image_embedding": image_embedding})
|
||||
|
||||
@overload
|
||||
|
@ -433,6 +459,16 @@ class IPAdapter(Generic[T], fl.Chain, Adapter[T]):
|
|||
weights: list[float] | None = None,
|
||||
concat_batches: bool = True,
|
||||
) -> Tensor:
|
||||
"""Compute the CLIP image embedding.
|
||||
|
||||
Args:
|
||||
image_prompt: The image prompt to use.
|
||||
weights: The scale to use for the image prompt.
|
||||
concat_batches: Whether to concatenate the batches.
|
||||
|
||||
Returns:
|
||||
The CLIP image embedding.
|
||||
"""
|
||||
if isinstance(image_prompt, Image.Image):
|
||||
image_prompt = self.preprocess_image(image_prompt)
|
||||
elif isinstance(image_prompt, list):
|
||||
|
@ -478,7 +514,18 @@ class IPAdapter(Generic[T], fl.Chain, Adapter[T]):
|
|||
mean: list[float] | None = None,
|
||||
std: list[float] | None = None,
|
||||
) -> Tensor:
|
||||
# Default mean and std are parameters from https://github.com/openai/CLIP
|
||||
"""Preprocess the image.
|
||||
|
||||
Note:
|
||||
The default mean and std are parameters from
|
||||
https://github.com/openai/CLIP
|
||||
|
||||
Args:
|
||||
image: The image to preprocess.
|
||||
size: The size to resize the image to.
|
||||
mean: The mean to use for normalization.
|
||||
std: The standard deviation to use for normalization.
|
||||
"""
|
||||
return normalize(
|
||||
image_to_tensor(image.resize(size), device=self.target.device, dtype=self.target.dtype),
|
||||
mean=[0.48145466, 0.4578275, 0.40821073] if mean is None else mean,
|
||||
|
|
|
@ -7,6 +7,8 @@ from refiners.foundationals.latent_diffusion.stable_diffusion_xl.unet import SDX
|
|||
|
||||
|
||||
class SDXLIPAdapter(IPAdapter[SDXLUNet]):
|
||||
"""Image Prompt adapter for the Stable Diffusion XL U-Net model."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
target: SDXLUNet,
|
||||
|
@ -16,6 +18,16 @@ class SDXLIPAdapter(IPAdapter[SDXLUNet]):
|
|||
fine_grained: bool = False,
|
||||
weights: dict[str, Tensor] | None = None,
|
||||
) -> None:
|
||||
"""Initialize the adapter.
|
||||
|
||||
Args:
|
||||
target: The SDXLUNet model to adapt.
|
||||
clip_image_encoder: The CLIP image encoder to use.
|
||||
image_proj: The image projection to use.
|
||||
scale: The scale to use for the image prompt.
|
||||
fine_grained: Whether to use fine-grained image prompt.
|
||||
weights: The weights of the IPAdapter.
|
||||
"""
|
||||
clip_image_encoder = clip_image_encoder or CLIPImageEncoderH(device=target.device, dtype=target.dtype)
|
||||
|
||||
if image_proj is None:
|
||||
|
|
Loading…
Reference in a new issue