diff --git a/Cargo.lock b/Cargo.lock index 1dc39d7bf9..13144b7b7b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2008,9 +2008,9 @@ version = "0.3.0-dev0" dependencies = [ "arrow2", "async-stream", - "common-daft-config", "common-error", "common-file-formats", + "common-runtime", "daft-catalog", "daft-core", "daft-dsl", @@ -2378,6 +2378,7 @@ dependencies = [ "serde", "snafu", "test-log", + "tokio", "typed-builder 0.20.0", "uuid 1.11.0", ] diff --git a/daft/daft/__init__.pyi b/daft/daft/__init__.pyi index bb8a1b8cd5..86e3c3b2ca 100644 --- a/daft/daft/__init__.pyi +++ b/daft/daft/__init__.pyi @@ -1695,10 +1695,7 @@ class LogicalPlanBuilder: def repr_mermaid(self, options: MermaidOptions) -> str: ... class NativeExecutor: - @staticmethod - def from_logical_plan_builder( - logical_plan_builder: LogicalPlanBuilder, - ) -> NativeExecutor: ... + def __init__(self) -> None: ... def run( self, psets: dict[str, list[PartitionT]], cfg: PyDaftExecutionConfig, results_buffer_size: int | None ) -> Iterator[PyMicroPartition]: ... diff --git a/daft/execution/native_executor.py b/daft/execution/native_executor.py index 1958c6b90f..333db5fc4c 100644 --- a/daft/execution/native_executor.py +++ b/daft/execution/native_executor.py @@ -5,10 +5,10 @@ from daft.daft import ( NativeExecutor as _NativeExecutor, ) -from daft.daft import PyDaftExecutionConfig from daft.table import MicroPartition if TYPE_CHECKING: + from daft.daft import PyDaftExecutionConfig from daft.logical.builder import LogicalPlanBuilder from daft.runners.partitioning import ( LocalMaterializedResult, @@ -18,16 +18,12 @@ class NativeExecutor: - def __init__(self, executor: _NativeExecutor): - self._executor = executor - - @classmethod - def from_logical_plan_builder(cls, builder: LogicalPlanBuilder) -> NativeExecutor: - executor = _NativeExecutor.from_logical_plan_builder(builder._builder) - return cls(executor) + def __init__(self): + self._executor = _NativeExecutor() def run( self, + builder: LogicalPlanBuilder, psets: dict[str, list[MaterializedResult[PartitionT]]], daft_execution_config: PyDaftExecutionConfig, results_buffer_size: int | None, @@ -39,5 +35,5 @@ def run( } return ( LocalMaterializedResult(MicroPartition._from_pymicropartition(part)) - for part in self._executor.run(psets_mp, daft_execution_config, results_buffer_size) + for part in self._executor.run(builder._builder, psets_mp, daft_execution_config, results_buffer_size) ) diff --git a/daft/runners/native_runner.py b/daft/runners/native_runner.py index c7e5ce8034..a03e14c93a 100644 --- a/daft/runners/native_runner.py +++ b/daft/runners/native_runner.py @@ -75,8 +75,9 @@ def run_iter( # Optimize the logical plan. builder = builder.optimize() - executor = NativeExecutor.from_logical_plan_builder(builder) + executor = NativeExecutor() results_gen = executor.run( + builder, {k: v.values() for k, v in self._part_set_cache.get_all_partition_sets().items()}, daft_execution_config, results_buffer_size, diff --git a/daft/runners/pyrunner.py b/daft/runners/pyrunner.py index 450bc4eb57..48be64921b 100644 --- a/daft/runners/pyrunner.py +++ b/daft/runners/pyrunner.py @@ -380,8 +380,9 @@ def run_iter( if daft_execution_config.enable_native_executor: logger.info("Using native executor") - executor = NativeExecutor.from_logical_plan_builder(builder) + executor = NativeExecutor() results_gen = executor.run( + builder, {k: v.values() for k, v in self._part_set_cache.get_all_partition_sets().items()}, daft_execution_config, results_buffer_size, diff --git a/src/common/runtime/src/lib.rs b/src/common/runtime/src/lib.rs index 2c8fc6acdd..df222fcfe9 100644 --- a/src/common/runtime/src/lib.rs +++ b/src/common/runtime/src/lib.rs @@ -69,13 +69,16 @@ impl Future for RuntimeTask { } pub struct Runtime { - runtime: tokio::runtime::Runtime, + pub runtime: Arc, pool_type: PoolType, } impl Runtime { pub(crate) fn new(runtime: tokio::runtime::Runtime, pool_type: PoolType) -> RuntimeRef { - Arc::new(Self { runtime, pool_type }) + Arc::new(Self { + runtime: Arc::new(runtime), + pool_type, + }) } async fn execute_task(future: F, pool_type: PoolType) -> DaftResult diff --git a/src/daft-connect/Cargo.toml b/src/daft-connect/Cargo.toml index 710d5ee472..55c972b219 100644 --- a/src/daft-connect/Cargo.toml +++ b/src/daft-connect/Cargo.toml @@ -1,7 +1,6 @@ [dependencies] arrow2 = {workspace = true, features = ["io_json_integration"]} async-stream = "0.3.6" -common-daft-config = {workspace = true, optional = true, features = ["python"]} common-error = {workspace = true, optional = true, features = ["python"]} common-file-formats = {workspace = true, optional = true, features = ["python"]} daft-catalog = {path = "../daft-catalog", optional = true, features = ["python"]} @@ -27,12 +26,12 @@ tokio = {version = "1.40.0", features = ["full"]} tonic = "0.12.3" tracing = {workspace = true} uuid = {version = "1.10.0", features = ["v4"]} +common-runtime.workspace = true [features] default = ["python"] python = [ "dep:pyo3", - "dep:common-daft-config", "dep:common-error", "dep:common-file-formats", "dep:daft-core", diff --git a/src/daft-connect/src/execute.rs b/src/daft-connect/src/execute.rs index 23caca66b9..d6b443475b 100644 --- a/src/daft-connect/src/execute.rs +++ b/src/daft-connect/src/execute.rs @@ -1,10 +1,8 @@ use std::{future::ready, sync::Arc}; -use common_daft_config::DaftExecutionConfig; -use common_error::{DaftError, DaftResult}; +use common_error::DaftResult; use common_file_formats::FileFormat; use daft_dsl::LiteralValue; -use daft_local_execution::NativeExecutor; use daft_logical_plan::LogicalPlanBuilder; use daft_micropartition::MicroPartition; use daft_ray_execution::RayEngine; @@ -12,9 +10,8 @@ use daft_table::Table; use eyre::{bail, Context}; use futures::{ stream::{self, BoxStream}, - StreamExt, TryFutureExt, TryStreamExt, + StreamExt, TryStreamExt, }; -use itertools::Itertools; use pyo3::Python; use spark_connect::{ relation::RelType, @@ -63,17 +60,13 @@ impl Session { Runner::Native => { let this = self.clone(); - let result_stream = tokio::task::spawn_blocking(move || { - let plan = lp.optimize()?; - let cfg = Arc::new(DaftExecutionConfig::default()); - let native_executor = NativeExecutor::from_logical_plan_builder(&plan)?; - let results = native_executor.run(&*this.psets, cfg, None)?; - let it = results.into_iter(); - Ok::<_, DaftError>(it.collect_vec()) - }) - .await??; - Ok(Box::pin(stream::iter(result_stream))) + let plan = lp.optimize_async().await?; + + let results = this + .engine + .run(&plan, &*this.psets, Default::default(), None)?; + Ok(results.into_stream().boxed()) } } } @@ -85,14 +78,12 @@ impl Session { ) -> Result { use futures::{StreamExt, TryStreamExt}; - // fallback response let result_complete = res.result_complete_response(); let (tx, rx) = tokio::sync::mpsc::channel::>(1); let this = self.clone(); - - tokio::spawn(async move { + self.compute_runtime.runtime.spawn(async move { let execution_fut = async { let translator = SparkAnalyzer::new(&this); match command.rel_type { @@ -144,7 +135,7 @@ impl Session { pub async fn execute_write_operation( &self, operation: WriteOperation, - response_builder: ResponseBuilder, + res: ResponseBuilder, ) -> Result { fn check_write_operation(write_op: &WriteOperation) -> Result<(), Status> { if !write_op.sort_column_names.is_empty() { @@ -179,60 +170,70 @@ impl Session { } } - let finished = response_builder.result_complete_response(); + let finished = res.result_complete_response(); + + let (tx, rx) = tokio::sync::mpsc::channel::>(1); let this = self.clone(); - let result = async move { - check_write_operation(&operation)?; + self.compute_runtime.runtime.spawn(async move { + let result = async { + check_write_operation(&operation)?; - let WriteOperation { - input, - source, - save_type, - .. - } = operation; + let WriteOperation { + input, + source, + save_type, + .. + } = operation; - let input = input.required("input")?; - let source = source.required("source")?; + let input = input.required("input")?; + let source = source.required("source")?; - let file_format: FileFormat = source.parse()?; + let file_format: FileFormat = source.parse()?; - let Some(save_type) = save_type else { - bail!("Save type is required"); - }; + let Some(save_type) = save_type else { + bail!("Save type is required"); + }; - let path = match save_type { - SaveType::Path(path) => path, - SaveType::Table(_) => { - return not_yet_implemented!("write to table").map_err(|e| e.into()) - } - }; + let path = match save_type { + SaveType::Path(path) => path, + SaveType::Table(_) => { + return not_yet_implemented!("write to table").map_err(|e| e.into()) + } + }; - let translator = SparkAnalyzer::new(&this); + let translator = SparkAnalyzer::new(&this); - let plan = translator.to_logical_plan(input).await?; + let plan = translator.to_logical_plan(input).await?; - let plan = plan.table_write(&path, file_format, None, None, None)?; + let plan = plan.table_write(&path, file_format, None, None, None)?; - let mut result_stream = this.run_query(plan).await?; + let mut result_stream = this.run_query(plan).await?; - // this is so we make sure the operation is actually done - // before we return - // - // an example where this is important is if we write to a parquet file - // and then read immediately after, we need to wait for the write to finish - while let Some(_result) = result_stream.next().await {} + // this is so we make sure the operation is actually done + // before we return + // + // an example where this is important is if we write to a parquet file + // and then read immediately after, we need to wait for the write to finish + while let Some(_result) = result_stream.next().await {} - Ok(()) - }; + Ok(()) + }; - let result = result.map_err(|e| { - Status::internal(textwrap::wrap(&format!("Error in Daft server: {e}"), 120).join("\n")) + if let Err(e) = result.await { + let _ = tx.send(Err(e)).await; + } }); + let stream = ReceiverStream::new(rx); - let future = result.and_then(|()| ready(Ok(finished))); - let stream = futures::stream::once(future); + let stream = stream + .map_err(|e| { + Status::internal( + textwrap::wrap(&format!("Error in Daft server: {e}"), 120).join("\n"), + ) + }) + .chain(stream::once(ready(Ok(finished)))); Ok(Box::pin(stream)) } diff --git a/src/daft-connect/src/lib.rs b/src/daft-connect/src/lib.rs index 7e8b84aca9..23a182a271 100644 --- a/src/daft-connect/src/lib.rs +++ b/src/daft-connect/src/lib.rs @@ -84,10 +84,10 @@ pub fn start(addr: &str) -> eyre::Result { shutdown_signal: Some(shutdown_signal), port, }; + let runtime = common_runtime::get_io_runtime(true); std::thread::spawn(move || { - let runtime = tokio::runtime::Runtime::new().unwrap(); - let result = runtime.block_on(async { + let result = runtime.block_on_current_thread(async { let incoming = { let listener = tokio::net::TcpListener::from_std(listener) .wrap_err("Failed to create TcpListener from std::net::TcpListener")?; diff --git a/src/daft-connect/src/session.rs b/src/daft-connect/src/session.rs index 25a8024f76..234918d946 100644 --- a/src/daft-connect/src/session.rs +++ b/src/daft-connect/src/session.rs @@ -3,7 +3,9 @@ use std::{ sync::{Arc, RwLock}, }; +use common_runtime::RuntimeRef; use daft_catalog::DaftCatalog; +use daft_local_execution::NativeExecutor; use daft_micropartition::partitioning::InMemoryPartitionSetCache; use uuid::Uuid; @@ -19,6 +21,8 @@ pub struct Session { /// MicroPartitionSet associated with this session /// this will be filled up as the user runs queries pub(crate) psets: Arc, + pub(crate) compute_runtime: RuntimeRef, + pub(crate) engine: Arc, pub(crate) catalog: Arc>, } @@ -34,11 +38,15 @@ impl Session { pub fn new(id: String) -> Self { let server_side_session_id = Uuid::new_v4(); let server_side_session_id = server_side_session_id.to_string(); + let rt = common_runtime::get_compute_runtime(); + Self { config_values: Default::default(), id, server_side_session_id, psets: Arc::new(InMemoryPartitionSetCache::empty()), + compute_runtime: rt.clone(), + engine: Arc::new(NativeExecutor::default().with_runtime(rt.runtime.clone())), catalog: Arc::new(RwLock::new(DaftCatalog::default())), } } diff --git a/src/daft-local-execution/src/lib.rs b/src/daft-local-execution/src/lib.rs index e1752eac9a..ef6cdbe93b 100644 --- a/src/daft-local-execution/src/lib.rs +++ b/src/daft-local-execution/src/lib.rs @@ -26,7 +26,7 @@ use common_runtime::{RuntimeRef, RuntimeTask}; use lazy_static::lazy_static; use progress_bar::{OperatorProgressBar, ProgressBarColor, ProgressBarManager}; use resource_manager::MemoryManager; -pub use run::{run_local, ExecutionEngineResult, NativeExecutor}; +pub use run::{ExecutionEngineResult, NativeExecutor}; use runtime_stats::{RuntimeStatsContext, TimedFuture}; use snafu::{futures::TryFutureExt, ResultExt, Snafu}; use tracing::Instrument; @@ -124,7 +124,7 @@ pub(crate) struct ExecutionRuntimeContext { worker_set: TaskSet>, default_morsel_size: usize, memory_manager: Arc, - progress_bar_manager: Option>, + progress_bar_manager: Option>, } impl ExecutionRuntimeContext { @@ -132,7 +132,7 @@ impl ExecutionRuntimeContext { pub fn new( default_morsel_size: usize, memory_manager: Arc, - progress_bar_manager: Option>, + progress_bar_manager: Option>, ) -> Self { Self { worker_set: TaskSet::new(), diff --git a/src/daft-local-execution/src/progress_bar.rs b/src/daft-local-execution/src/progress_bar.rs index 3b42333d49..d865826da5 100644 --- a/src/daft-local-execution/src/progress_bar.rs +++ b/src/daft-local-execution/src/progress_bar.rs @@ -16,7 +16,7 @@ pub trait ProgressBar: Send + Sync { fn close(&self) -> DaftResult<()>; } -pub trait ProgressBarManager { +pub trait ProgressBarManager: std::fmt::Debug + Send + Sync { fn make_new_bar( &self, color: ProgressBarColor, @@ -128,6 +128,7 @@ impl ProgressBar for IndicatifProgressBar { } } +#[derive(Debug)] struct IndicatifProgressBarManager { multi_progress: indicatif::MultiProgress, } @@ -168,19 +169,19 @@ impl ProgressBarManager for IndicatifProgressBarManager { } } -pub fn make_progress_bar_manager() -> Box { +pub fn make_progress_bar_manager() -> Arc { #[cfg(feature = "python")] { if python::in_notebook() { - Box::new(python::TqdmProgressBarManager::new()) + Arc::new(python::TqdmProgressBarManager::new()) } else { - Box::new(IndicatifProgressBarManager::new()) + Arc::new(IndicatifProgressBarManager::new()) } } #[cfg(not(feature = "python"))] { - Box::new(IndicatifProgressBarManager::new()) + Arc::new(IndicatifProgressBarManager::new()) } } @@ -215,7 +216,7 @@ mod python { } } - #[derive(Clone)] + #[derive(Clone, Debug)] pub struct TqdmProgressBarManager { inner: Arc, } diff --git a/src/daft-local-execution/src/run.rs b/src/daft-local-execution/src/run.rs index c3dea72f7a..644914d523 100644 --- a/src/daft-local-execution/src/run.rs +++ b/src/daft-local-execution/src/run.rs @@ -9,7 +9,7 @@ use std::{ use common_daft_config::DaftExecutionConfig; use common_error::DaftResult; use common_tracing::refresh_chrome_trace; -use daft_local_plan::{translate, LocalPhysicalPlan}; +use daft_local_plan::translate; use daft_logical_plan::LogicalPlanBuilder; use daft_micropartition::{ partitioning::{InMemoryPartitionSetCache, MicroPartitionSet, PartitionSetCache}, @@ -31,7 +31,7 @@ use { use crate::{ channel::{create_channel, Receiver}, pipeline::{physical_plan_to_pipeline, viz_pipeline}, - progress_bar::make_progress_bar_manager, + progress_bar::{make_progress_bar_manager, ProgressBarManager}, resource_manager::get_or_init_memory_manager, Error, ExecutionRuntimeContext, }; @@ -62,25 +62,28 @@ pub struct PyNativeExecutor { executor: NativeExecutor, } +#[cfg(feature = "python")] +impl Default for PyNativeExecutor { + fn default() -> Self { + Self::new() + } +} + #[cfg(feature = "python")] #[pymethods] impl PyNativeExecutor { - #[staticmethod] - pub fn from_logical_plan_builder( - logical_plan_builder: &PyLogicalPlanBuilder, - py: Python, - ) -> PyResult { - py.allow_threads(|| { - Ok(Self { - executor: NativeExecutor::from_logical_plan_builder(&logical_plan_builder.builder)?, - }) - }) + #[new] + pub fn new() -> Self { + Self { + executor: NativeExecutor::new(), + } } - #[pyo3(signature = (psets, cfg, results_buffer_size=None))] + #[pyo3(signature = (logical_plan_builder, psets, cfg, results_buffer_size=None))] pub fn run<'a>( &self, py: Python<'a>, + logical_plan_builder: &PyLogicalPlanBuilder, psets: HashMap>, cfg: PyDaftExecutionConfig, results_buffer_size: Option, @@ -103,7 +106,12 @@ impl PyNativeExecutor { let psets = InMemoryPartitionSetCache::new(&native_psets); let out = py.allow_threads(|| { self.executor - .run(&psets, cfg.config, results_buffer_size) + .run( + &logical_plan_builder.builder, + &psets, + cfg.config, + results_buffer_size, + ) .map(|res| res.into_iter()) })?; let iter = Box::new(out.map(|part| { @@ -119,37 +127,134 @@ impl PyNativeExecutor { } } +#[derive(Debug, Clone)] pub struct NativeExecutor { - local_physical_plan: Arc, cancel: CancellationToken, + runtime: Option>, + pb_manager: Option>, + enable_explain_analyze: bool, +} + +impl Default for NativeExecutor { + fn default() -> Self { + Self { + cancel: CancellationToken::new(), + runtime: None, + pb_manager: should_enable_progress_bar().then(make_progress_bar_manager), + enable_explain_analyze: should_enable_explain_analyze(), + } + } } impl NativeExecutor { - pub fn from_logical_plan_builder( - logical_plan_builder: &LogicalPlanBuilder, - ) -> DaftResult { - let logical_plan = logical_plan_builder.build(); - let local_physical_plan = translate(&logical_plan)?; + pub fn new() -> Self { + Self::default() + } - Ok(Self { - local_physical_plan, - cancel: CancellationToken::new(), - }) + pub fn with_runtime(mut self, runtime: Arc) -> Self { + self.runtime = Some(runtime); + self + } + + pub fn with_progress_bar_manager(mut self, pb_manager: Arc) -> Self { + self.pb_manager = Some(pb_manager); + self + } + + pub fn enable_explain_analyze(mut self, b: bool) -> Self { + self.enable_explain_analyze = b; + self } pub fn run( &self, + logical_plan_builder: &LogicalPlanBuilder, psets: &(impl PartitionSetCache> + ?Sized), cfg: Arc, results_buffer_size: Option, ) -> DaftResult { - run_local( - &self.local_physical_plan, - psets, - cfg, - results_buffer_size, - self.cancel.clone(), - ) + let logical_plan = logical_plan_builder.build(); + let physical_plan = translate(&logical_plan)?; + refresh_chrome_trace(); + let cancel = self.cancel.clone(); + let pipeline = physical_plan_to_pipeline(&physical_plan, psets, &cfg)?; + let (tx, rx) = create_channel(results_buffer_size.unwrap_or(1)); + + let rt = self.runtime.clone(); + let pb_manager = self.pb_manager.clone(); + let enable_explain_analyze = self.enable_explain_analyze; + // todo: split this into a run and run_async method + // the run_async should spawn a task instead of a thread like this + let handle = std::thread::spawn(move || { + let runtime = rt.unwrap_or_else(|| { + Arc::new( + tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .expect("Failed to create tokio runtime"), + ) + }); + let execution_task = async { + let memory_manager = get_or_init_memory_manager(); + let mut runtime_handle = ExecutionRuntimeContext::new( + cfg.default_morsel_size, + memory_manager.clone(), + pb_manager, + ); + let receiver = pipeline.start(true, &mut runtime_handle)?; + + while let Some(val) = receiver.recv().await { + if tx.send(val).await.is_err() { + break; + } + } + + while let Some(result) = runtime_handle.join_next().await { + match result { + Ok(Err(e)) => { + runtime_handle.shutdown().await; + return DaftResult::Err(e.into()); + } + Err(e) => { + runtime_handle.shutdown().await; + return DaftResult::Err(Error::JoinError { source: e }.into()); + } + _ => {} + } + } + if enable_explain_analyze { + let curr_ms = SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("Time went backwards") + .as_millis(); + let file_name = format!("explain-analyze-{curr_ms}-mermaid.md"); + let mut file = File::create(file_name)?; + writeln!(file, "```mermaid\n{}\n```", viz_pipeline(pipeline.as_ref()))?; + } + Ok(()) + }; + + let local_set = tokio::task::LocalSet::new(); + local_set.block_on(&runtime, async { + tokio::select! { + biased; + () = cancel.cancelled() => { + log::info!("Execution engine cancelled"); + Ok(()) + } + _ = tokio::signal::ctrl_c() => { + log::info!("Received Ctrl-C, shutting down execution engine"); + Ok(()) + } + result = execution_task => result, + } + }) + }); + + Ok(ExecutionEngineResult { + handle, + receiver: rx, + }) } } @@ -270,82 +375,3 @@ impl IntoIterator for ExecutionEngineResult { } } } - -pub fn run_local( - physical_plan: &LocalPhysicalPlan, - psets: &(impl PartitionSetCache> + ?Sized), - cfg: Arc, - results_buffer_size: Option, - cancel: CancellationToken, -) -> DaftResult { - refresh_chrome_trace(); - let pipeline = physical_plan_to_pipeline(physical_plan, psets, &cfg)?; - let (tx, rx) = create_channel(results_buffer_size.unwrap_or(1)); - let handle = std::thread::spawn(move || { - let pb_manager = should_enable_progress_bar().then(make_progress_bar_manager); - let runtime = tokio::runtime::Builder::new_current_thread() - .enable_all() - .build() - .expect("Failed to create tokio runtime"); - let execution_task = async { - let memory_manager = get_or_init_memory_manager(); - let mut runtime_handle = ExecutionRuntimeContext::new( - cfg.default_morsel_size, - memory_manager.clone(), - pb_manager, - ); - let receiver = pipeline.start(true, &mut runtime_handle)?; - - while let Some(val) = receiver.recv().await { - if tx.send(val).await.is_err() { - break; - } - } - - while let Some(result) = runtime_handle.join_next().await { - match result { - Ok(Err(e)) => { - runtime_handle.shutdown().await; - return DaftResult::Err(e.into()); - } - Err(e) => { - runtime_handle.shutdown().await; - return DaftResult::Err(Error::JoinError { source: e }.into()); - } - _ => {} - } - } - if should_enable_explain_analyze() { - let curr_ms = SystemTime::now() - .duration_since(UNIX_EPOCH) - .expect("Time went backwards") - .as_millis(); - let file_name = format!("explain-analyze-{curr_ms}-mermaid.md"); - let mut file = File::create(file_name)?; - writeln!(file, "```mermaid\n{}\n```", viz_pipeline(pipeline.as_ref()))?; - } - Ok(()) - }; - - let local_set = tokio::task::LocalSet::new(); - local_set.block_on(&runtime, async { - tokio::select! { - biased; - () = cancel.cancelled() => { - log::info!("Execution engine cancelled"); - Ok(()) - } - _ = tokio::signal::ctrl_c() => { - log::info!("Received Ctrl-C, shutting down execution engine"); - Ok(()) - } - result = execution_task => result, - } - }) - }); - - Ok(ExecutionEngineResult { - handle, - receiver: rx, - }) -} diff --git a/src/daft-logical-plan/Cargo.toml b/src/daft-logical-plan/Cargo.toml index b46cf1d222..0ff2d2ac1c 100644 --- a/src/daft-logical-plan/Cargo.toml +++ b/src/daft-logical-plan/Cargo.toml @@ -21,6 +21,7 @@ log = {workspace = true} pyo3 = {workspace = true, optional = true} serde = {workspace = true, features = ["rc"]} snafu = {workspace = true} +tokio = {workspace = true} typed-builder = {workspace = true} uuid = {version = "1", features = ["v4"]} diff --git a/src/daft-logical-plan/src/builder/mod.rs b/src/daft-logical-plan/src/builder/mod.rs index 244a42f933..9006505fca 100644 --- a/src/daft-logical-plan/src/builder/mod.rs +++ b/src/daft-logical-plan/src/builder/mod.rs @@ -4,6 +4,7 @@ mod tests; use std::{ collections::{HashMap, HashSet}, + future::Future, sync::Arc, }; @@ -688,19 +689,79 @@ impl LogicalPlanBuilder { Ok(self.with_new_plan(logical_plan)) } + /// Async equivalent of `optimize` + /// This is safe to call from a tokio runtime + pub fn optimize_async(&self) -> impl Future> { + let cfg = self.config.clone(); + + // Run LogicalPlan optimizations + let unoptimized_plan = self.build(); + let (tx, rx) = tokio::sync::oneshot::channel(); + + std::thread::spawn(move || { + let optimizer = OptimizerBuilder::default() + .when( + cfg.as_ref() + .map_or(false, |conf| conf.enable_join_reordering), + |builder| builder.reorder_joins(), + ) + .simplify_expressions() + .build(); + + let optimized_plan = optimizer.optimize( + unoptimized_plan, + |new_plan, rule_batch, pass, transformed, seen| { + if transformed { + log::debug!( + "Rule batch {:?} transformed plan on pass {}, and produced {} plan:\n{}", + rule_batch, + pass, + if seen { "an already seen" } else { "a new" }, + new_plan.repr_ascii(true), + ); + } else { + log::debug!( + "Rule batch {:?} did NOT transform plan on pass {} for plan:\n{}", + rule_batch, + pass, + new_plan.repr_ascii(true), + ); + } + }, + ); + tx.send(optimized_plan).unwrap(); + }); + + let cfg = self.config.clone(); + async move { + rx.await + .map_err(|e| { + DaftError::InternalError(format!("Error optimizing logical plan: {:?}", e)) + })? + .map(|plan| Self::new(plan, cfg)) + } + } + + /// optimize the logical plan + /// + /// **Important**: Do not call this method from the main thread as there is a `block_on` call deep within this method + /// Calling will result in a runtime panic pub fn optimize(&self) -> DaftResult { + // TODO: remove the `block_on` to make this method safe to call from the main thread + + let cfg = self.config.clone(); + + let unoptimized_plan = self.build(); + let optimizer = OptimizerBuilder::default() .when( - self.config - .as_ref() + cfg.as_ref() .map_or(false, |conf| conf.enable_join_reordering), |builder| builder.reorder_joins(), ) .simplify_expressions() .build(); - // Run LogicalPlan optimizations - let unoptimized_plan = self.build(); let optimized_plan = optimizer.optimize( unoptimized_plan, |new_plan, rule_batch, pass, transformed, seen| { @@ -723,7 +784,7 @@ impl LogicalPlanBuilder { }, )?; - let builder = Self::new(optimized_plan, self.config.clone()); + let builder = Self::new(optimized_plan, cfg); Ok(builder) }