@@ -15,6 +15,8 @@ module;
1515#include < format>
1616#include < memory>
1717#include < stdexcept>
18+ #include < future>
19+ #include < stop_token>
1820
1921export module Mila.Chat;
2022export import Chat.Config;
@@ -25,7 +27,6 @@ namespace Mila::ChatApp
2527 using namespace Mila ::Dnn;
2628 using namespace Mila ::Dnn::Compute;
2729 using namespace Mila ::Data;
28- using namespace Mila ::Data;
2930
3031 using LanguageModelType = LanguageModel<DeviceType::Cuda, TensorDataType::FP32>;
3132
@@ -84,11 +85,34 @@ namespace Mila::ChatApp
8485
8586 conversation_history.push_back ( " User: " + user_input );
8687
87- std::string response = generateResponse ( conversation_history );
88+ const std::string& prompt = conversation_history.back ().substr ( 6 );
89+ std::vector<TokenId> prompt_tokens = tokenizer_->encode ( prompt );
90+ std::vector<int32_t > input_tokens ( prompt_tokens.begin (), prompt_tokens.end () );
91+
92+ std::string response;
93+ response.reserve ( 512 );
94+
95+ std::cout << " \n Mila: " ;
8896
89- conversation_history. push_back ( " Mila: " + response ) ;
97+ stop_src_ = std::stop_source{} ;
9098
91- std::cout << " \n Mila: " << response << " \n " ;
99+ auto fut = model_->generateAsync (
100+ input_tokens,
101+ [&]( int32_t tok )
102+ {
103+ auto text = tokenizer_->decode ( std::vector<TokenId>{ static_cast <TokenId>(tok) } );
104+ response += text;
105+ std::cout << text << std::flush;
106+ },
107+ config_.max_new_tokens ,
108+ config_.temperature ,
109+ config_.top_k ,
110+ stop_src_.get_token () );
111+
112+ fut.wait ();
113+ std::cout << ' \n ' ;
114+
115+ conversation_history.push_back ( " Mila: " + trimResponse ( response ) );
92116 }
93117 }
94118
@@ -123,92 +147,56 @@ namespace Mila::ChatApp
123147
124148 void loadModel ()
125149 {
126- // try
127- // {
128- std::cout << " Loading model from: " << config_.model_path << " \n " ;
150+ std::cout << " Loading model from: " << config_.model_path << " \n " ;
129151
130- switch ( config_.model_type )
131- {
132- case ModelType::Gpt:
133- model_ = GptModel<DeviceType::Cuda, TensorDataType::FP32>::fromPretrained (
134- config_.model_path ,
135- config_.context_length ,
136- DeviceId{ DeviceType::Cuda, 0 },
137- /* strict=*/ true );
138- break ;
139-
140- case ModelType::Llama:
141- model_ = LlamaModel<DeviceType::Cuda, TensorDataType::FP32>::fromPretrained (
142- config_.model_path ,
143- config_.context_length ,
144- DeviceId{ DeviceType::Cuda, 0 },
145- /* strict=*/ true );
146- break ;
147- }
148-
149- std::cout << model_->toString ();
150-
151- auto stats = model_->getMemoryStats ();
152- std::cout << stats.toString () << " \n " ;
153-
154- std::cout << " Model loaded successfully!\n " ;
155- // }
156- // catch ( const std::exception& e )
157- // {
158- // std::cerr << "Error loading model: " << e.what() << "\n";
159- // throw;
160- // }
161- }
162-
163- std::string generateResponse ( const std::vector<std::string>& history )
164- {
165- /* try
166- {*/
167- if ( !tokenizer_ )
168- return " Tokenizer not loaded." ;
169-
170- // Both GPT-2 and LLaMA base models are completion models; pass the raw
171- // user text without a chat template to avoid instruction-format mismatch.
172- const std::string& prompt = history.back ().substr ( 6 ); // strip "User: "
173-
174- std::vector<TokenId> prompt_tokens = tokenizer_->encode ( prompt );
152+ switch ( config_.model_type )
153+ {
154+ case ModelType::Gpt:
155+ model_ = GptModel<DeviceType::Cuda, TensorDataType::FP32>::fromPretrained (
156+ config_.model_path ,
157+ config_.context_length ,
158+ DeviceId{ DeviceType::Cuda, 0 },
159+ /* strict=*/ true );
160+ break ;
175161
176- std::vector<int32_t > input_tokens ( prompt_tokens.begin (), prompt_tokens.end () );
162+ case ModelType::Llama:
163+ model_ = LlamaModel<DeviceType::Cuda, TensorDataType::FP32>::fromPretrained (
164+ config_.model_path ,
165+ config_.context_length ,
166+ DeviceId{ DeviceType::Cuda, 0 },
167+ /* strict=*/ true );
168+ break ;
169+ }
177170
178- std::vector<int32_t > generated = model_->generate (
179- std::vector<int32_t >( input_tokens ),
180- config_.max_new_tokens ,
181- config_.temperature ,
182- config_.top_k );
171+ std::cout << model_->toString ();
183172
184- std::string full_text = tokenizer_->decode ( std::vector<TokenId>( generated.begin (), generated.end () ) );
173+ auto stats = model_->getMemoryStats ();
174+ std::cout << stats.toString () << " \n " ;
185175
186- return extractResponse ( full_text, prompt );
187- /* }
188- catch ( const std::exception& e )
189- {
190- return "Error: " + std::string( e.what() );
191- }*/
176+ std::cout << " Model loaded successfully!\n " ;
192177 }
193178
194- std::string extractResponse (
195- const std::string& full_output,
196- const std::string& prompt ) const
179+ /* *
180+ * @brief Strip leading whitespace and truncate at the first paragraph break.
181+ *
182+ * Applied to the accumulated streaming response before storing in history.
183+ * The live printed output is unaffected.
184+ */
185+ std::string trimResponse ( const std::string& raw ) const
197186 {
198- if ( full_output.size () <= prompt.size () )
199- return full_output;
187+ auto start = raw.find_first_not_of ( " \t\n\r " );
188+
189+ if ( start == std::string::npos )
190+ return {};
200191
201- std::string response = full_output .substr ( prompt. size () );
192+ std::string result = raw .substr ( start );
202193
203- auto start = response.find_first_not_of ( " \t\n\r " );
204- if ( start != std::string::npos )
205- response = response.substr ( start );
194+ auto end = result.find ( " \n\n " );
206195
207- auto end = response.find ( " \n\n " );
208196 if ( end != std::string::npos )
209- response = response. substr ( 0 , end );
197+ result. resize ( end );
210198
211- return response ;
199+ return result ;
212200 }
213201
214202 void printWelcome () const
@@ -243,5 +231,6 @@ Just type your message to chat with Mila AI.
243231 ChatConfig config_;
244232 std::unique_ptr<LanguageModelType> model_;
245233 std::shared_ptr<BpeTokenizer> tokenizer_{ nullptr };
234+ std::stop_source stop_src_;
246235 };
247236}
0 commit comments