diff --git a/bench/bench_skiplist.ml b/bench/bench_skiplist.ml new file mode 100644 index 00000000..2b420570 --- /dev/null +++ b/bench/bench_skiplist.ml @@ -0,0 +1,57 @@ +open Saturn + +let workload num_elems num_threads add remove = + let sl = Skiplist.create ~compare:Int.compare () in + let elems = Array.init num_elems (fun _ -> Random.int 10000) in + let push () = + Domain.spawn (fun () -> + let start_time = Unix.gettimeofday () in + for i = 0 to (num_elems - 1) / num_threads do + Domain.cpu_relax (); + let prob = Random.float 1.0 in + if prob < add then Skiplist.try_add sl (Random.int 10000) () |> ignore + else if prob >= add && prob < add +. remove then + Skiplist.try_remove sl (Random.int 10000) |> ignore + else Skiplist.mem sl elems.(i) |> ignore + done; + start_time) + in + let threads = List.init num_threads (fun _ -> push ()) in + let start_time_threads = + List.map (fun domain -> Domain.join domain) threads + in + let end_time = Unix.gettimeofday () in + let time_diff = end_time -. List.nth start_time_threads 0 in + time_diff + +(* A write heavy workload with threads with 50% adds and 50% removes. *) +let write_heavy_workload num_elems num_threads = + workload num_elems num_threads 0.5 0.5 + +(* A regular workload with 90% reads, 9% adds and 1% removes. *) +let read_heavy_workload num_elems num_threads = + workload num_elems num_threads 0.09 0.01 + +let moderate_heavy_workload num_elems num_threads = + workload num_elems num_threads 0.2 0.1 + +let balanced_heavy_workload num_elems num_threads = + workload num_elems num_threads 0.3 0.2 + +let bench ~workload_type ~num_elems ~num_threads () = + let workload = + if workload_type = "read_heavy" then read_heavy_workload + else if workload_type = "moderate_heavy" then moderate_heavy_workload + else if workload_type = "balanced_heavy" then balanced_heavy_workload + else write_heavy_workload + in + let results = ref [] in + for i = 1 to 10 do + let time = workload num_elems num_threads in + if i > 1 then results := time :: !results + done; + let results = List.sort Float.compare !results in + let median_time = List.nth results 4 in + let median_throughput = Float.of_int num_elems /. median_time in + Benchmark_result.create_generic ~median_time ~median_throughput + ("atomic_skiplist_" ^ workload_type) diff --git a/bench/main.ml b/bench/main.ml index f99a0bfe..7f694502 100644 --- a/bench/main.ml +++ b/bench/main.ml @@ -7,6 +7,10 @@ let benchmark_list = Mpmc_queue.bench ~use_cas:true ~takers:4 ~pushers:4; Mpmc_queue.bench ~use_cas:true ~takers:1 ~pushers:8; Mpmc_queue.bench ~use_cas:true ~takers:8 ~pushers:1; + Bench_skiplist.bench ~workload_type:"read_heavy" ~num_elems:2000000 + ~num_threads:2; + Bench_skiplist.bench ~workload_type:"moderate_heavy" ~num_elems:2000000 + ~num_threads:2; ] let () = diff --git a/src/saturn.ml b/src/saturn.ml index 10c6e7c7..ad9de103 100644 --- a/src/saturn.ml +++ b/src/saturn.ml @@ -35,3 +35,4 @@ module Single_prod_single_cons_queue = module Single_consumer_queue = Saturn_lockfree.Single_consumer_queue module Relaxed_queue = Mpmc_relaxed_queue +module Skiplist = Saturn_lockfree.Skiplist diff --git a/src/saturn.mli b/src/saturn.mli index d1b5eaf0..4e48698d 100644 --- a/src/saturn.mli +++ b/src/saturn.mli @@ -39,3 +39,4 @@ module Single_prod_single_cons_queue = module Single_consumer_queue = Saturn_lockfree.Single_consumer_queue module Relaxed_queue = Mpmc_relaxed_queue +module Skiplist = Saturn_lockfree.Skiplist diff --git a/src_lockfree/saturn_lockfree.ml b/src_lockfree/saturn_lockfree.ml index e5aeb7af..9bd337c9 100644 --- a/src_lockfree/saturn_lockfree.ml +++ b/src_lockfree/saturn_lockfree.ml @@ -33,3 +33,4 @@ module Single_prod_single_cons_queue = Spsc_queue module Single_consumer_queue = Mpsc_queue module Relaxed_queue = Mpmc_relaxed_queue module Size = Size +module Skiplist = Skiplist diff --git a/src_lockfree/saturn_lockfree.mli b/src_lockfree/saturn_lockfree.mli index 44870cd2..5662b6b4 100644 --- a/src_lockfree/saturn_lockfree.mli +++ b/src_lockfree/saturn_lockfree.mli @@ -36,4 +36,5 @@ module Work_stealing_deque = Ws_deque module Single_prod_single_cons_queue = Spsc_queue module Single_consumer_queue = Mpsc_queue module Relaxed_queue = Mpmc_relaxed_queue +module Skiplist = Skiplist module Size = Size diff --git a/src_lockfree/skiplist.ml b/src_lockfree/skiplist.ml new file mode 100644 index 00000000..bd4aabf3 --- /dev/null +++ b/src_lockfree/skiplist.ml @@ -0,0 +1,355 @@ +(* Copyright (c) 2023 Vesa Karvonen + + Permission to use, copy, modify, and/or distribute this software for any + purpose with or without fee is hereby granted, provided that the above + copyright notice and this permission notice appear in all copies. + + THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES WITH + REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY + AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY SPECIAL, DIRECT, + INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM + LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR + OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR + PERFORMANCE OF THIS SOFTWARE. *) + +(* This implementation has been written from scratch with inspiration from a + lock-free skiplist implementation in PR + + https://github.com/ocaml-multicore/saturn/pull/65 + + by + + Sooraj Srinivasan ( https://github.com/sooraj-srini ) + + including tests and changes by + + Carine Morel ( https://github.com/lyrm ). *) + +(* TODO: Grow and possibly shrink the skiplist or e.g. adjust search and node + generation based on the dynamic number of bindings. *) + +module Atomic = Transparent_atomic + +(* OCaml doesn't allow us to use one of the unused (always 0) bits in pointers + for the marks and an indirection is needed. This representation avoids the + indirection except for marked references in nodes to be removed. A GADT with + polymorphic variants is used to disallow nested [Mark]s. *) +type ('k, 'v, _) node = + | Null : ('k, 'v, [> `Null ]) node + | Node : { + key : 'k; + value : 'v; + next : ('k, 'v) links; + mutable incr : Size.once; + } + -> ('k, 'v, [> `Node ]) node + | Mark : { + node : ('k, 'v, [< `Null | `Node ]) node; + decr : Size.once; + } + -> ('k, 'v, [> `Mark ]) node + +(* The implementation relies on this existential being unboxed. More + specifically, it is assumed that [Link node == Link node] meaning that the + [Link] constructor does not allocate. *) +and ('k, 'v) link = + | Link : ('k, 'v, [< `Null | `Node | `Mark ]) node -> ('k, 'v) link +[@@unboxed] + +and ('k, 'v) links = ('k, 'v) link Atomic.t array + +type 'k compare = 'k -> 'k -> int +(* Encoding the [compare] function using an algebraic type would allow the + overhead of calling a closure to be avoided for selected primitive types like + [int]. *) + +type ('k, 'v) t = { compare : 'k compare; root : ('k, 'v) links; size : Size.t } + +(* *) + +(** [get_random_height max_height] gives a random value [n] in the range from + [1] to [max_height] with the desired distribution such that [n] is twice as + likely as [n + 1]. *) +let rec get_random_height max_height = + let m = (1 lsl max_height) - 1 in + let x = Random.bits () land m in + if x = 1 then + (* We reject [1] to get the desired distribution. *) + get_random_height max_height + else + (* We do a binary search for the highest 1 bit. Techniques in + + Using de Bruijn Sequences to Index a 1 in a Computer Word + by Leiserson, Prokop, and Randall + + could perhaps speed this up a bit, but this is likely not performance + critical. *) + let n = 0 in + let n, x = if 0xFFFF < x then (n + 0x10, x lsr 0x10) else (n, x) in + let n, x = if 0x00FF < x then (n + 0x08, x lsr 0x08) else (n, x) in + let n, x = if 0x000F < x then (n + 0x04, x lsr 0x04) else (n, x) in + let n, x = if 0x0003 < x then (n + 0x02, x lsr 0x02) else (n, x) in + let n, _ = if 0x0001 < x then (n + 0x01, x lsr 0x01) else (n, x) in + max_height - n + +(* *) + +let[@inline] is_marked = function + | Link (Mark _) -> true + | Link (Null | Node _) -> false + +(* *) + +(** [find_path t key preds succs lowest] tries to find the node with the specified + [key], updating [preds] and [succs] and removing nodes with marked + references along the way, and always descending down to [lowest] level. The + boolean return value is only meaningful when [lowest] is given as [0]. *) +let rec find_path t key preds succs lowest = + let prev = t.root in + let level = Array.length prev - 1 in + let prev_at_level = Array.unsafe_get prev level in + find_path_rec t key prev prev_at_level preds succs level lowest + (Atomic.get prev_at_level) + +and find_path_rec t key prev prev_at_level preds succs level lowest = function + | Link Null -> + if level < Array.length preds then begin + Array.unsafe_set preds level prev_at_level; + Array.unsafe_set succs level Null + end; + lowest < level + && + let level = level - 1 in + let prev_at_level = Array.unsafe_get prev level in + find_path_rec t key prev prev_at_level preds succs level lowest + (Atomic.get prev_at_level) + | Link (Node r as curr) -> begin + let next_at_level = Array.unsafe_get r.next level in + match Atomic.get next_at_level with + | Link (Null | Node _) as next -> + let c = t.compare key r.key in + if 0 < c then + find_path_rec t key r.next next_at_level preds succs level lowest + next + else begin + if level < Array.length preds then begin + Array.unsafe_set preds level (Array.unsafe_get prev level); + Array.unsafe_set succs level curr + end; + if lowest < level then + let level = level - 1 in + let prev_at_level = Array.unsafe_get prev level in + find_path_rec t key prev prev_at_level preds succs level lowest + (Atomic.get prev_at_level) + else begin + if level = 0 && r.incr != Size.used_once then begin + Size.update_once t.size r.incr; + r.incr <- Size.used_once + end; + 0 = c + end + end + | Link (Mark r) -> + (* The [curr_node] is being removed from the skiplist and we help with + that. *) + if level = 0 then Size.update_once t.size r.decr; + find_path_rec t key prev prev_at_level preds succs level lowest + (let after = Link r.node in + if Atomic.compare_and_set prev_at_level (Link curr) after then + after + else Atomic.get prev_at_level) + end + | Link (Mark _) -> + (* The node corresponding to [prev] is being removed from the skiplist. + This means we might no longer have an up-to-date view of the skiplist + and so we must restart the search. *) + find_path t key preds succs lowest + +(* *) + +(** [find_node t key] tries to find the node with the specified [key], removing + nodes with marked references along the way, and stopping as soon as the node + is found. *) +let rec find_node t key = + let prev = t.root in + let level = Array.length prev - 1 in + let prev_at_level = Array.unsafe_get prev level in + find_node_rec t key prev prev_at_level level (Atomic.get prev_at_level) + +and find_node_rec t key prev prev_at_level level : + _ -> (_, _, [< `Null | `Node ]) node = function + | Link Null -> + if 0 < level then + let level = level - 1 in + let prev_at_level = Array.unsafe_get prev level in + find_node_rec t key prev prev_at_level level (Atomic.get prev_at_level) + else Null + | Link (Node r as curr) -> begin + let next_at_level = Array.unsafe_get r.next level in + match Atomic.get next_at_level with + | Link (Null | Node _) as next -> + let c = t.compare key r.key in + if 0 < c then find_node_rec t key r.next next_at_level level next + else if 0 = c then begin + (* At this point we know the node was not removed, because removal + is done in order of descending levels. *) + if r.incr != Size.used_once then begin + Size.update_once t.size r.incr; + r.incr <- Size.used_once + end; + curr + end + else if 0 < level then + let level = level - 1 in + let prev_at_level = Array.unsafe_get prev level in + find_node_rec t key prev prev_at_level level + (Atomic.get prev_at_level) + else Null + | Link (Mark r) -> + if level = 0 then Size.update_once t.size r.decr; + find_node_rec t key prev prev_at_level level + (let after = Link r.node in + if Atomic.compare_and_set prev_at_level (Link curr) after then + after + else Atomic.get prev_at_level) + end + | Link (Mark _) -> find_node t key + +(* *) + +let create ?(max_height = 10) ~compare () = + (* The upper limit of [30] comes from [Random.bits ()] as well as from + limitations with 32-bit implementations. It should not be a problem in + practice. *) + if max_height < 1 || 30 < max_height then + invalid_arg "Skiplist: max_height must be in the range [1, 30]"; + let root = Array.init max_height @@ fun _ -> Atomic.make (Link Null) in + let size = Size.create () in + { compare; root; size } + +let max_height_of t = Array.length t.root + +(* *) + +let find_opt t key = + match find_node t key with Null -> None | Node r -> Some r.value + +(* *) + +let mem t key = match find_node t key with Null -> false | Node _ -> true + +(* *) + +let rec try_add t key value preds succs = + (not (find_path t key preds succs 0)) + && + let (Node r as node : (_, _, [ `Node ]) node) = + let next = Array.map (fun succ -> Atomic.make (Link succ)) succs in + let incr = Size.new_once t.size Size.incr in + Node { key; value; incr; next } + in + if + let succ = Link (Array.unsafe_get succs 0) in + Atomic.compare_and_set (Array.unsafe_get preds 0) succ (Link node) + then begin + if r.incr != Size.used_once then begin + Size.update_once t.size r.incr; + r.incr <- Size.used_once + end; + (* The node is now considered as added to the skiplist. *) + let rec update_levels level = + if Array.length r.next = level then begin + if is_marked (Atomic.get (Array.unsafe_get r.next (level - 1))) then begin + (* The node we finished adding has been removed concurrently. To + ensure that no references we added to the node remain, we call + [find_node] which will remove nodes with marked references along + the way. *) + find_node t key |> ignore + end; + true + end + else if + let succ = Link (Array.unsafe_get succs level) in + Atomic.compare_and_set (Array.unsafe_get preds level) succ (Link node) + then update_levels (level + 1) + else + let _found = find_path t key preds succs level in + let rec update_nexts level' = + if level' < level then update_levels level + else + let next = Array.unsafe_get r.next level' in + match Atomic.get next with + | Link (Null | Node _) as before -> + let succ = Link (Array.unsafe_get succs level') in + if before != succ then + (* It is possible for a concurrent remove operation to have + marked the link. *) + if Atomic.compare_and_set next before succ then + update_nexts (level' - 1) + else update_levels level + else update_nexts (level' - 1) + | Link (Mark _) -> + (* The node we were trying to add has been removed concurrently. + To ensure that no references we added to the node remain, we + call [find_node] which will remove nodes with marked + references along the way. *) + find_node t key |> ignore; + true + in + update_nexts (Array.length r.next - 1) + in + update_levels 1 + end + else try_add t key value preds succs + +let try_add t key value = + let height = get_random_height (Array.length t.root) in + let preds = + (* Init with [Obj.magic ()] is safe as the array is fully overwritten by + [find_path] called at the start of the recursive [try_add]. *) + Array.make height (Obj.magic ()) + in + let succs = Array.make height Null in + try_add t key value preds succs + +(* *) + +let rec try_remove t key next level link = function + | Link (Mark r) -> + if level = 0 then begin + Size.update_once t.size r.decr; + false + end + else + let level = level - 1 in + let link = Array.unsafe_get next level in + try_remove t key next level link (Atomic.get link) + | Link ((Null | Node _) as succ) -> + let decr = + if level = 0 then Size.new_once t.size Size.decr else Size.used_once + in + let marked_succ = Mark { node = succ; decr } in + if Atomic.compare_and_set link (Link succ) (Link marked_succ) then + if level = 0 then + (* We have finished marking references on the node. To ensure that no + references to the node remain, we call [find_node] which will + remove nodes with marked references along the way. *) + let _node = find_node t key in + true + else + let level = level - 1 in + let link = Array.unsafe_get next level in + try_remove t key next level link (Atomic.get link) + else try_remove t key next level link (Atomic.get link) + +let try_remove t key = + match find_node t key with + | Null -> false + | Node { next; _ } -> + let level = Array.length next - 1 in + let link = Array.unsafe_get next level in + try_remove t key next level link (Atomic.get link) + +(* *) + +let length t = Size.get t.size diff --git a/src_lockfree/skiplist.mli b/src_lockfree/skiplist.mli new file mode 100644 index 00000000..174b781a --- /dev/null +++ b/src_lockfree/skiplist.mli @@ -0,0 +1,42 @@ +(** A lock-free skiplist. *) + +type (!'k, !'v) t +(** The type of a lock-free skiplist containing bindings of keys of type ['k] to + values of type ['v]. *) + +val create : ?max_height:int -> compare:('k -> 'k -> int) -> unit -> ('k, 'v) t +(** [create ~compare ()] creates a new empty skiplist where keys are ordered + based on the given [compare] function. + + Note that the polymorphic [Stdlib.compare] function has relatively high + overhead and it is usually better to use a type specific [compare] function + such as [Int.compare] or [String.compare]. + + The optional [max_height] argument determines the maximum height of nodes in + the skiplist and directly affects the performance of the skiplist. The + current implementation does not adjust height automatically. *) + +val max_height_of : ('k, 'v) t -> int +(** [max_height_of s] returns the maximum height of nodes of the skiplist [s] as + specified to {!create}. *) + +val find_opt : ('k, 'v) t -> 'k -> 'v option +(** [find_opt s k] tries to find a binding of [k] to [v] from the skiplist [s] + and returns [Some v] in case such a binding was found or return [None] in + case no such binding was found. *) + +val mem : ('k, 'v) t -> 'k -> bool +(** [mem s k] determines whether the skiplist [s] contained a binding of [k]. *) + +val try_add : ('k, 'v) t -> 'k -> 'v -> bool +(** [try_add s k v] tries to add a new binding of [k] to [v] into the skiplist + [s] and returns [true] on success. Otherwise the skiplist already contained + a binding of [k] and [false] is returned. *) + +val try_remove : ('k, 'v) t -> 'k -> bool +(** [try_remove s k] tries to remove a binding of [k] from the skiplist and + returns [true] on success. Otherwise the skiplist did not contain a binding + of [k] and [false] is returned. *) + +val length : ('k, 'v) t -> int +(** [length s] computes the number of bindings in the skiplist [s]. *) diff --git a/test/skiplist/dscheck_skiplist.ml b/test/skiplist/dscheck_skiplist.ml new file mode 100644 index 00000000..3ae8d4a5 --- /dev/null +++ b/test/skiplist/dscheck_skiplist.ml @@ -0,0 +1,108 @@ +open Skiplist + +let test_max_height_of () = + let s = create ~max_height:1 ~compare () in + assert (max_height_of s = 1); + let s = create ~max_height:10 ~compare () in + assert (max_height_of s = 10); + let s = create ~max_height:30 ~compare () in + assert (max_height_of s = 30) + +let try_add s k = try_add s k () + +let _two_mem () = + Atomic.trace (fun () -> + Random.init 0; + let sl = create ~max_height:2 ~compare:Int.compare () in + let added1 = ref false in + let found1 = ref false in + let found2 = ref false in + + Atomic.spawn (fun () -> + added1 := try_add sl 1; + found1 := mem sl 1); + + Atomic.spawn (fun () -> found2 := mem sl 2); + + Atomic.final (fun () -> + Atomic.check (fun () -> !added1 && !found1 && not !found2))) + +let _two_add () = + Atomic.trace (fun () -> + Random.init 0; + let sl = create ~max_height:3 ~compare:Int.compare () in + let added1 = ref false in + let added2 = ref false in + + Atomic.spawn (fun () -> added1 := try_add sl 1); + Atomic.spawn (fun () -> added2 := try_add sl 2); + + Atomic.final (fun () -> + Atomic.check (fun () -> !added1 && !added2 && mem sl 1 && mem sl 2))) + +let _two_add_same () = + Atomic.trace (fun () -> + Random.init 0; + let sl = create ~max_height:3 ~compare:Int.compare () in + let added1 = ref false in + let added2 = ref false in + + Atomic.spawn (fun () -> added1 := try_add sl 1); + Atomic.spawn (fun () -> added2 := try_add sl 1); + + Atomic.final (fun () -> + Atomic.check (fun () -> + (!added1 && not !added2) + || (((not !added1) && !added2) && mem sl 1)))) + +let _two_remove_same () = + Atomic.trace (fun () -> + Random.init 0; + let sl = create ~max_height:2 ~compare:Int.compare () in + let added1 = ref false in + let removed1 = ref false in + let removed2 = ref false in + + Atomic.spawn (fun () -> + added1 := try_add sl 1; + removed1 := try_remove sl 1); + Atomic.spawn (fun () -> removed2 := try_remove sl 1); + + Atomic.final (fun () -> + Atomic.check (fun () -> + !added1 + && ((!removed1 && not !removed2) || ((not !removed1) && !removed2)) + && not (mem sl 1)))) + +let _two_remove () = + Atomic.trace (fun () -> + Random.init 0; + let sl = create ~max_height:2 ~compare:Int.compare () in + let added1 = ref false in + let removed1 = ref false in + let removed2 = ref false in + + Atomic.spawn (fun () -> + added1 := try_add sl 1; + removed1 := try_remove sl 1); + Atomic.spawn (fun () -> removed2 := try_remove sl 2); + + Atomic.final (fun () -> + Atomic.check (fun () -> + let found1 = mem sl 1 in + !added1 && !removed1 && (not !removed2) && not found1))) + +let () = + let open Alcotest in + run "DSCheck Skiplist" + [ + ( "basic", + [ + test_case "max_height_of" `Quick test_max_height_of; + test_case "2-mem" `Slow _two_mem; + test_case "2-add-same" `Slow _two_add_same; + test_case "2-add" `Slow _two_add; + test_case "2-remove-same" `Slow _two_remove_same; + test_case "2-remove" `Slow _two_remove; + ] ); + ] diff --git a/test/skiplist/dune b/test/skiplist/dune new file mode 100644 index 00000000..4fd90697 --- /dev/null +++ b/test/skiplist/dune @@ -0,0 +1,32 @@ +(rule + (action + (progn + (copy ../../src_lockfree/skiplist.ml skiplist.ml) + (copy ../../src_lockfree/size.ml size.ml))) + (package saturn_lockfree)) + +(test + (package saturn_lockfree) + (name dscheck_skiplist) + (modules skiplist size dscheck_skiplist) + (libraries atomic transparent_atomic dscheck alcotest multicore-magic)) + +(test + (package saturn_lockfree) + (name qcheck_skiplist) + (modules qcheck_skiplist) + (libraries saturn qcheck qcheck-core qcheck-alcotest alcotest)) + +(test + (package saturn_lockfree) + (name stm_skiplist) + (modules stm_skiplist) + (libraries + saturn + qcheck-core + qcheck-core.runner + qcheck-stm.stm + qcheck-stm.sequential + qcheck-stm.domain) + (action + (run %{test} --verbose))) diff --git a/test/skiplist/qcheck_skiplist.ml b/test/skiplist/qcheck_skiplist.ml new file mode 100644 index 00000000..dce2d6bb --- /dev/null +++ b/test/skiplist/qcheck_skiplist.ml @@ -0,0 +1,195 @@ +module Skiplist = struct + include Saturn.Skiplist + + let try_add s k = try_add s k () +end + +module IntSet = Set.Make (Int) + +let[@tail_mod_cons] rec uniq ?(seen = IntSet.empty) = function + | [] -> [] + | x :: xs -> + if IntSet.mem x seen then uniq ~seen xs + else x :: uniq ~seen:(IntSet.add x seen) xs + +let tests_sequential = + QCheck. + [ + (* TEST 1: add*) + Test.make ~name:"add" (list int) (fun lpush -> + let sl = Skiplist.create ~compare:Int.compare () in + let rec add_all_elems seen l = + match l with + | h :: t -> + if Skiplist.try_add sl h <> IntSet.mem h seen then + add_all_elems (IntSet.add h seen) t + else false + | [] -> true + in + add_all_elems IntSet.empty lpush); + (*TEST 2: add_remove*) + Test.make ~name:"add_remove" (list int) (fun lpush -> + let lpush = uniq lpush in + let sl = Skiplist.create ~compare:Int.compare () in + List.iter (fun key -> ignore (Skiplist.try_add sl key)) lpush; + let rec remove_all_elems l = + match l with + | h :: t -> + if Skiplist.try_remove sl h then remove_all_elems t else false + | [] -> true + in + remove_all_elems lpush); + (*TEST 3: add_find*) + Test.make ~name:"add_find" (list int) (fun lpush -> + let lpush = uniq lpush in + let lpush = Array.of_list lpush in + let sl = Skiplist.create ~compare:Int.compare () in + let len = Array.length lpush in + let pos = Array.sub lpush 0 (len / 2) in + let neg = Array.sub lpush (len / 2) (len / 2) in + Array.iter (fun key -> ignore @@ Skiplist.try_add sl key) pos; + let rec check_pos index = + if index < len / 2 then + if Skiplist.mem sl pos.(index) then check_pos (index + 1) + else false + else true + in + let rec check_neg index = + if index < len / 2 then + if not @@ Skiplist.mem sl neg.(index) then check_neg (index + 1) + else false + else true + in + check_pos 0 && check_neg 0); + (* TEST 4: add_remove_find *) + Test.make ~name:"add_remove_find" (list int) (fun lpush -> + let lpush = uniq lpush in + let sl = Skiplist.create ~compare:Int.compare () in + List.iter (fun key -> ignore @@ Skiplist.try_add sl key) lpush; + List.iter (fun key -> ignore @@ Skiplist.try_remove sl key) lpush; + let rec not_find_all_elems l = + match l with + | h :: t -> + if not @@ Skiplist.mem sl h then not_find_all_elems t else false + | [] -> true + in + + not_find_all_elems lpush); + ] + +let tests_two_domains = + QCheck. + [ + (* TEST 1: Two domains doing multiple adds *) + Test.make ~name:"parallel_add" (pair small_nat small_nat) + (fun (npush1, npush2) -> + let sl = Skiplist.create ~compare:Int.compare () in + let sema = Semaphore.Binary.make false in + let lpush1 = List.init npush1 (fun i -> i) in + let lpush2 = List.init npush2 (fun i -> i + npush1) in + let work lpush = + List.map + (fun elt -> + let completed = Skiplist.try_add sl elt in + Domain.cpu_relax (); + completed) + lpush + in + + let domain1 = + Domain.spawn (fun () -> + Semaphore.Binary.release sema; + work lpush1) + in + let popped2 = + while not (Semaphore.Binary.try_acquire sema) do + Domain.cpu_relax () + done; + work lpush2 + in + let popped1 = Domain.join domain1 in + let rec compare_all_true l = + match l with + | true :: t -> compare_all_true t + | false :: _ -> false + | [] -> true + in + compare_all_true popped1 && compare_all_true popped2); + (* TEST 2: Two domains doing multiple one push and one pop in parallel *) + Test.make ~count:10000 ~name:"parallel_add_remove" + (pair small_nat small_nat) (fun (npush1, npush2) -> + let sl = Skiplist.create ~compare:Int.compare () in + let sema = Semaphore.Binary.make false in + + let lpush1 = List.init npush1 (fun i -> i) in + let lpush2 = List.init npush2 (fun i -> i + npush1) in + + let work lpush = + List.map + (fun elt -> + ignore @@ Skiplist.try_add sl elt; + Domain.cpu_relax (); + Skiplist.try_remove sl elt) + lpush + in + + let domain1 = + Domain.spawn (fun () -> + Semaphore.Binary.release sema; + work lpush1) + in + let _ = + while not (Semaphore.Binary.try_acquire sema) do + Domain.cpu_relax () + done; + work lpush2 + in + let _ = Domain.join domain1 in + + let rec check_none_present l = + match l with + | h :: t -> + if Skiplist.mem sl h then false else check_none_present t + | [] -> true + in + check_none_present lpush1 && check_none_present lpush2); + (* TEST 3: Parallel push and pop using the same elements in two domains *) + Test.make ~name:"parallel_add_remove_same_list" (list int) (fun lpush -> + let sl = Skiplist.create ~compare:Int.compare () in + let sema = Semaphore.Binary.make false in + let add_all_elems l = List.map (Skiplist.try_add sl) l in + let remove_all_elems l = List.map (Skiplist.try_remove sl) l in + + let domain1 = + Domain.spawn (fun () -> + Semaphore.Binary.release sema; + Domain.cpu_relax (); + let add1 = add_all_elems lpush in + let remove1 = remove_all_elems lpush in + (add1, remove1)) + in + let _, _ = + while not (Semaphore.Binary.try_acquire sema) do + Domain.cpu_relax () + done; + let add2 = add_all_elems lpush in + let remove2 = remove_all_elems lpush in + (add2, remove2) + in + let _, _ = Domain.join domain1 in + let rec check_none_present l = + match l with + | h :: t -> + if Skiplist.mem sl h then false else check_none_present t + | [] -> true + in + check_none_present lpush); + ] + +let () = + let to_alcotest = List.map QCheck_alcotest.to_alcotest in + Alcotest.run "QCheck Skiplist" + [ + ("test_sequential", to_alcotest tests_sequential); + ("tests_two_domains", to_alcotest tests_two_domains); + ] diff --git a/test/skiplist/stm_skiplist.ml b/test/skiplist/stm_skiplist.ml new file mode 100644 index 00000000..0aab1c47 --- /dev/null +++ b/test/skiplist/stm_skiplist.ml @@ -0,0 +1,76 @@ +open QCheck +open STM + +module Skiplist = struct + include Saturn.Skiplist + + type nonrec 'a t = ('a, unit) t + + let try_add s k = try_add s k () +end + +module WSDConf = struct + type cmd = Mem of int | Add of int | Remove of int | Length + + let show_cmd c = + match c with + | Mem i -> "Mem " ^ string_of_int i + | Add i -> "Add " ^ string_of_int i + | Remove i -> "Remove " ^ string_of_int i + | Length -> "Length" + + module Sint = Set.Make (Int) + + type state = Sint.t + type sut = int Skiplist.t + + let arb_cmd _s = + let int_gen = Gen.nat in + QCheck.make ~print:show_cmd + (Gen.oneof + [ + Gen.map (fun i -> Add i) int_gen; + Gen.map (fun i -> Mem i) int_gen; + Gen.map (fun i -> Remove i) int_gen; + Gen.return Length; + ]) + + let init_state = Sint.empty + let init_sut () = Skiplist.create ~compare:Int.compare () + let cleanup _ = () + + let next_state c s = + match c with + | Add i -> Sint.add i s + | Remove i -> Sint.remove i s + | Mem _ -> s + | Length -> s + + let precond _ _ = true + + let run c d = + match c with + | Add i -> Res (bool, Skiplist.try_add d i) + | Remove i -> Res (bool, Skiplist.try_remove d i) + | Mem i -> Res (bool, Skiplist.mem d i) + | Length -> Res (int, Skiplist.length d) + + let postcond c (s : state) res = + match (c, res) with + | Add i, Res ((Bool, _), res) -> Sint.mem i s = not res + | Remove i, Res ((Bool, _), res) -> Sint.mem i s = res + | Mem i, Res ((Bool, _), res) -> Sint.mem i s = res + | Length, Res ((Int, _), res) -> Sint.cardinal s = res + | _, _ -> false +end + +module WSDT_seq = STM_sequential.Make (WSDConf) +module WSDT_dom = STM_domain.Make (WSDConf) + +let () = + let count = 1000 in + QCheck_base_runner.run_tests_main + [ + WSDT_seq.agree_test ~count ~name:"STM Lockfree.Skiplist test sequential"; + WSDT_dom.agree_test_par ~count ~name:"STM Lockfree.Skiplist test parallel"; + ]