From 23f09588394a9e5d542fb2293ee6ae3e7fef6ec9 Mon Sep 17 00:00:00 2001 From: universalmind303 Date: Fri, 17 Jan 2025 12:11:27 -0600 Subject: [PATCH] feat: make micropartition streamable over tables --- Cargo.lock | 1 + src/daft-connect/src/execute.rs | 7 +- src/daft-micropartition/Cargo.toml | 1 + src/daft-micropartition/src/micropartition.rs | 114 +++++++++++++++++- src/daft-table/src/lib.rs | 32 +++-- src/daft-table/src/ops/explode.rs | 4 +- src/daft-table/src/ops/joins/hash_join.rs | 14 ++- src/daft-table/src/ops/joins/mod.rs | 10 +- src/daft-table/src/ops/unpivot.rs | 9 +- 9 files changed, 163 insertions(+), 29 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 1dc39d7bf9..a3993cc8b5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2407,6 +2407,7 @@ dependencies = [ "parquet2", "pyo3", "snafu", + "tokio", "tracing", ] diff --git a/src/daft-connect/src/execute.rs b/src/daft-connect/src/execute.rs index 23caca66b9..f31e243170 100644 --- a/src/daft-connect/src/execute.rs +++ b/src/daft-connect/src/execute.rs @@ -111,9 +111,10 @@ impl Session { while let Some(result) = result_stream.next().await { let result = result?; - let tables = result.get_tables()?; - for table in tables.as_slice() { - let response = res.arrow_batch_response(table)?; + let mut tables_stream = result.into_stream()?; + + while let Some(Ok(table)) = tables_stream.next().await { + let response = res.arrow_batch_response(&table)?; if tx.send(Ok(response)).await.is_err() { return Ok(()); } diff --git a/src/daft-micropartition/Cargo.toml b/src/daft-micropartition/Cargo.toml index c39b0b5b78..b9bbd35ca7 100644 --- a/src/daft-micropartition/Cargo.toml +++ b/src/daft-micropartition/Cargo.toml @@ -20,6 +20,7 @@ futures = {workspace = true} parquet2 = {workspace = true} pyo3 = {workspace = true, optional = true} snafu = {workspace = true} +tokio = {workspace = true} tracing = {workspace = true} [features] diff --git a/src/daft-micropartition/src/micropartition.rs b/src/daft-micropartition/src/micropartition.rs index 652ce3e57c..0edc5ea144 100644 --- a/src/daft-micropartition/src/micropartition.rs +++ b/src/daft-micropartition/src/micropartition.rs @@ -1,11 +1,13 @@ use std::{ collections::{BTreeMap, HashMap, HashSet}, fmt::Display, + pin::Pin, sync::{Arc, Mutex}, + task::{Context, Poll}, }; use arrow2::io::parquet::read::schema::infer_schema_with_options; -use common_error::DaftResult; +use common_error::{DaftError, DaftResult}; #[cfg(feature = "python")] use common_file_formats::DatabaseSourceConfig; use common_file_formats::{FileFormatConfig, ParquetSourceConfig}; @@ -22,6 +24,7 @@ use daft_parquet::read::{ use daft_scan::{storage_config::StorageConfig, ChunkSpec, DataSource, ScanTask}; use daft_stats::{PartitionSpec, TableMetadata, TableStatistics}; use daft_table::Table; +use futures::{Future, Stream}; use parquet2::metadata::FileMetaData; use snafu::ResultExt; @@ -1187,5 +1190,112 @@ impl Display for MicroPartition { } } +struct MicroPartitionStreamAdapter { + state: TableState, + current: usize, + pending_task: Option>>>, +} + +impl Stream for MicroPartitionStreamAdapter { + type Item = DaftResult; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.get_mut(); + + if let Some(handle) = &mut this.pending_task { + match Pin::new(handle).poll(cx) { + Poll::Ready(Ok(Ok(tables))) => { + let tables = Arc::new(tables); + this.state = TableState::Loaded(tables.clone()); + this.current = 0; + this.pending_task = None; + return Poll::Ready(tables.first().cloned().map(Ok)); + } + Poll::Ready(Ok(Err(e))) => return Poll::Ready(Some(Err(e))), + Poll::Ready(Err(e)) => { + return Poll::Ready(Some(Err(DaftError::InternalError(e.to_string())))) + } + Poll::Pending => return Poll::Pending, + } + } + + match &this.state { + // if the state is unloaded, we spawn a task to load the tables + // and set the state to loaded + TableState::Unloaded(scan_task) => { + let scan_task = scan_task.clone(); + let handle = tokio::spawn(async move { + materialize_scan_task(scan_task, None) + .map(|(tables, _)| tables) + .map_err(DaftError::from) + }); + this.pending_task = Some(handle); + cx.waker().wake_by_ref(); + Poll::Pending + } + TableState::Loaded(tables) => { + let current = this.current; + if current < tables.len() { + this.current = current + 1; + Poll::Ready(tables.get(current).cloned().map(Ok)) + } else { + Poll::Ready(None) + } + } + } + } +} +impl MicroPartition { + pub fn into_stream(self: Arc) -> DaftResult>> { + let state = match &*self.state.lock().unwrap() { + TableState::Unloaded(scan_task) => TableState::Unloaded(scan_task.clone()), + TableState::Loaded(tables) => TableState::Loaded(tables.clone()), + }; + + Ok(MicroPartitionStreamAdapter { + state, + current: 0, + pending_task: None, + }) + } +} + #[cfg(test)] -mod test {} +mod tests { + use std::sync::Arc; + + use common_error::DaftResult; + use daft_core::{ + datatypes::{DataType, Field, Int32Array}, + prelude::Schema, + series::IntoSeries, + }; + use daft_table::Table; + use futures::StreamExt; + + use crate::MicroPartition; + + #[tokio::test] + async fn test_mp_stream() -> DaftResult<()> { + let columns = vec![Int32Array::from_values("a", vec![1].into_iter()).into_series()]; + let columns2 = vec![Int32Array::from_values("a", vec![2].into_iter()).into_series()]; + let schema = Schema::new(vec![Field::new("a", DataType::Int32)])?; + + let table1 = Table::from_nonempty_columns(columns)?; + let table2 = Table::from_nonempty_columns(columns2)?; + + let mp = MicroPartition::new_loaded( + Arc::new(schema), + Arc::new(vec![table1.clone(), table2.clone()]), + None, + ); + let mp = Arc::new(mp); + + let mut stream = mp.into_stream()?; + let tbl = stream.next().await.unwrap().unwrap(); + assert_eq!(tbl, table1); + let tbl = stream.next().await.unwrap().unwrap(); + assert_eq!(tbl, table2); + Ok(()) + } +} diff --git a/src/daft-table/src/lib.rs b/src/daft-table/src/lib.rs index b06eae7c36..0bc7c1a4cc 100644 --- a/src/daft-table/src/lib.rs +++ b/src/daft-table/src/lib.rs @@ -7,6 +7,7 @@ use std::{ collections::{HashMap, HashSet}, fmt::{Display, Formatter, Result}, hash::{Hash, Hasher}, + sync::Arc, }; use arrow2::array::Array; @@ -44,14 +45,14 @@ use repr_html::html_value; #[derive(Clone, Debug, serde::Serialize, serde::Deserialize)] pub struct Table { pub schema: SchemaRef, - columns: Vec, + columns: Arc>, num_rows: usize, } impl Hash for Table { fn hash(&self, state: &mut H) { self.schema.hash(state); - for col in &self.columns { + for col in &*self.columns { let hashes = col.hash(None).expect("Failed to hash column"); hashes.into_iter().for_each(|h| h.hash(state)); } @@ -159,7 +160,7 @@ impl Table { ) -> Self { Self { schema: schema.into(), - columns, + columns: Arc::new(columns), num_rows, } } @@ -181,7 +182,8 @@ impl Table { /// # Arguments /// /// * `columns` - Columns to crate a table from as [`Series`] objects - pub fn from_nonempty_columns(columns: Vec) -> DaftResult { + pub fn from_nonempty_columns(columns: impl Into>>) -> DaftResult { + let columns = columns.into(); assert!(!columns.is_empty(), "Cannot call Table::new() with empty columns. This indicates an internal error, please file an issue."); let schema = Schema::new(columns.iter().map(|s| s.field().clone()).collect())?; @@ -199,7 +201,11 @@ impl Table { } } - Ok(Self::new_unchecked(schema, columns, num_rows)) + Ok(Self { + schema, + columns, + num_rows, + }) } pub fn num_columns(&self) -> usize { @@ -231,11 +237,11 @@ impl Table { pub fn head(&self, num: usize) -> DaftResult { if num >= self.len() { - return Ok(Self::new_unchecked( - self.schema.clone(), - self.columns.clone(), - self.len(), - )); + return Ok(Self { + schema: self.schema.clone(), + columns: self.columns.clone(), + num_rows: self.len(), + }); } self.slice(0, num) } @@ -769,7 +775,7 @@ impl Table { // Begin row. res.push_str(""); - for col in &self.columns { + for col in &*self.columns { res.push_str(styled_td); res.push_str(&html_value(col, i)); res.push_str(""); @@ -781,7 +787,7 @@ impl Table { if tail_rows != 0 { res.push_str(""); - for _ in &self.columns { + for _ in &*self.columns { res.push_str(""); } res.push_str("\n"); @@ -791,7 +797,7 @@ impl Table { // Begin row. res.push_str(""); - for col in &self.columns { + for col in &*self.columns { res.push_str(styled_td); res.push_str(&html_value(col, i)); res.push_str(""); diff --git a/src/daft-table/src/ops/explode.rs b/src/daft-table/src/ops/explode.rs index bdd715ac4a..cee00485ac 100644 --- a/src/daft-table/src/ops/explode.rs +++ b/src/daft-table/src/ops/explode.rs @@ -1,3 +1,5 @@ +use std::sync::Arc; + use common_error::{DaftError, DaftResult}; use daft_core::{ array::ops::as_arrow::AsArrow, @@ -79,7 +81,7 @@ impl Table { let capacity_expected = exploded_columns.first().unwrap().len(); let take_idx = lengths_to_indices(&first_len, capacity_expected)?.into_series(); - let mut new_series = self.columns.clone(); + let mut new_series = Arc::unwrap_or_clone(self.columns.clone()); for i in 0..self.num_columns() { let name = new_series.get(i).unwrap().name(); diff --git a/src/daft-table/src/ops/joins/hash_join.rs b/src/daft-table/src/ops/joins/hash_join.rs index 7f74666443..3a2dd08a52 100644 --- a/src/daft-table/src/ops/joins/hash_join.rs +++ b/src/daft-table/src/ops/joins/hash_join.rs @@ -1,4 +1,4 @@ -use std::{cmp, iter::repeat}; +use std::{cmp, iter::repeat, sync::Arc}; use arrow2::{bitmap::MutableBitmap, types::IndexRange}; use common_error::DaftResult; @@ -89,11 +89,13 @@ pub(super) fn hash_inner_join( let common_join_keys: Vec<_> = get_common_join_keys(left_on, right_on).collect(); - let mut join_series = left + let join_series = left .get_columns(common_join_keys.as_slice())? .take(&lidx)? .columns; + let mut join_series = Arc::unwrap_or_clone(join_series); + drop(lkeys); drop(rkeys); @@ -198,9 +200,11 @@ pub(super) fn hash_left_right_join( let common_join_keys = get_common_join_keys(left_on, right_on); let mut join_series = if left_side { - left.get_columns(common_join_keys.collect::>().as_slice())? - .take(&lidx)? - .columns + Arc::unwrap_or_clone( + left.get_columns(common_join_keys.collect::>().as_slice())? + .take(&lidx)? + .columns, + ) } else { common_join_keys .map(|name| { diff --git a/src/daft-table/src/ops/joins/mod.rs b/src/daft-table/src/ops/joins/mod.rs index d7dd991e5b..5eb22b7cbd 100644 --- a/src/daft-table/src/ops/joins/mod.rs +++ b/src/daft-table/src/ops/joins/mod.rs @@ -1,4 +1,4 @@ -use std::collections::HashSet; +use std::{collections::HashSet, sync::Arc}; use common_error::{DaftError, DaftResult}; use daft_core::{ @@ -216,7 +216,7 @@ impl Table { Table::concat(&vec![input; outer_len]) } - let (left_table, mut right_table) = match outer_loop_side { + let (left_table, right_table) = match outer_loop_side { JoinSide::Left => ( create_outer_loop_table(self, right.len())?, create_inner_loop_table(right, self.len())?, @@ -230,8 +230,10 @@ impl Table { let num_rows = self.len() * right.len(); let join_schema = self.schema.union(&right.schema)?; - let mut join_columns = left_table.columns; - join_columns.append(&mut right_table.columns); + let mut join_columns = Arc::unwrap_or_clone(left_table.columns); + let mut right_columns = Arc::unwrap_or_clone(right_table.columns); + + join_columns.append(&mut right_columns); Self::new_with_size(join_schema, join_columns, num_rows) } diff --git a/src/daft-table/src/ops/unpivot.rs b/src/daft-table/src/ops/unpivot.rs index 5b43777417..e916e53e5a 100644 --- a/src/daft-table/src/ops/unpivot.rs +++ b/src/daft-table/src/ops/unpivot.rs @@ -1,3 +1,5 @@ +use std::sync::Arc; + use common_error::{DaftError, DaftResult}; use daft_core::{prelude::*, series::cast_series_to_supertype}; use daft_dsl::ExprRef; @@ -52,7 +54,12 @@ impl Table { variable_series.field().clone(), value_series.field().clone(), ])?)?; - let unpivot_series = [ids_series, vec![variable_series, value_series]].concat(); + + let unpivot_series = [ + Arc::unwrap_or_clone(ids_series), + vec![variable_series, value_series], + ] + .concat(); Self::new_with_size(unpivot_schema, unpivot_series, unpivoted_len) }
...