diff --git a/datafusion-flight-sql-server/src/service.rs b/datafusion-flight-sql-server/src/service.rs index bc27020..734eff4 100644 --- a/datafusion-flight-sql-server/src/service.rs +++ b/datafusion-flight-sql-server/src/service.rs @@ -2,9 +2,9 @@ use std::{pin::Pin, sync::Arc}; use arrow::{ array::{ArrayRef, RecordBatch, StringArray}, - datatypes::{DataType, Field, SchemaRef}, + datatypes::{DataType, Field, FieldRef, Fields, SchemaRef}, error::ArrowError, - ipc::writer::IpcWriteOptions, + ipc::writer::{DictionaryTracker, IpcWriteOptions}, }; use arrow_flight::sql::{ self, @@ -945,9 +945,131 @@ fn encode_schema(schema: &Schema) -> std::result::Result { /// Return the schema for the specified logical plan fn get_schema_for_plan(logical_plan: LogicalPlan) -> SchemaRef { // gather real schema, but only - let schema = Arc::new(Schema::from(logical_plan.schema().as_ref())) as _; + let schema = Schema::from(logical_plan.schema().as_ref()); - schema + prepare_schema_for_flight(&schema, &mut DictionaryTracker::new(true), false).into() +} + +/// Prepare an arrow Schema for transport over the Arrow Flight protocol +/// +/// Convert dictionary types to underlying types +/// +/// See hydrate_dictionary for more information +/// +/// NOTE: This logic comes from https://github.com/apache/arrow-rs/blob/master/arrow-flight/src/encode.rs#L527 +/// As its the logic used by the [`FlightDataEncoderBuilder`] with [`DictionaryHandling::Hydrate`]. +/// We need to replicate it here so that the schema we report from the flight_info requests matches +/// the actual schema of the data encoded over the wire. +fn prepare_schema_for_flight( + schema: &Schema, + dictionary_tracker: &mut DictionaryTracker, + send_dictionaries: bool, +) -> Schema { + let fields: Fields = schema + .fields() + .iter() + .map(|field| match field.data_type() { + DataType::Dictionary(_, value_type) => { + if !send_dictionaries { + Field::new( + field.name(), + value_type.as_ref().clone(), + field.is_nullable(), + ) + .with_metadata(field.metadata().clone()) + } else { + let dict_id = dictionary_tracker.set_dict_id(field.as_ref()); + Field::new_dict( + field.name(), + field.data_type().clone(), + field.is_nullable(), + dict_id, + field.dict_is_ordered().unwrap_or_default(), + ) + .with_metadata(field.metadata().clone()) + } + } + tpe if tpe.is_nested() => { + prepare_field_for_flight(field, dictionary_tracker, send_dictionaries) + } + _ => field.as_ref().clone(), + }) + .collect(); + + Schema::new(fields).with_metadata(schema.metadata().clone()) +} + +fn prepare_field_for_flight( + field: &FieldRef, + dictionary_tracker: &mut DictionaryTracker, + send_dictionaries: bool, +) -> Field { + match field.data_type() { + DataType::List(inner) => Field::new_list( + field.name(), + prepare_field_for_flight(inner, dictionary_tracker, send_dictionaries), + field.is_nullable(), + ) + .with_metadata(field.metadata().clone()), + DataType::LargeList(inner) => Field::new_list( + field.name(), + prepare_field_for_flight(inner, dictionary_tracker, send_dictionaries), + field.is_nullable(), + ) + .with_metadata(field.metadata().clone()), + DataType::Struct(fields) => { + let new_fields: Vec = fields + .iter() + .map(|f| prepare_field_for_flight(f, dictionary_tracker, send_dictionaries)) + .collect(); + Field::new_struct(field.name(), new_fields, field.is_nullable()) + .with_metadata(field.metadata().clone()) + } + DataType::Union(fields, mode) => { + let (type_ids, new_fields): (Vec, Vec) = fields + .iter() + .map(|(type_id, f)| { + ( + type_id, + prepare_field_for_flight(f, dictionary_tracker, send_dictionaries), + ) + }) + .unzip(); + + Field::new_union(field.name(), type_ids, new_fields, *mode) + } + DataType::Dictionary(_, value_type) => { + if !send_dictionaries { + Field::new( + field.name(), + value_type.as_ref().clone(), + field.is_nullable(), + ) + .with_metadata(field.metadata().clone()) + } else { + let dict_id = dictionary_tracker.set_dict_id(field.as_ref()); + + Field::new_dict( + field.name(), + field.data_type().clone(), + field.is_nullable(), + dict_id, + field.dict_is_ordered().unwrap_or_default(), + ) + .with_metadata(field.metadata().clone()) + } + } + DataType::Map(inner, sorted) => Field::new( + field.name(), + DataType::Map( + prepare_field_for_flight(inner, dictionary_tracker, send_dictionaries).into(), + *sorted, + ), + field.is_nullable(), + ) + .with_metadata(field.metadata().clone()), + _ => field.as_ref().clone(), + } } fn arrow_error_to_status(err: ArrowError) -> Status {