Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEAT] daft-connect support for parquet #3236

Closed
wants to merge 12 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
968 changes: 668 additions & 300 deletions Cargo.lock

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions daft/daft/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -1241,6 +1241,8 @@ def connect_start(addr: str) -> ConnectionHandle: ...
class ConnectionHandle:
def shutdown(self) -> None: ...

def connect_start(addr: str) -> None: ...

# expr numeric ops
def abs(expr: PyExpr) -> PyExpr: ...
def cbrt(expr: PyExpr) -> PyExpr: ...
Expand Down
17 changes: 15 additions & 2 deletions src/daft-connect/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,15 +1,28 @@
[dependencies]
arrow2.workspace = true
common-daft-config.workspace = true
common-file-formats.workspace = true
daft-core.workspace = true
daft-dsl.workspace = true
daft-local-execution.workspace = true
daft-local-plan.workspace = true
daft-logical-plan.workspace = true
daft-physical-plan.workspace = true
daft-schema.workspace = true
daft-table.workspace = true
dashmap = "6.1.0"
eyre = "0.6.12"
futures = "0.3.31"
pyo3 = {workspace = true, optional = true}
spark-connect.workspace = true
tokio = {version = "1.40.0", features = ["full"]}
tokio-stream = "0.1.16"
tonic = "0.12.3"
tracing-subscriber = {version = "0.3.18", features = ["env-filter"]}
tracing-tracy = "0.11.3"
uuid = {version = "1.10.0", features = ["v4"]}
spark-connect.workspace = true
common-error.workspace = true
tracing.workspace = true
uuid = {version = "1.10.0", features = ["v4"]}

[features]
python = ["dep:pyo3"]
Expand Down
126 changes: 126 additions & 0 deletions src/daft-connect/src/command.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
use std::thread;

use arrow2::io::ipc::write::StreamWriter;
use daft_table::Table;
use eyre::Context;
use futures::TryStreamExt;
use spark_connect::{
execute_plan_response::{ArrowBatch, ResponseType, ResultComplete},
spark_connect_service_server::SparkConnectService,
ExecutePlanResponse, Relation,
};
use tokio_stream::wrappers::UnboundedReceiverStream;
use tonic::Status;
use uuid::Uuid;

use crate::{convert::convert_data, DaftSparkConnectService, Session};

type DaftStream = <DaftSparkConnectService as SparkConnectService>::ExecutePlanStream;

struct ExecutablePlanChannel {
session_id: String,
server_side_session_id: String,
operation_id: String,
tx: tokio::sync::mpsc::UnboundedSender<eyre::Result<ExecutePlanResponse>>,
}

pub trait ConcreteDataChannel {
fn send_table(&mut self, table: &Table) -> eyre::Result<()>;
}

impl ConcreteDataChannel for ExecutablePlanChannel {
fn send_table(&mut self, table: &Table) -> eyre::Result<()> {
let mut data = Vec::new();

let mut writer = StreamWriter::new(
&mut data,
arrow2::io::ipc::write::WriteOptions { compression: None },
);

let row_count = table.num_rows();

let schema = table
.schema
.to_arrow()
.wrap_err("Failed to convert Daft schema to Arrow schema")?;

writer
.start(&schema, None)
.wrap_err("Failed to start Arrow stream writer with schema")?;

let arrays = table.get_inner_arrow_arrays().collect();
let chunk = arrow2::chunk::Chunk::new(arrays);

writer
.write(&chunk, None)
.wrap_err("Failed to write Arrow chunk to stream writer")?;

let response = ExecutePlanResponse {
session_id: self.session_id.to_string(),
server_side_session_id: self.server_side_session_id.to_string(),
operation_id: self.operation_id.to_string(),
response_id: Uuid::new_v4().to_string(), // todo: implement this
metrics: None, // todo: implement this
observed_metrics: vec![],
schema: None,
response_type: Some(ResponseType::ArrowBatch(ArrowBatch {
row_count: row_count as i64,
data,
start_offset: None,
})),
};

self.tx
.send(Ok(response))
.wrap_err("Error sending response to client")?;

Ok(())
}
}

impl Session {
pub async fn handle_root_command(
&self,
command: Relation,
operation_id: String,
) -> Result<DaftStream, Status> {
let (tx, rx) = tokio::sync::mpsc::unbounded_channel();

let mut channel = ExecutablePlanChannel {
session_id: self.client_side_session_id().to_string(),
server_side_session_id: self.server_side_session_id().to_string(),
operation_id: operation_id.clone(),
tx: tx.clone(),
};

thread::spawn({
let session_id = self.client_side_session_id().to_string();
let server_side_session_id = self.server_side_session_id().to_string();
move || {
let result = convert_data(command, &mut channel);

if let Err(e) = result {
tx.send(Err(e)).unwrap();
} else {
let finished = ExecutePlanResponse {
session_id,
server_side_session_id,
operation_id: operation_id.to_string(),
response_id: Uuid::new_v4().to_string(),
metrics: None,
observed_metrics: vec![],
schema: None,
response_type: Some(ResponseType::ResultComplete(ResultComplete {})),
};

tx.send(Ok(finished)).unwrap();
}
}
});

let recv_stream =
UnboundedReceiverStream::new(rx).map_err(|e| Status::internal(e.to_string()));

Ok(Box::pin(recv_stream))
}
}
45 changes: 45 additions & 0 deletions src/daft-connect/src/convert.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
mod data_conversion;
mod expression;
mod formatting;
mod plan_conversion;
mod schema_conversion;

