Skip to content

Commit

Permalink
test passes yay
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewgazelka committed Nov 7, 2024
1 parent 0b5bb26 commit fe92b98
Show file tree
Hide file tree
Showing 11 changed files with 216 additions and 70 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ comfy-table = "7.1.1"
common-daft-config = {path = "src/common/daft-config"}
common-display = {path = "src/common/display"}
common-error = {path = "src/common/error", default-features = false}
common-file-formats = {path = "src/common/file-formats"}
daft-connect = {path = "src/daft-connect", default-features = false}
daft-core = {path = "src/daft-core"}
daft-dsl = {path = "src/daft-dsl"}
Expand Down
3 changes: 2 additions & 1 deletion src/daft-connect/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ dashmap = "6.1.0"
eyre = "0.6.12"
futures = "0.3.31"
pyo3 = {workspace = true, optional = true}
#ron = "0.9.0-alpha.0"
# ron = "0.9.0-alpha.0"
tokio = {version = "1.40.0", features = ["full"]}
tokio-stream = "0.1.16"
tonic = "0.12.3"
Expand All @@ -13,6 +13,7 @@ tracing-tracy = "0.11.3"
uuid = {version = "1.10.0", features = ["v4"]}
arrow2.workspace = true
common-daft-config.workspace = true
common-file-formats.workspace = true
daft-core.workspace = true
daft-dsl.workspace = true
daft-local-execution.workspace = true
Expand Down
137 changes: 132 additions & 5 deletions src/daft-connect/src/command.rs
Original file line number Diff line number Diff line change
@@ -1,21 +1,26 @@
// Stream of Result<ExecutePlanResponse, Status>

use std::thread;
use std::{ops::ControlFlow, thread};

use arrow2::io::ipc::write::StreamWriter;
use common_file_formats::FileFormat;
use daft_table::Table;
use eyre::Context;
use futures::TryStreamExt;
use spark_connect::{
execute_plan_response::{ArrowBatch, ResponseType, ResultComplete},
spark_connect_service_server::SparkConnectService,
ExecutePlanResponse, Relation,
write_operation::{SaveMode, SaveType},
ExecutePlanResponse, Relation, WriteOperation,
};
use tokio_stream::wrappers::UnboundedReceiverStream;
use tonic::Status;
use uuid::Uuid;

use crate::{convert::convert_data, DaftSparkConnectService, Session};
use crate::{
convert::{convert_data, run_local, to_logical_plan},
invalid_argument, unimplemented_err, DaftSparkConnectService, Session,
};

type DaftStream = <DaftSparkConnectService as SparkConnectService>::ExecutePlanStream;

