add context_key getter and setter to TimestepEncoder

This commit is contained in:
Laurent 2024-02-14 15:21:34 +00:00 committed by Laureηt
parent 0230971543
commit 35b6e2f7c5

View file

@ -54,7 +54,12 @@ class TextTimeEmbedding(fl.Chain):
class TimestepEncoder(fl.Passthrough): class TimestepEncoder(fl.Passthrough):
def __init__(self, device: Device | str | None = None, dtype: DType | None = None) -> None: def __init__(
self,
context_key: str = "timestep_embedding",
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
self.timestep_embedding_dim = 1280 self.timestep_embedding_dim = 1280
super().__init__( super().__init__(
fl.Sum( fl.Sum(
@ -69,9 +74,21 @@ class TimestepEncoder(fl.Passthrough):
), ),
TextTimeEmbedding(device=device, dtype=dtype), TextTimeEmbedding(device=device, dtype=dtype),
), ),
fl.SetContext(context="range_adapter", key="timestep_embedding"), fl.SetContext(context="range_adapter", key=context_key),
) )
@property
def context_key(self) -> str:
set_context_module = self.ensure_find(fl.SetContext)
assert set_context_module.context == "range_adapter"
return set_context_module.key
@context_key.setter
def context_key(self, value: str) -> None:
set_context_module = self.ensure_find(fl.SetContext)
assert set_context_module.context == "range_adapter"
set_context_module.key = value
class SDXLCrossAttention(CrossAttentionBlock2d): class SDXLCrossAttention(CrossAttentionBlock2d):
def __init__( def __init__(