diff --git a/README.md b/README.md index 3dc748c..133ac76 100644 --- a/README.md +++ b/README.md @@ -31,6 +31,7 @@ ______________________________________________________________________ ## Latest News 🔥 +- Added [ELLA](https://arxiv.org/abs/2403.05135) for better prompts handling (contributed by [@ily-R](https://github.com/ily-R)) - Added the Box Segmenter all-in-one solution ([model](https://huggingface.co/finegrain/finegrain-box-segmenter), [HF Space](https://huggingface.co/spaces/finegrain/finegrain-object-cutter)) - Added [MVANet](https://arxiv.org/abs/2404.07445) for high resolution segmentation - Added [IC-Light](https://github.com/lllyasviel/IC-Light) to manipulate the illumination of images diff --git a/docs/reference/foundationals/latent_diffusion.md b/docs/reference/foundationals/latent_diffusion.md index 35731bf..7f01803 100644 --- a/docs/reference/foundationals/latent_diffusion.md +++ b/docs/reference/foundationals/latent_diffusion.md @@ -15,3 +15,5 @@ ::: refiners.foundationals.latent_diffusion.style_aligned ::: refiners.foundationals.latent_diffusion.multi_diffusion + +::: refiners.foundationals.latent_diffusion.ella_adapter diff --git a/src/refiners/foundationals/latent_diffusion/ella_adapter.py b/src/refiners/foundationals/latent_diffusion/ella_adapter.py index ed2b543..59bd781 100644 --- a/src/refiners/foundationals/latent_diffusion/ella_adapter.py +++ b/src/refiners/foundationals/latent_diffusion/ella_adapter.py @@ -210,6 +210,11 @@ class PerceiverResampler(fl.Chain): class ELLA(fl.Passthrough): + """ELLA latents encoder. + + See [[arXiv:2403.05135] ELLA: Equip Diffusion Models with LLM for Enhanced Semantic Alignment](https://arxiv.org/abs/2403.05135) for more details. + """ + def __init__( self, time_channel: int, @@ -249,6 +254,8 @@ class ELLACrossAttentionAdapter(fl.Chain, Adapter[fl.UseContext]): class ELLAAdapter(Generic[T], fl.Chain, Adapter[T]): + """Adapter for [`ELLA`][refiners.foundationals.latent_diffusion.ella_adapter.ELLA].""" + def __init__(self, target: T, latents_encoder: ELLA, weights: dict[str, Tensor] | None = None) -> None: if weights is not None: latents_encoder.load_state_dict(weights) diff --git a/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/ella_adapter.py b/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/ella_adapter.py index 4b4fccd..507f8e4 100644 --- a/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/ella_adapter.py +++ b/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/ella_adapter.py @@ -5,7 +5,15 @@ from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import SD1U class SD1ELLAAdapter(ELLAAdapter[SD1UNet]): + """[`ELLA`][refiners.foundationals.latent_diffusion.ella_adapter.ELLA] adapter for Stable Diffusion 1.5.""" + def __init__(self, target: SD1UNet, weights: dict[str, Tensor] | None = None) -> None: + """Initialize the adapter. + + Args: + target: The target model to adapt. + weights: The weights of the ELLA adapter (see `scripts/conversion/convert_ella_adapter.py`). + """ latents_encoder = ELLA( time_channel=320, timestep_embedding_dim=768,