refiners/tests/adapters/test_adapter_context.py

26 lines
738 B
Python

import refiners.fluxion.layers as fl
from refiners.fluxion.adapters.adapter import Adapter
from refiners.fluxion.context import Contexts
class ContextAdapter(fl.Chain, Adapter[fl.Chain]):
def __init__(self, target: fl.Chain):
with self.setup_adapter(target):
super().__init__(
fl.Lambda(lambda: 42),
fl.SetContext("foo", "bar"),
)
class ContextChain(fl.Chain):
def init_context(self) -> Contexts:
return {"foo": {"bar": None}}
def test_adapter_can_access_parent_context():
chain = ContextChain(fl.Chain(), fl.UseContext("foo", "bar"))
adaptee = chain.layer("Chain", fl.Chain)
ContextAdapter(adaptee).inject(chain)
assert chain() == 42