Skip to content

Commit b2353d3

Browse files
authored
Type-annotate all of SocketShark (#328)
* Type-annotate all of SocketShark. * More accurate annotation for `ping_timeout_handler`. Previous annotation was too loose. The `ping_timeout_handler` is only ever called with the result of `await self.websocket.ping()`, which has the following signature: ``` async def ping(self, data: Optional[Data] = None) -> Awaitable[None]: ``` * More precise `client` annotation. * Type-annotate the re-added `LogMetrics`, too. I briefly removed it from master before realizing it's still used (in a very counter-intuitive fashion).
1 parent 5ed4e27 commit b2353d3

15 files changed

Lines changed: 577 additions & 329 deletions

setup.cfg

Lines changed: 0 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -109,40 +109,3 @@ check_untyped_defs = true
109109
disallow_any_generics = true
110110
disallow_untyped_defs = true
111111
disallow_incomplete_defs = true
112-
113-
# The following modules have not yet been fully type-annotated. Allow untyped
114-
# defs in them until they get typed, then remove the corresponding override.
115-
[mypy-socketshark]
116-
allow_untyped_defs = True
117-
118-
[mypy-socketshark.backend.websockets]
119-
allow_untyped_defs = True
120-
121-
[mypy-socketshark.events]
122-
allow_untyped_defs = True
123-
124-
[mypy-socketshark.exceptions]
125-
allow_untyped_defs = True
126-
127-
[mypy-socketshark.metrics]
128-
allow_untyped_defs = True
129-
130-
[mypy-socketshark.metrics.prometheus]
131-
allow_untyped_defs = True
132-
133-
[mypy-socketshark.receiver]
134-
allow_untyped_defs = True
135-
136-
[mypy-socketshark.redis_connection]
137-
allow_untyped_defs = True
138-
allow_incomplete_defs = True
139-
140-
[mypy-socketshark.session]
141-
allow_untyped_defs = True
142-
allow_incomplete_defs = True
143-
144-
[mypy-socketshark.subscription]
145-
allow_untyped_defs = True
146-
147-
[mypy-socketshark.utils]
148-
allow_untyped_defs = True

socketshark/__init__.py

Lines changed: 39 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import signal
66
import ssl
77
import sys
8+
from typing import Any
89

910
import aioredis
1011
import click
@@ -14,14 +15,16 @@
1415
from .metrics import Metrics
1516
from .receiver import ServiceReceiver
1617
from .redis_connection import RedisConnection
18+
from .session import Session
19+
from .types import Config, LogConfig
1720

1821

19-
def setup_logging(log_config):
22+
def setup_logging(log_config: LogConfig) -> None:
2023
# Configure root logger if logging level is specified in config
2124
if log_config['level']:
2225
level = getattr(logging, log_config['level'])
2326
formatter = logging.Formatter(log_config['format'])
24-
sh = logging.StreamHandler()
27+
sh: logging.StreamHandler[Any] = logging.StreamHandler()
2528
sh.setFormatter(formatter)
2629

2730
logger = logging.getLogger()
@@ -36,7 +39,7 @@ def setup_logging(log_config):
3639
setup_structlog(sys.stdout.isatty())
3740

3841

39-
def setup_structlog(tty=False):
42+
def setup_structlog(tty: bool = False) -> None:
4043
processors = [
4144
structlog.stdlib.filter_by_level,
4245
structlog.stdlib.add_log_level,
@@ -59,7 +62,7 @@ def setup_structlog(tty=False):
5962
)
6063

6164

62-
def load_backend(config):
65+
def load_backend(config: Config) -> Any:
6366
"""
6467
Return the backend module from the given SocketShark configuration.
6568
"""
@@ -69,30 +72,34 @@ def load_backend(config):
6972

7073

7174
class SocketShark:
72-
def __init__(self, config):
75+
def __init__(self, config: Config) -> None:
7376
self.config = config
7477
backend_module = load_backend(config)
7578
backend_cls = backend_module.Backend
7679
self.backend = backend_cls(self)
7780
self._init_logging()
78-
self._task = None
81+
self._task: asyncio.Task[None] | None = None
7982
self._shutdown = False
80-
self.sessions = set()
83+
self.sessions: set[Session] = set()
8184
self.metrics = Metrics(self)
8285
self.metrics.initialize()
8386
self.metrics.set_ready(False)
84-
self.redis_connections = []
87+
self.redis_connections: list[RedisConnection] = []
8588

86-
def _init_logging(self):
89+
def _init_logging(self) -> None:
8790
logger_name = self.config['LOG']['logger_name']
8891
trace_logger_prefix = self.config['LOG']['trace_logger_prefix']
8992
trace_logger_name = '{}.{}'.format(trace_logger_prefix, logger_name)
9093
pid = os.getpid()
91-
self.log = structlog.get_logger(logger_name).bind(pid=pid)
92-
self.trace_log = structlog.get_logger(trace_logger_name).bind(pid=pid)
94+
self.log: structlog.stdlib.BoundLogger = structlog.get_logger(
95+
logger_name
96+
).bind(pid=pid)
97+
self.trace_log: structlog.stdlib.BoundLogger = structlog.get_logger(
98+
trace_logger_name
99+
).bind(pid=pid)
93100
self.trace_log.debug('trace')
94101

95-
def signal_ready(self):
102+
def signal_ready(self) -> None:
96103
"""
97104
Notify that the backend is ready.
98105
"""
@@ -104,14 +111,14 @@ def signal_ready(self):
104111
)
105112
self.metrics.set_ready(True)
106113

