Skip to content

Commit 0106c79

Browse files
author
Dpbm
committed
added Images class
1 parent 616a21c commit 0106c79

4 files changed

Lines changed: 148 additions & 257 deletions

File tree

dataset.py

Lines changed: 0 additions & 232 deletions
Original file line numberDiff line numberDiff line change
@@ -144,201 +144,6 @@ def __str__(self) -> str:
144144
)
145145

146146

147-
class CircuitResult(TypedDict):
148-
"""Type for circuit results"""
149-
150-
index: int
151-
depth: int
152-
file: str
153-
measurements: str # JSON string
154-
result: str # JSON string
155-
hash: str
156-
157-
158-
def get_circuit_results(qc: QuantumCircuit, sampler: Sampler, shots: int) -> Dist:
159-
"""Execute circuit on sampler. Returns its quasi dist"""
160-
return sampler.run([qc], shots=shots).result().quasi_dists[0] # type: ignore
161-
162-
163-
def fix_dist_gaps(dist: Dist, states: States):
164-
"""Auxiliary function to fill the remaining bitstrings with 0"""
165-
for state in states:
166-
result_value = dist.get(state)
167-
if result_value is None:
168-
dist[state] = 0
169-
170-
171-
def generate_circuit_images(
172-
base_index: int,
173-
states: States,
174-
measurements: MeasurementsCombinations,
175-
base_image_path: FilePath,
176-
n_qubits: int,
177-
total_gates: int,
178-
shots: int,
179-
) -> List[CircuitResult]:
180-
"""
181-
Run an experiment, save its images and return its results for different, combinations
182-
of measurements.
183-
"""
184-
185-
sim = AerSimulator()
186-
pm = generate_preset_pass_manager(backend=sim, optimization_level=0)
187-
sampler = Sampler()
188-
results: List[CircuitResult] = []
189-
190-
# non-interactive backend
191-
matplotlib.use("Agg")
192-
193-
qc = get_random_circuit(n_qubits, total_gates)
194-
195-
for index, measurement in enumerate(measurements):
196-
image_index = base_index + index
197-
image_path = os.path.join(base_image_path, "%d.png" % image_index)
198-
199-
qc_copy = qc.copy()
200-
total_measurements = len(measurement)
201-
classical_register = ClassicalRegister(total_measurements, name="c")
202-
qc_copy.add_register(classical_register)
203-
qc_copy.measure(measurement, list(range(total_measurements)))
204-
205-
drawing = qc_copy.draw("mpl", filename=image_path, fold=-1, scale=SCALE_CIRCUIT_SIZE)
206-
plt.close(drawing)
207-
del drawing
208-
209-
depth = qc_copy.depth()
210-
isa_qc = pm.run(qc_copy)
211-
del qc_copy
212-
213-
with open(image_path, "rb") as file:
214-
file_hash = hashlib.md5(file.read()).hexdigest()
215-
216-
outcomes = get_circuit_results(isa_qc, sampler, shots)
217-
fix_dist_gaps(outcomes, states)
218-
219-
del isa_qc
220-
gc.collect()
221-
222-
# once we have more than a few combinations, depending on how many threads we
223-
# start, it can use a lot o memory. It also depends on how many states are possible, growing
224-
# exponentially with the number of qubits (2^n).
225-
results.append(
226-
{
227-
"index": image_index,
228-
"depth": depth,
229-
"file": image_path,
230-
"result": json.dumps(list(outcomes.values())),
231-
"hash": file_hash,
232-
"measurements": json.dumps(measurement),
233-
}
234-
)
235-
236-
# clear data
237-
del sim
238-
del pm
239-
del sampler
240-
gc.collect()
241-
242-
return results
243-
244-
245-
def generate_images(
246-
target_folder: FilePath,
247-
n_qubits: int,
248-
total_gates: int,
249-
shots: int,
250-
amount_circuits: int,
251-
total_threads: int,
252-
checkpoint: Checkpoint,
253-
):
254-
"""
255-
Generate multiple images and saves a dataframe with information about them.
256-
It runs in multiple threads(processes in this case) to speed up.
257-
"""
258-
259-
dataset_file_path = dataset_file(target_folder)
260-
261-
bitstrings_to_int = [
262-
int("".join(comb), 2) for comb in product("01", repeat=n_qubits)
263-
]
264-
265-
# get all measurement combinations
266-
# may be expensive with a large number of qubits, but for 5,6,... it's good
267-
qubits_iter = list(range(n_qubits))
268-
measurement_combs: MeasurementsCombinations = [
269-
qubits_iter
270-
] # start with [[0,1,2,3,4,....,n-1]]
271-
for amount in range(1, n_qubits):
272-
measurement_combs = [
273-
*measurement_combs,
274-
*list(combinations(qubits_iter, amount)), # type: ignore
275-
] # type: ignore
276-
total_measurement_combs = len(measurement_combs)
277-
278-
base_dataset_path = dataset_path(target_folder)
279-
280-
index = checkpoint.index
281-
with tqdm(total=amount_circuits, initial=index) as progress:
282-
while index < amount_circuits:
283-
args = []
284-
285-
for _ in range(total_threads):
286-
base_index = index * total_measurement_combs
287-
args.append(
288-
(
289-
base_index,
290-
bitstrings_to_int,
291-
measurement_combs,
292-
base_dataset_path,
293-
n_qubits,
294-
total_gates,
295-
shots,
296-
)
297-
)
298-
index += 1
299-
300-
with ThreadPoolExecutor(max_workers=total_threads) as pool:
301-
threads = [pool.submit(generate_circuit_images, *arg) for arg in args] # type:ignore
302-
303-
# The best would be using the polars scan_csv and sink_csv to
304-
# write memory efficient queries easily.
305-
# However, it's an experimental feature, and for some reason they don't work
306-
# well together.
307-
# https://github.com/pola-rs/polars/issues/22845
308-
# https://github.com/pola-rs/polars/issues/20468
309-
# to solve that, we gonna use the built-in python's csv library
310-
# to append the new lines without loading the whole csv into memory.
311-
312-
# df = open_csv(dataset_file_path)
313-
314-
rows: Rows = []
315-
for future in as_completed(threads): # type: ignore
316-
rows = [
317-
*rows,
318-
*[list(result.values()) for result in future.result()],
319-
]
320-
321-
append_rows_to_df(dataset_file_path, rows)
322-
323-
del rows
324-
del threads
325-
del args
326-
gc.collect()
327-
328-
# save_df(df, dataset_file_path)
329-
330-
# remove df from memory to open avoid excessive
331-
# of memory usage
332-
# del df
333-
# gc.collect()
334-
335-
progress.update(total_threads)
336-
337-
checkpoint.index = index
338-
checkpoint.save()
339-
340-
341-
342147

