Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
204 changes: 116 additions & 88 deletions livekit-agents/livekit/agents/voice/remote_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import time
from abc import ABC, abstractmethod
from collections.abc import AsyncIterator, Mapping, Sequence
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any

from google.protobuf.timestamp_pb2 import Timestamp
Expand Down Expand Up @@ -55,68 +56,99 @@
TOPIC_SESSION_MESSAGES = "lk.agent.session"


@dataclass(frozen=True, slots=True)
class IncomingMessage:
message: agent_pb.AgentSessionMessage
sender_identity: str | None = None


class SessionTransport(ABC):
@abstractmethod
async def start(self) -> None: ...
@abstractmethod
async def send_message(self, msg: agent_pb.AgentSessionMessage) -> None: ...
async def send_message(
self,
msg: agent_pb.AgentSessionMessage,
*,
destination_identity: str | None = None,
) -> None:
"""Send a message. When ``destination_identity`` is None, broadcast to every authorized peer."""

@abstractmethod
async def close(self) -> None: ...
@abstractmethod
def __aiter__(self) -> AsyncIterator[agent_pb.AgentSessionMessage]: ...
def __aiter__(self) -> AsyncIterator[IncomingMessage]: ...
@abstractmethod
async def __anext__(self) -> agent_pb.AgentSessionMessage: ...
async def __anext__(self) -> IncomingMessage: ...


class RoomSessionTransport(SessionTransport):
def __init__(self, room: rtc.Room, remote_identity: str | None = None) -> None:
def __init__(self, room: rtc.Room) -> None:
self._room = room
self._remote_identity = remote_identity
self._recv_ch: utils.aio.Chan[agent_pb.AgentSessionMessage] = utils.aio.Chan()
self._recv_ch: utils.aio.Chan[IncomingMessage] = utils.aio.Chan()
self._handler_registered = False
self._tasks: set[asyncio.Task[None]] = set()

@property
def remote_identity(self) -> str | None:
return self._remote_identity

@remote_identity.setter
def remote_identity(self, value: str | None) -> None:
self._remote_identity = value

async def start(self) -> None:
if self._handler_registered:
return
self._room.register_byte_stream_handler(TOPIC_SESSION_MESSAGES, self._on_byte_stream)
self._handler_registered = True

def _can_manage(self, identity: str) -> bool:
participant = self._room.remote_participants.get(identity)
return participant is not None and participant.permissions.can_manage_agent_session

def _authorized_identities(self) -> list[str]:
return [
identity
for identity, participant in self._room.remote_participants.items()
if participant.permissions.can_manage_agent_session
]

def _on_byte_stream(self, reader: rtc.ByteStreamReader, participant_identity: str) -> None:
if self._remote_identity and participant_identity != self._remote_identity:
if not self._can_manage(participant_identity):
logger.debug(
"ignoring session message from participant without can_manage_agent_session grant",
extra={"participant": participant_identity},
)
return
task = asyncio.create_task(self._read_stream(reader))
task = asyncio.create_task(self._read_stream(reader, participant_identity))
self._tasks.add(task)
task.add_done_callback(self._tasks.discard)

async def _read_stream(self, reader: rtc.ByteStreamReader) -> None:
async def _read_stream(self, reader: rtc.ByteStreamReader, sender_identity: str) -> None:
try:
chunks: list[bytes] = []
async for chunk in reader:
chunks.append(chunk)
data = b"".join(chunks)
msg = agent_pb.AgentSessionMessage()
msg.ParseFromString(data)
self._recv_ch.send_nowait(msg)
self._recv_ch.send_nowait(IncomingMessage(message=msg, sender_identity=sender_identity))
except utils.aio.ChanClosed:
pass
except Exception as e:
logger.warning("failed to read binary stream message", exc_info=e)

