(training_utils) add NeptuneCallback

This commit is contained in:
Laurent 2024-05-28 14:24:16 +00:00 committed by Laureηt
parent 3a7f14e4dc
commit 3ad7f592db
3 changed files with 205 additions and 0 deletions

View file

@ -33,6 +33,7 @@ training = [
"torchvision>=0.16.1", "torchvision>=0.16.1",
"loguru>=0.7.2", "loguru>=0.7.2",
"wandb>=0.16.0", "wandb>=0.16.0",
"neptune>=1.10.4",
"datasets>=2.15.0", "datasets>=2.15.0",
"tomli>=2.0.1", "tomli>=2.0.1",
"gitpython>=3.1.43", "gitpython>=3.1.43",

View file

@ -17,16 +17,29 @@ annotated-types==0.6.0
# via pydantic # via pydantic
appdirs==1.4.4 appdirs==1.4.4
# via wandb # via wandb
arrow==1.3.0
# via isoduration
async-timeout==4.0.3 async-timeout==4.0.3
# via aiohttp # via aiohttp
attrs==23.2.0 attrs==23.2.0
# via aiohttp # via aiohttp
# via jsonschema
# via referencing
babel==2.14.0 babel==2.14.0
# via mkdocs-material # via mkdocs-material
bitsandbytes==0.43.0 bitsandbytes==0.43.0
# via refiners # via refiners
black==24.3.0 black==24.3.0
# via refiners # via refiners
boto3==1.34.113
# via neptune
botocore==1.34.113
# via boto3
# via s3transfer
bravado==11.0.3
# via neptune
bravado-core==6.1.1
# via bravado
certifi==2024.2.2 certifi==2024.2.2
# via requests # via requests
# via sentry-sdk # via sentry-sdk
@ -36,6 +49,7 @@ click==8.1.7
# via black # via black
# via mkdocs # via mkdocs
# via mkdocstrings # via mkdocstrings
# via neptune
# via wandb # via wandb
colorama==0.4.6 colorama==0.4.6
# via griffe # via griffe
@ -56,6 +70,8 @@ filelock==3.13.3
# via torch # via torch
# via transformers # via transformers
# via triton # via triton
fqdn==1.5.1
# via jsonschema
frozenlist==1.4.1 frozenlist==1.4.1
# via aiohttp # via aiohttp
# via aiosignal # via aiosignal
@ -63,11 +79,14 @@ fsspec==2024.2.0
# via datasets # via datasets
# via huggingface-hub # via huggingface-hub
# via torch # via torch
future==1.0.0
# via neptune
ghp-import==2.1.0 ghp-import==2.1.0
# via mkdocs # via mkdocs
gitdb==4.0.11 gitdb==4.0.11
# via gitpython # via gitpython
gitpython==3.1.43 gitpython==3.1.43
# via neptune
# via refiners # via refiners
# via wandb # via wandb
griffe==0.42.1 griffe==0.42.1
@ -79,10 +98,13 @@ huggingface-hub==0.22.2
# via tokenizers # via tokenizers
# via transformers # via transformers
idna==3.6 idna==3.6
# via jsonschema
# via requests # via requests
# via yarl # via yarl
importlib-metadata==7.1.0 importlib-metadata==7.1.0
# via diffusers # via diffusers
isoduration==20.11.0
# via jsonschema
jaxtyping==0.2.28 jaxtyping==0.2.28
# via refiners # via refiners
jinja2==3.1.3 jinja2==3.1.3
@ -90,6 +112,18 @@ jinja2==3.1.3
# via mkdocs-material # via mkdocs-material
# via mkdocstrings # via mkdocstrings
# via torch # via torch
jmespath==1.0.1
# via boto3
# via botocore
jsonpointer==2.4
# via jsonschema
jsonref==1.1.0
# via bravado-core
jsonschema==4.22.0
# via bravado-core
# via swagger-spec-validator
jsonschema-specifications==2023.12.1
# via jsonschema
loguru==0.7.2 loguru==0.7.2
# via refiners # via refiners
markdown==3.6 markdown==3.6
@ -123,8 +157,13 @@ mkdocstrings==0.24.1
# via refiners # via refiners
mkdocstrings-python==1.8.0 mkdocstrings-python==1.8.0
# via mkdocstrings # via mkdocstrings
monotonic==1.6
# via bravado
mpmath==1.3.0 mpmath==1.3.0
# via sympy # via sympy
msgpack==1.0.8
# via bravado
# via bravado-core
multidict==6.0.5 multidict==6.0.5
# via aiohttp # via aiohttp
# via yarl # via yarl
@ -132,6 +171,8 @@ multiprocess==0.70.16
# via datasets # via datasets
mypy-extensions==1.0.0 mypy-extensions==1.0.0
# via black # via black
neptune==1.10.4
# via refiners
networkx==3.2.1 networkx==3.2.1
# via torch # via torch
numpy==1.26.4 numpy==1.26.4
@ -172,22 +213,28 @@ nvidia-nvjitlink-cu12==12.4.99
# via nvidia-cusparse-cu12 # via nvidia-cusparse-cu12
nvidia-nvtx-cu12==12.1.105 nvidia-nvtx-cu12==12.1.105
# via torch # via torch
oauthlib==3.2.2
# via neptune
# via requests-oauthlib
packaging==24.0 packaging==24.0
# via black # via black
# via datasets # via datasets
# via huggingface-hub # via huggingface-hub
# via mkdocs # via mkdocs
# via neptune
# via refiners # via refiners
# via transformers # via transformers
paginate==0.5.6 paginate==0.5.6
# via mkdocs-material # via mkdocs-material
pandas==2.2.1 pandas==2.2.1
# via datasets # via datasets
# via neptune
pathspec==0.12.1 pathspec==0.12.1
# via black # via black
# via mkdocs # via mkdocs
pillow==10.3.0 pillow==10.3.0
# via diffusers # via diffusers
# via neptune
# via refiners # via refiners
# via torchvision # via torchvision
piq==0.8.0 piq==0.8.0
@ -201,6 +248,7 @@ prodigyopt==1.0
protobuf==4.25.3 protobuf==4.25.3
# via wandb # via wandb
psutil==5.9.8 psutil==5.9.8
# via neptune
# via wandb # via wandb
pyarrow==15.0.2 pyarrow==15.0.2
# via datasets # via datasets
@ -212,37 +260,65 @@ pydantic-core==2.16.3
# via pydantic # via pydantic
pygments==2.17.2 pygments==2.17.2
# via mkdocs-material # via mkdocs-material
pyjwt==2.8.0
# via neptune
pymdown-extensions==10.7.1 pymdown-extensions==10.7.1
# via mkdocs-material # via mkdocs-material
# via mkdocstrings # via mkdocstrings
python-dateutil==2.9.0.post0 python-dateutil==2.9.0.post0
# via arrow
# via botocore
# via bravado
# via bravado-core
# via ghp-import # via ghp-import
# via pandas # via pandas
pytz==2024.1 pytz==2024.1
# via bravado-core
# via pandas # via pandas
pyyaml==6.0.1 pyyaml==6.0.1
# via bravado
# via bravado-core
# via datasets # via datasets
# via huggingface-hub # via huggingface-hub
# via mkdocs # via mkdocs
# via pymdown-extensions # via pymdown-extensions
# via pyyaml-env-tag # via pyyaml-env-tag
# via swagger-spec-validator
# via timm # via timm
# via transformers # via transformers
# via wandb # via wandb
pyyaml-env-tag==0.1 pyyaml-env-tag==0.1
# via mkdocs # via mkdocs
referencing==0.35.1
# via jsonschema
# via jsonschema-specifications
regex==2023.12.25 regex==2023.12.25
# via diffusers # via diffusers
# via mkdocs-material # via mkdocs-material
# via transformers # via transformers
requests==2.31.0 requests==2.31.0
# via bravado
# via bravado-core
# via datasets # via datasets
# via diffusers # via diffusers
# via huggingface-hub # via huggingface-hub
# via mkdocs-material # via mkdocs-material
# via neptune
# via refiners # via refiners
# via requests-oauthlib
# via transformers # via transformers
# via wandb # via wandb
requests-oauthlib==2.0.0
# via neptune
rfc3339-validator==0.1.4
# via jsonschema
rfc3986-validator==0.1.1
# via jsonschema
rpds-py==0.18.1
# via jsonschema
# via referencing
s3transfer==0.10.1
# via boto3
safetensors==0.4.2 safetensors==0.4.2
# via diffusers # via diffusers
# via refiners # via refiners
@ -258,11 +334,21 @@ setproctitle==1.3.3
# via wandb # via wandb
setuptools==69.2.0 setuptools==69.2.0
# via wandb # via wandb
simplejson==3.19.2
# via bravado
# via bravado-core
six==1.16.0 six==1.16.0
# via bravado
# via bravado-core
# via docker-pycreds # via docker-pycreds
# via neptune
# via python-dateutil # via python-dateutil
# via rfc3339-validator
smmap==5.0.1 smmap==5.0.1
# via gitdb # via gitdb
swagger-spec-validator==3.0.3
# via bravado-core
# via neptune
sympy==1.12 sympy==1.12
# via torch # via torch
timm==0.9.16 timm==0.9.16
@ -296,21 +382,34 @@ triton==2.2.0
# via torch # via torch
typeguard==2.13.3 typeguard==2.13.3
# via jaxtyping # via jaxtyping
types-python-dateutil==2.9.0.20240316
# via arrow
typing-extensions==4.10.0 typing-extensions==4.10.0
# via black # via black
# via bravado
# via huggingface-hub # via huggingface-hub
# via neptune
# via pydantic # via pydantic
# via pydantic-core # via pydantic-core
# via swagger-spec-validator
# via torch # via torch
tzdata==2024.1 tzdata==2024.1
# via pandas # via pandas
uri-template==1.3.0
# via jsonschema
urllib3==2.2.1 urllib3==2.2.1
# via botocore
# via neptune
# via requests # via requests
# via sentry-sdk # via sentry-sdk
wandb==0.16.5 wandb==0.16.5
# via refiners # via refiners
watchdog==4.0.0 watchdog==4.0.0
# via mkdocs # via mkdocs
webcolors==1.13
# via jsonschema
websocket-client==1.8.0
# via neptune
xxhash==3.4.1 xxhash==3.4.1
# via datasets # via datasets
yarl==1.9.4 yarl==1.9.4

View file

@ -0,0 +1,105 @@
from abc import ABC
from os import PathLike
from typing import Any, Literal
from neptune import Run, init_run # type: ignore
from neptune.internal.init.parameters import ( # type: ignore
ASYNC_LAG_THRESHOLD,
ASYNC_NO_PROGRESS_THRESHOLD,
DEFAULT_FLUSH_PERIOD,
)
from neptune.metadata_containers.abstract import NeptuneObjectCallback # type: ignore
from neptune.types.atoms.git_ref import GitRef, GitRefDisabled # type: ignore
from refiners.training_utils.callback import Callback, CallbackConfig
from refiners.training_utils.config import BaseConfig
from refiners.training_utils.trainer import Trainer, register_callback
AnyTrainer = Trainer[BaseConfig, Any]
class NeptuneConfig(CallbackConfig):
"""Neptune.ai run configuration
See https://docs.neptune.ai/api/neptune#init_run
and https://github.com/neptune-ai/neptune-client/blob/1cd8452045e8524318f59216d151d73328f85bd1/src/neptune/objects/run.py#L131
"""
project: str | None = None
api_token: str | None = None
with_id: str | None = None
custom_run_id: str | None = None
mode: Literal["async", "sync", "offline", "read-only", "debug"] | None = None
name: str | None = None
description: str | None = None
tags: str | list[str] | None = None
source_files: str | list[str] | None = None
capture_stdout: bool | None = None
capture_stderr: bool | None = None
capture_hardware_metrics: bool | None = None
fail_on_exception: bool = True
monitoring_namespace: str | None = None
flush_period: float = DEFAULT_FLUSH_PERIOD
proxies: dict[str, str] | None = None
capture_traceback: bool = True
git_ref: GitRef | GitRefDisabled | None = None
dependencies: PathLike[str] | str | None = None
async_lag_callback: NeptuneObjectCallback | None = None
async_no_progress_callback: NeptuneObjectCallback | None = None
async_lag_threshold: float = ASYNC_LAG_THRESHOLD
async_no_progress_threshold: float = ASYNC_NO_PROGRESS_THRESHOLD
class NeptuneCallback(Callback[AnyTrainer]):
"""Neptune.ai callback for logging metrics"""
run: Run
def __init__(self, config: NeptuneConfig) -> None:
"""Initialize Neptune.ai callback
Args:
config: Neptune.ai run configuration
"""
self.config = config
self.epoch_losses: list[float] = []
self.iteration_losses: list[float] = []
def on_train_begin(self, trainer: AnyTrainer) -> None:
# initialize Neptune `Run` (see https://docs.neptune.ai/api/run/)
self.run = init_run(**self.config.model_dump())
self.run["config"] = trainer.config.model_dump()
# reset epoch and iteration losses
self.epoch_losses = []
self.iteration_losses = []
def on_compute_loss_end(self, trainer: AnyTrainer) -> None:
loss_value = trainer.loss.detach().cpu().item()
self.epoch_losses.append(loss_value)
self.iteration_losses.append(loss_value)
self.run["train/step_loss"].append(loss_value, step=trainer.clock.step) # type: ignore
def on_optimizer_step_end(self, trainer: AnyTrainer) -> None:
self.run["train/total_grad_norm"].append(trainer.grad_norm, step=trainer.clock.step) # type: ignore
avg_iteration_loss = sum(self.iteration_losses) / len(self.iteration_losses)
self.run["train/average_iteration_loss"].append(avg_iteration_loss, step=trainer.clock.step) # type: ignore
self.iteration_losses = []
def on_epoch_end(self, trainer: AnyTrainer) -> None:
avg_epoch_loss = sum(self.epoch_losses) / len(self.epoch_losses)
self.run["train/average_epoch_loss"].append(avg_epoch_loss, step=trainer.clock.step) # type: ignore
self.run["train/epoch"].append(trainer.clock.epoch, step=trainer.clock.step) # type: ignore
self.epoch_losses = []
def on_lr_scheduler_step_end(self, trainer: AnyTrainer) -> None:
self.run["train/learning_rate"].append(trainer.optimizer.param_groups[0]["lr"], step=trainer.clock.step) # type: ignore
def on_train_end(self, trainer: AnyTrainer) -> None:
self.run.stop()
class NeptuneMixin(ABC):
@register_callback()
def neptune(self, config: NeptuneConfig) -> NeptuneCallback:
return NeptuneCallback(config)