mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-14 00:58:13 +00:00
58 lines
2 KiB
Python
58 lines
2 KiB
Python
# Original weights can be found here: https://huggingface.co/spaces/carolineec/informativedrawings
|
|
# Code is at https://github.com/carolineec/informative-drawings
|
|
# Copy `model.py` in your `PYTHONPATH`. You can edit it to remove un-necessary code
|
|
# and imports if you want, we only need `Generator`.
|
|
|
|
import torch
|
|
|
|
from refiners.fluxion.utils import create_state_dict_mapping, convert_state_dict, save_to_safetensors
|
|
|
|
from refiners.foundationals.latent_diffusion.preprocessors.informative_drawings import InformativeDrawings
|
|
from model import Generator # type: ignore
|
|
|
|
|
|
@torch.no_grad()
|
|
def convert(checkpoint: str, device: torch.device | str) -> dict[str, torch.Tensor]:
|
|
src_model = Generator(3, 1, 3) # type: ignore
|
|
src_model.load_state_dict(torch.load(checkpoint, map_location=device)) # type: ignore
|
|
src_model.eval() # type: ignore
|
|
|
|
dst_model = InformativeDrawings()
|
|
|
|
x = torch.randn(1, 3, 512, 512)
|
|
|
|
mapping = create_state_dict_mapping(source_model=src_model, target_model=dst_model, source_args=[x]) # type: ignore
|
|
assert mapping is not None, "Model conversion failed"
|
|
state_dict = convert_state_dict(source_state_dict=src_model.state_dict(), target_state_dict=dst_model.state_dict(), state_dict_mapping=mapping) # type: ignore
|
|
return {k: v.half() for k, v in state_dict.items()}
|
|
|
|
|
|
def main() -> None:
|
|
import argparse
|
|
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument(
|
|
"--from",
|
|
type=str,
|
|
dest="source",
|
|
required=False,
|
|
default="model2.pth",
|
|
help="Source model",
|
|
)
|
|
parser.add_argument(
|
|
"--output-file",
|
|
type=str,
|
|
required=False,
|
|
default="informative-drawings.safetensors",
|
|
help="Path for the output file",
|
|
)
|
|
args = parser.parse_args()
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
tensors = convert(checkpoint=args.source, device=device)
|
|
save_to_safetensors(path=args.output_file, tensors=tensors)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|