Skip to content

Commit

Permalink
feat: support prepared statement parameters (#81)
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanielc authored Nov 28, 2024
1 parent 7818648 commit 591f65c
Show file tree
Hide file tree
Showing 2 changed files with 248 additions and 47 deletions.
244 changes: 210 additions & 34 deletions datafusion-flight-sql-server/src/service.rs
Original file line number Diff line number Diff line change
@@ -1,24 +1,32 @@
use std::{pin::Pin, sync::Arc};
use std::{collections::BTreeMap, pin::Pin, sync::Arc};

use arrow::{
array::{ArrayRef, RecordBatch, StringArray},
datatypes::{DataType, Field, FieldRef, Fields, SchemaRef},
compute::concat_batches,
datatypes::{DataType, Field, FieldRef, Fields, SchemaBuilder, SchemaRef},
error::ArrowError,
ipc::writer::{DictionaryTracker, IpcWriteOptions},
ipc::{
reader::StreamReader,
writer::{DictionaryTracker, IpcWriteOptions, StreamWriter},
},
};
use arrow_flight::sql::{
self,
server::{FlightSqlService as ArrowFlightSqlService, PeekableFlightDataStream},
ActionBeginSavepointRequest, ActionBeginSavepointResult, ActionBeginTransactionRequest,
ActionBeginTransactionResult, ActionCancelQueryRequest, ActionCancelQueryResult,
ActionClosePreparedStatementRequest, ActionCreatePreparedStatementRequest,
ActionCreatePreparedStatementResult, ActionCreatePreparedSubstraitPlanRequest,
ActionEndSavepointRequest, ActionEndTransactionRequest, Any, CommandGetCatalogs,
CommandGetCrossReference, CommandGetDbSchemas, CommandGetExportedKeys, CommandGetImportedKeys,
CommandGetPrimaryKeys, CommandGetSqlInfo, CommandGetTableTypes, CommandGetTables,
CommandGetXdbcTypeInfo, CommandPreparedStatementQuery, CommandPreparedStatementUpdate,
CommandStatementQuery, CommandStatementSubstraitPlan, CommandStatementUpdate,
DoPutPreparedStatementResult, ProstMessageExt as _, SqlInfo, TicketStatementQuery,
use arrow_flight::{
decode::{DecodedPayload, FlightDataDecoder},
sql::{
self,
server::{FlightSqlService as ArrowFlightSqlService, PeekableFlightDataStream},
ActionBeginSavepointRequest, ActionBeginSavepointResult, ActionBeginTransactionRequest,
ActionBeginTransactionResult, ActionCancelQueryRequest, ActionCancelQueryResult,
ActionClosePreparedStatementRequest, ActionCreatePreparedStatementRequest,
ActionCreatePreparedStatementResult, ActionCreatePreparedSubstraitPlanRequest,
ActionEndSavepointRequest, ActionEndTransactionRequest, Any, CommandGetCatalogs,
CommandGetCrossReference, CommandGetDbSchemas, CommandGetExportedKeys,
CommandGetImportedKeys, CommandGetPrimaryKeys, CommandGetSqlInfo, CommandGetTableTypes,
CommandGetTables, CommandGetXdbcTypeInfo, CommandPreparedStatementQuery,
CommandPreparedStatementUpdate, CommandStatementQuery, CommandStatementSubstraitPlan,
CommandStatementUpdate, DoPutPreparedStatementResult, ProstMessageExt as _, SqlInfo,
TicketStatementQuery,
},
};
use arrow_flight::{
encode::FlightDataEncoderBuilder,
Expand All @@ -28,13 +36,14 @@ use arrow_flight::{
IpcMessage, SchemaAsIpc, Ticket,
};
use datafusion::{
common::arrow::datatypes::Schema,
common::{arrow::datatypes::Schema, ParamValues},
dataframe::DataFrame,
datasource::TableType,
error::{DataFusionError, Result as DataFusionResult},
execution::context::{SQLOptions, SessionContext, SessionState},
logical_expr::LogicalPlan,
physical_plan::SendableRecordBatchStream,
scalar::ScalarValue,
};
use datafusion_substrait::{
logical_plan::consumer::from_substrait_plan, serializer::deserialize_bytes,
Expand Down Expand Up @@ -143,6 +152,13 @@ impl FlightSqlSessionContext {

async fn execute_sql(&self, sql: &str) -> DataFusionResult<SendableRecordBatchStream> {
let plan = self.sql_to_logical_plan(sql).await?;
self.execute_logical_plan(plan).await
}

async fn execute_logical_plan(
&self,
plan: LogicalPlan,
) -> DataFusionResult<SendableRecordBatchStream> {
self.inner
.execute_logical_plan(plan)
.await?
Expand Down Expand Up @@ -198,10 +214,25 @@ impl ArrowFlightSqlService for FlightSqlService {
sql::Command::CommandPreparedStatementQuery(CommandPreparedStatementQuery {
prepared_statement_handle,
}) => {
let query = std::str::from_utf8(prepared_statement_handle.as_ref()).unwrap();
// print!("Query: {query}\n");
let handle = QueryHandle::try_decode(prepared_statement_handle)?;

let mut plan = ctx
.sql_to_logical_plan(handle.query())
.await
.map_err(df_error_to_status)?;

if let Some(param_values) =
decode_param_values(handle.parameters()).map_err(arrow_error_to_status)?
{
plan = plan
.with_param_values(param_values)
.map_err(df_error_to_status)?;
}

let stream = ctx.execute_sql(query).await.map_err(df_error_to_status)?;
let stream = ctx
.execute_logical_plan(plan)
.await
.map_err(df_error_to_status)?;
let arrow_schema = stream.schema();
let arrow_stream = stream.map(|i| {
let batch = i.map_err(|e| FlightError::ExternalError(e.into()))?;
Expand Down Expand Up @@ -272,7 +303,7 @@ impl ArrowFlightSqlService for FlightSqlService {
.await
.map_err(df_error_to_status)?;

let dataset_schema = get_schema_for_plan(plan);
let dataset_schema = get_schema_for_plan(&plan);

// Form the response ticket (that the client will pass back to DoGet)
let ticket = CommandTicket::new(sql::Command::CommandStatementQuery(query))
Expand Down Expand Up @@ -311,7 +342,7 @@ impl ArrowFlightSqlService for FlightSqlService {

let flight_descriptor = request.into_inner();

let dataset_schema = get_schema_for_plan(plan);
let dataset_schema = get_schema_for_plan(&plan);

// Form the response ticket (that the client will pass back to DoGet)
let ticket = CommandTicket::new(sql::Command::CommandStatementSubstraitPlan(query))
Expand Down Expand Up @@ -350,7 +381,7 @@ impl ArrowFlightSqlService for FlightSqlService {
.await
.map_err(df_error_to_status)?;

let dataset_schema = get_schema_for_plan(plan);
let dataset_schema = get_schema_for_plan(&plan);

// Form the response ticket (that the client will pass back to DoGet)
let ticket = CommandTicket::new(sql::Command::CommandPreparedStatementQuery(cmd))
Expand Down Expand Up @@ -754,15 +785,56 @@ impl ArrowFlightSqlService for FlightSqlService {

async fn do_put_prepared_statement_query(
&self,
_query: CommandPreparedStatementQuery,
query: CommandPreparedStatementQuery,
request: Request<PeekableFlightDataStream>,
) -> Result<DoPutPreparedStatementResult, Status> {
info!("do_put_prepared_statement_query");
let (_, _) = self.new_context(request).await?;
let (request, _) = self.new_context(request).await?;

Err(Status::unimplemented(
"Implement do_put_prepared_statement_query",
))
let mut handle = QueryHandle::try_decode(query.prepared_statement_handle)?;

info!(
"do_action_create_prepared_statement query={:?}",
handle.query()
);
// Collect request flight data as parameters
// Decode and encode as a single ipc stream
let mut decoder =
FlightDataDecoder::new(request.into_inner().map_err(status_to_flight_error));
let schema = decode_schema(&mut decoder).await?;
let mut parameters = Vec::new();
let mut encoder =
StreamWriter::try_new(&mut parameters, &schema).map_err(arrow_error_to_status)?;
let mut total_rows = 0;
while let Some(msg) = decoder.try_next().await? {
match msg.payload {
DecodedPayload::None => {}
DecodedPayload::Schema(_) => {
return Err(Status::invalid_argument(
"parameter flight data must contain a single schema",
));
}
DecodedPayload::RecordBatch(record_batch) => {
total_rows += record_batch.num_rows();
encoder
.write(&record_batch)
.map_err(arrow_error_to_status)?;
}
}
}
if total_rows > 1 {
return Err(Status::invalid_argument(
"parameters should contain a single row",
));
}

handle.set_parameters(Some(parameters.into()));

let res = DoPutPreparedStatementResult {
prepared_statement_handle: Some(Bytes::from(handle)),
};

Ok(res)
}

async fn do_put_prepared_statement_update(
Expand Down Expand Up @@ -809,15 +881,20 @@ impl ArrowFlightSqlService for FlightSqlService {
.await
.map_err(df_error_to_status)?;

let dataset_schema = get_schema_for_plan(plan);
let dataset_schema = get_schema_for_plan(&plan);
let parameter_schema = parameter_schema_for_plan(&plan)?;

let dataset_schema =
encode_schema(dataset_schema.as_ref()).map_err(arrow_error_to_status)?;
let parameter_schema =
encode_schema(parameter_schema.as_ref()).map_err(arrow_error_to_status)?;

let schema_bytes = encode_schema(dataset_schema.as_ref()).map_err(arrow_error_to_status)?;
let handle = QueryHandle::new(sql);
let handle = QueryHandle::new(sql, None);

let res = ActionCreatePreparedStatementResult {
prepared_statement_handle: Bytes::from(handle),
dataset_schema: schema_bytes,
parameter_schema: Default::default(),
dataset_schema,
parameter_schema,
};

Ok(res)
Expand Down Expand Up @@ -943,7 +1020,7 @@ fn encode_schema(schema: &Schema) -> std::result::Result<Bytes, ArrowError> {
}

/// Return the schema for the specified logical plan
fn get_schema_for_plan(logical_plan: LogicalPlan) -> SchemaRef {
fn get_schema_for_plan(logical_plan: &LogicalPlan) -> SchemaRef {
// gather real schema, but only
let schema = Schema::from(logical_plan.schema().as_ref());

Expand Down Expand Up @@ -1072,6 +1149,28 @@ fn prepare_field_for_flight(
}
}

fn parameter_schema_for_plan(plan: &LogicalPlan) -> Result<SchemaRef, Status> {
let parameters = plan
.get_parameter_types()
.map_err(df_error_to_status)?
.into_iter()
.map(|(name, dt)| {
dt.map(|dt| (name.clone(), dt)).ok_or_else(|| {
Status::internal(format!(
"unable to determine type of query parameter {name}"
))
})
})
// Collect into BTreeMap so we get a consistent order of the parameters
.collect::<Result<BTreeMap<_, _>, Status>>()?;

let mut builder = SchemaBuilder::new();
parameters
.into_iter()
.for_each(|(name, typ)| builder.push(Field::new(name, typ, false)));
Ok(builder.finish().into())
}

fn arrow_error_to_status(err: ArrowError) -> Status {
Status::internal(format!("{err:?}"))
}
Expand All @@ -1083,3 +1182,80 @@ fn flight_error_to_status(err: FlightError) -> Status {
fn df_error_to_status(err: DataFusionError) -> Status {
Status::internal(format!("{err:?}"))
}

fn status_to_flight_error(status: Status) -> FlightError {
FlightError::Tonic(status)
}

async fn decode_schema(decoder: &mut FlightDataDecoder) -> Result<SchemaRef, Status> {
while let Some(msg) = decoder.try_next().await? {
match msg.payload {
DecodedPayload::None => {}
DecodedPayload::Schema(schema) => {
return Ok(schema);
}
DecodedPayload::RecordBatch(_) => {
return Err(Status::invalid_argument(
"parameter flight data must have a known schema",
));
}
}
}

Err(Status::invalid_argument(
"parameter flight data must have a schema",
))
}

// Decode parameter ipc stream as ParamValues
fn decode_param_values(
parameters: Option<&[u8]>,
) -> Result<Option<ParamValues>, arrow::error::ArrowError> {
parameters
.map(|parameters| {
let decoder = StreamReader::try_new(parameters, None)?;
let schema = decoder.schema();
let batches = decoder.into_iter().collect::<Result<Vec<_>, _>>()?;
let batch = concat_batches(&schema, batches.iter())?;
Ok(record_to_param_values(&batch)?)
})
.transpose()
}

// Converts a record batch with a single row into ParamValues
fn record_to_param_values(batch: &RecordBatch) -> Result<ParamValues, DataFusionError> {
let mut param_values: Vec<(String, Option<usize>, ScalarValue)> = Vec::new();

let mut is_list = true;
for col_index in 0..batch.num_columns() {
let array = batch.column(col_index);
let scalar = ScalarValue::try_from_array(array, 0)?;
let name = batch
.schema_ref()
.field(col_index)
.name()
.trim_start_matches('$')
.to_string();
let index = name.parse().ok();
is_list &= index.is_some();
param_values.push((name, index, scalar));
}
if is_list {
let mut values: Vec<(Option<usize>, ScalarValue)> = param_values
.into_iter()
.map(|(_name, index, value)| (index, value))
.collect();
values.sort_by_key(|(index, _value)| *index);
Ok(values
.into_iter()
.map(|(_index, value)| value)
.collect::<Vec<ScalarValue>>()
.into())
} else {
Ok(param_values
.into_iter()
.map(|(name, _index, value)| (name, value))
.collect::<Vec<(String, ScalarValue)>>()
.into())
}
}
Loading

0 comments on commit 591f65c

Please sign in to comment.