-
Notifications
You must be signed in to change notification settings - Fork 26
Mkhona/hyperball fix #158
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Mkhona/hyperball fix #158
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -35,8 +35,8 @@ class MuonHyperball(muon.Muon): | |
|
|
||
| W_{t+1} = R \\cdot \\text{normalize}(W_t - \\text{lr} \\cdot R \\cdot \\text{normalize}(\\text{update})) | ||
|
|
||
| where :math:`R` is the Frobenius norm of :math:`W_t` (or a user-specified radius). This keeps | ||
| the weight matrix at constant scale while updating. | ||
| where :math:`R` is the user-specified Frobenius norm. This keeps the weight matrix at | ||
| constant scale while updating. | ||
|
|
||
| Warning: | ||
| This optimizer is experimental and may change in future versions. | ||
|
|
@@ -49,52 +49,60 @@ class MuonHyperball(muon.Muon): | |
| *args: Arguments passed to Muon. | ||
| hyperball_eps: Epsilon for numerical stability in normalization. | ||
| Default: ``1e-8``. | ||
| hyperball_radius: Fixed radius for the hyperball. If ``None`` (default), | ||
| uses each parameter's initial Frobenius norm as its radius. If specified, all | ||
| parameters will be rescaled to have this radius at initialization. | ||
| hyperball_radius: Fixed radius for the hyperball. All parameters must | ||
| already have this Frobenius norm at construction time. | ||
| **kwargs: Keyword arguments passed to Muon. | ||
|
|
||
| Raises: | ||
| ValueError: If any parameter has zero norm, or if a parameter's | ||
| Frobenius norm does not match ``hyperball_radius``. | ||
|
|
||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| *args: Any, | ||
| hyperball_eps: float = 1e-8, | ||
| hyperball_radius: float | None = None, | ||
| hyperball_radius: float, | ||
| **kwargs: Any, | ||
| ) -> None: | ||
| self.hyperball_eps = hyperball_eps | ||
| self.hyperball_radius = hyperball_radius | ||
| super().__init__(*args, **kwargs) | ||
|
|
||
| # Validate and optionally rescale parameters based on hyperball_radius. | ||
| with torch.no_grad(): | ||
| for group in self.param_groups: | ||
| for p in group["params"]: | ||
| p_norm = p.norm() | ||
| # Validate that parameter has non-zero norm. | ||
| if p_norm.item() == 0: | ||
| if p_norm == 0: | ||
| raise ValueError( | ||
| "MuonHyperball requires all parameters to have non-zero norm. " | ||
| "Found parameter with zero norm." | ||
| ) | ||
| if not torch.isclose( | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In consistent with above NOTE: this is potentially a regression than optimization as more host device sync can be triggered. In the worst case, a cudaMalloc. |
||
| p_norm, | ||
| torch.tensor(self.hyperball_radius, dtype=p.dtype, device=p.device), | ||
| rtol=1e-5, | ||
| atol=1e-8, | ||
| ): | ||
| raise ValueError( | ||
| "MuonHyperball requires all parameters to have non-zero norm. Found parameter with zero norm." | ||
| f"hyperball_radius={self.hyperball_radius} was specified but a parameter " | ||
| f"has Frobenius norm {p_norm.item()}. Rescale your model parameters to the " | ||
| f"desired radius before constructing the optimizer." | ||
| ) | ||
| # Rescale parameter to have the specified radius if provided. | ||
| if self.hyperball_radius is not None: | ||
| p.mul_(self.hyperball_radius / p_norm.clamp_min(self.hyperball_eps)) | ||
|
|
||
| @override | ||
| def pre_weight_update_fn_inplace(self, p: torch.Tensor, update: torch.Tensor) -> None: | ||
| """Store the original weight norm and normalize the update using Frobenius norm. | ||
| """Normalize the update using Frobenius norm, scaled by R. | ||
|
|
||
| Args: | ||
| p: The parameter tensor. | ||
| update: The orthogonalized gradient tensor. | ||
| """ | ||
| # Use user-specified radius or compute R = ||W_t||_F (Frobenius norm) | ||
| R = self.hyperball_radius if self.hyperball_radius is not None else p.norm().item() | ||
| self.state[p]["hyperball_R"] = R | ||
| if "hyperball_R" not in self.state[p]: | ||
| self.state[p]["hyperball_R"] = torch.tensor(self.hyperball_radius, dtype=p.dtype, device=p.device) | ||
| R = self.state[p]["hyperball_R"] | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Probably missed it in previous review, |
||
|
|
||
| # Normalize the update in-place and scale by R | ||
| # This modifies update to be: R * normalize(update) using Frobenius norm. | ||
| update_norm = update.norm().clamp_min(self.hyperball_eps) | ||
| update.mul_(R / update_norm) | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hyperball_radiusis positiveThe parameter is now required but has no guard against non-positive values. If
hyperball_radius <= 0is passed, thep_norm == 0check at line 77 won't help (parameter norms are always ≥ 0 andisclosewith a negative target will always fail), so the user receives a confusing norm-mismatch error instead of a clear message about an invalid radius.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add a default value, negative check is optional.