-
Notifications
You must be signed in to change notification settings - Fork 509
perf(join): vectorize high-fanout inner join build takes #7143
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -7,40 +7,45 @@ use daft_recordbatch::{GrowableRecordBatch, ProbeState}; | |
|
|
||
| use crate::join::hash_join::HashJoinParams; | ||
|
|
||
| const DEFAULT_GROWABLE_SIZE: usize = 20; | ||
| const MIN_MATCHES_FOR_VECTORIZED_TAKE: usize = 1024; | ||
| const MIN_FANOUT_FOR_VECTORIZED_TAKE: usize = 4; | ||
| const MIN_AVG_RUN_LEN_FOR_VECTORIZED_TAKE: usize = 8; | ||
|
|
||
| type BuildMatch = (u64, usize, u64); | ||
|
|
||
| pub(crate) fn probe_inner( | ||
| input: &MicroPartition, | ||
| probe_state: &ProbeState, | ||
| params: &HashJoinParams, | ||
| ) -> DaftResult<MicroPartition> { | ||
| let build_side_tables = probe_state.get_record_batches().iter().collect::<Vec<_>>(); | ||
| const DEFAULT_GROWABLE_SIZE: usize = 20; | ||
|
|
||
| let input_tables = input.record_batches(); | ||
| let result_tables = input_tables | ||
| .iter() | ||
| .map(|input_table| { | ||
| let mut build_side_growable = | ||
| GrowableRecordBatch::new(&build_side_tables, false, DEFAULT_GROWABLE_SIZE)?; | ||
| let mut probe_side_idxs = Vec::new(); | ||
|
|
||
| let join_keys = input_table.eval_expression_list(¶ms.probe_on)?; | ||
| let idx_iter = probe_state.probe_indices(join_keys)?; | ||
| let mut matches = Vec::new(); | ||
| for (probe_row_idx, inner_iter) in idx_iter.enumerate() { | ||
| if let Some(inner_iter) = inner_iter { | ||
| for (build_rb_idx, build_row_idx) in inner_iter { | ||
| build_side_growable.extend( | ||
| build_rb_idx as usize, | ||
| build_row_idx as usize, | ||
| 1, | ||
| ); | ||
| probe_side_idxs.push(probe_row_idx as u64); | ||
| matches.push((probe_row_idx as u64, build_rb_idx as usize, build_row_idx)); | ||
| } | ||
| } | ||
| } | ||
|
|
||
| let build_side_table = build_side_growable.build()?; | ||
| let build_side_table = | ||
| build_side_for_inner_matches(&build_side_tables, &matches, input_table.len())?; | ||
| let probe_side_table = { | ||
| let indices_arr = UInt64Array::from_vec("", probe_side_idxs); | ||
| let indices_arr = UInt64Array::from_vec( | ||
| "", | ||
| matches | ||
| .iter() | ||
| .map(|(probe_row_idx, _, _)| *probe_row_idx) | ||
| .collect::<Vec<_>>(), | ||
| ); | ||
| input_table.take(&indices_arr)? | ||
| }; | ||
|
|
||
|
|
@@ -83,3 +88,103 @@ pub(crate) fn probe_inner( | |
| None, | ||
| )) | ||
| } | ||
|
|
||
| fn build_side_for_inner_matches( | ||
| build_side_tables: &[&daft_recordbatch::RecordBatch], | ||
| matches: &[BuildMatch], | ||
| probe_rows: usize, | ||
| ) -> DaftResult<daft_recordbatch::RecordBatch> { | ||
| if should_use_vectorized_take(matches, probe_rows) { | ||
| build_side_with_vectorized_take(build_side_tables, matches) | ||
| } else { | ||
| build_side_with_growable(build_side_tables, matches) | ||
| } | ||
| } | ||
|
|
||
| fn should_use_vectorized_take(matches: &[BuildMatch], probe_rows: usize) -> bool { | ||
| if matches.len() < MIN_MATCHES_FOR_VECTORIZED_TAKE || probe_rows == 0 { | ||
| return false; | ||
| } | ||
|
|
||
| if matches.len() < probe_rows.saturating_mul(MIN_FANOUT_FOR_VECTORIZED_TAKE) { | ||
| return false; | ||
| } | ||
|
|
||
| let run_count = count_build_table_runs(matches); | ||
| matches.len() >= run_count.saturating_mul(MIN_AVG_RUN_LEN_FOR_VECTORIZED_TAKE) | ||
| } | ||
|
|
||
| fn count_build_table_runs(matches: &[BuildMatch]) -> usize { | ||
| if matches.is_empty() { | ||
| return 0; | ||
| } | ||
|
|
||
| let mut runs = 1; | ||
| let mut current_build_table = matches[0].1; | ||
| for (_, build_table, _) in matches.iter().skip(1) { | ||
| if *build_table != current_build_table { | ||
| runs += 1; | ||
| current_build_table = *build_table; | ||
| } | ||
| } | ||
| runs | ||
| } | ||
|
|
||
| fn build_side_with_growable( | ||
| build_side_tables: &[&daft_recordbatch::RecordBatch], | ||
| matches: &[BuildMatch], | ||
| ) -> DaftResult<daft_recordbatch::RecordBatch> { | ||
| let mut build_side_growable = | ||
| GrowableRecordBatch::new(build_side_tables, false, DEFAULT_GROWABLE_SIZE)?; | ||
|
|
||
| for (_, build_rb_idx, build_row_idx) in matches { | ||
| build_side_growable.extend(*build_rb_idx, *build_row_idx as usize, 1); | ||
| } | ||
|
|
||
| build_side_growable.build() | ||
| } | ||
|
|
||
| fn build_side_with_vectorized_take( | ||
| build_side_tables: &[&daft_recordbatch::RecordBatch], | ||
| matches: &[BuildMatch], | ||
| ) -> DaftResult<daft_recordbatch::RecordBatch> { | ||
| if matches.is_empty() { | ||
| return build_side_with_growable(build_side_tables, matches); | ||
| } | ||
|
|
||
| let mut taken_tables = Vec::with_capacity(count_build_table_runs(matches)); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time! |
||
| let mut current_build_table = matches[0].1; | ||
| let mut run_row_idxs = Vec::new(); | ||
|
|
||
| for (_, build_rb_idx, build_row_idx) in matches { | ||
| if *build_rb_idx != current_build_table { | ||
| push_taken_run( | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Wont this just lead to a bunch of small |
||
| build_side_tables, | ||
| &mut taken_tables, | ||
| current_build_table, | ||
| &mut run_row_idxs, | ||
| )?; | ||
| current_build_table = *build_rb_idx; | ||
| } | ||
| run_row_idxs.push(*build_row_idx); | ||
| } | ||
| push_taken_run( | ||
| build_side_tables, | ||
| &mut taken_tables, | ||
| current_build_table, | ||
| &mut run_row_idxs, | ||
| )?; | ||
|
|
||
| daft_recordbatch::RecordBatch::concat(taken_tables) | ||
| } | ||
|
|
||
| fn push_taken_run( | ||
| build_side_tables: &[&daft_recordbatch::RecordBatch], | ||
| taken_tables: &mut Vec<daft_recordbatch::RecordBatch>, | ||
| build_table_idx: usize, | ||
| run_row_idxs: &mut Vec<u64>, | ||
| ) -> DaftResult<()> { | ||
| let indices_arr = UInt64Array::from_vec("", std::mem::take(run_row_idxs)); | ||
| taken_tables.push(build_side_tables[build_table_idx].take(&indices_arr)?); | ||
| Ok(()) | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
probe_rowsisinput_table.len(), the total number of probe-side rows in the current micro-batch, including rows that produced zero matches. In sparse-match scenarios (e.g. 1 000 probe rows, 10 of which each match 200 build rows),matches.len()is 2 000 butprobe_rows * MIN_FANOUT_FOR_VECTORIZED_TAKEis 4 000, so the fanout guard short-circuits and the growable path is always taken — even though average run length easily clearsMIN_AVG_RUN_LEN_FOR_VECTORIZED_TAKE. This won't produce wrong output, but the optimization won't fire in some high-fanout, low-selectivity workloads where it would help most.