Skip to content

Commit

Permalink
Uses supertype with coercions in list constructor
Browse files Browse the repository at this point in the history
  • Loading branch information
rchowell committed Jan 31, 2025
1 parent 70c3d20 commit 1881d0a
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 32 deletions.
19 changes: 13 additions & 6 deletions src/daft-core/src/series/ops/zip.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::cmp::{max, min};

use arrow2::offset::Offsets;
use common_error::{DaftError, DaftResult};
use daft_schema::{dtype::DataType, field::Field};
Expand Down Expand Up @@ -38,20 +40,25 @@ impl Series {
}
};

// 0 -> index of child in 'arrays' vector
// 1 -> last index of child
type Child = (usize, usize);

// build a null series mask so we can skip making full_nulls and avoid downcast "Null to T" errors.
let mut mask: Vec<Option<usize>> = vec![];
let mut mask: Vec<Option<Child>> = vec![];
let mut rows = 0;
let mut capacity = 0;
let mut arrays = vec![];

for s in series {
let len = s.len();
if is_null(s) {
mask.push(None);
} else {
mask.push(Some(arrays.len()));
mask.push(Some((arrays.len(), len - 1)));
arrays.push(*s);
}
let len = s.len();
rows = std::cmp::max(rows, len);
rows = max(rows, len);
capacity += len;
}

