From b67efb26df642a2d81494823eaa9f81bdabcde0b Mon Sep 17 00:00:00 2001 From: Laurent Date: Thu, 5 Sep 2024 09:30:48 +0000 Subject: [PATCH] add comfyui custom nodes --- pyproject.toml | 1 + src/comfyui-refiners/LICENSE | 1 + src/comfyui-refiners/README.md | 30 ++++ src/comfyui-refiners/__init__.py | 15 ++ src/comfyui-refiners/box_segmenter.py | 105 ++++++++++++++ src/comfyui-refiners/grounding_dino.py | 191 +++++++++++++++++++++++++ src/comfyui-refiners/huggingface.py | 63 ++++++++ src/comfyui-refiners/pyproject.toml | 18 +++ src/comfyui-refiners/utils.py | 86 +++++++++++ 9 files changed, 510 insertions(+) create mode 120000 src/comfyui-refiners/LICENSE create mode 100644 src/comfyui-refiners/README.md create mode 100644 src/comfyui-refiners/__init__.py create mode 100644 src/comfyui-refiners/box_segmenter.py create mode 100644 src/comfyui-refiners/grounding_dino.py create mode 100644 src/comfyui-refiners/huggingface.py create mode 100644 src/comfyui-refiners/pyproject.toml create mode 100644 src/comfyui-refiners/utils.py diff --git a/pyproject.toml b/pyproject.toml index 02ac925..0818e80 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -92,6 +92,7 @@ dev-dependencies = [ "pytest>=8.0.0", "coverage>=7.4.1", "typos>=1.18.2", + "comfy-cli>=1.1.6", ] diff --git a/src/comfyui-refiners/LICENSE b/src/comfyui-refiners/LICENSE new file mode 120000 index 0000000..30cff74 --- /dev/null +++ b/src/comfyui-refiners/LICENSE @@ -0,0 +1 @@ +../../LICENSE \ No newline at end of file diff --git a/src/comfyui-refiners/README.md b/src/comfyui-refiners/README.md new file mode 100644 index 0000000..371607e --- /dev/null +++ b/src/comfyui-refiners/README.md @@ -0,0 +1,30 @@ +
+ + + + + Finegrain Refiners Library + + +**The simplest way to train and run adapters on top of foundation models** + +[**Manifesto**](https://refine.rs/home/why/) | +[**Docs**](https://refine.rs) | +[**Guides**](https://refine.rs/guides/adapting_sdxl/) | +[**Discussions**](https://github.com/finegrain-ai/refiners/discussions) | +[**Discord**](https://discord.gg/mCmjNUVV7d) + +
+ +## Installation + +The nodes are published at https://registry.comfy.org/publishers/finegrain/nodes/comfyui-refiners. + +To easily install the nodes, run the following command: +```bash +comfy node registry-install comfyui-refiners +``` + +You may also download the nodes by cliking the "Download Latest" button and unzipping the content of the archive into you custom_nodes directory. + +See https://docs.comfy.org/registry/overview for more information. diff --git a/src/comfyui-refiners/__init__.py b/src/comfyui-refiners/__init__.py new file mode 100644 index 0000000..7e2c4c7 --- /dev/null +++ b/src/comfyui-refiners/__init__.py @@ -0,0 +1,15 @@ +from typing import Any + +from .box_segmenter import NODE_CLASS_MAPPINGS as box_segmenter_mappings +from .grounding_dino import NODE_CLASS_MAPPINGS as grounding_dino_mappings +from .huggingface import NODE_CLASS_MAPPINGS as huggingface_mappings +from .utils import NODE_CLASS_MAPPINGS as utils_mappings + +NODE_CLASS_MAPPINGS: dict[str, Any] = {} +NODE_CLASS_MAPPINGS.update(box_segmenter_mappings) +NODE_CLASS_MAPPINGS.update(grounding_dino_mappings) +NODE_CLASS_MAPPINGS.update(huggingface_mappings) +NODE_CLASS_MAPPINGS.update(utils_mappings) + +NODE_DISPLAY_NAME_MAPPINGS = {k: v.__name__ for k, v in NODE_CLASS_MAPPINGS.items()} +__all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS"] diff --git a/src/comfyui-refiners/box_segmenter.py b/src/comfyui-refiners/box_segmenter.py new file mode 100644 index 0000000..9dc37ce --- /dev/null +++ b/src/comfyui-refiners/box_segmenter.py @@ -0,0 +1,105 @@ +from typing import Any + +import torch + +from refiners.fluxion.utils import image_to_tensor, no_grad, tensor_to_image +from refiners.solutions import BoxSegmenter as _BoxSegmenter +from refiners.solutions.box_segmenter import BoundingBox + + +class LoadBoxSegmenter: + @classmethod + def INPUT_TYPES(cls) -> dict[str, Any]: + return { + "required": { + "checkpoint": ("PATH", {}), + "margin": ( + "FLOAT", + { + "default": 0.05, + "min": 0.0, + "max": 1.0, + "step": 0.01, + }, + ), + "device": ("STRING", {"default": "cuda"}), + } + } + + RETURN_TYPES = ("MODEL",) + RETURN_NAMES = ("model",) + DESCRIPTION = "Load a BoxSegmenter refiners model." + CATEGORY = "Refiners/Solutions" + FUNCTION = "load" + + def load( + self, + checkpoint: str, + margin: float, + device: str, + ) -> tuple[_BoxSegmenter]: + """Load a BoxSegmenter refiners model. + + Args: + checkpoint: The path to the checkpoint file. + margin: The bbox margin to use when processing images. + device: The torch device to load the model on. + + Returns: + A BoxSegmenter model instance. + """ + return ( + _BoxSegmenter( + weights=checkpoint, + margin=margin, + device=device, + ), + ) + + +class BoxSegmenter: + @classmethod + def INPUT_TYPES(cls) -> dict[str, Any]: + return { + "required": { + "model": ("MODEL", {}), + "image": ("IMAGE", {}), + }, + "optional": { + "bbox": ("BOUNDING_BOX", {}), + }, + } + + RETURN_TYPES = ("MASK",) + RETURN_NAMES = ("mask",) + DESCRIPTION = "Segment an image using a BoxSegmenter model and a bbox." + CATEGORY = "Refiners/Solutions" + FUNCTION = "process" + + @no_grad() + def process( + self, + model: _BoxSegmenter, + image: torch.Tensor, + bbox: BoundingBox | None = None, + ) -> tuple[torch.Tensor]: + """Segment an image using a BoxSegmenter model and a bbox. + + Args: + model: The BoxSegmenter model to use. + image: The input image to process. + bbox: Where in the image to apply the model. + + Returns: + The mask of the segmented object. + """ + pil_image = tensor_to_image(image.permute(0, 3, 1, 2)) + mask = model(img=pil_image, box_prompt=bbox) + mask_tensor = image_to_tensor(mask).squeeze(1) + return (mask_tensor,) + + +NODE_CLASS_MAPPINGS: dict[str, Any] = { + "BoxSegmenter": BoxSegmenter, + "LoadBoxSegmenter": LoadBoxSegmenter, +} diff --git a/src/comfyui-refiners/grounding_dino.py b/src/comfyui-refiners/grounding_dino.py new file mode 100644 index 0000000..4675872 --- /dev/null +++ b/src/comfyui-refiners/grounding_dino.py @@ -0,0 +1,191 @@ +from typing import Any, Sequence + +import torch +from transformers import GroundingDinoForObjectDetection, GroundingDinoProcessor # type: ignore + +from refiners.fluxion.utils import no_grad, tensor_to_image + +from .utils import BoundingBox, get_dtype + + +class LoadGroundingDino: + @classmethod + def INPUT_TYPES(cls) -> dict[str, Any]: + return { + "required": { + "checkpoint": ("PATH", {}), + "dtype": ( + "STRING", + { + "default": "float32", + }, + ), + "device": ( + "STRING", + { + "default": "cuda", + }, + ), + } + } + + RETURN_TYPES = ("PROCESSOR", "MODEL") + RETURN_NAMES = ("processor", "model") + DESCRIPTION = "Load a grounding dino model." + CATEGORY = "Refiners/Solutions" + FUNCTION = "load" + + def load( + self, + checkpoint: str, + dtype: str, + device: str, + ) -> tuple[GroundingDinoProcessor, GroundingDinoForObjectDetection]: + """Load a grounding dino model. + + Args: + checkpoint: The path to the checkpoint folder. + dtype: The torch data type to use. + device: The torch device to load the model on. + + Returns: + The grounding dino processor and model instances. + """ + processor = GroundingDinoProcessor.from_pretrained(checkpoint) # type: ignore + assert isinstance(processor, GroundingDinoProcessor) + + model = GroundingDinoForObjectDetection.from_pretrained(checkpoint, torch_dtype=get_dtype(dtype)) # type: ignore + model = model.to(device=device) # type: ignore + assert isinstance(model, GroundingDinoForObjectDetection) + + return (processor, model) + + +# NOTE: not yet natively supported in Refiners, hence the transformers dependency +class GroundingDino: + @classmethod + def INPUT_TYPES(cls) -> dict[str, Any]: + return { + "required": { + "processor": ("PROCESSOR", {}), + "model": ("MODEL", {}), + "image": ("IMAGE", {}), + "prompt": ("STRING", {}), + "box_threshold": ( + "FLOAT", + { + "default": 0.25, + "min": 0.0, + "max": 1.0, + "step": 0.01, + }, + ), + "text_threshold": ( + "FLOAT", + { + "default": 0.25, + "min": 0.0, + "max": 1.0, + "step": 0.01, + }, + ), + }, + } + + RETURN_TYPES = ("BOUNDING_BOX",) + RETURN_NAMES = ("bbox",) + DESCRIPTION = "Detect an object in an image using a GroundingDino model." + CATEGORY = "Refiners/Solutions" + FUNCTION = "process" + + @staticmethod + def corners_to_pixels_format( + bboxes: torch.Tensor, + width: int, + height: int, + ) -> torch.Tensor: + x1, y1, x2, y2 = bboxes.round().to(torch.int32).unbind(-1) + return torch.stack( + tensors=( + x1.clamp_(0, width), + y1.clamp_(0, height), + x2.clamp_(0, width), + y2.clamp_(0, height), + ), + dim=-1, + ) + + @staticmethod + def bbox_union(bboxes: Sequence[list[int]]) -> BoundingBox | None: + if not bboxes: + return None + for bbox in bboxes: + assert len(bbox) == 4 + assert all(isinstance(x, int) for x in bbox) + return ( + min(bbox[0] for bbox in bboxes), + min(bbox[1] for bbox in bboxes), + max(bbox[2] for bbox in bboxes), + max(bbox[3] for bbox in bboxes), + ) + + @no_grad() + def process( + self, + processor: GroundingDinoProcessor, + model: GroundingDinoForObjectDetection, + image: torch.Tensor, + prompt: str, + box_threshold: float, + text_threshold: float, + ) -> tuple[BoundingBox]: + """Detect an object in an image using a GroundingDino model and a text prompt. + + Args: + processor: The image processor to use. + model: The grounding dino model to use. + image: The input image to detect in. + prompt: The text prompt of what to detect in the image. + box_threshold: The score threshold for the bounding boxes. + text_threshold: The score threshold for the text. + + Returns: + The union of the bounding boxes found in the image. + """ + # prepare the inputs + pil_image = tensor_to_image(image.permute(0, 3, 1, 2)) + + # NOTE: queries must be in lower cas + end with a dot. See: + # https://github.com/IDEA-Research/GroundingDINO/blob/856dde2/groundingdino/util/inference.py#L22-L26 + inputs = processor(images=pil_image, text=f"{prompt.lower()}.", return_tensors="pt").to(device=model.device) + + # get the model's prediction + outputs = model(**inputs) + + # post-process the model's prediction + results: dict[str, Any] = processor.post_process_grounded_object_detection( # type: ignore + outputs=outputs, + input_ids=inputs["input_ids"], + target_sizes=[(pil_image.height, pil_image.width)], + box_threshold=box_threshold, + text_threshold=text_threshold, + )[0] + + # retrieve the bounding boxes + assert "boxes" in results + bboxes = results["boxes"].cpu() # type: ignore + assert isinstance(bboxes, torch.Tensor) + assert bboxes.shape[0] != 0, "No bounding boxes found. Try adjusting the thresholds or pick another prompt." + bboxes = self.corners_to_pixels_format(bboxes, pil_image.width, pil_image.height) # type: ignore + + # compute the union of the bounding boxes + bbox = self.bbox_union(bboxes.numpy().tolist()) # type: ignore + assert bbox is not None + + return (bbox,) + + +NODE_CLASS_MAPPINGS: dict[str, Any] = { + "GroundingDino": GroundingDino, + "LoadGroundingDino": LoadGroundingDino, +} diff --git a/src/comfyui-refiners/huggingface.py b/src/comfyui-refiners/huggingface.py new file mode 100644 index 0000000..00624da --- /dev/null +++ b/src/comfyui-refiners/huggingface.py @@ -0,0 +1,63 @@ +from pathlib import Path +from typing import Any + +from huggingface_hub import hf_hub_download, snapshot_download # type: ignore + + +class HfHubDownload: + @classmethod + def INPUT_TYPES(cls) -> dict[str, Any]: + return { + "required": { + "repo_id": ("STRING", {}), + }, + "optional": { + "filename": ("STRING", {}), + "revision": ( + "STRING", + { + "default": "main", + }, + ), + }, + } + + RETURN_TYPES = ("PATH",) + RETURN_NAMES = ("path",) + DESCRIPTION = "Download file(s) from the HuggingFace Hub." + CATEGORY = "Refiners/HuggingFace" + FUNCTION = "download" + + def download( + self, + repo_id: str, + filename: str, + revision: str, + ) -> tuple[Path]: + """Download file(s) from the HuggingFace Hub. + + Args: + repo_id: The HuggingFace repository ID. + filename: The filename to download, if empty, the entire repository will be downloaded. + revision: The git revision to download. + + Returns: + The path to the downloaded file(s). + """ + if filename == "": + path = snapshot_download( + repo_id=repo_id, + revision=revision, + ) + else: + path = hf_hub_download( + repo_id=repo_id, + filename=filename, + revision=revision, + ) + return (Path(path),) + + +NODE_CLASS_MAPPINGS: dict[str, Any] = { + "HfHubDownload": HfHubDownload, +} diff --git a/src/comfyui-refiners/pyproject.toml b/src/comfyui-refiners/pyproject.toml new file mode 100644 index 0000000..64fa98e --- /dev/null +++ b/src/comfyui-refiners/pyproject.toml @@ -0,0 +1,18 @@ +[project] +name = "comfyui-refiners" +description = "ComfyUI custom nodes for refiners models" +version = "1.0.1" +license = { file = "LICENSE" } +dependencies = [ + "refiners @ git+https://github.com/finegrain-ai/refiners.git", + "huggingface_hub", + "transformers", +] + +[project.urls] +Repository = "https://github.com/finegrain-ai/refiners" + +[tool.comfy] +PublisherId = "finegrain" +DisplayName = "refiners" +Icon = "https://raw.githubusercontent.com/finegrain-ai/refiners/main/assets/dropy.png" diff --git a/src/comfyui-refiners/utils.py b/src/comfyui-refiners/utils.py new file mode 100644 index 0000000..9595376 --- /dev/null +++ b/src/comfyui-refiners/utils.py @@ -0,0 +1,86 @@ +from typing import Any + +import torch +from PIL import ImageDraw + +from refiners.fluxion.utils import image_to_tensor, tensor_to_image + +BoundingBox = tuple[int, int, int, int] + + +class DrawBoundingBox: + @classmethod + def INPUT_TYPES(cls) -> dict[str, Any]: + return { + "required": { + "image": ("IMAGE", {}), + "bbox": ("BOUNDING_BOX", {}), + "color": ("STRING", {"default": "red"}), + "width": ("INT", {"default": 3}), + }, + } + + RETURN_TYPES = ("IMAGE",) + RETURN_NAMES = ("image",) + DESCRIPTION = "Draw a bounding box on an image." + CATEGORY = "Refiners/Helpers" + FUNCTION = "process" + + def process( + self, + image: torch.Tensor, + bbox: BoundingBox, + color: str, + width: int, + ) -> tuple[torch.Tensor]: + """Draw a bounding box on an image. + + Args: + image: The image to draw on. + bbox: The bounding box to draw. + color: The color of the bounding box. + width: The width of the bounding box. + """ + pil_image = tensor_to_image(image.permute(0, 3, 1, 2)) + draw = ImageDraw.Draw(pil_image) + draw.rectangle(bbox, outline=color, width=width) + image = image_to_tensor(pil_image).permute(0, 2, 3, 1) + return (image,) + + +def get_dtype(dtype: str) -> torch.dtype: + """Converts a string dtype to a torch.dtype. + + See also https://pytorch.org/docs/stable/tensor_attributes.html#torch-dtype""" + match dtype: + case "float32" | "float": + return torch.float32 + case "float64" | "double": + return torch.float64 + case "complex64" | "cfloat": + return torch.complex64 + case "complex128" | "cdouble": + return torch.complex128 + case "float16" | "half": + return torch.float16 + case "bfloat16": + return torch.bfloat16 + case "uint8": + return torch.uint8 + case "int8": + return torch.int8 + case "int16" | "short": + return torch.int16 + case "int32" | "int": + return torch.int32 + case "int64" | "long": + return torch.int64 + case "bool": + return torch.bool + case _: + raise ValueError(f"Unknown dtype: {dtype}") + + +NODE_CLASS_MAPPINGS: dict[str, Any] = { + "DrawBoundingBox": DrawBoundingBox, +}