Skip to content

Commit

Permalink
[BugFix] MaxBy/MinBy not filter nulls (#51354)
Browse files Browse the repository at this point in the history
Signed-off-by: satanson <[email protected]>
(cherry picked from commit 9398edd)

# Conflicts:
#	be/src/exprs/agg/factory/aggregate_factory.cpp
#	fe/fe-core/src/main/java/com/starrocks/catalog/Function.java
#	fe/fe-core/src/main/java/com/starrocks/catalog/combinator/AggStateDesc.java
#	test/sql/test_agg_state/R/test_agg_state_table_with_all_functions.sql
  • Loading branch information
satanson authored and mergify[bot] committed Sep 27, 2024
1 parent 4748422 commit 54e84c9
Show file tree
Hide file tree
Showing 15 changed files with 2,689 additions and 211 deletions.
8 changes: 7 additions & 1 deletion be/src/exec/analytor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,13 @@ Status Analytor::prepare(RuntimeState* state, ObjectPool* pool, RuntimeProfile*
real_fn_name += "_in";
_need_partition_materializing = true;
}
func = get_window_function(real_fn_name, arg_type.type, return_type.type, is_input_nullable, fn.binary_type,
const auto& fname = fn.name.function_name;
auto real_arg_type = arg_type.type;
if (fname == "max_by" || fname == "min_by" || fname == "max_by_v2" || fname == "min_by_v2") {
const TypeDescriptor arg1_type = TypeDescriptor::from_thrift(fn.arg_types[1]);
real_arg_type = arg1_type.type;
}
func = get_window_function(real_fn_name, real_arg_type, return_type.type, is_input_nullable, fn.binary_type,
state->func_version());
if (func == nullptr) {
return Status::InternalError(strings::Substitute(
Expand Down
45 changes: 45 additions & 0 deletions be/src/exprs/agg/factory/aggregate_factory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -168,4 +168,49 @@ const AggregateFunction* get_window_function(const std::string& name, LogicalTyp
return nullptr;
}

<<<<<<< HEAD
=======
const AggregateFunction* get_aggregate_function(const std::string& agg_func_name, const TypeDescriptor& return_type,
const std::vector<TypeDescriptor>& arg_types, bool is_result_nullable,
TFunctionBinaryType::type binary_type, int func_version) {
// get function
if (agg_func_name == "count") {
return get_aggregate_function("count", TYPE_BIGINT, TYPE_BIGINT, is_result_nullable);
} else {
DCHECK_GE(arg_types.size(), 1);
TypeDescriptor arg_type = arg_types[0];
// Because intersect_count have two input types.
// And intersect_count's first argument's type is alwasy Bitmap,
// so we use its second arguments type as input.
if (agg_func_name == "intersect_count") {
arg_type = arg_types[1];
}

// Because max_by and min_by function have two input types,
// so we use its second arguments type as input.
if (agg_func_name == "max_by" || agg_func_name == "min_by" || agg_func_name == "max_by_v2" ||
agg_func_name == "min_by_v2") {
arg_type = arg_types[1];
}

// Because windowfunnel have more two input types.
// functions registry use 2th args(datetime/date).
if (agg_func_name == "window_funnel") {
arg_type = arg_types[1];
}

// hack for accepting various arguments
if (agg_func_name == "exchange_bytes" || agg_func_name == "exchange_speed") {
arg_type = TypeDescriptor(TYPE_BIGINT);
}

if (agg_func_name == "array_union_agg" || agg_func_name == "array_unique_agg") {
arg_type = arg_type.children[0];
}
return get_aggregate_function(agg_func_name, arg_type.type, return_type.type, is_result_nullable, binary_type,
func_version);
}
}

>>>>>>> 9398edd4af ([BugFix] MaxBy/MinBy not filter nulls (#51354))
} // namespace starrocks
16 changes: 8 additions & 8 deletions be/src/exprs/agg/factory/aggregate_factory.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,10 +122,10 @@ class AggregateFactory {
template <LogicalType LT>
static auto MakeMaxAggregateFunction();

template <LogicalType LT>
template <LogicalType LT, bool not_filter_nulls>
static auto MakeMaxByAggregateFunction();

template <LogicalType LT>
template <LogicalType LT, bool not_filter_nulls>
static auto MakeMinByAggregateFunction();

template <LogicalType LT>
Expand Down Expand Up @@ -280,16 +280,16 @@ auto AggregateFactory::MakeMaxAggregateFunction() {
return std::make_shared<MaxMinAggregateFunction<LT, MaxAggregateData<LT>, MaxElement<LT, MaxAggregateData<LT>>>>();
}

template <LogicalType LT>
template <LogicalType LT, bool not_filter_nulls>
auto AggregateFactory::MakeMaxByAggregateFunction() {
return std::make_shared<
MaxMinByAggregateFunction<LT, MaxByAggregateData<LT>, MaxByElement<LT, MaxByAggregateData<LT>>>>();
using AggData = MaxByAggregateData<LT, not_filter_nulls>;
return std::make_shared<MaxMinByAggregateFunction<LT, AggData, MaxByElement<LT, AggData>>>();
}

template <LogicalType LT>
template <LogicalType LT, bool not_filter_nulls>
auto AggregateFactory::MakeMinByAggregateFunction() {
return std::make_shared<
MaxMinByAggregateFunction<LT, MinByAggregateData<LT>, MinByElement<LT, MinByAggregateData<LT>>>>();
using AggData = MinByAggregateData<LT, not_filter_nulls>;
return std::make_shared<MaxMinByAggregateFunction<LT, AggData, MinByElement<LT, AggData>>>();
}

template <LogicalType LT>
Expand Down
12 changes: 8 additions & 4 deletions be/src/exprs/agg/factory/aggregate_resolver_minmaxany.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,11 +75,15 @@ struct MaxMinByDispatcherInner {
if constexpr ((lt_is_aggregate<arg_type> || lt_is_json<arg_type>)&&(lt_is_aggregate<ret_type> ||
lt_is_json<ret_type>)) {
if constexpr (is_max_by) {
resolver->add_aggregate_mapping_variadic<arg_type, ret_type, MaxByAggregateData<arg_type>>(
"max_by", true, AggregateFactory::MakeMaxByAggregateFunction<arg_type>());
resolver->add_aggregate_mapping_notnull<arg_type, ret_type>(
"max_by", true, AggregateFactory::MakeMaxByAggregateFunction<arg_type, false>());
resolver->add_aggregate_mapping_notnull<arg_type, ret_type>(
"max_by_v2", true, AggregateFactory::MakeMaxByAggregateFunction<arg_type, true>());
} else {
resolver->add_aggregate_mapping_variadic<arg_type, ret_type, MinByAggregateData<arg_type>>(
"min_by", true, AggregateFactory::MakeMinByAggregateFunction<arg_type>());
resolver->add_aggregate_mapping_notnull<arg_type, ret_type>(
"min_by", true, AggregateFactory::MakeMinByAggregateFunction<arg_type, false>());
resolver->add_aggregate_mapping_notnull<arg_type, ret_type>(
"min_by_v2", true, AggregateFactory::MakeMinByAggregateFunction<arg_type, true>());
}
}
}
Expand Down
Loading

0 comments on commit 54e84c9

Please sign in to comment.