107-
def signal_shutdown(self):
114+
def signal_shutdown(self) -> None:
108115
"""
109116
Notify that the backend shut down.
110117
"""
111118
self.log.info('done')
112119
self.metrics.set_ready(False)
113120

114-
async def _redis_connection_handler(self):
121+
async def _redis_connection_handler(self) -> None:
115122
"""
116123
Handle Redis connection errors.
117124
@@ -132,7 +139,7 @@ async def _redis_connection_handler(self):
132139
# Redis goes down so they can reconnect and restore subscriptions.
133140
asyncio.ensure_future(self.shutdown())
134141

135-
async def prepare(self):
142+
async def prepare(self) -> None:
136143
"""
137144
Callback called by the backend to prepare SocketShark.
138145
@@ -155,12 +162,12 @@ async def prepare(self):
155162

156163
self.service_receiver = ServiceReceiver(self)
157164

158-
def _cleanup(self):
165+
def _cleanup(self) -> None:
159166
self._redis_connection_handler_task.cancel()
160167
for c in self.redis_connections:
161168
c.redis.close()
162169

163-
async def shutdown(self):
170+
async def shutdown(self) -> None:
164171
"""
165172
Shut down SocketShark cleanly.
166173
"""
@@ -201,20 +208,22 @@ async def shutdown(self):
201208
self._uninstall_signal_handlers()
202209
self._shutdown = False
203210

204-
async def run_service_receiver(self, once=False):
211+
async def run_service_receiver(
212+
self, once: bool = False
213+
) -> list[bool | None] | None:
205214
return await self.service_receiver.reader(once=once)
206215

207-
def start(self):
216+
def start(self) -> None:
208217
"""
209218
Start the backend (main entrypoint into SocketShark).
210219
"""
211220
self.backend.start()
212221

213-
async def _run(self, once=False):
222+
async def _run(self, once: bool = False) -> None:
214223
await self.run_service_receiver()
215224
asyncio.ensure_future(self.shutdown())
216225

