diff --git a/DESCRIPTION b/DESCRIPTION index a318284..ff7f72f 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,6 +1,6 @@ Package: orbweaver Title: Fast and Efficient Graph Data Structures -Version: 0.17.1 +Version: 0.18.0 Authors@R: c(person(given = "ixpantia, SRL", role = "cph", diff --git a/R/extendr-wrappers.R b/R/extendr-wrappers.R index 5a115b3..8f6f7ba 100644 --- a/R/extendr-wrappers.R +++ b/R/extendr-wrappers.R @@ -34,10 +34,10 @@ DirectedGraph$get_all_roots <- function() .Call(wrap__DirectedGraph__get_all_roo DirectedGraph$get_roots_over <- function(node_ids) .Call(wrap__DirectedGraph__get_roots_over, self, node_ids) -DirectedGraph$subset <- function(node_id) .Call(wrap__DirectedGraph__subset, self, node_id) - DirectedGraph$subset_multi <- function(node_ids) .Call(wrap__DirectedGraph__subset_multi, self, node_ids) +DirectedGraph$subset_multi_with_limit <- function(node_ids, limit) .Call(wrap__DirectedGraph__subset_multi_with_limit, self, node_ids, limit) + DirectedGraph$print <- function() invisible(.Call(wrap__DirectedGraph__print, self)) DirectedGraph$to_bin_disk <- function(path) .Call(wrap__DirectedGraph__to_bin_disk, self, path) @@ -84,10 +84,10 @@ DirectedAcyclicGraph$get_all_roots <- function() .Call(wrap__DirectedAcyclicGrap DirectedAcyclicGraph$get_roots_over <- function(node_ids) .Call(wrap__DirectedAcyclicGraph__get_roots_over, self, node_ids) -DirectedAcyclicGraph$subset <- function(node_id) .Call(wrap__DirectedAcyclicGraph__subset, self, node_id) - DirectedAcyclicGraph$subset_multi <- function(node_ids) .Call(wrap__DirectedAcyclicGraph__subset_multi, self, node_ids) +DirectedAcyclicGraph$subset_multi_with_limit <- function(node_ids, limit) .Call(wrap__DirectedAcyclicGraph__subset_multi_with_limit, self, node_ids, limit) + DirectedAcyclicGraph$print <- function() invisible(.Call(wrap__DirectedAcyclicGraph__print, self)) DirectedAcyclicGraph$to_bin_disk <- function(path) .Call(wrap__DirectedAcyclicGraph__to_bin_disk, self, path) diff --git a/R/subset.R b/R/subset.R index aa445ab..3903c9e 100644 --- a/R/subset.R +++ b/R/subset.R @@ -1,17 +1,19 @@ +NO_LIMIT <- -1L + #' @export -subset.DirectedGraph <- function(x, ...) { +subset.DirectedGraph <- function(x, ..., limit = NO_LIMIT) { arguments <- c(...) - if (length(arguments) > 1) { + if (limit == NO_LIMIT) { return(throw_if_error(x$subset_multi(arguments))) } - throw_if_error(x$subset(arguments)) + return(throw_if_error(x$subset_multi_with_limit(arguments, limit))) } #' @export -subset.DirectedAcyclicGraph <- function(x, ...) { +subset.DirectedAcyclicGraph <- function(x, ..., limit = NO_LIMIT) { arguments <- c(...) - if (length(arguments) > 1) { + if (limit == NO_LIMIT) { return(throw_if_error(x$subset_multi(arguments))) } - throw_if_error(x$subset(arguments)) + return(throw_if_error(x$subset_multi_with_limit(arguments, limit))) } diff --git a/src/rust/Cargo.lock b/src/rust/Cargo.lock index 443f20b..04ba0fd 100644 --- a/src/rust/Cargo.lock +++ b/src/rust/Cargo.lock @@ -132,9 +132,8 @@ checksum = "1261fe7e33c73b354eab43b1273a57c8f967d0391e80353e51f764ac02cf6775" [[package]] name = "orbweaver" -version = "0.17.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f7074c2d5f33aa5ad4d66fbdcc114dd7652e24cc99f21e7118fe61e11f495efa" +version = "0.18.0" +source = "git+https://github.com/andyquinterom/orbweaver-rs.git?branch=T27#5c12db46cdf1911be39d132a8c1d469190f04a7a" dependencies = [ "flate2", "fxhash", @@ -145,7 +144,7 @@ dependencies = [ [[package]] name = "orbweaver_r" -version = "0.16.0" +version = "0.18.0" dependencies = [ "extendr-api", "orbweaver", diff --git a/src/rust/Cargo.toml b/src/rust/Cargo.toml index 8da1e29..bee06a0 100644 --- a/src/rust/Cargo.toml +++ b/src/rust/Cargo.toml @@ -1,6 +1,6 @@ [package] name = 'orbweaver_r' -version = '0.16.0' +version = '0.18.0' edition = '2021' [lib] @@ -8,7 +8,8 @@ crate-type = [ 'staticlib', 'lib' ] name = 'orbweaver_r' [dependencies] -orbweaver = { version = "0.17.1" } +# orbweaver = { version = "0.17.1" } +orbweaver = { git = "https://github.com/andyquinterom/orbweaver-rs.git", branch = "T27" } extendr-api = { version = "0.7", features = ["serde", "result_condition"] } # This will help us filter the platforms diff --git a/src/rust/src/lib.rs b/src/rust/src/lib.rs index 3b3e998..4c5fdf7 100644 --- a/src/rust/src/lib.rs +++ b/src/rust/src/lib.rs @@ -144,8 +144,8 @@ pub trait RImplDirectedGraph: Sized { fn get_leaves_under(&self, nodes: RNodesIn) -> Result; fn get_all_roots(&self) -> NodeVec; fn get_roots_over(&self, node_ids: RNodesIn) -> Result; - fn subset(&self, node_id: &str) -> Result; fn subset_multi(&self, node_id: RNodesIn) -> Result; + fn subset_multi_with_limit(&self, node_id: RNodesIn, limit: i32) -> Result; fn print(&self); fn find_all_paths(&self, from: &str, to: &str) -> Result; fn to_bin_disk(&self, path: &str) -> Result<()>; diff --git a/src/rust/src/macros.rs b/src/rust/src/macros.rs index 715be86..a24a68d 100644 --- a/src/rust/src/macros.rs +++ b/src/rust/src/macros.rs @@ -48,14 +48,25 @@ macro_rules! impl_directed_graph { .map_err(to_r_error) .map(NodeVec) } - fn subset(&self, node_id: &str) -> Result { - Ok(Self(self.0.subset(node_id).map_err(to_r_error)?)) - } fn subset_multi(&self, node_ids: RNodesIn) -> Result { Ok(Self( self.0.subset_multi(node_ids.iter()).map_err(to_r_error)?, )) } + fn subset_multi_with_limit(&self, node_ids: RNodesIn, limit: i32) -> Result { + if limit <= 0 { + return Err("Limit cannot be negative".into()); + } + + // This is safe because we checked right before + let limit = unsafe { std::num::NonZeroUsize::new_unchecked(limit as usize) }; + + Ok(Self( + self.0 + .subset_multi_with_limit(node_ids.iter(), limit) + .map_err(to_r_error)?, + )) + } fn print(&self) { println!("{:?}", self.0) }