mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-23 06:38:45 +00:00
gradute comfyui-refiners get_dtype to refiners
This commit is contained in:
parent
4360aa046f
commit
16714e6745
|
@ -3,9 +3,9 @@ from typing import Any, Sequence
|
|||
import torch
|
||||
from transformers import GroundingDinoForObjectDetection, GroundingDinoProcessor # type: ignore
|
||||
|
||||
from refiners.fluxion.utils import no_grad, tensor_to_image
|
||||
from refiners.fluxion.utils import no_grad, str_to_dtype, tensor_to_image
|
||||
|
||||
from .utils import BoundingBox, get_dtype
|
||||
from .utils import BoundingBox
|
||||
|
||||
|
||||
class LoadGroundingDino:
|
||||
|
@ -54,7 +54,7 @@ class LoadGroundingDino:
|
|||
processor = GroundingDinoProcessor.from_pretrained(checkpoint) # type: ignore
|
||||
assert isinstance(processor, GroundingDinoProcessor)
|
||||
|
||||
model = GroundingDinoForObjectDetection.from_pretrained(checkpoint, torch_dtype=get_dtype(dtype)) # type: ignore
|
||||
model = GroundingDinoForObjectDetection.from_pretrained(checkpoint, torch_dtype=str_to_dtype(dtype)) # type: ignore
|
||||
model = model.to(device=device) # type: ignore
|
||||
assert isinstance(model, GroundingDinoForObjectDetection)
|
||||
|
||||
|
|
|
@ -48,39 +48,6 @@ class DrawBoundingBox:
|
|||
return (image,)
|
||||
|
||||
|
||||
def get_dtype(dtype: str) -> torch.dtype:
|
||||
"""Converts a string dtype to a torch.dtype.
|
||||
|
||||
See also https://pytorch.org/docs/stable/tensor_attributes.html#torch-dtype"""
|
||||
match dtype:
|
||||
case "float32" | "float":
|
||||
return torch.float32
|
||||
case "float64" | "double":
|
||||
return torch.float64
|
||||
case "complex64" | "cfloat":
|
||||
return torch.complex64
|
||||
case "complex128" | "cdouble":
|
||||
return torch.complex128
|
||||
case "float16" | "half":
|
||||
return torch.float16
|
||||
case "bfloat16":
|
||||
return torch.bfloat16
|
||||
case "uint8":
|
||||
return torch.uint8
|
||||
case "int8":
|
||||
return torch.int8
|
||||
case "int16" | "short":
|
||||
return torch.int16
|
||||
case "int32" | "int":
|
||||
return torch.int32
|
||||
case "int64" | "long":
|
||||
return torch.int64
|
||||
case "bool":
|
||||
return torch.bool
|
||||
case _:
|
||||
raise ValueError(f"Unknown dtype: {dtype}")
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS: dict[str, Any] = {
|
||||
"DrawBoundingBox": DrawBoundingBox,
|
||||
}
|
||||
|
|
|
@ -270,3 +270,37 @@ def summarize_tensor(tensor: torch.Tensor, /) -> str:
|
|||
)
|
||||
|
||||
return "Tensor(" + ", ".join(info_list) + ")"
|
||||
|
||||
|
||||
def str_to_dtype(dtype: str) -> torch.dtype:
|
||||
"""Converts a string dtype to a torch.dtype.
|
||||
|
||||
See also https://pytorch.org/docs/stable/tensor_attributes.html#torch-dtype
|
||||
"""
|
||||
match dtype.lower():
|
||||
case "float32" | "float":
|
||||
return torch.float32
|
||||
case "float64" | "double":
|
||||
return torch.float64
|
||||
case "complex64" | "cfloat":
|
||||
return torch.complex64
|
||||
case "complex128" | "cdouble":
|
||||
return torch.complex128
|
||||
case "float16" | "half":
|
||||
return torch.float16
|
||||
case "bfloat16":
|
||||
return torch.bfloat16
|
||||
case "uint8":
|
||||
return torch.uint8
|
||||
case "int8":
|
||||
return torch.int8
|
||||
case "int16" | "short":
|
||||
return torch.int16
|
||||
case "int32" | "int":
|
||||
return torch.int32
|
||||
case "int64" | "long":
|
||||
return torch.int64
|
||||
case "bool":
|
||||
return torch.bool
|
||||
case _:
|
||||
raise ValueError(f"Unknown dtype: {dtype}")
|
||||
|
|
Loading…
Reference in a new issue