Skip to content

Commit

Permalink
refactor(execution): NativeExecutor refactor (#3689)
Browse files Browse the repository at this point in the history
Small refactor to `NativeExecutor` to make it a bit more flexible for
spark connect.


some changes made

- make the following configurable:
  - runtime
  - progress_bar_manager
  - enable_explain_analyze
- refactor the engine to take in the plan when performing `.run` not at
creation.
- this is useful as we can now reuse an engine instead of having to
create a new one for every plan.
- Note: we could probably follow up with a refactor the python code to
have a single "native" engine instead of creating it every time
  
The runtime configuration is especially important for spark connect as
previously we were using `spawn_blocking` when running a plan, but now
we can leverage the existing compute runtime instead of creating a new
one.
  • Loading branch information
universalmind303 authored Jan 21, 2025
1 parent 4b8397b commit bae106c
Show file tree
Hide file tree
Showing 15 changed files with 298 additions and 202 deletions.
3 changes: 2 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 1 addition & 4 deletions daft/daft/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -1695,10 +1695,7 @@ class LogicalPlanBuilder:
def repr_mermaid(self, options: MermaidOptions) -> str: ...

class NativeExecutor:
@staticmethod
def from_logical_plan_builder(
logical_plan_builder: LogicalPlanBuilder,
) -> NativeExecutor: ...
def __init__(self) -> None: ...
def run(
self, psets: dict[str, list[PartitionT]], cfg: PyDaftExecutionConfig, results_buffer_size: int | None
) -> Iterator[PyMicroPartition]: ...
Expand Down
14 changes: 5 additions & 9 deletions daft/execution/native_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
from daft.daft import (
NativeExecutor as _NativeExecutor,
)
from daft.daft import PyDaftExecutionConfig
from daft.table import MicroPartition

if TYPE_CHECKING:
from daft.daft import PyDaftExecutionConfig
from daft.logical.builder import LogicalPlanBuilder
from daft.runners.partitioning import (
LocalMaterializedResult,
Expand All @@ -18,16 +18,12 @@


class NativeExecutor:
def __init__(self, executor: _NativeExecutor):
self._executor = executor

@classmethod
def from_logical_plan_builder(cls, builder: LogicalPlanBuilder) -> NativeExecutor:
executor = _NativeExecutor.from_logical_plan_builder(builder._builder)
return cls(executor)
def __init__(self):
self._executor = _NativeExecutor()

def run(
self,
builder: LogicalPlanBuilder,
psets: dict[str, list[MaterializedResult[PartitionT]]],
daft_execution_config: PyDaftExecutionConfig,
results_buffer_size: int | None,
Expand All @@ -39,5 +35,5 @@ def run(
}
return (
LocalMaterializedResult(MicroPartition._from_pymicropartition(part))
for part in self._executor.run(psets_mp, daft_execution_config, results_buffer_size)
for part in self._executor.run(builder._builder, psets_mp, daft_execution_config, results_buffer_size)
)
3 changes: 2 additions & 1 deletion daft/runners/native_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,9 @@ def run_iter(

# Optimize the logical plan.
builder = builder.optimize()
executor = NativeExecutor.from_logical_plan_builder(builder)
executor = NativeExecutor()
results_gen = executor.run(
builder,
{k: v.values() for k, v in self._part_set_cache.get_all_partition_sets().items()},
daft_execution_config,
results_buffer_size,
Expand Down
3 changes: 2 additions & 1 deletion daft/runners/pyrunner.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,8 +380,9 @@ def run_iter(
if daft_execution_config.enable_native_executor:
logger.info("Using native executor")

executor = NativeExecutor.from_logical_plan_builder(builder)
executor = NativeExecutor()
results_gen = executor.run(
builder,
{k: v.values() for k, v in self._part_set_cache.get_all_partition_sets().items()},
daft_execution_config,
results_buffer_size,
Expand Down
7 changes: 5 additions & 2 deletions src/common/runtime/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,13 +69,16 @@ impl<T: Send + 'static> Future for RuntimeTask<T> {
}

pub struct Runtime {
runtime: tokio::runtime::Runtime,
pub runtime: Arc<tokio::runtime::Runtime>,
pool_type: PoolType,
}

impl Runtime {
pub(crate) fn new(runtime: tokio::runtime::Runtime, pool_type: PoolType) -> RuntimeRef {
Arc::new(Self { runtime, pool_type })
Arc::new(Self {
runtime: Arc::new(runtime),
pool_type,
})
}

async fn execute_task<F>(future: F, pool_type: PoolType) -> DaftResult<F::Output>
Expand Down
3 changes: 1 addition & 2 deletions src/daft-connect/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
[dependencies]
arrow2 = {workspace = true, features = ["io_json_integration"]}
async-stream = "0.3.6"
common-daft-config = {workspace = true, optional = true, features = ["python"]}
common-error = {workspace = true, optional = true, features = ["python"]}
common-file-formats = {workspace = true, optional = true, features = ["python"]}
daft-catalog = {path = "../daft-catalog", optional = true, features = ["python"]}
Expand All @@ -27,12 +26,12 @@ tokio = {version = "1.40.0", features = ["full"]}
tonic = "0.12.3"
tracing = {workspace = true}
uuid = {version = "1.10.0", features = ["v4"]}
common-runtime.workspace = true

[features]
default = ["python"]
python = [
"dep:pyo3",
"dep:common-daft-config",
"dep:common-error",
"dep:common-file-formats",
"dep:daft-core",
Expand Down
113 changes: 57 additions & 56 deletions src/daft-connect/src/execute.rs
Original file line number Diff line number Diff line change
@@ -1,20 +1,17 @@
use std::{future::ready, sync::Arc};

use common_daft_config::DaftExecutionConfig;
use common_error::{DaftError, DaftResult};
use common_error::DaftResult;
use common_file_formats::FileFormat;
use daft_dsl::LiteralValue;
use daft_local_execution::NativeExecutor;
use daft_logical_plan::LogicalPlanBuilder;
use daft_micropartition::MicroPartition;
use daft_ray_execution::RayEngine;
use daft_table::Table;
use eyre::{bail, Context};
use futures::{
stream::{self, BoxStream},
StreamExt, TryFutureExt, TryStreamExt,
StreamExt, TryStreamExt,
};
use itertools::Itertools;
use pyo3::Python;
use spark_connect::{
relation::RelType,
Expand Down Expand Up @@ -63,17 +60,13 @@ impl Session {

Runner::Native => {
let this = self.clone();
let result_stream = tokio::task::spawn_blocking(move || {
let plan = lp.optimize()?;
let cfg = Arc::new(DaftExecutionConfig::default());
let native_executor = NativeExecutor::from_logical_plan_builder(&plan)?;
let results = native_executor.run(&*this.psets, cfg, None)?;
let it = results.into_iter();
Ok::<_, DaftError>(it.collect_vec())
})
.await??;

Ok(Box::pin(stream::iter(result_stream)))
let plan = lp.optimize_async().await?;

let results = this
.engine
.run(&plan, &*this.psets, Default::default(), None)?;
Ok(results.into_stream().boxed())
}
}
}
Expand All @@ -85,14 +78,12 @@ impl Session {
) -> Result<ExecuteStream, Status> {
use futures::{StreamExt, TryStreamExt};

// fallback response
let result_complete = res.result_complete_response();

let (tx, rx) = tokio::sync::mpsc::channel::<eyre::Result<ExecutePlanResponse>>(1);

let this = self.clone();

tokio::spawn(async move {
self.compute_runtime.runtime.spawn(async move {
let execution_fut = async {
let translator = SparkAnalyzer::new(&this);
match command.rel_type {
Expand Down Expand Up @@ -144,7 +135,7 @@ impl Session {
pub async fn execute_write_operation(
&self,
operation: WriteOperation,
response_builder: ResponseBuilder<ExecutePlanResponse>,
res: ResponseBuilder<ExecutePlanResponse>,
) -> Result<ExecuteStream, Status> {
fn check_write_operation(write_op: &WriteOperation) -> Result<(), Status> {
if !write_op.sort_column_names.is_empty() {
Expand Down Expand Up @@ -179,60 +170,70 @@ impl Session {
}
}

let finished = response_builder.result_complete_response();
let finished = res.result_complete_response();

let (tx, rx) = tokio::sync::mpsc::channel::<eyre::Result<ExecutePlanResponse>>(1);

let this = self.clone();

let result = async move {
check_write_operation(&operation)?;
self.compute_runtime.runtime.spawn(async move {
let result = async {
check_write_operation(&operation)?;

let WriteOperation {
input,
source,
save_type,
..
} = operation;
let WriteOperation {
input,
source,
save_type,
..
} = operation;

let input = input.required("input")?;
let source = source.required("source")?;
let input = input.required("input")?;
let source = source.required("source")?;

let file_format: FileFormat = source.parse()?;
let file_format: FileFormat = source.parse()?;

let Some(save_type) = save_type else {
bail!("Save type is required");
};
let Some(save_type) = save_type else {
bail!("Save type is required");
};

let path = match save_type {
SaveType::Path(path) => path,
SaveType::Table(_) => {
return not_yet_implemented!("write to table").map_err(|e| e.into())
}
};
let path = match save_type {
SaveType::Path(path) => path,
SaveType::Table(_) => {
return not_yet_implemented!("write to table").map_err(|e| e.into())
}
};

let translator = SparkAnalyzer::new(&this);
let translator = SparkAnalyzer::new(&this);

let plan = translator.to_logical_plan(input).await?;
let plan = translator.to_logical_plan(input).await?;

let plan = plan.table_write(&path, file_format, None, None, None)?;
let plan = plan.table_write(&path, file_format, None, None, None)?;

let mut result_stream = this.run_query(plan).await?;
let mut result_stream = this.run_query(plan).await?;

// this is so we make sure the operation is actually done
// before we return
//
// an example where this is important is if we write to a parquet file
// and then read immediately after, we need to wait for the write to finish
while let Some(_result) = result_stream.next().await {}
// this is so we make sure the operation is actually done
// before we return
//
// an example where this is important is if we write to a parquet file
// and then read immediately after, we need to wait for the write to finish
while let Some(_result) = result_stream.next().await {}

Ok(())
};
Ok(())
};

let result = result.map_err(|e| {
Status::internal(textwrap::wrap(&format!("Error in Daft server: {e}"), 120).join("\n"))
if let Err(e) = result.await {
let _ = tx.send(Err(e)).await;
}
});
let stream = ReceiverStream::new(rx);

let future = result.and_then(|()| ready(Ok(finished)));
let stream = futures::stream::once(future);
let stream = stream
.map_err(|e| {
Status::internal(
textwrap::wrap(&format!("Error in Daft server: {e}"), 120).join("\n"),
)
})
.chain(stream::once(ready(Ok(finished))));

Ok(Box::pin(stream))
}
Expand Down
4 changes: 2 additions & 2 deletions src/daft-connect/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,10 +84,10 @@ pub fn start(addr: &str) -> eyre::Result<ConnectionHandle> {
shutdown_signal: Some(shutdown_signal),
port,
};
let runtime = common_runtime::get_io_runtime(true);

std::thread::spawn(move || {
let runtime = tokio::runtime::Runtime::new().unwrap();
let result = runtime.block_on(async {
let result = runtime.block_on_current_thread(async {
let incoming = {
let listener = tokio::net::TcpListener::from_std(listener)
.wrap_err("Failed to create TcpListener from std::net::TcpListener")?;
Expand Down
8 changes: 8 additions & 0 deletions src/daft-connect/src/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@ use std::{
sync::{Arc, RwLock},
};

use common_runtime::RuntimeRef;
use daft_catalog::DaftCatalog;
use daft_local_execution::NativeExecutor;
use daft_micropartition::partitioning::InMemoryPartitionSetCache;
use uuid::Uuid;

Expand All @@ -19,6 +21,8 @@ pub struct Session {
/// MicroPartitionSet associated with this session
/// this will be filled up as the user runs queries
pub(crate) psets: Arc<InMemoryPartitionSetCache>,
pub(crate) compute_runtime: RuntimeRef,
pub(crate) engine: Arc<NativeExecutor>,
pub(crate) catalog: Arc<RwLock<DaftCatalog>>,
}

Expand All @@ -34,11 +38,15 @@ impl Session {
pub fn new(id: String) -> Self {
let server_side_session_id = Uuid::new_v4();
let server_side_session_id = server_side_session_id.to_string();
let rt = common_runtime::get_compute_runtime();

Self {
config_values: Default::default(),
id,
server_side_session_id,
psets: Arc::new(InMemoryPartitionSetCache::empty()),
compute_runtime: rt.clone(),
engine: Arc::new(NativeExecutor::default().with_runtime(rt.runtime.clone())),
catalog: Arc::new(RwLock::new(DaftCatalog::default())),
}
}
Expand Down
6 changes: 3 additions & 3 deletions src/daft-local-execution/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ use common_runtime::{RuntimeRef, RuntimeTask};
use lazy_static::lazy_static;
use progress_bar::{OperatorProgressBar, ProgressBarColor, ProgressBarManager};
use resource_manager::MemoryManager;
pub use run::{run_local, ExecutionEngineResult, NativeExecutor};
pub use run::{ExecutionEngineResult, NativeExecutor};
use runtime_stats::{RuntimeStatsContext, TimedFuture};
use snafu::{futures::TryFutureExt, ResultExt, Snafu};
use tracing::Instrument;
Expand Down Expand Up @@ -124,15 +124,15 @@ pub(crate) struct ExecutionRuntimeContext {
worker_set: TaskSet<crate::Result<()>>,
default_morsel_size: usize,
memory_manager: Arc<MemoryManager>,
progress_bar_manager: Option<Box<dyn ProgressBarManager>>,
progress_bar_manager: Option<Arc<dyn ProgressBarManager>>,
}

impl ExecutionRuntimeContext {
#[must_use]
pub fn new(
default_morsel_size: usize,
memory_manager: Arc<MemoryManager>,
progress_bar_manager: Option<Box<dyn ProgressBarManager>>,
progress_bar_manager: Option<Arc<dyn ProgressBarManager>>,
) -> Self {
Self {
worker_set: TaskSet::new(),
Expand Down
Loading

0 comments on commit bae106c

Please sign in to comment.