Skip to content

feat: register_pytree_node — allow custom classes in mx.compile#3500

Open
st-adam wants to merge 3 commits into
ml-explore:mainfrom
st-adam:feat/register-pytree-node
Open

feat: register_pytree_node — allow custom classes in mx.compile#3500
st-adam wants to merge 3 commits into
ml-explore:mainfrom
st-adam:feat/register-pytree-node

Conversation

@st-adam
Copy link
Copy Markdown

@st-adam st-adam commented May 8, 2026

Fixes #3499.

Addresses review feedback from @zcbenz: implementation moved to C++ so mx.compile is natively aware of registered pytree types. The previous Python-side utils.compile wrapper has been removed.

Summary

  • New mx.register_pytree_node(cls, flatten_fn, unflatten_fn) API (mirrors jax.tree_util.register_pytree_node)
  • C++ registry in python/src/trees.cpp, exposed via init_trees() nanobind module
  • PyCompiledFun::call_impl recurse handles registered types: their type-id + aux hash participate in the compile cache key
  • Existing tree traversal (tree_visit, tree_map, tree_visit_update) recurse into registered nodes — same code path used by tree_unflatten on the compile path
  • mlx.utils re-exports register_pytree_node for the natural from mlx.utils import … access pattern

API

import mlx.core as mx

class Pair:
    def __init__(self, a, b):
        self.a = a
        self.b = b

mx.register_pytree_node(
    Pair,
    lambda p: ([p.a, p.b], None),          # flatten: (children, aux_data)
    lambda _aux, children: Pair(*children), # unflatten
)

@mx.compile
def add_pair(p):
    return p.a + p.b

add_pair(Pair(mx.array(3), mx.array(4)))    # → array(7, dtype=int32)

Test plan

python/tests/test_compile.py::TestCompile::test_compile_registered_pytree_node covers:

  • Pre-registration rejection (unchanged error path)
  • Compiled forward with a registered custom class
  • aux_data participation in cache key (tagged subclass → distinct trace)
  • Malformed flatten_fn return surfaces a clean ValueError

Regression sweep on the CPU build:

  • python/tests/test_compile.py — 55 passed
  • python/tests/test_tree.py — 4 passed
  • python/tests/test_autograd.py + test_vmap.py — full suites green

Implementation notes

  • The C++ registry is heap-allocated and never freed. A function-local std::unordered_map<PyTypeObject*, PytreeNodeDef> triggers a use-after-finalize segfault at interpreter shutdown because the stored nb::callables outlive Python state. This matches the lifetime trick used by structure_sentinel().
  • The compile-cache fingerprint mixes id(type) with hash(aux_data) (golden-ratio mixing constant) so two structurally distinct registered instances retrace correctly. Unhashable aux falls back to type-only fingerprinting.
  • flatten_fn may return children as either list or tuple; the registry rejects non-sequence return values up front.

Files

  • python/src/trees.h (+38) — public API + helpers
  • python/src/trees.cpp (+263) — registry, init_trees, tree-traversal hooks
  • python/src/transforms.cpp (+15) — PyCompiledFun::call_impl recurse + error-message hint
  • python/src/mlx.cpp (+2) — wire init_trees() into NB_MODULE
  • python/mlx/utils.py (+1 net) — re-export register_pytree_node
  • python/tests/test_compile.py (+63) — coverage

🤖 Generated with Claude Code

Copy link
Copy Markdown
Collaborator

@zcbenz zcbenz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the register_pytree_node API is a good thing to have, but utils.compile would be a bad addition.

The register_pytree_node API should be implemented in C++ layer so mx.compile can be made aware of it directly. And we should probably remove the python versions of tree utils and expose the C++ ones instead so we don't have to duplicate the register_pytree_node implementation in 2 languages.

Adds a JAX-style pytree registration mechanism so third-party Python
classes can flow through mx.compile, tree_visit, tree_map, and the
rest of MLX's tree utilities.

Motivation
----------
mx.compile rejects any function argument that is not a plain array,
list, dict, tuple, or scalar constant:

  ValueError: [compile] Function arguments must be trees of arrays or
  constants (floats, ints, strings, or None), but received type
  mlx_lm.models.cache.ArraysCache.

Any model whose forward pass receives a custom cache object — every
hybrid SSM+attention model in mlx-lm (Qwen 3.5/3.6, Llama 4, Gemma 3n,
etc.) — therefore cannot be compiled, even though the computation is
fully expressible as MLX ops.

Implementation
--------------
The registry, the public API, and all tree-traversal hooks live in
C++ (per review feedback: a Python-side compile wrapper would
duplicate the implementation across two languages).

python/src/trees.h, python/src/trees.cpp:

