From 35b6e2f7c5194d49735b8ca35d113e76e230d7ff Mon Sep 17 00:00:00 2001 From: Laurent Date: Wed, 14 Feb 2024 15:21:34 +0000 Subject: [PATCH] add context_key getter and setter to `TimestepEncoder` --- .../stable_diffusion_xl/unet.py | 21 +++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/unet.py b/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/unet.py index 9663fa6..b44a8b2 100644 --- a/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/unet.py +++ b/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/unet.py @@ -54,7 +54,12 @@ class TextTimeEmbedding(fl.Chain): class TimestepEncoder(fl.Passthrough): - def __init__(self, device: Device | str | None = None, dtype: DType | None = None) -> None: + def __init__( + self, + context_key: str = "timestep_embedding", + device: Device | str | None = None, + dtype: DType | None = None, + ) -> None: self.timestep_embedding_dim = 1280 super().__init__( fl.Sum( @@ -69,9 +74,21 @@ class TimestepEncoder(fl.Passthrough): ), TextTimeEmbedding(device=device, dtype=dtype), ), - fl.SetContext(context="range_adapter", key="timestep_embedding"), + fl.SetContext(context="range_adapter", key=context_key), ) + @property + def context_key(self) -> str: + set_context_module = self.ensure_find(fl.SetContext) + assert set_context_module.context == "range_adapter" + return set_context_module.key + + @context_key.setter + def context_key(self, value: str) -> None: + set_context_module = self.ensure_find(fl.SetContext) + assert set_context_module.context == "range_adapter" + set_context_module.key = value + class SDXLCrossAttention(CrossAttentionBlock2d): def __init__(