343148
def shuffle_csv(target_folder:FilePath):
344149
"""
@@ -350,43 +155,6 @@ def shuffle_csv(target_folder:FilePath):
350155
df = shuffle_df(df)
351156
df.write_csv(file_path)
352157

353-
def transform_images(
354-
target_folder: FilePath, new_dim: Dimensions, checkpoint: Checkpoint
355-
):
356-
"""Normalize images and save them into a h5 file"""
357-
print("%sTransforming images%s" % (Colors.GREENBG, Colors.ENDC))
358-
359-
df = open_csv(dataset_file(target_folder))
360-
361-
current_index = checkpoint.index
362-
amount_of_rows_per_iteration = 500
363-
364-
max_width, max_height = new_dim
365-
366-
while True:
367-
collected_rows: List[FilePath] = (
368-
df.slice(offset=current_index, length=amount_of_rows_per_iteration)
369-
.collect()
370-
.get_column("file")
371-
.to_list()
372-
)
373-
374-
if len(collected_rows) <= 0:
375-
break
376-
377-
image_i = checkpoint.index
378-
with h5py.File(images_h5_file(target_folder), "a") as file:
379-
for image_path in tqdm(collected_rows):
380-
with Image.open(image_path) as img:
381-
tensor = transform_image(img, max_width, max_height)
382-
file.create_dataset(f"{image_i}", data=tensor)
383-
384-
image_i += 1
385-
checkpoint.index = image_i
386-
checkpoint.save()
387-
388-
current_index += amount_of_rows_per_iteration
389-
390158

391159

392160

generate/dataset/dataframe.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ def df_schema(self) -> Schema:
4545
"file": pl.String,
4646
"result": pl.String,
4747
"hash": pl.String,
48+
"total_meas": pl.UInt8,
4849
"measurements": pl.String,
4950
"img_width": pl.UInt16,
5051
"img_height": pl.UInt16,

0 commit comments

Comments
 (0)