Skip to content

Commit

Permalink
Refactor the postprocessor
Browse files Browse the repository at this point in the history
  • Loading branch information
hirrolot committed Dec 25, 2024
1 parent 3377d0e commit 9801ab7
Show file tree
Hide file tree
Showing 10 changed files with 97 additions and 82 deletions.
18 changes: 6 additions & 12 deletions lib/abstract_syntax/renaming.ml
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,12 @@ type t = (Symbol.t Symbol_map.t[@printer pp_as_list]) [@@deriving eq, show]

let not_in_codomain ~env y = Symbol_map.for_all (fun _ y' -> y <> y') env

let fresh_symbol ~gensym ~renaming ~env = function
let fresh_symbol ~gensym ~fresh_to_source_vars ~env = function
(* [x] was generated in the process of driving; we must generate a new symbol that has
not been used yet. *)
| x when (Symbol.to_string x).[0] = '.' ->
let x' =
(* [renaming] maps driver-generated variables to their counterparts from the
original user program. If it contains a mapping for [x], we use it; otherwise,
we generate a new proper symbol to replace [x]. *)
match Symbol_map.find_opt x renaming with
match Symbol_map.find_opt x fresh_to_source_vars with
| Some x' -> x'
| None -> Gensym.emit gensym
in
Expand All @@ -30,27 +27,24 @@ let fresh_symbol ~gensym ~renaming ~env = function
| x -> x
;;

let insert
~(gensym : Gensym.t)
?(renaming = ((Symbol_map.empty : t) [@coverage off]))
((env, x) : t * Symbol.t)
let insert ~(gensym : Gensym.t) ~(fresh_to_source_vars : t) ((env, x) : t * Symbol.t)
: t * Symbol.t
=
let y = fresh_symbol ~gensym ~renaming ~env x in
let y = fresh_symbol ~gensym ~fresh_to_source_vars ~env x in
Symbol_map.add x y env, y
;;

let insert_list
~(gensym : Gensym.t)
?(renaming = (Symbol_map.empty : t))
~(fresh_to_source_vars : t)
((env, list) : t * Symbol.t list)
: t * Symbol.t list
=
let env, list =
list
|> List.fold_left
(fun (env, list) x ->
let env, y = insert ~gensym ~renaming (env, x) in
let env, y = insert ~gensym ~fresh_to_source_vars (env, x) in
env, fun xs -> list (y :: xs))
(env, Fun.id)
in
Expand Down
8 changes: 6 additions & 2 deletions lib/abstract_syntax/renaming.mli
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
type t = Symbol.t Symbol_map.t [@@deriving eq, show]

val insert : gensym:Gensym.t -> ?renaming:t -> t * Symbol.t -> t * Symbol.t
val insert : gensym:Gensym.t -> fresh_to_source_vars:t -> t * Symbol.t -> t * Symbol.t

val insert_list : gensym:Gensym.t -> ?renaming:t -> t * Symbol.t list -> t * Symbol.t list
val insert_list
: gensym:Gensym.t
-> fresh_to_source_vars:t
-> t * Symbol.t list
-> t * Symbol.t list
8 changes: 4 additions & 4 deletions lib/mechanics/driver.ml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ type 'a step =
and contraction =
{ c : Symbol.t
; fresh_vars : Symbol.t list
; original_vars : Symbol.t list
; source_vars : Symbol.t list
}

and 'a case_body = (Symbol.t * 'a) option * 'a
Expand Down Expand Up @@ -45,7 +45,7 @@ let map ~(f : 'a -> 'b) : 'a step -> 'b step = function
Extract (binding, f u)
;;

