Skip to content

Commit 49ab5f2

Browse files
committed
fix: correct MiniMax M3 partial cache resume (#1888)
1 parent a3c949f commit 49ab5f2

1 file changed

Lines changed: 122 additions & 4 deletions

File tree

omlx/scheduler.py

Lines changed: 122 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5465,6 +5465,12 @@ def add_request(self, request: Request) -> None:
54655465
request.remaining_tokens = request.prompt_token_ids[
54665466
block_table.num_tokens :
54675467
]
5468+
if self._align_minimax_m3_partial_cache_to_prefill_step(request):
5469+
request.cached_tokens = block_table.num_tokens
5470+
request.shared_prefix_blocks = len(block_table.block_ids)
5471+
request.remaining_tokens = request.prompt_token_ids[
5472+
block_table.num_tokens :
5473+
]
54685474
# For exact prefix hits we need cache state at (N-1) and the
54695475
# last prompt token as input to produce the first decode logit.
54705476
# Reusing cache state at N and feeding the last token again
@@ -6132,34 +6138,146 @@ def _cleanup_specprefill(self, request_id: str) -> None:
61326138

61336139
def _trim_prompt_cache_for_generation(self, cache_list: list[Any]) -> bool:
61346140
"""Trim each cache layer by one token for exact-hit generation kickoff."""
6141+
return self._trim_prompt_cache_by_tokens(cache_list, 1)
6142+
6143+
def _trim_prompt_cache_by_tokens(self, cache_list: list[Any], n: int) -> bool:
6144+
"""Trim each cache layer by n tokens."""
61356145
if not cache_list:
61366146
return False
6147+
if n <= 0:
6148+
return True
61376149

61386150
for cache_obj in cache_list:
6139-
if not self._trim_cache_tree_by_one(cache_obj):
6151+
if not self._trim_cache_tree_by_tokens(cache_obj, n):
61406152
return False
61416153
return True
61426154

61436155
def _trim_cache_tree_by_one(self, cache_obj: Any) -> bool:
61446156
"""Trim one token from cache object (recursively for CacheList)."""
6157+
return self._trim_cache_tree_by_tokens(cache_obj, 1)
6158+
6159+
def _trim_cache_tree_by_tokens(self, cache_obj: Any, n: int) -> bool:
6160+
"""Trim n tokens from cache object (recursively for CacheList)."""
61456161
sub_caches = getattr(cache_obj, "caches", None)
61466162
if isinstance(sub_caches, (list, tuple)):
61476163
return all(
6148-
self._trim_cache_tree_by_one(sub_cache) for sub_cache in sub_caches
6164+
self._trim_cache_tree_by_tokens(sub_cache, n)
6165+
for sub_cache in sub_caches
61496166
)
61506167

61516168
trim_fn = getattr(cache_obj, "trim", None)
61526169
if not callable(trim_fn):
61536170
return False
61546171

61556172
try:
6156-
trimmed = trim_fn(1)
6173+
trimmed = trim_fn(n)
61576174
if trimmed is None:
61586175
return True
6159-
return int(trimmed) >= 1
6176+
return int(trimmed) >= n
61606177
except Exception:
61616178
return False
61626179

6180+
def _cache_tree_has_class_name(
6181+
self,
6182+
cache_obj: Any,
6183+
class_names: frozenset[str],
6184+
) -> bool:
6185+
"""Return True when a cache tree contains one of the named cache classes."""
6186+
if type(cache_obj).__name__ in class_names:
6187+
return True
6188+
sub_caches = getattr(cache_obj, "caches", None)
6189+
if isinstance(sub_caches, (list, tuple)):
6190+
return any(
6191+
self._cache_tree_has_class_name(sub_cache, class_names)
6192+
for sub_cache in sub_caches
6193+
)
6194+
return False
6195+
6196+
def _align_minimax_m3_partial_cache_to_prefill_step(
6197+
self,
6198+
request: "Request",
6199+
) -> bool:
6200+
"""Align MiniMax M3 partial hits to external prefill chunk boundaries."""
6201+
cache_list = request.prompt_cache
6202+
block_table = request.block_table
6203+
prompt_tokens = request.prompt_token_ids or []
6204+
if not cache_list or block_table is None or not block_table.block_ids:
6205+
return False
6206+
if (
6207+
block_table.num_tokens <= 0
6208+
or block_table.num_tokens >= len(prompt_tokens)
6209+
):
6210+
return False
6211+
6212+
minimax_m3_names = frozenset({"MiniMaxM3KVCache"})
6213+
has_minimax_m3 = any(
6214+
self._cache_tree_has_class_name(cache_obj, minimax_m3_names)
6215+
for cache_obj in cache_list
6216+
)
6217+
if not has_minimax_m3:
6218+
return False
6219+
6220+
block_size = int(getattr(self.config, "paged_cache_block_size", 0) or 0)
6221+
prefill_step = int(getattr(self.config, "prefill_step_size", 0) or 0)
6222+
if block_size <= 0 or prefill_step <= block_size:
6223+
return False
6224+
6225+
aligned_tokens = (block_table.num_tokens // prefill_step) * prefill_step
6226+
aligned_tokens = (aligned_tokens // block_size) * block_size
6227+
if aligned_tokens <= 0 or aligned_tokens >= block_table.num_tokens:
6228+
return False
6229+
6230+
target_block_count = 0
6231+
target_tokens = 0
6232+
for block_id in block_table.block_ids:
6233+
block = (
6234+
self.paged_cache_manager.allocated_blocks.get(block_id)
6235+
if self.paged_cache_manager is not None
6236+
else None
6237+
)
6238+
token_count = int(getattr(block, "token_count", block_size) or block_size)
6239+
if target_tokens + token_count > aligned_tokens:
6240+
break
6241+
target_tokens += token_count
6242+
target_block_count += 1
6243+
6244+
if target_tokens != aligned_tokens:
6245+
logger.debug(
6246+
"MiniMax M3 partial cache alignment skipped for %s: cannot align "
6247+
"block table from %d to %d tokens",
6248+
request.request_id,
6249+
block_table.num_tokens,
6250+
aligned_tokens,
6251+
)
6252+
return False
6253+
6254+
trim_tokens = block_table.num_tokens - aligned_tokens
6255+
if not self._trim_prompt_cache_by_tokens(cache_list, trim_tokens):
6256+
logger.debug(
6257+
"MiniMax M3 partial cache alignment skipped for %s: cache trim "
6258+
"by %d tokens failed",
6259+
request.request_id,
6260+
trim_tokens,
6261+
)
6262+
return False
6263+
6264+
dropped_block_ids = block_table.block_ids[target_block_count:]
6265+
if self.paged_cache_manager is not None:
6266+
for block_id in dropped_block_ids:
6267+
self.paged_cache_manager.free_block(block_id)
6268+
block_table.block_ids = block_table.block_ids[:target_block_count]
6269+
block_table.num_tokens = aligned_tokens
6270+
6271+
logger.info(
6272+
"MiniMax M3 partial cache aligned to prefill step for %s: "
6273+
"%d -> %d tokens, dropped %d block(s)",
6274+
request.request_id,
6275+
aligned_tokens + trim_tokens,
6276+
aligned_tokens,
6277+
len(dropped_block_ids),
6278+
)
6279+
return True
6280+
61636281
def _remove_uid_from_active_batch(self, uid: int) -> None:
61646282
"""Remove UID from BatchGenerator safely.
61656283

0 commit comments

Comments
 (0)