Skip to content

Commit 70a9cbb

Browse files
Pavkazzzclaude
andcommitted
refactor(contracts): add PoolDriver ABC, acquire contexts, balancer refactoring
- Add PoolDriver ABC (hasql/abc.py): formal interface for all DB drivers - Add acquire context classes (hasql/acquire.py): AcquireContext protocol, TimeoutAcquireContext, PoolAcquireContext with deadline-based timeout budget - Add minimal PoolStateProvider protocol (hasql/pool_state.py stub) - Refactor balancer policies to depend on PoolStateProvider protocol instead of concrete BasePoolManager — eliminates circular import - Simplify RandomWeightedBalancerPolicy weight formula (MACHINE_EPSILON removed) Stack 2/4 Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent c3c6390 commit 70a9cbb

11 files changed

Lines changed: 727 additions & 148 deletions

File tree

hasql/abc.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
import warnings
2+
from abc import ABC, abstractmethod
3+
from collections.abc import Sequence
4+
from typing import Any, Generic, TypeVar
5+
6+
from .acquire import AcquireContext
7+
from .metrics import DriverMetrics, PoolStats
8+
from .utils import Dsn
9+
10+
PoolT = TypeVar("PoolT")
11+
ConnT = TypeVar("ConnT")
12+
13+
14+
class PoolDriver(ABC, Generic[PoolT, ConnT]):
15+
"""Database driver interface for pool operations."""
16+
17+
@abstractmethod
18+
def get_pool_freesize(self, pool: PoolT) -> int: ...
19+
20+
@abstractmethod
21+
def acquire_from_pool(
22+
self,
23+
pool: PoolT,
24+
*,
25+
timeout: float | None = None,
26+
**kwargs,
27+
) -> AcquireContext[ConnT]: ...
28+
29+
@abstractmethod
30+
async def release_to_pool(
31+
self,
32+
connection: ConnT,
33+
pool: PoolT,
34+
**kwargs,
35+
) -> None: ...
36+
37+
@abstractmethod
38+
async def is_master(self, connection: ConnT) -> bool: ...
39+
40+
@abstractmethod
41+
async def fetch_scalar(self, connection: ConnT, query: str) -> Any:
42+
"""Execute a query and return a single scalar value."""
43+
...
44+
45+
@abstractmethod
46+
async def pool_factory(self, dsn: Dsn, **kwargs) -> PoolT: ...
47+
48+
@abstractmethod
49+
async def close_pool(self, pool: PoolT) -> None: ...
50+
51+
@abstractmethod
52+
async def terminate_pool(self, pool: PoolT) -> None: ...
53+
54+
@abstractmethod
55+
def is_connection_closed(self, connection: ConnT) -> bool: ...
56+
57+
@abstractmethod
58+
def host(self, pool: PoolT) -> str: ...
59+
60+
@abstractmethod
61+
def pool_stats(self, pool: PoolT) -> PoolStats: ...
62+
63+
def driver_metrics(
64+
self,
65+
pools: Sequence[PoolT | None],
66+
) -> Sequence[DriverMetrics]:
67+
warnings.warn(
68+
"driver_metrics() is deprecated, implement pool_stats() instead",
69+
DeprecationWarning,
70+
stacklevel=2,
71+
)
72+
return [
73+
DriverMetrics(
74+
min=s.min, max=s.max, idle=s.idle, used=s.used,
75+
host=self.host(p),
76+
)
77+
for p in pools if p
78+
for s in [self.pool_stats(p)]
79+
]
80+
81+
def prepare_pool_factory_kwargs(self, kwargs: dict) -> dict:
82+
"""Hook for drivers to adjust pool factory kwargs."""
83+
return kwargs
84+
85+
86+
__all__ = ("PoolDriver",)

