feat: register_pytree_node — allow custom classes in mx.compile#3500
feat: register_pytree_node — allow custom classes in mx.compile#3500st-adam wants to merge 3 commits into
Conversation
zcbenz
left a comment
There was a problem hiding this comment.
I think the register_pytree_node API is a good thing to have, but utils.compile would be a bad addition.
The register_pytree_node API should be implemented in C++ layer so mx.compile can be made aware of it directly. And we should probably remove the python versions of tree utils and expose the C++ ones instead so we don't have to duplicate the register_pytree_node implementation in 2 languages.
Adds a JAX-style pytree registration mechanism so third-party Python
classes can flow through mx.compile, tree_visit, tree_map, and the
rest of MLX's tree utilities.
Motivation
----------
mx.compile rejects any function argument that is not a plain array,
list, dict, tuple, or scalar constant:
ValueError: [compile] Function arguments must be trees of arrays or
constants (floats, ints, strings, or None), but received type
mlx_lm.models.cache.ArraysCache.
Any model whose forward pass receives a custom cache object — every
hybrid SSM+attention model in mlx-lm (Qwen 3.5/3.6, Llama 4, Gemma 3n,
etc.) — therefore cannot be compiled, even though the computation is
fully expressible as MLX ops.
Implementation
--------------
The registry, the public API, and all tree-traversal hooks live in
C++ (per review feedback: a Python-side compile wrapper would
duplicate the implementation across two languages).
python/src/trees.h, python/src/trees.cpp:
* PytreeNodeDef — (flatten_fn, unflatten_fn) pair.
* registry() — heap-allocated map keyed by PyTypeObject*, never
freed. Avoids the use-after-finalize segfault
that a function-local static would hit when
Python tears down the interpreter while
stored nb::callables still hold refs. Same
lifetime pattern used by structure_sentinel().
* register_pytree_node(cls, flatten_fn, unflatten_fn) — exposed to
Python as mx.register_pytree_node.
* is_registered_pytree, flatten_registered, unflatten_registered,
registered_pytree_fingerprint — internal helpers.
* tree_visit / tree_map (multi-tree and single-tree overloads) and
tree_visit_update now recurse into registered types, so
tree_unflatten through the compile path reconstructs them.
python/src/transforms.cpp:
* PyCompiledFun::call_impl::recurse adds a pytree_identifier branch:
flattens the registered node into its children and embeds the
type-id + aux hash in the constants vector, so two structurally
different registered instances retrace correctly.
* Error message updated to mention mx.register_pytree_node.
python/src/mlx.cpp:
* Wires init_trees() into NB_MODULE.
python/mlx/utils.py:
* re-exports mlx.core.register_pytree_node so users can do either
`import mlx.core as mx; mx.register_pytree_node(...)` or
`from mlx.utils import register_pytree_node`.
Test
----
python/tests/test_compile.py::test_compile_registered_pytree_node:
* mx.compile rejects an unregistered custom class.
* After registration the compiled forward returns the correct value.
* aux_data tagged differently on two subclasses retraces cleanly.
* flatten_fn returning a malformed value surfaces a clear ValueError.
All existing tests still pass:
- python/tests/test_compile.py — 55 passed
- python/tests/test_tree.py — 4 passed
- python/tests/test_autograd.py + test_vmap.py — full suite green
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
59c91c6 to
727264c
Compare
|
@zcbenz Thanks for the review. I've moved the implementation into C++ as requested:
Verified locally on a CPU build: |
| // Combines id(type) and hash(aux) so that compile retraces if either changes. | ||
| uint64_t registered_pytree_fingerprint(nb::handle obj); | ||
|
|
||
| void init_trees(nb::module_& m); |
There was a problem hiding this comment.
This declaration is not needed, already declared in python/src/mlx.cpp.
| } // namespace | ||
|
|
||
| void register_pytree_node( | ||
| nb::object cls, |
There was a problem hiding this comment.
You can use nb::type_object and let nanobind do the check.
| throw std::invalid_argument( | ||
| "[register_pytree_node] cls must be a Python class object."); | ||
| } | ||
| PyTypeObject* type = reinterpret_cast<PyTypeObject*>(cls.ptr()); |
There was a problem hiding this comment.
We should probably just use PyObject* as registry key, force converting to PyTypeObject* adds code complexity without much benefits.
| nb::handle obj); | ||
|
|
||
| // Calls the registered unflatten_fn for the given type object. | ||
| nb::object unflatten_registered( |
There was a problem hiding this comment.
flatten_registered and unflatten_registered are helpers that do not need to be exposed in header.
| nb::object aux = seq[1]; | ||
| if (!aux.is_none()) { | ||
| try { | ||
| auto h = aux.attr("__hash__")(); |
| nb::object children_obj = seq[0]; | ||
| nb::object aux = seq[1]; | ||
|
|
||
| std::vector<nb::object> children; |
There was a problem hiding this comment.
Can you just do auto children = nb::cast<std::vector<nb::object>>(seq[0]);?
| "[flatten_registered] type is not registered as a pytree node"); | ||
| } | ||
| nb::object result = it->second.flatten_fn(obj); | ||
| if (!nb::isinstance<nb::tuple>(result) && !nb::isinstance<nb::list>(result)) { |
There was a problem hiding this comment.
The nb::cast<nb::sequence> should be able to do type check so I think this check is redundant.
| auto seq = nb::cast<nb::sequence>(result); | ||
| if (nb::len(seq) == 2) { | ||
| nb::object aux = seq[1]; | ||
| if (!aux.is_none()) { |
There was a problem hiding this comment.
There is no need to be defensive here since it is inside a try/catch, just do castings and let it throw when bad happens.
| for (auto& c : children) { | ||
| new_children.push_back(recurse(c)); | ||
| } | ||
| return unflatten_registered(type_handle, aux, new_children); |
There was a problem hiding this comment.
You can just pass subtree here?
- Header surface trimmed: only `register_pytree_node`, `is_registered_pytree`, `pytree_children`, `registered_pytree_fingerprint` are exposed. `flatten_registered`/`unflatten_registered` are internal helpers in trees.cpp and `init_trees` is no longer redeclared (already in mlx.cpp). - `register_pytree_node` now takes `nb::type_object` so nanobind enforces the type check; manual `PyType_Check` is gone. - Registry keyed by `PyObject*` directly — no `PyTypeObject*` reinterpret cast at the boundary. - Internal `flatten_registered` uses `nb::cast<std::vector<nb::object>>` for children and lets `nb::cast<nb::sequence>` enforce the list/tuple shape. - Fingerprint uses `nb::hash` and lets nanobind throw on unhashable aux (no extra defensive casting). - `tree_visit_update` / `tree_map` pass the subtree handle directly to `unflatten_registered` instead of fabricating a type handle.
|
@zcbenz Thanks for the detailed review — all nine inline comments addressed in 2f7c5e6:
Lint also fixed (clang-format + black) in the previous commit. Local CPU run: |
|
Sorry for being late to this but why do we think that is better than A more real world example of this is simply the training loop. We don't need the model to be a typed PyTree, we can very quickly make the whole call "pure functional" by passing in a dictionary of parameters Changing gears after discussing whether we should add this at all, I see two issues to be addressed in the code
|
Fixes #3499.
Addresses review feedback from @zcbenz: implementation moved to C++ so
mx.compileis natively aware of registered pytree types. The previous Python-sideutils.compilewrapper has been removed.Summary
mx.register_pytree_node(cls, flatten_fn, unflatten_fn)API (mirrorsjax.tree_util.register_pytree_node)python/src/trees.cpp, exposed viainit_trees()nanobind modulePyCompiledFun::call_implrecurse handles registered types: their type-id + aux hash participate in the compile cache keytree_visit,tree_map,tree_visit_update) recurse into registered nodes — same code path used bytree_unflattenon the compile pathmlx.utilsre-exportsregister_pytree_nodefor the naturalfrom mlx.utils import …access patternAPI
Test plan
python/tests/test_compile.py::TestCompile::test_compile_registered_pytree_nodecovers:flatten_fnreturn surfaces a cleanValueErrorRegression sweep on the CPU build:
python/tests/test_compile.py— 55 passedpython/tests/test_tree.py— 4 passedpython/tests/test_autograd.py+test_vmap.py— full suites greenImplementation notes
std::unordered_map<PyTypeObject*, PytreeNodeDef>triggers a use-after-finalize segfault at interpreter shutdown because the storednb::callables outlive Python state. This matches the lifetime trick used bystructure_sentinel().id(type)withhash(aux_data)(golden-ratio mixing constant) so two structurally distinct registered instances retrace correctly. Unhashable aux falls back to type-only fingerprinting.flatten_fnmay return children as eitherlistortuple; the registry rejects non-sequence return values up front.Files
python/src/trees.h(+38) — public API + helperspython/src/trees.cpp(+263) — registry,init_trees, tree-traversal hookspython/src/transforms.cpp(+15) —PyCompiledFun::call_implrecurse + error-message hintpython/src/mlx.cpp(+2) — wireinit_trees()intoNB_MODULEpython/mlx/utils.py(+1 net) — re-exportregister_pytree_nodepython/tests/test_compile.py(+63) — coverage🤖 Generated with Claude Code