diff --git a/pyproject.toml b/pyproject.toml index edcf9e0..baf81b3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,6 +35,7 @@ training = [ "wandb>=0.16.0", "datasets>=2.15.0", "tomli>=2.0.1", + "gitpython>=3.1.43", ] test = [ "diffusers>=0.26.1", diff --git a/requirements.lock b/requirements.lock index 9a7dd72..967219b 100644 --- a/requirements.lock +++ b/requirements.lock @@ -68,6 +68,7 @@ ghp-import==2.1.0 gitdb==4.0.11 # via gitpython gitpython==3.1.43 + # via refiners # via wandb griffe==0.42.1 # via mkdocstrings-python diff --git a/src/refiners/training_utils/forcecommit.py b/src/refiners/training_utils/forcecommit.py new file mode 100644 index 0000000..aa453da --- /dev/null +++ b/src/refiners/training_utils/forcecommit.py @@ -0,0 +1,81 @@ +from typing import Any + +import wandb +from git import Repo +from loguru import logger + +from refiners.training_utils.callback import Callback, CallbackConfig +from refiners.training_utils.config import BaseConfig +from refiners.training_utils.trainer import Trainer + +AnyTrainer = Trainer[BaseConfig, Any] + + +class ForceCommitConfig(CallbackConfig): + """Configuration of the ForceCommit callback. + + Attributes: + check_changed: Whether to check if there are modified files. + check_untracked: Whether to check if there are untracked files. + upload_wandb_patch: Whether to upload the patch of the changes. + search_parent_directories: Whether to search parent directories for the git repository. + exclusions: List of files to exclude from the checks. + """ + + check_changed: bool = True + check_untracked: bool = False + upload_wandb_patch: bool = False + search_parent_directories: bool = False + exclusions: list[str] = [] + + +class ForceCommit(Callback[AnyTrainer]): + """Callback to force user to commit or stash changes before running the training. + + This callback assumes that the training is run from a git repository. + """ + + def __init__(self, config: ForceCommitConfig) -> None: + """Initialize the callback. + + Args: + config: Configuration of the callback. + """ + self.check_changed = config.check_changed + self.check_untracked = config.check_untracked + self.upload_wandb_patch = config.upload_wandb_patch + self.search_parent_directories = config.search_parent_directories + self.exclusions = config.exclusions + + def on_init_begin(self, trainer: AnyTrainer) -> None: + # get git repo and diff list + repo = Repo(search_parent_directories=self.search_parent_directories) + logger.info(f"Git repository: {repo.working_dir}") + logger.info(f"Git branch: {repo.active_branch}") + logger.info(f"Git commit: {repo.head.commit.hexsha}") + diffs = repo.index.diff(other=None, create_patch=True) # type: ignore + + # get list of modified files + modified_files: list[str] = [item.a_path for item in diffs] # type: ignore + modified_files: set[str] = set(modified_files) - set(self.exclusions) + logger.info(f"Modified files: {modified_files}") + if self.check_changed and modified_files: + raise RuntimeError( + "There are modified files. Please commit or stash them before running the training.", + ) + + # get list of untracked files + untracked_files = repo.untracked_files + untracked_files = set(untracked_files) - set(self.exclusions) + logger.info(f"Untracked files: {untracked_files}") + if self.check_untracked and untracked_files: + raise RuntimeError( + "There are untracked files. Please add them to the repository before running the training.", + ) + + # create patch + if self.upload_wandb_patch: + patch = str(repo.git.diff()).replace("\n", "
") + artifact = wandb.Artifact(name="git", type="metadata") + artifact.add(name="patch", obj=wandb.Html(patch)) + wandb.log_artifact(artifact) # type: ignore