Skip to content

Commit

Permalink
Support thread local object iteration (#2632)
Browse files Browse the repository at this point in the history
  • Loading branch information
chenBright authored Jun 3, 2024
1 parent 59a68dd commit 02fd47c
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 61 deletions.
81 changes: 29 additions & 52 deletions src/butil/thread_key.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ pthread_mutex_t g_thread_key_mutex = PTHREAD_MUTEX_INITIALIZER;
static size_t g_id = 0;
static std::deque<size_t>* g_free_ids = NULL;
static std::vector<ThreadKeyInfo>* g_thread_keys = NULL;
static __thread std::vector<ThreadKeyTLS>* g_tls_data = NULL;
static __thread std::vector<ThreadKeyTLS>* thread_key_tls_data = NULL;

ThreadKey& ThreadKey::operator=(ThreadKey&& other) noexcept {
if (this == &other) {
Expand All @@ -56,58 +56,42 @@ bool ThreadKey::Valid() const {
}

static void DestroyTlsData() {
if (!g_tls_data) {
if (!thread_key_tls_data) {
return;
}
std::vector<ThreadKeyInfo> dummy_keys;
{
BAIDU_SCOPED_LOCK(g_thread_key_mutex);
if (BAIDU_LIKELY(g_thread_keys)) {
dummy_keys.insert(dummy_keys.end(), g_thread_keys->begin(), g_thread_keys->end());
}
dummy_keys.insert(dummy_keys.end(),
g_thread_keys->begin(),
g_thread_keys->end());
}
for (size_t i = 0; i < g_tls_data->size(); ++i) {
for (size_t i = 0; i < thread_key_tls_data->size(); ++i) {
if (!KEY_UNUSED(dummy_keys[i].seq) && dummy_keys[i].dtor) {
dummy_keys[i].dtor((*g_tls_data)[i].data);
dummy_keys[i].dtor((*thread_key_tls_data)[i].data);
}
}
delete g_tls_data;
g_tls_data = NULL;
}

static std::deque<size_t>* GetGlobalFreeIds() {
if (BAIDU_UNLIKELY(!g_free_ids)) {
g_free_ids = new (std::nothrow) std::deque<size_t>();
if (BAIDU_UNLIKELY(!g_free_ids)) {
abort();
}
}

return g_free_ids;
delete thread_key_tls_data;
thread_key_tls_data = NULL;
}

int thread_key_create(ThreadKey& thread_key, DtorFunction dtor) {
BAIDU_SCOPED_LOCK(g_thread_key_mutex);
size_t id;
auto free_ids = GetGlobalFreeIds();
if (!free_ids) {
return ENOMEM;
if (BAIDU_UNLIKELY(!g_free_ids)) {
g_free_ids = new std::deque<size_t>;
}

if (!free_ids->empty()) {
id = free_ids->back();
free_ids->pop_back();
size_t id;
if (!g_free_ids->empty()) {
id = g_free_ids->back();
g_free_ids->pop_back();
} else {
if (g_id >= ThreadKey::InvalidID) {
// No more available ids.
return EAGAIN;
}
id = g_id++;
if(BAIDU_UNLIKELY(!g_thread_keys)) {
g_thread_keys = new (std::nothrow) std::vector<ThreadKeyInfo>;
if(BAIDU_UNLIKELY(!g_thread_keys)) {
return ENOMEM;
}
if (BAIDU_UNLIKELY(!g_thread_keys)) {
g_thread_keys = new std::vector<ThreadKeyInfo>;
g_thread_keys->reserve(THREAD_KEY_RESERVE);
}
g_thread_keys->resize(id + 1);
Expand Down Expand Up @@ -136,14 +120,10 @@ int thread_key_delete(ThreadKey& thread_key) {
return EINVAL;
}

if (BAIDU_UNLIKELY(!GetGlobalFreeIds())) {
return ENOMEM;
}

++((*g_thread_keys)[id].seq);
// Collect the usable key id for reuse.
if (KEY_USABLE((*g_thread_keys)[id].seq)) {
GetGlobalFreeIds()->push_back(id);
g_free_ids->push_back(id);
}
thread_key.Reset();

Expand All @@ -156,22 +136,19 @@ int thread_setspecific(ThreadKey& thread_key, void* data) {
}
size_t id = thread_key._id;
size_t seq = thread_key._seq;
if (BAIDU_UNLIKELY(!g_tls_data)) {
g_tls_data = new (std::nothrow) std::vector<ThreadKeyTLS>;
if (BAIDU_UNLIKELY(!g_tls_data)) {
return ENOMEM;
}
g_tls_data->reserve(THREAD_KEY_RESERVE);
if (BAIDU_UNLIKELY(!thread_key_tls_data)) {
thread_key_tls_data = new std::vector<ThreadKeyTLS>;
thread_key_tls_data->reserve(THREAD_KEY_RESERVE);
// Register the destructor of tls_data in this thread.
butil::thread_atexit(DestroyTlsData);
}

if (id >= g_tls_data->size()) {
g_tls_data->resize(id + 1);
if (id >= thread_key_tls_data->size()) {
thread_key_tls_data->resize(id + 1);
}

(*g_tls_data)[id].seq = seq;
(*g_tls_data)[id].data = data;
(*thread_key_tls_data)[id].seq = seq;
(*thread_key_tls_data)[id].data = data;

return 0;
}
Expand All @@ -182,13 +159,13 @@ void* thread_getspecific(ThreadKey& thread_key) {
}
size_t id = thread_key._id;
size_t seq = thread_key._seq;
if (BAIDU_UNLIKELY(!g_tls_data ||
id >= g_tls_data->size() ||
(*g_tls_data)[id].seq != seq)){
if (BAIDU_UNLIKELY(!thread_key_tls_data ||
id >= thread_key_tls_data->size() ||
(*thread_key_tls_data)[id].seq != seq)){
return NULL;
}

return (*g_tls_data)[id].data;
return (*thread_key_tls_data)[id].data;
}

} // namespace butil
28 changes: 23 additions & 5 deletions src/butil/thread_key.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include <stdlib.h>
#include <vector>
#include "butil/scoped_lock.h"
#include "butil/type_traits.h"

namespace butil {

Expand All @@ -38,7 +39,7 @@ class ThreadKey {
static constexpr size_t InvalidID = std::numeric_limits<size_t>::max();
static constexpr size_t InitSeq = 0;

constexpr ThreadKey() :_id(InvalidID), _seq(InitSeq) {}
constexpr ThreadKey() : _id(InvalidID), _seq(InitSeq) {}

~ThreadKey() {
Reset();
Expand All @@ -62,7 +63,7 @@ class ThreadKey {
_seq = InitSeq;
}

private:
private:
size_t _id; // Key id.
// Sequence number form g_thread_keys set in thread_key_create.
size_t _seq;
Expand Down Expand Up @@ -111,6 +112,20 @@ class ThreadLocal {

T& operator*() const { return *get(); }

// Iterate through all thread local objects.
// Callback, which must accept Args params and return void,
// will be called under a thread lock.
template <typename Callback>
void for_each(Callback&& callback) {
BAIDU_CASSERT(
(is_result_void<Callback, T*>::value),
"Callback must accept Args params and return void");
BAIDU_SCOPED_LOCK(_mutex);
for (auto ptr : ptrs) {
callback(ptr);
}
}

void reset(T* ptr);

void reset() {
Expand Down Expand Up @@ -177,6 +192,9 @@ T* ThreadLocal<T>::get() {
template <typename T>
void ThreadLocal<T>::reset(T* ptr) {
T* old_ptr = get();
if (ptr == old_ptr) {
return;
}
if (thread_setspecific(_key, ptr) != 0) {
return;
}
Expand All @@ -187,9 +205,9 @@ void ThreadLocal<T>::reset(T* ptr) {
}
// Remove and delete old_ptr.
if (old_ptr) {
auto iter = std::find(ptrs.begin(), ptrs.end(), old_ptr);
if (iter!=ptrs.end()) {
ptrs.erase(iter);
auto iter = std::remove(ptrs.begin(), ptrs.end(), old_ptr);
if (iter != ptrs.end()) {
ptrs.erase(iter, ptrs.end());
}
DefaultDtor(old_ptr);
}
Expand Down
4 changes: 2 additions & 2 deletions test/endpoint_unittest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -486,11 +486,11 @@ TEST(EndPointTest, tcp_connect) {
ASSERT_EQ(0, butil::hostname2endpoint(g_hostname, 80, &ep));
{
butil::fd_guard sockfd(butil::tcp_connect(ep, NULL));
ASSERT_LE(0, sockfd);
ASSERT_LE(0, sockfd) << "errno=" << errno;
}
{
butil::fd_guard sockfd(butil::tcp_connect(ep, NULL, 1000));
ASSERT_LE(0, sockfd);
ASSERT_LE(0, sockfd) << "errno=" << errno;
}
{
butil::fd_guard sockfd(butil::tcp_connect(ep, NULL, 1));
Expand Down
44 changes: 42 additions & 2 deletions test/thread_key_unittest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ TEST(ThreadLocalTest, thread_key_seq) {
}
}

void* THreadKeyCreateAndDeleteFunc(void* arg) {
void* THreadKeyCreateAndDeleteFunc(void*) {
while (!g_stopped) {
ThreadKey key;
EXPECT_EQ(0, butil::thread_key_create(key, NULL));
Expand Down Expand Up @@ -162,7 +162,7 @@ TEST(ThreadLocalTest, thread_local_multi_thread) {
ASSERT_EQ(0, pthread_create(&threads[i], NULL, ThreadLocalFunc, &args));
}

sleep(5);
sleep(2);
g_stopped = true;
for (const auto& thread : threads) {
pthread_join(thread, NULL);
Expand All @@ -172,6 +172,46 @@ TEST(ThreadLocalTest, thread_local_multi_thread) {
}
}

butil::atomic<int> g_counter(0);

void* ThreadLocalForEachFunc(void* arg) {
auto counter = static_cast<ThreadLocal<butil::atomic<int>>*>(arg);
auto local_counter = counter->get();
EXPECT_NE(nullptr, local_counter);
while (!g_stopped) {
local_counter->fetch_add(1, butil::memory_order_relaxed);
g_counter.fetch_add(1, butil::memory_order_relaxed);
if (butil::fast_rand_less_than(100) + 1 > 80) {
local_counter = new butil::atomic<int>(
local_counter->load(butil::memory_order_relaxed));
counter->reset(local_counter);
}
}
return NULL;
}

TEST(ThreadLocalTest, thread_local_for_each) {
g_stopped = false;
ThreadLocal<butil::atomic<int>> counter(false);
const int thread_num = 8;
pthread_t threads[thread_num];
for (int i = 0; i < thread_num; ++i) {
ASSERT_EQ(0, pthread_create(
&threads[i], NULL, ThreadLocalForEachFunc, &counter));
}

sleep(2);
g_stopped = true;
for (const auto& thread : threads) {
pthread_join(thread, NULL);
}
int count = 0;
counter.for_each([&count](butil::atomic<int>* c) {
count += c->load(butil::memory_order_relaxed);
});
ASSERT_EQ(count, g_counter.load(butil::memory_order_relaxed));
}

struct BAIDU_CACHELINE_ALIGNMENT ThreadKeyArg {
std::vector<ThreadKey*> thread_keys;
bool ready_delete = false;
Expand Down

0 comments on commit 02fd47c

Please sign in to comment.