Skip to content

Commit 5b6e5e9

Browse files
authored
MetaMathQA: upload checkpoint to bucket after training (#3163)
To do post-training analysis on the checkpoints (e.g., when experimenting on different post-training metrics) it is useful to have the checkpoints available. Since we have buckets now, let's use those. For the Makefile you need to set the `UPLOAD_BUCKET` environment variable to activate this feature for the whole run.
1 parent cc89731 commit 5b6e5e9

4 files changed

Lines changed: 31 additions & 3 deletions

File tree

method_comparison/MetaMathQA/Makefile

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,12 @@ RUN_SCRIPT := run.py
66
EXPERIMENTS_DIR := experiments
77
RESULTS_DIR := results
88

9+
OPTIONAL_FLAGS =
10+
11+
ifdef UPLOAD_BUCKET
12+
OPTIONAL_FLAGS += --bucket_name "${UPLOAD_BUCKET}"
13+
endif
14+
915
# --- Automatic Experiment and Result Discovery ---
1016

1117
# 1. Find all experiment directories by looking for adapter_config.json files.
@@ -49,7 +55,7 @@ define EXPERIMENT_template
4955
$(call exp_to_res,$(1)): $(wildcard $(1)/adapter_config.json) $(wildcard $(1)/training_params.json)
5056
@echo "---"
5157
@echo "Running experiment: $(1)"
52-
-$(PYTHON) $(RUN_SCRIPT) -v $(1)
58+
-$(PYTHON) $(RUN_SCRIPT) $(OPTIONAL_FLAGS) -v $(1)
5359
@echo "Finished: $$@"
5460
@echo "---"
5561

method_comparison/MetaMathQA/README.md

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,10 @@ without modifying it. For example:
4747

4848
to run the VBLoRA default experiment again.
4949

50+
If you set `UPLOAD_BUCKET="your_user/bucket_name"` as an environment variable prior to starting experiments
51+
via `make`, all experiments will be called with the `--bucket_name $UPLOAD_BUCKET` parameter and therefore
52+
store the checkpoints in that bucket.
53+
5054
### `adapter_config.json`
5155

5256
This must be a valid PEFT configuration. It is easiest to create it programmatically, e.g.:
@@ -94,7 +98,7 @@ From practical experiments, for a batch size of 4, a bucket size of 80 provides
9498

9599
### Start a run
96100

97-
Once everything is set up properly, start a run by using the `run.py` script. Pass `-v` for verbose output to the console (recommended if observing the progress is desired). As an example, for `experiments/lora/llama-3.2-3B-rank32/` the invocation would be:
101+
Once everything is set up properly, start a run by using the `run.py` script. Pass `-v` for verbose output to the console (recommended if observing the progress is desired). To save the resulting experiment checkpoints to a huggingface bucket, you can pass the bucket name via the `--bucket_name` parameter (e.g., `"user/my_bucket_name"`). As an example, for `experiments/lora/llama-3.2-3B-rank32/` the invocation would be:
98102

99103
```sh
100104
python run.py -v experiments/lora/llama-3.2-3B-rank32/

method_comparison/MetaMathQA/run.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
get_train_config,
5151
init_accelerator,
5252
log_results,
53+
upload_checkpoint_to_bucket,
5354
validate_experiment_path,
5455
)
5556

@@ -405,7 +406,7 @@ def train(
405406
return train_result
406407

407408

408-
def main(*, path_experiment: str, experiment_name: str, clean: bool) -> None:
409+
def main(*, path_experiment: str, experiment_name: str, clean: bool, bucket_name: Optional[str]) -> None:
409410
tic_total = time.perf_counter()
410411
start_date = dt.datetime.now(tz=dt.timezone.utc).replace(microsecond=0).isoformat()
411412

@@ -477,6 +478,9 @@ def main(*, path_experiment: str, experiment_name: str, clean: bool) -> None:
477478
print_fn=print_verbose,
478479
)
479480

481+
if bucket_name is not None:
482+
upload_checkpoint_to_bucket(model, experiment_name, bucket_name)
483+
480484
time_total = time.perf_counter() - tic_total
481485
# log results: print and save to file
482486
log_results(
@@ -503,6 +507,7 @@ def main(*, path_experiment: str, experiment_name: str, clean: bool) -> None:
503507
action="store_true",
504508
help="Delete training artifacts after run finishes (logs are still saved)",
505509
)
510+
parser.add_argument("--bucket_name", type=str, help="HF bucket to upload checkpoints to.")
506511
args = parser.parse_args()
507512

508513
experiment_name = validate_experiment_path(args.path_experiment)
@@ -521,4 +526,5 @@ def print_verbose(*args, **kwargs) -> None:
521526
path_experiment=args.path_experiment,
522527
experiment_name=experiment_name,
523528
clean=args.clean,
529+
bucket_name=args.bucket_name,
524530
)

method_comparison/MetaMathQA/utils.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -344,6 +344,18 @@ def __iter__(self):
344344
yield from self._batch_iterator(bucket)
345345

346346

347+
def upload_checkpoint_to_bucket(model: nn.Module, experiment_name: str, bucket_name: str):
348+
try:
349+
with tempfile.TemporaryDirectory(ignore_cleanup_errors=True, delete=True) as tmp_dir:
350+
model.save_pretrained(tmp_dir)
351+
huggingface_hub.batch_bucket_files(
352+
bucket_name,
353+
add=[(os.path.join(tmp_dir, fname), f"{experiment_name}/{fname}") for fname in os.listdir(tmp_dir)],
354+
)
355+
except Exception as exc:
356+
print(f"Failed to upload model checkpoint to hub: {exc}")
357+
358+
347359
def get_file_size(
348360
model: nn.Module, *, peft_config: Optional[PeftConfig], clean: bool, print_fn: Callable[..., None]
349361
) -> int:

0 commit comments

Comments
 (0)