Skip to content

Commit 9a05344

Browse files
committed
Exit early if trying to close the peer connections twice
Previously, the repeated close() calls were hanging indefinitely for PublisherPeerConnection and SubscriberPeerConnection. Now, we set `_closed=True` guard after the connection is closed for the first time and exit early on the repeated call.
1 parent f3972a6 commit 9a05344

4 files changed

Lines changed: 77 additions & 4 deletions

File tree

getstream/video/rtc/pc.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ def __init__(
5050
)
5151
super().__init__(configuration)
5252
self.manager = manager
53+
self._closed = False
5354
self._connected_event = asyncio.Event()
5455

5556
for transceiver in self.getTransceivers():
@@ -113,6 +114,14 @@ async def wait_for_connected(self, timeout: float = 15.0):
113114
logger.error(f"Publisher connection timed out after {timeout}s")
114115
raise TimeoutError(f"Connection timed out after {timeout} seconds")
115116

117+
async def close(self):
118+
# Using self._closed guard here
119+
# to avoid closing RTCPeerConnectionTwice by accident (it freezes on second time)
120+
if self._closed:
121+
return
122+
await super().close()
123+
self._closed = True
124+
116125
async def restartIce(self):
117126
"""Restart ICE connection for reconnection scenarios."""
118127
logger.info("Restarting ICE connection for publisher")
@@ -138,6 +147,7 @@ def __init__(
138147
)
139148
super().__init__(configuration)
140149
self.connection = connection
150+
self._closed = False
141151
self._drain_video_frames = drain_video_frames
142152

143153
self.track_map = {} # track_id -> (MediaRelay, original_track)
@@ -245,6 +255,31 @@ def get_video_frame_tracker(self) -> Optional[Any]:
245255
return next(iter(self.video_frame_trackers.values()))
246256
return None
247257

258+
async def close(self):
259+
# Using self._closed guard here
260+
# to avoid closing RTCPeerConnectionTwice by accident (it freezes on second time)
261+
if self._closed:
262+
return
263+
264+
# Clean up video drains
265+
for blackhole, drain_task, drain_proxy in list(self._video_drains.values()):
266+
drain_task.cancel()
267+
drain_proxy.stop()
268+
await blackhole.stop()
269+
self._video_drains.clear()
270+
271+
# Cancel background tasks
272+
for task in list(self._background_tasks):
273+
task.cancel()
274+
self._background_tasks.clear()
275+
276+
# Clear track maps
277+
self.track_map.clear()
278+
self.video_frame_trackers.clear()
279+
280+
await super().close()
281+
self._closed = True
282+
248283
async def restartIce(self):
249284
"""Restart ICE connection for reconnection scenarios."""
250285
logger.info("Restarting ICE connection for subscriber")

getstream/video/rtc/reconnection.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,10 @@ async def _reconnect_migrate(self):
284284
current_publisher = self.connection_manager.publisher_pc
285285
current_subscriber = self.connection_manager.subscriber_pc
286286

287+
# Clear old references so _connect_internal creates fresh PCs
288+
self.connection_manager.publisher_pc = None
289+
self.connection_manager.subscriber_pc = None
290+
287291
self.connection_manager.connection_state = ConnectionState.MIGRATING
288292

289293
if current_publisher and hasattr(current_publisher, "removeListener"):
Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,23 @@
1+
import asyncio
12
import contextlib
3+
import uuid
4+
from unittest.mock import AsyncMock, MagicMock, patch
25

36
import pytest
4-
from unittest.mock import AsyncMock, patch, MagicMock
7+
from dotenv import load_dotenv
58

9+
from getstream import AsyncStream
10+
from getstream.video import rtc
611
from getstream.video.rtc.connection_manager import ConnectionManager
7-
from getstream.video.rtc.connection_utils import SfuJoinError, SfuConnectionError
12+
from getstream.video.rtc.connection_utils import (
13+
ConnectionState,
14+
SfuConnectionError,
15+
SfuJoinError,
16+
)
817
from getstream.video.rtc.pb.stream.video.sfu.models import models_pb2
918

19+
load_dotenv()
20+
1021

1122
@contextlib.contextmanager
1223
def patched_dependencies():
@@ -45,8 +56,30 @@ def connection_manager(request):
4556
yield cm
4657

4758

48-
class TestConnectRetry:
49-
"""Tests for connect() retry logic when SFU is full."""
59+
@pytest.fixture
60+
def client():
61+
return AsyncStream(timeout=10.0)
62+
63+
64+
class TestConnectionManager:
65+
@pytest.mark.asyncio
66+
@pytest.mark.integration
67+
async def test_leave_twice_does_not_hang(self, client: AsyncStream):
68+
"""Integration test: join a real call and leave twice without hanging."""
69+
call_id = str(uuid.uuid4())
70+
call = client.video.call("default", call_id)
71+
72+
async with await rtc.join(call, "test-user") as connection:
73+
assert connection.connection_state == ConnectionState.JOINED
74+
75+
await asyncio.sleep(2)
76+
77+
await asyncio.wait_for(connection.leave(), timeout=10.0)
78+
assert connection.connection_state == ConnectionState.LEFT
79+
80+
# Second leave must not hang
81+
await asyncio.wait_for(connection.leave(), timeout=10.0)
82+
assert connection.connection_state == ConnectionState.LEFT
5083

5184
@pytest.mark.asyncio
5285
@pytest.mark.parametrize("connection_manager", [2], indirect=True)

tests/rtc/test_subscriber_drain.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ def subscriber_pc():
1313
"""Create a SubscriberPeerConnection bypassing heavy parent inits."""
1414
pc = SubscriberPeerConnection.__new__(SubscriberPeerConnection)
1515
pc.connection = Mock()
16+
pc._closed = False
1617
pc._drain_video_frames = True
1718
pc.track_map = {}
1819
pc.video_frame_trackers = {}

0 commit comments

Comments
 (0)