Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 57 additions & 62 deletions nemo_skills/dataset/asr-leaderboard/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,121 +16,122 @@

Downloads and formats datasets from the official HF Open ASR Leaderboard ESB
test-only sorted dataset (hf-audio/esb-datasets-test-only-sorted). This is the
same data source used by the official leaderboard and the offline NeMo eval
pipeline, ensuring apples-to-apples WER comparison.
same data source used by the official leaderboard, ensuring apples-to-apples
WER comparison.

Audio paths in JSONL: /dataset/asr-leaderboard/data/{dataset}/{sample_id}.flac

Usage:
ns prepare_data asr-leaderboard
ns prepare_data asr-leaderboard --datasets librispeech_clean ami
ns prepare_data asr-leaderboard --datasets earnings22
ns prepare_data asr-leaderboard --no-audio # skip saving audio files
ns prepare_data asr-leaderboard --no-audio
"""

import argparse
import json
from pathlib import Path

import numpy as np
import soundfile as sf
from datasets import load_dataset
from datasets import Audio, load_dataset
from tqdm import tqdm

HF_REPO = "hf-audio/esb-datasets-test-only-sorted"
SYSTEM_MESSAGE = "You are a helpful assistant. /no_think"
MIN_AUDIO_DURATION = 0.1 # Skip audio shorter than this (causes mel spectrogram errors)
AUDIO_SAMPLE_RATE = 16000

# (hf_repo, config, split, text_field, id_field)
# (config, split, text_field, id_field)
DATASET_CONFIGS = {
"librispeech_clean": ("hf-audio/esb-datasets-test-only-sorted", "librispeech", "test.clean", "text", "id"),
"librispeech_other": ("hf-audio/esb-datasets-test-only-sorted", "librispeech", "test.other", "text", "id"),
"voxpopuli": ("hf-audio/esb-datasets-test-only-sorted", "voxpopuli", "test", "text", "id"),
"tedlium": ("hf-audio/esb-datasets-test-only-sorted", "tedlium", "test", "text", "id"),
"gigaspeech": ("hf-audio/esb-datasets-test-only-sorted", "gigaspeech", "test", "text", "id"),
"spgispeech": ("hf-audio/esb-datasets-test-only-sorted", "spgispeech", "test", "text", "id"),
"earnings22": ("hf-audio/esb-datasets-test-only-sorted", "earnings22", "test", "text", "id"),
"ami": ("hf-audio/esb-datasets-test-only-sorted", "ami", "test", "text", "id"),
"librispeech_clean": ("librispeech", "test.clean", "text", "id"),
"librispeech_other": ("librispeech", "test.other", "text", "id"),
"voxpopuli": ("voxpopuli", "test", "text", "id"),
"tedlium": ("tedlium", "test", "text", "id"),
"gigaspeech": ("gigaspeech", "test", "text", "id"),
"spgispeech": ("spgispeech", "test", "text", "id"),
"earnings22": ("earnings22", "test", "text", "id"),
"ami": ("ami", "test", "text", "id"),
}


def save_audio_and_format_entry(
entry, dataset_name, audio_dir, sample_idx, text_field="text", id_field="id", with_audio=True
):
"""Format a dataset entry and optionally save audio file."""
text = entry[text_field].strip()
def extract_audio(audio_info):
"""Extract audio array and sampling rate from a HF dataset audio entry.

