Skip to content

Commit

Permalink
feat(core): make micropartition streamable over tables (#3709)
Browse files Browse the repository at this point in the history
### Description

makes Micropartition streamable over `Table`, and makes `Table` a bit
cheaper to clone by wrapping the `columns` column in an `Arc`.


The driving force for these changes is that currently the logic to go
from micropartition to table does not work nicely withing a streaming
context as the `get_tables` method returns `Arc<Vec<Table>>`. So you
can't easily chain streaming methods when working with micropartitions
and tables.
  • Loading branch information
universalmind303 authored Jan 21, 2025
1 parent bae106c commit 03fea9c
Show file tree
Hide file tree
Showing 9 changed files with 163 additions and 29 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 4 additions & 3 deletions src/daft-connect/src/execute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,10 @@ impl Session {

while let Some(result) = result_stream.next().await {
let result = result?;
let tables = result.get_tables()?;
for table in tables.as_slice() {
let response = res.arrow_batch_response(table)?;
let mut tables_stream = result.into_stream()?;

while let Some(Ok(table)) = tables_stream.next().await {
let response = res.arrow_batch_response(&table)?;
if tx.send(Ok(response)).await.is_err() {
return Ok(());
}
Expand Down
1 change: 1 addition & 0 deletions src/daft-micropartition/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ futures = {workspace = true}
parquet2 = {workspace = true}
pyo3 = {workspace = true, optional = true}
snafu = {workspace = true}
tokio = {workspace = true}
tracing = {workspace = true}

[features]
Expand Down
114 changes: 112 additions & 2 deletions src/daft-micropartition/src/micropartition.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
use std::{
collections::{BTreeMap, HashMap, HashSet},
fmt::Display,
pin::Pin,
sync::{Arc, Mutex},
task::{Context, Poll},
};

use arrow2::io::parquet::read::schema::infer_schema_with_options;
use common_error::DaftResult;
use common_error::{DaftError, DaftResult};
#[cfg(feature = "python")]
use common_file_formats::DatabaseSourceConfig;
use common_file_formats::{FileFormatConfig, ParquetSourceConfig};
Expand All @@ -22,6 +24,7 @@ use daft_parquet::read::{
use daft_scan::{storage_config::StorageConfig, ChunkSpec, DataSource, ScanTask};
use daft_stats::{PartitionSpec, TableMetadata, TableStatistics};
use daft_table::Table;
use futures::{Future, Stream};
use parquet2::metadata::FileMetaData;
use snafu::ResultExt;

Expand Down Expand Up @@ -1187,5 +1190,112 @@ impl Display for MicroPartition {
}
}

struct MicroPartitionStreamAdapter {
state: TableState,
current: usize,
pending_task: Option<tokio::task::JoinHandle<DaftResult<Vec<Table>>>>,
}

impl Stream for MicroPartitionStreamAdapter {
type Item = DaftResult<Table>;

fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.get_mut();

if let Some(handle) = &mut this.pending_task {
match Pin::new(handle).poll(cx) {
Poll::Ready(Ok(Ok(tables))) => {
let tables = Arc::new(tables);
this.state = TableState::Loaded(tables.clone());
this.current = 0;
this.pending_task = None;
return Poll::Ready(tables.first().cloned().map(Ok));
}
Poll::Ready(Ok(Err(e))) => return Poll::Ready(Some(Err(e))),
Poll::Ready(Err(e)) => {
return Poll::Ready(Some(Err(DaftError::InternalError(e.to_string()))))
}
Poll::Pending => return Poll::Pending,
}
}

match &this.state {
// if the state is unloaded, we spawn a task to load the tables
// and set the state to loaded
TableState::Unloaded(scan_task) => {
let scan_task = scan_task.clone();
let handle = tokio::spawn(async move {
materialize_scan_task(scan_task, None)
.map(|(tables, _)| tables)
.map_err(DaftError::from)
});
this.pending_task = Some(handle);
cx.waker().wake_by_ref();
Poll::Pending
}
TableState::Loaded(tables) => {
let current = this.current;
if current < tables.len() {
this.current = current + 1;
Poll::Ready(tables.get(current).cloned().map(Ok))
} else {
Poll::Ready(None)
}
}
}
}
}
impl MicroPartition {
pub fn into_stream(self: Arc<Self>) -> DaftResult<impl Stream<Item = DaftResult<Table>>> {
let state = match &*self.state.lock().unwrap() {
TableState::Unloaded(scan_task) => TableState::Unloaded(scan_task.clone()),
TableState::Loaded(tables) => TableState::Loaded(tables.clone()),
};

Ok(MicroPartitionStreamAdapter {
state,
current: 0,
pending_task: None,
})
}
}

#[cfg(test)]
mod test {}
mod tests {
use std::sync::Arc;

use common_error::DaftResult;
use daft_core::{
datatypes::{DataType, Field, Int32Array},
prelude::Schema,
series::IntoSeries,
};
use daft_table::Table;
use futures::StreamExt;

use crate::MicroPartition;

#[tokio::test]
async fn test_mp_stream() -> DaftResult<()> {
let columns = vec![Int32Array::from_values("a", vec![1].into_iter()).into_series()];
let columns2 = vec![Int32Array::from_values("a", vec![2].into_iter()).into_series()];
let schema = Schema::new(vec![Field::new("a", DataType::Int32)])?;

let table1 = Table::from_nonempty_columns(columns)?;
let table2 = Table::from_nonempty_columns(columns2)?;

let mp = MicroPartition::new_loaded(
Arc::new(schema),
Arc::new(vec![table1.clone(), table2.clone()]),
None,
);
let mp = Arc::new(mp);

let mut stream = mp.into_stream()?;
let tbl = stream.next().await.unwrap().unwrap();
assert_eq!(tbl, table1);
let tbl = stream.next().await.unwrap().unwrap();
assert_eq!(tbl, table2);
Ok(())
}
}
32 changes: 19 additions & 13 deletions src/daft-table/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use std::{
collections::{HashMap, HashSet},
fmt::{Display, Formatter, Result},
hash::{Hash, Hasher},
sync::Arc,
};

use arrow2::array::Array;
Expand Down Expand Up @@ -44,14 +45,14 @@ use repr_html::html_value;
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
pub struct Table {
pub schema: SchemaRef,
columns: Vec<Series>,
columns: Arc<Vec<Series>>,
num_rows: usize,
}

impl Hash for Table {
fn hash<H: Hasher>(&self, state: &mut H) {
self.schema.hash(state);
for col in &self.columns {
for col in &*self.columns {
let hashes = col.hash(None).expect("Failed to hash column");
hashes.into_iter().for_each(|h| h.hash(state));
}
Expand Down Expand Up @@ -159,7 +160,7 @@ impl Table {
) -> Self {
Self {
schema: schema.into(),
columns,
columns: Arc::new(columns),
num_rows,
}
}
Expand All @@ -181,7 +182,8 @@ impl Table {
/// # Arguments
///
/// * `columns` - Columns to crate a table from as [`Series`] objects
pub fn from_nonempty_columns(columns: Vec<Series>) -> DaftResult<Self> {
pub fn from_nonempty_columns(columns: impl Into<Arc<Vec<Series>>>) -> DaftResult<Self> {
let columns = columns.into();
assert!(!columns.is_empty(), "Cannot call Table::new() with empty columns. This indicates an internal error, please file an issue.");

let schema = Schema::new(columns.iter().map(|s| s.field().clone()).collect())?;
Expand All @@ -199,7 +201,11 @@ impl Table {
}
}

Ok(Self::new_unchecked(schema, columns, num_rows))
Ok(Self {
schema,
columns,
num_rows,
})
}

pub fn num_columns(&self) -> usize {
Expand Down Expand Up @@ -231,11 +237,11 @@ impl Table {

pub fn head(&self, num: usize) -> DaftResult<Self> {
if num >= self.len() {
return Ok(Self::new_unchecked(
self.schema.clone(),
self.columns.clone(),
self.len(),
));
return Ok(Self {
schema: self.schema.clone(),
columns: self.columns.clone(),
num_rows: self.len(),
});
}
self.slice(0, num)
}
Expand Down Expand Up @@ -769,7 +775,7 @@ impl Table {
// Begin row.
res.push_str("<tr>");

for col in &self.columns {
for col in &*self.columns {
res.push_str(styled_td);
res.push_str(&html_value(col, i));
res.push_str("</div></td>");
Expand All @@ -781,7 +787,7 @@ impl Table {

if tail_rows != 0 {
res.push_str("<tr>");
for _ in &self.columns {
for _ in &*self.columns {
res.push_str("<td>...</td>");
}
res.push_str("</tr>\n");
Expand All @@ -791,7 +797,7 @@ impl Table {
// Begin row.
res.push_str("<tr>");

for col in &self.columns {
for col in &*self.columns {
res.push_str(styled_td);
res.push_str(&html_value(col, i));
res.push_str("</td>");
Expand Down
4 changes: 3 additions & 1 deletion src/daft-table/src/ops/explode.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::sync::Arc;

use common_error::{DaftError, DaftResult};
use daft_core::{
array::ops::as_arrow::AsArrow,
Expand Down Expand Up @@ -79,7 +81,7 @@ impl Table {
let capacity_expected = exploded_columns.first().unwrap().len();
let take_idx = lengths_to_indices(&first_len, capacity_expected)?.into_series();

let mut new_series = self.columns.clone();
let mut new_series = Arc::unwrap_or_clone(self.columns.clone());

for i in 0..self.num_columns() {
let name = new_series.get(i).unwrap().name();
Expand Down
14 changes: 9 additions & 5 deletions src/daft-table/src/ops/joins/hash_join.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::{cmp, iter::repeat};
use std::{cmp, iter::repeat, sync::Arc};

use arrow2::{bitmap::MutableBitmap, types::IndexRange};
use common_error::DaftResult;
Expand Down Expand Up @@ -89,11 +89,13 @@ pub(super) fn hash_inner_join(

let common_join_keys: Vec<_> = get_common_join_keys(left_on, right_on).collect();

let mut join_series = left
let join_series = left
.get_columns(common_join_keys.as_slice())?
.take(&lidx)?
.columns;

let mut join_series = Arc::unwrap_or_clone(join_series);

drop(lkeys);
drop(rkeys);

Expand Down Expand Up @@ -198,9 +200,11 @@ pub(super) fn hash_left_right_join(
let common_join_keys = get_common_join_keys(left_on, right_on);

let mut join_series = if left_side {
left.get_columns(common_join_keys.collect::<Vec<_>>().as_slice())?
.take(&lidx)?
.columns
Arc::unwrap_or_clone(
left.get_columns(common_join_keys.collect::<Vec<_>>().as_slice())?
.take(&lidx)?
.columns,
)
} else {
common_join_keys
.map(|name| {
Expand Down
10 changes: 6 additions & 4 deletions src/daft-table/src/ops/joins/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::collections::HashSet;
use std::{collections::HashSet, sync::Arc};

use common_error::{DaftError, DaftResult};
use daft_core::{
Expand Down Expand Up @@ -216,7 +216,7 @@ impl Table {
Table::concat(&vec![input; outer_len])
}

let (left_table, mut right_table) = match outer_loop_side {
let (left_table, right_table) = match outer_loop_side {
JoinSide::Left => (
create_outer_loop_table(self, right.len())?,
create_inner_loop_table(right, self.len())?,
Expand All @@ -230,8 +230,10 @@ impl Table {
let num_rows = self.len() * right.len();

let join_schema = self.schema.union(&right.schema)?;
let mut join_columns = left_table.columns;
join_columns.append(&mut right_table.columns);
let mut join_columns = Arc::unwrap_or_clone(left_table.columns);
let mut right_columns = Arc::unwrap_or_clone(right_table.columns);

join_columns.append(&mut right_columns);

Self::new_with_size(join_schema, join_columns, num_rows)
}
Expand Down
9 changes: 8 additions & 1 deletion src/daft-table/src/ops/unpivot.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::sync::Arc;

use common_error::{DaftError, DaftResult};
use daft_core::{prelude::*, series::cast_series_to_supertype};
use daft_dsl::ExprRef;
Expand Down Expand Up @@ -52,7 +54,12 @@ impl Table {
variable_series.field().clone(),
value_series.field().clone(),
])?)?;
let unpivot_series = [ids_series, vec![variable_series, value_series]].concat();

let unpivot_series = [
Arc::unwrap_or_clone(ids_series),
vec![variable_series, value_series],
]
.concat();

Self::new_with_size(unpivot_schema, unpivot_series, unpivoted_len)
}
Expand Down

0 comments on commit 03fea9c

Please sign in to comment.