Skip to content

Commit

Permalink
compile-time match literals and options (#3903)
Browse files Browse the repository at this point in the history
This PR makes constant expressions fully matchable. We now also cover
- `LitP` and
- `OptP`

in the matching machinery. Thus there is no need to differentiate if the pattern is irrefutable.

Simplifies the logic quite a bit (as it should!)

-----------
TODO
- [x] tests for const `Opt` and `Lit`
- [x] bottoming cases too
- [x] check that `Opt` shortcutting works as intended — already in `iter-no-alloc.mo` for kernel present and `aardvark.mo` for kernel missing.
  • Loading branch information
ggreif authored Mar 30, 2023
1 parent 7903d21 commit efb8180
Show file tree
Hide file tree
Showing 22 changed files with 181 additions and 67 deletions.
154 changes: 88 additions & 66 deletions src/codegen/compile.ml
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,18 @@ module Const = struct
| Word64 of int64
| Float64 of Numerics.Float.t
| Blob of string

| Null

let lit_eq = function
| Vanilla i, Vanilla j -> i = j
| BigInt i, BigInt j -> Big_int.eq_big_int i j
| Word32 i, Word32 j -> i = j
| Word64 i, Word64 j -> i = j
| Float64 i, Float64 j -> i = j
| Bool i, Bool j -> i = j
| Blob s, Blob t -> s = t
| Null, Null -> true
| _ -> false

(* Inlineable functions
Expand Down Expand Up @@ -139,6 +150,7 @@ module Const = struct
| Unit
| Array of t list (* also tuples, but not nullary *)
| Tag of (string * t)
| Opt of t
| Lit of lit

(* A constant known value together with a vanilla pointer.
Expand Down Expand Up @@ -1678,7 +1690,13 @@ module Opt = struct
let null_lit env =
compile_unboxed_const (null_vanilla_lit env)

let is_some env =
let vanilla_lit env ptr : int32 =
E.add_static env StaticBytes.[
I32 Tagged.(int_of_tag Some);
I32 ptr
]

let is_some env =
null_lit env ^^
G.i (Compare (Wasm.Values.I32 I32Op.Ne))

Expand All @@ -1692,10 +1710,7 @@ module Opt = struct
[ Tagged.Null,
(* NB: even ?null does not require allocation: We use a static
singleton for that: *)
compile_unboxed_const (E.add_static env StaticBytes.[
I32 Tagged.(int_of_tag Some);
I32 (null_vanilla_lit env)
])
compile_unboxed_const (vanilla_lit env (null_vanilla_lit env))
; Tagged.Some,
Tagged.obj env Tagged.Some [get_x]
]
Expand All @@ -1706,7 +1721,6 @@ module Opt = struct
we know for sure that it wouldn’t do anything anyways *)
let inject_noop _env e = e


let project env =
Func.share_code1 env "opt_project" ("x", I32Type) [I32Type] (fun env get_x ->
get_x ^^ BitTagged.if_tagged_scalar env [I32Type]
Expand Down Expand Up @@ -7125,13 +7139,14 @@ module StackRep = struct
*)
let materialize_lit env (lit : Const.lit) : int32 =
match lit with
| Const.Vanilla n -> n
| Const.Bool n -> Bool.vanilla_lit n
| Const.BigInt n -> BigNum.vanilla_lit env n
| Const.Word32 n -> BoxedSmallWord.vanilla_lit env n
| Const.Word64 n -> BoxedWord64.vanilla_lit env n
| Const.Float64 f -> Float.vanilla_lit env f
| Const.Blob t -> Blob.vanilla_lit env t
| Const.Vanilla n -> n
| Const.Bool n -> Bool.vanilla_lit n
| Const.BigInt n -> BigNum.vanilla_lit env n
| Const.Word32 n -> BoxedSmallWord.vanilla_lit env n
| Const.Word64 n -> BoxedWord64.vanilla_lit env n
| Const.Float64 f -> Float.vanilla_lit env f
| Const.Blob t -> Blob.vanilla_lit env t
| Const.Null -> Opt.null_vanilla_lit env

let rec materialize_const_t env (p, cv) : int32 =
Lib.Promise.lazy_value p (fun () -> materialize_const_v env cv)
Expand All @@ -7150,6 +7165,14 @@ module StackRep = struct
let ptr = materialize_const_t env c in
Variant.vanilla_lit env i ptr
| Const.Lit l -> materialize_lit env l
| Const.Opt c ->
let rec kernel = Const.(function
| (_, Lit Null) -> None
| (_, Opt c) -> kernel c
| (_, other) -> Some (materialize_const_v env other)) in
match kernel c with
| Some ptr -> ptr
| None -> Opt.vanilla_lit env (materialize_const_t env c)

let adjust env (sr_in : t) sr_out =
if eq sr_in sr_out
Expand Down Expand Up @@ -8148,7 +8171,7 @@ let nat64_to_int64 n =
then sub_big_int n (power_int_positive_int 2 64)
else n

let const_lit_of_lit env : Ir.lit -> Const.lit = function
let const_lit_of_lit : Ir.lit -> Const.lit = function
| BoolLit b -> Const.Bool b
| IntLit n
| NatLit n -> Const.BigInt (Numerics.Nat.to_big_int n)
Expand All @@ -8161,22 +8184,21 @@ let const_lit_of_lit env : Ir.lit -> Const.lit = function
| Int64Lit n -> Const.Word64 (Big_int.int64_of_big_int (Numerics.Int_64.to_big_int n))
| Nat64Lit n -> Const.Word64 (Big_int.int64_of_big_int (nat64_to_int64 (Numerics.Nat64.to_big_int n)))
| CharLit c -> Const.Vanilla Int32.(shift_left (of_int c) 8)
| NullLit -> Const.Vanilla (Opt.null_vanilla_lit env)
| NullLit -> Const.Null
| TextLit t
| BlobLit t -> Const.Blob t
| FloatLit f -> Const.Float64 f

let const_of_lit env lit =
Const.t_of_v (Const.Lit (const_lit_of_lit env lit))
let const_of_lit lit =
Const.t_of_v (Const.Lit (const_lit_of_lit lit))

let compile_lit env lit =
SR.Const (const_of_lit env lit), G.nop
let compile_lit lit =
SR.Const (const_of_lit lit), G.nop

let compile_lit_as env sr_out lit =
let sr_in, code = compile_lit env lit in
let sr_in, code = compile_lit lit in
code ^^ StackRep.adjust env sr_in sr_out


(* helper, traps with message *)
let then_arithmetic_overflow env =
E.then_trap_with env "arithmetic overflow"
Expand Down Expand Up @@ -9808,7 +9830,7 @@ and compile_exp_with_hint (env : E.t) ae sr_hint exp =
compile_exp_as env ae sr e2 ^^
store_code
| LitE l ->
compile_lit env l
compile_lit l
| IfE (scrut, e1, e2) ->
let code_scrut = compile_exp_as_test env ae scrut in
let sr1, code1 = compile_exp_with_hint env ae sr_hint e1 in
Expand Down Expand Up @@ -10197,16 +10219,6 @@ and compile_dec env pre_ae how v2en dec : VarEnv.t * G.t * (VarEnv.t -> scope_wr
G.(pre_ae, with_region dec.at alloc_code, fun ae body_code ->
with_region dec.at (mk_code ae) ^^ wrap body_code)) @@

(* Set up some helpers, for exclusive usage by the "constant expressions" special case below *)
let const_exp_helper =
lazy begin
let[@warning "-8"] LetD (p, e) = dec.it in
const_exp_matches_pat env p e
end in
let is_compile_time_matchable () =
let lazy const_exp_matches = const_exp_helper in
Option.is_some const_exp_matches in

match dec.it with
(* A special case for public methods *)
(* This relies on the fact that in the top-level mutually recursive group, no shadowing happens. *)
Expand All @@ -10219,15 +10231,13 @@ and compile_dec env pre_ae how v2en dec : VarEnv.t * G.t * (VarEnv.t -> scope_wr
G.( pre_ae1, nop, (fun ae -> fill env ae; nop), unmodified)

(* A special case for constant expressions *)
| LetD (p, e) when e.note.Note.const && is_compile_time_matchable () ->
let is_compile_time_bottom =
let lazy const_exp_matches = const_exp_helper in
not (Option.get const_exp_matches) in
if is_compile_time_bottom then (* refuted *)
(pre_ae, G.nop, (fun _ -> PatCode.patternFailTrap env), unmodified)
else (* not refuted *)
| LetD (p, e) when e.note.Note.const ->
(* constant expression matching with patterns is fully decidable *)
if const_exp_matches_pat env pre_ae p e then (* not refuted *)
let extend, fill = compile_const_dec env pre_ae dec in
G.(extend pre_ae, nop, (fun ae -> fill env ae; nop), unmodified)
else (* refuted *)
(pre_ae, G.nop, (fun _ -> PatCode.patternFailTrap env), unmodified)

| LetD (p, e) ->
let (pre_ae1, alloc_code, pre_code, sr, fill_code) = compile_unboxed_pat env pre_ae how p in
Expand Down Expand Up @@ -10354,7 +10364,7 @@ and compile_const_exp env pre_ae exp : Const.t * (E.t -> VarEnv.t -> unit) =
| _, Const.Array cs -> cs
| _ -> fatal "compile_const_exp/ProjE: not a static tuple" in
(List.nth cs i, fill)
| LitE l -> Const.(t_of_v (Lit (const_lit_of_lit env l))), (fun _ _ -> ())
| LitE l -> Const.(t_of_v (Lit (const_lit_of_lit l))), (fun _ _ -> ())
| PrimE (TupPrim, []) -> Const.t_of_v Const.Unit, (fun _ _ -> ())
| PrimE (ArrayPrim (Const, _), es)
| PrimE (TupPrim, es) ->
Expand All @@ -10365,6 +10375,10 @@ and compile_const_exp env pre_ae exp : Const.t * (E.t -> VarEnv.t -> unit) =
let (arg_ct, fill) = compile_const_exp env pre_ae e in
Const.(t_of_v (Tag (i, arg_ct))),
fill
| PrimE (OptPrim, [e]) ->
let (arg_ct, fill) = compile_const_exp env pre_ae e in
Const.(t_of_v (Opt arg_ct)),
fill

| _ -> assert false

Expand All @@ -10380,39 +10394,47 @@ and compile_const_decs env pre_ae decs : (VarEnv.t -> VarEnv.t) * (E.t -> VarEnv
(fun env ae -> fill1 env ae; fill2 env ae) in
go pre_ae decs

and const_exp_matches_pat env pat exp : bool option =
and const_exp_matches_pat env ae pat exp : bool =
assert exp.note.Note.const;
match exp.it with
| _ when Ir_utils.is_irrefutable pat -> Some true
| PrimE (TagPrim _, _) ->
let c, _ = compile_const_exp env VarEnv.empty_ae exp in
(try ignore (destruct_const_pat VarEnv.empty_ae pat c); Some true with Invalid_argument _ -> Some false)
| _ -> None

and destruct_const_pat ae pat const : VarEnv.t = match pat.it with
| WildP -> ae
| VarP v -> VarEnv.add_local_const ae v const
let c, _ = compile_const_exp env ae exp in
match destruct_const_pat VarEnv.empty_ae pat c with Some _ -> true | _ -> false

and destruct_const_pat ae pat const : VarEnv.t option = match pat.it with
| WildP -> Some ae
| VarP v -> Some (VarEnv.add_local_const ae v const)
| ObjP pfs ->
let fs = match const with (_, Const.Obj fs) -> fs | _ -> assert false in
List.fold_left (fun ae (pf : pat_field) ->
match List.find_opt (fun (n, _) -> pf.it.name = n) fs with
| Some (_, c) -> destruct_const_pat ae pf.it.pat c
| None -> assert false
) ae pfs
match ae, List.find_opt (fun (n, _) -> pf.it.name = n) fs with
| None, _ -> None
| Some ae, Some (_, c) -> destruct_const_pat ae pf.it.pat c
| _, None -> assert false
) (Some ae) pfs
| AltP (p1, p2) ->
begin
try destruct_const_pat ae p1 const with
Invalid_argument _ -> destruct_const_pat ae p2 const
end
let l = destruct_const_pat ae p1 const in
if l = None then destruct_const_pat ae p2 const
else l
| TupP ps ->
let cs = match const with (_ , Const.Array cs) -> cs | (_, Const.Unit) -> [] | _ -> assert false in
List.fold_left2 destruct_const_pat ae ps cs
| LitP _ -> raise (Invalid_argument "LitP in static irrefutable pattern")
| OptP _ -> raise (Invalid_argument "OptP in static irrefutable pattern")
let cs = match const with (_, Const.Array cs) -> cs | (_, Const.Unit) -> [] | _ -> assert false in
let go ae p c = match ae with
| Some ae -> destruct_const_pat ae p c
| _ -> None in
List.fold_left2 go (Some ae) ps cs
| LitP lp ->
begin match const with
| (_, Const.Lit lc) when Const.lit_eq (const_lit_of_lit lp, lc) -> Some ae
| _ -> None
end
| OptP p ->
begin match const with
| (_, Const.Opt c) -> destruct_const_pat ae p c
| (_, Const.(Lit Null)) -> None
| _ -> assert false
end
| TagP (i, p) ->
match const with
| (_, Const.Tag (ic, c)) when i = ic -> destruct_const_pat ae p c
| (_, Const.Tag _) -> raise (Invalid_argument "TagP mismatch")
| (_, Const.Tag _) -> None
| _ -> assert false

and compile_const_dec env pre_ae dec : (VarEnv.t -> VarEnv.t) * (E.t -> VarEnv.t -> unit) =
Expand All @@ -10424,7 +10446,7 @@ and compile_const_dec env pre_ae dec : (VarEnv.t -> VarEnv.t) * (E.t -> VarEnv.t
(* This should only contain constants (cf. is_const_exp) *)
| LetD (p, e) ->
let (const, fill) = compile_const_exp env pre_ae e in
(fun ae -> destruct_const_pat ae p const),
(fun ae -> match destruct_const_pat ae p const with Some ae -> ae | _ -> assert false),
(fun env ae -> fill env ae)
| VarD _ | RefD _ -> fatal "compile_const_dec: Unexpected VarD/RefD"

Expand Down
2 changes: 2 additions & 0 deletions src/ir_def/check_ir.ml
Original file line number Diff line number Diff line change
Expand Up @@ -867,6 +867,8 @@ let rec check_exp env (exp:Ir.exp) : unit =
check e1.note.Note.const "constant DotPrim on non-constant subexpression"
| PrimE (ProjPrim _, [e1]) ->
check e1.note.Note.const "constant ProjPrim on non-constant subexpression"
| PrimE (OptPrim, [e1]) ->
check e1.note.Note.const "constant OptPrim with non-constant subexpression"
| PrimE (TagPrim _, [e1]) ->
check e1.note.Note.const "constant TagPrim with non-constant subexpression"
| BlockE (ds, e) ->
Expand Down
2 changes: 1 addition & 1 deletion src/ir_passes/const.ml
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ let rec exp lvl (env : env) e : Lbool.t =
| PrimE (TupPrim, es)
| PrimE (ArrayPrim (Const, _), es) ->
all (List.map (fun e -> exp lvl env e) es)
| PrimE (DotPrim _, [e1] | ProjPrim _, [e1] | TagPrim _, [e1]) ->
| PrimE (DotPrim _, [e1] | ProjPrim _, [e1] | OptPrim, [e1] | TagPrim _, [e1]) ->
exp lvl env e1
| LitE _ ->
surely_true
Expand Down
1 change: 1 addition & 0 deletions test/run/ok/refuted-const-float.run.ok
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
refuted-const-float.mo:2.5-2.9: execution error, value 3.140_000_000_000_000_1 does not match pattern
1 change: 1 addition & 0 deletions test/run/ok/refuted-const-float.run.ret.ok
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Return code 1
4 changes: 4 additions & 0 deletions test/run/ok/refuted-const-float.tc.ok
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
refuted-const-float.mo:2.5-2.9: warning [M0145], this pattern of type
Float
does not cover value
_
10 changes: 10 additions & 0 deletions test/run/ok/refuted-const-float.wasm-run.ok
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
pattern failed
Error: failed to run main module `_out/refuted-const-float.wasm`

Caused by:
0: failed to invoke command default
1: wasm trap: unreachable
wasm backtrace:
0: init
1: _start

1 change: 1 addition & 0 deletions test/run/ok/refuted-const-float.wasm-run.ret.ok
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Return code 134
1 change: 1 addition & 0 deletions test/run/ok/refuted-const-option-null.run.ok
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
refuted-const-option-null.mo:2.5-2.7: execution error, value null does not match pattern
1 change: 1 addition & 0 deletions test/run/ok/refuted-const-option-null.run.ret.ok
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Return code 1
4 changes: 4 additions & 0 deletions test/run/ok/refuted-const-option-null.tc.ok
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
refuted-const-option-null.mo:2.5-2.7: warning [M0145], this pattern of type
Null
does not cover value
null
10 changes: 10 additions & 0 deletions test/run/ok/refuted-const-option-null.wasm-run.ok
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
pattern failed
Error: failed to run main module `_out/refuted-const-option-null.wasm`

Caused by:
0: failed to invoke command default
1: wasm trap: unreachable
wasm backtrace:
0: init
1: _start

1 change: 1 addition & 0 deletions test/run/ok/refuted-const-option-null.wasm-run.ret.ok
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Return code 134
1 change: 1 addition & 0 deletions test/run/ok/refuted-const-option.run.ok
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
refuted-const-option.mo:2.5-2.9: execution error, value ?42 does not match pattern
1 change: 1 addition & 0 deletions test/run/ok/refuted-const-option.run.ret.ok
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Return code 1
4 changes: 4 additions & 0 deletions test/run/ok/refuted-const-option.tc.ok
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
refuted-const-option.mo:2.5-2.9: warning [M0145], this pattern of type
?Nat
does not cover value
?(_)
10 changes: 10 additions & 0 deletions test/run/ok/refuted-const-option.wasm-run.ok
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
pattern failed
Error: failed to run main module `_out/refuted-const-option.wasm`

Caused by:
0: failed to invoke command default
1: wasm trap: unreachable
wasm backtrace:
0: init
1: _start

1 change: 1 addition & 0 deletions test/run/ok/refuted-const-option.wasm-run.ret.ok
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Return code 134
11 changes: 11 additions & 0 deletions test/run/refuted-const-float.mo
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
// a failing pattern match that can be compiled to a trap
let 0.67 = 3.14;

// CHECK: (func $init (type
// CHECK: call $blob_of_principal
// CHECK: i32.const 14
// CHECK-NEXT: call $print_ptr
// CHECK-NEXT: unreachable)

//SKIP run-low
//SKIP run-ir
11 changes: 11 additions & 0 deletions test/run/refuted-const-option-null.mo
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
// a failing pattern match that can be compiled to a trap
let ?b = null;

// CHECK: (func $init (type
// CHECK: call $blob_of_principal
// CHECK: i32.const 14
// CHECK-NEXT: call $print_ptr
// CHECK-NEXT: unreachable)

//SKIP run-low
//SKIP run-ir
Loading

0 comments on commit efb8180

Please sign in to comment.