Skip to content

Commit

Permalink
Use visitor for various Jib rewrites
Browse files Browse the repository at this point in the history
Add a simple heuristic for detecting constant folds that would produce large literals

Make is_union_constructor in type_env more efficient
  • Loading branch information
Alasdair committed Oct 27, 2023
1 parent 0b040ce commit e4f9665
Show file tree
Hide file tree
Showing 8 changed files with 182 additions and 238 deletions.
40 changes: 27 additions & 13 deletions src/lib/constant_fold.ml
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,17 @@ and exp_of_value =
| V_attempted_read str -> mk_exp (E_id (mk_id str))
| _ -> failwith "No expression for value"

(* A simple heuristic to avoid generating overly large literals. Note
that we avoid traversing through every element of vectors and
lists, so a list of large lists could still sneak through *)
let rec is_too_large =
let open Value in
function
| V_int _ | V_bit _ | V_bool _ | V_string _ | V_unit | V_attempted_read _ | V_real _ | V_ref _ -> false
| V_vector vs | V_tuple vs | V_list vs -> List.compare_length_with vs 256 > 0
| V_record fields -> StringMap.exists (fun _ v -> is_too_large v) fields
| V_ctor (_, vs) -> List.exists is_too_large vs

(* We want to avoid evaluating things like print statements at compile
time, so we remove them from this list of primops we can use when
constant folding. *)
Expand Down Expand Up @@ -205,19 +216,22 @@ let rw_exp fixed target ok not_ok istate =
try
begin
let v = run (Interpreter.Step (lazy "", istate, initial_monad, [])) in
let exp = exp_of_value v in
try
ok ();
Type_check.check_exp (env_of_annot annot) exp (typ_of_annot annot)
with Type_error.Type_error (l, err) ->
(* A type error here would be unexpected, so don't ignore it! *)
Reporting.warn "" l
("Type error when folding constants in "
^ string_of_exp (E_aux (e_aux, annot))
^ "\n" ^ Type_error.string_of_type_error err
);
not_ok ();
E_aux (e_aux, annot)
if not (is_too_large v) then (
let exp = exp_of_value v in
try
ok ();
Type_check.check_exp (env_of_annot annot) exp (typ_of_annot annot)
with Type_error.Type_error (l, err) ->
(* A type error here would be unexpected, so don't ignore it! *)
Reporting.warn "" l
("Type error when folding constants in "
^ string_of_exp (E_aux (e_aux, annot))
^ "\n" ^ Type_error.string_of_type_error err
);
not_ok ();
E_aux (e_aux, annot)
)
else E_aux (e_aux, annot)
end
with
(* Otherwise if anything goes wrong when trying to constant
Expand Down
121 changes: 62 additions & 59 deletions src/lib/jib_compile.ml
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ open Ast_defs
open Ast_util
open Jib
open Jib_util
open Jib_visitor
open Type_check
open Value2

Expand Down Expand Up @@ -1695,30 +1696,12 @@ module Make (C : CONFIG) = struct

let is_variant id = function CT_variant (id', _) -> Id.compare id id' = 0 | _ -> false

let map_structs_and_variants f = function
| ( CT_lint | CT_fint _ | CT_constant _ | CT_lbits | CT_fbits _ | CT_sbits _ | CT_bit | CT_unit | CT_bool | CT_real
| CT_string | CT_poly _ | CT_enum _ | CT_float _ | CT_rounding_mode ) as ctyp ->
ctyp
| CT_tup ctyps -> CT_tup (List.map (map_ctyp f) ctyps)
| CT_ref ctyp -> CT_ref (map_ctyp f ctyp)
| CT_vector ctyp -> CT_vector (map_ctyp f ctyp)
| CT_fvector (n, ctyp) -> CT_fvector (n, map_ctyp f ctyp)
| CT_list ctyp -> CT_list (map_ctyp f ctyp)
| CT_struct (id, fields) -> begin
match f (CT_struct (id, fields)) with
| CT_struct (id, fields) -> CT_struct (id, List.map (fun (id, ctyp) -> (id, map_ctyp f ctyp)) fields)
| _ -> Reporting.unreachable (id_loc id) __POS__ "Struct mapped to non-struct"
end
| CT_variant (id, ctors) -> begin
match f (CT_variant (id, ctors)) with
| CT_variant (id, ctors) -> CT_variant (id, List.map (fun (id, ctyp) -> (id, map_ctyp f ctyp)) ctors)
| _ -> Reporting.unreachable (id_loc id) __POS__ "Variant mapped to non-variant"
end
class fix_variants_visitor ctx var_id =
object
inherit empty_jib_visitor

let rec specialize_variants ctx prior =
let instantiations = ref CTListSet.empty in
let fix_variants ctx var_id =
map_structs_and_variants (function
method vctyp =
function
| CT_variant (id, ctors) when Id.compare var_id id = 0 ->
let generic_ctors = Bindings.find id ctx.variants |> snd |> Bindings.bindings in
let unifiers =
Expand All @@ -1729,47 +1712,71 @@ module Make (C : CONFIG) = struct
( mangle_mono_id id ctx unifiers,
List.map (fun (ctor_id, ctyp) -> (mangle_mono_id ctor_id ctx unifiers, ctyp)) ctors
)
|> change_do_children
| CT_struct (id, fields) when Id.compare var_id id = 0 ->
let generic_fields = Bindings.find id ctx.records |> snd |> Bindings.bindings in
let unifiers =
ctyp_unify (id_loc id) (CT_struct (id, generic_fields)) (CT_struct (id, fields))
|> KBindings.bindings |> List.map snd
in
CT_struct (mangle_mono_id id ctx unifiers, List.map (fun (field_id, ctyp) -> (field_id, ctyp)) fields)
| ctyp -> ctyp
)
in
|> change_do_children
| _ -> DoChildren
end

class specialize_constructor_visitor instantiations ctx ctor_id =
object
inherit empty_jib_visitor

method vctyp _ = SkipChildren
method vclexp _ = SkipChildren

method vcval =
function
| V_ctor_kind (cval, (id, unifiers), pat_ctyp) when Id.compare id ctor_id = 0 ->
change_do_children (V_ctor_kind (cval, (mangle_mono_id id ctx unifiers, []), pat_ctyp))
| V_ctor_unwrap (cval, (id, unifiers), ctor_ctyp) when Id.compare id ctor_id = 0 ->
change_do_children (V_ctor_unwrap (cval, (mangle_mono_id id ctx unifiers, []), ctor_ctyp))
| _ -> DoChildren

method vinstr =
function
| I_aux (I_funcall (clexp, extern, (id, ctyp_args), args), aux) when Id.compare id ctor_id = 0 ->
instantiations := CTListSet.add ctyp_args !instantiations;
I_aux (I_funcall (clexp, extern, (mangle_mono_id id ctx ctyp_args, []), args), aux) |> change_do_children
| _ -> DoChildren
end

class specialize_field_visitor instantiations ctx struct_id =
object
inherit empty_jib_visitor

method vctyp _ = SkipChildren
method vclexp _ = SkipChildren
method vcval _ = SkipChildren

method vinstr =
function
| I_aux (I_decl (CT_struct (struct_id', fields), _), (_, l)) when Id.compare struct_id struct_id' = 0 ->
let generic_fields = Bindings.find struct_id ctx.records |> snd |> Bindings.bindings in
let unifiers =
ctyp_unify l (CT_struct (struct_id, generic_fields)) (CT_struct (struct_id, fields))
|> KBindings.bindings |> List.map snd
in
instantiations := CTListSet.add unifiers !instantiations;
DoChildren
| _ -> DoChildren
end

let specialize_cval ctx ctor_id = function
| V_ctor_kind (cval, (id, unifiers), pat_ctyp) when Id.compare id ctor_id = 0 ->
V_ctor_kind (cval, (mangle_mono_id id ctx unifiers, []), pat_ctyp)
| V_ctor_unwrap (cval, (id, unifiers), ctor_ctyp) when Id.compare id ctor_id = 0 ->
V_ctor_unwrap (cval, (mangle_mono_id id ctx unifiers, []), ctor_ctyp)
| cval -> cval
in
let rec specialize_variants ctx prior =
let instantiations = ref CTListSet.empty in
let fix_variants ctx var_id = visit_ctyp (new fix_variants_visitor ctx var_id) in

let specialize_constructor ctx var_id ctor_id ctyp = function
| I_aux (I_funcall (clexp, extern, (id, ctyp_args), [cval]), aux) when Id.compare id ctor_id = 0 ->
instantiations := CTListSet.add ctyp_args !instantiations;
I_aux
( I_funcall
(clexp, extern, (mangle_mono_id id ctx ctyp_args, []), [map_cval (specialize_cval ctx ctor_id) cval]),
aux
)
| instr -> map_instr_cval (map_cval (specialize_cval ctx ctor_id)) instr
let specialize_constructor ctx ctor_id =
visit_cdefs (new specialize_constructor_visitor instantiations ctx ctor_id)
in

let specialize_field ctx struct_id = function
| I_aux (I_decl (CT_struct (struct_id', fields), _), (_, l)) as instr when Id.compare struct_id struct_id' = 0 ->
let generic_fields = Bindings.find struct_id ctx.records |> snd |> Bindings.bindings in
let unifiers =
ctyp_unify l (CT_struct (struct_id, generic_fields)) (CT_struct (struct_id, fields))
|> KBindings.bindings |> List.map snd
in
instantiations := CTListSet.add unifiers !instantiations;
instr
| instr -> instr
in
let specialize_field ctx struct_id = visit_cdefs (new specialize_field_visitor instantiations ctx struct_id) in

let mangled_pragma orig_id mangled_id =
CDEF_pragma
Expand Down Expand Up @@ -1808,11 +1815,7 @@ module Make (C : CONFIG) = struct
cdefs;

let cdefs =
List.fold_left
(fun cdefs (ctor_id, ctyp) ->
List.map (cdef_map_instr (specialize_constructor ctx var_id ctor_id ctyp)) cdefs
)
cdefs ctors
List.fold_left (fun cdefs (ctor_id, ctyp) -> specialize_constructor ctx ctor_id cdefs) cdefs ctors
in

let monomorphized_variants =
Expand Down Expand Up @@ -1878,7 +1881,7 @@ module Make (C : CONFIG) = struct
->
let typ_params = List.fold_left (fun set (_, ctyp) -> KidSet.union (ctyp_vars ctyp) set) KidSet.empty fields in

let cdefs = List.map (cdef_map_instr (specialize_field ctx struct_id)) cdefs in
let cdefs = specialize_field ctx struct_id cdefs in
let monomorphized_structs =
List.map
(fun inst ->
Expand Down
Loading

0 comments on commit e4f9665

Please sign in to comment.