Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
229 changes: 220 additions & 9 deletions itk/agents/go/v03/main.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
package main

import (
"bytes"
"context"
"encoding/base64"
"encoding/json"
"bytes"
"flag"
"fmt"
"io"
Expand All @@ -14,6 +14,7 @@ import (
"os"
"os/signal"
"strings"
"sync"
"syscall"
"time"

Expand All @@ -34,12 +35,32 @@ import (
"golang.org/x/sync/errgroup"
)

func shouldHold(inst *pb.Instruction) bool {
if inst.GetReturnResponse() != nil && inst.GetReturnResponse().HoldTask {
return true
}
if inst.GetSteps() != nil {
for _, step := range inst.GetSteps().Instructions {
if shouldHold(step) {
return true
}
}
}
return false
}

type V03AgentExecutor struct {
cancels sync.Map
}

func (e *V03AgentExecutor) Execute(ctx context.Context, reqCtx *a2asrv.RequestContext, queue eventqueue.Queue) error {
log.Info(ctx, "Executing task", "taskId", reqCtx.Message.ID)

ctx, cancel := context.WithCancel(ctx)
e.cancels.Store(reqCtx.TaskID, cancel)
defer e.cancels.Delete(reqCtx.TaskID)
defer cancel()

if reqCtx.StoredTask == nil {
if err := queue.Write(ctx, a2a.NewSubmittedTask(reqCtx, reqCtx.Message)); err != nil {
return err
Expand Down Expand Up @@ -80,20 +101,102 @@ func (e *V03AgentExecutor) Execute(ctx context.Context, reqCtx *a2asrv.RequestCo

// 2. Handle Instruction
results, err := e.handleInstruction(ctx, reqCtx, &instruction)
log.Info(ctx, "handleInstruction results", "results", results)
if err != nil {
log.Error(ctx, "Error handling instruction", err)
if err := queue.Write(ctx, a2a.NewStatusUpdateEvent(reqCtx, a2a.TaskStateFailed, nil)); err != nil {
log.Error(ctx, "Failed to write status update", err)
}
return queue.Write(ctx, a2a.NewMessageForTask(a2a.MessageRoleAgent, reqCtx, a2a.TextPart{Text: fmt.Sprintf("Execution Error: %v", err)}))
return err
}

// 3. Return response
response := strings.Join(results, "\n")
msg := a2a.NewMessageForTask(a2a.MessageRoleAgent, reqCtx, a2a.TextPart{Text: response})
event := a2a.NewStatusUpdateEvent(reqCtx, a2a.TaskStateCompleted, msg)
event.Final = true
return queue.Write(ctx, event)

if shouldHold(&instruction) {
log.Info(ctx, "Holding task as requested", "taskId", reqCtx.Message.ID)

// First emitted event: the actual response
// Emitted event: response + task-finished
log.Info(ctx, "Emitting response and task-finished", "taskId", reqCtx.Message.ID)
finnishedMsg := a2a.NewMessageForTask(a2a.MessageRoleAgent, reqCtx, a2a.TextPart{Text: response + "\ntask-finished"})

if err := queue.Write(ctx, a2a.NewStatusUpdateEvent(reqCtx, a2a.TaskStateWorking, finnishedMsg)); err != nil {
return err
}

select {
case <-ctx.Done():
log.Info(ctx, "Task cancelled during sleep", "taskId", reqCtx.Message.ID)
event := a2a.NewStatusUpdateEvent(reqCtx, a2a.TaskStateCanceled, nil)
event.Final = true
queue.Write(context.Background(), event)
return nil
case <-time.After(2 * time.Second):
}

select {
case <-ctx.Done():
log.Info(ctx, "Task cancelled during second sleep", "taskId", reqCtx.Message.ID)
event := a2a.NewStatusUpdateEvent(reqCtx, a2a.TaskStateCanceled, nil)
event.Final = true
queue.Write(context.Background(), event)
return nil
case <-time.After(2 * time.Second):
}

// Continue emitting "task-finished" every 2 seconds
ticker := time.NewTicker(2 * time.Second)
defer ticker.Stop()

holdCtx, cancelHold := context.WithTimeout(ctx, 10*time.Second)
defer cancelHold()

for {
if holdCtx.Err() != nil {
if holdCtx.Err() == context.DeadlineExceeded {
log.Info(ctx, "Hold timeout, exiting hold loop", "taskId", reqCtx.Message.ID)
event := a2a.NewStatusUpdateEvent(reqCtx, a2a.TaskStateFailed, nil)
event.Final = true
queue.Write(context.Background(), event)
} else {
log.Info(ctx, "Task cancelled, exiting hold loop", "taskId", reqCtx.Message.ID)
event := a2a.NewStatusUpdateEvent(reqCtx, a2a.TaskStateCanceled, nil)
event.Final = true
queue.Write(context.Background(), event)
}
return nil
}
select {
case <-holdCtx.Done():
if holdCtx.Err() == context.DeadlineExceeded {
log.Info(ctx, "Hold timeout, exiting hold loop", "taskId", reqCtx.Message.ID)
} else {
log.Info(ctx, "Task cancelled, exiting hold loop", "taskId", reqCtx.Message.ID)
event := a2a.NewStatusUpdateEvent(reqCtx, a2a.TaskStateCanceled, nil)
event.Final = true
queue.Write(context.Background(), event)
}
return nil
case <-ticker.C:
log.Info(ctx, "Emitting periodic status update with response", "taskId", reqCtx.Message.ID)
bgCtx, cancelWrite := context.WithTimeout(context.Background(), 1*time.Second)
// In v0.3, re-subscribing creates a new child event queue that only receives new events.
// We must re-emit the message along with the status so that any client that
// re-subscribes immediately receives the latest task status and the response text.
err := queue.Write(bgCtx, a2a.NewStatusUpdateEvent(reqCtx, a2a.TaskStateWorking, finnishedMsg))
cancelWrite()
if err != nil {
log.Error(ctx, "Failed to write periodic update to queue", err)
}
}
}
} else {
msg := a2a.NewMessageForTask(a2a.MessageRoleAgent, reqCtx, a2a.TextPart{Text: response})
event := a2a.NewStatusUpdateEvent(reqCtx, a2a.TaskStateCompleted, msg)
event.Final = true
return queue.Write(ctx, event)
}
}

func (e *V03AgentExecutor) handleInstruction(ctx context.Context, reqCtx *a2asrv.RequestContext, inst *pb.Instruction) ([]string, error) {
Expand Down Expand Up @@ -146,7 +249,15 @@ func (e *V03AgentExecutor) handleInstruction(ctx context.Context, reqCtx *a2asrv
}

var responses []string
if call.Streaming {
if call.GetResubscribe() != nil {
if !call.Streaming {
return nil, fmt.Errorf("re-subscription requires streaming to be enabled")
}
responses, err = e.handleCallAgentWithResubscribe(ctx, client, wrappedMsg)
if err != nil {
return nil, err
}
} else if call.Streaming {
events := client.SendStreamingMessage(ctx, wrappedMsg)
for ev, err := range events {
if err != nil {
Expand Down Expand Up @@ -218,12 +329,112 @@ func (e *V03AgentExecutor) handleInstruction(ctx context.Context, reqCtx *a2asrv
}
}

func (e *V03AgentExecutor) handleCallAgentWithResubscribe(ctx context.Context, client *a2aclient.Client, req *a2a.MessageSendParams) ([]string, error) {
var responses []string
log.Info(ctx, "Executing re-subscribe behavior")

events := client.SendStreamingMessage(ctx, req)
var taskID a2a.TaskID
foundTask := false

for ev, err := range events {
if err != nil {
return nil, fmt.Errorf("streaming call failed before disconnect: %w", err)
}
log.Info(ctx, "Event before disconnect", "event", ev)

switch r := ev.(type) {
case *a2a.Task:
taskID = r.ID
foundTask = true
case *a2a.TaskStatusUpdateEvent:
taskID = r.TaskID
foundTask = true
}

if foundTask && taskID != "" {
break // Disconnect!
}
}

log.Info(ctx, "Disconnected from task. Now re-subscribing.", "taskID", taskID)

resubEvents := client.ResubscribeToTask(ctx, &a2a.TaskIDParams{ID: taskID})

var taskObj *a2a.Task
for ev, err := range resubEvents {
if err != nil {
return nil, fmt.Errorf("re-subscribe failed: %w", err)
}
log.Info(ctx, "Event after re-subscribe", "event", ev)

switch r := ev.(type) {
case *a2a.Task:
taskObj = r
case *a2a.TaskStatusUpdateEvent:
if r.Status.Message != nil {
for _, part := range r.Status.Message.Parts {
if textPart, ok := part.(a2a.TextPart); ok {
t := textPart.Text
t = strings.ReplaceAll(t, "task-finished", "")
responses = append(responses, t)

if strings.Contains(textPart.Text, "task-finished") {
log.Info(ctx, "Received task-finished after re-subscribe, breaking loop.")
goto EndLoop
}
}
}
}
}
}
EndLoop:

if len(responses) == 0 && taskObj != nil {
log.Info(ctx, "Responses empty after loop, reading from history.")
for _, msg := range taskObj.History {
if msg.Role == a2a.MessageRoleAgent {
for _, part := range msg.Parts {
if textPart, ok := part.(a2a.TextPart); ok {
t := textPart.Text
t = strings.ReplaceAll(t, "task-finished", "")
responses = append(responses, t)
}
}
}
}
}

log.Info(ctx, "Canceling task after retrieval.", "taskID", taskID)
_, err := client.CancelTask(ctx, &a2a.TaskIDParams{ID: taskID})
if err != nil {
log.Error(ctx, "Failed to cancel task", err, "taskID", taskID)
return responses, err
}

return responses, nil
}

func boolPtr(b bool) *bool {
return &b
}

func (e *V03AgentExecutor) Cancel(ctx context.Context, reqCtx *a2asrv.RequestContext, _ eventqueue.Queue) error {
log.Info(ctx, "Cancel requested", "taskId", reqCtx.Message.ID)
func (e *V03AgentExecutor) Cancel(ctx context.Context, reqCtx *a2asrv.RequestContext, queue eventqueue.Queue) error {
log.Info(ctx, "Cancel requested", "taskId", reqCtx.TaskID)

// Run in background to avoid blocking if queue is full
go func() {
bgCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
err := queue.Write(bgCtx, a2a.NewStatusUpdateEvent(reqCtx, a2a.TaskStateCanceled, nil))
if err != nil {
slog.Error("Failed to write cancel status to queue", "error", err)
}
}()

if cancel, ok := e.cancels.Load(reqCtx.TaskID); ok {
cancel.(context.CancelFunc)()
}
return nil
}

Expand Down
Loading