mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-23 14:48:45 +00:00
(doc/foundationals) add IPAdapter
, related docstrings
This commit is contained in:
parent
0643a40cd2
commit
1cbad09ca0
|
@ -7,3 +7,5 @@
|
||||||
::: refiners.foundationals.latent_diffusion.solvers
|
::: refiners.foundationals.latent_diffusion.solvers
|
||||||
|
|
||||||
::: refiners.foundationals.latent_diffusion.lora
|
::: refiners.foundationals.latent_diffusion.lora
|
||||||
|
|
||||||
|
::: refiners.foundationals.latent_diffusion.image_prompt
|
||||||
|
|
|
@ -329,6 +329,12 @@ class CrossAttentionAdapter(fl.Chain, Adapter[fl.Attention]):
|
||||||
|
|
||||||
|
|
||||||
class IPAdapter(Generic[T], fl.Chain, Adapter[T]):
|
class IPAdapter(Generic[T], fl.Chain, Adapter[T]):
|
||||||
|
"""Image Prompt adapter for a Stable Diffusion U-Net model.
|
||||||
|
|
||||||
|
See [[arXiv:2308.06721] IP-Adapter: Text Compatible Image Prompt Adapter for Text-to-Image Diffusion Models](https://arxiv.org/abs/2308.06721)
|
||||||
|
for more details.
|
||||||
|
"""
|
||||||
|
|
||||||
# Prevent PyTorch module registration
|
# Prevent PyTorch module registration
|
||||||
_clip_image_encoder: list[CLIPImageEncoderH]
|
_clip_image_encoder: list[CLIPImageEncoderH]
|
||||||
_grid_image_encoder: list[CLIPImageEncoderH]
|
_grid_image_encoder: list[CLIPImageEncoderH]
|
||||||
|
@ -343,6 +349,16 @@ class IPAdapter(Generic[T], fl.Chain, Adapter[T]):
|
||||||
fine_grained: bool = False,
|
fine_grained: bool = False,
|
||||||
weights: dict[str, Tensor] | None = None,
|
weights: dict[str, Tensor] | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
"""Initialize the adapter.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
target: The target model to adapt.
|
||||||
|
clip_image_encoder: The CLIP image encoder to use.
|
||||||
|
image_proj: The image projection to use.
|
||||||
|
scale: The scale to use for the image prompt.
|
||||||
|
fine_grained: Whether to use fine-grained image prompt.
|
||||||
|
weights: The weights of the IPAdapter.
|
||||||
|
"""
|
||||||
with self.setup_adapter(target):
|
with self.setup_adapter(target):
|
||||||
super().__init__(target)
|
super().__init__(target)
|
||||||
|
|
||||||
|
@ -376,6 +392,7 @@ class IPAdapter(Generic[T], fl.Chain, Adapter[T]):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def clip_image_encoder(self) -> CLIPImageEncoderH:
|
def clip_image_encoder(self) -> CLIPImageEncoderH:
|
||||||
|
"""The CLIP image encoder of the adapter."""
|
||||||
return self._clip_image_encoder[0]
|
return self._clip_image_encoder[0]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -399,6 +416,7 @@ class IPAdapter(Generic[T], fl.Chain, Adapter[T]):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def scale(self) -> float:
|
def scale(self) -> float:
|
||||||
|
"""The scale of the adapter."""
|
||||||
return self.sub_adapters[0].scale
|
return self.sub_adapters[0].scale
|
||||||
|
|
||||||
@scale.setter
|
@scale.setter
|
||||||
|
@ -411,6 +429,14 @@ class IPAdapter(Generic[T], fl.Chain, Adapter[T]):
|
||||||
cross_attn.scale = scale
|
cross_attn.scale = scale
|
||||||
|
|
||||||
def set_clip_image_embedding(self, image_embedding: Tensor) -> None:
|
def set_clip_image_embedding(self, image_embedding: Tensor) -> None:
|
||||||
|
"""Set the CLIP image embedding context.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
This is required by `ImageCrossAttention`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_embedding: The CLIP image embedding to set.
|
||||||
|
"""
|
||||||
self.set_context("ip_adapter", {"clip_image_embedding": image_embedding})
|
self.set_context("ip_adapter", {"clip_image_embedding": image_embedding})
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
|
@ -433,6 +459,16 @@ class IPAdapter(Generic[T], fl.Chain, Adapter[T]):
|
||||||
weights: list[float] | None = None,
|
weights: list[float] | None = None,
|
||||||
concat_batches: bool = True,
|
concat_batches: bool = True,
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
|
"""Compute the CLIP image embedding.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_prompt: The image prompt to use.
|
||||||
|
weights: The scale to use for the image prompt.
|
||||||
|
concat_batches: Whether to concatenate the batches.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The CLIP image embedding.
|
||||||
|
"""
|
||||||
if isinstance(image_prompt, Image.Image):
|
if isinstance(image_prompt, Image.Image):
|
||||||
image_prompt = self.preprocess_image(image_prompt)
|
image_prompt = self.preprocess_image(image_prompt)
|
||||||
elif isinstance(image_prompt, list):
|
elif isinstance(image_prompt, list):
|
||||||
|
@ -478,7 +514,18 @@ class IPAdapter(Generic[T], fl.Chain, Adapter[T]):
|
||||||
mean: list[float] | None = None,
|
mean: list[float] | None = None,
|
||||||
std: list[float] | None = None,
|
std: list[float] | None = None,
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
# Default mean and std are parameters from https://github.com/openai/CLIP
|
"""Preprocess the image.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
The default mean and std are parameters from
|
||||||
|
https://github.com/openai/CLIP
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image: The image to preprocess.
|
||||||
|
size: The size to resize the image to.
|
||||||
|
mean: The mean to use for normalization.
|
||||||
|
std: The standard deviation to use for normalization.
|
||||||
|
"""
|
||||||
return normalize(
|
return normalize(
|
||||||
image_to_tensor(image.resize(size), device=self.target.device, dtype=self.target.dtype),
|
image_to_tensor(image.resize(size), device=self.target.device, dtype=self.target.dtype),
|
||||||
mean=[0.48145466, 0.4578275, 0.40821073] if mean is None else mean,
|
mean=[0.48145466, 0.4578275, 0.40821073] if mean is None else mean,
|
||||||
|
|
|
@ -7,6 +7,8 @@ from refiners.foundationals.latent_diffusion.stable_diffusion_xl.unet import SDX
|
||||||
|
|
||||||
|
|
||||||
class SDXLIPAdapter(IPAdapter[SDXLUNet]):
|
class SDXLIPAdapter(IPAdapter[SDXLUNet]):
|
||||||
|
"""Image Prompt adapter for the Stable Diffusion XL U-Net model."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
target: SDXLUNet,
|
target: SDXLUNet,
|
||||||
|
@ -16,6 +18,16 @@ class SDXLIPAdapter(IPAdapter[SDXLUNet]):
|
||||||
fine_grained: bool = False,
|
fine_grained: bool = False,
|
||||||
weights: dict[str, Tensor] | None = None,
|
weights: dict[str, Tensor] | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
"""Initialize the adapter.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
target: The SDXLUNet model to adapt.
|
||||||
|
clip_image_encoder: The CLIP image encoder to use.
|
||||||
|
image_proj: The image projection to use.
|
||||||
|
scale: The scale to use for the image prompt.
|
||||||
|
fine_grained: Whether to use fine-grained image prompt.
|
||||||
|
weights: The weights of the IPAdapter.
|
||||||
|
"""
|
||||||
clip_image_encoder = clip_image_encoder or CLIPImageEncoderH(device=target.device, dtype=target.dtype)
|
clip_image_encoder = clip_image_encoder or CLIPImageEncoderH(device=target.device, dtype=target.dtype)
|
||||||
|
|
||||||
if image_proj is None:
|
if image_proj is None:
|
||||||
|
|
Loading…
Reference in a new issue