Skip to content

Commit

Permalink
[FEAT] connect: support basic column operations
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewgazelka committed Nov 20, 2024
1 parent 863ac08 commit c0f0a0a
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 2 deletions.
27 changes: 25 additions & 2 deletions src/daft-connect/src/translation/expr.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
use std::sync::Arc;

use eyre::{bail, Context};
use spark_connect::{expression as spark_expr, Expression};
use spark_connect::{
expression as spark_expr,
expression::sort_order::{NullOrdering, SortDirection},
Expression,
};
use tracing::warn;
use unresolved_function::unresolved_to_daft_expr;

Expand Down Expand Up @@ -73,7 +77,26 @@ pub fn to_daft_expr(expression: &Expression) -> eyre::Result<daft_dsl::ExprRef>
spark_expr::ExprType::UnresolvedRegex(_) => {
bail!("Unresolved regex expressions not yet supported")
}
spark_expr::ExprType::SortOrder(_) => bail!("Sort order expressions not yet supported"),
spark_expr::ExprType::SortOrder(s) => {
let spark_expr::SortOrder {
child,
direction,
null_ordering,
} = &**s;

let Some(child) = child else {
bail!("Sort order child is required");
};

let sort_direction = SortDirection::try_from(*direction)
.wrap_err_with(|| format!("Invalid sort direction: {direction}"))?;

let sort_nulls = NullOrdering::try_from(*null_ordering)
.wrap_err_with(|| format!("Invalid sort nulls: {null_ordering}"))?;


bail!("Sort order expressions not yet supported");
}
spark_expr::ExprType::LambdaFunction(_) => {
bail!("Lambda function expressions not yet supported")
}
Expand Down
36 changes: 36 additions & 0 deletions tests/connect/test_basic_column.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from __future__ import annotations

from pyspark.sql.functions import col
from pyspark.sql.types import StringType


def test_column_operations(spark_session):
# Create DataFrame from range(10)
df = spark_session.range(10)

# Test __getattr__
df_attr = df.select(col("id").desc()) # Fix: call desc() as method
assert df_attr.toPandas()["id"].iloc[0] == 9, "desc should sort in descending order"

# Test __getitem__
df_item = df.select(col("id")[0])
assert df_item.toPandas()["id"].iloc[0] == 0, "getitem should return first element"

# Test alias
df_alias = df.select(col("id").alias("my_number"))
assert "my_number" in df_alias.columns, "alias should rename column"
assert df_alias.toPandas()["my_number"].equals(df.toPandas()["id"]), "data should be unchanged"

# Test cast
df_cast = df.select(col("id").cast(StringType()))
assert df_cast.schema.fields[0].dataType == StringType(), "cast should change data type"

# Test isNotNull/isNull
df_null = df.select(col("id").isNotNull(), col("id").isNull())
assert df_null.toPandas().iloc[0,0] == True, "isNotNull should be True for non-null values"
assert df_null.toPandas().iloc[0,1] == False, "isNull should be False for non-null values"

# Test name
df_name = df.select(col("id").name("renamed_id"))
assert "renamed_id" in df_name.columns, "name should rename column"
assert df_name.toPandas()["renamed_id"].equals(df.toPandas()["id"]), "data should be unchanged"

0 comments on commit c0f0a0a

Please sign in to comment.