Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 118 additions & 13 deletions src/daft-local-execution/src/join/inner_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,40 +7,45 @@ use daft_recordbatch::{GrowableRecordBatch, ProbeState};

use crate::join::hash_join::HashJoinParams;

const DEFAULT_GROWABLE_SIZE: usize = 20;
const MIN_MATCHES_FOR_VECTORIZED_TAKE: usize = 1024;
const MIN_FANOUT_FOR_VECTORIZED_TAKE: usize = 4;
const MIN_AVG_RUN_LEN_FOR_VECTORIZED_TAKE: usize = 8;

type BuildMatch = (u64, usize, u64);

pub(crate) fn probe_inner(
input: &MicroPartition,
probe_state: &ProbeState,
params: &HashJoinParams,
) -> DaftResult<MicroPartition> {
let build_side_tables = probe_state.get_record_batches().iter().collect::<Vec<_>>();
const DEFAULT_GROWABLE_SIZE: usize = 20;

let input_tables = input.record_batches();
let result_tables = input_tables
.iter()
.map(|input_table| {
let mut build_side_growable =
GrowableRecordBatch::new(&build_side_tables, false, DEFAULT_GROWABLE_SIZE)?;
let mut probe_side_idxs = Vec::new();

let join_keys = input_table.eval_expression_list(&params.probe_on)?;
let idx_iter = probe_state.probe_indices(join_keys)?;
let mut matches = Vec::new();
for (probe_row_idx, inner_iter) in idx_iter.enumerate() {
if let Some(inner_iter) = inner_iter {
for (build_rb_idx, build_row_idx) in inner_iter {
build_side_growable.extend(
build_rb_idx as usize,
build_row_idx as usize,
1,
);
probe_side_idxs.push(probe_row_idx as u64);
matches.push((probe_row_idx as u64, build_rb_idx as usize, build_row_idx));
}
}
}

let build_side_table = build_side_growable.build()?;
let build_side_table =
build_side_for_inner_matches(&build_side_tables, &matches, input_table.len())?;
let probe_side_table = {
let indices_arr = UInt64Array::from_vec("", probe_side_idxs);
let indices_arr = UInt64Array::from_vec(
"",
matches
.iter()
.map(|(probe_row_idx, _, _)| *probe_row_idx)
.collect::<Vec<_>>(),
);
input_table.take(&indices_arr)?
};

Expand Down Expand Up @@ -83,3 +88,103 @@ pub(crate) fn probe_inner(
None,
))
}

fn build_side_for_inner_matches(
build_side_tables: &[&daft_recordbatch::RecordBatch],
matches: &[BuildMatch],
probe_rows: usize,
) -> DaftResult<daft_recordbatch::RecordBatch> {
if should_use_vectorized_take(matches, probe_rows) {
build_side_with_vectorized_take(build_side_tables, matches)
} else {
build_side_with_growable(build_side_tables, matches)
}
}

fn should_use_vectorized_take(matches: &[BuildMatch], probe_rows: usize) -> bool {
if matches.len() < MIN_MATCHES_FOR_VECTORIZED_TAKE || probe_rows == 0 {
return false;
}

if matches.len() < probe_rows.saturating_mul(MIN_FANOUT_FOR_VECTORIZED_TAKE) {
return false;
}

let run_count = count_build_table_runs(matches);
matches.len() >= run_count.saturating_mul(MIN_AVG_RUN_LEN_FOR_VECTORIZED_TAKE)
}

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Fanout threshold uses total probe rows, not matched probe rows

probe_rows is input_table.len(), the total number of probe-side rows in the current micro-batch, including rows that produced zero matches. In sparse-match scenarios (e.g. 1 000 probe rows, 10 of which each match 200 build rows), matches.len() is 2 000 but probe_rows * MIN_FANOUT_FOR_VECTORIZED_TAKE is 4 000, so the fanout guard short-circuits and the growable path is always taken — even though average run length easily clears MIN_AVG_RUN_LEN_FOR_VECTORIZED_TAKE. This won't produce wrong output, but the optimization won't fire in some high-fanout, low-selectivity workloads where it would help most.


fn count_build_table_runs(matches: &[BuildMatch]) -> usize {
if matches.is_empty() {
return 0;
}

let mut runs = 1;
let mut current_build_table = matches[0].1;
for (_, build_table, _) in matches.iter().skip(1) {
if *build_table != current_build_table {
runs += 1;
current_build_table = *build_table;
}
}
runs
}

fn build_side_with_growable(
build_side_tables: &[&daft_recordbatch::RecordBatch],
matches: &[BuildMatch],
) -> DaftResult<daft_recordbatch::RecordBatch> {
let mut build_side_growable =
GrowableRecordBatch::new(build_side_tables, false, DEFAULT_GROWABLE_SIZE)?;

for (_, build_rb_idx, build_row_idx) in matches {
build_side_growable.extend(*build_rb_idx, *build_row_idx as usize, 1);
}

build_side_growable.build()
}

fn build_side_with_vectorized_take(
build_side_tables: &[&daft_recordbatch::RecordBatch],
matches: &[BuildMatch],
) -> DaftResult<daft_recordbatch::RecordBatch> {
if matches.is_empty() {
return build_side_with_growable(build_side_tables, matches);
}

let mut taken_tables = Vec::with_capacity(count_build_table_runs(matches));

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Redundant O(n) traversal of matches

count_build_table_runs is called here to size taken_tables, but the same traversal was already performed inside should_use_vectorized_take when deciding to take this path. For large match sets (≥1024 entries, which is when vectorized take is even attempted), this doubles the scan cost just to compute capacity. Passing the pre-computed run count through from the dispatch site would eliminate the redundancy without changing behaviour.

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

let mut current_build_table = matches[0].1;
let mut run_row_idxs = Vec::new();

for (_, build_rb_idx, build_row_idx) in matches {
if *build_rb_idx != current_build_table {
push_taken_run(

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wont this just lead to a bunch of small .take operations and allocations for run_row_idxs because we need indexes for the same build chunk to be contiguous in the matches list?

build_side_tables,
&mut taken_tables,
current_build_table,
&mut run_row_idxs,
)?;
current_build_table = *build_rb_idx;
}
run_row_idxs.push(*build_row_idx);
}
push_taken_run(
build_side_tables,
&mut taken_tables,
current_build_table,
&mut run_row_idxs,
)?;

daft_recordbatch::RecordBatch::concat(taken_tables)
}

fn push_taken_run(
build_side_tables: &[&daft_recordbatch::RecordBatch],
taken_tables: &mut Vec<daft_recordbatch::RecordBatch>,
build_table_idx: usize,
run_row_idxs: &mut Vec<u64>,
) -> DaftResult<()> {
let indices_arr = UInt64Array::from_vec("", std::mem::take(run_row_idxs));
taken_tables.push(build_side_tables[build_table_idx].take(&indices_arr)?);
Ok(())
}
39 changes: 39 additions & 0 deletions tests/dataframe/test_joins.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,45 @@ def test_joins_all_same_key(join_strategy, join_type, make_df, n_partitions: int
}


@pytest.mark.parametrize("n_partitions", [1, 4])
def test_inner_join_high_build_side_fanout(make_df, n_partitions: int, with_default_morsel_size):
num_keys = 64
fanout = 32

left = make_df(
{
"A": list(range(num_keys)),
"left_payload": [f"left-{i}" for i in range(num_keys)],
},
repartition=n_partitions,
repartition_columns=["A"],
)
right_keys = [key for key in range(num_keys) for _ in range(fanout)]
right = make_df(
{
"A": right_keys,
"right_payload": [
f"right-{key}-{idx:02d}" for key in range(num_keys) for idx in range(fanout)
],
},
repartition=n_partitions,
repartition_columns=["A"],
)

joined = left.join(right, on="A", strategy="hash", how="inner").sort(
["A", "right_payload"]
)
joined_data = joined.to_pydict()

assert joined_data["A"] == [key for key in range(num_keys) for _ in range(fanout)]
assert joined_data["left_payload"] == [
f"left-{key}" for key in range(num_keys) for _ in range(fanout)
]
assert joined_data["right_payload"] == [
f"right-{key}-{idx:02d}" for key in range(num_keys) for idx in range(fanout)
]


@pytest.mark.parametrize("n_partitions", get_n_partitions())
@pytest.mark.parametrize(
"join_strategy",
Expand Down
39 changes: 39 additions & 0 deletions tests/microbenchmarks/test_join.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,45 @@ def bench_join() -> DataFrame:
assert data[-big_factor:] == [str(small_length - 1)] * big_factor


@pytest.mark.benchmark(group="joins")
@pytest.mark.parametrize("num_partitions", [1, 10], ids=["1part", "10part"])
def test_inner_join_high_fanout(benchmark, num_partitions) -> None:
"""Test inner joins where each probe row matches many build-side rows."""
small_length = 1_000
fanout = 32

left_arr = np.arange(small_length)
np.random.shuffle(left_arr)
right_arr = np.repeat(np.arange(small_length), fanout)
np.random.shuffle(right_arr)

left_table = (
daft.from_pydict(
{
"keys": left_arr,
}
)
.into_partitions(num_partitions)
.collect()
)
right_table = daft.from_pydict(
{
"keys": right_arr,
"right_payload": [str(x) for x in right_arr],
}
).collect()

def bench_join() -> DataFrame:
return left_table.join(right_table, on=["keys"], how="inner").collect()

result = benchmark(bench_join)

assert len(result) == small_length * fanout
assert result.groupby("keys").agg(col("right_payload").count()).sort("keys").to_pydict()[
"right_payload"
] == [fanout] * small_length


@pytest.mark.benchmark(group="joins")
@pytest.mark.parametrize(
"num_samples, num_partitions",
Expand Down
Loading