@@ -1523,15 +1523,18 @@ def _maybe_modify_attn_mask_npu(query: torch.Tensor, key: torch.Tensor, attn_mas
15231523
15241524 # Reshape Attention Mask: [batch_size, seq_len_k] -> [batch_size, 1, sqe_len_q, seq_len_k]
15251525 # https://www.hiascend.com/document/detail/zh/Pytorch/730/apiref/torchnpuCustomsapi/docs/context/torch_npu-npu_fusion_attention.md
1526- if (
1527- attn_mask is not None
1528- and attn_mask .ndim == 2
1529- and attn_mask .shape [0 ] == query .shape [0 ]
1530- and attn_mask .shape [1 ] == key .shape [1 ]
1531- ):
1532- B , Sq , Skv = attn_mask .shape [0 ], query .shape [1 ], key .shape [1 ]
1526+ if attn_mask is not None :
1527+ if (
1528+ attn_mask .ndim == 2
1529+ and attn_mask .shape [0 ] == query .shape [0 ]
1530+ and attn_mask .shape [1 ] == key .shape [1 ]
1531+ ):
1532+ B , Sq , Skv = attn_mask .shape [0 ], query .shape [1 ], key .shape [1 ]
1533+ attn_mask = attn_mask .unsqueeze (1 ).expand (B , Sq , Skv ).unsqueeze (1 ).contiguous ()
1534+ elif attn_mask .ndim == 4 and attn_mask .shape [1 :3 ] == (1 , 1 ):
1535+ attn_mask = attn_mask .expand (- 1 , - 1 , query .shape [1 ], - 1 )
1536+
15331537 attn_mask = ~ attn_mask .to (torch .bool )
1534- attn_mask = attn_mask .unsqueeze (1 ).expand (B , Sq , Skv ).unsqueeze (1 ).contiguous ()
15351538
15361539 return attn_mask
15371540
0 commit comments