Skip to content

Commit e7c5d56

Browse files
davidmcc73claude
andcommitted
Add LpB kernel patches for Qwen3.5 dense models (27B, 9B)
Loop-over-B custom GEMV kernels for expanding projections (N > K): gate_proj, up_proj, down_proj, in_proj_qkv, in_proj_z, out_proj, q_proj. These reduce S>1 verification cost from ~7ms/token to ~3ms/token, critical for speculative decoding speedup. Auto-detected for model_type=qwen3_5 (dense models like 27B, 9B). MoE models (qwen3_5_moe) use the existing batched fused patches instead. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent dd71182 commit e7c5d56

4 files changed

Lines changed: 249 additions & 0 deletions

File tree

src/exo/worker/engines/mlx/patches/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,3 +37,9 @@ def maybe_apply_patches(model: nn.Module, model_path: Path) -> None:
3737

3838
logger.info("Detected Qwen3.5 MoE model, applying batched fused kernel patches")
3939
apply_qwen35_batched_fused_patches(model)
40+
41+
elif model_type == "qwen3_5":
42+
from .qwen3_5.lpb_patch import apply_lpb_patches
43+
44+
logger.info("Detected Qwen3.5 dense model, applying LpB kernel patches")
45+
apply_lpb_patches(model, batch_size=4)

src/exo/worker/engines/mlx/patches/qwen3_5/__init__.py

