import json
import time
import uuid
import os
from typing import Dict, Union, Optional
from orkes.graph.unit import ForwardEdge, ConditionalEdge
from orkes.graph.schema import NodePoolItem, TracesSchema, EdgeTrace
from orkes.graph.unit import _EndNode, _StartNode
from orkes.visualizer.generator import TraceInspector
from orkes.shared.context import trace_var, edge_id_var, edge_trace_var
from datetime import datetime
[docs]
class GraphRunner:
"""Executes a compiled OrkesGraph, managing state and tracing the execution.
The GraphRunner takes a compiled graph and an initial state, then traverses the
graph, executing the nodes and updating the state accordingly. It also records
traces of the execution, which can be saved to a file and visualized.
Attributes:
state_def (type): The TypedDict class that defines the shared state of the graph.
nodes_pool (Dict[str, NodePoolItem]): A dictionary of all nodes in the graph.
graph_state (Dict): The current state of the graph.
run_id (str): A unique identifier for the current run.
graph_name (str): The name of the graph being executed.
graph_description (str): The description of the graph being executed.
trace (TracesSchema): The trace of the current run.
traces_dir (str): The directory where traces are saved.
run_number (int): The current step number in the execution.
auto_save_trace (bool): If True, the trace is automatically saved after execution.
trace_inspector (TraceInspector): An object to generate a visualization of the trace.
"""
[docs]
def __init__(self, graph_name: str, graph_description: str, nodes_pool: Dict[str, NodePoolItem], graph_type: Dict, traces_dir: str = "traces", auto_save_trace: bool = False, traced: bool = True):
"""Initializes the GraphRunner.
Args:
graph_name (str): The name of the graph.
graph_description (str): The description of the graph.
nodes_pool (Dict[str, NodePoolItem]): The dictionary of nodes in the graph.
graph_type (Dict): The TypedDict class that defines the shared state.
traces_dir (str, optional): The directory to save traces. Defaults to "traces".
auto_save_trace (bool, optional): Whether to automatically save traces.
Defaults to False.
traced (bool, optional): Whether to enable tracing. Defaults to True.
"""
self.state_def = graph_type
self.nodes_pool = nodes_pool
self.graph_state: Dict = {}
self.run_id = str(uuid.uuid4())
self.graph_name = graph_name
self.graph_description = graph_description
self.traced = traced
self.trace = None
self.trace_inspector = None
if self.traced:
self.trace = TracesSchema(
run_id=self.run_id,
graph_name=self.graph_name,
graph_description=self.graph_description,
nodes_trace=[v.node.node_trace for k, v in nodes_pool.items()],
edges_trace=[]
)
self.trace_inspector = TraceInspector()
self.traces_dir = traces_dir
self.run_number = 0
self.auto_save_trace = auto_save_trace
def save_run_trace(self):
"""Saves the execution trace to a JSON file."""
if not self.traced:
return
if not os.path.exists(self.traces_dir):
os.makedirs(self.traces_dir)
filename = os.path.join(self.traces_dir, f"trace_{self.run_id}.json")
with open(filename, 'w') as f:
json.dump(self.trace.model_dump(), f, indent=4)
def visualize_trace(self):
"""Generates an HTML visualization of the execution trace."""
if not self.traced:
return
if not os.path.exists(self.traces_dir):
os.makedirs(self.traces_dir)
base_name = f"trace_{self.run_id}_inspector.html"
out_file = os.path.join(self.traces_dir, base_name)
self.trace_inspector.generate_viz(self.trace.model_dump(), out_file)
def run(self, invoke_state: Dict) -> Dict:
"""Runs the graph with a given initial state.
Args:
invoke_state (Dict): The initial state to run the graph with.
Returns:
Dict: The final state of the graph after execution.
Raises:
KeyError: If the invoke_state contains keys not defined in the graph's state.
"""
# Check that all keys in invoke_state exist in graph_state
missing_keys = [key for key in invoke_state if key not in self.state_def.__annotations__]
if missing_keys:
raise KeyError(f"The following items are missing in self.graph_state: {missing_keys}")
# Merge invoke_state into a copy of graph_state (avoid mutating original)
self.graph_state = invoke_state
input_state = self.graph_state.copy()
# Start traversal from the START node
start_pool = self.nodes_pool['START']
start_edges = start_pool.edge
if self.traced:
self.trace.start_time = time.time()
token = trace_var.set(self.trace)
try:
self.traverse_graph(start_edges, input_state)
finally:
trace_var.reset(token)
self.trace.elapsed_time = time.time() - self.trace.start_time
self.trace.status = "FINISHED"
if self.auto_save_trace:
self.save_run_trace()
else:
self.traverse_graph(start_edges, input_state)
return self.graph_state
def traverse_graph(self, current_edge: Union[ForwardEdge, ConditionalEdge], input_state: Dict, stop_node_name: Optional[str] = None):
"""Traverses the graph, either with or without tracing.
Args:
current_edge (Union[ForwardEdge, ConditionalEdge]): The current edge to traverse.
input_state (Dict): The current state of the graph.
"""
if self.traced:
self._traverse_traced(current_edge, input_state, stop_node_name)
else:
self._traverse_untraced(current_edge, input_state, stop_node_name)
def _traverse_untraced(self, current_edge: Union[ForwardEdge, ConditionalEdge], input_state: Dict, stop_node_name: Optional[str] = None):
"""Recursively traverses the graph without tracing.
Args:
current_edge (Union[ForwardEdge, ConditionalEdge]): The current edge to traverse.
input_state (Dict): The current state of the graph.
stop_node_name (Optional[str]): The name of the node at which to stop traversal.
Raises:
RuntimeError: If an edge is traversed more than the maximum allowed times.
"""
current_node = current_edge.from_node.node
if stop_node_name and current_node.name == stop_node_name:
return
if current_edge.passes > current_edge.max_passes:
raise RuntimeError(
f"Edge '{current_edge.id}' has been passed {current_edge.max_passes} times, "
"exceeding the allowed maximum without reaching a stop condition."
)
else:
current_edge.passes += 1
if current_edge.edge_type == "__forward__":
if not isinstance(current_node, _StartNode):
result = current_node.execute(input_state)
self.graph_state.update(result)
next_edge = current_edge.to_node.edge
next_node = current_edge.to_node.node
elif current_edge.edge_type == "__conditional__":
result = current_node.execute(input_state)
self.graph_state.update(result)
gate_function = current_edge.gate_function
condition = current_edge.condition
result_gate = gate_function(self.graph_state)
next_node_name = condition[result_gate]
next_node = self.nodes_pool[next_node_name].node
next_edge = self.nodes_pool[next_node_name].edge
elif current_edge.edge_type == "__parallel__":
if not isinstance(current_node, _StartNode):
result = current_node.execute(input_state)
self.graph_state.update(result)
# Recursively traverse each parallel branch
for to_node_item in current_edge.to_nodes:
branch_start_edge = self.nodes_pool[to_node_item.node.name].edge
self.traverse_graph(branch_start_edge, self.graph_state.copy(), stop_node_name=current_edge.aggregation_node.node.name)
# After all parallel branches are (sequentially) traversed,
# the control flow should move to the aggregation node.
# The 'next_edge' should be the one coming *out* of the aggregation node.
next_node = current_edge.aggregation_node.node
next_edge = current_edge.aggregation_node.edge
if not isinstance(next_node, _EndNode):
next_input = self.graph_state.copy()
self.traverse_graph(next_edge, next_input, stop_node_name)
def _traverse_traced(self, current_edge: Union[ForwardEdge, ConditionalEdge], input_state: Dict, stop_node_name: Optional[str] = None):
"""Recursively traverses the graph with tracing.
Args:
current_edge (Union[ForwardEdge, ConditionalEdge]): The current edge to traverse.
input_state (Dict): The current state of the graph.
stop_node_name (Optional[str]): The name of the node at which to stop traversal.
Raises:
RuntimeError: If an edge is traversed more than the maximum allowed times.
"""
current_node = current_edge.from_node.node
if stop_node_name and current_node.name == stop_node_name:
return
edge_token = edge_id_var.set(current_edge.id)
try:
if current_edge.passes > current_edge.max_passes:
raise RuntimeError(
f"Edge '{current_edge.id}' has been passed {current_edge.max_passes} times, "
"exceeding the allowed maximum without reaching a stop condition."
)
else:
current_edge.passes += 1
self.run_number += 1
edge_trace = current_edge.edge_trace.model_copy()
edge_trace.edge_run_number = self.run_number
edge_trace.passes_left = current_edge.max_passes - current_edge.passes
start = time.time()
edge_trace.state_snapshot = input_state.copy()
edge_trace_token = edge_trace_var.set(edge_trace)
try:
if current_edge.edge_type == "__forward__":
if not isinstance(current_node, _StartNode):
result = current_node.execute(input_state)
self.graph_state.update(result)
next_edge = current_edge.to_node.edge
next_node = current_edge.to_node.node
elif current_edge.edge_type == "__conditional__":
result = current_node.execute(input_state)
self.graph_state.update(result)
gate_function = current_edge.gate_function
condition = current_edge.condition
result_gate = gate_function(self.graph_state)
next_node_name = condition[result_gate]
next_node = self.nodes_pool[next_node_name].node
next_edge = self.nodes_pool[next_node_name].edge
edge_trace.to_node = next_node_name
elif current_edge.edge_type == "__parallel__":
if not isinstance(current_node, _StartNode):
result = current_node.execute(input_state)
self.graph_state.update(result)
# Record the parallel branch starting points in the trace
edge_trace.to_node = [node_item.node.name for node_item in current_edge.to_nodes]
edge_trace.meta["aggregation_node"] = current_edge.aggregation_node.node.name
# Recursively traverse each parallel branch
for to_node_item in current_edge.to_nodes:
branch_start_edge = self.nodes_pool[to_node_item.node.name].edge
# Pass a copy of the state to simulate independent branches for tracing,
# but the graph_state itself is shared.
self.traverse_graph(branch_start_edge, self.graph_state.copy(), stop_node_name=current_edge.aggregation_node.node.name)
# After all parallel branches are (sequentially) traversed,
# the control flow should move to the aggregation node.
next_node = current_edge.aggregation_node.node
next_edge = current_edge.aggregation_node.edge
finally:
edge_trace_var.reset(edge_trace_token)
edge_trace.elapsed = time.time() - start
self.trace.edges_trace.append(edge_trace)
if not isinstance(next_node, _EndNode):
next_input = self.graph_state.copy()
self.traverse_graph(next_edge, next_input, stop_node_name)
finally:
edge_id_var.reset(edge_token)
# Handle Brancing and merging state -> because state update only happen after node process done, no shared mutable object
# FAN IN FAN OUT STRATEGY, EVERY BRANCHING NODE NEED TO BE RETURNED
# In your example:
# A
# |
# B
# / \
# C D
# \
# E
# If E needs data from both C and D, you have two main options:
# Make E a "merge node" that accepts inputs from both C and D ā i.e., edges C -> E and D -> E.
# E will receive two incoming states, merge them internally, then execute.
# Insert an explicit merge node (e.g., M):
# C D
# \ /
# M
# |
# E
# The merge node M merges C and Dās outputs.
# Then E runs with the combined state.