replace poetry by rye for python dependency management

Co-authored-by: Cédric Deltheil <cedric@deltheil.me>
Co-authored-by: Pierre Chapuis <git@catwell.info>
This commit is contained in:
limiteinductive 2023-12-08 12:26:50 +01:00 committed by Cédric Deltheil
parent 807ef5551c
commit 86c54977b9
16 changed files with 303 additions and 3543 deletions

View file

@ -1,34 +1,35 @@
name: CI name: CI
on: push
jobs:
on: push
jobs:
lint_and_typecheck: lint_and_typecheck:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- name: checkout - name: checkout
uses: actions/checkout@v4 uses: actions/checkout@v4
- name: Set up python - name: Install Rye
id: setup-python uses: eifinger/setup-rye@v1
uses: actions/setup-python@v4
with: with:
python-version: "3.10" enable-cache: true
cache-prefix: 'refiners-rye-cache'
- name: Install Poetry - name: add home shims dir to PATH
uses: snok/install-poetry@v1 run: echo "$HOME/.rye/shims" >> $GITHUB_PATH
with:
virtualenvs-create: true
virtualenvs-in-project: true
installer-parallel: true
- name: poetry install - name: pin python
run: poetry install --no-interaction --all-extras run: rye pin 3.10
- name: black - name: rye sync
run: poetry run black --check . run: rye sync --all-features
- name: lint - name: ruff format
run: poetry run ruff check . run: rye run ruff format --check .
- name: ruff check
run: rye run ruff check .
- name: typecheck - name: typecheck
run: poetry run pyright run: rye run pyright

3
.gitignore vendored
View file

@ -29,3 +29,6 @@ wandb/
# env variable definitions file # env variable definitions file
.env .env
# lock files
requirements-dev.lock

74
CONTRIBUTING.md Normal file
View file

