diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index f82cd4fdd..a4b364abe 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -181,11 +181,11 @@ jobs: - name: Install wasm-pack run: curl https://rustwasm.github.io/wasm-pack/installer/init.sh -sSf | sh - name: Build target wasm32-unknown-unknown - run: cargo build --workspace --exclude argmin-observer-spectator --exclude spectator --exclude argmin-observer-paramwriter --exclude "example-*" --target wasm32-unknown-unknown --features wasm-bindgen + run: cargo build --workspace --exclude argmin-observer-spectator --exclude spectator --exclude argmin-observer-paramwriter --exclude "example-*" --exclude argmin-py --target wasm32-unknown-unknown --features wasm-bindgen - name: Build target wasm32-wasi with feature wasm-bindgen - run: cargo build --workspace --exclude argmin-observer-spectator --exclude spectator --exclude argmin-observer-paramwriter --exclude "example-*" --target wasm32-wasi --features wasm-bindgen + run: cargo build --workspace --exclude argmin-observer-spectator --exclude spectator --exclude argmin-observer-paramwriter --exclude "example-*" --exclude argmin-py --target wasm32-wasi --features wasm-bindgen - name: Build target wasm32-unknown-emscripten - run: cargo build --workspace --exclude argmin-observer-spectator --exclude spectator --exclude argmin-observer-paramwriter --exclude "example-*" --target wasm32-unknown-emscripten --no-default-features --features wasm-bindgen + run: cargo build --workspace --exclude argmin-observer-spectator --exclude spectator --exclude argmin-observer-paramwriter --exclude "example-*" --exclude argmin-py --target wasm32-unknown-emscripten --no-default-features --features wasm-bindgen cargo-deny: runs-on: ubuntu-latest diff --git a/.gitignore b/.gitignore index 77cf2b7bd..be6745289 100644 --- a/.gitignore +++ b/.gitignore @@ -4,3 +4,5 @@ target/ target/* *.log justfile +.vscode +.venv diff --git a/crates/argmin-py/Cargo.toml b/crates/argmin-py/Cargo.toml new file mode 100644 index 000000000..6eb8cd5ff --- /dev/null +++ b/crates/argmin-py/Cargo.toml @@ -0,0 +1,28 @@ +[package] +name = "argmin-py" +version = "0.1.0" +authors = ["Joris Bayer or the MIT license , at your option. This file may not be +# copied, modified, or distributed except according to those terms. +from argmin import Problem, Solver, Executor +import numpy as np +from scipy.optimize import rosen_der, rosen_hess + + +def main(): + problem = Problem( + gradient=rosen_der, + hessian=rosen_hess, + ) + solver = Solver.Newton + executor = Executor(problem, solver) + executor.configure(param=np.array([-1.2, 1.0]), max_iters=8) + + result = executor.run() + print(result) + + +if __name__ == "__main__": + main() diff --git a/crates/argmin-py/src/executor.rs b/crates/argmin-py/src/executor.rs new file mode 100644 index 000000000..94b5c3f4e --- /dev/null +++ b/crates/argmin-py/src/executor.rs @@ -0,0 +1,71 @@ +// Copyright 2018-2023 argmin developers +// +// Licensed under the Apache License, Version 2.0 or the MIT license , at your option. This file may not be +// copied, modified, or distributed except according to those terms. + +// TODO: docs + +use pyo3::{prelude::*, types::PyDict}; + +use argmin::core; + +use crate::problem::Problem; +use crate::solver::{DynamicSolver, Solver}; +use crate::types::{IterState, PyArray1}; + +#[pyclass] +pub struct Executor(Option>); + +impl Executor { + /// Consumes the inner executor. + /// + /// PyObjects do not allow methods that consume the object itself, so this is a workaround + /// for using methods like `configure` and `run`. + fn take(&mut self) -> anyhow::Result> { + let Some(inner) = self.0.take() else { + return Err(anyhow::anyhow!("Executor was already run.")); + }; + Ok(inner) + } +} + +#[pymethods] +impl Executor { + #[new] + fn new(problem: Problem, solver: Solver) -> Self { + Self(Some(core::Executor::new(problem, solver.into()))) + } + + #[pyo3(signature = (**kwargs))] + fn configure(&mut self, kwargs: Option<&PyDict>) -> PyResult<()> { + if let Some(kwargs) = kwargs { + let param = kwargs + .get_item("param")? + .map(|x| x.extract::<&PyArray1>()) + .map_or(Ok(None), |r| r.map(Some))?; + let max_iters = kwargs + .get_item("max_iters")? + .map(|x| x.extract()) + .map_or(Ok(None), |r| r.map(Some))?; + + self.0 = Some(self.take()?.configure(|mut state| { + if let Some(param) = param { + state = state.param(param.to_owned_array()); + } + if let Some(max_iters) = max_iters { + state = state.max_iters(max_iters); + } + state + })); + } + Ok(()) + } + + fn run(&mut self) -> PyResult { + // TODO: return usable OptimizationResult + let res = self.take()?.run(); + Ok(res?.to_string()) + } +} diff --git a/crates/argmin-py/src/lib.rs b/crates/argmin-py/src/lib.rs new file mode 100644 index 000000000..fa3a21e15 --- /dev/null +++ b/crates/argmin-py/src/lib.rs @@ -0,0 +1,24 @@ +// Copyright 2018-2023 argmin developers +// +// Licensed under the Apache License, Version 2.0 or the MIT license , at your option. This file may not be +// copied, modified, or distributed except according to those terms. + +// TODO: docs +mod executor; +mod problem; +mod solver; +mod types; + +use pyo3::prelude::*; + +#[pymodule] +#[pyo3(name = "argmin")] +fn argmin_py(_py: Python, m: &PyModule) -> PyResult<()> { + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + + Ok(()) +} diff --git a/crates/argmin-py/src/problem.rs b/crates/argmin-py/src/problem.rs new file mode 100644 index 000000000..93d6aae2a --- /dev/null +++ b/crates/argmin-py/src/problem.rs @@ -0,0 +1,68 @@ +// Copyright 2018-2023 argmin developers +// +// Licensed under the Apache License, Version 2.0 or the MIT license , at your option. This file may not be +// copied, modified, or distributed except according to those terms. + +// TODO: docs + +use numpy::ToPyArray; +use pyo3::{prelude::*, types::PyTuple}; + +use argmin::core; + +use crate::types::{Array1, Array2, Scalar}; + +#[pyclass] +#[derive(Clone)] +pub struct Problem { + gradient: PyObject, + hessian: PyObject, + // TODO: jacobian +} + +#[pymethods] +impl Problem { + #[new] + fn new(gradient: PyObject, hessian: PyObject) -> Self { + Self { gradient, hessian } + } +} + +impl core::Gradient for Problem { + type Param = Array1; + type Gradient = Array1; + + fn gradient(&self, param: &Self::Param) -> Result { + call(&self.gradient, param) + } +} + +impl argmin::core::Hessian for Problem { + type Param = Array1; + + type Hessian = Array2; + + fn hessian(&self, param: &Self::Param) -> Result { + call(&self.hessian, param) + } +} + +fn call( + callable: &PyObject, + param: &ndarray::Array, +) -> Result, argmin::core::Error> +where + InputDimension: ndarray::Dimension, + OutputDimension: ndarray::Dimension, +{ + // TODO: prevent dynamic dispatch for every call + Python::with_gil(|py| { + let args = PyTuple::new(py, [param.to_pyarray(py)]); + let pyresult = callable.call(py, args, Default::default())?; + let pyarray = pyresult.extract::<&numpy::PyArray>(py)?; + // TODO: try to get ownership instead of cloning + Ok(pyarray.to_owned_array()) + }) +} diff --git a/crates/argmin-py/src/solver.rs b/crates/argmin-py/src/solver.rs new file mode 100644 index 000000000..0cae464b1 --- /dev/null +++ b/crates/argmin-py/src/solver.rs @@ -0,0 +1,46 @@ +// Copyright 2018-2023 argmin developers +// +// Licensed under the Apache License, Version 2.0 or the MIT license , at your option. This file may not be +// copied, modified, or distributed except according to those terms. + +// TODO: docs + +use pyo3::prelude::*; + +use argmin::{core, solver}; + +use crate::{problem::Problem, types::IterState}; + +#[pyclass] +#[derive(Clone)] +pub enum Solver { + Newton, +} + +pub struct DynamicSolver(Box + Send>); + +impl From for DynamicSolver { + fn from(value: Solver) -> Self { + let inner = match value { + Solver::Newton => solver::newton::Newton::new(), + }; + Self(Box::new(inner)) + } +} + +impl core::Solver for DynamicSolver { + // TODO: make this a trait method so we can return a dynamic + fn name(&self) -> &str { + self.0.name() + } + + fn next_iter( + &mut self, + problem: &mut core::Problem, + state: IterState, + ) -> Result<(IterState, Option), core::Error> { + self.0.next_iter(problem, state) + } +} diff --git a/crates/argmin-py/src/types.rs b/crates/argmin-py/src/types.rs new file mode 100644 index 000000000..57cab8f97 --- /dev/null +++ b/crates/argmin-py/src/types.rs @@ -0,0 +1,16 @@ +// Copyright 2018-2023 argmin developers +// +// Licensed under the Apache License, Version 2.0 or the MIT license , at your option. This file may not be +// copied, modified, or distributed except according to those terms. + +//! Base types for the Python extension. + +pub type Scalar = f64; // TODO: allow complex numbers +pub type Array1 = ndarray::Array1; +pub type Array2 = ndarray::Array2; +pub type PyArray1 = numpy::PyArray1; + +pub type IterState = + argmin::core::IterState, (), Scalar>; diff --git a/crates/argmin/src/core/checkpointing/mod.rs b/crates/argmin/src/core/checkpointing/mod.rs index e87ae2d73..674a81452 100644 --- a/crates/argmin/src/core/checkpointing/mod.rs +++ b/crates/argmin/src/core/checkpointing/mod.rs @@ -155,7 +155,7 @@ use std::fmt::Display; /// } /// # fn main() {} /// ``` -pub trait Checkpoint { +pub trait Checkpoint: Send { /// Save a checkpoint /// /// Gets a reference to the current `solver` of type `S` and to the current `state` of type diff --git a/crates/argmin/src/core/executor.rs b/crates/argmin/src/core/executor.rs index a4c546504..236ac053d 100644 --- a/crates/argmin/src/core/executor.rs +++ b/crates/argmin/src/core/executor.rs @@ -180,7 +180,8 @@ where let kv = kv.unwrap_or(kv![]); // Observe after init - self.observers.observe_init(S::NAME, &state, &kv)?; + self.observers + .observe_init(self.solver.name(), &state, &kv)?; } state.func_counts(&self.problem); @@ -681,7 +682,9 @@ mod tests { P: Clone, F: ArgminFloat, { - const NAME: &'static str = "OptimizationAlgorithm"; + fn name(&self) -> &str { + "OptimizationAlgorithm" + } // Only resets internal_state to 1 fn init( diff --git a/crates/argmin/src/core/observers/mod.rs b/crates/argmin/src/core/observers/mod.rs index 1d2157ee0..89a1b2e8b 100644 --- a/crates/argmin/src/core/observers/mod.rs +++ b/crates/argmin/src/core/observers/mod.rs @@ -149,7 +149,7 @@ use std::sync::{Arc, Mutex}; /// } /// } /// ``` -pub trait Observe { +pub trait Observe: Send { /// Called once after initialization of the solver. /// /// Has access to the name of the solver via `name`, the initial `state` and to a key-value diff --git a/crates/argmin/src/core/result.rs b/crates/argmin/src/core/result.rs index 10eb309ef..3ca346248 100644 --- a/crates/argmin/src/core/result.rs +++ b/crates/argmin/src/core/result.rs @@ -124,7 +124,7 @@ where { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { writeln!(f, "OptimizationResult:")?; - writeln!(f, " Solver: {}", S::NAME)?; + writeln!(f, " Solver: {}", self.solver().name())?; writeln!( f, " param (best): {}", diff --git a/crates/argmin/src/core/solver.rs b/crates/argmin/src/core/solver.rs index d3a24bde7..394d687d8 100644 --- a/crates/argmin/src/core/solver.rs +++ b/crates/argmin/src/core/solver.rs @@ -36,7 +36,7 @@ use crate::core::{Error, Problem, State, TerminationReason, TerminationStatus, K /// P: Clone, /// F: ArgminFloat /// { -/// const NAME: &'static str = "OptimizationAlgorithm"; +/// fn name(&self) -> &str { "OptimizationAlgorithm" } /// /// fn init( /// &mut self, @@ -67,7 +67,7 @@ use crate::core::{Error, Problem, State, TerminationReason, TerminationStatus, K /// ``` pub trait Solver { /// Name of the solver. Mainly used in [Observers](`crate::core::observers::Observe`). - const NAME: &'static str; + // const NAME: &'static str; /// Initializes the algorithm. /// @@ -117,4 +117,7 @@ pub trait Solver { fn terminate(&mut self, _state: &I) -> TerminationStatus { TerminationStatus::NotTerminated } + + /// Returns the name of the solver. + fn name(&self) -> &str; } diff --git a/crates/argmin/src/core/test_utils.rs b/crates/argmin/src/core/test_utils.rs index b73ca1252..1803025e4 100644 --- a/crates/argmin/src/core/test_utils.rs +++ b/crates/argmin/src/core/test_utils.rs @@ -327,7 +327,9 @@ impl TestSolver { } impl Solver, (), (), (), (), f64>> for TestSolver { - const NAME: &'static str = "TestSolver"; + fn name(&self) -> &str { + "TestSolver" + } fn next_iter( &mut self, diff --git a/crates/argmin/src/solver/brent/brentopt.rs b/crates/argmin/src/solver/brent/brentopt.rs index 4a0995f08..9b9f8c574 100644 --- a/crates/argmin/src/solver/brent/brentopt.rs +++ b/crates/argmin/src/solver/brent/brentopt.rs @@ -102,7 +102,9 @@ where O: CostFunction, F: ArgminFloat, { - const NAME: &'static str = "BrentOpt"; + fn name(&self) -> &str { + "BrentOpt" + } fn init( &mut self, diff --git a/crates/argmin/src/solver/brent/brentroot.rs b/crates/argmin/src/solver/brent/brentroot.rs index d20df595e..b8752f56f 100644 --- a/crates/argmin/src/solver/brent/brentroot.rs +++ b/crates/argmin/src/solver/brent/brentroot.rs @@ -80,7 +80,9 @@ where O: CostFunction, F: ArgminFloat, { - const NAME: &'static str = "BrentRoot"; + fn name(&self) -> &str { + "BrentRoot" + } fn init( &mut self, diff --git a/crates/argmin/src/solver/conjugategradient/cg.rs b/crates/argmin/src/solver/conjugategradient/cg.rs index ac517930f..38eb477f9 100644 --- a/crates/argmin/src/solver/conjugategradient/cg.rs +++ b/crates/argmin/src/solver/conjugategradient/cg.rs @@ -91,7 +91,9 @@ where R: ArgminMul + ArgminMul + ArgminConj + ArgminDot + ArgminScaledAdd, F: ArgminFloat + ArgminL2Norm, { - const NAME: &'static str = "Conjugate Gradient"; + fn name(&self) -> &str { + "Conjugate Gradient" + } fn init( &mut self, diff --git a/crates/argmin/src/solver/conjugategradient/nonlinear_cg.rs b/crates/argmin/src/solver/conjugategradient/nonlinear_cg.rs index 2c3574291..2c3b0af4e 100644 --- a/crates/argmin/src/solver/conjugategradient/nonlinear_cg.rs +++ b/crates/argmin/src/solver/conjugategradient/nonlinear_cg.rs @@ -128,7 +128,9 @@ where B: NLCGBetaUpdate, F: ArgminFloat, { - const NAME: &'static str = "Nonlinear Conjugate Gradient"; + fn name(&self) -> &str { + "Nonlinear Conjugate Gradient" + } fn init( &mut self, diff --git a/crates/argmin/src/solver/gaussnewton/gaussnewton_linesearch.rs b/crates/argmin/src/solver/gaussnewton/gaussnewton_linesearch.rs index 975d84b5e..52f6d6bf6 100644 --- a/crates/argmin/src/solver/gaussnewton/gaussnewton_linesearch.rs +++ b/crates/argmin/src/solver/gaussnewton/gaussnewton_linesearch.rs @@ -96,7 +96,9 @@ where F: ArgminFloat, R: Clone, { - const NAME: &'static str = "Gauss-Newton method with line search"; + fn name(&self) -> &str { + "Gauss-Newton method with line search" + } fn next_iter( &mut self, diff --git a/crates/argmin/src/solver/gaussnewton/gaussnewton_method.rs b/crates/argmin/src/solver/gaussnewton/gaussnewton_method.rs index e4f63a3a3..eb2defddf 100644 --- a/crates/argmin/src/solver/gaussnewton/gaussnewton_method.rs +++ b/crates/argmin/src/solver/gaussnewton/gaussnewton_method.rs @@ -122,7 +122,9 @@ where + ArgminDot, F: ArgminFloat, { - const NAME: &'static str = "Gauss-Newton method"; + fn name(&self) -> &str { + "Gauss-Newton method" + } fn init( &mut self, diff --git a/crates/argmin/src/solver/goldensectionsearch/mod.rs b/crates/argmin/src/solver/goldensectionsearch/mod.rs index d5d87fff2..eb0bae871 100644 --- a/crates/argmin/src/solver/goldensectionsearch/mod.rs +++ b/crates/argmin/src/solver/goldensectionsearch/mod.rs @@ -139,7 +139,9 @@ where O: CostFunction, F: ArgminFloat, { - const NAME: &'static str = "Golden-section search"; + fn name(&self) -> &str { + "Golden-section search" + } fn init( &mut self, diff --git a/crates/argmin/src/solver/gradientdescent/steepestdescent.rs b/crates/argmin/src/solver/gradientdescent/steepestdescent.rs index bd27ab8a7..49b6168aa 100644 --- a/crates/argmin/src/solver/gradientdescent/steepestdescent.rs +++ b/crates/argmin/src/solver/gradientdescent/steepestdescent.rs @@ -58,7 +58,9 @@ where L: Clone + LineSearch + Solver>, F: ArgminFloat, { - const NAME: &'static str = "Steepest Descent"; + fn name(&self) -> &str { + "Steepest Descent" + } fn next_iter( &mut self, diff --git a/crates/argmin/src/solver/landweber/mod.rs b/crates/argmin/src/solver/landweber/mod.rs index acd171b09..b0abaf55e 100644 --- a/crates/argmin/src/solver/landweber/mod.rs +++ b/crates/argmin/src/solver/landweber/mod.rs @@ -69,7 +69,9 @@ where P: Clone + ArgminScaledSub, F: ArgminFloat, { - const NAME: &'static str = "Landweber"; + fn name(&self) -> &str { + "Landweber" + } fn next_iter( &mut self, diff --git a/crates/argmin/src/solver/linesearch/backtracking.rs b/crates/argmin/src/solver/linesearch/backtracking.rs index 313341870..2148a7084 100644 --- a/crates/argmin/src/solver/linesearch/backtracking.rs +++ b/crates/argmin/src/solver/linesearch/backtracking.rs @@ -183,7 +183,9 @@ where L: LineSearchCondition, F: ArgminFloat, { - const NAME: &'static str = "Backtracking line search"; + fn name(&self) -> &str { + "Backtracking line search" + } fn init( &mut self, diff --git a/crates/argmin/src/solver/linesearch/hagerzhang.rs b/crates/argmin/src/solver/linesearch/hagerzhang.rs index 6fa4228c7..5db8a77a9 100644 --- a/crates/argmin/src/solver/linesearch/hagerzhang.rs +++ b/crates/argmin/src/solver/linesearch/hagerzhang.rs @@ -502,7 +502,9 @@ where G: Clone + ArgminDot, F: ArgminFloat, { - const NAME: &'static str = "Hager-Zhang line search"; + fn name(&self) -> &str { + "Hager-Zhang line search" + } fn init( &mut self, diff --git a/crates/argmin/src/solver/linesearch/morethuente.rs b/crates/argmin/src/solver/linesearch/morethuente.rs index a2ca305f9..0fd8c924a 100644 --- a/crates/argmin/src/solver/linesearch/morethuente.rs +++ b/crates/argmin/src/solver/linesearch/morethuente.rs @@ -303,7 +303,9 @@ where G: Clone + ArgminDot, F: ArgminFloat, { - const NAME: &'static str = "More-Thuente Line search"; + fn name(&self) -> &str { + "More-Thuente Line search" + } fn init( &mut self, diff --git a/crates/argmin/src/solver/neldermead/mod.rs b/crates/argmin/src/solver/neldermead/mod.rs index 623f9c0ed..c224a31de 100644 --- a/crates/argmin/src/solver/neldermead/mod.rs +++ b/crates/argmin/src/solver/neldermead/mod.rs @@ -322,7 +322,9 @@ where P: Clone + ArgminSub + ArgminAdd + ArgminMul, F: ArgminFloat + std::iter::Sum, { - const NAME: &'static str = "Nelder-Mead method"; + fn name(&self) -> &str { + "Nelder-Mead method" + } fn init( &mut self, diff --git a/crates/argmin/src/solver/newton/newton_cg.rs b/crates/argmin/src/solver/newton/newton_cg.rs index e5e8d3bdb..bcefe0bd9 100644 --- a/crates/argmin/src/solver/newton/newton_cg.rs +++ b/crates/argmin/src/solver/newton/newton_cg.rs @@ -120,7 +120,9 @@ where L: Clone + LineSearch + Solver>, F: ArgminFloat + ArgminL2Norm, { - const NAME: &'static str = "Newton-CG"; + fn name(&self) -> &str { + "Newton-CG" + } fn next_iter( &mut self, diff --git a/crates/argmin/src/solver/newton/newton_method.rs b/crates/argmin/src/solver/newton/newton_method.rs index 47434fa2a..210648d3c 100644 --- a/crates/argmin/src/solver/newton/newton_method.rs +++ b/crates/argmin/src/solver/newton/newton_method.rs @@ -92,7 +92,9 @@ where H: ArgminInv + ArgminDot, F: ArgminFloat, { - const NAME: &'static str = "Newton method"; + fn name(&self) -> &str { + "Newton method" + } fn next_iter( &mut self, diff --git a/crates/argmin/src/solver/particleswarm/mod.rs b/crates/argmin/src/solver/particleswarm/mod.rs index ee5d77734..2c3830a5f 100644 --- a/crates/argmin/src/solver/particleswarm/mod.rs +++ b/crates/argmin/src/solver/particleswarm/mod.rs @@ -289,7 +289,9 @@ where F: ArgminFloat, R: Rng, { - const NAME: &'static str = "Particle Swarm Optimization"; + fn name(&self) -> &str { + "Particle Swarm Optimization" + } fn init( &mut self, diff --git a/crates/argmin/src/solver/quasinewton/bfgs.rs b/crates/argmin/src/solver/quasinewton/bfgs.rs index 5809afb30..38f237b0f 100644 --- a/crates/argmin/src/solver/quasinewton/bfgs.rs +++ b/crates/argmin/src/solver/quasinewton/bfgs.rs @@ -145,7 +145,9 @@ where L: Clone + LineSearch + Solver>, F: ArgminFloat, { - const NAME: &'static str = "BFGS"; + fn name(&self) -> &str { + "BFGS" + } fn init( &mut self, diff --git a/crates/argmin/src/solver/quasinewton/dfp.rs b/crates/argmin/src/solver/quasinewton/dfp.rs index f9b41a222..b606e393d 100644 --- a/crates/argmin/src/solver/quasinewton/dfp.rs +++ b/crates/argmin/src/solver/quasinewton/dfp.rs @@ -104,7 +104,9 @@ where L: Clone + LineSearch + Solver>, F: ArgminFloat, { - const NAME: &'static str = "DFP"; + fn name(&self) -> &str { + "DFP" + } fn init( &mut self, diff --git a/crates/argmin/src/solver/quasinewton/lbfgs.rs b/crates/argmin/src/solver/quasinewton/lbfgs.rs index febf24224..c024b7363 100644 --- a/crates/argmin/src/solver/quasinewton/lbfgs.rs +++ b/crates/argmin/src/solver/quasinewton/lbfgs.rs @@ -333,7 +333,9 @@ where + Solver, IterState>, F: ArgminFloat, { - const NAME: &'static str = "L-BFGS"; + fn name(&self) -> &str { + "L-BFGS" + } fn init( &mut self, diff --git a/crates/argmin/src/solver/quasinewton/sr1.rs b/crates/argmin/src/solver/quasinewton/sr1.rs index 35c9c18c4..6582179de 100644 --- a/crates/argmin/src/solver/quasinewton/sr1.rs +++ b/crates/argmin/src/solver/quasinewton/sr1.rs @@ -159,7 +159,9 @@ where L: Clone + LineSearch + Solver>, F: ArgminFloat, { - const NAME: &'static str = "SR1"; + fn name(&self) -> &str { + "SR1" + } fn init( &mut self, diff --git a/crates/argmin/src/solver/quasinewton/sr1_trustregion.rs b/crates/argmin/src/solver/quasinewton/sr1_trustregion.rs index 39c160323..1dcad05c9 100644 --- a/crates/argmin/src/solver/quasinewton/sr1_trustregion.rs +++ b/crates/argmin/src/solver/quasinewton/sr1_trustregion.rs @@ -196,7 +196,9 @@ where R: Clone + TrustRegionRadius + Solver>, F: ArgminFloat + ArgminL2Norm, { - const NAME: &'static str = "SR1 trust region"; + fn name(&self) -> &str { + "SR1 trust region" + } fn init( &mut self, diff --git a/crates/argmin/src/solver/simulatedannealing/mod.rs b/crates/argmin/src/solver/simulatedannealing/mod.rs index 2e78d1841..adc39fc7c 100644 --- a/crates/argmin/src/solver/simulatedannealing/mod.rs +++ b/crates/argmin/src/solver/simulatedannealing/mod.rs @@ -442,7 +442,9 @@ where F: ArgminFloat, R: Rng, { - const NAME: &'static str = "Simulated Annealing"; + fn name(&self) -> &str { + "Simulated Annealing" + } fn init( &mut self, problem: &mut Problem, diff --git a/crates/argmin/src/solver/trustregion/cauchypoint.rs b/crates/argmin/src/solver/trustregion/cauchypoint.rs index 46c0aa29f..ae1aefd53 100644 --- a/crates/argmin/src/solver/trustregion/cauchypoint.rs +++ b/crates/argmin/src/solver/trustregion/cauchypoint.rs @@ -58,7 +58,9 @@ where G: ArgminMul + ArgminWeightedDot + ArgminL2Norm, F: ArgminFloat, { - const NAME: &'static str = "Cauchy Point"; + fn name(&self) -> &str { + "Cauchy Point" + } fn next_iter( &mut self, diff --git a/crates/argmin/src/solver/trustregion/dogleg.rs b/crates/argmin/src/solver/trustregion/dogleg.rs index 1e624543c..c1d702960 100644 --- a/crates/argmin/src/solver/trustregion/dogleg.rs +++ b/crates/argmin/src/solver/trustregion/dogleg.rs @@ -65,7 +65,9 @@ where H: ArgminInv + ArgminDot, F: ArgminFloat, { - const NAME: &'static str = "Dogleg"; + fn name(&self) -> &str { + "Dogleg" + } fn next_iter( &mut self, diff --git a/crates/argmin/src/solver/trustregion/steihaug.rs b/crates/argmin/src/solver/trustregion/steihaug.rs index d13472959..a6f8234aa 100644 --- a/crates/argmin/src/solver/trustregion/steihaug.rs +++ b/crates/argmin/src/solver/trustregion/steihaug.rs @@ -189,7 +189,9 @@ where H: ArgminDot, F: ArgminFloat, { - const NAME: &'static str = "Steihaug"; + fn name(&self) -> &str { + "Steihaug" + } fn init( &mut self, diff --git a/crates/argmin/src/solver/trustregion/trustregion_method.rs b/crates/argmin/src/solver/trustregion/trustregion_method.rs index 1241d7b20..e80fc2b7f 100644 --- a/crates/argmin/src/solver/trustregion/trustregion_method.rs +++ b/crates/argmin/src/solver/trustregion/trustregion_method.rs @@ -167,7 +167,9 @@ where R: Clone + TrustRegionRadius + Solver>, F: ArgminFloat, { - const NAME: &'static str = "Trust region"; + fn name(&self) -> &str { + "Trust region" + } fn init( &mut self, diff --git a/media/book/src/implementing_solver.md b/media/book/src/implementing_solver.md index 8255c48b0..ec58d9ffb 100644 --- a/media/book/src/implementing_solver.md +++ b/media/book/src/implementing_solver.md @@ -1,7 +1,7 @@ # Implementing a solver In this section we are going to implement the Landweber solver, which essentially is a special form of gradient descent. -In iteration \\( k \\), the new parameter vector \\( x_{k+1} \\) is calculated from the previous parameter vector \\( x_k \\) and the gradient at \\( x_k \\) according to the following update rule: +In iteration \\( k \\), the new parameter vector \\( x\_{k+1} \\) is calculated from the previous parameter vector \\( x_k \\) and the gradient at \\( x_k \\) according to the following update rule: \\[ x_{k+1} = x_k - \omega * \nabla f(x_k) @@ -12,50 +12,50 @@ Then, the [`Solver`](https://docs.rs/argmin/latest/argmin/core/trait.Solver.html The `Solver` trait consists of several methods; however, not all of them need to be implemented since most come with default implementations. -* `NAME`: a `&'static str` which holds the solvers name (mainly needed for the observers). -* `init(...)`: Run before the the actual iterations and initializes the solver. Does nothing by default. -* `next_iter(...)`: One iteration of the solver. Will be executed by the `Executor` until a stopping criterion is met. -* `terminate(...)`: Solver specific stopping criteria. This method is run after every iteration. Note that one can also terminate from within `next_iter` if necessary. -* `terminate_internal(...)`: By default calls `terminate` and in addition checks if the maximum number of iterations was reached or if the best cost function value is below the target cost value. Should only be overwritten if absolutely necessary. +- `NAME`: a `&'static str` which holds the solvers name (mainly needed for the observers). +- `init(...)`: Run before the the actual iterations and initializes the solver. Does nothing by default. +- `next_iter(...)`: One iteration of the solver. Will be executed by the `Executor` until a stopping criterion is met. +- `terminate(...)`: Solver specific stopping criteria. This method is run after every iteration. Note that one can also terminate from within `next_iter` if necessary. +- `terminate_internal(...)`: By default calls `terminate` and in addition checks if the maximum number of iterations was reached or if the best cost function value is below the target cost value. Should only be overwritten if absolutely necessary. Both `init` and `next_iter` have access to the optimization problem (`problem`) as well as the internal state (`state`). -The methods `terminate` and `terminate_internal` only have access to `state`. +The methods `terminate` and `terminate_internal` only have access to `state`. The function parameter `problem` is a wrapped version of the optimization problem and as such gives access to the cost function, gradient, Hessian, Jacobian,...). It also keeps track of how often each of these is called. Via `state` the solver has access to the current parameter vector, the current best parameter vector, gradient, Hessian, Jacobian, population, the current iteration number, and so on. The `state` can be modified (for instance a new parameter vector is set) and is then returned by both `init` and `next_iter`. -The `Executor` then takes care of updating the state properly, for instance by updating the current best parameter vector if the new parameter vector is better than the previous best. +The `Executor` then takes care of updating the state properly, for instance by updating the current best parameter vector if the new parameter vector is better than the previous best. It is advisable to design the solver such that it is generic over the actual type of the parameter vector, gradient, and so on. The current naming convention for generics in argmin is as follows: -* `O`: Optimization problem -* `P`: Parameter vector -* `G`: Gradient -* `J`: Jacobian -* `H`: Hessian -* `F`: Floats (`f32` or `f64`) +- `O`: Optimization problem +- `P`: Parameter vector +- `G`: Gradient +- `J`: Jacobian +- `H`: Hessian +- `F`: Floats (`f32` or `f64`) These individual generic parameters are then constrained by type constraints. For instance, the Landweber iteration requires the problem `O` to implement `Gradient`, therefore a trait bound of the form `O: Gradient` is necessary. From the Landweber update formula, we know that a scaled subtraction of two vectors is required. This must be represented in form of a trait bound as well: `P: ArgminScaledSub`. -`ArgminScaledSub` is a trait from `argmin-math` which represents a scaled subtraction. +`ArgminScaledSub` is a trait from `argmin-math` which represents a scaled subtraction. With this trait bound, we require that it must be possible to subtract a value of type `G` scaled with a value of type `F` from a value of type `P`, resulting in a value of type `P`. The generic type `F` represents floating point value and therefore allows users to choose which precision they want. Implementing the algorithm is straightforward: First we get the current parameter vector `xk` from the state via `state.take_param()`. -Note that `take_param` moves the parameter vector from the `state` into `xk`, therefore one needs to make sure to move the updated parameter vector into `state` at the end of `next_iter` via `state.param(...)`. -Landweber requires the user to provide an initial parameter vector. +Note that `take_param` moves the parameter vector from the `state` into `xk`, therefore one needs to make sure to move the updated parameter vector into `state` at the end of `next_iter` via `state.param(...)`. +Landweber requires the user to provide an initial parameter vector. If this is not the case than we return an error to inform the user. Then the gradient `grad` is computed by calling `problem.gradient(...)` on the parameter vector. -This will return the gradient and internally increase the gradient function evaluation count. -We compute the updated parameter vector `xkp1` by computing `xk.scaled_sub(&self.omega, &grad)` (which is possible because of the `ArgminScaledSub` trait bound introduced before). +This will return the gradient and internally increase the gradient function evaluation count. +We compute the updated parameter vector `xkp1` by computing `xk.scaled_sub(&self.omega, &grad)` (which is possible because of the `ArgminScaledSub` trait bound introduced before). Finally, the state is updated via `state.param(xkp1)` and returned by the function. ```rust @@ -87,7 +87,7 @@ impl Landweber { impl Solver> for Landweber where - // The Landweber solver requires `O` to implement `Gradient`. + // The Landweber solver requires `O` to implement `Gradient`. // `P` and `G` indicate the types of the parameter vector and gradient, // respectively. O: Gradient, @@ -98,7 +98,7 @@ where F: ArgminFloat, { // This gives the solver a name which will be used for logging - const NAME: &'static str = "Landweber"; + fn name(&self) -> &str { "Landweber" } // Defines the computations performed in a single iteration. fn next_iter( @@ -114,20 +114,20 @@ where mut state: IterState, ) -> Result<(IterState, Option), Error> { // First we obtain the current parameter vector from the `state` struct (`x_k`). - // Landweber requires an initial parameter vector. Return an error if this was + // Landweber requires an initial parameter vector. Return an error if this was // not provided by the user. let xk = state.take_param().ok_or_else(argmin_error_closure!( NotInitialized, "Initial parameter vector required!" ))?; - + // Then we compute the gradient at `x_k` (`\nabla f(x_k)`) let grad = problem.gradient(&xk)?; - + // Now subtract `\nabla f(x_k)` scaled by `omega` from `x_k` // to compute `x_{k+1}` let xkp1 = xk.scaled_sub(&self.omega, &grad); - + // Return new the updated `state` Ok((state.param(xkp1), None)) }