Skip to content

Commit

Permalink
[Enhancement] support partition-level cardinality estimation (backport
Browse files Browse the repository at this point in the history
…#51513) (#51749)

Co-authored-by: Murphy <[email protected]>
  • Loading branch information
mergify[bot] and murphyatwork authored Oct 11, 2024
1 parent c7a42f3 commit a8c8b5c
Show file tree
Hide file tree
Showing 16 changed files with 696 additions and 83 deletions.
43 changes: 43 additions & 0 deletions be/src/runtime/statistic_result_writer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ const int STATISTIC_HISTOGRAM_VERSION = 2;
const int DICT_STATISTIC_DATA_VERSION = 101;
const int STATISTIC_TABLE_VERSION = 3;
const int STATISTIC_BATCH_VERSION = 4;
const int STATISTIC_PARTITION_VERSION = 11;
const int STATISTIC_EXTERNAL_VERSION = 5;
const int STATISTIC_EXTERNAL_QUERY_VERSION = 6;
const int STATISTIC_EXTERNAL_HISTOGRAM_VERSION = 7;
Expand Down Expand Up @@ -148,6 +149,9 @@ StatusOr<TFetchDataResultPtr> StatisticResultWriter::_process_chunk(Chunk* chunk
} else if (version == STATISTIC_TABLE_VERSION) {
RETURN_IF_ERROR_WITH_WARN(_fill_table_statistic_data(version, result_columns, chunk, result.get()),
"Fill table statistic data failed");
} else if (version == STATISTIC_PARTITION_VERSION) {
RETURN_IF_ERROR_WITH_WARN(_fill_partition_statistic_data(version, result_columns, chunk, result.get()),
"Fill partition statistic data failed");
} else if (version == STATISTIC_BATCH_VERSION) {
RETURN_IF_ERROR_WITH_WARN(_fill_full_statistic_data_v4(version, result_columns, chunk, result.get()),
"Fill table statistic data failed");
Expand Down Expand Up @@ -380,6 +384,45 @@ Status StatisticResultWriter::_fill_full_statistic_data_v4(int version, const Co
return Status::OK();
}

Status StatisticResultWriter::_fill_partition_statistic_data(int version, const Columns& columns, const Chunk* chunk,
TFetchDataResult* result) {
/*
SQL:
SELECT cast(" + STATISTIC_PARTITION_VERSION + " as INT), +
`partition_id`,
`column_name`,
hll_cardinality(hll_union(`ndv`)) as distinct_count
*/

SCOPED_TIMER(_serialize_timer);

// mapping with Data.thrift.TStatisticData
DCHECK(columns.size() == 4);

// skip read version
auto partition_id = ColumnViewer<TYPE_BIGINT>(columns[1]);
auto column_name = ColumnViewer<TYPE_VARCHAR>(columns[2]);
auto distinct_count = ColumnViewer<TYPE_BIGINT>(columns[3]);
std::vector<TStatisticData> data_list;
int num_rows = chunk->num_rows();

data_list.resize(num_rows);
for (int i = 0; i < num_rows; ++i) {
data_list[i].__set_partitionId(partition_id.value(i));
data_list[i].__set_columnName(column_name.value(i).to_string());
data_list[i].__set_countDistinct(distinct_count.value(i));
}

result->result_batch.rows.resize(num_rows);
result->result_batch.__set_statistic_version(version);

ThriftSerializer serializer(true, chunk->memory_usage());
for (int i = 0; i < num_rows; ++i) {
RETURN_IF_ERROR(serializer.serialize(&data_list[i], &result->result_batch.rows[i]));
}
return Status::OK();
}

/*
FE SQL:
SELECT cast(5 as INT),
Expand Down
2 changes: 2 additions & 0 deletions be/src/runtime/statistic_result_writer.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ class StatisticResultWriter final : public ResultWriter {

Status _fill_table_statistic_data(int version, const Columns& columns, const Chunk* chunk,
TFetchDataResult* result);
Status _fill_partition_statistic_data(int version, const Columns& columns, const Chunk* chunk,
TFetchDataResult* result);

Status _fill_full_statistic_data_v4(int version, const Columns& columns, const Chunk* chunk,
TFetchDataResult* result);
Expand Down
13 changes: 13 additions & 0 deletions fe/fe-core/src/main/java/com/starrocks/qe/SessionVariable.java
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,8 @@ public class SessionVariable implements Serializable, Writable, Cloneable {
public static final String CBO_PRUNE_JSON_SUBFIELD_DEPTH = "cbo_prune_json_subfield_depth";
public static final String ENABLE_OPTIMIZER_REWRITE_GROUPINGSETS_TO_UNION_ALL =
"enable_rewrite_groupingsets_to_union_all";
public static final String ENABLE_PARTITION_LEVEL_CARDINALITY_ESTIMATION =
"enable_partition_level_cardinality_estimation";

public static final String CBO_USE_DB_LOCK = "cbo_use_lock_db";
public static final String CBO_PREDICATE_SUBFIELD_PATH = "cbo_enable_predicate_subfield_path";
Expand Down Expand Up @@ -1280,6 +1282,9 @@ public static MaterializedViewRewriteMode parse(String str) {
@VariableMgr.VarAttr(name = ENABLE_OPTIMIZER_REWRITE_GROUPINGSETS_TO_UNION_ALL)
private boolean enableRewriteGroupingSetsToUnionAll = false;

@VariableMgr.VarAttr(name = ENABLE_PARTITION_LEVEL_CARDINALITY_ESTIMATION, flag = VariableMgr.INVISIBLE)
private boolean enablePartitionLevelCardinalityEstimation = true;

// value should be 0~4
// 0 represents automatic selection, and 1, 2, 3, and 4 represent forced selection of AGG of
// corresponding stages respectively. However, stages 3 and 4 can only be generated in
Expand Down Expand Up @@ -3213,6 +3218,14 @@ public void setEnableRewriteGroupingSetsToUnionAll(boolean enableRewriteGrouping
this.enableRewriteGroupingSetsToUnionAll = enableRewriteGroupingSetsToUnionAll;
}

public boolean isEnablePartitionLevelCardinalityEstimation() {
return enablePartitionLevelCardinalityEstimation;
}

public void setEnablePartitionLevelCardinalityEstimation(boolean enablePartitionLevelCardinalityEstimation) {
this.enablePartitionLevelCardinalityEstimation = enablePartitionLevelCardinalityEstimation;
}

public void setEnableLowCardinalityOptimize(boolean enableLowCardinalityOptimize) {
this.enableLowCardinalityOptimize = enableLowCardinalityOptimize;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import com.starrocks.connector.statistics.ConnectorTableColumnStats;
import com.starrocks.server.GlobalStateMgr;
import com.starrocks.statistic.StatisticUtils;
import org.apache.commons.collections4.MapUtils;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

Expand Down Expand Up @@ -59,13 +60,20 @@ public class CachedStatisticStorage implements StatisticStorage {
.executor(statsCacheRefresherExecutor)
.buildAsync(new TableStatsCacheLoader());

AsyncLoadingCache<ColumnStatsCacheKey, Optional<ColumnStatistic>> cachedStatistics = Caffeine.newBuilder()
AsyncLoadingCache<ColumnStatsCacheKey, Optional<ColumnStatistic>> columnStatistics = Caffeine.newBuilder()
.expireAfterWrite(Config.statistic_update_interval_sec * 2, TimeUnit.SECONDS)
.refreshAfterWrite(Config.statistic_update_interval_sec, TimeUnit.SECONDS)
.maximumSize(Config.statistic_cache_columns)
.executor(statsCacheRefresherExecutor)
.buildAsync(new ColumnBasicStatsCacheLoader());

AsyncLoadingCache<ColumnStatsCacheKey, Optional<PartitionStats>> partitionStatistics = Caffeine.newBuilder()
.expireAfterWrite(Config.statistic_update_interval_sec * 2, TimeUnit.SECONDS)
.refreshAfterWrite(Config.statistic_update_interval_sec, TimeUnit.SECONDS)
.maximumSize(Config.statistic_cache_columns)
.executor(statsCacheRefresherExecutor)
.buildAsync(new PartitionStatsCacheLoader());

AsyncLoadingCache<ConnectorTableColumnKey, Optional<ConnectorTableColumnStats>> connectorTableCachedStatistics =
Caffeine.newBuilder().expireAfterWrite(Config.statistic_update_interval_sec * 2, TimeUnit.SECONDS)
.refreshAfterWrite(Config.statistic_update_interval_sec, TimeUnit.SECONDS)
Expand Down Expand Up @@ -249,7 +257,7 @@ public ColumnStatistic getColumnStatistic(Table table, String column) {
}
try {
CompletableFuture<Optional<ColumnStatistic>> result =
cachedStatistics.get(new ColumnStatsCacheKey(table.getId(), column));
columnStatistics.get(new ColumnStatsCacheKey(table.getId(), column));
if (result.isDone()) {
Optional<ColumnStatistic> realResult;
realResult = result.get();
Expand Down Expand Up @@ -284,7 +292,8 @@ public List<ColumnStatistic> getColumnStatistics(Table table, List<String> colum
}

try {
CompletableFuture<Map<ColumnStatsCacheKey, Optional<ColumnStatistic>>> result = cachedStatistics.getAll(cacheKeys);
CompletableFuture<Map<ColumnStatsCacheKey, Optional<ColumnStatistic>>> result =
columnStatistics.getAll(cacheKeys);
if (result.isDone()) {
List<ColumnStatistic> columnStatistics = new ArrayList<>();
Map<ColumnStatsCacheKey, Optional<ColumnStatistic>> realResult;
Expand Down Expand Up @@ -328,7 +337,8 @@ public List<ColumnStatistic> getColumnStatisticsSync(Table table, List<String> c
}

try {
Map<ColumnStatsCacheKey, Optional<ColumnStatistic>> result = cachedStatistics.synchronous().getAll(cacheKeys);
Map<ColumnStatsCacheKey, Optional<ColumnStatistic>> result =
columnStatistics.synchronous().getAll(cacheKeys);
List<ColumnStatistic> columnStatistics = new ArrayList<>();

for (String column : columns) {
Expand All @@ -347,6 +357,86 @@ public List<ColumnStatistic> getColumnStatisticsSync(Table table, List<String> c
}
}

/**
*
*/
private Map<String, PartitionStats> getColumnNDVForPartitions(Table table, List<Long> partitions,
List<String> columns) {

List<ColumnStatsCacheKey> cacheKeys = new ArrayList<>();
long tableId = table.getId();
for (String column : columns) {
cacheKeys.add(new ColumnStatsCacheKey(tableId, column));
}

try {
Map<ColumnStatsCacheKey, Optional<PartitionStats>> result =
partitionStatistics.synchronous().getAll(cacheKeys);

Map<String, PartitionStats> columnStatistics = Maps.newHashMap();
for (String column : columns) {
Optional<PartitionStats> columnStatistic = result.get(new ColumnStatsCacheKey(tableId, column));
columnStatistics.put(column, columnStatistic.orElse(null));
}
return columnStatistics;
} catch (Exception e) {
LOG.warn("Get partition NDV fail", e);
return null;
}
}

/**
* We don't really maintain all statistics for partition, as most of them are not necessary.
* Currently, the only partition-level statistics is DistinctCount, which may differs a lot among partitions
*/
@Override
public Map<Long, List<ColumnStatistic>> getColumnStatisticsOfPartitionLevel(Table table, List<Long> partitions,
List<String> columns) {

Preconditions.checkState(table != null);

// get Statistics Table column info, just return default column statistics
if (StatisticUtils.statisticTableBlackListCheck(table.getId())) {
return null;
}
if (!StatisticUtils.checkStatisticTableStateNormal()) {
return null;
}

long tableId = table.getId();
List<ColumnStatsCacheKey> cacheKeys = columns.stream()
.map(x -> new ColumnStatsCacheKey(tableId, x))
.collect(Collectors.toList());
List<ColumnStatistic> columnStatistics = getColumnStatistics(table, columns);
Map<String, PartitionStats> columnNDVForPartitions = getColumnNDVForPartitions(table, partitions, columns);
if (MapUtils.isEmpty(columnNDVForPartitions)) {
return null;
}

Map<Long, List<ColumnStatistic>> result = Maps.newHashMap();
for (long partition : partitions) {
List<ColumnStatistic> newStatistics = Lists.newArrayList();
for (int i = 0; i < columns.size(); i++) {
ColumnStatistic columnStatistic = columnStatistics.get(i);
PartitionStats partitionStats = columnNDVForPartitions.get(columns.get(i));
if (partitionStats == null) {
// some of the columns miss statistics
return null;
}
if (!partitionStats.getDistinctCount().containsKey(partition)) {
// some of the partitions miss statistics
return null;
}
double distinctCount = partitionStats.getDistinctCount().get(partition);
ColumnStatistic newStats = ColumnStatistic.buildFrom(columnStatistic)
.setDistinctValuesCount(distinctCount).build();
newStatistics.add(newStats);
}
result.put(partition, newStatistics);
}
return result;
}

@Override
public void expireTableAndColumnStatistics(Table table, List<String> columns) {
List<TableStatsCacheKey> tableStatsCacheKeys = Lists.newArrayList();
Expand All @@ -363,12 +453,13 @@ public void expireTableAndColumnStatistics(Table table, List<String> columns) {
ColumnStatsCacheKey key = new ColumnStatsCacheKey(table.getId(), column);
allKeys.add(key);
}
cachedStatistics.synchronous().invalidateAll(allKeys);
columnStatistics.synchronous().invalidateAll(allKeys);
}

@Override
public void addColumnStatistic(Table table, String column, ColumnStatistic columnStatistic) {
this.cachedStatistics.synchronous().put(new ColumnStatsCacheKey(table.getId(), column), Optional.of(columnStatistic));
this.columnStatistics.synchronous()
.put(new ColumnStatsCacheKey(table.getId(), column), Optional.of(columnStatistic));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
// Copyright 2021-present StarRocks, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package com.starrocks.sql.optimizer.statistics;

import com.google.common.collect.Maps;

import java.util.Map;

/**
* Partition-level statistics
*/
public class PartitionStats {
public final Map<Long, Double> distinctCount;

public PartitionStats() {
this.distinctCount = Maps.newHashMap();
}

public PartitionStats(Map<Long, Double> distinctCount) {
this.distinctCount = distinctCount;
}

public Map<Long, Double> getDistinctCount() {
return distinctCount;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
// Copyright 2021-present StarRocks, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package com.starrocks.sql.optimizer.statistics;

import com.github.benmanes.caffeine.cache.AsyncCacheLoader;
import com.google.common.collect.Lists;
import com.starrocks.qe.ConnectContext;
import com.starrocks.statistic.StatisticExecutor;
import com.starrocks.statistic.StatisticUtils;
import com.starrocks.thrift.TStatisticData;
import org.checkerframework.checker.nullness.qual.NonNull;

import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException;
import java.util.concurrent.Executor;

public class PartitionStatsCacheLoader implements AsyncCacheLoader<ColumnStatsCacheKey, Optional<PartitionStats>> {

private final StatisticExecutor statisticExecutor = new StatisticExecutor();

@Override
public @NonNull CompletableFuture<Optional<PartitionStats>> asyncLoad(@NonNull ColumnStatsCacheKey cacheKey,
@NonNull Executor executor) {
return asyncLoadAll(Lists.newArrayList(cacheKey), executor).thenApply(x -> x.get(cacheKey));
}

@Override
public @NonNull CompletableFuture<Map<ColumnStatsCacheKey, Optional<PartitionStats>>>
asyncLoadAll(@NonNull Iterable<? extends @NonNull ColumnStatsCacheKey> cacheKey, @NonNull Executor executor) {
return CompletableFuture.supplyAsync(() -> {
try {
ConnectContext connectContext = StatisticUtils.buildConnectContext();
connectContext.setThreadLocalInfo();

Map<ColumnStatsCacheKey, Optional<PartitionStats>> result = new HashMap<>();
long tableId = -1;
List<String> columns = Lists.newArrayList();
for (ColumnStatsCacheKey statsCacheKey : cacheKey) {
columns.add(statsCacheKey.column);
tableId = statsCacheKey.tableId;
}
List<TStatisticData> statisticData = statisticExecutor.queryPartitionLevelColumnNDV(connectContext,
tableId, Lists.newArrayList(), columns);
for (TStatisticData data : statisticData) {
ColumnStatsCacheKey key = new ColumnStatsCacheKey(tableId, data.columnName);
result.computeIfAbsent(key, (x) -> Optional.of(new PartitionStats()))
.get().getDistinctCount().put(data.partitionId, (double) data.countDistinct);
}
for (ColumnStatsCacheKey key : cacheKey) {
if (!result.containsKey(key)) {
result.put(key, Optional.empty());
}
}
return result;
} catch (RuntimeException e) {
throw e;
} catch (Exception e) {
throw new CompletionException(e);
} finally {
ConnectContext.remove();
}
}, executor);
}
}
Loading

0 comments on commit a8c8b5c

Please sign in to comment.