Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 27 additions & 19 deletions emerging_optimizers/orthogonalized_optimizers/muon_hyperball.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ class MuonHyperball(muon.Muon):

W_{t+1} = R \\cdot \\text{normalize}(W_t - \\text{lr} \\cdot R \\cdot \\text{normalize}(\\text{update}))

where :math:`R` is the Frobenius norm of :math:`W_t` (or a user-specified radius). This keeps
the weight matrix at constant scale while updating.
where :math:`R` is the user-specified Frobenius norm. This keeps the weight matrix at
constant scale while updating.

Warning:
This optimizer is experimental and may change in future versions.
Expand All @@ -49,52 +49,60 @@ class MuonHyperball(muon.Muon):
*args: Arguments passed to Muon.
hyperball_eps: Epsilon for numerical stability in normalization.
Default: ``1e-8``.
hyperball_radius: Fixed radius for the hyperball. If ``None`` (default),
uses each parameter's initial Frobenius norm as its radius. If specified, all
parameters will be rescaled to have this radius at initialization.
hyperball_radius: Fixed radius for the hyperball. All parameters must
already have this Frobenius norm at construction time.
**kwargs: Keyword arguments passed to Muon.

Raises:
ValueError: If any parameter has zero norm, or if a parameter's
Frobenius norm does not match ``hyperball_radius``.

"""

def __init__(
self,
*args: Any,
hyperball_eps: float = 1e-8,
hyperball_radius: float | None = None,
hyperball_radius: float,
**kwargs: Any,
) -> None:
self.hyperball_eps = hyperball_eps
self.hyperball_radius = hyperball_radius
super().__init__(*args, **kwargs)
Comment on lines 62 to 71
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 No validation that hyperball_radius is positive

The parameter is now required but has no guard against non-positive values. If hyperball_radius <= 0 is passed, the p_norm == 0 check at line 77 won't help (parameter norms are always ≥ 0 and isclose with a negative target will always fail), so the user receives a confusing norm-mismatch error instead of a clear message about an invalid radius.

Suggested change
def __init__(
self,
*args: Any,
hyperball_eps: float = 1e-8,
hyperball_radius: float | None = None,
hyperball_radius: float,
**kwargs: Any,
) -> None:
self.hyperball_eps = hyperball_eps
self.hyperball_radius = hyperball_radius
super().__init__(*args, **kwargs)
self.hyperball_eps = hyperball_eps
self.hyperball_radius = hyperball_radius
if hyperball_radius <= 0:
raise ValueError(
f"hyperball_radius must be positive, got {hyperball_radius}."
)
super().__init__(*args, **kwargs)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a default value, negative check is optional.


# Validate and optionally rescale parameters based on hyperball_radius.
with torch.no_grad():
for group in self.param_groups:
for p in group["params"]:
p_norm = p.norm()
# Validate that parameter has non-zero norm.
if p_norm.item() == 0:
if p_norm == 0:
raise ValueError(
"MuonHyperball requires all parameters to have non-zero norm. "
"Found parameter with zero norm."
)
if not torch.isclose(
Copy link
Copy Markdown
Contributor

@skyw skyw Apr 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In consistent with above p_norm == 0 check, can use torch.equal.

NOTE: this is potentially a regression than optimization as more host device sync can be triggered. In the worst case, a cudaMalloc.

p_norm,
torch.tensor(self.hyperball_radius, dtype=p.dtype, device=p.device),
rtol=1e-5,
atol=1e-8,
):
raise ValueError(
"MuonHyperball requires all parameters to have non-zero norm. Found parameter with zero norm."
f"hyperball_radius={self.hyperball_radius} was specified but a parameter "
f"has Frobenius norm {p_norm.item()}. Rescale your model parameters to the "
f"desired radius before constructing the optimizer."
)
# Rescale parameter to have the specified radius if provided.
if self.hyperball_radius is not None:
p.mul_(self.hyperball_radius / p_norm.clamp_min(self.hyperball_eps))

@override
def pre_weight_update_fn_inplace(self, p: torch.Tensor, update: torch.Tensor) -> None:
"""Store the original weight norm and normalize the update using Frobenius norm.
"""Normalize the update using Frobenius norm, scaled by R.

Args:
p: The parameter tensor.
update: The orthogonalized gradient tensor.
"""
# Use user-specified radius or compute R = ||W_t||_F (Frobenius norm)
R = self.hyperball_radius if self.hyperball_radius is not None else p.norm().item()
self.state[p]["hyperball_R"] = R
if "hyperball_R" not in self.state[p]:
self.state[p]["hyperball_R"] = torch.tensor(self.hyperball_radius, dtype=p.dtype, device=p.device)
R = self.state[p]["hyperball_R"]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably missed it in previous review, hyperball_R and hyperball_radius are inconsistent, should use hyperball_radius for both


# Normalize the update in-place and scale by R
# This modifies update to be: R * normalize(update) using Frobenius norm.
update_norm = update.norm().clamp_min(self.hyperball_eps)
update.mul_(R / update_norm)

Expand Down
Loading