1313from .exceptions import SourceNameError
1414from .reader import DataCollection , by_id , by_index
1515from .read_machinery import DataChunk , roi_shape , split_trains
16+ from .utils import default_num_threads
1617from .writer import FileWriter
1718from .write_cxi import XtdfCXIWriter , JUNGFRAUCXIWriter
1819
@@ -1216,7 +1217,56 @@ def _read_chunk(self, chunk: DataChunk, mod_out, roi):
12161217 axis = 0 , out = mod_out [tgt_pulse_sel ]
12171218 )
12181219
1219- def ndarray (self , * , fill_value = None , out = None , roi = (), astype = None , module_gaps = False ):
1220+ def _read_parallel_decompress (self , out , module_gaps , threads = 16 ):
1221+ from .compression import multi_dataset_decompressor , parallel_decompress_chunks
1222+
1223+ modno_to_keydata_no_virtual = {}
1224+ all_datasets = []
1225+ for (m , vkd ) in self .modno_to_keydata .items ():
1226+ modno_to_keydata_no_virtual [m ] = kd = vkd ._without_virtual_overview ()
1227+ all_datasets .extend ([f .file [kd .hdf5_data_path ] for f in kd .files ])
1228+
1229+ if any (d .chunks != (1 ,) + d .shape [1 :] for d in all_datasets ):
1230+ return False # Chunking not as we expect
1231+
1232+ decomp_proto = multi_dataset_decompressor (all_datasets )
1233+ if decomp_proto is None :
1234+ return False # No suitable fast decompression path
1235+
1236+ load_tasks = []
1237+ for i , (modno , kd ) in enumerate (sorted (modno_to_keydata_no_virtual .items ())):
1238+ mod_ix = (modno - self .det ._modnos_start_at ) if module_gaps else i
1239+ # 'chunk' in the lines below means a range of consecutive indices
1240+ # in one HDF5 dataset, as elsewhere in EXtra-data.
1241+ # We use this to build a list of HDF5 chunks (1 frame per chunk)
1242+ # to be loaded & decompressed. Sorry about that.
1243+ for chunk in kd ._data_chunks :
1244+ dset = chunk .dataset
1245+
1246+ for tgt_slice , chunk_slice in self .det ._split_align_chunk (
1247+ chunk , self .det .train_ids_perframe ,
1248+ ):
1249+ inc_pulses_chunk = self ._sel_frames [tgt_slice ]
1250+ if inc_pulses_chunk .sum () == 0 : # No data from this chunk selected
1251+ continue
1252+
1253+ dataset_ixs = np .nonzero (inc_pulses_chunk )[0 ] + chunk_slice .start
1254+
1255+ # Where does this data go in the target array?
1256+ tgt_start_ix = self ._sel_frames [:tgt_slice .start ].sum ()
1257+
1258+ # Each task is a h5py.h5d.DatasetID, coordinates & array destination
1259+ load_tasks .extend ([
1260+ (dset .id , (ds_ix , 0 , 0 ), out [mod_ix , tgt_start_ix + i ])
1261+ for i , ds_ix in enumerate (dataset_ixs )]
1262+ )
1263+
1264+ parallel_decompress_chunks (load_tasks , decomp_proto , threads = threads )
1265+
1266+ return True
1267+
1268+ def ndarray (self , * , fill_value = None , out = None , roi = (), astype = None ,
1269+ module_gaps = False , decompress_threads = None ):
12201270 """Get an array of per-pulse data (image.*) for xtdf detector"""
12211271 out_shape = self .buffer_shape (module_gaps = module_gaps , roi = roi )
12221272
@@ -1226,6 +1276,14 @@ def ndarray(self, *, fill_value=None, out=None, roi=(), astype=None, module_gaps
12261276 elif out .shape != out_shape :
12271277 raise ValueError (f'requires output array of shape { out_shape } ' )
12281278
1279+ if roi == () and astype is None :
1280+ if decompress_threads is None :
1281+ decompress_threads = default_num_threads (fixed_limit = 16 )
1282+
1283+ if decompress_threads > 1 :
1284+ if self ._read_parallel_decompress (out , module_gaps , decompress_threads ):
1285+ return out
1286+
12291287 reading_view = out .view ()
12301288 if self ._extraneous_dim :
12311289 reading_view .shape = out .shape [:2 ] + (1 ,) + out .shape [2 :]
@@ -1252,8 +1310,13 @@ def _wrap_xarray(self, arr, subtrain_index='pulseId'):
12521310 })
12531311
12541312 def xarray (self , * , pulses = None , fill_value = None , roi = (), astype = None ,
1255- subtrain_index = 'pulseId' , unstack_pulses = False ):
1256- arr = self .ndarray (fill_value = fill_value , roi = roi , astype = astype )
1313+ subtrain_index = 'pulseId' , unstack_pulses = False , decompress_threads = None ):
1314+ arr = self .ndarray (
1315+ fill_value = fill_value ,
1316+ roi = roi ,
1317+ astype = astype ,
1318+ decompress_threads = decompress_threads ,
1319+ )
12571320 out = self ._wrap_xarray (arr , subtrain_index )
12581321
12591322 if unstack_pulses :
0 commit comments