1515use std:: borrow:: Cow ;
1616use std:: collections:: HashMap ;
1717use std:: fmt:: Write ;
18- use std:: io:: ErrorKind ;
18+ use std:: io:: { ErrorKind , Write as _ } ;
1919use std:: path:: { Path , PathBuf } ;
2020use std:: sync:: Arc ;
2121
@@ -30,6 +30,7 @@ use ort::session::builder::GraphOptimizationLevel;
3030use serde:: Serialize ;
3131use tokio:: fs:: File ;
3232use tokio:: io:: AsyncReadExt ;
33+ use tokio:: task:: JoinSet ;
3334
3435/// Determines file content types using AI.
3536#[ derive( Parser ) ]
@@ -158,6 +159,13 @@ struct Experimental {
158159
159160#[ tokio:: main]
160161async fn main ( ) -> Result < ( ) > {
162+ let mut tasks = JoinSet :: new ( ) ;
163+ let result = start ( & mut tasks) . await ;
164+ while tasks. join_next ( ) . await . is_some ( ) { }
165+ result
166+ }
167+
168+ async fn start ( tasks : & mut JoinSet < ( ) > ) -> Result < ( ) > {
161169 let mut flags = Flags :: parse ( ) ;
162170 ensure ! ( 0 < flags. experimental. batch_size, "--batch-size cannot be zero" ) ;
163171 // If --num-tasks is set, we don't do any guessing.
@@ -183,33 +191,49 @@ async fn main() -> Result<()> {
183191 if flags. colors . disable {
184192 colored:: control:: set_override ( false ) ;
185193 }
186- let ( result_sender, mut result_receiver) =
194+ let ( result_sender, result_receiver) =
187195 tokio:: sync:: mpsc:: channel :: < Result < Response > > ( num_tasks * flags. experimental . batch_size ) ;
188196 let ( batch_sender, batch_receiver) = async_channel:: bounded :: < Batch > ( num_tasks) ;
189- tokio :: spawn ( {
197+ tasks . spawn ( {
190198 let flags = flags. clone ( ) ;
191199 let result_sender = result_sender. clone ( ) ;
192200 async move {
193201 if let Err ( e) = extract_features ( & flags, & batch_sender, & result_sender) . await {
194- result_sender. send ( Err ( e) ) . await . unwrap ( ) ;
202+ let _ = result_sender. send ( Err ( e) ) . await ;
195203 }
196204 }
197205 } ) ;
198206 for _ in 0 ..num_tasks {
199207 let mut magika = build_session ( & flags) ?;
200- tokio :: spawn ( {
208+ tasks . spawn ( {
201209 let batch_receiver = batch_receiver. clone ( ) ;
202210 let result_sender = result_sender. clone ( ) ;
203211 async move {
204212 if let Err ( e) = infer_batch ( & mut magika, & batch_receiver, & result_sender) . await {
205- result_sender. send ( Err ( e) ) . await . unwrap ( ) ;
213+ let _ = result_sender. send ( Err ( e) ) . await ;
206214 }
207215 }
208216 } ) ;
209217 }
210218 drop ( result_sender) ;
219+ match print ( & flags, result_receiver) . await {
220+ Err ( e)
221+ if e. root_cause ( )
222+ . downcast_ref :: < std:: io:: Error > ( )
223+ . is_some_and ( |x| x. kind ( ) == std:: io:: ErrorKind :: BrokenPipe ) =>
224+ {
225+ Ok ( ( ) )
226+ }
227+ x => x,
228+ }
229+ }
230+
231+ async fn print (
232+ flags : & Flags , mut result_receiver : tokio:: sync:: mpsc:: Receiver < Result < Response > > ,
233+ ) -> Result < ( ) > {
234+ let mut stdout = std:: io:: stdout ( ) . lock ( ) ;
211235 if flags. format . json {
212- print ! ( "[" ) ;
236+ write ! ( stdout , "[" ) ? ;
213237 }
214238 let mut reorder = Reorder :: default ( ) ;
215239 let mut errors = false ;
@@ -219,22 +243,22 @@ async fn main() -> Result<()> {
219243 errors |= response. result . is_err ( ) ;
220244 if flags. format . json {
221245 if reorder. next != 1 {
222- print ! ( "," ) ;
246+ write ! ( stdout , "," ) ? ;
223247 }
224248 for line in serde_json:: to_string_pretty ( & response. json ( ) ?) ?. lines ( ) {
225- print ! ( "\n {line}" ) ;
249+ write ! ( stdout , "\n {line}" ) ? ;
226250 }
227251 } else {
228- println ! ( "{}" , response. format( & flags) ?) ;
252+ writeln ! ( stdout , "{}" , response. format( flags) ?) ? ;
229253 }
230254 }
231255 }
232256 debug_assert ! ( reorder. is_empty( ) ) ;
233257 if flags. format . json {
234258 if reorder. next != 0 {
235- println ! ( ) ;
259+ writeln ! ( stdout ) ? ;
236260 }
237- println ! ( "]" ) ;
261+ writeln ! ( stdout , "]" ) ? ;
238262 }
239263 if errors {
240264 std:: process:: exit ( 1 ) ;
0 commit comments