diff --git a/data_server/src/main.rs b/data_server/src/main.rs index 9f8619f6..6ab4dd85 100644 --- a/data_server/src/main.rs +++ b/data_server/src/main.rs @@ -1,9 +1,8 @@ use clap::Parser; use log::info; use prost::Message; -use rand::prelude::IteratorRandom; use rand::seq::SliceRandom; -use rand::thread_rng; +use rand::{thread_rng, Rng}; use std::fs::File; use std::io::{self, BufReader, Read, Result as IoResult}; use std::vec; @@ -21,6 +20,7 @@ use text_data::{ #[derive(Default)] pub struct MyDataService { groups: Vec, + causual_sampling: bool, weights: Vec, } @@ -56,7 +56,7 @@ fn read_pb_stream(mut reader: BufReader) -> io::Result } impl MyDataService { - pub fn new(files: Vec) -> IoResult { + pub fn new(files: Vec, causual_sampling: bool) -> IoResult { let mut groups = Vec::new(); let mut weights = Vec::new(); @@ -73,7 +73,7 @@ impl MyDataService { info!("Loaded {} groups", groups.len()); - Ok(MyDataService { groups, weights }) + Ok(MyDataService { groups, weights, causual_sampling }) } } @@ -90,15 +90,36 @@ impl DataService for MyDataService { .groups .choose_weighted(&mut rng, |item| item.sentences.len() as f32); - if group.is_ok() { - let group = group.unwrap(); + if group.is_err() { + return Err(Status::internal("Failed to select a group")); + } + + let group = group.unwrap(); + + if self.causual_sampling { if num_samples > group.sentences.len() { num_samples = group.sentences.len(); } + // Random number between 0 and group.sentences.len() - num_samples + let max = group.sentences.len() - num_samples; + if max <= 0 { + return Ok(Response::new(SampledData { + name: group.name.clone(), + source: group.source.clone(), + samples: group.sentences.clone(), + })); + } + + let start = rng.gen_range(0..max); + Ok(Response::new(SampledData { + name: group.name.clone(), + source: group.source.clone(), + samples: group.sentences[start..start + num_samples].to_vec(), + })) + } else { let sentences_ref = group .sentences - .iter() .choose_multiple(&mut rng, num_samples); let sentences: Vec = sentences_ref @@ -109,10 +130,8 @@ impl DataService for MyDataService { Ok(Response::new(SampledData { name: group.name.clone(), source: group.source.clone(), - samples: sentences + samples: sentences, })) - } else { - Err(Status::internal("Failed to select a group")) } } } @@ -124,6 +143,10 @@ struct Args { /// Files to process #[clap(short, long, value_name = "FILE", required = true)] files: Vec, + + /// Causual sampling + #[clap(short, long, default_value = "false")] + causal: bool } #[tokio::main] @@ -132,9 +155,10 @@ async fn main() -> Result<(), Box> { // Parse command-line arguments let args = Args::parse(); + info!("Arguments: {:?}", args); let addr = "127.0.0.1:50051".parse()?; - let data_service = MyDataService::new(args.files)?; + let data_service = MyDataService::new(args.files, args.causal)?; info!("Starting server at {}", addr);