|
| 1 | +import math |
| 2 | + |
| 3 | +import torch |
| 4 | +import torch.nn as nn |
| 5 | +import triton |
| 6 | + |
| 7 | +from utils import QUANTILES |
| 8 | +from utils import SingleBenchmarkRunInput |
| 9 | +from utils import SingleBenchmarkRunOutput |
| 10 | +from utils import _test_memory |
| 11 | +from utils import parse_benchmark_script_args |
| 12 | +from utils import run_benchmarks |
| 13 | + |
| 14 | +from liger_kernel.transformers.rms_norm import LigerRMSNorm |
| 15 | +from liger_kernel.utils import infer_device |
| 16 | + |
| 17 | +device = infer_device() |
| 18 | + |
| 19 | + |
| 20 | +class LlamaRMSNorm(nn.Module): |
| 21 | + def __init__(self, hidden_size, eps=1e-6): |
| 22 | + """ |
| 23 | + LlamaRMSNorm is equivalent to T5LayerNorm |
| 24 | + """ |
| 25 | + super().__init__() |
| 26 | + self.weight = nn.Parameter(torch.ones(hidden_size)) |
| 27 | + self.variance_epsilon = eps |
| 28 | + |
| 29 | + def forward(self, hidden_states): |
| 30 | + input_dtype = hidden_states.dtype |
| 31 | + hidden_states = hidden_states.to(torch.float32) |
| 32 | + variance = hidden_states.pow(2).mean(-1, keepdim=True) |
| 33 | + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) |
| 34 | + return self.weight * hidden_states.to(input_dtype) |
| 35 | + |
| 36 | + |
| 37 | +class LoRALinear(nn.Module): |
| 38 | + def __init__(self, in_features, out_features, r=8, alpha=16.0, bias=False): |
| 39 | + super().__init__() |
| 40 | + self.weight = nn.Parameter(torch.empty(out_features, in_features)) |
| 41 | + self.weight.requires_grad_(False) # base weight frozen (LoRA) |
| 42 | + self.lora_A = nn.Parameter(torch.empty(r, in_features)) |
| 43 | + self.lora_B = nn.Parameter(torch.empty(out_features, r)) |
| 44 | + self.scaling = alpha / r |
| 45 | + if bias: |
| 46 | + self.bias = nn.Parameter(torch.zeros(out_features)) |
| 47 | + else: |
| 48 | + self.register_parameter("bias", None) |
| 49 | + |
| 50 | + # Init with small random values so grads flow through both A and B |
| 51 | + nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) |
| 52 | + nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) |
| 53 | + nn.init.kaiming_uniform_(self.lora_B, a=math.sqrt(5)) |
| 54 | + |
| 55 | + def forward(self, x): |
| 56 | + base = x @ self.weight.t() |
| 57 | + lora = (x @ self.lora_A.t()) @ self.lora_B.t() |
| 58 | + out = base + lora * self.scaling |
| 59 | + if self.bias is not None: |
| 60 | + out = out + self.bias |
| 61 | + return out |
| 62 | + |
| 63 | + |
| 64 | +class MixedBlock(nn.Module): |
| 65 | + def __init__(self, norm_cls, hidden_size, eps, lora_r, lora_alpha): |
| 66 | + super().__init__() |
| 67 | + self.norm = norm_cls(hidden_size=hidden_size, eps=eps) |
| 68 | + self.proj = LoRALinear(hidden_size, hidden_size, r=lora_r, alpha=lora_alpha) |
| 69 | + |
| 70 | + def forward(self, x): |
| 71 | + return self.proj(self.norm(x)) |
| 72 | + |
| 73 | + |
| 74 | +def _build_block(provider, hidden_size, eps, dtype, lora_r, lora_alpha, freeze_norm_weight): |
| 75 | + norm_cls = LigerRMSNorm if provider == "liger" else LlamaRMSNorm |
| 76 | + block = MixedBlock(norm_cls, hidden_size=hidden_size, eps=eps, lora_r=lora_r, lora_alpha=lora_alpha) |
| 77 | + block = block.to(device=device, dtype=dtype) |
| 78 | + if freeze_norm_weight: |
| 79 | + block.norm.weight.requires_grad_(False) |
| 80 | + return block |
| 81 | + |
| 82 | + |
| 83 | +def _grad_to_none_tensors(module, x): |
| 84 | + tensors = [x] |
| 85 | + for p in module.parameters(): |
| 86 | + if p.requires_grad: |
| 87 | + tensors.append(p) |
| 88 | + return tensors |
| 89 | + |
| 90 | + |
| 91 | +def bench_speed_rms_norm_mixed(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: |
| 92 | + N = input.x |
| 93 | + provider = input.kernel_provider |
| 94 | + mode = input.kernel_operation_mode |
| 95 | + |
| 96 | + extra = input.extra_benchmark_config |
| 97 | + M = extra["M"] |
| 98 | + eps = extra["eps"] |
| 99 | + dtype = extra["dtype"] |
| 100 | + lora_r = extra["lora_r"] |
| 101 | + lora_alpha = extra["lora_alpha"] |
| 102 | + freeze_norm_weight = extra.get("freeze_norm_weight", True) |
| 103 | + |
| 104 | + x_shape = (M, N) |
| 105 | + |
| 106 | + block = _build_block(provider, N, eps, dtype, lora_r, lora_alpha, freeze_norm_weight) |
| 107 | + |
| 108 | + x = torch.randn(x_shape, dtype=dtype, device=device) |
| 109 | + dy = torch.randn_like(x) |
| 110 | + x.requires_grad_(True) |
| 111 | + |
| 112 | + def y_fwd(): |
| 113 | + return block(x) |
| 114 | + |
| 115 | + grad_to_none = _grad_to_none_tensors(block, x) |
| 116 | + |
| 117 | + if mode == "forward": |
| 118 | + ms_50, ms_20, ms_80 = triton.testing.do_bench( |
| 119 | + y_fwd, |
| 120 | + grad_to_none=grad_to_none, |
| 121 | + rep=500, |
| 122 | + quantiles=QUANTILES, |
| 123 | + ) |
| 124 | + elif mode == "backward": |
| 125 | + y = y_fwd() |
| 126 | + ms_50, ms_20, ms_80 = triton.testing.do_bench( |
| 127 | + lambda: y.backward(dy, retain_graph=True), |
| 128 | + grad_to_none=grad_to_none, |
| 129 | + rep=500, |
| 130 | + quantiles=QUANTILES, |
| 131 | + ) |
| 132 | + elif mode == "full": |
| 133 | + |
| 134 | + def full(): |
| 135 | + y = y_fwd() |
| 136 | + y.backward(dy, retain_graph=True) |
| 137 | + |
| 138 | + ms_50, ms_20, ms_80 = triton.testing.do_bench( |
| 139 | + full, |
| 140 | + grad_to_none=grad_to_none, |
| 141 | + rep=500, |
| 142 | + quantiles=QUANTILES, |
| 143 | + ) |
| 144 | + |
| 145 | + return SingleBenchmarkRunOutput( |
| 146 | + y_20=ms_20, |
| 147 | + y_50=ms_50, |
| 148 | + y_80=ms_80, |
| 149 | + ) |
| 150 | + |
| 151 | + |
| 152 | +def bench_memory_rms_norm_mixed(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: |
| 153 | + N = input.x |
| 154 | + provider = input.kernel_provider |
| 155 | + |
| 156 | + extra = input.extra_benchmark_config |
| 157 | + M = extra["M"] |
| 158 | + eps = extra["eps"] |
| 159 | + dtype = extra["dtype"] |
| 160 | + lora_r = extra["lora_r"] |
| 161 | + lora_alpha = extra["lora_alpha"] |
| 162 | + freeze_norm_weight = extra.get("freeze_norm_weight", True) |
| 163 | + |
| 164 | + x_shape = (M, N) |
| 165 | + |
| 166 | + block = _build_block(provider, N, eps, dtype, lora_r, lora_alpha, freeze_norm_weight) |
| 167 | + |
| 168 | + x = torch.randn(x_shape, dtype=dtype, device=device) |
| 169 | + dy = torch.randn_like(x) |
| 170 | + x.requires_grad_(True) |
| 171 | + |
| 172 | + def y_fwd(): |
| 173 | + return block(x) |
| 174 | + |
| 175 | + def full(): |
| 176 | + y = y_fwd() |
| 177 | + y.backward(dy, retain_graph=True) |
| 178 | + |
| 179 | + mem_50, mem_20, mem_80 = _test_memory(full, quantiles=QUANTILES) |
| 180 | + |
| 181 | + return SingleBenchmarkRunOutput( |
| 182 | + y_20=mem_20, |
| 183 | + y_50=mem_50, |
| 184 | + y_80=mem_80, |
| 185 | + ) |
| 186 | + |
| 187 | + |
| 188 | +if __name__ == "__main__": |
| 189 | + args = parse_benchmark_script_args() |
| 190 | + |
| 191 | + common_configs = { |
| 192 | + "kernel_name": "rms_norm_mixed", |
| 193 | + "x_name": "H", |
| 194 | + "x_label": "hidden size", |
| 195 | + "x_values": [2**i for i in range(10, 16)], |
| 196 | + "kernel_providers": ["liger", "huggingface"], |
| 197 | + "extra_benchmark_configs": [ |
| 198 | + { |
| 199 | + "M": 2048, |
| 200 | + "dtype": torch.bfloat16, |
| 201 | + "eps": 1e-6, |
| 202 | + "lora_r": 8, |
| 203 | + "lora_alpha": 16.0, |
| 204 | + "freeze_norm_weight": True, |
| 205 | + } |
| 206 | + ], |
| 207 | + "overwrite": args.overwrite, |
| 208 | + } |
| 209 | + |
| 210 | + run_benchmarks( |
| 211 | + bench_test_fn=bench_speed_rms_norm_mixed, |
| 212 | + kernel_operation_modes=["forward", "full", "backward"], |
| 213 | + metric_name="speed", |
| 214 | + metric_unit="ms", |
| 215 | + **common_configs, |
| 216 | + ) |
| 217 | + run_benchmarks( |
| 218 | + bench_test_fn=bench_memory_rms_norm_mixed, |
| 219 | + kernel_operation_modes=["full"], |
| 220 | + metric_name="memory", |
| 221 | + metric_unit="MB", |
| 222 | + **common_configs, |
| 223 | + ) |
0 commit comments