From e9f23d2243b56d313c6777eabec335d29f8d3e3d Mon Sep 17 00:00:00 2001 From: liukun4515 Date: Tue, 28 May 2024 13:58:17 +0800 Subject: [PATCH] support limit in agg exec for ser/deser --- datafusion/proto/proto/datafusion.proto | 6 + datafusion/proto/src/generated/pbjson.rs | 111 ++++++++++++++++++ datafusion/proto/src/generated/prost.rs | 9 ++ datafusion/proto/src/physical_plan/mod.rs | 18 ++- .../tests/cases/roundtrip_physical_plan.rs | 27 +++++ 5 files changed, 169 insertions(+), 2 deletions(-) diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index cb0ae0f551f2..9768eb79cef5 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -1110,6 +1110,11 @@ message MaybePhysicalSortExprs { repeated PhysicalSortExprNode sort_expr = 1; } +message AggLimit { + // wrap into a message to make it optional + uint64 limit = 1; +} + message AggregateExecNode { repeated PhysicalExprNode group_expr = 1; repeated PhysicalExprNode aggr_expr = 2; @@ -1122,6 +1127,7 @@ message AggregateExecNode { repeated PhysicalExprNode null_expr = 8; repeated bool groups = 9; repeated MaybeFilter filter_expr = 10; + AggLimit limit = 11; } message GlobalLimitExecNode { diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 2edbae24294b..592c5bdb021f 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -1,3 +1,97 @@ +impl serde::Serialize for AggLimit { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.limit != 0 { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.AggLimit", len)?; + if self.limit != 0 { + #[allow(clippy::needless_borrow)] + struct_ser.serialize_field("limit", ToString::to_string(&self.limit).as_str())?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for AggLimit { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "limit", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Limit, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "limit" => Ok(GeneratedField::Limit), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = AggLimit; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.AggLimit") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut limit__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Limit => { + if limit__.is_some() { + return Err(serde::de::Error::duplicate_field("limit")); + } + limit__ = + Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) + ; + } + } + } + Ok(AggLimit { + limit: limit__.unwrap_or_default(), + }) + } + } + deserializer.deserialize_struct("datafusion.AggLimit", FIELDS, GeneratedVisitor) + } +} impl serde::Serialize for AggregateExecNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result @@ -36,6 +130,9 @@ impl serde::Serialize for AggregateExecNode { if !self.filter_expr.is_empty() { len += 1; } + if self.limit.is_some() { + len += 1; + } let mut struct_ser = serializer.serialize_struct("datafusion.AggregateExecNode", len)?; if !self.group_expr.is_empty() { struct_ser.serialize_field("groupExpr", &self.group_expr)?; @@ -69,6 +166,9 @@ impl serde::Serialize for AggregateExecNode { if !self.filter_expr.is_empty() { struct_ser.serialize_field("filterExpr", &self.filter_expr)?; } + if let Some(v) = self.limit.as_ref() { + struct_ser.serialize_field("limit", v)?; + } struct_ser.end() } } @@ -96,6 +196,7 @@ impl<'de> serde::Deserialize<'de> for AggregateExecNode { "groups", "filter_expr", "filterExpr", + "limit", ]; #[allow(clippy::enum_variant_names)] @@ -110,6 +211,7 @@ impl<'de> serde::Deserialize<'de> for AggregateExecNode { NullExpr, Groups, FilterExpr, + Limit, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -141,6 +243,7 @@ impl<'de> serde::Deserialize<'de> for AggregateExecNode { "nullExpr" | "null_expr" => Ok(GeneratedField::NullExpr), "groups" => Ok(GeneratedField::Groups), "filterExpr" | "filter_expr" => Ok(GeneratedField::FilterExpr), + "limit" => Ok(GeneratedField::Limit), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -170,6 +273,7 @@ impl<'de> serde::Deserialize<'de> for AggregateExecNode { let mut null_expr__ = None; let mut groups__ = None; let mut filter_expr__ = None; + let mut limit__ = None; while let Some(k) = map_.next_key()? { match k { GeneratedField::GroupExpr => { @@ -232,6 +336,12 @@ impl<'de> serde::Deserialize<'de> for AggregateExecNode { } filter_expr__ = Some(map_.next_value()?); } + GeneratedField::Limit => { + if limit__.is_some() { + return Err(serde::de::Error::duplicate_field("limit")); + } + limit__ = map_.next_value()?; + } } } Ok(AggregateExecNode { @@ -245,6 +355,7 @@ impl<'de> serde::Deserialize<'de> for AggregateExecNode { null_expr: null_expr__.unwrap_or_default(), groups: groups__.unwrap_or_default(), filter_expr: filter_expr__.unwrap_or_default(), + limit: limit__, }) } } diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index e9407cc65bb1..8dc6ac5f5168 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -1709,6 +1709,13 @@ pub struct MaybePhysicalSortExprs { } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] +pub struct AggLimit { + /// wrap into a message to make it optional + #[prost(uint64, tag = "1")] + pub limit: u64, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] pub struct AggregateExecNode { #[prost(message, repeated, tag = "1")] pub group_expr: ::prost::alloc::vec::Vec, @@ -1731,6 +1738,8 @@ pub struct AggregateExecNode { pub groups: ::prost::alloc::vec::Vec, #[prost(message, repeated, tag = "10")] pub filter_expr: ::prost::alloc::vec::Vec, + #[prost(message, optional, tag = "11")] + pub limit: ::core::option::Option, } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] diff --git a/datafusion/proto/src/physical_plan/mod.rs b/datafusion/proto/src/physical_plan/mod.rs index a85bfdc89d01..91ed3b7f5e7c 100644 --- a/datafusion/proto/src/physical_plan/mod.rs +++ b/datafusion/proto/src/physical_plan/mod.rs @@ -539,14 +539,23 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { }) .collect::, _>>()?; - Ok(Arc::new(AggregateExec::try_new( + let limit = hash_agg + .limit + .as_ref() + .map(|lit_value| lit_value.limit as usize); + + let agg = AggregateExec::try_new( agg_mode, PhysicalGroupBy::new(group_expr, null_expr, groups), physical_aggr_expr, physical_filter_expr, input, physical_schema, - )?)) + )?; + + let agg = agg.with_limit(limit); + + Ok(Arc::new(agg)) } PhysicalPlanType::HashJoin(hashjoin) => { let left: Arc = into_physical_plan( @@ -1504,6 +1513,10 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { .map(|expr| serialize_physical_expr(expr.0.to_owned(), extension_codec)) .collect::>>()?; + let limit = exec.limit().map(|value| protobuf::AggLimit { + limit: value as u64, + }); + return Ok(protobuf::PhysicalPlanNode { physical_plan_type: Some(PhysicalPlanType::Aggregate(Box::new( protobuf::AggregateExecNode { @@ -1517,6 +1530,7 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { input_schema: Some(input_schema.as_ref().try_into()?), null_expr, groups, + limit, }, ))), }); diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index 79abecf556da..55b346a482d3 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -370,6 +370,33 @@ fn rountrip_aggregate() -> Result<()> { Ok(()) } +#[test] +fn rountrip_aggregate_with_limit() -> Result<()> { + let field_a = Field::new("a", DataType::Int64, false); + let field_b = Field::new("b", DataType::Int64, false); + let schema = Arc::new(Schema::new(vec![field_a, field_b])); + + let groups: Vec<(Arc, String)> = + vec![(col("a", &schema)?, "unused".to_string())]; + + let aggregates: Vec> = vec![Arc::new(Avg::new( + cast(col("b", &schema)?, &schema, DataType::Float64)?, + "AVG(b)".to_string(), + DataType::Float64, + ))]; + + let agg = AggregateExec::try_new( + AggregateMode::Final, + PhysicalGroupBy::new_single(groups.clone()), + aggregates.clone(), + vec![None], + Arc::new(EmptyExec::new(schema.clone())), + schema, + )?; + let agg = agg.with_limit(Some(12)); + roundtrip_test(Arc::new(agg)) +} + #[test] fn roundtrip_aggregate_udaf() -> Result<()> { let field_a = Field::new("a", DataType::Int64, false);