Skip to content

Commit

Permalink
feat: initial noir support (#1)
Browse files Browse the repository at this point in the history
  • Loading branch information
olehmisar authored and ewynx committed Sep 19, 2024
1 parent 5315753 commit 8ded16e
Show file tree
Hide file tree
Showing 3 changed files with 136 additions and 2 deletions.
12 changes: 10 additions & 2 deletions packages/compiler/src/bin/compiler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,14 @@ enum Commands {
Decomposed {
#[arg(short, long)]
decomposed_regex_path: String,
#[arg(short, long)]
#[arg(long)]
halo2_dir_path: Option<String>,
#[arg(short, long)]
circom_file_path: Option<String>,
#[arg(short, long)]
template_name: Option<String>,
#[arg(long)]
noir_file_path: Option<String>,
#[arg(short, long)]
gen_substrs: Option<bool>,
},
Expand All @@ -74,12 +76,14 @@ enum Commands {
raw_regex: String,
#[arg(short, long)]
substrs_json_path: Option<String>,
#[arg(short, long)]
#[arg(long)]
halo2_dir_path: Option<String>,
#[arg(short, long)]
circom_file_path: Option<String>,
#[arg(short, long)]
template_name: Option<String>,
#[arg(long)]
noir_file_path: Option<String>,
#[arg(short, long)]
gen_substrs: Option<bool>,
},
Expand All @@ -99,6 +103,7 @@ fn process_decomposed(cli: Cli) {
halo2_dir_path,
circom_file_path,
template_name,
noir_file_path,
gen_substrs,
} = cli.command
{
Expand All @@ -107,6 +112,7 @@ fn process_decomposed(cli: Cli) {
halo2_dir_path.as_deref(),
circom_file_path.as_deref(),
template_name.as_deref(),
noir_file_path.as_deref(),
gen_substrs,
) {
eprintln!("Error: {}", e);
Expand All @@ -122,6 +128,7 @@ fn process_raw(cli: Cli) {
halo2_dir_path,
circom_file_path,
template_name,
noir_file_path,
gen_substrs,
} = cli.command
{
Expand All @@ -131,6 +138,7 @@ fn process_raw(cli: Cli) {
halo2_dir_path.as_deref(),
circom_file_path.as_deref(),
template_name.as_deref(),
noir_file_path.as_deref(),
gen_substrs,
) {
eprintln!("Error: {}", e);
Expand Down
11 changes: 11 additions & 0 deletions packages/compiler/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
mod circom;
mod errors;
mod halo2;
mod noir;
mod regex;
mod structs;
mod wasm;
Expand All @@ -9,6 +10,7 @@ use circom::gen_circom_template;
use errors::CompilerError;
use halo2::gen_halo2_tables;
use itertools::Itertools;
use noir::gen_noir_fn;
use regex::{create_regex_and_dfa_from_str_and_defs, get_regex_and_dfa};
use std::{fs::File, path::PathBuf};
use structs::{DecomposedRegexConfig, RegexAndDFA, SubstringDefinitionsJson};
Expand Down Expand Up @@ -55,6 +57,7 @@ fn generate_outputs(
halo2_dir_path: Option<&str>,
circom_file_path: Option<&str>,
circom_template_name: Option<&str>,
noir_file_path: Option<&str>,
num_public_parts: usize,
gen_substrs: bool,
) -> Result<(), CompilerError> {
Expand Down Expand Up @@ -86,6 +89,10 @@ fn generate_outputs(
)?;
}

if let Some(noir_file_path) = noir_file_path {
gen_noir_fn(regex_and_dfa, &PathBuf::from(noir_file_path))?;
}

Ok(())
}

Expand All @@ -107,6 +114,7 @@ pub fn gen_from_decomposed(
halo2_dir_path: Option<&str>,
circom_file_path: Option<&str>,
circom_template_name: Option<&str>,
noir_file_path: Option<&str>,
gen_substrs: Option<bool>,
) -> Result<(), CompilerError> {
let mut decomposed_regex_config: DecomposedRegexConfig =
Expand All @@ -126,6 +134,7 @@ pub fn gen_from_decomposed(
halo2_dir_path,
circom_file_path,
circom_template_name,
noir_file_path,
num_public_parts,
gen_substrs,
)?;
Expand Down Expand Up @@ -153,6 +162,7 @@ pub fn gen_from_raw(
halo2_dir_path: Option<&str>,
circom_file_path: Option<&str>,
template_name: Option<&str>,
noir_file_path: Option<&str>,
gen_substrs: Option<bool>,
) -> Result<(), CompilerError> {
let substrs_defs_json = load_substring_definitions_json(substrs_json_path)?;
Expand All @@ -167,6 +177,7 @@ pub fn gen_from_raw(
halo2_dir_path,
circom_file_path,
template_name,
noir_file_path,
num_public_parts,
gen_substrs,
)?;
Expand Down
115 changes: 115 additions & 0 deletions packages/compiler/src/noir.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
use std::{collections::HashSet, fs::File, io::Write, iter::FromIterator, path::Path};

use itertools::Itertools;

use crate::structs::RegexAndDFA;

const ACCEPT_STATE_ID: &str = "accept";

pub fn gen_noir_fn(regex_and_dfa: &RegexAndDFA, path: &Path) -> Result<(), std::io::Error> {
let noir_fn = to_noir_fn(regex_and_dfa);
let mut file = File::create(path)?;
file.write_all(noir_fn.as_bytes())?;
file.flush()?;
Ok(())
}

fn to_noir_fn(regex_and_dfa: &RegexAndDFA) -> String {
let accept_state_ids = {
let accept_states = regex_and_dfa
.dfa
.states
.iter()
.filter(|s| s.state_type == ACCEPT_STATE_ID)
.map(|s| s.state_id)
.collect_vec();
assert!(accept_states.len() > 0, "no accept states");
accept_states
};

const BYTE_SIZE: u32 = 256; // u8 size
let mut lookup_table_body = String::new();

// curr_state + char_code -> next_state
let mut rows: Vec<(usize, u8, usize)> = vec![];

for state in regex_and_dfa.dfa.states.iter() {
for (&tran_next_state_id, tran) in &state.transitions {
for &char_code in tran {
rows.push((state.state_id, char_code, tran_next_state_id));
}
}
if state.state_type == ACCEPT_STATE_ID {
let existing_char_codes = &state
.transitions
.iter()
.flat_map(|(_, tran)| tran.iter().copied().collect_vec())
.collect::<HashSet<_>>();
let all_char_codes = HashSet::from_iter(0..=255);
let mut char_codes = all_char_codes.difference(existing_char_codes).collect_vec();
char_codes.sort(); // to be deterministic
for &char_code in char_codes {
rows.push((state.state_id, char_code, state.state_id));
}
}
}

for (curr_state_id, char_code, next_state_id) in rows {
lookup_table_body +=
&format!("table[{curr_state_id} * {BYTE_SIZE} + {char_code}] = {next_state_id};\n",);
}

lookup_table_body = indent(&lookup_table_body);
let table_size = BYTE_SIZE as usize * regex_and_dfa.dfa.states.len();
let lookup_table = format!(
r#"
comptime fn make_lookup_table() -> [Field; {table_size}] {{
let mut table = [0; {table_size}];
{lookup_table_body}
table
}}
"#
);

let final_states_condition_body = accept_state_ids
.iter()
.map(|id| format!("(s == {id})"))
.collect_vec()
.join(" | ");
let fn_body = format!(
r#"
global table = comptime {{ make_lookup_table() }};
pub fn regex_match<let N: u32>(input: [u8; N]) {{
// regex: {regex_pattern}
let mut s = 0;
for i in 0..input.len() {{
s = table[s * {BYTE_SIZE} + input[i] as Field];
}}
assert({final_states_condition_body}, f"no match: {{s}}");
}}
"#,
regex_pattern = regex_and_dfa.regex_pattern,
);
format!(
r#"
{fn_body}
{lookup_table}
"#
)
.trim()
.to_owned()
}

fn indent(s: &str) -> String {
s.split("\n")
.map(|s| {
if s.trim().is_empty() {
s.to_owned()
} else {
format!("{}{}", " ", s)
}
})
.collect::<Vec<_>>()
.join("\n")
}

0 comments on commit 8ded16e

Please sign in to comment.