Skip to content

Commit 478ebe2

Browse files
author
Dpbm
committed
added early stopping
1 parent 905c8e4 commit 478ebe2

3 files changed

Lines changed: 138 additions & 52 deletions

File tree

args/parser.py

Lines changed: 61 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
DEFAULT_TEST_PERCENTAGE,
1818
DEFAULT_TARGET_FOLDER,
1919
DEFAULT_CHECKPOINT,
20+
DEFAULT_EARLY_STOP_PAIENCE,
21+
DEFAULT_EARLY_STOP_THRESHOLD
2022
)
2123
from utils.datatypes import Dimensions, FilePath
2224

@@ -37,6 +39,8 @@ class Arguments:
3739
"_target_folder",
3840
"_checkpoint",
3941
"_new_image_dim",
42+
"_es_patience",
43+
"_es_threshold"
4044
]
4145

4246
def __init__(self):
@@ -54,6 +58,8 @@ def __init__(self):
5458
self._target_folder = DEFAULT_TARGET_FOLDER
5559
self._checkpoint = DEFAULT_CHECKPOINT
5660
self._new_image_dim = DEFAULT_NEW_DIM
61+
self._es_patience = DEFAULT_EARLY_STOP_PATIENCE
62+
self._es_threshold = DEFAULT_EARLY_STOP_THRESHOLD
5763

5864
def parse(self, args: argparse.Namespace):
5965
"""Parse arguments from argparse"""
@@ -69,6 +75,8 @@ def parse(self, args: argparse.Namespace):
6975
self._target_folder = args.target_folder
7076
self._checkpoint = args.checkpoint
7177
self._new_image_dim = args.new_image_dim
78+
self._es_patience = args.es_patience
79+
self._es_threshold = args.es_threshold
7280

7381
@property
7482
def epochs(self) -> int:
@@ -189,6 +197,26 @@ def new_image_dim(self) -> Dimensions:
189197
def new_image_dim(self, value: Dimensions):
190198
"""Set new_image_dim data"""
191199
self._new_image_dim = value
200+
201+
@property
202+
def es_patience(self) -> int:
203+
"""Get es_patience data"""
204+
return self._es_patience # type: ignore
205+
206+
@es_patience.setter
207+
def es_patience(self, value: int):
208+
"""Set es_patience data"""
209+
self._es_patience = value
210+
211+
@property
212+
def es_threshold(self) -> float:
213+
"""Get es_threshold data"""
214+
return self._es_threshold # type: ignore
215+
216+
@es_threshold.setter
217+
def es_threshold(self, value: float):
218+
"""Set es_threshold data"""
219+
self._es_threshold = value
192220

193221
def __str__(self) -> str:
194222
string = f"epochs: {self._epochs}\n"
@@ -202,34 +230,39 @@ def __str__(self) -> str:
202230
string += f"amount circuits: {self._amount_circuits}\n"
203231
string += f"target_folder: {self._target_folder}\n"
204232
string += f"checkpoint: {self._checkpoint}\n"
205-
string += f"new image dim: {self._new_image_dim}\n"
206-
207-
return string
208-
209-
210-
def parse_args() -> Arguments:
211-
"""
212-
Use argparse to parse CLI arguments for all scripts
213-
"""
214-
215-
parser = argparse.ArgumentParser()
216-
parser.add_argument("--epochs", type=int, default=DEFAULT_EPOCHS)
217-
parser.add_argument("--batch-size", type=int, default=DEFAULT_BATCH_SIZE)
218-
parser.add_argument("--train-size", type=float, default=DEFAULT_TRAIN_PERCENTAGE)
219-
parser.add_argument("--test-size", type=float, default=DEFAULT_TEST_PERCENTAGE)
220-
parser.add_argument("--checkpoint", type=str, default=DEFAULT_CHECKPOINT)
221-
222-
parser.add_argument("--threads", type=int, default=DEFAULT_THREADS)
223-
224-
parser.add_argument("--shots", type=int, default=DEFAULT_SHOTS)
225-
parser.add_argument("--n-qubits", type=int, default=DEFAULT_NUM_QUBITS)
226-
parser.add_argument("--max-gates", type=int, default=DEFAULT_MAX_TOTAL_GATES)
227-
228-
parser.add_argument(
229-
"--amount-circuits", type=int, default=DEFAULT_AMOUNT_OF_CIRCUITS
230-
)
231-
parser.add_argument("--target-folder", type=str, default=DEFAULT_TARGET_FOLDER)
232-
parser.add_argument("--new-image-dim", type=int, nargs=2, default=DEFAULT_NEW_DIM)
233+
string += f"new image dim: {self._new_image_dim}\n"
234+
string += f"early stop patience: {self._es_patience}\n"
235+
string += f"early stop threshold: {self._es_threshold}\n"
236+
237+
return string
238+
239+
240+
def parse_args() -> Arguments:
241+
"""
242+
Use argparse to parse CLI arguments for all scripts
243+
"""
244+
245+
parser = argparse.ArgumentParser()
246+
parser.add_argument("--epochs", type=int, default=DEFAULT_EPOCHS)
247+
parser.add_argument("--batch-size", type=int, default=DEFAULT_BATCH_SIZE)
248+
parser.add_argument("--train-size", type=float, default=DEFAULT_TRAIN_PERCENTAGE)
249+
parser.add_argument("--test-size", type=float, default=DEFAULT_TEST_PERCENTAGE)
250+
parser.add_argument("--checkpoint", type=str, default=DEFAULT_CHECKPOINT)
251+
252+
parser.add_argument("--threads", type=int, default=DEFAULT_THREADS)
253+
254+
parser.add_argument("--shots", type=int, default=DEFAULT_SHOTS)
255+
parser.add_argument("--n-qubits", type=int, default=DEFAULT_NUM_QUBITS)
256+
parser.add_argument("--max-gates", type=int, default=DEFAULT_MAX_TOTAL_GATES)
257+
258+
parser.add_argument(
259+
"--amount-circuits", type=int, default=DEFAULT_AMOUNT_OF_CIRCUITS
260+
)
261+
parser.add_argument("--target-folder", type=str, default=DEFAULT_TARGET_FOLDER)
262+
parser.add_argument("--new-image-dim", type=int, nargs=2, default=DEFAULT_NEW_DIM)
263+
264+
parser.add_argument("--es-patience", type=int, default=DEFAULT_EARLY_STOP_PATIENCE)
265+
parser.add_argument("--es-threshold", type=float, default=DEFAULT_EARLY_STOP_THRESHOLD)
233266

