Skip to content

Commit eed95d8

Browse files
authored
[nemotron_h] respect _no_reinit flag on dt_bias and out_proj.weight (#45591)
* [nemotron_h] respect _no_reinit flag on dt_bias and out_proj.weight _init_weights() on `NemotronHPreTrainedModel` unconditionally overwrites `dt_bias` (random `inv_softplus(dt)`) and `out_proj.weight` (kaiming_uniform scaled by 1/sqrt(n_layer)) every time it is invoked on a mamba block. It sets `module.dt_bias._no_reinit = True` after the copy, but the flag is never checked by either code path (only the Linear-bias branch reads it). On transformers>=5.0, `_init_weights` is triggered a second time after `from_pretrained()` has loaded the checkpoint (the post-load safety pass that initializes tensors staying on `meta`). For `NemotronHForCausalLM` that silently overwrites the checkpoint values for `dt_bias` and `out_proj.weight` with fresh random draws. The model then outputs repetitive stop-word streams like ` and and and and ,` for any input. Minimal repro with any Nemotron-H checkpoint: from transformers import AutoConfig, AutoModelForCausalLM from safetensors.torch import load_file import json, pathlib path = ".../NVIDIA-Nemotron-Cascade-2-30B-A3B-BF16" # or Nano cfg = AutoConfig.from_pretrained(path); cfg._attn_implementation='eager' m = AutoModelForCausalLM.from_pretrained(path, config=cfg, torch_dtype='bfloat16') idx = json.loads((pathlib.Path(path) / 'model.safetensors.index.json').read_text())['weight_map'] k = 'backbone.layers.0.mixer.dt_bias' on_disk = load_file(f'{path}/{idx[k]}')[k] in_mem = m.backbone.layers[0].mixer.dt_bias print((on_disk.float() - in_mem.float().cpu()).abs().max()) # ~26.8 This patch makes `_init_weights` honour `_no_reinit` on both `dt_bias` and `out_proj.weight` (the only two params that re-init unconditionally), and sets `_no_reinit = True` on `out_proj.weight` after the initial kaiming scale so a second pass is a no-op. Ordinary fresh-init training is unaffected; only the second invocation becomes idempotent. Signed-off-by: Min Zhou <minzhou@virtueai.com> * Switch to canonical _is_hf_initialized flag per review Per @Rocketknight1's review: replace the ad-hoc `_no_reinit` flag with the existing `_is_hf_initialized` flag that `from_pretrained` already sets on checkpoint-loaded parameters. Guard each Mamba2 init target (A_log / D / dt_bias) and the residual-scaled `out_proj.weight` independently, so parameters restored from a checkpoint survive any subsequent `_init_weights` pass. * Use _is_hf_initialized for nn.Linear.bias check too --------- Signed-off-by: Min Zhou <minzhou@virtueai.com>
1 parent 807d9d7 commit eed95d8

2 files changed

Lines changed: 52 additions & 38 deletions

File tree

src/transformers/models/nemotron_h/modeling_nemotron_h.py

Lines changed: 26 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -974,22 +974,27 @@ def _init_weights(self, module):
974974
"""Initialize the weights."""
975975
super()._init_weights(module)
976976
if isinstance(module, NemotronHMamba2Mixer):
977-
# Initialize A_log and D parameters
978-
A = torch.arange(1, self.config.mamba_num_heads + 1)
979-
init.copy_(module.A_log, torch.log(A))
980-
init.ones_(module.D)
981-
982-
dt = torch.exp(
983-
torch.rand(self.config.mamba_num_heads)
984-
* (math.log(self.config.time_step_max) - math.log(self.config.time_step_min))
985-
+ math.log(self.config.time_step_min)
986-
).clamp(min=self.config.time_step_floor)
987-
988-
# # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
989-
inv_dt = dt + torch.log(-torch.expm1(-dt))
990-
with torch.no_grad():
991-
init.copy_(module.dt_bias, inv_dt)
992-
module.dt_bias._no_reinit = True
977+
# Only re-initialise params that were NOT loaded from a checkpoint.
978+
# `_is_hf_initialized` is set by `from_pretrained` on each loaded
979+
# parameter; without this guard a post-load safety pass of
980+
# `_init_weights` would overwrite checkpoint values of
981+
# A_log / D / dt_bias with fresh random draws.
982+
if not getattr(module.A_log, "_is_hf_initialized", False):
983+
A = torch.arange(1, self.config.mamba_num_heads + 1)
984+
init.copy_(module.A_log, torch.log(A))
985+
if not getattr(module.D, "_is_hf_initialized", False):
986+
init.ones_(module.D)
987+
if not getattr(module.dt_bias, "_is_hf_initialized", False):
988+
dt = torch.exp(
989+
torch.rand(self.config.mamba_num_heads)
990+
* (math.log(self.config.time_step_max) - math.log(self.config.time_step_min))
991+
+ math.log(self.config.time_step_min)
992+
).clamp(min=self.config.time_step_floor)
993+
994+
# # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
995+
inv_dt = dt + torch.log(-torch.expm1(-dt))
996+
with torch.no_grad():
997+
init.copy_(module.dt_bias, inv_dt)
993998
elif isinstance(module, NemotronHTopkRouter):
994999
init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
9951000
init.zeros_(module.e_score_correction_bias)
@@ -1000,7 +1005,7 @@ def _init_weights(self, module):
10001005

10011006
if isinstance(module, nn.Linear):
10021007
if module.bias is not None:
1003-
if not getattr(module.bias, "_no_reinit", False):
1008+
if not getattr(module.bias, "_is_hf_initialized", False):
10041009
init.zeros_(module.bias)
10051010
elif isinstance(module, nn.Embedding):
10061011
init.normal_(module.weight, std=self.config.initializer_range)
@@ -1014,10 +1019,12 @@ def _init_weights(self, module):
10141019
# Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
10151020
for name, p in module.named_parameters():
10161021
if name == "out_proj.weight":
1022+
# Skip checkpoint-loaded weights so a post-load safety
1023+
# pass of `_init_weights` doesn't silently overwrite them.
1024+
if getattr(p, "_is_hf_initialized", False):
1025+
continue
10171026
# Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
10181027
# Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
1019-
# We need to reinit p since this code could be called multiple times
1020-
# Having just p *= scale would repeatedly scale it down
10211028
init.kaiming_uniform_(p, a=math.sqrt(5))
10221029
with torch.no_grad():
10231030
p_new = p / math.sqrt(self.config.num_hidden_layers)

src/transformers/models/nemotron_h/modular_nemotron_h.py

Lines changed: 26 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -327,22 +327,27 @@ def _init_weights(self, module):
327327
"""Initialize the weights."""
328328
super()._init_weights(module)
329329
if isinstance(module, NemotronHMamba2Mixer):
330-
# Initialize A_log and D parameters
331-
A = torch.arange(1, self.config.mamba_num_heads + 1)
332-
init.copy_(module.A_log, torch.log(A))
333-
init.ones_(module.D)
334-
335-
dt = torch.exp(
336-
torch.rand(self.config.mamba_num_heads)
337-
* (math.log(self.config.time_step_max) - math.log(self.config.time_step_min))
338-
+ math.log(self.config.time_step_min)
339-
).clamp(min=self.config.time_step_floor)
340-
341-
# # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
342-
inv_dt = dt + torch.log(-torch.expm1(-dt))
343-
with torch.no_grad():
344-
init.copy_(module.dt_bias, inv_dt)
345-
module.dt_bias._no_reinit = True
330+
# Only re-initialise params that were NOT loaded from a checkpoint.
331+
# `_is_hf_initialized` is set by `from_pretrained` on each loaded
332+
# parameter; without this guard a post-load safety pass of
333+
# `_init_weights` would overwrite checkpoint values of
334+
# A_log / D / dt_bias with fresh random draws.
335+
if not getattr(module.A_log, "_is_hf_initialized", False):
336+
A = torch.arange(1, self.config.mamba_num_heads + 1)
337+
init.copy_(module.A_log, torch.log(A))
338+
if not getattr(module.D, "_is_hf_initialized", False):
339+
init.ones_(module.D)
340+
if not getattr(module.dt_bias, "_is_hf_initialized", False):
341+
dt = torch.exp(
342+
torch.rand(self.config.mamba_num_heads)
343+
* (math.log(self.config.time_step_max) - math.log(self.config.time_step_min))
344+
+ math.log(self.config.time_step_min)
345+
).clamp(min=self.config.time_step_floor)
346+
347+
# # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
348+
inv_dt = dt + torch.log(-torch.expm1(-dt))
349+
with torch.no_grad():
350+
init.copy_(module.dt_bias, inv_dt)
346351
elif isinstance(module, NemotronHTopkRouter):
347352
init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
348353
init.zeros_(module.e_score_correction_bias)
@@ -353,7 +358,7 @@ def _init_weights(self, module):
353358

354359
if isinstance(module, nn.Linear):
355360
if module.bias is not None:
356-
if not getattr(module.bias, "_no_reinit", False):
361+
if not getattr(module.bias, "_is_hf_initialized", False):
357362
init.zeros_(module.bias)
358363
elif isinstance(module, nn.Embedding):
359364
init.normal_(module.weight, std=self.config.initializer_range)
@@ -367,10 +372,12 @@ def _init_weights(self, module):
367372
# Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
368373
for name, p in module.named_parameters():
369374
if name == "out_proj.weight":
375+
# Skip checkpoint-loaded weights so a post-load safety
376+
# pass of `_init_weights` doesn't silently overwrite them.
377+
if getattr(p, "_is_hf_initialized", False):
378+
continue
370379
# Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
371380
# Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
372-
# We need to reinit p since this code could be called multiple times
373-
# Having just p *= scale would repeatedly scale it down
374381
init.kaiming_uniform_(p, a=math.sqrt(5))
375382
with torch.no_grad():
376383
p_new = p / math.sqrt(self.config.num_hidden_layers)

0 commit comments

Comments
 (0)