Source code for orkes.graph.unit


from typing import Any, Callable, Dict
import uuid
from abc import ABC, abstractmethod
from orkes.graph.schema import NodePoolItem, NodeTrace, EdgeTrace

[docs] class Node: """Represents a node in the computational graph. Each node encapsulates a function to be executed and maintains its state within the graph. Attributes: name (str): The unique identifier for the node. func (Callable): The function to be executed by this node. graph_state: A reference to the graph's state. id (str): A unique identifier for the node instance. description (str): The docstring of the function. node_trace (NodeTrace): The trace object for the node. """
[docs] def __init__(self, name: str, func: Callable, graph_state): """Initializes a Node. Args: name (str): The unique identifier for the node. func (Callable): The function to be executed by this node. The function should accept the graph's state as input and return an output that contributes to the state. graph_state: A reference to the graph's state, allowing the node to interact with and modify it. """ self.name: str = name self.func: Callable = func self.graph_state = graph_state self.id = "node_" + str(uuid.uuid4()) self.description = func.__doc__ self.node_trace = NodeTrace( node_name=self.name, node_id=self.id, node_description=self.description, meta={ "type": "function_node" } )
def execute(self, input_state) -> Any: """Executes the node's function. Args: input_state: The input state for the function. Returns: Any: The output of the function. """ output = self.func(input_state) return output def __repr__(self) -> str: return f"Node({self.name})"
class _StartNode(Node): """A special node that marks the entry point of the graph. It is the first node to be executed and typically forwards the initial state to the next nodes. """ def __init__(self, graph_state): super().__init__("START", self._start, graph_state) self.node_trace.meta = { "type": "start_node" } self.node_trace.node_description = self.func.__doc__ def _start(self, state): """The entry point of the graph.""" return state class _EndNode(Node): """A special node that marks the termination point of the graph. When the execution reaches this node, the graph's run is considered complete. It can be used for finalizing the state or cleaning up resources. """ def __init__(self, graph_state): super().__init__("END", self._end, graph_state) self.node_trace.node_description = self.func.__doc__ self.node_trace.meta = { "type": "end_node" } def _end(self, state): """The termination point of the graph.""" return state
[docs] class Edge(ABC): """Represents a directed connection between two nodes in the graph. This is an abstract base class that should not be instantiated directly. Attributes: id (str): A unique identifier for the edge instance. from_node (NodePoolItem): The node from which the edge originates. to_node (NodePoolItem): The node to which the edge points. passes (int): The number of times the edge has been traversed. max_passes (int): The maximum number of times the edge can be traversed. edge_type (str): The type of the edge. edge_trace (EdgeTrace): The trace object for the edge. """
[docs] def __init__(self, from_node: NodePoolItem, to_node: NodePoolItem = None, max_passes=25): """Initializes an Edge. Args: from_node (NodePoolItem): The node from which the edge originates. to_node (NodePoolItem, optional): The node to which the edge points. Defaults to None for conditional edges where the destination is determined at runtime. max_passes (int, optional): The maximum number of times the graph execution can traverse this edge. This helps prevent infinite loops. Defaults to 25. """ self.id = "edge_" + str(uuid.uuid4()) self.from_node = from_node self.to_node = to_node self.passes = 0 self.max_passes = max_passes self.edge_type = None self.edge_trace = EdgeTrace( edge_id=self.id, edge_run_number=0, from_node=self.from_node.node.name, to_node=self.to_node.node.name if self.to_node else "N/A", passes_left=self.max_passes, edge_type=self.edge_type, elapsed=0.0, meta={} )
def __repr__(self): return f"{self.__class__.__name__}({self.id})"
[docs] class ForwardEdge(Edge): """An edge that unconditionally forwards data and control from a source node to a destination node. """
[docs] def __init__(self, from_node: NodePoolItem, to_node: NodePoolItem, max_passes: int = 25): """Initializes a ForwardEdge. Args: from_node (NodePoolItem): The node from which the edge originates. to_node (NodePoolItem): The node to which the edge points. max_passes (int, optional): The maximum number of times the graph execution can traverse this edge. Defaults to 25. """ super().__init__(from_node, to_node, max_passes) self.edge_type = "__forward__" self.edge_trace = EdgeTrace( edge_id=self.id, edge_run_number=0, from_node=self.from_node.node.name, to_node=self.to_node.node.name if self.to_node else "N/A", passes_left=self.max_passes, edge_type=self.edge_type, elapsed=0.0, meta={ "type": "forward_edge" })
[docs] class ConditionalEdge(Edge): """An edge that determines the next node based on the outcome of a gate function."""
[docs] def __init__( self, from_node: NodePoolItem, gate_function: Callable, condition: Dict[str, str], max_passes=25 ): """Initializes a ConditionalEdge. Args: from_node (NodePoolItem): The node from which the edge originates. gate_function (Callable): A function that is executed to decide which path to take. The function's return value is matched against the keys in the `condition` dictionary to determine the next node. condition (Dict[str, str]): A dictionary mapping the possible outcomes of the `gate_function` to the names of the next nodes. max_passes (int, optional): The maximum number of times the graph execution can traverse this edge. Defaults to 25. """ super().__init__(from_node, to_node=None, max_passes=max_passes) # initialize parent part self.gate_function = gate_function self.condition = condition self.edge_type = "__conditional__" self.edge_trace = EdgeTrace( edge_id=self.id, edge_run_number=0, from_node=self.from_node.node.name, to_node=self.to_node.node.name if self.to_node else "N/A", passes_left=self.max_passes, edge_type=self.edge_type, elapsed=0.0, meta={ "type": "conditional_edge" } )
[docs] class ParallelEdge(Edge): """An edge that creates multiple parallel branches of execution. Each branch starts from a node in `to_nodes`. All branches are expected to eventually converge at the `aggregation_node`. """
[docs] def __init__(self, from_node: NodePoolItem, to_nodes: list[NodePoolItem], aggregation_node: NodePoolItem, max_passes: int = 25): """Initializes a ParallelEdge. Args: from_node (NodePoolItem): The node from which the edge originates. to_nodes (list[NodePoolItem]): A list of nodes where each parallel branch starts. aggregation_node (NodePoolItem): The node where all parallel branches are expected to converge. max_passes (int, optional): The maximum number of times this edge can be traversed. Defaults to 25. """ super().__init__(from_node, to_node=None, max_passes=max_passes) self.to_nodes = to_nodes self.aggregation_node = aggregation_node self.edge_type = "__parallel__" self.edge_trace = EdgeTrace( edge_id=self.id, edge_run_number=0, from_node=self.from_node.node.name, to_node=[node.node.name for node in self.to_nodes], passes_left=self.max_passes, edge_type=self.edge_type, elapsed=0.0, meta={ "type": "parallel_edge", "aggregation_node": self.aggregation_node.node.name })
NodePoolItem.model_rebuild()