Whitespace-only changes.
Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
#!/usr/bin/env python3
2+
"""Isolated loop-over-B GEMV kernel for quantized matmul.
3+
4+
Extracts the loop-over-B pattern from batched_fused_gdn_projections_8bit
5+
but without any epilogues — pure Y = X @ dequant(W)^T output.
6+
7+
For comparing our GEMV approach against MLX's affine_qmv_fast on
8+
an isolated QuantizedLinear operation (e.g., in_proj_qkv: N=8192, K=2048).
9+
10+
TG: (32, 2, 1) = 64 threads = 2 SGs.
11+
Each SG: 4 output rows.
12+
B loop inside row loop for low register pressure (R = 4B + 5).
13+
14+
Usage:
15+
from custom_qmv_loop_over_b import custom_qmv_loop_over_b
16+
y = custom_qmv_loop_over_b(x, w, scales, biases, M=8, N=8192, K=2048)
17+
"""
18+
19+
import mlx.core as mx
20+
21+
22+
def ceil_div(a, b):
23+
return (a + b - 1) // b
24+
25+
26+
def _gen_custom_qmv_source(M_val, N_val, K_val, group_size=64):
27+
gs = group_size
28+
sc_stride = 256 // gs
29+
slid_div = gs // 8
30+
K_groups = K_val // gs
31+
B = M_val # batch size = M
32+
33+
return f"""
34+
const int RESULTS_PER_SG = 4;
35+
const int VALUES_PER_THREAD = 8;
36+
const int BLOCK_SIZE = 256;
37+
const int K = {K_val};
38+
const int N = {N_val};
39+
const int M = {M_val};
40+
const int K_groups = {K_groups};
41+
const int SC_STRIDE = {sc_stride};
42+
const int SLID_DIV = {slid_div};
43+
44+
uint3 tgid = threadgroup_position_in_grid;
45+
uint sgid = simdgroup_index_in_threadgroup;
46+
uint slid = thread_index_in_simdgroup;
47+
int tg = tgid.y;
48+
49+
int out_row = tg * 8 + sgid * RESULTS_PER_SG;
50+
if (out_row >= N) return;
51+
52+
// Weight pointers
53+
const device uint8_t* ws = (const device uint8_t*)w + (long)out_row * K + slid * VALUES_PER_THREAD;
54+
const device bfloat16_t* sc = (const device bfloat16_t*)scales + (long)out_row * K_groups + slid / SLID_DIV;
55+
const device bfloat16_t* bi = (const device bfloat16_t*)biases + (long)out_row * K_groups + slid / SLID_DIV;
56+
57+
// Result accumulators: 4 rows × B batches
58+
float result[{4 * B}];
59+
for (int i = 0; i < {4 * B}; i++) result[i] = 0;
60+
61+
int x_base = slid * VALUES_PER_THREAD;
62+
63+
// K-loop: loop over B inside row loop
64+
for (int k_off = 0; k_off < K; k_off += BLOCK_SIZE) {{
65+
66+
for (int row = 0; row < RESULTS_PER_SG; row++) {{
67+
const device uint8_t* wl = ws + row * K;
68+
float s_val = float(sc[row * K_groups]);
69+
float b_val = float(bi[row * K_groups]);
70+
71+
for (int b = 0; b < {B}; b++) {{
72+
float accum = 0, xsum = 0;
73+
for (int i = 0; i < VALUES_PER_THREAD; i++) {{
74+
float xi = float(((const device bfloat16_t*)x)[b * K + x_base + i]);
75+
accum += xi * float(wl[i]);
76+
xsum += xi;
77+
}}
78+
result[b * 4 + row] += s_val * accum + xsum * b_val;
79+
}}
80+
}}
81+
82+
ws += BLOCK_SIZE; sc += SC_STRIDE; bi += SC_STRIDE; x_base += BLOCK_SIZE;
83+
}}
84+
85+
// Reduction
86+
for (int i = 0; i < {4 * B}; i++) result[i] = simd_sum(result[i]);
87+
88+
// Write output (bf16)
89+
if (slid < 4u) {{
90+
for (int b = 0; b < {B}; b++) {{
91+
int r = out_row + (int)slid;
92+
if (r < N) {{
93+
y[b * N + r] = static_cast<bfloat16_t>(result[b * 4 + slid]);
94+
}}
95+
}}
96+
}}
97+
"""
98+
99+
100+
_custom_qmv_cache = {}
101+
102+
103+
def custom_qmv_loop_over_b(x, w, scales, biases, M, N, K, group_size=64):
104+
"""Loop-over-B GEMV for quantized matmul.
105+
106+
Args:
107+
x: (M, K) bfloat16 input
108+
w: (N, K/4) uint32 packed 8-bit weights
109+
scales: (N, K/gs) bfloat16
110+
biases: (N, K/gs) bfloat16
111+
M, N, K: dimensions
112+
Returns:
113+
y: (M, N) bfloat16
114+
"""
115+
key = (M, N, K, group_size)
116+
if key not in _custom_qmv_cache:
117+
_custom_qmv_cache[key] = mx.fast.metal_kernel(
118+
name=f"custom_qmv_loop_b_M{M}_N{N}_K{K}",
119+
input_names=["x", "w", "scales", "biases"],
120+
output_names=["y"],
121+
source=_gen_custom_qmv_source(M, N, K, group_size),
122+
)
123+
kern = _custom_qmv_cache[key]
124+
125+
n_tg = ceil_div(N, 8)
126+
127+
result = kern(
128+
inputs=[x, w, scales, biases],
129+
output_shapes=[(M * N,)],
130+
output_dtypes=[mx.bfloat16],
131+
grid=(32, n_tg * 2, 1),
132+
threadgroup=(32, 2, 1),
133+
)
134+
135+
return result[0].reshape(M, N)
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
#!/usr/bin/env python3
2+
"""Loop-over-B patches for Qwen3.5-27B dense model.
3+
4+
Replaces vanilla QuantizedLinear calls with custom loop-over-B GEMV
5+
for projections where N > K (expanding projections). Falls back to
6+
vanilla for N <= K (contracting projections like down_proj, o_proj).
7+
8+
Usage:
9+
from lpb_patch import apply_lpb_patches
10+
apply_lpb_patches(model, batch_size=4)
11+
"""
12+
13+
import mlx.core as mx
14+
import mlx.nn as nn
15+
16+
from .custom_qmv_loop_over_b import custom_qmv_loop_over_b
17+
18+
19+
def _make_lpb_forward(original_module, N, K, BS, GS=64):
20+
"""Create a patched forward that uses loop-over-B."""
21+
w = original_module.weight
22+
s = original_module.scales
23+
b = original_module.biases
24+
25+
MAX_M = 16 # Max total tokens (B*S) for custom kernel; above this use vanilla
26+
27+
def forward(self_unused, x):
28+
# Use LpB for small M=B*S. Large prefill falls back to vanilla.
29+
M_total = 1
30+
for d in x.shape[:-1]:
31+
M_total *= d
32+
if M_total > MAX_M:
33+
return original_module(x)
34+
orig_shape = x.shape
35+
x_2d = x.reshape(-1, K)
36+
M = x_2d.shape[0]
37+
y = custom_qmv_loop_over_b(x_2d, w, s, b, M, N, K, GS)
38+
return y.reshape(*orig_shape[:-1], N)
39+
40+
return forward
41+
42+
43+
def apply_lpb_patches(model, batch_size=4):
44+
"""Patch all expanding QuantizedLinear projections with loop-over-B.
45+
46+
Only patches projections where N > K (expanding):
47+
- gate_proj, up_proj (17408 > 5120)
48+
- in_proj_qkv (10240 > 5120)
49+
- in_proj_z (6144 > 5120)
50+
- q_proj (12288 > 5120)
51+
52+
Skips N <= K projections (down_proj, o_proj, k_proj, v_proj)
53+
where vanilla is already efficient.
54+
"""
55+
inner = getattr(model, 'model', None) or model.language_model.model
56+
patched = 0
57+
58+
for li, layer in enumerate(inner.layers):
59+
# MLP: gate_proj, up_proj (N=17408, K=5120)
60+
mlp = layer.mlp
61+
for proj_name in ['gate_proj', 'up_proj', 'down_proj']:
62+
proj = getattr(mlp, proj_name)
63+
if isinstance(proj, nn.QuantizedLinear):
64+
N = proj.weight.shape[0] # output dim
65+
K_packed = proj.weight.shape[1]
66+
K = K_packed * 4 # 8-bit: 4 values per uint32
67+
setattr(mlp, proj_name, type('LpBLinear', (), {
68+
'__call__': _make_lpb_forward(proj, N, K, batch_size),
69+
'weight': proj.weight,
70+
'scales': proj.scales,
71+
'biases': proj.biases,
72+
})())
73+
patched += 1
74+
75+
# Attention projections
76+
if layer.is_linear:
77+
attn = layer.linear_attn
78+
for proj_name in ['in_proj_qkv', 'in_proj_z', 'out_proj']:
79+
if hasattr(attn, proj_name):
80+
proj = getattr(attn, proj_name)
81+
if isinstance(proj, nn.QuantizedLinear):
82+
N = proj.weight.shape[0]
83+
K = proj.weight.shape[1] * 4
84+
setattr(attn, proj_name, type('LpBLinear', (), {
85+
'__call__': _make_lpb_forward(proj, N, K, batch_size),
86+
'weight': proj.weight,
87+
'scales': proj.scales,
88+
'biases': proj.biases,
89+
})())
90+
patched += 1
91+
else:
92+
attn = layer.self_attn
93+
for proj_name in ['q_proj', 'o_proj']:
94+
if hasattr(attn, proj_name):
95+
proj = getattr(attn, proj_name)
96+
if isinstance(proj, nn.QuantizedLinear):
97+
N = proj.weight.shape[0]
98+
K = proj.weight.shape[1] * 4
99+
setattr(attn, proj_name, type('LpBLinear', (), {
100+
'__call__': _make_lpb_forward(proj, N, K, batch_size),
101+
'weight': proj.weight,
102+
'scales': proj.scales,
103+
'biases': proj.biases,
104+
})())
105+
patched += 1
106+
107+
print(f" Patched {patched} projections with loop-over-B")
108+
return patched

0 commit comments

Comments
 (0)