Skip to content

fix: expose alpha param through LigerFusedLinearDPO public API#1194

Open
micdoh wants to merge 2 commits intolinkedin:mainfrom
micdoh:expose_dpo_alpha
Open

fix: expose alpha param through LigerFusedLinearDPO public API#1194
micdoh wants to merge 2 commits intolinkedin:mainfrom
micdoh:expose_dpo_alpha

Conversation

@micdoh
Copy link
Copy Markdown

@micdoh micdoh commented Apr 16, 2026

Summary

  • `LigerFusedLinearDPOFunction.forward()` was missing the `alpha` parameter entirely, so the NLL loss scaling weight silently defaulted to `1.0` no matter what callers passed — the parameter exists in `LigerFusedLinearPreferenceBase` but was never plumbed through `LigerFusedLinearDPOFunction`.
  • `LigerFusedLinearDPOLoss` (the `nn.Module` wrapper) similarly had no `alpha` argument.
  • Added one extra `None` to `backward()` to match the new positional count.
  • Fixed positional-arg ordering in the two existing functional tests that called `LigerFusedLinearDPOFunction.apply(...)` directly.
  • Added `test_alpha_scales_nll_loss` to verify `alpha` actually reaches the loss computation.

Changes

`src/liger_kernel/chunked_loss/dpo_loss.py`

  • Add `alpha: float = 1.0` to `LigerFusedLinearDPOFunction.forward()` and forward it to `super().forward()`
  • Add one `None` to `backward()` return (was 11, now 12 — one per non-tensor arg)
  • Add `alpha: float = 1.0` to `LigerFusedLinearDPOLoss.init()` and `self.alpha` assignment
  • Pass `self.alpha` to `LigerFusedLinearDPOFunction.apply()` in `LigerFusedLinearDPOLoss.forward()`

`test/chunked_loss/test_dpo_loss.py`

  • Insert `1.0` (alpha) in the correct positional slot in `test_correctness_functional` and `test_correctness_functional_apo_loss_types`
  • Add `test_alpha_scales_nll_loss` regression test

Test plan

  • Existing `test_correctness`, `test_correctness_functional`, `test_correctness_apo_loss_types`, `test_correctness_functional_apo_loss_types` all pass
  • New `test_alpha_scales_nll_loss` passes and confirms `alpha != 1.0` changes the loss value

micdoh and others added 2 commits April 16, 2026 12:03
LigerFusedLinearDPOFunction.forward() accepted every base-class parameter
except alpha, so the NLL scaling weight silently defaulted to 1.0 regardless
of what callers passed. This adds alpha to both the Function and the
LigerFusedLinearDPOLoss module, fixes the positional-arg order in the
existing functional tests, and adds a regression test that verifies alpha
actually affects the loss value when compute_nll_loss=True.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant