Skip to content

Commit

Permalink
Enable the specification of special @leaf captures
Browse files Browse the repository at this point in the history
These capture are treated as leafs for the purpose of the flattening.
  • Loading branch information
Erin van der Veen committed Jun 9, 2022
1 parent 920872f commit cd118e9
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 26 deletions.
5 changes: 5 additions & 0 deletions languages/queries/json.scm
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
; Sometimes we want to indicate that certain parts of our source text should
; not be formated, but taken as is. We use the leaf capture name to inform the
; tool of this.
(string) @leaf

; We want every object and array to have the { start a newline. So we match on
; the named object/array followed by the first anonymous node { or [.
(object
Expand Down
83 changes: 58 additions & 25 deletions src/main.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use pretty::RcDoc;
use std::collections::BTreeSet;
use std::{error::Error, fs, io};
use tree_sitter::{Node, Parser, Query, QueryCursor};
use tree_sitter::{Node, Parser, Query, QueryCursor, QueryMatches};
use tree_sitter_json::language;

static TEST_FILE: &str = "tests/json.json";
Expand All @@ -19,51 +20,62 @@ enum Atom {
}

/// Given a node, returns the id of the first leaf in the subtree.
fn first_leaf_id(node: Node) -> usize {
fn first_leaf(node: Node) -> Node {
if node.child_count() == 0 {
node.id()
node
} else {
first_leaf_id(node.child(0).unwrap())
first_leaf(node.child(0).unwrap())
}
}

/// Given a node, returns the id of the last leaf in the subtree.
fn last_leaf_id(node: Node) -> usize {
fn last_leaf(node: Node) -> Node {
let nr_children = node.child_count();
if nr_children == 0 {
node.id()
node
} else {
last_leaf_id(node.child(nr_children - 1).unwrap())
last_leaf(node.child(nr_children - 1).unwrap())
}
}

fn collect_leafs<'a>(node: Node, atoms: &mut Vec<Atom>, source: &'a [u8]) {
if node.child_count() == 0 {
fn collect_leafs<'a>(
node: Node,
atoms: &mut Vec<Atom>,
source: &'a [u8],
specified_leaf_nodes: &BTreeSet<usize>,
) {
if node.child_count() == 0 || specified_leaf_nodes.contains(&node.id()) {
atoms.push(Atom::Leaf {
content: String::from(node.utf8_text(source).expect("Source file not valid utf8")),
id: node.id(),
});
} else {
for child in node.children(&mut node.walk()) {
collect_leafs(child, atoms, source)
collect_leafs(child, atoms, source, &specified_leaf_nodes)
}
}
}

/// Finds the matching node in the atoms and returns the index
/// TODO: Error
fn find_node(wanted_id: usize, atoms: &mut Vec<Atom>) -> usize {
for (i, node) in atoms.iter().enumerate() {
match node {
Atom::Leaf { id, .. } => {
if *id == wanted_id {
return i;
fn find_node(node: Node, atoms: &mut Vec<Atom>) -> usize {
let mut target_node = node;
loop {
for (i, node) in atoms.iter().enumerate() {
match node {
Atom::Leaf { id, .. } => {
if *id == target_node.id() {
return i;
}
}
_ => continue,
}
_ => continue,
}
target_node = match node.parent() {
Some(p) => p,
None => unreachable!(),
}
}
unreachable!()
}

fn main() -> Result<(), Box<dyn Error>> {
Expand All @@ -85,14 +97,18 @@ fn main() -> Result<(), Box<dyn Error>> {
let source = content.as_bytes();
let query = Query::new(json, query_str).expect("Error parsing query file");

// The Flattening: collects all terminal nodes of the tree-sitter tree in a Vec
let mut atoms: Vec<Atom> = Vec::new();
collect_leafs(root, &mut atoms, source);
// Find the ids of all tree-sitter nodes that were identified as a leaf
// We want to avoid recursing into them in the collect_leafs function.
let specified_leaf_nodes: BTreeSet<usize> = collect_leaf_ids(&query, root, source);

// Match queries
let mut cursor = QueryCursor::new();
let matches = cursor.matches(&query, root, source);

// The Flattening: collects all terminal nodes of the tree-sitter tree in a Vec
let mut atoms: Vec<Atom> = Vec::new();
collect_leafs(root, &mut atoms, source, &specified_leaf_nodes);

// Formatting
for m in matches {
for c in m.captures {
Expand All @@ -109,6 +125,23 @@ fn main() -> Result<(), Box<dyn Error>> {
Ok(())
}

fn collect_leaf_ids<'a>(query: &Query, root: Node, source: &'a [u8]) -> BTreeSet<usize> {
let mut ids = BTreeSet::new();

// TODO: Should probably use the same cursor as above
let mut cursor = QueryCursor::new();
let matches = cursor.matches(query, root, source);

for m in matches {
for c in m.captures {
if query.capture_names()[c.index as usize] == "leaf" {
ids.insert(c.node.id());
}
}
}
ids
}

fn atoms_to_doc<'a>(i: &mut usize, atoms: &'a Vec<Atom>) -> RcDoc<'a, ()> {
let mut doc = RcDoc::nil();
while *i < atoms.len() {
Expand Down Expand Up @@ -148,14 +181,14 @@ fn resolve_capture(name: String, atoms: &mut Vec<Atom>, node: Node) {
}

fn atoms_prepend(atom: Atom, node: Node, atoms: &mut Vec<Atom>) {
let id = first_leaf_id(node);
let index = find_node(id, atoms);
let target_node = first_leaf(node);
let index = find_node(target_node, atoms);
atoms.insert(index, atom);
}

fn atoms_append(atom: Atom, node: Node, atoms: &mut Vec<Atom>) {
let id = last_leaf_id(node);
let index = find_node(id, atoms);
let target_node = last_leaf(node);
let index = find_node(target_node, atoms);
if index > atoms.len() {
atoms.push(atom);
} else {
Expand Down
2 changes: 1 addition & 1 deletion tests/json.json
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
{
"category": "computing science",
"id": 1,
"question": "What is the type of the function \"f = uncurry . const\"",
"question": "What is the type of the function \"f = uncurry . const\"?",
"options": [
"f :: (b -> c) -> (a, b) -> c",
"f :: (a -> b1 -> c) -> b2 -> (a, b1) -> c"
Expand Down

0 comments on commit cd118e9

Please sign in to comment.