Expand All @@ -63,8 +70,8 @@ impl Series {
// merge each series based upon the mask
for row in 0..rows {
for i in &mask {
if let Some(i) = *i {
child.extend(i, row, 1);
if let Some((i, end)) = *i {
child.extend(i, min(row, end), 1);
} else {
child.extend_nulls(1);
}
Expand Down
26 changes: 7 additions & 19 deletions src/daft-dsl/src/expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -961,7 +961,7 @@ impl Expr {
}
Self::IsIn(expr, items) => {
// Use the expr's field name, and infer membership op type.
let list_dtype = infer_list_type(items, schema)?.unwrap_or(DataType::Null);
let list_dtype = try_get_collection_supertype(items, schema)?;
let expr_field = expr.to_field(schema)?;
let expr_type = &expr_field.dtype;
let field_name = &expr_field.name;
Expand All @@ -973,7 +973,7 @@ impl Expr {
Self::List(items) => {
// Use "list" as the field name, and infer list type from items.
let field_name = "list";
let field_type = infer_list_type(items, schema)?.unwrap_or(DataType::Null);
let field_type = try_get_collection_supertype(items, schema)?;
Ok(Field::new(field_name, DataType::new_list(field_type)))
}
Self::Between(value, lower, upper) => {
Expand Down Expand Up @@ -1493,25 +1493,13 @@ pub fn deduplicate_expr_names(exprs: &[ExprRef]) -> Vec<ExprRef> {
.collect()
}

/// Asserts an expr slice is homogeneous and returns the type, or None if empty or all nulls.
/// None allows for context-dependent handling such as erroring or defaulting to Null.
fn infer_list_type(exprs: &[ExprRef], schema: &Schema) -> DaftResult<Option<DataType>> {
let mut dtype: Option<DataType> = None;
/// Tries to get the supertype of all exprs in the collection.
fn try_get_collection_supertype(exprs: &[ExprRef], schema: &Schema) -> DaftResult<DataType> {
let mut dtype = DataType::Null;
for expr in exprs {
let other_dtype = expr.get_type(schema)?;
// other is null, continue
if other_dtype == DataType::Null {
continue;
}
// other != null and dtype is unset -> set dtype
if dtype.is_none() {
dtype = Some(other_dtype);
continue;
}
// other != null and dtype is set -> compare or err!
if dtype.as_ref() != Some(&other_dtype) {
return Err(DaftError::TypeError(format!("Expected all arguments to be of the same type {}, but found element with type {other_dtype}", dtype.unwrap())));
}
let super_dtype = try_get_supertype(&dtype, &other_dtype)?;
dtype = super_dtype;
}
Ok(dtype)
}
12 changes: 11 additions & 1 deletion src/daft-table/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -554,9 +554,19 @@ impl Table {
.is_in(&s)
}
Expr::List(items) => {
// compute list type to determine each child cast
let field = expr.to_field(&self.schema)?;
let items = items.iter().map(|item| self.eval_expression(item)).collect::<DaftResult<Vec<_>>>()?;
// extract list child type (could be de-duped with zip and moved to impl DataType)
let dtype = if let DataType::List(dtype) = &field.dtype {
dtype
} else {
return Err(DaftError::ComputeError("List expression must be of type List(T)".to_string()))
};
// compute child series with explicit casts to the supertype
let items = items.iter().map(|i| i.clone().cast(dtype)).collect::<Vec<_>>();
let items = items.iter().map(|i| self.eval_expression(i)).collect::<DaftResult<Vec<_>>>()?;
let items = items.iter().collect::<Vec<&Series>>();
// zip the series into a single series of lists
Series::zip(field, items.as_slice())
}
Expr::Between(child, lower, upper) => self
Expand Down
36 changes: 30 additions & 6 deletions tests/expressions/test_list_.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import pytest

import daft
from daft import DataType, col, list_, lit
from daft import DataType as dt
from daft import col, list_, lit


def test_list_constructor_empty():
Expand All @@ -10,6 +11,30 @@ def test_list_constructor_empty():
df = df.select(list_())


def test_list_constructor_with_coercions():
df = daft.from_pydict({"v_i32": [1, 2, 3], "v_bool": [True, True, False]})
df = df.select(list_(lit(1), col("v_i32"), col("v_bool")))
assert df.to_pydict() == {"list": [[1, 1, 1], [1, 2, 1], [1, 3, 0]]}


def test_list_constructor_with_lit_first():
df = daft.from_pydict({"x": [1, 2, 3], "y": [4, 5, 6]})
df = df.select(list_(lit(1), col("x"), col("y")))
assert df.to_pydict() == {"list": [[1, 1, 4], [1, 2, 5], [1, 3, 6]]}


def test_list_constructor_with_lit_mid():
df = daft.from_pydict({"x": [1, 2, 3], "y": [4, 5, 6]})
df = df.select(list_(col("x"), lit(1), col("y")))
assert df.to_pydict() == {"list": [[1, 1, 4], [2, 1, 5], [3, 1, 6]]}


def test_list_constructor_with_lit_last():
df = daft.from_pydict({"x": [1, 2, 3], "y": [4, 5, 6]})
df = df.select(list_(col("x"), col("y"), lit(1)))
assert df.to_pydict() == {"list": [[1, 4, 1], [2, 5, 1], [3, 6, 1]]}


def test_list_constructor_multi_column():
df = daft.from_pydict({"x": [1, 2, 3], "y": [4, 5, 6]})
df = df.select(list_("x", "y").alias("fwd"), list_("y", "x").alias("rev"))
Expand All @@ -35,15 +60,14 @@ def test_list_constructor_homogeneous():


def test_list_constructor_heterogeneous():
with pytest.raises(Exception, match="Expected all arguments to be of the same type"):
df = daft.from_pydict({"x": [1, 2, 3], "y": [True, True, False]})
df = df.select(list_("x", "y"))
df.show()
df = daft.from_pydict({"x": [1, 2, 3], "y": [True, True, False]})
df = df.select(list_("x", "y").alias("heterogeneous"))
assert df.to_pydict() == {"heterogeneous": [[1, 1], [2, 1], [3, 0]]}


def test_list_constructor_heterogeneous_with_cast():
df = daft.from_pydict({"x": [1, 2, 3], "y": [True, True, False]})
df = df.select(list_(col("x").cast(DataType.string()), col("y").cast(DataType.string())).alias("strs"))
df = df.select(list_(col("x").cast(dt.string()), col("y").cast(dt.string())).alias("strs"))
assert df.to_pydict() == {"strs": [["1", "1"], ["2", "1"], ["3", "0"]]}


Expand Down

0 comments on commit 1881d0a

Please sign in to comment.