From 96e84f732b76f45c61fa4ceecc0cba2d08bf341c Mon Sep 17 00:00:00 2001 From: Markus <66058642+mhovd@users.noreply.github.com> Date: Tue, 24 Dec 2024 14:50:13 +0100 Subject: [PATCH] Experimenting with algorithm choice and dispatch --- src/algorithms/map.rs | 1 + src/algorithms/mod.rs | 33 ++++++++++++----------------- src/algorithms/npag.rs | 2 +- src/algorithms/npod.rs | 1 + src/algorithms/routines/settings.rs | 11 ++++------ 5 files changed, 20 insertions(+), 28 deletions(-) diff --git a/src/algorithms/map.rs b/src/algorithms/map.rs index 8a336cff..77463b09 100644 --- a/src/algorithms/map.rs +++ b/src/algorithms/map.rs @@ -15,6 +15,7 @@ use super::{initialization, output::CycleLog, NonParametricAlgorithm}; /// Maximum a posteriori (MAP) estimation /// /// Calculate the MAP estimate of the parameters of the model given the data. +#[derive(Debug, Clone)] pub struct MAP { equation: E, psi: Array2, diff --git a/src/algorithms/mod.rs b/src/algorithms/mod.rs index 12b41b33..f1fbb0cf 100644 --- a/src/algorithms/mod.rs +++ b/src/algorithms/mod.rs @@ -3,7 +3,7 @@ use std::path::Path; use crate::prelude::{self, settings::Settings}; -use anyhow::{bail, Result}; +use anyhow::Result; use anyhow::{Context, Error}; use map::MAP; use ndarray::Array2; @@ -23,26 +23,14 @@ pub mod routines; /// Supported algorithms by `PMcore` #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, PartialOrd)] pub enum Algorithm { - NonParametric(NonParametric), - Parametric(Parametric), -} - -/// Supported non-parametric algorithms -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, PartialOrd)] -pub enum NonParametric { + // Non-parametric algorithms NPAG, NPOD, MAP, + // Parametric algorithms } -/// Supported parametric algorithms -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, PartialOrd)] -pub enum Parametric { - FOCE, - NPSA, -} - -/// This traint defines the methods for non-parametric (NP) algorithms +/// This trait defines the methods for non-parametric (NP) algorithms pub trait NonParametricAlgorithm { fn new(config: Settings, equation: E, data: Data) -> Result, Error> where @@ -122,15 +110,20 @@ pub trait NonParametricAlgorithm { fn into_npresult(&self) -> NPResult; } +pub trait ParametricAlgorithm { + fn fit(&mut self) -> Result<()> { + unimplemented!() + } +} + pub fn dispatch_algorithm( settings: Settings, equation: E, data: Data, ) -> Result>> { match settings.config().algorithm { - Algorithm::NonParametric(NonParametric::NPAG) => Ok(NPAG::new(settings, equation, data)?), - Algorithm::NonParametric(NonParametric::NPOD) => Ok(NPOD::new(settings, equation, data)?), - Algorithm::NonParametric(NonParametric::MAP) => Ok(MAP::new(settings, equation, data)?), - _ => bail!("Unsupported algorithm"), + Algorithm::NPAG => Ok(NPAG::new(settings, equation, data)?), + Algorithm::NPOD => Ok(NPOD::new(settings, equation, data)?), + Algorithm::MAP => Ok(MAP::new(settings, equation, data)?), } } diff --git a/src/algorithms/npag.rs b/src/algorithms/npag.rs index 95390fc1..8b0f71cf 100644 --- a/src/algorithms/npag.rs +++ b/src/algorithms/npag.rs @@ -24,7 +24,7 @@ const THETA_G: f64 = 1e-4; // Objective function convergence criteria const THETA_F: f64 = 1e-2; const THETA_D: f64 = 1e-4; -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct NPAG { equation: E, psi: Array2, diff --git a/src/algorithms/npod.rs b/src/algorithms/npod.rs index 3f80a4df..f3817c4d 100644 --- a/src/algorithms/npod.rs +++ b/src/algorithms/npod.rs @@ -28,6 +28,7 @@ use super::{ const THETA_F: f64 = 1e-2; const THETA_D: f64 = 1e-4; +#[derive(Debug, Clone)] pub struct NPOD { equation: E, psi: Array2, diff --git a/src/algorithms/routines/settings.rs b/src/algorithms/routines/settings.rs index 500dac15..d6ec5fdf 100644 --- a/src/algorithms/routines/settings.rs +++ b/src/algorithms/routines/settings.rs @@ -227,7 +227,7 @@ impl Default for Config { fn default() -> Self { Config { cycles: 100, - algorithm: Algorithm::NonParametric(crate::algorithms::NonParametric::NPAG), + algorithm: Algorithm::NPAG, cache: true, } } @@ -793,7 +793,7 @@ impl SettingsBuilder { mod tests { use super::*; - use crate::algorithms::{Algorithm, NonParametric}; + use crate::algorithms::Algorithm; use pharmsol::prelude::data::ErrorType; #[test] @@ -805,7 +805,7 @@ mod tests { .unwrap(); let settings = SettingsBuilder::new() - .set_algorithm(Algorithm::NonParametric(NonParametric::NPAG)) // Step 1: Define algorithm + .set_algorithm(Algorithm::NPAG) // Step 1: Define algorithm .set_parameters(parameters) // Step 2: Define parameters .set_error_model(Error { value: 0.1, @@ -814,9 +814,6 @@ mod tests { }) // Step 3: Define error model .build(); // Final step - assert_eq!( - settings.config.algorithm, - Algorithm::NonParametric(NonParametric::NPAG,) - ); + assert_eq!(settings.config.algorithm, Algorithm::NPAG); } }