diff --git a/.gitignore b/.gitignore index 67617f0e0..8655d3d21 100644 --- a/.gitignore +++ b/.gitignore @@ -22,3 +22,4 @@ meta*.csv stop .vscode *.f90 +settings.json diff --git a/Cargo.toml b/Cargo.toml index 2b2c286e1..57e0dbe0c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -21,6 +21,7 @@ csv = "1.2.1" ndarray = { version = "0.15.6", features = ["rayon"] } serde = "1.0.188" serde_derive = "1.0.188" +serde_json = "1.0.66" sobol_burley = "0.5.0" toml = { version = "0.8.1", features = ["preserve_order"] } ode_solvers = "0.3.7" @@ -46,6 +47,7 @@ faer = { version = "0.15.0", features = ["nalgebra", "ndarray"] } tracing = "0.1.40" tracing-subscriber = { version = "0.3.17", features = ["env-filter", "fmt", "time"] } chrono = "0.4" +config = "0.13" [dev-dependencies] criterion = "0.5" diff --git a/examples/bimodal_ke/config.toml b/examples/bimodal_ke/config.toml index 6f758a00f..9e2574a93 100644 --- a/examples/bimodal_ke/config.toml +++ b/examples/bimodal_ke/config.toml @@ -1,18 +1,19 @@ [paths] data = "examples/data/bimodal_ke.csv" -log_out = "log/bimodal_ke.log" -#prior_dist = "theta_bimodal_ke.csv" +log = "log/bimodal_ke.log" +#prior = "theta_bimodal_ke.csv" [config] cycles = 1024 engine = "NPAG" -init_points = 10000 +init_points = 2129 seed = 347 tui = true -pmetrics_outputs = true +output = true cache = true idelta = 0.1 -log_level = "debug" +log_level = "info" + [random] Ke = [0.001, 3.0] diff --git a/src/algorithms.rs b/src/algorithms.rs index a6597aba5..25155f7c1 100644 --- a/src/algorithms.rs +++ b/src/algorithms.rs @@ -1,4 +1,4 @@ -use crate::prelude::{self, settings::run::Settings}; +use crate::prelude::{self, settings::Settings}; use output::NPResult; use prelude::{datafile::Scenario, *}; @@ -9,12 +9,6 @@ mod npag; mod npod; mod postprob; -// pub enum Type { -// NPAG, -// NPOD, -// POSTPROB, -// } - pub trait Algorithm { fn fit(&mut self) -> NPResult; fn to_npresult(&self) -> NPResult; @@ -35,17 +29,17 @@ where Err(err) => panic!("Unable to remove previous stop file: {}", err), } } - let ranges = settings.computed.random.ranges.clone(); + let ranges = settings.random.ranges(); let theta = initialization::sample_space(&settings, &ranges); //This should be a macro, so it can automatically expands as soon as we add a new option in the Type Enum - match settings.parsed.config.engine.as_str() { + match settings.config.engine.as_str() { "NPAG" => Box::new(npag::NPAG::new( engine, ranges, theta, scenarios, - settings.parsed.error.poly, + settings.error.poly, tx, settings, )), @@ -54,7 +48,7 @@ where ranges, theta, scenarios, - settings.parsed.error.poly, + settings.error.poly, tx, settings, )), @@ -62,7 +56,7 @@ where engine, theta, scenarios, - settings.parsed.error.poly, + settings.error.poly, tx, settings, )), diff --git a/src/algorithms/npag.rs b/src/algorithms/npag.rs index 30e69f675..81f690a86 100644 --- a/src/algorithms/npag.rs +++ b/src/algorithms/npag.rs @@ -7,7 +7,7 @@ use crate::{ output::NPResult, output::{CycleLog, NPCycle}, prob, qr, - settings::run::Settings, + settings::Settings, simulation::predict::Engine, simulation::predict::{sim_obs, Predict}, }, @@ -118,15 +118,15 @@ where f1: f64::default(), cycle: 1, gamma_delta: 0.1, - gamma: settings.parsed.error.value, - error_type: match settings.parsed.error.class.to_lowercase().as_str() { + gamma: settings.error.value, + error_type: match settings.error.class.to_lowercase().as_str() { "additive" => ErrorType::Add, "proportional" => ErrorType::Prop, _ => panic!("Error type not supported"), }, converged: false, - cycle_log: CycleLog::new(&settings.computed.random.names), - cache: settings.parsed.config.cache.unwrap_or(false), + cycle_log: CycleLog::new(&settings.random.names()), + cache: settings.config.cache, tx, settings, scenarios, @@ -277,8 +277,6 @@ where gamlam: self.gamma, }; self.tx.send(Comm::NPCycle(state.clone())).unwrap(); - self.cycle_log - .push_and_write(state, self.settings.parsed.config.pmetrics_outputs.unwrap()); // Increasing objf signals instability or model misspecification. if self.last_objf > self.objf { @@ -292,6 +290,9 @@ where self.w = self.lambda.clone(); let pyl = self.psi.dot(&self.w); + self.cycle_log + .push_and_write(state, self.settings.config.output); + // Stop if we have reached convergence criteria if (self.last_objf - self.objf).abs() <= THETA_G && self.eps > THETA_E { self.eps /= 2.; @@ -309,7 +310,7 @@ where } // Stop if we have reached maximum number of cycles - if self.cycle >= self.settings.parsed.config.cycles { + if self.cycle >= self.settings.config.cycles { tracing::warn!("Maximum number of cycles reached"); break; } diff --git a/src/algorithms/npod.rs b/src/algorithms/npod.rs index 9144a8c2f..46949bdb2 100644 --- a/src/algorithms/npod.rs +++ b/src/algorithms/npod.rs @@ -9,7 +9,7 @@ use crate::{ output::NPResult, output::{CycleLog, NPCycle}, prob, qr, - settings::run::Settings, + settings::Settings, simulation::predict::Engine, simulation::predict::{sim_obs, Predict}, }, @@ -111,15 +111,15 @@ where objf: f64::INFINITY, cycle: 1, gamma_delta: 0.1, - gamma: settings.parsed.error.value, - error_type: match settings.parsed.error.class.as_str() { + gamma: settings.error.value, + error_type: match settings.error.class.as_str() { "additive" => ErrorType::Add, "proportional" => ErrorType::Prop, _ => panic!("Error type not supported"), }, converged: false, - cycle_log: CycleLog::new(&settings.computed.random.names), - cache: settings.parsed.config.cache.unwrap_or(false), + cycle_log: CycleLog::new(&settings.random.names()), + cache: settings.config.cache, tx, settings, scenarios, @@ -296,7 +296,7 @@ where } // Stop if we have reached maximum number of cycles - if self.cycle >= self.settings.parsed.config.cycles { + if self.cycle >= self.settings.config.cycles { tracing::warn!("Maximum number of cycles reached"); break; } @@ -308,7 +308,7 @@ where } //TODO: the cycle migh break before reaching this point self.cycle_log - .push_and_write(state, self.settings.parsed.config.pmetrics_outputs.unwrap()); + .push_and_write(state, self.settings.config.output); self.cycle += 1; diff --git a/src/algorithms/postprob.rs b/src/algorithms/postprob.rs index 714c48904..85e6b1353 100644 --- a/src/algorithms/postprob.rs +++ b/src/algorithms/postprob.rs @@ -6,7 +6,7 @@ use crate::{ ipm, output::NPResult, prob, - settings::run::Settings, + settings::Settings, simulation::predict::Engine, simulation::predict::{sim_obs, Predict}, }, @@ -82,8 +82,8 @@ where objf: f64::INFINITY, cycle: 0, converged: false, - gamma: settings.parsed.error.value, - error_type: match settings.parsed.error.class.as_str() { + gamma: settings.error.value, + error_type: match settings.error.class.as_str() { "additive" => ErrorType::Add, "proportional" => ErrorType::Prop, _ => panic!("Error type not supported"), diff --git a/src/entrypoints.rs b/src/entrypoints.rs index fc55d3b3d..885e9dde7 100644 --- a/src/entrypoints.rs +++ b/src/entrypoints.rs @@ -5,7 +5,7 @@ use crate::prelude::{ *, }; use crate::routines::datafile::Scenario; -use crate::routines::settings::run::Settings; +use crate::routines::settings::*; use csv::{ReaderBuilder, WriterBuilder}; use eyre::Result; @@ -35,16 +35,16 @@ pub fn simulate(engine: Engine, settings_path: String) -> Result<()> where S: Predict<'static> + std::marker::Sync + std::marker::Send + 'static + Clone, { - let settings = settings::simulator::read(settings_path); - let theta_file = File::open(settings.paths.theta).unwrap(); + let settings: Settings = read_settings(settings_path).unwrap(); + let theta_file = File::open(settings.paths.prior.unwrap()).unwrap(); let mut reader = ReaderBuilder::new() .has_headers(true) .from_reader(theta_file); let theta: Array2 = reader.deserialize_array2_dynamic().unwrap(); // Expand data - let idelta = settings.config.idelta.unwrap_or(0.0); - let tad = settings.config.tad.unwrap_or(0.0); + let idelta = settings.config.idelta; + let tad = settings.config.tad; let mut scenarios = datafile::parse(&settings.paths.data).unwrap(); scenarios.iter_mut().for_each(|scenario| { *scenario = scenario.add_event_interval(idelta, tad); @@ -88,17 +88,23 @@ where S: Predict<'static> + std::marker::Sync + std::marker::Send + 'static + Clone, { let now = Instant::now(); - let settings = settings::run::read(settings_path); + let settings = match read_settings(settings_path) { + Ok(s) => s, + Err(e) => { + eprintln!("Error reading settings: {:?}", e); + std::process::exit(-1); + } + }; let (tx, rx) = mpsc::unbounded_channel::(); let maintx = tx.clone(); logger::setup_log(&settings, tx.clone()); tracing::info!("Starting NPcore"); // Read input data and remove excluded scenarios (if any) - let mut scenarios = datafile::parse(&settings.parsed.paths.data).unwrap(); - if let Some(exclude) = &settings.parsed.config.exclude { + let mut scenarios = datafile::parse(&settings.paths.data).unwrap(); + if let Some(exclude) = &settings.config.exclude { for val in exclude { - scenarios.remove(val.as_integer().unwrap() as usize); + scenarios.remove(val.as_ptr() as usize); } } @@ -111,7 +117,7 @@ where // Spawn new thread for TUI let settings_tui = settings.clone(); - let handle = if settings.parsed.config.tui { + let handle = if settings.config.tui { spawn(move || { start_ui(rx, settings_tui).expect("Failed to start TUI"); }) @@ -128,10 +134,10 @@ where tracing::info!("Total time: {:.2?}", now.elapsed()); // Write output files (if configured) - if let Some(write) = &settings.parsed.config.pmetrics_outputs { - let idelta = settings.parsed.config.idelta.unwrap_or(0.0); - let tad = settings.parsed.config.tad.unwrap_or(0.0); - result.write_outputs(*write, &engine, idelta, tad); + if settings.config.output { + let idelta = settings.config.idelta; + let tad = settings.config.tad; + result.write_outputs(true, &engine, idelta, tad); } tracing::info!("Program complete"); diff --git a/src/lib.rs b/src/lib.rs index 93dddf25a..8aa99a42e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -16,10 +16,7 @@ pub mod routines { pub mod adaptative_grid; } - pub mod settings { - pub mod run; - pub mod simulator; - } + pub mod settings; pub mod evaluation { pub mod ipm; diff --git a/src/logger.rs b/src/logger.rs index e14f3f6f7..755fb17ea 100644 --- a/src/logger.rs +++ b/src/logger.rs @@ -1,4 +1,4 @@ -use crate::routines::settings::run::Settings; +use crate::routines::settings::Settings; use crate::tui::ui::Comm; use std::io::{self, Write}; use tokio::sync::mpsc::UnboundedSender; @@ -22,14 +22,10 @@ use tracing_subscriber::EnvFilter; /// If not, the log messages are written to stdout. pub fn setup_log(settings: &Settings, ui_tx: UnboundedSender) { // Use the log level defined in configuration file, or default to info - let log_level = settings - .parsed - .config - .log_level - .as_ref() - .map(|level| level.as_str()) - .unwrap_or("info") - .to_lowercase(); + let log_level = settings.config.log_level.as_str(); + + // Use the log file defined in configuration file, or default to npcore.log + let log_path = settings.paths.log.as_ref().unwrap(); let env_filter = EnvFilter::new(&log_level); @@ -41,7 +37,7 @@ pub fn setup_log(settings: &Settings, ui_tx: UnboundedSender) { .create(true) .write(true) .truncate(true) - .open(&settings.parsed.paths.log_out) + .open(log_path) .expect("Failed to open log file - does the directory exist?"); let file_layer = fmt::layer() @@ -50,7 +46,7 @@ pub fn setup_log(settings: &Settings, ui_tx: UnboundedSender) { .with_timer(CompactTimestamp); // Define layer for stdout - let stdout_layer = if !settings.parsed.config.tui { + let stdout_layer = if !settings.config.tui { let layer = fmt::layer() .with_writer(std::io::stdout) .with_ansi(true) @@ -66,7 +62,7 @@ pub fn setup_log(settings: &Settings, ui_tx: UnboundedSender) { ui_tx: ui_tx.clone(), }; - let tui_layer = if settings.parsed.config.tui { + let tui_layer = if settings.config.tui { let layer = fmt::layer() .with_writer(tui_writer_closure) .with_ansi(false) diff --git a/src/routines/evaluation/ipm.rs b/src/routines/evaluation/ipm.rs index 38c6d2816..6a8d0b735 100644 --- a/src/routines/evaluation/ipm.rs +++ b/src/routines/evaluation/ipm.rs @@ -41,7 +41,7 @@ pub fn burke( // if row>col { // return Err("The matrix PSI has row>col".into()); // } - if psi.min().unwrap() < &0.0 { + if psi.min()? < &0.0 { return Err("PSI contains negative elements".into()); } let ecol: ArrayBase, Dim<[usize; 1]>> = Array::ones(col); @@ -55,7 +55,7 @@ pub fn burke( let mut lam = ecol.clone(); let mut w = 1. / &plam; let mut ptw = psi.t().dot(&w); - let shrink = 2. * *ptw.max().unwrap(); + let shrink = 2. * *ptw.max()?; lam *= shrink; plam *= shrink; w /= shrink; @@ -88,10 +88,10 @@ pub fn burke( let dw = dw_aux.column(0); let dy = -psi.t().dot(&dw); let dlam = smuyinv - &lam - inner * &dy; - let mut alfpri = -1. / ((&dlam / &lam).min().unwrap().min(-0.5)); + let mut alfpri = -1. / ((&dlam / &lam).min()?.min(-0.5)); alfpri = (0.99995 * alfpri).min(1.0); - let mut alfdual = -1. / ((&dy / &y).min().unwrap().min(-0.5)); - alfdual = alfdual.min(-1. / (&dw / &w).min().unwrap().min(-0.5)); + let mut alfdual = -1. / ((&dy / &y).min()?.min(-0.5)); + alfdual = alfdual.min(-1. / (&dw / &w).min()?.min(-0.5)); alfdual = (0.99995 * alfdual).min(1.0); lam = lam + alfpri * dlam; w = w + alfdual * &dw; @@ -111,8 +111,7 @@ pub fn burke( (1. - alfdual).powi(2), (norm_r - mu) / (norm_r + 100. * mu) ]] - .max() - .unwrap() + .max()? .min(0.3); } } diff --git a/src/routines/initialization.rs b/src/routines/initialization.rs index c8ac91d95..0f0ab7698 100644 --- a/src/routines/initialization.rs +++ b/src/routines/initialization.rs @@ -2,13 +2,14 @@ use std::fs::File; use ndarray::Array2; -use crate::prelude::settings::run::Settings; +use crate::prelude::settings::Settings; pub mod sobol; pub fn sample_space(settings: &Settings, ranges: &Vec<(f64, f64)>) -> Array2 { - match &settings.parsed.paths.prior_dist { + match &settings.paths.prior { Some(prior_path) => { + tracing::info!("Reading prior from {}", prior_path); let file = File::open(prior_path).unwrap(); let mut reader = csv::ReaderBuilder::new() .has_headers(true) @@ -28,12 +29,7 @@ pub fn sample_space(settings: &Settings, ranges: &Vec<(f64, f64)>) -> Array2 = settings - .parsed - .random - .iter() - .map(|(name, _)| name.clone()) - .collect(); + let random_names: Vec = settings.random.names(); let mut reordered_indices: Vec = Vec::new(); for random_name in &random_names { @@ -76,10 +72,6 @@ pub fn sample_space(settings: &Settings, ranges: &Vec<(f64, f64)>) -> Array2 sobol::generate( - settings.parsed.config.init_points, - ranges, - settings.parsed.config.seed, - ), + None => sobol::generate(settings.config.init_points, ranges, settings.config.seed), } } diff --git a/src/routines/initialization/sobol.rs b/src/routines/initialization/sobol.rs index e2ad7b7c0..1eb7e0a00 100644 --- a/src/routines/initialization/sobol.rs +++ b/src/routines/initialization/sobol.rs @@ -8,7 +8,7 @@ use sobol_burley::sample; pub fn generate( n_points: usize, range_params: &Vec<(f64, f64)>, - seed: u32, + seed: usize, ) -> ArrayBase, Dim<[usize; 2]>> { let n_params = range_params.len(); let mut seq = Array::::zeros((n_points, n_params).f()); @@ -16,7 +16,7 @@ pub fn generate( let mut row = seq.slice_mut(s![i, ..]); let mut point: Vec = Vec::new(); for j in 0..n_params { - point.push(sample(i.try_into().unwrap(), j.try_into().unwrap(), seed) as f64) + point.push(sample(i.try_into().unwrap(), j.try_into().unwrap(), seed as u32) as f64) } row.assign(&Array::from(point)); } diff --git a/src/routines/output.rs b/src/routines/output.rs index 3252e72b2..f45fbc067 100644 --- a/src/routines/output.rs +++ b/src/routines/output.rs @@ -4,7 +4,7 @@ use datafile::Scenario; use ndarray::parallel::prelude::*; use ndarray::{Array, Array1, Array2, Axis}; use predict::{post_predictions, sim_obs, Engine, Predict}; -use settings::run::Settings; +use settings::Settings; use std::fs::File; /// Defines the result objects from an NPAG run @@ -36,12 +36,8 @@ impl NPResult { ) -> Self { // TODO: Add support for fixed and constant parameters - let par_names = settings - .parsed - .random - .iter() - .map(|(name, _)| name.clone()) - .collect(); + let par_names = settings.random.names(); + Self { scenarios, theta, @@ -358,7 +354,7 @@ impl CycleWriter { pub fn write(&mut self, cycle: usize, objf: f64, gamma: f64, theta: &Array2) { self.writer.write_field(format!("{}", cycle)).unwrap(); - self.writer.write_field(format!("{}", objf)).unwrap(); + self.writer.write_field(format!("{}", -2. * objf)).unwrap(); self.writer.write_field(format!("{}", gamma)).unwrap(); self.writer .write_field(format!("{}", theta.nrows())) diff --git a/src/routines/settings.rs b/src/routines/settings.rs new file mode 100644 index 000000000..49b9c3da2 --- /dev/null +++ b/src/routines/settings.rs @@ -0,0 +1,233 @@ +#![allow(dead_code)] + +use config::Config as eConfig; +use serde::Deserialize; +use serde_derive::Serialize; +use serde_json; +use std::collections::HashMap; + +/// Contains all settings NPcore +#[derive(Debug, Deserialize, Clone, Serialize)] +#[serde(deny_unknown_fields)] +pub struct Settings { + pub paths: Paths, + pub config: Config, + pub random: Random, + pub fixed: Option, + pub constant: Option, + pub error: Error, +} + +/// This struct contains the paths to the data, log and prior files. +#[derive(Debug, Deserialize, Clone, Serialize)] +#[serde(deny_unknown_fields)] +pub struct Paths { + /// Path to the data file, see `datafile::parse` for details. + pub data: String, + /// If provided, the log file will be written to this path. + pub log: Option, + /// If provided, NPcore will use this prior instead of a "uniform" prior, see `sobol::generate` for details. + pub prior: Option, +} + +/// General configuration settings +#[derive(Debug, Deserialize, Clone, Serialize)] +#[serde(deny_unknown_fields)] +pub struct Config { + pub cycles: usize, + pub engine: String, + #[serde(default = "default_seed")] + pub seed: usize, + #[serde(default = "default_10k")] + pub init_points: usize, + #[serde(default = "default_false")] + pub tui: bool, + #[serde(default = "default_true")] + pub output: bool, + #[serde(default = "default_true")] + pub cache: bool, + #[serde(default = "default_idelta")] + pub idelta: f64, + #[serde(default = "default_log_level")] + pub log_level: String, + pub exclude: Option>, + #[serde(default = "default_tad")] + pub tad: f64, +} + +/// Random parameters to be estimated +/// +/// This struct contains the random parameters to be estimated. The parameters are specified as a hashmap, where the key is the name of the parameter, and the value is a tuple containing the upper and lower bounds of the parameter. +/// +/// # Example +/// +/// ```toml +/// [random] +/// alpha = [0.0, 1.0] +/// beta = [0.0, 1.0] +/// ``` +#[derive(Debug, Deserialize, Clone, Serialize)] +pub struct Random { + #[serde(flatten)] + pub parameters: HashMap, +} + +impl Random { + /// Get the upper and lower bounds of a random parameter from its key + pub fn get(&self, key: &str) -> Option<&(f64, f64)> { + self.parameters.get(key) + } + + /// Returns a vector of tuples containing the names and ranges of the random parameters + pub fn names_and_ranges(&self) -> Vec<(String, (f64, f64))> { + let mut pairs: Vec<(String, (f64, f64))> = self + .parameters + .iter() + .map(|(key, &(upper, lower))| (key.clone(), (upper, lower))) + .collect(); + + // Sorting alphabetically by name + pairs.sort_by(|a, b| a.0.cmp(&b.0)); + + pairs + } + /// Returns a vector of the names of the random parameters + pub fn names(&self) -> Vec { + self.names_and_ranges() + .into_iter() + .map(|(name, _)| name) + .collect() + } + + /// Returns a vector of the upper and lower bounds of the random parameters + pub fn ranges(&self) -> Vec<(f64, f64)> { + self.names_and_ranges() + .into_iter() + .map(|(_, range)| range) + .collect() + } + + /// Validate the boundaries of the random parameters + pub fn validate(&self) -> Result<(), String> { + for (key, &(lower, upper)) in &self.parameters { + if lower >= upper { + return Err(format!( + "In key '{}', lower bound ({}) is not less than upper bound ({})", + key, lower, upper + )); + } + } + Ok(()) + } +} + +/// Parameters which are estimated, but fixed for the population +#[derive(Debug, Deserialize, Clone, Serialize)] +pub struct Fixed { + #[serde(flatten)] + pub parameters: HashMap, +} + +/// Parameters which are held constant +#[derive(Debug, Deserialize, Clone, Serialize)] +pub struct Constant { + #[serde(flatten)] + pub parameters: HashMap, +} + +/// Defines the error model and polynomial to be used +#[derive(Debug, Deserialize, Clone, Serialize)] +#[serde(deny_unknown_fields)] +pub struct Error { + pub value: f64, + pub class: String, + pub poly: (f64, f64, f64, f64), +} + +impl Error { + pub fn validate(&self) -> Result<(), String> { + if self.value < 0.0 { + return Err(format!( + "Error value must be non-negative, got {}", + self.value + )); + } + Ok(()) + } +} + +/// Parses the settings from a TOML configuration file +/// +/// This function parses the settings from a TOML configuration file. The settings are validated, and a copy of the settings is written to file. +/// +/// Entries in the TOML file may be overridden by environment variables. The environment variables must be prefixed with `NPCORE_`, and the TOML entry must be in uppercase. For example, the TUI may be disabled by setting the environment variable `NPCORE_TUI=false`. +pub fn read_settings(path: String) -> Result { + let settings_path = path; + + let parsed = eConfig::builder() + .add_source(config::File::with_name(&settings_path).format(config::FileFormat::Toml)) + .add_source(config::Environment::with_prefix("NPCORE").separator("_")) + .build()?; + + // Deserialize settings to the Settings struct + let settings: Settings = parsed.try_deserialize()?; + + // Validate entries + settings + .random + .validate() + .map_err(config::ConfigError::Message)?; + settings + .error + .validate() + .map_err(config::ConfigError::Message)?; + + // Write a copy of the settings to file + write_settings_to_file(&settings).expect("Could not write settings to file"); + + Ok(settings) // Return the settings wrapped in Ok +} + +/// Writes a copy of the parsed settings to file +/// +/// This function writes a copy of the parsed settings to file. The file is written to the current working directory, and is named `settings.json`. +pub fn write_settings_to_file(settings: &Settings) -> Result<(), std::io::Error> { + let serialized = serde_json::to_string_pretty(settings) + .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?; + + let file_path = "settings.json"; + let mut file = std::fs::File::create(file_path)?; + std::io::Write::write_all(&mut file, serialized.as_bytes())?; + Ok(()) +} + +// ********************************* +// Default values for deserializing +// ********************************* +fn default_true() -> bool { + true +} + +fn default_false() -> bool { + false +} + +fn default_log_level() -> String { + "info".to_string() +} + +fn default_seed() -> usize { + 347 +} + +fn default_idelta() -> f64 { + 0.12 +} + +fn default_tad() -> f64 { + 0.0 +} + +fn default_10k() -> usize { + 10_000 +} diff --git a/src/routines/settings/run.rs b/src/routines/settings/run.rs deleted file mode 100644 index 56ad18a22..000000000 --- a/src/routines/settings/run.rs +++ /dev/null @@ -1,152 +0,0 @@ -use serde_derive::Deserialize; -use std::fs; -use std::process::exit; -use toml::value::Array; -use toml::{self, Table}; - -/// Settings used for algorithm execution -/// -/// The user can specify the desired settings in a TOML configuration file, see `routines::settings::run` for details. -#[derive(Deserialize, Clone, Debug)] -pub struct Settings { - pub computed: Computed, - pub parsed: Parsed, -} - -#[derive(Deserialize, Clone, Debug)] -pub struct Computed { - pub random: Range, - pub constant: Single, - pub fixed: Single, -} - -/// The `Error` struct is used to specify the error model -/// - `value`: the value of the error -/// - `class`: the class of the error, can be either `additive` or `proportional` -/// - `poly`: the polynomial coefficients of the error model -/// -/// For more information see `routines::evaluation::sigma` -#[derive(Deserialize, Clone, Debug)] -pub struct Error { - pub value: f64, - pub class: String, - pub poly: (f64, f64, f64, f64), -} - -#[derive(Deserialize, Clone, Debug)] -pub struct Range { - pub names: Vec, - pub ranges: Vec<(f64, f64)>, -} - -#[derive(Deserialize, Clone, Debug)] -pub struct Single { - pub names: Vec, - pub values: Vec, -} - -#[derive(Deserialize, Clone, Debug)] -pub struct Parsed { - pub paths: Paths, - pub config: Config, - pub random: Table, - pub fixed: Option, - pub constant: Option
, - pub error: Error, -} - -#[derive(Deserialize, Clone, Debug)] -pub struct Paths { - pub data: String, - pub log_out: String, - pub prior_dist: Option, -} - -#[derive(Deserialize, Clone, Debug)] -pub struct Config { - pub cycles: usize, - pub engine: String, - pub init_points: usize, - pub seed: u32, - pub tui: bool, - pub pmetrics_outputs: Option, - pub exclude: Option, - pub cache: Option, - pub idelta: Option, - pub tad: Option, - pub log_level: Option, -} - -/// Read and parse settings from a TOML configuration file -pub fn read(filename: String) -> Settings { - let contents = match fs::read_to_string(&filename) { - Ok(c) => c, - Err(e) => { - eprintln!("{}", e); - eprintln!("ERROR: Could not read file {}", &filename); - exit(1); - } - }; - - let parsed: Parsed = match toml::from_str(&contents) { - Ok(d) => d, - Err(e) => { - eprintln!("{}", e); - eprintln!("ERROR: Unable to load data from {}", &filename); - exit(1); - } - }; - //Pri - let mut pr = vec![]; - let mut pn = vec![]; - for (name, range) in &parsed.random { - let range = range.as_array().unwrap(); - if range.len() != 2 { - eprintln!( - "ERROR: Ranges can only have 2 elements, {} found", - range.len() - ); - eprintln!("ERROR: In {:?}: {:?}", name, range); - exit(1); - } - pn.push(name.clone()); - pr.push((range[0].as_float().unwrap(), range[1].as_float().unwrap())); - } - //Constant - let mut cn = vec![]; - let mut cv = vec![]; - if let Some(constant) = &parsed.constant { - for (name, value) in constant { - cn.push(name.clone()); - cv.push(value.as_float().unwrap()); - } - } - - //Randfix - let mut rn = vec![]; - let mut rv = vec![]; - if let Some(randfix) = &parsed.fixed { - for (name, value) in randfix { - rn.push(name.clone()); - rv.push(value.as_float().unwrap()); - } - } - - Settings { - computed: Computed { - random: Range { - names: pn, - ranges: pr, - }, - constant: Single { - names: cn, - values: cv, - }, - fixed: Single { - names: rn, - values: rv, - }, - }, - parsed, - } -} diff --git a/src/routines/settings/simulator.rs b/src/routines/settings/simulator.rs deleted file mode 100644 index 3742f43a4..000000000 --- a/src/routines/settings/simulator.rs +++ /dev/null @@ -1,51 +0,0 @@ -use serde_derive::Deserialize; -use std::fs; -use std::process::exit; -use toml; - -#[derive(Deserialize, Clone, Debug)] -pub struct Settings { - pub paths: Paths, - pub config: Config, -} - -#[derive(Deserialize, Clone, Debug)] -pub struct Paths { - pub data: String, - pub theta: String, -} - -#[derive(Deserialize, Clone, Debug)] -pub struct Config { - pub idelta: Option, - pub tad: Option, -} - -pub fn read(filename: String) -> Settings { - let contents = match fs::read_to_string(&filename) { - Ok(c) => c, - Err(e) => { - eprintln!("{}", e); - eprintln!("ERROR: Could not read file {}", &filename); - exit(1); - } - }; - let parse: Settings = match toml::from_str(&contents) { - Ok(d) => d, - Err(e) => { - eprintln!("{}", e); - eprintln!("ERROR: Unable to load data from {}", &filename); - exit(1); - } - }; - Settings { - paths: Paths { - data: parse.paths.data, - theta: parse.paths.theta, - }, - config: Config { - idelta: parse.config.idelta, - tad: parse.config.tad, - }, - } -} diff --git a/src/tests/config.toml b/src/tests/config.toml index e68034516..eb1bf733d 100644 --- a/src/tests/config.toml +++ b/src/tests/config.toml @@ -1,6 +1,6 @@ [paths] data = "data.csv" -log_out = "test.log" +log = "test.log" [config] cycles = 1024 @@ -8,7 +8,7 @@ engine = "NPAG" init_points = 500 seed = 347 tui = false -pmetrics_outputs = true +output = true [random] ka = [0.1, 0.9] diff --git a/src/tests/mod.rs b/src/tests/mod.rs index 74c5a45e1..b60eb689d 100644 --- a/src/tests/mod.rs +++ b/src/tests/mod.rs @@ -33,45 +33,6 @@ fn scaled_sobol() { ) } -#[test] -fn read_mandatory_settings() { - let settings = settings::run::read("src/tests/config.toml".to_string()); - assert_eq!(settings.parsed.paths.data, "data.csv"); - assert_eq!(settings.parsed.config.cycles, 1024); - assert_eq!(settings.parsed.config.engine, "NPAG"); -} - -#[test] -fn read_parameter_names() { - let settings = settings::run::read("src/tests/config.toml".to_string()); - assert_eq!(settings.computed.random.names, vec!["ka", "ke", "v"]); -} - -#[test] -fn read_parameter_ranges() { - let settings = settings::run::read("src/tests/config.toml".to_string()); - - assert_eq!( - settings.computed.random.ranges, - vec![(0.1, 0.9), (0.001, 0.1), (30.0, 120.0)] - ); -} - -#[test] -fn read_randfix() { - let settings = settings::run::read("src/tests/config.toml".to_string()); - assert_eq!(settings.computed.fixed.names, vec!["KCP", "KPC"]); - assert_eq!(settings.computed.fixed.values, vec![5.1, 2.0]); -} - -#[test] -fn read_error() { - let settings = settings::run::read("src/tests/config.toml".to_string()); - assert_eq!(settings.parsed.error.value, 0.5); - assert_eq!(settings.parsed.error.class, "additive"); - assert_eq!(settings.parsed.error.poly, (0.0, 0.5, 0.0, 0.0)) -} - #[test] fn read_test_datafile() { let scenarios = datafile::parse(&"src/tests/test.csv".to_string()); diff --git a/src/tui/components.rs b/src/tui/components.rs index 567be651f..ebe1e103d 100644 --- a/src/tui/components.rs +++ b/src/tui/components.rs @@ -15,7 +15,7 @@ use ratatui::{ use super::App; -use crate::prelude::settings::run::Settings; +use crate::prelude::settings::Settings; pub fn draw_title<'a>() -> Paragraph<'a> { Paragraph::new("NPcore Execution") @@ -78,17 +78,16 @@ pub fn draw_status<'a>(app: &App, elapsed_time: Duration) -> Table<'a> { pub fn draw_options<'a>(settings: &Settings) -> Table<'a> { // Define the table data - let cycles = settings.parsed.config.cycles.to_string(); - let engine = settings.parsed.config.engine.to_string(); + let cycles = settings.config.cycles.to_string(); + let engine = settings.config.engine.to_string(); let conv_crit = "Placeholder".to_string(); - let indpts = settings.parsed.config.init_points.to_string(); - let error = settings.parsed.error.class.to_string(); - let cache = match settings.parsed.config.cache { - Some(true) => "Yes".to_string(), - Some(false) => "No".to_string(), - None => "Not set".to_string(), + let indpts = settings.config.init_points.to_string(); + let error = settings.error.class.to_string(); + let cache = match settings.config.cache { + true => "Enabled".to_string(), + false => "Disabled".to_string(), }; - let seed = settings.parsed.config.seed.to_string(); + let seed = settings.config.seed.to_string(); let data = vec![ ("Maximum cycles", cycles), @@ -255,64 +254,6 @@ pub fn draw_tabs<'a>(app: &App) -> Tabs<'a> { tabs } -fn get_computed_settings(settings: &Settings) -> Vec { - let computed = settings.computed.clone(); - let mut rows = Vec::new(); - let key_style = Style::default().fg(Color::LightCyan); - let help_style = Style::default().fg(Color::Gray); - - // Iterate over the random ranges - for (name, &(start, end)) in computed.random.names.iter().zip(&computed.random.ranges) { - let row = Row::new(vec![ - Cell::from(Span::styled(name.to_string(), key_style)), - Cell::from(Span::styled( - format!("{:.2} - {:.2}", start, end), - help_style, - )), - ]); - rows.push(row); - } - - // Iterate over the constant values - for (name, &value) in computed - .constant - .names - .iter() - .zip(&computed.constant.values) - { - let row = Row::new(vec![ - Cell::from(Span::styled(name.to_string(), key_style)), - Cell::from(Span::styled(format!("{:.2} (Constant)", value), help_style)), - ]); - rows.push(row); - } - - // Iterate over the fixed values - for (name, &value) in computed.fixed.names.iter().zip(&computed.fixed.values) { - let row = Row::new(vec![ - Cell::from(Span::styled(name.to_string(), key_style)), - Cell::from(Span::styled(format!("{:.2} (Fixed)", value), help_style)), - ]); - rows.push(row); - } - - rows -} - -pub fn draw_parameter_bounds(settings: &Settings) -> Table { - let rows = get_computed_settings(&settings); - Table::default() - .rows(rows) - .block( - Block::default() - .borders(Borders::ALL) - .border_type(BorderType::Plain) - .title(" Parameters "), - ) - .widths(&[Constraint::Percentage(20), Constraint::Percentage(80)]) // Set percentage widths for columns - .column_spacing(1) -} - fn format_time(elapsed_time: std::time::Duration) -> String { let elapsed_seconds = elapsed_time.as_secs(); let (elapsed, unit) = if elapsed_seconds < 60 { diff --git a/src/tui/ui.rs b/src/tui/ui.rs index 5b323800b..0fcd5b092 100644 --- a/src/tui/ui.rs +++ b/src/tui/ui.rs @@ -28,10 +28,11 @@ pub enum Comm { LogMessage(String), } -use crate::prelude::{output::NPCycle, settings::run::Settings}; +use crate::prelude::{output::NPCycle, settings::Settings}; use crate::tui::components::*; pub fn start_ui(mut rx: UnboundedReceiver, settings: Settings) -> Result<()> { + initialize_panic_handler(); let mut stdout = stdout(); execute!(stdout, crossterm::terminal::EnterAlternateScreen)?; crossterm::terminal::enable_raw_mode()?; @@ -228,9 +229,21 @@ pub fn draw( rect.render_widget(plot, tab_layout[1]); } 2 => { - let par_bounds = draw_parameter_bounds(&settings); - rect.render_widget(par_bounds, tab_layout[1]); + // TODO: Return this to show the parameter boundaries + let plot = draw_plot(&mut norm_data); + rect.render_widget(plot, tab_layout[1]); } _ => unreachable!(), }; } + +// From https://ratatui.rs/how-to/develop-apps/panic-hooks/ +pub fn initialize_panic_handler() { + let original_hook = std::panic::take_hook(); + std::panic::set_hook(Box::new(move |panic_info| { + crossterm::execute!(std::io::stderr(), crossterm::terminal::LeaveAlternateScreen).unwrap(); + crossterm::terminal::disable_raw_mode().unwrap(); + crossterm::terminal::Clear(crossterm::terminal::ClearType::All); + original_hook(panic_info); + })); +}