use factories in context example

Using the same instance multiple times is a bad idea
because PyTorch memorizes things internally. Among
other things this breaks Chain's `__repr__`.
This commit is contained in:
Pierre Chapuis 2024-04-05 17:58:43 +02:00
parent bbb46e3fc7
commit e033306f60

View file

@ -53,22 +53,22 @@ Another use of the context is simplifying complex models, in particular those wi
To emulate this, let us consider this toy example with a structure somewhat similar to a U-Net:
```py
square = fl.Lambda(lambda x: x ** 2)
sqrt = fl.Lambda(lambda x: x ** 0.5)
square = lambda: fl.Lambda(lambda x: x ** 2)
sqrt = lambda: fl.Lambda(lambda x: x ** 0.5)
m1 = fl.Chain(
fl.Residual(
square,
square(),
fl.Residual(
square,
square(),
fl.Residual(
square,
square(),
),
sqrt,
sqrt(),
),
sqrt,
sqrt(),
),
sqrt,
sqrt(),
)
```
@ -86,7 +86,12 @@ class MyModel(fl.Chain):
def init_context(self) -> Contexts:
return {"mymodel": {"residuals": []}}
push_residual = fl.SetContext("mymodel", "residuals", callback=lambda l, x: l.append(x))
def push_residual():
return fl.SetContext(
"mymodel",
"residuals",
callback=lambda l, x: l.append(x),
)
class ApplyResidual(fl.Sum):
def __init__(self):
@ -95,8 +100,8 @@ class ApplyResidual(fl.Sum):
fl.UseContext("mymodel", "residuals").compose(lambda x: x.pop()),
)
squares = fl.Chain(x for _ in range(3) for x in (push_residual, square))
sqrts = fl.Chain(x for _ in range(3) for x in (ApplyResidual(), sqrt))
squares = fl.Chain(x for _ in range(3) for x in (push_residual(), square()))
sqrts = fl.Chain(x for _ in range(3) for x in (ApplyResidual(), sqrt()))
m2 = MyModel(squares, sqrts)
```