Skip to content

Commit bb3133a

Browse files
committed
Version: 0.10.18-alpha.2
Add native CUDA BF16 support for SwiGLU op - Implement highly vectorized BF16 forward/backward CUDA kernels for SwiGLU - Integrate BF16 into type dispatch, op registry, and op interface - Refactor FP32 kernel for float4 vectorization and consistency - Add detailed mixed-precision compute architecture documentation Enables production-grade BF16 inference and mixed-precision training for SwiGLU on modern NVIDIA GPUs. All arithmetic is performed in FP32 for numerical stability. Backward pass uses FP32 gradients for optimizer compatibility.No breaking changes.
1 parent 0efa361 commit bb3133a

12 files changed

Lines changed: 1613 additions & 561 deletions

File tree

Mila/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ add_library( Mila STATIC
4343
"Src/Dnn/Compute/Devices/Cuda/Operations/Activations/Gelu/Kernels/Gelu.cuh"
4444

4545
"Src/Dnn/Compute/Devices/Cuda/Operations/Activations/Swiglu/Kernels/Swiglu.Fp32.cu"
46+
"Src/Dnn/Compute/Devices/Cuda/Operations/Activations/Swiglu/Kernels/Swiglu.Bf16.cu"
4647
"Src/Dnn/Compute/Devices/Cuda/Operations/Activations/Swiglu/Kernels/Swiglu.cuh"
4748

4849
"Src/Dnn/Compute/Devices/Cuda/Operations/Normalizations/LayerNorm/Kernels/LayerNorm.Fp32.cu"

Mila/Specifications/Compute.md

Lines changed: 350 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,350 @@
1+
# Mila Mixed-Precision Compute Architecture Specification
2+
3+
## Overview
4+
5+
This document captures the design rationale and implementation decisions for Mila's
6+
mixed-precision CUDA compute backend. The SwiGLU op is the canonical reference
7+
implementation. All ops follow the same pattern.
8+
9+
---
10+
11+
## 1. Supported Precision Types
12+
13+
Mila supports the following `TensorDataType` values for CUDA compute:
14+
15+
| Abstract Type | Native CUDA Type | Forward Activations | Gradient Buffer | Notes |
16+
|----------------------|---------------------|---------------------|-----------------|------------------------------|
17+
| `TensorDataType::FP32` | `float` | FP32 | FP32 | Baseline, fully validated |
18+
| `TensorDataType::BF16` | `__nv_bfloat16` | BF16 | FP32 | Primary inference + training |
19+
| `TensorDataType::FP16` | `__half` | FP16 | FP32 | Deferred — no current need |
20+
| `TensorDataType::FP8_E4M3` | `__nv_fp8_e4m3` | FP8 | FP32 | Future |
21+
| `TensorDataType::FP8_E5M2` | `__nv_fp8_e5m2` | FP8 | FP32 | Future |
22+
23+
**BF16 is the primary reduced-precision target.** It has the same dynamic range as FP32
24+
(same exponent width), which makes it numerically stable for both inference and training
25+
without loss scaling. The RTX 4070 (Ada Lovelace) has native BF16 Tensor Core support.
26+
FP16 is deferred — there is no current use case that BF16 does not serve better on the
27+
target hardware.
28+
29+
---
30+
31+
## 2. Type Resolution Chain
32+
33+
The dispatch chain from abstract type to native kernel is fully compile-time:
34+
35+
```
36+
TensorDataType (enum)
37+
└─► TensorDataTypeMap<TPrecision>::native_type // abstract → native C++ type
38+
└─► cuda_op_impl<NativeType> // dispatch struct (per op)
39+
└─► cuda_op_forward_bf16(...) // plain C kernel launcher
40+
└─► op_bf16_forward_kernel<<<>>> // __global__ kernel
41+
```
42+
43+
### Key files per op
44+
45+
| File | Role |
46+
|-----------------------------|-------------------------------------------------------------|
47+
| `Op.ixx` | Hardware-agnostic component. Knows only `TensorDataType`. |
48+
| `CudaOp.ixx` | CUDA op. Resolves `NativeType` via `TensorDataTypeMap`. |
49+
| `CudaOp.Dispatch.ixx` | Module partition. `cuda_op_impl<NativeType>` structs. |
50+
| `CudaOp.Registrar.ixx` | Runtime registry. One `registerUnaryOperation` per type. |
51+
| `Op.Fp32.cu` | FP32 kernel + launcher. |
52+
| `Op.Bf16.cu` | BF16 kernel + launcher. |
53+
54+
### `TensorDataTypeMap` is the single source of truth
55+
56+
`CudaTensorDataType-Maps.ixx` maps every `TensorDataType` to its CUDA native type.
57+
No op or kernel file duplicates this mapping. The dispatch struct constraint on the
58+
primary template derives from this map — it is never hand-enumerated per op.
59+
60+
---
61+
62+
## 3. Dispatch Partition Contract
63+
64+
Each op provides a `CudaOp.Dispatch.ixx` module partition containing:
65+
66+
```cpp
67+
namespace Detail
68+
{
69+
// Primary template — gates to CUDA float native types only.
70+
// Constraint is derived from TensorDataTypeMap, never hand-enumerated.
71+
template <typename TNative>
72+
requires CudaFloatNativeType<TNative>
73+
struct cuda_op_impl;
74+
75+
// One complete specialization per supported native type.
76+
// If a specialization exists, ALL methods are implemented — no stub throws.
77+
// If a type is not ready, there is no specialization. The missing
78+
// specialization is a compile error at CudaOp instantiation — the correct
79+
// failure mode.
80+
template <>
81+
struct cuda_op_impl<float> { ... };
82+
83+
template <>
84+
struct cuda_op_impl<__nv_bfloat16> { ... };
85+
}
86+
```
87+
88+
### Rules
89+
90+
- **Complete or absent.** A specialization that throws at runtime for an unimplemented
91+
method violates the contract. The compile error from a missing specialization is the
92+
correct diagnostic.
93+
- **No runtime type switching.** The dispatch struct is instantiated at compile time from
94+
`NativeType`. There are no `if/switch` on `TensorDataType` at runtime inside an op.
95+
- **No state for elementwise ops.** For ops like SwiGLU where the impl carries no
96+
per-instance data, the struct is empty and compiles to nothing. The layer exists
97+
for consistency and to accommodate stateful ops (e.g. cuBLASLt plan holders in
98+
`CudaLinearOp`, `CudaGqaOp`).
99+
100+
---
101+
102+
## 4. Registrar Contract
103+
104+
Each op provides a `CudaOp.Registrar.ixx` containing a `CudaOpRegistrar` class with
105+
a static `registerOperations()` method. One `registerUnaryOperation` (or equivalent)
106+
call per supported `TensorDataType`.
107+
108+
### Rules
109+
110+
- **Registrar and dispatch specialization set must stay in sync.** Registering a type
111+
that has no dispatch specialization is a compile error. Providing a dispatch
112+
specialization without a registrar entry is a silent runtime omission — the op
113+
compiles but is unreachable via the registry.
114+
- **When adding a new type**, both the dispatch specialization and the registrar entry
115+
must land together.
116+
117+
---
118+
119+
## 5. Memory Layout
120+
121+
Mila uses **contiguous halves** throughout, not interleaved per-token layout.
122+
123+
For SwiGLU with input `X` of size `2N` and output `Y` of size `N`:
124+
125+
```
126+
X: [ gate_0, gate_1, ..., gate_N-1 | up_0, up_1, ..., up_N-1 ]
127+
└──────────── first half ───────┘└─────────── second half ─┘
128+
Y: [ y_0, y_1, ..., y_N-1 ]
129+
```
130+
131+
**Rationale:** Contiguous halves is easier to reason with, produces simpler vectorized
132+
indexing (no per-element token/col arithmetic), and is consistent with how all other
133+
Mila ops handle split buffers (QKV projection, etc.).
134+
135+
This differs from HuggingFace's interleaved layout which falls out of fused QKV
136+
projections. Mila uses explicit separate projections so the contiguous layout is natural.
137+
138+
**Batch size:** Mila targets B=1 for decode (single-user local inference). Batch > 1
139+
is not a current architectural requirement. The vectorized kernels are correct for B=1
140+
by construction.
141+
142+
---
143+
144+
## 6. Memory Alignment
145+
146+
Every tensor buffer pointer is guaranteed aligned at allocation time by
147+
`get_alignment<TDataType, MR>()`:
148+
149+
```cpp
150+
// CUDA alignment = CUDA_WARP_SIZE (32) * sizeof(element)
151+
FP32: 128 bytes (32 floats) — supports float4 loads
152+
BF16: 64 bytes (32 bfloat16) — supports uint4 loads (8 BF16 per load)
153+
FP16: 64 bytes (32 halfs) — supports uint4 loads
154+
INT8: 32 bytes (32 int8) — supports int4 loads
155+
```
156+
157+
**Consequence for kernels:** Vectorized loads are unconditional — no scalar prologue,
158+
no scalar epilogue for alignment. The only remainder handling required is for `N` not
159+
being a multiple of the vector width, which is enforced at the op level (see Section 7).
160+
161+
---
162+
163+
## 7. Vectorization
164+
165+
All elementwise CUDA kernels use vectorized loads and stores. The kernel is
166+
unconditionally vectorized — no scalar fallback path.
167+
168+
### Vector widths per type
169+
170+
| Type | Vector type | Elements per thread | Bytes per load |
171+
|--------|--------------|---------------------|----------------|
172+
| FP32 | `float4` | 4 | 16 |
173+
| BF16 | `uint4` | 8 | 16 |
174+
| FP16 | `uint4` | 8 | 16 |
175+
176+
### Exported vector width constants
177+
178+
Each kernel file exports a `constexpr int kOpTypeVectorWidth` constant. The op's
179+
`forward()` validates against this constant rather than a magic number:
180+
181+
```cpp
182+
// Op.Fp32.cu
183+
constexpr int kSwigluFp32VectorWidth = 4;
184+
185+
// Op.Bf16.cu
186+
constexpr int kSwigluBf16VectorWidth = 8;
187+
```
188+
189+
### Op-level validation
190+
191+
The op `forward()` enforces the vector width precondition before launching the kernel:
192+
193+
```cpp
194+
// Example for BF16 SwiGLU — input size must be multiple of 2 * VectorWidth
195+
// (gate half + up half, each must be a multiple of VectorWidth)
196+
if ( input.size() % ( 2 * kSwigluBf16VectorWidth ) != 0 )
197+
{
198+
throw std::invalid_argument(
199+
std::format( "CudaSwigluOp: input size must be a multiple of {} for vectorized BF16.",
200+
2 * kSwigluBf16VectorWidth )
201+
);
202+
}
203+
```
204+
205+
### Block size
206+
207+
All forward and backward kernels use **256 threads per block**. Grid size is computed
208+
over the number of vector-width chunks, not scalar elements:
209+
210+
```cpp
211+
int vec_N = N / kVectorWidth;
212+
int grid_size = ( vec_N + 256 - 1 ) / 256;
213+
```
214+
215+
---
216+
217+
## 8. BF16 Arithmetic: FP32 Promotion
218+
219+
BF16 kernels load and store in BF16 but compute in FP32. This applies to both
220+
forward and backward passes.
221+
222+
**Rationale:** Training stability. BF16 has only 7 mantissa bits — insufficient
223+
precision for intermediate values like sigmoid, exp, and gradient chain products.
224+
Promoting to FP32 for arithmetic gives full precision where it matters while
225+
preserving the memory bandwidth and VRAM benefits of BF16 storage. This is
226+
consistent with PyTorch's internal BF16 kernel strategy.
227+
228+
### Pattern for paired BF16 arithmetic
229+
230+
```cuda
231+
// Load 8 BF16 elements as uint4
232+
uint4 packed = reinterpret_cast<const uint4*>(X)[i];
233+
234+
// Reinterpret as four __nv_bfloat162 pairs
235+
__nv_bfloat162 ab = reinterpret_cast<const __nv_bfloat162*>(&packed)[0];
236+
__nv_bfloat162 cd = reinterpret_cast<const __nv_bfloat162*>(&packed)[1];
237+
// ... etc
238+
239+
// Promote to FP32 for arithmetic
240+
float2 ab_f = __bfloat1622float2( ab );
241+
// ... compute in float ...
242+
// Demote back to BF16 for store
243+
__nv_bfloat162 result = __float22bfloat162_rn( result_f );
244+
```
245+
246+
---
247+
248+
## 9. Mixed-Precision Training: Forward and Backward Tensor Types
249+
250+
### Forward pass
251+
252+
| Input tensor | Output tensor |
253+
|---------------|---------------|
254+
| BF16 | BF16 |
255+
256+
### Backward pass
257+
258+
| dY (upstream gradient) | X (saved activations) | dX (gradient output) |
259+
|------------------------|-----------------------|----------------------|
260+
| FP32 | BF16 | FP32 |
261+
262+
**Rationale:** FP32 gradients are the canonical format at the optimizer boundary.
263+
Both CUDA Adam (on-device) and CPU Adam (offloaded) consume FP32 gradients.
264+
CPU Adam is the practical path for users who offload optimizer state to host RAM
265+
to extend effective VRAM — a first-class use case for Mila.
266+
267+
The backward kernel signature for all BF16 ops follows this pattern:
268+
269+
```cuda
270+
void cuda_op_backward_bf16(
271+
float* dX, // FP32 gradient output — optimizer boundary
272+
const __nv_bfloat16* X, // BF16 saved forward activations
273+
const float* dY, // FP32 upstream gradient
274+
int N,
275+
cudaStream_t stream )
276+
```
277+
278+
---
279+
280+
## 10. Mixed-Precision Training: Weight Strategy
281+
282+
Mila targets the standard Micikevicius (2018) mixed-precision training recipe:
283+
284+
```
285+
FP32 master weights
286+
→ cast to BF16 for forward pass
287+
→ BF16 activations through forward
288+
→ FP32 gradients through backward
289+
→ FP32 Adam optimizer step (CUDA or CPU)
290+
→ update FP32 master weights
291+
→ repeat
292+
```
293+
294+
**Minimum hardware requirement for training:** 16GB VRAM.
295+
296+
**Rationale for FP32 master weights:**
297+
- Adam's first and second moments require FP32 to accumulate small updates correctly
298+
- Weight updates (`w -= lr * grad`) can be vanishingly small relative to `w` in BF16
299+
- FP32 master weights make Mila's training results directly comparable to the literature
300+
- At 16GB the standard recipe is viable for Llama 1B with gradient checkpointing
301+
302+
**CPU Adam offload:** Users with 16GB VRAM and sufficient system RAM can offload
303+
FP32 master weights and Adam moments to CPU, extending effective training capacity
304+
well beyond 1B parameters. This is a first-class supported configuration.
305+
306+
**`REVIEW:`** Stochastic rounding on BF16 weight updates is a future consideration
307+
for training stability without FP32 master weights. Not a current kernel concern but
308+
must not be accidentally designed around in the optimizer interface.
309+
310+
---
311+
312+
## 11. cuBLASLt and BF16
313+
314+
For ops using cuBLASLt (`CudaLinearOp`, `CudaGqaOp`):
315+
316+
- Data type: `CUDA_R_16BF`
317+
- Compute type: `CUBLAS_COMPUTE_32F_FAST_16BF`
318+
319+
This is mixed-precision matmul: data moves in BF16, accumulation is FP32 internally.
320+
`CudaDataTypeMap<__nv_bfloat16>::fp32_compute_type` holds this value.
321+
322+
**Important:** `CudaDataTypeMap<__nv_bfloat16>` has no `compute_type` member
323+
(BF16-native accumulation does not exist in cuBLAS). Plan builders must always use
324+
`fp32_compute_type` for BF16 — never `compute_type`. Any generic plan builder
325+
that falls through to `compute_type` without type checking is a silent bug.
326+
327+
Pre-built cuBLASLt plans are cached per op instance. Plan selection for
328+
mixed-precision BF16 matmuls has more heuristic space than FP32 — cuBLASLt
329+
may select a different optimal algorithm. Caching is therefore more valuable
330+
for BF16 than FP32.
331+
332+
---
333+
334+
## 12. Adding a New Type to an Existing Op — Checklist
335+
336+
When adding support for a new `TensorDataType` (e.g. BF16) to an existing op:
337+
338+
- [ ] `Op.Bf16.cu` — kernel + launcher, exports `kOpBf16VectorWidth`
339+
- [ ] `CudaOp.Dispatch.ixx` — add complete `cuda_op_impl<__nv_bfloat16>` specialization
340+
- [ ] `CudaOp.Registrar.ixx` — add `registerUnaryOperation<Cuda, BF16, BF16>` entry
341+
- [ ] `CudaOp.ixx` `forward()` — add vector width validation against `kOpBf16VectorWidth`
342+
- [ ] Validate kernel output against FP32 reference before enabling vectorization
343+
344+
All four must land together. Missing the registrar entry is a silent runtime omission.
345+
Missing the dispatch specialization is a compile error.
346+
347+
---
348+
349+
*This document reflects design decisions made through April 2026.*
350+
*Update when new types, ops, or training strategies are added.*

0 commit comments

Comments
 (0)