system_message = {"role": "system", "content": SYSTEM_MESSAGE}
user_message = {"role": "user", "content": "Transcribe the following audio."}
Handles both the legacy dict format ({"array": ..., "sampling_rate": ...})
and the newer AudioDecoder object from torchcodec-based datasets library.
"""
if audio_info is None:
return None, None
try:
audio_array = np.array(audio_info["array"])
sampling_rate = int(audio_info["sampling_rate"])
return audio_array, sampling_rate
except (KeyError, TypeError, IndexError):
return None, None
Comment thread
melllinia marked this conversation as resolved.


def format_entry(entry, dataset_name, audio_dir, text_field, id_field, with_audio):
"""Format a dataset entry into JSONL and optionally save the audio file."""
text = entry[text_field].strip()
if not text:
return None

sample_id = str(entry[id_field]).replace("/", "_")
audio_filename = f"{Path(sample_id).stem}.flac"

audio_info = entry.get("audio", {})
audio_array, sampling_rate = extract_audio(entry.get("audio"))
duration = None
if isinstance(audio_info, dict) and "array" in audio_info and "sampling_rate" in audio_info:
audio_array = audio_info["array"]
sampling_rate = audio_info["sampling_rate"]
duration = len(audio_array) / sampling_rate

if duration < MIN_AUDIO_DURATION:
return None

if audio_array is not None and sampling_rate is not None:
duration = len(audio_array) / sampling_rate
if with_audio:
sf.write(str(audio_dir / audio_filename), audio_array, sampling_rate)

user_message = {"role": "user", "content": "Transcribe the following audio."}
audio_meta = {"path": f"/dataset/asr-leaderboard/data/{dataset_name}/{audio_filename}"}
if duration is not None:
audio_meta["duration"] = float(duration)
user_message["audio"] = audio_meta

formatted_entry = {
formatted = {
"task_type": "ASR",
"expected_answer": text,
"messages": [system_message, user_message],
"messages": [{"role": "system", "content": SYSTEM_MESSAGE}, user_message],
"subset_for_metrics": dataset_name,
"id": entry[id_field],
}

formatted_entry["id"] = entry[id_field]
if "speaker_id" in entry:
formatted_entry["speaker_id"] = entry["speaker_id"]
formatted["speaker_id"] = entry["speaker_id"]

return formatted_entry
return formatted


def prepare_dataset(dataset_name, output_dir, with_audio=True):
"""Prepare a single ASR dataset."""
"""Download, decode, and write a single ASR dataset to JSONL + audio files."""
if dataset_name not in DATASET_CONFIGS:
raise ValueError(f"Unknown dataset: {dataset_name}. Available: {list(DATASET_CONFIGS.keys())}")

hf_repo, hf_config, hf_split, text_field, id_field = DATASET_CONFIGS[dataset_name]
hf_config, hf_split, text_field, id_field = DATASET_CONFIGS[dataset_name]

print(f"Loading {dataset_name} from {hf_repo} (config={hf_config}, split={hf_split})...")
dataset = load_dataset(hf_repo, hf_config, split=hf_split, trust_remote_code=True)
print(f"Loading {dataset_name} from {HF_REPO} (config={hf_config}, split={hf_split})...")
dataset = load_dataset(HF_REPO, hf_config, split=hf_split)
if with_audio and "audio" in dataset.column_names:
dataset = dataset.cast_column("audio", Audio(sampling_rate=AUDIO_SAMPLE_RATE))
Comment thread
melllinia marked this conversation as resolved.

output_file = output_dir / f"{dataset_name}.jsonl"
audio_dir = output_dir / "data" / dataset_name

if with_audio:
audio_dir.mkdir(parents=True, exist_ok=True)
print(f"Saving audio files to {audio_dir}")

print(f"Processing {len(dataset)} samples from {dataset_name}...")

count = 0
skipped = 0
with open(output_file, "w", encoding="utf-8") as fout:
for idx, entry in enumerate(tqdm(dataset, desc=dataset_name)):
formatted = save_audio_and_format_entry(
entry, dataset_name, audio_dir, idx, text_field=text_field, id_field=id_field, with_audio=with_audio
)
for entry in tqdm(dataset, desc=dataset_name):
formatted = format_entry(entry, dataset_name, audio_dir, text_field, id_field, with_audio)
if formatted is None:
skipped += 1
continue
if formatted["expected_answer"]:
fout.write(json.dumps(formatted) + "\n")
count += 1

if skipped > 0:
print(f"Skipped {skipped} samples with audio < {MIN_AUDIO_DURATION}s")
fout.write(json.dumps(formatted) + "\n")
count += 1

print(f"Saved {count} samples to {output_file}")
return count
Expand All @@ -157,25 +158,19 @@ def main():
output_dir.mkdir(parents=True, exist_ok=True)

with_audio = not args.no_audio

if args.no_audio:
if not with_audio:
print("Running without saving audio files.")
else:
print("Running with audio. Saving to data/{dataset}/")

datasets_to_prepare = list(DATASET_CONFIGS.keys()) if "all" in args.datasets else args.datasets

total_samples = 0
for dataset_name in datasets_to_prepare:
total_samples += prepare_dataset(dataset_name, output_dir, with_audio=with_audio)

# Combine all dataset JSONLs into test.jsonl
combined_file = output_dir / "test.jsonl"
print(f"\nCreating combined file: {combined_file}")

all_jsonl_files = sorted(output_dir.glob("*.jsonl"))
dataset_files = [f for f in all_jsonl_files if f.name != "test.jsonl"]

dataset_files = sorted(f for f in output_dir.glob("*.jsonl") if f.name != "test.jsonl")
combined_count = 0
with open(combined_file, "w", encoding="utf-8") as fout:
for dataset_file in dataset_files:
Expand Down
106 changes: 106 additions & 0 deletions recipes/asr/run_hf_leaderboard.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""HuggingFace Open ASR Leaderboard evaluation for NeMo ASR models.

Runs the HF Open ASR Leaderboard benchmark (8 datasets, WER metric) on
NeMo ASR models using the unified server with the nemo_asr or salm backend.

Uses 1 GPU for the server (NeMo ASR model). The evaluation client runs
on CPU alongside it.

Example usage::

# Evaluate parakeet-v3 (traditional ASR model -> nemo_asr backend)
python recipes/asr/run_hf_leaderboard.py \\
--model nvidia/parakeet-tdt-0.6b-v3 \\
--cluster oci_iad \\
--output_dir /lustre/.../parakeet-v3-asr-leaderboard

# Evaluate canary-qwen-2.5b (SALM model -> salm backend)
python recipes/asr/run_hf_leaderboard.py \\
--model nvidia/canary-qwen-2.5b \\
--backend salm \\
--cluster oci_iad \\
--output_dir /lustre/.../canary-qwen-asr-leaderboard
"""

