Skip to content

Commit 491c15c

Browse files
committed
Version: 0.11.1-alpha.3
BF16 backend WIP: full CUDA support for ops - Adds BF16 CUDA kernels for TokenEmbedding and RoPE; removes all FP16 code - Refactors LanguageModel API: streaming/callback-based generation, async support - Updates Chat CLI to stream tokens live and support cancellation - Declares BF16 as primary reduced-precision target; updates docs and build BREAKING CHANGE: FP16 is no longer supported; all reduced-precision ops use BF16. Generation API is now streaming/callback-based.
1 parent bb3133a commit 491c15c

21 files changed

Lines changed: 1123 additions & 3437 deletions

Mila/CMakeLists.txt

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,14 @@ add_library( Mila STATIC
1010
# Cuda Operations Kernels
1111
#--------------------------------------------------------------------------
1212
"Src/Dnn/Compute/Devices/Cuda/Operations/Embeddings/Kernels/TokenEmbedding.cuh"
13-
"Src/Dnn/Compute/Devices/Cuda/Operations/Embeddings/Kernels/TokenEmbedding.cu"
13+
"Src/Dnn/Compute/Devices/Cuda/Operations/Embeddings/Kernels/TokenEmbedding.Fp32.cu"
14+
"Src/Dnn/Compute/Devices/Cuda/Operations/Embeddings/Kernels/TokenEmbedding.Bf16.cu"
1415

1516
"Src/Dnn/Compute/Devices/Cuda/Operations/Encodings/Lpe/Kernels/Lpe.Fp32.cu"
1617
"Src/Dnn/Compute/Devices/Cuda/Operations/Encodings/Lpe/Kernels/Lpe.Fp16.cu"
1718
"Src/Dnn/Compute/Devices/Cuda/Operations/Encodings/Lpe/Kernels/Lpe.cuh"
1819
"Src/Dnn/Compute/Devices/Cuda/Operations/Encodings/Rope/Kernels/Rope.Fp32.cu"
19-
#"Src/Dnn/Compute/Devices/Cuda/Operations/Encodings/Rope/Kernels/Rope.Fp16.cu"
20+
"Src/Dnn/Compute/Devices/Cuda/Operations/Encodings/Rope/Kernels/Rope.Bf16.cu"
2021
"Src/Dnn/Compute/Devices/Cuda/Operations/Encodings/Rope/Kernels/Rope.cuh"
2122

2223
"Src/Dnn/Compute/Devices/Cuda/Operations/Attention/Common/Kernels/CudaAttention.cuh"
@@ -446,7 +447,7 @@ PUBLIC
446447
"Src/Dnn/Components/Attention/GQA/GroupedQueryAttention.Config.ixx"
447448

448449
"Src/Dnn/Compute/Operations/PairedOperation.ixx"
449-
"Src/Dnn/Core/Component.MemoryStats.ixx" "Src/Dnn/Core/Model.RuntimeMode.ixx" "Src/Dnn/Core/Model.ixx" "Src/Dnn/Core/LanguageModel.ixx" "Src/Dnn/Core/Comonent.TrainingMode.ixx" "Src/Dnn/Compute/Operations/IPositionalPairedOp.ixx" "Src/Dnn/Tensors/Operations/TensorOps.Random.ixx" "Src/Dnn/Tensors/Operations/TensorOps.Structural.ixx" "Src/Dnn/Compute/Operations/IPositionalDecode.ixx" "Src/Dnn/Compute/Operations/IKvInference.ixx" "Src/Dnn/Compute/Operations/IPackedKvInference.ixx")
450+
"Src/Dnn/Core/Component.MemoryStats.ixx" "Src/Dnn/Core/Model.RuntimeMode.ixx" "Src/Dnn/Core/Model.ixx" "Src/Dnn/Core/LanguageModel.ixx" "Src/Dnn/Core/Comonent.TrainingMode.ixx" "Src/Dnn/Compute/Operations/IPositionalPairedOp.ixx" "Src/Dnn/Tensors/Operations/TensorOps.Random.ixx" "Src/Dnn/Tensors/Operations/TensorOps.Structural.ixx" "Src/Dnn/Compute/Operations/IPositionalDecode.ixx" "Src/Dnn/Compute/Operations/IKvInference.ixx" "Src/Dnn/Compute/Operations/IPackedKvInference.ixx" "Src/Dnn/Core/TokenStreamer.ixx")
450451

451452
set(MILA_INSTALL_FILE_SET_ARGS FILE_SET module_files DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/mila/modules)
452453

Mila/Samples/Chat/Src/Chat.ixx

Lines changed: 66 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ module;
1515
#include <format>
1616
#include <memory>
1717
#include <stdexcept>
18+
#include <future>
19+
#include <stop_token>
1820

1921
export module Mila.Chat;
2022
export import Chat.Config;
@@ -25,7 +27,6 @@ namespace Mila::ChatApp
2527
using namespace Mila::Dnn;
2628
using namespace Mila::Dnn::Compute;
2729
using namespace Mila::Data;
28-
using namespace Mila::Data;
2930