async def send_message(self, msg: agent_pb.AgentSessionMessage) -> None:
async def send_message(
self,
msg: agent_pb.AgentSessionMessage,
*,
destination_identity: str | None = None,
) -> None:
if self._recv_ch.closed or not self._room.isconnected():
return
if destination_identity is not None:
if not self._can_manage(destination_identity):
return
dest = [destination_identity]
else:
dest = self._authorized_identities()
if not dest:
return
try:
data = msg.SerializeToString()
dest = [self._remote_identity] if self._remote_identity else None
writer = await self._room.local_participant.stream_bytes(
name=utils.shortuuid("AS_"),
topic=TOPIC_SESSION_MESSAGES,
Expand All @@ -140,10 +172,10 @@ async def close(self) -> None:
pass
self._handler_registered = False

def __aiter__(self) -> AsyncIterator[agent_pb.AgentSessionMessage]:
def __aiter__(self) -> AsyncIterator[IncomingMessage]:
return self._recv_ch.__aiter__()

async def __anext__(self) -> agent_pb.AgentSessionMessage:
async def __anext__(self) -> IncomingMessage:
return await self._recv_ch.__anext__()


Expand Down Expand Up @@ -171,7 +203,12 @@ async def start(self) -> None:
self._writer = writer
self._loop = asyncio.get_running_loop()

async def send_message(self, msg: agent_pb.AgentSessionMessage) -> None:
async def send_message(
self,
msg: agent_pb.AgentSessionMessage,
*,
destination_identity: str | None = None,
) -> None:
if self._closed or self._writer is None:
return
data = msg.SerializeToString()
Expand All @@ -198,10 +235,10 @@ async def close(self) -> None:
except (ConnectionError, OSError):
pass

def __aiter__(self) -> AsyncIterator[agent_pb.AgentSessionMessage]:
def __aiter__(self) -> AsyncIterator[IncomingMessage]:
return self

async def __anext__(self) -> agent_pb.AgentSessionMessage:
async def __anext__(self) -> IncomingMessage:
if self._closed or self._reader is None:
raise StopAsyncIteration

Expand All @@ -222,7 +259,7 @@ async def __anext__(self) -> agent_pb.AgentSessionMessage:

msg = agent_pb.AgentSessionMessage()
msg.ParseFromString(data)
return msg
return IncomingMessage(message=msg)


_AGENT_STATE_MAP: dict[AgentState, agent_pb.AgentState] = {
Expand Down Expand Up @@ -399,10 +436,11 @@ async def aclose(self) -> None:

async def _recv_loop(self) -> None:
try:
async for msg in self._transport:
async for incoming in self._transport:
msg = incoming.message
if msg.HasField("request"):
if self._session is not None:
self._tasks.create_task(self._handle_request_safe(msg.request))
self._tasks.create_task(self._handle_request_safe(incoming))
else:
msg_type = msg.WhichOneof("message")
if msg_type:
Expand Down Expand Up @@ -542,65 +580,64 @@ def _on_error(self, event: ErrorEvent) -> None:
)
)

async def _handle_request_safe(self, req: agent_pb.SessionRequest) -> None:
async def _send_response(
self, incoming: IncomingMessage, response: agent_pb.SessionResponse
) -> None:
response.request_id = incoming.message.request.request_id
await self._transport.send_message(
agent_pb.AgentSessionMessage(response=response),
destination_identity=incoming.sender_identity,
)

