(training_utils) add NeptuneCallback

This commit is contained in:
Laurent 2024-05-28 14:24:16 +00:00 committed by Laureηt
parent 3a7f14e4dc
commit 3ad7f592db
3 changed files with 205 additions and 0 deletions

View file

@ -33,6 +33,7 @@ training = [
"torchvision>=0.16.1",
"loguru>=0.7.2",
"wandb>=0.16.0",
"neptune>=1.10.4",
"datasets>=2.15.0",
"tomli>=2.0.1",
"gitpython>=3.1.43",

View file

@ -17,16 +17,29 @@ annotated-types==0.6.0
# via pydantic
appdirs==1.4.4
# via wandb
arrow==1.3.0
# via isoduration
async-timeout==4.0.3
# via aiohttp
attrs==23.2.0
# via aiohttp
# via jsonschema
# via referencing
babel==2.14.0
# via mkdocs-material
bitsandbytes==0.43.0
# via refiners
black==24.3.0
# via refiners
boto3==1.34.113
# via neptune
botocore==1.34.113
# via boto3
# via s3transfer
bravado==11.0.3
# via neptune
bravado-core==6.1.1
# via bravado
certifi==2024.2.2
# via requests
# via sentry-sdk
@ -36,6 +49,7 @@ click==8.1.7
# via black
# via mkdocs
# via mkdocstrings
# via neptune
# via wandb
colorama==0.4.6
# via griffe
@ -56,6 +70,8 @@ filelock==3.13.3
# via torch
# via transformers
# via triton
fqdn==1.5.1
# via jsonschema
frozenlist==1.4.1
# via aiohttp
# via aiosignal
@ -63,11 +79,14 @@ fsspec==2024.2.0
# via datasets
# via huggingface-hub
# via torch
future==1.0.0
# via neptune
ghp-import==2.1.0
# via mkdocs
gitdb==4.0.11
# via gitpython
gitpython==3.1.43
# via neptune
# via refiners
# via wandb
griffe==0.42.1
@ -79,10 +98,13 @@ huggingface-hub==0.22.2
# via tokenizers
# via transformers
idna==3.6
# via jsonschema
# via requests
# via yarl
importlib-metadata==7.1.0
# via diffusers
isoduration==20.11.0
# via jsonschema
jaxtyping==0.2.28
# via refiners
jinja2==3.1.3
@ -90,6 +112,18 @@ jinja2==3.1.3
# via mkdocs-material
# via mkdocstrings
# via torch
jmespath==1.0.1
# via boto3
# via botocore
jsonpointer==2.4
# via jsonschema
jsonref==1.1.0
# via bravado-core
jsonschema==4.22.0
# via bravado-core
# via swagger-spec-validator
jsonschema-specifications==2023.12.1
# via jsonschema
loguru==0.7.2
# via refiners
markdown==3.6
@ -123,8 +157,13 @@ mkdocstrings==0.24.1
# via refiners
mkdocstrings-python==1.8.0
# via mkdocstrings
monotonic==1.6
# via bravado
mpmath==1.3.0
# via sympy
msgpack==1.0.8
# via bravado
# via bravado-core
multidict==6.0.5
# via aiohttp
# via yarl
@ -132,6 +171,8 @@ multiprocess==0.70.16
# via datasets
mypy-extensions==1.0.0
# via black
neptune==1.10.4
# via refiners
networkx==3.2.1
# via torch
numpy==1.26.4
@ -172,22 +213,28 @@ nvidia-nvjitlink-cu12==12.4.99
# via nvidia-cusparse-cu12
nvidia-nvtx-cu12==12.1.105
# via torch
oauthlib==3.2.2
# via neptune
# via requests-oauthlib
packaging==24.0
# via black
# via datasets
# via huggingface-hub
# via mkdocs
# via neptune
# via refiners
# via transformers
paginate==0.5.6
# via mkdocs-material
pandas==2.2.1
# via datasets
# via neptune
pathspec==0.12.1
# via black
# via mkdocs
pillow==10.3.0
# via diffusers
# via neptune
# via refiners
# via torchvision
piq==0.8.0
@ -201,6 +248,7 @@ prodigyopt==1.0
protobuf==4.25.3
# via wandb
psutil==5.9.8
# via neptune
# via wandb
pyarrow==15.0.2
# via datasets
@ -212,37 +260,65 @@ pydantic-core==2.16.3
# via pydantic
pygments==2.17.2
# via mkdocs-material
pyjwt==2.8.0
# via neptune
pymdown-extensions==10.7.1
# via mkdocs-material
# via mkdocstrings
python-dateutil==2.9.0.post0
# via arrow
# via botocore
# via bravado
# via bravado-core
# via ghp-import
# via pandas
pytz==2024.1
# via bravado-core
# via pandas
pyyaml==6.0.1
# via bravado
# via bravado-core
# via datasets
# via huggingface-hub
# via mkdocs
# via pymdown-extensions
# via pyyaml-env-tag
# via swagger-spec-validator
# via timm
# via transformers
# via wandb
pyyaml-env-tag==0.1
# via mkdocs
referencing==0.35.1
# via jsonschema
# via jsonschema-specifications
regex==2023.12.25
# via diffusers
# via mkdocs-material
# via transformers
requests==2.31.0
# via bravado
# via bravado-core
# via datasets
# via diffusers
# via huggingface-hub
# via mkdocs-material
# via neptune
# via refiners
# via requests-oauthlib
# via transformers
# via wandb
requests-oauthlib==2.0.0
# via neptune
rfc3339-validator==0.1.4
# via jsonschema
rfc3986-validator==0.1.1
# via jsonschema
rpds-py==0.18.1
# via jsonschema
# via referencing
s3transfer==0.10.1
# via boto3
safetensors==0.4.2
# via diffusers
# via refiners
@ -258,11 +334,21 @@ setproctitle==1.3.3
# via wandb
setuptools==69.2.0
# via wandb
simplejson==3.19.2
# via bravado
# via bravado-core
six==1.16.0
# via bravado
# via bravado-core
# via docker-pycreds
# via neptune
# via python-dateutil
# via rfc3339-validator
smmap==5.0.1
# via gitdb
swagger-spec-validator==3.0.3
# via bravado-core
# via neptune
sympy==1.12
# via torch
timm==0.9.16
@ -296,21 +382,34 @@ triton==2.2.0
# via torch
typeguard==2.13.3
# via jaxtyping
types-python-dateutil==2.9.0.20240316
# via arrow
typing-extensions==4.10.0
# via black
# via bravado
# via huggingface-hub
# via neptune
# via pydantic
# via pydantic-core
# via swagger-spec-validator
# via torch
tzdata==2024.1
# via pandas
uri-template==1.3.0
# via jsonschema
urllib3==2.2.1
# via botocore
# via neptune
# via requests
# via sentry-sdk
wandb==0.16.5
# via refiners
watchdog==4.0.0
# via mkdocs
webcolors==1.13
# via jsonschema
websocket-client==1.8.0
# via neptune
xxhash==3.4.1
# via datasets
yarl==1.9.4

View file

@ -0,0 +1,105 @@
from abc import ABC
from os import PathLike
from typing import Any, Literal
from neptune import Run, init_run # type: ignore
from neptune.internal.init.parameters import ( # type: ignore
ASYNC_LAG_THRESHOLD,
ASYNC_NO_PROGRESS_THRESHOLD,
DEFAULT_FLUSH_PERIOD,
)
from neptune.metadata_containers.abstract import NeptuneObjectCallback # type: ignore
from neptune.types.atoms.git_ref import GitRef, GitRefDisabled # type: ignore
from refiners.training_utils.callback import Callback, CallbackConfig
from refiners.training_utils.config import BaseConfig
from refiners.training_utils.trainer import Trainer, register_callback
AnyTrainer = Trainer[BaseConfig, Any]
class NeptuneConfig(CallbackConfig):
"""Neptune.ai run configuration
See https://docs.neptune.ai/api/neptune#init_run
and https://github.com/neptune-ai/neptune-client/blob/1cd8452045e8524318f59216d151d73328f85bd1/src/neptune/objects/run.py#L131
"""
project: str | None = None
api_token: str | None = None
with_id: str | None = None
custom_run_id: str | None = None
mode: Literal["async", "sync", "offline", "read-only", "debug"] | None = None
name: str | None = None
description: str | None = None
tags: str | list[str] | None = None
source_files: str | list[str] | None = None
capture_stdout: bool | None = None
capture_stderr: bool | None = None
capture_hardware_metrics: bool | None = None
fail_on_exception: bool = True
monitoring_namespace: str | None = None
flush_period: float = DEFAULT_FLUSH_PERIOD
proxies: dict[str, str] | None = None
capture_traceback: bool = True
git_ref: GitRef | GitRefDisabled | None = None
dependencies: PathLike[str] | str | None = None
async_lag_callback: NeptuneObjectCallback | None = None
async_no_progress_callback: NeptuneObjectCallback | None = None
async_lag_threshold: float = ASYNC_LAG_THRESHOLD
async_no_progress_threshold: float = ASYNC_NO_PROGRESS_THRESHOLD
class NeptuneCallback(Callback[AnyTrainer]):
"""Neptune.ai callback for logging metrics"""
run: Run
def __init__(self, config: NeptuneConfig) -> None:
"""Initialize Neptune.ai callback
Args:
config: Neptune.ai run configuration
"""
self.config = config
self.epoch_losses: list[float] = []
self.iteration_losses: list[float] = []
def on_train_begin(self, trainer: AnyTrainer) -> None:
# initialize Neptune `Run` (see https://docs.neptune.ai/api/run/)
self.run = init_run(**self.config.model_dump())
self.run["config"] = trainer.config.model_dump()
# reset epoch and iteration losses
self.epoch_losses = []
self.iteration_losses = []
def on_compute_loss_end(self, trainer: AnyTrainer) -> None:
loss_value = trainer.loss.detach().cpu().item()
self.epoch_losses.append(loss_value)
self.iteration_losses.append(loss_value)
self.run["train/step_loss"].append(loss_value, step=trainer.clock.step) # type: ignore
def on_optimizer_step_end(self, trainer: AnyTrainer) -> None:
self.run["train/total_grad_norm"].append(trainer.grad_norm, step=trainer.clock.step) # type: ignore
avg_iteration_loss = sum(self.iteration_losses) / len(self.iteration_losses)
self.run["train/average_iteration_loss"].append(avg_iteration_loss, step=trainer.clock.step) # type: ignore
self.iteration_losses = []
def on_epoch_end(self, trainer: AnyTrainer) -> None:
avg_epoch_loss = sum(self.epoch_losses) / len(self.epoch_losses)
self.run["train/average_epoch_loss"].append(avg_epoch_loss, step=trainer.clock.step) # type: ignore
self.run["train/epoch"].append(trainer.clock.epoch, step=trainer.clock.step) # type: ignore
self.epoch_losses = []
def on_lr_scheduler_step_end(self, trainer: AnyTrainer) -> None:
self.run["train/learning_rate"].append(trainer.optimizer.param_groups[0]["lr"], step=trainer.clock.step) # type: ignore
def on_train_end(self, trainer: AnyTrainer) -> None:
self.run.stop()
class NeptuneMixin(ABC):
@register_callback()
def neptune(self, config: NeptuneConfig) -> NeptuneCallback:
return NeptuneCallback(config)