@ -0,0 +1,74 @@
We are happy to accept contributions from everyone. Feel free to browse our [bounty list](https://www.finegrain.ai/bounties) to find a task you would like to work on.
This document describes the process for contributing to Refiners.
## Licensing
Refiners is a library that is freely available to use and modify under the MIT License. It's essential to exercise caution when using external code, as any code that can affect the licensing of Refiners, including proprietary code, should not be copied and pasted. It's worth noting that some open-source licenses can also be problematic. For instance, we'd need to redistribute an Apache 2.0 license if you're using code from an Apache 2.0 codebase.
## Design principles
We do not enforce strict rules on the design of the code, but we do have a few guidelines that we try to follow:
- No dead code. We keep the codebase clean and remove unnecessary code/functionality.
- No unnecessary dependencies. We keep the number of dependencies to a minimum and only add new ones if necessary.
- Separate concerns. We separate the code into different modules and avoid having too many dependencies between modules. In particular, we try not to revisit existing code/models when adding new functionality. Instead, we add new functionality in a separate module with the `Adapter` pattern.
- Declarative style. We make the code as declarative, self-documenting, and easily read as possible. By reading the model's `repr`, you should understand how it works. We use explicit names for the different components of the models or the variables in the code.
## Setting up your environment
We use [Rye](https://rye-up.com/guide/installation/) to manage our development environment. Please follow the instructions on the Rye website to install it.
Once Rye is installed, you can clone the repository and run `rye sync` to install the dependencies.
## Linting
We use [ruff](https://docs.astral.sh/ruff/) to lint our code. You can lint your code by running.
```bash
rye run lint
```
We also enforce strict type checking with [pyright](https://github.com/microsoft/pyright). You can run the type checker with:
```bash
rye run pyright
```
## Running the tests
Running end-to-end tests is pretty compute-intensive, and you must convert all the model weights to the correct format before you can run them.
First, install test dependencies with:
```bash
rye sync --features test
```
Then, download and convert all the necessary weights. Be aware that this will use around 50 GB of disk space:
```bash
./scripts/prepare-test-weights.sh
```
To run all the tests, you will need to set the following environment variables:
```bash
export REFINERS_TEST_DEVICE=cuda:0
export REFINERS_TEST_WEIGHTS_DIR=$(pwd)/tests/weights
rye run pytest
```
In particular, `-k` is handy only to run tests that match a given expression, e.g.:
```bash
rye run pytest -k diffusion_std_init_image tests/e2e/test_diffusion.py
```
You can also run tests that are lightweight and will run on CPU:
```bash
rye run pytest -k "not test_diffusion"
```

View file

@ -33,32 +33,17 @@ ______________________________________________________________________
### Install ### Install
Refiners is still an early stage project so we recommend using the `main` branch directly with [Poetry](https://python-poetry.org). Refiners is still an early stage project, and we do not release minor versions yet. We recommend
installing the latest version via a git install:
If you just want to use Refiners directly, clone the repository and run:
```bash ```bash
poetry install --all-extras pip install git+https://github.com/finegrain-ai/refiners.git
``` ```
There is currently [a bug with PyTorch 2.0.1 and Poetry](https://github.com/pytorch/pytorch/issues/100974), to work around it run: To include the training utils, use:
```bash ```bash
poetry run pip install --upgrade torch torchvision pip install 'refiners[training] @ git+https://github.com/finegrain-ai/refiners.git'
```
If you want to depend on Refiners in your project which uses Poetry, you can do so:
```bash
poetry add git+ssh://git@github.com:finegrain-ai/refiners.git#main
```
If you want to run tests, we provide a script to download and convert all the necessary weights first. Be aware that this will use around 50 GB of disk space.
```bash
poetry shell
./scripts/prepare-test-weights.sh
pytest
``` ```
### Hello World ### Hello World

3438
poetry.lock generated

File diff suppressed because it is too large Load diff

View file

@ -1,56 +1,67 @@
[tool.poetry] [project]
name = "refiners" name = "refiners"
version = "0.2.0" version = "0.2.0"
description = "The simplest way to train and run adapters on top of foundational models" description = "The simplest way to train and run adapters on top of foundational models"
authors = [ authors = [{ name = "The Finegrain Team", email = "bonjour@lagon.tech" }]
"The Finegrain Team <bonjour@lagon.tech>",
]
license = "MIT" license = "MIT"
dependencies = [
"torch>=2.1.1",
"safetensors>=0.4.0",
"pillow>=10.1.0",
"jaxtyping>=0.2.23",
]
readme = "README.md" readme = "README.md"
packages = [{include = "refiners", from = "src"}] requires-python = ">= 3.10"
[tool.poetry.dependencies] [project.optional-dependencies]
python = ">=3.10,<3.12" training = [
jaxtyping = "^0.2.14" "bitsandbytes>=0.41.2.post2",
torch = "^2.1.0" "pydantic>=2.5.2",
safetensors = "^0.3.0" "prodigyopt>=1.0",
numpy = "^1.24.2" "torchvision>=0.16.1",
pillow = ">=9.5.0" "loguru>=0.7.2",
datasets = {version = "^2.14.0", optional = true} "wandb>=0.16.0",
tomli = {version = "^2.0.1", optional = true}
wandb = {version = "^0.15.7", optional = true}
loguru = {version = "^0.7.0", optional = true}
bitsandbytes = {version = "^0.41.0", optional = true}
prodigyopt = {version = "^1.0", optional = true}
pydantic = {version = "~2.0.3", optional = true}
# Added scipy as a work around until this PR gets merged: # Added scipy as a work around until this PR gets merged:
# https://github.com/TimDettmers/bitsandbytes/pull/525 # https://github.com/TimDettmers/bitsandbytes/pull/525
scipy = {version = "^1.11.1", optional = true} "scipy>=1.11.4",
torchvision = {version = "^0.16.0", optional = true} "datasets>=2.15.0",
diffusers = {version = "^0.21.0", optional = true} ]
transformers = {version = "^4.27.4", optional = true} test = [
piq = {version = "^0.7.1", optional = true} "diffusers>=0.24.0",
invisible-watermark = {version = "^0.2.0", optional = true} "transformers>=4.35.2",
"piq>=0.8.0",
"invisible-watermark>=0.2.0",
"torchvision>=0.16.1",
# An unofficial Python package for Meta AI's Segment Anything Model: # An unofficial Python package for Meta AI's Segment Anything Model:
# https://github.com/opengeos/segment-anything # https://github.com/opengeos/segment-anything
segment-anything-py = {version = "1.0", optional = true} "segment-anything-py>=1.0",
]
[tool.poetry.extras] conversion = [
training = ["datasets", "tomli", "wandb", "loguru", "bitsandbytes", "prodigyopt", "pydantic", "scipy", "torchvision"] "diffusers>=0.24.0",
conversion = ["diffusers", "transformers", "segment-anything-py"] "transformers>=4.35.2",
test = ["diffusers", "transformers", "piq", "invisible-watermark", "segment-anything-py", "torchvision"] "segment-anything-py>=1.0",
]
[tool.poetry.group.dev.dependencies]
black = "^23.1.0"
pytest = "^7.2.2"
isort = "^5.12.0"
ipykernel = "^6.22.0"
pyright = "^1.1.318"
ruff = "^0.0.281"
[build-system] [build-system]
requires = ["poetry-core"] requires = ["hatchling"]
build-backend = "poetry.core.masonry.api" build-backend = "hatchling.build"
[tool.rye]
managed = true
dev-dependencies = [
"pyright == 1.1.333",
"ruff>=0.0.292",
"docformatter>=1.7.5",
"pytest>=7.4.2",
]
[tool.hatch.metadata]
allow-direct-references = true
[tool.rye.scripts]
lint = { chain = ["ruff format .", "ruff --fix ."] }
[tool.black] [tool.black]
line-length = 120 line-length = 120
@ -59,10 +70,18 @@ line-length = 120
ignore = [ ignore = [
"F722", # forward-annotation-syntax-error, because of Jaxtyping "F722", # forward-annotation-syntax-error, because of Jaxtyping
"E731", # do-not-assign-lambda "E731", # do-not-assign-lambda
"E501", # line-too-long, because Black (https://beta.ruff.rs/docs/faq/#is-ruff-compatible-with-black)
] ]
line-length = 120 line-length = 120
[tool.docformatter]
black = true
[tool.isort]
profile = "black"
force_sort_within_sections = true
order_by_type = true
combine_as_imports = true
[tool.pyright] [tool.pyright]
include = ["src/refiners", "tests", "scripts"] include = ["src/refiners", "tests", "scripts"]
strict = ["*"] strict = ["*"]

95
requirements.lock Normal file
View file

@ -0,0 +1,95 @@
# generated by rye
# use `rye lock` or `rye sync` to update this lockfile
#
# last locked with the following flags:
# pre: false
# features: []
# all-features: true
-e file:.
aiohttp==3.9.1
aiosignal==1.3.1
annotated-types==0.6.0
appdirs==1.4.4
async-timeout==4.0.3
attrs==23.1.0
bitsandbytes==0.41.3
certifi==2023.11.17
charset-normalizer==3.3.2
click==8.1.7
datasets==2.15.0
diffusers==0.24.0
dill==0.3.7
docker-pycreds==0.4.0
filelock==3.13.1
frozenlist==1.4.0
fsspec==2023.10.0
gitdb==4.0.11
gitpython==3.1.40
huggingface-hub==0.19.4
idna==3.6
importlib-metadata==7.0.0
invisible-watermark==0.2.0
jaxtyping==0.2.24
jinja2==3.1.2
loguru==0.7.2
markupsafe==2.1.3
mpmath==1.3.0
multidict==6.0.4
multiprocess==0.70.15
networkx==3.2.1
numpy==1.26.2
nvidia-cublas-cu12==12.1.3.1
nvidia-cuda-cupti-cu12==12.1.105
nvidia-cuda-nvrtc-cu12==12.1.105
nvidia-cuda-runtime-cu12==12.1.105
nvidia-cudnn-cu12==8.9.2.26
nvidia-cufft-cu12==11.0.2.54
nvidia-curand-cu12==10.3.2.106
nvidia-cusolver-cu12==11.4.5.107
nvidia-cusparse-cu12==12.1.0.106
nvidia-nccl-cu12==2.18.1
nvidia-nvjitlink-cu12==12.3.101
nvidia-nvtx-cu12==12.1.105
opencv-python==4.8.1.78
packaging==23.2
pandas==2.1.3
pillow==10.1.0
piq==0.8.0
prodigyopt==1.0
protobuf==4.25.1
psutil==5.9.6
pyarrow==14.0.1
pyarrow-hotfix==0.6
pydantic==2.5.2
pydantic-core==2.14.5
python-dateutil==2.8.2
pytz==2023.3.post1
pywavelets==1.5.0
pyyaml==6.0.1
regex==2023.10.3
requests==2.31.0
safetensors==0.4.1
scipy==1.11.4
segment-anything-py==1.0
sentry-sdk==1.38.0
setproctitle==1.3.3
six==1.16.0
smmap==5.0.1
sympy==1.12
tokenizers==0.15.0
torch==2.1.1
torchvision==0.16.1
tqdm==4.66.1
transformers==4.35.2
triton==2.1.0
typeguard==2.13.3
typing-extensions==4.8.0
tzdata==2023.3
urllib3==2.1.0
wandb==0.16.1
xxhash==3.4.1
yarl==1.9.4
zipp==3.17.0
# The following packages are considered to be unsafe in a requirements file:
setuptools==69.0.2

View file

@ -16,7 +16,9 @@ class Args(argparse.Namespace):
def setup_converter(args: Args) -> ModelConverter: def setup_converter(args: Args) -> ModelConverter:
target = LatentDiffusionAutoencoder() target = LatentDiffusionAutoencoder()
source: nn.Module = AutoencoderKL.from_pretrained(pretrained_model_name_or_path=args.source_path, subfolder=args.subfolder) # type: ignore source: nn.Module = AutoencoderKL.from_pretrained( # type: ignore
pretrained_model_name_or_path=args.source_path, subfolder=args.subfolder
) # type: ignore
x = torch.randn(1, 3, 512, 512) x = torch.randn(1, 3, 512, 512)
converter = ModelConverter(source_model=source, target_model=target, skip_output_check=True, verbose=args.verbose) converter = ModelConverter(source_model=source, target_model=target, skip_output_check=True, verbose=args.verbose)
if not converter.run(source_args=(x,)): if not converter.run(source_args=(x,)):

View file

@ -51,7 +51,9 @@ def convert_mask_encoder(prompt_encoder: nn.Module) -> dict[str, Tensor]:
def convert_point_encoder(prompt_encoder: nn.Module) -> dict[str, Tensor]: def convert_point_encoder(prompt_encoder: nn.Module) -> dict[str, Tensor]:
manual_seed(seed=0) manual_seed(seed=0)
point_embeddings: list[Tensor] = [pe.weight for pe in prompt_encoder.point_embeddings] + [prompt_encoder.not_a_point_embed.weight] # type: ignore point_embeddings: list[Tensor] = [pe.weight for pe in prompt_encoder.point_embeddings] + [
prompt_encoder.not_a_point_embed.weight
] # type: ignore
pe = prompt_encoder.pe_layer.positional_encoding_gaussian_matrix # type: ignore pe = prompt_encoder.pe_layer.positional_encoding_gaussian_matrix # type: ignore
assert isinstance(pe, Tensor) assert isinstance(pe, Tensor)
state_dict: dict[str, Tensor] = { state_dict: dict[str, Tensor] = {
@ -161,8 +163,14 @@ def convert_mask_decoder(mask_decoder: nn.Module) -> dict[str, Tensor]:
assert mapping is not None assert mapping is not None
mapping["IOUMaskEncoder"] = "iou_token" mapping["IOUMaskEncoder"] = "iou_token"
state_dict = converter._convert_state_dict(source_state_dict=mask_decoder.state_dict(), target_state_dict=refiners_mask_decoder.state_dict(), state_dict_mapping=mapping) # type: ignore state_dict = converter._convert_state_dict( # type: ignore
state_dict["IOUMaskEncoder.weight"] = torch.cat(tensors=[mask_decoder.iou_token.weight, mask_decoder.mask_tokens.weight], dim=0) # type: ignore source_state_dict=mask_decoder.state_dict(),
target_state_dict=refiners_mask_decoder.state_dict(),
state_dict_mapping=mapping,
)
state_dict["IOUMaskEncoder.weight"] = torch.cat(
tensors=[mask_decoder.iou_token.weight, mask_decoder.mask_tokens.weight], dim=0
) # type: ignore
refiners_mask_decoder.load_state_dict(state_dict=state_dict) refiners_mask_decoder.load_state_dict(state_dict=state_dict)

View file

@ -102,5 +102,5 @@ class InformativeDrawings(fl.Chain):
dtype=dtype, dtype=dtype,
), ),
fl.Sigmoid(), fl.Sigmoid(),
) ),
) )

