diff --git a/src/lib/reasoners/bitv_rel.ml b/src/lib/reasoners/bitv_rel.ml index ec5f0e152..a666a4ad8 100644 --- a/src/lib/reasoners/bitv_rel.ml +++ b/src/lib/reasoners/bitv_rel.ml @@ -103,6 +103,402 @@ module BitvNormalForm = struct in loop Z.zero [] (Shostak.Bitv.embed r) end +module Product = struct + module type S = sig + include Domains_intf.ComparableType + + val mul : t -> t -> t + (** Product multiplication. *) + + type var + (** The type of variables. A value of type [t] represents a product of + variables with type [var] associated with a nonnegative power. *) + + val var : var -> t + (** Injects a variable into a product. *) + + val as_var : t -> var option + (** If this product is equal to a single variable, return that variable. *) + + val fold : (var -> int -> 'a -> 'a) -> t -> 'a -> 'a + (** Fold over all the variables in the product and the corresponding + exponent. *) + end + + module Make(X : Domains_intf.ComparableType) : S with type var = X.t = struct + type var = X.t + + type t = { variables : int X.Map.t } + + let pp = + Fmt.using (fun { variables } -> variables) @@ + Fmt.iter_bindings ~sep:(Fmt.any "*") X.Map.iter @@ + fun ppf (x, k) -> + if k = 1 then X.pp ppf x + else Fmt.pf ppf "@[%a^%d@]" X.pp x k + + let equal p1 p2 = + X.Map.equal Int.equal p1.variables p2.variables + + let compare p1 p2 = + X.Map.compare Int.compare p1.variables p2.variables + + let hash p = + X.Map.fold (fun x k acc -> + 23 * acc + X.hash x * Hashtbl.hash (k : int) + ) p.variables 0 + + let var x = { variables = X.Map.singleton x 1 } + + let mul p1 p2 = + let variables = + X.Map.union + (fun _ k1 k2 -> Some (k1 + k2)) + p1.variables p2.variables + in + { variables } + + let as_var { variables } = + try + let r = ref None in + X.Map.iter (fun x k -> + match !r with + | None when k = 1 -> r := Some x + | _ -> raise Exit + ) variables; + !r + with Exit -> None + + let fold f { variables } acc = + X.Map.fold f variables acc + + module Set = Set.Make(struct type nonrec t = t let compare = compare end) + module Map = Map.Make(struct type nonrec t = t let compare = compare end) + module Table = Hashtbl.Make(struct + type nonrec t = t + let equal = equal + let hash = hash + end) + end +end + +module type Int_like = sig + type t + + val pp_print : t Fmt.t + + val equal : t -> t -> bool + + val compare : t -> t -> int + + val hash : t -> int + + val zero : t + + val one : t + + val add : t -> t -> t + + val sub : t -> t -> t + + val neg : t -> t + + val mul : t -> t -> t +end + +module Affine = struct + module type S = sig + include Domains_intf.ComparableType + + type var + + type constant + + val const : constant -> t + + val var : var -> t + + val as_var : t -> var option + + val ( ~- ) : t -> t + + val ( + ) : t -> t -> t + + val ( - ) : t -> t -> t + + val ( *$$ ) : t -> constant -> t + + val map_constant : (constant -> constant) -> t -> t + + type monomial + + val fold : + constant:(constant -> 'a) -> + coeff:(monomial -> constant -> 'a -> 'a) -> + t -> 'a + end + + module Make(K : Int_like)(X : Domains_intf.ComparableType) + : S with type var = X.t and type monomial = X.t and type constant = K.t = + struct + type var = X.t + + type monomial = X.t + + type constant = K.t + + type t = { const : K.t ; coeffs : K.t X.Map.t } + + let pp ppf { const ; coeffs } = + if not (K.equal const K.zero) || X.Map.is_empty coeffs then ( + K.pp_print ppf const; + if not (X.Map.is_empty coeffs) then + Fmt.pf ppf " +@ "; + ); + Fmt.pf ppf "%a" ( + Fmt.iter_bindings ~sep:(Fmt.any " +@ ") X.Map.iter @@ + fun ppf (x, k) -> + if K.equal k K.one then X.pp ppf x + else Fmt.pf ppf "@[%a*%a@]" K.pp_print k X.pp x + ) coeffs + + let equal a1 a2 = + K.equal a1.const a2.const && + X.Map.equal K.equal a1.coeffs a2.coeffs + + let compare a1 a2 = + let c = X.Map.compare K.compare a1.coeffs a2.coeffs in + if c <> 0 then c else K.compare a1.const a2.const + + let hash a = + X.Map.fold (fun x k acc -> + 23 * acc + X.hash x * K.hash k + ) a.coeffs (19 * K.hash a.const) + + let const const = { const ; coeffs = X.Map.empty } + + (* 0 -> None /!\ important for normalization *) + let simplify k = + if K.equal k K.zero then None else Some k + + let merge op a1 a2 = + let coeffs = + X.Map.merge (fun _ k1 k2 -> + match k1, k2 with + | None, None -> assert false + | Some k, None -> simplify (op k K.zero) + | None, Some k -> simplify (op K.zero k) + | Some k1, Some k2 -> simplify (op k1 k2) + ) a1.coeffs a2.coeffs + in + { const = op a1.const a2.const + ; coeffs } + + let add a1 a2 = merge K.add a1 a2 + + let ( + ) = add + + let sub a1 a2 = merge K.sub a1 a2 + + let ( - ) = sub + + let map op a = + { const = op a.const + ; coeffs = X.Map.filter_map (fun _ k -> simplify (op k)) a.coeffs } + + let neg a = map K.neg a + + let (~-) = neg + + let scale k a = map (K.mul k) a + + let ( *$$ ) a k = scale k a + + let var x = { const = K.zero ; coeffs = X.Map.singleton x K.one } + + let as_var a = + if K.equal a.const K.zero then + try + let r = ref None in + X.Map.iter (fun x k -> + match !r with + | None when K.equal k K.one -> r := Some x + | _ -> raise Exit + ) a.coeffs; + !r + with Exit -> None + else + None + + let map_constant fn a = + { const = fn a.const + ; coeffs = X.Map.filter_map (fun _ k -> simplify (fn k)) a.coeffs } + + let fold ~constant ~coeff { const ; coeffs } = + X.Map.fold coeff coeffs (constant const) + + module Set = Set.Make(struct type nonrec t = t let compare = compare end) + module Map = Map.Make(struct type nonrec t = t let compare = compare end) + module Table = Hashtbl.Make(struct + type nonrec t = t + let equal = equal + let hash = hash + end) + end +end + +module Polynomial = struct + module type S = sig + include Affine.S + + module Mon : Product.S with type var = var and type t = monomial + + val ( * ) : t -> t -> t + end + + module Make(K : Int_like)(X : Domains_intf.ComparableType) + : S with type var = X.t and type constant = K.t = + struct + module Mon = Product.Make(X) + module Aff = Affine.Make(K)(Mon) + + include Aff + + type var = X.t + + let var x = Aff.var (Mon.var x) + + let as_var p = Option.bind (Aff.as_var p) Mon.as_var + + let mul a1 a2 = + Aff.fold a2 + ~constant:Aff.(fun c2 -> a1 *$$ c2) + ~coeff:(fun x2 k2 acc -> + Aff.fold a1 + ~constant:Aff.(fun c1 -> acc + var x2 *$$ K.mul c1 k2) + ~coeff:Aff.(fun x1 k1 acc -> + acc + var (Mon.mul x1 x2) *$$ K.mul k1 k2)) + + let ( * ) = mul + end +end + +module IntPolynomial = Polynomial.Make(Z)(Rel_utils.XComparable) + +(* Convert a bit-vector term to a polynomial. + + This is essentially a bit-vector version of [Arith.make]. *) +let rec bv2poly uf t = + match E.term_view t with + | { f = Bitv (_, n) ; xs = []; _ } -> + IntPolynomial.const n + | { f = Op BVadd ; xs = [ x ; y ] ; _ } -> + IntPolynomial.(bv2poly uf x + bv2poly uf y) + | { f = Op BVsub ; xs = [ x ; y ] ; _ } -> + IntPolynomial.(bv2poly uf x - bv2poly uf y) + | { f = Op BVmul ; xs = [ x ; y ] ; _ } -> + IntPolynomial.(bv2poly uf x * bv2poly uf y) + | { f = Op BVnot ; xs = [ x ] ; _ } -> + IntPolynomial.(-bv2poly uf x - const Z.one) + | { f = Op Concat ; xs = [ x ; y ] ; _ } -> + let sz = match Expr.type_info y with Tbitv n -> n | _ -> assert false in + IntPolynomial.(bv2poly uf x *$$ Z.(one lsl sz) + bv2poly uf y) + | { f = Op (Repeat n) ; xs = [ x ] ; _ } -> + assert (n > 0); + (* (repeat 1) x = x * 1 + (repeat 2) x = x * (2^sz + 1) + (repeat 3) x = x * (2^(2sz) + 2^sz + 1) + ... + *) + let sz = match Expr.type_info x with Tbitv n -> n | _ -> assert false in + let rec loop n acc = + if n = 1 then acc + else loop (n - 1) Z.(acc lsl sz lor one) + in + IntPolynomial.(bv2poly uf x *$$ loop n Z.one) + | { f = Op BVshl ; xs = [ x ; y ] ; _ } -> ( + match Shostak.Bitv.embed (Uf.make uf y) with + | [ { bv = Cte n ; sz } ] -> ( + match Z.to_int n with + | n -> + if n >= sz then + IntPolynomial.const Z.zero + else + IntPolynomial.(bv2poly uf x *$$ Z.(one lsl n)) + | exception Z.Overflow -> + IntPolynomial.const Z.zero + ) + | _ -> + IntPolynomial.var @@ Uf.make uf t + ) + | _ -> + IntPolynomial.var @@ Uf.make uf t + +let bv2poly uf t = + let p = bv2poly uf t in + let sz = match Expr.type_info t with Tbitv n -> n | _ -> assert false in + let p = IntPolynomial.map_constant (fun n -> Z.extract n 0 sz) p in + p + +module Interner = struct + module H = Ephemeron.K1.Make(struct + type t = X.r + let equal = X.equal + let hash = X.hash + end) + + module EE = Ephemeron.K1.Make(Expr) + + module P = Shostak.Polynome + + type t = + { bvasr : (int, P.t) Hashtbl.t H.t + (** An entry [r -> ofs -> p] in this map indicates that the equality: + + [p = bv2nat(r asr ofs)] + + holds in all contexts. *) + ; bv2poly : X.r EE.t + (** An entry [e -> r] in this map indicates that the equality: + + [bv2nat(e) = bv2poly(e) + r * 2^sz] + + where [e] has type [Tbitv sz]. *) + } + + let create () = + { bvasr = H.create 17 + ; bv2poly = EE.create 17 } + + let find_asr t r ofs = + if ofs >= bitwidth r then + Shostak.Polynome.Ints.zero + else + let table = + try H.find t.bvasr r + with Not_found -> + let table = Hashtbl.create 3 in + H.replace t.bvasr r table; + table + in + try Hashtbl.find table ofs + with Not_found -> + let p = Shostak.Arith.embed (X.term_embed (E.fresh_name Tint)) in + Hashtbl.replace table ofs p; + p + + let find_bv2poly t r = + try EE.find t.bv2poly r + with Not_found -> + let k = X.term_embed (E.fresh_name Tint) in + EE.replace t.bv2poly r k; + k +end + +(* Used to make sure that we create the same variables in the main environment + and the case split environment. *) +let global_interner = Interner.create () + module BV2Nat = struct (* Domain for bv2nat and int2bv conversions @@ -313,7 +709,7 @@ module BV2Nat = struct if ofs >= bitwidth bv then (T.Ints.zero, Ex.empty), t else - let k = T.Ints.of_repr @@ X.term_embed @@ E.fresh_name Tint in + let k = Interner.find_asr global_interner bv ofs in let use = add_use_p k t.use in let bv2nat = add_ext ~ex:Ex.empty bv ext k t.bv2nat in let nat2bv = P.Map.add k (bv, ext, Ex.empty) t.nat2bv in @@ -472,6 +868,18 @@ module BV2Nat = struct let (p', ex'), t = composite x t in { t with eqs = (T.Ints.mkv_eq p p', Ex.union ex ex') :: t.eqs } + let find_bv2nat bv_r t = + match BitvNormalForm.normal_form bv_r with + | Constant cte -> + (T.Ints.of_bigint cte, Ex.empty), t + | Atom (x, ofs) -> + let ext = Extraction.full ~size:(bitwidth x) in + let (p, ex), t = find_or_init_ext x ext t in + (T.Ints.(p +$$ ofs), ex), t + | Composite (x, ofs) -> + let (p, ex), t = composite x t in + (T.Ints.(p +$$ ofs), ex), t + (* Add the equality [bv_r = int2bv(int_r)]. We do not have a table mapping polynomials to their (truncated) bit-vector @@ -2081,11 +2489,13 @@ let rec propagate_all uf eqs bdom idom = type t = { terms : SX.t - ; size_splits : Q.t } + ; size_splits : Q.t + ; poly : Expr.t IntPolynomial.Map.t } let empty uf = { terms = SX.empty - ; size_splits = Q.one }, + ; size_splits = Q.one + ; poly = IntPolynomial.Map.empty }, Uf.GlobalDomains.add (module BV2Nat) BV2Nat.empty @@ Uf.GlobalDomains.add (module Bitlist_domains) Bitlist_domains.empty @@ Uf.GlobalDomains.add (module Interval_domains) Interval_domains.empty @@ @@ -2171,7 +2581,7 @@ let assume env uf la = { Sig_rel.assume = List.rev_append bvconv_assume assume ; remove = [] } in - { size_splits ; terms = env.terms }, + { size_splits ; terms = env.terms ; poly = env.poly }, Uf.GlobalDomains.add (module BV2Nat) bvconv @@ Uf.GlobalDomains.add (module Bitlist_domains) domain @@ Uf.GlobalDomains.add (module Interval_domains) int_domain ds, @@ -2236,6 +2646,11 @@ let case_split env uf ~for_model = [ Uf.LX.mkv_eq lhs zero, true, Th_util.CS (Th_util.Th_bitv, Q.of_int 2) ] | exception Not_found -> [] +let find_poly uf env bp = + match IntPolynomial.as_var bp with + | Some v -> v, Ex.empty + | None -> Uf.find uf (IntPolynomial.Map.find bp env.poly) + let add env uf r t = let ds = Uf.domains uf in match X.type_info r with @@ -2245,7 +2660,53 @@ let add env uf r t = let bvconv = Uf.GlobalDomains.find (module BV2Nat) ds in let rx, ex = Uf.find uf x in let bvconv = BV2Nat.add_bv2nat ~ex r rx bvconv in + let bp = bv2poly uf x in + let px, exp, bvconv = + IntPolynomial.fold bp + ~constant:(fun const -> + (Shostak.Polynome.Ints.(of_bigint const), Ex.empty, bvconv)) + ~coeff:(fun mon k (acc, ex, bvconv) -> + let l, ex, bvconv = + IntPolynomial.Mon.fold + (fun r n (acc, ex, bvconv) -> + let r, ex' = Uf.find_r uf r in + let ex = Ex.union ex ex' in + let (rp, ex'), bvconv = BV2Nat.find_bv2nat r bvconv in + let ex = Ex.union ex ex' in + let rr = Shostak.Arith.is_mine rp in + ((rr, n) :: acc, ex, bvconv)) + mon ([], ex, bvconv) + in + let l = Shostak.Ac.compact l in + let acc = + Shostak.Polynome.add acc @@ + match l with + | [] -> assert false + | [ x, 1 ] -> + Shostak.Polynome.Ints.(Shostak.Arith.embed x *$$ k) + | l -> + let x = + X.color { h = Op Mult ; t = Tint ; l ; distribute = true } + in + Shostak.Polynome.Ints.(Shostak.Arith.embed x *$$ k) + in + (acc, ex, bvconv)) + in + let sz = bitwidth rx in let eqs, bvconv = BV2Nat.flush bvconv in + let eqs = + if X.equal (Shostak.Arith.is_mine px) r then + eqs + else + let k = + BV2Nat.T.Ints.of_repr @@ + Interner.find_bv2poly global_interner x + in + let lit = + BV2Nat.T.Ints.mkv_eq (Shostak.Arith.embed r) + BV2Nat.T.Ints.(px + k *$$ Z.(one lsl sz)) + in + (lit, exp) :: eqs in env, Uf.GlobalDomains.add (module BV2Nat) bvconv ds, eqs @@ -2262,16 +2723,34 @@ let add env uf r t = Uf.GlobalDomains.add (module BV2Nat) bvconv ds, eqs | _ -> + let eqs, poly = + (* Record equality with any existing term that normalizes to the same + bit-vector polynomial. *) + let bp = bv2poly uf t in + match find_poly uf env bp with + | r', ex' -> + let r, ex = Uf.find uf t in + if X.equal r r' then [], env.poly + else + let lit = + BV2Nat.T.BV.mkv_eq + (r, BV2Nat.Extraction.full ~size:(bitwidth r)) + (r', BV2Nat.Extraction.full ~size:(bitwidth r')) + in + [(lit, Ex.union ex ex')], env.poly + | exception Not_found -> + [], IntPolynomial.Map.add bp t env.poly + in let dom = Uf.GlobalDomains.find (module Bitlist_domains) ds in let idom = Uf.GlobalDomains.find (module Interval_domains) ds in let terms, dom, idom = extract_constraints env.terms dom idom uf r t in - { env with terms }, + { env with terms ; poly }, Uf.GlobalDomains.add (module Bitlist_domains) dom @@ Uf.GlobalDomains.add (module Interval_domains) idom @@ ds, - [] + eqs ) | _ -> env, ds, []