* PytreeNodeDef     — (flatten_fn, unflatten_fn) pair.
* registry()        — heap-allocated map keyed by PyTypeObject*, never
                      freed.  Avoids the use-after-finalize segfault
                      that a function-local static would hit when
                      Python tears down the interpreter while
                      stored nb::callables still hold refs.  Same
                      lifetime pattern used by structure_sentinel().
* register_pytree_node(cls, flatten_fn, unflatten_fn) — exposed to
                      Python as mx.register_pytree_node.
* is_registered_pytree, flatten_registered, unflatten_registered,
  registered_pytree_fingerprint — internal helpers.
* tree_visit / tree_map (multi-tree and single-tree overloads) and
  tree_visit_update now recurse into registered types, so
  tree_unflatten through the compile path reconstructs them.

python/src/transforms.cpp:

* PyCompiledFun::call_impl::recurse adds a pytree_identifier branch:
  flattens the registered node into its children and embeds the
  type-id + aux hash in the constants vector, so two structurally
  different registered instances retrace correctly.
* Error message updated to mention mx.register_pytree_node.

python/src/mlx.cpp:

* Wires init_trees() into NB_MODULE.

python/mlx/utils.py:

* re-exports mlx.core.register_pytree_node so users can do either
  `import mlx.core as mx; mx.register_pytree_node(...)` or
  `from mlx.utils import register_pytree_node`.

Test
----
python/tests/test_compile.py::test_compile_registered_pytree_node:

* mx.compile rejects an unregistered custom class.
* After registration the compiled forward returns the correct value.
* aux_data tagged differently on two subclasses retraces cleanly.
* flatten_fn returning a malformed value surfaces a clear ValueError.

All existing tests still pass:
- python/tests/test_compile.py        — 55 passed
- python/tests/test_tree.py           —  4 passed
- python/tests/test_autograd.py + test_vmap.py — full suite green

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
@st-adam st-adam force-pushed the feat/register-pytree-node branch from 59c91c6 to 727264c Compare May 11, 2026 06:36
@st-adam st-adam changed the title feat(utils): add register_pytree_node and pytree-aware compile wrapper feat: register_pytree_node — allow custom classes in mx.compile May 11, 2026
@st-adam
Copy link
Copy Markdown
Author

st-adam commented May 11, 2026

@zcbenz Thanks for the review. I've moved the implementation into C++ as requested:

  • Registry, register_pytree_node, and PyCompiledFun::call_impl recurse all live in python/src/trees.{cpp,h} and python/src/transforms.cpp.
  • The Python utils.compile wrapper is gone; mx.compile natively handles registered types now.
  • mlx.utils.register_pytree_node is a thin re-export of the C++ binding, so we keep a single source of truth as you suggested.
  • Existing tree_visit / tree_map / tree_visit_update overloads recurse into registered nodes, which is what tree_unflatten rides on the compile path — so the trace/unflatten round-trip just works.
  • I left the Python tree_* utilities in python/mlx/utils.py untouched for this PR. Migrating them to C++ bindings feels like a separate, larger change; happy to do that as a follow-up if you'd prefer it bundled.

Verified locally on a CPU build: test_compile.py (55), test_tree.py (4), test_autograd.py + test_vmap.py all green, plus a new test_compile_registered_pytree_node covering the registered-pair, aux-data retrace, and malformed-flatten cases.

Comment thread python/src/trees.h Outdated
// Combines id(type) and hash(aux) so that compile retraces if either changes.
uint64_t registered_pytree_fingerprint(nb::handle obj);

void init_trees(nb::module_& m);
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This declaration is not needed, already declared in python/src/mlx.cpp.

Comment thread python/src/trees.cpp Outdated
} // namespace

void register_pytree_node(
nb::object cls,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can use nb::type_object and let nanobind do the check.

Comment thread python/src/trees.cpp Outdated
throw std::invalid_argument(
"[register_pytree_node] cls must be a Python class object.");
}
PyTypeObject* type = reinterpret_cast<PyTypeObject*>(cls.ptr());
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should probably just use PyObject* as registry key, force converting to PyTypeObject* adds code complexity without much benefits.

Comment thread python/src/trees.h Outdated
nb::handle obj);

