1+ import struct
2+ import torch
3+ from transformers import AutoModelForCausalLM , AutoTokenizer
4+
5+ model_id = "meta-llama/Llama-3.2-1B"
6+
7+ model = AutoModelForCausalLM .from_pretrained (model_id , torch_dtype = torch .float32 )
8+ model .eval ()
9+
10+ tokenizer = AutoTokenizer .from_pretrained (model_id )
11+
12+ prompt = "Once upon a time"
13+ input_ids = tokenizer .encode (prompt , return_tensors = "pt" , add_special_tokens = False )
14+ print (f"Prompt: { prompt !r} " )
15+ print (f"Token ids: { input_ids [0 ].tolist ()} " )
16+ print (f"Tokens: { [tokenizer .decode ([t ]) for t in input_ids [0 ].tolist ()]} " )
17+
18+ # ── Constants ────────────────────────────────────────────────────────────────
19+
20+ _FIRST_N = 16
21+ _MAX_ROWS = 4
22+
23+ # ── Capture storage ──────────────────────────────────────────────────────────
24+
25+ captured = {}
26+ captured_slices = {}
27+
28+ # ── Stats helpers ────────────────────────────────────────────────────────────
29+
30+ def fnv1a_checksum_last_token (t : torch .Tensor ) -> int :
31+ FNV_OFFSET = 1469598103934665603
32+ FNV_PRIME = 1099511628211
33+ MASK64 = 0xFFFFFFFFFFFFFFFF
34+
35+ last = t [0 , :] if t .dim () == 2 else t [0 , - 1 , :]
36+ checksum = FNV_OFFSET
37+
38+ for val in last .detach ().cpu ().to (torch .float32 ):
39+ bits = struct .unpack ('<I' , struct .pack ('<f' , float (val )))[0 ]
40+ for byte_idx in range (4 ):
41+ b = (bits >> (byte_idx * 8 )) & 0xFF
42+ checksum = ((checksum ^ b ) * FNV_PRIME ) & MASK64
43+
44+ return checksum
45+
46+ def stats_and_checksum_last_token (t : torch .Tensor ) -> dict :
47+ """
48+ min/max/mean/std/checksum for the last-token vector.
49+ std uses population formula (divide by N) to match Mila's print_stats.
50+ """
51+ last = t [0 , :] if t .dim () == 2 else t [0 , - 1 , :]
52+ last_f32 = last .detach ().cpu ().to (torch .float32 )
53+ n = last_f32 .numel ()
54+ mean_val = last_f32 .sum ().item () / n
55+ var_val = ((last_f32 - mean_val ) ** 2 ).sum ().item () / n
56+
57+ return {
58+ "min" : last_f32 .min ().item (),
59+ "max" : last_f32 .max ().item (),
60+ "mean" : mean_val ,
61+ "std" : var_val ** 0.5 ,
62+ "checksum" : fnv1a_checksum_last_token (t ),
63+ }
64+
65+ def _fmt_mila_stats (stats : dict ) -> str :
66+ return (
67+ f"min={ stats ['min' ]:.6f} max={ stats ['max' ]:.6f} "
68+ f"mean={ stats ['mean' ]:.6f} std={ stats ['std' ]:.6f} "
69+ f"checksum=0x{ stats ['checksum' ]:016x} "
70+ )
71+
72+ # ── Hook factories ───────────────────────────────────────────────────────────
73+
74+ def make_stats_hook (name ):
75+ def fn (module , input , output ):
76+ t = output if isinstance (output , torch .Tensor ) else output [0 ]
77+ captured [name ] = stats_and_checksum_last_token (t )
78+ return fn
79+
80+ def make_store_rows_hook (name , rows = _MAX_ROWS , cols = _FIRST_N ):
81+ def fn (module , input , output ):
82+ t = output if isinstance (output , torch .Tensor ) else output [0 ]
83+ rows_list = []
84+ if t .dim () == 2 :
85+ rows_list .append ([float (x ) for x in t [0 , :cols ].detach ().cpu ().tolist ()])
86+ else :
87+ use_rows = min (rows , t .size (1 ))
88+ for r in range (use_rows ):
89+ rows_list .append ([float (x ) for x in t [0 , r , :cols ].detach ().cpu ().tolist ()])
90+ captured_slices [name ] = rows_list
91+ return fn
92+
93+ def make_attn_output_pre_hook (layer_index ):
94+ """Pre-hook on o_proj — matches Mila's attn_->decode() return value."""
95+ def fn (module , input ):
96+ t = input [0 ] if isinstance (input , (tuple , list )) else input
97+ captured [f"layer_{ layer_index } .attn_out" ] = stats_and_checksum_last_token (t )
98+ rows_list = []
99+ use_rows = min (_MAX_ROWS , t .size (1 ) if t .dim () == 3 else 1 )
100+ for r in range (use_rows ):
101+ vals = t [0 , r , :_FIRST_N ] if t .dim () == 3 else t [0 , :_FIRST_N ]
102+ rows_list .append ([float (x ) for x in vals .detach ().cpu ().tolist ()])
103+ captured_slices [f"layer_{ layer_index } .attn_out_rows_first{ _FIRST_N } " ] = rows_list
104+ return fn
105+
106+ # ── Post-RoPE K/Q capture via apply_rotary_pos_emb monkey-patch ─────────────
107+
108+ def install_rope_capture ():
109+ """
110+ Monkey-patch apply_rotary_pos_emb in the Llama modeling module to capture
111+ post-RoPE Q and K for layer 0. Returns a restore function.
112+ """
113+ import transformers .models .llama .modeling_llama as llama_mod
114+
115+ original_rope = llama_mod .apply_rotary_pos_emb
116+
117+ def capturing_rope (q , k , cos , sin ):
118+ q_rot , k_rot = original_rope (q , k , cos , sin )
119+
120+ # q_rot: [B, n_heads, T, head_dim]
121+ # k_rot: [B, n_kv_heads, T, head_dim]
122+ # Capture last token, head 0, first _FIRST_N elements
123+ captured ["layer_0.k_post_rope" ] = k_rot [0 , 0 , - 1 , :_FIRST_N ].detach ().cpu ().tolist ()
124+ captured ["layer_0.q_post_rope" ] = q_rot [0 , 0 , - 1 , :_FIRST_N ].detach ().cpu ().tolist ()
125+
126+ # Also capture all KV heads for the last token (first _FIRST_N elements each)
127+ # so we can compare against Mila's full K cache row at position
128+ n_kv_heads = k_rot .shape [1 ]
129+ k_all_heads = []
130+ for h in range (n_kv_heads ):
131+ k_all_heads .append ([float (x ) for x in k_rot [0 , h , - 1 , :_FIRST_N ].detach ().cpu ().tolist ()])
132+ captured_slices ["layer_0.k_post_rope_all_heads" ] = k_all_heads
133+
134+ return q_rot , k_rot
135+
136+ llama_mod .apply_rotary_pos_emb = capturing_rope
137+
138+ def restore ():
139+ llama_mod .apply_rotary_pos_emb = original_rope
140+
141+ return restore
142+
143+ # ── Hook registration — decode only, layer 0 ────────────────────────────────
144+
145+ def register_decode_hooks ():
146+ hooks = []
147+ layer = model .model .layers [0 ]
148+
149+ if hasattr (model .model , "embed_tokens" ):
150+ hooks .append (model .model .embed_tokens .register_forward_hook (
151+ make_stats_hook ("token_embeds" )))
152+ hooks .append (model .model .embed_tokens .register_forward_hook (
153+ make_store_rows_hook (f"token_embeds_rows_first{ _FIRST_N } " )))
154+
155+ if hasattr (layer , "input_layernorm" ):
156+ hooks .append (layer .input_layernorm .register_forward_hook (
157+ make_stats_hook ("layer_0.rmsn_1" )))
158+ hooks .append (layer .input_layernorm .register_forward_hook (
159+ make_store_rows_hook (f"layer_0.rmsn_1_rows_first{ _FIRST_N } " )))
160+
161+ if hasattr (layer .self_attn , "q_proj" ):
162+ hooks .append (layer .self_attn .q_proj .register_forward_hook (
163+ make_stats_hook ("layer_0.q_pre_rope" )))
164+ hooks .append (layer .self_attn .q_proj .register_forward_hook (
165+ make_store_rows_hook (f"layer_0.q_pre_rope_rows_first{ _FIRST_N } " )))
166+
167+ if hasattr (layer .self_attn , "k_proj" ):
168+ hooks .append (layer .self_attn .k_proj .register_forward_hook (
169+ make_stats_hook ("layer_0.k_pre_rope" )))
170+ hooks .append (layer .self_attn .k_proj .register_forward_hook (
171+ make_store_rows_hook (f"layer_0.k_pre_rope_rows_first{ _FIRST_N } " )))
172+
173+ if hasattr (layer .self_attn , "v_proj" ):
174+ hooks .append (layer .self_attn .v_proj .register_forward_hook (
175+ make_stats_hook ("layer_0.v_proj" )))
176+ hooks .append (layer .self_attn .v_proj .register_forward_hook (
177+ make_store_rows_hook (f"layer_0.v_proj_rows_first{ _FIRST_N } " )))
178+
179+ if hasattr (layer .self_attn , "o_proj" ):
180+ hooks .append (layer .self_attn .o_proj .register_forward_pre_hook (
181+ make_attn_output_pre_hook (0 )))
182+ hooks .append (layer .self_attn .o_proj .register_forward_hook (
183+ make_stats_hook ("layer_0.fc_out_proj" )))
184+ hooks .append (layer .self_attn .o_proj .register_forward_hook (
185+ make_store_rows_hook (f"layer_0.fc_out_proj_rows_first{ _FIRST_N } " )))
186+
187+ if hasattr (layer , "post_attention_layernorm" ):
188+ hooks .append (layer .post_attention_layernorm .register_forward_hook (
189+ make_stats_hook ("layer_0.rmsn_2" )))
190+ hooks .append (layer .post_attention_layernorm .register_forward_hook (
191+ make_store_rows_hook (f"layer_0.rmsn_2_rows_first{ _FIRST_N } " )))
192+
193+ if hasattr (layer , "mlp" ):
194+ if hasattr (layer .mlp , "gate_proj" ):
195+ hooks .append (layer .mlp .gate_proj .register_forward_hook (
196+ make_stats_hook ("layer_0.gate_proj" )))
197+ hooks .append (layer .mlp .gate_proj .register_forward_hook (
198+ make_store_rows_hook (f"layer_0.gate_proj_rows_first{ _FIRST_N } " )))
199+ if hasattr (layer .mlp , "up_proj" ):
200+ hooks .append (layer .mlp .up_proj .register_forward_hook (
201+ make_stats_hook ("layer_0.up_proj" )))
202+ hooks .append (layer .mlp .up_proj .register_forward_hook (
203+ make_store_rows_hook (f"layer_0.up_proj_rows_first{ _FIRST_N } " )))
204+ if hasattr (layer .mlp , "down_proj" ):
205+ hooks .append (layer .mlp .down_proj .register_forward_hook (
206+ make_stats_hook ("layer_0.fc_down" )))
207+ hooks .append (layer .mlp .down_proj .register_forward_hook (
208+ make_store_rows_hook (f"layer_0.fc_down_rows_first{ _FIRST_N } " )))
209+
210+ hooks .append (layer .register_forward_hook (
211+ make_stats_hook ("layer_0.block_out" )))
212+ hooks .append (layer .register_forward_hook (
213+ make_store_rows_hook (f"layer_0.block_out_rows_first{ _FIRST_N } " )))
214+
215+ return hooks
216+
217+ # ── Formatting helpers ───────────────────────────────────────────────────────
218+
219+ def _fmt_num (x : float ) -> str :
220+ return f"{ x :.6g} "
221+
222+ def _print_table (name , rows_list ):
223+ if not rows_list :
224+ print (f"{ name } : (no rows captured)" )
225+ return
226+
227+ cols = min (max (len (r ) for r in rows_list ), _FIRST_N )
228+
229+ formatted_rows = []
230+ for row in rows_list [:_MAX_ROWS ]:
231+ formatted = []
232+ for j in range (cols ):
233+ val = row [j ] if j < len (row ) else None
234+ formatted .append (_fmt_num (val ) if val is not None else "" )
235+ formatted_rows .append (formatted )
236+
237+ col_widths = []
238+ for j in range (cols ):
239+ max_cell = max ((len (r [j ]) for r in formatted_rows ), default = 0 )
240+ col_widths .append (max (len (f"C{ j } " ), max_cell ))
241+
242+ header_cols = " | " .join (f"{ f'C{ j } ' :>{col_widths [j ]}} " for j in range (cols ))
243+ header = f"Row | { header_cols } "
244+ sep = "-" * len (header )
245+
246+ print (f"\n { name } " )
247+ print (sep )
248+ print (header )
249+ print (sep )
250+
251+ for i , row in enumerate (formatted_rows ):
252+ row_str = " | " .join (f"{ row [j ]:>{col_widths [j ]}} " for j in range (cols ))
253+ print (f"{ i :3} | { row_str } " )
254+
255+ if len (rows_list ) > _MAX_ROWS :
256+ print (f"... ({ len (rows_list )} rows captured, showing first { _MAX_ROWS } )" )
257+
258+ # ── Phase 1: Prefill (no hooks — just populate KV cache) ────────────────────
259+
260+ print (f"\n { '=' * 72 } " )
261+ print (f" PHASE 1: PREFILL ({ input_ids .shape [1 ]} tokens, hooks disabled)" )
262+ print (f"{ '=' * 72 } " )
263+
264+ with torch .no_grad ():
265+ prefill_out = model (input_ids , use_cache = True )
266+ past_key_values = prefill_out .past_key_values
267+ prefill_logits = prefill_out .logits
268+
269+ next_token_id = prefill_logits [0 , - 1 , :].argmax ().item ()
270+ print (f"Prefill top prediction: { tokenizer .decode ([next_token_id ])!r} (id={ next_token_id } )" )
271+
272+ top5 = torch .topk (prefill_logits [0 , - 1 , :], 5 )
273+ print ("Top 5 prefill predictions:" )
274+ for v , idx in zip (top5 .values , top5 .indices ):
275+ print (f" { tokenizer .decode ([idx .item ()])!r:15} { _fmt_num (v .item ())} " )
276+
277+ # ── Phase 2: Decode loop — hook step 1 only (position 5, input ' the') ──────
278+
279+ print (f"\n { '=' * 72 } " )
280+ print (f" DECODE LOOP — 2 steps, hooking step 1 only (position 5)" )
281+ print (f"{ '=' * 72 } " )
282+
283+ num_decode_steps = 2
284+ current_token_id = next_token_id
285+
286+ for step in range (num_decode_steps ):
287+ position = input_ids .shape [1 ] + step
288+ decode_input = torch .tensor ([[current_token_id ]])
289+
290+ hooks = []
291+ restore_rope = None
292+
293+ if step == 1 :
294+ hooks = register_decode_hooks ()
295+ restore_rope = install_rope_capture ()
296+
297+ with torch .no_grad ():
298+ decode_out = model (decode_input , past_key_values = past_key_values , use_cache = True )
299+ decode_logits = decode_out .logits
300+ past_key_values = decode_out .past_key_values
301+
302+ if step == 1 :
303+ for h in hooks :
304+ h .remove ()
305+ restore_rope ()
306+
307+ predicted_id = decode_logits [0 , - 1 , :].argmax ().item ()
308+ predicted_token = tokenizer .decode ([predicted_id ])
309+
310+ print (f" step={ step } pos={ position } in={ tokenizer .decode ([current_token_id ])!r:10} -> { predicted_token !r} " )
311+
312+ current_token_id = predicted_id
313+
314+ # ── Print layer 0 decode checkpoints ────────────────────────────────────────
315+
316+ print (f"\n { '=' * 72 } " )
317+ print (f" DECODE STEP 1 (pos=5, input=' the') — Layer 0 Checkpoints" )
318+ print (f" Compare directly against Mila print_stats output" )
319+ print (f"{ '=' * 72 } " )
320+
321+ checkpoint_keys = [
322+ "token_embeds" ,
323+ "layer_0.rmsn_1" ,
324+ "layer_0.q_pre_rope" ,
325+ "layer_0.k_pre_rope" ,
326+ "layer_0.v_proj" ,
327+ "layer_0.attn_out" ,
328+ "layer_0.fc_out_proj" ,
329+ "layer_0.rmsn_2" ,
330+ "layer_0.gate_proj" ,
331+ "layer_0.up_proj" ,
332+ "layer_0.fc_down" ,
333+ "layer_0.block_out" ,
334+ ]
335+
336+ for key in checkpoint_keys :
337+ val = captured .get (key , "not captured" )
338+ if isinstance (val , dict ) and "checksum" in val :
339+ print (f" { key } :" )
340+ print (f" { _fmt_mila_stats (val )} " )
341+ else :
342+ print (f" { key } : not captured" )
343+
344+ # ── Post-RoPE K/Q comparison ─────────────────────────────────────────────────
345+
346+ print (f"\n { '=' * 72 } " )
347+ print (f" POST-RoPE K/Q (layer 0, head 0, first { _FIRST_N } elements)" )
348+ print (f" Compare k_post_rope[0] against Mila decode.k cache row 5" )
349+ print (f"{ '=' * 72 } " )
350+
351+ k_post = captured .get ("layer_0.k_post_rope" , "not captured" )
352+ q_post = captured .get ("layer_0.q_post_rope" , "not captured" )
353+ print (f" k_post_rope (head 0): { [f'{ x :.6f} ' for x in k_post ] if isinstance (k_post , list ) else k_post } " )
354+ print (f" q_post_rope (head 0): { [f'{ x :.6f} ' for x in q_post ] if isinstance (q_post , list ) else q_post } " )
355+
356+ print (f"\n --- First { _FIRST_N } elements (up to { _MAX_ROWS } rows) ---" )
357+ for k , v in captured_slices .items ():
358+ rows = [v ] if (v and isinstance (v [0 ], float )) else v
359+ _print_table (k , rows )
0 commit comments