Skip to content

Commit 1ceb6f4

Browse files
committed
feat: Add streaming support to Cohere instrumentation
Extends the Cohere instrumentation with streaming chat support. Adds CohereStreamWrapper and AsyncCohereStreamWrapper for chat_stream operations. Ref #3050
1 parent 323bb89 commit 1ceb6f4

3 files changed

Lines changed: 420 additions & 0 deletions

File tree

instrumentation-genai/opentelemetry-instrumentation-cohere/src/opentelemetry/instrumentation/cohere/__init__.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,9 @@
5656

5757
from .patch import (
5858
async_chat_create,
59+
async_chat_stream_create,
5960
chat_create,
61+
chat_stream_create,
6062
)
6163

6264

@@ -90,16 +92,32 @@ def _instrument(self, **kwargs):
9092
wrapper=chat_create(handler, content_mode),
9193
)
9294

95+
# Instrument sync V2Client.chat_stream
96+
wrap_function_wrapper(
97+
module="cohere.v2.client",
98+
name="V2Client.chat_stream",
99+
wrapper=chat_stream_create(handler, content_mode),
100+
)
101+
93102
# Instrument async AsyncV2Client.chat
94103
wrap_function_wrapper(
95104
module="cohere.v2.client",
96105
name="AsyncV2Client.chat",
97106
wrapper=async_chat_create(handler, content_mode),
98107
)
99108

109+
# Instrument async AsyncV2Client.chat_stream
110+
wrap_function_wrapper(
111+
module="cohere.v2.client",
112+
name="AsyncV2Client.chat_stream",
113+
wrapper=async_chat_stream_create(handler, content_mode),
114+
)
115+
100116

101117
def _uninstrument(self, **kwargs):
102118
import cohere.v2.client # pylint: disable=import-outside-toplevel
103119

104120
unwrap(cohere.v2.client.V2Client, "chat")
121+
unwrap(cohere.v2.client.V2Client, "chat_stream")
105122
unwrap(cohere.v2.client.AsyncV2Client, "chat")
123+
unwrap(cohere.v2.client.AsyncV2Client, "chat_stream")

instrumentation-genai/opentelemetry-instrumentation-cohere/src/opentelemetry/instrumentation/cohere/patch.py

Lines changed: 273 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,15 @@
1818
from opentelemetry.util.genai.types import (
1919
ContentCapturingMode,
2020
Error,
21+
LLMInvocation,
22+
OutputMessage,
23+
Text,
2124
)
2225

2326
from .utils import (
27+
COHERE_PROVIDER_NAME,
2428
create_chat_invocation,
29+
map_finish_reason,
2530
set_response_attributes,
2631
)
2732

@@ -74,3 +79,271 @@ async def traced_method(wrapped, instance, args, kwargs):
7479
raise
7580

