Skip to content

Commit

Permalink
Import for basic comfy UI workflow files (might not work when using b…
Browse files Browse the repository at this point in the history
…ypass and similar)
  • Loading branch information
Acly committed Oct 5, 2024
1 parent 5356b9d commit b57cd94
Show file tree
Hide file tree
Showing 7 changed files with 856 additions and 27 deletions.
2 changes: 1 addition & 1 deletion ai_diffusion/comfy_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ async def connect(url=default_url, access_token=""):

# Check for required and optional model resources
models = client.models
models.node_inputs = {name: nodes[name]["input"].get("required", None) for name in nodes}
models.node_inputs = {name: nodes[name]["input"] for name in nodes}
available_resources = client.models.resources = {}

clip_models = nodes["DualCLIPLoader"]["input"]["required"]["clip_name1"][0]
Expand Down
69 changes: 65 additions & 4 deletions ai_diffusion/comfy_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,17 +64,22 @@ def __init__(self, node_inputs: dict | None = None, run_mode=ComfyRunMode.server
self.node_count = 0
self.sample_count = 0
self._cache: dict[str, Output | Output2 | Output3 | Output4] = {}
self._nodes_required_inputs: dict[str, dict[str, Any]] = node_inputs or {}
self._nodes_inputs: dict[str, dict[str, Any]] = node_inputs or {}
self._run_mode: ComfyRunMode = run_mode

@staticmethod
def import_graph(existing: dict):
w = ComfyWorkflow()
def import_graph(existing: dict, node_inputs: dict):
w = ComfyWorkflow(node_inputs)
existing = _convert_ui_workflow(existing, node_inputs)
node_map: dict[str, str] = {}
queue = list(existing.keys())
while queue:
id = queue.pop(0)
node = deepcopy(existing[id])
if node_inputs and node["class_type"] not in node_inputs:
raise ValueError(
f"Workflow contains a node of type {node['class_type']} which is not installed on the ComfyUI server."
)
edges = [e for e in node["inputs"].values() if isinstance(e, list)]
if any(e[0] not in node_map for e in edges):
queue.append(id) # requeue node if an input is not yet mapped
Expand All @@ -94,7 +99,7 @@ def from_dict(existing: dict):
return w

def add_default_values(self, node_name: str, args: dict):
if node_inputs := self._nodes_required_inputs.get(node_name, None):
if node_inputs := _inputs_for_node(self._nodes_inputs, node_name, "required"):
for k, v in node_inputs.items():
if k not in args:
if len(v) == 1 and isinstance(v[0], list) and len(v[0]) > 0:
Expand Down Expand Up @@ -834,3 +839,59 @@ def estimate_pose(self, image: Output, resolution: int):
# use smaller model, but it requires onnxruntime, see #630
mdls["bbox_detector"] = "yolo_nas_l_fp16.onnx"
return self.add("DWPreprocessor", 1, image=image, resolution=resolution, **feat, **mdls)


def _inputs_for_node(node_inputs: dict[str, dict[str, Any]], node_name: str, filter=""):
inputs = node_inputs.get(node_name)
if inputs is None:
return None
if filter:
return inputs.get(filter)
result = inputs.get("required", {})
result.update(inputs.get("optional", {}))
return result


def _convert_ui_workflow(w: dict, node_inputs: dict):
version = w.get("version")
nodes = w.get("nodes")
links = w.get("links")
if not (version and nodes and links):
return w

primitives = {}
for node in nodes:
if node["type"] == "PrimitiveNode":
primitives[node["id"]] = node["widgets_values"][0]

r = {}
for node in nodes:
id = node["id"]
type = node["type"]
if type == "PrimitiveNode":
continue

inputs = {}
fields = _inputs_for_node(node_inputs, type)
if fields is None:
raise ValueError(
f"Workflow uses node type {type}, but it is not installed on the ComfyUI server."
)
widget_count = 0
for field_name, field in fields.items():
field_type = field[0]
if field_type in ["INT", "FLOAT", "BOOL", "STRING"] or isinstance(field_type, list):
inputs[field_name] = node["widgets_values"][widget_count]
widget_count += 1
for connection in node["inputs"]:
if connection["name"] == field_name and connection["link"] is not None:
link = next(l for l in links if l[0] == connection["link"])
prim = primitives.get(link[1])
if prim is not None:
inputs[field_name] = prim
else:
inputs[field_name] = [link[1], link[2]]
break
r[id] = {"class_type": type, "inputs": inputs}

return r
31 changes: 22 additions & 9 deletions ai_diffusion/custom_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,18 @@ class CustomWorkflow:
id: str
source: WorkflowSource
graph: dict
workflow: ComfyWorkflow
path: Path | None = None

@staticmethod
def from_api(id: str, source: WorkflowSource, graph: dict, path: Path | None = None):
# doesn't work for UI workflow export (API workflow only)
return CustomWorkflow(id, source, graph, ComfyWorkflow.import_graph(graph, {}), path)

@property
def name(self):
return self.id.removesuffix(".json")

@property
def workflow(self):
return ComfyWorkflow.import_graph(self.graph)


class WorkflowCollection(QAbstractListModel):

Expand All @@ -58,12 +60,20 @@ def __init__(self, connection: Connection, folder: Path | None = None):
for wf in self._connection.workflows.keys():
self._process_remote_workflow(wf)

def _create_workflow(
self, id: str, source: WorkflowSource, graph: dict, path: Path | None = None
):
wf = ComfyWorkflow.import_graph(graph, self._connection.client.models.node_inputs)
return CustomWorkflow(id, source, graph, wf, path)

def _process_remote_workflow(self, id: str):
self._process(CustomWorkflow(id, WorkflowSource.remote, self._connection.workflows[id]))
graph = self._connection.workflows[id]
self._process(self._create_workflow(id, WorkflowSource.remote, graph))

def _process_file(self, file: Path):
with file.open("r") as f:
self._process(CustomWorkflow(file.stem, WorkflowSource.local, json.load(f), file))
graph = json.load(f)
self._process(self._create_workflow(file.stem, WorkflowSource.local, graph, file))

def _process(self, workflow: CustomWorkflow):
idx = self.find_index(workflow.id)
Expand Down Expand Up @@ -94,6 +104,9 @@ def append(self, item: CustomWorkflow):
self._workflows.append(item)
self.endInsertRows()

def add_from_document(self, id: str, graph: dict):
self.append(self._create_workflow(id, WorkflowSource.document, graph))

def remove(self, id: str):
idx = self.find_index(id)
if idx.isValid():
Expand All @@ -118,15 +131,15 @@ def save_as(self, id: str, graph: dict):
self._folder.mkdir(exist_ok=True)
path = self._folder / f"{id}.json"
path.write_text(json.dumps(graph, indent=2))
self.append(CustomWorkflow(id, WorkflowSource.local, graph, path))
self.append(self._create_workflow(id, WorkflowSource.local, graph, path))
return id

def import_file(self, filepath: Path):
try:
with filepath.open("r") as f:
graph = json.load(f)
try:
ComfyWorkflow.import_graph(graph)
ComfyWorkflow.import_graph(graph, self._connection.client.models.node_inputs)
except Exception as e:
raise RuntimeError(f"This is not a supported workflow file ({e})")
return self.save_as(filepath.stem, graph)
Expand Down Expand Up @@ -279,7 +292,7 @@ def _set_workflow_id(self, id: str):

def set_graph(self, id: str, graph: dict):
if self._workflows.find(id) is None:
self._workflows.append(CustomWorkflow(id, WorkflowSource.document, graph))
self._workflows.add_from_document(id, graph)
self.workflow_id = id

def import_file(self, filepath: Path):
Expand Down
Loading

0 comments on commit b57cd94

Please sign in to comment.