234267
args = parser.parse_args(sys.argv[1:])
235268

train.py

Lines changed: 69 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,15 @@
1818

1919
from utils.constants import (
2020
DEBUG,
21-
MODEL_FILE_PREFIX,
22-
CHECKPOINT_FILE_PREFIX,
2321
dataset_file,
2422
images_h5_file,
2523
ghz_file,
2624
ghz_pred_file,
2725
output_plot_file,
2826
history_file,
27+
epoch_tracker_file,
28+
checkpoint_file,
29+
model_file
2930
)
3031
from utils.helpers import debug, PlotImages
3132
from utils.colors import Colors
@@ -221,10 +222,9 @@ def forward(self, image: torch.Tensor) -> torch.Tensor:
221222
debug(out.shape)
222223
return out
223224

224-
def save(self):
225+
def save(self, target_folder:FilePath):
225226
"""Save model weights."""
226-
path = "%s%s" % (MODEL_FILE_PREFIX, time.ctime())
227-
torch.save(self.state_dict(), path)
227+
torch.save(self.state_dict(), model_file(target_folder))
228228

229229

230230
class Checkpoint:
@@ -265,24 +265,17 @@ def scheduler(self) -> Optional[StateDict]:
265265
"""Get Scheduler parameters"""
266266
return self._data.get("scheduler")
267267

