mirror of
https://github.com/finegrain-ai/refiners.git
synced 2024-11-21 13:48:46 +00:00
add comfyui custom nodes
This commit is contained in:
parent
b6f547d14a
commit
b67efb26df
|
@ -92,6 +92,7 @@ dev-dependencies = [
|
|||
"pytest>=8.0.0",
|
||||
"coverage>=7.4.1",
|
||||
"typos>=1.18.2",
|
||||
"comfy-cli>=1.1.6",
|
||||
]
|
||||
|
||||
|
||||
|
|
1
src/comfyui-refiners/LICENSE
Symbolic link
1
src/comfyui-refiners/LICENSE
Symbolic link
|
@ -0,0 +1 @@
|
|||
../../LICENSE
|
30
src/comfyui-refiners/README.md
Normal file
30
src/comfyui-refiners/README.md
Normal file
|
@ -0,0 +1,30 @@
|
|||
<div align="center">
|
||||
|
||||
<picture>
|
||||
<source media="(prefers-color-scheme: dark)" srcset="https://raw.githubusercontent.com/finegrain-ai/refiners/main/assets/logo_dark.png">
|
||||
<source media="(prefers-color-scheme: light)" srcset="https://raw.githubusercontent.com/finegrain-ai/refiners/main/assets/logo_light.png">
|
||||
<img alt="Finegrain Refiners Library" width="352" height="128" style="max-width: 100%;">
|
||||
</picture>
|
||||
|
||||
**The simplest way to train and run adapters on top of foundation models**
|
||||
|
||||
[**Manifesto**](https://refine.rs/home/why/) |
|
||||
[**Docs**](https://refine.rs) |
|
||||
[**Guides**](https://refine.rs/guides/adapting_sdxl/) |
|
||||
[**Discussions**](https://github.com/finegrain-ai/refiners/discussions) |
|
||||
[**Discord**](https://discord.gg/mCmjNUVV7d)
|
||||
|
||||
</div>
|
||||
|
||||
## Installation
|
||||
|
||||
The nodes are published at https://registry.comfy.org/publishers/finegrain/nodes/comfyui-refiners.
|
||||
|
||||
To easily install the nodes, run the following command:
|
||||
```bash
|
||||
comfy node registry-install comfyui-refiners
|
||||
```
|
||||
|
||||
You may also download the nodes by cliking the "Download Latest" button and unzipping the content of the archive into you custom_nodes directory.
|
||||
|
||||
See https://docs.comfy.org/registry/overview for more information.
|
15
src/comfyui-refiners/__init__.py
Normal file
15
src/comfyui-refiners/__init__.py
Normal file
|
@ -0,0 +1,15 @@
|
|||
from typing import Any
|
||||
|
||||
from .box_segmenter import NODE_CLASS_MAPPINGS as box_segmenter_mappings
|
||||
from .grounding_dino import NODE_CLASS_MAPPINGS as grounding_dino_mappings
|
||||
from .huggingface import NODE_CLASS_MAPPINGS as huggingface_mappings
|
||||
from .utils import NODE_CLASS_MAPPINGS as utils_mappings
|
||||
|
||||
NODE_CLASS_MAPPINGS: dict[str, Any] = {}
|
||||
NODE_CLASS_MAPPINGS.update(box_segmenter_mappings)
|
||||
NODE_CLASS_MAPPINGS.update(grounding_dino_mappings)
|
||||
NODE_CLASS_MAPPINGS.update(huggingface_mappings)
|
||||
NODE_CLASS_MAPPINGS.update(utils_mappings)
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {k: v.__name__ for k, v in NODE_CLASS_MAPPINGS.items()}
|
||||
__all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS"]
|
105
src/comfyui-refiners/box_segmenter.py
Normal file
105
src/comfyui-refiners/box_segmenter.py
Normal file
|
@ -0,0 +1,105 @@
|
|||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from refiners.fluxion.utils import image_to_tensor, no_grad, tensor_to_image
|
||||
from refiners.solutions import BoxSegmenter as _BoxSegmenter
|
||||
from refiners.solutions.box_segmenter import BoundingBox
|
||||
|
||||
|
||||
class LoadBoxSegmenter:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> dict[str, Any]:
|
||||
return {
|
||||
"required": {
|
||||
"checkpoint": ("PATH", {}),
|
||||
"margin": (
|
||||
"FLOAT",
|
||||
{
|
||||
"default": 0.05,
|
||||
"min": 0.0,
|
||||
"max": 1.0,
|
||||
"step": 0.01,
|
||||
},
|
||||
),
|
||||
"device": ("STRING", {"default": "cuda"}),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
RETURN_NAMES = ("model",)
|
||||
DESCRIPTION = "Load a BoxSegmenter refiners model."
|
||||
CATEGORY = "Refiners/Solutions"
|
||||
FUNCTION = "load"
|
||||
|
||||
def load(
|
||||
self,
|
||||
checkpoint: str,
|
||||
margin: float,
|
||||
device: str,
|
||||
) -> tuple[_BoxSegmenter]:
|
||||
"""Load a BoxSegmenter refiners model.
|
||||
|
||||
Args:
|
||||
checkpoint: The path to the checkpoint file.
|
||||
margin: The bbox margin to use when processing images.
|
||||
device: The torch device to load the model on.
|
||||
|
||||
Returns:
|
||||
A BoxSegmenter model instance.
|
||||
"""
|
||||
return (
|
||||
_BoxSegmenter(
|
||||
weights=checkpoint,
|
||||
margin=margin,
|
||||
device=device,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class BoxSegmenter:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> dict[str, Any]:
|
||||
return {
|
||||
"required": {
|
||||
"model": ("MODEL", {}),
|
||||
"image": ("IMAGE", {}),
|
||||
},
|
||||
"optional": {
|
||||
"bbox": ("BOUNDING_BOX", {}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("MASK",)
|
||||
RETURN_NAMES = ("mask",)
|
||||
DESCRIPTION = "Segment an image using a BoxSegmenter model and a bbox."
|
||||
CATEGORY = "Refiners/Solutions"
|
||||
FUNCTION = "process"
|
||||
|
||||
@no_grad()
|
||||
def process(
|
||||
self,
|
||||
model: _BoxSegmenter,
|
||||
image: torch.Tensor,
|
||||
bbox: BoundingBox | None = None,
|
||||
) -> tuple[torch.Tensor]:
|
||||
"""Segment an image using a BoxSegmenter model and a bbox.
|
||||
|
||||
Args:
|
||||
model: The BoxSegmenter model to use.
|
||||
image: The input image to process.
|
||||
bbox: Where in the image to apply the model.
|
||||
|
||||
Returns:
|
||||
The mask of the segmented object.
|
||||
"""
|
||||
pil_image = tensor_to_image(image.permute(0, 3, 1, 2))
|
||||
mask = model(img=pil_image, box_prompt=bbox)
|
||||
mask_tensor = image_to_tensor(mask).squeeze(1)
|
||||
return (mask_tensor,)
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS: dict[str, Any] = {
|
||||
"BoxSegmenter": BoxSegmenter,
|
||||
"LoadBoxSegmenter": LoadBoxSegmenter,
|
||||
}
|
191
src/comfyui-refiners/grounding_dino.py
Normal file
191
src/comfyui-refiners/grounding_dino.py
Normal file
|
@ -0,0 +1,191 @@
|
|||
from typing import Any, Sequence
|
||||
|
||||
import torch
|
||||
from transformers import GroundingDinoForObjectDetection, GroundingDinoProcessor # type: ignore
|
||||
|
||||
from refiners.fluxion.utils import no_grad, tensor_to_image
|
||||
|
||||
from .utils import BoundingBox, get_dtype
|
||||
|
||||
|
||||
class LoadGroundingDino:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> dict[str, Any]:
|
||||
return {
|
||||
"required": {
|
||||
"checkpoint": ("PATH", {}),
|
||||
"dtype": (
|
||||
"STRING",
|
||||
{
|
||||
"default": "float32",
|
||||
},
|
||||
),
|
||||
"device": (
|
||||
"STRING",
|
||||
{
|
||||
"default": "cuda",
|
||||
},
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("PROCESSOR", "MODEL")
|
||||
RETURN_NAMES = ("processor", "model")
|
||||
DESCRIPTION = "Load a grounding dino model."
|
||||
CATEGORY = "Refiners/Solutions"
|
||||
FUNCTION = "load"
|
||||
|
||||
def load(
|
||||
self,
|
||||
checkpoint: str,
|
||||
dtype: str,
|
||||
device: str,
|
||||
) -> tuple[GroundingDinoProcessor, GroundingDinoForObjectDetection]:
|
||||
"""Load a grounding dino model.
|
||||
|
||||
Args:
|
||||
checkpoint: The path to the checkpoint folder.
|
||||
dtype: The torch data type to use.
|
||||
device: The torch device to load the model on.
|
||||
|
||||
Returns:
|
||||
The grounding dino processor and model instances.
|
||||
"""
|
||||
processor = GroundingDinoProcessor.from_pretrained(checkpoint) # type: ignore
|
||||
assert isinstance(processor, GroundingDinoProcessor)
|
||||
|
||||
model = GroundingDinoForObjectDetection.from_pretrained(checkpoint, torch_dtype=get_dtype(dtype)) # type: ignore
|
||||
model = model.to(device=device) # type: ignore
|
||||
assert isinstance(model, GroundingDinoForObjectDetection)
|
||||
|
||||
return (processor, model)
|
||||
|
||||
|
||||
# NOTE: not yet natively supported in Refiners, hence the transformers dependency
|
||||
class GroundingDino:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> dict[str, Any]:
|
||||
return {
|
||||
"required": {
|
||||
"processor": ("PROCESSOR", {}),
|
||||
"model": ("MODEL", {}),
|
||||
"image": ("IMAGE", {}),
|
||||
"prompt": ("STRING", {}),
|
||||
"box_threshold": (
|
||||
"FLOAT",
|
||||
{
|
||||
"default": 0.25,
|
||||
"min": 0.0,
|
||||
"max": 1.0,
|
||||
"step": 0.01,
|
||||
},
|
||||
),
|
||||
"text_threshold": (
|
||||
"FLOAT",
|
||||
{
|
||||
"default": 0.25,
|
||||
"min": 0.0,
|
||||
"max": 1.0,
|
||||
"step": 0.01,
|
||||
},
|
||||
),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("BOUNDING_BOX",)
|
||||
RETURN_NAMES = ("bbox",)
|
||||
DESCRIPTION = "Detect an object in an image using a GroundingDino model."
|
||||
CATEGORY = "Refiners/Solutions"
|
||||
FUNCTION = "process"
|
||||
|
||||
@staticmethod
|
||||
def corners_to_pixels_format(
|
||||
bboxes: torch.Tensor,
|
||||
width: int,
|
||||
height: int,
|
||||
) -> torch.Tensor:
|
||||
x1, y1, x2, y2 = bboxes.round().to(torch.int32).unbind(-1)
|
||||
return torch.stack(
|
||||
tensors=(
|
||||
x1.clamp_(0, width),
|
||||
y1.clamp_(0, height),
|
||||
x2.clamp_(0, width),
|
||||
y2.clamp_(0, height),
|
||||
),
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def bbox_union(bboxes: Sequence[list[int]]) -> BoundingBox | None:
|
||||
if not bboxes:
|
||||
return None
|
||||
for bbox in bboxes:
|
||||
assert len(bbox) == 4
|
||||
assert all(isinstance(x, int) for x in bbox)
|
||||
return (
|
||||
min(bbox[0] for bbox in bboxes),
|
||||
min(bbox[1] for bbox in bboxes),
|
||||
max(bbox[2] for bbox in bboxes),
|
||||
max(bbox[3] for bbox in bboxes),
|
||||
)
|
||||
|
||||
@no_grad()
|
||||
def process(
|
||||
self,
|
||||
processor: GroundingDinoProcessor,
|
||||
model: GroundingDinoForObjectDetection,
|
||||
image: torch.Tensor,
|
||||
prompt: str,
|
||||
box_threshold: float,
|
||||
text_threshold: float,
|
||||
) -> tuple[BoundingBox]:
|
||||
"""Detect an object in an image using a GroundingDino model and a text prompt.
|
||||
|
||||
Args:
|
||||
processor: The image processor to use.
|
||||
model: The grounding dino model to use.
|
||||
image: The input image to detect in.
|
||||
prompt: The text prompt of what to detect in the image.
|
||||
box_threshold: The score threshold for the bounding boxes.
|
||||
text_threshold: The score threshold for the text.
|
||||
|
||||
Returns:
|
||||
The union of the bounding boxes found in the image.
|
||||
"""
|
||||
# prepare the inputs
|
||||
pil_image = tensor_to_image(image.permute(0, 3, 1, 2))
|
||||
|
||||
# NOTE: queries must be in lower cas + end with a dot. See:
|
||||
# https://github.com/IDEA-Research/GroundingDINO/blob/856dde2/groundingdino/util/inference.py#L22-L26
|
||||
inputs = processor(images=pil_image, text=f"{prompt.lower()}.", return_tensors="pt").to(device=model.device)
|
||||
|
||||
# get the model's prediction
|
||||
outputs = model(**inputs)
|
||||
|
||||
# post-process the model's prediction
|
||||
results: dict[str, Any] = processor.post_process_grounded_object_detection( # type: ignore
|
||||
outputs=outputs,
|
||||
input_ids=inputs["input_ids"],
|
||||
target_sizes=[(pil_image.height, pil_image.width)],
|
||||
box_threshold=box_threshold,
|
||||
text_threshold=text_threshold,
|
||||
)[0]
|
||||
|
||||
# retrieve the bounding boxes
|
||||
assert "boxes" in results
|
||||
bboxes = results["boxes"].cpu() # type: ignore
|
||||
assert isinstance(bboxes, torch.Tensor)
|
||||
assert bboxes.shape[0] != 0, "No bounding boxes found. Try adjusting the thresholds or pick another prompt."
|
||||
bboxes = self.corners_to_pixels_format(bboxes, pil_image.width, pil_image.height) # type: ignore
|
||||
|
||||
# compute the union of the bounding boxes
|
||||
bbox = self.bbox_union(bboxes.numpy().tolist()) # type: ignore
|
||||
assert bbox is not None
|
||||
|
||||
return (bbox,)
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS: dict[str, Any] = {
|
||||
"GroundingDino": GroundingDino,
|
||||
"LoadGroundingDino": LoadGroundingDino,
|
||||
}
|
63
src/comfyui-refiners/huggingface.py
Normal file
63
src/comfyui-refiners/huggingface.py
Normal file
|
@ -0,0 +1,63 @@
|
|||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from huggingface_hub import hf_hub_download, snapshot_download # type: ignore
|
||||
|
||||
|
||||
class HfHubDownload:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> dict[str, Any]:
|
||||
return {
|
||||
"required": {
|
||||
"repo_id": ("STRING", {}),
|
||||
},
|
||||
"optional": {
|
||||
"filename": ("STRING", {}),
|
||||
"revision": (
|
||||
"STRING",
|
||||
{
|
||||
"default": "main",
|
||||
},
|
||||
),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("PATH",)
|
||||
RETURN_NAMES = ("path",)
|
||||
DESCRIPTION = "Download file(s) from the HuggingFace Hub."
|
||||
CATEGORY = "Refiners/HuggingFace"
|
||||
FUNCTION = "download"
|
||||
|
||||
def download(
|
||||
self,
|
||||
repo_id: str,
|
||||
filename: str,
|
||||
revision: str,
|
||||
) -> tuple[Path]:
|
||||
"""Download file(s) from the HuggingFace Hub.
|
||||
|
||||
Args:
|
||||
repo_id: The HuggingFace repository ID.
|
||||
filename: The filename to download, if empty, the entire repository will be downloaded.
|
||||
revision: The git revision to download.
|
||||
|
||||
Returns:
|
||||
The path to the downloaded file(s).
|
||||
"""
|
||||
if filename == "":
|
||||
path = snapshot_download(
|
||||
repo_id=repo_id,
|
||||
revision=revision,
|
||||
)
|
||||
else:
|
||||
path = hf_hub_download(
|
||||
repo_id=repo_id,
|
||||
filename=filename,
|
||||
revision=revision,
|
||||
)
|
||||
return (Path(path),)
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS: dict[str, Any] = {
|
||||
"HfHubDownload": HfHubDownload,
|
||||
}
|
18
src/comfyui-refiners/pyproject.toml
Normal file
18
src/comfyui-refiners/pyproject.toml
Normal file
|
@ -0,0 +1,18 @@
|
|||
[project]
|
||||
name = "comfyui-refiners"
|
||||
description = "ComfyUI custom nodes for refiners models"
|
||||
version = "1.0.1"
|
||||
license = { file = "LICENSE" }
|
||||
dependencies = [
|
||||
"refiners @ git+https://github.com/finegrain-ai/refiners.git",
|
||||
"huggingface_hub",
|
||||
"transformers",
|
||||
]
|
||||
|
||||
[project.urls]
|
||||
Repository = "https://github.com/finegrain-ai/refiners"
|
||||
|
||||
[tool.comfy]
|
||||
PublisherId = "finegrain"
|
||||
DisplayName = "refiners"
|
||||
Icon = "https://raw.githubusercontent.com/finegrain-ai/refiners/main/assets/dropy.png"
|
86
src/comfyui-refiners/utils.py
Normal file
86
src/comfyui-refiners/utils.py
Normal file
|
@ -0,0 +1,86 @@
|
|||
from typing import Any
|
||||
|
||||
import torch
|
||||
from PIL import ImageDraw
|
||||
|
||||
from refiners.fluxion.utils import image_to_tensor, tensor_to_image
|
||||
|
||||
BoundingBox = tuple[int, int, int, int]
|
||||
|
||||
|
||||
class DrawBoundingBox:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> dict[str, Any]:
|
||||
return {
|
||||
"required": {
|
||||
"image": ("IMAGE", {}),
|
||||
"bbox": ("BOUNDING_BOX", {}),
|
||||
"color": ("STRING", {"default": "red"}),
|
||||
"width": ("INT", {"default": 3}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
RETURN_NAMES = ("image",)
|
||||
DESCRIPTION = "Draw a bounding box on an image."
|
||||
CATEGORY = "Refiners/Helpers"
|
||||
FUNCTION = "process"
|
||||
|
||||
def process(
|
||||
self,
|
||||
image: torch.Tensor,
|
||||
bbox: BoundingBox,
|
||||
color: str,
|
||||
width: int,
|
||||
) -> tuple[torch.Tensor]:
|
||||
"""Draw a bounding box on an image.
|
||||
|
||||
Args:
|
||||
image: The image to draw on.
|
||||
bbox: The bounding box to draw.
|
||||
color: The color of the bounding box.
|
||||
width: The width of the bounding box.
|
||||
"""
|
||||
pil_image = tensor_to_image(image.permute(0, 3, 1, 2))
|
||||
draw = ImageDraw.Draw(pil_image)
|
||||
draw.rectangle(bbox, outline=color, width=width)
|
||||
image = image_to_tensor(pil_image).permute(0, 2, 3, 1)
|
||||
return (image,)
|
||||
|
||||
|
||||
def get_dtype(dtype: str) -> torch.dtype:
|
||||
"""Converts a string dtype to a torch.dtype.
|
||||
|
||||
See also https://pytorch.org/docs/stable/tensor_attributes.html#torch-dtype"""
|
||||
match dtype:
|
||||
case "float32" | "float":
|
||||
return torch.float32
|
||||
case "float64" | "double":
|
||||
return torch.float64
|
||||
case "complex64" | "cfloat":
|
||||
return torch.complex64
|
||||
case "complex128" | "cdouble":
|
||||
return torch.complex128
|
||||
case "float16" | "half":
|
||||
return torch.float16
|
||||
case "bfloat16":
|
||||
return torch.bfloat16
|
||||
case "uint8":
|
||||
return torch.uint8
|
||||
case "int8":
|
||||
return torch.int8
|
||||
case "int16" | "short":
|
||||
return torch.int16
|
||||
case "int32" | "int":
|
||||
return torch.int32
|
||||
case "int64" | "long":
|
||||
return torch.int64
|
||||
case "bool":
|
||||
return torch.bool
|
||||
case _:
|
||||
raise ValueError(f"Unknown dtype: {dtype}")
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS: dict[str, Any] = {
|
||||
"DrawBoundingBox": DrawBoundingBox,
|
||||
}
|
Loading…
Reference in a new issue