mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-25 07:38:45 +00:00
(training_utils) add NeptuneCallback
This commit is contained in:
parent
3a7f14e4dc
commit
3ad7f592db
|
@ -33,6 +33,7 @@ training = [
|
||||||
"torchvision>=0.16.1",
|
"torchvision>=0.16.1",
|
||||||
"loguru>=0.7.2",
|
"loguru>=0.7.2",
|
||||||
"wandb>=0.16.0",
|
"wandb>=0.16.0",
|
||||||
|
"neptune>=1.10.4",
|
||||||
"datasets>=2.15.0",
|
"datasets>=2.15.0",
|
||||||
"tomli>=2.0.1",
|
"tomli>=2.0.1",
|
||||||
"gitpython>=3.1.43",
|
"gitpython>=3.1.43",
|
||||||
|
|
|
@ -17,16 +17,29 @@ annotated-types==0.6.0
|
||||||
# via pydantic
|
# via pydantic
|
||||||
appdirs==1.4.4
|
appdirs==1.4.4
|
||||||
# via wandb
|
# via wandb
|
||||||
|
arrow==1.3.0
|
||||||
|
# via isoduration
|
||||||
async-timeout==4.0.3
|
async-timeout==4.0.3
|
||||||
# via aiohttp
|
# via aiohttp
|
||||||
attrs==23.2.0
|
attrs==23.2.0
|
||||||
# via aiohttp
|
# via aiohttp
|
||||||
|
# via jsonschema
|
||||||
|
# via referencing
|
||||||
babel==2.14.0
|
babel==2.14.0
|
||||||
# via mkdocs-material
|
# via mkdocs-material
|
||||||
bitsandbytes==0.43.0
|
bitsandbytes==0.43.0
|
||||||
# via refiners
|
# via refiners
|
||||||
black==24.3.0
|
black==24.3.0
|
||||||
# via refiners
|
# via refiners
|
||||||
|
boto3==1.34.113
|
||||||
|
# via neptune
|
||||||
|
botocore==1.34.113
|
||||||
|
# via boto3
|
||||||
|
# via s3transfer
|
||||||
|
bravado==11.0.3
|
||||||
|
# via neptune
|
||||||
|
bravado-core==6.1.1
|
||||||
|
# via bravado
|
||||||
certifi==2024.2.2
|
certifi==2024.2.2
|
||||||
# via requests
|
# via requests
|
||||||
# via sentry-sdk
|
# via sentry-sdk
|
||||||
|
@ -36,6 +49,7 @@ click==8.1.7
|
||||||
# via black
|
# via black
|
||||||
# via mkdocs
|
# via mkdocs
|
||||||
# via mkdocstrings
|
# via mkdocstrings
|
||||||
|
# via neptune
|
||||||
# via wandb
|
# via wandb
|
||||||
colorama==0.4.6
|
colorama==0.4.6
|
||||||
# via griffe
|
# via griffe
|
||||||
|
@ -56,6 +70,8 @@ filelock==3.13.3
|
||||||
# via torch
|
# via torch
|
||||||
# via transformers
|
# via transformers
|
||||||
# via triton
|
# via triton
|
||||||
|
fqdn==1.5.1
|
||||||
|
# via jsonschema
|
||||||
frozenlist==1.4.1
|
frozenlist==1.4.1
|
||||||
# via aiohttp
|
# via aiohttp
|
||||||
# via aiosignal
|
# via aiosignal
|
||||||
|
@ -63,11 +79,14 @@ fsspec==2024.2.0
|
||||||
# via datasets
|
# via datasets
|
||||||
# via huggingface-hub
|
# via huggingface-hub
|
||||||
# via torch
|
# via torch
|
||||||
|
future==1.0.0
|
||||||
|
# via neptune
|
||||||
ghp-import==2.1.0
|
ghp-import==2.1.0
|
||||||
# via mkdocs
|
# via mkdocs
|
||||||
gitdb==4.0.11
|
gitdb==4.0.11
|
||||||
# via gitpython
|
# via gitpython
|
||||||
gitpython==3.1.43
|
gitpython==3.1.43
|
||||||
|
# via neptune
|
||||||
# via refiners
|
# via refiners
|
||||||
# via wandb
|
# via wandb
|
||||||
griffe==0.42.1
|
griffe==0.42.1
|
||||||
|
@ -79,10 +98,13 @@ huggingface-hub==0.22.2
|
||||||
# via tokenizers
|
# via tokenizers
|
||||||
# via transformers
|
# via transformers
|
||||||
idna==3.6
|
idna==3.6
|
||||||
|
# via jsonschema
|
||||||
# via requests
|
# via requests
|
||||||
# via yarl
|
# via yarl
|
||||||
importlib-metadata==7.1.0
|
importlib-metadata==7.1.0
|
||||||
# via diffusers
|
# via diffusers
|
||||||
|
isoduration==20.11.0
|
||||||
|
# via jsonschema
|
||||||
jaxtyping==0.2.28
|
jaxtyping==0.2.28
|
||||||
# via refiners
|
# via refiners
|
||||||
jinja2==3.1.3
|
jinja2==3.1.3
|
||||||
|
@ -90,6 +112,18 @@ jinja2==3.1.3
|
||||||
# via mkdocs-material
|
# via mkdocs-material
|
||||||
# via mkdocstrings
|
# via mkdocstrings
|
||||||
# via torch
|
# via torch
|
||||||
|
jmespath==1.0.1
|
||||||
|
# via boto3
|
||||||
|
# via botocore
|
||||||
|
jsonpointer==2.4
|
||||||
|
# via jsonschema
|
||||||
|
jsonref==1.1.0
|
||||||
|
# via bravado-core
|
||||||
|
jsonschema==4.22.0
|
||||||
|
# via bravado-core
|
||||||
|
# via swagger-spec-validator
|
||||||
|
jsonschema-specifications==2023.12.1
|
||||||
|
# via jsonschema
|
||||||
loguru==0.7.2
|
loguru==0.7.2
|
||||||
# via refiners
|
# via refiners
|
||||||
markdown==3.6
|
markdown==3.6
|
||||||
|
@ -123,8 +157,13 @@ mkdocstrings==0.24.1
|
||||||
# via refiners
|
# via refiners
|
||||||
mkdocstrings-python==1.8.0
|
mkdocstrings-python==1.8.0
|
||||||
# via mkdocstrings
|
# via mkdocstrings
|
||||||
|
monotonic==1.6
|
||||||
|
# via bravado
|
||||||
mpmath==1.3.0
|
mpmath==1.3.0
|
||||||
# via sympy
|
# via sympy
|
||||||
|
msgpack==1.0.8
|
||||||
|
# via bravado
|
||||||
|
# via bravado-core
|
||||||
multidict==6.0.5
|
multidict==6.0.5
|
||||||
# via aiohttp
|
# via aiohttp
|
||||||
# via yarl
|
# via yarl
|
||||||
|
@ -132,6 +171,8 @@ multiprocess==0.70.16
|
||||||
# via datasets
|
# via datasets
|
||||||
mypy-extensions==1.0.0
|
mypy-extensions==1.0.0
|
||||||
# via black
|
# via black
|
||||||
|
neptune==1.10.4
|
||||||
|
# via refiners
|
||||||
networkx==3.2.1
|
networkx==3.2.1
|
||||||
# via torch
|
# via torch
|
||||||
numpy==1.26.4
|
numpy==1.26.4
|
||||||
|
@ -172,22 +213,28 @@ nvidia-nvjitlink-cu12==12.4.99
|
||||||
# via nvidia-cusparse-cu12
|
# via nvidia-cusparse-cu12
|
||||||
nvidia-nvtx-cu12==12.1.105
|
nvidia-nvtx-cu12==12.1.105
|
||||||
# via torch
|
# via torch
|
||||||
|
oauthlib==3.2.2
|
||||||
|
# via neptune
|
||||||
|
# via requests-oauthlib
|
||||||
packaging==24.0
|
packaging==24.0
|
||||||
# via black
|
# via black
|
||||||
# via datasets
|
# via datasets
|
||||||
# via huggingface-hub
|
# via huggingface-hub
|
||||||
# via mkdocs
|
# via mkdocs
|
||||||
|
# via neptune
|
||||||
# via refiners
|
# via refiners
|
||||||
# via transformers
|
# via transformers
|
||||||
paginate==0.5.6
|
paginate==0.5.6
|
||||||
# via mkdocs-material
|
# via mkdocs-material
|
||||||
pandas==2.2.1
|
pandas==2.2.1
|
||||||
# via datasets
|
# via datasets
|
||||||
|
# via neptune
|
||||||
pathspec==0.12.1
|
pathspec==0.12.1
|
||||||
# via black
|
# via black
|
||||||
# via mkdocs
|
# via mkdocs
|
||||||
pillow==10.3.0
|
pillow==10.3.0
|
||||||
# via diffusers
|
# via diffusers
|
||||||
|
# via neptune
|
||||||
# via refiners
|
# via refiners
|
||||||
# via torchvision
|
# via torchvision
|
||||||
piq==0.8.0
|
piq==0.8.0
|
||||||
|
@ -201,6 +248,7 @@ prodigyopt==1.0
|
||||||
protobuf==4.25.3
|
protobuf==4.25.3
|
||||||
# via wandb
|
# via wandb
|
||||||
psutil==5.9.8
|
psutil==5.9.8
|
||||||
|
# via neptune
|
||||||
# via wandb
|
# via wandb
|
||||||
pyarrow==15.0.2
|
pyarrow==15.0.2
|
||||||
# via datasets
|
# via datasets
|
||||||
|
@ -212,37 +260,65 @@ pydantic-core==2.16.3
|
||||||
# via pydantic
|
# via pydantic
|
||||||
pygments==2.17.2
|
pygments==2.17.2
|
||||||
# via mkdocs-material
|
# via mkdocs-material
|
||||||
|
pyjwt==2.8.0
|
||||||
|
# via neptune
|
||||||
pymdown-extensions==10.7.1
|
pymdown-extensions==10.7.1
|
||||||
# via mkdocs-material
|
# via mkdocs-material
|
||||||
# via mkdocstrings
|
# via mkdocstrings
|
||||||
python-dateutil==2.9.0.post0
|
python-dateutil==2.9.0.post0
|
||||||
|
# via arrow
|
||||||
|
# via botocore
|
||||||
|
# via bravado
|
||||||
|
# via bravado-core
|
||||||
# via ghp-import
|
# via ghp-import
|
||||||
# via pandas
|
# via pandas
|
||||||
pytz==2024.1
|
pytz==2024.1
|
||||||
|
# via bravado-core
|
||||||
# via pandas
|
# via pandas
|
||||||
pyyaml==6.0.1
|
pyyaml==6.0.1
|
||||||
|
# via bravado
|
||||||
|
# via bravado-core
|
||||||
# via datasets
|
# via datasets
|
||||||
# via huggingface-hub
|
# via huggingface-hub
|
||||||
# via mkdocs
|
# via mkdocs
|
||||||
# via pymdown-extensions
|
# via pymdown-extensions
|
||||||
# via pyyaml-env-tag
|
# via pyyaml-env-tag
|
||||||
|
# via swagger-spec-validator
|
||||||
# via timm
|
# via timm
|
||||||
# via transformers
|
# via transformers
|
||||||
# via wandb
|
# via wandb
|
||||||
pyyaml-env-tag==0.1
|
pyyaml-env-tag==0.1
|
||||||
# via mkdocs
|
# via mkdocs
|
||||||
|
referencing==0.35.1
|
||||||
|
# via jsonschema
|
||||||
|
# via jsonschema-specifications
|
||||||
regex==2023.12.25
|
regex==2023.12.25
|
||||||
# via diffusers
|
# via diffusers
|
||||||
# via mkdocs-material
|
# via mkdocs-material
|
||||||
# via transformers
|
# via transformers
|
||||||
requests==2.31.0
|
requests==2.31.0
|
||||||
|
# via bravado
|
||||||
|
# via bravado-core
|
||||||
# via datasets
|
# via datasets
|
||||||
# via diffusers
|
# via diffusers
|
||||||
# via huggingface-hub
|
# via huggingface-hub
|
||||||
# via mkdocs-material
|
# via mkdocs-material
|
||||||
|
# via neptune
|
||||||
# via refiners
|
# via refiners
|
||||||
|
# via requests-oauthlib
|
||||||
# via transformers
|
# via transformers
|
||||||
# via wandb
|
# via wandb
|
||||||
|
requests-oauthlib==2.0.0
|
||||||
|
# via neptune
|
||||||
|
rfc3339-validator==0.1.4
|
||||||
|
# via jsonschema
|
||||||
|
rfc3986-validator==0.1.1
|
||||||
|
# via jsonschema
|
||||||
|
rpds-py==0.18.1
|
||||||
|
# via jsonschema
|
||||||
|
# via referencing
|
||||||
|
s3transfer==0.10.1
|
||||||
|
# via boto3
|
||||||
safetensors==0.4.2
|
safetensors==0.4.2
|
||||||
# via diffusers
|
# via diffusers
|
||||||
# via refiners
|
# via refiners
|
||||||
|
@ -258,11 +334,21 @@ setproctitle==1.3.3
|
||||||
# via wandb
|
# via wandb
|
||||||
setuptools==69.2.0
|
setuptools==69.2.0
|
||||||
# via wandb
|
# via wandb
|
||||||
|
simplejson==3.19.2
|
||||||
|
# via bravado
|
||||||
|
# via bravado-core
|
||||||
six==1.16.0
|
six==1.16.0
|
||||||
|
# via bravado
|
||||||
|
# via bravado-core
|
||||||
# via docker-pycreds
|
# via docker-pycreds
|
||||||
|
# via neptune
|
||||||
# via python-dateutil
|
# via python-dateutil
|
||||||
|
# via rfc3339-validator
|
||||||
smmap==5.0.1
|
smmap==5.0.1
|
||||||
# via gitdb
|
# via gitdb
|
||||||
|
swagger-spec-validator==3.0.3
|
||||||
|
# via bravado-core
|
||||||
|
# via neptune
|
||||||
sympy==1.12
|
sympy==1.12
|
||||||
# via torch
|
# via torch
|
||||||
timm==0.9.16
|
timm==0.9.16
|
||||||
|
@ -296,21 +382,34 @@ triton==2.2.0
|
||||||
# via torch
|
# via torch
|
||||||
typeguard==2.13.3
|
typeguard==2.13.3
|
||||||
# via jaxtyping
|
# via jaxtyping
|
||||||
|
types-python-dateutil==2.9.0.20240316
|
||||||
|
# via arrow
|
||||||
typing-extensions==4.10.0
|
typing-extensions==4.10.0
|
||||||
# via black
|
# via black
|
||||||
|
# via bravado
|
||||||
# via huggingface-hub
|
# via huggingface-hub
|
||||||
|
# via neptune
|
||||||
# via pydantic
|
# via pydantic
|
||||||
# via pydantic-core
|
# via pydantic-core
|
||||||
|
# via swagger-spec-validator
|
||||||
# via torch
|
# via torch
|
||||||
tzdata==2024.1
|
tzdata==2024.1
|
||||||
# via pandas
|
# via pandas
|
||||||
|
uri-template==1.3.0
|
||||||
|
# via jsonschema
|
||||||
urllib3==2.2.1
|
urllib3==2.2.1
|
||||||
|
# via botocore
|
||||||
|
# via neptune
|
||||||
# via requests
|
# via requests
|
||||||
# via sentry-sdk
|
# via sentry-sdk
|
||||||
wandb==0.16.5
|
wandb==0.16.5
|
||||||
# via refiners
|
# via refiners
|
||||||
watchdog==4.0.0
|
watchdog==4.0.0
|
||||||
# via mkdocs
|
# via mkdocs
|
||||||
|
webcolors==1.13
|
||||||
|
# via jsonschema
|
||||||
|
websocket-client==1.8.0
|
||||||
|
# via neptune
|
||||||
xxhash==3.4.1
|
xxhash==3.4.1
|
||||||
# via datasets
|
# via datasets
|
||||||
yarl==1.9.4
|
yarl==1.9.4
|
||||||
|
|
105
src/refiners/training_utils/neptune.py
Normal file
105
src/refiners/training_utils/neptune.py
Normal file
|
@ -0,0 +1,105 @@
|
||||||
|
from abc import ABC
|
||||||
|
from os import PathLike
|
||||||
|
from typing import Any, Literal
|
||||||
|
|
||||||
|
from neptune import Run, init_run # type: ignore
|
||||||
|
from neptune.internal.init.parameters import ( # type: ignore
|
||||||
|
ASYNC_LAG_THRESHOLD,
|
||||||
|
ASYNC_NO_PROGRESS_THRESHOLD,
|
||||||
|
DEFAULT_FLUSH_PERIOD,
|
||||||
|
)
|
||||||
|
from neptune.metadata_containers.abstract import NeptuneObjectCallback # type: ignore
|
||||||
|
from neptune.types.atoms.git_ref import GitRef, GitRefDisabled # type: ignore
|
||||||
|
|
||||||
|
from refiners.training_utils.callback import Callback, CallbackConfig
|
||||||
|
from refiners.training_utils.config import BaseConfig
|
||||||
|
from refiners.training_utils.trainer import Trainer, register_callback
|
||||||
|
|
||||||
|
AnyTrainer = Trainer[BaseConfig, Any]
|
||||||
|
|
||||||
|
|
||||||
|
class NeptuneConfig(CallbackConfig):
|
||||||
|
"""Neptune.ai run configuration
|
||||||
|
|
||||||
|
See https://docs.neptune.ai/api/neptune#init_run
|
||||||
|
and https://github.com/neptune-ai/neptune-client/blob/1cd8452045e8524318f59216d151d73328f85bd1/src/neptune/objects/run.py#L131
|
||||||
|
"""
|
||||||
|
|
||||||
|
project: str | None = None
|
||||||
|
api_token: str | None = None
|
||||||
|
with_id: str | None = None
|
||||||
|
custom_run_id: str | None = None
|
||||||
|
mode: Literal["async", "sync", "offline", "read-only", "debug"] | None = None
|
||||||
|
name: str | None = None
|
||||||
|
description: str | None = None
|
||||||
|
tags: str | list[str] | None = None
|
||||||
|
source_files: str | list[str] | None = None
|
||||||
|
capture_stdout: bool | None = None
|
||||||
|
capture_stderr: bool | None = None
|
||||||
|
capture_hardware_metrics: bool | None = None
|
||||||
|
fail_on_exception: bool = True
|
||||||
|
monitoring_namespace: str | None = None
|
||||||
|
flush_period: float = DEFAULT_FLUSH_PERIOD
|
||||||
|
proxies: dict[str, str] | None = None
|
||||||
|
capture_traceback: bool = True
|
||||||
|
git_ref: GitRef | GitRefDisabled | None = None
|
||||||
|
dependencies: PathLike[str] | str | None = None
|
||||||
|
async_lag_callback: NeptuneObjectCallback | None = None
|
||||||
|
async_no_progress_callback: NeptuneObjectCallback | None = None
|
||||||
|
async_lag_threshold: float = ASYNC_LAG_THRESHOLD
|
||||||
|
async_no_progress_threshold: float = ASYNC_NO_PROGRESS_THRESHOLD
|
||||||
|
|
||||||
|
|
||||||
|
class NeptuneCallback(Callback[AnyTrainer]):
|
||||||
|
"""Neptune.ai callback for logging metrics"""
|
||||||
|
|
||||||
|
run: Run
|
||||||
|
|
||||||
|
def __init__(self, config: NeptuneConfig) -> None:
|
||||||
|
"""Initialize Neptune.ai callback
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: Neptune.ai run configuration
|
||||||
|
"""
|
||||||
|
self.config = config
|
||||||
|
self.epoch_losses: list[float] = []
|
||||||
|
self.iteration_losses: list[float] = []
|
||||||
|
|
||||||
|
def on_train_begin(self, trainer: AnyTrainer) -> None:
|
||||||
|
# initialize Neptune `Run` (see https://docs.neptune.ai/api/run/)
|
||||||
|
self.run = init_run(**self.config.model_dump())
|
||||||
|
self.run["config"] = trainer.config.model_dump()
|
||||||
|
|
||||||
|
# reset epoch and iteration losses
|
||||||
|
self.epoch_losses = []
|
||||||
|
self.iteration_losses = []
|
||||||
|
|
||||||
|
def on_compute_loss_end(self, trainer: AnyTrainer) -> None:
|
||||||
|
loss_value = trainer.loss.detach().cpu().item()
|
||||||
|
self.epoch_losses.append(loss_value)
|
||||||
|
self.iteration_losses.append(loss_value)
|
||||||
|
self.run["train/step_loss"].append(loss_value, step=trainer.clock.step) # type: ignore
|
||||||
|
|
||||||
|
def on_optimizer_step_end(self, trainer: AnyTrainer) -> None:
|
||||||
|
self.run["train/total_grad_norm"].append(trainer.grad_norm, step=trainer.clock.step) # type: ignore
|
||||||
|
avg_iteration_loss = sum(self.iteration_losses) / len(self.iteration_losses)
|
||||||
|
self.run["train/average_iteration_loss"].append(avg_iteration_loss, step=trainer.clock.step) # type: ignore
|
||||||
|
self.iteration_losses = []
|
||||||
|
|
||||||
|
def on_epoch_end(self, trainer: AnyTrainer) -> None:
|
||||||
|
avg_epoch_loss = sum(self.epoch_losses) / len(self.epoch_losses)
|
||||||
|
self.run["train/average_epoch_loss"].append(avg_epoch_loss, step=trainer.clock.step) # type: ignore
|
||||||
|
self.run["train/epoch"].append(trainer.clock.epoch, step=trainer.clock.step) # type: ignore
|
||||||
|
self.epoch_losses = []
|
||||||
|
|
||||||
|
def on_lr_scheduler_step_end(self, trainer: AnyTrainer) -> None:
|
||||||
|
self.run["train/learning_rate"].append(trainer.optimizer.param_groups[0]["lr"], step=trainer.clock.step) # type: ignore
|
||||||
|
|
||||||
|
def on_train_end(self, trainer: AnyTrainer) -> None:
|
||||||
|
self.run.stop()
|
||||||
|
|
||||||
|
|
||||||
|
class NeptuneMixin(ABC):
|
||||||
|
@register_callback()
|
||||||
|
def neptune(self, config: NeptuneConfig) -> NeptuneCallback:
|
||||||
|
return NeptuneCallback(config)
|
Loading…
Reference in a new issue