@@ -364,3 +364,129 @@ async def test_sgr_tool_calling_agent_full_execution_cycle():
364364
365365 assert result is not None
366366 _assert_agent_completed (agent )
367+
368+
369+ @pytest .mark .asyncio
370+ async def test_sgr_tool_calling_agent_custom_reasoning_tool_is_used ():
371+ """Custom ReasoningTool is actually passed to OpenAI in _reasoning_phase.
372+
373+ Verifies that self.ReasoningTool is forwarded to
374+ pydantic_function_tool() instead of the hardcoded base
375+ ReasoningTool.
376+ """
377+ from pydantic import Field as PydanticField
378+
379+ class CustomReasoningTool (ReasoningTool ):
380+ confidence : float = PydanticField (default = 0.5 , description = "Confidence in the decision" )
381+
382+ captured_reasoning_tool_names : list [str ] = []
383+
384+ reasoning_instance = CustomReasoningTool (
385+ reasoning_steps = ["Analyze" , "Decide" ],
386+ current_situation = "Test situation" ,
387+ plan_status = "On track" ,
388+ enough_data = False ,
389+ remaining_steps = ["Finalize" ],
390+ task_completed = False ,
391+ confidence = 0.9 ,
392+ )
393+ final_answer_instance = FinalAnswerTool (
394+ reasoning = "Done" ,
395+ completed_steps = ["Step 1" ],
396+ answer = "Final answer to the research task" ,
397+ status = AgentStatesEnum .COMPLETED ,
398+ )
399+
400+ client = Mock (spec = AsyncOpenAI )
401+
402+ def mock_stream (** kwargs ):
403+ tools_param = kwargs .get ("tools" , [])
404+ tool_name = None
405+ if tools_param and isinstance (tools_param , list ) and isinstance (tools_param [0 ], dict ):
406+ tool_name = tools_param [0 ].get ("function" , {}).get ("name" )
407+
408+ if tool_name == CustomReasoningTool .tool_name :
409+ captured_reasoning_tool_names .append (tool_name )
410+ return MockStream ({"content" : None , "tool_calls" : [_create_tool_call (reasoning_instance , "call-r" )]})
411+
412+ return MockStream ({"content" : None , "tool_calls" : [_create_tool_call (final_answer_instance , "call-a" )]})
413+
414+ client .chat .completions .stream = Mock (side_effect = mock_stream )
415+
416+ agent = SGRToolCallingAgent (
417+ task_messages = [{"role" : "user" , "content" : "Test task" }],
418+ openai_client = client ,
419+ agent_config = _create_test_agent_config (),
420+ toolkit = [FinalAnswerTool ],
421+ reasoning_tool_cls = CustomReasoningTool ,
422+ )
423+
424+ result = await agent .execute ()
425+
426+ assert result is not None
427+ assert agent ._context .state == AgentStatesEnum .COMPLETED
428+ assert len (captured_reasoning_tool_names ) >= 1 , "Custom ReasoningTool was never passed to OpenAI"
429+ assert captured_reasoning_tool_names [0 ] == CustomReasoningTool .tool_name
430+
431+
432+ @pytest .mark .asyncio
433+ async def test_sgr_agent_custom_reasoning_tool_is_used ():
434+ """Custom ReasoningTool is used as SO base in SGRAgent._reasoning_phase.
435+
436+ Verifies that response_format passed to OpenAI is built on top of
437+ the custom ReasoningTool subclass rather than the default one.
438+ """
439+ from pydantic import Field as PydanticField
440+
441+ class CustomReasoningTool (ReasoningTool ):
442+ confidence : float = PydanticField (default = 0.5 , description = "Confidence in the decision" )
443+
444+ captured_response_formats : list [type ] = []
445+
446+ client = Mock (spec = AsyncOpenAI )
447+
448+ def mock_stream (** kwargs ):
449+ response_format = kwargs .get ("response_format" )
450+ if response_format is not None :
451+ captured_response_formats .append (response_format )
452+
453+ NextStepTools = NextStepToolsBuilder .build_NextStepTools (
454+ [FinalAnswerTool ],
455+ base_reasoning_cls = CustomReasoningTool ,
456+ )
457+ response = NextStepTools (
458+ reasoning_steps = ["Step 1" , "Step 2" ],
459+ current_situation = "Test" ,
460+ plan_status = "Ok" ,
461+ enough_data = True ,
462+ remaining_steps = ["Finalize" ],
463+ task_completed = True ,
464+ confidence = 0.8 ,
465+ function = {
466+ "tool_name_discriminator" : FinalAnswerTool .tool_name ,
467+ "reasoning" : "Done" ,
468+ "completed_steps" : ["Step 1" ],
469+ "answer" : "Final answer to the research task" ,
470+ "status" : AgentStatesEnum .COMPLETED ,
471+ },
472+ )
473+ return MockStream ({"parsed" : response })
474+
475+ client .chat .completions .stream = Mock (side_effect = mock_stream )
476+
477+ agent = SGRAgent (
478+ task_messages = [{"role" : "user" , "content" : "Test task" }],
479+ openai_client = client ,
480+ agent_config = _create_test_agent_config (),
481+ toolkit = [FinalAnswerTool ],
482+ reasoning_tool_cls = CustomReasoningTool ,
483+ )
484+
485+ result = await agent .execute ()
486+
487+ assert result is not None
488+ assert agent ._context .state == AgentStatesEnum .COMPLETED
489+ assert len (captured_response_formats ) >= 1 , "response_format was never passed to OpenAI"
490+ assert issubclass (
491+ captured_response_formats [0 ], CustomReasoningTool
492+ ), f"response_format { captured_response_formats [0 ]} is not a subclass of CustomReasoningTool"
0 commit comments