mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-14 00:58:13 +00:00
add context_key getter and setter to RangeAdapter2d
This commit is contained in:
parent
35b6e2f7c5
commit
a54808e757
|
@ -55,14 +55,31 @@ class RangeAdapter2d(fl.Sum, Adapter[fl.Conv2d]):
|
||||||
) -> None:
|
) -> None:
|
||||||
self.channels = channels
|
self.channels = channels
|
||||||
self.embedding_dim = embedding_dim
|
self.embedding_dim = embedding_dim
|
||||||
self.context_key = context_key
|
|
||||||
with self.setup_adapter(target):
|
with self.setup_adapter(target):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
target,
|
target,
|
||||||
fl.Chain(
|
fl.Chain(
|
||||||
fl.UseContext("range_adapter", context_key),
|
fl.UseContext("range_adapter", context_key),
|
||||||
fl.SiLU(),
|
fl.SiLU(),
|
||||||
fl.Linear(in_features=embedding_dim, out_features=channels, device=device, dtype=dtype),
|
fl.Linear(
|
||||||
|
in_features=embedding_dim,
|
||||||
|
out_features=channels,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
fl.Reshape(channels, 1, 1),
|
fl.Reshape(channels, 1, 1),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def context_key(self) -> str:
|
||||||
|
use_context_module = self.ensure_find(fl.UseContext)
|
||||||
|
assert use_context_module.context == "range_adapter"
|
||||||
|
return use_context_module.key
|
||||||
|
|
||||||
|
@context_key.setter
|
||||||
|
def context_key(self, value: str) -> None:
|
||||||
|
use_context_module = self.ensure_find(fl.UseContext)
|
||||||
|
assert use_context_module.context == "range_adapter"
|
||||||
|
use_context_module.key = value
|
||||||
|
|
Loading…
Reference in a new issue