Skip to content

Commit

Permalink
Finished up remaining touches to the implementation. Now I need to fi…
Browse files Browse the repository at this point in the history
…nish unit tests and add a few more
  • Loading branch information
MicroProofs committed Oct 31, 2024
1 parent a50657e commit 6593752
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 84 deletions.
42 changes: 26 additions & 16 deletions crates/aiken-lang/src/gen_uplc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2493,22 +2493,24 @@ impl<'a> CodeGenerator<'a> {
)
});

let y = AirTree::when(
let when_air_clauses = AirTree::when(
test_subject_name,
return_tipo.clone(),
current_tipo.clone(),
AirTree::local_var(current_subject_name, current_tipo.clone()),
clauses,
);

let x = builtins_to_add.to_air(
builtins_to_add.to_air(
// The only reason I pass this in is to ensure I signal
// whether or not constr_fields_exposer was used. I could
// probably optimize this part out to simplify codegen in
// the future
&mut self.special_functions,
prev_subject_name,
prev_tipo,
y,
);

x
when_air_clauses,
)
}
DecisionTree::ListSwitch {
path,
Expand Down Expand Up @@ -2573,10 +2575,9 @@ impl<'a> CodeGenerator<'a> {
tree.1.clone()
};

let last_case = cases.last().unwrap().0.clone();

let builtins_for_pattern =
builtins_path.merge(Builtins::new_from_list_case(last_case.clone()));
let builtins_for_pattern = builtins_path.merge(Builtins::new_from_list_case(
CaseTest::List(longest_pattern),
));

stick_set.diff_union_builtins(builtins_for_pattern.clone());

Expand Down Expand Up @@ -2655,6 +2656,8 @@ impl<'a> CodeGenerator<'a> {
format!("{}_{}", subject_name, builtins_for_pattern.to_string())
};

// TODO: change this in the future to use the Builtins to_string method
// to ensure future changes don't break things
let next_tail_name = Some(format!("{}_tail", tail_name));

let then = self.handle_decision_tree(
Expand All @@ -2675,29 +2678,34 @@ impl<'a> CodeGenerator<'a> {
false,
);

// since we iterate over the list cases in reverse
// We pop off a builtin to make it easier to get the name of
// prev_tested list case since each name is based off the builtins
builtins_for_pattern.pop();

(builtins_for_pattern, acc)
}
},
);

let y = AirTree::when(
let when_list_cases = AirTree::when(
current_subject_name.clone(),
return_tipo.clone(),
current_tipo.clone(),
AirTree::local_var(current_subject_name, current_tipo.clone()),
list_clauses.1,
);

let x = builtins_to_add.to_air(
builtins_to_add.to_air(
// The only reason I pass this in is to ensure I signal
// whether or not constr_fields_exposer was used. I could
// probably optimize this part out to simplify codegen in
// the future
&mut self.special_functions,
prev_subject_name,
prev_tipo,
y,
);

x
when_list_cases,
)
}
DecisionTree::HoistedLeaf(name, args) => {
let air_args = args
Expand Down Expand Up @@ -2742,6 +2750,8 @@ impl<'a> CodeGenerator<'a> {
assign
})
.collect_vec(),
// The one reason we have to pass in mutable self
// So we can build the TypedExpr into Air
self.build(then, module_build_name, &[]),
true,
),
Expand Down
132 changes: 64 additions & 68 deletions crates/aiken-lang/src/gen_uplc/decision_tree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,8 @@ impl<'a> DecisionTree<'a> {
.join("\n")
}

// Please help me with this the way nesting works
// has me baffled
fn to_doc(&self) -> RcDoc<()> {
match self {
DecisionTree::Switch {
Expand Down Expand Up @@ -382,7 +384,7 @@ impl<'a> DecisionTree<'a> {
}

/// For fun I decided to do this without recursion
/// It doesn't look to bad lol
/// It doesn't look too bad lol
fn get_hoist_paths<'b>(&self, names: Vec<&'b String>) -> IndexMap<&'b String, Scope> {
let mut prev = vec![];

Expand Down Expand Up @@ -461,35 +463,37 @@ impl<'a> DecisionTree<'a> {
scope_for_name.common_ancestor(&current_path);
}
}
// These are not generated by do_build_tree, but
// added afterwards
DecisionTree::HoistThen { .. } => unreachable!(),
}

