Skip to content

Commit

Permalink
batch serialize
Browse files Browse the repository at this point in the history
Signed-off-by: guo-shaoge <[email protected]>
  • Loading branch information
guo-shaoge committed Jan 14, 2025
1 parent 1dd0fe7 commit a2a6116
Show file tree
Hide file tree
Showing 3 changed files with 189 additions and 36 deletions.
180 changes: 157 additions & 23 deletions dbms/src/Common/ColumnsHashing.h
Original file line number Diff line number Diff line change
Expand Up @@ -332,36 +332,149 @@ struct KeyDescStringBinPadding : KeyDescStringBin
}
};

template <size_t batch_size>
struct KeySerializedBatchHandlerBase
{
static constexpr size_t min_batch_size = 256;
static_assert(batch_size == 0 || batch_size >= min_batch_size);

size_t batch_row_idx = 0;
PaddedPODArray<size_t> byte_size{};
String sort_key_container{};
PaddedPODArray<char *> pos{};
PaddedPODArray<char *> ori_pos{};
PaddedPODArray<size_t> real_byte_size{};

void init(const ColumnRawPtrs & key_columns, const TiDB::TiDBCollators & collators, size_t start_row)
{
assert(batch_size >= min_batch_size);

byte_size.resize_fill_zero(key_columns[0]->size());
for (size_t i = 0; i < key_columns.size(); ++i)
key_columns[i]->countSerializeByteSizeUnique(byte_size, collators.empty() ? nullptr : collators[i]);

pos.resize_fill_zero(batch_size);
ori_pos.resize_fill_zero(batch_size);
real_byte_size.resize_fill_zero(batch_size);

batch_row_idx = (start_row / batch_size) * batch_size;
}

ALWAYS_INLINE inline void assertInit() const
{
assert(batch_size == pos.size() && ori_pos.size() == pos.size() && real_byte_size.size() == pos.size());
}

size_t prepareNextBatch(const ColumnRawPtrs & key_columns, Arena * pool, const TiDB::TiDBCollators & collators)
{
assertInit();

const auto len = std::min(batch_size, key_columns[0]->size() - batch_row_idx);
if unlikely (len <= 0)
return 0;

size_t mem_size = 0;
for (size_t i = batch_row_idx; i < batch_row_idx + len; ++i)
mem_size += byte_size[i];

auto * ptr = static_cast<char *>(pool->alignedAlloc(mem_size, 16));
for (size_t i = 0; i < len; ++i)
{
pos[i] = ptr;
ori_pos[i] = ptr;
ptr += byte_size[i + batch_row_idx];
}

for (size_t i = 0; i < key_columns.size(); ++i)
key_columns[i]->serializeToPosUnique(
pos,
batch_row_idx,
len,
false,
collators.empty() ? nullptr : collators[i],
&sort_key_container);

for (size_t i = 0; i < len; ++i)
real_byte_size[i] = pos[i] - ori_pos[i];

batch_row_idx += len;

return mem_size;
}

ALWAYS_INLINE inline ArenaKeyHolder getKeyHolderBatch(ssize_t row, Arena * pool) const
{
assertInit();
assert(static_cast<size_t>(row) < batch_row_idx + batch_size);

const auto idx = row % batch_size;
return ArenaKeyHolder{StringRef{ori_pos[idx], real_byte_size[idx]}, pool};
}
};

