mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-09 15:02:01 +00:00
add context_key getter and setter to RangeAdapter2d
This commit is contained in:
parent
35b6e2f7c5
commit
a54808e757
|
@ -55,14 +55,31 @@ class RangeAdapter2d(fl.Sum, Adapter[fl.Conv2d]):
|
|||
) -> None:
|
||||
self.channels = channels
|
||||
self.embedding_dim = embedding_dim
|
||||
self.context_key = context_key
|
||||
|
||||
with self.setup_adapter(target):
|
||||
super().__init__(
|
||||
target,
|
||||
fl.Chain(
|
||||
fl.UseContext("range_adapter", context_key),
|
||||
fl.SiLU(),
|
||||
fl.Linear(in_features=embedding_dim, out_features=channels, device=device, dtype=dtype),
|
||||
fl.Linear(
|
||||
in_features=embedding_dim,
|
||||
out_features=channels,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
),
|
||||
fl.Reshape(channels, 1, 1),
|
||||
),
|
||||
)
|
||||
|
||||
@property
|
||||
def context_key(self) -> str:
|
||||
use_context_module = self.ensure_find(fl.UseContext)
|
||||
assert use_context_module.context == "range_adapter"
|
||||
return use_context_module.key
|
||||
|
||||
@context_key.setter
|
||||
def context_key(self, value: str) -> None:
|
||||
use_context_module = self.ensure_find(fl.UseContext)
|
||||
assert use_context_module.context == "range_adapter"
|
||||
use_context_module.key = value
|
||||
|
|
Loading…
Reference in a new issue