mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-21 21:58:47 +00:00
add helper for multiple image prompts
This commit is contained in:
parent
fd01ba910e
commit
ca5c5a7ca5
|
@ -1,9 +1,9 @@
|
|||
import math
|
||||
from typing import TYPE_CHECKING, Any, Generic, TypeVar
|
||||
from typing import TYPE_CHECKING, Any, Generic, TypeVar, overload
|
||||
|
||||
from jaxtyping import Float
|
||||
from PIL import Image
|
||||
from torch import Tensor, cat, device as Device, dtype as DType, nn, softmax, zeros_like
|
||||
from torch import Tensor, cat, device as Device, dtype as DType, nn, softmax, tensor, zeros_like
|
||||
|
||||
import refiners.fluxion.layers as fl
|
||||
from refiners.fluxion.adapters.adapter import Adapter
|
||||
|
@ -413,7 +413,53 @@ class IPAdapter(Generic[T], fl.Chain, Adapter[T]):
|
|||
def set_clip_image_embedding(self, image_embedding: Tensor) -> None:
|
||||
self.set_context("ip_adapter", {"clip_image_embedding": image_embedding})
|
||||
|
||||
@overload
|
||||
def compute_clip_image_embedding(self, image_prompt: Tensor) -> Tensor:
|
||||
...
|
||||
|
||||
@overload
|
||||
def compute_clip_image_embedding(self, image_prompt: Image.Image) -> Tensor:
|
||||
...
|
||||
|
||||
@overload
|
||||
def compute_clip_image_embedding(
|
||||
self, image_prompt: list[Image.Image], weights: list[float] | None = None
|
||||
) -> Tensor:
|
||||
...
|
||||
|
||||
def compute_clip_image_embedding(
|
||||
self,
|
||||
image_prompt: Tensor | Image.Image | list[Image.Image],
|
||||
weights: list[float] | None = None,
|
||||
concat_batches: bool = True,
|
||||
) -> Tensor:
|
||||
if isinstance(image_prompt, Image.Image):
|
||||
image_prompt = self.preprocess_image(image_prompt)
|
||||
elif isinstance(image_prompt, list):
|
||||
assert all(isinstance(image, Image.Image) for image in image_prompt)
|
||||
image_prompt = cat([self.preprocess_image(image) for image in image_prompt])
|
||||
|
||||
negative_embedding, conditional_embedding = self._compute_clip_image_embedding(image_prompt)
|
||||
|
||||
batch_size = image_prompt.shape[0]
|
||||
if weights is not None:
|
||||
assert len(weights) == batch_size, f"Got {len(weights)} weights for {batch_size} images"
|
||||
if any(weight != 1.0 for weight in weights):
|
||||
conditional_embedding *= (
|
||||
tensor(weights, device=conditional_embedding.device, dtype=conditional_embedding.dtype)
|
||||
.unsqueeze(-1)
|
||||
.unsqueeze(-1)
|
||||
)
|
||||
|
||||
if batch_size > 1 and concat_batches:
|
||||
# Create a longer image tokens sequence when a batch of images is given
|
||||
# See https://github.com/tencent-ailab/IP-Adapter/issues/99
|
||||
negative_embedding = cat(negative_embedding.chunk(batch_size), dim=1)
|
||||
conditional_embedding = cat(conditional_embedding.chunk(batch_size), dim=1)
|
||||
|
||||
return cat((negative_embedding, conditional_embedding))
|
||||
|
||||
def _compute_clip_image_embedding(self, image_prompt: Tensor) -> tuple[Tensor, Tensor]:
|
||||
image_encoder = self.clip_image_encoder if not self.fine_grained else self.grid_image_encoder
|
||||
clip_embedding = image_encoder(image_prompt)
|
||||
conditional_embedding = self.image_proj(clip_embedding)
|
||||
|
@ -423,7 +469,7 @@ class IPAdapter(Generic[T], fl.Chain, Adapter[T]):
|
|||
# See https://github.com/tencent-ailab/IP-Adapter/blob/d580c50/tutorial_train_plus.py#L351-L352
|
||||
clip_embedding = image_encoder(zeros_like(image_prompt))
|
||||
negative_embedding = self.image_proj(clip_embedding)
|
||||
return cat((negative_embedding, conditional_embedding))
|
||||
return negative_embedding, conditional_embedding
|
||||
|
||||
def preprocess_image(
|
||||
self,
|
||||
|
|
Loading…
Reference in a new issue