From 1881d0aae46d98fc0ded73f3d61323268a9ffe51 Mon Sep 17 00:00:00 2001 From: "R. Conner Howell" Date: Fri, 31 Jan 2025 10:50:20 -0800 Subject: [PATCH] Uses supertype with coercions in list constructor --- src/daft-core/src/series/ops/zip.rs | 19 ++++++++++----- src/daft-dsl/src/expr/mod.rs | 26 ++++++--------------- src/daft-table/src/lib.rs | 12 +++++++++- tests/expressions/test_list_.py | 36 ++++++++++++++++++++++++----- 4 files changed, 61 insertions(+), 32 deletions(-) diff --git a/src/daft-core/src/series/ops/zip.rs b/src/daft-core/src/series/ops/zip.rs index b194bb87fb..c7b0e844a2 100644 --- a/src/daft-core/src/series/ops/zip.rs +++ b/src/daft-core/src/series/ops/zip.rs @@ -1,3 +1,5 @@ +use std::cmp::{max, min}; + use arrow2::offset::Offsets; use common_error::{DaftError, DaftResult}; use daft_schema::{dtype::DataType, field::Field}; @@ -38,20 +40,25 @@ impl Series { } }; + // 0 -> index of child in 'arrays' vector + // 1 -> last index of child + type Child = (usize, usize); + // build a null series mask so we can skip making full_nulls and avoid downcast "Null to T" errors. - let mut mask: Vec> = vec![]; + let mut mask: Vec> = vec![]; let mut rows = 0; let mut capacity = 0; let mut arrays = vec![]; + for s in series { + let len = s.len(); if is_null(s) { mask.push(None); } else { - mask.push(Some(arrays.len())); + mask.push(Some((arrays.len(), len - 1))); arrays.push(*s); } - let len = s.len(); - rows = std::cmp::max(rows, len); + rows = max(rows, len); capacity += len; } @@ -63,8 +70,8 @@ impl Series { // merge each series based upon the mask for row in 0..rows { for i in &mask { - if let Some(i) = *i { - child.extend(i, row, 1); + if let Some((i, end)) = *i { + child.extend(i, min(row, end), 1); } else { child.extend_nulls(1); } diff --git a/src/daft-dsl/src/expr/mod.rs b/src/daft-dsl/src/expr/mod.rs index 619ef5a4f1..487c74fa53 100644 --- a/src/daft-dsl/src/expr/mod.rs +++ b/src/daft-dsl/src/expr/mod.rs @@ -961,7 +961,7 @@ impl Expr { } Self::IsIn(expr, items) => { // Use the expr's field name, and infer membership op type. - let list_dtype = infer_list_type(items, schema)?.unwrap_or(DataType::Null); + let list_dtype = try_get_collection_supertype(items, schema)?; let expr_field = expr.to_field(schema)?; let expr_type = &expr_field.dtype; let field_name = &expr_field.name; @@ -973,7 +973,7 @@ impl Expr { Self::List(items) => { // Use "list" as the field name, and infer list type from items. let field_name = "list"; - let field_type = infer_list_type(items, schema)?.unwrap_or(DataType::Null); + let field_type = try_get_collection_supertype(items, schema)?; Ok(Field::new(field_name, DataType::new_list(field_type))) } Self::Between(value, lower, upper) => { @@ -1493,25 +1493,13 @@ pub fn deduplicate_expr_names(exprs: &[ExprRef]) -> Vec { .collect() } -/// Asserts an expr slice is homogeneous and returns the type, or None if empty or all nulls. -/// None allows for context-dependent handling such as erroring or defaulting to Null. -fn infer_list_type(exprs: &[ExprRef], schema: &Schema) -> DaftResult> { - let mut dtype: Option = None; +/// Tries to get the supertype of all exprs in the collection. +fn try_get_collection_supertype(exprs: &[ExprRef], schema: &Schema) -> DaftResult { + let mut dtype = DataType::Null; for expr in exprs { let other_dtype = expr.get_type(schema)?; - // other is null, continue - if other_dtype == DataType::Null { - continue; - } - // other != null and dtype is unset -> set dtype - if dtype.is_none() { - dtype = Some(other_dtype); - continue; - } - // other != null and dtype is set -> compare or err! - if dtype.as_ref() != Some(&other_dtype) { - return Err(DaftError::TypeError(format!("Expected all arguments to be of the same type {}, but found element with type {other_dtype}", dtype.unwrap()))); - } + let super_dtype = try_get_supertype(&dtype, &other_dtype)?; + dtype = super_dtype; } Ok(dtype) } diff --git a/src/daft-table/src/lib.rs b/src/daft-table/src/lib.rs index 5a4825fe64..2d3082fe52 100644 --- a/src/daft-table/src/lib.rs +++ b/src/daft-table/src/lib.rs @@ -554,9 +554,19 @@ impl Table { .is_in(&s) } Expr::List(items) => { + // compute list type to determine each child cast let field = expr.to_field(&self.schema)?; - let items = items.iter().map(|item| self.eval_expression(item)).collect::>>()?; + // extract list child type (could be de-duped with zip and moved to impl DataType) + let dtype = if let DataType::List(dtype) = &field.dtype { + dtype + } else { + return Err(DaftError::ComputeError("List expression must be of type List(T)".to_string())) + }; + // compute child series with explicit casts to the supertype + let items = items.iter().map(|i| i.clone().cast(dtype)).collect::>(); + let items = items.iter().map(|i| self.eval_expression(i)).collect::>>()?; let items = items.iter().collect::>(); + // zip the series into a single series of lists Series::zip(field, items.as_slice()) } Expr::Between(child, lower, upper) => self diff --git a/tests/expressions/test_list_.py b/tests/expressions/test_list_.py index a6f50952dc..8ee5d08975 100644 --- a/tests/expressions/test_list_.py +++ b/tests/expressions/test_list_.py @@ -1,7 +1,8 @@ import pytest import daft -from daft import DataType, col, list_, lit +from daft import DataType as dt +from daft import col, list_, lit def test_list_constructor_empty(): @@ -10,6 +11,30 @@ def test_list_constructor_empty(): df = df.select(list_()) +def test_list_constructor_with_coercions(): + df = daft.from_pydict({"v_i32": [1, 2, 3], "v_bool": [True, True, False]}) + df = df.select(list_(lit(1), col("v_i32"), col("v_bool"))) + assert df.to_pydict() == {"list": [[1, 1, 1], [1, 2, 1], [1, 3, 0]]} + + +def test_list_constructor_with_lit_first(): + df = daft.from_pydict({"x": [1, 2, 3], "y": [4, 5, 6]}) + df = df.select(list_(lit(1), col("x"), col("y"))) + assert df.to_pydict() == {"list": [[1, 1, 4], [1, 2, 5], [1, 3, 6]]} + + +def test_list_constructor_with_lit_mid(): + df = daft.from_pydict({"x": [1, 2, 3], "y": [4, 5, 6]}) + df = df.select(list_(col("x"), lit(1), col("y"))) + assert df.to_pydict() == {"list": [[1, 1, 4], [2, 1, 5], [3, 1, 6]]} + + +def test_list_constructor_with_lit_last(): + df = daft.from_pydict({"x": [1, 2, 3], "y": [4, 5, 6]}) + df = df.select(list_(col("x"), col("y"), lit(1))) + assert df.to_pydict() == {"list": [[1, 4, 1], [2, 5, 1], [3, 6, 1]]} + + def test_list_constructor_multi_column(): df = daft.from_pydict({"x": [1, 2, 3], "y": [4, 5, 6]}) df = df.select(list_("x", "y").alias("fwd"), list_("y", "x").alias("rev")) @@ -35,15 +60,14 @@ def test_list_constructor_homogeneous(): def test_list_constructor_heterogeneous(): - with pytest.raises(Exception, match="Expected all arguments to be of the same type"): - df = daft.from_pydict({"x": [1, 2, 3], "y": [True, True, False]}) - df = df.select(list_("x", "y")) - df.show() + df = daft.from_pydict({"x": [1, 2, 3], "y": [True, True, False]}) + df = df.select(list_("x", "y").alias("heterogeneous")) + assert df.to_pydict() == {"heterogeneous": [[1, 1], [2, 1], [3, 0]]} def test_list_constructor_heterogeneous_with_cast(): df = daft.from_pydict({"x": [1, 2, 3], "y": [True, True, False]}) - df = df.select(list_(col("x").cast(DataType.string()), col("y").cast(DataType.string())).alias("strs")) + df = df.select(list_(col("x").cast(dt.string()), col("y").cast(dt.string())).alias("strs")) assert df.to_pydict() == {"strs": [["1", "1"], ["2", "1"], ["3", "0"]]}