diff --git a/.gitignore b/.gitignore index ec48ea7f41..f09148e5ba 100644 --- a/.gitignore +++ b/.gitignore @@ -48,8 +48,8 @@ config/ configure config-h.in autom4te.cache -*Makefile.in -*Makefile +build/*Makefile.in +build/*Makefile libtool aclocal.m4 config.log diff --git a/CMakeLists.txt b/CMakeLists.txt index 6f5b754caa..ede776c5c7 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -868,7 +868,7 @@ file(GLOB_RECURSE function(add_test_util_lib TYPE) string(TOLOWER ${TYPE} TYPE_LOWER) add_library(noisepage_test_util_${TYPE_LOWER} ${TYPE} ${NOISEPAGE_TEST_UTIL_SRCS}) - add_custom_command(TARGET noisepage_test_util_${TYPE_LOWER} DEPENDS gtest gtest_main gmock gmock_main) + add_custom_command(TARGET noisepage_test_util_${TYPE_LOWER} DEPENDS gtest gtest_main) target_compile_options(noisepage_test_util_${TYPE_LOWER} PRIVATE "-Werror" "-Wall") # Inject the source directory path into the translation units for test utility lib target_compile_definitions(noisepage_test_util_${TYPE_LOWER} PRIVATE NOISEPAGE_BUILD_ROOT=${CMAKE_BINARY_DIR}) @@ -878,7 +878,7 @@ function(add_test_util_lib TYPE) ${CMAKE_BINARY_DIR}/_deps/src/googletest/googletest/include/ ) target_link_libraries(noisepage_test_util_${TYPE_LOWER} PUBLIC - ${CMAKE_BINARY_DIR}/lib/libgtest.a ${CMAKE_BINARY_DIR}/lib/libgmock.a + ${CMAKE_BINARY_DIR}/lib/libgtest.a util_${TYPE_LOWER} pqxx) set_target_properties(noisepage_test_util_${TYPE_LOWER} PROPERTIES CXX_EXTENSIONS OFF UNITY_BUILD ${NOISEPAGE_UNITY_BUILD}) endfunction() @@ -917,7 +917,6 @@ function(add_noisepage_test add_executable(${TEST_NAME} ${EXCLUDE_OPTION} ${TEST_SOURCES}) target_compile_options(${TEST_NAME} PRIVATE "-Werror" "-Wall" "-fvisibility=hidden") - target_link_libraries(${TEST_NAME} PRIVATE ${CMAKE_BINARY_DIR}/lib/libgmock_main.a) if (${NOISEPAGE_ENABLE_SHARED}) target_link_libraries(${TEST_NAME} PRIVATE noisepage_test_util_shared) else () @@ -1119,10 +1118,10 @@ file(GLOB_RECURSE ) add_library(noisepage_benchmark_util STATIC ${NOISEPAGE_BENCHMARK_UTIL_SRCS}) -add_custom_command(TARGET noisepage_benchmark_util DEPENDS gtest gtest_main gmock gmock_main) +add_custom_command(TARGET noisepage_benchmark_util DEPENDS gtest gtest_main) target_compile_options(noisepage_benchmark_util PRIVATE "-Werror" "-Wall") target_include_directories(noisepage_benchmark_util PUBLIC ${PROJECT_SOURCE_DIR}/benchmark/include) -target_link_libraries(noisepage_benchmark_util PUBLIC ${CMAKE_BINARY_DIR}/lib/libgmock_main.a noisepage_test_util_static benchmark) +target_link_libraries(noisepage_benchmark_util PUBLIC noisepage_test_util_static benchmark) set_target_properties(noisepage_benchmark_util PROPERTIES CXX_EXTENSIONS OFF) set(NOISEPAGE_BENCHMARKS "") diff --git a/Jenkinsfile-utils.groovy b/Jenkinsfile-utils.groovy index b21045f911..e6791d06ba 100644 --- a/Jenkinsfile-utils.groovy +++ b/Jenkinsfile-utils.groovy @@ -76,13 +76,13 @@ void stageTest(Boolean runPipelineMetrics, Map args = [:]) { buildType = (args.cmake.toUpperCase().contains("CMAKE_BUILD_TYPE=RELEASE")) ? "release" : "debug" - sh script: "cd build && PYTHONPATH=.. timeout 20m python3 -m script.testing.junit --build-type=$buildType --query-mode=simple", label: 'UnitTest (Simple)' + sh script: "cd build && PYTHONPATH=.. timeout 40m python3 -m script.testing.junit --build-type=$buildType --query-mode=simple", label: 'UnitTest (Simple)' sh script: "cd build && PYTHONPATH=.. timeout 60m python3 -m script.testing.junit --build-type=$buildType --query-mode=simple -a 'compiled_query_execution=True' -a 'bytecode_handlers_path=./bytecode_handlers_ir.bc'", label: 'UnitTest (Simple, Compiled Execution)' - sh script: "cd build && PYTHONPATH=.. timeout 20m python3 -m script.testing.junit --build-type=$buildType --query-mode=extended", label: 'UnitTest (Extended)' + sh script: "cd build && PYTHONPATH=.. timeout 40m python3 -m script.testing.junit --build-type=$buildType --query-mode=extended", label: 'UnitTest (Extended)' sh script: "cd build && PYTHONPATH=.. timeout 60m python3 -m script.testing.junit --build-type=$buildType --query-mode=extended -a 'compiled_query_execution=True' -a 'bytecode_handlers_path=./bytecode_handlers_ir.bc'", label: 'UnitTest (Extended, Compiled Execution)' if (runPipelineMetrics) { - sh script: "cd build && PYTHONPATH=.. timeout 20m python3 -m script.testing.junit --build-type=$buildType --query-mode=extended -a 'pipeline_metrics_enable=True' -a 'pipeline_metrics_sample_rate=100' -a 'counters_enable=True' -a 'query_trace_metrics_enable=True'", label: 'UnitTest (Extended with pipeline metrics, counters, and query trace metrics)' + sh script: "cd build && PYTHONPATH=.. timeout 40m python3 -m script.testing.junit --build-type=$buildType --query-mode=extended -a 'pipeline_metrics_enable=True' -a 'pipeline_metrics_sample_rate=100' -a 'counters_enable=True' -a 'query_trace_metrics_enable=True'", label: 'UnitTest (Extended with pipeline metrics, counters, and query trace metrics)' sh script: "cd build && PYTHONPATH=.. timeout 60m python3 -m script.testing.junit --build-type=$buildType --query-mode=extended -a 'pipeline_metrics_enable=True' -a 'pipeline_metrics_sample_rate=100' -a 'counters_enable=True' -a 'query_trace_metrics_enable=True' -a 'compiled_query_execution=True' -a 'bytecode_handlers_path=./bytecode_handlers_ir.bc'", label: 'UnitTest (Extended, Compiled Execution with pipeline metrics, counters, and query trace metrics)' } diff --git a/Makefile b/Makefile new file mode 100644 index 0000000000..145e25be92 --- /dev/null +++ b/Makefile @@ -0,0 +1,20 @@ +# Makefile +# Some scripting shortcuts. + +# Run all regression checks in all execution modes +check-regress: check-regress-interpreted check-regress-compiled + +# Run all regression tests in interpreted mode +check-regress-interpreted: + cd build && PYTHONPATH=.. python -m script.testing.junit --build-type=debug --query-mode=simple + +# Run all regression tests in compiled mode +check-regress-compiled: + PYTHONPATH=.. python -m script.testing.junit --build-type=debug --query-mode=simple -a 'compiled_query_execution=True' -a 'bytecode_handlers_path=./bytecode_handlers_ir.bc' + +check-regress-udf: + cd build && PYTHONPATH=.. python -m script.testing.junit --build-type=debug --query-mode=simple --tracefile-test=udf.test + +# Re-generate the trace file for UDF regression tests +generate-regress-udf: + cd script/testing/junit/ && ant generate-trace -Dpath=sql/udf.sql -Ddb-url=jdbc:postgresql://localhost/test -Ddb-user=postgres -Ddb-password=password -Doutput-name=udf.test diff --git a/benchmark/runner/execution_runners.cpp b/benchmark/runner/execution_runners.cpp index 244760f842..42f3806dce 100644 --- a/benchmark/runner/execution_runners.cpp +++ b/benchmark/runner/execution_runners.cpp @@ -11,6 +11,7 @@ #include "common/scoped_timer.h" #include "execution/compiler/compilation_context.h" #include "execution/compiler/executable_query.h" +#include "execution/exec/execution_context_builder.h" #include "execution/exec/execution_settings.h" #include "execution/execution_util.h" #include "execution/sql/ddl_executors.h" @@ -448,9 +449,17 @@ class ExecutionRunners : public benchmark::Fixture { exec_settings = *exec_settings_arg; } - auto exec_ctx = std::make_unique( - db_oid, common::ManagedPointer(txn), execution::exec::NoOpResultConsumer(), out_plan->GetOutputSchema().Get(), - common::ManagedPointer(accessor), exec_settings, metrics_manager_, DISABLED, DISABLED); + auto exec_ctx = execution::exec::ExecutionContextBuilder() + .WithDatabaseOID(db_oid) + .WithExecutionSettings(exec_settings) + .WithTxnContext(common::ManagedPointer{txn}) + .WithOutputSchema(out_plan->GetOutputSchema()) + .WithOutputCallback(execution::exec::NoOpResultConsumer{}) + .WithCatalogAccessor(common::ManagedPointer{accessor}) + .WithMetricsManager(metrics_manager_) + .WithReplicationManager(DISABLED) + .WithRecoveryManager(DISABLED) + .Build(); execution::compiler::ExecutableQuery::query_identifier.store(ExecutionRunners::query_id++); auto exec_query = execution::compiler::CompilationContext::Compile(*out_plan, exec_settings, accessor.get(), @@ -494,9 +503,18 @@ class ExecutionRunners : public benchmark::Fixture { auto txn = txn_manager->BeginTransaction(); auto accessor = catalog->GetAccessor(common::ManagedPointer(txn), db_oid, DISABLED); auto exec_settings = ExecutionRunners::GetExecutionSettings(); - auto exec_ctx = std::make_unique( - db_oid, common::ManagedPointer(txn), nullptr, nullptr, common::ManagedPointer(accessor), exec_settings, - metrics_manager_, DISABLED, DISABLED); + + auto exec_ctx = execution::exec::ExecutionContextBuilder() + .WithDatabaseOID(db_oid) + .WithExecutionSettings(exec_settings) + .WithTxnContext(common::ManagedPointer{txn}) + .WithOutputSchema(execution::exec::ExecutionContext::NULL_OUTPUT_SCHEMA) + .WithOutputCallback(execution::exec::ExecutionContext::NULL_OUTPUT_CALLBACK) + .WithCatalogAccessor(common::ManagedPointer{accessor}) + .WithMetricsManager(metrics_manager_) + .WithReplicationManager(DISABLED) + .WithRecoveryManager(DISABLED) + .Build(); execution::sql::TableGenerator table_generator(exec_ctx.get(), block_store, accessor->GetDefaultNamespace()); if (is_build) { @@ -529,12 +547,12 @@ class ExecutionRunners : public benchmark::Fixture { } void BenchmarkExecQuery(int64_t num_iters, execution::compiler::ExecutableQuery *exec_query, - planner::OutputSchema *out_schema, bool commit, + const planner::OutputSchema *out_schema, bool commit, std::vector> *params = &empty_params, execution::exec::ExecutionSettings *exec_settings_arg = nullptr) { transaction::TransactionContext *txn = nullptr; std::unique_ptr accessor = nullptr; - std::vector> param_ref = *params; + const auto ¶ms_ref = *params; execution::exec::NoOpResultConsumer consumer; execution::exec::OutputCallback callback = consumer; @@ -553,14 +571,23 @@ class ExecutionRunners : public benchmark::Fixture { exec_settings = *exec_settings_arg; } - auto exec_ctx = std::make_unique( - db_oid, common::ManagedPointer(txn), callback, out_schema, common::ManagedPointer(accessor), exec_settings, - metrics_manager, DISABLED, DISABLED); - - // Attach params to ExecutionContext - if (static_cast(i) < param_ref.size()) { - exec_ctx->SetParams(common::ManagedPointer>(¶m_ref[i])); + // TODO(Kyle): This makes an unnecessary copy of the query parameters + std::vector parameters{}; + if (static_cast(i) < params_ref.size()) { + std::copy(params_ref[i].cbegin(), params_ref[i].cend(), std::back_inserter(parameters)); } + auto exec_ctx = execution::exec::ExecutionContextBuilder() + .WithDatabaseOID(db_oid) + .WithQueryParametersFrom(parameters) + .WithExecutionSettings(exec_settings) + .WithTxnContext(common::ManagedPointer{txn}) + .WithOutputSchema(common::ManagedPointer{out_schema}) + .WithOutputCallback(callback) + .WithCatalogAccessor(common::ManagedPointer{accessor}) + .WithMetricsManager(metrics_manager_) + .WithReplicationManager(DISABLED) + .WithRecoveryManager(DISABLED) + .Build(); exec_query->Run(common::ManagedPointer(exec_ctx), mode); @@ -582,10 +609,17 @@ class ExecutionRunners : public benchmark::Fixture { auto txn = txn_manager_->BeginTransaction(); auto accessor = catalog_->GetAccessor(common::ManagedPointer(txn), db_oid, DISABLED); auto exec_settings = GetExecutionSettings(); - auto exec_ctx = std::make_unique( - db_oid, common::ManagedPointer(txn), nullptr, nullptr, common::ManagedPointer(accessor), exec_settings, - metrics_manager_, DISABLED, DISABLED); - exec_ctx->SetExecutionMode(static_cast(mode)); + auto exec_ctx = execution::exec::ExecutionContextBuilder() + .WithDatabaseOID(db_oid) + .WithExecutionSettings(exec_settings) + .WithTxnContext(common::ManagedPointer{txn}) + .WithOutputSchema(execution::exec::ExecutionContext::NULL_OUTPUT_SCHEMA) + .WithOutputCallback(execution::exec::ExecutionContext::NULL_OUTPUT_CALLBACK) + .WithCatalogAccessor(common::ManagedPointer{accessor}) + .WithMetricsManager(metrics_manager_) + .WithReplicationManager(DISABLED) + .WithRecoveryManager(DISABLED) + .Build(); selfdriving::PipelineOperatingUnits units; selfdriving::ExecutionOperatingUnitFeatureVector pipe0_vec; @@ -939,15 +973,24 @@ BENCHMARK_DEFINE_F(ExecutionRunners, SEQ0_OutputRunners)(benchmark::State &state auto txn = txn_manager_->BeginTransaction(); auto accessor = catalog_->GetAccessor(common::ManagedPointer(txn), db_oid, DISABLED); - auto schema = std::make_unique(std::move(cols)); + auto schema = std::make_unique(std::move(cols)); auto exec_settings = GetExecutionSettings(); execution::compiler::ExecutableQuery::query_identifier.store(ExecutionRunners::query_id++); execution::exec::NoOpResultConsumer consumer; execution::exec::OutputCallback callback = consumer; - auto exec_ctx = std::make_unique( - db_oid, common::ManagedPointer(txn), callback, schema.get(), common::ManagedPointer(accessor), exec_settings, - metrics_manager_, DISABLED, DISABLED); + + auto exec_ctx = execution::exec::ExecutionContextBuilder() + .WithDatabaseOID(db_oid) + .WithExecutionSettings(exec_settings) + .WithTxnContext(common::ManagedPointer{txn}) + .WithOutputSchema(common::ManagedPointer{schema}) + .WithOutputCallback(callback) + .WithCatalogAccessor(common::ManagedPointer{accessor}) + .WithMetricsManager(metrics_manager_) + .WithReplicationManager(DISABLED) + .WithRecoveryManager(DISABLED) + .Build(); auto exec_query = execution::compiler::ExecutableQuery(output.str(), common::ManagedPointer(exec_ctx), false, 16, exec_settings, txn->StartTime()); @@ -1011,9 +1054,18 @@ void ExecutionRunners::ExecuteIndexOperation(benchmark::State *state, bool is_in auto exec_settings = GetExecutionSettings(); execution::exec::NoOpResultConsumer consumer; execution::exec::OutputCallback callback = consumer; - auto exec_ctx = std::make_unique( - db_oid, common::ManagedPointer(txn), callback, nullptr, common::ManagedPointer(accessor), exec_settings, - metrics_manager, DISABLED, DISABLED); + + auto exec_ctx = execution::exec::ExecutionContextBuilder() + .WithDatabaseOID(db_oid) + .WithExecutionSettings(exec_settings) + .WithTxnContext(common::ManagedPointer{txn}) + .WithOutputSchema(execution::exec::ExecutionContext::NULL_OUTPUT_SCHEMA) + .WithOutputCallback(callback) + .WithCatalogAccessor(common::ManagedPointer{accessor}) + .WithMetricsManager(metrics_manager_) + .WithReplicationManager(DISABLED) + .WithRecoveryManager(DISABLED) + .Build(); // A brief discussion of the features: // NUM_ROWS: size of the index @@ -2062,9 +2114,17 @@ void InitializeRunnersState() { // Load the database auto accessor = catalog->GetAccessor(common::ManagedPointer(txn), db_oid, DISABLED); auto exec_settings = ExecutionRunners::GetExecutionSettings(); - auto exec_ctx = std::make_unique( - db_oid, common::ManagedPointer(txn), nullptr, nullptr, common::ManagedPointer(accessor), exec_settings, - db_main->GetMetricsManager(), DISABLED, DISABLED); + auto exec_ctx = execution::exec::ExecutionContextBuilder() + .WithDatabaseOID(db_oid) + .WithExecutionSettings(exec_settings) + .WithTxnContext(common::ManagedPointer{txn}) + .WithOutputSchema(execution::exec::ExecutionContext::NULL_OUTPUT_SCHEMA) + .WithOutputCallback(execution::exec::ExecutionContext::NULL_OUTPUT_CALLBACK) + .WithCatalogAccessor(common::ManagedPointer{accessor}) + .WithMetricsManager(db_main->GetMetricsManager()) + .WithReplicationManager(DISABLED) + .WithRecoveryManager(DISABLED) + .Build(); execution::sql::TableGenerator table_gen(exec_ctx.get(), block_store, accessor->GetDefaultNamespace()); table_gen.GenerateExecutionRunnersData(settings, config); diff --git a/benchmark/runner/procbench_runner.cpp b/benchmark/runner/procbench_runner.cpp new file mode 100644 index 0000000000..b479b13999 --- /dev/null +++ b/benchmark/runner/procbench_runner.cpp @@ -0,0 +1,86 @@ +#include "benchmark/benchmark.h" +#include "common/scoped_timer.h" +#include "common/worker_pool.h" +#include "execution/execution_util.h" +#include "execution/vm/module.h" +#include "main/db_main.h" +#include "test_util/fs_util.h" +#include "test_util/procbench/workload.h" + +/** + * The local paths to the data directories. + * https://github.com/malin1993ml/tpl_tables and "bash gen_tpch.sh 0.1". + */ +static constexpr const char PROCBENCH_TABLE_ROOT[] = "/home/turing/dev/tpl-tables/tpcds-tables/"; +static constexpr const char PROCBENCH_DATABASE_NAME[] = "procbench_runner_db"; + +namespace noisepage::runner { + +/** + * ProcbenchRunner runs SQL ProcBench benchmarks. + */ +class ProcbenchRunner : public benchmark::Fixture { + public: + /** The execution mode for the execution engine */ + const execution::vm::ExecutionMode exec_mode_ = execution::vm::ExecutionMode::Interpret; + + /** The main database instance */ + std::unique_ptr db_main_; + + /** The workload with loaded data and queries */ + std::unique_ptr workload_; + + /** Local paths to data */ + const std::string procbench_table_root_{PROCBENCH_TABLE_ROOT}; + const std::string procbench_database_name_{PROCBENCH_DATABASE_NAME}; + + /** Setup the database instance for benchmark. */ + void SetUp(const benchmark::State &state) final { + auto db_main_builder = DBMain::Builder() + .SetUseGC(true) + .SetUseCatalog(true) + .SetUseGCThread(true) + .SetUseMetrics(true) + .SetUseMetricsThread(true) + .SetBlockStoreSize(1000000) + .SetBlockStoreReuse(1000000) + .SetRecordBufferSegmentSize(1000000) + .SetRecordBufferSegmentReuse(1000000) + .SetBytecodeHandlersPath(common::GetBinaryArtifactPath("bytecode_handlers_ir.bc")); + db_main_ = db_main_builder.Build(); + + auto metrics_manager = db_main_->GetMetricsManager(); + metrics_manager->SetMetricSampleRate(metrics::MetricsComponent::EXECUTION_PIPELINE, 100); + metrics_manager->EnableMetric(metrics::MetricsComponent::EXECUTION_PIPELINE); + } + + /** Teardown the database instance after a benchmark */ + void TearDown(const benchmark::State &state) final { + // free db main here so we don't need to use the loggers anymore + db_main_.reset(); + } +}; + +// NOLINTNEXTLINE +BENCHMARK_DEFINE_F(ProcbenchRunner, Runner)(benchmark::State &state) { + // Load the ProcBench tables and compile the queries + workload_ = std::make_unique(common::ManagedPointer(db_main_), procbench_database_name_, + procbench_table_root_); + + const auto start = std::chrono::high_resolution_clock::now(); + + // Execute the workload + workload_->Execute(6, execution::vm::ExecutionMode::Interpret); + + const auto stop = std::chrono::high_resolution_clock::now(); + const auto duration = std::chrono::duration_cast(stop - start).count(); + + state.SetIterationTime(duration); + + // Free the workload here so we don't need to use the loggers anymore + workload_.reset(); +} + +BENCHMARK_REGISTER_F(ProcbenchRunner, Runner)->Unit(benchmark::kMillisecond)->UseManualTime()->Iterations(1); + +} // namespace noisepage::runner diff --git a/benchmark/runner/tpch_runner.cpp b/benchmark/runner/tpch_runner.cpp index 2e9e6bb239..680a10a8bb 100644 --- a/benchmark/runner/tpch_runner.cpp +++ b/benchmark/runner/tpch_runner.cpp @@ -7,25 +7,51 @@ #include "test_util/fs_util.h" #include "test_util/tpch/workload.h" +/** + * The local paths to the data directories. + * https://github.com/malin1993ml/tpl_tables and "bash gen_tpch.sh 0.1". + */ +static constexpr const char TPCH_TABLE_ROOT[] = "/home/turing/dev/tpl-tables/tables/"; +static constexpr const char SSB_TABLE_ROOT[] = "/home/turing/dev/tpl-tables/tables/"; +static constexpr const char TPCH_DATABASE_NAME[] = "tpch_runner_db"; + namespace noisepage::runner { + +/** + * TPCHRunner runs TPCH benchmarks. + */ class TPCHRunner : public benchmark::Fixture { public: - const int8_t total_num_threads_ = 4; // defines the number of terminals (workers threads) - const uint64_t execution_us_per_worker_ = 20000000; // Time (us) to run per terminal (worker thread) + /** Defines the number of terminals (workers threads) */ + const int8_t total_num_threads_ = 4; + + /** Time (us) to run per terminal (worker thread) */ + const uint64_t execution_us_per_worker_ = 20000000; + + /** The average intervals in microseconds */ std::vector avg_interval_us_ = {10, 20, 50, 100, 200, 500, 1000}; + + /** The execution mode for the execution engine */ const execution::vm::ExecutionMode mode_ = execution::vm::ExecutionMode::Interpret; - const bool single_test_run_ = false; + /** Flag indicating if only a single test run should be run */ + const bool single_test_run_ = true; + + /** The main database instance */ std::unique_ptr db_main_; + + /** The workload with loaded data and queries */ std::unique_ptr workload_; - // To get tpl_tables, https://github.com/malin1993ml/tpl_tables and "bash gen_tpch.sh 0.1". - const std::string tpch_table_root_ = "../../../tpl_tables/tables/"; - const std::string ssb_dir_ = "../../../SSB_Table_Generator/ssb_tables/"; - const std::string tpch_database_name_ = "tpch_runner_db"; + /** Local paths to data */ + const std::string tpch_table_root_{TPCH_TABLE_ROOT}; + const std::string ssb_dir_{SSB_TABLE_ROOT}; + const std::string tpch_database_name_{TPCH_DATABASE_NAME}; + /** The benchmark type */ tpch::Workload::BenchmarkType type_ = tpch::Workload::BenchmarkType::TPCH; + /** Setup the database instance for benchmark. */ void SetUp(const benchmark::State &state) final { auto db_main_builder = DBMain::Builder() .SetUseGC(true) @@ -38,7 +64,6 @@ class TPCHRunner : public benchmark::Fixture { .SetRecordBufferSegmentSize(1000000) .SetRecordBufferSegmentReuse(1000000) .SetBytecodeHandlersPath(common::GetBinaryArtifactPath("bytecode_handlers_ir.bc")); - db_main_ = db_main_builder.Build(); auto metrics_manager = db_main_->GetMetricsManager(); @@ -46,6 +71,7 @@ class TPCHRunner : public benchmark::Fixture { metrics_manager->EnableMetric(metrics::MetricsComponent::EXECUTION_PIPELINE); } + /** Teardown the database instance after a benchmark */ void TearDown(const benchmark::State &state) final { // free db main here so we don't need to use the loggers anymore db_main_.reset(); @@ -82,11 +108,13 @@ BENCHMARK_DEFINE_F(TPCHRunner, Runner)(benchmark::State &state) { } auto total_query_num = workload_->GetQueryNum() + 1; - for (uint32_t query_num = query_num_start; query_num < total_query_num; query_num += 4) - for (auto num_threads = num_thread_start; num_threads <= total_num_threads_; num_threads += 3) - for (uint32_t repeat = 0; repeat < repeat_num; ++repeat) + for (uint32_t query_num = query_num_start; query_num < total_query_num; query_num += 4) { + for (auto num_threads = num_thread_start; num_threads <= total_num_threads_; num_threads += 3) { + for (uint32_t repeat = 0; repeat < repeat_num; ++repeat) { for (auto avg_interval_us : avg_interval_us_) { - std::this_thread::sleep_for(std::chrono::seconds(2)); // Let GC clean up + // Let GC clean up + std::this_thread::sleep_for(std::chrono::seconds(2)); + common::WorkerPool thread_pool{static_cast(num_threads), {}}; thread_pool.Startup(); @@ -99,10 +127,14 @@ BENCHMARK_DEFINE_F(TPCHRunner, Runner)(benchmark::State &state) { thread_pool.WaitUntilAllFinished(); thread_pool.Shutdown(); } + } + } + } - // free the workload here so we don't need to use the loggers anymore + // Free the workload here so we don't need to use the loggers anymore workload_.reset(); } BENCHMARK_REGISTER_F(TPCHRunner, Runner)->Unit(benchmark::kMillisecond)->UseManualTime()->Iterations(1); + } // namespace noisepage::runner diff --git a/build-support/run_tpl_tests.py b/build-support/run_tpl_tests.py index ea1fdf9257..c594ee2c73 100755 --- a/build-support/run_tpl_tests.py +++ b/build-support/run_tpl_tests.py @@ -1,15 +1,25 @@ #!/usr/bin/env python3 -import argparse import os -import subprocess import sys +import argparse +import subprocess -VM_TARGET_STRING = 'VM main() returned: ' -ADAPTIVE_TARGET_STRING = 'ADAPTIVE main() returned: ' -JIT_TARGET_STRING = 'JIT main() returned: ' -TARGET_STRINGS = [VM_TARGET_STRING, ADAPTIVE_TARGET_STRING, JIT_TARGET_STRING] +# Exit codes +EXIT_SUCCESS = 0 +EXIT_FAILURE = 1 + +# String prefixed to VM execution tests +VM_TARGET_STRING = "VM main() returned: " +# String prefixed to ADAPTIVE execution tests +ADAPTIVE_TARGET_STRING = "ADAPTIVE main() returned: " + +# String prefixed to JIT execution tests +JIT_TARGET_STRING = "JIT main() returned: " + +# Collection of all target strings +TARGET_STRINGS = [VM_TARGET_STRING, ADAPTIVE_TARGET_STRING, JIT_TARGET_STRING] def run(tpl_bin, tpl_file, is_sql): args = [tpl_bin] @@ -18,11 +28,7 @@ def run(tpl_bin, tpl_file, is_sql): args.append(tpl_file) proc = subprocess.run(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE) result = [] - #print("tpl_file stdout:") - #print(proc.stdout.decode('utf-8')) - #print("tpl_file stderr:") - #print(proc.stderr.decode('utf-8')) - for line in reversed(proc.stdout.decode('utf-8').split('\n')): + for line in reversed(proc.stdout.decode("utf-8").split("\n")): if "ERROR" in line or "error" in line: return [] for target_string in TARGET_STRINGS: @@ -31,51 +37,52 @@ def run(tpl_bin, tpl_file, is_sql): result.append(line[idx + len(target_string):]) return result - def check(tpl_bin, tpl_folder, tpl_tests_file, build_dir): os.chdir(build_dir) with open(tpl_tests_file) as tpl_tests: num_tests, failed = 0, set() - print('Tests:') + print("Tests:") for line in tpl_tests: line = line.strip() - if not line or line[0] == '#': + if not line or line[0] == "#": continue - tpl_file, sql, expected_output = [x.strip() for x in line.split(',')] + tpl_file, sql, expected_output = [x.strip() for x in line.split(",")] + is_sql = sql.lower() == "true" res = run(tpl_bin, os.path.join(tpl_folder, tpl_file), is_sql) num_tests += 1 - report = 'PASS' + report = "PASS" if not res: - report = 'ERR' + report = "ERR" failed.add(tpl_file) - elif len(res) != 3 or not all(output == expected_output for output in res): - report = 'FAIL [expect: {}, actual: {}]'.format(expected_output, - res) + elif len(res) != 3 or not all(output == expected_output for output in res): + report = "FAIL [expect: {}, actual: {}]".format(expected_output, res) failed.add(tpl_file) - print('\t{}: {}'.format(tpl_file, report)) - print('{}/{} tests passed.'.format(num_tests - len(failed), num_tests)) + print("\t{}: {}".format(tpl_file, report)) + + print("{}/{} tests passed.".format(num_tests - len(failed), num_tests)) if len(failed) > 0: - print('{} failed:'.format(len(failed))) + print("{} failed:".format(len(failed))) for fail in failed: - print('\t{}'.format(fail)) - sys.exit(-1) + print("\t{}".format(fail)) + return EXIT_FAILURE + return EXIT_SUCCESS def main(): parser = argparse.ArgumentParser() - parser.add_argument('-b', dest='tpl_bin', help='TPL binary.') - parser.add_argument('-f', dest='tpl_tests_file', - help='File containing lines.') - parser.add_argument('-t', dest='tpl_folder', help='TPL tests folder.') - parser.add_argument('-d', dest='build_dir', help='Build Directory.') + parser.add_argument("-b", dest="tpl_bin", help="TPL binary.") + parser.add_argument("-f", dest="tpl_tests_file", + help="File containing lines.") + parser.add_argument("-t", dest="tpl_folder", help="TPL tests folder.") + parser.add_argument("-d", dest="build_dir", help="Build Directory.") args = parser.parse_args() - check(args.tpl_bin, args.tpl_folder, args.tpl_tests_file, args.build_dir) - + + return check(args.tpl_bin, args.tpl_folder, args.tpl_tests_file, args.build_dir) -if __name__ == '__main__': - main() +if __name__ == "__main__": + sys.exit(main()) diff --git a/docs/design_closures.md b/docs/design_closures.md new file mode 100644 index 0000000000..9a52dcb6a0 --- /dev/null +++ b/docs/design_closures.md @@ -0,0 +1,167 @@ +# Design Doc: TPL Closures + +### Overview + +This document describes the implementation of closures in TPL. It includes both a high-level description as well as a complete walkthough of the low-level implementation details. + +### Architecture + +TPL closures are implemented as regular TPL functions with the added ability to capture arbitrary variables. In the same way that return values in TPL are implemented via a "hidden" out-parameter to each function, the variables captured by a TPL closure are represented as a TPL structure that is passed as a second hidden parameter to the function that implements the logic of the closure. + +The closure itself is represented as a stack-allocated structure - a local variable within the frame of the function in which the lambda that produces the closure appears. This structure contains `N` fields. The first `N - 1` fields are the variables captured by the closure. The final field is a pointer to the compiled function that implements the closure's logic - a regular TPL function. + +TPL closures introduce some interesting implementation challenges that manifest during both code generation and during execution of the generated code. These implementation details are explored in further detail in the sections below. + +Closures can be passed like values throughout a TPL program; this allows one to, for instance, construct a closure and pass it to other functions that may then invoke it to perform computations or produce side-effects. However, there is a major limitation to our current closure implementation design: because the TPL structure that implements the closure is allocated in the stack frame of the function in which the lambda that produced the closure appears, the closure cannot escape the lexical scope of this function. In other words, we cannot return a closure from a TPL function and invoke it elsewhere because the structure that backs its implementation would be deallocated the moment the function that creates it returns. This is a major limitation, and most languages that support closures (read: every language implementation that I can find) include the ability for closures the escape the scope in which they are defined. Adding support for this functionality obviously requires some additional engineering and often significantly complicates the implementation. We get away with this implementation for now because our use-cases for closures never require us to generate code that allows the closure to escape the scope in which it is defined, but this is likely an issue we should address in the future. + +In a future refactor of this design, it may be beneficial (from a software-design perspective, at least) to implement closures as their own first-class type. Rather than implementing a closure as a TPL structure containing the closure's captures with the ad-hoc constraint that the final member is _always_ a pointer to the closure's associated function, we might consider adding a dedicated `ClosureType` type to the TPL DSL. At a first approximation, this would incur zero additional cost (compile-time or runtime) and would simplify some of the implementation because we could more easily distinguish between regular function invocations and closure invocations. + +Another limitation of the current implementation of closures is the inability to easily specify their type. Because TPL is statically typed, this makes implementing higher-order functions that accept closures as arguments or return closures more difficult. At present I am unsure of the best way to address this limitation. Languages like C++ and Rust get around this with either of 1) type erasure (e.g. C++'s `std::function` or Rust's `Fn` trait) or 2) generics. Both of these like relatively involved approaches for our purposes here in TPL. Perhaps we might consider some kind of implicit-conversion facility between function pointer types (which can be concisely specified) and closures (even those that capture). + +### Code Generation Details + +Closures and the lambda expressions that produce them introduce some additional complexity to code generation. The general flow of code generation for a function that contains a lambda expression proceeds as follows: + +- Visit the function declaration for the function in which the lambda expression appears +- Visit the body of the function; during visitation of the statement(s) in the function's body, the lambda expression is encountered +- Visit the lambda expression + - Allocate a new local in the frame of the current function for the closure structure + - Emit the bytecode to "capture" local variables; this is performed by loading the address of all captured locals into the fields in the closure (captures) structure + - Emit the bytecode to pass (a pointer to) the locally-allocated captures structure to the function that will implement the closure's logic + - Allocate the TPL function for the body of the closure + - Defer an action for the current function to visit the body of the closure; this deferred action captures (in C++-land, not in TPL!) the TPL function allocated for the closure +- Complete visitation of the function in which the lambda expression appears +- As the final step in visitation of the function, execute the deferred action to visit the body of the closure's function + +### Walkthough #0: Closure Without Captures + +As a first example, we consider the following TPL program: + +``` +fun main() -> int32 { + var addOne = lambda [] (x: int32) -> int32 { + return x + 1 + } + return addOne(1) +} +``` + +The bytecode generated for this program, with annotations, is shown below. + +``` +Data: + Data section size 0 bytes (0 locals) + +Function 0
: + Frame size 32 bytes (1 parameter, 5 locals) + param hiddenRv: offset=0 size=8 align=8 type=*int32 + + // The addOne local captures the closure that results from evaluating the lambda expression + local addOne: offset=8 size=8 align=8 type=lambda[(int32,*int32)->int32] + + // The captures structure for the closure is allocated in the frame + // of the function in which the lambda expression appears + local addOneCaptures: offset=16 size=8 align=8 type=struct{*(int32,*int32)->int32} + local tmp1: offset=24 size=4 align=4 type=int32 + local tmp2: offset=28 size=4 align=4 type=int32 + + // In the current implementation, the closure is synonymous with a + // pointer to the base of the captures structure; the bytecode in + // the body of the function generated for the closure assumes this + 0x00000000 Assign8 local=&addOne local=&addOneCaptures + 0x0000000c AssignImm4 local=&tmp2 i32=1 + + // Invoke the function generated for the closure; the captures structure is + // passed as an implicit final argument to the function call + 0x00000018 Call func= local=&tmp1 local=tmp2 local=addOne + 0x0000002c Assign4 local=hiddenRv local=tmp1 + 0x00000038 Return + +Function 1 : + Frame size 32 bytes (3 parameters, 5 locals) + param hiddenRv: offset=0 size=8 align=8 type=*int32 + param x: offset=8 size=4 align=4 type=int32 + param captures: offset=16 size=8 align=8 type=*int32 + local tmp1: offset=24 size=4 align=4 type=int32 + local tmp2: offset=28 size=4 align=4 type=int32 + + // The lambda that generated this function has no captures, therefore + // we don't need to do anything special here to handle captured variables + + 0x00000000 AssignImm4 local=&tmp2 i32=1 + // Perform the addition + 0x0000000c Add_int32_t local=&tmp1 local=x local=tmp2 + // Set the return value + 0x0000001c Assign4 local=hiddenRv local=tmp1 + 0x00000028 Return +``` + +### Walkthough #1: Closure With Captures + +As a second example, we consider the following TPL program: + +``` +fun main() -> int32 { + var x = 1 + var addValue = lambda [x] (y: int32) -> int32 { + return x + y + } + return addValue(2) +} +``` + +The bytecode generated for this program, with annotations, is shown below. + +``` +Data: + Data section size 0 bytes (0 locals) + +Function 0
: + Frame size 56 bytes (1 parameter, 7 locals) + param hiddenRv: offset=0 size=8 align=8 type=*int32 + local x: offset=8 size=4 align=4 type=int32 + local addValue: offset=16 size=8 align=8 type=lambda[(int32,*int32)->int32] + + // The first member of the captures structure is a pointer to the captured local; + // the second member of the captures structure is a pointer to the associated function + local addValueCaptures: offset=24 size=16 align=8 type=struct{*int32,*(int32,*int32)->int32} + local tmp1: offset=40 size=8 align=8 type=**int32 + local tmp2: offset=48 size=4 align=4 type=int32 + local tmp3: offset=52 size=4 align=4 type=int32 + + 0x00000000 AssignImm4 local=&x i32=1 + + // Capture the variable `x`; load the address of the local variable `x` into the captures structure + 0x0000000c Lea local=&tmp1 local=&addValueCaptures i32=0 + 0x0000001c Assign8 local=tmp1 local=&x + + // Initialize the closure itself as a pointer to the base of the captures structure + 0x00000028 Assign8 local=&addValue local=&addValueCaptures + + 0x00000034 AssignImm4 local=&tmp3 i32=2 + 0x00000040 Call func= local=&tmp2 local=tmp3 local=addValue + 0x00000054 Assign4 local=hiddenRv local=tmp2 + 0x00000060 Return + +Function 1 : + Frame size 52 bytes (3 parameters, 7 locals) + param hiddenRv: offset=0 size=8 align=8 type=*int32 + param y: offset=8 size=4 align=4 type=int32 + param captures: offset=16 size=8 align=8 type=*int32 + local tmp1: offset=24 size=4 align=4 type=int32 + local tmp2: offset=32 size=8 align=8 type=**int32 + local xptr: offset=40 size=8 align=8 type=*int32 + local tmp3: offset=48 size=4 align=4 type=int32 + + // Load the captured `x` pointer to the local `xptr` + 0x00000000 Lea local=&tmp2 local=captures i32=0 + 0x00000010 DerefN local=&xptr local=tmp2 u32=8 + + // Dereference the pointer to the captured `x` to get its value + 0x00000020 DerefN local=&tmp3 local=xptr u32=4 + + // Perform the addition and return the result + 0x00000030 Add_int32_t local=&tmp1 local=tmp3 local=y + 0x00000040 Assign4 local=hiddenRv local=tmp1 + 0x0000004c Return +``` \ No newline at end of file diff --git a/docs/design_codegen.md b/docs/design_codegen.md new file mode 100644 index 0000000000..9898adfca1 --- /dev/null +++ b/docs/design_codegen.md @@ -0,0 +1,131 @@ +# Design Doc: Execution Engine Code Generation + +### Overview + +As described in the _Execution Engine Design Document_, NoisePage utilizes [data-centric code generation](https://15721.courses.cs.cmu.edu/spring2020/papers/14-compilation/p539-neumann.pdf) to compile the query plans produced by the optimizer to a byetcode representation that is then either interpreted or JIT-compiled. This document describes some of the nuances of the code generation process. While it is a strict subset of the process descibed in the _Execution Engine Design Document_, code generation is a complex topic, and giving it its own document allows us to focus in on the details without getting lost in unrelated concerns from the layers of the execution engine above and below it. + +### Data-Centric Code Generation + +Our goal in code generation is to produce a bytecode program that implements a query plan. + +The straightforward and most common way of accomplishingz this is to have each operator in the query plan tree assume responsibility for generating the code that it requires to execute. The complete byetcode program might then be realized by having each operator generate code into a distinct bytecode function and then chaining these functions together via calls from the functions produced by parent operators to those produced by child operators. + +As mentioned above, this approach is straightforward to reason about and to implement. The code generated for each operator is nicely self-contained in a single bytecode function, allowing developers to verify the correctness of the generated code and debug code generation issues. However, the simplicity of this approach comes at the cost of query runtime performance. We now incur function-call overhead in the transition between each operator. More importantly, we leave ourselves open to the same performance issues present in any operator-centric execution model: poor code and data locality resulting from tuple-at-a-time processing among each operator. + +Data-centric code generation is a solution to these performance issues. In this paradigm, code is generated according to the data dependencies between individual operators, rather than along operator boundaries themselves. In practice, this has the effect of _fusing_ multiple operators together into larger units called _pipelines_. When multiple operators are fused into a pipeline, all of the operations required to implement the logic of each operator may be performed in sequence, without incurring function call overhead or even spilling tuple attributes to memory - it is often possible to keep tuple attributes in registers for the duration of a pipeline, dramatically improving CPU efficiency. + +### Pipelines + +Pipelines are the lowest-level unit of code generation in the NoisePage query compilation architecture. Individual operators are assigned to a pipeline (some operators may be part of more than one pipeline e.g. `JOIN` operators). The pipeline defines a set of top-level bytecode functions to generate, and invokes a set of pre-defined member functions on each of its operators to populate the body of each of these functions. The specifics of each of the functions defined by each pipeline are described below. + +**Complications** + +Since the original implementation of code generation, we have introduced several features that have required updates to the pipeline interface. Namely: +- Inductive Common Table Expressions (`WITH RECURSIVE` and `WITH ITERATIVE`) which introduce the concept of _nested pipelines_ +- User-Defined Functions which introduce the concept of an _output callback_ + +Both of these aditions, nested pipelines and output callbacks, slightly complicate the code generation process, and this additional complexity is reflected in the pipeline interface, to which we now turn our attention. + +### The Pipeline Interface + +The are several flavors of pipelines within NoisePage that differ slightly in the signature of their top-level bytecode functions, as well as their semantics. In this section, we explain each of these distinct flavors and provide the signatures of each of these top-level functions. + +#### Serial Pipelines + +We begin the discussion with serial pipelines because they are slightly less complicated than their parallel counterparts. + +**State Initialization** + +The interface for the pipeline state initialization functions is the same across all pipeline variants. + +To initialize the pipeline state, we generate: + +``` +fun Query0_Pipeline1_InitPipelineState(*QueryState, *PipelineState) +``` + +and the teardown the pipeline state, we generate: + +``` +fun Query0_Pipeline1_TeardownPipelineState(*QueryState, *PipelineState) +``` + +**Pipline Initialization** + +The interface for pipeline initialization varies depending on the pipeline variant. + +In the common case, we generate: + +``` +fun Query0_Pipeline1_Init(*QueryState) +``` + +In the case of a nested pipeline, we generate: + +``` +fun Query0_Pipeline1_Init(*QueryState, *PipelineState) +``` + +In the case of a pipeline with an output callback, we generate: + +``` +fun Query0_Pipeline1_Init(*QueryState, *PipelineState) +``` + +A pointer to the pipeline state (`*PipelineState`) is provided to the call for nested pipelines and pipelines with output callbacks because in both of these cases the pipeline state associated with the thread running the pipeline is not owned by the pipeline in question. Instead, this pipeline state structure is allocated on the stack at runtime and passed through the bytecode function invocations. + +**Pipeline Run** + +The interface for the pipeline _Run_ function varies depending on the pipeline variant. + +In the common case, we generate: + +``` +fun Query0_Pipeline1_Run(*QueryState) +``` + +In the case of nested pipelines, we generate: + +``` +fun Query0_Pipeline1_Run(*QueryState, *PipelineState) +``` + +In the case of pipelines with an output callback, we generate: + +``` +fun Query0_Pipeline1_Run(*QueryState, *PipelineState, Closure) +``` + +The distinction between pipelines with output callbacks and nested pipelines manifests here. The output callback (in the form of a TPL closure) is provided as a third parameter to the _Run_ function such that it can be invoked by the operators that utilize it (for now, just the `OutputTranslator`) in the body of the pipeline _Work_ function. + +**Pipeline Teardown** + +The interface for pipeline teardown varies depending on the pipeline variant. + +In the common case, we generate: + +``` +fun Query0_Pipeline1_Teardown(*QueryState) +``` + +In the case of a nested pipeline, we generate: + +``` +fun Query0_Pipeline1_Teardown(*QueryState, *PipelineState) +``` + +In the case of a pipeline with an output callback, we generate: + +``` +fun Query0_Pipeline1_Teardown(*QueryState, *PipelineState) +``` + +The reason that a pointer to the pipeline state is provided to the call in the latter two cases is the same as in the case of pipeline initialization. + +#### Parallel Pipelines + +Parallel pipelines require different semantics from serial pipelines. Despite these differences, only the _Work_ function is affected by the change from a serial to a parallel pipeline. + +### References + +- [Efficiently Compiling Efficient Query Plans for Modern Hardware](https://15721.courses.cs.cmu.edu/spring2020/papers/14-compilation/p539-neumann.pdf) by Thomas Neumann. The paper that introduced the concept of data-centric code generation, among other techniques now considered standard best-practice in compiling query engines. diff --git a/docs/discussion_cte_implementation.md b/docs/design_ctes.md similarity index 98% rename from docs/discussion_cte_implementation.md rename to docs/design_ctes.md index a733bf4988..c1dc33f151 100644 --- a/docs/discussion_cte_implementation.md +++ b/docs/design_ctes.md @@ -1,8 +1,14 @@ -# Discussion Doc: Common Table Expression Implementation +# Design Doc: Common Table Expressions + +### Overview This document provides an overview of some of the important features of our implementation of common table expressions (CTEs). -## Known Limitations +### Design + +TODO(Kyle): Fill this in. + +### Limitations Our current implementation of CTEs suffers from some known limitations which limits the queries we are able to execute. This section provides a comprehensive overview of the queries on which the system currently fails, a best-estimate of the underlying reason for the failure, and what might be required to address it. diff --git a/docs/design_udfs.md b/docs/design_udfs.md new file mode 100644 index 0000000000..184f180373 --- /dev/null +++ b/docs/design_udfs.md @@ -0,0 +1,41 @@ +# Design Doc: User-Defined Functions + +### Overview + +This document describes important aspects of the design and implementation of user-defined functions in NoisePage. + +### Limitations + +This section describes known limitations of our implementation of UDFs. + +**Function Argument Modes** + +Currently, only the implicit `IN` argument mode is supported. `OUT` and `INOUT` argument modes have no effect on the semantics of the function. + +**Parallel Operations** + +There is some data race that occurs when an output callback is used in the context of a parallel pipeline that results in garbage results. + +**Missing `RETURN`** + +In Postgres, a PL/pgSQL function that declares a return type but is missing a `RETURN` statement in the body of the function parses successfully, but results in a runtime error when the function is executed. Currently, we fail to parse such functions (which may be directly related to the issue below). + +**Implicit `RETURN`s** + +Currently, the following control flow is not supported: + +```sql +CREATE FUNCTION fun(x INT) RETURNS INT AS $$ +BEGIN + IF x > 10 THEN + RETURN 0; + ELSE + RETURN 1; + END IF; +END +$$ LANGUAGE PLPGSQL; +``` + +This fails in the parser because the library we use to parse the raw UDF (libpg_query) inserts an implicit empty `RETURN` at the end of the body of the function. This implicit `RETURN` has no associated expression, and therefore it fails when we attempt to parse it. + +Obviously, we can see that this implicit `RETURN` is unreachable code, so we know this UDF body is valid. diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000..6dbc5c5462 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,29 @@ +certifi==2021.5.30 +charset-normalizer==2.0.4 +coverage==5.5 +distro==1.6.0 +idna==3.2 +importlib-metadata==4.8.1 +joblib==1.0.1 +lightgbm==3.2.1 +numpy==1.21.2 +pandas==1.1.5 +prettytable==2.2.0 +psutil==5.8.0 +psycopg2==2.9.1 +pyarrow==5.0.0 +python-dateutil==2.8.2 +pytz==2021.1 +pyzmq==22.2.1 +requests==2.26.0 +scikit-learn==0.24.2 +scipy==1.7.1 +six==1.16.0 +sklearn==0.0 +threadpoolctl==2.2.0 +torch==1.9.0 +tqdm==4.62.2 +typing-extensions==3.10.0.2 +urllib3==1.26.6 +wcwidth==0.2.5 +zipp==3.5.0 diff --git a/sample_tpl/closure0.tpl b/sample_tpl/closure0.tpl new file mode 100644 index 0000000000..19d97761a2 --- /dev/null +++ b/sample_tpl/closure0.tpl @@ -0,0 +1,9 @@ +// Expected output: 2 + +fun main() -> int32 { + // Closure without capture + var addOne = lambda [] (x: int32) -> int32 { + return x + 1 + } + return addOne(1) +} \ No newline at end of file diff --git a/sample_tpl/closure1.tpl b/sample_tpl/closure1.tpl new file mode 100644 index 0000000000..ea87cabcdf --- /dev/null +++ b/sample_tpl/closure1.tpl @@ -0,0 +1,11 @@ +// Expected output: 3 + +fun main() -> int32 { + var x = 1 + // Closure that uses capture in computation; + // the closure does not write captured variable + var addValue = lambda [x] (y: int32) -> int32 { + return x + y + } + return addValue(2) +} \ No newline at end of file diff --git a/sample_tpl/closure2.tpl b/sample_tpl/closure2.tpl new file mode 100644 index 0000000000..20d04de8cb --- /dev/null +++ b/sample_tpl/closure2.tpl @@ -0,0 +1,11 @@ +// Expected output: 6 + +fun main() -> int32 { + var x = 1 + var y = 2 + // Closure that uses multiple captures in computation + var addValues = lambda [x, y] (z: int32) -> int32 { + return x + y + z + } + return addValues(3) +} \ No newline at end of file diff --git a/sample_tpl/closure3.tpl b/sample_tpl/closure3.tpl new file mode 100644 index 0000000000..086d9f782e --- /dev/null +++ b/sample_tpl/closure3.tpl @@ -0,0 +1,11 @@ +// Expected output: 2 + +fun main() -> int32 { + var x = 1 + // Closure that writes to the captured variable + var addOne = lambda [x] () -> nil { + x = x + 1 + } + addOne() + return x +} \ No newline at end of file diff --git a/sample_tpl/closure4.tpl b/sample_tpl/closure4.tpl new file mode 100644 index 0000000000..5575d78a6b --- /dev/null +++ b/sample_tpl/closure4.tpl @@ -0,0 +1,12 @@ +// Expected output: 8 + +fun main() -> int32 { + // Lambda expressions may contain other lambda expressions + var timesFour = lambda [] (x: int32) -> int32 { + var timesTwo = lambda [] (y: int32) -> int32 { + return y*2 + } + return timesTwo(x) + timesTwo(x) + } + return timesFour(2) +} diff --git a/sample_tpl/struct-lambda.tpl b/sample_tpl/struct-lambda.tpl new file mode 100644 index 0000000000..d8195af028 --- /dev/null +++ b/sample_tpl/struct-lambda.tpl @@ -0,0 +1,21 @@ +// Expected output: 10 + +struct S { + a: int + b: int + c: (int32) -> int32 +} +struct SDup { + d: int + e: int + f: int +} + +fun sss(x : int32) -> int32 { + return x +} + +fun main() -> int { + var p: S + p.c = sss +} diff --git a/sample_tpl/tpl_tests.txt b/sample_tpl/tpl_tests.txt index eb12cc6412..daf92456bb 100644 --- a/sample_tpl/tpl_tests.txt +++ b/sample_tpl/tpl_tests.txt @@ -10,6 +10,11 @@ array.tpl,false,44 array-iterate.tpl,false,110 array-iterate-2.tpl,false,110 call.tpl,false,70 +closure0.tpl,false,2 +closure1.tpl,false,3 +closure2.tpl,false,6 +closure3.tpl,false,2 +closure4.tpl,false,8 comments.tpl,false,46 compare.tpl,false,200 date-functions.tpl,false,0 diff --git a/script/installation/packages.sh b/script/installation/packages.sh index c3f7c6c0e3..58e5c3dbbb 100755 --- a/script/installation/packages.sh +++ b/script/installation/packages.sh @@ -30,12 +30,12 @@ LINUX_BUILD_PACKAGES=(\ "llvm-8" \ "pkg-config" \ "postgresql-client" \ - "python3-pip" \ "ninja-build" "wget" \ "zlib1g-dev" \ "time" \ ) + LINUX_TEST_PACKAGES=(\ "ant" \ "ccache" \ @@ -44,27 +44,6 @@ LINUX_TEST_PACKAGES=(\ "lsof" \ ) -# Packages to be installed through pip3. -PYTHON_BUILD_PACKAGES=( -) -PYTHON_TEST_PACKAGES=(\ - "distro" \ - "lightgbm" \ - "numpy" \ - "pandas" \ - "prettytable" \ - "psutil" \ - "psycopg2" \ - "pyarrow" \ - "pyzmq" \ - "requests" \ - "sklearn" \ - "torch" \ - "tqdm" \ - "coverage" \ -) - - ## ================================================================= @@ -143,12 +122,6 @@ install() { esac } -install_pip() { - curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py - python get-pip.py - rm get-pip.py -} - install_linux() { # Update apt-get. apt-get -y update @@ -160,17 +133,6 @@ install_linux() { if [ "$INSTALL_TYPE" == "test" ] || [ "$INSTALL_TYPE" = "all" ]; then apt-get -y install $( IFS=$' '; echo "${LINUX_TEST_PACKAGES[*]}" ) fi - - if [ "$INSTALL_TYPE" == "build" ] || [ "$INSTALL_TYPE" = "all" ]; then - for pkg in "${PYTHON_BUILD_PACKAGES[@]}"; do - python3 -m pip show $pkg || python3 -m pip install $pkg - done - fi - if [ "$INSTALL_TYPE" == "test" ] || [ "$INSTALL_TYPE" = "all" ]; then - for pkg in "${PYTHON_TEST_PACKAGES[@]}"; do - python3 -m pip show $pkg || python3 -m pip install $pkg - done - fi } main "$@" diff --git a/script/testing/junit/README.md b/script/testing/junit/README.md index 09951522fc..49ecf077e0 100644 --- a/script/testing/junit/README.md +++ b/script/testing/junit/README.md @@ -96,7 +96,7 @@ The procedure for running the `GenerateTrace.java` program is as follows: 1. Establish a local Postgres database and start the database server. The procedure to accomplish this depends on the particulars of your development environment. If you are using a CMU DB development machine, see the _PostgreSQL on CMU DB Development Machines_ section below. 2. Write your own SQL input file. The format of this file consists of SQL statements, one per line. Comments (denoted by `#`)) are permitted. 3. Compile the test infrastructure: `ant compile` -4. Run the filter trace program: `ant filter-trace`. The program expects 6 arguments: +4. Run the filter trace program: `ant generate-trace`. The program expects 6 arguments: - `path`: The path to the input file - `db-url`: The JDBC URL for the DBMS server - `db-user`: The database username diff --git a/script/testing/junit/sql/udf.sql b/script/testing/junit/sql/udf.sql new file mode 100644 index 0000000000..ef5b5d8a7b --- /dev/null +++ b/script/testing/junit/sql/udf.sql @@ -0,0 +1,725 @@ +-- udf.sql +-- Integration tests for user-defined functions. +-- +-- Currently, these tests rely on the fact that we +-- utilize Postgres as a reference implementation +-- because all user-defined functions are implemented +-- in the Postgres PL/SQL dialect, PL/pgSQL. + +-- Create test tables +CREATE TABLE integers(x INT, y INT); +INSERT INTO integers(x, y) VALUES (1, 1), (2, 2), (3, 3); + +CREATE TABLE strings(s TEXT); +INSERT INTO strings(s) VALUES ('aaa'), ('bbb'), ('ccc'); + +-- ---------------------------------------------------------------------------- +-- return_constant() + +CREATE FUNCTION return_constant() RETURNS INT AS $$ \ +BEGIN \ + RETURN 1; \ +END \ +$$ LANGUAGE PLPGSQL; + +SELECT return_constant(); + +DROP FUNCTION return_constant(); + +CREATE FUNCTION return_constant_str() RETURNS TEXT AS $$ \ +BEGIN \ + RETURN 'hello, functions'; \ +END \ +$$ LANGUAGE PLPGSQL; + +SELECT return_constant_str(); + +DROP FUNCTION return_constant_str(); + +-- ---------------------------------------------------------------------------- +-- return_input() + +CREATE FUNCTION return_input(x INT) RETURNS INT AS $$ \ +BEGIN \ + RETURN x; \ +END \ +$$ LANGUAGE PLPGSQL; + +SELECT x, return_input(x) FROM integers; + +DROP FUNCTION return_input(INT); + +CREATE FUNCTION return_input(x TEXT) RETURNS TEXT AS $$ \ +BEGIN \ + RETURN x; \ +END \ +$$ LANGUAGE PLPGSQL; + +SELECT s, return_input(s) FROM strings; + +DROP FUNCTION return_input(TEXT); + +-- ---------------------------------------------------------------------------- +-- return_sum() + +CREATE FUNCTION return_sum(x INT, y INT) RETURNS INT AS $$ \ +BEGIN \ + RETURN x + y; \ +END \ +$$ LANGUAGE PLPGSQL; + +SELECT x, y, return_sum(x, y) FROM integers; + +DROP FUNCTION return_sum(INT, INT); + +-- ---------------------------------------------------------------------------- +-- return_prod() + +CREATE FUNCTION return_product(x INT, y INT) RETURNS INT AS $$ \ +BEGIN \ + RETURN x * y; \ +END \ +$$ LANGUAGE PLPGSQL; + +SELECT x, y, return_product(x, y) FROM integers; + +DROP FUNCTION return_product(INT, INT); + +-- ---------------------------------------------------------------------------- +-- integer_decl() + +CREATE FUNCTION integer_decl() RETURNS INT AS $$ \ +DECLARE \ + x INT := 0; \ +BEGIN \ + RETURN x; \ +END \ +$$ LANGUAGE PLPGSQL; + +SELECT integer_decl(); + +DROP FUNCTION integer_decl(); + +-- ---------------------------------------------------------------------------- +-- conditional() +-- +-- TODO(Kyle): The final RETURN 0 is unreachable, but we +-- need this temporary hack to deal with missing logic in parser + +CREATE FUNCTION conditional(x INT) RETURNS INT AS $$ \ +BEGIN \ + IF x > 1 THEN \ + RETURN 1; \ + ELSE \ + RETURN 2; \ + END IF; \ + RETURN 0; \ +END \ +$$ LANGUAGE PLPGSQL; + +SELECT x, conditional(x) FROM integers; + +DROP FUNCTION conditional(INT); + +-- Nested conditional control flow +CREATE FUNCTION conditional(x INT, y INT) RETURNS INT AS $$ \ +BEGIN \ + IF x > 1 THEN \ + IF y > 1 THEN \ + RETURN 1; \ + ELSE \ + RETURN 2; \ + END IF; \ + ELSE \ + IF y > 1 THEN \ + RETURN 3; \ + ELSE \ + RETURN 4; \ + END IF; \ + END IF; \ + RETURN 0; \ +END \ +$$ LANGUAGE PLPGSQL; + +SELECT conditional(1, 1); +SELECT conditional(1, 2); +SELECT conditional(2, 1); +SELECT conditional(2, 2); + +DROP FUNCTION conditional(INT, INT); + +-- ---------------------------------------------------------------------------- +-- proc_while() + +CREATE FUNCTION proc_while() RETURNS INT AS $$ \ +DECLARE \ + x INT := 0; \ +BEGIN \ + WHILE x < 10 LOOP \ + x = x + 1; \ + END LOOP; \ + RETURN x; \ +END \ +$$ LANGUAGE PLPGSQL; + +SELECT proc_while(); + +DROP FUNCTION proc_while(); + +-- ---------------------------------------------------------------------------- +-- proc_fori() +-- +-- TODO(Kyle): for-loop control flow (integer variant) is not supported + +-- CREATE FUNCTION proc_fori() RETURNS INT AS $$ \ +-- DECLARE \ +-- x INT := 0; \ +-- BEGIN \ +-- FOR i IN 1..10 LOOP \ +-- x = x + 1; \ +-- END LOOP; \ +-- RETURN x; \ +-- END \ +-- $$ LANGUAGE PLPGSQL; + +-- SELECT x, proc_fori() FROM integers; + +-- ---------------------------------------------------------------------------- +-- sql_select_single_constant() + +CREATE FUNCTION sql_select_single_constant() RETURNS INT AS $$ \ +DECLARE \ + v INT; \ +BEGIN \ + SELECT 1 INTO v; \ + RETURN v; \ +END \ +$$ LANGUAGE PLPGSQL; + +SELECT sql_select_single_constant(); + +DROP FUNCTION sql_select_single_constant(); + +-- ---------------------------------------------------------------------------- +-- sql_select_mutliple_constants() + +CREATE FUNCTION sql_select_multiple_constants() RETURNS INT AS $$ \ +DECLARE \ + x INT; \ + y INT; \ +BEGIN \ + SELECT 1, 2 INTO x, y; \ + RETURN x + y; \ +END \ +$$ LANGUAGE PLPGSQL; + +SELECT sql_select_multiple_constants(); + +DROP FUNCTION sql_select_multiple_constants(); + +-- ---------------------------------------------------------------------------- +-- sql_select_constant_assignment() + +CREATE FUNCTION sql_select_constant_assignment() RETURNS INT AS $$ \ +DECLARE \ + x INT; \ + y INT; \ +BEGIN \ + x = (SELECT 1); \ + y = (SELECT 2); \ + RETURN x + y; \ +END \ +$$ LANGUAGE PLPGSQL; + +SELECT sql_select_constant_assignment(); + +DROP FUNCTION sql_select_constant_assignment(); + +-- ---------------------------------------------------------------------------- +-- sql_embedded_agg_count() + +CREATE FUNCTION sql_embedded_agg_count() RETURNS INT AS $$ \ +DECLARE \ + v INT; \ +BEGIN \ + SELECT COUNT(*) FROM integers INTO v; \ + RETURN v; \ +END \ +$$ LANGUAGE PLPGSQL; + +SELECT sql_embedded_agg_count(); + +DROP FUNCTION sql_embedded_agg_count(); + +-- ---------------------------------------------------------------------------- +-- sql_embedded_agg_min() + +CREATE FUNCTION sql_embedded_agg_min() RETURNS INT AS $$ \ +DECLARE \ + v INT; \ +BEGIN \ + SELECT MIN(x) FROM integers INTO v; \ + RETURN v; \ +END \ +$$ LANGUAGE PLPGSQL; + +SELECT sql_embedded_agg_min(); + +DROP FUNCTION sql_embedded_agg_min(); + +-- ---------------------------------------------------------------------------- +-- sql_embedded_agg_max() + +CREATE FUNCTION sql_embedded_agg_max() RETURNS INT AS $$ \ +DECLARE \ + v INT; \ +BEGIN \ + SELECT MAX(x) FROM integers INTO v; \ + RETURN v; \ +END \ +$$ LANGUAGE PLPGSQL; + +SELECT sql_embedded_agg_max(); + +DROP FUNCTION sql_embedded_agg_max(); + +-- ---------------------------------------------------------------------------- +-- sql_embedded_agg_multi() + +CREATE FUNCTION sql_embedded_agg_multi() RETURNS INT AS $$ \ +DECLARE \ + minimum INT; \ + maximum INT; \ +BEGIN \ + minimum = (SELECT MIN(x) FROM integers); \ + maximum = (SELECT MAX(x) FROM integers); \ + RETURN minimum + maximum; \ +END; \ +$$ LANGUAGE PLPGSQL; + +DROP FUNCTION sql_embedded_agg_multi(); + +-- ---------------------------------------------------------------------------- +-- proc_fors_constant_var() + +-- Select constant into a scalar variable +CREATE FUNCTION proc_fors_constant_var() RETURNS INT AS $$ \ +DECLARE \ + v INT; \ + x INT := 0; \ +BEGIN \ + FOR v IN SELECT 1 LOOP \ + x = x + 1; \ + END LOOP; \ + RETURN x; \ +END \ +$$ LANGUAGE PLPGSQL; + +SELECT proc_fors_constant_var(); + +DROP FUNCTION proc_fors_constant_var(); + +-- ---------------------------------------------------------------------------- +-- proc_fors_constant_vars() + +-- Select multiple constants in scalar variables +CREATE FUNCTION proc_fors_constant_vars() RETURNS INT AS $$ \ +DECLARE \ + x INT; \ + y INT; \ + z INT := 0; \ +BEGIN \ + FOR x, y IN SELECT 1, 2 LOOP \ + z = z + 1; \ + END LOOP; \ + RETURN z; \ +END \ +$$ LANGUAGE PLPGSQL; + +SELECT proc_fors_constant_vars(); + +DROP FUNCTION proc_fors_constant_vars(); + +-- ---------------------------------------------------------------------------- +-- proc_fors_rec() +-- +-- TODO(Kyle): RECORD types not supported + +-- -- Bind query result to a RECORD type +-- CREATE FUNCTION proc_fors_rec() RETURNS INT AS $$ \ +-- DECLARE \ +-- x INT := 0; \ +-- v RECORD; \ +-- BEGIN \ +-- FOR v IN (SELECT z FROM temp) LOOP \ +-- x = x + 1; \ +-- END LOOP; \ +-- RETURN x; \ +-- END \ +-- $$ LANGUAGE PLPGSQL; + +-- SELECT proc_fors_rec() FROM integers; + +-- ---------------------------------------------------------------------------- +-- proc_fors_var() + +-- Bind query result directly to INT type +CREATE FUNCTION proc_fors_var() RETURNS INT AS $$ \ +DECLARE \ + c INT := 0; \ + v INT; \ +BEGIN \ + FOR v IN (SELECT x FROM integers) LOOP \ + c = c + 1; \ + END LOOP; \ + RETURN c; \ +END \ +$$ LANGUAGE PLPGSQL; + +SELECT proc_fors_var(); + +DROP FUNCTION proc_fors_var(); + +-- ---------------------------------------------------------------------------- +-- proc_call_*() + +CREATE FUNCTION proc_call_callee() RETURNS INT AS $$ \ +BEGIN \ + RETURN 1; \ +END \ +$$ LANGUAGE PLPGSQL; + +-- Just RETURN the result of call +CREATE FUNCTION proc_call_ret() RETURNS INT AS $$ \ +BEGIN \ + RETURN proc_call_callee(); \ +END \ +$$ LANGUAGE PLPGSQL; + +SELECT proc_call_ret(); + +-- Assign the result of call to variable +CREATE FUNCTION proc_call_assign() RETURNS INT AS $$ \ +DECLARE \ + v INT; \ +BEGIN \ + v = proc_call_callee(); \ + RETURN v; \ +END \ +$$ LANGUAGE PLPGSQl; + +SELECT proc_call_assign(); + +-- SELECT the result of call into variable +CREATE FUNCTION proc_call_select() RETURNS INT AS $$ \ +DECLARE \ + v INT; \ +BEGIN \ + SELECT proc_call_callee() INTO v; \ + RETURN v; \ +END \ +$$ LANGUAGE PLPGSQL; + +SELECT proc_call_select(); + +DROP FUNCTION proc_call_callee(); +DROP FUNCTION proc_call_ret(); +DROP FUNCTION proc_call_assign(); +DROP FUNCTION proc_call_select(); + +-- ---------------------------------------------------------------------------- +-- proc_predicate() + +CREATE FUNCTION proc_predicate(threshold INT) RETURNS INT AS $$ \ +DECLARE \ + c INT; \ +BEGIN \ + SELECT COUNT(x) FROM integers WHERE x > threshold INTO c; \ + RETURN c; \ +END \ +$$ LANGUAGE PLPGSQL; + +SELECT proc_predicate(0); +SELECT proc_predicate(1); +SELECT proc_predicate(2); + +DROP FUNCTION proc_predicate(INT); + +-- ---------------------------------------------------------------------------- +-- proc_call_args() + +-- Argument to call can be an expression +CREATE FUNCTION proc_call_args() RETURNS INT AS $$ \ +DECLARE \ + x INT := 1; \ + y INT := 2; \ + z INT := 3; \ +BEGIN \ + RETURN ABS(x * y + z); \ +END \ +$$ LANGUAGE PLPGSQL; + +SELECT proc_call_args(); + +DROP FUNCTION proc_call_args(); + +-- Argument to call can be an identifier +CREATE FUNCTION proc_call_args() RETURNS INT AS $$ \ +DECLARE \ + x INT := 1; \ + y INT := 2; \ + z INT := 3; \ + r INT; \ +BEGIN \ + r = x * y + z; \ + RETURN ABS(r); \ +END \ +$$ LANGUAGE PLPGSQL; + +SELECT proc_call_args(); + +DROP FUNCTION proc_call_args(); + +-- ---------------------------------------------------------------------------- +-- proc_promotion() + +-- Able to (silently) promote REAL to DOUBLE PRECISION +CREATE FUNCTION proc_promotion() RETURNS REAL AS $$ \ +DECLARE \ + x INT := 1; \ + y REAL := 1.0; \ + t REAL; \ +BEGIN \ + t = x * y; \ + RETURN FLOOR(t); \ +END \ +$$ LANGUAGE PLPGSQL; + +SELECT proc_promotion(); +DROP FUNCTION proc_promotion(); + +-- Able to (silently) promote FLOAT to DOUBLE PRECISION +CREATE FUNCTION proc_promotion() RETURNS FLOAT AS $$ \ +DECLARE \ + x INT := 1; \ + y FLOAT := 1.0; \ + t FLOAT; \ +BEGIN \ + t = x * y; \ + RETURN FLOOR(t); \ +END \ +$$ LANGUAGE PLPGSQL; + +SELECT proc_promotion(); +DROP FUNCTION proc_promotion(); + +-- Promotion does not affect correct operation of DOUBLE PRECISION +CREATE FUNCTION proc_promotion() RETURNS DOUBLE PRECISION AS $$ \ +DECLARE \ + x INT := 1; \ + y DOUBLE PRECISION := 1.0; \ + t DOUBLE PRECISION; \ +BEGIN \ + t = x * y; \ + RETURN FLOOR(t); \ +END \ +$$ LANGUAGE PLPGSQL; + +SELECT proc_promotion(); +DROP FUNCTION proc_promotion(); + +-- Promotion does not affect correct operation of FLOAT8 +CREATE FUNCTION proc_promotion() RETURNS DOUBLE PRECISION AS $$ \ +DECLARE \ + x INT := 1; \ + y DOUBLE PRECISION := 1.0; \ + t DOUBLE PRECISION; \ +BEGIN \ + t = x * y; \ + RETURN FLOOR(t); \ +END \ +$$ LANGUAGE PLPGSQL; + +SELECT proc_promotion(); +DROP FUNCTION proc_promotion(); + +-- Promotion works as expected with UDF arguments +CREATE FUNCTION proc_promotion(x FLOAT) RETURNS FLOAT AS $$ \ +BEGIN \ + RETURN x; \ +END \ +$$ LANGUAGE PLPGSQL; + +SELECT proc_promotion(1337.0); +DROP FUNCTION proc_promotion(FLOAT); + +-- Promotion works as expected with UDF arguments +CREATE FUNCTION proc_promotion(x REAL) RETURNS REAL AS $$ \ +BEGIN \ + RETURN x; \ +END \ +$$ LANGUAGE PLPGSQL; + +SELECT proc_promotion(1337.0); +DROP FUNCTION proc_promotion(REAL); + +-- Promotion works as expected with UDF arguments +CREATE FUNCTION proc_promotion(x DOUBLE PRECISION) RETURNS DOUBLE PRECISION AS $$ \ +BEGIN \ + RETURN x; \ +END \ +$$ LANGUAGE PLPGSQL; + +SELECT proc_promotion(1337.0); +DROP FUNCTION proc_promotion(DOUBLE PRECISION); + +-- Promotion works as expected with UDF arguments +CREATE FUNCTION proc_promotion(x FLOAT8) RETURNS FLOAT8 AS $$ \ +BEGIN \ + RETURN x; \ +END \ +$$ LANGUAGE PLPGSQL; + +SELECT proc_promotion(1337.0); +DROP FUNCTION proc_promotion(FLOAT8); + +-- ---------------------------------------------------------------------------- +-- proc_cast() + +-- CAST works in assignment expression +CREATE FUNCTION proc_cast() RETURNS FLOAT AS $$ \ +DECLARE \ + x FLOAT; \ +BEGIN \ + x = CAST(1 AS FLOAT); \ + RETURN x; \ +END \ +$$ LANGUAGE PLPGSQL; + +SELECT proc_cast(); +DROP FUNCTION proc_cast(); + +-- TODO(Kyle): this is a great example of a function that +-- we can't currently compile because we only resort to a +-- full handoff to the SQL execution infrastructure in the +-- case of assignment expressions. For everything else, in +-- this case a RETURN statement, we don't yet have the +-- ability to defer this to the SQL engine, and we also can't +-- handle it in the "builtin" manner, so we just fail. + +-- CREATE FUNCTION proc_cast() RETURNS FLOAT AS $$ \ +-- BEGIN \ +-- RETURN CAST(1 AS FLOAT); \ +-- END \ +-- $$ LANGUAGE PLPGSQL; + +-- ---------------------------------------------------------------------------- +-- proc_is_null() + +CREATE FUNCTION proc_is_null(x INT) RETURNS INT AS $$ \ +DECLARE \ + r INT; \ +BEGIN \ + IF x IS NULL THEN \ + r = 1; \ + ELSE \ + r = 2; \ + END IF; \ + RETURN r; \ +END \ +$$ LANGUAGE PLPGSQL; + +SELECT proc_is_null(1); +SELECT proc_is_null(NULL); + +DROP FUNCTION proc_is_null(INT); + +CREATE FUNCTION proc_is_not_null(x INT) RETURNS INT AS $$ \ +DECLARE \ + r INT; \ +BEGIN \ + IF x IS NOT NULL THEN \ + r = 1; \ + ELSE \ + r = 2; \ + END IF; \ + RETURN r; \ +END \ +$$ LANGUAGE PLPGSQL; + +SELECT proc_is_not_null(1); +SELECT proc_is_not_null(NULL); + +DROP FUNCTION proc_is_not_null(INT); + +-- ---------------------------------------------------------------------------- +-- proc_length() + +-- Assignment of LENGTH to temporary +CREATE FUNCTION proc_length(t VARCHAR) RETURNS INT AS $$ \ +DECLARE \ + r INT; \ +BEGIN \ + r = LENGTH(t); \ + RETURN r; \ +END \ +$$ LANGUAGE PLPGSQL; + +SELECT proc_length('hello'); +DROP FUNCTION proc_length(VARCHAR); + +-- Direct RETURN of LENGTH +CREATE FUNCTION proc_length(t VARCHAR) RETURNS INT AS $$ \ +BEGIN \ + RETURN LENGTH(t); \ +END \ +$$ LANGUAGE PLPGSQL; + +SELECT proc_length('hello'); + +DROP FUNCTION proc_length(VARCHAR); + +-- Use of LENGTH() in conditional +CREATE FUNCTION proc_length(t VARCHAR) RETURNS INT AS $$ \ +BEGIN \ + IF LENGTH(t) > 1 THEN \ + RETURN 1; \ + ELSE \ + RETURN 2; \ + END IF; \ + RETURN 0; \ +END \ +$$ LANGUAGE PLPGSQL; + +SELECT proc_length('a'); +SELECT proc_length('ab'); +SELECT proc_length('abc'); + +DROP FUNCTION proc_length(VARCHAR); + +-- ---------------------------------------------------------------------------- +-- proc_substr() + +-- Able to pass all arguments through +CREATE FUNCTION proc_substr(t VARCHAR, i INT, l INT) RETURNS VARCHAR AS $$ \ +BEGIN \ + RETURN SUBSTR(t, i, l); \ +END \ +$$ LANGUAGE PLPGSQL; + +SELECT proc_substr('hello', 1, 1); +SELECT proc_substr('hello', 1, 2); + +DROP FUNCTION proc_substr(VARCHAR, INT, INT); + +-- Able to specify a literal value +CREATE FUNCTION proc_substr(t VARCHAR, i INT) RETURNS VARCHAR AS $$ \ +BEGIN \ + RETURN SUBSTR(t, i, 1); \ +END \ +$$ LANGUAGE PLPGSQL; + +SELECT proc_substr('hello', 1); +SELECT proc_substr('hello', 2); + +DROP FUNCTION proc_substr(VARCHAR, INT); diff --git a/script/testing/junit/src/GenerateTrace.java b/script/testing/junit/src/GenerateTrace.java index 7894398a2d..436ead1656 100644 --- a/script/testing/junit/src/GenerateTrace.java +++ b/script/testing/junit/src/GenerateTrace.java @@ -1,67 +1,190 @@ -import java.io.*; -import java.sql.*; -import java.util.ArrayList; -import java.util.Arrays; +/** + * GenerateTrace.java + */ + +import java.io.File; +import java.io.FileReader; +import java.io.FileWriter; +import java.io.IOException; +import java.io.BufferedReader; + +import java.sql.ResultSet; +import java.sql.Statement; +import java.sql.Connection; +import java.sql.SQLException; +import java.sql.ResultSetMetaData; + import java.util.List; -import moglib.*; +import moglib.MogDb; +import moglib.MogSqlite; +import moglib.Constants; + +/** + * A generic logger interface. + * (Apparently `Logger` is already taken) + */ +interface ILogger { + /** + * Log an informational message. + * @param message The message + */ + public void info(final String message); + + /** + * Log an error message. + * @param message The message + */ + public void error(final String message); +} /** - * class that convert sql statements to trace format - * first, establish a local postgresql database - * second, start the database server with "pg_ctl -D /usr/local/var/postgres start" - * third, modify the url, user and password string to match the database you set up - * finally, provide path to a file, run generateTrace with the file path as argument - * input file format: sql statements, one per line - * output file: to be tested by TracefileTest + * A dummy logger class that just writes to standard output. + * + * We might want to replace this eventually with an actual + * logger implementation, and this dummy class might(?) make + * that transition slightly less painful. For now, it also + * provides the slight benefit of making logging less verbose. + */ +class StandardLogger implements ILogger { + /** + * Construct a logger instance. + */ + StandardLogger() {} + + /** + * Log an informational message. + * @param message + */ + public void info(final String message) { + System.out.println(message); + } + + /** + * Log an error message. + * @param message + */ + public void error(final String message) { + System.err.println(message); + } +} + +/** + * The GenerateTrace class converts SQL statements to the tracefile + * format used for integration testing. For instructions on how to + * use this program to generate a tracefile, see junit/README. */ public class GenerateTrace { + /** + * Error code for process exit on program success. + */ + private static final int EXIT_SUCCESS = 0; + + /** + * Error code for process exit on program failure. + */ + private static final int EXIT_ERROR = 1; + + /** + * The expected number of commandline arguments. + */ + private static final int EXPECTED_ARGUMENT_COUNT = 5; + + /** + * The character used to delimit multiline statements (e.g. UDF definition). + */ + private static final String MULTILINE_DELIMITER = "\\"; + + /** + * The current working directory. + */ + private static final String WORKING_DRIECTORY = System.getProperty("user.dir"); + + /** + * The logger instance. + */ + private static final ILogger LOGGER = new StandardLogger(); + + /** + * Program entry point. + * @param args Commandline arguments + * @throws Throwable + */ public static void main(String[] args) throws Throwable { - System.out.println("Working Directory = " + System.getProperty("user.dir")); - String path = args[0]; - File file = new File(path); - System.out.println("File path: " + path); - MogSqlite mog = new MogSqlite(file); - // open connection to postgresql database with jdbc - MogDb db = new MogDb(args[1], args[2], args[3]); - Connection conn = db.getDbTest().newConn(); - // remove existing table name - List tab = getAllExistingTableName(mog,conn); - removeExistingTable(tab,conn); + if (args.length < EXPECTED_ARGUMENT_COUNT) { + LOGGER.error("Error: invalid arguments"); + LOGGER.error("Usage: see junit/README.md"); + System.exit(EXIT_ERROR); + } + + LOGGER.info("Working Directory = " + WORKING_DRIECTORY); + + // Parse commandline arguments + final String inputPath = args[0]; + final String jdbcUrl = args[1]; + final String dbUsername = args[2]; + final String dbPassword = args[3]; + final String outputPath = args[4]; + + MogSqlite mog = new MogSqlite(new File(inputPath)); + + // Open connection to Postgre database over JDBC + MogDb db = new MogDb(jdbcUrl, dbUsername, dbPassword); + Connection connection = db.getDbTest().newConn(); + + // Initialize the database + removeAllTables(mog, connection); + removeAllFunctions(mog, connection); + + BufferedReader reader = new BufferedReader(new FileReader(new File(inputPath))); + FileWriter writer = new FileWriter(new File(Constants.DEST_DIR, outputPath)); + System.exit(run(db, mog, connection, reader, writer)); + } + + /** + * Run trace generation. + * @param db The `MogDb` instance + * @param mog The `MogSqlite` instance + * @param connection The database connection + * @param reader The buffered reader for the input file + * @param writer The file writer for the output file + * @return The status code + */ + private static int run(MogDb db, MogSqlite mog, Connection connection, + BufferedReader reader, FileWriter writer) throws SQLException, IOException { String line; String label; Statement statement = null; - BufferedReader br = new BufferedReader(new FileReader(file)); - // create output file - FileWriter writer = new FileWriter(new File(Constants.DEST_DIR, args[4])); + int expected_result_num = -1; boolean include_result = false; - while (null != (line = br.readLine())) { + while (null != (line = readLine(reader, MULTILINE_DELIMITER))) { line = line.trim(); - // execute sql statement - try{ - statement = conn.createStatement(); + + // Execute SQL statement + try { + statement = connection.createStatement(); statement.execute(line); label = Constants.STATEMENT_OK; } catch (SQLException e) { - System.err.println("Error executing SQL Statement: '" + line + "'; " + e.getMessage()); + LOGGER.error("Error executing SQL Statement: '" + line + "'; " + e.getMessage()); label = Constants.STATEMENT_ERROR; } catch (Throwable e) { label = Constants.STATEMENT_ERROR; } - if(line.startsWith("SELECT") || line.toLowerCase().startsWith("with")) { + if (line.startsWith("SELECT") || line.startsWith("WITH")) { ResultSet rs = statement.getResultSet(); - if (line.toLowerCase().startsWith("with") && null == rs) { + if (line.startsWith("WITH") && null == rs) { // We might have a query that begins with `WITH` that has a null result set int updateCount = statement.getUpdateCount(); // check if expected number is equal to update count - if(expected_result_num>=0 && expected_result_num!=updateCount){ + if (expected_result_num >= 0 && expected_result_num != updateCount) { label = Constants.STATEMENT_ERROR; } - writeToFile(writer, label); - writeToFile(writer, line); + writeLine(writer, label); + writeLine(writer, line); writer.write('\n'); expected_result_num = -1; continue; @@ -72,13 +195,13 @@ public static void main(String[] args) throws Throwable { for (int i = 1; i <= rsmd.getColumnCount(); ++i) { String colTypeName = rsmd.getColumnTypeName(i); MogDb.DbColumnType colType = db.getDbTest().getDbColumnType(colTypeName); - if(colType==MogDb.DbColumnType.FLOAT){ + if (colType == MogDb.DbColumnType.FLOAT) { typeString += "R"; - }else if(colType==MogDb.DbColumnType.INTEGER){ + } else if (colType == MogDb.DbColumnType.INTEGER) { typeString += "I"; - }else if(colType==MogDb.DbColumnType.TEXT){ + } else if(colType == MogDb.DbColumnType.TEXT) { typeString += "T"; - }else{ + } else { System.out.println(colTypeName + " column invalid"); } } @@ -93,98 +216,219 @@ public static void main(String[] args) throws Throwable { sortOption = "rowsort"; mog.sortMode = "rowsort"; } - String query_sort = Constants.QUERY + " " + typeString + " " + sortOption; - writeToFile(writer, query_sort); - writeToFile(writer, line); - writeToFile(writer, Constants.SEPARATION); - List res = mog.processResults(rs); - // compute the hash - String hash = TestUtility.getHashFromDb(res); - String queryResult = ""; - // when include_result is true, set queryResult to be exact result instead of hash - if(include_result){ - for(String i:res){ - queryResult += i; - queryResult += "\n"; + final String query_sort = Constants.QUERY + " " + typeString + " " + sortOption; + writeLine(writer, query_sort); + writeLine(writer, line); + writeLine(writer, Constants.SEPARATION); + + final List results = mog.processResults(rs); + final String hash = TestUtility.getHashFromDb(results); + + StringBuilder resultBuilder = new StringBuilder(); + if (include_result) { + for (final String result : results) { + resultBuilder.append(result); + resultBuilder.append('\n'); } - queryResult = queryResult.trim(); - }else{ - // if expected number of results is specified - if(expected_result_num>=0){ - queryResult = "Expected " + expected_result_num + " values hashing to " + hash; - }else{ - if(res.size()>0){ - // set queryResult to format x values hashing to xxx - queryResult = res.size() + " values hashing to " + hash; + } else { + // Expected number of results is specified + if (expected_result_num >= 0) { + resultBuilder.append("Expected " + expected_result_num + " values hashing to " + hash); + } else { + if (results.size() > 0) { + resultBuilder.append(results.size() + " values hashing to " + hash); } - // set queryResult to be exact result instead of hash when - // result size is smaller than Constants.DISPLAY_RESULT_SIZE - if(res.size() < Constants.DISPLAY_RESULT_SIZE){ - queryResult = ""; - for(String i:res){ - queryResult += i; - queryResult += "\n"; + if (results.size() < Constants.DISPLAY_RESULT_SIZE) { + resultBuilder.setLength(0); + for (final String result : results) { + resultBuilder.append(result); + resultBuilder.append('\n'); } - queryResult = queryResult.trim(); } } } - writeToFile(writer, queryResult); - if(res.size()>0){ + + writeLine(writer, resultBuilder.toString()); + if (results.size() > 0) { writer.write('\n'); } + include_result = false; expected_result_num = -1; - } else if(line.startsWith(Constants.HASHTAG)){ - writeToFile(writer, line); - if(line.contains(Constants.NUM_OUTPUT_FLAG)){ - // case for specifying the expected number of outputs - String[] arr = line.split(" "); - expected_result_num = Integer.parseInt(arr[arr.length-1]); - }else if(line.contains(Constants.FAIL_FLAG)){ - // case for expecting the query to fail + } else if (line.startsWith(Constants.HASHTAG)) { + writeLine(writer, line); + if (line.contains(Constants.NUM_OUTPUT_FLAG)) { + // Case for specifying the expected number of outputs + final String[] arr = line.split(" "); + expected_result_num = Integer.parseInt(arr[arr.length - 1]); + } else if (line.contains(Constants.FAIL_FLAG)) { + // Case for expecting the query to fail label = Constants.STATEMENT_ERROR; - } else if(line.contains(Constants.EXPECTED_OUTPUT_FLAG)){ - // case for including exact result in mog.queryResult + } else if (line.contains(Constants.EXPECTED_OUTPUT_FLAG)) { + // Case for including exact result in mog.queryResult include_result = true; } - } else{ - // other sql statements - int rs = statement.getUpdateCount(); + } else { + // Other sql statements + final int updateCount = statement.getUpdateCount(); // check if expected number is equal to update count - if(expected_result_num>=0 && expected_result_num!=rs){ + if (expected_result_num >= 0 && expected_result_num != updateCount){ label = Constants.STATEMENT_ERROR; } - writeToFile(writer, label); - writeToFile(writer, line); + writeLine(writer, label); + writeLine(writer, line); writer.write('\n'); expected_result_num = -1; } } // Prevents tests from erroring out when trace file ends with a comment - writeToFile(writer, Constants.STATEMENT_OK); + writeLine(writer, Constants.STATEMENT_OK); writer.close(); - br.close(); + reader.close(); + + return EXIT_SUCCESS; } - public static void writeToFile(FileWriter writer, String str) throws IOException { - writer.write(str); + /** + * Read a line from the specified `BufferedReader` instance. + * @param reader The instance from which lines are read + * @param delimiter The character used to delimit multiline statements + * @return The input line, or `null` on end of input + */ + private static String readLine(BufferedReader reader, final String delimiter) throws IOException { + StringBuilder builder = new StringBuilder(); + for (;;) { + final String input = reader.readLine(); + if (input == null) { + return null; + } + + if (input.endsWith(delimiter)) { + builder.append( + input.substring(0, input.length() - delimiter.length() - 1) + .trim() + " "); + } else { + builder.append(input); + break; + } + } + return builder.toString(); + } + + /** + * Write the specified line to a file using the provided `FileWriter`. + * @param writer The `FileWriter` instance + * @param line The line to be written + * @throws IOException On IO error + */ + public static void writeLine(FileWriter writer, final String line) throws IOException { + writer.write(line); writer.write('\n'); } - public static void removeExistingTable(List tab, Connection connection) throws SQLException { - for(String i:tab){ - Statement st = connection.createStatement(); - String sql = "DROP TABLE IF EXISTS " + i + " CASCADE"; - st.execute(sql); + /* ------------------------------------------------------------------------ + Table Management + ------------------------------------------------------------------------ */ + + /** + * Remove all existing tables from the database + * @param mog The `MogSqlite` instance + * @param connection The database connection + * @throws SQLException On SQL error + */ + private static void removeAllTables(MogSqlite mog, Connection connection) throws SQLException { + final List tableNames = getExistingTableNames(mog, connection); + removeTables(tableNames, connection); + } + + /** + * Get the names of all existing tables in the database. + * @param mog The `MogSqlite` instance + * @param connection The database connection + * @return A list of all table names + * @throws SQLException On SQL exception + */ + public static List getExistingTableNames(MogSqlite mog, Connection connection) throws SQLException { + final String query = "SELECT TABLENAME FROM pg_tables WHERE schemaname = 'public';"; + Statement statement = connection.createStatement(); + statement.execute(query); + return mog.processResults(statement.getResultSet()); + } + + /** + * Remove all specified tables from the database. + * @param tableNames The collection of table names to remove + * @param connection The database connection + * @throws SQLException On SQL error + */ + private static void removeTables(final List tableNames, Connection connection) throws SQLException { + for (final String tableName : tableNames){ + removeTable(tableName, connection); } } - public static List getAllExistingTableName(MogSqlite mog,Connection connection) throws SQLException { - Statement st = connection.createStatement(); - String getTableName = "SELECT tablename FROM pg_tables WHERE schemaname = 'public';"; - st.execute(getTableName); - ResultSet rs = st.getResultSet(); - List res = mog.processResults(rs); - return res; + + /** + * Remove the specified table from the database. + * @param tableName The name of the table to remove + * @param connection The database connection + * @throws SQLException On SQL error + */ + private static void removeTable(final String tableName, Connection connection) throws SQLException { + final String query = "DROP TABLE IF EXISTS " + tableName + " CASCADE"; + Statement statement = connection.createStatement(); + statement.execute(query); + } + + /* ------------------------------------------------------------------------ + Function Management + ------------------------------------------------------------------------ */ + + /** + * Remove all existing functions from the database. + * @param mog The `MogSqlite` instance. + * @param connection The database connection. + * @throws SQLException On SQL error + */ + private static void removeAllFunctions(MogSqlite mog, Connection connection) throws SQLException { + final List functionNames = getExistingFunctions(mog, connection); + removeFunctions(functionNames, connection); + } + + /** + * Get the names of all existing functions in the database. + * @param mog The MogSqlite instance + * @param connection The databse connection + * @return A collection of the function names + * @throws SQLException On SQL error + */ + private static List getExistingFunctions(MogSqlite mog, Connection connection) throws SQLException { + final String query = "SELECT proname FROM pg_proc WHERE pronamespace = 'public'::regnamespace;"; + Statement statement = connection.createStatement(); + statement.execute(query); + return mog.processResults(statement.getResultSet()); + } + + /** + * Remove all of the functions in `functionNames` from the database. + * @param functionNames The names of the functions to remove + * @param connection The database connection + * @throws SQLException On SQL error + */ + private static void removeFunctions(final List functionNames, Connection connection) throws SQLException { + for (final String functionName : functionNames) { + removeFunction(functionName, connection); + } + } + + /** + * Remove the function identified by `functionName` from the database. + * @param functionName The name of the function to remove + * @param connection The database connection + * @throws SQLException On SQL error + */ + private static void removeFunction(final String functionName, Connection connection) throws SQLException { + final String query = "DROP FUNCTION IF EXISTS " + functionName + " CASCADE;"; + Statement statement = connection.createStatement(); + statement.execute(query); } } diff --git a/script/testing/junit/traces/udf.test b/script/testing/junit/traces/udf.test new file mode 100644 index 0000000000..9321696cb4 --- /dev/null +++ b/script/testing/junit/traces/udf.test @@ -0,0 +1,1379 @@ +statement ok +-- udf.sql + +statement ok +-- Integration tests for user-defined functions. + +statement ok +-- + +statement ok +-- Currently, these tests rely on the fact that we + +statement ok +-- utilize Postgres as a reference implementation + +statement ok +-- because all user-defined functions are implemented + +statement ok +-- in the Postgres PL/SQL dialect, PL/pgSQL. + +statement ok + + +statement ok +-- Create test tables + +statement ok +CREATE TABLE integers(x INT, y INT); + +statement ok +INSERT INTO integers(x, y) VALUES (1, 1), (2, 2), (3, 3); + +statement ok + + +statement ok +CREATE TABLE strings(s TEXT); + +statement ok +INSERT INTO strings(s) VALUES ('aaa'), ('bbb'), ('ccc'); + +statement ok + + +statement ok +-- ---------------------------------------------------------------------------- + +statement ok +-- return_constant() + +statement ok + + +statement ok +CREATE FUNCTION return_constant() RETURNS INT AS $$ BEGIN RETURN 1; END $$ LANGUAGE PLPGSQL; + +statement ok + + +query I rowsort +SELECT return_constant(); +---- +1 + + +statement ok + + +statement ok +DROP FUNCTION return_constant(); + +statement ok + + +statement ok +CREATE FUNCTION return_constant_str() RETURNS TEXT AS $$ BEGIN RETURN 'hello, functions'; END $$ LANGUAGE PLPGSQL; + +statement ok + + +query T rowsort +SELECT return_constant_str(); +---- +hello, functions + + +statement ok + + +statement ok +DROP FUNCTION return_constant_str(); + +statement ok + + +statement ok +-- ---------------------------------------------------------------------------- + +statement ok +-- return_input() + +statement ok + + +statement ok +CREATE FUNCTION return_input(x INT) RETURNS INT AS $$ BEGIN RETURN x; END $$ LANGUAGE PLPGSQL; + +statement ok + + +query II rowsort +SELECT x, return_input(x) FROM integers; +---- +1 +1 +2 +2 +3 +3 + + +statement ok + + +statement ok +DROP FUNCTION return_input(INT); + +statement ok + + +statement ok +CREATE FUNCTION return_input(x TEXT) RETURNS TEXT AS $$ BEGIN RETURN x; END $$ LANGUAGE PLPGSQL; + +statement ok + + +query TT rowsort +SELECT s, return_input(s) FROM strings; +---- +aaa +aaa +bbb +bbb +ccc +ccc + + +statement ok + + +statement ok +DROP FUNCTION return_input(TEXT); + +statement ok + + +statement ok +-- ---------------------------------------------------------------------------- + +statement ok +-- return_sum() + +statement ok + + +statement ok +CREATE FUNCTION return_sum(x INT, y INT) RETURNS INT AS $$ BEGIN RETURN x + y; END $$ LANGUAGE PLPGSQL; + +statement ok + + +query III rowsort +SELECT x, y, return_sum(x, y) FROM integers; +---- +1 +1 +2 +2 +2 +4 +3 +3 +6 + + +statement ok + + +statement ok +DROP FUNCTION return_sum(INT, INT); + +statement ok + + +statement ok +-- ---------------------------------------------------------------------------- + +statement ok +-- return_prod() + +statement ok + + +statement ok +CREATE FUNCTION return_product(x INT, y INT) RETURNS INT AS $$ BEGIN RETURN x * y; END $$ LANGUAGE PLPGSQL; + +statement ok + + +query III rowsort +SELECT x, y, return_product(x, y) FROM integers; +---- +1 +1 +1 +2 +2 +4 +3 +3 +9 + + +statement ok + + +statement ok +DROP FUNCTION return_product(INT, INT); + +statement ok + + +statement ok +-- ---------------------------------------------------------------------------- + +statement ok +-- integer_decl() + +statement ok + + +statement ok +CREATE FUNCTION integer_decl() RETURNS INT AS $$ DECLARE x INT := 0; BEGIN RETURN x; END $$ LANGUAGE PLPGSQL; + +statement ok + + +query I rowsort +SELECT integer_decl(); +---- +0 + + +statement ok + + +statement ok +DROP FUNCTION integer_decl(); + +statement ok + + +statement ok +-- ---------------------------------------------------------------------------- + +statement ok +-- conditional() + +statement ok +-- + +statement ok +-- TODO(Kyle): The final RETURN 0 is unreachable, but we + +statement ok +-- need this temporary hack to deal with missing logic in parser + +statement ok + + +statement ok +CREATE FUNCTION conditional(x INT) RETURNS INT AS $$ BEGIN IF x > 1 THEN RETURN 1; ELSE RETURN 2; END IF; RETURN 0; END $$ LANGUAGE PLPGSQL; + +statement ok + + +query II rowsort +SELECT x, conditional(x) FROM integers; +---- +1 +2 +2 +1 +3 +1 + + +statement ok + + +statement ok +DROP FUNCTION conditional(INT); + +statement ok + + +statement ok +-- Nested conditional control flow + +statement ok +CREATE FUNCTION conditional(x INT, y INT) RETURNS INT AS $$ BEGIN IF x > 1 THEN IF y > 1 THEN RETURN 1; ELSE RETURN 2; END IF; ELSE IF y > 1 THEN RETURN 3; ELSE RETURN 4; END IF; END IF; RETURN 0; END $$ LANGUAGE PLPGSQL; + +statement ok + + +query I rowsort +SELECT conditional(1, 1); +---- +4 + + +query I rowsort +SELECT conditional(1, 2); +---- +3 + + +query I rowsort +SELECT conditional(2, 1); +---- +2 + + +query I rowsort +SELECT conditional(2, 2); +---- +1 + + +statement ok + + +statement ok +DROP FUNCTION conditional(INT, INT); + +statement ok + + +statement ok +-- ---------------------------------------------------------------------------- + +statement ok +-- proc_while() + +statement ok + + +statement ok +CREATE FUNCTION proc_while() RETURNS INT AS $$ DECLARE x INT := 0; BEGIN WHILE x < 10 LOOP x = x + 1; END LOOP; RETURN x; END $$ LANGUAGE PLPGSQL; + +statement ok + + +query I rowsort +SELECT proc_while(); +---- +10 + + +statement ok + + +statement ok +DROP FUNCTION proc_while(); + +statement ok + + +statement ok +-- ---------------------------------------------------------------------------- + +statement ok +-- proc_fori() + +statement ok +-- + +statement ok +-- TODO(Kyle): for-loop control flow (integer variant) is not supported + +statement ok + + +statement ok +-- CREATE FUNCTION proc_fori() RETURNS INT AS $$ -- DECLARE -- x INT := 0; -- BEGIN -- FOR i IN 1..10 LOOP -- x = x + 1; -- END LOOP; -- RETURN x; -- END -- $$ LANGUAGE PLPGSQL; + +statement ok + + +statement ok +-- SELECT x, proc_fori() FROM integers; + +statement ok + + +statement ok +-- ---------------------------------------------------------------------------- + +statement ok +-- sql_select_single_constant() + +statement ok + + +statement ok +CREATE FUNCTION sql_select_single_constant() RETURNS INT AS $$ DECLARE v INT; BEGIN SELECT 1 INTO v; RETURN v; END $$ LANGUAGE PLPGSQL; + +statement ok + + +query I rowsort +SELECT sql_select_single_constant(); +---- +1 + + +statement ok + + +statement ok +DROP FUNCTION sql_select_single_constant(); + +statement ok + + +statement ok +-- ---------------------------------------------------------------------------- + +statement ok +-- sql_select_mutliple_constants() + +statement ok + + +statement ok +CREATE FUNCTION sql_select_multiple_constants() RETURNS INT AS $$ DECLARE x INT; y INT; BEGIN SELECT 1, 2 INTO x, y; RETURN x + y; END $$ LANGUAGE PLPGSQL; + +statement ok + + +query I rowsort +SELECT sql_select_multiple_constants(); +---- +3 + + +statement ok + + +statement ok +DROP FUNCTION sql_select_multiple_constants(); + +statement ok + + +statement ok +-- ---------------------------------------------------------------------------- + +statement ok +-- sql_select_constant_assignment() + +statement ok + + +statement ok +CREATE FUNCTION sql_select_constant_assignment() RETURNS INT AS $$ DECLARE x INT; y INT; BEGIN x = (SELECT 1); y = (SELECT 2); RETURN x + y; END $$ LANGUAGE PLPGSQL; + +statement ok + + +query I rowsort +SELECT sql_select_constant_assignment(); +---- +3 + + +statement ok + + +statement ok +DROP FUNCTION sql_select_constant_assignment(); + +statement ok + + +statement ok +-- ---------------------------------------------------------------------------- + +statement ok +-- sql_embedded_agg_count() + +statement ok + + +statement ok +CREATE FUNCTION sql_embedded_agg_count() RETURNS INT AS $$ DECLARE v INT; BEGIN SELECT COUNT(*) FROM integers INTO v; RETURN v; END $$ LANGUAGE PLPGSQL; + +statement ok + + +query I rowsort +SELECT sql_embedded_agg_count(); +---- +3 + + +statement ok + + +statement ok +DROP FUNCTION sql_embedded_agg_count(); + +statement ok + + +statement ok +-- ---------------------------------------------------------------------------- + +statement ok +-- sql_embedded_agg_min() + +statement ok + + +statement ok +CREATE FUNCTION sql_embedded_agg_min() RETURNS INT AS $$ DECLARE v INT; BEGIN SELECT MIN(x) FROM integers INTO v; RETURN v; END $$ LANGUAGE PLPGSQL; + +statement ok + + +query I rowsort +SELECT sql_embedded_agg_min(); +---- +1 + + +statement ok + + +statement ok +DROP FUNCTION sql_embedded_agg_min(); + +statement ok + + +statement ok +-- ---------------------------------------------------------------------------- + +statement ok +-- sql_embedded_agg_max() + +statement ok + + +statement ok +CREATE FUNCTION sql_embedded_agg_max() RETURNS INT AS $$ DECLARE v INT; BEGIN SELECT MAX(x) FROM integers INTO v; RETURN v; END $$ LANGUAGE PLPGSQL; + +statement ok + + +query I rowsort +SELECT sql_embedded_agg_max(); +---- +3 + + +statement ok + + +statement ok +DROP FUNCTION sql_embedded_agg_max(); + +statement ok + + +statement ok +-- ---------------------------------------------------------------------------- + +statement ok +-- sql_embedded_agg_multi() + +statement ok + + +statement ok +CREATE FUNCTION sql_embedded_agg_multi() RETURNS INT AS $$ DECLARE minimum INT; maximum INT; BEGIN minimum = (SELECT MIN(x) FROM integers); maximum = (SELECT MAX(x) FROM integers); RETURN minimum + maximum; END; $$ LANGUAGE PLPGSQL; + +statement ok + + +statement ok +DROP FUNCTION sql_embedded_agg_multi(); + +statement ok + + +statement ok +-- ---------------------------------------------------------------------------- + +statement ok +-- proc_fors_constant_var() + +statement ok + + +statement ok +-- Select constant into a scalar variable + +statement ok +CREATE FUNCTION proc_fors_constant_var() RETURNS INT AS $$ DECLARE v INT; x INT := 0; BEGIN FOR v IN SELECT 1 LOOP x = x + 1; END LOOP; RETURN x; END $$ LANGUAGE PLPGSQL; + +statement ok + + +query I rowsort +SELECT proc_fors_constant_var(); +---- +1 + + +statement ok + + +statement ok +DROP FUNCTION proc_fors_constant_var(); + +statement ok + + +statement ok +-- ---------------------------------------------------------------------------- + +statement ok +-- proc_fors_constant_vars() + +statement ok + + +statement ok +-- Select multiple constants in scalar variables + +statement ok +CREATE FUNCTION proc_fors_constant_vars() RETURNS INT AS $$ DECLARE x INT; y INT; z INT := 0; BEGIN FOR x, y IN SELECT 1, 2 LOOP z = z + 1; END LOOP; RETURN z; END $$ LANGUAGE PLPGSQL; + +statement ok + + +query I rowsort +SELECT proc_fors_constant_vars(); +---- +1 + + +statement ok + + +statement ok +DROP FUNCTION proc_fors_constant_vars(); + +statement ok + + +statement ok +-- ---------------------------------------------------------------------------- + +statement ok +-- proc_fors_rec() + +statement ok +-- + +statement ok +-- TODO(Kyle): RECORD types not supported + +statement ok + + +statement ok +-- -- Bind query result to a RECORD type + +statement ok +-- CREATE FUNCTION proc_fors_rec() RETURNS INT AS $$ -- DECLARE \ + +statement ok +-- x INT := 0; -- v RECORD; -- BEGIN -- FOR v IN (SELECT z FROM temp) LOOP -- x = x + 1; -- END LOOP; -- RETURN x; -- END -- $$ LANGUAGE PLPGSQL; + +statement ok + + +statement ok +-- SELECT proc_fors_rec() FROM integers; + +statement ok + + +statement ok +-- ---------------------------------------------------------------------------- + +statement ok +-- proc_fors_var() + +statement ok + + +statement ok +-- Bind query result directly to INT type + +statement ok +CREATE FUNCTION proc_fors_var() RETURNS INT AS $$ DECLARE c INT := 0; v INT; BEGIN FOR v IN (SELECT x FROM integers) LOOP c = c + 1; END LOOP; RETURN c; END $$ LANGUAGE PLPGSQL; + +statement ok + + +query I rowsort +SELECT proc_fors_var(); +---- +3 + + +statement ok + + +statement ok +DROP FUNCTION proc_fors_var(); + +statement ok + + +statement ok +-- ---------------------------------------------------------------------------- + +statement ok +-- proc_call_*() + +statement ok + + +statement ok +CREATE FUNCTION proc_call_callee() RETURNS INT AS $$ BEGIN RETURN 1; END $$ LANGUAGE PLPGSQL; + +statement ok + + +statement ok +-- Just RETURN the result of call + +statement ok +CREATE FUNCTION proc_call_ret() RETURNS INT AS $$ BEGIN RETURN proc_call_callee(); END $$ LANGUAGE PLPGSQL; + +statement ok + + +query I rowsort +SELECT proc_call_ret(); +---- +1 + + +statement ok + + +statement ok +-- Assign the result of call to variable + +statement ok +CREATE FUNCTION proc_call_assign() RETURNS INT AS $$ DECLARE v INT; BEGIN v = proc_call_callee(); RETURN v; END $$ LANGUAGE PLPGSQl; + +statement ok + + +query I rowsort +SELECT proc_call_assign(); +---- +1 + + +statement ok + + +statement ok +-- SELECT the result of call into variable + +statement ok +CREATE FUNCTION proc_call_select() RETURNS INT AS $$ DECLARE v INT; BEGIN SELECT proc_call_callee() INTO v; RETURN v; END $$ LANGUAGE PLPGSQL; + +statement ok + + +query I rowsort +SELECT proc_call_select(); +---- +1 + + +statement ok + + +statement ok +DROP FUNCTION proc_call_callee(); + +statement ok +DROP FUNCTION proc_call_ret(); + +statement ok +DROP FUNCTION proc_call_assign(); + +statement ok +DROP FUNCTION proc_call_select(); + +statement ok + + +statement ok +-- ---------------------------------------------------------------------------- + +statement ok +-- proc_predicate() + +statement ok + + +statement ok +CREATE FUNCTION proc_predicate(threshold INT) RETURNS INT AS $$ DECLARE c INT; BEGIN SELECT COUNT(x) FROM integers WHERE x > threshold INTO c; RETURN c; END $$ LANGUAGE PLPGSQL; + +statement ok + + +query I rowsort +SELECT proc_predicate(0); +---- +3 + + +query I rowsort +SELECT proc_predicate(1); +---- +2 + + +query I rowsort +SELECT proc_predicate(2); +---- +1 + + +statement ok + + +statement ok +DROP FUNCTION proc_predicate(INT); + +statement ok + + +statement ok +-- ---------------------------------------------------------------------------- + +statement ok +-- proc_call_args() + +statement ok + + +statement ok +-- Argument to call can be an expression + +statement ok +CREATE FUNCTION proc_call_args() RETURNS INT AS $$ DECLARE x INT := 1; y INT := 2; z INT := 3; BEGIN RETURN ABS(x * y + z); END $$ LANGUAGE PLPGSQL; + +statement ok + + +query I rowsort +SELECT proc_call_args(); +---- +5 + + +statement ok + + +statement ok +DROP FUNCTION proc_call_args(); + +statement ok + + +statement ok +-- Argument to call can be an identifier + +statement ok +CREATE FUNCTION proc_call_args() RETURNS INT AS $$ DECLARE x INT := 1; y INT := 2; z INT := 3; r INT; BEGIN r = x * y + z; RETURN ABS(r); END $$ LANGUAGE PLPGSQL; + +statement ok + + +query I rowsort +SELECT proc_call_args(); +---- +5 + + +statement ok + + +statement ok +DROP FUNCTION proc_call_args(); + +statement ok + + +statement ok +-- ---------------------------------------------------------------------------- + +statement ok +-- proc_promotion() + +statement ok + + +statement ok +-- Able to (silently) promote REAL to DOUBLE PRECISION + +statement ok +CREATE FUNCTION proc_promotion() RETURNS REAL AS $$ DECLARE x INT := 1; y REAL := 1.0; t REAL; BEGIN t = x * y; RETURN FLOOR(t); END $$ LANGUAGE PLPGSQL; + +statement ok + + +query rowsort +SELECT proc_promotion(); +---- +1 + + +statement ok +DROP FUNCTION proc_promotion(); + +statement ok + + +statement ok +-- Able to (silently) promote FLOAT to DOUBLE PRECISION + +statement ok +CREATE FUNCTION proc_promotion() RETURNS FLOAT AS $$ DECLARE x INT := 1; y FLOAT := 1.0; t FLOAT; BEGIN t = x * y; RETURN FLOOR(t); END $$ LANGUAGE PLPGSQL; + +statement ok + + +query R rowsort +SELECT proc_promotion(); +---- +1 + + +statement ok +DROP FUNCTION proc_promotion(); + +statement ok + + +statement ok +-- Promotion does not affect correct operation of DOUBLE PRECISION + +statement ok +CREATE FUNCTION proc_promotion() RETURNS DOUBLE PRECISION AS $$ DECLARE x INT := 1; y DOUBLE PRECISION := 1.0; t DOUBLE PRECISION; BEGIN t = x * y; RETURN FLOOR(t); END $$ LANGUAGE PLPGSQL; + +statement ok + + +query R rowsort +SELECT proc_promotion(); +---- +1 + + +statement ok +DROP FUNCTION proc_promotion(); + +statement ok + + +statement ok +-- Promotion does not affect correct operation of FLOAT8 + +statement ok +CREATE FUNCTION proc_promotion() RETURNS DOUBLE PRECISION AS $$ DECLARE x INT := 1; y DOUBLE PRECISION := 1.0; t DOUBLE PRECISION; BEGIN t = x * y; RETURN FLOOR(t); END $$ LANGUAGE PLPGSQL; + +statement ok + + +query R rowsort +SELECT proc_promotion(); +---- +1 + + +statement ok +DROP FUNCTION proc_promotion(); + +statement ok + + +statement ok +-- Promotion works as expected with UDF arguments + +statement ok +CREATE FUNCTION proc_promotion(x FLOAT) RETURNS FLOAT AS $$ BEGIN RETURN x; END $$ LANGUAGE PLPGSQL; + +statement ok + + +query R rowsort +SELECT proc_promotion(1337.0); +---- +1337 + + +statement ok +DROP FUNCTION proc_promotion(FLOAT); + +statement ok + + +statement ok +-- Promotion works as expected with UDF arguments + +statement ok +CREATE FUNCTION proc_promotion(x REAL) RETURNS REAL AS $$ BEGIN RETURN x; END $$ LANGUAGE PLPGSQL; + +statement ok + + +query rowsort +SELECT proc_promotion(1337.0); +---- +1337 + + +statement ok +DROP FUNCTION proc_promotion(REAL); + +statement ok + + +statement ok +-- Promotion works as expected with UDF arguments + +statement ok +CREATE FUNCTION proc_promotion(x DOUBLE PRECISION) RETURNS DOUBLE PRECISION AS $$ BEGIN RETURN x; END $$ LANGUAGE PLPGSQL; + +statement ok + + +query R rowsort +SELECT proc_promotion(1337.0); +---- +1337 + + +statement ok +DROP FUNCTION proc_promotion(DOUBLE PRECISION); + +statement ok + + +statement ok +-- Promotion works as expected with UDF arguments + +statement ok +CREATE FUNCTION proc_promotion(x FLOAT8) RETURNS FLOAT8 AS $$ BEGIN RETURN x; END $$ LANGUAGE PLPGSQL; + +statement ok + + +query R rowsort +SELECT proc_promotion(1337.0); +---- +1337 + + +statement ok +DROP FUNCTION proc_promotion(FLOAT8); + +statement ok + + +statement ok +-- ---------------------------------------------------------------------------- + +statement ok +-- proc_cast() + +statement ok + + +statement ok +-- CAST works in assignment expression + +statement ok +CREATE FUNCTION proc_cast() RETURNS FLOAT AS $$ DECLARE x FLOAT; BEGIN x = CAST(1 AS FLOAT); RETURN x; END $$ LANGUAGE PLPGSQL; + +statement ok + + +query R rowsort +SELECT proc_cast(); +---- +1 + + +statement ok +DROP FUNCTION proc_cast(); + +statement ok + + +statement ok +-- TODO(Kyle): this is a great example of a function that + +statement ok +-- we can't currently compile because we only resort to a + +statement ok +-- full handoff to the SQL execution infrastructure in the + +statement ok +-- case of assignment expressions. For everything else, in + +statement ok +-- this case a RETURN statement, we don't yet have the + +statement ok +-- ability to defer this to the SQL engine, and we also can't + +statement ok +-- handle it in the "builtin" manner, so we just fail. + +statement ok + + +statement ok +-- CREATE FUNCTION proc_cast() RETURNS FLOAT AS $$ -- BEGIN -- RETURN CAST(1 AS FLOAT); -- END -- $$ LANGUAGE PLPGSQL; + +statement ok + + +statement ok +-- ---------------------------------------------------------------------------- + +statement ok +-- proc_is_null() + +statement ok + + +statement ok +CREATE FUNCTION proc_is_null(x INT) RETURNS INT AS $$ DECLARE r INT; BEGIN IF x IS NULL THEN r = 1; ELSE r = 2; END IF; RETURN r; END $$ LANGUAGE PLPGSQL; + +statement ok + + +query I rowsort +SELECT proc_is_null(1); +---- +2 + + +query I rowsort +SELECT proc_is_null(NULL); +---- +1 + + +statement ok + + +statement ok +DROP FUNCTION proc_is_null(INT); + +statement ok + + +statement ok +CREATE FUNCTION proc_is_not_null(x INT) RETURNS INT AS $$ DECLARE r INT; BEGIN IF x IS NOT NULL THEN r = 1; ELSE r = 2; END IF; RETURN r; END $$ LANGUAGE PLPGSQL; + +statement ok + + +query I rowsort +SELECT proc_is_not_null(1); +---- +1 + + +query I rowsort +SELECT proc_is_not_null(NULL); +---- +2 + + +statement ok + + +statement ok +DROP FUNCTION proc_is_not_null(INT); + +statement ok + + +statement ok +-- ---------------------------------------------------------------------------- + +statement ok +-- proc_length() + +statement ok + + +statement ok +-- Assignment of LENGTH to temporary + +statement ok +CREATE FUNCTION proc_length(t VARCHAR) RETURNS INT AS $$ DECLARE r INT; BEGIN r = LENGTH(t); RETURN r; END $$ LANGUAGE PLPGSQL; + +statement ok + + +query I rowsort +SELECT proc_length('hello'); +---- +5 + + +statement ok +DROP FUNCTION proc_length(VARCHAR); + +statement ok + + +statement ok +-- Direct RETURN of LENGTH + +statement ok +CREATE FUNCTION proc_length(t VARCHAR) RETURNS INT AS $$ BEGIN RETURN LENGTH(t); END $$ LANGUAGE PLPGSQL; + +statement ok + + +query I rowsort +SELECT proc_length('hello'); +---- +5 + + +statement ok + + +statement ok +DROP FUNCTION proc_length(VARCHAR); + +statement ok + + +statement ok +-- Use of LENGTH() in conditional + +statement ok +CREATE FUNCTION proc_length(t VARCHAR) RETURNS INT AS $$ BEGIN IF LENGTH(t) > 1 THEN RETURN 1; ELSE RETURN 2; END IF; RETURN 0; END $$ LANGUAGE PLPGSQL; + +statement ok + + +query I rowsort +SELECT proc_length('a'); +---- +2 + + +query I rowsort +SELECT proc_length('ab'); +---- +1 + + +query I rowsort +SELECT proc_length('abc'); +---- +1 + + +statement ok + + +statement ok +DROP FUNCTION proc_length(VARCHAR); + +statement ok + + +statement ok +-- ---------------------------------------------------------------------------- + +statement ok +-- proc_substr() + +statement ok + + +statement ok +-- Able to pass all arguments through + +statement ok +CREATE FUNCTION proc_substr(t VARCHAR, i INT, l INT) RETURNS VARCHAR AS $$ BEGIN RETURN SUBSTR(t, i, l); END $$ LANGUAGE PLPGSQL; + +statement ok + + +query T rowsort +SELECT proc_substr('hello', 1, 1); +---- +h + + +query T rowsort +SELECT proc_substr('hello', 1, 2); +---- +he + + +statement ok + + +statement ok +DROP FUNCTION proc_substr(VARCHAR, INT, INT); + +statement ok + + +statement ok +-- Able to specify a literal value + +statement ok +CREATE FUNCTION proc_substr(t VARCHAR, i INT) RETURNS VARCHAR AS $$ BEGIN RETURN SUBSTR(t, i, 1); END $$ LANGUAGE PLPGSQL; + +statement ok + + +query T rowsort +SELECT proc_substr('hello', 1); +---- +h + + +query T rowsort +SELECT proc_substr('hello', 2); +---- +e + + +statement ok + + +statement ok +DROP FUNCTION proc_substr(VARCHAR, INT); + +statement ok diff --git a/script/testing/util/db_server.py b/script/testing/util/db_server.py index bb500a8e62..88538a600a 100644 --- a/script/testing/util/db_server.py +++ b/script/testing/util/db_server.py @@ -164,7 +164,6 @@ def stop_db(self, is_dry_run=False): finally: unix_socket = os.path.join("/tmp/", f".s.PGSQL.{self.db_port}") if os.path.exists(unix_socket): - os.remove(unix_socket) LOG.info(f"Removing: {unix_socket}") self.print_db_logs() exit_code = self.db_process.returncode @@ -493,8 +492,9 @@ def handle_flags(value: str, meta: Dict) -> str: `-attribute=value` and instead want to format it as `-attribute` alone. This preprocessor encapsulates the logic for this transformation. - TODO(Kyle): Do we actually support any arguments like this? - I can't seem to come up with any actual examples... + NOTE(Kyle): At this time it doesn't appear we actually support + any arguments like this, but keeping it in anyway so I don't + inadvertently break something. Arguments --------- @@ -518,11 +518,6 @@ def apply_all(functions: List, init_obj, meta: Dict): Apply all of the functions in `functions` to object `init_obj` sequentially, supplying metadata object `meta` to each function invocation. - TODO(Kyle): Initially I wanted to implement this with function composition - in terms of functools.reduce() which makes it really beautiful, but there - we run into issues with multi-argument callbacks, and the real solution is - to use partial application, but this seemed like overkill... maybe revisit. - Arguments --------- functions : List[function] diff --git a/src/binder/bind_node_visitor.cpp b/src/binder/bind_node_visitor.cpp index 3a92988a70..42eddbb7b6 100644 --- a/src/binder/bind_node_visitor.cpp +++ b/src/binder/bind_node_visitor.cpp @@ -64,6 +64,19 @@ void BindNodeVisitor::BindNameToNode( BindNodeVisitor::~BindNodeVisitor() = default; +std::vector BindNodeVisitor::BindAndGetUDFVariableRefs( + common::ManagedPointer parse_result, + common::ManagedPointer udf_ast_context) { + NOISEPAGE_ASSERT(parse_result != nullptr, "We shouldn't be trying to bind something without a ParseResult."); + sherpa_ = std::make_unique(parse_result, nullptr, nullptr); + NOISEPAGE_ASSERT(sherpa_->GetParseResult()->GetStatements().size() == 1, "Binder can only bind one at a time."); + udf_ast_context_ = udf_ast_context; + sherpa_->GetParseResult()->GetStatement(0)->Accept( + common::ManagedPointer(this).CastManagedPointerTo()); + // TODO(Kyle): This is strange, why are we returning this member by value? + return udf_variable_refs_; +} + void BindNodeVisitor::Visit(common::ManagedPointer node) { BINDER_LOG_TRACE("Visiting AnalyzeStatement ..."); SqlNodeVisitor::Visit(node); @@ -310,6 +323,18 @@ void BindNodeVisitor::Visit(common::ManagedPointer node) common::ErrorCode::ERRCODE_UNDEFINED_OBJECT); } break; + case parser::DropStatement::DropType::kFunction: { + ValidateDatabaseName(node->GetDatabaseName()); + if (catalog_accessor_->GetProcOid(node->GetFunctionName(), node->GetFunctionArguments()) == + catalog::INVALID_PROC_OID) { + // TODO(Kyle): We have all of the information needed for DROP FUNCTION IF EXISTS, + // but it does not seem that there is a way to communicate a non-error failure + // condition during binding, maybe we need to add an error severity to the exception? + throw BINDER_EXCEPTION(fmt::format("function \"{}\" does not exist", node->GetFunctionName()), + common::ErrorCode::ERRCODE_UNDEFINED_OBJECT); + } + break; + } case parser::DropStatement::DropType::kTrigger: // TODO(Ling): Get Trigger OID in catalog? case parser::DropStatement::DropType::kSchema: @@ -667,8 +692,8 @@ void BindNodeVisitor::Visit(common::ManagedPointerGetDesiredType(expr.CastManagedPointerTo()); + // Before checking with the schema, cache the desired type that expr should have + const auto cached_desired_type = sherpa_->GetDesiredType(expr.CastManagedPointerTo()); // TODO(Ling): consider remove precondition check if the *_oid_ will never be initialized till binder // That is, the object would not be initialized using ColumnValueExpression(database_oid, table_oid, column_oid) @@ -683,13 +708,16 @@ void BindNodeVisitor::Visit(common::ManagedPointerGetColumnOid().UnderlyingValue())), common::ErrorCode::ERRCODE_UNDEFINED_COLUMN); } - // Convert all the names to lower cases + // Convert all the names to lower case std::transform(table_alias_name.begin(), table_alias_name.end(), table_alias_name.begin(), ::tolower); std::transform(col_name.begin(), col_name.end(), col_name.begin(), ::tolower); - // Table name not specified in the expression. Loop through all the table in the binder context. - if (table_alias.Empty()) { - if (context_ == nullptr || !context_->SetColumnPosTuple(expr)) { + // Table name not specified in the expression; loop through all the tables in the binder context + if (table_alias_name.empty()) { + if (BindingForUDF() && IsUDFVariable(expr->GetColumnName())) { + // This expression refers to a PL/pgSQL variable + AddUDFVariableReference(expr, expr->GetColumnName()); + } else if (context_ == nullptr || !context_->SetColumnPosTuple(expr)) { throw BINDER_EXCEPTION(fmt::format("column \"{}\" does not exist", col_name), common::ErrorCode::ERRCODE_UNDEFINED_COLUMN); } @@ -704,6 +732,9 @@ void BindNodeVisitor::Visit(common::ManagedPointerGetTableAlias().GetName())) { + // This expression refers to a structural (RECORD) PL/pgSQL variable + AddUDFVariableReference(expr, expr->GetTableAlias().GetName(), expr->GetColumnName()); } else if (context_ == nullptr || !context_->CheckNestedTableColumn(table_alias, col_name, expr)) { throw BINDER_EXCEPTION(fmt::format("Invalid table reference {}", expr->GetTableAlias().GetName()), common::ErrorCode::ERRCODE_UNDEFINED_TABLE); @@ -713,7 +744,8 @@ void BindNodeVisitor::Visit(common::ManagedPointerGetReturnValueType() : desired_type; + const auto desired_type = + cached_desired_type == execution::sql::SqlTypeId::Invalid ? expr->GetReturnValueType() : cached_desired_type; sherpa_->SetDesiredType(expr.CastManagedPointerTo(), desired_type); sherpa_->CheckDesiredType(expr.CastManagedPointerTo()); } @@ -726,7 +758,7 @@ void BindNodeVisitor::Visit(common::ManagedPointer SqlNodeVisitor::Visit(expr); // If any of the operands are typecasts, the typecast children should have been casted by now. Pull the children up. - for (size_t i = 0; i < expr->GetChildrenSize(); ++i) { + for (std::size_t i = 0; i < expr->GetChildrenSize(); ++i) { auto child = expr->GetChild(i); if (parser::ExpressionType::OPERATOR_CAST == child->GetExpressionType()) { NOISEPAGE_ASSERT(parser::ExpressionType::VALUE_CONSTANT == child->GetChild(0)->GetExpressionType(), @@ -734,6 +766,20 @@ void BindNodeVisitor::Visit(common::ManagedPointer expr->SetChild(i, child->GetChild(0)); } } + + for (auto i = 0UL; i < expr->GetChildrenSize(); ++i) { + auto child = expr->GetChild(i); + if (child->GetExpressionType() == parser::ExpressionType::COLUMN_VALUE) { + const auto index = child.CastManagedPointerTo()->GetParamIdx(); + if (index > parser::ColumnValueExpression::INVALID_PARAM_INDEX) { + // replace with PVE + std::unique_ptr pve = std::make_unique(index); + pve->SetReturnValueType(child->GetReturnValueType()); + expr->SetChild(i, common::ManagedPointer(pve)); + sherpa_->GetParseResult()->AddExpression(std::move(pve)); + } + } + } } void BindNodeVisitor::Visit(common::ManagedPointer expr) { @@ -770,20 +816,42 @@ void BindNodeVisitor::Visit(common::ManagedPointer e BINDER_LOG_TRACE("Visiting FunctionExpression ..."); SqlNodeVisitor::Visit(expr); - std::vector arg_types; auto children = expr->GetChildren(); + std::vector arg_types{}; arg_types.reserve(children.size()); for (const auto &child : children) { arg_types.push_back(catalog_accessor_->GetTypeOidFromTypeId(child->GetReturnValueType())); } - auto proc_oid = catalog_accessor_->GetProcOid(expr->GetFuncName(), arg_types); + // Resolve the argument types to handle the case where an untyped NULL is passed + const auto resolved_types = catalog_accessor_->ResolveProcArgumentTypes(expr->GetFuncName(), arg_types); + if (resolved_types.empty()) { + throw BINDER_EXCEPTION("Procedure not registered", common::ErrorCode::ERRCODE_UNDEFINED_FUNCTION); + } else if (resolved_types.size() > 1) { + throw BINDER_EXCEPTION("Procedure call is ambiguous", common::ErrorCode::ERRCODE_UNDEFINED_FUNCTION); + } + + // This lookup should now always succeed + auto proc_oid = catalog_accessor_->GetProcOid(expr->GetFuncName(), resolved_types.front()); if (proc_oid == catalog::INVALID_PROC_OID) { throw BINDER_EXCEPTION("Procedure not registered", common::ErrorCode::ERRCODE_UNDEFINED_FUNCTION); } - auto func_context = catalog_accessor_->GetFunctionContext(proc_oid); + // The function is now resolved; we need to perform one further substitution + // here to handle the case where a literal untyped NULL is provided as an + // argument to the function call. In this case, the execution engine has no + // way to model the untyped NULL, so we need to replace this with a typed NULL + // from the function call argument that was resolved above + for (std::size_t i = 0; i < children.size(); ++i) { + auto child = children[i]; + if (child->GetExpressionType() == parser::ExpressionType::VALUE_CONSTANT && + child->GetReturnValueType() == execution::sql::SqlTypeId::Invalid) { + auto cve = child.CastManagedPointerTo(); + cve->SetValue(catalog_accessor_->GetTypeIdFromTypeOid(resolved_types.front()[0]), execution::sql::Val(true)); + } + } + auto func_context = catalog_accessor_->GetFunctionContext(proc_oid); expr->SetProcOid(proc_oid); expr->SetReturnValueType(func_context->GetFunctionReturnType()); } @@ -797,6 +865,9 @@ void BindNodeVisitor::Visit(common::ManagedPointer e void BindNodeVisitor::Visit(common::ManagedPointer expr) { BINDER_LOG_TRACE("Visiting ParameterValueExpression ..."); SqlNodeVisitor::Visit(expr); + if (sherpa_ == nullptr || sherpa_->GetParameters() == nullptr) { + return; + } const common::ManagedPointer param = common::ManagedPointer(&((*(sherpa_->GetParameters()))[expr->GetValueIdx()])); const auto desired_type = sherpa_->GetDesiredType(expr.CastManagedPointerTo()); @@ -1098,4 +1169,46 @@ void BindNodeVisitor::SetUniqueTableAlias(common::ManagedPointerAddTableAliasMapping(node->GetAlias().GetName(), node->GetAlias()); } +bool BindNodeVisitor::BindingForUDF() const { return udf_ast_context_ != nullptr; } + +bool BindNodeVisitor::IsUDFVariable(const std::string &identifier) const { + return udf_ast_context_->HasVariable(identifier); +} + +bool BindNodeVisitor::HaveUDFVariableRef(const std::string &identifier) const { + auto it = std::find_if(udf_variable_refs_.cbegin(), udf_variable_refs_.cend(), + [&identifier](const parser::udf::VariableRef &ref) { return ref.ColumnName() == identifier; }); + return it != udf_variable_refs_.cend(); +} + +void BindNodeVisitor::AddUDFVariableReference(common::ManagedPointer expr, + const std::string &table_name, const std::string &column_name) { + NOISEPAGE_ASSERT(udf_ast_context_->GetVariableTypeFailFast(table_name) == execution::sql::SqlTypeId::Invalid, + "Must be a RECORD type"); + + // Locate the column name in the structure + const auto fields = udf_ast_context_->GetRecordTypeFailFast(table_name); + auto field = std::find_if(fields.cbegin(), fields.cend(), [=](auto p) { return p.first == expr->GetColumnName(); }); + if (field == fields.cend()) { + throw BINDER_EXCEPTION(fmt::format("RECORD type field '{}' not found", expr->GetColumnName()), + common::ErrorCode::ERRCODE_PLPGSQL_ERROR); + } + + if (!HaveUDFVariableRef(column_name)) { + const std::size_t index = udf_variable_refs_.size(); + udf_variable_refs_.emplace_back(table_name, column_name, index); + expr->SetReturnValueType(field->second); + expr->SetParamIdx(index); + } +} + +void BindNodeVisitor::AddUDFVariableReference(common::ManagedPointer expr, + const std::string &column_name) { + if (!HaveUDFVariableRef(column_name)) { + const std::size_t index = udf_variable_refs_.size(); + udf_variable_refs_.emplace_back(column_name, index); + expr->SetReturnValueType(udf_ast_context_->GetVariableTypeFailFast(expr->GetColumnName())); + expr->SetParamIdx(index); + } +} } // namespace noisepage::binder diff --git a/src/catalog/catalog_accessor.cpp b/src/catalog/catalog_accessor.cpp index 5b541e9960..a38f47ce7d 100644 --- a/src/catalog/catalog_accessor.cpp +++ b/src/catalog/catalog_accessor.cpp @@ -192,7 +192,7 @@ proc_oid_t CatalogAccessor::CreateProcedure(const std::string &procname, const l const std::vector &args, const std::vector &arg_types, const std::vector &all_arg_types, - const std::vector &arg_modes, + const std::vector &arg_modes, const type_oid_t rettype, const std::string &src, const bool is_aggregate) { return dbc_->CreateProcedure(txn_, procname, language_oid, procns, variadic_type, args, arg_types, all_arg_types, arg_modes, rettype, src, is_aggregate); @@ -200,6 +200,15 @@ proc_oid_t CatalogAccessor::CreateProcedure(const std::string &procname, const l bool CatalogAccessor::DropProcedure(proc_oid_t proc_oid) { return dbc_->DropProcedure(txn_, proc_oid); } +proc_oid_t CatalogAccessor::GetProcOid(const std::string &procname, const std::vector &arg_types) { + // Transform the string type identifiers to internal type IDs + std::vector types{}; + types.reserve(arg_types.size()); + std::transform(arg_types.cbegin(), arg_types.cend(), std::back_inserter(types), + [this](const std::string &name) { return TypeNameToType(name); }); + return GetProcOid(procname, types); +} + proc_oid_t CatalogAccessor::GetProcOid(const std::string &procname, const std::vector &arg_types) { proc_oid_t ret; for (auto ns_oid : search_path_) { @@ -211,15 +220,35 @@ proc_oid_t CatalogAccessor::GetProcOid(const std::string &procname, const std::v return catalog::INVALID_PROC_OID; } -bool CatalogAccessor::SetFunctionContextPointer(proc_oid_t proc_oid, - const execution::functions::FunctionContext *func_context) { - return dbc_->SetFunctionContextPointer(txn_, proc_oid, func_context); +std::vector> CatalogAccessor::ResolveProcArgumentTypes( + const std::string &procname, const std::vector &arg_types) const { + // Transform the string type identifiers to internal type IDs + std::vector types{}; + types.reserve(arg_types.size()); + std::transform(arg_types.cbegin(), arg_types.cend(), std::back_inserter(types), + [this](const std::string &name) { return TypeNameToType(name); }); + return ResolveProcArgumentTypes(procname, arg_types); +} + +std::vector> CatalogAccessor::ResolveProcArgumentTypes( + const std::string &procname, const std::vector &arg_types) const { + std::vector> types{}; + for (auto ns_oid : search_path_) { + const auto resolved = dbc_->ResolveProcArgumentTypes(txn_, ns_oid, procname, arg_types); + types.insert(types.cend(), resolved.cbegin(), resolved.cend()); + } + return types; } common::ManagedPointer CatalogAccessor::GetFunctionContext(proc_oid_t proc_oid) { return dbc_->GetFunctionContext(txn_, proc_oid); } +bool CatalogAccessor::SetFunctionContext(proc_oid_t proc_oid, + const execution::functions::FunctionContext *func_context) { + return dbc_->SetFunctionContext(txn_, proc_oid, func_context); +} + std::unique_ptr CatalogAccessor::GetColumnStatistics(table_oid_t table_oid, col_oid_t col_oid) { return dbc_->GetColumnStatistics(txn_, table_oid, col_oid); @@ -229,10 +258,14 @@ optimizer::TableStats CatalogAccessor::GetTableStatistics(table_oid_t table_oid) return dbc_->GetTableStatistics(txn_, table_oid); } -type_oid_t CatalogAccessor::GetTypeOidFromTypeId(execution::sql::SqlTypeId type) { +type_oid_t CatalogAccessor::GetTypeOidFromTypeId(execution::sql::SqlTypeId type) const { return dbc_->GetTypeOidForType(type); } +execution::sql::SqlTypeId CatalogAccessor::GetTypeIdFromTypeOid(type_oid_t type) const { + return dbc_->GetTypeForTypeOid(type); +} + common::ManagedPointer CatalogAccessor::GetBlockStore() const { // TODO(Matt): at some point we may decide to adjust the source (i.e. each DatabaseCatalog has one), stick it in a // pg_tablespace table, or we may eliminate the concept entirely. This works for now to allow CREATE nodes to bind a @@ -246,4 +279,36 @@ void CatalogAccessor::RegisterTempTable(table_oid_t table_oid, const common::Man temp_schemas_[table_oid] = schema; } +type_oid_t CatalogAccessor::TypeNameToType(const std::string &type_name) const { + type_oid_t type; + if (type_name == "int2") { + type = GetTypeOidFromTypeId(execution::sql::SqlTypeId::SmallInt); + } else if (type_name == "int4") { + type = GetTypeOidFromTypeId(execution::sql::SqlTypeId::Integer); + } else if (type_name == "int8") { + type = GetTypeOidFromTypeId(execution::sql::SqlTypeId::BigInt); + } else if (type_name == "bool") { + type = GetTypeOidFromTypeId(execution::sql::SqlTypeId::Boolean); + } else if (type_name == "float4") { + // NOTE(Kyle): The "regular" SQL frontend always promotes + // FLOAT / REAL to DOUBLE PRECISION / FLOAT8, so we do the + // same here to remain consistent + type = GetTypeOidFromTypeId(execution::sql::SqlTypeId::Double); + } else if (type_name == "float8") { + type = GetTypeOidFromTypeId(execution::sql::SqlTypeId::Double); + } else if (type_name == "numeric") { + type = GetTypeOidFromTypeId(execution::sql::SqlTypeId::Decimal); + } else if (type_name == "bpchar") { + type = GetTypeOidFromTypeId(execution::sql::SqlTypeId::Char); + } else if (type_name == "varchar" || type_name == "text") { + type = GetTypeOidFromTypeId(execution::sql::SqlTypeId::Varchar); + } else if (type_name == "varbinary") { + type = GetTypeOidFromTypeId(execution::sql::SqlTypeId::Varbinary); + } else { + type = GetTypeOidFromTypeId(execution::sql::SqlTypeId::Invalid); + } + + return type; +} + } // namespace noisepage::catalog diff --git a/src/catalog/database_catalog.cpp b/src/catalog/database_catalog.cpp index 74dc38305b..6fd4aa3c33 100644 --- a/src/catalog/database_catalog.cpp +++ b/src/catalog/database_catalog.cpp @@ -312,11 +312,47 @@ std::vector, const Index return pg_core_.GetIndexes(txn, table); } -type_oid_t DatabaseCatalog::GetTypeOidForType(const execution::sql::SqlTypeId type) { +type_oid_t DatabaseCatalog::GetTypeOidForType(const execution::sql::SqlTypeId type) const { // TODO(WAN): WARNING! Do not change this seeing PgCoreImpl::MakeColumn and PgCoreImpl::CreateColumn. return type_oid_t(static_cast(type)); } +execution::sql::SqlTypeId DatabaseCatalog::GetTypeForTypeOid(type_oid_t type) const { + // NOTE(Kyle): This is a disgusting hack + switch (type.UnderlyingValue()) { + case 0: + return execution::sql::SqlTypeId::Boolean; + case 1: + return execution::sql::SqlTypeId::TinyInt; + case 2: + return execution::sql::SqlTypeId::SmallInt; + case 3: + return execution::sql::SqlTypeId::Integer; + case 4: + return execution::sql::SqlTypeId::BigInt; + case 5: + return execution::sql::SqlTypeId::Real; + case 6: + return execution::sql::SqlTypeId::Double; + case 7: + return execution::sql::SqlTypeId::Decimal; + case 8: + return execution::sql::SqlTypeId::Date; + case 9: + return execution::sql::SqlTypeId::Timestamp; + case 10: + return execution::sql::SqlTypeId::Char; + case 11: + return execution::sql::SqlTypeId::Varchar; + case 12: + return execution::sql::SqlTypeId::Varbinary; + case 255: + return execution::sql::SqlTypeId::Invalid; + default: + UNREACHABLE("Impossible type_oid_t"); + } +} + void DatabaseCatalog::BootstrapTable(const common::ManagedPointer txn, const table_oid_t table_oid, const namespace_oid_t ns_oid, const std::string &name, const Schema &schema, const common::ManagedPointer table_ptr) { @@ -363,9 +399,14 @@ bool DatabaseCatalog::CreateIndexEntry(const common::ManagedPointer txn, - proc_oid_t proc_oid, - const execution::functions::FunctionContext *func_context) { +common::ManagedPointer DatabaseCatalog::GetFunctionContext( + common::ManagedPointer txn, proc_oid_t proc_oid) { + return pg_proc_.GetProcCtxPtr(txn, proc_oid); +} + +bool DatabaseCatalog::SetFunctionContext(common::ManagedPointer txn, + proc_oid_t proc_oid, + const execution::functions::FunctionContext *func_context) { NOISEPAGE_ASSERT( write_lock_.load() == txn->FinishTime(), "Setting the object's pointer should only be done after successful DDL change request. i.e. this txn " @@ -379,13 +420,6 @@ bool DatabaseCatalog::SetFunctionContextPointer(common::ManagedPointer DatabaseCatalog::GetFunctionContext( - common::ManagedPointer txn, proc_oid_t proc_oid) { - auto proc_ctx = pg_proc_.GetProcCtxPtr(txn, proc_oid); - NOISEPAGE_ASSERT(proc_ctx != nullptr, "Dynamically added UDFs are currently not supported."); - return proc_ctx; -} - std::unique_ptr DatabaseCatalog::GetColumnStatistics( common::ManagedPointer txn, table_oid_t table_oid, col_oid_t col_oid) { return pg_stat_.GetColumnStatistics(txn, common::ManagedPointer(this), table_oid, col_oid); @@ -451,7 +485,7 @@ proc_oid_t DatabaseCatalog::CreateProcedure(const common::ManagedPointer &args, const std::vector &arg_types, const std::vector &all_arg_types, - const std::vector &arg_modes, + const std::vector &arg_modes, const type_oid_t rettype, const std::string &src, bool is_aggregate) { if (!TryLock(txn)) return INVALID_PROC_OID; const proc_oid_t proc_oid = proc_oid_t{next_oid_++}; @@ -470,15 +504,78 @@ bool DatabaseCatalog::DropProcedure(const common::ManagedPointer txn, namespace_oid_t procns, const std::string &procname, const std::vector &arg_types) { + if (ContainsUntypedNull(arg_types)) { + // NOTE(Kyle): Should this be a harder error condition (i.e. assertion failure)? + return INVALID_PROC_OID; + } return pg_proc_.GetProcOid(txn, common::ManagedPointer(this), procns, procname, arg_types); } +std::vector> DatabaseCatalog::ResolveProcArgumentTypes( + common::ManagedPointer txn, namespace_oid_t procns, const std::string &procname, + const std::vector &arg_types) { + std::vector> result{}; + ResolveProcArgumentTypes(txn, procns, procname, arg_types, &result); + return result; +} + +void DatabaseCatalog::ResolveProcArgumentTypes(common::ManagedPointer txn, + namespace_oid_t procns, const std::string &procname, + const std::vector &arg_types, + std::vector> *result) { + // If the provided collection of arguments does not contain + // an untyped NULL, all types are fully resolved, bottom out + if (!ContainsUntypedNull(arg_types)) { + if (pg_proc_.GetProcOid(txn, common::ManagedPointer(this), procns, procname, arg_types) != INVALID_PROC_OID) { + result->push_back(arg_types); + } + return; + } + + // Handle the case where an untyped NULL is passed as an argument to the function; + // in this case, we enumerate all possible combinations of types for the NULL argument + + // TODO(Kyle): This is a brittle hack + for (int8_t type_value = static_cast(execution::sql::SqlTypeId::Boolean); + type_value <= static_cast(execution::sql::SqlTypeId::Varbinary); ++type_value) { + const execution::sql::SqlTypeId type = static_cast(type_value); + // Recursively invoke this function; there may be further untyped NULLs + ResolveProcArgumentTypes(txn, procns, procname, ReplaceFirstUntypedNullWith(arg_types, type), result); + } +} + template bool DatabaseCatalog::SetClassPointer(const common::ManagedPointer txn, const ClassOid oid, const Ptr *const pointer, const col_oid_t class_col) { return pg_core_.SetClassPointer(txn, oid, pointer, class_col); } +bool DatabaseCatalog::ContainsUntypedNull(const std::vector &arg_types) const { + const type_oid_t null_oid = GetTypeOidForType(execution::sql::SqlTypeId::Invalid); + return std::any_of(arg_types.cbegin(), arg_types.cend(), [null_oid](const type_oid_t t) { return t == null_oid; }); +} + +std::vector DatabaseCatalog::ReplaceFirstUntypedNullWith(const std::vector &arg_types, + execution::sql::SqlTypeId type) const { + NOISEPAGE_ASSERT(ContainsUntypedNull(arg_types), "Broken precondition"); + const type_oid_t null_oid = GetTypeOidForType(execution::sql::SqlTypeId::Invalid); + auto it = std::find(arg_types.cbegin(), arg_types.cend(), null_oid); + NOISEPAGE_ASSERT(it != arg_types.cend(), "Broken invariant"); + const std::size_t index = std::distance(arg_types.cbegin(), it); + + // Manually construct the modified vector + std::vector modified{}; + modified.reserve(arg_types.size()); + for (std::size_t i = 0; i < arg_types.size(); ++i) { + if (i == index) { + modified.push_back(GetTypeOidForType(type)); + } else { + modified.push_back(arg_types.at(i)); + } + } + return modified; +} + // Template instantiations. #define DEFINE_SET_CLASS_POINTER(ClassOid, Ptr) \ diff --git a/src/catalog/postgres/pg_proc_impl.cpp b/src/catalog/postgres/pg_proc_impl.cpp index 0e6eb310c7..52b0e87deb 100644 --- a/src/catalog/postgres/pg_proc_impl.cpp +++ b/src/catalog/postgres/pg_proc_impl.cpp @@ -78,13 +78,13 @@ bool PgProcImpl::CreateProcedure(const common::ManagedPointer &args, const std::vector &arg_types, const std::vector &all_arg_types, - const std::vector &arg_modes, const type_oid_t rettype, + const std::vector &arg_modes, const type_oid_t rettype, const std::string &src, const bool is_aggregate) { NOISEPAGE_ASSERT(args.size() < UINT16_MAX, "Number of arguments must fit in a SMALLINT"); NOISEPAGE_ASSERT(args.size() == arg_types.size(), "Every input arg needs a type."); NOISEPAGE_ASSERT( arg_modes.empty() || (!(std::all_of(arg_modes.cbegin(), arg_modes.cend(), - [](PgProc::ArgModes mode) { return mode == PgProc::ArgModes::IN; })) && + [](PgProc::ArgMode mode) { return mode == PgProc::ArgMode::IN; })) && arg_modes.size() >= args.size()), "argmodes should be empty unless there are modes other than IN, in which case arg_modes must be at " "least equal to the size of args."); @@ -168,7 +168,6 @@ bool PgProcImpl::CreateProcedure(const common::ManagedPointerSet(name_map.at(indexkeycol_oid_t(2)), name_varlen, false); if (auto result = procs_name_index_->Insert(txn, *name_pr, tuple_slot); !result) { - delete[] buffer; return false; } } @@ -186,7 +185,7 @@ bool PgProcImpl::CreateProcedure(const common::ManagedPointer txn, proc_oid_t proc) { - NOISEPAGE_ASSERT(proc != INVALID_PROC_OID, "Invalid oid passed"); + NOISEPAGE_ASSERT(proc != INVALID_PROC_OID, "DropProcedure called with invalid procedure OID"); const auto &name_pri = procs_name_index_->GetProjectedRowInitializer(); const auto &oid_pri = procs_oid_index_->GetProjectedRowInitializer(); @@ -229,7 +228,9 @@ bool PgProcImpl::DropProcedure(const common::ManagedPointerGet(proc_pm[PgProc::PRONAME.oid_], nullptr); auto proc_ns = *table_pr->Get(proc_pm[PgProc::PRONAMESPACE.oid_], nullptr); - auto ctx_ptr = table_pr->AccessWithNullCheck(proc_pm[PgProc::PRO_CTX_PTR.oid_]); + + // Grab a pointer to the procedure context (if present) + auto *ptr_ptr = reinterpret_cast(table_pr->AccessWithNullCheck(proc_pm[PgProc::PRO_CTX_PTR.oid_])); // Delete from pg_proc_name_index. { @@ -241,7 +242,8 @@ bool PgProcImpl::DropProcedure(const common::ManagedPointer(ptr_ptr); txn->RegisterCommitAction([=](transaction::DeferredActionManager *deferred_action_manager) { deferred_action_manager->RegisterDeferredAction( [=]() { deferred_action_manager->RegisterDeferredAction([=]() { delete ctx_ptr; }); }); @@ -302,8 +304,7 @@ common::ManagedPointer PgProcImpl::GetPro NOISEPAGE_ASSERT(result, "Index already verified visibility. This shouldn't fail."); auto *ptr_ptr = (reinterpret_cast(select_pr->AccessWithNullCheck(0))); - NOISEPAGE_ASSERT(nullptr != ptr_ptr, - "GetFunctionContext called on an invalid OID or before SetFunctionContextPointer."); + NOISEPAGE_ASSERT(nullptr != ptr_ptr, "GetFunctionContext called on an invalid OID or before SetFunctionContext."); execution::functions::FunctionContext *ptr = *reinterpret_cast(ptr_ptr); delete[] buffer; @@ -452,12 +453,12 @@ void PgProcImpl::BootstrapProcs(const common::ManagedPointer &args, const std::vector &arg_types, type_oid_t rettype) { - std::vector arg_modes; + std::vector arg_modes; std::vector all_arg_types; if (variadic_type != INVALID_TYPE_OID) { all_arg_types = arg_types; - arg_modes.resize(arg_types.size(), PgProc::ArgModes::IN); // we dont' support OUT or INOUT args right now - arg_modes.back() = PgProc::ArgModes::VARIADIC; // variadic must be the last arg + arg_modes.resize(arg_types.size(), PgProc::ArgMode::IN); // we dont' support OUT or INOUT args right now + arg_modes.back() = PgProc::ArgMode::VARIADIC; // variadic must be the last arg } CreateProcedure(txn, proc_oid_t{dbc->next_oid_++}, procname, PgLanguage::INTERNAL_LANGUAGE_OID, PgNamespace::NAMESPACE_DEFAULT_NAMESPACE_OID, variadic_type, args, arg_types, all_arg_types, @@ -528,6 +529,7 @@ void PgProcImpl::BootstrapProcs(const common::ManagedPointernext_oid_++}, "nprunnersemitint", PgLanguage::INTERNAL_LANGUAGE_OID, PgNamespace::NAMESPACE_DEFAULT_NAMESPACE_OID, INVALID_TYPE_OID, @@ -561,7 +563,7 @@ void PgProcImpl::BootstrapProcContext(const common::ManagedPointerSetFunctionContextPointer(txn, proc_oid, func_context); + const auto retval UNUSED_ATTRIBUTE = dbc->SetFunctionContext(txn, proc_oid, func_context); NOISEPAGE_ASSERT(retval, "Bootstrap operations should not fail"); } @@ -642,6 +644,7 @@ void PgProcImpl::BootstrapProcContexts(const common::ManagedPointerTypeRepr()), func_(func) {} +FunctionDecl::FunctionDecl(const SourcePosition &pos, Identifier name, FunctionLitExpr *func, bool is_lambda) + : Decl(Kind::FunctionDecl, pos, name, func->TypeRepr()), func_(func), is_lambda_(is_lambda) {} // --------------------------------------------------------- // Structure Declaration @@ -90,8 +90,8 @@ bool ComparisonOpExpr::IsLiteralCompareNil(Expr **result) const { // Function Literal Expressions // --------------------------------------------------------- -FunctionLitExpr::FunctionLitExpr(FunctionTypeRepr *type_repr, BlockStmt *body) - : Expr(Kind::FunctionLitExpr, type_repr->Position()), type_repr_(type_repr), body_(body) {} +FunctionLitExpr::FunctionLitExpr(FunctionTypeRepr *type_repr, BlockStmt *body, bool is_lambda) + : Expr(Kind::FunctionLitExpr, type_repr->Position()), type_repr_(type_repr), body_(body), is_lambda_(is_lambda) {} // --------------------------------------------------------- // Call Expression diff --git a/src/execution/ast/ast_clone.cpp b/src/execution/ast/ast_clone.cpp new file mode 100644 index 0000000000..8d3d0aa2e8 --- /dev/null +++ b/src/execution/ast/ast_clone.cpp @@ -0,0 +1,259 @@ +#include +#include + +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/SmallString.h" +#include "llvm/Support/raw_os_ostream.h" +#include "llvm/Support/raw_ostream.h" + +#include "execution/ast/ast.h" +#include "execution/ast/ast_clone.h" +#include "execution/ast/ast_visitor.h" +#include "execution/ast/context.h" +#include "execution/ast/type.h" + +namespace noisepage::execution::ast { + +class AstCloneImpl : public AstVisitor { + public: + explicit AstCloneImpl(AstNode *root, AstNodeFactory *factory, Context *old_context, Context *new_context) + : root_(root), factory_{factory}, old_context_{old_context}, new_context_{new_context} {} + + AstNode *Run() { return Visit(root_); } + + // Declare all node visit methods here +#define DECLARE_VISIT_METHOD(type) AstNode *Visit##type(type *node); + AST_NODES(DECLARE_VISIT_METHOD) +#undef DECLARE_VISIT_METHOD + + Identifier CloneIdentifier(const Identifier &ident) { return new_context_->GetIdentifier(ident.GetData()); } + + Identifier CloneIdentifier(const Identifier &&ident) { + (void)old_context_; + return new_context_->GetIdentifier(ident.GetData()); + } + + private: + /** The root of the AST to clone. */ + AstNode *root_; + + /** The AST node factory used to allocate new nodes. */ + AstNodeFactory *factory_; + + /** The AST context of the source AST. */ + Context *old_context_; + + /** The AST context of the destination AST. */ + Context *new_context_; +}; + +AstNode *AstCloneImpl::VisitFile(File *node) { + util::RegionVector decls(new_context_->GetRegion()); + for (auto *decl : node->Declarations()) { + decls.push_back(reinterpret_cast(Visit(decl))); + } + return factory_->NewFile(node->Position(), std::move(decls)); +} + +AstNode *AstCloneImpl::VisitFieldDecl(FieldDecl *node) { + return factory_->NewFieldDecl(node->Position(), CloneIdentifier(node->Name()), + reinterpret_cast(Visit(node->TypeRepr()))); +} + +AstNode *AstCloneImpl::VisitFunctionDecl(FunctionDecl *node) { + return factory_->NewFunctionDecl(node->Position(), CloneIdentifier(node->Name()), + reinterpret_cast(VisitFunctionLitExpr(node->Function()))); +} + +AstNode *AstCloneImpl::VisitVariableDecl(VariableDecl *node) { + return factory_->NewVariableDecl( + node->Position(), CloneIdentifier(node->Name()), + node->TypeRepr() == nullptr ? nullptr : reinterpret_cast(Visit(node->TypeRepr())), + node->Initial() == nullptr ? nullptr : reinterpret_cast(Visit(node->Initial()))); +} + +AstNode *AstCloneImpl::VisitStructDecl(StructDecl *node) { + return factory_->NewStructDecl( + node->Position(), CloneIdentifier(node->Name()), + reinterpret_cast(VisitStructTypeRepr(reinterpret_cast(node->TypeRepr())))); +} + +AstNode *AstCloneImpl::VisitAssignmentStmt(AssignmentStmt *node) { + return factory_->NewAssignmentStmt(node->Position(), reinterpret_cast(Visit(node->Destination())), + reinterpret_cast(Visit(node->Source()))); +} + +AstNode *AstCloneImpl::VisitBlockStmt(BlockStmt *node) { + util::RegionVector stmts(new_context_->GetRegion()); + for (auto *stmt : node->Statements()) { + stmts.push_back(reinterpret_cast(Visit(stmt))); + } + return factory_->NewBlockStmt(node->Position(), node->RightBracePosition(), std::move(stmts)); +} + +AstNode *AstCloneImpl::VisitDeclStmt(DeclStmt *node) { + return factory_->NewDeclStmt(reinterpret_cast(Visit(node->Declaration()))); +} + +AstNode *AstCloneImpl::VisitExpressionStmt(ExpressionStmt *node) { + return factory_->NewExpressionStmt(reinterpret_cast(Visit(node->Expression()))); +} + +AstNode *AstCloneImpl::VisitForStmt(ForStmt *node) { + auto init = node->Init() == nullptr ? nullptr : reinterpret_cast(Visit(node->Init())); + auto next = node->Next() == nullptr ? nullptr : reinterpret_cast(Visit(node->Next())); + return factory_->NewForStmt(node->Position(), init, reinterpret_cast(Visit(node->Condition())), next, + reinterpret_cast(VisitBlockStmt(node->Body()))); +} + +AstNode *AstCloneImpl::VisitForInStmt(ForInStmt *node) { + return factory_->NewForInStmt(node->Position(), reinterpret_cast(Visit(node->Target())), + reinterpret_cast(Visit(node->Iterable())), + reinterpret_cast(VisitBlockStmt(node->Body()))); +} + +AstNode *AstCloneImpl::VisitIfStmt(IfStmt *node) { + auto *else_stmt = node->ElseStmt() == nullptr ? nullptr : reinterpret_cast(Visit((node->ElseStmt()))); + return factory_->NewIfStmt(node->Position(), reinterpret_cast(Visit(node->Condition())), + reinterpret_cast(VisitBlockStmt(node->ThenStmt())), else_stmt); +} + +AstNode *AstCloneImpl::VisitReturnStmt(ReturnStmt *node) { + if (node->Ret() == nullptr) { + return factory_->NewReturnStmt(node->Position(), nullptr); + } + return factory_->NewReturnStmt(node->Position(), reinterpret_cast(Visit(node->Ret()))); +} + +AstNode *AstCloneImpl::VisitCallExpr(CallExpr *node) { + util::RegionVector args(new_context_->GetRegion()); + + for (auto *arg : node->Arguments()) { + args.push_back(reinterpret_cast(Visit(arg))); + } + if (node->GetCallKind() == CallExpr::CallKind::Builtin) { + return factory_->NewBuiltinCallExpr(reinterpret_cast(Visit(node->Function())), std::move(args)); + } + return factory_->NewCallExpr(reinterpret_cast(Visit(node->Function())), std::move(args)); +} + +AstNode *AstCloneImpl::VisitBinaryOpExpr(BinaryOpExpr *node) { + return factory_->NewBinaryOpExpr(node->Position(), node->Op(), reinterpret_cast(Visit(node->Left())), + reinterpret_cast(Visit(node->Right()))); +} + +AstNode *AstCloneImpl::VisitComparisonOpExpr(ComparisonOpExpr *node) { + return factory_->NewComparisonOpExpr(node->Position(), node->Op(), reinterpret_cast(Visit(node->Left())), + reinterpret_cast(Visit(node->Right()))); +} + +AstNode *AstCloneImpl::VisitFunctionLitExpr(FunctionLitExpr *node) { + return factory_->NewFunctionLitExpr(reinterpret_cast(VisitFunctionTypeRepr(node->TypeRepr())), + reinterpret_cast(VisitBlockStmt(node->Body()))); +} + +AstNode *AstCloneImpl::VisitIdentifierExpr(IdentifierExpr *node) { + return factory_->NewIdentifierExpr(node->Position(), CloneIdentifier(node->Name())); +} + +AstNode *AstCloneImpl::VisitImplicitCastExpr(ImplicitCastExpr *node) { return Visit(node->Input()); } + +AstNode *AstCloneImpl::VisitIndexExpr(IndexExpr *node) { + return factory_->NewIndexExpr(node->Position(), reinterpret_cast(Visit(node->Object())), + reinterpret_cast(Visit(node->Index()))); +} + +AstNode *AstCloneImpl::VisitLambdaExpr(LambdaExpr *node) { + util::RegionVector capture_idents(new_context_->GetRegion()); + for (auto ident : node->GetCaptureIdents()) { + capture_idents.push_back(reinterpret_cast(Visit(ident))); + } + return factory_->NewLambdaExpr(node->Position(), + reinterpret_cast(Visit(node->GetFunctionLiteralExpr())), + std::move(capture_idents)); +} + +AstNode *AstCloneImpl::VisitLitExpr(LitExpr *node) { + AstNode *literal = nullptr; + switch (node->GetLiteralKind()) { + case LitExpr::LitKind::Nil: { + literal = factory_->NewNilLiteral(node->Position()); + break; + } + case LitExpr::LitKind::Boolean: { + literal = factory_->NewBoolLiteral(node->Position(), node->BoolVal()); + break; + } + case LitExpr::LitKind::Int: { + literal = factory_->NewIntLiteral(node->Position(), node->Int64Val()); + break; + } + case LitExpr::LitKind::Float: { + literal = factory_->NewFloatLiteral(node->Position(), node->Float64Val()); + break; + } + case LitExpr::LitKind::String: { + literal = factory_->NewStringLiteral(node->Position(), CloneIdentifier(node->StringVal())); + break; + } + } + NOISEPAGE_ASSERT(literal != nullptr, "Unknown literal kind"); + return literal; +} + +AstNode *AstCloneImpl::VisitBreakStmt(BreakStmt *node) { return factory_->NewBreakStmt(node->Position()); } + +AstNode *AstCloneImpl::VisitMemberExpr(MemberExpr *node) { + return factory_->NewMemberExpr(node->Position(), reinterpret_cast(Visit(node->Object())), + reinterpret_cast(Visit(node->Member()))); +} + +AstNode *AstCloneImpl::VisitUnaryOpExpr(UnaryOpExpr *node) { + return factory_->NewUnaryOpExpr(node->Position(), node->Op(), reinterpret_cast(Visit(node->Input()))); +} + +AstNode *AstCloneImpl::VisitBadExpr(BadExpr *node) { return factory_->NewBadExpr(node->Position()); } + +AstNode *AstCloneImpl::VisitStructTypeRepr(StructTypeRepr *node) { + util::RegionVector field_decls(new_context_->GetRegion()); + field_decls.reserve(node->Fields().size()); + for (auto field : node->Fields()) { + field_decls.push_back(reinterpret_cast((VisitFieldDecl(field)))); + } + return factory_->NewStructType(node->Position(), std::move(field_decls)); +} + +AstNode *AstCloneImpl::VisitPointerTypeRepr(PointerTypeRepr *node) { + return factory_->NewPointerType(node->Position(), reinterpret_cast(Visit(node->Base()))); +} + +AstNode *AstCloneImpl::VisitFunctionTypeRepr(FunctionTypeRepr *node) { + util::RegionVector params(new_context_->GetRegion()); + for (auto *param : node->Parameters()) { + params.push_back(reinterpret_cast(VisitFieldDecl(param))); + } + + return factory_->NewFunctionType(node->Position(), std::move(params), + reinterpret_cast(Visit(node->ReturnType()))); +} + +AstNode *AstCloneImpl::VisitArrayTypeRepr(ArrayTypeRepr *node) { + return factory_->NewArrayType(node->Position(), reinterpret_cast(Visit(node->Length())), + reinterpret_cast(Visit(node->ElementType()))); +} + +AstNode *AstCloneImpl::VisitMapTypeRepr(MapTypeRepr *node) { + return factory_->NewMapType(node->Position(), reinterpret_cast(Visit(node->KeyType())), + reinterpret_cast(Visit(node->ValType()))); +} + +AstNode *AstCloneImpl::VisitLambdaTypeRepr(LambdaTypeRepr *node) { + return factory_->NewLambdaType(node->Position(), reinterpret_cast(Visit(node->FunctionType()))); +} + +AstNode *AstClone::Clone(AstNode *node, AstNodeFactory *factory, Context *old_context, Context *new_context) { + AstCloneImpl cloner{node, factory, old_context, new_context}; + return cloner.Run(); +} + +} // namespace noisepage::execution::ast diff --git a/src/execution/ast/ast_dump.cpp b/src/execution/ast/ast_dump.cpp index 38a48e23cc..c99345d996 100644 --- a/src/execution/ast/ast_dump.cpp +++ b/src/execution/ast/ast_dump.cpp @@ -171,6 +171,11 @@ void AstDumperImpl::VisitFunctionDecl(FunctionDecl *node) { DumpExpr(node->Function()); } +void AstDumperImpl::VisitLambdaExpr(LambdaExpr *node) { + DumpNodeCommon(node); + DumpExpr(node->GetFunctionLiteralExpr()); +} + void AstDumperImpl::VisitVariableDecl(VariableDecl *node) { DumpNodeCommon(node); DumpIdentifier(node->Name()); @@ -203,6 +208,8 @@ void AstDumperImpl::VisitBlockStmt(BlockStmt *node) { } } +void AstDumperImpl::VisitBreakStmt(BreakStmt *node) { DumpNodeCommon(node); } + void AstDumperImpl::VisitDeclStmt(DeclStmt *node) { AstVisitor::Visit(node->Declaration()); } void AstDumperImpl::VisitExpressionStmt(ExpressionStmt *node) { AstVisitor::Visit(node->Expression()); } @@ -257,6 +264,11 @@ void AstDumperImpl::VisitCallExpr(CallExpr *node) { } case CallExpr::CallKind::Regular: { out_ << "Regular"; + break; + } + case CallExpr::CallKind::Lambda: { + out_ << "Lambda"; + break; } } } @@ -375,6 +387,11 @@ void AstDumperImpl::VisitMapTypeRepr(MapTypeRepr *node) { DumpExpr(node->ValType()); } +void AstDumperImpl::VisitLambdaTypeRepr(LambdaTypeRepr *node) { + DumpNodeCommon(node); + DumpExpr(node->FunctionType()); +} + std::string AstDump::Dump(AstNode *node) { llvm::SmallString<256> buffer; llvm::raw_svector_ostream stream(buffer); diff --git a/src/execution/ast/ast_pretty_print.cpp b/src/execution/ast/ast_pretty_print.cpp index 330c119237..f2bf0b9579 100644 --- a/src/execution/ast/ast_pretty_print.cpp +++ b/src/execution/ast/ast_pretty_print.cpp @@ -204,6 +204,12 @@ void AstPrettyPrintImpl::VisitMapTypeRepr(MapTypeRepr *node) { Visit(node->ValType()); } +void AstPrettyPrintImpl::VisitLambdaTypeRepr(LambdaTypeRepr *node) { + os_ << "lambda["; + Visit(node->FunctionType()); + os_ << "]"; +} + void AstPrettyPrintImpl::VisitLitExpr(LitExpr *node) { switch (node->GetLiteralKind()) { case LitExpr::LitKind::Nil: @@ -224,6 +230,8 @@ void AstPrettyPrintImpl::VisitLitExpr(LitExpr *node) { } } +void AstPrettyPrintImpl::VisitBreakStmt(BreakStmt *node) { os_ << "break;\n"; } + void AstPrettyPrintImpl::VisitStructTypeRepr(StructTypeRepr *node) { // We want to ensure all types are aligned. Pre-process the fields to // find longest field names, then align as appropriate. @@ -283,6 +291,11 @@ void AstPrettyPrintImpl::VisitIndexExpr(IndexExpr *node) { os_ << "]"; } +void AstPrettyPrintImpl::VisitLambdaExpr(LambdaExpr *node) { + os_ << "lambda "; + VisitFunctionLitExpr(node->GetFunctionLiteralExpr()); +} + void AstPrettyPrintImpl::VisitFunctionTypeRepr(FunctionTypeRepr *node) { os_ << "("; bool first = true; diff --git a/src/execution/ast/context.cpp b/src/execution/ast/context.cpp index 30d42ee2a1..5633f8b0d6 100644 --- a/src/execution/ast/context.cpp +++ b/src/execution/ast/context.cpp @@ -150,8 +150,10 @@ struct Context::Implementation { llvm::DenseMap builtin_types_; llvm::DenseMap builtin_funcs_; llvm::DenseMap pointer_types_; + llvm::DenseMap reference_types_; llvm::DenseMap, ArrayType *> array_types_; llvm::DenseMap, MapType *> map_types_; + llvm::DenseMap lambda_types_; llvm::DenseSet struct_types_; llvm::DenseSet func_types_; @@ -232,6 +234,8 @@ Identifier Context::GetBuiltinType(BuiltinType::Kind kind) { PointerType *Type::PointerTo() { return PointerType::Get(this); } +ReferenceType *Type::ReferenceTo() { return ReferenceType::Get(this); } + // static BuiltinType *BuiltinType::Get(Context *ctx, BuiltinType::Kind kind) { return ctx->Impl()->builtin_types_list_[kind]; } @@ -251,6 +255,19 @@ PointerType *PointerType::Get(Type *base) { return pointer_type; } +// static +ReferenceType *ReferenceType::Get(Type *base) { + Context *ctx = base->GetContext(); + + ReferenceType *&reference_type = ctx->Impl()->reference_types_[base]; + + if (reference_type == nullptr) { + reference_type = new (ctx->GetRegion()) ReferenceType(base); + } + + return reference_type; +} + // static ArrayType *ArrayType::Get(uint64_t length, Type *elem_type) { Context *ctx = elem_type->GetContext(); @@ -287,6 +304,19 @@ Field CreatePaddingElement(uint32_t id, uint32_t size, Context *ctx) { }; // namespace +// static +LambdaType *LambdaType::Get(FunctionType *fn_type) { + Context *ctx = fn_type->GetContext(); + + LambdaType *&lambda_type = ctx->Impl()->lambda_types_[fn_type]; + + if (lambda_type == nullptr) { + lambda_type = new (ctx->GetRegion()) LambdaType(fn_type); + } + + return lambda_type; +} + // static StructType *StructType::Get(Context *ctx, util::RegionVector &&fields) { // Empty structs get an artificial element @@ -380,7 +410,31 @@ FunctionType *FunctionType::Get(util::RegionVector &¶ms, Type *ret) { if (inserted) { // The function type was not in the cache, create the type now and insert it // into the cache - func_type = new (ctx->GetRegion()) FunctionType(std::move(params), ret); + func_type = new (ctx->GetRegion()) FunctionType(std::move(params), ret, false); + *iter = func_type; + } else { + func_type = *iter; + } + + return func_type; +} + +// static +FunctionType *FunctionType::GetLambda(util::RegionVector &¶ms, Type *ret) { + Context *ctx = ret->GetContext(); + + const FunctionTypeKeyInfo::KeyTy key(ret, params); + + auto insert_res = ctx->Impl()->func_types_.insert_as(nullptr, key); + auto iter = insert_res.first; + auto inserted = insert_res.second; + + FunctionType *func_type = nullptr; + + if (inserted) { + // The function type was not in the cache, create the type now and insert it + // into the cache + func_type = new (ctx->GetRegion()) FunctionType(std::move(params), ret, true); *iter = func_type; } else { func_type = *iter; diff --git a/src/execution/ast/type.cpp b/src/execution/ast/type.cpp index 2ea4e771fc..912c0ebdf5 100644 --- a/src/execution/ast/type.cpp +++ b/src/execution/ast/type.cpp @@ -3,6 +3,7 @@ #include #include +#include "execution/ast/context.h" #include "execution/exec/execution_context.h" #include "execution/sql/aggregation_hash_table.h" #include "execution/sql/aggregators.h" @@ -85,10 +86,30 @@ const bool BuiltinType::SIGNED_FLAGS[] = { // Function Type // --------------------------------------------------------- -FunctionType::FunctionType(util::RegionVector &¶ms, Type *ret) +FunctionType::FunctionType(util::RegionVector &¶ms, Type *ret, bool is_lambda) : Type(ret->GetContext(), sizeof(void *), alignof(void *), TypeId::FunctionType), params_(std::move(params)), - ret_(ret) {} + ret_(ret), + is_lambda_(is_lambda) {} + +bool FunctionType::IsEqual(const FunctionType *other) { + if (other->params_.size() != params_.size()) { + return false; + } + + for (auto i = 0UL; i < params_.size(); i++) { + if (params_[i].type_ != other->params_[i].type_) { + return false; + } + } + + return true; +} + +void FunctionType::RegisterCapture() { + NOISEPAGE_ASSERT(captures_ != nullptr, "No capture given"); + params_.emplace_back(GetContext()->GetIdentifier("captures"), captures_); +} // --------------------------------------------------------- // Map Type @@ -100,6 +121,13 @@ MapType::MapType(Type *key_type, Type *val_type) key_type_(key_type), val_type_(val_type) {} +// --------------------------------------------------------- +// Lambda Type +// --------------------------------------------------------- + +LambdaType::LambdaType(FunctionType *fn_type) + : Type(fn_type->GetContext(), fn_type->GetSize(), fn_type->GetAlignment(), TypeId::LambdaType), fn_type_(fn_type) {} + // --------------------------------------------------------- // Struct Type // --------------------------------------------------------- diff --git a/src/execution/ast/type_printer.cpp b/src/execution/ast/type_printer.cpp index 1662f4fc9f..e262944dc0 100644 --- a/src/execution/ast/type_printer.cpp +++ b/src/execution/ast/type_printer.cpp @@ -54,6 +54,11 @@ void TypePrinter::VisitPointerType(const PointerType *type) { Visit(type->GetBase()); } +void TypePrinter::VisitReferenceType(const ReferenceType *type) { + Os() << "&"; + Visit(type->GetBase()); +} + void TypePrinter::VisitStructType(const StructType *type) { Os() << "struct{"; bool first = true; @@ -85,6 +90,12 @@ void execution::ast::TypePrinter::VisitMapType(const MapType *type) { Visit(type->GetValueType()); } +void execution::ast::TypePrinter::VisitLambdaType(const LambdaType *type) { + Os() << "lambda["; + Visit(type->GetFunctionType()); + Os() << "]"; +} + } // namespace // static diff --git a/src/execution/ast/udf/udf_ast_nodes.cpp b/src/execution/ast/udf/udf_ast_nodes.cpp new file mode 100644 index 0000000000..b4ac41615c --- /dev/null +++ b/src/execution/ast/udf/udf_ast_nodes.cpp @@ -0,0 +1,50 @@ +#include + +#include "common/macros.h" +#include "execution/ast/udf/node_types.h" + +namespace noisepage::execution::ast::udf { + +std::string NodeTypeToShortString(NodeType type) { + switch (type) { + case NodeType::VALUE_EXPR: + return "VALUE_EXPR"; + case NodeType::IS_NULL_EXPR: + return "IS_NULL_EXPR"; + case NodeType::VARIABLE_EXPR: + return "VARIABLE_EXPR"; + case NodeType::MEMBER_EXPR: + return "MEMBER_EXPR"; + case NodeType::BINARY_EXPR: + return "BINARY_EXPR"; + case NodeType::CALL_EXPR: + return "CALL_EXPR"; + case NodeType::SEQ_STMT: + return "SEQ_STMT"; + case NodeType::DECL_STMT: + return "DECL_STMT"; + case NodeType::IF_STMT: + return "IF_STMT"; + case NodeType::FORI_STMT: + return "FORI_STMT"; + case NodeType::FORS_STMT: + return "FORS_STMT"; + case NodeType::WHILE_STMT: + return "WHILE_STMT"; + case NodeType::RET_STMT: + return "RET_STMT"; + case NodeType::ASSIGN_STMT: + return "ASSIGN_STMT"; + case NodeType::SQL_STMT: + return "SQL_STMT"; + case NodeType::DYNAMIC_SQL_STMT: + return "DYNAMIC_SQL_STMT"; + case NodeType::FUNCTION: + return "FUNCTION"; + default: + NOISEPAGE_ASSERT(false, "Impossible node type"); + return "INVALID"; + } +} + +} // namespace noisepage::execution::ast::udf diff --git a/src/execution/compiler/codegen.cpp b/src/execution/compiler/codegen.cpp index 01bf29d060..0fb54ed525 100644 --- a/src/execution/compiler/codegen.cpp +++ b/src/execution/compiler/codegen.cpp @@ -194,6 +194,10 @@ ast::Expr *CodeGen::Float32Type() const { return BuiltinType(ast::BuiltinType::F ast::Expr *CodeGen::Float64Type() const { return BuiltinType(ast::BuiltinType::Float64); } +ast::Expr *CodeGen::LambdaType(ast::Expr *fn_type) { + return context_->GetNodeFactory()->NewLambdaType(position_, fn_type); +} + ast::Expr *CodeGen::PointerType(ast::Expr *base_type_repr) const { // Create the type representation auto *type_repr = context_->GetNodeFactory()->NewPointerType(position_, base_type_repr); @@ -367,6 +371,12 @@ ast::Expr *CodeGen::AccessStructMember(ast::Expr *object, ast::Identifier member return context_->GetNodeFactory()->NewMemberExpr(position_, object, MakeExpr(member)); } +ast::Stmt *CodeGen::Break() { + ast::Stmt *break_stmt = context_->GetNodeFactory()->NewBreakStmt(position_); + NewLine(); + return break_stmt; +} + ast::Stmt *CodeGen::Return() { return Return(nullptr); } ast::Stmt *CodeGen::Return(ast::Expr *ret) { diff --git a/src/execution/compiler/compilation_context.cpp b/src/execution/compiler/compilation_context.cpp index 07ff7eb1f9..af9a3e345b 100644 --- a/src/execution/compiler/compilation_context.cpp +++ b/src/execution/compiler/compilation_context.cpp @@ -84,17 +84,22 @@ std::atomic unique_ids{0}; } // namespace CompilationContext::CompilationContext(ExecutableQuery *query, query_id_t query_id, catalog::CatalogAccessor *accessor, - const CompilationMode mode, const exec::ExecutionSettings &settings) + const CompilationMode mode, const exec::ExecutionSettings &settings, + ast::LambdaExpr *output_callback) : unique_id_(unique_ids++), query_id_(query_id), query_(query), mode_(mode), codegen_(query_->GetContext(), accessor), query_state_var_(codegen_.MakeIdentifier("queryState")), - query_state_type_(codegen_.MakeIdentifier("QueryState")), + query_state_type_(codegen_.MakeIdentifier( + output_callback == nullptr ? "QueryState" : output_callback->GetName().GetString() + "QueryState")), query_state_(query_state_type_, [this](CodeGen *codegen) { return codegen->MakeExpr(query_state_var_); }), + output_callback_(output_callback), counters_enabled_(settings.GetIsCountersEnabled()), - pipeline_metrics_enabled_(settings.GetIsPipelineMetricsEnabled()) {} + pipeline_metrics_enabled_((output_callback != nullptr) ? false : settings.GetIsPipelineMetricsEnabled()) {} + +// TODO(Kyle): Why disable pipeline metrics whenever we have an output callback? ast::FunctionDecl *CompilationContext::GenerateInitFunction() { const auto name = codegen_.MakeIdentifier(GetFunctionPrefix() + "_Init"); @@ -200,17 +205,19 @@ void CompilationContext::GeneratePlan(const planner::AbstractPlanNode &plan, // static std::unique_ptr CompilationContext::Compile( const planner::AbstractPlanNode &plan, const exec::ExecutionSettings &exec_settings, - catalog::CatalogAccessor *accessor, const CompilationMode mode, std::optional override_qid, - common::ManagedPointer plan_meta_data) { - // The query we're generating code for. - auto query = std::make_unique(plan, exec_settings, accessor->GetTxn()->StartTime()); + catalog::CatalogAccessor *accessor, CompilationMode mode, std::optional override_qid, + common::ManagedPointer plan_meta_data, ast::LambdaExpr *output_callback, + common::ManagedPointer context) { + // The query for which we're generating code + auto query = std::make_unique(plan, exec_settings, accessor->GetTxn()->StartTime(), context.Get()); if (override_qid.has_value()) { query->SetQueryId(override_qid.value()); } // Generate the plan for the query - CompilationContext ctx(query.get(), query->GetQueryId(), accessor, mode, exec_settings); + CompilationContext ctx{query.get(), query->GetQueryId(), accessor, mode, exec_settings, output_callback}; ctx.GeneratePlan(plan, plan_meta_data); + query->SetQueryStateType(ctx.query_state_.GetType()); // Done return query; @@ -229,10 +236,8 @@ void CompilationContext::PrepareOut(const planner::AbstractPlanNode &plan, Pipel } void CompilationContext::Prepare(const planner::AbstractPlanNode &plan, Pipeline *pipeline) { - std::unique_ptr translator; - NOISEPAGE_ASSERT(ops_.find(&plan) == ops_.end(), "plan already prepared"); - + std::unique_ptr translator; switch (plan.GetPlanNodeType()) { case planner::PlanNodeType::AGGREGATE: { const auto &aggregation = dynamic_cast(plan); @@ -436,7 +441,14 @@ ExpressionTranslator *CompilationContext::LookupTranslator(const parser::Abstrac return nullptr; } -std::string CompilationContext::GetFunctionPrefix() const { return "Query" + std::to_string(unique_id_); } +std::string CompilationContext::GetFunctionPrefix() const { + // If an output callback is present, we prefix + // each function with the callback name + if (HasOutputCallback()) { + return fmt::format("{}Query{}", output_callback_->GetName().GetString(), std::to_string(unique_id_)); + } + return fmt::format("Query{}", std::to_string(unique_id_)); +} util::RegionVector CompilationContext::QueryParams() const { ast::Expr *state_type = codegen_.PointerType(codegen_.MakeExpr(query_state_type_)); diff --git a/src/execution/compiler/executable_query.cpp b/src/execution/compiler/executable_query.cpp index b8346fcf6c..7e7e92f25a 100644 --- a/src/execution/compiler/executable_query.cpp +++ b/src/execution/compiler/executable_query.cpp @@ -4,10 +4,13 @@ #include "common/error/error_code.h" #include "common/error/exception.h" +#include "execution/ast/ast.h" +#include "execution/ast/ast_dump.h" #include "execution/ast/context.h" #include "execution/compiler/compiler.h" #include "execution/exec/execution_context.h" #include "execution/sema/error_reporter.h" +#include "execution/vm/bytecode_function_info.h" #include "execution/vm/module.h" #include "loggers/execution_logger.h" #include "self_driving/modeling/operating_unit.h" @@ -21,9 +24,16 @@ namespace noisepage::execution::compiler { // //===----------------------------------------------------------------------===// -ExecutableQuery::Fragment::Fragment(std::vector &&functions, std::vector &&teardown_fn, +ExecutableQuery::Fragment::Fragment(std::vector &&functions, std::vector &&teardown_fns, std::unique_ptr module) - : functions_(std::move(functions)), teardown_fn_(std::move(teardown_fn)), module_(std::move(module)) {} + : functions_{std::move(functions)}, teardown_fns_{std::move(teardown_fns)}, module_{std::move(module)} {} + +ExecutableQuery::Fragment::Fragment(std::vector &&functions, std::vector &&teardown_fns, + std::unique_ptr module, ast::File *file) + : functions_{std::move(functions)}, + teardown_fns_{std::move(teardown_fns)}, + module_{std::move(module)}, + file_{file} {} ExecutableQuery::Fragment::~Fragment() = default; @@ -43,7 +53,7 @@ void ExecutableQuery::Fragment::Run(byte query_state[], vm::ExecutionMode mode) try { func(query_state); } catch (const AbortException &e) { - for (const auto &teardown_name : teardown_fn_) { + for (const auto &teardown_name : teardown_fns_) { if (!module_->GetFunction(teardown_name, mode, &func)) { throw EXECUTION_EXCEPTION(fmt::format("Could not find teardown function '{}' in query fragment.", func_name), common::ErrorCode::ERRCODE_INTERNAL_ERROR); @@ -55,6 +65,11 @@ void ExecutableQuery::Fragment::Run(byte query_state[], vm::ExecutionMode mode) } } +std::optional ExecutableQuery::Fragment::GetFunctionMetadata(const std::string &name) const { + const auto *metadata = module_->GetFuncInfoByName(name); + return (metadata == nullptr) ? std::nullopt : std::make_optional(metadata); +} + const vm::ModuleMetadata &ExecutableQuery::Fragment::GetModuleMetadata() const { return module_->GetMetadata(); } //===----------------------------------------------------------------------===// @@ -80,54 +95,66 @@ void ExecutableQuery::SetPipelineOperatingUnits(std::unique_ptr("errors_region")), - context_region_(std::make_unique("context_region")), - errors_(std::make_unique(errors_region_.get())), - ast_context_(std::make_unique(context_region_.get(), errors_.get())), - query_state_size_(0), - pipeline_operating_units_(nullptr), - query_id_(query_identifier++) {} + transaction::timestamp_t timestamp, ast::Context *context) + : plan_{plan}, + exec_settings_{exec_settings}, + timestamp_{timestamp}, + context_region_{std::make_unique("context_region")}, + errors_region_{std::make_unique("errors_region")}, + errors_{std::make_unique(errors_region_.get())}, + ast_context_{context}, + query_state_size_{0}, + pipeline_operating_units_{nullptr}, + query_id_{query_identifier++} { + owns_ast_context_ = (ast_context_ == nullptr); + if (owns_ast_context_) { + ast_context_ = new ast::Context(context_region_.get(), errors_.get()); + } +} ExecutableQuery::ExecutableQuery(const std::string &contents, const common::ManagedPointer exec_ctx, bool is_file, - size_t query_state_size, const exec::ExecutionSettings &exec_settings, - transaction::timestamp_t timestamp) + std::size_t query_state_size, const exec::ExecutionSettings &exec_settings, + transaction::timestamp_t timestamp, ast::Context *context) // TODO(WAN): Giant hack for the plan. The whole point is that you have no plan. - : plan_(reinterpret_cast(exec_settings)), - exec_settings_(exec_settings), - timestamp_(timestamp) { - context_region_ = std::make_unique("context_region"); - errors_region_ = std::make_unique("error_region"); - errors_ = std::make_unique(errors_region_.get()); - ast_context_ = std::make_unique(context_region_.get(), errors_.get()); + : plan_{reinterpret_cast(exec_settings)}, + exec_settings_{exec_settings}, + timestamp_{timestamp}, + context_region_{std::make_unique("context_region")}, + errors_region_{std::make_unique("error_region")}, + errors_{std::make_unique(errors_region_.get())}, + ast_context_{context}, + query_state_size_{0}, + pipeline_operating_units_{nullptr}, + query_id_{query_identifier++} { + owns_ast_context_ = (ast_context_ == nullptr); + if (owns_ast_context_) { + ast_context_ = new ast::Context(context_region_.get(), errors_.get()); + } // Let's scan the source - std::string source; + std::string source{}; if (is_file) { auto file = llvm::MemoryBuffer::getFile(contents); if (std::error_code error = file.getError()) { EXECUTION_LOG_ERROR("There was an error reading file '{}': {}", contents, error.message()); return; } - // Copy the source into a temporary, compile, and run source = (*file)->getBuffer().str(); } else { source = contents; } - auto input = Compiler::Input("tpl_source", ast_context_.get(), &source, exec_settings.GetCompilerSettings()); + auto input = Compiler::Input("tpl_source", ast_context_, &source, exec_settings.GetCompilerSettings()); auto module = compiler::Compiler::RunCompilationSimple(input); std::vector functions{"main"}; - std::vector teardown_functions; + std::vector teardown_functions{}; + auto fragment = std::make_unique(std::move(functions), std::move(teardown_functions), std::move(module)); - std::vector> fragments; + std::vector> fragments{}; fragments.emplace_back(std::move(fragment)); Setup(std::move(fragments), query_state_size, nullptr); @@ -139,8 +166,11 @@ ExecutableQuery::ExecutableQuery(const std::string &contents, } // Needed because we forward-declare classes used as template types to std::unique_ptr<> -ExecutableQuery::~ExecutableQuery() = default; - +ExecutableQuery::~ExecutableQuery() { + if (owns_ast_context_) { + delete ast_context_; + } +} void ExecutableQuery::Setup(std::vector> &&fragments, const std::size_t query_state_size, std::unique_ptr pipeline_operating_units) { NOISEPAGE_ASSERT( @@ -161,9 +191,9 @@ void ExecutableQuery::Run(common::ManagedPointer exec_ct // First, allocate the query state and move the execution context into it. auto query_state = std::make_unique(query_state_size_); *reinterpret_cast(query_state.get()) = exec_ctx.Get(); - exec_ctx->SetQueryState(query_state.get()); - exec_ctx->SetExecutionMode(static_cast(mode)); + exec_ctx->SetExecutionMode(mode); + exec_ctx->SetQueryState(query_state.get()); exec_ctx->SetPipelineOperatingUnits(GetPipelineOperatingUnits()); exec_ctx->SetQueryId(query_id_); @@ -171,10 +201,37 @@ void ExecutableQuery::Run(common::ManagedPointer exec_ct for (const auto &fragment : fragments_) { fragment->Run(query_state.get(), mode); } +} + +std::vector ExecutableQuery::GetFunctionNames() const { + std::vector function_names{}; + for (const auto &f : fragments_) { + const auto &frag_functions = f->GetFunctions(); + function_names.insert(function_names.end(), frag_functions.cbegin(), frag_functions.cend()); + } + return function_names; +} - // We do not currently re-use ExecutionContexts. However, this is unset to help ensure - // we don't *intentionally* retain any dangling pointers. - exec_ctx->SetQueryState(nullptr); +std::vector ExecutableQuery::GetFunctionMetadata() const { + std::vector function_meta{}; + for (const auto &f : fragments_) { + const auto function_names = f->GetFunctions(); + for (const auto &function_name : function_names) { + auto meta = f->GetFunctionMetadata(function_name); + NOISEPAGE_ASSERT(meta.has_value(), "Broken invariant"); + function_meta.push_back(meta.value()); + } + } + return function_meta; +} + +std::vector ExecutableQuery::GetDecls() const { + std::vector decls{}; + for (const auto &f : fragments_) { + const auto &frag_decls = f->GetFile()->Declarations(); + decls.insert(decls.end(), frag_decls.cbegin(), frag_decls.cend()); + } + return decls; } } // namespace noisepage::execution::compiler diff --git a/src/execution/compiler/executable_query_builder.cpp b/src/execution/compiler/executable_query_builder.cpp index 6eb80c0e70..edd2a73602 100644 --- a/src/execution/compiler/executable_query_builder.cpp +++ b/src/execution/compiler/executable_query_builder.cpp @@ -69,7 +69,7 @@ std::unique_ptr ExecutableQueryFragmentBuilder::Compi teardown_names.push_back(decl->Name().GetString()); } return std::make_unique(std::move(step_functions_), std::move(teardown_names), - std::move(module)); + std::move(module), generated_file); } } // namespace noisepage::execution::compiler diff --git a/src/execution/compiler/expression/expression_translator.cpp b/src/execution/compiler/expression/expression_translator.cpp index 043772dced..3b732435a7 100644 --- a/src/execution/compiler/expression/expression_translator.cpp +++ b/src/execution/compiler/expression/expression_translator.cpp @@ -15,4 +15,16 @@ ast::Expr *ExpressionTranslator::GetExecutionContextPtr() const { return compilation_context_->GetExecutionContextPtrFromQueryState(); } +void ExpressionTranslator::DefineHelperFunctions(util::RegionVector *decls) { + for (auto child : expr_.GetChildren()) { + compilation_context_->LookupTranslator(*child)->DefineHelperFunctions(decls); + } +} + +void ExpressionTranslator::DefineHelperStructs(util::RegionVector *decls) { + for (auto child : expr_.GetChildren()) { + compilation_context_->LookupTranslator(*child)->DefineHelperStructs(decls); + } +} + } // namespace noisepage::execution::compiler diff --git a/src/execution/compiler/expression/function_translator.cpp b/src/execution/compiler/expression/function_translator.cpp index 2035a6a7fe..40c432407b 100644 --- a/src/execution/compiler/expression/function_translator.cpp +++ b/src/execution/compiler/expression/function_translator.cpp @@ -1,6 +1,8 @@ #include "execution/compiler/expression/function_translator.h" #include "catalog/catalog_accessor.h" +#include "execution/ast/ast.h" +#include "execution/ast/ast_clone.h" #include "execution/compiler/compilation_context.h" #include "execution/compiler/work_context.h" #include "execution/functions/function_context.h" @@ -21,11 +23,8 @@ ast::Expr *FunctionTranslator::DeriveValue(WorkContext *ctx, const ColumnValuePr const auto &func_expr = GetExpressionAs(); auto proc_oid = func_expr.GetProcOid(); auto func_context = codegen->GetCatalogAccessor()->GetFunctionContext(proc_oid); - if (!func_context->IsBuiltin()) { - UNREACHABLE("User-defined functions are not supported"); - } - std::vector params; + std::vector params{}; if (func_context->IsExecCtxRequired()) { params.push_back(GetExecutionContextPtr()); } @@ -34,7 +33,54 @@ ast::Expr *FunctionTranslator::DeriveValue(WorkContext *ctx, const ColumnValuePr params.push_back(derived_expr); } + if (!func_context->IsBuiltin()) { + const auto identifier_expr = main_fn_; + std::vector args{params.cbegin(), params.cend()}; + return GetCodeGen()->Call(identifier_expr, args); + } + return codegen->CallBuiltin(func_context->GetBuiltin(), params); } +void FunctionTranslator::DefineHelperFunctions(util::RegionVector *decls) { + ExpressionTranslator::DefineHelperFunctions(decls); + auto proc_oid = GetExpressionAs().GetProcOid(); + auto func_context = GetCodeGen()->GetCatalogAccessor()->GetFunctionContext(proc_oid); + if (func_context->IsBuiltin()) { + return; + } + auto *file = reinterpret_cast( + ast::AstClone::Clone(func_context->GetFile(), GetCodeGen()->GetAstContext()->GetNodeFactory(), nullptr, + GetCodeGen()->GetAstContext().Get())); + auto udf_decls = file->Declarations(); + main_fn_ = udf_decls.back()->Name(); + std::size_t num_added = 0; + for (ast::Decl *udf_decl : udf_decls) { + if (udf_decl->IsFunctionDecl()) { + decls->insert(decls->begin() + num_added, udf_decl->As()); + num_added++; + } + } +} + +void FunctionTranslator::DefineHelperStructs(util::RegionVector *decls) { + ExpressionTranslator::DefineHelperStructs(decls); + auto proc_oid = GetExpressionAs().GetProcOid(); + auto func_context = GetCodeGen()->GetCatalogAccessor()->GetFunctionContext(proc_oid); + if (func_context->IsBuiltin()) { + return; + } + auto *file = reinterpret_cast( + ast::AstClone::Clone(func_context->GetFile(), GetCodeGen()->GetAstContext()->GetNodeFactory(), nullptr, + GetCodeGen()->GetAstContext().Get())); + auto udf_decls = file->Declarations(); + std::size_t num_added = 0; + for (ast::Decl *udf_decl : udf_decls) { + if (udf_decl->IsStructDecl()) { + decls->insert(decls->begin() + num_added, udf_decl->As()); + num_added++; + } + } +} + } // namespace noisepage::execution::compiler diff --git a/src/execution/compiler/function_builder.cpp b/src/execution/compiler/function_builder.cpp index a517dd62cf..8e1872a591 100644 --- a/src/execution/compiler/function_builder.cpp +++ b/src/execution/compiler/function_builder.cpp @@ -1,29 +1,55 @@ #include "execution/compiler/function_builder.h" #include "execution/ast/ast_node_factory.h" +#include "execution/ast/context.h" #include "execution/compiler/codegen.h" namespace noisepage::execution::compiler { FunctionBuilder::FunctionBuilder(CodeGen *codegen, ast::Identifier name, util::RegionVector &¶ms, - ast::Expr *ret_type) - : codegen_(codegen), - name_(name), - params_(std::move(params)), - ret_type_(ret_type), - start_(codegen->GetPosition()), - statements_(codegen->MakeEmptyBlock()), - decl_(nullptr) {} - -FunctionBuilder::~FunctionBuilder() { Finish(); } - -ast::Expr *FunctionBuilder::GetParameterByPosition(uint32_t param_idx) { + ast::Expr *return_type) + : type_{FunctionType::FUNCTION}, + codegen_{codegen}, + name_{name}, + params_{std::move(params)}, + captures_{codegen_->GetAstContext()->GetRegion()}, + return_type_{return_type}, + start_{codegen->GetPosition()}, + statements_{codegen->MakeEmptyBlock()}, + decl_{std::in_place_type, nullptr} {} + +FunctionBuilder::FunctionBuilder(CodeGen *codegen, util::RegionVector &¶ms, + util::RegionVector &&captures, ast::Expr *return_type) + : type_{FunctionType::CLOSURE}, + codegen_{codegen}, + params_{std::move(params)}, + captures_{std::move(captures)}, + return_type_{return_type}, + start_{codegen->GetPosition()}, + statements_{codegen->MakeEmptyBlock()}, + decl_{std::in_place_type, nullptr} {} + +FunctionBuilder::~FunctionBuilder() { + if (type_ == FunctionType::FUNCTION) { + Finish(); + } +} + +ast::Expr *FunctionBuilder::GetParameterByPosition(const std::size_t param_idx) { if (param_idx < params_.size()) { return codegen_->MakeExpr(params_[param_idx]->Name()); } return nullptr; } +std::vector FunctionBuilder::GetParameters() const { + std::vector parameters{}; + parameters.reserve(params_.size()); + std::transform(params_.cbegin(), params_.cend(), std::back_inserter(parameters), + [this](const ast::FieldDecl *p) -> ast::Expr * { return codegen_->MakeExpr(p->Name()); }); + return parameters; +} + void FunctionBuilder::Append(ast::Stmt *stmt) { // Append the statement to the block. statements_->AppendStatement(stmt); @@ -36,8 +62,12 @@ void FunctionBuilder::Append(ast::Expr *expr) { Append(codegen_->GetFactory()->N void FunctionBuilder::Append(ast::VariableDecl *decl) { Append(codegen_->GetFactory()->NewDeclStmt(decl)); } ast::FunctionDecl *FunctionBuilder::Finish(ast::Expr *ret) { - if (decl_ != nullptr) { - return decl_; + NOISEPAGE_ASSERT(type_ == FunctionType::FUNCTION, + "Attempt to call FunctionBuilder::Finish on non-function-type builder"); + NOISEPAGE_ASSERT(std::holds_alternative(decl_), "Broken invariant"); + auto *declaration = std::get(decl_); + if (declaration != nullptr) { + return declaration; } NOISEPAGE_ASSERT(ret == nullptr || statements_->IsEmpty() || !statements_->GetLast()->IsReturnStmt(), @@ -45,23 +75,49 @@ ast::FunctionDecl *FunctionBuilder::Finish(ast::Expr *ret) { "with an explicit return expression, or use the factory to manually append a return " "statement and call FunctionBuilder::Finish() with a null return."); - // Add the return. + // Add the return if (!statements_->IsEmpty() && !statements_->GetLast()->IsReturnStmt()) { Append(codegen_->GetFactory()->NewReturnStmt(codegen_->GetPosition(), ret)); } - // Finalize everything. + // Finalize everything statements_->SetRightBracePosition(codegen_->GetPosition()); - // Build the function's type. - auto func_type = codegen_->GetFactory()->NewFunctionType(start_, std::move(params_), ret_type_); + // Build the function's type + auto func_type = codegen_->GetFactory()->NewFunctionType(start_, std::move(params_), return_type_); - // Create the declaration. + // Create the declaration auto func_lit = codegen_->GetFactory()->NewFunctionLitExpr(func_type, statements_); decl_ = codegen_->GetFactory()->NewFunctionDecl(start_, name_, func_lit); + return std::get(decl_); +} + +noisepage::execution::ast::LambdaExpr *FunctionBuilder::FinishClosure(ast::Expr *ret) { + NOISEPAGE_ASSERT(type_ == FunctionType::CLOSURE, + "Attempt to call FuncionBuilder::FinishClosure on non-closure-type builder"); + NOISEPAGE_ASSERT(std::holds_alternative(decl_), "Broken invariant"); + auto *declaration = std::get(decl_); + if (declaration != nullptr) { + return declaration; + } - // Done - return decl_; + NOISEPAGE_ASSERT(ret == nullptr || statements_->IsEmpty() || !statements_->GetLast()->IsReturnStmt(), + "Double-return at end of function. You should either call FunctionBuilder::FinishClosure() " + "with an explicit return expression, or use the factory to manually append a return " + "statement and call FunctionBuilder::FinishClosure() with a null return."); + // Add the return + if (!statements_->IsEmpty() && !statements_->GetLast()->IsReturnStmt()) { + Append(codegen_->GetFactory()->NewReturnStmt(codegen_->GetPosition(), ret)); + } + // Finalize everything + statements_->SetRightBracePosition(codegen_->GetPosition()); + // Build the function's type + auto func_type = codegen_->GetFactory()->NewFunctionType(start_, std::move(params_), return_type_); + + // Create the declaration + auto func_lit = codegen_->GetFactory()->NewFunctionLitExpr(func_type, statements_); + decl_ = codegen_->GetFactory()->NewLambdaExpr(start_, func_lit, std::move(captures_)); + return std::get(decl_); } } // namespace noisepage::execution::compiler diff --git a/src/execution/compiler/operator/hash_join_translator.cpp b/src/execution/compiler/operator/hash_join_translator.cpp index 99a8942386..d2efb9d2ae 100644 --- a/src/execution/compiler/operator/hash_join_translator.cpp +++ b/src/execution/compiler/operator/hash_join_translator.cpp @@ -390,7 +390,7 @@ void HashJoinTranslator::CheckJoinPredicate(WorkContext *ctx, FunctionBuilder *f FillProbeRow(ctx, function, codegen->MakeExpr(probe_row_var_)); // joinConsumer(queryState, pipelineState, buildRow, probeRow); std::initializer_list args{GetQueryStatePtr(), - codegen->MakeExpr(GetPipeline()->GetPipelineStateVar()), + codegen->MakeExpr(GetPipeline()->GetPipelineStateName()), codegen->MakeExpr(build_row_var_), codegen->AddressOf(probe_row)}; function->Append(codegen->Call(join_consumer_, args)); } else { @@ -462,7 +462,7 @@ void HashJoinTranslator::CollectUnmatchedLeftRows(FunctionBuilder *function) con } // joinConsumer(queryState, pipelineState, buildRow, probeRow); std::initializer_list args{GetQueryStatePtr(), - codegen->MakeExpr(GetPipeline()->GetPipelineStateVar()), + codegen->MakeExpr(GetPipeline()->GetPipelineStateName()), codegen->MakeExpr(build_row_var_), codegen->AddressOf(probe_row)}; function->Append(codegen->Call(join_consumer_, args)); } diff --git a/src/execution/compiler/operator/operator_translator.cpp b/src/execution/compiler/operator/operator_translator.cpp index e433699f28..e4d36d4776 100644 --- a/src/execution/compiler/operator/operator_translator.cpp +++ b/src/execution/compiler/operator/operator_translator.cpp @@ -44,6 +44,18 @@ ast::Expr *OperatorTranslator::GetOutput(WorkContext *context, uint32_t attr_idx return context->DeriveValue(*output_expression, this); } +void OperatorTranslator::DefineHelperFunctions(util::RegionVector *decls) { + for (const auto &output_column : GetPlan().GetOutputSchema()->GetColumns()) { + GetCompilationContext()->LookupTranslator(*output_column.GetExpr())->DefineHelperFunctions(decls); + } +} + +void OperatorTranslator::DefineHelperStructs(util::RegionVector *decls) { + for (const auto &output_column : GetPlan().GetOutputSchema()->GetColumns()) { + GetCompilationContext()->LookupTranslator(*output_column.GetExpr())->DefineHelperStructs(decls); + } +} + ast::Expr *OperatorTranslator::GetChildOutput(WorkContext *context, uint32_t child_idx, uint32_t attr_idx) const { // Check valid child. if (child_idx >= plan_.GetChildrenSize()) { @@ -76,6 +88,18 @@ ast::Expr *OperatorTranslator::GetMemoryPool() const { return GetCodeGen()->ExecCtxGetMemoryPool(GetExecutionContext()); } +ast::Identifier OperatorTranslator::MakeLocalIdentifier(std::string_view name) const { + const auto identifier = fmt::format("{}", name); + return GetCodeGen()->MakeFreshIdentifier(identifier); +} + +ast::Identifier OperatorTranslator::MakeGlobalIdentifier(std::string_view name) const { + const auto identifier = GetCompilationContext()->HasOutputCallback() + ? fmt::format("{}{}", GetCompilationContext()->GetFunctionPrefix(), name) + : fmt::format("{}", name); + return GetCodeGen()->MakeFreshIdentifier(identifier); +} + void OperatorTranslator::GetAllChildOutputFields(const uint32_t child_index, const std::string &field_name_prefix, util::RegionVector *fields) const { auto *codegen = GetCodeGen(); diff --git a/src/execution/compiler/operator/output_translator.cpp b/src/execution/compiler/operator/output_translator.cpp index 6c4ac63df2..990b341755 100644 --- a/src/execution/compiler/operator/output_translator.cpp +++ b/src/execution/compiler/operator/output_translator.cpp @@ -18,7 +18,8 @@ OutputTranslator::OutputTranslator(const planner::AbstractPlanNode &plan, Compil Pipeline *pipeline) : OperatorTranslator(plan, compilation_context, pipeline, selfdriving::ExecutionOperatingUnitType::OUTPUT), output_var_(GetCodeGen()->MakeFreshIdentifier("outRow")), - output_struct_(GetCodeGen()->MakeFreshIdentifier("OutputStruct")) { + output_struct_(GetCodeGen()->MakeFreshIdentifier( + "OutputStruct" + std::to_string(compilation_context->GetQueryId().UnderlyingValue()))) { // Prepare the child. compilation_context->Prepare(plan, pipeline); @@ -28,6 +29,10 @@ OutputTranslator::OutputTranslator(const planner::AbstractPlanNode &plan, Compil } void OutputTranslator::InitializePipelineState(const Pipeline &pipeline, FunctionBuilder *function) const { + if (HasOutputCallback()) { + return; + } + auto exec_ctx = GetExecutionContext(); auto *new_call = GetCodeGen()->CallBuiltin(ast::Builtin::ResultBufferNew, {exec_ctx}); function->Append(GetCodeGen()->Assign(output_buffer_.Get(GetCodeGen()), new_call)); @@ -36,6 +41,10 @@ void OutputTranslator::InitializePipelineState(const Pipeline &pipeline, Functio } void OutputTranslator::TearDownPipelineState(const Pipeline &pipeline, FunctionBuilder *function) const { + if (HasOutputCallback()) { + return; + } + auto out_buffer = output_buffer_.Get(GetCodeGen()); ast::Expr *alloc_call = GetCodeGen()->CallBuiltin(ast::Builtin::ResultBufferFree, {out_buffer}); function->Append(GetCodeGen()->MakeStmt(alloc_call)); @@ -43,20 +52,38 @@ void OutputTranslator::TearDownPipelineState(const Pipeline &pipeline, FunctionB void OutputTranslator::PerformPipelineWork(noisepage::execution::compiler::WorkContext *context, noisepage::execution::compiler::FunctionBuilder *function) const { - // First generate the call @resultBufferAllocRow(execCtx) - auto out_buffer = output_buffer_.Get(GetCodeGen()); - ast::Expr *alloc_call = GetCodeGen()->CallBuiltin(ast::Builtin::ResultBufferAllocOutRow, {out_buffer}); - ast::Expr *cast_call = GetCodeGen()->PtrCast(output_struct_, alloc_call); + ast::Expr *cast_call; + if (HasOutputCallback()) { + auto output = GetCodeGen()->MakeFreshIdentifier("output_row"); + auto *row_alloc = GetCodeGen()->DeclareVarNoInit(output, GetCodeGen()->MakeExpr(output_struct_)); + function->Append(row_alloc); + cast_call = GetCodeGen()->AddressOf(GetCodeGen()->MakeExpr(output)); + } else { + auto out_buffer = output_buffer_.Get(GetCodeGen()); + ast::Expr *alloc_call = GetCodeGen()->CallBuiltin(ast::Builtin::ResultBufferAllocOutRow, {out_buffer}); + cast_call = GetCodeGen()->PtrCast(output_struct_, alloc_call); + } + function->Append(GetCodeGen()->DeclareVar(output_var_, nullptr, cast_call)); const auto child_translator = GetCompilationContext()->LookupTranslator(GetPlan()); // Now fill up the output row // For each column in the output, set out.col_i = col_i + std::vector callback_args{GetExecutionContext()}; for (uint32_t attr_idx = 0; attr_idx < GetPlan().GetOutputSchema()->NumColumns(); attr_idx++) { ast::Identifier attr_name = GetCodeGen()->MakeIdentifier(OUTPUT_COL_PREFIX + std::to_string(attr_idx)); ast::Expr *lhs = GetCodeGen()->AccessStructMember(GetCodeGen()->MakeExpr(output_var_), attr_name); ast::Expr *rhs = child_translator->GetOutput(context, attr_idx); function->Append(GetCodeGen()->Assign(lhs, rhs)); + if (HasOutputCallback()) { + callback_args.push_back(lhs); + } + } + + // If an output callback is present, append the callback invocation + if (HasOutputCallback()) { + auto *callback = GetCompilationContext()->GetOutputCallback(); + function->Append(GetCodeGen()->Call(callback->As()->GetName(), callback_args)); } CounterAdd(function, num_output_, 1); @@ -79,6 +106,10 @@ void OutputTranslator::EndParallelPipelineWork(const Pipeline &pipeline, Functio } void OutputTranslator::FinishPipelineWork(const Pipeline &pipeline, FunctionBuilder *function) const { + if (GetCompilationContext()->GetOutputCallback() != nullptr) { + return; + } + auto out_buffer = output_buffer_.Get(GetCodeGen()); function->Append(GetCodeGen()->CallBuiltin(ast::Builtin::ResultBufferFinalize, {out_buffer})); @@ -105,4 +136,6 @@ void OutputTranslator::DefineHelperStructs(util::RegionVector decls->push_back(codegen->DeclareStruct(output_struct_, std::move(fields))); } +bool OutputTranslator::HasOutputCallback() const { return GetCompilationContext()->HasOutputCallback(); } + } // namespace noisepage::execution::compiler diff --git a/src/execution/compiler/operator/static_aggregation_translator.cpp b/src/execution/compiler/operator/static_aggregation_translator.cpp index c5ce108f45..5db69f2e1f 100644 --- a/src/execution/compiler/operator/static_aggregation_translator.cpp +++ b/src/execution/compiler/operator/static_aggregation_translator.cpp @@ -12,15 +12,19 @@ namespace noisepage::execution::compiler { namespace { constexpr char AGG_ATTR_PREFIX[] = "agg_term_attr"; +constexpr char AGG_ROW_VAR[] = "aggRow"; +constexpr char AGG_PAYLOAD_TYPE[] = "AggPayload"; +constexpr char AGG_VALUES_TYPE[] = "AggValues"; +constexpr char AGG_MERGE_FUNC[] = "MergeAggregates"; } // namespace StaticAggregationTranslator::StaticAggregationTranslator(const planner::AggregatePlanNode &plan, CompilationContext *compilation_context, Pipeline *pipeline) : OperatorTranslator(plan, compilation_context, pipeline, selfdriving::ExecutionOperatingUnitType::DUMMY), - agg_row_var_(GetCodeGen()->MakeFreshIdentifier("aggRow")), - agg_payload_type_(GetCodeGen()->MakeFreshIdentifier("AggPayload")), - agg_values_type_(GetCodeGen()->MakeFreshIdentifier("AggValues")), - merge_func_(GetCodeGen()->MakeFreshIdentifier("MergeAggregates")), + agg_row_var_(GetCodeGen()->MakeFreshIdentifier(AGG_ROW_VAR)), + agg_payload_type_(MakeGlobalIdentifier(AGG_PAYLOAD_TYPE)), + agg_values_type_(MakeGlobalIdentifier(AGG_VALUES_TYPE)), + merge_func_(MakeGlobalIdentifier(AGG_MERGE_FUNC)), build_pipeline_(this, Pipeline::Parallelism::Parallel) { NOISEPAGE_ASSERT(plan.GetGroupByTerms().empty(), "Global aggregations shouldn't have grouping keys"); NOISEPAGE_ASSERT(plan.GetChildrenSize() == 1, "Global aggregations should only have one child"); diff --git a/src/execution/compiler/pipeline.cpp b/src/execution/compiler/pipeline.cpp index a788db5917..e20d21afad 100644 --- a/src/execution/compiler/pipeline.cpp +++ b/src/execution/compiler/pipeline.cpp @@ -15,6 +15,7 @@ #include "loggers/execution_logger.h" #include "metrics/metrics_defs.h" #include "planner/plannodes/abstract_plan_node.h" +#include "planner/plannodes/output_schema.h" #include "spdlog/fmt/fmt.h" namespace noisepage::execution::compiler { @@ -25,23 +26,28 @@ Pipeline::Pipeline(CompilationContext *ctx) : id_(ctx->RegisterPipeline(this)), compilation_context_(ctx), codegen_(compilation_context_->GetCodeGen()), - state_var_(codegen_->MakeIdentifier("pipelineState")), - state_(codegen_->MakeIdentifier(fmt::format("P{}_State", id_)), - [this](CodeGen *codegen) { return codegen_->MakeExpr(state_var_); }), + state_(codegen_->MakeIdentifier(fmt::format("{}_Pipeline{}_State", ctx->GetFunctionPrefix(), id_)), + [this](CodeGen *codegen) { return codegen_->MakeExpr(GetPipelineStateName()); }), driver_(nullptr), parallelism_(Parallelism::Parallel), check_parallelism_(true), - nested_(false) {} + nested_(false) { + if (HasOutputCallback()) { + UpdateParallelism(Parallelism::Serial); + } +} Pipeline::Pipeline(OperatorTranslator *op, Pipeline::Parallelism parallelism) : Pipeline(op->GetCompilationContext()) { - UpdateParallelism(parallelism); + if (!HasOutputCallback()) { + UpdateParallelism(parallelism); + } RegisterStep(op); } void Pipeline::RegisterStep(OperatorTranslator *op) { - NOISEPAGE_ASSERT(std::count(steps_.begin(), steps_.end(), op) == 0, + NOISEPAGE_ASSERT(std::count(steps_.cbegin(), steps_.cend(), op) == 0, "Duplicate registration of operator in pipeline."); - auto num_steps = steps_.size(); + const auto num_steps = steps_.size(); if (num_steps > 0) { auto last_step = common::ManagedPointer(steps_[num_steps - 1]); // TODO(WAN): MAYDAY CHECK WITH LIN AND PRASHANTH, did ordering of these change? @@ -75,26 +81,6 @@ StateDescriptor::Entry Pipeline::DeclarePipelineStateEntry(const std::string &na return state.DeclareStateEntry(codegen_, name, type_repr); } -std::string Pipeline::CreatePipelineFunctionName(const std::string &func_name) const { - auto result = fmt::format("{}_Pipeline{}", compilation_context_->GetFunctionPrefix(), id_); - if (!func_name.empty()) { - result += "_" + func_name; - } - return result; -} - -ast::Identifier Pipeline::GetSetupPipelineStateFunctionName() const { - return codegen_->MakeIdentifier(CreatePipelineFunctionName("InitPipelineState")); -} - -ast::Identifier Pipeline::GetTearDownPipelineStateFunctionName() const { - return codegen_->MakeIdentifier(CreatePipelineFunctionName("TearDownPipelineState")); -} - -ast::Identifier Pipeline::GetWorkFunctionName() const { - return codegen_->MakeIdentifier(CreatePipelineFunctionName(IsParallel() ? "ParallelWork" : "SerialWork")); -} - void Pipeline::InjectStartResourceTracker(FunctionBuilder *builder, bool is_hook) const { if (compilation_context_->IsPipelineMetricsEnabled()) { auto *exec_ctx = compilation_context_->GetExecutionContextPtrFromQueryState(); @@ -142,36 +128,34 @@ void Pipeline::InjectEndResourceTracker(FunctionBuilder *builder, bool is_hook) } } -util::RegionVector Pipeline::PipelineParams() const { - // The main query parameters. - util::RegionVector query_params = compilation_context_->QueryParams(); - // Tag on the pipeline state. - auto &state = GetPipelineStateDescriptor(); - ast::Expr *pipeline_state = codegen_->PointerType(codegen_->MakeExpr(state.GetTypeName())); - query_params.push_back(codegen_->MakeField(state_var_, pipeline_state)); - return query_params; -} - void Pipeline::LinkSourcePipeline(Pipeline *dependency) { NOISEPAGE_ASSERT(dependency != nullptr, "Source cannot be null"); + // Add pipeline `dependency` as a nested pipeline dependencies_.push_back(dependency); + // Remove ourselves from the nested pipeline of dependency, if present + // TODO(Kyle): Is this possible? If so, is this a broken invariant? + if (std::find(dependency->nested_pipelines_.begin(), dependency->nested_pipelines_.end(), this) != + dependency->nested_pipelines_.end()) { + dependency->nested_pipelines_.erase( + std::remove(dependency->nested_pipelines_.begin(), dependency->nested_pipelines_.end(), this), + dependency->nested_pipelines_.end()); + } } -void Pipeline::LinkNestedPipeline(Pipeline *nested_pipeline, const OperatorTranslator *op) { - NOISEPAGE_ASSERT(nested_pipeline != nullptr, "Nested pipeline cannot be null"); +void Pipeline::LinkNestedPipeline(Pipeline *pipeline, const OperatorTranslator *op) { + NOISEPAGE_ASSERT(pipeline != nullptr, "Nested pipeline cannot be null"); // if pipeline is in my dependencies let's not do this to avoid circularity - if (std::find(dependencies_.begin(), dependencies_.end(), nested_pipeline) == dependencies_.end()) { - nested_pipeline->nested_pipelines_.push_back(this); + if (std::find(dependencies_.begin(), dependencies_.end(), pipeline) == dependencies_.end()) { + pipeline->nested_pipelines_.push_back(this); } - if (!nested_pipeline->nested_) { - nested_pipeline->nested_ = true; - nested_pipeline->parent_ = this; + if (!pipeline->IsNestedPipeline()) { + pipeline->MarkNested(); // add to pipeline params - size_t i = 0; + std::size_t i = 0; for (auto &col : op->GetPlan().GetOutputSchema()->GetColumns()) { - NOISEPAGE_ASSERT(nested_pipeline->extra_pipeline_params_.empty(), + NOISEPAGE_ASSERT(pipeline->extra_pipeline_params_.empty(), "We do not support two pipelines nesting the same pipeline yet"); - nested_pipeline->extra_pipeline_params_.push_back( + pipeline->extra_pipeline_params_.push_back( codegen_->MakeField(codegen_->MakeIdentifier("row" + std::to_string(i++)), codegen_->PointerType(codegen_->TplType(sql::GetTypeId(col.GetType()))))); } @@ -208,8 +192,7 @@ void Pipeline::Prepare(const exec::ExecutionSettings &exec_settings) { ast::Expr *type = codegen_->BuiltinType(ast::BuiltinType::ExecOUFeatureVector); oufeatures_ = DeclarePipelineStateEntry("execFeatures", type); } - - // if this pipeline is nested, it doesn't own its pipeline state + // If this pipeline is nested, it doesn't own its pipeline state state_.ConstructFinalType(codegen_); // Finalize the execution mode. We choose serial execution if ANY of the below @@ -228,10 +211,12 @@ void Pipeline::Prepare(const exec::ExecutionSettings &exec_settings) { // Pretty print. { - std::string result; + std::string result{}; bool first = true; for (auto iter = Begin(), end = End(); iter != end; ++iter) { - if (!first) result += " --> "; + if (!first) { + result += " --> "; + } first = false; std::string plan_type = planner::PlanNodeTypeToString((*iter)->GetPlan().GetPlanNodeType()); std::transform(plan_type.begin(), plan_type.end(), plan_type.begin(), ::tolower); @@ -244,12 +229,60 @@ void Pipeline::Prepare(const exec::ExecutionSettings &exec_settings) { prepared_ = true; } -ast::FunctionDecl *Pipeline::GenerateSetupPipelineStateFunction() const { - auto name = GetSetupPipelineStateFunctionName(); - FunctionBuilder builder(codegen_, name, PipelineParams(), codegen_->Nil()); +/* ---------------------------------------------------------------------------- + Pipeline Generation: Top-Level +----------------------------------------------------------------------------- */ + +void Pipeline::GeneratePipeline(ExecutableQueryFragmentBuilder *builder) const { + NOISEPAGE_ASSERT(!(IsNestedPipeline() && HasOutputCallback()), + "Single pipeline cannot both be nested and have an output callback"); + + // Declare the pipeline state. + builder->DeclareStruct(state_.GetType()); + // Generate pipeline state initialization and tear-down functions. + builder->DeclareFunction(GenerateInitPipelineStateFunction()); + builder->DeclareFunction(GenerateTearDownPipelineStateFunction()); + + auto teardown = GenerateTearDownPipelineFunction(); + + // Declare top-level functions + builder->DeclareFunction(GenerateInitPipelineFunction()); + builder->DeclareFunction(GeneratePipelineWorkFunction()); + builder->DeclareFunction(GenerateRunPipelineFunction()); + builder->DeclareFunction(teardown); + + if (HasOutputCallback()) { + auto run_all = GeneratePipelineRunAllFunction(); + builder->DeclareFunction(run_all); + builder->RegisterStep(run_all); + } else if (!IsNestedPipeline()) { + // Register the main init, run, tear-down functions as steps, in that order. + builder->RegisterStep(GenerateInitPipelineFunction()); + builder->RegisterStep(GenerateRunPipelineFunction()); + builder->RegisterStep(teardown); + } + + // For nested pipelines, do not register any of the top-level + // pipeline functions as steps in the Fragment builder, and + // instead rely on a call to Pipeline::CallNestedRunPipeline + + builder->AddTeardownFn(teardown); +} + +/* ---------------------------------------------------------------------------- + Pipeline Generation: State Setup + Teardown +----------------------------------------------------------------------------- */ + +ast::FunctionDecl *Pipeline::GenerateInitPipelineStateFunction() const { + /** + * fun QueryX_PipelineY_InitPipelineState(*QueryState, *PipelineState) + */ + + auto name = GetInitPipelineStateFunctionName(); + FunctionBuilder builder{codegen_, name, PipelineParams(), codegen_->Nil()}; { // Request new scope for the function. - CodeGen::CodeScope code_scope(codegen_); + CodeGen::CodeScope code_scope{codegen_}; for (auto *op : steps_) { op->InitializePipelineState(*this, &builder); } @@ -258,11 +291,15 @@ ast::FunctionDecl *Pipeline::GenerateSetupPipelineStateFunction() const { } ast::FunctionDecl *Pipeline::GenerateTearDownPipelineStateFunction() const { + /** + * fun QueryX_PipelineY_TeardownPipelineState(*QueryState, *PipelineState) + */ + auto name = GetTearDownPipelineStateFunctionName(); - FunctionBuilder builder(codegen_, name, PipelineParams(), codegen_->Nil()); + FunctionBuilder builder{codegen_, name, PipelineParams(), codegen_->Nil()}; { // Request new scope for the function. - CodeGen::CodeScope code_scope(codegen_); + CodeGen::CodeScope code_scope{codegen_}; for (auto *op : steps_) { op->TearDownPipelineState(*this, &builder); } @@ -277,146 +314,127 @@ ast::FunctionDecl *Pipeline::GenerateTearDownPipelineStateFunction() const { return builder.Finish(); } -ast::FunctionDecl *Pipeline::GenerateInitPipelineFunction() const { - auto query_state = compilation_context_->GetQueryState(); - auto name = codegen_->MakeIdentifier(CreatePipelineFunctionName("Init")); - auto params = compilation_context_->QueryParams(); - ast::FieldDecl *p_state_ptr = nullptr; - auto &state = GetPipelineStateDescriptor(); - // Need to initialize this to stop compiler from complaining - uint32_t p_state_ind = -1; - if (nested_) { - p_state_ptr = codegen_->MakeField(codegen_->MakeFreshIdentifier("pipeline_state"), - codegen_->PointerType(codegen_->MakeExpr(state.GetTypeName()))); - params.push_back(p_state_ptr); - p_state_ind = params.size() - 1; - } - FunctionBuilder builder(codegen_, name, std::move(params), codegen_->Nil()); - { - CodeGen::CodeScope code_scope(codegen_); - // var tls = @execCtxGetTLS(exec_ctx) - ast::Expr *exec_ctx = compilation_context_->GetExecutionContextPtrFromQueryState(); - ast::Identifier tls = codegen_->MakeFreshIdentifier("threadStateContainer"); - builder.Append(codegen_->DeclareVarWithInit(tls, codegen_->ExecCtxGetTLS(exec_ctx))); - // @tlsReset(tls, @sizeOf(ThreadState), init, tearDown, queryState) - ast::Expr *state_ptr = query_state->GetStatePointer(codegen_); - auto &state = GetPipelineStateDescriptor(); - if (!nested_) { - // TLS reset if this is not a nested pipeline, i.e: we OWN the TLS - builder.Append(codegen_->TLSReset(codegen_->MakeExpr(tls), state.GetTypeName(), - GetSetupPipelineStateFunctionName(), GetTearDownPipelineStateFunctionName(), - state_ptr)); - } else { - // no TLS reset if pipeline is nested - auto pipeline_state = builder.GetParameterByPosition(p_state_ind); - builder.Append(codegen_->Call(GetSetupPipelineStateFunctionName(), {state_ptr, pipeline_state})); - } - } - return builder.Finish(); -} +/* ---------------------------------------------------------------------------- + Pipeline Generation: RunAll +----------------------------------------------------------------------------- */ -ast::FunctionDecl *Pipeline::GeneratePipelineWorkFunction() const { - auto params = PipelineParams(); - for (auto field : extra_pipeline_params_) { - params.push_back(field); - } +ast::FunctionDecl *Pipeline::GeneratePipelineRunAllFunction() const { + NOISEPAGE_ASSERT(HasOutputCallback(), "Should only generate RunAll function for pipeline with output callback"); + /** + * fun QueryX_PipelineY_RunAll(*QueryState, Closure) + */ - if (IsParallel()) { - auto additional_params = driver_->GetWorkerParams(); - params.insert(params.end(), additional_params.begin(), additional_params.end()); - } + const ast::Identifier name = GetRunAllPipelineFunctionName(); + util::RegionVector params{QueryParams()}; + params.push_back(codegen_->MakeField( + GetOutputCallback()->GetName(), codegen_->LambdaType(GetOutputCallback()->GetFunctionLiteralExpr()->TypeRepr()))); - FunctionBuilder builder(codegen_, GetWorkFunctionName(), std::move(params), codegen_->Nil()); - { - // Begin a new code scope for fresh variables. - CodeGen::CodeScope code_scope(codegen_); - if (IsParallel()) { - for (auto *op : steps_) { - op->BeginParallelPipelineWork(*this, &builder); - } + FunctionBuilder builder{codegen_, name, std::move(params), codegen_->Nil()}; - InjectStartResourceTracker(&builder, false); - } + { + CodeGen::CodeScope code_scope{codegen_}; - // Create the working context and push it through the pipeline. - WorkContext context(compilation_context_, *this); - (*Begin())->PerformPipelineWork(&context, &builder); + ast::Identifier pipeline_state_id = codegen_->MakeFreshIdentifier("pipelineState"); + builder.Append(codegen_->DeclareVarNoInit(pipeline_state_id, state_.GetType()->TypeRepr())); - if (IsParallel()) { - for (auto *op : steps_) { - op->EndParallelPipelineWork(*this, &builder); - } + NOISEPAGE_ASSERT(builder.GetParameterCount() == 2, "Unexpected parameter count for RunAll function"); + auto *query_state = builder.GetParameterByPosition(0); + auto *pipeline_state = codegen_->AddressOf(pipeline_state_id); + auto *callback = builder.GetParameterByPosition(1); - InjectEndResourceTracker(&builder, false); - } + builder.Append(codegen_->Call(GetInitPipelineFunctionName(), {query_state, pipeline_state})); + builder.Append(codegen_->Call(GetRunPipelineFunctionName(), {query_state, pipeline_state, callback})); + builder.Append(codegen_->Call(GetTeardownPipelineFunctionName(), {query_state, pipeline_state})); } - return builder.Finish(); -} -std::vector Pipeline::GenerateSingleRunPipelineFunction() const { - NOISEPAGE_ASSERT(!nested_, "can't call a nested pipeline like this"); - return { - codegen_->Call(GetInitPipelineFunctionName(), {compilation_context_->GetQueryState()->GetStatePointer(codegen_)}), - codegen_->Call(GetRunPipelineFunctionName(), {compilation_context_->GetQueryState()->GetStatePointer(codegen_)}), - codegen_->Call(GetTeardownPipelineFunctionName(), - {compilation_context_->GetQueryState()->GetStatePointer(codegen_)})}; + return builder.Finish(); } -void Pipeline::CallNestedRunPipelineFunction(WorkContext *ctx, const OperatorTranslator *op, - FunctionBuilder *function) const { - std::vector stmts; - auto p_state = codegen_->MakeFreshIdentifier("nested_state"); - auto p_state_ptr = codegen_->AddressOf(p_state); +/* ---------------------------------------------------------------------------- + Pipeline Generation: Steps +----------------------------------------------------------------------------- */ - std::vector params_vec = {compilation_context_->GetQueryState()->GetStatePointer(codegen_)}; - params_vec.push_back(p_state_ptr); +ast::FunctionDecl *Pipeline::GenerateInitPipelineFunction() const { + /** + * Common Case: + * fun QueryX_PipelineY_Init(queryState) + * + * Nested Pipeline: + * fun QueryX_PipelineY_InitPipeline(queryState, pipelineState) + * + * Output Callback: + * fun QueryX_PipelineY_InitPipeline(queryState, pipelineState) + */ - for (size_t i = 0; i < op->GetPlan().GetOutputSchema()->GetColumns().size(); i++) { - params_vec.push_back(codegen_->AddressOf(op->GetOutput(ctx, i))); + auto query_state = compilation_context_->GetQueryState(); + const ast::Identifier name = GetInitPipelineFunctionName(); + + util::RegionVector params{QueryParams()}; + if (IsNestedPipeline() || HasOutputCallback()) { + const auto &state = GetPipelineStateDescriptor(); + ast::FieldDecl *pipeline_state_ptr = codegen_->MakeField( + codegen_->MakeFreshIdentifier("pipelineState"), codegen_->PointerType(codegen_->MakeExpr(state.GetTypeName()))); + params.push_back(pipeline_state_ptr); } - function->Append(codegen_->DeclareVarNoInit(p_state, codegen_->MakeExpr(GetPipelineStateDescriptor().GetTypeName()))); - function->Append(codegen_->Call(GetInitPipelineFunctionName(), - {compilation_context_->GetQueryState()->GetStatePointer(codegen_), p_state_ptr})); - function->Append(codegen_->Call(GetRunPipelineFunctionName(), params_vec)); - function->Append(codegen_->Call(GetTeardownPipelineFunctionName(), - {compilation_context_->GetQueryState()->GetStatePointer(codegen_), p_state_ptr})); -} - -ast::Identifier Pipeline::GetInitPipelineFunctionName() const { - return codegen_->MakeIdentifier(CreatePipelineFunctionName("Init")); -} - -ast::Identifier Pipeline::GetTeardownPipelineFunctionName() const { - return codegen_->MakeIdentifier(CreatePipelineFunctionName("TearDown")); -} - -ast::Identifier Pipeline::GetRunPipelineFunctionName() const { - return codegen_->MakeIdentifier(CreatePipelineFunctionName("Run")); -} + FunctionBuilder builder{codegen_, name, std::move(params), codegen_->Nil()}; + { + CodeGen::CodeScope code_scope{codegen_}; + ast::Expr *state_ptr = query_state->GetStatePointer(codegen_); -ast::Expr *Pipeline::GetNestedInputArg(uint32_t index) const { - NOISEPAGE_ASSERT(nested_, "Asking for input arg on non-nested pipeline"); - NOISEPAGE_ASSERT(index < extra_pipeline_params_.size(), - "Asking for input arg on non-nested pipeline that doesn't exist"); - return codegen_->UnaryOp(parsing::Token::Type::STAR, codegen_->MakeExpr(extra_pipeline_params_[index]->Name())); + if (IsNestedPipeline() || HasOutputCallback()) { + // No TLS reset in nested pipelines + // NOTE: Assumes the pipeline state is always the final parameter + const auto pipeline_state_index = builder.GetParameterCount() - 1; + auto *pipeline_state = builder.GetParameterByPosition(pipeline_state_index); + builder.Append(codegen_->Call(GetInitPipelineStateFunctionName(), {state_ptr, pipeline_state})); + } else { + auto &state = GetPipelineStateDescriptor(); + ast::Expr *exec_ctx = compilation_context_->GetExecutionContextPtrFromQueryState(); + ast::Identifier tls = codegen_->MakeFreshIdentifier("threadStateContainer"); + // var tls = @execCtxGetTLS(exec_ctx) + builder.Append(codegen_->DeclareVarWithInit(tls, codegen_->ExecCtxGetTLS(exec_ctx))); + // @tlsReset(tls, @sizeOf(ThreadState), init, tearDown, queryState) + builder.Append(codegen_->TLSReset(codegen_->MakeExpr(tls), state.GetTypeName(), + GetInitPipelineStateFunctionName(), GetTearDownPipelineStateFunctionName(), + state_ptr)); + } + } + return builder.Finish(); } ast::FunctionDecl *Pipeline::GenerateRunPipelineFunction() const { + /** + * Common Case: + * fun QueryX_PipelineY_Run(queryState) + * + * Nested Pipeline: + * fun QueryX_PipelineY_Run(queryState, pipelineState) + * + * Output Callback: + * fun QueryX_PipelineY_Run(queryState, outputCallback) + */ + bool started_tracker = false; - auto name = codegen_->MakeIdentifier(CreatePipelineFunctionName("Run")); - auto params = compilation_context_->QueryParams(); - if (nested_) { - // if we're nested we also take in the pipeline state as an argument - params.push_back(codegen_->MakeField(state_var_, codegen_->PointerType(state_.GetTypeName()))); + const ast::Identifier name = GetRunPipelineFunctionName(); + + util::RegionVector params{QueryParams()}; + if (IsNestedPipeline() || HasOutputCallback()) { + params.push_back(codegen_->MakeField(GetPipelineStateName(), codegen_->PointerType(state_.GetTypeName()))); } - for (auto field : extra_pipeline_params_) { + for (auto *field : extra_pipeline_params_) { params.push_back(field); } - FunctionBuilder builder(codegen_, name, std::move(params), codegen_->Nil()); + if (HasOutputCallback()) { + params.push_back( + codegen_->MakeField(GetOutputCallback()->GetName(), + codegen_->LambdaType(GetOutputCallback()->GetFunctionLiteralExpr()->TypeRepr()))); + } + + FunctionBuilder builder{codegen_, name, std::move(params), codegen_->Nil()}; { // Begin a new code scope for fresh variables. - CodeGen::CodeScope code_scope(codegen_); + CodeGen::CodeScope code_scope{codegen_}; // TODO(abalakum): This shouldn't actually be dependent on order and the loop can be simplified // after issue #1154 is fixed @@ -425,35 +443,39 @@ ast::FunctionDecl *Pipeline::GenerateRunPipelineFunction() const { (*iter)->BeginPipelineWork(*this, &builder); } - // var pipelineState = @tlsGetCurrentThreadState(...) - auto exec_ctx = compilation_context_->GetExecutionContextPtrFromQueryState(); - auto tls = codegen_->ExecCtxGetTLS(exec_ctx); - auto state_type = GetPipelineStateDescriptor().GetTypeName(); - auto state = codegen_->TLSAccessCurrentThreadState(tls, state_type); - builder.Append(codegen_->DeclareVarWithInit(state_var_, state)); + // Nested pipelines and pipelines with callbacks have their + // pipeline state passed as an argument to this function + if (!IsNestedPipeline() && !HasOutputCallback()) { + // var pipelineState = @tlsGetCurrentThreadState(...) + auto exec_ctx = compilation_context_->GetExecutionContextPtrFromQueryState(); + auto tls = codegen_->ExecCtxGetTLS(exec_ctx); + auto state_type = GetPipelineStateDescriptor().GetTypeName(); + auto state = codegen_->TLSAccessCurrentThreadState(tls, state_type); + builder.Append(codegen_->DeclareVarWithInit(GetPipelineStateName(), state)); + } // Launch pipeline work. if (IsParallel()) { - driver_->LaunchWork(&builder, GetWorkFunctionName()); + driver_->LaunchWork(&builder, GetPipelineWorkFunctionName()); } else { - // SerialWork(queryState, pipelineState) - // if(!nested_) { + // Serial pipeline InjectStartResourceTracker(&builder, false); started_tracker = true; - // } - - std::vector args = {builder.GetParameterByPosition(0), codegen_->MakeExpr(state_var_)}; - if (nested_) { - // if this is a nested pipeline we also pop in all the extra arguments that the nested pipeline takes - // i starts at arg size because the first two args of run pipeline are query state and pipelien state - size_t i = args.size(); - ast::Expr *arg = builder.GetParameterByPosition(i++); - while (arg != nullptr) { - args.push_back(arg); - arg = builder.GetParameterByPosition(i++); - } + + std::vector params{codegen_->MakeExpr(GetQueryStateName()), + codegen_->MakeExpr(GetPipelineStateName())}; + if (IsNestedPipeline()) { + const auto run_params = builder.GetParameters(); + auto begin = run_params.cbegin(); + std::advance(begin, params.size()); + params.insert(params.end(), begin, run_params.cend()); + } + + if (HasOutputCallback()) { + params.push_back(codegen_->MakeExpr(GetOutputCallback()->GetName())); } - builder.Append(codegen_->Call(GetWorkFunctionName(), args)); + + builder.Append(codegen_->Call(GetPipelineWorkFunctionName(), params)); } // TODO(abalakum): This shouldn't actually be dependent on order and the loop can be simplified @@ -471,74 +493,201 @@ ast::FunctionDecl *Pipeline::GenerateRunPipelineFunction() const { return builder.Finish(); } +ast::FunctionDecl *Pipeline::GeneratePipelineWorkFunction() const { + util::RegionVector params{PipelineParams()}; + for (auto *field : extra_pipeline_params_) { + params.push_back(field); + } + + // NOTE(Kyle): This is hacky... + if (IsParallel()) { + auto additional_params = driver_->GetWorkerParams(); + params.insert(params.end(), additional_params.cbegin(), additional_params.cend()); + } else if (HasOutputCallback()) { + params.push_back( + codegen_->MakeField(GetOutputCallback()->GetName(), + codegen_->LambdaType(GetOutputCallback()->GetFunctionLiteralExpr()->TypeRepr()))); + } + + FunctionBuilder builder{codegen_, GetPipelineWorkFunctionName(), std::move(params), codegen_->Nil()}; + { + // Begin a new code scope for fresh variables. + CodeGen::CodeScope code_scope{codegen_}; + if (IsParallel()) { + for (auto *op : steps_) { + op->BeginParallelPipelineWork(*this, &builder); + } + + InjectStartResourceTracker(&builder, false); + } + + // Create the working context and push it through the pipeline. + WorkContext context(compilation_context_, *this); + (*Begin())->PerformPipelineWork(&context, &builder); + + if (IsParallel()) { + for (auto *op : steps_) { + op->EndParallelPipelineWork(*this, &builder); + } + + InjectEndResourceTracker(&builder, false); + } + } + return builder.Finish(); +} + ast::FunctionDecl *Pipeline::GenerateTearDownPipelineFunction() const { - auto name = codegen_->MakeIdentifier(CreatePipelineFunctionName("TearDown")); - auto params = compilation_context_->QueryParams(); - ast::FieldDecl *p_state_ptr = nullptr; - auto &state = GetPipelineStateDescriptor(); - // Need to initialize this to stop compiler from complaining - uint32_t p_state_index = -1; - if (nested_) { - // if we're nested we also take in the pipeline state as an argument - p_state_ptr = codegen_->MakeField(codegen_->MakeFreshIdentifier("pipeline_state"), - codegen_->PointerType(codegen_->MakeExpr(state.GetTypeName()))); - params.push_back(p_state_ptr); - p_state_index = params.size() - 1; + /** + * Common Case: + * QueryX_PipelineY_Teardown(queryState) + * + * Nested Pipeline: + * QueryX_PipelineY_Teardown(queryState, pipelineState) + * + * Output Callback: + * QueryX_PipelineY_Teardown(queryState, pipelineState) + */ + + const ast::Identifier name = GetTeardownPipelineFunctionName(); + + util::RegionVector params{QueryParams()}; + if (IsNestedPipeline() || HasOutputCallback()) { + ast::FieldDecl *pipeline_state = + codegen_->MakeField(codegen_->MakeFreshIdentifier("pipelineState"), + codegen_->PointerType(codegen_->MakeExpr(GetPipelineStateDescriptor().GetTypeName()))); + params.push_back(pipeline_state); } - FunctionBuilder builder(codegen_, name, std::move(params), codegen_->Nil()); + FunctionBuilder builder{codegen_, name, std::move(params), codegen_->Nil()}; { // Begin a new code scope for fresh variables. - CodeGen::CodeScope code_scope(codegen_); - if (!nested_) { - // Tear down thread local state if parallel pipeline - // again this is only applicable if we OWN the TLS (not nested) + CodeGen::CodeScope code_scope{codegen_}; + if (IsNestedPipeline() || HasOutputCallback()) { + // NOTE: Assumes pipeline state is always final parameter to call + const auto pipeline_state_index = builder.GetParameterCount() - 1; + auto query_state = compilation_context_->GetQueryState()->GetStatePointer(codegen_); + auto pipeline_state = builder.GetParameterByPosition(pipeline_state_index); + auto call = codegen_->Call(GetTearDownPipelineStateFunctionName(), {query_state, pipeline_state}); + builder.Append(codegen_->MakeStmt(call)); + } else { + // Tear down thread local state if parallel pipeline. ast::Expr *exec_ctx = compilation_context_->GetExecutionContextPtrFromQueryState(); builder.Append(codegen_->TLSClear(codegen_->ExecCtxGetTLS(exec_ctx))); - auto call = codegen_->CallBuiltin(ast::Builtin::EnsureTrackersStopped, {exec_ctx}); builder.Append(codegen_->MakeStmt(call)); - } else { - // nested pipelines don't own the TLS so we won't - auto query_state = compilation_context_->GetQueryState(); - auto state_ptr = query_state->GetStatePointer(codegen_); - - auto pipeline_state = builder.GetParameterByPosition(p_state_index); - auto call = codegen_->Call(GetTearDownPipelineStateFunctionName(), {state_ptr, pipeline_state}); - builder.Append(codegen_->MakeStmt(call)); } } return builder.Finish(); } -void Pipeline::GeneratePipeline(ExecutableQueryFragmentBuilder *builder) const { - // Declare the pipeline state. - builder->DeclareStruct(state_.GetType()); - // Generate pipeline state initialization and tear-down functions. - builder->DeclareFunction(GenerateSetupPipelineStateFunction()); - builder->DeclareFunction(GenerateTearDownPipelineStateFunction()); +/* ---------------------------------------------------------------------------- + Pipeline Function Parameter Definition +----------------------------------------------------------------------------- */ - // Generate main pipeline logic. - builder->DeclareFunction(GeneratePipelineWorkFunction()); +util::RegionVector Pipeline::QueryParams() const { return compilation_context_->QueryParams(); } - auto init_fn = GenerateInitPipelineFunction(); - auto run_fn = GenerateRunPipelineFunction(); - auto teardown = GenerateTearDownPipelineFunction(); +util::RegionVector Pipeline::PipelineParams() const { + // The main query parameters + util::RegionVector pipeline_params{QueryParams()}; + // Tag on the pipeline state + auto &state = GetPipelineStateDescriptor(); + ast::Expr *pipeline_state = codegen_->PointerType(codegen_->MakeExpr(state.GetTypeName())); + pipeline_params.push_back(codegen_->MakeField(GetPipelineStateName(), pipeline_state)); + return pipeline_params; +} - // Register the main init, run, tear-down functions as steps, in that order only if this isn't nested - // if this is nested, we don't want to register these things as steps as a nested pipeline's functions - // will be called by other pipeline functions - if (!nested_) { - builder->RegisterStep(init_fn); - builder->RegisterStep(run_fn); - builder->RegisterStep(teardown); - } else { - // we have to manually declare these, in the other case, RegisterStep did the equivalent - builder->DeclareFunction(init_fn); - builder->DeclareFunction(run_fn); - builder->DeclareFunction(teardown); +/* ---------------------------------------------------------------------------- + Nested Pipelines +----------------------------------------------------------------------------- */ + +void Pipeline::CallNestedRunPipelineFunction(WorkContext *ctx, const OperatorTranslator *op, + FunctionBuilder *function) const { + std::vector stmts{}; + auto pipeline_state = codegen_->MakeFreshIdentifier("nestedPipelineState"); + auto pipeline_state_ptr = codegen_->AddressOf(pipeline_state); + + // Populate the parameters passed to the Run function for the nested pipeline + std::vector run_parameters{compilation_context_->GetQueryState()->GetStatePointer(codegen_), + pipeline_state_ptr}; + for (std::size_t i = 0; i < op->GetPlan().GetOutputSchema()->GetColumns().size(); i++) { + run_parameters.push_back(codegen_->AddressOf(op->GetOutput(ctx, i))); } - builder->AddTeardownFn(teardown); + + // Declare a local pipeline state variable + function->Append( + codegen_->DeclareVarNoInit(pipeline_state, codegen_->MakeExpr(GetPipelineStateDescriptor().GetTypeName()))); + + // call QueryX_PipelineY_Init(*QueryState, *PipelineState) + function->Append( + codegen_->Call(GetInitPipelineFunctionName(), + {compilation_context_->GetQueryState()->GetStatePointer(codegen_), pipeline_state_ptr})); + // call QueryX_PipelineY_Run(*QueryState, *PipelineState, ...) + function->Append(codegen_->Call(GetRunPipelineFunctionName(), run_parameters)); + // call QueryX_PipelineY_Teardown(*QueryState, *PipelineState) + function->Append( + codegen_->Call(GetTeardownPipelineFunctionName(), + {compilation_context_->GetQueryState()->GetStatePointer(codegen_), pipeline_state_ptr})); } +/* ---------------------------------------------------------------------------- + Variable + Function Identifiers +----------------------------------------------------------------------------- */ + +ast::Identifier Pipeline::GetQueryStateName() const { return compilation_context_->GetQueryStateName(); } + +ast::Identifier Pipeline::GetPipelineStateName() const { return codegen_->MakeIdentifier("pipelineState"); } + +ast::Identifier Pipeline::GetInitPipelineStateFunctionName() const { + return codegen_->MakeIdentifier(CreatePipelineFunctionName("InitPipelineState")); +} + +ast::Identifier Pipeline::GetTearDownPipelineStateFunctionName() const { + return codegen_->MakeIdentifier(CreatePipelineFunctionName("TearDownPipelineState")); +} + +ast::Identifier Pipeline::GetRunAllPipelineFunctionName() const { + return codegen_->MakeIdentifier(CreatePipelineFunctionName("RunAll")); +} + +ast::Identifier Pipeline::GetInitPipelineFunctionName() const { + return codegen_->MakeIdentifier(CreatePipelineFunctionName("Init")); +} + +ast::Identifier Pipeline::GetTeardownPipelineFunctionName() const { + return codegen_->MakeIdentifier(CreatePipelineFunctionName("TearDown")); +} + +ast::Identifier Pipeline::GetRunPipelineFunctionName() const { + return codegen_->MakeIdentifier(CreatePipelineFunctionName("Run")); +} + +ast::Identifier Pipeline::GetPipelineWorkFunctionName() const { + return codegen_->MakeIdentifier(CreatePipelineFunctionName(IsParallel() ? "ParallelWork" : "SerialWork")); +} + +std::string Pipeline::CreatePipelineFunctionName(const std::string &func_name) const { + auto result = fmt::format("{}_Pipeline{}", compilation_context_->GetFunctionPrefix(), id_); + if (!func_name.empty()) { + result += "_" + func_name; + } + return result; +} + +/* ---------------------------------------------------------------------------- + Additional Helpers +----------------------------------------------------------------------------- */ + +ast::Expr *Pipeline::GetNestedInputArg(const std::size_t index) const { + NOISEPAGE_ASSERT(IsNestedPipeline(), "Requested nested input argument on non-nested pipeline"); + NOISEPAGE_ASSERT(index < extra_pipeline_params_.size(), "Requested nested index argument out of range"); + return codegen_->UnaryOp(parsing::Token::Type::STAR, codegen_->MakeExpr(extra_pipeline_params_[index]->Name())); +} + +ast::LambdaExpr *Pipeline::GetOutputCallback() const { + NOISEPAGE_ASSERT(HasOutputCallback(), "Attempt to get nonexistent output callback"); + return compilation_context_->GetOutputCallback(); +} + +bool Pipeline::HasOutputCallback() const { return compilation_context_->HasOutputCallback(); } + } // namespace noisepage::execution::compiler diff --git a/src/execution/compiler/udf/udf_codegen.cpp b/src/execution/compiler/udf/udf_codegen.cpp new file mode 100644 index 0000000000..dac8820c39 --- /dev/null +++ b/src/execution/compiler/udf/udf_codegen.cpp @@ -0,0 +1,1085 @@ +#include "execution/compiler/udf/udf_codegen.h" + +#include "binder/bind_node_visitor.h" +#include "catalog/catalog_accessor.h" +#include "common/error/error_code.h" +#include "common/error/exception.h" +#include "execution/ast/ast.h" +#include "execution/ast/ast_clone.h" +#include "execution/ast/context.h" +#include "execution/ast/udf/udf_ast_nodes.h" +#include "execution/compiler/compilation_context.h" +#include "execution/compiler/executable_query.h" +#include "execution/compiler/if.h" +#include "execution/compiler/loop.h" +#include "execution/exec/execution_settings.h" +#include "execution/parsing/token.h" +#include "execution/vm/bytecode_function_info.h" +#include "optimizer/cost_model/trivial_cost_model.h" +#include "optimizer/statistics/stats_storage.h" +#include "parser/expression/constant_value_expression.h" +#include "parser/postgresparser.h" +#include "parser/udf/variable_ref.h" +#include "planner/plannodes/abstract_plan_node.h" +#include "planner/plannodes/output_schema.h" +#include "traffic_cop/traffic_cop_util.h" + +namespace noisepage::execution::compiler::udf { + +/** The identifier for the pipeline `RunAll` function */ +constexpr static const char RUN_ALL_IDENTIFIER[] = "RunAll"; + +UdfCodegen::UdfCodegen(catalog::CatalogAccessor *accessor, FunctionBuilder *fb, + ast::udf::UdfAstContext *udf_ast_context, CodeGen *codegen, catalog::db_oid_t db_oid) + : accessor_{accessor}, + fb_{fb}, + udf_ast_context_{udf_ast_context}, + codegen_{codegen}, + db_oid_{db_oid}, + aux_decls_{codegen->GetAstContext()->GetRegion()} { + for (auto i = 0UL; fb->GetParameterByPosition(i) != nullptr; ++i) { + auto param = fb->GetParameterByPosition(i); + const auto &name = param->As()->Name(); + SymbolTable()[name.GetString()] = name; + } +} + +// Static +ast::File *UdfCodegen::Run(catalog::CatalogAccessor *accessor, FunctionBuilder *function_builder, + ast::udf::UdfAstContext *ast_context, CodeGen *codegen, catalog::db_oid_t db_oid, + ast::udf::FunctionAST *root) { + UdfCodegen generator{accessor, function_builder, ast_context, codegen, db_oid}; + generator.GenerateUDF(root->Body()); + return generator.Finish(); +} + +// Static +const char *UdfCodegen::GetReturnParamString() { return "return_val"; } + +void UdfCodegen::GenerateUDF(ast::udf::AbstractAST *ast) { ast->Accept(this); } + +catalog::type_oid_t UdfCodegen::GetCatalogTypeOidFromSQLType(sql::SqlTypeId type) { + return accessor_->GetTypeOidFromTypeId(type); +} + +catalog::type_oid_t UdfCodegen::GetCatalogTypeFromBuiltinKind(ast::BuiltinType::Kind type) { + switch (type) { + case ast::BuiltinType::Kind::Boolean: { + return GetCatalogTypeOidFromSQLType(sql::SqlTypeId::Boolean); + } + case ast::BuiltinType::Kind::Integer: { + return GetCatalogTypeOidFromSQLType(sql::SqlTypeId::Integer); + } + case ast::BuiltinType::Kind::Real: { + return GetCatalogTypeOidFromSQLType(sql::SqlTypeId::Real); + } + case ast::BuiltinType::Kind::Decimal: { + return GetCatalogTypeOidFromSQLType(sql::SqlTypeId::Decimal); + } + case ast::BuiltinType::Kind::StringVal: { + return GetCatalogTypeOidFromSQLType(sql::SqlTypeId::Varchar); + } + case ast::BuiltinType::Kind::Date: { + return GetCatalogTypeOidFromSQLType(sql::SqlTypeId::Date); + } + case ast::BuiltinType::Kind::Timestamp: { + return GetCatalogTypeOidFromSQLType(sql::SqlTypeId::Timestamp); + } + default: + NOISEPAGE_ASSERT(false, "Invalid SQL type in function call"); + return accessor_->GetTypeOidFromTypeId(sql::SqlTypeId::Invalid); + } +} + +ast::File *UdfCodegen::Finish() { + ast::FunctionDecl *fn = fb_->Finish(); + util::RegionVector decls{{fn}, codegen_->GetAstContext()->GetRegion()}; + decls.insert(decls.begin(), aux_decls_.cbegin(), aux_decls_.cend()); + auto file = codegen_->GetAstContext()->GetNodeFactory()->NewFile({0, 0}, std::move(decls)); + return file; +} + +/* ---------------------------------------------------------------------------- + Code Generation: "Simple" Constructs +---------------------------------------------------------------------------- */ + +void UdfCodegen::Visit(ast::udf::AbstractAST *ast) { + throw NOT_IMPLEMENTED_EXCEPTION("UdfCodegen::Visit(AbstractAST*)"); +} + +void UdfCodegen::Visit(ast::udf::StmtAST *ast) { UNREACHABLE("Not implemented"); } + +void UdfCodegen::Visit(ast::udf::ExprAST *ast) { UNREACHABLE("Not implemented"); } + +void UdfCodegen::Visit(ast::udf::DynamicSQLStmtAST *ast) { + throw NOT_IMPLEMENTED_EXCEPTION("UdfCodegen::Visit(DynamicSQLStmtAST*)"); +} + +void UdfCodegen::Visit(ast::udf::DeclStmtAST *ast) { + if (ast->Name() == INTERNAL_DECL_ID) { + return; + } + + const ast::Identifier identifier = codegen_->MakeFreshIdentifier(ast->Name()); + SymbolTable()[ast->Name()] = identifier; + + auto prev_type = current_type_; + ast::Expr *tpl_type = nullptr; + if (ast->Type() == sql::SqlTypeId::Invalid) { + // Record type + util::RegionVector fields{codegen_->GetAstContext()->GetRegion()}; + + // TODO(Kyle): Handle unbound record types + const auto record_type = udf_ast_context_->GetRecordType(ast->Name()); + if (!record_type.has_value()) { + // Unbound record type + throw NOT_IMPLEMENTED_EXCEPTION("Unbound RECORD types not supported"); + } + + for (const auto &p : record_type.value()) { + fields.push_back( + codegen_->MakeField(codegen_->MakeIdentifier(p.first), codegen_->TplType(sql::GetTypeId(p.second)))); + } + auto record_decl = codegen_->DeclareStruct(codegen_->MakeFreshIdentifier("rectype"), std::move(fields)); + aux_decls_.push_back(record_decl); + tpl_type = record_decl->TypeRepr(); + } else { + tpl_type = codegen_->TplType(sql::GetTypeId(ast->Type())); + } + current_type_ = ast->Type(); + if (ast->Initial() != nullptr) { + ast::Expr *initializer = EvaluateExpression(ast->Initial()); + fb_->Append(codegen_->DeclareVar(identifier, tpl_type, initializer)); + } else { + fb_->Append(codegen_->DeclareVarNoInit(identifier, tpl_type)); + } + current_type_ = prev_type; +} + +void UdfCodegen::Visit(ast::udf::FunctionAST *ast) { + for (size_t i = 0; i < ast->ParameterTypes().size(); i++) { + SymbolTable()[ast->ParameterNames().at(i)] = codegen_->MakeFreshIdentifier("udf"); + } + ast->Body()->Accept(this); +} + +void UdfCodegen::Visit(ast::udf::VariableExprAST *ast) { + auto it = SymbolTable().find(ast->Name()); + NOISEPAGE_ASSERT(it != SymbolTable().end(), "Variable not declared"); + SetExecutionResult(codegen_->MakeExpr(it->second)); +} + +void UdfCodegen::Visit(ast::udf::ValueExprAST *ast) { + auto val = common::ManagedPointer(ast->Value()).CastManagedPointerTo(); + if (val->IsNull()) { + SetExecutionResult(codegen_->ConstNull(current_type_)); + return; + } + + ast::Expr *expr; + auto type_id = sql::GetTypeId(val->GetReturnValueType()); + switch (type_id) { + case sql::TypeId::Boolean: + expr = codegen_->BoolToSql(val->GetBoolVal().val_); + break; + case sql::TypeId::TinyInt: + case sql::TypeId::SmallInt: + case sql::TypeId::Integer: + case sql::TypeId::BigInt: + expr = codegen_->IntToSql(val->GetInteger().val_); + break; + case sql::TypeId::Float: + case sql::TypeId::Double: + expr = codegen_->FloatToSql(val->GetReal().val_); + break; + case sql::TypeId::Date: + expr = codegen_->DateToSql(val->GetDateVal().val_); + break; + case sql::TypeId::Timestamp: + expr = codegen_->TimestampToSql(val->GetTimestampVal().val_); + break; + case sql::TypeId::Varchar: + expr = codegen_->StringToSql(val->GetStringVal().StringView()); + break; + default: + throw NOT_IMPLEMENTED_EXCEPTION("Unsupported type in UDF codegen"); + } + SetExecutionResult(expr); +} + +void UdfCodegen::Visit(ast::udf::AssignStmtAST *ast) { + const sql::SqlTypeId left_type = GetVariableType(ast->Destination()->Name()); + current_type_ = left_type; + + ast::Expr *rhs_expr = EvaluateExpression(ast->Source()); + + auto it = SymbolTable().find(ast->Destination()->Name()); + NOISEPAGE_ASSERT(it != SymbolTable().end(), "Variable not found"); + auto left_codegen_ident = it->second; + + auto *left_expr = codegen_->MakeExpr(left_codegen_ident); + fb_->Append(codegen_->Assign(left_expr, rhs_expr)); +} + +void UdfCodegen::Visit(ast::udf::BinaryExprAST *ast) { + parsing::Token::Type op_token; + bool compare = false; + switch (ast->Op()) { + case parser::ExpressionType::OPERATOR_DIVIDE: + op_token = parsing::Token::Type::SLASH; + break; + case parser::ExpressionType::OPERATOR_PLUS: + op_token = parsing::Token::Type::PLUS; + break; + case parser::ExpressionType::OPERATOR_MINUS: + op_token = parsing::Token::Type::MINUS; + break; + case parser::ExpressionType::OPERATOR_MULTIPLY: + op_token = parsing::Token::Type::STAR; + break; + case parser::ExpressionType::OPERATOR_MOD: + op_token = parsing::Token::Type::PERCENT; + break; + case parser::ExpressionType::CONJUNCTION_OR: + op_token = parsing::Token::Type::OR; + break; + case parser::ExpressionType::CONJUNCTION_AND: + op_token = parsing::Token::Type::AND; + break; + case parser::ExpressionType::COMPARE_GREATER_THAN: + compare = true; + op_token = parsing::Token::Type::GREATER; + break; + case parser::ExpressionType::COMPARE_GREATER_THAN_OR_EQUAL_TO: + compare = true; + op_token = parsing::Token::Type::GREATER_EQUAL; + break; + case parser::ExpressionType::COMPARE_LESS_THAN_OR_EQUAL_TO: + compare = true; + op_token = parsing::Token::Type::LESS_EQUAL; + break; + case parser::ExpressionType::COMPARE_LESS_THAN: + compare = true; + op_token = parsing::Token::Type::LESS; + break; + case parser::ExpressionType::COMPARE_EQUAL: + compare = true; + op_token = parsing::Token::Type::EQUAL_EQUAL; + break; + default: + // TODO(Kyle): Figure out concatenation operation from expressions? + UNREACHABLE("Unsupported expression"); + } + ast::Expr *lhs_expr = EvaluateExpression(ast->Left()); + ast::Expr *rhs_expr = EvaluateExpression(ast->Right()); + ast::Expr *result = + compare ? codegen_->Compare(op_token, lhs_expr, rhs_expr) : codegen_->BinaryOp(op_token, lhs_expr, rhs_expr); + SetExecutionResult(result); +} + +void UdfCodegen::Visit(ast::udf::IfStmtAST *ast) { + // TODO(Kyle): It would be nice to add support for IF .. ELSIF .. ELSE + // constructs, but the current TPL architecture does not have native + // support for code generation of this type of control flow, so I am + // going to punt on it for now. + ast::Expr *condition = EvaluateExpression(ast->Condition()); + If branch(fb_, condition); + ast->Then()->Accept(this); + if (ast->Else() != nullptr) { + branch.Else(); + ast->Else()->Accept(this); + } + branch.EndIf(); +} + +void UdfCodegen::Visit(ast::udf::IsNullExprAST *ast) { + ast::Expr *child = EvaluateExpression(ast->Child()); + ast::Expr *null_check = codegen_->CallBuiltin(ast::Builtin::IsValNull, {child}); + SetExecutionResult(null_check); + if (!ast->IsNullCheck()) { + SetExecutionResult(codegen_->UnaryOp(parsing::Token::Type::BANG, null_check)); + } +} + +void UdfCodegen::Visit(ast::udf::SeqStmtAST *ast) { + for (auto &stmt : ast->Statements()) { + stmt->Accept(this); + } +} + +void UdfCodegen::Visit(ast::udf::WhileStmtAST *ast) { + ast::Expr *condition = EvaluateExpression(ast->Condition()); + Loop loop(fb_, condition); + ast->Body()->Accept(this); + loop.EndLoop(); +} + +void UdfCodegen::Visit(ast::udf::RetStmtAST *ast) { + // TODO(Kyle): Handle NULL returns + ast::Expr *return_expr = EvaluateExpression(ast->Return()); + fb_->Append(codegen_->Return(return_expr)); +} + +void UdfCodegen::Visit(ast::udf::MemberExprAST *ast) { + ast::Expr *object = EvaluateExpression(ast->Object()); + ast::Expr *access = codegen_->AccessStructMember(object, codegen_->MakeIdentifier(ast->FieldName())); + SetExecutionResult(access); +} + +/* ---------------------------------------------------------------------------- + Code Generation: Function Calls +---------------------------------------------------------------------------- */ + +void UdfCodegen::Visit(ast::udf::CallExprAST *ast) { + const auto &args = ast->Args(); + + // Generate code to evaluate call arguments + std::vector arguments{}; + arguments.reserve(ast->Args().size()); + std::transform(args.cbegin(), args.cend(), std::back_inserter(arguments), + [this](const std::unique_ptr &expr) { return EvaluateExpression(expr.get()); }); + + std::vector argument_types{}; + argument_types.reserve(arguments.size()); + std::transform( + arguments.cbegin(), arguments.cend(), std::back_inserter(argument_types), + [this](const ast::Expr *expr) -> catalog::type_oid_t { return GetCatalogTypeOidFromSQLType(ResolveType(expr)); }); + + // Resolve the argument types to handle the case where an untyped NULL is passed + const auto resolved_types = accessor_->ResolveProcArgumentTypes(ast->Callee(), argument_types); + if (resolved_types.empty()) { + throw BINDER_EXCEPTION(fmt::format("Procedure '{}' not registered", ast->Callee()), + common::ErrorCode::ERRCODE_PLPGSQL_ERROR); + } else if (resolved_types.size() > 1) { + throw BINDER_EXCEPTION(fmt::format("Procedure call '{}' is ambiguous", ast->Callee()), + common::ErrorCode::ERRCODE_PLPGSQL_ERROR); + } + + // This lookup should now always succeed + const auto proc_oid = accessor_->GetProcOid(ast->Callee(), resolved_types.front()); + if (proc_oid == catalog::INVALID_PROC_OID) { + throw BINDER_EXCEPTION(fmt::format("Invalid function call '{}'", ast->Callee()), + common::ErrorCode::ERRCODE_PLPGSQL_ERROR); + } + + auto context = accessor_->GetFunctionContext(proc_oid); + if (context->IsBuiltin()) { + if (context->IsExecCtxRequired()) { + // If this builtin requires an execution context, provide it + arguments.insert(arguments.begin(), GetExecutionContext()); + } + ast::Expr *result = codegen_->CallBuiltin(context->GetBuiltin(), arguments); + SetExecutionResult(result); + } else { + // NOTE(Kyle): This is an unfortunate operation because it + // requires shifting all elements in the vector, but we + // don't typically see functions with super-high arity + arguments.insert(arguments.begin(), GetExecutionContext()); + auto it = SymbolTable().find(ast->Callee()); + ast::Identifier ident_expr; + if (it != SymbolTable().end()) { + ident_expr = it->second; + } else { + auto file = reinterpret_cast( + ast::AstClone::Clone(context->GetFile(), codegen_->GetAstContext()->GetNodeFactory(), + context->GetASTContext(), codegen_->GetAstContext().Get())); + for (auto decl : file->Declarations()) { + aux_decls_.push_back(decl); + } + ident_expr = codegen_->MakeFreshIdentifier(file->Declarations().back()->Name().GetString()); + SymbolTable()[file->Declarations().back()->Name().GetString()] = ident_expr; + } + ast::Expr *result = codegen_->Call(ident_expr, arguments); + SetExecutionResult(result); + } +} + +sql::SqlTypeId UdfCodegen::ResolveType(const ast::Expr *expr) const { + const auto t = expr->GetKind(); + (void)t; + switch (expr->GetKind()) { + case ast::AstNode::Kind::LitExpr: + return ResolveTypeForLiteralExpression(expr->SafeAs()); + case ast::AstNode::Kind::BinaryOpExpr: + return ResolveTypeForBinaryExpression(expr->SafeAs()); + case ast::AstNode::Kind::IdentifierExpr: + return ResolveTypeForIdentifierExpression(expr->SafeAs()); + case ast::AstNode::Kind::CallExpr: + return ResolveTypeForCallExpression(expr->SafeAs()); + default: + UNREACHABLE("Function call argument type cannot be resolved"); + } +} + +sql::SqlTypeId UdfCodegen::ResolveTypeForLiteralExpression(const ast::LitExpr *expr) const { + NOISEPAGE_ASSERT(expr->IsLitExpr(), "Broken precondition."); + // TODO(Kyle): What to do about the ambiguity here? + // e.g. a literal might be a float vs double + switch (expr->GetLiteralKind()) { + case ast::LitExpr::LitKind::Boolean: + return sql::SqlTypeId::Boolean; + case ast::LitExpr::LitKind::Float: + return sql::SqlTypeId::Double; + case ast::LitExpr::LitKind::Int: + return sql::SqlTypeId::Integer; + case ast::LitExpr::LitKind::String: + return sql::SqlTypeId::Varchar; + default: + UNREACHABLE("Invalid type"); + } +} + +/** @return `true` if the given type is an integral type */ +static bool IsIntegral(sql::SqlTypeId type) { + return type == sql::SqlTypeId::TinyInt || type == sql::SqlTypeId::SmallInt || type == sql::SqlTypeId::Integer || + type == sql::SqlTypeId::BigInt; +} + +/** @return `true` if the given type is a floating-point type */ +static bool IsFloatingPoint(sql::SqlTypeId type) { + return type == sql::SqlTypeId::Real || type == sql::SqlTypeId::Double; +} + +sql::SqlTypeId UdfCodegen::ResolveTypeForBinaryExpression(const ast::BinaryOpExpr *expr) const { + NOISEPAGE_ASSERT(expr->IsBinaryOpExpr(), "Broken precondition"); + const auto *binary = expr->SafeAs(); + sql::SqlTypeId left = ResolveType(binary->Left()); + sql::SqlTypeId right = ResolveType(binary->Right()); + switch (binary->Op()) { + // Basic arithmetic operators + case parsing::Token::Type::PLUS: + case parsing::Token::Type::MINUS: + case parsing::Token::Type::STAR: + case parsing::Token::Type::SLASH: + if (left == right) { + return left; + } + if (IsFloatingPoint(left) && IsIntegral(right)) { + return left; + } + if (IsIntegral(left) && IsFloatingPoint(right)) { + return right; + } + UNREACHABLE("Unsupported types for arithmetic operations"); + default: + break; + } + UNREACHABLE("Binary operation not supported"); +} + +sql::SqlTypeId UdfCodegen::ResolveTypeForIdentifierExpression(const ast::IdentifierExpr *expr) const { + NOISEPAGE_ASSERT(expr->IsIdentifierExpr(), "Broken precondition."); + // Just lookup the type for the variable with which it was declared + return GetVariableType(expr->Name().GetString()); +} + +sql::SqlTypeId UdfCodegen::ResolveTypeForCallExpression(const ast::CallExpr *expr) const { + const ast::Type *type = expr->GetType(); + NOISEPAGE_ASSERT(type->IsSqlValueType(), "Invalid type"); + const ast::BuiltinType *builtin = type->SafeAs(); + switch (builtin->GetKind()) { + case ast::BuiltinType::Kind::Boolean: + return sql::SqlTypeId::Boolean; + case ast::BuiltinType::Kind::Integer: + return sql::SqlTypeId::Integer; + case ast::BuiltinType::Kind::Real: + return sql::SqlTypeId::Real; + case ast::BuiltinType::Kind::Decimal: + return sql::SqlTypeId::Decimal; + case ast::BuiltinType::Kind::StringVal: + return sql::SqlTypeId::Varchar; + case ast::BuiltinType::Kind::Date: + return sql::SqlTypeId::Date; + case ast::BuiltinType::Kind::Timestamp: + return sql::SqlTypeId::Timestamp; + default: + UNREACHABLE("Invalid type"); + } +} + +/* ---------------------------------------------------------------------------- + Code Generation: Integer-Variant For-Loops +---------------------------------------------------------------------------- */ + +void UdfCodegen::Visit(ast::udf::ForIStmtAST *ast) { throw NOT_IMPLEMENTED_EXCEPTION("ForIStmtAST Not Implemented"); } + +/* ---------------------------------------------------------------------------- + Code Generation: Query-Variant For-Loops +---------------------------------------------------------------------------- */ + +void UdfCodegen::Visit(ast::udf::ForSStmtAST *ast) { + // Executing a SQL query requires an execution context + ast::Expr *exec_ctx = GetExecutionContext(); + + // Bind the embedded query; must do this prior to attempting + // to optimize to ensure correctness + const auto variable_refs = BindQueryAndGetVariableRefs(ast->Query()); + + // Optimize the embedded query + auto optimize_result = OptimizeEmbeddedQuery(ast->Query(), variable_refs); + auto plan = optimize_result->GetPlanNode(); + + // Start construction of the lambda expression + auto builder = StartLambda(plan, ast->Variables()); + + // Generate code for variable initialization + CodegenBoundVariableInit(plan, ast->Variables()); + + // Generate code for the loop body + { + auto cached_builder = fb_; + fb_ = builder.get(); + ast->Body()->Accept(this); + fb_ = cached_builder; + } + + ast::LambdaExpr *lambda_expr = builder->FinishClosure(); + const ast::Identifier lambda_identifier = codegen_->MakeFreshIdentifier("udfLambda"); + lambda_expr->SetName(lambda_identifier); + + // Materialize the lambda into the lambda expression + exec::ExecutionSettings exec_settings{}; + const std::string dummy_query{}; + auto exec_query = compiler::CompilationContext::Compile( + *plan, exec_settings, accessor_, compiler::CompilationMode::OneShot, std::nullopt, + common::ManagedPointer{}, lambda_expr, codegen_->GetAstContext()); + + // Append all of the declarations from the compiled query + auto decls = exec_query->GetDecls(); + aux_decls_.insert(aux_decls_.end(), decls.cbegin(), decls.cend()); + + // Declare the closure and the query state in the current function + auto query_state = codegen_->MakeFreshIdentifier("query_state"); + fb_->Append(codegen_->DeclareVarNoInit(query_state, codegen_->MakeExpr(exec_query->GetQueryStateType()->Name()))); + fb_->Append(codegen_->DeclareVar( + lambda_identifier, codegen_->LambdaType(lambda_expr->GetFunctionLiteralExpr()->TypeRepr()), lambda_expr)); + + // Set its execution context to whatever execution context was passed in here + fb_->Append(codegen_->CallBuiltin(ast::Builtin::StartNewParams, {exec_ctx})); + + CodegenAddParameters(exec_ctx, variable_refs); + + fb_->Append(codegen_->Assign( + codegen_->AccessStructMember(codegen_->MakeExpr(query_state), codegen_->MakeIdentifier("execCtx")), exec_ctx)); + + // Manually append calls to each function from the compiled + // executable query (implementing the closure) to the builder + CodegenTopLevelCalls(exec_query.get(), query_state, lambda_identifier); + + fb_->Append(codegen_->CallBuiltin(ast::Builtin::FinishNewParams, {exec_ctx})); +} + +std::unique_ptr UdfCodegen::StartLambda(common::ManagedPointer plan, + const std::vector &variables) { + return GetVariableType(variables.front()) == sql::SqlTypeId::Invalid ? StartLambdaBindingToRecord(plan, variables) + : StartLambdaBindingToScalars(plan, variables); +} + +std::unique_ptr UdfCodegen::StartLambdaBindingToRecord( + common::ManagedPointer plan, const std::vector &variables) { + // bind results to a single RECORD variable + NOISEPAGE_ASSERT(variables.size() == 1, "Broken invariant"); + + const std::string &record_name = variables.front(); + const auto record_type = GetRecordType(record_name); + + const auto n_fields = record_type.size(); + const auto n_columns = plan->GetOutputSchema()->GetColumns().size(); + if (n_fields != n_columns) { + throw EXECUTION_EXCEPTION( + fmt::format("Attempt to bind {} query outputs to record type with {} fields", n_columns, n_fields), + common::ErrorCode::ERRCODE_PLPGSQL_ERROR); + } + + // The lambda accepts all columns of the query output schema as parameters + util::RegionVector parameters{codegen_->GetAstContext()->GetRegion()}; + + // The first parameter is always the execution context + ast::Expr *exec_ctx = GetExecutionContext(); + parameters.push_back( + codegen_->MakeField(exec_ctx->As()->Name(), + codegen_->PointerType(codegen_->BuiltinType(ast::BuiltinType::Kind::ExecutionContext)))); + + // The lambda captures all variables in the symbol table + // NOTE(Kyle): It might be possible / preferable to make this more conservative + util::RegionVector captures{codegen_->GetAstContext()->GetRegion()}; + for (const auto &[name, identifier] : SymbolTable()) { + if (name != "executionCtx") { + captures.push_back(codegen_->MakeExpr(identifier)); + } + } + + // While the closure only captures a single variable, we still need + // to generate code for an assignment to each field memeber + std::vector assignees{}; + assignees.reserve(n_columns); + + ast::Expr *record = codegen_->MakeExpr(SymbolTable().find(record_name)->second); + for (std::size_t i = 0UL; i < n_columns; ++i) { + const auto &column = plan->GetOutputSchema()->GetColumn(i); + assignees.push_back(codegen_->AccessStructMember(record, codegen_->MakeIdentifier(record_type[i].first))); + parameters.push_back(codegen_->MakeField(codegen_->MakeFreshIdentifier("input"), + codegen_->TplType(sql::GetTypeId(column.GetType())))); + } + + auto builder = std::make_unique(codegen_, std::move(parameters), std::move(captures), + codegen_->BuiltinType(ast::BuiltinType::Nil)); + for (std::size_t i = 0UL; i < assignees.size(); ++i) { + auto *assignee = assignees.at(i); + auto input_parameter = builder->GetParameterByPosition(i + 1); + builder->Append(codegen_->Assign(assignee, input_parameter)); + } + return builder; +} + +std::unique_ptr UdfCodegen::StartLambdaBindingToScalars( + common::ManagedPointer plan, const std::vector &variables) { + // bind results to one or more non-RECORD variables + const auto n_variables = variables.size(); + const auto n_columns = plan->GetOutputSchema()->GetColumns().size(); + if (n_variables != n_columns) { + throw EXECUTION_EXCEPTION(fmt::format("Attempt to bind {} query outputs to {} variables", n_columns, n_variables), + common::ErrorCode::ERRCODE_PLPGSQL_ERROR); + } + + // The lambda accepts all columns of the query output schema as parameters + util::RegionVector parameters{codegen_->GetAstContext()->GetRegion()}; + + // The lambda captures all variables in the symbol table + // NOTE(Kyle): It might be possible / preferable to make this more conservative + util::RegionVector captures{codegen_->GetAstContext()->GetRegion()}; + for (const auto &[name, identifier] : SymbolTable()) { + if (name != "executionCtx") { + captures.push_back(codegen_->MakeExpr(identifier)); + } + } + + // The first parameter is always the execution context + ast::Expr *exec_ctx = GetExecutionContext(); + parameters.push_back( + codegen_->MakeField(exec_ctx->As()->Name(), + codegen_->PointerType(codegen_->BuiltinType(ast::BuiltinType::Kind::ExecutionContext)))); + + // Assignees are those captures that are written in the closure + std::vector assignees{}; + assignees.reserve(n_columns); + + // Populate the parameters and capture assignees + for (std::size_t i = 0UL; i < n_columns; ++i) { + const auto &variable = variables.at(i); + const auto &column = plan->GetOutputSchema()->GetColumn(i); + assignees.push_back(codegen_->MakeExpr(SymbolTable().find(variable)->second)); + parameters.push_back(codegen_->MakeField(codegen_->MakeFreshIdentifier("input"), + codegen_->TplType(sql::GetTypeId(column.GetType())))); + } + + // Begin construction of the function that implements the closure + auto builder = std::make_unique(codegen_, std::move(parameters), std::move(captures), + codegen_->BuiltinType(ast::BuiltinType::Nil)); + + // Generate an assignment from each input parameter to the associated capture + for (std::size_t i = 0UL; i < assignees.size(); ++i) { + ast::Expr *capture = assignees.at(i); + auto input_parameter = builder->GetParameterByPosition(i + 1); + builder->Append(codegen_->Assign(capture, input_parameter)); + } + return builder; +} + +/* ---------------------------------------------------------------------------- + Code Generation: SQL Statements +---------------------------------------------------------------------------- */ + +void UdfCodegen::Visit(ast::udf::SQLStmtAST *ast) { + // Executing a SQL query requires an execution context + ast::Expr *exec_ctx = GetExecutionContext(); + + // Bind the embedded query; must do this prior to attempting + // to optimize to ensure correctness + const auto variable_refs = BindQueryAndGetVariableRefs(ast->Query()); + + // Optimize the query and generate get a reference to the plan + auto optimize_result = OptimizeEmbeddedQuery(ast->Query(), variable_refs); + auto plan = optimize_result->GetPlanNode(); + + // Construct a lambda that writes the output of the query + // into the bound variables, as defined by the function body + ast::LambdaExpr *lambda_expr = MakeLambda(plan, ast->Variables()); + const ast::Identifier lambda_identifier = codegen_->MakeFreshIdentifier("udfLambda"); + lambda_expr->SetName(lambda_identifier); + + // Generate code for the embedded query, utilizing the generated closure as the output callback + exec::ExecutionSettings exec_settings{}; + auto exec_query = compiler::CompilationContext::Compile( + *plan, exec_settings, accessor_, compiler::CompilationMode::OneShot, std::nullopt, + common::ManagedPointer{}, lambda_expr, codegen_->GetAstContext()); + + // Append all declarations from the compiled query + auto decls = exec_query->GetDecls(); + aux_decls_.insert(aux_decls_.end(), decls.cbegin(), decls.cend()); + + // Declare the closure and the query state in the current function + auto query_state = codegen_->MakeFreshIdentifier("query_state"); + fb_->Append(codegen_->DeclareVarNoInit(query_state, codegen_->MakeExpr(exec_query->GetQueryStateType()->Name()))); + fb_->Append(codegen_->DeclareVar( + lambda_identifier, codegen_->LambdaType(lambda_expr->GetFunctionLiteralExpr()->TypeRepr()), lambda_expr)); + + // Set its execution context to whatever execution context was passed in here + fb_->Append(codegen_->CallBuiltin(ast::Builtin::StartNewParams, {exec_ctx})); + + // Determine the column references in the query (if any) + // that depend on variables in the UDF definition + CodegenAddParameters(exec_ctx, variable_refs); + + // Load the execution context member of the query state + fb_->Append(codegen_->Assign( + codegen_->AccessStructMember(codegen_->MakeExpr(query_state), codegen_->MakeIdentifier("execCtx")), exec_ctx)); + + // Initialize the captures + CodegenBoundVariableInit(plan, ast->Variables()); + + // Manually append calls to each function from the compiled + // executable query (implementing the closure) to the builder + CodegenTopLevelCalls(exec_query.get(), query_state, lambda_identifier); + + fb_->Append(codegen_->CallBuiltin(ast::Builtin::FinishNewParams, {exec_ctx})); +} + +ast::LambdaExpr *UdfCodegen::MakeLambda(common::ManagedPointer plan, + const std::vector &variables) { + return GetVariableType(variables.front()) == sql::SqlTypeId::Invalid ? MakeLambdaBindingToRecord(plan, variables) + : MakeLambdaBindingToScalars(plan, variables); +} + +ast::LambdaExpr *UdfCodegen::MakeLambdaBindingToRecord(common::ManagedPointer plan, + const std::vector &variables) { + // bind results to a single RECORD variable + NOISEPAGE_ASSERT(variables.size() == 1, "Broken invariant"); + + const std::string &record_name = variables.front(); + const auto record_type = GetRecordType(record_name); + + const auto n_fields = record_type.size(); + const auto n_columns = plan->GetOutputSchema()->GetColumns().size(); + if (n_fields != n_columns) { + throw EXECUTION_EXCEPTION( + fmt::format("Attempt to bind {} query outputs to record type with {} fields", n_columns, n_fields), + common::ErrorCode::ERRCODE_PLPGSQL_ERROR); + } + + // The lambda accepts all columns of the query output schema as parameters + util::RegionVector parameters{codegen_->GetAstContext()->GetRegion()}; + + ast::Expr *exec_ctx = GetExecutionContext(); + parameters.push_back( + codegen_->MakeField(exec_ctx->As()->Name(), + codegen_->PointerType(codegen_->BuiltinType(ast::BuiltinType::Kind::ExecutionContext)))); + + // The lambda only captures the RECORD variable to which all results are bound + ast::Expr *capture = codegen_->MakeExpr(SymbolTable().find(record_name)->second); + util::RegionVector captures{codegen_->GetAstContext()->GetRegion()}; + + // While the closure only captures a single variable, we still need + // to generate code for an assignment to each field memeber + std::vector assignees{}; + assignees.reserve(n_columns); + + for (std::size_t i = 0; i < n_columns; ++i) { + const auto &column = plan->GetOutputSchema()->GetColumn(i); + assignees.push_back(codegen_->AccessStructMember(capture, codegen_->MakeIdentifier(record_type[i].first))); + parameters.push_back(codegen_->MakeField(codegen_->MakeFreshIdentifier("input"), + codegen_->TplType(sql::GetTypeId(column.GetType())))); + } + + FunctionBuilder builder{codegen_, std::move(parameters), std::move(captures), + codegen_->BuiltinType(ast::BuiltinType::Nil)}; + for (std::size_t i = 0UL; i < assignees.size(); ++i) { + auto *assignee = assignees.at(i); + auto input_parameter = builder.GetParameterByPosition(i + 1); + builder.Append(codegen_->Assign(assignee, input_parameter)); + } + + return builder.FinishClosure(); +} + +ast::LambdaExpr *UdfCodegen::MakeLambdaBindingToScalars(common::ManagedPointer plan, + const std::vector &variables) { + // bind results to one or more non-RECORD variables + const auto n_variables = variables.size(); + const auto n_columns = plan->GetOutputSchema()->GetColumns().size(); + if (n_variables != n_columns) { + throw EXECUTION_EXCEPTION(fmt::format("Attempt to bind {} query outputs to {} variables", n_columns, n_variables), + common::ErrorCode::ERRCODE_PLPGSQL_ERROR); + } + + // The lambda accepts all columns of the query output schema as parameters + util::RegionVector parameters{codegen_->GetAstContext()->GetRegion()}; + // The lambda captures the variables to which results are bound from the enclosing scope + util::RegionVector captures{codegen_->GetAstContext()->GetRegion()}; + + ast::Expr *exec_ctx = GetExecutionContext(); + parameters.push_back( + codegen_->MakeField(exec_ctx->As()->Name(), + codegen_->PointerType(codegen_->BuiltinType(ast::BuiltinType::Kind::ExecutionContext)))); + + // Populate the remainder of the parameters and captures + for (std::size_t i = 0; i < n_columns; ++i) { + const auto &variable = variables.at(i); + const auto &column = plan->GetOutputSchema()->GetColumn(i); + captures.push_back(codegen_->MakeExpr(SymbolTable().find(variable)->second)); + parameters.push_back(codegen_->MakeField(codegen_->MakeFreshIdentifier("input"), + codegen_->TplType(sql::GetTypeId(column.GetType())))); + } + + // Clone the captures for assignment within the closure body + const std::vector assignees{captures.cbegin(), captures.cend()}; + + // Begin construction of the function that implements the closure + FunctionBuilder builder{codegen_, std::move(parameters), std::move(captures), + codegen_->BuiltinType(ast::BuiltinType::Nil)}; + + // Generate an assignment from each input parameter to the associated capture + for (std::size_t i = 0UL; i < assignees.size(); ++i) { + ast::Expr *capture = assignees.at(i); + auto input_parameter = builder.GetParameterByPosition(i + 1); + builder.Append(codegen_->Assign(capture, input_parameter)); + } + return builder.FinishClosure(); +} + +/* ---------------------------------------------------------------------------- + Code Gneration Helpers: Add Parameters +---------------------------------------------------------------------------- */ + +void UdfCodegen::CodegenAddParameters(ast::Expr *exec_ctx, const std::vector &variable_refs) { + for (const auto &variable_ref : variable_refs) { + if (variable_ref.IsScalar()) { + CodegenAddScalarParameter(exec_ctx, variable_ref); + } else { + CodegenAddTableParameter(exec_ctx, variable_ref); + } + } +} + +void UdfCodegen::CodegenAddScalarParameter(ast::Expr *exec_ctx, const parser::udf::VariableRef &variable_ref) { + NOISEPAGE_ASSERT(variable_ref.IsScalar(), "Broken invariant"); + const auto &name = variable_ref.ColumnName(); + const sql::SqlTypeId type = GetVariableType(name); + ast::Expr *expr = codegen_->MakeExpr(SymbolTable().at(name)); + fb_->Append(codegen_->CallBuiltin(AddParamBuiltinForParameterType(type), {exec_ctx, expr})); +} + +void UdfCodegen::CodegenAddTableParameter(ast::Expr *exec_ctx, const parser::udf::VariableRef &variable_ref) { + NOISEPAGE_ASSERT(!variable_ref.IsScalar(), "Broken invariant"); + + const auto &record_name = variable_ref.TableName(); + const auto &field_name = variable_ref.ColumnName(); + + const auto fields = GetRecordType(record_name); + auto it = std::find_if( + fields.cbegin(), fields.cend(), + [&field_name](const std::pair &field) -> bool { return field.first == field_name; }); + if (it == fields.cend()) { + throw EXECUTION_EXCEPTION(fmt::format("Field '{}' not found in record '{}'", field_name, record_name), + common::ErrorCode::ERRCODE_PLPGSQL_ERROR); + } + + const sql::SqlTypeId type = it->second; + ast::Expr *expr = codegen_->AccessStructMember(codegen_->MakeExpr(SymbolTable().at(record_name)), + codegen_->MakeIdentifier(field_name)); + fb_->Append(codegen_->CallBuiltin(AddParamBuiltinForParameterType(type), {exec_ctx, expr})); +} + +/* ---------------------------------------------------------------------------- + Code Gneration Helpers: Bound Variable Initialization +---------------------------------------------------------------------------- */ + +void UdfCodegen::CodegenBoundVariableInit(common::ManagedPointer plan, + const std::vector &bound_variables) { + if (bound_variables.empty()) { + // Nothing to do + return; + } + + if (GetVariableType(bound_variables.front()) == sql::SqlTypeId::Invalid) { + CodegenBoundVariableInitForRecord(plan, bound_variables.front()); + } else { + CodegenBoundVariableInitForScalars(plan, bound_variables); + } +} + +void UdfCodegen::CodegenBoundVariableInitForScalars(common::ManagedPointer plan, + const std::vector &bound_variables) { + const auto n_columns = plan->GetOutputSchema()->GetColumns().size(); + const auto n_variables = bound_variables.size(); + if (n_columns != n_variables) { + throw EXECUTION_EXCEPTION( + fmt::format("Attempt to bind {} query results to {} scalar variables", n_columns, n_variables), + common::ErrorCode::ERRCODE_PLPGSQL_ERROR); + } + + for (std::size_t i = 0; i < n_columns; ++i) { + const auto &column = plan->GetOutputSchema()->GetColumn(i); + const auto &variable = bound_variables.at(i); + ast::Expr *capture = codegen_->MakeExpr(SymbolTable().find(variable)->second); + fb_->Append(codegen_->Assign(capture, codegen_->ConstNull(column.GetType()))); + } +} + +void UdfCodegen::CodegenBoundVariableInitForRecord(common::ManagedPointer plan, + const std::string &record_name) { + NOISEPAGE_ASSERT(GetVariableType(record_name) == sql::SqlTypeId::Invalid, "Broken invariant"); + const auto n_columns = plan->GetOutputSchema()->GetColumns().size(); + const auto fields = GetRecordType(record_name); + const auto n_fields = fields.size(); + if (n_columns != n_fields) { + // NOTE(Kyle): This should be impossible, the structure of the + // record type is derived from the output schema of the query + throw EXECUTION_EXCEPTION( + fmt::format("Attempt to bind {} query results to record with {} fields", n_columns, n_fields), + common::ErrorCode::ERRCODE_PLPGSQL_ERROR); + } + + ast::Expr *record = codegen_->MakeExpr(SymbolTable().find(record_name)->second); + for (std::size_t i = 0; i < n_columns; ++i) { + const auto &column = plan->GetOutputSchema()->GetColumn(i); + const auto &field = fields.at(i); + NOISEPAGE_ASSERT(column.GetName() == field.first, "Broken invariant"); + ast::Expr *capture = codegen_->AccessStructMember(record, codegen_->MakeIdentifier(field.first)); + fb_->Append(codegen_->Assign(capture, codegen_->ConstNull(column.GetType()))); + } +} + +void UdfCodegen::CodegenTopLevelCalls(const ExecutableQuery *exec_query, ast::Identifier query_state_id, + ast::Identifier lambda_id) { + /** + * We don't inject the lambda parameter into every "Run" function, + * and instead only add it as an additional parameter for those + * pipelines that require it. This is parsimonious, but makes the + * process of injecting calls to each function slightly more complex. + * + * Pipelines with output callbacks are wrapped in a top-level `RunAll` + * function which accepts the lambda as a parameter. This `RunAll` function + * then assumes responsibility for calling the other top-level functions + * of which the pipeline is composed, in the proper order. In pipelines + * with output callbacks, this `RunAll` function is the only one registered + * for which a step is registered with the ExecutableQueryFragmentBuilder, + * so it is the only function returned by `GetFunctionMetadata()` for this + * pipeline. + * + * Pipelines without output callbacks are generated as normal, without + * the output callback added as an additional parameter. Therefore, we + * must inject calls to these functions with the regular signature. + */ + + for (const auto *metadata : exec_query->GetFunctionMetadata()) { + const auto &function_name = metadata->GetName(); + if (IsRunAllFunction(function_name)) { + NOISEPAGE_ASSERT(metadata->GetParamsCount() == 2, "Unexpected arity for RunAll function"); + fb_->Append(codegen_->Call(codegen_->GetAstContext()->GetIdentifier(function_name), + {codegen_->AddressOf(query_state_id), codegen_->MakeExpr(lambda_id)})); + } else { + NOISEPAGE_ASSERT(metadata->GetParamsCount() == 1, "Unexpected arity for top-level pipeline function"); + fb_->Append(codegen_->Call(codegen_->GetAstContext()->GetIdentifier(function_name), + {codegen_->AddressOf(query_state_id)})); + } + } +} + +/* ---------------------------------------------------------------------------- + General Utilities +---------------------------------------------------------------------------- */ + +ast::Expr *UdfCodegen::GetExecutionContext() { + // The execution context is always suppplied to the + // top-level function builder for the function + return fb_->GetParameterByPosition(0); +} + +ast::Expr *UdfCodegen::GetExecutionResult() { return execution_result_; } + +void UdfCodegen::SetExecutionResult(ast::Expr *result) { execution_result_ = result; } + +ast::Expr *UdfCodegen::EvaluateExpression(ast::udf::ExprAST *expr) { + expr->Accept(this); + return GetExecutionResult(); +} + +sql::SqlTypeId UdfCodegen::GetVariableType(const std::string &name) const { + auto type = udf_ast_context_->GetVariableType(name); + if (!type.has_value()) { + throw EXECUTION_EXCEPTION(fmt::format("Failed to resolve type for variable '{}'", name), + common::ErrorCode::ERRCODE_PLPGSQL_ERROR); + } + return type.value(); +} + +std::vector> UdfCodegen::GetRecordType(const std::string &name) const { + auto type = udf_ast_context_->GetRecordType(name); + if (!type.has_value()) { + throw EXECUTION_EXCEPTION(fmt::format("Failed to resolve type for record variable '{}'", name), + common::ErrorCode::ERRCODE_PLPGSQL_ERROR); + } + return type.value(); +} + +std::vector UdfCodegen::BindQueryAndGetVariableRefs(parser::ParseResult *query) { + binder::BindNodeVisitor visitor{common::ManagedPointer{accessor_}, db_oid_}; + return visitor.BindAndGetUDFVariableRefs(common::ManagedPointer{query}, common::ManagedPointer{udf_ast_context_}); +} + +std::unique_ptr UdfCodegen::OptimizeEmbeddedQuery( + parser::ParseResult *parsed_query, const std::vector &variable_refs) { + // For each variable reference, we provide a dummy ConstantValueExpression + std::vector parameters{}; + parameters.reserve(variable_refs.size()); + std::transform(variable_refs.cbegin(), variable_refs.cend(), std::back_inserter(parameters), + [](const parser::udf::VariableRef &v) -> parser::ConstantValueExpression { + return parser::ConstantValueExpression{sql::SqlTypeId::Integer, sql::Integer{0}}; + }); + + // Optimize the query + optimizer::StatsStorage stats{}; + const std::uint64_t optimizer_timeout = 1000000; + return trafficcop::TrafficCopUtil::Optimize( + accessor_->GetTxn(), common::ManagedPointer(accessor_), common::ManagedPointer(parsed_query), db_oid_, + common::ManagedPointer(&stats), std::make_unique(), optimizer_timeout, + common::ManagedPointer{¶meters}); +} + +// Static +bool UdfCodegen::IsRunAllFunction(const std::string &name) { + return name.find(RUN_ALL_IDENTIFIER) != std::string::npos; +} + +// Static +ast::Builtin UdfCodegen::AddParamBuiltinForParameterType(sql::SqlTypeId parameter_type) { + switch (parameter_type) { + case sql::SqlTypeId::Boolean: + return ast::Builtin::AddParamBool; + case sql::SqlTypeId::TinyInt: + return ast::Builtin::AddParamTinyInt; + case sql::SqlTypeId::SmallInt: + return ast::Builtin::AddParamSmallInt; + case sql::SqlTypeId::Integer: + return ast::Builtin::AddParamInt; + case sql::SqlTypeId::BigInt: + return ast::Builtin::AddParamBigInt; + case sql::SqlTypeId::Decimal: + return ast::Builtin::AddParamDouble; + case sql::SqlTypeId::Real: + return ast::Builtin::AddParamReal; + case sql::SqlTypeId::Double: + return ast::Builtin::AddParamDouble; + case sql::SqlTypeId::Date: + return ast::Builtin::AddParamDate; + case sql::SqlTypeId::Timestamp: + return ast::Builtin::AddParamTimestamp; + case sql::SqlTypeId::Varchar: + return ast::Builtin::AddParamString; + default: + UNREACHABLE("Unsupported parameter type"); + } +} + +} // namespace noisepage::execution::compiler::udf diff --git a/src/execution/exec/execution_context.cpp b/src/execution/exec/execution_context.cpp index d69d615260..6330e88564 100644 --- a/src/execution/exec/execution_context.cpp +++ b/src/execution/exec/execution_context.cpp @@ -5,7 +5,6 @@ #include "execution/sql/value.h" #include "metrics/metrics_manager.h" #include "metrics/metrics_store.h" -#include "parser/expression/constant_value_expression.h" #include "replication/primary_replication_manager.h" #include "self_driving/modeling/operating_unit.h" #include "self_driving/modeling/operating_unit_util.h" @@ -15,18 +14,19 @@ namespace noisepage::execution::exec { OutputBuffer *ExecutionContext::OutputBufferNew() { - if (schema_ == nullptr) { + if (output_schema_ == nullptr) { return nullptr; } // Use C++ placement new auto size = sizeof(OutputBuffer); auto *buffer = reinterpret_cast(mem_pool_->Allocate(size)); - new (buffer) OutputBuffer(mem_pool_.get(), schema_->GetColumns().size(), ComputeTupleSize(schema_), callback_); + new (buffer) OutputBuffer(mem_pool_.get(), output_schema_->GetColumns().size(), ComputeTupleSize(output_schema_), + output_callback_); return buffer; } -uint32_t ExecutionContext::ComputeTupleSize(const planner::OutputSchema *schema) { +uint32_t ExecutionContext::ComputeTupleSize(common::ManagedPointer schema) { uint32_t tuple_size = 0; for (const auto &col : schema->GetColumns()) { auto alignment = sql::ValUtil::GetSqlAlignment(col.GetType()); @@ -106,7 +106,8 @@ void ExecutionContext::EndResourceTracker(const char *name, uint32_t len) { common::thread_context.resource_tracker_.Stop(); common::thread_context.resource_tracker_.SetMemory(mem_tracker_->GetAllocatedSize()); const auto &resource_metrics = common::thread_context.resource_tracker_.GetMetrics(); - common::thread_context.metrics_store_->RecordExecutionData(name, len, execution_mode_, resource_metrics); + common::thread_context.metrics_store_->RecordExecutionData(name, len, static_cast(execution_mode_), + resource_metrics); } } @@ -147,8 +148,8 @@ void ExecutionContext::EndPipelineTracker(query_id_t query_id, pipeline_id_t pip NOISEPAGE_ASSERT(pipeline_id == ouvec->pipeline_id_, "Incorrect feature vector pipeline id?"); selfdriving::ExecutionOperatingUnitFeatureVector features(ouvec->pipeline_features_->begin(), ouvec->pipeline_features_->end()); - common::thread_context.metrics_store_->RecordPipelineData(query_id, pipeline_id, execution_mode_, - std::move(features), resource_metrics); + common::thread_context.metrics_store_->RecordPipelineData( + query_id, pipeline_id, static_cast(execution_mode_), std::move(features), resource_metrics); } } @@ -219,10 +220,6 @@ void ExecutionContext::InitializeParallelOUFeatureVector(selfdriving::ExecOUFeat } } -const parser::ConstantValueExpression &ExecutionContext::GetParam(const uint32_t param_idx) const { - return (*params_)[param_idx]; -} - void ExecutionContext::RegisterHook(size_t hook_idx, HookFn hook) { NOISEPAGE_ASSERT(hook_idx < hooks_.capacity(), "Incorrect number of reserved hooks"); hooks_[hook_idx] = hook; diff --git a/src/execution/exec/execution_context_builder.cpp b/src/execution/exec/execution_context_builder.cpp new file mode 100644 index 0000000000..318af2f4c5 --- /dev/null +++ b/src/execution/exec/execution_context_builder.cpp @@ -0,0 +1,58 @@ +#include "execution/exec/execution_context_builder.h" + +#include "common/error/error_code.h" +#include "common/error/exception.h" +#include "common/macros.h" +#include "execution/exec/execution_context.h" +#include "parser/expression/constant_value_expression.h" + +namespace noisepage::execution::exec { + +std::unique_ptr ExecutionContextBuilder::Build() { + if (db_oid_ == catalog::INVALID_DATABASE_OID) { + throw EXECUTION_EXCEPTION("Must specify database OID.", common::ErrorCode::ERRCODE_INTERNAL_ERROR); + } + if (!exec_settings_.has_value()) { + throw EXECUTION_EXCEPTION("Must specify exection settings.", common::ErrorCode::ERRCODE_INTERNAL_ERROR); + } + if (!txn_.has_value()) { + throw EXECUTION_EXCEPTION("Must specify a transaction context.", common::ErrorCode::ERRCODE_INTERNAL_ERROR); + } + if (!output_schema_.has_value()) { + throw EXECUTION_EXCEPTION("Must specify output schema.", common::ErrorCode::ERRCODE_INTERNAL_ERROR); + } + if (!output_callback_.has_value()) { + throw EXECUTION_EXCEPTION("Must specify output callback.", common::ErrorCode::ERRCODE_INTERNAL_ERROR); + } + if (!catalog_accessor_.has_value()) { + throw EXECUTION_EXCEPTION("Must specify catalog accessor.", common::ErrorCode::ERRCODE_INTERNAL_ERROR); + } + if (!metrics_manager_.has_value()) { + throw EXECUTION_EXCEPTION("Must specify metrics manager.", common::ErrorCode::ERRCODE_INTERNAL_ERROR); + } + if (!replication_manager_.has_value()) { + throw EXECUTION_EXCEPTION("Must specify replication manager.", common::ErrorCode::ERRCODE_INTERNAL_ERROR); + } + if (!recovery_manager_.has_value()) { + throw EXECUTION_EXCEPTION("Must specify recovery manager.", common::ErrorCode::ERRCODE_INTERNAL_ERROR); + } + + // Query parameters (parameters_) is not validated because default is empty collection + // ExecutionSettings exec_settings = exec_settings_.value(); + return std::unique_ptr{ + new ExecutionContext{db_oid_, std::move(parameters_), exec_settings_.value(), txn_.value(), + output_schema_.value(), std::move(output_callback_.value()), catalog_accessor_.value(), + metrics_manager_.value(), replication_manager_.value(), recovery_manager_.value()}}; +} + +ExecutionContextBuilder &ExecutionContextBuilder::WithQueryParametersFrom( + const std::vector ¶meter_exprs) { + NOISEPAGE_ASSERT(parameters_.empty(), "Attempt to initialize query parameters more than once."); + parameters_.reserve(parameter_exprs.size()); + std::transform(parameter_exprs.cbegin(), parameter_exprs.cend(), std::back_inserter(parameters_), + [](const parser::ConstantValueExpression &expr) -> common::ManagedPointer { + return common::ManagedPointer{expr.SqlValue()}; + }); + return *this; +} +} // namespace noisepage::execution::exec diff --git a/src/execution/parsing/parser.cpp b/src/execution/parsing/parser.cpp index 592c91eaca..549cefa852 100644 --- a/src/execution/parsing/parser.cpp +++ b/src/execution/parsing/parser.cpp @@ -425,11 +425,44 @@ ast::Expr *Parser::ParseUnaryOpExpr() { return ParsePrimaryExpr(); } +ast::Expr *Parser::ParseLambdaExpr() { + Expect(Token::Type::LAMBDA); + + const SourcePosition &position = scanner_->CurrentPosition(); + + util::RegionVector captures(Region()); + + Expect(Token::Type::LEFT_BRACKET); + + while (Peek() != Token::Type::RIGHT_BRACKET) { + if (Matches(Token::Type::IDENTIFIER)) { + auto var = GetSymbol(); + captures.push_back(new (Region()) ast::IdentifierExpr(position, var)); + } + + if (!Matches(Token::Type::COMMA)) { + break; + } + } + + Expect(Token::Type::RIGHT_BRACKET); + + // The function literal + auto *fun = ParseFunctionLitExpr()->As(); + + // Create declaration + auto *lambda = node_factory_->NewLambdaExpr(position, fun, std::move(captures)); + + // Done + return lambda; +} + ast::Expr *Parser::ParsePrimaryExpr() { - // PrimaryExpr = Operand | CallExpr | MemberExpr | IndexExpr ; + // PrimaryExpr = Operand | CallExpr | MemberExpr | IndexExpr | LambdaExpr ; // CallExpr = PrimaryExpr '(' (Expr)* ') ; // MemberExpr = PrimaryExpr '.' Expr // IndexExpr = PrimaryExpr '[' Expr ']' + // LambdaExpr = lambda (FunctionLitExpr) ast::Expr *result = ParseOperand(); @@ -538,6 +571,10 @@ ast::Expr *Parser::ParseOperand() { Expect(Token::Type::RIGHT_PAREN); return expr; } + case Token::Type::LAMBDA: { + return ParseLambdaExpr(); + break; + } default: { break; } @@ -584,6 +621,9 @@ ast::Expr *Parser::ParseType() { case Token::Type::STRUCT: { return ParseStructType(); } + case Token::Type::LAMBDA: { + return ParseLambdaType(); + } default: { break; } @@ -728,4 +768,20 @@ ast::Expr *Parser::ParseMapType() { return node_factory_->NewMapType(position, key_type, value_type); } +ast::Expr *Parser::ParseLambdaType() { + // LambdaType = 'lambda' '[' FunctionExpr ']' ; + + const SourcePosition &position = scanner_->CurrentPosition(); + + Consume(Token::Type::LAMBDA); + + Expect(Token::Type::LEFT_BRACKET); + + ast::Expr *fn_type = ParseFunctionType(); + + Expect(Token::Type::RIGHT_BRACKET); + + return node_factory_->NewLambdaType(position, fn_type); +} + } // namespace noisepage::execution::parsing diff --git a/src/execution/parsing/scanner.cpp b/src/execution/parsing/scanner.cpp index 16c0e18f12..f349f80bb6 100644 --- a/src/execution/parsing/scanner.cpp +++ b/src/execution/parsing/scanner.cpp @@ -298,6 +298,8 @@ Token::Type Scanner::ScanIdentifierOrKeyword() { GROUP_START('i') \ GROUP_ELEM("if", Token::Type::IF) \ GROUP_ELEM("in", Token::Type::IN) \ + GROUP_START('l') \ + GROUP_ELEM("lambda", Token::Type::LAMBDA) \ GROUP_START('m') \ GROUP_ELEM("map", Token::Type::MAP) \ GROUP_START('n') \ diff --git a/src/execution/sema/scope.cpp b/src/execution/sema/scope.cpp index 16baf64cd5..f2357eb56f 100644 --- a/src/execution/sema/scope.cpp +++ b/src/execution/sema/scope.cpp @@ -29,4 +29,18 @@ ast::Type *Scope::LookupLocal(ast::Identifier name) const { return (iter == decls_.end() ? nullptr : iter->second); } +Scope::Kind Scope::GetKind() const { return scope_kind_; } + +std::vector> Scope::GetLocals() const { + std::vector> locals; + auto scope = this; + do { + for (auto it : scope->decls_) { + locals.emplace_back(it.first, it.second); + } + scope = scope->outer_; + } while (scope->scope_kind_ != Scope::Kind::Function); + return locals; +} + } // namespace noisepage::execution::sema diff --git a/src/execution/sema/sema_builtin.cpp b/src/execution/sema/sema_builtin.cpp index 421e7fb89d..923f834010 100644 --- a/src/execution/sema/sema_builtin.cpp +++ b/src/execution/sema/sema_builtin.cpp @@ -1381,10 +1381,11 @@ void Sema::CheckBuiltinTableIterParCall(ast::CallExpr *call) { // Check the type of the scanner function parameters. See TableVectorIterator::ScanFn. const auto tvi_kind = ast::BuiltinType::TableVectorIterator; const auto ¶ms = scan_fn_type->GetParams(); - if (params.size() != 3 // Scan function has 3 arguments. - || !params[0].type_->IsPointerType() // QueryState, must contain execCtx. - || !params[1].type_->IsPointerType() // Thread state. - || !IsPointerToSpecificBuiltin(params[2].type_, tvi_kind)) { // TableVectorIterator. + + if (params.size() != 3 // Call has 3 parameters + || !params[0].GetType()->IsPointerType() // QueryState* + || !params[1].GetType()->IsPointerType() // PipelineState* + || !IsPointerToSpecificBuiltin(params[2].GetType(), tvi_kind)) { // TableVectorIterator* GetErrorReporter()->Report(call->Position(), ErrorMessages::kBadParallelScanFunction, call_args[5]->GetType()); return; } @@ -2079,22 +2080,27 @@ void Sema::CheckBuiltinPtrCastCall(ast::CallExpr *call) { return; } + if (call->Arguments()[0]->GetType() != nullptr && call->Arguments()[1]->GetType() != nullptr && + call->Arguments()[0]->GetType()->IsPointerType() && call->Arguments()[1]->GetType()->IsPointerType()) { + return; + } + // The first argument will be a UnaryOpExpr with the '*' (star) op. This is // because parsing function calls assumes expression arguments, not types. So, // something like '*Type', which would be the first argument to @ptrCast, will // get parsed as a dereference expression before a type expression. // TODO(pmenon): Fix the above to parse correctly - auto unary_op = call->Arguments()[0]->SafeAs(); - if (unary_op == nullptr || unary_op->Op() != parsing::Token::Type::STAR) { - GetErrorReporter()->Report(call->Position(), ErrorMessages::kBadArgToPtrCast, call->Arguments()[0]->GetType(), 1); - return; + if (!call->Arguments()[0]->Is()) { + auto unary_op = call->Arguments()[0]->SafeAs(); + if (unary_op == nullptr || unary_op->Op() != parsing::Token::Type::STAR) { + GetErrorReporter()->Report(call->Position(), ErrorMessages::kBadArgToPtrCast, call->Arguments()[0]->GetType(), 1); + return; + } + call->SetArgument( + 0, GetContext()->GetNodeFactory()->NewPointerType(call->Arguments()[0]->Position(), unary_op->Input())); } - // Replace the unary with a PointerTypeRepr node and resolve it - call->SetArgument( - 0, GetContext()->GetNodeFactory()->NewPointerType(call->Arguments()[0]->Position(), unary_op->Input())); - for (auto *arg : call->Arguments()) { auto *resolved_type = Resolve(arg); if (resolved_type == nullptr) { @@ -3098,10 +3104,6 @@ void Sema::CheckBuiltinAbortCall(ast::CallExpr *call) { } void Sema::CheckBuiltinParamCall(ast::CallExpr *call, ast::Builtin builtin) { - if (!CheckArgCount(call, 2)) { - return; - } - // first argument is an exec ctx auto exec_ctx_kind = ast::BuiltinType::ExecutionContext; if (!IsPointerToSpecificBuiltin(call->Arguments()[0]->GetType(), exec_ctx_kind)) { @@ -3110,48 +3112,98 @@ void Sema::CheckBuiltinParamCall(ast::CallExpr *call, ast::Builtin builtin) { } // second argument is the index of the parameter - if (!call->Arguments()[1]->GetType()->IsIntegerType()) { - ReportIncorrectCallArg(call, 0, GetBuiltinType(ast::BuiltinType::Kind::Uint32)); - return; - } - - // Type output sql value - ast::BuiltinType::Kind sql_type; - switch (builtin) { - case ast::Builtin::GetParamBool: { - sql_type = ast::BuiltinType::Boolean; - break; - } - case ast::Builtin::GetParamTinyInt: - case ast::Builtin::GetParamSmallInt: - case ast::Builtin::GetParamInt: - case ast::Builtin::GetParamBigInt: { - sql_type = ast::BuiltinType::Integer; - break; - } - case ast::Builtin::GetParamReal: - case ast::Builtin::GetParamDouble: { - sql_type = ast::BuiltinType::Real; - break; + if (builtin < ast::Builtin::StartNewParams) { + if (!call->Arguments()[1]->GetType()->IsIntegerType()) { + ReportIncorrectCallArg(call, 1, GetBuiltinType(ast::BuiltinType::Kind::Uint32)); + return; } - case ast::Builtin::GetParamDate: { - sql_type = ast::BuiltinType::Date; - break; + + // Type output sql value + ast::BuiltinType::Kind sql_type; + switch (builtin) { + case ast::Builtin::GetParamBool: { + sql_type = ast::BuiltinType::Boolean; + break; + } + case ast::Builtin::GetParamTinyInt: + case ast::Builtin::GetParamSmallInt: + case ast::Builtin::GetParamInt: + case ast::Builtin::GetParamBigInt: { + sql_type = ast::BuiltinType::Integer; + break; + } + case ast::Builtin::GetParamReal: + case ast::Builtin::GetParamDouble: { + sql_type = ast::BuiltinType::Real; + break; + } + case ast::Builtin::GetParamDate: { + sql_type = ast::BuiltinType::Date; + break; + } + case ast::Builtin::GetParamTimestamp: { + sql_type = ast::BuiltinType::Timestamp; + break; + } + case ast::Builtin::GetParamString: { + sql_type = ast::BuiltinType::StringVal; + break; + } + default: + UNREACHABLE("Undefined parameter call!!"); } - case ast::Builtin::GetParamTimestamp: { - sql_type = ast::BuiltinType::Timestamp; - break; + // Return sql type + call->SetType(ast::BuiltinType::Get(GetContext(), sql_type)); + return; + } + if (builtin > ast::Builtin::FinishNewParams) { + ast::BuiltinType::Kind add_sql_type; + switch (builtin) { + case ast::Builtin::AddParamBool: { + add_sql_type = ast::BuiltinType::Boolean; + break; + } + case ast::Builtin::AddParamTinyInt: + case ast::Builtin::AddParamSmallInt: + case ast::Builtin::AddParamInt: + case ast::Builtin::AddParamBigInt: { + add_sql_type = ast::BuiltinType::Integer; + break; + } + case ast::Builtin::AddParamReal: + case ast::Builtin::AddParamDouble: { + add_sql_type = ast::BuiltinType::Real; + break; + } + case ast::Builtin::AddParamDate: { + add_sql_type = ast::BuiltinType::Date; + break; + } + case ast::Builtin::AddParamTimestamp: { + add_sql_type = ast::BuiltinType::Timestamp; + break; + } + case ast::Builtin::AddParamString: { + add_sql_type = ast::BuiltinType::StringVal; + break; + } + default: { + UNREACHABLE("Undefined parameter call!!"); + } } - case ast::Builtin::GetParamString: { - sql_type = ast::BuiltinType::StringVal; - break; + if (call->Arguments()[1]->GetType() != GetBuiltinType(add_sql_type)) { + ReportIncorrectCallArg(call, 1, GetBuiltinType(add_sql_type)); + return; } - default: - UNREACHABLE("Undefined parameter call!!"); } + call->SetType(ast::BuiltinType::Get(GetContext(), ast::BuiltinType::Nil)); +} - // Return sql type - call->SetType(ast::BuiltinType::Get(GetContext(), sql_type)); +void Sema::CheckBuiltinRandomCall(ast::CallExpr *call, ast::Builtin builtin) { + if (!CheckArgCount(call, 0)) { + return; + } + call->SetType(ast::BuiltinType::Get(GetContext(), ast::BuiltinType::Kind::Real)); } void Sema::CheckBuiltinStringCall(ast::CallExpr *call, ast::Builtin builtin) { @@ -3980,7 +4032,19 @@ void Sema::CheckBuiltinCall(ast::CallExpr *call) { case ast::Builtin::GetParamDouble: case ast::Builtin::GetParamDate: case ast::Builtin::GetParamTimestamp: - case ast::Builtin::GetParamString: { + case ast::Builtin::GetParamString: + case ast::Builtin::AddParamBool: + case ast::Builtin::AddParamTinyInt: + case ast::Builtin::AddParamSmallInt: + case ast::Builtin::AddParamInt: + case ast::Builtin::AddParamBigInt: + case ast::Builtin::AddParamReal: + case ast::Builtin::AddParamDouble: + case ast::Builtin::AddParamDate: + case ast::Builtin::AddParamTimestamp: + case ast::Builtin::AddParamString: + case ast::Builtin::StartNewParams: + case ast::Builtin::FinishNewParams: { CheckBuiltinParamCall(call, builtin); break; } @@ -4079,6 +4143,10 @@ void Sema::CheckBuiltinCall(ast::CallExpr *call) { CheckBuiltinTestCatalogIndexLookup(call); break; } + case ast::Builtin::Random: { + CheckBuiltinRandomCall(call, builtin); + break; + } default: UNREACHABLE("Unhandled builtin!"); } diff --git a/src/execution/sema/sema_checking.cpp b/src/execution/sema/sema_checking.cpp index a97accec57..2f74949bac 100644 --- a/src/execution/sema/sema_checking.cpp +++ b/src/execution/sema/sema_checking.cpp @@ -269,6 +269,13 @@ bool Sema::CheckAssignmentConstraints(ast::Type *target_type, ast::Expr **expr) return true; } + // Lambdas (more accurately, the closures produced by lambda expressions) + if (target_type->IsLambdaType() && (*expr)->GetType()->IsLambdaType()) { + auto expr_fn_type = (*expr)->GetType()->As()->GetFunctionType(); + auto target_fn_type = target_type->As()->GetFunctionType(); + return expr_fn_type->IsEqual(target_fn_type); + } + // Integer expansion if (target_type->IsIntegerType() && (*expr)->GetType()->IsIntegerType()) { if (target_type->GetSize() > (*expr)->GetType()->GetSize()) { diff --git a/src/execution/sema/sema_decl.cpp b/src/execution/sema/sema_decl.cpp index 71beaf7e81..aa79334a09 100644 --- a/src/execution/sema/sema_decl.cpp +++ b/src/execution/sema/sema_decl.cpp @@ -24,6 +24,9 @@ void Sema::VisitVariableDecl(ast::VariableDecl *node) { } if (node->HasInitialValue()) { + if (node->Initial()->GetKind() == ast::AstNode::Kind::LambdaExpr) { + node->Initial()->As()->name_ = node->Name(); + } initializer_type = Resolve(node->Initial()); } diff --git a/src/execution/sema/sema_expr.cpp b/src/execution/sema/sema_expr.cpp index bb48f23bf2..66f99453e8 100644 --- a/src/execution/sema/sema_expr.cpp +++ b/src/execution/sema/sema_expr.cpp @@ -83,6 +83,11 @@ void Sema::VisitCallExpr(ast::CallExpr *node) { return; } + // Type checking already performed + if (node->GetType() != nullptr) { + return; + } + // Resolve the function type ast::Type *type = Resolve(node->Function()); if (type == nullptr) { @@ -91,13 +96,30 @@ void Sema::VisitCallExpr(ast::CallExpr *node) { // Check that the resolved function type is actually a function auto *func_type = type->SafeAs(); + auto *struct_type = type->SafeAs(); + auto lambda_adjustment = 1; if (func_type == nullptr) { - GetErrorReporter()->Report(node->Position(), ErrorMessages::kNonFunction); - return; + if (struct_type != nullptr) { + func_type = struct_type->GetFunctionType(); + // TODO(Kyle): Find a better way to see if sema has processed this already + ast::IdentifierExpr *last_arg = nullptr; + if (!node->Arguments().empty()) { + last_arg = node->Arguments().back()->SafeAs(); + } + if (last_arg != nullptr && last_arg->Name() == node->GetFuncName()) { + // already processed + lambda_adjustment = 0; + } + } else { + GetErrorReporter()->Report(node->Position(), ErrorMessages::kNonFunction); + return; + } } // Check argument count matches - if (!CheckArgCount(node, func_type->GetNumParams())) { + const auto arg_count = + (struct_type != nullptr) ? func_type->GetNumParams() - lambda_adjustment : func_type->GetNumParams(); + if (!CheckArgCount(node, arg_count)) { return; } @@ -133,6 +155,10 @@ void Sema::VisitCallExpr(ast::CallExpr *node) { } } + if (struct_type != nullptr && lambda_adjustment > 0) { + node->PushArgument(GetContext()->GetNodeFactory()->NewIdentifierExpr(SourcePosition(), node->GetFuncName())); + } + if (has_errors) { return; } @@ -141,6 +167,70 @@ void Sema::VisitCallExpr(ast::CallExpr *node) { node->SetType(func_type->GetReturnType()); } +void Sema::VisitLambdaExpr(ast::LambdaExpr *node) { + auto factory = GetContext()->GetNodeFactory(); + + // Resolve the types necessary to get the type representation + // used to implement captures for closures produced by lambdas + + // TODO(Kyle): We perform quite a bit of mutation here during + // semantic analysis because this is where we resolve the type + // of the captures for the closure produced by the lambda expression; + // in the future we might want to revisit this to determine if + // we can perform this resolution during AST construction instead. + + util::RegionVector fields(GetContext()->GetRegion()); + for (auto expr : node->GetCaptureIdents()) { + auto ident = expr->As(); + Resolve(ident); + if (ident->GetType()->SafeAs() != nullptr) { + auto type_repr = factory->NewPointerType( + SourcePosition(), + factory->NewIdentifierExpr( + SourcePosition(), + GetContext()->GetIdentifier( + ast::BuiltinType::Get(GetContext(), ident->GetType()->As()->GetKind()) + ->GetTplName()))); + fields.push_back(factory->NewFieldDecl(SourcePosition(), ident->Name(), type_repr)); + } else { + util::RegionVector nested_fields{GetContext()->GetRegion()}; + for (const auto &field : ident->GetType()->SafeAs()->GetFieldsWithoutPadding()) { + nested_fields.push_back(factory->NewFieldDecl( + SourcePosition(), field.name_, + factory->NewIdentifierExpr( + SourcePosition(), + GetContext()->GetIdentifier( + ast::BuiltinType::Get(GetContext(), field.type_->As()->GetKind()) + ->GetTplName())))); + } + auto *type_repr = + factory->NewPointerType(SourcePosition(), factory->NewStructType(SourcePosition(), std::move(nested_fields))); + fields.push_back(factory->NewFieldDecl(SourcePosition(), ident->Name(), type_repr)); + } + } + + fields.push_back( + factory->NewFieldDecl(SourcePosition(), GetContext()->GetIdentifier("function"), + factory->NewPointerType(SourcePosition(), node->GetFunctionLiteralExpr()->TypeRepr()))); + + ast::StructTypeRepr *struct_type_repr = factory->NewStructType(SourcePosition(), std::move(fields)); + ast::StructDecl *struct_decl = factory->NewStructDecl( + SourcePosition(), GetContext()->GetIdentifier("lambda" + std::to_string(node->Position().line_)), + struct_type_repr); + VisitStructDecl(struct_decl); + node->SetCaptureStructType(Resolve(struct_type_repr)); + node->SetType(ast::LambdaType::Get(Resolve(node->GetFunctionLiteralExpr()->TypeRepr())->As())); + + auto type = Resolve(node->GetFunctionLiteralExpr()->TypeRepr()); + auto fn_type = type->As(); + fn_type->GetParams().emplace_back(GetContext()->GetIdentifier("captures"), + GetBuiltinType(ast::BuiltinType::Kind::Int32)->PointerTo()); + fn_type->SetIsLambda(true); + fn_type->SetCapturesType(node->GetCaptureStructType()->As()); + + VisitFunctionLitExpr(node->GetFunctionLiteralExpr()); +} + void Sema::VisitFunctionLitExpr(ast::FunctionLitExpr *node) { // Resolve the type, if not resolved already if (auto *type = node->TypeRepr()->GetType(); type == nullptr) { @@ -157,6 +247,15 @@ void Sema::VisitFunctionLitExpr(ast::FunctionLitExpr *node) { // The function scope FunctionSemaScope function_scope(this, node); + if (node->IsLambda()) { + auto ¶ms = func_type->GetParams(); + auto captures = params[params.size() - 1]; + auto capture_type = captures.type_->As(); + for (auto field : capture_type->GetFieldsWithoutPadding()) { + GetCurrentScope()->Declare(field.name_, field.type_->GetPointeeType()->ReferenceTo()); + } + } + // Declare function parameters in scope for (const auto ¶m : func_type->GetParams()) { GetCurrentScope()->Declare(param.name_, param.type_); @@ -192,12 +291,18 @@ void Sema::VisitIdentifierExpr(ast::IdentifierExpr *node) { return; } + if (auto *type = GetCurrentScope()->Lookup(node->Name())) { + node->SetType(type); + return; + } + // Error GetErrorReporter()->Report(node->Position(), ErrorMessages::kUndefinedVariable, node->Name()); } void Sema::VisitImplicitCastExpr(ast::ImplicitCastExpr *node) { - throw std::runtime_error("Should never perform semantic checking on implicit cast expressions"); + // TODO(Kyle): Why did we throw here before? + Visit(node->Input()); } void Sema::VisitIndexExpr(ast::IndexExpr *node) { diff --git a/src/execution/sema/sema_stmt.cpp b/src/execution/sema/sema_stmt.cpp index e0972962e9..788207de09 100644 --- a/src/execution/sema/sema_stmt.cpp +++ b/src/execution/sema/sema_stmt.cpp @@ -24,6 +24,11 @@ void Sema::VisitAssignmentStmt(ast::AssignmentStmt *node) { if (source != node->Source()) { node->SetSource(source); } + + if (src_type->IsFunctionType()) { + // this is a lambda function assignment + node->Source()->As()->name_ = node->Destination()->As()->Name(); + } } void Sema::VisitBlockStmt(ast::BlockStmt *node) { @@ -57,6 +62,16 @@ void Sema::VisitForStmt(ast::ForStmt *node) { return; } // If the resolved type isn't a boolean, it's an error + if (cond_type->IsSqlBooleanType()) { + auto context = GetContext(); + auto factory = context->GetNodeFactory(); + auto args = util::RegionVector({node->Condition()}, context->GetRegion()); + node->SetCondition(factory->NewBuiltinCallExpr( + factory->NewIdentifierExpr(node->Position(), + GetContext()->GetBuiltinFunction(execution::ast::Builtin::SqlToBool)), + std::move(args))); + cond_type = Resolve(node->Condition()); + } if (!cond_type->IsBoolType()) { error_reporter_->Report(node->Condition()->Position(), ErrorMessages::kNonBoolForCondition); } @@ -70,6 +85,22 @@ void Sema::VisitForStmt(ast::ForStmt *node) { Visit(node->Body()); } +void Sema::VisitBreakStmt(ast::BreakStmt *node) { + // Look for a loop in my scope stack + auto scope = GetCurrentScope(); + bool found_loop = false; + while (scope != nullptr) { + found_loop |= scope->GetKind() == Scope::Kind::Loop; + if (found_loop) { + break; + } + scope = scope->Outer(); + } + if (!found_loop) { + error_reporter_->Report(node->Position(), ErrorMessages::kNoScopeToBreak); + } +} + void Sema::VisitForInStmt(ast::ForInStmt *node) { NOISEPAGE_ASSERT(false, "Not supported"); } void Sema::VisitExpressionStmt(ast::ExpressionStmt *node) { Visit(node->Expression()); } diff --git a/src/execution/sema/sema_type.cpp b/src/execution/sema/sema_type.cpp index fb5087c715..eb619b188d 100644 --- a/src/execution/sema/sema_type.cpp +++ b/src/execution/sema/sema_type.cpp @@ -90,4 +90,24 @@ void Sema::VisitMapTypeRepr(ast::MapTypeRepr *node) { node->SetType(ast::MapType::Get(key_type, value_type)); } +void Sema::VisitLambdaTypeRepr(ast::LambdaTypeRepr *node) { + auto *fn_type = Resolve(node->FunctionType())->SafeAs(); + if (fn_type == nullptr) { + return; + } + + // Captures are passed to the function that implements the lambda + // by way of the final parameter to the function; the parameter is + // always specified as an Int32 pointer and then we emit the code + // necessary to dereference the pointers within the structure + // (relative to the base pointer) appropriately to extract captures + + // TODO(Kyle): This seems like a potentially-expedient yet needlessly + // confusing (and potentially unsafe?) way to implement the passage + // the captures structure to the function that implements the closure + fn_type->GetParams().emplace_back(GetContext()->GetIdentifier("captures"), + GetBuiltinType(ast::BuiltinType::Kind::Int32)->PointerTo()); + node->SetType(ast::LambdaType::Get(fn_type)); +} + } // namespace noisepage::execution::sema diff --git a/src/execution/sql/ddl_executors.cpp b/src/execution/sql/ddl_executors.cpp index 13ed211cc4..b4e44aaf08 100644 --- a/src/execution/sql/ddl_executors.cpp +++ b/src/execution/sql/ddl_executors.cpp @@ -5,14 +5,26 @@ #include #include "catalog/catalog_accessor.h" +#include "catalog/postgres/pg_language.h" #include "common/macros.h" +#include "execution/ast/ast_pretty_print.h" +#include "execution/ast/context.h" +#include "execution/ast/udf/udf_ast_context.h" +#include "execution/compiler/codegen.h" +#include "execution/compiler/function_builder.h" +#include "execution/compiler/udf/udf_codegen.h" #include "execution/exec/execution_context.h" +#include "execution/sema/sema.h" +#include "loggers/execution_logger.h" #include "parser/expression/column_value_expression.h" +#include "parser/udf/plpgsql_parser.h" #include "planner/plannodes/create_database_plan_node.h" +#include "planner/plannodes/create_function_plan_node.h" #include "planner/plannodes/create_index_plan_node.h" #include "planner/plannodes/create_namespace_plan_node.h" #include "planner/plannodes/create_table_plan_node.h" #include "planner/plannodes/drop_database_plan_node.h" +#include "planner/plannodes/drop_function_plan_node.h" #include "planner/plannodes/drop_index_plan_node.h" #include "planner/plannodes/drop_namespace_plan_node.h" #include "planner/plannodes/drop_table_plan_node.h" @@ -33,6 +45,118 @@ bool DDLExecutors::CreateNamespaceExecutor(const common::ManagedPointerCreateNamespace(node->GetNamespaceName()) != catalog::INVALID_NAMESPACE_OID; } +bool DDLExecutors::CreateFunctionExecutor(const common::ManagedPointer node, + const common::ManagedPointer accessor) { + NOISEPAGE_ASSERT(node->GetUDFLanguage() == parser::PLType::PL_PGSQL, "Unsupported language"); + NOISEPAGE_ASSERT(!node->GetFunctionBody().empty(), "Unsupported function body contents"); + + const auto ¶meter_types = node->GetFunctionParameterTypes(); + + std::vector param_type_ids{}; + std::transform( + parameter_types.cbegin(), parameter_types.cend(), std::back_inserter(param_type_ids), + [](const parser::BaseFunctionParameter::DataType &t) { return parser::FuncParameter::DataTypeToTypeId(t); }); + + std::vector param_types{}; + std::transform(parameter_types.cbegin(), parameter_types.cend(), std::back_inserter(param_types), + [&accessor](const parser::BaseFunctionParameter::DataType &t) { + return accessor->GetTypeOidFromTypeId(parser::FuncParameter::DataTypeToTypeId(t)); + }); + + const auto return_type = accessor->GetTypeOidFromTypeId(parser::ReturnType::DataTypeToTypeId(node->GetReturnType())); + auto proc_id = + accessor->CreateProcedure(node->GetFunctionName(), catalog::postgres::PgLanguage::PLPGSQL_LANGUAGE_OID, + node->GetNamespaceOid(), catalog::INVALID_TYPE_OID, node->GetFunctionParameterNames(), + param_types, {}, {}, return_type, node->GetFunctionBody().front(), false); + if (proc_id == catalog::INVALID_PROC_OID) { + return false; + } + + // Make the context here using the body + ast::udf::UdfAstContext udf_ast_context{}; + + // TODO(Kyle): Revisit this after clearing up what the + // preferred way to report errors is in the system, both + // within components and between components... + parser::udf::PLpgSQLParser udf_parser{common::ManagedPointer{&udf_ast_context}}; + std::unique_ptr ast{}; + try { + ast = udf_parser.Parse(node->GetFunctionParameterNames(), param_type_ids, node->GetFunctionBody().front()); + } catch (const ParserException &parser_error) { + PARSER_LOG_ERROR(parser_error.what()); + return false; + } + + auto region = std::make_unique(node->GetFunctionName()); + sema::ErrorReporter error_reporter{region.get()}; + + auto ast_context = std::make_unique(region.get(), &error_reporter); + + compiler::CodeGen codegen{ast_context.get(), accessor.Get()}; + util::RegionVector fn_params{codegen.GetAstContext()->GetRegion()}; + fn_params.emplace_back( + codegen.MakeField(codegen.MakeFreshIdentifier("executionCtx"), + codegen.PointerType(codegen.BuiltinType(ast::BuiltinType::ExecutionContext)))); + + for (auto i = 0UL; i < node->GetFunctionParameterNames().size(); i++) { + const auto raw = node->GetFunctionParameterTypes()[i]; + (void)raw; + const auto name = node->GetFunctionParameterNames()[i]; + const auto type = parser::BaseFunctionParameter::DataTypeToTypeId(node->GetFunctionParameterTypes()[i]); + fn_params.emplace_back( + codegen.MakeField(ast_context->GetIdentifier(name), codegen.TplType(execution::sql::GetTypeId(type)))); + } + + auto name = node->GetFunctionName(); + compiler::FunctionBuilder fb{ + &codegen, codegen.MakeFreshIdentifier(name), std::move(fn_params), + codegen.TplType(execution::sql::GetTypeId(parser::ReturnType::DataTypeToTypeId(node->GetReturnType())))}; + + // Run UDF code generation + ast::File *file; + try { + file = compiler::udf::UdfCodegen::Run(accessor.Get(), &fb, &udf_ast_context, &codegen, node->GetDatabaseOid(), + ast.get()); + } catch (const BinderException &binder_error) { + EXECUTION_LOG_ERROR(binder_error.what()); + return false; + } catch (const ExecutionException &execution_error) { + EXECUTION_LOG_ERROR(execution_error.what()); + return false; + } + + { + sema::Sema type_check{codegen.GetAstContext().Get()}; + type_check.GetErrorReporter()->Reset(); + if (type_check.Run(file)) { + EXECUTION_LOG_ERROR("Errors: \n {}", type_check.GetErrorReporter()->SerializeErrors()); + return false; + } + } + + // TODO(Kyle): We are recomputing the types here because we lost + // them to a std::move() above when we generate the AST, can we + // avoid duplicating this work? Would need to change the APIs. + + std::vector types{}; + types.reserve(node->GetFunctionParameterTypes().size()); + std::transform(node->GetFunctionParameterTypes().cbegin(), node->GetFunctionParameterTypes().cend(), + std::back_inserter(types), [](const parser::BaseFunctionParameter::DataType &type) -> sql::SqlTypeId { + return parser::FuncParameter::DataTypeToTypeId(type); + }); + + auto udf_context = std::make_unique( + node->GetFunctionName(), parser::ReturnType::DataTypeToTypeId(node->GetReturnType()), std::move(types), + std::move(region), std::move(ast_context), file); + + // TODO(Kyle): We used to manually register an abort action here to destroy the + // function context in the event the transaction aborts, but this is already + // done in the catalog (in the call to CatalogAccessor::SetFunctionContext), is + // this the "ownership model" for transaction abort that we want? + + return accessor->SetFunctionContext(proc_id, udf_context.release()); +} + bool DDLExecutors::CreateTableExecutor(const common::ManagedPointer node, const common::ManagedPointer accessor, const catalog::db_oid_t connection_db) { @@ -162,4 +286,9 @@ bool DDLExecutors::CreateIndex(const common::ManagedPointer node, + common::ManagedPointer accessor) { + return accessor->DropProcedure(node->GetProcedureOid()); +} + } // namespace noisepage::execution::sql diff --git a/src/execution/sql/functions/system_functions.cpp b/src/execution/sql/functions/system_functions.cpp index 4d023ce669..3e71dc4071 100644 --- a/src/execution/sql/functions/system_functions.cpp +++ b/src/execution/sql/functions/system_functions.cpp @@ -1,5 +1,7 @@ #include "execution/sql/functions/system_functions.h" +#include + #include "common/version.h" #include "execution/exec/execution_context.h" @@ -10,4 +12,12 @@ void SystemFunctions::Version(UNUSED_ATTRIBUTE exec::ExecutionContext *ctx, Stri *result = StringVal(version); } +void SystemFunctions::Random(Real *result) { + // TODO(Kyle): Static locals are kind of gross, where + // should state for this type of one-off thing live? + static std::mt19937 generator{std::random_device{}()}; // NOLINT + static std::uniform_real_distribution<> distribution{0, 1}; + *result = Real(distribution(generator)); +} + } // namespace noisepage::execution::sql diff --git a/src/execution/sql/sql.cpp b/src/execution/sql/sql.cpp index 4922c57f1d..60284b80d4 100644 --- a/src/execution/sql/sql.cpp +++ b/src/execution/sql/sql.cpp @@ -272,8 +272,12 @@ std::string SqlTypeIdToString(SqlTypeId type) { return "Integer"; case SqlTypeId::BigInt: return "BigInt"; + case SqlTypeId::Real: + return "Real"; case SqlTypeId::Double: return "Double"; + case SqlTypeId::Decimal: + return "Decimal"; case SqlTypeId::Date: return "Date"; case SqlTypeId::Timestamp: @@ -282,6 +286,8 @@ std::string SqlTypeIdToString(SqlTypeId type) { return "Varchar"; case SqlTypeId::Varbinary: return "Varbinary"; + case SqlTypeId::Invalid: + return "Invalid"; default: // All cases handled UNREACHABLE("Impossible type"); @@ -329,8 +335,7 @@ SqlTypeId SqlTypeIdFromString(const std::string &type_string) { } TypeId GetTypeId(SqlTypeId frontend_type) { - execution::sql::TypeId execution_type_id; - + TypeId execution_type_id; switch (frontend_type) { case SqlTypeId::Boolean: execution_type_id = execution::sql::TypeId::Boolean; @@ -347,6 +352,9 @@ TypeId GetTypeId(SqlTypeId frontend_type) { case SqlTypeId::BigInt: execution_type_id = execution::sql::TypeId::BigInt; break; + case SqlTypeId::Real: + execution_type_id = execution::sql::TypeId::Float; + break; case SqlTypeId::Double: execution_type_id = execution::sql::TypeId::Double; break; diff --git a/src/execution/vm/bytecode_emitter.cpp b/src/execution/vm/bytecode_emitter.cpp index 7678d4610c..08bdcf4f8b 100644 --- a/src/execution/vm/bytecode_emitter.cpp +++ b/src/execution/vm/bytecode_emitter.cpp @@ -25,6 +25,10 @@ void BytecodeEmitter::EmitAssign(Bytecode bytecode, LocalVar dest, LocalVar src) EmitAll(bytecode, dest, src); } +void BytecodeEmitter::EmitAssignN(LocalVar dest, LocalVar src, uint32_t len) { + EmitAll(Bytecode::AssignN, dest, src.AddressOf(), len); +} + void BytecodeEmitter::EmitAssignImm1(LocalVar dest, int8_t val) { EmitAll(Bytecode::AssignImm1, dest, val); } void BytecodeEmitter::EmitAssignImm2(LocalVar dest, int16_t val) { EmitAll(Bytecode::AssignImm2, dest, val); } @@ -62,6 +66,18 @@ void BytecodeEmitter::EmitCall(FunctionId func_id, const std::vector & } } +std::function BytecodeEmitter::DeferredEmitCall(const std::vector ¶ms) { + NOISEPAGE_ASSERT(Bytecodes::GetNthOperandSize(Bytecode::Call, 1) == OperandSize::Short, + "Expected argument count to be 2-byte short"); + NOISEPAGE_ASSERT(params.size() < std::numeric_limits::max(), "Too many parameters!"); + auto bc_insert_index = bytecode_->size() + sizeof(Bytecode); + EmitAll(Bytecode::Call, std::numeric_limits::max(), static_cast(params.size())); + for (LocalVar local : params) { + EmitImpl(local); + } + return [=](FunctionId func_id) { EmitScalarValue(static_cast(func_id), bc_insert_index); }; +} + void BytecodeEmitter::EmitReturn() { EmitImpl(Bytecode::Return); } void BytecodeEmitter::Bind(BytecodeLabel *label) { diff --git a/src/execution/vm/bytecode_generator.cpp b/src/execution/vm/bytecode_generator.cpp index 07d8948d2e..06062bb371 100644 --- a/src/execution/vm/bytecode_generator.cpp +++ b/src/execution/vm/bytecode_generator.cpp @@ -164,7 +164,8 @@ void BytecodeGenerator::VisitIterationStatement(ast::IterationStmt *iteration, L } void BytecodeGenerator::VisitForStmt(ast::ForStmt *node) { - LoopBuilder loop_builder(this); + LoopBuilder *prev = current_loop_; + LoopBuilder loop_builder{this, prev}; if (node->Init() != nullptr) { Visit(node->Init()); @@ -177,7 +178,9 @@ void BytecodeGenerator::VisitForStmt(ast::ForStmt *node) { VisitExpressionForTest(node->Condition(), &loop_body_label, loop_builder.GetBreakLabel(), TestFallthrough::Then); } + current_loop_ = &loop_builder; VisitIterationStatement(node, &loop_builder); + current_loop_ = prev; if (node->Next() != nullptr) { Visit(node->Next()); @@ -186,6 +189,12 @@ void BytecodeGenerator::VisitForStmt(ast::ForStmt *node) { loop_builder.JumpToHeader(); } +void BytecodeGenerator::VisitBreakStmt(ast::BreakStmt *node) { + if (current_loop_ != nullptr && current_loop_->GetPrevLoop() != nullptr) { + current_loop_->GetPrevLoop()->Break(); + } +} + void BytecodeGenerator::VisitForInStmt(UNUSED_ATTRIBUTE ast::ForInStmt *node) { NOISEPAGE_ASSERT(false, "For-in statements not supported"); } @@ -197,7 +206,8 @@ void BytecodeGenerator::VisitFunctionDecl(ast::FunctionDecl *node) { auto *func_type = node->TypeRepr()->GetType()->As(); // Allocate the function - FunctionInfo *func_info = AllocateFunc(node->Name().GetData(), func_type); + auto *func_info = AllocateFunction(node->Name().GetData(), func_type); + EnterFunction(func_info->GetId()); { // Visit the body of the function. We use this handy scope object to track @@ -207,6 +217,68 @@ void BytecodeGenerator::VisitFunctionDecl(ast::FunctionDecl *node) { BytecodePositionScope position_scope(this, func_info); Visit(node->Function()); } + + // Execute the deferred actions for the function; + // in the current implementation, the only functionality + // that relies on deferred actions during code generation + // are TPL lambda expressions that generate closures + for (auto &f : func_info->actions_) { + f(); + } +} + +void BytecodeGenerator::VisitLambdaExpr(ast::LambdaExpr *node) { + // The function's TPL type + auto *func_type = node->GetFunctionLiteralExpr()->GetType()->As(); + + // Elide code generation for lambda expressions that are not stored + if (!GetExecutionResult()->HasDestination()) { + return; + } + + auto captures = + GetCurrentFunction()->NewLocal(node->GetCaptureStructType(), node->GetName().GetString() + "Captures"); + auto fields = node->GetCaptureStructType()->As()->GetFieldsWithoutPadding(); + + // Capture each of the values for the closure by storing the + // current value of the captured local in the captures struct + for (std::size_t i = 0; i < fields.size() - 1; ++i) { + auto field = fields[i]; + ast::IdentifierExpr ident{node->Position(), field.name_}; + ident.SetType(field.type_->GetPointeeType()); + LocalVar local = VisitExpressionForLValue(&ident); + LocalVar fieldvar = GetCurrentFunction()->NewLocal(field.type_->PointerTo(), ""); + GetEmitter()->EmitLea(fieldvar, captures.AddressOf(), + node->GetCaptureStructType()->As()->GetOffsetOfFieldByName(field.name_)); + GetEmitter()->EmitAssign(Bytecode::Assign8, fieldvar.ValueOf(), local); + } + + GetEmitter()->EmitAssign(Bytecode::Assign8, GetExecutionResult()->GetDestination(), captures.AddressOf()); + FunctionInfo *func_info = AllocateFunction(node->GetName().GetString(), func_type); + + // Create a new deferred action for the current function + // that visits the body of the lambda; this action is subsequently + // executed when the function declaration itself is visited + GetCurrentFunction()->DeferAction([=]() { + func_info->captures_ = captures; + func_info->is_lambda_ = true; + { + // Visit the body of the function. We use this handy scope object to track + // the start and end position of this function's bytecode in the module's + // bytecode array. Upon destruction, the scoped class will set the bytecode + // range in the function. + EnterFunction(func_info->GetId()); + BytecodePositionScope position_scope(this, func_info); + Visit(node->GetFunctionLiteralExpr()->Body()); + } + for (auto &f : func_info->actions_) { + f(); + } + }); +} + +void BytecodeGenerator::VisitLambdaTypeRepr(ast::LambdaTypeRepr *node) { + UNREACHABLE("Should not visit type-representation nodes!"); } void BytecodeGenerator::VisitIdentifierExpr(ast::IdentifierExpr *node) { @@ -216,9 +288,58 @@ void BytecodeGenerator::VisitIdentifierExpr(ast::IdentifierExpr *node) { const std::string local_name = node->Name().GetData(); LocalVar local = GetCurrentFunction()->LookupLocal(local_name); + std::string suffix{}; + bool capture = false; + + if (local.IsInvalid() && GetCurrentFunction()->is_lambda_) { + local = GetCurrentFunction()->LookupLocal(local_name + "ptr").ValueOf(); + suffix = "ptr"; + if (!local.IsInvalid()) { + if (GetExecutionResult()->IsRValue()) { + auto local_val = GetCurrentFunction()->NewLocal(node->GetType(), ""); + GetEmitter()->EmitDerefN(local_val, local.ValueOf(), node->GetType()->GetSize()); + local = local_val; + } + } + } + + if (local.IsInvalid()) { + NOISEPAGE_ASSERT(GetCurrentFunction()->is_lambda_, "Not a lambda and variable not found"); + auto params = GetCurrentFunction()->func_type_->GetParams(); + auto captures = GetCurrentFunction()->func_type_->GetCapturesType(); + for (auto field : captures->GetFieldsWithoutPadding()) { + if (field.name_.GetString() == local_name) { + auto captures_local = GetCurrentFunction()->LookupLocal("captures"); + + auto local_ptr = GetCurrentFunction()->NewLocal(field.type_->PointerTo()); + GetEmitter()->EmitLea(local_ptr, captures_local.ValueOf(), captures->GetOffsetOfFieldByName(field.name_)); + + auto local_ptr_2 = GetCurrentFunction()->NewLocal(field.type_, local_name + "ptr"); + GetEmitter()->EmitDerefN(local_ptr_2, local_ptr.ValueOf(), field.type_->GetSize()); + + local = local_ptr_2; + suffix = "ptr"; + + if (GetExecutionResult()->IsRValue()) { + local = GetCurrentFunction()->NewLocal(field.type_->GetPointeeType(), ""); + GetEmitter()->EmitDerefN(local, local_ptr_2.ValueOf(), field.type_->GetPointeeType()->GetSize()); + suffix = "val"; + } + local = local.ValueOf(); + break; + } + } + capture = true; + } + NOISEPAGE_ASSERT(!local.IsInvalid(), "Local not found"); if (GetExecutionResult()->IsLValue()) { - GetExecutionResult()->SetDestination(local); + auto *local_info = GetCurrentFunction()->LookupLocalInfoByOffset(local.GetOffset()); + if (local_info->GetType()->IsPointerType() && local_info->GetType()->GetPointeeType()->IsSqlValueType()) { + GetExecutionResult()->SetDestination(local.ValueOf()); + } else { + GetExecutionResult()->SetDestination(local); + } return; } @@ -236,34 +357,39 @@ void BytecodeGenerator::VisitIdentifierExpr(ast::IdentifierExpr *node) { // If the local we want the R-Value of is a parameter, we can't take its // pointer for the deref, so we use an assignment. Otherwise, a deref is good. - if (auto *local_info = GetCurrentFunction()->LookupLocalInfoByName(local_name); local_info->IsParameter()) { - BuildAssign(dest, local.ValueOf(), node->GetType()); + auto *local_info = GetCurrentFunction()->LookupLocalInfoByOffset(local.GetOffset()); + if (local_info->IsParameter()) { + if (local_info->GetType()->IsPointerType() && local_info->GetType()->GetPointeeType()->IsSqlValueType() && + GetExecutionResult()->IsRValue()) { + BuildDeref(dest, local.ValueOf(), node->GetType()); + } else { + BuildAssign(dest, local.ValueOf(), node->GetType()); + } } else { BuildDeref(dest, local, node->GetType()); } - GetExecutionResult()->SetDestination(dest); + GetExecutionResult()->SetDestination(capture ? dest.ValueOf() : dest); } void BytecodeGenerator::VisitImplicitCastExpr(ast::ImplicitCastExpr *node) { + LocalVar input = VisitExpressionForRValue(node->Input()); + switch (node->GetCastKind()) { case ast::CastKind::SqlBoolToBool: { LocalVar dest = GetExecutionResult()->GetOrCreateDestination(node->GetType()); - LocalVar input = VisitExpressionForSQLValue(node->Input()); GetEmitter()->Emit(Bytecode::ForceBoolTruth, dest, input); GetExecutionResult()->SetDestination(dest.ValueOf()); break; } case ast::CastKind::BoolToSqlBool: { LocalVar dest = GetExecutionResult()->GetOrCreateDestination(node->GetType()); - LocalVar input = VisitExpressionForRValue(node->Input()); GetEmitter()->Emit(Bytecode::InitBool, dest, input); GetExecutionResult()->SetDestination(dest); break; } case ast::CastKind::IntToSqlInt: { LocalVar dest = GetExecutionResult()->GetOrCreateDestination(node->GetType()); - LocalVar input = VisitExpressionForRValue(node->Input()); ast::Expr *arg = node->Input(); Bytecode bytecode = Bytecode::InitInteger; @@ -281,7 +407,6 @@ void BytecodeGenerator::VisitImplicitCastExpr(ast::ImplicitCastExpr *node) { } case ast::CastKind::BitCast: case ast::CastKind::IntegralCast: { - LocalVar input = VisitExpressionForRValue(node->Input()); // As an optimization, we only issue a new assignment if the input and // output types of the cast have different sizes. if (node->Input()->GetType()->GetSize() != node->GetType()->GetSize()) { @@ -295,15 +420,13 @@ void BytecodeGenerator::VisitImplicitCastExpr(ast::ImplicitCastExpr *node) { } case ast::CastKind::FloatToSqlReal: { LocalVar dest = GetExecutionResult()->GetOrCreateDestination(node->GetType()); - LocalVar input = VisitExpressionForRValue(node->Input()); - GetEmitter()->Emit(Bytecode::InitReal, dest, input); + GetEmitter()->Emit(Bytecode::InitReal, dest.AddressOf(), input.AddressOf()); GetExecutionResult()->SetDestination(dest); break; } case ast::CastKind::SqlIntToSqlReal: { LocalVar dest = GetExecutionResult()->GetOrCreateDestination(node->GetType()); - LocalVar input = VisitExpressionForSQLValue(node->Input()); - GetEmitter()->Emit(Bytecode::IntegerToReal, dest, input); + GetEmitter()->Emit(Bytecode::IntegerToReal, dest.AddressOf(), input.AddressOf()); GetExecutionResult()->SetDestination(dest); break; } @@ -466,7 +589,7 @@ void BytecodeGenerator::VisitLogicalNotExpr(ast::UnaryOpExpr *op) { GetEmitter()->EmitUnaryOp(Bytecode::Not, dest, input); GetExecutionResult()->SetDestination(dest.ValueOf()); } else if (op->GetType()->IsSqlBooleanType()) { - input = VisitExpressionForSQLValue(op->Input()); + input = VisitExpressionForLValue(op->Input()); GetEmitter()->EmitUnaryOp(Bytecode::NotSql, dest, input); GetExecutionResult()->SetDestination(dest); } @@ -501,7 +624,7 @@ void BytecodeGenerator::VisitReturnStmt(ast::ReturnStmt *node) { if (node->Ret() != nullptr) { LocalVar rv = GetCurrentFunction()->GetReturnValueLocal(); if (node->Ret()->GetType()->IsSqlValueType()) { - LocalVar result = VisitExpressionForSQLValue(node->Ret()); + LocalVar result = VisitExpressionForLValue(node->Ret()); BuildDeref(rv.ValueOf(), result, node->Ret()->GetType()); } else { LocalVar result = VisitExpressionForRValue(node->Ret()); @@ -574,17 +697,17 @@ void BytecodeGenerator::VisitSqlConversionCall(ast::CallExpr *call, ast::Builtin break; } case ast::Builtin::SqlToBool: { - auto input = VisitExpressionForSQLValue(call->Arguments()[0]); + auto input = VisitExpressionForLValue(call->Arguments()[0]); GetEmitter()->Emit(Bytecode::ForceBoolTruth, dest, input); GetExecutionResult()->SetDestination(dest.ValueOf()); break; } -#define GEN_CASE(Builtin, Bytecode) \ - case Builtin: { \ - auto input = VisitExpressionForSQLValue(call->Arguments()[0]); \ - GetEmitter()->Emit(Bytecode, dest, input); \ - break; \ +#define GEN_CASE(Builtin, Bytecode) \ + case Builtin: { \ + auto input = VisitExpressionForRValue(call->Arguments()[0]); \ + GetEmitter()->Emit(Bytecode, dest, input); \ + break; \ } GEN_CASE(ast::Builtin::ConvertBoolToInteger, Bytecode::BoolToInteger); GEN_CASE(ast::Builtin::ConvertIntegerToReal, Bytecode::IntegerToReal); @@ -606,7 +729,7 @@ void BytecodeGenerator::VisitNullValueCall(ast::CallExpr *call, UNUSED_ATTRIBUTE switch (builtin) { case ast::Builtin::IsValNull: { LocalVar result = GetExecutionResult()->GetOrCreateDestination(call->GetType()); - LocalVar input = VisitExpressionForSQLValue(call->Arguments()[0]); + LocalVar input = VisitExpressionForLValue(call->Arguments()[0]); GetEmitter()->Emit(Bytecode::ValIsNull, result, input); GetExecutionResult()->SetDestination(result.ValueOf()); break; @@ -629,15 +752,15 @@ void BytecodeGenerator::VisitNullValueCall(ast::CallExpr *call, UNUSED_ATTRIBUTE void BytecodeGenerator::VisitSqlStringLikeCall(ast::CallExpr *call) { auto dest = GetExecutionResult()->GetOrCreateDestination(call->GetType()); - auto input = VisitExpressionForSQLValue(call->Arguments()[0]); - auto pattern = VisitExpressionForSQLValue(call->Arguments()[1]); + auto input = VisitExpressionForLValue(call->Arguments()[0]); + auto pattern = VisitExpressionForLValue(call->Arguments()[1]); GetEmitter()->Emit(Bytecode::Like, dest, input, pattern); GetExecutionResult()->SetDestination(dest); } void BytecodeGenerator::VisitBuiltinDateFunctionCall(ast::CallExpr *call, ast::Builtin builtin) { auto dest = GetExecutionResult()->GetOrCreateDestination(call->GetType()); - auto input = VisitExpressionForSQLValue(call->Arguments()[0]); + auto input = VisitExpressionForLValue(call->Arguments()[0]); auto date_type = sql::DatePartType(call->Arguments()[1]->As()->Arguments()[0]->As()->Int64Val()); @@ -651,6 +774,11 @@ void BytecodeGenerator::VisitBuiltinDateFunctionCall(ast::CallExpr *call, ast::B GetExecutionResult()->SetDestination(dest); } +void BytecodeGenerator::VisitBuiltinRandomFunctionCall(ast::CallExpr *call, ast::Builtin builtin) { + LocalVar ret = GetExecutionResult()->GetOrCreateDestination(call->GetType()); + GetEmitter()->Emit(Bytecode::Random, ret); +} + void BytecodeGenerator::VisitBuiltinTableIterCall(ast::CallExpr *call, ast::Builtin builtin) { // The first argument to all calls is a pointer to the TVI LocalVar iter = VisitExpressionForRValue(call->Arguments()[0]); @@ -837,13 +965,13 @@ void BytecodeGenerator::VisitBuiltinVPICall(ast::CallExpr *call, ast::Builtin bu #define GEN_CASE(BuiltinName, Bytecode) \ case ast::Builtin::BuiltinName: { \ - auto input = VisitExpressionForSQLValue(call->Arguments()[1]); \ + auto input = VisitExpressionForLValue(call->Arguments()[1]); \ auto col_idx = call->Arguments()[2]->As()->Int64Val(); \ GetEmitter()->EmitVPISet(Bytecode, vpi, input, col_idx); \ break; \ } \ case ast::Builtin::BuiltinName##Null: { \ - auto input = VisitExpressionForSQLValue(call->Arguments()[1]); \ + auto input = VisitExpressionForLValue(call->Arguments()[1]); \ auto col_idx = call->Arguments()[2]->As()->Int64Val(); \ GetEmitter()->EmitVPISet(Bytecode##Null, vpi, input, col_idx); \ break; \ @@ -881,7 +1009,7 @@ void BytecodeGenerator::VisitBuiltinHashCall(ast::CallExpr *call) { for (uint32_t idx = 0; idx < call->NumArgs(); idx++) { NOISEPAGE_ASSERT(call->Arguments()[idx]->GetType()->IsSqlValueType(), "Input to hash must be a SQL value type"); - LocalVar input = VisitExpressionForSQLValue(call->Arguments()[idx]); + LocalVar input = VisitExpressionForLValue(call->Arguments()[idx]); const auto *type = call->Arguments()[idx]->GetType()->As(); switch (type->GetKind()) { case ast::BuiltinType::Integer: @@ -954,7 +1082,7 @@ void BytecodeGenerator::VisitBuiltinVectorFilterCall(ast::CallExpr *call, ast::B #define GEN_CASE(BYTECODE) \ LocalVar left_col = VisitExpressionForRValue(call->Arguments()[2]); \ if (!call->Arguments()[3]->GetType()->IsIntegerType()) { \ - LocalVar right_val = VisitExpressionForSQLValue(call->Arguments()[3]); \ + LocalVar right_val = VisitExpressionForLValue(call->Arguments()[3]); \ GetEmitter()->Emit(BYTECODE##Val, exec_ctx, vector_projection, left_col, right_val, tid_list); \ } else { \ LocalVar right_col = VisitExpressionForRValue(call->Arguments()[3]); \ @@ -1835,7 +1963,7 @@ void BytecodeGenerator::VisitBuiltinThreadStateContainerCall(ast::CallExpr *call void BytecodeGenerator::VisitBuiltinTrigCall(ast::CallExpr *call, ast::Builtin builtin) { LocalVar dest = GetExecutionResult()->GetOrCreateDestination(call->GetType()); - LocalVar src = VisitExpressionForSQLValue(call->Arguments()[0]); + LocalVar src = VisitExpressionForLValue(call->Arguments()[0]); switch (builtin) { case ast::Builtin::ACos: { @@ -1863,7 +1991,7 @@ void BytecodeGenerator::VisitBuiltinTrigCall(ast::CallExpr *call, ast::Builtin b break; } case ast::Builtin::ATan2: { - LocalVar src2 = VisitExpressionForSQLValue(call->Arguments()[1]); + LocalVar src2 = VisitExpressionForLValue(call->Arguments()[1]); GetEmitter()->Emit(Bytecode::Atan2, dest, src, src2); break; } @@ -1904,7 +2032,7 @@ void BytecodeGenerator::VisitBuiltinTrigCall(ast::CallExpr *call, ast::Builtin b break; } case ast::Builtin::Exp: { - src = VisitExpressionForSQLValue(call->Arguments()[1]); + src = VisitExpressionForLValue(call->Arguments()[1]); GetEmitter()->Emit(Bytecode::Exp, dest, src); break; } @@ -1921,12 +2049,12 @@ void BytecodeGenerator::VisitBuiltinTrigCall(ast::CallExpr *call, ast::Builtin b break; } case ast::Builtin::Round2: { - LocalVar src2 = VisitExpressionForSQLValue(call->Arguments()[1]); + LocalVar src2 = VisitExpressionForLValue(call->Arguments()[1]); GetEmitter()->Emit(Bytecode::Round2, dest, src, src2); break; } case ast::Builtin::Pow: { - LocalVar src2 = VisitExpressionForSQLValue(call->Arguments()[1]); + LocalVar src2 = VisitExpressionForLValue(call->Arguments()[1]); GetEmitter()->Emit(Bytecode::Pow, dest, src, src2); break; } @@ -1944,13 +2072,13 @@ void BytecodeGenerator::VisitBuiltinArithmeticCall(ast::CallExpr *call, ast::Bui switch (builtin) { case ast::Builtin::Abs: { - LocalVar src = VisitExpressionForSQLValue(call->Arguments()[0]); + LocalVar src = VisitExpressionForLValue(call->Arguments()[0]); GetEmitter()->Emit(is_integer_math ? Bytecode::AbsInteger : Bytecode::AbsReal, dest, src); break; } case ast::Builtin::Mod: { - LocalVar first_input = VisitExpressionForSQLValue(call->Arguments()[0]); - LocalVar second_input = VisitExpressionForSQLValue(call->Arguments()[1]); + LocalVar first_input = VisitExpressionForLValue(call->Arguments()[0]); + LocalVar second_input = VisitExpressionForLValue(call->Arguments()[1]); if (!is_integer_math) { NOISEPAGE_ASSERT(call->Arguments()[0]->GetType()->IsSpecificBuiltin(ast::BuiltinType::Real) && call->Arguments()[1]->GetType()->IsSpecificBuiltin(ast::BuiltinType::Real), @@ -2429,8 +2557,22 @@ void BytecodeGenerator::VisitBuiltinStorageInterfaceCall(ast::CallExpr *call, as void BytecodeGenerator::VisitBuiltinParamCall(ast::CallExpr *call, ast::Builtin builtin) { LocalVar exec_ctx = VisitExpressionForRValue(call->Arguments()[0]); - LocalVar param_idx = VisitExpressionForRValue(call->Arguments()[1]); - LocalVar ret = GetExecutionResult()->GetOrCreateDestination(call->GetType()); + LocalVar param_idx{}; + if (builtin != ast::Builtin::StartNewParams && builtin != ast::Builtin::FinishNewParams) { + param_idx = VisitExpressionForRValue(call->Arguments()[1]); + } + LocalVar ret; + if (builtin < ast::Builtin::StartNewParams) { + ret = GetExecutionResult()->GetOrCreateDestination(call->GetType()); + } else { + if (builtin != ast::Builtin::StartNewParams && builtin != ast::Builtin::FinishNewParams) { + if (call->Arguments()[1]->GetType()->IsPointerType()) { + param_idx = VisitExpressionForRValue(call->Arguments()[1]); + } else { + param_idx = VisitExpressionForLValue(call->Arguments()[1]); + } + } + } switch (builtin) { case ast::Builtin::GetParamBool: GetEmitter()->Emit(Bytecode::GetParamBool, ret, exec_ctx, param_idx); @@ -2462,6 +2604,42 @@ void BytecodeGenerator::VisitBuiltinParamCall(ast::CallExpr *call, ast::Builtin case ast::Builtin::GetParamString: GetEmitter()->Emit(Bytecode::GetParamString, ret, exec_ctx, param_idx); break; + case ast::Builtin::AddParamBool: + GetEmitter()->Emit(Bytecode::AddParamBool, exec_ctx, param_idx); + break; + case ast::Builtin::AddParamTinyInt: + GetEmitter()->Emit(Bytecode::AddParamTinyInt, exec_ctx, param_idx); + break; + case ast::Builtin::AddParamSmallInt: + GetEmitter()->Emit(Bytecode::AddParamSmallInt, exec_ctx, param_idx); + break; + case ast::Builtin::AddParamInt: + GetEmitter()->Emit(Bytecode::AddParamInt, exec_ctx, param_idx); + break; + case ast::Builtin::AddParamBigInt: + GetEmitter()->Emit(Bytecode::AddParamBigInt, exec_ctx, param_idx); + break; + case ast::Builtin::AddParamReal: + GetEmitter()->Emit(Bytecode::AddParamReal, exec_ctx, param_idx); + break; + case ast::Builtin::AddParamDouble: + GetEmitter()->Emit(Bytecode::AddParamDouble, exec_ctx, param_idx); + break; + case ast::Builtin::AddParamDate: + GetEmitter()->Emit(Bytecode::AddParamDateVal, exec_ctx, param_idx); + break; + case ast::Builtin::AddParamTimestamp: + GetEmitter()->Emit(Bytecode::AddParamTimestampVal, exec_ctx, param_idx); + break; + case ast::Builtin::AddParamString: + GetEmitter()->Emit(Bytecode::AddParamString, exec_ctx, param_idx); + break; + case ast::Builtin::StartNewParams: + GetEmitter()->Emit(Bytecode::StartNewParams, exec_ctx); + break; + case ast::Builtin::FinishNewParams: + GetEmitter()->Emit(Bytecode::FinishParams, exec_ctx); + break; default: UNREACHABLE("Impossible parameter call!"); } @@ -2472,35 +2650,35 @@ void BytecodeGenerator::VisitBuiltinStringCall(ast::CallExpr *call, ast::Builtin LocalVar ret = GetExecutionResult()->GetOrCreateDestination(call->GetType()); switch (builtin) { case ast::Builtin::SplitPart: { - LocalVar input_string = VisitExpressionForSQLValue(call->Arguments()[1]); - LocalVar delim = VisitExpressionForSQLValue(call->Arguments()[2]); - LocalVar field = VisitExpressionForSQLValue(call->Arguments()[3]); + LocalVar input_string = VisitExpressionForRValue(call->Arguments()[1]); + LocalVar delim = VisitExpressionForRValue(call->Arguments()[2]); + LocalVar field = VisitExpressionForRValue(call->Arguments()[3]); GetEmitter()->Emit(Bytecode::SplitPart, ret, exec_ctx, input_string, delim, field); break; } case ast::Builtin::Chr: { // input_string here is a integer type number - LocalVar input_string = VisitExpressionForSQLValue(call->Arguments()[1]); + LocalVar input_string = VisitExpressionForRValue(call->Arguments()[1]); GetEmitter()->Emit(Bytecode::Chr, ret, exec_ctx, input_string); break; } case ast::Builtin::CharLength: { - LocalVar input_string = VisitExpressionForSQLValue(call->Arguments()[1]); + LocalVar input_string = VisitExpressionForRValue(call->Arguments()[1]); GetEmitter()->Emit(Bytecode::CharLength, ret, exec_ctx, input_string); break; } case ast::Builtin::ASCII: { - LocalVar input_string = VisitExpressionForSQLValue(call->Arguments()[1]); + LocalVar input_string = VisitExpressionForRValue(call->Arguments()[1]); GetEmitter()->Emit(Bytecode::ASCII, ret, exec_ctx, input_string); break; } case ast::Builtin::Lower: { - LocalVar input_string = VisitExpressionForSQLValue(call->Arguments()[1]); + LocalVar input_string = VisitExpressionForRValue(call->Arguments()[1]); GetEmitter()->Emit(Bytecode::Lower, ret, exec_ctx, input_string); break; } case ast::Builtin::Upper: { - LocalVar input_string = VisitExpressionForSQLValue(call->Arguments()[1]); + LocalVar input_string = VisitExpressionForRValue(call->Arguments()[1]); GetEmitter()->Emit(Bytecode::Upper, ret, exec_ctx, input_string); break; } @@ -2509,73 +2687,73 @@ void BytecodeGenerator::VisitBuiltinStringCall(ast::CallExpr *call, ast::Builtin break; } case ast::Builtin::StartsWith: { - LocalVar input_string = VisitExpressionForSQLValue(call->Arguments()[1]); - LocalVar start_str = VisitExpressionForSQLValue(call->Arguments()[2]); + LocalVar input_string = VisitExpressionForRValue(call->Arguments()[1]); + LocalVar start_str = VisitExpressionForRValue(call->Arguments()[2]); GetEmitter()->Emit(Bytecode::StartsWith, ret, exec_ctx, input_string, start_str); break; } case ast::Builtin::Substring: { - LocalVar input_string = VisitExpressionForSQLValue(call->Arguments()[1]); - LocalVar start_ind = VisitExpressionForSQLValue(call->Arguments()[2]); - LocalVar length = VisitExpressionForSQLValue(call->Arguments()[3]); + LocalVar input_string = VisitExpressionForRValue(call->Arguments()[1]); + LocalVar start_ind = VisitExpressionForRValue(call->Arguments()[2]); + LocalVar length = VisitExpressionForRValue(call->Arguments()[3]); GetEmitter()->Emit(Bytecode::Substring, ret, exec_ctx, input_string, start_ind, length); break; } case ast::Builtin::Reverse: { - LocalVar input_string = VisitExpressionForSQLValue(call->Arguments()[1]); + LocalVar input_string = VisitExpressionForRValue(call->Arguments()[1]); GetEmitter()->Emit(Bytecode::Reverse, ret, exec_ctx, input_string); break; } case ast::Builtin::Left: { - LocalVar input_string = VisitExpressionForSQLValue(call->Arguments()[1]); - LocalVar len = VisitExpressionForSQLValue(call->Arguments()[2]); + LocalVar input_string = VisitExpressionForRValue(call->Arguments()[1]); + LocalVar len = VisitExpressionForRValue(call->Arguments()[2]); GetEmitter()->Emit(Bytecode::Left, ret, exec_ctx, input_string, len); break; } case ast::Builtin::Right: { - LocalVar input_string = VisitExpressionForSQLValue(call->Arguments()[1]); - LocalVar len = VisitExpressionForSQLValue(call->Arguments()[2]); + LocalVar input_string = VisitExpressionForRValue(call->Arguments()[1]); + LocalVar len = VisitExpressionForRValue(call->Arguments()[2]); GetEmitter()->Emit(Bytecode::Right, ret, exec_ctx, input_string, len); break; } case ast::Builtin::Repeat: { - LocalVar input_string = VisitExpressionForSQLValue(call->Arguments()[1]); - LocalVar num_repeat = VisitExpressionForSQLValue(call->Arguments()[2]); + LocalVar input_string = VisitExpressionForRValue(call->Arguments()[1]); + LocalVar num_repeat = VisitExpressionForRValue(call->Arguments()[2]); GetEmitter()->Emit(Bytecode::Repeat, ret, exec_ctx, input_string, num_repeat); break; } case ast::Builtin::Trim: { - LocalVar input_string = VisitExpressionForSQLValue(call->Arguments()[1]); + LocalVar input_string = VisitExpressionForRValue(call->Arguments()[1]); GetEmitter()->Emit(Bytecode::Trim, ret, exec_ctx, input_string); break; } case ast::Builtin::Trim2: { - LocalVar input_string = VisitExpressionForSQLValue(call->Arguments()[1]); - LocalVar trim_str = VisitExpressionForSQLValue(call->Arguments()[2]); + LocalVar input_string = VisitExpressionForRValue(call->Arguments()[1]); + LocalVar trim_str = VisitExpressionForRValue(call->Arguments()[2]); GetEmitter()->Emit(Bytecode::Trim2, ret, exec_ctx, input_string, trim_str); break; } case ast::Builtin::Position: { - LocalVar input_string = VisitExpressionForSQLValue(call->Arguments()[1]); - LocalVar sub_string = VisitExpressionForSQLValue(call->Arguments()[2]); + LocalVar input_string = VisitExpressionForRValue(call->Arguments()[1]); + LocalVar sub_string = VisitExpressionForRValue(call->Arguments()[2]); GetEmitter()->Emit(Bytecode::Position, ret, exec_ctx, input_string, sub_string); break; } case ast::Builtin::Length: { - LocalVar input_string = VisitExpressionForSQLValue(call->Arguments()[1]); + LocalVar input_string = VisitExpressionForRValue(call->Arguments()[1]); GetEmitter()->Emit(Bytecode::Length, ret, exec_ctx, input_string); break; } case ast::Builtin::InitCap: { - LocalVar input_string = VisitExpressionForSQLValue(call->Arguments()[1]); + LocalVar input_string = VisitExpressionForRValue(call->Arguments()[1]); GetEmitter()->Emit(Bytecode::InitCap, ret, exec_ctx, input_string); break; } case ast::Builtin::Lpad: { - LocalVar input_string = VisitExpressionForSQLValue(call->Arguments()[1]); - LocalVar len = VisitExpressionForSQLValue(call->Arguments()[2]); + LocalVar input_string = VisitExpressionForRValue(call->Arguments()[1]); + LocalVar len = VisitExpressionForRValue(call->Arguments()[2]); if (call->NumArgs() == 4) { - LocalVar pad = VisitExpressionForSQLValue(call->Arguments()[3]); + LocalVar pad = VisitExpressionForRValue(call->Arguments()[3]); GetEmitter()->Emit(Bytecode::LPad3Arg, ret, exec_ctx, input_string, len, pad); } else { GetEmitter()->Emit(Bytecode::LPad2Arg, ret, exec_ctx, input_string, len); @@ -2583,10 +2761,10 @@ void BytecodeGenerator::VisitBuiltinStringCall(ast::CallExpr *call, ast::Builtin break; } case ast::Builtin::Rpad: { - LocalVar input_string = VisitExpressionForSQLValue(call->Arguments()[1]); - LocalVar len = VisitExpressionForSQLValue(call->Arguments()[2]); + LocalVar input_string = VisitExpressionForRValue(call->Arguments()[1]); + LocalVar len = VisitExpressionForRValue(call->Arguments()[2]); if (call->NumArgs() == 4) { - LocalVar pad = VisitExpressionForSQLValue(call->Arguments()[3]); + LocalVar pad = VisitExpressionForRValue(call->Arguments()[3]); GetEmitter()->Emit(Bytecode::RPad3Arg, ret, exec_ctx, input_string, len, pad); } else { GetEmitter()->Emit(Bytecode::RPad2Arg, ret, exec_ctx, input_string, len); @@ -2594,21 +2772,21 @@ void BytecodeGenerator::VisitBuiltinStringCall(ast::CallExpr *call, ast::Builtin break; } case ast::Builtin::Ltrim: { - LocalVar input_string = VisitExpressionForSQLValue(call->Arguments()[1]); + LocalVar input_string = VisitExpressionForRValue(call->Arguments()[1]); if (call->NumArgs() == 2) { GetEmitter()->Emit(Bytecode::LTrim1Arg, ret, exec_ctx, input_string); } else { - LocalVar chars = VisitExpressionForSQLValue(call->Arguments()[2]); + LocalVar chars = VisitExpressionForRValue(call->Arguments()[2]); GetEmitter()->Emit(Bytecode::LTrim2Arg, ret, exec_ctx, input_string, chars); } break; } case ast::Builtin::Rtrim: { - LocalVar input_string = VisitExpressionForSQLValue(call->Arguments()[1]); + LocalVar input_string = VisitExpressionForRValue(call->Arguments()[1]); if (call->NumArgs() == 2) { GetEmitter()->Emit(Bytecode::RTrim1Arg, ret, exec_ctx, input_string); } else { - LocalVar chars = VisitExpressionForSQLValue(call->Arguments()[2]); + LocalVar chars = VisitExpressionForRValue(call->Arguments()[2]); GetEmitter()->Emit(Bytecode::RTrim2Arg, ret, exec_ctx, input_string, chars); } break; @@ -2623,7 +2801,7 @@ void BytecodeGenerator::VisitBuiltinStringCall(ast::CallExpr *call, ast::Builtin auto arr_elem_ptr = GetCurrentFunction()->NewLocal(string_type->PointerTo()->PointerTo()); for (uint32_t i = 0; i < num_inputs; i++) { GetEmitter()->EmitLea(arr_elem_ptr, inputs, i * 8); - LocalVar input_string = VisitExpressionForSQLValue(call->Arguments()[i + 1]); + LocalVar input_string = VisitExpressionForLValue(call->Arguments()[i + 1]); GetEmitter()->EmitAssign(Bytecode::Assign8, arr_elem_ptr.ValueOf(), input_string); } @@ -2688,6 +2866,10 @@ void BytecodeGenerator::VisitBuiltinCallExpr(ast::CallExpr *call) { VisitBuiltinDateFunctionCall(call, builtin); break; } + case ast::Builtin::Random: { + VisitBuiltinRandomFunctionCall(call, builtin); + break; + } case ast::Builtin::RegisterThreadWithMetricsManager: { LocalVar exec_ctx = VisitExpressionForRValue(call->Arguments()[0]); GetEmitter()->Emit(Bytecode::RegisterThreadWithMetricsManager, exec_ctx); @@ -3096,7 +3278,19 @@ void BytecodeGenerator::VisitBuiltinCallExpr(ast::CallExpr *call) { case ast::Builtin::GetParamDouble: case ast::Builtin::GetParamDate: case ast::Builtin::GetParamTimestamp: - case ast::Builtin::GetParamString: { + case ast::Builtin::GetParamString: + case ast::Builtin::AddParamBool: + case ast::Builtin::AddParamTinyInt: + case ast::Builtin::AddParamSmallInt: + case ast::Builtin::AddParamInt: + case ast::Builtin::AddParamBigInt: + case ast::Builtin::AddParamReal: + case ast::Builtin::AddParamDouble: + case ast::Builtin::AddParamDate: + case ast::Builtin::AddParamTimestamp: + case ast::Builtin::AddParamString: + case ast::Builtin::StartNewParams: + case ast::Builtin::FinishNewParams: { VisitBuiltinParamCall(call, builtin); break; } @@ -3260,11 +3454,15 @@ void BytecodeGenerator::VisitBuiltinIndexIteratorCall(ast::CallExpr *call, ast:: void BytecodeGenerator::VisitRegularCallExpr(ast::CallExpr *call) { bool caller_wants_result = GetExecutionResult() != nullptr; - NOISEPAGE_ASSERT(!caller_wants_result || GetExecutionResult()->IsRValue(), "Calls can only be R-Values!"); - + NOISEPAGE_ASSERT(!caller_wants_result || GetExecutionResult()->IsRValue() || + (GetExecutionResult()->IsLValue() && call->GetType()->IsSqlValueType()), + "Calls can only be R-Values!"); std::vector params; - auto *func_type = call->Function()->GetType()->As(); + auto *func_type = call->Function()->GetType()->SafeAs(); + if (func_type == nullptr) { + func_type = call->Function()->GetType()->SafeAs()->GetFunctionType(); + } if (!func_type->GetReturnType()->IsNilType()) { LocalVar ret_val; @@ -3273,6 +3471,9 @@ void BytecodeGenerator::VisitRegularCallExpr(ast::CallExpr *call) { // Let the caller know where the result value is GetExecutionResult()->SetDestination(ret_val.ValueOf()); + if (GetExecutionResult()->IsLValue()) { + GetExecutionResult()->SetDestination(ret_val.AddressOf()); + } } else { ret_val = GetCurrentFunction()->NewLocal(func_type->GetReturnType()); } @@ -3282,12 +3483,21 @@ void BytecodeGenerator::VisitRegularCallExpr(ast::CallExpr *call) { } // Collect non-return-value parameters as usual - for (uint32_t i = 0; i < func_type->GetNumParams(); i++) { - params.push_back(VisitExpressionForRValue(call->Arguments()[i])); + for (uint32_t i = 0; i < call->Arguments().size(); i++) { + if (func_type->GetParams()[i].type_->IsSqlValueType()) { + params.push_back(VisitExpressionForLValue(call->Arguments()[i])); + } else { + params.push_back(VisitExpressionForRValue(call->Arguments()[i])); + } } // Emit call const auto func_id = LookupFuncIdByName(call->GetFuncName().GetData()); + if (func_id == FunctionInfo::K_INVALID_FUNC_ID) { + auto action = GetEmitter()->DeferredEmitCall(params); + deferred_function_create_actions_[call->GetFuncName().GetString()].push_back(action); + return; + } NOISEPAGE_ASSERT(func_id != FunctionInfo::K_INVALID_FUNC_ID, "Function not found!"); GetEmitter()->EmitCall(func_id, params); } @@ -3295,10 +3505,18 @@ void BytecodeGenerator::VisitRegularCallExpr(ast::CallExpr *call) { void BytecodeGenerator::VisitCallExpr(ast::CallExpr *node) { ast::CallExpr::CallKind call_kind = node->GetCallKind(); - if (call_kind == ast::CallExpr::CallKind::Builtin) { - VisitBuiltinCallExpr(node); - } else { - VisitRegularCallExpr(node); + switch (call_kind) { + case ast::CallExpr::CallKind::Builtin: { + VisitBuiltinCallExpr(node); + break; + } + case ast::CallExpr::CallKind::Regular: { + VisitRegularCallExpr(node); + break; + } + default: { + UNREACHABLE("Unknown Call Kind"); + } } } @@ -3314,7 +3532,7 @@ void BytecodeGenerator::VisitFile(ast::File *node) { } void BytecodeGenerator::VisitLitExpr(ast::LitExpr *node) { - NOISEPAGE_ASSERT(GetExecutionResult()->IsRValue(), "Literal expressions cannot be R-Values!"); + NOISEPAGE_ASSERT(GetExecutionResult()->IsRValue(), "Literal expressions cannot be L-Values!"); LocalVar target = GetExecutionResult()->GetOrCreateDestination(node->GetType()); @@ -3469,8 +3687,8 @@ void BytecodeGenerator::VisitPrimitiveArithmeticExpr(ast::BinaryOpExpr *node) { void BytecodeGenerator::VisitSqlArithmeticExpr(ast::BinaryOpExpr *node) { LocalVar dest = GetExecutionResult()->GetOrCreateDestination(node->GetType()); - LocalVar left = VisitExpressionForSQLValue(node->Left()); - LocalVar right = VisitExpressionForSQLValue(node->Right()); + LocalVar left = VisitExpressionForLValue(node->Left()); + LocalVar right = VisitExpressionForLValue(node->Right()); const bool is_integer_math = node->GetType()->IsSpecificBuiltin(ast::BuiltinType::Integer); @@ -3556,8 +3774,8 @@ void BytecodeGenerator::VisitBinaryOpExpr(ast::BinaryOpExpr *node) { void BytecodeGenerator::VisitSqlCompareOpExpr(ast::ComparisonOpExpr *compare) { LocalVar dest = GetExecutionResult()->GetOrCreateDestination(compare->GetType()); - LocalVar left = VisitExpressionForSQLValue(compare->Left()); - LocalVar right = VisitExpressionForSQLValue(compare->Right()); + LocalVar left = VisitExpressionForLValue(compare->Left()); + LocalVar right = VisitExpressionForLValue(compare->Right()); NOISEPAGE_ASSERT(compare->Left()->GetType() == compare->Right()->GetType(), "Left and right input types to comparison are not equal"); @@ -3700,8 +3918,10 @@ void BytecodeGenerator::BuildAssign(LocalVar dest, LocalVar val, ast::Type *dest GetEmitter()->EmitAssign(Bytecode::Assign2, dest, val); } else if (size == 4) { GetEmitter()->EmitAssign(Bytecode::Assign4, dest, val); - } else { + } else if (size == 8 && dest_type != ast::BuiltinType::Get(dest_type->GetContext(), ast::BuiltinType::Date)) { GetEmitter()->EmitAssign(Bytecode::Assign8, dest, val); + } else { + GetEmitter()->EmitAssignN(dest, val, size); } } @@ -3818,25 +4038,35 @@ void BytecodeGenerator::VisitMapTypeRepr(ast::MapTypeRepr *node) { NOISEPAGE_ASSERT(false, "Should not visit type-representation nodes!"); } -FunctionInfo *BytecodeGenerator::AllocateFunc(const std::string &func_name, ast::FunctionType *const func_type) { +FunctionInfo *BytecodeGenerator::AllocateFunction(const std::string &function_name, + ast::FunctionType *const function_type) { // Allocate function const auto func_id = static_cast(functions_.size()); - functions_.emplace_back(func_id, func_name, func_type); - FunctionInfo *func = &functions_.back(); + functions_.push_back(std::make_unique(func_id, function_name, function_type)); + FunctionInfo *func = functions_.back().get(); // Register return type - if (auto *return_type = func_type->GetReturnType(); !return_type->IsNilType()) { + if (auto *return_type = function_type->GetReturnType(); !return_type->IsNilType()) { func->NewParameterLocal(return_type->PointerTo(), "hiddenRv"); } // Register parameters - for (const auto ¶m : func_type->GetParams()) { - func->NewParameterLocal(param.type_, param.name_.GetData()); + for (const auto ¶m : function_type->GetParams()) { + if (param.type_->IsSqlValueType()) { + func->NewParameterLocal(param.type_->PointerTo(), param.name_.GetData()); + } else { + func->NewParameterLocal(param.type_, param.name_.GetData()); + } } - // Cache + // Cache the function func_map_[func->GetName()] = func->GetId(); + // Execute all deferred creation actions for the function + for (const auto &action : deferred_function_create_actions_[func->GetName()]) { + action(func->GetId()); + } + return func; } @@ -3894,12 +4124,6 @@ LocalVar BytecodeGenerator::VisitExpressionForRValue(ast::Expr *expr) { return scope.GetDestination(); } -LocalVar BytecodeGenerator::VisitExpressionForSQLValue(ast::Expr *expr) { return VisitExpressionForLValue(expr); } - -void BytecodeGenerator::VisitExpressionForSQLValue(ast::Expr *expr, LocalVar dest) { - VisitExpressionForRValue(expr, dest); -} - void BytecodeGenerator::VisitExpressionForRValue(ast::Expr *expr, LocalVar dest) { RValueResultScope scope(this, dest); Visit(expr); @@ -3907,7 +4131,6 @@ void BytecodeGenerator::VisitExpressionForRValue(ast::Expr *expr, LocalVar dest) void BytecodeGenerator::VisitExpressionForTest(ast::Expr *expr, BytecodeLabel *then_label, BytecodeLabel *else_label, TestFallthrough fallthrough) { - // Evaluate the expression // Jumps don't expect addresses of locals LocalVar cond = VisitExpressionForRValue(expr).ValueOf(); @@ -3952,6 +4175,7 @@ std::unique_ptr BytecodeGenerator::Compile(ast::AstNode *root, c return std::make_unique(name, std::move(generator.code_), std::move(generator.data_), std::move(generator.functions_), std::move(generator.static_locals_)); } + void BytecodeGenerator::VisitBuiltinCteScanCall(ast::CallExpr *call, ast::Builtin builtin) { LocalVar iterator = VisitExpressionForRValue(call->Arguments()[0]); switch (builtin) { diff --git a/src/execution/vm/bytecode_module.cpp b/src/execution/vm/bytecode_module.cpp index 37e3f33118..c307df5fd8 100644 --- a/src/execution/vm/bytecode_module.cpp +++ b/src/execution/vm/bytecode_module.cpp @@ -13,7 +13,8 @@ namespace noisepage::execution::vm { BytecodeModule::BytecodeModule(std::string name, std::vector &&code, std::vector &&data, - std::vector &&functions, std::vector &&static_locals) + std::vector> &&functions, + std::vector &&static_locals) : name_(std::move(name)), code_(std::move(code)), data_(std::move(data)), @@ -170,6 +171,13 @@ void PrettyPrintFuncCode(std::ostream &os, const BytecodeModule &module, const F break; } case OperandType::FunctionId: { + auto fn_id = iter->GetFunctionIdOperand(i); + if (fn_id == FunctionInfo::K_INVALID_FUNC_ID) { + os << "func=<" + << "unresolved lambda" + << ">"; + break; + } auto target = module.GetFuncInfoById(iter->GetFunctionIdOperand(i)); os << "func=<" << target->GetName() << ">"; break; @@ -202,7 +210,7 @@ void BytecodeModule::Dump(std::ostream &os) const { // Functions for (const auto &func : functions_) { - PrettyPrintFunc(os, *this, func); + PrettyPrintFunc(os, *this, *func); } } diff --git a/src/execution/vm/control_flow_builders.cpp b/src/execution/vm/control_flow_builders.cpp index 340747ffa5..835f03abd0 100644 --- a/src/execution/vm/control_flow_builders.cpp +++ b/src/execution/vm/control_flow_builders.cpp @@ -40,6 +40,11 @@ void LoopBuilder::BindContinueTarget() { GetGenerator()->GetEmitter()->Bind(GetContinueLabel()); } +LoopBuilder *LoopBuilder::GetPrevLoop() const { + NOISEPAGE_ASSERT(prev_loop_ != nullptr, "Attempt to access a non-existent outer loop"); + return prev_loop_; +} + // --------------------------------------------------------- // If-Then-Else Builders // --------------------------------------------------------- diff --git a/src/execution/vm/llvm_engine.cpp b/src/execution/vm/llvm_engine.cpp index 5af6ec3c48..10fc1ac2e8 100644 --- a/src/execution/vm/llvm_engine.cpp +++ b/src/execution/vm/llvm_engine.cpp @@ -31,6 +31,7 @@ #include #include +#include "common/error/exception.h" #include "execution/ast/type.h" #include "execution/vm/bytecode_module.h" #include "execution/vm/bytecode_traits.h" @@ -185,6 +186,10 @@ llvm::Type *LLVMEngine::TypeMap::GetLLVMType(const ast::Type *type) { llvm_type = llvm::PointerType::getUnqual(GetLLVMType(ptr_type->GetBase())); break; } + case ast::Type::TypeId::ReferenceType: { + throw NOT_IMPLEMENTED_EXCEPTION("ReferenceType Not Implemented"); + break; + } case ast::Type::TypeId::ArrayType: { auto *arr_type = type->As(); llvm::Type *elem_type = GetLLVMType(arr_type->GetElementType()); @@ -196,7 +201,8 @@ llvm::Type *LLVMEngine::TypeMap::GetLLVMType(const ast::Type *type) { break; } case ast::Type::TypeId::MapType: { - // TODO(pmenon): me + // TODO(Kyle): Implement this + throw NOT_IMPLEMENTED_EXCEPTION("MapType Not Implemented"); break; } case ast::Type::TypeId::StructType: { @@ -207,6 +213,14 @@ llvm::Type *LLVMEngine::TypeMap::GetLLVMType(const ast::Type *type) { llvm_type = GetLLVMFunctionType(type->As()); break; } + case ast::Type::TypeId::LambdaType: { + llvm_type = Int32Type()->getPointerTo(); + break; + } + default: { + UNREACHABLE("Unknown Type"); + break; + } } // @@ -214,9 +228,7 @@ llvm::Type *LLVMEngine::TypeMap::GetLLVMType(const ast::Type *type) { // NOISEPAGE_ASSERT(llvm_type != nullptr, "No LLVM type found!"); - iter->second = llvm_type; - return llvm_type; } @@ -277,8 +289,12 @@ llvm::FunctionType *LLVMEngine::TypeMap::GetLLVMFunctionType(const ast::Function // for (const auto ¶m_info : func_type->GetParams()) { - llvm::Type *param_type = GetLLVMType(param_info.type_); - param_types.push_back(param_type); + if (param_info.type_->IsSqlValueType()) { + param_types.push_back(GetLLVMType(param_info.type_->PointerTo())); + } else { + llvm::Type *param_type = GetLLVMType(param_info.type_); + param_types.push_back(param_type); + } } return llvm::FunctionType::get(return_type, param_types, false); @@ -307,8 +323,8 @@ class LLVMEngine::FunctionLocalsMap { LLVMEngine::FunctionLocalsMap::FunctionLocalsMap(const FunctionInfo &func_info, llvm::Function *func, TypeMap *type_map, llvm::IRBuilder<> *ir_builder) : ir_builder_(ir_builder) { - uint32_t local_idx = 0; - + // The local variable index used throughout function body + std::size_t local_idx = 0; const auto &func_locals = func_info.GetLocals(); // Make an allocation for the return value, if it's direct. @@ -325,7 +341,15 @@ LLVMEngine::FunctionLocalsMap::FunctionLocalsMap(const FunctionInfo &func_info, params_[param.GetOffset()] = &*arg_iter; } - // Allocate all local variables up front. + if (func_info.IsLambda()) { + auto capture_type = type_map->GetLLVMType(func_info.GetFuncType()->GetCapturesType()->PointerTo()); + auto capture_local = func_locals[local_idx - 1]; + auto capture_param = params_[capture_local.GetOffset()]; + auto new_capture_param = ir_builder->CreateBitCast(capture_param, capture_type); + params_[capture_local.GetOffset()] = new_capture_param; + } + + // Allocate all local variables up front for (; local_idx < func_info.GetLocals().size(); local_idx++) { const LocalInfo &local_info = func_locals[local_idx]; llvm::Type *llvm_type = type_map->GetLLVMType(local_info.GetType()); @@ -336,7 +360,13 @@ LLVMEngine::FunctionLocalsMap::FunctionLocalsMap(const FunctionInfo &func_info, llvm::Value *LLVMEngine::FunctionLocalsMap::GetArgumentById(LocalVar var) { if (auto iter = params_.find(var.GetOffset()); iter != params_.end()) { - return iter->second; + auto val = iter->second; + if ((var.GetAddressMode() == LocalVar::AddressMode::Address) && llvm::isa(val)) { + auto new_val = ir_builder_->CreateAlloca(val->getType()); + ir_builder_->CreateStore(val, new_val); + val = new_val; + } + return val; } if (auto iter = locals_.find(var.GetOffset()); iter != locals_.end()) { @@ -538,9 +568,9 @@ void LLVMEngine::CompiledModuleBuilder::DeclareStaticLocals() { } void LLVMEngine::CompiledModuleBuilder::DeclareFunctions() { - for (const auto &func_info : tpl_module_.GetFunctionsInfo()) { - auto *func_type = llvm::cast(type_map_->GetLLVMType(func_info.GetFuncType())); - llvm_module_->getOrInsertFunction(func_info.GetName(), func_type); + for (const auto *func_info : tpl_module_.GetFunctionsInfo()) { + auto *func_type = llvm::cast(type_map_->GetLLVMType(func_info->GetFuncType())); + llvm_module_->getOrInsertFunction(func_info->GetName(), func_type); } } @@ -626,6 +656,12 @@ void LLVMEngine::CompiledModuleBuilder::BuildSimpleCFG(const FunctionInfo &func_ void LLVMEngine::CompiledModuleBuilder::DefineFunction(const FunctionInfo &func_info, llvm::IRBuilder<> *ir_builder) { llvm::LLVMContext &ctx = ir_builder->getContext(); llvm::Function *func = llvm_module_->getFunction(func_info.GetName()); + // The line below is flagged by `check-censored` target because of 'inline' + if (func->getName().str().find("inline") != std::string::npos) { // NOLINT + func->setLinkage(llvm::Function::LinkOnceAnyLinkage); + func->addFnAttr(llvm::Attribute::AlwaysInline); + } + llvm::BasicBlock *first_bb = llvm::BasicBlock::Create(ctx, "BB0", func); llvm::BasicBlock *entry_bb = llvm::BasicBlock::Create(ctx, "EntryBB", func, first_bb); @@ -925,8 +961,8 @@ void LLVMEngine::CompiledModuleBuilder::DefineFunction(const FunctionInfo &func_ void LLVMEngine::CompiledModuleBuilder::DefineFunctions() { llvm::IRBuilder<> ir_builder(*context_); - for (const auto &func_info : tpl_module_.GetFunctionsInfo()) { - DefineFunction(func_info, &ir_builder); + for (const auto *func_info : tpl_module_.GetFunctionsInfo()) { + DefineFunction(*func_info, &ir_builder); } } @@ -1129,13 +1165,13 @@ void LLVMEngine::CompiledModule::Load(const BytecodeModule &module) { // all module functions into a handy cache. // - for (const auto &func : module.GetFunctionsInfo()) { - auto symbol = loader.getSymbol(func.GetName()); + for (const auto *func : module.GetFunctionsInfo()) { + auto symbol = loader.getSymbol(func->GetName()); if (symbol.getAddress() == 0) { // for Mac portability - symbol = loader.getSymbol("_" + func.GetName()); + symbol = loader.getSymbol("_" + func->GetName()); } - functions_[func.GetName()] = reinterpret_cast(symbol.getAddress()); + functions_[func->GetName()] = reinterpret_cast(symbol.getAddress()); NOISEPAGE_ASSERT(symbol.getAddress() != 0, "symbol came out to be badly defined or missing"); } diff --git a/src/execution/vm/module.cpp b/src/execution/vm/module.cpp index 92f79efc48..e8c5c8fa6e 100644 --- a/src/execution/vm/module.cpp +++ b/src/execution/vm/module.cpp @@ -51,8 +51,8 @@ Module::Module(std::unique_ptr bytecode_module, std::unique_ptr< bytecode_trampolines_(std::make_unique(bytecode_module_->GetFunctionCount())), metadata_(std::move(metadata)) { // Create the trampolines for all bytecode functions - for (const auto &func : bytecode_module_->GetFunctionsInfo()) { - CreateFunctionTrampoline(func.GetId()); + for (const auto *func : bytecode_module_->GetFunctionsInfo()) { + CreateFunctionTrampoline(func->GetId()); } // If a compiled module wasn't provided, all internal function stubs point to @@ -280,10 +280,10 @@ void Module::CompileToMachineCode() { // JIT completed successfully. For each function in the module, pull out its // compiled implementation into the function cache, atomically replacing any // previous implementation. - for (const auto &func_info : bytecode_module_->GetFunctionsInfo()) { - auto *jit_function = jit_module_->GetFunctionPointer(func_info.GetName()); - NOISEPAGE_ASSERT(jit_function != nullptr, "Missing function in compiled module!"); - functions_[func_info.GetId()].store(jit_function, std::memory_order_relaxed); + for (const auto *func_info : bytecode_module_->GetFunctionsInfo()) { + auto *jit_function = jit_module_->GetFunctionPointer(func_info->GetName()); + NOISEPAGE_ASSERT(jit_function != nullptr, "Function not found!"); + functions_[func_info->GetId()].store(jit_function, std::memory_order_relaxed); } }); } diff --git a/src/execution/vm/vm.cpp b/src/execution/vm/vm.cpp index 77028fb9a1..8136a1aee7 100644 --- a/src/execution/vm/vm.cpp +++ b/src/execution/vm/vm.cpp @@ -413,6 +413,14 @@ void VM::Interpret(const uint8_t *ip, Frame *frame) { // NOLINT GEN_ASSIGN(int64_t, 8); #undef GEN_ASSIGN + OP(AssignN) : { + auto *dest = frame->LocalAt(READ_LOCAL_ID()); + auto *src = frame->LocalAt(READ_LOCAL_ID()); + auto len = READ_UIMM4(); + OpAssignN(dest, src, len); + DISPATCH_NEXT(); + } + OP(AssignImm4F) : { auto *dest = frame->LocalAt(READ_LOCAL_ID()); OpAssignImm4F(dest, READ_IMM4F()); @@ -2328,6 +2336,38 @@ void VM::Interpret(const uint8_t *ip, Frame *frame) { // NOLINT GEN_PARAM_GET(String, StringVal) #undef GEN_PARAM_GET +#define GEN_PARAM_ADD(Name, SqlType) \ + OP(AddParam##Name) : { \ + auto *exec_ctx = frame->LocalAt(READ_LOCAL_ID()); \ + auto *ret = frame->LocalAt(READ_LOCAL_ID()); \ + OpAddParam##Name(exec_ctx, ret); \ + DISPATCH_NEXT(); \ + } + + GEN_PARAM_ADD(Bool, BoolVal) + GEN_PARAM_ADD(TinyInt, Integer) + GEN_PARAM_ADD(SmallInt, Integer) + GEN_PARAM_ADD(Int, Integer) + GEN_PARAM_ADD(BigInt, Integer) + GEN_PARAM_ADD(Real, Real) + GEN_PARAM_ADD(Double, Real) + GEN_PARAM_ADD(DateVal, DateVal) + GEN_PARAM_ADD(TimestampVal, TimestampVal) + GEN_PARAM_ADD(String, StringVal) +#undef GEN_PARAM_ADD + + OP(StartNewParams) : { + auto *exec_ctx = frame->LocalAt(READ_LOCAL_ID()); + OpStartNewParams(exec_ctx); + DISPATCH_NEXT(); + } + + OP(FinishParams) : { + auto *exec_ctx = frame->LocalAt(READ_LOCAL_ID()); + OpFinishParams(exec_ctx); + DISPATCH_NEXT(); + } + // ------------------------------------------------------- // Trig functions // ------------------------------------------------------- @@ -2753,6 +2793,12 @@ void VM::Interpret(const uint8_t *ip, Frame *frame) { // NOLINT DISPATCH_NEXT(); } + OP(Random) : { + auto *result = frame->LocalAt(READ_LOCAL_ID()); + OpRandom(result); + DISPATCH_NEXT(); + } + OP(InitCap) : { auto *result = frame->LocalAt(READ_LOCAL_ID()); auto *exec_ctx = frame->LocalAt(READ_LOCAL_ID()); @@ -2794,7 +2840,7 @@ const uint8_t *VM::ExecuteCall(const uint8_t *ip, VM::Frame *caller) { const LocalVar param = LocalVar::Decode(READ_LOCAL_ID()); const void *param_ptr = caller->PtrToLocalAt(param); if (param.GetAddressMode() == LocalVar::AddressMode::Address) { - std::memcpy(raw_frame + param_info.GetOffset(), ¶m_ptr, param_info.GetSize()); + std::memcpy(raw_frame + param_info.GetOffset(), ¶m_ptr, sizeof(void *)); } else { std::memcpy(raw_frame + param_info.GetOffset(), param_ptr, param_info.GetSize()); } diff --git a/src/include/binder/bind_node_visitor.h b/src/include/binder/bind_node_visitor.h index 81ce0712b6..713d1a64a5 100644 --- a/src/include/binder/bind_node_visitor.h +++ b/src/include/binder/bind_node_visitor.h @@ -2,11 +2,17 @@ #include #include +#include +#include #include #include "binder/sql_node_visitor.h" #include "catalog/catalog_defs.h" +#include "execution/ast/udf/udf_ast_context.h" #include "execution/sql/sql.h" +#include "parser/postgresparser.h" +#include "parser/select_statement.h" +#include "parser/udf/variable_ref.h" namespace noisepage { @@ -47,6 +53,18 @@ class BindNodeVisitor final : public SqlNodeVisitor { /** Destructor. Must be defined due to forward declaration. */ ~BindNodeVisitor() final; + /** + * Perform binding for a UDF. + * @param parse_result The result of parsing the UDF. + * @param udf_ast_context The AST context for the UDF. + * @return The map of UDF parameters: + * Column Name -> (Parameter Name, Parameter Index) + * @throws BinderException on failure to bind query + */ + std::vector BindAndGetUDFVariableRefs( + common::ManagedPointer parse_result, + common::ManagedPointer udf_ast_context); + /** * Perform binding on the passed in tree. Bind the relation names to oids * @param parse_result Result generated by the parser. A collection of statements and expressions in the query @@ -101,6 +119,11 @@ class BindNodeVisitor final : public SqlNodeVisitor { /** Current context of the query or subquery */ common::ManagedPointer context_ = nullptr; + /** Context for UDF AST */ + common::ManagedPointer udf_ast_context_{}; + /** Parameters for UDF */ + std::vector udf_variable_refs_; + /** Catalog accessor */ const common::ManagedPointer catalog_accessor_; @@ -137,6 +160,41 @@ class BindNodeVisitor final : public SqlNodeVisitor { std::vector> *values, const catalog::Schema &table_schema); + /** @return `true` if we are binding within the context of a UDF, `false` otherwise */ + bool BindingForUDF() const; + + /** + * Determine if the given identifier names a UDF variable. + * @param identifier The variable identifier + * @return `true` if the variable is declared in the UDF + * for which binding is performed, `false` otherwise + */ + bool IsUDFVariable(const std::string &identifier) const; + + /** + * Determine if the given identifier names a variable + * reference that is already tracked. + * @param identifier The variable identifier + */ + bool HaveUDFVariableRef(const std::string &identifier) const; + + /** + * Add a UDF variable reference to the internal tracker. + * @param expr The expression + * @param table_name The name of the table associated with the reference + * @param column_name The name of the column associated with the reference + */ + void AddUDFVariableReference(common::ManagedPointer expr, + const std::string &table_name, const std::string &column_name); + + /** + * Add a UDF variable reference to the internal tracker. + * @param expr The expression + * @param column_name The name of the column associated with the reference + */ + void AddUDFVariableReference(common::ManagedPointer expr, + const std::string &column_name); + /** * Set the serial number of the table alias to a unique number if it isn't already set * @param node Table Ref to set serial number of diff --git a/src/include/binder/binder_sherpa.h b/src/include/binder/binder_sherpa.h index 2b6b724f6c..27985103a8 100644 --- a/src/include/binder/binder_sherpa.h +++ b/src/include/binder/binder_sherpa.h @@ -47,12 +47,13 @@ class BinderSherpa { common::ManagedPointer GetParseResult() const { return parse_result_; } /** - * @return parameters for the query being bound - * @warning can be nullptr if there are no parameters + * @return The parameters for the query being bound + * @warning May be `nullptr` if there are no parameters */ common::ManagedPointer> GetParameters() const { return parameters_; } /** + * Get the desired type for the expression. * @param expr The expression whose type constraints we want to look up. * @return The previously recorded type constraints, or the expression's current return value type if none exist. */ diff --git a/src/include/catalog/catalog_accessor.h b/src/include/catalog/catalog_accessor.h index f685fdd215..6da5005664 100644 --- a/src/include/catalog/catalog_accessor.h +++ b/src/include/catalog/catalog_accessor.h @@ -2,6 +2,7 @@ #include #include +#include #include #include #include @@ -331,7 +332,7 @@ class EXPORT CatalogAccessor { proc_oid_t CreateProcedure(const std::string &procname, language_oid_t language_oid, namespace_oid_t procns, type_oid_t variadic_type, const std::vector &args, const std::vector &arg_types, const std::vector &all_arg_types, - const std::vector &arg_modes, type_oid_t rettype, + const std::vector &arg_modes, type_oid_t rettype, const std::string &src, bool is_aggregate); /** @@ -342,22 +343,43 @@ class EXPORT CatalogAccessor { bool DropProcedure(proc_oid_t proc_oid); /** - * Gets the oid of a procedure from pg_proc given a requested name and namespace + * Get the OID of the procedure from pg_proc given a requested name and argument + * types as string identifiers. + * This lookup with return the first one found through a sequential scan through + * the current search path. + * @param procname name of the proc to lookup + * @param arg_types vector of type identifiers for the arguments of the procedure + * @return The OID of the resolved procedure if found, else `INVALID_PROC_OID` + */ + proc_oid_t GetProcOid(const std::string &procname, const std::vector &arg_types); + + /** + * Gets the OID of a procedure from pg_proc given a requested name and resolved argument types. * This lookup will return the first one found through a sequential scan through - * the current search path + * the current search path. * @param procname name of the proc to lookup - * @param all_arg_types vector of types of arguments of procedure to look up - * @return the oid of the found proc if found else INVALID_PROC_OID + * @param arg_types vector of types of arguments of procedure to look up + * @return The OID of the resolved procedure if found, else `INVALID_PROC_OID` */ - proc_oid_t GetProcOid(const std::string &procname, const std::vector &all_arg_types); + proc_oid_t GetProcOid(const std::string &procname, const std::vector &arg_types); /** - * Sets the proc context pointer column of proc_oid to func_context - * @param proc_oid The proc_oid whose pointer column we are setting here - * @param func_context The context object to set to - * @return False if the given proc_oid is invalid, True if else + * Resolve procedure argument types. + * @param procname The name of the procedure + * @param arg_types A vector of the string representation of the argument types + * @return A collection of all sets of arguments for which this procedure is resolved + */ + std::vector> ResolveProcArgumentTypes(const std::string &procname, + const std::vector &arg_types) const; + + /** + * Resolve procedure argument types. + * @param procname The name of the procedure + * @param arg_types A vector of the string representation of the argument types + * @return A collection of all sets of arguments for which this procedure is resolved */ - bool SetFunctionContextPointer(proc_oid_t proc_oid, const execution::functions::FunctionContext *func_context); + std::vector> ResolveProcArgumentTypes(const std::string &procname, + const std::vector &arg_types) const; /** * Gets the proc context pointer column of proc_oid @@ -366,6 +388,14 @@ class EXPORT CatalogAccessor { */ common::ManagedPointer GetFunctionContext(proc_oid_t proc_oid); + /** + * Sets the proc context pointer column of proc_oid to func_context + * @param proc_oid The proc_oid whose pointer column we are setting here + * @param func_context The context object to set to + * @return False if the given proc_oid is invalid, True if else + */ + bool SetFunctionContext(proc_oid_t proc_oid, const execution::functions::FunctionContext *func_context); + /** * Gets the statistics of a column from pg_statistic * @param table_oid table oid of table @@ -382,11 +412,18 @@ class EXPORT CatalogAccessor { optimizer::TableStats GetTableStatistics(table_oid_t table_oid); /** - * Returns the type oid of the given TypeId in pg_type - * @param type - * @return type_oid of type in pg_type + * Returns the type oid of the given TypeId in pg_type. + * @param type The queried type + * @return The corresponding type_oid_t */ - type_oid_t GetTypeOidFromTypeId(execution::sql::SqlTypeId type); + type_oid_t GetTypeOidFromTypeId(execution::sql::SqlTypeId type) const; + + /** + * Returns the SQL type ID of the given type_oid_t. + * @param type The queried type + * @return The corresponding SQL type ID + */ + execution::sql::SqlTypeId GetTypeIdFromTypeOid(type_oid_t type) const; /** * @return BlockStore to be used for CREATE operations @@ -458,6 +495,13 @@ class EXPORT CatalogAccessor { static void NormalizeObjectName(std::string *name) { std::transform(name->begin(), name->end(), name->begin(), [](auto &&c) { return std::tolower(c); }); } + + /** + * Resolve a string type name identifier to a catalog type. + * @param type_name The type name + * @return The internal catalog type identifier for the type + */ + type_oid_t TypeNameToType(const std::string &type_name) const; }; } // namespace noisepage::catalog diff --git a/src/include/catalog/database_catalog.h b/src/include/catalog/database_catalog.h index 63a0960f65..dc60153fab 100644 --- a/src/include/catalog/database_catalog.h +++ b/src/include/catalog/database_catalog.h @@ -149,7 +149,10 @@ class DatabaseCatalog { common::ManagedPointer txn, table_oid_t table); /** @return The type_oid_t that corresponds to the internal TypeId. */ - type_oid_t GetTypeOidForType(execution::sql::SqlTypeId type); + type_oid_t GetTypeOidForType(execution::sql::SqlTypeId type) const; + + /** @return The SQL type ID that corresponds to the type_oid_t */ + execution::sql::SqlTypeId GetTypeForTypeOid(type_oid_t type) const; /** @brief Get a list of all of the constraints for the specified table. */ std::vector GetConstraints(common::ManagedPointer txn, @@ -169,20 +172,28 @@ class DatabaseCatalog { language_oid_t language_oid, namespace_oid_t procns, type_oid_t variadic_type, const std::vector &args, const std::vector &arg_types, const std::vector &all_arg_types, - const std::vector &arg_modes, type_oid_t rettype, + const std::vector &arg_modes, type_oid_t rettype, const std::string &src, bool is_aggregate); /** @brief Drop the specified procedure. @see PgProcImpl::DropProcedure */ bool DropProcedure(common::ManagedPointer txn, proc_oid_t proc); + /** @brief Get the OID of the specified procedure. @see PgProcImpl::GetProcOid */ proc_oid_t GetProcOid(common::ManagedPointer txn, namespace_oid_t procns, const std::string &procname, const std::vector &all_arg_types); - /** @brief Set the procedure context for the specified procedure. @see PgProcImpl::SetFunctionContextPointer */ - bool SetFunctionContextPointer(common::ManagedPointer txn, proc_oid_t proc_oid, - const execution::functions::FunctionContext *func_context); - /** @brief Get the procedure context for the specified procedure. @see PgProcImpl::GetFunctionContext */ + + /** @brief Resolve all combinations of argument types for the procedure */ + std::vector> ResolveProcArgumentTypes( + common::ManagedPointer txn, namespace_oid_t procns, const std::string &procname, + const std::vector &arg_types); + + /** @brief Get the procedure context for the specified procedure. @see PgProcImpl::GetProcCtxPtr */ common::ManagedPointer GetFunctionContext( common::ManagedPointer txn, proc_oid_t proc_oid); + /** @brief Set the procedure context for the specified procedure. @see PgProcImpl::SetProcCtxPtr */ + bool SetFunctionContext(common::ManagedPointer txn, proc_oid_t proc_oid, + const execution::functions::FunctionContext *func_context); + /** @brief Get the statistics for the specified column. @see PgStatisticImpl::GetColumnStatistics */ std::unique_ptr GetColumnStatistics( common::ManagedPointer txn, table_oid_t table_oid, col_oid_t col_oid); @@ -349,5 +360,37 @@ class DatabaseCatalog { template bool SetClassPointer(common::ManagedPointer txn, ClassOid oid, const Ptr *pointer, col_oid_t class_col); + + /* -------------------------------------------------------------------------- + Function Lookup + -------------------------------------------------------------------------- */ + + /** + * Recursive helper function for procedure argument type resolution. + * @param txn The transaction context + * @param procns The namespace of the procedure + * @param procname The procedure name + * @param arg_types The argument types + * @param result The vector that receives any resolved sets of arguments + */ + void ResolveProcArgumentTypes(common::ManagedPointer txn, namespace_oid_t procns, + const std::string &procname, const std::vector &arg_types, + std::vector> *result); + + /** + * Determine if the vector of argument types contains an untyped NULL. + * @param arg_types The vector of argument types + * @return `true` if the vector contains an untyped NULL type, `false` otherwise + */ + bool ContainsUntypedNull(const std::vector &arg_types) const; + + /** + * Swap the first untyped NULL argument type in `arg_types` with `type`. + * @param arg_types The vector of argument types that is mutated + * @param type The type that is swapped in for the untyped NULL + * @return The modified vector + */ + std::vector ReplaceFirstUntypedNullWith(const std::vector &arg_types, + execution::sql::SqlTypeId type) const; }; } // namespace noisepage::catalog diff --git a/src/include/catalog/postgres/pg_language.h b/src/include/catalog/postgres/pg_language.h index a22ece3c4a..14cce1adb8 100644 --- a/src/include/catalog/postgres/pg_language.h +++ b/src/include/catalog/postgres/pg_language.h @@ -9,6 +9,10 @@ namespace noisepage::storage { class RecoveryManager; } // namespace noisepage::storage +namespace noisepage::execution::sql { +class DDLExecutors; +} // namespace noisepage::execution::sql + namespace noisepage::catalog::postgres { class Builder; class PgLanguageImpl; @@ -17,7 +21,11 @@ class PgProcImpl; /** The OIDs used by the NoisePage version of pg_language. */ class PgLanguage { private: + // TODO(Kyle): Should we come up with a better way of exposting + // these constants rather than simply adding friends for each + // class that needs to access them? This is not scalable. friend class storage::RecoveryManager; + friend class execution::sql::DDLExecutors; friend class Builder; friend class PgLanguageImpl; @@ -38,7 +46,7 @@ class PgLanguage { static constexpr CatalogColumnDef LANNAME{col_oid_t{2}}; // VARCHAR (skey) static constexpr CatalogColumnDef LANISPL{col_oid_t{3}}; // BOOLEAN (skey) static constexpr CatalogColumnDef LANPLTRUSTED{col_oid_t{4}}; // BOOLEAN (skey) - // TODO(tanujnay112): Make these foreign keys when we implement pg_proc + static constexpr CatalogColumnDef LANPLCALLFOID{ col_oid_t{5}}; // INTEGER (skey) (fkey: pg_proc) static constexpr CatalogColumnDef LANINLINE{col_oid_t{6}}; // INTEGER (skey) (fkey: pg_proc) diff --git a/src/include/catalog/postgres/pg_proc.h b/src/include/catalog/postgres/pg_proc.h index b948709152..11a6d2a173 100644 --- a/src/include/catalog/postgres/pg_proc.h +++ b/src/include/catalog/postgres/pg_proc.h @@ -26,7 +26,7 @@ class PgProcImpl; class PgProc { public: /** The type of the argument to the procedure. */ - enum class ArgModes : char { + enum class ArgMode : char { IN = 'i', ///< Input argument. OUT = 'o', ///< Output argument. INOUT = 'b', ///< Both input and output argument. diff --git a/src/include/catalog/postgres/pg_proc_impl.h b/src/include/catalog/postgres/pg_proc_impl.h index 43500ffb81..e045b8722a 100644 --- a/src/include/catalog/postgres/pg_proc_impl.h +++ b/src/include/catalog/postgres/pg_proc_impl.h @@ -105,7 +105,7 @@ class PgProcImpl { const std::string &procname, language_oid_t language_oid, namespace_oid_t procns, type_oid_t variadic_type, const std::vector &args, const std::vector &arg_types, const std::vector &all_arg_types, - const std::vector &arg_modes, type_oid_t rettype, + const std::vector &arg_modes, type_oid_t rettype, const std::string &src, bool is_aggregate); /** diff --git a/src/include/common/strong_typedef.h b/src/include/common/strong_typedef.h index ce9f368041..d6f4fedbdf 100644 --- a/src/include/common/strong_typedef.h +++ b/src/include/common/strong_typedef.h @@ -102,7 +102,6 @@ class StrongTypeAlias { */ constexpr const IntType &UnderlyingValue() const { return val_; } - // TODO(Kyle): perhaps remove ability to static_cast to underlying value altogether. /** * * @return the underlying value diff --git a/src/include/execution/ast/ast.h b/src/include/execution/ast/ast.h index 6c45d7d0a9..d9ed04d7b9 100644 --- a/src/include/execution/ast/ast.h +++ b/src/include/execution/ast/ast.h @@ -48,6 +48,7 @@ namespace ast { T(DeclStmt) \ T(ExpressionStmt) \ T(ForStmt) \ + T(BreakStmt) \ T(ForInStmt) \ T(IfStmt) \ T(ReturnStmt) @@ -66,12 +67,14 @@ namespace ast { T(IdentifierExpr) \ T(ImplicitCastExpr) \ T(IndexExpr) \ + T(LambdaExpr) \ T(LitExpr) \ T(MemberExpr) \ T(UnaryOpExpr) \ /* Type Representation Expressions */ \ T(ArrayTypeRepr) \ T(FunctionTypeRepr) \ + T(LambdaTypeRepr) \ T(MapTypeRepr) \ T(PointerTypeRepr) \ T(StructTypeRepr) @@ -351,14 +354,20 @@ class FunctionDecl : public Decl { * @param pos source position * @param name identifier * @param func function literal (param types, return type, body) + * @param is_lambda `true` if this function is constructed from a lambda expresison */ - FunctionDecl(const SourcePosition &pos, Identifier name, FunctionLitExpr *func); + FunctionDecl(const SourcePosition &pos, Identifier name, FunctionLitExpr *func, bool is_lambda = false); /** * @return The function literal defining the body of the function declaration. */ FunctionLitExpr *Function() const { return func_; } + /** + * @return `true` if this function is a lambda, `false` otherwise. + */ + bool IsLambda() const noexcept { return is_lambda_; } + /** * Is the given node a function declaration? Needed as part of the custom AST RTTI infrastructure. * @param node The node to check. @@ -371,6 +380,8 @@ class FunctionDecl : public Decl { private: // The function definition (signature and body). FunctionLitExpr *func_; + // Is this function generated by a lambda expression. + const bool is_lambda_; }; /** @@ -691,6 +702,26 @@ class IterationStmt : public Stmt { BlockStmt *body_; }; +/** + * A break statement. + */ +class BreakStmt : public Stmt { + public: + /** + * Constructor + * @param pos source position + */ + explicit BreakStmt(const SourcePosition &pos) : Stmt(Kind::BreakStmt, pos) {} + + /** + * Is the given node a break statement? + * Needed as part of the custom AST RTTI infrastructure. + * @param node The node to check. + * @return `true` if the node is a break statement, `false` otherwise. + */ + static bool classof(const AstNode *node) { return node->GetKind() == Kind::BreakStmt; } // NOLINT +}; + /** * A vanilla for-statement. */ @@ -731,6 +762,17 @@ class ForStmt : public IterationStmt { return node->GetKind() == Kind::ForStmt; } + private: + friend class sema::Sema; + + /** + * Set the condition for the for-loop. + */ + void SetCondition(Expr *cond) { + NOISEPAGE_ASSERT(cond != nullptr, "Cannot set null condition"); + cond_ = cond; + } + private: Stmt *init_; Expr *cond_; @@ -831,6 +873,9 @@ class IfStmt : public Stmt { private: friend class sema::Sema; + /** + * Set the condition for the if-statement. + */ void SetCondition(Expr *cond) { NOISEPAGE_ASSERT(cond != nullptr, "Cannot set null condition"); cond_ = cond; @@ -1039,6 +1084,65 @@ class BinaryOpExpr : public Expr { Expr *right_; }; +/** + * A lambda expression. + */ +class LambdaExpr : public Expr { + public: + /** + * Construct a new LambdaExpr instance. + * @param pos source position + * @param function the associated function literal expression + * @param captures a collection of lambda captures + */ + LambdaExpr(const SourcePosition &pos, FunctionLitExpr *function, util::RegionVector &&captures) + : Expr{Kind::LambdaExpr, pos}, function_literal_{function}, capture_idents_{std::move(captures)} {} + + /** @return The identifier for this lambda expression. */ + const Identifier &GetName() const { return name_; } + + /** + * Set the name of this lambda expression. + * @param name The desired name. + */ + void SetName(Identifier name) { name_ = name; } + + /** @return Get the capture struct type for this lambda expression. */ + ast::Type *GetCaptureStructType() const { return capture_type_; } + + /** + * Set the capture struct type for this lambda expression. + * @param capture_type The desired type. + */ + void SetCaptureStructType(ast::Type *capture_type) { capture_type_ = capture_type; } + + /** @return The function literal expression associated with this lambda. */ + FunctionLitExpr *GetFunctionLiteralExpr() const { return function_literal_; } + + /** @return The identifiers for the captures of this lambda expression. */ + const util::RegionVector &GetCaptureIdents() const { return capture_idents_; } + + /** + * Is the given node a lambda expression? Needed as part of the custom AST RTTI infrastructure. + * @param node The node to check. + * @return `true` if the node is a lambda expression; `false` otherwise. + */ + static bool classof(const AstNode *node) { // NOLINT + return node->GetKind() == Kind::LambdaExpr; + } + + private: + friend class sema::Sema; + /** The identifier for the lambda expression. */ + Identifier name_; + /** The type of the lambda captures struct. */ + ast::Type *capture_type_; + /** The associated function literal expression. */ + FunctionLitExpr *function_literal_; + /** The collection of identifers for lambda captures. */ + util::RegionVector capture_idents_; +}; + /** * A function call expression. */ @@ -1047,7 +1151,7 @@ class CallExpr : public Expr { /** * Type of call (builtin call or regular function call) */ - enum class CallKind : uint8_t { Regular, Builtin }; + enum class CallKind : uint8_t { Regular, Builtin, Lambda }; /** * Constructor for regular calls @@ -1085,6 +1189,11 @@ class CallExpr : public Expr { */ uint32_t NumArgs() const { return static_cast(args_.size()); } + /** + * Add an argument to the call. + */ + void PushArgument(Expr *expr) { args_.push_back(expr); } + /** * @return The kind of call, either regular or a call to a builtin function. */ @@ -1201,8 +1310,9 @@ class FunctionLitExpr : public Expr { * Constructor * @param type_repr type representation (param types, return type) * @param body body of the function + * @param is_lambda `true` if the literal is a lambda, `false` otherwise */ - FunctionLitExpr(FunctionTypeRepr *type_repr, BlockStmt *body); + FunctionLitExpr(FunctionTypeRepr *type_repr, BlockStmt *body, bool is_lambda = false); /** * @return The function's signature. @@ -1215,10 +1325,15 @@ class FunctionLitExpr : public Expr { BlockStmt *Body() const { return body_; } /** - * @return True if the function has no statements; false otherwise. + * @return `true` if the function has no statements; `false` otherwise. */ bool IsEmpty() const { return Body()->IsEmpty(); } + /** + * @return `true` if the function is a lambda, `false` otherwise. + */ + bool IsLambda() const { return is_lambda_; } + /** * Is the given node a function literal? Needed as part of the custom AST RTTI infrastructure. * @param node The node to check. @@ -1233,6 +1348,8 @@ class FunctionLitExpr : public Expr { FunctionTypeRepr *type_repr_; // The body of the function. BlockStmt *body_; + // Is this function literal a lambda. + const bool is_lambda_; }; /** @@ -1797,6 +1914,36 @@ class MapTypeRepr : public Expr { Expr *val_; }; +/** + * Lambda type. + */ +class LambdaTypeRepr : public Expr { + public: + /** + * Constructor + * @param pos source position + * @param fn_type function type + */ + LambdaTypeRepr(const SourcePosition &pos, Expr *fn_type) : Expr(Kind::LambdaTypeRepr, pos), fn_type_(fn_type) {} + + /** + * @return The expression for the type. + */ + Expr *FunctionType() const { return fn_type_; } + + /** + * Is the given node a lambda type representation? Needed as part of the custom AST RTTI infrastructure. + * @param node The node to check. + * @return `true` if the node is a lambda type representation; `false` otherwise. + */ + static bool classof(const AstNode *node) { // NOLINT + return node->GetKind() == Kind::LambdaTypeRepr; + } + + private: + Expr *fn_type_; +}; + /** * Pointer type. */ diff --git a/src/include/execution/ast/ast_clone.h b/src/include/execution/ast/ast_clone.h new file mode 100644 index 0000000000..540a881777 --- /dev/null +++ b/src/include/execution/ast/ast_clone.h @@ -0,0 +1,28 @@ +#pragma once + +#include + +#include "execution/ast/ast_node_factory.h" +#include "execution/ast/context.h" + +namespace noisepage::execution::ast { + +class AstNode; + +/** + * The AstClone class encapsulates the logic necessary to clone an AST. + */ +class AstClone { + public: + /** + * Clones an ASTNode and its descendants. + * @param node The root of the AST to clone + * @param factory The AstNodeFactory instance from which AST nodes are allocated + * @param old_context The old AST context + * @param new_context The new AST context + * @return + */ + static AstNode *Clone(AstNode *node, AstNodeFactory *factory, Context *old_context, Context *new_context); +}; + +} // namespace noisepage::execution::ast diff --git a/src/include/execution/ast/ast_fwd.h b/src/include/execution/ast/ast_fwd.h index 77ee5be674..e1d7091cb8 100644 --- a/src/include/execution/ast/ast_fwd.h +++ b/src/include/execution/ast/ast_fwd.h @@ -12,6 +12,7 @@ class Decl; class FieldDecl; class File; // NOLINT it picks up madoka's File class FunctionDecl; +class LambdaExpr; class Stmt; class StructDecl; class VariableDecl; diff --git a/src/include/execution/ast/ast_node_factory.h b/src/include/execution/ast/ast_node_factory.h index cae09e4690..8c10eb11eb 100644 --- a/src/include/execution/ast/ast_node_factory.h +++ b/src/include/execution/ast/ast_node_factory.h @@ -45,6 +45,17 @@ class AstNodeFactory { return new (region_) FunctionDecl(pos, name, fun); } + /** + * @param pos source position + * @param fun function literal (params, return type, body) + * @param captures lambda captures + * @return created LambdaExpr node. + */ + LambdaExpr *NewLambdaExpr(const SourcePosition &pos, FunctionLitExpr *fun, + util::RegionVector &&captures) { + return new (region_) LambdaExpr(pos, fun, std::move(captures)); + } + /** * @param pos source position * @param name struct name @@ -133,6 +144,12 @@ class AstNodeFactory { return new (region_) IfStmt(pos, cond, then_stmt, else_stmt); } + /** + * @param pos source position + * @return created BreakStmt node + */ + BreakStmt *NewBreakStmt(const SourcePosition &pos) { return new (region_) BreakStmt(pos); } + /** * @param pos source position * @param ret returned expression @@ -337,6 +354,15 @@ class AstNodeFactory { return new (region_) MapTypeRepr(pos, key_type, val_type); } + /** + * @param pos source position + * @param fn_type the function type + * @return created LambdaTypeRepr + */ + LambdaTypeRepr *NewLambdaType(const SourcePosition &pos, Expr *fn_type) { + return new (region_) LambdaTypeRepr(pos, fn_type); + } + private: util::Region *region_; }; diff --git a/src/include/execution/ast/ast_traversal_visitor.h b/src/include/execution/ast/ast_traversal_visitor.h index f4e90e5d32..7a0ff7c203 100644 --- a/src/include/execution/ast/ast_traversal_visitor.h +++ b/src/include/execution/ast/ast_traversal_visitor.h @@ -210,6 +210,12 @@ inline void AstTraversalVisitor::VisitForStmt(ForStmt *node) { RECURSE(Visit(node->Body())); } +template +inline void AstTraversalVisitor::VisitBreakStmt(BreakStmt *node) { + PROCESS_NODE(node); + // TODO(Kyle): Implement this?? +} + template inline void AstTraversalVisitor::VisitForInStmt(ForInStmt *node) { PROCESS_NODE(node); @@ -232,6 +238,12 @@ inline void AstTraversalVisitor::VisitMapTypeRepr(MapTypeRepr *node) { RECURSE(Visit(node->ValType())); } +template +inline void AstTraversalVisitor::VisitLambdaTypeRepr(LambdaTypeRepr *node) { + PROCESS_NODE(node); + // TODO(Kyle): Implement this?? +} + template inline void AstTraversalVisitor::VisitLitExpr(LitExpr *node) { PROCESS_NODE(node); @@ -294,6 +306,12 @@ inline void AstTraversalVisitor::VisitIndexExpr(IndexExpr *node) { RECURSE(Visit(node->Index())); } +template +inline void AstTraversalVisitor::VisitLambdaExpr(LambdaExpr *node) { + PROCESS_NODE(node); + // TODO(Kyle): Implement this?? +} + template inline void AstTraversalVisitor::VisitFunctionTypeRepr(FunctionTypeRepr *node) { PROCESS_NODE(node); diff --git a/src/include/execution/ast/builtins.h b/src/include/execution/ast/builtins.h index f0be803815..65034a1472 100644 --- a/src/include/execution/ast/builtins.h +++ b/src/include/execution/ast/builtins.h @@ -34,6 +34,7 @@ namespace noisepage::execution::ast { /* SQL Functions */ \ F(Like, like) \ F(DatePart, datePart) \ + F(Random, random) \ \ /* Thread State Container */ \ F(ExecutionContextAddRowsAffected, execCtxAddRowsAffected) \ @@ -349,6 +350,18 @@ namespace noisepage::execution::ast { F(GetParamDate, getParamDate) \ F(GetParamTimestamp, getParamTimestamp) \ F(GetParamString, getParamString) \ + F(StartNewParams, startNewParams) \ + F(FinishNewParams, finishNewParams) \ + F(AddParamBool, addParamBool) \ + F(AddParamTinyInt, addParamTinyInt) \ + F(AddParamSmallInt, addParamSmallInt) \ + F(AddParamInt, addParamInt) \ + F(AddParamBigInt, addParamBigInt) \ + F(AddParamReal, addParamReal) \ + F(AddParamDouble, addParamDouble) \ + F(AddParamDate, addParamDate) \ + F(AddParamTimestamp, addParamTimestamp) \ + F(AddParamString, addParamString) \ \ /* String functions */ \ F(Lower, lower) \ diff --git a/src/include/execution/ast/type.h b/src/include/execution/ast/type.h index 0cf58eecd7..d9237ff9df 100644 --- a/src/include/execution/ast/type.h +++ b/src/include/execution/ast/type.h @@ -22,7 +22,9 @@ class Context; F(BuiltinType) \ F(StringType) \ F(PointerType) \ + F(ReferenceType) \ F(ArrayType) \ + F(LambdaType) \ F(MapType) \ F(StructType) \ F(FunctionType) @@ -313,6 +315,11 @@ class Type : public util::RegionObject { */ PointerType *PointerTo(); + /** + * @return A new type that is a reference to the current type. + */ + ReferenceType *ReferenceTo(); + /** * @return If this is a pointer type, the type of the element pointed to. Returns null otherwise. */ @@ -506,6 +513,37 @@ class PointerType : public Type { Type *base_; }; +/** + * Reference type. + */ +class ReferenceType : public Type { + public: + /** + * @return base type + */ + Type *GetBase() const { return base_; } + + /** + * Static Constructor + * @param base type + * @return reference to base type + */ + static ReferenceType *Get(Type *base); + + /** + * @param type checked type + * @return whether type is a reference type. + */ + static bool classof(const Type *type) { return type->GetTypeId() == TypeId::ReferenceType; } // NOLINT + + private: + explicit ReferenceType(Type *base) + : Type(base->GetContext(), sizeof(int8_t *), alignof(int8_t *), TypeId::ReferenceType), base_(base) {} + + private: + Type *base_; +}; + /** * Array type. */ @@ -560,8 +598,8 @@ class ArrayType : public Type { }; /** - * A field is a pair containing a name and a type. It is used to represent both fields within a struct, and parameters - * to a function. + * A Field is a pair containing a name and a type. + * It is used to represent both fields within a struct, and parameters to a function. */ struct Field { /** @@ -575,12 +613,18 @@ struct Field { Type *type_; /** - * Constructor + * Construct a new Field instance. * @param name of the field * @param type of the field */ Field(const Identifier &name, Type *type) : name_(name), type_(type) {} + /** @return The name of the field */ + const Identifier &GetName() const { return name_; } + + /** @return The type of the field */ + Type *GetType() const { return type_; } + /** * @param other rhs of the comparison * @return whether this == other @@ -594,10 +638,15 @@ struct Field { class FunctionType : public Type { public: /** - * @return A constant reference to the list of parameters to a function. + * @return An immutable reference to the list of parameters to a function. */ const util::RegionVector &GetParams() const { return params_; } + /** + * @return A mutable reference to the list of parameters to a function. + */ + util::RegionVector &GetParams() { return params_; } + /** * @return The number of parameters to the function. */ @@ -609,13 +658,57 @@ class FunctionType : public Type { Type *GetReturnType() const { return ret_; } /** - * Create a function with parameters @em params and returning types of type @em ret. + * Determine if this function is equivalent to `other`. + * @param other The other function of interest + * @return `true` if the functions are equivalent, `false` otherwise. + */ + bool IsEqual(const FunctionType *other); + + /** @return `true` if this function is a lambda, `false` otherwise. */ + bool IsLambda() const { return is_lambda_; } + + /** + * Set the lambda disposition for this function. + * @param is_lambda `true` if this function is a lambda, `false` otherwise. + */ + void SetIsLambda(bool is_lambda) { is_lambda_ = is_lambda; } + + /** + * Get the type of the lambda captures struct. + * @return The struct type for lambda captures. + */ + ast::StructType *GetCapturesType() const { + NOISEPAGE_ASSERT(is_lambda_, "Getting capture type from not lambda"); + return captures_; + } + + /** + * Set the type of the lambda captures struct. + * @param captures The struct type for lambda captures. + */ + void SetCapturesType(ast::StructType *captures) { captures_ = captures; } + + /** + * Register lambda captures as a parameter to this function. + */ + void RegisterCapture(); + + /** + * Create a function with parameters `params` and returning types of type `ret`. * @param params The parameters to the function. * @param ret The type of the object the function returns. * @return The function type. */ static FunctionType *Get(util::RegionVector &¶ms, Type *ret); + /** + * Create a lambda function with params `params` and returning types of type `ret`. + * @param params The parameters to the function. + * @param ret The type of the object the function returns. + * @return The function type. + */ + static FunctionType *GetLambda(util::RegionVector &¶ms, Type *ret); + /** * @param type type to compare with * @return whether type is of function type @@ -623,11 +716,13 @@ class FunctionType : public Type { static bool classof(const Type *type) { return type->GetTypeId() == TypeId::FunctionType; } // NOLINT private: - explicit FunctionType(util::RegionVector &¶ms, Type *ret); + explicit FunctionType(util::RegionVector &¶ms, Type *ret, bool is_lambda); private: util::RegionVector params_; Type *ret_; + bool is_lambda_; + ast::StructType *captures_; }; /** @@ -655,7 +750,7 @@ class MapType : public Type { /** * @param type to compare with - * @return whether type is of map type. + * @return whether type is of Map type. */ static bool classof(const Type *type) { return type->GetTypeId() == TypeId::MapType; } // NOLINT @@ -667,6 +762,34 @@ class MapType : public Type { Type *val_type_; }; +/** + * Lambda type. + */ +class LambdaType : public Type { + public: + /** + * @return The function type representation. + */ + FunctionType *GetFunctionType() const { return fn_type_; } + + /** + * @return A newly-constructed lambda type. + */ + static LambdaType *Get(FunctionType *fn_type); + + /** + * @param type to compare with + * @return whether type is of Lambda type. + */ + static bool classof(const Type *type) { return type->GetTypeId() == TypeId::LambdaType; } // NOLINT + + private: + explicit LambdaType(FunctionType *fn_type); + + private: + FunctionType *fn_type_; +}; + /** * Struct type. */ diff --git a/src/include/execution/ast/udf/node_types.h b/src/include/execution/ast/udf/node_types.h new file mode 100644 index 0000000000..44cf92d37d --- /dev/null +++ b/src/include/execution/ast/udf/node_types.h @@ -0,0 +1,24 @@ +namespace noisepage::execution::ast::udf { + +/** Enumerates all (instantiable) AST node types */ +enum class NodeType { + VALUE_EXPR, + IS_NULL_EXPR, + VARIABLE_EXPR, + MEMBER_EXPR, + BINARY_EXPR, + CALL_EXPR, + SEQ_STMT, + DECL_STMT, + IF_STMT, + FORI_STMT, + FORS_STMT, + WHILE_STMT, + RET_STMT, + ASSIGN_STMT, + SQL_STMT, + DYNAMIC_SQL_STMT, + FUNCTION +}; + +} // namespace noisepage::execution::ast::udf diff --git a/src/include/execution/ast/udf/udf_ast_context.h b/src/include/execution/ast/udf/udf_ast_context.h new file mode 100644 index 0000000000..009825372c --- /dev/null +++ b/src/include/execution/ast/udf/udf_ast_context.h @@ -0,0 +1,145 @@ +#pragma once + +#include +#include +#include +#include +#include + +#include "execution/sql/sql.h" + +namespace noisepage::execution::ast::udf { + +/** + * The UdfAstContext class maintains state that is utilized + * throughout construction of the UDF abstract syntax tree. + */ +class UdfAstContext { + /** An invidual entry for a record type, (name, type ID) */ + using RecordTypeEntry = std::pair; + + /** A full description of a record type */ + using RecordType = std::vector; + + public: + /** + * Construct a new AstContext instance. + */ + UdfAstContext() = default; + + /** + * Push a new local variable. + * @param name The name of the variable + */ + void AddLocal(const std::string &name) { locals_.push_back(name); } + + /** + * Get the local variable at index `index`. + * @param index The index of interest + * @return The name of the variable at the specified index + */ + const std::string &GetLocalAtIndex(const std::size_t index) const { + NOISEPAGE_ASSERT(locals_.size() >= index, "Index out of range"); + // TODO(Kyle): I moved the subtraction to the call site because + // it seems misleading to have a getter for an index but deliver + // a local that does not actually appear at that index... + return locals_.at(index); + } + + /** + * Determine if a variable with name `name` is present in the UDF AST. + * @param name The name of the variable + * @return `true` if the UDF AST context contains a variable + * identified by `name`, `false` otherwise + */ + bool HasVariable(const std::string &name) const { return (symbol_table_.find(name) != symbol_table_.cend()); } + + /** + * Set the type of the variabel identifed by `name`. + * @param name The name of the variable + * @param type The type to which the variable should be set + */ + void SetVariableType(const std::string &name, sql::SqlTypeId type) { symbol_table_[name] = type; } + + /** + * Get the type of the variable identified by `name`. + * @param name The name of the variable + * @return The type ID for the specified variable if present, + * empty optional value otherwise + */ + std::optional GetVariableType(const std::string &name) const { + auto it = symbol_table_.find(name); + return (it == symbol_table_.cend()) ? std::nullopt : std::make_optional(it->second); + } + + /** + * Get the type of the variable identified by `name`. + * @param name The name of the variable + * @return The type ID for the specified variable + * + * NOTE: This function terminates the program in the event + * that the variable is not present; for variable queries + * that may fail, use UdfAstContext::GetVariableType(). + */ + sql::SqlTypeId GetVariableTypeFailFast(const std::string &name) const { + auto it = symbol_table_.find(name); + NOISEPAGE_ASSERT(it != symbol_table_.cend(), "Required variable is not present in UDF AST"); + return it->second; + } + + /** + * Determine if a record variable with name `name` is present in the UDF AST. + * @param name The name of the variable + * @return `true` if the UDF AST context contains a record variable + * identified by `name`, `false` otherwise + */ + bool HasRecord(const std::string &name) const { return (record_types_.find(name) != record_types_.cend()); } + + /** + * Set the record type for the variable identified by `name`. + * @param name The name of the variable + * @param elems The record + */ + void SetRecordType(const std::string &name, std::vector> &&elems) { + record_types_[name] = std::move(elems); + } + + /** + * Get the record type for the variable identified by `name`. + * @param name The name of the variable + * @return The type of the record variable if present, + * empty optional value otherwise + */ + std::optional GetRecordType(const std::string &name) const { + auto it = record_types_.find(name); + // TODO(Kyle): I updated the API for this function to use std::optional, + // I like this more, but it makes it impossible to return a reference to + // the underlying data so this now materializes a copy every time + return (it == record_types_.cend()) ? std::nullopt : std::make_optional(it->second); + } + + /** + * Get the record type for the variable identified by `name`. + * @param name The name of the variable + * @return The type of the record variable + * + * NOTE: This function terminates the program in the event + * that the variable is not present; for variable queries + * that may fail, use UdfAstContext::GetRecordType(). + */ + RecordType GetRecordTypeFailFast(const std::string &name) const { + auto it = record_types_.find(name); + NOISEPAGE_ASSERT(it != record_types_.cend(), "Required record variable is not present in UDF AST"); + return it->second; + } + + private: + /** Collection of local variable names for the UDF. */ + std::vector locals_; + /** The symbol table for the UDF. */ + std::unordered_map symbol_table_; + /** Collection of record types for the UDF. */ + std::unordered_map record_types_; +}; + +} // namespace noisepage::execution::ast::udf diff --git a/src/include/execution/ast/udf/udf_ast_node_visitor.h b/src/include/execution/ast/udf/udf_ast_node_visitor.h new file mode 100644 index 0000000000..9a115c5ec6 --- /dev/null +++ b/src/include/execution/ast/udf/udf_ast_node_visitor.h @@ -0,0 +1,158 @@ +#pragma once + +namespace noisepage::execution::ast::udf { + +class AbstractAST; +class StmtAST; +class ExprAST; +class ValueExprAST; +class IsNullExprAST; +class VariableExprAST; +class BinaryExprAST; +class CallExprAST; +class MemberExprAST; +class SeqStmtAST; +class DeclStmtAST; +class IfStmtAST; +class WhileStmtAST; +class RetStmtAST; +class AssignStmtAST; +class SQLStmtAST; +class DynamicSQLStmtAST; +class ForIStmtAST; +class ForSStmtAST; +class FunctionAST; + +/** + * The ASTNodeVisitor class defines the interface for + * visitors of the UDF abstract syntax tree. + */ +class ASTNodeVisitor { + public: + /** + * Destroy the visitor. + */ + virtual ~ASTNodeVisitor() = default; + + /** + * Visit an AbstractAST node. + * @param ast The node to visit + */ + virtual void Visit(AbstractAST *ast) = 0; + + /** + * Visit an StmtAST node. + * @param ast The node to visit + */ + virtual void Visit(StmtAST *ast) = 0; + + /** + * Visit an ExprAST node. + * @param ast The node to visit + */ + virtual void Visit(ExprAST *ast) = 0; + + /** + * Visit an FunctionAST node. + * @param ast The node to visit + */ + virtual void Visit(FunctionAST *ast) = 0; + + /** + * Visit an ValueExprAST node. + * @param ast The node to visit + */ + virtual void Visit(ValueExprAST *ast) = 0; + + /** + * Visit an VariableExprAST node. + * @param ast The node to visit + */ + virtual void Visit(VariableExprAST *ast) = 0; + + /** + * Visit an BinaryExprAST node. + * @param ast The node to visit + */ + virtual void Visit(BinaryExprAST *ast) = 0; + + /** + * Visit an IsNullExprAST node. + * @param ast The node to visit + */ + virtual void Visit(IsNullExprAST *ast) = 0; + + /** + * Visit an CallExprAST node. + * @param ast The node to visit + */ + virtual void Visit(CallExprAST *ast) = 0; + + /** + * Visit an MemberExprAST node. + * @param ast The node to visit + */ + virtual void Visit(MemberExprAST *ast) = 0; + + /** + * Visit an SeqStmtAST node. + * @param ast The node to visit + */ + virtual void Visit(SeqStmtAST *ast) = 0; + + /** + * Visit an DeclStmtAST node. + * @param ast The node to visit + */ + virtual void Visit(DeclStmtAST *ast) = 0; + + /** + * Visit an IfStmtAST node. + * @param ast The node to visit + */ + virtual void Visit(IfStmtAST *ast) = 0; + + /** + * Visit an WhileStmtAST node. + * @param ast The node to visit + */ + virtual void Visit(WhileStmtAST *ast) = 0; + + /** + * Visit an RetStmtAST node. + * @param ast The node to visit + */ + virtual void Visit(RetStmtAST *ast) = 0; + + /** + * Visit an AssignStmtAST node. + * @param ast The node to visit + */ + virtual void Visit(AssignStmtAST *ast) = 0; + + /** + * Visit a ForIStmtAST node. + * @param ast The node to visit + */ + virtual void Visit(ForIStmtAST *ast) = 0; + + /** + * Visit an ForSStmtAST node. + * @param ast The node to visit + */ + virtual void Visit(ForSStmtAST *ast) = 0; + + /** + * Visit an SQLStmtAST node. + * @param ast The node to visit + */ + virtual void Visit(SQLStmtAST *ast) = 0; + + /** + * Visit an DynamicSQLStmtAST node. + * @param ast The node to visit + */ + virtual void Visit(DynamicSQLStmtAST *ast) = 0; +}; + +} // namespace noisepage::execution::ast::udf diff --git a/src/include/execution/ast/udf/udf_ast_nodes.h b/src/include/execution/ast/udf/udf_ast_nodes.h new file mode 100644 index 0000000000..7fd99d85d4 --- /dev/null +++ b/src/include/execution/ast/udf/udf_ast_nodes.h @@ -0,0 +1,793 @@ +#pragma once + +#include +#include +#include +#include +#include + +#include "parser/expression/constant_value_expression.h" +#include "parser/expression_defs.h" +#include "parser/parse_result.h" + +#include "execution/ast/udf/node_types.h" +#include "execution/ast/udf/udf_ast_node_visitor.h" +#include "execution/sql/value.h" + +namespace noisepage::execution::ast::udf { + +/** + * Get the string representation of a node type. + * @param type The node type + * @return The string representation + */ +std::string NodeTypeToShortString(NodeType type); + +/** + * The AbstractAST class serves as a base class for all AST nodes. + */ +class AbstractAST { + public: + /** + * Construct a new AbstractAST node instance. + * @param type The type of the node + */ + explicit AbstractAST(NodeType type) : type_{type} {} + + /** Destroy the AST node. */ + virtual ~AbstractAST() = default; + + /** + * AST visitor pattern. + * @param visitor The visitor + */ + virtual void Accept(ASTNodeVisitor *visitor) { visitor->Visit(this); } + + /** @return The type of the AST node */ + NodeType GetType() const { return type_; } + + private: + /** The type of the AST node */ + NodeType type_; +}; + +/** + * The ExprAST class serves as the base class for all expression nodes. + */ +class ExprAST : public AbstractAST { + public: + /** + * Construct a new ExprAST instance. + * @param type The type of the expression node + */ + explicit ExprAST(NodeType type) : AbstractAST{type} {} + + /** Destroy the AST node. */ + ~ExprAST() override = default; + + /** + * AST visitor pattern. + * @param visitor The visitor + */ + void Accept(ASTNodeVisitor *visitor) override { visitor->Visit(this); } +}; + +/** + * The StmtAST class serves as the base class for all statement nodes. + */ +class StmtAST : public AbstractAST { + public: + /** + * Construct a new StmtAST instance. + * @param type The type of the statement node + */ + explicit StmtAST(NodeType type) : AbstractAST{type} {} + + /** Destroy the AST node. */ + ~StmtAST() override = default; + + /** + * AST visitor pattern. + * @param visitor The visitor + */ + void Accept(ASTNodeVisitor *visitor) override { visitor->Visit(this); } +}; + +/** + * The ValueExprAST class represents literal values. + */ +class ValueExprAST : public ExprAST { + public: + /** + * Construct a new ValueExprAST instance. + * @param value The AbstractExpression that represents the value + */ + explicit ValueExprAST(std::unique_ptr &&value) + : ExprAST{NodeType::VALUE_EXPR}, value_(std::move(value)) {} + + /** + * AST visitor pattern. + * @param visitor The visitor + */ + void Accept(ASTNodeVisitor *visitor) override { visitor->Visit(this); } + + /** @return A mutable pointer to the value expression */ + parser::AbstractExpression *Value() { return value_.get(); } + + /** @return An immutable pointer to the value expression */ + const parser::AbstractExpression *Value() const { return value_.get(); } + + private: + /** The expression that represents the value */ + std::unique_ptr value_; +}; + +/** + * The IsNullExprAST class represents an expression that performs a NULL check. + */ +class IsNullExprAST : public ExprAST { + public: + /** + * Construct a new IsNullExprAST instance. + * @param is_null_check The NULL check flag + * @param child The child expression + */ + IsNullExprAST(bool is_null_check, std::unique_ptr &&child) + : ExprAST{NodeType::IS_NULL_EXPR}, is_null_check_{is_null_check}, child_{std::move(child)} {} + + /** + * AST visitor pattern. + * @param visitor The visitor + */ + void Accept(ASTNodeVisitor *visitor) override { visitor->Visit(this); } + + /** @return `true` if the NULL check is performed, `false` otherwise */ + bool IsNullCheck() const { return is_null_check_; } + + /** @return The child expression */ + ExprAST *Child() { return child_.get(); } + + /** @return The child expression */ + const ExprAST *Child() const { return child_.get(); } + + private: + /** The NULL check flag */ + bool is_null_check_; + + /** The child expression */ + std::unique_ptr child_; +}; + +/** + * The VariableExprAST class represents an expression that references a variable. + */ +class VariableExprAST : public ExprAST { + public: + /** + * Construct a new VariableExprAST instance. + * @param name The name of the variable + */ + explicit VariableExprAST(std::string name) : ExprAST{NodeType::VARIABLE_EXPR}, name_{std::move(name)} {} + + /** + * AST visitor pattern. + * @param visitor The visitor + */ + void Accept(ASTNodeVisitor *visitor) override { visitor->Visit(this); } + + /** @return The name of the variable */ + const std::string &Name() const { return name_; } + + private: + /** The name of the variable */ + const std::string name_; +}; + +/** + * The MemberExprAST class represents a structure member expression. + */ +class MemberExprAST : public ExprAST { + public: + /** + * Construct a new MemberExprAST instance. + * @param object The structure + * @param field The name of the field in the structure + */ + MemberExprAST(std::unique_ptr &&object, std::string field) + : ExprAST{NodeType::MEMBER_EXPR}, object_{std::move(object)}, field_(std::move(field)) {} + + /** + * AST visitor pattern. + * @param visitor The visitor + */ + void Accept(ASTNodeVisitor *visitor) override { visitor->Visit(this); } + + /** @return The object */ + VariableExprAST *Object() { return object_.get(); } + + /** @return The object */ + const VariableExprAST *Object() const { return object_.get(); } + + /** @return The name of the field */ + const std::string &FieldName() const { return field_; } + + private: + /** The expression for the object */ + std::unique_ptr object_; + + /** The identifier for the field in the object */ + std::string field_; +}; + +/** + * The BinaryExprAST class represents a generic binary expression. + */ +class BinaryExprAST : public ExprAST { + public: + /** + * Construct a new BinaryExprAST instance. + * @param op The expression type for the operation + * @param lhs The expression on the left-hande side of the operation + * @param rhs The expression on the right-hand side of the operation + */ + BinaryExprAST(parser::ExpressionType op, std::unique_ptr &&lhs, std::unique_ptr &&rhs) + : ExprAST{NodeType::BINARY_EXPR}, op_{op}, lhs_{std::move(lhs)}, rhs_{std::move(rhs)} {} + + /** + * AST visitor pattern. + * @param visitor The visitor + */ + void Accept(ASTNodeVisitor *visitor) override { visitor->Visit(this); } + + /** @return The expression type for the operation */ + parser::ExpressionType Op() const { return op_; } + + /** @return A mutable pointer to the left expression */ + ExprAST *Left() { return lhs_.get(); } + + /** @return An immutable pointer to the left expression */ + const ExprAST *Left() const { return lhs_.get(); } + + /** @return A mutable pointer to the right expression */ + ExprAST *Right() { return rhs_.get(); } + + /** @return An immutable pointer to the right expression */ + const ExprAST *Right() const { return rhs_.get(); } + + private: + /** The expression type for the operation */ + parser::ExpressionType op_; + + /** The expression on the left-hand side of the operation */ + std::unique_ptr lhs_; + + /** The expression on the right-hand side of the operation */ + std::unique_ptr rhs_; +}; + +/** + * The CallExprAST class represents a function call expression. + */ +class CallExprAST : public ExprAST { + public: + /** + * Construct a new CallExprAST instance. + * @param callee The name of the called function + * @param args The arguments to the function call + */ + CallExprAST(std::string callee, std::vector> &&args) + : ExprAST{NodeType::CALL_EXPR}, callee_{std::move(callee)}, args_{std::move(args)} {} + + /** + * AST visitor pattern. + * @param visitor The visitor + */ + void Accept(ASTNodeVisitor *visitor) override { visitor->Visit(this); } + + /** @return The name of the called function */ + const std::string &Callee() const { return callee_; } + + /** @return A mutable reference to the function call arguments */ + std::vector> &Args() { return args_; } + + /** @return An immutable reference to the function call arguments */ + const std::vector> &Args() const { return args_; } + + private: + /** The name of the called function */ + const std::string callee_; + + /** The arguments to the function call */ + std::vector> args_; +}; + +/** + * The SeqStmtAST class represents a sequence of statements. + */ +class SeqStmtAST : public StmtAST { + public: + /** + * Construct a new SeqStmtAST instance. + * @param statements The collection of statements in the sequence + */ + explicit SeqStmtAST(std::vector> &&statements) + : StmtAST{NodeType::SEQ_STMT}, statements_(std::move(statements)) {} + + /** + * AST visitor pattern. + * @param visitor The visitor + */ + void Accept(ASTNodeVisitor *visitor) override { visitor->Visit(this); } + + /** @return A mutable reference to the statements in the sequence */ + std::vector> &Statements() { return statements_; } + + /** @return An immutable reference to the statements in the sequence */ + const std::vector> &Statements() const { return statements_; } + + private: + /** The collection of statements in the sequence */ + std::vector> statements_; +}; + +// DeclStmtAST - Statement class for sequence of statements +/** + * The DeclStmtAST class represents a declaration statement. + */ +class DeclStmtAST : public StmtAST { + public: + /** + * Construct a new DeclStmtAST instance. + * @param name The name of the variable that is declared + * @param type The type of the declared variable + * @param initial The initial value in the declaration + */ + DeclStmtAST(std::string name, sql::SqlTypeId type, std::unique_ptr &&initial) + : StmtAST{NodeType::DECL_STMT}, name_{std::move(name)}, type_(type), initial_{std::move(initial)} {} + + /** + * AST visitor pattern. + * @param visitor The visitor + */ + void Accept(ASTNodeVisitor *visitor) override { visitor->Visit(this); }; + + /** @return The name of the declared variable */ + const std::string &Name() const { return name_; } + + /** @return The type of the declared variable */ + sql::SqlTypeId Type() const { return type_; } + + /** @return A mutable pointer to the initial value expression */ + ExprAST *Initial() { return initial_.get(); } + + /** @return An immutable pointer to the initial value expression */ + const ExprAST *Initial() const { return initial_.get(); } + + private: + /** The name of the variable declared in the statement */ + std::string name_; + + /** The type of the declared variable */ + sql::SqlTypeId type_; + + /** The initial value of the declaration */ + std::unique_ptr initial_; +}; + +/** + * The IfStmtAST class represents an IF/THEN/ELSE construct. + */ +class IfStmtAST : public StmtAST { + public: + /** + * Construct a new IfStmtAST instance. + * @param cond_expr The conditional expression + * @param then_stmt The `then` statement + * @param else_stmt The `else` statement + */ + IfStmtAST(std::unique_ptr &&cond_expr, std::unique_ptr &&then_stmt, + std::unique_ptr &&else_stmt) + : StmtAST{NodeType::IF_STMT}, + cond_expr_{std::move(cond_expr)}, + then_stmt_{std::move(then_stmt)}, + else_stmt_{std::move(else_stmt)} {} + + /** + * AST visitor pattern. + * @param visitor The visitor + */ + void Accept(ASTNodeVisitor *visitor) override { visitor->Visit(this); }; + + /** @return The conditional expression */ + ExprAST *Condition() { return cond_expr_.get(); } + + /** @return The conditional expression */ + const ExprAST *Condition() const { return cond_expr_.get(); } + + /** @return The `then` statement */ + StmtAST *Then() { return then_stmt_.get(); } + + /** @return The `then` statement */ + const StmtAST *Then() const { return then_stmt_.get(); } + + /** @return The `else` statement */ + StmtAST *Else() { return else_stmt_.get(); } + + /** @return The `else` statement */ + const StmtAST *Else() const { return else_stmt_.get(); } + + private: + /** The conditional expression */ + std::unique_ptr cond_expr_; + + /** The `then` statement */ + std::unique_ptr then_stmt_; + + /** The `else` statement */ + std::unique_ptr else_stmt_; +}; + +/** + * The ForIStmtAST class represents a `for`-loop construct. + * + * Ex: FOR i IN 1..10 LOOP... + */ +class ForIStmtAST : public StmtAST { + public: + /** + * The default query that defines the "step" expression. + * + * The PLpgSQL documentation specifies this behavior. + */ + constexpr static const char DEFAULT_STEP_EXPR[] = "SELECT 1"; + + /** + * Construct a new ForIStmtAST instance. + * @param variable The loop induction variable + * @param lower The loop lower bound + * @param upper The loop upper bound + * @param step The loop step + * @param body The body of the loop + */ + ForIStmtAST(std::string variable, std::unique_ptr lower, std::unique_ptr upper, + std::unique_ptr step, std::unique_ptr body) + : StmtAST{NodeType::FORI_STMT}, + variable_{std::move(variable)}, + lower_{std::move(lower)}, + upper_{std::move(upper)}, + step_{std::move(step)}, + body_{std::move(body)} {} + + /** + * AST visitor pattern. + * @param visitor The visitor + */ + void Accept(ASTNodeVisitor *visitor) override { visitor->Visit(this); }; + + /** @return The loop variable */ + const std::string &Variable() const { return variable_; } + + /** @return A mutable pointer to the loop lower-bound expression */ + ExprAST *Lower() { return lower_.get(); } + + /** @return An immutable pointer to the loop lower-bound expression */ + const ExprAST *Lower() const { return lower_.get(); } + + /** @return A mutable pointer to the loop upper-bound expression */ + ExprAST *Upper() { return upper_.get(); } + + /** @return An immutable pointer to the loop upper-bound expression */ + const ExprAST *Upper() const { return upper_.get(); } + + /** @return A mutable pointer to the loop step expression */ + ExprAST *Step() { return step_.get(); } + + /** @return An immutable pointer to the loop step expression */ + const ExprAST *Step() const { return step_.get(); } + + /** @return A mutable pointer to the loop body statement */ + StmtAST *Body() { return body_.get(); } + + /** @return An immutable pointer to the loop body statement */ + const StmtAST *Body() const { return body_.get(); } + + private: + /** The identifier for the loop variable */ + const std::string variable_; + /** The expression that defines the loop lower-bound */ + std::unique_ptr lower_; + /** The expression that defines the loop upper-bound */ + std::unique_ptr upper_; + /** The expression that defines the loop step */ + std::unique_ptr step_; + /** The loop body */ + std::unique_ptr body_; +}; + +/** + * The ForSStmtAST class represents a `for`-loop construct. + * + * Ex: FOR record IN (SELECT * FROM tmp) LOOP ... + */ +class ForSStmtAST : public StmtAST { + public: + /** + * Construct a new ForSStmtAST instance. + * @param variables The collection of variables in the loop + * @param query The associated query + * @param body The body of the loop + */ + ForSStmtAST(std::vector &&variables, std::unique_ptr &&query, + std::unique_ptr body) + : StmtAST{NodeType::FORS_STMT}, + variables_{std::move(variables)}, + query_{std::move(query)}, + body_{std::move(body)} {} + + /** + * AST visitor pattern. + * @param visitor The visitor + */ + void Accept(ASTNodeVisitor *visitor) override { visitor->Visit(this); }; + + /** @return The collection of loop variables */ + const std::vector &Variables() const { return variables_; } + + /** @return The associated query */ + parser::ParseResult *Query() { return query_.get(); } + + /** @return The associated query */ + const parser::ParseResult *Query() const { return query_.get(); } + + /** @return The loop body statement */ + StmtAST *Body() { return body_.get(); } + + /** @return The loop body statement */ + const StmtAST *Body() const { return body_.get(); } + + private: + /** The collection of loop variables */ + std::vector variables_; + + /** The associated query */ + std::unique_ptr query_; + + /** The loop body statement */ + std::unique_ptr body_; +}; + +/** + * The WhileStmtAST represents a `while`-loop construct. + */ +class WhileStmtAST : public StmtAST { + public: + /** + * Construct a new WhileStmtAST instance. + * @param condition The loop condition + * @param body The loop body statement + */ + WhileStmtAST(std::unique_ptr &&condition, std::unique_ptr &&body) + : StmtAST{NodeType::WHILE_STMT}, condition_{std::move(condition)}, body_{std::move(body)} {} + + /** + * AST visitor pattern. + * @param visitor The visitor + */ + void Accept(ASTNodeVisitor *visitor) override { visitor->Visit(this); } + + /** @return The loop condition */ + ExprAST *Condition() { return condition_.get(); } + + /** @return The loop condition */ + const ExprAST *Condition() const { return condition_.get(); } + + /** @return The loop body statement */ + StmtAST *Body() { return body_.get(); } + + /** @return The loop body statement */ + const StmtAST *Body() const { return body_.get(); } + + private: + /** The loop condition */ + std::unique_ptr condition_; + + /** The loop body statement */ + std::unique_ptr body_; +}; + +/** + * The RetStmtAST class represents a `return` statement. + */ +class RetStmtAST : public StmtAST { + public: + /** + * Construct a new RetStmtAST instance. + * @param ret_expr The `return` expression + */ + explicit RetStmtAST(std::unique_ptr &&ret_expr) + : StmtAST{NodeType::RET_STMT}, ret_expr_{std::move(ret_expr)} {} + + /** + * AST visitor pattern. + * @param visitor The visitor + */ + void Accept(ASTNodeVisitor *visitor) override { visitor->Visit(this); } + + /** @return The `return` expression */ + ExprAST *Return() { return ret_expr_.get(); } + + /** @return The `return` expression */ + const ExprAST *Return() const { return ret_expr_.get(); } + + private: + /** The `return` expression */ + std::unique_ptr ret_expr_; +}; + +/** + * The AssignStmtAST class represents an assignment statement. + */ +class AssignStmtAST : public StmtAST { + public: + /** + * Construct a new AssignStmtAST instance. + * @param dst The variable that represents the destination of the assignment + * @param src The expression that represents the source of the assignment + */ + AssignStmtAST(std::unique_ptr &&dst, std::unique_ptr &&src) + : StmtAST{NodeType::ASSIGN_STMT}, dst_{std::move(dst)}, src_{std::move(src)} {} + + /** + * AST visitor pattern. + * @param visitor The visitor + */ + void Accept(ASTNodeVisitor *visitor) override { visitor->Visit(this); } + + /** @return The destination variable of the assignment */ + VariableExprAST *Destination() { return dst_.get(); } + + /** @return The destination variable of the assignment */ + const VariableExprAST *Destination() const { return dst_.get(); } + + /** @return The source expression of the assignment */ + ExprAST *Source() { return src_.get(); } + + /** @return The source expression of the assignment */ + const ExprAST *Source() const { return src_.get(); } + + private: + /** The destination of the assignment */ + std::unique_ptr dst_; + + /** The source of the assignment */ + std::unique_ptr src_; +}; + +/** + * The SQLStmtAST class represents a SQL statement. + */ +class SQLStmtAST : public StmtAST { + public: + /** + * Construct a new SQLStmtAST instance. + * @param query The result of parsing the SQL query + * @param variables The collection of identifiers of variables + * to which results of the query are bound + */ + SQLStmtAST(std::unique_ptr &&query, std::vector &&variables) + : StmtAST{NodeType::SQL_STMT}, query_{std::move(query)}, variables_{std::move(variables)} {} + + /** + * AST visitor pattern. + * @param visitor The visitor + */ + void Accept(ASTNodeVisitor *visitor) override { visitor->Visit(this); } + + /** @return The result of parsing the SQL query */ + parser::ParseResult *Query() { return query_.get(); } + + /** @return The result of parsing the SQL query */ + const parser::ParseResult *Query() const { return query_.get(); } + + /** @return The variable names to which results are bound */ + const std::vector &Variables() const { return variables_; } + + private: + /** The result of parsing the SQL query */ + std::unique_ptr query_; + + /** The names of the variables to which results are bound */ + std::vector variables_; +}; + +/** + * The DynamicSQLStmtAST class represents a dynamic SQL statement. + */ +class DynamicSQLStmtAST : public StmtAST { + public: + /** + * Construct a new DynamicSQLStmtAST instance. + * @param query The expression that represents the query + * @param name The name of the variable to which results are bound + */ + DynamicSQLStmtAST(std::unique_ptr &&query, std::string name) + : StmtAST{NodeType::DYNAMIC_SQL_STMT}, query_{std::move(query)}, name_{std::move(name)} {} + + /** + * AST visitor pattern. + * @param visitor The visitor + */ + void Accept(ASTNodeVisitor *visitor) override { visitor->Visit(this); } + + /** @return The expression that represents the query */ + const ExprAST *Query() const { return query_.get(); } + + /** @return The name of the variable to which results are bound */ + const std::string &Name() const { return name_; } + + private: + /** The expression that represents the query */ + std::unique_ptr query_; + + /** The name of the variable to which results are bound */ + std::string name_; +}; + +/** + * The FunctionAST class represents a function definition. + */ +class FunctionAST : public AbstractAST { + public: + /** + * Construct a new FunctionAST instance. + * @param body The body of the function + * @param parameter_names The names of the parameters to the function + * @param parameter_types The types of the parameters to the function + */ + FunctionAST(std::unique_ptr &&body, std::vector parameter_names, + std::vector parameter_types) + : AbstractAST{NodeType::FUNCTION}, + body_{std::move(body)}, + parameter_names_{std::move(parameter_names)}, + parameter_types_{std::move(parameter_types)} { + NOISEPAGE_ASSERT(parameter_names_.size() == parameter_types_.size(), "Parameter Name and Type Mismatch"); + // TODO(Kyle): The copies made in this constructor may not be necessary, + // I need to look more closely at the ownership for this data + } + + /** + * AST visitor pattern. + * @param visitor The visitor + */ + void Accept(ASTNodeVisitor *visitor) override { visitor->Visit(this); } + + /** @return The function body */ + StmtAST *Body() { return body_.get(); } + + /** @return The function body */ + const StmtAST *Body() const { return body_.get(); } + + /** The function parameter names */ + const std::vector &ParameterNames() const { return parameter_names_; } + + /** @return The function parameter types */ + const std::vector &ParameterTypes() const { return parameter_types_; } + + private: + /** The body of the function */ + std::unique_ptr body_; + + /** The names of the parameters to the function */ + std::vector parameter_names_; + + /** The types of the parameters to the function */ + std::vector parameter_types_; +}; + +// ---------------------------------------------------------------------------- +// Error Handling Helpers +// ---------------------------------------------------------------------------- + +std::unique_ptr LogError(const char *str); + +} // namespace noisepage::execution::ast::udf diff --git a/src/include/execution/compiler/ast_fwd.h b/src/include/execution/compiler/ast_fwd.h index 1f94ff4ec4..b12eb753ad 100644 --- a/src/include/execution/compiler/ast_fwd.h +++ b/src/include/execution/compiler/ast_fwd.h @@ -12,6 +12,7 @@ class Decl; class FieldDecl; class File; class FunctionDecl; +class LambdaExpr; class Stmt; class StructDecl; class VariableDecl; diff --git a/src/include/execution/compiler/codegen.h b/src/include/execution/compiler/codegen.h index 87b2fadb5c..066f2fb8d0 100644 --- a/src/include/execution/compiler/codegen.h +++ b/src/include/execution/compiler/codegen.h @@ -195,6 +195,11 @@ class CodeGen { */ [[nodiscard]] ast::Expr *Float64Type() const; + /** + * @return The type representation for a TPL lambda. + */ + [[nodiscard]] ast::Expr *LambdaType(ast::Expr *fn_type); + /** * @return The type representation for the provided builtin type. */ @@ -411,6 +416,12 @@ class CodeGen { */ [[nodiscard]] ast::Expr *AccessStructMember(ast::Expr *object, ast::Identifier member); + /** + * Create a break statement. + * @return The statement. + */ + [[nodiscard]] ast::Stmt *Break(); + /** * Create a return statement without a return value. * @return The statement. diff --git a/src/include/execution/compiler/compilation_context.h b/src/include/execution/compiler/compilation_context.h index 254952af51..1ab178089d 100644 --- a/src/include/execution/compiler/compilation_context.h +++ b/src/include/execution/compiler/compilation_context.h @@ -55,12 +55,15 @@ class CompilationContext { * @param mode The compilation mode. * @param override_qid Optional indicating how to override the plan's query id * @param plan_meta_data Query plan meta data (stores cardinality information) + * @param output_callback The lambda utilized as the output callback for the query + * @param context The AST context for the query */ static std::unique_ptr Compile( const planner::AbstractPlanNode &plan, const exec::ExecutionSettings &exec_settings, catalog::CatalogAccessor *accessor, CompilationMode mode = CompilationMode::Interleaved, std::optional override_qid = std::nullopt, - common::ManagedPointer plan_meta_data = nullptr); + common::ManagedPointer plan_meta_data = nullptr, + ast::LambdaExpr *output_callback = nullptr, common::ManagedPointer context = nullptr); /** * Register a pipeline in this context. @@ -82,16 +85,15 @@ class CompilationContext { */ void Prepare(const parser::AbstractExpression &expression); - /** - * @return The code generator instance. - */ + /** @return The code generator instance. */ CodeGen *GetCodeGen() { return &codegen_; } - /** - * @return The query state. - */ + /** @return The query state. */ StateDescriptor *GetQueryState() { return &query_state_; } + /** @return The identifier for the query state variable */ + ast::Identifier GetQueryStateName() const { return query_state_var_; } + /** * @return The translator for the given relational plan node; null if the provided plan node does * not have a translator registered in this context. @@ -104,26 +106,24 @@ class CompilationContext { */ ExpressionTranslator *LookupTranslator(const parser::AbstractExpression &expr) const; - /** - * @return A common prefix for all functions generated in this module. - */ + /** @return A common prefix for all functions generated in this module. */ std::string GetFunctionPrefix() const; - /** - * @return The list of parameters common to all query functions. For now, just the query state. - */ + /** @return The list of parameters common to all query functions. For now, just the query state. */ util::RegionVector QueryParams() const; - /** - * @return The slot in the query state where the execution context can be found. - */ + /** @return The slot in the query state where the execution context can be found. */ ast::Expr *GetExecutionContextPtrFromQueryState(); - /** - * @return The compilation mode. - */ + /** @return The compilation mode. */ CompilationMode GetCompilationMode() const { return mode_; } + /** @return The output callback. */ + ast::LambdaExpr *GetOutputCallback() const { return output_callback_; } + + /** @return `true` if the compilation context has an output callback, `false` otherwise */ + bool HasOutputCallback() const { return output_callback_ != nullptr; } + /** @return True if we should collect counters in TPL, used for Lin's models. */ bool IsCountersEnabled() const { return counters_enabled_; } @@ -136,7 +136,8 @@ class CompilationContext { private: // Private to force use of static Compile() function. explicit CompilationContext(ExecutableQuery *query, query_id_t query_id_, catalog::CatalogAccessor *accessor, - CompilationMode mode, const exec::ExecutionSettings &exec_settings); + CompilationMode mode, const exec::ExecutionSettings &exec_settings, + ast::LambdaExpr *output_callback = nullptr); // Given a plan node, compile it into a compiled query object. void GeneratePlan(const planner::AbstractPlanNode &plan, @@ -175,6 +176,9 @@ class CompilationContext { StateDescriptor query_state_; StateDescriptor::Entry exec_ctx_; + // The output callback. + ast::LambdaExpr *output_callback_; + // The operator and expression translators. std::unordered_map> ops_; std::unordered_map> expressions_; diff --git a/src/include/execution/compiler/executable_query.h b/src/include/execution/compiler/executable_query.h index 94ab1b5d91..51bfa79a45 100644 --- a/src/include/execution/compiler/executable_query.h +++ b/src/include/execution/compiler/executable_query.h @@ -2,6 +2,7 @@ #include #include +#include #include #include @@ -9,7 +10,7 @@ #include "common/managed_pointer.h" #include "execution/ast/ast_fwd.h" #include "execution/exec_defs.h" -#include "execution/vm/vm_defs.h" +#include "execution/vm/execution_mode.h" #include "transaction/transaction_defs.h" namespace noisepage { @@ -37,6 +38,7 @@ class Region; namespace vm { class Module; class ModuleMetadata; +class FunctionInfo; } // namespace vm } // namespace execution @@ -69,6 +71,9 @@ class ExecutableQuery { public: /** * Construct a fragment composed of the given functions from the given module. + * + * This constructor assumes that no file is present for the fragment. + * * @param functions The name of the functions to execute, in order. * @param teardown_fns The name of the teardown functions in the module, in order. * @param module The module that contains the functions. @@ -76,6 +81,16 @@ class ExecutableQuery { Fragment(std::vector &&functions, std::vector &&teardown_fns, std::unique_ptr module); + /** + * Construct a fragment composed of the given functions from the given module. + * @param functions The name of the functions to execute, in order. + * @param teardown_fns The name of the teardown functions in the module, in order. + * @param module The module that contains the functions. + * @param file The file associated with the fragment + */ + Fragment(std::vector &&functions, std::vector &&teardown_fns, + std::unique_ptr module, ast::File *file); + /** * Destructor. */ @@ -93,28 +108,50 @@ class ExecutableQuery { */ bool IsCompiled() const { return module_ != nullptr; } + /** @return The functions in the fragment, in program execution order*/ + const std::vector &GetFunctions() const { return functions_; } + + /** + * Get the metatdata for the bytecode function identified by `name`. + * @param name The name of the function to query. + * @return The function metadata for the specified function, + * or empty optional in the event that the function is not present + */ + std::optional GetFunctionMetadata(const std::string &name) const; + + /** + * @return The file. + */ + ast::File *GetFile() { return file_; } + /** @return The metadata of this module. */ const vm::ModuleMetadata &GetModuleMetadata() const; private: - // The functions that must be run (in the provided order) to execute this - // query fragment. + // The functions that must be run (in the provided order) + // to execute this query fragment. std::vector functions_; - std::vector teardown_fn_; + // The functions that must be run (in the provided order) + // to tear down this query fragment. + std::vector teardown_fns_; // The module. std::unique_ptr module_; + + // The file. + ast::File *file_; }; /** - * Create a query object. + * Construct a new ExecutableQuery instance. * @param plan The physical plan. - * @param exec_settings The execution settings used for this query. + * @param exec_settings The execution settings used for this query * @param timestamp The start timestamp of the transaction that generates this ExecutableQuery + * @param context The AST context for the executable query; may be nullptr */ ExecutableQuery(const planner::AbstractPlanNode &plan, const exec::ExecutionSettings &exec_settings, - transaction::timestamp_t timestamp); + transaction::timestamp_t timestamp, ast::Context *context = nullptr); /** * This class cannot be copied or moved. @@ -153,7 +190,7 @@ class ExecutableQuery { /** * @return The AST context. */ - ast::Context *GetContext() { return ast_context_.get(); } + ast::Context *GetContext() { return ast_context_; } /** @return The execution settings used for this query. */ const exec::ExecutionSettings &GetExecutionSettings() const { return exec_settings_; } @@ -169,27 +206,60 @@ class ExecutableQuery { /** @return The Query Identifier */ query_id_t GetQueryId() { return query_id_; } + /** @brief Set the query state type */ + void SetQueryStateType(ast::StructDecl *query_state_type) { query_state_type_ = query_state_type; } + + /** @return The query state type */ + ast::StructDecl *GetQueryStateType() const { return query_state_type_; } + + /** @param query_text The SQL string for this query */ + void SetQueryText(common::ManagedPointer query_text) { query_text_ = query_text; } + + /** @return The SQL query string */ + common::ManagedPointer GetQueryText() { return query_text_; } + + /** @return All of the function names in the executable query, in program execution order. */ + std::vector GetFunctionNames() const; + + /** @return The metadata for each TPL function in the executable query, in program execution order. */ + std::vector GetFunctionMetadata() const; + + /** @return All of the declarations in the executable query. */ + std::vector GetDecls() const; + /** @return The query fragments in this module. */ const std::vector> &GetFragments() const { return fragments_; } private: // The plan. const planner::AbstractPlanNode &plan_; + // The execution settings used for code generation. const exec::ExecutionSettings &exec_settings_; // The start timestamp of the transaction that generates this ExecutableQuery const transaction::timestamp_t timestamp_; - std::unique_ptr errors_region_; + + // The regions for context and errors std::unique_ptr context_region_; + std::unique_ptr errors_region_; + // The AST error reporter. std::unique_ptr errors_; + // The AST context used to generate the TPL AST. - std::unique_ptr ast_context_; + ast::Context *ast_context_; + // Denotes whether or not the ExecutableQuery owns the AST context. + bool owns_ast_context_; + // The compiled query fragments that make up the query. std::vector> fragments_; + // The query state size. std::size_t query_state_size_; + // The type of the query state. + ast::StructDecl *query_state_type_; + // The pipeline operating units that were generated as part of this query. std::unique_ptr pipeline_operating_units_; @@ -197,8 +267,8 @@ class ExecutableQuery { /** Legacy constructor that creates a hardcoded fragment with main(ExecutionContext*)->int32. */ ExecutableQuery(const std::string &contents, common::ManagedPointer exec_ctx, bool is_file, - size_t query_state_size, const exec::ExecutionSettings &exec_settings, - transaction::timestamp_t timestamp); + std::size_t query_state_size, const exec::ExecutionSettings &exec_settings, + transaction::timestamp_t timestamp, ast::Context *context = nullptr); /** * Set Pipeline Operating Units for use by mini_runners * @param units Pipeline Operating Units @@ -211,9 +281,14 @@ class ExecutableQuery { */ void SetQueryId(query_id_t query_id) { query_id_ = query_id; } + // The name of the query std::string query_name_; + // The query identitifier query_id_t query_id_; + // TODO(Kyle): What is this for? static std::atomic query_identifier; + // The text of the query + common::ManagedPointer query_text_; // MiniRunners needs to set query_identifier and pipeline_operating_units_. friend class noisepage::runner::ExecutionRunners; diff --git a/src/include/execution/compiler/expression/expression_translator.h b/src/include/execution/compiler/expression/expression_translator.h index 2e120a09d9..a91e5074cc 100644 --- a/src/include/execution/compiler/expression/expression_translator.h +++ b/src/include/execution/compiler/expression/expression_translator.h @@ -5,6 +5,7 @@ #include "common/macros.h" #include "execution/ast/ast_fwd.h" #include "execution/compiler/expression/column_value_provider.h" +#include "execution/util/region_containers.h" namespace noisepage::parser { class AbstractExpression; @@ -47,6 +48,26 @@ class ExpressionTranslator { */ virtual ast::Expr *DeriveValue(WorkContext *ctx, const ColumnValueProvider *provider) const = 0; + /** + * Define all of the helper functions for this expression translator. + * + * The default implementation simply invokes the DefineHelperFunctions() + * method for each child of the current expression translator. + * + * @param decls The collection of function declarations. + */ + virtual void DefineHelperFunctions(util::RegionVector *decls); + + /** + * Define all of the helper structs for this expression translator. + * + * The default implementation simply invokes the DefineHelperStructs() + * method for each child of the current expression translator. + * + * @param decls The collection of struct declarations. + */ + virtual void DefineHelperStructs(util::RegionVector *decls); + /** * @return The expression being translated. */ diff --git a/src/include/execution/compiler/expression/function_translator.h b/src/include/execution/compiler/expression/function_translator.h index 6cdcfe18ea..a4f645a05f 100644 --- a/src/include/execution/compiler/expression/function_translator.h +++ b/src/include/execution/compiler/expression/function_translator.h @@ -1,6 +1,11 @@ #pragma once +#include +#include + #include "execution/compiler/expression/expression_translator.h" +#include "execution/functions/function_context.h" +#include "execution/util/region_containers.h" namespace noisepage::parser { class FunctionExpression; @@ -27,6 +32,22 @@ class FunctionTranslator : public ExpressionTranslator { * @return The value of the expression. */ ast::Expr *DeriveValue(WorkContext *ctx, const ColumnValueProvider *provider) const override; + + /** + * Define the helper functions for this function translator. + * @param decls The collection of helper function declarations + */ + void DefineHelperFunctions(util::RegionVector *decls) override; + + /** + * Define the helper structs for this function translator. + * @param decls The collection of helper struct declarations + */ + void DefineHelperStructs(util::RegionVector *decls) override; + + private: + std::vector params_; + ast::Identifier main_fn_; }; } // namespace noisepage::execution::compiler diff --git a/src/include/execution/compiler/function_builder.h b/src/include/execution/compiler/function_builder.h index 1bfb6cc611..3353977b15 100644 --- a/src/include/execution/compiler/function_builder.h +++ b/src/include/execution/compiler/function_builder.h @@ -2,7 +2,10 @@ #include #include +#include +#include +#include "common/macros.h" #include "execution/ast/identifier.h" #include "execution/compiler/ast_fwd.h" #include "execution/util/region_containers.h" @@ -11,6 +14,9 @@ namespace noisepage::execution::compiler { class CodeGen; +/** Enumerates the function types */ +enum class FunctionType { FUNCTION, CLOSURE }; + /** * Helper class to build TPL functions. */ @@ -20,76 +26,106 @@ class FunctionBuilder { public: /** - * Create a builder for a function with the provided name, return type, and arguments. - * @param codegen The code generation instance. - * @param name The name of the function. - * @param params The parameters to the function. - * @param ret_type The return type representation of the function. + * Construct a new FunctionBuilder instance for a "vanilla" function. + * @param codegen The code generation instance + * @param name The function name + * @param params The function parameters + * @param return_type The return type representation of the function */ FunctionBuilder(CodeGen *codegen, ast::Identifier name, util::RegionVector &¶ms, - ast::Expr *ret_type); + ast::Expr *return_type); /** - * Destructor. + * Construct a new FunctionBuilder instance for a closure. + * @param codegen The code generation instance + * @param params The function parameters + * @param captures The function captures + * @param return_type The return type representation of the function */ + FunctionBuilder(CodeGen *codegen, util::RegionVector &¶ms, + util::RegionVector &&captures, ast::Expr *return_type); + + /** Destructor; invokes FunctionBuilder::Finish() */ ~FunctionBuilder(); - /** - * @return A reference to a function parameter by its ordinal position. - */ - ast::Expr *GetParameterByPosition(uint32_t param_idx); + /** @return The arity of the function */ + std::size_t GetParameterCount() const { return params_.size(); } + + /** @return A reference to a function parameter by its ordinal position */ + ast::Expr *GetParameterByPosition(std::size_t param_idx); + + /** @return The expression representation of the parameters to the function */ + std::vector GetParameters() const; /** * Append a statement to the list of statements in this function. - * @param stmt The statement to append. + * @param stmt The statement to append */ void Append(ast::Stmt *stmt); /** * Append an expression as a statement to the list of statements in this function. - * @param expr The expression to append as a statement. + * @param expr The expression to append as a statement */ void Append(ast::Expr *expr); /** * Append a variable declaration as a statement to the list of statements in this function. - * @param decl The declaration to append to the statement. + * @param decl The declaration to append to the statement */ void Append(ast::VariableDecl *decl); /** - * Finish constructing the function. - * @param ret The value to return from the function. Use a null pointer to return nothing. - * @return The build function declaration. + * Finish construction of the function. + * @param ret The function return value; use `nullptr` for `nil` return + * @return The finished declaration */ ast::FunctionDecl *Finish(ast::Expr *ret = nullptr); /** - * @return The final constructed function; null if the builder hasn't been constructed through - * FunctionBuilder::Finish(). + * Finish construction of the closure. + * @param ret The function return value; use `nullptr` for `nil` return + * @return The finished expression */ - ast::FunctionDecl *GetConstructedFunction() const { return decl_; } + ast::LambdaExpr *FinishClosure(ast::Expr *ret = nullptr); - /** - * @return The code generator instance. - */ + /** @return The final constructed function */ + ast::FunctionDecl *GetFinishedFunction() const { + NOISEPAGE_ASSERT(type_ == FunctionType::FUNCTION, "Attempt to get function from non-function-type builder"); + return std::get(decl_); + } + + /** @return The final constructed closure */ + ast::LambdaExpr *GetFinishedClosure() const { + NOISEPAGE_ASSERT(type_ == FunctionType::CLOSURE, "Attempt to get closure from non-closure-type builder"); + return std::get(decl_); + } + + /** @return The code generator instance. */ CodeGen *GetCodeGen() const { return codegen_; } + /** @return `true` if the function is a closure, `false` otherwise */ + bool IsClosure() const { return type_ == FunctionType::CLOSURE; } + private: - // The code generation instance. + /** The type of the function */ + FunctionType type_; + /** The code generation instance */ CodeGen *codegen_; - // The function's name. + /** The function's name */ ast::Identifier name_; - // The function's arguments. + /** The function's arguments */ util::RegionVector params_; - // The return type of the function. - ast::Expr *ret_type_; - // The start and stop position of statements in the function. + /** The captures for the closure (if applicable) */ + util::RegionVector captures_; + /** The return type of the function */ + ast::Expr *return_type_; + /** The start and stop position of statements in the function */ SourcePosition start_; - // The list of generated statements making up the function. + /** The list of generated statements making up the function */ ast::BlockStmt *statements_; - // The cached function declaration. Constructed once in Finish(). - ast::FunctionDecl *decl_; + /** The cached, completed function; constructed once in Finish() */ + std::variant decl_; }; } // namespace noisepage::execution::compiler diff --git a/src/include/execution/compiler/operator/operator_translator.h b/src/include/execution/compiler/operator/operator_translator.h index ddcc1c07a1..7309cf52b5 100644 --- a/src/include/execution/compiler/operator/operator_translator.h +++ b/src/include/execution/compiler/operator/operator_translator.h @@ -111,14 +111,14 @@ class OperatorTranslator : public ColumnValueProvider { * declaration container. * @param decls Query-level declarations. */ - virtual void DefineHelperStructs(util::RegionVector *decls) {} + virtual void DefineHelperStructs(util::RegionVector *decls); /** * Define any helper functions required for processing. Ensure they're declared in the provided * declaration container. * @param decls Query-level declarations. */ - virtual void DefineHelperFunctions(util::RegionVector *decls) {} + virtual void DefineHelperFunctions(util::RegionVector *decls); /** * Define any helper functions that rely on pipeline's thread local state. @@ -270,9 +270,23 @@ class OperatorTranslator : public ColumnValueProvider { /** Get the memory pool pointer from the execution context stored in the query state. */ ast::Expr *GetMemoryPool() const; - /** The pipeline this translator is a part of. */ + /** @return The pipeline this translator is a part of. */ Pipeline *GetPipeline() const { return pipeline_; } + /** + * Make a local identifier from `name`. + * @param name The base name for the identifier + * @return The identifier + */ + ast::Identifier MakeLocalIdentifier(std::string_view name) const; + + /** + * Make a global identifier from `name`. + * @param name The base name for the identifier + * @return The identifier + */ + ast::Identifier MakeGlobalIdentifier(std::string_view name) const; + /** The plan node for this translator as its concrete type. */ template const T &GetPlanAs() const { @@ -361,7 +375,7 @@ class OperatorTranslator : public ColumnValueProvider { const planner::AbstractPlanNode &plan_; // The compilation context. CompilationContext *compilation_context_; - // The pipeline the operator belongs to. + // The pipeline to which the operator belongs. Pipeline *pipeline_; /** The child operator translator. */ diff --git a/src/include/execution/compiler/operator/output_translator.h b/src/include/execution/compiler/operator/output_translator.h index 358aaa9c8e..f580791e26 100644 --- a/src/include/execution/compiler/operator/output_translator.h +++ b/src/include/execution/compiler/operator/output_translator.h @@ -32,13 +32,13 @@ class OutputTranslator : public OperatorTranslator { */ DISALLOW_COPY_AND_MOVE(OutputTranslator); - /** - * Define the output struct. - */ + /** Define the output struct. */ void DefineHelperStructs(util::RegionVector *decls) override; + /** Initialize pipeline state for the output translator */ void InitializePipelineState(const Pipeline &pipeline, FunctionBuilder *function) const override; + /** Teardown pipeline state for the output translator */ void TearDownPipelineState(const Pipeline &pipeline, FunctionBuilder *function) const override; void InitializeCounters(const Pipeline &pipeline, FunctionBuilder *function) const override; @@ -46,26 +46,27 @@ class OutputTranslator : public OperatorTranslator { void EndParallelPipelineWork(const Pipeline &pipeline, FunctionBuilder *function) const override; void FinishPipelineWork(const Pipeline &pipeline, FunctionBuilder *function) const override; - /** - * Perform the main work of the translator. - */ + /** Perform the main work of the translator. */ void PerformPipelineWork(WorkContext *context, FunctionBuilder *function) const override; - /** - * Does not interact with tables. - */ + /** Does not interact with tables. */ ast::Expr *GetTableColumn(catalog::col_oid_t col_oid) const override { UNREACHABLE("Output does not interact with tables."); } + /** @return `true` if the output translator has an associated output callback, `false` otherwise */ + bool HasOutputCallback() const; + private: + /** The output variable */ ast::Identifier output_var_; + /** The output structure */ ast::Identifier output_struct_; - // The number of rows that are output. + /** The number of rows that are output */ StateDescriptor::Entry num_output_; - // The OutputBuffer to use + /** The OutputBuffer to use */ StateDescriptor::Entry output_buffer_; }; diff --git a/src/include/execution/compiler/pipeline.h b/src/include/execution/compiler/pipeline.h index a298203a00..24c45e689b 100644 --- a/src/include/execution/compiler/pipeline.h +++ b/src/include/execution/compiler/pipeline.h @@ -52,9 +52,7 @@ class Pipeline { */ enum class Parallelism : uint8_t { Serial = 0, Parallel = 2 }; - /** - * Enum class representing whether the pipeline is vectorized. - */ + /** Enum class representing whether the pipeline is vectorized. */ enum class Vectorization : uint8_t { Disabled = 0, Enabled = 1 }; /** @@ -118,10 +116,10 @@ class Pipeline { /** * Registers a nested pipeline. These pipelines are invoked from other pipelines and are not added to the main steps - * @param nested_pipeline The pipeline to nest + * @param pipeline The pipeline to nest * @param op The operator translator that is nesting this pipeline */ - void LinkNestedPipeline(Pipeline *nested_pipeline, const OperatorTranslator *op); + void LinkNestedPipeline(Pipeline *pipeline, const OperatorTranslator *op); /** * Store in the provided output vector the set of all dependencies for this pipeline. In other @@ -131,6 +129,14 @@ class Pipeline { */ void CollectDependencies(std::vector *deps); + /** + * Store in the provided output vector the set of all dependencies for this pipeline. In other + * words, store in the output vector all pipelines that must execute (in order) before this + * pipeline can begin. + * @param[out] deps The sorted list of pipelines to execute before this pipeline can begin. + */ + void CollectDependencies(std::vector *deps) const; + /** * Perform initialization logic before code generation. * @param exec_settings The execution settings used for query compilation. @@ -143,50 +149,32 @@ class Pipeline { */ void GeneratePipeline(ExecutableQueryFragmentBuilder *builder) const; - /** - * @return True if the pipeline is parallel; false otherwise. - */ + /** @return `true` if the pipeline is parallel, `false` otherwise. */ bool IsParallel() const { return parallelism_ == Parallelism ::Parallel; } - /** - * @return True if this pipeline is fully vectorized; false otherwise. - */ + /** @return `true` if this pipeline is fully vectorized, `false` otherwise. */ bool IsVectorized() const { return false; } - /** - * Typedef used to specify an iterator over the steps in a pipeline. - */ + /** Typedef used to specify an iterator over the steps in a pipeline. */ using StepIterator = std::vector::const_reverse_iterator; - /** - * @return An iterator over the operators in the pipeline. - */ + /** @return An iterator over the operators in the pipeline. */ StepIterator Begin() const { return steps_.rbegin(); } - /** - * @return An iterator positioned at the end of the operators steps in the pipeline. - */ + /** @return An iterator positioned at the end of the operators steps in the pipeline. */ StepIterator End() const { return steps_.rend(); } - /** - * @return True if the given operator is the driver for this pipeline; false otherwise. - */ + /** @return True if the given operator is the driver for this pipeline; false otherwise. */ bool IsDriver(const PipelineDriver *driver) const { return driver == driver_; } - /** - * @return Arguments common to all pipeline functions. - */ + /** @return The arguments common to all pipeline functions. */ util::RegionVector PipelineParams() const; - /** - * @return A unique name for a function local to this pipeline. - */ - std::string CreatePipelineFunctionName(const std::string &func_name) const; + /** @return An identifier for the pipeline state variable */ + ast::Identifier GetPipelineStateName() const; - /** - * @return A vector of expressions that initialize, run and teardown a nested pipeline - */ - std::vector GenerateSingleRunPipelineFunction() const; + /** @return A unique name for a function local to this pipeline. */ + std::string CreatePipelineFunctionName(const std::string &func_name) const; /** * Calls a nested pipeline's execution functions @@ -196,11 +184,6 @@ class Pipeline { */ void CallNestedRunPipelineFunction(WorkContext *ctx, const OperatorTranslator *op, FunctionBuilder *function) const; - /** - * @return Pipeline state variable - */ - ast::Identifier GetPipelineStateVar() { return state_var_; } - /** @return The unique ID of this pipeline. */ pipeline_id_t GetPipelineId() const { return pipeline_id_t{id_}; } @@ -218,118 +201,172 @@ class Pipeline { */ void InjectEndResourceTracker(FunctionBuilder *builder, bool is_hook) const; - /** - * @return query identifier of the query that we are codegen-ing - */ + /** @return The identifier for the query that we are codegen-ing */ query_id_t GetQueryId() const; - /** - * @return a pointer to the OUFeatureVector in the pipeline state - */ + /** @return A pointer to the OUFeatureVector in the pipeline state */ ast::Expr *OUFeatureVecPtr() const { return oufeatures_.GetPtr(codegen_); } /** * Gets an argument from the set of "extra" pipeline arguments given to the current pipeline's function * Only applicable if this is a nested pipeline. Extra refers to arguments other than the query state and the - * pipeline state + * pipeline state. * @param index The extra argument index * @return An expression representing the requested argument */ - ast::Expr *GetNestedInputArg(uint32_t index) const; + ast::Expr *GetNestedInputArg(std::size_t index) const; - /** - * @return true iff this pipeline has already been prepared - */ + /** @return `true` if this pipeline is prepared, `false` otherwise */ bool IsPrepared() const { return prepared_; } + /** @return The output callback for the pipeline */ + ast::LambdaExpr *GetOutputCallback() const; + + /** @return `true` if this pipeline has an output callback, `false` otherwise */ + bool HasOutputCallback() const; + private: - // Return the thread-local state initialization and tear-down function names. - // This is needed when we invoke @tlsReset() from the pipeline initialization - // function to setup the thread-local state. - ast::Identifier GetSetupPipelineStateFunctionName() const; - ast::Identifier GetTearDownPipelineStateFunctionName() const; - ast::Identifier GetWorkFunctionName() const; + // Internals which are exposed for minirunners. + friend class compiler::CompilationContext; + friend class selfdriving::OperatingUnitRecorder; + + /* -------------------------------------------------------------------------- + Pipeline Function Generation + -------------------------------------------------------------------------- */ - // Generate the pipeline state initialization logic. - ast::FunctionDecl *GenerateSetupPipelineStateFunction() const; + /** + * Generate code to initialize pipeline state. + * @return The function declaration for the generated function + */ + ast::FunctionDecl *GenerateInitPipelineStateFunction() const; - // Generate the pipeline state cleanup logic. + /** + * Generate code to teardown pipeline state. + * @return The function declaration for the generated function + */ ast::FunctionDecl *GenerateTearDownPipelineStateFunction() const; - // Generate pipeline initialization logic. - // @warning TLS is NOT reset if this is a nested pipeline - ast::FunctionDecl *GenerateInitPipelineFunction() const; + /** + * Generate code to wrap top-level pipeline calls. + * NOTE: Currently only used for pipelines with output callback. + * @return The function declaration for the generated function + */ + ast::FunctionDecl *GeneratePipelineRunAllFunction() const; - // Generate the main pipeline work function. - ast::FunctionDecl *GeneratePipelineWorkFunction() const; + /** + * Generate code to initialize the pipeline. + * @return The function declaration for the generated function + */ + ast::FunctionDecl *GenerateInitPipelineFunction() const; - // Generate the main pipeline logic. + /** + * Generate code to run primary pipeline logic. + * @return The function declaration for the generated function + */ ast::FunctionDecl *GenerateRunPipelineFunction() const; - // Generate pipeline tear-down logic. + /** + * Generate code to perform pipeline work. + * @return The function declaration for the generated function + */ + ast::FunctionDecl *GeneratePipelineWorkFunction() const; + + /** + * Generate code to teardown the pipeline. + * @return The function declaration for the generated function + */ ast::FunctionDecl *GenerateTearDownPipelineFunction() const; - // Marks this pipeline as nested + /* -------------------------------------------------------------------------- + Pipeline Function Parameter Definition + -------------------------------------------------------------------------- */ + + /** @return The arguments common to all query functions */ + util::RegionVector QueryParams() const; + + /* -------------------------------------------------------------------------- + Nested Pipelines + -------------------------------------------------------------------------- */ + + /** @brief Indicate that this pipeline is nested. */ void MarkNested() { nested_ = true; } - private: - // Internals which are exposed for minirunners. - friend class compiler::CompilationContext; - friend class selfdriving::OperatingUnitRecorder; + /** @return `true` if this is a nested pipeline, `false` otherwise */ + bool IsNestedPipeline() const { return nested_; } - /** @return The vector of pipeline operators that make up the pipeline. */ - const std::vector &GetTranslators() const { return steps_; } + /* -------------------------------------------------------------------------- + Pipeline Variable and Function Identifiers + -------------------------------------------------------------------------- */ - void InjectStartPipelineTracker(FunctionBuilder *builder) const; + /** @return An identifier for the query state variable */ + ast::Identifier GetQueryStateName() const; - void InjectEndResourceTracker(FunctionBuilder *builder, query_id_t query_id) const; + /** @return An identifier for the `InitPipelineState` function */ + ast::Identifier GetInitPipelineStateFunctionName() const; - void CollectDependencies(std::vector *deps) const; + /** @return An identifier for the `TeardownPipelineState` function */ + ast::Identifier GetTearDownPipelineStateFunctionName() const; + + /** @return An identifier for the pipeline `RunAll` function */ + ast::Identifier GetRunAllPipelineFunctionName() const; + /** @return An identifier for the pipeline `Init` function */ ast::Identifier GetInitPipelineFunctionName() const; + + /** @return An identifier for the pipeline `Run` function */ ast::Identifier GetRunPipelineFunctionName() const; + + /** @return An identifier for the pipeline `Teardown` function */ ast::Identifier GetTeardownPipelineFunctionName() const; + /** @return An identifier for the pipeline `Work` function (serial or parallel) */ + ast::Identifier GetPipelineWorkFunctionName() const; + + /** @return An immutable reference to the pipeline state descriptor */ const StateDescriptor &GetPipelineStateDescriptor() const { return state_; } + /** @return A mutable reference to the pipeline state descriptor */ StateDescriptor &GetPipelineStateDescriptor() { return state_; } + /* -------------------------------------------------------------------------- + Additional Helpers + -------------------------------------------------------------------------- */ + + /** @return The vector of pipeline operators that make up the pipeline. */ + const std::vector &GetTranslators() const { return steps_; } + private: - // A unique pipeline ID. + /** A unique pipeline ID. */ uint32_t id_; - // The compilation context this pipeline is part of. + /** The compilation context this pipeline is part of. */ CompilationContext *compilation_context_; - // The code generation instance. + /** The code generation instance. */ CodeGen *codegen_; - - // Cache of common identifiers. - ast::Identifier state_var_; - // The pipeline state. + /** The pipeline state. */ StateDescriptor state_; - // The pipeline operating unit feature vector state. + /** The pipeline operating unit feature vector state. */ StateDescriptor::Entry oufeatures_; - // Operators making up the pipeline. + /** Operators making up the pipeline. */ std::vector steps_; - // The driver. + /** The driver. */ PipelineDriver *driver_; - // pointer to parent pipeline (only applicable if this is a nested pipeline) + /** pointer to parent pipeline (only applicable if this is a nested pipeline) */ Pipeline *parent_; - // Expressions participating in the pipeline. + /** Expressions participating in the pipeline. */ std::vector expressions_; - // All unnested pipelines this one depends on completion of. + /** All unnested pipelines this one depends on completion of. */ std::vector dependencies_; - // Vector of pipelines that are nested under this pipeline + /** Vector of pipelines that are nested under this pipeline. */ std::vector nested_pipelines_; - // Extra parameters to pass into pipeline, currently used for nested - // consumer pipeline work functions + /** Extra parameters to passed into pipeline functions; used for nested consumer pipeline work. */ std::vector extra_pipeline_params_; - // Configured parallelism. + /** Configured parallelism. */ Parallelism parallelism_; - // Whether to check for parallelism in new pipeline elements. + /** Whether to check for parallelism in new pipeline elements. */ bool check_parallelism_; - // Whether or not this is a nested pipeline + /** Whether or not this is a nested pipeline. */ bool nested_; - // Whether or not this pipeline has been prepared already + /** Whether or not this pipeline is prepared. */ bool prepared_{false}; }; diff --git a/src/include/execution/compiler/udf/udf_codegen.h b/src/include/execution/compiler/udf/udf_codegen.h new file mode 100644 index 0000000000..368a78b59b --- /dev/null +++ b/src/include/execution/compiler/udf/udf_codegen.h @@ -0,0 +1,526 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "execution/ast/udf/udf_ast_context.h" +#include "execution/ast/udf/udf_ast_node_visitor.h" +#include "execution/compiler/codegen.h" +#include "execution/compiler/function_builder.h" +#include "execution/functions/function_context.h" +#include "planner/plannodes/abstract_join_plan_node.h" + +namespace noisepage::catalog { +class CatalogAccessor; +} // namespace noisepage::catalog + +namespace noisepage::optimizer { +class OptimizeResult; +} // namespace noisepage::optimizer + +namespace noisepage::parser::udf { +class VariableRef; +} // namespace noisepage::parser::udf + +namespace noisepage::execution { + +namespace compiler { +class ExecutableQuery; +} // namespace compiler + +namespace vm { +class FunctionInfo; +} // namespace vm + +// Forward declarations +namespace ast::udf { +class AbstractAST; +class StmtAST; +class ExprAST; +class ValueExprAST; +class VariableExprAST; +class BinaryExprAST; +class CallExprAST; +class MemberExprAST; +class SeqStmtAST; +class DeclStmtAST; +class IfStmtAST; +class WhileStmtAST; +class RetStmtAST; +class AssignStmtAST; +class SQLStmtAST; +class FunctionAST; +class IsNullExprAST; +class DynamicSQLStmtAST; +class ForSStmtAST; +} // namespace ast::udf + +namespace compiler::udf { + +class ExpressionResultScope; + +/** + * The UdfCodegen class implements a visitor for UDF AST + * nodes and encapsulates all of the logic required to generate + * code from the UDF abstract syntax tree. + */ +class UdfCodegen : ast::udf::ASTNodeVisitor { + public: + /** + * Construct a new UdfCodegen instance. + * @param accessor The catalog accessor used in code generation + * @param fb The function builder instance used for the UDF + * @param udf_ast_context The AST context for the UDF + * @param codegen The codegen instance + * @param db_oid The OID for the relevant database + */ + UdfCodegen(catalog::CatalogAccessor *accessor, FunctionBuilder *fb, ast::udf::UdfAstContext *udf_ast_context, + CodeGen *codegen, catalog::db_oid_t db_oid); + + /** Destroy the UDF code generation context. */ + ~UdfCodegen() override = default; + + /** + * Run UDF code generation. + * @param accessor The catalog accessor + * @param function_builder The function builder to use during code generation + * @param ast_context The UDF AST context + * @param codegen The code generation instance + * @param db_oid The database OID + * @param root The root of the UDF AST for which code is generated + * @return The file containing the generated code + */ + static execution::ast::File *Run(catalog::CatalogAccessor *accessor, FunctionBuilder *function_builder, + ast::udf::UdfAstContext *ast_context, CodeGen *codegen, catalog::db_oid_t db_oid, + ast::udf::FunctionAST *root); + + private: + /** + * Generate a UDF from the given abstract syntax tree. + * @param ast The AST from which to generate the UDF + */ + void GenerateUDF(ast::udf::AbstractAST *ast); + + /** + * Visit an AbstractAST node. + * @param ast The AST node to visit + */ + void Visit(ast::udf::AbstractAST *ast) override; + + /** + * Visit a FunctionAST node. + * @param ast The AST node to visit + */ + void Visit(ast::udf::FunctionAST *ast) override; + + /** + * Visit a StmtAST node. + * @param ast The AST node to visit + */ + void Visit(ast::udf::StmtAST *ast) override; + + /** + * Visit an ExprAST node. + * @param ast The AST node to visit + */ + void Visit(ast::udf::ExprAST *ast) override; + + /** + * Visit a ValueExprAST node. + * @param ast The AST node to visit + */ + void Visit(ast::udf::ValueExprAST *ast) override; + + /** + * Visit a VariableExprAST node. + */ + void Visit(ast::udf::VariableExprAST *ast) override; + + /** + * Visit a BinaryExprAST node. + * @param ast The AST node to visit + */ + void Visit(ast::udf::BinaryExprAST *ast) override; + + /** + * Visit a CallExprAST node. + * @param ast The AST node to visit + */ + void Visit(ast::udf::CallExprAST *ast) override; + + /** + * Visit an IsNullExprAST node. + * @param ast The AST node to visit + */ + void Visit(ast::udf::IsNullExprAST *ast) override; + + /** + * Visit a SeqStmtAST node. + * @param ast The AST node to visit + */ + void Visit(ast::udf::SeqStmtAST *ast) override; + + /** + * Visit a DeclStmtNode node. + * @param ast The AST node to visit + */ + void Visit(ast::udf::DeclStmtAST *ast) override; + + /** + * Visit a IfStmtAST node. + * @param ast The AST node to visit + */ + void Visit(ast::udf::IfStmtAST *ast) override; + + /** + * Visit a WhileStmtAST node. + * @param ast The AST node to visit + */ + void Visit(ast::udf::WhileStmtAST *ast) override; + + /** + * Visit a RetStmtAST node. + * @param ast The AST node to visit + */ + void Visit(ast::udf::RetStmtAST *ast) override; + + /** + * Visit an AssignStmtAST node. + * @param ast The AST node to visit + */ + void Visit(ast::udf::AssignStmtAST *ast) override; + + /** + * Visit a SQLStmtAST node. + * @param ast The AST node to visit + */ + void Visit(ast::udf::SQLStmtAST *ast) override; + + /** + * Visit a DynamicSQLStmtAST node. + * @param ast The AST node to visit + */ + void Visit(ast::udf::DynamicSQLStmtAST *ast) override; + + /** + * Visit a ForIStmtAST node. + * @param ast The AST node to visit + */ + void Visit(ast::udf::ForIStmtAST *ast) override; + + /** + * Visit a ForSStmtAST node. + * @param ast The AST node to visit + */ + void Visit(ast::udf::ForSStmtAST *ast) override; + + /** + * Visit a MemberExprAST node. + * @param ast The AST node to visit + */ + void Visit(ast::udf::MemberExprAST *ast) override; + + /** + * Complete UDF code generation. + * @return The result of code generation as a file + */ + execution::ast::File *Finish(); + + /** + * Return the string that represents the return value. + * @return The string that represents the return value + */ + static const char *GetReturnParamString(); + + private: + /* -------------------------------------------------------------------------- + Code Generation: Function Calls + -------------------------------------------------------------------------- */ + + /** + * Resolve the type of an expression. + * @param expr The expression + * @return The resolved type + */ + sql::SqlTypeId ResolveType(const ast::Expr *expr) const; + + /** + * Resolve the type of a literal expression in a function call argument. + * @param expr The literal expression + * @return The resolved type of the literal expression + */ + sql::SqlTypeId ResolveTypeForLiteralExpression(const ast::LitExpr *expr) const; + + /** + * Resolve the type of a binary expression in a function call argument. + * @param expr The binary expression + * @return The resolved type of the binary expression + */ + sql::SqlTypeId ResolveTypeForBinaryExpression(const ast::BinaryOpExpr *expr) const; + + /** + * Resolve the type of an identifier expression in a function call argument. + * @param expr The identifier expression + * @return The resolved type of the identifier expression + */ + sql::SqlTypeId ResolveTypeForIdentifierExpression(const ast::IdentifierExpr *expr) const; + + /** + * Resolve the type of a call expression in a function call argument. + * @param expr The call expression + * @return The resolved type of the call expression + */ + sql::SqlTypeId ResolveTypeForCallExpression(const ast::CallExpr *expr) const; + + /* -------------------------------------------------------------------------- + Code Generation: For-S Loops + -------------------------------------------------------------------------- */ + + /** + * Begin construction of a lambda that writes the output of the query + * represented by `plan` into the variables identified by `variables`. + * @param plan The query plan + * @param variables The names of the variables to which results are bound + * @return The unfinished function builder for the lambda + */ + std::unique_ptr StartLambda(common::ManagedPointer plan, + const std::vector &variables); + + /** + * Begin construction of a lambda that writes the output of the query + * represented by `plan` into a single RECORD-type variable. + * @param plan The query plan + * @param variables The names of the variables to which results are bound + * @return The unfinished function builder for the lambda + */ + std::unique_ptr StartLambdaBindingToRecord(common::ManagedPointer plan, + const std::vector &variables); + + /** + * Begin construction of a lambda that writes the output of the query + * represented by `plan` into one or more non-RECORD variables. + * @param plan The query plan + * @param variables The names of the variables to which results are bound + * @return The unfinished function builder for the lambda + */ + std::unique_ptr StartLambdaBindingToScalars(common::ManagedPointer plan, + const std::vector &variables); + + /* -------------------------------------------------------------------------- + Code Generation: SQL Statements + -------------------------------------------------------------------------- */ + + /** + * Construct a lambda expression that writes the output of the query + * represented by `plan` into the variables identified by `variables`. + * @param plan The query plan + * @param variables The names of the variables to which results are bound + * @return The finished lambda expression + */ + ast::LambdaExpr *MakeLambda(common::ManagedPointer plan, + const std::vector &variables); + + /** + * Construct a lambda expression that writes the output of the query + * represented by `plan` into a single RECORD-type variable. + * @param plan The query plan + * @param variables The names of the variables to which results are bound + * @return The finished lambda expression + */ + ast::LambdaExpr *MakeLambdaBindingToRecord(common::ManagedPointer plan, + const std::vector &variables); + + /** + * Construct a lambda expression that writes the output of the query + * represented by `plan` into one or more non-RECORD variables. + * @param plan The query plan + * @param variables The names of the variables to which results are bound + * @return The finished lambda expression + */ + ast::LambdaExpr *MakeLambdaBindingToScalars(common::ManagedPointer plan, + const std::vector &variables); + + /* -------------------------------------------------------------------------- + Code Generation: Common + -------------------------------------------------------------------------- */ + + /** + * Generate code to add query parameters to the execution context. + * @param exec_ctx The execution context expression + * @param variable_refs The collection of variable references + */ + void CodegenAddParameters(ast::Expr *exec_ctx, const std::vector &variable_refs); + + /** + * Generate code to add a scalar parameter to the execution context. + * @param exec_ctx The execution context + * @param variable_ref The variable reference + */ + void CodegenAddScalarParameter(ast::Expr *exec_ctx, const parser::udf::VariableRef &variable_ref); + + /** + * Generate code to add a non-scalar parameter to the execution context. + * @param exec_ctx The execution context + * @param variable_ref The variable reference + */ + void CodegenAddTableParameter(ast::Expr *exec_ctx, const parser::udf::VariableRef &variable_ref); + + /** + * Generate code to initialize bound variables. + * @param plan The query plan + * @param bound_variables The variables to which results of the query are bound + */ + void CodegenBoundVariableInit(common::ManagedPointer plan, + const std::vector &bound_variables); + + /** + * Generate code to initialize bound scalar variables. + * @param plan The query plan + * @param bound_variables The name(s) of the scalar variables to which results of the query are bound + */ + void CodegenBoundVariableInitForScalars(common::ManagedPointer plan, + const std::vector &bound_variables); + + /** + * Generate code to initialize a bound record variable. + * @param plan The query plan + * @param record_name The name of the record variable to which results of the query are bound + */ + void CodegenBoundVariableInitForRecord(common::ManagedPointer plan, + const std::string &record_name); + + /** + * Generate code to invoke each top-level function in the executable query. + * @param exec_query The executable query for which calls are generated + * @param query_state_id The identifier for the query state + * @param lambda_id The identifier for the lambda expression that + * is used as an output callback in the query + */ + void CodegenTopLevelCalls(const ExecutableQuery *exec_query, ast::Identifier query_state_id, + ast::Identifier lambda_id); + + /** + * Translate a SQL type to its corresponding catalog type. + * @param type The SQL type of interest + * @return The corresponding catalog type + */ + catalog::type_oid_t GetCatalogTypeOidFromSQLType(execution::sql::SqlTypeId type); + + /** + * Translate a builtin type Kind to its corresponding catalog type. + * @param type The builtin type of interst + * @return The corresponding catalog type + */ + catalog::type_oid_t GetCatalogTypeFromBuiltinKind(execution::ast::BuiltinType::Kind type); + + /** @return A mutable reference to the symbol table */ + std::unordered_map &SymbolTable() { return symbol_table_; } + + /** @return An immutable reference to the symbol table */ + const std::unordered_map &SymbolTable() const { return symbol_table_; } + + /** + * Get the type of the variable identified by `name`. + * @param name The name of the variable + * @return The type of the variable identified by `name` + * @throw EXECUTION_EXCEPTION on failure to resolve type + */ + sql::SqlTypeId GetVariableType(const std::string &name) const; + + /** + * Get the type of the record variable identified by `name`. + * @param name The name of the variable + * @return The type of the record variable identified by `name` + * @throw EXECUTION_EXCEPTION on failure to resolve type + */ + std::vector> GetRecordType(const std::string &name) const; + + /** + * Bind the query and return the variable references. + * @param query The parsed query + * @return The collection of variable references + */ + std::vector BindQueryAndGetVariableRefs(parser::ParseResult *query); + + /** + * Run the optimizer on an embedded SQL query. + * @param parsed_query The result of parsing the query + * @param variable_refs The vector of variable references within query + * @return The optimized result + */ + std::unique_ptr OptimizeEmbeddedQuery( + parser::ParseResult *parsed_query, const std::vector &variable_refs); + + /** + * Determine if the function described by the given metdata is a + * top-level run function that accepts an output callback argument. + * @param function_metatdata The function metadata + * @return `true` if the function meets the above criteria, `false` otherwise + */ + static bool IsRunAllFunction(const std::string &name); + + /** + * Get the builtin parameter-add function for the specified parameter type. + * @param parameter_type The parameter type + * @return The builtin function to add this parameter + */ + static ast::Builtin AddParamBuiltinForParameterType(sql::SqlTypeId parameter_type); + + /** @return The execution context provided to the function */ + ast::Expr *GetExecutionContext(); + + /** @return The current execution result expression */ + ast::Expr *GetExecutionResult(); + + /** + * Set the current execution result expression. + * @param The execution result expression + */ + void SetExecutionResult(ast::Expr *result); + + /** + * Stage evaluation of the expression `expr` by generating + * code to perform the evaluation (at runtime). + * @param expr The expression to evaluate + * @return The result of evaluating the expression + */ + ast::Expr *EvaluateExpression(ast::udf::ExprAST *expr); + + private: + /** The string identifier for internal declarations */ + constexpr static const char INTERNAL_DECL_ID[] = "*internal*"; + + /** The catalog access used during code generation */ + catalog::CatalogAccessor *accessor_; + + /** The function builder used during code generation */ + FunctionBuilder *fb_; + + /** The AST context for the UDF */ + ast::udf::UdfAstContext *udf_ast_context_; + + /** The code generation instance */ + CodeGen *codegen_; + + /** The OID of the relevant database */ + catalog::db_oid_t db_oid_; + + /** Auxiliary declarations */ + execution::util::RegionVector aux_decls_; + + /** The current type during code generation */ + sql::SqlTypeId current_type_{sql::SqlTypeId::Invalid}; + + /** The current execution result expression */ + execution::ast::Expr *execution_result_; + + /** Map from human-readable string identifier to internal identifier */ + std::unordered_map symbol_table_; +}; + +} // namespace compiler::udf +} // namespace noisepage::execution diff --git a/src/include/execution/exec/execution_context.h b/src/include/execution/exec/execution_context.h index 77537059ba..d7eb1ba5a9 100644 --- a/src/include/execution/exec/execution_context.h +++ b/src/include/execution/exec/execution_context.h @@ -1,9 +1,12 @@ #pragma once #include +#include +#include #include #include +#include "catalog/catalog_defs.h" #include "common/managed_pointer.h" #include "execution/exec/execution_settings.h" #include "execution/exec/output.h" @@ -11,7 +14,9 @@ #include "execution/sql/memory_tracker.h" #include "execution/sql/runtime_types.h" #include "execution/sql/thread_state_container.h" +#include "execution/sql/value.h" #include "execution/util/region.h" +#include "execution/vm/execution_mode.h" #include "metrics/metrics_defs.h" #include "planner/plannodes/output_schema.h" #include "self_driving/modeling/operating_unit.h" @@ -43,10 +48,11 @@ class RecoveryManager; } // namespace noisepage::storage namespace noisepage::execution::exec { + class ExecutionSettings; + /** - * Execution Context: Stores information handed in by upper layers. - * TODO(Amadou): This class will change once we know exactly what we get from upper layers. + * The ExecutionContext class stores information handed in by upper layers. */ class EXPORT ExecutionContext { public: @@ -76,85 +82,113 @@ class EXPORT ExecutionContext { */ using HookFn = void (*)(void *, void *, void *); - /** - * Constructor - * @param db_oid oid of the database - * @param txn transaction used by this query - * @param callback callback function for outputting - * @param schema the schema of the output - * @param accessor the catalog accessor of this query - * @param exec_settings The execution settings to run with. - * @param metrics_manager The metrics manager for recording metrics - * @param replication_manager The replication manager to handle communication between primary and replicas. - * @param recovery_manager The recovery manager that handles both recovery and application of replication records. - */ - ExecutionContext(catalog::db_oid_t db_oid, common::ManagedPointer txn, - const OutputCallback &callback, const planner::OutputSchema *schema, - const common::ManagedPointer accessor, - const exec::ExecutionSettings &exec_settings, - common::ManagedPointer metrics_manager, - common::ManagedPointer replication_manager, - common::ManagedPointer recovery_manager) - : exec_settings_(exec_settings), - db_oid_(db_oid), - txn_(txn), - mem_tracker_(std::make_unique()), - mem_pool_(std::make_unique(common::ManagedPointer(mem_tracker_))), - schema_(schema), - callback_(callback), - thread_state_container_(std::make_unique(mem_pool_.get())), - accessor_(accessor), - metrics_manager_(metrics_manager), - replication_manager_(replication_manager), - recovery_manager_(recovery_manager) {} + /* -------------------------------------------------------------------------- + Getters / Setters + -------------------------------------------------------------------------- */ + + /** @return The identifier for the associated query */ + execution::query_id_t GetQueryId() { return query_id_; } /** - * @return the transaction used by this query + * Set the current executing query identifier. + * @param query_id The query identifier */ + void SetQueryId(execution::query_id_t query_id) { query_id_ = query_id; } + + /** @return The database OID. */ + catalog::db_oid_t DBOid() { return db_oid_; } + + /** @return The transaction associated with this execution context */ common::ManagedPointer GetTxn() { return txn_; } - /** - * Constructs a new Output Buffer for outputting query results to consumers - * @return newly created output buffer - */ - OutputBuffer *OutputBufferNew(); + /** @return The execution mode for the execution context */ + vm::ExecutionMode GetExecutionMode() const { return execution_mode_; } /** - * @return The thread state container. + * Set the execution mode for the execution context. + * @param execution_mode The desired execution mode + * + * NOTE: Most of the time one should avoid calling this + * function directly; the execution mode for the ExecutionContext + * instance is automatically set in ExecutableQuery::Run() to + * the execution mode in which the query is executed. */ + void SetExecutionMode(const vm::ExecutionMode execution_mode) { execution_mode_ = execution_mode; } + + /** @return The execution settings. */ + const exec::ExecutionSettings &GetExecutionSettings() const { return execution_settings_; } + + /** @return The catalog accessor associated with this execution context */ + catalog::CatalogAccessor *GetAccessor() { return accessor_.Get(); } + + /** @return The metrics manager associated with this execution context */ + common::ManagedPointer GetMetricsManager() { return metrics_manager_; } + + /** @return The memory pool for this execution context */ + sql::MemoryPool *GetMemoryPool() { return mem_pool_.get(); } + + /** @return The thread state container */ sql::ThreadStateContainer *GetThreadStateContainer() { return thread_state_container_.get(); } + /** @return The string allocator for this execution context */ + sql::VarlenHeap *GetStringAllocator() { return &string_allocator_; } + + /** @return The pipeline operating units for the execution context */ + common::ManagedPointer GetPipelineOperatingUnits() { + return pipeline_operating_units_; + } + /** - * @return the memory pool + * Set the pipeline operating units for the execution context. + * @param op pipeline operating units for executing the query */ - sql::MemoryPool *GetMemoryPool() { return mem_pool_.get(); } + void SetPipelineOperatingUnits(common::ManagedPointer op) { + pipeline_operating_units_ = op; + } + + /** @return The number of rows affected by the current execution, e.g., INSERT/DELETE/UPDATE. */ + uint32_t GetRowsAffected() const { return rows_affected_; } /** - * @return the string allocator + * Increment or decrement the number of rows affected. + * @param num_rows The delta for the number of rows affected */ - sql::VarlenHeap *GetStringAllocator() { return &string_allocator_; } + void AddRowsAffected(int64_t num_rows) { rows_affected_ += num_rows; } /** - * @param schema the schema of the output - * @return the size of tuple with this final_schema + * Overrides recording from memory tracker. + * NOTE: This should never be used by parallel threads directly + * @param memory_use Correct memory value to record */ - static uint32_t ComputeTupleSize(const planner::OutputSchema *schema); + void SetMemoryUseOverride(uint32_t memory_use) { + memory_use_override_ = true; + memory_use_override_value_ = memory_use; + } /** - * @return The catalog accessor. + * Sets the estimated concurrency of a parallel operation. + * This value is used when initializing an ExecOUFeatureVector + * + * @note this value is reset by setting it to 0. + * @param estimate Estimated number of concurrent tasks */ - catalog::CatalogAccessor *GetAccessor() { return accessor_.Get(); } - - /** @return The execution settings. */ - const exec::ExecutionSettings &GetExecutionSettings() const { return exec_settings_; } + void SetNumConcurrentEstimate(uint32_t estimate) { num_concurrent_estimate_ = estimate; } /** - * Start the resource tracker + * Sets the opaque query state pointer for the current query invocation. + * @param query_state QueryState */ + void SetQueryState(void *query_state) { query_state_ = query_state; } + + /* -------------------------------------------------------------------------- + Resource Metrics Collection + -------------------------------------------------------------------------- */ + + /** Start the resource tracker. */ void StartResourceTracker(metrics::MetricsComponent component); /** - * End the resource tracker and record the metrics + * End the resource tracker and record the metrics. * @param name the string name get printed out with the time * @param len the length of the string name */ @@ -188,63 +222,69 @@ class EXPORT ExecutionContext { */ void InitializeParallelOUFeatureVector(selfdriving::ExecOUFeatureVector *ouvec, pipeline_id_t pipeline_id); - /** - * @return the db oid - */ - catalog::db_oid_t DBOid() { return db_oid_; } - - /** - * Set the mode for this execution. - * This only records the mode and serves the metrics collection purpose, which does not have any impact on the - * actual execution. - * @param mode the integer value of the execution mode to record - */ - void SetExecutionMode(uint8_t mode) { execution_mode_ = mode; } + /* -------------------------------------------------------------------------- + Runtime Parameters (User-Defined Functions) + -------------------------------------------------------------------------- */ - /** - * Set the accessor - * @param accessor The catalog accessor. - */ - void SetAccessor(const common::ManagedPointer accessor) { accessor_ = accessor; } + /** Initialize a new, empty collection of parameters at the top of the parameter stack */ + void StartParams() { runtime_parameters_.emplace(); } - /** - * Set the execution parameters. - * @param params The execution parameters. - */ - void SetParams(common::ManagedPointer> params) { - params_ = params; + /** Remove the topmost collection of parameters from the parameter stack */ + void FinishParams() { + NOISEPAGE_ASSERT(!runtime_parameters_.empty(), "Attempt to pop from empty runtime parameter stack."); + runtime_parameters_.pop(); } /** - * @param param_idx index of parameter to access - * @return immutable parameter at provided index + * Add a runtime parameter to the "top-most" collection of runtime parameters. + * @param val The parameter to be added */ - const parser::ConstantValueExpression &GetParam(uint32_t param_idx) const; + void AddParam(common::ManagedPointer val) { + NOISEPAGE_ASSERT(!runtime_parameters_.empty(), "Must call StartParams() prior to adding runtime parameters."); + runtime_parameters_.top().push_back(val.CastManagedPointerTo()); + } /** - * Set the PipelineOperatingUnits - * @param op PipelineOperatingUnits for executing the given query + * Add a runtime parameter to the "top-most" collection of runtime parameters. + * @param val The parameter to be added */ - void SetPipelineOperatingUnits(common::ManagedPointer op) { - pipeline_operating_units_ = op; + void AddParam(common::ManagedPointer val) { + NOISEPAGE_ASSERT(!runtime_parameters_.empty(), "Must call StartParams() prior to adding runtime parameters."); + runtime_parameters_.top().push_back(val); } /** - * @return PipelineOperatingUnits - */ - common::ManagedPointer GetPipelineOperatingUnits() { - return pipeline_operating_units_; + * Get the parameter at the specified index. + * @param index index of parameter to access + * @return An immutable point to the parameter at specified index + */ + common::ManagedPointer GetParam(uint32_t index) const { + // Always get the query parameter from the "top-most" collection + // of parameters; if the runtime parameters stack is empty, default + // to the "base" set of parameters for the query, otherwise, grab + // the parameter at the specified index from the top of the runtime + // parameters stack. + if (!runtime_parameters_.empty()) { + NOISEPAGE_ASSERT(index < runtime_parameters_.top().size(), "ExecutionContext::GetParam() index out of range."); + return runtime_parameters_.top()[index]; + } + NOISEPAGE_ASSERT(index < parameters_.size(), "ExecutionContext::GetParam() index out of range"); + return parameters_[index]; } - /** @return The number of rows affected by the current execution, e.g., INSERT/DELETE/UPDATE. */ - uint32_t GetRowsAffected() const { return rows_affected_; } + /* -------------------------------------------------------------------------- + Other Functionality + -------------------------------------------------------------------------- */ - /** Increment or decrement the number of rows affected. */ - void AddRowsAffected(int64_t num_rows) { rows_affected_ += num_rows; } + /** + * Constructs a new Output Buffer for outputting query results to consumers. + * @return The newly created output buffer + */ + OutputBuffer *OutputBufferNew(); /** - * @return On the primary, returns the ID of the last txn sent. - * On a replica, returns the ID of the last txn applied. + * @return On the primary, returns the ID of the last txn sent. + * On a replica, returns the ID of the last txn applied. */ uint64_t ReplicationGetLastTransactionId() const; @@ -266,49 +306,33 @@ class EXPORT ExecutionContext { void AggregateMetricsThread(); /** - * Ensures that the trackers for the current thread are stopped + * Ensures that the trackers for the current thread are stopped. */ void EnsureTrackersStopped(); /** - * @return metrics manager used by execution context + * Compute the size of an output tuple based on the provided schema. + * @param schema The output schema + * @return The size of tuple in this schema */ - common::ManagedPointer GetMetricsManager() { return metrics_manager_; } + static uint32_t ComputeTupleSize(common::ManagedPointer schema); - /** - * @return query identifier - */ - execution::query_id_t GetQueryId() { return query_id_; } + /* -------------------------------------------------------------------------- + Hook Function Management + -------------------------------------------------------------------------- */ /** - * Set the current executing query identifier + * Initializes the set of hooks for the execution context to specified capacity. + * @param num_hooks The desired number of hooks */ - void SetQueryId(execution::query_id_t query_id) { query_id_ = query_id; } + void InitHooks(std::size_t num_hooks); /** - * Overrides recording from memory tracker - * This should never be used by parallel threads directly - * @param memory_use Correct memory value to record - */ - void SetMemoryUseOverride(uint32_t memory_use) { - memory_use_override_ = true; - memory_use_override_value_ = memory_use; - } - - /** - * Sets the opaque query state pointer for the current query invocation - * @param query_state QueryState - */ - void SetQueryState(void *query_state) { query_state_ = query_state; } - - /** - * Sets the estimated concurrency of a parallel operation. - * This value is used when initializing an ExecOUFeatureVector - * - * @note this value is reset by setting it to 0. - * @param estimate Estimated number of concurrent tasks + * Registers a hook function + * @param hook_idx Hook index to register function + * @param hook Function to register */ - void SetNumConcurrentEstimate(uint32_t estimate) { num_concurrent_estimate_ = estimate; } + void RegisterHook(std::size_t hook_idx, HookFn hook); /** * Invoke a hook function if a hook function is available @@ -316,56 +340,130 @@ class EXPORT ExecutionContext { * @param tls TLS argument * @param arg Opaque argument to pass */ - void InvokeHook(size_t hook_index, void *tls, void *arg); + void InvokeHook(std::size_t hook_index, void *tls, void *arg); /** - * Registers a hook function - * @param hook_idx Hook index to register function - * @param hook Function to register + * Clear the hooks for the execution context. */ - void RegisterHook(size_t hook_idx, HookFn hook); + void ClearHooks() { hooks_.clear(); } - /** - * Initializes hooks_ to a certain capacity - * @param num_hooks Number of hooks needed - */ - void InitHooks(size_t num_hooks); + public: + /** An empty output schema */ + constexpr static const std::nullptr_t NULL_OUTPUT_SCHEMA{nullptr}; + /** An empty output callback */ + constexpr static const std::nullptr_t NULL_OUTPUT_CALLBACK{nullptr}; + + private: + friend class ExecutionContextBuilder; /** - * Clears hooks_ + * Construct a new ExecutionContext instance. + * + * NOTE: Private access modifier forces use of ExecutionContextBuilder. + * + * @param db_oid The OID of the database + * @param parameters The query parameters + * @param execution_settings The execution settings to run with + * @param txn The transaction used by this query + * @param output_schema The output schema + * @param output_callback The callback function for query output + * @param accessor The catalog accessor of this query + * @param metrics_manager The metrics manager for recording metrics + * @param replication_manager The replication manager to handle communication between primary and replicas. + * @param recovery_manager The recovery manager that handles both recovery and application of replication records. */ - void ClearHooks() { hooks_.clear(); } + ExecutionContext(const catalog::db_oid_t db_oid, std::vector> &¶meters, + exec::ExecutionSettings execution_settings, + const common::ManagedPointer txn, + const common::ManagedPointer output_schema, + OutputCallback &&output_callback, const common::ManagedPointer accessor, + const common::ManagedPointer metrics_manager, + const common::ManagedPointer replication_manager, + const common::ManagedPointer recovery_manager) + : db_oid_{db_oid}, + parameters_{std::move(parameters)}, + execution_settings_{execution_settings}, + txn_{txn}, + output_schema_{output_schema}, + output_callback_{std::move(output_callback)}, + accessor_{accessor}, + metrics_manager_{metrics_manager}, + replication_manager_{replication_manager}, + recovery_manager_{recovery_manager}, + mem_tracker_{std::make_unique()}, + mem_pool_{std::make_unique(common::ManagedPointer(mem_tracker_))}, + thread_state_container_{std::make_unique(mem_pool_.get())} {} private: + /** + * The query identifier + * + * The query identifier is only used in certain situations and is + * set manually after construction of the ExecutionContext via the + * SetQueryId() member function. + */ query_id_t query_id_{execution::query_id_t(0)}; - exec::ExecutionSettings exec_settings_; - catalog::db_oid_t db_oid_; - common::ManagedPointer txn_; + + /** The OID of the database with which the query is associated */ + const catalog::db_oid_t db_oid_; + + /** The query parameters */ + std::vector> parameters_; + /** The query execution mode */ + vm::ExecutionMode execution_mode_; + /** The execution setting for the query */ + const exec::ExecutionSettings execution_settings_; + + /** The associated transaction */ + const common::ManagedPointer txn_; + + /** The query output schema */ + common::ManagedPointer output_schema_{nullptr}; + /** The query output buffer */ + std::unique_ptr buffer_{nullptr}; + /** The query output callback */ + OutputCallback output_callback_; + + /** The query catalog accessor */ + common::ManagedPointer accessor_; + /** The query metrics manager */ + common::ManagedPointer metrics_manager_; + /** The replication manager with which the query is associated */ + common::ManagedPointer replication_manager_; + /** The recovery manager with which the query is associated */ + common::ManagedPointer recovery_manager_; + + /** The memory tracker */ std::unique_ptr mem_tracker_; + /** The memory pool */ std::unique_ptr mem_pool_; - std::unique_ptr buffer_ = nullptr; - const planner::OutputSchema *schema_ = nullptr; - const OutputCallback &callback_; - // Container for thread-local state. - // During parallel processing, execution threads access their thread-local state from this container. + /** The container for thread-local state */ std::unique_ptr thread_state_container_; + + /** The allocator for strings */ // TODO(WAN): EXEC PORT we used to push the memory tracker into the string allocator, do this sql::VarlenHeap string_allocator_; + + /** The pipeline operating units for the query */ common::ManagedPointer pipeline_operating_units_{nullptr}; - common::ManagedPointer accessor_; - common::ManagedPointer metrics_manager_; - common::ManagedPointer> params_; - uint8_t execution_mode_; - uint32_t rows_affected_ = 0; - - common::ManagedPointer replication_manager_; - common::ManagedPointer recovery_manager_; + /** The number of rows affected by the query */ + uint32_t rows_affected_{0}; + /** `true` if memory overrride is used */ bool memory_use_override_ = false; - uint32_t memory_use_override_value_ = 0; - uint32_t num_concurrent_estimate_ = 0; + /** The value to use for memory override */ + uint32_t memory_use_override_value_{0}; + /** The concurrency estimate for query execution */ + uint32_t num_concurrent_estimate_{0}; + + /** The hooks for the query */ std::vector hooks_{}; + + /** The query state object */ void *query_state_; + + /** The runtime parameter stack */ + std::stack>> runtime_parameters_; }; } // namespace noisepage::execution::exec diff --git a/src/include/execution/exec/execution_context_builder.h b/src/include/execution/exec/execution_context_builder.h new file mode 100644 index 0000000000..d522dd1510 --- /dev/null +++ b/src/include/execution/exec/execution_context_builder.h @@ -0,0 +1,193 @@ +#pragma once + +#include +#include +#include +#include + +#include "common/managed_pointer.h" +#include "execution/exec/execution_settings.h" +#include "execution/exec/output.h" + +namespace noisepage::parser { +class ConstantValueExpression; +} // namespace noisepage::parser + +namespace noisepage::planner { +class OutputSchema; +} // namespace noisepage::planner + +namespace noisepage::catalog { +class CatalogAccessor; +} // namespace noisepage::catalog + +namespace noisepage::metrics { +class MetricsManager; +} // namespace noisepage::metrics + +namespace noisepage::replication { +class ReplicationManager; +} // namespace noisepage::replication + +namespace noisepage::storage { +class RecoveryManager; +} // namespace noisepage::storage + +namespace noisepage::execution::sql { +struct Val; +} // namespace noisepage::execution::sql + +namespace noisepage::transaction { +class TransactionContext; +} // namespace noisepage::transaction + +namespace noisepage::execution::exec { + +class ExecutionContext; +class ExecutionSettings; + +/** + * The ExecutionContextBuilder class implements a builder for ExecutionContext. + */ +class ExecutionContextBuilder { + public: + /** + * Construct a new ExecutionContextBuilder. + */ + ExecutionContextBuilder() = default; + + /** @return The completed ExecutionContext instance */ + std::unique_ptr Build(); + + /** + * Set the query parameters for the execution context. + * @param parameters The query parameters + * @return Builder reference for chaining + */ + ExecutionContextBuilder &WithQueryParameters(std::vector> &¶meters) { + parameters_ = std::move(parameters); + return *this; + } + + /** + * Set the query parameters for the execution context. + * @param parameter_exprs The collection of expressions from which the query parameters are derived + * @return Builder reference for chaining + */ + ExecutionContextBuilder &WithQueryParametersFrom(const std::vector ¶meter_exprs); + + /** + * Set the database OID for the execution context. + * @param db_oid The database OID + * @return Builder reference for chaining + */ + ExecutionContextBuilder &WithDatabaseOID(const catalog::db_oid_t db_oid) { + db_oid_ = db_oid; + return *this; + } + + /** + * Set the transaction context for the execution context. + * @param txn The transaction context + * @return Builder reference for chaining + */ + ExecutionContextBuilder &WithTxnContext(common::ManagedPointer txn) { + txn_ = txn; + return *this; + } + + /** + * Set the output schema for the execution context. + * @param output_schema The output schema + * @return Builder reference for chaining + */ + ExecutionContextBuilder &WithOutputSchema(common::ManagedPointer output_schema) { + output_schema_ = output_schema; + return *this; + } + + /** + * Set the output callback for the execution context. + * @param output_callback The output callback + * @return Builder reference for chaining + */ + ExecutionContextBuilder &WithOutputCallback(OutputCallback output_callback) { + output_callback_.emplace(std::move(output_callback)); + return *this; + } + + /** + * Set the catalog accessor for the execution context. + * @param accessor The catalog accessor + * @return Builder reference for chaining + */ + ExecutionContextBuilder &WithCatalogAccessor(common::ManagedPointer accessor) { + catalog_accessor_ = accessor; + return *this; + } + + /** + * Set the execution settings for the execution context. + * @param exec_settings The execution settings + * @return Builder reference for chaining + */ + ExecutionContextBuilder &WithExecutionSettings(exec::ExecutionSettings exec_settings) { + exec_settings_.emplace(exec_settings); + return *this; + } + + /** + * Set the metrics manager for the execution context. + * @param metrics_manager The metrics manager + * @return Builder reference for chaining + */ + ExecutionContextBuilder &WithMetricsManager(common::ManagedPointer metrics_manager) { + metrics_manager_ = metrics_manager; + return *this; + } + + /** + * Set the replication manager for the execution context. + * @param replication_manager The replication manager + * @return Builder reference for chaining + */ + ExecutionContextBuilder &WithReplicationManager( + common::ManagedPointer replication_manager) { + replication_manager_ = replication_manager; + return *this; + } + + /** + * Set the recovery manager for the execution context. + * @param recovery_manager The recovery manager + * @return Builder reference for chaining + */ + ExecutionContextBuilder &WithRecoveryManager(common::ManagedPointer recovery_manager) { + recovery_manager_ = recovery_manager; + return *this; + } + + private: + /** The query execution settings */ + std::optional exec_settings_; + /** The query parmeters */ + std::vector> parameters_; + /** The database OID */ + catalog::db_oid_t db_oid_{catalog::INVALID_DATABASE_OID}; + /** The associated transaction */ + std::optional> txn_; + /** The output callback */ + std::optional output_callback_; + /** The output schema */ + std::optional> output_schema_; + /** The catalog accessor */ + std::optional> catalog_accessor_; + /** The metrics manager */ + std::optional> metrics_manager_; + /** The replication manager */ + std::optional> replication_manager_; + /** The recovery manager */ + std::optional> recovery_manager_; +}; + +} // namespace noisepage::execution::exec diff --git a/src/include/execution/exec/execution_settings.h b/src/include/execution/exec/execution_settings.h index a6f28a559f..84dfc0741a 100644 --- a/src/include/execution/exec/execution_settings.h +++ b/src/include/execution/exec/execution_settings.h @@ -32,6 +32,10 @@ namespace noisepage::tpch { class Workload; } // namespace noisepage::tpch +namespace noisepage::procbench { +class Workload; +} // namespace noisepage::procbench + namespace noisepage::selfdriving { namespace pilot { class PilotUtil; @@ -109,6 +113,7 @@ class EXPORT ExecutionSettings { // MiniRunners needs to set query_identifier and pipeline_operating_units_. friend class noisepage::runner::ExecutionRunners; friend class noisepage::tpch::Workload; + friend class noisepage::procbench::Workload; friend class noisepage::execution::SqlBasedTest; friend class noisepage::optimizer::IdxJoinTest_SimpleIdxJoinTest_Test; friend class noisepage::optimizer::IdxJoinTest_MultiPredicateJoin_Test; diff --git a/src/include/execution/exec/output.h b/src/include/execution/exec/output.h index 6882786402..dbb16d4591 100644 --- a/src/include/execution/exec/output.h +++ b/src/include/execution/exec/output.h @@ -89,8 +89,9 @@ class EXPORT OutputBuffer { private: sql::MemoryPool *memory_pool_; - uint32_t num_tuples_; - uint32_t tuple_size_; + // TODO(Kyle): Tanuj made this atomic, does it need to be? + std::uint32_t num_tuples_; + std::uint32_t tuple_size_; byte *tuples_; /** @@ -138,7 +139,7 @@ class OutputWriter { * @param out packet writer to use * @param field_formats reference to the field formats for this query */ - OutputWriter(const common::ManagedPointer schema, + OutputWriter(const common::ManagedPointer schema, const common::ManagedPointer out, const std::vector &field_formats) : schema_(schema), out_(out), field_formats_(field_formats) {} @@ -165,8 +166,11 @@ class OutputWriter { * (parallel scan) */ std::mutex output_synchronization_; - const common::ManagedPointer schema_; + /** The output schema */ + const common::ManagedPointer schema_; + /** The output writer */ const common::ManagedPointer out_; + /** The field formats */ const std::vector &field_formats_; }; diff --git a/src/include/execution/functions/function_context.h b/src/include/execution/functions/function_context.h index 2a5aba2cbe..06f4603f0d 100644 --- a/src/include/execution/functions/function_context.h +++ b/src/include/execution/functions/function_context.h @@ -1,94 +1,145 @@ #pragma once +#include #include #include #include #include "catalog/catalog_defs.h" #include "common/managed_pointer.h" +#include "execution/ast/ast.h" #include "execution/ast/builtins.h" +#include "execution/ast/context.h" +#include "execution/util/region.h" namespace noisepage::execution::functions { /** - * @brief Stores execution and type information about a stored procedure + * @brief Stores execution and type information about a stored procedure. */ class FunctionContext { public: /** - * Creates a FunctionContext object + * Construct a FunctionContext instance. * @param func_name Name of function * @param func_ret_type Return type of function - * @param args_type Vector of argument types + * @param arg_types Vector of argument types */ FunctionContext(std::string func_name, execution::sql::SqlTypeId func_ret_type, - std::vector &&args_type) + std::vector &&arg_types) : func_name_(std::move(func_name)), func_ret_type_(func_ret_type), - args_type_(std::move(args_type)), + arg_types_(std::move(arg_types)), is_builtin_{false}, is_exec_ctx_required_{false} {} + /** - * Creates a FunctionContext object for a builtin function + * Construct a FunctionContext instance for a builtin function. * @param func_name Name of function * @param func_ret_type Return type of function - * @param args_type Vector of argument types + * @param arg_types Vector of argument types * @param builtin Which builtin this context refers to * @param is_exec_ctx_required true if this function requires an execution context var as its first argument */ - FunctionContext(std::string func_name, execution::sql::SqlTypeId func_ret_type, - std::vector &&args_type, ast::Builtin builtin, - bool is_exec_ctx_required = false) + FunctionContext(std::string func_name, sql::SqlTypeId func_ret_type, std::vector &&arg_types, + ast::Builtin builtin, bool is_exec_ctx_required = false) : func_name_(std::move(func_name)), func_ret_type_(func_ret_type), - args_type_(std::move(args_type)), + arg_types_(std::move(arg_types)), is_builtin_{true}, builtin_{builtin}, is_exec_ctx_required_{is_exec_ctx_required} {} + /** - * @return The name of the function represented by this context object + * Construct a FunctionContext instance for a non-builtin function. + * @param func_name Name of function= + * @param func_ret_type Return type of function + * @param arg_types Vector of argument types + * @param ast_region The region associated with the AST context + * @param ast_context The AST context for the function + * @param file The AST file + * @param is_exec_ctx_required Flag indicating whether an + * execution context is required for this function */ + FunctionContext(std::string func_name, sql::SqlTypeId func_ret_type, std::vector &&arg_types, + std::unique_ptr ast_region, std::unique_ptr ast_context, ast::File *file, + bool is_exec_ctx_required = true) + : func_name_(std::move(func_name)), + func_ret_type_(func_ret_type), + arg_types_(std::move(arg_types)), + is_builtin_{false}, + is_exec_ctx_required_{is_exec_ctx_required}, + ast_region_{std::move(ast_region)}, + ast_context_{std::move(ast_context)}, + file_{file} {} + + /** @return The name of the function represented by this context object. */ const std::string &GetFunctionName() const { return func_name_; } - /** - * @return The vector of type arguments of the function represented by this context object - */ - const std::vector &GetFunctionArgsType() const { return args_type_; } + /** @return The vector of type arguments of the function represented by this context object */ + const std::vector &GetFunctionArgTypes() const { return arg_types_; } /** - * Gets the return type of the function represented by this object - * @return return type of this function + * Gets the return type of the function represented by this object. + * @return The return type of this function. */ execution::sql::SqlTypeId GetFunctionReturnType() const { return func_ret_type_; } - /** - * @return true iff this represents a builtin function - */ + /** @return `true` if this represents a builtin function, `false` otherwise. */ bool IsBuiltin() const { return is_builtin_; } - /** - * @return returns what builtin function this represents - */ + /** @return The builtin function this procedure represents. */ ast::Builtin GetBuiltin() const { NOISEPAGE_ASSERT(IsBuiltin(), "Getting a builtin from a non-builtin function"); return builtin_; } - /** - * @return returns if this function requires an execution context - */ + /** @return `true` if this function requires an execution context, `false` otherwise. */ bool IsExecCtxRequired() const { - NOISEPAGE_ASSERT(IsBuiltin(), "IsExecCtxRequired is only valid or a builtin function"); + // TODO(Kyle): Is it valid to query execution context requirement for non-builtins? return is_exec_ctx_required_; } + /** @return The main function declaration of this UDF. */ + common::ManagedPointer GetMainFunctionDecl() const { + NOISEPAGE_ASSERT(!IsBuiltin(), "Getting a non-builtin from a builtin function"); + return common::ManagedPointer( + reinterpret_cast(file_->Declarations().back())); + } + + /** @return The file with the function declaration and supporting declarations. */ + ast::File *GetFile() const { + NOISEPAGE_ASSERT(!IsBuiltin(), "Getting a non-builtin from a builtin function"); + return file_; + } + + /** @return The AST context for this procedure. */ + ast::Context *GetASTContext() const { + NOISEPAGE_ASSERT(!IsBuiltin(), "No AST Context associated with builtin function"); + return ast_context_.get(); + } + private: + /** The function name */ std::string func_name_; - execution::sql::SqlTypeId func_ret_type_; - std::vector args_type_; + /** The function return type */ + sql::SqlTypeId func_ret_type_; + /** The function argument types */ + std::vector arg_types_; + /** `true` if this function is a builtin */ bool is_builtin_; + /** The builtin function, if applicable */ ast::Builtin builtin_; + /** `true` if an execution context is required for this function, `false` otherwise */ bool is_exec_ctx_required_; + + /** The associated AST region */ + std::unique_ptr ast_region_; + /** The associated AST context */ + std::unique_ptr ast_context_; + + /** The associated file */ + ast::File *file_; }; } // namespace noisepage::execution::functions diff --git a/src/include/execution/parsing/parser.h b/src/include/execution/parsing/parser.h index e5a35c1ded..4006034de4 100644 --- a/src/include/execution/parsing/parser.h +++ b/src/include/execution/parsing/parser.h @@ -121,6 +121,8 @@ class Parser { ast::Expr *ParseUnaryOpExpr(); + ast::Expr *ParseLambdaExpr(); + ast::Expr *ParsePrimaryExpr(); ast::Expr *ParseOperand(); @@ -139,6 +141,8 @@ class Parser { ast::Expr *ParseMapType(); + ast::Expr *ParseLambdaType(); + private: // The source code scanner Scanner *scanner_; diff --git a/src/include/execution/parsing/token.h b/src/include/execution/parsing/token.h index 16a8e635dd..91a4129fad 100644 --- a/src/include/execution/parsing/token.h +++ b/src/include/execution/parsing/token.h @@ -64,6 +64,7 @@ namespace noisepage::execution::parsing { K(FUN, "fun", 0) \ K(IF, "if", 0) \ K(IN, "in", 0) \ + K(LAMBDA, "lambda", 0) \ K(MAP, "map", 0) \ K(NIL, "nil", 0) \ K(RETURN, "return", 0) \ diff --git a/src/include/execution/sema/error_message.h b/src/include/execution/sema/error_message.h index 8c0a693384..f874ac5fee 100644 --- a/src/include/execution/sema/error_message.h +++ b/src/include/execution/sema/error_message.h @@ -74,7 +74,7 @@ namespace sema { F(MissingArrayLength, "missing array length (either compile-time number or '*')", ()) \ F(NotASQLAggregate, "'%0' is not a SQL aggregator type", (ast::Type *)) \ F(BadParallelScanFunction, \ - "parallel scan function must have type (*ExecutionContext, *TableVectorIterator)->nil, " \ + "parallel scan function must have type (*QueryState, *PipelineState, *TableVectorIterator)->nil, " \ "received '%0'", \ (ast::Type *)) \ F(BadHookFunction, \ @@ -95,7 +95,8 @@ namespace sema { "indexIteratorFree() expects (*IndexIterator) argument " \ "types. Received type '%0' in position %1", \ (ast::Type *, uint32_t)) \ - F(IsValNullExpectsSqlValue, "@isValNull() expects a SQL value input, received type '%0'", (ast::Type *)) + F(IsValNullExpectsSqlValue, "@isValNull() expects a SQL value input, received type '%0'", (ast::Type *)) \ + F(NoScopeToBreak, "There is no scope to break from in position", ()) /// Define the ErrorMessageId enumeration enum class ErrorMessageId : uint16_t { diff --git a/src/include/execution/sema/scope.h b/src/include/execution/sema/scope.h index 34533825d6..5be9e6676f 100644 --- a/src/include/execution/sema/scope.h +++ b/src/include/execution/sema/scope.h @@ -1,6 +1,8 @@ #pragma once -#include +#include +#include +#include #include "execution/ast/identifier.h" #include "execution/util/execution_common.h" @@ -67,6 +69,18 @@ class Scope { */ ast::Type *LookupLocal(ast::Identifier name) const; + /** + * Get the kind of the scope. + * @return The kind + */ + Kind GetKind() const; + + /** + * Get the local variables for the scope. + * @return A collection of the scope's locals + */ + std::vector> GetLocals() const; + /** * @return the parent scope */ @@ -78,7 +92,7 @@ class Scope { // The scope kind. Kind scope_kind_; // The mapping of identifiers to their types. - llvm::DenseMap decls_; + std::unordered_map decls_; }; } // namespace sema diff --git a/src/include/execution/sema/sema.h b/src/include/execution/sema/sema.h index c07121fed9..31c36598f9 100644 --- a/src/include/execution/sema/sema.h +++ b/src/include/execution/sema/sema.h @@ -166,6 +166,7 @@ class Sema : public ast::AstVisitor { void CheckBuiltinParamCall(ast::CallExpr *call, ast::Builtin builtin); void CheckBuiltinCteScanCall(ast::CallExpr *call, ast::Builtin builtin); void CheckBuiltinStringCall(ast::CallExpr *call, ast::Builtin builtin); + void CheckBuiltinRandomCall(ast::CallExpr *call, ast::Builtin builtin); void CheckBuiltinReplicationCall(ast::CallExpr *call, ast::Builtin builtin); diff --git a/src/include/execution/sql/ddl_executors.h b/src/include/execution/sql/ddl_executors.h index 21931f7a70..8ece11b335 100644 --- a/src/include/execution/sql/ddl_executors.h +++ b/src/include/execution/sql/ddl_executors.h @@ -11,8 +11,10 @@ class CreateNamespacePlanNode; class CreateTablePlanNode; class CreateIndexPlanNode; class CreateViewPlanNode; +class CreateFunctionPlanNode; class DropDatabasePlanNode; class DropNamespacePlanNode; +class DropFunctionPlanNode; class DropTablePlanNode; class DropIndexPlanNode; } // namespace noisepage::planner @@ -32,7 +34,7 @@ class DDLExecutors { DDLExecutors() = delete; /** - * @param node node to executed + * @param node node to execute * @param accessor accessor to use for execution * @return true if operation succeeded, false otherwise */ @@ -40,65 +42,81 @@ class DDLExecutors { common::ManagedPointer accessor); /** - * @param node node to executed + * @param node node to execute * @param accessor accessor to use for execution - * @return true if operation succeeded, false otherwise + * @return `true` if operation succeeds, `false` otherwise */ static bool CreateNamespaceExecutor(common::ManagedPointer node, common::ManagedPointer accessor); /** - * @param node node to executed + * @param node node to execute + * @param accessor accessor to use for execution + * @return `true` if the operation succeeds, `false` otherwise + */ + static bool CreateFunctionExecutor(common::ManagedPointer node, + common::ManagedPointer accessor); + + /** + * @param node node to execute * @param accessor accessor to use for execution * @param connection_db database for the current connection - * @return true if operation succeeded, false otherwise + * @return `true` if operation succeeds, `false` otherwise */ static bool CreateTableExecutor(common::ManagedPointer node, common::ManagedPointer accessor, catalog::db_oid_t connection_db); /** - * @param node node to executed + * @param node node to execute * @param accessor accessor to use for execution - * @return true if operation succeeded, false otherwise + * @return `true` if operation succeeds, `false` otherwise */ static bool CreateIndexExecutor(common::ManagedPointer node, common::ManagedPointer accessor); /** - * @param node node to executed + * @param node node to execute * @param accessor accessor to use for execution * @param connection_db database for the current connection - * @return true if operation succeeded, false otherwise + * @return `true` if operation succeeds, `false` otherwise */ static bool DropDatabaseExecutor(common::ManagedPointer node, common::ManagedPointer accessor, catalog::db_oid_t connection_db); /** - * @param node node to executed + * @param node node to execute * @param accessor accessor to use for execution - * @return true if operation succeeded, false otherwise + * @return `true` if operation succeeds, `false` otherwise */ static bool DropNamespaceExecutor(common::ManagedPointer node, common::ManagedPointer accessor); /** - * @param node node to executed + * @param node node to execute * @param accessor accessor to use for execution - * @return true if operation succeeded, false otherwise + * @return `true` if operation succeeds, `false` otherwise */ static bool DropTableExecutor(common::ManagedPointer node, common::ManagedPointer accessor); /** - * @param node node to executed + * @param node node to execute * @param accessor accessor to use for execution - * @return true if operation succeeded, false otherwise + * @return `true` if operation succeeds, `false` otherwise */ static bool DropIndexExecutor(common::ManagedPointer node, common::ManagedPointer accessor); + /** + * @param node node to execute + * @param accessor accessor to use for execution + * @return `true` if operation succeeds, `false` otherwise + */ + static bool DropFunctionExecutor(common::ManagedPointer node, + common::ManagedPointer accessor); + private: static bool CreateIndex(common::ManagedPointer accessor, catalog::namespace_oid_t ns, const std::string &name, catalog::table_oid_t table, diff --git a/src/include/execution/sql/functions/system_functions.h b/src/include/execution/sql/functions/system_functions.h index 7db0646d20..eee51a4dac 100644 --- a/src/include/execution/sql/functions/system_functions.h +++ b/src/include/execution/sql/functions/system_functions.h @@ -19,9 +19,17 @@ class EXPORT SystemFunctions { SystemFunctions() = delete; /** - * Gets the version of the database + * Get the version of the database. + * @param ctx The execution context + * @param result The out parameter that receives version string */ static void Version(exec::ExecutionContext *ctx, StringVal *result); + + /** + * Generate a random floating point value on [0.0, 1.0). + * @param result The out parameter that receives the result + */ + static void Random(Real *result); }; } // namespace noisepage::execution::sql diff --git a/src/include/execution/sql/sql.h b/src/include/execution/sql/sql.h index 5e6fecffce..346d6aca58 100644 --- a/src/include/execution/sql/sql.h +++ b/src/include/execution/sql/sql.h @@ -129,8 +129,18 @@ uint16_t GetSqlTypeIdSize(SqlTypeId type); */ std::size_t GetTypeIdAlignment(TypeId type); +/** + * Parse a SQL type ID from a string. + * @param type_string The string representation of the type name + * @return The SQL type ID + */ SqlTypeId SqlTypeIdFromString(const std::string &type_string); +/** + * Convert a SQL type ID to a human-readable string. + * @param type The SQL type ID + * @return The string representation of the type + */ std::string SqlTypeIdToString(SqlTypeId type); /** diff --git a/src/include/execution/vm/bytecode_emitter.h b/src/include/execution/vm/bytecode_emitter.h index f5d532b519..51030c1756 100644 --- a/src/include/execution/vm/bytecode_emitter.h +++ b/src/include/execution/vm/bytecode_emitter.h @@ -59,13 +59,21 @@ class BytecodeEmitter { // ------------------------------------------------------- /** - * Emit arbitrary assignment code + * Emit arbitrary assignment code. * @param bytecode assignment bytecode * @param dest destination variable * @param src source variable */ void EmitAssign(Bytecode bytecode, LocalVar dest, LocalVar src); + /** + * Emit arbitrary assignment code. + * @param dest destination variable + * @param src source variable + * @param len length + */ + void EmitAssignN(LocalVar dest, LocalVar src, uint32_t len); + /** * Emit assignment code for 1 byte values. * @param dest destination variable @@ -162,11 +170,19 @@ class BytecodeEmitter { /** * Emit a function call - * @param func_id id of the function to call - * @param params parameters of the function + * @param func_id The ID of the function to call. + * @param params The parameters of the function. */ void EmitCall(FunctionId func_id, const std::vector ¶ms); + /** + * Create a function that emits a function call. + * @param params The parameters of the function. + * @return A new callable that, when invoked with a FunctionID, + * emits a fuinction call into the bytecode stream. + */ + std::function DeferredEmitCall(const std::vector ¶ms); + /** * Emit a return bytecode */ @@ -476,13 +492,26 @@ class BytecodeEmitter { void EmitConcat(LocalVar ret, LocalVar exec_ctx, LocalVar inputs, uint32_t num_inputs); private: - /** Copy a scalar immediate value into the bytecode stream */ + /** + * Copy a scalar immediate value into the bytecode stream. + * @param val The scalar value to emit into the stream. + */ template auto EmitScalarValue(const T val) -> std::enable_if_t> { bytecode_->insert(bytecode_->end(), sizeof(T), 0); *reinterpret_cast(&*(bytecode_->end() - sizeof(T))) = val; } + /** + * Copy a scalar immediate value into the bytecode stream at specified index. + * @param val The scalar value to emit into the stream. + * @param index The index in the stream at which to emit the value. + */ + template + auto EmitScalarValue(const T val, std::size_t index) -> std::enable_if_t> { + *reinterpret_cast(&*(bytecode_->begin() + index)) = val; + } + /** Emit a bytecode */ void EmitImpl(const Bytecode bytecode) { EmitScalarValue(Bytecodes::ToByte(bytecode)); } diff --git a/src/include/execution/vm/bytecode_function_info.h b/src/include/execution/vm/bytecode_function_info.h index 1a169117bc..93eae8bbec 100644 --- a/src/include/execution/vm/bytecode_function_info.h +++ b/src/include/execution/vm/bytecode_function_info.h @@ -1,6 +1,7 @@ #pragma once #include +#include #include #include #include @@ -287,6 +288,23 @@ class FunctionInfo { */ uint32_t GetParamsCount() const noexcept { return num_params_; } + /** + * @brief Defer an action for the current function. + * + * This functionality is used for TPL lambda expressions. + * When we visit a lambda expression in the nody of the + * current function, we defer an action that in turn visits + * the body of the lambda. This action is evaluated when we + * later visit the declaration for the function itself. + */ + void DeferAction(std::function &&action) { actions_.push_back(std::move(action)); } + + /** + * @return `true` if the TBC function represented by this object + * is generated by a TPL lambda, `false` otherwise. + */ + bool IsLambda() const { return is_lambda_; } + private: friend class BytecodeGenerator; @@ -302,6 +320,15 @@ class FunctionInfo { // Allocate a new local variable in the function. LocalVar NewLocal(ast::Type *type, const std::string &name, LocalInfo::Kind kind); + // The captures in the event this function is a TPL lambda. + LocalVar captures_; + + // Indicates whether this TBC function is generated by a TPL lambda. + bool is_lambda_{false}; + + // The collection of deferred actions if this function is a TPL lambda. + std::vector> actions_; + private: // The ID of the function in the module. IDs are unique within a module. FunctionId id_; diff --git a/src/include/execution/vm/bytecode_generator.h b/src/include/execution/vm/bytecode_generator.h index b460081b9c..a9fc19227c 100644 --- a/src/include/execution/vm/bytecode_generator.h +++ b/src/include/execution/vm/bytecode_generator.h @@ -67,9 +67,15 @@ class BytecodeGenerator final : public ast::AstVisitor { class RValueResultScope; class BytecodePositionScope; - // Allocate a new function ID - FunctionInfo *AllocateFunc(const std::string &func_name, ast::FunctionType *func_type); + /** + * Allocate a new function. + * @param function_name The function name + * @param function_type The function type + * @return A non-owning pointer to the allocated function + */ + FunctionInfo *AllocateFunction(const std::string &function_name, ast::FunctionType *function_type); + // Visit a transaction abort call expression void VisitAbortTxn(ast::CallExpr *call); // ONLY FOR TESTING! @@ -82,6 +88,7 @@ class BytecodeGenerator final : public ast::AstVisitor { void VisitNullValueCall(ast::CallExpr *call, ast::Builtin builtin); void VisitSqlStringLikeCall(ast::CallExpr *call); void VisitBuiltinDateFunctionCall(ast::CallExpr *call, ast::Builtin builtin); + void VisitBuiltinRandomFunctionCall(ast::CallExpr *call, ast::Builtin builtin); void VisitBuiltinTableIterCall(ast::CallExpr *call, ast::Builtin builtin); void VisitBuiltinTableIterParallelCall(ast::CallExpr *call); void VisitBuiltinVPICall(ast::CallExpr *call, ast::Builtin builtin); @@ -156,6 +163,9 @@ class BytecodeGenerator final : public ast::AstVisitor { void VisitExpressionForTest(ast::Expr *expr, BytecodeLabel *then_label, BytecodeLabel *else_label, TestFallthrough fallthrough); + // Visit the body of a break statement + void VisitBreakStatement(ast::BreakStmt *break_stmt); + // Visit the body of an iteration statement void VisitIterationStatement(ast::IterationStmt *iteration, LoopBuilder *loop_builder); @@ -187,7 +197,9 @@ class BytecodeGenerator final : public ast::AstVisitor { void SetExecutionResult(ExpressionResultScope *exec_result) { execution_result_ = exec_result; } // Access the current function that's being generated. May be NULL. - FunctionInfo *GetCurrentFunction() { return &functions_.back(); } + FunctionInfo *GetCurrentFunction() { return functions_[current_fn_].get(); } + + void EnterFunction(FunctionId id) { current_fn_ = id; } private: // The data section of the module @@ -202,16 +214,23 @@ class BytecodeGenerator final : public ast::AstVisitor { std::unordered_map static_string_cache_; // Information about all generated functions - std::vector functions_; + std::vector> functions_; + + // The ID of the current function. + FunctionId current_fn_{0}; // Cache of function names to IDs for faster lookup std::unordered_map func_map_; + std::unordered_map>> deferred_function_create_actions_; // Emitter to write bytecode into the code section BytecodeEmitter emitter_; // RAII struct to capture semantics of expression evaluation ExpressionResultScope *execution_result_{nullptr}; + + // The loop builder for the current loop. + LoopBuilder *current_loop_{nullptr}; }; } // namespace noisepage::execution::vm diff --git a/src/include/execution/vm/bytecode_handlers.h b/src/include/execution/vm/bytecode_handlers.h index aa05cc3f5c..ecbccaf767 100644 --- a/src/include/execution/vm/bytecode_handlers.h +++ b/src/include/execution/vm/bytecode_handlers.h @@ -171,6 +171,10 @@ VM_OP_HOT void OpAssign4(int32_t *dest, int32_t src) { *dest = src; } VM_OP_HOT void OpAssign8(int64_t *dest, int64_t src) { *dest = src; } +VM_OP_HOT void OpAssignN(noisepage::byte *dest, const noisepage::byte *const src, uint32_t len) { + std::memcpy(dest, src, len); +} + VM_OP_HOT void OpAssignImm1(int8_t *dest, int8_t src) { *dest = src; } VM_OP_HOT void OpAssignImm2(int16_t *dest, int16_t src) { *dest = src; } @@ -1461,6 +1465,9 @@ VM_OP_WARM void OpSorterIteratorSkipRows(noisepage::execution::sql::SorterIterat VM_OP void OpSorterIteratorFree(noisepage::execution::sql::SorterIterator *iter); +VM_OP void OpPushParamContext(noisepage::execution::exec::ExecutionContext **new_ctx, + noisepage::execution::exec::ExecutionContext *ctx); + // --------------------------------------------------------- // Output // --------------------------------------------------------- @@ -1896,6 +1903,10 @@ VM_OP_WARM void OpVersion(noisepage::execution::exec::ExecutionContext *ctx, noisepage::execution::sql::SystemFunctions::Version(ctx, result); } +VM_OP_WARM void OpRandom(noisepage::execution::sql::Real *result) { + noisepage::execution::sql::SystemFunctions::Random(result); +} + VM_OP_WARM void OpInitCap(noisepage::execution::sql::StringVal *result, noisepage::execution::exec::ExecutionContext *ctx, const noisepage::execution::sql::StringVal *str) { @@ -2166,21 +2177,26 @@ VM_OP_WARM void OpExtractYearFromDate(noisepage::execution::sql::Integer *result } } +// --------------------------------- +// Transaction Calls +// --------------------------------- + VM_OP_WARM void OpAbortTxn(noisepage::execution::exec::ExecutionContext *exec_ctx) { exec_ctx->GetTxn()->SetMustAbort(); throw noisepage::ABORT_EXCEPTION("transaction aborted"); } -// Parameter calls +// --------------------------------- +// Parameter Calls +// --------------------------------- + +// TODO(Kyle): Is it ever the case that we pass a NULL CVE to call? #define GEN_SCALAR_PARAM_GET(Name, SqlType) \ VM_OP_HOT void OpGetParam##Name(noisepage::execution::sql::SqlType *ret, \ noisepage::execution::exec::ExecutionContext *exec_ctx, uint32_t param_idx) { \ - const auto &cve = exec_ctx->GetParam(param_idx); \ - if (cve.IsNull()) { \ - ret->is_null_ = true; \ - } else { \ - *ret = cve.Get##SqlType(); \ - } \ + const auto &val = \ + *reinterpret_cast(exec_ctx->GetParam(param_idx).Get()); \ + *ret = val; \ } GEN_SCALAR_PARAM_GET(Bool, BoolVal) @@ -2195,6 +2211,29 @@ GEN_SCALAR_PARAM_GET(TimestampVal, TimestampVal) GEN_SCALAR_PARAM_GET(String, StringVal) #undef GEN_SCALAR_PARAM_GET +#define GEN_SCALAR_PARAM_ADD(Name, SqlType, typeId) \ + VM_OP_HOT void OpAddParam##Name(noisepage::execution::exec::ExecutionContext *exec_ctx, \ + noisepage::execution::sql::SqlType *ret) { \ + exec_ctx->AddParam(noisepage::common::ManagedPointer( \ + reinterpret_cast(ret))); \ + } + +GEN_SCALAR_PARAM_ADD(Bool, BoolVal, BOOLEAN) +GEN_SCALAR_PARAM_ADD(TinyInt, Integer, TINYINT) +GEN_SCALAR_PARAM_ADD(SmallInt, Integer, SMALLINT) +GEN_SCALAR_PARAM_ADD(Int, Integer, INTEGER) +GEN_SCALAR_PARAM_ADD(BigInt, Integer, BIGINT) +GEN_SCALAR_PARAM_ADD(Real, Real, DECIMAL) +GEN_SCALAR_PARAM_ADD(Double, Real, DECIMAL) +GEN_SCALAR_PARAM_ADD(DateVal, DateVal, DATE) +GEN_SCALAR_PARAM_ADD(TimestampVal, TimestampVal, TIMESTAMP) +GEN_SCALAR_PARAM_ADD(String, StringVal, VARCHAR) +#undef GEN_SCALAR_PARAM_ADD + +VM_OP_HOT void OpStartNewParams(noisepage::execution::exec::ExecutionContext *exec_ctx) { exec_ctx->StartParams(); } + +VM_OP_HOT void OpFinishParams(noisepage::execution::exec::ExecutionContext *exec_ctx) { exec_ctx->FinishParams(); } + // --------------------------------- // Replication functions // --------------------------------- diff --git a/src/include/execution/vm/bytecode_module.h b/src/include/execution/vm/bytecode_module.h index 3689dc825f..c1d38cad77 100644 --- a/src/include/execution/vm/bytecode_module.h +++ b/src/include/execution/vm/bytecode_module.h @@ -2,6 +2,7 @@ #include #include +#include #include #include @@ -28,7 +29,7 @@ class BytecodeModule { * @param static_locals All statically allocated variables in the data section. */ BytecodeModule(std::string name, std::vector &&code, std::vector &&data, - std::vector &&functions, std::vector &&static_locals); + std::vector> &&functions, std::vector &&static_locals); /** * This class cannot be copied or moved. @@ -42,7 +43,7 @@ class BytecodeModule { const FunctionInfo *GetFuncInfoById(const FunctionId func_id) const { // Function IDs are dense, so the given ID must be in the range [0, # functions) NOISEPAGE_ASSERT(func_id < GetFunctionCount(), "Invalid function"); - return &functions_[func_id]; + return functions_[func_id].get(); } /** @@ -50,7 +51,8 @@ class BytecodeModule { * no such function exists, a NULL pointer is returned. */ const FunctionInfo *LookupFuncInfoByName(const std::string &name) const { - for (const FunctionInfo &info : functions_) { + for (const auto &function : functions_) { + const FunctionInfo &info = *function; if (info.GetName() == name) { return &info; } @@ -92,7 +94,14 @@ class BytecodeModule { /** * @return A const-view of the metadata for all functions in this module. */ - const std::vector &GetFunctionsInfo() const { return functions_; } + std::vector GetFunctionsInfo() const { + // TODO(Kyle): Cache these results? + std::vector functions{}; + functions.reserve(functions_.size()); + std::transform(functions_.cbegin(), functions_.cend(), std::back_inserter(functions), + [](const std::unique_ptr &f) { return f.get(); }); + return functions; + } /** * @return A const-view of the metadata for all static-locals in this module. @@ -156,7 +165,7 @@ class BytecodeModule { // The raw static data for ALL static data stored contiguously. const std::vector data_; // Metadata for all functions. - const std::vector functions_; + const std::vector> functions_; // Metadata for all static data. const std::vector static_locals_; }; diff --git a/src/include/execution/vm/bytecodes.h b/src/include/execution/vm/bytecodes.h index 51fda63006..359103b4f2 100644 --- a/src/include/execution/vm/bytecodes.h +++ b/src/include/execution/vm/bytecodes.h @@ -84,6 +84,7 @@ namespace noisepage::execution::vm { F(Assign2, OperandType::Local, OperandType::Local) \ F(Assign4, OperandType::Local, OperandType::Local) \ F(Assign8, OperandType::Local, OperandType::Local) \ + F(AssignN, OperandType::Local, OperandType::Local, OperandType::UImm4) \ F(AssignImm1, OperandType::Local, OperandType::Imm1) \ F(AssignImm2, OperandType::Local, OperandType::Imm2) \ F(AssignImm4, OperandType::Local, OperandType::Imm4) \ @@ -771,6 +772,7 @@ namespace noisepage::execution::vm { \ /* Miscellaneous functions. */ \ F(Version, OperandType::Local, OperandType::Local) \ + F(Random, OperandType::Local) \ \ /* Parameter support. */ \ F(GetParamBool, OperandType::Local, OperandType::Local, OperandType::Local) \ @@ -783,6 +785,18 @@ namespace noisepage::execution::vm { F(GetParamDateVal, OperandType::Local, OperandType::Local, OperandType::Local) \ F(GetParamTimestampVal, OperandType::Local, OperandType::Local, OperandType::Local) \ F(GetParamString, OperandType::Local, OperandType::Local, OperandType::Local) \ + F(AddParamBool, OperandType::Local, OperandType::Local) \ + F(AddParamTinyInt, OperandType::Local, OperandType::Local) \ + F(AddParamSmallInt, OperandType::Local, OperandType::Local) \ + F(AddParamInt, OperandType::Local, OperandType::Local) \ + F(AddParamBigInt, OperandType::Local, OperandType::Local) \ + F(AddParamReal, OperandType::Local, OperandType::Local) \ + F(AddParamDouble, OperandType::Local, OperandType::Local) \ + F(AddParamDateVal, OperandType::Local, OperandType::Local) \ + F(AddParamTimestampVal, OperandType::Local, OperandType::Local) \ + F(AddParamString, OperandType::Local, OperandType::Local) \ + F(StartNewParams, OperandType::Local) \ + F(FinishParams, OperandType::Local) \ \ /* FOR TESTING ONLY */ \ F(TestCatalogLookup, OperandType::Local, OperandType::Local, OperandType::StaticLocal, OperandType::UImm4, \ diff --git a/src/include/execution/vm/control_flow_builders.h b/src/include/execution/vm/control_flow_builders.h index 02479eebe5..ccc68a0e82 100644 --- a/src/include/execution/vm/control_flow_builders.h +++ b/src/include/execution/vm/control_flow_builders.h @@ -79,6 +79,15 @@ class LoopBuilder : public BreakableBlockBuilder { */ explicit LoopBuilder(BytecodeGenerator *generator) : BreakableBlockBuilder(generator) {} + /** + * Construct a loop builder. + * @param generator The generator the loop writes. + * @param prev The previous (outer) loop in the current + * code generation context + */ + explicit LoopBuilder(BytecodeGenerator *generator, LoopBuilder *prev = nullptr) + : BreakableBlockBuilder(generator), prev_loop_(prev) {} + /** * Destructor. */ @@ -104,6 +113,11 @@ class LoopBuilder : public BreakableBlockBuilder { */ void BindContinueTarget(); + /** + * Get the previous (outer) loop. + */ + LoopBuilder *GetPrevLoop() const; + private: /** @return The label associated with the header of the loop. */ BytecodeLabel *GetHeaderLabel() { return &header_label_; } @@ -114,6 +128,7 @@ class LoopBuilder : public BreakableBlockBuilder { private: BytecodeLabel header_label_; BytecodeLabel continue_label_; + LoopBuilder *prev_loop_; }; /** diff --git a/src/include/execution/vm/vm_defs.h b/src/include/execution/vm/execution_mode.h similarity index 100% rename from src/include/execution/vm/vm_defs.h rename to src/include/execution/vm/execution_mode.h diff --git a/src/include/execution/vm/llvm_engine.h b/src/include/execution/vm/llvm_engine.h index 269122f4eb..422d8c17d2 100644 --- a/src/include/execution/vm/llvm_engine.h +++ b/src/include/execution/vm/llvm_engine.h @@ -225,7 +225,7 @@ class LLVMEngine { /** * Process-wide LLVM engine settings. * - * TODO(Kyle): I'm not particularly happy with this setup - an inline + * NOTE(Kyle): I'm not particularly happy with this setup - an inline * static variable (essentially just a global with scoping) for managing * the settings for the LLVM engine. The ownership model should be * relatively simple - the LLVMEngine should own its settings, but diff --git a/src/include/execution/vm/module.h b/src/include/execution/vm/module.h index 44ce8082d0..c9e19a79a6 100644 --- a/src/include/execution/vm/module.h +++ b/src/include/execution/vm/module.h @@ -10,9 +10,9 @@ #include "execution/ast/type.h" #include "execution/vm/bytecode_module.h" +#include "execution/vm/execution_mode.h" #include "execution/vm/llvm_engine.h" #include "execution/vm/module_metadata.h" -#include "execution/vm/vm_defs.h" namespace noisepage::execution::vm { @@ -62,7 +62,7 @@ class Module { /** * Look up a TPL function in this module by its name * @param name The name of the function to lookup - * @return A pointer to the function's info if it exists; null otherwise + * @return A pointer to the function's info if it exists; `nullptr` otherwise */ const FunctionInfo *GetFuncInfoByName(const std::string &name) const { return bytecode_module_->LookupFuncInfoByName(name); diff --git a/src/include/network/network_defs.h b/src/include/network/network_defs.h index 85907c63d1..ed29f2fa7b 100644 --- a/src/include/network/network_defs.h +++ b/src/include/network/network_defs.h @@ -109,9 +109,11 @@ enum class QueryType : uint8_t { QUERY_CREATE_INDEX, QUERY_CREATE_TRIGGER, QUERY_CREATE_SCHEMA, + QUERY_CREATE_FUNCTION, QUERY_CREATE_VIEW, QUERY_DROP_TABLE, QUERY_DROP_DB, + QUERY_DROP_FUNCTION, QUERY_DROP_INDEX, QUERY_DROP_TRIGGER, QUERY_DROP_SCHEMA, diff --git a/src/include/optimizer/child_property_deriver.h b/src/include/optimizer/child_property_deriver.h index a080b92217..358cd3e12d 100644 --- a/src/include/optimizer/child_property_deriver.h +++ b/src/include/optimizer/child_property_deriver.h @@ -276,6 +276,12 @@ class ChildPropertyDeriver : public OperatorVisitor { */ void Visit(const DropView *drop_view) override; + /** + * Visit a DropFunction operator + * @param drop_function operator + */ + void Visit(const DropFunction *drop_function) override; + /** * Visit an Analyze operator * @param analyze analyze operator diff --git a/src/include/optimizer/logical_operators.h b/src/include/optimizer/logical_operators.h index 475cb77a56..8bb1e4cc77 100644 --- a/src/include/optimizer/logical_operators.h +++ b/src/include/optimizer/logical_operators.h @@ -1903,6 +1903,49 @@ class LogicalDropView : public OperatorNodeContents { bool if_exists_; }; +/** + * Logical operator for DropFunction + */ +class LogicalDropFunction : public OperatorNodeContents { + public: + /** + * @param database_oid OID of the database + * @param proc_oid OID of the function to be dropped + * @param if_exists `true` if `IF EXISTS` specified + * @return LogicalDropFunction + */ + static Operator Make(catalog::db_oid_t database_oid, catalog::proc_oid_t proc_oid, bool if_exists); + + /** + * Copy + * @returns copy of this + */ + BaseOperatorNodeContents *Copy() const override; + + /** Comparison operator */ + bool operator==(const BaseOperatorNodeContents &r) override; + + /** @return The hash of the instance */ + common::hash_t Hash() const override; + + /** @return The OID of the database */ + catalog::db_oid_t GetDatabaseOid() const { return database_oid_; } + + /** @return The OID of the function to drop */ + catalog::proc_oid_t GetFunctionOid() const { return proc_oid_; } + + /** @return `true` if `IF EXISTS` specified */ + bool GetIfExists() const { return if_exists_; } + + private: + /** OID of the database */ + catalog::db_oid_t database_oid_; + /** OID of the function to drop */ + catalog::proc_oid_t proc_oid_; + /** `true` if `IF EXISTS` specified */ + bool if_exists_; +}; + /** * Logical operator for Analyze */ diff --git a/src/include/optimizer/operator_visitor.h b/src/include/optimizer/operator_visitor.h index 69f36ae655..f5862053ff 100644 --- a/src/include/optimizer/operator_visitor.h +++ b/src/include/optimizer/operator_visitor.h @@ -42,6 +42,7 @@ class DropIndex; class DropNamespace; class DropTrigger; class DropView; +class DropFunction; class Analyze; class LogicalGet; class LogicalExternalFileGet; @@ -77,6 +78,7 @@ class LogicalDropIndex; class LogicalDropNamespace; class LogicalDropTrigger; class LogicalDropView; +class LogicalDropFunction; class LogicalAnalyze; class LogicalCteScan; @@ -320,6 +322,12 @@ class OperatorVisitor { */ virtual void Visit(const DropView *drop_view) {} + /** + * Visit a DropFunction operator + * @param drop_function operator + */ + virtual void Visit(const DropFunction *drop_function) {} + /** * Visit a Analyze operator * @param analyze operator @@ -530,6 +538,12 @@ class OperatorVisitor { */ virtual void Visit(const LogicalDropView *logical_drop_view) {} + /** + * Visit a LogicalDropFunction operator + * @param logical_drop_function + */ + virtual void Visit(const LogicalDropFunction *logical_drop_function) {} + /** * Visit a LogicalAnalyze operator * @param logical_analyze operator diff --git a/src/include/optimizer/physical_operators.h b/src/include/optimizer/physical_operators.h index 928fab6c84..d2e9922a4c 100644 --- a/src/include/optimizer/physical_operators.h +++ b/src/include/optimizer/physical_operators.h @@ -2116,6 +2116,46 @@ class DropView : public OperatorNodeContents { bool if_exists_; }; +/** + * Physical operator for DropFunction + */ +class DropFunction : public OperatorNodeContents { + public: + /** + * @param database_oid OID of database + * @param proc_oid OID of view to drop + * @param if_exists `true` if `IF_EXISTS` specified + * @return + */ + static Operator Make(catalog::db_oid_t database_oid, catalog::proc_oid_t proc_oid, bool if_exists); + + /** @return A copy of this */ + BaseOperatorNodeContents *Copy() const override; + + /** Comparison operator */ + bool operator==(const BaseOperatorNodeContents &r) override; + + /** @return The hash of this instance */ + common::hash_t Hash() const override; + + /** @return The OID of the database */ + catalog::db_oid_t GetDatabaseOid() const { return database_oid_; } + + /** @return The OID of the function to drop */ + catalog::proc_oid_t GetFunctionOid() const { return proc_oid_; } + + /** @return `true` if `IF EXISTS` specified */ + bool GetIfExists() const { return if_exists_; } + + private: + /** OID of the database */ + catalog::db_oid_t database_oid_; + /** OID of the view to drop */ + catalog::proc_oid_t proc_oid_; + /** `true` if `IF EXISTS` specified */ + bool if_exists_; +}; + /** * Physical operator for Analyze */ diff --git a/src/include/optimizer/plan_generator.h b/src/include/optimizer/plan_generator.h index 70f14e44ec..6edcb130ab 100644 --- a/src/include/optimizer/plan_generator.h +++ b/src/include/optimizer/plan_generator.h @@ -303,6 +303,12 @@ class PlanGenerator : public OperatorVisitor { */ void Visit(const DropView *drop_view) override; + /** + * Visit a DropFunction operator + * @param drop_function operator + */ + void Visit(const DropFunction *drop_function) override; + /** * Visit a Analyze operator * @param analyze operator diff --git a/src/include/optimizer/rule.h b/src/include/optimizer/rule.h index 0b0d94facb..d1cdcc1bf6 100644 --- a/src/include/optimizer/rule.h +++ b/src/include/optimizer/rule.h @@ -59,6 +59,7 @@ enum class RuleType : uint32_t { DROP_NAMESPACE_TO_PHYSICAL, DROP_TRIGGER_TO_PHYSICAL, DROP_VIEW_TO_PHYSICAL, + DROP_FUNCTION_TO_PHYSICAL, // Don't move this one RewriteDelimiter, diff --git a/src/include/optimizer/rules/implementation_rules.h b/src/include/optimizer/rules/implementation_rules.h index 314dd4d943..805c6a8b09 100644 --- a/src/include/optimizer/rules/implementation_rules.h +++ b/src/include/optimizer/rules/implementation_rules.h @@ -923,6 +923,33 @@ class LogicalDropViewToPhysicalDropView : public Rule { OptimizationContext *context) const override; }; +/** + * Rule transforms Logical DropFunction -> Physical DropFunction + */ +class LogicalDropFunctionToPhysicalDropFunction : public Rule { + public: + /** Constructor */ + LogicalDropFunctionToPhysicalDropFunction(); + + /** + * Checks whether the given rule can be applied + * @param plan AbstractOptimizerNode to check + * @param context Current OptimizationContext executing under + * @returns Whether the input AbstractOptimizerNode passes the check + */ + bool Check(common::ManagedPointer plan, OptimizationContext *context) const override; + + /** + * Transforms the input expression using the given rule + * @param input Input AbstractOptimizerNode to transform + * @param transformed Vector of transformed AbstractOptimizerNodes + * @param context Current OptimizationContext executing under + */ + void Transform(common::ManagedPointer input, + std::vector> *transformed, + OptimizationContext *context) const override; +}; + /** * Rule transforms Logical Analyze -> Physical Analyze */ diff --git a/src/include/parser/create_function_statement.h b/src/include/parser/create_function_statement.h index e20b74b5cc..be620b71c9 100644 --- a/src/include/parser/create_function_statement.h +++ b/src/include/parser/create_function_statement.h @@ -9,13 +9,9 @@ #include "expression/abstract_expression.h" #include "parser/sql_statement.h" -// TODO(WAN): this file is messy -namespace noisepage { -namespace parser { +namespace noisepage::parser { /** Base function parameter. */ struct BaseFunctionParameter { - // TODO(WAN): there used to be a FuncParamMode that was never used? - /** Parameter data types. */ enum class DataType { INT, @@ -30,7 +26,8 @@ struct BaseFunctionParameter { VARCHAR, TEXT, BOOL, - BOOLEAN + BOOLEAN, + DATE }; /** @param datatype data type of the parameter */ @@ -41,6 +38,44 @@ struct BaseFunctionParameter { /** @return data type of the parameter */ DataType GetDataType() { return datatype_; } + /** @return internal type id of the parameter */ + static execution::sql::SqlTypeId DataTypeToTypeId(DataType datatype) { + switch (datatype) { + case DataType::INT: + return execution::sql::SqlTypeId::Integer; + case DataType::INTEGER: + return execution::sql::SqlTypeId::Integer; + case DataType::TINYINT: + return execution::sql::SqlTypeId::TinyInt; + case DataType::SMALLINT: + return execution::sql::SqlTypeId::SmallInt; + case DataType::BIGINT: + return execution::sql::SqlTypeId::BigInt; + case DataType::CHAR: + return execution::sql::SqlTypeId::Invalid; + case DataType::FLOAT: + // NOTE(Kyle): The "regular" SQL frontend automatically + // promotes FLOAT / REAL to DOUBLE PRECISION / FLOAT8; + // we do the same here to remain consistent + return execution::sql::SqlTypeId::Double; + case DataType::DOUBLE: + return execution::sql::SqlTypeId::Double; + case DataType::DECIMAL: + return execution::sql::SqlTypeId::Decimal; + case DataType::VARCHAR: + return execution::sql::SqlTypeId::Varchar; + case DataType::TEXT: + return execution::sql::SqlTypeId::Varchar; + case DataType::BOOL: + return execution::sql::SqlTypeId::Boolean; + case DataType::BOOLEAN: + return execution::sql::SqlTypeId::Boolean; + case DataType::DATE: + return execution::sql::SqlTypeId::Date; + } + return execution::sql::SqlTypeId::Invalid; + } + private: const DataType datatype_; }; @@ -98,29 +133,19 @@ class CreateFunctionStatement : public SQLStatement { void Accept(common::ManagedPointer v) override { v->Visit(common::ManagedPointer(this)); } - /** - * @return true if this function should replace existing definitions - */ + /** @return `true` if this function should replace existing definitions */ bool ShouldReplace() { return replace_; } - /** - * @return function name - */ + /** @return The function name */ std::string GetFuncName() { return func_name_; } - /** - * @return return type - */ + /** @return The function return type */ common::ManagedPointer GetFuncReturnType() { return common::ManagedPointer(return_type_); } - /** - * @return function body - */ + /** @return The function body */ std::vector GetFuncBody() { return func_body_; } - /** - * @return function parameters - */ + /** @return The function parameters */ std::vector> GetFuncParameters() { std::vector> params; params.reserve(func_parameters_.size()); @@ -130,14 +155,10 @@ class CreateFunctionStatement : public SQLStatement { return params; } - /** - * @return programming language type - */ + /** @return The programming language type */ PLType GetPLType() { return pl_type_; } - /** - * @return as type (executable or query string) - */ + /** @return As type (executable or query string) */ AsType GetAsType() { return as_type_; } private: @@ -150,5 +171,4 @@ class CreateFunctionStatement : public SQLStatement { const AsType as_type_; }; -} // namespace parser -} // namespace noisepage +} // namespace noisepage::parser diff --git a/src/include/parser/drop_statement.h b/src/include/parser/drop_statement.h index 946cd6aa29..69e773fbd8 100644 --- a/src/include/parser/drop_statement.h +++ b/src/include/parser/drop_statement.h @@ -3,6 +3,7 @@ #include #include #include +#include #include "binder/sql_node_visitor.h" #include "parser/sql_statement.h" @@ -15,7 +16,7 @@ namespace parser { class DropStatement : public TableRefStatement { public: /** Drop statement type. */ - enum class DropType { kDatabase, kTable, kSchema, kIndex, kView, kPreparedStatement, kTrigger }; + enum class DropType { kDatabase, kTable, kSchema, kIndex, kView, kPreparedStatement, kTrigger, kFunction }; /** * DROP DATABASE, DROP TABLE @@ -36,6 +37,21 @@ class DropStatement : public TableRefStatement { type_(DropType::kIndex), index_name_(std::move(index_name)) {} + /** + * DROP FUNCTION + * @param table_info table information + * @param function_name function name + * @param function_args function argument type identifiers + * @param if_exists `true` if `IF EXISTS` specified, `false` otherwise + */ + DropStatement(std::unique_ptr table_info, std::string function_name, + std::vector &&function_args, bool if_exists) + : TableRefStatement(StatementType::DROP, std::move(table_info)), + type_(DropType::kFunction), + if_exists_(if_exists), + function_name_(std::move(function_name)), + function_args_(std::move(function_args)) {} + /** * DROP SCHEMA * @param table_info table information @@ -79,10 +95,19 @@ class DropStatement : public TableRefStatement { /** @return trigger name for [DROP TRIGGER] */ std::string GetTriggerName() { return trigger_name_; } + /** @return function name for [DROP FUNCTION] */ + std::string GetFunctionName() { return function_name_; } + + /** @return function argument types for [DROP FUNCTION] */ + const std::vector &GetFunctionArguments() const { return function_args_; } + private: const DropType type_; - // DROP DATABASE, SCHEMA + // TODO(Kyle): Maybe use a std::variant here to make + // the overloading of this type less wasteful? + + // DROP DATABASE, SCHEMA, FUNCTION const bool if_exists_ = false; // DROP INDEX @@ -93,6 +118,10 @@ class DropStatement : public TableRefStatement { // DROP TRIGGER const std::string trigger_name_; + + // DROP FUNCTION + const std::string function_name_; + std::vector function_args_; }; } // namespace parser diff --git a/src/include/parser/expression/column_value_expression.h b/src/include/parser/expression/column_value_expression.h index bd445f5682..9429399479 100644 --- a/src/include/parser/expression/column_value_expression.h +++ b/src/include/parser/expression/column_value_expression.h @@ -36,6 +36,9 @@ class ColumnValueExpression : public AbstractExpression { friend class noisepage::TpccPlanTest; public: + /** Denotes an invalid parameter index */ + static constexpr const std::int32_t INVALID_PARAM_INDEX{-1}; + /** * This constructor is called only in postgresparser, setting the column name, * and optionally setting the table name and alias. @@ -146,6 +149,12 @@ class ColumnValueExpression : public AbstractExpression { /** @return column oid */ catalog::col_oid_t GetColumnOid() const { return column_oid_; } + /** @return The parameter index */ + std::int32_t GetParamIdx() const { return param_idx_; } + + /** @brief Set the parameter index */ + void SetParamIdx(const std::size_t param_idx) { param_idx_ = static_cast(param_idx); } + /** * Get Column Full Name [tbl].[col] */ @@ -206,17 +215,22 @@ class ColumnValueExpression : public AbstractExpression { private: friend class binder::BinderContext; friend class execution::sql::TableGenerator; + /** @param database_oid Database OID to be assigned to this expression */ void SetDatabaseOID(catalog::db_oid_t database_oid) { database_oid_ = database_oid; } + /** @param table_oid Table OID to be assigned to this expression */ void SetTableOID(catalog::table_oid_t table_oid) { table_oid_ = table_oid; } + /** @param column_oid Column OID to be assigned to this expression */ void SetColumnOID(catalog::col_oid_t column_oid) { column_oid_ = column_oid; } + /** @param column_oid Column OID to be assigned to this expression */ void SetColumnName(const std::string &col_name) { column_name_ = std::string(col_name); } - /** Table name. */ + /** Table alias. */ AliasType table_alias_; + /** Column name. */ std::string column_name_; @@ -228,6 +242,9 @@ class ColumnValueExpression : public AbstractExpression { /** OID of the column */ catalog::col_oid_t column_oid_ = catalog::INVALID_COLUMN_OID; + + /** parameter index */ + std::int32_t param_idx_{INVALID_PARAM_INDEX}; }; DEFINE_JSON_HEADER_DECLARATIONS(ColumnValueExpression); diff --git a/src/include/parser/expression/constant_value_expression.h b/src/include/parser/expression/constant_value_expression.h index 73edabfdc6..c1f3230d1d 100644 --- a/src/include/parser/expression/constant_value_expression.h +++ b/src/include/parser/expression/constant_value_expression.h @@ -78,13 +78,22 @@ class ConstantValueExpression : public AbstractExpression { */ ConstantValueExpression(const ConstantValueExpression &other); + /** + * Compute a hash for the expression. + * @return The hash value + */ common::hash_t Hash() const override; + /** + * Equality comparison. + * @param other The other ConstantValueExpression instance + * @return `true` if the instances are equivalent, `false` otherwise + */ bool operator==(const AbstractExpression &other) const override; /** * Copies this ConstantValueExpression - * @returns copy of this + * @returns A copy of `this` */ std::unique_ptr Copy() const override { return std::unique_ptr{std::make_unique(*this)}; @@ -102,62 +111,57 @@ class ConstantValueExpression : public AbstractExpression { return Copy(); } + /** Derive the name of the expression if it is not present */ void DeriveExpressionName() override { if (!this->GetAliasName().empty()) { this->SetExpressionName(this->GetAliasName()); } } - /** - * @return copy of the underlying Val - */ + /** @return The expression value as a generic SQL value */ + common::ManagedPointer GetVal() const { + NOISEPAGE_ASSERT(std::holds_alternative(value_), "GetVal() bad variant access"); + return common::ManagedPointer(&std::get(value_)); + } + + /** @return A copy of the underlying Val */ execution::sql::BoolVal GetBoolVal() const { NOISEPAGE_ASSERT(std::holds_alternative(value_), "Invalid variant type for Get."); return std::get(value_); } - /** - * @return copy of the underlying Val - */ + /** @return A copy of the underlying Val */ execution::sql::Integer GetInteger() const { NOISEPAGE_ASSERT(std::holds_alternative(value_), "Invalid variant type for Get."); return std::get(value_); } - /** - * @return copy of the underlying Val - */ + /** @return A copy of the underlying Val */ execution::sql::Real GetReal() const { NOISEPAGE_ASSERT(std::holds_alternative(value_), "Invalid variant type for Get."); return std::get(value_); } - /** - * @return copy of underlying Val - */ + /** @return A copy of underlying Val */ execution::sql::DecimalVal GetDecimalVal() const { NOISEPAGE_ASSERT(std::holds_alternative(value_), "Invalid variant type for Get."); return std::get(value_); } - /** - * @return copy of the underlying Val - */ + /** @return A copy of the underlying Val */ execution::sql::DateVal GetDateVal() const { NOISEPAGE_ASSERT(std::holds_alternative(value_), "Invalid variant type for Get."); return std::get(value_); } - /** - * @return copy of the underlying Val - */ + /** @return A copy of the underlying Val */ execution::sql::TimestampVal GetTimestampVal() const { NOISEPAGE_ASSERT(std::holds_alternative(value_), "Invalid variant type for Get."); return std::get(value_); } /** - * @return copy of the underlying Val + * @return A copy of the underlying Val * @warning StringVal may not have inlined its value, in which case the StringVal returned by this function will hold * a pointer to the buffer in this CVE. In that case, do not destroy this CVE before the copied StringVal */ @@ -194,9 +198,7 @@ class ConstantValueExpression : public AbstractExpression { Validate(); } - /** - * @return true if CVE value represents a NULL - */ + /** @return `true` if CVE value represents a NULL, `false` otherwise */ bool IsNull() const { if (std::holds_alternative(value_) && std::get(value_).is_null_) return true; @@ -229,15 +231,25 @@ class ConstantValueExpression : public AbstractExpression { } /** - * Extracts the underlying execution value as a C++ type + * Extracts the underlying execution value as a C++ type. * @tparam T C++ type to extract * @return copy of the underlying value as the requested type - * @warning std::string_view returned by this function will hold a pointer to the buffer in this CVE. In that case, do - * not destroy this CVE before the std::string_view + * @warning std::string_view returned by this function will hold a pointer to the buffer in this CVE. + * In that case, do not destroy this CVE before the std::string_view */ template T Peek() const; + /** + * Get a pointer to the underlying value as a generic SQL type. + * @return An immutable pointer to the underlying value + */ + const execution::sql::Val *SqlValue() const; + + /** + * Visitor pattern for binder. + * @param v The SqlNodeVisitor + */ void Accept(common::ManagedPointer v) override; /** @return A string representation of this ConstantValueExpression. */ @@ -259,10 +271,13 @@ class ConstantValueExpression : public AbstractExpression { private: friend class binder::BindNodeVisitor; /* value_ may be modified, e.g., when parsing dates. */ void Validate() const; + + // The underlying constant value std::variant value_{execution::sql::Val(true)}; + // Buffer for inlined string values std::unique_ptr buffer_ = nullptr; }; diff --git a/src/include/parser/expression/subquery_expression.h b/src/include/parser/expression/subquery_expression.h index bcd1436382..28431171ba 100644 --- a/src/include/parser/expression/subquery_expression.h +++ b/src/include/parser/expression/subquery_expression.h @@ -42,9 +42,12 @@ class SubqueryExpression : public AbstractExpression { return Copy(); } - /** @return managed pointer to the sub-select */ + /** @return A non-owning pointer to the sub-select */ common::ManagedPointer GetSubselect() { return common::ManagedPointer(subselect_); } + /** @return An owning pointer to the sub-select */ + std::unique_ptr ReleaseSubselect() { return std::move(subselect_); } + void Accept(common::ManagedPointer v) override { v->Visit(common::ManagedPointer(this)); } /** diff --git a/src/include/parser/expression/type_cast_expression.h b/src/include/parser/expression/type_cast_expression.h index 13260ecab4..27f3637b5a 100644 --- a/src/include/parser/expression/type_cast_expression.h +++ b/src/include/parser/expression/type_cast_expression.h @@ -36,6 +36,7 @@ class TypeCastExpression : public AbstractExpression { * @returns copy of this */ std::unique_ptr Copy() const override; + /** * Creates a copy of the current AbstractExpression with new children implanted. * The children should not be owned by any other AbstractExpression. diff --git a/src/include/parser/nodes.h b/src/include/parser/nodes.h index c004a40ef0..f534bf7db4 100644 --- a/src/include/parser/nodes.h +++ b/src/include/parser/nodes.h @@ -15,3 +15,18 @@ using value = struct Value { char *str_; /**< string */ } val_; /**< value */ }; + +/** + * A typename parsenode as produced by the Postgres parser + */ +using typname = struct TypName { + NodeTag type_; + List *names_; + Oid type_oid_; + bool setof_; + bool pct_type_; + List *typmods_; + int32_t typemod_; + List *array_bounds_; + int location_; +}; diff --git a/src/include/parser/parse_result.h b/src/include/parser/parse_result.h index 8a82aa98c2..6013b0f489 100644 --- a/src/include/parser/parse_result.h +++ b/src/include/parser/parse_result.h @@ -52,10 +52,15 @@ class ParseResult { */ uint32_t NumStatements() const { return statements_.size(); } - /** - * @return the statement at a particular index - */ - common::ManagedPointer GetStatement(size_t idx) { return common::ManagedPointer(statements_[idx]); } + /** @return The statement at index `index`*/ + common::ManagedPointer GetStatement(std::size_t idx) { + return common::ManagedPointer(statements_[idx]); + } + + /** @return The statement at a index `index` */ + common::ManagedPointer GetStatement(std::size_t idx) const { + return common::ManagedPointer(statements_.at(idx).get()); + } /** * @return non-owning list of all the expressions contained in this parse result diff --git a/src/include/parser/postgresparser.h b/src/include/parser/postgresparser.h index cac171fd4f..c66116303a 100644 --- a/src/include/parser/postgresparser.h +++ b/src/include/parser/postgresparser.h @@ -77,6 +77,11 @@ class PostgresParser { } } + /** + * Determine if the function identified by `fun_name` is an aggregate function. + * @param fun_name The function name + * @return `true` if the function is an aggregation, `false` otherwise + */ static bool IsAggregateFunction(const std::string &fun_name) { return (fun_name == "min" || fun_name == "max" || fun_name == "count" || fun_name == "avg" || fun_name == "sum"); } @@ -85,16 +90,19 @@ class PostgresParser { * Transforms the entire parsed nodes list into a corresponding SQLStatementList. * @param[in,out] parse_result the current parse result, which will be updated * @param root list of parsed nodes + * @param query_string the query string */ - static void ListTransform(ParseResult *parse_result, List *root); + static void ListTransform(ParseResult *parse_result, List *root, const std::string &query_string); /** * Transforms a single node in the parse list into a noisepage SQLStatement object. * @param[in,out] parse_result the current parse result, which will be updated * @param node parsed node + * @param query_string the query string * @return SQLStatement corresponding to the parsed node */ - static std::unique_ptr NodeTransform(ParseResult *parse_result, Node *node); + static std::unique_ptr NodeTransform(ParseResult *parse_result, Node *node, + const std::string &query_string); static std::unique_ptr ExprTransform(ParseResult *parse_result, Node *node, char *alias); static ExpressionType StringToExpressionType(const std::string &parser_str); @@ -134,7 +142,8 @@ class PostgresParser { // CREATE statements static std::unique_ptr CreateTransform(ParseResult *parse_result, CreateStmt *root); static std::unique_ptr CreateDatabaseTransform(ParseResult *parse_result, CreateDatabaseStmt *root); - static std::unique_ptr CreateFunctionTransform(ParseResult *parse_result, CreateFunctionStmt *root); + static std::unique_ptr CreateFunctionTransform(ParseResult *parse_result, CreateFunctionStmt *root, + const std::string &query_string); static std::unique_ptr CreateIndexTransform(ParseResult *parse_result, IndexStmt *root); static std::unique_ptr CreateSchemaTransform(ParseResult *parse_result, CreateSchemaStmt *root); static std::unique_ptr CreateTriggerTransform(ParseResult *parse_result, CreateTrigStmt *root); @@ -160,6 +169,7 @@ class PostgresParser { // DROP statements static std::unique_ptr DropTransform(ParseResult *parse_result, DropStmt *root); static std::unique_ptr DropDatabaseTransform(ParseResult *parse_result, DropDatabaseStmt *root); + static std::unique_ptr DropFunctionTransform(ParseResult *parse_result, DropStmt *root); static std::unique_ptr DropIndexTransform(ParseResult *parse_result, DropStmt *root); static std::unique_ptr DropSchemaTransform(ParseResult *parse_result, DropStmt *root); static std::unique_ptr DropTableTransform(ParseResult *parse_result, DropStmt *root); @@ -173,7 +183,8 @@ class PostgresParser { List *root); // EXPLAIN statements - static std::unique_ptr ExplainTransform(ParseResult *parse_result, ExplainStmt *root); + static std::unique_ptr ExplainTransform(ParseResult *parse_result, ExplainStmt *root, + const std::string &query_string); // INSERT statements static std::unique_ptr InsertTransform(ParseResult *parse_result, InsertStmt *root); @@ -184,7 +195,8 @@ class PostgresParser { ParseResult *parse_result, List *root); // PREPARE statements - static std::unique_ptr PrepareTransform(ParseResult *parse_result, PrepareStmt *root); + static std::unique_ptr PrepareTransform(ParseResult *parse_result, PrepareStmt *root, + const std::string &query_string); static std::unique_ptr TruncateTransform(ParseResult *parse_result, TruncateStmt *truncate_stmt); diff --git a/src/include/parser/select_statement.h b/src/include/parser/select_statement.h index 21c748a384..9a92885f2a 100644 --- a/src/include/parser/select_statement.h +++ b/src/include/parser/select_statement.h @@ -355,7 +355,7 @@ class SelectStatement : public SQLStatement { std::unique_ptr Copy(); /** @return The columns targeted by SELECT */ - const std::vector> &GetSelectColumns() { return select_; } + const std::vector> &GetSelectColumns() const { return select_; } /** @return `true` if "SELECT DISTINCT", `false` otherwise */ bool IsSelectDistinct() const { return select_distinct_; } @@ -462,7 +462,7 @@ class SelectStatement : public SQLStatement { // The depth of the SELECT statement int depth_{-1}; - // A colletion of the temporary tables (CTEs) available to this SELECT + // A collection of the temporary tables (CTEs) available to this SELECT std::vector> with_table_; /** @param select List of SELECT columns */ diff --git a/src/include/parser/udf/plpgsql_parse_result.h b/src/include/parser/udf/plpgsql_parse_result.h new file mode 100644 index 0000000000..39457ab584 --- /dev/null +++ b/src/include/parser/udf/plpgsql_parse_result.h @@ -0,0 +1,35 @@ +#pragma once + +#include + +#include "libpg_query/pg_query.h" + +namespace noisepage::parser::udf { + +/** + * The PLpgSQLParseResult class is a simple RAII + * wrapper for the parse result returned by libpq_query. + * + * NOTE: Could just do this with a std::unique_ptr with + * a default deleter, but this is more pleasant. + */ +class PLpgSQLParseResult { + public: + /** + * Construct a new PLpgSQLParseResult instance. + * @param result The raw result + */ + explicit PLpgSQLParseResult(PgQueryPlpgsqlParseResult &&result) : result_{result} {} + + /** Release resources from the parse result */ + ~PLpgSQLParseResult() { pg_query_free_plpgsql_parse_result(result_); } + + /** @return An immutable reference to the underlying result */ + const PgQueryPlpgsqlParseResult &operator*() const { return result_; } + + private: + /** The underlying parse result */ + PgQueryPlpgsqlParseResult result_; +}; + +} // namespace noisepage::parser::udf diff --git a/src/include/parser/udf/plpgsql_parser.h b/src/include/parser/udf/plpgsql_parser.h new file mode 100644 index 0000000000..3ff99ec376 --- /dev/null +++ b/src/include/parser/udf/plpgsql_parser.h @@ -0,0 +1,255 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "catalog/catalog_accessor.h" +#include "execution/ast/udf/udf_ast_context.h" +#include "execution/ast/udf/udf_ast_nodes.h" + +#include "parser/expression_util.h" +#include "parser/postgresparser.h" +#include "parser/sql_statement.h" + +namespace noisepage::parser { +class SQLStatement; +} // namespace noisepage::parser + +namespace noisepage::execution::ast::udf { +class FunctionAST; +} // namespace noisepage::execution::ast::udf + +namespace noisepage::parser::udf { + +/** An enumeration over the supported PL/pgSQL statement types */ +enum class StatementType { UNKNOWN, RETURN, IF, ASSIGN, WHILE, FORI, FORS, EXECSQL, DYNEXECUTE }; + +/** + * The PLpgSQLParser class parses source PL/pgSQL to an abstract syntax tree. + * + * Internally, PLpgSQLParser utilizes libpg_query to perform the actual parsing + * of the input PL/pgSQL source, and then maps the representation from libpg_query + * to our our internal representation that then proceeds through code generation. + */ +class PLpgSQLParser { + public: + /** + * Construct a new PLpgSQLParser instance. + * @param udf_ast_context The AST context + */ + explicit PLpgSQLParser(common::ManagedPointer udf_ast_context) + : udf_ast_context_{udf_ast_context} {} + + /** + * Parse source PL/pgSQL to an abstract syntax tree. + * @param param_names The names of the function parameters + * @param param_types The types of the function parameters + * @param func_body The input source for the function + * @return The abstract syntax tree for the source function + */ + std::unique_ptr Parse(const std::vector ¶m_names, + const std::vector ¶m_types, + const std::string &func_body); + + private: + /** + * Parse a block statement. + * @param json The input JSON object + * @return The AST for the block + */ + std::unique_ptr ParseBlock(const nlohmann::json &json); + + /** + * Parse a return statement. + * @param json The input JSON object + * @return The AST for the return statement + */ + std::unique_ptr ParseReturn(const nlohmann::json &json); + + /** + * Parse a function statement. + * @param json The input JSON object + * @return json AST for the function + */ + std::unique_ptr ParseFunction(const nlohmann::json &json); + + /** + * Parse a declaration statement. + * @param json The input JSON object + * @return The AST for the declaration + */ + std::unique_ptr ParseDecl(const nlohmann::json &json); + + /** + * Parse an assignment statement. + * @param json The input JSON object + * @return The AST for the assignment + */ + std::unique_ptr ParseAssign(const nlohmann::json &json); + + /** + * Parse an if-statement. + * @param json The input JSON object + * @return The AST for the if-statement + */ + std::unique_ptr ParseIf(const nlohmann::json &json); + + /** + * Parse a while-statement. + * @param json The input JSON object + * @return The AST for the while-statement + */ + std::unique_ptr ParseWhile(const nlohmann::json &json); + + /** + * Parse a for-statement (integer variant). + * @param json The input JSON object + * @return The AST for the for-statement + */ + std::unique_ptr ParseForI(const nlohmann::json &json); + + /** + * Parse a for-statement (query variant). + * @param json The input JSON object + * @return The AST for the for-statement + */ + std::unique_ptr ParseForS(const nlohmann::json &json); + + /** + * Parse a SQL statement. + * @param json The input JSON object + * @return The AST for the SQL statement + */ + std::unique_ptr ParseExecSQL(const nlohmann::json &json); + + /** + * Parse a SQL statement. + * @param sql The input SQL query text + * @param variables The collection of variables to which results are bound + * @return The AST for the SQL statement + */ + std::unique_ptr ParseExecSQL(const std::string &sql, + std::vector &&variables); + + /** + * Parse a dynamic SQL statement. + * @param json The input JSON object + * @return The AST for the dynamic SQL statement + */ + std::unique_ptr ParseDynamicSQL(const nlohmann::json &json); + + /** + * Parse a SQL expression from a query string. + * @param sql The SQL expression string + * @return The AST for the SQL expression + */ + std::unique_ptr ParseExprFromSQLString(const std::string &sql); + + /** + * Try to parse a SQL expression from a query string. If the expression + * type is not supported, indicate failure with an empty std::optional. + * @param sql The SQL expression string + * @return The AST for the SQL expression on success, empty std::optional on failure + */ + std::optional> TryParseExprFromSQLString( + const std::string &sql) noexcept; + + /** + * Parse a SQL expression from a SQL statement. + * @param statement The SQL statement + * @return The AST for the SQL statement + */ + std::unique_ptr ParseExprFromSQLStatement( + common::ManagedPointer statement); + + /** + * Try to parse an abstract expression from a SQL statement. If the statement + * type is not supported, indicate failure with an empty std::optional. + * @param statement The input SQL statement + * @return The AST for the statement on success, empty std::optional on failure + */ + std::optional> TryParseExprFromSQLStatement( + common::ManagedPointer statement) noexcept; + + /** + * Parse an abstract expression to an expression AST. + * @param expr The abstract expression + * @return The AST for the expression + */ + std::unique_ptr ParseExprFromAbstract( + common::ManagedPointer expr); + + /** + * Try to parse an abstract expression to an expression AST. If the expression + * type is not supported, indicate failure with an empty std::optional. + * @param expr The input expression + * @return The AST for the expression on success, empty std::optional on failure + */ + std::optional> TryParseExprFromAbstract( + common::ManagedPointer expr) noexcept; + + private: + /** + * Determine if all variables in `names` are declared in the function. + * @param names The collection of variable identifiers + * @return `true` if all variables are declared, `false` otherwise + */ + bool AllVariablesDeclared(const std::vector &names) const; + + /** + * Determine if any of the variables in `names` refer to a RECORD type. + * @param names The collection of variable identifiers + * @return `true` if any of the variables in `names` refer + * to a RECORD type previously declared, `false` otherwise + */ + bool ContainsRecordType(const std::vector &names) const; + + /** + * Resolve a PL/pgSQL RECORD type from a SELECT statement. + * @param parse_result The result of parsing the SQL query + * @return The resolved record type + */ + std::vector> ResolveRecordType(const ParseResult *parse_result); + + /** + * Get the StatementType for the provided statement type identifier. + * @param type The identifier for the statement type + * @return The corresponding StatementType + */ + static StatementType GetStatementType(const std::string &type); + + /** + * Strip an enclosing SELECT query from an existing ParseResult. + * @param input The existing ParseResult + * @return A new ParseResult with the enclosing query stripped + */ + static std::unique_ptr StripEnclosingQuery(std::unique_ptr &&input); + + /** + * Determine if the parsed query has an enclosing "wrapper" query + * introduced by the PL/pgSQL parser. + * @param parse_result The parsed query + * @return `true` if the query has an enclosing query, `false` otherwise + */ + static bool HasEnclosingQuery(ParseResult *parse_result); + + /** + * Get the internal type identifier for given type name. + * @param type_name The typename + * @return The type identifier for the type, or empty std::optional + * in the case of an unsupported or unrecognized type + */ + static std::optional TypeNameToType(const std::string &type_name); + + private: + /** The UDF AST context */ + common::ManagedPointer udf_ast_context_; + /** The function symbol table */ + std::unordered_map symbol_table_; +}; + +} // namespace noisepage::parser::udf diff --git a/src/include/parser/udf/string_utils.h b/src/include/parser/udf/string_utils.h new file mode 100644 index 0000000000..94165545c9 --- /dev/null +++ b/src/include/parser/udf/string_utils.h @@ -0,0 +1,29 @@ +#pragma once + +#include + +namespace noisepage::parser::udf { + +/** + * StringUtils is a static class that implements some basic + * string-processing utilities. Eventually, we might want to + * move functionality like this to our own internal algo library. + */ +class StringUtils { + public: + /** + * Convert a non-owned string to lowercase. + * @param string The input string + * @return The lowercased string + */ + static std::string Lower(const std::string &string); + + /** + * Strip whitespace from the start and end of a non-owned string. + * @param string The input string + * @return The stripped string + */ + static std::string Strip(const std::string &string); +}; + +} // namespace noisepage::parser::udf diff --git a/src/include/parser/udf/variable_ref.h b/src/include/parser/udf/variable_ref.h new file mode 100644 index 0000000000..dcd3434e33 --- /dev/null +++ b/src/include/parser/udf/variable_ref.h @@ -0,0 +1,73 @@ +#pragma once + +#include +#include + +#include "common/macros.h" + +namespace noisepage::parser::udf { + +/** + * The VariableRefType enumeration defines the + * valid types of variable references. + */ +enum class VariableRefType { SCALAR, TABLE }; + +/** + * The VariableRef type represents a UDF variable reference + * within a SQL query. It is used during binding to identify + * and track the query parameters that must be read from the + * UDF environment prior to query execution. + */ +class VariableRef { + public: + /** + * Construct a new VariableRef instance for a TABLE reference. + * @param table_name The name of the table + * @param column_name The name of the column + * @param index The index + */ + VariableRef(std::string table_name, std::string column_name, std::size_t index) + : type_{VariableRefType::TABLE}, + table_name_{std::move(table_name)}, + column_name_{std::move(column_name)}, + index_{index} {} + + /** + * Construct a new VariableRef instance for a SCALAR reference. + * @param column_name The name of the column + * @param index The index + */ + VariableRef(std::string column_name, std::size_t index) + : type_{VariableRefType::SCALAR}, table_name_{}, column_name_{std::move(column_name)}, index_{index} {} + + /** @return `true` if this is a SCALAR variable reference, `false` otherwise */ + bool IsScalar() const { return type_ == VariableRefType::SCALAR; } + + /** @return The table name of the variable reference */ + const std::string &TableName() const { + NOISEPAGE_ASSERT(!IsScalar(), "SCALAR variable references do not have associated table names"); + return table_name_; + } + + /** @return The column name of the variable reference */ + const std::string &ColumnName() const { return column_name_; } + + /** @return The index of the variable reference */ + std::size_t Index() const { return index_; } + + /** @return A string representation of the variable reference */ + std::string ToString() const { return fmt::format("{} {} {}", table_name_.c_str(), column_name_.c_str(), index_); } + + private: + /** The type of this variable reference */ + const VariableRefType type_; + /** The table name associated with this variable reference (may be empty) */ + const std::string table_name_; + /** The column name associated with this variable reference */ + const std::string column_name_; + /** The index of the reference in the query */ + const std::size_t index_; +}; + +} // namespace noisepage::parser::udf diff --git a/src/include/planner/plannodes/abstract_plan_node.h b/src/include/planner/plannodes/abstract_plan_node.h index 44b487434d..a07c20476f 100644 --- a/src/include/planner/plannodes/abstract_plan_node.h +++ b/src/include/planner/plannodes/abstract_plan_node.h @@ -166,7 +166,9 @@ class AbstractPlanNode { * @return output schema for the node. The output schema contains information on columns of the output of the plan * node operator */ - common::ManagedPointer GetOutputSchema() const { return common::ManagedPointer(output_schema_); } + common::ManagedPointer GetOutputSchema() const { + return common::ManagedPointer(output_schema_.get()); + } //===--------------------------------------------------------------------===// // Add child diff --git a/src/include/planner/plannodes/create_function_plan_node.h b/src/include/planner/plannodes/create_function_plan_node.h index 2a3cac21d8..589872cd2b 100644 --- a/src/include/planner/plannodes/create_function_plan_node.h +++ b/src/include/planner/plannodes/create_function_plan_node.h @@ -13,7 +13,7 @@ namespace noisepage::planner { /** - * Plan node for creating user defined functions + * Plan node for creating user-defined functions */ class CreateFunctionPlanNode : public AbstractPlanNode { public: @@ -243,12 +243,12 @@ class CreateFunctionPlanNode : public AbstractPlanNode { /** * @return parameter names of the user defined function */ - std::vector GetFunctionParameterNames() const { return function_param_names_; } + const std::vector &GetFunctionParameterNames() const { return function_param_names_; } /** * @return parameter types of the user defined function */ - std::vector GetFunctionParameterTypes() const { + const std::vector &GetFunctionParameterTypes() const { return function_param_types_; } diff --git a/src/include/planner/plannodes/drop_function_plan_node.h b/src/include/planner/plannodes/drop_function_plan_node.h new file mode 100644 index 0000000000..1b58493523 --- /dev/null +++ b/src/include/planner/plannodes/drop_function_plan_node.h @@ -0,0 +1,126 @@ +#pragma once + +#include +#include + +#include "parser/drop_statement.h" +#include "parser/parser_defs.h" +#include "planner/plannodes/abstract_plan_node.h" +#include "planner/plannodes/plan_visitor.h" + +namespace noisepage::planner { + +/** + * Plan node for dropping user-defined functions. + */ +class DropFunctionPlanNode : public AbstractPlanNode { + public: + /** + * Builder for an create function plan node + */ + class Builder : public AbstractPlanNode::Builder { + public: + Builder() = default; + + /** + * Don't allow builder to be copied or moved + */ + DISALLOW_COPY_AND_MOVE(Builder); + + /** + * @param database_oid The OID of the database + * @return builder object + */ + Builder &SetDatabaseOid(catalog::db_oid_t database_oid) { + database_oid_ = database_oid; + return *this; + } + + /** + * @param proc_oid The OID of the procedure + * @return builder object + */ + Builder &SetProcedureOid(catalog::proc_oid_t proc_oid) { + proc_oid_ = proc_oid; + return *this; + } + + /** + * @param if_exists `true` if `IF EXISTS` is specified + * @return builder object + */ + Builder &SetIfExists(bool if_exists) { + if_exists_ = if_exists; + return *this; + } + + /** + * Build the drop function plan node + * @return plan node + */ + std::unique_ptr Build(); + + protected: + /** OID of the database */ + catalog::db_oid_t database_oid_; + /** OID of the procedure */ + catalog::proc_oid_t proc_oid_; + /** `true` if `IF EXISTS` specified */ + bool if_exists_; + }; + + private: + /** + * @param children child plan nodes + * @param output_schema Schema representing the structure of the output of this plan node + * @param database_oid OID of the database + * @param proc_oid OID of the procedure + * @param if_exists `true` if `IF EXISTS` specified + * @param plan_node_id Plan node ID + */ + DropFunctionPlanNode(std::vector> &&children, + std::unique_ptr output_schema, catalog::db_oid_t database_oid, + catalog::proc_oid_t proc_oid, bool if_exists, plan_node_id_t plan_node_id); + + public: + /** Default constructor used for deserialization */ + DropFunctionPlanNode() = default; + + DISALLOW_COPY_AND_MOVE(DropFunctionPlanNode) + + /** @return the type of this plan node */ + PlanNodeType GetPlanNodeType() const override { return PlanNodeType::DROP_FUNC; } + + /** @return OID of the database */ + catalog::db_oid_t GetDatabaseOid() const { return database_oid_; } + + /** @return OID of the procedure */ + catalog::proc_oid_t GetProcedureOid() const { return proc_oid_; } + + /** @return `true` if `IF EXISTS` is specified */ + bool GetIfExists() const { return if_exists_; } + + /** @return the hashed value of this plan node */ + common::hash_t Hash() const override; + + bool operator==(const AbstractPlanNode &rhs) const override; + + void Accept(common::ManagedPointer v) const override { v->Visit(this); } + + /** Serialize to JSON representation */ + nlohmann::json ToJson() const override; + /** Deserialize from JSON representation */ + std::vector> FromJson(const nlohmann::json &j) override; + + private: + /** OID of database */ + catalog::db_oid_t database_oid_; + /** OID of procedure */ + catalog::proc_oid_t proc_oid_; + /** `true` if `IF EXISTS` specified */ + bool if_exists_; +}; + +DEFINE_JSON_HEADER_DECLARATIONS(DropFunctionPlanNode); + +} // namespace noisepage::planner diff --git a/src/include/planner/plannodes/plan_node_defs.h b/src/include/planner/plannodes/plan_node_defs.h index 4e71d20da7..f79c807153 100644 --- a/src/include/planner/plannodes/plan_node_defs.h +++ b/src/include/planner/plannodes/plan_node_defs.h @@ -62,6 +62,7 @@ enum class PlanNodeType { DROP_NAMESPACE, DROP_TABLE, DROP_INDEX, + DROP_FUNC, DROP_TRIGGER, DROP_VIEW, ANALYZE, diff --git a/src/include/planner/plannodes/plan_visitor.h b/src/include/planner/plannodes/plan_visitor.h index c055c9587a..0f8fa51ef1 100644 --- a/src/include/planner/plannodes/plan_visitor.h +++ b/src/include/planner/plannodes/plan_visitor.h @@ -19,6 +19,7 @@ class DropNamespacePlanNode; class DropTablePlanNode; class DropTriggerPlanNode; class DropViewPlanNode; +class DropFunctionPlanNode; class ExportExternalFilePlanNode; class HashJoinPlanNode; class IndexJoinPlanNode; @@ -143,6 +144,12 @@ class PlanVisitor { */ virtual void Visit(UNUSED_ATTRIBUTE const DropViewPlanNode *plan) {} + /** + * Visit a DropFunctionPlanNode + * @param plan DropFunctionPlanNode + */ + virtual void Visit(UNUSED_ATTRIBUTE const DropFunctionPlanNode *plan) {} + /** * Visit an ExportExternalFilePlanNode * @param plan ExportExternalFilePlanNode diff --git a/src/include/traffic_cop/traffic_cop.h b/src/include/traffic_cop/traffic_cop.h index 54fea5b7ab..56e2215acd 100644 --- a/src/include/traffic_cop/traffic_cop.h +++ b/src/include/traffic_cop/traffic_cop.h @@ -7,7 +7,7 @@ #include "catalog/catalog_defs.h" #include "common/managed_pointer.h" -#include "execution/vm/vm_defs.h" +#include "execution/vm/execution_mode.h" #include "network/network_defs.h" #include "traffic_cop/traffic_cop_defs.h" #include "transaction/transaction_defs.h" diff --git a/src/include/util/query_exec_util.h b/src/include/util/query_exec_util.h index 5828c6b66b..a89fe42e25 100644 --- a/src/include/util/query_exec_util.h +++ b/src/include/util/query_exec_util.h @@ -11,6 +11,7 @@ #include "execution/compiler/executable_query.h" #include "execution/exec/execution_settings.h" #include "execution/exec_defs.h" +#include "execution/vm/execution_mode.h" #include "planner/plannodes/output_schema.h" namespace noisepage::transaction { @@ -216,20 +217,39 @@ class QueryExecUtil { std::string GetError() { return error_msg_; } private: + /** Reset the error message stored by the instance */ void ResetError(); + + /** + * Set the database OID. + * @param db_oid The database OID + */ void SetDatabase(catalog::db_oid_t db_oid); + /** The transaction manager */ common::ManagedPointer txn_manager_; + /** The catalog accessor */ common::ManagedPointer catalog_; + /** The settings manager */ common::ManagedPointer settings_; + /** The statistics storage */ common::ManagedPointer stats_; + + /** The timeout for query optimizatio */ uint64_t optimizer_timeout_; - /** Database being accessed */ + /** Idenditifer for the database being accessed */ catalog::db_oid_t db_oid_{catalog::INVALID_DATABASE_OID}; + + /** `true` if the QueryExecUtil instance owns the transaction, `false` otherwise */ bool own_txn_ = false; + /** The transaction context */ transaction::TransactionContext *txn_ = nullptr; + /** The query execution mode */ + // TODO(Kyle): Need a way to not just hard-code this value + execution::vm::ExecutionMode execution_mode_{execution::vm::ExecutionMode::Interpret}; + /** * Information about cached executable queries * Assumes that the query string is a unique identifier. @@ -237,9 +257,7 @@ class QueryExecUtil { std::unordered_map> schemas_; std::unordered_map> exec_queries_; - /** - * Stores the most recently encountered error. - */ + /** Stores the most recently encountered error */ std::string error_msg_; }; diff --git a/src/network/noisepage_server.cpp b/src/network/noisepage_server.cpp index de83c1dbd1..b99d456a00 100644 --- a/src/network/noisepage_server.cpp +++ b/src/network/noisepage_server.cpp @@ -139,8 +139,7 @@ void TerrierServer::RunServer() { // Register the network socket. RegisterSocket(); - // Register the Unix domain socket. - RegisterSocket(); + // TODO(Kyle): Removed UNIX domain socket. // Register the ConnectionDispatcherTask. This handles connections to the sockets created above. dispatcher_task_ = thread_registry_->RegisterDedicatedThread( diff --git a/src/network/postgres/postgres_packet_writer.cpp b/src/network/postgres/postgres_packet_writer.cpp index 8f7710f4dd..40b3a133f5 100644 --- a/src/network/postgres/postgres_packet_writer.cpp +++ b/src/network/postgres/postgres_packet_writer.cpp @@ -162,6 +162,9 @@ void PostgresPacketWriter::WriteCommandComplete(const QueryType query_type, cons case QueryType::QUERY_CREATE_SCHEMA: WriteCommandComplete("CREATE SCHEMA"); break; + case QueryType::QUERY_CREATE_FUNCTION: + WriteCommandComplete("CREATE FUNCTION"); + break; case QueryType::QUERY_DROP_DB: WriteCommandComplete("DROP DATABASE"); break; @@ -174,6 +177,9 @@ void PostgresPacketWriter::WriteCommandComplete(const QueryType query_type, cons case QueryType::QUERY_DROP_SCHEMA: WriteCommandComplete("DROP SCHEMA"); break; + case QueryType::QUERY_DROP_FUNCTION: + WriteCommandComplete("DROP FUNCTION"); + break; case QueryType::QUERY_EXPLAIN: WriteCommandComplete("EXPLAIN"); break; diff --git a/src/optimizer/child_property_deriver.cpp b/src/optimizer/child_property_deriver.cpp index 02170aac7e..93664aa220 100644 --- a/src/optimizer/child_property_deriver.cpp +++ b/src/optimizer/child_property_deriver.cpp @@ -268,6 +268,11 @@ void ChildPropertyDeriver::Visit(UNUSED_ATTRIBUTE const DropView *drop_view) { output_.emplace_back(new PropertySet(), std::vector{}); } +void ChildPropertyDeriver::Visit(UNUSED_ATTRIBUTE const DropFunction *drop_function) { + // Operator does not provide any properties + output_.emplace_back(new PropertySet(), std::vector{}); +} + void ChildPropertyDeriver::Visit(UNUSED_ATTRIBUTE const Analyze *analyze) { // Analyze does not provide any properties output_.emplace_back(new PropertySet(), std::vector{new PropertySet()}); diff --git a/src/optimizer/logical_operators.cpp b/src/optimizer/logical_operators.cpp index 69b1e27766..5237e38d76 100644 --- a/src/optimizer/logical_operators.cpp +++ b/src/optimizer/logical_operators.cpp @@ -1212,6 +1212,35 @@ bool LogicalDropView::operator==(const BaseOperatorNodeContents &r) { return if_exists_ == node.if_exists_; } +//===--------------------------------------------------------------------===// +// LogicalDropFunction +//===--------------------------------------------------------------------===// +BaseOperatorNodeContents *LogicalDropFunction::Copy() const { return new LogicalDropFunction(*this); } + +Operator LogicalDropFunction::Make(catalog::db_oid_t database_oid, catalog::proc_oid_t proc_oid, bool if_exists) { + auto *op = new LogicalDropFunction(); + op->database_oid_ = database_oid; + op->proc_oid_ = proc_oid; + op->if_exists_ = if_exists; + return Operator(common::ManagedPointer(op)); +} + +common::hash_t LogicalDropFunction::Hash() const { + common::hash_t hash = BaseOperatorNodeContents::Hash(); + hash = common::HashUtil::CombineHashes(hash, common::HashUtil::Hash(database_oid_)); + hash = common::HashUtil::CombineHashes(hash, common::HashUtil::Hash(proc_oid_)); + hash = common::HashUtil::CombineHashes(hash, common::HashUtil::Hash(if_exists_)); + return hash; +} + +bool LogicalDropFunction::operator==(const BaseOperatorNodeContents &r) { + if (r.GetOpType() != OpType::LOGICALDROPFUNCTION) return false; + const LogicalDropFunction &node = *dynamic_cast(&r); + if (database_oid_ != node.database_oid_) return false; + if (proc_oid_ == node.proc_oid_) return false; + return if_exists_ == node.if_exists_; +} + //===--------------------------------------------------------------------===// // LogicalAnalyze //===--------------------------------------------------------------------===// @@ -1385,6 +1414,8 @@ const char *OperatorNodeContents::name = "LogicalDropTrigger template <> const char *OperatorNodeContents::name = "LogicalDropView"; template <> +const char *OperatorNodeContents::name = "LogicalDropFunction"; +template <> const char *OperatorNodeContents::name = "LogicalAnalyze"; template <> const char *OperatorNodeContents::name = "LogicalCteScan"; @@ -1461,6 +1492,8 @@ OpType OperatorNodeContents::type = OpType::LOGICALDROPTRIGG template <> OpType OperatorNodeContents::type = OpType::LOGICALDROPVIEW; template <> +OpType OperatorNodeContents::type = OpType::LOGICALDROPFUNCTION; +template <> OpType OperatorNodeContents::type = OpType::LOGICALANALYZE; template <> OpType OperatorNodeContents::type = OpType::LOGICALCTESCAN; diff --git a/src/optimizer/physical_operators.cpp b/src/optimizer/physical_operators.cpp index e823f0df70..80e52b3858 100644 --- a/src/optimizer/physical_operators.cpp +++ b/src/optimizer/physical_operators.cpp @@ -1332,6 +1332,35 @@ bool DropView::operator==(const BaseOperatorNodeContents &r) { return if_exists_ == node.if_exists_; } +//===--------------------------------------------------------------------===// +// DropFunction +//===--------------------------------------------------------------------===// +BaseOperatorNodeContents *DropFunction::Copy() const { return new DropFunction(*this); } + +Operator DropFunction::Make(catalog::db_oid_t database_oid, catalog::proc_oid_t proc_oid, bool if_exists) { + auto *op = new DropFunction(); + op->database_oid_ = database_oid; + op->proc_oid_ = proc_oid; + op->if_exists_ = if_exists; + return Operator(common::ManagedPointer(op)); +} + +common::hash_t DropFunction::Hash() const { + common::hash_t hash = BaseOperatorNodeContents::Hash(); + hash = common::HashUtil::CombineHashes(hash, common::HashUtil::Hash(database_oid_)); + hash = common::HashUtil::CombineHashes(hash, common::HashUtil::Hash(proc_oid_)); + hash = common::HashUtil::CombineHashes(hash, common::HashUtil::Hash(if_exists_)); + return hash; +} + +bool DropFunction::operator==(const BaseOperatorNodeContents &r) { + if (r.GetOpType() != OpType::DROPFUNCTION) return false; + const DropFunction &node = *dynamic_cast(&r); + if (database_oid_ != node.database_oid_) return false; + if (proc_oid_ != node.proc_oid_) return false; + return if_exists_ == node.if_exists_; +} + //===--------------------------------------------------------------------===// // Analyze //===--------------------------------------------------------------------===// @@ -1473,6 +1502,8 @@ const char *OperatorNodeContents::name = "DropTrigger"; template <> const char *OperatorNodeContents::name = "DropView"; template <> +const char *OperatorNodeContents::name = "DropFunction"; +template <> const char *OperatorNodeContents::name = "Analyze"; template <> const char *OperatorNodeContents::name = "CteScan"; @@ -1555,6 +1586,8 @@ OpType OperatorNodeContents::type = OpType::DROPTRIGGER; template <> OpType OperatorNodeContents::type = OpType::DROPVIEW; template <> +OpType OperatorNodeContents::type = OpType::DROPFUNCTION; +template <> OpType OperatorNodeContents::type = OpType::ANALYZE; template <> OpType OperatorNodeContents::type = OpType::CTESCAN; diff --git a/src/optimizer/plan_generator.cpp b/src/optimizer/plan_generator.cpp index 9ed33e63d6..7cc10642b7 100644 --- a/src/optimizer/plan_generator.cpp +++ b/src/optimizer/plan_generator.cpp @@ -31,6 +31,7 @@ #include "planner/plannodes/cte_scan_plan_node.h" #include "planner/plannodes/delete_plan_node.h" #include "planner/plannodes/drop_database_plan_node.h" +#include "planner/plannodes/drop_function_plan_node.h" #include "planner/plannodes/drop_index_plan_node.h" #include "planner/plannodes/drop_namespace_plan_node.h" #include "planner/plannodes/drop_table_plan_node.h" @@ -1136,6 +1137,15 @@ void PlanGenerator::Visit(const DropView *drop_view) { .Build(); } +void PlanGenerator::Visit(const DropFunction *drop_function) { + output_plan_ = planner::DropFunctionPlanNode::Builder() + .SetPlanNodeId(GetNextPlanNodeID()) + .SetDatabaseOid(drop_function->GetDatabaseOid()) + .SetProcedureOid(drop_function->GetFunctionOid()) + .SetIfExists(drop_function->GetIfExists()) + .Build(); +} + void PlanGenerator::Visit(const Analyze *analyze) { NOISEPAGE_ASSERT(children_plans_.size() == 1, "Analyze should have 1 child plan"); output_plan_ = planner::AnalyzePlanNode::Builder() diff --git a/src/optimizer/query_to_operator_transformer.cpp b/src/optimizer/query_to_operator_transformer.cpp index 1750f6491a..0540e002cc 100644 --- a/src/optimizer/query_to_operator_transformer.cpp +++ b/src/optimizer/query_to_operator_transformer.cpp @@ -640,6 +640,12 @@ void QueryToOperatorTransformer::Visit(common::ManagedPointer>{}, txn_context); break; + case parser::DropStatement::DropType::kFunction: + drop_expr = std::make_unique( + LogicalDropFunction::Make(db_oid_, accessor_->GetProcOid(op->GetFunctionName(), op->GetFunctionArguments()), + op->IsIfExists()) + .RegisterWithTxnContext(txn_context), + std::vector>{}, txn_context); case parser::DropStatement::DropType::kTrigger: case parser::DropStatement::DropType::kView: case parser::DropStatement::DropType::kPreparedStatement: diff --git a/src/optimizer/rule.cpp b/src/optimizer/rule.cpp index a781cadefc..3fd811bc51 100644 --- a/src/optimizer/rule.cpp +++ b/src/optimizer/rule.cpp @@ -52,6 +52,7 @@ RuleSet::RuleSet() { AddRule(RuleSetName::PHYSICAL_IMPLEMENTATION, new LogicalDropNamespaceToPhysicalDropNamespace()); AddRule(RuleSetName::PHYSICAL_IMPLEMENTATION, new LogicalDropTriggerToPhysicalDropTrigger()); AddRule(RuleSetName::PHYSICAL_IMPLEMENTATION, new LogicalDropViewToPhysicalDropView()); + AddRule(RuleSetName::PHYSICAL_IMPLEMENTATION, new LogicalDropFunctionToPhysicalDropFunction()); AddRule(RuleSetName::PHYSICAL_IMPLEMENTATION, new LogicalAnalyzeToPhysicalAnalyze()); AddRule(RuleSetName::PHYSICAL_IMPLEMENTATION, new LogicalCteScanToPhysicalCteScan()); AddRule(RuleSetName::PHYSICAL_IMPLEMENTATION, new LogicalCteScanToPhysicalEmptyCteScan()); diff --git a/src/optimizer/rules/implementation_rules.cpp b/src/optimizer/rules/implementation_rules.cpp index 6b24be16cc..15a1c7aeb7 100644 --- a/src/optimizer/rules/implementation_rules.cpp +++ b/src/optimizer/rules/implementation_rules.cpp @@ -1185,6 +1185,30 @@ void LogicalDropViewToPhysicalDropView::Transform(common::ManagedPointeremplace_back(std::move(op)); } +LogicalDropFunctionToPhysicalDropFunction::LogicalDropFunctionToPhysicalDropFunction() { + type_ = RuleType::DROP_FUNCTION_TO_PHYSICAL; + match_pattern_ = new Pattern(OpType::LOGICALDROPFUNCTION); +} + +bool LogicalDropFunctionToPhysicalDropFunction::Check(common::ManagedPointer plan, + OptimizationContext *context) const { + return true; +} + +void LogicalDropFunctionToPhysicalDropFunction::Transform( + common::ManagedPointer input, + std::vector> *transformed, + UNUSED_ATTRIBUTE OptimizationContext *context) const { + auto df_op = input->Contents()->GetContentsAs(); + NOISEPAGE_ASSERT(input->GetChildren().empty(), "LogicalDropFunction should have 0 children"); + + auto op = std::make_unique( + DropFunction::Make(df_op->GetDatabaseOid(), df_op->GetFunctionOid(), df_op->GetIfExists()) + .RegisterWithTxnContext(context->GetOptimizerContext()->GetTxn()), + std::vector>(), context->GetOptimizerContext()->GetTxn()); + transformed->emplace_back(std::move(op)); +} + LogicalAnalyzeToPhysicalAnalyze::LogicalAnalyzeToPhysicalAnalyze() { type_ = RuleType::ANALYZE_TO_PHYSICAL; match_pattern_ = new Pattern(OpType::LOGICALANALYZE); diff --git a/src/parser/expression/constant_value_expression.cpp b/src/parser/expression/constant_value_expression.cpp index 2cc41d40d1..3942fdae05 100644 --- a/src/parser/expression/constant_value_expression.cpp +++ b/src/parser/expression/constant_value_expression.cpp @@ -70,8 +70,7 @@ T ConstantValueExpression::Peek() const { } // NOLINTNEXTLINE: bugprone-suspicious-semicolon: seems like a false positive because of constexpr if constexpr (std::is_same_v || std::is_same_v || std::is_same_v || - std::is_same_v) { // NOLINT: bugprone-suspicious-semicolon: seems like a false positive - // because of constexpr + std::is_same_v) { // NOLINT return static_cast(GetInteger().val_); } // NOLINTNEXTLINE: bugprone-suspicious-semicolon: seems like a false positive because of constexpr @@ -103,6 +102,13 @@ T ConstantValueExpression::Peek() const { UNREACHABLE("Invalid type for Peek."); } +const execution::sql::Val *ConstantValueExpression::SqlValue() const { + // TODO(Kyle): This solution is a bit hacky, we might want to + // consider revisiting (no pun intended) the way that we manage + // parameters provided to the execution context to resolve + return std::visit([](auto &&val) { return static_cast(&val); }, value_); +} + ConstantValueExpression &ConstantValueExpression::operator=(const ConstantValueExpression &other) { if (this != &other) { // self-assignment check expected // AbstractExpression fields we need copied over @@ -130,7 +136,6 @@ ConstantValueExpression &ConstantValueExpression::operator=(const ConstantValueE ConstantValueExpression::ConstantValueExpression(const ConstantValueExpression &other) : AbstractExpression(other) { if (std::holds_alternative(other.value_)) { auto string_val = execution::sql::ValueUtil::CreateStringVal(other.GetStringVal()); - value_ = string_val.first; buffer_ = std::move(string_val.second); } else { diff --git a/src/parser/postgresparser.cpp b/src/parser/postgresparser.cpp index f4ad227072..396392df5f 100644 --- a/src/parser/postgresparser.cpp +++ b/src/parser/postgresparser.cpp @@ -3,6 +3,7 @@ #include #include #include +#include #include #include #include @@ -48,6 +49,11 @@ std::unique_ptr PostgresParser::BuildParseTree(const std::s auto result = pg_query_parse(text); // Parse the query string with the Postgres parser. + + // TODO(Kyle): Syntax "DROP FUNCTION fun;" fails in the + // Postgres parser, do we need to update the version to + // add support for the shorthand syntax? + if (result.error != nullptr) { PARSER_LOG_DEBUG("BuildParseTree error: msg {}, curpos {}", result.error->message, result.error->cursorpos); @@ -61,7 +67,7 @@ std::unique_ptr PostgresParser::BuildParseTree(const std::s // Transform the Postgres parse tree to a Terrier representation. auto parse_result = std::make_unique(); try { - ListTransform(parse_result.get(), result.tree); + ListTransform(parse_result.get(), result.tree, query_string); } catch (const Exception &e) { pg_query_parse_finish(ctx); pg_query_free_parse_result(result); @@ -74,16 +80,54 @@ std::unique_ptr PostgresParser::BuildParseTree(const std::s return parse_result; } -void PostgresParser::ListTransform(ParseResult *parse_result, List *root) { +void PostgresParser::ListTransform(ParseResult *parse_result, List *root, const std::string &query_string) { if (root != nullptr) { for (auto cell = root->head; cell != nullptr; cell = cell->next) { auto node = static_cast(cell->data.ptr_value); - parse_result->AddStatement(NodeTransform(parse_result, node)); + parse_result->AddStatement(NodeTransform(parse_result, node, query_string)); } } } -std::unique_ptr PostgresParser::NodeTransform(ParseResult *parse_result, Node *node) { +/** + * Get the data type for the specified type name. + * @param name The type name (as C-style string) + * @return The data type + */ +static std::optional TypeNameToDataType(const char *name) { + BaseFunctionParameter::DataType data_type; + if ((strcmp(name, "int") == 0) || (strcmp(name, "int4") == 0)) { + data_type = BaseFunctionParameter::DataType::INT; + } else if (strcmp(name, "varchar") == 0) { + data_type = BaseFunctionParameter::DataType::VARCHAR; + } else if (strcmp(name, "int8") == 0) { + data_type = BaseFunctionParameter::DataType::BIGINT; + } else if (strcmp(name, "int2") == 0) { + data_type = BaseFunctionParameter::DataType::SMALLINT; + } else if ((strcmp(name, "double") == 0) || (strcmp(name, "float8") == 0)) { + data_type = BaseFunctionParameter::DataType::DOUBLE; + } else if ((strcmp(name, "real") == 0) || (strcmp(name, "float4") == 0)) { + data_type = BaseFunctionParameter::DataType::FLOAT; + } else if ((strcmp(name, "decimal") == 0) || strcmp(name, "numeric") == 0) { + return BaseFunctionParameter::DataType::DECIMAL; + } else if (strcmp(name, "text") == 0) { + data_type = BaseFunctionParameter::DataType::TEXT; + } else if (strcmp(name, "bpchar") == 0) { + data_type = BaseFunctionParameter::DataType::CHAR; + } else if (strcmp(name, "tinyint") == 0) { + data_type = BaseFunctionParameter::DataType::TINYINT; + } else if (strcmp(name, "bool") == 0) { + data_type = BaseFunctionParameter::DataType::BOOL; + } else if (strcmp(name, "date") == 0) { + data_type = BaseFunctionParameter::DataType::DATE; + } else { + return std::nullopt; + } + return std::make_optional(data_type); +} + +std::unique_ptr PostgresParser::NodeTransform(ParseResult *parse_result, Node *node, + const std::string &query_string) { // TODO(WAN): Document what input is parsed to nullptr if (node == nullptr) { return nullptr; @@ -104,7 +148,7 @@ std::unique_ptr PostgresParser::NodeTransform(ParseResult *parse_r break; } case T_CreateFunctionStmt: { - result = CreateFunctionTransform(parse_result, reinterpret_cast(node)); + result = CreateFunctionTransform(parse_result, reinterpret_cast(node), query_string); break; } case T_CreateSchemaStmt: { @@ -128,7 +172,7 @@ std::unique_ptr PostgresParser::NodeTransform(ParseResult *parse_r break; } case T_ExplainStmt: { - result = ExplainTransform(parse_result, reinterpret_cast(node)); + result = ExplainTransform(parse_result, reinterpret_cast(node), query_string); break; } case T_IndexStmt: { @@ -140,7 +184,7 @@ std::unique_ptr PostgresParser::NodeTransform(ParseResult *parse_r break; } case T_PrepareStmt: { - result = PrepareTransform(parse_result, reinterpret_cast(node)); + result = PrepareTransform(parse_result, reinterpret_cast(node), query_string); break; } case T_SelectStmt: { @@ -1296,21 +1340,23 @@ std::unique_ptr PostgresParser::CreateDatabaseTransform(Pa // Postgres.CreateFunctionStmt -> noisepage.CreateFunctionStatement std::unique_ptr PostgresParser::CreateFunctionTransform(ParseResult *parse_result, - CreateFunctionStmt *root) { + CreateFunctionStmt *root, + const std::string &query_string) { bool replace = root->replace_; - std::vector> func_parameters; - - for (auto cell = root->parameters_->head; cell != nullptr; cell = cell->next) { - auto node = reinterpret_cast(cell->data.ptr_value); - switch (node->type) { - case T_FunctionParameter: { - func_parameters.emplace_back( - FunctionParameterTransform(parse_result, reinterpret_cast(node))); - break; - } - default: { - // TODO(WAN): previous code just ignored it, is this right? - break; + std::vector> func_parameters{}; + if (root->parameters_ != nullptr) { + for (auto cell = root->parameters_->head; cell != nullptr; cell = cell->next) { + auto node = reinterpret_cast(cell->data.ptr_value); + switch (node->type) { + case T_FunctionParameter: { + func_parameters.emplace_back( + FunctionParameterTransform(parse_result, reinterpret_cast(node))); + break; + } + default: { + // TODO(WAN): previous code just ignored it, is this right? + break; + } } } } @@ -1320,7 +1366,9 @@ std::unique_ptr PostgresParser::CreateFunctionTransform(ParseResul // TODO(WAN): assumption from old code, can only pass one function name for now std::string func_name = (reinterpret_cast(root->funcname_->tail->data.ptr_value)->val_.str_); - std::vector func_body; + std::vector func_body{}; + func_body.push_back(query_string); + AsType as_type = AsType::INVALID; PLType pl_type = PLType::INVALID; @@ -1334,7 +1382,7 @@ std::unique_ptr PostgresParser::CreateFunctionTransform(ParseResul func_body.push_back(query_string); } - if (func_body.size() > 1) { + if (func_body.size() > 2) { as_type = AsType::EXECUTABLE; } else { as_type = AsType::QUERY_STRING; @@ -1351,11 +1399,9 @@ std::unique_ptr PostgresParser::CreateFunctionTransform(ParseResul } } - auto result = - std::make_unique(replace, std::move(func_name), std::move(func_body), - std::move(return_type), std::move(func_parameters), pl_type, as_type); - - return result; + return std::make_unique(replace, std::move(func_name), std::move(func_body), + std::move(return_type), std::move(func_parameters), pl_type, + as_type); } // Postgres.IndexStmt -> noisepage.CreateStatement @@ -1660,68 +1706,23 @@ std::unique_ptr PostgresParser::FunctionParameterTransform(ParseR FunctionParameter *root) { // TODO(WAN): significant code duplication, refactor out char* -> DataType char *name = (reinterpret_cast(root->arg_type_->names_->tail->data.ptr_value)->val_.str_); - parser::FuncParameter::DataType data_type; - - if ((strcmp(name, "int") == 0) || (strcmp(name, "int4") == 0)) { - data_type = BaseFunctionParameter::DataType::INT; - } else if (strcmp(name, "varchar") == 0) { - data_type = BaseFunctionParameter::DataType::VARCHAR; - } else if (strcmp(name, "int8") == 0) { - data_type = BaseFunctionParameter::DataType::BIGINT; - } else if (strcmp(name, "int2") == 0) { - data_type = BaseFunctionParameter::DataType::SMALLINT; - } else if ((strcmp(name, "double") == 0) || (strcmp(name, "float8") == 0)) { - data_type = BaseFunctionParameter::DataType::DOUBLE; - } else if ((strcmp(name, "real") == 0) || (strcmp(name, "float4") == 0)) { - data_type = BaseFunctionParameter::DataType::FLOAT; - } else if (strcmp(name, "text") == 0) { - data_type = BaseFunctionParameter::DataType::TEXT; - } else if (strcmp(name, "bpchar") == 0) { - data_type = BaseFunctionParameter::DataType::CHAR; - } else if (strcmp(name, "tinyint") == 0) { - data_type = BaseFunctionParameter::DataType::TINYINT; - } else if (strcmp(name, "bool") == 0) { - data_type = BaseFunctionParameter::DataType::BOOL; - } else { + auto data_type = TypeNameToDataType(name); + if (!data_type.has_value()) { PARSER_LOG_AND_THROW("FunctionParameterTransform", "DataType", name); } auto param_name = root->name_ != nullptr ? root->name_ : ""; - auto result = std::make_unique(data_type, param_name); - return result; + return std::make_unique(data_type.value(), param_name); } // Postgres.TypeName -> noisepage.ReturnType std::unique_ptr PostgresParser::ReturnTypeTransform(ParseResult *parse_result, TypeName *root) { char *name = (reinterpret_cast(root->names_->tail->data.ptr_value)->val_.str_); - ReturnType::DataType data_type; - - if ((strcmp(name, "int") == 0) || (strcmp(name, "int4") == 0)) { - data_type = BaseFunctionParameter::DataType::INT; - } else if (strcmp(name, "varchar") == 0) { - data_type = BaseFunctionParameter::DataType::VARCHAR; - } else if (strcmp(name, "int8") == 0) { - data_type = BaseFunctionParameter::DataType::BIGINT; - } else if (strcmp(name, "int2") == 0) { - data_type = BaseFunctionParameter::DataType::SMALLINT; - } else if ((strcmp(name, "double") == 0) || (strcmp(name, "float8") == 0)) { - data_type = BaseFunctionParameter::DataType::DOUBLE; - } else if ((strcmp(name, "real") == 0) || (strcmp(name, "float4") == 0)) { - data_type = BaseFunctionParameter::DataType::FLOAT; - } else if (strcmp(name, "text") == 0) { - data_type = BaseFunctionParameter::DataType::TEXT; - } else if (strcmp(name, "bpchar") == 0) { - data_type = BaseFunctionParameter::DataType::CHAR; - } else if (strcmp(name, "tinyint") == 0) { - data_type = BaseFunctionParameter::DataType::TINYINT; - } else if (strcmp(name, "bool") == 0) { - data_type = BaseFunctionParameter::DataType::BOOL; - } else { + auto data_type = TypeNameToDataType(name); + if (!data_type.has_value()) { PARSER_LOG_AND_THROW("ReturnTypeTransform", "ReturnType", name); } - - auto result = std::make_unique(data_type); - return result; + return std::make_unique(data_type.value()); } // Postgres.Node -> noisepage.AbstractExpression @@ -1758,6 +1759,9 @@ std::unique_ptr PostgresParser::DeleteTransform(ParseResult *pa // Postgres.DropStmt -> noisepage.DropStatement std::unique_ptr PostgresParser::DropTransform(ParseResult *parse_result, DropStmt *root) { switch (root->remove_type_) { + case ObjectType::OBJECT_FUNCTION: { + return DropFunctionTransform(parse_result, root); + } case ObjectType::OBJECT_INDEX: { return DropIndexTransform(parse_result, root); } @@ -1786,6 +1790,37 @@ std::unique_ptr PostgresParser::DropDatabaseTransform(ParseResult return result; } +// Postgres.DropStmt -> noisepage.DropStatement +std::unique_ptr PostgresParser::DropFunctionTransform(ParseResult *parse_result, DropStmt *root) { + // Grab the function name + auto objects = reinterpret_cast(root->objects_->head->data.ptr_value); + std::string function_name = reinterpret_cast(objects->head->data.ptr_value)->val_.str_; + + // Grab the argument types from the function signature + std::vector function_args{}; + + auto *arguments = reinterpret_cast(root->arguments_->head->data.ptr_value); + if (arguments != nullptr) { + function_args.reserve(arguments->length); + for (auto *cell = arguments->head; cell != nullptr; cell = cell->next) { + // The descriptor for some types consists of a head node with + // "pg_catalog" as the string value, so we need to skip over + auto *descriptor = reinterpret_cast(cell->data.ptr_value)->names_; + if (descriptor->length > 1) { + std::string type = reinterpret_cast(descriptor->head->next->data.ptr_value)->val_.str_; + function_args.emplace_back(std::move(type)); + } else { + std::string type = reinterpret_cast(descriptor->head->data.ptr_value)->val_.str_; + function_args.emplace_back(std::move(type)); + } + } + } + + const auto if_exists = root->missing_ok_; + return std::make_unique(std::make_unique("", "", ""), std::move(function_name), + std::move(function_args), if_exists); +} + // Postgres.DropStmt -> noisepage.DropStatement std::unique_ptr PostgresParser::DropIndexTransform(ParseResult *parse_result, DropStmt *root) { // TODO(WAN): There are unimplemented DROP INDEX options. @@ -1920,10 +1955,11 @@ std::vector> PostgresParser::ParamLis return result; } -std::unique_ptr PostgresParser::ExplainTransform(ParseResult *parse_result, ExplainStmt *root) { +std::unique_ptr PostgresParser::ExplainTransform(ParseResult *parse_result, ExplainStmt *root, + const std::string &query_string) { static constexpr char k_format_tok[] = "format"; std::unique_ptr result; - auto query = NodeTransform(parse_result, root->query_); + auto query = NodeTransform(parse_result, root->query_, query_string); result = std::make_unique(std::move(query)); if (root->options_ != nullptr) { @@ -2078,9 +2114,10 @@ std::vector> PostgresParser::UpdateTargetTransform } // Postgres.PrepareStmt -> noisepage.PrepareStatement -std::unique_ptr PostgresParser::PrepareTransform(ParseResult *parse_result, PrepareStmt *root) { +std::unique_ptr PostgresParser::PrepareTransform(ParseResult *parse_result, PrepareStmt *root, + const std::string &query_string) { auto name = root->name_; - auto query = NodeTransform(parse_result, root->query_); + auto query = NodeTransform(parse_result, root->query_, query_string); // TODO(WAN): This should probably be populated? std::vector> placeholders; diff --git a/src/parser/udf/plpgsql_parser.cpp b/src/parser/udf/plpgsql_parser.cpp new file mode 100644 index 0000000000..0c78a2fc42 --- /dev/null +++ b/src/parser/udf/plpgsql_parser.cpp @@ -0,0 +1,621 @@ +#include + +#include "binder/bind_node_visitor.h" +#include "execution/ast/udf/udf_ast_nodes.h" +#include "parser/expression/subquery_expression.h" +#include "parser/udf/plpgsql_parse_result.h" +#include "parser/udf/plpgsql_parser.h" +#include "parser/udf/string_utils.h" + +#include "libpg_query/pg_query.h" +#include "nlohmann/json.hpp" + +namespace noisepage::parser::udf { + +/** The identifiers used as keys in the parse tree */ +static constexpr const char K_FUNCTION_LIST[] = "FunctionList"; +static constexpr const char K_DATUMS[] = "datums"; +static constexpr const char K_PLPGSQL_VAR[] = "PLpgSQL_var"; +static constexpr const char K_REFNAME[] = "refname"; +static constexpr const char K_DATATYPE[] = "datatype"; +static constexpr const char K_DEFAULT_VAL[] = "default_val"; +static constexpr const char K_PLPGSQL_TYPE[] = "PLpgSQL_type"; +static constexpr const char K_TYPENAME[] = "typname"; +static constexpr const char K_ACTION[] = "action"; +static constexpr const char K_PLPGSQL_FUNCTION[] = "PLpgSQL_function"; +static constexpr const char K_BODY[] = "body"; +static constexpr const char K_PLPGSQL_STMT_BLOCK[] = "PLpgSQL_stmt_block"; +static constexpr const char K_PLPGSQL_STMT_RETURN[] = "PLpgSQL_stmt_return"; +static constexpr const char K_PLPGSQL_STMT_IF[] = "PLpgSQL_stmt_if"; +static constexpr const char K_PLPGSQL_STMT_WHILE[] = "PLpgSQL_stmt_while"; +static constexpr const char K_PLPGSQL_STMT_FORS[] = "PLpgSQL_stmt_fors"; +static constexpr const char K_PLPGSQL_STMT_FORI[] = "PLpgSQL_stmt_fori"; +static constexpr const char K_COND[] = "cond"; +static constexpr const char K_THEN_BODY[] = "then_body"; +static constexpr const char K_ELSE_BODY[] = "else_body"; +static constexpr const char K_EXPR[] = "expr"; +static constexpr const char K_QUERY[] = "query"; +static constexpr const char K_PLPGSQL_EXPR[] = "PLpgSQL_expr"; +static constexpr const char K_PLPGSQL_STMT_ASSIGN[] = "PLpgSQL_stmt_assign"; +static constexpr const char K_VARNO[] = "varno"; +static constexpr const char K_PLGPSQL_STMT_EXECSQL[] = "PLpgSQL_stmt_execsql"; +static constexpr const char K_SQLSTMT[] = "sqlstmt"; +static constexpr const char K_ROW[] = "row"; +static constexpr const char K_FIELDS[] = "fields"; +static constexpr const char K_NAME[] = "name"; +static constexpr const char K_PLPGSQL_ROW[] = "PLpgSQL_row"; +static constexpr const char K_PLPGSQL_STMT_DYNEXECUTE[] = "PLpgSQL_stmt_dynexecute"; +static constexpr const char K_LOWER[] = "lower"; +static constexpr const char K_UPPER[] = "upper"; +static constexpr const char K_STEP[] = "step"; +static constexpr const char K_VAR[] = "var"; + +/** Integral types */ +static constexpr const char DECL_TYPE_ID_SMALLINT[] = "smallint"; +static constexpr const char DECL_TYPE_ID_INT[] = "int"; +static constexpr const char DECL_TYPE_ID_INTEGER[] = "integer"; +static constexpr const char DECL_TYPE_ID_BIGINT[] = "bigint"; + +/** Variable-precision floating point */ +static constexpr const char DECL_TYPE_ID_REAL[] = "real"; +static constexpr const char DECL_TYPE_ID_FLOAT[] = "float"; +static constexpr const char DECL_TYPE_ID_DOUBLE[] = "double precision"; +static constexpr const char DECL_TYPE_ID_FLOAT8[] = "float8"; + +/** Arbitrary-precision floating point */ +static constexpr const char DECL_TYPE_ID_NUMERIC[] = "numeric"; +static constexpr const char DECL_TYPE_ID_DECIMAL[] = "decimal"; + +/** Character types */ +static constexpr const char DECL_TYPE_ID_CHAR[] = "char"; +static constexpr const char DECL_TYPE_ID_VARCHAR[] = "varchar"; +static constexpr const char DECL_TYPE_ID_TEXT[] = "text"; + +/** Other */ +static constexpr const char DECL_TYPE_ID_DATE[] = "date"; +static constexpr const char DECL_TYPE_ID_RECORD[] = "record"; + +std::unique_ptr PLpgSQLParser::Parse( + const std::vector ¶m_names, const std::vector ¶m_types, + const std::string &func_body) { + PLpgSQLParseResult result{pg_query_parse_plpgsql(func_body.c_str())}; + if ((*result).error != nullptr) { + throw PARSER_EXCEPTION(fmt::format("PL/pgSQL parser : {}", (*result).error->message)); + } + + // The result is a list, we need to wrap it + const auto ast_json_str = fmt::format("{{ \"{}\" : {} }}", K_FUNCTION_LIST, (*result).plpgsql_funcs); + + const nlohmann::json ast_json = nlohmann::json::parse(ast_json_str); + const auto function_list = ast_json[K_FUNCTION_LIST]; + NOISEPAGE_ASSERT(function_list.is_array(), "Function list is not an array"); + + if (function_list.size() != 1) { + throw PARSER_EXCEPTION("Function list has size other than 1"); + } + + // TODO(Kyle): This is a zip() + std::size_t i = 0; + for (const auto &udf_name : param_names) { + udf_ast_context_->SetVariableType(udf_name, param_types[i++]); + } + const auto function = function_list[0][K_PLPGSQL_FUNCTION]; + return std::make_unique(ParseFunction(function), param_names, param_types); +} + +std::unique_ptr PLpgSQLParser::ParseFunction(const nlohmann::json &json) { + const auto declarations = json[K_DATUMS]; + NOISEPAGE_ASSERT(declarations.is_array(), "Declaration list is not an array"); + + const auto function_body = json[K_ACTION][K_PLPGSQL_STMT_BLOCK][K_BODY]; + + std::vector> statements{}; + // Skip the first declaration in the datums list; parse all declarations + std::transform(declarations.cbegin() + 1, declarations.cend(), std::back_inserter(statements), + [this](const nlohmann::json &declaration) -> std::unique_ptr { + return ParseDecl(declaration); + }); + // Remove the invalid declarations + statements.erase( + std::remove_if(statements.begin(), statements.end(), + [](std::unique_ptr &stmt) { return !static_cast(stmt); }), + statements.end()); + statements.push_back(ParseBlock(function_body)); + return std::make_unique(std::move(statements)); +} + +std::unique_ptr PLpgSQLParser::ParseBlock(const nlohmann::json &json) { + NOISEPAGE_ASSERT(json.is_array(), "Block isn't array"); + if (json.empty()) { + throw PARSER_EXCEPTION("PL/pgSQL parser : Empty block is not supported"); + } + + std::vector> statements{}; + for (const auto &statement : json) { + const StatementType statement_type = GetStatementType(statement.items().begin().key()); + switch (statement_type) { + case StatementType::RETURN: { + statements.push_back(ParseReturn(statement[K_PLPGSQL_STMT_RETURN])); + break; + } + case StatementType::IF: { + statements.push_back(ParseIf(statement[K_PLPGSQL_STMT_IF])); + break; + } + case StatementType::ASSIGN: { + statements.push_back(ParseAssign(statement[K_PLPGSQL_STMT_ASSIGN])); + break; + } + case StatementType::WHILE: { + statements.push_back(ParseWhile(statement[K_PLPGSQL_STMT_WHILE])); + break; + } + case StatementType::FORI: { + statements.push_back(ParseForI(statement[K_PLPGSQL_STMT_FORI])); + break; + } + case StatementType::FORS: { + statements.push_back(ParseForS(statement[K_PLPGSQL_STMT_FORS])); + break; + } + case StatementType::EXECSQL: { + statements.push_back(ParseExecSQL(statement[K_PLGPSQL_STMT_EXECSQL])); + break; + } + case StatementType::DYNEXECUTE: { + statements.push_back(ParseDynamicSQL(statement[K_PLPGSQL_STMT_DYNEXECUTE])); + break; + } + case StatementType::UNKNOWN: { + throw PARSER_EXCEPTION( + fmt::format("PL/pgSQL Parser : statement type '{}' not supported", statement.items().begin().key())); + } + } + } + + return std::make_unique(std::move(statements)); +} + +std::unique_ptr PLpgSQLParser::ParseReturn(const nlohmann::json &json) { + // TODO(Kyle): Handle RETURN without expression + if (json.empty()) { + throw NOT_IMPLEMENTED_EXCEPTION("PL/pgSQL Parser : RETURN without expression not implemented."); + } + auto expr = ParseExprFromSQLString(json[K_EXPR][K_PLPGSQL_EXPR][K_QUERY].get()); + return std::make_unique(std::move(expr)); +} + +std::unique_ptr PLpgSQLParser::ParseDecl(const nlohmann::json &json) { + const auto &declaration_type = json.items().begin().key(); + if (declaration_type == K_PLPGSQL_VAR) { + auto var_name = json[K_PLPGSQL_VAR][K_REFNAME].get(); + + // Track the local variable (for assignment) + udf_ast_context_->AddLocal(var_name); + + // Parse the initializer, if present + std::unique_ptr initial{nullptr}; + if (json[K_PLPGSQL_VAR].find(K_DEFAULT_VAL) != json[K_PLPGSQL_VAR].end()) { + initial = ParseExprFromSQLString(json[K_PLPGSQL_VAR][K_DEFAULT_VAL][K_PLPGSQL_EXPR][K_QUERY].get()); + } + + // Detemine if the variable has already been declared; + // if so, just re-use this type that has already been resolved + const auto resolved_type = udf_ast_context_->GetVariableType(var_name); + if (resolved_type.has_value()) { + return std::make_unique(var_name, resolved_type.value(), std::move(initial)); + } + + // Otherwise, we perform a string comparison with the type identifier + // for the variable to determine the type for the declaration + + // Grab the type identifier from the PL/pgSQL parser + const std::string type_name = StringUtils::Strip( + StringUtils::Lower(json[K_PLPGSQL_VAR][K_DATATYPE][K_PLPGSQL_TYPE][K_TYPENAME].get())); + auto type = TypeNameToType(type_name); + if (!type.has_value()) { + throw PARSER_EXCEPTION( + fmt::format("PL/pgSQL Parser : unsupported type '{}' for variable '{}'", type_name, var_name)); + } + + udf_ast_context_->SetVariableType(var_name, type.value()); + return std::make_unique(var_name, type.value(), std::move(initial)); + } + + if (declaration_type == K_PLPGSQL_ROW && json[K_PLPGSQL_ROW][K_REFNAME].get() == "*internal*") { + // For query-variant for-loop structures (For-S in PL/pgSQL parlance) + // the Postgres parser generates a dummy internal declaration for the + // variable that is a target of the `SELECT INTO`, we can elide this + return std::unique_ptr{}; + } + + // TODO(Kyle): Handle RECORD declarations + // TODO(Kyle): Handle table row declarations + throw PARSER_EXCEPTION(fmt::format("PL/pgSQL Parser : declaration type '{}' not supported", declaration_type)); +} + +std::unique_ptr PLpgSQLParser::ParseAssign(const nlohmann::json &json) { + // Extract the destination of the assignment + const auto var_index = json[K_VARNO].get() - 1; + const auto &var_name = udf_ast_context_->GetLocalAtIndex(var_index); + + // Attempt to parse the SQL expression to an AST expression + const auto &sql = json[K_EXPR][K_PLPGSQL_EXPR][K_QUERY].get(); + auto rhs = TryParseExprFromSQLString(sql); + if (rhs.has_value()) { + auto lhs = std::make_unique(var_name); + return std::make_unique(std::move(lhs), std::move(*rhs)); + } + + // Failed to parse the SQL expression to an AST expression; + // this could be the result of malformed SQL, OR it could + // be that the SQL is sufficiently complex that we need to + // generate code to execute the query. In this latter case, + // we use the existing infrastructure for executing SQL in + // the UDF body, and "desugar" the assignment to a SELECT INTO. + + // TODO(Kyle): Is this semantically correct? We are hacking + // an assignment expression into a SQL execution statement + return ParseExecSQL(sql, std::vector{var_name}); +} + +std::unique_ptr PLpgSQLParser::ParseIf(const nlohmann::json &json) { + auto cond_expr = ParseExprFromSQLString(json[K_COND][K_PLPGSQL_EXPR][K_QUERY].get()); + auto then_stmt = ParseBlock(json[K_THEN_BODY]); + std::unique_ptr else_stmt = + json.contains(K_ELSE_BODY) ? ParseBlock(json[K_ELSE_BODY]) : nullptr; + return std::make_unique(std::move(cond_expr), std::move(then_stmt), + std::move(else_stmt)); +} + +std::unique_ptr PLpgSQLParser::ParseWhile(const nlohmann::json &json) { + auto cond_expr = ParseExprFromSQLString(json[K_COND][K_PLPGSQL_EXPR][K_QUERY].get()); + auto body_stmt = ParseBlock(json[K_BODY]); + return std::make_unique(std::move(cond_expr), std::move(body_stmt)); +} + +std::unique_ptr PLpgSQLParser::ParseForI(const nlohmann::json &json) { + const auto name = json[K_VAR][K_PLPGSQL_VAR][K_REFNAME].get(); + auto lower = ParseExprFromSQLString(json[K_LOWER][K_PLPGSQL_EXPR][K_QUERY]); + auto upper = ParseExprFromSQLString(json[K_UPPER][K_PLPGSQL_EXPR][K_QUERY]); + auto step = json.contains(K_STEP) ? ParseExprFromSQLString(json[K_STEP][K_PLPGSQL_EXPR][K_QUERY]) + : ParseExprFromSQLString(execution::ast::udf::ForIStmtAST::DEFAULT_STEP_EXPR); + auto body = ParseBlock(json[K_BODY]); + return std::make_unique(name, std::move(lower), std::move(upper), std::move(step), + std::move(body)); +} + +std::unique_ptr PLpgSQLParser::ParseForS(const nlohmann::json &json) { + const auto sql_query = json[K_QUERY][K_PLPGSQL_EXPR][K_QUERY].get(); + auto parse_result = PostgresParser::BuildParseTree(sql_query); + if (parse_result == nullptr) { + return nullptr; + } + auto body_stmt = ParseBlock(json[K_BODY]); + auto variable_array = json[K_ROW][K_PLPGSQL_ROW][K_FIELDS]; + std::vector variables{}; + variables.reserve(variable_array.size()); + std::transform(variable_array.cbegin(), variable_array.cend(), std::back_inserter(variables), + [](const nlohmann::json &var) { return var[K_NAME].get(); }); + + if (!AllVariablesDeclared(variables)) { + throw PARSER_EXCEPTION("PLpgSQL parser : variable was not declared"); + } + + return std::make_unique(std::move(variables), std::move(parse_result), + std::move(body_stmt)); +} + +std::unique_ptr PLpgSQLParser::ParseExecSQL(const nlohmann::json &json) { + // The query text + const auto sql = json[K_SQLSTMT][K_PLPGSQL_EXPR][K_QUERY].get(); + // The variable(s) to which results are bound + const auto variable_array = json[K_ROW][K_PLPGSQL_ROW][K_FIELDS]; + std::vector variables{}; + variables.reserve(variable_array.size()); + std::transform(variable_array.cbegin(), variable_array.cend(), std::back_inserter(variables), + [](const nlohmann::json &var) -> std::string { return var[K_NAME].get(); }); + + return ParseExecSQL(sql, std::move(variables)); +} + +std::unique_ptr PLpgSQLParser::ParseExecSQL(const std::string &sql, + std::vector &&variables) { + auto parse_result = StripEnclosingQuery(PostgresParser::BuildParseTree(sql)); + if (!parse_result) { + throw PARSER_EXCEPTION(fmt::format("PL/pgSQL parser : failed to parse query '{}'", sql)); + } + + // Ensure all variables to which results are bound are declared + if (!AllVariablesDeclared(variables)) { + throw PARSER_EXCEPTION("PL/pgSQL parser : variable was not declared"); + } + + /** + * Two possibilities for binding of results: + * - Exactly one RECORD variable + * - One or more non-RECORD variables + */ + + if (ContainsRecordType(variables)) { + if (variables.size() > 1) { + throw PARSER_EXCEPTION("Binding of query results is ambiguous"); + } + // There is only a single result variable and it is a RECORD; + // derive the structure of the RECORD from the SELECT columns + const auto &name = variables.front(); + udf_ast_context_->SetRecordType(name, ResolveRecordType(parse_result.get())); + } + + return std::make_unique(std::move(parse_result), std::move(variables)); +} + +std::unique_ptr PLpgSQLParser::ParseDynamicSQL(const nlohmann::json &json) { + auto sql_expr = ParseExprFromSQLString(json[K_QUERY][K_PLPGSQL_EXPR][K_QUERY].get()); + auto var_name = json[K_ROW][K_PLPGSQL_ROW][K_FIELDS][0][K_NAME].get(); + return std::make_unique(std::move(sql_expr), std::move(var_name)); +} + +std::unique_ptr PLpgSQLParser::ParseExprFromSQLString(const std::string &sql) { + auto expr = TryParseExprFromSQLString(sql); + if (!expr.has_value()) { + throw PARSER_EXCEPTION("PL/pgSQL parser : failed to parse SQL query"); + } + return std::move(*expr); +} + +std::optional> PLpgSQLParser::TryParseExprFromSQLString( + const std::string &sql) noexcept { + auto statements = PostgresParser::BuildParseTree(sql); + if (!statements) { + return std::nullopt; + } + + if (statements->GetStatements().size() != 1) { + return std::nullopt; + } + return TryParseExprFromSQLStatement(statements->GetStatement(0)); +} + +std::unique_ptr PLpgSQLParser::ParseExprFromSQLStatement( + common::ManagedPointer statement) { + auto expr = TryParseExprFromSQLStatement(statement); + if (!expr.has_value()) { + throw PARSER_EXCEPTION("PL/pgSQL parser : failed to parse SQL statement"); + } + return std::move(*expr); +} + +std::optional> PLpgSQLParser::TryParseExprFromSQLStatement( + common::ManagedPointer statement) noexcept { + if (statement->GetType() != parser::StatementType::SELECT) { + return std::nullopt; + } + + auto select = statement.CastManagedPointerTo(); + if (select->GetSelectTable() != nullptr || select->GetSelectColumns().size() != 1) { + return std::nullopt; + } + return TryParseExprFromAbstract(select->GetSelectColumns().at(0)); +} + +std::unique_ptr PLpgSQLParser::ParseExprFromAbstract( + common::ManagedPointer expr) { + auto result = TryParseExprFromAbstract(expr); + if (!result.has_value()) { + throw PARSER_EXCEPTION(fmt::format("PL/pgSQL parser : expression type '{}' not supported", + parser::ExpressionTypeToShortString(expr->GetExpressionType()))); + } + return std::move(*result); +} + +std::optional> PLpgSQLParser::TryParseExprFromAbstract( + common::ManagedPointer expr) noexcept { + if ((parser::ExpressionUtil::IsOperatorExpression(expr->GetExpressionType()) && expr->GetChildrenSize() == 2) || + (parser::ExpressionUtil::IsComparisonExpression(expr->GetExpressionType()))) { + auto lhs = TryParseExprFromAbstract(expr->GetChild(0)); + auto rhs = TryParseExprFromAbstract(expr->GetChild(1)); + if (!lhs.has_value() || !rhs.has_value()) { + return std::nullopt; + } + return std::make_optional(std::make_unique(expr->GetExpressionType(), + std::move(*lhs), std::move(*rhs))); + } + + switch (expr->GetExpressionType()) { + case parser::ExpressionType::COLUMN_VALUE: { + auto cve = expr.CastManagedPointerTo(); + if (cve->GetTableAlias().GetName().empty()) { + return std::make_optional(std::make_unique(cve->GetColumnName())); + } + auto vexpr = std::make_unique(cve->GetTableAlias().GetName()); + return std::make_optional( + std::make_unique(std::move(vexpr), cve->GetColumnName())); + } + case parser::ExpressionType::FUNCTION: { + auto func_expr = expr.CastManagedPointerTo(); + std::vector> args{}; + for (auto child : func_expr->GetChildren()) { + auto argument = TryParseExprFromAbstract(child); + if (!argument.has_value()) { + return std::nullopt; + } + args.push_back(std::move(*argument)); + } + return std::make_optional( + std::make_unique(func_expr->GetFuncName(), std::move(args))); + } + case parser::ExpressionType::VALUE_CONSTANT: + return std::make_optional(std::make_unique(expr->Copy())); + case parser::ExpressionType::OPERATOR_IS_NOT_NULL: { + auto target = TryParseExprFromAbstract(expr->GetChild(0)); + if (!target.has_value()) { + return std::nullopt; + } + return std::make_optional(std::make_unique(false, std::move(*target))); + } + case parser::ExpressionType::OPERATOR_IS_NULL: { + auto target = TryParseExprFromAbstract(expr->GetChild(0)); + if (!target.has_value()) { + return std::nullopt; + } + return std::make_optional(std::make_unique(true, std::move(*target))); + } + case parser::ExpressionType::ROW_SUBQUERY: { + // We can handle subqeries, but only in the event + // that they are shallow "wrappers" around simple queries + auto subquery_expr = expr.CastManagedPointerTo(); + return TryParseExprFromSQLStatement(subquery_expr->GetSubselect().CastManagedPointerTo()); + } + default: + return std::nullopt; + } +} + +bool PLpgSQLParser::AllVariablesDeclared(const std::vector &names) const { + return std::all_of(names.cbegin(), names.cend(), + [this](const std::string &name) -> bool { return udf_ast_context_->HasVariable(name); }); +} + +bool PLpgSQLParser::ContainsRecordType(const std::vector &names) const { + return std::any_of(names.cbegin(), names.cend(), [this](const std::string &name) -> bool { + return udf_ast_context_->GetVariableType(name) == execution::sql::SqlTypeId::Invalid; + }); +} + +std::vector> PLpgSQLParser::ResolveRecordType( + const ParseResult *parse_result) { + std::vector> fields{}; + const auto &select_columns = + parse_result->GetStatement(0).CastManagedPointerTo()->GetSelectColumns(); + fields.reserve(select_columns.size()); + std::transform(select_columns.cbegin(), select_columns.cend(), std::back_inserter(fields), + [](const common::ManagedPointer &column) { + return std::make_pair(column->GetAlias().GetName(), column->GetReturnValueType()); + }); + return fields; +} + +StatementType PLpgSQLParser::GetStatementType(const std::string &type) { + StatementType stmt_type; + if (type == K_PLPGSQL_STMT_RETURN) { + stmt_type = StatementType::RETURN; + } else if (type == K_PLPGSQL_STMT_IF) { + stmt_type = StatementType::IF; + } else if (type == K_PLPGSQL_STMT_ASSIGN) { + stmt_type = StatementType::ASSIGN; + } else if (type == K_PLPGSQL_STMT_WHILE) { + stmt_type = StatementType::WHILE; + } else if (type == K_PLPGSQL_STMT_FORI) { + stmt_type = StatementType::FORI; + } else if (type == K_PLPGSQL_STMT_FORS) { + stmt_type = StatementType::FORS; + } else if (type == K_PLGPSQL_STMT_EXECSQL) { + stmt_type = StatementType::EXECSQL; + } else if (type == K_PLPGSQL_STMT_DYNEXECUTE) { + stmt_type = StatementType::DYNEXECUTE; + } else { + stmt_type = StatementType::UNKNOWN; + } + return stmt_type; +} + +// Static +std::unique_ptr PLpgSQLParser::StripEnclosingQuery(std::unique_ptr &&input) { + NOISEPAGE_ASSERT(input->GetStatements().size() == 1, "Must have a single SQL statement"); + + // If the query does not match the target pattern, return unmodified + if (!HasEnclosingQuery(input.get())) { + return std::move(input); + } + + // The query consists of enclosing SELECT around a + // subquery that implements the actual logic we want; + // now we perform some surgery on the ParseResult + + // Grab the SELECT from the subquery expression + auto statement = input->GetStatement(0); + auto select = statement.CastManagedPointerTo(); + auto subquery = select->GetSelectColumns().at(0).CastManagedPointerTo(); + + // Here, we take ownership of the new top-level statement for the query; + // the SELECT does not own its own target expressions, however, so we + // need to ensure that we manually copy these over into the new ParseResult + // such that their data is still alive after the transformation + auto subselect = subquery->ReleaseSubselect(); + + // Take ownership of the expressions we want; it is important + // that we actually take ownership of the existing expressions + // rather than making a copy of the collection because the + // remainder of the statements in the query hold non-owning + // pointers to these existing expressions, copies will result + // in dangling pointers in any number of the query statements + auto expressions = input->TakeExpressionsOwnership(); + expressions.erase(std::remove_if(expressions.begin(), expressions.end(), + [](const std::unique_ptr &expr) { + return expr->GetExpressionType() == parser::ExpressionType::ROW_SUBQUERY; + }), + expressions.end()); + + auto output = std::make_unique(); + output->AddStatement(std::move(subselect)); + for (auto &expression : expressions) { + output->AddExpression(std::move(expression)); + } + + // The input ParseResult is dropped here, so we need to be sure + // that we have extracted all of the data that we want out of it + + return output; +} + +bool PLpgSQLParser::HasEnclosingQuery(ParseResult *parse_result) { + NOISEPAGE_ASSERT(parse_result->GetStatements().size() == 1, "Must have a single SQL statement"); + auto statement = parse_result->GetStatement(0); + if (statement->GetType() != parser::StatementType::SELECT) { + return false; + } + auto select = statement.CastManagedPointerTo(); + if (select->GetSelectColumns().size() > 1) { + return false; + } + auto target = select->GetSelectColumns().at(0); + return (target->GetExpressionType() == parser::ExpressionType::ROW_SUBQUERY); +} + +std::optional PLpgSQLParser::TypeNameToType(const std::string &type_name) { + // TODO(Kyle): This is awkward control flow because we + // model RECORD types with the SqlTypeId::Invalid type + + execution::sql::SqlTypeId type; + if (type_name == DECL_TYPE_ID_SMALLINT) { + type = execution::sql::SqlTypeId::SmallInt; + } else if (type_name == DECL_TYPE_ID_INT || type_name == DECL_TYPE_ID_INTEGER) { + type = execution::sql::SqlTypeId::Integer; + } else if (type_name == DECL_TYPE_ID_BIGINT) { + type = execution::sql::SqlTypeId::BigInt; + } else if (type_name == DECL_TYPE_ID_REAL || type_name == DECL_TYPE_ID_FLOAT) { + // NOTE(Kyle): We perform a sneaky trick here: the "normal" + // SQL frontend automatically promotes all floating-point + // types to DOUBLE PRECISION (FLOAT8); we do the same thing + // here to remain consistent across the entire system. + type = execution::sql::SqlTypeId::Double; + } else if (type_name == DECL_TYPE_ID_DOUBLE || type_name == DECL_TYPE_ID_FLOAT8) { + type = execution::sql::SqlTypeId::Double; + } else if (type_name == DECL_TYPE_ID_NUMERIC || type_name == DECL_TYPE_ID_DECIMAL) { + type = execution::sql::SqlTypeId::Decimal; + } else if (type_name == DECL_TYPE_ID_CHAR) { + type = execution::sql::SqlTypeId::Char; + } else if (type_name == DECL_TYPE_ID_VARCHAR || type_name == DECL_TYPE_ID_TEXT) { + type = execution::sql::SqlTypeId::Varchar; + } else if (type_name == DECL_TYPE_ID_DATE) { + type = execution::sql::SqlTypeId::Date; + } else if (type_name == DECL_TYPE_ID_RECORD) { + type = execution::sql::SqlTypeId::Invalid; + } else { + return std::nullopt; + } + return std::make_optional(type); +} + +} // namespace noisepage::parser::udf diff --git a/src/parser/udf/string_utils.cpp b/src/parser/udf/string_utils.cpp new file mode 100644 index 0000000000..4cd0db11ab --- /dev/null +++ b/src/parser/udf/string_utils.cpp @@ -0,0 +1,32 @@ +#include "parser/udf/string_utils.h" + +#include + +namespace noisepage::parser::udf { + +std::string StringUtils::Lower(const std::string &string) { + std::string result{}; + std::transform(string.cbegin(), string.cend(), std::back_inserter(result), + [](unsigned char c) { return std::tolower(c); }); + return result; +} + +std::string StringUtils::Strip(const std::string &string) { + auto not_whitespace = [](unsigned char c) { return std::isspace(c) == 0; }; + + // Find the first non-whitespace character + auto begin = std::find_if(string.cbegin(), string.cend(), not_whitespace); + if (begin == string.cend()) { + return std::string{}; + } + + // Find the last non whitespace character + auto end = std::find_if(string.rbegin(), string.rend(), not_whitespace); + + // Construct the result + std::string result{}; + std::copy(begin, end.base(), std::back_inserter(result)); + return result; +} + +} // namespace noisepage::parser::udf diff --git a/src/planner/plannodes/abstract_plan_node.cpp b/src/planner/plannodes/abstract_plan_node.cpp index 30e86899dd..282b8e5aeb 100644 --- a/src/planner/plannodes/abstract_plan_node.cpp +++ b/src/planner/plannodes/abstract_plan_node.cpp @@ -18,6 +18,7 @@ #include "planner/plannodes/csv_scan_plan_node.h" #include "planner/plannodes/delete_plan_node.h" #include "planner/plannodes/drop_database_plan_node.h" +#include "planner/plannodes/drop_function_plan_node.h" #include "planner/plannodes/drop_index_plan_node.h" #include "planner/plannodes/drop_namespace_plan_node.h" #include "planner/plannodes/drop_table_plan_node.h" @@ -216,6 +217,10 @@ JSONDeserializeNodeIntermediate DeserializePlanNode(const nlohmann::json &json) plan_node = std::make_unique(); break; } + case PlanNodeType::DROP_FUNC: { + plan_node = std::make_unique(); + break; + } case PlanNodeType::EXPORT_EXTERNAL_FILE: { plan_node = std::make_unique(); break; diff --git a/src/planner/plannodes/drop_function_plan_node.cpp b/src/planner/plannodes/drop_function_plan_node.cpp new file mode 100644 index 0000000000..7cc4ce9a3c --- /dev/null +++ b/src/planner/plannodes/drop_function_plan_node.cpp @@ -0,0 +1,73 @@ +#include "planner/plannodes/drop_function_plan_node.h" + +#include +#include +#include + +#include "common/json.h" +#include "planner/plannodes/output_schema.h" + +namespace noisepage::planner { + +std::unique_ptr DropFunctionPlanNode::Builder::Build() { + return std::unique_ptr(new DropFunctionPlanNode( + std::move(children_), std::move(output_schema_), database_oid_, proc_oid_, if_exists_, plan_node_id_)); +} + +DropFunctionPlanNode::DropFunctionPlanNode(std::vector> &&children, + std::unique_ptr output_schema, catalog::db_oid_t database_oid, + catalog::proc_oid_t proc_oid, bool if_exists, plan_node_id_t plan_node_id) + : AbstractPlanNode(std::move(children), std::move(output_schema), plan_node_id), + database_oid_(database_oid), + proc_oid_(proc_oid), + if_exists_(if_exists) {} + +common::hash_t DropFunctionPlanNode::Hash() const { + common::hash_t hash = AbstractPlanNode::Hash(); + // Hash database_oid + hash = common::HashUtil::CombineHashes(hash, common::HashUtil::Hash(database_oid_)); + // Hash procedure oid + hash = common::HashUtil::CombineHashes(hash, common::HashUtil::Hash(proc_oid_)); + // Hash `IF EXISTS` + hash = common::HashUtil::CombineHashes(hash, common::HashUtil::Hash(if_exists_)); + return hash; +} + +bool DropFunctionPlanNode::operator==(const AbstractPlanNode &rhs) const { + if (!AbstractPlanNode::operator==(rhs)) return false; + + auto &other = dynamic_cast(rhs); + + // Database OID + if (database_oid_ != other.database_oid_) return false; + + // Procedure OID + if (proc_oid_ != other.proc_oid_) return false; + + // IF EXISTS + if (if_exists_ != other.if_exists_) return false; + + return true; +} + +nlohmann::json DropFunctionPlanNode::ToJson() const { + nlohmann::json j = AbstractPlanNode::ToJson(); + j["database_oid"] = database_oid_; + j["proc_oid"] = proc_oid_; + j["if_exists"] = if_exists_; + return j; +} + +std::vector> DropFunctionPlanNode::FromJson(const nlohmann::json &j) { + std::vector> exprs; + auto e1 = AbstractPlanNode::FromJson(j); + exprs.insert(exprs.end(), std::make_move_iterator(e1.begin()), std::make_move_iterator(e1.end())); + database_oid_ = j.at("database_oid").get(); + proc_oid_ = j.at("proc_oid").get(); + if_exists_ = j.at("if_exists").get(); + return exprs; +} + +DEFINE_JSON_BODY_DECLARATIONS(DropFunctionPlanNode); + +} // namespace noisepage::planner diff --git a/src/planner/plannodes/plan_node_defs.cpp b/src/planner/plannodes/plan_node_defs.cpp index c42ab91451..fd955e35a4 100644 --- a/src/planner/plannodes/plan_node_defs.cpp +++ b/src/planner/plannodes/plan_node_defs.cpp @@ -52,6 +52,8 @@ std::string PlanNodeTypeToString(PlanNodeType type) { return "DropTrigger"; case PlanNodeType::DROP_VIEW: return "DropView"; + case PlanNodeType::DROP_FUNC: + return "DropFunction"; case PlanNodeType::ANALYZE: return "Analyze"; case PlanNodeType::AGGREGATE: diff --git a/src/storage/recovery/recovery_manager.cpp b/src/storage/recovery/recovery_manager.cpp index 22c81fcb73..87dfaeeda9 100644 --- a/src/storage/recovery/recovery_manager.cpp +++ b/src/storage/recovery/recovery_manager.cpp @@ -1111,7 +1111,7 @@ uint32_t RecoveryManager::ProcessSpecialCasePGProcRecord( auto result UNUSED_ATTRIBUTE = catalog_->GetDatabaseCatalog(common::ManagedPointer(txn), redo_record->GetDatabaseOid()) - ->SetFunctionContextPointer(common::ManagedPointer(txn), proc_oid, nullptr); + ->SetFunctionContext(common::ManagedPointer(txn), proc_oid, nullptr); NOISEPAGE_ASSERT(result, "Setting to null did not work"); return 0; // No additional records processed } diff --git a/src/traffic_cop/traffic_cop.cpp b/src/traffic_cop/traffic_cop.cpp index c8c2234a70..d9c533772f 100644 --- a/src/traffic_cop/traffic_cop.cpp +++ b/src/traffic_cop/traffic_cop.cpp @@ -14,7 +14,7 @@ #include "common/error/exception.h" #include "common/thread_context.h" #include "execution/compiler/compilation_context.h" -#include "execution/exec/execution_context.h" +#include "execution/exec/execution_context_builder.h" #include "execution/exec/execution_settings.h" #include "execution/exec/output.h" #include "execution/sql/ddl_executors.h" @@ -36,10 +36,12 @@ #include "planner/plannodes/abstract_plan_node.h" #include "planner/plannodes/analyze_plan_node.h" #include "planner/plannodes/create_database_plan_node.h" +#include "planner/plannodes/create_function_plan_node.h" #include "planner/plannodes/create_index_plan_node.h" #include "planner/plannodes/create_namespace_plan_node.h" #include "planner/plannodes/create_table_plan_node.h" #include "planner/plannodes/drop_database_plan_node.h" +#include "planner/plannodes/drop_function_plan_node.h" #include "planner/plannodes/drop_index_plan_node.h" #include "planner/plannodes/drop_namespace_plan_node.h" #include "planner/plannodes/drop_table_plan_node.h" @@ -219,7 +221,7 @@ TrafficCopResult TrafficCop::ExecuteSetStatement(common::ManagedPointer connection_ctx, @@ -245,7 +247,7 @@ TrafficCopResult TrafficCop::ExecuteShowStatement(common::ManagedPointerWriteDataRow(reinterpret_cast(&result), cols, {network::FieldFormat::text}); - return {ResultType::COMPLETE, 0u}; + return {ResultType::COMPLETE, 0U}; } TrafficCopResult TrafficCop::ExecuteCreateStatement( @@ -257,35 +259,44 @@ TrafficCopResult TrafficCop::ExecuteCreateStatement( NOISEPAGE_ASSERT( query_type == network::QueryType::QUERY_CREATE_TABLE || query_type == network::QueryType::QUERY_CREATE_SCHEMA || query_type == network::QueryType::QUERY_CREATE_INDEX || query_type == network::QueryType::QUERY_CREATE_DB || - query_type == network::QueryType::QUERY_CREATE_VIEW || query_type == network::QueryType::QUERY_CREATE_TRIGGER, + query_type == network::QueryType::QUERY_CREATE_VIEW || + query_type == network::QueryType::QUERY_CREATE_TRIGGER || + query_type == network::QueryType::QUERY_CREATE_FUNCTION, "ExecuteCreateStatement called with invalid QueryType."); switch (query_type) { case network::QueryType::QUERY_CREATE_TABLE: { if (execution::sql::DDLExecutors::CreateTableExecutor( physical_plan.CastManagedPointerTo(), connection_ctx->Accessor(), connection_ctx->GetDatabaseOid())) { - return {ResultType::COMPLETE, 0u}; + return {ResultType::COMPLETE, 0U}; } break; } case network::QueryType::QUERY_CREATE_DB: { if (execution::sql::DDLExecutors::CreateDatabaseExecutor( physical_plan.CastManagedPointerTo(), connection_ctx->Accessor())) { - return {ResultType::COMPLETE, 0u}; + return {ResultType::COMPLETE, 0U}; } break; } case network::QueryType::QUERY_CREATE_INDEX: { if (execution::sql::DDLExecutors::CreateIndexExecutor( physical_plan.CastManagedPointerTo(), connection_ctx->Accessor())) { - return {ResultType::COMPLETE, 0u}; + return {ResultType::COMPLETE, 0U}; } break; } case network::QueryType::QUERY_CREATE_SCHEMA: { if (execution::sql::DDLExecutors::CreateNamespaceExecutor( physical_plan.CastManagedPointerTo(), connection_ctx->Accessor())) { - return {ResultType::COMPLETE, 0u}; + return {ResultType::COMPLETE, 0U}; + } + break; + } + case network::QueryType::QUERY_CREATE_FUNCTION: { + if (execution::sql::DDLExecutors::CreateFunctionExecutor( + physical_plan.CastManagedPointerTo(), connection_ctx->Accessor())) { + return {ResultType::COMPLETE, 0U}; } break; } @@ -311,13 +322,14 @@ TrafficCopResult TrafficCop::ExecuteDropStatement( NOISEPAGE_ASSERT( query_type == network::QueryType::QUERY_DROP_TABLE || query_type == network::QueryType::QUERY_DROP_SCHEMA || query_type == network::QueryType::QUERY_DROP_INDEX || query_type == network::QueryType::QUERY_DROP_DB || - query_type == network::QueryType::QUERY_DROP_VIEW || query_type == network::QueryType::QUERY_DROP_TRIGGER, + query_type == network::QueryType::QUERY_DROP_VIEW || query_type == network::QueryType::QUERY_DROP_TRIGGER || + query_type == network::QueryType::QUERY_DROP_FUNCTION, "ExecuteDropStatement called with invalid QueryType."); switch (query_type) { case network::QueryType::QUERY_DROP_TABLE: { if (execution::sql::DDLExecutors::DropTableExecutor( physical_plan.CastManagedPointerTo(), connection_ctx->Accessor())) { - return {ResultType::COMPLETE, 0u}; + return {ResultType::COMPLETE, 0U}; } break; } @@ -325,21 +337,28 @@ TrafficCopResult TrafficCop::ExecuteDropStatement( if (execution::sql::DDLExecutors::DropDatabaseExecutor( physical_plan.CastManagedPointerTo(), connection_ctx->Accessor(), connection_ctx->GetDatabaseOid())) { - return {ResultType::COMPLETE, 0u}; + return {ResultType::COMPLETE, 0U}; } break; } case network::QueryType::QUERY_DROP_INDEX: { if (execution::sql::DDLExecutors::DropIndexExecutor( physical_plan.CastManagedPointerTo(), connection_ctx->Accessor())) { - return {ResultType::COMPLETE, 0u}; + return {ResultType::COMPLETE, 0U}; } break; } case network::QueryType::QUERY_DROP_SCHEMA: { if (execution::sql::DDLExecutors::DropNamespaceExecutor( physical_plan.CastManagedPointerTo(), connection_ctx->Accessor())) { - return {ResultType::COMPLETE, 0u}; + return {ResultType::COMPLETE, 0U}; + } + break; + } + case network::QueryType::QUERY_DROP_FUNCTION: { + if (execution::sql::DDLExecutors::DropFunctionExecutor( + physical_plan.CastManagedPointerTo(), connection_ctx->Accessor())) { + return {ResultType::COMPLETE, 0U}; } break; } @@ -399,7 +418,7 @@ TrafficCopResult TrafficCop::ExecuteExplainStatement( out->WriteDataRow(reinterpret_cast(&plan_string_val), output_columns, {network::FieldFormat::text}); - return {ResultType::COMPLETE, 0u}; + return {ResultType::COMPLETE, 0U}; } std::variant, common::ErrorData> TrafficCop::ParseQuery( @@ -456,7 +475,7 @@ TrafficCopResult TrafficCop::BindQuery( return {ResultType::ERROR, error}; } - return {ResultType::COMPLETE, 0u}; + return {ResultType::COMPLETE, 0U}; } TrafficCopResult TrafficCop::CodegenPhysicalPlan( @@ -483,7 +502,7 @@ TrafficCopResult TrafficCop::CodegenPhysicalPlan( if (portal->GetStatement()->GetExecutableQuery() != nullptr && use_query_cache_) { // We've already codegen'd this, move on... - return {ResultType::COMPLETE, 0u}; + return {ResultType::COMPLETE, 0U}; } // TODO(WAN): see #1047 @@ -521,7 +540,7 @@ TrafficCopResult TrafficCop::CodegenPhysicalPlan( portal->GetStatement()->SetExecutableQuery(std::move(exec_query)); - return {ResultType::COMPLETE, 0u}; + return {ResultType::COMPLETE, 0U}; } TrafficCopResult TrafficCop::RunExecutableQuery(const common::ManagedPointer connection_ctx, @@ -584,11 +603,18 @@ TrafficCopResult TrafficCop::RunExecutableQuery(const common::ManagedPointerMetricsManager(); } - auto exec_ctx = std::make_unique( - connection_ctx->GetDatabaseOid(), connection_ctx->Transaction(), callback, physical_plan->GetOutputSchema().Get(), - connection_ctx->Accessor(), exec_settings, metrics, replication_manager_, recovery_manager_); - - exec_ctx->SetParams(portal->Parameters()); + auto exec_ctx = execution::exec::ExecutionContextBuilder() + .WithDatabaseOID(connection_ctx->GetDatabaseOid()) + .WithExecutionSettings(exec_settings) + .WithTxnContext(connection_ctx->Transaction()) + .WithOutputSchema(physical_plan->GetOutputSchema()) + .WithOutputCallback(std::move(callback)) + .WithCatalogAccessor(connection_ctx->Accessor()) + .WithMetricsManager(metrics) + .WithReplicationManager(replication_manager_) + .WithRecoveryManager(recovery_manager_) + .WithQueryParametersFrom(*portal->Parameters()) + .Build(); const auto exec_query = portal->GetStatement()->GetExecutableQuery(); diff --git a/src/traffic_cop/traffic_cop_util.cpp b/src/traffic_cop/traffic_cop_util.cpp index de00d0e37a..4480d1253b 100644 --- a/src/traffic_cop/traffic_cop_util.cpp +++ b/src/traffic_cop/traffic_cop_util.cpp @@ -143,6 +143,9 @@ network::QueryType TrafficCopUtil::QueryTypeForStatement(const common::ManagedPo return network::QueryType::QUERY_CREATE_VIEW; } } + case parser::StatementType::CREATE_FUNC: { + return network::QueryType::QUERY_CREATE_FUNCTION; + } case parser::StatementType::DROP: { const auto drop_type = statement.CastManagedPointerTo()->GetDropType(); switch (drop_type) { @@ -160,6 +163,8 @@ network::QueryType TrafficCopUtil::QueryTypeForStatement(const common::ManagedPo return network::QueryType::QUERY_DROP_PREPARED_STATEMENT; case parser::DropStatement::DropType::kTrigger: return network::QueryType::QUERY_DROP_TRIGGER; + case parser::DropStatement::DropType::kFunction: + return network::QueryType::QUERY_DROP_FUNCTION; } } case parser::StatementType::VARIABLE_SET: diff --git a/src/util/query_exec_util.cpp b/src/util/query_exec_util.cpp index 81aea8f34c..057f77ca78 100644 --- a/src/util/query_exec_util.cpp +++ b/src/util/query_exec_util.cpp @@ -8,9 +8,9 @@ #include "catalog/catalog_accessor.h" #include "execution/compiler/compilation_context.h" #include "execution/compiler/executable_query.h" -#include "execution/exec/execution_context.h" +#include "execution/exec/execution_context_builder.h" #include "execution/sql/ddl_executors.h" -#include "execution/vm/vm_defs.h" +#include "execution/vm/execution_mode.h" #include "loggers/common_logger.h" #include "metrics/metrics_manager.h" #include "network/network_defs.h" @@ -190,7 +190,7 @@ bool QueryExecUtil::ExecuteDDL(const std::string &query, bool what_if) { // has run. We can't compile the query before the CreateIndexExecutor because codegen would have // no idea which index to insert into. execution::exec::ExecutionSettings settings{}; - common::ManagedPointer schema = out_plan->GetOutputSchema(); + const auto schema = out_plan->GetOutputSchema(); auto exec_query = execution::compiler::CompilationContext::Compile( *out_plan, settings, accessor.get(), execution::compiler::CompilationMode::OneShot, std::nullopt, statement->OptimizeResult()->GetPlanMetaData()); @@ -236,8 +236,7 @@ bool QueryExecUtil::CompileQuery(const std::string &statement, const common::ManagedPointer out_plan = result->OptimizeResult()->GetPlanNode(); NOISEPAGE_ASSERT(network::NetworkUtil::DMLQueryType(result->GetQueryType()), "ExecuteDML expects DML"); - common::ManagedPointer schema = out_plan->GetOutputSchema(); - + const auto schema = out_plan->GetOutputSchema(); auto exec_query = execution::compiler::CompilationContext::Compile( *out_plan, exec_settings, accessor.get(), execution::compiler::CompilationMode::OneShot, override_qid, result->OptimizeResult()->GetPlanMetaData()); @@ -254,7 +253,7 @@ bool QueryExecUtil::ExecuteQuery(const std::string &statement, TupleFunction tup NOISEPAGE_ASSERT(txn_ != nullptr, "Requires BeginTransaction() or UseTransaction()"); ResetError(); auto txn = common::ManagedPointer(txn_); - planner::OutputSchema *schema = schemas_[statement].get(); + const planner::OutputSchema *schema = schemas_[statement].get(); std::mutex sync_mutex; auto consumer = [&tuple_fn, &sync_mutex, schema](byte *tuples, uint32_t num_tuples, uint32_t tuple_size) { @@ -282,12 +281,27 @@ bool QueryExecUtil::ExecuteQuery(const std::string &statement, TupleFunction tup // TODO(wz2): May want to thread the replication manager or recovery manager through execution::exec::OutputCallback callback = consumer; auto accessor = catalog_->GetAccessor(txn, db_oid_, DISABLED); - auto exec_ctx = std::make_unique( - db_oid_, txn, callback, schema, common::ManagedPointer(accessor), exec_settings, metrics, DISABLED, DISABLED); - exec_ctx->SetParams(common::ManagedPointer>(params.Get())); + // TODO(Kyle): Making this copy is far from ideal... + const std::vector query_parameters = + static_cast(params) ? *params : std::vector{}; + auto exec_ctx = execution::exec::ExecutionContextBuilder() + .WithDatabaseOID(db_oid_) + .WithExecutionSettings(exec_settings) + .WithTxnContext(txn) + .WithOutputSchema(common::ManagedPointer{schema}) + .WithOutputCallback(std::move(callback)) + .WithCatalogAccessor(common::ManagedPointer{accessor}) + .WithMetricsManager(metrics) + .WithReplicationManager(DISABLED) + .WithRecoveryManager(DISABLED) + .WithQueryParametersFrom(query_parameters) + .Build(); NOISEPAGE_ASSERT(!txn->MustAbort(), "Transaction should not be in must-abort state prior to executing"); + // TODO(Kyle): Right now it looks like the QueryExecUtil always runs queries in interpreted + // execution mode, regardless of how the setting is updated throughout the rest of the system, + // is this the intended behavior..? (unlikely) exec_queries_[statement]->Run(common::ManagedPointer(exec_ctx), execution::vm::ExecutionMode::Interpret); if (txn->MustAbort()) { // Return false to indicate that the query encountered a runtime error. diff --git a/test/catalog/catalog_test.cpp b/test/catalog/catalog_test.cpp index a9d032e258..ccf5e41c7b 100644 --- a/test/catalog/catalog_test.cpp +++ b/test/catalog/catalog_test.cpp @@ -115,62 +115,174 @@ TEST_F(CatalogTests, LanguageTest) { txn_manager_->Abort(txn); } -TEST_F(CatalogTests, ProcTest) { +/** User-defined function */ +TEST_F(CatalogTests, ProcTest0) { auto txn = txn_manager_->BeginTransaction(); auto accessor = catalog_->GetAccessor(common::ManagedPointer(txn), db_, DISABLED); - // Check visibility to me VerifyCatalogTables(*accessor); - auto lan_oid = accessor->CreateLanguage("test_language"); - auto ns_oid = accessor->GetDefaultNamespace(); - - EXPECT_NE(lan_oid, catalog::INVALID_LANGUAGE_OID); + const auto language_oid = accessor->CreateLanguage("test_language"); + const auto namespace_oid = accessor->GetDefaultNamespace(); + EXPECT_NE(language_oid, catalog::INVALID_LANGUAGE_OID); + EXPECT_NE(namespace_oid, catalog::INVALID_NAMESPACE_OID); txn_manager_->Commit(txn, transaction::TransactionUtil::EmptyCallback, nullptr); + // Create the procedure txn = txn_manager_->BeginTransaction(); accessor = catalog_->GetAccessor(common::ManagedPointer(txn), db_, DISABLED); - // create a sample proc - auto procname = "sample"; - std::vector args = {"arg1", "arg2", "arg3"}; - std::vector arg_types = {accessor->GetTypeOidFromTypeId(execution::sql::SqlTypeId::Integer), - accessor->GetTypeOidFromTypeId(execution::sql::SqlTypeId::Boolean), - accessor->GetTypeOidFromTypeId(execution::sql::SqlTypeId::SmallInt)}; - auto src = "int sample(arg1, arg2, arg3){return 2;}"; + const std::string procname{"sample"}; + const std::vector args{"arg1", "arg2", "arg3"}; + const std::vector arg_types{accessor->GetTypeOidFromTypeId(execution::sql::SqlTypeId::Integer), + accessor->GetTypeOidFromTypeId(execution::sql::SqlTypeId::Boolean), + accessor->GetTypeOidFromTypeId(execution::sql::SqlTypeId::SmallInt)}; + const std::string src{"int sample(arg1, arg2, arg3){return 2;}"}; - auto proc_oid = accessor->CreateProcedure( - procname, lan_oid, ns_oid, catalog::INVALID_TYPE_OID, args, arg_types, {}, {}, + const auto proc_oid = accessor->CreateProcedure( + procname, language_oid, namespace_oid, catalog::INVALID_TYPE_OID, args, arg_types, {}, {}, catalog::type_oid_t(static_cast(execution::sql::SqlTypeId::Integer)), src, false); EXPECT_NE(proc_oid, catalog::INVALID_PROC_OID); + txn_manager_->Commit(txn, transaction::TransactionUtil::EmptyCallback, nullptr); + // Query the catalog for the procedure txn = txn_manager_->BeginTransaction(); accessor = catalog_->GetAccessor(common::ManagedPointer(txn), db_, DISABLED); - // make sure we didn't find this proc that we never added - auto found_oid = accessor->GetProcOid("bad_proc", arg_types); - EXPECT_EQ(found_oid, catalog::INVALID_PROC_OID); + // Make sure we didn't find this proc that we never added + EXPECT_EQ(accessor->GetProcOid("bad_proc", arg_types), catalog::INVALID_PROC_OID); - // look for proc that we actually added - found_oid = accessor->GetProcOid(procname, arg_types); + // Look for proc that we actually added + EXPECT_EQ(proc_oid, accessor->GetProcOid(procname, arg_types)); + EXPECT_TRUE(accessor->DropProcedure(proc_oid)); - auto sin_oid = accessor->GetProcOid("sin", {accessor->GetTypeOidFromTypeId(execution::sql::SqlTypeId::Double)}); + txn_manager_->Commit(txn, transaction::TransactionUtil::EmptyCallback, nullptr); +} + +/** Builtin procedure */ +TEST_F(CatalogTests, ProcTest1) { + auto txn = txn_manager_->BeginTransaction(); + auto accessor = catalog_->GetAccessor(common::ManagedPointer(txn), db_, DISABLED); + + VerifyCatalogTables(*accessor); + + const auto language_oid = accessor->CreateLanguage("test_language"); + const auto namespace_oid = accessor->GetDefaultNamespace(); + EXPECT_NE(language_oid, catalog::INVALID_LANGUAGE_OID); + EXPECT_NE(namespace_oid, catalog::INVALID_NAMESPACE_OID); + + txn_manager_->Commit(txn, transaction::TransactionUtil::EmptyCallback, nullptr); + + // The procedure should already exist + txn = txn_manager_->BeginTransaction(); + accessor = catalog_->GetAccessor(common::ManagedPointer(txn), db_, DISABLED); + + const auto sin_oid = accessor->GetProcOid("sin", {accessor->GetTypeOidFromTypeId(execution::sql::SqlTypeId::Double)}); EXPECT_NE(sin_oid, catalog::INVALID_PROC_OID); + // The function context should already exist auto sin_context = accessor->GetFunctionContext(sin_oid); EXPECT_TRUE(sin_context->IsBuiltin()); EXPECT_EQ(sin_context->GetBuiltin(), execution::ast::Builtin::Sin); EXPECT_EQ(sin_context->GetFunctionReturnType(), execution::sql::SqlTypeId::Double); - auto sin_args = sin_context->GetFunctionArgsType(); + + auto sin_args = sin_context->GetFunctionArgTypes(); EXPECT_EQ(sin_args.size(), 1); EXPECT_EQ(sin_args.back(), execution::sql::SqlTypeId::Double); EXPECT_EQ(sin_context->GetFunctionName(), "sin"); - EXPECT_EQ(found_oid, proc_oid); - auto result = accessor->DropProcedure(found_oid); - EXPECT_TRUE(result); + txn_manager_->Commit(txn, transaction::TransactionUtil::EmptyCallback, nullptr); +} + +/** Untyped NULL arguments */ +TEST_F(CatalogTests, ProcTest2) { + auto txn = txn_manager_->BeginTransaction(); + auto accessor = catalog_->GetAccessor(common::ManagedPointer(txn), db_, DISABLED); + + VerifyCatalogTables(*accessor); + + const auto language_oid = accessor->CreateLanguage("test_language"); + const auto namespace_oid = accessor->GetDefaultNamespace(); + EXPECT_NE(language_oid, catalog::INVALID_LANGUAGE_OID); + EXPECT_NE(namespace_oid, catalog::INVALID_NAMESPACE_OID); + + txn_manager_->Commit(txn, transaction::TransactionUtil::EmptyCallback, nullptr); + + // Create the procedure + txn = txn_manager_->BeginTransaction(); + accessor = catalog_->GetAccessor(common::ManagedPointer(txn), db_, DISABLED); + + const std::string procname{"foo"}; + const std::vector args{"a", "b"}; + const std::vector arg_types{accessor->GetTypeOidFromTypeId(execution::sql::SqlTypeId::Integer), + accessor->GetTypeOidFromTypeId(execution::sql::SqlTypeId::Integer)}; + const std::string src{"int foo(a, b){ return 1337; }"}; + + const auto proc_oid = accessor->CreateProcedure( + procname, language_oid, namespace_oid, catalog::INVALID_TYPE_OID, args, arg_types, {}, {}, + catalog::type_oid_t(static_cast(execution::sql::SqlTypeId::Integer)), src, false); + EXPECT_NE(proc_oid, catalog::INVALID_PROC_OID); + + txn_manager_->Commit(txn, transaction::TransactionUtil::EmptyCallback, nullptr); + + // Query the catalog for the procedure + txn = txn_manager_->BeginTransaction(); + accessor = catalog_->GetAccessor(common::ManagedPointer(txn), db_, DISABLED); + + // Look for proc that we added, with fully-specified types + EXPECT_EQ(proc_oid, accessor->GetProcOid(procname, arg_types)); + + // Look for the same proc, but with the first type unspecified (should fail) + EXPECT_EQ(catalog::INVALID_PROC_OID, + accessor->GetProcOid(procname, {accessor->GetTypeOidFromTypeId(execution::sql::SqlTypeId::Invalid), + accessor->GetTypeOidFromTypeId(execution::sql::SqlTypeId::Integer)})); + + // Look for the same proc, but with the second type unspecified (should fail) + EXPECT_EQ(catalog::INVALID_PROC_OID, + accessor->GetProcOid(procname, {accessor->GetTypeOidFromTypeId(execution::sql::SqlTypeId::Integer), + accessor->GetTypeOidFromTypeId(execution::sql::SqlTypeId::Invalid)})); + + // Look for the same proc, but with both types unspecified (should fail) + EXPECT_EQ(catalog::INVALID_PROC_OID, + accessor->GetProcOid(procname, {accessor->GetTypeOidFromTypeId(execution::sql::SqlTypeId::Invalid), + accessor->GetTypeOidFromTypeId(execution::sql::SqlTypeId::Invalid)})); + + // Look for the same proc, but with types resolved + const auto r0 = accessor->ResolveProcArgumentTypes( + procname, {accessor->GetTypeOidFromTypeId(execution::sql::SqlTypeId::Invalid), + accessor->GetTypeOidFromTypeId(execution::sql::SqlTypeId::Integer)}); + EXPECT_EQ(1, r0.size()); + EXPECT_EQ(proc_oid, accessor->GetProcOid(procname, r0.front())); + + // Look for the same proc, but with types resolved + const auto r1 = accessor->ResolveProcArgumentTypes( + procname, {accessor->GetTypeOidFromTypeId(execution::sql::SqlTypeId::Integer), + accessor->GetTypeOidFromTypeId(execution::sql::SqlTypeId::Invalid)}); + EXPECT_EQ(1, r1.size()); + EXPECT_EQ(proc_oid, accessor->GetProcOid(procname, r1.front())); + + // Look for the same proc, but with types resolved + const auto r2 = accessor->ResolveProcArgumentTypes( + procname, {accessor->GetTypeOidFromTypeId(execution::sql::SqlTypeId::Invalid), + accessor->GetTypeOidFromTypeId(execution::sql::SqlTypeId::Invalid)}); + EXPECT_EQ(1, r2.size()); + EXPECT_EQ(proc_oid, accessor->GetProcOid(procname, r2.front())); + + // Look for a proc with one fixed, incorrect parameter + const auto r3 = + accessor->ResolveProcArgumentTypes(procname, {accessor->GetTypeOidFromTypeId(execution::sql::SqlTypeId::Invalid), + accessor->GetTypeOidFromTypeId(execution::sql::SqlTypeId::Real)}); + EXPECT_TRUE(r3.empty()); + + // Look for a proc with one fixed, incorrect parameter + const auto r4 = accessor->ResolveProcArgumentTypes( + procname, {accessor->GetTypeOidFromTypeId(execution::sql::SqlTypeId::Real), + accessor->GetTypeOidFromTypeId(execution::sql::SqlTypeId::Invalid)}); + EXPECT_TRUE(r4.empty()); + + EXPECT_TRUE(accessor->DropProcedure(proc_oid)); txn_manager_->Commit(txn, transaction::TransactionUtil::EmptyCallback, nullptr); } @@ -903,4 +1015,18 @@ TEST_F(CatalogTests, StatisticTest) { txn_manager_->Commit(txn, transaction::TransactionUtil::EmptyCallback, nullptr); } +TEST_F(CatalogTests, TypeRoundTrip) { + // Ensure that types always round-trip + auto txn = txn_manager_->BeginTransaction(); + auto accessor = catalog_->GetAccessor(common::ManagedPointer(txn), db_, DISABLED); + for (int8_t type_raw = static_cast(execution::sql::SqlTypeId::Boolean); + type_raw <= static_cast(execution::sql::SqlTypeId::Varbinary); ++type_raw) { + const execution::sql::SqlTypeId type = static_cast(type_raw); + const catalog::type_oid_t oid = accessor->GetTypeOidFromTypeId(type); + EXPECT_EQ(type, accessor->GetTypeIdFromTypeOid(oid)); + } + + txn_manager_->Commit(txn, transaction::TransactionUtil::EmptyCallback, nullptr); +} + } // namespace noisepage diff --git a/test/execution/ast_test.cpp b/test/execution/ast_test.cpp index e9a6e5449a..23f369ddcc 100644 --- a/test/execution/ast_test.cpp +++ b/test/execution/ast_test.cpp @@ -64,6 +64,11 @@ TEST_F(AstTest, HierarchyTest) { factory.NewCallExpr(factory.NewNilLiteral(EmptyPos()), util::RegionVector(Region())), factory.NewFunctionLitExpr( factory.NewFunctionType(EmptyPos(), util::RegionVector(Region()), nullptr), nullptr), + factory.NewLambdaExpr( + EmptyPos(), + factory.NewFunctionLitExpr( + factory.NewFunctionType(EmptyPos(), util::RegionVector(Region()), nullptr), nullptr), + util::RegionVector(Region())), factory.NewNilLiteral(EmptyPos()), factory.NewUnaryOpExpr(EmptyPos(), parsing::Token::Type::MINUS, nullptr), factory.NewIdentifierExpr(EmptyPos(), Identifier()), @@ -96,6 +101,7 @@ TEST_F(AstTest, HierarchyTest) { factory.NewDeclStmt(factory.NewVariableDecl(EmptyPos(), Identifier(), nullptr, nullptr)), factory.NewExpressionStmt(factory.NewNilLiteral(EmptyPos())), factory.NewForStmt(EmptyPos(), nullptr, nullptr, nullptr, nullptr), + factory.NewBreakStmt(EmptyPos()), factory.NewIfStmt(EmptyPos(), nullptr, nullptr, nullptr), factory.NewReturnStmt(EmptyPos(), nullptr), }; diff --git a/test/execution/atomics_test.cpp b/test/execution/atomics_test.cpp index d3fd2d1d13..cfe5471dd8 100644 --- a/test/execution/atomics_test.cpp +++ b/test/execution/atomics_test.cpp @@ -8,9 +8,9 @@ #include "execution/compiler/compiler_settings.h" #include "execution/sema/error_reporter.h" #include "execution/util/region.h" +#include "execution/vm/execution_mode.h" #include "execution/vm/llvm_engine.h" #include "execution/vm/module.h" -#include "execution/vm/vm_defs.h" #include "spdlog/fmt/fmt.h" #include "test_util/fs_util.h" #include "test_util/multithread_test_util.h" diff --git a/test/execution/compiler_test.cpp b/test/execution/compiler_test.cpp index f9a3636ff1..a7d12c73dc 100644 --- a/test/execution/compiler_test.cpp +++ b/test/execution/compiler_test.cpp @@ -15,7 +15,7 @@ #include "execution/compiler/expression_maker.h" #include "execution/compiler/output_checker.h" #include "execution/compiler/output_schema_util.h" -#include "execution/exec/execution_context.h" +#include "execution/exec/execution_context_builder.h" #include "execution/exec/output.h" #include "execution/execution_util.h" #include "execution/sema/sema.h" @@ -415,12 +415,12 @@ TEST_F(CompilerTest, SimpleSeqScanWithParamsTest) { exec::OutputPrinter printer(seq_scan->GetOutputSchema().Get()); MultiOutputCallback callback{std::vector{store, printer}}; exec::OutputCallback callback_fn = callback.ConstructOutputCallback(); - auto exec_ctx = MakeExecCtx(&callback_fn, seq_scan->GetOutputSchema().Get()); - std::vector params; + + std::vector params{}; params.emplace_back(execution::sql::SqlTypeId::Integer, execution::sql::Integer(100)); params.emplace_back(execution::sql::SqlTypeId::Integer, execution::sql::Integer(500)); params.emplace_back(execution::sql::SqlTypeId::Integer, execution::sql::Integer(3)); - exec_ctx->SetParams(common::ManagedPointer>(¶ms)); + auto exec_ctx = MakeExecCtxWithParameters(params, &callback_fn, seq_scan->GetOutputSchema().Get()); // Run & Check auto executable = execution::compiler::CompilationContext::Compile(*seq_scan, exec_ctx->GetExecutionSettings(), @@ -3072,11 +3072,12 @@ TEST_F(CompilerTest, InsertIntoSelectWithParamTest) { // Make Exec Ctx MultiOutputCallback callback{std::vector{}}; exec::OutputCallback callback_fn = callback.ConstructOutputCallback(); - auto exec_ctx = MakeExecCtx(&callback_fn, insert->GetOutputSchema().Get()); - std::vector params; + + std::vector params{}; params.emplace_back(execution::sql::SqlTypeId::Integer, execution::sql::Integer(495)); params.emplace_back(execution::sql::SqlTypeId::Integer, execution::sql::Integer(505)); - exec_ctx->SetParams(common::ManagedPointer>(¶ms)); + auto exec_ctx = MakeExecCtxWithParameters(params, &callback_fn, insert->GetOutputSchema().Get()); + auto executable = execution::compiler::CompilationContext::Compile(*insert, exec_ctx->GetExecutionSettings(), exec_ctx->GetAccessor()); executable->Run(common::ManagedPointer(exec_ctx), MODE); @@ -3295,8 +3296,8 @@ TEST_F(CompilerTest, SimpleInsertWithParamsTest) { // Make Exec Ctx MultiOutputCallback callback{std::vector{}}; exec::OutputCallback callback_fn = callback.ConstructOutputCallback(); - auto exec_ctx = MakeExecCtx(&callback_fn, insert->GetOutputSchema().Get()); - std::vector params; + std::vector params{}; + // First parameter list auto str1_val = sql::ValueUtil::CreateStringVal(str1); params.emplace_back(execution::sql::SqlTypeId::Varchar, str1_val.first, std::move(str1_val.second)); @@ -3317,7 +3318,8 @@ TEST_F(CompilerTest, SimpleInsertWithParamsTest) { params.emplace_back(execution::sql::SqlTypeId::SmallInt, sql::Integer(smallint2)); params.emplace_back(execution::sql::SqlTypeId::Integer, sql::Integer(int2)); params.emplace_back(execution::sql::SqlTypeId::BigInt, sql::Integer(bigint2)); - exec_ctx->SetParams(common::ManagedPointer>(¶ms)); + + auto exec_ctx = MakeExecCtxWithParameters(params, &callback_fn, insert->GetOutputSchema().Get()); auto executable = execution::compiler::CompilationContext::Compile(*insert, exec_ctx->GetExecutionSettings(), exec_ctx->GetAccessor()); executable->Run(common::ManagedPointer(exec_ctx), MODE); @@ -3497,13 +3499,14 @@ TEST_F(CompilerTest, SimpleInsertWithParamsTest) { exec::OutputPrinter printer(index_scan->GetOutputSchema().Get()); MultiOutputCallback callback{std::vector{store, printer}}; exec::OutputCallback callback_fn = callback.ConstructOutputCallback(); - auto exec_ctx = MakeExecCtx(&callback_fn, index_scan->GetOutputSchema().Get()); - std::vector params; + + std::vector params{}; auto str1_val = sql::ValueUtil::CreateStringVal(str1); auto str2_val = sql::ValueUtil::CreateStringVal(str2); params.emplace_back(execution::sql::SqlTypeId::Varchar, str1_val.first, std::move(str1_val.second)); params.emplace_back(execution::sql::SqlTypeId::Varchar, str2_val.first, std::move(str2_val.second)); - exec_ctx->SetParams(common::ManagedPointer>(¶ms)); + + auto exec_ctx = MakeExecCtxWithParameters(params, &callback_fn, index_scan->GetOutputSchema().Get()); auto executable = execution::compiler::CompilationContext::Compile(*index_scan, exec_ctx->GetExecutionSettings(), exec_ctx->GetAccessor()); executable->Run(common::ManagedPointer(exec_ctx), MODE); diff --git a/test/execution/execution_context_builder_test.cpp b/test/execution/execution_context_builder_test.cpp new file mode 100644 index 0000000000..b39c71295d --- /dev/null +++ b/test/execution/execution_context_builder_test.cpp @@ -0,0 +1,169 @@ +#include "execution/exec/execution_context_builder.h" + +#include "execution/compiled_tpl_test.h" +#include "execution/exec/execution_context.h" + +/** A dummy from which we can constuct null ManagedPointers */ +#define DUMMY nullptr + +namespace noisepage::execution::test { + +class ExecutionContextBuilderTest : public TplTest { + /** The OID with which the database OID is initialized */ + constexpr static const uint32_t DB_OID = 15721; + + public: + ExecutionContextBuilderTest() : db_oid_{DB_OID}, output_callback_{[](byte *, uint32_t, uint32_t) {}} {} // NOLINT + + /** @return The dummy database OID */ + catalog::db_oid_t GetDatabaseOID() const { return db_oid_; } + + /** @return The dummy execution settings */ + const exec::ExecutionSettings &GetExecutionSettings() const { return execution_settings_; } + + /** @return The dummy output callback */ + const exec::OutputCallback &GetOutputCallback() const { return output_callback_; } + + private: + /** A dummy database OID */ + catalog::db_oid_t db_oid_; + /** A dummy ExecutionSettings instance */ + exec::ExecutionSettings execution_settings_{}; + /** A dummy output callback */ + const exec::OutputCallback output_callback_; +}; + +TEST_F(ExecutionContextBuilderTest, DoesNotThrowWithAllConfigurationSpecified) { + auto builder = exec::ExecutionContextBuilder() + .WithDatabaseOID(GetDatabaseOID()) + .WithTxnContext(DUMMY) + .WithExecutionSettings(GetExecutionSettings()) + .WithOutputSchema(DUMMY) + .WithOutputCallback(GetOutputCallback()) + .WithCatalogAccessor(DUMMY) + .WithMetricsManager(DUMMY) + .WithReplicationManager(DUMMY) + .WithRecoveryManager(DUMMY); + EXPECT_NO_THROW(builder.Build()); +} + +TEST_F(ExecutionContextBuilderTest, ThrowsOnMissingDatabaseOID) { + auto builder = exec::ExecutionContextBuilder() + .WithTxnContext(DUMMY) + .WithExecutionSettings(GetExecutionSettings()) + .WithOutputSchema(DUMMY) + .WithOutputCallback(GetOutputCallback()) + .WithCatalogAccessor(DUMMY) + .WithMetricsManager(DUMMY) + .WithReplicationManager(DUMMY) + .WithRecoveryManager(DUMMY); + EXPECT_THROW(builder.Build(), ExecutionException); +} + +TEST_F(ExecutionContextBuilderTest, ThrowsOnMissingTransactionContext) { + auto builder = exec::ExecutionContextBuilder() + .WithDatabaseOID(GetDatabaseOID()) + .WithExecutionSettings(GetExecutionSettings()) + .WithOutputSchema(DUMMY) + .WithOutputCallback(GetOutputCallback()) + .WithCatalogAccessor(DUMMY) + .WithMetricsManager(DUMMY) + .WithReplicationManager(DUMMY) + .WithRecoveryManager(DUMMY); + EXPECT_THROW(builder.Build(), ExecutionException); +} + +TEST_F(ExecutionContextBuilderTest, ThrowsOnMissingExecutionSettings) { + auto builder = exec::ExecutionContextBuilder() + .WithDatabaseOID(GetDatabaseOID()) + .WithTxnContext(DUMMY) + .WithOutputSchema(DUMMY) + .WithOutputCallback(GetOutputCallback()) + .WithCatalogAccessor(DUMMY) + .WithMetricsManager(DUMMY) + .WithReplicationManager(DUMMY) + .WithRecoveryManager(DUMMY); + EXPECT_THROW(builder.Build(), ExecutionException); +} + +TEST_F(ExecutionContextBuilderTest, ThrowsOnMissingOutputSchema) { + auto builder = exec::ExecutionContextBuilder() + .WithDatabaseOID(GetDatabaseOID()) + .WithTxnContext(DUMMY) + .WithExecutionSettings(GetExecutionSettings()) + .WithOutputCallback(GetOutputCallback()) + .WithCatalogAccessor(DUMMY) + .WithMetricsManager(DUMMY) + .WithReplicationManager(DUMMY) + .WithRecoveryManager(DUMMY); + EXPECT_THROW(builder.Build(), ExecutionException); +} + +TEST_F(ExecutionContextBuilderTest, ThrowsOnMissingOutputCallback) { + auto builder = exec::ExecutionContextBuilder() + .WithDatabaseOID(GetDatabaseOID()) + .WithTxnContext(DUMMY) + .WithExecutionSettings(GetExecutionSettings()) + .WithOutputSchema(DUMMY) + .WithCatalogAccessor(DUMMY) + .WithMetricsManager(DUMMY) + .WithReplicationManager(DUMMY) + .WithRecoveryManager(DUMMY); + EXPECT_THROW(builder.Build(), ExecutionException); +} + +TEST_F(ExecutionContextBuilderTest, ThrowsOnMissingCatalogAccessor) { + auto builder = exec::ExecutionContextBuilder() + .WithDatabaseOID(GetDatabaseOID()) + .WithTxnContext(DUMMY) + .WithExecutionSettings(GetExecutionSettings()) + .WithOutputSchema(DUMMY) + .WithOutputCallback(GetOutputCallback()) + .WithMetricsManager(DUMMY) + .WithReplicationManager(DUMMY) + .WithRecoveryManager(DUMMY); + EXPECT_THROW(builder.Build(), ExecutionException); +} + +TEST_F(ExecutionContextBuilderTest, ThrowsOnMissingMetricsManager) { + auto builder = exec::ExecutionContextBuilder() + .WithDatabaseOID(GetDatabaseOID()) + .WithTxnContext(DUMMY) + .WithExecutionSettings(GetExecutionSettings()) + .WithOutputSchema(DUMMY) + .WithOutputCallback(GetOutputCallback()) + .WithCatalogAccessor(DUMMY) + .WithReplicationManager(DUMMY) + .WithRecoveryManager(DUMMY); + EXPECT_THROW(builder.Build(), ExecutionException); +} + +TEST_F(ExecutionContextBuilderTest, ThrowsOnMissingReplicationManager) { + auto builder = exec::ExecutionContextBuilder() + .WithDatabaseOID(GetDatabaseOID()) + .WithTxnContext(DUMMY) + .WithExecutionSettings(GetExecutionSettings()) + .WithOutputSchema(DUMMY) + .WithOutputCallback(GetOutputCallback()) + .WithCatalogAccessor(DUMMY) + .WithMetricsManager(DUMMY) + .WithRecoveryManager(DUMMY); + EXPECT_THROW(builder.Build(), ExecutionException); +} + +TEST_F(ExecutionContextBuilderTest, ThrowsOnMissingRecoveryManager) { + auto builder = exec::ExecutionContextBuilder() + .WithDatabaseOID(GetDatabaseOID()) + .WithTxnContext(DUMMY) + .WithExecutionSettings(GetExecutionSettings()) + .WithOutputSchema(DUMMY) + .WithOutputCallback(GetOutputCallback()) + .WithCatalogAccessor(DUMMY) + .WithMetricsManager(DUMMY) + .WithReplicationManager(DUMMY); + EXPECT_THROW(builder.Build(), ExecutionException); +} + +#undef DUMMY + +} // namespace noisepage::execution::test diff --git a/test/execution/system_functions_test.cpp b/test/execution/system_functions_test.cpp index 1cb6dbbba2..07c7c38bf7 100644 --- a/test/execution/system_functions_test.cpp +++ b/test/execution/system_functions_test.cpp @@ -4,6 +4,7 @@ #include "common/version.h" #include "execution/exec/execution_context.h" +#include "execution/exec/execution_context_builder.h" #include "execution/exec/execution_settings.h" #include "execution/sql/value.h" #include "execution/tpl_test.h" @@ -12,14 +13,31 @@ namespace noisepage::execution::sql::test { class SystemFunctionsTests : public TplTest { public: - SystemFunctionsTests() - : ctx_(catalog::db_oid_t(0), nullptr, nullptr, nullptr, nullptr, settings_, nullptr, DISABLED, DISABLED) {} - - exec::ExecutionContext *Ctx() { return &ctx_; } + SystemFunctionsTests() { + ctx_ = exec::ExecutionContextBuilder() + .WithDatabaseOID(DATABASE_OID) + .WithTxnContext(nullptr) + .WithExecutionSettings(settings_) + .WithOutputSchema(nullptr) + .WithOutputCallback(nullptr) + .WithCatalogAccessor(nullptr) + .WithMetricsManager(DISABLED) + .WithReplicationManager(DISABLED) + .WithRecoveryManager(DISABLED) + .Build(); + } + + /** @return A non-owning pointer to the execution context */ + exec::ExecutionContext *Ctx() { return ctx_.get(); } private: + /** Dummy database OID */ + constexpr static catalog::db_oid_t DATABASE_OID{15721}; + + /** The execution settings for the test */ exec::ExecutionSettings settings_{}; - exec::ExecutionContext ctx_; + /** The execution context for the test */ + std::unique_ptr ctx_; }; // NOLINTNEXTLINE diff --git a/test/include/execution/compiler/expression_maker.h b/test/include/execution/compiler/expression_maker.h index da629e6b71..95b7e3413d 100644 --- a/test/include/execution/compiler/expression_maker.h +++ b/test/include/execution/compiler/expression_maker.h @@ -95,6 +95,13 @@ class ExpressionMaker { return MakeManaged(std::make_unique(catalog::table_oid_t(0), column_oid, type)); } + /** + * Create a column value expression + */ + ManagedExpression CVE(catalog::table_oid_t table_oid, catalog::col_oid_t column_oid, execution::sql::SqlTypeId type) { + return MakeManaged(std::make_unique(table_oid, column_oid, type)); + } + /** * Create a derived value expression */ diff --git a/test/include/execution/sql_test.h b/test/include/execution/sql_test.h index da7e0a0a7b..1928824416 100644 --- a/test/include/execution/sql_test.h +++ b/test/include/execution/sql_test.h @@ -5,6 +5,7 @@ #include #include "execution/exec/execution_context.h" +#include "execution/exec/execution_context_builder.h" #include "execution/exec/execution_settings.h" #include "execution/sql/sql.h" #include "execution/sql/vector.h" @@ -53,41 +54,99 @@ class SqlBasedTest : public TplTest { ~SqlBasedTest() override { txn_manager_->Commit(test_txn_, transaction::TransactionUtil::EmptyCallback, nullptr); } + /** @return The namespace OID */ catalog::namespace_oid_t NSOid() { return test_ns_oid_; } + /** @return The block store */ common::ManagedPointer BlockStore() { return block_store_; } + /** + * Construct and return an execution context. + * @param callback[optional] The output callback + * @param schema[optional] the output schema + * @return The execution context + */ std::unique_ptr MakeExecCtx(exec::OutputCallback *callback = nullptr, const planner::OutputSchema *schema = nullptr) { exec::OutputCallback empty = nullptr; const auto &callback_ref = (callback == nullptr) ? empty : *callback; - return std::make_unique(test_db_oid_, common::ManagedPointer(test_txn_), callback_ref, - schema, common::ManagedPointer(accessor_), *exec_settings_, - metrics_manager_, DISABLED, DISABLED); + return exec::ExecutionContextBuilder() + .WithDatabaseOID(test_db_oid_) + .WithExecutionSettings(*exec_settings_) + .WithTxnContext(common::ManagedPointer{test_txn_}) + .WithOutputSchema(common::ManagedPointer{schema}) + .WithOutputCallback(callback_ref) + .WithCatalogAccessor(common::ManagedPointer{accessor_}) + .WithMetricsManager(metrics_manager_) + .WithReplicationManager(DISABLED) + .WithRecoveryManager(DISABLED) + .Build(); } + /** + * Construct and return an execution context. + * @param parameters The query execution parameters + * @param callback[optional] The output callback + * @param schema[optional] The output schema + * @return The execution context + */ + std::unique_ptr MakeExecCtxWithParameters( + const std::vector ¶meters, exec::OutputCallback *callback = nullptr, + const planner::OutputSchema *schema = nullptr) { + exec::OutputCallback empty = nullptr; + const auto &callback_ref = (callback == nullptr) ? empty : *callback; + return exec::ExecutionContextBuilder() + .WithDatabaseOID(test_db_oid_) + .WithExecutionSettings(*exec_settings_) + .WithTxnContext(common::ManagedPointer{test_txn_}) + .WithOutputSchema(common::ManagedPointer{schema}) + .WithOutputCallback(callback_ref) + .WithCatalogAccessor(common::ManagedPointer{accessor_}) + .WithMetricsManager(metrics_manager_) + .WithReplicationManager(DISABLED) + .WithRecoveryManager(DISABLED) + .WithQueryParametersFrom(parameters) + .Build(); + } + + /** + * Generate the test tables for SQL tests. + * @param exec_ctx The execution context to use for table generation. + */ void GenerateTestTables(exec::ExecutionContext *exec_ctx) { sql::TableGenerator table_generator{exec_ctx, block_store_, test_ns_oid_}; table_generator.GenerateTestTables(); } + /** @return A new, owned catalog accessor */ std::unique_ptr MakeAccessor() { return catalog_->GetAccessor(common::ManagedPointer(test_txn_), test_db_oid_, DISABLED); } protected: + /** The catalog accessor */ std::unique_ptr accessor_; + /** The identifier for the test database */ catalog::db_oid_t test_db_oid_{0}; + /** The statistics storage */ common::ManagedPointer stats_storage_; + /** The test transaction context */ transaction::TransactionContext *test_txn_; + /** The transaction manager */ common::ManagedPointer txn_manager_; private: + /** The database instance */ std::unique_ptr db_main_; + /** The metrics manager instance */ common::ManagedPointer metrics_manager_; + /** The block store */ common::ManagedPointer block_store_; + /** The catalog instance */ common::ManagedPointer catalog_; + /** The identifier for the test namespace */ catalog::namespace_oid_t test_ns_oid_; + /** The execution settings instance */ std::unique_ptr exec_settings_; }; diff --git a/test/include/test_util/procbench/procbench_query.h b/test/include/test_util/procbench/procbench_query.h new file mode 100644 index 0000000000..f4110f2c50 --- /dev/null +++ b/test/include/test_util/procbench/procbench_query.h @@ -0,0 +1,19 @@ +#pragma once + +#include +#include + +#include "catalog/catalog_accessor.h" +#include "execution/compiler/executable_query.h" + +namespace noisepage::procbench { + +/** ProcbenchQuery defines queries for SQL Procbench benchmarks. */ +class ProcbenchQuery { + public: + // Static functions to generate executable queries for ProcBench benchmark. Query plans are hard coded. + static std::tuple, std::unique_ptr> + MakeExecutableQ6(const std::unique_ptr &accessor, + const execution::exec::ExecutionSettings &exec_settings); +}; +} // namespace noisepage::procbench diff --git a/test/include/test_util/procbench/workload.h b/test/include/test_util/procbench/workload.h new file mode 100644 index 0000000000..695f742fc9 --- /dev/null +++ b/test/include/test_util/procbench/workload.h @@ -0,0 +1,103 @@ +#pragma once + +#include +#include +#include +#include +#include + +#include "catalog/catalog_accessor.h" +#include "catalog/catalog_defs.h" +#include "common/managed_pointer.h" +#include "execution/compiler/executable_query.h" +#include "execution/exec/execution_settings.h" +#include "execution/vm/module.h" + +namespace noisepage::execution::exec { +class ExecutionContext; +} + +namespace noisepage::catalog { +class Catalog; +} + +namespace noisepage::transaction { +class TransactionManager; +} + +namespace noisepage { +class DBMain; +} + +namespace noisepage::procbench { + +/** + * Class that can load the ProcBench tables, compile the + * ProcBench queries, and execute the ProcBench workload. + */ +class Workload { + public: + /** + * Construct a new Workload instance. + * @param db_main The database instance + * @param db_name The name of the database + * @param table_root The root of the table data directory + */ + Workload(common::ManagedPointer db_main, const std::string &db_name, const std::string &table_root); + + /** + * Function to invoke for a single worker thread to invoke the ProcBench queries. + * @param exec_mode The execution mode + */ + void Execute(std::size_t query_number, execution::vm::ExecutionMode exec_mode); + + /** @return The number of queries in the workload. */ + uint32_t GetQueryCount() { return query_and_plan_.size(); } + + private: + /** + * Load the tables for the ProcBench benchmark. + * @param exec_ctx The execution context + * @param directory The name of the directory from which tables are loaded + */ + void LoadTables(execution::exec::ExecutionContext *exec_ctx, const std::string &directory); + + /** + * Load the queries for the ProcBench benchmark. + * @param accessor The catalog accessor instance + */ + void LoadQueries(const std::unique_ptr &accessor); + + /** + * Get the index for the specified query number + * @param query_number The query number + * @return The index + */ + std::size_t QueryNumberToIndex(std::size_t query_number) const; + + private: + /** The database server instance */ + common::ManagedPointer db_main_; + /** The block store */ + common::ManagedPointer block_store_; + /** The catalog instance */ + common::ManagedPointer catalog_; + /** The transaction manager */ + common::ManagedPointer txn_manager_; + /** The database OID */ + catalog::db_oid_t db_oid_; + /** The namespace OID */ + catalog::namespace_oid_t ns_oid_; + /** Execution settings for all executed queries */ + execution::exec::ExecutionSettings exec_settings_{}; + /** The catalog accessor */ + std::unique_ptr accessor_; + /** The collection of executable queries and associated plans */ + std::vector< + std::tuple, std::unique_ptr>> + query_and_plan_; + /** Translate a query number of corresponding index */ + std::unordered_map query_number_to_index_; +}; + +} // namespace noisepage::procbench diff --git a/test/optimizer/index_nested_loops_join_test.cpp b/test/optimizer/index_nested_loops_join_test.cpp index 8090ac8ca8..c29e00be2e 100644 --- a/test/optimizer/index_nested_loops_join_test.cpp +++ b/test/optimizer/index_nested_loops_join_test.cpp @@ -8,6 +8,7 @@ #include "execution/compiler/executable_query.h" #include "execution/compiler/output_checker.h" #include "execution/exec/execution_context.h" +#include "execution/exec/execution_context_builder.h" #include "execution/exec/execution_settings.h" #include "execution/sql/value.h" #include "execution/vm/module.h" @@ -125,11 +126,17 @@ struct IdxJoinTest : public TerrierTest { void TearDown() override { TerrierTest::TearDown(); } + /** The connection context */ network::ConnectionContext context_; + /** The catalog instance */ common::ManagedPointer catalog_; + /** The transaction manager instance */ common::ManagedPointer txn_manager_; + /** The traffic cop instance */ common::ManagedPointer tcop_; + /** The database instance */ std::unique_ptr db_main_; + /** The database OID */ catalog::db_oid_t db_oid_; }; @@ -203,9 +210,18 @@ TEST_F(IdxJoinTest, SimpleIdxJoinTest) { execution::exec::ExecutionSettings exec_settings{}; exec_settings.is_parallel_execution_enabled_ = false; execution::exec::OutputCallback callback_fn = callback.ConstructOutputCallback(); - auto exec_ctx = std::make_unique( - db_oid_, common::ManagedPointer(txn), callback_fn, out_plan->GetOutputSchema().Get(), - common::ManagedPointer(accessor), exec_settings, db_main_->GetMetricsManager(), DISABLED, DISABLED); + + auto exec_ctx = execution::exec::ExecutionContextBuilder() + .WithDatabaseOID(db_oid_) + .WithTxnContext(common::ManagedPointer{txn}) + .WithExecutionSettings(exec_settings) + .WithOutputSchema(common::ManagedPointer{out_plan->GetOutputSchema().Get()}) + .WithOutputCallback(callback_fn) + .WithCatalogAccessor(common::ManagedPointer{accessor}) + .WithMetricsManager(db_main_->GetMetricsManager()) + .WithReplicationManager(DISABLED) + .WithRecoveryManager(DISABLED) + .Build(); // Run & Check auto executable = execution::compiler::CompilationContext::Compile(*out_plan, exec_ctx->GetExecutionSettings(), @@ -326,9 +342,18 @@ TEST_F(IdxJoinTest, MultiPredicateJoin) { execution::exec::ExecutionSettings exec_settings{}; exec_settings.is_parallel_execution_enabled_ = false; execution::exec::OutputCallback callback_fn = callback.ConstructOutputCallback(); - auto exec_ctx = std::make_unique( - db_oid_, common::ManagedPointer(txn), callback_fn, out_plan->GetOutputSchema().Get(), - common::ManagedPointer(accessor), exec_settings, db_main_->GetMetricsManager(), DISABLED, DISABLED); + + auto exec_ctx = execution::exec::ExecutionContextBuilder() + .WithDatabaseOID(db_oid_) + .WithTxnContext(common::ManagedPointer{txn}) + .WithExecutionSettings(exec_settings) + .WithOutputSchema(common::ManagedPointer{out_plan->GetOutputSchema().Get()}) + .WithOutputCallback(callback_fn) + .WithCatalogAccessor(common::ManagedPointer{accessor}) + .WithMetricsManager(db_main_->GetMetricsManager()) + .WithReplicationManager(DISABLED) + .WithRecoveryManager(DISABLED) + .Build(); // Run & Check auto executable = execution::compiler::CompilationContext::Compile(*out_plan, exec_ctx->GetExecutionSettings(), @@ -409,9 +434,17 @@ TEST_F(IdxJoinTest, MultiPredicateJoinWithExtra) { execution::exec::ExecutionSettings exec_settings{}; exec_settings.is_parallel_execution_enabled_ = false; execution::exec::OutputCallback callback_fn = callback.ConstructOutputCallback(); - auto exec_ctx = std::make_unique( - db_oid_, common::ManagedPointer(txn), callback_fn, out_plan->GetOutputSchema().Get(), - common::ManagedPointer(accessor), exec_settings, db_main_->GetMetricsManager(), DISABLED, DISABLED); + auto exec_ctx = execution::exec::ExecutionContextBuilder() + .WithDatabaseOID(db_oid_) + .WithTxnContext(common::ManagedPointer{txn}) + .WithExecutionSettings(exec_settings) + .WithOutputSchema(common::ManagedPointer{out_plan->GetOutputSchema().Get()}) + .WithOutputCallback(callback_fn) + .WithCatalogAccessor(common::ManagedPointer{accessor}) + .WithMetricsManager(db_main_->GetMetricsManager()) + .WithReplicationManager(DISABLED) + .WithRecoveryManager(DISABLED) + .Build(); // Run & Check auto executable = execution::compiler::CompilationContext::Compile(*out_plan, exec_ctx->GetExecutionSettings(), @@ -478,9 +511,17 @@ TEST_F(IdxJoinTest, FooOnlyScan) { execution::exec::ExecutionSettings exec_settings{}; exec_settings.is_parallel_execution_enabled_ = false; execution::exec::OutputCallback callback_fn = callback.ConstructOutputCallback(); - auto exec_ctx = std::make_unique( - db_oid_, common::ManagedPointer(txn), callback_fn, out_plan->GetOutputSchema().Get(), - common::ManagedPointer(accessor), exec_settings, db_main_->GetMetricsManager(), DISABLED, DISABLED); + auto exec_ctx = execution::exec::ExecutionContextBuilder() + .WithDatabaseOID(db_oid_) + .WithTxnContext(common::ManagedPointer{txn}) + .WithExecutionSettings(exec_settings) + .WithOutputSchema(common::ManagedPointer{out_plan->GetOutputSchema().Get()}) + .WithOutputCallback(callback_fn) + .WithCatalogAccessor(common::ManagedPointer{accessor}) + .WithMetricsManager(db_main_->GetMetricsManager()) + .WithReplicationManager(DISABLED) + .WithRecoveryManager(DISABLED) + .Build(); // Run & Check auto executable = execution::compiler::CompilationContext::Compile(*out_plan, exec_ctx->GetExecutionSettings(), @@ -547,9 +588,17 @@ TEST_F(IdxJoinTest, BarOnlyScan) { execution::exec::ExecutionSettings exec_settings{}; exec_settings.is_parallel_execution_enabled_ = false; execution::exec::OutputCallback callback_fn = callback.ConstructOutputCallback(); - auto exec_ctx = std::make_unique( - db_oid_, common::ManagedPointer(txn), callback_fn, out_plan->GetOutputSchema().Get(), - common::ManagedPointer(accessor), exec_settings, db_main_->GetMetricsManager(), DISABLED, DISABLED); + auto exec_ctx = execution::exec::ExecutionContextBuilder() + .WithDatabaseOID(db_oid_) + .WithTxnContext(common::ManagedPointer{txn}) + .WithExecutionSettings(exec_settings) + .WithOutputSchema(common::ManagedPointer{out_plan->GetOutputSchema().Get()}) + .WithOutputCallback(callback_fn) + .WithCatalogAccessor(common::ManagedPointer{accessor}) + .WithMetricsManager(db_main_->GetMetricsManager()) + .WithReplicationManager(DISABLED) + .WithRecoveryManager(DISABLED) + .Build(); // Run & Check auto executable = execution::compiler::CompilationContext::Compile(*out_plan, exec_ctx->GetExecutionSettings(), @@ -629,9 +678,17 @@ TEST_F(IdxJoinTest, IndexToIndexJoin) { execution::exec::ExecutionSettings exec_settings{}; exec_settings.is_parallel_execution_enabled_ = false; execution::exec::OutputCallback callback_fn = callback.ConstructOutputCallback(); - auto exec_ctx = std::make_unique( - db_oid_, common::ManagedPointer(txn), callback_fn, out_plan->GetOutputSchema().Get(), - common::ManagedPointer(accessor), exec_settings, db_main_->GetMetricsManager(), DISABLED, DISABLED); + auto exec_ctx = execution::exec::ExecutionContextBuilder() + .WithDatabaseOID(db_oid_) + .WithTxnContext(common::ManagedPointer{txn}) + .WithExecutionSettings(exec_settings) + .WithOutputSchema(common::ManagedPointer{out_plan->GetOutputSchema().Get()}) + .WithOutputCallback(callback_fn) + .WithCatalogAccessor(common::ManagedPointer{accessor}) + .WithMetricsManager(db_main_->GetMetricsManager()) + .WithReplicationManager(DISABLED) + .WithRecoveryManager(DISABLED) + .Build(); // Run & Check auto executable = execution::compiler::CompilationContext::Compile(*out_plan, exec_ctx->GetExecutionSettings(), diff --git a/test/parser/plpgsql_parser_test.cpp b/test/parser/plpgsql_parser_test.cpp new file mode 100644 index 0000000000..e1a19f2cd8 --- /dev/null +++ b/test/parser/plpgsql_parser_test.cpp @@ -0,0 +1,24 @@ +#include + +#include "parser/udf/string_utils.h" +#include "test_util/test_harness.h" + +namespace noisepage::parser { + +class PLpgSQLParserTest : public TerrierTest {}; + +TEST_F(PLpgSQLParserTest, LowerTest0) { + const std::string input{"HELLO WORLD"}; + const std::string expected{"hello world"}; + const auto result = udf::StringUtils::Lower(input); + EXPECT_EQ(expected, result); +} + +TEST_F(PLpgSQLParserTest, StripTest0) { + const std::string input{" hello "}; + const std::string expected{"hello"}; + const auto result = udf::StringUtils::Strip(input); + EXPECT_EQ(expected, result); +} + +} // namespace noisepage::parser diff --git a/test/test_util/procbench/procbench_query.cpp b/test/test_util/procbench/procbench_query.cpp new file mode 100644 index 0000000000..336b4cba76 --- /dev/null +++ b/test/test_util/procbench/procbench_query.cpp @@ -0,0 +1,49 @@ +#include "test_util/procbench/procbench_query.h" + +#include "catalog/catalog_accessor.h" +#include "execution/compiler/compilation_context.h" +#include "execution/compiler/expression_maker.h" +#include "execution/compiler/output_schema_util.h" +#include "execution/sql/sql_def.h" +#include "planner/plannodes/aggregate_plan_node.h" +#include "planner/plannodes/hash_join_plan_node.h" +#include "planner/plannodes/nested_loop_join_plan_node.h" +#include "planner/plannodes/order_by_plan_node.h" +#include "planner/plannodes/seq_scan_plan_node.h" + +namespace noisepage::procbench { + +std::tuple, std::unique_ptr> +ProcbenchQuery::MakeExecutableQ6(const std::unique_ptr &accessor, + const execution::exec::ExecutionSettings &exec_settings) { + execution::compiler::test::ExpressionMaker expr_maker; + const auto web_sales_history_oid = accessor->GetTableOid("web_sales_history"); + const auto &web_sales_history_schema = accessor->GetSchema(web_sales_history_oid); + + // Scan the table + std::unique_ptr seq_scan; + execution::compiler::test::OutputSchemaHelper seq_scan_out{0, &expr_maker}; + { + // Read all needed columns + auto ws_sold_date = + expr_maker.CVE(web_sales_history_oid, web_sales_history_schema.GetColumn("ws_sold_date_sk").Oid(), + execution::sql::SqlTypeId::Integer); + std::vector col_oids = {web_sales_history_schema.GetColumn("ws_sold_date_sk").Oid()}; + + // Make the output schema + seq_scan_out.AddOutput("ws_sold_date", ws_sold_date); + auto schema = seq_scan_out.MakeSchema(); + + // Build + planner::SeqScanPlanNode::Builder builder; + seq_scan = builder.SetOutputSchema(std::move(schema)) + .SetScanPredicate(nullptr) + .SetTableOid(web_sales_history_oid) + .SetColumnOids(std::move(col_oids)) + .Build(); + } + auto query = execution::compiler::CompilationContext::Compile(*seq_scan, exec_settings, accessor.get()); + return std::make_tuple(std::move(query), std::move(seq_scan)); +} + +} // namespace noisepage::procbench diff --git a/test/test_util/procbench/workload.cpp b/test/test_util/procbench/workload.cpp new file mode 100644 index 0000000000..8d7d900461 --- /dev/null +++ b/test/test_util/procbench/workload.cpp @@ -0,0 +1,165 @@ +#include "test_util/procbench/workload.h" + +#include +#include +#include + +#include "common/managed_pointer.h" +#include "execution/compiler/output_schema_util.h" +#include "execution/exec/execution_context_builder.h" +#include "execution/sql/value_util.h" +#include "execution/table_generator/table_generator.h" +#include "main/db_main.h" +#include "planner/plannodes/aggregate_plan_node.h" +#include "planner/plannodes/hash_join_plan_node.h" +#include "planner/plannodes/nested_loop_join_plan_node.h" +#include "planner/plannodes/order_by_plan_node.h" +#include "planner/plannodes/seq_scan_plan_node.h" +#include "test_util/procbench/procbench_query.h" + +namespace noisepage::procbench { + +/** Query identifiers */ +static constexpr const std::size_t Q6_ID = 6; + +/** ProcBench table names */ +static const std::vector PROCBENCH_TABLE_NAMES{"call_center", + "catalog_page", + "catalog_returns_history", + "catalog_returns", + "catalog_sales_history", + "catalog_sales", + "customer_address", + "customer_demographics", + "customer", + "date_dim", + "household_demographics", + "income_band", + "inventory_history", + "inventory", + "item", + "promotion", + "reason", + "ship_mode", + "store_returns_history", + "store_returns", + "store_sales_history", + "store_sales", + "store", + "time_dim", + "warehouse", + "web_page", + "web_returns_history", + "web_returns", + "web_sales_history", + "web_sales", + "web_site"}; + +Workload::Workload(common::ManagedPointer db_main, const std::string &db_name, const std::string &table_root) { + // cache db main and members + db_main_ = db_main; + txn_manager_ = db_main_->GetTransactionLayer()->GetTransactionManager(); + block_store_ = db_main_->GetStorageLayer()->GetBlockStore(); + catalog_ = db_main_->GetCatalogLayer()->GetCatalog(); + txn_manager_ = db_main_->GetTransactionLayer()->GetTransactionManager(); + + auto txn = txn_manager_->BeginTransaction(); + + // Create database catalog and namespace + db_oid_ = catalog_->CreateDatabase(common::ManagedPointer(txn), db_name, true); + auto accessor = + catalog_->GetAccessor(common::ManagedPointer(txn), db_oid_, DISABLED); + ns_oid_ = accessor->GetDefaultNamespace(); + + // Enable counters and disable the parallel execution for this workload + exec_settings_.is_parallel_execution_enabled_ = false; + exec_settings_.is_counters_enabled_ = true; + + // Make the execution context + auto exec_ctx = execution::exec::ExecutionContextBuilder() + .WithDatabaseOID(db_oid_) + .WithExecutionSettings(exec_settings_) + .WithTxnContext(common::ManagedPointer{txn}) + .WithOutputSchema(execution::exec::ExecutionContext::NULL_OUTPUT_SCHEMA) + .WithOutputCallback(execution::exec::ExecutionContext::NULL_OUTPUT_CALLBACK) + .WithCatalogAccessor(common::ManagedPointer{accessor}) + .WithMetricsManager(db_main->GetMetricsManager()) + .WithReplicationManager(DISABLED) + .WithRecoveryManager(DISABLED) + .Build(); + + // Create the ProcBench database + LoadTables(exec_ctx.get(), table_root); + // Compile all queries for the benchmark + LoadQueries(accessor); + + txn_manager_->Commit(txn, transaction::TransactionUtil::EmptyCallback, nullptr); +} + +void Workload::LoadTables(execution::exec::ExecutionContext *exec_ctx, const std::string &directory) { + EXECUTION_LOG_INFO("Loading tables for ProcBench benchmark..."); + execution::sql::TableReader table_reader{exec_ctx, block_store_.Get(), ns_oid_}; + for (const auto &table_name : PROCBENCH_TABLE_NAMES) { + const std::string data_path = fmt::format("{}{}.data", directory, table_name); + const std::string schema_path = fmt::format("{}{}.schema", directory, table_name); + const auto num_rows = table_reader.ReadTable(schema_path, data_path); + EXECUTION_LOG_INFO("Wrote {} rows on table {}.", num_rows, table_name); + } + EXECUTION_LOG_INFO("Done."); +} + +void Workload::LoadQueries(const std::unique_ptr &accessor) { + EXECUTION_LOG_INFO("Loading queries for ProcBench benchmark..."); + + // Executable query and plan node are stored as a tuple as the entry of vector + query_and_plan_.emplace_back(ProcbenchQuery::MakeExecutableQ6(accessor, exec_settings_)); + query_number_to_index_[Q6_ID] = query_and_plan_.size() - 1; + + EXECUTION_LOG_INFO("Done."); +} + +void Workload::Execute(std::size_t query_number, execution::vm::ExecutionMode mode) { + // The total number of queries to be executed + const std::size_t query_index = QueryNumberToIndex(query_number); + + // Register to the metrics manager + db_main_->GetMetricsManager()->RegisterThread(); + + // Execute the selected query + auto txn = txn_manager_->BeginTransaction(); + auto accessor = + catalog_->GetAccessor(common::ManagedPointer(txn), db_oid_, DISABLED); + + // Get the output schema for the query + auto *output_schema = std::get<1>(query_and_plan_.at(query_index))->GetOutputSchema().Get(); + + // Construct an execution context for the query + execution::exec::NoOpResultConsumer printer; + auto exec_ctx = execution::exec::ExecutionContextBuilder() + .WithDatabaseOID(db_oid_) + .WithExecutionSettings(exec_settings_) + .WithTxnContext(common::ManagedPointer{txn}) + .WithOutputSchema(common::ManagedPointer{output_schema}) + .WithOutputCallback(printer) + .WithCatalogAccessor(common::ManagedPointer{accessor}) + .WithMetricsManager(db_main_->GetMetricsManager()) + .WithReplicationManager(DISABLED) + .WithRecoveryManager(DISABLED) + .Build(); + + // Execute the query + std::cout << "Executing...\n"; + std::get<0>(query_and_plan_.at(query_index)) + ->Run(common::ManagedPointer(exec_ctx), mode); + txn_manager_->Commit(txn, transaction::TransactionUtil::EmptyCallback, nullptr); + std::cout << "Done.\n"; + + // Unregister from the metrics manager + db_main_->GetMetricsManager()->UnregisterThread(); +} + +std::size_t Workload::QueryNumberToIndex(std::size_t query_number) const { + return query_number_to_index_.at(query_number); +} + +} // namespace noisepage::procbench diff --git a/test/test_util/tpcc/workload_cached.cpp b/test/test_util/tpcc/workload_cached.cpp index ffd949d5f1..fd82e31da2 100644 --- a/test/test_util/tpcc/workload_cached.cpp +++ b/test/test_util/tpcc/workload_cached.cpp @@ -5,7 +5,7 @@ #include "binder/bind_node_visitor.h" #include "execution/compiler/executable_query.h" -#include "execution/exec/execution_context.h" +#include "execution/exec/execution_context_builder.h" #include "main/db_main.h" #include "optimizer/cost_model/trivial_cost_model.h" #include "parser/expression/derived_value_expression.h" @@ -78,9 +78,17 @@ void WorkloadCached::LoadTPCCQueries(const std::vector &txn_names) nullptr) ->TakePlanNodeOwnership(); - auto exec_ctx = std::make_unique( - db_oid_, common::ManagedPointer(txn), nullptr, nullptr, common::ManagedPointer(accessor), exec_settings_, - db_main_->GetMetricsManager(), DISABLED, DISABLED); + auto exec_ctx = execution::exec::ExecutionContextBuilder() + .WithDatabaseOID(db_oid_) + .WithExecutionSettings(exec_settings_) + .WithTxnContext(common::ManagedPointer{txn}) + .WithOutputSchema(execution::exec::ExecutionContext::NULL_OUTPUT_SCHEMA) + .WithOutputCallback(execution::exec::ExecutionContext::NULL_OUTPUT_CALLBACK) + .WithCatalogAccessor(common::ManagedPointer{accessor}) + .WithMetricsManager(db_main_->GetMetricsManager()) + .WithReplicationManager(DISABLED) + .WithRecoveryManager(DISABLED) + .Build(); // generate executable query and emplace it into the vector; break down here auto exec_query = std::make_unique( @@ -114,16 +122,18 @@ void WorkloadCached::Execute(int8_t worker_id, uint32_t num_precomputed_txns_per auto accessor = catalog_->GetAccessor(common::ManagedPointer(txn), db_oid_, DISABLED); for (const auto &query : queries_.find(txn_names_[index[counter]])->second) { - execution::exec::ExecutionContext exec_ctx{db_oid_, - common::ManagedPointer(txn), - nullptr, - nullptr, // FIXME: Get the correct output later - common::ManagedPointer(accessor), - exec_settings_, - db_main_->GetMetricsManager(), - DISABLED, - DISABLED}; - query->Run(common::ManagedPointer(&exec_ctx), mode); + auto exec_ctx = execution::exec::ExecutionContextBuilder() + .WithDatabaseOID(db_oid_) + .WithExecutionSettings(exec_settings_) + .WithTxnContext(common::ManagedPointer{txn}) + .WithOutputSchema(execution::exec::ExecutionContext::NULL_OUTPUT_SCHEMA) + .WithOutputCallback(execution::exec::ExecutionContext::NULL_OUTPUT_CALLBACK) + .WithCatalogAccessor(common::ManagedPointer{accessor}) + .WithMetricsManager(db_main_->GetMetricsManager()) + .WithReplicationManager(DISABLED) + .WithRecoveryManager(DISABLED) + .Build(); + query->Run(common::ManagedPointer{exec_ctx}, mode); } counter = counter == num_queries - 1 ? 0 : counter + 1; txn_manager_->Commit(txn, transaction::TransactionUtil::EmptyCallback, nullptr); diff --git a/test/test_util/tpch/workload.cpp b/test/test_util/tpch/workload.cpp index 12f83a91df..1244c55e11 100644 --- a/test/test_util/tpch/workload.cpp +++ b/test/test_util/tpch/workload.cpp @@ -5,7 +5,7 @@ #include "common/managed_pointer.h" #include "execution/compiler/output_schema_util.h" -#include "execution/exec/execution_context.h" +#include "execution/exec/execution_context_builder.h" #include "execution/sql/value_util.h" #include "execution/table_generator/table_generator.h" #include "main/db_main.h" @@ -41,13 +41,20 @@ Workload::Workload(common::ManagedPointer db_main, const std::string &db exec_settings_.is_counters_enabled_ = true; // Make the execution context - auto exec_ctx = - execution::exec::ExecutionContext(db_oid_, common::ManagedPointer(txn), nullptr, - nullptr, common::ManagedPointer(accessor), - exec_settings_, db_main->GetMetricsManager(), DISABLED, DISABLED); + auto exec_ctx = execution::exec::ExecutionContextBuilder() + .WithDatabaseOID(db_oid_) + .WithExecutionSettings(exec_settings_) + .WithTxnContext(common::ManagedPointer{txn}) + .WithOutputSchema(execution::exec::ExecutionContext::NULL_OUTPUT_SCHEMA) + .WithOutputCallback(execution::exec::ExecutionContext::NULL_OUTPUT_CALLBACK) + .WithCatalogAccessor(common::ManagedPointer{accessor}) + .WithMetricsManager(db_main->GetMetricsManager()) + .WithReplicationManager(DISABLED) + .WithRecoveryManager(DISABLED) + .Build(); // create the TPCH database and compile the queries - GenerateTables(&exec_ctx, table_root, type); + GenerateTables(exec_ctx.get(), table_root, type); LoadQueries(accessor, type); txn_manager_->Commit(txn, transaction::TransactionUtil::EmptyCallback, nullptr); @@ -147,13 +154,21 @@ void Workload::Execute(int8_t worker_id, uint64_t execution_us_per_worker, uint6 // Uncomment this line and change output.cpp:90 to EXECUTION_LOG_INFO to print output // execution::exec::OutputPrinter printer(output_schema); execution::exec::NoOpResultConsumer printer; - auto exec_ctx = execution::exec::ExecutionContext( - db_oid_, common::ManagedPointer(txn), printer, output_schema, - common::ManagedPointer(accessor), exec_settings_, db_main_->GetMetricsManager(), - DISABLED, DISABLED); + + auto exec_ctx = execution::exec::ExecutionContextBuilder() + .WithDatabaseOID(db_oid_) + .WithExecutionSettings(exec_settings_) + .WithTxnContext(common::ManagedPointer{txn}) + .WithOutputSchema(common::ManagedPointer{output_schema}) + .WithOutputCallback(printer) + .WithCatalogAccessor(common::ManagedPointer{accessor}) + .WithMetricsManager(db_main_->GetMetricsManager()) + .WithReplicationManager(DISABLED) + .WithRecoveryManager(DISABLED) + .Build(); std::get<0>(query_and_plan_[index[counter]]) - ->Run(common::ManagedPointer(&exec_ctx), mode); + ->Run(common::ManagedPointer(exec_ctx), mode); // Only execute up to query_num number of queries for this thread in round-robin counter = counter == query_num - 1 ? 0 : counter + 1; diff --git a/third_party/libpg_query/pg_list.h b/third_party/libpg_query/pg_list.h index 21e9a1a31b..25c023c6cf 100644 --- a/third_party/libpg_query/pg_list.h +++ b/third_party/libpg_query/pg_list.h @@ -37,6 +37,7 @@ #ifndef PG_LIST_H #define PG_LIST_H +#include #include "nodes.h" typedef struct ListCell ListCell; @@ -76,30 +77,23 @@ struct ListCell * if supported by the compiler, or as regular functions otherwise. * See STATIC_IF_INLINE in c.h. */ -#ifndef PG_USE_INLINE -extern ListCell *list_head(const List *l); -extern ListCell *list_tail(List *l); -extern int list_length(const List *l); -#endif /* PG_USE_INLINE */ -#if defined(PG_USE_INLINE) || defined(PG_LIST_INCLUDE_DEFINITIONS) -STATIC_IF_INLINE ListCell * +static inline ListCell * list_head(const List *l) { return l ? l->head : NULL; } -STATIC_IF_INLINE ListCell * +static inline ListCell * list_tail(List *l) { return l ? l->tail : NULL; } -STATIC_IF_INLINE int +static inline int list_length(const List *l) { return l ? l->length : 0; } -#endif /*-- PG_USE_INLINE || PG_LIST_INCLUDE_DEFINITIONS */ /* * NB: There is an unfortunate legacy from a previous incarnation of diff --git a/third_party/libpg_query/src/pg_query_parse_plpgsql.c b/third_party/libpg_query/src/pg_query_parse_plpgsql.c index ba102157eb..af6773e04b 100644 --- a/third_party/libpg_query/src/pg_query_parse_plpgsql.c +++ b/third_party/libpg_query/src/pg_query_parse_plpgsql.c @@ -439,6 +439,7 @@ PgQueryPlpgsqlParseResult pg_query_parse_plpgsql(const char* input) result.plpgsql_funcs[strlen(result.plpgsql_funcs) - 2] = '\n'; result.plpgsql_funcs[strlen(result.plpgsql_funcs) - 1] = ']'; + free(parse_result.stderr_buffer); pg_query_exit_memory_context(ctx); return result; diff --git a/util/execution/table_generator/table_reader.cpp b/util/execution/table_generator/table_reader.cpp index ca5306b097..124ce15679 100644 --- a/util/execution/table_generator/table_reader.cpp +++ b/util/execution/table_generator/table_reader.cpp @@ -138,7 +138,7 @@ void TableReader::WriteIndexEntry(IndexInfo *index_info, storage::ProjectedRow * void TableReader::WriteTableCol(storage::ProjectedRow *insert_pr, uint16_t col_offset, execution::sql::SqlTypeId type, csv::CSVField *field) { - if (*field == NULL_STRING) { + if (*field == NULL_STRING || field->is_null()) { insert_pr->SetNull(col_offset); return; } @@ -190,7 +190,7 @@ void TableReader::WriteTableCol(storage::ProjectedRow *insert_pr, uint16_t col_o break; } default: - UNREACHABLE("Unsupported type. Add it here first!!!"); + UNREACHABLE("Unsupported type."); } } diff --git a/util/execution/tpl.cpp b/util/execution/tpl.cpp index 029faf4efc..cc72c17c07 100644 --- a/util/execution/tpl.cpp +++ b/util/execution/tpl.cpp @@ -16,7 +16,7 @@ #include "execution/ast/ast_dump.h" #include "execution/ast/ast_pretty_print.h" -#include "execution/exec/execution_context.h" +#include "execution/exec/execution_context_builder.h" #include "execution/exec/execution_settings.h" #include "execution/parsing/parser.h" #include "execution/parsing/scanner.h" @@ -43,12 +43,24 @@ #include "transaction/deferred_action_manager.h" #include "transaction/timestamp_manager.h" +/** Suppress warnings from unused variables */ +#define SUPPRESS_UNUSED(x) ((void)x) + // --------------------------------------------------------- // CLI options // --------------------------------------------------------- +/** Enumeration for requested execution modes */ +enum ExecuteOn { VM, JIT, ADAPTIVE, ALL }; + // clang-format off llvm::cl::OptionCategory TPL_OPTIONS_CATEGORY("TPL Compiler Options", "Options for controlling the TPL compilation process."); // NOLINT +llvm::cl::opt EXECUTE_ON("execute-on", llvm::cl::desc("The execution mode"), llvm::cl::values( // NOLINT + clEnumVal(VM, ""), + clEnumVal(JIT, ""), + clEnumVal(ADAPTIVE, ""), + clEnumVal(ALL, "") +), llvm::cl::init(ALL), llvm::cl::cat(TPL_OPTIONS_CATEGORY)); llvm::cl::opt PRINT_AST("print-ast", llvm::cl::desc("Print the programs AST"), llvm::cl::cat(TPL_OPTIONS_CATEGORY)); // NOLINT llvm::cl::opt PRINT_TBC("print-tbc", llvm::cl::desc("Print the generated TPL Bytecode"), llvm::cl::cat(TPL_OPTIONS_CATEGORY)); // NOLINT llvm::cl::opt PRETTY_PRINT("pretty-print", llvm::cl::desc("Pretty-print the source from the parsed AST"), llvm::cl::cat(TPL_OPTIONS_CATEGORY)); // NOLINT @@ -66,6 +78,68 @@ namespace noisepage::execution { static constexpr const char *K_EXIT_KEYWORD = ".exit"; +/** + * + */ +static bool ShouldExecuteInMode(vm::ExecutionMode mode) { + auto mode_requested = [mode]() -> bool { + switch (mode) { + case vm::ExecutionMode::Interpret: + return EXECUTE_ON == VM; + case vm::ExecutionMode::Compiled: + return EXECUTE_ON == JIT; + case vm::ExecutionMode::Adaptive: + return EXECUTE_ON == ADAPTIVE; + default: + return false; + } + }; + return EXECUTE_ON == ALL || mode_requested(); +} + +/** + * Execute + */ +static double ExecuteInMode(vm::Module *module, vm::ExecutionMode mode, exec::ExecutionContext *exec_ctx) { + const char *mode_identifier = [mode]() { + switch (mode) { + case vm::ExecutionMode::Interpret: + return "VM"; + case vm::ExecutionMode::Compiled: + return "JIT"; + case vm::ExecutionMode::Adaptive: + return "ADAPTIVE"; + default: + UNREACHABLE("Unknown Execution Mode"); + } + }(); + + double exec_ms{}; + exec_ctx->SetExecutionMode(mode); + { + util::ScopedTimer timer(&exec_ms); + + if (IS_SQL) { + std::function main; + if (!module->GetFunction("main", mode, &main)) { + EXECUTION_LOG_ERROR("Missing 'main' entry function with signature (*ExecutionContext) - >int32"); + return 0.0; + } + EXECUTION_LOG_INFO("{} main() returned: {}", mode_identifier, main(exec_ctx)); + } else { + std::function main; + if (!module->GetFunction("main", mode, &main)) { + EXECUTION_LOG_ERROR("Missing 'main' entry function with signature () -> int32"); + return 0.0; + } + EXECUTION_LOG_INFO("{} main() returned: {}", mode_identifier, main()); + } + } + + SUPPRESS_UNUSED(mode_identifier); + return exec_ms; +} + /** * Compile TPL source code contained in @em source and execute it in all execution modes once. * @param source The TPL source. @@ -94,20 +168,30 @@ static void CompileAndRun(const std::string &source, const std::string &name = " exec::ExecutionSettings exec_settings{}; exec::OutputPrinter printer(output_schema); exec::OutputCallback callback = printer; - exec::ExecutionContext exec_ctx{ - db_oid, common::ManagedPointer(txn), callback, output_schema, common::ManagedPointer(accessor), - exec_settings, db_main->GetMetricsManager(), DISABLED, DISABLED}; + // Add dummy parameters for tests - std::vector params; + std::vector params{}; params.emplace_back(execution::sql::SqlTypeId::Integer, sql::Integer(37)); params.emplace_back(execution::sql::SqlTypeId::Double, sql::Real(37.73)); params.emplace_back(execution::sql::SqlTypeId::Date, sql::DateVal(sql::Date::FromYMD(1937, 3, 7))); auto string_val = sql::ValueUtil::CreateStringVal(std::string_view("37 Strings")); params.emplace_back(execution::sql::SqlTypeId::Varchar, string_val.first, std::move(string_val.second)); - exec_ctx.SetParams(common::ManagedPointer>(¶ms)); + + auto exec_ctx = exec::ExecutionContextBuilder() + .WithDatabaseOID(db_oid) + .WithExecutionSettings(exec_settings) + .WithTxnContext(common::ManagedPointer{txn}) + .WithOutputSchema(common::ManagedPointer{output_schema}) + .WithOutputCallback(callback) + .WithCatalogAccessor(common::ManagedPointer{accessor}) + .WithMetricsManager(db_main->GetMetricsManager()) + .WithReplicationManager(DISABLED) + .WithRecoveryManager(DISABLED) + .WithQueryParametersFrom(params) + .Build(); // Generate test tables - sql::TableGenerator table_generator{&exec_ctx, db_main->GetStorageLayer()->GetBlockStore(), ns_oid}; + sql::TableGenerator table_generator{exec_ctx.get(), db_main->GetStorageLayer()->GetBlockStore(), ns_oid}; table_generator.GenerateTestTables(); // Comment out to make more tables available at runtime // table_generator.GenerateTPCHTables(); @@ -122,12 +206,9 @@ static void CompileAndRun(const std::string &source, const std::string &name = " parsing::Scanner scanner(source.data(), source.length()); parsing::Parser parser(&scanner, &context); - double parse_ms = 0.0, // Time to parse the source - typecheck_ms = 0.0, // Time to perform semantic analysis - codegen_ms = 0.0, // Time to generate TBC - interp_exec_ms = 0.0, // Time to execute the program in fully interpreted mode - adaptive_exec_ms = 0.0, // Time to execute the program in adaptive mode - jit_exec_ms = 0.0; // Time to execute the program in JIT excluding compilation time + double parse_ms = 0.0; // Time to parse the source + double typecheck_ms = 0.0; // Time to perform semantic analysis + double codegen_ms = 0.0; // Time to generate TBC // // Parse @@ -189,87 +270,33 @@ static void CompileAndRun(const std::string &source, const std::string &name = " auto module = std::make_unique(std::move(bytecode_module), std::move(module_metadata)); // - // Interpret - // - - { - exec_ctx.SetExecutionMode(static_cast(vm::ExecutionMode::Interpret)); - util::ScopedTimer timer(&interp_exec_ms); - - if (IS_SQL) { - std::function main; - if (!module->GetFunction("main", vm::ExecutionMode::Interpret, &main)) { - EXECUTION_LOG_ERROR("Missing 'main' entry function with signature (*ExecutionContext)->int32"); - return; - } - EXECUTION_LOG_INFO("VM main() returned: {}", main(&exec_ctx)); - } else { - std::function main; - if (!module->GetFunction("main", vm::ExecutionMode::Interpret, &main)) { - EXECUTION_LOG_ERROR("Missing 'main' entry function with signature ()->int32"); - return; - } - EXECUTION_LOG_INFO("VM main() returned: {}", main()); - } - } - - // - // Adaptive + // Execution // - exec_ctx.SetExecutionMode(static_cast(vm::ExecutionMode::Adaptive)); - util::ScopedTimer timer(&adaptive_exec_ms); - - if (IS_SQL) { - std::function main; - if (!module->GetFunction("main", vm::ExecutionMode::Adaptive, &main)) { - EXECUTION_LOG_ERROR("Missing 'main' entry function with signature (*ExecutionContext)->int32"); - return; - } - EXECUTION_LOG_INFO("ADAPTIVE main() returned: {}", main(&exec_ctx)); - } else { - std::function main; - if (!module->GetFunction("main", vm::ExecutionMode::Adaptive, &main)) { - EXECUTION_LOG_ERROR("Missing 'main' entry function with signature ()->int32"); - return; - } - EXECUTION_LOG_INFO("ADAPTIVE main() returned: {}", main()); - } + const double vm_ms = ShouldExecuteInMode(vm::ExecutionMode::Interpret) + ? ExecuteInMode(module.get(), vm::ExecutionMode::Interpret, exec_ctx.get()) + : 0.0; + const double jit_ms = ShouldExecuteInMode(vm::ExecutionMode::Compiled) + ? ExecuteInMode(module.get(), vm::ExecutionMode::Compiled, exec_ctx.get()) + : 0.0; + const double adaptive_ms = ShouldExecuteInMode(vm::ExecutionMode::Adaptive) + ? ExecuteInMode(module.get(), vm::ExecutionMode::Adaptive, exec_ctx.get()) + : 0.0; // - // JIT + // Dump stats // - { - exec_ctx.SetExecutionMode(static_cast(vm::ExecutionMode::Compiled)); - util::ScopedTimer timer(&jit_exec_ms); - if (IS_SQL) { - std::function main; - if (!module->GetFunction("main", vm::ExecutionMode::Compiled, &main)) { - EXECUTION_LOG_ERROR("Missing 'main' entry function with signature (*ExecutionContext)->int32"); - return; - } - util::Timer x; - x.Start(); - EXECUTION_LOG_INFO("JIT main() returned: {}", main(&exec_ctx)); - x.Stop(); - EXECUTION_LOG_INFO("Jit exec: {} ms", x.GetElapsed()); - } else { - std::function main; - if (!module->GetFunction("main", vm::ExecutionMode::Compiled, &main)) { - EXECUTION_LOG_ERROR("Missing 'main' entry function with signature ()->int32"); - return; - } - EXECUTION_LOG_INFO("JIT main() returned: {}", main()); - } - } - - // Dump stats EXECUTION_LOG_INFO( "Parse: {} ms, Type-check: {} ms, Code-gen: {} ms, Interp. Exec.: {} ms, " - "Adaptive Exec.: {} ms, Jit+Exec.: {} ms", - parse_ms, typecheck_ms, codegen_ms, interp_exec_ms, adaptive_exec_ms, jit_exec_ms); + "JIT Exec.: {} ms, Adaptive Exec.: {} ms", + parse_ms, typecheck_ms, codegen_ms, vm_ms, jit_ms, adaptive_ms); + txn_manager->Commit(txn, transaction::TransactionUtil::EmptyCallback, nullptr); + + SUPPRESS_UNUSED(vm_ms); + SUPPRESS_UNUSED(jit_ms); + SUPPRESS_UNUSED(adaptive_ms); } /** diff --git a/util/include/execution/table_generator/schema_reader.h b/util/include/execution/table_generator/schema_reader.h index d28368b927..78e4fc69fd 100644 --- a/util/include/execution/table_generator/schema_reader.h +++ b/util/include/execution/table_generator/schema_reader.h @@ -104,10 +104,12 @@ class SchemaReader { */ SchemaReader() : type_names_{{"tinyint", execution::sql::SqlTypeId::TinyInt}, {"smallint", execution::sql::SqlTypeId::SmallInt}, - {"int", execution::sql::SqlTypeId::Integer}, {"bigint", execution::sql::SqlTypeId::BigInt}, - {"bool", execution::sql::SqlTypeId::Boolean}, {"real", execution::sql::SqlTypeId::Double}, + {"integer", execution::sql::SqlTypeId::Integer}, {"int", execution::sql::SqlTypeId::Integer}, + {"bigint", execution::sql::SqlTypeId::BigInt}, {"bool", execution::sql::SqlTypeId::Boolean}, + {"real", execution::sql::SqlTypeId::Double}, {"float8", execution::sql::SqlTypeId::Double}, {"decimal", execution::sql::SqlTypeId::Double}, {"varchar", execution::sql::SqlTypeId::Varchar}, - {"varlen", execution::sql::SqlTypeId::Varchar}, {"date", execution::sql::SqlTypeId::Date}} {} + {"char", execution::sql::SqlTypeId::Char}, {"varlen", execution::sql::SqlTypeId::Varchar}, + {"date", execution::sql::SqlTypeId::Date}} {} /** * Reads table metadata