diff --git a/Cargo.lock b/Cargo.lock index 93a474a8..b2a93a5d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -389,7 +389,6 @@ dependencies = [ "num", "ordered-float", "rustc-hash", - "serde_json", "smallvec", "symbol_table", "thiserror", diff --git a/Cargo.toml b/Cargo.toml index 80ca644b..12f824e6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -33,7 +33,6 @@ bin = [ "graphviz", "dep:clap", "dep:env_logger", - "dep:serde_json", "dep:chrono", ] serde = ["egraph-serialize/serde"] @@ -42,38 +41,28 @@ wasm-bindgen = ["instant/wasm-bindgen", "dep:getrandom"] nondeterministic = [] [dependencies] +clap = { version = "4", features = ["derive"], optional = true } +egraph-serialize = { version = "0.2.0", default-features = false } +env_logger = { version = "0.10", optional = true } hashbrown = { version = "0.15" } +im-rc = "15.1.0" +im = "15.1.0" indexmap = "2.0" instant = "0.1" +lazy_static = "1.4" log = "0.4" +num = "0.4.3" +ordered-float = { version = "3.7" } rustc-hash = "1.1" +smallvec = "1.11" symbol_table = { version = "0.4.0", features = ["global"] } thiserror = "1" -lazy_static = "1.4" -num = "0.4.3" -smallvec = "1.11" - -egraph-serialize = { version = "0.2.0", default-features = false } - -# binary dependencies -clap = { version = "4", features = ["derive"], optional = true } -env_logger = { version = "0.10", optional = true } -serde_json = { version = "1.0.100", optional = true, features = [ - "preserve_order", -] } - -ordered-float = { version = "3.7" } # Need to add "js" feature for "graphviz-rust" to work in wasm -getrandom = { version = "0.2.10", features = ["js"], optional = true } - -im-rc = "15.1.0" -im = "15.1.0" +getrandom = { version = "0.2.10", optional = true, features = ["js"] } [build-dependencies] -chrono = { version = "0.4", default-features = false, optional = true, features = [ - "now", -] } +chrono = { version = "0.4", default-features = false, features = ["now"], optional = true } [dev-dependencies] codspeed-criterion-compat = "2.7.2" diff --git a/src/cli.rs b/src/cli.rs new file mode 100644 index 00000000..bc7c25cd --- /dev/null +++ b/src/cli.rs @@ -0,0 +1,253 @@ +use crate::*; +use std::io::{self, BufRead, BufReader, Read, Write}; + +#[cfg(feature = "bin")] +pub mod bin { + use super::*; + use clap::Parser; + use std::path::PathBuf; + + #[derive(Debug, Parser)] + #[command(version = env!("FULL_VERSION"), about = env!("CARGO_PKG_DESCRIPTION"))] + struct Args { + /// Directory for files when using `input` and `output` commands + #[clap(short = 'F', long)] + fact_directory: Option, + /// Turns off the seminaive optimization + #[clap(long)] + naive: bool, + /// Prints extra information, which can be useful for debugging + #[clap(long, default_value_t = RunMode::Normal)] + show: RunMode, + /// Changes the prefix of the generated symbols + // TODO why do we support this? + // TODO remove this evil hack + #[clap(long, default_value = "__")] + reserved_symbol: String, + /// The file names for the egglog files to run + inputs: Vec, + /// Serializes the egraph for each egglog file as JSON + #[clap(long)] + to_json: bool, + /// Serializes the egraph for each egglog file as a dot file + #[clap(long)] + to_dot: bool, + /// Serializes the egraph for each egglog file as an SVG + #[clap(long)] + to_svg: bool, + /// Splits the serialized egraph into primitives and non-primitives + #[clap(long)] + serialize_split_primitive_outputs: bool, + /// Maximum number of function nodes to render in dot/svg output + #[clap(long, default_value = "40")] + max_functions: usize, + /// Maximum number of calls per function to render in dot/svg output + #[clap(long, default_value = "40")] + max_calls_per_function: usize, + /// Number of times to inline leaves + #[clap(long, default_value = "0")] + serialize_n_inline_leaves: usize, + /// Prevents egglog from printing messages + #[clap(long)] + no_messages: bool, + } + + #[allow(clippy::disallowed_macros)] + pub fn cli(mut egraph: EGraph) { + env_logger::Builder::new() + .filter_level(log::LevelFilter::Info) + .format_timestamp(None) + .format_target(false) + .parse_default_env() + .init(); + + let args = Args::parse(); + egraph.set_reserved_symbol(args.reserved_symbol.clone().into()); + egraph.fact_directory.clone_from(&args.fact_directory); + egraph.seminaive = !args.naive; + egraph.run_mode = args.show; + if args.no_messages { + egraph.disable_messages(); + } + + if args.inputs.is_empty() { + log::info!("Welcome to Egglog REPL! (build: {})", env!("FULL_VERSION")); + match egraph.repl() { + Ok(()) => std::process::exit(0), + Err(err) => { + log::error!("{err}"); + std::process::exit(1) + } + } + } else { + for input in &args.inputs { + let program = std::fs::read_to_string(input).unwrap_or_else(|_| { + let arg = input.to_string_lossy(); + panic!("Failed to read file {arg}") + }); + + match egraph.parse_and_run_program(Some(input.to_str().unwrap().into()), &program) { + Ok(msgs) => { + for msg in msgs { + println!("{msg}"); + } + } + Err(err) => { + log::error!("{err}"); + std::process::exit(1) + } + } + + if args.to_json || args.to_dot || args.to_svg { + let mut serialized = egraph.serialize(SerializeConfig::default()); + if args.serialize_split_primitive_outputs { + serialized.split_classes(|id, _| egraph.from_node_id(id).is_primitive()) + } + for _ in 0..args.serialize_n_inline_leaves { + serialized.inline_leaves(); + } + + // if we are splitting primitive outputs, add `-split` to the end of the file name + let serialize_filename = if args.serialize_split_primitive_outputs { + input.with_file_name(format!( + "{}-split", + input.file_stem().unwrap().to_str().unwrap() + )) + } else { + input.clone() + }; + if args.to_dot { + let dot_path = serialize_filename.with_extension("dot"); + serialized.to_dot_file(dot_path).unwrap() + } + if args.to_svg { + let svg_path = serialize_filename.with_extension("svg"); + serialized.to_svg_file(svg_path).unwrap() + } + if args.to_json { + let json_path = serialize_filename.with_extension("json"); + serialized.to_json_file(json_path).unwrap(); + } + } + } + } + + // no need to drop the egraph if we are going to exit + std::mem::forget(egraph) + } +} + +impl EGraph { + pub fn repl(&mut self) -> io::Result<()> { + self.repl_with(io::stdin(), io::stdout()) + } + + pub fn repl_with(&mut self, input: R, mut output: W) -> io::Result<()> + where + R: Read, + W: Write, + { + let mut cmd_buffer = String::new(); + + for line in BufReader::new(input).lines() { + let line_str = line?; + cmd_buffer.push_str(&line_str); + cmd_buffer.push('\n'); + // handles multi-line commands + if should_eval(&cmd_buffer) { + run_command_in_scripting(self, &cmd_buffer, &mut output)?; + cmd_buffer = String::new(); + } + } + + if !cmd_buffer.is_empty() { + run_command_in_scripting(self, &cmd_buffer, &mut output)?; + } + + Ok(()) + } +} + +fn should_eval(curr_cmd: &str) -> bool { + all_sexps(Context::new(None, curr_cmd)).is_ok() +} + +fn run_command_in_scripting(egraph: &mut EGraph, command: &str, mut output: W) -> io::Result<()> +where + W: Write, +{ + match egraph.parse_and_run_program(None, command) { + Ok(msgs) => { + for msg in msgs { + writeln!(output, "{msg}")?; + } + } + Err(err) => log::error!("{err}"), + } + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_should_eval() { + #[rustfmt::skip] + let test_cases = vec![ + vec![ + "(extract", + "\"1", + ")", + "(", + ")))", + "\"", + ";; )", + ")" + ], + vec![ + "(extract 1) (extract", + "2) (", + "extract 3) (extract 4) ;;;; (" + ], + vec![ + "(extract \"\\\")\")" + ]]; + for test in test_cases { + let mut cmd_buffer = String::new(); + for (i, line) in test.iter().enumerate() { + cmd_buffer.push_str(line); + cmd_buffer.push('\n'); + assert_eq!(should_eval(&cmd_buffer), i == test.len() - 1); + } + } + } + + #[test] + fn test_repl() { + let mut egraph = EGraph::default(); + + let input = "(extract 1)"; + let mut output = Vec::new(); + egraph.repl_with(input.as_bytes(), &mut output).unwrap(); + assert_eq!(String::from_utf8(output).unwrap(), "1\n"); + + let input = "\n\n\n"; + let mut output = Vec::new(); + egraph.repl_with(input.as_bytes(), &mut output).unwrap(); + assert_eq!(String::from_utf8(output).unwrap(), ""); + + let input = "(set-option interactive_mode 1)"; + let mut output = Vec::new(); + egraph.repl_with(input.as_bytes(), &mut output).unwrap(); + assert_eq!(String::from_utf8(output).unwrap(), "(done)\n"); + + let input = "(set-option interactive_mode 1)\n(extract 1)(extract 2)\n"; + let mut output = Vec::new(); + egraph.repl_with(input.as_bytes(), &mut output).unwrap(); + assert_eq!( + String::from_utf8(output).unwrap(), + "(done)\n1\n(done)\n2\n(done)\n" + ); + } +} diff --git a/src/lib.rs b/src/lib.rs index dd141c3b..ba212ff7 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -13,12 +13,12 @@ //! mod actions; pub mod ast; +mod cli; pub mod constraint; mod core; pub mod extract; mod function; mod gj; -mod repl; mod serialize; pub mod sort; mod termdag; @@ -33,6 +33,8 @@ use crate::typechecking::TypeError; use actions::Program; use ast::remove_globals::remove_globals; use ast::*; +#[cfg(feature = "bin")] +pub use cli::bin::*; use constraint::{Constraint, SimpleTypeConstraint, TypeConstraint}; use extract::Extractor; pub use function::Function; diff --git a/src/main.rs b/src/main.rs index 3dd93c0f..0dafdfd0 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,133 +1,3 @@ -use clap::Parser; -use egglog::{EGraph, RunMode, SerializeConfig}; -use std::path::PathBuf; - -#[derive(Debug, Parser)] -#[command(version = env!("FULL_VERSION"), about = env!("CARGO_PKG_DESCRIPTION"))] -struct Args { - #[clap(short = 'F', long)] - fact_directory: Option, - #[clap(long)] - naive: bool, - #[clap(long)] - desugar: bool, - #[clap(long)] - resugar: bool, - #[clap(long, default_value_t = RunMode::Normal)] - show: RunMode, - // TODO remove this evil hack - #[clap(long, default_value = "__")] - reserved_symbol: String, - inputs: Vec, - #[clap(long)] - to_json: bool, - #[clap(long)] - to_dot: bool, - #[clap(long)] - to_svg: bool, - #[clap(long)] - serialize_split_primitive_outputs: bool, - /// Maximum number of function nodes to render in dot/svg output - #[clap(long, default_value = "40")] - max_functions: usize, - /// Maximum number of calls per function to render in dot/svg output - #[clap(long, default_value = "40")] - max_calls_per_function: usize, - /// Number of times to inline leaves - #[clap(long, default_value = "0")] - serialize_n_inline_leaves: usize, - #[clap(long)] - no_messages: bool, -} - -#[allow(clippy::disallowed_macros)] fn main() { - env_logger::Builder::new() - .filter_level(log::LevelFilter::Info) - .format_timestamp(None) - .format_target(false) - .parse_default_env() - .init(); - - let args = Args::parse(); - - let mk_egraph = || { - let mut egraph = EGraph::default(); - egraph.set_reserved_symbol(args.reserved_symbol.clone().into()); - egraph.fact_directory.clone_from(&args.fact_directory); - egraph.seminaive = !args.naive; - egraph.run_mode = args.show; - if args.no_messages { - egraph.disable_messages(); - } - egraph - }; - - if args.inputs.is_empty() { - let mut egraph = mk_egraph(); - - log::info!("Welcome to Egglog REPL! (build: {})", env!("FULL_VERSION")); - match egraph.repl() { - Ok(()) => std::process::exit(0), - Err(err) => { - log::error!("{err}"); - std::process::exit(1) - } - } - } - - for (idx, input) in args.inputs.iter().enumerate() { - let program = std::fs::read_to_string(input).unwrap_or_else(|_| { - let arg = input.to_string_lossy(); - panic!("Failed to read file {arg}") - }); - let mut egraph = mk_egraph(); - match egraph.parse_and_run_program(Some(input.to_str().unwrap().into()), &program) { - Ok(msgs) => { - for msg in msgs { - println!("{msg}"); - } - } - Err(err) => { - log::error!("{err}"); - std::process::exit(1) - } - } - - if args.to_json || args.to_dot || args.to_svg { - let mut serialized = egraph.serialize(SerializeConfig::default()); - if args.serialize_split_primitive_outputs { - serialized.split_classes(|id, _| egraph.from_node_id(id).is_primitive()) - } - for _ in 0..args.serialize_n_inline_leaves { - serialized.inline_leaves(); - } - - // if we are splitting primitive outputs, add `-split` to the end of the file name - let serialize_filename = if args.serialize_split_primitive_outputs { - input.with_file_name(format!( - "{}-split", - input.file_stem().unwrap().to_str().unwrap() - )) - } else { - input.clone() - }; - if args.to_dot { - let dot_path = serialize_filename.with_extension("dot"); - serialized.to_dot_file(dot_path).unwrap() - } - if args.to_svg { - let svg_path = serialize_filename.with_extension("svg"); - serialized.to_svg_file(svg_path).unwrap() - } - if args.to_json { - let json_path = serialize_filename.with_extension("json"); - serialized.to_json_file(json_path).unwrap(); - } - } - // no need to drop the egraph if we are going to exit - if idx == args.inputs.len() - 1 { - std::mem::forget(egraph) - } - } + egglog::cli(egglog::EGraph::default()) } diff --git a/src/repl.rs b/src/repl.rs deleted file mode 100644 index b04dcc6f..00000000 --- a/src/repl.rs +++ /dev/null @@ -1,119 +0,0 @@ -use std::io::{self, Read}; -use std::io::{BufRead, BufReader, Write}; - -use crate::{all_sexps, Context, EGraph}; - -impl EGraph { - pub fn repl(&mut self) -> io::Result<()> { - self.repl_with(io::stdin(), io::stdout()) - } - - pub fn repl_with(&mut self, input: R, mut output: W) -> io::Result<()> - where - R: Read, - W: Write, - { - let mut cmd_buffer = String::new(); - - for line in BufReader::new(input).lines() { - let line_str = line?; - cmd_buffer.push_str(&line_str); - cmd_buffer.push('\n'); - // handles multi-line commands - if should_eval(&cmd_buffer) { - run_command_in_scripting(self, &cmd_buffer, &mut output)?; - cmd_buffer = String::new(); - } - } - - if !cmd_buffer.is_empty() { - run_command_in_scripting(self, &cmd_buffer, &mut output)?; - } - - Ok(()) - } -} - -fn should_eval(curr_cmd: &str) -> bool { - all_sexps(Context::new(None, curr_cmd)).is_ok() -} - -fn run_command_in_scripting(egraph: &mut EGraph, command: &str, mut output: W) -> io::Result<()> -where - W: Write, -{ - match egraph.parse_and_run_program(None, command) { - Ok(msgs) => { - for msg in msgs { - writeln!(output, "{msg}")?; - } - } - Err(err) => log::error!("{err}"), - } - Ok(()) -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_should_eval() { - #[rustfmt::skip] - let test_cases = vec![ - vec![ - "(extract", - "\"1", - ")", - "(", - ")))", - "\"", - ";; )", - ")" - ], - vec![ - "(extract 1) (extract", - "2) (", - "extract 3) (extract 4) ;;;; (" - ], - vec![ - "(extract \"\\\")\")" - ]]; - for test in test_cases { - let mut cmd_buffer = String::new(); - for (i, line) in test.iter().enumerate() { - cmd_buffer.push_str(line); - cmd_buffer.push('\n'); - assert_eq!(should_eval(&cmd_buffer), i == test.len() - 1); - } - } - } - - #[test] - fn test_repl() { - let mut egraph = EGraph::default(); - - let input = "(extract 1)"; - let mut output = Vec::new(); - egraph.repl_with(input.as_bytes(), &mut output).unwrap(); - assert_eq!(String::from_utf8(output).unwrap(), "1\n"); - - let input = "\n\n\n"; - let mut output = Vec::new(); - egraph.repl_with(input.as_bytes(), &mut output).unwrap(); - assert_eq!(String::from_utf8(output).unwrap(), ""); - - let input = "(set-option interactive_mode 1)"; - let mut output = Vec::new(); - egraph.repl_with(input.as_bytes(), &mut output).unwrap(); - assert_eq!(String::from_utf8(output).unwrap(), "(done)\n"); - - let input = "(set-option interactive_mode 1)\n(extract 1)(extract 2)\n"; - let mut output = Vec::new(); - egraph.repl_with(input.as_bytes(), &mut output).unwrap(); - assert_eq!( - String::from_utf8(output).unwrap(), - "(done)\n1\n(done)\n2\n(done)\n" - ); - } -}