From 9da00e6fcf43a39dcb366c23fc30ab98ef161e7c Mon Sep 17 00:00:00 2001 From: limiteinductive Date: Thu, 17 Aug 2023 12:32:08 +0200 Subject: [PATCH] fix typing for informative drawings convert script --- .../convert-informative-drawings-weights.py | 27 +++++++++---------- 1 file changed, 12 insertions(+), 15 deletions(-) diff --git a/scripts/convert-informative-drawings-weights.py b/scripts/convert-informative-drawings-weights.py index adac5f0..49034df 100644 --- a/scripts/convert-informative-drawings-weights.py +++ b/scripts/convert-informative-drawings-weights.py @@ -5,32 +5,29 @@ import torch -from safetensors.torch import save_file -from refiners.fluxion.utils import ( - create_state_dict_mapping, - convert_state_dict, -) +from refiners.fluxion.utils import create_state_dict_mapping, convert_state_dict, save_to_safetensors from refiners.foundationals.latent_diffusion.preprocessors.informative_drawings import InformativeDrawings -from model import Generator +from model import Generator # type: ignore @torch.no_grad() -def convert(checkpoint: str, device: torch.device) -> dict[str, torch.Tensor]: - src_model = Generator(3, 1, 3) - src_model.load_state_dict(torch.load(checkpoint, map_location=device)) - src_model.eval() +def convert(checkpoint: str, device: torch.device | str) -> dict[str, torch.Tensor]: + src_model = Generator(3, 1, 3) # type: ignore + src_model.load_state_dict(torch.load(checkpoint, map_location=device)) # type: ignore + src_model.eval() # type: ignore dst_model = InformativeDrawings() x = torch.randn(1, 3, 512, 512) - mapping = create_state_dict_mapping(src_model, dst_model, [x]) - state_dict = convert_state_dict(src_model.state_dict(), dst_model.state_dict(), mapping) + mapping = create_state_dict_mapping(source_model=src_model, target_model=dst_model, source_args=[x]) # type: ignore + assert mapping is not None, "Model conversion failed" + state_dict = convert_state_dict(source_state_dict=src_model.state_dict(), target_state_dict=dst_model.state_dict(), state_dict_mapping=mapping) # type: ignore return {k: v.half() for k, v in state_dict.items()} -def main(): +def main() -> None: import argparse parser = argparse.ArgumentParser() @@ -52,8 +49,8 @@ def main(): args = parser.parse_args() device = "cuda" if torch.cuda.is_available() else "cpu" - tensors = convert(args.source, device) - save_file(tensors, args.output_file) + tensors = convert(checkpoint=args.source, device=device) + save_to_safetensors(path=args.output_file, tensors=tensors) if __name__ == "__main__":