Skip to content

Commit

Permalink
fix: use new known_schema api of FlightDataEncoder (#87)
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanielc authored Dec 5, 2024
1 parent 758323a commit c50ad3d
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 130 deletions.
6 changes: 3 additions & 3 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@ readme = "README.md"
repository = "https://github.com/datafusion-contrib/datafusion-federation"

[workspace.dependencies]
arrow = "53.2"
arrow-flight = { version = "53.2", features = ["flight-sql-experimental"] }
arrow-json = "53.2"
arrow = "53.3"
arrow-flight = { version = "53.3", features = ["flight-sql-experimental"] }
arrow-json = "53.3"
async-stream = "0.3.5"
async-trait = "0.1.83"
datafusion = "43.0.0"
Expand Down
142 changes: 15 additions & 127 deletions datafusion-flight-sql-server/src/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@ use std::{collections::BTreeMap, pin::Pin, sync::Arc};
use arrow::{
array::{ArrayRef, RecordBatch, StringArray},
compute::concat_batches,
datatypes::{DataType, Field, FieldRef, Fields, SchemaBuilder, SchemaRef},
datatypes::{DataType, Field, SchemaBuilder, SchemaRef},
error::ArrowError,
ipc::{
reader::StreamReader,
writer::{DictionaryTracker, IpcWriteOptions, StreamWriter},
writer::{IpcWriteOptions, StreamWriter},
},
};
use arrow_flight::{
Expand Down Expand Up @@ -1022,131 +1022,19 @@ fn encode_schema(schema: &Schema) -> std::result::Result<Bytes, ArrowError> {
/// Return the schema for the specified logical plan
fn get_schema_for_plan(logical_plan: &LogicalPlan) -> SchemaRef {
// gather real schema, but only
let schema = Schema::from(logical_plan.schema().as_ref());

prepare_schema_for_flight(&schema, &mut DictionaryTracker::new(true), false).into()
}

/// Prepare an arrow Schema for transport over the Arrow Flight protocol
///
/// Convert dictionary types to underlying types
///
/// See hydrate_dictionary for more information
///
/// NOTE: This logic comes from https://github.com/apache/arrow-rs/blob/master/arrow-flight/src/encode.rs#L527
/// As its the logic used by the [`FlightDataEncoderBuilder`] with [`DictionaryHandling::Hydrate`].
/// We need to replicate it here so that the schema we report from the flight_info requests matches
/// the actual schema of the data encoded over the wire.
fn prepare_schema_for_flight(
schema: &Schema,
dictionary_tracker: &mut DictionaryTracker,
send_dictionaries: bool,
) -> Schema {
let fields: Fields = schema
.fields()
.iter()
.map(|field| match field.data_type() {
DataType::Dictionary(_, value_type) => {
if !send_dictionaries {
Field::new(
field.name(),
value_type.as_ref().clone(),
field.is_nullable(),
)
.with_metadata(field.metadata().clone())
} else {
let dict_id = dictionary_tracker.set_dict_id(field.as_ref());
Field::new_dict(
field.name(),
field.data_type().clone(),
field.is_nullable(),
dict_id,
field.dict_is_ordered().unwrap_or_default(),
)
.with_metadata(field.metadata().clone())
}
}
tpe if tpe.is_nested() => {
prepare_field_for_flight(field, dictionary_tracker, send_dictionaries)
}
_ => field.as_ref().clone(),
})
.collect();

Schema::new(fields).with_metadata(schema.metadata().clone())
}

fn prepare_field_for_flight(
field: &FieldRef,
dictionary_tracker: &mut DictionaryTracker,
send_dictionaries: bool,
) -> Field {
match field.data_type() {
DataType::List(inner) => Field::new_list(
field.name(),
prepare_field_for_flight(inner, dictionary_tracker, send_dictionaries),
field.is_nullable(),
)
.with_metadata(field.metadata().clone()),
DataType::LargeList(inner) => Field::new_list(
field.name(),
prepare_field_for_flight(inner, dictionary_tracker, send_dictionaries),
field.is_nullable(),
)
.with_metadata(field.metadata().clone()),
DataType::Struct(fields) => {
let new_fields: Vec<Field> = fields
.iter()
.map(|f| prepare_field_for_flight(f, dictionary_tracker, send_dictionaries))
.collect();
Field::new_struct(field.name(), new_fields, field.is_nullable())
.with_metadata(field.metadata().clone())
}
DataType::Union(fields, mode) => {
let (type_ids, new_fields): (Vec<i8>, Vec<Field>) = fields
.iter()
.map(|(type_id, f)| {
(
type_id,
prepare_field_for_flight(f, dictionary_tracker, send_dictionaries),
)
})
.unzip();

Field::new_union(field.name(), type_ids, new_fields, *mode)
}
DataType::Dictionary(_, value_type) => {
if !send_dictionaries {
Field::new(
field.name(),
value_type.as_ref().clone(),
field.is_nullable(),
)
.with_metadata(field.metadata().clone())
} else {
let dict_id = dictionary_tracker.set_dict_id(field.as_ref());

Field::new_dict(
field.name(),
field.data_type().clone(),
field.is_nullable(),
dict_id,
field.dict_is_ordered().unwrap_or_default(),
)
.with_metadata(field.metadata().clone())
}
}
DataType::Map(inner, sorted) => Field::new(
field.name(),
DataType::Map(
prepare_field_for_flight(inner, dictionary_tracker, send_dictionaries).into(),
*sorted,
),
field.is_nullable(),
)
.with_metadata(field.metadata().clone()),
_ => field.as_ref().clone(),
}
let schema = Schema::from(logical_plan.schema().as_ref()).into();

// Use an empty FlightDataEncoder to determine the schema of the encoded flight data.
// This is necessary as the schema can change based on dictionary hydration behavior.
let flight_data_stream = FlightDataEncoderBuilder::new()
// Inform the builder of the input stream schema
.with_schema(schema)
.build(futures::stream::iter([]));

// Retrieve the schema of the encoded data
flight_data_stream
.known_schema()
.expect("flight data schema should be known when explicitly provided via `with_schema`")
}

fn parameter_schema_for_plan(plan: &LogicalPlan) -> Result<SchemaRef, Status> {
Expand Down

0 comments on commit c50ad3d

Please sign in to comment.