Skip to content

Commit 2f1123b

Browse files
authored
Merge pull request #593 from European-XFEL/parallel-decompress-cleanup
Parallel decompression for detector data
2 parents 5967f59 + 2167dcd commit 2f1123b

12 files changed

Lines changed: 324 additions & 18 deletions

File tree

.coveragerc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[run]
22
omit = */tests/*
3-
concurrency = multiprocessing
3+
concurrency = multiprocessing,thread
44

55
[paths]
66
source =

.github/workflows/tests.yml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,9 @@ jobs:
4040
MPLBACKEND: agg
4141

4242
- name: Upload coverage to Codecov
43-
# This specific version uses Node 20 - see codecov-action#1293
44-
uses: codecov/codecov-action@v3.1.5
43+
uses: codecov/codecov-action@v5
44+
with:
45+
token: ${{ secrets.CODECOV_TOKEN }}
4546

4647
publish:
4748
runs-on: ubuntu-latest

.readthedocs.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
version: 2 # Required
22

33
build:
4-
os: ubuntu-20.04
4+
os: ubuntu-24.04
55
tools:
66
python: "3.12"
77

extra_data/components.py

Lines changed: 66 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from .exceptions import SourceNameError
1414
from .reader import DataCollection, by_id, by_index
1515
from .read_machinery import DataChunk, roi_shape, split_trains
16+
from .utils import default_num_threads
1617
from .writer import FileWriter
1718
from .write_cxi import XtdfCXIWriter, JUNGFRAUCXIWriter
1819

@@ -1216,7 +1217,56 @@ def _read_chunk(self, chunk: DataChunk, mod_out, roi):
12161217
axis=0, out=mod_out[tgt_pulse_sel]
12171218
)
12181219

1219-
def ndarray(self, *, fill_value=None, out=None, roi=(), astype=None, module_gaps=False):
1220+
def _read_parallel_decompress(self, out, module_gaps, threads=16):
1221+
from .compression import multi_dataset_decompressor, parallel_decompress_chunks
1222+
1223+
modno_to_keydata_no_virtual = {}
1224+
all_datasets = []
1225+
for (m, vkd) in self.modno_to_keydata.items():
1226+
modno_to_keydata_no_virtual[m] = kd = vkd._without_virtual_overview()
1227+
all_datasets.extend([f.file[kd.hdf5_data_path] for f in kd.files])
1228+
1229+
if any(d.chunks != (1,) + d.shape[1:] for d in all_datasets):
1230+
return False # Chunking not as we expect
1231+
1232+
decomp_proto = multi_dataset_decompressor(all_datasets)
1233+
if decomp_proto is None:
1234+
return False # No suitable fast decompression path
1235+
1236+
load_tasks = []
1237+
for i, (modno, kd) in enumerate(sorted(modno_to_keydata_no_virtual.items())):
1238+
mod_ix = (modno - self.det._modnos_start_at) if module_gaps else i
1239+
# 'chunk' in the lines below means a range of consecutive indices
1240+
# in one HDF5 dataset, as elsewhere in EXtra-data.
1241+
# We use this to build a list of HDF5 chunks (1 frame per chunk)
1242+
# to be loaded & decompressed. Sorry about that.
1243+
for chunk in kd._data_chunks:
1244+
dset = chunk.dataset
1245+
1246+
for tgt_slice, chunk_slice in self.det._split_align_chunk(
1247+
chunk, self.det.train_ids_perframe,
1248+
):
1249+
inc_pulses_chunk = self._sel_frames[tgt_slice]
1250+
if inc_pulses_chunk.sum() == 0: # No data from this chunk selected
1251+
continue
1252+
1253+
dataset_ixs = np.nonzero(inc_pulses_chunk)[0] + chunk_slice.start
1254+
1255+
# Where does this data go in the target array?
1256+
tgt_start_ix = self._sel_frames[:tgt_slice.start].sum()
1257+
1258+
# Each task is a h5py.h5d.DatasetID, coordinates & array destination
1259+
load_tasks.extend([
1260+
(dset.id, (ds_ix, 0, 0), out[mod_ix, tgt_start_ix + i])
1261+
for i, ds_ix in enumerate(dataset_ixs)]
1262+
)
1263+
1264+
parallel_decompress_chunks(load_tasks, decomp_proto, threads=threads)
1265+
1266+
return True
1267+
1268+
def ndarray(self, *, fill_value=None, out=None, roi=(), astype=None,
1269+
module_gaps=False, decompress_threads=None):
12201270
"""Get an array of per-pulse data (image.*) for xtdf detector"""
12211271
out_shape = self.buffer_shape(module_gaps=module_gaps, roi=roi)
12221272

@@ -1226,6 +1276,14 @@ def ndarray(self, *, fill_value=None, out=None, roi=(), astype=None, module_gaps
12261276
elif out.shape != out_shape:
12271277
raise ValueError(f'requires output array of shape {out_shape}')
12281278

1279+
if roi == () and astype is None:
1280+
if decompress_threads is None:
1281+
decompress_threads = default_num_threads(fixed_limit=16)
1282+
1283+
if decompress_threads > 1:
1284+
if self._read_parallel_decompress(out, module_gaps, decompress_threads):
1285+
return out
1286+
12291287
reading_view = out.view()
12301288
if self._extraneous_dim:
12311289
reading_view.shape = out.shape[:2] + (1,) + out.shape[2:]
@@ -1252,8 +1310,13 @@ def _wrap_xarray(self, arr, subtrain_index='pulseId'):
12521310
})
12531311

12541312
def xarray(self, *, pulses=None, fill_value=None, roi=(), astype=None,
1255-
subtrain_index='pulseId', unstack_pulses=False):
1256-
arr = self.ndarray(fill_value=fill_value, roi=roi, astype=astype)
1313+
subtrain_index='pulseId', unstack_pulses=False, decompress_threads=None):
1314+
arr = self.ndarray(
1315+
fill_value=fill_value,
1316+
roi=roi,
1317+
astype=astype,
1318+
decompress_threads=decompress_threads,
1319+
)
12571320
out = self._wrap_xarray(arr, subtrain_index)
12581321

12591322
if unstack_pulses:

extra_data/compression.py

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
import threading
2+
from copy import copy
3+
from multiprocessing.pool import ThreadPool
4+
5+
import h5py
6+
import numpy as np
7+
from zlib_into import decompress_into
8+
9+
10+
def filter_ids(dset: h5py.Dataset):
11+
dcpl = dset.id.get_create_plist()
12+
return [dcpl.get_filter(i)[0] for i in range(dcpl.get_nfilters())]
13+
14+
15+
class DeflateDecompressor:
16+
def __init__(self, deflate_filter_idx=0):
17+
self.deflate_filter_bit = 1 << deflate_filter_idx
18+
19+
@classmethod
20+
def for_dataset(cls, dset: h5py.Dataset):
21+
filters = filter_ids(dset)
22+
if filters == [h5py.h5z.FILTER_DEFLATE]:
23+
return cls()
24+
if dset.dtype.itemsize == 1 and filters == [
25+
h5py.h5z.FILTER_SHUFFLE,
26+
h5py.h5z.FILTER_DEFLATE,
27+
]:
28+
# The shuffle filter doesn't change single byte values, so we can
29+
# skip it.
30+
return cls(deflate_filter_idx=1)
31+
32+
return None
33+
34+
def clone(self):
35+
return copy(self)
36+
37+
def apply_filters(self, data, filter_mask, out):
38+
if filter_mask & self.deflate_filter_bit:
39+
# The deflate filter is skipped, so just copy the data
40+
memoryview(out)[:] = data
41+
else:
42+
decompress_into(data, out)
43+
44+
45+
class ShuffleDeflateDecompressor:
46+
def __init__(self, chunk_shape, dtype):
47+
self.chunk_shape = chunk_shape
48+
self.dtype = dtype
49+
chunk_nbytes = dtype.itemsize
50+
for l in chunk_shape:
51+
chunk_nbytes *= l
52+
# This will hold the decompressed data before shuffling
53+
self.chunk_buf = np.zeros(chunk_nbytes, dtype=np.uint8)
54+
self.shuffled_view = ( # E.g. for int32 data with chunks (10, 5):
55+
self.chunk_buf # (200,) uint8
56+
.reshape((dtype.itemsize, -1)) # (4, 50)
57+
.transpose() # (50, 4)
58+
)
59+
# Check this is still a view on the buffered data
60+
assert self.shuffled_view.base is self.chunk_buf
61+
62+
@classmethod
63+
def for_dataset(cls, dset: h5py.Dataset):
64+
if filter_ids(dset) == [h5py.h5z.FILTER_SHUFFLE, h5py.h5z.FILTER_DEFLATE]:
65+
return cls(dset.chunks, dset.dtype)
66+
67+
return None
68+
69+
def clone(self):
70+
return type(self)(self.chunk_shape, self.dtype)
71+
72+
def apply_filters(self, data, filter_mask, out):
73+
if filter_mask & 2:
74+
# The deflate filter is skipped
75+
memoryview(self.chunk_buf)[:] = data
76+
else:
77+
decompress_into(data, self.chunk_buf)
78+
79+
if filter_mask & 1:
80+
# The shuffle filter is skipped
81+
memoryview(out)[:] = self.chunk_buf
82+
else:
83+
# Numpy does the shuffling by copying data between views
84+
out.reshape((-1, 1)).view(np.uint8)[:] = self.shuffled_view
85+
86+
87+
def dataset_decompressor(dset):
88+
for cls in [DeflateDecompressor, ShuffleDeflateDecompressor]:
89+
if (inst := cls.for_dataset(dset)) is not None:
90+
return inst
91+
92+
return None
93+
94+
95+
def multi_dataset_decompressor(dsets):
96+
if not dsets:
97+
return None
98+
99+
chunk = dsets[0].chunks
100+
dtype = dsets[0]
101+
filters = filter_ids(dsets[0])
102+
for d in dsets[1:]:
103+
if d.chunks != chunk or d.dtype != dtype or filter_ids(d) != filters:
104+
return None # Datasets are not consistent
105+
106+
return dataset_decompressor(dsets[0])
107+
108+
109+
def parallel_decompress_chunks(tasks, decompressor_proto, threads=16):
110+
tlocal = threading.local()
111+
112+
def load_one(dset_id, coord, dest):
113+
try:
114+
decomp = tlocal.decompressor
115+
except AttributeError:
116+
tlocal.decompressor = decomp = decompressor_proto.clone()
117+
118+
if dset_id.get_chunk_info_by_coord(coord).byte_offset is None:
119+
return # Chunk not allocated in file
120+
121+
filter_mask, compdata = dset_id.read_direct_chunk(coord)
122+
decomp.apply_filters(compdata, filter_mask, dest)
123+
124+
with ThreadPool(threads) as pool:
125+
pool.starmap(load_one, tasks)

extra_data/keydata.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,21 @@ def source_file_paths(self):
177177
from pathlib import Path
178178
return [Path(p) for p in paths]
179179

180+
def _without_virtual_overview(self):
181+
if not self.files[0].file[self.hdf5_data_path].is_virtual:
182+
# We're already looking at regular source files
183+
return self
184+
185+
return KeyData(
186+
self.source, self.key,
187+
train_ids=self.train_ids,
188+
files=[FileAccess(p) for p in self.source_file_paths],
189+
section=self.section,
190+
dtype=self.dtype,
191+
eshape=self.entry_shape,
192+
inc_suspect_trains=self.inc_suspect_trains,
193+
)
194+
180195
def _find_attributes(self, dset):
181196
"""Find Karabo attributes belonging to a dataset."""
182197
attrs = dict(dset.attrs)

extra_data/tests/make_examples.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -355,6 +355,11 @@ def make_modern_spb_proc_run(dir_path, format_version='1.2'):
355355
legacy_name=f'SPB_DET_AGIPD1M-1/DET/{modno}CH0')
356356
], ntrains=64, chunksize=32, format_version=format_version)
357357

358+
# Ensure one chunk of mask data is actually written
359+
with h5py.File(path, 'r+') as f:
360+
ds = f['INSTRUMENT/SPB_DET_AGIPD1M-1/CORR/15CH0:output/image/mask']
361+
ds[0, 0, 5] = 1
362+
358363

359364
def make_agipd1m_run(
360365
dir_path,

extra_data/tests/mockdata/detectors.py

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -38,18 +38,22 @@ def write_control(self, f):
3838
def image_keys(self):
3939
if self.raw:
4040
return [
41-
('data', 'u2', self.image_dims),
42-
('length', 'u4', (1,)),
43-
('status', 'u2', (1,)),
41+
('data', 'u2', self.image_dims, {}),
42+
('length', 'u4', (1,), {}),
43+
('status', 'u2', (1,), {}),
4444
]
4545

4646
else:
4747
return [
48-
('data', 'f4', self.image_dims),
49-
('mask', 'u4', self.image_dims),
50-
('gain', 'u1', self.image_dims),
51-
('length', 'u4', (1,)),
52-
('status', 'u2', (1,)),
48+
('data', 'f4', self.image_dims, {}),
49+
('mask', 'u4', self.image_dims, {
50+
'compression': 'gzip', 'compression_opts': 1
51+
}),
52+
('gain', 'u1', self.image_dims, {
53+
'compression': 'gzip', 'compression_opts': 1
54+
}),
55+
('length', 'u4', (1,), {}),
56+
('status', 'u2', (1,), {}),
5357
]
5458

5559
@property
@@ -138,9 +142,16 @@ def write_instrument(self, f):
138142
)
139143

140144
max_len = None if self.raw else nframes
141-
for (key, datatype, dims) in self.image_keys:
142-
f.create_dataset(f'INSTRUMENT/{inst_source}/image/{key}',
143-
(nframes,) + dims, datatype, maxshape=((max_len,) + dims))
145+
for (key, datatype, dims, kw) in self.image_keys:
146+
if dims == self.image_dims and 'chunks' not in kw:
147+
kw['chunks'] = (1,) + dims
148+
f.create_dataset(
149+
f'INSTRUMENT/{inst_source}/image/{key}',
150+
shape=(nframes,) + dims,
151+
dtype=datatype,
152+
maxshape=((max_len,) + dims),
153+
**kw
154+
)
144155

145156

146157
# INSTRUMENT (other parts)

extra_data/tests/test_components.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -592,6 +592,23 @@ def test_modern_corr_sources(mock_modern_spb_proc_run, mock_spb_raw_run_fmt1):
592592
assert 'image.mask' in agipd_dflt
593593

594594

595+
def test_decompress_threads(mock_modern_spb_proc_run):
596+
run = RunDirectory(mock_modern_spb_proc_run)
597+
598+
agipd = AGIPD1M(run[:1])
599+
# Load
600+
ref = agipd['image.mask'].ndarray(decompress_threads=1)
601+
print(ref.shape)
602+
print(ref[15, 0, :10, :10])
603+
assert ref[15, 0, 0, 0] == 0
604+
assert ref[15, 0, 0, 5] == 1
605+
606+
import h5py
607+
h5py._errors.unsilence_errors()
608+
arr = agipd['image.mask'].ndarray(decompress_threads=16)
609+
np.testing.assert_array_equal(arr, ref)
610+
611+
595612
def test_write_virtual_cxi(mock_spb_proc_run, tmpdir):
596613
run = RunDirectory(mock_spb_proc_run)
597614
det = AGIPD1M(run)

0 commit comments

Comments
 (0)