mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-24 23:28:45 +00:00
add comfyui custom nodes
This commit is contained in:
parent
b6f547d14a
commit
b67efb26df
|
@ -92,6 +92,7 @@ dev-dependencies = [
|
||||||
"pytest>=8.0.0",
|
"pytest>=8.0.0",
|
||||||
"coverage>=7.4.1",
|
"coverage>=7.4.1",
|
||||||
"typos>=1.18.2",
|
"typos>=1.18.2",
|
||||||
|
"comfy-cli>=1.1.6",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|
1
src/comfyui-refiners/LICENSE
Symbolic link
1
src/comfyui-refiners/LICENSE
Symbolic link
|
@ -0,0 +1 @@
|
||||||
|
../../LICENSE
|
30
src/comfyui-refiners/README.md
Normal file
30
src/comfyui-refiners/README.md
Normal file
|
@ -0,0 +1,30 @@
|
||||||
|
<div align="center">
|
||||||
|
|
||||||
|
<picture>
|
||||||
|
<source media="(prefers-color-scheme: dark)" srcset="https://raw.githubusercontent.com/finegrain-ai/refiners/main/assets/logo_dark.png">
|
||||||
|
<source media="(prefers-color-scheme: light)" srcset="https://raw.githubusercontent.com/finegrain-ai/refiners/main/assets/logo_light.png">
|
||||||
|
<img alt="Finegrain Refiners Library" width="352" height="128" style="max-width: 100%;">
|
||||||
|
</picture>
|
||||||
|
|
||||||
|
**The simplest way to train and run adapters on top of foundation models**
|
||||||
|
|
||||||
|
[**Manifesto**](https://refine.rs/home/why/) |
|
||||||
|
[**Docs**](https://refine.rs) |
|
||||||
|
[**Guides**](https://refine.rs/guides/adapting_sdxl/) |
|
||||||
|
[**Discussions**](https://github.com/finegrain-ai/refiners/discussions) |
|
||||||
|
[**Discord**](https://discord.gg/mCmjNUVV7d)
|
||||||
|
|
||||||
|
</div>
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
|
||||||
|
The nodes are published at https://registry.comfy.org/publishers/finegrain/nodes/comfyui-refiners.
|
||||||
|
|
||||||
|
To easily install the nodes, run the following command:
|
||||||
|
```bash
|
||||||
|
comfy node registry-install comfyui-refiners
|
||||||
|
```
|
||||||
|
|
||||||
|
You may also download the nodes by cliking the "Download Latest" button and unzipping the content of the archive into you custom_nodes directory.
|
||||||
|
|
||||||
|
See https://docs.comfy.org/registry/overview for more information.
|
15
src/comfyui-refiners/__init__.py
Normal file
15
src/comfyui-refiners/__init__.py
Normal file
|
@ -0,0 +1,15 @@
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from .box_segmenter import NODE_CLASS_MAPPINGS as box_segmenter_mappings
|
||||||
|
from .grounding_dino import NODE_CLASS_MAPPINGS as grounding_dino_mappings
|
||||||
|
from .huggingface import NODE_CLASS_MAPPINGS as huggingface_mappings
|
||||||
|
from .utils import NODE_CLASS_MAPPINGS as utils_mappings
|
||||||
|
|
||||||
|
NODE_CLASS_MAPPINGS: dict[str, Any] = {}
|
||||||
|
NODE_CLASS_MAPPINGS.update(box_segmenter_mappings)
|
||||||
|
NODE_CLASS_MAPPINGS.update(grounding_dino_mappings)
|
||||||
|
NODE_CLASS_MAPPINGS.update(huggingface_mappings)
|
||||||
|
NODE_CLASS_MAPPINGS.update(utils_mappings)
|
||||||
|
|
||||||
|
NODE_DISPLAY_NAME_MAPPINGS = {k: v.__name__ for k, v in NODE_CLASS_MAPPINGS.items()}
|
||||||
|
__all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS"]
|
105
src/comfyui-refiners/box_segmenter.py
Normal file
105
src/comfyui-refiners/box_segmenter.py
Normal file
|
@ -0,0 +1,105 @@
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from refiners.fluxion.utils import image_to_tensor, no_grad, tensor_to_image
|
||||||
|
from refiners.solutions import BoxSegmenter as _BoxSegmenter
|
||||||
|
from refiners.solutions.box_segmenter import BoundingBox
|
||||||
|
|
||||||
|
|
||||||
|
class LoadBoxSegmenter:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"checkpoint": ("PATH", {}),
|
||||||
|
"margin": (
|
||||||
|
"FLOAT",
|
||||||
|
{
|
||||||
|
"default": 0.05,
|
||||||
|
"min": 0.0,
|
||||||
|
"max": 1.0,
|
||||||
|
"step": 0.01,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"device": ("STRING", {"default": "cuda"}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("MODEL",)
|
||||||
|
RETURN_NAMES = ("model",)
|
||||||
|
DESCRIPTION = "Load a BoxSegmenter refiners model."
|
||||||
|
CATEGORY = "Refiners/Solutions"
|
||||||
|
FUNCTION = "load"
|
||||||
|
|
||||||
|
def load(
|
||||||
|
self,
|
||||||
|
checkpoint: str,
|
||||||
|
margin: float,
|
||||||
|
device: str,
|
||||||
|
) -> tuple[_BoxSegmenter]:
|
||||||
|
"""Load a BoxSegmenter refiners model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
checkpoint: The path to the checkpoint file.
|
||||||
|
margin: The bbox margin to use when processing images.
|
||||||
|
device: The torch device to load the model on.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A BoxSegmenter model instance.
|
||||||
|
"""
|
||||||
|
return (
|
||||||
|
_BoxSegmenter(
|
||||||
|
weights=checkpoint,
|
||||||
|
margin=margin,
|
||||||
|
device=device,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class BoxSegmenter:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"model": ("MODEL", {}),
|
||||||
|
"image": ("IMAGE", {}),
|
||||||
|
},
|
||||||
|
"optional": {
|
||||||
|
"bbox": ("BOUNDING_BOX", {}),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("MASK",)
|
||||||
|
RETURN_NAMES = ("mask",)
|
||||||
|
DESCRIPTION = "Segment an image using a BoxSegmenter model and a bbox."
|
||||||
|
CATEGORY = "Refiners/Solutions"
|
||||||
|
FUNCTION = "process"
|
||||||
|
|
||||||
|
@no_grad()
|
||||||
|
def process(
|
||||||
|
self,
|
||||||
|
model: _BoxSegmenter,
|
||||||
|
image: torch.Tensor,
|
||||||
|
bbox: BoundingBox | None = None,
|
||||||
|
) -> tuple[torch.Tensor]:
|
||||||
|
"""Segment an image using a BoxSegmenter model and a bbox.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: The BoxSegmenter model to use.
|
||||||
|
image: The input image to process.
|
||||||
|
bbox: Where in the image to apply the model.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The mask of the segmented object.
|
||||||
|
"""
|
||||||
|
pil_image = tensor_to_image(image.permute(0, 3, 1, 2))
|
||||||
|
mask = model(img=pil_image, box_prompt=bbox)
|
||||||
|
mask_tensor = image_to_tensor(mask).squeeze(1)
|
||||||
|
return (mask_tensor,)
|
||||||
|
|
||||||
|
|
||||||
|
NODE_CLASS_MAPPINGS: dict[str, Any] = {
|
||||||
|
"BoxSegmenter": BoxSegmenter,
|
||||||
|
"LoadBoxSegmenter": LoadBoxSegmenter,
|
||||||
|
}
|
191
src/comfyui-refiners/grounding_dino.py
Normal file
191
src/comfyui-refiners/grounding_dino.py
Normal file
|
@ -0,0 +1,191 @@
|
||||||
|
from typing import Any, Sequence
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from transformers import GroundingDinoForObjectDetection, GroundingDinoProcessor # type: ignore
|
||||||
|
|
||||||
|
from refiners.fluxion.utils import no_grad, tensor_to_image
|
||||||
|
|
||||||
|
from .utils import BoundingBox, get_dtype
|
||||||
|
|
||||||
|
|
||||||
|
class LoadGroundingDino:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"checkpoint": ("PATH", {}),
|
||||||
|
"dtype": (
|
||||||
|
"STRING",
|
||||||
|
{
|
||||||
|
"default": "float32",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"device": (
|
||||||
|
"STRING",
|
||||||
|
{
|
||||||
|
"default": "cuda",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("PROCESSOR", "MODEL")
|
||||||
|
RETURN_NAMES = ("processor", "model")
|
||||||
|
DESCRIPTION = "Load a grounding dino model."
|
||||||
|
CATEGORY = "Refiners/Solutions"
|
||||||
|
FUNCTION = "load"
|
||||||
|
|
||||||
|
def load(
|
||||||
|
self,
|
||||||
|
checkpoint: str,
|
||||||
|
dtype: str,
|
||||||
|
device: str,
|
||||||
|
) -> tuple[GroundingDinoProcessor, GroundingDinoForObjectDetection]:
|
||||||
|
"""Load a grounding dino model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
checkpoint: The path to the checkpoint folder.
|
||||||
|
dtype: The torch data type to use.
|
||||||
|
device: The torch device to load the model on.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The grounding dino processor and model instances.
|
||||||
|
"""
|
||||||
|
processor = GroundingDinoProcessor.from_pretrained(checkpoint) # type: ignore
|
||||||
|
assert isinstance(processor, GroundingDinoProcessor)
|
||||||
|
|
||||||
|
model = GroundingDinoForObjectDetection.from_pretrained(checkpoint, torch_dtype=get_dtype(dtype)) # type: ignore
|
||||||
|
model = model.to(device=device) # type: ignore
|
||||||
|
assert isinstance(model, GroundingDinoForObjectDetection)
|
||||||
|
|
||||||
|
return (processor, model)
|
||||||
|
|
||||||
|
|
||||||
|
# NOTE: not yet natively supported in Refiners, hence the transformers dependency
|
||||||
|
class GroundingDino:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"processor": ("PROCESSOR", {}),
|
||||||
|
"model": ("MODEL", {}),
|
||||||
|
"image": ("IMAGE", {}),
|
||||||
|
"prompt": ("STRING", {}),
|
||||||
|
"box_threshold": (
|
||||||
|
"FLOAT",
|
||||||
|
{
|
||||||
|
"default": 0.25,
|
||||||
|
"min": 0.0,
|
||||||
|
"max": 1.0,
|
||||||
|
"step": 0.01,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"text_threshold": (
|
||||||
|
"FLOAT",
|
||||||
|
{
|
||||||
|
"default": 0.25,
|
||||||
|
"min": 0.0,
|
||||||
|
"max": 1.0,
|
||||||
|
"step": 0.01,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("BOUNDING_BOX",)
|
||||||
|
RETURN_NAMES = ("bbox",)
|
||||||
|
DESCRIPTION = "Detect an object in an image using a GroundingDino model."
|
||||||
|
CATEGORY = "Refiners/Solutions"
|
||||||
|
FUNCTION = "process"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def corners_to_pixels_format(
|
||||||
|
bboxes: torch.Tensor,
|
||||||
|
width: int,
|
||||||
|
height: int,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
x1, y1, x2, y2 = bboxes.round().to(torch.int32).unbind(-1)
|
||||||
|
return torch.stack(
|
||||||
|
tensors=(
|
||||||
|
x1.clamp_(0, width),
|
||||||
|
y1.clamp_(0, height),
|
||||||
|
x2.clamp_(0, width),
|
||||||
|
y2.clamp_(0, height),
|
||||||
|
),
|
||||||
|
dim=-1,
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def bbox_union(bboxes: Sequence[list[int]]) -> BoundingBox | None:
|
||||||
|
if not bboxes:
|
||||||
|
return None
|
||||||
|
for bbox in bboxes:
|
||||||
|
assert len(bbox) == 4
|
||||||
|
assert all(isinstance(x, int) for x in bbox)
|
||||||
|
return (
|
||||||
|
min(bbox[0] for bbox in bboxes),
|
||||||
|
min(bbox[1] for bbox in bboxes),
|
||||||
|
max(bbox[2] for bbox in bboxes),
|
||||||
|
max(bbox[3] for bbox in bboxes),
|
||||||
|
)
|
||||||
|
|
||||||
|
@no_grad()
|
||||||
|
def process(
|
||||||
|
self,
|
||||||
|
processor: GroundingDinoProcessor,
|
||||||
|
model: GroundingDinoForObjectDetection,
|
||||||
|
image: torch.Tensor,
|
||||||
|
prompt: str,
|
||||||
|
box_threshold: float,
|
||||||
|
text_threshold: float,
|
||||||
|
) -> tuple[BoundingBox]:
|
||||||
|
"""Detect an object in an image using a GroundingDino model and a text prompt.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
processor: The image processor to use.
|
||||||
|
model: The grounding dino model to use.
|
||||||
|
image: The input image to detect in.
|
||||||
|
prompt: The text prompt of what to detect in the image.
|
||||||
|
box_threshold: The score threshold for the bounding boxes.
|
||||||
|
text_threshold: The score threshold for the text.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The union of the bounding boxes found in the image.
|
||||||
|
"""
|
||||||
|
# prepare the inputs
|
||||||
|
pil_image = tensor_to_image(image.permute(0, 3, 1, 2))
|
||||||
|
|
||||||
|
# NOTE: queries must be in lower cas + end with a dot. See:
|
||||||
|
# https://github.com/IDEA-Research/GroundingDINO/blob/856dde2/groundingdino/util/inference.py#L22-L26
|
||||||
|
inputs = processor(images=pil_image, text=f"{prompt.lower()}.", return_tensors="pt").to(device=model.device)
|
||||||
|
|
||||||
|
# get the model's prediction
|
||||||
|
outputs = model(**inputs)
|
||||||
|
|
||||||
|
# post-process the model's prediction
|
||||||
|
results: dict[str, Any] = processor.post_process_grounded_object_detection( # type: ignore
|
||||||
|
outputs=outputs,
|
||||||
|
input_ids=inputs["input_ids"],
|
||||||
|
target_sizes=[(pil_image.height, pil_image.width)],
|
||||||
|
box_threshold=box_threshold,
|
||||||
|
text_threshold=text_threshold,
|
||||||
|
)[0]
|
||||||
|
|
||||||
|
# retrieve the bounding boxes
|
||||||
|
assert "boxes" in results
|
||||||
|
bboxes = results["boxes"].cpu() # type: ignore
|
||||||
|
assert isinstance(bboxes, torch.Tensor)
|
||||||
|
assert bboxes.shape[0] != 0, "No bounding boxes found. Try adjusting the thresholds or pick another prompt."
|
||||||
|
bboxes = self.corners_to_pixels_format(bboxes, pil_image.width, pil_image.height) # type: ignore
|
||||||
|
|
||||||
|
# compute the union of the bounding boxes
|
||||||
|
bbox = self.bbox_union(bboxes.numpy().tolist()) # type: ignore
|
||||||
|
assert bbox is not None
|
||||||
|
|
||||||
|
return (bbox,)
|
||||||
|
|
||||||
|
|
||||||
|
NODE_CLASS_MAPPINGS: dict[str, Any] = {
|
||||||
|
"GroundingDino": GroundingDino,
|
||||||
|
"LoadGroundingDino": LoadGroundingDino,
|
||||||
|
}
|
63
src/comfyui-refiners/huggingface.py
Normal file
63
src/comfyui-refiners/huggingface.py
Normal file
|
@ -0,0 +1,63 @@
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from huggingface_hub import hf_hub_download, snapshot_download # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
class HfHubDownload:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"repo_id": ("STRING", {}),
|
||||||
|
},
|
||||||
|
"optional": {
|
||||||
|
"filename": ("STRING", {}),
|
||||||
|
"revision": (
|
||||||
|
"STRING",
|
||||||
|
{
|
||||||
|
"default": "main",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("PATH",)
|
||||||
|
RETURN_NAMES = ("path",)
|
||||||
|
DESCRIPTION = "Download file(s) from the HuggingFace Hub."
|
||||||
|
CATEGORY = "Refiners/HuggingFace"
|
||||||
|
FUNCTION = "download"
|
||||||
|
|
||||||
|
def download(
|
||||||
|
self,
|
||||||
|
repo_id: str,
|
||||||
|
filename: str,
|
||||||
|
revision: str,
|
||||||
|
) -> tuple[Path]:
|
||||||
|
"""Download file(s) from the HuggingFace Hub.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
repo_id: The HuggingFace repository ID.
|
||||||
|
filename: The filename to download, if empty, the entire repository will be downloaded.
|
||||||
|
revision: The git revision to download.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The path to the downloaded file(s).
|
||||||
|
"""
|
||||||
|
if filename == "":
|
||||||
|
path = snapshot_download(
|
||||||
|
repo_id=repo_id,
|
||||||
|
revision=revision,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
path = hf_hub_download(
|
||||||
|
repo_id=repo_id,
|
||||||
|
filename=filename,
|
||||||
|
revision=revision,
|
||||||
|
)
|
||||||
|
return (Path(path),)
|
||||||
|
|
||||||
|
|
||||||
|
NODE_CLASS_MAPPINGS: dict[str, Any] = {
|
||||||
|
"HfHubDownload": HfHubDownload,
|
||||||
|
}
|
18
src/comfyui-refiners/pyproject.toml
Normal file
18
src/comfyui-refiners/pyproject.toml
Normal file
|
@ -0,0 +1,18 @@
|
||||||
|
[project]
|
||||||
|
name = "comfyui-refiners"
|
||||||
|
description = "ComfyUI custom nodes for refiners models"
|
||||||
|
version = "1.0.1"
|
||||||
|
license = { file = "LICENSE" }
|
||||||
|
dependencies = [
|
||||||
|
"refiners @ git+https://github.com/finegrain-ai/refiners.git",
|
||||||
|
"huggingface_hub",
|
||||||
|
"transformers",
|
||||||
|
]
|
||||||
|
|
||||||
|
[project.urls]
|
||||||
|
Repository = "https://github.com/finegrain-ai/refiners"
|
||||||
|
|
||||||
|
[tool.comfy]
|
||||||
|
PublisherId = "finegrain"
|
||||||
|
DisplayName = "refiners"
|
||||||
|
Icon = "https://raw.githubusercontent.com/finegrain-ai/refiners/main/assets/dropy.png"
|
86
src/comfyui-refiners/utils.py
Normal file
86
src/comfyui-refiners/utils.py
Normal file
|
@ -0,0 +1,86 @@
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from PIL import ImageDraw
|
||||||
|
|
||||||
|
from refiners.fluxion.utils import image_to_tensor, tensor_to_image
|
||||||
|
|
||||||
|
BoundingBox = tuple[int, int, int, int]
|
||||||
|
|
||||||
|
|
||||||
|
class DrawBoundingBox:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"image": ("IMAGE", {}),
|
||||||
|
"bbox": ("BOUNDING_BOX", {}),
|
||||||
|
"color": ("STRING", {"default": "red"}),
|
||||||
|
"width": ("INT", {"default": 3}),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("IMAGE",)
|
||||||
|
RETURN_NAMES = ("image",)
|
||||||
|
DESCRIPTION = "Draw a bounding box on an image."
|
||||||
|
CATEGORY = "Refiners/Helpers"
|
||||||
|
FUNCTION = "process"
|
||||||
|
|
||||||
|
def process(
|
||||||
|
self,
|
||||||
|
image: torch.Tensor,
|
||||||
|
bbox: BoundingBox,
|
||||||
|
color: str,
|
||||||
|
width: int,
|
||||||
|
) -> tuple[torch.Tensor]:
|
||||||
|
"""Draw a bounding box on an image.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image: The image to draw on.
|
||||||
|
bbox: The bounding box to draw.
|
||||||
|
color: The color of the bounding box.
|
||||||
|
width: The width of the bounding box.
|
||||||
|
"""
|
||||||
|
pil_image = tensor_to_image(image.permute(0, 3, 1, 2))
|
||||||
|
draw = ImageDraw.Draw(pil_image)
|
||||||
|
draw.rectangle(bbox, outline=color, width=width)
|
||||||
|
image = image_to_tensor(pil_image).permute(0, 2, 3, 1)
|
||||||
|
return (image,)
|
||||||
|
|
||||||
|
|
||||||
|
def get_dtype(dtype: str) -> torch.dtype:
|
||||||
|
"""Converts a string dtype to a torch.dtype.
|
||||||
|
|
||||||
|
See also https://pytorch.org/docs/stable/tensor_attributes.html#torch-dtype"""
|
||||||
|
match dtype:
|
||||||
|
case "float32" | "float":
|
||||||
|
return torch.float32
|
||||||
|
case "float64" | "double":
|
||||||
|
return torch.float64
|
||||||
|
case "complex64" | "cfloat":
|
||||||
|
return torch.complex64
|
||||||
|
case "complex128" | "cdouble":
|
||||||
|
return torch.complex128
|
||||||
|
case "float16" | "half":
|
||||||
|
return torch.float16
|
||||||
|
case "bfloat16":
|
||||||
|
return torch.bfloat16
|
||||||
|
case "uint8":
|
||||||
|
return torch.uint8
|
||||||
|
case "int8":
|
||||||
|
return torch.int8
|
||||||
|
case "int16" | "short":
|
||||||
|
return torch.int16
|
||||||
|
case "int32" | "int":
|
||||||
|
return torch.int32
|
||||||
|
case "int64" | "long":
|
||||||
|
return torch.int64
|
||||||
|
case "bool":
|
||||||
|
return torch.bool
|
||||||
|
case _:
|
||||||
|
raise ValueError(f"Unknown dtype: {dtype}")
|
||||||
|
|
||||||
|
|
||||||
|
NODE_CLASS_MAPPINGS: dict[str, Any] = {
|
||||||
|
"DrawBoundingBox": DrawBoundingBox,
|
||||||
|
}
|
Loading…
Reference in a new issue