add helper for multiple image prompts

This commit is contained in:
Cédric Deltheil 2024-01-29 17:06:56 +00:00 committed by Cédric Deltheil
parent fd01ba910e
commit ca5c5a7ca5

View file

@ -1,9 +1,9 @@
import math import math
from typing import TYPE_CHECKING, Any, Generic, TypeVar from typing import TYPE_CHECKING, Any, Generic, TypeVar, overload
from jaxtyping import Float from jaxtyping import Float
from PIL import Image from PIL import Image
from torch import Tensor, cat, device as Device, dtype as DType, nn, softmax, zeros_like from torch import Tensor, cat, device as Device, dtype as DType, nn, softmax, tensor, zeros_like
import refiners.fluxion.layers as fl import refiners.fluxion.layers as fl
from refiners.fluxion.adapters.adapter import Adapter from refiners.fluxion.adapters.adapter import Adapter
@ -413,7 +413,53 @@ class IPAdapter(Generic[T], fl.Chain, Adapter[T]):
def set_clip_image_embedding(self, image_embedding: Tensor) -> None: def set_clip_image_embedding(self, image_embedding: Tensor) -> None:
self.set_context("ip_adapter", {"clip_image_embedding": image_embedding}) self.set_context("ip_adapter", {"clip_image_embedding": image_embedding})
@overload
def compute_clip_image_embedding(self, image_prompt: Tensor) -> Tensor: def compute_clip_image_embedding(self, image_prompt: Tensor) -> Tensor:
...
@overload
def compute_clip_image_embedding(self, image_prompt: Image.Image) -> Tensor:
...
@overload
def compute_clip_image_embedding(
self, image_prompt: list[Image.Image], weights: list[float] | None = None
) -> Tensor:
...
def compute_clip_image_embedding(
self,
image_prompt: Tensor | Image.Image | list[Image.Image],
weights: list[float] | None = None,
concat_batches: bool = True,
) -> Tensor:
if isinstance(image_prompt, Image.Image):
image_prompt = self.preprocess_image(image_prompt)
elif isinstance(image_prompt, list):
assert all(isinstance(image, Image.Image) for image in image_prompt)
image_prompt = cat([self.preprocess_image(image) for image in image_prompt])
negative_embedding, conditional_embedding = self._compute_clip_image_embedding(image_prompt)
batch_size = image_prompt.shape[0]
if weights is not None:
assert len(weights) == batch_size, f"Got {len(weights)} weights for {batch_size} images"
if any(weight != 1.0 for weight in weights):
conditional_embedding *= (
tensor(weights, device=conditional_embedding.device, dtype=conditional_embedding.dtype)
.unsqueeze(-1)
.unsqueeze(-1)
)
if batch_size > 1 and concat_batches:
# Create a longer image tokens sequence when a batch of images is given
# See https://github.com/tencent-ailab/IP-Adapter/issues/99
negative_embedding = cat(negative_embedding.chunk(batch_size), dim=1)
conditional_embedding = cat(conditional_embedding.chunk(batch_size), dim=1)
return cat((negative_embedding, conditional_embedding))
def _compute_clip_image_embedding(self, image_prompt: Tensor) -> tuple[Tensor, Tensor]:
image_encoder = self.clip_image_encoder if not self.fine_grained else self.grid_image_encoder image_encoder = self.clip_image_encoder if not self.fine_grained else self.grid_image_encoder
clip_embedding = image_encoder(image_prompt) clip_embedding = image_encoder(image_prompt)
conditional_embedding = self.image_proj(clip_embedding) conditional_embedding = self.image_proj(clip_embedding)
@ -423,7 +469,7 @@ class IPAdapter(Generic[T], fl.Chain, Adapter[T]):
# See https://github.com/tencent-ailab/IP-Adapter/blob/d580c50/tutorial_train_plus.py#L351-L352 # See https://github.com/tencent-ailab/IP-Adapter/blob/d580c50/tutorial_train_plus.py#L351-L352
clip_embedding = image_encoder(zeros_like(image_prompt)) clip_embedding = image_encoder(zeros_like(image_prompt))
negative_embedding = self.image_proj(clip_embedding) negative_embedding = self.image_proj(clip_embedding)
return cat((negative_embedding, conditional_embedding)) return negative_embedding, conditional_embedding
def preprocess_image( def preprocess_image(
self, self,