diff --git a/include/cuco/detail/equal_wrapper.cuh b/include/cuco/detail/equal_wrapper.cuh index 9dc6b030b..4fcb47bfd 100644 --- a/include/cuco/detail/equal_wrapper.cuh +++ b/include/cuco/detail/equal_wrapper.cuh @@ -26,7 +26,7 @@ namespace detail { * @brief Enum of equality comparison results */ // ENUM VALUE MATTERS, DO NOT CHANGE -enum class equal_result : int32_t { UNEQUAL = 0, EQUAL = 1, EMPTY = 2, AVAILABLE = 3 }; +enum class equal_result : int32_t { UNEQUAL = 0, EQUAL = 1, EMPTY = 2, ERASED = 3 }; enum class is_insert : bool { YES, NO }; @@ -97,10 +97,13 @@ struct equal_wrapper { __device__ constexpr equal_result operator()(LHS const& lhs, RHS const& rhs) const noexcept { if constexpr (IsInsert == is_insert::YES) { - return (cuco::detail::bitwise_compare(rhs, empty_sentinel_) or - cuco::detail::bitwise_compare(rhs, erased_sentinel_)) - ? equal_result::AVAILABLE - : this->equal_to(lhs, rhs); + if (cuco::detail::bitwise_compare(rhs, empty_sentinel_)) { + return equal_result::EMPTY; + } else if (cuco::detail::bitwise_compare(rhs, erased_sentinel_)) { + return equal_result::ERASED; + } else { + return this->equal_to(lhs, rhs); + } } else { return cuco::detail::bitwise_compare(rhs, empty_sentinel_) ? equal_result::EMPTY : this->equal_to(lhs, rhs); diff --git a/include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh b/include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh index 1676fb816..8100a84cf 100644 --- a/include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh +++ b/include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh @@ -383,7 +383,12 @@ class open_addressing_ref_impl { auto probing_iter = probing_scheme_(key, storage_ref_.bucket_extent()); auto const init_idx = *probing_iter; + [[maybe_unused]] auto probing_iter_copy = probing_iter; + [[maybe_unused]] bool erased = false; + [[maybe_unused]] bool empty_after_erased = false; + while (true) { + [[maybe_unused]] continue_after_erased: auto const bucket_slots = storage_ref_[*probing_iter]; for (auto& slot_content : bucket_slots) { @@ -393,23 +398,54 @@ class open_addressing_ref_impl { if constexpr (not allows_duplicates) { // If the key is already in the container, return false if (eq_res == detail::equal_result::EQUAL) { return false; } + if (eq_res == detail::equal_result::ERASED and not erased and not empty_after_erased) { + erased = true; + probing_iter_copy = probing_iter; + } + if (eq_res == detail::equal_result::EMPTY and erased and not empty_after_erased) { + empty_after_erased = true; + probing_iter = probing_iter_copy; + goto continue_after_erased; + } } - if (eq_res == detail::equal_result::AVAILABLE) { - auto const intra_bucket_index = thrust::distance(bucket_slots.begin(), &slot_content); - switch (attempt_insert((storage_ref_.data() + *probing_iter)->data() + intra_bucket_index, - slot_content, - val)) { - case insert_result::DUPLICATE: { - if constexpr (allows_duplicates) { - [[fallthrough]]; - } else { - return false; + + if (not erased or empty_after_erased) { + if ((eq_res == detail::equal_result::EMPTY) or (eq_res == detail::equal_result::ERASED)) { + auto const intra_bucket_index = thrust::distance(bucket_slots.begin(), &slot_content); + switch ( + attempt_insert((storage_ref_.data() + *probing_iter)->data() + intra_bucket_index, + slot_content, + val)) { + case insert_result::DUPLICATE: { + if constexpr (allows_duplicates) { + [[fallthrough]]; + } else { + return false; + } } + case insert_result::CONTINUE: continue; + case insert_result::SUCCESS: return true; } - case insert_result::CONTINUE: continue; - case insert_result::SUCCESS: return true; } } + + // if (eq_res == detail::equal_result::AVAILABLE) { + // auto const intra_bucket_index = thrust::distance(bucket_slots.begin(), &slot_content); + // switch (attempt_insert((storage_ref_.data() + *probing_iter)->data() + + // intra_bucket_index, + // slot_content, + // val)) { + // case insert_result::DUPLICATE: { + // if constexpr (allows_duplicates) { + // [[fallthrough]]; + // } else { + // return false; + // } + // } + // case insert_result::CONTINUE: continue; + // case insert_result::SUCCESS: return true; + // } + // } } ++probing_iter; if (*probing_iter == init_idx) { return false; } @@ -442,8 +478,10 @@ class open_addressing_ref_impl { for (auto i = 0; i < bucket_size; ++i) { switch ( this->predicate_.operator()(key, this->extract_key(bucket_slots[i]))) { - case detail::equal_result::AVAILABLE: - return bucket_probing_results{detail::equal_result::AVAILABLE, i}; + case detail::equal_result::EMPTY: + return bucket_probing_results{detail::equal_result::EMPTY, i}; + case detail::equal_result::ERASED: + return bucket_probing_results{detail::equal_result::ERASED, i}; case detail::equal_result::EQUAL: { if constexpr (allows_duplicates) { continue; @@ -463,7 +501,8 @@ class open_addressing_ref_impl { if (group.any(state == detail::equal_result::EQUAL)) { return false; } } - auto const group_contains_available = group.ballot(state == detail::equal_result::AVAILABLE); + auto const group_contains_available = group.ballot((state == detail::equal_result::EMPTY) or + (state == detail::equal_result::ERASED)); if (group_contains_available) { auto const src_lane = __ffs(group_contains_available) - 1; auto const status = @@ -538,7 +577,7 @@ class open_addressing_ref_impl { } return {iterator{&bucket_ptr[i]}, false}; } - if (eq_res == detail::equal_result::AVAILABLE) { + if ((eq_res == detail::equal_result::EMPTY) or (eq_res == detail::equal_result::ERASED)) { switch (this->attempt_insert_stable(bucket_ptr + i, bucket_slots[i], val)) { case insert_result::SUCCESS: { if constexpr (has_payload) { @@ -626,7 +665,8 @@ class open_addressing_ref_impl { return {iterator{reinterpret_cast(res)}, false}; } - auto const group_contains_available = group.ballot(state == detail::equal_result::AVAILABLE); + auto const group_contains_available = group.ballot((state == detail::equal_result::EMPTY) or + (state == detail::equal_result::ERASED)); if (group_contains_available) { auto const src_lane = __ffs(group_contains_available) - 1; auto const res = group.shfl(reinterpret_cast(slot_ptr), src_lane); diff --git a/include/cuco/detail/static_map/static_map_ref.inl b/include/cuco/detail/static_map/static_map_ref.inl index 2deacaae3..cfd1ae035 100644 --- a/include/cuco/detail/static_map/static_map_ref.inl +++ b/include/cuco/detail/static_map/static_map_ref.inl @@ -510,7 +510,7 @@ class operator_impl< payload_ref.store(val.second, cuda::memory_order_relaxed); return; } - if (eq_res == detail::equal_result::AVAILABLE) { + if ((eq_res == detail::equal_result::EMPTY) or (eq_res == detail::equal_result::ERASED)) { if (attempt_insert_or_assign(slot_ptr, val)) { return; } } } @@ -571,7 +571,8 @@ class operator_impl< return; } - auto const group_contains_available = group.ballot(state == detail::equal_result::AVAILABLE); + auto const group_contains_available = group.ballot((state == detail::equal_result::EMPTY) or + (state == detail::equal_result::ERASED)); if (group_contains_available) { auto const src_lane = __ffs(group_contains_available) - 1; auto const status = @@ -883,7 +884,7 @@ class operator_impl< op(cuda::atomic_ref{slot_ptr->second}, val.second); return false; } - if (eq_res == detail::equal_result::AVAILABLE) { + if ((eq_res == detail::equal_result::EMPTY) or (eq_res == detail::equal_result::ERASED)) { switch (ref_.attempt_insert_or_apply(slot_ptr, slot_content, val, op)) { case insert_result::SUCCESS: return true; case insert_result::DUPLICATE: { @@ -970,7 +971,8 @@ class operator_impl< return false; } - auto const group_contains_available = group.ballot(state == detail::equal_result::AVAILABLE); + auto const group_contains_available = group.ballot((state == detail::equal_result::EMPTY) or + (state == detail::equal_result::ERASED)); if (group_contains_available) { auto const src_lane = __ffs(group_contains_available) - 1; auto const status = [&, target_idx = intra_bucket_index]() { diff --git a/tests/static_map/erase_test.cu b/tests/static_map/erase_test.cu index 4d68a680d..7025bc94b 100644 --- a/tests/static_map/erase_test.cu +++ b/tests/static_map/erase_test.cu @@ -75,6 +75,10 @@ void test_erase(Map& map, size_type num_keys) REQUIRE(cuco::test::all_of( d_keys_exist.begin() + num_keys / 2, d_keys_exist.end(), thrust::identity{})); + // tests #606 + map.insert(pairs_begin + num_keys / 2, pairs_begin + num_keys); + // TODO insert_and_find, insert_or_assign, insert_or_apply + map.erase(keys_begin + num_keys / 2, keys_begin + num_keys); REQUIRE(map.size() == 0); }