async def _handle_request_safe(self, incoming: IncomingMessage) -> None:
req = incoming.message.request
try:
await self._handle_request(req)
await self._handle_request(incoming)
except Exception:
logger.warning(
"error handling session request",
exc_info=True,
extra={"request_id": req.request_id},
)
try:
resp = agent_pb.AgentSessionMessage(
response=agent_pb.SessionResponse(
request_id=req.request_id,
error="internal error",
)
await self._send_response(
incoming, agent_pb.SessionResponse(error="internal error")
)
await self._transport.send_message(resp)
except Exception:
pass

async def _handle_request(self, req: agent_pb.SessionRequest) -> None:
async def _handle_request(self, incoming: IncomingMessage) -> None:
assert self._session is not None
req = incoming.message.request

if req.HasField("ping"):
resp = agent_pb.AgentSessionMessage(
response=agent_pb.SessionResponse(
request_id=req.request_id,
pong=agent_pb.SessionResponse.Pong(),
)
await self._send_response(
incoming, agent_pb.SessionResponse(pong=agent_pb.SessionResponse.Pong())
)
await self._transport.send_message(resp)

elif req.HasField("get_chat_history"):
items = [_chat_item_to_proto(item) for item in self._session.history.items]
resp = agent_pb.AgentSessionMessage(
response=agent_pb.SessionResponse(
request_id=req.request_id,
get_chat_history=agent_pb.SessionResponse.GetChatHistoryResponse(
items=items,
),
)
await self._send_response(
incoming,
agent_pb.SessionResponse(
get_chat_history=agent_pb.SessionResponse.GetChatHistoryResponse(items=items),
),
)
await self._transport.send_message(resp)

elif req.HasField("get_agent_info"):
agent = self._session.current_agent
items = [_chat_item_to_proto(item) for item in agent.chat_ctx.items]
resp = agent_pb.AgentSessionMessage(
response=agent_pb.SessionResponse(
request_id=req.request_id,
await self._send_response(
incoming,
agent_pb.SessionResponse(
get_agent_info=agent_pb.SessionResponse.GetAgentInfoResponse(
id=agent.id,
instructions=agent.instructions,
tools=_tool_names(agent.tools),
chat_ctx=items,
),
)
),
)
await self._transport.send_message(resp)

elif req.HasField("run_input"):
items_list: list[agent_pb.ChatContext.ChatItem] = []
Expand Down Expand Up @@ -629,42 +666,36 @@ async def _handle_request(self, req: agent_pb.SessionRequest) -> None:
except Exception as e:
error = str(e)

resp = agent_pb.AgentSessionMessage(
response=agent_pb.SessionResponse(
request_id=req.request_id,
await self._send_response(
incoming,
agent_pb.SessionResponse(
error=error,
run_input=agent_pb.SessionResponse.RunInputResponse(
items=items_list,
),
)
run_input=agent_pb.SessionResponse.RunInputResponse(items=items_list),
),
)
await self._transport.send_message(resp)

elif req.HasField("get_session_state"):
agent = self._session.current_agent
created_at = Timestamp()
started_at = self._session._started_at or time.time()
created_at.FromNanoseconds(int(started_at * 1e9))

resp = agent_pb.AgentSessionMessage(
response=agent_pb.SessionResponse(
request_id=req.request_id,
await self._send_response(
incoming,
agent_pb.SessionResponse(
get_session_state=agent_pb.SessionResponse.GetSessionStateResponse(
agent_state=_AGENT_STATE_MAP.get(
self._session.agent_state,
agent_pb.AS_IDLE,
self._session.agent_state, agent_pb.AS_IDLE
),
user_state=_USER_STATE_MAP.get(
self._session.user_state,
agent_pb.US_LISTENING,
self._session.user_state, agent_pb.US_LISTENING
),
agent_id=agent.id,
options=_serialize_options(self._session.options),
created_at=created_at,
),
)
),
)
await self._transport.send_message(resp)

elif req.HasField("get_rtc_stats"):
from google.protobuf.struct_pb2 import Struct
Expand All @@ -690,43 +721,40 @@ async def _handle_request(self, req: agent_pb.SessionRequest) -> None:
st.update(d)
subscriber_stats.append(st)

resp = agent_pb.AgentSessionMessage(
response=agent_pb.SessionResponse(
request_id=req.request_id,
await self._send_response(
incoming,
agent_pb.SessionResponse(
get_rtc_stats=agent_pb.SessionResponse.GetRTCStatsResponse(
publisher_stats=publisher_stats,
subscriber_stats=subscriber_stats,
),
)
),
)
await self._transport.send_message(resp)

elif req.HasField("get_session_usage"):
created_at = Timestamp()
created_at.FromNanoseconds(int(time.time() * 1e9))

resp = agent_pb.AgentSessionMessage(
response=agent_pb.SessionResponse(
request_id=req.request_id,
await self._send_response(
incoming,
agent_pb.SessionResponse(
get_session_usage=agent_pb.SessionResponse.GetSessionUsageResponse(
usage=_session_usage_to_proto(self._session.usage),
created_at=created_at,
),
)
),
)
await self._transport.send_message(resp)

elif req.HasField("get_framework_info"):
resp = agent_pb.AgentSessionMessage(
response=agent_pb.SessionResponse(
request_id=req.request_id,
await self._send_response(
incoming,
agent_pb.SessionResponse(
get_framework_info=agent_pb.SessionResponse.GetFrameworkInfoResponse(
sdk="python",
sdk_version=__version__,
),
)
),
)
await self._transport.send_message(resp)


def _session_usage_to_proto(usage: AgentSessionUsage) -> agent_pb.AgentSessionUsage:
Expand Down
Loading
Loading