Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
bclement-ocp committed Aug 29, 2024
1 parent 1111fb0 commit b00d6c0
Show file tree
Hide file tree
Showing 7 changed files with 190 additions and 79 deletions.
250 changes: 175 additions & 75 deletions src/lib/reasoners/bitv_rel.ml
Original file line number Diff line number Diff line change
Expand Up @@ -474,18 +474,22 @@ module Interner = struct
if ofs >= bitwidth r then
Shostak.Polynome.Ints.zero
else
let table =
try H.find t.bvasr r
match X.term_extract r with
| Some t, true when ofs = 0 ->
Shostak.Arith.embed (X.term_embed (E.BV.bv2nat t))
| _ ->
let table =
try H.find t.bvasr r
with Not_found ->
let table = Hashtbl.create 3 in
H.replace t.bvasr r table;
table
in
try Hashtbl.find table ofs
with Not_found ->
let table = Hashtbl.create 3 in
H.replace t.bvasr r table;
table
in
try Hashtbl.find table ofs
with Not_found ->
let p = Shostak.Arith.embed (X.term_embed (E.fresh_name Tint)) in
Hashtbl.replace table ofs p;
p
let p = Shostak.Arith.embed (X.term_embed (E.fresh_name Tint)) in
Hashtbl.replace table ofs p;
p

let find_bv2poly t r =
try EE.find t.bv2poly r
Expand Down Expand Up @@ -703,10 +707,11 @@ module BV2Nat = struct
(* Returns the polynomial associated with [bv2nat(bv asr ofs)], creating a
fresh variable for it if it does not exist. *)
let find_or_create_asr bv ofs t =
let ext = Extraction.shift_right ~size:(bitwidth bv) ofs in
let sz = bitwidth bv in
let ext = Extraction.shift_right ~size:sz ofs in
try find_ext bv ext t.bv2nat, t
with Not_found ->
if ofs >= bitwidth bv then
if ofs >= sz then
(T.Ints.zero, Ex.empty), t
else
let k = Interner.find_asr global_interner bv ofs in
Expand Down Expand Up @@ -2014,6 +2019,7 @@ type state =
; bitlist_variables : SX.t MX.t
; interval_changed : unit HX.t
; bitlist_changed : unit HX.t
; mutable steps : int
}

type 'a hashable = (module Hashtbl.HashedType with type t = 'a)
Expand All @@ -2033,6 +2039,9 @@ module Any_propagator : sig
val push : t -> 'a propagator -> 'a -> unit

val run1 : state -> t -> bool

val filter_fold :
'a propagator -> ('a -> 'b -> 'b) -> t -> 'b -> 'b
end
end = struct
type 'a t =
Expand Down Expand Up @@ -2076,6 +2085,13 @@ end = struct
match Q.pop q with
| B ({ run ; _ }, p) -> run st p; true
| exception Q.Empty -> false

let filter_fold (type a) (k : a propagator) (f : a -> _ -> _) q acc =
Q.fold (fun (B (k', p)) acc ->
match Compat.Type.Id.provably_equal k.id k'.id with
| Some Equal -> f (p : a) acc
| None -> acc
) q acc
end
end

Expand Down Expand Up @@ -2410,7 +2426,8 @@ let state uf idom bdom =
; bitlists = bdom
; bitlists_uf = uf_bdom
; bitlist_variables = bvars
; bitlist_changed = HX.create 17 }
; bitlist_changed = HX.create 17
; steps = 0}

module Schedule = struct
(* Schedule specifications *)
Expand All @@ -2423,23 +2440,32 @@ module Schedule = struct
(** Repeat the given schedule to completion. *)

(* Returns [false] if no propagation was performed. *)
let rec run state schedule =
let rec run' state schedule =
if state.steps >= 10_000 then
raise Exit;
match schedule with
| Single queue ->
Any_propagator.Queue.run1 state queue
if Any_propagator.Queue.run1 state queue then (
state.steps <- state.steps + 1; true
) else
false
| Sequence schedules ->
Array.fold_left (fun did_run schedule ->
run state schedule || did_run
run' state schedule || did_run
) false schedules
| Repeat schedule ->
run_repeatedly state schedule false

and run_repeatedly state schedule did_run =
if run state schedule then
if run' state schedule then
run_repeatedly state schedule true
else
did_run

