Summary
Request to investigate adding critical sharpness (λ_c) as a loggable training diagnostic.
Critical sharpness measures loss landscape curvature using only forward passes — no Hessian computation — making it practical at scale. Validated on OLMo-2 up to 77B parameters.
Motivation
Training instabilities (loss spikes, divergence) are common in large-scale runs but difficult to diagnose. Standard metrics (loss, gradient norm) are lagging indicators. Critical sharpness provides a leading signal: it tracks how close the current learning rate is to the stability threshold, and progressive sharpening (curvature increasing over training) predicts instability before it manifests as a loss spike.
Key findings at scale (OLMo-2, 7B–77B):
- Progressive sharpening persists throughout pre-training (4T tokens) and mid-training (50B tokens) — first demonstration at this scale
- Critical sharpness tracks pre-conditioned Hessian sharpness, validating it as a cheap proxy
- Relative critical sharpness can guide data mixing decisions — identifies the ratio at which all task curvatures balance, allowing the largest stable learning rate without expensive ablations
Relevance to Emerging Optimizers:
- Enables apples-to-apples comparison of how different optimizers (Muon, Shampoo, SOAP, AdamW) interact with loss landscape curvature
- Guides learning rate selection relative to the stability edge for any optimizer
Requested Feature
A periodic diagnostic utility that:
- At a configurable step interval, runs ~5-6 extra forward passes with varied learning rates on the current update direction
- Computes and logs critical sharpness λ_c (and optionally relative critical sharpness for multi-dataset runs)
References
Summary
Request to investigate adding critical sharpness (λ_c) as a loggable training diagnostic.
Critical sharpness measures loss landscape curvature using only forward passes — no Hessian computation — making it practical at scale. Validated on OLMo-2 up to 77B parameters.
Motivation
Training instabilities (loss spikes, divergence) are common in large-scale runs but difficult to diagnose. Standard metrics (loss, gradient norm) are lagging indicators. Critical sharpness provides a leading signal: it tracks how close the current learning rate is to the stability threshold, and progressive sharpening (curvature increasing over training) predicts instability before it manifests as a loss spike.
Key findings at scale (OLMo-2, 7B–77B):
Relevance to Emerging Optimizers:
Requested Feature
A periodic diagnostic utility that:
References