Skip to content

Commit 8777c9f

Browse files
committed
Skip norm weight/bias gradients when frozen (LoRA/PEFT optimization)
When using LoRA/PEFT, normalization weights are typically frozen but gradients were still being computed. This PR skips dW/dB computation when parameters have requires_grad=False, providing significant speedups at larger hidden sizes (up to 3x faster backward pass at H=32768). Changes: - Add compute_dW/compute_dB flags to backward kernels (tl.constexpr) - Skip gradient buffer allocation when not needed - Check ctx.needs_input_grad in all norm backward passes - Add frozen weight/bias test coverage for all norm ops - Add mixed RMSNorm+LoRA benchmark - Fix dS_out None check in fused_add_rms_norm_backward Affected ops: RMSNorm, FusedAddRMSNorm, LayerNorm, GroupNorm, PolyNorm No public API changes.
1 parent bb88671 commit 8777c9f

12 files changed

Lines changed: 814 additions & 90 deletions

benchmark/scripts/benchmark_rms_norm.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ def bench_speed_rms_norm(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOu
4141
M = extra_benchmark_config["M"]
4242
eps = extra_benchmark_config["eps"]
4343
dtype = extra_benchmark_config["dtype"]
44+
freeze_weight = extra_benchmark_config.get("freeze_weight", False)
4445

4546
x_shape = (M, N)
4647

@@ -51,6 +52,10 @@ def bench_speed_rms_norm(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOu
5152
dy = torch.randn_like(x)
5253
x.requires_grad_(True)
5354

55+
if freeze_weight:
56+
triton_rms.weight.requires_grad_(False)
57+
llama_rms.weight.requires_grad_(False)
58+
5459
# utility functions
5560

5661
def y_fwd():
@@ -60,18 +65,24 @@ def y_fwd():
6065
if provider == "huggingface":
6166
return llama_rms(x)
6267

68+
grad_to_none = [x]
69+
if provider == "liger" and triton_rms.weight.requires_grad:
70+
grad_to_none.append(triton_rms.weight)
71+
elif provider == "huggingface" and llama_rms.weight.requires_grad:
72+
grad_to_none.append(llama_rms.weight)
73+
6374
if mode == "forward":
6475
ms_50, ms_20, ms_80 = triton.testing.do_bench(
6576
y_fwd,
66-
grad_to_none=[x],
77+
grad_to_none=grad_to_none,
6778
rep=500,
6879
quantiles=QUANTILES,
6980
)
7081
elif mode == "backward":
7182
y = y_fwd()
7283
ms_50, ms_20, ms_80 = triton.testing.do_bench(
7384
lambda: y.backward(dy, retain_graph=True),
74-
grad_to_none=[x],
85+
grad_to_none=grad_to_none,
7586
rep=500,
7687
quantiles=QUANTILES,
7788
)
@@ -83,7 +94,7 @@ def full():
8394

8495
ms_50, ms_20, ms_80 = triton.testing.do_bench(
8596
full,
86-
grad_to_none=[x],
97+
grad_to_none=grad_to_none,
8798
rep=500,
8899
quantiles=QUANTILES,
89100
)
@@ -103,6 +114,7 @@ def bench_memory_rms_norm(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunO
103114
M = extra_benchmark_config["M"]
104115
eps = extra_benchmark_config["eps"]
105116
dtype = extra_benchmark_config["dtype"]
117+
freeze_weight = extra_benchmark_config.get("freeze_weight", False)
106118

107119
x_shape = (M, N)
108120

@@ -113,6 +125,10 @@ def bench_memory_rms_norm(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunO
113125
dy = torch.randn_like(x)
114126
x.requires_grad_(True)
115127

128+
if freeze_weight:
129+
triton_rms.weight.requires_grad_(False)
130+
llama_rms.weight.requires_grad_(False)
131+
116132
# utility functions
117133
def y_fwd():
118134
if provider == "liger":
@@ -142,7 +158,10 @@ def full():
142158
"x_label": "hidden size",
143159
"x_values": [2**i for i in range(10, 16)],
144160
"kernel_providers": ["liger", "huggingface"],
145-
"extra_benchmark_configs": [{"M": 2048, "dtype": torch.bfloat16, "eps": 1e-6}],
161+
"extra_benchmark_configs": [
162+
{"M": 2048, "dtype": torch.bfloat16, "eps": 1e-6, "freeze_weight": False},
163+
{"M": 2048, "dtype": torch.bfloat16, "eps": 1e-6, "freeze_weight": True},
164+
],
146165
"overwrite": args.overwrite,
147166
}
148167

Lines changed: 223 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,223 @@
1+
import math
2+
3+
import torch
4+
import torch.nn as nn
5+
import triton
6+
7+
from utils import QUANTILES
8+
from utils import SingleBenchmarkRunInput
9+
from utils import SingleBenchmarkRunOutput
10+
from utils import _test_memory
11+
from utils import parse_benchmark_script_args
12+
from utils import run_benchmarks
13+
14+
from liger_kernel.transformers.rms_norm import LigerRMSNorm
15+
from liger_kernel.utils import infer_device
16+
17+
device = infer_device()
18+
19+
20+
class LlamaRMSNorm(nn.Module):
21+
def __init__(self, hidden_size, eps=1e-6):
22+
"""
23+
LlamaRMSNorm is equivalent to T5LayerNorm
24+
"""
25+
super().__init__()
26+
self.weight = nn.Parameter(torch.ones(hidden_size))
27+
self.variance_epsilon = eps
28+
29+
def forward(self, hidden_states):
30+
input_dtype = hidden_states.dtype
31+
hidden_states = hidden_states.to(torch.float32)
32+
variance = hidden_states.pow(2).mean(-1, keepdim=True)
33+
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
34+
return self.weight * hidden_states.to(input_dtype)
35+
36+
37+
class LoRALinear(nn.Module):
38+
def __init__(self, in_features, out_features, r=8, alpha=16.0, bias=False):
39+
super().__init__()
40+
self.weight = nn.Parameter(torch.empty(out_features, in_features))
41+
self.weight.requires_grad_(False) # base weight frozen (LoRA)
42+
self.lora_A = nn.Parameter(torch.empty(r, in_features))
43+
self.lora_B = nn.Parameter(torch.empty(out_features, r))
44+
self.scaling = alpha / r
45+
if bias:
46+
self.bias = nn.Parameter(torch.zeros(out_features))
47+
else:
48+
self.register_parameter("bias", None)
49+
50+
# Init with small random values so grads flow through both A and B
51+
nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
52+
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
53+
nn.init.kaiming_uniform_(self.lora_B, a=math.sqrt(5))
54+
55+
def forward(self, x):
56+
base = x @ self.weight.t()
57+
lora = (x @ self.lora_A.t()) @ self.lora_B.t()
58+
out = base + lora * self.scaling
59+
if self.bias is not None:
60+
out = out + self.bias
61+
return out
62+
63+
64+
class MixedBlock(nn.Module):
65+
def __init__(self, norm_cls, hidden_size, eps, lora_r, lora_alpha):
66+
super().__init__()
67+
self.norm = norm_cls(hidden_size=hidden_size, eps=eps)
68+
self.proj = LoRALinear(hidden_size, hidden_size, r=lora_r, alpha=lora_alpha)
69+
70+
def forward(self, x):
71+
return self.proj(self.norm(x))
72+
73+
74+
def _build_block(provider, hidden_size, eps, dtype, lora_r, lora_alpha, freeze_norm_weight):
75+
norm_cls = LigerRMSNorm if provider == "liger" else LlamaRMSNorm
76+
block = MixedBlock(norm_cls, hidden_size=hidden_size, eps=eps, lora_r=lora_r, lora_alpha=lora_alpha)
77+
block = block.to(device=device, dtype=dtype)
78+
if freeze_norm_weight:
79+
block.norm.weight.requires_grad_(False)
80+
return block
81+
82+
83+
def _grad_to_none_tensors(module, x):
84+
tensors = [x]
85+
for p in module.parameters():
86+
if p.requires_grad:
87+
tensors.append(p)
88+
return tensors
89+
90+
91+
def bench_speed_rms_norm_mixed(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
92+
N = input.x
93+
provider = input.kernel_provider
94+
mode = input.kernel_operation_mode
95+
96+
extra = input.extra_benchmark_config
97+
M = extra["M"]
98+
eps = extra["eps"]
99+
dtype = extra["dtype"]
100+
lora_r = extra["lora_r"]
101+
lora_alpha = extra["lora_alpha"]
102+
freeze_norm_weight = extra.get("freeze_norm_weight", True)
103+
104+
x_shape = (M, N)
105+
106+
block = _build_block(provider, N, eps, dtype, lora_r, lora_alpha, freeze_norm_weight)
107+
108+
x = torch.randn(x_shape, dtype=dtype, device=device)
109+
dy = torch.randn_like(x)
110+
x.requires_grad_(True)
111+
112+
def y_fwd():
113+
return block(x)
114+
115+
grad_to_none = _grad_to_none_tensors(block, x)
116+
117+
if mode == "forward":
118+
ms_50, ms_20, ms_80 = triton.testing.do_bench(
119+
y_fwd,
120+
grad_to_none=grad_to_none,
121+
rep=500,
122+
quantiles=QUANTILES,
123+
)
124+
elif mode == "backward":
125+
y = y_fwd()
126+
ms_50, ms_20, ms_80 = triton.testing.do_bench(
127+
lambda: y.backward(dy, retain_graph=True),
128+
grad_to_none=grad_to_none,
129+
rep=500,
130+
quantiles=QUANTILES,
131+
)
132+
elif mode == "full":
133+
134+
def full():
135+
y = y_fwd()
136+
y.backward(dy, retain_graph=True)
137+
138+
ms_50, ms_20, ms_80 = triton.testing.do_bench(
139+
full,
140+
grad_to_none=grad_to_none,
141+
rep=500,
142+
quantiles=QUANTILES,
143+
)
144+
145+
return SingleBenchmarkRunOutput(
146+
y_20=ms_20,
147+
y_50=ms_50,
148+
y_80=ms_80,
149+
)
150+
151+
152+
def bench_memory_rms_norm_mixed(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
153+
N = input.x
154+
provider = input.kernel_provider
155+
156+
extra = input.extra_benchmark_config
157+
M = extra["M"]
158+
eps = extra["eps"]
159+
dtype = extra["dtype"]
160+
lora_r = extra["lora_r"]
161+
lora_alpha = extra["lora_alpha"]
162+
freeze_norm_weight = extra.get("freeze_norm_weight", True)
163+
164+
x_shape = (M, N)
165+
166+
block = _build_block(provider, N, eps, dtype, lora_r, lora_alpha, freeze_norm_weight)
167+
168+
x = torch.randn(x_shape, dtype=dtype, device=device)
169+
dy = torch.randn_like(x)
170+
x.requires_grad_(True)
171+
172+
def y_fwd():
173+
return block(x)
174+
175+
def full():
176+
y = y_fwd()
177+
y.backward(dy, retain_graph=True)
178+
179+
mem_50, mem_20, mem_80 = _test_memory(full, quantiles=QUANTILES)
180+
181+
return SingleBenchmarkRunOutput(
182+
y_20=mem_20,
183+
y_50=mem_50,
184+
y_80=mem_80,
185+
)
186+
187+
188+
if __name__ == "__main__":
189+
args = parse_benchmark_script_args()
190+
191+
common_configs = {
192+
"kernel_name": "rms_norm_mixed",
193+
"x_name": "H",
194+
"x_label": "hidden size",
195+
"x_values": [2**i for i in range(10, 16)],
196+
"kernel_providers": ["liger", "huggingface"],
197+
"extra_benchmark_configs": [
198+
{
199+
"M": 2048,
200+
"dtype": torch.bfloat16,
201+
"eps": 1e-6,
202+
"lora_r": 8,
203+
"lora_alpha": 16.0,
204+
"freeze_norm_weight": True,
205+
}
206+
],
207+
"overwrite": args.overwrite,
208+
}
209+
210+
run_benchmarks(
211+
bench_test_fn=bench_speed_rms_norm_mixed,
212+
kernel_operation_modes=["forward", "full", "backward"],
213+
metric_name="speed",
214+
metric_unit="ms",
215+
**common_configs,
216+
)
217+
run_benchmarks(
218+
bench_test_fn=bench_memory_rms_norm_mixed,
219+
kernel_operation_modes=["full"],
220+
metric_name="memory",
221+
metric_unit="MB",
222+
**common_configs,
223+
)

0 commit comments

Comments
 (0)