/// For the case when there are 2 keys.
template <typename Key1Desc, typename Key2Desc, typename Value, typename Mapped>
template <typename Key1Desc, typename Key2Desc, typename Value, typename Mapped, size_t batch_size = 0>
struct HashMethodFastPathTwoKeysSerialized
: public columns_hashing_impl::
HashMethodBase<HashMethodFastPathTwoKeysSerialized<Key1Desc, Key2Desc, Value, Mapped>, Value, Mapped, false>
: public columns_hashing_impl::HashMethodBase<
HashMethodFastPathTwoKeysSerialized<Key1Desc, Key2Desc, Value, Mapped, batch_size>,
Value,
Mapped,
false>
, KeySerializedBatchHandlerBase<batch_size>
{
using Self = HashMethodFastPathTwoKeysSerialized<Key1Desc, Key2Desc, Value, Mapped>;
using Self = HashMethodFastPathTwoKeysSerialized<Key1Desc, Key2Desc, Value, Mapped, batch_size>;
using Base = columns_hashing_impl::HashMethodBase<Self, Value, Mapped, false>;
using KeyHolderType = SerializedKeyHolder;
using BatchHandlerBase = KeySerializedBatchHandlerBase<batch_size>;
static constexpr bool enable_batch = batch_size > 0;
using KeyHolderType = typename std::conditional<enable_batch, ArenaKeyHolder, SerializedKeyHolder>::type;

static constexpr bool is_serialized_key = true;

ColumnRawPtrs key_columns;
Key1Desc key_1_desc;
Key2Desc key_2_desc;
TiDB::TiDBCollators collators;

HashMethodFastPathTwoKeysSerialized(const ColumnRawPtrs & key_columns, const Sizes &, const TiDB::TiDBCollators &)
: key_1_desc(key_columns[0])
, key_2_desc(key_columns[1])
HashMethodFastPathTwoKeysSerialized(
const ColumnRawPtrs & key_columns_,
const Sizes &,
const TiDB::TiDBCollators & collators_)
: key_columns(key_columns_)
, key_1_desc(key_columns_[0])
, key_2_desc(key_columns_[1])
, collators(collators_)
{}

void initBatchHandler(size_t start_row)
{
if constexpr (enable_batch)
BatchHandlerBase::init(key_columns, collators, start_row);
}

size_t prepareNextBatch(Arena * pool)
{
if constexpr (enable_batch)
return BatchHandlerBase::prepareNextBatch(key_columns, pool, collators);
return 0;
}

ALWAYS_INLINE inline KeyHolderType getKeyHolder(ssize_t row, Arena * pool, std::vector<String> &) const
{
StringRef key1;
StringRef key2;
size_t alloc_size = key_1_desc.getKey(row, key1) + key_2_desc.getKey(row, key2);
char * start = pool->alloc(alloc_size);
SerializedKeyHolder ret{{start, alloc_size}, pool};
Key1Desc::serializeKey(start, key1);
Key2Desc::serializeKey(start, key2);
return ret;
if constexpr (enable_batch)
{
return BatchHandlerBase::getKeyHolderBatch(row, pool);
}
else
{
StringRef key1;
StringRef key2;
size_t alloc_size = key_1_desc.getKey(row, key1) + key_2_desc.getKey(row, key2);
char * start = pool->alloc(alloc_size);
SerializedKeyHolder ret{{start, alloc_size}, pool};
Key1Desc::serializeKey(start, key1);
Key2Desc::serializeKey(start, key2);
return ret;
}
}

protected:
Expand Down Expand Up @@ -576,13 +689,16 @@ struct HashMethodKeysFixed
* That is, for example, for strings, it contains first the serialized length of the string, and then the bytes.
* Therefore, when aggregating by several strings, there is no ambiguity.
*/
template <typename Value, typename Mapped>
template <typename Value, typename Mapped, size_t batch_size = 0>
struct HashMethodSerialized
: public columns_hashing_impl::HashMethodBase<HashMethodSerialized<Value, Mapped>, Value, Mapped, false>
: public columns_hashing_impl::HashMethodBase<HashMethodSerialized<Value, Mapped, batch_size>, Value, Mapped, false>
, KeySerializedBatchHandlerBase<batch_size>
{
using Self = HashMethodSerialized<Value, Mapped>;
using Self = HashMethodSerialized<Value, Mapped, batch_size>;
using Base = columns_hashing_impl::HashMethodBase<Self, Value, Mapped, false>;
using KeyHolderType = SerializedKeyHolder;
using BatchHandlerBase = KeySerializedBatchHandlerBase<batch_size>;
static constexpr bool enable_batch = batch_size > 0;
using KeyHolderType = typename std::conditional<enable_batch, ArenaKeyHolder, SerializedKeyHolder>::type;

static constexpr bool is_serialized_key = true;

Expand All @@ -599,12 +715,30 @@ struct HashMethodSerialized
, collators(collators_)
{}

void initBatchHandler(size_t start_row)
{
if constexpr (enable_batch)
BatchHandlerBase::init(key_columns, collators, start_row);
}

size_t prepareNextBatch(Arena * pool)
{
if constexpr (enable_batch)
return BatchHandlerBase::prepareNextBatch(key_columns, pool, collators);
return 0;
}

ALWAYS_INLINE inline KeyHolderType getKeyHolder(size_t row, Arena * pool, std::vector<String> & sort_key_containers)
const
{
return SerializedKeyHolder{
serializeKeysToPoolContiguous(row, keys_size, key_columns, collators, sort_key_containers, *pool),
pool};
if constexpr (enable_batch)
return BatchHandlerBase::getKeyHolderBatch(row, pool);
else
{
return SerializedKeyHolder{
serializeKeysToPoolContiguous(row, keys_size, key_columns, collators, sort_key_containers, *pool),
pool};
}
}

protected:
Expand Down
37 changes: 27 additions & 10 deletions dbms/src/Interpreters/Aggregator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,6 @@ extern const char random_fail_in_resize_callback[];
extern const char force_agg_prefetch[];
} // namespace FailPoints

static constexpr size_t agg_prefetch_step = 16;
static constexpr size_t agg_mini_batch = 256;

#define AggregationMethodName(NAME) AggregatedDataVariants::AggregationMethod_##NAME
#define AggregationMethodNameTwoLevel(NAME) AggregatedDataVariants::AggregationMethod_##NAME##_two_level
#define AggregationMethodType(NAME) AggregatedDataVariants::Type::NAME
Expand Down Expand Up @@ -680,7 +677,18 @@ void NO_INLINE Aggregator::executeImpl(

if constexpr (Method::State::is_serialized_key)
{
executeImplMiniBatch<collect_hit_rate, only_lookup, false>(method, state, aggregates_pool, agg_process_info);
// For key_serialized, memory allocation and key serialization will be batch-wise.
// Need to init batch handler.
state.initBatchHandler(agg_process_info.start_row);

if (disable_prefetch)
executeImplMiniBatch<collect_hit_rate, only_lookup, false>(
method,
state,
aggregates_pool,
agg_process_info);
else
executeImplMiniBatch<collect_hit_rate, only_lookup, true>(method, state, aggregates_pool, agg_process_info);
}
else if constexpr (Method::Data::is_string_hash_map)
{
Expand Down Expand Up @@ -742,7 +750,7 @@ std::optional<typename Method::template EmplaceOrFindKeyResult<only_lookup>::Res
}

template <typename Method>
ALWAYS_INLINE inline void prepareBatch(
ALWAYS_INLINE inline void setupHashVals(
size_t row_idx,
size_t end_row,
std::vector<size_t> & hashvals,
Expand Down Expand Up @@ -867,17 +875,23 @@ void Aggregator::handleMiniBatchImpl(
key_holders.resize(agg_mini_batch);
}

Arena temp_batch_pool;
// i is the begin row index of each mini batch.
while (i < end)
{
if constexpr (enable_prefetch)
{
if unlikely (i + mini_batch_size > end)
mini_batch_size = end - i;
if unlikely (i + mini_batch_size > end)
mini_batch_size = end - i;

prepareBatch(i, end, hashvals, key_holders, aggregates_pool, sort_key_containers, method, state);
size_t batch_mem_size = 0;
if constexpr (Method::State::is_serialized_key)
{
assert(hashvals.size() == state.getBatchSize());
batch_mem_size = state.prepareNextBatch(&temp_batch_pool);
}

if constexpr (enable_prefetch)
setupHashVals(i, end, hashvals, key_holders, aggregates_pool, sort_key_containers, method, state);

const auto cur_batch_end = i + mini_batch_size;
// j is the row index of Column.
// k is the index of hashvals/key_holders.
Expand Down Expand Up @@ -961,6 +975,9 @@ void Aggregator::handleMiniBatchImpl(
processed_rows = j;
}

if constexpr (Method::State::is_serialized_key)
temp_batch_pool.rollback(batch_mem_size);

if unlikely (!processed_rows.has_value())
break;

Expand Down
8 changes: 5 additions & 3 deletions dbms/src/Interpreters/Aggregator.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ class IBlockOutputStream;
template <typename Method>
class AggHashTableToBlocksBlockInputStream;

static constexpr size_t agg_prefetch_step = 16;
static constexpr size_t agg_mini_batch = 256;

/** Different data structures that can be used for aggregation
* For efficiency, the aggregation data itself is put into the pool.
Expand Down Expand Up @@ -359,8 +361,8 @@ struct AggregationMethodFastPathTwoKeysNoCache
: data(other.data)
{}

using State
= ColumnsHashing::HashMethodFastPathTwoKeysSerialized<Key1Desc, Key2Desc, typename Data::value_type, Mapped>;
using State = ColumnsHashing::
HashMethodFastPathTwoKeysSerialized<Key1Desc, Key2Desc, typename Data::value_type, Mapped, agg_mini_batch>;
template <bool only_lookup>
struct EmplaceOrFindKeyResult
{
Expand Down Expand Up @@ -683,7 +685,7 @@ struct AggregationMethodSerialized
: data(other.data)
{}

using State = ColumnsHashing::HashMethodSerialized<typename Data::value_type, Mapped>;
using State = ColumnsHashing::HashMethodSerialized<typename Data::value_type, Mapped, agg_mini_batch>;
template <bool only_lookup>
struct EmplaceOrFindKeyResult
{
Expand Down

0 comments on commit a2a6116

Please sign in to comment.