use std::{collections::HashMap, pin::Pin, sync::Arc};

use common_daft_config::DaftExecutionConfig;
use common_error::{DaftError, DaftResult};
use daft_logical_plan::LogicalPlanRef;
use daft_table::Table;
pub use data_conversion::convert_data;
use futures::{stream, Stream, StreamExt};
pub use schema_conversion::connect_schema;

pub fn run_local_to_tables(
logical_plan: &LogicalPlanRef,
) -> DaftResult<impl Stream<Item = DaftResult<Table>>> {
let physical_plan = daft_local_plan::translate(logical_plan)?;
let cfg = Arc::new(DaftExecutionConfig::default());
let psets = HashMap::new();

let stream = daft_local_execution::run_local(&physical_plan, psets, cfg, None)?;

let stream = stream
.map(|partition| match partition {
Ok(partition) => partition.get_tables().map_err(DaftError::from),
Err(err) => Err(err),
})
.flat_map(|tables| match tables {
Ok(tables) => {
let tables = Arc::try_unwrap(tables).unwrap();

let tables = tables.into_iter().map(Ok);
let stream: Pin<Box<dyn Stream<Item = DaftResult<Table>>>> =
Box::pin(stream::iter(tables));

stream
}
Err(err) => Box::pin(stream::once(async { Err(err) })),
});

Ok(stream)
}
54 changes: 54 additions & 0 deletions src/daft-connect/src/convert/data_conversion.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
//! Relation handling for Spark Connect protocol.
//!
//! A Relation represents a structured dataset or transformation in Spark Connect.
//! It can be either a base relation (direct data source) or derived relation
//! (result of operations on other relations).
//!
//! The protocol represents relations as trees of operations where:
//! - Each node is a Relation with metadata and an operation type
//! - Operations can reference other relations, forming a DAG
//! - The tree describes how to derive the final result
//!
//! Example flow for: SELECT age, COUNT(*) FROM employees WHERE dept='Eng' GROUP BY age
//!
//! ```text
//! Aggregate (grouping by age)
//! ↳ Filter (department = 'Engineering')
//! ↳ Read (employees table)
//! ```
//!
//! Relations abstract away:
//! - Physical storage details
//! - Distributed computation
//! - Query optimization
//! - Data source specifics
//!
//! This allows Spark to optimize and execute queries efficiently across a cluster
//! while providing a consistent API regardless of the underlying data source.
//! ```mermaid
//!
//! ```
use eyre::{eyre, Context};
use spark_connect::{relation::RelType, Relation};
use tracing::trace;

use crate::{command::ConcreteDataChannel, convert::formatting::RelTypeExt};

mod range;
use range::range;

pub fn convert_data(plan: Relation, encoder: &mut impl ConcreteDataChannel) -> eyre::Result<()> {
// First check common fields if needed
if let Some(common) = &plan.common {
// contains metadata shared across all relation types
// Log or handle common fields if necessary
trace!("Processing relation with plan_id: {:?}", common.plan_id);
}

let rel_type = plan.rel_type.ok_or_else(|| eyre!("rel_type is None"))?;

match rel_type {
RelType::Range(input) => range(input, encoder).wrap_err("parsing Range"),
other => Err(eyre!("Unsupported top-level relation: {}", other.name())),
}
}
42 changes: 42 additions & 0 deletions src/daft-connect/src/convert/data_conversion/range.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
use daft_core::prelude::Series;
use daft_schema::prelude::Schema;
use daft_table::Table;
use eyre::{ensure, Context};
use spark_connect::Range;

use crate::command::ConcreteDataChannel;

pub fn range(range: Range, channel: &mut impl ConcreteDataChannel) -> eyre::Result<()> {
let Range {
start,
end,
step,
num_partitions,
} = range;

let start = start.unwrap_or(0);

ensure!(num_partitions.is_none(), "num_partitions is not supported");

let step = usize::try_from(step).wrap_err("step must be a positive integer")?;
ensure!(step > 0, "step must be greater than 0");

let arrow_array: arrow2::array::Int64Array = (start..end).step_by(step).map(Some).collect();
let len = arrow_array.len();

let singleton_series = Series::try_from((
"range",
Box::new(arrow_array) as Box<dyn arrow2::array::Array>,
))
.wrap_err("creating singleton series")?;

let singleton_table = Table::new_with_size(
Schema::new(vec![singleton_series.field().clone()])?,
vec![singleton_series],
len,
)?;

channel.send_table(&singleton_table)?;

Ok(())
}
Loading
Loading