From 0a43806ecf01b909b2ead17582f64495c45e15f4 Mon Sep 17 00:00:00 2001 From: "zihe.liu" Date: Thu, 5 Sep 2024 14:59:47 +0800 Subject: [PATCH] [BugFix] Clear probe RF whose probe expr contains dict mapping expr (backport #50690) (#50728) Signed-off-by: zihe.liu --- .../exprs/vectorized/runtime_filter_bank.cpp | 25 ++++++++++- .../java/com/starrocks/analysis/Expr.java | 11 +++++ .../com/starrocks/planner/PlanFragment.java | 18 ++++++++ .../com/starrocks/planner/ProjectNode.java | 6 ++- .../sql/plan/PlanFragmentBuilder.java | 2 + .../sql/plan/LowCardinalityTest.java | 42 +++++++++++++++++++ 6 files changed, 101 insertions(+), 3 deletions(-) diff --git a/be/src/exprs/vectorized/runtime_filter_bank.cpp b/be/src/exprs/vectorized/runtime_filter_bank.cpp index 26b93c596c841..de0ca8c3097ac 100644 --- a/be/src/exprs/vectorized/runtime_filter_bank.cpp +++ b/be/src/exprs/vectorized/runtime_filter_bank.cpp @@ -6,6 +6,7 @@ #include "column/column.h" #include "exec/pipeline/runtime_filter_types.h" +#include "exprs/vectorized/dictmapping_expr.h" #include "exprs/vectorized/in_const_predicate.hpp" #include "exprs/vectorized/literal.h" #include "exprs/vectorized/runtime_filter.h" @@ -514,6 +515,23 @@ void RuntimeFilterProbeCollector::update_selectivity(vectorized::Chunk* chunk, } } +static bool contains_dict_mapping_expr(Expr* expr) { + if (typeid(*expr) == typeid(DictMappingExpr)) { + return true; + } + + return std::any_of(expr->children().begin(), expr->children().end(), + [](Expr* child) { return contains_dict_mapping_expr(child); }); +} + +static bool contains_dict_mapping_expr(RuntimeFilterProbeDescriptor* probe_desc) { + auto* probe_expr_ctx = probe_desc->probe_expr_ctx(); + if (probe_expr_ctx == nullptr) { + return false; + } + return contains_dict_mapping_expr(probe_expr_ctx->root()); +} + void RuntimeFilterProbeCollector::push_down(const RuntimeState* state, TPlanNodeId target_plan_node_id, RuntimeFilterProbeCollector* parent, const std::vector& tuple_ids, std::set& local_rf_waiting_set) { @@ -525,8 +543,11 @@ void RuntimeFilterProbeCollector::push_down(const RuntimeState* state, TPlanNode ++iter; continue; } - if (desc->is_bound(tuple_ids) && !(state->broadcast_join_right_offsprings().count(target_plan_node_id) && - state->non_broadcast_rf_ids().count(desc->filter_id()))) { + + if (desc->is_bound(tuple_ids) && + !(state->broadcast_join_right_offsprings().count(target_plan_node_id) && + state->non_broadcast_rf_ids().count(desc->filter_id())) && + !contains_dict_mapping_expr(desc)) { add_descriptor(desc); if (desc->is_local()) { local_rf_waiting_set.insert(desc->build_plan_node_id()); diff --git a/fe/fe-core/src/main/java/com/starrocks/analysis/Expr.java b/fe/fe-core/src/main/java/com/starrocks/analysis/Expr.java index 543d288981c47..999cb1d8b5ac5 100644 --- a/fe/fe-core/src/main/java/com/starrocks/analysis/Expr.java +++ b/fe/fe-core/src/main/java/com/starrocks/analysis/Expr.java @@ -1454,4 +1454,15 @@ public List getHints() { return hints; } + public boolean containsDictMappingExpr() { + return containsDictMappingExpr(this); + } + + private static boolean containsDictMappingExpr(Expr expr) { + if (expr instanceof DictMappingExpr) { + return true; + } + return expr.getChildren().stream().anyMatch(child -> containsDictMappingExpr(child)); + } + } diff --git a/fe/fe-core/src/main/java/com/starrocks/planner/PlanFragment.java b/fe/fe-core/src/main/java/com/starrocks/planner/PlanFragment.java index 067bbc52a4f6b..d2a02c45bdc83 100644 --- a/fe/fe-core/src/main/java/com/starrocks/planner/PlanFragment.java +++ b/fe/fe-core/src/main/java/com/starrocks/planner/PlanFragment.java @@ -755,4 +755,22 @@ public void removeRfOnRightOffspringsOfBroadcastJoin() { removeRfOfRightOffspring(getPlanRoot(), localRightOffsprings, filterIds); } + + public void removeDictMappingProbeRuntimeFilters() { + removeDictMappingProbeRuntimeFilters(getPlanRoot()); + } + + private void removeDictMappingProbeRuntimeFilters(PlanNode root) { + root.getProbeRuntimeFilters().removeIf(filter -> { + Expr probExpr = filter.getNodeIdToProbeExpr().get(root.getId().asInt()); + return probExpr.containsDictMappingExpr(); + }); + + for (PlanNode child : root.getChildren()) { + if (child.getFragmentId().equals(root.getFragmentId())) { + removeDictMappingProbeRuntimeFilters(child); + } + } + } + } diff --git a/fe/fe-core/src/main/java/com/starrocks/planner/ProjectNode.java b/fe/fe-core/src/main/java/com/starrocks/planner/ProjectNode.java index b033fcd9a64b6..50a48fa0c8f6b 100644 --- a/fe/fe-core/src/main/java/com/starrocks/planner/ProjectNode.java +++ b/fe/fe-core/src/main/java/com/starrocks/planner/ProjectNode.java @@ -161,7 +161,11 @@ public boolean pushDownRuntimeFilters(RuntimeFilterDescription description, return false; } - return pushdownRuntimeFilterForChildOrAccept(description, probeExpr, candidatesOfSlotExpr(probeExpr), + Optional> optProbeExprCandidates = candidatesOfSlotExpr(probeExpr); + optProbeExprCandidates.ifPresent( + exprs -> exprs.removeIf(probeExprCandidate -> probeExprCandidate.containsDictMappingExpr())); + + return pushdownRuntimeFilterForChildOrAccept(description, probeExpr,optProbeExprCandidates, partitionByExprs, candidatesOfSlotExprs(partitionByExprs), 0, true); } diff --git a/fe/fe-core/src/main/java/com/starrocks/sql/plan/PlanFragmentBuilder.java b/fe/fe-core/src/main/java/com/starrocks/sql/plan/PlanFragmentBuilder.java index 075cbd4b0264b..5d2fa3a7c4a67 100644 --- a/fe/fe-core/src/main/java/com/starrocks/sql/plan/PlanFragmentBuilder.java +++ b/fe/fe-core/src/main/java/com/starrocks/sql/plan/PlanFragmentBuilder.java @@ -258,6 +258,8 @@ private static ExecPlan finalizeFragments(ExecPlan execPlan, TResultSinkType res fragment.computeLocalRfWaitingSet(fragment.getPlanRoot(), shouldClearRuntimeFilters); } + fragments.forEach(PlanFragment::removeDictMappingProbeRuntimeFilters); + if (useQueryCache(execPlan)) { List fragmentsWithLeftmostOlapScanNode = execPlan.getFragments().stream() .filter(PlanFragment::hasOlapScanNode).collect(Collectors.toList()); diff --git a/fe/fe-core/src/test/java/com/starrocks/sql/plan/LowCardinalityTest.java b/fe/fe-core/src/test/java/com/starrocks/sql/plan/LowCardinalityTest.java index 37e61a21a763f..a9dc3332e2d12 100644 --- a/fe/fe-core/src/test/java/com/starrocks/sql/plan/LowCardinalityTest.java +++ b/fe/fe-core/src/test/java/com/starrocks/sql/plan/LowCardinalityTest.java @@ -1892,4 +1892,46 @@ public void testNestedStringFunc() throws Exception { assertContains(plan, "if(DictExpr(10: S_ADDRESS,[ = '']), '', " + "substr(md5(DictExpr(10: S_ADDRESS,[])), 1, 3))"); } + + @Test + public void testRuntimeFilterOnProjectWithDictExpr() throws Exception { + String sql = "WITH \n" + + " w1 AS (\n" + + " SELECT CASE\n" + + " WHEN P_NAME = 'a' THEN 'a1'\n" + + " WHEN P_BRAND = 'b' THEN 'b1'\n" + + " ELSE 'c1'\n" + + " END as P_NAME2, P_NAME from part_v2\n" + + " UNION ALL\n" + + " SELECT P_NAME, P_NAME from part_v2\n" + + ")\n" + + "SELECT count(1) \n" + + "FROM \n" + + " w1 t1 \n" + + " JOIN [broadcast] part_v2 t2 ON t1.P_NAME2 = t2.P_NAME AND t1.P_NAME = t2.P_NAME;"; + String plan = getCostExplain(sql); + assertContains(plan, " 3:Decode\n" + + " | : \n" + + " | cardinality: 1\n" + + " | column statistics: \n" + + " | * P_NAME-->[-Infinity, Infinity, 0.0, 1.0, 1.0] UNKNOWN\n" + + " | * P_BRAND-->[-Infinity, Infinity, 0.0, 1.0, 1.0] UNKNOWN\n" + + " | * cast-->[-Infinity, Infinity, 0.0, 16.0, 3.0] ESTIMATE\n" + + " | \n" + + " 2:Project\n" + + " | output columns:\n" + + " | 36 <-> CASE WHEN DictExpr(62: P_NAME,[ = 'a']) THEN 'a1' WHEN " + + "DictExpr(63: P_BRAND,[ = 'b']) THEN 'b1' ELSE 'c1' END\n" + + " | 62 <-> [62: P_NAME, INT, false]\n" + + " | cardinality: 1\n" + + " | column statistics: \n" + + " | * cast-->[-Infinity, Infinity, 0.0, 16.0, 3.0] ESTIMATE\n" + + " | \n" + + " 1:OlapScanNode\n" + + " table: part_v2, rollup: part_v2\n" + + " preAggregation: on\n" + + " dict_col=P_NAME,P_BRAND"); + System.out.println(plan); + } + }