let run state schedule =
try ignore (run' state schedule); `Stable
with Exit -> `Unstable

let queue q = Single q

let sequence scheds = Sequence scheds
Expand All @@ -2457,13 +2483,15 @@ let rec propagate_all uf eqs bdom idom =

(* We run all propagations over a single domain to completion, then run
the consistency propagators to perform cross-domain propagations. *)
ignore Schedule.(
let outcome =
Schedule.(
run state @@ repeat @@
sequence
[| repeat @@ queue queues.propagation_queue
; repeat @@ queue queues.consistency_queue
|]
);
)
in

let bdom = state.bitlists in
let idom = state.intervals in
Expand All @@ -2483,7 +2511,23 @@ let rec propagate_all uf eqs bdom idom =
in

(* Propagate again in case constraints were simplified. *)
propagate_all uf eqs bdom idom
match outcome with
| `Stable -> propagate_all uf eqs bdom idom
| `Unstable ->
let bdom =
Any_propagator.Queue.filter_fold bitlist_constraint_propagator
(fun c bdom ->
Bitlist_domains.trigger c bdom
) queues.propagation_queue bdom
in
let idom =
Any_propagator.Queue.filter_fold interval_constraint_propagator
(fun c idom ->
Interval_domains.trigger c idom
) queues.propagation_queue idom
in
(* TODO: mark remaining propagations!! *)
eqs, bdom, idom
else
eqs, bdom, idom

Expand All @@ -2501,18 +2545,54 @@ let empty uf =
Uf.GlobalDomains.add (module Interval_domains) Interval_domains.empty @@
Uf.domains uf

let find_poly uf env bp =
match IntPolynomial.as_var bp with
| Some v -> v, Ex.empty
| None -> Uf.find uf (IntPolynomial.Map.find bp env.poly)

let poly2nat uf bvconv bp =
IntPolynomial.fold bp
~constant:(fun const ->
(Shostak.Polynome.Ints.(of_bigint const), Ex.empty, bvconv))
~coeff:(fun mon k (acc, ex, bvconv) ->
let l, ex, bvconv =
IntPolynomial.Mon.fold
(fun r n (acc, ex, bvconv) ->
let r, ex' = Uf.find_r uf r in
let ex = Ex.union ex ex' in
let (rp, ex'), bvconv = BV2Nat.find_bv2nat r bvconv in
let ex = Ex.union ex ex' in
let rr = Shostak.Arith.is_mine rp in
((rr, n) :: acc, ex, bvconv))
mon ([], ex, bvconv)
in
let l = Shostak.Ac.compact l in
let acc =
Shostak.Polynome.add acc @@
match l with
| [] -> assert false
| [ x, 1 ] ->
Shostak.Polynome.Ints.(Shostak.Arith.embed x *$$ k)
| l ->
let x =
X.color { h = Op Mult ; t = Tint ; l ; distribute = true }
in
Shostak.Polynome.Ints.(Shostak.Arith.embed x *$$ k)
in
(acc, ex, bvconv))

let assume env uf la =
let ds = Uf.domains uf in
let bvconv = Uf.GlobalDomains.find (module BV2Nat) ds in
let domain = Uf.GlobalDomains.find (module Bitlist_domains) ds in
let int_domain =
Uf.GlobalDomains.find (module Interval_domains) ds
in
let (domain, int_domain, eqs, size_splits) =
let (domain, int_domain, bvconv, eqs, size_splits) =
try
let (domain, int_domain, eqs, size_splits) =
let (domain, int_domain, bvconv, eqs, size_splits) =
List.fold_left
(fun (domain, int_domain, eqs, ss) (a, _root, ex, orig) ->
(fun (domain, int_domain, bvconv, eqs, ss) (a, _root, ex, orig) ->
let ss =
match orig with
| Th_util.CS (Th_bitv, n) -> Q.(ss * n)
Expand All @@ -2528,20 +2608,74 @@ let assume env uf la =
let x, exx = Uf.find_r uf x in
let y, exy = Uf.find_r uf y in
let ex = Ex.union ex @@ Ex.union exx exy in
let c =
if is_true then
Constraint.bvule x y
else
Constraint.bvugt x y

let (x2nat, exx'), bvconv = BV2Nat.find_bv2nat x bvconv in
let (y2nat, exy'), bvconv = BV2Nat.find_bv2nat y bvconv in
let ex = Ex.union ex @@ Ex.union exx' exy' in

let xcl = Uf.rclass_of uf x in
let ycl = Uf.rclass_of uf y in

let sz = bitwidth x in

let eqs, bvconv =
Expr.Set.fold (fun xt ((eqs, bvconv) as acc) ->
let bp = bv2poly uf xt in
let px, exp, bvconv = poly2nat uf bvconv bp in
if Shostak.Polynome.equal px x2nat then acc
else
let k =
BV2Nat.T.Ints.of_repr @@
Interner.find_bv2poly global_interner xt
in
let lit =
BV2Nat.T.Ints.mkv_eq x2nat
BV2Nat.T.Ints.(px + k *$$ Z.(one lsl sz))
in
(lit, exp) :: eqs, bvconv) xcl (eqs, bvconv)
in

let eqs, bvconv =
Expr.Set.fold (fun yt ((eqs, bvconv) as acc) ->
let bp = bv2poly uf yt in
let py, eyp, bvconv = poly2nat uf bvconv bp in
if Shostak.Polynome.equal py y2nat then acc
else
let k =
BV2Nat.T.Ints.of_repr @@
Interner.find_bv2poly global_interner yt
in
let lit =
BV2Nat.T.Ints.mkv_eq y2nat
BV2Nat.T.Ints.(py + k *$$ Z.(one lsl sz))
in
(lit, eyp) :: eqs, bvconv) ycl (eqs, bvconv)
in

let lit =
Uf.LX.mkv_builtin is_true LE
BV2Nat.T.Ints.[to_repr x2nat; to_repr y2nat]
in
(* Only watch comparisons on the interval domain since we don't
propagate them in the bitlist domain. . *)

let eqs = (lit, ex) :: eqs in

let int_domain =
Interval_domains.watch (explained ~ex c) x @@
Interval_domains.watch (explained ~ex c) y @@
int_domain
if false then
let c =
if is_true then
Constraint.bvule x y
else
Constraint.bvugt x y
in
(* Only watch comparisons on the interval domain since we
don't
propagate them in the bitlist domain. . *)
Interval_domains.watch (explained ~ex c) x @@
Interval_domains.watch (explained ~ex c) y @@
int_domain
else int_domain
in
(domain, int_domain, eqs, ss)
(domain, int_domain, bvconv, eqs, ss)
| L.Distinct (false, [rr; nrr]), _ when is_1bit rr ->
(* We don't (yet) support [distinct] in general, but we must
support it for case splits to avoid looping.
Expand All @@ -2551,10 +2685,11 @@ let assume env uf la =
let not_nrr =
Shostak.Bitv.is_mine (Bitv.lognot (Shostak.Bitv.embed nrr))
in
(domain, int_domain, (Uf.LX.mkv_eq rr not_nrr, ex) :: eqs, ss)
| _ -> (domain, int_domain, eqs, ss)
(domain, int_domain, bvconv,
(Uf.LX.mkv_eq rr not_nrr, ex) :: eqs, ss)
| _ -> (domain, int_domain, bvconv, eqs, ss)
)
(domain, int_domain, [], env.size_splits)
(domain, int_domain, bvconv, [], env.size_splits)
la
in
let eqs, domain, int_domain = propagate_all uf eqs domain int_domain in
Expand All @@ -2566,7 +2701,7 @@ let assume env uf la =
~module_name:"Bitv_rel" ~function_name:"assume"
"interval domain: @[%a@]" Interval_domains.pp int_domain;
);
(domain, int_domain, eqs, size_splits)
(domain, int_domain, bvconv, eqs, size_splits)
with Bitlist.Inconsistent ex | Interval_domain.Inconsistent ex ->
raise @@ Ex.Inconsistent (ex, Uf.cl_extract uf)
in
Expand Down Expand Up @@ -2646,11 +2781,6 @@ let case_split env uf ~for_model =
[ Uf.LX.mkv_eq lhs zero, true, Th_util.CS (Th_util.Th_bitv, Q.of_int 2) ]
| exception Not_found -> []

let find_poly uf env bp =
match IntPolynomial.as_var bp with
| Some v -> v, Ex.empty
| None -> Uf.find uf (IntPolynomial.Map.find bp env.poly)

let add env uf r t =
let ds = Uf.domains uf in
match X.type_info r with
Expand All @@ -2661,37 +2791,7 @@ let add env uf r t =
let rx, ex = Uf.find uf x in
let bvconv = BV2Nat.add_bv2nat ~ex r rx bvconv in
let bp = bv2poly uf x in
let px, exp, bvconv =
IntPolynomial.fold bp
~constant:(fun const ->
(Shostak.Polynome.Ints.(of_bigint const), Ex.empty, bvconv))
~coeff:(fun mon k (acc, ex, bvconv) ->
let l, ex, bvconv =
IntPolynomial.Mon.fold
(fun r n (acc, ex, bvconv) ->
let r, ex' = Uf.find_r uf r in
let ex = Ex.union ex ex' in
let (rp, ex'), bvconv = BV2Nat.find_bv2nat r bvconv in
let ex = Ex.union ex ex' in
let rr = Shostak.Arith.is_mine rp in
((rr, n) :: acc, ex, bvconv))
mon ([], ex, bvconv)
in
let l = Shostak.Ac.compact l in
let acc =
Shostak.Polynome.add acc @@
match l with
| [] -> assert false
| [ x, 1 ] ->
Shostak.Polynome.Ints.(Shostak.Arith.embed x *$$ k)
| l ->
let x =
X.color { h = Op Mult ; t = Tint ; l ; distribute = true }
in
Shostak.Polynome.Ints.(Shostak.Arith.embed x *$$ k)
in
(acc, ex, bvconv))
in
let px, exp, bvconv = poly2nat uf bvconv bp in
let sz = bitwidth rx in
let eqs, bvconv = BV2Nat.flush bvconv in
let eqs =
Expand Down
Loading

0 comments on commit b00d6c0

Please sign in to comment.