217-
async def run(self, once=False):
226+
async def run(self, once: bool = False) -> None:
218227
"""
219228
Set up SocketShark signal handlers and run the service receiver.
220229
@@ -223,39 +232,40 @@ async def run(self, once=False):
223232
self._install_signal_handlers()
224233
self._task = asyncio.ensure_future(self._run())
225234

226-
def _install_signal_handlers(self):
235+
def _install_signal_handlers(self) -> None:
227236
"""
228237
Set up signal handlers for safely stopping the worker.
229238
"""
230239

231-
def request_stop():
240+
def request_stop() -> None:
232241
self.log.info('stop requested')
233242
asyncio.ensure_future(self.shutdown())
234243

235244
loop = asyncio.get_event_loop()
236245
loop.add_signal_handler(signal.SIGINT, request_stop)
237246
loop.add_signal_handler(signal.SIGTERM, request_stop)
238247

239-
def _uninstall_signal_handlers(self):
248+
def _uninstall_signal_handlers(self) -> None:
240249
"""
241250
Restore default signal handlers.
242251
"""
243252
loop = asyncio.get_event_loop()
244253
loop.remove_signal_handler(signal.SIGINT)
245254
loop.remove_signal_handler(signal.SIGTERM)
246255

247-
def get_ssl_context(self):
256+
def get_ssl_context(self) -> ssl.SSLContext | None:
248257
ssl_settings = self.config.get('WS_SSL')
249258
if ssl_settings:
250259
ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
251260
ssl_context.load_cert_chain(
252261
certfile=ssl_settings['cert'], keyfile=ssl_settings['key']
253262
)
254263
return ssl_context
264+
return None
255265

256266

257-
def load_config(config_name):
258-
config = {}
267+
def load_config(config_name: str) -> Config:
268+
config: dict[str, Any] = {}
259269

260270
# Get config defaults
261271
for key in dir(config_defaults):
@@ -272,13 +282,13 @@ def load_config(config_name):
272282
else:
273283
config[key] = value
274284

275-
return config
285+
return Config(config)
276286

277287

278288
@click.command()
279289
@click.option('-c', '--config', required=True, help='dotted path to config')
280290
@click.pass_context
281-
def run(context, config):
291+
def run(context: click.Context, config: str) -> None:
282292
config_obj = load_config(config)
283293

284294
setup_logging(config_obj['LOG'])

socketshark/backend/websockets.py

Lines changed: 30 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,24 @@
11
import asyncio
22
import json
33
import time
4+
from typing import TYPE_CHECKING
45

56
import websockets
67

78
from .. import constants as c
89
from ..session import Session
10+
from ..types import ClientMessage
11+
12+
if TYPE_CHECKING:
13+
from .. import SocketShark
914

1015

1116
class Client:
12-
def __init__(self, shark, websocket):
17+
def __init__(
18+
self,
19+
shark: 'SocketShark',
20+
websocket: websockets.WebSocketServerProtocol,
21+
) -> None:
1322
self.websocket = websocket
1423
self.session = Session(
1524
shark,
@@ -20,7 +29,7 @@ def __init__(self, shark, websocket):
2029
)
2130
self.shark = shark
2231

23-
async def ping_timeout_handler(self, ping):
32+
async def ping_timeout_handler(self, ping: asyncio.Future[None]) -> bool:
2433
ping_timeout = self.shark.config['WS_PING']['timeout']
2534
await asyncio.sleep(ping_timeout)
2635

@@ -33,7 +42,7 @@ async def ping_timeout_handler(self, ping):
3342

3443
return False
3544

36-
async def ping_handler(self):
45+
async def ping_handler(self) -> None:
3746
ping_interval = self.shark.config['WS_PING']['interval']
3847
if not ping_interval:
3948
return
@@ -62,7 +71,7 @@ async def ping_handler(self):
6271
if not timeout_handler.cancel() and timeout_handler.result():
6372
return
6473

65-
async def consumer_handler(self):
74+
async def consumer_handler(self) -> None:
6675
try:
6776
ping_handler = asyncio.ensure_future(self.ping_handler())
6877
try:
@@ -73,10 +82,12 @@ async def consumer_handler(self):
7382
except json.decoder.JSONDecodeError:
7483
self.session.log.warn('received invalid json')
7584
await self.send(
76-
{
77-
"status": "error",
78-
"error": c.ERR_INVALID_EVENT,
79-
}
85+
ClientMessage(
86+
{
87+
"status": "error",
88+
"error": c.ERR_INVALID_EVENT,
89+
}
90+
)
8091
)
8192
else:
8293
await self.session.on_client_event(data)
@@ -87,23 +98,23 @@ async def consumer_handler(self):
8798
except Exception:
8899
self.session.log.exception('unhandled error in consumer handler')
89100

90-
async def send(self, event):
101+
async def send(self, event: ClientMessage) -> None:
91102
try:
92103
await self.websocket.send(json.dumps(event))
93104
except websockets.ConnectionClosed:
94105
self.session.log.warn('attempted to send to closed socket')
95106

96-
async def close(self):
107+
async def close(self) -> None:
97108
await self.websocket.close()
98109

99110

100111
class Backend:
101-
def __init__(self, shark):
112+
def __init__(self, shark: 'SocketShark') -> None:
102113
self.shark = shark
103-
self.server = None
104-
self._closed = False
114+
self.server: websockets.WebSocketServer | None = None
115+
self._closed: bool = False
105116

106-
def close(self):
117+
def close(self) -> None:
107118
"""
108119
Called by SocketShark to make the backend stop accepting connections.
109120
"""
@@ -112,20 +123,22 @@ def close(self):
112123
self._closed = True
113124
self.server.server.close()
114125

115-
async def shutdown(self):
126+
async def shutdown(self) -> None:
116127
"""
117128
Called by SocketShark to close any open connections.
118129
"""
119130
if self.server:
120131
self.server.close()
121132
await self.server.wait_closed()
122133

123-
def start(self):
134+
def start(self) -> None:
124135
"""
125136
Called by SocketShark to initialize the server, prepare & run.
126137
"""
127138

128-
async def serve(websocket, path):
139+
async def serve(
140+
websocket: websockets.WebSocketServerProtocol, path: str
141+
) -> None:
129142
# If there are any pending connections that were established after
130143
# calling close() but before this callback was executed, close
131144
# them immediately.

0 commit comments

Comments
 (0)