hasql/acquire.py

Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
import asyncio
2+
from collections.abc import Callable, Generator
3+
from contextlib import AbstractAsyncContextManager
4+
from types import TracebackType
5+
from typing import (
6+
TYPE_CHECKING,
7+
Any,
8+
Generic,
9+
Protocol,
10+
TypeVar,
11+
)
12+
13+
from .exceptions import NoAvailablePoolError
14+
from .metrics import CalculateMetrics
15+
16+
if TYPE_CHECKING:
17+
from .balancer_policy.base import AbstractBalancerPolicy
18+
from .pool_state import PoolState
19+
20+
PoolT = TypeVar("PoolT")
21+
ConnT = TypeVar("ConnT")
22+
ConnT_co = TypeVar("ConnT_co", covariant=True)
23+
24+
25+
class AcquireContext(Protocol[ConnT_co]):
26+
async def __aenter__(self) -> ConnT_co: ...
27+
async def __aexit__(
28+
self,
29+
exc_type: type[BaseException] | None,
30+
exc_val: BaseException | None,
31+
exc_tb: TracebackType | None,
32+
) -> bool | None: ...
33+
def __await__(self) -> Generator[Any, None, ConnT_co]: ...
34+
35+
36+
class TimeoutAcquireContext(Generic[ConnT]):
37+
__slots__ = ("_context", "_timeout")
38+
39+
def __init__(self, context: AcquireContext[ConnT], timeout: float):
40+
self._context = context
41+
self._timeout = timeout
42+
43+
async def __aenter__(self) -> ConnT:
44+
return await asyncio.wait_for(
45+
self._context.__aenter__(),
46+
timeout=self._timeout,
47+
)
48+
49+
async def __aexit__(self, *exc):
50+
# TODO: consider adding a bounded timeout here. Currently if the
51+
# underlying driver hangs during connection release this will block
52+
# indefinitely. A timeout risks leaking the connection (not returned
53+
# to pool), so this needs careful design.
54+
return await self._context.__aexit__(*exc)
55+
56+
def __await__(self) -> Generator[Any, None, ConnT]:
57+
return asyncio.wait_for(
58+
self._context.__aenter__(),
59+
timeout=self._timeout,
60+
).__await__()
61+
62+
63+
class PoolAcquireContext(
64+
AbstractAsyncContextManager[ConnT],
65+
Generic[PoolT, ConnT],
66+
):
67+
def __init__(
68+
self,
69+
pool_state: "PoolState[PoolT, ConnT]",
70+
balancer: "AbstractBalancerPolicy[PoolT]",
71+
register_connection: Callable[[ConnT, PoolT], None],
72+
unregister_connection: Callable[[ConnT], None],
73+
read_only: bool,
74+
master_as_replica_weight: float | None,
75+
timeout: float,
76+
metrics: CalculateMetrics,
77+
fallback_master: bool = False,
78+
**kwargs,
79+
):
80+
self._pool_state = pool_state
81+
self._balancer = balancer
82+
self._register_connection = register_connection
83+
self._unregister_connection = unregister_connection
84+
self._read_only = read_only
85+
self._fallback_master = fallback_master
86+
self._master_as_replica_weight = master_as_replica_weight
87+
self._timeout = timeout
88+
self._kwargs = kwargs
89+
self._metrics = metrics
90+
self._pool: PoolT | None = None
91+
self._conn: ConnT | None = None
92+
self._context: AcquireContext[ConnT] | None = None
93+
94+
def _deadline(self) -> float:
95+
return asyncio.get_running_loop().time() + self._timeout
96+
97+
def _remaining_timeout(self, deadline: float) -> float:
98+
remaining_timeout = deadline - asyncio.get_running_loop().time()
99+
if remaining_timeout <= 0:
100+
raise asyncio.TimeoutError
101+
return remaining_timeout
102+
103+
async def _get_pool(self, deadline: float) -> PoolT:
104+
async def get_pool() -> PoolT:
105+
with self._metrics.with_get_pool():
106+
pool = await self._balancer.get_pool(
107+
read_only=self._read_only,
108+
fallback_master=self._fallback_master,
109+
master_as_replica_weight=self._master_as_replica_weight,
110+
)
111+
if pool is None:
112+
raise NoAvailablePoolError("No available pool")
113+
return pool
114+
115+
return await asyncio.wait_for(
116+
get_pool(),
117+
timeout=self._remaining_timeout(deadline),
118+
)
119+
120+
async def _resolve_pool_and_acquire_context(
121+
self,
122+
) -> tuple[PoolT, AcquireContext[ConnT]]:
123+
deadline = self._deadline()
124+
pool = await self._get_pool(deadline)
125+
remaining = self._remaining_timeout(deadline)
126+
driver_ctx = self._pool_state.acquire_from_pool(
127+
pool,
128+
timeout=remaining,
129+
**self._kwargs,
130+
)
131+
return pool, driver_ctx
132+
133+
async def _acquire_connection(self) -> ConnT:
134+
pool, driver_ctx = await self._resolve_pool_and_acquire_context()
135+
136+
host = self._pool_state.host(pool)
137+
with self._metrics.with_acquire(host):
138+
conn: ConnT = await driver_ctx
139+
140+
try:
141+
self._metrics.add_connection(host)
142+
self._register_connection(conn, pool)
143+
except BaseException:
144+
await self._pool_state.release_to_pool(conn, pool)
145+
raise
146+
return conn
147+
148+
async def __aenter__(self) -> ConnT:
149+
pool, driver_ctx = await self._resolve_pool_and_acquire_context()
150+
151+
host = self._pool_state.host(pool)
152+
with self._metrics.with_acquire(host):
153+
conn: ConnT = await driver_ctx.__aenter__()
154+
155+
try:
156+
self._metrics.add_connection(host)
157+
self._register_connection(conn, pool)
158+
except BaseException:
159+
await driver_ctx.__aexit__(None, None, None)
160+
raise
161+
162+
self._pool = pool
163+
self._conn = conn
164+
self._context = driver_ctx
165+
return conn
166+
167+
async def __aexit__(self, *exc):
168+
if self._conn is None or self._pool is None or self._context is None:
169+
return
170+
self._unregister_connection(self._conn)
171+
self._metrics.remove_connection(
172+
self._pool_state.host(self._pool),
173+
)
174+
await self._context.__aexit__(*exc)
175+
176+
def __await__(self):
177+
return self._acquire_connection().__await__()
178+
179+
180+
__all__ = (
181+
"AcquireContext",
182+
"TimeoutAcquireContext",
183+
"PoolAcquireContext",
184+
)

