Skip to content

Commit 233e9f3

Browse files
author
Dpbm
committed
fixing data generation
1 parent 3d56feb commit 233e9f3

15 files changed

Lines changed: 312 additions & 151 deletions

Makefile

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,25 +4,25 @@ TARGET_PATH ?= "."
44
clean-all: clean-dataset clean-pred clean-ghz clean-model clean-checkpoints clean-history clean-gen-checkpoint
55

66
clean-dataset:
7-
rm -rf $(TARGET_PATH)/dataset/ $(TARGET_PATH)/dataset.csv $(TARGET_PATH)/*.h5 $(TARGET_PATH)/dataset-images.zip
7+
sudo rm -rf $(TARGET_PATH)/dataset/ $(TARGET_PATH)/dataset.csv $(TARGET_PATH)/*.h5 $(TARGET_PATH)/dataset-images.zip
88

99
clean-pred:
10-
rm -rf $(TARGET_PATH)/ghz-prediction.pth
10+
sudo rm -rf $(TARGET_PATH)/ghz-prediction.pth
1111

1212
clean-ghz:
13-
rm -rf $(TARGET_PATH)/ghz.pth $(TARGET_PATH)/ghz.jpeg
13+
sudo rm -rf $(TARGET_PATH)/ghz.pth $(TARGET_PATH)/ghz.jpeg
1414

1515
clean-model:
16-
rm -rf $(TARGET_PATH)/model_*
16+
sudo rm -rf $(TARGET_PATH)/model_*
1717

1818
clean-checkpoints:
19-
rm -rf $(TARGET_PATH)/checkpoint_*
19+
sudo rm -rf $(TARGET_PATH)/checkpoint_*
2020

2121
clean-gen-checkpoint:
22-
rm -rf $(TARGET_PATH)/gen_checkpoint.json
22+
sudo rm -rf $(TARGET_PATH)/gen_checkpoint.json
2323

2424
clean-history:
25-
rm -rf $(TARGET_PATH)/history.json
25+
sudo rm -rf $(TARGET_PATH)/history.json
2626

2727
pack:
2828
zip -r $(TARGET_PATH)/dataset-images.zip $(TARGET_PATH)/dataset/

dags/dataset.py

Lines changed: 35 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
remove_duplicated_files,
1818
transform_images,
1919
start_df,
20+
shuffle_csv,
2021
Checkpoint,
2122
Stages,
2223
)
@@ -42,6 +43,7 @@
4243
GEN_IMAGES_TASK_ID = "gen_images"
4344
REMOVE_DUPLICATES_TASK_ID = "remove_duplicates"
4445
TRANSFORM_TASK_ID = "transform_images"
46+
SHUFFLE_DATASET_ID = "shuffle_df"
4547

4648

4749
def next_step(checkpoint: Checkpoint) -> str:
@@ -55,6 +57,9 @@ def next_step(checkpoint: Checkpoint) -> str:
5557
if checkpoint.stage == Stages.GEN_IMAGES:
5658
return GEN_IMAGES_TASK_ID
5759

60+
if checkpoint.stage == Stages.SHUFFLE:
61+
return SHUFFLE_DATASET_ID
62+
5863
if checkpoint.stage == Stages.DUPLICATES:
5964
return REMOVE_DUPLICATES_TASK_ID
6065

@@ -131,14 +136,35 @@ def update_checkpoint(checkpoint: Checkpoint, stage: Stages):
131136
Qiskit framework.
132137
"""
133138

134-
transtion_gen_to_remove = PythonOperator(
135-
task_id="gen_to_remove",
139+
transtion_gen_to_shuffle = PythonOperator(
140+
task_id="gen_to_shuffle",
141+
python_callable=update_checkpoint,
142+
op_args=[checkpoint, Stages.SHUFFLE],
143+
)
144+
145+
transtion_gen_to_shuffle.doc_md = """
146+
Update checkpoint to start shuffling rows.
147+
"""
148+
149+
shuffle = PythonOperator(
150+
task_id=SHUFFLE_DATASET_ID,
151+
python_callable=shuffle_csv,
152+
op_args=[folder],
153+
trigger_rule=TriggerRule.NONE_FAILED_MIN_ONE_SUCCESS
154+
)
155+
156+
shuffle.doc_md = """
157+
Shuffle dataset rows.
158+
"""
159+
160+
transtion_shuffle_to_remove = PythonOperator(
161+
task_id="shuffle_to_remove",
136162
python_callable=update_checkpoint,
137163
op_args=[checkpoint, Stages.DUPLICATES],
138164
)
139165

140-
transtion_gen_to_remove.doc_md = """
141-
Update checkpoint to start removing duplicated files.
166+
transtion_shuffle_to_remove.doc_md = """
167+
Update checkpoint to start deleting duplicated rows.
142168
"""
143169

144170
remove_duplicates = PythonOperator(
@@ -228,11 +254,14 @@ def update_checkpoint(checkpoint: Checkpoint, stage: Stages):
228254
gen_df >> branch_checkpoint
229255

230256
branch_checkpoint >> gen_images
257+
branch_checkpoint >> shuffle
231258
branch_checkpoint >> remove_duplicates
232259
branch_checkpoint >> transform_img
233260

234-
gen_images >> transtion_gen_to_remove
235-
transtion_gen_to_remove >> remove_duplicates
261+
gen_images >> transtion_gen_to_shuffle
262+
transtion_gen_to_shuffle >> shuffle
263+
shuffle >> transtion_shuffle_to_remove
264+
transtion_shuffle_to_remove >> remove_duplicates
236265
remove_duplicates >> transition_remove_to_transform
237266
transition_remove_to_transform >> transform_img
238267
transform_img >> reset_checkpoint

dataset.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
images_h5_file,
3232
images_gen_checkpoint_file,
3333
dataset_file_tmp,
34+
SCALE_CIRCUIT_SIZE
3435
)
3536
from utils.datatypes import FilePath, df_schema, Dimensions
3637
from utils.image import transform_image
@@ -49,6 +50,7 @@ class Stages(Enum):
4950
"""Enum for dataset generation stages"""
5051

5152
GEN_IMAGES = "gen"
53+
SHUFFLE = "shuffle"
5254
DUPLICATES = "duplicates"
5355
TRANSFORM = "transform"
5456

@@ -200,7 +202,7 @@ def generate_circuit_images(
200202
qc_copy.add_register(classical_register)
201203
qc_copy.measure(measurement, list(range(total_measurements)))
202204

203-
drawing = qc_copy.draw("mpl", filename=image_path)
205+
drawing = qc_copy.draw("mpl", filename=image_path, fold=-1, scale=SCALE_CIRCUIT_SIZE)
204206
plt.close(drawing)
205207
del drawing
206208

@@ -405,6 +407,22 @@ def get_duplicated_files_list_by_diff(
405407

406408
return duplicated_files.to_list() # type: ignore
407409

410+
def shuffle_df(df:pl.DataFrame) -> pl.DataFrame:
411+
"""
412+
Shffle DF rows to ensure no sequential logic is kept.
413+
"""
414+
return df.sample(fraction=1.0, shuffle=True, seed=32)
415+
416+
417+
def shuffle_csv(target_folder:FilePath):
418+
"""
419+
Shffle CSV rows.
420+
"""
421+
print("%sShuffling DF....%s"%(Colors.GREENBG,Colors.ENDC))
422+
file_path = dataset_file(target_folder)
423+
df = pl.read_csv(file_path)
424+
df = shuffle_df(df)
425+
df.write_csv(file_path)
408426

409427
def transform_images(
410428
target_folder: FilePath, new_dim: Dimensions, checkpoint: Checkpoint
@@ -434,7 +452,7 @@ def transform_images(
434452
with h5py.File(images_h5_file(target_folder), "a") as file:
435453
for image_path in tqdm(collected_rows):
436454
with Image.open(image_path) as img:
437-
tensor = transform_image(img, max_width, max_height)
455+
tensor = transform_image(img, max_width)
438456
file.create_dataset(f"{image_i}", data=tensor)
439457

440458
image_i += 1
@@ -522,8 +540,15 @@ def main(args: Arguments):
522540
checkpoint.stage = Stages.DUPLICATES
523541
checkpoint.index = 0
524542

543+
if checkpoint.stage == Stages.SHUFFLE:
544+
shuffle_df()
545+
546+
checkpoint.stage = Stages.DUPLICATES
547+
checkpoint.index = 0
548+
525549
if checkpoint.stage == Stages.DUPLICATES:
526550
remove_duplicated_files(args.target_folder, checkpoint)
551+
527552
checkpoint.stage = Stages.TRANSFORM
528553
checkpoint.index = 0
529554

generate/fix_default_img_sizes.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
import sys
2+
import os
3+
from typing import Optional, Tuple
4+
from pathlib import Path
5+
import re
6+
7+
from PIL import Image
8+
9+
base_path = str(Path(__file__).resolve().parent.parent)
10+
sys.path.append(base_path)
11+
12+
import matplotlib.pyplot as plt
13+
from qiskit import QuantumCircuit
14+
15+
from utils.constants import DEFAULT_MAX_TOTAL_GATES, DEFAULT_NUM_QUBITS
16+
from random_circuit import generate_circuit, generate_circuit_worst_case
17+
18+
CONSTANTS_FILE_PATH = os.path.join(base_path, "utils", "constants.py")
19+
TESTS_FILE_PATH = os.path.join(base_path, "tests")
20+
SCALE_SIZE = 0.5
21+
22+
def draw(qc:QuantumCircuit, ax:Optional[plt.Axes]=None) -> plt.Figure:
23+
"""Draw circuit with tweaks."""
24+
return qc.draw("mpl", scale=SCALE_SIZE, fold=-1, ax=ax)
25+
26+
if __name__ == "__main__":
27+
print("Testing the min size of a circuit image...")
28+
qc1 = generate_circuit(DEFAULT_NUM_QUBITS, 0)
29+
small_fig_path = os.path.join(TESTS_FILE_PATH, "small.png")
30+
img1 = draw(qc1)
31+
img1.savefig(small_fig_path, bbox_inches="tight")
32+
plt.close(img1)
33+
min_width,min_height = Image.open(small_fig_path).size
34+
print(f"Size -> {min_width}x{min_height}px")
35+
36+
print("Testing the max size of a circuit image...")
37+
qc2 = generate_circuit_worst_case(DEFAULT_NUM_QUBITS, DEFAULT_MAX_TOTAL_GATES)
38+
large_fig_path = os.path.join(TESTS_FILE_PATH, "large.png")
39+
img2 = draw(qc2)
40+
img2.savefig(large_fig_path, bbox_inches="tight")
41+
plt.close(img2)
42+
max_width,max_height = Image.open(large_fig_path).size
43+
print(f"Size -> {max_width}x{max_height}px")
44+
45+
# fig,ax = plt.subplots(2)
46+
# draw(qc1,ax[0])
47+
# draw(qc2,ax[1])
48+
# plt.show()
49+
50+
51+
with open(CONSTANTS_FILE_PATH, "r+", encoding="utf-8") as file:
52+
read = file.read()
53+
54+
pattern_dim = r'DEFAULT_NEW_DIM *= *\( *[0-9]{2,} *, *[0-9]{2,} *\)'
55+
pattern_scale = r'SCALE_CIRCUIT_SIZE *= *0\.[1-9]'
56+
57+
new_dim_data = f"DEFAULT_NEW_DIM = ({max_width}, {max_height})"
58+
new_scale_factor = f"SCALE_CIRCUIT_SIZE = {SCALE_SIZE}"
59+
60+
new_data = read
61+
62+
if re.search(pattern_dim, new_data):
63+
new_data = re.sub(pattern_dim, new_dim_data, new_data)
64+
else:
65+
new_data += new_dim_data
66+
67+
if re.search(pattern_scale, new_data):
68+
new_data = re.sub(pattern_scale, new_scale_factor, new_data)
69+
else:
70+
new_data += new_scale_factor
71+
72+
with open(CONSTANTS_FILE_PATH, "w", encoding="utf-8") as file:
73+
file.write(new_data)

generate/random_circuit.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -60,12 +60,8 @@ def get_random_gate(cls) -> QiskitGate:
6060
gate = random.choice(cls.gates)
6161
return gate()
6262

63-
64-
def get_random_circuit(n_qubits: int, total_gates: int) -> QuantumCircuit:
65-
"""Generate a random circuit based on the amount of qubits and gates."""
66-
67-
total_gates = random.randint(0, total_gates)
68-
63+
def generate_circuit(n_qubits:int, total_gates:int) -> QuantumCircuit:
64+
"""Generate a circuit based on the amount of gates"""
6965
qc = QuantumCircuit(n_qubits)
7066

7167
for _ in range(total_gates):
@@ -85,3 +81,21 @@ def get_random_circuit(n_qubits: int, total_gates: int) -> QuantumCircuit:
8581
qc.barrier()
8682

8783
return qc
84+
85+
def generate_circuit_worst_case(n_qubits: int, total_gates: int) -> QuantumCircuit:
86+
"""Generate the longes circuit as possible"""
87+
qc = QuantumCircuit(n_qubits)
88+
89+
for _ in range(total_gates):
90+
gate = SingleQubitGate.get_random_gate()
91+
qc.append(gate, [0])
92+
qc.barrier()
93+
94+
return qc
95+
96+
def get_random_circuit(n_qubits: int, total_gates: int) -> QuantumCircuit:
97+
"""Generate a random circuit based on the amount of qubits and gates."""
98+
99+
total_gates = random.randint(0, total_gates)
100+
return generate_circuit(n_qubits,total_gates)
101+

ghz.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from utils.image import transform_image
88
from args.parser import parse_args
9-
from utils.constants import ghz_image_file, ghz_file
9+
from utils.constants import ghz_image_file, ghz_file, SCALE_CIRCUIT_SIZE
1010
from utils.datatypes import FilePath, Dimensions
1111

1212

@@ -20,14 +20,13 @@ def gen_circuit(n_qubits: int, target_folder: FilePath, new_dim: Dimensions):
2020
qc.measure_all()
2121

2222
ghz_image_path = ghz_image_file(target_folder)
23-
qc.draw("mpl", filename=ghz_image_path)
23+
qc.draw("mpl", filename=ghz_image_path, fold=-1, scale=SCALE_CIRCUIT_SIZE)
2424

2525
with Image.open(ghz_image_path) as file:
2626
width, height = new_dim
27-
tensor = transform_image(file, width, height)
27+
tensor = transform_image(file, width)
2828
torch.save(tensor, ghz_file(target_folder))
2929

30-
3130
if __name__ == "__main__":
3231
args = parse_args()
3332
gen_circuit(args.n_qubits, args.target_folder, args.new_image_dim)

tests/large.png

12.6 KB
Loading

tests/small.png

1.66 KB
Loading

tests/test_dataset_generation.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
save_df,
1212
start_df,
1313
get_duplicated_files_list_by_diff,
14+
shuffle_df
1415
)
1516
from utils.datatypes import df_schema
1617

@@ -233,6 +234,15 @@ def test_remove_duplicates_sequence(self, base_df, tmp_df, tmp_df2):
233234
assert os.path.exists(tmp_df)
234235
assert len(pl.read_csv(tmp_df)) == 8
235236

237+
def test_shuffle_df(self,base_df):
238+
df = pl.read_csv(base_df)
239+
no_shuffled = df["file"].to_list()
240+
shuffled = shuffle_df(df)["file"].to_list()
241+
print(no_shuffled, shuffled)
242+
243+
assert no_shuffled != shuffled
244+
245+
236246
# SINCE SAVING A LAZY FRAME AS CSV IN THE SAME FILE IS NOT STABLE,
237247
# WE GONNA IGNORE THE TESTS BELLOW. FOR THE PRODUCTION CODE, WE GONNA
238248
# SAVE THE UPDATED LAZY FRAME IN A DIFFRERENT FILE AND THEN RENAME IT.

tests/test_model.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
import os
2+
3+
from PIL import Image
4+
5+
from utils.image import transform_image
6+
from utils.constants import DEFAULT_NEW_DIM
7+
from train import Model
8+
9+
IMAGES_TESTS_PATH = os.path.join(".", "tests")
10+
11+
class TestModel:
12+
"""
13+
Test the Model itself.
14+
"""
15+
16+
def test_image_on_model(self):
17+
"""Check if the images fit in the network."""
18+
model = Model().half()
19+
20+
width,height = DEFAULT_NEW_DIM
21+
22+
small_img = Image.open(os.path.join(IMAGES_TESTS_PATH, "small.png"))
23+
small_img_tr = transform_image(small_img, width)
24+
model.forward(small_img_tr)
25+
26+
large_img = Image.open(os.path.join(IMAGES_TESTS_PATH, "large.png"))
27+
large_img_tr = transform_image(small_img, width)
28+
model.forward(large_img_tr)

0 commit comments

Comments
 (0)