Commit eed95d8
authored
[nemotron_h] respect _no_reinit flag on dt_bias and out_proj.weight (#45591)
* [nemotron_h] respect _no_reinit flag on dt_bias and out_proj.weight
_init_weights() on `NemotronHPreTrainedModel` unconditionally overwrites
`dt_bias` (random `inv_softplus(dt)`) and `out_proj.weight` (kaiming_uniform
scaled by 1/sqrt(n_layer)) every time it is invoked on a mamba block.
It sets `module.dt_bias._no_reinit = True` after the copy, but the flag is
never checked by either code path (only the Linear-bias branch reads it).
On transformers>=5.0, `_init_weights` is triggered a second time after
`from_pretrained()` has loaded the checkpoint (the post-load safety pass
that initializes tensors staying on `meta`). For `NemotronHForCausalLM`
that silently overwrites the checkpoint values for `dt_bias` and
`out_proj.weight` with fresh random draws. The model then outputs
repetitive stop-word streams like ` and and and and ,` for any input.
Minimal repro with any Nemotron-H checkpoint:
from transformers import AutoConfig, AutoModelForCausalLM
from safetensors.torch import load_file
import json, pathlib
path = ".../NVIDIA-Nemotron-Cascade-2-30B-A3B-BF16" # or Nano
cfg = AutoConfig.from_pretrained(path); cfg._attn_implementation='eager'
m = AutoModelForCausalLM.from_pretrained(path, config=cfg, torch_dtype='bfloat16')
idx = json.loads((pathlib.Path(path) / 'model.safetensors.index.json').read_text())['weight_map']
k = 'backbone.layers.0.mixer.dt_bias'
on_disk = load_file(f'{path}/{idx[k]}')[k]
in_mem = m.backbone.layers[0].mixer.dt_bias
print((on_disk.float() - in_mem.float().cpu()).abs().max()) # ~26.8
This patch makes `_init_weights` honour `_no_reinit` on both `dt_bias` and
`out_proj.weight` (the only two params that re-init unconditionally), and
sets `_no_reinit = True` on `out_proj.weight` after the initial kaiming
scale so a second pass is a no-op. Ordinary fresh-init training is
unaffected; only the second invocation becomes idempotent.
Signed-off-by: Min Zhou <minzhou@virtueai.com>
* Switch to canonical _is_hf_initialized flag per review
Per @Rocketknight1's review: replace the ad-hoc `_no_reinit` flag with the
existing `_is_hf_initialized` flag that `from_pretrained` already sets on
checkpoint-loaded parameters. Guard each Mamba2 init target
(A_log / D / dt_bias) and the residual-scaled `out_proj.weight`
independently, so parameters restored from a checkpoint survive any
subsequent `_init_weights` pass.
* Use _is_hf_initialized for nn.Linear.bias check too
---------
Signed-off-by: Min Zhou <minzhou@virtueai.com>1 parent 807d9d7 commit eed95d8
2 files changed
Lines changed: 52 additions & 38 deletions
Lines changed: 26 additions & 19 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
974 | 974 | | |
975 | 975 | | |
976 | 976 | | |
977 | | - | |
978 | | - | |
979 | | - | |
980 | | - | |
981 | | - | |
982 | | - | |
983 | | - | |
984 | | - | |
985 | | - | |
986 | | - | |
987 | | - | |
988 | | - | |
989 | | - | |
990 | | - | |
991 | | - | |
992 | | - | |
| 977 | + | |
| 978 | + | |
| 979 | + | |
| 980 | + | |
| 981 | + | |
| 982 | + | |
| 983 | + | |
| 984 | + | |
| 985 | + | |
| 986 | + | |
| 987 | + | |
| 988 | + | |
| 989 | + | |
| 990 | + | |
| 991 | + | |
| 992 | + | |
| 993 | + | |
| 994 | + | |
| 995 | + | |
| 996 | + | |
| 997 | + | |
993 | 998 | | |
994 | 999 | | |
995 | 1000 | | |
| |||
1000 | 1005 | | |
1001 | 1006 | | |
1002 | 1007 | | |
1003 | | - | |
| 1008 | + | |
1004 | 1009 | | |
1005 | 1010 | | |
1006 | 1011 | | |
| |||
1014 | 1019 | | |
1015 | 1020 | | |
1016 | 1021 | | |
| 1022 | + | |
| 1023 | + | |
| 1024 | + | |
| 1025 | + | |
1017 | 1026 | | |
1018 | 1027 | | |
1019 | | - | |
1020 | | - | |
1021 | 1028 | | |
1022 | 1029 | | |
1023 | 1030 | | |
| |||
Lines changed: 26 additions & 19 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
327 | 327 | | |
328 | 328 | | |
329 | 329 | | |
330 | | - | |
331 | | - | |
332 | | - | |
333 | | - | |
334 | | - | |
335 | | - | |
336 | | - | |
337 | | - | |
338 | | - | |
339 | | - | |
340 | | - | |
341 | | - | |
342 | | - | |
343 | | - | |
344 | | - | |
345 | | - | |
| 330 | + | |
| 331 | + | |
| 332 | + | |
| 333 | + | |
| 334 | + | |
| 335 | + | |
| 336 | + | |
| 337 | + | |
| 338 | + | |
| 339 | + | |
| 340 | + | |
| 341 | + | |
| 342 | + | |
| 343 | + | |
| 344 | + | |
| 345 | + | |
| 346 | + | |
| 347 | + | |
| 348 | + | |
| 349 | + | |
| 350 | + | |
346 | 351 | | |
347 | 352 | | |
348 | 353 | | |
| |||
353 | 358 | | |
354 | 359 | | |
355 | 360 | | |
356 | | - | |
| 361 | + | |
357 | 362 | | |
358 | 363 | | |
359 | 364 | | |
| |||
367 | 372 | | |
368 | 373 | | |
369 | 374 | | |
| 375 | + | |
| 376 | + | |
| 377 | + | |
| 378 | + | |
370 | 379 | | |
371 | 380 | | |
372 | | - | |
373 | | - | |
374 | 381 | | |
375 | 382 | | |
376 | 383 | | |
| |||
0 commit comments