hasql/balancer_policy/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
1+
from .base import AbstractBalancerPolicy
12
from .greedy import GreedyBalancerPolicy
23
from .random_weighted import RandomWeightedBalancerPolicy
34
from .round_robin import RoundRobinBalancerPolicy
45

5-
66
__all__ = (
7+
"AbstractBalancerPolicy",
78
"GreedyBalancerPolicy",
89
"RandomWeightedBalancerPolicy",
910
"RoundRobinBalancerPolicy",

hasql/balancer_policy/base.py

Lines changed: 42 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,22 @@
11
import random
2-
from abc import abstractmethod
3-
from typing import Any, Optional
2+
from abc import ABC, abstractmethod
3+
from typing import Generic, TypeVar
44

5-
from ..base import AbstractBalancerPolicy, BasePoolManager
5+
from ..pool_state import PoolStateProvider
66

7+
PoolT = TypeVar("PoolT")
78

8-
class BaseBalancerPolicy(AbstractBalancerPolicy):
9-
def __init__(self, pool_manager: BasePoolManager):
10-
self._pool_manager = pool_manager
9+
10+
class AbstractBalancerPolicy(ABC, Generic[PoolT]):
11+
def __init__(self, pool_state: PoolStateProvider[PoolT]):
12+
self._pool_state = pool_state
1113

1214
async def get_pool(
1315
self,
1416
read_only: bool,
1517
fallback_master: bool = False,
16-
master_as_replica_weight: Optional[float] = None,
17-
) -> Any:
18+
master_as_replica_weight: float | None = None,
19+
) -> PoolT | None:
1820
if not read_only and master_as_replica_weight is not None:
1921
raise ValueError(
2022
"Field master_as_replica_weight is used only when "
@@ -23,23 +25,51 @@ async def get_pool(
2325

2426
choose_master_as_replica = False
2527
if master_as_replica_weight is not None:
26-
rand = random.random()
27-
choose_master_as_replica = 0 < rand <= master_as_replica_weight
28+
choose_master_as_replica = (
29+
random.random() < master_as_replica_weight
30+
)
2831

2932
return await self._get_pool(
3033
read_only=read_only,
3134
fallback_master=fallback_master or choose_master_as_replica,
3235
choose_master_as_replica=choose_master_as_replica,
3336
)
3437

38+
async def _get_candidates(
39+
self,
40+
read_only: bool,
41+
fallback_master: bool = False,
42+
choose_master_as_replica: bool = False,
43+
) -> list[PoolT]:
44+
candidates: list[PoolT] = []
45+
46+
if read_only:
47+
candidates.extend(
48+
await self._pool_state.get_replica_pools(
49+
fallback_master=fallback_master,
50+
),
51+
)
52+
53+
if not read_only or (
54+
choose_master_as_replica
55+
and self._pool_state.master_pool_count > 0
56+
and self._pool_state.replica_pool_count > 0
57+
):
58+
candidates.extend(await self._pool_state.get_master_pools())
59+
60+
return candidates
61+
3562
@abstractmethod
3663
async def _get_pool(
3764
self,
3865
read_only: bool,
3966
fallback_master: bool = False,
4067
choose_master_as_replica: bool = False,
41-
):
68+
) -> PoolT | None:
4269
pass
4370

4471

45-
__all__ = ["BaseBalancerPolicy"]
72+
# Backward-compatible alias
73+
BaseBalancerPolicy = AbstractBalancerPolicy
74+
75+
__all__ = ["AbstractBalancerPolicy", "BaseBalancerPolicy"]

0 commit comments

Comments
 (0)