Skip to content

Commit 0efa361

Browse files
committed
Version: 0.10.17-alpha.2
Llama3.2: Add full Mila/HF equivalency validation - Add detailed Python scripts for prefill and decode equivalency vs HF - Refactor GQA/RoPE CUDA ops to accept Q, K, V separately in decode Numerical validation of Mila's Llama3.2 against HuggingFace for both prefill and decode. Marks full-network greedy decode validation as complete.
1 parent 588ff05 commit 0efa361

18 files changed

Lines changed: 3462 additions & 285 deletions

File tree

Data/Scripts/hf_greedy_validation.py

Lines changed: 0 additions & 9 deletions
This file was deleted.
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
from transformers import AutoTokenizer, AutoModelForCausalLM
2+
import torch
3+
4+
model_id = "meta-llama/Llama-3.2-1B"
5+
6+
tokenizer = AutoTokenizer.from_pretrained(model_id)
7+
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float32)
8+
9+
ids = tokenizer.encode("Once upon a time", return_tensors="pt", add_special_tokens=False)
10+
out = model.generate(ids, max_new_tokens=64, do_sample=False) # greedy
11+
print(tokenizer.decode(out[0]))
Lines changed: 359 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,359 @@
1+
import struct
2+
import torch
3+
from transformers import AutoModelForCausalLM, AutoTokenizer
4+
5+
model_id = "meta-llama/Llama-3.2-1B"
6+
7+
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float32)
8+
model.eval()
9+
10+
tokenizer = AutoTokenizer.from_pretrained(model_id)
11+
12+
prompt = "Once upon a time"
13+
input_ids = tokenizer.encode(prompt, return_tensors="pt", add_special_tokens=False)
14+
print(f"Prompt: {prompt!r}")
15+
print(f"Token ids: {input_ids[0].tolist()}")
16+
print(f"Tokens: {[tokenizer.decode([t]) for t in input_ids[0].tolist()]}")
17+
18+
# ── Constants ────────────────────────────────────────────────────────────────
19+
20+
_FIRST_N = 16
21+
_MAX_ROWS = 4
22+
23+
# ── Capture storage ──────────────────────────────────────────────────────────
24+
25+
captured = {}
26+
captured_slices = {}
27+
28+
# ── Stats helpers ────────────────────────────────────────────────────────────
29+
30+
def fnv1a_checksum_last_token(t: torch.Tensor) -> int:
31+
FNV_OFFSET = 1469598103934665603
32+
FNV_PRIME = 1099511628211
33+
MASK64 = 0xFFFFFFFFFFFFFFFF
34+
35+
last = t[0, :] if t.dim() == 2 else t[0, -1, :]
36+
checksum = FNV_OFFSET
37+
38+
for val in last.detach().cpu().to(torch.float32):
39+
bits = struct.unpack('<I', struct.pack('<f', float(val)))[0]
40+
for byte_idx in range(4):
41+
b = (bits >> (byte_idx * 8)) & 0xFF
42+
checksum = ((checksum ^ b) * FNV_PRIME) & MASK64
43+
44+
return checksum
45+
46+
def stats_and_checksum_last_token(t: torch.Tensor) -> dict:
47+
"""
48+
min/max/mean/std/checksum for the last-token vector.
49+
std uses population formula (divide by N) to match Mila's print_stats.
50+
"""
51+
last = t[0, :] if t.dim() == 2 else t[0, -1, :]
52+
last_f32 = last.detach().cpu().to(torch.float32)
53+
n = last_f32.numel()
54+
mean_val = last_f32.sum().item() / n
55+
var_val = ((last_f32 - mean_val) ** 2).sum().item() / n
56+
57+
return {
58+
"min": last_f32.min().item(),
59+
"max": last_f32.max().item(),
60+
"mean": mean_val,
61+
"std": var_val ** 0.5,
62+
"checksum": fnv1a_checksum_last_token(t),
63+
}
64+
65+
def _fmt_mila_stats(stats: dict) -> str:
66+
return (
67+
f"min={stats['min']:.6f} max={stats['max']:.6f} "
68+
f"mean={stats['mean']:.6f} std={stats['std']:.6f} "
69+
f"checksum=0x{stats['checksum']:016x}"
70+
)
71+
72+
# ── Hook factories ───────────────────────────────────────────────────────────
73+
74+
def make_stats_hook(name):
75+
def fn(module, input, output):
76+
t = output if isinstance(output, torch.Tensor) else output[0]
77+
captured[name] = stats_and_checksum_last_token(t)
78+
return fn
79+
80+
def make_store_rows_hook(name, rows=_MAX_ROWS, cols=_FIRST_N):
81+
def fn(module, input, output):
82+
t = output if isinstance(output, torch.Tensor) else output[0]
83+
rows_list = []
84+
if t.dim() == 2:
85+
rows_list.append([float(x) for x in t[0, :cols].detach().cpu().tolist()])
86+
else:
87+
use_rows = min(rows, t.size(1))
88+
for r in range(use_rows):
89+
rows_list.append([float(x) for x in t[0, r, :cols].detach().cpu().tolist()])
90+
captured_slices[name] = rows_list
91+
return fn
92+
93+
def make_attn_output_pre_hook(layer_index):
94+
"""Pre-hook on o_proj — matches Mila's attn_->decode() return value."""
95+
def fn(module, input):
96+
t = input[0] if isinstance(input, (tuple, list)) else input
97+
captured[f"layer_{layer_index}.attn_out"] = stats_and_checksum_last_token(t)
98+
rows_list = []
99+
use_rows = min(_MAX_ROWS, t.size(1) if t.dim() == 3 else 1)
100+
for r in range(use_rows):
101+
vals = t[0, r, :_FIRST_N] if t.dim() == 3 else t[0, :_FIRST_N]
102+
rows_list.append([float(x) for x in vals.detach().cpu().tolist()])
103+
captured_slices[f"layer_{layer_index}.attn_out_rows_first{_FIRST_N}"] = rows_list
104+
return fn
105+
106+
# ── Post-RoPE K/Q capture via apply_rotary_pos_emb monkey-patch ─────────────
107+
108+
def install_rope_capture():
109+
"""
110+
Monkey-patch apply_rotary_pos_emb in the Llama modeling module to capture
111+
post-RoPE Q and K for layer 0. Returns a restore function.
112+
"""
113+
import transformers.models.llama.modeling_llama as llama_mod
114+
115+
original_rope = llama_mod.apply_rotary_pos_emb
116+
117+
def capturing_rope(q, k, cos, sin ):
118+
q_rot, k_rot = original_rope(q, k, cos, sin )
119+
120+
# q_rot: [B, n_heads, T, head_dim]
121+
# k_rot: [B, n_kv_heads, T, head_dim]
122+
# Capture last token, head 0, first _FIRST_N elements
123+
captured["layer_0.k_post_rope"] = k_rot[0, 0, -1, :_FIRST_N].detach().cpu().tolist()
124+
captured["layer_0.q_post_rope"] = q_rot[0, 0, -1, :_FIRST_N].detach().cpu().tolist()
125+
126+
# Also capture all KV heads for the last token (first _FIRST_N elements each)
127+
# so we can compare against Mila's full K cache row at position
128+
n_kv_heads = k_rot.shape[1]
129+
k_all_heads = []
130+
for h in range(n_kv_heads):
131+
k_all_heads.append([float(x) for x in k_rot[0, h, -1, :_FIRST_N].detach().cpu().tolist()])
132+
captured_slices["layer_0.k_post_rope_all_heads"] = k_all_heads
133+
134+
return q_rot, k_rot
135+
136+
llama_mod.apply_rotary_pos_emb = capturing_rope
137+
138+
def restore():
139+
llama_mod.apply_rotary_pos_emb = original_rope
140+
141+
return restore
142+
143+
# ── Hook registration — decode only, layer 0 ────────────────────────────────
144+
145+
def register_decode_hooks():
146+
hooks = []
147+
layer = model.model.layers[0]
148+
149+
if hasattr(model.model, "embed_tokens"):
150+
hooks.append(model.model.embed_tokens.register_forward_hook(
151+
make_stats_hook("token_embeds")))
152+
hooks.append(model.model.embed_tokens.register_forward_hook(
153+
make_store_rows_hook(f"token_embeds_rows_first{_FIRST_N}")))
154+
155+
if hasattr(layer, "input_layernorm"):
156+
hooks.append(layer.input_layernorm.register_forward_hook(
157+
make_stats_hook("layer_0.rmsn_1")))
158+
hooks.append(layer.input_layernorm.register_forward_hook(
159+
make_store_rows_hook(f"layer_0.rmsn_1_rows_first{_FIRST_N}")))
160+
161+
if hasattr(layer.self_attn, "q_proj"):
162+
hooks.append(layer.self_attn.q_proj.register_forward_hook(
163+
make_stats_hook("layer_0.q_pre_rope")))
164+
hooks.append(layer.self_attn.q_proj.register_forward_hook(
165+
make_store_rows_hook(f"layer_0.q_pre_rope_rows_first{_FIRST_N}")))
166+
167+
if hasattr(layer.self_attn, "k_proj"):
168+
hooks.append(layer.self_attn.k_proj.register_forward_hook(
169+
make_stats_hook("layer_0.k_pre_rope")))
170+
hooks.append(layer.self_attn.k_proj.register_forward_hook(
171+
make_store_rows_hook(f"layer_0.k_pre_rope_rows_first{_FIRST_N}")))
172+
173+
if hasattr(layer.self_attn, "v_proj"):
174+
hooks.append(layer.self_attn.v_proj.register_forward_hook(
175+
make_stats_hook("layer_0.v_proj")))
176+
hooks.append(layer.self_attn.v_proj.register_forward_hook(
177+
make_store_rows_hook(f"layer_0.v_proj_rows_first{_FIRST_N}")))
178+
179+
if hasattr(layer.self_attn, "o_proj"):
180+
hooks.append(layer.self_attn.o_proj.register_forward_pre_hook(
181+
make_attn_output_pre_hook(0)))
182+
hooks.append(layer.self_attn.o_proj.register_forward_hook(
183+
make_stats_hook("layer_0.fc_out_proj")))
184+
hooks.append(layer.self_attn.o_proj.register_forward_hook(
185+
make_store_rows_hook(f"layer_0.fc_out_proj_rows_first{_FIRST_N}")))
186+
187+
if hasattr(layer, "post_attention_layernorm"):
188+
hooks.append(layer.post_attention_layernorm.register_forward_hook(
189+
make_stats_hook("layer_0.rmsn_2")))
190+
hooks.append(layer.post_attention_layernorm.register_forward_hook(
191+
make_store_rows_hook(f"layer_0.rmsn_2_rows_first{_FIRST_N}")))
192+
193+
if hasattr(layer, "mlp"):
194+
if hasattr(layer.mlp, "gate_proj"):
195+
hooks.append(layer.mlp.gate_proj.register_forward_hook(
196+
make_stats_hook("layer_0.gate_proj")))
197+
hooks.append(layer.mlp.gate_proj.register_forward_hook(
198+
make_store_rows_hook(f"layer_0.gate_proj_rows_first{_FIRST_N}")))
199+
if hasattr(layer.mlp, "up_proj"):
200+
hooks.append(layer.mlp.up_proj.register_forward_hook(
201+
make_stats_hook("layer_0.up_proj")))
202+
hooks.append(layer.mlp.up_proj.register_forward_hook(
203+
make_store_rows_hook(f"layer_0.up_proj_rows_first{_FIRST_N}")))
204+
if hasattr(layer.mlp, "down_proj"):
205+
hooks.append(layer.mlp.down_proj.register_forward_hook(
206+
make_stats_hook("layer_0.fc_down")))
207+
hooks.append(layer.mlp.down_proj.register_forward_hook(
208+
make_store_rows_hook(f"layer_0.fc_down_rows_first{_FIRST_N}")))
209+
210+
hooks.append(layer.register_forward_hook(
211+
make_stats_hook("layer_0.block_out")))
212+
hooks.append(layer.register_forward_hook(
213+
make_store_rows_hook(f"layer_0.block_out_rows_first{_FIRST_N}")))
214+
215+
return hooks
216+
217+
# ── Formatting helpers ───────────────────────────────────────────────────────
218+
219+
def _fmt_num(x: float) -> str:
220+
return f"{x:.6g}"
221+
222+
def _print_table(name, rows_list):
223+
if not rows_list:
224+
print(f"{name}: (no rows captured)")
225+
return
226+
227+
cols = min(max(len(r) for r in rows_list), _FIRST_N)
228+
229+
formatted_rows = []
230+
for row in rows_list[:_MAX_ROWS]:
231+
formatted = []
232+
for j in range(cols):
233+
val = row[j] if j < len(row) else None
234+
formatted.append(_fmt_num(val) if val is not None else "")
235+
formatted_rows.append(formatted)
236+
237+
col_widths = []
238+
for j in range(cols):
239+
max_cell = max((len(r[j]) for r in formatted_rows), default=0)
240+
col_widths.append(max(len(f"C{j}"), max_cell))
241+
242+
header_cols = " | ".join(f"{f'C{j}':>{col_widths[j]}}" for j in range(cols))
243+
header = f"Row | {header_cols}"
244+
sep = "-" * len(header)
245+
246+
print(f"\n{name}")
247+
print(sep)
248+
print(header)
249+
print(sep)
250+
251+
for i, row in enumerate(formatted_rows):
252+
row_str = " | ".join(f"{row[j]:>{col_widths[j]}}" for j in range(cols))
253+
print(f"{i:3} | {row_str}")
254+
255+
if len(rows_list) > _MAX_ROWS:
256+
print(f"... ({len(rows_list)} rows captured, showing first {_MAX_ROWS})")
257+
258+
# ── Phase 1: Prefill (no hooks — just populate KV cache) ────────────────────
259+
260+
print(f"\n{'='*72}")
261+
print(f" PHASE 1: PREFILL ({input_ids.shape[1]} tokens, hooks disabled)")
262+
print(f"{'='*72}")
263+
264+
with torch.no_grad():
265+
prefill_out = model(input_ids, use_cache=True)
266+
past_key_values = prefill_out.past_key_values
267+
prefill_logits = prefill_out.logits
268+
269+
next_token_id = prefill_logits[0, -1, :].argmax().item()
270+
print(f"Prefill top prediction: {tokenizer.decode([next_token_id])!r} (id={next_token_id})")
271+
272+
top5 = torch.topk(prefill_logits[0, -1, :], 5)
273+
print("Top 5 prefill predictions:")
274+
for v, idx in zip(top5.values, top5.indices):
275+
print(f" {tokenizer.decode([idx.item()])!r:15} {_fmt_num(v.item())}")
276+
277+
# ── Phase 2: Decode loop — hook step 1 only (position 5, input ' the') ──────
278+
279+
print(f"\n{'='*72}")
280+
print(f" DECODE LOOP — 2 steps, hooking step 1 only (position 5)")
281+
print(f"{'='*72}")
282+
283+
num_decode_steps = 2
284+
current_token_id = next_token_id
285+
286+
for step in range(num_decode_steps):
287+
position = input_ids.shape[1] + step
288+
decode_input = torch.tensor([[current_token_id]])
289+
290+
hooks = []
291+
restore_rope = None
292+
293+
if step == 1:
294+
hooks = register_decode_hooks()
295+
restore_rope = install_rope_capture()
296+
297+
with torch.no_grad():
298+
decode_out = model(decode_input, past_key_values=past_key_values, use_cache=True)
299+
decode_logits = decode_out.logits
300+
past_key_values = decode_out.past_key_values
301+
302+
if step == 1:
303+
for h in hooks:
304+
h.remove()
305+
restore_rope()
306+
307+
predicted_id = decode_logits[0, -1, :].argmax().item()
308+
predicted_token = tokenizer.decode([predicted_id])
309+
310+
print(f" step={step} pos={position} in={tokenizer.decode([current_token_id])!r:10} -> {predicted_token!r}")
311+
312+
current_token_id = predicted_id
313+
314+
# ── Print layer 0 decode checkpoints ────────────────────────────────────────
315+
316+
print(f"\n{'='*72}")
317+
print(f" DECODE STEP 1 (pos=5, input=' the') — Layer 0 Checkpoints")
318+
print(f" Compare directly against Mila print_stats output")
319+
print(f"{'='*72}")
320+
321+
checkpoint_keys = [
322+
"token_embeds",
323+
"layer_0.rmsn_1",
324+
"layer_0.q_pre_rope",
325+
"layer_0.k_pre_rope",
326+
"layer_0.v_proj",
327+
"layer_0.attn_out",
328+
"layer_0.fc_out_proj",
329+
"layer_0.rmsn_2",
330+
"layer_0.gate_proj",
331+
"layer_0.up_proj",
332+
"layer_0.fc_down",
333+
"layer_0.block_out",
334+
]
335+
336+
for key in checkpoint_keys:
337+
val = captured.get(key, "not captured")
338+
if isinstance(val, dict) and "checksum" in val:
339+
print(f" {key}:")
340+
print(f" {_fmt_mila_stats(val)}")
341+
else:
342+
print(f" {key}: not captured")
343+
344+
# ── Post-RoPE K/Q comparison ─────────────────────────────────────────────────
345+
346+
print(f"\n{'='*72}")
347+
print(f" POST-RoPE K/Q (layer 0, head 0, first {_FIRST_N} elements)")
348+
print(f" Compare k_post_rope[0] against Mila decode.k cache row 5")
349+
print(f"{'='*72}")
350+
351+
k_post = captured.get("layer_0.k_post_rope", "not captured")
352+
q_post = captured.get("layer_0.q_post_rope", "not captured")
353+
print(f" k_post_rope (head 0): {[f'{x:.6f}' for x in k_post] if isinstance(k_post, list) else k_post}")
354+
print(f" q_post_rope (head 0): {[f'{x:.6f}' for x in q_post] if isinstance(q_post, list) else q_post}")
355+
356+
print(f"\n--- First {_FIRST_N} elements (up to {_MAX_ROWS} rows) ---")
357+
for k, v in captured_slices.items():
358+
rows = [v] if (v and isinstance(v[0], float)) else v
359+
_print_table(k, rows)
File renamed without changes.

0 commit comments

Comments
 (0)