From f9d442f6e2d8ed3b290ac631358914ad34646fa4 Mon Sep 17 00:00:00 2001 From: Andrew Gazelka Date: Mon, 18 Nov 2024 01:19:30 -0800 Subject: [PATCH 1/2] feat: add `limit` and `first` --- src/daft-connect/src/translation/logical_plan.rs | 15 ++++++++++++++- tests/connect/test_range_simple.py | 13 +++++++++++++ 2 files changed, 27 insertions(+), 1 deletion(-) diff --git a/src/daft-connect/src/translation/logical_plan.rs b/src/daft-connect/src/translation/logical_plan.rs index 947e0cd0d3..93c9e9bd4a 100644 --- a/src/daft-connect/src/translation/logical_plan.rs +++ b/src/daft-connect/src/translation/logical_plan.rs @@ -1,6 +1,6 @@ use daft_logical_plan::LogicalPlanBuilder; use eyre::{bail, Context}; -use spark_connect::{relation::RelType, Relation}; +use spark_connect::{relation::RelType, Limit, Relation}; use tracing::warn; use crate::translation::logical_plan::{aggregate::aggregate, project::project, range::range}; @@ -19,6 +19,7 @@ pub fn to_logical_plan(relation: Relation) -> eyre::Result { }; match rel_type { + RelType::Limit(l) => limit(*l).wrap_err("Failed to apply limit to logical plan"), RelType::Range(r) => range(r).wrap_err("Failed to apply range to logical plan"), RelType::Project(p) => project(*p).wrap_err("Failed to apply project to logical plan"), RelType::Aggregate(a) => { @@ -27,3 +28,15 @@ pub fn to_logical_plan(relation: Relation) -> eyre::Result { plan => bail!("Unsupported relation type: {plan:?}"), } } + +fn limit(limit: Limit) -> eyre::Result { + let Limit { input, limit } = limit; + + let Some(input) = input else { + bail!("input must be set"); + }; + + let plan = to_logical_plan(*input)?.limit(i64::from(limit), false)?; // todo: eager or no + + Ok(plan) +} diff --git a/tests/connect/test_range_simple.py b/tests/connect/test_range_simple.py index b277d38481..34e82ebbcf 100644 --- a/tests/connect/test_range_simple.py +++ b/tests/connect/test_range_simple.py @@ -12,3 +12,16 @@ def test_range_operation(spark_session): # Verify the DataFrame has expected values assert len(pandas_df) == 10, "DataFrame should have 10 rows" assert list(pandas_df["id"]) == list(range(10)), "DataFrame should contain values 0-9" + + +def test_range_first(spark_session): + spark_range = spark_session.range(10) + first_row = spark_range.first() + assert first_row["id"] == 0, "First row should have id=0" + + +def test_range_limit(spark_session): + spark_range = spark_session.range(10) + limited_df = spark_range.limit(5).toPandas() + assert len(limited_df) == 5, "Limited DataFrame should have 5 rows" + assert list(limited_df["id"]) == list(range(5)), "Limited DataFrame should contain values 0-4" From e099f2e73ce39bd9369271399beecd64f9d4810f Mon Sep 17 00:00:00 2001 From: Andrew Gazelka Date: Mon, 18 Nov 2024 01:21:52 -0800 Subject: [PATCH 2/2] separate tests --- tests/connect/test_limit_simple.py | 14 ++++++++++++++ tests/connect/test_range_simple.py | 13 ------------- 2 files changed, 14 insertions(+), 13 deletions(-) create mode 100644 tests/connect/test_limit_simple.py diff --git a/tests/connect/test_limit_simple.py b/tests/connect/test_limit_simple.py new file mode 100644 index 0000000000..d5f2c97dae --- /dev/null +++ b/tests/connect/test_limit_simple.py @@ -0,0 +1,14 @@ +from __future__ import annotations + + +def test_range_first(spark_session): + spark_range = spark_session.range(10) + first_row = spark_range.first() + assert first_row["id"] == 0, "First row should have id=0" + + +def test_range_limit(spark_session): + spark_range = spark_session.range(10) + limited_df = spark_range.limit(5).toPandas() + assert len(limited_df) == 5, "Limited DataFrame should have 5 rows" + assert list(limited_df["id"]) == list(range(5)), "Limited DataFrame should contain values 0-4" diff --git a/tests/connect/test_range_simple.py b/tests/connect/test_range_simple.py index 34e82ebbcf..b277d38481 100644 --- a/tests/connect/test_range_simple.py +++ b/tests/connect/test_range_simple.py @@ -12,16 +12,3 @@ def test_range_operation(spark_session): # Verify the DataFrame has expected values assert len(pandas_df) == 10, "DataFrame should have 10 rows" assert list(pandas_df["id"]) == list(range(10)), "DataFrame should contain values 0-9" - - -def test_range_first(spark_session): - spark_range = spark_session.range(10) - first_row = spark_range.first() - assert first_row["id"] == 0, "First row should have id=0" - - -def test_range_limit(spark_session): - spark_range = spark_session.range(10) - limited_df = spark_range.limit(5).toPandas() - assert len(limited_df) == 5, "Limited DataFrame should have 5 rows" - assert list(limited_df["id"]) == list(range(5)), "Limited DataFrame should contain values 0-4"