Skip to content

Commit 2fa1f31

Browse files
store lockfile in node-meta.json (#913)
* store lockfile in `node-meta.json` * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * test lockfile * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * split line * typo * test `mp_start_stage_lock` and `mp_join_stage_lock` * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * mock `mp.Queue` * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent a680daa commit 2fa1f31

5 files changed

Lines changed: 117 additions & 1 deletion

File tree

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
import json
2+
from pathlib import Path
3+
from unittest.mock import MagicMock
4+
5+
import pytest
6+
7+
import zntrack
8+
from zntrack.utils.lockfile import get_stage_lock
9+
10+
11+
class ReadFileContent(zntrack.Node):
12+
deps_file: Path = zntrack.deps_path()
13+
params: str = zntrack.params()
14+
15+
def run(self):
16+
pass
17+
18+
19+
@pytest.fixture()
20+
def lockfile_01():
21+
return {
22+
"cmd": "zntrack run test_node_meta.ReadFileContent --name ReadFileContent",
23+
"deps": [
24+
{
25+
"hash": "md5",
26+
"md5": "6dbd01b4309de2c22b027eb35a3ce18b",
27+
"path": "data.txt",
28+
"size": 11,
29+
}
30+
],
31+
"params": {"params.yaml": {"ReadFileContent": {"params": "test"}}},
32+
}
33+
34+
35+
def test_node_meta_lock(proj_path, lockfile_01):
36+
project = zntrack.Project()
37+
38+
file = Path("data.txt")
39+
file.write_text("Lorem Ipsum")
40+
41+
with project:
42+
_ = ReadFileContent(deps_file=file, params="test")
43+
44+
project.repro()
45+
# TODO: do we want the node to require from_rev()?
46+
# to update lockfile and other node-meta states?
47+
node = ReadFileContent.from_rev()
48+
assert node.state.lockfile == lockfile_01
49+
50+
51+
def test_node_meta_lock_mp(proj_path, lockfile_01):
52+
project = zntrack.Project()
53+
54+
file = Path("data.txt")
55+
file.write_text("Lorem Ipsum")
56+
57+
with project:
58+
node = ReadFileContent(deps_file=file, params="test")
59+
60+
project.repro()
61+
62+
queue = MagicMock() # mock multiprocessing.Queue
63+
get_stage_lock(node.name, queue)
64+
assert queue.put.call_count == 1
65+
lockfile = json.loads(json.dumps(queue.put.call_args[0][0]))
66+
assert lockfile["cmd"] == lockfile_01["cmd"]
67+
assert lockfile["deps"] == lockfile_01["deps"]
68+
assert lockfile["params"] == lockfile_01["params"]

zntrack/cli/cli.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from zntrack import Node, utils
1515
from zntrack.state import PLUGIN_LIST
1616
from zntrack.utils.import_handler import import_handler
17+
from zntrack.utils.lockfile import mp_join_stage_lock, mp_start_stage_lock
1718
from zntrack.utils.misc import load_env_vars
1819

1920
load_env_vars()
@@ -54,7 +55,11 @@ def main(
5455

5556
@app.command()
5657
def run(
57-
node_path: str, name: str = None, meta_only: bool = False, method: str = "run"
58+
node_path: str,
59+
name: str = None,
60+
meta_only: bool = False,
61+
method: str = "run",
62+
save_lockfile: bool = True,
5863
) -> None:
5964
"""Execute a ZnTrack Node.
6065
@@ -70,6 +75,8 @@ def run(
7075
Save only the metadata.
7176
method : str, default 'run'
7277
The method to run on the node.
78+
save_lockfile : bool
79+
Save the lockfile for the inputs into the node-meta.json file.
7380
7481
"""
7582
start_time = datetime.datetime.now()
@@ -78,13 +85,19 @@ def run(
7885

7986
cls: Node = utils.import_handler.import_handler(node_path)
8087
node: Node = cls.from_rev(name=name, running=True)
88+
if save_lockfile:
89+
queue, proc = mp_start_stage_lock(name)
8190
node.state.increment_run_count()
8291
node.state.save_node_meta()
8392
# dynamic version of node.run()
8493
getattr(node, method)()
8594
node.save()
95+
if save_lockfile:
96+
stage_lock = mp_join_stage_lock(queue, proc)
97+
node.state.set_lockfile(stage_lock)
8698
run_time = datetime.datetime.now() - start_time
8799
node.state.add_run_time(run_time)
100+
88101
node.state.save_node_meta()
89102

90103

zntrack/node.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,12 +242,14 @@ def from_rev(
242242
content = json.load(f)
243243
run_count = content.get("run_count", 0)
244244
run_time = content.get("run_time", 0)
245+
lockfile = content.get("lockfile", None)
245246
if node_uuid := content.get("uuid", None):
246247
instance._uuid = uuid.UUID(node_uuid)
247248
instance.__dict__["state"]["run_count"] = run_count
248249
instance.__dict__["state"]["run_time"] = datetime.timedelta(
249250
seconds=run_time
250251
)
252+
instance.__dict__["state"]["lockfile"] = lockfile
251253
if not instance.state.lazy_evaluation:
252254
for field in dataclasses.fields(cls):
253255
_ = getattr(instance, field.name)

zntrack/state.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ class NodeStatus:
8080
group: Group | None = None
8181
run_time: datetime.timedelta | None = None
8282
path: pathlib.Path = dataclasses.field(default_factory=pathlib.Path)
83+
lockfile: dict | None = None
8384
# TODO: move node name and nwd to here as well
8485

8586
@property
@@ -258,6 +259,10 @@ def add_run_time(self, run_time: datetime.timedelta) -> None:
258259
def increment_run_count(self) -> None:
259260
self.node.__dict__["state"]["run_count"] = self.run_count + 1
260261

262+
def set_lockfile(self, lockfile: dict) -> None:
263+
"""Set the lockfile for the node."""
264+
self.node.__dict__["state"]["lockfile"] = lockfile
265+
261266
def save_node_meta(self) -> None:
262267
node_meta_content = {
263268
"uuid": str(self.node.uuid),
@@ -267,6 +272,8 @@ def save_node_meta(self) -> None:
267272

268273
if self.run_time is not None:
269274
node_meta_content["run_time"] = self.run_time.total_seconds()
275+
if self.lockfile is not None:
276+
node_meta_content["lockfile"] = self.lockfile
270277

271278
with contextlib.suppress(importlib.metadata.PackageNotFoundError):
272279
module = self.node.__module__.split(".")[0]

zntrack/utils/lockfile.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
import typing as t
2+
from multiprocessing import Process, Queue
3+
4+
from dvc.api import DVCFileSystem
5+
from dvc.stage.serialize import to_single_stage_lockfile
6+
7+
8+
def get_stage_lock(name: str, queue: Queue) -> None:
9+
fs = DVCFileSystem()
10+
stage = next(iter(fs.repo.stage.collect(name)))
11+
stage.save_deps(allow_missing=False)
12+
result = to_single_stage_lockfile(stage)
13+
queue.put(result) # Send the result back to the main process
14+
15+
16+
def mp_start_stage_lock(name: str) -> t.Tuple[Queue, Process]:
17+
queue = Queue()
18+
p = Process(target=get_stage_lock, args=(name, queue))
19+
p.start()
20+
return queue, p
21+
22+
23+
def mp_join_stage_lock(queue: Queue, p: Process) -> dict:
24+
p.join()
25+
stage_lock = queue.get() # Receive the result
26+
return {k: v for k, v in stage_lock.items() if k in ["cmd", "deps", "params"]}

0 commit comments

Comments
 (0)