initial commit

This commit is contained in:
Laurent 2024-07-06 16:25:30 +00:00
commit 8eacdfda7b
No known key found for this signature in database
10 changed files with 886 additions and 0 deletions

163
.gitignore vendored Normal file
View file

@ -0,0 +1,163 @@
# https://github.com/github/gitignore/blob/main/Python.gitignore
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
.pdm.toml
.pdm-python
.pdm-build/
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/

1
.python-version Normal file
View file

@ -0,0 +1 @@
3.12.1

20
README.md Normal file
View file

@ -0,0 +1,20 @@
# Refiners Model Explorer
A Model Explorer extension to visualize and explore Refiners-based model.
## Installation
```bash
pip install git+https://github.com/finegrain-ai/model_explorer_refiners.git
```
## Usage
```python
from refiners.foundationals.dinov2 import DINOv2_small_reg
from model_explorer_refiners import RefinersAdapter
model = DINOv2_small_reg()
RefinersAdapter.visualize(model)
```

60
pyproject.toml Normal file
View file

@ -0,0 +1,60 @@
[project]
version = "0.1.0"
readme = "README.md"
requires-python = ">= 3.9"
name = "model-explorer-refiners"
description = "Refiners adapter for Model Explorer"
authors = [{ name = "Laurent", email = "laurent@lagon.tech" }]
dependencies = [
"ai-edge-model-explorer",
"refiners @ git+https://github.com/finegrain-ai/refiners",
]
[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"
[tool.rye]
managed = true
dev-dependencies = []
[tool.hatch.metadata]
allow-direct-references = true
[tool.hatch.build.targets.wheel]
packages = ["src/model_explorer_refiners"]
[tool.ruff]
src = ["src"]
line-length = 120
[tool.ruff.lint]
select = [
"A", # flake8-builtins
"B", # flake8-bugbear
"C90", # mccabe
"COM", # flake8-commas
"EM", # flake8-errmsg
"E", # pycodestyle errors
"F", # Pyflakes
"G", # flake8-logging-format
"I", # isort
"N", # pep8-naming
"PIE", # flake8-pie
"PTH", # flake8-use-pathlib
"TD", # flake8-todo
"FIX", # flake8-fixme
"RUF", # ruff
"S", # flake8-bandit
"TCH", # flake8-type-checking
"TID", # flake8-tidy-imports
"UP", # pyupgrade
"W", # pycodestyle warnings
]
[tool.pyright]
include = ["model_explorer_refiners"]
reportMissingTypeStubs = false
pythonVersion = "3.12"
pythonPlatform = "Linux"

140
requirements-dev.lock Normal file
View file

@ -0,0 +1,140 @@
# generated by rye
# use `rye lock` or `rye sync` to update this lockfile
#
# last locked with the following flags:
# pre: false
# features: []
# all-features: true
# with-sources: false
# generate-hashes: false
-e file:.
ai-edge-model-explorer==0.1.7
# via model-explorer-refiners
ai-edge-model-explorer-adapter==0.1.3
# via ai-edge-model-explorer
asttokens==2.4.1
# via stack-data
blinker==1.8.2
# via flask
certifi==2024.6.2
# via requests
charset-normalizer==3.3.2
# via requests
click==8.1.7
# via flask
decorator==5.1.1
# via ipython
executing==2.0.1
# via stack-data
filelock==3.15.4
# via torch
flask==3.0.3
# via ai-edge-model-explorer
fsspec==2024.6.1
# via torch
idna==3.7
# via requests
ipython==8.26.0
# via ai-edge-model-explorer
itsdangerous==2.2.0
# via flask
jaxtyping==0.2.31
# via refiners
jedi==0.19.1
# via ipython
jinja2==3.1.4
# via flask
# via torch
markupsafe==2.1.5
# via jinja2
# via werkzeug
matplotlib-inline==0.1.7
# via ipython
mpmath==1.3.0
# via sympy
networkx==3.3
# via torch
numpy==2.0.0
# via ai-edge-model-explorer
# via refiners
nvidia-cublas-cu12==12.1.3.1
# via nvidia-cudnn-cu12
# via nvidia-cusolver-cu12
# via torch
nvidia-cuda-cupti-cu12==12.1.105
# via torch
nvidia-cuda-nvrtc-cu12==12.1.105
# via torch
nvidia-cuda-runtime-cu12==12.1.105
# via torch
nvidia-cudnn-cu12==8.9.2.26
# via torch
nvidia-cufft-cu12==11.0.2.54
# via torch
nvidia-curand-cu12==10.3.2.106
# via torch
nvidia-cusolver-cu12==11.4.5.107
# via torch
nvidia-cusparse-cu12==12.1.0.106
# via nvidia-cusolver-cu12
# via torch
nvidia-nccl-cu12==2.20.5
# via torch
nvidia-nvjitlink-cu12==12.5.40
# via nvidia-cusolver-cu12
# via nvidia-cusparse-cu12
nvidia-nvtx-cu12==12.1.105
# via torch
packaging==24.1
# via ai-edge-model-explorer
# via refiners
parso==0.8.4
# via jedi
pexpect==4.9.0
# via ipython
pillow==10.3.0
# via refiners
portpicker==1.6.0
# via ai-edge-model-explorer
prompt-toolkit==3.0.47
# via ipython
psutil==6.0.0
# via portpicker
ptyprocess==0.7.0
# via pexpect
pure-eval==0.2.2
# via stack-data
pygments==2.18.0
# via ipython
refiners @ git+https://github.com/finegrain-ai/refiners@e091788b885c5024b82e18cef082a922cd050481
# via model-explorer-refiners
requests==2.32.3
# via ai-edge-model-explorer
safetensors==0.4.3
# via refiners
six==1.16.0
# via asttokens
stack-data==0.6.3
# via ipython
sympy==1.12.1
# via torch
termcolor==2.4.0
# via ai-edge-model-explorer
torch==2.3.1
# via ai-edge-model-explorer
# via refiners
traitlets==5.14.3
# via ipython
# via matplotlib-inline
typeguard==2.13.3
# via jaxtyping
typing-extensions==4.12.2
# via ai-edge-model-explorer
# via torch
urllib3==2.2.2
# via requests
wcwidth==0.2.13
# via prompt-toolkit
werkzeug==3.0.3
# via flask

140
requirements.lock Normal file
View file

@ -0,0 +1,140 @@
# generated by rye
# use `rye lock` or `rye sync` to update this lockfile
#
# last locked with the following flags:
# pre: false
# features: []
# all-features: true
# with-sources: false
# generate-hashes: false
-e file:.
ai-edge-model-explorer==0.1.7
# via model-explorer-refiners
ai-edge-model-explorer-adapter==0.1.3
# via ai-edge-model-explorer
asttokens==2.4.1
# via stack-data
blinker==1.8.2
# via flask
certifi==2024.6.2
# via requests
charset-normalizer==3.3.2
# via requests
click==8.1.7
# via flask
decorator==5.1.1
# via ipython
executing==2.0.1
# via stack-data
filelock==3.15.4
# via torch
flask==3.0.3
# via ai-edge-model-explorer
fsspec==2024.6.1
# via torch
idna==3.7
# via requests
ipython==8.26.0
# via ai-edge-model-explorer
itsdangerous==2.2.0
# via flask
jaxtyping==0.2.31
# via refiners
jedi==0.19.1
# via ipython
jinja2==3.1.4
# via flask
# via torch
markupsafe==2.1.5
# via jinja2
# via werkzeug
matplotlib-inline==0.1.7
# via ipython
mpmath==1.3.0
# via sympy
networkx==3.3
# via torch
numpy==2.0.0
# via ai-edge-model-explorer
# via refiners
nvidia-cublas-cu12==12.1.3.1
# via nvidia-cudnn-cu12
# via nvidia-cusolver-cu12
# via torch
nvidia-cuda-cupti-cu12==12.1.105
# via torch
nvidia-cuda-nvrtc-cu12==12.1.105
# via torch
nvidia-cuda-runtime-cu12==12.1.105
# via torch
nvidia-cudnn-cu12==8.9.2.26
# via torch
nvidia-cufft-cu12==11.0.2.54
# via torch
nvidia-curand-cu12==10.3.2.106
# via torch
nvidia-cusolver-cu12==11.4.5.107
# via torch
nvidia-cusparse-cu12==12.1.0.106
# via nvidia-cusolver-cu12
# via torch
nvidia-nccl-cu12==2.20.5
# via torch
nvidia-nvjitlink-cu12==12.5.40
# via nvidia-cusolver-cu12
# via nvidia-cusparse-cu12
nvidia-nvtx-cu12==12.1.105
# via torch
packaging==24.1
# via ai-edge-model-explorer
# via refiners
parso==0.8.4
# via jedi
pexpect==4.9.0
# via ipython
pillow==10.3.0
# via refiners
portpicker==1.6.0
# via ai-edge-model-explorer
prompt-toolkit==3.0.47
# via ipython
psutil==6.0.0
# via portpicker
ptyprocess==0.7.0
# via pexpect
pure-eval==0.2.2
# via stack-data
pygments==2.18.0
# via ipython
refiners @ git+https://github.com/finegrain-ai/refiners@e091788b885c5024b82e18cef082a922cd050481
# via model-explorer-refiners
requests==2.32.3
# via ai-edge-model-explorer
safetensors==0.4.3
# via refiners
six==1.16.0
# via asttokens
stack-data==0.6.3
# via ipython
sympy==1.12.1
# via torch
termcolor==2.4.0
# via ai-edge-model-explorer
torch==2.3.1
# via ai-edge-model-explorer
# via refiners
traitlets==5.14.3
# via ipython
# via matplotlib-inline
typeguard==2.13.3
# via jaxtyping
typing-extensions==4.12.2
# via ai-edge-model-explorer
# via torch
urllib3==2.2.2
# via requests
wcwidth==0.2.13
# via prompt-toolkit
werkzeug==3.0.3
# via flask

View file

@ -0,0 +1,3 @@
from model_explorer_refiners.main import RefinersAdapter
__all__ = ["RefinersAdapter"]

View file

@ -0,0 +1,298 @@
import refiners.fluxion.layers as fl
from model_explorer import ModelExplorerGraphs
from model_explorer.graph_builder import Graph, GraphNode, IncomingEdge, KeyValue
from model_explorer_refiners.utils import find_node
def convert_chain(
chain: fl.Chain,
input_nodes: list[GraphNode],
) -> tuple[list[GraphNode], list[GraphNode]]:
"""Convert a refiners chain layer to nodes.
Args:
chain: the chain layer to convert.
input_nodes: the input nodes of the chain.
"""
nodes: list[GraphNode] = []
previous_nodes = input_nodes
for layer, parent_layer in chain.walk(recurse=False):
module_nodes, previous_nodes = convert_module(layer, parent_layer, previous_nodes)
nodes.extend(module_nodes)
return nodes, previous_nodes
def convert_passthrough(
passthrough: fl.Passthrough,
input_nodes: list[GraphNode],
) -> tuple[list[GraphNode], list[GraphNode]]:
"""Convert a refiners passthrough layer to nodes.
Args:
passthrough: the passthrough layer to convert.
input_nodes: the input nodes of the passthrough.
"""
nodes, _ = convert_chain(passthrough, input_nodes)
return nodes, input_nodes
def convert_residual(
residual: fl.Residual,
input_nodes: list[GraphNode],
) -> tuple[list[GraphNode], list[GraphNode]]:
"""Convert a refiners residual layer to nodes.
Args:
residual: the residual layer to convert.
input_nodes: the input nodes of the residual.
"""
nodes: list[GraphNode] = []
previous_nodes = input_nodes
for layer, parent_layer in residual.walk(recurse=False):
module_nodes, previous_nodes = convert_module(layer, parent_layer, previous_nodes)
nodes.extend(module_nodes)
# add a summation node
summation_node = GraphNode(
namespace=residual.get_path().replace(".", "/"),
id=residual.get_path(),
label="+",
)
for node in previous_nodes:
summation_node.incomingEdges.append(
IncomingEdge(
sourceNodeId=node.id,
),
)
for input_node in input_nodes:
summation_node.incomingEdges.append(
IncomingEdge(
sourceNodeId=input_node.id,
),
)
nodes.append(summation_node)
return nodes, [summation_node]
def convert_parallel(
parallel: fl.Parallel,
input_nodes: list[GraphNode],
) -> tuple[list[GraphNode], list[GraphNode]]:
"""Convert a refiners parallel layer to nodes.
Args:
parallel: the parallel layer to convert.
input_nodes: the input nodes of the parallel layer.
"""
nodes: list[GraphNode] = []
output_nodes: list[GraphNode] = []
for layer, parent_layer in parallel.walk(recurse=False):
module_nodes, module_output_nodes = convert_module(layer, parent_layer, input_nodes)
nodes.extend(module_nodes)
output_nodes.extend(module_output_nodes)
return nodes, output_nodes
def convert_concatenate(
concatenate: fl.Concatenate,
input_nodes: list[GraphNode],
) -> tuple[list[GraphNode], list[GraphNode]]:
"""Convert a refiners concatenate layer to nodes.
Args:
concatenate: the concatenate layer to convert.
input_nodes: the input nodes of the concatenate layer.
"""
nodes: list[GraphNode] = []
output_nodes: list[GraphNode] = []
for layer, parent_layer in concatenate.walk(recurse=False):
module_nodes, module_output_nodes = convert_module(layer, parent_layer, input_nodes)
nodes.extend(module_nodes)
output_nodes.extend(module_output_nodes)
# add a concatenation node
concat_node = GraphNode(
namespace=concatenate.get_path().replace(".", "/"),
id=concatenate.get_path(),
label="+",
)
for node in output_nodes:
concat_node.incomingEdges.append(
IncomingEdge(
sourceNodeId=node.id,
),
)
nodes.append(concat_node)
return nodes, [concat_node]
def convert_distribute(
distribute: fl.Distribute,
input_nodes: list[GraphNode],
) -> tuple[list[GraphNode], list[GraphNode]]:
"""Convert a refiners distribute layer to nodes.
Args:
distribute: the distribute layer to convert.
input_nodes: the input nodes of the distribute layer.
"""
nodes: list[GraphNode] = []
output_nodes: list[GraphNode] = []
for (layer, parent_layer), input_node in zip(distribute.walk(recurse=False), input_nodes, strict=True):
module_nodes, previous_nodes = convert_module(layer, parent_layer, [input_node])
nodes.extend(module_nodes)
output_nodes.extend(previous_nodes)
return nodes, output_nodes
def convert_other(
module: fl.Module,
parent_module: fl.Chain | None,
input_nodes: list[GraphNode],
) -> tuple[list[GraphNode], list[GraphNode]]:
"""Convert a refiners module to nodes.
Args:
module: the module to convert.
parent_module: the parent module of the module.
input_nodes: the input nodes of the module.
"""
node = GraphNode(
namespace=parent_module.get_path().replace(".", "/") if parent_module else "",
id=module.get_path(parent_module),
label=module._get_name(), # type: ignore
attrs=[
KeyValue(
key=key,
value=str(value),
)
for key, value in module.basic_attributes().items()
],
)
setattr(node, "refiners_module", module) # noqa: B010
for input_node in input_nodes:
node.incomingEdges.append(
IncomingEdge(
sourceNodeId=input_node.id,
),
)
return [node], [node]
def convert_module( # type: ignore
module: fl.Module,
parent_module: fl.Chain | None,
input_nodes: list[GraphNode],
) -> tuple[list[GraphNode], list[GraphNode]]:
"""Convert a refiners module to nodes.
Args:
module: the module to convert.
parent_module: the parent module of the module.
input_nodes: the input nodes of the module.
"""
match module:
case fl.Parallel():
return convert_parallel(module, input_nodes)
case fl.Distribute():
return convert_distribute(module, input_nodes)
case fl.Concatenate():
return convert_concatenate(module, input_nodes)
case fl.Residual():
return convert_residual(module, input_nodes)
case fl.Passthrough():
return convert_passthrough(module, input_nodes)
case fl.Chain():
return convert_chain(module, input_nodes)
case fl.UseContext():
return convert_other(module, parent_module=parent_module, input_nodes=[])
case fl.Parameter():
return convert_other(module, parent_module=parent_module, input_nodes=[])
case _:
return convert_other(module, parent_module, input_nodes)
def convert_model(model: fl.Chain) -> ModelExplorerGraphs:
"""Convert a refiners model to a Graph.
Args:
model: the model to convert.
"""
# initialize the graph, with an input and output node
graph = Graph(id=model._get_name()) # type: ignore
input_node = GraphNode(id="input", label="input")
graph.nodes.append(input_node)
output_node = GraphNode(id="output", label="output")
graph.nodes.append(output_node)
# convert the model
model_nodes, model_output_nodes = convert_module(
module=model,
parent_module=None,
input_nodes=[input_node],
)
graph.nodes.extend(model_nodes)
# connect the model's output to the output node
for node in model_output_nodes:
output_node.incomingEdges.append(
IncomingEdge(
sourceNodeId=node.id,
),
)
# connect the model's context nodes
for node in graph.nodes:
refiners_module = getattr(node, "refiners_module", None)
if isinstance(refiners_module, fl.SetContext):
# find the context node
context_node = find_node(
graph=graph,
node_id=f"context.{refiners_module.context}.{refiners_module.key}",
)
# if it doesn't exist, create it
if context_node is None:
context_node = GraphNode(
namespace=f"context/{refiners_module.context}/{refiners_module.key}",
id=f"context.{refiners_module.context}.{refiners_module.key}",
label=refiners_module.key,
)
# connect the nodes: node -> context_node
context_node.incomingEdges.append(
IncomingEdge(
sourceNodeId=node.id,
),
)
graph.nodes.append(context_node)
if isinstance(refiners_module, fl.UseContext):
# find the context node
context_node = find_node(
graph=graph,
node_id=f"context.{refiners_module.context}.{refiners_module.key}",
)
# if it doesn't exist, create it
if context_node is None:
context_node = GraphNode(
namespace=f"context/{refiners_module.context}/{refiners_module.key}",
id=f"context.{refiners_module.context}.{refiners_module.key}",
label=refiners_module.key,
)
graph.nodes.append(context_node)
# connect the nodes: context_node -> node
node.incomingEdges.append(
IncomingEdge(
sourceNodeId=f"context.{refiners_module.context}.{refiners_module.key}",
),
)
return ModelExplorerGraphs(graphs=[graph])

View file

@ -0,0 +1,44 @@
import refiners.fluxion.layers as fl
from model_explorer import Adapter, AdapterMetadata, server
from model_explorer.config import ModelExplorerConfig
from model_explorer_refiners.convert import convert_model
class RefinersAdapter(Adapter):
metadata = AdapterMetadata(
id="refiners",
name="Refiners adapter",
description="Refiners adapter for Model Explorer",
source_repo="https://github.com/finegrain/model-explorer-refiners",
)
def __init__(self):
super().__init__()
@staticmethod
def visualize(
model: fl.Chain,
host: str = "localhost",
port: int = 8080,
) -> None:
"""Visualize a refiners model in Model Explorer.
Args:
model: the refiners model to visualize.
host: the host of the server.
port: the port of the server.
"""
# construct config
config = ModelExplorerConfig()
graph = convert_model(model)
config.graphs_list.append(graph)
index = len(config.graphs_list) - 1
config.model_sources.append({"url": f"graphs://refiners/{index}"})
# start server
server.start(
config=config,
host=host,
port=port,
)

View file

@ -0,0 +1,17 @@
from model_explorer.graph_builder import Graph, GraphNode
def find_node(graph: Graph, node_id: str) -> GraphNode | None:
"""Find a node in a graph.
Args:
graph: the graph to search.
node_id: the id of the node to find.
Returns:
The node node which matches the node_id, if found, otherwise None.
"""
return next(
(context_node for context_node in graph.nodes if context_node.id == node_id),
None,
)