mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-09 23:12:02 +00:00
refactor: convert bash script to python
Ran successfully to completion. But on a repeat run `convert_unclip` didn't pass the hash check for some reason. - fix inpainting model download urls - shows a progress bar for downloads - skips downloading existing files - uses a temporary file to prevent partial downloads - can do a dry run to check if url is valid `DRY_RUN=1 python scripts/prepare_test_weights.py` - displays the downloaded file hash
This commit is contained in:
parent
77fb8032c2
commit
5ca1549c96
|
@ -48,7 +48,7 @@ rye sync --features test,conversion
|
|||
Then, download and convert all the necessary weights. Be aware that this will use around 50 GB of disk space:
|
||||
|
||||
```bash
|
||||
./scripts/prepare-test-weights.sh
|
||||
rye run python scripts/prepare_test_weights.py
|
||||
```
|
||||
|
||||
Finally, run the tests:
|
||||
|
|
|
@ -41,6 +41,8 @@ conversion = [
|
|||
"diffusers>=0.24.0",
|
||||
"transformers>=4.35.2",
|
||||
"segment-anything-py>=1.0",
|
||||
"requests>=2.26.0",
|
||||
"tqdm>=4.62.3",
|
||||
]
|
||||
|
||||
[build-system]
|
||||
|
|
|
@ -18,8 +18,11 @@ class Args(argparse.Namespace):
|
|||
|
||||
def setup_converter(args: Args) -> ModelConverter:
|
||||
target = LatentDiffusionAutoencoder()
|
||||
# low_cpu_mem_usage=False stops some annoying console messages us to `pip install accelerate`
|
||||
source: nn.Module = AutoencoderKL.from_pretrained( # type: ignore
|
||||
pretrained_model_name_or_path=args.source_path, subfolder=args.subfolder
|
||||
pretrained_model_name_or_path=args.source_path,
|
||||
subfolder=args.subfolder,
|
||||
low_cpu_mem_usage=False,
|
||||
) # type: ignore
|
||||
x = torch.randn(1, 3, 512, 512)
|
||||
converter = ModelConverter(source_model=source, target_model=target, skip_output_check=True, verbose=args.verbose)
|
||||
|
|
|
@ -22,7 +22,11 @@ class Args(argparse.Namespace):
|
|||
|
||||
@torch.no_grad()
|
||||
def convert(args: Args) -> dict[str, torch.Tensor]:
|
||||
controlnet_src: nn.Module = ControlNetModel.from_pretrained(pretrained_model_name_or_path=args.source_path) # type: ignore
|
||||
# low_cpu_mem_usage=False stops some annoying console messages us to `pip install accelerate`
|
||||
controlnet_src: nn.Module = ControlNetModel.from_pretrained( # type: ignore
|
||||
pretrained_model_name_or_path=args.source_path,
|
||||
low_cpu_mem_usage=False,
|
||||
)
|
||||
unet = SD1UNet(in_channels=4)
|
||||
adapter = SD1ControlnetAdapter(unet, name="mycn").inject()
|
||||
controlnet = unet.Controlnet
|
||||
|
|
|
@ -40,7 +40,11 @@ class Args(argparse.Namespace):
|
|||
@torch.no_grad()
|
||||
def process(args: Args) -> None:
|
||||
diffusers_state_dict = cast(dict[str, Tensor], torch.load(args.source_path, map_location="cpu")) # type: ignore
|
||||
diffusers_sd = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path=args.base_model) # type: ignore
|
||||
# low_cpu_mem_usage=False stops some annoying console messages us to `pip install accelerate`
|
||||
diffusers_sd = DiffusionPipeline.from_pretrained( # type: ignore
|
||||
pretrained_model_name_or_path=args.base_model,
|
||||
low_cpu_mem_usage=False,
|
||||
)
|
||||
diffusers_model = cast(fl.Module, diffusers_sd.unet) # type: ignore
|
||||
|
||||
refiners_model = SD1UNet(in_channels=4)
|
||||
|
|
|
@ -48,7 +48,11 @@ if __name__ == "__main__":
|
|||
|
||||
sdxl = "xl" in args.source_path
|
||||
target = ConditionEncoderXL() if sdxl else ConditionEncoder()
|
||||
source: nn.Module = T2IAdapter.from_pretrained(pretrained_model_name_or_path=args.source_path) # type: ignore
|
||||
# low_cpu_mem_usage=False stops some annoying console messages us to `pip install accelerate`
|
||||
source: nn.Module = T2IAdapter.from_pretrained( # type: ignore
|
||||
pretrained_model_name_or_path=args.source_path,
|
||||
low_cpu_mem_usage=False,
|
||||
)
|
||||
assert isinstance(source, nn.Module), "Source model is not a nn.Module"
|
||||
|
||||
x = torch.randn(1, 3, 1024, 1024) if sdxl else torch.randn(1, 3, 512, 512)
|
||||
|
|
|
@ -17,8 +17,11 @@ class Args(argparse.Namespace):
|
|||
|
||||
|
||||
def setup_converter(args: Args) -> ModelConverter:
|
||||
# low_cpu_mem_usage=False stops some annoying console messages us to `pip install accelerate`
|
||||
source: nn.Module = UNet2DConditionModel.from_pretrained( # type: ignore
|
||||
pretrained_model_name_or_path=args.source_path, subfolder="unet"
|
||||
pretrained_model_name_or_path=args.source_path,
|
||||
subfolder="unet",
|
||||
low_cpu_mem_usage=False,
|
||||
)
|
||||
source_in_channels: int = source.config.in_channels # type: ignore
|
||||
source_clip_embedding_dim: int = source.config.cross_attention_dim # type: ignore
|
||||
|
|
|
@ -21,8 +21,11 @@ class Args(argparse.Namespace):
|
|||
|
||||
|
||||
def setup_converter(args: Args) -> ModelConverter:
|
||||
# low_cpu_mem_usage=False stops some annoying console messages us to `pip install accelerate`
|
||||
source: nn.Module = CLIPVisionModelWithProjection.from_pretrained( # type: ignore
|
||||
pretrained_model_name_or_path=args.source_path, subfolder=args.subfolder
|
||||
pretrained_model_name_or_path=args.source_path,
|
||||
subfolder=args.subfolder,
|
||||
low_cpu_mem_usage=False,
|
||||
)
|
||||
assert isinstance(source, nn.Module), "Source model is not a nn.Module"
|
||||
architecture: str = source.config.architectures[0] # type: ignore
|
||||
|
|
|
@ -22,8 +22,11 @@ class Args(argparse.Namespace):
|
|||
|
||||
|
||||
def setup_converter(args: Args) -> ModelConverter:
|
||||
# low_cpu_mem_usage=False stops some annoying console messages us to `pip install accelerate`
|
||||
source: nn.Module = CLIPTextModelWithProjection.from_pretrained( # type: ignore
|
||||
pretrained_model_name_or_path=args.source_path, subfolder=args.subfolder
|
||||
pretrained_model_name_or_path=args.source_path,
|
||||
subfolder=args.subfolder,
|
||||
low_cpu_mem_usage=False,
|
||||
)
|
||||
assert isinstance(source, nn.Module), "Source model is not a nn.Module"
|
||||
architecture: str = source.config.architectures[0] # type: ignore
|
||||
|
|
|
@ -1,389 +0,0 @@
|
|||
#!/bin/bash
|
||||
|
||||
# This script downloads source weights from Hugging Face using cURL.
|
||||
# We want to convert from local directories (not use the network in conversion
|
||||
# scripts) but we also do not want to clone full repositories to save space.
|
||||
# This approach is more verbose but it lets us pick and choose.
|
||||
|
||||
set -x
|
||||
cd "$(dirname "$0")/.."
|
||||
|
||||
die () { >&2 echo "$@" ; exit 1 ; }
|
||||
|
||||
check_hash () { # (path, hash)
|
||||
_path="$1"; shift
|
||||
_expected="$1"
|
||||
_found="$(b2sum -l 32 "$_path" | cut -d' ' -f1)"
|
||||
[ "$_found" = "$_expected" ] || die "invalid hash for $_path ($_found != $_expected)"
|
||||
}
|
||||
|
||||
download_sd_text_encoder () { # (base="runwayml/stable-diffusion-v1-5" subdir="text_encoder")
|
||||
_base="$1"; shift
|
||||
_subdir="$1"
|
||||
mkdir tests/weights/$_base/$_subdir
|
||||
pushd tests/weights/$_base/$_subdir
|
||||
curl -LO https://huggingface.co/$_base/raw/main/$_subdir/config.json
|
||||
curl -LO https://huggingface.co/$_base/resolve/main/$_subdir/model.safetensors
|
||||
popd
|
||||
}
|
||||
|
||||
download_sd_tokenizer () { # (base="runwayml/stable-diffusion-v1-5" subdir="tokenizer")
|
||||
_base="$1"; shift
|
||||
_subdir="$1"
|
||||
mkdir tests/weights/$_base/$_subdir
|
||||
pushd tests/weights/$_base/$_subdir
|
||||
curl -LO https://huggingface.co/$_base/raw/main/$_subdir/merges.txt
|
||||
curl -LO https://huggingface.co/$_base/raw/main/$_subdir/special_tokens_map.json
|
||||
curl -LO https://huggingface.co/$_base/raw/main/$_subdir/tokenizer_config.json
|
||||
curl -LO https://huggingface.co/$_base/raw/main/$_subdir/vocab.json
|
||||
popd
|
||||
}
|
||||
|
||||
download_sd_base () { # (base="runwayml/stable-diffusion-v1-5")
|
||||
_base="$1"
|
||||
|
||||
# Inpainting source does not have safetensors.
|
||||
_ext="safetensors"
|
||||
grep -q "inpainting" <<< $_base && _ext="bin"
|
||||
|
||||
mkdir -p tests/weights/$_base
|
||||
pushd tests/weights/$_base
|
||||
curl -LO https://huggingface.co/$_base/raw/main/model_index.json
|
||||
mkdir scheduler unet vae
|
||||
pushd scheduler
|
||||
curl -LO https://huggingface.co/$_base/raw/main/scheduler/scheduler_config.json
|
||||
popd
|
||||
pushd unet
|
||||
curl -LO https://huggingface.co/$_base/raw/main/unet/config.json
|
||||
curl -LO https://huggingface.co/$_base/resolve/main/unet/diffusion_pytorch_model.$_ext
|
||||
popd
|
||||
pushd vae
|
||||
curl -LO https://huggingface.co/$_base/raw/main/vae/config.json
|
||||
curl -LO https://huggingface.co/$_base/resolve/main/vae/diffusion_pytorch_model.$_ext
|
||||
popd
|
||||
popd
|
||||
download_sd_text_encoder $_base text_encoder
|
||||
download_sd_tokenizer $_base tokenizer
|
||||
}
|
||||
|
||||
download_sd15 () { # (base="runwayml/stable-diffusion-v1-5")
|
||||
_base="$1"
|
||||
download_sd_base $_base
|
||||
pushd tests/weights/$_base
|
||||
mkdir feature_extractor safety_checker
|
||||
pushd feature_extractor
|
||||
curl -LO https://huggingface.co/$_base/raw/main/feature_extractor/preprocessor_config.json
|
||||
popd
|
||||
pushd safety_checker
|
||||
curl -LO https://huggingface.co/$_base/raw/main/safety_checker/config.json
|
||||
curl -LO https://huggingface.co/$_base/resolve/main/safety_checker/model.safetensors
|
||||
popd
|
||||
popd
|
||||
}
|
||||
|
||||
download_sdxl () { # (base="stabilityai/stable-diffusion-xl-base-1.0")
|
||||
_base="$1"
|
||||
download_sd_base $_base
|
||||
download_sd_text_encoder $_base text_encoder_2
|
||||
download_sd_tokenizer $_base tokenizer_2
|
||||
}
|
||||
|
||||
download_vae_ft_mse () {
|
||||
mkdir -p tests/weights/stabilityai/sd-vae-ft-mse
|
||||
pushd tests/weights/stabilityai/sd-vae-ft-mse
|
||||
curl -LO https://huggingface.co/stabilityai/sd-vae-ft-mse/raw/main/config.json
|
||||
curl -LO https://huggingface.co/stabilityai/sd-vae-ft-mse/resolve/main/diffusion_pytorch_model.safetensors
|
||||
popd
|
||||
}
|
||||
|
||||
download_lora () {
|
||||
mkdir -p tests/weights/pcuenq/pokemon-lora
|
||||
pushd tests/weights/pcuenq/pokemon-lora
|
||||
curl -LO https://huggingface.co/pcuenq/pokemon-lora/resolve/main/pytorch_lora_weights.bin
|
||||
popd
|
||||
}
|
||||
|
||||
download_preprocessors () {
|
||||
mkdir -p tests/weights/carolineec/informativedrawings
|
||||
pushd tests/weights/carolineec/informativedrawings
|
||||
curl -LO https://huggingface.co/spaces/carolineec/informativedrawings/resolve/main/model2.pth
|
||||
popd
|
||||
}
|
||||
|
||||
download_controlnet () {
|
||||
mkdir -p tests/weights/lllyasviel
|
||||
pushd tests/weights/lllyasviel
|
||||
mkdir control_v11p_sd15_canny
|
||||
pushd control_v11p_sd15_canny
|
||||
curl -LO https://huggingface.co/lllyasviel/control_v11p_sd15_canny/raw/main/config.json
|
||||
curl -LO https://huggingface.co/lllyasviel/control_v11p_sd15_canny/resolve/main/diffusion_pytorch_model.safetensors
|
||||
popd
|
||||
mkdir control_v11f1p_sd15_depth
|
||||
pushd control_v11f1p_sd15_depth
|
||||
curl -LO https://huggingface.co/lllyasviel/control_v11f1p_sd15_depth/raw/main/config.json
|
||||
curl -LO https://huggingface.co/lllyasviel/control_v11f1p_sd15_depth/resolve/main/diffusion_pytorch_model.safetensors
|
||||
popd
|
||||
mkdir control_v11p_sd15_normalbae
|
||||
pushd control_v11p_sd15_normalbae
|
||||
curl -LO https://huggingface.co/lllyasviel/control_v11p_sd15_normalbae/raw/main/config.json
|
||||
curl -LO https://huggingface.co/lllyasviel/control_v11p_sd15_normalbae/resolve/main/diffusion_pytorch_model.safetensors
|
||||
popd
|
||||
mkdir control_v11p_sd15_lineart
|
||||
pushd control_v11p_sd15_lineart
|
||||
curl -LO https://huggingface.co/lllyasviel/control_v11p_sd15_lineart/raw/main/config.json
|
||||
curl -LO https://huggingface.co/lllyasviel/control_v11p_sd15_lineart/resolve/main/diffusion_pytorch_model.safetensors
|
||||
popd
|
||||
popd
|
||||
|
||||
mkdir -p tests/weights/mfidabel/controlnet-segment-anything
|
||||
pushd tests/weights/mfidabel/controlnet-segment-anything
|
||||
curl -LO https://huggingface.co/mfidabel/controlnet-segment-anything/raw/main/config.json
|
||||
curl -LO https://huggingface.co/mfidabel/controlnet-segment-anything/resolve/main/diffusion_pytorch_model.bin
|
||||
popd
|
||||
}
|
||||
|
||||
download_unclip () {
|
||||
mkdir -p tests/weights/stabilityai/stable-diffusion-2-1-unclip
|
||||
pushd tests/weights/stabilityai/stable-diffusion-2-1-unclip
|
||||
curl -LO https://huggingface.co/stabilityai/stable-diffusion-2-1-unclip/raw/main/model_index.json
|
||||
mkdir image_encoder
|
||||
pushd image_encoder
|
||||
curl -LO https://huggingface.co/stabilityai/stable-diffusion-2-1-unclip/raw/main/image_encoder/config.json
|
||||
curl -LO https://huggingface.co/stabilityai/stable-diffusion-2-1-unclip/resolve/main/image_encoder/model.safetensors
|
||||
popd
|
||||
popd
|
||||
}
|
||||
|
||||
download_ip_adapter () {
|
||||
mkdir -p tests/weights/h94/IP-Adapter
|
||||
pushd tests/weights/h94/IP-Adapter
|
||||
mkdir -p models
|
||||
pushd models
|
||||
curl -LO https://huggingface.co/h94/IP-Adapter/resolve/main/models/ip-adapter_sd15.bin
|
||||
curl -LO https://huggingface.co/h94/IP-Adapter/resolve/main/models/ip-adapter-plus_sd15.bin
|
||||
popd
|
||||
mkdir -p sdxl_models
|
||||
pushd sdxl_models
|
||||
curl -LO https://huggingface.co/h94/IP-Adapter/resolve/main/sdxl_models/ip-adapter_sdxl_vit-h.bin
|
||||
curl -LO https://huggingface.co/h94/IP-Adapter/resolve/main/sdxl_models/ip-adapter-plus_sdxl_vit-h.bin
|
||||
popd
|
||||
popd
|
||||
}
|
||||
|
||||
download_t2i_adapter () {
|
||||
mkdir -p tests/weights/TencentARC/t2iadapter_depth_sd15v2
|
||||
pushd tests/weights/TencentARC/t2iadapter_depth_sd15v2
|
||||
curl -LO https://huggingface.co/TencentARC/t2iadapter_depth_sd15v2/raw/main/config.json
|
||||
curl -LO https://huggingface.co/TencentARC/t2iadapter_depth_sd15v2/resolve/main/diffusion_pytorch_model.bin
|
||||
popd
|
||||
|
||||
mkdir -p tests/weights/TencentARC/t2i-adapter-canny-sdxl-1.0
|
||||
pushd tests/weights/TencentARC/t2i-adapter-canny-sdxl-1.0
|
||||
curl -LO https://huggingface.co/TencentARC/t2i-adapter-canny-sdxl-1.0/raw/main/config.json
|
||||
curl -LO https://huggingface.co/TencentARC/t2i-adapter-canny-sdxl-1.0/resolve/main/diffusion_pytorch_model.safetensors
|
||||
popd
|
||||
}
|
||||
|
||||
download_sam () {
|
||||
mkdir -p tests/weights
|
||||
pushd tests/weights
|
||||
curl -LO https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth
|
||||
popd
|
||||
check_hash "tests/weights/sam_vit_h_4b8939.pth" 06785e66
|
||||
}
|
||||
|
||||
convert_sd15 () {
|
||||
python scripts/conversion/convert_transformers_clip_text_model.py \
|
||||
--from "tests/weights/runwayml/stable-diffusion-v1-5" \
|
||||
--to "tests/weights/CLIPTextEncoderL.safetensors" \
|
||||
--half
|
||||
check_hash "tests/weights/CLIPTextEncoderL.safetensors" 6c9cbc59
|
||||
|
||||
python scripts/conversion/convert_diffusers_autoencoder_kl.py \
|
||||
--from "tests/weights/runwayml/stable-diffusion-v1-5" \
|
||||
--to "tests/weights/lda.safetensors"
|
||||
check_hash "tests/weights/lda.safetensors" 329e369c
|
||||
|
||||
python scripts/conversion/convert_diffusers_unet.py \
|
||||
--from "tests/weights/runwayml/stable-diffusion-v1-5" \
|
||||
--to "tests/weights/unet.safetensors" \
|
||||
--half
|
||||
check_hash "tests/weights/unet.safetensors" f81ac65a
|
||||
|
||||
mkdir tests/weights/inpainting
|
||||
|
||||
python scripts/conversion/convert_diffusers_unet.py \
|
||||
--from "tests/weights/runwayml/stable-diffusion-inpainting" \
|
||||
--to "tests/weights/inpainting/unet.safetensors" \
|
||||
--half
|
||||
check_hash "tests/weights/inpainting/unet.safetensors" c07a8c61
|
||||
}
|
||||
|
||||
convert_sdxl () {
|
||||
python scripts/conversion/convert_transformers_clip_text_model.py \
|
||||
--from "tests/weights/stabilityai/stable-diffusion-xl-base-1.0" \
|
||||
--to "tests/weights/DoubleCLIPTextEncoder.safetensors" \
|
||||
--subfolder2 text_encoder_2 \
|
||||
--half
|
||||
check_hash "tests/weights/DoubleCLIPTextEncoder.safetensors" 7f99c30b
|
||||
|
||||
python scripts/conversion/convert_diffusers_autoencoder_kl.py \
|
||||
--from "tests/weights/stabilityai/stable-diffusion-xl-base-1.0" \
|
||||
--to "tests/weights/sdxl-lda.safetensors" \
|
||||
--half
|
||||
check_hash "tests/weights/sdxl-lda.safetensors" 7464e9dc
|
||||
|
||||
python scripts/conversion/convert_diffusers_unet.py \
|
||||
--from "tests/weights/stabilityai/stable-diffusion-xl-base-1.0" \
|
||||
--to "tests/weights/sdxl-unet.safetensors" \
|
||||
--half
|
||||
check_hash "tests/weights/sdxl-unet.safetensors" 2e5c4911
|
||||
}
|
||||
|
||||
convert_vae_ft_mse () {
|
||||
python scripts/conversion/convert_diffusers_autoencoder_kl.py \
|
||||
--from "tests/weights/stabilityai/sd-vae-ft-mse" \
|
||||
--to "tests/weights/lda_ft_mse.safetensors" \
|
||||
--half
|
||||
check_hash "tests/weights/lda_ft_mse.safetensors" 4d0bae7e
|
||||
}
|
||||
|
||||
convert_lora () {
|
||||
mkdir tests/weights/loras
|
||||
|
||||
python scripts/conversion/convert_diffusers_lora.py \
|
||||
--from "tests/weights/pcuenq/pokemon-lora/pytorch_lora_weights.bin" \
|
||||
--base-model "tests/weights/runwayml/stable-diffusion-v1-5" \
|
||||
--to "tests/weights/loras/pcuenq_pokemon_lora.safetensors"
|
||||
check_hash "tests/weights/loras/pcuenq_pokemon_lora.safetensors" a9d7e08e
|
||||
}
|
||||
|
||||
convert_preprocessors () {
|
||||
curl -L https://raw.githubusercontent.com/carolineec/informative-drawings/main/model.py \
|
||||
-o src/model.py
|
||||
python scripts/conversion/convert_informative_drawings.py \
|
||||
--from "tests/weights/carolineec/informativedrawings/model2.pth" \
|
||||
--to "tests/weights/informative-drawings.safetensors"
|
||||
rm -f src/model.py
|
||||
check_hash "tests/weights/informative-drawings.safetensors" 93dca207
|
||||
}
|
||||
|
||||
convert_controlnet () {
|
||||
mkdir tests/weights/controlnet
|
||||
|
||||
python scripts/conversion/convert_diffusers_controlnet.py \
|
||||
--from "tests/weights/lllyasviel/control_v11p_sd15_canny" \
|
||||
--to "tests/weights/controlnet/lllyasviel_control_v11p_sd15_canny.safetensors"
|
||||
check_hash "tests/weights/controlnet/lllyasviel_control_v11p_sd15_canny.safetensors" 9a1a48cf
|
||||
|
||||
python scripts/conversion/convert_diffusers_controlnet.py \
|
||||
--from "tests/weights/lllyasviel/control_v11f1p_sd15_depth" \
|
||||
--to "tests/weights/controlnet/lllyasviel_control_v11f1p_sd15_depth.safetensors"
|
||||
check_hash "tests/weights/controlnet/lllyasviel_control_v11f1p_sd15_depth.safetensors" bbe7e5a6
|
||||
|
||||
python scripts/conversion/convert_diffusers_controlnet.py \
|
||||
--from "tests/weights/lllyasviel/control_v11p_sd15_normalbae" \
|
||||
--to "tests/weights/controlnet/lllyasviel_control_v11p_sd15_normalbae.safetensors"
|
||||
check_hash "tests/weights/controlnet/lllyasviel_control_v11p_sd15_normalbae.safetensors" 9fa88ed5
|
||||
|
||||
python scripts/conversion/convert_diffusers_controlnet.py \
|
||||
--from "tests/weights/lllyasviel/control_v11p_sd15_lineart" \
|
||||
--to "tests/weights/controlnet/lllyasviel_control_v11p_sd15_lineart.safetensors"
|
||||
check_hash "tests/weights/controlnet/lllyasviel_control_v11p_sd15_lineart.safetensors" c29e8c03
|
||||
|
||||
python scripts/conversion/convert_diffusers_controlnet.py \
|
||||
--from "tests/weights/mfidabel/controlnet-segment-anything" \
|
||||
--to "tests/weights/controlnet/mfidabel_controlnet-segment-anything.safetensors"
|
||||
check_hash "tests/weights/controlnet/mfidabel_controlnet-segment-anything.safetensors" d536eebb
|
||||
}
|
||||
|
||||
convert_unclip () {
|
||||
python scripts/conversion/convert_transformers_clip_image_model.py \
|
||||
--from "tests/weights/stabilityai/stable-diffusion-2-1-unclip" \
|
||||
--to "tests/weights/CLIPImageEncoderH.safetensors" \
|
||||
--half
|
||||
check_hash "tests/weights/CLIPImageEncoderH.safetensors" 4ddb44d2
|
||||
}
|
||||
|
||||
convert_ip_adapter () {
|
||||
python scripts/conversion/convert_diffusers_ip_adapter.py \
|
||||
--from "tests/weights/h94/IP-Adapter/models/ip-adapter_sd15.bin" \
|
||||
--to "tests/weights/ip-adapter_sd15.safetensors"
|
||||
check_hash "tests/weights/ip-adapter_sd15.safetensors" 9579b465
|
||||
|
||||
python scripts/conversion/convert_diffusers_ip_adapter.py \
|
||||
--from "tests/weights/h94/IP-Adapter/sdxl_models/ip-adapter_sdxl_vit-h.bin" \
|
||||
--to "tests/weights/ip-adapter_sdxl_vit-h.safetensors" \
|
||||
--half
|
||||
check_hash "tests/weights/ip-adapter_sdxl_vit-h.safetensors" 739504c6
|
||||
|
||||
python scripts/conversion/convert_diffusers_ip_adapter.py \
|
||||
--from "tests/weights/h94/IP-Adapter/models/ip-adapter-plus_sd15.bin" \
|
||||
--to "tests/weights/ip-adapter-plus_sd15.safetensors" \
|
||||
--half
|
||||
check_hash "tests/weights/ip-adapter-plus_sd15.safetensors" 842b20e2
|
||||
|
||||
python scripts/conversion/convert_diffusers_ip_adapter.py \
|
||||
--from "tests/weights/h94/IP-Adapter/sdxl_models/ip-adapter-plus_sdxl_vit-h.bin" \
|
||||
--to "tests/weights/ip-adapter-plus_sdxl_vit-h.safetensors" \
|
||||
--half
|
||||
check_hash "tests/weights/ip-adapter-plus_sdxl_vit-h.safetensors" 0409974b
|
||||
}
|
||||
|
||||
convert_t2i_adapter () {
|
||||
mkdir tests/weights/T2I-Adapter
|
||||
python scripts/conversion/convert_diffusers_t2i_adapter.py \
|
||||
--from "tests/weights/TencentARC/t2iadapter_depth_sd15v2" \
|
||||
--to "tests/weights/T2I-Adapter/t2iadapter_depth_sd15v2.safetensors" \
|
||||
--half
|
||||
check_hash "tests/weights/T2I-Adapter/t2iadapter_depth_sd15v2.safetensors" bb2b3115
|
||||
|
||||
python scripts/conversion/convert_diffusers_t2i_adapter.py \
|
||||
--from "tests/weights/TencentARC/t2i-adapter-canny-sdxl-1.0" \
|
||||
--to "tests/weights/T2I-Adapter/t2i-adapter-canny-sdxl-1.0.safetensors" \
|
||||
--half
|
||||
check_hash "tests/weights/T2I-Adapter/t2i-adapter-canny-sdxl-1.0.safetensors" f07249a6
|
||||
}
|
||||
|
||||
convert_sam () {
|
||||
python scripts/conversion/convert_segment_anything.py \
|
||||
--from "tests/weights/sam_vit_h_4b8939.pth" \
|
||||
--to "tests/weights/segment-anything-h.safetensors"
|
||||
check_hash "tests/weights/segment-anything-h.safetensors" 6b843800
|
||||
}
|
||||
|
||||
download_all () {
|
||||
download_sd15 runwayml/stable-diffusion-v1-5
|
||||
download_sd15 runwayml/stable-diffusion-inpainting
|
||||
download_sdxl stabilityai/stable-diffusion-xl-base-1.0
|
||||
download_vae_ft_mse
|
||||
download_lora
|
||||
download_preprocessors
|
||||
download_controlnet
|
||||
download_unclip
|
||||
download_ip_adapter
|
||||
download_t2i_adapter
|
||||
download_sam
|
||||
}
|
||||
|
||||
convert_all () {
|
||||
convert_sd15
|
||||
convert_sdxl
|
||||
convert_vae_ft_mse
|
||||
convert_lora
|
||||
convert_preprocessors
|
||||
convert_controlnet
|
||||
convert_unclip
|
||||
convert_ip_adapter
|
||||
convert_t2i_adapter
|
||||
convert_sam
|
||||
}
|
||||
|
||||
main () {
|
||||
git lfs install || die "could not install git lfs"
|
||||
rm -rf tests/weights
|
||||
download_all
|
||||
convert_all
|
||||
}
|
||||
|
||||
main
|
584
scripts/prepare_test_weights.py
Normal file
584
scripts/prepare_test_weights.py
Normal file
|
@ -0,0 +1,584 @@
|
|||
"""
|
||||
Download and convert weights for testing
|
||||
|
||||
To see what weights will be downloaded and converted, run:
|
||||
DRY_RUN=1 python scripts/prepare_test_weights.py
|
||||
"""
|
||||
import hashlib
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import requests
|
||||
from tqdm import tqdm
|
||||
|
||||
# Set the base directory to the parent directory of the script
|
||||
project_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||
test_weights_dir = os.path.join(project_dir, "tests", "weights")
|
||||
|
||||
previous_line = "\033[F"
|
||||
|
||||
download_count = 0
|
||||
bytes_count = 0
|
||||
|
||||
|
||||
def die(message: str) -> None:
|
||||
print(message, file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def rel(path: str) -> str:
|
||||
return os.path.relpath(path, project_dir)
|
||||
|
||||
|
||||
def calc_hash(filepath: str) -> str:
|
||||
with open(filepath, "rb") as f:
|
||||
data = f.read()
|
||||
found = hashlib.blake2b(data, digest_size=int(32 / 8)).hexdigest()
|
||||
return found
|
||||
|
||||
|
||||
def check_hash(path: str, expected: str) -> str:
|
||||
found = calc_hash(path)
|
||||
if found != expected:
|
||||
die(f"❌ Invalid hash for {path} ({found} != {expected})")
|
||||
return found
|
||||
|
||||
|
||||
def download_file(
|
||||
url: str,
|
||||
dest_folder: str,
|
||||
dry_run: bool | None = None,
|
||||
skip_existing: bool = True,
|
||||
expected_hash: str | None = None,
|
||||
):
|
||||
"""
|
||||
Downloads a file
|
||||
|
||||
Features:
|
||||
- shows a progress bar
|
||||
- skips existing files
|
||||
- uses a temporary file to prevent partial downloads
|
||||
- can do a dry run to check the url is valid
|
||||
- displays the downloaded file hash
|
||||
|
||||
"""
|
||||
global download_count, bytes_count
|
||||
filename = os.path.basename(urlparse(url).path)
|
||||
dest_filename = os.path.join(dest_folder, filename)
|
||||
temp_filename = dest_filename + ".part"
|
||||
dry_run = bool(os.environ.get("DRY_RUN") == "1") if dry_run is None else dry_run
|
||||
|
||||
is_downloaded = os.path.exists(dest_filename)
|
||||
if is_downloaded and skip_existing:
|
||||
skip_icon = "✖️ "
|
||||
else:
|
||||
skip_icon = "🔽"
|
||||
|
||||
if dry_run:
|
||||
response = requests.head(url, allow_redirects=True)
|
||||
readable_size = ""
|
||||
|
||||
if response.status_code == 200:
|
||||
content_length = response.headers.get("content-length")
|
||||
|
||||
if content_length:
|
||||
size_in_bytes = int(content_length)
|
||||
readable_size = human_readable_size(size_in_bytes)
|
||||
download_count += 1
|
||||
bytes_count += size_in_bytes
|
||||
print(f"✅{skip_icon} {response.status_code} READY {readable_size:<8} {url}")
|
||||
|
||||
else:
|
||||
print(f"❌{skip_icon} {response.status_code} ERROR {readable_size:<8} {url}")
|
||||
return
|
||||
|
||||
if skip_existing and os.path.exists(dest_filename):
|
||||
print(f"{skip_icon}️ Skipping previously downloaded {url}")
|
||||
return
|
||||
|
||||
os.makedirs(dest_folder, exist_ok=True)
|
||||
|
||||
print(f"🔽 Downloading {url} => '{rel(dest_filename)}'", end="\n")
|
||||
response = requests.get(url, stream=True)
|
||||
if response.status_code != 200:
|
||||
print(response.content[:1000])
|
||||
die(f"Failed to download {url}. Status code: {response.status_code}")
|
||||
total = int(response.headers.get("content-length", 0))
|
||||
bar = tqdm(
|
||||
desc=filename,
|
||||
total=total,
|
||||
unit="iB",
|
||||
unit_scale=True,
|
||||
unit_divisor=1024,
|
||||
leave=False,
|
||||
)
|
||||
with open(temp_filename, "wb") as f, bar:
|
||||
for data in response.iter_content(chunk_size=1024 * 1000):
|
||||
size = f.write(data)
|
||||
bar.update(size)
|
||||
|
||||
os.rename(temp_filename, dest_filename)
|
||||
calculated_hash = calc_hash(dest_filename)
|
||||
|
||||
print(f"{previous_line}✅ Downloaded {calculated_hash} {url} => '{rel(dest_filename)}' ")
|
||||
if expected_hash is not None:
|
||||
check_hash(dest_filename, expected_hash)
|
||||
|
||||
|
||||
def download_files(urls: list[str], dest_folder: str):
|
||||
for url in urls:
|
||||
download_file(url, dest_folder)
|
||||
|
||||
|
||||
def human_readable_size(size: int | float, decimal_places: int = 2) -> str:
|
||||
for unit in ["B", "KB", "MB", "GB", "TB", "PB"]:
|
||||
if size < 1024.0:
|
||||
break
|
||||
size /= 1024.0
|
||||
return f"{size:.{decimal_places}f}{unit}" # type: ignore
|
||||
|
||||
|
||||
def download_sd_text_encoder(hf_repo_id: str = "runwayml/stable-diffusion-v1-5", subdir: str = "text_encoder"):
|
||||
encoder_filename = "model.safetensors" if "inpainting" not in hf_repo_id else "model.fp16.safetensors"
|
||||
base_url = f"https://huggingface.co/{hf_repo_id}"
|
||||
download_files(
|
||||
urls=[
|
||||
f"{base_url}/raw/main/{subdir}/config.json",
|
||||
f"{base_url}/resolve/main/{subdir}/{encoder_filename}",
|
||||
],
|
||||
dest_folder=os.path.join(test_weights_dir, hf_repo_id, subdir),
|
||||
)
|
||||
|
||||
|
||||
def download_sd_tokenizer(hf_repo_id: str = "runwayml/stable-diffusion-v1-5", subdir: str = "tokenizer"):
|
||||
download_files(
|
||||
urls=[
|
||||
f"https://huggingface.co/{hf_repo_id}/raw/main/{subdir}/merges.txt",
|
||||
f"https://huggingface.co/{hf_repo_id}/raw/main/{subdir}/special_tokens_map.json",
|
||||
f"https://huggingface.co/{hf_repo_id}/raw/main/{subdir}/tokenizer_config.json",
|
||||
f"https://huggingface.co/{hf_repo_id}/raw/main/{subdir}/vocab.json",
|
||||
],
|
||||
dest_folder=os.path.join(test_weights_dir, hf_repo_id, subdir),
|
||||
)
|
||||
|
||||
|
||||
def download_sd_base(hf_repo_id: str = "runwayml/stable-diffusion-v1-5"):
|
||||
is_inpainting = "inpainting" in hf_repo_id
|
||||
ext = "safetensors" if not is_inpainting else "bin"
|
||||
base_folder = os.path.join(test_weights_dir, hf_repo_id)
|
||||
download_file(f"https://huggingface.co/{hf_repo_id}/raw/main/model_index.json", base_folder)
|
||||
download_file(
|
||||
f"https://huggingface.co/{hf_repo_id}/raw/main/scheduler/scheduler_config.json",
|
||||
os.path.join(base_folder, "scheduler"),
|
||||
)
|
||||
|
||||
for subdir in ["unet", "vae"]:
|
||||
subdir_folder = os.path.join(base_folder, subdir)
|
||||
download_file(f"https://huggingface.co/{hf_repo_id}/raw/main/{subdir}/config.json", subdir_folder)
|
||||
download_file(
|
||||
f"https://huggingface.co/{hf_repo_id}/resolve/main/{subdir}/diffusion_pytorch_model.{ext}", subdir_folder
|
||||
)
|
||||
# we only need the unet for the inpainting model
|
||||
if not is_inpainting:
|
||||
download_sd_text_encoder(hf_repo_id, "text_encoder")
|
||||
download_sd_tokenizer(hf_repo_id, "tokenizer")
|
||||
|
||||
|
||||
def download_sd15(hf_repo_id: str = "runwayml/stable-diffusion-v1-5"):
|
||||
download_sd_base(hf_repo_id)
|
||||
base_folder = os.path.join(test_weights_dir, hf_repo_id)
|
||||
|
||||
subdir = "feature_extractor"
|
||||
download_file(
|
||||
f"https://huggingface.co/{hf_repo_id}/raw/main/{subdir}/preprocessor_config.json",
|
||||
os.path.join(base_folder, subdir),
|
||||
)
|
||||
|
||||
if "inpainting" not in hf_repo_id:
|
||||
subdir = "safety_checker"
|
||||
subdir_folder = os.path.join(base_folder, subdir)
|
||||
download_file(f"https://huggingface.co/{hf_repo_id}/raw/main/{subdir}/config.json", subdir_folder)
|
||||
download_file(f"https://huggingface.co/{hf_repo_id}/resolve/main/{subdir}/model.safetensors", subdir_folder)
|
||||
|
||||
|
||||
def download_sdxl(hf_repo_id: str = "stabilityai/stable-diffusion-xl-base-1.0"):
|
||||
download_sd_base(hf_repo_id)
|
||||
download_sd_text_encoder(hf_repo_id, "text_encoder_2")
|
||||
download_sd_tokenizer(hf_repo_id, "tokenizer_2")
|
||||
|
||||
|
||||
def download_vae_ft_mse():
|
||||
download_files(
|
||||
urls=[
|
||||
"https://huggingface.co/stabilityai/sd-vae-ft-mse/raw/main/config.json",
|
||||
"https://huggingface.co/stabilityai/sd-vae-ft-mse/resolve/main/diffusion_pytorch_model.safetensors",
|
||||
],
|
||||
dest_folder=os.path.join(test_weights_dir, "stabilityai", "sd-vae-ft-mse"),
|
||||
)
|
||||
|
||||
|
||||
def download_lora():
|
||||
dest_folder = os.path.join(test_weights_dir, "pcuenq", "pokemon-lora")
|
||||
download_file("https://huggingface.co/pcuenq/pokemon-lora/resolve/main/pytorch_lora_weights.bin", dest_folder)
|
||||
|
||||
|
||||
def download_preprocessors():
|
||||
dest_folder = os.path.join(test_weights_dir, "carolineec", "informativedrawings")
|
||||
download_file("https://huggingface.co/spaces/carolineec/informativedrawings/resolve/main/model2.pth", dest_folder)
|
||||
|
||||
|
||||
def download_controlnet():
|
||||
base_folder = os.path.join(test_weights_dir, "lllyasviel")
|
||||
controlnets = [
|
||||
"control_v11p_sd15_canny",
|
||||
"control_v11f1p_sd15_depth",
|
||||
"control_v11p_sd15_normalbae",
|
||||
"control_v11p_sd15_lineart",
|
||||
]
|
||||
for net in controlnets:
|
||||
net_folder = os.path.join(base_folder, net)
|
||||
urls = [
|
||||
f"https://huggingface.co/lllyasviel/{net}/raw/main/config.json",
|
||||
f"https://huggingface.co/lllyasviel/{net}/resolve/main/diffusion_pytorch_model.safetensors",
|
||||
]
|
||||
download_files(urls, net_folder)
|
||||
|
||||
mfidabel_folder = os.path.join(test_weights_dir, "mfidabel", "controlnet-segment-anything")
|
||||
urls = [
|
||||
"https://huggingface.co/mfidabel/controlnet-segment-anything/raw/main/config.json",
|
||||
"https://huggingface.co/mfidabel/controlnet-segment-anything/resolve/main/diffusion_pytorch_model.bin",
|
||||
]
|
||||
download_files(urls, mfidabel_folder)
|
||||
|
||||
|
||||
def download_unclip():
|
||||
base_folder = os.path.join(test_weights_dir, "stabilityai", "stable-diffusion-2-1-unclip")
|
||||
download_file(
|
||||
"https://huggingface.co/stabilityai/stable-diffusion-2-1-unclip/raw/main/model_index.json", base_folder
|
||||
)
|
||||
image_encoder_folder = os.path.join(base_folder, "image_encoder")
|
||||
urls = [
|
||||
"https://huggingface.co/stabilityai/stable-diffusion-2-1-unclip/raw/main/image_encoder/config.json",
|
||||
"https://huggingface.co/stabilityai/stable-diffusion-2-1-unclip/resolve/main/image_encoder/model.safetensors",
|
||||
]
|
||||
download_files(urls, image_encoder_folder)
|
||||
|
||||
|
||||
def download_ip_adapter():
|
||||
base_folder = os.path.join(test_weights_dir, "h94", "IP-Adapter")
|
||||
models_folder = os.path.join(base_folder, "models")
|
||||
urls = [
|
||||
"https://huggingface.co/h94/IP-Adapter/resolve/main/models/ip-adapter_sd15.bin",
|
||||
"https://huggingface.co/h94/IP-Adapter/resolve/main/models/ip-adapter-plus_sd15.bin",
|
||||
]
|
||||
download_files(urls, models_folder)
|
||||
|
||||
sdxl_models_folder = os.path.join(base_folder, "sdxl_models")
|
||||
urls = [
|
||||
"https://huggingface.co/h94/IP-Adapter/resolve/main/sdxl_models/ip-adapter_sdxl_vit-h.bin",
|
||||
"https://huggingface.co/h94/IP-Adapter/resolve/main/sdxl_models/ip-adapter-plus_sdxl_vit-h.bin",
|
||||
]
|
||||
download_files(urls, sdxl_models_folder)
|
||||
|
||||
|
||||
def download_t2i_adapter():
|
||||
base_folder = os.path.join(test_weights_dir, "TencentARC", "t2iadapter_depth_sd15v2")
|
||||
urls = [
|
||||
"https://huggingface.co/TencentARC/t2iadapter_depth_sd15v2/raw/main/config.json",
|
||||
"https://huggingface.co/TencentARC/t2iadapter_depth_sd15v2/resolve/main/diffusion_pytorch_model.bin",
|
||||
]
|
||||
download_files(urls, base_folder)
|
||||
|
||||
canny_sdxl_folder = os.path.join(test_weights_dir, "TencentARC", "t2i-adapter-canny-sdxl-1.0")
|
||||
urls = [
|
||||
"https://huggingface.co/TencentARC/t2i-adapter-canny-sdxl-1.0/raw/main/config.json",
|
||||
"https://huggingface.co/TencentARC/t2i-adapter-canny-sdxl-1.0/resolve/main/diffusion_pytorch_model.safetensors",
|
||||
]
|
||||
download_files(urls, canny_sdxl_folder)
|
||||
|
||||
|
||||
def download_sam():
|
||||
weights_folder = os.path.join(test_weights_dir)
|
||||
download_file(
|
||||
"https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth", weights_folder, expected_hash="06785e66"
|
||||
)
|
||||
|
||||
|
||||
def printg(msg: str):
|
||||
"""print in green color"""
|
||||
print("\033[92m" + msg + "\033[0m")
|
||||
|
||||
|
||||
def run_conversion_script(
|
||||
script_filename: str,
|
||||
from_weights: str,
|
||||
to_weights: str,
|
||||
half: bool = False,
|
||||
expected_hash: str | None = None,
|
||||
additional_args: list[str] | None = None,
|
||||
skip_existing: bool = True,
|
||||
):
|
||||
if skip_existing and expected_hash and os.path.exists(to_weights):
|
||||
found_hash = check_hash(to_weights, expected_hash)
|
||||
if expected_hash == found_hash:
|
||||
printg(f"✖️ Skipping converted from {from_weights} to {to_weights} (hash {found_hash} confirmed) ")
|
||||
return
|
||||
|
||||
msg = f"Converting {from_weights} to {to_weights}"
|
||||
printg(msg)
|
||||
|
||||
args = ["python", f"scripts/conversion/{script_filename}", "--from", from_weights, "--to", to_weights]
|
||||
if half:
|
||||
args.append("--half")
|
||||
if additional_args:
|
||||
args.extend(additional_args)
|
||||
|
||||
subprocess.run(args, check=True)
|
||||
if expected_hash is not None:
|
||||
found_hash = check_hash(to_weights, expected_hash)
|
||||
printg(f"✅ Converted from {from_weights} to {to_weights} (hash {found_hash} confirmed) ")
|
||||
else:
|
||||
printg(f"✅⚠️ Converted from {from_weights} to {to_weights} (no hash check performed)")
|
||||
|
||||
|
||||
def convert_sd15():
|
||||
run_conversion_script(
|
||||
script_filename="convert_transformers_clip_text_model.py",
|
||||
from_weights="tests/weights/runwayml/stable-diffusion-v1-5",
|
||||
to_weights="tests/weights/CLIPTextEncoderL.safetensors",
|
||||
half=True,
|
||||
expected_hash="6c9cbc59",
|
||||
)
|
||||
run_conversion_script(
|
||||
"convert_diffusers_autoencoder_kl.py",
|
||||
"tests/weights/runwayml/stable-diffusion-v1-5",
|
||||
"tests/weights/lda.safetensors",
|
||||
expected_hash="329e369c",
|
||||
)
|
||||
run_conversion_script(
|
||||
"convert_diffusers_unet.py",
|
||||
"tests/weights/runwayml/stable-diffusion-v1-5",
|
||||
"tests/weights/unet.safetensors",
|
||||
half=True,
|
||||
expected_hash="f81ac65a",
|
||||
)
|
||||
os.makedirs("tests/weights/inpainting", exist_ok=True)
|
||||
run_conversion_script(
|
||||
"convert_diffusers_unet.py",
|
||||
"tests/weights/runwayml/stable-diffusion-inpainting",
|
||||
"tests/weights/inpainting/unet.safetensors",
|
||||
half=True,
|
||||
expected_hash="c07a8c61",
|
||||
)
|
||||
|
||||
|
||||
def convert_sdxl():
|
||||
run_conversion_script(
|
||||
"convert_transformers_clip_text_model.py",
|
||||
"tests/weights/stabilityai/stable-diffusion-xl-base-1.0",
|
||||
"tests/weights/DoubleCLIPTextEncoder.safetensors",
|
||||
half=True,
|
||||
expected_hash="7f99c30b",
|
||||
additional_args=["--subfolder2", "text_encoder_2"],
|
||||
)
|
||||
run_conversion_script(
|
||||
"convert_diffusers_autoencoder_kl.py",
|
||||
"tests/weights/stabilityai/stable-diffusion-xl-base-1.0",
|
||||
"tests/weights/sdxl-lda.safetensors",
|
||||
half=True,
|
||||
expected_hash="7464e9dc",
|
||||
)
|
||||
run_conversion_script(
|
||||
"convert_diffusers_unet.py",
|
||||
"tests/weights/stabilityai/stable-diffusion-xl-base-1.0",
|
||||
"tests/weights/sdxl-unet.safetensors",
|
||||
half=True,
|
||||
expected_hash="2e5c4911",
|
||||
)
|
||||
|
||||
|
||||
def convert_vae_ft_mse():
|
||||
run_conversion_script(
|
||||
"convert_diffusers_autoencoder_kl.py",
|
||||
"tests/weights/stabilityai/sd-vae-ft-mse",
|
||||
"tests/weights/lda_ft_mse.safetensors",
|
||||
half=True,
|
||||
expected_hash="4d0bae7e",
|
||||
)
|
||||
|
||||
|
||||
def convert_lora():
|
||||
os.makedirs("tests/weights/loras", exist_ok=True)
|
||||
run_conversion_script(
|
||||
"convert_diffusers_lora.py",
|
||||
"tests/weights/pcuenq/pokemon-lora/pytorch_lora_weights.bin",
|
||||
"tests/weights/loras/pcuenq_pokemon_lora.safetensors",
|
||||
additional_args=["--base-model", "tests/weights/runwayml/stable-diffusion-v1-5"],
|
||||
expected_hash="a9d7e08e",
|
||||
)
|
||||
|
||||
|
||||
def convert_preprocessors():
|
||||
subprocess.run(
|
||||
[
|
||||
"curl",
|
||||
"-L",
|
||||
"https://raw.githubusercontent.com/carolineec/informative-drawings/main/model.py",
|
||||
"-o",
|
||||
"src/model.py",
|
||||
],
|
||||
check=True,
|
||||
)
|
||||
run_conversion_script(
|
||||
"convert_informative_drawings.py",
|
||||
"tests/weights/carolineec/informativedrawings/model2.pth",
|
||||
"tests/weights/informative-drawings.safetensors",
|
||||
expected_hash="93dca207",
|
||||
)
|
||||
os.remove("src/model.py")
|
||||
|
||||
|
||||
def convert_controlnet():
|
||||
os.makedirs("tests/weights/controlnet", exist_ok=True)
|
||||
run_conversion_script(
|
||||
"convert_diffusers_controlnet.py",
|
||||
"tests/weights/lllyasviel/control_v11p_sd15_canny",
|
||||
"tests/weights/controlnet/lllyasviel_control_v11p_sd15_canny.safetensors",
|
||||
expected_hash="9a1a48cf",
|
||||
)
|
||||
run_conversion_script(
|
||||
"convert_diffusers_controlnet.py",
|
||||
"tests/weights/lllyasviel/control_v11f1p_sd15_depth",
|
||||
"tests/weights/controlnet/lllyasviel_control_v11f1p_sd15_depth.safetensors",
|
||||
expected_hash="bbe7e5a6",
|
||||
)
|
||||
run_conversion_script(
|
||||
"convert_diffusers_controlnet.py",
|
||||
"tests/weights/lllyasviel/control_v11p_sd15_normalbae",
|
||||
"tests/weights/controlnet/lllyasviel_control_v11p_sd15_normalbae.safetensors",
|
||||
expected_hash="9fa88ed5",
|
||||
)
|
||||
run_conversion_script(
|
||||
"convert_diffusers_controlnet.py",
|
||||
"tests/weights/lllyasviel/control_v11p_sd15_lineart",
|
||||
"tests/weights/controlnet/lllyasviel_control_v11p_sd15_lineart.safetensors",
|
||||
expected_hash="c29e8c03",
|
||||
)
|
||||
run_conversion_script(
|
||||
"convert_diffusers_controlnet.py",
|
||||
"tests/weights/mfidabel/controlnet-segment-anything",
|
||||
"tests/weights/controlnet/mfidabel_controlnet-segment-anything.safetensors",
|
||||
expected_hash="d536eebb",
|
||||
)
|
||||
|
||||
|
||||
def convert_unclip():
|
||||
run_conversion_script(
|
||||
"convert_transformers_clip_image_model.py",
|
||||
"tests/weights/stabilityai/stable-diffusion-2-1-unclip",
|
||||
"tests/weights/CLIPImageEncoderH.safetensors",
|
||||
half=True,
|
||||
expected_hash="4ddb44d2",
|
||||
)
|
||||
|
||||
|
||||
def convert_ip_adapter():
|
||||
run_conversion_script(
|
||||
"convert_diffusers_ip_adapter.py",
|
||||
"tests/weights/h94/IP-Adapter/models/ip-adapter_sd15.bin",
|
||||
"tests/weights/ip-adapter_sd15.safetensors",
|
||||
expected_hash="9579b465",
|
||||
)
|
||||
run_conversion_script(
|
||||
"convert_diffusers_ip_adapter.py",
|
||||
"tests/weights/h94/IP-Adapter/sdxl_models/ip-adapter_sdxl_vit-h.bin",
|
||||
"tests/weights/ip-adapter_sdxl_vit-h.safetensors",
|
||||
half=True,
|
||||
expected_hash="739504c6",
|
||||
)
|
||||
run_conversion_script(
|
||||
"convert_diffusers_ip_adapter.py",
|
||||
"tests/weights/h94/IP-Adapter/models/ip-adapter-plus_sd15.bin",
|
||||
"tests/weights/ip-adapter-plus_sd15.safetensors",
|
||||
half=True,
|
||||
expected_hash="842b20e2",
|
||||
)
|
||||
run_conversion_script(
|
||||
"convert_diffusers_ip_adapter.py",
|
||||
"tests/weights/h94/IP-Adapter/sdxl_models/ip-adapter-plus_sdxl_vit-h.bin",
|
||||
"tests/weights/ip-adapter-plus_sdxl_vit-h.safetensors",
|
||||
half=True,
|
||||
expected_hash="0409974b",
|
||||
)
|
||||
|
||||
|
||||
def convert_t2i_adapter():
|
||||
os.makedirs("tests/weights/T2I-Adapter", exist_ok=True)
|
||||
run_conversion_script(
|
||||
"convert_diffusers_t2i_adapter.py",
|
||||
"tests/weights/TencentARC/t2iadapter_depth_sd15v2",
|
||||
"tests/weights/T2I-Adapter/t2iadapter_depth_sd15v2.safetensors",
|
||||
half=True,
|
||||
expected_hash="bb2b3115",
|
||||
)
|
||||
run_conversion_script(
|
||||
"convert_diffusers_t2i_adapter.py",
|
||||
"tests/weights/TencentARC/t2i-adapter-canny-sdxl-1.0",
|
||||
"tests/weights/T2I-Adapter/t2i-adapter-canny-sdxl-1.0.safetensors",
|
||||
half=True,
|
||||
expected_hash="f07249a6",
|
||||
)
|
||||
|
||||
|
||||
def convert_sam():
|
||||
run_conversion_script(
|
||||
"convert_segment_anything.py",
|
||||
"tests/weights/sam_vit_h_4b8939.pth",
|
||||
"tests/weights/segment-anything-h.safetensors",
|
||||
expected_hash="6b843800",
|
||||
)
|
||||
|
||||
|
||||
def download_all():
|
||||
print(f"\nAll weights will be downloaded to {test_weights_dir}\n")
|
||||
download_sd15("runwayml/stable-diffusion-v1-5")
|
||||
download_sd15("runwayml/stable-diffusion-inpainting")
|
||||
download_sdxl("stabilityai/stable-diffusion-xl-base-1.0")
|
||||
download_vae_ft_mse()
|
||||
download_lora()
|
||||
download_preprocessors()
|
||||
download_controlnet()
|
||||
download_unclip()
|
||||
download_ip_adapter()
|
||||
download_t2i_adapter()
|
||||
download_sam()
|
||||
|
||||
|
||||
def convert_all():
|
||||
convert_sd15()
|
||||
convert_sdxl()
|
||||
convert_vae_ft_mse()
|
||||
convert_lora()
|
||||
convert_preprocessors()
|
||||
convert_controlnet()
|
||||
convert_unclip()
|
||||
convert_ip_adapter()
|
||||
convert_t2i_adapter()
|
||||
convert_sam()
|
||||
|
||||
|
||||
def main():
|
||||
try:
|
||||
download_all()
|
||||
print(f"{download_count} files ({human_readable_size(bytes_count)})\n")
|
||||
if not bool(os.environ.get("DRY_RUN") == "1"):
|
||||
printg("Converting weights to refiners format\n")
|
||||
convert_all()
|
||||
except KeyboardInterrupt:
|
||||
print("Stopped")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Loading…
Reference in a new issue