Skip to content

Commit

Permalink
Support first_value window function (#7427)
Browse files Browse the repository at this point in the history
ref #7376
  • Loading branch information
xzhangxian1008 authored May 16, 2023
1 parent 14fa0cb commit 0292765
Show file tree
Hide file tree
Showing 10 changed files with 270 additions and 14 deletions.
21 changes: 13 additions & 8 deletions dbms/src/DataStreams/WindowBlockInputStream.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,9 +122,14 @@ Block WindowBlockInputStream::readImpl()
}

// Judge whether current_partition_row is end row of partition in current block
// How to judge?
// Compare data in previous partition with the new scanned data.
bool WindowTransformAction::isDifferentFromPrevPartition(UInt64 current_partition_row)
{
// prev_frame_start refers to the data in previous partition
const Columns & reference_columns = inputAt(prev_frame_start);

// partition_end refers to the new scanned data
const Columns & compared_columns = inputAt(partition_end);

for (size_t i = 0; i < partition_column_indices.size(); ++i)
Expand Down Expand Up @@ -299,9 +304,9 @@ void WindowTransformAction::advanceFrameStart()
}
}

bool WindowTransformAction::arePeers(const RowNumber & x, const RowNumber & y) const
bool WindowTransformAction::arePeers(const RowNumber & peer_group_last_row, const RowNumber & current_row) const
{
if (x == y)
if (peer_group_last_row == current_row)
{
// For convenience, a row is always its own peer.
return true;
Expand All @@ -324,18 +329,18 @@ bool WindowTransformAction::arePeers(const RowNumber & x, const RowNumber & y) c

for (size_t i = 0; i < n; ++i)
{
const auto * column_x = inputAt(x)[order_column_indices[i]].get();
const auto * column_y = inputAt(y)[order_column_indices[i]].get();
const auto * column_peer_last = inputAt(peer_group_last_row)[order_column_indices[i]].get();
const auto * column_current = inputAt(current_row)[order_column_indices[i]].get();
if (window_description.order_by[i].collator)
{
if (column_x->compareAt(x.row, y.row, *column_y, 1 /* nan_direction_hint */, *window_description.order_by[i].collator) != 0)
if (column_peer_last->compareAt(peer_group_last_row.row, current_row.row, *column_current, 1 /* nan_direction_hint */, *window_description.order_by[i].collator) != 0)
{
return false;
}
}
else
{
if (column_x->compareAt(x.row, y.row, *column_y, 1 /* nan_direction_hint */) != 0)
if (column_peer_last->compareAt(peer_group_last_row.row, current_row.row, *column_current, 1 /* nan_direction_hint */) != 0)
{
return false;
}
Expand Down Expand Up @@ -607,8 +612,8 @@ void WindowTransformAction::tryCalculate()
partition_start = partition_end;
advanceRowNumber(partition_end);
partition_ended = false;
// We have to reset the frame and other pointers when the new partition
// starts.

// We have to reset the frame and other pointers when the new partition starts.
frame_start = partition_start;
frame_end = partition_start;
prev_frame_start = partition_start;
Expand Down
3 changes: 2 additions & 1 deletion dbms/src/DataStreams/WindowBlockInputStream.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ struct WindowTransformAction
void advancePartitionEnd();
bool isDifferentFromPrevPartition(UInt64 current_partition_row);

bool arePeers(const RowNumber & x, const RowNumber & y) const;
bool arePeers(const RowNumber & peer_group_last_row, const RowNumber & current_row) const;

void advanceFrameStart();
void advanceFrameEndCurrentRow();
Expand Down Expand Up @@ -202,6 +202,7 @@ struct WindowTransformAction

// The row for which we are now computing the window functions.
RowNumber current_row;

// The start of current peer group, needed for CURRENT ROW frame start.
// For ROWS frame, always equal to the current row, and for RANGE and GROUP
// frames may be earlier.
Expand Down
1 change: 1 addition & 0 deletions dbms/src/Debug/MockExecutor/FuncSigMap.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,5 +96,6 @@ std::unordered_map<String, tipb::ExprType> window_func_name_to_sig({
{"DenseRank", tipb::ExprType::DenseRank},
{"Lead", tipb::ExprType::Lead},
{"Lag", tipb::ExprType::Lag},
{"FirstValue", tipb::ExprType::FirstValue},
});
} // namespace DB::tests
14 changes: 14 additions & 0 deletions dbms/src/Debug/MockExecutor/WindowBinder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
#include <Debug/MockExecutor/FuncSigMap.h>
#include <Debug/MockExecutor/WindowBinder.h>
#include <Parsers/ASTFunction.h>
#include <tipb/expression.pb.h>


namespace DB::mock
{
Expand Down Expand Up @@ -73,6 +75,13 @@ bool WindowBinder::toTiPBExecutor(tipb::Executor * tipb_executor, int32_t collat
ft->set_decimal(first_arg_type.decimal());
break;
}
case tipb::ExprType::FirstValue:
{
assert(window_expr->children_size() == 1);
const auto arg_type = window_expr->children(0).field_type();
(*ft) = arg_type;
break;
}
default:
ft->set_tp(TiDB::TypeLongLong);
ft->set_flag(TiDB::ColumnFlagBinary);
Expand Down Expand Up @@ -202,6 +211,11 @@ ExecutorBinderPtr compileWindow(ExecutorBinderPtr input, size_t & executor_index
}
break;
}
case tipb::ExprType::FirstValue:
{
ci = children_ci[0];
break;
}
default:
throw Exception(fmt::format("Unsupported window function {}", func->name), ErrorCodes::LOGICAL_ERROR);
}
Expand Down
3 changes: 2 additions & 1 deletion dbms/src/Flash/Coprocessor/DAGUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ const std::unordered_map<tipb::ExprType, String> window_func_map({
{tipb::ExprType::RowNumber, "row_number"},
{tipb::ExprType::Lead, "lead"},
{tipb::ExprType::Lag, "lag"},
{tipb::ExprType::FirstValue, "first_value"},
});

const std::unordered_map<tipb::ExprType, String> agg_func_map({
Expand Down Expand Up @@ -1030,10 +1031,10 @@ bool isWindowFunctionExpr(const tipb::Expr & expr)
case tipb::ExprType::DenseRank:
case tipb::ExprType::Lead:
case tipb::ExprType::Lag:
case tipb::ExprType::FirstValue:
// case tipb::ExprType::CumeDist:
// case tipb::ExprType::PercentRank:
// case tipb::ExprType::Ntile:
// case tipb::ExprType::FirstValue:
// case tipb::ExprType::LastValue:
// case tipb::ExprType::NthValue:
return true;
Expand Down
1 change: 1 addition & 0 deletions dbms/src/TestUtils/mockExecutor.h
Original file line number Diff line number Diff line change
Expand Up @@ -295,5 +295,6 @@ MockWindowFrame buildDefaultRowsFrame();
#define Lag1(expr) makeASTFunction("Lag", (expr))
#define Lag2(expr1, expr2) makeASTFunction("Lag", (expr1), (expr2))
#define Lag3(expr1, expr2, expr3) makeASTFunction("Lag", (expr1), (expr2), (expr3))
#define FirstValue(expr) makeASTFunction("FirstValue", (expr))
} // namespace tests
} // namespace DB
39 changes: 39 additions & 0 deletions dbms/src/WindowFunctions/IWindowFunction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,44 @@ struct WindowFunctionRowNumber final : public IWindowFunction
}
};

