diff --git a/pyproject.toml b/pyproject.toml
index 02ac925..0818e80 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -92,6 +92,7 @@ dev-dependencies = [
"pytest>=8.0.0",
"coverage>=7.4.1",
"typos>=1.18.2",
+ "comfy-cli>=1.1.6",
]
diff --git a/src/comfyui-refiners/LICENSE b/src/comfyui-refiners/LICENSE
new file mode 120000
index 0000000..30cff74
--- /dev/null
+++ b/src/comfyui-refiners/LICENSE
@@ -0,0 +1 @@
+../../LICENSE
\ No newline at end of file
diff --git a/src/comfyui-refiners/README.md b/src/comfyui-refiners/README.md
new file mode 100644
index 0000000..371607e
--- /dev/null
+++ b/src/comfyui-refiners/README.md
@@ -0,0 +1,30 @@
+
+
+
+
+**The simplest way to train and run adapters on top of foundation models**
+
+[**Manifesto**](https://refine.rs/home/why/) |
+[**Docs**](https://refine.rs) |
+[**Guides**](https://refine.rs/guides/adapting_sdxl/) |
+[**Discussions**](https://github.com/finegrain-ai/refiners/discussions) |
+[**Discord**](https://discord.gg/mCmjNUVV7d)
+
+
+
+## Installation
+
+The nodes are published at https://registry.comfy.org/publishers/finegrain/nodes/comfyui-refiners.
+
+To easily install the nodes, run the following command:
+```bash
+comfy node registry-install comfyui-refiners
+```
+
+You may also download the nodes by cliking the "Download Latest" button and unzipping the content of the archive into you custom_nodes directory.
+
+See https://docs.comfy.org/registry/overview for more information.
diff --git a/src/comfyui-refiners/__init__.py b/src/comfyui-refiners/__init__.py
new file mode 100644
index 0000000..7e2c4c7
--- /dev/null
+++ b/src/comfyui-refiners/__init__.py
@@ -0,0 +1,15 @@
+from typing import Any
+
+from .box_segmenter import NODE_CLASS_MAPPINGS as box_segmenter_mappings
+from .grounding_dino import NODE_CLASS_MAPPINGS as grounding_dino_mappings
+from .huggingface import NODE_CLASS_MAPPINGS as huggingface_mappings
+from .utils import NODE_CLASS_MAPPINGS as utils_mappings
+
+NODE_CLASS_MAPPINGS: dict[str, Any] = {}
+NODE_CLASS_MAPPINGS.update(box_segmenter_mappings)
+NODE_CLASS_MAPPINGS.update(grounding_dino_mappings)
+NODE_CLASS_MAPPINGS.update(huggingface_mappings)
+NODE_CLASS_MAPPINGS.update(utils_mappings)
+
+NODE_DISPLAY_NAME_MAPPINGS = {k: v.__name__ for k, v in NODE_CLASS_MAPPINGS.items()}
+__all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS"]
diff --git a/src/comfyui-refiners/box_segmenter.py b/src/comfyui-refiners/box_segmenter.py
new file mode 100644
index 0000000..9dc37ce
--- /dev/null
+++ b/src/comfyui-refiners/box_segmenter.py
@@ -0,0 +1,105 @@
+from typing import Any
+
+import torch
+
+from refiners.fluxion.utils import image_to_tensor, no_grad, tensor_to_image
+from refiners.solutions import BoxSegmenter as _BoxSegmenter
+from refiners.solutions.box_segmenter import BoundingBox
+
+
+class LoadBoxSegmenter:
+ @classmethod
+ def INPUT_TYPES(cls) -> dict[str, Any]:
+ return {
+ "required": {
+ "checkpoint": ("PATH", {}),
+ "margin": (
+ "FLOAT",
+ {
+ "default": 0.05,
+ "min": 0.0,
+ "max": 1.0,
+ "step": 0.01,
+ },
+ ),
+ "device": ("STRING", {"default": "cuda"}),
+ }
+ }
+
+ RETURN_TYPES = ("MODEL",)
+ RETURN_NAMES = ("model",)
+ DESCRIPTION = "Load a BoxSegmenter refiners model."
+ CATEGORY = "Refiners/Solutions"
+ FUNCTION = "load"
+
+ def load(
+ self,
+ checkpoint: str,
+ margin: float,
+ device: str,
+ ) -> tuple[_BoxSegmenter]:
+ """Load a BoxSegmenter refiners model.
+
+ Args:
+ checkpoint: The path to the checkpoint file.
+ margin: The bbox margin to use when processing images.
+ device: The torch device to load the model on.
+
+ Returns:
+ A BoxSegmenter model instance.
+ """
+ return (
+ _BoxSegmenter(
+ weights=checkpoint,
+ margin=margin,
+ device=device,
+ ),
+ )
+
+
+class BoxSegmenter:
+ @classmethod
+ def INPUT_TYPES(cls) -> dict[str, Any]:
+ return {
+ "required": {
+ "model": ("MODEL", {}),
+ "image": ("IMAGE", {}),
+ },
+ "optional": {
+ "bbox": ("BOUNDING_BOX", {}),
+ },
+ }
+
+ RETURN_TYPES = ("MASK",)
+ RETURN_NAMES = ("mask",)
+ DESCRIPTION = "Segment an image using a BoxSegmenter model and a bbox."
+ CATEGORY = "Refiners/Solutions"
+ FUNCTION = "process"
+
+ @no_grad()
+ def process(
+ self,
+ model: _BoxSegmenter,
+ image: torch.Tensor,
+ bbox: BoundingBox | None = None,
+ ) -> tuple[torch.Tensor]:
+ """Segment an image using a BoxSegmenter model and a bbox.
+
+ Args:
+ model: The BoxSegmenter model to use.
+ image: The input image to process.
+ bbox: Where in the image to apply the model.
+
+ Returns:
+ The mask of the segmented object.
+ """
+ pil_image = tensor_to_image(image.permute(0, 3, 1, 2))
+ mask = model(img=pil_image, box_prompt=bbox)
+ mask_tensor = image_to_tensor(mask).squeeze(1)
+ return (mask_tensor,)
+
+
+NODE_CLASS_MAPPINGS: dict[str, Any] = {
+ "BoxSegmenter": BoxSegmenter,
+ "LoadBoxSegmenter": LoadBoxSegmenter,
+}
diff --git a/src/comfyui-refiners/grounding_dino.py b/src/comfyui-refiners/grounding_dino.py
new file mode 100644
index 0000000..4675872
--- /dev/null
+++ b/src/comfyui-refiners/grounding_dino.py
@@ -0,0 +1,191 @@
+from typing import Any, Sequence
+
+import torch
+from transformers import GroundingDinoForObjectDetection, GroundingDinoProcessor # type: ignore
+
+from refiners.fluxion.utils import no_grad, tensor_to_image
+
+from .utils import BoundingBox, get_dtype
+
+
+class LoadGroundingDino:
+ @classmethod
+ def INPUT_TYPES(cls) -> dict[str, Any]:
+ return {
+ "required": {
+ "checkpoint": ("PATH", {}),
+ "dtype": (
+ "STRING",
+ {
+ "default": "float32",
+ },
+ ),
+ "device": (
+ "STRING",
+ {
+ "default": "cuda",
+ },
+ ),
+ }
+ }
+
+ RETURN_TYPES = ("PROCESSOR", "MODEL")
+ RETURN_NAMES = ("processor", "model")
+ DESCRIPTION = "Load a grounding dino model."
+ CATEGORY = "Refiners/Solutions"
+ FUNCTION = "load"
+
+ def load(
+ self,
+ checkpoint: str,
+ dtype: str,
+ device: str,
+ ) -> tuple[GroundingDinoProcessor, GroundingDinoForObjectDetection]:
+ """Load a grounding dino model.
+
+ Args:
+ checkpoint: The path to the checkpoint folder.
+ dtype: The torch data type to use.
+ device: The torch device to load the model on.
+
+ Returns:
+ The grounding dino processor and model instances.
+ """
+ processor = GroundingDinoProcessor.from_pretrained(checkpoint) # type: ignore
+ assert isinstance(processor, GroundingDinoProcessor)
+
+ model = GroundingDinoForObjectDetection.from_pretrained(checkpoint, torch_dtype=get_dtype(dtype)) # type: ignore
+ model = model.to(device=device) # type: ignore
+ assert isinstance(model, GroundingDinoForObjectDetection)
+
+ return (processor, model)
+
+
+# NOTE: not yet natively supported in Refiners, hence the transformers dependency
+class GroundingDino:
+ @classmethod
+ def INPUT_TYPES(cls) -> dict[str, Any]:
+ return {
+ "required": {
+ "processor": ("PROCESSOR", {}),
+ "model": ("MODEL", {}),
+ "image": ("IMAGE", {}),
+ "prompt": ("STRING", {}),
+ "box_threshold": (
+ "FLOAT",
+ {
+ "default": 0.25,
+ "min": 0.0,
+ "max": 1.0,
+ "step": 0.01,
+ },
+ ),
+ "text_threshold": (
+ "FLOAT",
+ {
+ "default": 0.25,
+ "min": 0.0,
+ "max": 1.0,
+ "step": 0.01,
+ },
+ ),
+ },
+ }
+
+ RETURN_TYPES = ("BOUNDING_BOX",)
+ RETURN_NAMES = ("bbox",)
+ DESCRIPTION = "Detect an object in an image using a GroundingDino model."
+ CATEGORY = "Refiners/Solutions"
+ FUNCTION = "process"
+
+ @staticmethod
+ def corners_to_pixels_format(
+ bboxes: torch.Tensor,
+ width: int,
+ height: int,
+ ) -> torch.Tensor:
+ x1, y1, x2, y2 = bboxes.round().to(torch.int32).unbind(-1)
+ return torch.stack(
+ tensors=(
+ x1.clamp_(0, width),
+ y1.clamp_(0, height),
+ x2.clamp_(0, width),
+ y2.clamp_(0, height),
+ ),
+ dim=-1,
+ )
+
+ @staticmethod
+ def bbox_union(bboxes: Sequence[list[int]]) -> BoundingBox | None:
+ if not bboxes:
+ return None
+ for bbox in bboxes:
+ assert len(bbox) == 4
+ assert all(isinstance(x, int) for x in bbox)
+ return (
+ min(bbox[0] for bbox in bboxes),
+ min(bbox[1] for bbox in bboxes),
+ max(bbox[2] for bbox in bboxes),
+ max(bbox[3] for bbox in bboxes),
+ )
+
+ @no_grad()
+ def process(
+ self,
+ processor: GroundingDinoProcessor,
+ model: GroundingDinoForObjectDetection,
+ image: torch.Tensor,
+ prompt: str,
+ box_threshold: float,
+ text_threshold: float,
+ ) -> tuple[BoundingBox]:
+ """Detect an object in an image using a GroundingDino model and a text prompt.
+
+ Args:
+ processor: The image processor to use.
+ model: The grounding dino model to use.
+ image: The input image to detect in.
+ prompt: The text prompt of what to detect in the image.
+ box_threshold: The score threshold for the bounding boxes.
+ text_threshold: The score threshold for the text.
+
+ Returns:
+ The union of the bounding boxes found in the image.
+ """
+ # prepare the inputs
+ pil_image = tensor_to_image(image.permute(0, 3, 1, 2))
+
+ # NOTE: queries must be in lower cas + end with a dot. See:
+ # https://github.com/IDEA-Research/GroundingDINO/blob/856dde2/groundingdino/util/inference.py#L22-L26
+ inputs = processor(images=pil_image, text=f"{prompt.lower()}.", return_tensors="pt").to(device=model.device)
+
+ # get the model's prediction
+ outputs = model(**inputs)
+
+ # post-process the model's prediction
+ results: dict[str, Any] = processor.post_process_grounded_object_detection( # type: ignore
+ outputs=outputs,
+ input_ids=inputs["input_ids"],
+ target_sizes=[(pil_image.height, pil_image.width)],
+ box_threshold=box_threshold,
+ text_threshold=text_threshold,
+ )[0]
+
+ # retrieve the bounding boxes
+ assert "boxes" in results
+ bboxes = results["boxes"].cpu() # type: ignore
+ assert isinstance(bboxes, torch.Tensor)
+ assert bboxes.shape[0] != 0, "No bounding boxes found. Try adjusting the thresholds or pick another prompt."
+ bboxes = self.corners_to_pixels_format(bboxes, pil_image.width, pil_image.height) # type: ignore
+
+ # compute the union of the bounding boxes
+ bbox = self.bbox_union(bboxes.numpy().tolist()) # type: ignore
+ assert bbox is not None
+
+ return (bbox,)
+
+
+NODE_CLASS_MAPPINGS: dict[str, Any] = {
+ "GroundingDino": GroundingDino,
+ "LoadGroundingDino": LoadGroundingDino,
+}
diff --git a/src/comfyui-refiners/huggingface.py b/src/comfyui-refiners/huggingface.py
new file mode 100644
index 0000000..00624da
--- /dev/null
+++ b/src/comfyui-refiners/huggingface.py
@@ -0,0 +1,63 @@
+from pathlib import Path
+from typing import Any
+
+from huggingface_hub import hf_hub_download, snapshot_download # type: ignore
+
+
+class HfHubDownload:
+ @classmethod
+ def INPUT_TYPES(cls) -> dict[str, Any]:
+ return {
+ "required": {
+ "repo_id": ("STRING", {}),
+ },
+ "optional": {
+ "filename": ("STRING", {}),
+ "revision": (
+ "STRING",
+ {
+ "default": "main",
+ },
+ ),
+ },
+ }
+
+ RETURN_TYPES = ("PATH",)
+ RETURN_NAMES = ("path",)
+ DESCRIPTION = "Download file(s) from the HuggingFace Hub."
+ CATEGORY = "Refiners/HuggingFace"
+ FUNCTION = "download"
+
+ def download(
+ self,
+ repo_id: str,
+ filename: str,
+ revision: str,
+ ) -> tuple[Path]:
+ """Download file(s) from the HuggingFace Hub.
+
+ Args:
+ repo_id: The HuggingFace repository ID.
+ filename: The filename to download, if empty, the entire repository will be downloaded.
+ revision: The git revision to download.
+
+ Returns:
+ The path to the downloaded file(s).
+ """
+ if filename == "":
+ path = snapshot_download(
+ repo_id=repo_id,
+ revision=revision,
+ )
+ else:
+ path = hf_hub_download(
+ repo_id=repo_id,
+ filename=filename,
+ revision=revision,
+ )
+ return (Path(path),)
+
+
+NODE_CLASS_MAPPINGS: dict[str, Any] = {
+ "HfHubDownload": HfHubDownload,
+}
diff --git a/src/comfyui-refiners/pyproject.toml b/src/comfyui-refiners/pyproject.toml
new file mode 100644
index 0000000..64fa98e
--- /dev/null
+++ b/src/comfyui-refiners/pyproject.toml
@@ -0,0 +1,18 @@
+[project]
+name = "comfyui-refiners"
+description = "ComfyUI custom nodes for refiners models"
+version = "1.0.1"
+license = { file = "LICENSE" }
+dependencies = [
+ "refiners @ git+https://github.com/finegrain-ai/refiners.git",
+ "huggingface_hub",
+ "transformers",
+]
+
+[project.urls]
+Repository = "https://github.com/finegrain-ai/refiners"
+
+[tool.comfy]
+PublisherId = "finegrain"
+DisplayName = "refiners"
+Icon = "https://raw.githubusercontent.com/finegrain-ai/refiners/main/assets/dropy.png"
diff --git a/src/comfyui-refiners/utils.py b/src/comfyui-refiners/utils.py
new file mode 100644
index 0000000..9595376
--- /dev/null
+++ b/src/comfyui-refiners/utils.py
@@ -0,0 +1,86 @@
+from typing import Any
+
+import torch
+from PIL import ImageDraw
+
+from refiners.fluxion.utils import image_to_tensor, tensor_to_image
+
+BoundingBox = tuple[int, int, int, int]
+
+
+class DrawBoundingBox:
+ @classmethod
+ def INPUT_TYPES(cls) -> dict[str, Any]:
+ return {
+ "required": {
+ "image": ("IMAGE", {}),
+ "bbox": ("BOUNDING_BOX", {}),
+ "color": ("STRING", {"default": "red"}),
+ "width": ("INT", {"default": 3}),
+ },
+ }
+
+ RETURN_TYPES = ("IMAGE",)
+ RETURN_NAMES = ("image",)
+ DESCRIPTION = "Draw a bounding box on an image."
+ CATEGORY = "Refiners/Helpers"
+ FUNCTION = "process"
+
+ def process(
+ self,
+ image: torch.Tensor,
+ bbox: BoundingBox,
+ color: str,
+ width: int,
+ ) -> tuple[torch.Tensor]:
+ """Draw a bounding box on an image.
+
+ Args:
+ image: The image to draw on.
+ bbox: The bounding box to draw.
+ color: The color of the bounding box.
+ width: The width of the bounding box.
+ """
+ pil_image = tensor_to_image(image.permute(0, 3, 1, 2))
+ draw = ImageDraw.Draw(pil_image)
+ draw.rectangle(bbox, outline=color, width=width)
+ image = image_to_tensor(pil_image).permute(0, 2, 3, 1)
+ return (image,)
+
+
+def get_dtype(dtype: str) -> torch.dtype:
+ """Converts a string dtype to a torch.dtype.
+
+ See also https://pytorch.org/docs/stable/tensor_attributes.html#torch-dtype"""
+ match dtype:
+ case "float32" | "float":
+ return torch.float32
+ case "float64" | "double":
+ return torch.float64
+ case "complex64" | "cfloat":
+ return torch.complex64
+ case "complex128" | "cdouble":
+ return torch.complex128
+ case "float16" | "half":
+ return torch.float16
+ case "bfloat16":
+ return torch.bfloat16
+ case "uint8":
+ return torch.uint8
+ case "int8":
+ return torch.int8
+ case "int16" | "short":
+ return torch.int16
+ case "int32" | "int":
+ return torch.int32
+ case "int64" | "long":
+ return torch.int64
+ case "bool":
+ return torch.bool
+ case _:
+ raise ValueError(f"Unknown dtype: {dtype}")
+
+
+NODE_CLASS_MAPPINGS: dict[str, Any] = {
+ "DrawBoundingBox": DrawBoundingBox,
+}