Skip to content

Commit a6f3ffe

Browse files
committed
Only update KV prefix cache on a good cache hit
1 parent 39c39e8 commit a6f3ffe

3 files changed

Lines changed: 38 additions & 4 deletions

File tree

src/exo/worker/engines/mlx/generator/batch_generate.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,10 @@
4040
patch_embed_tokens,
4141
prefill,
4242
)
43-
from exo.worker.engines.mlx.utils_mlx import fix_unmatched_think_end_tokens
43+
from exo.worker.engines.mlx.utils_mlx import (
44+
fix_unmatched_think_end_tokens,
45+
system_prompt_token_count,
46+
)
4447
from exo.worker.engines.mlx.vision import (
4548
MediaRegion,
4649
VisionProcessor,
@@ -211,12 +214,16 @@ def submit(
211214
c._idx = c.max_size
212215

213216
if not is_bench:
217+
min_prefix_hit_length = max(
218+
1000, system_prompt_token_count(task_params, self.tokenizer)
219+
)
214220
self._save_prefix_cache(
215221
all_prompt_tokens,
216222
list(cache),
217223
cache_snapshots,
218224
prefix_hit_length,
219225
matched_index,
226+
min_prefix_hit_length,
220227
media_regions,
221228
)
222229

@@ -426,6 +433,7 @@ def _save_prefix_cache(
426433
cache_snapshots: list[CacheSnapshot] | None,
427434
prefix_hit_length: int,
428435
matched_index: int | None,
436+
min_prefix_hit_length: int = 1000,
429437
media_regions: list[MediaRegion] | None = None,
430438
) -> None:
431439
if self.kv_prefix_cache is None:
@@ -438,7 +446,8 @@ def _save_prefix_cache(
438446
else 0.0
439447
)
440448
if matched_index is not None and (
441-
prefix_hit_length > 1000 or hit_ratio >= _MIN_PREFIX_HIT_RATIO_TO_UPDATE
449+
prefix_hit_length >= min_prefix_hit_length
450+
and hit_ratio >= _MIN_PREFIX_HIT_RATIO_TO_UPDATE
442451
):
443452
self.kv_prefix_cache.update_kv_cache(
444453
matched_index,

src/exo/worker/engines/mlx/generator/generate.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
apply_chat_template,
5656
fix_unmatched_think_end_tokens,
5757
mx_barrier,
58+
system_prompt_token_count,
5859
)
5960
from exo.worker.engines.mlx.vision import (
6061
MediaRegion,
@@ -498,6 +499,7 @@ def mlx_generate(
498499
# Encode prompt once at the top and fix unmatched think tags
499500
all_prompt_tokens = encode_prompt(tokenizer, prompt)
500501
all_prompt_tokens = fix_unmatched_think_end_tokens(all_prompt_tokens, tokenizer)
502+
min_prefix_hit_length = max(1000, system_prompt_token_count(task, tokenizer))
501503

502504
vision: VisionResult | None = None
503505
if vision_processor is not None:
@@ -714,8 +716,8 @@ def mlx_generate(
714716
else 0.0
715717
)
716718
if matched_index is not None and (
717-
prefix_hit_length > 1000
718-
or hit_ratio >= _MIN_PREFIX_HIT_RATIO_TO_UPDATE
719+
prefix_hit_length >= min_prefix_hit_length
720+
and hit_ratio >= _MIN_PREFIX_HIT_RATIO_TO_UPDATE
719721
):
720722
kv_prefix_cache.update_kv_cache(
721723
matched_index,

src/exo/worker/engines/mlx/utils_mlx.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -633,6 +633,29 @@ def apply_chat_template(
633633
return prompt
634634

635635

636+
def system_prompt_token_count(
637+
task_params: TextGenerationTaskParams,
638+
tokenizer: TokenizerWrapper,
639+
) -> int:
640+
"""Approximate token count of the system prompt portion of the input."""
641+
parts: list[str] = []
642+
if task_params.chat_template_messages is not None:
643+
for msg in task_params.chat_template_messages:
644+
if msg.get("role") in ("system", "developer"):
645+
content = msg.get("content", "") # type: ignore
646+
if isinstance(content, str):
647+
parts.append(content)
648+
else:
649+
if task_params.instructions:
650+
parts.append(task_params.instructions)
651+
for msg in task_params.input:
652+
if msg.role in ("system", "developer"):
653+
parts.append(msg.content)
654+
if len(parts) == 0:
655+
return 0
656+
return len(tokenizer.encode(" ".join(parts), add_special_tokens=False))
657+
658+
636659
def detect_thinking_prompt_suffix(prompt: str, tokenizer: TokenizerWrapper) -> bool:
637660
"""
638661
Detect if prompt ends with a thinking opening tag that should be

0 commit comments

Comments
 (0)