Skip to content

Commit 856aa2e

Browse files
temp_fix
1 parent 31eb406 commit 856aa2e

1 file changed

Lines changed: 213 additions & 0 deletions

File tree

xllm/core/runtime/mtp_worker_impl.cpp

Lines changed: 213 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ limitations under the License.
2020
#include <algorithm>
2121
#include <cctype>
2222
#include <memory>
23+
#include <sstream>
2324

2425
#include "common/global_flags.h"
2526
#include "common/metrics.h"
@@ -193,6 +194,163 @@ torch::Tensor to_cpu_int_tensor_for_read(const torch::Tensor& values) {
193194
.contiguous();
194195
}
195196

197+
std::string summarize_int32_values(const int32_t* values,
198+
size_t size,
199+
size_t limit) {
200+
std::ostringstream oss;
201+
oss << "[";
202+
const size_t print_size = std::min(size, limit);
203+
for (size_t i = 0; i < print_size; ++i) {
204+
if (i > 0) {
205+
oss << ", ";
206+
}
207+
oss << values[i];
208+
}
209+
if (size > limit) {
210+
oss << ", ...";
211+
}
212+
oss << "]";
213+
return oss.str();
214+
}
215+
216+
std::string summarize_int32_vector(const std::vector<int32_t>& values,
217+
size_t limit = 32) {
218+
return summarize_int32_values(values.data(), values.size(), limit);
219+
}
220+
221+
std::string summarize_string_vector(const std::vector<std::string>& values,
222+
size_t limit = 8) {
223+
std::ostringstream oss;
224+
oss << "[";
225+
const size_t print_size = std::min(values.size(), limit);
226+
for (size_t i = 0; i < print_size; ++i) {
227+
if (i > 0) {
228+
oss << ", ";
229+
}
230+
oss << values[i];
231+
}
232+
if (values.size() > limit) {
233+
oss << ", ...";
234+
}
235+
oss << "]";
236+
return oss.str();
237+
}
238+
239+
std::string summarize_int_tensor(const torch::Tensor& tensor,
240+
size_t limit = 64) {
241+
if (!tensor.defined()) {
242+
return "undefined";
243+
}
244+
torch::Tensor flat = to_cpu_int_tensor_for_read(tensor);
245+
std::ostringstream oss;
246+
oss << "sizes=" << tensor.sizes() << ", values="
247+
<< summarize_int32_values(flat.const_data_ptr<int32_t>(),
248+
static_cast<size_t>(flat.numel()),
249+
limit);
250+
return oss.str();
251+
}
252+
253+
std::string summarize_decode_states(
254+
const std::vector<EmbeddingCache::DecodeState>& states,
255+
size_t limit = 16) {
256+
std::ostringstream oss;
257+
oss << "[";
258+
const size_t print_size = std::min(states.size(), limit);
259+
for (size_t i = 0; i < print_size; ++i) {
260+
if (i > 0) {
261+
oss << ", ";
262+
}
263+
const auto& state = states[i];
264+
oss << "{idx=" << i << ", valid=" << state.valid
265+
<< ", token_id=" << state.token_id
266+
<< ", position_offset=" << state.position_offset
267+
<< ", prev_token_id=" << state.prev_token_id
268+
<< ", all_draft_accepted=" << state.all_draft_accepted << "}";
269+
}
270+
if (states.size() > limit) {
271+
oss << ", ...";
272+
}
273+
oss << "]";
274+
return oss.str();
275+
}
276+
277+
bool has_accepted_prefix_longer_than_one(const torch::Tensor& accepted_tokens) {
278+
if (!accepted_tokens.defined() || accepted_tokens.dim() != 2) {
279+
return false;
280+
}
281+
torch::Tensor flat = to_cpu_int_tensor_for_read(accepted_tokens);
282+
const int32_t* data = flat.const_data_ptr<int32_t>();
283+
const int64_t rows = accepted_tokens.size(0);
284+
const int64_t width = accepted_tokens.size(1);
285+
for (int64_t row = 0; row < rows; ++row) {
286+
int32_t accepted_len = 0;
287+
for (int64_t col = 0; col < width; ++col) {
288+
if (data[row * width + col] < 0) {
289+
break;
290+
}
291+
++accepted_len;
292+
}
293+
if (accepted_len > 1) {
294+
return true;
295+
}
296+
}
297+
return false;
298+
}
299+
300+
std::string summarize_accepted_tokens(const torch::Tensor& accepted_tokens,
301+
size_t max_rows = 8,
302+
size_t max_width = 8) {
303+
if (!accepted_tokens.defined()) {
304+
return "undefined";
305+
}
306+
if (accepted_tokens.dim() != 2) {
307+
return summarize_int_tensor(accepted_tokens);
308+
}
309+
torch::Tensor flat = to_cpu_int_tensor_for_read(accepted_tokens);
310+
const int32_t* data = flat.const_data_ptr<int32_t>();
311+
const int64_t rows = accepted_tokens.size(0);
312+
const int64_t width = accepted_tokens.size(1);
313+
314+
std::ostringstream oss;
315+
oss << "sizes=" << accepted_tokens.sizes() << ", rows=[";
316+
const int64_t print_rows =
317+
std::min<int64_t>(rows, static_cast<int64_t>(max_rows));
318+
for (int64_t row = 0; row < print_rows; ++row) {
319+
if (row > 0) {
320+
oss << ", ";
321+
}
322+
int32_t accepted_len = 0;
323+
int32_t last_token = -1;
324+
for (int64_t col = 0; col < width; ++col) {
325+
const int32_t token = data[row * width + col];
326+
if (token < 0) {
327+
break;
328+
}
329+
last_token = token;
330+
++accepted_len;
331+
}
332+
oss << "{row=" << row << ", accepted_len=" << accepted_len
333+
<< ", last_token=" << last_token << ", tokens=[";
334+
const int64_t print_width =
335+
std::min<int64_t>(width, static_cast<int64_t>(max_width));
336+
for (int64_t col = 0; col < print_width; ++col) {
337+
if (col > 0) {
338+
oss << ", ";
339+
}
340+
oss << data[row * width + col];
341+
}
342+
if (width > print_width) {
343+
oss << ", ...";
344+
}
345+
oss << "]}";
346+
}
347+
if (rows > print_rows) {
348+
oss << ", ...";
349+
}
350+
oss << "]";
351+
return oss.str();
352+
}
353+
196354
bool has_mtp_prefill_placeholder_extra_token(
197355
const std::vector<int32_t>& extra_token_ids,
198356
int32_t placeholder) {
@@ -225,6 +383,18 @@ void check_mtp_decode_states(
225383
CHECK(state.embedding.defined())
226384
<< "MTP decode target state embedding is undefined, request_id="
227385
<< request_ids[i];
386+
if (state.token_id != token_id) {
387+
LOG(ERROR) << "[MTP_DECODE_STATE_MISMATCH] seq_id=" << i
388+
<< ", request_id=" << request_ids[i]
389+
<< ", input_token_id=" << token_id
390+
<< ", cached_token_id=" << state.token_id
391+
<< ", cached_position_offset=" << state.position_offset
392+
<< ", cached_prev_token_id=" << state.prev_token_id
393+
<< ", all_draft_accepted=" << state.all_draft_accepted
394+
<< ", token_ids_host="
395+
<< summarize_int_tensor(token_ids_host)
396+
<< ", decode_states=" << summarize_decode_states(states);
397+
}
228398
if (token_id < 0) {
229399
CHECK(allow_overlap_fake_token)
230400
<< "MTP decode fake token is only allowed with schedule overlap, "
@@ -915,6 +1085,23 @@ std::optional<ForwardOutput> MTPWorkerImpl::step_decode(
9151085
CHECK_EQ(last_states.size(),
9161086
input.input_params.embedding.embedding_ids.size())
9171087
<< "decode target state count mismatch";
1088+
VLOG(1) << "[MTP_DECODE_INPUT] num_sequences="
1089+
<< input.input_params.meta.num_sequences
1090+
<< ", token_ids_host=" << summarize_int_tensor(input.token_ids_host)
1091+
<< ", q_seq_lens="
1092+
<< summarize_int32_vector(
1093+
input.input_params.attention.host.q_seq_lens)
1094+
<< ", kv_seq_lens="
1095+
<< summarize_int32_vector(
1096+
input.input_params.attention.host.kv_seq_lens)
1097+
<< ", embedding_ids="
1098+
<< summarize_int32_vector(
1099+
input.input_params.embedding.embedding_ids)
1100+
<< ", request_ids="
1101+
<< summarize_string_vector(
1102+
input.input_params.embedding.request_ids);
1103+
VLOG(1) << "[MTP_DECODE_CACHE] states="
1104+
<< summarize_decode_states(last_states);
9181105
check_mtp_decode_states(last_states,
9191106
input.input_params.embedding.request_ids,
9201107
input.token_ids_host,
@@ -1047,6 +1234,23 @@ void MTPWorkerImpl::write_target_context_to_cache(
10471234
<< "embedding_cache_ must be initialized before target cache write";
10481235
CHECK(!input.input_params.embedding.embedding_ids.empty())
10491236
<< "target context cache write requires embedding ids";
1237+
const bool has_multi_token_accept =
1238+
has_accepted_prefix_longer_than_one(validate_output.next_tokens);
1239+
if (has_multi_token_accept) {
1240+
LOG(INFO) << "[MTP_ACCEPTED_PREFIX] multi-token accepted prefix detected, "
1241+
<< "num_speculative_tokens=" << options_.num_speculative_tokens()
1242+
<< ", embedding_ids="
1243+
<< summarize_int32_vector(
1244+
input.input_params.embedding.embedding_ids)
1245+
<< ", request_ids="
1246+
<< summarize_string_vector(
1247+
input.input_params.embedding.request_ids)
1248+
<< ", accepted_tokens="
1249+
<< summarize_accepted_tokens(validate_output.next_tokens);
1250+
} else {
1251+
VLOG(1) << "[MTP_ACCEPTED_PREFIX] accepted_tokens="
1252+
<< summarize_accepted_tokens(validate_output.next_tokens);
1253+
}
10501254
embedding_cache_->write_target_context(
10511255
input.input_params.embedding.embedding_ids,
10521256
input.input_params.embedding.request_ids,
@@ -1217,6 +1421,15 @@ void MTPWorkerImpl::prepare_validate_inputs(const ForwardInput& input,
12171421
const bool use_atb_spec_kernel =
12181422
::xllm::SpeculativeConfig::get_instance().enable_atb_spec_kernel() ||
12191423
use_qwen3_5_spec_verify_path();
1424+
LOG_FIRST_N(INFO, 4)
1425+
<< "[MTP_VALIDATE_LAYOUT] use_atb_spec_kernel=" << use_atb_spec_kernel
1426+
<< ", original_num_sequences=" << num_sequences
1427+
<< ", num_speculative_tokens=" << num_speculative_tokens
1428+
<< ", num_val_tokens=" << num_val_tokens
1429+
<< ", total_num_val_tokens=" << total_num_val_tokens
1430+
<< ", layout="
1431+
<< (use_atb_spec_kernel ? "chunked_prefill_rows"
1432+
: "expanded_decode_rows");
12201433
specBuilder::DecodeBuildBuffers buf;
12211434
buf.out_token_ids.reserve(total_num_val_tokens);
12221435
buf.out_positions.reserve(input.positions_host.dim() == 2

0 commit comments

Comments
 (0)