Skip to content

Commit 4060f17

Browse files
committed
feat(structured-outputs): typed runAs[T] in agent abstraction
1 parent 97340ee commit 4060f17

8 files changed

Lines changed: 183 additions & 13 deletions

File tree

claude/src/main/scala/sttp/ai/claude/ClaudeSyncClient.scala

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,6 @@ class ClaudeSyncClient(config: ClaudeConfig, backend: SyncBackend = DefaultSyncB
1818
case Right(response) => response
1919
}
2020

21-
/** Creates a typed message response. The response schema is derived from `T` via Tapir and set on the request as a structured-output
22-
* format (unless one is already set), and the response's text content is parsed back into `T` via uPickle.
23-
*
24-
* @param request
25-
* Message request. If [[MessageRequest.usesStructuredOutput]] is false, a JSON schema for `T` is set automatically.
26-
* @tparam T
27-
* The return type, which must have both a [[TapirSchema]] and a [[SnakePickle.Reader]] available.
28-
*/
2921
def createMessageAs[T: TapirSchema: SnakePickle.Reader](request: MessageRequest): T = {
3022
val withSchema =
3123
if (request.usesStructuredOutput) request

claude/src/test/scala/sttp/ai/claude/integration/ClaudeAgentIntegrationSpec.scala

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ class ClaudeAgentIntegrationSpec extends AgentIntegrationSpecBase {
1717
val config = ClaudeConfig.fromEnv
1818
val client = ClaudeClient(config)
1919
val agentConfig = AgentConfig(maxIterations = maxIterations, userTools = tools).right.get
20-
val allTools = agentConfig.userTools ++ AgentConfig.systemTools
20+
val allTools = agentConfig.userTools ++ AgentConfig.systemTools(agentConfig)
2121
val agentBackend = new ClaudeAgentBackend[Identity](
2222
client,
2323
"claude-haiku-4-5-20251001",
@@ -26,4 +26,17 @@ class ClaudeAgentIntegrationSpec extends AgentIntegrationSpecBase {
2626
)(IdentityMonad)
2727
Agent(agentBackend, agentConfig)(IdentityMonad)
2828
}
29+
30+
override def createTypedAgent[T](
31+
maxIterations: Int,
32+
tools: Seq[AgentTool[_]],
33+
responseSchema: ResponseSchema[T]
34+
): Agent[Identity] = {
35+
val agentConfig = AgentConfig(
36+
maxIterations = maxIterations,
37+
userTools = tools,
38+
responseSchema = Some(responseSchema)
39+
).right.get
40+
ClaudeAgent.synchronous(ClaudeConfig.fromEnv, "claude-haiku-4-5-20251001", agentConfig)
41+
}
2942
}

core/src/main/scala/sttp/ai/core/agent/Agent.scala

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
package sttp.ai.core.agent
22

3+
import sttp.ai.core.json.SnakePickle
34
import sttp.client4.Backend
45
import sttp.monad.MonadError
5-
import sttp.ai.core.json.SnakePickle
66

77
class Agent[F[_]](
88
agentBackend: AgentBackend[F],
@@ -96,6 +96,16 @@ class Agent[F[_]](
9696
loop(initialHistory, 0, Seq.empty)
9797
}
9898

99+
def runAs[T](
100+
initialPrompt: String
101+
)(backend: Backend[F])(implicit r: SnakePickle.Reader[T]): F[AgentResult[Either[AgentParseError, T]]] =
102+
monad.map(run(initialPrompt)(backend)) { res =>
103+
val parsed: Either[AgentParseError, T] =
104+
try Right(SnakePickle.read[T](res.finalAnswer))
105+
catch { case e: Exception => Left(AgentParseError(res.finalAnswer, e)) }
106+
AgentResult(parsed, res.iterations, res.toolCalls, res.finishReason)
107+
}
108+
99109
private def executeTool[T](tool: AgentTool[T], toolCall: ToolCall): Either[String, String] = {
100110
val parseResult: Either[Exception, T] =
101111
try

core/src/main/scala/sttp/ai/core/agent/AgentResult.scala

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,16 +10,21 @@ object FinishReason {
1010
case class Error(message: String) extends FinishReason
1111
}
1212

13-
case class ToolCallRecord(
13+
final case class ToolCallRecord(
1414
toolName: String,
1515
input: String,
1616
output: String,
1717
iteration: Int
1818
)
1919

20-
case class AgentResult[T](
20+
final case class AgentResult[T](
2121
finalAnswer: T,
2222
iterations: Int,
2323
toolCalls: Seq[ToolCallRecord],
2424
finishReason: FinishReason
2525
)
26+
27+
final case class AgentParseError(
28+
rawAnswer: String,
29+
cause: Throwable
30+
)

core/src/test/scala/sttp/ai/core/agent/AgentSpec.scala

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -604,6 +604,63 @@ class AgentSpec extends AnyFlatSpec with Matchers {
604604
cfg.systemPrompt.getOrElse("") should include("JSON object matching the provided input schema")
605605
}
606606

607+
"Agent.runAs[T]" should "return Right(T) when the model emits a well-formed structured payload" in {
608+
val cfg = AgentConfig(
609+
maxIterations = 3,
610+
responseSchema = Some(ResponseSchema.derived[WeatherSummary]())
611+
).toOption.get
612+
613+
val (loop, _) = createLoop(
614+
Seq(
615+
AgentResponse(
616+
"",
617+
Seq(
618+
ToolCall(
619+
id = "call_1",
620+
toolName = "finish",
621+
input = """{"city":"Krakow","temp_c":12.0,"conditions":"sunny"}"""
622+
)
623+
),
624+
StopReason.ToolUse
625+
)
626+
),
627+
cfg
628+
)
629+
630+
val result = loop.runAs[WeatherSummary]("What's the weather?")(backend)
631+
632+
result.finishReason shouldBe FinishReason.ToolFinish: Unit
633+
result.iterations shouldBe 1: Unit
634+
result.finalAnswer shouldBe Right(WeatherSummary("Krakow", 12.0, "sunny"))
635+
}
636+
637+
it should "return Left(AgentParseError) preserving the trace when the answer can't be parsed as T" in {
638+
val cfg = AgentConfig(
639+
maxIterations = 3,
640+
responseSchema = Some(ResponseSchema.derived[WeatherSummary]())
641+
).toOption.get
642+
643+
val (loop, _) = createLoop(
644+
Seq(
645+
AgentResponse(
646+
"",
647+
Seq(ToolCall(id = "call_1", toolName = "finish", input = """{"wrong":"shape"}""")),
648+
StopReason.ToolUse
649+
)
650+
),
651+
cfg
652+
)
653+
654+
val result = loop.runAs[WeatherSummary]("What's the weather?")(backend)
655+
656+
result.iterations shouldBe 1: Unit
657+
result.finishReason shouldBe FinishReason.ToolFinish: Unit
658+
result.finalAnswer.isLeft shouldBe true: Unit
659+
val err = result.finalAnswer.left.toOption.get
660+
err.rawAnswer should include("Invalid arguments"): Unit
661+
err.cause should not be null
662+
}
663+
607664
it should "terminate with the parse-error message as finalAnswer when finish is called with a malformed payload" in {
608665
val cfg = AgentConfig(
609666
maxIterations = 3,

core/src/test/scala/sttp/ai/core/agent/integration/AgentIntegrationSpecBase.scala

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,13 @@ abstract class AgentIntegrationSpecBase extends AnyFlatSpec with Matchers {
1414
def apiKeyEnvVar: String
1515
def createAgent(maxIterations: Int, tools: Seq[AgentTool[_]]): Agent[Identity]
1616

17+
def createTypedAgent[T](
18+
maxIterations: Int,
19+
tools: Seq[AgentTool[_]],
20+
responseSchema: ResponseSchema[T]
21+
): Agent[Identity] =
22+
cancel(s"$providerName typed agent factory not implemented for this spec")
23+
1724
protected val maybeApiKey: Option[String] = sys.env.get(apiKeyEnvVar)
1825

1926
case class CalculatorInput(operation: String, a: Double, b: Double)
@@ -151,4 +158,28 @@ abstract class AgentIntegrationSpecBase extends AnyFlatSpec with Matchers {
151158
result.iterations shouldBe 1
152159
()
153160
}
161+
162+
case class TripSummary(weather: String, calculation: String, conclusion: String)
163+
implicit val tripSummaryRW: SnakePickle.ReadWriter[TripSummary] = SnakePickle.macroRW
164+
implicit val tripSummarySchema: Schema[TripSummary] = Schema.derived
165+
166+
it should "return a typed structured answer via runAs[T]" in {
167+
if (maybeApiKey.isEmpty) cancel(s"$apiKeyEnvVar not defined - skipping integration test")
168+
val backend = DefaultSyncBackend()
169+
try {
170+
val rs = ResponseSchema.derived[TripSummary]()
171+
val agent = createTypedAgent(maxIterations = 6, tools = Seq(weatherTool, calculatorTool), responseSchema = rs)
172+
173+
val result = agent.runAs[TripSummary](
174+
"What's the weather in Paris? Also, what is 15 multiplied by 3? Then summarise both into the structured response."
175+
)(backend)
176+
177+
result.finishReason shouldBe FinishReason.ToolFinish: Unit
178+
result.finalAnswer.isRight shouldBe true: Unit
179+
val summary = result.finalAnswer.toOption.get
180+
summary.weather should not be empty: Unit
181+
summary.calculation should not be empty: Unit
182+
summary.conclusion should not be empty
183+
} finally backend.close()
184+
}
154185
}

examples/src/main/scala/examples/AgentLoopExample.scala

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,3 +74,51 @@ object AgentLoopExample extends App {
7474
}
7575
finally backend.close()
7676
}
77+
78+
object TypedAgentLoopExample extends App {
79+
80+
case class TripSummary(weatherSummary: String, calculation: String, conclusion: String) derives SnakePickle.ReadWriter, Schema
81+
82+
case class WeatherInput(location: String, unit: Option[String]) derives SnakePickle.ReadWriter, Schema
83+
84+
val weatherTool = AgentTool.fromFunction("get_weather", "Get the current weather for a location") { (input: WeatherInput) =>
85+
val unit = input.unit.getOrElse("celsius")
86+
s"The weather in ${input.location} is 22°${if (unit == "celsius") "C" else "F"}, sunny"
87+
}
88+
89+
case class CalculatorInput(operation: String, a: Double, b: Double) derives SnakePickle.ReadWriter, Schema
90+
91+
val calculatorTool = AgentTool.fromFunction("calculate", "Perform a mathematical calculation") { (input: CalculatorInput) =>
92+
val result = input.operation match {
93+
case "add" => input.a + input.b
94+
case "subtract" => input.a - input.b
95+
case "multiply" => input.a * input.b
96+
case "divide" => if (input.b != 0) input.a / input.b else Double.NaN
97+
case _ => 0.0
98+
}
99+
s"${input.a} ${input.operation} ${input.b} = $result"
100+
}
101+
102+
val cfg = AgentConfig(
103+
maxIterations = 8,
104+
userTools = Seq(weatherTool, calculatorTool),
105+
responseSchema = Some(ResponseSchema.derived[TripSummary]())
106+
).toOption.get
107+
108+
val openai = OpenAI.fromEnv
109+
val backend = DefaultSyncBackend()
110+
try {
111+
val agent = OpenAIAgent.synchronous(openai, "gpt-4o-mini", cfg)
112+
val prompt = "What's the weather in Paris? Also, what is 15 multiplied by 23? Provide a complete answer."
113+
114+
agent.runAs[TripSummary](prompt)(backend).finalAnswer match {
115+
case Right(summary) =>
116+
println(s"weather: ${summary.weatherSummary}")
117+
println(s"calculation: ${summary.calculation}")
118+
println(s"conclusion: ${summary.conclusion}")
119+
case Left(err) =>
120+
println(s"Failed to parse structured answer: ${err.cause.getMessage}")
121+
println(s"Raw answer was: ${err.rawAnswer}")
122+
}
123+
} finally backend.close()
124+
}

openai/src/test/scala/sttp/ai/openai/integration/OpenAIAgentIntegrationSpec.scala

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ class OpenAIAgentIntegrationSpec extends AgentIntegrationSpecBase {
1515
override def createAgent(maxIterations: Int, tools: Seq[AgentTool[_]]): Agent[Identity] = {
1616
val openai = OpenAI.fromEnv
1717
val agentConfig = AgentConfig(maxIterations = maxIterations, userTools = tools).right.get
18-
val allTools = agentConfig.userTools ++ AgentConfig.systemTools
18+
val allTools = agentConfig.userTools ++ AgentConfig.systemTools(agentConfig)
1919
val agentBackend = new OpenAIAgentBackend[Identity](
2020
openai,
2121
"gpt-4o-mini",
@@ -24,4 +24,18 @@ class OpenAIAgentIntegrationSpec extends AgentIntegrationSpecBase {
2424
)(IdentityMonad)
2525
Agent(agentBackend, agentConfig)(IdentityMonad)
2626
}
27+
28+
override def createTypedAgent[T](
29+
maxIterations: Int,
30+
tools: Seq[AgentTool[_]],
31+
responseSchema: ResponseSchema[T]
32+
): Agent[Identity] = {
33+
val openai = OpenAI.fromEnv
34+
val agentConfig = AgentConfig(
35+
maxIterations = maxIterations,
36+
userTools = tools,
37+
responseSchema = Some(responseSchema)
38+
).right.get
39+
OpenAIAgent.synchronous(openai, "gpt-4o-mini", agentConfig)
40+
}
2741
}

0 commit comments

Comments
 (0)