Skip to content

Commit

Permalink
Thread-safe EGraph struct (#517)
Browse files Browse the repository at this point in the history
* thread safe

* simplify code
  • Loading branch information
yihozhang authored Feb 3, 2025
1 parent 9e6ecb6 commit 215714e
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 19 deletions.
11 changes: 7 additions & 4 deletions src/ast/parse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -217,16 +217,19 @@ fn map_fallible<T>(
.collect::<Result<_, _>>()
}

pub trait Macro<T> {
pub trait Macro<T>: Send + Sync {
fn name(&self) -> Symbol;
fn parse(&self, args: &[Sexp], span: Span, parser: &mut Parser) -> Result<T, ParseError>;
}

pub struct SimpleMacro<T, F: Fn(&[Sexp], Span, &mut Parser) -> Result<T, ParseError>>(Symbol, F);
pub struct SimpleMacro<T, F: Fn(&[Sexp], Span, &mut Parser) -> Result<T, ParseError> + Send + Sync>(
Symbol,
F,
);

impl<T, F> SimpleMacro<T, F>
where
F: Fn(&[Sexp], Span, &mut Parser) -> Result<T, ParseError>,
F: Fn(&[Sexp], Span, &mut Parser) -> Result<T, ParseError> + Send + Sync,
{
pub fn new(head: &str, f: F) -> Self {
Self(head.into(), f)
Expand All @@ -235,7 +238,7 @@ where

impl<T, F> Macro<T> for SimpleMacro<T, F>
where
F: Fn(&[Sexp], Span, &mut Parser) -> Result<T, ParseError>,
F: Fn(&[Sexp], Span, &mut Parser) -> Result<T, ParseError> + Send + Sync,
{
fn name(&self) -> Symbol {
self.0
Expand Down
16 changes: 8 additions & 8 deletions src/function/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ pub struct Function {
pub merge: MergeFn,
pub(crate) nodes: table::Table,
sorts: HashSet<Symbol>,
pub(crate) indexes: Vec<Rc<ColumnIndex>>,
pub(crate) indexes: Vec<Arc<ColumnIndex>>,
pub(crate) rebuild_indexes: Vec<Option<CompositeColumnIndex>>,
index_updated_through: usize,
updates: usize,
Expand All @@ -30,7 +30,7 @@ pub enum MergeFn {
Union,
// the rc is make sure it's cheaply clonable, since calling the merge fn
// requires a clone
Expr(Rc<Program>),
Expr(Arc<Program>),
}

/// All information we know determined by the input.
Expand Down Expand Up @@ -125,7 +125,7 @@ impl Function {
let program = egraph
.compile_expr(&binding, &actions, &target)
.map_err(Error::TypeErrors)?;
MergeFn::Expr(Rc::new(program))
MergeFn::Expr(Arc::new(program))
} else if decl.subtype == FunctionSubtype::Constructor {
MergeFn::Union
} else {
Expand All @@ -136,7 +136,7 @@ impl Function {
input
.iter()
.chain(once(&output))
.map(|x| Rc::new(ColumnIndex::new(x.name()))),
.map(|x| Arc::new(ColumnIndex::new(x.name()))),
);

let rebuild_indexes = Vec::from_iter(input.iter().chain(once(&output)).map(|x| {
Expand Down Expand Up @@ -179,7 +179,7 @@ impl Function {
self.nodes.clear();
self.indexes
.iter_mut()
.for_each(|x| Rc::make_mut(x).clear());
.for_each(|x| Arc::make_mut(x).clear());
self.rebuild_indexes.iter_mut().for_each(|x| {
if let Some(x) = x {
x.clear()
Expand Down Expand Up @@ -219,7 +219,7 @@ impl Function {
&self,
col: usize,
timestamps: &Range<u32>,
) -> Option<Rc<ColumnIndex>> {
) -> Option<Arc<ColumnIndex>> {
let range = self.nodes.transform_range(timestamps);
if range.end > self.index_updated_through {
return None;
Expand Down Expand Up @@ -250,7 +250,7 @@ impl Function {
.zip(self.rebuild_indexes.iter_mut())
.enumerate()
{
let as_mut = Rc::make_mut(index);
let as_mut = Arc::make_mut(index);
if col == self.schema.input.len() {
for (slot, _, out) in self.nodes.iter_range(offsets.clone(), true) {
as_mut.add(out.value, slot)
Expand Down Expand Up @@ -295,7 +295,7 @@ impl Function {
for index in &mut self.indexes {
// Everything works if we don't have a unique copy of the indexes,
// but we ought to be able to avoid this copy.
Rc::make_mut(index).clear();
Arc::make_mut(index).clear();
}
for rebuild_index in self.rebuild_indexes.iter_mut().flatten() {
rebuild_index.clear();
Expand Down
4 changes: 2 additions & 2 deletions src/gj.rs
Original file line number Diff line number Diff line change
Expand Up @@ -797,7 +797,7 @@ type RowIdx = u32;
#[derive(Debug)]
enum LazyTrieInner {
Borrowed {
index: Rc<ColumnIndex>,
index: Arc<ColumnIndex>,
map: HashMap<Value, LazyTrie>,
},
Delayed(SmallVec<[RowIdx; 4]>),
Expand All @@ -822,7 +822,7 @@ impl LazyTrie {
LazyTrieInner::Borrowed { index, .. } => index.len(),
}
}
fn from_column_index(index: Rc<ColumnIndex>) -> LazyTrie {
fn from_column_index(index: Arc<ColumnIndex>) -> LazyTrie {
LazyTrie(UnsafeCell::new(LazyTrieInner::Borrowed {
index,
map: Default::default(),
Expand Down
16 changes: 11 additions & 5 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,16 +45,16 @@ use indexmap::map::Entry;
use instant::{Duration, Instant};
pub use serialize::{SerializeConfig, SerializedNode};
use sort::*;
use std::fmt::Debug;
use std::fmt::{Display, Formatter};
use std::fs::File;
use std::hash::Hash;
use std::io::Read;
use std::iter::once;
use std::ops::{Deref, Range};
use std::path::PathBuf;
use std::rc::Rc;
use std::str::FromStr;
use std::{fmt::Debug, sync::Arc};
use std::sync::Arc;
pub use termdag::{Term, TermDag, TermId};
use thiserror::Error;
pub use typechecking::TypeInfo;
Expand Down Expand Up @@ -292,7 +292,7 @@ impl RunReport {
}

#[derive(Clone)]
pub struct Primitive(Arc<dyn PrimitiveLike>);
pub struct Primitive(Arc<dyn PrimitiveLike + Send + Sync>);
impl Primitive {
// Takes the full signature of a primitive (including input and output types)
// Returns whether the primitive is compatible with this signature
Expand Down Expand Up @@ -344,7 +344,7 @@ impl Debug for Primitive {
}
}

impl<T: PrimitiveLike + 'static> From<T> for Primitive {
impl<T: PrimitiveLike + 'static + Send + Sync> From<T> for Primitive {
fn from(p: T) -> Self {
Self(Arc::new(p))
}
Expand Down Expand Up @@ -1585,7 +1585,9 @@ pub enum Error {

#[cfg(test)]
mod tests {
use std::sync::Arc;
use std::sync::{Arc, Mutex};

use lazy_static::lazy_static;

use crate::constraint::SimpleTypeConstraint;
use crate::sort::*;
Expand Down Expand Up @@ -1656,4 +1658,8 @@ mod tests {
)
.unwrap();
}

lazy_static! {
pub static ref RT: Mutex<EGraph> = Mutex::new(EGraph::default());
}
}

0 comments on commit 215714e

Please sign in to comment.