Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expose CLI for egglog-experimental #510

Merged
merged 12 commits into from
Jan 16, 2025
1 change: 0 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

33 changes: 11 additions & 22 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ bin = [
"graphviz",
"dep:clap",
"dep:env_logger",
"dep:serde_json",
"dep:chrono",
]
serde = ["egraph-serialize/serde"]
Expand All @@ -42,38 +41,28 @@ wasm-bindgen = ["instant/wasm-bindgen", "dep:getrandom"]
nondeterministic = []

[dependencies]
clap = { version = "4", features = ["derive"], optional = true }
egraph-serialize = { version = "0.2.0", default-features = false }
env_logger = { version = "0.10", optional = true }
hashbrown = { version = "0.15" }
im-rc = "15.1.0"
im = "15.1.0"
indexmap = "2.0"
instant = "0.1"
lazy_static = "1.4"
log = "0.4"
num = "0.4.3"
ordered-float = { version = "3.7" }
rustc-hash = "1.1"
smallvec = "1.11"
symbol_table = { version = "0.4.0", features = ["global"] }
thiserror = "1"
lazy_static = "1.4"
num = "0.4.3"
smallvec = "1.11"

egraph-serialize = { version = "0.2.0", default-features = false }

# binary dependencies
clap = { version = "4", features = ["derive"], optional = true }
env_logger = { version = "0.10", optional = true }
serde_json = { version = "1.0.100", optional = true, features = [
"preserve_order",
] }

ordered-float = { version = "3.7" }

# Need to add "js" feature for "graphviz-rust" to work in wasm
getrandom = { version = "0.2.10", features = ["js"], optional = true }

im-rc = "15.1.0"
im = "15.1.0"
getrandom = { version = "0.2.10", optional = true, features = ["js"] }

[build-dependencies]
chrono = { version = "0.4", default-features = false, optional = true, features = [
"now",
] }
chrono = { version = "0.4", default-features = false, features = ["now"], optional = true }

[dev-dependencies]
codspeed-criterion-compat = "2.7.2"
Expand Down
253 changes: 253 additions & 0 deletions src/cli.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,253 @@
use crate::*;
use std::io::{self, BufRead, BufReader, Read, Write};

#[cfg(feature = "bin")]
pub mod bin {
use super::*;
use clap::Parser;
use std::path::PathBuf;

#[derive(Debug, Parser)]
#[command(version = env!("FULL_VERSION"), about = env!("CARGO_PKG_DESCRIPTION"))]
struct Args {
/// Directory for files when using `input` and `output` commands
#[clap(short = 'F', long)]
fact_directory: Option<PathBuf>,
/// Turns off the seminaive optimization
#[clap(long)]
naive: bool,
/// Prints extra information, which can be useful for debugging
#[clap(long, default_value_t = RunMode::Normal)]
show: RunMode,
/// Changes the prefix of the generated symbols
// TODO why do we support this?
// TODO remove this evil hack
#[clap(long, default_value = "__")]
reserved_symbol: String,
/// The file names for the egglog files to run
inputs: Vec<PathBuf>,
/// Serializes the egraph for each egglog file as JSON
#[clap(long)]
to_json: bool,
/// Serializes the egraph for each egglog file as a dot file
#[clap(long)]
to_dot: bool,
/// Serializes the egraph for each egglog file as an SVG
#[clap(long)]
to_svg: bool,
/// Splits the serialized egraph into primitives and non-primitives
#[clap(long)]
serialize_split_primitive_outputs: bool,
/// Maximum number of function nodes to render in dot/svg output
#[clap(long, default_value = "40")]
max_functions: usize,
/// Maximum number of calls per function to render in dot/svg output
#[clap(long, default_value = "40")]
max_calls_per_function: usize,
/// Number of times to inline leaves
#[clap(long, default_value = "0")]
serialize_n_inline_leaves: usize,
/// Prevents egglog from printing messages
#[clap(long)]
no_messages: bool,
}

#[allow(clippy::disallowed_macros)]
pub fn cli(mut egraph: EGraph) {
env_logger::Builder::new()
.filter_level(log::LevelFilter::Info)
.format_timestamp(None)
.format_target(false)
.parse_default_env()
.init();

let args = Args::parse();
egraph.set_reserved_symbol(args.reserved_symbol.clone().into());
egraph.fact_directory.clone_from(&args.fact_directory);
egraph.seminaive = !args.naive;
egraph.run_mode = args.show;
if args.no_messages {
egraph.disable_messages();
}

if args.inputs.is_empty() {
log::info!("Welcome to Egglog REPL! (build: {})", env!("FULL_VERSION"));
match egraph.repl() {
Ok(()) => std::process::exit(0),
Err(err) => {
log::error!("{err}");
std::process::exit(1)
}
}
} else {
for input in &args.inputs {
let program = std::fs::read_to_string(input).unwrap_or_else(|_| {
let arg = input.to_string_lossy();
panic!("Failed to read file {arg}")
});

match egraph.parse_and_run_program(Some(input.to_str().unwrap().into()), &program) {
Ok(msgs) => {
for msg in msgs {
println!("{msg}");
}
}
Err(err) => {
log::error!("{err}");
std::process::exit(1)
}
}

if args.to_json || args.to_dot || args.to_svg {
let mut serialized = egraph.serialize(SerializeConfig::default());
if args.serialize_split_primitive_outputs {
serialized.split_classes(|id, _| egraph.from_node_id(id).is_primitive())
}
for _ in 0..args.serialize_n_inline_leaves {
serialized.inline_leaves();
}

// if we are splitting primitive outputs, add `-split` to the end of the file name
let serialize_filename = if args.serialize_split_primitive_outputs {
input.with_file_name(format!(
"{}-split",
input.file_stem().unwrap().to_str().unwrap()
))
} else {
input.clone()
};
if args.to_dot {
let dot_path = serialize_filename.with_extension("dot");
serialized.to_dot_file(dot_path).unwrap()
}
if args.to_svg {
let svg_path = serialize_filename.with_extension("svg");
serialized.to_svg_file(svg_path).unwrap()
}
if args.to_json {
let json_path = serialize_filename.with_extension("json");
serialized.to_json_file(json_path).unwrap();
}
}
}
}

// no need to drop the egraph if we are going to exit
std::mem::forget(egraph)
}
}

