Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BugFix] Broadcast Join should not generate nondetermistic GRF (backport #44111) #45167

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion be/src/exec/exec_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ void ExecNode::push_down_join_runtime_filter(RuntimeState* state, vectorized::Ru
if (_type != TPlanNodeType::AGGREGATION_NODE && _type != TPlanNodeType::ANALYTIC_EVAL_NODE) {
push_down_join_runtime_filter_to_children(state, collector);
}
_runtime_filter_collector.push_down(collector, _tuple_ids, _local_rf_waiting_set);
_runtime_filter_collector.push_down(state, id(), collector, _tuple_ids, _local_rf_waiting_set);
}

void ExecNode::push_down_join_runtime_filter_to_children(RuntimeState* state,
Expand Down
12 changes: 12 additions & 0 deletions be/src/exec/exec_node.h
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,18 @@ class ExecNode {
// Names of counters shared by all exec nodes
static const std::string ROW_THROUGHPUT_COUNTER;

<<<<<<< HEAD
=======
static void may_add_chunk_accumulate_operator(OpFactories& ops, pipeline::PipelineBuilderContext* context, int id);

void set_children(std::vector<ExecNode*>&& children) { _children = std::move(children); }

const std::vector<ExecNode*>& children() const { return _children; }

[[nodiscard]] static Status create_vectorized_node(RuntimeState* state, ObjectPool* pool, const TPlanNode& tnode,
const DescriptorTbl& descs, ExecNode** node);

>>>>>>> ecbc7907bb ([BugFix] Broadcast Join should not generate nondetermistic GRF (#44111))
protected:
friend class DataSink;

Expand Down
41 changes: 41 additions & 0 deletions be/src/exec/pipeline/fragment_executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,41 @@ int FragmentExecutor::_calc_query_expired_seconds(const UnifiedExecPlanFragmentP
return QueryContext::DEFAULT_EXPIRE_SECONDS;
}

static void collect_shuffle_hash_bucket_rf_ids(const ExecNode* node, std::unordered_set<int32_t>& filter_ids) {
for (const auto* child : node->children()) {
collect_shuffle_hash_bucket_rf_ids(child, filter_ids);
}
if (node->type() == TPlanNodeType::HASH_JOIN_NODE) {
const auto* join_node = down_cast<const HashJoinNode*>(node);
if (join_node->distribution_mode() == TJoinDistributionMode::SHUFFLE_HASH_BUCKET) {
for (const auto* rf : join_node->build_runtime_filters()) {
filter_ids.insert(rf->filter_id());
}
}
}
}

static std::unordered_set<int32_t> collect_broadcast_join_right_offsprings(
const ExecNode* node, BroadcastJoinRightOffsprings& broadcast_join_right_offsprings) {
std::vector<std::unordered_set<int32_t>> offsprings_per_child;
std::unordered_set<int32_t> offsprings;
offsprings_per_child.reserve(node->children().size());
for (const auto* child : node->children()) {
auto child_offspring = collect_broadcast_join_right_offsprings(child, broadcast_join_right_offsprings);
offsprings.insert(child_offspring.begin(), child_offspring.end());
offsprings_per_child.push_back(std::move(child_offspring));
}
offsprings.insert(node->id());
if (node->type() == TPlanNodeType::HASH_JOIN_NODE) {
const auto* join_node = down_cast<const HashJoinNode*>(node);
if (join_node->distribution_mode() == TJoinDistributionMode::BROADCAST &&
join_node->can_generate_global_runtime_filter()) {
broadcast_join_right_offsprings.insert(offsprings_per_child[1].begin(), offsprings_per_child[1].end());
}
}
return offsprings;
}

Status FragmentExecutor::_prepare_exec_plan(ExecEnv* exec_env, const UnifiedExecPlanFragmentParams& request) {
auto* runtime_state = _fragment_ctx->runtime_state();
auto* obj_pool = runtime_state->obj_pool();
Expand All @@ -307,6 +342,12 @@ Status FragmentExecutor::_prepare_exec_plan(ExecEnv* exec_env, const UnifiedExec
// Set up plan
RETURN_IF_ERROR(ExecNode::create_tree(runtime_state, obj_pool, fragment.plan, desc_tbl, &_fragment_ctx->plan()));
ExecNode* plan = _fragment_ctx->plan();
std::unordered_set<int32_t> filter_ids;
collect_shuffle_hash_bucket_rf_ids(plan, filter_ids);
runtime_state->set_shuffle_hash_bucket_rf_ids(std::move(filter_ids));
BroadcastJoinRightOffsprings broadcast_join_right_offsprings_map;
collect_broadcast_join_right_offsprings(plan, broadcast_join_right_offsprings_map);
runtime_state->set_broadcast_join_right_offsprings(std::move(broadcast_join_right_offsprings_map));
plan->push_down_join_runtime_filter_recursively(runtime_state);
std::vector<TupleSlotMapping> empty_mappings;
plan->push_down_tuple_slot_mappings(runtime_state, empty_mappings);
Expand Down
2 changes: 1 addition & 1 deletion be/src/exec/vectorized/aggregate/aggregate_base_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ Status AggregateBaseNode::close(RuntimeState* state) {
void AggregateBaseNode::push_down_join_runtime_filter(RuntimeState* state,
vectorized::RuntimeFilterProbeCollector* collector) {
// accept runtime filters from parent if possible.
_runtime_filter_collector.push_down(collector, _tuple_ids, _local_rf_waiting_set);
_runtime_filter_collector.push_down(state, id(), collector, _tuple_ids, _local_rf_waiting_set);

// check to see if runtime filters can be rewritten
auto& descriptors = _runtime_filter_collector.descriptors();
Expand Down
27 changes: 27 additions & 0 deletions be/src/exec/vectorized/hash_join_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -942,4 +942,31 @@ Status HashJoinNode::_create_implicit_local_join_runtime_filters(RuntimeState* s
return Status::OK();
}

<<<<<<< HEAD:be/src/exec/vectorized/hash_join_node.cpp
} // namespace starrocks::vectorized
=======
bool HashJoinNode::can_generate_global_runtime_filter() const {
return std::any_of(_build_runtime_filters.begin(), _build_runtime_filters.end(),
[](const RuntimeFilterBuildDescriptor* rf) { return rf->has_remote_targets(); });
}

void HashJoinNode::push_down_join_runtime_filter(RuntimeState* state, RuntimeFilterProbeCollector* collector) {
if (collector->empty()) return;
if (_join_type == TJoinOp::INNER_JOIN || _join_type == TJoinOp::LEFT_SEMI_JOIN ||
_join_type == TJoinOp::RIGHT_SEMI_JOIN) {
ExecNode::push_down_join_runtime_filter(state, collector);
return;
}
_runtime_filter_collector.push_down(state, id(), collector, _tuple_ids, _local_rf_waiting_set);
}

TJoinDistributionMode::type HashJoinNode::distribution_mode() const {
return _distribution_mode;
}

const std::list<RuntimeFilterBuildDescriptor*>& HashJoinNode::build_runtime_filters() const {
return _build_runtime_filters;
}

} // namespace starrocks
>>>>>>> ecbc7907bb ([BugFix] Broadcast Join should not generate nondetermistic GRF (#44111)):be/src/exec/hash_join_node.cpp
4 changes: 4 additions & 0 deletions be/src/exec/vectorized/hash_join_node.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ class HashJoinNode final : public ExecNode {
Status get_next(RuntimeState* state, ChunkPtr* chunk, bool* eos) override;
Status close(RuntimeState* state) override;
pipeline::OpFactories decompose_to_pipeline(pipeline::PipelineBuilderContext* context) override;
bool can_generate_global_runtime_filter() const;
TJoinDistributionMode::type distribution_mode() const;
const std::list<RuntimeFilterBuildDescriptor*>& build_runtime_filters() const;
void push_down_join_runtime_filter(RuntimeState* state, RuntimeFilterProbeCollector* collector) override;

private:
static bool _has_null(const ColumnPtr& column);
Expand Down
2 changes: 1 addition & 1 deletion be/src/exec/vectorized/project_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ void ProjectNode::push_down_tuple_slot_mappings(RuntimeState* state,
void ProjectNode::push_down_join_runtime_filter(RuntimeState* state,
vectorized::RuntimeFilterProbeCollector* collector) {
// accept runtime filters from parent if possible.
_runtime_filter_collector.push_down(collector, _tuple_ids, _local_rf_waiting_set);
_runtime_filter_collector.push_down(state, id(), collector, _tuple_ids, _local_rf_waiting_set);

// check to see if runtime filters can be rewritten
auto& descriptors = _runtime_filter_collector.descriptors();
Expand Down
6 changes: 4 additions & 2 deletions be/src/exprs/vectorized/runtime_filter_bank.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -514,7 +514,8 @@ void RuntimeFilterProbeCollector::update_selectivity(vectorized::Chunk* chunk,
}
}

void RuntimeFilterProbeCollector::push_down(RuntimeFilterProbeCollector* parent, const std::vector<TupleId>& tuple_ids,
void RuntimeFilterProbeCollector::push_down(const RuntimeState* state, TPlanNodeId target_plan_node_id,
RuntimeFilterProbeCollector* parent, const std::vector<TupleId>& tuple_ids,
std::set<TPlanNodeId>& local_rf_waiting_set) {
if (this == parent) return;
auto iter = parent->_descriptors.begin();
Expand All @@ -524,7 +525,8 @@ void RuntimeFilterProbeCollector::push_down(RuntimeFilterProbeCollector* parent,
++iter;
continue;
}
if (desc->is_bound(tuple_ids)) {
if (desc->is_bound(tuple_ids) && !(state->broadcast_join_right_offsprings().contains(target_plan_node_id) &&
state->shuffle_hash_bucket_rf_ids().contains(desc->filter_id()))) {
add_descriptor(desc);
if (desc->is_local()) {
local_rf_waiting_set.insert(desc->build_plan_node_id());
Expand Down
4 changes: 2 additions & 2 deletions be/src/exprs/vectorized/runtime_filter_bank.h
Original file line number Diff line number Diff line change
Expand Up @@ -192,8 +192,8 @@ class RuntimeFilterProbeCollector {
void add_descriptor(RuntimeFilterProbeDescriptor* desc);
// accept RuntimeFilterCollector from parent node
// which means parent node to push down runtime filter.
void push_down(RuntimeFilterProbeCollector* parent, const std::vector<TupleId>& tuple_ids,
std::set<TPlanNodeId>& rf_waiting_set);
void push_down(const RuntimeState* state, TPlanNodeId target_plan_node_id, RuntimeFilterProbeCollector* parent,
const std::vector<TupleId>& tuple_ids, std::set<TPlanNodeId>& rf_waiting_set);
std::map<int32_t, RuntimeFilterProbeDescriptor*>& descriptors() { return _descriptors; }
const std::map<int32_t, RuntimeFilterProbeDescriptor*>& descriptors() const { return _descriptors; }

Expand Down
3 changes: 3 additions & 0 deletions be/src/runtime/runtime_filter_worker.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,9 @@ void RuntimeFilterPort::publish_runtime_filters(std::list<vectorized::RuntimeFil
auto* filter = rf_desc->runtime_filter();

if (filter == nullptr || !rf_desc->has_remote_targets()) continue;
// Empty runtime filter generated by broadcast join can not be used as a global runtime, because it
// maybe shirt-circuited by empty probe side.
if (rf_desc->join_mode() == TRuntimeFilterBuildJoinMode::BORADCAST && filter->size() == 0) continue;

auto directly_send_broadcast_grf = rf_desc->join_mode() == TRuntimeFilterBuildJoinMode::BORADCAST &&
!rf_desc->broadcast_grf_senders().empty();
Expand Down
19 changes: 18 additions & 1 deletion be/src/runtime/runtime_state.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class RowDescriptor;
class RuntimeFilterPort;
class QueryStatistics;
class QueryStatisticsRecvr;

using BroadcastJoinRightOffsprings = std::unordered_set<int32_t>;
namespace pipeline {
class QueryContext;
}
Expand Down Expand Up @@ -345,6 +345,20 @@ class RuntimeState {

bool use_page_cache();

void set_shuffle_hash_bucket_rf_ids(std::unordered_set<int32_t>&& filter_ids) {
this->_shuffle_hash_bucket_rf_ids = std::move(filter_ids);
}

const std::unordered_set<int32_t>& shuffle_hash_bucket_rf_ids() const { return this->_shuffle_hash_bucket_rf_ids; }

void set_broadcast_join_right_offsprings(BroadcastJoinRightOffsprings&& broadcast_join_right_offsprings) {
this->_broadcast_join_right_offsprings = std::move(broadcast_join_right_offsprings);
}

const BroadcastJoinRightOffsprings& broadcast_join_right_offsprings() const {
return this->_broadcast_join_right_offsprings;
}

private:
// Set per-query state.
void _init(const TUniqueId& fragment_instance_id, const TQueryOptions& query_options,
Expand Down Expand Up @@ -464,6 +478,9 @@ class RuntimeState {
pipeline::FragmentContext* _fragment_ctx = nullptr;

bool _enable_pipeline_engine = false;

std::unordered_set<int32_t> _shuffle_hash_bucket_rf_ids;
BroadcastJoinRightOffsprings _broadcast_join_right_offsprings;
};

#define LIMIT_EXCEEDED(tracker, state, msg) \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,10 @@ public boolean isLeftAntiJoin() {
return this == LEFT_ANTI_JOIN || this == NULL_AWARE_LEFT_ANTI_JOIN;
}

public boolean isNullAwareLeftAntiJoin() {
return this == NULL_AWARE_LEFT_ANTI_JOIN;
}

public boolean isRightSemiJoin() {
return this == RIGHT_SEMI_JOIN;
}
Expand Down Expand Up @@ -140,6 +144,10 @@ public static Set<JoinOperator> semiAntiJoinSet() {
public static Set<JoinOperator> innerCrossJoinSet() {
return Sets.newHashSet(INNER_JOIN, CROSS_JOIN);
}

public boolean canGenerateRuntimeFilter() {
return !(isLeftOuterJoin() || isFullOuterJoin() || isLeftAntiJoin());
}
}


86 changes: 86 additions & 0 deletions fe/fe-core/src/main/java/com/starrocks/planner/JoinNode.java
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,15 @@

import com.google.common.base.MoreObjects;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
import com.starrocks.analysis.Analyzer;
import com.starrocks.analysis.BinaryPredicate;
<<<<<<< HEAD
=======
import com.starrocks.analysis.BinaryType;
import com.starrocks.analysis.DescriptorTable;
>>>>>>> ecbc7907bb ([BugFix] Broadcast Join should not generate nondetermistic GRF (#44111))
import com.starrocks.analysis.Expr;
import com.starrocks.analysis.JoinOperator;
import com.starrocks.analysis.SlotId;
Expand Down Expand Up @@ -158,7 +164,12 @@ public List<Expr> getProbePartitionByExprs() {
}

@Override
<<<<<<< HEAD
public void buildRuntimeFilters(IdGenerator<RuntimeFilterId> runtimeFilterIdIdGenerator) {
=======
public void buildRuntimeFilters(IdGenerator<RuntimeFilterId> runtimeFilterIdIdGenerator, DescriptorTable descTbl,
ExecGroupSets execGroupSets) {
>>>>>>> ecbc7907bb ([BugFix] Broadcast Join should not generate nondetermistic GRF (#44111))
SessionVariable sessionVariable = ConnectContext.get().getSessionVariable();
JoinOperator joinOp = getJoinOp();
PlanNode inner = getChild(1);
Expand Down Expand Up @@ -271,13 +282,79 @@ public boolean pushDownRuntimeFiltersForChild(RuntimeFilterDescription descripti
partitionByExprs, candidatesOfSlotExprsForChild(partitionByExprs, childIdx), childIdx, false);
}

private Optional<Boolean> pushDownRuntimeFilterBilaterally(RuntimeFilterPushDownContext context,
Expr probeExpr,
List<Expr> partitionByExprs) {
if (joinOp.isCrossJoin() || joinOp.isNullAwareLeftAntiJoin() || eqJoinConjuncts.isEmpty()) {
return Optional.empty();
}

if (!(probeExpr instanceof SlotRef)) {
return Optional.empty();
}
SlotRef probeSlotRefExpr = probeExpr.cast();
int slotId = probeSlotRefExpr.getSlotId().asInt();
boolean probeExprIsNotJoinColumn = eqJoinConjuncts.stream()
.filter(conj -> conj.getOp().equals(BinaryType.EQ))
.noneMatch(conj -> conj.getUsedSlotIds().contains(slotId));

if (probeExprIsNotJoinColumn) {
return Optional.empty();
}

// for join types except null-aware-left-anti-join and cross join
// runtime-filer probe expr uses join column, it can always be push down to both side of the join.
boolean hasPushedDown = pushDownRuntimeFiltersForChild(context, probeExpr, partitionByExprs, 0);
hasPushedDown |= pushDownRuntimeFiltersForChild(context, probeExpr, partitionByExprs, 1);
return Optional.of(hasPushedDown);
}


private Optional<Boolean> pushDownRuntimeFilterUnilaterally(RuntimeFilterPushDownContext context,
Expr probeExpr,
List<Expr> partitionByExprs) {
List<Integer> sides = ImmutableList.of();
if (joinOp.isLeftAntiJoin() || joinOp.isLeftOuterJoin()) {
sides = ImmutableList.of(0);
} else if (joinOp.isRightAntiJoin() || joinOp.isRightOuterJoin()) {
sides = ImmutableList.of(1);
} else if (joinOp.isInnerJoin() || joinOp.isSemiJoin() || joinOp.isCrossJoin()) {
sides = ImmutableList.of(0, 1);
}

boolean result = false;
Optional<List<List<Expr>>> optCandidatePartitionByExprs =
canPushDownRuntimeFilterCrossExchange(partitionByExprs);
if (optCandidatePartitionByExprs.isEmpty()) {
return Optional.of(false);
}
List<List<Expr>> candidatePartitionByExprs = optCandidatePartitionByExprs.get();
for (Integer side : sides) {
if (candidatePartitionByExprs.isEmpty()) {
result = getChild(side).pushDownRuntimeFilters(context, probeExpr, Lists.newArrayList());
} else {
for (List<Expr> partByExprs : candidatePartitionByExprs) {
result = getChild(side).pushDownRuntimeFilters(context, probeExpr, partByExprs);
if (result) {
break;
}
}
}
if (result) {
break;
}
}
return Optional.of(result);
}

@Override
public boolean pushDownRuntimeFilters(RuntimeFilterDescription description, Expr probeExpr, List<Expr> partitionByExprs) {
if (!canPushDownRuntimeFilter()) {
return false;
}

if (probeExpr.isBoundByTupleIds(getTupleIds())) {
<<<<<<< HEAD
boolean hasPushedDown = false;
// If probeExpr is SlotRef(a) and an equalJoinConjunct SlotRef(a)=SlotRef(b) exists in SemiJoin
// or InnerJoin, then the rf also can be pushed down to both sides of HashJoin because SlotRef(a) and
Expand All @@ -290,6 +367,15 @@ public boolean pushDownRuntimeFilters(RuntimeFilterDescription description, Expr
// fall back to PlanNode.pushDownRuntimeFilters for HJ if rf cannot be pushed down via equivalent
// equalJoinConjuncts
if (hasPushedDown || super.pushDownRuntimeFilters(description, probeExpr, partitionByExprs)) {
=======

Optional<Boolean> pushDownResult = pushDownRuntimeFilterBilaterally(context, probeExpr, partitionByExprs);
if (pushDownResult.isEmpty()) {
pushDownResult = pushDownRuntimeFilterUnilaterally(context, probeExpr, partitionByExprs);
}

if (pushDownResult.isPresent() && pushDownResult.get()) {
>>>>>>> ecbc7907bb ([BugFix] Broadcast Join should not generate nondetermistic GRF (#44111))
return true;
}

Expand Down
Loading
Loading