add context_key getter and setter to RangeAdapter2d

This commit is contained in:
Laurent 2024-02-14 15:21:49 +00:00 committed by Laureηt
parent 35b6e2f7c5
commit a54808e757

View file

@ -55,14 +55,31 @@ class RangeAdapter2d(fl.Sum, Adapter[fl.Conv2d]):
) -> None:
self.channels = channels
self.embedding_dim = embedding_dim
self.context_key = context_key
with self.setup_adapter(target):
super().__init__(
target,
fl.Chain(
fl.UseContext("range_adapter", context_key),
fl.SiLU(),
fl.Linear(in_features=embedding_dim, out_features=channels, device=device, dtype=dtype),
fl.Linear(
in_features=embedding_dim,
out_features=channels,
device=device,
dtype=dtype,
),
fl.Reshape(channels, 1, 1),
),
)
@property
def context_key(self) -> str:
use_context_module = self.ensure_find(fl.UseContext)
assert use_context_module.context == "range_adapter"
return use_context_module.key
@context_key.setter
def context_key(self, value: str) -> None:
use_context_module = self.ensure_find(fl.UseContext)
assert use_context_module.context == "range_adapter"
use_context_module.key = value