7681
return traced_method
82+
83+
84+
def chat_stream_create(
85+
handler: TelemetryHandler,
86+
content_capturing_mode: ContentCapturingMode,
87+
):
88+
"""Wrap ``V2Client.chat_stream`` to emit GenAI telemetry."""
89+
capture_content = content_capturing_mode != ContentCapturingMode.NO_CONTENT
90+
91+
def traced_method(wrapped, instance, args, kwargs):
92+
invocation = handler.start_llm(
93+
create_chat_invocation(kwargs, instance, capture_content=capture_content)
94+
)
95+
try:
96+
result = wrapped(*args, **kwargs)
97+
return CohereStreamWrapper(result, handler, invocation, capture_content)
98+
except Exception as error:
99+
handler.fail_llm(
100+
invocation, Error(type=type(error), message=str(error))
101+
)
102+
raise
103+
104+
return traced_method
105+
106+
107+
def async_chat_stream_create(
108+
handler: TelemetryHandler,
109+
content_capturing_mode: ContentCapturingMode,
110+
):
111+
"""Wrap ``AsyncV2Client.chat_stream`` to emit GenAI telemetry."""
112+
capture_content = content_capturing_mode != ContentCapturingMode.NO_CONTENT
113+
114+
async def traced_method(wrapped, instance, args, kwargs):
115+
invocation = handler.start_llm(
116+
create_chat_invocation(kwargs, instance, capture_content=capture_content)
117+
)
118+
try:
119+
result = wrapped(*args, **kwargs)
120+
return AsyncCohereStreamWrapper(result, handler, invocation, capture_content)
121+
except Exception as error:
122+
handler.fail_llm(
123+
invocation, Error(type=type(error), message=str(error))
124+
)
125+
raise
126+
127+
return traced_method
128+
129+
130+
class CohereStreamWrapper:
131+
"""Wraps a synchronous Cohere chat_stream iterator to capture telemetry."""
132+
133+
def __init__(
134+
self,
135+
stream,
136+
handler: TelemetryHandler,
137+
invocation: LLMInvocation,
138+
capture_content: bool,
139+
):
140+
self._stream = stream
141+
self._handler = handler
142+
self._invocation = invocation
143+
self._capture_content = capture_content
144+
self._content_parts: list[str] = []
145+
self._finish_reason = None
146+
self._response_id = None
147+
148+
def __iter__(self):
149+
return self
150+
151+
def __next__(self):
152+
try:
153+
event = next(self._stream)
154+
self._process_event(event)
155+
return event
156+
except StopIteration:
157+
self._finalize()
158+
raise
159+
except Exception as error:
160+
self._handler.fail_llm(
161+
self._invocation,
162+
Error(type=type(error), message=str(error)),
163+
)
164+
raise
165+
166+
def __enter__(self):
167+
return self
168+
169+
def __exit__(self, exc_type, exc_val, exc_tb):
170+
if exc_type is not None:
171+
self._handler.fail_llm(
172+
self._invocation,
173+
Error(type=exc_type, message=str(exc_val)),
174+
)
175+
return False
176+
177+
def _process_event(self, event):
178+
event_type = getattr(event, "type", None)
179+
180+
if event_type == "message-start":
181+
delta = getattr(event, "delta", None)
182+
if delta:
183+
msg = getattr(delta, "message", None)
184+
if msg:
185+
role = getattr(msg, "role", None)
186+
if role:
187+
self._invocation.attributes["_cohere_role"] = role
188+
event_id = getattr(event, "id", None)
189+
if event_id:
190+
self._response_id = event_id
191+
192+
elif event_type == "content-delta":
193+
delta = getattr(event, "delta", None)
194+
if delta:
195+
msg = getattr(delta, "message", None)
196+
if msg:
197+
content = getattr(msg, "content", None)
198+
if content:
199+
text = getattr(content, "text", None)
200+
if text:
201+
self._content_parts.append(text)
202+
203+
elif event_type == "message-end":
204+
delta = getattr(event, "delta", None)
205+
if delta:
206+
self._finish_reason = getattr(delta, "finish_reason", None)
207+
usage = getattr(delta, "usage", None)
208+
if usage:
209+
from .utils import _set_usage
210+
211+
_set_usage(self._invocation, usage)
212+
event_id = getattr(event, "id", None)
213+
if event_id:
214+
self._response_id = event_id
215+
216+
def _finalize(self):
217+
if self._response_id:
218+
self._invocation.response_id = self._response_id
219+
220+
if self._finish_reason is not None:
221+
self._invocation.finish_reasons = [
222+
map_finish_reason(self._finish_reason)
223+
]
224+
225+
if self._capture_content and self._content_parts:
226+
role = self._invocation.attributes.pop("_cohere_role", "assistant")
227+
full_text = "".join(self._content_parts)
228+
self._invocation.output_messages = [
229+
OutputMessage(
230+
role=role,
231+
parts=[Text(content=full_text)],
232+
finish_reason=map_finish_reason(self._finish_reason),
233+
)
234+
]
235+
else:
236+
self._invocation.attributes.pop("_cohere_role", None)
237+
238+
self._handler.stop_llm(self._invocation)
239+
240+
241+
class AsyncCohereStreamWrapper:
242+
"""Wraps an async Cohere chat_stream iterator to capture telemetry."""
243+
244+
def __init__(
245+
self,
246+
stream,
247+
handler: TelemetryHandler,
248+
invocation: LLMInvocation,
249+
capture_content: bool,
250+
):
251+
self._stream = stream
252+
self._handler = handler
253+
self._invocation = invocation
254+
self._capture_content = capture_content
255+
self._content_parts: list[str] = []
256+
self._finish_reason = None
257+
self._response_id = None
258+
259+
def __aiter__(self):
260+
return self
261+
262+
async def __anext__(self):
263+
try:
264+
event = await self._stream.__anext__()
265+
self._process_event(event)
266+
return event
267+
except StopAsyncIteration:
268+
self._finalize()
269+
raise
270+
except Exception as error:
271+
self._handler.fail_llm(
272+
self._invocation,
273+
Error(type=type(error), message=str(error)),
274+
)
275+
raise
276+
277+
async def __aenter__(self):
278+
return self
279+
280+
async def __aexit__(self, exc_type, exc_val, exc_tb):
281+
if exc_type is not None:
282+
self._handler.fail_llm(
283+
self._invocation,
284+
Error(type=exc_type, message=str(exc_val)),
285+
)
286+
return False
287+
288+
def _process_event(self, event):
289+
event_type = getattr(event, "type", None)
290+
291+
if event_type == "message-start":
292+
delta = getattr(event, "delta", None)
293+
if delta:
294+
msg = getattr(delta, "message", None)
295+
if msg:
296+
role = getattr(msg, "role", None)
297+
if role:
298+
self._invocation.attributes["_cohere_role"] = role
299+
event_id = getattr(event, "id", None)
300+
if event_id:
301+
self._response_id = event_id
302+
303+
elif event_type == "content-delta":
304+
delta = getattr(event, "delta", None)
305+
if delta:
306+
msg = getattr(delta, "message", None)
307+
if msg:
308+
content = getattr(msg, "content", None)
309+
if content:
310+
text = getattr(content, "text", None)
311+
if text:
312+
self._content_parts.append(text)
313+
314+
elif event_type == "message-end":
315+
delta = getattr(event, "delta", None)
316+
if delta:
317+
self._finish_reason = getattr(delta, "finish_reason", None)
318+
usage = getattr(delta, "usage", None)
319+
if usage:
320+
from .utils import _set_usage
321+
322+
_set_usage(self._invocation, usage)
323+
event_id = getattr(event, "id", None)
324+
if event_id:
325+
self._response_id = event_id
326+
327+
def _finalize(self):
328+
if self._response_id:
329+
self._invocation.response_id = self._response_id
330+
331+
if self._finish_reason is not None:
332+
self._invocation.finish_reasons = [
333+
map_finish_reason(self._finish_reason)
334+
]
335+
336+
if self._capture_content and self._content_parts:
337+
role = self._invocation.attributes.pop("_cohere_role", "assistant")
338+
full_text = "".join(self._content_parts)
339+
self._invocation.output_messages = [
340+
OutputMessage(
341+
role=role,
342+
parts=[Text(content=full_text)],
343+
finish_reason=map_finish_reason(self._finish_reason),
344+
)
345+
]
346+
else:
347+
self._invocation.attributes.pop("_cohere_role", None)
348+
349+
self._handler.stop_llm(self._invocation)

0 commit comments

Comments
 (0)