Skip to content

Commit 93e616b

Browse files
committed
Add changes required for deepagent-temporal
1 parent dbbf020 commit 93e616b

11 files changed

Lines changed: 769 additions & 12 deletions

File tree

langgraph/temporal/activities.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from __future__ import annotations
88

9+
import contextvars
910
import itertools
1011
from collections.abc import Sequence
1112
from typing import Any, cast
@@ -35,6 +36,48 @@
3536
)
3637
from langgraph.temporal.converter import GraphRegistry
3738

39+
# Generic extension point: context variable for collecting Child Workflow
40+
# requests during Activity execution. Middleware (e.g., TemporalSubAgentMiddleware)
41+
# appends requests here; the Activity collects them into NodeActivityOutput.
42+
_child_workflow_requests_var: contextvars.ContextVar[list[dict[str, Any]]] = (
43+
contextvars.ContextVar("_child_workflow_requests")
44+
)
45+
46+
# Worker-specific task queue for worker affinity.
47+
# Set at worker startup via `set_worker_task_queue()`. The
48+
# `get_available_task_queue` activity returns this value so the Workflow
49+
# can pin all subsequent Activities to this worker's queue.
50+
_worker_task_queue: str | None = None
51+
52+
53+
def set_worker_task_queue(queue: str) -> None:
54+
"""Set the worker-specific task queue name.
55+
56+
Called at worker startup. The `get_available_task_queue` activity
57+
returns this value to the Workflow for worker affinity.
58+
"""
59+
global _worker_task_queue # noqa: PLW0603
60+
_worker_task_queue = queue
61+
62+
63+
@activity.defn
64+
async def get_available_task_queue() -> str:
65+
"""Return this worker's unique task queue name.
66+
67+
The Workflow calls this activity on the shared distribution queue.
68+
Whichever worker picks it up returns its own unique queue name. The
69+
Workflow then dispatches all subsequent Activities to that queue,
70+
achieving worker affinity (following the Temporal worker-specific
71+
task queues pattern).
72+
"""
73+
if _worker_task_queue is None:
74+
raise ApplicationError(
75+
"Worker task queue not configured. "
76+
"Call set_worker_task_queue() at worker startup.",
77+
non_retryable=True,
78+
)
79+
return _worker_task_queue
80+
3881

