Mkhona/hyperball fix#158
Conversation
Signed-off-by: mikail <mkhona@nvidia.com>
Signed-off-by: mikail <mkhona@nvidia.com>
Signed-off-by: mikail <mkhona@nvidia.com>
Greptile SummaryThis PR removes the auto-rescaling initialization logic from The only remaining finding is a minor hardening gap: Confidence Score: 5/5Safe to merge; all remaining findings are P2 style/hardening suggestions that do not affect correctness. The core logic change is correct — removing auto-rescaling and enforcing a strict pre-condition keeps the invariant clear. The only gap is a missing positive-value guard on No files require special attention. Important Files Changed
Sequence DiagramsequenceDiagram
participant User
participant MuonHyperball
participant OrthogonalizedOptimizer
User->>MuonHyperball: __init__(hyperball_radius=R)
MuonHyperball->>MuonHyperball: store hyperball_eps, hyperball_radius
MuonHyperball->>OrthogonalizedOptimizer: super().__init__()
MuonHyperball->>MuonHyperball: validate each p_norm == R (torch.isclose)
note over MuonHyperball: Raises ValueError if p_norm=0 or p_norm≠R
User->>OrthogonalizedOptimizer: step()
OrthogonalizedOptimizer->>OrthogonalizedOptimizer: _init_group() — init momentum_buffer
loop for each parameter p with grad
OrthogonalizedOptimizer->>OrthogonalizedOptimizer: apply weight decay
OrthogonalizedOptimizer->>OrthogonalizedOptimizer: update momentum buffer
OrthogonalizedOptimizer->>OrthogonalizedOptimizer: orthogonalize (Newton-Schulz)
OrthogonalizedOptimizer->>MuonHyperball: pre_weight_update_fn_inplace(p, update)
note over MuonHyperball: Lazy-init hyperball_R tensor in state[p]\nNormalise update → R·normalize(update)
OrthogonalizedOptimizer->>OrthogonalizedOptimizer: p -= lr · update
OrthogonalizedOptimizer->>MuonHyperball: post_weight_update_fn_inplace(p)
note over MuonHyperball: Re-project p onto hypersphere:\np = R · normalize(p)
end
Reviews (1): Last reviewed commit: "linting" | Re-trigger Greptile |
| def __init__( | ||
| self, | ||
| *args: Any, | ||
| hyperball_eps: float = 1e-8, | ||
| hyperball_radius: float | None = None, | ||
| hyperball_radius: float, | ||
| **kwargs: Any, | ||
| ) -> None: | ||
| self.hyperball_eps = hyperball_eps | ||
| self.hyperball_radius = hyperball_radius | ||
| super().__init__(*args, **kwargs) |
There was a problem hiding this comment.
No validation that
hyperball_radius is positive
The parameter is now required but has no guard against non-positive values. If hyperball_radius <= 0 is passed, the p_norm == 0 check at line 77 won't help (parameter norms are always ≥ 0 and isclose with a negative target will always fail), so the user receives a confusing norm-mismatch error instead of a clear message about an invalid radius.
| def __init__( | |
| self, | |
| *args: Any, | |
| hyperball_eps: float = 1e-8, | |
| hyperball_radius: float | None = None, | |
| hyperball_radius: float, | |
| **kwargs: Any, | |
| ) -> None: | |
| self.hyperball_eps = hyperball_eps | |
| self.hyperball_radius = hyperball_radius | |
| super().__init__(*args, **kwargs) | |
| self.hyperball_eps = hyperball_eps | |
| self.hyperball_radius = hyperball_radius | |
| if hyperball_radius <= 0: | |
| raise ValueError( | |
| f"hyperball_radius must be positive, got {hyperball_radius}." | |
| ) | |
| super().__init__(*args, **kwargs) |
There was a problem hiding this comment.
Add a default value, negative check is optional.
| def __init__( | ||
| self, | ||
| *args: Any, | ||
| hyperball_eps: float = 1e-8, | ||
| hyperball_radius: float | None = None, | ||
| hyperball_radius: float, | ||
| **kwargs: Any, | ||
| ) -> None: | ||
| self.hyperball_eps = hyperball_eps | ||
| self.hyperball_radius = hyperball_radius | ||
| super().__init__(*args, **kwargs) |
There was a problem hiding this comment.
Add a default value, negative check is optional.
| self.state[p]["hyperball_R"] = R | ||
| if "hyperball_R" not in self.state[p]: | ||
| self.state[p]["hyperball_R"] = torch.tensor(self.hyperball_radius, dtype=p.dtype, device=p.device) | ||
| R = self.state[p]["hyperball_R"] |
There was a problem hiding this comment.
Probably missed it in previous review, hyperball_R and hyperball_radius are inconsistent, should use hyperball_radius for both
| "MuonHyperball requires all parameters to have non-zero norm. " | ||
| "Found parameter with zero norm." | ||
| ) | ||
| if not torch.isclose( |
There was a problem hiding this comment.
In consistent with above p_norm == 0 check, can use torch.equal.
NOTE: this is potentially a regression than optimization as more host device sync can be triggered. In the worst case, a cudaMalloc.
Addressed #155 (comment) and removed some logic about initialization.