Skip to content

Mkhona/hyperball fix#158

Open
mkhona-nvidia wants to merge 4 commits intoNVIDIA-NeMo:mainfrom
mkhona-nvidia:mkhona/hyperball_fix
Open

Mkhona/hyperball fix#158
mkhona-nvidia wants to merge 4 commits intoNVIDIA-NeMo:mainfrom
mkhona-nvidia:mkhona/hyperball_fix

Conversation

@mkhona-nvidia
Copy link
Copy Markdown
Contributor

Addressed #155 (comment) and removed some logic about initialization.

Signed-off-by: mikail <mkhona@nvidia.com>
Signed-off-by: mikail <mkhona@nvidia.com>
Signed-off-by: mikail <mkhona@nvidia.com>
Signed-off-by: mikail <mkhona@nvidia.com>
@mkhona-nvidia mkhona-nvidia requested a review from skyw April 7, 2026 20:34
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented Apr 7, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Apr 7, 2026

Greptile Summary

This PR removes the auto-rescaling initialization logic from MuonHyperball and makes hyperball_radius a required argument. Instead of silently rescaling parameters to the target radius at construction time, the optimizer now validates that all parameters already match the specified radius and raises a ValueError otherwise.

The only remaining finding is a minor hardening gap: hyperball_radius is not validated to be strictly positive, which could produce a confusing error message for invalid inputs.

Confidence Score: 5/5

Safe to merge; all remaining findings are P2 style/hardening suggestions that do not affect correctness.

The core logic change is correct — removing auto-rescaling and enforcing a strict pre-condition keeps the invariant clear. The only gap is a missing positive-value guard on hyperball_radius, which is a best-practice hardening rather than a present defect.

No files require special attention.

Important Files Changed

Filename Overview
emerging_optimizers/orthogonalized_optimizers/muon_hyperball.py Removes auto-rescaling init logic; enforces strict norm-matching pre-condition; lazy-initialises hyperball_R tensor in optimizer state on first step.

Sequence Diagram

sequenceDiagram
    participant User
    participant MuonHyperball
    participant OrthogonalizedOptimizer

    User->>MuonHyperball: __init__(hyperball_radius=R)
    MuonHyperball->>MuonHyperball: store hyperball_eps, hyperball_radius
    MuonHyperball->>OrthogonalizedOptimizer: super().__init__()
    MuonHyperball->>MuonHyperball: validate each p_norm == R (torch.isclose)
    note over MuonHyperball: Raises ValueError if p_norm=0 or p_norm≠R

    User->>OrthogonalizedOptimizer: step()
    OrthogonalizedOptimizer->>OrthogonalizedOptimizer: _init_group() — init momentum_buffer
    loop for each parameter p with grad
        OrthogonalizedOptimizer->>OrthogonalizedOptimizer: apply weight decay
        OrthogonalizedOptimizer->>OrthogonalizedOptimizer: update momentum buffer
        OrthogonalizedOptimizer->>OrthogonalizedOptimizer: orthogonalize (Newton-Schulz)
        OrthogonalizedOptimizer->>MuonHyperball: pre_weight_update_fn_inplace(p, update)
        note over MuonHyperball: Lazy-init hyperball_R tensor in state[p]\nNormalise update → R·normalize(update)
        OrthogonalizedOptimizer->>OrthogonalizedOptimizer: p -= lr · update
        OrthogonalizedOptimizer->>MuonHyperball: post_weight_update_fn_inplace(p)
        note over MuonHyperball: Re-project p onto hypersphere:\np = R · normalize(p)
    end
Loading

Reviews (1): Last reviewed commit: "linting" | Re-trigger Greptile

Comment on lines 62 to 71
def __init__(
self,
*args: Any,
hyperball_eps: float = 1e-8,
hyperball_radius: float | None = None,
hyperball_radius: float,
**kwargs: Any,
) -> None:
self.hyperball_eps = hyperball_eps
self.hyperball_radius = hyperball_radius
super().__init__(*args, **kwargs)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 No validation that hyperball_radius is positive

The parameter is now required but has no guard against non-positive values. If hyperball_radius <= 0 is passed, the p_norm == 0 check at line 77 won't help (parameter norms are always ≥ 0 and isclose with a negative target will always fail), so the user receives a confusing norm-mismatch error instead of a clear message about an invalid radius.

Suggested change
def __init__(
self,
*args: Any,
hyperball_eps: float = 1e-8,
hyperball_radius: float | None = None,
hyperball_radius: float,
**kwargs: Any,
) -> None:
self.hyperball_eps = hyperball_eps
self.hyperball_radius = hyperball_radius
super().__init__(*args, **kwargs)
self.hyperball_eps = hyperball_eps
self.hyperball_radius = hyperball_radius
if hyperball_radius <= 0:
raise ValueError(
f"hyperball_radius must be positive, got {hyperball_radius}."
)
super().__init__(*args, **kwargs)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a default value, negative check is optional.

Copy link
Copy Markdown
Contributor

@skyw skyw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code itself is fine, but it doesn't serve the purpose to fix #155

Comment on lines 62 to 71
def __init__(
self,
*args: Any,
hyperball_eps: float = 1e-8,
hyperball_radius: float | None = None,
hyperball_radius: float,
**kwargs: Any,
) -> None:
self.hyperball_eps = hyperball_eps
self.hyperball_radius = hyperball_radius
super().__init__(*args, **kwargs)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a default value, negative check is optional.

self.state[p]["hyperball_R"] = R
if "hyperball_R" not in self.state[p]:
self.state[p]["hyperball_R"] = torch.tensor(self.hyperball_radius, dtype=p.dtype, device=p.device)
R = self.state[p]["hyperball_R"]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably missed it in previous review, hyperball_R and hyperball_radius are inconsistent, should use hyperball_radius for both

"MuonHyperball requires all parameters to have non-zero norm. "
"Found parameter with zero norm."
)
if not torch.isclose(
Copy link
Copy Markdown
Contributor

@skyw skyw Apr 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In consistent with above p_norm == 0 check, can use torch.equal.

NOTE: this is potentially a regression than optimization as more host device sync can be triggered. In the worst case, a cudaMalloc.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants