Skip to content

Commit 2623481

Browse files
committed
Fix _native_npu_attention: add inversion for 4D attn_mask and expand when dim2/dim3 == 1
1 parent 80451b9 commit 2623481

1 file changed

Lines changed: 11 additions & 8 deletions

File tree

src/diffusers/models/attention_dispatch.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1523,15 +1523,18 @@ def _maybe_modify_attn_mask_npu(query: torch.Tensor, key: torch.Tensor, attn_mas
15231523

15241524
# Reshape Attention Mask: [batch_size, seq_len_k] -> [batch_size, 1, sqe_len_q, seq_len_k]
15251525
# https://www.hiascend.com/document/detail/zh/Pytorch/730/apiref/torchnpuCustomsapi/docs/context/torch_npu-npu_fusion_attention.md
1526-
if (
1527-
attn_mask is not None
1528-
and attn_mask.ndim == 2
1529-
and attn_mask.shape[0] == query.shape[0]
1530-
and attn_mask.shape[1] == key.shape[1]
1531-
):
1532-
B, Sq, Skv = attn_mask.shape[0], query.shape[1], key.shape[1]
1526+
if attn_mask is not None:
1527+
if (
1528+
attn_mask.ndim == 2
1529+
and attn_mask.shape[0] == query.shape[0]
1530+
and attn_mask.shape[1] == key.shape[1]
1531+
):
1532+
B, Sq, Skv = attn_mask.shape[0], query.shape[1], key.shape[1]
1533+
attn_mask = attn_mask.unsqueeze(1).expand(B, Sq, Skv).unsqueeze(1).contiguous()
1534+
elif attn_mask.ndim == 4 and attn_mask.shape[1:3] == (1, 1):
1535+
attn_mask = attn_mask.expand(-1, -1, query.shape[1], -1)
1536+
15331537
attn_mask = ~attn_mask.to(torch.bool)
1534-
attn_mask = attn_mask.unsqueeze(1).expand(B, Sq, Skv).unsqueeze(1).contiguous()
15351538

15361539
return attn_mask
15371540

0 commit comments

Comments
 (0)