@@ -5465,6 +5465,12 @@ def add_request(self, request: Request) -> None:
54655465 request .remaining_tokens = request .prompt_token_ids [
54665466 block_table .num_tokens :
54675467 ]
5468+ if self ._align_minimax_m3_partial_cache_to_prefill_step (request ):
5469+ request .cached_tokens = block_table .num_tokens
5470+ request .shared_prefix_blocks = len (block_table .block_ids )
5471+ request .remaining_tokens = request .prompt_token_ids [
5472+ block_table .num_tokens :
5473+ ]
54685474 # For exact prefix hits we need cache state at (N-1) and the
54695475 # last prompt token as input to produce the first decode logit.
54705476 # Reusing cache state at N and feeding the last token again
@@ -6132,34 +6138,146 @@ def _cleanup_specprefill(self, request_id: str) -> None:
61326138
61336139 def _trim_prompt_cache_for_generation (self , cache_list : list [Any ]) -> bool :
61346140 """Trim each cache layer by one token for exact-hit generation kickoff."""
6141+ return self ._trim_prompt_cache_by_tokens (cache_list , 1 )
6142+
6143+ def _trim_prompt_cache_by_tokens (self , cache_list : list [Any ], n : int ) -> bool :
6144+ """Trim each cache layer by n tokens."""
61356145 if not cache_list :
61366146 return False
6147+ if n <= 0 :
6148+ return True
61376149
61386150 for cache_obj in cache_list :
6139- if not self ._trim_cache_tree_by_one (cache_obj ):
6151+ if not self ._trim_cache_tree_by_tokens (cache_obj , n ):
61406152 return False
61416153 return True
61426154
61436155 def _trim_cache_tree_by_one (self , cache_obj : Any ) -> bool :
61446156 """Trim one token from cache object (recursively for CacheList)."""
6157+ return self ._trim_cache_tree_by_tokens (cache_obj , 1 )
6158+
6159+ def _trim_cache_tree_by_tokens (self , cache_obj : Any , n : int ) -> bool :
6160+ """Trim n tokens from cache object (recursively for CacheList)."""
61456161 sub_caches = getattr (cache_obj , "caches" , None )
61466162 if isinstance (sub_caches , (list , tuple )):
61476163 return all (
6148- self ._trim_cache_tree_by_one (sub_cache ) for sub_cache in sub_caches
6164+ self ._trim_cache_tree_by_tokens (sub_cache , n )
6165+ for sub_cache in sub_caches
61496166 )
61506167
61516168 trim_fn = getattr (cache_obj , "trim" , None )
61526169 if not callable (trim_fn ):
61536170 return False
61546171
61556172 try :
6156- trimmed = trim_fn (1 )
6173+ trimmed = trim_fn (n )
61576174 if trimmed is None :
61586175 return True
6159- return int (trimmed ) >= 1
6176+ return int (trimmed ) >= n
61606177 except Exception :
61616178 return False
61626179
6180+ def _cache_tree_has_class_name (
6181+ self ,
6182+ cache_obj : Any ,
6183+ class_names : frozenset [str ],
6184+ ) -> bool :
6185+ """Return True when a cache tree contains one of the named cache classes."""
6186+ if type (cache_obj ).__name__ in class_names :
6187+ return True
6188+ sub_caches = getattr (cache_obj , "caches" , None )
6189+ if isinstance (sub_caches , (list , tuple )):
6190+ return any (
6191+ self ._cache_tree_has_class_name (sub_cache , class_names )
6192+ for sub_cache in sub_caches
6193+ )
6194+ return False
6195+
6196+ def _align_minimax_m3_partial_cache_to_prefill_step (
6197+ self ,
6198+ request : "Request" ,
6199+ ) -> bool :
6200+ """Align MiniMax M3 partial hits to external prefill chunk boundaries."""
6201+ cache_list = request .prompt_cache
6202+ block_table = request .block_table
6203+ prompt_tokens = request .prompt_token_ids or []
6204+ if not cache_list or block_table is None or not block_table .block_ids :
6205+ return False
6206+ if (
6207+ block_table .num_tokens <= 0
6208+ or block_table .num_tokens >= len (prompt_tokens )
6209+ ):
6210+ return False
6211+
6212+ minimax_m3_names = frozenset ({"MiniMaxM3KVCache" })
6213+ has_minimax_m3 = any (
6214+ self ._cache_tree_has_class_name (cache_obj , minimax_m3_names )
6215+ for cache_obj in cache_list
6216+ )
6217+ if not has_minimax_m3 :
6218+ return False
6219+
6220+ block_size = int (getattr (self .config , "paged_cache_block_size" , 0 ) or 0 )
6221+ prefill_step = int (getattr (self .config , "prefill_step_size" , 0 ) or 0 )
6222+ if block_size <= 0 or prefill_step <= block_size :
6223+ return False
6224+
6225+ aligned_tokens = (block_table .num_tokens // prefill_step ) * prefill_step
6226+ aligned_tokens = (aligned_tokens // block_size ) * block_size
6227+ if aligned_tokens <= 0 or aligned_tokens >= block_table .num_tokens :
6228+ return False
6229+
6230+ target_block_count = 0
6231+ target_tokens = 0
6232+ for block_id in block_table .block_ids :
6233+ block = (
6234+ self .paged_cache_manager .allocated_blocks .get (block_id )
6235+ if self .paged_cache_manager is not None
6236+ else None
6237+ )
6238+ token_count = int (getattr (block , "token_count" , block_size ) or block_size )
6239+ if target_tokens + token_count > aligned_tokens :
6240+ break
6241+ target_tokens += token_count
6242+ target_block_count += 1
6243+
6244+ if target_tokens != aligned_tokens :
6245+ logger .debug (
6246+ "MiniMax M3 partial cache alignment skipped for %s: cannot align "
6247+ "block table from %d to %d tokens" ,
6248+ request .request_id ,
6249+ block_table .num_tokens ,
6250+ aligned_tokens ,
6251+ )
6252+ return False
6253+
6254+ trim_tokens = block_table .num_tokens - aligned_tokens
6255+ if not self ._trim_prompt_cache_by_tokens (cache_list , trim_tokens ):
6256+ logger .debug (
6257+ "MiniMax M3 partial cache alignment skipped for %s: cache trim "
6258+ "by %d tokens failed" ,
6259+ request .request_id ,
6260+ trim_tokens ,
6261+ )
6262+ return False
6263+
6264+ dropped_block_ids = block_table .block_ids [target_block_count :]
6265+ if self .paged_cache_manager is not None :
6266+ for block_id in dropped_block_ids :
6267+ self .paged_cache_manager .free_block (block_id )
6268+ block_table .block_ids = block_table .block_ids [:target_block_count ]
6269+ block_table .num_tokens = aligned_tokens
6270+
6271+ logger .info (
6272+ "MiniMax M3 partial cache aligned to prefill step for %s: "
6273+ "%d -> %d tokens, dropped %d block(s)" ,
6274+ request .request_id ,
6275+ aligned_tokens + trim_tokens ,
6276+ aligned_tokens ,
6277+ len (dropped_block_ids ),
6278+ )
6279+ return True
6280+
61636281 def _remove_uid_from_active_batch (self , uid : int ) -> None :
61646282 """Remove UID from BatchGenerator safely.
61656283
0 commit comments