3982
def _make_counter() -> Any:
4083
"""Create a simple callable counter using itertools.count."""
@@ -123,6 +166,11 @@ def read_handler(
123166
try:
124167
activity.heartbeat(f"Starting node {input.node_name}")
125168

169+
# Initialize the child workflow requests context variable so that
170+
# middleware (e.g., TemporalSubAgentMiddleware) can append requests
171+
# during node execution.
172+
_child_workflow_requests_var.set([])
173+
126174
# Set the runnable config context so that interrupt() and other
127175
# functions that call get_config() can find it. This is normally
128176
# done by the Runnable invoke/ainvoke tracing path, but Activities
@@ -161,6 +209,9 @@ def read_handler(
161209
# Filter out UntrackedValue writes
162210
serializable_writes = _filter_untracked_writes(all_writes, graph)
163211

212+
# Collect any Child Workflow requests stored during execution
213+
child_wf_requests = _child_workflow_requests_var.get([])
214+
164215
return NodeActivityOutput(
165216
node_name=input.node_name,
166217
writes=serializable_writes,
@@ -169,6 +220,7 @@ def read_handler(
169220
push_sends=push_sends if push_sends else None,
170221
command=command_data,
171222
custom_data=custom_data if custom_data else None,
223+
child_workflow_requests=(child_wf_requests if child_wf_requests else None),
172224
)
173225
except GraphInterrupt as exc:
174226
# Node called interrupt() - propagate to Workflow

langgraph/temporal/config.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,17 +26,36 @@ class ActivityOptions:
2626
schedule_to_close_timeout: timedelta | None = None
2727

2828

29+
@dataclass
30+
class SubAgentConfig:
31+
"""Configuration for sub-agent Child Workflow dispatch.
32+
33+
Attributes:
34+
task_queue: Task queue for sub-agent Child Workflows.
35+
sticky_task_queue: Sticky task queue for sub-agent worker affinity.
36+
execution_timeout_seconds: Maximum execution time for sub-agent
37+
Child Workflows in seconds.
38+
"""
39+
40+
task_queue: str | None = None
41+
sticky_task_queue: str | None = None
42+
execution_timeout_seconds: float = 1800.0
43+
44+
2945
@dataclass
3046
class RestoredState:
3147
"""State carried across continue-as-new boundaries.
3248
3349
Attributes:
3450
checkpoint: The serialized Checkpoint TypedDict.
3551
step: The current step counter.
52+
sticky_task_queue: Sticky task queue name for worker affinity,
53+
preserved across continue-as-new.
3654
"""
3755

3856
checkpoint: dict[str, Any]
3957
step: int
58+
sticky_task_queue: str | None = None
4059

4160

4261
@dataclass
@@ -74,6 +93,16 @@ class WorkflowInput:
7493
node_task_queues: Per-node task queue overrides.
7594
node_activity_options: Per-node Activity configuration (serialized).
7695
node_retry_policies: Per-node retry policy configuration.
96+
sticky_task_queue: Sticky task queue for worker affinity. When set,
97+
overrides per-node task queue routing for all Activities.
98+
Can be set explicitly or discovered at runtime via
99+
`get_available_task_queue` when `use_worker_affinity` is True.
100+
use_worker_affinity: When True, the Workflow calls the
101+
`get_available_task_queue` activity at startup to discover a
102+
worker-specific queue. All subsequent Activities are dispatched
103+
to that queue. The discovered queue is forwarded on
104+
continue-as-new via `sticky_task_queue`.
105+
subagent_config: Configuration for sub-agent Child Workflow dispatch.
77106
"""
78107

79108
graph_definition_ref: str
@@ -85,6 +114,9 @@ class WorkflowInput:
85114
node_task_queues: dict[str, str] | None = None
86115
node_activity_options: dict[str, ActivityOptions] | None = None
87116
node_retry_policies: dict[str, RetryPolicyConfig] | None = None
117+
sticky_task_queue: str | None = None
118+
use_worker_affinity: bool = False
119+
subagent_config: SubAgentConfig | None = None
88120

89121

90122
@dataclass
@@ -138,6 +170,8 @@ class NodeActivityOutput:
138170
interrupts: Interrupt payloads if node called interrupt().
139171
push_sends: Dynamic Send objects emitted during execution.
140172
command: Command metadata if node returned a Command.
173+
child_workflow_requests: Requests for Child Workflow dispatch
174+
(e.g., sub-agent invocations collected via context variable).
141175
"""
142176

143177
node_name: str
@@ -148,6 +182,7 @@ class NodeActivityOutput:
148182
push_sends: list[dict[str, Any]] | None = None
149183
command: dict[str, Any] | None = None
150184
custom_data: list[Any] | None = None
185+
child_workflow_requests: list[dict[str, Any]] | None = None
151186

152187

153188
@dataclass

langgraph/temporal/graph.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
ActivityOptions,
2121
RetryPolicyConfig,
2222
StateUpdatePayload,
23+
SubAgentConfig,
2324
WorkflowInput,
2425
WorkflowOutput,
2526
)
@@ -122,6 +123,27 @@ def _build_workflow_input(
122123
max_attempts=getattr(policy, "max_attempts", 0),
123124
)
124125

126+
# Read sticky_task_queue, use_worker_affinity, and subagent_config
127+
sticky_task_queue: str | None = None
128+
use_worker_affinity: bool = False
129+
subagent_config: SubAgentConfig | None = None
130+
if config and "configurable" in config:
131+
sticky_task_queue = config["configurable"].get("sticky_task_queue")
132+
use_worker_affinity = bool(
133+
config["configurable"].get("use_worker_affinity", False)
134+
)
135+
raw_subagent = config["configurable"].get("subagent_config")
136+
if isinstance(raw_subagent, SubAgentConfig):
137+
subagent_config = raw_subagent
138+
elif isinstance(raw_subagent, dict):
139+
subagent_config = SubAgentConfig(
140+
task_queue=raw_subagent.get("task_queue"),
141+
sticky_task_queue=raw_subagent.get("sticky_task_queue"),
142+
execution_timeout_seconds=raw_subagent.get(
143+
"execution_timeout_seconds", 1800.0
144+
),
145+
)
146+
125147
return WorkflowInput(
126148
graph_definition_ref=self._graph_ref,
127149
input_data=input if isinstance(input, dict) else {"__root__": input},
@@ -131,6 +153,9 @@ def _build_workflow_input(
131153
node_task_queues=self.node_task_queues if self.node_task_queues else None,
132154
node_activity_options=serialized_options,
133155
node_retry_policies=serialized_retry,
156+
sticky_task_queue=sticky_task_queue,
157+
use_worker_affinity=use_worker_affinity,
158+
subagent_config=subagent_config,
134159
)
135160

136161
async def ainvoke(

langgraph/temporal/worker.py

Lines changed: 112 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@
66

77
from __future__ import annotations
88

9+
import uuid
10+
from pathlib import Path
11+
from types import TracebackType
912
from typing import Any
1013

1114
from langgraph.pregel import Pregel
@@ -16,40 +19,145 @@
1619
dynamic_execute_node,
1720
evaluate_conditional_edge,
1821
execute_node,
22+
get_available_task_queue,
23+
set_worker_task_queue,
1924
)
2025
from langgraph.temporal.converter import GraphRegistry
2126
from langgraph.temporal.workflow import LangGraphWorkflow
2227

2328

29+
class WorkerGroup:
30+
"""Manages multiple Temporal Workers as a single async context manager.
31+
32+
Used for worker-affinity mode where two workers are needed:
33+
one on the shared queue (Workflows + discovery Activity) and one on
34+
a worker-specific queue (node execution Activities).
35+
"""
36+
37+
def __init__(self, workers: list[Worker]) -> None:
38+
self._workers = workers
39+
40+
async def __aenter__(self) -> WorkerGroup:
41+
for w in self._workers:
42+
await w.__aenter__()
43+
return self
44+
45+
async def __aexit__(
46+
self,
47+
exc_type: type[BaseException] | None,
48+
exc_val: BaseException | None,
49+
exc_tb: TracebackType | None,
50+
) -> None:
51+
for w in reversed(self._workers):
52+
await w.__aexit__(exc_type, exc_val, exc_tb)
53+
54+
async def run(self) -> None:
55+
"""Run all workers. Blocks until shutdown."""
56+
import asyncio
57+
58+
await asyncio.gather(*[w.run() for w in self._workers])
59+
60+
61+
def _resolve_worker_queue(
62+
task_queue: str,
63+
queue_file: Path | str | None,
64+
) -> str:
65+
"""Resolve the worker-specific queue name.
66+
67+
If `queue_file` is provided and exists, read the persisted queue name
68+
(so a restarted worker re-registers on the same queue). Otherwise
69+
generate a new one and persist it if a path was given.
70+
"""
71+
if queue_file is not None:
72+
path = Path(queue_file)
73+
if path.exists():
74+
stored = path.read_text().strip()
75+
if stored:
76+
return stored
77+
# Generate and persist
78+
queue = f"{task_queue}-worker-{uuid.uuid4().hex[:12]}"
79+
path.parent.mkdir(parents=True, exist_ok=True)
80+
path.write_text(queue)
81+
return queue
82+
83+
return f"{task_queue}-worker-{uuid.uuid4().hex[:12]}"
84+
85+
2486
def create_worker(
2587
graph: Pregel,
2688
client: TemporalClient,
2789
task_queue: str = "langgraph-default",
90+
*,
91+
use_worker_affinity: bool = False,
92+
worker_queue_file: Path | str | None = None,
2893
**kwargs: Any,
29-
) -> Worker:
94+
) -> Worker | WorkerGroup:
3095
"""Create a Temporal Worker configured for a LangGraph graph.
3196
3297
Registers the `LangGraphWorkflow` as a Workflow and `execute_node` /
3398
`evaluate_conditional_edge` as Activities. The graph is registered in
3499
the `GraphRegistry` for Activity-side lookup.
35100
101+
When `use_worker_affinity` is True, returns a `WorkerGroup` with two
102+
workers following the Temporal worker-specific task queues pattern:
103+
- A shared worker on `task_queue` (Workflows + discovery Activity)
104+
- A worker-specific worker on a unique queue (node Activities)
105+
36106
Args:
37107
graph: A compiled Pregel graph instance.
38108
client: A Temporal client instance.
39109
task_queue: The task queue to listen on.
110+
use_worker_affinity: When True, create a dual-worker setup for
111+
worker-specific task queue affinity.
112+
worker_queue_file: Path to persist the worker-specific queue name.
113+
On restart, the worker re-registers on the same queue so that
114+
in-flight Activities resume on this worker. If None, a new
115+
queue name is generated each time (no restart recovery).
40116
**kwargs: Additional Worker configuration (e.g.,
41117
`max_concurrent_activities`, `max_concurrent_workflow_tasks`).
42118
43119
Returns:
44-
A configured Temporal Worker instance.
120+
A configured Temporal Worker or WorkerGroup instance.
45121
"""
46122
# Ensure graph is registered
47123
GraphRegistry.get_instance().register(graph)
48124

49-
return Worker(
125+
if not use_worker_affinity:
126+
return Worker(
127+
client,
128+
task_queue=task_queue,
129+
workflows=[LangGraphWorkflow],
130+
activities=[
131+
execute_node,
132+
dynamic_execute_node,
133+
evaluate_conditional_edge,
134+
],
135+
**kwargs,
136+
)
137+
138+
# Worker-affinity mode: two workers following the Temporal
139+
# worker-specific task queues pattern.
140+
worker_specific_queue = _resolve_worker_queue(task_queue, worker_queue_file)
141+
set_worker_task_queue(worker_specific_queue)
142+
143+
# Worker 1: shared queue — Workflows + get_available_task_queue
144+
shared_worker = Worker(
50145
client,
51146
task_queue=task_queue,
52147
workflows=[LangGraphWorkflow],
53-
activities=[execute_node, dynamic_execute_node, evaluate_conditional_edge],
148+
activities=[get_available_task_queue],
54149
**kwargs,
55150
)
151+
152+
# Worker 2: worker-specific queue — node execution Activities
153+
specific_worker = Worker(
154+
client,
155+
task_queue=worker_specific_queue,
156+
activities=[
157+
execute_node,
158+
dynamic_execute_node,
159+
evaluate_conditional_edge,
160+
],
161+
)
162+
163+
return WorkerGroup([shared_worker, specific_worker])

0 commit comments

Comments
 (0)