Expand Down Expand Up @@ -89,14 +94,14 @@ impl Session {
let (tx, rx) = tokio::sync::mpsc::unbounded_channel();

let mut channel = ExecutablePlanChannel {
session_id: self.id().to_string(),
session_id: self.client_side_session_id().to_string(),
server_side_session_id: self.server_side_session_id().to_string(),
operation_id: operation_id.clone(),
tx: tx.clone(),
};

thread::spawn({
let session_id = self.id().to_string();
let session_id = self.client_side_session_id().to_string();
let server_side_session_id = self.server_side_session_id().to_string();
move || {
let result = convert_data(command, &mut channel);
Expand Down Expand Up @@ -125,4 +130,126 @@ impl Session {

Ok(Box::pin(recv_stream))
}

pub fn handle_write_operation(
&self,
operation: WriteOperation,
operation_id: String,
) -> Result<DaftStream, Status> {
let mode = operation.mode();

let WriteOperation {
input,
source,
sort_column_names,
partitioning_columns,
bucket_by,
options,
clustering_columns,
save_type,
mode: _,
} = operation;

let input = input.ok_or_else(|| invalid_argument!("input is None"))?;

let source = source.unwrap_or_else(|| "parquet".to_string());
if source != "parquet" {
return Err(unimplemented_err!(
"Only writing parquet is supported for now but got {source}"
));
}

match mode {
SaveMode::Unspecified => {}
SaveMode::Append => {
return Err(unimplemented_err!("Append mode is not yet supported"));
}
SaveMode::Overwrite => {
return Err(unimplemented_err!("Overwrite mode is not yet supported"));
}
SaveMode::ErrorIfExists => {
return Err(unimplemented_err!(
"ErrorIfExists mode is not yet supported"
));
}
SaveMode::Ignore => {
return Err(unimplemented_err!("Ignore mode is not yet supported"));
}
}

if !sort_column_names.is_empty() {
return Err(unimplemented_err!("Sort by columns is not yet supported"));
}

if !partitioning_columns.is_empty() {
return Err(unimplemented_err!(
"Partitioning columns is not yet supported"
));
}

if bucket_by.is_some() {
return Err(unimplemented_err!("Bucket by columns is not yet supported"));
}

if !options.is_empty() {
return Err(unimplemented_err!("Options are not yet supported"));
}

if !clustering_columns.is_empty() {
return Err(unimplemented_err!(
"Clustering columns is not yet supported"
));
}

let save_type = save_type.ok_or_else(|| invalid_argument!("save_type is required"))?;

let save_path = match save_type {
SaveType::Path(path) => path,
SaveType::Table(_) => {
return Err(unimplemented_err!("Save type table is not yet supported"));
}
};

std::thread::scope(|scope| {
let res = scope.spawn(|| {
let plan = to_logical_plan(input)
.map_err(|_| Status::internal("Failed to convert to logical plan"))?;

// todo: assuming this is parquet
// todo: is save_path right?
let plan = plan
.table_write(&save_path, FileFormat::Parquet, None, None, None)
.map_err(|_| Status::internal("Failed to write table"))?;

let plan = plan.build();

run_local(
&plan,
|_table| ControlFlow::Continue(()),
|| ControlFlow::Break(()),
)
.map_err(|e| Status::internal(format!("Failed to write table: {e}")))?;

Result::<(), Status>::Ok(())
});

res.join().unwrap()
})?;

let session_id = self.client_side_session_id().to_string();
let server_side_session_id = self.server_side_session_id().to_string();

Ok(Box::pin(futures::stream::once(async {
Ok(ExecutePlanResponse {
session_id,
server_side_session_id,
operation_id,
response_id: "abcxyz".to_string(),
metrics: None,
observed_metrics: vec![],
schema: None,
response_type: Some(ResponseType::ResultComplete(ResultComplete {})),
})
})))
}
}
21 changes: 12 additions & 9 deletions src/daft-connect/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use crate::Session;
impl Session {
fn config_response(&self) -> ConfigResponse {
ConfigResponse {
session_id: self.id().to_string(),
session_id: self.client_side_session_id().to_string(),
server_side_session_id: self.server_side_session_id().to_string(),
pairs: vec![],
warnings: vec![],
Expand All @@ -21,7 +21,8 @@ impl Session {
pub fn set(&mut self, operation: Set) -> Result<ConfigResponse, Status> {
let mut response = self.config_response();

let span = tracing::info_span!("set", session_id = %self.id(), ?operation);
let span =
tracing::info_span!("set", session_id = %self.client_side_session_id(), ?operation);
let _enter = span.enter();

for KeyValue { key, value } in operation.pairs {
Expand All @@ -45,7 +46,7 @@ impl Session {
pub fn get(&self, operation: Get) -> Result<ConfigResponse, Status> {
let mut response = self.config_response();

let span = tracing::info_span!("get", session_id = %self.id());
let span = tracing::info_span!("get", session_id = %self.client_side_session_id());
let _enter = span.enter();

for key in operation.keys {
Expand All @@ -59,7 +60,8 @@ impl Session {
pub fn get_with_default(&self, operation: GetWithDefault) -> Result<ConfigResponse, Status> {
let mut response = self.config_response();

let span = tracing::info_span!("get_with_default", session_id = %self.id());
let span =
tracing::info_span!("get_with_default", session_id = %self.client_side_session_id());
let _enter = span.enter();

for KeyValue {
Expand All @@ -79,7 +81,7 @@ impl Session {
pub fn get_option(&self, operation: GetOption) -> Result<ConfigResponse, Status> {
let mut response = self.config_response();

let span = tracing::info_span!("get_option", session_id = %self.id());
let span = tracing::info_span!("get_option", session_id = %self.client_side_session_id());
let _enter = span.enter();

for key in operation.keys {
Expand All @@ -93,7 +95,7 @@ impl Session {
pub fn get_all(&self, operation: GetAll) -> Result<ConfigResponse, Status> {
let mut response = self.config_response();

let span = tracing::info_span!("get_all", session_id = %self.id());
let span = tracing::info_span!("get_all", session_id = %self.client_side_session_id());
let _enter = span.enter();

let Some(prefix) = operation.prefix else {
Expand All @@ -119,7 +121,7 @@ impl Session {
pub fn unset(&mut self, operation: Unset) -> Result<ConfigResponse, Status> {
let mut response = self.config_response();

let span = tracing::info_span!("unset", session_id = %self.id());
let span = tracing::info_span!("unset", session_id = %self.client_side_session_id());
let _enter = span.enter();

for key in operation.keys {
Expand All @@ -137,10 +139,11 @@ impl Session {
pub fn is_modifiable(&self, _operation: IsModifiable) -> Result<ConfigResponse, Status> {
let response = self.config_response();

let span = tracing::info_span!("is_modifiable", session_id = %self.id());
let span =
tracing::info_span!("is_modifiable", session_id = %self.client_side_session_id());
let _enter = span.enter();

tracing::warn!(session_id = %self.id(), "is_modifiable operation not yet implemented");
tracing::warn!(session_id = %self.client_side_session_id(), "is_modifiable operation not yet implemented");
// todo: need to implement this
Ok(response)
}
Expand Down
2 changes: 1 addition & 1 deletion src/daft-connect/src/convert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use eyre::Context;
pub use plan_conversion::to_logical_plan;
pub use schema_conversion::connect_schema;

pub fn map_to_tables<T: Try>(
pub fn run_local<T: Try>(
logical_plan: &LogicalPlanRef,
mut f: impl FnMut(&Table) -> T,
default: impl FnOnce() -> T,
Expand Down
4 changes: 2 additions & 2 deletions src/daft-connect/src/convert/data_conversion/show_string.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use spark_connect::ShowString;

use crate::{
command::ConcreteDataChannel,
convert::{map_to_tables, plan_conversion::to_logical_plan},
convert::{plan_conversion::to_logical_plan, run_local},
};

pub fn show_string(
Expand All @@ -28,7 +28,7 @@ pub fn show_string(

let logical_plan = to_logical_plan(input)?.build();

map_to_tables(
run_local(
&logical_plan,
|table| -> eyre::Result<()> {
let display = format!("{table}");
Expand Down
Loading

0 comments on commit fe92b98

Please sign in to comment.