Skip to content

graph_trainer JIT fails with 'Node tangents_2 was invalid, but is output' when module forward returns a DTensor #3013

@fegin

Description

@fegin

TL;DR

#2963 converts all tensors (states, activations, loss) to DTensor (on TP dimension only). Graph Trainer CI fails on #2963. The core trainer with --compile.enable doesn't fail (both Inductor and AOTEager backends).

Below suspected root cause is generated by Claude and I don't have enough knowledge to verify if it is correct.

cc., @yiming0416 @SherlockNoMad @xmfan

Summary

After #2963 (config-based tensor parallel that returns DTensors from the module's forward rather than unwrapping to local), the graph_trainer experiment's JIT compile path fails inside the AOTAutograd partitioner with AssertionError: Node tangents_N was invalid, but is output. Reproduces under both the aot_eager and inductor backends; does not reproduce on the regular torchtitan trainer with either backend. Blocks graph_trainer CI on the PR1 stack.

Reproducer

cd torchtitan  # at pytorch/torchtitan#2963 (or any PR in the stack on top of it)
NGPU=8 MODULE=graph_trainer.llama3 CONFIG=graph_trainer_llama3_debugmodel \
  ./run_train.sh \
    --compile.mode jit \
    --parallelism.data_parallel_shard_degree=4 \
    --parallelism.tensor_parallel_degree=2

PP is not required. Same failure with --compile.backend=inductor. Regular trainer (core llama3) with --compile.enable passes on every FSDP/TP/PP/SP/backend combination I tested.

Stack trace

File "torchtitan/train.py", line 55, in main
  trainer.train()
File "torchtitan/trainer.py", line 852, in train
  self.train_step(data_iterator)
File "torchtitan/trainer.py", line 762, in train_step
  loss = self.forward_backward_step(...)
File "torchtitan/experiments/graph_trainer/trainer.py", line 78, in forward_backward_step
  return super().forward_backward_step(...)
File "torchtitan/trainer.py", line 679, in forward_backward_step
  self.pp_schedule.step(...)
File "torch/distributed/pipelining/schedules.py", line 1768, in step
  self._step_microbatches(...)
File "torch/distributed/pipelining/schedules.py", line 2181, in _step_microbatches
  self._initialize_stages(...)
File "torch/distributed/pipelining/schedules.py", line 1672, in _initialize_stages
  ) = self._initialize_pp_stages(...)
File "torch/distributed/pipelining/schedules.py", line 431, in _initialize_pp_stages
  next_stage_args = stage._prepare_forward_infra(...)
File "torch/distributed/pipelining/stage.py", line 2156, in _prepare_forward_infra
  fwd_meta_output = self._forward_metadata_inference(...)
File "torch/distributed/pipelining/stage.py", line 1886, in _forward_metadata_inference
  outputs = self._compute_outputs(...)
File "torch/distributed/pipelining/stage.py", line 1771, in _compute_outputs
  return module(*args, **kwargs)
File "torch/nn/modules/module.py", line 1776, in _wrapped_call_impl
  return self._compiled_call_impl(*args, **kwargs)
File "torch/_dynamo/eval_frame.py", line 1062, in compile_wrapper
  raise e.remove_dynamo_frames() from None
File "torch/_dynamo/output_graph.py", line 2975, in _call_user_compiler
  raise BackendCompilerFailed(...)
File "torch/_dynamo/output_graph.py", line 2950, in _call_user_compiler
  compiled_fn = compiler_fn(gm, example_inputs)
File "torchtitan/experiments/graph_trainer/jit_backend.py", line 158, in graph_trainer_custom_pass
  return backend(*args, **kwargs)
File "torch/_dynamo/backends/debugging.py", line 414, in aot_eager
  return aot_autograd(...)
File "torch/_dynamo/backends/common.py", line 123, in __call__
  cg = aot_module_simplified(gm, example_inputs, **self.kwargs)
File "torch/_functorch/aot_autograd.py", line 1234, in aot_module_simplified
  compiled_fn, _ = aot_stage2_compile(...)
File "torch/_functorch/_aot_autograd/graph_compile.py", line 378, in aot_stage2_compile
  return aot_stage2_autograd(aot_state, aot_graph_capture)
File "torch/_functorch/_aot_autograd/graph_compile.py", line 2274, in aot_stage2_autograd
  ) = _aot_stage2a_partition(...)
File "torch/_functorch/_aot_autograd/graph_compile.py", line 2012, in _aot_stage2a_partition
  _partition_joint_graph_into_fw_bw(...)
File "torch/_functorch/_aot_autograd/graph_compile.py", line 1762, in _partition_joint_graph_into_fw_bw
  fw_module, bw_module = aot_config.partition_fn(...)
File "torch/_functorch/partitioners.py", line 3607, in min_cut_rematerialization_partition
  node_info = classify_nodes(...)
File "torch/_functorch/partitioners.py", line 3517, in classify_nodes
  forward_only_graph = _extract_graph_with_inputs_outputs(
File "torch/_functorch/partitioners.py", line 380, in _extract_graph_with_inputs_outputs
  raise AssertionError(f"Node {x} was invalid, but is output")
torch._dynamo.exc.BackendCompilerFailed: backend='graph_trainer_custom_pass' raised:
AssertionError: Node tangents_2 was invalid, but is output

Suspected (not yet confirmed) root cause

AOTAutograd's subclass-output desugaring for a DTensor return appears to materialize the DTensor's non-tensor meta (device_mesh, placements) as extra graph output positions, which inflates num_inner_fwd_outputs and lets backward tangents_N placeholders slide into the forward-output slice. Inductor's subclass runtime wrapper likely hides this on the core-trainer path; graph_trainer's JIT pipeline (whole-model model.compile(fullgraph=True) + simple_fsdp parametrization + functorch_config.joint_custom_pass) appears to exercise a path where the non-tensor positions leak. Full joint-graph dump available on request.

Environment

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions