mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-21 21:58:47 +00:00
add context_key getter and setter to TimestepEncoder
This commit is contained in:
parent
0230971543
commit
35b6e2f7c5
|
@ -54,7 +54,12 @@ class TextTimeEmbedding(fl.Chain):
|
||||||
|
|
||||||
|
|
||||||
class TimestepEncoder(fl.Passthrough):
|
class TimestepEncoder(fl.Passthrough):
|
||||||
def __init__(self, device: Device | str | None = None, dtype: DType | None = None) -> None:
|
def __init__(
|
||||||
|
self,
|
||||||
|
context_key: str = "timestep_embedding",
|
||||||
|
device: Device | str | None = None,
|
||||||
|
dtype: DType | None = None,
|
||||||
|
) -> None:
|
||||||
self.timestep_embedding_dim = 1280
|
self.timestep_embedding_dim = 1280
|
||||||
super().__init__(
|
super().__init__(
|
||||||
fl.Sum(
|
fl.Sum(
|
||||||
|
@ -69,9 +74,21 @@ class TimestepEncoder(fl.Passthrough):
|
||||||
),
|
),
|
||||||
TextTimeEmbedding(device=device, dtype=dtype),
|
TextTimeEmbedding(device=device, dtype=dtype),
|
||||||
),
|
),
|
||||||
fl.SetContext(context="range_adapter", key="timestep_embedding"),
|
fl.SetContext(context="range_adapter", key=context_key),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def context_key(self) -> str:
|
||||||
|
set_context_module = self.ensure_find(fl.SetContext)
|
||||||
|
assert set_context_module.context == "range_adapter"
|
||||||
|
return set_context_module.key
|
||||||
|
|
||||||
|
@context_key.setter
|
||||||
|
def context_key(self, value: str) -> None:
|
||||||
|
set_context_module = self.ensure_find(fl.SetContext)
|
||||||
|
assert set_context_module.context == "range_adapter"
|
||||||
|
set_context_module.key = value
|
||||||
|
|
||||||
|
|
||||||
class SDXLCrossAttention(CrossAttentionBlock2d):
|
class SDXLCrossAttention(CrossAttentionBlock2d):
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|
Loading…
Reference in a new issue