struct WindowFunctionFirstValue final : public IWindowFunction
{
public:
static constexpr auto name = "first_value";

explicit WindowFunctionFirstValue(const DataTypes & argument_types_)
: IWindowFunction(argument_types_)
{
RUNTIME_CHECK(argument_types_.size() == 1);
return_type = argument_types_[0];
}

String getName() const override
{
return name;
}

DataTypePtr getReturnType() const override
{
return return_type;
}

void windowInsertResultInto(
WindowTransformAction & action,
size_t function_index,
const ColumnNumbers & arguments) override
{
assert(action.frame_started);
IColumn & to = *action.blockAt(action.current_row).output_columns[function_index];
const auto & value_column = *action.inputAt(action.frame_start)[arguments[0]];
const auto & value_field = value_column[action.frame_start.row];
to.insert(value_field);
}

private:
DataTypePtr return_type;
};

/**
LEAD/LAG(<expression>[,offset[, default_value]]) OVER (
PARTITION BY (expr)
Expand Down Expand Up @@ -319,5 +357,6 @@ void registerWindowFunctions(WindowFunctionFactory & factory)
factory.registerFunction<WindowFunctionRowNumber>();
factory.registerFunction<WindowFunctionLeadLagBase<LeadImpl>>();
factory.registerFunction<WindowFunctionLeadLagBase<LagImpl>>();
factory.registerFunction<WindowFunctionFirstValue>();
}
} // namespace DB
141 changes: 141 additions & 0 deletions dbms/src/WindowFunctions/tests/gtest_first_value.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
// Copyright 2023 PingCAP, Ltd.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <Interpreters/Context.h>
#include <TestUtils/ExecutorTestUtils.h>

namespace DB::tests
{
// TODO Tests with frame should be added
class FirstValue : public DB::tests::ExecutorTest
{
static const size_t max_concurrency_level = 10;

public:
static constexpr auto value_col_name = "first_value";
const ASTPtr value_col = col(value_col_name);

void initializeContext() override
{
ExecutorTest::initializeContext();
}

void executeWithConcurrencyAndBlockSize(const std::shared_ptr<tipb::DAGRequest> & request, const ColumnsWithTypeAndName & expect_columns)
{
std::vector<size_t> block_sizes{1, 2, 3, 4, DEFAULT_BLOCK_SIZE};
for (auto block_size : block_sizes)
{
context.context->setSetting("max_block_size", Field(static_cast<UInt64>(block_size)));
ASSERT_COLUMNS_EQ_R(expect_columns, executeStreams(request));
ASSERT_COLUMNS_EQ_UR(expect_columns, executeStreams(request, 2));
ASSERT_COLUMNS_EQ_UR(expect_columns, executeStreams(request, max_concurrency_level));
}
}

void executeFunctionAndAssert(
const ColumnWithTypeAndName & result,
const ASTPtr & function,
const ColumnsWithTypeAndName & input)
{
ColumnsWithTypeAndName actual_input = input;
assert(actual_input.size() == 3);
TiDB::TP value_tp = dataTypeToTP(actual_input[2].type);

actual_input[0].name = "partition";
actual_input[1].name = "order";
actual_input[2].name = value_col_name;
context.addMockTable(
{"test_db", "test_table_for_first_value"},
{{"partition", TiDB::TP::TypeLongLong, actual_input[0].type->isNullable()},
{"order", TiDB::TP::TypeLongLong, actual_input[1].type->isNullable()},
{value_col_name, value_tp, actual_input[2].type->isNullable()}},
actual_input);

auto request = context
.scan("test_db", "test_table_for_first_value")
.sort({{"partition", false}, {"order", false}}, true)
.window(function, {"order", false}, {"partition", false}, MockWindowFrame{})
.build(context);

ColumnsWithTypeAndName expect = input;
expect.push_back(result);
executeWithConcurrencyAndBlockSize(request, expect);
}

template <typename IntType>
void testInt()
{
executeFunctionAndAssert(
toVec<IntType>({1, 2, 2, 2, 2, 6, 6, 6, 6, 6, 11, 11, 11}),
FirstValue(value_col),
{toVec<Int64>(/*partition*/ {0, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3}),
toVec<Int64>(/*order*/ {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}),
toVec<IntType>(/*value*/ {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13})});

