gradute comfyui-refiners get_dtype to refiners

This commit is contained in:
Laurent 2024-10-03 08:46:23 +00:00 committed by Laureηt
parent 4360aa046f
commit 16714e6745
3 changed files with 37 additions and 36 deletions

View file

@ -3,9 +3,9 @@ from typing import Any, Sequence
import torch
from transformers import GroundingDinoForObjectDetection, GroundingDinoProcessor # type: ignore
from refiners.fluxion.utils import no_grad, tensor_to_image
from refiners.fluxion.utils import no_grad, str_to_dtype, tensor_to_image
from .utils import BoundingBox, get_dtype
from .utils import BoundingBox
class LoadGroundingDino:
@ -54,7 +54,7 @@ class LoadGroundingDino:
processor = GroundingDinoProcessor.from_pretrained(checkpoint) # type: ignore
assert isinstance(processor, GroundingDinoProcessor)
model = GroundingDinoForObjectDetection.from_pretrained(checkpoint, torch_dtype=get_dtype(dtype)) # type: ignore
model = GroundingDinoForObjectDetection.from_pretrained(checkpoint, torch_dtype=str_to_dtype(dtype)) # type: ignore
model = model.to(device=device) # type: ignore
assert isinstance(model, GroundingDinoForObjectDetection)

View file

@ -48,39 +48,6 @@ class DrawBoundingBox:
return (image,)
def get_dtype(dtype: str) -> torch.dtype:
"""Converts a string dtype to a torch.dtype.
See also https://pytorch.org/docs/stable/tensor_attributes.html#torch-dtype"""
match dtype:
case "float32" | "float":
return torch.float32
case "float64" | "double":
return torch.float64
case "complex64" | "cfloat":
return torch.complex64
case "complex128" | "cdouble":
return torch.complex128
case "float16" | "half":
return torch.float16
case "bfloat16":
return torch.bfloat16
case "uint8":
return torch.uint8
case "int8":
return torch.int8
case "int16" | "short":
return torch.int16
case "int32" | "int":
return torch.int32
case "int64" | "long":
return torch.int64
case "bool":
return torch.bool
case _:
raise ValueError(f"Unknown dtype: {dtype}")
NODE_CLASS_MAPPINGS: dict[str, Any] = {
"DrawBoundingBox": DrawBoundingBox,
}

View file

@ -270,3 +270,37 @@ def summarize_tensor(tensor: torch.Tensor, /) -> str:
)
return "Tensor(" + ", ".join(info_list) + ")"
def str_to_dtype(dtype: str) -> torch.dtype:
"""Converts a string dtype to a torch.dtype.
See also https://pytorch.org/docs/stable/tensor_attributes.html#torch-dtype
"""
match dtype.lower():
case "float32" | "float":
return torch.float32
case "float64" | "double":
return torch.float64
case "complex64" | "cfloat":
return torch.complex64
case "complex128" | "cdouble":
return torch.complex128
case "float16" | "half":
return torch.float16
case "bfloat16":
return torch.bfloat16
case "uint8":
return torch.uint8
case "int8":
return torch.int8
case "int16" | "short":
return torch.int16
case "int32" | "int":
return torch.int32
case "int64" | "long":
return torch.int64
case "bool":
return torch.bool
case _:
raise ValueError(f"Unknown dtype: {dtype}")