Skip to content

Commit b1efa9c

Browse files
authored
Merge pull request #117 from modern-python/enhance-types-parsing
enhance types parsing
2 parents f65792b + 6ccbda5 commit b1efa9c

8 files changed

Lines changed: 191 additions & 105 deletions

File tree

packages/modern-di/modern_di/helpers/__init__.py

Whitespace-only changes.

packages/modern-di/modern_di/helpers/type_helpers.py

Lines changed: 0 additions & 30 deletions
This file was deleted.

packages/modern-di/modern_di/providers/factory.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33
import typing
44

55
from modern_di import types
6-
from modern_di.helpers.type_helpers import parse_signature
76
from modern_di.providers.abstract import AbstractProvider
87
from modern_di.scope import Scope
8+
from modern_di.types_parser import parse_creator
99

1010

1111
if typing.TYPE_CHECKING:
@@ -34,16 +34,26 @@ def __init__(
3434
kwargs: dict[str, typing.Any] | None = None,
3535
cache_settings: CacheSettings[types.T_co] | None = None,
3636
) -> None:
37-
dependency_type, self._parsed_kwargs = parse_signature(creator)
38-
super().__init__(scope=scope, bound_type=bound_type if bound_type != types.UNSET else dependency_type)
37+
dependency_type, self._parsed_kwargs = parse_creator(creator)
38+
super().__init__(scope=scope, bound_type=bound_type if bound_type != types.UNSET else dependency_type.arg_type)
3939
self._creator = creator
4040
self.cache_settings = cache_settings
4141
self._kwargs = kwargs
4242

4343
def _compile_kwargs(self, container: "Container") -> dict[str, typing.Any]:
44-
result: dict[str, typing.Any] = self._parsed_kwargs.copy()
45-
for k, v in result.items():
46-
result[k] = container.providers_registry.find_provider(dependency_name=k, dependency_type=v)
44+
result: dict[str, typing.Any] = {}
45+
for k, v in self._parsed_kwargs.items():
46+
provider: AbstractProvider[types.T_co] | None = container.providers_registry.find_provider(
47+
dependency_name=k, dependency_type=v.arg_type
48+
)
49+
if provider:
50+
result[k] = provider
51+
continue
52+
53+
if (not self._kwargs or k not in self._kwargs) and v.default == types.UNSET:
54+
msg = f"Argument {k} cannot be resolved, type={v.arg_type}"
55+
raise RuntimeError(msg)
56+
4757
if self._kwargs:
4858
result.update(self._kwargs)
4959
return result
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
import dataclasses
2+
import inspect
3+
import types
4+
import typing
5+
6+
from modern_di.types import UNSET
7+
8+
9+
@dataclasses.dataclass(kw_only=True, slots=True, frozen=True)
10+
class SignatureItem:
11+
arg_type: type | None = None
12+
args: list[type] = dataclasses.field(default_factory=list)
13+
is_nullable: bool = False
14+
default: object = UNSET
15+
16+
@classmethod
17+
def from_type(cls, type_: type, default: object = UNSET) -> "SignatureItem":
18+
result: dict[str, typing.Any] = {"default": default}
19+
if isinstance(type_, types.GenericAlias):
20+
result["arg_type"] = type_.__origin__
21+
result["args"] = list(type_.__args__)
22+
elif isinstance(type_, (types.UnionType, typing._UnionGenericAlias)): # type: ignore[attr-defined] # noqa: SLF001
23+
args = list(type_.__args__)
24+
if types.NoneType in args:
25+
result["is_nullable"] = True
26+
args.remove(types.NoneType)
27+
if len(args) > 1:
28+
result["args"] = args
29+
elif args:
30+
result["arg_type"] = args[0]
31+
elif isinstance(type_, type):
32+
result["arg_type"] = type_
33+
return cls(**result)
34+
35+
36+
def parse_creator(creator: typing.Callable[..., typing.Any]) -> tuple[SignatureItem, dict[str, SignatureItem]]:
37+
try:
38+
sig = inspect.signature(creator)
39+
except ValueError:
40+
return SignatureItem.from_type(typing.cast(type, creator)), {}
41+
42+
param_hints = {}
43+
for param_name, param in sig.parameters.items():
44+
default = UNSET
45+
if param.default is not param.empty:
46+
default = param.default
47+
if param.annotation is not param.empty:
48+
param_hints[param_name] = SignatureItem.from_type(param.annotation, default=default)
49+
else:
50+
param_hints[param_name] = SignatureItem(default=default)
51+
if sig.return_annotation:
52+
return_sig = SignatureItem.from_type(sig.return_annotation)
53+
elif isinstance(creator, type):
54+
return_sig = SignatureItem.from_type(creator)
55+
else:
56+
return_sig = SignatureItem()
57+
58+
return return_sig, param_hints

packages/modern-di/tests_core/helpers/__init__.py

Whitespace-only changes.

packages/modern-di/tests_core/helpers/test_type_helpers.py