impl EGraph {
pub fn repl(&mut self) -> io::Result<()> {
self.repl_with(io::stdin(), io::stdout())
}

pub fn repl_with<R, W>(&mut self, input: R, mut output: W) -> io::Result<()>
where
R: Read,
W: Write,
{
let mut cmd_buffer = String::new();

for line in BufReader::new(input).lines() {
let line_str = line?;
cmd_buffer.push_str(&line_str);
cmd_buffer.push('\n');
// handles multi-line commands
if should_eval(&cmd_buffer) {
run_command_in_scripting(self, &cmd_buffer, &mut output)?;
cmd_buffer = String::new();
}
}

if !cmd_buffer.is_empty() {
run_command_in_scripting(self, &cmd_buffer, &mut output)?;
}

Ok(())
}
}

fn should_eval(curr_cmd: &str) -> bool {
all_sexps(Context::new(None, curr_cmd)).is_ok()
}

fn run_command_in_scripting<W>(egraph: &mut EGraph, command: &str, mut output: W) -> io::Result<()>
where
W: Write,
{
match egraph.parse_and_run_program(None, command) {
Ok(msgs) => {
for msg in msgs {
writeln!(output, "{msg}")?;
}
}
Err(err) => log::error!("{err}"),
}
Ok(())
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_should_eval() {
#[rustfmt::skip]
let test_cases = vec![
vec![
"(extract",
"\"1",
")",
"(",
")))",
"\"",
";; )",
")"
],
vec![
"(extract 1) (extract",
"2) (",
"extract 3) (extract 4) ;;;; ("
],
vec![
"(extract \"\\\")\")"
]];
for test in test_cases {
let mut cmd_buffer = String::new();
for (i, line) in test.iter().enumerate() {
cmd_buffer.push_str(line);
cmd_buffer.push('\n');
assert_eq!(should_eval(&cmd_buffer), i == test.len() - 1);
}
}
}

#[test]
fn test_repl() {
let mut egraph = EGraph::default();

let input = "(extract 1)";
let mut output = Vec::new();
egraph.repl_with(input.as_bytes(), &mut output).unwrap();
assert_eq!(String::from_utf8(output).unwrap(), "1\n");

let input = "\n\n\n";
let mut output = Vec::new();
egraph.repl_with(input.as_bytes(), &mut output).unwrap();
assert_eq!(String::from_utf8(output).unwrap(), "");

let input = "(set-option interactive_mode 1)";
let mut output = Vec::new();
egraph.repl_with(input.as_bytes(), &mut output).unwrap();
assert_eq!(String::from_utf8(output).unwrap(), "(done)\n");

let input = "(set-option interactive_mode 1)\n(extract 1)(extract 2)\n";
let mut output = Vec::new();
egraph.repl_with(input.as_bytes(), &mut output).unwrap();
assert_eq!(
String::from_utf8(output).unwrap(),
"(done)\n1\n(done)\n2\n(done)\n"
);
}
}
4 changes: 3 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@
//!
mod actions;
pub mod ast;
mod cli;
pub mod constraint;
mod core;
pub mod extract;
mod function;
mod gj;
mod repl;
mod serialize;
pub mod sort;
mod termdag;
Expand All @@ -33,6 +33,8 @@ use crate::typechecking::TypeError;
use actions::Program;
use ast::remove_globals::remove_globals;
use ast::*;
#[cfg(feature = "bin")]
pub use cli::bin::*;
use constraint::{Constraint, SimpleTypeConstraint, TypeConstraint};
use extract::Extractor;
pub use function::Function;
Expand Down
Loading
Loading