Below suspected root cause is generated by Claude and I don't have enough knowledge to verify if it is correct.
PP is not required. Same failure with --compile.backend=inductor. Regular trainer (core llama3) with --compile.enable passes on every FSDP/TP/PP/SP/backend combination I tested.
File "torchtitan/train.py", line 55, in main
trainer.train()
File "torchtitan/trainer.py", line 852, in train
self.train_step(data_iterator)
File "torchtitan/trainer.py", line 762, in train_step
loss = self.forward_backward_step(...)
File "torchtitan/experiments/graph_trainer/trainer.py", line 78, in forward_backward_step
return super().forward_backward_step(...)
File "torchtitan/trainer.py", line 679, in forward_backward_step
self.pp_schedule.step(...)
File "torch/distributed/pipelining/schedules.py", line 1768, in step
self._step_microbatches(...)
File "torch/distributed/pipelining/schedules.py", line 2181, in _step_microbatches
self._initialize_stages(...)
File "torch/distributed/pipelining/schedules.py", line 1672, in _initialize_stages
) = self._initialize_pp_stages(...)
File "torch/distributed/pipelining/schedules.py", line 431, in _initialize_pp_stages
next_stage_args = stage._prepare_forward_infra(...)
File "torch/distributed/pipelining/stage.py", line 2156, in _prepare_forward_infra
fwd_meta_output = self._forward_metadata_inference(...)
File "torch/distributed/pipelining/stage.py", line 1886, in _forward_metadata_inference
outputs = self._compute_outputs(...)
File "torch/distributed/pipelining/stage.py", line 1771, in _compute_outputs
return module(*args, **kwargs)
File "torch/nn/modules/module.py", line 1776, in _wrapped_call_impl
return self._compiled_call_impl(*args, **kwargs)
File "torch/_dynamo/eval_frame.py", line 1062, in compile_wrapper
raise e.remove_dynamo_frames() from None
File "torch/_dynamo/output_graph.py", line 2975, in _call_user_compiler
raise BackendCompilerFailed(...)
File "torch/_dynamo/output_graph.py", line 2950, in _call_user_compiler
compiled_fn = compiler_fn(gm, example_inputs)
File "torchtitan/experiments/graph_trainer/jit_backend.py", line 158, in graph_trainer_custom_pass
return backend(*args, **kwargs)
File "torch/_dynamo/backends/debugging.py", line 414, in aot_eager
return aot_autograd(...)
File "torch/_dynamo/backends/common.py", line 123, in __call__
cg = aot_module_simplified(gm, example_inputs, **self.kwargs)
File "torch/_functorch/aot_autograd.py", line 1234, in aot_module_simplified
compiled_fn, _ = aot_stage2_compile(...)
File "torch/_functorch/_aot_autograd/graph_compile.py", line 378, in aot_stage2_compile
return aot_stage2_autograd(aot_state, aot_graph_capture)
File "torch/_functorch/_aot_autograd/graph_compile.py", line 2274, in aot_stage2_autograd
) = _aot_stage2a_partition(...)
File "torch/_functorch/_aot_autograd/graph_compile.py", line 2012, in _aot_stage2a_partition
_partition_joint_graph_into_fw_bw(...)
File "torch/_functorch/_aot_autograd/graph_compile.py", line 1762, in _partition_joint_graph_into_fw_bw
fw_module, bw_module = aot_config.partition_fn(...)
File "torch/_functorch/partitioners.py", line 3607, in min_cut_rematerialization_partition
node_info = classify_nodes(...)
File "torch/_functorch/partitioners.py", line 3517, in classify_nodes
forward_only_graph = _extract_graph_with_inputs_outputs(
File "torch/_functorch/partitioners.py", line 380, in _extract_graph_with_inputs_outputs
raise AssertionError(f"Node {x} was invalid, but is output")
torch._dynamo.exc.BackendCompilerFailed: backend='graph_trainer_custom_pass' raised:
AssertionError: Node tangents_2 was invalid, but is output
TL;DR
#2963 converts all tensors (states, activations, loss) to DTensor (on TP dimension only). Graph Trainer CI fails on #2963. The core trainer with
--compile.enabledoesn't fail (both Inductor and AOTEager backends).Below suspected root cause is generated by Claude and I don't have enough knowledge to verify if it is correct.
cc., @yiming0416 @SherlockNoMad @xmfan
Summary
After #2963 (config-based tensor parallel that returns
DTensors from the module'sforwardrather than unwrapping to local), thegraph_trainerexperiment's JIT compile path fails inside the AOTAutograd partitioner withAssertionError: Node tangents_N was invalid, but is output. Reproduces under both theaot_eagerandinductorbackends; does not reproduce on the regular torchtitan trainer with either backend. Blocks graph_trainer CI on the PR1 stack.Reproducer
PP is not required. Same failure with
--compile.backend=inductor. Regular trainer (corellama3) with--compile.enablepasses on every FSDP/TP/PP/SP/backend combination I tested.Stack trace
Suspected (not yet confirmed) root cause
AOTAutograd's subclass-output desugaring for a
DTensorreturn appears to materialize the DTensor's non-tensor meta (device_mesh,placements) as extra graph output positions, which inflatesnum_inner_fwd_outputsand lets backwardtangents_Nplaceholders slide into the forward-output slice. Inductor's subclass runtime wrapper likely hides this on the core-trainer path; graph_trainer's JIT pipeline (whole-modelmodel.compile(fullgraph=True)+ simple_fsdp parametrization +functorch_config.joint_custom_pass) appears to exercise a path where the non-tensor positions leak. Full joint-graph dump available on request.Environment