55import signal
66import ssl
77import sys
8+ from typing import Any
89
910import aioredis
1011import click
1415from .metrics import Metrics
1516from .receiver import ServiceReceiver
1617from .redis_connection import RedisConnection
18+ from .session import Session
19+ from .types import Config , LogConfig
1720
1821
19- def setup_logging (log_config ) :
22+ def setup_logging (log_config : LogConfig ) -> None :
2023 # Configure root logger if logging level is specified in config
2124 if log_config ['level' ]:
2225 level = getattr (logging , log_config ['level' ])
2326 formatter = logging .Formatter (log_config ['format' ])
24- sh = logging .StreamHandler ()
27+ sh : logging . StreamHandler [ Any ] = logging .StreamHandler ()
2528 sh .setFormatter (formatter )
2629
2730 logger = logging .getLogger ()
@@ -36,7 +39,7 @@ def setup_logging(log_config):
3639 setup_structlog (sys .stdout .isatty ())
3740
3841
39- def setup_structlog (tty = False ):
42+ def setup_structlog (tty : bool = False ) -> None :
4043 processors = [
4144 structlog .stdlib .filter_by_level ,
4245 structlog .stdlib .add_log_level ,
@@ -59,7 +62,7 @@ def setup_structlog(tty=False):
5962 )
6063
6164
62- def load_backend (config ) :
65+ def load_backend (config : Config ) -> Any :
6366 """
6467 Return the backend module from the given SocketShark configuration.
6568 """
@@ -69,30 +72,34 @@ def load_backend(config):
6972
7073
7174class SocketShark :
72- def __init__ (self , config ) :
75+ def __init__ (self , config : Config ) -> None :
7376 self .config = config
7477 backend_module = load_backend (config )
7578 backend_cls = backend_module .Backend
7679 self .backend = backend_cls (self )
7780 self ._init_logging ()
78- self ._task = None
81+ self ._task : asyncio . Task [ None ] | None = None
7982 self ._shutdown = False
80- self .sessions = set ()
83+ self .sessions : set [ Session ] = set ()
8184 self .metrics = Metrics (self )
8285 self .metrics .initialize ()
8386 self .metrics .set_ready (False )
84- self .redis_connections = []
87+ self .redis_connections : list [ RedisConnection ] = []
8588
86- def _init_logging (self ):
89+ def _init_logging (self ) -> None :
8790 logger_name = self .config ['LOG' ]['logger_name' ]
8891 trace_logger_prefix = self .config ['LOG' ]['trace_logger_prefix' ]
8992 trace_logger_name = '{}.{}' .format (trace_logger_prefix , logger_name )
9093 pid = os .getpid ()
91- self .log = structlog .get_logger (logger_name ).bind (pid = pid )
92- self .trace_log = structlog .get_logger (trace_logger_name ).bind (pid = pid )
94+ self .log : structlog .stdlib .BoundLogger = structlog .get_logger (
95+ logger_name
96+ ).bind (pid = pid )
97+ self .trace_log : structlog .stdlib .BoundLogger = structlog .get_logger (
98+ trace_logger_name
99+ ).bind (pid = pid )
93100 self .trace_log .debug ('trace' )
94101
95- def signal_ready (self ):
102+ def signal_ready (self ) -> None :
96103 """
97104 Notify that the backend is ready.
98105 """
@@ -104,14 +111,14 @@ def signal_ready(self):
104111 )
105112 self .metrics .set_ready (True )
106113
107- def signal_shutdown (self ):
114+ def signal_shutdown (self ) -> None :
108115 """
109116 Notify that the backend shut down.
110117 """
111118 self .log .info ('done' )
112119 self .metrics .set_ready (False )
113120
114- async def _redis_connection_handler (self ):
121+ async def _redis_connection_handler (self ) -> None :
115122 """
116123 Handle Redis connection errors.
117124
@@ -132,7 +139,7 @@ async def _redis_connection_handler(self):
132139 # Redis goes down so they can reconnect and restore subscriptions.
133140 asyncio .ensure_future (self .shutdown ())
134141
135- async def prepare (self ):
142+ async def prepare (self ) -> None :
136143 """
137144 Callback called by the backend to prepare SocketShark.
138145
@@ -155,12 +162,12 @@ async def prepare(self):
155162
156163 self .service_receiver = ServiceReceiver (self )
157164
158- def _cleanup (self ):
165+ def _cleanup (self ) -> None :
159166 self ._redis_connection_handler_task .cancel ()
160167 for c in self .redis_connections :
161168 c .redis .close ()
162169
163- async def shutdown (self ):
170+ async def shutdown (self ) -> None :
164171 """
165172 Shut down SocketShark cleanly.
166173 """
@@ -201,20 +208,22 @@ async def shutdown(self):
201208 self ._uninstall_signal_handlers ()
202209 self ._shutdown = False
203210
204- async def run_service_receiver (self , once = False ):
211+ async def run_service_receiver (
212+ self , once : bool = False
213+ ) -> list [bool | None ] | None :
205214 return await self .service_receiver .reader (once = once )
206215
207- def start (self ):
216+ def start (self ) -> None :
208217 """
209218 Start the backend (main entrypoint into SocketShark).
210219 """
211220 self .backend .start ()
212221
213- async def _run (self , once = False ):
222+ async def _run (self , once : bool = False ) -> None :
214223 await self .run_service_receiver ()
215224 asyncio .ensure_future (self .shutdown ())
216225
217- async def run (self , once = False ):
226+ async def run (self , once : bool = False ) -> None :
218227 """
219228 Set up SocketShark signal handlers and run the service receiver.
220229
@@ -223,39 +232,40 @@ async def run(self, once=False):
223232 self ._install_signal_handlers ()
224233 self ._task = asyncio .ensure_future (self ._run ())
225234
226- def _install_signal_handlers (self ):
235+ def _install_signal_handlers (self ) -> None :
227236 """
228237 Set up signal handlers for safely stopping the worker.
229238 """
230239
231- def request_stop ():
240+ def request_stop () -> None :
232241 self .log .info ('stop requested' )
233242 asyncio .ensure_future (self .shutdown ())
234243
235244 loop = asyncio .get_event_loop ()
236245 loop .add_signal_handler (signal .SIGINT , request_stop )
237246 loop .add_signal_handler (signal .SIGTERM , request_stop )
238247
239- def _uninstall_signal_handlers (self ):
248+ def _uninstall_signal_handlers (self ) -> None :
240249 """
241250 Restore default signal handlers.
242251 """
243252 loop = asyncio .get_event_loop ()
244253 loop .remove_signal_handler (signal .SIGINT )
245254 loop .remove_signal_handler (signal .SIGTERM )
246255
247- def get_ssl_context (self ):
256+ def get_ssl_context (self ) -> ssl . SSLContext | None :
248257 ssl_settings = self .config .get ('WS_SSL' )
249258 if ssl_settings :
250259 ssl_context = ssl .create_default_context (ssl .Purpose .CLIENT_AUTH )
251260 ssl_context .load_cert_chain (
252261 certfile = ssl_settings ['cert' ], keyfile = ssl_settings ['key' ]
253262 )
254263 return ssl_context
264+ return None
255265
256266
257- def load_config (config_name ) :
258- config = {}
267+ def load_config (config_name : str ) -> Config :
268+ config : dict [ str , Any ] = {}
259269
260270 # Get config defaults
261271 for key in dir (config_defaults ):
@@ -272,13 +282,13 @@ def load_config(config_name):
272282 else :
273283 config [key ] = value
274284
275- return config
285+ return Config ( config )
276286
277287
278288@click .command ()
279289@click .option ('-c' , '--config' , required = True , help = 'dotted path to config' )
280290@click .pass_context
281- def run (context , config ) :
291+ def run (context : click . Context , config : str ) -> None :
282292 config_obj = load_config (config )
283293
284294 setup_logging (config_obj ['LOG' ])
0 commit comments