Skip to content
3 changes: 1 addition & 2 deletions api/core/app/layers/pause_state_persist_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def __init__(
"""
if isinstance(session_factory, Engine):
session_factory = sessionmaker(session_factory)
super().__init__()
self._session_maker = session_factory
self._state_owner_user_id = state_owner_user_id
self._generate_entity = generate_entity
Expand Down Expand Up @@ -98,8 +99,6 @@ def on_event(self, event: GraphEngineEvent) -> None:
if not isinstance(event, GraphRunPausedEvent):
return

assert self.graph_runtime_state is not None

entity_wrapper: _GenerateEntityUnion
if isinstance(self._generate_entity, WorkflowAppGenerateEntity):
entity_wrapper = _WorkflowGenerateEntityWrapper(entity=self._generate_entity)
Expand Down
5 changes: 1 addition & 4 deletions api/core/app/layers/trigger_post_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def __init__(
trigger_log_id: str,
session_maker: sessionmaker[Session],
):
super().__init__()
self.trigger_log_id = trigger_log_id
self.start_time = start_time
self.cfs_plan_scheduler_entity = cfs_plan_scheduler_entity
Expand All @@ -57,10 +58,6 @@ def on_event(self, event: GraphEngineEvent):
elapsed_time = (datetime.now(UTC) - self.start_time).total_seconds()

# Extract relevant data from result
if not self.graph_runtime_state:
logger.exception("Graph runtime state is not set")
return

outputs = self.graph_runtime_state.outputs

# BASICLY, workflow_execution_id is the same as workflow_run_id
Expand Down
3 changes: 3 additions & 0 deletions api/core/workflow/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,9 @@ engine.layer(DebugLoggingLayer(level="INFO"))
engine.layer(ExecutionLimitsLayer(max_nodes=100))
```

`engine.layer()` binds the read-only runtime state before execution, so layer hooks
can assume `graph_runtime_state` is available.

### Event-Driven Architecture

All node executions emit events for monitoring and integration:
Expand Down
14 changes: 7 additions & 7 deletions api/core/workflow/graph_engine/graph_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,9 +212,16 @@ def _validate_graph_state_consistency(self) -> None:
if id(node.graph_runtime_state) != expected_state_id:
raise ValueError(f"GraphRuntimeState consistency violation: Node '{node.id}' has a different instance")

def _bind_layer_context(
self,
layer: GraphEngineLayer,
) -> None:
layer.initialize(ReadOnlyGraphRuntimeStateWrapper(self._graph_runtime_state), self._command_channel)

def layer(self, layer: GraphEngineLayer) -> "GraphEngine":
"""Add a layer for extending functionality."""
self._layers.append(layer)
self._bind_layer_context(layer)
return self

def run(self) -> Generator[GraphEngineEvent, None, None]:
Expand Down Expand Up @@ -301,14 +308,7 @@ def run(self) -> Generator[GraphEngineEvent, None, None]:
def _initialize_layers(self) -> None:
"""Initialize layers with context."""
self._event_manager.set_layers(self._layers)
# Create a read-only wrapper for the runtime state
read_only_state = ReadOnlyGraphRuntimeStateWrapper(self._graph_runtime_state)
for layer in self._layers:
try:
layer.initialize(read_only_state, self._command_channel)
except Exception as e:
logger.warning("Failed to initialize layer %s: %s", layer.__class__.__name__, e)

try:
layer.on_graph_start()
except Exception as e:
Expand Down
5 changes: 4 additions & 1 deletion api/core/workflow/graph_engine/layers/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ Pluggable middleware for engine extensions.

Abstract base class for layers.

- `initialize()` - Receive runtime context
- `initialize()` - Receive runtime context (runtime state is bound here and always available to hooks)
- `on_graph_start()` - Execution start hook
- `on_event()` - Process all events
- `on_graph_end()` - Execution end hook
Expand All @@ -34,6 +34,9 @@ engine.layer(debug_layer)
engine.run()
```

`engine.layer()` binds the read-only runtime state before execution, so
`graph_runtime_state` is always available inside layer hooks.

## Custom Layers

```python
Expand Down
25 changes: 19 additions & 6 deletions api/core/workflow/graph_engine/layers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,14 @@
from core.workflow.runtime import ReadOnlyGraphRuntimeState


class GraphEngineLayerNotInitializedError(Exception):
"""Raised when a layer's runtime state is accessed before initialization."""

def __init__(self, layer_name: str | None = None) -> None:
name = layer_name or "GraphEngineLayer"
super().__init__(f"{name} runtime state is not initialized. Bind the layer to a GraphEngine before access.")


class GraphEngineLayer(ABC):
"""
Abstract base class for GraphEngine layers.
Expand All @@ -28,22 +36,27 @@ class GraphEngineLayer(ABC):

def __init__(self) -> None:
"""Initialize the layer. Subclasses can override with custom parameters."""
self.graph_runtime_state: ReadOnlyGraphRuntimeState | None = None
self._graph_runtime_state: ReadOnlyGraphRuntimeState | None = None
self.command_channel: CommandChannel | None = None

@property
def graph_runtime_state(self) -> ReadOnlyGraphRuntimeState:
if self._graph_runtime_state is None:
raise GraphEngineLayerNotInitializedError(type(self).__name__)
return self._graph_runtime_state

def initialize(self, graph_runtime_state: ReadOnlyGraphRuntimeState, command_channel: CommandChannel) -> None:
"""
Initialize the layer with engine dependencies.

Called by GraphEngine before execution starts to inject the read-only runtime state
and command channel. This allows layers to observe engine context and send
commands, but prevents direct state modification.

Called by GraphEngine to inject the read-only runtime state and command channel.
This is invoked when the layer is registered with a `GraphEngine` instance.
Implementations should be idempotent.
Args:
graph_runtime_state: Read-only view of the runtime state
command_channel: Channel for sending commands to the engine
"""
self.graph_runtime_state = graph_runtime_state
self._graph_runtime_state = graph_runtime_state
self.command_channel = command_channel

@abstractmethod
Expand Down
11 changes: 4 additions & 7 deletions api/core/workflow/graph_engine/layers/debug_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,10 +109,8 @@ def on_graph_start(self) -> None:
self.logger.info("=" * 80)
self.logger.info("🚀 GRAPH EXECUTION STARTED")
self.logger.info("=" * 80)

if self.graph_runtime_state:
# Log initial state
self.logger.info("Initial State:")
# Log initial state
self.logger.info("Initial State:")

@override
def on_event(self, event: GraphEngineEvent) -> None:
Expand Down Expand Up @@ -243,8 +241,7 @@ def on_graph_end(self, error: Exception | None) -> None:
self.logger.info(" Node retries: %s", self.retry_count)

# Log final state if available
if self.graph_runtime_state and self.include_outputs:
if self.graph_runtime_state.outputs:
self.logger.info("Final outputs: %s", self._format_dict(self.graph_runtime_state.outputs))
if self.include_outputs and self.graph_runtime_state.outputs:
self.logger.info("Final outputs: %s", self._format_dict(self.graph_runtime_state.outputs))

self.logger.info("=" * 80)
4 changes: 0 additions & 4 deletions api/core/workflow/graph_engine/layers/persistence.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,8 +337,6 @@ def _populate_completion_statistics(self, execution: WorkflowExecution, *, updat
if update_finished:
execution.finished_at = naive_utc_now()
runtime_state = self.graph_runtime_state
if runtime_state is None:
return
execution.total_tokens = runtime_state.total_tokens
execution.total_steps = runtime_state.node_run_steps
execution.outputs = execution.outputs or runtime_state.outputs
Expand Down Expand Up @@ -404,6 +402,4 @@ def _enqueue_trace_task(self, execution: WorkflowExecution) -> None:

def _system_variables(self) -> Mapping[str, Any]:
runtime_state = self.graph_runtime_state
if runtime_state is None:
return {}
return runtime_state.variable_pool.get_by_prefix(SYSTEM_VARIABLE_NODE_ID)
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from core.workflow.entities.pause_reason import SchedulingPause
from core.workflow.enums import WorkflowExecutionStatus
from core.workflow.graph_engine.entities.commands import GraphEngineCommand
from core.workflow.graph_engine.layers.base import GraphEngineLayerNotInitializedError
from core.workflow.graph_events.graph import GraphRunPausedEvent
from core.workflow.runtime.graph_runtime_state import GraphRuntimeState
from core.workflow.runtime.graph_runtime_state_protocol import ReadOnlyGraphRuntimeState
Expand Down Expand Up @@ -569,10 +570,10 @@ def test_layer_requires_initialization(self, db_session_with_containers):
"""Test that layer requires proper initialization before handling events."""
# Arrange
layer = self._create_pause_state_persistence_layer()
# Don't initialize - graph_runtime_state should not be set
# Don't initialize - graph_runtime_state should be uninitialized

event = GraphRunPausedEvent(reasons=[SchedulingPause(message="test pause")])

# Act & Assert - Should raise AttributeError
with pytest.raises(AttributeError):
# Act & Assert - Should raise GraphEngineLayerNotInitializedError
with pytest.raises(GraphEngineLayerNotInitializedError):
layer.on_event(event)
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from core.variables.segments import Segment
from core.workflow.entities.pause_reason import SchedulingPause
from core.workflow.graph_engine.entities.commands import GraphEngineCommand
from core.workflow.graph_engine.layers.base import GraphEngineLayerNotInitializedError
from core.workflow.graph_events.graph import (
GraphRunFailedEvent,
GraphRunPausedEvent,
Expand Down Expand Up @@ -209,8 +210,9 @@ def test_init_with_dependency_injection(self):

assert layer._session_maker is session_factory
assert layer._state_owner_user_id == state_owner_user_id
assert not hasattr(layer, "graph_runtime_state")
assert not hasattr(layer, "command_channel")
with pytest.raises(GraphEngineLayerNotInitializedError):
_ = layer.graph_runtime_state
assert layer.command_channel is None

def test_initialize_sets_dependencies(self):
session_factory = Mock(name="session_factory")
Expand Down Expand Up @@ -295,7 +297,7 @@ def test_on_event_ignores_non_paused_events(self, monkeypatch: pytest.MonkeyPatc
mock_factory.assert_not_called()
mock_repo.create_workflow_pause.assert_not_called()

def test_on_event_raises_attribute_error_when_graph_runtime_state_is_none(self):
def test_on_event_raises_when_graph_runtime_state_is_uninitialized(self):
session_factory = Mock(name="session_factory")
layer = PauseStatePersistenceLayer(
session_factory=session_factory,
Expand All @@ -305,7 +307,7 @@ def test_on_event_raises_attribute_error_when_graph_runtime_state_is_none(self):

event = TestDataFactory.create_graph_run_paused_event()

with pytest.raises(AttributeError):
with pytest.raises(GraphEngineLayerNotInitializedError):
layer.on_event(event)

def test_on_event_asserts_when_workflow_execution_id_missing(self, monkeypatch: pytest.MonkeyPatch):
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
from __future__ import annotations

import pytest

from core.workflow.graph_engine import GraphEngine
from core.workflow.graph_engine.command_channels import InMemoryChannel
from core.workflow.graph_engine.layers.base import (
GraphEngineLayer,
GraphEngineLayerNotInitializedError,
)
from core.workflow.graph_events import GraphEngineEvent

from ..test_table_runner import WorkflowRunner


class LayerForTest(GraphEngineLayer):
def on_graph_start(self) -> None:
pass

def on_event(self, event: GraphEngineEvent) -> None:
pass

def on_graph_end(self, error: Exception | None) -> None:
pass


def test_layer_runtime_state_raises_when_uninitialized() -> None:
layer = LayerForTest()

with pytest.raises(GraphEngineLayerNotInitializedError):
_ = layer.graph_runtime_state


def test_layer_runtime_state_available_after_engine_layer() -> None:
runner = WorkflowRunner()
fixture_data = runner.load_fixture("simple_passthrough_workflow")
graph, graph_runtime_state = runner.create_graph_from_fixture(
fixture_data,
inputs={"query": "test layer state"},
)
engine = GraphEngine(
workflow_id="test_workflow",
graph=graph,
graph_runtime_state=graph_runtime_state,
command_channel=InMemoryChannel(),
)

layer = LayerForTest()
engine.layer(layer)

outputs = layer.graph_runtime_state.outputs
ready_queue_size = layer.graph_runtime_state.ready_queue_size

assert outputs == {}
assert isinstance(ready_queue_size, int)
assert ready_queue_size >= 0