add comfyui custom nodes

This commit is contained in:
Laurent 2024-09-05 09:30:48 +00:00 committed by Laureηt
parent b6f547d14a
commit 99e2554b9c
9 changed files with 510 additions and 0 deletions

View file

@ -92,6 +92,7 @@ dev-dependencies = [
"pytest>=8.0.0", "pytest>=8.0.0",
"coverage>=7.4.1", "coverage>=7.4.1",
"typos>=1.18.2", "typos>=1.18.2",
"comfy-cli>=1.1.6",
] ]

View file

@ -0,0 +1 @@
../../LICENSE

View file

@ -0,0 +1,30 @@
<div align="center">
<picture>
<source media="(prefers-color-scheme: dark)" srcset="https://raw.githubusercontent.com/finegrain-ai/refiners/main/assets/logo_dark.png">
<source media="(prefers-color-scheme: light)" srcset="https://raw.githubusercontent.com/finegrain-ai/refiners/main/assets/logo_light.png">
<img alt="Finegrain Refiners Library" width="352" height="128" style="max-width: 100%;">
</picture>
**The simplest way to train and run adapters on top of foundation models**
[**Manifesto**](https://refine.rs/home/why/) |
[**Docs**](https://refine.rs) |
[**Guides**](https://refine.rs/guides/adapting_sdxl/) |
[**Discussions**](https://github.com/finegrain-ai/refiners/discussions) |
[**Discord**](https://discord.gg/mCmjNUVV7d)
</div>
## Installation
The nodes are published at https://registry.comfy.org/publishers/finegrain/nodes/comfyui-refiners.
To easily install the nodes, run the following command:
```bash
comfy node registry-install comfyui-refiners
```
You may also download the nodes by cliking the "Download Latest" button and unzipping the content of the archive into you custom_nodes directory.
See https://docs.comfy.org/registry/overview for more information.

View file

@ -0,0 +1,15 @@
from typing import Any
from .box_segmenter import NODE_CLASS_MAPPINGS as box_segmenter_mappings
from .grounding_dino import NODE_CLASS_MAPPINGS as grounding_dino_mappings
from .huggingface import NODE_CLASS_MAPPINGS as huggingface_mappings
from .utils import NODE_CLASS_MAPPINGS as utils_mappings
NODE_CLASS_MAPPINGS: dict[str, Any] = {}
NODE_CLASS_MAPPINGS.update(box_segmenter_mappings)
NODE_CLASS_MAPPINGS.update(grounding_dino_mappings)
NODE_CLASS_MAPPINGS.update(huggingface_mappings)
NODE_CLASS_MAPPINGS.update(utils_mappings)
NODE_DISPLAY_NAME_MAPPINGS = {k: v.__name__ for k, v in NODE_CLASS_MAPPINGS.items()}
__all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS"]

View file

@ -0,0 +1,105 @@
from typing import Any
import torch
from refiners.fluxion.utils import image_to_tensor, no_grad, tensor_to_image
from refiners.solutions import BoxSegmenter as _BoxSegmenter
from refiners.solutions.box_segmenter import BoundingBox
class LoadBoxSegmenter:
@classmethod
def INPUT_TYPES(cls) -> dict[str, Any]:
return {
"required": {
"checkpoint": ("PATH", {}),
"margin": (
"FLOAT",
{
"default": 0.05,
"min": 0.0,
"max": 1.0,
"step": 0.01,
},
),
"device": ("STRING", {"default": "cuda"}),
}
}
RETURN_TYPES = ("MODEL",)
RETURN_NAMES = ("model",)
DESCRIPTION = "Load a BoxSegmenter refiners model."
CATEGORY = "Refiners/Solutions"
FUNCTION = "load"
def load(
self,
checkpoint: str,
margin: float,
device: str,
) -> tuple[_BoxSegmenter]:
"""Load a BoxSegmenter refiners model.
Args:
checkpoint: The path to the checkpoint file.
margin: The bbox margin to use when processing images.
device: The torch device to load the model on.
Returns:
A BoxSegmenter model instance.
"""
return (
_BoxSegmenter(
weights=checkpoint,
margin=margin,
device=device,
),
)
class BoxSegmenter:
@classmethod
def INPUT_TYPES(cls) -> dict[str, Any]:
return {
"required": {
"model": ("MODEL", {}),
"image": ("IMAGE", {}),
},
"optional": {
"bbox": ("BOUNDING_BOX", {}),
},
}
RETURN_TYPES = ("MASK",)
RETURN_NAMES = ("mask",)
DESCRIPTION = "Segment an image using a BoxSegmenter model and a bbox."
CATEGORY = "Refiners/Solutions"
FUNCTION = "process"
@no_grad()
def process(
self,
model: _BoxSegmenter,
image: torch.Tensor,
bbox: BoundingBox | None = None,
) -> tuple[torch.Tensor]:
"""Segment an image using a BoxSegmenter model and a bbox.
Args:
model: The BoxSegmenter model to use.
image: The input image to process.
bbox: Where in the image to apply the model.
Returns:
The mask of the segmented object.
"""
pil_image = tensor_to_image(image.permute(0, 3, 1, 2))
mask = model(img=pil_image, box_prompt=bbox)
mask_tensor = image_to_tensor(mask).squeeze(1)
return (mask_tensor,)
NODE_CLASS_MAPPINGS: dict[str, Any] = {
"BoxSegmenter": BoxSegmenter,
"LoadBoxSegmenter": LoadBoxSegmenter,
}

View file

@ -0,0 +1,191 @@
from typing import Any, Sequence
import torch
from transformers import GroundingDinoForObjectDetection, GroundingDinoProcessor # type: ignore
from refiners.fluxion.utils import no_grad, tensor_to_image
from .utils import BoundingBox, get_dtype
class LoadGroundingDino:
@classmethod
def INPUT_TYPES(cls) -> dict[str, Any]:
return {
"required": {
"checkpoint": ("PATH", {}),
"dtype": (
"STRING",
{
"default": "float32",
},
),
"device": (
"STRING",
{
"default": "cuda",
},
),
}
}
RETURN_TYPES = ("PROCESSOR", "MODEL")
RETURN_NAMES = ("processor", "model")
DESCRIPTION = "Load a grounding dino model."
CATEGORY = "Refiners/Solutions"
FUNCTION = "load"
def load(
self,
checkpoint: str,
dtype: str,
device: str,
) -> tuple[GroundingDinoProcessor, GroundingDinoForObjectDetection]:
"""Load a grounding dino model.
Args:
checkpoint: The path to the checkpoint folder.
dtype: The torch data type to use.
device: The torch device to load the model on.
Returns:
The grounding dino processor and model instances.
"""
processor = GroundingDinoProcessor.from_pretrained(checkpoint) # type: ignore
assert isinstance(processor, GroundingDinoProcessor)
model = GroundingDinoForObjectDetection.from_pretrained(checkpoint, torch_dtype=get_dtype(dtype)) # type: ignore
model = model.to(device=device) # type: ignore
assert isinstance(model, GroundingDinoForObjectDetection)
return (processor, model)
# NOTE: not yet natively supported in Refiners, hence the transformers dependency
class GroundingDino:
@classmethod
def INPUT_TYPES(cls) -> dict[str, Any]:
return {
"required": {
"processor": ("PROCESSOR", {}),
"model": ("MODEL", {}),
"image": ("IMAGE", {}),
"prompt": ("STRING", {}),
"box_threshold": (
"FLOAT",
{
"default": 0.25,
"min": 0.0,
"max": 1.0,
"step": 0.01,
},
),
"text_threshold": (
"FLOAT",
{
"default": 0.25,
"min": 0.0,
"max": 1.0,
"step": 0.01,
},
),
},
}
RETURN_TYPES = ("BOUNDING_BOX",)
RETURN_NAMES = ("bbox",)
DESCRIPTION = "Detect an object in an image using a GroundingDino model."
CATEGORY = "Refiners/Solutions"
FUNCTION = "process"
@staticmethod
def corners_to_pixels_format(
bboxes: torch.Tensor,
width: int,
height: int,
) -> torch.Tensor:
x1, y1, x2, y2 = bboxes.round().to(torch.int32).unbind(-1)
return torch.stack(
tensors=(
x1.clamp_(0, width),
y1.clamp_(0, height),
x2.clamp_(0, width),
y2.clamp_(0, height),
),
dim=-1,
)
@staticmethod
def bbox_union(bboxes: Sequence[list[int]]) -> BoundingBox | None:
if not bboxes:
return None
for bbox in bboxes:
assert len(bbox) == 4
assert all(isinstance(x, int) for x in bbox)
return (
min(bbox[0] for bbox in bboxes),
min(bbox[1] for bbox in bboxes),
max(bbox[2] for bbox in bboxes),
max(bbox[3] for bbox in bboxes),
)
@no_grad()
def process(
self,
processor: GroundingDinoProcessor,
model: GroundingDinoForObjectDetection,
image: torch.Tensor,
prompt: str,
box_threshold: float,
text_threshold: float,
) -> tuple[BoundingBox]:
"""Detect an object in an image using a GroundingDino model and a text prompt.
Args:
processor: The image processor to use.
model: The grounding dino model to use.
image: The input image to detect in.
prompt: The text prompt of what to detect in the image.
box_threshold: The score threshold for the bounding boxes.
text_threshold: The score threshold for the text.
Returns:
The union of the bounding boxes found in the image.
"""
# prepare the inputs
pil_image = tensor_to_image(image.permute(0, 3, 1, 2))
# NOTE: queries must be in lower cas + end with a dot. See:
# https://github.com/IDEA-Research/GroundingDINO/blob/856dde2/groundingdino/util/inference.py#L22-L26
inputs = processor(images=pil_image, text=f"{prompt.lower()}.", return_tensors="pt").to(device=model.device)
# get the model's prediction
outputs = model(**inputs)
# post-process the model's prediction
results: dict[str, Any] = processor.post_process_grounded_object_detection( # type: ignore
outputs=outputs,
input_ids=inputs["input_ids"],
target_sizes=[(pil_image.height, pil_image.width)],
box_threshold=box_threshold,
text_threshold=text_threshold,
)[0]
# retrieve the bounding boxes
assert "boxes" in results
bboxes = results["boxes"].cpu() # type: ignore
assert isinstance(bboxes, torch.Tensor)
assert bboxes.shape[0] != 0, "No bounding boxes found. Try adjusting the thresholds or pick another prompt."
bboxes = self.corners_to_pixels_format(bboxes, pil_image.width, pil_image.height) # type: ignore
# compute the union of the bounding boxes
bbox = self.bbox_union(bboxes.numpy().tolist()) # type: ignore
assert bbox is not None
return (bbox,)
NODE_CLASS_MAPPINGS: dict[str, Any] = {
"GroundingDino": GroundingDino,
"LoadGroundingDino": LoadGroundingDino,
}

View file

@ -0,0 +1,63 @@
from pathlib import Path
from typing import Any
from huggingface_hub import hf_hub_download, snapshot_download # type: ignore
class HfHubDownload:
@classmethod
def INPUT_TYPES(cls) -> dict[str, Any]:
return {
"required": {
"repo_id": ("STRING", {}),
},
"optional": {
"filename": ("STRING", {}),
"revision": (
"STRING",
{
"default": "main",
},
),
},
}
RETURN_TYPES = ("PATH",)
RETURN_NAMES = ("path",)
DESCRIPTION = "Download file(s) from the HuggingFace Hub."
CATEGORY = "Refiners/HuggingFace"
FUNCTION = "download"
def download(
self,
repo_id: str,
filename: str,
revision: str,
) -> tuple[Path]:
"""Download file(s) from the HuggingFace Hub.
Args:
repo_id: The HuggingFace repository ID.
filename: The filename to download, if empty, the entire repository will be downloaded.
revision: The git revision to download.
Returns:
The path to the downloaded file(s).
"""
if filename == "":
path = snapshot_download(
repo_id=repo_id,
revision=revision,
)
else:
path = hf_hub_download(
repo_id=repo_id,
filename=filename,
revision=revision,
)
return (Path(path),)
NODE_CLASS_MAPPINGS: dict[str, Any] = {
"HfHubDownload": HfHubDownload,
}

View file

@ -0,0 +1,18 @@
[project]
name = "comfyui-refiners"
description = "ComfyUI custom nodes for refiners models"
version = "1.0.1"
license = { file = "LICENSE" }
dependencies = [
"refiners @ git+https://github.com/finegrain-ai/refiners.git",
"huggingface_hub",
"transformers",
]
[project.urls]
Repository = "https://github.com/finegrain-ai/refiners"
[tool.comfy]
PublisherId = "finegrain"
DisplayName = "refiners"
Icon = "https://raw.githubusercontent.com/finegrain-ai/refiners/main/assets/dropy.png"

View file

@ -0,0 +1,86 @@
from typing import Any
import torch
from PIL import ImageDraw
from refiners.fluxion.utils import image_to_tensor, tensor_to_image
BoundingBox = tuple[int, int, int, int]
class DrawBoundingBox:
@classmethod
def INPUT_TYPES(cls) -> dict[str, Any]:
return {
"required": {
"image": ("IMAGE", {}),
"bbox": ("BOUNDING_BOX", {}),
"color": ("STRING", {"default": "red"}),
"width": ("INT", {"default": 3}),
},
}
RETURN_TYPES = ("IMAGE",)
RETURN_NAMES = ("image",)
DESCRIPTION = "Draw a bounding box on an image."
CATEGORY = "Refiners/Helpers"
FUNCTION = "process"
def process(
self,
image: torch.Tensor,
bbox: BoundingBox,
color: str,
width: int,
) -> tuple[torch.Tensor]:
"""Draw a bounding box on an image.
Args:
image: The image to draw on.
bbox: The bounding box to draw.
color: The color of the bounding box.
width: The width of the bounding box.
"""
pil_image = tensor_to_image(image.permute(0, 3, 1, 2))
draw = ImageDraw.Draw(pil_image)
draw.rectangle(bbox, outline=color, width=width)
image = image_to_tensor(pil_image).permute(0, 2, 3, 1)
return (image,)
def get_dtype(dtype: str) -> torch.dtype:
"""Converts a string dtype to a torch.dtype.
See also https://pytorch.org/docs/stable/tensor_attributes.html#torch-dtype"""
match dtype:
case "float32" | "float":
return torch.float32
case "float64" | "double":
return torch.float64
case "complex64" | "cfloat":
return torch.complex64
case "complex128" | "cdouble":
return torch.complex128
case "float16" | "half":
return torch.float16
case "bfloat16":
return torch.bfloat16
case "uint8":
return torch.uint8
case "int8":
return torch.int8
case "int16" | "short":
return torch.int16
case "int32" | "int":
return torch.int32
case "int64" | "long":
return torch.int64
case "bool":
return torch.bool
case _:
raise ValueError(f"Unknown dtype: {dtype}")
NODE_CLASS_MAPPINGS: dict[str, Any] = {
"DrawBoundingBox": DrawBoundingBox,
}