feat: add Spark-compatible math functions (bround, greatest, least, hex, unhex)#7122
feat: add Spark-compatible math functions (bround, greatest, least, hex, unhex)#7122XuQianJin-Stars wants to merge 3 commits into
Conversation
Greptile SummaryAdds five Spark-compatible scalar functions —
Confidence Score: 3/5The hex/unhex and bround implementations are functionally solid for most inputs, but greatest/least silently produces wrong results for float columns containing NaN, contradicting the Spark-compatible guarantee this PR advertises. The greatest/least implementation uses raw IEEE 754 comparisons which return false when either operand is NaN. Spark explicitly documents NaN as the maximum value in greatest/least — any float column with NaN values will produce different answers here than in Spark. This is a quiet data-correctness divergence in a feature explicitly positioned as a Spark compatibility layer. src/daft-functions/src/greatest.rs needs the most attention — specifically the compare_inputs function and its treatment of NaN in float comparisons. Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
PY["Python: daft.functions\nbround / greatest / least / hex / unhex"]
REG["Global FunctionRegistry\n(src/lib.rs)"]
NUM["NumericFunctions module\n(numeric/mod.rs)"]
SQL["SQL engine\n(auto-resolves by name)"]
PY -->|"_call_builtin_scalar_fn(name, ...)"| REG
REG --> SQL
BRound["BRound ScalarUDF\nnumeric/bround.rs\nf64::round_ties_even"]
HEX["Hex / Unhex ScalarUDF\nnumeric/hex.rs\nbytes_to_hex_upper / decode_hex_padded"]
GL["Greatest / Least ScalarUDF\ngreatest.rs\ncompare_inputs(keep_greater)"]
NUM --> BRound
NUM --> HEX
REG --> GL
BRound -->|"series_bround()"| BOUT["Float output"]
HEX -->|"hex_impl / unhex_impl"| HOUT["Utf8 / Binary output"]
GL -->|"try_get_collection_supertype + if_else"| GLOUT["Supertype output\nNULL if all-NULL row"]
|
| #[derive(Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] | ||
| pub struct Greatest; | ||
|
|
||
| #[typetag::serde] | ||
| impl ScalarUDF for Greatest { | ||
| fn name(&self) -> &'static str { | ||
| "greatest" | ||
| } | ||
|
|
||
| /// Returns the largest of the input values per row, ignoring NULLs. | ||
| /// Returns NULL only when all inputs in a row are NULL. Mirrors Spark's | ||
| /// `Greatest` semantics (`org.apache.spark.sql.catalyst.expressions.Greatest`). | ||
| fn call( | ||
| &self, | ||
| inputs: FunctionArgs<Series>, | ||
| _ctx: &daft_dsl::functions::scalar::EvalContext, | ||
| ) -> DaftResult<Series> { | ||
| compare_inputs(inputs, /* keep_greater = */ true) | ||
| } | ||
|
|
||
| fn get_return_field( | ||
| &self, | ||
| inputs: FunctionArgs<ExprRef>, | ||
| schema: &Schema, | ||
| ) -> DaftResult<Field> { | ||
| common_return_field(self.name(), inputs, schema) | ||
| } | ||
|
|
||
| fn docstring(&self) -> &'static str { | ||
| "Returns the largest value among the inputs, skipping NULL values. \ | ||
| Returns NULL only if all inputs are NULL. Requires at least one argument." | ||
| } | ||
| } | ||
|
|
||
| #[derive(Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] | ||
| pub struct Least; | ||
|
|
||
| #[typetag::serde] | ||
| impl ScalarUDF for Least { | ||
| fn name(&self) -> &'static str { | ||
| "least" | ||
| } | ||
|
|
||
| /// Returns the smallest of the input values per row, ignoring NULLs. | ||
| /// Returns NULL only when all inputs in a row are NULL. Mirrors Spark's | ||
| /// `Least` semantics (`org.apache.spark.sql.catalyst.expressions.Least`). | ||
| fn call( | ||
| &self, | ||
| inputs: FunctionArgs<Series>, | ||
| _ctx: &daft_dsl::functions::scalar::EvalContext, | ||
| ) -> DaftResult<Series> { | ||
| compare_inputs(inputs, /* keep_greater = */ false) | ||
| } | ||
|
|
||
| fn get_return_field( | ||
| &self, | ||
| inputs: FunctionArgs<ExprRef>, | ||
| schema: &Schema, | ||
| ) -> DaftResult<Field> { | ||
| common_return_field(self.name(), inputs, schema) | ||
| } | ||
|
|
||
| fn docstring(&self) -> &'static str { | ||
| "Returns the smallest value among the inputs, skipping NULL values. \ | ||
| Returns NULL only if all inputs are NULL. Requires at least one argument." | ||
| } | ||
| } |
There was a problem hiding this comment.
Two separate structs where one parametrised struct would suffice
Greatest and Least are structurally identical — the only difference is which comparison direction is passed to compare_inputs. Per the codebase's convention, similar functionality with different variants should be expressed as a single parametrised type (e.g. struct GreatestLeast { is_greatest: bool }) rather than two separate structs. The internal helper compare_inputs(inputs, keep_greater) already has the right shape; the public structs just add boilerplate.
Rule Used: Prefer single parametrized functions over multiple... (source)
Learned From
Eventual-Inc/Daft#5207
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
| fn f64_bround(arr: &Float64Array, precision: i32) -> DaftResult<Float64Array> { | ||
| if precision == 0 { | ||
| arr.apply(|v| v.round_ties_even()) | ||
| } else { | ||
| let multiplier: f64 = 10.0f64.pow(precision); | ||
| arr.apply(|v| (v * multiplier).round_ties_even() / multiplier) | ||
| } |
There was a problem hiding this comment.
Negative precision uses a non-representable f64 multiplier
For negative precision (e.g. -2), multiplier = 10.0f64.pow(-2) = 0.01, which is not exactly representable in f64. The round-trip (v * 0.01).round_ties_even() / 0.01 can accumulate two rounding errors (once on multiply, once on divide). Using integer division/multiplication — (v / 10.0f64.powi(-precision)).round_ties_even() * 10.0f64.powi(-precision) — avoids multiplying by a non-representable reciprocal and is the more numerically stable approach for negative precisions.
| fn call_bround(input: Series, decimals: Option<i32>) -> Series { | ||
| let mut args = vec![input.clone()]; | ||
| if let Some(d) = decimals { | ||
| args.push( | ||
| Int64Array::from_iter( | ||
| Field::new("d", DataType::Int64), | ||
| vec![Some(d as i64)].into_iter(), | ||
| ) | ||
| .into_series(), | ||
| ); | ||
| } | ||
| let ctx = EvalContext { | ||
| row_count: input.len(), | ||
| }; | ||
| BRound {} | ||
| .call(FunctionArgs::new_unnamed(args), &ctx) | ||
| .unwrap() | ||
| } | ||
|
|
There was a problem hiding this comment.
Test helper passes
Int64 for a field declared as Option<i32>
The BRoundArgs struct declares decimals: Option<i32>, yet the test helper pushes an Int64Array as the second argument. It would be worth verifying that calling bround via the SQL engine with an integer literal also works, since SQL may infer the literal as Int64. Adding a test case that passes the decimal as Int32 directly would make the contract explicit.
srilman
left a comment
There was a problem hiding this comment.
@XuQianJin-Stars I believe we already have greatest and least, they are just named columns_min and columns_max. Maybe not the best names, and it would make sense to have greatest and least as aliases, but we dont need I custom implementation I believe
…ex, unhex) - bround(expr, d): HALF_EVEN (banker's) rounding, supports negative d - greatest(*exprs) / least(*exprs): N-ary row-wise min/max, skipping NULLs - hex(expr): integer/string/binary -> uppercase hex (negatives use 64-bit two's complement) - unhex(expr): hex string -> binary, odd-length left-padded with '0', invalid input returns NULL Auto-registered as ScalarUDF in NumericFunctions / global registry, exposed via daft.functions, and unit-tested (17 tests).
- numeric.py: use raw docstring for unhex() to satisfy ruff D301 (escape sequence \x0f in a regular docstring) - bround.rs: avoid multiplying by non-exact powers of 10 for negative precisions. precision >= 0 keeps multiply-then-divide; precision < 0 now divides-then-multiplies by 10^|p|, eliminating the second rounding step that biased ties (e.g. 250 with d=-2). Add regression test plus an Int32 decimals literal test for BRoundArgs<Option<i32>>. - greatest.rs: deduplicate Greatest/Least via a private GreatestLeastKind trait carrying KEEP_GREATER/NAME/DOCSTRING. Both ScalarUDF impls now share compare_inputs through impl_call::<Self>; public unit structs and typetag::serde tags are preserved for backwards compatibility.
Thanks @srilman! After checking the NULL semantics I confirmed list_min / list_max and the new greatest / least behave identically (skip NULLs row-wise; return NULL only when all inputs are NULL — verified in tests/recordbatch/list/test_list_numeric_aggs.py). |
Address review feedback: replace the list-based implementation (to_list().list_min()/list_max()) with direct delegation to the existing greatest/least scalar functions. Benefits: - Row-wise NULL skipping: result is NULL only when all inputs in a row are NULL, matching Spark Greatest/Least semantics. - Works on any comparable dtype (numeric, boolean, string, temporal), not just types supported by list aggregation. - Avoids the overhead of constructing an intermediate list column per row. Public surface is preserved: aliases (columns_min/columns_max) and empty-arg error messages remain unchanged. All 80 existing tests/dataframe/test_horizontal.py cases pass.
7d4ffa0 to
d719636
Compare
Changes Made
This PR adds five Spark-compatible math functions to Daft, callable from
both the Python DataFrame API and Daft SQL. All semantics match Spark
exactly so that workloads migrating from Spark / PySpark can run
unchanged.
Functions added
broundbround(expr, d=0)d. NULL passthrough.greatestgreatest(*exprs)leastleast(*exprs)hexhex(expr)unhexunhex(expr)'0'. Invalid input → NULL.Implementation details
src/daft-functions/src/:numeric/bround.rs— usesf64::round_ties_evento bit-exactlymatch Spark's
HALF_EVENmode; integer inputs short-circuit.numeric/hex.rs—HexandUnhexScalarUDFs covering Int*/UInt*/Utf8/Binary, with two's-complement encoding for negative integers.
greatest.rs— single ScalarUDF reused for bothgreatestandleastvia ais_leastflag; computes the supertype of all inputsand skips NULLs row-wise.
ScalarUDFs inNumericFunctionsand the global function registry, so SQL picks them up for free
(
SELECT bround(x, 2),SELECT greatest(a, b, c), etc.).daft/functions/numeric.py, exported fromdaft.functionsso users can dofrom daft.functions import bround, greatest, least, hex, unhex.round_ties_even(already in the project's MSRV).
Behavior examples
Verification
cargo check(workspace) — no errorscargo fmt --check(workspace) — cleancargo clippy -p daft-functions— no errors, no new warningsdaft.functionsRelated Issues
Closes #7121