Skip to content

Async Execution Functions #3

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions comfy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -992,11 +992,12 @@ def set_progress_bar_global_hook(function):
PROGRESS_BAR_HOOK = function

class ProgressBar:
def __init__(self, total):
def __init__(self, total, node_id=None):
global PROGRESS_BAR_HOOK
self.total = total
self.current = 0
self.hook = PROGRESS_BAR_HOOK
self.node_id = node_id

def update_absolute(self, value, total=None, preview=None):
if total is not None:
Expand All @@ -1005,7 +1006,7 @@ def update_absolute(self, value, total=None, preview=None):
value = self.total
self.current = value
if self.hook is not None:
self.hook(self.current, self.total, preview)
self.hook(self.current, self.total, preview, node_id=self.node_id)

def update(self, value):
self.update_absolute(self.current + value)
Expand Down
20 changes: 19 additions & 1 deletion comfy_execution/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Type, Literal

import nodes
import asyncio
from comfy_execution.graph_utils import is_link
from comfy.comfy_types.node_typing import ComfyNodeABC, InputTypeDict, InputTypeOptions

Expand Down Expand Up @@ -100,6 +101,8 @@ def __init__(self, dynprompt):
self.pendingNodes = {}
self.blockCount = {} # Number of nodes this node is directly blocked by
self.blocking = {} # Which nodes are blocked by this node
self.externalBlocks = 0
self.unblockedEvent = asyncio.Event()

def get_input_info(self, unique_id, input_name):
class_type = self.dynprompt.get_node(unique_id)["class_type"]
Expand Down Expand Up @@ -153,6 +156,16 @@ def add_node(self, node_unique_id, include_lazy=False, subgraph_nodes=None):
for link in links:
self.add_strong_link(*link)

def add_external_block(self, node_id):
assert node_id in self.blockCount, "Can't add external block to a node that isn't pending"
self.externalBlocks += 1
self.blockCount[node_id] += 1
def unblock():
self.externalBlocks -= 1
self.blockCount[node_id] -= 1
self.unblockedEvent.set()
return unblock

def is_cached(self, node_id):
return False

Expand Down Expand Up @@ -181,11 +194,16 @@ def __init__(self, dynprompt, output_cache):
def is_cached(self, node_id):
return self.output_cache.get(node_id) is not None

def stage_node_execution(self):
async def stage_node_execution(self):
assert self.staged_node_id is None
if self.is_empty():
return None, None, None
available = self.get_ready_nodes()
while len(available) == 0 and self.externalBlocks > 0:
# Wait for an external block to be released
await self.unblockedEvent.wait()
self.unblockedEvent.clear()
available = self.get_ready_nodes()
if len(available) == 0:
cycled_nodes = self.get_nodes_in_cycle()
# Because cycles composed entirely of static nodes are caught during initial validation,
Expand Down
103 changes: 96 additions & 7 deletions execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from enum import Enum
import inspect
from typing import List, Literal, NamedTuple, Optional
import asyncio

import torch
import nodes
Expand Down Expand Up @@ -192,6 +193,63 @@ def process_inputs(inputs, index=None, input_is_list=False):
process_inputs(input_dict, i)
return results

async def _async_map_node_over_list(obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None):
# check if node wants the lists
input_is_list = getattr(obj, "INPUT_IS_LIST", False)

if len(input_data_all) == 0:
max_len_input = 0
else:
max_len_input = max(len(x) for x in input_data_all.values())

# get a slice of inputs, repeat last input when list isn't long enough
def slice_dict(d, i):
return {k: v[i if len(v) > i else -1] for k, v in d.items()}

results = []
async def process_inputs(inputs, index=None, input_is_list=False):
if allow_interrupt:
nodes.before_node_execution()
execution_block = None
for k, v in inputs.items():
if input_is_list:
for e in v:
if isinstance(e, ExecutionBlocker):
v = e
break
if isinstance(v, ExecutionBlocker):
execution_block = execution_block_cb(v) if execution_block_cb else v
break
if execution_block is None:
if pre_execute_cb is not None and index is not None:
pre_execute_cb(index)
f = getattr(obj, func)
if inspect.iscoroutinefunction(f):
task = asyncio.create_task(f(**inputs))
# Give the task a chance to execute without yielding
await asyncio.sleep(0)
if task.done():
result = task.result()
results.append(result)
else:
results.append(task)
else:
result = f(**inputs)
results.append(result)
else:
results.append(execution_block)

