refiners/guides/training_101/index.html

2299 lines
210 KiB
HTML
Raw Permalink Normal View History

<!doctype html>
<html lang="en" class="no-js">
<head>
<meta charset="utf-8">
<meta name="viewport" content="width=device-width,initial-scale=1">
<meta name="description" content="A micro framework on top of PyTorch with first class citizen APIs for foundation model adaptation">
<link rel="prev" href="../adapting_sdxl/">
<link rel="next" href="../comfyui_refiners/">
<link rel="icon" href="../../assets/favicon.svg">
<meta name="generator" content="mkdocs-1.6.1, mkdocs-material-9.5.45">
<title>Training 101 - Refiners</title>
<link rel="stylesheet" href="../../assets/stylesheets/main.0253249f.min.css">
<link rel="stylesheet" href="../../assets/stylesheets/palette.06af60db.min.css">
<link rel="preconnect" href="https://fonts.gstatic.com" crossorigin>
<link rel="stylesheet" href="https://fonts.googleapis.com/css?family=Roboto:300,300i,400,400i,700,700i%7CRoboto+Mono:400,400i,700,700i&display=fallback">
<style>:root{--md-text-font:"Roboto";--md-code-font:"Roboto Mono"}</style>
<link rel="stylesheet" href="../../assets/_mkdocstrings.css">
<link rel="stylesheet" href="../../stylesheets/extra.css">
<script>__md_scope=new URL("../..",location),__md_hash=e=>[...e].reduce(((e,_)=>(e<<5)-e+_.charCodeAt(0)),0),__md_get=(e,_=localStorage,t=__md_scope)=>JSON.parse(_.getItem(t.pathname+"."+e)),__md_set=(e,_,t=localStorage,a=__md_scope)=>{try{t.setItem(a.pathname+"."+e,JSON.stringify(_))}catch(e){}}</script>
</head>
<body dir="ltr" data-md-color-scheme="default" data-md-color-primary="deep-orange" data-md-color-accent="deep-orange">
<input class="md-toggle" data-md-toggle="drawer" type="checkbox" id="__drawer" autocomplete="off">
<input class="md-toggle" data-md-toggle="search" type="checkbox" id="__search" autocomplete="off">
<label class="md-overlay" for="__drawer"></label>
<div data-md-component="skip">
<a href="#training-101" class="md-skip">
Skip to content
</a>
</div>
<div data-md-component="announce">
<aside class="md-banner">
<div class="md-banner__inner md-grid md-typeset">
<button class="md-banner__button md-icon" aria-label="Don't show this again">
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M19 6.41 17.59 5 12 10.59 6.41 5 5 6.41 10.59 12 5 17.59 6.41 19 12 13.41 17.59 19 19 17.59 13.41 12z"/></svg>
</button>
Check out our <a href="https://finegrain.ai/bounties">Bounty Program</a> 💰!
</div>
<script>var el=document.querySelector("[data-md-component=announce]");if(el){var content=el.querySelector(".md-typeset");__md_hash(content.innerHTML)===__md_get("__announce")&&(el.hidden=!0)}</script>
</aside>
</div>
<header class="md-header md-header--shadow md-header--lifted" data-md-component="header">
<nav class="md-header__inner md-grid" aria-label="Header">
<a href="../.." title="Refiners" class="md-header__button md-logo" aria-label="Refiners" data-md-component="logo">
<img src="../../assets/favicon.svg" alt="logo">
</a>
<label class="md-header__button md-icon" for="__drawer">
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M3 6h18v2H3zm0 5h18v2H3zm0 5h18v2H3z"/></svg>
</label>
<div class="md-header__title" data-md-component="header-title">
<div class="md-header__ellipsis">
<div class="md-header__topic">
<span class="md-ellipsis">
Refiners
</span>
</div>
<div class="md-header__topic" data-md-component="header-topic">
<span class="md-ellipsis">
Training 101
</span>
</div>
</div>
</div>
<label class="md-header__button md-icon" for="__search">
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M9.5 3A6.5 6.5 0 0 1 16 9.5c0 1.61-.59 3.09-1.56 4.23l.27.27h.79l5 5-1.5 1.5-5-5v-.79l-.27-.27A6.52 6.52 0 0 1 9.5 16 6.5 6.5 0 0 1 3 9.5 6.5 6.5 0 0 1 9.5 3m0 2C7 5 5 7 5 9.5S7 14 9.5 14 14 12 14 9.5 12 5 9.5 5"/></svg>
</label>
<div class="md-search" data-md-component="search" role="dialog">
<label class="md-search__overlay" for="__search"></label>
<div class="md-search__inner" role="search">
<form class="md-search__form" name="search">
<input type="text" class="md-search__input" name="query" aria-label="Search" placeholder="Search" autocapitalize="off" autocorrect="off" autocomplete="off" spellcheck="false" data-md-component="search-query" required>
<label class="md-search__icon md-icon" for="__search">
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M9.5 3A6.5 6.5 0 0 1 16 9.5c0 1.61-.59 3.09-1.56 4.23l.27.27h.79l5 5-1.5 1.5-5-5v-.79l-.27-.27A6.52 6.52 0 0 1 9.5 16 6.5 6.5 0 0 1 3 9.5 6.5 6.5 0 0 1 9.5 3m0 2C7 5 5 7 5 9.5S7 14 9.5 14 14 12 14 9.5 12 5 9.5 5"/></svg>
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M20 11v2H8l5.5 5.5-1.42 1.42L4.16 12l7.92-7.92L13.5 5.5 8 11z"/></svg>
</label>
<nav class="md-search__options" aria-label="Search">
<button type="reset" class="md-search__icon md-icon" title="Clear" aria-label="Clear" tabindex="-1">
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M19 6.41 17.59 5 12 10.59 6.41 5 5 6.41 10.59 12 5 17.59 6.41 19 12 13.41 17.59 19 19 17.59 13.41 12z"/></svg>
</button>
</nav>
</form>
<div class="md-search__output">
<div class="md-search__scrollwrap" tabindex="0" data-md-scrollfix>
<div class="md-search-result" data-md-component="search-result">
<div class="md-search-result__meta">
Initializing search
</div>
<ol class="md-search-result__list" role="presentation"></ol>
</div>
</div>
</div>
</div>
</div>
<div class="md-header__source">
<a href="https://github.com/finegrain-ai/refiners" title="Go to repository" class="md-source" data-md-component="source">
<div class="md-source__icon md-icon">
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 448 512"><!--! Font Awesome Free 6.6.0 by @fontawesome - https://fontawesome.com License - https://fontawesome.com/license/free (Icons: CC BY 4.0, Fonts: SIL OFL 1.1, Code: MIT License) Copyright 2024 Fonticons, Inc.--><path d="M439.55 236.05 244 40.45a28.87 28.87 0 0 0-40.81 0l-40.66 40.63 51.52 51.52c27.06-9.14 52.68 16.77 43.39 43.68l49.66 49.66c34.23-11.8 61.18 31 35.47 56.69-26.49 26.49-70.21-2.87-56-37.34L240.22 199v121.85c25.3 12.54 22.26 41.85 9.08 55a34.34 34.34 0 0 1-48.55 0c-17.57-17.6-11.07-46.91 11.25-56v-123c-20.8-8.51-24.6-30.74-18.64-45L142.57 101 8.45 235.14a28.86 28.86 0 0 0 0 40.81l195.61 195.6a28.86 28.86 0 0 0 40.8 0l194.69-194.69a28.86 28.86 0 0 0 0-40.81"/></svg>
</div>
<div class="md-source__repository">
Refiners
</div>
</a>
</div>
</nav>
<nav class="md-tabs" aria-label="Tabs" data-md-component="tabs">
<div class="md-grid">
<ul class="md-tabs__list">
<li class="md-tabs__item">
<a href="../.." class="md-tabs__link">
Home
</a>
</li>
<li class="md-tabs__item">
<a href="../../getting-started/recommended/" class="md-tabs__link">
Getting started
</a>
</li>
<li class="md-tabs__item">
<a href="../../concepts/chain/" class="md-tabs__link">
Key Concepts
</a>
</li>
<li class="md-tabs__item md-tabs__item--active">
<a href="../adapting_sdxl/" class="md-tabs__link">
Guides
</a>
</li>
<li class="md-tabs__item">
<a href="../../reference/fluxion/adapters/" class="md-tabs__link">
API Reference
</a>
</li>
</ul>
</div>
</nav>
</header>
<div class="md-container" data-md-component="container">
<main class="md-main" data-md-component="main">
<div class="md-main__inner md-grid">
<div class="md-sidebar md-sidebar--primary" data-md-component="sidebar" data-md-type="navigation" >
<div class="md-sidebar__scrollwrap">
<div class="md-sidebar__inner">
<nav class="md-nav md-nav--primary md-nav--lifted" aria-label="Navigation" data-md-level="0">
<label class="md-nav__title" for="__drawer">
<a href="../.." title="Refiners" class="md-nav__button md-logo" aria-label="Refiners" data-md-component="logo">
<img src="../../assets/favicon.svg" alt="logo">
</a>
Refiners
</label>
<div class="md-nav__source">
<a href="https://github.com/finegrain-ai/refiners" title="Go to repository" class="md-source" data-md-component="source">
<div class="md-source__icon md-icon">
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 448 512"><!--! Font Awesome Free 6.6.0 by @fontawesome - https://fontawesome.com License - https://fontawesome.com/license/free (Icons: CC BY 4.0, Fonts: SIL OFL 1.1, Code: MIT License) Copyright 2024 Fonticons, Inc.--><path d="M439.55 236.05 244 40.45a28.87 28.87 0 0 0-40.81 0l-40.66 40.63 51.52 51.52c27.06-9.14 52.68 16.77 43.39 43.68l49.66 49.66c34.23-11.8 61.18 31 35.47 56.69-26.49 26.49-70.21-2.87-56-37.34L240.22 199v121.85c25.3 12.54 22.26 41.85 9.08 55a34.34 34.34 0 0 1-48.55 0c-17.57-17.6-11.07-46.91 11.25-56v-123c-20.8-8.51-24.6-30.74-18.64-45L142.57 101 8.45 235.14a28.86 28.86 0 0 0 0 40.81l195.61 195.6a28.86 28.86 0 0 0 40.8 0l194.69-194.69a28.86 28.86 0 0 0 0-40.81"/></svg>
</div>
<div class="md-source__repository">
Refiners
</div>
</a>
</div>
<ul class="md-nav__list" data-md-scrollfix>
<li class="md-nav__item md-nav__item--nested">
<input class="md-nav__toggle md-toggle md-toggle--indeterminate" type="checkbox" id="__nav_1" >
<label class="md-nav__link" for="__nav_1" id="__nav_1_label" tabindex="0">
<span class="md-ellipsis">
Home
</span>
<span class="md-nav__icon md-icon"></span>
</label>
<nav class="md-nav" data-md-level="1" aria-labelledby="__nav_1_label" aria-expanded="false">
<label class="md-nav__title" for="__nav_1">
<span class="md-nav__icon md-icon"></span>
Home
</label>
<ul class="md-nav__list" data-md-scrollfix>
<li class="md-nav__item">
<a href="../.." class="md-nav__link">
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="m12 3.77-.75.84S9.97 6.06 8.68 7.94 6 12.07 6 14.23a6 6 0 0 0 6 6 6 6 0 0 0 6-6c0-2.16-1.39-4.41-2.68-6.29s-2.57-3.33-2.57-3.33zm0 3.13c.44.52.84.95 1.68 2.17 1.21 1.76 2.32 4 2.32 5.16 0 2.22-1.78 4-4 4s-4-1.78-4-4c0-1.16 1.11-3.4 2.32-5.16.84-1.22 1.24-1.65 1.68-2.17"/></svg>
<span class="md-ellipsis">
Welcome
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../home/why/" class="md-nav__link">
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M20 11h3v2h-3zM1 11h3v2H1zM13 1v3h-2V1zM4.92 3.5l2.13 2.14-1.42 1.41L3.5 4.93zm12.03 2.13 2.12-2.13 1.43 1.43-2.13 2.12zM12 6a6 6 0 0 1 6 6c0 2.22-1.21 4.16-3 5.2V19a1 1 0 0 1-1 1h-4a1 1 0 0 1-1-1v-1.8c-1.79-1.04-3-2.98-3-5.2a6 6 0 0 1 6-6m2 15v1a1 1 0 0 1-1 1h-2a1 1 0 0 1-1-1v-1zm-3-3h2v-2.13c1.73-.44 3-2.01 3-3.87a4 4 0 0 0-4-4 4 4 0 0 0-4 4c0 1.86 1.27 3.43 3 3.87z"/></svg>
<span class="md-ellipsis">
Manifesto
</span>
</a>
</li>
</ul>
</nav>
</li>
<li class="md-nav__item md-nav__item--nested">
<input class="md-nav__toggle md-toggle md-toggle--indeterminate" type="checkbox" id="__nav_2" >
<label class="md-nav__link" for="__nav_2" id="__nav_2_label" tabindex="0">
<span class="md-ellipsis">
Getting started
</span>
<span class="md-nav__icon md-icon"></span>
</label>
<nav class="md-nav" data-md-level="1" aria-labelledby="__nav_2_label" aria-expanded="false">
<label class="md-nav__title" for="__nav_2">
<span class="md-nav__icon md-icon"></span>
Getting started
</label>
<ul class="md-nav__list" data-md-scrollfix>
<li class="md-nav__item">
<a href="../../getting-started/recommended/" class="md-nav__link">
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="m12 15.39-3.76 2.27.99-4.28-3.32-2.88 4.38-.37L12 6.09l1.71 4.04 4.38.37-3.32 2.88.99 4.28M22 9.24l-7.19-.61L12 2 9.19 8.63 2 9.24l5.45 4.73L5.82 21 12 17.27 18.18 21l-1.64-7.03z"/></svg>
<span class="md-ellipsis">
Recommended usage
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../getting-started/advanced/" class="md-nav__link">
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M9 1.09V6H7V1.09C4.16 1.57 2 4.03 2 7c0 2.22 1.21 4.15 3 5.19V21c0 .55.45 1 1 1h4c.55 0 1-.45 1-1v-8.81c1.79-1.04 3-2.97 3-5.19 0-2.97-2.16-5.43-5-5.91m1 9.37-1 .58V20H7v-8.96l-1-.58C4.77 9.74 4 8.42 4 7c0-1 .37-1.94 1-2.65V8h6V4.35c.63.71 1 1.65 1 2.65 0 1.42-.77 2.74-2 3.46m10.94 7.48a3.3 3.3 0 0 0 0-.89l.97-.73a.22.22 0 0 0 .06-.29l-.92-1.56c-.05-.1-.18-.14-.29-.1l-1.15.45c-.24-.17-.49-.32-.78-.44l-.17-1.19a.235.235 0 0 0-.23-.19h-1.85c-.12 0-.22.08-.24.19l-.17 1.19c-.29.12-.54.27-.78.44l-1.15-.45c-.1-.04-.24 0-.28.1l-.93 1.56c-.06.1-.03.22.06.29l.97.73c-.01.15-.03.3-.03.45s.02.29.03.44l-.97.74a.22.22 0 0 0-.06.29l.93 1.56c.04.1.18.13.28.1l1.15-.46c.24.18.49.33.78.45l.17 1.19c.02.11.12.19.24.19h1.85c.11 0 .21-.08.23-.19l.17-1.19c.29-.12.54-.27.78-.45l1.15.46c.11.03.24 0 .29-.1l.92-1.56a.22.22 0 0 0-.06-.29zM17.5 19c-.83 0-1.5-.67-1.5-1.5s.67-1.5 1.5-1.5 1.5.67 1.5 1.5-.67 1.5-1.5 1.5"/></svg>
<span class="md-ellipsis">
Advanced usage
</span>
</a>
</li>
</ul>
</nav>
</li>
<li class="md-nav__item md-nav__item--nested">
<input class="md-nav__toggle md-toggle md-toggle--indeterminate" type="checkbox" id="__nav_3" >
<label class="md-nav__link" for="__nav_3" id="__nav_3_label" tabindex="0">
<span class="md-ellipsis">
Key Concepts
</span>
<span class="md-nav__icon md-icon"></span>
</label>
<nav class="md-nav" data-md-level="1" aria-labelledby="__nav_3_label" aria-expanded="false">
<label class="md-nav__title" for="__nav_3">
<span class="md-nav__icon md-icon"></span>
Key Concepts
</label>
<ul class="md-nav__list" data-md-scrollfix>
<li class="md-nav__item">
<a href="../../concepts/chain/" class="md-nav__link">
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M12 1a2.5 2.5 0 0 0-2.5 2.5A2.5 2.5 0 0 0 11 5.79V7H7a2 2 0 0 0-2 2v.71A2.5 2.5 0 0 0 3.5 12 2.5 2.5 0 0 0 5 14.29V15H4a2 2 0 0 0-2 2v1.21A2.5 2.5 0 0 0 .5 20.5 2.5 2.5 0 0 0 3 23a2.5 2.5 0 0 0 2.5-2.5A2.5 2.5 0 0 0 4 18.21V17h4v1.21a2.5 2.5 0 0 0-1.5 2.29A2.5 2.5 0 0 0 9 23a2.5 2.5 0 0 0 2.5-2.5 2.5 2.5 0 0 0-1.5-2.29V17a2 2 0 0 0-2-2H7v-.71A2.5 2.5 0 0 0 8.5 12 2.5 2.5 0 0 0 7 9.71V9h10v.71A2.5 2.5 0 0 0 15.5 12a2.5 2.5 0 0 0 1.5 2.29V15h-1a2 2 0 0 0-2 2v1.21a2.5 2.5 0 0 0-1.5 2.29A2.5 2.5 0 0 0 15 23a2.5 2.5 0 0 0 2.5-2.5 2.5 2.5 0 0 0-1.5-2.29V17h4v1.21a2.5 2.5 0 0 0-1.5 2.29A2.5 2.5 0 0 0 21 23a2.5 2.5 0 0 0 2.5-2.5 2.5 2.5 0 0 0-1.5-2.29V17a2 2 0 0 0-2-2h-1v-.71A2.5 2.5 0 0 0 20.5 12 2.5 2.5 0 0 0 19 9.71V9a2 2 0 0 0-2-2h-4V5.79a2.5 2.5 0 0 0 1.5-2.29A2.5 2.5 0 0 0 12 1m0 1.5a1 1 0 0 1 1 1 1 1 0 0 1-1 1 1 1 0 0 1-1-1 1 1 0 0 1 1-1M6 11a1 1 0 0 1 1 1 1 1 0 0 1-1 1 1 1 0 0 1-1-1 1 1 0 0 1 1-1m12 0a1 1 0 0 1 1 1 1 1 0 0 1-1 1 1 1 0 0 1-1-1 1 1 0 0 1 1-1M3 19.5a1 1 0 0 1 1 1 1 1 0 0 1-1 1 1 1 0 0 1-1-1 1 1 0 0 1 1-1m6 0a1 1 0 0 1 1 1 1 1 0 0 1-1 1 1 1 0 0 1-1-1 1 1 0 0 1 1-1m6 0a1 1 0 0 1 1 1 1 1 0 0 1-1 1 1 1 0 0 1-1-1 1 1 0 0 1 1-1m6 0a1 1 0 0 1 1 1 1 1 0 0 1-1 1 1 1 0 0 1-1-1 1 1 0 0 1 1-1"/></svg>
<span class="md-ellipsis">
Chain
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../concepts/context/" class="md-nav__link">
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M9 22a1 1 0 0 1-1-1v-3H4a2 2 0 0 1-2-2V4a2 2 0 0 1 2-2h16a2 2 0 0 1 2 2v12a2 2 0 0 1-2 2h-6.1l-3.7 3.71c-.2.19-.45.29-.7.29zm1-6v3.08L13.08 16H20V4H4v12zm3-6h-2V6h2zm0 4h-2v-2h2z"/></svg>
<span class="md-ellipsis">
Context
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../concepts/adapter/" class="md-nav__link">
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M2 12h2v5h16v-5h2v5a2 2 0 0 1-2 2H4a2 2 0 0 1-2-2m9-12h2v3h3v2h-3v3h-2v-3H8V8h3Z"/></svg>
<span class="md-ellipsis">
Adapter
</span>
</a>
</li>
</ul>
</nav>
</li>
<li class="md-nav__item md-nav__item--active md-nav__item--section md-nav__item--nested">
<input class="md-nav__toggle md-toggle " type="checkbox" id="__nav_4" checked>
<label class="md-nav__link" for="__nav_4" id="__nav_4_label" tabindex="">
<span class="md-ellipsis">
Guides
</span>
<span class="md-nav__icon md-icon"></span>
</label>
<nav class="md-nav" data-md-level="1" aria-labelledby="__nav_4_label" aria-expanded="true">
<label class="md-nav__title" for="__nav_4">
<span class="md-nav__icon md-icon"></span>
Guides
</label>
<ul class="md-nav__list" data-md-scrollfix>
<li class="md-nav__item">
<a href="../adapting_sdxl/" class="md-nav__link">
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M2 13h2v2h2v-2h2v2h2v-2h2v2h2v-5l3-3V1h2l4 2-4 2v2l3 3v12H11v-3a2 2 0 0 0-2-2 2 2 0 0 0-2 2v3H2zm16-3c-.55 0-1 .54-1 1.2V13h2v-1.8c0-.66-.45-1.2-1-1.2"/></svg>
<span class="md-ellipsis">
Adapting SDXL
</span>
</a>
</li>
<li class="md-nav__item md-nav__item--active">
<input class="md-nav__toggle md-toggle" type="checkbox" id="__toc">
<label class="md-nav__link md-nav__link--active" for="__toc">
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M18 22a2 2 0 0 0 2-2V4a2 2 0 0 0-2-2h-6v7L9.5 7.5 7 9V2H6a2 2 0 0 0-2 2v16a2 2 0 0 0 2 2z"/></svg>
<span class="md-ellipsis">
Training 101
</span>
<span class="md-nav__icon md-icon"></span>
</label>
<a href="./" class="md-nav__link md-nav__link--active">
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M18 22a2 2 0 0 0 2-2V4a2 2 0 0 0-2-2h-6v7L9.5 7.5 7 9V2H6a2 2 0 0 0-2 2v16a2 2 0 0 0 2 2z"/></svg>
<span class="md-ellipsis">
Training 101
</span>
</a>
<nav class="md-nav md-nav--secondary" aria-label="Table of contents">
<label class="md-nav__title" for="__toc">
<span class="md-nav__icon md-icon"></span>
Table of contents
</label>
<ul class="md-nav__list" data-md-component="toc" data-md-scrollfix>
<li class="md-nav__item">
<a href="#pre-requisites" class="md-nav__link">
<span class="md-ellipsis">
Pre-requisites
</span>
</a>
</li>
<li class="md-nav__item">
<a href="#model" class="md-nav__link">
<span class="md-ellipsis">
Model
</span>
</a>
</li>
<li class="md-nav__item">
<a href="#dataset" class="md-nav__link">
<span class="md-ellipsis">
Dataset
</span>
</a>
</li>
<li class="md-nav__item">
<a href="#trainer" class="md-nav__link">
<span class="md-ellipsis">
Trainer
</span>
</a>
<nav class="md-nav" aria-label="Trainer">
<ul class="md-nav__list">
<li class="md-nav__item">
<a href="#batch" class="md-nav__link">
<span class="md-ellipsis">
Batch
</span>
</a>
</li>
<li class="md-nav__item">
<a href="#config" class="md-nav__link">
<span class="md-ellipsis">
Config
</span>
</a>
</li>
<li class="md-nav__item">
<a href="#subclass" class="md-nav__link">
<span class="md-ellipsis">
Subclass
</span>
</a>
</li>
<li class="md-nav__item">
<a href="#model-registration" class="md-nav__link">
<span class="md-ellipsis">
Model registration
</span>
</a>
</li>
</ul>
</nav>
</li>
<li class="md-nav__item">
<a href="#logging" class="md-nav__link">
<span class="md-ellipsis">
Logging
</span>
</a>
</li>
<li class="md-nav__item">
<a href="#evaluation" class="md-nav__link">
<span class="md-ellipsis">
Evaluation
</span>
</a>
</li>
<li class="md-nav__item">
<a href="#wrap-up" class="md-nav__link">
<span class="md-ellipsis">
Wrap up
</span>
</a>
</li>
</ul>
</nav>
</li>
<li class="md-nav__item">
<a href="../comfyui_refiners/" class="md-nav__link">
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M15 3v4.59L7.59 15H3v6h6v-4.58L16.42 9H21V3m-4 2h2v2h-2M5 17h2v2H5"/></svg>
<span class="md-ellipsis">
ComfyUI Refiners
</span>
</a>
</li>
</ul>
</nav>
</li>
<li class="md-nav__item md-nav__item--nested">
<input class="md-nav__toggle md-toggle md-toggle--indeterminate" type="checkbox" id="__nav_5" >
<label class="md-nav__link" for="__nav_5" id="__nav_5_label" tabindex="0">
<span class="md-ellipsis">
API Reference
</span>
<span class="md-nav__icon md-icon"></span>
</label>
<nav class="md-nav" data-md-level="1" aria-labelledby="__nav_5_label" aria-expanded="false">
<label class="md-nav__title" for="__nav_5">
<span class="md-nav__icon md-icon"></span>
API Reference
</label>
<ul class="md-nav__list" data-md-scrollfix>
<li class="md-nav__item md-nav__item--nested">
<input class="md-nav__toggle md-toggle md-toggle--indeterminate" type="checkbox" id="__nav_5_1" >
<label class="md-nav__link" for="__nav_5_1" id="__nav_5_1_label" tabindex="0">
<span class="md-ellipsis">
Refiners
</span>
<span class="md-nav__icon md-icon"></span>
</label>
<nav class="md-nav" data-md-level="2" aria-labelledby="__nav_5_1_label" aria-expanded="false">
<label class="md-nav__title" for="__nav_5_1">
<span class="md-nav__icon md-icon"></span>
Refiners
</label>
<ul class="md-nav__list" data-md-scrollfix>
<li class="md-nav__item md-nav__item--nested">
<input class="md-nav__toggle md-toggle md-toggle--indeterminate" type="checkbox" id="__nav_5_1_1" >
<label class="md-nav__link" for="__nav_5_1_1" id="__nav_5_1_1_label" tabindex="0">
<span class="md-ellipsis">
<code class="doc-symbol doc-symbol-nav doc-symbol-module"></code> Fluxion
</span>
<span class="md-nav__icon md-icon"></span>
</label>
<nav class="md-nav" data-md-level="3" aria-labelledby="__nav_5_1_1_label" aria-expanded="false">
<label class="md-nav__title" for="__nav_5_1_1">
<span class="md-nav__icon md-icon"></span>
<code class="doc-symbol doc-symbol-nav doc-symbol-module"></code> Fluxion
</label>
<ul class="md-nav__list" data-md-scrollfix>
<li class="md-nav__item">
<a href="../../reference/fluxion/adapters/" class="md-nav__link">
<span class="md-ellipsis">
<code class="doc-symbol doc-symbol-nav doc-symbol-module"></code> Adapters
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../reference/fluxion/layers/" class="md-nav__link">
<span class="md-ellipsis">
<code class="doc-symbol doc-symbol-nav doc-symbol-module"></code> Layers
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../reference/fluxion/context/" class="md-nav__link">
<span class="md-ellipsis">
<code class="doc-symbol doc-symbol-nav doc-symbol-module"></code> Context
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../reference/fluxion/utils/" class="md-nav__link">
<span class="md-ellipsis">
<code class="doc-symbol doc-symbol-nav doc-symbol-module"></code> Utils
</span>
</a>
</li>
</ul>
</nav>
</li>
<li class="md-nav__item md-nav__item--nested">
<input class="md-nav__toggle md-toggle md-toggle--indeterminate" type="checkbox" id="__nav_5_1_2" >
<label class="md-nav__link" for="__nav_5_1_2" id="__nav_5_1_2_label" tabindex="0">
<span class="md-ellipsis">
<code class="doc-symbol doc-symbol-nav doc-symbol-module"></code> Foundation Models
</span>
<span class="md-nav__icon md-icon"></span>
</label>
<nav class="md-nav" data-md-level="3" aria-labelledby="__nav_5_1_2_label" aria-expanded="false">
<label class="md-nav__title" for="__nav_5_1_2">
<span class="md-nav__icon md-icon"></span>
<code class="doc-symbol doc-symbol-nav doc-symbol-module"></code> Foundation Models
</label>
<ul class="md-nav__list" data-md-scrollfix>
<li class="md-nav__item">
<a href="../../reference/foundationals/clip/" class="md-nav__link">
<span class="md-ellipsis">
<code class="doc-symbol doc-symbol-nav doc-symbol-module"></code> CLIP
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../reference/foundationals/dinov2/" class="md-nav__link">
<span class="md-ellipsis">
<code class="doc-symbol doc-symbol-nav doc-symbol-module"></code> DINOv2
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../reference/foundationals/latent_diffusion/" class="md-nav__link">
<span class="md-ellipsis">
<code class="doc-symbol doc-symbol-nav doc-symbol-module"></code> Latent Diffusion
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../reference/foundationals/segment_anything/" class="md-nav__link">
<span class="md-ellipsis">
<code class="doc-symbol doc-symbol-nav doc-symbol-module"></code> Segment Anything
</span>
</a>
</li>
<li class="md-nav__item">
<a href="../../reference/foundationals/swin/" class="md-nav__link">
<span class="md-ellipsis">
<code class="doc-symbol doc-symbol-nav doc-symbol-module"></code> Swin Transformers
</span>
</a>
</li>
</ul>
</nav>
</li>
</ul>
</nav>
</li>
</ul>
</nav>
</li>
</ul>
</nav>
</div>
</div>
</div>
<div class="md-sidebar md-sidebar--secondary" data-md-component="sidebar" data-md-type="toc" >
<div class="md-sidebar__scrollwrap">
<div class="md-sidebar__inner">
<nav class="md-nav md-nav--secondary" aria-label="Table of contents">
<label class="md-nav__title" for="__toc">
<span class="md-nav__icon md-icon"></span>
Table of contents
</label>
<ul class="md-nav__list" data-md-component="toc" data-md-scrollfix>
<li class="md-nav__item">
<a href="#pre-requisites" class="md-nav__link">
<span class="md-ellipsis">
Pre-requisites
</span>
</a>
</li>
<li class="md-nav__item">
<a href="#model" class="md-nav__link">
<span class="md-ellipsis">
Model
</span>
</a>
</li>
<li class="md-nav__item">
<a href="#dataset" class="md-nav__link">
<span class="md-ellipsis">
Dataset
</span>
</a>
</li>
<li class="md-nav__item">
<a href="#trainer" class="md-nav__link">
<span class="md-ellipsis">
Trainer
</span>
</a>
<nav class="md-nav" aria-label="Trainer">
<ul class="md-nav__list">
<li class="md-nav__item">
<a href="#batch" class="md-nav__link">
<span class="md-ellipsis">
Batch
</span>
</a>
</li>
<li class="md-nav__item">
<a href="#config" class="md-nav__link">
<span class="md-ellipsis">
Config
</span>
</a>
</li>
<li class="md-nav__item">
<a href="#subclass" class="md-nav__link">
<span class="md-ellipsis">
Subclass
</span>
</a>
</li>
<li class="md-nav__item">
<a href="#model-registration" class="md-nav__link">
<span class="md-ellipsis">
Model registration
</span>
</a>
</li>
</ul>
</nav>
</li>
<li class="md-nav__item">
<a href="#logging" class="md-nav__link">
<span class="md-ellipsis">
Logging
</span>
</a>
</li>
<li class="md-nav__item">
<a href="#evaluation" class="md-nav__link">
<span class="md-ellipsis">
Evaluation
</span>
</a>
</li>
<li class="md-nav__item">
<a href="#wrap-up" class="md-nav__link">
<span class="md-ellipsis">
Wrap up
</span>
</a>
</li>
</ul>
</nav>
</div>
</div>
</div>
<div class="md-content" data-md-component="content">
<article class="md-content__inner md-typeset">
<h1 id="training-101">Training 101<a class="headerlink" href="#training-101" title="Permanent link">&para;</a></h1>
<p>This guide will walk you through training a model using Refiners. We built the <code>training_utils</code> module to provide a simple, flexible, statically type-safe interface.</p>
<p>We will use a simple model and a toy dataset for demonstration purposes. The model will be a simple <a href="https://en.wikipedia.org/wiki/Autoencoder">autoencoder</a>, and the dataset will be a synthetic dataset of rectangles
of different sizes.</p>
<h2 id="pre-requisites">Pre-requisites<a class="headerlink" href="#pre-requisites" title="Permanent link">&para;</a></h2>
<p>We recommend installing Refiners targeting a specific commit hash to avoid unexpected changes in the API. You also
get the benefit of having a perfectly reproducible environment.</p>
<ul>
<li>
<p>with rye (recommended):
<div class="language-bash highlight"><pre><span></span><code><span id="__span-0-1"><a id="__codelineno-0-1" name="__codelineno-0-1" href="#__codelineno-0-1"></a>rye<span class="w"> </span>add<span class="w"> </span>refiners<span class="o">[</span>training<span class="o">]</span><span class="w"> </span>--git<span class="o">=</span>https://github.com/finegrain-ai/refiners.git<span class="w"> </span>--branch<span class="o">=</span>&lt;insert-latest-commit-hash&gt;
</span></code></pre></div></p>
</li>
<li>
<p>with pip:
<div class="language-bash highlight"><pre><span></span><code><span id="__span-1-1"><a id="__codelineno-1-1" name="__codelineno-1-1" href="#__codelineno-1-1"></a><span class="w"> </span>pip<span class="w"> </span>install<span class="w"> </span><span class="s2">&quot;git+https://github.com/finegrain-ai/refiners.git@&lt;insert-latest-commit-hash&gt;#egg=refiners[training]&quot;</span>
</span></code></pre></div></p>
</li>
</ul>
<h2 id="model">Model<a class="headerlink" href="#model" title="Permanent link">&para;</a></h2>
<p>Let's start by building our autoencoder using Refiners.</p>
<details class="autoencoder">
<summary>Expand to see the autoencoder model.</summary>
<div class="language-py highlight"><pre><span></span><code><span id="__span-2-1"><a id="__codelineno-2-1" name="__codelineno-2-1" href="#__codelineno-2-1"></a><span class="kn">from</span> <span class="nn">refiners.fluxion</span> <span class="kn">import</span> <span class="n">layers</span> <span class="k">as</span> <span class="n">fl</span>
</span><span id="__span-2-2"><a id="__codelineno-2-2" name="__codelineno-2-2" href="#__codelineno-2-2"></a>
</span><span id="__span-2-3"><a id="__codelineno-2-3" name="__codelineno-2-3" href="#__codelineno-2-3"></a>
</span><span id="__span-2-4"><a id="__codelineno-2-4" name="__codelineno-2-4" href="#__codelineno-2-4"></a> <span class="k">class</span> <span class="nc">ConvBlock</span><span class="p">(</span><span class="n">fl</span><span class="o">.</span><span class="n">Chain</span><span class="p">):</span>
</span><span id="__span-2-5"><a id="__codelineno-2-5" name="__codelineno-2-5" href="#__codelineno-2-5"></a> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">in_channels</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">out_channels</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="kc">None</span><span class="p">:</span>
</span><span id="__span-2-6"><a id="__codelineno-2-6" name="__codelineno-2-6" href="#__codelineno-2-6"></a> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span>
</span><span id="__span-2-7"><a id="__codelineno-2-7" name="__codelineno-2-7" href="#__codelineno-2-7"></a> <span class="n">fl</span><span class="o">.</span><span class="n">Conv2d</span><span class="p">(</span>
</span><span id="__span-2-8"><a id="__codelineno-2-8" name="__codelineno-2-8" href="#__codelineno-2-8"></a> <span class="n">in_channels</span><span class="o">=</span><span class="n">in_channels</span><span class="p">,</span>
</span><span id="__span-2-9"><a id="__codelineno-2-9" name="__codelineno-2-9" href="#__codelineno-2-9"></a> <span class="n">out_channels</span><span class="o">=</span><span class="n">out_channels</span><span class="p">,</span>
</span><span id="__span-2-10"><a id="__codelineno-2-10" name="__codelineno-2-10" href="#__codelineno-2-10"></a> <span class="n">kernel_size</span><span class="o">=</span><span class="mi">3</span><span class="p">,</span>
</span><span id="__span-2-11"><a id="__codelineno-2-11" name="__codelineno-2-11" href="#__codelineno-2-11"></a> <span class="n">padding</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span>
</span><span id="__span-2-12"><a id="__codelineno-2-12" name="__codelineno-2-12" href="#__codelineno-2-12"></a> <span class="n">groups</span><span class="o">=</span><span class="nb">min</span><span class="p">(</span><span class="n">in_channels</span><span class="p">,</span> <span class="n">out_channels</span><span class="p">)</span>
</span><span id="__span-2-13"><a id="__codelineno-2-13" name="__codelineno-2-13" href="#__codelineno-2-13"></a> <span class="p">),</span>
</span><span id="__span-2-14"><a id="__codelineno-2-14" name="__codelineno-2-14" href="#__codelineno-2-14"></a> <span class="n">fl</span><span class="o">.</span><span class="n">LayerNorm2d</span><span class="p">(</span><span class="n">out_channels</span><span class="p">),</span>
</span><span id="__span-2-15"><a id="__codelineno-2-15" name="__codelineno-2-15" href="#__codelineno-2-15"></a> <span class="n">fl</span><span class="o">.</span><span class="n">SiLU</span><span class="p">(),</span>
</span><span id="__span-2-16"><a id="__codelineno-2-16" name="__codelineno-2-16" href="#__codelineno-2-16"></a> <span class="n">fl</span><span class="o">.</span><span class="n">Conv2d</span><span class="p">(</span>
</span><span id="__span-2-17"><a id="__codelineno-2-17" name="__codelineno-2-17" href="#__codelineno-2-17"></a> <span class="n">in_channels</span><span class="o">=</span><span class="n">out_channels</span><span class="p">,</span>
</span><span id="__span-2-18"><a id="__codelineno-2-18" name="__codelineno-2-18" href="#__codelineno-2-18"></a> <span class="n">out_channels</span><span class="o">=</span><span class="n">out_channels</span><span class="p">,</span>
</span><span id="__span-2-19"><a id="__codelineno-2-19" name="__codelineno-2-19" href="#__codelineno-2-19"></a> <span class="n">kernel_size</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span>
</span><span id="__span-2-20"><a id="__codelineno-2-20" name="__codelineno-2-20" href="#__codelineno-2-20"></a> <span class="n">padding</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span>
</span><span id="__span-2-21"><a id="__codelineno-2-21" name="__codelineno-2-21" href="#__codelineno-2-21"></a> <span class="p">),</span>
</span><span id="__span-2-22"><a id="__codelineno-2-22" name="__codelineno-2-22" href="#__codelineno-2-22"></a> <span class="n">fl</span><span class="o">.</span><span class="n">LayerNorm2d</span><span class="p">(</span><span class="n">out_channels</span><span class="p">),</span>
</span><span id="__span-2-23"><a id="__codelineno-2-23" name="__codelineno-2-23" href="#__codelineno-2-23"></a> <span class="n">fl</span><span class="o">.</span><span class="n">SiLU</span><span class="p">(),</span>
</span><span id="__span-2-24"><a id="__codelineno-2-24" name="__codelineno-2-24" href="#__codelineno-2-24"></a> <span class="p">)</span>
</span><span id="__span-2-25"><a id="__codelineno-2-25" name="__codelineno-2-25" href="#__codelineno-2-25"></a>
</span><span id="__span-2-26"><a id="__codelineno-2-26" name="__codelineno-2-26" href="#__codelineno-2-26"></a>
</span><span id="__span-2-27"><a id="__codelineno-2-27" name="__codelineno-2-27" href="#__codelineno-2-27"></a><span class="k">class</span> <span class="nc">ResidualBlock</span><span class="p">(</span><span class="n">fl</span><span class="o">.</span><span class="n">Sum</span><span class="p">):</span>
</span><span id="__span-2-28"><a id="__codelineno-2-28" name="__codelineno-2-28" href="#__codelineno-2-28"></a> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">in_channels</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">out_channels</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="kc">None</span><span class="p">:</span>
</span><span id="__span-2-29"><a id="__codelineno-2-29" name="__codelineno-2-29" href="#__codelineno-2-29"></a> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span>
</span><span id="__span-2-30"><a id="__codelineno-2-30" name="__codelineno-2-30" href="#__codelineno-2-30"></a> <span class="n">ConvBlock</span><span class="p">(</span><span class="n">in_channels</span><span class="o">=</span><span class="n">in_channels</span><span class="p">,</span> <span class="n">out_channels</span><span class="o">=</span><span class="n">out_channels</span><span class="p">),</span>
</span><span id="__span-2-31"><a id="__codelineno-2-31" name="__codelineno-2-31" href="#__codelineno-2-31"></a> <span class="n">fl</span><span class="o">.</span><span class="n">Conv2d</span><span class="p">(</span>
</span><span id="__span-2-32"><a id="__codelineno-2-32" name="__codelineno-2-32" href="#__codelineno-2-32"></a> <span class="n">in_channels</span><span class="o">=</span><span class="n">in_channels</span><span class="p">,</span>
</span><span id="__span-2-33"><a id="__codelineno-2-33" name="__codelineno-2-33" href="#__codelineno-2-33"></a> <span class="n">out_channels</span><span class="o">=</span><span class="n">out_channels</span><span class="p">,</span>
</span><span id="__span-2-34"><a id="__codelineno-2-34" name="__codelineno-2-34" href="#__codelineno-2-34"></a> <span class="n">kernel_size</span><span class="o">=</span><span class="mi">3</span><span class="p">,</span>
</span><span id="__span-2-35"><a id="__codelineno-2-35" name="__codelineno-2-35" href="#__codelineno-2-35"></a> <span class="n">padding</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span>
</span><span id="__span-2-36"><a id="__codelineno-2-36" name="__codelineno-2-36" href="#__codelineno-2-36"></a> <span class="p">),</span>
</span><span id="__span-2-37"><a id="__codelineno-2-37" name="__codelineno-2-37" href="#__codelineno-2-37"></a> <span class="p">)</span>
</span><span id="__span-2-38"><a id="__codelineno-2-38" name="__codelineno-2-38" href="#__codelineno-2-38"></a>
</span><span id="__span-2-39"><a id="__codelineno-2-39" name="__codelineno-2-39" href="#__codelineno-2-39"></a>
</span><span id="__span-2-40"><a id="__codelineno-2-40" name="__codelineno-2-40" href="#__codelineno-2-40"></a><span class="k">class</span> <span class="nc">Encoder</span><span class="p">(</span><span class="n">fl</span><span class="o">.</span><span class="n">Chain</span><span class="p">):</span>
</span><span id="__span-2-41"><a id="__codelineno-2-41" name="__codelineno-2-41" href="#__codelineno-2-41"></a> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="kc">None</span><span class="p">:</span>
</span><span id="__span-2-42"><a id="__codelineno-2-42" name="__codelineno-2-42" href="#__codelineno-2-42"></a> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span>
</span><span id="__span-2-43"><a id="__codelineno-2-43" name="__codelineno-2-43" href="#__codelineno-2-43"></a> <span class="n">ResidualBlock</span><span class="p">(</span><span class="n">in_channels</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">out_channels</span><span class="o">=</span><span class="mi">8</span><span class="p">),</span>
</span><span id="__span-2-44"><a id="__codelineno-2-44" name="__codelineno-2-44" href="#__codelineno-2-44"></a> <span class="n">fl</span><span class="o">.</span><span class="n">Downsample</span><span class="p">(</span><span class="n">channels</span><span class="o">=</span><span class="mi">8</span><span class="p">,</span> <span class="n">scale_factor</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span> <span class="n">register_shape</span><span class="o">=</span><span class="kc">False</span><span class="p">),</span>
</span><span id="__span-2-45"><a id="__codelineno-2-45" name="__codelineno-2-45" href="#__codelineno-2-45"></a> <span class="n">ResidualBlock</span><span class="p">(</span><span class="n">in_channels</span><span class="o">=</span><span class="mi">8</span><span class="p">,</span> <span class="n">out_channels</span><span class="o">=</span><span class="mi">16</span><span class="p">),</span>
</span><span id="__span-2-46"><a id="__codelineno-2-46" name="__codelineno-2-46" href="#__codelineno-2-46"></a> <span class="n">fl</span><span class="o">.</span><span class="n">Downsample</span><span class="p">(</span><span class="n">channels</span><span class="o">=</span><span class="mi">16</span><span class="p">,</span> <span class="n">scale_factor</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span> <span class="n">register_shape</span><span class="o">=</span><span class="kc">False</span><span class="p">),</span>
</span><span id="__span-2-47"><a id="__codelineno-2-47" name="__codelineno-2-47" href="#__codelineno-2-47"></a> <span class="n">ResidualBlock</span><span class="p">(</span><span class="n">in_channels</span><span class="o">=</span><span class="mi">16</span><span class="p">,</span> <span class="n">out_channels</span><span class="o">=</span><span class="mi">32</span><span class="p">),</span>
</span><span id="__span-2-48"><a id="__codelineno-2-48" name="__codelineno-2-48" href="#__codelineno-2-48"></a> <span class="n">fl</span><span class="o">.</span><span class="n">Downsample</span><span class="p">(</span><span class="n">channels</span><span class="o">=</span><span class="mi">32</span><span class="p">,</span> <span class="n">scale_factor</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span> <span class="n">register_shape</span><span class="o">=</span><span class="kc">False</span><span class="p">),</span>
</span><span id="__span-2-49"><a id="__codelineno-2-49" name="__codelineno-2-49" href="#__codelineno-2-49"></a> <span class="n">fl</span><span class="o">.</span><span class="n">Reshape</span><span class="p">(</span><span class="mi">2048</span><span class="p">),</span>
</span><span id="__span-2-50"><a id="__codelineno-2-50" name="__codelineno-2-50" href="#__codelineno-2-50"></a> <span class="n">fl</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">in_features</span><span class="o">=</span><span class="mi">2048</span><span class="p">,</span> <span class="n">out_features</span><span class="o">=</span><span class="mi">256</span><span class="p">),</span>
</span><span id="__span-2-51"><a id="__codelineno-2-51" name="__codelineno-2-51" href="#__codelineno-2-51"></a> <span class="n">fl</span><span class="o">.</span><span class="n">SiLU</span><span class="p">(),</span>
</span><span id="__span-2-52"><a id="__codelineno-2-52" name="__codelineno-2-52" href="#__codelineno-2-52"></a> <span class="n">fl</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">in_features</span><span class="o">=</span><span class="mi">256</span><span class="p">,</span> <span class="n">out_features</span><span class="o">=</span><span class="mi">256</span><span class="p">),</span>
</span><span id="__span-2-53"><a id="__codelineno-2-53" name="__codelineno-2-53" href="#__codelineno-2-53"></a> <span class="p">)</span>
</span><span id="__span-2-54"><a id="__codelineno-2-54" name="__codelineno-2-54" href="#__codelineno-2-54"></a>
</span><span id="__span-2-55"><a id="__codelineno-2-55" name="__codelineno-2-55" href="#__codelineno-2-55"></a>
</span><span id="__span-2-56"><a id="__codelineno-2-56" name="__codelineno-2-56" href="#__codelineno-2-56"></a><span class="k">class</span> <span class="nc">Decoder</span><span class="p">(</span><span class="n">fl</span><span class="o">.</span><span class="n">Chain</span><span class="p">):</span>
</span><span id="__span-2-57"><a id="__codelineno-2-57" name="__codelineno-2-57" href="#__codelineno-2-57"></a> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="kc">None</span><span class="p">:</span>
</span><span id="__span-2-58"><a id="__codelineno-2-58" name="__codelineno-2-58" href="#__codelineno-2-58"></a> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span>
</span><span id="__span-2-59"><a id="__codelineno-2-59" name="__codelineno-2-59" href="#__codelineno-2-59"></a> <span class="n">fl</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">in_features</span><span class="o">=</span><span class="mi">256</span><span class="p">,</span> <span class="n">out_features</span><span class="o">=</span><span class="mi">256</span><span class="p">),</span>
</span><span id="__span-2-60"><a id="__codelineno-2-60" name="__codelineno-2-60" href="#__codelineno-2-60"></a> <span class="n">fl</span><span class="o">.</span><span class="n">SiLU</span><span class="p">(),</span>
</span><span id="__span-2-61"><a id="__codelineno-2-61" name="__codelineno-2-61" href="#__codelineno-2-61"></a> <span class="n">fl</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">in_features</span><span class="o">=</span><span class="mi">256</span><span class="p">,</span> <span class="n">out_features</span><span class="o">=</span><span class="mi">2048</span><span class="p">),</span>
</span><span id="__span-2-62"><a id="__codelineno-2-62" name="__codelineno-2-62" href="#__codelineno-2-62"></a> <span class="n">fl</span><span class="o">.</span><span class="n">Reshape</span><span class="p">(</span><span class="mi">32</span><span class="p">,</span> <span class="mi">8</span><span class="p">,</span> <span class="mi">8</span><span class="p">),</span>
</span><span id="__span-2-63"><a id="__codelineno-2-63" name="__codelineno-2-63" href="#__codelineno-2-63"></a> <span class="n">ResidualBlock</span><span class="p">(</span><span class="n">in_channels</span><span class="o">=</span><span class="mi">32</span><span class="p">,</span> <span class="n">out_channels</span><span class="o">=</span><span class="mi">32</span><span class="p">),</span>
</span><span id="__span-2-64"><a id="__codelineno-2-64" name="__codelineno-2-64" href="#__codelineno-2-64"></a> <span class="n">ResidualBlock</span><span class="p">(</span><span class="n">in_channels</span><span class="o">=</span><span class="mi">32</span><span class="p">,</span> <span class="n">out_channels</span><span class="o">=</span><span class="mi">32</span><span class="p">),</span>
</span><span id="__span-2-65"><a id="__codelineno-2-65" name="__codelineno-2-65" href="#__codelineno-2-65"></a> <span class="n">fl</span><span class="o">.</span><span class="n">Upsample</span><span class="p">(</span><span class="n">channels</span><span class="o">=</span><span class="mi">32</span><span class="p">,</span> <span class="n">upsample_factor</span><span class="o">=</span><span class="mi">2</span><span class="p">),</span>
</span><span id="__span-2-66"><a id="__codelineno-2-66" name="__codelineno-2-66" href="#__codelineno-2-66"></a> <span class="n">ResidualBlock</span><span class="p">(</span><span class="n">in_channels</span><span class="o">=</span><span class="mi">32</span><span class="p">,</span> <span class="n">out_channels</span><span class="o">=</span><span class="mi">16</span><span class="p">),</span>
</span><span id="__span-2-67"><a id="__codelineno-2-67" name="__codelineno-2-67" href="#__codelineno-2-67"></a> <span class="n">ResidualBlock</span><span class="p">(</span><span class="n">in_channels</span><span class="o">=</span><span class="mi">16</span><span class="p">,</span> <span class="n">out_channels</span><span class="o">=</span><span class="mi">16</span><span class="p">),</span>
</span><span id="__span-2-68"><a id="__codelineno-2-68" name="__codelineno-2-68" href="#__codelineno-2-68"></a> <span class="n">fl</span><span class="o">.</span><span class="n">Upsample</span><span class="p">(</span><span class="n">channels</span><span class="o">=</span><span class="mi">16</span><span class="p">,</span> <span class="n">upsample_factor</span><span class="o">=</span><span class="mi">2</span><span class="p">),</span>
</span><span id="__span-2-69"><a id="__codelineno-2-69" name="__codelineno-2-69" href="#__codelineno-2-69"></a> <span class="n">ResidualBlock</span><span class="p">(</span><span class="n">in_channels</span><span class="o">=</span><span class="mi">16</span><span class="p">,</span> <span class="n">out_channels</span><span class="o">=</span><span class="mi">8</span><span class="p">),</span>
</span><span id="__span-2-70"><a id="__codelineno-2-70" name="__codelineno-2-70" href="#__codelineno-2-70"></a> <span class="n">ResidualBlock</span><span class="p">(</span><span class="n">in_channels</span><span class="o">=</span><span class="mi">8</span><span class="p">,</span> <span class="n">out_channels</span><span class="o">=</span><span class="mi">8</span><span class="p">),</span>
</span><span id="__span-2-71"><a id="__codelineno-2-71" name="__codelineno-2-71" href="#__codelineno-2-71"></a> <span class="n">fl</span><span class="o">.</span><span class="n">Upsample</span><span class="p">(</span><span class="n">channels</span><span class="o">=</span><span class="mi">8</span><span class="p">,</span> <span class="n">upsample_factor</span><span class="o">=</span><span class="mi">2</span><span class="p">),</span>
</span><span id="__span-2-72"><a id="__codelineno-2-72" name="__codelineno-2-72" href="#__codelineno-2-72"></a> <span class="n">ResidualBlock</span><span class="p">(</span><span class="n">in_channels</span><span class="o">=</span><span class="mi">8</span><span class="p">,</span> <span class="n">out_channels</span><span class="o">=</span><span class="mi">8</span><span class="p">),</span>
</span><span id="__span-2-73"><a id="__codelineno-2-73" name="__codelineno-2-73" href="#__codelineno-2-73"></a> <span class="n">ResidualBlock</span><span class="p">(</span><span class="n">in_channels</span><span class="o">=</span><span class="mi">8</span><span class="p">,</span> <span class="n">out_channels</span><span class="o">=</span><span class="mi">1</span><span class="p">),</span>
</span><span id="__span-2-74"><a id="__codelineno-2-74" name="__codelineno-2-74" href="#__codelineno-2-74"></a> <span class="n">fl</span><span class="o">.</span><span class="n">Sigmoid</span><span class="p">(),</span>
</span><span id="__span-2-75"><a id="__codelineno-2-75" name="__codelineno-2-75" href="#__codelineno-2-75"></a> <span class="p">)</span>
</span><span id="__span-2-76"><a id="__codelineno-2-76" name="__codelineno-2-76" href="#__codelineno-2-76"></a>
</span><span id="__span-2-77"><a id="__codelineno-2-77" name="__codelineno-2-77" href="#__codelineno-2-77"></a>
</span><span id="__span-2-78"><a id="__codelineno-2-78" name="__codelineno-2-78" href="#__codelineno-2-78"></a><span class="k">class</span> <span class="nc">Autoencoder</span><span class="p">(</span><span class="n">fl</span><span class="o">.</span><span class="n">Chain</span><span class="p">):</span>
</span><span id="__span-2-79"><a id="__codelineno-2-79" name="__codelineno-2-79" href="#__codelineno-2-79"></a> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="kc">None</span><span class="p">:</span>
</span><span id="__span-2-80"><a id="__codelineno-2-80" name="__codelineno-2-80" href="#__codelineno-2-80"></a> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span>
</span><span id="__span-2-81"><a id="__codelineno-2-81" name="__codelineno-2-81" href="#__codelineno-2-81"></a> <span class="n">Encoder</span><span class="p">(),</span>
</span><span id="__span-2-82"><a id="__codelineno-2-82" name="__codelineno-2-82" href="#__codelineno-2-82"></a> <span class="n">Decoder</span><span class="p">(),</span>
</span><span id="__span-2-83"><a id="__codelineno-2-83" name="__codelineno-2-83" href="#__codelineno-2-83"></a> <span class="p">)</span>
</span><span id="__span-2-84"><a id="__codelineno-2-84" name="__codelineno-2-84" href="#__codelineno-2-84"></a>
</span><span id="__span-2-85"><a id="__codelineno-2-85" name="__codelineno-2-85" href="#__codelineno-2-85"></a> <span class="nd">@property</span>
</span><span id="__span-2-86"><a id="__codelineno-2-86" name="__codelineno-2-86" href="#__codelineno-2-86"></a> <span class="k">def</span> <span class="nf">encoder</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Encoder</span><span class="p">:</span>
</span><span id="__span-2-87"><a id="__codelineno-2-87" name="__codelineno-2-87" href="#__codelineno-2-87"></a> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">ensure_find</span><span class="p">(</span><span class="n">Encoder</span><span class="p">)</span>
</span><span id="__span-2-88"><a id="__codelineno-2-88" name="__codelineno-2-88" href="#__codelineno-2-88"></a>
</span><span id="__span-2-89"><a id="__codelineno-2-89" name="__codelineno-2-89" href="#__codelineno-2-89"></a> <span class="nd">@property</span>
</span><span id="__span-2-90"><a id="__codelineno-2-90" name="__codelineno-2-90" href="#__codelineno-2-90"></a> <span class="k">def</span> <span class="nf">decoder</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Decoder</span><span class="p">:</span>
</span><span id="__span-2-91"><a id="__codelineno-2-91" name="__codelineno-2-91" href="#__codelineno-2-91"></a> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">ensure_find</span><span class="p">(</span><span class="n">Decoder</span><span class="p">)</span>
</span></code></pre></div>
</details>
<p>We now have a fully functional autoencoder that takes an image with one channel of
size 64x64 and compresses it to a vector of size 256 (x16 compression). The decoder then takes this vector and reconstructs the original image.</p>
<div class="language-py highlight"><pre><span></span><code><span id="__span-3-1"><a id="__codelineno-3-1" name="__codelineno-3-1" href="#__codelineno-3-1"></a><span class="kn">import</span> <span class="nn">torch</span>
</span><span id="__span-3-2"><a id="__codelineno-3-2" name="__codelineno-3-2" href="#__codelineno-3-2"></a>
</span><span id="__span-3-3"><a id="__codelineno-3-3" name="__codelineno-3-3" href="#__codelineno-3-3"></a><span class="n">autoencoder</span> <span class="o">=</span> <span class="n">Autoencoder</span><span class="p">()</span>
</span><span id="__span-3-4"><a id="__codelineno-3-4" name="__codelineno-3-4" href="#__codelineno-3-4"></a>
</span><span id="__span-3-5"><a id="__codelineno-3-5" name="__codelineno-3-5" href="#__codelineno-3-5"></a><span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">randn</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">64</span><span class="p">,</span> <span class="mi">64</span><span class="p">)</span> <span class="c1"># batch of 2 images</span>
</span><span id="__span-3-6"><a id="__codelineno-3-6" name="__codelineno-3-6" href="#__codelineno-3-6"></a>
</span><span id="__span-3-7"><a id="__codelineno-3-7" name="__codelineno-3-7" href="#__codelineno-3-7"></a><span class="n">z</span> <span class="o">=</span> <span class="n">autoencoder</span><span class="o">.</span><span class="n">encoder</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="c1"># [2, 256]</span>
</span><span id="__span-3-8"><a id="__codelineno-3-8" name="__codelineno-3-8" href="#__codelineno-3-8"></a>
</span><span id="__span-3-9"><a id="__codelineno-3-9" name="__codelineno-3-9" href="#__codelineno-3-9"></a><span class="n">x_reconstructed</span> <span class="o">=</span> <span class="n">autoencoder</span><span class="o">.</span><span class="n">decoder</span><span class="p">(</span><span class="n">z</span><span class="p">)</span> <span class="c1"># [2, 1, 64, 64]</span>
</span></code></pre></div>
<h2 id="dataset">Dataset<a class="headerlink" href="#dataset" title="Permanent link">&para;</a></h2>
<p>We will use a synthetic dataset of rectangles of different sizes. The dataset will be generated on the fly using this
simple function:</p>
<div class="language-python highlight"><pre><span></span><code><span id="__span-4-1"><a id="__codelineno-4-1" name="__codelineno-4-1" href="#__codelineno-4-1"></a><span class="kn">import</span> <span class="nn">random</span>
</span><span id="__span-4-2"><a id="__codelineno-4-2" name="__codelineno-4-2" href="#__codelineno-4-2"></a><span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Generator</span>
</span><span id="__span-4-3"><a id="__codelineno-4-3" name="__codelineno-4-3" href="#__codelineno-4-3"></a><span class="kn">from</span> <span class="nn">PIL</span> <span class="kn">import</span> <span class="n">Image</span>
</span><span id="__span-4-4"><a id="__codelineno-4-4" name="__codelineno-4-4" href="#__codelineno-4-4"></a>
</span><span id="__span-4-5"><a id="__codelineno-4-5" name="__codelineno-4-5" href="#__codelineno-4-5"></a><span class="kn">from</span> <span class="nn">refiners.fluxion.utils</span> <span class="kn">import</span> <span class="n">image_to_tensor</span>
</span><span id="__span-4-6"><a id="__codelineno-4-6" name="__codelineno-4-6" href="#__codelineno-4-6"></a>
</span><span id="__span-4-7"><a id="__codelineno-4-7" name="__codelineno-4-7" href="#__codelineno-4-7"></a><span class="k">def</span> <span class="nf">generate_mask</span><span class="p">(</span><span class="n">size</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">seed</span><span class="p">:</span> <span class="nb">int</span> <span class="o">|</span> <span class="kc">None</span> <span class="o">=</span> <span class="kc">None</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Generator</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="kc">None</span><span class="p">,</span> <span class="kc">None</span><span class="p">]:</span>
</span><span id="__span-4-8"><a id="__codelineno-4-8" name="__codelineno-4-8" href="#__codelineno-4-8"></a><span class="w"> </span><span class="sd">&quot;&quot;&quot;Generate a tensor of a binary mask of size `size` using random rectangles.&quot;&quot;&quot;</span>
</span><span id="__span-4-9"><a id="__codelineno-4-9" name="__codelineno-4-9" href="#__codelineno-4-9"></a> <span class="k">if</span> <span class="n">seed</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
</span><span id="__span-4-10"><a id="__codelineno-4-10" name="__codelineno-4-10" href="#__codelineno-4-10"></a> <span class="n">seed</span> <span class="o">=</span> <span class="n">random</span><span class="o">.</span><span class="n">randint</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">2</span><span class="o">**</span><span class="mi">32</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span>
</span><span id="__span-4-11"><a id="__codelineno-4-11" name="__codelineno-4-11" href="#__codelineno-4-11"></a> <span class="n">random</span><span class="o">.</span><span class="n">seed</span><span class="p">(</span><span class="n">seed</span><span class="p">)</span>
</span><span id="__span-4-12"><a id="__codelineno-4-12" name="__codelineno-4-12" href="#__codelineno-4-12"></a>
</span><span id="__span-4-13"><a id="__codelineno-4-13" name="__codelineno-4-13" href="#__codelineno-4-13"></a> <span class="k">while</span> <span class="kc">True</span><span class="p">:</span>
</span><span id="__span-4-14"><a id="__codelineno-4-14" name="__codelineno-4-14" href="#__codelineno-4-14"></a> <span class="n">rectangle</span> <span class="o">=</span> <span class="n">Image</span><span class="o">.</span><span class="n">new</span><span class="p">(</span>
</span><span id="__span-4-15"><a id="__codelineno-4-15" name="__codelineno-4-15" href="#__codelineno-4-15"></a> <span class="s2">&quot;L&quot;</span><span class="p">,</span> <span class="p">(</span><span class="n">random</span><span class="o">.</span><span class="n">randint</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">size</span><span class="p">),</span> <span class="n">random</span><span class="o">.</span><span class="n">randint</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">size</span><span class="p">)),</span> <span class="n">color</span><span class="o">=</span><span class="mi">255</span>
</span><span id="__span-4-16"><a id="__codelineno-4-16" name="__codelineno-4-16" href="#__codelineno-4-16"></a> <span class="p">)</span>
</span><span id="__span-4-17"><a id="__codelineno-4-17" name="__codelineno-4-17" href="#__codelineno-4-17"></a> <span class="n">mask</span> <span class="o">=</span> <span class="n">Image</span><span class="o">.</span><span class="n">new</span><span class="p">(</span><span class="s2">&quot;L&quot;</span><span class="p">,</span> <span class="p">(</span><span class="n">size</span><span class="p">,</span> <span class="n">size</span><span class="p">))</span>
</span><span id="__span-4-18"><a id="__codelineno-4-18" name="__codelineno-4-18" href="#__codelineno-4-18"></a> <span class="n">mask</span><span class="o">.</span><span class="n">paste</span><span class="p">(</span>
</span><span id="__span-4-19"><a id="__codelineno-4-19" name="__codelineno-4-19" href="#__codelineno-4-19"></a> <span class="n">rectangle</span><span class="p">,</span>
</span><span id="__span-4-20"><a id="__codelineno-4-20" name="__codelineno-4-20" href="#__codelineno-4-20"></a> <span class="p">(</span>
</span><span id="__span-4-21"><a id="__codelineno-4-21" name="__codelineno-4-21" href="#__codelineno-4-21"></a> <span class="n">random</span><span class="o">.</span><span class="n">randint</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">size</span> <span class="o">-</span> <span class="n">rectangle</span><span class="o">.</span><span class="n">width</span><span class="p">),</span>
</span><span id="__span-4-22"><a id="__codelineno-4-22" name="__codelineno-4-22" href="#__codelineno-4-22"></a> <span class="n">random</span><span class="o">.</span><span class="n">randint</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">size</span> <span class="o">-</span> <span class="n">rectangle</span><span class="o">.</span><span class="n">height</span><span class="p">),</span>
</span><span id="__span-4-23"><a id="__codelineno-4-23" name="__codelineno-4-23" href="#__codelineno-4-23"></a> <span class="p">),</span>
</span><span id="__span-4-24"><a id="__codelineno-4-24" name="__codelineno-4-24" href="#__codelineno-4-24"></a> <span class="p">)</span>
</span><span id="__span-4-25"><a id="__codelineno-4-25" name="__codelineno-4-25" href="#__codelineno-4-25"></a> <span class="n">tensor</span> <span class="o">=</span> <span class="n">image_to_tensor</span><span class="p">(</span><span class="n">mask</span><span class="p">)</span>
</span><span id="__span-4-26"><a id="__codelineno-4-26" name="__codelineno-4-26" href="#__codelineno-4-26"></a>
</span><span id="__span-4-27"><a id="__codelineno-4-27" name="__codelineno-4-27" href="#__codelineno-4-27"></a> <span class="k">if</span> <span class="n">random</span><span class="o">.</span><span class="n">random</span><span class="p">()</span> <span class="o">&gt;</span> <span class="mf">0.5</span><span class="p">:</span>
</span><span id="__span-4-28"><a id="__codelineno-4-28" name="__codelineno-4-28" href="#__codelineno-4-28"></a> <span class="n">tensor</span> <span class="o">=</span> <span class="mi">1</span> <span class="o">-</span> <span class="n">tensor</span>
</span><span id="__span-4-29"><a id="__codelineno-4-29" name="__codelineno-4-29" href="#__codelineno-4-29"></a>
</span><span id="__span-4-30"><a id="__codelineno-4-30" name="__codelineno-4-30" href="#__codelineno-4-30"></a> <span class="k">yield</span> <span class="n">tensor</span>
</span></code></pre></div>
<p>To generate a mask, do:</p>
<div class="language-python highlight"><pre><span></span><code><span id="__span-5-1"><a id="__codelineno-5-1" name="__codelineno-5-1" href="#__codelineno-5-1"></a><span class="kn">from</span> <span class="nn">refiners.fluxion.utils</span> <span class="kn">import</span> <span class="n">tensor_to_image</span>
</span><span id="__span-5-2"><a id="__codelineno-5-2" name="__codelineno-5-2" href="#__codelineno-5-2"></a>
</span><span id="__span-5-3"><a id="__codelineno-5-3" name="__codelineno-5-3" href="#__codelineno-5-3"></a><span class="n">mask</span> <span class="o">=</span> <span class="nb">next</span><span class="p">(</span><span class="n">generate_mask</span><span class="p">(</span><span class="mi">64</span><span class="p">,</span> <span class="n">seed</span><span class="o">=</span><span class="mi">42</span><span class="p">))</span>
</span><span id="__span-5-4"><a id="__codelineno-5-4" name="__codelineno-5-4" href="#__codelineno-5-4"></a><span class="n">tensor_to_image</span><span class="p">(</span><span class="n">mask</span><span class="p">)</span><span class="o">.</span><span class="n">save</span><span class="p">(</span><span class="s2">&quot;mask.png&quot;</span><span class="p">)</span>
</span></code></pre></div>
<p>Here are a two examples of generated masks:
<img alt="alt text" src="sample-0.png" />
<img alt="alt text" src="sample-1.png" /></p>
<h2 id="trainer">Trainer<a class="headerlink" href="#trainer" title="Permanent link">&para;</a></h2>
<p>We will now create a Trainer class to handle the training loop. This class will manage the model, the optimizer, the loss function, and the dataset. It will also orchestrate the training loop and the evaluation loop.</p>
<p>But first, we need to define the batch type that will be used to represent a batch for the forward and backward pass and the configuration associated with the trainer.</p>
<h3 id="batch">Batch<a class="headerlink" href="#batch" title="Permanent link">&para;</a></h3>
<p>Our batches are composed of a single tensor representing the images. We will define a simple <code>Batch</code> type to implement this.</p>
<div class="language-python highlight"><pre><span></span><code><span id="__span-6-1"><a id="__codelineno-6-1" name="__codelineno-6-1" href="#__codelineno-6-1"></a><span class="kn">from</span> <span class="nn">dataclasses</span> <span class="kn">import</span> <span class="n">dataclass</span>
</span><span id="__span-6-2"><a id="__codelineno-6-2" name="__codelineno-6-2" href="#__codelineno-6-2"></a>
</span><span id="__span-6-3"><a id="__codelineno-6-3" name="__codelineno-6-3" href="#__codelineno-6-3"></a><span class="nd">@dataclass</span>
</span><span id="__span-6-4"><a id="__codelineno-6-4" name="__codelineno-6-4" href="#__codelineno-6-4"></a><span class="k">class</span> <span class="nc">Batch</span><span class="p">:</span>
</span><span id="__span-6-5"><a id="__codelineno-6-5" name="__codelineno-6-5" href="#__codelineno-6-5"></a> <span class="n">image</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span>
</span></code></pre></div>
<h3 id="config">Config<a class="headerlink" href="#config" title="Permanent link">&para;</a></h3>
<p>We will now define the configuration for the autoencoder. It holds the configuration for the training loop, the optimizer, and the learning rate scheduler. It should inherit <code>refiners.training_utils.BaseConfig</code> and has the following mandatory attributes:</p>
<ul>
<li><code>TrainingConfig</code>: The configuration for the training loop, including the duration of the training, the batch size, device, dtype, etc.</li>
<li><code>OptimizerConfig</code>: The configuration for the optimizer, including the learning rate, weight decay, etc.</li>
<li><code>LRSchedulerConfig</code>: The configuration for the learning rate scheduler, including the scheduler type, parameters, etc.</li>
</ul>
<p>Example:</p>
<div class="language-python highlight"><pre><span></span><code><span id="__span-7-1"><a id="__codelineno-7-1" name="__codelineno-7-1" href="#__codelineno-7-1"></a><span class="kn">from</span> <span class="nn">refiners.training_utils</span> <span class="kn">import</span> <span class="n">BaseConfig</span><span class="p">,</span> <span class="n">TrainingConfig</span><span class="p">,</span> <span class="n">OptimizerConfig</span><span class="p">,</span> <span class="n">LRSchedulerConfig</span><span class="p">,</span> <span class="n">Optimizers</span><span class="p">,</span> <span class="n">LRSchedulerType</span><span class="p">,</span> <span class="n">Epoch</span>
</span><span id="__span-7-2"><a id="__codelineno-7-2" name="__codelineno-7-2" href="#__codelineno-7-2"></a>
</span><span id="__span-7-3"><a id="__codelineno-7-3" name="__codelineno-7-3" href="#__codelineno-7-3"></a><span class="k">class</span> <span class="nc">AutoencoderConfig</span><span class="p">(</span><span class="n">BaseConfig</span><span class="p">):</span>
</span><span id="__span-7-4"><a id="__codelineno-7-4" name="__codelineno-7-4" href="#__codelineno-7-4"></a> <span class="o">...</span>
</span><span id="__span-7-5"><a id="__codelineno-7-5" name="__codelineno-7-5" href="#__codelineno-7-5"></a>
</span><span id="__span-7-6"><a id="__codelineno-7-6" name="__codelineno-7-6" href="#__codelineno-7-6"></a><span class="n">training</span> <span class="o">=</span> <span class="n">TrainingConfig</span><span class="p">(</span>
</span><span id="__span-7-7"><a id="__codelineno-7-7" name="__codelineno-7-7" href="#__codelineno-7-7"></a> <span class="n">duration</span><span class="o">=</span><span class="n">Epoch</span><span class="p">(</span><span class="mi">1000</span><span class="p">),</span>
</span><span id="__span-7-8"><a id="__codelineno-7-8" name="__codelineno-7-8" href="#__codelineno-7-8"></a> <span class="n">device</span><span class="o">=</span><span class="s2">&quot;cuda&quot;</span> <span class="k">if</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">is_available</span><span class="p">()</span> <span class="k">else</span> <span class="s2">&quot;cpu&quot;</span><span class="p">,</span>
</span><span id="__span-7-9"><a id="__codelineno-7-9" name="__codelineno-7-9" href="#__codelineno-7-9"></a> <span class="n">dtype</span><span class="o">=</span><span class="s2">&quot;float32&quot;</span>
</span><span id="__span-7-10"><a id="__codelineno-7-10" name="__codelineno-7-10" href="#__codelineno-7-10"></a><span class="p">)</span>
</span><span id="__span-7-11"><a id="__codelineno-7-11" name="__codelineno-7-11" href="#__codelineno-7-11"></a>
</span><span id="__span-7-12"><a id="__codelineno-7-12" name="__codelineno-7-12" href="#__codelineno-7-12"></a><span class="n">optimizer</span> <span class="o">=</span> <span class="n">OptimizerConfig</span><span class="p">(</span>
</span><span id="__span-7-13"><a id="__codelineno-7-13" name="__codelineno-7-13" href="#__codelineno-7-13"></a> <span class="n">optimizer</span><span class="o">=</span><span class="n">Optimizers</span><span class="o">.</span><span class="n">AdamW</span><span class="p">,</span>
</span><span id="__span-7-14"><a id="__codelineno-7-14" name="__codelineno-7-14" href="#__codelineno-7-14"></a> <span class="n">learning_rate</span><span class="o">=</span><span class="mf">1e-4</span><span class="p">,</span>
</span><span id="__span-7-15"><a id="__codelineno-7-15" name="__codelineno-7-15" href="#__codelineno-7-15"></a><span class="p">)</span>
</span><span id="__span-7-16"><a id="__codelineno-7-16" name="__codelineno-7-16" href="#__codelineno-7-16"></a>
</span><span id="__span-7-17"><a id="__codelineno-7-17" name="__codelineno-7-17" href="#__codelineno-7-17"></a><span class="n">lr_scheduler</span> <span class="o">=</span> <span class="n">LRSchedulerConfig</span><span class="p">(</span>
</span><span id="__span-7-18"><a id="__codelineno-7-18" name="__codelineno-7-18" href="#__codelineno-7-18"></a> <span class="nb">type</span><span class="o">=</span><span class="n">LRSchedulerType</span><span class="o">.</span><span class="n">ConstantLR</span>
</span><span id="__span-7-19"><a id="__codelineno-7-19" name="__codelineno-7-19" href="#__codelineno-7-19"></a><span class="p">)</span>
</span><span id="__span-7-20"><a id="__codelineno-7-20" name="__codelineno-7-20" href="#__codelineno-7-20"></a>
</span><span id="__span-7-21"><a id="__codelineno-7-21" name="__codelineno-7-21" href="#__codelineno-7-21"></a><span class="n">config</span> <span class="o">=</span> <span class="n">AutoencoderConfig</span><span class="p">(</span>
</span><span id="__span-7-22"><a id="__codelineno-7-22" name="__codelineno-7-22" href="#__codelineno-7-22"></a> <span class="n">training</span><span class="o">=</span><span class="n">training</span><span class="p">,</span>
</span><span id="__span-7-23"><a id="__codelineno-7-23" name="__codelineno-7-23" href="#__codelineno-7-23"></a> <span class="n">optimizer</span><span class="o">=</span><span class="n">optimizer</span><span class="p">,</span>
</span><span id="__span-7-24"><a id="__codelineno-7-24" name="__codelineno-7-24" href="#__codelineno-7-24"></a> <span class="n">lr_scheduler</span><span class="o">=</span><span class="n">lr_scheduler</span><span class="p">,</span>
</span><span id="__span-7-25"><a id="__codelineno-7-25" name="__codelineno-7-25" href="#__codelineno-7-25"></a><span class="p">)</span>
</span></code></pre></div>
<h3 id="subclass">Subclass<a class="headerlink" href="#subclass" title="Permanent link">&para;</a></h3>
<p>We can now define the Trainer subclass. It should inherit from <code>refiners.training_utils.Trainer</code> and implement the following methods:</p>
<ul>
<li><code>create_data_iterable</code>: The <code>Trainer</code> will call this method to create and cache the data iterable. During training, the loop will pull batches from this iterable and pass them to the <code>compute_loss</code> method. Every time the iterable is exhausted, an epoch ends.</li>
<li><code>compute_loss</code>: This method should take a Batch and return the loss tensor.</li>
</ul>
<p>Here is a simple implementation of the <code>create_data_iterable</code> method. For this toy example, we will generate a simple list of <code>Batch</code> objects containing random masks. Later you can replace this with <code>torch.utils.data.DataLoader</code> or any other data loader with more complex features that support shuffling, parallel loading, etc.</p>
<div class="language-python highlight"><pre><span></span><code><span id="__span-8-1"><a id="__codelineno-8-1" name="__codelineno-8-1" href="#__codelineno-8-1"></a><span class="kn">from</span> <span class="nn">functools</span> <span class="kn">import</span> <span class="n">cached_property</span>
</span><span id="__span-8-2"><a id="__codelineno-8-2" name="__codelineno-8-2" href="#__codelineno-8-2"></a><span class="kn">from</span> <span class="nn">refiners.training_utils</span> <span class="kn">import</span> <span class="n">Trainer</span>
</span><span id="__span-8-3"><a id="__codelineno-8-3" name="__codelineno-8-3" href="#__codelineno-8-3"></a>
</span><span id="__span-8-4"><a id="__codelineno-8-4" name="__codelineno-8-4" href="#__codelineno-8-4"></a>
</span><span id="__span-8-5"><a id="__codelineno-8-5" name="__codelineno-8-5" href="#__codelineno-8-5"></a><span class="k">class</span> <span class="nc">AutoencoderConfig</span><span class="p">(</span><span class="n">BaseConfig</span><span class="p">):</span>
</span><span id="__span-8-6"><a id="__codelineno-8-6" name="__codelineno-8-6" href="#__codelineno-8-6"></a> <span class="n">num_images</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">2048</span>
</span><span id="__span-8-7"><a id="__codelineno-8-7" name="__codelineno-8-7" href="#__codelineno-8-7"></a> <span class="n">batch_size</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">32</span>
</span><span id="__span-8-8"><a id="__codelineno-8-8" name="__codelineno-8-8" href="#__codelineno-8-8"></a>
</span><span id="__span-8-9"><a id="__codelineno-8-9" name="__codelineno-8-9" href="#__codelineno-8-9"></a>
</span><span id="__span-8-10"><a id="__codelineno-8-10" name="__codelineno-8-10" href="#__codelineno-8-10"></a><span class="k">class</span> <span class="nc">AutoencoderTrainer</span><span class="p">(</span><span class="n">Trainer</span><span class="p">[</span><span class="n">AutoencoderConfig</span><span class="p">,</span> <span class="n">Batch</span><span class="p">]):</span>
</span><span id="__span-8-11"><a id="__codelineno-8-11" name="__codelineno-8-11" href="#__codelineno-8-11"></a> <span class="k">def</span> <span class="nf">create_data_iterable</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">list</span><span class="p">[</span><span class="n">Batch</span><span class="p">]:</span>
</span><span id="__span-8-12"><a id="__codelineno-8-12" name="__codelineno-8-12" href="#__codelineno-8-12"></a> <span class="n">dataset</span><span class="p">:</span> <span class="nb">list</span><span class="p">[</span><span class="n">Batch</span><span class="p">]</span> <span class="o">=</span> <span class="p">[]</span>
</span><span id="__span-8-13"><a id="__codelineno-8-13" name="__codelineno-8-13" href="#__codelineno-8-13"></a> <span class="n">generator</span> <span class="o">=</span> <span class="n">generate_mask</span><span class="p">(</span><span class="n">size</span><span class="o">=</span><span class="mi">64</span><span class="p">)</span>
</span><span id="__span-8-14"><a id="__codelineno-8-14" name="__codelineno-8-14" href="#__codelineno-8-14"></a>
</span><span id="__span-8-15"><a id="__codelineno-8-15" name="__codelineno-8-15" href="#__codelineno-8-15"></a> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">num_images</span> <span class="o">//</span> <span class="bp">self</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">batch_size</span><span class="p">):</span>
</span><span id="__span-8-16"><a id="__codelineno-8-16" name="__codelineno-8-16" href="#__codelineno-8-16"></a> <span class="n">masks</span> <span class="o">=</span> <span class="p">[</span><span class="nb">next</span><span class="p">(</span><span class="n">generator</span><span class="p">)</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">batch_size</span><span class="p">)]</span>
</span><span id="__span-8-17"><a id="__codelineno-8-17" name="__codelineno-8-17" href="#__codelineno-8-17"></a> <span class="n">dataset</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">Batch</span><span class="p">(</span><span class="n">image</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">(</span><span class="n">masks</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)))</span>
</span><span id="__span-8-18"><a id="__codelineno-8-18" name="__codelineno-8-18" href="#__codelineno-8-18"></a>
</span><span id="__span-8-19"><a id="__codelineno-8-19" name="__codelineno-8-19" href="#__codelineno-8-19"></a> <span class="k">return</span> <span class="n">dataset</span>
</span><span id="__span-8-20"><a id="__codelineno-8-20" name="__codelineno-8-20" href="#__codelineno-8-20"></a>
</span><span id="__span-8-21"><a id="__codelineno-8-21" name="__codelineno-8-21" href="#__codelineno-8-21"></a> <span class="k">def</span> <span class="nf">compute_loss</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch</span><span class="p">:</span> <span class="n">Batch</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">:</span>
</span><span id="__span-8-22"><a id="__codelineno-8-22" name="__codelineno-8-22" href="#__codelineno-8-22"></a> <span class="k">raise</span> <span class="ne">NotImplementedError</span><span class="p">(</span><span class="s2">&quot;We&#39;ll implement this later&quot;</span><span class="p">)</span>
</span><span id="__span-8-23"><a id="__codelineno-8-23" name="__codelineno-8-23" href="#__codelineno-8-23"></a>
</span><span id="__span-8-24"><a id="__codelineno-8-24" name="__codelineno-8-24" href="#__codelineno-8-24"></a>
</span><span id="__span-8-25"><a id="__codelineno-8-25" name="__codelineno-8-25" href="#__codelineno-8-25"></a><span class="n">trainer</span> <span class="o">=</span> <span class="n">AutoencoderTrainer</span><span class="p">(</span><span class="n">config</span><span class="p">)</span>
</span></code></pre></div>
<h3 id="model-registration">Model registration<a class="headerlink" href="#model-registration" title="Permanent link">&para;</a></h3>
<p>For the Trainer to be able to handle the model, we need to register it. </p>
<p>We need two things to do so: </p>
<ul>
<li>Add <code>refiners.training_utils.ModelConfig</code> attribute to the Config named <code>autoencoder</code>.</li>
<li>Add a method to the Trainer subclass that returns the model decorated with <code>@register_model</code> decorator. This method should take the <code>ModelConfig</code> as an argument. The Trainer's <code>__init__</code> will register the models and add any parameters to the optimizer that have <code>requires_grad</code> enabled.</li>
</ul>
<p>After registering the model, the <code>self.autoencoder</code> attribute will be available in the Trainer.</p>
<div class="language-python highlight"><pre><span></span><code><span id="__span-9-1"><a id="__codelineno-9-1" name="__codelineno-9-1" href="#__codelineno-9-1"></a><span class="kn">from</span> <span class="nn">refiners.training_utils</span> <span class="kn">import</span> <span class="n">ModelConfig</span><span class="p">,</span> <span class="n">register_model</span>
</span><span id="__span-9-2"><a id="__codelineno-9-2" name="__codelineno-9-2" href="#__codelineno-9-2"></a>
</span><span id="__span-9-3"><a id="__codelineno-9-3" name="__codelineno-9-3" href="#__codelineno-9-3"></a>
</span><span id="__span-9-4"><a id="__codelineno-9-4" name="__codelineno-9-4" href="#__codelineno-9-4"></a><span class="k">class</span> <span class="nc">AutoencoderModelConfig</span><span class="p">(</span><span class="n">ModelConfig</span><span class="p">):</span>
</span><span id="__span-9-5"><a id="__codelineno-9-5" name="__codelineno-9-5" href="#__codelineno-9-5"></a> <span class="k">pass</span>
</span><span id="__span-9-6"><a id="__codelineno-9-6" name="__codelineno-9-6" href="#__codelineno-9-6"></a>
</span><span id="__span-9-7"><a id="__codelineno-9-7" name="__codelineno-9-7" href="#__codelineno-9-7"></a>
</span><span id="__span-9-8"><a id="__codelineno-9-8" name="__codelineno-9-8" href="#__codelineno-9-8"></a><span class="k">class</span> <span class="nc">AutoencoderConfig</span><span class="p">(</span><span class="n">BaseConfig</span><span class="p">):</span>
</span><span id="__span-9-9"><a id="__codelineno-9-9" name="__codelineno-9-9" href="#__codelineno-9-9"></a> <span class="n">num_images</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">2048</span>
</span><span id="__span-9-10"><a id="__codelineno-9-10" name="__codelineno-9-10" href="#__codelineno-9-10"></a> <span class="n">batch_size</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">32</span>
</span><span id="__span-9-11"><a id="__codelineno-9-11" name="__codelineno-9-11" href="#__codelineno-9-11"></a> <span class="n">autoencoder</span><span class="p">:</span> <span class="n">AutoencoderModelConfig</span>
</span><span id="__span-9-12"><a id="__codelineno-9-12" name="__codelineno-9-12" href="#__codelineno-9-12"></a>
</span><span id="__span-9-13"><a id="__codelineno-9-13" name="__codelineno-9-13" href="#__codelineno-9-13"></a>
</span><span id="__span-9-14"><a id="__codelineno-9-14" name="__codelineno-9-14" href="#__codelineno-9-14"></a><span class="k">class</span> <span class="nc">AutoencoderTrainer</span><span class="p">(</span><span class="n">Trainer</span><span class="p">[</span><span class="n">AutoencoderConfig</span><span class="p">,</span> <span class="n">Batch</span><span class="p">]):</span>
</span><span id="__span-9-15"><a id="__codelineno-9-15" name="__codelineno-9-15" href="#__codelineno-9-15"></a> <span class="c1"># ... other methods</span>
</span><span id="__span-9-16"><a id="__codelineno-9-16" name="__codelineno-9-16" href="#__codelineno-9-16"></a>
</span><span id="__span-9-17"><a id="__codelineno-9-17" name="__codelineno-9-17" href="#__codelineno-9-17"></a> <span class="nd">@register_model</span><span class="p">()</span>
</span><span id="__span-9-18"><a id="__codelineno-9-18" name="__codelineno-9-18" href="#__codelineno-9-18"></a> <span class="k">def</span> <span class="nf">autoencoder</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">config</span><span class="p">:</span> <span class="n">AutoencoderModelConfig</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Autoencoder</span><span class="p">:</span>
</span><span id="__span-9-19"><a id="__codelineno-9-19" name="__codelineno-9-19" href="#__codelineno-9-19"></a> <span class="k">return</span> <span class="n">Autoencoder</span><span class="p">()</span>
</span><span id="__span-9-20"><a id="__codelineno-9-20" name="__codelineno-9-20" href="#__codelineno-9-20"></a>
</span><span id="__span-9-21"><a id="__codelineno-9-21" name="__codelineno-9-21" href="#__codelineno-9-21"></a> <span class="k">def</span> <span class="nf">compute_loss</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch</span><span class="p">:</span> <span class="n">Batch</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">:</span>
</span><span id="__span-9-22"><a id="__codelineno-9-22" name="__codelineno-9-22" href="#__codelineno-9-22"></a> <span class="n">batch</span><span class="o">.</span><span class="n">image</span> <span class="o">=</span> <span class="n">batch</span><span class="o">.</span><span class="n">image</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span>
</span><span id="__span-9-23"><a id="__codelineno-9-23" name="__codelineno-9-23" href="#__codelineno-9-23"></a> <span class="n">x_reconstructed</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">autoencoder</span><span class="o">.</span><span class="n">decoder</span><span class="p">(</span>
</span><span id="__span-9-24"><a id="__codelineno-9-24" name="__codelineno-9-24" href="#__codelineno-9-24"></a> <span class="bp">self</span><span class="o">.</span><span class="n">autoencoder</span><span class="o">.</span><span class="n">encoder</span><span class="p">(</span><span class="n">batch</span><span class="o">.</span><span class="n">image</span><span class="p">)</span>
</span><span id="__span-9-25"><a id="__codelineno-9-25" name="__codelineno-9-25" href="#__codelineno-9-25"></a> <span class="p">)</span>
</span><span id="__span-9-26"><a id="__codelineno-9-26" name="__codelineno-9-26" href="#__codelineno-9-26"></a> <span class="k">return</span> <span class="n">F</span><span class="o">.</span><span class="n">binary_cross_entropy</span><span class="p">(</span><span class="n">x_reconstructed</span><span class="p">,</span> <span class="n">batch</span><span class="o">.</span><span class="n">image</span><span class="p">)</span>
</span></code></pre></div>
<p>We now have a fully functional Trainer that can train our autoencoder. We can now call the <code>train</code> method to start the training loop.</p>
<div class="language-python highlight"><pre><span></span><code><span id="__span-10-1"><a id="__codelineno-10-1" name="__codelineno-10-1" href="#__codelineno-10-1"></a><span class="n">trainer</span><span class="o">.</span><span class="n">train</span><span class="p">()</span>
</span></code></pre></div>
<p><img alt="alt text" src="terminal-logging.png" /></p>
<h2 id="logging">Logging<a class="headerlink" href="#logging" title="Permanent link">&para;</a></h2>
<p>Let's write a simple logging callback to log the loss and the reconstructed images during training. A callback is a class that inherits from <code>refiners.training_utils.Callback</code> and implement any of the following methods:</p>
<ul>
<li><code>on_init_begin</code></li>
<li><code>on_init_end</code></li>
<li><code>on_train_begin</code></li>
<li><code>on_train_end</code></li>
<li><code>on_epoch_begin</code></li>
<li><code>on_epoch_end</code></li>
<li><code>on_step_begin</code></li>
<li><code>on_step_end</code></li>
<li><code>on_backward_begin</code></li>
<li><code>on_backward_end</code></li>
<li><code>on_optimizer_step_begin</code></li>
<li><code>on_optimizer_step_end</code></li>
<li><code>on_compute_loss_begin</code></li>
<li><code>on_compute_loss_end</code></li>
<li><code>on_evaluate_begin</code></li>
<li><code>on_evaluate_end</code></li>
<li><code>on_lr_scheduler_step_begin</code></li>
<li><code>on_lr_scheduler_step_end</code></li>
</ul>
<p>We will implement the <code>on_epoch_end</code> method to log the loss and the reconstructed images and the <code>on_compute_loss_end</code> method to store the loss in a list.</p>
<div class="language-python highlight"><pre><span></span><code><span id="__span-11-1"><a id="__codelineno-11-1" name="__codelineno-11-1" href="#__codelineno-11-1"></a><span class="kn">from</span> <span class="nn">refiners.training_utils</span> <span class="kn">import</span> <span class="n">Callback</span>
</span><span id="__span-11-2"><a id="__codelineno-11-2" name="__codelineno-11-2" href="#__codelineno-11-2"></a><span class="kn">from</span> <span class="nn">loguru</span> <span class="kn">import</span> <span class="n">logger</span>
</span><span id="__span-11-3"><a id="__codelineno-11-3" name="__codelineno-11-3" href="#__codelineno-11-3"></a><span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Any</span>
</span><span id="__span-11-4"><a id="__codelineno-11-4" name="__codelineno-11-4" href="#__codelineno-11-4"></a>
</span><span id="__span-11-5"><a id="__codelineno-11-5" name="__codelineno-11-5" href="#__codelineno-11-5"></a>
</span><span id="__span-11-6"><a id="__codelineno-11-6" name="__codelineno-11-6" href="#__codelineno-11-6"></a><span class="k">class</span> <span class="nc">LoggingCallback</span><span class="p">(</span><span class="n">Callback</span><span class="p">[</span><span class="n">Any</span><span class="p">]):</span>
</span><span id="__span-11-7"><a id="__codelineno-11-7" name="__codelineno-11-7" href="#__codelineno-11-7"></a> <span class="n">losses</span><span class="p">:</span> <span class="nb">list</span><span class="p">[</span><span class="nb">float</span><span class="p">]</span> <span class="o">=</span> <span class="p">[]</span>
</span><span id="__span-11-8"><a id="__codelineno-11-8" name="__codelineno-11-8" href="#__codelineno-11-8"></a>
</span><span id="__span-11-9"><a id="__codelineno-11-9" name="__codelineno-11-9" href="#__codelineno-11-9"></a> <span class="k">def</span> <span class="nf">on_compute_loss_end</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">loss</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="kc">None</span><span class="p">:</span>
</span><span id="__span-11-10"><a id="__codelineno-11-10" name="__codelineno-11-10" href="#__codelineno-11-10"></a> <span class="bp">self</span><span class="o">.</span><span class="n">losses</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">loss</span><span class="o">.</span><span class="n">item</span><span class="p">())</span>
</span><span id="__span-11-11"><a id="__codelineno-11-11" name="__codelineno-11-11" href="#__codelineno-11-11"></a>
</span><span id="__span-11-12"><a id="__codelineno-11-12" name="__codelineno-11-12" href="#__codelineno-11-12"></a> <span class="k">def</span> <span class="nf">on_epoch_end</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">epoch</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="kc">None</span><span class="p">:</span>
</span><span id="__span-11-13"><a id="__codelineno-11-13" name="__codelineno-11-13" href="#__codelineno-11-13"></a> <span class="n">mean_loss</span> <span class="o">=</span> <span class="nb">sum</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">losses</span><span class="p">)</span> <span class="o">/</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">losses</span><span class="p">)</span>
</span><span id="__span-11-14"><a id="__codelineno-11-14" name="__codelineno-11-14" href="#__codelineno-11-14"></a> <span class="n">logger</span><span class="o">.</span><span class="n">info</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Mean loss: </span><span class="si">{</span><span class="n">mean_loss</span><span class="si">}</span><span class="s2">, epoch: </span><span class="si">{</span><span class="n">epoch</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
</span><span id="__span-11-15"><a id="__codelineno-11-15" name="__codelineno-11-15" href="#__codelineno-11-15"></a> <span class="bp">self</span><span class="o">.</span><span class="n">losses</span> <span class="o">=</span> <span class="p">[]</span>
</span></code></pre></div>
<p>Exactly like models, we need to register the callback to the Trainer. We can do so by adding a <code>CallbackConfig</code> attribute to the config named <code>logging</code> and adding a method to the Trainer class that returns the callback decorated with <code>@register_callback</code> decorator. </p>
<div class="language-python highlight"><pre><span></span><code><span id="__span-12-1"><a id="__codelineno-12-1" name="__codelineno-12-1" href="#__codelineno-12-1"></a><span class="kn">from</span> <span class="nn">refiners.training_utils</span> <span class="kn">import</span> <span class="n">CallbackConfig</span><span class="p">,</span> <span class="n">register_callback</span>
</span><span id="__span-12-2"><a id="__codelineno-12-2" name="__codelineno-12-2" href="#__codelineno-12-2"></a>
</span><span id="__span-12-3"><a id="__codelineno-12-3" name="__codelineno-12-3" href="#__codelineno-12-3"></a><span class="k">class</span> <span class="nc">AutoencoderConfig</span><span class="p">(</span><span class="n">BaseConfig</span><span class="p">):</span>
</span><span id="__span-12-4"><a id="__codelineno-12-4" name="__codelineno-12-4" href="#__codelineno-12-4"></a> <span class="c1"># ... other properties</span>
</span><span id="__span-12-5"><a id="__codelineno-12-5" name="__codelineno-12-5" href="#__codelineno-12-5"></a> <span class="n">logging</span><span class="p">:</span> <span class="n">CallbackConfig</span> <span class="o">=</span> <span class="n">CallbackConfig</span><span class="p">()</span>
</span><span id="__span-12-6"><a id="__codelineno-12-6" name="__codelineno-12-6" href="#__codelineno-12-6"></a>
</span><span id="__span-12-7"><a id="__codelineno-12-7" name="__codelineno-12-7" href="#__codelineno-12-7"></a>
</span><span id="__span-12-8"><a id="__codelineno-12-8" name="__codelineno-12-8" href="#__codelineno-12-8"></a><span class="k">class</span> <span class="nc">AutoencoderTrainer</span><span class="p">(</span><span class="n">Trainer</span><span class="p">[</span><span class="n">AutoencoderConfig</span><span class="p">,</span> <span class="n">Batch</span><span class="p">]):</span>
</span><span id="__span-12-9"><a id="__codelineno-12-9" name="__codelineno-12-9" href="#__codelineno-12-9"></a> <span class="c1"># ... other methods</span>
</span><span id="__span-12-10"><a id="__codelineno-12-10" name="__codelineno-12-10" href="#__codelineno-12-10"></a>
</span><span id="__span-12-11"><a id="__codelineno-12-11" name="__codelineno-12-11" href="#__codelineno-12-11"></a> <span class="nd">@register_callback</span><span class="p">()</span>
</span><span id="__span-12-12"><a id="__codelineno-12-12" name="__codelineno-12-12" href="#__codelineno-12-12"></a> <span class="k">def</span> <span class="nf">logging</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">config</span><span class="p">:</span> <span class="n">CallbackConfig</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">LoggingCallback</span><span class="p">:</span>
</span><span id="__span-12-13"><a id="__codelineno-12-13" name="__codelineno-12-13" href="#__codelineno-12-13"></a> <span class="k">return</span> <span class="n">LoggingCallback</span><span class="p">()</span>
</span></code></pre></div>
<p><img alt="alt text" src="loss-logging.png" /></p>
<h2 id="evaluation">Evaluation<a class="headerlink" href="#evaluation" title="Permanent link">&para;</a></h2>
<p>Let's add an evaluation step to the Trainer. We will generate a few masks and their reconstructions and save them to a file. We start by implementing a <code>compute_evaluation</code> method, then we register a callback to call this method at regular intervals.</p>
<div class="language-python highlight"><pre><span></span><code><span id="__span-13-1"><a id="__codelineno-13-1" name="__codelineno-13-1" href="#__codelineno-13-1"></a><span class="k">class</span> <span class="nc">AutoencoderTrainer</span><span class="p">(</span><span class="n">Trainer</span><span class="p">[</span><span class="n">AutoencoderConfig</span><span class="p">,</span> <span class="n">Batch</span><span class="p">]):</span>
</span><span id="__span-13-2"><a id="__codelineno-13-2" name="__codelineno-13-2" href="#__codelineno-13-2"></a> <span class="c1"># ... other methods</span>
</span><span id="__span-13-3"><a id="__codelineno-13-3" name="__codelineno-13-3" href="#__codelineno-13-3"></a>
</span><span id="__span-13-4"><a id="__codelineno-13-4" name="__codelineno-13-4" href="#__codelineno-13-4"></a> <span class="k">def</span> <span class="nf">compute_evaluation</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="kc">None</span><span class="p">:</span>
</span><span id="__span-13-5"><a id="__codelineno-13-5" name="__codelineno-13-5" href="#__codelineno-13-5"></a> <span class="n">generator</span> <span class="o">=</span> <span class="n">generate_mask</span><span class="p">(</span><span class="n">size</span><span class="o">=</span><span class="mi">64</span><span class="p">,</span> <span class="n">seed</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
</span><span id="__span-13-6"><a id="__codelineno-13-6" name="__codelineno-13-6" href="#__codelineno-13-6"></a>
</span><span id="__span-13-7"><a id="__codelineno-13-7" name="__codelineno-13-7" href="#__codelineno-13-7"></a> <span class="n">grid</span><span class="p">:</span> <span class="nb">list</span><span class="p">[</span><span class="nb">tuple</span><span class="p">[</span><span class="n">Image</span><span class="o">.</span><span class="n">Image</span><span class="p">,</span> <span class="n">Image</span><span class="o">.</span><span class="n">Image</span><span class="p">]]</span> <span class="o">=</span> <span class="p">[]</span>
</span><span id="__span-13-8"><a id="__codelineno-13-8" name="__codelineno-13-8" href="#__codelineno-13-8"></a> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">4</span><span class="p">):</span>
</span><span id="__span-13-9"><a id="__codelineno-13-9" name="__codelineno-13-9" href="#__codelineno-13-9"></a> <span class="n">mask</span> <span class="o">=</span> <span class="nb">next</span><span class="p">(</span><span class="n">generator</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span>
</span><span id="__span-13-10"><a id="__codelineno-13-10" name="__codelineno-13-10" href="#__codelineno-13-10"></a> <span class="n">x_reconstructed</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">autoencoder</span><span class="o">.</span><span class="n">decoder</span><span class="p">(</span>
</span><span id="__span-13-11"><a id="__codelineno-13-11" name="__codelineno-13-11" href="#__codelineno-13-11"></a> <span class="bp">self</span><span class="o">.</span><span class="n">autoencoder</span><span class="o">.</span><span class="n">encoder</span><span class="p">(</span><span class="n">mask</span><span class="p">)</span>
</span><span id="__span-13-12"><a id="__codelineno-13-12" name="__codelineno-13-12" href="#__codelineno-13-12"></a> <span class="p">)</span>
</span><span id="__span-13-13"><a id="__codelineno-13-13" name="__codelineno-13-13" href="#__codelineno-13-13"></a> <span class="n">loss</span> <span class="o">=</span> <span class="n">F</span><span class="o">.</span><span class="n">mse_loss</span><span class="p">(</span><span class="n">x_reconstructed</span><span class="p">,</span> <span class="n">mask</span><span class="p">)</span>
</span><span id="__span-13-14"><a id="__codelineno-13-14" name="__codelineno-13-14" href="#__codelineno-13-14"></a> <span class="n">logger</span><span class="o">.</span><span class="n">info</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Validation loss: </span><span class="si">{</span><span class="n">loss</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">()</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
</span><span id="__span-13-15"><a id="__codelineno-13-15" name="__codelineno-13-15" href="#__codelineno-13-15"></a> <span class="n">grid</span><span class="o">.</span><span class="n">append</span><span class="p">(</span>
</span><span id="__span-13-16"><a id="__codelineno-13-16" name="__codelineno-13-16" href="#__codelineno-13-16"></a> <span class="p">(</span><span class="n">tensor_to_image</span><span class="p">(</span><span class="n">mask</span><span class="p">),</span> <span class="n">tensor_to_image</span><span class="p">((</span><span class="n">x_reconstructed</span><span class="o">&gt;</span><span class="mf">0.5</span><span class="p">)</span><span class="o">.</span><span class="n">float</span><span class="p">()))</span>
</span><span id="__span-13-17"><a id="__codelineno-13-17" name="__codelineno-13-17" href="#__codelineno-13-17"></a> <span class="p">)</span>
</span><span id="__span-13-18"><a id="__codelineno-13-18" name="__codelineno-13-18" href="#__codelineno-13-18"></a>
</span><span id="__span-13-19"><a id="__codelineno-13-19" name="__codelineno-13-19" href="#__codelineno-13-19"></a> <span class="kn">import</span> <span class="nn">matplotlib.pyplot</span> <span class="k">as</span> <span class="nn">plt</span>
</span><span id="__span-13-20"><a id="__codelineno-13-20" name="__codelineno-13-20" href="#__codelineno-13-20"></a>
</span><span id="__span-13-21"><a id="__codelineno-13-21" name="__codelineno-13-21" href="#__codelineno-13-21"></a> <span class="n">_</span><span class="p">,</span> <span class="n">axes</span> <span class="o">=</span> <span class="n">plt</span><span class="o">.</span><span class="n">subplots</span><span class="p">(</span><span class="mi">4</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">8</span><span class="p">,</span> <span class="mi">16</span><span class="p">))</span>
</span><span id="__span-13-22"><a id="__codelineno-13-22" name="__codelineno-13-22" href="#__codelineno-13-22"></a>
</span><span id="__span-13-23"><a id="__codelineno-13-23" name="__codelineno-13-23" href="#__codelineno-13-23"></a> <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="p">(</span><span class="n">mask</span><span class="p">,</span> <span class="n">reconstructed</span><span class="p">)</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">grid</span><span class="p">):</span>
</span><span id="__span-13-24"><a id="__codelineno-13-24" name="__codelineno-13-24" href="#__codelineno-13-24"></a> <span class="n">axes</span><span class="p">[</span><span class="n">i</span><span class="p">,</span> <span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">imshow</span><span class="p">(</span><span class="n">mask</span><span class="p">,</span> <span class="n">cmap</span><span class="o">=</span><span class="s1">&#39;gray&#39;</span><span class="p">)</span>
</span><span id="__span-13-25"><a id="__codelineno-13-25" name="__codelineno-13-25" href="#__codelineno-13-25"></a> <span class="n">axes</span><span class="p">[</span><span class="n">i</span><span class="p">,</span> <span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">axis</span><span class="p">(</span><span class="s1">&#39;off&#39;</span><span class="p">)</span>
</span><span id="__span-13-26"><a id="__codelineno-13-26" name="__codelineno-13-26" href="#__codelineno-13-26"></a> <span class="n">axes</span><span class="p">[</span><span class="n">i</span><span class="p">,</span> <span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">set_title</span><span class="p">(</span><span class="s1">&#39;Mask&#39;</span><span class="p">)</span>
</span><span id="__span-13-27"><a id="__codelineno-13-27" name="__codelineno-13-27" href="#__codelineno-13-27"></a>
</span><span id="__span-13-28"><a id="__codelineno-13-28" name="__codelineno-13-28" href="#__codelineno-13-28"></a> <span class="n">axes</span><span class="p">[</span><span class="n">i</span><span class="p">,</span> <span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">imshow</span><span class="p">(</span><span class="n">reconstructed</span><span class="p">,</span> <span class="n">cmap</span><span class="o">=</span><span class="s1">&#39;gray&#39;</span><span class="p">)</span>
</span><span id="__span-13-29"><a id="__codelineno-13-29" name="__codelineno-13-29" href="#__codelineno-13-29"></a> <span class="n">axes</span><span class="p">[</span><span class="n">i</span><span class="p">,</span> <span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">axis</span><span class="p">(</span><span class="s1">&#39;off&#39;</span><span class="p">)</span>
</span><span id="__span-13-30"><a id="__codelineno-13-30" name="__codelineno-13-30" href="#__codelineno-13-30"></a> <span class="n">axes</span><span class="p">[</span><span class="n">i</span><span class="p">,</span> <span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">set_title</span><span class="p">(</span><span class="s1">&#39;Reconstructed&#39;</span><span class="p">)</span>
</span><span id="__span-13-31"><a id="__codelineno-13-31" name="__codelineno-13-31" href="#__codelineno-13-31"></a>
</span><span id="__span-13-32"><a id="__codelineno-13-32" name="__codelineno-13-32" href="#__codelineno-13-32"></a> <span class="n">plt</span><span class="o">.</span><span class="n">tight_layout</span><span class="p">()</span>
</span><span id="__span-13-33"><a id="__codelineno-13-33" name="__codelineno-13-33" href="#__codelineno-13-33"></a> <span class="n">plt</span><span class="o">.</span><span class="n">savefig</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;result_</span><span class="si">{</span><span class="n">trainer</span><span class="o">.</span><span class="n">clock</span><span class="o">.</span><span class="n">epoch</span><span class="si">}</span><span class="s2">.png&quot;</span><span class="p">)</span>
</span><span id="__span-13-34"><a id="__codelineno-13-34" name="__codelineno-13-34" href="#__codelineno-13-34"></a> <span class="n">plt</span><span class="o">.</span><span class="n">close</span><span class="p">()</span>
</span></code></pre></div>
<p>We starting by implementing an <code>EvaluationConfig</code> that controls the evaluation interval and the seed for the random generator.</p>
<div class="language-python highlight"><pre><span></span><code><span id="__span-14-1"><a id="__codelineno-14-1" name="__codelineno-14-1" href="#__codelineno-14-1"></a><span class="kn">from</span> <span class="nn">refiners.training_utils.config</span> <span class="kn">import</span> <span class="n">TimeValueField</span>
</span><span id="__span-14-2"><a id="__codelineno-14-2" name="__codelineno-14-2" href="#__codelineno-14-2"></a>
</span><span id="__span-14-3"><a id="__codelineno-14-3" name="__codelineno-14-3" href="#__codelineno-14-3"></a><span class="k">class</span> <span class="nc">EvaluationConfig</span><span class="p">(</span><span class="n">CallbackConfig</span><span class="p">):</span>
</span><span id="__span-14-4"><a id="__codelineno-14-4" name="__codelineno-14-4" href="#__codelineno-14-4"></a> <span class="n">interval</span><span class="p">:</span> <span class="n">TimeValueField</span>
</span><span id="__span-14-5"><a id="__codelineno-14-5" name="__codelineno-14-5" href="#__codelineno-14-5"></a> <span class="n">seed</span><span class="p">:</span> <span class="nb">int</span>
</span></code></pre></div>
<p>The <code>TimeValueField</code> is a custom field that allow Pydantic to parse a string representing a time value (e.g., <code>"50:epochs"</code>) into a <code>TimeValue</code> object. This is useful to specify the evaluation interval in the configuration file.</p>
<div class="language-python highlight"><pre><span></span><code><span id="__span-15-1"><a id="__codelineno-15-1" name="__codelineno-15-1" href="#__codelineno-15-1"></a><span class="kn">from</span> <span class="nn">refiners.training_utils</span> <span class="kn">import</span> <span class="n">scoped_seed</span><span class="p">,</span> <span class="n">Callback</span>
</span><span id="__span-15-2"><a id="__codelineno-15-2" name="__codelineno-15-2" href="#__codelineno-15-2"></a>
</span><span id="__span-15-3"><a id="__codelineno-15-3" name="__codelineno-15-3" href="#__codelineno-15-3"></a><span class="k">class</span> <span class="nc">EvaluationCallback</span><span class="p">(</span><span class="n">Callback</span><span class="p">[</span><span class="n">Any</span><span class="p">]):</span>
</span><span id="__span-15-4"><a id="__codelineno-15-4" name="__codelineno-15-4" href="#__codelineno-15-4"></a> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">config</span><span class="p">:</span> <span class="n">EvaluationConfig</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="kc">None</span><span class="p">:</span>
</span><span id="__span-15-5"><a id="__codelineno-15-5" name="__codelineno-15-5" href="#__codelineno-15-5"></a> <span class="bp">self</span><span class="o">.</span><span class="n">config</span> <span class="o">=</span> <span class="n">config</span>
</span><span id="__span-15-6"><a id="__codelineno-15-6" name="__codelineno-15-6" href="#__codelineno-15-6"></a>
</span><span id="__span-15-7"><a id="__codelineno-15-7" name="__codelineno-15-7" href="#__codelineno-15-7"></a> <span class="k">def</span> <span class="nf">on_epoch_end</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">trainer</span><span class="p">:</span> <span class="n">Trainer</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="kc">None</span><span class="p">:</span>
</span><span id="__span-15-8"><a id="__codelineno-15-8" name="__codelineno-15-8" href="#__codelineno-15-8"></a> <span class="c1"># The `is_due` method checks if the current epoch is a multiple of the interval.</span>
</span><span id="__span-15-9"><a id="__codelineno-15-9" name="__codelineno-15-9" href="#__codelineno-15-9"></a> <span class="k">if</span> <span class="ow">not</span> <span class="n">trainer</span><span class="o">.</span><span class="n">clock</span><span class="o">.</span><span class="n">is_due</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">interval</span><span class="p">):</span>
</span><span id="__span-15-10"><a id="__codelineno-15-10" name="__codelineno-15-10" href="#__codelineno-15-10"></a> <span class="k">return</span>
</span><span id="__span-15-11"><a id="__codelineno-15-11" name="__codelineno-15-11" href="#__codelineno-15-11"></a>
</span><span id="__span-15-12"><a id="__codelineno-15-12" name="__codelineno-15-12" href="#__codelineno-15-12"></a> <span class="c1"># The `scoped_seed` context manager encapsulates the random state for the evaluation and restores it after the </span>
</span><span id="__span-15-13"><a id="__codelineno-15-13" name="__codelineno-15-13" href="#__codelineno-15-13"></a> <span class="c1"># evaluation.</span>
</span><span id="__span-15-14"><a id="__codelineno-15-14" name="__codelineno-15-14" href="#__codelineno-15-14"></a> <span class="k">with</span> <span class="n">scoped_seed</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">seed</span><span class="p">):</span>
</span><span id="__span-15-15"><a id="__codelineno-15-15" name="__codelineno-15-15" href="#__codelineno-15-15"></a> <span class="n">trainer</span><span class="o">.</span><span class="n">compute_evaluation</span><span class="p">()</span>
</span></code></pre></div>
<p>We can now register the callback to the Trainer.</p>
<div class="language-python highlight"><pre><span></span><code><span id="__span-16-1"><a id="__codelineno-16-1" name="__codelineno-16-1" href="#__codelineno-16-1"></a><span class="k">class</span> <span class="nc">AutoencoderConfig</span><span class="p">(</span><span class="n">BaseConfig</span><span class="p">):</span>
</span><span id="__span-16-2"><a id="__codelineno-16-2" name="__codelineno-16-2" href="#__codelineno-16-2"></a> <span class="c1"># ... other properties</span>
</span><span id="__span-16-3"><a id="__codelineno-16-3" name="__codelineno-16-3" href="#__codelineno-16-3"></a> <span class="n">evaluation</span><span class="p">:</span> <span class="n">EvaluationConfig</span>
</span></code></pre></div>
<div class="language-python highlight"><pre><span></span><code><span id="__span-17-1"><a id="__codelineno-17-1" name="__codelineno-17-1" href="#__codelineno-17-1"></a><span class="k">class</span> <span class="nc">AutoencoderTrainer</span><span class="p">(</span><span class="n">Trainer</span><span class="p">[</span><span class="n">AutoencoderConfig</span><span class="p">,</span> <span class="n">Batch</span><span class="p">]):</span>
</span><span id="__span-17-2"><a id="__codelineno-17-2" name="__codelineno-17-2" href="#__codelineno-17-2"></a> <span class="c1"># ... other methods</span>
</span><span id="__span-17-3"><a id="__codelineno-17-3" name="__codelineno-17-3" href="#__codelineno-17-3"></a>
</span><span id="__span-17-4"><a id="__codelineno-17-4" name="__codelineno-17-4" href="#__codelineno-17-4"></a> <span class="nd">@register_callback</span><span class="p">()</span>
</span><span id="__span-17-5"><a id="__codelineno-17-5" name="__codelineno-17-5" href="#__codelineno-17-5"></a> <span class="k">def</span> <span class="nf">evaluation</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">config</span><span class="p">:</span> <span class="n">EvaluationConfig</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">EvaluationCallback</span><span class="p">:</span>
</span><span id="__span-17-6"><a id="__codelineno-17-6" name="__codelineno-17-6" href="#__codelineno-17-6"></a> <span class="k">return</span> <span class="n">EvaluationCallback</span><span class="p">(</span><span class="n">config</span><span class="p">)</span>
</span></code></pre></div>
<p>We can now train the model and see the results in the <code>result_{epoch}.png</code> files.</p>
<p><img alt="alt text" src="evaluation.png" /></p>
<h2 id="wrap-up">Wrap up<a class="headerlink" href="#wrap-up" title="Permanent link">&para;</a></h2>
<p>You can train this toy model using the code below:</p>
<details class="complete end-to-end code">
<summary>Expand to see the full code.</summary>
<div class="language-py highlight"><pre><span></span><code><span id="__span-18-1"><a id="__codelineno-18-1" name="__codelineno-18-1" href="#__codelineno-18-1"></a><span class="kn">import</span> <span class="nn">random</span>
</span><span id="__span-18-2"><a id="__codelineno-18-2" name="__codelineno-18-2" href="#__codelineno-18-2"></a><span class="kn">from</span> <span class="nn">dataclasses</span> <span class="kn">import</span> <span class="n">dataclass</span>
</span><span id="__span-18-3"><a id="__codelineno-18-3" name="__codelineno-18-3" href="#__codelineno-18-3"></a><span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Any</span><span class="p">,</span> <span class="n">Generator</span>
</span><span id="__span-18-4"><a id="__codelineno-18-4" name="__codelineno-18-4" href="#__codelineno-18-4"></a>
</span><span id="__span-18-5"><a id="__codelineno-18-5" name="__codelineno-18-5" href="#__codelineno-18-5"></a><span class="kn">import</span> <span class="nn">torch</span>
</span><span id="__span-18-6"><a id="__codelineno-18-6" name="__codelineno-18-6" href="#__codelineno-18-6"></a><span class="kn">from</span> <span class="nn">loguru</span> <span class="kn">import</span> <span class="n">logger</span>
</span><span id="__span-18-7"><a id="__codelineno-18-7" name="__codelineno-18-7" href="#__codelineno-18-7"></a><span class="kn">from</span> <span class="nn">PIL</span> <span class="kn">import</span> <span class="n">Image</span>
</span><span id="__span-18-8"><a id="__codelineno-18-8" name="__codelineno-18-8" href="#__codelineno-18-8"></a><span class="kn">from</span> <span class="nn">torch.nn</span> <span class="kn">import</span> <span class="n">functional</span> <span class="k">as</span> <span class="n">F</span>
</span><span id="__span-18-9"><a id="__codelineno-18-9" name="__codelineno-18-9" href="#__codelineno-18-9"></a>
</span><span id="__span-18-10"><a id="__codelineno-18-10" name="__codelineno-18-10" href="#__codelineno-18-10"></a><span class="kn">from</span> <span class="nn">refiners.fluxion</span> <span class="kn">import</span> <span class="n">layers</span> <span class="k">as</span> <span class="n">fl</span>
</span><span id="__span-18-11"><a id="__codelineno-18-11" name="__codelineno-18-11" href="#__codelineno-18-11"></a><span class="kn">from</span> <span class="nn">refiners.fluxion.utils</span> <span class="kn">import</span> <span class="n">image_to_tensor</span><span class="p">,</span> <span class="n">tensor_to_image</span>
</span><span id="__span-18-12"><a id="__codelineno-18-12" name="__codelineno-18-12" href="#__codelineno-18-12"></a><span class="kn">from</span> <span class="nn">refiners.training_utils</span> <span class="kn">import</span> <span class="p">(</span>
</span><span id="__span-18-13"><a id="__codelineno-18-13" name="__codelineno-18-13" href="#__codelineno-18-13"></a> <span class="n">BaseConfig</span><span class="p">,</span>
</span><span id="__span-18-14"><a id="__codelineno-18-14" name="__codelineno-18-14" href="#__codelineno-18-14"></a> <span class="n">Callback</span><span class="p">,</span>
</span><span id="__span-18-15"><a id="__codelineno-18-15" name="__codelineno-18-15" href="#__codelineno-18-15"></a> <span class="n">CallbackConfig</span><span class="p">,</span>
</span><span id="__span-18-16"><a id="__codelineno-18-16" name="__codelineno-18-16" href="#__codelineno-18-16"></a> <span class="n">ClockConfig</span><span class="p">,</span>
</span><span id="__span-18-17"><a id="__codelineno-18-17" name="__codelineno-18-17" href="#__codelineno-18-17"></a> <span class="n">Epoch</span><span class="p">,</span>
</span><span id="__span-18-18"><a id="__codelineno-18-18" name="__codelineno-18-18" href="#__codelineno-18-18"></a> <span class="n">LRSchedulerConfig</span><span class="p">,</span>
</span><span id="__span-18-19"><a id="__codelineno-18-19" name="__codelineno-18-19" href="#__codelineno-18-19"></a> <span class="n">LRSchedulerType</span><span class="p">,</span>
</span><span id="__span-18-20"><a id="__codelineno-18-20" name="__codelineno-18-20" href="#__codelineno-18-20"></a> <span class="n">ModelConfig</span><span class="p">,</span>
</span><span id="__span-18-21"><a id="__codelineno-18-21" name="__codelineno-18-21" href="#__codelineno-18-21"></a> <span class="n">OptimizerConfig</span><span class="p">,</span>
</span><span id="__span-18-22"><a id="__codelineno-18-22" name="__codelineno-18-22" href="#__codelineno-18-22"></a> <span class="n">Optimizers</span><span class="p">,</span>
</span><span id="__span-18-23"><a id="__codelineno-18-23" name="__codelineno-18-23" href="#__codelineno-18-23"></a> <span class="n">Trainer</span><span class="p">,</span>
</span><span id="__span-18-24"><a id="__codelineno-18-24" name="__codelineno-18-24" href="#__codelineno-18-24"></a> <span class="n">TrainingConfig</span><span class="p">,</span>
</span><span id="__span-18-25"><a id="__codelineno-18-25" name="__codelineno-18-25" href="#__codelineno-18-25"></a> <span class="n">register_callback</span><span class="p">,</span>
</span><span id="__span-18-26"><a id="__codelineno-18-26" name="__codelineno-18-26" href="#__codelineno-18-26"></a> <span class="n">register_model</span><span class="p">,</span>
</span><span id="__span-18-27"><a id="__codelineno-18-27" name="__codelineno-18-27" href="#__codelineno-18-27"></a><span class="p">)</span>
</span><span id="__span-18-28"><a id="__codelineno-18-28" name="__codelineno-18-28" href="#__codelineno-18-28"></a><span class="kn">from</span> <span class="nn">refiners.training_utils.common</span> <span class="kn">import</span> <span class="n">scoped_seed</span>
</span><span id="__span-18-29"><a id="__codelineno-18-29" name="__codelineno-18-29" href="#__codelineno-18-29"></a><span class="kn">from</span> <span class="nn">refiners.training_utils.config</span> <span class="kn">import</span> <span class="n">TimeValueField</span>
</span><span id="__span-18-30"><a id="__codelineno-18-30" name="__codelineno-18-30" href="#__codelineno-18-30"></a>
</span><span id="__span-18-31"><a id="__codelineno-18-31" name="__codelineno-18-31" href="#__codelineno-18-31"></a>
</span><span id="__span-18-32"><a id="__codelineno-18-32" name="__codelineno-18-32" href="#__codelineno-18-32"></a><span class="k">class</span> <span class="nc">ConvBlock</span><span class="p">(</span><span class="n">fl</span><span class="o">.</span><span class="n">Chain</span><span class="p">):</span>
</span><span id="__span-18-33"><a id="__codelineno-18-33" name="__codelineno-18-33" href="#__codelineno-18-33"></a> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">in_channels</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">out_channels</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="kc">None</span><span class="p">:</span>
</span><span id="__span-18-34"><a id="__codelineno-18-34" name="__codelineno-18-34" href="#__codelineno-18-34"></a> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span>
</span><span id="__span-18-35"><a id="__codelineno-18-35" name="__codelineno-18-35" href="#__codelineno-18-35"></a> <span class="n">fl</span><span class="o">.</span><span class="n">Conv2d</span><span class="p">(</span>
</span><span id="__span-18-36"><a id="__codelineno-18-36" name="__codelineno-18-36" href="#__codelineno-18-36"></a> <span class="n">in_channels</span><span class="o">=</span><span class="n">in_channels</span><span class="p">,</span>
</span><span id="__span-18-37"><a id="__codelineno-18-37" name="__codelineno-18-37" href="#__codelineno-18-37"></a> <span class="n">out_channels</span><span class="o">=</span><span class="n">out_channels</span><span class="p">,</span>
</span><span id="__span-18-38"><a id="__codelineno-18-38" name="__codelineno-18-38" href="#__codelineno-18-38"></a> <span class="n">kernel_size</span><span class="o">=</span><span class="mi">3</span><span class="p">,</span>
</span><span id="__span-18-39"><a id="__codelineno-18-39" name="__codelineno-18-39" href="#__codelineno-18-39"></a> <span class="n">padding</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span>
</span><span id="__span-18-40"><a id="__codelineno-18-40" name="__codelineno-18-40" href="#__codelineno-18-40"></a> <span class="n">groups</span><span class="o">=</span><span class="nb">min</span><span class="p">(</span><span class="n">in_channels</span><span class="p">,</span> <span class="n">out_channels</span><span class="p">),</span>
</span><span id="__span-18-41"><a id="__codelineno-18-41" name="__codelineno-18-41" href="#__codelineno-18-41"></a> <span class="p">),</span>
</span><span id="__span-18-42"><a id="__codelineno-18-42" name="__codelineno-18-42" href="#__codelineno-18-42"></a> <span class="n">fl</span><span class="o">.</span><span class="n">LayerNorm2d</span><span class="p">(</span><span class="n">out_channels</span><span class="p">),</span>
</span><span id="__span-18-43"><a id="__codelineno-18-43" name="__codelineno-18-43" href="#__codelineno-18-43"></a> <span class="n">fl</span><span class="o">.</span><span class="n">SiLU</span><span class="p">(),</span>
</span><span id="__span-18-44"><a id="__codelineno-18-44" name="__codelineno-18-44" href="#__codelineno-18-44"></a> <span class="n">fl</span><span class="o">.</span><span class="n">Conv2d</span><span class="p">(</span>
</span><span id="__span-18-45"><a id="__codelineno-18-45" name="__codelineno-18-45" href="#__codelineno-18-45"></a> <span class="n">in_channels</span><span class="o">=</span><span class="n">out_channels</span><span class="p">,</span>
</span><span id="__span-18-46"><a id="__codelineno-18-46" name="__codelineno-18-46" href="#__codelineno-18-46"></a> <span class="n">out_channels</span><span class="o">=</span><span class="n">out_channels</span><span class="p">,</span>
</span><span id="__span-18-47"><a id="__codelineno-18-47" name="__codelineno-18-47" href="#__codelineno-18-47"></a> <span class="n">kernel_size</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span>
</span><span id="__span-18-48"><a id="__codelineno-18-48" name="__codelineno-18-48" href="#__codelineno-18-48"></a> <span class="n">padding</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span>
</span><span id="__span-18-49"><a id="__codelineno-18-49" name="__codelineno-18-49" href="#__codelineno-18-49"></a> <span class="p">),</span>
</span><span id="__span-18-50"><a id="__codelineno-18-50" name="__codelineno-18-50" href="#__codelineno-18-50"></a> <span class="n">fl</span><span class="o">.</span><span class="n">LayerNorm2d</span><span class="p">(</span><span class="n">out_channels</span><span class="p">),</span>
</span><span id="__span-18-51"><a id="__codelineno-18-51" name="__codelineno-18-51" href="#__codelineno-18-51"></a> <span class="n">fl</span><span class="o">.</span><span class="n">SiLU</span><span class="p">(),</span>
</span><span id="__span-18-52"><a id="__codelineno-18-52" name="__codelineno-18-52" href="#__codelineno-18-52"></a> <span class="p">)</span>
</span><span id="__span-18-53"><a id="__codelineno-18-53" name="__codelineno-18-53" href="#__codelineno-18-53"></a>
</span><span id="__span-18-54"><a id="__codelineno-18-54" name="__codelineno-18-54" href="#__codelineno-18-54"></a>
</span><span id="__span-18-55"><a id="__codelineno-18-55" name="__codelineno-18-55" href="#__codelineno-18-55"></a><span class="k">class</span> <span class="nc">ResidualBlock</span><span class="p">(</span><span class="n">fl</span><span class="o">.</span><span class="n">Sum</span><span class="p">):</span>
</span><span id="__span-18-56"><a id="__codelineno-18-56" name="__codelineno-18-56" href="#__codelineno-18-56"></a> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">in_channels</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">out_channels</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="kc">None</span><span class="p">:</span>
</span><span id="__span-18-57"><a id="__codelineno-18-57" name="__codelineno-18-57" href="#__codelineno-18-57"></a> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span>
</span><span id="__span-18-58"><a id="__codelineno-18-58" name="__codelineno-18-58" href="#__codelineno-18-58"></a> <span class="n">ConvBlock</span><span class="p">(</span><span class="n">in_channels</span><span class="o">=</span><span class="n">in_channels</span><span class="p">,</span> <span class="n">out_channels</span><span class="o">=</span><span class="n">out_channels</span><span class="p">),</span>
</span><span id="__span-18-59"><a id="__codelineno-18-59" name="__codelineno-18-59" href="#__codelineno-18-59"></a> <span class="n">fl</span><span class="o">.</span><span class="n">Conv2d</span><span class="p">(</span>
</span><span id="__span-18-60"><a id="__codelineno-18-60" name="__codelineno-18-60" href="#__codelineno-18-60"></a> <span class="n">in_channels</span><span class="o">=</span><span class="n">in_channels</span><span class="p">,</span>
</span><span id="__span-18-61"><a id="__codelineno-18-61" name="__codelineno-18-61" href="#__codelineno-18-61"></a> <span class="n">out_channels</span><span class="o">=</span><span class="n">out_channels</span><span class="p">,</span>
</span><span id="__span-18-62"><a id="__codelineno-18-62" name="__codelineno-18-62" href="#__codelineno-18-62"></a> <span class="n">kernel_size</span><span class="o">=</span><span class="mi">3</span><span class="p">,</span>
</span><span id="__span-18-63"><a id="__codelineno-18-63" name="__codelineno-18-63" href="#__codelineno-18-63"></a> <span class="n">padding</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span>
</span><span id="__span-18-64"><a id="__codelineno-18-64" name="__codelineno-18-64" href="#__codelineno-18-64"></a> <span class="p">),</span>
</span><span id="__span-18-65"><a id="__codelineno-18-65" name="__codelineno-18-65" href="#__codelineno-18-65"></a> <span class="p">)</span>
</span><span id="__span-18-66"><a id="__codelineno-18-66" name="__codelineno-18-66" href="#__codelineno-18-66"></a>
</span><span id="__span-18-67"><a id="__codelineno-18-67" name="__codelineno-18-67" href="#__codelineno-18-67"></a>
</span><span id="__span-18-68"><a id="__codelineno-18-68" name="__codelineno-18-68" href="#__codelineno-18-68"></a><span class="k">class</span> <span class="nc">Encoder</span><span class="p">(</span><span class="n">fl</span><span class="o">.</span><span class="n">Chain</span><span class="p">):</span>
</span><span id="__span-18-69"><a id="__codelineno-18-69" name="__codelineno-18-69" href="#__codelineno-18-69"></a> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="kc">None</span><span class="p">:</span>
</span><span id="__span-18-70"><a id="__codelineno-18-70" name="__codelineno-18-70" href="#__codelineno-18-70"></a> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span>
</span><span id="__span-18-71"><a id="__codelineno-18-71" name="__codelineno-18-71" href="#__codelineno-18-71"></a> <span class="n">ResidualBlock</span><span class="p">(</span><span class="n">in_channels</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">out_channels</span><span class="o">=</span><span class="mi">8</span><span class="p">),</span>
</span><span id="__span-18-72"><a id="__codelineno-18-72" name="__codelineno-18-72" href="#__codelineno-18-72"></a> <span class="n">fl</span><span class="o">.</span><span class="n">Downsample</span><span class="p">(</span><span class="n">channels</span><span class="o">=</span><span class="mi">8</span><span class="p">,</span> <span class="n">scale_factor</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span> <span class="n">register_shape</span><span class="o">=</span><span class="kc">False</span><span class="p">),</span>
</span><span id="__span-18-73"><a id="__codelineno-18-73" name="__codelineno-18-73" href="#__codelineno-18-73"></a> <span class="n">ResidualBlock</span><span class="p">(</span><span class="n">in_channels</span><span class="o">=</span><span class="mi">8</span><span class="p">,</span> <span class="n">out_channels</span><span class="o">=</span><span class="mi">16</span><span class="p">),</span>
</span><span id="__span-18-74"><a id="__codelineno-18-74" name="__codelineno-18-74" href="#__codelineno-18-74"></a> <span class="n">fl</span><span class="o">.</span><span class="n">Downsample</span><span class="p">(</span><span class="n">channels</span><span class="o">=</span><span class="mi">16</span><span class="p">,</span> <span class="n">scale_factor</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span> <span class="n">register_shape</span><span class="o">=</span><span class="kc">False</span><span class="p">),</span>
</span><span id="__span-18-75"><a id="__codelineno-18-75" name="__codelineno-18-75" href="#__codelineno-18-75"></a> <span class="n">ResidualBlock</span><span class="p">(</span><span class="n">in_channels</span><span class="o">=</span><span class="mi">16</span><span class="p">,</span> <span class="n">out_channels</span><span class="o">=</span><span class="mi">32</span><span class="p">),</span>
</span><span id="__span-18-76"><a id="__codelineno-18-76" name="__codelineno-18-76" href="#__codelineno-18-76"></a> <span class="n">fl</span><span class="o">.</span><span class="n">Downsample</span><span class="p">(</span><span class="n">channels</span><span class="o">=</span><span class="mi">32</span><span class="p">,</span> <span class="n">scale_factor</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span> <span class="n">register_shape</span><span class="o">=</span><span class="kc">False</span><span class="p">),</span>
</span><span id="__span-18-77"><a id="__codelineno-18-77" name="__codelineno-18-77" href="#__codelineno-18-77"></a> <span class="n">fl</span><span class="o">.</span><span class="n">Reshape</span><span class="p">(</span><span class="mi">2048</span><span class="p">),</span>
</span><span id="__span-18-78"><a id="__codelineno-18-78" name="__codelineno-18-78" href="#__codelineno-18-78"></a> <span class="n">fl</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">in_features</span><span class="o">=</span><span class="mi">2048</span><span class="p">,</span> <span class="n">out_features</span><span class="o">=</span><span class="mi">256</span><span class="p">),</span>
</span><span id="__span-18-79"><a id="__codelineno-18-79" name="__codelineno-18-79" href="#__codelineno-18-79"></a> <span class="n">fl</span><span class="o">.</span><span class="n">SiLU</span><span class="p">(),</span>
</span><span id="__span-18-80"><a id="__codelineno-18-80" name="__codelineno-18-80" href="#__codelineno-18-80"></a> <span class="n">fl</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">in_features</span><span class="o">=</span><span class="mi">256</span><span class="p">,</span> <span class="n">out_features</span><span class="o">=</span><span class="mi">256</span><span class="p">),</span>
</span><span id="__span-18-81"><a id="__codelineno-18-81" name="__codelineno-18-81" href="#__codelineno-18-81"></a> <span class="p">)</span>
</span><span id="__span-18-82"><a id="__codelineno-18-82" name="__codelineno-18-82" href="#__codelineno-18-82"></a>
</span><span id="__span-18-83"><a id="__codelineno-18-83" name="__codelineno-18-83" href="#__codelineno-18-83"></a>
</span><span id="__span-18-84"><a id="__codelineno-18-84" name="__codelineno-18-84" href="#__codelineno-18-84"></a><span class="k">class</span> <span class="nc">Decoder</span><span class="p">(</span><span class="n">fl</span><span class="o">.</span><span class="n">Chain</span><span class="p">):</span>
</span><span id="__span-18-85"><a id="__codelineno-18-85" name="__codelineno-18-85" href="#__codelineno-18-85"></a> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="kc">None</span><span class="p">:</span>
</span><span id="__span-18-86"><a id="__codelineno-18-86" name="__codelineno-18-86" href="#__codelineno-18-86"></a> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span>
</span><span id="__span-18-87"><a id="__codelineno-18-87" name="__codelineno-18-87" href="#__codelineno-18-87"></a> <span class="n">fl</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">in_features</span><span class="o">=</span><span class="mi">256</span><span class="p">,</span> <span class="n">out_features</span><span class="o">=</span><span class="mi">256</span><span class="p">),</span>
</span><span id="__span-18-88"><a id="__codelineno-18-88" name="__codelineno-18-88" href="#__codelineno-18-88"></a> <span class="n">fl</span><span class="o">.</span><span class="n">SiLU</span><span class="p">(),</span>
</span><span id="__span-18-89"><a id="__codelineno-18-89" name="__codelineno-18-89" href="#__codelineno-18-89"></a> <span class="n">fl</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">in_features</span><span class="o">=</span><span class="mi">256</span><span class="p">,</span> <span class="n">out_features</span><span class="o">=</span><span class="mi">2048</span><span class="p">),</span>
</span><span id="__span-18-90"><a id="__codelineno-18-90" name="__codelineno-18-90" href="#__codelineno-18-90"></a> <span class="n">fl</span><span class="o">.</span><span class="n">Reshape</span><span class="p">(</span><span class="mi">32</span><span class="p">,</span> <span class="mi">8</span><span class="p">,</span> <span class="mi">8</span><span class="p">),</span>
</span><span id="__span-18-91"><a id="__codelineno-18-91" name="__codelineno-18-91" href="#__codelineno-18-91"></a> <span class="n">ResidualBlock</span><span class="p">(</span><span class="n">in_channels</span><span class="o">=</span><span class="mi">32</span><span class="p">,</span> <span class="n">out_channels</span><span class="o">=</span><span class="mi">32</span><span class="p">),</span>
</span><span id="__span-18-92"><a id="__codelineno-18-92" name="__codelineno-18-92" href="#__codelineno-18-92"></a> <span class="n">ResidualBlock</span><span class="p">(</span><span class="n">in_channels</span><span class="o">=</span><span class="mi">32</span><span class="p">,</span> <span class="n">out_channels</span><span class="o">=</span><span class="mi">32</span><span class="p">),</span>
</span><span id="__span-18-93"><a id="__codelineno-18-93" name="__codelineno-18-93" href="#__codelineno-18-93"></a> <span class="n">fl</span><span class="o">.</span><span class="n">Upsample</span><span class="p">(</span><span class="n">channels</span><span class="o">=</span><span class="mi">32</span><span class="p">,</span> <span class="n">upsample_factor</span><span class="o">=</span><span class="mi">2</span><span class="p">),</span>
</span><span id="__span-18-94"><a id="__codelineno-18-94" name="__codelineno-18-94" href="#__codelineno-18-94"></a> <span class="n">ResidualBlock</span><span class="p">(</span><span class="n">in_channels</span><span class="o">=</span><span class="mi">32</span><span class="p">,</span> <span class="n">out_channels</span><span class="o">=</span><span class="mi">16</span><span class="p">),</span>
</span><span id="__span-18-95"><a id="__codelineno-18-95" name="__codelineno-18-95" href="#__codelineno-18-95"></a> <span class="n">ResidualBlock</span><span class="p">(</span><span class="n">in_channels</span><span class="o">=</span><span class="mi">16</span><span class="p">,</span> <span class="n">out_channels</span><span class="o">=</span><span class="mi">16</span><span class="p">),</span>
</span><span id="__span-18-96"><a id="__codelineno-18-96" name="__codelineno-18-96" href="#__codelineno-18-96"></a> <span class="n">fl</span><span class="o">.</span><span class="n">Upsample</span><span class="p">(</span><span class="n">channels</span><span class="o">=</span><span class="mi">16</span><span class="p">,</span> <span class="n">upsample_factor</span><span class="o">=</span><span class="mi">2</span><span class="p">),</span>
</span><span id="__span-18-97"><a id="__codelineno-18-97" name="__codelineno-18-97" href="#__codelineno-18-97"></a> <span class="n">ResidualBlock</span><span class="p">(</span><span class="n">in_channels</span><span class="o">=</span><span class="mi">16</span><span class="p">,</span> <span class="n">out_channels</span><span class="o">=</span><span class="mi">8</span><span class="p">),</span>
</span><span id="__span-18-98"><a id="__codelineno-18-98" name="__codelineno-18-98" href="#__codelineno-18-98"></a> <span class="n">ResidualBlock</span><span class="p">(</span><span class="n">in_channels</span><span class="o">=</span><span class="mi">8</span><span class="p">,</span> <span class="n">out_channels</span><span class="o">=</span><span class="mi">8</span><span class="p">),</span>
</span><span id="__span-18-99"><a id="__codelineno-18-99" name="__codelineno-18-99" href="#__codelineno-18-99"></a> <span class="n">fl</span><span class="o">.</span><span class="n">Upsample</span><span class="p">(</span><span class="n">channels</span><span class="o">=</span><span class="mi">8</span><span class="p">,</span> <span class="n">upsample_factor</span><span class="o">=</span><span class="mi">2</span><span class="p">),</span>
</span><span id="__span-18-100"><a id="__codelineno-18-100" name="__codelineno-18-100" href="#__codelineno-18-100"></a> <span class="n">ResidualBlock</span><span class="p">(</span><span class="n">in_channels</span><span class="o">=</span><span class="mi">8</span><span class="p">,</span> <span class="n">out_channels</span><span class="o">=</span><span class="mi">8</span><span class="p">),</span>
</span><span id="__span-18-101"><a id="__codelineno-18-101" name="__codelineno-18-101" href="#__codelineno-18-101"></a> <span class="n">ResidualBlock</span><span class="p">(</span><span class="n">in_channels</span><span class="o">=</span><span class="mi">8</span><span class="p">,</span> <span class="n">out_channels</span><span class="o">=</span><span class="mi">1</span><span class="p">),</span>
</span><span id="__span-18-102"><a id="__codelineno-18-102" name="__codelineno-18-102" href="#__codelineno-18-102"></a> <span class="n">fl</span><span class="o">.</span><span class="n">Sigmoid</span><span class="p">(),</span>
</span><span id="__span-18-103"><a id="__codelineno-18-103" name="__codelineno-18-103" href="#__codelineno-18-103"></a> <span class="p">)</span>
</span><span id="__span-18-104"><a id="__codelineno-18-104" name="__codelineno-18-104" href="#__codelineno-18-104"></a>
</span><span id="__span-18-105"><a id="__codelineno-18-105" name="__codelineno-18-105" href="#__codelineno-18-105"></a>
</span><span id="__span-18-106"><a id="__codelineno-18-106" name="__codelineno-18-106" href="#__codelineno-18-106"></a><span class="k">class</span> <span class="nc">Autoencoder</span><span class="p">(</span><span class="n">fl</span><span class="o">.</span><span class="n">Chain</span><span class="p">):</span>
</span><span id="__span-18-107"><a id="__codelineno-18-107" name="__codelineno-18-107" href="#__codelineno-18-107"></a> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="kc">None</span><span class="p">:</span>
</span><span id="__span-18-108"><a id="__codelineno-18-108" name="__codelineno-18-108" href="#__codelineno-18-108"></a> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span>
</span><span id="__span-18-109"><a id="__codelineno-18-109" name="__codelineno-18-109" href="#__codelineno-18-109"></a> <span class="n">Encoder</span><span class="p">(),</span>
</span><span id="__span-18-110"><a id="__codelineno-18-110" name="__codelineno-18-110" href="#__codelineno-18-110"></a> <span class="n">Decoder</span><span class="p">(),</span>
</span><span id="__span-18-111"><a id="__codelineno-18-111" name="__codelineno-18-111" href="#__codelineno-18-111"></a> <span class="p">)</span>
</span><span id="__span-18-112"><a id="__codelineno-18-112" name="__codelineno-18-112" href="#__codelineno-18-112"></a>
</span><span id="__span-18-113"><a id="__codelineno-18-113" name="__codelineno-18-113" href="#__codelineno-18-113"></a> <span class="nd">@property</span>
</span><span id="__span-18-114"><a id="__codelineno-18-114" name="__codelineno-18-114" href="#__codelineno-18-114"></a> <span class="k">def</span> <span class="nf">encoder</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Encoder</span><span class="p">:</span>
</span><span id="__span-18-115"><a id="__codelineno-18-115" name="__codelineno-18-115" href="#__codelineno-18-115"></a> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">ensure_find</span><span class="p">(</span><span class="n">Encoder</span><span class="p">)</span>
</span><span id="__span-18-116"><a id="__codelineno-18-116" name="__codelineno-18-116" href="#__codelineno-18-116"></a>
</span><span id="__span-18-117"><a id="__codelineno-18-117" name="__codelineno-18-117" href="#__codelineno-18-117"></a> <span class="nd">@property</span>
</span><span id="__span-18-118"><a id="__codelineno-18-118" name="__codelineno-18-118" href="#__codelineno-18-118"></a> <span class="k">def</span> <span class="nf">decoder</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Decoder</span><span class="p">:</span>
</span><span id="__span-18-119"><a id="__codelineno-18-119" name="__codelineno-18-119" href="#__codelineno-18-119"></a> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">ensure_find</span><span class="p">(</span><span class="n">Decoder</span><span class="p">)</span>
</span><span id="__span-18-120"><a id="__codelineno-18-120" name="__codelineno-18-120" href="#__codelineno-18-120"></a>
</span><span id="__span-18-121"><a id="__codelineno-18-121" name="__codelineno-18-121" href="#__codelineno-18-121"></a>
</span><span id="__span-18-122"><a id="__codelineno-18-122" name="__codelineno-18-122" href="#__codelineno-18-122"></a><span class="k">def</span> <span class="nf">generate_mask</span><span class="p">(</span><span class="n">size</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">seed</span><span class="p">:</span> <span class="nb">int</span> <span class="o">|</span> <span class="kc">None</span> <span class="o">=</span> <span class="kc">None</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Generator</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">,</span> <span class="kc">None</span><span class="p">,</span> <span class="kc">None</span><span class="p">]:</span>
</span><span id="__span-18-123"><a id="__codelineno-18-123" name="__codelineno-18-123" href="#__codelineno-18-123"></a><span class="w"> </span><span class="sd">&quot;&quot;&quot;Generate a tensor of a binary mask of size `size` using random rectangles.&quot;&quot;&quot;</span>
</span><span id="__span-18-124"><a id="__codelineno-18-124" name="__codelineno-18-124" href="#__codelineno-18-124"></a> <span class="k">if</span> <span class="n">seed</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
</span><span id="__span-18-125"><a id="__codelineno-18-125" name="__codelineno-18-125" href="#__codelineno-18-125"></a> <span class="n">seed</span> <span class="o">=</span> <span class="n">random</span><span class="o">.</span><span class="n">randint</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">2</span><span class="o">**</span><span class="mi">32</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span>
</span><span id="__span-18-126"><a id="__codelineno-18-126" name="__codelineno-18-126" href="#__codelineno-18-126"></a> <span class="n">random</span><span class="o">.</span><span class="n">seed</span><span class="p">(</span><span class="n">seed</span><span class="p">)</span>
</span><span id="__span-18-127"><a id="__codelineno-18-127" name="__codelineno-18-127" href="#__codelineno-18-127"></a>
</span><span id="__span-18-128"><a id="__codelineno-18-128" name="__codelineno-18-128" href="#__codelineno-18-128"></a> <span class="k">while</span> <span class="kc">True</span><span class="p">:</span>
</span><span id="__span-18-129"><a id="__codelineno-18-129" name="__codelineno-18-129" href="#__codelineno-18-129"></a> <span class="n">rectangle</span> <span class="o">=</span> <span class="n">Image</span><span class="o">.</span><span class="n">new</span><span class="p">(</span><span class="s2">&quot;L&quot;</span><span class="p">,</span> <span class="p">(</span><span class="n">random</span><span class="o">.</span><span class="n">randint</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">size</span><span class="p">),</span> <span class="n">random</span><span class="o">.</span><span class="n">randint</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">size</span><span class="p">)),</span> <span class="n">color</span><span class="o">=</span><span class="mi">255</span><span class="p">)</span>
</span><span id="__span-18-130"><a id="__codelineno-18-130" name="__codelineno-18-130" href="#__codelineno-18-130"></a> <span class="n">mask</span> <span class="o">=</span> <span class="n">Image</span><span class="o">.</span><span class="n">new</span><span class="p">(</span><span class="s2">&quot;L&quot;</span><span class="p">,</span> <span class="p">(</span><span class="n">size</span><span class="p">,</span> <span class="n">size</span><span class="p">))</span>
</span><span id="__span-18-131"><a id="__codelineno-18-131" name="__codelineno-18-131" href="#__codelineno-18-131"></a> <span class="n">mask</span><span class="o">.</span><span class="n">paste</span><span class="p">(</span>
</span><span id="__span-18-132"><a id="__codelineno-18-132" name="__codelineno-18-132" href="#__codelineno-18-132"></a> <span class="n">rectangle</span><span class="p">,</span>
</span><span id="__span-18-133"><a id="__codelineno-18-133" name="__codelineno-18-133" href="#__codelineno-18-133"></a> <span class="p">(</span>
</span><span id="__span-18-134"><a id="__codelineno-18-134" name="__codelineno-18-134" href="#__codelineno-18-134"></a> <span class="n">random</span><span class="o">.</span><span class="n">randint</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">size</span> <span class="o">-</span> <span class="n">rectangle</span><span class="o">.</span><span class="n">width</span><span class="p">),</span>
</span><span id="__span-18-135"><a id="__codelineno-18-135" name="__codelineno-18-135" href="#__codelineno-18-135"></a> <span class="n">random</span><span class="o">.</span><span class="n">randint</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">size</span> <span class="o">-</span> <span class="n">rectangle</span><span class="o">.</span><span class="n">height</span><span class="p">),</span>
</span><span id="__span-18-136"><a id="__codelineno-18-136" name="__codelineno-18-136" href="#__codelineno-18-136"></a> <span class="p">),</span>
</span><span id="__span-18-137"><a id="__codelineno-18-137" name="__codelineno-18-137" href="#__codelineno-18-137"></a> <span class="p">)</span>
</span><span id="__span-18-138"><a id="__codelineno-18-138" name="__codelineno-18-138" href="#__codelineno-18-138"></a> <span class="n">tensor</span> <span class="o">=</span> <span class="n">image_to_tensor</span><span class="p">(</span><span class="n">mask</span><span class="p">)</span>
</span><span id="__span-18-139"><a id="__codelineno-18-139" name="__codelineno-18-139" href="#__codelineno-18-139"></a>
</span><span id="__span-18-140"><a id="__codelineno-18-140" name="__codelineno-18-140" href="#__codelineno-18-140"></a> <span class="k">if</span> <span class="n">random</span><span class="o">.</span><span class="n">random</span><span class="p">()</span> <span class="o">&gt;</span> <span class="mf">0.5</span><span class="p">:</span>
</span><span id="__span-18-141"><a id="__codelineno-18-141" name="__codelineno-18-141" href="#__codelineno-18-141"></a> <span class="n">tensor</span> <span class="o">=</span> <span class="mi">1</span> <span class="o">-</span> <span class="n">tensor</span>
</span><span id="__span-18-142"><a id="__codelineno-18-142" name="__codelineno-18-142" href="#__codelineno-18-142"></a>
</span><span id="__span-18-143"><a id="__codelineno-18-143" name="__codelineno-18-143" href="#__codelineno-18-143"></a> <span class="k">yield</span> <span class="n">tensor</span>
</span><span id="__span-18-144"><a id="__codelineno-18-144" name="__codelineno-18-144" href="#__codelineno-18-144"></a>
</span><span id="__span-18-145"><a id="__codelineno-18-145" name="__codelineno-18-145" href="#__codelineno-18-145"></a>
</span><span id="__span-18-146"><a id="__codelineno-18-146" name="__codelineno-18-146" href="#__codelineno-18-146"></a><span class="nd">@dataclass</span>
</span><span id="__span-18-147"><a id="__codelineno-18-147" name="__codelineno-18-147" href="#__codelineno-18-147"></a><span class="k">class</span> <span class="nc">Batch</span><span class="p">:</span>
</span><span id="__span-18-148"><a id="__codelineno-18-148" name="__codelineno-18-148" href="#__codelineno-18-148"></a> <span class="n">image</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span>
</span><span id="__span-18-149"><a id="__codelineno-18-149" name="__codelineno-18-149" href="#__codelineno-18-149"></a>
</span><span id="__span-18-150"><a id="__codelineno-18-150" name="__codelineno-18-150" href="#__codelineno-18-150"></a>
</span><span id="__span-18-151"><a id="__codelineno-18-151" name="__codelineno-18-151" href="#__codelineno-18-151"></a><span class="k">class</span> <span class="nc">AutoencoderModelConfig</span><span class="p">(</span><span class="n">ModelConfig</span><span class="p">):</span>
</span><span id="__span-18-152"><a id="__codelineno-18-152" name="__codelineno-18-152" href="#__codelineno-18-152"></a> <span class="k">pass</span>
</span><span id="__span-18-153"><a id="__codelineno-18-153" name="__codelineno-18-153" href="#__codelineno-18-153"></a>
</span><span id="__span-18-154"><a id="__codelineno-18-154" name="__codelineno-18-154" href="#__codelineno-18-154"></a>
</span><span id="__span-18-155"><a id="__codelineno-18-155" name="__codelineno-18-155" href="#__codelineno-18-155"></a><span class="k">class</span> <span class="nc">LoggingCallback</span><span class="p">(</span><span class="n">Callback</span><span class="p">[</span><span class="n">Trainer</span><span class="p">[</span><span class="n">Any</span><span class="p">,</span> <span class="n">Any</span><span class="p">]]):</span>
</span><span id="__span-18-156"><a id="__codelineno-18-156" name="__codelineno-18-156" href="#__codelineno-18-156"></a> <span class="n">losses</span><span class="p">:</span> <span class="nb">list</span><span class="p">[</span><span class="nb">float</span><span class="p">]</span> <span class="o">=</span> <span class="p">[]</span>
</span><span id="__span-18-157"><a id="__codelineno-18-157" name="__codelineno-18-157" href="#__codelineno-18-157"></a>
</span><span id="__span-18-158"><a id="__codelineno-18-158" name="__codelineno-18-158" href="#__codelineno-18-158"></a> <span class="k">def</span> <span class="nf">on_compute_loss_end</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">trainer</span><span class="p">:</span> <span class="n">Trainer</span><span class="p">[</span><span class="n">Any</span><span class="p">,</span> <span class="n">Any</span><span class="p">])</span> <span class="o">-&gt;</span> <span class="kc">None</span><span class="p">:</span>
</span><span id="__span-18-159"><a id="__codelineno-18-159" name="__codelineno-18-159" href="#__codelineno-18-159"></a> <span class="bp">self</span><span class="o">.</span><span class="n">losses</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">trainer</span><span class="o">.</span><span class="n">loss</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">())</span>
</span><span id="__span-18-160"><a id="__codelineno-18-160" name="__codelineno-18-160" href="#__codelineno-18-160"></a>
</span><span id="__span-18-161"><a id="__codelineno-18-161" name="__codelineno-18-161" href="#__codelineno-18-161"></a> <span class="k">def</span> <span class="nf">on_epoch_end</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">trainer</span><span class="p">:</span> <span class="n">Trainer</span><span class="p">[</span><span class="n">Any</span><span class="p">,</span> <span class="n">Any</span><span class="p">])</span> <span class="o">-&gt;</span> <span class="kc">None</span><span class="p">:</span>
</span><span id="__span-18-162"><a id="__codelineno-18-162" name="__codelineno-18-162" href="#__codelineno-18-162"></a> <span class="n">mean_loss</span> <span class="o">=</span> <span class="nb">sum</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">losses</span><span class="p">)</span> <span class="o">/</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">losses</span><span class="p">)</span>
</span><span id="__span-18-163"><a id="__codelineno-18-163" name="__codelineno-18-163" href="#__codelineno-18-163"></a> <span class="n">logger</span><span class="o">.</span><span class="n">info</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Mean loss: </span><span class="si">{</span><span class="n">mean_loss</span><span class="si">}</span><span class="s2">, epoch: </span><span class="si">{</span><span class="n">trainer</span><span class="o">.</span><span class="n">clock</span><span class="o">.</span><span class="n">epoch</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
</span><span id="__span-18-164"><a id="__codelineno-18-164" name="__codelineno-18-164" href="#__codelineno-18-164"></a> <span class="bp">self</span><span class="o">.</span><span class="n">losses</span> <span class="o">=</span> <span class="p">[]</span>
</span><span id="__span-18-165"><a id="__codelineno-18-165" name="__codelineno-18-165" href="#__codelineno-18-165"></a>
</span><span id="__span-18-166"><a id="__codelineno-18-166" name="__codelineno-18-166" href="#__codelineno-18-166"></a>
</span><span id="__span-18-167"><a id="__codelineno-18-167" name="__codelineno-18-167" href="#__codelineno-18-167"></a><span class="k">class</span> <span class="nc">EvaluationConfig</span><span class="p">(</span><span class="n">CallbackConfig</span><span class="p">):</span>
</span><span id="__span-18-168"><a id="__codelineno-18-168" name="__codelineno-18-168" href="#__codelineno-18-168"></a> <span class="n">interval</span><span class="p">:</span> <span class="n">TimeValueField</span>
</span><span id="__span-18-169"><a id="__codelineno-18-169" name="__codelineno-18-169" href="#__codelineno-18-169"></a> <span class="n">seed</span><span class="p">:</span> <span class="nb">int</span>
</span><span id="__span-18-170"><a id="__codelineno-18-170" name="__codelineno-18-170" href="#__codelineno-18-170"></a>
</span><span id="__span-18-171"><a id="__codelineno-18-171" name="__codelineno-18-171" href="#__codelineno-18-171"></a>
</span><span id="__span-18-172"><a id="__codelineno-18-172" name="__codelineno-18-172" href="#__codelineno-18-172"></a><span class="k">class</span> <span class="nc">EvaluationCallback</span><span class="p">(</span><span class="n">Callback</span><span class="p">[</span><span class="s2">&quot;AutoencoderTrainer&quot;</span><span class="p">]):</span>
</span><span id="__span-18-173"><a id="__codelineno-18-173" name="__codelineno-18-173" href="#__codelineno-18-173"></a> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">config</span><span class="p">:</span> <span class="n">EvaluationConfig</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="kc">None</span><span class="p">:</span>
</span><span id="__span-18-174"><a id="__codelineno-18-174" name="__codelineno-18-174" href="#__codelineno-18-174"></a> <span class="bp">self</span><span class="o">.</span><span class="n">config</span> <span class="o">=</span> <span class="n">config</span>
</span><span id="__span-18-175"><a id="__codelineno-18-175" name="__codelineno-18-175" href="#__codelineno-18-175"></a>
</span><span id="__span-18-176"><a id="__codelineno-18-176" name="__codelineno-18-176" href="#__codelineno-18-176"></a> <span class="k">def</span> <span class="nf">on_epoch_end</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">trainer</span><span class="p">:</span> <span class="s2">&quot;AutoencoderTrainer&quot;</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="kc">None</span><span class="p">:</span>
</span><span id="__span-18-177"><a id="__codelineno-18-177" name="__codelineno-18-177" href="#__codelineno-18-177"></a> <span class="c1"># The `is_due` method checks if the current epoch is a multiple of the interval.</span>
</span><span id="__span-18-178"><a id="__codelineno-18-178" name="__codelineno-18-178" href="#__codelineno-18-178"></a> <span class="k">if</span> <span class="ow">not</span> <span class="n">trainer</span><span class="o">.</span><span class="n">clock</span><span class="o">.</span><span class="n">is_due</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">interval</span><span class="p">):</span>
</span><span id="__span-18-179"><a id="__codelineno-18-179" name="__codelineno-18-179" href="#__codelineno-18-179"></a> <span class="k">return</span>
</span><span id="__span-18-180"><a id="__codelineno-18-180" name="__codelineno-18-180" href="#__codelineno-18-180"></a>
</span><span id="__span-18-181"><a id="__codelineno-18-181" name="__codelineno-18-181" href="#__codelineno-18-181"></a> <span class="c1"># The `scoped_seed` context manager encapsulates the random state for the evaluation and restores it after the</span>
</span><span id="__span-18-182"><a id="__codelineno-18-182" name="__codelineno-18-182" href="#__codelineno-18-182"></a> <span class="c1"># evaluation.</span>
</span><span id="__span-18-183"><a id="__codelineno-18-183" name="__codelineno-18-183" href="#__codelineno-18-183"></a> <span class="k">with</span> <span class="n">scoped_seed</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">seed</span><span class="p">):</span>
</span><span id="__span-18-184"><a id="__codelineno-18-184" name="__codelineno-18-184" href="#__codelineno-18-184"></a> <span class="n">trainer</span><span class="o">.</span><span class="n">compute_evaluation</span><span class="p">()</span>
</span><span id="__span-18-185"><a id="__codelineno-18-185" name="__codelineno-18-185" href="#__codelineno-18-185"></a>
</span><span id="__span-18-186"><a id="__codelineno-18-186" name="__codelineno-18-186" href="#__codelineno-18-186"></a>
</span><span id="__span-18-187"><a id="__codelineno-18-187" name="__codelineno-18-187" href="#__codelineno-18-187"></a><span class="k">class</span> <span class="nc">AutoencoderConfig</span><span class="p">(</span><span class="n">BaseConfig</span><span class="p">):</span>
</span><span id="__span-18-188"><a id="__codelineno-18-188" name="__codelineno-18-188" href="#__codelineno-18-188"></a> <span class="n">num_images</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">2048</span>
</span><span id="__span-18-189"><a id="__codelineno-18-189" name="__codelineno-18-189" href="#__codelineno-18-189"></a> <span class="n">batch_size</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">32</span>
</span><span id="__span-18-190"><a id="__codelineno-18-190" name="__codelineno-18-190" href="#__codelineno-18-190"></a> <span class="n">autoencoder</span><span class="p">:</span> <span class="n">AutoencoderModelConfig</span>
</span><span id="__span-18-191"><a id="__codelineno-18-191" name="__codelineno-18-191" href="#__codelineno-18-191"></a> <span class="n">evaluation</span><span class="p">:</span> <span class="n">EvaluationConfig</span>
</span><span id="__span-18-192"><a id="__codelineno-18-192" name="__codelineno-18-192" href="#__codelineno-18-192"></a> <span class="n">logging</span><span class="p">:</span> <span class="n">CallbackConfig</span> <span class="o">=</span> <span class="n">CallbackConfig</span><span class="p">()</span>
</span><span id="__span-18-193"><a id="__codelineno-18-193" name="__codelineno-18-193" href="#__codelineno-18-193"></a>
</span><span id="__span-18-194"><a id="__codelineno-18-194" name="__codelineno-18-194" href="#__codelineno-18-194"></a>
</span><span id="__span-18-195"><a id="__codelineno-18-195" name="__codelineno-18-195" href="#__codelineno-18-195"></a><span class="n">autoencoder_config</span> <span class="o">=</span> <span class="n">AutoencoderModelConfig</span><span class="p">(</span>
</span><span id="__span-18-196"><a id="__codelineno-18-196" name="__codelineno-18-196" href="#__codelineno-18-196"></a> <span class="n">requires_grad</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="c1"># set during registration to set the requires_grad attribute of the model.</span>
</span><span id="__span-18-197"><a id="__codelineno-18-197" name="__codelineno-18-197" href="#__codelineno-18-197"></a><span class="p">)</span>
</span><span id="__span-18-198"><a id="__codelineno-18-198" name="__codelineno-18-198" href="#__codelineno-18-198"></a>
</span><span id="__span-18-199"><a id="__codelineno-18-199" name="__codelineno-18-199" href="#__codelineno-18-199"></a><span class="n">training</span> <span class="o">=</span> <span class="n">TrainingConfig</span><span class="p">(</span>
</span><span id="__span-18-200"><a id="__codelineno-18-200" name="__codelineno-18-200" href="#__codelineno-18-200"></a> <span class="n">duration</span><span class="o">=</span><span class="n">Epoch</span><span class="p">(</span><span class="mi">200</span><span class="p">),</span>
</span><span id="__span-18-201"><a id="__codelineno-18-201" name="__codelineno-18-201" href="#__codelineno-18-201"></a> <span class="n">device</span><span class="o">=</span><span class="s2">&quot;cuda&quot;</span> <span class="k">if</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">is_available</span><span class="p">()</span> <span class="k">else</span> <span class="s2">&quot;cpu&quot;</span><span class="p">,</span>
</span><span id="__span-18-202"><a id="__codelineno-18-202" name="__codelineno-18-202" href="#__codelineno-18-202"></a> <span class="n">dtype</span><span class="o">=</span><span class="s2">&quot;float32&quot;</span><span class="p">,</span>
</span><span id="__span-18-203"><a id="__codelineno-18-203" name="__codelineno-18-203" href="#__codelineno-18-203"></a><span class="p">)</span>
</span><span id="__span-18-204"><a id="__codelineno-18-204" name="__codelineno-18-204" href="#__codelineno-18-204"></a>
</span><span id="__span-18-205"><a id="__codelineno-18-205" name="__codelineno-18-205" href="#__codelineno-18-205"></a><span class="n">optimizer</span> <span class="o">=</span> <span class="n">OptimizerConfig</span><span class="p">(</span>
</span><span id="__span-18-206"><a id="__codelineno-18-206" name="__codelineno-18-206" href="#__codelineno-18-206"></a> <span class="n">optimizer</span><span class="o">=</span><span class="n">Optimizers</span><span class="o">.</span><span class="n">AdamW</span><span class="p">,</span>
</span><span id="__span-18-207"><a id="__codelineno-18-207" name="__codelineno-18-207" href="#__codelineno-18-207"></a> <span class="n">learning_rate</span><span class="o">=</span><span class="mf">1e-4</span><span class="p">,</span>
</span><span id="__span-18-208"><a id="__codelineno-18-208" name="__codelineno-18-208" href="#__codelineno-18-208"></a><span class="p">)</span>
</span><span id="__span-18-209"><a id="__codelineno-18-209" name="__codelineno-18-209" href="#__codelineno-18-209"></a>
</span><span id="__span-18-210"><a id="__codelineno-18-210" name="__codelineno-18-210" href="#__codelineno-18-210"></a><span class="n">lr_scheduler</span> <span class="o">=</span> <span class="n">LRSchedulerConfig</span><span class="p">(</span><span class="nb">type</span><span class="o">=</span><span class="n">LRSchedulerType</span><span class="o">.</span><span class="n">CONSTANT_LR</span><span class="p">)</span>
</span><span id="__span-18-211"><a id="__codelineno-18-211" name="__codelineno-18-211" href="#__codelineno-18-211"></a>
</span><span id="__span-18-212"><a id="__codelineno-18-212" name="__codelineno-18-212" href="#__codelineno-18-212"></a><span class="n">config</span> <span class="o">=</span> <span class="n">AutoencoderConfig</span><span class="p">(</span>
</span><span id="__span-18-213"><a id="__codelineno-18-213" name="__codelineno-18-213" href="#__codelineno-18-213"></a> <span class="n">training</span><span class="o">=</span><span class="n">training</span><span class="p">,</span>
</span><span id="__span-18-214"><a id="__codelineno-18-214" name="__codelineno-18-214" href="#__codelineno-18-214"></a> <span class="n">optimizer</span><span class="o">=</span><span class="n">optimizer</span><span class="p">,</span>
</span><span id="__span-18-215"><a id="__codelineno-18-215" name="__codelineno-18-215" href="#__codelineno-18-215"></a> <span class="n">lr_scheduler</span><span class="o">=</span><span class="n">lr_scheduler</span><span class="p">,</span>
</span><span id="__span-18-216"><a id="__codelineno-18-216" name="__codelineno-18-216" href="#__codelineno-18-216"></a> <span class="n">autoencoder</span><span class="o">=</span><span class="n">autoencoder_config</span><span class="p">,</span>
</span><span id="__span-18-217"><a id="__codelineno-18-217" name="__codelineno-18-217" href="#__codelineno-18-217"></a> <span class="n">evaluation</span><span class="o">=</span><span class="n">EvaluationConfig</span><span class="p">(</span><span class="n">interval</span><span class="o">=</span><span class="n">Epoch</span><span class="p">(</span><span class="mi">50</span><span class="p">),</span> <span class="n">seed</span><span class="o">=</span><span class="mi">0</span><span class="p">),</span>
</span><span id="__span-18-218"><a id="__codelineno-18-218" name="__codelineno-18-218" href="#__codelineno-18-218"></a> <span class="n">clock</span><span class="o">=</span><span class="n">ClockConfig</span><span class="p">(</span><span class="n">verbose</span><span class="o">=</span><span class="kc">False</span><span class="p">),</span> <span class="c1"># to disable the default clock logging</span>
</span><span id="__span-18-219"><a id="__codelineno-18-219" name="__codelineno-18-219" href="#__codelineno-18-219"></a><span class="p">)</span>
</span><span id="__span-18-220"><a id="__codelineno-18-220" name="__codelineno-18-220" href="#__codelineno-18-220"></a>
</span><span id="__span-18-221"><a id="__codelineno-18-221" name="__codelineno-18-221" href="#__codelineno-18-221"></a>
</span><span id="__span-18-222"><a id="__codelineno-18-222" name="__codelineno-18-222" href="#__codelineno-18-222"></a><span class="k">class</span> <span class="nc">AutoencoderTrainer</span><span class="p">(</span><span class="n">Trainer</span><span class="p">[</span><span class="n">AutoencoderConfig</span><span class="p">,</span> <span class="n">Batch</span><span class="p">]):</span>
</span><span id="__span-18-223"><a id="__codelineno-18-223" name="__codelineno-18-223" href="#__codelineno-18-223"></a> <span class="k">def</span> <span class="nf">create_data_iterable</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">list</span><span class="p">[</span><span class="n">Batch</span><span class="p">]:</span>
</span><span id="__span-18-224"><a id="__codelineno-18-224" name="__codelineno-18-224" href="#__codelineno-18-224"></a> <span class="n">dataset</span><span class="p">:</span> <span class="nb">list</span><span class="p">[</span><span class="n">Batch</span><span class="p">]</span> <span class="o">=</span> <span class="p">[]</span>
</span><span id="__span-18-225"><a id="__codelineno-18-225" name="__codelineno-18-225" href="#__codelineno-18-225"></a> <span class="n">generator</span> <span class="o">=</span> <span class="n">generate_mask</span><span class="p">(</span><span class="n">size</span><span class="o">=</span><span class="mi">64</span><span class="p">)</span>
</span><span id="__span-18-226"><a id="__codelineno-18-226" name="__codelineno-18-226" href="#__codelineno-18-226"></a>
</span><span id="__span-18-227"><a id="__codelineno-18-227" name="__codelineno-18-227" href="#__codelineno-18-227"></a> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">num_images</span> <span class="o">//</span> <span class="bp">self</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">batch_size</span><span class="p">):</span>
</span><span id="__span-18-228"><a id="__codelineno-18-228" name="__codelineno-18-228" href="#__codelineno-18-228"></a> <span class="n">masks</span> <span class="o">=</span> <span class="p">[</span><span class="nb">next</span><span class="p">(</span><span class="n">generator</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">batch_size</span><span class="p">)]</span>
</span><span id="__span-18-229"><a id="__codelineno-18-229" name="__codelineno-18-229" href="#__codelineno-18-229"></a> <span class="n">dataset</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">Batch</span><span class="p">(</span><span class="n">image</span><span class="o">=</span><span class="n">torch</span><span class="o">.</span><span class="n">cat</span><span class="p">(</span><span class="n">masks</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)))</span>
</span><span id="__span-18-230"><a id="__codelineno-18-230" name="__codelineno-18-230" href="#__codelineno-18-230"></a>
</span><span id="__span-18-231"><a id="__codelineno-18-231" name="__codelineno-18-231" href="#__codelineno-18-231"></a> <span class="k">return</span> <span class="n">dataset</span>
</span><span id="__span-18-232"><a id="__codelineno-18-232" name="__codelineno-18-232" href="#__codelineno-18-232"></a>
</span><span id="__span-18-233"><a id="__codelineno-18-233" name="__codelineno-18-233" href="#__codelineno-18-233"></a> <span class="nd">@register_model</span><span class="p">()</span>
</span><span id="__span-18-234"><a id="__codelineno-18-234" name="__codelineno-18-234" href="#__codelineno-18-234"></a> <span class="k">def</span> <span class="nf">autoencoder</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">config</span><span class="p">:</span> <span class="n">AutoencoderModelConfig</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Autoencoder</span><span class="p">:</span>
</span><span id="__span-18-235"><a id="__codelineno-18-235" name="__codelineno-18-235" href="#__codelineno-18-235"></a> <span class="k">return</span> <span class="n">Autoencoder</span><span class="p">()</span>
</span><span id="__span-18-236"><a id="__codelineno-18-236" name="__codelineno-18-236" href="#__codelineno-18-236"></a>
</span><span id="__span-18-237"><a id="__codelineno-18-237" name="__codelineno-18-237" href="#__codelineno-18-237"></a> <span class="k">def</span> <span class="nf">compute_loss</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch</span><span class="p">:</span> <span class="n">Batch</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">:</span>
</span><span id="__span-18-238"><a id="__codelineno-18-238" name="__codelineno-18-238" href="#__codelineno-18-238"></a> <span class="n">batch</span><span class="o">.</span><span class="n">image</span> <span class="o">=</span> <span class="n">batch</span><span class="o">.</span><span class="n">image</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span>
</span><span id="__span-18-239"><a id="__codelineno-18-239" name="__codelineno-18-239" href="#__codelineno-18-239"></a> <span class="n">x_reconstructed</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">autoencoder</span><span class="o">.</span><span class="n">decoder</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">autoencoder</span><span class="o">.</span><span class="n">encoder</span><span class="p">(</span><span class="n">batch</span><span class="o">.</span><span class="n">image</span><span class="p">))</span>
</span><span id="__span-18-240"><a id="__codelineno-18-240" name="__codelineno-18-240" href="#__codelineno-18-240"></a> <span class="k">return</span> <span class="n">F</span><span class="o">.</span><span class="n">binary_cross_entropy</span><span class="p">(</span><span class="n">x_reconstructed</span><span class="p">,</span> <span class="n">batch</span><span class="o">.</span><span class="n">image</span><span class="p">)</span>
</span><span id="__span-18-241"><a id="__codelineno-18-241" name="__codelineno-18-241" href="#__codelineno-18-241"></a>
</span><span id="__span-18-242"><a id="__codelineno-18-242" name="__codelineno-18-242" href="#__codelineno-18-242"></a> <span class="k">def</span> <span class="nf">compute_evaluation</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="kc">None</span><span class="p">:</span>
</span><span id="__span-18-243"><a id="__codelineno-18-243" name="__codelineno-18-243" href="#__codelineno-18-243"></a> <span class="n">generator</span> <span class="o">=</span> <span class="n">generate_mask</span><span class="p">(</span><span class="n">size</span><span class="o">=</span><span class="mi">64</span><span class="p">,</span> <span class="n">seed</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
</span><span id="__span-18-244"><a id="__codelineno-18-244" name="__codelineno-18-244" href="#__codelineno-18-244"></a>
</span><span id="__span-18-245"><a id="__codelineno-18-245" name="__codelineno-18-245" href="#__codelineno-18-245"></a> <span class="n">grid</span><span class="p">:</span> <span class="nb">list</span><span class="p">[</span><span class="nb">tuple</span><span class="p">[</span><span class="n">Image</span><span class="o">.</span><span class="n">Image</span><span class="p">,</span> <span class="n">Image</span><span class="o">.</span><span class="n">Image</span><span class="p">]]</span> <span class="o">=</span> <span class="p">[]</span>
</span><span id="__span-18-246"><a id="__codelineno-18-246" name="__codelineno-18-246" href="#__codelineno-18-246"></a> <span class="n">validation_losses</span><span class="p">:</span> <span class="nb">list</span><span class="p">[</span><span class="nb">float</span><span class="p">]</span> <span class="o">=</span> <span class="p">[]</span>
</span><span id="__span-18-247"><a id="__codelineno-18-247" name="__codelineno-18-247" href="#__codelineno-18-247"></a> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">4</span><span class="p">):</span>
</span><span id="__span-18-248"><a id="__codelineno-18-248" name="__codelineno-18-248" href="#__codelineno-18-248"></a> <span class="n">mask</span> <span class="o">=</span> <span class="nb">next</span><span class="p">(</span><span class="n">generator</span><span class="p">)</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">device</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span>
</span><span id="__span-18-249"><a id="__codelineno-18-249" name="__codelineno-18-249" href="#__codelineno-18-249"></a> <span class="n">x_reconstructed</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">autoencoder</span><span class="o">.</span><span class="n">decoder</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">autoencoder</span><span class="o">.</span><span class="n">encoder</span><span class="p">(</span><span class="n">mask</span><span class="p">))</span>
</span><span id="__span-18-250"><a id="__codelineno-18-250" name="__codelineno-18-250" href="#__codelineno-18-250"></a> <span class="n">loss</span> <span class="o">=</span> <span class="n">F</span><span class="o">.</span><span class="n">mse_loss</span><span class="p">(</span><span class="n">x_reconstructed</span><span class="p">,</span> <span class="n">mask</span><span class="p">)</span>
</span><span id="__span-18-251"><a id="__codelineno-18-251" name="__codelineno-18-251" href="#__codelineno-18-251"></a> <span class="n">validation_losses</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">loss</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">())</span>
</span><span id="__span-18-252"><a id="__codelineno-18-252" name="__codelineno-18-252" href="#__codelineno-18-252"></a> <span class="n">grid</span><span class="o">.</span><span class="n">append</span><span class="p">((</span><span class="n">tensor_to_image</span><span class="p">(</span><span class="n">mask</span><span class="p">),</span> <span class="n">tensor_to_image</span><span class="p">((</span><span class="n">x_reconstructed</span> <span class="o">&gt;</span> <span class="mf">0.5</span><span class="p">)</span><span class="o">.</span><span class="n">float</span><span class="p">())))</span>
</span><span id="__span-18-253"><a id="__codelineno-18-253" name="__codelineno-18-253" href="#__codelineno-18-253"></a>
</span><span id="__span-18-254"><a id="__codelineno-18-254" name="__codelineno-18-254" href="#__codelineno-18-254"></a> <span class="n">mean_loss</span> <span class="o">=</span> <span class="nb">sum</span><span class="p">(</span><span class="n">validation_losses</span><span class="p">)</span> <span class="o">/</span> <span class="nb">len</span><span class="p">(</span><span class="n">validation_losses</span><span class="p">)</span>
</span><span id="__span-18-255"><a id="__codelineno-18-255" name="__codelineno-18-255" href="#__codelineno-18-255"></a> <span class="n">logger</span><span class="o">.</span><span class="n">info</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Mean validation loss: </span><span class="si">{</span><span class="n">mean_loss</span><span class="si">}</span><span class="s2">, epoch: </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">clock</span><span class="o">.</span><span class="n">epoch</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
</span><span id="__span-18-256"><a id="__codelineno-18-256" name="__codelineno-18-256" href="#__codelineno-18-256"></a>
</span><span id="__span-18-257"><a id="__codelineno-18-257" name="__codelineno-18-257" href="#__codelineno-18-257"></a> <span class="kn">import</span> <span class="nn">matplotlib.pyplot</span> <span class="k">as</span> <span class="nn">plt</span>
</span><span id="__span-18-258"><a id="__codelineno-18-258" name="__codelineno-18-258" href="#__codelineno-18-258"></a>
</span><span id="__span-18-259"><a id="__codelineno-18-259" name="__codelineno-18-259" href="#__codelineno-18-259"></a> <span class="n">_</span><span class="p">,</span> <span class="n">axes</span> <span class="o">=</span> <span class="n">plt</span><span class="o">.</span><span class="n">subplots</span><span class="p">(</span><span class="mi">4</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">8</span><span class="p">,</span> <span class="mi">16</span><span class="p">))</span> <span class="c1"># type: ignore</span>
</span><span id="__span-18-260"><a id="__codelineno-18-260" name="__codelineno-18-260" href="#__codelineno-18-260"></a>
</span><span id="__span-18-261"><a id="__codelineno-18-261" name="__codelineno-18-261" href="#__codelineno-18-261"></a> <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="p">(</span><span class="n">mask</span><span class="p">,</span> <span class="n">reconstructed</span><span class="p">)</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">grid</span><span class="p">):</span>
</span><span id="__span-18-262"><a id="__codelineno-18-262" name="__codelineno-18-262" href="#__codelineno-18-262"></a> <span class="n">axes</span><span class="p">[</span><span class="n">i</span><span class="p">,</span> <span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">imshow</span><span class="p">(</span><span class="n">mask</span><span class="p">,</span> <span class="n">cmap</span><span class="o">=</span><span class="s2">&quot;gray&quot;</span><span class="p">)</span>
</span><span id="__span-18-263"><a id="__codelineno-18-263" name="__codelineno-18-263" href="#__codelineno-18-263"></a> <span class="n">axes</span><span class="p">[</span><span class="n">i</span><span class="p">,</span> <span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">axis</span><span class="p">(</span><span class="s2">&quot;off&quot;</span><span class="p">)</span>
</span><span id="__span-18-264"><a id="__codelineno-18-264" name="__codelineno-18-264" href="#__codelineno-18-264"></a> <span class="n">axes</span><span class="p">[</span><span class="n">i</span><span class="p">,</span> <span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">set_title</span><span class="p">(</span><span class="s2">&quot;Mask&quot;</span><span class="p">)</span>
</span><span id="__span-18-265"><a id="__codelineno-18-265" name="__codelineno-18-265" href="#__codelineno-18-265"></a>
</span><span id="__span-18-266"><a id="__codelineno-18-266" name="__codelineno-18-266" href="#__codelineno-18-266"></a> <span class="n">axes</span><span class="p">[</span><span class="n">i</span><span class="p">,</span> <span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">imshow</span><span class="p">(</span><span class="n">reconstructed</span><span class="p">,</span> <span class="n">cmap</span><span class="o">=</span><span class="s2">&quot;gray&quot;</span><span class="p">)</span>
</span><span id="__span-18-267"><a id="__codelineno-18-267" name="__codelineno-18-267" href="#__codelineno-18-267"></a> <span class="n">axes</span><span class="p">[</span><span class="n">i</span><span class="p">,</span> <span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">axis</span><span class="p">(</span><span class="s2">&quot;off&quot;</span><span class="p">)</span>
</span><span id="__span-18-268"><a id="__codelineno-18-268" name="__codelineno-18-268" href="#__codelineno-18-268"></a> <span class="n">axes</span><span class="p">[</span><span class="n">i</span><span class="p">,</span> <span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">set_title</span><span class="p">(</span><span class="s2">&quot;Reconstructed&quot;</span><span class="p">)</span>
</span><span id="__span-18-269"><a id="__codelineno-18-269" name="__codelineno-18-269" href="#__codelineno-18-269"></a>
</span><span id="__span-18-270"><a id="__codelineno-18-270" name="__codelineno-18-270" href="#__codelineno-18-270"></a> <span class="n">plt</span><span class="o">.</span><span class="n">tight_layout</span><span class="p">()</span> <span class="c1"># type: ignore</span>
</span><span id="__span-18-271"><a id="__codelineno-18-271" name="__codelineno-18-271" href="#__codelineno-18-271"></a> <span class="n">plt</span><span class="o">.</span><span class="n">savefig</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;result_</span><span class="si">{</span><span class="n">trainer</span><span class="o">.</span><span class="n">clock</span><span class="o">.</span><span class="n">epoch</span><span class="si">}</span><span class="s2">.png&quot;</span><span class="p">)</span> <span class="c1"># type: ignore</span>
</span><span id="__span-18-272"><a id="__codelineno-18-272" name="__codelineno-18-272" href="#__codelineno-18-272"></a> <span class="n">plt</span><span class="o">.</span><span class="n">close</span><span class="p">()</span> <span class="c1"># type: ignore</span>
</span><span id="__span-18-273"><a id="__codelineno-18-273" name="__codelineno-18-273" href="#__codelineno-18-273"></a>
</span><span id="__span-18-274"><a id="__codelineno-18-274" name="__codelineno-18-274" href="#__codelineno-18-274"></a> <span class="nd">@register_callback</span><span class="p">()</span>
</span><span id="__span-18-275"><a id="__codelineno-18-275" name="__codelineno-18-275" href="#__codelineno-18-275"></a> <span class="k">def</span> <span class="nf">evaluation</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">config</span><span class="p">:</span> <span class="n">EvaluationConfig</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">EvaluationCallback</span><span class="p">:</span>
</span><span id="__span-18-276"><a id="__codelineno-18-276" name="__codelineno-18-276" href="#__codelineno-18-276"></a> <span class="k">return</span> <span class="n">EvaluationCallback</span><span class="p">(</span><span class="n">config</span><span class="p">)</span>
</span><span id="__span-18-277"><a id="__codelineno-18-277" name="__codelineno-18-277" href="#__codelineno-18-277"></a>
</span><span id="__span-18-278"><a id="__codelineno-18-278" name="__codelineno-18-278" href="#__codelineno-18-278"></a> <span class="nd">@register_callback</span><span class="p">()</span>
</span><span id="__span-18-279"><a id="__codelineno-18-279" name="__codelineno-18-279" href="#__codelineno-18-279"></a> <span class="k">def</span> <span class="nf">logging</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">config</span><span class="p">:</span> <span class="n">CallbackConfig</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">LoggingCallback</span><span class="p">:</span>
</span><span id="__span-18-280"><a id="__codelineno-18-280" name="__codelineno-18-280" href="#__codelineno-18-280"></a> <span class="k">return</span> <span class="n">LoggingCallback</span><span class="p">()</span>
</span><span id="__span-18-281"><a id="__codelineno-18-281" name="__codelineno-18-281" href="#__codelineno-18-281"></a>
</span><span id="__span-18-282"><a id="__codelineno-18-282" name="__codelineno-18-282" href="#__codelineno-18-282"></a>
</span><span id="__span-18-283"><a id="__codelineno-18-283" name="__codelineno-18-283" href="#__codelineno-18-283"></a><span class="n">trainer</span> <span class="o">=</span> <span class="n">AutoencoderTrainer</span><span class="p">(</span><span class="n">config</span><span class="p">)</span>
</span><span id="__span-18-284"><a id="__codelineno-18-284" name="__codelineno-18-284" href="#__codelineno-18-284"></a>
</span><span id="__span-18-285"><a id="__codelineno-18-285" name="__codelineno-18-285" href="#__codelineno-18-285"></a><span class="n">trainer</span><span class="o">.</span><span class="n">train</span><span class="p">()</span>
</span></code></pre></div>
</details>
</article>
</div>
<script>var target=document.getElementById(location.hash.slice(1));target&&target.name&&(target.checked=target.name.startsWith("__tabbed_"))</script>
</div>
<button type="button" class="md-top md-icon" data-md-component="top" hidden>
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M13 20h-2V8l-5.5 5.5-1.42-1.42L12 4.16l7.92 7.92-1.42 1.42L13 8z"/></svg>
Back to top
</button>
</main>
<footer class="md-footer">
<div class="md-footer-meta md-typeset">
<div class="md-footer-meta__inner md-grid">
<div class="md-copyright">
<div class="md-copyright__highlight">
© Lagon Technologies
</div>
Made with
<a href="https://squidfunk.github.io/mkdocs-material/" target="_blank" rel="noopener">
Material for MkDocs
</a>
</div>
<div class="md-social">
<a href="https://discord.gg/mCmjNUVV7d" target="_blank" rel="noopener" title="discord.gg" class="md-social__link">
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 640 512"><!--! Font Awesome Free 6.6.0 by @fontawesome - https://fontawesome.com License - https://fontawesome.com/license/free (Icons: CC BY 4.0, Fonts: SIL OFL 1.1, Code: MIT License) Copyright 2024 Fonticons, Inc.--><path d="M524.531 69.836a1.5 1.5 0 0 0-.764-.7A485 485 0 0 0 404.081 32.03a1.82 1.82 0 0 0-1.923.91 338 338 0 0 0-14.9 30.6 447.9 447.9 0 0 0-134.426 0 310 310 0 0 0-15.135-30.6 1.89 1.89 0 0 0-1.924-.91 483.7 483.7 0 0 0-119.688 37.107 1.7 1.7 0 0 0-.788.676C39.068 183.651 18.186 294.69 28.43 404.354a2.02 2.02 0 0 0 .765 1.375 487.7 487.7 0 0 0 146.825 74.189 1.9 1.9 0 0 0 2.063-.676A348 348 0 0 0 208.12 430.4a1.86 1.86 0 0 0-1.019-2.588 321 321 0 0 1-45.868-21.853 1.885 1.885 0 0 1-.185-3.126 251 251 0 0 0 9.109-7.137 1.82 1.82 0 0 1 1.9-.256c96.229 43.917 200.41 43.917 295.5 0a1.81 1.81 0 0 1 1.924.233 235 235 0 0 0 9.132 7.16 1.884 1.884 0 0 1-.162 3.126 301.4 301.4 0 0 1-45.89 21.83 1.875 1.875 0 0 0-1 2.611 391 391 0 0 0 30.014 48.815 1.86 1.86 0 0 0 2.063.7A486 486 0 0 0 610.7 405.729a1.88 1.88 0 0 0 .765-1.352c12.264-126.783-20.532-236.912-86.934-334.541M222.491 337.58c-28.972 0-52.844-26.587-52.844-59.239s23.409-59.241 52.844-59.241c29.665 0 53.306 26.82 52.843 59.239 0 32.654-23.41 59.241-52.843 59.241m195.38 0c-28.971 0-52.843-26.587-52.843-59.239s23.409-59.241 52.843-59.241c29.667 0 53.307 26.82 52.844 59.239 0 32.654-23.177 59.241-52.844 59.241"/></svg>
</a>
<a href="https://github.com/finegrain-ai/refiners" target="_blank" rel="noopener" title="github.com" class="md-social__link">
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 496 512"><!--! Font Awesome Free 6.6.0 by @fontawesome - https://fontawesome.com License - https://fontawesome.com/license/free (Icons: CC BY 4.0, Fonts: SIL OFL 1.1, Code: MIT License) Copyright 2024 Fonticons, Inc.--><path d="M165.9 397.4c0 2-2.3 3.6-5.2 3.6-3.3.3-5.6-1.3-5.6-3.6 0-2 2.3-3.6 5.2-3.6 3-.3 5.6 1.3 5.6 3.6m-31.1-4.5c-.7 2 1.3 4.3 4.3 4.9 2.6 1 5.6 0 6.2-2s-1.3-4.3-4.3-5.2c-2.6-.7-5.5.3-6.2 2.3m44.2-1.7c-2.9.7-4.9 2.6-4.6 4.9.3 2 2.9 3.3 5.9 2.6 2.9-.7 4.9-2.6 4.6-4.6-.3-1.9-3-3.2-5.9-2.9M244.8 8C106.1 8 0 113.3 0 252c0 110.9 69.8 205.8 169.5 239.2 12.8 2.3 17.3-5.6 17.3-12.1 0-6.2-.3-40.4-.3-61.4 0 0-70 15-84.7-29.8 0 0-11.4-29.1-27.8-36.6 0 0-22.9-15.7 1.6-15.4 0 0 24.9 2 38.6 25.8 21.9 38.6 58.6 27.5 72.9 20.9 2.3-16 8.8-27.1 16-33.7-55.9-6.2-112.3-14.3-112.3-110.5 0-27.5 7.6-41.3 23.6-58.9-2.6-6.5-11.1-33.3 2.6-67.9 20.9-6.5 69 27 69 27 20-5.6 41.5-8.5 62.8-8.5s42.8 2.9 62.8 8.5c0 0 48.1-33.6 69-27 13.7 34.7 5.2 61.4 2.6 67.9 16 17.7 25.8 31.5 25.8 58.9 0 96.5-58.9 104.2-114.8 110.5 9.2 7.9 17 22.9 17 46.4 0 33.7-.3 75.4-.3 83.6 0 6.5 4.6 14.4 17.3 12.1C428.2 457.8 496 362.9 496 252 496 113.3 383.5 8 244.8 8M97.2 352.9c-1.3 1-1 3.3.7 5.2 1.6 1.6 3.9 2.3 5.2 1 1.3-1 1-3.3-.7-5.2-1.6-1.6-3.9-2.3-5.2-1m-10.8-8.1c-.7 1.3.3 2.9 2.3 3.9 1.6 1 3.6.7 4.3-.7.7-1.3-.3-2.9-2.3-3.9-2-.6-3.6-.3-4.3.7m32.4 35.6c-1.6 1.3-1 4.3 1.3 6.2 2.3 2.3 5.2 2.6 6.5 1 1.3-1.3.7-4.3-1.3-6.2-2.2-2.3-5.2-2.6-6.5-1m-11.4-14.7c-1.6 1-1.6 3.6 0 5.9s4.3 3.3 5.6 2.3c1.6-1.3 1.6-3.9 0-6.2-1.4-2.3-4-3.3-5.6-2"/></svg>
</a>
<a href="https://twitter.com/finegrain_ai" target="_blank" rel="noopener" title="twitter.com" class="md-social__link">
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 512 512"><!--! Font Awesome Free 6.6.0 by @fontawesome - https://fontawesome.com License - https://fontawesome.com/license/free (Icons: CC BY 4.0, Fonts: SIL OFL 1.1, Code: MIT License) Copyright 2024 Fonticons, Inc.--><path d="M459.37 151.716c.325 4.548.325 9.097.325 13.645 0 138.72-105.583 298.558-298.558 298.558-59.452 0-114.68-17.219-161.137-47.106 8.447.974 16.568 1.299 25.34 1.299 49.055 0 94.213-16.568 130.274-44.832-46.132-.975-84.792-31.188-98.112-72.772 6.498.974 12.995 1.624 19.818 1.624 9.421 0 18.843-1.3 27.614-3.573-48.081-9.747-84.143-51.98-84.143-102.985v-1.299c13.969 7.797 30.214 12.67 47.431 13.319-28.264-18.843-46.781-51.005-46.781-87.391 0-19.492 5.197-37.36 14.294-52.954 51.655 63.675 129.3 105.258 216.365 109.807-1.624-7.797-2.599-15.918-2.599-24.04 0-57.828 46.782-104.934 104.934-104.934 30.213 0 57.502 12.67 76.67 33.137 23.715-4.548 46.456-13.32 66.599-25.34-7.798 24.366-24.366 44.833-46.132 57.827 21.117-2.273 41.584-8.122 60.426-16.243-14.292 20.791-32.161 39.308-52.628 54.253"/></svg>
</a>
<a href="https://www.linkedin.com/company/finegrain-ai/" target="_blank" rel="noopener" title="www.linkedin.com" class="md-social__link">
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 448 512"><!--! Font Awesome Free 6.6.0 by @fontawesome - https://fontawesome.com License - https://fontawesome.com/license/free (Icons: CC BY 4.0, Fonts: SIL OFL 1.1, Code: MIT License) Copyright 2024 Fonticons, Inc.--><path d="M416 32H31.9C14.3 32 0 46.5 0 64.3v383.4C0 465.5 14.3 480 31.9 480H416c17.6 0 32-14.5 32-32.3V64.3c0-17.8-14.4-32.3-32-32.3M135.4 416H69V202.2h66.5V416zm-33.2-243c-21.3 0-38.5-17.3-38.5-38.5S80.9 96 102.2 96c21.2 0 38.5 17.3 38.5 38.5 0 21.3-17.2 38.5-38.5 38.5m282.1 243h-66.4V312c0-24.8-.5-56.7-34.5-56.7-34.6 0-39.9 27-39.9 54.9V416h-66.4V202.2h63.7v29.2h.9c8.9-16.8 30.6-34.5 62.9-34.5 67.2 0 79.7 44.3 79.7 101.9z"/></svg>
</a>
</div>
</div>
</div>
</footer>
</div>
<div class="md-dialog" data-md-component="dialog">
<div class="md-dialog__inner md-typeset"></div>
</div>
<script id="__config" type="application/json">{"base": "../..", "features": ["navigation.tabs", "navigation.sections", "navigation.top", "navigation.tracking", "navigation.expand", "navigation.path", "toc.follow", "navigation.tabs.sticky", "content.code.copy", "announce.dismiss"], "search": "../../assets/javascripts/workers/search.6ce7567c.min.js", "translations": {"clipboard.copied": "Copied to clipboard", "clipboard.copy": "Copy to clipboard", "search.result.more.one": "1 more on this page", "search.result.more.other": "# more on this page", "search.result.none": "No matching documents", "search.result.one": "1 matching document", "search.result.other": "# matching documents", "search.result.placeholder": "Type to start searching", "search.result.term.missing": "Missing", "select.version": "Select version"}}</script>
<script src="../../assets/javascripts/bundle.83f73b43.min.js"></script>
</body>
</html>