Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(core): make micropartition streamable over tables #3709

Merged
merged 1 commit into from
Jan 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 4 additions & 3 deletions src/daft-connect/src/execute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,10 @@ impl Session {

while let Some(result) = result_stream.next().await {
let result = result?;
let tables = result.get_tables()?;
for table in tables.as_slice() {
let response = res.arrow_batch_response(table)?;
let mut tables_stream = result.into_stream()?;

while let Some(Ok(table)) = tables_stream.next().await {
let response = res.arrow_batch_response(&table)?;
if tx.send(Ok(response)).await.is_err() {
return Ok(());
}
Expand Down
1 change: 1 addition & 0 deletions src/daft-micropartition/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ futures = {workspace = true}
parquet2 = {workspace = true}
pyo3 = {workspace = true, optional = true}
snafu = {workspace = true}
tokio = {workspace = true}
tracing = {workspace = true}

[features]
Expand Down
114 changes: 112 additions & 2 deletions src/daft-micropartition/src/micropartition.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
use std::{
collections::{BTreeMap, HashMap, HashSet},
fmt::Display,
pin::Pin,
sync::{Arc, Mutex},
task::{Context, Poll},
};

use arrow2::io::parquet::read::schema::infer_schema_with_options;
use common_error::DaftResult;
use common_error::{DaftError, DaftResult};
#[cfg(feature = "python")]
use common_file_formats::DatabaseSourceConfig;
use common_file_formats::{FileFormatConfig, ParquetSourceConfig};
Expand All @@ -22,6 +24,7 @@
use daft_scan::{storage_config::StorageConfig, ChunkSpec, DataSource, ScanTask};
use daft_stats::{PartitionSpec, TableMetadata, TableStatistics};
use daft_table::Table;
use futures::{Future, Stream};
use parquet2::metadata::FileMetaData;
use snafu::ResultExt;

Expand Down Expand Up @@ -1187,5 +1190,112 @@
}
}

struct MicroPartitionStreamAdapter {
state: TableState,
current: usize,
pending_task: Option<tokio::task::JoinHandle<DaftResult<Vec<Table>>>>,
}

impl Stream for MicroPartitionStreamAdapter {
type Item = DaftResult<Table>;

fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.get_mut();

if let Some(handle) = &mut this.pending_task {
match Pin::new(handle).poll(cx) {
Poll::Ready(Ok(Ok(tables))) => {
let tables = Arc::new(tables);
this.state = TableState::Loaded(tables.clone());
this.current = 0;
this.pending_task = None;
return Poll::Ready(tables.first().cloned().map(Ok));

Check warning on line 1212 in src/daft-micropartition/src/micropartition.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-micropartition/src/micropartition.rs#L1206-L1212

Added lines #L1206 - L1212 were not covered by tests
}
Poll::Ready(Ok(Err(e))) => return Poll::Ready(Some(Err(e))),
Poll::Ready(Err(e)) => {
return Poll::Ready(Some(Err(DaftError::InternalError(e.to_string()))))

Check warning on line 1216 in src/daft-micropartition/src/micropartition.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-micropartition/src/micropartition.rs#L1214-L1216

Added lines #L1214 - L1216 were not covered by tests
}
Poll::Pending => return Poll::Pending,

Check warning on line 1218 in src/daft-micropartition/src/micropartition.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-micropartition/src/micropartition.rs#L1218

Added line #L1218 was not covered by tests
}
}

match &this.state {
// if the state is unloaded, we spawn a task to load the tables
// and set the state to loaded
TableState::Unloaded(scan_task) => {
let scan_task = scan_task.clone();
let handle = tokio::spawn(async move {
materialize_scan_task(scan_task, None)
.map(|(tables, _)| tables)
.map_err(DaftError::from)
});
this.pending_task = Some(handle);
cx.waker().wake_by_ref();
Poll::Pending

Check warning on line 1234 in src/daft-micropartition/src/micropartition.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-micropartition/src/micropartition.rs#L1225-L1234

Added lines #L1225 - L1234 were not covered by tests
}
TableState::Loaded(tables) => {
let current = this.current;
if current < tables.len() {
this.current = current + 1;
Poll::Ready(tables.get(current).cloned().map(Ok))
} else {
Poll::Ready(None)
}
}
}
}
}
impl MicroPartition {
pub fn into_stream(self: Arc<Self>) -> DaftResult<impl Stream<Item = DaftResult<Table>>> {
let state = match &*self.state.lock().unwrap() {
TableState::Unloaded(scan_task) => TableState::Unloaded(scan_task.clone()),

Check warning on line 1251 in src/daft-micropartition/src/micropartition.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-micropartition/src/micropartition.rs#L1251

Added line #L1251 was not covered by tests
TableState::Loaded(tables) => TableState::Loaded(tables.clone()),
};

Ok(MicroPartitionStreamAdapter {
state,
current: 0,
pending_task: None,
})
}
}

#[cfg(test)]
mod test {}
mod tests {
use std::sync::Arc;

use common_error::DaftResult;
use daft_core::{
datatypes::{DataType, Field, Int32Array},
prelude::Schema,
series::IntoSeries,
};
use daft_table::Table;
use futures::StreamExt;

use crate::MicroPartition;

#[tokio::test]
async fn test_mp_stream() -> DaftResult<()> {
let columns = vec![Int32Array::from_values("a", vec![1].into_iter()).into_series()];
let columns2 = vec![Int32Array::from_values("a", vec![2].into_iter()).into_series()];
let schema = Schema::new(vec![Field::new("a", DataType::Int32)])?;

let table1 = Table::from_nonempty_columns(columns)?;
let table2 = Table::from_nonempty_columns(columns2)?;

let mp = MicroPartition::new_loaded(
Arc::new(schema),
Arc::new(vec![table1.clone(), table2.clone()]),
None,
);
let mp = Arc::new(mp);

let mut stream = mp.into_stream()?;
let tbl = stream.next().await.unwrap().unwrap();
assert_eq!(tbl, table1);
let tbl = stream.next().await.unwrap().unwrap();
assert_eq!(tbl, table2);
Ok(())
}
}
32 changes: 19 additions & 13 deletions src/daft-table/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
collections::{HashMap, HashSet},
fmt::{Display, Formatter, Result},
hash::{Hash, Hasher},
sync::Arc,
};