Lines changed: 0 additions & 69 deletions
This file was deleted.

packages/modern-di/tests_core/providers/test_factory.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ class AnotherCreator:
2222

2323
class MyGroup(Group):
2424
app_factory = providers.Factory(creator=SimpleCreator, kwargs={"dep1": "original"})
25+
app_factory_unresolvable = providers.Factory(creator=SimpleCreator, bound_type=None)
2526
request_factory = providers.Factory(scope=Scope.REQUEST, creator=DependentCreator)
2627
request_factory_with_di_container = providers.Factory(scope=Scope.REQUEST, creator=AnotherCreator)
2728

@@ -39,6 +40,12 @@ def test_app_factory() -> None:
3940
assert instance2 is not instance3
4041

4142

43+
def test_app_factory_unresolvable() -> None:
44+
app_container = Container(groups=[MyGroup])
45+
with pytest.raises(RuntimeError, match="Argument dep1 cannot be resolved, type=<class 'str'"):
46+
app_container.resolve_provider(MyGroup.app_factory_unresolvable)
47+
48+
4249
def test_request_factory() -> None:
4350
app_container = Container(groups=[MyGroup])
4451
request_container = app_container.build_child_container(scope=Scope.REQUEST)
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
import dataclasses
2+
import typing
3+
4+
import pytest
5+
from modern_di.types_parser import SignatureItem, parse_creator
6+
7+
8+
@pytest.mark.parametrize(
9+
("type_", "result"),
10+
[
11+
(int, SignatureItem(arg_type=int)),
12+
(list[int], SignatureItem(arg_type=list, args=[int])),
13+
(dict[str, typing.Any], SignatureItem(arg_type=dict, args=[str, typing.Any])),
14+
(typing.Optional[str], SignatureItem(arg_type=str, is_nullable=True)), # noqa: UP045
15+
(str | None, SignatureItem(arg_type=str, is_nullable=True)),
16+
(str | int, SignatureItem(args=[str, int])),
17+
],
18+
)
19+
def test_signature_item_parser(type_: type, result: SignatureItem) -> None:
20+
assert SignatureItem.from_type(type_) == result
21+
22+
23+
def simple_func(arg1: int, arg2: str | None = None) -> int: ... # type: ignore[empty-body]
24+
def none_func(arg1: int, arg2: str | None = None) -> None: ...
25+
async def async_func(arg1: int = 1, arg2="str") -> int: ... # type: ignore[no-untyped-def,empty-body] # noqa: ANN001
26+
27+
28+
@dataclasses.dataclass(kw_only=True, slots=True, frozen=True)
29+
class SomeDataClass:
30+
arg1: str
31+
arg2: int
32+
33+
34+
@dataclasses.dataclass(kw_only=True, slots=True)
35+
class DataClassInitFalse:
36+
arg1: str
37+
arg2: int = dataclasses.field(init=False)
38+
39+
40+
class SomeRegularClass:
41+
def __init__(self, arg1: str, arg2: int) -> None: ...
42+
43+
44+
@pytest.mark.parametrize(
45+
("creator", "result"),
46+
[
47+
(
48+
simple_func,
49+
(
50+
SignatureItem(arg_type=int),
51+
{
52+
"arg1": SignatureItem(arg_type=int),
53+
"arg2": SignatureItem(arg_type=str, is_nullable=True, default=None),
54+
},
55+
),
56+
),
57+
(
58+
none_func,
59+
(
60+
SignatureItem(),
61+
{
62+
"arg1": SignatureItem(arg_type=int),
63+
"arg2": SignatureItem(arg_type=str, is_nullable=True, default=None),
64+
},
65+
),
66+
),
67+
(
68+
async_func,
69+
(
70+
SignatureItem(arg_type=int),
71+
{
72+
"arg1": SignatureItem(arg_type=int, default=1),
73+
"arg2": SignatureItem(default="str"),
74+
},
75+
),
76+
),
77+
(
78+
SomeDataClass,
79+
(
80+
SignatureItem(arg_type=SomeDataClass),
81+
{
82+
"arg1": SignatureItem(arg_type=str),
83+
"arg2": SignatureItem(arg_type=int),
84+
},
85+
),
86+
),
87+
(
88+
DataClassInitFalse,
89+
(
90+
SignatureItem(arg_type=DataClassInitFalse),
91+
{
92+
"arg1": SignatureItem(arg_type=str),
93+
},
94+
),
95+
),
96+
(
97+
SomeRegularClass,
98+
(
99+
SignatureItem(arg_type=SomeRegularClass),
100+
{
101+
"arg1": SignatureItem(arg_type=str),
102+
"arg2": SignatureItem(arg_type=int),
103+
},
104+
),
105+
),
106+
(int, (SignatureItem(arg_type=int), {})),
107+
],
108+
)
109+
def test_parse_creator(creator: type, result: tuple[SignatureItem | None, dict[str, SignatureItem]]) -> None:
110+
assert parse_creator(creator) == result

0 commit comments

Comments
 (0)