if let Some(action) = prev.pop() {
match action {
Marker::Pop => {
current_path.pop();
}
Marker::Push(p, dec_tree) => {
current_path.push(p);
match prev.pop() {
Some(Marker::Pop) => {
current_path.pop();
}
Some(Marker::Push(p, dec_tree)) => {
current_path.push(p);

tree = dec_tree;
}
Marker::PopPush(p, dec_tree) => {
current_path.pop();
tree = dec_tree;
}
Some(Marker::PopPush(p, dec_tree)) => {
current_path.pop();

current_path.push(p);
current_path.push(p);

tree = dec_tree;
}
tree = dec_tree;
}
} else {
break;
}
// Break out of loop and return the map with all names properly
// scoped
None => break,
};
}

scope_map
}

// I did recursion here since we need mutable pointers to modify the tree
fn hoist_by_path(
&mut self,
current_path: &mut Scope,
Expand Down Expand Up @@ -542,6 +546,8 @@ impl<'a> DecisionTree<'a> {
}

loop {
// We sorted name_paths before passing it in.
// This ensures we will visit each node in the order we would pop it off
if let Some(name_path) = name_paths.pop() {
if name_path.1 == *current_path {
let (assigns, then) = hoistables.remove(&name_path.0).unwrap();
Expand Down Expand Up @@ -605,6 +611,8 @@ impl<'a, 'b> TreeGen<'a, 'b> {
.iter()
.enumerate()
.map(|(index, clause)| {
// Assigns are split out from patterns so they can be handled
// outside of the tree algorithm
let (assign, row_items) =
self.map_pattern_to_row(&clause.pattern, subject_tipo, vec![]);

Expand Down Expand Up @@ -663,8 +671,8 @@ impl<'a, 'b> TreeGen<'a, 'b> {
.iter()
.all(|row| { row.columns.len() == column_length }));

let occurrence_col = highest_occurrence(&matrix, column_length);
// Find which column has the most important pattern
let occurrence_col = highest_occurrence(&matrix, column_length);

let Some(occurrence_col) = occurrence_col else {
// No more patterns to match on so we grab the first default row and return that
Expand All @@ -676,6 +684,9 @@ impl<'a, 'b> TreeGen<'a, 'b> {
unreachable!()
};

// This is just to prevent repeated assigning clones for the same fallback
// used in multiple places
// So we could just overwrite it everytime too.
if assigns.is_empty() {
*assigns = row.assigns.clone();
}
Expand All @@ -688,6 +699,7 @@ impl<'a, 'b> TreeGen<'a, 'b> {
let mut has_list_pattern = false;

// List patterns are special so we need more information on length
// of the longest pattern with and without a tail
matrix.rows.iter().for_each(|item| {
let col = &item.columns[occurrence_col];

Expand Down Expand Up @@ -722,17 +734,21 @@ impl<'a, 'b> TreeGen<'a, 'b> {
}
});

// Since occurrence_col is Some it means there is a
// pattern to match on so we also must have a path to the object to test
// for that pattern
let path = matrix
.rows
.get(0)
.unwrap()
.columns
.get(occurrence_col)
.map(|col| col.path.clone())
.unwrap_or(vec![]);
.unwrap();

let specialized_tipo = get_tipo_by_path(subject_tipo.clone(), &path);

// Time to split on the matrices based on case to test for or lack of
let (default_matrix, specialized_matrices) = matrix.rows.into_iter().fold(
(vec![], vec![]),
|(mut default_matrix, mut case_matrices): (Vec<Row>, Vec<(CaseTest, Vec<Row>)>),
Expand Down Expand Up @@ -763,6 +779,7 @@ impl<'a, 'b> TreeGen<'a, 'b> {
.map(|elem| match elem {
Position::First((index, element))
| Position::Middle((index, element))
// Impossible to have a list pattern of only tail element
| Position::Only((index, element)) => {
let mut item_path = col.path.clone();

Expand Down Expand Up @@ -822,6 +839,7 @@ impl<'a, 'b> TreeGen<'a, 'b> {
)
}
Pattern::Tuple { .. } | Pattern::Pair { .. } | Pattern::Assign { .. } => {
// These patterns are fully expanded out when mapping pattern to row
unreachable!("{:#?}", col.pattern)
}
};
Expand All @@ -836,6 +854,8 @@ impl<'a, 'b> TreeGen<'a, 'b> {
// Add inner patterns to existing row
let mut new_cols = remaining_patts.into_iter().flat_map(|x| x.1).collect_vec();

// To align number of columns we pop off the tail since it can
// never include a pattern besides wild card
if matches!(case, CaseTest::ListWithTail(_)) {
new_cols.pop();
}
Expand Down Expand Up @@ -866,20 +886,24 @@ impl<'a, 'b> TreeGen<'a, 'b> {
matrix.push(row);
}
});
} else if let CaseTest::ListWithTail(case_length) = case {
} else if let CaseTest::ListWithTail(tail_case_length) = case {
// For lists with tail it's a special case where we also add it to existing patterns
// all the way to the longest element. The reason being that each list size greater
// than the list with tail could also match with could also match depending on the inner pattern.
// all the way to the longest list pattern with no tail. The reason being that each list greater
// than the list with tail pattern could also match with the with list that has x elements + any extra afterwards
// See tests below for an example

let longest_elems_with_tail = longest_elems_with_tail.unwrap();

// You can have a match with all list patterns having a tail or wild card
if let Some(longest_elems_no_tail) = longest_elems_no_tail {
for elem_count in case_length..=longest_elems_no_tail {
for elem_count in tail_case_length..=longest_elems_no_tail {
let case = CaseTest::List(elem_count);

let mut row = row.clone();

for _ in 0..(elem_count - case_length) {
for _ in 0..(elem_count - tail_case_length) {
row.columns
.insert(case_length, self.wild_card_pattern.clone());
.insert(tail_case_length, self.wild_card_pattern.clone());
}

self.insert_case(
Expand All @@ -892,18 +916,15 @@ impl<'a, 'b> TreeGen<'a, 'b> {
}
}

let Some(longest_elems_with_tail) = longest_elems_with_tail else {
unreachable!()
};

for elem_count in case_length..=longest_elems_with_tail {
// Comment above applies here
for elem_count in tail_case_length..=longest_elems_with_tail {
let case = CaseTest::ListWithTail(elem_count);

let mut row = row.clone();

for _ in 0..(elem_count - case_length) {
for _ in 0..(elem_count - tail_case_length) {
row.columns
.insert(case_length, self.wild_card_pattern.clone());
.insert(tail_case_length, self.wild_card_pattern.clone());
}

self.insert_case(
Expand Down Expand Up @@ -932,26 +953,16 @@ impl<'a, 'b> TreeGen<'a, 'b> {
rows: default_matrix,
};

if has_list_pattern {
// Since the list_tail case might cover the rest of the possible matches extensively
// then fallback is optional here
let fallback_option = if default_matrix.rows.is_empty() {
None
} else {
Some(
self.do_build_tree(
subject_name,
subject_tipo,
// Since everything after this point had a wild card on or above
// the row for the selected column in front. Then we ignore the
// cases and continue to check other columns.
default_matrix,
then_map,
)
let fallback_option = if default_matrix.rows.is_empty() {
None
} else {
Some(
self.do_build_tree(subject_name, subject_tipo, default_matrix, then_map)
.into(),
)
};
)
};

if has_list_pattern {
let (tail_cases, cases): (Vec<_>, Vec<_>) = specialized_matrices
.into_iter()
.partition(|(case, _)| matches!(case, CaseTest::ListWithTail(_)));
Expand Down Expand Up @@ -989,23 +1000,6 @@ impl<'a, 'b> TreeGen<'a, 'b> {
default: fallback_option,
}
} else {
let fallback_option = if default_matrix.rows.is_empty() {
None
} else {
Some(
self.do_build_tree(
subject_name,
subject_tipo,
// Since everything after this point had a wild card on or above
// the row for the selected column in front. Then we ignore the
// cases and continue to check other columns.
default_matrix,
then_map,
)
.into(),
)
};

DecisionTree::Switch {
path,
cases: specialized_matrices
Expand Down Expand Up @@ -1277,6 +1271,8 @@ fn highest_occurrence(matrix: &PatternMatrix, column_length: usize) -> Option<us
}
});

// This condition is only true if and only if
// all columns on the top row are wild cards
if highest_occurrence.1 == 0 {
None
} else {
Expand Down

0 comments on commit 6593752

Please sign in to comment.