let unify ~x ~contraction:{ c; fresh_vars; original_vars = _ } list =
let unify ~x ~contraction:{ c; fresh_vars; source_vars = _ } list =
match x with
| None -> list
| Some x ->
Expand Down Expand Up @@ -171,7 +171,7 @@ struct
view_g_rules ~program g
|> List.map (fun ((c, (c_params, params, body)), is_productive) ->
let fresh_vars = Gensym.emit_list ~length_list:c_params gensym in
let contraction = { c; fresh_vars; original_vars = c_params } in
let contraction = { c; fresh_vars; source_vars = c_params } in
let args =
try f contraction with
| Uncontractable ->
Expand All @@ -193,7 +193,7 @@ struct
;;

let unfold_g_rules_t_f ~depth ~test:(x, op', unifier) ~args g =
let f { c; fresh_vars; original_vars = _ } =
let f { c; fresh_vars; source_vars = _ } =
match Symbol.(to_string c, to_string op'), fresh_vars with
| ("T", "="), [] | ("F", "!="), [] ->
List.map (Term.subst ~x ~value:unifier) args
Expand Down
2 changes: 1 addition & 1 deletion lib/mechanics/driver.mli
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ type 'a step =
and contraction =
{ c : Symbol.t
; fresh_vars : Symbol.t list
; original_vars : Symbol.t list
; source_vars : Symbol.t list
}

and 'a case_body = (Symbol.t * 'a) option * 'a
Expand Down
103 changes: 62 additions & 41 deletions lib/mechanics/postprocessor.ml
Original file line number Diff line number Diff line change
@@ -1,38 +1,27 @@
[@@@coverage off]

open Raw_term

let symbol = Symbol.of_string

(* TODO: handle other types of restrictions as well. *)
type restriction = NotEqual of Raw_term.t

let handle_term ~(gensym : Gensym.t) ~(env : Renaming.t) ~(renaming : Renaming.t) t =
(* We propagate negative information about primitives at residualization-time. During
driving, we only propagate positive information. *)
let restrictions = Hashtbl.create 128 in
let is_conflict = function
(* We propagate negative information about primitives at residualization-time.
During driving, we only propagate positive information. *)
let propagate_restrictions : env:restriction Symbol_map.t -> Raw_term.t -> Raw_term.t =
let open Raw_term in
let is_conflict ~env = function
| x, ((Var _ | Const _) as t) ->
(match Hashtbl.find_opt restrictions x with
(match Symbol_map.find_opt x env with
| Some (NotEqual t') when equal t t' -> true
| _ -> false)
| _ -> false
in
let hermetic k force =
let gensym_backup = Gensym.clone gensym in
let result = k force in
Gensym.assign ~other:gensym_backup gensym;
result
in
let exception Select of Raw_term.t in
let rec go ~env = function
| Var x -> Var (Symbol_map.find x env)
| Const _ as t -> t
| Call (op, [ Var x; t ])
when op = symbol "=" && is_conflict (x, hermetic (go ~env) t) ->
| (Var _ | Const _) as t -> t
| Call (op, [ Var x; t ]) when op = symbol "=" && is_conflict ~env (x, t) ->
Call (symbol "F", [])
| Call (op, [ t; Var x ])
when op = symbol "=" && is_conflict (x, hermetic (go ~env) t) ->
| Call (op, [ t; Var x ]) when op = symbol "=" && is_conflict ~env (x, t) ->
Call (symbol "F", [])
| Call (op, args) -> Call (op, List.map (go ~env) args)
| Match
Expand All @@ -49,27 +38,16 @@ let handle_term ~(gensym : Gensym.t) ~(env : Renaming.t) ~(renaming : Renaming.t
| Match (t, cases) ->
(try
let t = go_scrutinee ~env ~cases t in
let cases = List.map (go_case ~env) cases in
let cases = List.map (fun (pattern, t) -> pattern, go ~env t) cases in
Match (t, cases)
with
| Select t -> t)
| Let (x, t, u) ->
let t = go ~env t in
let env, y = Renaming.insert ~gensym ~renaming (env, x) in
let u = go ~env u in
Let (y, t, u)
| Let (x, t, u) -> Let (x, go ~env t, go ~env u)
and go_restrict ~env ~x ~negation t =
let negation = go ~env negation in
match negation with
| Var _ | Const _ ->
Hashtbl.add restrictions x (NotEqual negation);
let t = go ~env t in
Hashtbl.remove restrictions x;
t
match go ~env negation with
| (Var _ | Const _) as negation ->
go ~env:(Symbol_map.add x (NotEqual negation) env) t
| _ -> go ~env t
and go_case ~env ((c, c_params), t) =
let env, c_params' = Renaming.insert_list ~gensym ~renaming (env, c_params) in
(c, c_params'), go ~env t
and go_scrutinee ~env ~cases t =
let t = go ~env t in
(match t with
Expand All @@ -81,14 +59,57 @@ let handle_term ~(gensym : Gensym.t) ~(env : Renaming.t) ~(renaming : Renaming.t
| _ -> ());
t
in
go ~env t
go
;;

(* Generate proper symbols (i.e., without a leading dot), recovering as much
original program symbols as possible. *)
let rename ~gensym ~fresh_to_source_vars : env:Renaming.t -> Raw_term.t -> Raw_term.t =
let open Raw_term in
let rec go ~env = function
| Var x -> Var (Symbol_map.find x env)
| Const _ as t -> t
| Call (op, args) -> Call (op, List.map (go ~env) args)
| Match (t, cases) ->
let t = go ~env t in
let cases = List.map (go_case ~env) cases in
Match (t, cases)
| Let (x, t, u) ->
let t = go ~env t in
let env, y = Renaming.insert ~gensym ~fresh_to_source_vars (env, x) in
let u = go ~env u in
Let (y, t, u)
and go_case ~env ((c, c_params), t) =
let env, c_params' =
Renaming.insert_list ~gensym ~fresh_to_source_vars (env, c_params)
in
(c, c_params'), go ~env t
in
go
;;

let handle_rule ~(renaming : Renaming.t) ((attrs, f, params, body) : Raw_program.rule)
let handle_rule
~(fresh_to_source_vars : Renaming.t)
((attrs, f, params, body) : Raw_program.rule)
: Raw_program.rule
=
let body = propagate_restrictions ~env:Symbol_map.empty body in
let gensym = Gensym.create ~prefix:"v" () in
let env = Symbol_map.empty in
let env, params' = Renaming.insert_list ~gensym ~renaming (env, params) in
attrs, f, params', handle_term ~gensym ~env ~renaming body
let env, params =
Renaming.insert_list ~gensym ~fresh_to_source_vars (Symbol_map.empty, params)
in
let body = rename ~gensym ~fresh_to_source_vars ~env body in
attrs, f, params, body
;;

let handle_main_body
~(fresh_to_source_vars : Renaming.t)
~(unknowns : Symbol.t list)
(body : Raw_term.t)
: Raw_term.t
=
let _, _, _, body =
handle_rule ~fresh_to_source_vars ([], symbol "main", unknowns, body)
in
body
;;
11 changes: 5 additions & 6 deletions lib/mechanics/postprocessor.mli
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
val handle_term
: gensym:Gensym.t
-> env:Renaming.t
-> renaming:Renaming.t
val handle_rule : fresh_to_source_vars:Renaming.t -> Raw_program.rule -> Raw_program.rule

val handle_main_body
: fresh_to_source_vars:Renaming.t
-> unknowns:Symbol.t list
-> Raw_term.t
-> Raw_term.t

val handle_rule : renaming:Renaming.t -> Raw_program.rule -> Raw_program.rule
12 changes: 6 additions & 6 deletions lib/mechanics/process_graph.ml
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,13 @@ and node_id = Symbol.t

type metadata =
{ symbol_table : (Symbol.t * Program.param_list) Symbol_map.t
; fresh_to_original_vars : Renaming.t
; fresh_to_source_vars : Renaming.t
}

let compute_metadata (graph : t) : metadata =
let f_gensym = Gensym.create ~prefix:"f" () in
let metadata =
ref { symbol_table = Symbol_map.empty; fresh_to_original_vars = Symbol_map.empty }
ref { symbol_table = Symbol_map.empty; fresh_to_source_vars = Symbol_map.empty }
in
let rec go = function
| Step step -> go_step step
Expand All @@ -31,16 +31,16 @@ let compute_metadata (graph : t) : metadata =
| Driver.Analyze (_x, graph, variants) ->
go graph;
List.iter
(fun (Driver.{ c = _; fresh_vars; original_vars }, (binding, graph)) ->
(fun (Driver.{ c = _; fresh_vars; source_vars }, (binding, graph)) ->
List.iter2
(fun x y ->
metadata
:= { !metadata with
fresh_to_original_vars =
Symbol_map.add x y !metadata.fresh_to_original_vars
fresh_to_source_vars =
Symbol_map.add x y !metadata.fresh_to_source_vars
})
fresh_vars
original_vars;
source_vars;
match binding with
| Some binding -> go_extract (binding, graph)
| None -> go graph)
Expand Down
2 changes: 1 addition & 1 deletion lib/mechanics/process_graph.mli
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ and node_id = Symbol.t

type metadata =
{ symbol_table : (Symbol.t * Program.param_list) Symbol_map.t
; fresh_to_original_vars : Renaming.t
; fresh_to_source_vars : Renaming.t
}

(* Computes a map from node identifiers to residualized function signatures. *)
Expand Down
13 changes: 5 additions & 8 deletions lib/mechanics/residualizer.ml
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ end = struct
(* [params] contains all free variables in [t_res]. *)
f_rules
:= Postprocessor.handle_rule
~renaming:graph_metadata.fresh_to_original_vars
~fresh_to_source_vars:graph_metadata.fresh_to_source_vars
([], f, params, t_res)
:: !f_rules;
(* Some parameters may refer to bound variables; substitute. *)
Expand Down Expand Up @@ -195,7 +195,7 @@ let run ~(unknowns : Symbol.t list) (graph : Process_graph.t) : Raw_term.t * Raw
let cases_res =
variants
|> List.map
(fun (Driver.{ c; fresh_vars; original_vars = _ }, (binding, graph)) ->
(fun (Driver.{ c; fresh_vars; source_vars = _ }, (binding, graph)) ->
match binding with
| Some binding -> (c, fresh_vars), go_extract ~env (binding, graph)
| None -> (c, fresh_vars), go ~env graph)
Expand Down Expand Up @@ -227,12 +227,9 @@ let run ~(unknowns : Symbol.t list) (graph : Process_graph.t) : Raw_term.t * Raw
if is_innocent t_res then Left (x, t_res) else Right (x, t_res))
in
let t_res = go ~env:Symbol_map.empty graph in
let gensym = Gensym.create ~prefix:"v" () in
let env, _ = Renaming.insert_list ~gensym (Symbol_map.empty, unknowns) in
( Postprocessor.handle_term
~gensym
~env
~renaming:graph_metadata.fresh_to_original_vars
( Postprocessor.handle_main_body
~fresh_to_source_vars:graph_metadata.fresh_to_source_vars
~unknowns
t_res
, Memoizer.finalize () )
;;
2 changes: 1 addition & 1 deletion lib/mechanics/visualizer.ml
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ let run ~(oc : out_channel Lazy.t) (graph : Process_graph.t) : unit =
go_body ~parent_node graph
and go_variants ~parent_node (x, variants) =
List.iter
(fun (Driver.{ c; fresh_vars; original_vars = _ }, (binding, graph)) ->
(fun (Driver.{ c; fresh_vars; source_vars = _ }, (binding, graph)) ->
let attrs =
Symbol.(
label "%s=%s(%s)" (to_string x) (to_string c) (comma_sep fresh_vars))
Expand Down

0 comments on commit 9801ab7

Please sign in to comment.