@@ -20,6 +20,7 @@ limitations under the License.
2020#include < algorithm>
2121#include < cctype>
2222#include < memory>
23+ #include < sstream>
2324
2425#include " common/global_flags.h"
2526#include " common/metrics.h"
@@ -193,6 +194,163 @@ torch::Tensor to_cpu_int_tensor_for_read(const torch::Tensor& values) {
193194 .contiguous ();
194195}
195196
197+ std::string summarize_int32_values (const int32_t * values,
198+ size_t size,
199+ size_t limit) {
200+ std::ostringstream oss;
201+ oss << " [" ;
202+ const size_t print_size = std::min (size, limit);
203+ for (size_t i = 0 ; i < print_size; ++i) {
204+ if (i > 0 ) {
205+ oss << " , " ;
206+ }
207+ oss << values[i];
208+ }
209+ if (size > limit) {
210+ oss << " , ..." ;
211+ }
212+ oss << " ]" ;
213+ return oss.str ();
214+ }
215+
216+ std::string summarize_int32_vector (const std::vector<int32_t >& values,
217+ size_t limit = 32 ) {
218+ return summarize_int32_values (values.data (), values.size (), limit);
219+ }
220+
221+ std::string summarize_string_vector (const std::vector<std::string>& values,
222+ size_t limit = 8 ) {
223+ std::ostringstream oss;
224+ oss << " [" ;
225+ const size_t print_size = std::min (values.size (), limit);
226+ for (size_t i = 0 ; i < print_size; ++i) {
227+ if (i > 0 ) {
228+ oss << " , " ;
229+ }
230+ oss << values[i];
231+ }
232+ if (values.size () > limit) {
233+ oss << " , ..." ;
234+ }
235+ oss << " ]" ;
236+ return oss.str ();
237+ }
238+
239+ std::string summarize_int_tensor (const torch::Tensor& tensor,
240+ size_t limit = 64 ) {
241+ if (!tensor.defined ()) {
242+ return " undefined" ;
243+ }
244+ torch::Tensor flat = to_cpu_int_tensor_for_read (tensor);
245+ std::ostringstream oss;
246+ oss << " sizes=" << tensor.sizes () << " , values="
247+ << summarize_int32_values (flat.const_data_ptr <int32_t >(),
248+ static_cast <size_t >(flat.numel ()),
249+ limit);
250+ return oss.str ();
251+ }
252+
253+ std::string summarize_decode_states (
254+ const std::vector<EmbeddingCache::DecodeState>& states,
255+ size_t limit = 16 ) {
256+ std::ostringstream oss;
257+ oss << " [" ;
258+ const size_t print_size = std::min (states.size (), limit);
259+ for (size_t i = 0 ; i < print_size; ++i) {
260+ if (i > 0 ) {
261+ oss << " , " ;
262+ }
263+ const auto & state = states[i];
264+ oss << " {idx=" << i << " , valid=" << state.valid
265+ << " , token_id=" << state.token_id
266+ << " , position_offset=" << state.position_offset
267+ << " , prev_token_id=" << state.prev_token_id
268+ << " , all_draft_accepted=" << state.all_draft_accepted << " }" ;
269+ }
270+ if (states.size () > limit) {
271+ oss << " , ..." ;
272+ }
273+ oss << " ]" ;
274+ return oss.str ();
275+ }
276+
277+ bool has_accepted_prefix_longer_than_one (const torch::Tensor& accepted_tokens) {
278+ if (!accepted_tokens.defined () || accepted_tokens.dim () != 2 ) {
279+ return false ;
280+ }
281+ torch::Tensor flat = to_cpu_int_tensor_for_read (accepted_tokens);
282+ const int32_t * data = flat.const_data_ptr <int32_t >();
283+ const int64_t rows = accepted_tokens.size (0 );
284+ const int64_t width = accepted_tokens.size (1 );
285+ for (int64_t row = 0 ; row < rows; ++row) {
286+ int32_t accepted_len = 0 ;
287+ for (int64_t col = 0 ; col < width; ++col) {
288+ if (data[row * width + col] < 0 ) {
289+ break ;
290+ }
291+ ++accepted_len;
292+ }
293+ if (accepted_len > 1 ) {
294+ return true ;
295+ }
296+ }
297+ return false ;
298+ }
299+
300+ std::string summarize_accepted_tokens (const torch::Tensor& accepted_tokens,
301+ size_t max_rows = 8 ,
302+ size_t max_width = 8 ) {
303+ if (!accepted_tokens.defined ()) {
304+ return " undefined" ;
305+ }
306+ if (accepted_tokens.dim () != 2 ) {
307+ return summarize_int_tensor (accepted_tokens);
308+ }
309+ torch::Tensor flat = to_cpu_int_tensor_for_read (accepted_tokens);
310+ const int32_t * data = flat.const_data_ptr <int32_t >();
311+ const int64_t rows = accepted_tokens.size (0 );
312+ const int64_t width = accepted_tokens.size (1 );
313+
314+ std::ostringstream oss;
315+ oss << " sizes=" << accepted_tokens.sizes () << " , rows=[" ;
316+ const int64_t print_rows =
317+ std::min<int64_t >(rows, static_cast <int64_t >(max_rows));
318+ for (int64_t row = 0 ; row < print_rows; ++row) {
319+ if (row > 0 ) {
320+ oss << " , " ;
321+ }
322+ int32_t accepted_len = 0 ;
323+ int32_t last_token = -1 ;
324+ for (int64_t col = 0 ; col < width; ++col) {
325+ const int32_t token = data[row * width + col];
326+ if (token < 0 ) {
327+ break ;
328+ }
329+ last_token = token;
330+ ++accepted_len;
331+ }
332+ oss << " {row=" << row << " , accepted_len=" << accepted_len
333+ << " , last_token=" << last_token << " , tokens=[" ;
334+ const int64_t print_width =
335+ std::min<int64_t >(width, static_cast <int64_t >(max_width));
336+ for (int64_t col = 0 ; col < print_width; ++col) {
337+ if (col > 0 ) {
338+ oss << " , " ;
339+ }
340+ oss << data[row * width + col];
341+ }
342+ if (width > print_width) {
343+ oss << " , ..." ;
344+ }
345+ oss << " ]}" ;
346+ }
347+ if (rows > print_rows) {
348+ oss << " , ..." ;
349+ }
350+ oss << " ]" ;
351+ return oss.str ();
352+ }
353+
196354bool has_mtp_prefill_placeholder_extra_token (
197355 const std::vector<int32_t >& extra_token_ids,
198356 int32_t placeholder) {
@@ -225,6 +383,18 @@ void check_mtp_decode_states(
225383 CHECK (state.embedding .defined ())
226384 << " MTP decode target state embedding is undefined, request_id="
227385 << request_ids[i];
386+ if (state.token_id != token_id) {
387+ LOG (ERROR ) << " [MTP_DECODE_STATE_MISMATCH] seq_id=" << i
388+ << " , request_id=" << request_ids[i]
389+ << " , input_token_id=" << token_id
390+ << " , cached_token_id=" << state.token_id
391+ << " , cached_position_offset=" << state.position_offset
392+ << " , cached_prev_token_id=" << state.prev_token_id
393+ << " , all_draft_accepted=" << state.all_draft_accepted
394+ << " , token_ids_host="
395+ << summarize_int_tensor (token_ids_host)
396+ << " , decode_states=" << summarize_decode_states (states);
397+ }
228398 if (token_id < 0 ) {
229399 CHECK (allow_overlap_fake_token)
230400 << " MTP decode fake token is only allowed with schedule overlap, "
@@ -915,6 +1085,23 @@ std::optional<ForwardOutput> MTPWorkerImpl::step_decode(
9151085 CHECK_EQ (last_states.size (),
9161086 input.input_params .embedding .embedding_ids .size ())
9171087 << " decode target state count mismatch" ;
1088+ VLOG (1 ) << " [MTP_DECODE_INPUT] num_sequences="
1089+ << input.input_params .meta .num_sequences
1090+ << " , token_ids_host=" << summarize_int_tensor (input.token_ids_host )
1091+ << " , q_seq_lens="
1092+ << summarize_int32_vector (
1093+ input.input_params .attention .host .q_seq_lens )
1094+ << " , kv_seq_lens="
1095+ << summarize_int32_vector (
1096+ input.input_params .attention .host .kv_seq_lens )
1097+ << " , embedding_ids="
1098+ << summarize_int32_vector (
1099+ input.input_params .embedding .embedding_ids )
1100+ << " , request_ids="
1101+ << summarize_string_vector (
1102+ input.input_params .embedding .request_ids );
1103+ VLOG (1 ) << " [MTP_DECODE_CACHE] states="
1104+ << summarize_decode_states (last_states);
9181105 check_mtp_decode_states (last_states,
9191106 input.input_params .embedding .request_ids ,
9201107 input.token_ids_host ,
@@ -1047,6 +1234,23 @@ void MTPWorkerImpl::write_target_context_to_cache(
10471234 << " embedding_cache_ must be initialized before target cache write" ;
10481235 CHECK (!input.input_params .embedding .embedding_ids .empty ())
10491236 << " target context cache write requires embedding ids" ;
1237+ const bool has_multi_token_accept =
1238+ has_accepted_prefix_longer_than_one (validate_output.next_tokens );
1239+ if (has_multi_token_accept) {
1240+ LOG (INFO ) << " [MTP_ACCEPTED_PREFIX] multi-token accepted prefix detected, "
1241+ << " num_speculative_tokens=" << options_.num_speculative_tokens ()
1242+ << " , embedding_ids="
1243+ << summarize_int32_vector (
1244+ input.input_params .embedding .embedding_ids )
1245+ << " , request_ids="
1246+ << summarize_string_vector (
1247+ input.input_params .embedding .request_ids )
1248+ << " , accepted_tokens="
1249+ << summarize_accepted_tokens (validate_output.next_tokens );
1250+ } else {
1251+ VLOG (1 ) << " [MTP_ACCEPTED_PREFIX] accepted_tokens="
1252+ << summarize_accepted_tokens (validate_output.next_tokens );
1253+ }
10501254 embedding_cache_->write_target_context (
10511255 input.input_params .embedding .embedding_ids ,
10521256 input.input_params .embedding .request_ids ,
@@ -1217,6 +1421,15 @@ void MTPWorkerImpl::prepare_validate_inputs(const ForwardInput& input,
12171421 const bool use_atb_spec_kernel =
12181422 ::xllm::SpeculativeConfig::get_instance ().enable_atb_spec_kernel() ||
12191423 use_qwen3_5_spec_verify_path();
1424+ LOG_FIRST_N (INFO , 4 )
1425+ << " [MTP_VALIDATE_LAYOUT] use_atb_spec_kernel=" << use_atb_spec_kernel
1426+ << " , original_num_sequences=" << num_sequences
1427+ << " , num_speculative_tokens=" << num_speculative_tokens
1428+ << " , num_val_tokens=" << num_val_tokens
1429+ << " , total_num_val_tokens=" << total_num_val_tokens
1430+ << " , layout="
1431+ << (use_atb_spec_kernel ? " chunked_prefill_rows"
1432+ : " expanded_decode_rows" );
12201433 specBuilder::DecodeBuildBuffers buf;
12211434 buf.out_token_ids .reserve (total_num_val_tokens);
12221435 buf.out_positions .reserve (input.positions_host .dim () == 2
0 commit comments