mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-21 21:58:47 +00:00
fix typing for informative drawings convert script
This commit is contained in:
parent
0fd46f9ec4
commit
9da00e6fcf
|
@ -5,32 +5,29 @@
|
|||
|
||||
import torch
|
||||
|
||||
from safetensors.torch import save_file
|
||||
from refiners.fluxion.utils import (
|
||||
create_state_dict_mapping,
|
||||
convert_state_dict,
|
||||
)
|
||||
from refiners.fluxion.utils import create_state_dict_mapping, convert_state_dict, save_to_safetensors
|
||||
|
||||
from refiners.foundationals.latent_diffusion.preprocessors.informative_drawings import InformativeDrawings
|
||||
from model import Generator
|
||||
from model import Generator # type: ignore
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def convert(checkpoint: str, device: torch.device) -> dict[str, torch.Tensor]:
|
||||
src_model = Generator(3, 1, 3)
|
||||
src_model.load_state_dict(torch.load(checkpoint, map_location=device))
|
||||
src_model.eval()
|
||||
def convert(checkpoint: str, device: torch.device | str) -> dict[str, torch.Tensor]:
|
||||
src_model = Generator(3, 1, 3) # type: ignore
|
||||
src_model.load_state_dict(torch.load(checkpoint, map_location=device)) # type: ignore
|
||||
src_model.eval() # type: ignore
|
||||
|
||||
dst_model = InformativeDrawings()
|
||||
|
||||
x = torch.randn(1, 3, 512, 512)
|
||||
|
||||
mapping = create_state_dict_mapping(src_model, dst_model, [x])
|
||||
state_dict = convert_state_dict(src_model.state_dict(), dst_model.state_dict(), mapping)
|
||||
mapping = create_state_dict_mapping(source_model=src_model, target_model=dst_model, source_args=[x]) # type: ignore
|
||||
assert mapping is not None, "Model conversion failed"
|
||||
state_dict = convert_state_dict(source_state_dict=src_model.state_dict(), target_state_dict=dst_model.state_dict(), state_dict_mapping=mapping) # type: ignore
|
||||
return {k: v.half() for k, v in state_dict.items()}
|
||||
|
||||
|
||||
def main():
|
||||
def main() -> None:
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
|
@ -52,8 +49,8 @@ def main():
|
|||
args = parser.parse_args()
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
tensors = convert(args.source, device)
|
||||
save_file(tensors, args.output_file)
|
||||
tensors = convert(checkpoint=args.source, device=device)
|
||||
save_to_safetensors(path=args.output_file, tensors=tensors)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
Loading…
Reference in a new issue