diff --git a/daft/delta_lake/delta_lake_scan.py b/daft/delta_lake/delta_lake_scan.py index f1357c2fa7..a287a90b8c 100644 --- a/daft/delta_lake/delta_lake_scan.py +++ b/daft/delta_lake/delta_lake_scan.py @@ -111,6 +111,9 @@ def __init__( def schema(self) -> Schema: return self._schema + def name(self) -> str: + return "DeltaLakeScanOperator" + def display_name(self) -> str: return f"DeltaLakeScanOperator({self._table.metadata().name})" diff --git a/daft/hudi/hudi_scan.py b/daft/hudi/hudi_scan.py index 3d87f9716a..5c10476947 100644 --- a/daft/hudi/hudi_scan.py +++ b/daft/hudi/hudi_scan.py @@ -37,6 +37,9 @@ def __init__(self, table_uri: str, storage_config: StorageConfig) -> None: def schema(self) -> Schema: return self._schema + def name(self) -> str: + return "HudiScanOperator" + def display_name(self) -> str: return f"HudiScanOperator({self._table.props.name})" diff --git a/daft/iceberg/iceberg_scan.py b/daft/iceberg/iceberg_scan.py index ac47116689..474eba36a4 100644 --- a/daft/iceberg/iceberg_scan.py +++ b/daft/iceberg/iceberg_scan.py @@ -115,6 +115,9 @@ def __init__(self, iceberg_table: Table, snapshot_id: int | None, storage_config def schema(self) -> Schema: return self._schema + def name(self) -> str: + return "IcebergScanOperator" + def display_name(self) -> str: return f"IcebergScanOperator({'.'.join(self._table.name())})" diff --git a/daft/io/_generator.py b/daft/io/_generator.py index d52915a0e7..41781c479b 100644 --- a/daft/io/_generator.py +++ b/daft/io/_generator.py @@ -80,6 +80,9 @@ def __init__( self._generators = generators self._schema = schema + def name(self) -> str: + return self.display_name() + def display_name(self) -> str: return "GeneratorScanOperator" diff --git a/daft/io/_lance.py b/daft/io/_lance.py index fb90865de9..008fc70ded 100644 --- a/daft/io/_lance.py +++ b/daft/io/_lance.py @@ -69,6 +69,9 @@ class LanceDBScanOperator(ScanOperator): def __init__(self, ds: "lance.LanceDataset"): self._ds = ds + def name(self) -> str: + return "LanceDBScanOperator" + def display_name(self) -> str: return f"LanceDBScanOperator({self._ds.uri})" diff --git a/daft/sql/sql_scan.py b/daft/sql/sql_scan.py index 892475e676..9f014bdbd1 100644 --- a/daft/sql/sql_scan.py +++ b/daft/sql/sql_scan.py @@ -69,6 +69,9 @@ def __init__( def schema(self) -> Schema: return self._schema + def name(self) -> str: + return "SQLScanOperator" + def display_name(self) -> str: return f"SQLScanOperator(sql={self.sql}, conn={self.conn})" diff --git a/src/common/scan-info/src/scan_operator.rs b/src/common/scan-info/src/scan_operator.rs index 10c798fe5f..65714cc496 100644 --- a/src/common/scan-info/src/scan_operator.rs +++ b/src/common/scan-info/src/scan_operator.rs @@ -10,6 +10,8 @@ use daft_schema::schema::SchemaRef; use crate::{PartitionField, Pushdowns, ScanTaskLikeRef}; pub trait ScanOperator: Send + Sync + Debug { + fn name(&self) -> &str; + fn schema(&self) -> SchemaRef; fn partitioning_keys(&self) -> &[PartitionField]; fn file_path_column(&self) -> Option<&str>; diff --git a/src/common/scan-info/src/test/mod.rs b/src/common/scan-info/src/test/mod.rs index 2fd717db1b..0da27600e9 100644 --- a/src/common/scan-info/src/test/mod.rs +++ b/src/common/scan-info/src/test/mod.rs @@ -97,6 +97,9 @@ Pushdowns: {pushdowns} } impl ScanOperator for DummyScanOperator { + fn name(&self) -> &'static str { + "dummy" + } fn schema(&self) -> SchemaRef { self.schema.clone() } diff --git a/src/daft-logical-plan/src/optimization/optimizer.rs b/src/daft-logical-plan/src/optimization/optimizer.rs index 7c68274d95..257d7de736 100644 --- a/src/daft-logical-plan/src/optimization/optimizer.rs +++ b/src/daft-logical-plan/src/optimization/optimizer.rs @@ -7,7 +7,8 @@ use super::{ logical_plan_tracker::LogicalPlanTracker, rules::{ DropRepartition, EliminateCrossJoin, EnrichWithStats, LiftProjectFromAgg, MaterializeScans, - OptimizerRule, PushDownFilter, PushDownLimit, PushDownProjection, SplitActorPoolProjects, + OptimizerRule, PushDownFilter, PushDownLimit, PushDownProjection, SimplifyExpressionsRule, + SplitActorPoolProjects, }, }; use crate::LogicalPlan; @@ -97,6 +98,11 @@ impl Optimizer { ], RuleExecutionStrategy::Once, ), + // we want to simplify expressions first to make the rest of the rules easier + RuleBatch::new( + vec![Box::new(SimplifyExpressionsRule::new())], + RuleExecutionStrategy::FixedPoint(Some(3)), + ), // --- Bulk of our rules --- RuleBatch::new( vec![ @@ -129,6 +135,11 @@ impl Optimizer { vec![Box::new(EnrichWithStats::new())], RuleExecutionStrategy::Once, ), + // try to simplify expressions again as other rules could introduce new exprs + RuleBatch::new( + vec![Box::new(SimplifyExpressionsRule::new())], + RuleExecutionStrategy::FixedPoint(Some(3)), + ), ]; Self::with_rule_batches(rule_batches, config) diff --git a/src/daft-logical-plan/src/optimization/rules/mod.rs b/src/daft-logical-plan/src/optimization/rules/mod.rs index 75e0f36c88..d8342fc991 100644 --- a/src/daft-logical-plan/src/optimization/rules/mod.rs +++ b/src/daft-logical-plan/src/optimization/rules/mod.rs @@ -7,6 +7,7 @@ mod push_down_filter; mod push_down_limit; mod push_down_projection; mod rule; +mod simplify_expressions; mod split_actor_pool_projects; pub use drop_repartition::DropRepartition; @@ -18,4 +19,5 @@ pub use push_down_filter::PushDownFilter; pub use push_down_limit::PushDownLimit; pub use push_down_projection::PushDownProjection; pub use rule::OptimizerRule; +pub use simplify_expressions::SimplifyExpressionsRule; pub use split_actor_pool_projects::SplitActorPoolProjects; diff --git a/src/daft-logical-plan/src/optimization/rules/simplify_expressions.rs b/src/daft-logical-plan/src/optimization/rules/simplify_expressions.rs new file mode 100644 index 0000000000..bb890e2a17 --- /dev/null +++ b/src/daft-logical-plan/src/optimization/rules/simplify_expressions.rs @@ -0,0 +1,580 @@ +use std::sync::Arc; + +use common_error::DaftResult; +use common_scan_info::{PhysicalScanInfo, ScanState}; +use common_treenode::{Transformed, TreeNode}; +use daft_core::prelude::SchemaRef; +use daft_dsl::{lit, null_lit, Expr, ExprRef, LiteralValue, Operator}; +use daft_schema::dtype::DataType; + +use super::OptimizerRule; +use crate::LogicalPlan; + +/// Optimization rule for simplifying expressions +#[derive(Default, Debug)] +pub struct SimplifyExpressionsRule {} + +impl SimplifyExpressionsRule { + pub fn new() -> Self { + Self {} + } +} + +impl OptimizerRule for SimplifyExpressionsRule { + fn try_optimize(&self, plan: Arc) -> DaftResult>> { + if plan.exists(|p| match p.as_ref() { + LogicalPlan::Source(source) => match source.source_info.as_ref() { + crate::SourceInfo::Physical(PhysicalScanInfo { scan_state: ScanState::Operator(scan_op), .. }) + // TODO: support simplify expressions for SQLScanOperator + if scan_op.0.name() == "SQLScanOperator" => + { + true + } + _ => false, + }, + _ => false, + }) { + return Ok(Transformed::no(plan)); + } + + let schema = plan.schema(); + plan.transform(|plan| { + Ok(Arc::unwrap_or_clone(plan) + .map_expressions(|expr| simplify_expr(Arc::unwrap_or_clone(expr), &schema))? + .update_data(Arc::new)) + }) + } +} + +fn simplify_expr(expr: Expr, schema: &SchemaRef) -> DaftResult> { + Ok(match expr { + // ---------------- + // Eq + // ---------------- + // true = A --> A + // false = A --> !A + Expr::BinaryOp { + op: Operator::Eq, + left, + right, + } + // A = true --> A + // A = false --> !A + | Expr::BinaryOp { + op: Operator::Eq, + left: right, + right: left, + } if is_bool_lit(&left) && is_bool_type(&right, schema) => { + Transformed::yes(match as_bool_lit(&left) { + Some(true) => right, + Some(false) => right.not(), + None => unreachable!(), + }) + } + + // null = A --> null + // A = null --> null + Expr::BinaryOp { + op: Operator::Eq, + left, + right, + } + | Expr::BinaryOp { + op: Operator::Eq, + left: right, + right: left, + } if is_null(&left) && is_bool_type(&right, schema) => Transformed::yes(null_lit()), + + // ---------------- + // Neq + // ---------------- + // true != A --> !A + // false != A --> A + Expr::BinaryOp { + op: Operator::NotEq, + left, + right, + } + // A != true --> !A + // A != false --> A + | Expr::BinaryOp { + op: Operator::NotEq, + left: right, + right: left, + } if is_bool_lit(&left) && is_bool_type(&right, schema) => { + Transformed::yes(match as_bool_lit(&left) { + Some(true) => right.not(), + Some(false) => right, + None => unreachable!(), + }) + } + + // null != A --> null + // A != null --> null + Expr::BinaryOp { + op: Operator::NotEq, + left, + right, + } + | Expr::BinaryOp { + op: Operator::NotEq, + left: right, + right: left, + } if is_null(&left) && is_bool_type(&right, schema) => Transformed::yes(null_lit()), + + // ---------------- + // OR + // ---------------- + + // true OR A --> true + Expr::BinaryOp { + op: Operator::Or, + left, + right: _, + } if is_true(&left) => Transformed::yes(left), + // false OR A --> A + Expr::BinaryOp { + op: Operator::Or, + left, + right, + } if is_false(&left) => Transformed::yes(right), + // A OR true --> true + Expr::BinaryOp { + op: Operator::Or, + left: _, + right, + } if is_true(&right) => Transformed::yes(right), + // A OR false --> A + Expr::BinaryOp { + left, + op: Operator::Or, + right, + } if is_false(&right) => Transformed::yes(left), + + // ---------------- + // AND (TODO) + // ---------------- + + // ---------------- + // Multiplication + // ---------------- + + // A * 1 --> A + // 1 * A --> A + Expr::BinaryOp { + op: Operator::Multiply, + left, + right, + }| Expr::BinaryOp { + op: Operator::Multiply, + left: right, + right: left, + } if is_one(&right) => Transformed::yes(left), + + // A * null --> null + Expr::BinaryOp { + op: Operator::Multiply, + left: _, + right, + } if is_null(&right) => Transformed::yes(right), + // null * A --> null + Expr::BinaryOp { + op: Operator::Multiply, + left, + right: _, + } if is_null(&left) => Transformed::yes(left), + + // TODO: Can't do this one because we don't have a way to determine if an expr potentially contains nulls (nullable) + // A * 0 --> 0 (if A is not null and not floating/decimal) + // 0 * A --> 0 (if A is not null and not floating/decimal) + + // ---------------- + // Division + // ---------------- + // A / 1 --> A + Expr::BinaryOp { + op: Operator::TrueDivide, + left, + right, + } if is_one(&right) => Transformed::yes(left), + // null / A --> null + Expr::BinaryOp { + op: Operator::TrueDivide, + left, + right: _, + } if is_null(&left) => Transformed::yes(left), + // A / null --> null + Expr::BinaryOp { + op: Operator::TrueDivide, + left: _, + right, + } if is_null(&right) => Transformed::yes(right), + + // ---------------- + // Addition + // ---------------- + // A + 0 --> A + Expr::BinaryOp { + op: Operator::Plus, + left, + right, + } if is_zero(&right) => Transformed::yes(left), + + // 0 + A --> A + Expr::BinaryOp { + op: Operator::Plus, + left, + right, + } if is_zero(&left) => Transformed::yes(right), + + // ---------------- + // Subtraction + // ---------------- + + // A - 0 --> A + Expr::BinaryOp { + op: Operator::Minus, + left, + right, + } if is_zero(&right) => Transformed::yes(left), + + // A - null --> null + Expr::BinaryOp { + op: Operator::Minus, + left: _, + right, + } if is_null(&right) => Transformed::yes(right), + // null - A --> null + Expr::BinaryOp { + op: Operator::Minus, + left, + right: _, + } if is_null(&left) => Transformed::yes(left), + + // ---------------- + // Modulus + // ---------------- + + // A % null --> null + Expr::BinaryOp { + op: Operator::Modulus, + left: _, + right, + } if is_null(&right) => Transformed::yes(right), + + // null % A --> null + Expr::BinaryOp { + op: Operator::Modulus, + left, + right: _, + } if is_null(&left) => Transformed::yes(left), + + // A BETWEEN low AND high --> A >= low AND A <= high + Expr::Between(expr, low, high) => { + Transformed::yes(expr.clone().lt_eq(high).and(expr.gt_eq(low))) + } + Expr::Not(expr) => match Arc::unwrap_or_clone(expr) { + // NOT (BETWEEN A AND B) --> A < low OR A > high + Expr::Between(expr, low, high) => { + Transformed::yes(expr.clone().lt(low).or(expr.gt(high))) + } + // expr NOT IN () --> true + Expr::IsIn(_, list) if list.is_empty() => Transformed::yes(lit(true)), + + expr => { + let expr = simplify_expr(expr, schema)?; + if expr.transformed { + Transformed::yes(expr.data.not()) + } else { + Transformed::no(expr.data.not()) + } + } + }, + // expr IN () --> false + Expr::IsIn(_, list) if list.is_empty() => Transformed::yes(lit(false)), + + other => Transformed::no(Arc::new(other)), + }) +} + +fn is_zero(s: &Expr) -> bool { + match s { + Expr::Literal(LiteralValue::Int32(0)) + | Expr::Literal(LiteralValue::Int64(0)) + | Expr::Literal(LiteralValue::UInt32(0)) + | Expr::Literal(LiteralValue::UInt64(0)) + | Expr::Literal(LiteralValue::Float64(0.)) => true, + Expr::Literal(LiteralValue::Decimal(v, _p, _s)) if *v == 0 => true, + _ => false, + } +} + +fn is_one(s: &Expr) -> bool { + match s { + Expr::Literal(LiteralValue::Int32(1)) + | Expr::Literal(LiteralValue::Int64(1)) + | Expr::Literal(LiteralValue::UInt32(1)) + | Expr::Literal(LiteralValue::UInt64(1)) + | Expr::Literal(LiteralValue::Float64(1.)) => true, + + Expr::Literal(LiteralValue::Decimal(v, _p, s)) => { + *s >= 0 && POWS_OF_TEN.get(*s as usize).is_some_and(|pow| v == pow) + } + _ => false, + } +} + +fn is_true(expr: &Expr) -> bool { + match expr { + Expr::Literal(LiteralValue::Boolean(v)) => *v, + _ => false, + } +} +fn is_false(expr: &Expr) -> bool { + match expr { + Expr::Literal(LiteralValue::Boolean(v)) => !*v, + _ => false, + } +} + +/// returns true if expr is a +/// `Expr::Literal(LiteralValue::Boolean(v))` , false otherwise +fn is_bool_lit(expr: &Expr) -> bool { + matches!(expr, Expr::Literal(LiteralValue::Boolean(_))) +} + +fn is_bool_type(expr: &Expr, schema: &SchemaRef) -> bool { + matches!(expr.get_type(schema), Ok(DataType::Boolean)) +} + +fn as_bool_lit(expr: &Expr) -> Option { + expr.as_literal().and_then(|l| l.as_bool()) +} + +fn is_null(expr: &Expr) -> bool { + matches!(expr, Expr::Literal(LiteralValue::Null)) +} + +static POWS_OF_TEN: [i128; 38] = [ + 1, + 10, + 100, + 1000, + 10000, + 100000, + 1000000, + 10000000, + 100000000, + 1000000000, + 10000000000, + 100000000000, + 1000000000000, + 10000000000000, + 100000000000000, + 1000000000000000, + 10000000000000000, + 100000000000000000, + 1000000000000000000, + 10000000000000000000, + 100000000000000000000, + 1000000000000000000000, + 10000000000000000000000, + 100000000000000000000000, + 1000000000000000000000000, + 10000000000000000000000000, + 100000000000000000000000000, + 1000000000000000000000000000, + 10000000000000000000000000000, + 100000000000000000000000000000, + 1000000000000000000000000000000, + 10000000000000000000000000000000, + 100000000000000000000000000000000, + 1000000000000000000000000000000000, + 10000000000000000000000000000000000, + 100000000000000000000000000000000000, + 1000000000000000000000000000000000000, + 10000000000000000000000000000000000000, +]; + +#[cfg(test)] +mod test { + use std::sync::Arc; + + use daft_core::prelude::Schema; + use daft_dsl::{col, lit, null_lit, ExprRef}; + use daft_schema::{dtype::DataType, field::Field}; + use rstest::rstest; + + use super::SimplifyExpressionsRule; + use crate::{ + ops::{Filter, Project, Source}, + optimization::rules::OptimizerRule, + source_info::PlaceHolderInfo, + stats::StatsState, + ClusteringSpec, LogicalPlan, LogicalPlanBuilder, SourceInfo, + }; + + fn make_source() -> LogicalPlanBuilder { + let schema = Arc::new( + Schema::new(vec![ + Field::new("bool", DataType::Boolean), + Field::new("int", DataType::Int32), + ]) + .unwrap(), + ); + LogicalPlanBuilder::from( + LogicalPlan::Source(Source { + output_schema: schema.clone(), + source_info: Arc::new(SourceInfo::PlaceHolder(PlaceHolderInfo { + source_schema: schema, + clustering_spec: Arc::new(ClusteringSpec::unknown()), + source_id: 0, + })), + stats_state: StatsState::NotMaterialized, + }) + .arced(), + ) + } + + #[rstest] + // true = A --> A + #[case(col("bool").eq(lit(true)), col("bool"))] + // false = A --> !A + #[case(col("bool").eq(lit(false)), col("bool").not())] + // A == true ---> A + #[case(col("bool").eq(lit(true)), col("bool"))] + // null = A --> null + #[case(null_lit().eq(col("bool")), null_lit())] + // A == false ---> !A + #[case(col("bool").eq(lit(false)), col("bool").not())] + // true != A --> !A + #[case(lit(true).not_eq(col("bool")), col("bool").not())] + // false != A --> A + #[case(lit(false).not_eq(col("bool")), col("bool"))] + // true OR A --> true + #[case(lit(true).or(col("bool")), lit(true))] + // false OR A --> A + #[case(lit(false).or(col("bool")), col("bool"))] + // A OR true --> true + #[case(col("bool").or(lit(true)), lit(true))] + // A OR false --> A + #[case(col("bool").or(lit(false)), col("bool"))] + fn test_simplify_bool_exprs(#[case] input: ExprRef, #[case] expected: ExprRef) { + let source = make_source().filter(input).unwrap().build(); + let optimizer = SimplifyExpressionsRule::new(); + let optimized = optimizer.try_optimize(source).unwrap(); + + let LogicalPlan::Filter(Filter { predicate, .. }) = optimized.data.as_ref() else { + panic!("Expected Filter, got {:?}", optimized.data) + }; + + // make sure the expression is simplified + assert!(optimized.transformed); + + assert_eq!(predicate, &expected); + } + + #[rstest] + // A * 1 --> A + #[case(col("int").mul(lit(1)), col("int"))] + // 1 * A --> A + #[case(lit(1).mul(col("int")), col("int"))] + // A / 1 --> A + #[case(col("int").div(lit(1)), col("int"))] + // A + 0 --> A + #[case(col("int").add(lit(0)), col("int"))] + // A - 0 --> A + #[case(col("int").sub(lit(0)), col("int"))] + fn test_math_exprs(#[case] input: ExprRef, #[case] expected: ExprRef) { + let source = make_source().select(vec![input]).unwrap().build(); + let optimizer = SimplifyExpressionsRule::new(); + let optimized = optimizer.try_optimize(source).unwrap(); + + let LogicalPlan::Project(Project { projection, .. }) = optimized.data.as_ref() else { + panic!("Expected Filter, got {:?}", optimized.data) + }; + + let projection = projection.first().unwrap(); + + // make sure the expression is simplified + assert!(optimized.transformed); + + assert_eq!(projection, &expected); + } + + #[test] + fn test_not_between() { + let source = make_source() + .filter(col("int").between(lit(1), lit(10)).not()) + .unwrap() + .build(); + let optimizer = SimplifyExpressionsRule::new(); + let optimized = optimizer.try_optimize(source).unwrap(); + + let LogicalPlan::Filter(Filter { predicate, .. }) = optimized.data.as_ref() else { + panic!("Expected Filter, got {:?}", optimized.data) + }; + + // make sure the expression is simplified + assert!(optimized.transformed); + + assert_eq!(predicate, &col("int").lt(lit(1)).or(col("int").gt(lit(10)))); + } + + #[test] + fn test_between() { + let source = make_source() + .filter(col("int").between(lit(1), lit(10))) + .unwrap() + .build(); + let optimizer = SimplifyExpressionsRule::new(); + let optimized = optimizer.try_optimize(source).unwrap(); + + let LogicalPlan::Filter(Filter { predicate, .. }) = optimized.data.as_ref() else { + panic!("Expected Filter, got {:?}", optimized.data) + }; + + // make sure the expression is simplified + assert!(optimized.transformed); + + assert_eq!( + predicate, + &col("int").lt_eq(lit(10)).and(col("int").gt_eq(lit(1))) + ); + } + #[test] + fn test_nested_plan() { + let source = make_source() + .filter(col("int").between(lit(1), lit(10))) + .unwrap() + .select(vec![col("int").add(lit(0))]) + .unwrap() + .build(); + let optimizer = SimplifyExpressionsRule::new(); + let optimized = optimizer.try_optimize(source).unwrap(); + + let LogicalPlan::Project(Project { + projection, input, .. + }) = optimized.data.as_ref() + else { + panic!("Expected Filter, got {:?}", optimized.data) + }; + + let LogicalPlan::Filter(Filter { predicate, .. }) = input.as_ref() else { + panic!("Expected Filter, got {:?}", optimized.data) + }; + + let projection = projection.first().unwrap(); + + // make sure the expression is simplified + assert!(optimized.transformed); + + assert_eq!(projection, &col("int")); + + // make sure the predicate is simplified + assert_eq!( + predicate, + &col("int").lt_eq(lit(10)).and(col("int").gt_eq(lit(1))) + ); + } +} diff --git a/src/daft-logical-plan/src/treenode.rs b/src/daft-logical-plan/src/treenode.rs index a5699e556f..151fe7eb4a 100644 --- a/src/daft-logical-plan/src/treenode.rs +++ b/src/daft-logical-plan/src/treenode.rs @@ -1,9 +1,15 @@ use std::sync::Arc; use common_error::DaftResult; -use common_treenode::DynTreeNode; +use common_treenode::{ + map_until_stop_and_collect, DynTreeNode, Transformed, TreeNodeIterator, TreeNodeRecursion, +}; +use daft_dsl::ExprRef; -use crate::LogicalPlan; +use crate::{ + partitioning::{HashRepartitionConfig, RepartitionSpec}, + LogicalPlan, +}; impl DynTreeNode for LogicalPlan { fn arc_children(&self) -> Vec> { @@ -29,3 +35,161 @@ impl DynTreeNode for LogicalPlan { } } } + +impl LogicalPlan { + pub fn map_expressions DaftResult>>( + self, + mut f: F, + ) -> DaftResult> { + use crate::ops::{ActorPoolProject, Explode, Filter, Join, Project, Repartition, Sort}; + + Ok(match self { + Self::Project(Project { + input, + projection, + projected_schema, + stats_state, + }) => projection + .into_iter() + .map_until_stop_and_collect(f)? + .update_data(|expr| { + Self::Project(Project { + input, + projection: expr, + projected_schema, + stats_state, + }) + }), + Self::Filter(Filter { + input, + predicate, + stats_state, + }) => f(predicate)?.update_data(|expr| { + Self::Filter(Filter { + input, + predicate: expr, + stats_state, + }) + }), + Self::Repartition(Repartition { + input, + repartition_spec, + stats_state, + }) => match repartition_spec { + RepartitionSpec::Hash(HashRepartitionConfig { num_partitions, by }) => by + .into_iter() + .map_until_stop_and_collect(f)? + .update_data(|expr| { + RepartitionSpec::Hash(HashRepartitionConfig { + num_partitions, + by: expr, + }) + }), + repartition_spec => Transformed::no(repartition_spec), + } + .update_data(|repartition_spec| { + Self::Repartition(Repartition { + input, + repartition_spec, + stats_state, + }) + }), + Self::ActorPoolProject(ActorPoolProject { + input, + projection, + projected_schema, + stats_state, + }) => projection + .into_iter() + .map_until_stop_and_collect(f)? + .update_data(|expr| { + Self::ActorPoolProject(ActorPoolProject { + input, + projection: expr, + projected_schema, + stats_state, + }) + }), + Self::Sort(Sort { + input, + sort_by, + descending, + nulls_first, + stats_state, + }) => sort_by + .into_iter() + .map_until_stop_and_collect(f)? + .update_data(|expr| { + Self::Sort(Sort { + input, + sort_by: expr, + descending, + nulls_first, + stats_state, + }) + }), + Self::Explode(Explode { + input, + to_explode, + exploded_schema, + stats_state, + }) => to_explode + .into_iter() + .map_until_stop_and_collect(f)? + .update_data(|expr| { + Self::Explode(Explode { + input, + to_explode: expr, + exploded_schema, + stats_state, + }) + }), + Self::Join(Join { + left, + right, + left_on, + right_on, + null_equals_nulls, + join_type, + join_strategy, + output_schema, + stats_state, + }) => { + let o = left_on + .into_iter() + .zip(right_on) + .map_until_stop_and_collect(|(l, r)| { + map_until_stop_and_collect!(f(l), r, f(r)) + })?; + let (left_on, right_on) = o.data.into_iter().unzip(); + + if o.transformed { + Transformed::yes(Self::Join(Join { + left, + right, + left_on, + right_on, + null_equals_nulls, + join_type, + join_strategy, + output_schema, + stats_state, + })) + } else { + Transformed::no(Self::Join(Join { + left, + right, + left_on, + right_on, + null_equals_nulls, + join_type, + join_strategy, + output_schema, + stats_state, + })) + } + } + lp => Transformed::no(lp), + }) + } +} diff --git a/src/daft-scan/src/anonymous.rs b/src/daft-scan/src/anonymous.rs index 17f8c6574a..6e3c02a14c 100644 --- a/src/daft-scan/src/anonymous.rs +++ b/src/daft-scan/src/anonymous.rs @@ -32,6 +32,9 @@ impl AnonymousScanOperator { } impl ScanOperator for AnonymousScanOperator { + fn name(&self) -> &str { + "AnonymousScanOperator" + } fn schema(&self) -> SchemaRef { self.schema.clone() } diff --git a/src/daft-scan/src/glob.rs b/src/daft-scan/src/glob.rs index 2f8d0f071f..bf91d4205c 100644 --- a/src/daft-scan/src/glob.rs +++ b/src/daft-scan/src/glob.rs @@ -297,6 +297,10 @@ impl GlobScanOperator { } impl ScanOperator for GlobScanOperator { + fn name(&self) -> &'static str { + "GlobScanOperator" + } + fn schema(&self) -> SchemaRef { self.schema.clone() } diff --git a/src/daft-scan/src/python.rs b/src/daft-scan/src/python.rs index f23c295ad0..e27d68ef3a 100644 --- a/src/daft-scan/src/python.rs +++ b/src/daft-scan/src/python.rs @@ -171,6 +171,7 @@ pub mod pylib { #[pyclass(module = "daft.daft")] #[derive(Debug)] struct PythonScanOperatorBridge { + name: String, operator: PyObject, schema: SchemaRef, partitioning_keys: Vec, @@ -181,6 +182,10 @@ pub mod pylib { } impl PythonScanOperatorBridge { + fn _name(abc: &PyObject, py: Python) -> PyResult { + let result = abc.call_method0(py, pyo3::intern!(py, "name"))?; + result.extract::(py) + } fn _partitioning_keys(abc: &PyObject, py: Python) -> PyResult> { let result = abc.call_method0(py, pyo3::intern!(py, "partitioning_keys"))?; let result = result.extract::<&PyList>(py)?; @@ -223,6 +228,7 @@ pub mod pylib { impl PythonScanOperatorBridge { #[staticmethod] pub fn from_python_abc(abc: PyObject, py: Python) -> PyResult { + let name = Self::_name(&abc, py)?; let partitioning_keys = Self::_partitioning_keys(&abc, py)?; let schema = Self::_schema(&abc, py)?; let can_absorb_filter = Self::_can_absorb_filter(&abc, py)?; @@ -231,6 +237,7 @@ pub mod pylib { let display_name = Self::_display_name(&abc, py)?; Ok(Self { + name, operator: abc, schema, partitioning_keys, @@ -243,6 +250,9 @@ pub mod pylib { } impl ScanOperator for PythonScanOperatorBridge { + fn name(&self) -> &str { + &self.name + } fn partitioning_keys(&self) -> &[PartitionField] { &self.partitioning_keys }