mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-23 06:38:45 +00:00
use factories in context example
Using the same instance multiple times is a bad idea because PyTorch memorizes things internally. Among other things this breaks Chain's `__repr__`.
This commit is contained in:
parent
bbb46e3fc7
commit
e033306f60
|
@ -53,22 +53,22 @@ Another use of the context is simplifying complex models, in particular those wi
|
||||||
To emulate this, let us consider this toy example with a structure somewhat similar to a U-Net:
|
To emulate this, let us consider this toy example with a structure somewhat similar to a U-Net:
|
||||||
|
|
||||||
```py
|
```py
|
||||||
square = fl.Lambda(lambda x: x ** 2)
|
square = lambda: fl.Lambda(lambda x: x ** 2)
|
||||||
sqrt = fl.Lambda(lambda x: x ** 0.5)
|
sqrt = lambda: fl.Lambda(lambda x: x ** 0.5)
|
||||||
|
|
||||||
m1 = fl.Chain(
|
m1 = fl.Chain(
|
||||||
fl.Residual(
|
fl.Residual(
|
||||||
square,
|
square(),
|
||||||
fl.Residual(
|
fl.Residual(
|
||||||
square,
|
square(),
|
||||||
fl.Residual(
|
fl.Residual(
|
||||||
square,
|
square(),
|
||||||
),
|
),
|
||||||
sqrt,
|
sqrt(),
|
||||||
),
|
),
|
||||||
sqrt,
|
sqrt(),
|
||||||
),
|
),
|
||||||
sqrt,
|
sqrt(),
|
||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
@ -86,7 +86,12 @@ class MyModel(fl.Chain):
|
||||||
def init_context(self) -> Contexts:
|
def init_context(self) -> Contexts:
|
||||||
return {"mymodel": {"residuals": []}}
|
return {"mymodel": {"residuals": []}}
|
||||||
|
|
||||||
push_residual = fl.SetContext("mymodel", "residuals", callback=lambda l, x: l.append(x))
|
def push_residual():
|
||||||
|
return fl.SetContext(
|
||||||
|
"mymodel",
|
||||||
|
"residuals",
|
||||||
|
callback=lambda l, x: l.append(x),
|
||||||
|
)
|
||||||
|
|
||||||
class ApplyResidual(fl.Sum):
|
class ApplyResidual(fl.Sum):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
@ -95,8 +100,8 @@ class ApplyResidual(fl.Sum):
|
||||||
fl.UseContext("mymodel", "residuals").compose(lambda x: x.pop()),
|
fl.UseContext("mymodel", "residuals").compose(lambda x: x.pop()),
|
||||||
)
|
)
|
||||||
|
|
||||||
squares = fl.Chain(x for _ in range(3) for x in (push_residual, square))
|
squares = fl.Chain(x for _ in range(3) for x in (push_residual(), square()))
|
||||||
sqrts = fl.Chain(x for _ in range(3) for x in (ApplyResidual(), sqrt))
|
sqrts = fl.Chain(x for _ in range(3) for x in (ApplyResidual(), sqrt()))
|
||||||
m2 = MyModel(squares, sqrts)
|
m2 = MyModel(squares, sqrts)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue