From 3ad7f592dbde9a3f1e322ccb34e7f5065ef61530 Mon Sep 17 00:00:00 2001 From: Laurent Date: Tue, 28 May 2024 14:24:16 +0000 Subject: [PATCH] (training_utils) add NeptuneCallback --- pyproject.toml | 1 + requirements.lock | 99 +++++++++++++++++++++++ src/refiners/training_utils/neptune.py | 105 +++++++++++++++++++++++++ 3 files changed, 205 insertions(+) create mode 100644 src/refiners/training_utils/neptune.py diff --git a/pyproject.toml b/pyproject.toml index baf81b3..362c143 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,6 +33,7 @@ training = [ "torchvision>=0.16.1", "loguru>=0.7.2", "wandb>=0.16.0", + "neptune>=1.10.4", "datasets>=2.15.0", "tomli>=2.0.1", "gitpython>=3.1.43", diff --git a/requirements.lock b/requirements.lock index 967219b..ebeee29 100644 --- a/requirements.lock +++ b/requirements.lock @@ -17,16 +17,29 @@ annotated-types==0.6.0 # via pydantic appdirs==1.4.4 # via wandb +arrow==1.3.0 + # via isoduration async-timeout==4.0.3 # via aiohttp attrs==23.2.0 # via aiohttp + # via jsonschema + # via referencing babel==2.14.0 # via mkdocs-material bitsandbytes==0.43.0 # via refiners black==24.3.0 # via refiners +boto3==1.34.113 + # via neptune +botocore==1.34.113 + # via boto3 + # via s3transfer +bravado==11.0.3 + # via neptune +bravado-core==6.1.1 + # via bravado certifi==2024.2.2 # via requests # via sentry-sdk @@ -36,6 +49,7 @@ click==8.1.7 # via black # via mkdocs # via mkdocstrings + # via neptune # via wandb colorama==0.4.6 # via griffe @@ -56,6 +70,8 @@ filelock==3.13.3 # via torch # via transformers # via triton +fqdn==1.5.1 + # via jsonschema frozenlist==1.4.1 # via aiohttp # via aiosignal @@ -63,11 +79,14 @@ fsspec==2024.2.0 # via datasets # via huggingface-hub # via torch +future==1.0.0 + # via neptune ghp-import==2.1.0 # via mkdocs gitdb==4.0.11 # via gitpython gitpython==3.1.43 + # via neptune # via refiners # via wandb griffe==0.42.1 @@ -79,10 +98,13 @@ huggingface-hub==0.22.2 # via tokenizers # via transformers idna==3.6 + # via jsonschema # via requests # via yarl importlib-metadata==7.1.0 # via diffusers +isoduration==20.11.0 + # via jsonschema jaxtyping==0.2.28 # via refiners jinja2==3.1.3 @@ -90,6 +112,18 @@ jinja2==3.1.3 # via mkdocs-material # via mkdocstrings # via torch +jmespath==1.0.1 + # via boto3 + # via botocore +jsonpointer==2.4 + # via jsonschema +jsonref==1.1.0 + # via bravado-core +jsonschema==4.22.0 + # via bravado-core + # via swagger-spec-validator +jsonschema-specifications==2023.12.1 + # via jsonschema loguru==0.7.2 # via refiners markdown==3.6 @@ -123,8 +157,13 @@ mkdocstrings==0.24.1 # via refiners mkdocstrings-python==1.8.0 # via mkdocstrings +monotonic==1.6 + # via bravado mpmath==1.3.0 # via sympy +msgpack==1.0.8 + # via bravado + # via bravado-core multidict==6.0.5 # via aiohttp # via yarl @@ -132,6 +171,8 @@ multiprocess==0.70.16 # via datasets mypy-extensions==1.0.0 # via black +neptune==1.10.4 + # via refiners networkx==3.2.1 # via torch numpy==1.26.4 @@ -172,22 +213,28 @@ nvidia-nvjitlink-cu12==12.4.99 # via nvidia-cusparse-cu12 nvidia-nvtx-cu12==12.1.105 # via torch +oauthlib==3.2.2 + # via neptune + # via requests-oauthlib packaging==24.0 # via black # via datasets # via huggingface-hub # via mkdocs + # via neptune # via refiners # via transformers paginate==0.5.6 # via mkdocs-material pandas==2.2.1 # via datasets + # via neptune pathspec==0.12.1 # via black # via mkdocs pillow==10.3.0 # via diffusers + # via neptune # via refiners # via torchvision piq==0.8.0 @@ -201,6 +248,7 @@ prodigyopt==1.0 protobuf==4.25.3 # via wandb psutil==5.9.8 + # via neptune # via wandb pyarrow==15.0.2 # via datasets @@ -212,37 +260,65 @@ pydantic-core==2.16.3 # via pydantic pygments==2.17.2 # via mkdocs-material +pyjwt==2.8.0 + # via neptune pymdown-extensions==10.7.1 # via mkdocs-material # via mkdocstrings python-dateutil==2.9.0.post0 + # via arrow + # via botocore + # via bravado + # via bravado-core # via ghp-import # via pandas pytz==2024.1 + # via bravado-core # via pandas pyyaml==6.0.1 + # via bravado + # via bravado-core # via datasets # via huggingface-hub # via mkdocs # via pymdown-extensions # via pyyaml-env-tag + # via swagger-spec-validator # via timm # via transformers # via wandb pyyaml-env-tag==0.1 # via mkdocs +referencing==0.35.1 + # via jsonschema + # via jsonschema-specifications regex==2023.12.25 # via diffusers # via mkdocs-material # via transformers requests==2.31.0 + # via bravado + # via bravado-core # via datasets # via diffusers # via huggingface-hub # via mkdocs-material + # via neptune # via refiners + # via requests-oauthlib # via transformers # via wandb +requests-oauthlib==2.0.0 + # via neptune +rfc3339-validator==0.1.4 + # via jsonschema +rfc3986-validator==0.1.1 + # via jsonschema +rpds-py==0.18.1 + # via jsonschema + # via referencing +s3transfer==0.10.1 + # via boto3 safetensors==0.4.2 # via diffusers # via refiners @@ -258,11 +334,21 @@ setproctitle==1.3.3 # via wandb setuptools==69.2.0 # via wandb +simplejson==3.19.2 + # via bravado + # via bravado-core six==1.16.0 + # via bravado + # via bravado-core # via docker-pycreds + # via neptune # via python-dateutil + # via rfc3339-validator smmap==5.0.1 # via gitdb +swagger-spec-validator==3.0.3 + # via bravado-core + # via neptune sympy==1.12 # via torch timm==0.9.16 @@ -296,21 +382,34 @@ triton==2.2.0 # via torch typeguard==2.13.3 # via jaxtyping +types-python-dateutil==2.9.0.20240316 + # via arrow typing-extensions==4.10.0 # via black + # via bravado # via huggingface-hub + # via neptune # via pydantic # via pydantic-core + # via swagger-spec-validator # via torch tzdata==2024.1 # via pandas +uri-template==1.3.0 + # via jsonschema urllib3==2.2.1 + # via botocore + # via neptune # via requests # via sentry-sdk wandb==0.16.5 # via refiners watchdog==4.0.0 # via mkdocs +webcolors==1.13 + # via jsonschema +websocket-client==1.8.0 + # via neptune xxhash==3.4.1 # via datasets yarl==1.9.4 diff --git a/src/refiners/training_utils/neptune.py b/src/refiners/training_utils/neptune.py new file mode 100644 index 0000000..018b014 --- /dev/null +++ b/src/refiners/training_utils/neptune.py @@ -0,0 +1,105 @@ +from abc import ABC +from os import PathLike +from typing import Any, Literal + +from neptune import Run, init_run # type: ignore +from neptune.internal.init.parameters import ( # type: ignore + ASYNC_LAG_THRESHOLD, + ASYNC_NO_PROGRESS_THRESHOLD, + DEFAULT_FLUSH_PERIOD, +) +from neptune.metadata_containers.abstract import NeptuneObjectCallback # type: ignore +from neptune.types.atoms.git_ref import GitRef, GitRefDisabled # type: ignore + +from refiners.training_utils.callback import Callback, CallbackConfig +from refiners.training_utils.config import BaseConfig +from refiners.training_utils.trainer import Trainer, register_callback + +AnyTrainer = Trainer[BaseConfig, Any] + + +class NeptuneConfig(CallbackConfig): + """Neptune.ai run configuration + + See https://docs.neptune.ai/api/neptune#init_run + and https://github.com/neptune-ai/neptune-client/blob/1cd8452045e8524318f59216d151d73328f85bd1/src/neptune/objects/run.py#L131 + """ + + project: str | None = None + api_token: str | None = None + with_id: str | None = None + custom_run_id: str | None = None + mode: Literal["async", "sync", "offline", "read-only", "debug"] | None = None + name: str | None = None + description: str | None = None + tags: str | list[str] | None = None + source_files: str | list[str] | None = None + capture_stdout: bool | None = None + capture_stderr: bool | None = None + capture_hardware_metrics: bool | None = None + fail_on_exception: bool = True + monitoring_namespace: str | None = None + flush_period: float = DEFAULT_FLUSH_PERIOD + proxies: dict[str, str] | None = None + capture_traceback: bool = True + git_ref: GitRef | GitRefDisabled | None = None + dependencies: PathLike[str] | str | None = None + async_lag_callback: NeptuneObjectCallback | None = None + async_no_progress_callback: NeptuneObjectCallback | None = None + async_lag_threshold: float = ASYNC_LAG_THRESHOLD + async_no_progress_threshold: float = ASYNC_NO_PROGRESS_THRESHOLD + + +class NeptuneCallback(Callback[AnyTrainer]): + """Neptune.ai callback for logging metrics""" + + run: Run + + def __init__(self, config: NeptuneConfig) -> None: + """Initialize Neptune.ai callback + + Args: + config: Neptune.ai run configuration + """ + self.config = config + self.epoch_losses: list[float] = [] + self.iteration_losses: list[float] = [] + + def on_train_begin(self, trainer: AnyTrainer) -> None: + # initialize Neptune `Run` (see https://docs.neptune.ai/api/run/) + self.run = init_run(**self.config.model_dump()) + self.run["config"] = trainer.config.model_dump() + + # reset epoch and iteration losses + self.epoch_losses = [] + self.iteration_losses = [] + + def on_compute_loss_end(self, trainer: AnyTrainer) -> None: + loss_value = trainer.loss.detach().cpu().item() + self.epoch_losses.append(loss_value) + self.iteration_losses.append(loss_value) + self.run["train/step_loss"].append(loss_value, step=trainer.clock.step) # type: ignore + + def on_optimizer_step_end(self, trainer: AnyTrainer) -> None: + self.run["train/total_grad_norm"].append(trainer.grad_norm, step=trainer.clock.step) # type: ignore + avg_iteration_loss = sum(self.iteration_losses) / len(self.iteration_losses) + self.run["train/average_iteration_loss"].append(avg_iteration_loss, step=trainer.clock.step) # type: ignore + self.iteration_losses = [] + + def on_epoch_end(self, trainer: AnyTrainer) -> None: + avg_epoch_loss = sum(self.epoch_losses) / len(self.epoch_losses) + self.run["train/average_epoch_loss"].append(avg_epoch_loss, step=trainer.clock.step) # type: ignore + self.run["train/epoch"].append(trainer.clock.epoch, step=trainer.clock.step) # type: ignore + self.epoch_losses = [] + + def on_lr_scheduler_step_end(self, trainer: AnyTrainer) -> None: + self.run["train/learning_rate"].append(trainer.optimizer.param_groups[0]["lr"], step=trainer.clock.step) # type: ignore + + def on_train_end(self, trainer: AnyTrainer) -> None: + self.run.stop() + + +class NeptuneMixin(ABC): + @register_callback() + def neptune(self, config: NeptuneConfig) -> NeptuneCallback: + return NeptuneCallback(config)