Skip to content

Commit a0ec686

Browse files
committed
opt indexer
1 parent 3cf0b07 commit a0ec686

4 files changed

Lines changed: 21 additions & 27 deletions

File tree

lmdeploy/pytorch/backends/cuda/attention/v4.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -363,7 +363,7 @@ def _forward_decoding(self, query, kv, attn_sink, attn_metadata: CudaV4Attention
363363
if self.compress_ratio:
364364
compressed_cache_fp8 = caches['compressed_kv_fp8']
365365
if index_out is not None:
366-
indices_in_kvcache = index_out.indices_in_kvcache
366+
indices_in_kvcache = index_out.indices_in_kvcache.unsqueeze(1) # [bsz, 1, topk_width]
367367
topk_length = index_out.topk_length
368368
elif self.compress_ratio == 4:
369369
indices_in_kvcache = attn_metadata.compress_fallback_indices_r4
@@ -440,7 +440,7 @@ def _select_compress_topk(self, index_out, attn_metadata: CudaV4AttentionMetadat
440440
return None, None
441441

442442
if index_out is not None:
443-
compress_topk = index_out.indices_in_kvcache.squeeze(0)
443+
compress_topk = index_out.indices_in_kvcache
444444
# Offset indexer's logical indices into flat_kv positions
445445
uncompressed_kv_lens = attn_metadata.prefill_uncompressed_kv_lens
446446
cu_q_seqlens = attn_metadata.cu_q_seqlens

lmdeploy/pytorch/backends/cuda/v4_indexer.py

Lines changed: 15 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -34,36 +34,32 @@ def forward(self,
3434
block_offsets = meta.block_offsets
3535
cu_q_seqlens = meta.cu_q_seqlens
3636
kv_seqlens = meta.kv_seqlens
37-
is_decoding = meta.is_decoding
3837
q_seqlens = meta.q_seqlens
3938
bsz = kv_seqlens.size(0)
4039
block_size = self._block_size
4140

42-
# quant query
43-
# FP8 quantize Indexer Q (replaces fp4_act_quant for better precision)
44-
# we might need to do quant fp4 in the future
45-
q_2d = query.reshape(-1, query.size(-1) * query.size(-2))
41+
# Reshape to fp8_index expected layout upfront.
42+
# query: [bsz, seqlen, n_heads, head_dim] -> [cum_seqlen, n_heads, head_dim]
43+
# weights: [bsz, seqlen, n_heads] -> [cum_seqlen, n_heads]
44+
q_3d = query.flatten(0, 1)
45+
weights_2d = weights.flatten(0, -2)
46+
47+
# FP8 quantize Indexer Q directly on 3D (replaces fp4_act_quant for better precision)
48+
q_2d = q_3d.reshape(-1, q_3d.size(-1) * q_3d.size(-2))
4649
q_fp8, q_scale_2d = quant_fp8(q_2d, group_size=128,
4750
dtype=torch.float8_e4m3fn, scale_fmt='ue8m0')
48-
query = q_fp8.view_as(query)
49-
q_scale = q_scale_2d.view(query.shape[:-1])
50-
51-
# reshape q and weights
52-
q_3d = query.flatten(0, 1)
53-
q_scale = q_scale.flatten(0, -2)
54-
weights = weights.flatten(0, -2)
55-
q_scale_weighted = q_scale * weights # [bsz, n_heads]
51+
q_3d = q_fp8.view_as(q_3d)
52+
q_scale = q_scale_2d.view(q_3d.shape[:-1]) # [cum_seqlen, n_heads]
53+
q_scale_weighted = q_scale * weights_2d
5654

5755
total_lens = kv_seqlens
5856
num_index = torch.div(total_lens, self.compress_ratio, rounding_mode='floor')
5957
max_kv_seqlen = meta.max_kv_seqlen if meta.max_kv_seqlen is not None else block_offsets.size(1) * block_size
6058
max_index = max(max_kv_seqlen // self.compress_ratio, 1)
6159

6260
if max_index == 0:
63-
if is_decoding:
64-
empty = query.new_empty((1, bsz, 0), dtype=torch.long)
65-
else:
66-
empty = query.new_empty((bsz, 1, 0), dtype=torch.long)
61+
total_q = q_3d.size(0)
62+
empty = query.new_empty((total_q, 0), dtype=torch.long)
6763
return V4IndexerOutput(indices_in_kvcache=empty,
6864
topk_length=num_index.new_zeros((bsz,), dtype=torch.int32))
6965

@@ -87,11 +83,8 @@ def forward(self,
8783
else:
8884
topk = scores.topk(topk_width, dim=-1)[1]
8985

90-
if is_decoding:
91-
topk = topk.unsqueeze(1) # [bsz, 1, topk_width]
92-
return V4IndexerOutput(indices_in_kvcache=topk, topk_length=topk_length)
93-
else:
94-
return V4IndexerOutput(indices_in_kvcache=topk.unsqueeze(0), topk_length=topk_length)
86+
# Always return [total_q, topk_width] — caller handles decode/prefill dimension adaptation
87+
return V4IndexerOutput(indices_in_kvcache=topk, topk_length=topk_length)
9588

9689

9790
class TritonV4IndexerBuilder(BaseV4IndexerBuilder):

lmdeploy/pytorch/models/deepseek_v4.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,6 @@ def forward(self,
237237
kv_rope = compressed_kv[..., -rd:].unsqueeze(1) # [total_flat, 1, rd]
238238
cos_c, sin_c = compress_pos_emb
239239
self.apply_rotary.forward_single(kv_rope, cos_c, sin_c, inplace=True, complex_mode=True)
240-
compressed_kv[..., -rd:] = kv_rope.squeeze(1)
241240
if self.rotate:
242241
compressed_kv = self.compressor_impl.rotate_activation(compressed_kv)
243242
else:

lmdeploy/pytorch/nn/rotary_embedding.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -270,8 +270,10 @@ def forward_single(self, x: Tensor, cos: Tensor, sin: Tensor, inplace: bool = Tr
270270
dummy_k = x_3d.new_empty(x_3d.size(0), 0, dummy_dim)
271271
x_3d, _ = self.forward(x_3d, dummy_k, cos, sin, inplace=False,
272272
complex_mode=complex_mode)
273-
x.copy_(x_3d.reshape(orig_shape))
274-
return x
273+
if inplace:
274+
x.copy_(x_3d.reshape(orig_shape))
275+
return x
276+
return x_3d.reshape(orig_shape)
275277

276278

277279
class FopeRotaryEmbedding(nn.Module):

0 commit comments

Comments
 (0)