Skip to content

Commit 587ec72

Browse files
author
Open Source Contributor
committed
Fix IndexError in sdpa_mask and flex_attention_mask for 0D tensors during ONNX export
Fix for Issue #45735 When torch.onnx.export is called with ModernBERT, cache_position can be passed as a 0-dimensional tensor (scalar), causing IndexError when accessing cache_position.shape[0] or cache_position[0]. This fix handles the 0D tensor case by unsqueezing to 1D before extracting shape and offset information.
1 parent 5c1c72b commit 587ec72

1 file changed

Lines changed: 4 additions & 0 deletions

File tree

src/transformers/masking_utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -477,6 +477,8 @@ def sdpa_mask(
477477
```
478478
479479
"""
480+
if cache_position.ndim == 0:
481+
cache_position = cache_position.unsqueeze(0)
480482
q_length = cache_position.shape[0]
481483

482484
# Potentially pad the 2D mask
@@ -660,6 +662,8 @@ def flex_attention_mask(
660662
attention_mask (`torch.Tensor`, optional):
661663
The 2D attention mask corresponding to padded tokens of shape (batch_size, number_of_seen_tokens+q_length)
662664
"""
665+
if cache_position.ndim == 0:
666+
cache_position = cache_position.unsqueeze(0)
663667
q_length, q_offset = cache_position.shape[0], cache_position[0]
664668

665669
# Potentially add the padding 2D mask

0 commit comments

Comments
 (0)