mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-09 23:12:02 +00:00
(training_utils) add NeptuneCallback
This commit is contained in:
parent
3a7f14e4dc
commit
3ad7f592db
|
@ -33,6 +33,7 @@ training = [
|
|||
"torchvision>=0.16.1",
|
||||
"loguru>=0.7.2",
|
||||
"wandb>=0.16.0",
|
||||
"neptune>=1.10.4",
|
||||
"datasets>=2.15.0",
|
||||
"tomli>=2.0.1",
|
||||
"gitpython>=3.1.43",
|
||||
|
|
|
@ -17,16 +17,29 @@ annotated-types==0.6.0
|
|||
# via pydantic
|
||||
appdirs==1.4.4
|
||||
# via wandb
|
||||
arrow==1.3.0
|
||||
# via isoduration
|
||||
async-timeout==4.0.3
|
||||
# via aiohttp
|
||||
attrs==23.2.0
|
||||
# via aiohttp
|
||||
# via jsonschema
|
||||
# via referencing
|
||||
babel==2.14.0
|
||||
# via mkdocs-material
|
||||
bitsandbytes==0.43.0
|
||||
# via refiners
|
||||
black==24.3.0
|
||||
# via refiners
|
||||
boto3==1.34.113
|
||||
# via neptune
|
||||
botocore==1.34.113
|
||||
# via boto3
|
||||
# via s3transfer
|
||||
bravado==11.0.3
|
||||
# via neptune
|
||||
bravado-core==6.1.1
|
||||
# via bravado
|
||||
certifi==2024.2.2
|
||||
# via requests
|
||||
# via sentry-sdk
|
||||
|
@ -36,6 +49,7 @@ click==8.1.7
|
|||
# via black
|
||||
# via mkdocs
|
||||
# via mkdocstrings
|
||||
# via neptune
|
||||
# via wandb
|
||||
colorama==0.4.6
|
||||
# via griffe
|
||||
|
@ -56,6 +70,8 @@ filelock==3.13.3
|
|||
# via torch
|
||||
# via transformers
|
||||
# via triton
|
||||
fqdn==1.5.1
|
||||
# via jsonschema
|
||||
frozenlist==1.4.1
|
||||
# via aiohttp
|
||||
# via aiosignal
|
||||
|
@ -63,11 +79,14 @@ fsspec==2024.2.0
|
|||
# via datasets
|
||||
# via huggingface-hub
|
||||
# via torch
|
||||
future==1.0.0
|
||||
# via neptune
|
||||
ghp-import==2.1.0
|
||||
# via mkdocs
|
||||
gitdb==4.0.11
|
||||
# via gitpython
|
||||
gitpython==3.1.43
|
||||
# via neptune
|
||||
# via refiners
|
||||
# via wandb
|
||||
griffe==0.42.1
|
||||
|
@ -79,10 +98,13 @@ huggingface-hub==0.22.2
|
|||
# via tokenizers
|
||||
# via transformers
|
||||
idna==3.6
|
||||
# via jsonschema
|
||||
# via requests
|
||||
# via yarl
|
||||
importlib-metadata==7.1.0
|
||||
# via diffusers
|
||||
isoduration==20.11.0
|
||||
# via jsonschema
|
||||
jaxtyping==0.2.28
|
||||
# via refiners
|
||||
jinja2==3.1.3
|
||||
|
@ -90,6 +112,18 @@ jinja2==3.1.3
|
|||
# via mkdocs-material
|
||||
# via mkdocstrings
|
||||
# via torch
|
||||
jmespath==1.0.1
|
||||
# via boto3
|
||||
# via botocore
|
||||
jsonpointer==2.4
|
||||
# via jsonschema
|
||||
jsonref==1.1.0
|
||||
# via bravado-core
|
||||
jsonschema==4.22.0
|
||||
# via bravado-core
|
||||
# via swagger-spec-validator
|
||||
jsonschema-specifications==2023.12.1
|
||||
# via jsonschema
|
||||
loguru==0.7.2
|
||||
# via refiners
|
||||
markdown==3.6
|
||||
|
@ -123,8 +157,13 @@ mkdocstrings==0.24.1
|
|||
# via refiners
|
||||
mkdocstrings-python==1.8.0
|
||||
# via mkdocstrings
|
||||
monotonic==1.6
|
||||
# via bravado
|
||||
mpmath==1.3.0
|
||||
# via sympy
|
||||
msgpack==1.0.8
|
||||
# via bravado
|
||||
# via bravado-core
|
||||
multidict==6.0.5
|
||||
# via aiohttp
|
||||
# via yarl
|
||||
|
@ -132,6 +171,8 @@ multiprocess==0.70.16
|
|||
# via datasets
|
||||
mypy-extensions==1.0.0
|
||||
# via black
|
||||
neptune==1.10.4
|
||||
# via refiners
|
||||
networkx==3.2.1
|
||||
# via torch
|
||||
numpy==1.26.4
|
||||
|
@ -172,22 +213,28 @@ nvidia-nvjitlink-cu12==12.4.99
|
|||
# via nvidia-cusparse-cu12
|
||||
nvidia-nvtx-cu12==12.1.105
|
||||
# via torch
|
||||
oauthlib==3.2.2
|
||||
# via neptune
|
||||
# via requests-oauthlib
|
||||
packaging==24.0
|
||||
# via black
|
||||
# via datasets
|
||||
# via huggingface-hub
|
||||
# via mkdocs
|
||||
# via neptune
|
||||
# via refiners
|
||||
# via transformers
|
||||
paginate==0.5.6
|
||||
# via mkdocs-material
|
||||
pandas==2.2.1
|
||||
# via datasets
|
||||
# via neptune
|
||||
pathspec==0.12.1
|
||||
# via black
|
||||
# via mkdocs
|
||||
pillow==10.3.0
|
||||
# via diffusers
|
||||
# via neptune
|
||||
# via refiners
|
||||
# via torchvision
|
||||
piq==0.8.0
|
||||
|
@ -201,6 +248,7 @@ prodigyopt==1.0
|
|||
protobuf==4.25.3
|
||||
# via wandb
|
||||
psutil==5.9.8
|
||||
# via neptune
|
||||
# via wandb
|
||||
pyarrow==15.0.2
|
||||
# via datasets
|
||||
|
@ -212,37 +260,65 @@ pydantic-core==2.16.3
|
|||
# via pydantic
|
||||
pygments==2.17.2
|
||||
# via mkdocs-material
|
||||
pyjwt==2.8.0
|
||||
# via neptune
|
||||
pymdown-extensions==10.7.1
|
||||
# via mkdocs-material
|
||||
# via mkdocstrings
|
||||
python-dateutil==2.9.0.post0
|
||||
# via arrow
|
||||
# via botocore
|
||||
# via bravado
|
||||
# via bravado-core
|
||||
# via ghp-import
|
||||
# via pandas
|
||||
pytz==2024.1
|
||||
# via bravado-core
|
||||
# via pandas
|
||||
pyyaml==6.0.1
|
||||
# via bravado
|
||||
# via bravado-core
|
||||
# via datasets
|
||||
# via huggingface-hub
|
||||
# via mkdocs
|
||||
# via pymdown-extensions
|
||||
# via pyyaml-env-tag
|
||||
# via swagger-spec-validator
|
||||
# via timm
|
||||
# via transformers
|
||||
# via wandb
|
||||
pyyaml-env-tag==0.1
|
||||
# via mkdocs
|
||||
referencing==0.35.1
|
||||
# via jsonschema
|
||||
# via jsonschema-specifications
|
||||
regex==2023.12.25
|
||||
# via diffusers
|
||||
# via mkdocs-material
|
||||
# via transformers
|
||||
requests==2.31.0
|
||||
# via bravado
|
||||
# via bravado-core
|
||||
# via datasets
|
||||
# via diffusers
|
||||
# via huggingface-hub
|
||||
# via mkdocs-material
|
||||
# via neptune
|
||||
# via refiners
|
||||
# via requests-oauthlib
|
||||
# via transformers
|
||||
# via wandb
|
||||
requests-oauthlib==2.0.0
|
||||
# via neptune
|
||||
rfc3339-validator==0.1.4
|
||||
# via jsonschema
|
||||
rfc3986-validator==0.1.1
|
||||
# via jsonschema
|
||||
rpds-py==0.18.1
|
||||
# via jsonschema
|
||||
# via referencing
|
||||
s3transfer==0.10.1
|
||||
# via boto3
|
||||
safetensors==0.4.2
|
||||
# via diffusers
|
||||
# via refiners
|
||||
|
@ -258,11 +334,21 @@ setproctitle==1.3.3
|
|||
# via wandb
|
||||
setuptools==69.2.0
|
||||
# via wandb
|
||||
simplejson==3.19.2
|
||||
# via bravado
|
||||
# via bravado-core
|
||||
six==1.16.0
|
||||
# via bravado
|
||||
# via bravado-core
|
||||
# via docker-pycreds
|
||||
# via neptune
|
||||
# via python-dateutil
|
||||
# via rfc3339-validator
|
||||
smmap==5.0.1
|
||||
# via gitdb
|
||||
swagger-spec-validator==3.0.3
|
||||
# via bravado-core
|
||||
# via neptune
|
||||
sympy==1.12
|
||||
# via torch
|
||||
timm==0.9.16
|
||||
|
@ -296,21 +382,34 @@ triton==2.2.0
|
|||
# via torch
|
||||
typeguard==2.13.3
|
||||
# via jaxtyping
|
||||
types-python-dateutil==2.9.0.20240316
|
||||
# via arrow
|
||||
typing-extensions==4.10.0
|
||||
# via black
|
||||
# via bravado
|
||||
# via huggingface-hub
|
||||
# via neptune
|
||||
# via pydantic
|
||||
# via pydantic-core
|
||||
# via swagger-spec-validator
|
||||
# via torch
|
||||
tzdata==2024.1
|
||||
# via pandas
|
||||
uri-template==1.3.0
|
||||
# via jsonschema
|
||||
urllib3==2.2.1
|
||||
# via botocore
|
||||
# via neptune
|
||||
# via requests
|
||||
# via sentry-sdk
|
||||
wandb==0.16.5
|
||||
# via refiners
|
||||
watchdog==4.0.0
|
||||
# via mkdocs
|
||||
webcolors==1.13
|
||||
# via jsonschema
|
||||
websocket-client==1.8.0
|
||||
# via neptune
|
||||
xxhash==3.4.1
|
||||
# via datasets
|
||||
yarl==1.9.4
|
||||
|
|
105
src/refiners/training_utils/neptune.py
Normal file
105
src/refiners/training_utils/neptune.py
Normal file
|
@ -0,0 +1,105 @@
|
|||
from abc import ABC
|
||||
from os import PathLike
|
||||
from typing import Any, Literal
|
||||
|
||||
from neptune import Run, init_run # type: ignore
|
||||
from neptune.internal.init.parameters import ( # type: ignore
|
||||
ASYNC_LAG_THRESHOLD,
|
||||
ASYNC_NO_PROGRESS_THRESHOLD,
|
||||
DEFAULT_FLUSH_PERIOD,
|
||||
)
|
||||
from neptune.metadata_containers.abstract import NeptuneObjectCallback # type: ignore
|
||||
from neptune.types.atoms.git_ref import GitRef, GitRefDisabled # type: ignore
|
||||
|
||||
from refiners.training_utils.callback import Callback, CallbackConfig
|
||||
from refiners.training_utils.config import BaseConfig
|
||||
from refiners.training_utils.trainer import Trainer, register_callback
|
||||
|
||||
AnyTrainer = Trainer[BaseConfig, Any]
|
||||
|
||||
|
||||
class NeptuneConfig(CallbackConfig):
|
||||
"""Neptune.ai run configuration
|
||||
|
||||
See https://docs.neptune.ai/api/neptune#init_run
|
||||
and https://github.com/neptune-ai/neptune-client/blob/1cd8452045e8524318f59216d151d73328f85bd1/src/neptune/objects/run.py#L131
|
||||
"""
|
||||
|
||||
project: str | None = None
|
||||
api_token: str | None = None
|
||||
with_id: str | None = None
|
||||
custom_run_id: str | None = None
|
||||
mode: Literal["async", "sync", "offline", "read-only", "debug"] | None = None
|
||||
name: str | None = None
|
||||
description: str | None = None
|
||||
tags: str | list[str] | None = None
|
||||
source_files: str | list[str] | None = None
|
||||
capture_stdout: bool | None = None
|
||||
capture_stderr: bool | None = None
|
||||
capture_hardware_metrics: bool | None = None
|
||||
fail_on_exception: bool = True
|
||||
monitoring_namespace: str | None = None
|
||||
flush_period: float = DEFAULT_FLUSH_PERIOD
|
||||
proxies: dict[str, str] | None = None
|
||||
capture_traceback: bool = True
|
||||
git_ref: GitRef | GitRefDisabled | None = None
|
||||
dependencies: PathLike[str] | str | None = None
|
||||
async_lag_callback: NeptuneObjectCallback | None = None
|
||||
async_no_progress_callback: NeptuneObjectCallback | None = None
|
||||
async_lag_threshold: float = ASYNC_LAG_THRESHOLD
|
||||
async_no_progress_threshold: float = ASYNC_NO_PROGRESS_THRESHOLD
|
||||
|
||||
|
||||
class NeptuneCallback(Callback[AnyTrainer]):
|
||||
"""Neptune.ai callback for logging metrics"""
|
||||
|
||||
run: Run
|
||||
|
||||
def __init__(self, config: NeptuneConfig) -> None:
|
||||
"""Initialize Neptune.ai callback
|
||||
|
||||
Args:
|
||||
config: Neptune.ai run configuration
|
||||
"""
|
||||
self.config = config
|
||||
self.epoch_losses: list[float] = []
|
||||
self.iteration_losses: list[float] = []
|
||||
|
||||
def on_train_begin(self, trainer: AnyTrainer) -> None:
|
||||
# initialize Neptune `Run` (see https://docs.neptune.ai/api/run/)
|
||||
self.run = init_run(**self.config.model_dump())
|
||||
self.run["config"] = trainer.config.model_dump()
|
||||
|
||||
# reset epoch and iteration losses
|
||||
self.epoch_losses = []
|
||||
self.iteration_losses = []
|
||||
|
||||
def on_compute_loss_end(self, trainer: AnyTrainer) -> None:
|
||||
loss_value = trainer.loss.detach().cpu().item()
|
||||
self.epoch_losses.append(loss_value)
|
||||
self.iteration_losses.append(loss_value)
|
||||
self.run["train/step_loss"].append(loss_value, step=trainer.clock.step) # type: ignore
|
||||
|
||||
def on_optimizer_step_end(self, trainer: AnyTrainer) -> None:
|
||||
self.run["train/total_grad_norm"].append(trainer.grad_norm, step=trainer.clock.step) # type: ignore
|
||||
avg_iteration_loss = sum(self.iteration_losses) / len(self.iteration_losses)
|
||||
self.run["train/average_iteration_loss"].append(avg_iteration_loss, step=trainer.clock.step) # type: ignore
|
||||
self.iteration_losses = []
|
||||
|
||||
def on_epoch_end(self, trainer: AnyTrainer) -> None:
|
||||
avg_epoch_loss = sum(self.epoch_losses) / len(self.epoch_losses)
|
||||
self.run["train/average_epoch_loss"].append(avg_epoch_loss, step=trainer.clock.step) # type: ignore
|
||||
self.run["train/epoch"].append(trainer.clock.epoch, step=trainer.clock.step) # type: ignore
|
||||
self.epoch_losses = []
|
||||
|
||||
def on_lr_scheduler_step_end(self, trainer: AnyTrainer) -> None:
|
||||
self.run["train/learning_rate"].append(trainer.optimizer.param_groups[0]["lr"], step=trainer.clock.step) # type: ignore
|
||||
|
||||
def on_train_end(self, trainer: AnyTrainer) -> None:
|
||||
self.run.stop()
|
||||
|
||||
|
||||
class NeptuneMixin(ABC):
|
||||
@register_callback()
|
||||
def neptune(self, config: NeptuneConfig) -> NeptuneCallback:
|
||||
return NeptuneCallback(config)
|
Loading…
Reference in a new issue