Skip to content

Commit

Permalink
Merge pull request #67 from LAPKB/typestate_settings_extended
Browse files Browse the repository at this point in the history
Minor changes to algorithm dispatch and choice
  • Loading branch information
mhovd authored Dec 24, 2024
2 parents 9b0d809 + 96e84f7 commit 4f807b4
Show file tree
Hide file tree
Showing 5 changed files with 20 additions and 28 deletions.
1 change: 1 addition & 0 deletions src/algorithms/map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ use super::{initialization, output::CycleLog, NonParametricAlgorithm};
/// Maximum a posteriori (MAP) estimation
///
/// Calculate the MAP estimate of the parameters of the model given the data.
#[derive(Debug, Clone)]
pub struct MAP<E: Equation> {
equation: E,
psi: Array2<f64>,
Expand Down
33 changes: 13 additions & 20 deletions src/algorithms/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::path::Path;

use crate::prelude::{self, settings::Settings};

use anyhow::{bail, Result};
use anyhow::Result;
use anyhow::{Context, Error};
use map::MAP;
use ndarray::Array2;
Expand All @@ -23,26 +23,14 @@ pub mod routines;
/// Supported algorithms by `PMcore`
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, PartialOrd)]
pub enum Algorithm {
NonParametric(NonParametric),
Parametric(Parametric),
}

/// Supported non-parametric algorithms
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, PartialOrd)]
pub enum NonParametric {
// Non-parametric algorithms
NPAG,
NPOD,
MAP,
// Parametric algorithms
}

/// Supported parametric algorithms
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, PartialOrd)]
pub enum Parametric {
FOCE,
NPSA,
}

/// This traint defines the methods for non-parametric (NP) algorithms
/// This trait defines the methods for non-parametric (NP) algorithms
pub trait NonParametricAlgorithm<E: Equation> {
fn new(config: Settings, equation: E, data: Data) -> Result<Box<Self>, Error>
where
Expand Down Expand Up @@ -122,15 +110,20 @@ pub trait NonParametricAlgorithm<E: Equation> {
fn into_npresult(&self) -> NPResult<E>;
}

pub trait ParametricAlgorithm<E: Equation> {
fn fit(&mut self) -> Result<()> {
unimplemented!()
}
}

pub fn dispatch_algorithm<E: Equation>(
settings: Settings,
equation: E,
data: Data,
) -> Result<Box<dyn NonParametricAlgorithm<E>>> {
match settings.config().algorithm {
Algorithm::NonParametric(NonParametric::NPAG) => Ok(NPAG::new(settings, equation, data)?),
Algorithm::NonParametric(NonParametric::NPOD) => Ok(NPOD::new(settings, equation, data)?),
Algorithm::NonParametric(NonParametric::MAP) => Ok(MAP::new(settings, equation, data)?),
_ => bail!("Unsupported algorithm"),
Algorithm::NPAG => Ok(NPAG::new(settings, equation, data)?),
Algorithm::NPOD => Ok(NPOD::new(settings, equation, data)?),
Algorithm::MAP => Ok(MAP::new(settings, equation, data)?),
}
}
2 changes: 1 addition & 1 deletion src/algorithms/npag.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ const THETA_G: f64 = 1e-4; // Objective function convergence criteria
const THETA_F: f64 = 1e-2;
const THETA_D: f64 = 1e-4;

#[derive(Debug)]
#[derive(Debug, Clone)]
pub struct NPAG<E: Equation> {
equation: E,
psi: Array2<f64>,
Expand Down
1 change: 1 addition & 0 deletions src/algorithms/npod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ use super::{
const THETA_F: f64 = 1e-2;
const THETA_D: f64 = 1e-4;

#[derive(Debug, Clone)]
pub struct NPOD<E: Equation> {
equation: E,
psi: Array2<f64>,
Expand Down
11 changes: 4 additions & 7 deletions src/algorithms/routines/settings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ impl Default for Config {
fn default() -> Self {
Config {
cycles: 100,
algorithm: Algorithm::NonParametric(crate::algorithms::NonParametric::NPAG),
algorithm: Algorithm::NPAG,
cache: true,
}
}
Expand Down Expand Up @@ -793,7 +793,7 @@ impl SettingsBuilder<ErrorSet> {

mod tests {
use super::*;
use crate::algorithms::{Algorithm, NonParametric};
use crate::algorithms::Algorithm;
use pharmsol::prelude::data::ErrorType;

#[test]
Expand All @@ -805,7 +805,7 @@ mod tests {
.unwrap();

let settings = SettingsBuilder::new()
.set_algorithm(Algorithm::NonParametric(NonParametric::NPAG)) // Step 1: Define algorithm
.set_algorithm(Algorithm::NPAG) // Step 1: Define algorithm
.set_parameters(parameters) // Step 2: Define parameters
.set_error_model(Error {
value: 0.1,
Expand All @@ -814,9 +814,6 @@ mod tests {
}) // Step 3: Define error model
.build(); // Final step

assert_eq!(
settings.config.algorithm,
Algorithm::NonParametric(NonParametric::NPAG,)
);
assert_eq!(settings.config.algorithm, Algorithm::NPAG);
}
}

0 comments on commit 4f807b4

Please sign in to comment.