import argparse

from nemo_skills.pipeline.cli import eval as run_eval
from nemo_skills.pipeline.cli import wrap_arguments

DEFAULT_SERVER_CONTAINER = "nvcr.io/nvidia/nemo:25.11"
DEFAULT_INSTALLATION_COMMAND = "pip install -r requirements/audio.txt"


def main():
parser = argparse.ArgumentParser(description="Run HF Open ASR Leaderboard evaluation on a NeMo ASR model")
parser.add_argument("--model", required=True, help="NeMo ASR model name or path (e.g. nvidia/canary-qwen-2.5b)")
parser.add_argument("--cluster", required=True, help="Cluster name (e.g. oci_iad)")
parser.add_argument("--output_dir", required=True, help="Directory for evaluation outputs")
parser.add_argument("--data_dir", default="/dataset", help="Root data directory (must contain asr-leaderboard/)")
parser.add_argument("--server_container", default=DEFAULT_SERVER_CONTAINER, help="NeMo container image")
parser.add_argument("--server_gpus", type=int, default=1, help="Number of GPUs for the ASR server")
parser.add_argument(
"--backend",
default="nemo_asr",
choices=["nemo_asr", "salm"],
help="Server backend: nemo_asr for traditional ASR models (parakeet), salm for SALM models (canary-qwen)",
)
parser.add_argument("--batch_size", type=int, default=16, help="NeMo ASR transcription batch size")
parser.add_argument(
Comment thread
melllinia marked this conversation as resolved.
"--num_chunks", type=int, default=None, help="Split dataset into N chunks for data parallelism"
)
parser.add_argument("--expname", default="asr-leaderboard", help="Experiment name")
parser.add_argument("--partition", default=None, help="Slurm partition (e.g. interactive)")
parser.add_argument("--config_dir", default=None, help="Directory containing cluster config YAMLs")
parser.add_argument("--split", default=None, help="Dataset split to evaluate (default: test = all datasets)")

args = parser.parse_args()

run_eval(
ctx=wrap_arguments(
"++prompt_format=openai "
"++prompt_config=null "
"++enable_audio=true "
"++server.server_type=vllm_multimodal "
"++max_concurrent_requests=16 "
"++inference.tokens_to_generate=256"
),
cluster=args.cluster,
output_dir=args.output_dir,
benchmarks="asr-leaderboard",
model=args.model,
server_type="generic",
server_gpus=args.server_gpus,
server_entrypoint=(
"MKL_SERVICE_FORCE_INTEL=1 MKL_THREADING_LAYER=GNU "
"MKL_NUM_THREADS=1 VML_NUM_THREADS=1 "
"python -m nemo_skills.inference.server.serve_unified"
),
server_args=f"--backend {args.backend} --batch_size {args.batch_size}",
server_container=args.server_container,
num_chunks=args.num_chunks,
data_dir=args.data_dir,
config_dir=args.config_dir,
partition=args.partition,
split=args.split,
installation_command=DEFAULT_INSTALLATION_COMMAND,
expname=args.expname,
)


if __name__ == "__main__":
main()
2 changes: 2 additions & 0 deletions recipes/multimodal/server/backends/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
Available backends:
- magpie_tts: MagpieTTS text-to-speech (audio output from text input)
- nemo_asr: NeMo ASR speech-to-text (text output from audio input)
- salm: SALM speech-to-text using chat-style generate API (e.g. canary-qwen-2.5b)

Backends are lazily loaded to avoid importing heavy dependencies upfront.
"""
Expand All @@ -40,6 +41,7 @@
BACKEND_REGISTRY = {
"magpie_tts": ("magpie_tts_backend", "MagpieTTSBackend"),
"nemo_asr": ("nemo_asr_backend", "NeMoASRBackend"),
"salm": ("salm_backend", "SALMBackend"),
}


Expand Down
14 changes: 12 additions & 2 deletions recipes/multimodal/server/backends/nemo_asr_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,13 @@ def _parse_single_hypothesis(self, hyp: Any) -> tuple[str, List[Dict[str, Any]]]
return hyp, []

if isinstance(hyp, dict):
text = hyp.get("text") or hyp.get("pred_text") or hyp.get("transcript") or ""
text = hyp.get("text")
if text is None:
text = hyp.get("pred_text")
if text is None:
text = hyp.get("transcript")
if text is None:
text = ""
words = hyp.get("words")
if words is None:
ts = hyp.get("timestamp")
Expand All @@ -229,7 +235,11 @@ def _parse_single_hypothesis(self, hyp: Any) -> tuple[str, List[Dict[str, Any]]]
words = ts["word"]
return text, self._normalize_words(words)

text = getattr(hyp, "text", None) or getattr(hyp, "pred_text", None) or str(hyp)
text = getattr(hyp, "text", None)
if text is None:
text = getattr(hyp, "pred_text", None)
if text is None:
text = ""
Comment thread
melllinia marked this conversation as resolved.
words = getattr(hyp, "words", None)
if words is None:
ts = getattr(hyp, "timestamp", None)
Expand Down
Loading
Loading