From e033306f601c83a4fbbd42f84d2ead8d314f90b9 Mon Sep 17 00:00:00 2001 From: Pierre Chapuis Date: Fri, 5 Apr 2024 17:58:43 +0200 Subject: [PATCH] use factories in context example Using the same instance multiple times is a bad idea because PyTorch memorizes things internally. Among other things this breaks Chain's `__repr__`. --- docs/concepts/context.md | 27 ++++++++++++++++----------- 1 file changed, 16 insertions(+), 11 deletions(-) diff --git a/docs/concepts/context.md b/docs/concepts/context.md index 5a57531..4e0270d 100644 --- a/docs/concepts/context.md +++ b/docs/concepts/context.md @@ -53,22 +53,22 @@ Another use of the context is simplifying complex models, in particular those wi To emulate this, let us consider this toy example with a structure somewhat similar to a U-Net: ```py -square = fl.Lambda(lambda x: x ** 2) -sqrt = fl.Lambda(lambda x: x ** 0.5) +square = lambda: fl.Lambda(lambda x: x ** 2) +sqrt = lambda: fl.Lambda(lambda x: x ** 0.5) m1 = fl.Chain( fl.Residual( - square, + square(), fl.Residual( - square, + square(), fl.Residual( - square, + square(), ), - sqrt, + sqrt(), ), - sqrt, + sqrt(), ), - sqrt, + sqrt(), ) ``` @@ -86,7 +86,12 @@ class MyModel(fl.Chain): def init_context(self) -> Contexts: return {"mymodel": {"residuals": []}} -push_residual = fl.SetContext("mymodel", "residuals", callback=lambda l, x: l.append(x)) +def push_residual(): + return fl.SetContext( + "mymodel", + "residuals", + callback=lambda l, x: l.append(x), + ) class ApplyResidual(fl.Sum): def __init__(self): @@ -95,8 +100,8 @@ class ApplyResidual(fl.Sum): fl.UseContext("mymodel", "residuals").compose(lambda x: x.pop()), ) -squares = fl.Chain(x for _ in range(3) for x in (push_residual, square)) -sqrts = fl.Chain(x for _ in range(3) for x in (ApplyResidual(), sqrt)) +squares = fl.Chain(x for _ in range(3) for x in (push_residual(), square())) +sqrts = fl.Chain(x for _ in range(3) for x in (ApplyResidual(), sqrt())) m2 = MyModel(squares, sqrts) ```