executeFunctionAndAssert(
toNullableVec<IntType>({{}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}}),
FirstValue(value_col),
{toNullableVec<Int64>(/*partition*/ {0, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3}),
toNullableVec<Int64>(/*order*/ {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}),
toNullableVec<IntType>(/*value*/ {{}, {}, 3, 4, 5, {}, 7, 8, 9, 10, {}, 12, 13})});
}

template <typename FloatType>
void testFloat()
{
executeFunctionAndAssert(
toVec<FloatType>({1, 2, 2, 2, 2, 6, 6, 6, 6, 6, 11, 11, 11}),
FirstValue(value_col),
{toVec<Int64>(/*partition*/ {0, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3}),
toVec<Int64>(/*order*/ {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}),
toVec<FloatType>(/*value*/ {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13})});

executeFunctionAndAssert(
toNullableVec<FloatType>({{}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}}),
FirstValue(value_col),
{toNullableVec<Int64>(/*partition*/ {0, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3}),
toNullableVec<Int64>(/*order*/ {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}),
toNullableVec<FloatType>(/*value*/ {{}, {}, 3, 4, 5, {}, 7, 8, 9, 10, {}, 12, 13})});
}
};

TEST_F(FirstValue, firstValue)
try
{
executeFunctionAndAssert(
toVec<String>({"1", "2", "2", "2", "2", "6", "6", "6", "6", "6", "11", "11", "11"}),
FirstValue(value_col),
{toVec<Int64>(/*partition*/ {0, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3}),
toVec<Int64>(/*order*/ {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}),
toVec<String>(/*value*/ {"1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "11", "12", "13"})});

executeFunctionAndAssert(
toNullableVec<String>({{}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}}),
FirstValue(value_col),
{toNullableVec<Int64>(/*partition*/ {0, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3}),
toNullableVec<Int64>(/*order*/ {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}),
toNullableVec<String>(/*value*/ {{}, {}, "3", "4", "5", {}, "7", "8", "9", "10", {}, "12", "13"})});

// TODO support unsigned int.
testInt<Int8>();
testInt<Int16>();
testInt<Int32>();
testInt<Int64>();

testFloat<Float32>();
testFloat<Float64>();
}
CATCH

} // namespace DB::tests
7 changes: 4 additions & 3 deletions dbms/src/WindowFunctions/tests/gtest_lead_lag.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ template <typename T>
using Limits = std::numeric_limits<T>;

// TODO Support more convenient testing framework for Window Function.
// TODO Tests with frame should be added
class LeadLag : public DB::tests::ExecutorTest
{
static const size_t max_concurrency_level = 10;
Expand Down Expand Up @@ -60,9 +61,9 @@ class LeadLag : public DB::tests::ExecutorTest
actual_input[2].name = value_col_name;
context.addMockTable(
{"test_db", "test_table_for_lead_lag"},
{{"partition", TiDB::TP::TypeLongLong},
{"order", TiDB::TP::TypeLongLong},
{value_col_name, value_tp}},
{{"partition", TiDB::TP::TypeLongLong, actual_input[0].type->isNullable()},
{"order", TiDB::TP::TypeLongLong, actual_input[1].type->isNullable()},
{value_col_name, value_tp, actual_input[2].type->isNullable()}},
actual_input);

auto request = context
Expand Down
Loading

0 comments on commit 0292765

Please sign in to comment.