commit 8eacdfda7bc8f8a6ad4963e612f785caf7edd7a6 Author: Laurent Date: Sat Jul 6 16:25:30 2024 +0000 initial commit diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..1f51344 --- /dev/null +++ b/.gitignore @@ -0,0 +1,163 @@ +# https://github.com/github/gitignore/blob/main/Python.gitignore +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/latest/usage/project/#working-with-version-control +.pdm.toml +.pdm-python +.pdm-build/ + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ diff --git a/.python-version b/.python-version new file mode 100644 index 0000000..171a6a9 --- /dev/null +++ b/.python-version @@ -0,0 +1 @@ +3.12.1 diff --git a/README.md b/README.md new file mode 100644 index 0000000..0ad0491 --- /dev/null +++ b/README.md @@ -0,0 +1,20 @@ +# Refiners Model Explorer + +A Model Explorer extension to visualize and explore Refiners-based model. + +## Installation + +```bash +pip install git+https://github.com/finegrain-ai/model_explorer_refiners.git +``` + +## Usage + +```python +from refiners.foundationals.dinov2 import DINOv2_small_reg +from model_explorer_refiners import RefinersAdapter + +model = DINOv2_small_reg() +RefinersAdapter.visualize(model) + +``` diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..55fb788 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,60 @@ +[project] +version = "0.1.0" +readme = "README.md" +requires-python = ">= 3.9" +name = "model-explorer-refiners" +description = "Refiners adapter for Model Explorer" +authors = [{ name = "Laurent", email = "laurent@lagon.tech" }] +dependencies = [ + "ai-edge-model-explorer", + "refiners @ git+https://github.com/finegrain-ai/refiners", +] + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.rye] +managed = true +dev-dependencies = [] + +[tool.hatch.metadata] +allow-direct-references = true + +[tool.hatch.build.targets.wheel] +packages = ["src/model_explorer_refiners"] + +[tool.ruff] +src = ["src"] +line-length = 120 + +[tool.ruff.lint] +select = [ + "A", # flake8-builtins + "B", # flake8-bugbear + "C90", # mccabe + "COM", # flake8-commas + "EM", # flake8-errmsg + "E", # pycodestyle errors + "F", # Pyflakes + "G", # flake8-logging-format + "I", # isort + "N", # pep8-naming + "PIE", # flake8-pie + "PTH", # flake8-use-pathlib + "TD", # flake8-todo + "FIX", # flake8-fixme + "RUF", # ruff + "S", # flake8-bandit + "TCH", # flake8-type-checking + "TID", # flake8-tidy-imports + "UP", # pyupgrade + "W", # pycodestyle warnings +] + + +[tool.pyright] +include = ["model_explorer_refiners"] +reportMissingTypeStubs = false +pythonVersion = "3.12" +pythonPlatform = "Linux" diff --git a/requirements-dev.lock b/requirements-dev.lock new file mode 100644 index 0000000..c7bf8e5 --- /dev/null +++ b/requirements-dev.lock @@ -0,0 +1,140 @@ +# generated by rye +# use `rye lock` or `rye sync` to update this lockfile +# +# last locked with the following flags: +# pre: false +# features: [] +# all-features: true +# with-sources: false +# generate-hashes: false + +-e file:. +ai-edge-model-explorer==0.1.7 + # via model-explorer-refiners +ai-edge-model-explorer-adapter==0.1.3 + # via ai-edge-model-explorer +asttokens==2.4.1 + # via stack-data +blinker==1.8.2 + # via flask +certifi==2024.6.2 + # via requests +charset-normalizer==3.3.2 + # via requests +click==8.1.7 + # via flask +decorator==5.1.1 + # via ipython +executing==2.0.1 + # via stack-data +filelock==3.15.4 + # via torch +flask==3.0.3 + # via ai-edge-model-explorer +fsspec==2024.6.1 + # via torch +idna==3.7 + # via requests +ipython==8.26.0 + # via ai-edge-model-explorer +itsdangerous==2.2.0 + # via flask +jaxtyping==0.2.31 + # via refiners +jedi==0.19.1 + # via ipython +jinja2==3.1.4 + # via flask + # via torch +markupsafe==2.1.5 + # via jinja2 + # via werkzeug +matplotlib-inline==0.1.7 + # via ipython +mpmath==1.3.0 + # via sympy +networkx==3.3 + # via torch +numpy==2.0.0 + # via ai-edge-model-explorer + # via refiners +nvidia-cublas-cu12==12.1.3.1 + # via nvidia-cudnn-cu12 + # via nvidia-cusolver-cu12 + # via torch +nvidia-cuda-cupti-cu12==12.1.105 + # via torch +nvidia-cuda-nvrtc-cu12==12.1.105 + # via torch +nvidia-cuda-runtime-cu12==12.1.105 + # via torch +nvidia-cudnn-cu12==8.9.2.26 + # via torch +nvidia-cufft-cu12==11.0.2.54 + # via torch +nvidia-curand-cu12==10.3.2.106 + # via torch +nvidia-cusolver-cu12==11.4.5.107 + # via torch +nvidia-cusparse-cu12==12.1.0.106 + # via nvidia-cusolver-cu12 + # via torch +nvidia-nccl-cu12==2.20.5 + # via torch +nvidia-nvjitlink-cu12==12.5.40 + # via nvidia-cusolver-cu12 + # via nvidia-cusparse-cu12 +nvidia-nvtx-cu12==12.1.105 + # via torch +packaging==24.1 + # via ai-edge-model-explorer + # via refiners +parso==0.8.4 + # via jedi +pexpect==4.9.0 + # via ipython +pillow==10.3.0 + # via refiners +portpicker==1.6.0 + # via ai-edge-model-explorer +prompt-toolkit==3.0.47 + # via ipython +psutil==6.0.0 + # via portpicker +ptyprocess==0.7.0 + # via pexpect +pure-eval==0.2.2 + # via stack-data +pygments==2.18.0 + # via ipython +refiners @ git+https://github.com/finegrain-ai/refiners@e091788b885c5024b82e18cef082a922cd050481 + # via model-explorer-refiners +requests==2.32.3 + # via ai-edge-model-explorer +safetensors==0.4.3 + # via refiners +six==1.16.0 + # via asttokens +stack-data==0.6.3 + # via ipython +sympy==1.12.1 + # via torch +termcolor==2.4.0 + # via ai-edge-model-explorer +torch==2.3.1 + # via ai-edge-model-explorer + # via refiners +traitlets==5.14.3 + # via ipython + # via matplotlib-inline +typeguard==2.13.3 + # via jaxtyping +typing-extensions==4.12.2 + # via ai-edge-model-explorer + # via torch +urllib3==2.2.2 + # via requests +wcwidth==0.2.13 + # via prompt-toolkit +werkzeug==3.0.3 + # via flask diff --git a/requirements.lock b/requirements.lock new file mode 100644 index 0000000..c7bf8e5 --- /dev/null +++ b/requirements.lock @@ -0,0 +1,140 @@ +# generated by rye +# use `rye lock` or `rye sync` to update this lockfile +# +# last locked with the following flags: +# pre: false +# features: [] +# all-features: true +# with-sources: false +# generate-hashes: false + +-e file:. +ai-edge-model-explorer==0.1.7 + # via model-explorer-refiners +ai-edge-model-explorer-adapter==0.1.3 + # via ai-edge-model-explorer +asttokens==2.4.1 + # via stack-data +blinker==1.8.2 + # via flask +certifi==2024.6.2 + # via requests +charset-normalizer==3.3.2 + # via requests +click==8.1.7 + # via flask +decorator==5.1.1 + # via ipython +executing==2.0.1 + # via stack-data +filelock==3.15.4 + # via torch +flask==3.0.3 + # via ai-edge-model-explorer +fsspec==2024.6.1 + # via torch +idna==3.7 + # via requests +ipython==8.26.0 + # via ai-edge-model-explorer +itsdangerous==2.2.0 + # via flask +jaxtyping==0.2.31 + # via refiners +jedi==0.19.1 + # via ipython +jinja2==3.1.4 + # via flask + # via torch +markupsafe==2.1.5 + # via jinja2 + # via werkzeug +matplotlib-inline==0.1.7 + # via ipython +mpmath==1.3.0 + # via sympy +networkx==3.3 + # via torch +numpy==2.0.0 + # via ai-edge-model-explorer + # via refiners +nvidia-cublas-cu12==12.1.3.1 + # via nvidia-cudnn-cu12 + # via nvidia-cusolver-cu12 + # via torch +nvidia-cuda-cupti-cu12==12.1.105 + # via torch +nvidia-cuda-nvrtc-cu12==12.1.105 + # via torch +nvidia-cuda-runtime-cu12==12.1.105 + # via torch +nvidia-cudnn-cu12==8.9.2.26 + # via torch +nvidia-cufft-cu12==11.0.2.54 + # via torch +nvidia-curand-cu12==10.3.2.106 + # via torch +nvidia-cusolver-cu12==11.4.5.107 + # via torch +nvidia-cusparse-cu12==12.1.0.106 + # via nvidia-cusolver-cu12 + # via torch +nvidia-nccl-cu12==2.20.5 + # via torch +nvidia-nvjitlink-cu12==12.5.40 + # via nvidia-cusolver-cu12 + # via nvidia-cusparse-cu12 +nvidia-nvtx-cu12==12.1.105 + # via torch +packaging==24.1 + # via ai-edge-model-explorer + # via refiners +parso==0.8.4 + # via jedi +pexpect==4.9.0 + # via ipython +pillow==10.3.0 + # via refiners +portpicker==1.6.0 + # via ai-edge-model-explorer +prompt-toolkit==3.0.47 + # via ipython +psutil==6.0.0 + # via portpicker +ptyprocess==0.7.0 + # via pexpect +pure-eval==0.2.2 + # via stack-data +pygments==2.18.0 + # via ipython +refiners @ git+https://github.com/finegrain-ai/refiners@e091788b885c5024b82e18cef082a922cd050481 + # via model-explorer-refiners +requests==2.32.3 + # via ai-edge-model-explorer +safetensors==0.4.3 + # via refiners +six==1.16.0 + # via asttokens +stack-data==0.6.3 + # via ipython +sympy==1.12.1 + # via torch +termcolor==2.4.0 + # via ai-edge-model-explorer +torch==2.3.1 + # via ai-edge-model-explorer + # via refiners +traitlets==5.14.3 + # via ipython + # via matplotlib-inline +typeguard==2.13.3 + # via jaxtyping +typing-extensions==4.12.2 + # via ai-edge-model-explorer + # via torch +urllib3==2.2.2 + # via requests +wcwidth==0.2.13 + # via prompt-toolkit +werkzeug==3.0.3 + # via flask diff --git a/src/model_explorer_refiners/__init__.py b/src/model_explorer_refiners/__init__.py new file mode 100644 index 0000000..0fdae92 --- /dev/null +++ b/src/model_explorer_refiners/__init__.py @@ -0,0 +1,3 @@ +from model_explorer_refiners.main import RefinersAdapter + +__all__ = ["RefinersAdapter"] diff --git a/src/model_explorer_refiners/convert.py b/src/model_explorer_refiners/convert.py new file mode 100644 index 0000000..30d535f --- /dev/null +++ b/src/model_explorer_refiners/convert.py @@ -0,0 +1,298 @@ +import refiners.fluxion.layers as fl +from model_explorer import ModelExplorerGraphs +from model_explorer.graph_builder import Graph, GraphNode, IncomingEdge, KeyValue + +from model_explorer_refiners.utils import find_node + + +def convert_chain( + chain: fl.Chain, + input_nodes: list[GraphNode], +) -> tuple[list[GraphNode], list[GraphNode]]: + """Convert a refiners chain layer to nodes. + + Args: + chain: the chain layer to convert. + input_nodes: the input nodes of the chain. + """ + nodes: list[GraphNode] = [] + previous_nodes = input_nodes + for layer, parent_layer in chain.walk(recurse=False): + module_nodes, previous_nodes = convert_module(layer, parent_layer, previous_nodes) + nodes.extend(module_nodes) + + return nodes, previous_nodes + + +def convert_passthrough( + passthrough: fl.Passthrough, + input_nodes: list[GraphNode], +) -> tuple[list[GraphNode], list[GraphNode]]: + """Convert a refiners passthrough layer to nodes. + + Args: + passthrough: the passthrough layer to convert. + input_nodes: the input nodes of the passthrough. + """ + nodes, _ = convert_chain(passthrough, input_nodes) + return nodes, input_nodes + + +def convert_residual( + residual: fl.Residual, + input_nodes: list[GraphNode], +) -> tuple[list[GraphNode], list[GraphNode]]: + """Convert a refiners residual layer to nodes. + + Args: + residual: the residual layer to convert. + input_nodes: the input nodes of the residual. + """ + nodes: list[GraphNode] = [] + previous_nodes = input_nodes + for layer, parent_layer in residual.walk(recurse=False): + module_nodes, previous_nodes = convert_module(layer, parent_layer, previous_nodes) + nodes.extend(module_nodes) + + # add a summation node + summation_node = GraphNode( + namespace=residual.get_path().replace(".", "/"), + id=residual.get_path(), + label="+", + ) + for node in previous_nodes: + summation_node.incomingEdges.append( + IncomingEdge( + sourceNodeId=node.id, + ), + ) + for input_node in input_nodes: + summation_node.incomingEdges.append( + IncomingEdge( + sourceNodeId=input_node.id, + ), + ) + nodes.append(summation_node) + + return nodes, [summation_node] + + +def convert_parallel( + parallel: fl.Parallel, + input_nodes: list[GraphNode], +) -> tuple[list[GraphNode], list[GraphNode]]: + """Convert a refiners parallel layer to nodes. + + Args: + parallel: the parallel layer to convert. + input_nodes: the input nodes of the parallel layer. + """ + nodes: list[GraphNode] = [] + output_nodes: list[GraphNode] = [] + for layer, parent_layer in parallel.walk(recurse=False): + module_nodes, module_output_nodes = convert_module(layer, parent_layer, input_nodes) + nodes.extend(module_nodes) + output_nodes.extend(module_output_nodes) + + return nodes, output_nodes + + +def convert_concatenate( + concatenate: fl.Concatenate, + input_nodes: list[GraphNode], +) -> tuple[list[GraphNode], list[GraphNode]]: + """Convert a refiners concatenate layer to nodes. + + Args: + concatenate: the concatenate layer to convert. + input_nodes: the input nodes of the concatenate layer. + """ + nodes: list[GraphNode] = [] + output_nodes: list[GraphNode] = [] + for layer, parent_layer in concatenate.walk(recurse=False): + module_nodes, module_output_nodes = convert_module(layer, parent_layer, input_nodes) + nodes.extend(module_nodes) + output_nodes.extend(module_output_nodes) + + # add a concatenation node + concat_node = GraphNode( + namespace=concatenate.get_path().replace(".", "/"), + id=concatenate.get_path(), + label="+", + ) + for node in output_nodes: + concat_node.incomingEdges.append( + IncomingEdge( + sourceNodeId=node.id, + ), + ) + nodes.append(concat_node) + + return nodes, [concat_node] + + +def convert_distribute( + distribute: fl.Distribute, + input_nodes: list[GraphNode], +) -> tuple[list[GraphNode], list[GraphNode]]: + """Convert a refiners distribute layer to nodes. + + Args: + distribute: the distribute layer to convert. + input_nodes: the input nodes of the distribute layer. + """ + nodes: list[GraphNode] = [] + output_nodes: list[GraphNode] = [] + for (layer, parent_layer), input_node in zip(distribute.walk(recurse=False), input_nodes, strict=True): + module_nodes, previous_nodes = convert_module(layer, parent_layer, [input_node]) + nodes.extend(module_nodes) + output_nodes.extend(previous_nodes) + + return nodes, output_nodes + + +def convert_other( + module: fl.Module, + parent_module: fl.Chain | None, + input_nodes: list[GraphNode], +) -> tuple[list[GraphNode], list[GraphNode]]: + """Convert a refiners module to nodes. + + Args: + module: the module to convert. + parent_module: the parent module of the module. + input_nodes: the input nodes of the module. + """ + node = GraphNode( + namespace=parent_module.get_path().replace(".", "/") if parent_module else "", + id=module.get_path(parent_module), + label=module._get_name(), # type: ignore + attrs=[ + KeyValue( + key=key, + value=str(value), + ) + for key, value in module.basic_attributes().items() + ], + ) + setattr(node, "refiners_module", module) # noqa: B010 + for input_node in input_nodes: + node.incomingEdges.append( + IncomingEdge( + sourceNodeId=input_node.id, + ), + ) + return [node], [node] + + +def convert_module( # type: ignore + module: fl.Module, + parent_module: fl.Chain | None, + input_nodes: list[GraphNode], +) -> tuple[list[GraphNode], list[GraphNode]]: + """Convert a refiners module to nodes. + + Args: + module: the module to convert. + parent_module: the parent module of the module. + input_nodes: the input nodes of the module. + """ + match module: + case fl.Parallel(): + return convert_parallel(module, input_nodes) + case fl.Distribute(): + return convert_distribute(module, input_nodes) + case fl.Concatenate(): + return convert_concatenate(module, input_nodes) + case fl.Residual(): + return convert_residual(module, input_nodes) + case fl.Passthrough(): + return convert_passthrough(module, input_nodes) + case fl.Chain(): + return convert_chain(module, input_nodes) + case fl.UseContext(): + return convert_other(module, parent_module=parent_module, input_nodes=[]) + case fl.Parameter(): + return convert_other(module, parent_module=parent_module, input_nodes=[]) + case _: + return convert_other(module, parent_module, input_nodes) + + +def convert_model(model: fl.Chain) -> ModelExplorerGraphs: + """Convert a refiners model to a Graph. + + Args: + model: the model to convert. + """ + # initialize the graph, with an input and output node + graph = Graph(id=model._get_name()) # type: ignore + input_node = GraphNode(id="input", label="input") + graph.nodes.append(input_node) + output_node = GraphNode(id="output", label="output") + graph.nodes.append(output_node) + + # convert the model + model_nodes, model_output_nodes = convert_module( + module=model, + parent_module=None, + input_nodes=[input_node], + ) + graph.nodes.extend(model_nodes) + + # connect the model's output to the output node + for node in model_output_nodes: + output_node.incomingEdges.append( + IncomingEdge( + sourceNodeId=node.id, + ), + ) + + # connect the model's context nodes + for node in graph.nodes: + refiners_module = getattr(node, "refiners_module", None) + if isinstance(refiners_module, fl.SetContext): + # find the context node + context_node = find_node( + graph=graph, + node_id=f"context.{refiners_module.context}.{refiners_module.key}", + ) + + # if it doesn't exist, create it + if context_node is None: + context_node = GraphNode( + namespace=f"context/{refiners_module.context}/{refiners_module.key}", + id=f"context.{refiners_module.context}.{refiners_module.key}", + label=refiners_module.key, + ) + + # connect the nodes: node -> context_node + context_node.incomingEdges.append( + IncomingEdge( + sourceNodeId=node.id, + ), + ) + graph.nodes.append(context_node) + if isinstance(refiners_module, fl.UseContext): + # find the context node + context_node = find_node( + graph=graph, + node_id=f"context.{refiners_module.context}.{refiners_module.key}", + ) + + # if it doesn't exist, create it + if context_node is None: + context_node = GraphNode( + namespace=f"context/{refiners_module.context}/{refiners_module.key}", + id=f"context.{refiners_module.context}.{refiners_module.key}", + label=refiners_module.key, + ) + graph.nodes.append(context_node) + + # connect the nodes: context_node -> node + node.incomingEdges.append( + IncomingEdge( + sourceNodeId=f"context.{refiners_module.context}.{refiners_module.key}", + ), + ) + + return ModelExplorerGraphs(graphs=[graph]) diff --git a/src/model_explorer_refiners/main.py b/src/model_explorer_refiners/main.py new file mode 100644 index 0000000..7268778 --- /dev/null +++ b/src/model_explorer_refiners/main.py @@ -0,0 +1,44 @@ +import refiners.fluxion.layers as fl +from model_explorer import Adapter, AdapterMetadata, server +from model_explorer.config import ModelExplorerConfig + +from model_explorer_refiners.convert import convert_model + + +class RefinersAdapter(Adapter): + metadata = AdapterMetadata( + id="refiners", + name="Refiners adapter", + description="Refiners adapter for Model Explorer", + source_repo="https://github.com/finegrain/model-explorer-refiners", + ) + + def __init__(self): + super().__init__() + + @staticmethod + def visualize( + model: fl.Chain, + host: str = "localhost", + port: int = 8080, + ) -> None: + """Visualize a refiners model in Model Explorer. + + Args: + model: the refiners model to visualize. + host: the host of the server. + port: the port of the server. + """ + # construct config + config = ModelExplorerConfig() + graph = convert_model(model) + config.graphs_list.append(graph) + index = len(config.graphs_list) - 1 + config.model_sources.append({"url": f"graphs://refiners/{index}"}) + + # start server + server.start( + config=config, + host=host, + port=port, + ) diff --git a/src/model_explorer_refiners/utils.py b/src/model_explorer_refiners/utils.py new file mode 100644 index 0000000..b93ccff --- /dev/null +++ b/src/model_explorer_refiners/utils.py @@ -0,0 +1,17 @@ +from model_explorer.graph_builder import Graph, GraphNode + + +def find_node(graph: Graph, node_id: str) -> GraphNode | None: + """Find a node in a graph. + + Args: + graph: the graph to search. + node_id: the id of the node to find. + + Returns: + The node node which matches the node_id, if found, otherwise None. + """ + return next( + (context_node for context_node in graph.nodes if context_node.id == node_id), + None, + )