use arrow2::array::Array;
Expand Down Expand Up @@ -44,14 +45,14 @@
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
pub struct Table {
pub schema: SchemaRef,
columns: Vec<Series>,
columns: Arc<Vec<Series>>,
num_rows: usize,
}

impl Hash for Table {
fn hash<H: Hasher>(&self, state: &mut H) {
self.schema.hash(state);
for col in &self.columns {
for col in &*self.columns {

Check warning on line 55 in src/daft-table/src/lib.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-table/src/lib.rs#L55

Added line #L55 was not covered by tests
let hashes = col.hash(None).expect("Failed to hash column");
hashes.into_iter().for_each(|h| h.hash(state));
}
Expand Down Expand Up @@ -159,7 +160,7 @@
) -> Self {
Self {
schema: schema.into(),
columns,
columns: Arc::new(columns),
num_rows,
}
}
Expand All @@ -181,7 +182,8 @@
/// # Arguments
///
/// * `columns` - Columns to crate a table from as [`Series`] objects
pub fn from_nonempty_columns(columns: Vec<Series>) -> DaftResult<Self> {
pub fn from_nonempty_columns(columns: impl Into<Arc<Vec<Series>>>) -> DaftResult<Self> {
let columns = columns.into();
assert!(!columns.is_empty(), "Cannot call Table::new() with empty columns. This indicates an internal error, please file an issue.");

let schema = Schema::new(columns.iter().map(|s| s.field().clone()).collect())?;
Expand All @@ -199,7 +201,11 @@
}
}

Ok(Self::new_unchecked(schema, columns, num_rows))
Ok(Self {
schema,
columns,
num_rows,
})
}

pub fn num_columns(&self) -> usize {
Expand Down Expand Up @@ -231,11 +237,11 @@

pub fn head(&self, num: usize) -> DaftResult<Self> {
if num >= self.len() {
return Ok(Self::new_unchecked(
self.schema.clone(),
self.columns.clone(),
self.len(),
));
return Ok(Self {
schema: self.schema.clone(),
columns: self.columns.clone(),
num_rows: self.len(),
});

Check warning on line 244 in src/daft-table/src/lib.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-table/src/lib.rs#L240-L244

Added lines #L240 - L244 were not covered by tests
}
self.slice(0, num)
}
Expand Down Expand Up @@ -769,7 +775,7 @@
// Begin row.
res.push_str("<tr>");

for col in &self.columns {
for col in &*self.columns {
res.push_str(styled_td);
res.push_str(&html_value(col, i));
res.push_str("</div></td>");
Expand All @@ -781,7 +787,7 @@

if tail_rows != 0 {
res.push_str("<tr>");
for _ in &self.columns {
for _ in &*self.columns {

Check warning on line 790 in src/daft-table/src/lib.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-table/src/lib.rs#L790

Added line #L790 was not covered by tests
res.push_str("<td>...</td>");
}
res.push_str("</tr>\n");
Expand All @@ -791,7 +797,7 @@
// Begin row.
res.push_str("<tr>");

for col in &self.columns {
for col in &*self.columns {

Check warning on line 800 in src/daft-table/src/lib.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-table/src/lib.rs#L800

Added line #L800 was not covered by tests
res.push_str(styled_td);
res.push_str(&html_value(col, i));
res.push_str("</td>");
Expand Down
4 changes: 3 additions & 1 deletion src/daft-table/src/ops/explode.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::sync::Arc;

use common_error::{DaftError, DaftResult};
use daft_core::{
array::ops::as_arrow::AsArrow,
Expand Down Expand Up @@ -79,7 +81,7 @@ impl Table {
let capacity_expected = exploded_columns.first().unwrap().len();
let take_idx = lengths_to_indices(&first_len, capacity_expected)?.into_series();

let mut new_series = self.columns.clone();
let mut new_series = Arc::unwrap_or_clone(self.columns.clone());

for i in 0..self.num_columns() {
let name = new_series.get(i).unwrap().name();
Expand Down
14 changes: 9 additions & 5 deletions src/daft-table/src/ops/joins/hash_join.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::{cmp, iter::repeat};
use std::{cmp, iter::repeat, sync::Arc};

use arrow2::{bitmap::MutableBitmap, types::IndexRange};
use common_error::DaftResult;
Expand Down Expand Up @@ -89,11 +89,13 @@ pub(super) fn hash_inner_join(

let common_join_keys: Vec<_> = get_common_join_keys(left_on, right_on).collect();

let mut join_series = left
let join_series = left
.get_columns(common_join_keys.as_slice())?
.take(&lidx)?
.columns;

let mut join_series = Arc::unwrap_or_clone(join_series);

drop(lkeys);
drop(rkeys);

Expand Down Expand Up @@ -198,9 +200,11 @@ pub(super) fn hash_left_right_join(
let common_join_keys = get_common_join_keys(left_on, right_on);

let mut join_series = if left_side {
left.get_columns(common_join_keys.collect::<Vec<_>>().as_slice())?
.take(&lidx)?
.columns
Arc::unwrap_or_clone(
left.get_columns(common_join_keys.collect::<Vec<_>>().as_slice())?
.take(&lidx)?
.columns,
)
} else {
common_join_keys
.map(|name| {
Expand Down
10 changes: 6 additions & 4 deletions src/daft-table/src/ops/joins/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::collections::HashSet;
use std::{collections::HashSet, sync::Arc};

use common_error::{DaftError, DaftResult};
use daft_core::{
Expand Down Expand Up @@ -216,7 +216,7 @@ impl Table {
Table::concat(&vec![input; outer_len])
}

let (left_table, mut right_table) = match outer_loop_side {
let (left_table, right_table) = match outer_loop_side {
JoinSide::Left => (
create_outer_loop_table(self, right.len())?,
create_inner_loop_table(right, self.len())?,
Expand All @@ -230,8 +230,10 @@ impl Table {
let num_rows = self.len() * right.len();

let join_schema = self.schema.union(&right.schema)?;
let mut join_columns = left_table.columns;
join_columns.append(&mut right_table.columns);
let mut join_columns = Arc::unwrap_or_clone(left_table.columns);
let mut right_columns = Arc::unwrap_or_clone(right_table.columns);

join_columns.append(&mut right_columns);

Self::new_with_size(join_schema, join_columns, num_rows)
}
Expand Down
9 changes: 8 additions & 1 deletion src/daft-table/src/ops/unpivot.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::sync::Arc;

use common_error::{DaftError, DaftResult};
use daft_core::{prelude::*, series::cast_series_to_supertype};
use daft_dsl::ExprRef;
Expand Down Expand Up @@ -52,7 +54,12 @@ impl Table {
variable_series.field().clone(),
value_series.field().clone(),
])?)?;
let unpivot_series = [ids_series, vec![variable_series, value_series]].concat();

let unpivot_series = [
Arc::unwrap_or_clone(ids_series),
vec![variable_series, value_series],
]
.concat();

Self::new_with_size(unpivot_schema, unpivot_series, unpivoted_len)
}
Expand Down
Loading