mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-15 09:38:14 +00:00
26 lines
738 B
Python
26 lines
738 B
Python
|
import refiners.fluxion.layers as fl
|
||
|
from refiners.fluxion.adapters.adapter import Adapter
|
||
|
from refiners.fluxion.context import Contexts
|
||
|
|
||
|
|
||
|
class ContextAdapter(fl.Chain, Adapter[fl.Chain]):
|
||
|
def __init__(self, target: fl.Chain):
|
||
|
with self.setup_adapter(target):
|
||
|
super().__init__(
|
||
|
fl.Lambda(lambda: 42),
|
||
|
fl.SetContext("foo", "bar"),
|
||
|
)
|
||
|
|
||
|
|
||
|
class ContextChain(fl.Chain):
|
||
|
def init_context(self) -> Contexts:
|
||
|
return {"foo": {"bar": None}}
|
||
|
|
||
|
|
||
|
def test_adapter_can_access_parent_context():
|
||
|
chain = ContextChain(fl.Chain(), fl.UseContext("foo", "bar"))
|
||
|
adaptee = chain.layer("Chain", fl.Chain)
|
||
|
ContextAdapter(adaptee).inject(chain)
|
||
|
|
||
|
assert chain() == 42
|