268-
@property
269-
def epoch(self) -> int:
270-
"""Get checkpoint epoch"""
271-
return self._data.get("epoch") or 0
272-
273268
@staticmethod
274269
def save(
275270
folder: FilePath,
276-
epoch: int,
277271
model: StateDict,
278272
optimizer: StateDict,
279273
scheduler: StateDict,
280274
):
281275
"""Save checkpoint data"""
282-
path = os.path.join(folder, "%s%s.pth" % (CHECKPOINT_FILE_PREFIX, time.ctime()))
276+
path = checkpoint_file(folder)
283277
print("%sSaving checkpoint at: %s...%s" % (Colors.MAGENTABG, path, Colors.ENDC))
284278
checkpoint = {
285-
"epoch": epoch,
286279
"model": model,
287280
"optimizer": optimizer,
288281
"scheduler": scheduler,
@@ -334,7 +327,45 @@ def load(self):
334327
with open(self._file_path, "r") as file:
335328
self._data = json.load(file)
336329

330+
class EpochTracker:
331+
"""Track the current epoch."""
332+
333+
def __init__(self, tracker_file:FilePath):
334+
self._file = tracker_file
335+
336+
def load(self) -> int:
337+
"""Load the text file containing the last epoch."""
338+
if(!os.path.exists(self._file)):
339+
return 0
340+
341+
with open(self._file, "r", encoding="utf-8") as tracker:
342+
return int(tracker.read().strip())
343+
344+
def save(self, epoch:int):
345+
"""Save the text file containing the last epoch."""
346+
with open(self._file, "w", encoding="utf-8") as tracker:
347+
tracker.write(str(epoch))
348+
349+
class EarlyStop:
350+
"""Handles Early stopping."""
351+
def __init__(self, patience:int, threshold:float):
352+
self._patience = patience
353+
self._threshold = threshold
354+
self._best_loss = float('inf')
355+
self._no_improve_counter = 0
356+
357+
def should_stop(self, loss:float) -> bool:
358+
"""Check whether the model is improving or not."""
359+
if(loss <= self._best_loss-self._threshold):
360+
self._best_loss = loss
361+
self._no_improve_counter = 0
362+
print("%sNew best loss%s"%(Colors.GREENBG, Colors.ENDC))
363+
return False
337364

365+
self._no_improve_counter += 1
366+
367+
if(self._no_improve_counter >= self._patience):
368+
return True
338369

339370
def one_epoch(
340371
dataset: DataLoader,
@@ -420,6 +451,8 @@ def train(
420451
test_percentage: float,
421452
batch_size: int,
422453
epochs: int,
454+
patience:int,
455+
es_threshold:float
423456
):
424457
"""Train model"""
425458

@@ -471,14 +504,19 @@ def train(
471504
scheduler = torch.optim.lr_scheduler.OneCycleLR(
472505
opt, max_lr=lr, steps_per_epoch=len(data_loader_train), epochs=epochs
473506
)
474-
best_loss = 1_000_000.0
507+
best_loss = float('inf')
475508

476509
if checkpoint.optimizer:
477510
opt.load_state_dict(checkpoint.optimizer)
478511
if checkpoint.scheduler:
479512
scheduler.load_state_dict(checkpoint.scheduler)
513+
514+
early_stopping = EarlyStop(patience=patience, threshold=threshold)
480515

481-
for epoch in range(checkpoint.epoch + 1, epochs):
516+
epoch_tracker = EpochTracker(epoch_tracker_file(target_folder))
517+
last_epoch = epoch_tracker.load()
518+
519+
for epoch in range(last_epoch + 1, epochs):
482520
print("%sEpoch: %d%s" % (Colors.YELLOWFG, epoch, Colors.ENDC))
483521
model.train(True)
484522

@@ -521,18 +559,23 @@ def train(
521559
print("%sAVG loss Test: %f%s" % (Colors.MAGENTAFG, avg_loss, Colors.ENDC))
522560
print("%sRMSE: %f%s" % (Colors.MAGENTAFG, rmse, Colors.ENDC))
523561

562+
epoch_tracker.save(epoch)
563+
564+
if(early_stopping.should_stop(avg_loss)):
565+
print("%sStop training early!%s"%(Colors.MAGENTABG,Colors.ENDC))
566+
break
567+
524568
if avg_loss < best_loss:
525569
best_loss = avg_loss
526570
print("%sBest loss: %f%s" % (Colors.GREENBG, best_loss, Colors.ENDC))
527571

528-
# save a checkpoint after every epoch
529-
Checkpoint.save(
530-
target_folder,
531-
epoch,
532-
model.state_dict(),
533-
opt.state_dict(),
534-
scheduler.state_dict(),
535-
)
572+
Checkpoint.save(
573+
target_folder,
574+
model.state_dict(),
575+
opt.state_dict(),
576+
scheduler.state_dict(),
577+
)
578+
536579

537580
eval_loss = 0.0
538581
targets = []
@@ -633,8 +676,10 @@ def setup_and_run_training(args: Arguments):
633676
args.test_size,
634677
args.batch_size,
635678
args.epochs,
679+
args.es_patience,
680+
args.es_threshold
636681
)
637-
model.save() # save best model
682+
model.save(args.target_folder) # save best model
638683
model.eval()
639684

640685
ghz = torch.load(ghz_file(args.target_folder), map_location=device)

utils/constants.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Constant values"""
22

33
import os
4+
import time
45

56
DEBUG = os.environ.get("DEBUG") or False
67

@@ -22,12 +23,16 @@
2223

2324
DEFAULT_TRAIN_PERCENTAGE = 0.7
2425
DEFAULT_TEST_PERCENTAGE = 0.2
26+
# The remaining 0.1 is for Evaluation
2527

2628
DEFAULT_CHECKPOINT = None
2729

2830
MODEL_FILE_PREFIX = "model_"
2931
CHECKPOINT_FILE_PREFIX = "checkpoint_"
3032

33+
DEFAULT_EARLY_STOP_PATIENCE=5
34+
DEFAULT_EARLY_STOP_THRESHOLD=0.1
35+
3136
# ruff: noqa: E731
3237
dataset_path = lambda target_folder: os.path.join(target_folder, "dataset")
3338
dataset_file = lambda target_folder: os.path.join(target_folder, "dataset.csv")
@@ -44,3 +49,6 @@
4449
images_gen_checkpoint_file = lambda target_folder: os.path.join(
4550
target_folder, "gen_checkpoint.json"
4651
)
52+
epoch_tracker_file = lambda target_folder: os.path.join(target_folder, "epoch.dat")
53+
checkpoint_file = lambda target_folder: os.path.join(target_folder, "%s%s.pth"%(CHECKPOINT_FILE_PREFIX, time.ctime()))
54+
model_file = lambda target_folder : os.path.join(target_folder, "%s%s"%(MODEL_FILE_PREFIX, time.ctime()))

0 commit comments

Comments
 (0)