// Calls the registered unflatten_fn for the given type object.
nb::object unflatten_registered(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

flatten_registered and unflatten_registered are helpers that do not need to be exposed in header.

Comment thread python/src/trees.cpp Outdated
nb::object aux = seq[1];
if (!aux.is_none()) {
try {
auto h = aux.attr("__hash__")();
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can use nb::hash

Comment thread python/src/trees.cpp Outdated
nb::object children_obj = seq[0];
nb::object aux = seq[1];

std::vector<nb::object> children;
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you just do auto children = nb::cast<std::vector<nb::object>>(seq[0]);?

Comment thread python/src/trees.cpp Outdated
"[flatten_registered] type is not registered as a pytree node");
}
nb::object result = it->second.flatten_fn(obj);
if (!nb::isinstance<nb::tuple>(result) && !nb::isinstance<nb::list>(result)) {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The nb::cast<nb::sequence> should be able to do type check so I think this check is redundant.

Comment thread python/src/trees.cpp Outdated
auto seq = nb::cast<nb::sequence>(result);
if (nb::len(seq) == 2) {
nb::object aux = seq[1];
if (!aux.is_none()) {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is no need to be defensive here since it is inside a try/catch, just do castings and let it throw when bad happens.

Comment thread python/src/trees.cpp Outdated
for (auto& c : children) {
new_children.push_back(recurse(c));
}
return unflatten_registered(type_handle, aux, new_children);
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can just pass subtree here?

Adam Staniszewski added 2 commits May 12, 2026 11:21
- Header surface trimmed: only `register_pytree_node`,
  `is_registered_pytree`, `pytree_children`, `registered_pytree_fingerprint`
  are exposed.  `flatten_registered`/`unflatten_registered` are internal
  helpers in trees.cpp and `init_trees` is no longer redeclared (already
  in mlx.cpp).
- `register_pytree_node` now takes `nb::type_object` so nanobind enforces
  the type check; manual `PyType_Check` is gone.
- Registry keyed by `PyObject*` directly — no `PyTypeObject*` reinterpret
  cast at the boundary.
- Internal `flatten_registered` uses `nb::cast<std::vector<nb::object>>`
  for children and lets `nb::cast<nb::sequence>` enforce the
  list/tuple shape.
- Fingerprint uses `nb::hash` and lets nanobind throw on unhashable aux
  (no extra defensive casting).
- `tree_visit_update` / `tree_map` pass the subtree handle directly to
  `unflatten_registered` instead of fabricating a type handle.
@st-adam
Copy link
Copy Markdown
Author

st-adam commented May 12, 2026

@zcbenz Thanks for the detailed review — all nine inline comments addressed in 2f7c5e6:

  • trees.h surface is now register_pytree_node, is_registered_pytree, pytree_children, registered_pytree_fingerprint. flatten_registered / unflatten_registered moved into the anonymous namespace in trees.cpp. Stale init_trees declaration removed.
  • register_pytree_node takes nb::type_object cls so nanobind enforces the type check.
  • Registry is now keyed by PyObject* directly — no PyTypeObject* reinterpret cast.
  • flatten_registered uses nb::cast<std::vector<nb::object>>(seq[0]) and lets nb::cast<nb::sequence> enforce the (children, aux) shape.
  • Fingerprint uses nb::hash and lets nanobind throw on unhashable aux.
  • tree_visit_update / tree_map pass the subtree handle directly to unflatten_registered (no fabricated type handle).
  • transforms.cpp now uses the new pytree_children helper.

Lint also fixed (clang-format + black) in the previous commit. Local CPU run: test_compile.py (56), test_tree.py (4), test_autograd.py + test_vmap.py (64) all green. Net diff: -49 lines.

@angeloskath
Copy link
Copy Markdown
Member

Sorry for being late to this but why do we think that register_pytree_node is a good addition? I find that specifically this option is often abused making modules be typed PyTrees and optimizers be typed PyTrees for very little added benefit. Do we have any real world use case where this simplifies the code? A good example that I would personally avoid is Equinox. Perhaps the problem there is not the proliferation of typed PyTrees but rather their immutability but I am not sure I see how

@mx.compile
def add_pair(p):
    return p.a + p.b

is better than

@mx.compile
def _add_pair(a, b):
    return a + b

def add_pair(p):
    return _add_pair(p.a, p.b)

A more real world example of this is simply the training loop. We don't need the model to be a typed PyTree, we can very quickly make the whole call "pure functional" by passing in a dictionary of parameters

@mx.compile
def step(params, opt_state, batch):
    model.update(params)
    loss_val, grads = nn.value_and_grad(model, loss_fn)(batch)
    optimizer.state = opt_state
    optimizer.update(model, grads)
    return (loss_val, model.parameters(), optimizer.state)

Changing gears after discussing whether we should add this at all, I see two issues to be addressed in the code

  1. The pytree check should be first otherwise any custom type that subclasses tuple, list or dict cannot be a registered pytree.
  2. In tree_visit_update we actually edit in-place the passed in tree. The registered types are now immutable like tuples which may be not the wanted behavior. For instance if I were to use a Pair from your example above in @partial(mx.compile, outputs=pair) it wouldn't quite work.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

mx.compile rejects custom cache objects — no pytree registration mechanism

3 participants