Bug description
Contex
RL code on 04/23 started failing with errors like:
File "/home/lucaskabela/torchtitan/torchtitan/experiments/rl/models/vllm_wrapper.py", line 270, in forward
def forward(
File "/home/lucaskabela/pytorch/torch/_dynamo/eval_frame.py", line 1297, in _fn
return fn(*args, **kwargs)
File "/home/lucaskabela/vllm/vllm/compilation/caching.py", line 215, in __call__
return self.optimized_call(*args, **kwargs)
File "<string>", line 316, in execution_fn
NameError: name 'Shard' is not defined
This is because vllm/compilation/codegen.py::_node_reffalls back torepr(arg)when stringifying non-primitive graph arguments (line 155). The generated source isexec'd in a namespace that only has import torch(line 101, line 130). For any arg whose repr()is not eval-able in that torch-only namespace, the generated execution_fnraisesNameError` on first call.
torch.device is the most common example: repr(torch.device('cuda:0')) returns
the bare string "device(type='cuda', index=0)". Any model whose forward lifts a
device into a graph arg (e.g. via .to(device), torch.zeros(..., device=d),
tensor factories inside a traced op) will hit this. DTensor placement types
(Shard, Replicate, Partial) have the same shape — repr(Shard(dim=2)) is
'Shard(dim=2)'
We have worked around this with monkeypatching, but need a proper upstream fix, and to resolve this monkeypatch once landed
Versions
vllm at commit 595562651 (main).
Bug description
Contex
RL code on 04/23 started failing with errors like:
This is because vllm/compilation/codegen.py::_node_ref
falls back torepr(arg)when stringifying non-primitive graph arguments (line 155). The generated source isexec'd in a namespace that only hasimport torch(line 101, line 130). For any arg whoserepr()is not eval-able in that torch-only namespace, the generatedexecution_fnraisesNameError` on first call.torch.deviceis the most common example:repr(torch.device('cuda:0'))returnsthe bare string
"device(type='cuda', index=0)". Any model whose forward lifts adevice into a graph arg (e.g. via
.to(device),torch.zeros(..., device=d),tensor factories inside a traced op) will hit this. DTensor placement types
(
Shard,Replicate,Partial) have the same shape —repr(Shard(dim=2))is'Shard(dim=2)'We have worked around this with monkeypatching, but need a proper upstream fix, and to resolve this monkeypatch once landed
Versions
vllmat commit595562651(main).