Skip to content

Commit fb3c19c

Browse files
dsfacciniDouweM
andauthored
Resolve dynamic toolset get_instructions in an activity under Temporal (#5925)
Co-authored-by: Douwe Maan <hi@douwe.me>
1 parent 961387c commit fb3c19c

2 files changed

Lines changed: 274 additions & 9 deletions

File tree

pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_dynamic_toolset.py

Lines changed: 54 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

3-
from collections.abc import Callable
3+
import copy
4+
from collections.abc import Callable, Sequence
45
from dataclasses import dataclass
56
from typing import TYPE_CHECKING, Any, Literal
67

@@ -9,7 +10,9 @@
910

1011
from pydantic_ai import ToolsetTool
1112
from pydantic_ai.exceptions import UserError
13+
from pydantic_ai.messages import InstructionPart
1214
from pydantic_ai.tools import AgentDepsT, RunContext, ToolDefinition
15+
from pydantic_ai.toolsets import AbstractToolset
1316
from pydantic_ai.toolsets._dynamic import DynamicToolset
1417
from pydantic_ai.toolsets.external import TOOL_SCHEMA_VALIDATOR
1518

@@ -33,6 +36,20 @@ class _ToolInfo:
3336
max_retries: int
3437

3538

39+
@dataclass
40+
class _GetToolsResult:
41+
"""Serializable result of `get_tools_activity`: the resolved toolset's tools and its instructions.
42+
43+
Instructions are collected in the same activity (and thus the same single resolution and entry of
44+
the inner toolset) as the tools. For an MCP-backed dynamic toolset this means the server is entered
45+
once per run step instead of once for tools and again for instructions; the second entry would add a
46+
redundant `initialize` round-trip whose `notifications/initialized` races teardown.
47+
"""
48+
49+
tools: dict[str, _ToolInfo]
50+
instructions: str | InstructionPart | Sequence[str | InstructionPart] | None
51+
52+
3653
class TemporalDynamicToolset(TemporalWrapperToolset[AgentDepsT]):
3754
"""Temporal wrapper for DynamicToolset.
3855
@@ -57,8 +74,11 @@ def __init__(
5774
self.tool_activity_config = tool_activity_config
5875
self.run_context_type = run_context_type
5976

60-
async def get_tools_activity(params: GetToolsParams, deps: AgentDepsT) -> dict[str, _ToolInfo]:
61-
"""Activity that calls the dynamic function and returns tool definitions."""
77+
# Set by `get_tools`, read by `get_instructions`; lives on the per-run `for_run` copy (no run-id key).
78+
self._run_instructions: str | InstructionPart | Sequence[str | InstructionPart] | None = None
79+
80+
async def get_tools_activity(params: GetToolsParams, deps: AgentDepsT) -> _GetToolsResult:
81+
"""Activity that resolves the dynamic toolset and returns its tools and instructions."""
6282
ctx = deserialize_run_context(
6383
self.run_context_type, params.serialized_run_context, deps=deps, agent=self._agent
6484
)
@@ -67,10 +87,14 @@ async def get_tools_activity(params: GetToolsParams, deps: AgentDepsT) -> dict[s
6787
async with run_toolset:
6888
run_toolset = await run_toolset.for_run_step(ctx)
6989
tools = await run_toolset.get_tools(ctx)
70-
return {
71-
name: _ToolInfo(tool_def=tool.tool_def, max_retries=tool.max_retries)
72-
for name, tool in tools.items()
73-
}
90+
instructions = await run_toolset.get_instructions(ctx)
91+
return _GetToolsResult(
92+
tools={
93+
name: _ToolInfo(tool_def=tool.tool_def, max_retries=tool.max_retries)
94+
for name, tool in tools.items()
95+
},
96+
instructions=instructions,
97+
)
7498

7599
get_tools_activity.__annotations__['deps'] = deps_type
76100

@@ -107,21 +131,42 @@ async def call_tool_activity(params: CallToolParams, deps: AgentDepsT) -> CallTo
107131
def temporal_activities(self) -> list[Callable[..., Any]]:
108132
return [self.get_tools_activity, self.call_tool_activity]
109133

134+
async def for_run(self, ctx: RunContext[AgentDepsT]) -> AbstractToolset[AgentDepsT]:
135+
if not workflow.in_workflow(): # pragma: no cover
136+
return await super().for_run(ctx)
137+
138+
# Per-run copy isolates `_run_instructions` from the process-shared, module-level instance.
139+
# Shallow, so the copy shares the worker-registered activities (`execute_activity` resolves by
140+
# name). The base `return self` is about wrapped-toolset lifecycle; this is only state isolation.
141+
run_copy = copy.copy(self)
142+
run_copy._run_instructions = None
143+
return run_copy
144+
145+
async def get_instructions(
146+
self, ctx: RunContext[AgentDepsT]
147+
) -> str | InstructionPart | Sequence[str | InstructionPart] | None:
148+
if not workflow.in_workflow(): # pragma: no cover
149+
return await super().get_instructions(ctx)
150+
151+
# Set by `get_tools`, which the framework runs (via `for_run_step`) earlier in each step.
152+
return self._run_instructions
153+
110154
async def get_tools(self, ctx: RunContext[AgentDepsT]) -> dict[str, ToolsetTool[AgentDepsT]]:
111155
if not workflow.in_workflow(): # pragma: no cover
112156
return await super().get_tools(ctx)
113157

114158
serialized_run_context = self.run_context_type.serialize_run_context(ctx)
115159
activity_config: ActivityConfig = {'summary': f'get tools: {self.id}', **self.activity_config}
116-
tool_infos = await workflow.execute_activity(
160+
result = await workflow.execute_activity(
117161
activity=self.get_tools_activity,
118162
args=[
119163
GetToolsParams(serialized_run_context=serialized_run_context),
120164
ctx.deps,
121165
],
122166
**activity_config,
123167
)
124-
return {name: self._tool_for_tool_info(tool_info) for name, tool_info in tool_infos.items()}
168+
self._run_instructions = result.instructions
169+
return {name: self._tool_for_tool_info(tool_info) for name, tool_info in result.tools.items()}
125170

126171
async def call_tool(
127172
self,

tests/test_temporal.py

Lines changed: 220 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from pydantic import BaseModel
1515

1616
from pydantic_ai import (
17+
AbstractToolset,
1718
Agent,
1819
AgentRunResultEvent,
1920
AgentStreamEvent,
@@ -1450,6 +1451,225 @@ async def test_dynamic_toolset_outside_workflow():
14501451
assert result.output == snapshot('{"get_dynamic_weather":"Weather in a for Bob: sunny."}')
14511452

14521453

1454+
# --- DynamicToolset.get_instructions test (issue #5282) ---
1455+
# A dynamic toolset whose resolved toolset implements `get_instructions()` must contribute those
1456+
# instructions under `TemporalAgent`, resolved inside an activity like `get_tools`.
1457+
1458+
1459+
def _echo_instructions(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse:
1460+
request = messages[-1]
1461+
assert isinstance(request, ModelRequest)
1462+
return ModelResponse(parts=[TextPart(request.instructions or '<no instructions>')])
1463+
1464+
1465+
dynamic_instructions_agent = Agent(FunctionModel(_echo_instructions), name='dynamic_instructions_agent')
1466+
1467+
1468+
@dynamic_instructions_agent.toolset(id='dynamic_instruction_toolset', per_run_step=False)
1469+
def dynamic_instruction_toolset(ctx: RunContext[object]) -> AbstractToolset[object]:
1470+
# A toolset that only contributes instructions, no tools.
1471+
return FunctionToolset(instructions='SENTINEL_INSTRUCTION_FROM_DYNAMIC_TOOLSET', id='instruction-only-toolset')
1472+
1473+
1474+
dynamic_instructions_temporal_agent = TemporalAgent(
1475+
dynamic_instructions_agent,
1476+
activity_config=BASE_ACTIVITY_CONFIG,
1477+
)
1478+
1479+
1480+
@workflow.defn
1481+
class DynamicInstructionsAgentWorkflow:
1482+
@workflow.run
1483+
async def run(self, prompt: str) -> str:
1484+
result = await dynamic_instructions_temporal_agent.run(prompt)
1485+
return result.output
1486+
1487+
1488+
async def test_dynamic_toolset_instructions_in_workflow(allow_model_requests: None, client: Client):
1489+
"""A dynamic toolset's `get_instructions()` reaches the model under `TemporalAgent` (issue #5282).
1490+
1491+
The model echoes the request's instructions back as its output, so the sentinel in the output
1492+
proves the resolved dynamic toolset's instructions were collected via the new activity.
1493+
"""
1494+
async with Worker(
1495+
client,
1496+
task_queue=TASK_QUEUE,
1497+
workflows=[DynamicInstructionsAgentWorkflow],
1498+
plugins=[AgentPlugin(dynamic_instructions_temporal_agent)],
1499+
):
1500+
output = await client.execute_workflow(
1501+
DynamicInstructionsAgentWorkflow.run,
1502+
args=['hello'],
1503+
id='test_dynamic_toolset_instructions_workflow',
1504+
task_queue=TASK_QUEUE,
1505+
)
1506+
assert output == snapshot('SENTINEL_INSTRUCTION_FROM_DYNAMIC_TOOLSET')
1507+
1508+
1509+
def test_dynamic_toolset_temporal_activities():
1510+
"""`TemporalDynamicToolset` collects instructions inside `get_tools`, so it has no separate `get_instructions` activity."""
1511+
activity_names = {
1512+
ActivityDefinition.must_from_callable(activity).name # pyright: ignore[reportUnknownMemberType]
1513+
for activity in dynamic_instructions_temporal_agent.temporal_activities
1514+
}
1515+
prefix = 'agent__dynamic_instructions_agent__dynamic_toolset__dynamic_instruction_toolset'
1516+
assert {f'{prefix}__get_tools', f'{prefix}__call_tool'} <= activity_names
1517+
assert f'{prefix}__get_instructions' not in activity_names
1518+
1519+
1520+
# --- DynamicToolset instructions refresh across run steps (issue #5282 follow-up) ---
1521+
# The per-run instructions cache is written by `get_tools` and read by `get_instructions` each
1522+
# step; this guards against it serving a stale step-1 value on a later step.
1523+
1524+
1525+
def _echo_instructions_after_tool_call(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse:
1526+
# First request: call a tool to force a second model-request step.
1527+
# Second request (carrying the tool return): echo the instructions, which by then must
1528+
# reflect the current step — proving the cache is repopulated by `get_tools` each step.
1529+
request = messages[-1]
1530+
assert isinstance(request, ModelRequest)
1531+
if any(isinstance(part, ToolReturnPart) for part in request.parts):
1532+
return ModelResponse(parts=[TextPart(request.instructions or '<no instructions>')])
1533+
return ModelResponse(parts=[ToolCallPart('noop', {})])
1534+
1535+
1536+
multi_step_instructions_agent = Agent(
1537+
FunctionModel(_echo_instructions_after_tool_call), name='multi_step_instructions_agent'
1538+
)
1539+
1540+
1541+
@multi_step_instructions_agent.toolset(id='multi_step_instruction_toolset')
1542+
def multi_step_instruction_toolset(ctx: RunContext[object]) -> AbstractToolset[object]:
1543+
# Instructions encode the run step, so a stale step-1 cached value read at step 2 would
1544+
# surface as the wrong sentinel in the model output.
1545+
toolset = FunctionToolset[object](
1546+
instructions=f'INSTRUCTIONS_FOR_STEP_{ctx.run_step}', id='step-instruction-toolset'
1547+
)
1548+
1549+
@toolset.tool_plain
1550+
def noop() -> str:
1551+
return 'noop'
1552+
1553+
return toolset
1554+
1555+
1556+
multi_step_instructions_temporal_agent = TemporalAgent(
1557+
multi_step_instructions_agent,
1558+
activity_config=BASE_ACTIVITY_CONFIG,
1559+
)
1560+
1561+
1562+
@workflow.defn
1563+
class MultiStepInstructionsAgentWorkflow:
1564+
@workflow.run
1565+
async def run(self, prompt: str) -> str:
1566+
result = await multi_step_instructions_temporal_agent.run(prompt)
1567+
return result.output
1568+
1569+
1570+
async def test_dynamic_toolset_instructions_refresh_across_steps_in_workflow(
1571+
allow_model_requests: None, client: Client
1572+
):
1573+
"""A dynamic toolset's instructions are refreshed each run step under `TemporalAgent` (issue #5282).
1574+
1575+
The toolset encodes the run step in its instructions; the model calls a tool on the first request to
1576+
force a second step, then echoes the instructions on the second request. The output being the step-2
1577+
sentinel (not the step-1 one) proves `get_tools` repopulates the per-run instructions cache each step
1578+
rather than serving a stale step-1 value.
1579+
"""
1580+
async with Worker(
1581+
client,
1582+
task_queue=TASK_QUEUE,
1583+
workflows=[MultiStepInstructionsAgentWorkflow],
1584+
plugins=[AgentPlugin(multi_step_instructions_temporal_agent)],
1585+
):
1586+
output = await client.execute_workflow(
1587+
MultiStepInstructionsAgentWorkflow.run,
1588+
args=['hello'],
1589+
id='test_dynamic_toolset_instructions_refresh_workflow',
1590+
task_queue=TASK_QUEUE,
1591+
)
1592+
assert output == snapshot('INSTRUCTIONS_FOR_STEP_2')
1593+
1594+
1595+
# --- DynamicToolset instructions replay determinism (issue #5282) ---
1596+
# The per-run instructions cache lives on a `for_run` copy of the wrapper rather than on the
1597+
# process-shared, module-level instance. A history recorded on one worker must replay on a
1598+
# freshly-constructed (cold) one, proving the `for_run` override reconstructs identically and
1599+
# introduces no `TMPRL1100` nondeterminism.
1600+
1601+
# A holder lets the replay step swap in a freshly-constructed (cold-process) instance.
1602+
dynamic_instructions_replay_holder: dict[str, TemporalAgent[object, str]] = {}
1603+
1604+
1605+
def _make_dynamic_instructions_replay_agent() -> TemporalAgent[object, str]:
1606+
agent = Agent(FunctionModel(_echo_instructions_after_tool_call), name='dynamic_instructions_replay_agent')
1607+
1608+
@agent.toolset(id='replay_instruction_toolset')
1609+
def _replay_toolset(ctx: RunContext[object]) -> AbstractToolset[object]:
1610+
toolset = FunctionToolset[object](
1611+
instructions=f'INSTRUCTIONS_FOR_STEP_{ctx.run_step}', id='step-instruction-toolset'
1612+
)
1613+
1614+
@toolset.tool_plain
1615+
def noop() -> str:
1616+
return 'noop'
1617+
1618+
return toolset
1619+
1620+
return TemporalAgent(agent, activity_config=BASE_ACTIVITY_CONFIG)
1621+
1622+
1623+
dynamic_instructions_replay_holder['agent'] = _make_dynamic_instructions_replay_agent()
1624+
1625+
1626+
@workflow.defn
1627+
class DynamicInstructionsReplayWorkflow:
1628+
@workflow.run
1629+
async def run(self, prompt: str) -> str:
1630+
result = await dynamic_instructions_replay_holder['agent'].run(prompt)
1631+
return result.output
1632+
1633+
1634+
async def test_dynamic_toolset_instructions_replay_deterministic(allow_model_requests: None, client: Client):
1635+
"""The per-run `for_run` instructions cache must be replay-deterministic (issue #5282).
1636+
1637+
Instructions resolved by `get_tools` are held on a per-run `for_run` copy of the wrapper, not
1638+
on the module-level instance. This records a two-step workflow (instructions differ per step)
1639+
and replays its history on a freshly-constructed cold instance — the worker-restart scenario —
1640+
asserting no nondeterminism, so the `for_run` copy is reconstructed identically on replay.
1641+
"""
1642+
warm = _make_dynamic_instructions_replay_agent()
1643+
dynamic_instructions_replay_holder['agent'] = warm
1644+
1645+
# Unsandboxed so the module-level instance is shared across the run exactly as a long-running
1646+
# worker process shares it in production.
1647+
async with Worker(
1648+
client,
1649+
task_queue=TASK_QUEUE,
1650+
workflows=[DynamicInstructionsReplayWorkflow],
1651+
activities=warm.temporal_activities,
1652+
workflow_runner=UnsandboxedWorkflowRunner(),
1653+
):
1654+
wf_id = DynamicInstructionsReplayWorkflow.__name__
1655+
output = await client.execute_workflow(
1656+
DynamicInstructionsReplayWorkflow.run, args=['hello'], id=wf_id, task_queue=TASK_QUEUE
1657+
)
1658+
assert output == snapshot('INSTRUCTIONS_FOR_STEP_2')
1659+
history = await client.get_workflow_handle(wf_id).fetch_history()
1660+
1661+
# Warm-recorded history replayed on a freshly-constructed cold instance (worker-restart trigger).
1662+
dynamic_instructions_replay_holder['agent'] = _make_dynamic_instructions_replay_agent()
1663+
try:
1664+
await Replayer(
1665+
workflows=[DynamicInstructionsReplayWorkflow],
1666+
workflow_runner=UnsandboxedWorkflowRunner(),
1667+
data_converter=pydantic_data_converter,
1668+
).replay_workflow(history)
1669+
finally:
1670+
dynamic_instructions_replay_holder['agent'] = warm
1671+
1672+
14531673
# --- MCP-based DynamicToolset test ---
14541674
# Tests that @agent.toolset returning an MCPToolset works with Temporal workflows.
14551675
# Uses an HTTP-based MCP server rather than subprocess-based since the subprocess transports

0 commit comments

Comments
 (0)