diff --git a/src/comfyui-refiners/grounding_dino.py b/src/comfyui-refiners/grounding_dino.py index 4675872..e325886 100644 --- a/src/comfyui-refiners/grounding_dino.py +++ b/src/comfyui-refiners/grounding_dino.py @@ -3,9 +3,9 @@ from typing import Any, Sequence import torch from transformers import GroundingDinoForObjectDetection, GroundingDinoProcessor # type: ignore -from refiners.fluxion.utils import no_grad, tensor_to_image +from refiners.fluxion.utils import no_grad, str_to_dtype, tensor_to_image -from .utils import BoundingBox, get_dtype +from .utils import BoundingBox class LoadGroundingDino: @@ -54,7 +54,7 @@ class LoadGroundingDino: processor = GroundingDinoProcessor.from_pretrained(checkpoint) # type: ignore assert isinstance(processor, GroundingDinoProcessor) - model = GroundingDinoForObjectDetection.from_pretrained(checkpoint, torch_dtype=get_dtype(dtype)) # type: ignore + model = GroundingDinoForObjectDetection.from_pretrained(checkpoint, torch_dtype=str_to_dtype(dtype)) # type: ignore model = model.to(device=device) # type: ignore assert isinstance(model, GroundingDinoForObjectDetection) diff --git a/src/comfyui-refiners/utils.py b/src/comfyui-refiners/utils.py index 9595376..e1e4613 100644 --- a/src/comfyui-refiners/utils.py +++ b/src/comfyui-refiners/utils.py @@ -48,39 +48,6 @@ class DrawBoundingBox: return (image,) -def get_dtype(dtype: str) -> torch.dtype: - """Converts a string dtype to a torch.dtype. - - See also https://pytorch.org/docs/stable/tensor_attributes.html#torch-dtype""" - match dtype: - case "float32" | "float": - return torch.float32 - case "float64" | "double": - return torch.float64 - case "complex64" | "cfloat": - return torch.complex64 - case "complex128" | "cdouble": - return torch.complex128 - case "float16" | "half": - return torch.float16 - case "bfloat16": - return torch.bfloat16 - case "uint8": - return torch.uint8 - case "int8": - return torch.int8 - case "int16" | "short": - return torch.int16 - case "int32" | "int": - return torch.int32 - case "int64" | "long": - return torch.int64 - case "bool": - return torch.bool - case _: - raise ValueError(f"Unknown dtype: {dtype}") - - NODE_CLASS_MAPPINGS: dict[str, Any] = { "DrawBoundingBox": DrawBoundingBox, } diff --git a/src/refiners/fluxion/utils.py b/src/refiners/fluxion/utils.py index f80589d..5de2f99 100644 --- a/src/refiners/fluxion/utils.py +++ b/src/refiners/fluxion/utils.py @@ -270,3 +270,37 @@ def summarize_tensor(tensor: torch.Tensor, /) -> str: ) return "Tensor(" + ", ".join(info_list) + ")" + + +def str_to_dtype(dtype: str) -> torch.dtype: + """Converts a string dtype to a torch.dtype. + + See also https://pytorch.org/docs/stable/tensor_attributes.html#torch-dtype + """ + match dtype.lower(): + case "float32" | "float": + return torch.float32 + case "float64" | "double": + return torch.float64 + case "complex64" | "cfloat": + return torch.complex64 + case "complex128" | "cdouble": + return torch.complex128 + case "float16" | "half": + return torch.float16 + case "bfloat16": + return torch.bfloat16 + case "uint8": + return torch.uint8 + case "int8": + return torch.int8 + case "int16" | "short": + return torch.int16 + case "int32" | "int": + return torch.int32 + case "int64" | "long": + return torch.int64 + case "bool": + return torch.bool + case _: + raise ValueError(f"Unknown dtype: {dtype}")