if input_is_list:
await process_inputs(input_data_all, 0, input_is_list=input_is_list)
elif max_len_input == 0:
await process_inputs({})
else:
for i in range(max_len_input):
input_dict = slice_dict(input_data_all, i)
await process_inputs(input_dict, i)
return results


def merge_result_data(results, obj):
# check which outputs need concatenating
output = []
Expand All @@ -213,11 +271,18 @@ def merge_result_data(results, obj):
output.append([o[i] for o in results])
return output

def get_output_data(obj, input_data_all, execution_block_cb=None, pre_execute_cb=None):
async def get_output_data(obj, input_data_all, execution_block_cb=None, pre_execute_cb=None):
return_values = await _async_map_node_over_list(obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb)
has_pending_task = any(isinstance(r, asyncio.Task) and not r.done() for r in return_values)
if has_pending_task:
return return_values, {}, False, has_pending_task
output, ui, has_subgraph = get_output_from_returns(return_values, obj)
return output, ui, has_subgraph, False

def get_output_from_returns(return_values, obj):
results = []
uis = []
subgraph_results = []
return_values = _map_node_over_list(obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb)
has_subgraph = False
for i in range(len(return_values)):
r = return_values[i]
Expand Down Expand Up @@ -251,6 +316,10 @@ def get_output_data(obj, input_data_all, execution_block_cb=None, pre_execute_cb
else:
output = []
ui = dict()
# Think there's an existing bug here
# If we're performing a subgraph expansion, we probably shouldn't be returning UI values yet.
# They'll get cached without the completed subgraphs. It's an edge case and I'm not aware of
# any nodes that use both subgraph expansion and custom UI outputs, but might be a problem in the future.
if len(uis) > 0:
ui = {k: [y for x in uis for y in x[k]] for k in uis[0].keys()}
return output, ui, has_subgraph
Expand All @@ -263,7 +332,7 @@ def format_value(x):
else:
return str(x)

def execute(server, dynprompt, caches, current_item, extra_data, executed, prompt_id, execution_list, pending_subgraph_results):
async def execute(server, dynprompt, caches, current_item, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes):
unique_id = current_item
real_node_id = dynprompt.get_real_node_id(unique_id)
display_node_id = dynprompt.get_display_node_id(unique_id)
Expand All @@ -279,7 +348,11 @@ def execute(server, dynprompt, caches, current_item, extra_data, executed, promp

input_data_all = None
try:
if unique_id in pending_subgraph_results:
if unique_id in pending_async_nodes:
results = [r.result() if isinstance(r, asyncio.Task) else r for r in pending_async_nodes[unique_id]]
del pending_async_nodes[unique_id]
output_data, output_ui, has_subgraph = get_output_from_returns(results, class_def)
elif unique_id in pending_subgraph_results:
cached_results = pending_subgraph_results[unique_id]
resolved_outputs = []
for is_subgraph, result in cached_results:
Expand Down Expand Up @@ -341,8 +414,18 @@ def execution_block_cb(block):
else:
return block
def pre_execute_cb(call_index):
# TODO - How to handle this with async functions without contextvars (which requires Python 3.12)?
GraphBuilder.set_default_prefix(unique_id, call_index, 0)
output_data, output_ui, has_subgraph = get_output_data(obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb)
output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb)
if has_pending_tasks:
pending_async_nodes[unique_id] = output_data
unblock = execution_list.add_external_block(unique_id)
async def await_completion():
tasks = [x for x in output_data if isinstance(x, asyncio.Task)]
await asyncio.gather(*tasks)
unblock()
asyncio.create_task(await_completion())
return (ExecutionResult.PENDING, None, None)
if len(output_ui) > 0:
caches.ui.set(unique_id, {
"meta": {
Expand Down Expand Up @@ -481,6 +564,11 @@ def handle_execution_error(self, prompt_id, prompt, current_outputs, executed, e
self.add_message("execution_error", mes, broadcast=False)

def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]):
asyncio_loop = asyncio.new_event_loop()
asyncio.set_event_loop(asyncio_loop)
asyncio.run(self.execute_async(prompt, prompt_id, extra_data, execute_outputs))

async def execute_async(self, prompt, prompt_id, extra_data={}, execute_outputs=[]):
nodes.interrupt_processing(False)

if "client_id" in extra_data:
Expand Down Expand Up @@ -508,19 +596,20 @@ def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]):
{ "nodes": cached_nodes, "prompt_id": prompt_id},
broadcast=False)
pending_subgraph_results = {}
pending_async_nodes = {} # TODO - Unify this with pending_subgraph_results
executed = set()
execution_list = ExecutionList(dynamic_prompt, self.caches.outputs)
current_outputs = self.caches.outputs.all_node_ids()
for node_id in list(execute_outputs):
execution_list.add_node(node_id)

