Skip to content

Commit

Permalink
Migrate to syn v2.0 (#36)
Browse files Browse the repository at this point in the history
  • Loading branch information
s-arash authored Mar 31, 2024
1 parent eb65f13 commit 0ee6238
Show file tree
Hide file tree
Showing 11 changed files with 107 additions and 74 deletions.
6 changes: 3 additions & 3 deletions ascent_macro/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,15 @@ proc-macro = true
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
syn = { version = "1.0.109", features = ["derive", "full", "extra-traits", "visit-mut"] }
syn = { version = "2.0.57", features = ["derive", "full", "extra-traits", "visit-mut"] }
quote = "1.0"
ascent_base = { workspace = true }
proc-macro2 = "1.0"
itertools = "0.12.0"
petgraph = "0.6.0"
derive-syn-parse = "0.1.5"
derive-syn-parse = "0.2.0"
lazy_static = "1.4.0"
duplicate = "0.4"
duplicate = { version = "1.0.0", default-features = false }

[dev-dependencies]
ascent = { path = "../ascent" }
Expand Down
5 changes: 1 addition & 4 deletions ascent_macro/src/ascent_codegen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1128,10 +1128,7 @@ fn head_clauses_structs_and_update_code(rule: &MirRule, scc: &MirScc, mir: &Asce
}
}
(
quote!{
// #(#struct_defs)*
// #(#rel_data_vars)*
},
quote!{},
quote!{#(#add_rows)*}
)
}
Expand Down
19 changes: 11 additions & 8 deletions ascent_macro/src/ascent_hir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,17 @@ impl AscentConfig {
const INTER_RULE_PARALLELISM_ATTR: &'static str = "inter_rule_parallelism";

pub fn new(attrs: Vec<Attribute>, is_parallel: bool) -> syn::Result<AscentConfig> {
let include_rule_times = attrs.iter().any(|attr| attr.path.is_ident(Self::MEASURE_RULE_TIMES_ATTR));
let generate_run_partial = attrs.iter().any(|attr| attr.path.is_ident(Self::GENERATE_RUN_TIMEOUT_ATTR));
let inter_rule_parallelism = attrs.iter().find(|attr| attr.path.is_ident(Self::INTER_RULE_PARALLELISM_ATTR));
let include_rule_times = attrs.iter().find(|attr| attr.meta.path().is_ident(Self::MEASURE_RULE_TIMES_ATTR))
.map(|attr| attr.meta.require_path_only()).transpose()?.is_some();
let generate_run_partial = attrs.iter().find(|attr| attr.meta.path().is_ident(Self::GENERATE_RUN_TIMEOUT_ATTR))
.map(|attr| attr.meta.require_path_only()).transpose()?.is_some();
let inter_rule_parallelism = attrs.iter().find(|attr| attr.meta.path().is_ident(Self::INTER_RULE_PARALLELISM_ATTR))
.map(|attr| attr.meta.require_path_only()).transpose()?;

let recognized_attrs =
[Self::MEASURE_RULE_TIMES_ATTR, Self::GENERATE_RUN_TIMEOUT_ATTR, Self::INTER_RULE_PARALLELISM_ATTR, REL_DS_ATTR];
for attr in attrs.iter() {
if !recognized_attrs.iter().any(|recognized_attr| attr.path.is_ident(recognized_attr)) {
if !recognized_attrs.iter().any(|recognized_attr| attr.meta.path().is_ident(recognized_attr)) {
return Err(Error::new_spanned(attr,
format!("unrecognized attribute. recognized attributes are: {}",
recognized_attrs.join(", "))));
Expand Down Expand Up @@ -236,7 +239,7 @@ pub(crate) fn compile_ascent_program_to_hir(prog: &AscentProgram, is_parallel: b
rel_identity.clone(),
RelationMetadata {
initialization: rel.initialization.clone().map(Rc::new),
attributes: Rc::new(rel.attrs.iter().filter(|attr| attr.path.get_ident().map_or(true, |ident| !RECOGNIIZED_REL_ATTRS.iter().any(|ra| ident == ra))).cloned().collect_vec()),
attributes: Rc::new(rel.attrs.iter().filter(|attr| attr.meta.path().get_ident().map_or(true, |ident| !RECOGNIIZED_REL_ATTRS.iter().any(|ra| ident == ra))).cloned().collect_vec()),
ds_macro_path: ds_attribute.path,
ds_macro_args: ds_attribute.args
}
Expand Down Expand Up @@ -275,15 +278,15 @@ pub(crate) fn compile_ascent_program_to_hir(prog: &AscentProgram, is_parallel: b

fn get_ds_attr(attrs: &[Attribute]) -> syn::Result<Option<DsAttributeContents>> {
let ds_attrs = attrs.iter()
.filter(|attr| attr.path.get_ident().map_or(false, |ident| ident == REL_DS_ATTR))
.filter(|attr| attr.meta.path().get_ident().map_or(false, |ident| ident == REL_DS_ATTR))
.collect_vec();
match &ds_attrs[..] {
[] => Ok(None),
[attr] => {
let res = syn::parse2::<DsAttributeContents>(attr.tokens.clone())?;
let res = syn::parse2::<DsAttributeContents>(attr.meta.require_list()?.tokens.clone())?;
Ok(Some(res))
},
[_attr1, attr2, ..] => Err(Error::new(attr2.bracket_token.span, "multiple `ds` attributes specified")),
[_attr1, attr2, ..] => Err(Error::new(attr2.bracket_token.span.join(), "multiple `ds` attributes specified")),
}
}

Expand Down
25 changes: 15 additions & 10 deletions ascent_macro/src/ascent_syntax.rs
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ impl Parse for RelationNode {
let name : Ident = input.parse()?;
let content;
parenthesized!(content in input);
let field_types = content.parse_terminated(Type::parse)?;
let field_types = content.parse_terminated(Type::parse, Token![,])?;
let initialization = if input.peek(Token![=]) {
input.parse::<Token![=]>()?;
Some(input.parse::<Expr>()?)
Expand Down Expand Up @@ -186,6 +186,7 @@ impl Parse for DisjunctionNode {
#[derive(Parse, Clone)]
pub struct GeneratorNode {
pub for_keyword: Token![for],
#[call(Pat::parse_multi)]
pub pattern: Pat,
pub in_keyword: Token![in],
pub expr: Expr
Expand All @@ -198,7 +199,7 @@ pub struct BodyClauseNode {
pub cond_clauses: Vec<CondClause>
}

#[derive(Parse, Clone, PartialEq, Eq)]
#[derive(Parse, Clone, PartialEq, Eq, Debug)]
pub enum BodyClauseArg {
#[peek(Token![?], name = "Pattern arg")]
Pat(ClauseArgPattern),
Expand Down Expand Up @@ -240,17 +241,19 @@ impl ToTokens for BodyClauseArg {
}
}

#[derive(Parse, Clone, PartialEq, Eq)]
#[derive(Parse, Clone, PartialEq, Eq, Debug)]
pub struct ClauseArgPattern {
pub huh_token: Token![?],
#[call(Pat::parse_multi)]
pub pattern : Pat,
}

#[derive(Parse, Clone, PartialEq, Eq, Hash, Debug)]
pub struct IfLetClause {
pub if_keyword: Token![if],
pub let_keyword: Token![let],
pub pattern: syn::Pat,
#[call(Pat::parse_multi)]
pub pattern: Pat,
pub eq_symbol : Token![=],
pub exp: syn::Expr,
}
Expand All @@ -264,7 +267,8 @@ pub struct IfClause {
#[derive(Parse, Clone, PartialEq, Eq, Hash, Debug)]
pub struct LetClause {
pub let_keyword: Token![let],
pub pattern: syn::Pat,
#[call(Pat::parse_multi)]
pub pattern: Pat,
pub eq_symbol : Token![=],
pub exp: syn::Expr,
}
Expand Down Expand Up @@ -326,7 +330,7 @@ impl Parse for BodyClauseNode{
let rel : Ident = input.parse()?;
let args_content;
parenthesized!(args_content in input);
let args = args_content.parse_terminated(BodyClauseArg::parse)?;
let args = args_content.parse_terminated(BodyClauseArg::parse, Token![,])?;
let mut cond_clauses = vec![];
while let Ok(cl) = input.parse(){
cond_clauses.push(cl);
Expand Down Expand Up @@ -381,14 +385,15 @@ impl Parse for HeadClauseNode{
let rel : Ident = input.parse()?;
let args_content;
parenthesized!(args_content in input);
let args = args_content.parse_terminated(Expr::parse)?;
let args = args_content.parse_terminated(Expr::parse, Token![,])?;
Ok(HeadClauseNode{rel, args})
}
}

#[derive(Clone, Parse)]
pub struct AggClauseNode {
pub agg_kw: kw::agg,
#[call(Pat::parse_multi)]
pub pat: Pat,
pub eq_token: Token![=],
pub aggregator: AggregatorNode,
Expand Down Expand Up @@ -591,8 +596,8 @@ pub(crate) struct DsAttributeContents {

impl Parse for DsAttributeContents {
fn parse(input: ParseStream) -> Result<Self> {
let content;
parenthesized!(content in input);
let content = input;
// parenthesized!(content in input);

let path = syn::Path::parse_mod_style(&content)?;
let args = if content.peek(Token![:]) {
Expand Down Expand Up @@ -972,7 +977,7 @@ fn rule_expand_macro_invocations(rule: RuleNode, macros: &HashMap<Ident, &MacroD
BodyItemNode::Disjunction(disj) => {
let new_disj: Punctuated<Result<_>, _> = punctuated_map(disj.disjuncts, |bis|{
let new_bis = punctuated_map(bis,|bi| {
body_item_expand_macros(bi, macros, gensym, depth - 1, Some(disj.paren.span))
body_item_expand_macros(bi, macros, gensym, depth - 1, Some(disj.paren.span.join()))
});
Ok(flatten_punctuated(punctuated_try_unwrap(new_bis)?))
});
Expand Down
66 changes: 42 additions & 24 deletions ascent_macro/src/syn_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ pub fn block_get_vars(block: &Block) -> Vec<Ident> {
pub fn pattern_get_vars(pat: &Pat) -> Vec<Ident> {
let mut res = vec![];
match pat {
Pat::Box(pat_box) => res.extend(pattern_get_vars(&pat_box.pat)),
Pat::Ident(pat_ident) => {
res.push(pat_ident.ident.clone());
if let Some(subpat) = &pat_ident.subpat {
Expand Down Expand Up @@ -70,7 +69,7 @@ pub fn pattern_get_vars(pat: &Pat) -> Vec<Ident> {
}
}
Pat::TupleStruct(tuple_strcut_pat) => {
for elem_pat in tuple_strcut_pat.pat.elems.iter(){
for elem_pat in tuple_strcut_pat.elems.iter(){
res.extend(pattern_get_vars(elem_pat));
}
},
Expand All @@ -93,7 +92,6 @@ pub fn pattern_visit_vars_mut(pat: &mut Pat, visitor: &mut dyn FnMut(&mut Ident)
};
}
match pat {
Pat::Box(pat_box) => visit!(&mut pat_box.pat),
Pat::Ident(pat_ident) => {
visitor(&mut pat_ident.ident);
if let Some(subpat) = &mut pat_ident.subpat {
Expand Down Expand Up @@ -127,7 +125,7 @@ pub fn pattern_visit_vars_mut(pat: &mut Pat, visitor: &mut dyn FnMut(&mut Ident)
}
}
Pat::TupleStruct(tuple_strcut_pat) => {
for elem_pat in tuple_strcut_pat.pat.elems.iter_mut(){
for elem_pat in tuple_strcut_pat.elems.iter_mut(){
visit!(elem_pat);
}
},
Expand All @@ -142,10 +140,12 @@ pub fn pattern_visit_vars_mut(pat: &mut Pat, visitor: &mut dyn FnMut(&mut Ident)

#[test]
fn test_pattern_get_vars(){
use syn::parse::Parser;

let pattern = quote! {
SomePair(x, (y, z))
};
let pat : syn::Pat = parse2(pattern).unwrap();
let pat = Pat::parse_single.parse2(pattern).unwrap();
assert_eq!(collect_set(["x", "y", "z"].iter().map(ToString::to_string)),
pattern_get_vars(&pat).into_iter().map(|id| id.to_string()).collect());

Expand All @@ -166,34 +166,58 @@ pub fn stmt_get_vars(stmt: &Stmt) -> (Vec<Ident>, Vec<Ident>) {
match stmt {
Stmt::Local(l) => {
bound_vars.extend(pattern_get_vars(&l.pat));
if let Some(init) = &l.init {used_vars.extend(expr_get_vars(&init.1))}
if let Some(init) = &l.init {
used_vars.extend(expr_get_vars(&init.expr));
if let Some(diverge) = &init.diverge {
used_vars.extend(expr_get_let_bound_vars(&diverge.1));
}
}
},
Stmt::Item(_) => {},
Stmt::Expr(e) => used_vars.extend(expr_get_vars(e)),
Stmt::Semi(e, _) => used_vars.extend(expr_get_vars(e))
Stmt::Expr(e, _) => used_vars.extend(expr_get_vars(e)),
Stmt::Macro(m) => {
eprintln!("WARNING: cannot determine variables of macro invocations. macro invocation:\n{}",
m.to_token_stream());
}
}
(bound_vars, used_vars)
}

pub fn stmt_visit_free_vars_mut(stmt: &mut Stmt, visitor: &mut dyn FnMut(&mut Ident)) {
match stmt {
Stmt::Local(l) => {
if let Some(init) = &mut l.init {expr_visit_free_vars_mut(&mut init.1, visitor)}
if let Some(init) = &mut l.init {
expr_visit_free_vars_mut(&mut init.expr, visitor);
if let Some(diverge) = &mut init.diverge {
expr_visit_free_vars_mut(&mut diverge.1, visitor);
}
}
},
Stmt::Item(_) => {},
Stmt::Expr(e) |
Stmt::Semi(e, _) => expr_visit_free_vars_mut(e, visitor)
Stmt::Expr(e, _) => expr_visit_free_vars_mut(e, visitor),
Stmt::Macro(m) => {
eprintln!("WARNING: cannot determine free variables of macro invocations. macro invocation:\n{}",
m.to_token_stream());
}
}
}

pub fn stmt_visit_free_vars(stmt: &Stmt, visitor: &mut dyn FnMut(& Ident)) {
match stmt {
Stmt::Local(l) => {
if let Some(init) = &l.init {expr_visit_free_vars(&init.1, visitor)}
if let Some(init) = &l.init {
expr_visit_free_vars(&init.expr, visitor);
if let Some(diverge) = &init.diverge {
expr_visit_free_vars(&diverge.1, visitor);
}
}
},
Stmt::Item(_) => {},
Stmt::Expr(e) |
Stmt::Semi(e, _) => expr_visit_free_vars(e, visitor)
Stmt::Expr(e, _) => expr_visit_free_vars(e, visitor),
Stmt::Macro(m) => {
eprintln!("WARNING: cannot determine free variables of macro invocations. macro invocation:\n{}",
m.to_token_stream());
}
}
}

Expand Down Expand Up @@ -246,18 +270,13 @@ pub fn expr_visit_free_vars_mbm(expr: reft([Expr]), visitor: &mut dyn FnMut(reft
visit!(assign.left);
visit!(assign.right)
},
Expr::AssignOp(assign_op) => {
visit!(assign_op.left);
visit!(assign_op.right)
},
Expr::Async(a) => block_visit_free_vars_mbm(reft([a.block]), visitor),
Expr::Await(a) => visit!(a.base),
Expr::Binary(b) => {
visit!(b.left);
visit!(b.right)
}
Expr::Block(b) => block_visit_free_vars_mbm(reft([b.block]), visitor),
Expr::Box(e) => expr_visit_free_vars_mbm(e.expr.deref_mbm(), visitor),
Expr::Break(b) => if let Some(b_e) = reft([b.expr]) {expr_visit_free_vars_mbm(b_e, visitor)},
Expr::Call(c) => {
visit!(c.func);
Expand Down Expand Up @@ -320,11 +339,11 @@ pub fn expr_visit_free_vars_mbm(expr: reft([Expr]), visitor: &mut dyn FnMut(reft
}
}
Expr::Range(r) => {
if let Some(from) = reft([r.from]) {
expr_visit_free_vars_mbm(from, visitor)
if let Some(start) = reft([r.start]) {
expr_visit_free_vars_mbm(start, visitor)
};
if let Some(to) = reft([r.to]) {
expr_visit_free_vars_mbm(to, visitor)
if let Some(end) = reft([r.end]) {
expr_visit_free_vars_mbm(end, visitor)
};
}
Expr::Reference(r) => visit!(r.expr),
Expand All @@ -350,7 +369,6 @@ pub fn expr_visit_free_vars_mbm(expr: reft([Expr]), visitor: &mut dyn FnMut(reft
expr_visit_free_vars_mbm(e, visitor)
}
}
Expr::Type(t) => visit!(t.expr),
Expr::Unary(u) => visit!(u.expr),
Expr::Unsafe(u) => block_visit_free_vars_mbm(reft([u.block]), visitor),
Expr::Verbatim(_) => {}
Expand Down
28 changes: 7 additions & 21 deletions ascent_macro/src/tests.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
#![cfg(test)]
use petgraph::dot::{Config, Dot};
use proc_macro2::TokenStream;
use syn::parse2;

use crate::ascent_impl;

Expand Down Expand Up @@ -188,14 +187,14 @@ fn test_macro3() {

#[test]
fn test_macro_agg() {
let inp = quote!{
relation foo(i32, i32);
lattice bar(i32, i32, i32);
relation baz(i32, i32, i32);
let inp = quote! {
relation foo(i32);
relation bar(i32, i32, i32);
lattice baz(i32, i32);

baz(x, y, min_z) <--
foo(x, y),
agg min_z = min(z) in bar(x, y, z);
baz(x, min_z) <--
foo(x),
agg min_z = min(z) in bar(x, _, z);
};
write_to_scratchpad(inp);
}
Expand Down Expand Up @@ -543,16 +542,3 @@ fn test_macro_in_macro() {

write_to_scratchpad(inp);
}

#[test]
fn test_macro_item() {
let def = quote! {
macro foo($x: expr, $y: expr) {
$x + $y
}
};

let parsed = parse2::<syn::ItemMacro2>(def).unwrap();

println!("rules: {}", parsed.rules);
}
1 change: 1 addition & 0 deletions ascent_macro/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ pub fn pat_to_ident(pat: &Pat) -> Option<Ident> {

pub fn is_wild_card(expr: &Expr) -> bool {
match expr {
Expr::Infer(_) => true,
Expr::Verbatim(ts) => ts.to_string() == "_",
_ => false
}
Expand Down
Loading

0 comments on commit 0ee6238

Please sign in to comment.