3031
using LanguageModelType = LanguageModel<DeviceType::Cuda, TensorDataType::FP32>;
3132

@@ -84,11 +85,34 @@ namespace Mila::ChatApp
8485

8586
conversation_history.push_back( "User: " + user_input );
8687

87-
std::string response = generateResponse( conversation_history );
88+
const std::string& prompt = conversation_history.back().substr( 6 );
89+
std::vector<TokenId> prompt_tokens = tokenizer_->encode( prompt );
90+
std::vector<int32_t> input_tokens( prompt_tokens.begin(), prompt_tokens.end() );
91+
92+
std::string response;
93+
response.reserve( 512 );
94+
95+
std::cout << "\nMila: ";
8896

89-
conversation_history.push_back( "Mila: " + response );
97+
stop_src_ = std::stop_source{};
9098

91-
std::cout << "\nMila: " << response << "\n";
99+
auto fut = model_->generateAsync(
100+
input_tokens,
101+
[&]( int32_t tok )
102+
{
103+
auto text = tokenizer_->decode( std::vector<TokenId>{ static_cast<TokenId>(tok) } );
104+
response += text;
105+
std::cout << text << std::flush;
106+
},
107+
config_.max_new_tokens,
108+
config_.temperature,
109+
config_.top_k,
110+
stop_src_.get_token() );
111+
112+
fut.wait();
113+
std::cout << '\n';
114+
115+
conversation_history.push_back( "Mila: " + trimResponse( response ) );
92116
}
93117
}
94118

@@ -123,92 +147,56 @@ namespace Mila::ChatApp
123147

124148
void loadModel()
125149
{
126-
//try
127-
//{
128-
std::cout << "Loading model from: " << config_.model_path << "\n";
150+
std::cout << "Loading model from: " << config_.model_path << "\n";
129151

130-
switch ( config_.model_type )
131-
{
132-
case ModelType::Gpt:
133-
model_ = GptModel<DeviceType::Cuda, TensorDataType::FP32>::fromPretrained(
134-
config_.model_path,
135-
config_.context_length,
136-
DeviceId{ DeviceType::Cuda, 0 },
137-
/*strict=*/true );
138-
break;
139-
140-
case ModelType::Llama:
141-
model_ = LlamaModel<DeviceType::Cuda, TensorDataType::FP32>::fromPretrained(
142-
config_.model_path,
143-
config_.context_length,
144-
DeviceId{ DeviceType::Cuda, 0 },
145-
/*strict=*/true );
146-
break;
147-
}
148-
149-
std::cout << model_->toString();
150-
151-
auto stats = model_->getMemoryStats();
152-
std::cout << stats.toString() << "\n";
153-
154-
std::cout << "Model loaded successfully!\n";
155-
//}
156-
//catch ( const std::exception& e )
157-
//{
158-
// std::cerr << "Error loading model: " << e.what() << "\n";
159-
// throw;
160-
//}
161-
}
162-
163-
std::string generateResponse( const std::vector<std::string>& history )
164-
{
165-
/*try
166-
{*/
167-
if ( !tokenizer_ )
168-
return "Tokenizer not loaded.";
169-
170-
// Both GPT-2 and LLaMA base models are completion models; pass the raw
171-
// user text without a chat template to avoid instruction-format mismatch.
172-
const std::string& prompt = history.back().substr( 6 ); // strip "User: "
173-
174-
std::vector<TokenId> prompt_tokens = tokenizer_->encode( prompt );
152+
switch ( config_.model_type )
153+
{
154+
case ModelType::Gpt:
155+
model_ = GptModel<DeviceType::Cuda, TensorDataType::FP32>::fromPretrained(
156+
config_.model_path,
157+
config_.context_length,
158+
DeviceId{ DeviceType::Cuda, 0 },
159+
/*strict=*/true );
160+
break;
175161

176-
std::vector<int32_t> input_tokens( prompt_tokens.begin(), prompt_tokens.end() );
162+
case ModelType::Llama:
163+
model_ = LlamaModel<DeviceType::Cuda, TensorDataType::FP32>::fromPretrained(
164+
config_.model_path,
165+
config_.context_length,
166+
DeviceId{ DeviceType::Cuda, 0 },
167+
/*strict=*/true );
168+
break;
169+
}
177170

178-
std::vector<int32_t> generated = model_->generate(
179-
std::vector<int32_t>( input_tokens ),
180-
config_.max_new_tokens,
181-
config_.temperature,
182-
config_.top_k );
171+
std::cout << model_->toString();
183172

184-
std::string full_text = tokenizer_->decode( std::vector<TokenId>( generated.begin(), generated.end() ) );
173+
auto stats = model_->getMemoryStats();
174+
std::cout << stats.toString() << "\n";
185175

186-
return extractResponse( full_text, prompt );
187-
/*}
188-
catch ( const std::exception& e )
189-
{
190-
return "Error: " + std::string( e.what() );
191-
}*/
176+
std::cout << "Model loaded successfully!\n";
192177
}
193178

