Skip to content

Commit 6f4f93b

Browse files
committed
addressing review comments
Signed-off-by: mmkrtchyan <mmkrtchyan@nvidia.com>
1 parent 0f4a275 commit 6f4f93b

3 files changed

Lines changed: 11 additions & 7 deletions

File tree

recipes/asr/run_hf_leaderboard.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,8 @@
3838

3939
import argparse
4040

41-
from nemo_skills.pipeline.cli import eval, wrap_arguments
41+
from nemo_skills.pipeline.cli import eval as run_eval
42+
from nemo_skills.pipeline.cli import wrap_arguments
4243

4344
DEFAULT_SERVER_CONTAINER = "nvcr.io/nvidia/nemo:25.11"
4445
DEFAULT_INSTALLATION_COMMAND = "pip install -r requirements/audio.txt"
@@ -69,7 +70,7 @@ def main():
6970

7071
args = parser.parse_args()
7172

72-
eval(
73+
run_eval(
7374
ctx=wrap_arguments(
7475
"++prompt_format=openai "
7576
"++prompt_config=null "

recipes/multimodal/server/backends/nemo_asr_backend.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,9 @@ def _parse_single_hypothesis(self, hyp: Any) -> tuple[str, List[Dict[str, Any]]]
221221
if text is None:
222222
text = hyp.get("pred_text")
223223
if text is None:
224-
text = hyp.get("transcript", "")
224+
text = hyp.get("transcript")
225+
if text is None:
226+
text = ""
225227
words = hyp.get("words")
226228
if words is None:
227229
ts = hyp.get("timestamp")

recipes/multimodal/server/backends/salm_backend.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@ class SALMConfig(BackendConfig):
3333
"""Configuration for SALM backend."""
3434

3535
model_name: Optional[str] = None
36-
batch_size: int = 16
3736
warmup: bool = True
3837
user_prompt: str = DEFAULT_ASR_PROMPT
3938

@@ -50,7 +49,6 @@ def from_dict(cls, d: Dict[str, Any]) -> "SALMConfig":
5049
"temperature",
5150
"top_p",
5251
"top_k",
53-
"batch_size",
5452
"warmup",
5553
"user_prompt",
5654
}
@@ -144,6 +142,8 @@ def validate_request(self, request: GenerationRequest) -> Optional[str]:
144142
)
145143
if not has_audio:
146144
return "SALM backend requires audio input"
145+
if request.audio_bytes_list is not None and len(request.audio_bytes_list) > 1:
146+
return "SALM backend currently supports one audio input per request"
147147
return None
148148

149149
def generate(self, requests: List[GenerationRequest]) -> List[GenerationResult]:
@@ -170,9 +170,10 @@ def generate(self, requests: List[GenerationRequest]) -> List[GenerationResult]:
170170
results[idx] = GenerationResult(error=str(e), request_id=req.request_id)
171171

172172
if temp_paths:
173-
first_extra = requests[valid_indices[0]].extra_params or {}
173+
first_req = requests[valid_indices[0]]
174+
first_extra = first_req.extra_params or {}
174175
user_prompt = first_extra.get("user_prompt", self.salm_config.user_prompt)
175-
max_new_tokens = int(first_extra.get("max_new_tokens", self.config.max_new_tokens))
176+
max_new_tokens = first_req.max_new_tokens or self.config.max_new_tokens
176177
audio_tag = self._model.audio_locator_tag
177178

178179
prompts = []

0 commit comments

Comments
 (0)