-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathlibrispeech_asr.py
More file actions
143 lines (115 loc) · 5.53 KB
/
librispeech_asr.py
File metadata and controls
143 lines (115 loc) · 5.53 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import os
import datasets
import glob
# hugging face sucks
# this works until they break something else
# train = load_dataset("./local_path_to/librispeech_asr.py", "clean", split="train_clean_100", trust_remote_code=True, storage_options={'client_kwargs': {'timeout': aiohttp.ClientTimeout(total=3600, connect=60)}}).rename_column("text", "transcription")
# test = load_dataset("./local_path_to/librispeech_asr.py", "clean", split="test_clean", trust_remote_code=True, storage_options={'client_kwargs': {'timeout': aiohttp.ClientTimeout(total=3600, connect=60)}}).rename_column("text", "transcription")
_URL = "http://www.openslr.org/12"
_DL_URL = "http://www.openslr.org/resources/12/"
_DL_URLS = {
"clean": {
"train_clean_100": _DL_URL + "train-clean-100.tar.gz",
"test_clean": _DL_URL + "test-clean.tar.gz",
"dev_clean": _DL_URL + "dev-clean.tar.gz",
},
}
class LibrispeechASRConfig(datasets.BuilderConfig):
"""BuilderConfig for LibriSpeechASR."""
def __init__(self, **kwargs):
super(LibrispeechASRConfig, self).__init__(version=datasets.Version("2.1.0", ""), **kwargs)
class LibrispeechASR(datasets.GeneratorBasedBuilder):
"""Librispeech dataset."""
DEFAULT_WRITER_BATCH_SIZE = 256
DEFAULT_CONFIG_NAME = "clean"
BUILDER_CONFIGS = [
LibrispeechASRConfig(name="clean", description="'Clean' speech."),
]
def _info(self):
return datasets.DatasetInfo(
features=datasets.Features(
{
"file": datasets.Value("string"),
"audio": datasets.Audio(sampling_rate=16_000),
"text": datasets.Value("string"),
"speaker_id": datasets.Value("int64"),
"chapter_id": datasets.Value("int64"),
"id": datasets.Value("string"),
}
),
supervised_keys=("file", "text"),
homepage=_URL,
)
def _split_generators(self, dl_manager):
if self.config.name == "clean":
urls_to_download = {
"train_clean_100": _DL_URLS["clean"]["train_clean_100"],
"test_clean": _DL_URLS["clean"]["test_clean"],
}
else:
raise ValueError(f"Configuration '{self.config.name}' not supported in this script version.")
archive_path = dl_manager.download(urls_to_download)
local_extracted_archive = dl_manager.extract(archive_path)
splits = []
splits.append(
datasets.SplitGenerator(
name="train_clean_100",
gen_kwargs={
"path_to_extracted_archive": local_extracted_archive["train_clean_100"],
},
)
)
splits.append(
datasets.SplitGenerator(
name="test_clean",
gen_kwargs={
"path_to_extracted_archive": local_extracted_archive["test_clean"],
},
)
)
return splits
def _generate_examples(self, path_to_extracted_archive):
key = 0
transcription_files = glob.glob(os.path.join(path_to_extracted_archive, "**", "*.trans.txt"), recursive=True)
transcriptions_by_id = {}
for trans_file_path in transcription_files:
if "README.TXT" in os.path.basename(trans_file_path).upper():
continue
with open(trans_file_path, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if line:
parts = line.split(" ", 1)
if len(parts) == 2:
utt_id, transcript = parts
transcriptions_by_id[utt_id] = transcript
else:
print(f"Warning: Skipping malformed line in {trans_file_path}: {line}")
audio_files = glob.glob(os.path.join(path_to_extracted_archive, "**", "*.flac"), recursive=True)
for audio_file_path in audio_files:
file_basename = os.path.basename(audio_file_path)
utt_id_from_file = file_basename[: -len(".flac")]
if utt_id_from_file in transcriptions_by_id:
transcript = transcriptions_by_id[utt_id_from_file]
parts = utt_id_from_file.split("-")
if len(parts) >= 2:
try:
speaker_id = int(parts[0])
chapter_id = int(parts[1])
except ValueError:
print(f"Warning: Skipping utterance {utt_id_from_file} due to non-integer speaker/chapter ID.")
continue
else:
print(f"Warning: Skipping utterance {utt_id_from_file} due to unexpected ID format.")
continue
yield key, {
"file": audio_file_path,
"audio": audio_file_path,
"text": transcript,
"speaker_id": speaker_id,
"chapter_id": chapter_id,
"id": utt_id_from_file,
}
key += 1
else:
print(f"Warning: No transcription found for audio file {audio_file_path}. Skipping.")