194-
std::string extractResponse(
195-
const std::string& full_output,
196-
const std::string& prompt ) const
179+
/**
180+
* @brief Strip leading whitespace and truncate at the first paragraph break.
181+
*
182+
* Applied to the accumulated streaming response before storing in history.
183+
* The live printed output is unaffected.
184+
*/
185+
std::string trimResponse( const std::string& raw ) const
197186
{
198-
if ( full_output.size() <= prompt.size() )
199-
return full_output;
187+
auto start = raw.find_first_not_of( " \t\n\r" );
188+
189+
if ( start == std::string::npos )
190+
return {};
200191

201-
std::string response = full_output.substr( prompt.size() );
192+
std::string result = raw.substr( start );
202193

203-
auto start = response.find_first_not_of( " \t\n\r" );
204-
if ( start != std::string::npos )
205-
response = response.substr( start );
194+
auto end = result.find( "\n\n" );
206195

207-
auto end = response.find( "\n\n" );
208196
if ( end != std::string::npos )
209-
response = response.substr( 0, end );
197+
result.resize( end );
210198

211-
return response;
199+
return result;
212200
}
213201

214202
void printWelcome() const
@@ -243,5 +231,6 @@ Just type your message to chat with Mila AI.
243231
ChatConfig config_;
244232
std::unique_ptr<LanguageModelType> model_;
245233
std::shared_ptr<BpeTokenizer> tokenizer_{ nullptr };
234+
std::stop_source stop_src_;
246235
};
247236
}

Mila/Src/Dnn/Compute/Devices/Cuda/Operations/Embeddings/CudaTokenEmbeddingOp.Dispatch.ixx

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ export module Compute.CudaTokenEmbeddingOp:Dispatch;
1515
namespace Mila::Dnn::Compute::Cuda::TokenEmbedding::Detail
1616
{
1717
template <typename TNative>
18-
requires std::is_same_v<TNative, float> || std::is_same_v<TNative, half>
18+
requires std::is_same_v<TNative, float> || std::is_same_v<TNative, __nv_bfloat16>
1919
struct cuda_token_embedding_impl;
2020

2121
// ========================================================================
@@ -48,31 +48,31 @@ namespace Mila::Dnn::Compute::Cuda::TokenEmbedding::Detail
4848
};
4949

5050
// ========================================================================
51-
// FP16 (stubs)
51+
// BF16
5252
// ========================================================================
5353

5454
template <>
55-
struct cuda_token_embedding_impl<half>
55+
struct cuda_token_embedding_impl<__nv_bfloat16>
5656
{
5757
static void forward(
58-
half* Y, const int* X, const half* wte,
58+
__nv_bfloat16* Y, const int* X, const __nv_bfloat16* wte,
5959
int B, int T, int C, cudaStream_t stream )
6060
{
61-
// TODO: cuda_token_embedding_forward_fp16(...)
61+
cuda_token_embedding_forward_bf16( Y, X, wte, B, T, C, stream );
6262
}
6363

6464
static void backward(
65-
half* dwte, const half* dY, const int* X,
65+
__nv_bfloat16* dwte, const __nv_bfloat16* dY, const int* X,
6666
int B, int T, int C, cudaStream_t stream )
6767
{
68-
// TODO: cuda_token_embedding_backward_fp16(...)
68+
cuda_token_embedding_backward_bf16( dwte, dY, X, B, T, C, stream );
6969
}
7070

7171
static void decode(
72-
half* Y, const int* X, const half* wte,
72+
__nv_bfloat16* Y, const int* X, const __nv_bfloat16* wte,
7373
int B, int C, cudaStream_t stream )
7474
{
75-
// TODO: cuda_token_embedding_decode_fp16(...)
75+
cuda_token_embedding_decode_bf16( Y, X, wte, B, C, stream );
7676
}
7777
};
7878
}

Mila/Src/Dnn/Compute/Devices/Cuda/Operations/Embeddings/CudaTokenEmbeddingOp.ixx

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -280,8 +280,8 @@ namespace Mila::Dnn::Compute::Cuda::TokenEmbedding
280280
TensorDataType::INT32, TensorDataType::FP32>( "TokenEmbeddingOp" );
281281

282282
registerUnaryOpType<DeviceType::Cuda,
283-
CudaTokenEmbeddingOp<TensorDataType::INT32, TensorDataType::FP16>,
284-
TensorDataType::INT32, TensorDataType::FP16>( "TokenEmbeddingOp" );
283+
CudaTokenEmbeddingOp<TensorDataType::INT32, TensorDataType::BF16>,
284+
TensorDataType::INT32, TensorDataType::BF16>( "TokenEmbeddingOp" );
285285
}
286286
};
287287
}

0 commit comments

Comments
 (0)