From 0e03303364f3be41bcedb3255636711299f47036 Mon Sep 17 00:00:00 2001 From: Cory Grinstead Date: Tue, 14 Jan 2025 10:55:11 -0600 Subject: [PATCH] feat(connect): Rust ray exec (#3666) ## Description you can now specify the runner you want to use via native spark config ```py from daft.daft import connect_start from pyspark.sql import SparkSession server = connect_start() url = f"sc://localhost:{server.port()}" daft_spark = SparkSession.builder.appName("DaftConnectExample").remote(url).getOrCreate() daft_spark.conf.set("daft.runner", "ray") # or use native # daft_spark.conf.set("daft.runner", "native") df1 = daft_spark.read.parquet("~/datasets/tpcds/sf10/customer.parquet") df1.limit(10).show() ``` ## Note for reviewers so i had to do a bit of refactoring to get this to work, mostly in how the show string works. The actual ray implementation is isolated within the new `daft-ray-execution` lib, and it's just a wrapper around our existing python code. The idea with putting it in it's own lib is that it creates a better abstraction and if we want to later port more of that code into rust, it'll be a lot easier. also a few small drivebys that were bugging me while working on this - change `warn!`'s to `debug!`'s as it was cluttering the output on every command. - refactor `PlanIds` to actually reflects what it does, a `ResponseBuilder`. - the error output for unsupported relations was nasty, so i simplified it [here](https://github.com/Eventual-Inc/Daft/pull/3666/files#diff-0f6aee05ac5693372752b1eab7454e80142119479e41093d0b975bb777d83ffdR169) and [here](https://github.com/Eventual-Inc/Daft/pull/3666/files#diff-0f6aee05ac5693372752b1eab7454e80142119479e41093d0b975bb777d83ffdR169) --- Cargo.lock | 14 + Cargo.toml | 6 +- src/daft-connect/Cargo.toml | 43 ++- src/daft-connect/src/display.rs | 1 - src/daft-connect/src/execute.rs | 313 ++++++++++++++++++ src/daft-connect/src/lib.rs | 74 +++-- src/daft-connect/src/op.rs | 1 - src/daft-connect/src/op/execute/root.rs | 77 ----- src/daft-connect/src/op/execute/write.rs | 144 -------- .../{op/execute.rs => response_builder.rs} | 53 +-- src/daft-connect/src/session.rs | 7 +- src/daft-connect/src/translation/datatype.rs | 4 +- src/daft-connect/src/translation/expr.rs | 10 +- .../src/translation/logical_plan.rs | 162 +++++++-- .../src/translation/logical_plan/range.rs | 94 +++--- .../src/translation/logical_plan/read.rs | 4 +- .../logical_plan/read/data_source.rs | 8 +- .../{translation.rs => translation/mod.rs} | 1 - src/daft-connect/src/translation/schema.rs | 72 ++-- src/daft-logical-plan/Cargo.toml | 1 + src/daft-micropartition/src/partitioning.rs | 1 + src/daft-micropartition/src/python.rs | 5 +- src/daft-ray-execution/Cargo.toml | 22 ++ src/daft-ray-execution/src/lib.rs | 74 +++++ src/daft-scheduler/src/scheduler.rs | 1 - 25 files changed, 769 insertions(+), 423 deletions(-) create mode 100644 src/daft-connect/src/execute.rs delete mode 100644 src/daft-connect/src/op.rs delete mode 100644 src/daft-connect/src/op/execute/root.rs delete mode 100644 src/daft-connect/src/op/execute/write.rs rename src/daft-connect/src/{op/execute.rs => response_builder.rs} (61%) rename src/daft-connect/src/{translation.rs => translation/mod.rs} (86%) create mode 100644 src/daft-ray-execution/Cargo.toml create mode 100644 src/daft-ray-execution/src/lib.rs diff --git a/Cargo.lock b/Cargo.lock index 1b4d338cb3..ca979cdaa9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1941,6 +1941,7 @@ dependencies = [ "daft-minhash", "daft-parquet", "daft-physical-plan", + "daft-ray-execution", "daft-scan", "daft-scheduler", "daft-sql", @@ -2006,12 +2007,14 @@ dependencies = [ "arrow2", "async-stream", "common-daft-config", + "common-error", "common-file-formats", "daft-core", "daft-dsl", "daft-local-execution", "daft-logical-plan", "daft-micropartition", + "daft-ray-execution", "daft-scan", "daft-schema", "daft-sql", @@ -2022,6 +2025,7 @@ dependencies = [ "itertools 0.11.0", "pyo3", "spark-connect", + "textwrap", "tokio", "tonic", "tracing", @@ -2472,6 +2476,16 @@ dependencies = [ "serde", ] +[[package]] +name = "daft-ray-execution" +version = "0.3.0-dev0" +dependencies = [ + "common-error", + "daft-logical-plan", + "daft-micropartition", + "pyo3", +] + [[package]] name = "daft-scan" version = "0.3.0-dev0" diff --git a/Cargo.toml b/Cargo.toml index eae95f1264..85be93e4c4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -31,6 +31,7 @@ daft-micropartition = {path = "src/daft-micropartition", default-features = fals daft-minhash = {path = "src/daft-minhash", default-features = false} daft-parquet = {path = "src/daft-parquet", default-features = false} daft-physical-plan = {path = "src/daft-physical-plan", default-features = false} +daft-ray-execution = {path = "src/daft-ray-execution", default-features = false} daft-scan = {path = "src/daft-scan", default-features = false} daft-scheduler = {path = "src/daft-scheduler", default-features = false} daft-sql = {path = "src/daft-sql", default-features = false} @@ -56,7 +57,6 @@ python = [ "common-system-info/python", "daft-catalog-python-catalog/python", "daft-catalog/python", - "daft-connect/python", "daft-core/python", "daft-csv/python", "daft-dsl/python", @@ -172,7 +172,8 @@ members = [ "src/parquet2", # "src/spark-connect-script", "src/generated/spark-connect", - "src/common/partitioning" + "src/common/partitioning", + "src/daft-ray-execution" ] [workspace.dependencies] @@ -200,6 +201,7 @@ daft-hash = {path = "src/daft-hash"} daft-local-execution = {path = "src/daft-local-execution"} daft-logical-plan = {path = "src/daft-logical-plan"} daft-micropartition = {path = "src/daft-micropartition"} +daft-ray-execution = {path = "src/daft-ray-execution"} daft-scan = {path = "src/daft-scan"} daft-schema = {path = "src/daft-schema"} daft-sql = {path = "src/daft-sql"} diff --git a/src/daft-connect/Cargo.toml b/src/daft-connect/Cargo.toml index a72d574677..bb846b9c46 100644 --- a/src/daft-connect/Cargo.toml +++ b/src/daft-connect/Cargo.toml @@ -1,30 +1,49 @@ [dependencies] arrow2 = {workspace = true, features = ["io_json_integration"]} async-stream = "0.3.6" -common-daft-config = {workspace = true} -common-file-formats = {workspace = true} -daft-core = {workspace = true} -daft-dsl = {workspace = true} -daft-local-execution = {workspace = true} -daft-logical-plan = {workspace = true} -daft-micropartition = {workspace = true} -daft-scan = {workspace = true} -daft-schema = {workspace = true} -daft-sql = {workspace = true} -daft-table = {workspace = true} +common-daft-config = {workspace = true, optional = true, features = ["python"]} +common-error = {workspace = true, optional = true, features = ["python"]} +common-file-formats = {workspace = true, optional = true, features = ["python"]} +daft-core = {workspace = true, optional = true, features = ["python"]} +daft-dsl = {workspace = true, optional = true, features = ["python"]} +daft-local-execution = {workspace = true, optional = true, features = ["python"]} +daft-logical-plan = {workspace = true, optional = true, features = ["python"]} +daft-micropartition = {workspace = true, optional = true, features = ["python"]} +daft-ray-execution = {workspace = true, optional = true, features = ["python"]} +daft-scan = {workspace = true, optional = true, features = ["python"]} +daft-schema = {workspace = true, optional = true, features = ["python"]} +daft-sql = {workspace = true, optional = true, features = ["python"]} +daft-table = {workspace = true, optional = true, features = ["python"]} dashmap = "6.1.0" eyre = "0.6.12" futures = "0.3.31" itertools = {workspace = true} pyo3 = {workspace = true, optional = true} spark-connect = {workspace = true} +textwrap = "0.16.1" tokio = {version = "1.40.0", features = ["full"]} tonic = "0.12.3" tracing = {workspace = true} uuid = {version = "1.10.0", features = ["v4"]} [features] -python = ["dep:pyo3", "common-daft-config/python", "daft-local-execution/python", "daft-logical-plan/python", "daft-scan/python", "daft-table/python", "daft-dsl/python", "daft-schema/python", "daft-core/python", "daft-micropartition/python"] +default = ["python"] +python = [ + "dep:pyo3", + "dep:common-daft-config", + "dep:common-error", + "dep:common-file-formats", + "dep:daft-core", + "dep:daft-dsl", + "dep:daft-local-execution", + "dep:daft-logical-plan", + "dep:daft-micropartition", + "dep:daft-ray-execution", + "dep:daft-scan", + "dep:daft-schema", + "dep:daft-sql", + "dep:daft-table" +] [lints] workspace = true diff --git a/src/daft-connect/src/display.rs b/src/daft-connect/src/display.rs index 83fce57fb5..8f80402997 100644 --- a/src/daft-connect/src/display.rs +++ b/src/daft-connect/src/display.rs @@ -114,7 +114,6 @@ fn type_to_string(dtype: &DataType) -> String { DataType::FixedShapeTensor(_, _) => "daft.fixed_shape_tensor".to_string(), DataType::SparseTensor(_) => "daft.sparse_tensor".to_string(), DataType::FixedShapeSparseTensor(_, _) => "daft.fixed_shape_sparse_tensor".to_string(), - #[cfg(feature = "python")] DataType::Python => "daft.python".to_string(), DataType::Unknown => "unknown".to_string(), DataType::UInt8 => "arrow.uint8".to_string(), diff --git a/src/daft-connect/src/execute.rs b/src/daft-connect/src/execute.rs new file mode 100644 index 0000000000..e1a0f200c8 --- /dev/null +++ b/src/daft-connect/src/execute.rs @@ -0,0 +1,313 @@ +use std::{future::ready, sync::Arc}; + +use common_daft_config::DaftExecutionConfig; +use common_error::{DaftError, DaftResult}; +use common_file_formats::FileFormat; +use daft_dsl::LiteralValue; +use daft_local_execution::NativeExecutor; +use daft_logical_plan::LogicalPlanBuilder; +use daft_micropartition::MicroPartition; +use daft_ray_execution::RayEngine; +use daft_table::Table; +use eyre::{bail, Context}; +use futures::{ + stream::{self, BoxStream}, + TryFutureExt, TryStreamExt, +}; +use itertools::Itertools; +use pyo3::Python; +use spark_connect::{ + relation::RelType, + write_operation::{SaveMode, SaveType}, + ExecutePlanResponse, Relation, ShowString, WriteOperation, +}; +use tonic::{codegen::tokio_stream::wrappers::ReceiverStream, Status}; +use tracing::debug; + +use crate::{ + response_builder::ResponseBuilder, session::Session, translation, ExecuteStream, Runner, +}; + +impl Session { + pub fn get_runner(&self) -> eyre::Result { + let runner = match self.config_values().get("daft.runner") { + Some(runner) => match runner.as_str() { + "ray" => Runner::Ray, + "native" => Runner::Native, + _ => bail!("Invalid runner: {}", runner), + }, + None => Runner::Native, + }; + Ok(runner) + } + + pub async fn run_query( + &self, + lp: LogicalPlanBuilder, + ) -> eyre::Result>>> { + match self.get_runner()? { + Runner::Ray => { + let runner_address = self.config_values().get("daft.runner.ray.address"); + let runner_address = runner_address.map(|s| s.to_string()); + + let runner = RayEngine::try_new(runner_address, None, None)?; + let result_set = tokio::task::spawn_blocking(move || { + Python::with_gil(|py| runner.run_iter_impl(py, lp, None)) + }) + .await??; + + Ok(Box::pin(stream::iter(result_set))) + } + + Runner::Native => { + let this = self.clone(); + let result_stream = tokio::task::spawn_blocking(move || { + let plan = lp.optimize()?; + let cfg = Arc::new(DaftExecutionConfig::default()); + let native_executor = NativeExecutor::from_logical_plan_builder(&plan)?; + let results = native_executor.run(&*this.psets, cfg, None)?; + let it = results.into_iter(); + Ok::<_, DaftError>(it.collect_vec()) + }) + .await??; + + Ok(Box::pin(stream::iter(result_stream))) + } + } + } + + pub async fn execute_command( + &self, + command: Relation, + operation_id: String, + ) -> Result { + use futures::{StreamExt, TryStreamExt}; + + let response_builder = ResponseBuilder::new_with_op_id( + self.client_side_session_id(), + self.server_side_session_id(), + operation_id, + ); + + // fallback response + let result_complete = response_builder.result_complete_response(); + + let (tx, rx) = tokio::sync::mpsc::channel::>(1); + + let this = self.clone(); + + tokio::spawn(async move { + let execution_fut = async { + let translator = translation::SparkAnalyzer::new(&this); + match command.rel_type { + Some(RelType::ShowString(ss)) => { + let response = this.show_string(*ss, response_builder.clone()).await?; + if tx.send(Ok(response)).await.is_err() { + return Ok(()); + } + + Ok(()) + } + _ => { + let lp = translator.to_logical_plan(command).await?; + + let mut result_stream = this.run_query(lp).await?; + + while let Some(result) = result_stream.next().await { + let result = result?; + let tables = result.get_tables()?; + for table in tables.as_slice() { + let response = response_builder.arrow_batch_response(table)?; + if tx.send(Ok(response)).await.is_err() { + return Ok(()); + } + } + } + Ok(()) + } + } + }; + if let Err(e) = execution_fut.await { + let _ = tx.send(Err(e)).await; + } + }); + + let stream = ReceiverStream::new(rx); + + let stream = stream + .map_err(|e| { + Status::internal( + textwrap::wrap(&format!("Error in Daft server: {e}"), 120).join("\n"), + ) + }) + .chain(stream::once(ready(Ok(result_complete)))); + + Ok(Box::pin(stream)) + } + + pub async fn execute_write_operation( + &self, + operation: WriteOperation, + operation_id: String, + ) -> Result { + use futures::StreamExt; + + let response_builder = ResponseBuilder::new_with_op_id( + self.client_side_session_id(), + self.server_side_session_id(), + operation_id, + ); + + let finished = response_builder.result_complete_response(); + + let this = self.clone(); + + let result = async move { + let WriteOperation { + input, + source, + mode, + sort_column_names, + partitioning_columns, + bucket_by, + options, + clustering_columns, + save_type, + } = operation; + + let Some(input) = input else { + bail!("Input is required"); + }; + + let Some(source) = source else { + bail!("Source is required"); + }; + + let file_format: FileFormat = source.parse()?; + + let Ok(mode) = SaveMode::try_from(mode) else { + bail!("Invalid save mode: {mode}"); + }; + + if !sort_column_names.is_empty() { + // todo(completeness): implement sort + debug!("Ignoring sort_column_names: {sort_column_names:?} (not yet implemented)"); + } + + if !partitioning_columns.is_empty() { + // todo(completeness): implement partitioning + debug!( + "Ignoring partitioning_columns: {partitioning_columns:?} (not yet implemented)" + ); + } + + if let Some(bucket_by) = bucket_by { + // todo(completeness): implement bucketing + debug!("Ignoring bucket_by: {bucket_by:?} (not yet implemented)"); + } + + if !options.is_empty() { + // todo(completeness): implement options + debug!("Ignoring options: {options:?} (not yet implemented)"); + } + + if !clustering_columns.is_empty() { + // todo(completeness): implement clustering + debug!("Ignoring clustering_columns: {clustering_columns:?} (not yet implemented)"); + } + + match mode { + SaveMode::Unspecified => {} + SaveMode::Append => {} + SaveMode::Overwrite => {} + SaveMode::ErrorIfExists => {} + SaveMode::Ignore => {} + } + + let Some(save_type) = save_type else { + bail!("Save type is required"); + }; + + let path = match save_type { + SaveType::Path(path) => path, + SaveType::Table(table) => { + let name = table.table_name; + bail!("Tried to write to table {name} but it is not yet implemented. Try to write to a path instead."); + } + }; + + let translator = translation::SparkAnalyzer::new(&this); + + let plan = translator.to_logical_plan(input).await?; + + let plan = plan + .table_write(&path, file_format, None, None, None) + .wrap_err("Failed to create table write plan")?; + + let mut result_stream = this.run_query(plan).await?; + + // this is so we make sure the operation is actually done + // before we return + // + // an example where this is important is if we write to a parquet file + // and then read immediately after, we need to wait for the write to finish + while let Some(_result) = result_stream.next().await {} + + Ok(()) + }; + + let result = result.map_err(|e| { + Status::internal(textwrap::wrap(&format!("Error in Daft server: {e}"), 120).join("\n")) + }); + + let future = result.and_then(|()| ready(Ok(finished))); + let stream = futures::stream::once(future); + + Ok(Box::pin(stream)) + } + + async fn show_string( + &self, + show_string: ShowString, + response_builder: ResponseBuilder, + ) -> eyre::Result { + let translator = translation::SparkAnalyzer::new(self); + + let ShowString { + input, + num_rows, + truncate: _, + vertical, + } = show_string; + + if vertical { + bail!("Vertical show string is not supported"); + } + + let Some(input) = input else { + bail!("input must be set"); + }; + + let plan = Box::pin(translator.to_logical_plan(*input)).await?; + let plan = plan.limit(num_rows as i64, true)?; + + let results = translator.session.run_query(plan).await?; + let results = results.try_collect::>().await?; + let single_batch = results + .into_iter() + .next() + .ok_or_else(|| eyre::eyre!("No results"))?; + + let tbls = single_batch.get_tables()?; + let tbl = Table::concat(&tbls)?; + let output = tbl.to_comfy_table(None).to_string(); + + let s = LiteralValue::Utf8(output) + .into_single_value_series()? + .rename("show_string"); + + let tbl = Table::from_nonempty_columns(vec![s])?; + let response = response_builder.arrow_batch_response(&tbl)?; + Ok(response) + } +} diff --git a/src/daft-connect/src/lib.rs b/src/daft-connect/src/lib.rs index bd55024825..7a99c5a867 100644 --- a/src/daft-connect/src/lib.rs +++ b/src/daft-connect/src/lib.rs @@ -6,11 +6,29 @@ #![feature(stmt_expr_attributes)] #![feature(try_trait_v2_residual)] -use daft_micropartition::partitioning::InMemoryPartitionSetCache; +#[cfg(feature = "python")] +mod config; +#[cfg(feature = "python")] +mod display; +#[cfg(feature = "python")] +mod err; +#[cfg(feature = "python")] +mod execute; +#[cfg(feature = "python")] +mod response_builder; +#[cfg(feature = "python")] +mod session; +#[cfg(feature = "python")] +mod translation; +#[cfg(feature = "python")] +pub mod util; +#[cfg(feature = "python")] use dashmap::DashMap; +#[cfg(feature = "python")] use eyre::Context; #[cfg(feature = "python")] use pyo3::types::PyModuleMethods; +#[cfg(feature = "python")] use spark_connect::{ analyze_plan_response, command::CommandType, @@ -22,20 +40,18 @@ use spark_connect::{ InterruptRequest, InterruptResponse, Plan, ReattachExecuteRequest, ReleaseExecuteRequest, ReleaseExecuteResponse, ReleaseSessionRequest, ReleaseSessionResponse, }; +#[cfg(feature = "python")] use tonic::{transport::Server, Request, Response, Status}; -use tracing::{info, warn}; +#[cfg(feature = "python")] +use tracing::{debug, info}; +#[cfg(feature = "python")] use uuid::Uuid; +#[cfg(feature = "python")] use crate::{display::SparkDisplay, session::Session, translation::SparkAnalyzer}; -mod config; -mod display; -mod err; -mod op; - -mod session; -mod translation; -pub mod util; +#[cfg(feature = "python")] +pub type ExecuteStream = ::ExecutePlanStream; #[cfg_attr(feature = "python", pyo3::pyclass)] pub struct ConnectionHandle { @@ -57,6 +73,7 @@ impl ConnectionHandle { } } +#[cfg(feature = "python")] pub fn start(addr: &str) -> eyre::Result { info!("Daft-Connect server listening on {addr}"); let addr = util::parse_spark_connect_address(addr)?; @@ -117,11 +134,13 @@ pub fn start(addr: &str) -> eyre::Result { Ok(handle) } +#[cfg(feature = "python")] #[derive(Default)] pub struct DaftSparkConnectService { client_to_session: DashMap, // To track session data } +#[cfg(feature = "python")] impl DaftSparkConnectService { fn get_session( &self, @@ -142,6 +161,7 @@ impl DaftSparkConnectService { } } +#[cfg(feature = "python")] #[tonic::async_trait] impl SparkConnectService for DaftSparkConnectService { type ExecutePlanStream = std::pin::Pin< @@ -177,7 +197,7 @@ impl SparkConnectService for DaftSparkConnectService { match plan { OpType::Root(relation) => { - let result = session.handle_root_command(relation, operation).await?; + let result = session.execute_command(relation, operation).await?; return Ok(Response::new(result)); } OpType::Command(command) => { @@ -190,7 +210,7 @@ impl SparkConnectService for DaftSparkConnectService { unimplemented_err!("RegisterFunction not implemented") } CommandType::WriteOperation(op) => { - let result = session.handle_write_command(op, operation).await?; + let result = session.execute_write_operation(op, operation).await?; return Ok(Response::new(result)); } CommandType::CreateDataframeView(_) => { @@ -235,7 +255,9 @@ impl SparkConnectService for DaftSparkConnectService { CommandType::MergeIntoTableCommand(_) => { unimplemented_err!("MergeIntoTableCommand not implemented") } - CommandType::Extension(_) => unimplemented_err!("Extension not implemented"), + CommandType::Extension(_) => { + unimplemented_err!("Extension not implemented") + } } } }? @@ -305,7 +327,10 @@ impl SparkConnectService for DaftSparkConnectService { return Err(Status::invalid_argument("op_type is required to be root")); }; - let result = match translation::relation_to_spark_schema(relation).await { + let session = self.get_session(&session_id)?; + let translator = SparkAnalyzer::new(&session); + + let result = match translator.relation_to_spark_schema(relation).await { Ok(schema) => schema, Err(e) => { return invalid_argument_err!( @@ -354,7 +379,7 @@ impl SparkConnectService for DaftSparkConnectService { }; if let Some(level) = level { - warn!("ignoring tree string level: {level:?}"); + debug!("ignoring tree string level: {level:?}"); }; let Some(op_type) = plan.op_type else { @@ -367,13 +392,12 @@ impl SparkConnectService for DaftSparkConnectService { if let Some(common) = &input.common { if common.origin.is_some() { - warn!("Ignoring common metadata for relation: {common:?}; not yet implemented"); + debug!("Ignoring common metadata for relation: {common:?}; not yet implemented"); } } + let session = self.get_session(&session_id)?; - // We're just checking the schema here, so we don't need to use a persistent cache as it won't be used - let pset = InMemoryPartitionSetCache::empty(); - let translator = SparkAnalyzer::new(&pset); + let translator = SparkAnalyzer::new(&session); let plan = Box::pin(translator.to_logical_plan(input)) .await .unwrap() @@ -392,7 +416,9 @@ impl SparkConnectService for DaftSparkConnectService { Ok(Response::new(response)) } - other => unimplemented_err!("Analyze plan operation is not yet implemented: {other:?}"), + other => { + unimplemented_err!("Analyze plan operation is not yet implemented: {other:?}") + } } } @@ -456,7 +482,13 @@ impl SparkConnectService for DaftSparkConnectService { } #[cfg(feature = "python")] -#[pyo3::pyfunction] +pub enum Runner { + Ray, + Native, +} + +#[cfg(feature = "python")] +#[cfg_attr(feature = "python", pyo3::pyfunction)] #[pyo3(name = "connect_start", signature = (addr = "sc://0.0.0.0:0"))] pub fn py_connect_start(addr: &str) -> pyo3::PyResult { start(addr).map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("{e:?}"))) diff --git a/src/daft-connect/src/op.rs b/src/daft-connect/src/op.rs deleted file mode 100644 index 2e8bdddf98..0000000000 --- a/src/daft-connect/src/op.rs +++ /dev/null @@ -1 +0,0 @@ -pub mod execute; diff --git a/src/daft-connect/src/op/execute/root.rs b/src/daft-connect/src/op/execute/root.rs deleted file mode 100644 index 4a9feb4ed7..0000000000 --- a/src/daft-connect/src/op/execute/root.rs +++ /dev/null @@ -1,77 +0,0 @@ -use std::{future::ready, sync::Arc}; - -use common_daft_config::DaftExecutionConfig; -use daft_local_execution::NativeExecutor; -use futures::stream; -use spark_connect::{ExecutePlanResponse, Relation}; -use tonic::{codegen::tokio_stream::wrappers::ReceiverStream, Status}; - -use crate::{ - op::execute::{ExecuteStream, PlanIds}, - session::Session, - translation, -}; - -impl Session { - pub async fn handle_root_command( - &self, - command: Relation, - operation_id: String, - ) -> Result { - use futures::{StreamExt, TryStreamExt}; - - let context = PlanIds { - session: self.client_side_session_id().to_string(), - server_side_session: self.server_side_session_id().to_string(), - operation: operation_id, - }; - - let finished = context.finished(); - - let (tx, rx) = tokio::sync::mpsc::channel::>(1); - - let pset = self.psets.clone(); - - tokio::spawn(async move { - let execution_fut = async { - let translator = translation::SparkAnalyzer::new(&pset); - let lp = translator.to_logical_plan(command).await?; - - // todo: convert optimize to async (looks like A LOT of work)... it touches a lot of API - // I tried and spent about an hour and gave up ~ Andrew Gazelka 🪦 2024-12-09 - let optimized_plan = tokio::task::spawn_blocking(move || lp.optimize()) - .await - .unwrap()?; - - let cfg = Arc::new(DaftExecutionConfig::default()); - let native_executor = NativeExecutor::from_logical_plan_builder(&optimized_plan)?; - - let mut result_stream = native_executor.run(&pset, cfg, None)?.into_stream(); - - while let Some(result) = result_stream.next().await { - let result = result?; - let tables = result.get_tables()?; - for table in tables.as_slice() { - let response = context.gen_response(table)?; - if tx.send(Ok(response)).await.is_err() { - return Ok(()); - } - } - } - Ok(()) - }; - - if let Err(e) = execution_fut.await { - let _ = tx.send(Err(e)).await; - } - }); - - let stream = ReceiverStream::new(rx); - - let stream = stream - .map_err(|e| Status::internal(format!("Error in Daft server: {e:?}"))) - .chain(stream::once(ready(Ok(finished)))); - - Ok(Box::pin(stream)) - } -} diff --git a/src/daft-connect/src/op/execute/write.rs b/src/daft-connect/src/op/execute/write.rs deleted file mode 100644 index da06f04887..0000000000 --- a/src/daft-connect/src/op/execute/write.rs +++ /dev/null @@ -1,144 +0,0 @@ -use std::future::ready; - -use common_daft_config::DaftExecutionConfig; -use common_file_formats::FileFormat; -use daft_local_execution::NativeExecutor; -use eyre::{bail, WrapErr}; -use spark_connect::{ - write_operation::{SaveMode, SaveType}, - WriteOperation, -}; -use tonic::Status; -use tracing::warn; - -use crate::{ - op::execute::{ExecuteStream, PlanIds}, - session::Session, - translation, -}; - -impl Session { - pub async fn handle_write_command( - &self, - operation: WriteOperation, - operation_id: String, - ) -> Result { - use futures::StreamExt; - - let context = PlanIds { - session: self.client_side_session_id().to_string(), - server_side_session: self.server_side_session_id().to_string(), - operation: operation_id, - }; - - let finished = context.finished(); - let pset = self.psets.clone(); - - let result = async move { - let WriteOperation { - input, - source, - mode, - sort_column_names, - partitioning_columns, - bucket_by, - options, - clustering_columns, - save_type, - } = operation; - - let Some(input) = input else { - bail!("Input is required"); - }; - - let Some(source) = source else { - bail!("Source is required"); - }; - - let file_format: FileFormat = source.parse()?; - - let Ok(mode) = SaveMode::try_from(mode) else { - bail!("Invalid save mode: {mode}"); - }; - - if !sort_column_names.is_empty() { - // todo(completeness): implement sort - warn!("Ignoring sort_column_names: {sort_column_names:?} (not yet implemented)"); - } - - if !partitioning_columns.is_empty() { - // todo(completeness): implement partitioning - warn!( - "Ignoring partitioning_columns: {partitioning_columns:?} (not yet implemented)" - ); - } - - if let Some(bucket_by) = bucket_by { - // todo(completeness): implement bucketing - warn!("Ignoring bucket_by: {bucket_by:?} (not yet implemented)"); - } - - if !options.is_empty() { - // todo(completeness): implement options - warn!("Ignoring options: {options:?} (not yet implemented)"); - } - - if !clustering_columns.is_empty() { - // todo(completeness): implement clustering - warn!("Ignoring clustering_columns: {clustering_columns:?} (not yet implemented)"); - } - - match mode { - SaveMode::Unspecified => {} - SaveMode::Append => {} - SaveMode::Overwrite => {} - SaveMode::ErrorIfExists => {} - SaveMode::Ignore => {} - } - - let Some(save_type) = save_type else { - bail!("Save type is required"); - }; - - let path = match save_type { - SaveType::Path(path) => path, - SaveType::Table(table) => { - let name = table.table_name; - bail!("Tried to write to table {name} but it is not yet implemented. Try to write to a path instead."); - } - }; - - let translator = translation::SparkAnalyzer::new(&pset); - - let plan = translator.to_logical_plan(input).await?; - - let plan = plan - .table_write(&path, file_format, None, None, None) - .wrap_err("Failed to create table write plan")?; - - let optimized_plan = plan.optimize()?; - let cfg = DaftExecutionConfig::default(); - let native_executor = NativeExecutor::from_logical_plan_builder(&optimized_plan)?; - - let mut result_stream = native_executor.run(&pset, cfg.into(), None)?.into_stream(); - - // this is so we make sure the operation is actually done - // before we return - // - // an example where this is important is if we write to a parquet file - // and then read immediately after, we need to wait for the write to finish - while let Some(_result) = result_stream.next().await {} - - Ok(()) - }; - - use futures::TryFutureExt; - - let result = result.map_err(|e| Status::internal(format!("Error in Daft server: {e:?}"))); - - let future = result.and_then(|()| ready(Ok(finished))); - let stream = futures::stream::once(future); - - Ok(Box::pin(stream)) - } -} diff --git a/src/daft-connect/src/op/execute.rs b/src/daft-connect/src/response_builder.rs similarity index 61% rename from src/daft-connect/src/op/execute.rs rename to src/daft-connect/src/response_builder.rs index 41baf88b09..f0092577f8 100644 --- a/src/daft-connect/src/op/execute.rs +++ b/src/daft-connect/src/response_builder.rs @@ -3,43 +3,55 @@ use daft_table::Table; use eyre::Context; use spark_connect::{ execute_plan_response::{ArrowBatch, ResponseType, ResultComplete}, - spark_connect_service_server::SparkConnectService, ExecutePlanResponse, }; use uuid::Uuid; -use crate::{DaftSparkConnectService, Session}; +use crate::Session; -mod root; -mod write; - -pub type ExecuteStream = ::ExecutePlanStream; - -pub struct PlanIds { - session: String, - server_side_session: String, - operation: String, +/// spark responses are stateful, so we need to keep track of the session id, operation id, and server side session id +#[derive(Clone)] +pub struct ResponseBuilder { + pub(crate) session: String, + pub(crate) operation_id: String, + pub(crate) server_side_session_id: String, } -impl PlanIds { +impl ResponseBuilder { + /// Create a new response builder pub fn new( client_side_session_id: impl Into, server_side_session_id: impl Into, + ) -> Self { + Self::new_with_op_id( + client_side_session_id, + server_side_session_id, + Uuid::new_v4().to_string(), + ) + } + + pub fn new_with_op_id( + client_side_session_id: impl Into, + server_side_session_id: impl Into, + operation_id: impl Into, ) -> Self { let client_side_session_id = client_side_session_id.into(); let server_side_session_id = server_side_session_id.into(); + let operation_id = operation_id.into(); + Self { session: client_side_session_id, - server_side_session: server_side_session_id, - operation: Uuid::new_v4().to_string(), + server_side_session_id, + operation_id, } } - pub fn finished(&self) -> ExecutePlanResponse { + /// Send a result complete response to the client + pub fn result_complete_response(&self) -> ExecutePlanResponse { ExecutePlanResponse { session_id: self.session.to_string(), - server_side_session_id: self.server_side_session.to_string(), - operation_id: self.operation.to_string(), + server_side_session_id: self.server_side_session_id.to_string(), + operation_id: self.operation_id.to_string(), response_id: Uuid::new_v4().to_string(), metrics: None, observed_metrics: vec![], @@ -48,7 +60,8 @@ impl PlanIds { } } - pub fn gen_response(&self, table: &Table) -> eyre::Result { + /// Send an arrow batch response to the client + pub fn arrow_batch_response(&self, table: &Table) -> eyre::Result { let mut data = Vec::new(); let mut writer = StreamWriter::new( @@ -76,8 +89,8 @@ impl PlanIds { let response = ExecutePlanResponse { session_id: self.session.to_string(), - server_side_session_id: self.server_side_session.to_string(), - operation_id: self.operation.to_string(), + server_side_session_id: self.server_side_session_id.to_string(), + operation_id: self.operation_id.to_string(), response_id: Uuid::new_v4().to_string(), // todo: implement this metrics: None, // todo: implement this observed_metrics: vec![], diff --git a/src/daft-connect/src/session.rs b/src/daft-connect/src/session.rs index 7de8d5851b..ae618e4c8b 100644 --- a/src/daft-connect/src/session.rs +++ b/src/daft-connect/src/session.rs @@ -1,8 +1,9 @@ -use std::collections::BTreeMap; +use std::{collections::BTreeMap, sync::Arc}; use daft_micropartition::partitioning::InMemoryPartitionSetCache; use uuid::Uuid; +#[derive(Clone)] pub struct Session { /// so order is preserved, and so we can efficiently do a prefix search /// @@ -13,7 +14,7 @@ pub struct Session { server_side_session_id: String, /// MicroPartitionSet associated with this session /// this will be filled up as the user runs queries - pub(crate) psets: InMemoryPartitionSetCache, + pub(crate) psets: Arc, } impl Session { @@ -32,7 +33,7 @@ impl Session { config_values: Default::default(), id, server_side_session_id, - psets: InMemoryPartitionSetCache::empty(), + psets: Arc::new(InMemoryPartitionSetCache::empty()), } } diff --git a/src/daft-connect/src/translation/datatype.rs b/src/daft-connect/src/translation/datatype.rs index d6e51250c7..a397a5c0d8 100644 --- a/src/daft-connect/src/translation/datatype.rs +++ b/src/daft-connect/src/translation/datatype.rs @@ -1,7 +1,7 @@ use daft_schema::{dtype::DataType, field::Field, time_unit::TimeUnit}; use eyre::{bail, ensure, WrapErr}; use spark_connect::data_type::Kind; -use tracing::warn; +use tracing::debug; pub fn to_spark_datatype(datatype: &DataType) -> spark_connect::DataType { match datatype { @@ -73,7 +73,7 @@ pub fn to_spark_datatype(datatype: &DataType) -> spark_connect::DataType { })), }, DataType::Timestamp(unit, _) => { - warn!("Ignoring time unit {unit:?} for timestamp type"); + debug!("Ignoring time unit {unit:?} for timestamp type"); spark_connect::DataType { kind: Some(Kind::Timestamp(spark_connect::data_type::Timestamp { type_variation_reference: 0, diff --git a/src/daft-connect/src/translation/expr.rs b/src/daft-connect/src/translation/expr.rs index 0354dc504c..ec3d9a8320 100644 --- a/src/daft-connect/src/translation/expr.rs +++ b/src/daft-connect/src/translation/expr.rs @@ -9,7 +9,7 @@ use spark_connect::{ }, Expression, }; -use tracing::warn; +use tracing::debug; use unresolved_function::unresolved_to_daft_expr; use crate::translation::{to_daft_datatype, to_daft_literal}; @@ -19,7 +19,7 @@ mod unresolved_function; pub fn to_daft_expr(expression: &Expression) -> eyre::Result { if let Some(common) = &expression.common { if common.origin.is_some() { - warn!("Ignoring common metadata for relation: {common:?}; not yet implemented"); + debug!("Ignoring common metadata for relation: {common:?}; not yet implemented"); } }; @@ -37,11 +37,11 @@ pub fn to_daft_expr(expression: &Expression) -> eyre::Result } = attr; if let Some(plan_id) = plan_id { - warn!("Ignoring plan_id {plan_id} for attribute expressions; not yet implemented"); + debug!("Ignoring plan_id {plan_id} for attribute expressions; not yet implemented"); } if let Some(is_metadata_column) = is_metadata_column { - warn!("Ignoring is_metadata_column {is_metadata_column} for attribute expressions; not yet implemented"); + debug!("Ignoring is_metadata_column {is_metadata_column} for attribute expressions; not yet implemented"); } Ok(daft_dsl::col(unparsed_identifier.as_str())) @@ -109,7 +109,7 @@ pub fn to_daft_expr(expression: &Expression) -> eyre::Result let eval_mode = EvalMode::try_from(*eval_mode) .wrap_err_with(|| format!("Invalid cast eval mode: {eval_mode}"))?; - warn!("Ignoring cast eval mode: {eval_mode:?}"); + debug!("Ignoring cast eval mode: {eval_mode:?}"); Ok(expr.cast(&data_type)) } diff --git a/src/daft-connect/src/translation/logical_plan.rs b/src/daft-connect/src/translation/logical_plan.rs index 5bf831756e..eeb380e6df 100644 --- a/src/daft-connect/src/translation/logical_plan.rs +++ b/src/daft-connect/src/translation/logical_plan.rs @@ -1,22 +1,22 @@ use std::sync::Arc; -use common_daft_config::DaftExecutionConfig; use daft_core::prelude::Schema; use daft_dsl::LiteralValue; -use daft_local_execution::NativeExecutor; -use daft_logical_plan::LogicalPlanBuilder; +use daft_logical_plan::{LogicalPlanBuilder, PyLogicalPlanBuilder}; use daft_micropartition::{ partitioning::{ - InMemoryPartitionSetCache, MicroPartitionSet, PartitionCacheEntry, PartitionMetadata, - PartitionSet, PartitionSetCache, + MicroPartitionSet, PartitionCacheEntry, PartitionMetadata, PartitionSet, PartitionSetCache, }, + python::PyMicroPartition, MicroPartition, }; use daft_table::Table; use eyre::{bail, Context}; use futures::TryStreamExt; use spark_connect::{relation::RelType, Limit, Relation, ShowString}; -use tracing::warn; +use tracing::debug; + +use crate::{session::Session, Runner}; mod aggregate; mod drop; @@ -29,42 +29,74 @@ mod to_df; mod with_columns; mod with_columns_renamed; +use pyo3::{intern, prelude::*}; + +#[derive(Clone)] pub struct SparkAnalyzer<'a> { - pub psets: &'a InMemoryPartitionSetCache, + pub session: &'a Session, } impl SparkAnalyzer<'_> { - pub fn new(pset: &InMemoryPartitionSetCache) -> SparkAnalyzer { - SparkAnalyzer { psets: pset } + pub fn new(session: &Session) -> SparkAnalyzer<'_> { + SparkAnalyzer { session } } + pub fn create_in_memory_scan( &self, plan_id: usize, schema: Arc, tables: Vec, ) -> eyre::Result { - let partition_key = uuid::Uuid::new_v4().to_string(); + let runner = self.session.get_runner()?; - let pset = Arc::new(MicroPartitionSet::from_tables(plan_id, tables)?); + match runner { + Runner::Ray => { + let mp = + MicroPartition::new_loaded(tables[0].schema.clone(), Arc::new(tables), None); + Python::with_gil(|py| { + // Convert MicroPartition to a logical plan using Python interop. + let py_micropartition = py + .import(intern!(py, "daft.table"))? + .getattr(intern!(py, "MicroPartition"))? + .getattr(intern!(py, "_from_pymicropartition"))? + .call1((PyMicroPartition::from(mp),))?; - let PartitionMetadata { - num_rows, - size_bytes, - } = pset.metadata(); - let num_partitions = pset.num_partitions(); + // ERROR: 2: AttributeError: 'daft.daft.PySchema' object has no attribute '_schema' + let py_plan_builder = py + .import(intern!(py, "daft.dataframe.dataframe"))? + .getattr(intern!(py, "to_logical_plan_builder"))? + .call1((py_micropartition,))?; + let py_plan_builder = py_plan_builder.getattr(intern!(py, "_builder"))?; + let plan: PyLogicalPlanBuilder = py_plan_builder.extract()?; - self.psets.put_partition_set(&partition_key, &pset); + Ok::<_, eyre::Error>(dbg!(plan.builder)) + }) + } + Runner::Native => { + let partition_key = uuid::Uuid::new_v4().to_string(); - let cache_entry = PartitionCacheEntry::new_rust(partition_key.clone(), pset); + let pset = Arc::new(MicroPartitionSet::from_tables(plan_id, tables)?); - Ok(LogicalPlanBuilder::in_memory_scan( - &partition_key, - cache_entry, - schema, - num_partitions, - size_bytes, - num_rows, - )?) + let PartitionMetadata { + num_rows, + size_bytes, + } = pset.metadata(); + let num_partitions = pset.num_partitions(); + + self.session.psets.put_partition_set(&partition_key, &pset); + + let cache_entry = PartitionCacheEntry::new_rust(partition_key.clone(), pset); + + Ok(LogicalPlanBuilder::in_memory_scan( + &partition_key, + cache_entry, + schema, + num_partitions, + size_bytes, + num_rows, + )?) + } + } } pub async fn to_logical_plan(&self, relation: Relation) -> eyre::Result { @@ -73,7 +105,7 @@ impl SparkAnalyzer<'_> { }; if common.origin.is_some() { - warn!("Ignoring common metadata for relation: {common:?}; not yet implemented"); + debug!("Ignoring common metadata for relation: {common:?}; not yet implemented"); } let Some(rel_type) = relation.rel_type else { @@ -134,7 +166,7 @@ impl SparkAnalyzer<'_> { .await .wrap_err("Failed to show string") } - plan => bail!("Unsupported relation type: {plan:?}"), + plan => bail!("Unsupported relation type: \"{}\"", rel_name(&plan)), } } @@ -176,15 +208,13 @@ impl SparkAnalyzer<'_> { let plan = Box::pin(self.to_logical_plan(*input)).await?; let plan = plan.limit(num_rows as i64, true)?; - let optimized_plan = tokio::task::spawn_blocking(move || plan.optimize()) - .await - .unwrap()?; + let results = self.session.run_query(plan).await?; + let results = results.try_collect::>().await?; + let single_batch = results + .into_iter() + .next() + .ok_or_else(|| eyre::eyre!("No results"))?; - let cfg = Arc::new(DaftExecutionConfig::default()); - let native_executor = NativeExecutor::from_logical_plan_builder(&optimized_plan)?; - let result_stream = native_executor.run(self.psets, cfg, None)?.into_stream(); - let batch = result_stream.try_collect::>().await?; - let single_batch = MicroPartition::concat(batch)?; let tbls = single_batch.get_tables()?; let tbl = Table::concat(&tbls)?; let output = tbl.to_comfy_table(None).to_string(); @@ -199,3 +229,63 @@ impl SparkAnalyzer<'_> { self.create_in_memory_scan(plan_id as _, schema, vec![tbl]) } } + +fn rel_name(rel: &RelType) -> &str { + match rel { + RelType::Read(_) => "Read", + RelType::Project(_) => "Project", + RelType::Filter(_) => "Filter", + RelType::Join(_) => "Join", + RelType::SetOp(_) => "SetOp", + RelType::Sort(_) => "Sort", + RelType::Limit(_) => "Limit", + RelType::Aggregate(_) => "Aggregate", + RelType::Sql(_) => "Sql", + RelType::LocalRelation(_) => "LocalRelation", + RelType::Sample(_) => "Sample", + RelType::Offset(_) => "Offset", + RelType::Deduplicate(_) => "Deduplicate", + RelType::Range(_) => "Range", + RelType::SubqueryAlias(_) => "SubqueryAlias", + RelType::Repartition(_) => "Repartition", + RelType::ToDf(_) => "ToDf", + RelType::WithColumnsRenamed(_) => "WithColumnsRenamed", + RelType::ShowString(_) => "ShowString", + RelType::Drop(_) => "Drop", + RelType::Tail(_) => "Tail", + RelType::WithColumns(_) => "WithColumns", + RelType::Hint(_) => "Hint", + RelType::Unpivot(_) => "Unpivot", + RelType::ToSchema(_) => "ToSchema", + RelType::RepartitionByExpression(_) => "RepartitionByExpression", + RelType::MapPartitions(_) => "MapPartitions", + RelType::CollectMetrics(_) => "CollectMetrics", + RelType::Parse(_) => "Parse", + RelType::GroupMap(_) => "GroupMap", + RelType::CoGroupMap(_) => "CoGroupMap", + RelType::WithWatermark(_) => "WithWatermark", + RelType::ApplyInPandasWithState(_) => "ApplyInPandasWithState", + RelType::HtmlString(_) => "HtmlString", + RelType::CachedLocalRelation(_) => "CachedLocalRelation", + RelType::CachedRemoteRelation(_) => "CachedRemoteRelation", + RelType::CommonInlineUserDefinedTableFunction(_) => "CommonInlineUserDefinedTableFunction", + RelType::AsOfJoin(_) => "AsOfJoin", + RelType::CommonInlineUserDefinedDataSource(_) => "CommonInlineUserDefinedDataSource", + RelType::WithRelations(_) => "WithRelations", + RelType::Transpose(_) => "Transpose", + RelType::FillNa(_) => "FillNa", + RelType::DropNa(_) => "DropNa", + RelType::Replace(_) => "Replace", + RelType::Summary(_) => "Summary", + RelType::Crosstab(_) => "Crosstab", + RelType::Describe(_) => "Describe", + RelType::Cov(_) => "Cov", + RelType::Corr(_) => "Corr", + RelType::ApproxQuantile(_) => "ApproxQuantile", + RelType::FreqItems(_) => "FreqItems", + RelType::SampleBy(_) => "SampleBy", + RelType::Catalog(_) => "Catalog", + RelType::Extension(_) => "Extension", + RelType::Unknown(_) => "Unknown", + } +} diff --git a/src/daft-connect/src/translation/logical_plan/range.rs b/src/daft-connect/src/translation/logical_plan/range.rs index c1ec7197ad..024f518060 100644 --- a/src/daft-connect/src/translation/logical_plan/range.rs +++ b/src/daft-connect/src/translation/logical_plan/range.rs @@ -6,57 +6,47 @@ use super::SparkAnalyzer; impl SparkAnalyzer<'_> { pub fn range(&self, range: Range) -> eyre::Result { - #[cfg(not(feature = "python"))] - { - use eyre::bail; - bail!("Range operations require Python feature to be enabled"); - } - - #[cfg(feature = "python")] - { - use daft_scan::python::pylib::ScanOperatorHandle; - use pyo3::prelude::*; - let Range { - start, - end, - step, - num_partitions, - } = range; - - let partitions = num_partitions.unwrap_or(1); - - ensure!(partitions > 0, "num_partitions must be greater than 0"); - - let start = start.unwrap_or(0); - - let step = usize::try_from(step).wrap_err("step must be a positive integer")?; - ensure!(step > 0, "step must be greater than 0"); - - let plan = Python::with_gil(|py| { - let range_module = PyModule::import(py, "daft.io._range") - .wrap_err("Failed to import range module")?; - - let range = range_module - .getattr(pyo3::intern!(py, "RangeScanOperator")) - .wrap_err("Failed to get range function")?; - - let range = range - .call1((start, end, step, partitions)) - .wrap_err("Failed to create range scan operator")? - .into_pyobject(py) - .unwrap() - .unbind(); - - let scan_operator_handle = - ScanOperatorHandle::from_python_scan_operator(range, py)?; - - let plan = LogicalPlanBuilder::table_scan(scan_operator_handle.into(), None)?; - - eyre::Result::<_>::Ok(plan) - }) - .wrap_err("Failed to create range scan")?; - - Ok(plan) - } + use daft_scan::python::pylib::ScanOperatorHandle; + use pyo3::prelude::*; + let Range { + start, + end, + step, + num_partitions, + } = range; + + let partitions = num_partitions.unwrap_or(1); + + ensure!(partitions > 0, "num_partitions must be greater than 0"); + + let start = start.unwrap_or(0); + + let step = usize::try_from(step).wrap_err("step must be a positive integer")?; + ensure!(step > 0, "step must be greater than 0"); + + let plan = Python::with_gil(|py| { + let range_module = + PyModule::import(py, "daft.io._range").wrap_err("Failed to import range module")?; + + let range = range_module + .getattr(pyo3::intern!(py, "RangeScanOperator")) + .wrap_err("Failed to get range function")?; + + let range = range + .call1((start, end, step, partitions)) + .wrap_err("Failed to create range scan operator")? + .into_pyobject(py) + .unwrap() + .unbind(); + + let scan_operator_handle = ScanOperatorHandle::from_python_scan_operator(range, py)?; + + let plan = LogicalPlanBuilder::table_scan(scan_operator_handle.into(), None)?; + + eyre::Result::<_>::Ok(plan) + }) + .wrap_err("Failed to create range scan")?; + + Ok(plan) } } diff --git a/src/daft-connect/src/translation/logical_plan/read.rs b/src/daft-connect/src/translation/logical_plan/read.rs index 9a73783191..077a826911 100644 --- a/src/daft-connect/src/translation/logical_plan/read.rs +++ b/src/daft-connect/src/translation/logical_plan/read.rs @@ -1,7 +1,7 @@ use daft_logical_plan::LogicalPlanBuilder; use eyre::{bail, WrapErr}; use spark_connect::read::ReadType; -use tracing::warn; +use tracing::debug; mod data_source; @@ -11,7 +11,7 @@ pub async fn read(read: spark_connect::Read) -> eyre::Result read_type, } = read; - warn!("Ignoring is_streaming: {is_streaming}"); + debug!("Ignoring is_streaming: {is_streaming}"); let Some(read_type) = read_type else { bail!("Read type is required"); diff --git a/src/daft-connect/src/translation/logical_plan/read/data_source.rs b/src/daft-connect/src/translation/logical_plan/read/data_source.rs index 863b5e8f1d..4ae758ea33 100644 --- a/src/daft-connect/src/translation/logical_plan/read/data_source.rs +++ b/src/daft-connect/src/translation/logical_plan/read/data_source.rs @@ -1,7 +1,7 @@ use daft_logical_plan::LogicalPlanBuilder; use daft_scan::builder::{CsvScanBuilder, ParquetScanBuilder}; use eyre::{bail, ensure, WrapErr}; -use tracing::warn; +use tracing::debug; pub async fn data_source( data_source: spark_connect::read::DataSource, @@ -21,15 +21,15 @@ pub async fn data_source( ensure!(!paths.is_empty(), "Paths are required"); if let Some(schema) = schema { - warn!("Ignoring schema: {schema:?}; not yet implemented"); + debug!("Ignoring schema: {schema:?}; not yet implemented"); } if !options.is_empty() { - warn!("Ignoring options: {options:?}; not yet implemented"); + debug!("Ignoring options: {options:?}; not yet implemented"); } if !predicates.is_empty() { - warn!("Ignoring predicates: {predicates:?}; not yet implemented"); + debug!("Ignoring predicates: {predicates:?}; not yet implemented"); } let plan = match &*format { diff --git a/src/daft-connect/src/translation.rs b/src/daft-connect/src/translation/mod.rs similarity index 86% rename from src/daft-connect/src/translation.rs rename to src/daft-connect/src/translation/mod.rs index 73dc2f998d..cdedc68c63 100644 --- a/src/daft-connect/src/translation.rs +++ b/src/daft-connect/src/translation/mod.rs @@ -10,4 +10,3 @@ pub use datatype::{to_daft_datatype, to_spark_datatype}; pub use expr::to_daft_expr; pub use literal::to_daft_literal; pub use logical_plan::SparkAnalyzer; -pub use schema::relation_to_spark_schema; diff --git a/src/daft-connect/src/translation/schema.rs b/src/daft-connect/src/translation/schema.rs index 0cbd3cd7a1..f71bdb672a 100644 --- a/src/daft-connect/src/translation/schema.rs +++ b/src/daft-connect/src/translation/schema.rs @@ -1,54 +1,52 @@ -use daft_micropartition::partitioning::InMemoryPartitionSetCache; use daft_schema::schema::SchemaRef; use spark_connect::{ data_type::{Kind, Struct, StructField}, DataType, Relation, }; -use tracing::warn; +use tracing::debug; use super::SparkAnalyzer; use crate::translation::to_spark_datatype; -#[tracing::instrument(skip_all)] -pub async fn relation_to_spark_schema(input: Relation) -> eyre::Result { - let result = relation_to_daft_schema(input).await?; - - let fields: eyre::Result> = result - .fields - .iter() - .map(|(name, field)| { - let field_type = to_spark_datatype(&field.dtype); - Ok(StructField { - name: name.clone(), // todo(correctness): name vs field.name... will they always be the same? - data_type: Some(field_type), - nullable: true, // todo(correctness): is this correct? - metadata: None, // todo(completeness): might want to add metadata here +impl SparkAnalyzer<'_> { + #[tracing::instrument(skip_all)] + pub async fn relation_to_spark_schema(&self, input: Relation) -> eyre::Result { + let result = self.relation_to_daft_schema(input).await?; + + let fields: eyre::Result> = result + .fields + .iter() + .map(|(name, field)| { + let field_type = to_spark_datatype(&field.dtype); + Ok(StructField { + name: name.clone(), // todo(correctness): name vs field.name... will they always be the same? + data_type: Some(field_type), + nullable: true, // todo(correctness): is this correct? + metadata: None, // todo(completeness): might want to add metadata here + }) }) + .collect(); + + Ok(DataType { + kind: Some(Kind::Struct(Struct { + fields: fields?, + type_variation_reference: 0, + })), }) - .collect(); - - Ok(DataType { - kind: Some(Kind::Struct(Struct { - fields: fields?, - type_variation_reference: 0, - })), - }) -} + } -#[tracing::instrument(skip_all)] -pub async fn relation_to_daft_schema(input: Relation) -> eyre::Result { - if let Some(common) = &input.common { - if common.origin.is_some() { - warn!("Ignoring common metadata for relation: {common:?}; not yet implemented"); + #[tracing::instrument(skip_all)] + pub async fn relation_to_daft_schema(&self, input: Relation) -> eyre::Result { + if let Some(common) = &input.common { + if common.origin.is_some() { + debug!("Ignoring common metadata for relation: {common:?}; not yet implemented"); + } } - } - // We're just checking the schema here, so we don't need to use a persistent cache as it won't be used - let pset = InMemoryPartitionSetCache::empty(); - let translator = SparkAnalyzer::new(&pset); - let plan = Box::pin(translator.to_logical_plan(input)).await?; + let plan = Box::pin(self.to_logical_plan(input)).await?; - let result = plan.schema(); + let result = plan.schema(); - Ok(result) + Ok(result) + } } diff --git a/src/daft-logical-plan/Cargo.toml b/src/daft-logical-plan/Cargo.toml index cf70c38998..6d5f6b2fb0 100644 --- a/src/daft-logical-plan/Cargo.toml +++ b/src/daft-logical-plan/Cargo.toml @@ -39,6 +39,7 @@ python = [ "common-io-config/python", "common-daft-config/python", "common-resource-request/python", + "common-partitioning/python", "common-scan-info/python", "daft-core/python", "daft-dsl/python", diff --git a/src/daft-micropartition/src/partitioning.rs b/src/daft-micropartition/src/partitioning.rs index 76667a8618..a2d8d19c00 100644 --- a/src/daft-micropartition/src/partitioning.rs +++ b/src/daft-micropartition/src/partitioning.rs @@ -25,6 +25,7 @@ impl Partition for MicroPartition { pub struct MicroPartitionSet { pub partitions: DashMap, } + impl From> for MicroPartitionSet { fn from(value: Vec) -> Self { let partitions = value diff --git a/src/daft-micropartition/src/python.rs b/src/daft-micropartition/src/python.rs index 6ed01c7a7a..53a302dc3e 100644 --- a/src/daft-micropartition/src/python.rs +++ b/src/daft-micropartition/src/python.rs @@ -1,6 +1,7 @@ use std::sync::{Arc, Mutex}; use common_error::DaftResult; +use common_partitioning::Partition; use daft_core::{ join::JoinSide, prelude::*, @@ -25,9 +26,9 @@ use crate::{ }; #[pyclass(module = "daft.daft", frozen)] -#[derive(Clone)] +#[derive(Clone, Debug)] pub struct PyMicroPartition { - inner: Arc, + pub inner: Arc, } #[pymethods] diff --git a/src/daft-ray-execution/Cargo.toml b/src/daft-ray-execution/Cargo.toml new file mode 100644 index 0000000000..98b689c865 --- /dev/null +++ b/src/daft-ray-execution/Cargo.toml @@ -0,0 +1,22 @@ +[dependencies] +common-error = {workspace = true} +daft-logical-plan = {workspace = true} +daft-micropartition = {workspace = true} +pyo3 = {workspace = true, optional = true} + +[features] +default = ["python"] +python = [ + "dep:pyo3", + "common-error/python", + "daft-logical-plan/python", + "daft-micropartition/python" +] + +[lints] +workspace = true + +[package] +name = "daft-ray-execution" +edition.workspace = true +version.workspace = true diff --git a/src/daft-ray-execution/src/lib.rs b/src/daft-ray-execution/src/lib.rs new file mode 100644 index 0000000000..2180a54e45 --- /dev/null +++ b/src/daft-ray-execution/src/lib.rs @@ -0,0 +1,74 @@ +//! Wrapper around the python RayRunner class +#[cfg(feature = "python")] +use common_error::{DaftError, DaftResult}; +#[cfg(feature = "python")] +use daft_logical_plan::{LogicalPlanBuilder, PyLogicalPlanBuilder}; +#[cfg(feature = "python")] +use daft_micropartition::{python::PyMicroPartition, MicroPartitionRef}; +#[cfg(feature = "python")] +use pyo3::{ + intern, + prelude::*, + types::{PyDict, PyIterator}, +}; + +#[cfg(feature = "python")] +pub struct RayEngine { + ray_runner: PyObject, +} + +#[cfg(feature = "python")] +impl RayEngine { + pub fn try_new( + address: Option, + max_task_backlog: Option, + force_client_mode: Option, + ) -> DaftResult { + Python::with_gil(|py| { + let ray_runner_module = py.import(intern!(py, "daft.runners.ray_runner"))?; + let ray_runner = ray_runner_module.getattr(intern!(py, "RayRunner"))?; + let kwargs = PyDict::new(py); + kwargs.set_item(intern!(py, "address"), address)?; + kwargs.set_item(intern!(py, "max_task_backlog"), max_task_backlog)?; + kwargs.set_item(intern!(py, "force_client_mode"), force_client_mode)?; + + let instance = ray_runner.call((), Some(&kwargs))?; + let instance = instance.unbind(); + + Ok(Self { + ray_runner: instance, + }) + }) + } + + pub fn run_iter_impl( + &self, + py: Python<'_>, + lp: LogicalPlanBuilder, + results_buffer_size: Option, + ) -> DaftResult>> { + let py_lp = PyLogicalPlanBuilder::new(lp); + let builder = py.import(intern!(py, "daft.logical.builder"))?; + let builder = builder.getattr(intern!(py, "LogicalPlanBuilder"))?; + let builder = builder.call((py_lp,), None)?; + let result = self.ray_runner.call_method( + py, + intern!(py, "run_iter_tables"), + (builder, results_buffer_size), + None, + )?; + + let result = result.bind(py); + let iter = PyIterator::from_object(result)?; + + let iter = iter.map(|item| { + let item = item?; + let partition = item.getattr(intern!(py, "_micropartition"))?; + let partition = partition.extract::()?; + let partition = partition.inner; + Ok::<_, DaftError>(partition) + }); + + Ok(iter.collect()) + } +} diff --git a/src/daft-scheduler/src/scheduler.rs b/src/daft-scheduler/src/scheduler.rs index 8146950ce4..6e57f9c5b4 100644 --- a/src/daft-scheduler/src/scheduler.rs +++ b/src/daft-scheduler/src/scheduler.rs @@ -268,7 +268,6 @@ fn physical_plan_to_partition_tasks( ) -> PyResult { use daft_dsl::Expr; use daft_physical_plan::ops::{CrossJoin, ShuffleExchange, ShuffleExchangeStrategy}; - match physical_plan { PhysicalPlan::InMemoryScan(InMemoryScan { in_memory_info: InMemoryInfo { cache_key, .. },