|
| 1 | +# Mila Mixed-Precision Compute Architecture Specification |
| 2 | + |
| 3 | +## Overview |
| 4 | + |
| 5 | +This document captures the design rationale and implementation decisions for Mila's |
| 6 | +mixed-precision CUDA compute backend. The SwiGLU op is the canonical reference |
| 7 | +implementation. All ops follow the same pattern. |
| 8 | + |
| 9 | +--- |
| 10 | + |
| 11 | +## 1. Supported Precision Types |
| 12 | + |
| 13 | +Mila supports the following `TensorDataType` values for CUDA compute: |
| 14 | + |
| 15 | +| Abstract Type | Native CUDA Type | Forward Activations | Gradient Buffer | Notes | |
| 16 | +|----------------------|---------------------|---------------------|-----------------|------------------------------| |
| 17 | +| `TensorDataType::FP32` | `float` | FP32 | FP32 | Baseline, fully validated | |
| 18 | +| `TensorDataType::BF16` | `__nv_bfloat16` | BF16 | FP32 | Primary inference + training | |
| 19 | +| `TensorDataType::FP16` | `__half` | FP16 | FP32 | Deferred — no current need | |
| 20 | +| `TensorDataType::FP8_E4M3` | `__nv_fp8_e4m3` | FP8 | FP32 | Future | |
| 21 | +| `TensorDataType::FP8_E5M2` | `__nv_fp8_e5m2` | FP8 | FP32 | Future | |
| 22 | + |
| 23 | +**BF16 is the primary reduced-precision target.** It has the same dynamic range as FP32 |
| 24 | +(same exponent width), which makes it numerically stable for both inference and training |
| 25 | +without loss scaling. The RTX 4070 (Ada Lovelace) has native BF16 Tensor Core support. |
| 26 | +FP16 is deferred — there is no current use case that BF16 does not serve better on the |
| 27 | +target hardware. |
| 28 | + |
| 29 | +--- |
| 30 | + |
| 31 | +## 2. Type Resolution Chain |
| 32 | + |
| 33 | +The dispatch chain from abstract type to native kernel is fully compile-time: |
| 34 | + |
| 35 | +``` |
| 36 | +TensorDataType (enum) |
| 37 | + └─► TensorDataTypeMap<TPrecision>::native_type // abstract → native C++ type |
| 38 | + └─► cuda_op_impl<NativeType> // dispatch struct (per op) |
| 39 | + └─► cuda_op_forward_bf16(...) // plain C kernel launcher |
| 40 | + └─► op_bf16_forward_kernel<<<>>> // __global__ kernel |
| 41 | +``` |
| 42 | + |
| 43 | +### Key files per op |
| 44 | + |
| 45 | +| File | Role | |
| 46 | +|-----------------------------|-------------------------------------------------------------| |
| 47 | +| `Op.ixx` | Hardware-agnostic component. Knows only `TensorDataType`. | |
| 48 | +| `CudaOp.ixx` | CUDA op. Resolves `NativeType` via `TensorDataTypeMap`. | |
| 49 | +| `CudaOp.Dispatch.ixx` | Module partition. `cuda_op_impl<NativeType>` structs. | |
| 50 | +| `CudaOp.Registrar.ixx` | Runtime registry. One `registerUnaryOperation` per type. | |
| 51 | +| `Op.Fp32.cu` | FP32 kernel + launcher. | |
| 52 | +| `Op.Bf16.cu` | BF16 kernel + launcher. | |
| 53 | + |
| 54 | +### `TensorDataTypeMap` is the single source of truth |
| 55 | + |
| 56 | +`CudaTensorDataType-Maps.ixx` maps every `TensorDataType` to its CUDA native type. |
| 57 | +No op or kernel file duplicates this mapping. The dispatch struct constraint on the |
| 58 | +primary template derives from this map — it is never hand-enumerated per op. |
| 59 | + |
| 60 | +--- |
| 61 | + |
| 62 | +## 3. Dispatch Partition Contract |
| 63 | + |
| 64 | +Each op provides a `CudaOp.Dispatch.ixx` module partition containing: |
| 65 | + |
| 66 | +```cpp |
| 67 | +namespace Detail |
| 68 | +{ |
| 69 | + // Primary template — gates to CUDA float native types only. |
| 70 | + // Constraint is derived from TensorDataTypeMap, never hand-enumerated. |
| 71 | + template <typename TNative> |
| 72 | + requires CudaFloatNativeType<TNative> |
| 73 | + struct cuda_op_impl; |
| 74 | + |
| 75 | + // One complete specialization per supported native type. |
| 76 | + // If a specialization exists, ALL methods are implemented — no stub throws. |
| 77 | + // If a type is not ready, there is no specialization. The missing |
| 78 | + // specialization is a compile error at CudaOp instantiation — the correct |
| 79 | + // failure mode. |
| 80 | + template <> |
| 81 | + struct cuda_op_impl<float> { ... }; |
| 82 | + |
| 83 | + template <> |
| 84 | + struct cuda_op_impl<__nv_bfloat16> { ... }; |
| 85 | +} |
| 86 | +``` |
| 87 | + |
| 88 | +### Rules |
| 89 | + |
| 90 | +- **Complete or absent.** A specialization that throws at runtime for an unimplemented |
| 91 | + method violates the contract. The compile error from a missing specialization is the |
| 92 | + correct diagnostic. |
| 93 | +- **No runtime type switching.** The dispatch struct is instantiated at compile time from |
| 94 | + `NativeType`. There are no `if/switch` on `TensorDataType` at runtime inside an op. |
| 95 | +- **No state for elementwise ops.** For ops like SwiGLU where the impl carries no |
| 96 | + per-instance data, the struct is empty and compiles to nothing. The layer exists |
| 97 | + for consistency and to accommodate stateful ops (e.g. cuBLASLt plan holders in |
| 98 | + `CudaLinearOp`, `CudaGqaOp`). |
| 99 | + |
| 100 | +--- |
| 101 | + |
| 102 | +## 4. Registrar Contract |
| 103 | + |
| 104 | +Each op provides a `CudaOp.Registrar.ixx` containing a `CudaOpRegistrar` class with |
| 105 | +a static `registerOperations()` method. One `registerUnaryOperation` (or equivalent) |
| 106 | +call per supported `TensorDataType`. |
| 107 | + |
| 108 | +### Rules |
| 109 | + |
| 110 | +- **Registrar and dispatch specialization set must stay in sync.** Registering a type |
| 111 | + that has no dispatch specialization is a compile error. Providing a dispatch |
| 112 | + specialization without a registrar entry is a silent runtime omission — the op |
| 113 | + compiles but is unreachable via the registry. |
| 114 | +- **When adding a new type**, both the dispatch specialization and the registrar entry |
| 115 | + must land together. |
| 116 | + |
| 117 | +--- |
| 118 | + |
| 119 | +## 5. Memory Layout |
| 120 | + |
| 121 | +Mila uses **contiguous halves** throughout, not interleaved per-token layout. |
| 122 | + |
| 123 | +For SwiGLU with input `X` of size `2N` and output `Y` of size `N`: |
| 124 | + |
| 125 | +``` |
| 126 | +X: [ gate_0, gate_1, ..., gate_N-1 | up_0, up_1, ..., up_N-1 ] |
| 127 | + └──────────── first half ───────┘└─────────── second half ─┘ |
| 128 | +Y: [ y_0, y_1, ..., y_N-1 ] |
| 129 | +``` |
| 130 | + |
| 131 | +**Rationale:** Contiguous halves is easier to reason with, produces simpler vectorized |
| 132 | +indexing (no per-element token/col arithmetic), and is consistent with how all other |
| 133 | +Mila ops handle split buffers (QKV projection, etc.). |
| 134 | + |
| 135 | +This differs from HuggingFace's interleaved layout which falls out of fused QKV |
| 136 | +projections. Mila uses explicit separate projections so the contiguous layout is natural. |
| 137 | + |
| 138 | +**Batch size:** Mila targets B=1 for decode (single-user local inference). Batch > 1 |
| 139 | +is not a current architectural requirement. The vectorized kernels are correct for B=1 |
| 140 | +by construction. |
| 141 | + |
| 142 | +--- |
| 143 | + |
| 144 | +## 6. Memory Alignment |
| 145 | + |
| 146 | +Every tensor buffer pointer is guaranteed aligned at allocation time by |
| 147 | +`get_alignment<TDataType, MR>()`: |
| 148 | + |
| 149 | +```cpp |
| 150 | +// CUDA alignment = CUDA_WARP_SIZE (32) * sizeof(element) |
| 151 | +FP32: 128 bytes (32 floats) — supports float4 loads |
| 152 | +BF16: 64 bytes (32 bfloat16) — supports uint4 loads (8 BF16 per load) |
| 153 | +FP16: 64 bytes (32 halfs) — supports uint4 loads |
| 154 | +INT8: 32 bytes (32 int8) — supports int4 loads |
| 155 | +``` |
| 156 | +
|
| 157 | +**Consequence for kernels:** Vectorized loads are unconditional — no scalar prologue, |
| 158 | +no scalar epilogue for alignment. The only remainder handling required is for `N` not |
| 159 | +being a multiple of the vector width, which is enforced at the op level (see Section 7). |
| 160 | +
|
| 161 | +--- |
| 162 | +
|
| 163 | +## 7. Vectorization |
| 164 | +
|
| 165 | +All elementwise CUDA kernels use vectorized loads and stores. The kernel is |
| 166 | +unconditionally vectorized — no scalar fallback path. |
| 167 | +
|
| 168 | +### Vector widths per type |
| 169 | +
|
| 170 | +| Type | Vector type | Elements per thread | Bytes per load | |
| 171 | +|--------|--------------|---------------------|----------------| |
| 172 | +| FP32 | `float4` | 4 | 16 | |
| 173 | +| BF16 | `uint4` | 8 | 16 | |
| 174 | +| FP16 | `uint4` | 8 | 16 | |
| 175 | +
|
| 176 | +### Exported vector width constants |
| 177 | +
|
| 178 | +Each kernel file exports a `constexpr int kOpTypeVectorWidth` constant. The op's |
| 179 | +`forward()` validates against this constant rather than a magic number: |
| 180 | +
|
| 181 | +```cpp |
| 182 | +// Op.Fp32.cu |
| 183 | +constexpr int kSwigluFp32VectorWidth = 4; |
| 184 | +
|
| 185 | +// Op.Bf16.cu |
| 186 | +constexpr int kSwigluBf16VectorWidth = 8; |
| 187 | +``` |
| 188 | + |
| 189 | +### Op-level validation |
| 190 | + |
| 191 | +The op `forward()` enforces the vector width precondition before launching the kernel: |
| 192 | + |
| 193 | +```cpp |
| 194 | +// Example for BF16 SwiGLU — input size must be multiple of 2 * VectorWidth |
| 195 | +// (gate half + up half, each must be a multiple of VectorWidth) |
| 196 | +if ( input.size() % ( 2 * kSwigluBf16VectorWidth ) != 0 ) |
| 197 | +{ |
| 198 | + throw std::invalid_argument( |
| 199 | + std::format( "CudaSwigluOp: input size must be a multiple of {} for vectorized BF16.", |
| 200 | + 2 * kSwigluBf16VectorWidth ) |
| 201 | + ); |
| 202 | +} |
| 203 | +``` |
| 204 | + |
| 205 | +### Block size |
| 206 | + |
| 207 | +All forward and backward kernels use **256 threads per block**. Grid size is computed |
| 208 | +over the number of vector-width chunks, not scalar elements: |
| 209 | + |
| 210 | +```cpp |
| 211 | +int vec_N = N / kVectorWidth; |
| 212 | +int grid_size = ( vec_N + 256 - 1 ) / 256; |
| 213 | +``` |
| 214 | + |
| 215 | +--- |
| 216 | + |
| 217 | +## 8. BF16 Arithmetic: FP32 Promotion |
| 218 | + |
| 219 | +BF16 kernels load and store in BF16 but compute in FP32. This applies to both |
| 220 | +forward and backward passes. |
| 221 | + |
| 222 | +**Rationale:** Training stability. BF16 has only 7 mantissa bits — insufficient |
| 223 | +precision for intermediate values like sigmoid, exp, and gradient chain products. |
| 224 | +Promoting to FP32 for arithmetic gives full precision where it matters while |
| 225 | +preserving the memory bandwidth and VRAM benefits of BF16 storage. This is |
| 226 | +consistent with PyTorch's internal BF16 kernel strategy. |
| 227 | + |
| 228 | +### Pattern for paired BF16 arithmetic |
| 229 | + |
| 230 | +```cuda |
| 231 | +// Load 8 BF16 elements as uint4 |
| 232 | +uint4 packed = reinterpret_cast<const uint4*>(X)[i]; |
| 233 | +
|
| 234 | +// Reinterpret as four __nv_bfloat162 pairs |
| 235 | +__nv_bfloat162 ab = reinterpret_cast<const __nv_bfloat162*>(&packed)[0]; |
| 236 | +__nv_bfloat162 cd = reinterpret_cast<const __nv_bfloat162*>(&packed)[1]; |
| 237 | +// ... etc |
| 238 | +
|
| 239 | +// Promote to FP32 for arithmetic |
| 240 | +float2 ab_f = __bfloat1622float2( ab ); |
| 241 | +// ... compute in float ... |
| 242 | +// Demote back to BF16 for store |
| 243 | +__nv_bfloat162 result = __float22bfloat162_rn( result_f ); |
| 244 | +``` |
| 245 | + |
| 246 | +--- |
| 247 | + |
| 248 | +## 9. Mixed-Precision Training: Forward and Backward Tensor Types |
| 249 | + |
| 250 | +### Forward pass |
| 251 | + |
| 252 | +| Input tensor | Output tensor | |
| 253 | +|---------------|---------------| |
| 254 | +| BF16 | BF16 | |
| 255 | + |
| 256 | +### Backward pass |
| 257 | + |
| 258 | +| dY (upstream gradient) | X (saved activations) | dX (gradient output) | |
| 259 | +|------------------------|-----------------------|----------------------| |
| 260 | +| FP32 | BF16 | FP32 | |
| 261 | + |
| 262 | +**Rationale:** FP32 gradients are the canonical format at the optimizer boundary. |
| 263 | +Both CUDA Adam (on-device) and CPU Adam (offloaded) consume FP32 gradients. |
| 264 | +CPU Adam is the practical path for users who offload optimizer state to host RAM |
| 265 | +to extend effective VRAM — a first-class use case for Mila. |
| 266 | + |
| 267 | +The backward kernel signature for all BF16 ops follows this pattern: |
| 268 | + |
| 269 | +```cuda |
| 270 | +void cuda_op_backward_bf16( |
| 271 | + float* dX, // FP32 gradient output — optimizer boundary |
| 272 | + const __nv_bfloat16* X, // BF16 saved forward activations |
| 273 | + const float* dY, // FP32 upstream gradient |
| 274 | + int N, |
| 275 | + cudaStream_t stream ) |
| 276 | +``` |
| 277 | + |
| 278 | +--- |
| 279 | + |
| 280 | +## 10. Mixed-Precision Training: Weight Strategy |
| 281 | + |
| 282 | +Mila targets the standard Micikevicius (2018) mixed-precision training recipe: |
| 283 | + |
| 284 | +``` |
| 285 | +FP32 master weights |
| 286 | + → cast to BF16 for forward pass |
| 287 | + → BF16 activations through forward |
| 288 | + → FP32 gradients through backward |
| 289 | + → FP32 Adam optimizer step (CUDA or CPU) |
| 290 | + → update FP32 master weights |
| 291 | + → repeat |
| 292 | +``` |
| 293 | + |
| 294 | +**Minimum hardware requirement for training:** 16GB VRAM. |
| 295 | + |
| 296 | +**Rationale for FP32 master weights:** |
| 297 | +- Adam's first and second moments require FP32 to accumulate small updates correctly |
| 298 | +- Weight updates (`w -= lr * grad`) can be vanishingly small relative to `w` in BF16 |
| 299 | +- FP32 master weights make Mila's training results directly comparable to the literature |
| 300 | +- At 16GB the standard recipe is viable for Llama 1B with gradient checkpointing |
| 301 | + |
| 302 | +**CPU Adam offload:** Users with 16GB VRAM and sufficient system RAM can offload |
| 303 | +FP32 master weights and Adam moments to CPU, extending effective training capacity |
| 304 | +well beyond 1B parameters. This is a first-class supported configuration. |
| 305 | + |
| 306 | +**`REVIEW:`** Stochastic rounding on BF16 weight updates is a future consideration |
| 307 | +for training stability without FP32 master weights. Not a current kernel concern but |
| 308 | +must not be accidentally designed around in the optimizer interface. |
| 309 | + |
| 310 | +--- |
| 311 | + |
| 312 | +## 11. cuBLASLt and BF16 |
| 313 | + |
| 314 | +For ops using cuBLASLt (`CudaLinearOp`, `CudaGqaOp`): |
| 315 | + |
| 316 | +- Data type: `CUDA_R_16BF` |
| 317 | +- Compute type: `CUBLAS_COMPUTE_32F_FAST_16BF` |
| 318 | + |
| 319 | +This is mixed-precision matmul: data moves in BF16, accumulation is FP32 internally. |
| 320 | +`CudaDataTypeMap<__nv_bfloat16>::fp32_compute_type` holds this value. |
| 321 | + |
| 322 | +**Important:** `CudaDataTypeMap<__nv_bfloat16>` has no `compute_type` member |
| 323 | +(BF16-native accumulation does not exist in cuBLAS). Plan builders must always use |
| 324 | +`fp32_compute_type` for BF16 — never `compute_type`. Any generic plan builder |
| 325 | +that falls through to `compute_type` without type checking is a silent bug. |
| 326 | + |
| 327 | +Pre-built cuBLASLt plans are cached per op instance. Plan selection for |
| 328 | +mixed-precision BF16 matmuls has more heuristic space than FP32 — cuBLASLt |
| 329 | +may select a different optimal algorithm. Caching is therefore more valuable |
| 330 | +for BF16 than FP32. |
| 331 | + |
| 332 | +--- |
| 333 | + |
| 334 | +## 12. Adding a New Type to an Existing Op — Checklist |
| 335 | + |
| 336 | +When adding support for a new `TensorDataType` (e.g. BF16) to an existing op: |
| 337 | + |
| 338 | +- [ ] `Op.Bf16.cu` — kernel + launcher, exports `kOpBf16VectorWidth` |
| 339 | +- [ ] `CudaOp.Dispatch.ixx` — add complete `cuda_op_impl<__nv_bfloat16>` specialization |
| 340 | +- [ ] `CudaOp.Registrar.ixx` — add `registerUnaryOperation<Cuda, BF16, BF16>` entry |
| 341 | +- [ ] `CudaOp.ixx` `forward()` — add vector width validation against `kOpBf16VectorWidth` |
| 342 | +- [ ] Validate kernel output against FP32 reference before enabling vectorization |
| 343 | + |
| 344 | +All four must land together. Missing the registrar entry is a silent runtime omission. |
| 345 | +Missing the dispatch specialization is a compile error. |
| 346 | + |
| 347 | +--- |
| 348 | + |
| 349 | +*This document reflects design decisions made through April 2026.* |
| 350 | +*Update when new types, ops, or training strategies are added.* |
0 commit comments