From 2637682c58b2bd682d579e228add93dd9d05dbd3 Mon Sep 17 00:00:00 2001 From: Ypatia Tsavliri Date: Mon, 9 Dec 2024 10:36:05 +0200 Subject: [PATCH] Add classes for Task and SharedTask in threadpool (#5391) Currently `ThreadPool::Task` is a typedef to `std:: future`. This PR: 1) Replaces this with a full class that imposes things like `wait` for our async processes to go through our ThreadPool's wait method. In that way, we never use a `std::future`'s `wait` method directly and we avoid possible deadlocks because only the ThreadPool's `wait` will yield. 2) Adjusts the structure and relationship of tasks and the threadpool so that the caller can rely on the task's internal functions and doesn't need to keep track of the relationship between the ThreadPool and each task manually throughout the codebase. 3) Adds a new first class citizen of our ThreadPool: `SharedTask`, that is encapsulating `std::shared_future` which allows multiple threads to wait on an async operation result to become available. This is needed for the upcoming work on parallelizing IO and compute operations in the codebase even further. This is heavily influenced from similar work by @Shelnutt2. --- TYPE: IMPROVEMENT DESC: Add classes for Task and SharedTask in threadpool --------- Co-authored-by: Seth Shelnutt --- .../thread_pool/test/unit_thread_pool.cc | 6 +- tiledb/common/thread_pool/thread_pool.cc | 50 ++++- tiledb/common/thread_pool/thread_pool.h | 196 +++++++++++++++++- tiledb/sm/query/readers/dense_reader.cc | 14 +- tiledb/sm/query/readers/dense_reader.h | 2 +- tiledb/sm/query/writers/ordered_writer.cc | 8 +- 6 files changed, 244 insertions(+), 32 deletions(-) diff --git a/tiledb/common/thread_pool/test/unit_thread_pool.cc b/tiledb/common/thread_pool/test/unit_thread_pool.cc index 5220f9be8a9..c119d2acd4f 100644 --- a/tiledb/common/thread_pool/test/unit_thread_pool.cc +++ b/tiledb/common/thread_pool/test/unit_thread_pool.cc @@ -100,7 +100,7 @@ void wait_all( ThreadPool& pool, bool use_wait, std::vector& results) { if (use_wait) { for (auto& r : results) { - REQUIRE(pool.wait(r).ok()); + REQUIRE(r.wait().ok()); } } else { REQUIRE(pool.wait_all(results).ok()); @@ -117,7 +117,7 @@ Status wait_all_status( if (use_wait) { Status ret; for (auto& r : results) { - auto st = pool.wait(r); + auto st = r.wait(); if (ret.ok() && !st.ok()) { ret = st; } @@ -139,7 +139,7 @@ uint64_t wait_all_num_status( int num_ok = 0; if (use_wait) { for (auto& r : results) { - num_ok += pool.wait(r).ok() ? 1 : 0; + num_ok += r.wait().ok() ? 1 : 0; } } else { std::vector statuses = pool.wait_all_status(results); diff --git a/tiledb/common/thread_pool/thread_pool.cc b/tiledb/common/thread_pool/thread_pool.cc index b026262a3c3..f4f7f9d1709 100644 --- a/tiledb/common/thread_pool/thread_pool.cc +++ b/tiledb/common/thread_pool/thread_pool.cc @@ -121,7 +121,7 @@ void ThreadPool::shutdown() { threads_.clear(); } -Status ThreadPool::wait_all(std::vector& tasks) { +Status ThreadPool::wait_all(std::vector& tasks) { auto statuses = wait_all_status(tasks); for (auto& st : statuses) { if (!st.ok()) { @@ -131,6 +131,24 @@ Status ThreadPool::wait_all(std::vector& tasks) { return Status::Ok(); } +Status ThreadPool::wait_all(std::vector& tasks) { + std::vector task_ptrs; + for (auto& t : tasks) { + task_ptrs.emplace_back(&t); + } + + return wait_all(task_ptrs); +} + +Status ThreadPool::wait_all(std::vector& tasks) { + std::vector task_ptrs; + for (auto& t : tasks) { + task_ptrs.emplace_back(&t); + } + + return wait_all(task_ptrs); +} + // Return a vector of Status. If any task returns an error value or throws an // exception, we save an error code in the corresponding location in the Status // vector. All tasks are waited on before return. Multiple error statuses may @@ -138,7 +156,8 @@ Status ThreadPool::wait_all(std::vector& tasks) { // context is fully constructed (which will include logger). // Unfortunately, C++ does not have the notion of an aggregate exception, so we // don't throw in the case of errors/exceptions. -std::vector ThreadPool::wait_all_status(std::vector& tasks) { +std::vector ThreadPool::wait_all_status( + std::vector& tasks) { std::vector statuses(tasks.size()); std::queue pending_tasks; @@ -154,17 +173,17 @@ std::vector ThreadPool::wait_all_status(std::vector& tasks) { pending_tasks.pop(); auto& task = tasks[task_id]; - if (!task.valid()) { + if (task && !task->valid()) { statuses[task_id] = Status_ThreadPoolError("Invalid task future"); LOG_STATUS_NO_RETURN_VALUE(statuses[task_id]); } else if ( - task.wait_for(std::chrono::milliseconds(0)) == + task->wait_for(std::chrono::milliseconds(0)) == std::future_status::ready) { // Task is completed, get result, handling possible exceptions Status st = [&task] { try { - return task.get(); + return task->get(); } catch (const std::exception& e) { return Status_TaskError( "Caught std::exception: " + std::string(e.what())); @@ -205,7 +224,26 @@ std::vector ThreadPool::wait_all_status(std::vector& tasks) { return statuses; } -Status ThreadPool::wait(Task& task) { +std::vector ThreadPool::wait_all_status(std::vector& tasks) { + std::vector task_ptrs; + for (auto& t : tasks) { + task_ptrs.emplace_back(&t); + } + + return wait_all_status(task_ptrs); +} + +std::vector ThreadPool::wait_all_status( + std::vector& tasks) { + std::vector task_ptrs; + for (auto& t : tasks) { + task_ptrs.emplace_back(&t); + } + + return wait_all_status(task_ptrs); +} + +Status ThreadPool::wait(ThreadPoolTask& task) { while (true) { if (!task.valid()) { return Status_ThreadPoolError("Invalid task future"); diff --git a/tiledb/common/thread_pool/thread_pool.h b/tiledb/common/thread_pool/thread_pool.h index e0f804c7278..f72ad34f256 100644 --- a/tiledb/common/thread_pool/thread_pool.h +++ b/tiledb/common/thread_pool/thread_pool.h @@ -47,7 +47,175 @@ namespace tiledb::common { class ThreadPool { public: - using Task = std::future; + /** + * @brief Abstract base class for tasks that can run in this threadpool. + */ + class ThreadPoolTask { + public: + ThreadPoolTask() = default; + ThreadPoolTask(ThreadPool* tp) + : tp_(tp){}; + + virtual ~ThreadPoolTask(){}; + + protected: + friend class ThreadPool; + + /* C.67 A polymorphic class should suppress public copy/move to prevent + * slicing */ + ThreadPoolTask(const ThreadPoolTask&) = default; + ThreadPoolTask& operator=(const ThreadPoolTask&) = default; + ThreadPoolTask(ThreadPoolTask&&) = default; + ThreadPoolTask& operator=(ThreadPoolTask&&) = default; + + /** + * Pure virtual functions that tasks need to implement so that they can be + * run in the threadpool wait loop + */ + virtual std::future_status wait_for( + const std::chrono::milliseconds timeout_duration) const = 0; + virtual bool valid() const noexcept = 0; + virtual Status get() = 0; + + ThreadPool* tp_{nullptr}; + }; + + /** + * @brief Task class encapsulating std::future. Like std::future it's shared + * state can only be get once and thus only one thread. It can only be moved + * and not copied. + */ + class Task : public ThreadPoolTask { + public: + /** + * Default constructor + * @brief Default constructed SharedTask is possible but not valid(). + */ + Task() = default; + + /** + * Constructor from std::future + */ + Task(std::future&& f, ThreadPool* tp) + : ThreadPoolTask(tp) + , f_(std::move(f)){}; + + /** + * Wait in the threadpool for this task to be ready. + */ + Status wait() { + if (tp_ == nullptr) { + throw std::runtime_error("Cannot wait, threadpool is not initialized."); + } else if (!f_.valid()) { + throw std::runtime_error("Cannot wait, task is invalid."); + } else { + return tp_->wait(*this); + } + } + + /** + * Is this task valid. Wait can only be called on vaid tasks. + */ + bool valid() const noexcept { + return f_.valid(); + } + + private: + friend class ThreadPool; + + /** + * Wait for input milliseconds for this task to be ready. + */ + std::future_status wait_for( + const std::chrono::milliseconds timeout_duration) const { + return f_.wait_for(timeout_duration); + } + + /** + * Get the result of that task. Can only be used once. Only accessible from + * within the threadpool `wait` loop. + */ + Status get() { + return f_.get(); + } + + /** + * The encapsulated std::shared_future + */ + std::future f_; + }; + + /** + * @brief SharedTask class encapsulating std::shared_future. Like + * std::shared_future multiple threads can wait/get on the shared state + * multiple times. It can be both moved and copied. + */ + class SharedTask : public ThreadPoolTask { + public: + /** + * Default constructor + * @brief Default constructed SharedTask is possible but not valid(). + */ + SharedTask() = default; + + /** + * Constructor from std::future or std::shared_future + */ + SharedTask(auto&& f, ThreadPool* tp) + : ThreadPoolTask(tp) + , f_(std::forward(f)){}; + + /** + * Move constructor from a Task + */ + SharedTask(Task&& t) noexcept + : ThreadPoolTask(t.tp_) + , f_(std::move(t.f_)){}; + + /** + * Wait in the threadpool for this task to be ready. + */ + Status wait() { + if (tp_ == nullptr) { + throw std::runtime_error("Cannot wait, threadpool is not initialized."); + } else if (!f_.valid()) { + throw std::runtime_error("Cannot wait, shared task is invalid."); + } else { + return tp_->wait(*this); + } + } + + /** + * Is this task valid. Wait can only be called on vaid tasks. + */ + bool valid() const noexcept { + return f_.valid(); + } + + private: + friend class ThreadPool; + + /** + * Wait for input milliseconds for this task to be ready. + */ + std::future_status wait_for( + const std::chrono::milliseconds timeout_duration) const { + return f_.wait_for(timeout_duration); + } + + /** + * Get the result of that task. Can be called multiple times from multiple + * threads. Only accessible from within the threadpool `wait` loop. + */ + Status get() { + return f_.get(); + } + + /** + * The encapsulated std::shared_future + */ + std::shared_future f_; + }; /* ********************************* */ /* CONSTRUCTORS & DESTRUCTORS */ @@ -108,7 +276,7 @@ class ThreadPool { return std::apply(std::move(f), std::move(args)); }); - std::future future = task->get_future(); + Task future(task->get_future(), this); task_queue_.push(task); @@ -127,6 +295,19 @@ class ThreadPool { return async(std::forward(f), std::forward(args)...); } + /* Helper functions for lists that consists purely of Tasks */ + Status wait_all(std::vector& tasks); + std::vector wait_all_status(std::vector& tasks); + + /* Helper functions for lists that consists purely of SharedTasks */ + Status wait_all(std::vector& tasks); + std::vector wait_all_status(std::vector& tasks); + + /* ********************************* */ + /* PRIVATE ATTRIBUTES */ + /* ********************************* */ + + private: /** * Wait on all the given tasks to complete. This function is safe to call * recursively and may execute pending tasks on the calling thread while @@ -136,7 +317,7 @@ class ThreadPool { * @return Status::Ok if all tasks returned Status::Ok, otherwise the first * error status is returned */ - Status wait_all(std::vector& tasks); + Status wait_all(std::vector& tasks); /** * Wait on all the given tasks to complete, returning a vector of their return @@ -151,7 +332,7 @@ class ThreadPool { * @param tasks Task list to wait on * @return Vector of each task's Status. */ - std::vector wait_all_status(std::vector& tasks); + std::vector wait_all_status(std::vector& tasks); /** * Wait on a single tasks to complete. This function is safe to call @@ -162,13 +343,8 @@ class ThreadPool { * @return Status::Ok if the task returned Status::Ok, otherwise the error * status is returned */ - Status wait(Task& task); - - /* ********************************* */ - /* PRIVATE ATTRIBUTES */ - /* ********************************* */ + Status wait(ThreadPoolTask& task); - private: /** The worker thread routine */ void worker(); diff --git a/tiledb/sm/query/readers/dense_reader.cc b/tiledb/sm/query/readers/dense_reader.cc index ef29f439ca0..f59c6dbe4d9 100644 --- a/tiledb/sm/query/readers/dense_reader.cc +++ b/tiledb/sm/query/readers/dense_reader.cc @@ -368,7 +368,7 @@ Status DenseReader::dense_read() { // This is as far as we should go before implementing this properly in a task // graph, where the start and end of every piece of work can clearly be // identified. - ThreadPool::Task compute_task; + ThreadPool::SharedTask compute_task; // Allow to disable the parallel read/compute in case the memory budget // doesn't allow it. @@ -432,7 +432,7 @@ Status DenseReader::dense_read() { // prevent using too much memory when the budget is small and doesn't allow // to process more than one batch at a time. if (wait_compute_task_before_read && compute_task.valid()) { - throw_if_not_ok(resources_.compute_tp().wait(compute_task)); + throw_if_not_ok(compute_task.wait()); } // Apply the query condition. @@ -478,7 +478,7 @@ Status DenseReader::dense_read() { // is to prevent using too much memory when the budget is small and // doesn't allow to process more than one batch at a time. if (wait_compute_task_before_read && compute_task.valid()) { - throw_if_not_ok(resources_.compute_tp().wait(compute_task)); + throw_if_not_ok(compute_task.wait()); } // Read and unfilter tiles. @@ -489,7 +489,7 @@ Status DenseReader::dense_read() { } if (compute_task.valid()) { - throw_if_not_ok(resources_.compute_tp().wait(compute_task)); + throw_if_not_ok(compute_task.wait()); if (read_state_.overflowed_) { return Status::Ok(); } @@ -578,7 +578,7 @@ Status DenseReader::dense_read() { } if (compute_task.valid()) { - throw_if_not_ok(resources_.compute_tp().wait(compute_task)); + throw_if_not_ok(compute_task.wait()); } // For `qc_coords_mode` just fill in the coordinates and skip attribute @@ -1038,7 +1038,7 @@ std::vector DenseReader::result_tiles_to_load( */ template Status DenseReader::apply_query_condition( - ThreadPool::Task& compute_task, + ThreadPool::SharedTask& compute_task, Subarray& subarray, const std::unordered_set& condition_names, const std::vector& tile_extents, @@ -1075,7 +1075,7 @@ Status DenseReader::apply_query_condition( NameToLoad::from_string_vec(qc_names), result_tiles)); if (compute_task.valid()) { - throw_if_not_ok(resources_.compute_tp().wait(compute_task)); + throw_if_not_ok(compute_task.wait()); } compute_task = resources_.compute_tp().execute([&, diff --git a/tiledb/sm/query/readers/dense_reader.h b/tiledb/sm/query/readers/dense_reader.h index 2434839157e..58e81607e2f 100644 --- a/tiledb/sm/query/readers/dense_reader.h +++ b/tiledb/sm/query/readers/dense_reader.h @@ -285,7 +285,7 @@ class DenseReader : public ReaderBase, public IQueryStrategy { /** Apply the query condition. */ template Status apply_query_condition( - ThreadPool::Task& compute_task, + ThreadPool::SharedTask& compute_task, Subarray& subarray, const std::unordered_set& condition_names, const std::vector& tile_extents, diff --git a/tiledb/sm/query/writers/ordered_writer.cc b/tiledb/sm/query/writers/ordered_writer.cc index ed20b1f35d4..e4ba47500de 100644 --- a/tiledb/sm/query/writers/ordered_writer.cc +++ b/tiledb/sm/query/writers/ordered_writer.cc @@ -313,7 +313,7 @@ Status OrderedWriter::prepare_filter_and_write_tiles( uint64_t frag_tile_id = 0; bool close_files = false; tile_batches.resize(batch_num); - std::optional write_task = nullopt; + std::optional write_task = nullopt; for (uint64_t b = 0; b < batch_num; ++b) { auto batch_size = (b == batch_num - 1) ? last_batch_size : thread_num; assert(batch_size > 0); @@ -358,8 +358,7 @@ Status OrderedWriter::prepare_filter_and_write_tiles( } if (write_task.has_value()) { - write_task->wait(); - RETURN_NOT_OK(write_task->get()); + RETURN_NOT_OK(write_task->wait()); } write_task = resources_.io_tp().execute([&, b, frag_tile_id]() { @@ -380,8 +379,7 @@ Status OrderedWriter::prepare_filter_and_write_tiles( } if (write_task.has_value()) { - write_task->wait(); - RETURN_NOT_OK(write_task->get()); + RETURN_NOT_OK(write_task->wait()); } return Status::Ok();