View file

@ -42,10 +42,13 @@ class DDIM(Scheduler):
else tensor(data=[0], device=self.device, dtype=self.dtype) else tensor(data=[0], device=self.device, dtype=self.dtype)
), ),
) )
current_scale_factor, previous_scale_factor = self.cumulative_scale_factors[timestep], ( current_scale_factor, previous_scale_factor = (
self.cumulative_scale_factors[timestep],
(
self.cumulative_scale_factors[previous_timestep] self.cumulative_scale_factors[previous_timestep]
if previous_timestep > 0 if previous_timestep > 0
else self.cumulative_scale_factors[0] else self.cumulative_scale_factors[0]
),
) )
predicted_x = (x - sqrt(1 - current_scale_factor**2) * noise) / current_scale_factor predicted_x = (x - sqrt(1 - current_scale_factor**2) * noise) / current_scale_factor
denoised_x = previous_scale_factor * predicted_x + sqrt(1 - previous_scale_factor**2) * noise denoised_x = previous_scale_factor * predicted_x + sqrt(1 - previous_scale_factor**2) * noise

View file

@ -65,9 +65,7 @@ class DDPM(Scheduler):
), ),
) )
current_factor = current_cumulative_factor / previous_cumulative_scale_factor current_factor = current_cumulative_factor / previous_cumulative_scale_factor
estimated_denoised_data = ( estimated_denoised_data = (x - (1 - current_cumulative_factor) ** 0.5 * noise) / current_cumulative_factor**0.5
x - (1 - current_cumulative_factor) ** 0.5 * noise
) / current_cumulative_factor**0.5
estimated_denoised_data = estimated_denoised_data.clamp(-1, 1) estimated_denoised_data = estimated_denoised_data.clamp(-1, 1)
original_data_coeff = (previous_cumulative_scale_factor**0.5 * (1 - current_factor)) / ( original_data_coeff = (previous_cumulative_scale_factor**0.5 * (1 - current_factor)) / (
1 - current_cumulative_factor 1 - current_cumulative_factor

View file

@ -33,7 +33,9 @@ def stabilityai_unclip_weights_path(test_weights_path: Path):
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def ref_encoder(stabilityai_unclip_weights_path: Path, test_device: torch.device) -> CLIPVisionModelWithProjection: def ref_encoder(stabilityai_unclip_weights_path: Path, test_device: torch.device) -> CLIPVisionModelWithProjection:
return CLIPVisionModelWithProjection.from_pretrained(stabilityai_unclip_weights_path, subfolder="image_encoder").to(test_device) # type: ignore return CLIPVisionModelWithProjection.from_pretrained(stabilityai_unclip_weights_path, subfolder="image_encoder").to( # type: ignore
test_device # type: ignore
)
def test_encoder( def test_encoder(

View file

@ -21,7 +21,7 @@ def test_dpm_solver_diffusers():
for step, timestep in enumerate(diffusers_scheduler.timesteps): for step, timestep in enumerate(diffusers_scheduler.timesteps):
diffusers_output = cast(Tensor, diffusers_scheduler.step(noise, timestep, sample).prev_sample) # type: ignore diffusers_output = cast(Tensor, diffusers_scheduler.step(noise, timestep, sample).prev_sample) # type: ignore
refiners_output = refiners_scheduler(x=sample, noise=noise, step=step) refiners_output = refiners_scheduler(x=sample, noise=noise, step=step)
assert allclose(diffusers_output, refiners_output), f"outputs differ at step {step}" assert allclose(diffusers_output, refiners_output, rtol=0.01), f"outputs differ at step {step}"
def test_ddim_solver_diffusers(): def test_ddim_solver_diffusers():

View file

@ -82,11 +82,13 @@ def test_double_text_encoder(diffusers_sdxl: DiffusersSDXL, double_text_encoder:
assert pooled_embedding.shape == torch.Size([1, 1280]) assert pooled_embedding.shape == torch.Size([1, 1280])
embedding_1, embedding_2 = cast( embedding_1, embedding_2 = cast(
tuple[Tensor, Tensor], prompt_embeds.split(split_size=[768, 1280], dim=-1) # type: ignore tuple[Tensor, Tensor],
prompt_embeds.split(split_size=[768, 1280], dim=-1), # type: ignore
) )
rembedding_1, rembedding_2 = cast( rembedding_1, rembedding_2 = cast(
tuple[Tensor, Tensor], double_embedding.split(split_size=[768, 1280], dim=-1) # type: ignore tuple[Tensor, Tensor],
double_embedding.split(split_size=[768, 1280], dim=-1), # type: ignore
) )
assert torch.allclose(input=embedding_1, other=rembedding_1, rtol=1e-3, atol=1e-3) assert torch.allclose(input=embedding_1, other=rembedding_1, rtol=1e-3, atol=1e-3)

View file

@ -264,8 +264,14 @@ def test_mask_decoder(facebook_sam_h: FacebookSAM, sam_h: SegmentAnythingH) -> N
assert mapping is not None assert mapping is not None
mapping["IOUMaskEncoder"] = "iou_token" mapping["IOUMaskEncoder"] = "iou_token"
state_dict = converter._convert_state_dict(source_state_dict=facebook_mask_decoder.state_dict(), target_state_dict=refiners_mask_decoder.state_dict(), state_dict_mapping=mapping) # type: ignore state_dict = converter._convert_state_dict( # type: ignore
state_dict["IOUMaskEncoder.weight"] = torch.cat([facebook_mask_decoder.iou_token.weight, facebook_mask_decoder.mask_tokens.weight], dim=0) # type: ignore source_state_dict=facebook_mask_decoder.state_dict(),
target_state_dict=refiners_mask_decoder.state_dict(),
state_dict_mapping=mapping,
)
state_dict["IOUMaskEncoder.weight"] = torch.cat(
[facebook_mask_decoder.iou_token.weight, facebook_mask_decoder.mask_tokens.weight], dim=0
) # type: ignore
refiners_mask_decoder.load_state_dict(state_dict=state_dict) refiners_mask_decoder.load_state_dict(state_dict=state_dict)
facebook_output = facebook_mask_decoder(**inputs) facebook_output = facebook_mask_decoder(**inputs)