mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-15 01:28:14 +00:00
gradute comfyui-refiners get_dtype to refiners
This commit is contained in:
parent
4360aa046f
commit
16714e6745
|
@ -3,9 +3,9 @@ from typing import Any, Sequence
|
||||||
import torch
|
import torch
|
||||||
from transformers import GroundingDinoForObjectDetection, GroundingDinoProcessor # type: ignore
|
from transformers import GroundingDinoForObjectDetection, GroundingDinoProcessor # type: ignore
|
||||||
|
|
||||||
from refiners.fluxion.utils import no_grad, tensor_to_image
|
from refiners.fluxion.utils import no_grad, str_to_dtype, tensor_to_image
|
||||||
|
|
||||||
from .utils import BoundingBox, get_dtype
|
from .utils import BoundingBox
|
||||||
|
|
||||||
|
|
||||||
class LoadGroundingDino:
|
class LoadGroundingDino:
|
||||||
|
@ -54,7 +54,7 @@ class LoadGroundingDino:
|
||||||
processor = GroundingDinoProcessor.from_pretrained(checkpoint) # type: ignore
|
processor = GroundingDinoProcessor.from_pretrained(checkpoint) # type: ignore
|
||||||
assert isinstance(processor, GroundingDinoProcessor)
|
assert isinstance(processor, GroundingDinoProcessor)
|
||||||
|
|
||||||
model = GroundingDinoForObjectDetection.from_pretrained(checkpoint, torch_dtype=get_dtype(dtype)) # type: ignore
|
model = GroundingDinoForObjectDetection.from_pretrained(checkpoint, torch_dtype=str_to_dtype(dtype)) # type: ignore
|
||||||
model = model.to(device=device) # type: ignore
|
model = model.to(device=device) # type: ignore
|
||||||
assert isinstance(model, GroundingDinoForObjectDetection)
|
assert isinstance(model, GroundingDinoForObjectDetection)
|
||||||
|
|
||||||
|
|
|
@ -48,39 +48,6 @@ class DrawBoundingBox:
|
||||||
return (image,)
|
return (image,)
|
||||||
|
|
||||||
|
|
||||||
def get_dtype(dtype: str) -> torch.dtype:
|
|
||||||
"""Converts a string dtype to a torch.dtype.
|
|
||||||
|
|
||||||
See also https://pytorch.org/docs/stable/tensor_attributes.html#torch-dtype"""
|
|
||||||
match dtype:
|
|
||||||
case "float32" | "float":
|
|
||||||
return torch.float32
|
|
||||||
case "float64" | "double":
|
|
||||||
return torch.float64
|
|
||||||
case "complex64" | "cfloat":
|
|
||||||
return torch.complex64
|
|
||||||
case "complex128" | "cdouble":
|
|
||||||
return torch.complex128
|
|
||||||
case "float16" | "half":
|
|
||||||
return torch.float16
|
|
||||||
case "bfloat16":
|
|
||||||
return torch.bfloat16
|
|
||||||
case "uint8":
|
|
||||||
return torch.uint8
|
|
||||||
case "int8":
|
|
||||||
return torch.int8
|
|
||||||
case "int16" | "short":
|
|
||||||
return torch.int16
|
|
||||||
case "int32" | "int":
|
|
||||||
return torch.int32
|
|
||||||
case "int64" | "long":
|
|
||||||
return torch.int64
|
|
||||||
case "bool":
|
|
||||||
return torch.bool
|
|
||||||
case _:
|
|
||||||
raise ValueError(f"Unknown dtype: {dtype}")
|
|
||||||
|
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS: dict[str, Any] = {
|
NODE_CLASS_MAPPINGS: dict[str, Any] = {
|
||||||
"DrawBoundingBox": DrawBoundingBox,
|
"DrawBoundingBox": DrawBoundingBox,
|
||||||
}
|
}
|
||||||
|
|
|
@ -270,3 +270,37 @@ def summarize_tensor(tensor: torch.Tensor, /) -> str:
|
||||||
)
|
)
|
||||||
|
|
||||||
return "Tensor(" + ", ".join(info_list) + ")"
|
return "Tensor(" + ", ".join(info_list) + ")"
|
||||||
|
|
||||||
|
|
||||||
|
def str_to_dtype(dtype: str) -> torch.dtype:
|
||||||
|
"""Converts a string dtype to a torch.dtype.
|
||||||
|
|
||||||
|
See also https://pytorch.org/docs/stable/tensor_attributes.html#torch-dtype
|
||||||
|
"""
|
||||||
|
match dtype.lower():
|
||||||
|
case "float32" | "float":
|
||||||
|
return torch.float32
|
||||||
|
case "float64" | "double":
|
||||||
|
return torch.float64
|
||||||
|
case "complex64" | "cfloat":
|
||||||
|
return torch.complex64
|
||||||
|
case "complex128" | "cdouble":
|
||||||
|
return torch.complex128
|
||||||
|
case "float16" | "half":
|
||||||
|
return torch.float16
|
||||||
|
case "bfloat16":
|
||||||
|
return torch.bfloat16
|
||||||
|
case "uint8":
|
||||||
|
return torch.uint8
|
||||||
|
case "int8":
|
||||||
|
return torch.int8
|
||||||
|
case "int16" | "short":
|
||||||
|
return torch.int16
|
||||||
|
case "int32" | "int":
|
||||||
|
return torch.int32
|
||||||
|
case "int64" | "long":
|
||||||
|
return torch.int64
|
||||||
|
case "bool":
|
||||||
|
return torch.bool
|
||||||
|
case _:
|
||||||
|
raise ValueError(f"Unknown dtype: {dtype}")
|
||||||
|
|
Loading…
Reference in a new issue