style: autoformating

Former-commit-id: 8c5c75469afa61e8d3728959390b1354033be462
This commit is contained in:
Your Name 2022-06-27 15:39:44 +02:00
parent cae795b9ad
commit a71e67690a
11 changed files with 650 additions and 204 deletions

13
.vscode/settings.json vendored Normal file
View file

@ -0,0 +1,13 @@
{
"python.formatting.provider": "black",
"editor.formatOnSave": true,
"python.linting.enabled": true,
"python.linting.lintOnSave": true,
// "python.linting.flake8Enabled": true,
// "python.linting.banditEnabled": true,
"[python]": {
"editor.codeActionsOnSave": {
"source.organizeImports": true
}
},
}

View file

@ -11,8 +11,8 @@ def evaluate(net, dataloader, device):
dice_score = 0
# iterate over the validation set
for batch in tqdm(dataloader, total=num_val_batches, desc='Validation round', unit='batch', leave=False):
image, mask_true = batch['image'], batch['mask']
for batch in tqdm(dataloader, total=num_val_batches, desc="Validation round", unit="batch", leave=False):
image, mask_true = batch["image"], batch["mask"]
# move images and labels to correct device and type
image = image.to(device=device, dtype=torch.float32)
mask_true = mask_true.to(device=device, dtype=torch.long)
@ -30,9 +30,9 @@ def evaluate(net, dataloader, device):
else:
mask_pred = F.one_hot(mask_pred.argmax(dim=1), net.n_classes).permute(0, 3, 1, 2).float()
# compute the Dice score, ignoring background
dice_score += multiclass_dice_coeff(mask_pred[:, 1:, ...], mask_true[:, 1:, ...], reduce_batch_first=False)
dice_score += multiclass_dice_coeff(
mask_pred[:, 1:, ...], mask_true[:, 1:, ...], reduce_batch_first=False
)
net.train()

353
poetry.lock generated
View file

