diff --git a/dbms/src/Columns/ColumnAggregateFunction.h b/dbms/src/Columns/ColumnAggregateFunction.h index f5e7963ef4d..61b2d8d515a 100644 --- a/dbms/src/Columns/ColumnAggregateFunction.h +++ b/dbms/src/Columns/ColumnAggregateFunction.h @@ -199,7 +199,7 @@ class ColumnAggregateFunction final : public COWPtrHelper & /* pos */, size_t /* start */, size_t /* length */, - bool /* has_null */, + const NullMap * /* nullmap */, const TiDB::TiDBCollatorPtr & /* collator */, String * /* sort_key_container */) const override { @@ -217,7 +217,7 @@ class ColumnAggregateFunction final : public COWPtrHelper & /* pos */, size_t /* start */, size_t /* length */, - bool /* has_null */, + const NullMap * /* nullmap */, const IColumn::Offsets & /* offsets */, const TiDB::TiDBCollatorPtr & /* collator */, String * /* sort_key_container */) const override diff --git a/dbms/src/Columns/ColumnArray.cpp b/dbms/src/Columns/ColumnArray.cpp index 0ba742e092c..ad20ed2a43b 100644 --- a/dbms/src/Columns/ColumnArray.cpp +++ b/dbms/src/Columns/ColumnArray.cpp @@ -230,7 +230,7 @@ void ColumnArray::countSerializeByteSize(PaddedPODArray & byte_size) con countSerializeByteSizeImpl(byte_size, nullptr); } -template +template void ColumnArray::countSerializeByteSizeImpl(PaddedPODArray & byte_size, const TiDB::TiDBCollatorPtr & collator) const { @@ -251,7 +251,7 @@ void ColumnArray::countSerializeByteSizeImpl(PaddedPODArray & byte_size, for (size_t i = 0; i < size; ++i) byte_size[i] += sizeof(UInt32); - if constexpr (for_compare) + if constexpr (compare_semantics) getData().countSerializeByteSizeForCmpColumnArray(byte_size, getOffsets(), collator); else getData().countSerializeByteSizeForColumnArray(byte_size, getOffsets()); @@ -261,35 +261,63 @@ void ColumnArray::serializeToPosForCmp( PaddedPODArray & pos, size_t start, size_t length, - bool has_null, + const NullMap * nullmap, const TiDB::TiDBCollatorPtr & collator, String * sort_key_container) const { - if (has_null) - serializeToPosImpl(pos, start, length, collator, sort_key_container); + if (nullmap != nullptr) + serializeToPosImpl( + pos, + start, + length, + collator, + sort_key_container, + nullmap); else - serializeToPosImpl(pos, start, length, collator, sort_key_container); + serializeToPosImpl( + pos, + start, + length, + collator, + sort_key_container, + nullptr); } void ColumnArray::serializeToPos(PaddedPODArray & pos, size_t start, size_t length, bool has_null) const { if (has_null) - serializeToPosImpl(pos, start, length, nullptr, nullptr); + serializeToPosImpl( + pos, + start, + length, + nullptr, + nullptr, + nullptr); else - serializeToPosImpl(pos, start, length, nullptr, nullptr); + serializeToPosImpl( + pos, + start, + length, + nullptr, + nullptr, + nullptr); } -template +template void ColumnArray::serializeToPosImpl( PaddedPODArray & pos, size_t start, size_t length, const TiDB::TiDBCollatorPtr & collator, - String * sort_key_container) const + String * sort_key_container, + const NullMap * nullmap) const { RUNTIME_CHECK_MSG(length <= pos.size(), "length({}) > size of pos({})", length, pos.size()); RUNTIME_CHECK_MSG(start + length <= size(), "start({}) + length({}) > size of column({})", start, length, size()); + static_assert(!(has_null && has_nullmap)); + RUNTIME_CHECK(!has_nullmap || (nullmap && nullmap->size() == size())); + /// countSerializeByteSize has already checked that the size of one element is not greater than UINT32_MAX for (size_t i = 0; i < length; ++i) { @@ -298,14 +326,20 @@ void ColumnArray::serializeToPosImpl( if (pos[i] == nullptr) continue; } + UInt32 len = sizeAt(start + i); + if constexpr (has_nullmap) + { + if (DB::isNullAt(*nullmap, start + i)) + len = 0; + } tiflash_compiler_builtin_memcpy(pos[i], &len, sizeof(UInt32)); pos[i] += sizeof(UInt32); } - if constexpr (for_compare) + if constexpr (compare_semantics) getData() - .serializeToPosForCmpColumnArray(pos, start, length, has_null, getOffsets(), collator, sort_key_container); + .serializeToPosForCmpColumnArray(pos, start, length, nullmap, getOffsets(), collator, sort_key_container); else getData().serializeToPosForColumnArray(pos, start, length, has_null, getOffsets()); } @@ -320,7 +354,7 @@ void ColumnArray::deserializeAndInsertFromPos(PaddedPODArray & pos, bool deserializeAndInsertFromPosImpl(pos, use_nt_align_buffer); } -template +template void ColumnArray::deserializeAndInsertFromPosImpl(PaddedPODArray & pos, bool use_nt_align_buffer) { auto & offsets = getOffsets(); @@ -336,7 +370,7 @@ void ColumnArray::deserializeAndInsertFromPosImpl(PaddedPODArray & pos, pos[i] += sizeof(UInt32); } - if constexpr (for_compare) + if constexpr (compare_semantics) getData().deserializeForCmpAndInsertFromPosColumnArray(pos, offsets, use_nt_align_buffer); else getData().deserializeAndInsertFromPosForColumnArray(pos, offsets, use_nt_align_buffer); diff --git a/dbms/src/Columns/ColumnArray.h b/dbms/src/Columns/ColumnArray.h index 7eb91feb028..9052890d1fe 100644 --- a/dbms/src/Columns/ColumnArray.h +++ b/dbms/src/Columns/ColumnArray.h @@ -44,18 +44,19 @@ class ColumnArray final : public COWPtrHelper ColumnArray(const ColumnArray &) = default; - template + template void countSerializeByteSizeImpl(PaddedPODArray & byte_size, const TiDB::TiDBCollatorPtr & collator) const; - template + template void serializeToPosImpl( PaddedPODArray & pos, size_t start, size_t length, const TiDB::TiDBCollatorPtr & collator, - String * sort_key_container) const; + String * sort_key_container, + const NullMap * nullmap) const; - template + template void deserializeAndInsertFromPosImpl(PaddedPODArray & pos, bool use_nt_align_buffer); public: @@ -122,7 +123,7 @@ class ColumnArray final : public COWPtrHelper PaddedPODArray & pos, size_t start, size_t length, - bool has_null, + const NullMap * nullmap, const TiDB::TiDBCollatorPtr & collator, String * sort_key_container) const override; void serializeToPos(PaddedPODArray & pos, size_t start, size_t length, bool has_null) const override; @@ -131,7 +132,7 @@ class ColumnArray final : public COWPtrHelper PaddedPODArray & /* pos */, size_t /* start */, size_t /* length */, - bool /* has_null */, + const NullMap * /* nullmap */, const IColumn::Offsets & /* array_offsets */, const TiDB::TiDBCollatorPtr & /* collator */, String * /* sort_key_container */) const override diff --git a/dbms/src/Columns/ColumnConst.h b/dbms/src/Columns/ColumnConst.h index c17ba694ead..a2a46c04754 100644 --- a/dbms/src/Columns/ColumnConst.h +++ b/dbms/src/Columns/ColumnConst.h @@ -147,7 +147,7 @@ class ColumnConst final : public COWPtrHelper PaddedPODArray & /* pos */, size_t /* start */, size_t /* length */, - bool /* has_null */, + const NullMap * /* nullmap */, const TiDB::TiDBCollatorPtr & /* collator */, String * /* sort_key_container */) const override { @@ -166,7 +166,7 @@ class ColumnConst final : public COWPtrHelper PaddedPODArray & /* pos */, size_t /* start */, size_t /* length */, - bool /* has_null */, + const NullMap * /* nullmap */, const IColumn::Offsets & /* array_offsets */, const TiDB::TiDBCollatorPtr & /* collator */, String * /* sort_key_container */) const override diff --git a/dbms/src/Columns/ColumnDecimal.cpp b/dbms/src/Columns/ColumnDecimal.cpp index 36a774975ca..610739d15c9 100644 --- a/dbms/src/Columns/ColumnDecimal.cpp +++ b/dbms/src/Columns/ColumnDecimal.cpp @@ -140,13 +140,13 @@ const char * ColumnDecimal::deserializeAndInsertFromArena(const char * pos, c } template -template -void ColumnDecimal::countSerializeByteSizeImpl(PaddedPODArray & byte_size) const +template +void ColumnDecimal::countSerializeByteSizeImpl(PaddedPODArray & byte_size, const NullMap *) const { RUNTIME_CHECK_MSG(byte_size.size() == size(), "size of byte_size({}) != column size({})", byte_size.size(), size()); size_t size = byte_size.size(); - if constexpr (for_compare && is_Decimal256) + if constexpr (compare_semantics && is_Decimal256) { for (size_t i = 0; i < size; ++i) { @@ -160,9 +160,8 @@ void ColumnDecimal::countSerializeByteSizeImpl(PaddedPODArray & byte_ } } -// TODO add unit test template -template +template void ColumnDecimal::countSerializeByteSizeForColumnArrayImpl( PaddedPODArray & byte_size, const IColumn::Offsets & array_offsets) const @@ -173,7 +172,7 @@ void ColumnDecimal::countSerializeByteSizeForColumnArrayImpl( byte_size.size(), array_offsets.size()); - if constexpr (for_compare && is_Decimal256) + if constexpr (compare_semantics && is_Decimal256) { size_t size = array_offsets.size(); for (size_t i = 0; i < size; ++i) @@ -194,12 +193,20 @@ void ColumnDecimal::countSerializeByteSizeForColumnArrayImpl( } template -template -void ColumnDecimal::serializeToPosImpl(PaddedPODArray & pos, size_t start, size_t length) const +template +void ColumnDecimal::serializeToPosImpl( + PaddedPODArray & pos, + size_t start, + size_t length, + const NullMap * nullmap) const { RUNTIME_CHECK_MSG(length <= pos.size(), "length({}) > size of pos({})", length, pos.size()); RUNTIME_CHECK_MSG(start + length <= size(), "start({}) + length({}) > size of column({})", start, length, size()); + static_assert(!(has_null && has_nullmap)); + RUNTIME_CHECK(!has_nullmap || (nullmap && nullmap->size() == size())); + + static constexpr T def_val{}; for (size_t i = 0; i < length; ++i) { if constexpr (has_null) @@ -207,8 +214,14 @@ void ColumnDecimal::serializeToPosImpl(PaddedPODArray & pos, size_t s if (pos[i] == nullptr) continue; } + if constexpr (has_nullmap) + { + if (DB::isNullAt(*nullmap, start + i)) + pos[i] = serializeDecimal256Helper(pos[i], def_val); + continue; + } - if constexpr (for_compare && is_Decimal256) + if constexpr (compare_semantics && is_Decimal256) { pos[i] = serializeDecimal256Helper(pos[i], data[start + i]); } @@ -221,12 +234,13 @@ void ColumnDecimal::serializeToPosImpl(PaddedPODArray & pos, size_t s } template -template +template void ColumnDecimal::serializeToPosForColumnArrayImpl( PaddedPODArray & pos, size_t start, size_t length, - const IColumn::Offsets & array_offsets) const + const IColumn::Offsets & array_offsets, + const NullMap * nullmap) const { RUNTIME_CHECK_MSG(length <= pos.size(), "length({}) > size of pos({})", length, pos.size()); RUNTIME_CHECK_MSG( @@ -241,6 +255,9 @@ void ColumnDecimal::serializeToPosForColumnArrayImpl( array_offsets.back(), size()); + static_assert(!(has_null && has_nullmap)); + RUNTIME_CHECK(!has_nullmap || (nullmap && nullmap->size() == array_offsets.size())); + for (size_t i = 0; i < length; ++i) { if constexpr (has_null) @@ -248,9 +265,14 @@ void ColumnDecimal::serializeToPosForColumnArrayImpl( if (pos[i] == nullptr) continue; } + if constexpr (has_nullmap) + { + if (DB::isNullAt(*nullmap, start + i)) + continue; + } size_t len = array_offsets[start + i] - array_offsets[start + i - 1]; - if constexpr (for_compare && is_Decimal256) + if constexpr (compare_semantics && is_Decimal256) { for (size_t j = 0; j < len; ++j) pos[i] = serializeDecimal256Helper(pos[i], data[array_offsets[start + i - 1] + j]); @@ -275,7 +297,7 @@ void ColumnDecimal::serializeToPosForColumnArrayImpl( } template -template +template void ColumnDecimal::deserializeAndInsertFromPosImpl( PaddedPODArray & pos, bool use_nt_align_buffer [[maybe_unused]]) @@ -285,7 +307,7 @@ void ColumnDecimal::deserializeAndInsertFromPosImpl( // is_complex_decimal256 is true means Decimal256 is serialized by [bool, limb_count, n * limb]. // NT optimization is not implemented for simplicity. - static const bool is_complex_decimal256 = (for_compare && is_Decimal256); + static const bool is_complex_decimal256 = (compare_semantics && is_Decimal256); #ifdef TIFLASH_ENABLE_AVX_SUPPORT if (use_nt_align_buffer) @@ -383,7 +405,7 @@ void ColumnDecimal::deserializeAndInsertFromPosImpl( } template -template +template void ColumnDecimal::deserializeAndInsertFromPosForColumnArrayImpl( PaddedPODArray & pos, const IColumn::Offsets & array_offsets, @@ -410,7 +432,7 @@ void ColumnDecimal::deserializeAndInsertFromPosForColumnArrayImpl( for (size_t i = 0; i < size; ++i) { size_t len = array_offsets[start_point + i] - array_offsets[start_point + i - 1]; - if constexpr (for_compare && is_Decimal256) + if constexpr (compare_semantics && is_Decimal256) { for (size_t j = 0; j < len; ++j) pos[i] = const_cast( diff --git a/dbms/src/Columns/ColumnDecimal.h b/dbms/src/Columns/ColumnDecimal.h index 18211b13177..ddfacb98bae 100644 --- a/dbms/src/Columns/ColumnDecimal.h +++ b/dbms/src/Columns/ColumnDecimal.h @@ -101,26 +101,27 @@ class ColumnDecimal final : public COWPtrHelper - void countSerializeByteSizeImpl(PaddedPODArray & byte_size) const; - template + template + void countSerializeByteSizeImpl(PaddedPODArray & byte_size, const NullMap *) const; + template void countSerializeByteSizeForColumnArrayImpl( PaddedPODArray & byte_size, const IColumn::Offsets & array_offsets) const; - template - void serializeToPosImpl(PaddedPODArray & pos, size_t start, size_t length) const; - template + template + void serializeToPosImpl(PaddedPODArray & pos, size_t start, size_t length, const NullMap * nullmap) const; + template void serializeToPosForColumnArrayImpl( PaddedPODArray & pos, size_t start, size_t length, - const IColumn::Offsets & array_offsets) const; + const IColumn::Offsets & array_offsets, + const NullMap * nullmap) const; - template + template void deserializeAndInsertFromPosImpl(PaddedPODArray & pos, bool use_nt_align_buffer [[maybe_unused]]); - template + template void deserializeAndInsertFromPosForColumnArrayImpl( PaddedPODArray & pos, const IColumn::Offsets & array_offsets, @@ -175,11 +176,11 @@ class ColumnDecimal final : public COWPtrHelper & byte_size, const TiDB::TiDBCollatorPtr &) const override { - countSerializeByteSizeImpl(byte_size); + countSerializeByteSizeImpl(byte_size, nullptr); } void countSerializeByteSize(PaddedPODArray & byte_size) const override { - countSerializeByteSizeImpl(byte_size); + countSerializeByteSizeImpl(byte_size, nullptr); } void countSerializeByteSizeForCmpColumnArray( @@ -200,44 +201,62 @@ class ColumnDecimal final : public COWPtrHelper & pos, size_t start, size_t length, - bool has_null, + const NullMap * nullmap, const TiDB::TiDBCollatorPtr &, String *) const override { - if (has_null) - serializeToPosImpl(pos, start, length); + if (nullmap != nullptr) + serializeToPosImpl( + pos, + start, + length, + nullmap); else - serializeToPosImpl(pos, start, length); + serializeToPosImpl( + pos, + start, + length, + nullptr); } void serializeToPos(PaddedPODArray & pos, size_t start, size_t length, bool has_null) const override { if (has_null) - serializeToPosImpl(pos, start, length); + serializeToPosImpl( + pos, + start, + length, + nullptr); else - serializeToPosImpl(pos, start, length); + serializeToPosImpl( + pos, + start, + length, + nullptr); } void serializeToPosForCmpColumnArray( PaddedPODArray & pos, size_t start, size_t length, - bool has_null, + const NullMap * nullmap, const IColumn::Offsets & array_offsets, const TiDB::TiDBCollatorPtr &, String *) const override { - if (has_null) - serializeToPosForColumnArrayImpl( + if (nullmap != nullptr) + serializeToPosForColumnArrayImpl( pos, start, length, - array_offsets); + array_offsets, + nullmap); else - serializeToPosForColumnArrayImpl( + serializeToPosForColumnArrayImpl( pos, start, length, - array_offsets); + array_offsets, + nullptr); } void serializeToPosForColumnArray( PaddedPODArray & pos, @@ -247,17 +266,19 @@ class ColumnDecimal final : public COWPtrHelper( + serializeToPosForColumnArrayImpl( pos, start, length, - array_offsets); + array_offsets, + nullptr); else - serializeToPosForColumnArrayImpl( + serializeToPosForColumnArrayImpl( pos, start, length, - array_offsets); + array_offsets, + nullptr); } void deserializeForCmpAndInsertFromPos(PaddedPODArray & pos, bool use_nt_align_buffer) override diff --git a/dbms/src/Columns/ColumnFixedString.cpp b/dbms/src/Columns/ColumnFixedString.cpp index b15e420a1ce..217b19cf913 100644 --- a/dbms/src/Columns/ColumnFixedString.cpp +++ b/dbms/src/Columns/ColumnFixedString.cpp @@ -134,7 +134,7 @@ const char * ColumnFixedString::deserializeAndInsertFromArena(const char * pos, return pos + n; } -void ColumnFixedString::countSerializeByteSize(PaddedPODArray & byte_size) const +void ColumnFixedString::countSerializeByteSizeImpl(PaddedPODArray & byte_size) const { RUNTIME_CHECK_MSG(byte_size.size() == size(), "size of byte_size({}) != column size({})", byte_size.size(), size()); @@ -143,7 +143,7 @@ void ColumnFixedString::countSerializeByteSize(PaddedPODArray & byte_siz byte_size[i] += n; } -void ColumnFixedString::countSerializeByteSizeForColumnArray( +void ColumnFixedString::countSerializeByteSizeForColumnArrayImpl( PaddedPODArray & byte_size, const IColumn::Offsets & array_offsets) const { @@ -161,17 +161,24 @@ void ColumnFixedString::countSerializeByteSizeForColumnArray( void ColumnFixedString::serializeToPos(PaddedPODArray & pos, size_t start, size_t length, bool has_null) const { if (has_null) - serializeToPosImpl(pos, start, length); + serializeToPosImpl(pos, start, length, nullptr); else - serializeToPosImpl(pos, start, length); + serializeToPosImpl(pos, start, length, nullptr); } -template -void ColumnFixedString::serializeToPosImpl(PaddedPODArray & pos, size_t start, size_t length) const +template +void ColumnFixedString::serializeToPosImpl( + PaddedPODArray & pos, + size_t start, + size_t length, + const NullMap * nullmap) const { RUNTIME_CHECK_MSG(length <= pos.size(), "length({}) > size of pos({})", length, pos.size()); RUNTIME_CHECK_MSG(start + length <= size(), "start({}) + length({}) > size of column({})", start, length, size()); + static_assert(!(has_null && has_nullmap)); + RUNTIME_CHECK(!has_nullmap || (nullmap && nullmap->size() == size())); + for (size_t i = 0; i < length; ++i) { if constexpr (has_null) @@ -179,6 +186,15 @@ void ColumnFixedString::serializeToPosImpl(PaddedPODArray & pos, size_t if (pos[i] == nullptr) continue; } + if constexpr (has_nullmap) + { + if (DB::isNullAt(*nullmap, start + i)) + { + memset(pos[i], '\0', n); + pos[i] += n; + continue; + } + } inline_memcpy(pos[i], &chars[n * (start + i)], n); pos[i] += n; } @@ -192,17 +208,18 @@ void ColumnFixedString::serializeToPosForColumnArray( const IColumn::Offsets & array_offsets) const { if (has_null) - serializeToPosForColumnArrayImpl(pos, start, length, array_offsets); + serializeToPosForColumnArrayImpl(pos, start, length, array_offsets, nullptr); else - serializeToPosForColumnArrayImpl(pos, start, length, array_offsets); + serializeToPosForColumnArrayImpl(pos, start, length, array_offsets, nullptr); } -template +template void ColumnFixedString::serializeToPosForColumnArrayImpl( PaddedPODArray & pos, size_t start, size_t length, - const IColumn::Offsets & array_offsets) const + const IColumn::Offsets & array_offsets, + const NullMap * nullmap) const { RUNTIME_CHECK_MSG(length <= pos.size(), "length({}) > size of pos({})", length, pos.size()); RUNTIME_CHECK_MSG( @@ -217,6 +234,9 @@ void ColumnFixedString::serializeToPosForColumnArrayImpl( array_offsets.back(), size()); + static_assert(!(has_null && has_nullmap)); + RUNTIME_CHECK(!has_nullmap || (nullmap && nullmap->size() == array_offsets.size())); + for (size_t i = 0; i < length; ++i) { if constexpr (has_null) @@ -224,7 +244,11 @@ void ColumnFixedString::serializeToPosForColumnArrayImpl( if (pos[i] == nullptr) continue; } - + if constexpr (has_nullmap) + { + if (DB::isNullAt(*nullmap, start + i)) + continue; + } size_t len = array_offsets[start + i] - array_offsets[start + i - 1]; inline_memcpy(pos[i], &chars[n * array_offsets[start + i - 1]], n * len); pos[i] += n * len; diff --git a/dbms/src/Columns/ColumnFixedString.h b/dbms/src/Columns/ColumnFixedString.h index 161cef5cf8a..b909cf37927 100644 --- a/dbms/src/Columns/ColumnFixedString.h +++ b/dbms/src/Columns/ColumnFixedString.h @@ -54,16 +54,22 @@ class ColumnFixedString final : public COWPtrHelper , chars(src.chars.begin(), src.chars.end()) , n(src.n){}; + void countSerializeByteSizeImpl(PaddedPODArray & byte_size) const; - template - void serializeToPosImpl(PaddedPODArray & pos, size_t start, size_t length) const; + void countSerializeByteSizeForColumnArrayImpl( + PaddedPODArray & byte_size, + const IColumn::Offsets & array_offsets) const; + + template + void serializeToPosImpl(PaddedPODArray & pos, size_t start, size_t length, const NullMap * nullmap) const; - template + template void serializeToPosForColumnArrayImpl( PaddedPODArray & pos, size_t start, size_t length, - const IColumn::Offsets & array_offsets) const; + const IColumn::Offsets & array_offsets, + const NullMap * nullmap) const; public: std::string getName() const override { return "FixedString(" + std::to_string(n) + ")"; } @@ -123,9 +129,12 @@ class ColumnFixedString final : public COWPtrHelper !collator, "{} doesn't support countSerializeByteSizeForCmp when collator is not null", getName()); - countSerializeByteSize(byte_size); + countSerializeByteSizeImpl(byte_size); + } + void countSerializeByteSize(PaddedPODArray & byte_size) const override + { + countSerializeByteSizeImpl(byte_size); } - void countSerializeByteSize(PaddedPODArray & byte_size) const override; void countSerializeByteSizeForCmpColumnArray( PaddedPODArray & byte_size, @@ -136,22 +145,28 @@ class ColumnFixedString final : public COWPtrHelper !collator, "{} doesn't support countSerializeByteSizeForCmpColumnArray when collator is not null", getName()); - countSerializeByteSizeForColumnArray(byte_size, array_offsets); + countSerializeByteSizeForColumnArrayImpl(byte_size, array_offsets); } void countSerializeByteSizeForColumnArray( PaddedPODArray & byte_size, - const IColumn::Offsets & array_offsets) const override; + const IColumn::Offsets & array_offsets) const override + { + countSerializeByteSizeForColumnArrayImpl(byte_size, array_offsets); + } void serializeToPosForCmp( PaddedPODArray & pos, size_t start, size_t length, - bool has_null, + const NullMap * nullmap, const TiDB::TiDBCollatorPtr & collator, String *) const override { RUNTIME_CHECK_MSG(!collator, "{} doesn't support serializeToPosForCmp when collator is not null", getName()); - serializeToPos(pos, start, length, has_null); + if (nullmap != nullptr) + serializeToPosImpl(pos, start, length, nullmap); + else + serializeToPosImpl(pos, start, length, nullptr); } void serializeToPos(PaddedPODArray & pos, size_t start, size_t length, bool has_null) const override; @@ -159,7 +174,7 @@ class ColumnFixedString final : public COWPtrHelper PaddedPODArray & pos, size_t start, size_t length, - bool has_null, + const NullMap * nullmap, const IColumn::Offsets & array_offsets, const TiDB::TiDBCollatorPtr & collator, String *) const override @@ -168,7 +183,10 @@ class ColumnFixedString final : public COWPtrHelper !collator, "{} doesn't support serializeToPosForCmpColumnArray when collator is not null", getName()); - serializeToPosForColumnArray(pos, start, length, has_null, array_offsets); + if (nullmap != nullptr) + serializeToPosForColumnArrayImpl(pos, start, length, array_offsets, nullmap); + else + serializeToPosForColumnArrayImpl(pos, start, length, array_offsets, nullptr); } void serializeToPosForColumnArray( PaddedPODArray & pos, diff --git a/dbms/src/Columns/ColumnFunction.h b/dbms/src/Columns/ColumnFunction.h index fd43650d255..8343ea4d6c5 100644 --- a/dbms/src/Columns/ColumnFunction.h +++ b/dbms/src/Columns/ColumnFunction.h @@ -155,7 +155,7 @@ class ColumnFunction final : public COWPtrHelper PaddedPODArray & /* pos */, size_t /* start */, size_t /* length */, - bool /* has_null */, + const NullMap * /* nullmap */, const TiDB::TiDBCollatorPtr & /* collator */, String * /*sort_key_container */) const override { @@ -174,7 +174,7 @@ class ColumnFunction final : public COWPtrHelper PaddedPODArray & /* pos */, size_t /* start */, size_t /* length */, - bool /* has_null */, + const NullMap * /* nullmap */, const IColumn::Offsets & /* array_offsets */, const TiDB::TiDBCollatorPtr & /* collator */, String * /* sort_key_container */) const override diff --git a/dbms/src/Columns/ColumnNullable.cpp b/dbms/src/Columns/ColumnNullable.cpp index acfa920ca7d..c9c7e6e82ee 100644 --- a/dbms/src/Columns/ColumnNullable.cpp +++ b/dbms/src/Columns/ColumnNullable.cpp @@ -315,13 +315,16 @@ void ColumnNullable::serializeToPosForCmp( PaddedPODArray & pos, size_t start, size_t length, - bool has_null, + const NullMap * nullmap, const TiDB::TiDBCollatorPtr & collator, String * sort_key_container) const { - getNullMapColumn().serializeToPosForCmp(pos, start, length, has_null, collator, sort_key_container); - getNestedColumn().serializeToPosForCmp(pos, start, length, has_null, collator, sort_key_container); + // Nested ColumnNullable like ColumnNullable(ColumnArray(ColumnNullable(ColumnXXX))) not support. + RUNTIME_CHECK_MSG(!nullmap, "serializeToPosForCmp cannot handle nested nullable"); + getNullMapColumn().serializeToPosForCmp(pos, start, length, nullptr, collator, sort_key_container); + getNestedColumn().serializeToPosForCmp(pos, start, length, &getNullMapData(), collator, sort_key_container); } + void ColumnNullable::serializeToPos(PaddedPODArray & pos, size_t start, size_t length, bool has_null) const { getNullMapColumn().serializeToPos(pos, start, length, has_null); @@ -329,18 +332,20 @@ void ColumnNullable::serializeToPos(PaddedPODArray & pos, size_t start, } void ColumnNullable::serializeToPosForCmpColumnArray( - PaddedPODArray & pos, - size_t start, - size_t length, - bool has_null, - const IColumn::Offsets & array_offsets, - const TiDB::TiDBCollatorPtr & collator, - String * sort_key_container) const -{ - getNullMapColumn() - .serializeToPosForCmpColumnArray(pos, start, length, has_null, array_offsets, collator, sort_key_container); - getNestedColumn() - .serializeToPosForCmpColumnArray(pos, start, length, has_null, array_offsets, collator, sort_key_container); + PaddedPODArray & /* pos */, + size_t /* start */, + size_t /* length */, + const NullMap * /* nullmap */, + const IColumn::Offsets & /* array_offsets */, + const TiDB::TiDBCollatorPtr & /* collator */, + String * /* sort_key_container */) const +{ + // Unable to handle ColumnArray(ColumnNullable(ColumnXXX)). Because the pos vector corresponds to the rows of ColumnArray, + // while ColumnNullable::nullmap corresponds to the rows of ColumnNullable. + // This means it's not easy to correctly serialize the row in ColumnNullable to the corresponding position in pos. + throw Exception( + "serializeToPosForCmpColumnArray cannot handle ColumnArray(" + getName() + ")", + ErrorCodes::NOT_IMPLEMENTED); } void ColumnNullable::serializeToPosForColumnArray( PaddedPODArray & pos, diff --git a/dbms/src/Columns/ColumnNullable.h b/dbms/src/Columns/ColumnNullable.h index 76b5a6708b5..5b216572e66 100644 --- a/dbms/src/Columns/ColumnNullable.h +++ b/dbms/src/Columns/ColumnNullable.h @@ -21,8 +21,7 @@ namespace DB { -using NullMap = ColumnUInt8::Container; -using ConstNullMapPtr = const NullMap *; +static_assert(std::is_same_v); /// Class that specifies nullable columns. A nullable column represents /// a column, which may have any type, provided with the possibility of @@ -61,7 +60,10 @@ class ColumnNullable final : public COWPtrHelper std::string getName() const override { return "Nullable(" + nested_column->getName() + ")"; } MutableColumnPtr cloneResized(size_t size) const override; size_t size() const override { return static_cast(*null_map).size(); } - bool isNullAt(size_t n) const override { return static_cast(*null_map).getData()[n] != 0; } + bool isNullAt(size_t n) const override + { + return DB::isNullAt(static_cast(*null_map).getData(), n); + } Field operator[](size_t n) const override; void get(size_t n, Field & res) const override; UInt64 get64(size_t n) const override { return nested_column->get64(n); } @@ -94,7 +96,7 @@ class ColumnNullable final : public COWPtrHelper PaddedPODArray & pos, size_t start, size_t length, - bool has_null, + const NullMap * nullmap, const TiDB::TiDBCollatorPtr & collator, String * sort_key_container) const override; void serializeToPos(PaddedPODArray & pos, size_t start, size_t length, bool has_null) const override; @@ -103,7 +105,7 @@ class ColumnNullable final : public COWPtrHelper PaddedPODArray & pos, size_t start, size_t length, - bool has_null, + const NullMap * nullmap, const IColumn::Offsets & array_offsets, const TiDB::TiDBCollatorPtr & collator, String * sort_key_container) const override; diff --git a/dbms/src/Columns/ColumnString.cpp b/dbms/src/Columns/ColumnString.cpp index 537ea145a59..d2e1ab47bec 100644 --- a/dbms/src/Columns/ColumnString.cpp +++ b/dbms/src/Columns/ColumnString.cpp @@ -35,6 +35,18 @@ extern const int PARAMETER_OUT_OF_BOUND; extern const int SIZES_OF_COLUMNS_DOESNT_MATCH; } // namespace ErrorCodes +struct ColumnStringDefaultValue +{ + char mem[sizeof(UInt32) + 1] = {0}; + ColumnStringDefaultValue() + { + UInt32 str_size = 1; + tiflash_compiler_builtin_memcpy(&mem[0], &str_size, sizeof(str_size)); + } +}; + +static ColumnStringDefaultValue col_str_def_val; + MutableColumnPtr ColumnString::cloneResized(size_t to_size) const { auto res = ColumnString::create(); @@ -659,61 +671,94 @@ void ColumnString::serializeToPosForCmp( PaddedPODArray & pos, size_t start, size_t length, - bool has_null, + const NullMap * nullmap, const TiDB::TiDBCollatorPtr & collator, String * sort_key_container) const { - if (has_null) + if (nullmap != nullptr) { if likely (collator != nullptr) - serializeToPosImplType( + serializeToPosImplType( pos, start, length, collator, - sort_key_container); + sort_key_container, + nullmap); else - serializeToPosImplType(pos, start, length, nullptr, nullptr); + serializeToPosImplType( + pos, + start, + length, + nullptr, + nullptr, + nullmap); } else { if likely (collator != nullptr) - serializeToPosImplType( + serializeToPosImplType( pos, start, length, collator, - sort_key_container); + sort_key_container, + nullptr); else - serializeToPosImplType(pos, start, length, nullptr, nullptr); + serializeToPosImplType( + pos, + start, + length, + nullptr, + nullptr, + nullptr); } } void ColumnString::serializeToPos(PaddedPODArray & pos, size_t start, size_t length, bool has_null) const { if (has_null) - serializeToPosImplType(pos, start, length, nullptr, nullptr); + serializeToPosImplType( + pos, + start, + length, + nullptr, + nullptr, + nullptr); else - serializeToPosImplType(pos, start, length, nullptr, nullptr); + serializeToPosImplType( + pos, + start, + length, + nullptr, + nullptr, + nullptr); } -template +template void ColumnString::serializeToPosImplType( PaddedPODArray & pos, size_t start, size_t length, const TiDB::TiDBCollatorPtr & collator, - String * sort_key_container) const + String * sort_key_container, + const NullMap * nullmap) const { - if constexpr (has_collator) + if constexpr (compare_semantics) { RUNTIME_CHECK(collator && sort_key_container); -#define M(VAR_PREFIX, COLLATOR_NAME, IMPL_TYPE, COLLATOR_ID) \ - case (COLLATOR_ID): \ - { \ - serializeToPosImpl(pos, start, length, collator, sort_key_container); \ - break; \ +#define M(VAR_PREFIX, COLLATOR_NAME, IMPL_TYPE, COLLATOR_ID) \ + case (COLLATOR_ID): \ + { \ + serializeToPosImpl( \ + pos, \ + start, \ + length, \ + collator, \ + sort_key_container, \ + nullmap); \ + break; \ } switch (collator->getCollatorId()) @@ -728,41 +773,50 @@ void ColumnString::serializeToPosImplType( } else { - serializeToPosImpl( + assert(!nullmap); + serializeToPosImpl( pos, start, length, collator, - sort_key_container); + sort_key_container, + nullptr); } } -template +template void ColumnString::serializeToPosImpl( PaddedPODArray & pos, size_t start, size_t length, const TiDB::TiDBCollatorPtr & collator, - String * sort_key_container) const + String * sort_key_container, + const NullMap * nullmap) const { RUNTIME_CHECK_MSG(length <= pos.size(), "length({}) > size of pos({})", length, pos.size()); RUNTIME_CHECK_MSG(start + length <= size(), "start({}) + length({}) > size of column({})", start, length, size()); + static_assert(!(has_null && has_nullmap)); + RUNTIME_CHECK(!has_nullmap || (nullmap && nullmap->size() == size())); + /// To avoid virtual function call of sortKey(). const auto * derived_collator = static_cast(collator); /// countSerializeByteSizeImpl has already checked that the size of one element is not greater than UINT32_MAX for (size_t i = 0; i < length; ++i) { - if constexpr (has_null) - { - if (pos[i] == nullptr) - continue; - } - - UInt32 str_size = sizeAt(start + i); - const void * src = &chars[offsetAt(start + i)]; - if constexpr (has_collator) + if constexpr (compare_semantics) { + if constexpr (has_nullmap) + { + if (DB::isNullAt(*nullmap, start + i)) + { + tiflash_compiler_builtin_memcpy(pos[i], &col_str_def_val.mem[0], sizeof(col_str_def_val.mem)); + pos[i] += sizeof(col_str_def_val.mem); + continue; + } + } + UInt32 str_size = sizeAt(start + i); + const void * src = &chars[offsetAt(start + i)]; auto sort_key = derived_collator->sortKey(reinterpret_cast(src), str_size - 1, *sort_key_container); // For terminating zero. @@ -777,6 +831,15 @@ void ColumnString::serializeToPosImpl( } else { + if constexpr (has_null) + { + if (pos[i] == nullptr) + continue; + } + + UInt32 str_size = sizeAt(start + i); + const void * src = &chars[offsetAt(start + i)]; + tiflash_compiler_builtin_memcpy(pos[i], &str_size, sizeof(UInt32)); pos[i] += sizeof(UInt32); inline_memcpy(pos[i], src, str_size); @@ -789,48 +852,48 @@ void ColumnString::serializeToPosForCmpColumnArray( PaddedPODArray & pos, size_t start, size_t length, - bool has_null, + const NullMap * nullmap, const IColumn::Offsets & array_offsets, const TiDB::TiDBCollatorPtr & collator, String * sort_key_container) const { - if (has_null) + if (nullmap != nullptr) { if likely (collator != nullptr) - serializeToPosForColumnArrayImplType( + serializeToPosForColumnArrayImplType( pos, start, length, array_offsets, collator, - sort_key_container); + sort_key_container, + nullmap); else - serializeToPosForColumnArrayImplType( + serializeToPosForColumnArrayImplType( pos, start, length, array_offsets, nullptr, - nullptr); + nullptr, + nullmap); } else { if likely (collator != nullptr) - serializeToPosForColumnArrayImplType( + serializeToPosForColumnArrayImplType( pos, start, length, array_offsets, collator, - sort_key_container); - else - serializeToPosForColumnArrayImplType( - pos, - start, - length, - array_offsets, - nullptr, + sort_key_container, nullptr); + else + serializeToPosForColumnArrayImplType< + /*has_null=*/false, + /*compare_semantics=*/false, + /*has_nullmap=*/false>(pos, start, length, array_offsets, nullptr, nullptr, nullptr); } } @@ -842,47 +905,51 @@ void ColumnString::serializeToPosForColumnArray( const IColumn::Offsets & array_offsets) const { if (has_null) - serializeToPosForColumnArrayImplType( + serializeToPosForColumnArrayImplType( pos, start, length, array_offsets, nullptr, + nullptr, nullptr); else - serializeToPosForColumnArrayImplType( + serializeToPosForColumnArrayImplType( pos, start, length, array_offsets, nullptr, + nullptr, nullptr); } -template +template void ColumnString::serializeToPosForColumnArrayImplType( PaddedPODArray & pos, size_t start, size_t length, const IColumn::Offsets & array_offsets, const TiDB::TiDBCollatorPtr & collator, - String * sort_key_container) const + String * sort_key_container, + const NullMap * nullmap) const { - if constexpr (has_collator) + if constexpr (compare_semantics) { RUNTIME_CHECK(collator && sort_key_container); -#define M(VAR_PREFIX, COLLATOR_NAME, IMPL_TYPE, COLLATOR_ID) \ - case (COLLATOR_ID): \ - { \ - serializeToPosForColumnArrayImpl( \ - pos, \ - start, \ - length, \ - array_offsets, \ - collator, \ - sort_key_container); \ - break; \ +#define M(VAR_PREFIX, COLLATOR_NAME, IMPL_TYPE, COLLATOR_ID) \ + case (COLLATOR_ID): \ + { \ + serializeToPosForColumnArrayImpl( \ + pos, \ + start, \ + length, \ + array_offsets, \ + collator, \ + sort_key_container, \ + nullmap); \ + break; \ } switch (collator->getCollatorId()) @@ -897,24 +964,26 @@ void ColumnString::serializeToPosForColumnArrayImplType( } else { - serializeToPosForColumnArrayImpl( + serializeToPosForColumnArrayImpl( pos, start, length, array_offsets, collator, - sort_key_container); + sort_key_container, + nullptr); } } -template +template void ColumnString::serializeToPosForColumnArrayImpl( PaddedPODArray & pos, size_t start, size_t length, const IColumn::Offsets & array_offsets, const TiDB::TiDBCollatorPtr & collator, - String * sort_key_container) const + String * sort_key_container, + const NullMap * nullmap) const { RUNTIME_CHECK_MSG(length <= pos.size(), "length({}) > size of pos({})", length, pos.size()); RUNTIME_CHECK_MSG( @@ -929,16 +998,19 @@ void ColumnString::serializeToPosForColumnArrayImpl( array_offsets.back(), size()); + static_assert(!(has_null && has_nullmap)); + RUNTIME_CHECK(!has_nullmap || (nullmap && nullmap->size() == array_offsets.size())); + /// countSerializeByteSizeForCmpColumnArray has already checked that the size of one element is not greater than UINT32_MAX - if constexpr (has_collator) + if constexpr (compare_semantics) { /// To avoid virtual function call of sortKey(). const auto * derived_collator = static_cast(collator); for (size_t i = 0; i < length; ++i) { - if constexpr (has_null) + if constexpr (has_nullmap) { - if (pos[i] == nullptr) + if (DB::isNullAt(*nullmap, start + i)) continue; } for (size_t j = array_offsets[start + i - 1]; j < array_offsets[start + i]; ++j) diff --git a/dbms/src/Columns/ColumnString.h b/dbms/src/Columns/ColumnString.h index da9eec75cb8..dd4f3c1f70f 100644 --- a/dbms/src/Columns/ColumnString.h +++ b/dbms/src/Columns/ColumnString.h @@ -115,37 +115,41 @@ class ColumnString final : public COWPtrHelper const IColumn::Offsets & array_offsets, const TiDB::TiDBCollatorPtr & collator) const; - template + template void serializeToPosImplType( PaddedPODArray & pos, size_t start, size_t length, const TiDB::TiDBCollatorPtr & collator, - String * sort_key_container) const; - template + String * sort_key_container, + const NullMap * nullmap) const; + template void serializeToPosImpl( PaddedPODArray & pos, size_t start, size_t length, const TiDB::TiDBCollatorPtr & collator, - String * sort_key_container) const; + String * sort_key_container, + const NullMap * nullmap) const; - template + template void serializeToPosForColumnArrayImplType( PaddedPODArray & pos, size_t start, size_t length, const IColumn::Offsets & array_offsets, const TiDB::TiDBCollatorPtr & collator, - String * sort_key_container) const; - template + String * sort_key_container, + const NullMap * nullmap) const; + template void serializeToPosForColumnArrayImpl( PaddedPODArray & pos, size_t start, size_t length, const IColumn::Offsets & array_offsets, const TiDB::TiDBCollatorPtr & collator, - String * sort_key_container) const; + String * sort_key_container, + const NullMap * nullmap) const; void deserializeAndInsertFromPosImpl(PaddedPODArray & pos, bool use_nt_align_buffer); template @@ -313,7 +317,7 @@ class ColumnString final : public COWPtrHelper PaddedPODArray & pos, size_t start, size_t length, - bool has_null, + const NullMap * nullmap, const TiDB::TiDBCollatorPtr & collator, String * sort_key_container) const override; void serializeToPos(PaddedPODArray & pos, size_t start, size_t length, bool has_null) const override; @@ -322,7 +326,7 @@ class ColumnString final : public COWPtrHelper PaddedPODArray & pos, size_t start, size_t length, - bool has_null, + const NullMap * nullmap, const IColumn::Offsets & array_offsets, const TiDB::TiDBCollatorPtr & collator, String * sort_key_container) const override; diff --git a/dbms/src/Columns/ColumnTuple.h b/dbms/src/Columns/ColumnTuple.h index 9e89ff7b937..d049fc74ebf 100644 --- a/dbms/src/Columns/ColumnTuple.h +++ b/dbms/src/Columns/ColumnTuple.h @@ -127,12 +127,12 @@ class ColumnTuple final : public COWPtrHelper PaddedPODArray & pos, size_t start, size_t length, - bool has_null, + const NullMap * nullmap, const TiDB::TiDBCollatorPtr & collator, String * sort_key_container) const override { for (const auto & column : columns) - column->serializeToPosForCmp(pos, start, length, has_null, collator, sort_key_container); + column->serializeToPosForCmp(pos, start, length, nullmap, collator, sort_key_container); } void serializeToPos(PaddedPODArray & pos, size_t start, size_t length, bool has_null) const override { @@ -144,7 +144,7 @@ class ColumnTuple final : public COWPtrHelper PaddedPODArray & pos, size_t start, size_t length, - bool has_null, + const NullMap * nullmap, const IColumn::Offsets & array_offsets, const TiDB::TiDBCollatorPtr & collator, String * sort_key_container) const override @@ -154,7 +154,7 @@ class ColumnTuple final : public COWPtrHelper pos, start, length, - has_null, + nullmap, array_offsets, collator, sort_key_container); diff --git a/dbms/src/Columns/ColumnVector.cpp b/dbms/src/Columns/ColumnVector.cpp index 20a823124ed..be92ec62c5f 100644 --- a/dbms/src/Columns/ColumnVector.cpp +++ b/dbms/src/Columns/ColumnVector.cpp @@ -88,18 +88,41 @@ template void ColumnVector::serializeToPos(PaddedPODArray & pos, size_t start, size_t length, bool has_null) const { if (has_null) - serializeToPosImpl(pos, start, length); + serializeToPosImpl(pos, start, length, nullptr); else - serializeToPosImpl(pos, start, length); + serializeToPosImpl(pos, start, length, nullptr); } template -template -void ColumnVector::serializeToPosImpl(PaddedPODArray & pos, size_t start, size_t length) const +void ColumnVector::serializeToPosForCmp( + PaddedPODArray & pos, + size_t start, + size_t length, + const NullMap * nullmap, + const TiDB::TiDBCollatorPtr &, + String *) const +{ + if (nullmap != nullptr) + serializeToPosImpl(pos, start, length, nullmap); + else + serializeToPosImpl(pos, start, length, nullptr); +} + +template +template +void ColumnVector::serializeToPosImpl( + PaddedPODArray & pos, + size_t start, + size_t length, + const NullMap * nullmap) const { RUNTIME_CHECK_MSG(length <= pos.size(), "length({}) > size of pos({})", length, pos.size()); RUNTIME_CHECK_MSG(start + length <= size(), "start({}) + length({}) > size of column({})", start, length, size()); + static_assert(!(has_null && has_nullmap)); + RUNTIME_CHECK(!has_nullmap || (nullmap && nullmap->size() == size())); + + static constexpr T def_val{}; for (size_t i = 0; i < length; ++i) { if constexpr (has_null) @@ -107,6 +130,15 @@ void ColumnVector::serializeToPosImpl(PaddedPODArray & pos, size_t st if (pos[i] == nullptr) continue; } + if constexpr (has_nullmap) + { + if (DB::isNullAt(*nullmap, start + i)) + { + tiflash_compiler_builtin_memcpy(pos[i], &def_val, sizeof(T)); + pos[i] += sizeof(T); + continue; + } + } tiflash_compiler_builtin_memcpy(pos[i], &data[start + i], sizeof(T)); pos[i] += sizeof(T); } @@ -121,18 +153,35 @@ void ColumnVector::serializeToPosForColumnArray( const IColumn::Offsets & array_offsets) const { if (has_null) - serializeToPosForColumnArrayImpl(pos, start, length, array_offsets); + serializeToPosForColumnArrayImpl(pos, start, length, array_offsets, nullptr); else - serializeToPosForColumnArrayImpl(pos, start, length, array_offsets); + serializeToPosForColumnArrayImpl(pos, start, length, array_offsets, nullptr); } template -template +void ColumnVector::serializeToPosForCmpColumnArray( + PaddedPODArray & pos, + size_t start, + size_t length, + const NullMap * nullmap, + const IColumn::Offsets & array_offsets, + const TiDB::TiDBCollatorPtr &, + String *) const +{ + if (nullmap != nullptr) + serializeToPosForColumnArrayImpl(pos, start, length, array_offsets, nullmap); + else + serializeToPosForColumnArrayImpl(pos, start, length, array_offsets, nullptr); +} + +template +template void ColumnVector::serializeToPosForColumnArrayImpl( PaddedPODArray & pos, size_t start, size_t length, - const IColumn::Offsets & array_offsets) const + const IColumn::Offsets & array_offsets, + const NullMap * nullmap) const { RUNTIME_CHECK_MSG(length <= pos.size(), "length({}) > size of pos({})", length, pos.size()); RUNTIME_CHECK_MSG( @@ -147,6 +196,9 @@ void ColumnVector::serializeToPosForColumnArrayImpl( array_offsets.back(), size()); + static_assert(!(has_null && has_nullmap)); + RUNTIME_CHECK(!has_nullmap || (nullmap && nullmap->size() == array_offsets.size())); + for (size_t i = 0; i < length; ++i) { if constexpr (has_null) @@ -154,6 +206,11 @@ void ColumnVector::serializeToPosForColumnArrayImpl( if (pos[i] == nullptr) continue; } + if constexpr (has_nullmap) + { + if (DB::isNullAt(*nullmap, start + i)) + continue; + } size_t len = array_offsets[start + i] - array_offsets[start + i - 1]; if (len <= 4) { diff --git a/dbms/src/Columns/ColumnVector.h b/dbms/src/Columns/ColumnVector.h index 57275b6905a..fd49028bf70 100644 --- a/dbms/src/Columns/ColumnVector.h +++ b/dbms/src/Columns/ColumnVector.h @@ -198,15 +198,16 @@ class ColumnVector final : public COWPtrHelper - void serializeToPosImpl(PaddedPODArray & pos, size_t start, size_t length) const; + template + void serializeToPosImpl(PaddedPODArray & pos, size_t start, size_t length, const NullMap * nullmap) const; - template + template void serializeToPosForColumnArrayImpl( PaddedPODArray & pos, size_t start, size_t length, - const IColumn::Offsets & array_offsets) const; + const IColumn::Offsets & array_offsets, + const NullMap * nullmap) const; public: bool isNumeric() const override { return is_arithmetic_v; } @@ -348,25 +349,19 @@ class ColumnVector final : public COWPtrHelper & pos, size_t start, size_t length, - bool has_null, + const NullMap * nullmap, const TiDB::TiDBCollatorPtr &, - String *) const override - { - serializeToPos(pos, start, length, has_null); - } + String *) const override; void serializeToPos(PaddedPODArray & pos, size_t start, size_t length, bool has_null) const override; void serializeToPosForCmpColumnArray( PaddedPODArray & pos, size_t start, size_t length, - bool has_null, + const NullMap * nullmap, const IColumn::Offsets & array_offsets, const TiDB::TiDBCollatorPtr &, - String *) const override - { - serializeToPosForColumnArray(pos, start, length, has_null, array_offsets); - } + String *) const override; void serializeToPosForColumnArray( PaddedPODArray & pos, size_t start, diff --git a/dbms/src/Columns/IColumn.h b/dbms/src/Columns/IColumn.h index 4321ed0755a..309c0c5dcac 100644 --- a/dbms/src/Columns/IColumn.h +++ b/dbms/src/Columns/IColumn.h @@ -38,6 +38,14 @@ extern const int SIZES_OF_COLUMNS_DOESNT_MATCH; class Arena; class ColumnGathererStream; +using NullMap = PaddedPODArray; +using ConstNullMapPtr = const NullMap *; + +inline bool isNullAt(const NullMap & nullmap, size_t n) +{ + return nullmap[n] != 0; +} + /// Declares interface to store columns in memory. class IColumn : public COWPtr { @@ -262,39 +270,43 @@ class IColumn : public COWPtr /// Note: /// 1. The pos.size() must be greater than or equal to length. /// 2. If has_null is true, then the pos[i] could be nullptr, which means the i-th element does not need to be serialized. - virtual void serializeToPosForCmp( + virtual void serializeToPos( PaddedPODArray & /* pos */, size_t /* start */, size_t /* length */, - bool /* has_null */, - const TiDB::TiDBCollatorPtr & /* collator */, - String * /* sort_key_container */) const + bool /* has_null */) const = 0; - virtual void serializeToPos( + /// Similar to serializeToPos, but there are two changes to make sure compare semantics is kept: + /// 1. For ColumnString with collator, this method first decode collator and then serialize to pos. + /// 2. For ColumnNullable(ColumnXXX), a default value of the nested column will be serialized if this row is null. + virtual void serializeToPosForCmp( PaddedPODArray & /* pos */, size_t /* start */, size_t /* length */, - bool /* has_null */) const + const NullMap * /*nullmap*/, + const TiDB::TiDBCollatorPtr & /* collator */, + String * /* sort_key_container */) const = 0; /// Serialize data of column from start to start + length into pointer of pos and forward each pos[i] to the end of /// serialized data. /// Only called by ColumnArray. - virtual void serializeToPosForCmpColumnArray( + virtual void serializeToPosForColumnArray( PaddedPODArray & /* pos */, size_t /* start */, size_t /* length */, bool /* has_null */, - const Offsets & /* array_offsets */, - const TiDB::TiDBCollatorPtr & /* collator */, - String * /* sort_key_container */) const + const Offsets & /* array_offsets */) const = 0; - virtual void serializeToPosForColumnArray( + /// Similary to serializeToPosForCmp, but only called by ColumnArray. + virtual void serializeToPosForCmpColumnArray( PaddedPODArray & /* pos */, size_t /* start */, size_t /* length */, - bool /* has_null */, - const Offsets & /* array_offsets */) const + const NullMap * /*nullmap*/, + const Offsets & /* array_offsets */, + const TiDB::TiDBCollatorPtr & /* collator */, + String * /* sort_key_container */) const = 0; /// Deserialize and insert data from pos and forward each pos[i] to the end of serialized data. diff --git a/dbms/src/Columns/IColumnDummy.h b/dbms/src/Columns/IColumnDummy.h index 3882cb4080f..250f4bedc14 100644 --- a/dbms/src/Columns/IColumnDummy.h +++ b/dbms/src/Columns/IColumnDummy.h @@ -123,7 +123,7 @@ class IColumnDummy : public IColumn PaddedPODArray & /* pos */, size_t /* start */, size_t /* length */, - bool /* has_null */, + const NullMap * /* nullmap */, const TiDB::TiDBCollatorPtr & /* collator */, String * /* sort_key_container */) const override { @@ -142,7 +142,7 @@ class IColumnDummy : public IColumn PaddedPODArray & /* pos */, size_t /* start */, size_t /* length */, - bool /* has_null */, + const NullMap * /* nullmap */, const IColumn::Offsets & /* array_offsets */, const TiDB::TiDBCollatorPtr & /* collator */, String * /* sort_key_container */) const override diff --git a/dbms/src/Columns/tests/gtest_column_serialize_deserialize.cpp b/dbms/src/Columns/tests/gtest_column_serialize_deserialize.cpp index 6cf5742aba3..26eabc0b616 100644 --- a/dbms/src/Columns/tests/gtest_column_serialize_deserialize.cpp +++ b/dbms/src/Columns/tests/gtest_column_serialize_deserialize.cpp @@ -31,14 +31,14 @@ class TestColumnSerializeDeserialize : public ::testing::Test static void testCountSerializeByteSize( const ColumnPtr & column_ptr, const PaddedPODArray & result_byte_size, - bool for_compare = false, + bool compare_semantics = false, const TiDB::TiDBCollatorPtr & collator = nullptr) { PaddedPODArray byte_size; byte_size.resize_fill_zero(column_ptr->size()); for (size_t i = 0; i < column_ptr->size(); ++i) byte_size[i] = i; - if (!for_compare) + if (!compare_semantics) column_ptr->countSerializeByteSize(byte_size); else column_ptr->countSerializeByteSizeForCmp(byte_size, collator); @@ -51,7 +51,7 @@ class TestColumnSerializeDeserialize : public ::testing::Test const ColumnPtr & column_ptr, const ColumnPtr & offsets, const PaddedPODArray & result_byte_size, - bool for_compare = false, + bool compare_semantics = false, const TiDB::TiDBCollatorPtr & collator = nullptr) { auto column_array = ColumnArray::create(column_ptr->cloneFullColumn(), offsets->cloneFullColumn()); @@ -59,7 +59,7 @@ class TestColumnSerializeDeserialize : public ::testing::Test byte_size.resize_fill_zero(column_array->size()); for (size_t i = 0; i < column_array->size(); ++i) byte_size[i] = i; - if (!for_compare) + if (!compare_semantics) column_array->countSerializeByteSize(byte_size); else column_array->countSerializeByteSizeForCmp(byte_size, collator); @@ -156,37 +156,37 @@ class TestColumnSerializeDeserialize : public ::testing::Test if (result_col_ptr->isNullAt(i)) continue; auto res = result_col_ptr->getDataAt(i); - auto sort_key = collator->sortKey(res.data, res.size, sort_key_container); - ASSERT_TRUE(sort_key == new_col_ptr->getDataAt(i)); + auto res_sort_key = collator->sortKey(res.data, res.size, sort_key_container); + auto act = new_col_ptr->getDataAt(i); + ASSERT_TRUE(res_sort_key == act); } } } static void testSerializeAndDeserialize( const ColumnPtr & column_ptr, - bool for_compare = false, + bool compare_semantics = false, const TiDB::TiDBCollatorPtr & collator = nullptr, String * sort_key_container = nullptr) { - doTestSerializeAndDeserialize(column_ptr, false, for_compare, collator, sort_key_container); - doTestSerializeAndDeserialize2(column_ptr, false, for_compare, collator, sort_key_container); - doTestSerializeAndDeserialize(column_ptr, true, for_compare, collator, sort_key_container); - doTestSerializeAndDeserialize2(column_ptr, true, for_compare, collator, sort_key_container); + if (compare_semantics) + { + doTestSerializeAndDeserializeForCmp(column_ptr, compare_semantics, collator, sort_key_container); + } + else + { + doTestSerializeAndDeserialize(column_ptr, false); + doTestSerializeAndDeserialize2(column_ptr, false); + doTestSerializeAndDeserialize(column_ptr, true); + doTestSerializeAndDeserialize2(column_ptr, true); + } } - static void doTestSerializeAndDeserialize( - const ColumnPtr & column_ptr, - bool use_nt_align_buffer, - bool for_compare = false, - const TiDB::TiDBCollatorPtr & collator = nullptr, - String * sort_key_container = nullptr) + static void doTestSerializeAndDeserialize(const ColumnPtr & column_ptr, bool use_nt_align_buffer) { PaddedPODArray byte_size; byte_size.resize_fill_zero(column_ptr->size()); - if (!for_compare) - column_ptr->countSerializeByteSize(byte_size); - else - column_ptr->countSerializeByteSizeForCmp(byte_size, collator); + column_ptr->countSerializeByteSize(byte_size); size_t total_size = 0; for (const auto size : byte_size) total_size += size; @@ -201,18 +201,12 @@ class TestColumnSerializeDeserialize : public ::testing::Test PaddedPODArray ori_pos; for (auto * ptr : pos) ori_pos.push_back(ptr); - if (!for_compare) - column_ptr->serializeToPos(pos, 0, byte_size.size() / 2, false); - else - column_ptr->serializeToPosForCmp(pos, 0, byte_size.size() / 2, false, collator, sort_key_container); + column_ptr->serializeToPos(pos, 0, byte_size.size() / 2, false); auto new_col_ptr = column_ptr->cloneEmpty(); if (use_nt_align_buffer) new_col_ptr->reserveAlign(byte_size.size(), FULL_VECTOR_SIZE_AVX2); - if (!for_compare) - new_col_ptr->deserializeAndInsertFromPos(ori_pos, use_nt_align_buffer); - else - new_col_ptr->deserializeForCmpAndInsertFromPos(ori_pos, use_nt_align_buffer); + new_col_ptr->deserializeAndInsertFromPos(ori_pos, use_nt_align_buffer); current_size = 0; pos.clear(); @@ -225,23 +219,11 @@ class TestColumnSerializeDeserialize : public ::testing::Test pos.push_back(nullptr); for (auto * ptr : pos) ori_pos.push_back(ptr); - if (!for_compare) - column_ptr->serializeToPos(pos, byte_size.size() / 2, byte_size.size() - byte_size.size() / 2, true); - else - column_ptr->serializeToPosForCmp( - pos, - byte_size.size() / 2, - byte_size.size() - byte_size.size() / 2, - true, - collator, - sort_key_container); + column_ptr->serializeToPos(pos, byte_size.size() / 2, byte_size.size() - byte_size.size() / 2, true); pos.resize(pos.size() - 1); ori_pos.resize(ori_pos.size() - 1); - if (!for_compare) - new_col_ptr->deserializeAndInsertFromPos(ori_pos, use_nt_align_buffer); - else - new_col_ptr->deserializeForCmpAndInsertFromPos(ori_pos, use_nt_align_buffer); + new_col_ptr->deserializeAndInsertFromPos(ori_pos, use_nt_align_buffer); current_size = 0; pos.clear(); @@ -253,15 +235,9 @@ class TestColumnSerializeDeserialize : public ::testing::Test } for (auto * ptr : pos) ori_pos.push_back(ptr); - if (!for_compare) - column_ptr->serializeToPos(pos, 0, byte_size.size(), true); - else - column_ptr->serializeToPosForCmp(pos, 0, byte_size.size(), true, collator, sort_key_container); + column_ptr->serializeToPos(pos, 0, byte_size.size(), true); - if (!for_compare) - new_col_ptr->deserializeAndInsertFromPos(ori_pos, use_nt_align_buffer); - else - new_col_ptr->deserializeForCmpAndInsertFromPos(ori_pos, use_nt_align_buffer); + new_col_ptr->deserializeAndInsertFromPos(ori_pos, use_nt_align_buffer); if (use_nt_align_buffer) new_col_ptr->flushNTAlignBuffer(); @@ -270,27 +246,16 @@ class TestColumnSerializeDeserialize : public ::testing::Test for (size_t i = 0; i < column_ptr->size(); ++i) result_col_ptr->insertFrom(*column_ptr, i); - if (collator != nullptr) - checkForColumnWithCollator(std::move(result_col_ptr), std::move(new_col_ptr), collator); - else - ASSERT_COLUMN_EQ(std::move(result_col_ptr), std::move(new_col_ptr)); + ASSERT_COLUMN_EQ(std::move(result_col_ptr), std::move(new_col_ptr)); } - static void doTestSerializeAndDeserialize2( - const ColumnPtr & column_ptr, - bool use_nt_align_buffer, - bool for_compare = false, - const TiDB::TiDBCollatorPtr & collator = nullptr, - String * sort_key_container = nullptr) + static void doTestSerializeAndDeserialize2(const ColumnPtr & column_ptr, bool use_nt_align_buffer) { if (column_ptr->size() < 2) return; PaddedPODArray byte_size; byte_size.resize_fill_zero(column_ptr->size()); - if (!for_compare) - column_ptr->countSerializeByteSize(byte_size); - else - column_ptr->countSerializeByteSizeForCmp(byte_size, collator); + column_ptr->countSerializeByteSize(byte_size); size_t total_size = 0; for (const auto size : byte_size) total_size += size; @@ -306,20 +271,14 @@ class TestColumnSerializeDeserialize : public ::testing::Test pos.push_back(nullptr); for (auto * ptr : pos) ori_pos.push_back(ptr); - if (!for_compare) - column_ptr->serializeToPos(pos, 0, byte_size.size() / 2, true); - else - column_ptr->serializeToPosForCmp(pos, 0, byte_size.size() / 2, true, collator, sort_key_container); + column_ptr->serializeToPos(pos, 0, byte_size.size() / 2, true); pos.resize(pos.size() - 1); ori_pos.resize(ori_pos.size() - 1); auto new_col_ptr = column_ptr->cloneEmpty(); if (use_nt_align_buffer) new_col_ptr->reserveAlign(byte_size.size(), FULL_VECTOR_SIZE_AVX2); - if (!for_compare) - new_col_ptr->deserializeAndInsertFromPos(ori_pos, use_nt_align_buffer); - else - new_col_ptr->deserializeForCmpAndInsertFromPos(ori_pos, use_nt_align_buffer); + new_col_ptr->deserializeAndInsertFromPos(ori_pos, use_nt_align_buffer); current_size = 0; pos.clear(); @@ -331,21 +290,8 @@ class TestColumnSerializeDeserialize : public ::testing::Test } for (auto * ptr : pos) ori_pos.push_back(ptr); - if (!for_compare) - column_ptr - ->serializeToPos(pos, byte_size.size() / 2 - 1, byte_size.size() - byte_size.size() / 2 + 1, false); - else - column_ptr->serializeToPosForCmp( - pos, - byte_size.size() / 2 - 1, - byte_size.size() - byte_size.size() / 2 + 1, - false, - collator, - sort_key_container); - if (!for_compare) - new_col_ptr->deserializeAndInsertFromPos(ori_pos, use_nt_align_buffer); - else - new_col_ptr->deserializeForCmpAndInsertFromPos(ori_pos, use_nt_align_buffer); + column_ptr->serializeToPos(pos, byte_size.size() / 2 - 1, byte_size.size() - byte_size.size() / 2 + 1, false); + new_col_ptr->deserializeAndInsertFromPos(ori_pos, use_nt_align_buffer); current_size = 0; pos.clear(); @@ -357,15 +303,83 @@ class TestColumnSerializeDeserialize : public ::testing::Test } for (auto * ptr : pos) ori_pos.push_back(ptr); - if (!for_compare) - column_ptr->serializeToPos(pos, 0, byte_size.size(), true); - else - column_ptr->serializeToPosForCmp(pos, 0, byte_size.size(), true, collator, sort_key_container); + column_ptr->serializeToPos(pos, 0, byte_size.size(), true); + + new_col_ptr->deserializeAndInsertFromPos(ori_pos, use_nt_align_buffer); + if (use_nt_align_buffer) + new_col_ptr->flushNTAlignBuffer(); + + auto result_col_ptr = column_ptr->cloneFullColumn(); + for (size_t i = 0; i < column_ptr->size(); ++i) + result_col_ptr->insertFrom(*column_ptr, i); + + ASSERT_COLUMN_EQ(std::move(result_col_ptr), std::move(new_col_ptr)); + } + + static void doTestSerializeAndDeserializeForCmp( + const ColumnPtr & column_ptr, + bool use_nt_align_buffer, + const TiDB::TiDBCollatorPtr & collator = nullptr, + String * sort_key_container = nullptr) + { + PaddedPODArray byte_size; + byte_size.resize_fill_zero(column_ptr->size()); + column_ptr->countSerializeByteSizeForCmp(byte_size, collator); + size_t total_size = 0; + for (const auto size : byte_size) + total_size += size; + PaddedPODArray memory(total_size); + PaddedPODArray pos; + size_t current_size = 0; + for (size_t i = 0; i < byte_size.size() / 2; ++i) + { + pos.push_back(memory.data() + current_size); + current_size += byte_size[i]; + } + PaddedPODArray ori_pos; + for (auto * ptr : pos) + ori_pos.push_back(ptr); + column_ptr->serializeToPosForCmp(pos, 0, byte_size.size() / 2, nullptr, collator, sort_key_container); + + auto new_col_ptr = column_ptr->cloneEmpty(); + if (use_nt_align_buffer) + new_col_ptr->reserveAlign(byte_size.size(), FULL_VECTOR_SIZE_AVX2); + new_col_ptr->deserializeForCmpAndInsertFromPos(ori_pos, use_nt_align_buffer); + + current_size = 0; + pos.clear(); + ori_pos.clear(); + for (size_t i = byte_size.size() / 2; i < byte_size.size(); ++i) + { + pos.push_back(memory.data() + current_size); + current_size += byte_size[i]; + } + for (auto * ptr : pos) + ori_pos.push_back(ptr); + column_ptr->serializeToPosForCmp( + pos, + byte_size.size() / 2, + byte_size.size() - byte_size.size() / 2, + nullptr, + collator, + sort_key_container); + + new_col_ptr->deserializeForCmpAndInsertFromPos(ori_pos, use_nt_align_buffer); + + current_size = 0; + pos.clear(); + ori_pos.clear(); + for (const auto size : byte_size) + { + pos.push_back(memory.data() + current_size); + current_size += size; + } + for (auto * ptr : pos) + ori_pos.push_back(ptr); + + column_ptr->serializeToPosForCmp(pos, 0, byte_size.size(), nullptr, collator, sort_key_container); + new_col_ptr->deserializeForCmpAndInsertFromPos(ori_pos, use_nt_align_buffer); - if (!for_compare) - new_col_ptr->deserializeAndInsertFromPos(ori_pos, use_nt_align_buffer); - else - new_col_ptr->deserializeForCmpAndInsertFromPos(ori_pos, use_nt_align_buffer); if (use_nt_align_buffer) new_col_ptr->flushNTAlignBuffer(); @@ -590,6 +604,28 @@ try testCountSerializeByteSize(col_nullable_array_vec, {1 + 4 + 4, 1 + 4 + 8, 1 + 4 + 12}); testSerializeAndDeserialize(col_nullable_array_vec); testSerializeAndDeserialize(col_nullable_array_vec, true, nullptr, nullptr); + + // ColumnNullable(ColumnArray(ColumnString)) + auto col_string = createColumn({"123", "2", "34", "456", "5678", "6"}).column; + auto col_array_string = ColumnArray::create(col_vector, col_offsets); + auto col_nullable_array_string = ColumnNullable::create(col_array_vec, createColumn({1, 0, 1}).column); + testSerializeAndDeserialize(col_nullable_array_vec); + testSerializeAndDeserialize(col_nullable_array_vec, true, nullptr, nullptr); + + // ColumnArray(ColumnNullable(ColumnVector)) not support. + + // Nested ColumnNullable like ColumnNullable(ColumnArray(ColumnNullable(ColumnString))) not support. + // auto col_offsets_1 = createColumn({1, 3, 6}).column; + // auto col_array_string = ColumnArray::create(col_nullable_string, col_offsets_1); + // auto col_nullable_array_string = ColumnNullable::create(col_array_string, createColumn({0, 1, 0}).column); + // testCountSerializeByteSize(col_nullable_array_string, + // {1 + 4 + 1 + 4 + 4, + // 1 + 4 + 2 + 8 + 4, + // 1 + 4 + 3 + 12 + 7}, true, nullptr); + // testSerializeAndDeserialize(col_nullable_array_string); + // testSerializeAndDeserialize(col_nullable_array_string, true, collator_utf8_bin, &sort_key_container); + // testSerializeAndDeserialize(col_nullable_array_string, true, collator_utf8_general_ci, &sort_key_container); + // testSerializeAndDeserialize(col_nullable_array_string, true, collator_utf8_unicode_ci, &sort_key_container); } CATCH @@ -625,9 +661,10 @@ try auto col_array_nullable_string = ColumnArray::create(col_nullable_string, col_offsets); testCountSerializeByteSize(col_array_nullable_string, {4 + 5 + 4, 4 + 10 + 4, 4 + 15 + 7}); testSerializeAndDeserialize(col_array_nullable_string); - testSerializeAndDeserialize(col_array_nullable_string, true, collator_utf8_bin, &sort_key_container); - testSerializeAndDeserialize(col_array_nullable_string, true, collator_utf8_general_ci, &sort_key_container); - testSerializeAndDeserialize(col_array_nullable_string, true, collator_utf8_unicode_ci, &sort_key_container); + // compare semantics not support ColumnArray(ColumnNullable(ColumnString)). + // testSerializeAndDeserialize(col_array_nullable_string, true, collator_utf8_bin, &sort_key_container); + // testSerializeAndDeserialize(col_array_nullable_string, true, collator_utf8_general_ci, &sort_key_container); + // testSerializeAndDeserialize(col_array_nullable_string, true, collator_utf8_unicode_ci, &sort_key_container); // ColumnArray(ColumnDecimal) auto col_decimal_256 = createColumn( @@ -778,6 +815,7 @@ try true, collator_utf8_unicode_ci); + // ColumnString String sort_key_container; testSerializeAndDeserialize(col_string, true, collator_utf8_bin, &sort_key_container); testSerializeAndDeserialize(col_string, true, collator_utf8_general_ci, &sort_key_container); diff --git a/dbms/src/Common/ColumnsHashing.h b/dbms/src/Common/ColumnsHashing.h index a03136bbed8..b3d18ad9319 100644 --- a/dbms/src/Common/ColumnsHashing.h +++ b/dbms/src/Common/ColumnsHashing.h @@ -48,6 +48,7 @@ struct HashMethodOneNumber using Self = HashMethodOneNumber; using Base = columns_hashing_impl::HashMethodBase; using KeyHolderType = FieldType; + using BatchKeyHolderType = KeyHolderType; static constexpr bool is_serialized_key = false; @@ -85,17 +86,113 @@ struct HashMethodOneNumber } const FieldType * getKeyData() const { return vec; } + + bool batchGetKeyHolder() const { return false; } }; +class KeyStringBatchHandlerBase +{ +private: + size_t batch_row_idx = 0; + std::vector sort_key_containers{}; + std::vector batch_rows{}; + + template + size_t prepareNextBatchType( + const UInt8 * chars, + const IColumn::Offsets & offsets, + size_t cur_batch_size, + const TiDB::TiDBCollatorPtr & collator) + { + if (cur_batch_size <= 0) + return 0; + + const auto * derived_collator = static_cast(collator); + for (size_t i = 0; i < cur_batch_size; ++i) + { + const auto row = batch_row_idx + i; + const auto last_offset = offsets[row - 1]; + // Remove last zero byte. + StringRef key(chars + last_offset, offsets[row] - offsets[row - 1] - 1); + if constexpr (has_collator) + key = derived_collator->sortKey(key.data, key.size, sort_key_containers[i]); + + batch_rows[i] = key; + } + return 0; + } + + void santityCheck() const + { + // Make sure init() has called. + assert(sort_key_containers.size() == batch_rows.size() && !sort_key_containers.empty()); + } + +protected: + bool inited() const { return !sort_key_containers.empty(); } + + void init(size_t start_row, size_t max_batch_size) + { + RUNTIME_CHECK(max_batch_size >= 256); + batch_row_idx = start_row; + sort_key_containers.resize(max_batch_size); + batch_rows.reserve(max_batch_size); + } + + size_t prepareNextBatch( + const UInt8 * chars, + const IColumn::Offsets & offsets, + size_t cur_batch_size, + const TiDB::TiDBCollatorPtr & collator) + { + batch_rows.resize(cur_batch_size); + + if likely (collator) + { +#define M(VAR_PREFIX, COLLATOR_NAME, IMPL_TYPE, COLLATOR_ID) \ + case (COLLATOR_ID): \ + { \ + return prepareNextBatchType(chars, offsets, cur_batch_size, collator); \ + break; \ + } + + switch (collator->getCollatorId()) + { + APPLY_FOR_COLLATOR_TYPES(M) + default: + { + throw Exception(fmt::format("unexpected collator: {}", collator->getCollatorId())); + } + }; +#undef M + } + else + { + return prepareNextBatchType(chars, offsets, cur_batch_size, collator); + } + } + + // NOTE: i is the index of mini batch, it's not the row index of Column. + ALWAYS_INLINE inline ArenaKeyHolder getKeyHolderBatch(size_t i, Arena * pool) const + { + santityCheck(); + assert(i < batch_rows.size()); + return ArenaKeyHolder{batch_rows[i], pool}; + } +}; /// For the case when there is one string key. template struct HashMethodString : public columns_hashing_impl::HashMethodBase, Value, Mapped, use_cache> + , KeyStringBatchHandlerBase { using Self = HashMethodString; using Base = columns_hashing_impl::HashMethodBase; using KeyHolderType = ArenaKeyHolder; + using BatchKeyHolderType = KeyHolderType; + + using BatchHandlerBase = KeyStringBatchHandlerBase; static constexpr bool is_serialized_key = false; @@ -116,11 +213,27 @@ struct HashMethodString collator = collators[0]; } + bool batchGetKeyHolder() const { return BatchHandlerBase::inited(); } + + void initBatchHandler(size_t start_row, size_t max_batch_size) + { + assert(!BatchHandlerBase::inited()); + BatchHandlerBase::init(start_row, max_batch_size); + } + + size_t prepareNextBatch(Arena *, size_t cur_batch_size) + { + assert(BatchHandlerBase::inited()); + return BatchHandlerBase::prepareNextBatch(chars, *offsets, cur_batch_size, collator); + } + ALWAYS_INLINE inline KeyHolderType getKeyHolder( ssize_t row, [[maybe_unused]] Arena * pool, std::vector & sort_key_containers) const { + assert(!BatchHandlerBase::inited()); + auto last_offset = row == 0 ? 0 : offsets[row - 1]; // Remove last zero byte. StringRef key(chars + last_offset, offsets[row] - last_offset - 1); @@ -141,6 +254,7 @@ struct HashMethodStringBin using Self = HashMethodStringBin; using Base = columns_hashing_impl::HashMethodBase; using KeyHolderType = ArenaKeyHolder; + using BatchKeyHolderType = KeyHolderType; static constexpr bool is_serialized_key = false; @@ -155,6 +269,8 @@ struct HashMethodStringBin chars = column_string.getChars().data(); } + bool batchGetKeyHolder() const { return false; } + ALWAYS_INLINE inline KeyHolderType getKeyHolder(ssize_t row, Arena * pool, std::vector &) const { auto last_offset = row == 0 ? 0 : offsets[row - 1]; @@ -341,6 +457,7 @@ struct HashMethodFastPathTwoKeysSerialized using Self = HashMethodFastPathTwoKeysSerialized; using Base = columns_hashing_impl::HashMethodBase; using KeyHolderType = SerializedKeyHolder; + using BatchKeyHolderType = KeyHolderType; static constexpr bool is_serialized_key = true; @@ -352,6 +469,8 @@ struct HashMethodFastPathTwoKeysSerialized , key_2_desc(key_columns[1]) {} + bool batchGetKeyHolder() const { return false; } + ALWAYS_INLINE inline KeyHolderType getKeyHolder(ssize_t row, Arena * pool, std::vector &) const { StringRef key1; @@ -378,6 +497,7 @@ struct HashMethodFixedString using Self = HashMethodFixedString; using Base = columns_hashing_impl::HashMethodBase; using KeyHolderType = ArenaKeyHolder; + using BatchKeyHolderType = KeyHolderType; static constexpr bool is_serialized_key = false; @@ -398,6 +518,8 @@ struct HashMethodFixedString collator = collators[0]; } + bool batchGetKeyHolder() const { return false; } + ALWAYS_INLINE inline KeyHolderType getKeyHolder(size_t row, Arena * pool, std::vector & sort_key_containers) const { @@ -426,6 +548,7 @@ struct HashMethodKeysFixed using BaseHashed = columns_hashing_impl::HashMethodBase; using Base = columns_hashing_impl::BaseStateKeysFixed; using KeyHolderType = Key; + using BatchKeyHolderType = KeyHolderType; static constexpr bool is_serialized_key = false; static constexpr bool has_nullable_keys = has_nullable_keys_; @@ -516,6 +639,8 @@ struct HashMethodKeysFixed #endif } + bool batchGetKeyHolder() const { return false; } + ALWAYS_INLINE inline KeyHolderType getKeyHolder(size_t row, Arena *, std::vector &) const { if constexpr (has_nullable_keys) @@ -570,6 +695,89 @@ struct HashMethodKeysFixed } }; +class KeySerializedBatchHandlerBase +{ +private: + size_t batch_row_idx = 0; + String sort_key_container{}; + PaddedPODArray byte_size{}; + PaddedPODArray pos{}; + PaddedPODArray ori_pos{}; + PaddedPODArray real_byte_size{}; + + ALWAYS_INLINE inline void santityCheck() const + { + assert(ori_pos.size() == pos.size() && real_byte_size.size() == pos.size()); + } + + ALWAYS_INLINE inline void resize(size_t batch_size) + { + pos.resize(batch_size); + ori_pos.resize(batch_size); + real_byte_size.resize(batch_size); + } + +protected: + bool inited() const { return !byte_size.empty(); } + + void init(size_t start_row, const ColumnRawPtrs & key_columns, const TiDB::TiDBCollators & collators) + { + batch_row_idx = start_row; + byte_size.resize_fill_zero(key_columns[0]->size()); + RUNTIME_CHECK(!byte_size.empty()); + for (size_t i = 0; i < key_columns.size(); ++i) + key_columns[i]->countSerializeByteSizeForCmp(byte_size, collators.empty() ? nullptr : collators[i]); + } + + size_t prepareNextBatch( + const ColumnRawPtrs & key_columns, + Arena * pool, + size_t cur_batch_size, + const TiDB::TiDBCollators & collators) + { + santityCheck(); + resize(cur_batch_size); + + if unlikely (cur_batch_size <= 0) + return 0; + + size_t mem_size = 0; + for (size_t i = batch_row_idx; i < batch_row_idx + cur_batch_size; ++i) + mem_size += byte_size[i]; + + auto * ptr = static_cast(pool->alignedAlloc(mem_size, 16)); + for (size_t i = 0; i < cur_batch_size; ++i) + { + pos[i] = ptr; + ori_pos[i] = ptr; + ptr += byte_size[i + batch_row_idx]; + } + + for (size_t i = 0; i < key_columns.size(); ++i) + key_columns[i]->serializeToPosForCmp( + pos, + batch_row_idx, + cur_batch_size, + nullptr, + collators.empty() ? nullptr : collators[i], + &sort_key_container); + + for (size_t i = 0; i < cur_batch_size; ++i) + real_byte_size[i] = pos[i] - ori_pos[i]; + + batch_row_idx += cur_batch_size; + + return mem_size; + } + + // NOTE: i is the index of mini batch, it's not the row index of Column. + ALWAYS_INLINE inline ArenaKeyHolder getKeyHolderBatch(size_t i, Arena * pool) const + { + santityCheck(); + assert(i < ori_pos.size()); + return ArenaKeyHolder{StringRef{ori_pos[i], real_byte_size[i]}, pool}; + } +}; /** Hash by concatenating serialized key values. * The serialized value differs in that it uniquely allows to deserialize it, having only the position with which it starts. @@ -579,12 +787,14 @@ struct HashMethodKeysFixed template struct HashMethodSerialized : public columns_hashing_impl::HashMethodBase, Value, Mapped, false> + , KeySerializedBatchHandlerBase { using Self = HashMethodSerialized; using Base = columns_hashing_impl::HashMethodBase; - using KeyHolderType = SerializedKeyHolder; - + using BatchHandlerBase = KeySerializedBatchHandlerBase; static constexpr bool is_serialized_key = true; + using KeyHolderType = SerializedKeyHolder; + using BatchKeyHolderType = ArenaKeyHolder; ColumnRawPtrs key_columns; size_t keys_size; @@ -599,9 +809,24 @@ struct HashMethodSerialized , collators(collators_) {} + bool batchGetKeyHolder() const { return BatchHandlerBase::inited(); } + + void initBatchHandler(size_t start_row, size_t /* max_batch_size */) + { + assert(!BatchHandlerBase::inited()); + BatchHandlerBase::init(start_row, key_columns, collators); + } + + size_t prepareNextBatch(Arena * pool, size_t cur_batch_size) + { + assert(BatchHandlerBase::inited()); + return BatchHandlerBase::prepareNextBatch(key_columns, pool, cur_batch_size, collators); + } + ALWAYS_INLINE inline KeyHolderType getKeyHolder(size_t row, Arena * pool, std::vector & sort_key_containers) const { + assert(!BatchHandlerBase::inited()); return SerializedKeyHolder{ serializeKeysToPoolContiguous(row, keys_size, key_columns, collators, sort_key_containers, *pool), pool}; @@ -620,6 +845,7 @@ struct HashMethodHashed using Self = HashMethodHashed; using Base = columns_hashing_impl::HashMethodBase; using KeyHolderType = Key; + using BatchKeyHolderType = KeyHolderType; static constexpr bool is_serialized_key = false; @@ -631,6 +857,8 @@ struct HashMethodHashed , collators(collators_) {} + bool batchGetKeyHolder() const { return false; } + ALWAYS_INLINE inline KeyHolderType getKeyHolder(size_t row, Arena *, std::vector & sort_key_containers) const { diff --git a/dbms/src/Interpreters/Aggregator.cpp b/dbms/src/Interpreters/Aggregator.cpp index 9af4ed8d466..0c68d3fac2b 100644 --- a/dbms/src/Interpreters/Aggregator.cpp +++ b/dbms/src/Interpreters/Aggregator.cpp @@ -655,6 +655,13 @@ void Aggregator::createAggregateStates(AggregateDataPtr & aggregate_data) const } } +template +concept HasBatchGetKeyHolderMemberFunc = requires { + // todo also require initBatchHandler() + { + std::declval().getKeyHolderBatch(std::declval(), std::declval()) + } -> std::same_as; +}; /** It's interesting - if you remove `noinline`, then gcc for some reason will inline this function, and the performance decreases (~ 10%). * (Probably because after the inline of this function, more internal functions no longer be inlined.) @@ -663,10 +670,11 @@ void Aggregator::createAggregateStates(AggregateDataPtr & aggregate_data) const template void NO_INLINE Aggregator::executeImpl( Method & method, - Arena * aggregates_pool, + AggregatedDataVariants & result, AggProcessInfo & agg_process_info, TiDB::TiDBCollators & collators) const { + auto * aggregates_pool = result.aggregates_pool; typename Method::State state(agg_process_info.key_columns, key_sizes, collators); // 2MB as prefetch threshold, because normally server L2 cache is 1MB. @@ -678,21 +686,70 @@ void NO_INLINE Aggregator::executeImpl( const bool disable_prefetch = (method.data.getBufferSizeInBytes() < prefetch_threshold); #endif - if constexpr (Method::State::is_serialized_key) + // For key_serialized, memory allocation and key serialization will be batch-wise. + // For key_string, collation decode will be batch-wise. + static constexpr bool batch_get_key_holder = HasBatchGetKeyHolderMemberFunc; + if constexpr (batch_get_key_holder) { - executeImplBatch(method, state, aggregates_pool, agg_process_info); + state.initBatchHandler(agg_process_info.start_row); + result.batch_get_key_holder = true; } - else if constexpr (Method::Data::is_string_hash_map) + // todo remove batchGetKeyHolder? + + if constexpr (Method::Data::is_string_hash_map) { // StringHashMap doesn't support prefetch. - executeImplBatch(method, state, aggregates_pool, agg_process_info); + if constexpr (batch_get_key_holder) + executeImplBatch< + collect_hit_rate, + only_lookup, + /*enable_prefetch=*/false, + /*batch_get_key_holder=*/true, + typename Method::State::BatchKeyHolderType>(method, state, aggregates_pool, agg_process_info); + else + executeImplBatch< + collect_hit_rate, + only_lookup, + /*enable_prefetch=*/false, + /*batch_get_key_holder=*/false, + typename Method::State::KeyHolderType>(method, state, aggregates_pool, agg_process_info); } else { if (disable_prefetch) - executeImplBatch(method, state, aggregates_pool, agg_process_info); + { + if constexpr (batch_get_key_holder) + executeImplBatch< + collect_hit_rate, + only_lookup, + /*enable_prefetch=*/false, + /*batch_get_key_holder=*/true, + typename Method::State::BatchKeyHolderType>(method, state, aggregates_pool, agg_process_info); + else + executeImplBatch< + collect_hit_rate, + only_lookup, + /*enable_prefetch=*/false, + /*batch_get_key_holder=*/false, + typename Method::State::KeyHolderType>(method, state, aggregates_pool, agg_process_info); + } else - executeImplBatch(method, state, aggregates_pool, agg_process_info); + { + if constexpr (batch_get_key_holder) + executeImplBatch< + collect_hit_rate, + only_lookup, + /*enable_prefetch=*/true, + /*batch_get_key_holder=*/true, + typename Method::State::BatchKeyHolderType>(method, state, aggregates_pool, agg_process_info); + else + executeImplBatch< + collect_hit_rate, + only_lookup, + /*enable_prefetch=*/true, + /*batch_get_key_holder=*/false, + typename Method::State::KeyHolderType>(method, state, aggregates_pool, agg_process_info); + } } } @@ -737,30 +794,40 @@ std::optional::Res } } -template -ALWAYS_INLINE inline void prepareBatch( +template +ALWAYS_INLINE inline void setupHashVals( size_t row_idx, - size_t end_row, + size_t batch_size, std::vector & hashvals, - std::vector & key_holders, + std::vector & key_holders, Arena * aggregates_pool, std::vector & sort_key_containers, Method & method, typename Method::State & state) { - assert(hashvals.size() == key_holders.size()); + assert(hashvals.size() == key_holders.size() && hashvals.size() == batch_size); - for (size_t i = row_idx, j = 0; i < row_idx + hashvals.size() && i < end_row; ++i, ++j) + for (size_t i = row_idx, j = 0; i < row_idx + batch_size; ++i, ++j) { - key_holders[j] = static_cast(&state)->getKeyHolder( - i, - aggregates_pool, - sort_key_containers); + if constexpr (batch_get_key_holder) + key_holders[j] + = static_cast(&state)->getKeyHolderBatch(j, aggregates_pool); + else + key_holders[j] = static_cast(&state)->getKeyHolder( + i, + aggregates_pool, + sort_key_containers); hashvals[j] = method.data.hash(keyHolderGetKey(key_holders[j])); } } -template +template < + bool collect_hit_rate, + bool only_lookup, + bool enable_prefetch, + bool batch_get_key_holder, + typename KeyHolderType, + typename Method> ALWAYS_INLINE void Aggregator::executeImplBatch( Method & method, typename Method::State & state, @@ -772,11 +839,13 @@ ALWAYS_INLINE void Aggregator::executeImplBatch( /// Optimization for special case when there are no aggregate functions. if (params.aggregates_size == 0) - return handleOneBatch( - method, - state, - agg_process_info, - aggregates_pool); + return handleOneBatch< + collect_hit_rate, + only_lookup, + enable_prefetch, + batch_get_key_holder, + /*compute_agg_data=*/false, + KeyHolderType>(method, state, agg_process_info, aggregates_pool); /// Optimization for special case when aggregating by 8bit key. if constexpr (std::is_same_v) @@ -818,14 +887,23 @@ ALWAYS_INLINE void Aggregator::executeImplBatch( } /// Generic case. - return handleOneBatch( - method, - state, - agg_process_info, - aggregates_pool); -} - -template + return handleOneBatch< + collect_hit_rate, + only_lookup, + enable_prefetch, + batch_get_key_holder, + /*compute_agg_data=*/true, + KeyHolderType>(method, state, agg_process_info, aggregates_pool); +} + +template < + bool collect_hit_rate, + bool only_lookup, + bool enable_prefetch, + bool batch_get_key_holder, + bool compute_agg_data, + typename KeyHolderType, + typename Method> void Aggregator::handleOneBatch( Method & method, typename Method::State & state, @@ -851,10 +929,11 @@ void Aggregator::handleOneBatch( size_t i = agg_process_info.start_row; const size_t end = agg_process_info.start_row + rows; + Arena temp_batch_pool; size_t mini_batch_size = rows; std::vector hashvals; - std::vector key_holders; + std::vector key_holders; if constexpr (enable_prefetch) { // mini batch will only be used when HashTable is big(a.k.a enable_prefetch is true), @@ -867,12 +946,24 @@ void Aggregator::handleOneBatch( // i is the begin row index of each mini batch. while (i < end) { + size_t batch_mem_size = 0; + if constexpr (batch_get_key_holder) + batch_mem_size = state.prepareNextBatch(mini_batch_size, &temp_batch_pool); + if constexpr (enable_prefetch) { if unlikely (i + mini_batch_size > end) mini_batch_size = end - i; - prepareBatch(i, end, hashvals, key_holders, aggregates_pool, sort_key_containers, method, state); + setupHashVals( + i, + mini_batch_size, + hashvals, + key_holders, + aggregates_pool, + sort_key_containers, + method, + state); } const auto cur_batch_end = i + mini_batch_size; @@ -890,6 +981,11 @@ void Aggregator::handleOneBatch( emplace_result_holder = emplaceOrFindKey(method, state, std::move(key_holders[k]), hashvals[k]); } + else if constexpr (batch_get_key_holder) + { + emplace_result_holder + = emplaceOrFindKey(method, state, std::move(key_holders[k]), hashvals[k]); + } else { emplace_result_holder @@ -908,13 +1004,9 @@ void Aggregator::handleOneBatch( if constexpr (compute_agg_data) { if (emplace_result.isFound()) - { aggregate_data = emplace_result.getMapped(); - } else - { agg_process_info.not_found_rows.push_back(j); - } } else { @@ -958,6 +1050,9 @@ void Aggregator::handleOneBatch( processed_rows = j; } + if constexpr (batch_get_key_holder) + temp_batch_pool.rollback(batch_mem_size); + if unlikely (!processed_rows.has_value()) break; @@ -1161,7 +1256,7 @@ bool Aggregator::executeOnBlockImpl( { \ executeImpl( \ *ToAggregationMethodPtr(NAME, result.aggregation_method_impl), \ - result.aggregates_pool, \ + result, \ agg_process_info, \ params.collators); \ break; \ @@ -1268,24 +1363,53 @@ Block Aggregator::convertOneBucketToBlock( bool final, size_t bucket) const { -#define FILLER_DEFINE(name, skip_convert_key) \ - auto filler_##name = [bucket, &method, arena, this]( \ - const Sizes & key_sizes, \ - MutableColumns & key_columns, \ - AggregateColumnsData & aggregate_columns, \ - MutableColumns & final_aggregate_columns, \ - bool final_) { \ - using METHOD_TYPE = std::decay_t; \ - using DATA_TYPE = std::decay_t; \ - convertToBlockImpl( \ - method, \ - method.data.impls[bucket], \ - key_sizes, \ - key_columns, \ - aggregate_columns, \ - final_aggregate_columns, \ - arena, \ - final_); \ + const bool batch_get_key_holder = data_variants.batch_get_key_holder; +#define FILLER_DEFINE(name, skip_convert_key) \ + auto filler_##name = [bucket, &method, arena, this, batch_get_key_holder]( \ + const Sizes & key_sizes, \ + MutableColumns & key_columns, \ + AggregateColumnsData & aggregate_columns, \ + MutableColumns & final_aggregate_columns, \ + bool final_) { \ + (void)batch_get_key_holder; \ + using METHOD_TYPE = std::decay_t; \ + using DATA_TYPE = std::decay_t; \ + if constexpr ( \ + METHOD_TYPE::State::is_serialized_key && HasBatchGetKeyHolderMemberFunc) \ + { \ + if (batch_get_key_holder) \ + convertToBlockImpl( \ + method, \ + method.data.impls[bucket], \ + key_sizes, \ + key_columns, \ + aggregate_columns, \ + final_aggregate_columns, \ + arena, \ + final_); \ + else \ + convertToBlockImpl( \ + method, \ + method.data.impls[bucket], \ + key_sizes, \ + key_columns, \ + aggregate_columns, \ + final_aggregate_columns, \ + arena, \ + final_); \ + } \ + else \ + { \ + convertToBlockImpl( \ + method, \ + method.data.impls[bucket], \ + key_sizes, \ + key_columns, \ + aggregate_columns, \ + final_aggregate_columns, \ + arena, \ + final_); \ + } \ } FILLER_DEFINE(convert_key, false); @@ -1328,22 +1452,50 @@ BlocksList Aggregator::convertOneBucketToBlocks( bool final, size_t bucket) const { -#define FILLER_DEFINE(name, skip_convert_key) \ - auto filler_##name = [bucket, &method, arena, this]( \ - const Sizes & key_sizes, \ - std::vector & key_columns_vec, \ - std::vector & aggregate_columns_vec, \ - std::vector & final_aggregate_columns_vec, \ - bool final_) { \ - convertToBlocksImpl( \ - method, \ - method.data.impls[bucket], \ - key_sizes, \ - key_columns_vec, \ - aggregate_columns_vec, \ - final_aggregate_columns_vec, \ - arena, \ - final_); \ + const auto batch_get_key_holder = data_variants.batch_get_key_holder; +#define FILLER_DEFINE(name, skip_convert_key) \ + auto filler_##name = [bucket, &method, arena, this, batch_get_key_holder]( \ + const Sizes & key_sizes, \ + std::vector & key_columns_vec, \ + std::vector & aggregate_columns_vec, \ + std::vector & final_aggregate_columns_vec, \ + bool final_) { \ + (void)batch_get_key_holder; \ + if constexpr (Method::State::is_serialized_key && HasBatchGetKeyHolderMemberFunc) \ + { \ + if (batch_get_key_holder) \ + convertToBlocksImpl( \ + method, \ + method.data.impls[bucket], \ + key_sizes, \ + key_columns_vec, \ + aggregate_columns_vec, \ + final_aggregate_columns_vec, \ + arena, \ + final_); \ + else \ + convertToBlocksImpl( \ + method, \ + method.data.impls[bucket], \ + key_sizes, \ + key_columns_vec, \ + aggregate_columns_vec, \ + final_aggregate_columns_vec, \ + arena, \ + final_); \ + } \ + else \ + { \ + convertToBlocksImpl( \ + method, \ + method.data.impls[bucket], \ + key_sizes, \ + key_columns_vec, \ + aggregate_columns_vec, \ + final_aggregate_columns_vec, \ + arena, \ + final_); \ + } \ }; FILLER_DEFINE(convert_key, false); @@ -1469,7 +1621,7 @@ void Aggregator::execute(const BlockInputStreamPtr & stream, AggregatedDataVaria src_bytes / elapsed_seconds / 1048576.0); } -template +template void Aggregator::convertToBlockImpl( Method & method, Table & data, @@ -1489,7 +1641,7 @@ void Aggregator::convertToBlockImpl( raw_key_columns.push_back(column.get()); if (final) - convertToBlockImplFinal( + convertToBlockImplFinal( method, data, key_sizes, @@ -1497,7 +1649,7 @@ void Aggregator::convertToBlockImpl( final_aggregate_columns, arena); else - convertToBlockImplNotFinal( + convertToBlockImplNotFinal( method, data, key_sizes, @@ -1508,7 +1660,7 @@ void Aggregator::convertToBlockImpl( data.clearAndShrink(); } -template +template void Aggregator::convertToBlocksImpl( Method & method, Table & data, @@ -1537,7 +1689,7 @@ void Aggregator::convertToBlocksImpl( } if (final) - convertToBlocksImplFinal( + convertToBlocksImplFinal( method, data, key_sizes, @@ -1545,7 +1697,7 @@ void Aggregator::convertToBlocksImpl( final_aggregate_columns_vec, arena); else - convertToBlocksImplNotFinal( + convertToBlocksImplNotFinal( method, data, key_sizes, @@ -1718,7 +1870,7 @@ struct AggregatorMethodInitKeyColumnHelper +template void NO_INLINE Aggregator::convertToBlockImplFinal( Method & method, Table & data, @@ -1743,15 +1895,34 @@ void NO_INLINE Aggregator::convertToBlockImplFinal( agg_keys_helper.initAggKeys(data.size(), key_columns); } + // For key_serialized, deserialize will be batch-wise if it's serialized batch-wise. + PaddedPODArray key_places; + if constexpr (batch_deserialize_key) + key_places.reserve(params.max_block_size); + // Doesn't prefetch agg data, because places[data.size()] is needed, which can be very large. data.forEachValue([&](const auto & key [[maybe_unused]], auto & mapped) { if constexpr (!skip_convert_key) { - agg_keys_helper.insertKeyIntoColumns(key, key_columns, key_sizes_ref, params.collators); + if constexpr (batch_deserialize_key) + { + // Assume key is StringRef, because only key_string and key_serialize can be batch-wise. + key_places.push_back(key.data); + } + else + { + agg_keys_helper.insertKeyIntoColumns(key, key_columns, key_sizes_ref, params.collators); + } } insertAggregatesIntoColumns(mapped, final_aggregate_columns, arena); }); + + if constexpr (!skip_convert_key && batch_deserialize_key) + { + if (!key_places.empty()) + method.insertKeyIntoColumnsBatch(key_places, key_columns); + } } namespace @@ -1791,7 +1962,7 @@ std::vector>> initAg } } // namespace -template +template void NO_INLINE Aggregator::convertToBlocksImplFinal( Method & method, Table & data, @@ -1824,25 +1995,53 @@ void NO_INLINE Aggregator::convertToBlocksImplFinal( const auto rows = data.size(); std::unique_ptr places(new AggregateDataPtr[rows]); + PaddedPODArray key_places; + // For key_serialized, deserialize will be batch-wise if it's serialized batch-wise. + if constexpr (batch_deserialize_key) + key_places.reserve(params.max_block_size); + size_t current_bound = params.max_block_size; size_t key_columns_vec_index = 0; data.forEachValue([&](const auto & key [[maybe_unused]], auto & mapped) { if constexpr (!skip_convert_key) { - agg_keys_helpers[key_columns_vec_index] - ->insertKeyIntoColumns(key, key_columns_vec[key_columns_vec_index], key_sizes_ref, params.collators); + if constexpr (batch_deserialize_key) + { + // Assume key is StringRef, because only key_string and key_serialize can be batch-wise. + key_places.push_back(key.data); + } + else + { + agg_keys_helpers[key_columns_vec_index]->insertKeyIntoColumns( + key, + key_columns_vec[key_columns_vec_index], + key_sizes_ref, + params.collators); + } } places[data_index] = mapped; ++data_index; if unlikely (data_index == current_bound) { + if constexpr (!skip_convert_key && batch_deserialize_key) + { + method.insertKeyIntoColumnsBatch(key_places, key_columns_vec[key_columns_vec_index]); + key_places.clear(); + } + ++key_columns_vec_index; current_bound += params.max_block_size; } }); + if constexpr (!skip_convert_key && batch_deserialize_key) + { + if (!key_places.empty()) + method.insertKeyIntoColumnsBatch(key_places, key_columns_vec[key_columns_vec_index]); + } + data_index = 0; current_bound = params.max_block_size; key_columns_vec_index = 0; @@ -1862,7 +2061,7 @@ void NO_INLINE Aggregator::convertToBlocksImplFinal( } } -template +template void NO_INLINE Aggregator::convertToBlockImplNotFinal( Method & method, Table & data, @@ -1884,10 +2083,23 @@ void NO_INLINE Aggregator::convertToBlockImplNotFinal( agg_keys_helper.initAggKeys(data.size(), key_columns); } + // For key_serialized, deserialize will be batch-wise if it's serialized batch-wise. + PaddedPODArray key_places; + if constexpr (batch_deserialize_key) + key_places.reserve(params.max_block_size); + data.forEachValue([&](const auto & key [[maybe_unused]], auto & mapped) { if constexpr (!skip_convert_key) { - agg_keys_helper.insertKeyIntoColumns(key, key_columns, key_sizes_ref, params.collators); + if constexpr (batch_deserialize_key) + { + // Assume key is StringRef. + key_places.push_back(key.data); + } + else + { + agg_keys_helper.insertKeyIntoColumns(key, key_columns, key_sizes_ref, params.collators); + } } /// reserved, so push_back does not throw exceptions @@ -1896,9 +2108,15 @@ void NO_INLINE Aggregator::convertToBlockImplNotFinal( mapped = nullptr; }); + + if constexpr (!skip_convert_key && batch_deserialize_key) + { + if (!key_places.empty()) + method.insertKeyIntoColumnsBatch(key_places, key_columns); + } } -template +template void NO_INLINE Aggregator::convertToBlocksImplNotFinal( Method & method, Table & data, @@ -1925,13 +2143,30 @@ void NO_INLINE Aggregator::convertToBlocksImplNotFinal( agg_keys_helpers = initAggKeysForKeyColumnsVec(method, key_columns_vec, params.max_block_size, data.size()); } + PaddedPODArray key_places; + // For key_serialized, deserialize will be batch-wise if it's serialized batch-wise. + if constexpr (batch_deserialize_key) + key_places.reserve(params.max_block_size); + size_t data_index = 0; + size_t current_bound = params.max_block_size; + size_t key_columns_vec_index = 0; data.forEachValue([&](const auto & key [[maybe_unused]], auto & mapped) { - size_t key_columns_vec_index = data_index / params.max_block_size; if constexpr (!skip_convert_key) { - agg_keys_helpers[key_columns_vec_index] - ->insertKeyIntoColumns(key, key_columns_vec[key_columns_vec_index], key_sizes_ref, params.collators); + if constexpr (batch_deserialize_key) + { + // Assume key is StringRef, because only key_string and key_serialize can be batch-wise. + key_places.push_back(key.data); + } + else + { + agg_keys_helpers[key_columns_vec_index]->insertKeyIntoColumns( + key, + key_columns_vec[key_columns_vec_index], + key_sizes_ref, + params.collators); + } } /// reserved, so push_back does not throw exceptions @@ -1940,7 +2175,25 @@ void NO_INLINE Aggregator::convertToBlocksImplNotFinal( ++data_index; mapped = nullptr; + + if unlikely (data_index == current_bound) + { + if constexpr (!skip_convert_key && batch_deserialize_key) + { + method.insertKeyIntoColumnsBatch(key_places, key_columns_vec[++key_columns_vec_index]); + key_places.clear(); + } + + ++key_columns_vec_index; + current_bound += params.max_block_size; + } }); + + if constexpr (!skip_convert_key && batch_deserialize_key) + { + if (!key_places.empty()) + method.insertKeyIntoColumnsBatch(key_places, key_columns_vec[++key_columns_vec_index]); + } } template @@ -2148,7 +2401,7 @@ BlocksList Aggregator::prepareBlocksAndFill( return res_list; } - +// todo check if ok BlocksList Aggregator::prepareBlocksAndFillWithoutKey(AggregatedDataVariants & data_variants, bool final) const { size_t rows = 1; @@ -2190,33 +2443,63 @@ BlocksList Aggregator::prepareBlocksAndFillWithoutKey(AggregatedDataVariants & d BlocksList Aggregator::prepareBlocksAndFillSingleLevel(AggregatedDataVariants & data_variants, bool final) const { size_t rows = data_variants.size(); -#define M(NAME, skip_convert_key) \ - case AggregationMethodType(NAME): \ - { \ - auto & tmp_method = *ToAggregationMethodPtr(NAME, data_variants.aggregation_method_impl); \ - auto & tmp_data = ToAggregationMethodPtr(NAME, data_variants.aggregation_method_impl) -> data; \ - convertToBlocksImpl( \ - tmp_method, \ - tmp_data, \ - key_sizes, \ - key_columns_vec, \ - aggregate_columns_vec, \ - final_aggregate_columns_vec, \ - data_variants.aggregates_pool, \ - final_); \ - break; \ + const bool batch_get_key_holder = data_variants.batch_get_key_holder; +#define M(NAME, skip_convert_key) \ + case AggregationMethodType(NAME): \ + { \ + auto & tmp_method = *ToAggregationMethodPtr(NAME, data_variants.aggregation_method_impl); \ + auto & tmp_data = ToAggregationMethodPtr(NAME, data_variants.aggregation_method_impl) -> data; \ + using MethodType = std::decay_t; \ + if constexpr ( \ + MethodType::State::is_serialized_key && HasBatchGetKeyHolderMemberFunc) \ + { \ + if (batch_get_key_holder) \ + convertToBlocksImpl( \ + tmp_method, \ + tmp_data, \ + key_sizes, \ + key_columns_vec, \ + aggregate_columns_vec, \ + final_aggregate_columns_vec, \ + data_variants.aggregates_pool, \ + final_); \ + else \ + convertToBlocksImpl( \ + tmp_method, \ + tmp_data, \ + key_sizes, \ + key_columns_vec, \ + aggregate_columns_vec, \ + final_aggregate_columns_vec, \ + data_variants.aggregates_pool, \ + final_); \ + } \ + else \ + { \ + convertToBlocksImpl( \ + tmp_method, \ + tmp_data, \ + key_sizes, \ + key_columns_vec, \ + aggregate_columns_vec, \ + final_aggregate_columns_vec, \ + data_variants.aggregates_pool, \ + final_); \ + } \ + break; \ } #define M_skip_convert_key(NAME) M(NAME, true) #define M_convert_key(NAME) M(NAME, false) #define FILLER_DEFINE(name, M_tmp) \ - auto filler_##name = [&data_variants, this]( \ + auto filler_##name = [&data_variants, this, batch_get_key_holder]( \ const Sizes & key_sizes, \ std::vector & key_columns_vec, \ std::vector & aggregate_columns_vec, \ std::vector & final_aggregate_columns_vec, \ bool final_) { \ + (void)batch_get_key_holder; \ switch (data_variants.type) \ { \ APPLY_FOR_VARIANTS_SINGLE_LEVEL(M_tmp) \ @@ -2406,6 +2689,9 @@ MergingBucketsPtr Aggregator::mergeAndConvertToBlocks( non_empty_data[i]->aggregates_pools.end()); } + for (auto & data : non_empty_data) + RUNTIME_CHECK(non_empty_data[0]->batch_get_key_holder == data->batch_get_key_holder); + // for single level merge, concurrency must be 1. size_t merge_concurrency = has_at_least_one_two_level ? std::max(max_threads, 1) : 1; return std::make_shared(*this, non_empty_data, final, merge_concurrency); diff --git a/dbms/src/Interpreters/Aggregator.h b/dbms/src/Interpreters/Aggregator.h index b9c1c18f484..d939348a1e0 100644 --- a/dbms/src/Interpreters/Aggregator.h +++ b/dbms/src/Interpreters/Aggregator.h @@ -52,7 +52,6 @@ class IBlockOutputStream; template class AggHashTableToBlocksBlockInputStream; - /** Different data structures that can be used for aggregation * For efficiency, the aggregation data itself is put into the pool. * Data and pool ownership (states of aggregate functions) @@ -714,6 +713,12 @@ struct AggregationMethodSerialized for (size_t i = 0; i < key_columns.size(); ++i) pos = key_columns[i]->deserializeAndInsertFromArena(pos, collators.empty() ? nullptr : collators[i]); } + + static void insertKeyIntoColumnsBatch(PaddedPODArray & key_places, std::vector & key_columns) + { + for (auto * key_column : key_columns) + key_column->deserializeForCmpAndInsertFromPos(key_places, true); + } }; @@ -755,6 +760,8 @@ struct AggregatedDataVariants : private boost::noncopyable */ AggregatedDataWithoutKey without_key = nullptr; + bool batch_get_key_holder = false; + using AggregationMethod_key8 = AggregationMethodOneNumber; using AggregationMethod_key16 = AggregationMethodOneNumber; using AggregationMethod_key32 = AggregationMethodOneNumber; @@ -1449,18 +1456,31 @@ class Aggregator template void executeImpl( Method & method, - Arena * aggregates_pool, + AggregatedDataVariants & result, AggProcessInfo & agg_process_info, TiDB::TiDBCollators & collators) const; - template + template < + bool collect_hit_rate, + bool only_loopup, + bool enable_prefetch, + bool batch_get_key_holder, + typename KeyHolderType, + typename Method> void executeImplBatch( Method & method, typename Method::State & state, Arena * aggregates_pool, AggProcessInfo & agg_process_info) const; - template + template < + bool collect_hit_rate, + bool only_lookup, + bool enable_prefetch, + bool batch_get_key_holder, + bool compute_agg_data, + typename KeyHolderType, + typename Method> void handleOneBatch( Method & method, typename Method::State & state, @@ -1498,7 +1518,7 @@ class Aggregator template void mergeSingleLevelDataImpl(ManyAggregatedDataVariants & non_empty_data) const; - template + template void convertToBlockImpl( Method & method, Table & data, @@ -1512,7 +1532,7 @@ class Aggregator // The template parameter skip_convert_key indicates whether we can skip deserializing the keys in the HashMap. // For example, select first_row(c1) from t group by c1, where c1 is a string column with collator, // only the result of first_row(c1) needs to be constructed. The key c1 only needs to reference to first_row(c1). - template + template void convertToBlocksImpl( Method & method, Table & data, @@ -1523,7 +1543,7 @@ class Aggregator Arena * arena, bool final) const; - template + template void convertToBlockImplFinal( Method & method, Table & data, @@ -1532,7 +1552,7 @@ class Aggregator MutableColumns & final_aggregate_columns, Arena * arena) const; - template + template void convertToBlocksImplFinal( Method & method, Table & data, @@ -1541,7 +1561,7 @@ class Aggregator std::vector & final_aggregate_columns_vec, Arena * arena) const; - template + template void convertToBlockImplNotFinal( Method & method, Table & data, @@ -1549,7 +1569,7 @@ class Aggregator std::vector key_columns, AggregateColumnsData & aggregate_columns) const; - template + template void convertToBlocksImplNotFinal( Method & method, Table & data,