@@ -34,36 +34,32 @@ def forward(self,
3434 block_offsets = meta .block_offsets
3535 cu_q_seqlens = meta .cu_q_seqlens
3636 kv_seqlens = meta .kv_seqlens
37- is_decoding = meta .is_decoding
3837 q_seqlens = meta .q_seqlens
3938 bsz = kv_seqlens .size (0 )
4039 block_size = self ._block_size
4140
42- # quant query
43- # FP8 quantize Indexer Q (replaces fp4_act_quant for better precision)
44- # we might need to do quant fp4 in the future
45- q_2d = query .reshape (- 1 , query .size (- 1 ) * query .size (- 2 ))
41+ # Reshape to fp8_index expected layout upfront.
42+ # query: [bsz, seqlen, n_heads, head_dim] -> [cum_seqlen, n_heads, head_dim]
43+ # weights: [bsz, seqlen, n_heads] -> [cum_seqlen, n_heads]
44+ q_3d = query .flatten (0 , 1 )
45+ weights_2d = weights .flatten (0 , - 2 )
46+
47+ # FP8 quantize Indexer Q directly on 3D (replaces fp4_act_quant for better precision)
48+ q_2d = q_3d .reshape (- 1 , q_3d .size (- 1 ) * q_3d .size (- 2 ))
4649 q_fp8 , q_scale_2d = quant_fp8 (q_2d , group_size = 128 ,
4750 dtype = torch .float8_e4m3fn , scale_fmt = 'ue8m0' )
48- query = q_fp8 .view_as (query )
49- q_scale = q_scale_2d .view (query .shape [:- 1 ])
50-
51- # reshape q and weights
52- q_3d = query .flatten (0 , 1 )
53- q_scale = q_scale .flatten (0 , - 2 )
54- weights = weights .flatten (0 , - 2 )
55- q_scale_weighted = q_scale * weights # [bsz, n_heads]
51+ q_3d = q_fp8 .view_as (q_3d )
52+ q_scale = q_scale_2d .view (q_3d .shape [:- 1 ]) # [cum_seqlen, n_heads]
53+ q_scale_weighted = q_scale * weights_2d
5654
5755 total_lens = kv_seqlens
5856 num_index = torch .div (total_lens , self .compress_ratio , rounding_mode = 'floor' )
5957 max_kv_seqlen = meta .max_kv_seqlen if meta .max_kv_seqlen is not None else block_offsets .size (1 ) * block_size
6058 max_index = max (max_kv_seqlen // self .compress_ratio , 1 )
6159
6260 if max_index == 0 :
63- if is_decoding :
64- empty = query .new_empty ((1 , bsz , 0 ), dtype = torch .long )
65- else :
66- empty = query .new_empty ((bsz , 1 , 0 ), dtype = torch .long )
61+ total_q = q_3d .size (0 )
62+ empty = query .new_empty ((total_q , 0 ), dtype = torch .long )
6763 return V4IndexerOutput (indices_in_kvcache = empty ,
6864 topk_length = num_index .new_zeros ((bsz ,), dtype = torch .int32 ))
6965
@@ -87,11 +83,8 @@ def forward(self,
8783 else :
8884 topk = scores .topk (topk_width , dim = - 1 )[1 ]
8985
90- if is_decoding :
91- topk = topk .unsqueeze (1 ) # [bsz, 1, topk_width]
92- return V4IndexerOutput (indices_in_kvcache = topk , topk_length = topk_length )
93- else :
94- return V4IndexerOutput (indices_in_kvcache = topk .unsqueeze (0 ), topk_length = topk_length )
86+ # Always return [total_q, topk_width] — caller handles decode/prefill dimension adaptation
87+ return V4IndexerOutput (indices_in_kvcache = topk , topk_length = topk_length )
9588
9689
9790class TritonV4IndexerBuilder (BaseV4IndexerBuilder ):
0 commit comments