use factories in context example

Using the same instance multiple times is a bad idea
because PyTorch memorizes things internally. Among
other things this breaks Chain's `__repr__`.
This commit is contained in:
Pierre Chapuis 2024-04-05 17:58:43 +02:00
parent bbb46e3fc7
commit e033306f60

View file

@ -53,22 +53,22 @@ Another use of the context is simplifying complex models, in particular those wi
To emulate this, let us consider this toy example with a structure somewhat similar to a U-Net: To emulate this, let us consider this toy example with a structure somewhat similar to a U-Net:
```py ```py
square = fl.Lambda(lambda x: x ** 2) square = lambda: fl.Lambda(lambda x: x ** 2)
sqrt = fl.Lambda(lambda x: x ** 0.5) sqrt = lambda: fl.Lambda(lambda x: x ** 0.5)
m1 = fl.Chain( m1 = fl.Chain(
fl.Residual( fl.Residual(
square, square(),
fl.Residual( fl.Residual(
square, square(),
fl.Residual( fl.Residual(
square, square(),
), ),
sqrt, sqrt(),
), ),
sqrt, sqrt(),
), ),
sqrt, sqrt(),
) )
``` ```
@ -86,7 +86,12 @@ class MyModel(fl.Chain):
def init_context(self) -> Contexts: def init_context(self) -> Contexts:
return {"mymodel": {"residuals": []}} return {"mymodel": {"residuals": []}}
push_residual = fl.SetContext("mymodel", "residuals", callback=lambda l, x: l.append(x)) def push_residual():
return fl.SetContext(
"mymodel",
"residuals",
callback=lambda l, x: l.append(x),
)
class ApplyResidual(fl.Sum): class ApplyResidual(fl.Sum):
def __init__(self): def __init__(self):
@ -95,8 +100,8 @@ class ApplyResidual(fl.Sum):
fl.UseContext("mymodel", "residuals").compose(lambda x: x.pop()), fl.UseContext("mymodel", "residuals").compose(lambda x: x.pop()),
) )
squares = fl.Chain(x for _ in range(3) for x in (push_residual, square)) squares = fl.Chain(x for _ in range(3) for x in (push_residual(), square()))
sqrts = fl.Chain(x for _ in range(3) for x in (ApplyResidual(), sqrt)) sqrts = fl.Chain(x for _ in range(3) for x in (ApplyResidual(), sqrt()))
m2 = MyModel(squares, sqrts) m2 = MyModel(squares, sqrts)
``` ```