mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-24 07:08:45 +00:00
replace poetry by rye for python dependency management
Co-authored-by: Cédric Deltheil <cedric@deltheil.me> Co-authored-by: Pierre Chapuis <git@catwell.info>
This commit is contained in:
parent
807ef5551c
commit
86c54977b9
39
.github/workflows/ci.yml
vendored
39
.github/workflows/ci.yml
vendored
|
@ -1,34 +1,35 @@
|
|||
name: CI
|
||||
on: push
|
||||
jobs:
|
||||
|
||||
on: push
|
||||
|
||||
jobs:
|
||||
lint_and_typecheck:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up python
|
||||
id: setup-python
|
||||
uses: actions/setup-python@v4
|
||||
- name: Install Rye
|
||||
uses: eifinger/setup-rye@v1
|
||||
with:
|
||||
python-version: "3.10"
|
||||
enable-cache: true
|
||||
cache-prefix: 'refiners-rye-cache'
|
||||
|
||||
- name: Install Poetry
|
||||
uses: snok/install-poetry@v1
|
||||
with:
|
||||
virtualenvs-create: true
|
||||
virtualenvs-in-project: true
|
||||
installer-parallel: true
|
||||
- name: add home shims dir to PATH
|
||||
run: echo "$HOME/.rye/shims" >> $GITHUB_PATH
|
||||
|
||||
- name: poetry install
|
||||
run: poetry install --no-interaction --all-extras
|
||||
- name: pin python
|
||||
run: rye pin 3.10
|
||||
|
||||
- name: black
|
||||
run: poetry run black --check .
|
||||
- name: rye sync
|
||||
run: rye sync --all-features
|
||||
|
||||
- name: lint
|
||||
run: poetry run ruff check .
|
||||
- name: ruff format
|
||||
run: rye run ruff format --check .
|
||||
|
||||
- name: ruff check
|
||||
run: rye run ruff check .
|
||||
|
||||
- name: typecheck
|
||||
run: poetry run pyright
|
||||
run: rye run pyright
|
||||
|
|
3
.gitignore
vendored
3
.gitignore
vendored
|
@ -29,3 +29,6 @@ wandb/
|
|||
|
||||
# env variable definitions file
|
||||
.env
|
||||
|
||||
# lock files
|
||||
requirements-dev.lock
|
||||
|
|
74
CONTRIBUTING.md
Normal file
74
CONTRIBUTING.md
Normal file
|
@ -0,0 +1,74 @@
|
|||
We are happy to accept contributions from everyone. Feel free to browse our [bounty list](https://www.finegrain.ai/bounties) to find a task you would like to work on.
|
||||
|
||||
This document describes the process for contributing to Refiners.
|
||||
|
||||
## Licensing
|
||||
|
||||
Refiners is a library that is freely available to use and modify under the MIT License. It's essential to exercise caution when using external code, as any code that can affect the licensing of Refiners, including proprietary code, should not be copied and pasted. It's worth noting that some open-source licenses can also be problematic. For instance, we'd need to redistribute an Apache 2.0 license if you're using code from an Apache 2.0 codebase.
|
||||
|
||||
## Design principles
|
||||
|
||||
We do not enforce strict rules on the design of the code, but we do have a few guidelines that we try to follow:
|
||||
|
||||
- No dead code. We keep the codebase clean and remove unnecessary code/functionality.
|
||||
- No unnecessary dependencies. We keep the number of dependencies to a minimum and only add new ones if necessary.
|
||||
- Separate concerns. We separate the code into different modules and avoid having too many dependencies between modules. In particular, we try not to revisit existing code/models when adding new functionality. Instead, we add new functionality in a separate module with the `Adapter` pattern.
|
||||
- Declarative style. We make the code as declarative, self-documenting, and easily read as possible. By reading the model's `repr`, you should understand how it works. We use explicit names for the different components of the models or the variables in the code.
|
||||
|
||||
## Setting up your environment
|
||||
|
||||
We use [Rye](https://rye-up.com/guide/installation/) to manage our development environment. Please follow the instructions on the Rye website to install it.
|
||||
|
||||
Once Rye is installed, you can clone the repository and run `rye sync` to install the dependencies.
|
||||
|
||||
## Linting
|
||||
|
||||
We use [ruff](https://docs.astral.sh/ruff/) to lint our code. You can lint your code by running.
|
||||
|
||||
```bash
|
||||
rye run lint
|
||||
```
|
||||
|
||||
We also enforce strict type checking with [pyright](https://github.com/microsoft/pyright). You can run the type checker with:
|
||||
|
||||
```bash
|
||||
rye run pyright
|
||||
```
|
||||
|
||||
## Running the tests
|
||||
|
||||
Running end-to-end tests is pretty compute-intensive, and you must convert all the model weights to the correct format before you can run them.
|
||||
|
||||
First, install test dependencies with:
|
||||
|
||||
```bash
|
||||
rye sync --features test
|
||||
```
|
||||
|
||||
Then, download and convert all the necessary weights. Be aware that this will use around 50 GB of disk space:
|
||||
|
||||
```bash
|
||||
./scripts/prepare-test-weights.sh
|
||||
```
|
||||
|
||||
To run all the tests, you will need to set the following environment variables:
|
||||
|
||||
```bash
|
||||
export REFINERS_TEST_DEVICE=cuda:0
|
||||
|
||||
export REFINERS_TEST_WEIGHTS_DIR=$(pwd)/tests/weights
|
||||
|
||||
rye run pytest
|
||||
```
|
||||
|
||||
In particular, `-k` is handy only to run tests that match a given expression, e.g.:
|
||||
|
||||
```bash
|
||||
rye run pytest -k diffusion_std_init_image tests/e2e/test_diffusion.py
|
||||
```
|
||||
|
||||
You can also run tests that are lightweight and will run on CPU:
|
||||
|
||||
```bash
|
||||
rye run pytest -k "not test_diffusion"
|
||||
```
|
25
README.md
25
README.md
|
@ -33,32 +33,17 @@ ______________________________________________________________________
|
|||
|
||||
### Install
|
||||
|
||||
Refiners is still an early stage project so we recommend using the `main` branch directly with [Poetry](https://python-poetry.org).
|
||||
|
||||
If you just want to use Refiners directly, clone the repository and run:
|
||||
Refiners is still an early stage project, and we do not release minor versions yet. We recommend
|
||||
installing the latest version via a git install:
|
||||
|
||||
```bash
|
||||
poetry install --all-extras
|
||||
pip install git+https://github.com/finegrain-ai/refiners.git
|
||||
```
|
||||
|
||||
There is currently [a bug with PyTorch 2.0.1 and Poetry](https://github.com/pytorch/pytorch/issues/100974), to work around it run:
|
||||
To include the training utils, use:
|
||||
|
||||
```bash
|
||||
poetry run pip install --upgrade torch torchvision
|
||||
```
|
||||
|
||||
If you want to depend on Refiners in your project which uses Poetry, you can do so:
|
||||
|
||||
```bash
|
||||
poetry add git+ssh://git@github.com:finegrain-ai/refiners.git#main
|
||||
```
|
||||
|
||||
If you want to run tests, we provide a script to download and convert all the necessary weights first. Be aware that this will use around 50 GB of disk space.
|
||||
|
||||
```bash
|
||||
poetry shell
|
||||
./scripts/prepare-test-weights.sh
|
||||
pytest
|
||||
pip install 'refiners[training] @ git+https://github.com/finegrain-ai/refiners.git'
|
||||
```
|
||||
|
||||
### Hello World
|
||||
|
|
3438
poetry.lock
generated
3438
poetry.lock
generated
File diff suppressed because it is too large
Load diff
111
pyproject.toml
111
pyproject.toml
|
@ -1,56 +1,67 @@
|
|||
[tool.poetry]
|
||||
[project]
|
||||
name = "refiners"
|
||||
version = "0.2.0"
|
||||
description = "The simplest way to train and run adapters on top of foundational models"
|
||||
authors = [
|
||||
"The Finegrain Team <bonjour@lagon.tech>",
|
||||
]
|
||||
authors = [{ name = "The Finegrain Team", email = "bonjour@lagon.tech" }]
|
||||
license = "MIT"
|
||||
dependencies = [
|
||||
"torch>=2.1.1",
|
||||
"safetensors>=0.4.0",
|
||||
"pillow>=10.1.0",
|
||||
"jaxtyping>=0.2.23",
|
||||
]
|
||||
readme = "README.md"
|
||||
packages = [{include = "refiners", from = "src"}]
|
||||
requires-python = ">= 3.10"
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = ">=3.10,<3.12"
|
||||
jaxtyping = "^0.2.14"
|
||||
torch = "^2.1.0"
|
||||
safetensors = "^0.3.0"
|
||||
numpy = "^1.24.2"
|
||||
pillow = ">=9.5.0"
|
||||
datasets = {version = "^2.14.0", optional = true}
|
||||
tomli = {version = "^2.0.1", optional = true}
|
||||
wandb = {version = "^0.15.7", optional = true}
|
||||
loguru = {version = "^0.7.0", optional = true}
|
||||
bitsandbytes = {version = "^0.41.0", optional = true}
|
||||
prodigyopt = {version = "^1.0", optional = true}
|
||||
pydantic = {version = "~2.0.3", optional = true}
|
||||
# Added scipy as a work around until this PR gets merged:
|
||||
# https://github.com/TimDettmers/bitsandbytes/pull/525
|
||||
scipy = {version = "^1.11.1", optional = true}
|
||||
torchvision = {version = "^0.16.0", optional = true}
|
||||
diffusers = {version = "^0.21.0", optional = true}
|
||||
transformers = {version = "^4.27.4", optional = true}
|
||||
piq = {version = "^0.7.1", optional = true}
|
||||
invisible-watermark = {version = "^0.2.0", optional = true}
|
||||
# An unofficial Python package for Meta AI's Segment Anything Model:
|
||||
# https://github.com/opengeos/segment-anything
|
||||
segment-anything-py = {version = "1.0", optional = true}
|
||||
|
||||
[tool.poetry.extras]
|
||||
training = ["datasets", "tomli", "wandb", "loguru", "bitsandbytes", "prodigyopt", "pydantic", "scipy", "torchvision"]
|
||||
conversion = ["diffusers", "transformers", "segment-anything-py"]
|
||||
test = ["diffusers", "transformers", "piq", "invisible-watermark", "segment-anything-py", "torchvision"]
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
black = "^23.1.0"
|
||||
pytest = "^7.2.2"
|
||||
isort = "^5.12.0"
|
||||
ipykernel = "^6.22.0"
|
||||
pyright = "^1.1.318"
|
||||
ruff = "^0.0.281"
|
||||
[project.optional-dependencies]
|
||||
training = [
|
||||
"bitsandbytes>=0.41.2.post2",
|
||||
"pydantic>=2.5.2",
|
||||
"prodigyopt>=1.0",
|
||||
"torchvision>=0.16.1",
|
||||
"loguru>=0.7.2",
|
||||
"wandb>=0.16.0",
|
||||
# Added scipy as a work around until this PR gets merged:
|
||||
# https://github.com/TimDettmers/bitsandbytes/pull/525
|
||||
"scipy>=1.11.4",
|
||||
"datasets>=2.15.0",
|
||||
]
|
||||
test = [
|
||||
"diffusers>=0.24.0",
|
||||
"transformers>=4.35.2",
|
||||
"piq>=0.8.0",
|
||||
"invisible-watermark>=0.2.0",
|
||||
"torchvision>=0.16.1",
|
||||
# An unofficial Python package for Meta AI's Segment Anything Model:
|
||||
# https://github.com/opengeos/segment-anything
|
||||
"segment-anything-py>=1.0",
|
||||
]
|
||||
conversion = [
|
||||
"diffusers>=0.24.0",
|
||||
"transformers>=4.35.2",
|
||||
"segment-anything-py>=1.0",
|
||||
]
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core"]
|
||||
build-backend = "poetry.core.masonry.api"
|
||||
requires = ["hatchling"]
|
||||
build-backend = "hatchling.build"
|
||||
|
||||
|
||||
[tool.rye]
|
||||
managed = true
|
||||
dev-dependencies = [
|
||||
"pyright == 1.1.333",
|
||||
"ruff>=0.0.292",
|
||||
"docformatter>=1.7.5",
|
||||
"pytest>=7.4.2",
|
||||
]
|
||||
|
||||
|
||||
[tool.hatch.metadata]
|
||||
allow-direct-references = true
|
||||
|
||||
[tool.rye.scripts]
|
||||
lint = { chain = ["ruff format .", "ruff --fix ."] }
|
||||
|
||||
[tool.black]
|
||||
line-length = 120
|
||||
|
@ -59,10 +70,18 @@ line-length = 120
|
|||
ignore = [
|
||||
"F722", # forward-annotation-syntax-error, because of Jaxtyping
|
||||
"E731", # do-not-assign-lambda
|
||||
"E501", # line-too-long, because Black (https://beta.ruff.rs/docs/faq/#is-ruff-compatible-with-black)
|
||||
]
|
||||
line-length = 120
|
||||
|
||||
[tool.docformatter]
|
||||
black = true
|
||||
|
||||
[tool.isort]
|
||||
profile = "black"
|
||||
force_sort_within_sections = true
|
||||
order_by_type = true
|
||||
combine_as_imports = true
|
||||
|
||||
[tool.pyright]
|
||||
include = ["src/refiners", "tests", "scripts"]
|
||||
strict = ["*"]
|
||||
|
|
95
requirements.lock
Normal file
95
requirements.lock
Normal file
|
@ -0,0 +1,95 @@
|
|||
# generated by rye
|
||||
# use `rye lock` or `rye sync` to update this lockfile
|
||||
#
|
||||
# last locked with the following flags:
|
||||
# pre: false
|
||||
# features: []
|
||||
# all-features: true
|
||||
|
||||
-e file:.
|
||||
aiohttp==3.9.1
|
||||
aiosignal==1.3.1
|
||||
annotated-types==0.6.0
|
||||
appdirs==1.4.4
|
||||
async-timeout==4.0.3
|
||||
attrs==23.1.0
|
||||
bitsandbytes==0.41.3
|
||||
certifi==2023.11.17
|
||||
charset-normalizer==3.3.2
|
||||
click==8.1.7
|
||||
datasets==2.15.0
|
||||
diffusers==0.24.0
|
||||
dill==0.3.7
|
||||
docker-pycreds==0.4.0
|
||||
filelock==3.13.1
|
||||
frozenlist==1.4.0
|
||||
fsspec==2023.10.0
|
||||
gitdb==4.0.11
|
||||
gitpython==3.1.40
|
||||
huggingface-hub==0.19.4
|
||||
idna==3.6
|
||||
importlib-metadata==7.0.0
|
||||
invisible-watermark==0.2.0
|
||||
jaxtyping==0.2.24
|
||||
jinja2==3.1.2
|
||||
loguru==0.7.2
|
||||
markupsafe==2.1.3
|
||||
mpmath==1.3.0
|
||||
multidict==6.0.4
|
||||
multiprocess==0.70.15
|
||||
networkx==3.2.1
|
||||
numpy==1.26.2
|
||||
nvidia-cublas-cu12==12.1.3.1
|
||||
nvidia-cuda-cupti-cu12==12.1.105
|
||||
nvidia-cuda-nvrtc-cu12==12.1.105
|
||||
nvidia-cuda-runtime-cu12==12.1.105
|
||||
nvidia-cudnn-cu12==8.9.2.26
|
||||
nvidia-cufft-cu12==11.0.2.54
|
||||
nvidia-curand-cu12==10.3.2.106
|
||||
nvidia-cusolver-cu12==11.4.5.107
|
||||
nvidia-cusparse-cu12==12.1.0.106
|
||||
nvidia-nccl-cu12==2.18.1
|
||||
nvidia-nvjitlink-cu12==12.3.101
|
||||
nvidia-nvtx-cu12==12.1.105
|
||||
opencv-python==4.8.1.78
|
||||
packaging==23.2
|
||||
pandas==2.1.3
|
||||
pillow==10.1.0
|
||||
piq==0.8.0
|
||||
prodigyopt==1.0
|
||||
protobuf==4.25.1
|
||||
psutil==5.9.6
|
||||
pyarrow==14.0.1
|
||||
pyarrow-hotfix==0.6
|
||||
pydantic==2.5.2
|
||||
pydantic-core==2.14.5
|
||||
python-dateutil==2.8.2
|
||||
pytz==2023.3.post1
|
||||
pywavelets==1.5.0
|
||||
pyyaml==6.0.1
|
||||
regex==2023.10.3
|
||||
requests==2.31.0
|
||||
safetensors==0.4.1
|
||||
scipy==1.11.4
|
||||
segment-anything-py==1.0
|
||||
sentry-sdk==1.38.0
|
||||
setproctitle==1.3.3
|
||||
six==1.16.0
|
||||
smmap==5.0.1
|
||||
sympy==1.12
|
||||
tokenizers==0.15.0
|
||||
torch==2.1.1
|
||||
torchvision==0.16.1
|
||||
tqdm==4.66.1
|
||||
transformers==4.35.2
|
||||
triton==2.1.0
|
||||
typeguard==2.13.3
|
||||
typing-extensions==4.8.0
|
||||
tzdata==2023.3
|
||||
urllib3==2.1.0
|
||||
wandb==0.16.1
|
||||
xxhash==3.4.1
|
||||
yarl==1.9.4
|
||||
zipp==3.17.0
|
||||
# The following packages are considered to be unsafe in a requirements file:
|
||||
setuptools==69.0.2
|
|
@ -16,7 +16,9 @@ class Args(argparse.Namespace):
|
|||
|
||||
def setup_converter(args: Args) -> ModelConverter:
|
||||
target = LatentDiffusionAutoencoder()
|
||||
source: nn.Module = AutoencoderKL.from_pretrained(pretrained_model_name_or_path=args.source_path, subfolder=args.subfolder) # type: ignore
|
||||
source: nn.Module = AutoencoderKL.from_pretrained( # type: ignore
|
||||
pretrained_model_name_or_path=args.source_path, subfolder=args.subfolder
|
||||
) # type: ignore
|
||||
x = torch.randn(1, 3, 512, 512)
|
||||
converter = ModelConverter(source_model=source, target_model=target, skip_output_check=True, verbose=args.verbose)
|
||||
if not converter.run(source_args=(x,)):
|
||||
|
|
|
@ -51,7 +51,9 @@ def convert_mask_encoder(prompt_encoder: nn.Module) -> dict[str, Tensor]:
|
|||
|
||||
def convert_point_encoder(prompt_encoder: nn.Module) -> dict[str, Tensor]:
|
||||
manual_seed(seed=0)
|
||||
point_embeddings: list[Tensor] = [pe.weight for pe in prompt_encoder.point_embeddings] + [prompt_encoder.not_a_point_embed.weight] # type: ignore
|
||||
point_embeddings: list[Tensor] = [pe.weight for pe in prompt_encoder.point_embeddings] + [
|
||||
prompt_encoder.not_a_point_embed.weight
|
||||
] # type: ignore
|
||||
pe = prompt_encoder.pe_layer.positional_encoding_gaussian_matrix # type: ignore
|
||||
assert isinstance(pe, Tensor)
|
||||
state_dict: dict[str, Tensor] = {
|
||||
|
@ -161,8 +163,14 @@ def convert_mask_decoder(mask_decoder: nn.Module) -> dict[str, Tensor]:
|
|||
assert mapping is not None
|
||||
mapping["IOUMaskEncoder"] = "iou_token"
|
||||
|
||||
state_dict = converter._convert_state_dict(source_state_dict=mask_decoder.state_dict(), target_state_dict=refiners_mask_decoder.state_dict(), state_dict_mapping=mapping) # type: ignore
|
||||
state_dict["IOUMaskEncoder.weight"] = torch.cat(tensors=[mask_decoder.iou_token.weight, mask_decoder.mask_tokens.weight], dim=0) # type: ignore
|
||||
state_dict = converter._convert_state_dict( # type: ignore
|
||||
source_state_dict=mask_decoder.state_dict(),
|
||||
target_state_dict=refiners_mask_decoder.state_dict(),
|
||||
state_dict_mapping=mapping,
|
||||
)
|
||||
state_dict["IOUMaskEncoder.weight"] = torch.cat(
|
||||
tensors=[mask_decoder.iou_token.weight, mask_decoder.mask_tokens.weight], dim=0
|
||||
) # type: ignore
|
||||
|
||||
refiners_mask_decoder.load_state_dict(state_dict=state_dict)
|
||||
|
||||
|
|
|
@ -102,5 +102,5 @@ class InformativeDrawings(fl.Chain):
|
|||
dtype=dtype,
|
||||
),
|
||||
fl.Sigmoid(),
|
||||
)
|
||||
),
|
||||
)
|
||||
|
|
|
@ -42,10 +42,13 @@ class DDIM(Scheduler):
|
|||
else tensor(data=[0], device=self.device, dtype=self.dtype)
|
||||
),
|
||||
)
|
||||
current_scale_factor, previous_scale_factor = self.cumulative_scale_factors[timestep], (
|
||||
current_scale_factor, previous_scale_factor = (
|
||||
self.cumulative_scale_factors[timestep],
|
||||
(
|
||||
self.cumulative_scale_factors[previous_timestep]
|
||||
if previous_timestep > 0
|
||||
else self.cumulative_scale_factors[0]
|
||||
),
|
||||
)
|
||||
predicted_x = (x - sqrt(1 - current_scale_factor**2) * noise) / current_scale_factor
|
||||
denoised_x = previous_scale_factor * predicted_x + sqrt(1 - previous_scale_factor**2) * noise
|
||||
|
|
|
@ -65,9 +65,7 @@ class DDPM(Scheduler):
|
|||
),
|
||||
)
|
||||
current_factor = current_cumulative_factor / previous_cumulative_scale_factor
|
||||
estimated_denoised_data = (
|
||||
x - (1 - current_cumulative_factor) ** 0.5 * noise
|
||||
) / current_cumulative_factor**0.5
|
||||
estimated_denoised_data = (x - (1 - current_cumulative_factor) ** 0.5 * noise) / current_cumulative_factor**0.5
|
||||
estimated_denoised_data = estimated_denoised_data.clamp(-1, 1)
|
||||
original_data_coeff = (previous_cumulative_scale_factor**0.5 * (1 - current_factor)) / (
|
||||
1 - current_cumulative_factor
|
||||
|
|
|
@ -33,7 +33,9 @@ def stabilityai_unclip_weights_path(test_weights_path: Path):
|
|||
|
||||
@pytest.fixture(scope="module")
|
||||
def ref_encoder(stabilityai_unclip_weights_path: Path, test_device: torch.device) -> CLIPVisionModelWithProjection:
|
||||
return CLIPVisionModelWithProjection.from_pretrained(stabilityai_unclip_weights_path, subfolder="image_encoder").to(test_device) # type: ignore
|
||||
return CLIPVisionModelWithProjection.from_pretrained(stabilityai_unclip_weights_path, subfolder="image_encoder").to( # type: ignore
|
||||
test_device # type: ignore
|
||||
)
|
||||
|
||||
|
||||
def test_encoder(
|
||||
|
|
|
@ -21,7 +21,7 @@ def test_dpm_solver_diffusers():
|
|||
for step, timestep in enumerate(diffusers_scheduler.timesteps):
|
||||
diffusers_output = cast(Tensor, diffusers_scheduler.step(noise, timestep, sample).prev_sample) # type: ignore
|
||||
refiners_output = refiners_scheduler(x=sample, noise=noise, step=step)
|
||||
assert allclose(diffusers_output, refiners_output), f"outputs differ at step {step}"
|
||||
assert allclose(diffusers_output, refiners_output, rtol=0.01), f"outputs differ at step {step}"
|
||||
|
||||
|
||||
def test_ddim_solver_diffusers():
|
||||
|
|
|
@ -82,11 +82,13 @@ def test_double_text_encoder(diffusers_sdxl: DiffusersSDXL, double_text_encoder:
|
|||
assert pooled_embedding.shape == torch.Size([1, 1280])
|
||||
|
||||
embedding_1, embedding_2 = cast(
|
||||
tuple[Tensor, Tensor], prompt_embeds.split(split_size=[768, 1280], dim=-1) # type: ignore
|
||||
tuple[Tensor, Tensor],
|
||||
prompt_embeds.split(split_size=[768, 1280], dim=-1), # type: ignore
|
||||
)
|
||||
|
||||
rembedding_1, rembedding_2 = cast(
|
||||
tuple[Tensor, Tensor], double_embedding.split(split_size=[768, 1280], dim=-1) # type: ignore
|
||||
tuple[Tensor, Tensor],
|
||||
double_embedding.split(split_size=[768, 1280], dim=-1), # type: ignore
|
||||
)
|
||||
|
||||
assert torch.allclose(input=embedding_1, other=rembedding_1, rtol=1e-3, atol=1e-3)
|
||||
|
|
|
@ -264,8 +264,14 @@ def test_mask_decoder(facebook_sam_h: FacebookSAM, sam_h: SegmentAnythingH) -> N
|
|||
assert mapping is not None
|
||||
mapping["IOUMaskEncoder"] = "iou_token"
|
||||
|
||||
state_dict = converter._convert_state_dict(source_state_dict=facebook_mask_decoder.state_dict(), target_state_dict=refiners_mask_decoder.state_dict(), state_dict_mapping=mapping) # type: ignore
|
||||
state_dict["IOUMaskEncoder.weight"] = torch.cat([facebook_mask_decoder.iou_token.weight, facebook_mask_decoder.mask_tokens.weight], dim=0) # type: ignore
|
||||
state_dict = converter._convert_state_dict( # type: ignore
|
||||
source_state_dict=facebook_mask_decoder.state_dict(),
|
||||
target_state_dict=refiners_mask_decoder.state_dict(),
|
||||
state_dict_mapping=mapping,
|
||||
)
|
||||
state_dict["IOUMaskEncoder.weight"] = torch.cat(
|
||||
[facebook_mask_decoder.iou_token.weight, facebook_mask_decoder.mask_tokens.weight], dim=0
|
||||
) # type: ignore
|
||||
refiners_mask_decoder.load_state_dict(state_dict=state_dict)
|
||||
|
||||
facebook_output = facebook_mask_decoder(**inputs)
|
||||
|
|
Loading…
Reference in a new issue