Skip to content

Commit d2386f7

Browse files
authored
reasoning tool as parameter (#180)
1 parent 8d4a9a5 commit d2386f7

10 files changed

Lines changed: 390 additions & 27 deletions

File tree

sgr_agent_core/agents/iron_agent.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from sgr_agent_core.next_step_tool import NextStepToolsBuilder
1212
from sgr_agent_core.services.registry import ToolRegistry
1313
from sgr_agent_core.services.tool_instantiator import ToolInstantiator
14-
from sgr_agent_core.tools import BaseTool, ReasoningTool, ToolNameSelectorStub
14+
from sgr_agent_core.tools import BaseTool, ReasoningTool, SystemBaseTool, ToolNameSelectorStub
1515

1616

1717
class IronAgent(BaseAgent):
@@ -34,7 +34,9 @@ def __init__(
3434
openai_client: AsyncOpenAI,
3535
agent_config: AgentConfig,
3636
toolkit: list[Type[BaseTool]],
37+
*,
3738
def_name: str | None = None,
39+
reasoning_tool_cls: type[SystemBaseTool] = ReasoningTool,
3840
**kwargs: dict,
3941
):
4042
super().__init__(
@@ -45,6 +47,7 @@ def __init__(
4547
def_name=def_name,
4648
**kwargs,
4749
)
50+
self.ReasoningTool: type[SystemBaseTool] = reasoning_tool_cls
4851

4952
def _log_tool_instantiator(
5053
self,
@@ -144,7 +147,7 @@ async def _prepare_tools(self) -> Type[ToolNameSelectorStub]:
144147
"""Prepare available tools for the current agent state and progress."""
145148
if self._context.iteration >= self.config.execution.max_iterations:
146149
raise RuntimeError("Max iterations reached")
147-
return NextStepToolsBuilder.build_NextStepToolSelector(self.toolkit)
150+
return NextStepToolsBuilder.build_NextStepToolSelector(self.toolkit, base_reasoning_cls=self.ReasoningTool)
148151

149152
async def _reasoning_phase(self) -> ReasoningTool:
150153
"""Call LLM to get ReasoningTool with selected tool name."""
@@ -153,8 +156,8 @@ async def _reasoning_phase(self) -> ReasoningTool:
153156
tool_selector_model = await self._prepare_tools()
154157
reasoning = await self._generate_tool(tool_selector_model, messages)
155158

156-
if not isinstance(reasoning, ReasoningTool):
157-
raise ValueError("Expected ReasoningTool instance")
159+
if not isinstance(reasoning, self.ReasoningTool):
160+
raise ValueError(f"Expected {self.ReasoningTool.__name__} instance")
158161

159162
# Log reasoning
160163
self._log_reasoning(reasoning)

sgr_agent_core/agents/sgr_agent.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
from sgr_agent_core.tools import (
99
BaseTool,
1010
NextStepToolStub,
11+
ReasoningTool,
12+
SystemBaseTool,
1113
)
1214

1315

@@ -22,7 +24,9 @@ def __init__(
2224
openai_client: AsyncOpenAI,
2325
agent_config: AgentConfig,
2426
toolkit: list[Type[BaseTool]],
27+
*,
2528
def_name: str | None = None,
29+
reasoning_tool_cls: type[SystemBaseTool] = ReasoningTool,
2630
**kwargs: dict,
2731
):
2832
super().__init__(
@@ -33,11 +37,12 @@ def __init__(
3337
def_name=def_name,
3438
**kwargs,
3539
)
40+
self.ReasoningTool: type[SystemBaseTool] = reasoning_tool_cls
3641

3742
async def _prepare_tools(self) -> Type[NextStepToolStub]:
3843
"""Prepare available tools for the current agent state and progress."""
3944
tools = set(self.toolkit)
40-
return NextStepToolsBuilder.build_NextStepTools(list(tools))
45+
return NextStepToolsBuilder.build_NextStepTools(list(tools), base_reasoning_cls=self.ReasoningTool)
4146

4247
async def _reasoning_phase(self) -> NextStepToolStub:
4348
phase_id = f"{self._context.iteration}-reasoning"

sgr_agent_core/agents/sgr_tool_calling_agent.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
BaseTool,
1010
FinalAnswerTool,
1111
ReasoningTool,
12+
SystemBaseTool,
1213
)
1314

1415

@@ -24,7 +25,9 @@ def __init__(
2425
openai_client: AsyncOpenAI,
2526
agent_config: AgentConfig,
2627
toolkit: list[Type[BaseTool]],
28+
*,
2729
def_name: str | None = None,
30+
reasoning_tool_cls: type[SystemBaseTool] = ReasoningTool,
2831
**kwargs: dict,
2932
):
3033
super().__init__(
@@ -36,12 +39,13 @@ def __init__(
3639
**kwargs,
3740
)
3841
self.tool_choice: Literal["required"] = "required"
42+
self.ReasoningTool: type[SystemBaseTool] = reasoning_tool_cls
3943

4044
async def _reasoning_phase(self) -> ReasoningTool:
4145
phase_id = f"{self._context.iteration}-reasoning"
4246
async with self.openai_client.chat.completions.stream(
4347
messages=await self._prepare_context(),
44-
tools=[pydantic_function_tool(ReasoningTool, name=ReasoningTool.tool_name)],
48+
tools=[pydantic_function_tool(self.ReasoningTool, name=self.ReasoningTool.tool_name)],
4549
tool_choice=self.tool_choice,
4650
**self.config.llm.to_openai_client_kwargs(),
4751
) as stream:

sgr_agent_core/base_tool.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import json
44
import logging
5-
from typing import TYPE_CHECKING, ClassVar, Self
5+
from typing import TYPE_CHECKING, ClassVar, Self, TypeVar
66

77
from fastmcp import Client
88
from pydantic import BaseModel
@@ -54,6 +54,9 @@ class SystemBaseTool(BaseTool):
5454
isSystemTool: ClassVar[bool] = True
5555

5656

57+
ReasoningToolStubType = TypeVar("ReasoningToolStubType", bound=SystemBaseTool)
58+
59+
5760
class MCPBaseTool(BaseTool):
5861
"""Base model for MCP Tool schema."""
5962

sgr_agent_core/next_step_tool.py

Lines changed: 41 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -8,27 +8,32 @@
88

99
from pydantic import BaseModel, Field, create_model
1010

11-
from sgr_agent_core.base_tool import BaseTool
11+
from sgr_agent_core.base_tool import BaseTool, SystemBaseTool
1212
from sgr_agent_core.tools.reasoning_tool import ReasoningTool
1313

1414
logger = logging.getLogger(__name__)
1515

1616
T = TypeVar("T", bound=BaseTool)
1717

1818

19-
class NextStepToolStub(ReasoningTool, ABC):
20-
"""SGR Core - Determines the next reasoning step with adaptive planning, choosing appropriate tool
21-
(!) Stub class for correct autocomplete. Use NextStepToolsBuilder"""
19+
class NextStepToolStub(SystemBaseTool, ABC):
20+
"""SGR Core - Determines the next reasoning step with adaptive planning, choosing appropriate tool.
21+
22+
(!) Stub class for correct autocomplete. Use NextStepToolsBuilder.
23+
The actual base reasoning class is injected at build time.
24+
"""
2225

2326
function: T = Field(description="Select the appropriate tool for the next step")
2427

2528

26-
class ToolNameSelectorStub(ReasoningTool, ABC):
27-
"""Stub class for tool name selection that inherits from ReasoningTool.
29+
class ToolNameSelectorStub(SystemBaseTool, ABC):
30+
"""Stub class for tool name selection.
2831
2932
Used by IronAgent to select tool name as part of reasoning phase.
3033
(!) Stub class for correct autocomplete. Use
31-
NextStepToolsBuilder.build_NextStepToolSelector
34+
NextStepToolsBuilder.build_NextStepToolSelector with
35+
base_reasoning_cls. The actual base reasoning class is injected at
36+
build time.
3237
"""
3338

3439
function_name_choice: str = Field(description="Select the name of the tool to use")
@@ -70,21 +75,43 @@ def _create_tool_types_union(cls, tools_list: list[Type[T]]) -> Type:
7075
return Annotated[union, Field()]
7176

7277
@classmethod
73-
def build_NextStepTools(cls, tools_list: list[Type[T]]) -> Type[NextStepToolStub]: # noqa
74-
"""Build a model with all NextStepTool args."""
78+
def build_NextStepTools( # noqa
79+
cls,
80+
tools_list: list[Type[T]],
81+
base_reasoning_cls: type[ReasoningTool] = ReasoningTool,
82+
) -> Type[NextStepToolStub]:
83+
"""Build a model with all NextStepTool args.
84+
85+
Args:
86+
tools_list: List of tool classes to include in the union.
87+
base_reasoning_cls: Pydantic model class used as the base for the
88+
reasoning schema sent to the LLM via Structured Output. Defaults
89+
to ReasoningTool. Pass a subclass to extend or override the
90+
reasoning schema.
91+
"""
7592
return create_model(
7693
"NextStepTools",
77-
__base__=NextStepToolStub,
94+
__base__=base_reasoning_cls,
7895
function=(
7996
cls._create_tool_types_union(tools_list),
8097
Field(description="Select and fill parameters of the appropriate tool for the next step"),
8198
),
8299
)
83100

84101
@classmethod
85-
def build_NextStepToolSelector(cls, tools_list: list[Type[T]]) -> Type[ToolNameSelectorStub]:
86-
"""Build a model for selecting tool name."""
87-
# Extract tool names and descriptions
102+
def build_NextStepToolSelector( # noqa
103+
cls,
104+
tools_list: list[Type[T]],
105+
base_reasoning_cls: type[SystemBaseTool] = ReasoningTool,
106+
) -> Type[ToolNameSelectorStub]:
107+
"""Build a model for selecting tool name.
108+
109+
Args:
110+
tools_list: List of tool classes whose names form the allowed choices.
111+
base_reasoning_cls: Pydantic model class used as the base for the
112+
reasoning schema. Defaults to ReasoningTool. Pass a subclass to
113+
extend or override the reasoning schema.
114+
"""
88115
tool_names = [tool.tool_name for tool in tools_list]
89116

90117
if len(tool_names) == 1:
@@ -98,7 +125,7 @@ def build_NextStepToolSelector(cls, tools_list: list[Type[T]]) -> Type[ToolNameS
98125
# Create model dynamically, inheriting from ToolNameSelectorStub (which inherits from ReasoningTool)
99126
model_class = create_model(
100127
"NextStepToolSelector",
101-
__base__=ToolNameSelectorStub,
128+
__base__=base_reasoning_cls,
102129
function_name_choice=(literal_type, Field(description="Choose the name for the best tool to use")),
103130
)
104131
model_class.tool_name = "nextsteptoolselector" # type: ignore

sgr_agent_core/tools/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from sgr_agent_core.base_tool import BaseTool, MCPBaseTool, SystemBaseTool
1+
from sgr_agent_core.base_tool import BaseTool, MCPBaseTool, ReasoningToolStubType, SystemBaseTool
22
from sgr_agent_core.next_step_tool import (
33
NextStepToolsBuilder,
44
NextStepToolStub,
@@ -19,6 +19,7 @@
1919
"BaseTool",
2020
"MCPBaseTool",
2121
"SystemBaseTool",
22+
"ReasoningToolStubType",
2223
"NextStepToolStub",
2324
"ToolNameSelectorStub",
2425
"NextStepToolsBuilder",

tests/test_agent_e2e.py

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -364,3 +364,129 @@ async def test_sgr_tool_calling_agent_full_execution_cycle():
364364

365365
assert result is not None
366366
_assert_agent_completed(agent)
367+
368+
369+
@pytest.mark.asyncio
370+
async def test_sgr_tool_calling_agent_custom_reasoning_tool_is_used():
371+
"""Custom ReasoningTool is actually passed to OpenAI in _reasoning_phase.
372+
373+
Verifies that self.ReasoningTool is forwarded to
374+
pydantic_function_tool() instead of the hardcoded base
375+
ReasoningTool.
376+
"""
377+
from pydantic import Field as PydanticField
378+
379+
class CustomReasoningTool(ReasoningTool):
380+
confidence: float = PydanticField(default=0.5, description="Confidence in the decision")
381+
382+
captured_reasoning_tool_names: list[str] = []
383+
384+
reasoning_instance = CustomReasoningTool(
385+
reasoning_steps=["Analyze", "Decide"],
386+
current_situation="Test situation",
387+
plan_status="On track",
388+
enough_data=False,
389+
remaining_steps=["Finalize"],
390+
task_completed=False,
391+
confidence=0.9,
392+
)
393+
final_answer_instance = FinalAnswerTool(
394+
reasoning="Done",
395+
completed_steps=["Step 1"],
396+
answer="Final answer to the research task",
397+
status=AgentStatesEnum.COMPLETED,
398+
)
399+
400+
client = Mock(spec=AsyncOpenAI)
401+
402+
def mock_stream(**kwargs):
403+
tools_param = kwargs.get("tools", [])
404+
tool_name = None
405+
if tools_param and isinstance(tools_param, list) and isinstance(tools_param[0], dict):
406+
tool_name = tools_param[0].get("function", {}).get("name")
407+
408+
if tool_name == CustomReasoningTool.tool_name:
409+
captured_reasoning_tool_names.append(tool_name)
410+
return MockStream({"content": None, "tool_calls": [_create_tool_call(reasoning_instance, "call-r")]})
411+
412+
return MockStream({"content": None, "tool_calls": [_create_tool_call(final_answer_instance, "call-a")]})
413+
414+
client.chat.completions.stream = Mock(side_effect=mock_stream)
415+
416+
agent = SGRToolCallingAgent(
417+
task_messages=[{"role": "user", "content": "Test task"}],
418+
openai_client=client,
419+
agent_config=_create_test_agent_config(),
420+
toolkit=[FinalAnswerTool],
421+
reasoning_tool_cls=CustomReasoningTool,
422+
)
423+
424+
result = await agent.execute()
425+
426+
assert result is not None
427+
assert agent._context.state == AgentStatesEnum.COMPLETED
428+
assert len(captured_reasoning_tool_names) >= 1, "Custom ReasoningTool was never passed to OpenAI"
429+
assert captured_reasoning_tool_names[0] == CustomReasoningTool.tool_name
430+
431+
432+
@pytest.mark.asyncio
433+
async def test_sgr_agent_custom_reasoning_tool_is_used():
434+
"""Custom ReasoningTool is used as SO base in SGRAgent._reasoning_phase.
435+
436+
Verifies that response_format passed to OpenAI is built on top of
437+
the custom ReasoningTool subclass rather than the default one.
438+
"""
439+
from pydantic import Field as PydanticField
440+
441+
class CustomReasoningTool(ReasoningTool):
442+
confidence: float = PydanticField(default=0.5, description="Confidence in the decision")
443+
444+
captured_response_formats: list[type] = []
445+
446+
client = Mock(spec=AsyncOpenAI)
447+
448+
def mock_stream(**kwargs):
449+
response_format = kwargs.get("response_format")
450+
if response_format is not None:
451+
captured_response_formats.append(response_format)
452+
453+
NextStepTools = NextStepToolsBuilder.build_NextStepTools(
454+
[FinalAnswerTool],
455+
base_reasoning_cls=CustomReasoningTool,
456+
)
457+
response = NextStepTools(
458+
reasoning_steps=["Step 1", "Step 2"],
459+
current_situation="Test",
460+
plan_status="Ok",
461+
enough_data=True,
462+
remaining_steps=["Finalize"],
463+
task_completed=True,
464+
confidence=0.8,
465+
function={
466+
"tool_name_discriminator": FinalAnswerTool.tool_name,
467+
"reasoning": "Done",
468+
"completed_steps": ["Step 1"],
469+
"answer": "Final answer to the research task",
470+
"status": AgentStatesEnum.COMPLETED,
471+
},
472+
)
473+
return MockStream({"parsed": response})
474+
475+
client.chat.completions.stream = Mock(side_effect=mock_stream)
476+
477+
agent = SGRAgent(
478+
task_messages=[{"role": "user", "content": "Test task"}],
479+
openai_client=client,
480+
agent_config=_create_test_agent_config(),
481+
toolkit=[FinalAnswerTool],
482+
reasoning_tool_cls=CustomReasoningTool,
483+
)
484+
485+
result = await agent.execute()
486+
487+
assert result is not None
488+
assert agent._context.state == AgentStatesEnum.COMPLETED
489+
assert len(captured_response_formats) >= 1, "response_format was never passed to OpenAI"
490+
assert issubclass(
491+
captured_response_formats[0], CustomReasoningTool
492+
), f"response_format {captured_response_formats[0]} is not a subclass of CustomReasoningTool"

0 commit comments

Comments
 (0)