@ -64,7 +64,7 @@ unicode_backport = ["unicodedata2"]
name = "click"
version = "8.1.3"
description = "Composable command line interface toolkit"
category = "dev"
category = "main"
optional = false
python-versions = ">=3.7"
@ -87,6 +87,17 @@ category = "main"
optional = false
python-versions = ">=3.6"
[[package]]
name = "docker-pycreds"
version = "0.4.0"
description = "Python bindings for the docker credentials store API"
category = "main"
optional = false
python-versions = "*"
[package.dependencies]
six = ">=1.4.0"
[[package]]
name = "fonttools"
version = "4.33.3"
@ -109,6 +120,28 @@ ufo = ["fs (>=2.2.0,<3)"]
unicode = ["unicodedata2 (>=14.0.0)"]
woff = ["zopfli (>=0.1.4)", "brotlicffi (>=0.8.0)", "brotli (>=1.0.1)"]
[[package]]
name = "gitdb"
version = "4.0.9"
description = "Git Object Database"
category = "main"
optional = false
python-versions = ">=3.6"
[package.dependencies]
smmap = ">=3.0.1,<6"
[[package]]
name = "gitpython"
version = "3.1.27"
description = "GitPython is a python library used to interact with Git repositories"
category = "main"
optional = false
python-versions = ">=3.7"
[package.dependencies]
gitdb = ">=4.0.1,<5"
[[package]]
name = "idna"
version = "3.3"
@ -146,6 +179,20 @@ pyav = ["av"]
test = ["invoke", "pytest", "pytest-cov", "fsspec"]
tifffile = ["tifffile"]
[[package]]
name = "isort"
version = "5.10.1"
description = "A Python utility / library to sort Python imports."
category = "main"
optional = false
python-versions = ">=3.6.1,<4.0"
[package.extras]
pipfile_deprecated_finder = ["pipreqs", "requirementslib"]
requirements_deprecated_finder = ["pipreqs", "pip-api"]
colors = ["colorama (>=0.4.3,<0.5.0)"]
plugins = ["setuptools"]
[[package]]
name = "joblib"
version = "1.1.0"
@ -247,6 +294,14 @@ category = "dev"
optional = false
python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,>=2.7"
[[package]]
name = "pathtools"
version = "0.1.2"
description = "File system general utilities"
category = "main"
optional = false
python-versions = "*"
[[package]]
name = "pillow"
version = "9.1.1"
@ -271,6 +326,39 @@ python-versions = ">=3.7"
docs = ["furo (>=2021.7.5b38)", "proselint (>=0.10.2)", "sphinx-autodoc-typehints (>=1.12)", "sphinx (>=4)"]
test = ["appdirs (==1.4.4)", "pytest-cov (>=2.7)", "pytest-mock (>=3.6)", "pytest (>=6)"]
[[package]]
name = "promise"
version = "2.3"
description = "Promises/A+ implementation for Python"
category = "main"
optional = false
python-versions = "*"
[package.dependencies]
six = "*"
[package.extras]
test = ["pytest (>=2.7.3)", "pytest-cov", "coveralls", "futures", "pytest-benchmark", "mock"]
[[package]]
name = "protobuf"
version = "3.20.1"
description = "Protocol Buffers"
category = "main"
optional = false
python-versions = ">=3.7"
[[package]]
name = "psutil"
version = "5.9.1"
description = "Cross-platform lib for process and system monitoring in Python."
category = "main"
optional = false
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*"
[package.extras]
test = ["ipaddress", "mock", "enum34", "pywin32", "wmi"]
[[package]]
name = "pyparsing"
version = "3.0.9"
@ -399,6 +487,47 @@ python-versions = ">=3.8,<3.11"
[package.dependencies]
numpy = ">=1.17.3,<1.25.0"
[[package]]
name = "sentry-sdk"
version = "1.6.0"
description = "Python client for Sentry (https://sentry.io)"
category = "main"
optional = false
python-versions = "*"
[package.dependencies]
certifi = "*"
urllib3 = ">=1.10.0"
[package.extras]
aiohttp = ["aiohttp (>=3.5)"]
beam = ["apache-beam (>=2.12)"]
bottle = ["bottle (>=0.12.13)"]
celery = ["celery (>=3)"]
chalice = ["chalice (>=1.16.0)"]
django = ["django (>=1.8)"]
falcon = ["falcon (>=1.4)"]
flask = ["flask (>=0.11)", "blinker (>=1.1)"]
httpx = ["httpx (>=0.16.0)"]
pure_eval = ["pure-eval", "executing", "asttokens"]
pyspark = ["pyspark (>=2.4.4)"]
quart = ["quart (>=0.16.1)", "blinker (>=1.1)"]
rq = ["rq (>=0.6)"]
sanic = ["sanic (>=0.8)"]
sqlalchemy = ["sqlalchemy (>=1.2)"]
tornado = ["tornado (>=5)"]
[[package]]
name = "setproctitle"
version = "1.2.3"
description = "A Python module to customize the process title"
category = "main"
optional = false
python-versions = ">=3.6"
[package.extras]
test = ["pytest"]
[[package]]
name = "setuptools-scm"
version = "7.0.2"
@ -416,6 +545,14 @@ typing-extensions = "*"
test = ["pytest (>=6.2)", "virtualenv (>20)"]
toml = ["setuptools (>=42)"]
[[package]]
name = "shortuuid"
version = "1.0.9"
description = "A generator library for concise, unambiguous and URL-safe UUIDs."
category = "main"
optional = false
python-versions = ">=3.5"
[[package]]
name = "six"
version = "1.16.0"
@ -424,6 +561,14 @@ category = "main"
optional = false
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*"
[[package]]
name = "smmap"
version = "5.0.0"
description = "A pure Python implementation of a sliding window memory map manager"
category = "main"
optional = false
python-versions = ">=3.6"
[[package]]
name = "threadpoolctl"
version = "3.1.0"
@ -521,10 +666,43 @@ brotli = ["brotlicffi (>=0.8.0)", "brotli (>=1.0.9)", "brotlipy (>=0.6.0)"]
secure = ["pyOpenSSL (>=0.14)", "cryptography (>=1.3.4)", "idna (>=2.0.0)", "certifi", "ipaddress"]
socks = ["PySocks (>=1.5.6,!=1.5.7,<2.0)"]
[[package]]
name = "wandb"
version = "0.12.19"
description = "A CLI and library for interacting with the Weights and Biases API."
category = "main"
optional = false
python-versions = ">=3.6"
[package.dependencies]
Click = ">=7.0,<8.0.0 || >8.0.0"
docker-pycreds = ">=0.4.0"
GitPython = ">=1.0.0"
pathtools = "*"
promise = ">=2.0,<3"
protobuf = ">=3.12.0,<4.0dev"
psutil = ">=5.0.0"
PyYAML = "*"
requests = ">=2.0.0,<3"
sentry-sdk = ">=1.0.0"
setproctitle = "*"
shortuuid = ">=0.5.0"
six = ">=1.13.0"
[package.extras]
aws = ["boto3"]
azure = ["azure-storage-blob"]
gcp = ["google-cloud-storage"]
grpc = ["grpcio (>=1.27.2)"]
kubeflow = ["kubernetes", "minio", "google-cloud-storage", "sh"]
launch = ["nbconvert", "nbformat", "chardet", "iso8601", "typing-extensions", "boto3", "google-cloud-storage", "kubernetes"]
media = ["numpy", "moviepy", "pillow", "bokeh", "soundfile", "plotly", "rdkit-pypi"]
sweeps = ["sweeps (>=0.1.0)"]
[metadata]
lock-version = "1.1"
python-versions = ">=3.8,<3.11"
content-hash = "92571d493bf6312f32dfc9b3703ceaad5ec6fee2f827632a81c38a6ae9d8875b"
content-hash = "c176396881b9758ed2f1e14e68ca042ea676d349ff4e8100cae92ca8220046be"
[metadata.files]
albumentations = [
@ -576,10 +754,22 @@ cycler = [
{file = "cycler-0.11.0-py3-none-any.whl", hash = "sha256:3a27e95f763a428a739d2add979fa7494c912a32c17c4c38c4d5f082cad165a3"},
{file = "cycler-0.11.0.tar.gz", hash = "sha256:9c87405839a19696e837b3b818fed3f5f69f16f1eec1a1ad77e043dcea9c772f"},
]
docker-pycreds = [
{file = "docker-pycreds-0.4.0.tar.gz", hash = "sha256:6ce3270bcaf404cc4c3e27e4b6c70d3521deae82fb508767870fdbf772d584d4"},
{file = "docker_pycreds-0.4.0-py2.py3-none-any.whl", hash = "sha256:7266112468627868005106ec19cd0d722702d2b7d5912a28e19b826c3d37af49"},
]
fonttools = [
{file = "fonttools-4.33.3-py3-none-any.whl", hash = "sha256:f829c579a8678fa939a1d9e9894d01941db869de44390adb49ce67055a06cc2a"},
{file = "fonttools-4.33.3.zip", hash = "sha256:c0fdcfa8ceebd7c1b2021240bd46ef77aa8e7408cf10434be55df52384865f8e"},
]
gitdb = [
{file = "gitdb-4.0.9-py3-none-any.whl", hash = "sha256:8033ad4e853066ba6ca92050b9df2f89301b8fc8bf7e9324d412a63f8bf1a8fd"},
{file = "gitdb-4.0.9.tar.gz", hash = "sha256:bac2fd45c0a1c9cf619e63a90d62bdc63892ef92387424b855792a6cabe789aa"},
]
gitpython = [
{file = "GitPython-3.1.27-py3-none-any.whl", hash = "sha256:5b68b000463593e05ff2b261acff0ff0972df8ab1b70d3cdbd41b546c8b8fc3d"},
{file = "GitPython-3.1.27.tar.gz", hash = "sha256:1c885ce809e8ba2d88a29befeb385fcea06338d3640712b59ca623c220bb5704"},
]
idna = [
{file = "idna-3.3-py3-none-any.whl", hash = "sha256:84d9dd047ffa80596e0f246e2eab0b391788b0503584e8945f2368256d2735ff"},
{file = "idna-3.3.tar.gz", hash = "sha256:9d643ff0a55b762d5cdb124b8eaa99c66322e2157b69160bc32796e824360e6d"},
@ -588,6 +778,10 @@ imageio = [
{file = "imageio-2.19.3-py3-none-any.whl", hash = "sha256:d36ab8616175a093676693a4dbc85c6cc767f981c9ce93041422569c76d06347"},
{file = "imageio-2.19.3.tar.gz", hash = "sha256:0c9df80e42f2ee68bea92001e7fcf612aa149910efe040eb757f5ce323250ae1"},
]
isort = [
{file = "isort-5.10.1-py3-none-any.whl", hash = "sha256:6f62d78e2f89b4500b080fe3a81690850cd254227f27f75c3a0c491a1f351ba7"},
{file = "isort-5.10.1.tar.gz", hash = "sha256:e8443a5e7a020e9d7f97f1d7d9cd17c88bcb3bc7e218bf9cf5095fe550be2951"},
]
joblib = [
{file = "joblib-1.1.0-py2.py3-none-any.whl", hash = "sha256:f21f109b3c7ff9d95f8387f752d0d9c34a02aa2f7060c2135f465da0e5160ff6"},
{file = "joblib-1.1.0.tar.gz", hash = "sha256:4158fcecd13733f8be669be0683b96ebdbbd38d23559f54dca7205aea1bf1e35"},
@ -723,6 +917,9 @@ pathspec = [
{file = "pathspec-0.9.0-py2.py3-none-any.whl", hash = "sha256:7d15c4ddb0b5c802d161efc417ec1a2558ea2653c2e8ad9c19098201dc1c993a"},
{file = "pathspec-0.9.0.tar.gz", hash = "sha256:e564499435a2673d586f6b2130bb5b95f04a3ba06f81b8f895b651a3c76aabb1"},
]
pathtools = [
{file = "pathtools-0.1.2.tar.gz", hash = "sha256:7c35c5421a39bb82e58018febd90e3b6e5db34c5443aaaf742b3f33d4655f1c0"},
]
pillow = [
{file = "Pillow-9.1.1-cp310-cp310-macosx_10_10_x86_64.whl", hash = "sha256:42dfefbef90eb67c10c45a73a9bc1599d4dac920f7dfcbf4ec6b80cb620757fe"},
{file = "Pillow-9.1.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:ffde4c6fabb52891d81606411cbfaf77756e3b561b566efd270b3ed3791fde4e"},
@ -767,6 +964,69 @@ platformdirs = [
{file = "platformdirs-2.5.2-py3-none-any.whl", hash = "sha256:027d8e83a2d7de06bbac4e5ef7e023c02b863d7ea5d079477e722bb41ab25788"},
{file = "platformdirs-2.5.2.tar.gz", hash = "sha256:58c8abb07dcb441e6ee4b11d8df0ac856038f944ab98b7be6b27b2a3c7feef19"},
]
promise = [
{file = "promise-2.3.tar.gz", hash = "sha256:dfd18337c523ba4b6a58801c164c1904a9d4d1b1747c7d5dbf45b693a49d93d0"},
]
protobuf = [
{file = "protobuf-3.20.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:3cc797c9d15d7689ed507b165cd05913acb992d78b379f6014e013f9ecb20996"},
{file = "protobuf-3.20.1-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:ff8d8fa42675249bb456f5db06c00de6c2f4c27a065955917b28c4f15978b9c3"},
{file = "protobuf-3.20.1-cp310-cp310-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:cd68be2559e2a3b84f517fb029ee611546f7812b1fdd0aa2ecc9bc6ec0e4fdde"},
{file = "protobuf-3.20.1-cp310-cp310-win32.whl", hash = "sha256:9016d01c91e8e625141d24ec1b20fed584703e527d28512aa8c8707f105a683c"},
{file = "protobuf-3.20.1-cp310-cp310-win_amd64.whl", hash = "sha256:32ca378605b41fd180dfe4e14d3226386d8d1b002ab31c969c366549e66a2bb7"},
{file = "protobuf-3.20.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:9be73ad47579abc26c12024239d3540e6b765182a91dbc88e23658ab71767153"},
{file = "protobuf-3.20.1-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:097c5d8a9808302fb0da7e20edf0b8d4703274d140fd25c5edabddcde43e081f"},
{file = "protobuf-3.20.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:e250a42f15bf9d5b09fe1b293bdba2801cd520a9f5ea2d7fb7536d4441811d20"},
{file = "protobuf-3.20.1-cp37-cp37m-manylinux2014_aarch64.whl", hash = "sha256:cdee09140e1cd184ba9324ec1df410e7147242b94b5f8b0c64fc89e38a8ba531"},
{file = "protobuf-3.20.1-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:af0ebadc74e281a517141daad9d0f2c5d93ab78e9d455113719a45a49da9db4e"},
{file = "protobuf-3.20.1-cp37-cp37m-win32.whl", hash = "sha256:755f3aee41354ae395e104d62119cb223339a8f3276a0cd009ffabfcdd46bb0c"},
{file = "protobuf-3.20.1-cp37-cp37m-win_amd64.whl", hash = "sha256:62f1b5c4cd6c5402b4e2d63804ba49a327e0c386c99b1675c8a0fefda23b2067"},
{file = "protobuf-3.20.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:06059eb6953ff01e56a25cd02cca1a9649a75a7e65397b5b9b4e929ed71d10cf"},
{file = "protobuf-3.20.1-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:cb29edb9eab15742d791e1025dd7b6a8f6fcb53802ad2f6e3adcb102051063ab"},
{file = "protobuf-3.20.1-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:69ccfdf3657ba59569c64295b7d51325f91af586f8d5793b734260dfe2e94e2c"},
{file = "protobuf-3.20.1-cp38-cp38-win32.whl", hash = "sha256:dd5789b2948ca702c17027c84c2accb552fc30f4622a98ab5c51fcfe8c50d3e7"},
{file = "protobuf-3.20.1-cp38-cp38-win_amd64.whl", hash = "sha256:77053d28427a29987ca9caf7b72ccafee011257561259faba8dd308fda9a8739"},
{file = "protobuf-3.20.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:6f50601512a3d23625d8a85b1638d914a0970f17920ff39cec63aaef80a93fb7"},
{file = "protobuf-3.20.1-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:284f86a6207c897542d7e956eb243a36bb8f9564c1742b253462386e96c6b78f"},
{file = "protobuf-3.20.1-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:7403941f6d0992d40161aa8bb23e12575637008a5a02283a930addc0508982f9"},
{file = "protobuf-3.20.1-cp39-cp39-win32.whl", hash = "sha256:db977c4ca738dd9ce508557d4fce0f5aebd105e158c725beec86feb1f6bc20d8"},
{file = "protobuf-3.20.1-cp39-cp39-win_amd64.whl", hash = "sha256:7e371f10abe57cee5021797126c93479f59fccc9693dafd6bd5633ab67808a91"},
{file = "protobuf-3.20.1-py2.py3-none-any.whl", hash = "sha256:adfc6cf69c7f8c50fd24c793964eef18f0ac321315439d94945820612849c388"},
{file = "protobuf-3.20.1.tar.gz", hash = "sha256:adc31566d027f45efe3f44eeb5b1f329da43891634d61c75a5944e9be6dd42c9"},
]
psutil = [
{file = "psutil-5.9.1-cp27-cp27m-manylinux2010_i686.whl", hash = "sha256:799759d809c31aab5fe4579e50addf84565e71c1dc9f1c31258f159ff70d3f87"},
{file = "psutil-5.9.1-cp27-cp27m-manylinux2010_x86_64.whl", hash = "sha256:9272167b5f5fbfe16945be3db475b3ce8d792386907e673a209da686176552af"},
{file = "psutil-5.9.1-cp27-cp27m-win32.whl", hash = "sha256:0904727e0b0a038830b019551cf3204dd48ef5c6868adc776e06e93d615fc5fc"},
{file = "psutil-5.9.1-cp27-cp27m-win_amd64.whl", hash = "sha256:e7e10454cb1ab62cc6ce776e1c135a64045a11ec4c6d254d3f7689c16eb3efd2"},
{file = "psutil-5.9.1-cp27-cp27mu-manylinux2010_i686.whl", hash = "sha256:56960b9e8edcca1456f8c86a196f0c3d8e3e361320071c93378d41445ffd28b0"},
{file = "psutil-5.9.1-cp27-cp27mu-manylinux2010_x86_64.whl", hash = "sha256:44d1826150d49ffd62035785a9e2c56afcea66e55b43b8b630d7706276e87f22"},
{file = "psutil-5.9.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:c7be9d7f5b0d206f0bbc3794b8e16fb7dbc53ec9e40bbe8787c6f2d38efcf6c9"},
{file = "psutil-5.9.1-cp310-cp310-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:abd9246e4cdd5b554a2ddd97c157e292ac11ef3e7af25ac56b08b455c829dca8"},
{file = "psutil-5.9.1-cp310-cp310-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:29a442e25fab1f4d05e2655bb1b8ab6887981838d22effa2396d584b740194de"},
{file = "psutil-5.9.1-cp310-cp310-win32.whl", hash = "sha256:20b27771b077dcaa0de1de3ad52d22538fe101f9946d6dc7869e6f694f079329"},
{file = "psutil-5.9.1-cp310-cp310-win_amd64.whl", hash = "sha256:58678bbadae12e0db55186dc58f2888839228ac9f41cc7848853539b70490021"},
{file = "psutil-5.9.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:3a76ad658641172d9c6e593de6fe248ddde825b5866464c3b2ee26c35da9d237"},
{file = "psutil-5.9.1-cp36-cp36m-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a6a11e48cb93a5fa606306493f439b4aa7c56cb03fc9ace7f6bfa21aaf07c453"},
{file = "psutil-5.9.1-cp36-cp36m-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:068935df39055bf27a29824b95c801c7a5130f118b806eee663cad28dca97685"},
{file = "psutil-5.9.1-cp36-cp36m-win32.whl", hash = "sha256:0f15a19a05f39a09327345bc279c1ba4a8cfb0172cc0d3c7f7d16c813b2e7d36"},
{file = "psutil-5.9.1-cp36-cp36m-win_amd64.whl", hash = "sha256:db417f0865f90bdc07fa30e1aadc69b6f4cad7f86324b02aa842034efe8d8c4d"},
{file = "psutil-5.9.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:91c7ff2a40c373d0cc9121d54bc5f31c4fa09c346528e6a08d1845bce5771ffc"},
{file = "psutil-5.9.1-cp37-cp37m-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:fea896b54f3a4ae6f790ac1d017101252c93f6fe075d0e7571543510f11d2676"},
{file = "psutil-5.9.1-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3054e923204b8e9c23a55b23b6df73a8089ae1d075cb0bf711d3e9da1724ded4"},
{file = "psutil-5.9.1-cp37-cp37m-win32.whl", hash = "sha256:d2d006286fbcb60f0b391741f520862e9b69f4019b4d738a2a45728c7e952f1b"},
{file = "psutil-5.9.1-cp37-cp37m-win_amd64.whl", hash = "sha256:b14ee12da9338f5e5b3a3ef7ca58b3cba30f5b66f7662159762932e6d0b8f680"},
{file = "psutil-5.9.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:19f36c16012ba9cfc742604df189f2f28d2720e23ff7d1e81602dbe066be9fd1"},
{file = "psutil-5.9.1-cp38-cp38-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:944c4b4b82dc4a1b805329c980f270f170fdc9945464223f2ec8e57563139cf4"},
{file = "psutil-5.9.1-cp38-cp38-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4b6750a73a9c4a4e689490ccb862d53c7b976a2a35c4e1846d049dcc3f17d83b"},
{file = "psutil-5.9.1-cp38-cp38-win32.whl", hash = "sha256:a8746bfe4e8f659528c5c7e9af5090c5a7d252f32b2e859c584ef7d8efb1e689"},
{file = "psutil-5.9.1-cp38-cp38-win_amd64.whl", hash = "sha256:79c9108d9aa7fa6fba6e668b61b82facc067a6b81517cab34d07a84aa89f3df0"},
{file = "psutil-5.9.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:28976df6c64ddd6320d281128817f32c29b539a52bdae5e192537bc338a9ec81"},
{file = "psutil-5.9.1-cp39-cp39-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b88f75005586131276634027f4219d06e0561292be8bd6bc7f2f00bdabd63c4e"},
{file = "psutil-5.9.1-cp39-cp39-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:645bd4f7bb5b8633803e0b6746ff1628724668681a434482546887d22c7a9537"},
{file = "psutil-5.9.1-cp39-cp39-win32.whl", hash = "sha256:32c52611756096ae91f5d1499fe6c53b86f4a9ada147ee42db4991ba1520e574"},
{file = "psutil-5.9.1-cp39-cp39-win_amd64.whl", hash = "sha256:f65f9a46d984b8cd9b3750c2bdb419b2996895b005aefa6cbaba9a143b1ce2c5"},
{file = "psutil-5.9.1.tar.gz", hash = "sha256:57f1819b5d9e95cdfb0c881a8a5b7d542ed0b7c522d575706a80bedc848c8954"},
]
pyparsing = [
{file = "pyparsing-3.0.9-py3-none-any.whl", hash = "sha256:5026bae9a10eeaefb61dab2f09052b9f4307d44aee4eda64b309723d8d206bbc"},
{file = "pyparsing-3.0.9.tar.gz", hash = "sha256:2b020ecf7d21b687f219b71ecad3631f644a47f01403fa1d1036b0c6416d70fb"},
@ -916,14 +1176,99 @@ scipy = [
{file = "scipy-1.8.1-cp39-cp39-win_amd64.whl", hash = "sha256:9dd4012ac599a1e7eb63c114d1eee1bcfc6dc75a29b589ff0ad0bb3d9412034f"},
{file = "scipy-1.8.1.tar.gz", hash = "sha256:9e3fb1b0e896f14a85aa9a28d5f755daaeeb54c897b746df7a55ccb02b340f33"},
]
sentry-sdk = [
{file = "sentry-sdk-1.6.0.tar.gz", hash = "sha256:b82ad57306d5546713f15d5d70daea0408cf7f998c7566db16e0e6257e51e561"},
{file = "sentry_sdk-1.6.0-py2.py3-none-any.whl", hash = "sha256:ddbd191b6f4e696b7845b4d87389898ae1207981faf114f968a57363aa6be03c"},
]
setproctitle = [
{file = "setproctitle-1.2.3-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:0a668acec8b61a971de54bc4c733869ea7b0eb1348eae5a32b9477f788908e5c"},
{file = "setproctitle-1.2.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:52265182fe5ac237d179d8e949248d307882a2e6ec7f189c8dac1c9d1b3631fa"},
{file = "setproctitle-1.2.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:71d00ef63a1f78e13c236895badac77b6c8503377467b9c1a4f81fe729d16e03"},
{file = "setproctitle-1.2.3-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:eb82a49aaf440232c762539ab3737b5174d31aba0141fd4bf4d8739c28d18624"},
{file = "setproctitle-1.2.3-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:791bed39e4ecbdd008b64999a60c9cc560d17b3836ca0c27cd4708e8e1bcf495"},
{file = "setproctitle-1.2.3-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d8e4da68d4d4ba46d4c5db6ae5eb61b11de9c520f25ae8334570f4d0018a8611"},
{file = "setproctitle-1.2.3-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:47f97f591ea2335b7d35f5e9ad7d806385338182dc6de5732d091e9c70ed1cc0"},
{file = "setproctitle-1.2.3-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:501c084cf3df7d848e91c97d4f8c44d799ba545858a79c6960326ce6f285b4e4"},
{file = "setproctitle-1.2.3-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:a39b30d7400c0d50941fe19e1fe0b7d35676186fec4d9c010129ac91b883fd26"},
{file = "setproctitle-1.2.3-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:b213376fc779c0e1a4b60008f3fd03f74e9baa9665db37fa6646e98d31baa6d8"},
{file = "setproctitle-1.2.3-cp310-cp310-win32.whl", hash = "sha256:e24fa9251cc22ddb88ef183070063fdca826c9636381f1c4fb9d2a1dccb7c2a4"},
{file = "setproctitle-1.2.3-cp310-cp310-win_amd64.whl", hash = "sha256:3b1883ccdbee624386dc046cfbcd80c4e75e24c478f35627984a79892e088b88"},
{file = "setproctitle-1.2.3-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:f9cf1098205c23fbcaaaef798afaff714fa9ffadf24166f5e85e6d16b9ef82a1"},
{file = "setproctitle-1.2.3-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a546cd2dfaecb227d24122257b98b2e062762871888835c7b608f1c41c3a77ad"},
{file = "setproctitle-1.2.3-cp36-cp36m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e40c35564081983eab6a07f9eb5693867bc447b0edf9c61b69446223d6593814"},
{file = "setproctitle-1.2.3-cp36-cp36m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2d083cae02e344e760bd21c28d591ac5f7ddbd6e1a0ecba62092ae724abd5c28"},
{file = "setproctitle-1.2.3-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b2fa9f4b382a6cf88f2f345044d0916a92f37cac21355585bd14bc7ee91af187"},
{file = "setproctitle-1.2.3-cp36-cp36m-musllinux_1_1_aarch64.whl", hash = "sha256:38855b06a124361dc73c198853dee3f2b775531c4f4b7472f0e3d441192b3d8a"},
{file = "setproctitle-1.2.3-cp36-cp36m-musllinux_1_1_i686.whl", hash = "sha256:a81067bdc015fee1cc148c79b346f24fdad1224a8898b4239c7cbdee1add8a60"},
{file = "setproctitle-1.2.3-cp36-cp36m-musllinux_1_1_ppc64le.whl", hash = "sha256:409a39f92e123be061626fdfd3e76625b04db103479bb4ba1c85b587db0b9498"},
{file = "setproctitle-1.2.3-cp36-cp36m-musllinux_1_1_x86_64.whl", hash = "sha256:a993610383028f093112dce7f77b262e88fce9d70127535fcdc78953179857e8"},
{file = "setproctitle-1.2.3-cp36-cp36m-win32.whl", hash = "sha256:4eed53c12146de5df959d84384ffc2774651cab406ee4854e12728cf0eee5297"},
{file = "setproctitle-1.2.3-cp36-cp36m-win_amd64.whl", hash = "sha256:335750c9eb5b18326a138a09266862a52b4f474277c3e410b419bea9a1df8bee"},
{file = "setproctitle-1.2.3-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:7a72bbe53191fbe574c94c0f8b9451dce535b398b7c47ce2e26e21d55eaa1d7e"},
{file = "setproctitle-1.2.3-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5464e6812d050c986e6e9b97d54ab88c23dbe9d81151a2fa10b48bb5133a1e2c"},
{file = "setproctitle-1.2.3-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ec7c3a27460ae7811e868e5494e3d8aee5012912744c48fa2d80b5e614b1b972"},
{file = "setproctitle-1.2.3-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:01cef383afc7ea7a3b1696818c8712029bf2f1d64f5d4777dbaf0166becf2c00"},
{file = "setproctitle-1.2.3-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:54c7315e53b49ef2227d47a75c3d28c4c51ea9ee46a066460732c0d0f8e605a7"},
{file = "setproctitle-1.2.3-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:0b444ed4051161a3b0a85dec2bb9b50922f37c75f5fb86f7784b235cf6754336"},
{file = "setproctitle-1.2.3-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:be0b46beeb1c92450079a7f30a025d69b63fd6a5de040ebc478fd6e6bf3b63fc"},
{file = "setproctitle-1.2.3-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:60f7a2f5da36a3075dda7edbee2173be5b765b0460b8d401ee01a11f68dee1d2"},
{file = "setproctitle-1.2.3-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:138bfa853e607f06d95b0f253e9152b32a00af3d0dbec96abf0871236a483932"},
{file = "setproctitle-1.2.3-cp37-cp37m-win32.whl", hash = "sha256:e80fc59739a738b5c67afbbb9d1c238aa47b6d290c2ada872b15c819350ec5f8"},
{file = "setproctitle-1.2.3-cp37-cp37m-win_amd64.whl", hash = "sha256:a912df3f065572cef211e9ed9f157a0dd2bd73d150281f18f00728afa1b1e5d2"},
{file = "setproctitle-1.2.3-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:d45dbe4171f8c27a515ecb4562f4cd9ef67d98474bea18e0c14dfbdc2b225050"},
{file = "setproctitle-1.2.3-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:b9d905ac84dde5227de6516ec08639759f99684148bb88ba05f4cbdaebff5d69"},
{file = "setproctitle-1.2.3-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f272b84d79bbe15af26ecf6f7c129bbe642f628866c9253659cdb519216f138f"},
{file = "setproctitle-1.2.3-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:fc586f002fd5dd8695718e22a83771fd9f744f081a2b8e614bf6b5f44135964a"},
{file = "setproctitle-1.2.3-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4051c3a3b07f8a4cca205cd45366a22f322da2f26491c0d6b313a10f8c77b734"},
{file = "setproctitle-1.2.3-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:25538341e56f9e75e9759229ff674282dccb5b1ce79a974f968d36208d465674"},
{file = "setproctitle-1.2.3-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:fdb2231db176e0848b757fc5d9bed08bc8a498b5b9abb8b640f39e9720f309fc"},
{file = "setproctitle-1.2.3-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:0670f2130a7ca0e167d3d5a7c8e3c707340b8693d6af7416ff55c18ab2a0a43f"},
{file = "setproctitle-1.2.3-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:9a92978030616f5e20617b7b832efee398df82072b7239c53db41c8026f5fe55"},
{file = "setproctitle-1.2.3-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:28e0df80d5069586a08a3cb463fb23503a37cbb805826ef93164bc4bfb5f35b9"},
{file = "setproctitle-1.2.3-cp38-cp38-win32.whl", hash = "sha256:35b869e416a105c59133a48b569c6e808159485d916f55e80c7394a42667a386"},
{file = "setproctitle-1.2.3-cp38-cp38-win_amd64.whl", hash = "sha256:f47f6704880869d8e8f52efac2f2f60f5ed4cb9662b98fc1c7e916eefe76e61d"},
{file = "setproctitle-1.2.3-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:ccb0b5334dbf248f7504d88b5e9e9a09a0da119eeafacd6f7247f7c055443522"},
{file = "setproctitle-1.2.3-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:14641a4ec2f2110cf4afc666eaecc82ba67814e927e02647fa1f4cf74476e752"},
{file = "setproctitle-1.2.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a4a3cb19346a0cd680617742f5e39fdd14596f6fd91d6c9038272663e37441b4"},
{file = "setproctitle-1.2.3-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e2ac0ebd9c63c3d19f768966be2f771bf088bc7373c63ed6fcbb3444a30d0f62"},
{file = "setproctitle-1.2.3-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:32a84cc309b9e595f06a55bec2fa335a23c307a55d2989864b60ecd71ea87897"},
{file = "setproctitle-1.2.3-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3f55493c987935fa540ef9ffb7ee7db03b4a18a9d5cc103681e2e6a6dfbd7054"},
{file = "setproctitle-1.2.3-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:f2a137984d3436f13e4bf7c8ca6f6f292df119c009c5e39556cabba4f4bfbf92"},
{file = "setproctitle-1.2.3-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:f06ff922254023eaabef6af6631f89e5f2f420cf0112865d57d7703f933d4e9f"},
{file = "setproctitle-1.2.3-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:eb06c1086cf8c8cf12ce45a02450befcb408dfd646d0ccb47d388fd6e73c333a"},
{file = "setproctitle-1.2.3-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:2c8c245e08f6a296fdaa1b36894ec40e20464a4fc6458e6178c8d55a2f83457a"},
{file = "setproctitle-1.2.3-cp39-cp39-win32.whl", hash = "sha256:21d6e064b8fee4e58eb00cdd8771c638de1bc30bb6c02d0208af9ca0a1c00898"},
{file = "setproctitle-1.2.3-cp39-cp39-win_amd64.whl", hash = "sha256:efb3001fd9e71d3ae939d826bf436f0446fd30a6ac01e0ce08cd7eb55ee5ac57"},
{file = "setproctitle-1.2.3-pp37-pypy37_pp73-macosx_10_9_x86_64.whl", hash = "sha256:3dbe87e76197f9a303451512088c18c96f09a6fc4f871a92e5bd695f46f94a26"},
{file = "setproctitle-1.2.3-pp37-pypy37_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0b207de9e4f4aa5265b36dd826a1f6ef6566b064a042033bd7447efb7e9a7664"},
{file = "setproctitle-1.2.3-pp37-pypy37_pp73-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:48ac48a94040ef21be37366cbc8270fcba2ca103d6c64da6099d5a7b034f72d0"},
{file = "setproctitle-1.2.3-pp37-pypy37_pp73-win_amd64.whl", hash = "sha256:9fb5d2e66f94eebc3d06cda9e71a3fffef24c5273971180a4b5628a37fae05a5"},
{file = "setproctitle-1.2.3-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:423f8a6d8116acf975ebf93d6b5c4a752f7d2039fa9aafe175a62de86e17016e"},
{file = "setproctitle-1.2.3-pp38-pypy38_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2c0be45535e934deab3aa72ed1a8487174af4ea12cec124478c68a312e1c8b13"},
{file = "setproctitle-1.2.3-pp38-pypy38_pp73-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:65a9384cafdfed98f91416e93705ad08f049c298afcb9c515882beba23153bd0"},
{file = "setproctitle-1.2.3-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:d312a170f539895c8093b5e68ba126aa131c9f0d00f6360410db27ec50bf7afa"},
{file = "setproctitle-1.2.3-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:c93a2272740e60cddf59d3e1d35dbb89fcc3676f5ca9618bb4e6ae9633fdf13c"},
{file = "setproctitle-1.2.3-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:76f59444a25fb42ca07f53a4474b1545d97a06f016e6c6b8246eee5b146820b5"},
{file = "setproctitle-1.2.3-pp39-pypy39_pp73-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:06aab65e68163ead9d046b452dd9ad1fc6834ce6bde490f63fdce3be53e9cc73"},
{file = "setproctitle-1.2.3-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:97accd117392b1e57e09888792750c403d7729b7e4b193005178b3736b325ea0"},
{file = "setproctitle-1.2.3.tar.gz", hash = "sha256:ecf28b1c07a799d76f4326e508157b71aeda07b84b90368ea451c0710dbd32c0"},
]
setuptools-scm = [
{file = "setuptools_scm-7.0.2-py3-none-any.whl", hash = "sha256:ec120c99027a785c7a349667480a5b2100dfc7d5063e545c93f3735508945aef"},
{file = "setuptools_scm-7.0.2.tar.gz", hash = "sha256:53afbaa882f43e5ec8e14f2d93b9c268dc62eb3f6cc8b2790f8cb97903e53c02"},
]
shortuuid = [
{file = "shortuuid-1.0.9-py3-none-any.whl", hash = "sha256:b2bb9eb7773170e253bb7ba25971023acb473517a8b76803d9618668cb1dd46f"},
{file = "shortuuid-1.0.9.tar.gz", hash = "sha256:459f12fa1acc34ff213b1371467c0325169645a31ed989e268872339af7563d5"},
]
six = [
{file = "six-1.16.0-py2.py3-none-any.whl", hash = "sha256:8abb2f1d86890a2dfb989f9a77cfcfd3e47c2a354b01111771326f8aa26e0254"},
{file = "six-1.16.0.tar.gz", hash = "sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926"},
]
smmap = [
{file = "smmap-5.0.0-py3-none-any.whl", hash = "sha256:2aba19d6a040e78d8b09de5c57e96207b09ed71d8e55ce0959eeee6c8e190d94"},
{file = "smmap-5.0.0.tar.gz", hash = "sha256:c840e62059cd3be204b0c9c9f74be2c09d5648eddd4580d9314c3ecde0b30936"},
]
threadpoolctl = [
{file = "threadpoolctl-3.1.0-py3-none-any.whl", hash = "sha256:8b99adda265feb6773280df41eece7b2e6561b772d21ffd52e372f999024907b"},
{file = "threadpoolctl-3.1.0.tar.gz", hash = "sha256:a335baacfaa4400ae1f0d8e3a58d6674d2f8828e3716bb2802c44955ad391380"},
@ -990,3 +1335,7 @@ urllib3 = [
{file = "urllib3-1.26.9-py2.py3-none-any.whl", hash = "sha256:44ece4d53fb1706f667c9bd1c648f5469a2ec925fcf3a776667042d645472c14"},
{file = "urllib3-1.26.9.tar.gz", hash = "sha256:aabaf16477806a5e1dd19aa41f8c2b7950dd3c746362d7e3223dbe6de6ac448e"},
]
wandb = [
{file = "wandb-0.12.19-py2.py3-none-any.whl", hash = "sha256:ff344e35a31a0d6ccfccaf0390895e4267392f994e395d16b348d7d2b2f67a18"},
{file = "wandb-0.12.19.tar.gz", hash = "sha256:72c54a918e2453ffa661934faf6c3c8b39269423daba0999b9919bcaef3ae320"},
]

View file

@ -12,11 +12,8 @@ from utils.data_loading import BasicDataset
from unet import UNet
from utils.utils import plot_img_and_mask
def predict_img(net,
full_img,
device,
scale_factor=1,
out_threshold=0.5):
def predict_img(net, full_img, device, scale_factor=1, out_threshold=0.5):
net.eval()
img = torch.from_numpy(BasicDataset.preprocess(full_img, scale_factor, is_mask=False))
img = img.unsqueeze(0)
@ -30,11 +27,9 @@ def predict_img(net,
else:
probs = torch.sigmoid(output)[0]
tf = transforms.Compose([
transforms.ToPILImage(),
transforms.Resize((full_img.size[1], full_img.size[0])),
transforms.ToTensor()
])
tf = transforms.Compose(
[transforms.ToPILImage(), transforms.Resize((full_img.size[1], full_img.size[0])), transforms.ToTensor()]
)
full_mask = tf(probs.cpu()).squeeze()
@ -45,19 +40,63 @@ def predict_img(net,
def get_args():
parser = argparse.ArgumentParser(description='Predict masks from input images')
parser.add_argument('--model', '-m', default='MODEL.pth', metavar='FILE',
help='Specify the file in which the model is stored')
parser.add_argument('--input', '-i', metavar='INPUT', nargs='+', help='Filenames of input images', required=True)
parser.add_argument('--output', '-o', metavar='OUTPUT', nargs='+', help='Filenames of output images')
parser.add_argument('--viz', '-v', action='store_true',
help='Visualize the images as they are processed')
parser.add_argument('--no-save', '-n', action='store_true', help='Do not save the output masks')
parser.add_argument('--mask-threshold', '-t', type=float, default=0.5,
help='Minimum probability value to consider a mask pixel white')
parser.add_argument('--scale', '-s', type=float, default=0.5,
help='Scale factor for the input images')
parser.add_argument('--bilinear', action='store_true', default=False, help='Use bilinear upsampling')
parser = argparse.ArgumentParser(
description="Predict masks from input images",
)
parser.add_argument(
"--model",
"-m",
default="MODEL.pth",
metavar="FILE",
help="Specify the file in which the model is stored",
)
parser.add_argument(
"--input",
"-i",
metavar="INPUT",
nargs="+",
help="Filenames of input images",
required=True,
)
parser.add_argument(
"--output",
"-o",
metavar="OUTPUT",
nargs="+",
help="Filenames of output images",
)
parser.add_argument(
"--viz",
"-v",
action="store_true",
help="Visualize the images as they are processed",
)
parser.add_argument(
"--no-save",
"-n",
action="store_true",
help="Do not save the output masks",
)
parser.add_argument(
"--mask-threshold",
"-t",
type=float,
default=0.5,
help="Minimum probability value to consider a mask pixel white",
)
parser.add_argument(
"--scale",
"-s",
type=float,
default=0.5,
help="Scale factor for the input images",
)
parser.add_argument(
"--bilinear",
action="store_true",
default=False,
help="Use bilinear upsampling",
)
return parser.parse_args()
@ -65,7 +104,7 @@ def get_args():
def get_output_filenames(args):
def _generate_name(fn):
split = os.path.splitext(fn)
return f'{split[0]}_OUT{split[1]}'
return f"{split[0]}_OUT{split[1]}"
return args.output or list(map(_generate_name, args.input))
@ -77,38 +116,36 @@ def mask_to_image(mask: np.ndarray):
return Image.fromarray((np.argmax(mask, axis=0) * 255 / mask.shape[0]).astype(np.uint8))
if __name__ == '__main__':
if __name__ == "__main__":
args = get_args()
in_files = args.input
out_files = get_output_filenames(args)
net = UNet(n_channels=3, n_classes=2, bilinear=args.bilinear)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logging.info(f'Loading model {args.model}')
logging.info(f'Using device {device}')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logging.info(f"Loading model {args.model}")
logging.info(f"Using device {device}")
net.to(device=device)
net.load_state_dict(torch.load(args.model, map_location=device))
logging.info('Model loaded!')
logging.info("Model loaded!")
for i, filename in enumerate(in_files):
logging.info(f'\nPredicting image {filename} ...')
logging.info(f"\nPredicting image {filename} ...")
img = Image.open(filename)
mask = predict_img(net=net,
full_img=img,
scale_factor=args.scale,
out_threshold=args.mask_threshold,
device=device)
mask = predict_img(
net=net, full_img=img, scale_factor=args.scale, out_threshold=args.mask_threshold, device=device
)
if not args.no_save:
out_filename = out_files[i]
result = mask_to_image(mask)
result.save(out_filename)
logging.info(f'Mask saved to {out_filename}')
logging.info(f"Mask saved to {out_filename}")
if args.viz:
logging.info(f'Visualizing results for image {filename}, close to continue...')
logging.info(f"Visualizing results for image {filename}, close to continue...")
plot_img_and_mask(img, mask)

View file

@ -13,9 +13,11 @@ albumentations = "^1.2.0"
tqdm = "^4.64.0"
torchvision = "^0.12.0"
matplotlib = "^3.5.2"
wandb = "^0.12.19"
[tool.poetry.dev-dependencies]
black = "^22.3.0"
isort = "^5.10.1"
[build-system]
requires = ["poetry-core>=1.0.0"]

View file

@ -1,27 +0,0 @@
#!/bin/bash
if [[ ! -f ~/.kaggle/kaggle.json ]]; then
echo -n "Kaggle username: "
read USERNAME
echo
echo -n "Kaggle API key: "
read APIKEY
mkdir -p ~/.kaggle
echo "{\"username\":\"$USERNAME\",\"key\":\"$APIKEY\"}" > ~/.kaggle/kaggle.json
chmod 600 ~/.kaggle/kaggle.json
fi
pip install kaggle --upgrade
kaggle competitions download -c carvana-image-masking-challenge -f train_hq.zip
unzip train_hq.zip
mv train_hq/* data/imgs/
rm -d train_hq
rm train_hq.zip
kaggle competitions download -c carvana-image-masking-challenge -f train_masks.zip
unzip train_masks.zip
mv train_masks/* data/masks/
rm -d train_masks
rm train_masks.zip

255
train.py
View file

@ -1,6 +1,5 @@
import argparse
import logging
import sys
from pathlib import Path
import torch
@ -11,25 +10,27 @@ from torch import optim
from torch.utils.data import DataLoader, random_split
from tqdm import tqdm
from utils.data_loading import BasicDataset, CarvanaDataset
from utils.dice_score import dice_loss
from evaluate import evaluate
from unet import UNet
from utils.data_loading import BasicDataset, CarvanaDataset
from utils.dice_score import dice_loss
dir_img = Path('./data/imgs/')
dir_mask = Path('./data/masks/')
dir_checkpoint = Path('./checkpoints/')
dir_img = Path("./data/imgs/")
dir_mask = Path("./data/masks/")
dir_checkpoint = Path("./checkpoints/")
def train_net(net,
device,
epochs: int = 5,
batch_size: int = 1,
learning_rate: float = 1e-5,
val_percent: float = 0.1,
save_checkpoint: bool = True,
img_scale: float = 0.5,
amp: bool = False):
def train_net(
net,
device,
epochs: int = 5,
batch_size: int = 1,
learning_rate: float = 1e-5,
val_percent: float = 0.1,
save_checkpoint: bool = True,
img_scale: float = 0.5,
amp: bool = False,
):
# 1. Create dataset
try:
dataset = CarvanaDataset(dir_img, dir_mask, img_scale)
@ -47,12 +48,21 @@ def train_net(net,
val_loader = DataLoader(val_set, shuffle=False, drop_last=True, **loader_args)
# (Initialize logging)
experiment = wandb.init(project='U-Net', resume='allow', anonymous='must')
experiment.config.update(dict(epochs=epochs, batch_size=batch_size, learning_rate=learning_rate,
val_percent=val_percent, save_checkpoint=save_checkpoint, img_scale=img_scale,
amp=amp))
experiment = wandb.init(project="U-Net", resume="allow", anonymous="must")
experiment.config.update(
dict(
epochs=epochs,
batch_size=batch_size,
learning_rate=learning_rate,
val_percent=val_percent,
save_checkpoint=save_checkpoint,
img_scale=img_scale,
amp=amp,
)
)
logging.info(f'''Starting training:
logging.info(
f"""Starting training:
Epochs: {epochs}
Batch size: {batch_size}
Learning rate: {learning_rate}
@ -62,38 +72,41 @@ def train_net(net,
Device: {device.type}
Images scaling: {img_scale}
Mixed Precision: {amp}
''')
"""
)
# 4. Set up the optimizer, the loss, the learning rate scheduler and the loss scaling for AMP
optimizer = optim.RMSprop(net.parameters(), lr=learning_rate, weight_decay=1e-8, momentum=0.9)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', patience=2) # goal: maximize Dice score
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, "max", patience=2) # goal: maximize Dice score
grad_scaler = torch.cuda.amp.GradScaler(enabled=amp)
criterion = nn.CrossEntropyLoss()
global_step = 0
# 5. Begin training
for epoch in range(1, epochs+1):
for epoch in range(1, epochs + 1):
net.train()
epoch_loss = 0
with tqdm(total=n_train, desc=f'Epoch {epoch}/{epochs}', unit='img') as pbar:
with tqdm(total=n_train, desc=f"Epoch {epoch}/{epochs}", unit="img") as pbar:
for batch in train_loader:
images = batch['image']
true_masks = batch['mask']
images = batch["image"]
true_masks = batch["mask"]
assert images.shape[1] == net.n_channels, \
f'Network has been defined with {net.n_channels} input channels, ' \
f'but loaded images have {images.shape[1]} channels. Please check that ' \
'the images are loaded correctly.'
assert images.shape[1] == net.n_channels, (
f"Network has been defined with {net.n_channels} input channels, "
f"but loaded images have {images.shape[1]} channels. Please check that "
"the images are loaded correctly."
)
images = images.to(device=device, dtype=torch.float32)
true_masks = true_masks.to(device=device, dtype=torch.long)
with torch.cuda.amp.autocast(enabled=amp):
masks_pred = net(images)
loss = criterion(masks_pred, true_masks) \
+ dice_loss(F.softmax(masks_pred, dim=1).float(),
F.one_hot(true_masks, net.n_classes).permute(0, 3, 1, 2).float(),
multiclass=True)
loss = criterion(masks_pred, true_masks) + dice_loss(
F.softmax(masks_pred, dim=1).float(),
F.one_hot(true_masks, net.n_classes).permute(0, 3, 1, 2).float(),
multiclass=True,
)
optimizer.zero_grad(set_to_none=True)
grad_scaler.scale(loss).backward()
@ -103,95 +116,157 @@ def train_net(net,
pbar.update(images.shape[0])
global_step += 1
epoch_loss += loss.item()
experiment.log({
'train loss': loss.item(),
'step': global_step,
'epoch': epoch
})
pbar.set_postfix(**{'loss (batch)': loss.item()})
experiment.log({"train loss": loss.item(), "step": global_step, "epoch": epoch})
pbar.set_postfix(**{"loss (batch)": loss.item()})
# Evaluation round
division_step = (n_train // (10 * batch_size))
division_step = n_train // (10 * batch_size)
if division_step > 0:
if global_step % division_step == 0:
histograms = {}
for tag, value in net.named_parameters():
tag = tag.replace('/', '.')
histograms['Weights/' + tag] = wandb.Histogram(value.data.cpu())
histograms['Gradients/' + tag] = wandb.Histogram(value.grad.data.cpu())
tag = tag.replace("/", ".")
histograms["Weights/" + tag] = wandb.Histogram(value.data.cpu())
histograms["Gradients/" + tag] = wandb.Histogram(value.grad.data.cpu())
val_score = evaluate(net, val_loader, device)
scheduler.step(val_score)
logging.info('Validation Dice score: {}'.format(val_score))
experiment.log({
'learning rate': optimizer.param_groups[0]['lr'],
'validation Dice': val_score,
'images': wandb.Image(images[0].cpu()),
'masks': {
'true': wandb.Image(true_masks[0].float().cpu()),
'pred': wandb.Image(torch.softmax(masks_pred, dim=1).argmax(dim=1)[0].float().cpu()),
},
'step': global_step,
'epoch': epoch,
**histograms
})
logging.info("Validation Dice score: {}".format(val_score))
experiment.log(
{
"learning rate": optimizer.param_groups[0]["lr"],
"validation Dice": val_score,
"images": wandb.Image(images[0].cpu()),
"masks": {
"true": wandb.Image(true_masks[0].float().cpu()),
"pred": wandb.Image(
torch.softmax(masks_pred, dim=1).argmax(dim=1)[0].float().cpu()
),
},
"step": global_step,
"epoch": epoch,
**histograms,
}
)
if save_checkpoint:
Path(dir_checkpoint).mkdir(parents=True, exist_ok=True)
torch.save(net.state_dict(), str(dir_checkpoint / 'checkpoint_epoch{}.pth'.format(epoch)))
logging.info(f'Checkpoint {epoch} saved!')
torch.save(net.state_dict(), str(dir_checkpoint / "checkpoint_epoch{}.pth".format(epoch)))
logging.info(f"Checkpoint {epoch} saved!")
def get_args():
parser = argparse.ArgumentParser(description='Train the UNet on images and target masks')
parser.add_argument('--epochs', '-e', metavar='E', type=int, default=5, help='Number of epochs')
parser.add_argument('--batch-size', '-b', dest='batch_size', metavar='B', type=int, default=1, help='Batch size')
parser.add_argument('--learning-rate', '-l', metavar='LR', type=float, default=1e-5,
help='Learning rate', dest='lr')
parser.add_argument('--load', '-f', type=str, default=False, help='Load model from a .pth file')
parser.add_argument('--scale', '-s', type=float, default=0.5, help='Downscaling factor of the images')
parser.add_argument('--validation', '-v', dest='val', type=float, default=10.0,
help='Percent of the data that is used as validation (0-100)')
parser.add_argument('--amp', action='store_true', default=False, help='Use mixed precision')
parser.add_argument('--bilinear', action='store_true', default=False, help='Use bilinear upsampling')
parser.add_argument('--classes', '-c', type=int, default=2, help='Number of classes')
parser = argparse.ArgumentParser(
description="Train the UNet on images and target masks",
)
parser.add_argument(
"--epochs",
"-e",
metavar="E",
type=int,
default=5,
help="Number of epochs",
)
parser.add_argument(
"--batch-size",
"-b",
dest="batch_size",
metavar="B",
type=int,
default=1,
help="Batch size",
)
parser.add_argument(
"--learning-rate",
"-l",
metavar="LR",
type=float,
default=1e-5,
help="Learning rate",
dest="lr",
)
parser.add_argument(
"--load",
"-f",
type=str,
default=False,
help="Load model from a .pth file",
)
parser.add_argument(
"--scale",
"-s",
type=float,
default=0.5,
help="Downscaling factor of the images",
)
parser.add_argument(
"--validation",
"-v",
dest="val",
type=float,
default=10.0,
help="Percent of the data that is used as validation (0-100)",
)
parser.add_argument(
"--amp",
action="store_true",
default=False,
help="Use mixed precision",
)
parser.add_argument(
"--bilinear",
action="store_true",
default=False,
help="Use bilinear upsampling",
)
parser.add_argument(
"--classes",
"-c",
type=int,
default=2,
help="Number of classes",
)
return parser.parse_args()
if __name__ == '__main__':
if __name__ == "__main__":
args = get_args()
logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logging.info(f'Using device {device}')
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logging.info(f"Using device {device}")
# Change here to adapt to your data
# n_channels=3 for RGB images
# n_classes is the number of probabilities you want to get per pixel
net = UNet(n_channels=3, n_classes=args.classes, bilinear=args.bilinear)
logging.info(f'Network:\n'
f'\t{net.n_channels} input channels\n'
f'\t{net.n_classes} output channels (classes)\n'
f'\t{"Bilinear" if net.bilinear else "Transposed conv"} upscaling')
logging.info(
f"Network:\n"
f"\t{net.n_channels} input channels\n"
f"\t{net.n_classes} output channels (classes)\n"
f'\t{"Bilinear" if net.bilinear else "Transposed conv"} upscaling'
)
if args.load:
net.load_state_dict(torch.load(args.load, map_location=device))
logging.info(f'Model loaded from {args.load}')
logging.info(f"Model loaded from {args.load}")
net.to(device=device)
try:
train_net(net=net,
epochs=args.epochs,
batch_size=args.batch_size,
learning_rate=args.lr,
device=device,
img_scale=args.scale,
val_percent=args.val / 100,
amp=args.amp)
train_net(
net=net,
epochs=args.epochs,
batch_size=args.batch_size,
learning_rate=args.lr,
device=device,
img_scale=args.scale,
val_percent=args.val / 100,
amp=args.amp,
)
except KeyboardInterrupt:
torch.save(net.state_dict(), 'INTERRUPTED.pth')
logging.info('Saved interrupt')
torch.save(net.state_dict(), "INTERRUPTED.pth")
logging.info("Saved interrupt")
raise

View file

@ -18,7 +18,7 @@ class DoubleConv(nn.Module):
nn.ReLU(inplace=True),
nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
nn.ReLU(inplace=True),
)
def forward(self, x):
@ -30,10 +30,7 @@ class Down(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.maxpool_conv = nn.Sequential(
nn.MaxPool2d(2),
DoubleConv(in_channels, out_channels)
)
self.maxpool_conv = nn.Sequential(nn.MaxPool2d(2), DoubleConv(in_channels, out_channels))
def forward(self, x):
return self.maxpool_conv(x)
@ -47,7 +44,7 @@ class Up(nn.Module):
# if bilinear, use the normal convolutions to reduce the number of channels
if bilinear:
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
else:
self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
@ -59,8 +56,7 @@ class Up(nn.Module):
diffY = x2.size()[2] - x1.size()[2]
diffX = x2.size()[3] - x1.size()[3]
x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
diffY // 2, diffY - diffY // 2])
x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2])
# if you have padding issues, see
# https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
# https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd

View file

@ -10,17 +10,17 @@ from torch.utils.data import Dataset
class BasicDataset(Dataset):
def __init__(self, images_dir: str, masks_dir: str, scale: float = 1.0, mask_suffix: str = ''):
def __init__(self, images_dir: str, masks_dir: str, scale: float = 1.0, mask_suffix: str = ""):
self.images_dir = Path(images_dir)
self.masks_dir = Path(masks_dir)
assert 0 < scale <= 1, 'Scale must be between 0 and 1'
assert 0 < scale <= 1, "Scale must be between 0 and 1"
self.scale = scale
self.mask_suffix = mask_suffix
self.ids = [splitext(file)[0] for file in listdir(images_dir) if not file.startswith('.')]
self.ids = [splitext(file)[0] for file in listdir(images_dir) if not file.startswith(".")]
if not self.ids:
raise RuntimeError(f'No input file found in {images_dir}, make sure you put your images there')
logging.info(f'Creating dataset with {len(self.ids)} examples')
raise RuntimeError(f"No input file found in {images_dir}, make sure you put your images there")
logging.info(f"Creating dataset with {len(self.ids)} examples")
def __len__(self):
return len(self.ids)
@ -29,7 +29,7 @@ class BasicDataset(Dataset):
def preprocess(pil_img, scale, is_mask):
w, h = pil_img.size
newW, newH = int(scale * w), int(scale * h)
assert newW > 0 and newH > 0, 'Scale is too small, resized images would have no pixel'
assert newW > 0 and newH > 0, "Scale is too small, resized images would have no pixel"
pil_img = pil_img.resize((newW, newH), resample=Image.NEAREST if is_mask else Image.BICUBIC)
img_ndarray = np.asarray(pil_img)
@ -46,35 +46,36 @@ class BasicDataset(Dataset):
@staticmethod
def load(filename):
ext = splitext(filename)[1]
if ext in ['.npz', '.npy']:
if ext in [".npz", ".npy"]:
return Image.fromarray(np.load(filename))
elif ext in ['.pt', '.pth']:
elif ext in [".pt", ".pth"]:
return Image.fromarray(torch.load(filename).numpy())
else:
return Image.open(filename)
def __getitem__(self, idx):
name = self.ids[idx]
mask_file = list(self.masks_dir.glob(name + self.mask_suffix + '.*'))
img_file = list(self.images_dir.glob(name + '.*'))
mask_file = list(self.masks_dir.glob(name + self.mask_suffix + ".*"))
img_file = list(self.images_dir.glob(name + ".*"))
assert len(img_file) == 1, f'Either no image or multiple images found for the ID {name}: {img_file}'
assert len(mask_file) == 1, f'Either no mask or multiple masks found for the ID {name}: {mask_file}'
assert len(img_file) == 1, f"Either no image or multiple images found for the ID {name}: {img_file}"
assert len(mask_file) == 1, f"Either no mask or multiple masks found for the ID {name}: {mask_file}"
mask = self.load(mask_file[0])
img = self.load(img_file[0])
assert img.size == mask.size, \
f'Image and mask {name} should be the same size, but are {img.size} and {mask.size}'
assert (
img.size == mask.size
), f"Image and mask {name} should be the same size, but are {img.size} and {mask.size}"
img = self.preprocess(img, self.scale, is_mask=False)
mask = self.preprocess(mask, self.scale, is_mask=True)
return {
'image': torch.as_tensor(img.copy()).float().contiguous(),
'mask': torch.as_tensor(mask.copy()).long().contiguous()
"image": torch.as_tensor(img.copy()).float().contiguous(),
"mask": torch.as_tensor(mask.copy()).long().contiguous(),
}
class CarvanaDataset(BasicDataset):
def __init__(self, images_dir, masks_dir, scale=1):
super().__init__(images_dir, masks_dir, scale, mask_suffix='_mask')
super().__init__(images_dir, masks_dir, scale, mask_suffix="_mask")

View file

@ -6,7 +6,7 @@ def dice_coeff(input: Tensor, target: Tensor, reduce_batch_first: bool = False,
# Average of Dice coefficient for all batches, or for a single mask
assert input.size() == target.size()
if input.dim() == 2 and reduce_batch_first:
raise ValueError(f'Dice: asked to reduce batch but got tensor without batch dimension (shape {input.shape})')
raise ValueError(f"Dice: asked to reduce batch but got tensor without batch dimension (shape {input.shape})")
if input.dim() == 2 or reduce_batch_first:
inter = torch.dot(input.reshape(-1), target.reshape(-1))

View file

@ -1,17 +1,17 @@
import matplotlib.pyplot as plt
def plot_img_and_mask(img, mask):
classes = mask.shape[0] if len(mask.shape) > 2 else 1
fig, ax = plt.subplots(1, classes + 1)
ax[0].set_title('Input image')
ax[0].imshow(img)
if classes > 1:
for i in range(classes):
ax[i + 1].set_title(f'Output mask (class {i + 1})')
ax[i + 1].imshow(mask[1, :, :])
else:
ax[1].set_title(f'Output mask')
ax[1].imshow(mask)
plt.xticks([]), plt.yticks([])
plt.show()
import matplotlib.pyplot as plt
def plot_img_and_mask(img, mask):
classes = mask.shape[0] if len(mask.shape) > 2 else 1
fig, ax = plt.subplots(1, classes + 1)
ax[0].set_title("Input image")
ax[0].imshow(img)
if classes > 1:
for i in range(classes):
ax[i + 1].set_title(f"Output mask (class {i + 1})")
ax[i + 1].imshow(mask[1, :, :])
else:
ax[1].set_title(f"Output mask")
ax[1].imshow(mask)
plt.xticks([]), plt.yticks([])
plt.show()