diff --git a/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/unet.py b/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/unet.py index 9663fa6..b44a8b2 100644 --- a/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/unet.py +++ b/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/unet.py @@ -54,7 +54,12 @@ class TextTimeEmbedding(fl.Chain): class TimestepEncoder(fl.Passthrough): - def __init__(self, device: Device | str | None = None, dtype: DType | None = None) -> None: + def __init__( + self, + context_key: str = "timestep_embedding", + device: Device | str | None = None, + dtype: DType | None = None, + ) -> None: self.timestep_embedding_dim = 1280 super().__init__( fl.Sum( @@ -69,9 +74,21 @@ class TimestepEncoder(fl.Passthrough): ), TextTimeEmbedding(device=device, dtype=dtype), ), - fl.SetContext(context="range_adapter", key="timestep_embedding"), + fl.SetContext(context="range_adapter", key=context_key), ) + @property + def context_key(self) -> str: + set_context_module = self.ensure_find(fl.SetContext) + assert set_context_module.context == "range_adapter" + return set_context_module.key + + @context_key.setter + def context_key(self, value: str) -> None: + set_context_module = self.ensure_find(fl.SetContext) + assert set_context_module.context == "range_adapter" + set_context_module.key = value + class SDXLCrossAttention(CrossAttentionBlock2d): def __init__(