while not execution_list.is_empty():
node_id, error, ex = execution_list.stage_node_execution()
node_id, error, ex = await execution_list.stage_node_execution()
if error is not None:
self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex)
break

result, error, ex = execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results)
result, error, ex = await execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes)
self.success = result != ExecutionResult.FAILURE
if result == ExecutionResult.FAILURE:
self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex)
Expand Down
8 changes: 6 additions & 2 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,9 +227,13 @@ async def run(server_instance, address='', port=8188, verbose=True, call_on_star


def hijack_progress(server_instance):
def hook(value, total, preview_image):
def hook(value, total, preview_image, prompt_id=None, node_id=None):
comfy.model_management.throw_exception_if_processing_interrupted()
progress = {"value": value, "max": total, "prompt_id": server_instance.last_prompt_id, "node": server_instance.last_node_id}
if prompt_id is None:
prompt_id = server_instance.last_prompt_id
if node_id is None:
node_id = server_instance.last_node_id
progress = {"value": value, "max": total, "prompt_id": prompt_id, "node": node_id}

server_instance.send_sync("progress", progress, server_instance.client_id)
if preview_image is not None:
Expand Down
31 changes: 31 additions & 0 deletions nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,35 @@ def interrupt_processing(value=True):

MAX_RESOLUTION=16384

class TestSleep(ComfyNodeABC):
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"value": (IO.ANY, {}),
"seconds": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 9999.0, "step": 0.01, "tooltip": "The amount of seconds to sleep."}),
},
"hidden": {
"unique_id": "UNIQUE_ID",
},
}
RETURN_TYPES = (IO.ANY,)
FUNCTION = "sleep"

CATEGORY = "_for_testing"

async def sleep(self, value, seconds, unique_id):
pbar = comfy.utils.ProgressBar(seconds, node_id=unique_id)
import asyncio
start = time.time()
expiration = start + seconds
now = start
while now < expiration:
now = time.time()
pbar.update_absolute(now - start)
await asyncio.sleep(0.01)
return (value,)

class CLIPTextEncode(ComfyNodeABC):
@classmethod
def INPUT_TYPES(s) -> InputTypeDict:
Expand Down Expand Up @@ -1941,6 +1970,7 @@ def expand_image(self, image, left, top, right, bottom, feathering):


NODE_CLASS_MAPPINGS = {
"TestSleep": TestSleep,
"KSampler": KSampler,
"CheckpointLoaderSimple": CheckpointLoaderSimple,
"CLIPTextEncode": CLIPTextEncode,
Expand Down Expand Up @@ -2011,6 +2041,7 @@ def expand_image(self, image, left, top, right, bottom, feathering):
}

NODE_DISPLAY_NAME_MAPPINGS = {
"TestSleep": "Test Sleep",
# Sampling
"KSampler": "KSampler",
"KSamplerAdvanced": "KSampler (Advanced)",
Expand Down