Skip to content

Commit 2313b51

Browse files
fix type checkers (#933)
* fix type checkers * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update test_field_type_annotations.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update test_field_type_annotations.py * fix type hinting * fix duplicate name * typo --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent cf9a896 commit 2313b51

8 files changed

Lines changed: 418 additions & 1 deletion

File tree

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ dev = [
3636
"ipykernel>=6.29.5",
3737
"mlflow>=2.20.0",
3838
"pre-commit>=4.1.0",
39+
"pyright>=1.1.403",
3940
"pytest>=8.3.4",
4041
"pytest-benchmark[histogram]>=5.1.0",
4142
"pytest-cov>=6.0.0",
Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
"""Test file to verify type annotations work correctly for zntrack fields.
2+
3+
This test should pass type checking with pyright/mypy.
4+
The main goal is to ensure that patterns like:
5+
field: int = zntrack.params()
6+
field: str = zntrack.outs()
7+
do not raise type errors when no explicit value is provided.
8+
"""
9+
10+
from pathlib import Path
11+
12+
import pandas as pd
13+
14+
import zntrack
15+
16+
17+
class TestFieldAnnotations(zntrack.Node):
18+
"""Test basic field type annotations without explicit values."""
19+
20+
# These should NOT raise type errors (the main fix)
21+
param_int: int = zntrack.params()
22+
param_str: str = zntrack.params()
23+
param_dict: dict = zntrack.params()
24+
25+
out_int: int = zntrack.outs()
26+
out_str: str = zntrack.outs()
27+
out_dict: dict = zntrack.outs()
28+
29+
metric_dict: dict = zntrack.metrics()
30+
31+
dep_int: int = zntrack.deps()
32+
dep_str: str = zntrack.deps()
33+
34+
plots_data: pd.DataFrame = (
35+
zntrack.plots()
36+
) # pandas.DataFrame, but we don't enforce the type
37+
38+
# Path fields
39+
params_path_str: str = zntrack.params_path()
40+
params_path_path: Path = zntrack.params_path()
41+
outs_path_str: str = zntrack.outs_path()
42+
outs_path_path: Path = zntrack.outs_path()
43+
deps_path_str: str = zntrack.deps_path()
44+
deps_path_path: Path = zntrack.deps_path()
45+
plots_path_path: Path = zntrack.plots_path()
46+
metrics_path_path: Path = zntrack.metrics_path()
47+
48+
def run(self):
49+
pass
50+
51+
52+
class TestFieldAnnotationsWithDefaults(zntrack.Node):
53+
"""Test that type safety is maintained when providing explicit values."""
54+
55+
# These should work fine (correct type matching)
56+
param_good_int: int = zntrack.params(42)
57+
param_good_str: str = zntrack.params("hello")
58+
param_good_tuple: str = zntrack.params(("word",))
59+
60+
# Using default_factory
61+
param_good_dict: dict = zntrack.params(default_factory=dict)
62+
param_factory: list = zntrack.params(default_factory=list)
63+
64+
# Path fields with correct types
65+
params_path_str_with_val: str = zntrack.params_path("config.yaml")
66+
params_path_path_with_val: Path = zntrack.params_path(Path("config.yaml"))
67+
68+
# Path fields using zntrack.nwd
69+
outs_path_in_nwd: Path = zntrack.outs_path(zntrack.nwd / "output.txt")
70+
metrics_path_in_nwd: Path = zntrack.metrics_path(zntrack.nwd / "metrics.json")
71+
plots_path_in_nwd: Path = zntrack.plots_path(zntrack.nwd / "plot.png")
72+
73+
# plots_path_list
74+
plots_path_list: list = zntrack.plots_path(
75+
default_factory=lambda: ["plot1.png", "plot2.png"]
76+
)
77+
78+
# path using tuples
79+
params_path_tuple: tuple[Path, ...] = zntrack.params_path(
80+
(zntrack.nwd / "config1.yaml", zntrack.nwd / "config2.yaml")
81+
)
82+
outs_path_tuple: tuple[Path, ...] = zntrack.outs_path(
83+
(zntrack.nwd / "output1.txt", zntrack.nwd / "output2.txt")
84+
)
85+
86+
def run(self):
87+
pass
88+
89+
90+
class TestTypeSafetyErrors(zntrack.Node):
91+
"""These should cause type errors to ensure type safety is maintained."""
92+
93+
# # Type mismatches with explicit values (should error when uncommented)
94+
param_bad_int: int = zntrack.params(
95+
"string"
96+
) # Should error: str not assignable to int
97+
param_bad_str: str = zntrack.params(42) # Should error: int not assignable to str
98+
param_bad_dict: dict = zntrack.params(
99+
"not_a_dict"
100+
) # Should error: str not assignable to dict
101+
102+
# # # Path field type mismatches (should error when uncommented)
103+
params_path_bad: Path = zntrack.params_path(
104+
1234
105+
) # Should error: int not assignable to path types
106+
outs_path_bad: str = zntrack.outs_path(
107+
42
108+
) # Should error: int not assignable to path types
109+
deps_path_bad: Path = zntrack.deps_path(
110+
3.14
111+
) # Should error: float not assignable to path types
112+
113+
# List vs single type mismatches (should error when uncommented)
114+
outs_path_list_mismatch: str = zntrack.outs_path(
115+
("file1.txt", "file2.txt")
116+
) # Should error: list not assignable to str
117+
params_path_list_mismatch: Path = zntrack.params_path(
118+
("config1.yaml", "config2.yaml")
119+
) # Should error: list not assignable to Path
120+
# using default factory with list
121+
params_path_list_mismatch_lambda: Path = zntrack.params_path(
122+
default_factory=lambda: ["config1.yaml", "config2.yaml"]
123+
) # Should error: list not assignable to Path
124+
125+
# Factory function returning wrong type (should error when uncommented)
126+
param_factory_bad: int = zntrack.params(
127+
default_factory=str
128+
) # Should error: str() returns str, not int
129+
param_factory_bad2: dict = zntrack.params(
130+
default_factory=list
131+
) # Should error: list() returns list, not dict
132+
133+
# zntrack.nwd as string
134+
nwd_as_str: str = zntrack.outs_path(
135+
zntrack.nwd / "output.txt"
136+
) # Should error: str not assignable to Path

uv.lock

Lines changed: 15 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

zntrack/fields/deps.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
from zntrack.node import Node
1313
from zntrack.utils.filesystem import resolve_state_file_path
1414

15+
_T = t.TypeVar("_T")
16+
1517

1618
def _deps_getter(self: "Node", name: str):
1719
zntrack_path = resolve_state_file_path(
@@ -84,6 +86,14 @@ def _deps_getter(self: "Node", name: str):
8486
return content
8587

8688

89+
@t.overload
90+
def deps() -> t.Any: ...
91+
92+
93+
@t.overload
94+
def deps(default: _T, **kwargs) -> _T: ...
95+
96+
8797
def deps(default=dataclasses.MISSING, **kwargs) -> t.Any:
8898
"""Define dependencies for a node.
8999

zntrack/fields/outs_and_metrics.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,10 @@ def _metrics_save_func(self: "Node", name: str, suffix: str):
3636
raise TypeError(f"Error while saving {name} to {self.nwd / name}.json") from err
3737

3838

39+
@t.overload
40+
def outs(*, cache: bool = True, independent: bool = False, **kwargs) -> t.Any: ...
41+
42+
3943
def outs(*, cache: bool = True, independent: bool = False, **kwargs) -> t.Any:
4044
"""Define output for a node.
4145
@@ -72,6 +76,12 @@ def outs(*, cache: bool = True, independent: bool = False, **kwargs) -> t.Any:
7276
)
7377

7478

79+
@t.overload
80+
def metrics(
81+
*, cache: bool | None = None, independent: bool = False, **kwargs
82+
) -> t.Any: ...
83+
84+
7585
def metrics(*, cache: bool | None = None, independent: bool = False, **kwargs) -> t.Any:
7686
"""Define metrics for a node.
7787

zntrack/fields/params.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,10 @@ def _params_getter(self: "Node", name: str):
2121

2222

2323
# Overloads for type checking
24+
@t.overload
25+
def params() -> t.Any: ...
26+
27+
2428
@t.overload
2529
def params(default: _T, **kwargs) -> _T: ...
2630

zntrack/fields/plots.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import typing as t
2+
13
import pandas as pd
24

35
from zntrack.config import NOT_AVAILABLE, ZNTRACK_OPTION_PLOTS_CONFIG, FieldTypes
@@ -24,6 +26,22 @@ def _plots_getter(self: "Node", name: str, suffix: str):
2426
return pd.read_csv(f, index_col=0)
2527

2628

29+
@t.overload
30+
def plots(
31+
*,
32+
y: str | list[str] | None = None,
33+
cache: bool = True,
34+
independent: bool = False,
35+
x: str = "step",
36+
x_label: str | None = None,
37+
y_label: str | None = None,
38+
template: str | None = None,
39+
title: str | None = None,
40+
autosave: bool = False,
41+
**kwargs,
42+
) -> t.Any: ...
43+
44+
2745
def plots(
2846
*,
2947
y: str | list[str] | None = None,

0 commit comments

Comments
 (0)