From d92368fea2ca867679f5bdb345a0273f0978f7f1 Mon Sep 17 00:00:00 2001 From: Andrew Frantz Date: Sat, 9 Dec 2023 09:37:12 -0500 Subject: [PATCH] fix: apply some of Clay's performance suggestions --- src/derive/command/endedness.rs | 70 +++++++++---------- src/derive/endedness/compute.rs | 120 ++++++++++++++++---------------- 2 files changed, 93 insertions(+), 97 deletions(-) diff --git a/src/derive/command/endedness.rs b/src/derive/command/endedness.rs index f1180a6..3abe3ec 100644 --- a/src/derive/command/endedness.rs +++ b/src/derive/command/endedness.rs @@ -3,7 +3,7 @@ use std::collections::HashMap; use std::collections::HashSet; use std::path::PathBuf; -use std::rc::Rc; +use std::sync::Arc; use clap::Args; use noodles::sam::record::data::field::Tag; @@ -62,16 +62,14 @@ pub struct DeriveEndednessArgs { pub fn derive(args: DeriveEndednessArgs) -> anyhow::Result<()> { info!("Starting derive endedness subcommand."); - let mut ordering_flags: HashMap, OrderingFlagsCounts> = HashMap::new(); - ordering_flags.insert(Rc::new(OVERALL.to_string()), OrderingFlagsCounts::new()); - ordering_flags.insert( - Rc::new(UNKNOWN_READ_GROUP.to_string()), - OrderingFlagsCounts::new(), - ); + let mut found_rgs = HashSet::new(); + + let mut ordering_flags: HashMap, OrderingFlagsCounts> = HashMap::new(); + ordering_flags.insert(Arc::clone(&OVERALL), OrderingFlagsCounts::new()); + ordering_flags.insert(Arc::clone(&UNKNOWN_READ_GROUP), OrderingFlagsCounts::new()); // only used if args.calc_rpt is true - let mut found_rgs = HashSet::new(); - let mut read_names = Trie::>>::new(); + let mut read_names = Trie::>>::new(); let ParsedBAMFile { mut reader, header, .. @@ -97,24 +95,28 @@ pub fn derive(args: DeriveEndednessArgs) -> anyhow::Result<()> { } let read_group = match record.data().get(Tag::ReadGroup) { - Some(rg) => Rc::new(rg.as_str().unwrap().to_owned()), - None => Rc::new(UNKNOWN_READ_GROUP.to_string()), + Some(rg) => { + let rg = rg.to_string(); + if !found_rgs.contains(&rg) { + found_rgs.insert(Arc::new(rg.clone())); + } + found_rgs.get(&rg).unwrap().clone() + } + None => Arc::clone(&UNKNOWN_READ_GROUP), }; if args.calc_rpt { - found_rgs.insert(Rc::clone(&read_group)); - match record.read_name() { Some(rn) => { - let rg_vec = read_names.get_mut(&rn.to_string()); + let rn = rn.to_string(); + let rg_vec = read_names.get_mut(&rn); match rg_vec { Some(rg_vec) => { - rg_vec.push(Rc::clone(&read_group)); + rg_vec.push(Arc::clone(&read_group)); } None => { - let rg_vec = vec![(Rc::clone(&read_group))]; - read_names.insert(rn.to_string(), rg_vec); + read_names.insert(rn, vec![(Arc::clone(&read_group))]); } } } @@ -126,12 +128,12 @@ pub fn derive(args: DeriveEndednessArgs) -> anyhow::Result<()> { } } + let overall_rg = Arc::clone(&OVERALL); + if record.flags().is_first_segment() && !record.flags().is_last_segment() { - ordering_flags - .entry(Rc::new(OVERALL.to_string())) - .and_modify(|e| { - e.first += 1; - }); + ordering_flags.entry(overall_rg).and_modify(|e| { + e.first += 1; + }); ordering_flags .entry(read_group) @@ -145,11 +147,9 @@ pub fn derive(args: DeriveEndednessArgs) -> anyhow::Result<()> { neither: 0, }); } else if !record.flags().is_first_segment() && record.flags().is_last_segment() { - ordering_flags - .entry(Rc::new(OVERALL.to_string())) - .and_modify(|e| { - e.last += 1; - }); + ordering_flags.entry(overall_rg).and_modify(|e| { + e.last += 1; + }); ordering_flags .entry(read_group) @@ -163,11 +163,9 @@ pub fn derive(args: DeriveEndednessArgs) -> anyhow::Result<()> { neither: 0, }); } else if record.flags().is_first_segment() && record.flags().is_last_segment() { - ordering_flags - .entry(Rc::new(OVERALL.to_string())) - .and_modify(|e| { - e.both += 1; - }); + ordering_flags.entry(overall_rg).and_modify(|e| { + e.both += 1; + }); ordering_flags .entry(read_group) @@ -181,11 +179,9 @@ pub fn derive(args: DeriveEndednessArgs) -> anyhow::Result<()> { neither: 0, }); } else if !record.flags().is_first_segment() && !record.flags().is_last_segment() { - ordering_flags - .entry(Rc::new(OVERALL.to_string())) - .and_modify(|e| { - e.neither += 1; - }); + ordering_flags.entry(overall_rg).and_modify(|e| { + e.neither += 1; + }); ordering_flags .entry(read_group) diff --git a/src/derive/endedness/compute.rs b/src/derive/endedness/compute.rs index dadac20..d897ab2 100644 --- a/src/derive/endedness/compute.rs +++ b/src/derive/endedness/compute.rs @@ -7,17 +7,17 @@ use radix_trie::TrieCommon; use serde::Serialize; use std::collections::HashMap; use std::collections::HashSet; -use std::rc::Rc; +use std::sync::Arc; use tracing::warn; // Strings used to index into the HashMaps used to store the Read Group ordering flags. // Lazy statics are used to save memory. lazy_static! { /// String used to index into the HashMaps used to store the "overall" ordering flags. - pub static ref OVERALL: String = String::from("overall"); + pub static ref OVERALL: Arc = Arc::new(String::from("overall")); /// String used to index into th HashMaps used to store the "unknown_read_group" ordering flags. - pub static ref UNKNOWN_READ_GROUP: String = String::from("unknown_read_group"); + pub static ref UNKNOWN_READ_GROUP: Arc = Arc::new(String::from("unknown_read_group")); } /// Struct holding the ordering flags for a single read group. @@ -160,30 +160,30 @@ impl DerivedEndednessResult { } fn calculate_reads_per_template( - read_names: Trie>>, -) -> HashMap, f64> { - let mut reads_per_template: HashMap, f64> = HashMap::new(); + read_names: Trie>>, +) -> HashMap, f64> { + let mut reads_per_template: HashMap, f64> = HashMap::new(); let mut total_reads: usize = 0; let mut total_templates: usize = 0; - let mut read_group_reads: HashMap, usize> = HashMap::new(); - let mut read_group_templates: HashMap, usize> = HashMap::new(); + let mut read_group_reads: HashMap, usize> = HashMap::new(); + let mut read_group_templates: HashMap, usize> = HashMap::new(); for (read_name, read_groups) in read_names.iter() { let num_reads = read_groups.len(); total_reads += num_reads; total_templates += 1; - let read_group_set: HashSet> = read_groups.iter().cloned().collect(); + let read_group_set: HashSet> = read_groups.iter().cloned().collect(); if read_group_set.len() == 1 { - let read_group = read_group_set.iter().next().unwrap(); + let read_group = read_group_set.iter().next().unwrap().clone(); read_group_reads .entry(read_group.clone()) .and_modify(|e| *e += num_reads) .or_insert(num_reads); read_group_templates - .entry(read_group.clone()) + .entry(read_group) .and_modify(|e| *e += 1) .or_insert(1); } else { @@ -207,14 +207,14 @@ fn calculate_reads_per_template( } reads_per_template.insert( - Rc::new(OVERALL.to_string()), + Arc::clone(&OVERALL), total_reads as f64 / total_templates as f64, ); for (read_group, num_reads) in read_group_reads.iter() { let num_templates = read_group_templates.get(read_group).unwrap(); let rpt = *num_reads as f64 / *num_templates as f64; - reads_per_template.insert(Rc::clone(read_group), rpt); + reads_per_template.insert(Arc::clone(read_group), rpt); } reads_per_template @@ -313,12 +313,12 @@ fn predict_endedness( /// return a result for the endedness of the file. This may fail, and the /// resulting [`DerivedEndednessResult`] should be evaluated accordingly. pub fn predict( - ordering_flags: HashMap, OrderingFlagsCounts>, - read_names: Trie>>, + ordering_flags: HashMap, OrderingFlagsCounts>, + read_names: Trie>>, paired_deviance: f64, round_rpt: bool, ) -> Result { - let mut rpts: HashMap, f64> = HashMap::new(); + let mut rpts: HashMap, f64> = HashMap::new(); if !read_names.is_empty() { rpts = calculate_reads_per_template(read_names); } @@ -332,7 +332,7 @@ pub fn predict( ); for (read_group, rg_ordering_flags) in ordering_flags.iter() { - if (read_group == &Rc::new(UNKNOWN_READ_GROUP.to_string())) + if (*read_group == *UNKNOWN_READ_GROUP) && (rg_ordering_flags.first == 0 && rg_ordering_flags.last == 0 && rg_ordering_flags.both == 0 @@ -369,17 +369,17 @@ mod tests { #[test] fn test_derive_endedness_from_all_zero_counts() { - let mut ordering_flags: HashMap, OrderingFlagsCounts> = HashMap::new(); - ordering_flags.insert(Rc::new(OVERALL.to_string()), OrderingFlagsCounts::new()); + let mut ordering_flags: HashMap, OrderingFlagsCounts> = HashMap::new(); + ordering_flags.insert(Arc::clone(&OVERALL), OrderingFlagsCounts::new()); let result = predict(ordering_flags, Trie::new(), 0.0, false); assert!(result.is_err()); } #[test] fn test_derive_endedness_from_only_first() { - let mut ordering_flags: HashMap, OrderingFlagsCounts> = HashMap::new(); + let mut ordering_flags: HashMap, OrderingFlagsCounts> = HashMap::new(); ordering_flags.insert( - Rc::new(OVERALL.to_string()), + Arc::clone(&OVERALL), OrderingFlagsCounts { first: 1, last: 0, @@ -402,9 +402,9 @@ mod tests { #[test] fn test_derive_endedness_from_only_last() { - let mut ordering_flags: HashMap, OrderingFlagsCounts> = HashMap::new(); + let mut ordering_flags: HashMap, OrderingFlagsCounts> = HashMap::new(); ordering_flags.insert( - Rc::new(OVERALL.to_string()), + Arc::clone(&OVERALL), OrderingFlagsCounts { first: 0, last: 1, @@ -427,9 +427,9 @@ mod tests { #[test] fn test_derive_endedness_from_only_both() { - let mut ordering_flags: HashMap, OrderingFlagsCounts> = HashMap::new(); + let mut ordering_flags: HashMap, OrderingFlagsCounts> = HashMap::new(); ordering_flags.insert( - Rc::new(OVERALL.to_string()), + Arc::clone(&OVERALL), OrderingFlagsCounts { first: 0, last: 0, @@ -452,9 +452,9 @@ mod tests { #[test] fn test_derive_endedness_from_only_neither() { - let mut ordering_flags: HashMap, OrderingFlagsCounts> = HashMap::new(); + let mut ordering_flags: HashMap, OrderingFlagsCounts> = HashMap::new(); ordering_flags.insert( - Rc::new(OVERALL.to_string()), + Arc::clone(&OVERALL), OrderingFlagsCounts { first: 0, last: 0, @@ -477,9 +477,9 @@ mod tests { #[test] fn test_derive_endedness_from_first_and_last() { - let mut ordering_flags: HashMap, OrderingFlagsCounts> = HashMap::new(); + let mut ordering_flags: HashMap, OrderingFlagsCounts> = HashMap::new(); ordering_flags.insert( - Rc::new(OVERALL.to_string()), + Arc::clone(&OVERALL), OrderingFlagsCounts { first: 1, last: 1, @@ -502,56 +502,56 @@ mod tests { #[test] fn test_calculate_reads_per_template() { - let mut read_names: Trie>> = Trie::new(); + let mut read_names: Trie>> = Trie::new(); read_names.insert( "read1".to_string(), vec![ - Rc::new("rg_paired".to_string()), - Rc::new("rg_paired".to_string()), + Arc::new("rg_paired".to_string()), + Arc::new("rg_paired".to_string()), ], ); read_names.insert( "read2".to_string(), vec![ - Rc::new("rg_paired".to_string()), - Rc::new("rg_paired".to_string()), - Rc::new("rg_single".to_string()), + Arc::new("rg_paired".to_string()), + Arc::new("rg_paired".to_string()), + Arc::new("rg_single".to_string()), ], ); - read_names.insert("read3".to_string(), vec![Rc::new("rg_single".to_string())]); + read_names.insert("read3".to_string(), vec![Arc::new("rg_single".to_string())]); read_names.insert( "read4".to_string(), vec![ - Rc::new("rg_paired".to_string()), - Rc::new("rg_paired".to_string()), + Arc::new("rg_paired".to_string()), + Arc::new("rg_paired".to_string()), ], ); read_names.insert( "read5".to_string(), vec![ - Rc::new("rg_paired".to_string()), - Rc::new("rg_paired".to_string()), - Rc::new("rg_single".to_string()), + Arc::new("rg_paired".to_string()), + Arc::new("rg_paired".to_string()), + Arc::new("rg_single".to_string()), ], ); let results = calculate_reads_per_template(read_names); assert_eq!(results.len(), 3); - assert_eq!(results.get(&Rc::new("overall".to_string())).unwrap(), &2.2); + assert_eq!(results.get(&Arc::new("overall".to_string())).unwrap(), &2.2); assert_eq!( - results.get(&Rc::new("rg_paired".to_string())).unwrap(), + results.get(&Arc::new("rg_paired".to_string())).unwrap(), &2.0 ); assert_eq!( - results.get(&Rc::new("rg_single".to_string())).unwrap(), + results.get(&Arc::new("rg_single".to_string())).unwrap(), &1.0 ); } #[test] fn test_derive_endedness_from_first_and_last_with_rpt() { - let mut ordering_flags: HashMap, OrderingFlagsCounts> = HashMap::new(); + let mut ordering_flags: HashMap, OrderingFlagsCounts> = HashMap::new(); ordering_flags.insert( - Rc::new(OVERALL.to_string()), + Arc::new(OVERALL.to_string()), OrderingFlagsCounts { first: 8, last: 8, @@ -560,7 +560,7 @@ mod tests { }, ); ordering_flags.insert( - Rc::new("rg_paired".to_string()), + Arc::new("rg_paired".to_string()), OrderingFlagsCounts { first: 8, last: 8, @@ -569,7 +569,7 @@ mod tests { }, ); ordering_flags.insert( - Rc::new("rg_single".to_string()), + Arc::new("rg_single".to_string()), OrderingFlagsCounts { first: 0, last: 0, @@ -577,36 +577,36 @@ mod tests { neither: 0, }, ); - let mut read_names: Trie>> = Trie::new(); + let mut read_names: Trie>> = Trie::new(); read_names.insert( "read1".to_string(), vec![ - Rc::new("rg_paired".to_string()), - Rc::new("rg_paired".to_string()), + Arc::new("rg_paired".to_string()), + Arc::new("rg_paired".to_string()), ], ); read_names.insert( "read2".to_string(), vec![ - Rc::new("rg_paired".to_string()), - Rc::new("rg_paired".to_string()), - Rc::new("rg_single".to_string()), + Arc::new("rg_paired".to_string()), + Arc::new("rg_paired".to_string()), + Arc::new("rg_single".to_string()), ], ); - read_names.insert("read3".to_string(), vec![Rc::new("rg_single".to_string())]); + read_names.insert("read3".to_string(), vec![Arc::new("rg_single".to_string())]); read_names.insert( "read4".to_string(), vec![ - Rc::new("rg_paired".to_string()), - Rc::new("rg_paired".to_string()), + Arc::new("rg_paired".to_string()), + Arc::new("rg_paired".to_string()), ], ); read_names.insert( "read5".to_string(), vec![ - Rc::new("rg_paired".to_string()), - Rc::new("rg_paired".to_string()), - Rc::new("rg_single".to_string()), + Arc::new("rg_paired".to_string()), + Arc::new("rg_paired".to_string()), + Arc::new("rg_single".to_string()), ], ); let result = predict(ordering_flags, read_names, 0.0, false);