From 4c94c5145eb5d0d71e0e856dc0bc2abf304191ed Mon Sep 17 00:00:00 2001 From: Joris Bayer Date: Sun, 19 Mar 2023 14:25:47 +0100 Subject: [PATCH 01/13] Python extension: First draft --- .gitignore | 2 + Cargo.toml | 1 + argmin-py/Cargo.toml | 18 +++++++++ argmin-py/examples/newton.py | 26 ++++++++++++ argmin-py/src/executor.rs | 64 ++++++++++++++++++++++++++++++ argmin-py/src/lib.rs | 24 +++++++++++ argmin-py/src/problem.rs | 68 ++++++++++++++++++++++++++++++++ argmin-py/src/solver.rs | 51 ++++++++++++++++++++++++ argmin-py/src/types.rs | 15 +++++++ argmin/src/core/executor.rs | 6 +-- argmin/src/core/observers/mod.rs | 4 +- 11 files changed, 274 insertions(+), 5 deletions(-) create mode 100644 argmin-py/Cargo.toml create mode 100644 argmin-py/examples/newton.py create mode 100644 argmin-py/src/executor.rs create mode 100644 argmin-py/src/lib.rs create mode 100644 argmin-py/src/problem.rs create mode 100644 argmin-py/src/solver.rs create mode 100644 argmin-py/src/types.rs diff --git a/.gitignore b/.gitignore index 77cf2b7bd..be6745289 100644 --- a/.gitignore +++ b/.gitignore @@ -4,3 +4,5 @@ target/ target/* *.log justfile +.vscode +.venv diff --git a/Cargo.toml b/Cargo.toml index cfb85b940..ed39c8b75 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -3,6 +3,7 @@ members = [ "argmin", "argmin-math", + "argmin-py", ] exclude = [ diff --git a/argmin-py/Cargo.toml b/argmin-py/Cargo.toml new file mode 100644 index 000000000..1590abc37 --- /dev/null +++ b/argmin-py/Cargo.toml @@ -0,0 +1,18 @@ +[package] +name = "argmin-py" +version = "0.1.0" +edition = "2021" + +[lib] +name = "argmin" +crate-type = ["cdylib"] + +[dependencies] +anyhow = "1.0.70" +argmin_testfunctions = "0.1.1" +argmin = {path="../argmin", default-features=false, features=[]} +argmin-math = {path="../argmin-math", features=["ndarray_latest-serde"]} +ndarray-linalg = { version = "0.16", features = ["netlib"] } +ndarray = { version = "0.15", features = ["serde-1"] } +numpy = "0.18.0" +pyo3 = {version="0.18.1", features=["extension-module", "anyhow"]} \ No newline at end of file diff --git a/argmin-py/examples/newton.py b/argmin-py/examples/newton.py new file mode 100644 index 000000000..f142282be --- /dev/null +++ b/argmin-py/examples/newton.py @@ -0,0 +1,26 @@ +# Copyright 2018-2023 argmin developers +# +# Licensed under the Apache License, Version 2.0 or the MIT license , at your option. This file may not be +# copied, modified, or distributed except according to those terms. +from argmin import Problem, Solver, Executor +import numpy as np +from scipy.optimize import rosen_der, rosen_hess + + +def main(): + problem = Problem( + gradient=rosen_der, + hessian=rosen_hess, + ) + solver = Solver.Newton + executor = Executor(problem, solver) + executor.configure(param=np.array([-1.2, 1.0]), max_iters=8) + + result = executor.run() + print(result) + + +if __name__ == "__main__": + main() diff --git a/argmin-py/src/executor.rs b/argmin-py/src/executor.rs new file mode 100644 index 000000000..b60509f63 --- /dev/null +++ b/argmin-py/src/executor.rs @@ -0,0 +1,64 @@ +// Copyright 2018-2023 argmin developers +// +// Licensed under the Apache License, Version 2.0 or the MIT license , at your option. This file may not be +// copied, modified, or distributed except according to those terms. + +// TODO: docs + +use pyo3::{prelude::*, types::PyDict}; + +use argmin::core; + +use crate::problem::Problem; +use crate::solver::{DynamicSolver, Solver}; +use crate::types::{IterState, PyArray1}; + +#[pyclass] +pub struct Executor(Option>); + +impl Executor { + /// Consumes the inner executor. + /// + /// PyObjects do not allow methods that consume the object itself, so this is a workaround + /// for using methods like `configure` and `run`. + fn take(&mut self) -> anyhow::Result> { + let Some(inner) = self.0.take() else { + return Err(anyhow::anyhow!("Executor was already run.")); + }; + Ok(inner) + } +} + +#[pymethods] +impl Executor { + #[new] + fn new(problem: Problem, solver: Solver) -> Self { + Self(Some(core::Executor::new(problem, solver.into()))) + } + + #[pyo3(signature = (**kwargs))] + fn configure(&mut self, kwargs: Option<&PyDict>) -> PyResult<()> { + if let Some(kwargs) = kwargs { + let new_self = self.take()?.configure(|mut state| { + if let Some(param) = kwargs.get_item("param") { + let param: &PyArray1 = param.extract().unwrap(); + state = state.param(param.to_owned_array()); + } + if let Some(max_iters) = kwargs.get_item("max_iters") { + state = state.max_iters(max_iters.extract().unwrap()); + } + state + }); + self.0 = Some(new_self); + } + Ok(()) + } + + fn run(&mut self) -> PyResult { + // TODO: return usable OptimizationResult + let res = self.take()?.run(); + Ok(res?.to_string()) + } +} diff --git a/argmin-py/src/lib.rs b/argmin-py/src/lib.rs new file mode 100644 index 000000000..fa3a21e15 --- /dev/null +++ b/argmin-py/src/lib.rs @@ -0,0 +1,24 @@ +// Copyright 2018-2023 argmin developers +// +// Licensed under the Apache License, Version 2.0 or the MIT license , at your option. This file may not be +// copied, modified, or distributed except according to those terms. + +// TODO: docs +mod executor; +mod problem; +mod solver; +mod types; + +use pyo3::prelude::*; + +#[pymodule] +#[pyo3(name = "argmin")] +fn argmin_py(_py: Python, m: &PyModule) -> PyResult<()> { + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + + Ok(()) +} diff --git a/argmin-py/src/problem.rs b/argmin-py/src/problem.rs new file mode 100644 index 000000000..93d6aae2a --- /dev/null +++ b/argmin-py/src/problem.rs @@ -0,0 +1,68 @@ +// Copyright 2018-2023 argmin developers +// +// Licensed under the Apache License, Version 2.0 or the MIT license , at your option. This file may not be +// copied, modified, or distributed except according to those terms. + +// TODO: docs + +use numpy::ToPyArray; +use pyo3::{prelude::*, types::PyTuple}; + +use argmin::core; + +use crate::types::{Array1, Array2, Scalar}; + +#[pyclass] +#[derive(Clone)] +pub struct Problem { + gradient: PyObject, + hessian: PyObject, + // TODO: jacobian +} + +#[pymethods] +impl Problem { + #[new] + fn new(gradient: PyObject, hessian: PyObject) -> Self { + Self { gradient, hessian } + } +} + +impl core::Gradient for Problem { + type Param = Array1; + type Gradient = Array1; + + fn gradient(&self, param: &Self::Param) -> Result { + call(&self.gradient, param) + } +} + +impl argmin::core::Hessian for Problem { + type Param = Array1; + + type Hessian = Array2; + + fn hessian(&self, param: &Self::Param) -> Result { + call(&self.hessian, param) + } +} + +fn call( + callable: &PyObject, + param: &ndarray::Array, +) -> Result, argmin::core::Error> +where + InputDimension: ndarray::Dimension, + OutputDimension: ndarray::Dimension, +{ + // TODO: prevent dynamic dispatch for every call + Python::with_gil(|py| { + let args = PyTuple::new(py, [param.to_pyarray(py)]); + let pyresult = callable.call(py, args, Default::default())?; + let pyarray = pyresult.extract::<&numpy::PyArray>(py)?; + // TODO: try to get ownership instead of cloning + Ok(pyarray.to_owned_array()) + }) +} diff --git a/argmin-py/src/solver.rs b/argmin-py/src/solver.rs new file mode 100644 index 000000000..3ff801f89 --- /dev/null +++ b/argmin-py/src/solver.rs @@ -0,0 +1,51 @@ +// Copyright 2018-2023 argmin developers +// +// Licensed under the Apache License, Version 2.0 or the MIT license , at your option. This file may not be +// copied, modified, or distributed except according to those terms. + +// TODO: docs + +use pyo3::prelude::*; + +use argmin::{core, solver}; + +use crate::{ + problem::Problem, + types::{IterState, Scalar}, +}; + +#[pyclass] +#[derive(Clone)] +pub enum Solver { + Newton, +} + +pub enum DynamicSolver { + // NOTE: I tried using a Box here, but Solver is not object safe. + Newton(solver::newton::Newton), +} + +impl From for DynamicSolver { + fn from(solver: Solver) -> Self { + match solver { + Solver::Newton => Self::Newton(solver::newton::Newton::new()), + } + } +} + +impl core::Solver for DynamicSolver { + // TODO: make this a trait method so we can return a dynamic + const NAME: &'static str = "Dynamic Solver"; + + fn next_iter( + &mut self, + problem: &mut core::Problem, + state: IterState, + ) -> Result<(IterState, Option), core::Error> { + match self { + DynamicSolver::Newton(inner) => inner.next_iter(problem, state), + } + } +} diff --git a/argmin-py/src/types.rs b/argmin-py/src/types.rs new file mode 100644 index 000000000..511e616ab --- /dev/null +++ b/argmin-py/src/types.rs @@ -0,0 +1,15 @@ +// Copyright 2018-2023 argmin developers +// +// Licensed under the Apache License, Version 2.0 or the MIT license , at your option. This file may not be +// copied, modified, or distributed except according to those terms. + +//! Base types for the Python extension. + +pub type Scalar = f64; // TODO: allow complex numbers +pub type Array1 = ndarray::Array1; +pub type Array2 = ndarray::Array2; +pub type PyArray1 = numpy::PyArray1; + +pub type IterState = argmin::core::IterState, Scalar>; diff --git a/argmin/src/core/executor.rs b/argmin/src/core/executor.rs index 184f9c970..b02319020 100644 --- a/argmin/src/core/executor.rs +++ b/argmin/src/core/executor.rs @@ -26,7 +26,7 @@ pub struct Executor { /// Storage for observers observers: Observers, /// Checkpoint - checkpoint: Option>>, + checkpoint: Option + Send>>, /// Indicates whether Ctrl-C functionality should be active or not ctrlc: bool, /// Indicates whether to time execution or not @@ -298,7 +298,7 @@ where /// # } /// ``` #[must_use] - pub fn add_observer + 'static>( + pub fn add_observer + 'static + Send>( mut self, observer: OBS, mode: ObserverMode, @@ -340,7 +340,7 @@ where /// # } /// ``` #[must_use] - pub fn checkpointing>(mut self, checkpoint: C) -> Self { + pub fn checkpointing + Send>(mut self, checkpoint: C) -> Self { self.checkpoint = Some(Box::new(checkpoint)); self } diff --git a/argmin/src/core/observers/mod.rs b/argmin/src/core/observers/mod.rs index 009adb5e5..d9359932d 100644 --- a/argmin/src/core/observers/mod.rs +++ b/argmin/src/core/observers/mod.rs @@ -188,7 +188,7 @@ pub trait Observe { } } -type ObserversVec = Vec<(Arc>>, ObserverMode)>; +type ObserversVec = Vec<(Arc + Send>>, ObserverMode)>; /// Container for observers. /// @@ -236,7 +236,7 @@ impl Observers { /// # #[cfg(feature = "slog-logger")] /// # assert!(!observers.is_empty()); /// ``` - pub fn push + 'static>( + pub fn push + 'static + Send>( &mut self, observer: OBS, mode: ObserverMode, From bc2e840bb0362daf6b136ba581b36a202254ab50 Mon Sep 17 00:00:00 2001 From: Joris Bayer Date: Sun, 19 Mar 2023 14:53:52 +0100 Subject: [PATCH 02/13] self-review --- argmin-py/Cargo.toml | 2 +- argmin-py/src/executor.rs | 21 ++++++++++++++------- 2 files changed, 15 insertions(+), 8 deletions(-) diff --git a/argmin-py/Cargo.toml b/argmin-py/Cargo.toml index 1590abc37..ea4fb9d60 100644 --- a/argmin-py/Cargo.toml +++ b/argmin-py/Cargo.toml @@ -15,4 +15,4 @@ argmin-math = {path="../argmin-math", features=["ndarray_latest-serde"]} ndarray-linalg = { version = "0.16", features = ["netlib"] } ndarray = { version = "0.15", features = ["serde-1"] } numpy = "0.18.0" -pyo3 = {version="0.18.1", features=["extension-module", "anyhow"]} \ No newline at end of file +pyo3 = {version="0.18.1", features=["extension-module", "anyhow"]} diff --git a/argmin-py/src/executor.rs b/argmin-py/src/executor.rs index b60509f63..490ae27f9 100644 --- a/argmin-py/src/executor.rs +++ b/argmin-py/src/executor.rs @@ -41,17 +41,24 @@ impl Executor { #[pyo3(signature = (**kwargs))] fn configure(&mut self, kwargs: Option<&PyDict>) -> PyResult<()> { if let Some(kwargs) = kwargs { - let new_self = self.take()?.configure(|mut state| { - if let Some(param) = kwargs.get_item("param") { - let param: &PyArray1 = param.extract().unwrap(); + let param = kwargs + .get_item("param") + .map(|x| x.extract::<&PyArray1>()) + .map_or(Ok(None), |r| r.map(Some))?; + let max_iters = kwargs + .get_item("max_iters") + .map(|x| x.extract()) + .map_or(Ok(None), |r| r.map(Some))?; + + self.0 = Some(self.take()?.configure(|mut state| { + if let Some(param) = param { state = state.param(param.to_owned_array()); } - if let Some(max_iters) = kwargs.get_item("max_iters") { - state = state.max_iters(max_iters.extract().unwrap()); + if let Some(max_iters) = max_iters { + state = state.max_iters(max_iters); } state - }); - self.0 = Some(new_self); + })); } Ok(()) } From 007a13c52616ace222a3bd6742fc3b8d745d5233 Mon Sep 17 00:00:00 2001 From: Joris Bayer Date: Sun, 26 Mar 2023 13:20:32 +0200 Subject: [PATCH 03/13] Make Checkpoint and Observe Send --- argmin/src/core/checkpointing/mod.rs | 2 +- argmin/src/core/executor.rs | 6 +++--- argmin/src/core/observers/mod.rs | 6 +++--- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/argmin/src/core/checkpointing/mod.rs b/argmin/src/core/checkpointing/mod.rs index 4b91c2510..254e6f2d3 100644 --- a/argmin/src/core/checkpointing/mod.rs +++ b/argmin/src/core/checkpointing/mod.rs @@ -157,7 +157,7 @@ use std::fmt::Display; /// } /// # fn main() {} /// ``` -pub trait Checkpoint { +pub trait Checkpoint: Send { /// Save a checkpoint /// /// Gets a reference to the current `solver` of type `S` and to the current `state` of type diff --git a/argmin/src/core/executor.rs b/argmin/src/core/executor.rs index b02319020..184f9c970 100644 --- a/argmin/src/core/executor.rs +++ b/argmin/src/core/executor.rs @@ -26,7 +26,7 @@ pub struct Executor { /// Storage for observers observers: Observers, /// Checkpoint - checkpoint: Option + Send>>, + checkpoint: Option>>, /// Indicates whether Ctrl-C functionality should be active or not ctrlc: bool, /// Indicates whether to time execution or not @@ -298,7 +298,7 @@ where /// # } /// ``` #[must_use] - pub fn add_observer + 'static + Send>( + pub fn add_observer + 'static>( mut self, observer: OBS, mode: ObserverMode, @@ -340,7 +340,7 @@ where /// # } /// ``` #[must_use] - pub fn checkpointing + Send>(mut self, checkpoint: C) -> Self { + pub fn checkpointing>(mut self, checkpoint: C) -> Self { self.checkpoint = Some(Box::new(checkpoint)); self } diff --git a/argmin/src/core/observers/mod.rs b/argmin/src/core/observers/mod.rs index d9359932d..6278c1964 100644 --- a/argmin/src/core/observers/mod.rs +++ b/argmin/src/core/observers/mod.rs @@ -169,7 +169,7 @@ use std::sync::{Arc, Mutex}; /// } /// } /// ``` -pub trait Observe { +pub trait Observe: Send { /// Called once after initialization of the solver. /// /// Has access to the name of the solver via `name` and to a key-value store `kv` with entries @@ -188,7 +188,7 @@ pub trait Observe { } } -type ObserversVec = Vec<(Arc + Send>>, ObserverMode)>; +type ObserversVec = Vec<(Arc>>, ObserverMode)>; /// Container for observers. /// @@ -236,7 +236,7 @@ impl Observers { /// # #[cfg(feature = "slog-logger")] /// # assert!(!observers.is_empty()); /// ``` - pub fn push + 'static + Send>( + pub fn push + 'static>( &mut self, observer: OBS, mode: ObserverMode, From f24ea5afc2f31d8619fbb84707cd3b556e282b01 Mon Sep 17 00:00:00 2001 From: Joris Bayer Date: Sun, 26 Mar 2023 13:33:36 +0200 Subject: [PATCH 04/13] Dynamic solver name --- argmin-py/src/solver.rs | 11 +++++++++++ argmin/src/core/executor.rs | 2 +- argmin/src/core/result.rs | 2 +- argmin/src/core/solver.rs | 7 +++++++ 4 files changed, 20 insertions(+), 2 deletions(-) diff --git a/argmin-py/src/solver.rs b/argmin-py/src/solver.rs index 3ff801f89..6d930b510 100644 --- a/argmin-py/src/solver.rs +++ b/argmin-py/src/solver.rs @@ -7,6 +7,8 @@ // TODO: docs +use std::path::Iter; + use pyo3::prelude::*; use argmin::{core, solver}; @@ -39,6 +41,15 @@ impl core::Solver for DynamicSolver { // TODO: make this a trait method so we can return a dynamic const NAME: &'static str = "Dynamic Solver"; + fn name(&self) -> &str { + match self { + DynamicSolver::Newton(inner) => { + as argmin::core::Solver> + ::name(inner) + } + } + } + fn next_iter( &mut self, problem: &mut core::Problem, diff --git a/argmin/src/core/executor.rs b/argmin/src/core/executor.rs index 184f9c970..a0f1ce1ba 100644 --- a/argmin/src/core/executor.rs +++ b/argmin/src/core/executor.rs @@ -181,7 +181,7 @@ where } // Observe after init - self.observers.observe_init(S::NAME, &logs)?; + self.observers.observe_init(self.solver.name(), &logs)?; } state.func_counts(&self.problem); diff --git a/argmin/src/core/result.rs b/argmin/src/core/result.rs index 00cfbe391..0d93d2c67 100644 --- a/argmin/src/core/result.rs +++ b/argmin/src/core/result.rs @@ -124,7 +124,7 @@ where { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { writeln!(f, "OptimizationResult:")?; - writeln!(f, " Solver: {}", S::NAME)?; + writeln!(f, " Solver: {}", self.solver().name())?; writeln!( f, " param (best): {}", diff --git a/argmin/src/core/solver.rs b/argmin/src/core/solver.rs index a81514ec7..4b7cd747d 100644 --- a/argmin/src/core/solver.rs +++ b/argmin/src/core/solver.rs @@ -117,4 +117,11 @@ pub trait Solver { fn terminate(&mut self, _state: &I) -> TerminationStatus { TerminationStatus::NotTerminated } + + /// Returns the name of the solver. + /// + /// Defaults to [`Self::NAME`], but can be overriden in implmentations. + fn name(&self) -> &str { + Self::NAME + } } From b669e48f57e8f0fa2c8ba45817ef8917120f9a05 Mon Sep 17 00:00:00 2001 From: Joris Bayer Date: Sun, 26 Mar 2023 14:15:49 +0200 Subject: [PATCH 05/13] Object-safe Solver --- argmin-py/src/solver.rs | 34 ++++--------- argmin/src/core/executor.rs | 2 +- argmin/src/core/solver.rs | 10 ++-- argmin/src/core/test_utils.rs | 4 +- argmin/src/solver/brent/brentopt.rs | 4 +- argmin/src/solver/brent/brentroot.rs | 4 +- argmin/src/solver/conjugategradient/cg.rs | 4 +- .../solver/conjugategradient/nonlinear_cg.rs | 4 +- .../gaussnewton/gaussnewton_linesearch.rs | 4 +- .../solver/gaussnewton/gaussnewton_method.rs | 4 +- argmin/src/solver/goldensectionsearch/mod.rs | 4 +- .../solver/gradientdescent/steepestdescent.rs | 4 +- argmin/src/solver/landweber/mod.rs | 4 +- argmin/src/solver/linesearch/backtracking.rs | 4 +- argmin/src/solver/linesearch/hagerzhang.rs | 4 +- argmin/src/solver/linesearch/morethuente.rs | 4 +- argmin/src/solver/neldermead/mod.rs | 4 +- argmin/src/solver/newton/newton_cg.rs | 4 +- argmin/src/solver/newton/newton_method.rs | 4 +- argmin/src/solver/particleswarm/mod.rs | 4 +- argmin/src/solver/quasinewton/bfgs.rs | 2 +- argmin/src/solver/quasinewton/dfp.rs | 4 +- argmin/src/solver/quasinewton/lbfgs.rs | 2 +- argmin/src/solver/quasinewton/sr1.rs | 2 +- .../src/solver/quasinewton/sr1_trustregion.rs | 4 +- argmin/src/solver/simulatedannealing/mod.rs | 4 +- argmin/src/solver/trustregion/cauchypoint.rs | 4 +- argmin/src/solver/trustregion/dogleg.rs | 4 +- argmin/src/solver/trustregion/steihaug.rs | 4 +- .../solver/trustregion/trustregion_method.rs | 4 +- media/book/src/implementing_solver.md | 50 +++++++++---------- 31 files changed, 113 insertions(+), 85 deletions(-) diff --git a/argmin-py/src/solver.rs b/argmin-py/src/solver.rs index 6d930b510..0cae464b1 100644 --- a/argmin-py/src/solver.rs +++ b/argmin-py/src/solver.rs @@ -7,16 +7,11 @@ // TODO: docs -use std::path::Iter; - use pyo3::prelude::*; use argmin::{core, solver}; -use crate::{ - problem::Problem, - types::{IterState, Scalar}, -}; +use crate::{problem::Problem, types::IterState}; #[pyclass] #[derive(Clone)] @@ -24,30 +19,21 @@ pub enum Solver { Newton, } -pub enum DynamicSolver { - // NOTE: I tried using a Box here, but Solver is not object safe. - Newton(solver::newton::Newton), -} +pub struct DynamicSolver(Box + Send>); impl From for DynamicSolver { - fn from(solver: Solver) -> Self { - match solver { - Solver::Newton => Self::Newton(solver::newton::Newton::new()), - } + fn from(value: Solver) -> Self { + let inner = match value { + Solver::Newton => solver::newton::Newton::new(), + }; + Self(Box::new(inner)) } } impl core::Solver for DynamicSolver { // TODO: make this a trait method so we can return a dynamic - const NAME: &'static str = "Dynamic Solver"; - fn name(&self) -> &str { - match self { - DynamicSolver::Newton(inner) => { - as argmin::core::Solver> - ::name(inner) - } - } + self.0.name() } fn next_iter( @@ -55,8 +41,6 @@ impl core::Solver for DynamicSolver { problem: &mut core::Problem, state: IterState, ) -> Result<(IterState, Option), core::Error> { - match self { - DynamicSolver::Newton(inner) => inner.next_iter(problem, state), - } + self.0.next_iter(problem, state) } } diff --git a/argmin/src/core/executor.rs b/argmin/src/core/executor.rs index a0f1ce1ba..8d95e945e 100644 --- a/argmin/src/core/executor.rs +++ b/argmin/src/core/executor.rs @@ -591,7 +591,7 @@ mod tests { P: Clone, F: ArgminFloat, { - const NAME: &'static str = "OptimizationAlgorithm"; + fn name(&self) -> &str { "OptimizationAlgorithm" } // Only resets internal_state to 1 fn init( diff --git a/argmin/src/core/solver.rs b/argmin/src/core/solver.rs index 4b7cd747d..399e764dc 100644 --- a/argmin/src/core/solver.rs +++ b/argmin/src/core/solver.rs @@ -36,7 +36,7 @@ use crate::core::{Error, Problem, State, TerminationReason, TerminationStatus, K /// P: Clone, /// F: ArgminFloat /// { -/// const NAME: &'static str = "OptimizationAlgorithm"; +/// fn name(&self) -> &str { "OptimizationAlgorithm" } /// /// fn init( /// &mut self, @@ -67,7 +67,7 @@ use crate::core::{Error, Problem, State, TerminationReason, TerminationStatus, K /// ``` pub trait Solver { /// Name of the solver. Mainly used in [Observers](`crate::core::observers::Observe`). - const NAME: &'static str; + // const NAME: &'static str; /// Initializes the algorithm. /// @@ -119,9 +119,5 @@ pub trait Solver { } /// Returns the name of the solver. - /// - /// Defaults to [`Self::NAME`], but can be overriden in implmentations. - fn name(&self) -> &str { - Self::NAME - } + fn name(&self) -> &str; } diff --git a/argmin/src/core/test_utils.rs b/argmin/src/core/test_utils.rs index 6797e181e..74bb27580 100644 --- a/argmin/src/core/test_utils.rs +++ b/argmin/src/core/test_utils.rs @@ -327,7 +327,9 @@ impl TestSolver { } impl Solver, (), (), (), f64>> for TestSolver { - const NAME: &'static str = "TestSolver"; + fn name(&self) -> &str { + "TestSolver" + } fn next_iter( &mut self, diff --git a/argmin/src/solver/brent/brentopt.rs b/argmin/src/solver/brent/brentopt.rs index 17fa40093..172b91168 100644 --- a/argmin/src/solver/brent/brentopt.rs +++ b/argmin/src/solver/brent/brentopt.rs @@ -102,7 +102,9 @@ where O: CostFunction, F: ArgminFloat, { - const NAME: &'static str = "BrentOpt"; + fn name(&self) -> &str { + "BrentOpt" + } fn init( &mut self, diff --git a/argmin/src/solver/brent/brentroot.rs b/argmin/src/solver/brent/brentroot.rs index 5ca00084d..3a027788e 100644 --- a/argmin/src/solver/brent/brentroot.rs +++ b/argmin/src/solver/brent/brentroot.rs @@ -80,7 +80,9 @@ where O: CostFunction, F: ArgminFloat, { - const NAME: &'static str = "BrentRoot"; + fn name(&self) -> &str { + "BrentRoot" + } fn init( &mut self, diff --git a/argmin/src/solver/conjugategradient/cg.rs b/argmin/src/solver/conjugategradient/cg.rs index cb1bb7ccd..90b4ce050 100644 --- a/argmin/src/solver/conjugategradient/cg.rs +++ b/argmin/src/solver/conjugategradient/cg.rs @@ -101,7 +101,9 @@ where + ArgminMul, F: ArgminFloat + ArgminL2Norm, { - const NAME: &'static str = "Conjugate Gradient"; + fn name(&self) -> &str { + "Conjugate Gradient" + } fn init( &mut self, diff --git a/argmin/src/solver/conjugategradient/nonlinear_cg.rs b/argmin/src/solver/conjugategradient/nonlinear_cg.rs index 868a96b03..6180ba0a9 100644 --- a/argmin/src/solver/conjugategradient/nonlinear_cg.rs +++ b/argmin/src/solver/conjugategradient/nonlinear_cg.rs @@ -133,7 +133,9 @@ where B: NLCGBetaUpdate, F: ArgminFloat, { - const NAME: &'static str = "Nonlinear Conjugate Gradient"; + fn name(&self) -> &str { + "Nonlinear Conjugate Gradient" + } fn init( &mut self, diff --git a/argmin/src/solver/gaussnewton/gaussnewton_linesearch.rs b/argmin/src/solver/gaussnewton/gaussnewton_linesearch.rs index 72a85a3d2..2fd52bca4 100644 --- a/argmin/src/solver/gaussnewton/gaussnewton_linesearch.rs +++ b/argmin/src/solver/gaussnewton/gaussnewton_linesearch.rs @@ -98,7 +98,9 @@ where L: Clone + LineSearch + Solver, IterState>, F: ArgminFloat, { - const NAME: &'static str = "Gauss-Newton method with line search"; + fn name(&self) -> &str { + "Gauss-Newton method with line search" + } fn next_iter( &mut self, diff --git a/argmin/src/solver/gaussnewton/gaussnewton_method.rs b/argmin/src/solver/gaussnewton/gaussnewton_method.rs index 167349aea..d042a570f 100644 --- a/argmin/src/solver/gaussnewton/gaussnewton_method.rs +++ b/argmin/src/solver/gaussnewton/gaussnewton_method.rs @@ -122,7 +122,9 @@ where + ArgminDot, F: ArgminFloat, { - const NAME: &'static str = "Gauss-Newton method"; + fn name(&self) -> &str { + "Gauss-Newton method" + } fn next_iter( &mut self, diff --git a/argmin/src/solver/goldensectionsearch/mod.rs b/argmin/src/solver/goldensectionsearch/mod.rs index 0634f97fa..87bd4d112 100644 --- a/argmin/src/solver/goldensectionsearch/mod.rs +++ b/argmin/src/solver/goldensectionsearch/mod.rs @@ -139,7 +139,9 @@ where O: CostFunction, F: ArgminFloat, { - const NAME: &'static str = "Golden-section search"; + fn name(&self) -> &str { + "Golden-section search" + } fn init( &mut self, diff --git a/argmin/src/solver/gradientdescent/steepestdescent.rs b/argmin/src/solver/gradientdescent/steepestdescent.rs index 29406acd5..2f317b2da 100644 --- a/argmin/src/solver/gradientdescent/steepestdescent.rs +++ b/argmin/src/solver/gradientdescent/steepestdescent.rs @@ -58,7 +58,9 @@ where L: Clone + LineSearch + Solver>, F: ArgminFloat, { - const NAME: &'static str = "Steepest Descent"; + fn name(&self) -> &str { + "Steepest Descent" + } fn next_iter( &mut self, diff --git a/argmin/src/solver/landweber/mod.rs b/argmin/src/solver/landweber/mod.rs index fbcf69cba..4619e0cd9 100644 --- a/argmin/src/solver/landweber/mod.rs +++ b/argmin/src/solver/landweber/mod.rs @@ -69,7 +69,9 @@ where P: Clone + ArgminScaledSub, F: ArgminFloat, { - const NAME: &'static str = "Landweber"; + fn name(&self) -> &str { + "Landweber" + } fn next_iter( &mut self, diff --git a/argmin/src/solver/linesearch/backtracking.rs b/argmin/src/solver/linesearch/backtracking.rs index 5ae71be3e..4f1120e01 100644 --- a/argmin/src/solver/linesearch/backtracking.rs +++ b/argmin/src/solver/linesearch/backtracking.rs @@ -182,7 +182,9 @@ where L: LineSearchCondition + SerializeAlias, F: ArgminFloat, { - const NAME: &'static str = "Backtracking line search"; + fn name(&self) -> &str { + "Backtracking line search" + } fn init( &mut self, diff --git a/argmin/src/solver/linesearch/hagerzhang.rs b/argmin/src/solver/linesearch/hagerzhang.rs index b318ed089..fd59620be 100644 --- a/argmin/src/solver/linesearch/hagerzhang.rs +++ b/argmin/src/solver/linesearch/hagerzhang.rs @@ -500,7 +500,9 @@ where G: Clone + SerializeAlias + ArgminDot, F: ArgminFloat, { - const NAME: &'static str = "Hager-Zhang line search"; + fn name(&self) -> &str { + "Hager-Zhang line search" + } fn init( &mut self, diff --git a/argmin/src/solver/linesearch/morethuente.rs b/argmin/src/solver/linesearch/morethuente.rs index 4d99fe610..09f914a42 100644 --- a/argmin/src/solver/linesearch/morethuente.rs +++ b/argmin/src/solver/linesearch/morethuente.rs @@ -303,7 +303,9 @@ where G: Clone + SerializeAlias + ArgminDot, F: ArgminFloat, { - const NAME: &'static str = "More-Thuente Line search"; + fn name(&self) -> &str { + "More-Thuente Line search" + } fn init( &mut self, diff --git a/argmin/src/solver/neldermead/mod.rs b/argmin/src/solver/neldermead/mod.rs index f8899133e..e5846a4b5 100644 --- a/argmin/src/solver/neldermead/mod.rs +++ b/argmin/src/solver/neldermead/mod.rs @@ -322,7 +322,9 @@ where P: Clone + SerializeAlias + ArgminSub + ArgminAdd + ArgminMul, F: ArgminFloat + std::iter::Sum, { - const NAME: &'static str = "Nelder-Mead method"; + fn name(&self) -> &str { + "Nelder-Mead method" + } fn init( &mut self, diff --git a/argmin/src/solver/newton/newton_cg.rs b/argmin/src/solver/newton/newton_cg.rs index 23b56a467..6c1d7c230 100644 --- a/argmin/src/solver/newton/newton_cg.rs +++ b/argmin/src/solver/newton/newton_cg.rs @@ -123,7 +123,9 @@ where L: Clone + LineSearch + Solver>, F: ArgminFloat + ArgminL2Norm, { - const NAME: &'static str = "Newton-CG"; + fn name(&self) -> &str { + "Newton-CG" + } fn next_iter( &mut self, diff --git a/argmin/src/solver/newton/newton_method.rs b/argmin/src/solver/newton/newton_method.rs index 457393212..7ad73aba8 100644 --- a/argmin/src/solver/newton/newton_method.rs +++ b/argmin/src/solver/newton/newton_method.rs @@ -92,7 +92,9 @@ where H: ArgminInv + ArgminDot, F: ArgminFloat, { - const NAME: &'static str = "Newton method"; + fn name(&self) -> &str { + "Newton method" + } fn next_iter( &mut self, diff --git a/argmin/src/solver/particleswarm/mod.rs b/argmin/src/solver/particleswarm/mod.rs index 8230252d6..4c6218340 100644 --- a/argmin/src/solver/particleswarm/mod.rs +++ b/argmin/src/solver/particleswarm/mod.rs @@ -244,7 +244,9 @@ where + ArgminMinMax, F: ArgminFloat, { - const NAME: &'static str = "Particle Swarm Optimization"; + fn name(&self) -> &str { + "Particle Swarm Optimization" + } fn init( &mut self, diff --git a/argmin/src/solver/quasinewton/bfgs.rs b/argmin/src/solver/quasinewton/bfgs.rs index f40a035c5..578467eb0 100644 --- a/argmin/src/solver/quasinewton/bfgs.rs +++ b/argmin/src/solver/quasinewton/bfgs.rs @@ -159,7 +159,7 @@ where L: Clone + LineSearch + Solver>, F: ArgminFloat, { - const NAME: &'static str = "BFGS"; + fn name(&self) -> &str { "BFGS" } fn init( &mut self, diff --git a/argmin/src/solver/quasinewton/dfp.rs b/argmin/src/solver/quasinewton/dfp.rs index bc9abaae2..27162d429 100644 --- a/argmin/src/solver/quasinewton/dfp.rs +++ b/argmin/src/solver/quasinewton/dfp.rs @@ -122,7 +122,9 @@ where L: Clone + LineSearch + Solver>, F: ArgminFloat, { - const NAME: &'static str = "DFP"; + fn name(&self) -> &str { + "DFP" + } fn init( &mut self, diff --git a/argmin/src/solver/quasinewton/lbfgs.rs b/argmin/src/solver/quasinewton/lbfgs.rs index 85a79acab..8aeb6b66f 100644 --- a/argmin/src/solver/quasinewton/lbfgs.rs +++ b/argmin/src/solver/quasinewton/lbfgs.rs @@ -338,7 +338,7 @@ where L: Clone + LineSearch + Solver, IterState>, F: ArgminFloat, { - const NAME: &'static str = "L-BFGS"; + fn name(&self) -> &str { "L-BFGS" } fn init( &mut self, diff --git a/argmin/src/solver/quasinewton/sr1.rs b/argmin/src/solver/quasinewton/sr1.rs index 4955b654f..35d3a5cf0 100644 --- a/argmin/src/solver/quasinewton/sr1.rs +++ b/argmin/src/solver/quasinewton/sr1.rs @@ -172,7 +172,7 @@ where L: Clone + LineSearch + Solver>, F: ArgminFloat, { - const NAME: &'static str = "SR1"; + fn name(&self) -> &str { "SR1" } fn init( &mut self, diff --git a/argmin/src/solver/quasinewton/sr1_trustregion.rs b/argmin/src/solver/quasinewton/sr1_trustregion.rs index 10668d2c4..5400b1afd 100644 --- a/argmin/src/solver/quasinewton/sr1_trustregion.rs +++ b/argmin/src/solver/quasinewton/sr1_trustregion.rs @@ -209,7 +209,9 @@ where R: Clone + TrustRegionRadius + Solver>, F: ArgminFloat + ArgminL2Norm, { - const NAME: &'static str = "SR1 trust region"; + fn name(&self) -> &str { + "SR1 trust region" + } fn init( &mut self, diff --git a/argmin/src/solver/simulatedannealing/mod.rs b/argmin/src/solver/simulatedannealing/mod.rs index 75ae76a9a..952901248 100644 --- a/argmin/src/solver/simulatedannealing/mod.rs +++ b/argmin/src/solver/simulatedannealing/mod.rs @@ -442,7 +442,9 @@ where F: ArgminFloat, R: Rng + SerializeAlias, { - const NAME: &'static str = "Simulated Annealing"; + fn name(&self) -> &str { + "Simulated Annealing" + } fn init( &mut self, problem: &mut Problem, diff --git a/argmin/src/solver/trustregion/cauchypoint.rs b/argmin/src/solver/trustregion/cauchypoint.rs index ce94c3c94..5aae1564c 100644 --- a/argmin/src/solver/trustregion/cauchypoint.rs +++ b/argmin/src/solver/trustregion/cauchypoint.rs @@ -58,7 +58,9 @@ where G: ArgminMul + ArgminWeightedDot + ArgminL2Norm, F: ArgminFloat, { - const NAME: &'static str = "Cauchy Point"; + fn name(&self) -> &str { + "Cauchy Point" + } fn next_iter( &mut self, diff --git a/argmin/src/solver/trustregion/dogleg.rs b/argmin/src/solver/trustregion/dogleg.rs index c5a0b92e6..3c50abf6d 100644 --- a/argmin/src/solver/trustregion/dogleg.rs +++ b/argmin/src/solver/trustregion/dogleg.rs @@ -65,7 +65,9 @@ where H: ArgminInv + ArgminDot, F: ArgminFloat, { - const NAME: &'static str = "Dogleg"; + fn name(&self) -> &str { + "Dogleg" + } fn next_iter( &mut self, diff --git a/argmin/src/solver/trustregion/steihaug.rs b/argmin/src/solver/trustregion/steihaug.rs index d01fad3e1..78451e29d 100644 --- a/argmin/src/solver/trustregion/steihaug.rs +++ b/argmin/src/solver/trustregion/steihaug.rs @@ -190,7 +190,9 @@ where H: ArgminDot, F: ArgminFloat, { - const NAME: &'static str = "Steihaug"; + fn name(&self) -> &str { + "Steihaug" + } fn init( &mut self, diff --git a/argmin/src/solver/trustregion/trustregion_method.rs b/argmin/src/solver/trustregion/trustregion_method.rs index c0ea2bf06..363ecc24e 100644 --- a/argmin/src/solver/trustregion/trustregion_method.rs +++ b/argmin/src/solver/trustregion/trustregion_method.rs @@ -175,7 +175,9 @@ where R: Clone + TrustRegionRadius + Solver>, F: ArgminFloat, { - const NAME: &'static str = "Trust region"; + fn name(&self) -> &str { + "Trust region" + } fn init( &mut self, diff --git a/media/book/src/implementing_solver.md b/media/book/src/implementing_solver.md index 7291eeea7..f00113e9a 100644 --- a/media/book/src/implementing_solver.md +++ b/media/book/src/implementing_solver.md @@ -1,7 +1,7 @@ # Implementing a solver In this section we are going to implement the Landweber solver, which essentially is a special form of gradient descent. -In iteration \\( k \\), the new parameter vector \\( x_{k+1} \\) is calculated from the previous parameter vector \\( x_k \\) and the gradient at \\( x_k \\) according to the following update rule: +In iteration \\( k \\), the new parameter vector \\( x\_{k+1} \\) is calculated from the previous parameter vector \\( x_k \\) and the gradient at \\( x_k \\) according to the following update rule: \\[ x_{k+1} = x_k - \omega * \nabla f(x_k) @@ -12,50 +12,50 @@ Then, the [`Solver`](https://docs.rs/argmin/latest/argmin/core/trait.Solver.html The `Solver` trait consists of several methods; however, not all of them need to be implemented since most come with default implementations. -* `NAME`: a `&'static str` which holds the solvers name (mainly needed for the observers). -* `init(...)`: Run before the the actual iterations and initializes the solver. Does nothing by default. -* `next_iter(...)`: One iteration of the solver. Will be executed by the `Executor` until a stopping criterion is met. -* `terminate(...)`: Solver specific stopping criteria. This method is run after every iteration. Note that one can also terminate from within `next_iter` if necessary. -* `terminate_internal(...)`: By default calls `terminate` and in addition checks if the maximum number of iterations was reached or if the best cost function value is below the target cost value. Should only be overwritten if absolutely necessary. +- `NAME`: a `&'static str` which holds the solvers name (mainly needed for the observers). +- `init(...)`: Run before the the actual iterations and initializes the solver. Does nothing by default. +- `next_iter(...)`: One iteration of the solver. Will be executed by the `Executor` until a stopping criterion is met. +- `terminate(...)`: Solver specific stopping criteria. This method is run after every iteration. Note that one can also terminate from within `next_iter` if necessary. +- `terminate_internal(...)`: By default calls `terminate` and in addition checks if the maximum number of iterations was reached or if the best cost function value is below the target cost value. Should only be overwritten if absolutely necessary. Both `init` and `next_iter` have access to the optimization problem (`problem`) as well as the internal state (`state`). -The methods `terminate` and `terminate_internal` only have access to `state`. +The methods `terminate` and `terminate_internal` only have access to `state`. The function parameter `problem` is a wrapped version of the optimization problem and as such gives access to the cost function, gradient, Hessian, Jacobian,...). It also keeps track of how often each of these is called. Via `state` the solver has access to the current parameter vector, the current best parameter vector, gradient, Hessian, Jacobian, population, the current iteration number, and so on. The `state` can be modified (for instance a new parameter vector is set) and is then returned by both `init` and `next_iter`. -The `Executor` then takes care of updating the state properly, for instance by updating the current best parameter vector if the new parameter vector is better than the previous best. +The `Executor` then takes care of updating the state properly, for instance by updating the current best parameter vector if the new parameter vector is better than the previous best. It is advisable to design the solver such that it is generic over the actual type of the parameter vector, gradient, and so on. The current naming convention for generics in argmin is as follows: -* `O`: Optimization problem -* `P`: Parameter vector -* `G`: Gradient -* `J`: Jacobian -* `H`: Hessian -* `F`: Floats (`f32` or `f64`) +- `O`: Optimization problem +- `P`: Parameter vector +- `G`: Gradient +- `J`: Jacobian +- `H`: Hessian +- `F`: Floats (`f32` or `f64`) These individual generic parameters are then constrained by type constraints. For instance, the Landweber iteration requires the problem `O` to implement `Gradient`, therefore a trait bound of the form `O: Gradient` is necessary. From the Landweber update formula, we know that a scaled subtraction of two vectors is required. This must be represented in form of a trait bound as well: `P: ArgminScaledSub`. -`ArgminScaledSub` is a trait from `argmin-math` which represents a scaled subtraction. +`ArgminScaledSub` is a trait from `argmin-math` which represents a scaled subtraction. With this trait bound, we require that it must be possible to subtract a value of type `G` scaled with a value of type `F` from a value of type `P`, resulting in a value of type `P`. The generic type `F` represents floating point value and therefore allows users to choose which precision they want. Implementing the algorithm is straightforward: First we get the current parameter vector `xk` from the state via `state.take_param()`. -Note that `take_param` moves the parameter vector from the `state` into `xk`, therefore one needs to make sure to move the updated parameter vector into `state` at the end of `next_iter` via `state.param(...)`. -Landweber requires the user to provide an initial parameter vector. +Note that `take_param` moves the parameter vector from the `state` into `xk`, therefore one needs to make sure to move the updated parameter vector into `state` at the end of `next_iter` via `state.param(...)`. +Landweber requires the user to provide an initial parameter vector. If this is not the case than we return an error to inform the user. Then the gradient `grad` is computed by calling `problem.gradient(...)` on the parameter vector. -This will return the gradient and internally increase the gradient function evaluation count. -We compute the updated parameter vector `xkp1` by computing `xk.scaled_sub(&self.omega, &grad)` (which is possible because of the `ArgminScaledSub` trait bound introduced before). +This will return the gradient and internally increase the gradient function evaluation count. +We compute the updated parameter vector `xkp1` by computing `xk.scaled_sub(&self.omega, &grad)` (which is possible because of the `ArgminScaledSub` trait bound introduced before). Finally, the state is updated via `state.param(xkp1)` and returned by the function. ```rust @@ -87,7 +87,7 @@ impl Landweber { impl Solver> for Landweber where - // The Landweber solver requires `O` to implement `Gradient`. + // The Landweber solver requires `O` to implement `Gradient`. // `P` and `G` indicate the types of the parameter vector and gradient, // respectively. O: Gradient, @@ -98,7 +98,7 @@ where F: ArgminFloat, { // This gives the solver a name which will be used for logging - const NAME: &'static str = "Landweber"; + fn name(&self) -> &str { "Landweber" } // Defines the computations performed in a single iteration. fn next_iter( @@ -114,20 +114,20 @@ where mut state: IterState, ) -> Result<(IterState, Option), Error> { // First we obtain the current parameter vector from the `state` struct (`x_k`). - // Landweber requires an initial parameter vector. Return an error if this was + // Landweber requires an initial parameter vector. Return an error if this was // not provided by the user. let xk = state.take_param().ok_or_else(argmin_error_closure!( NotInitialized, "Initial parameter vector required!" ))?; - + // Then we compute the gradient at `x_k` (`\nabla f(x_k)`) let grad = problem.gradient(&xk)?; - + // Now subtract `\nabla f(x_k)` scaled by `omega` from `x_k` // to compute `x_{k+1}` let xkp1 = xk.scaled_sub(&self.omega, &grad); - + // Return new the updated `state` Ok((state.param(xkp1), None)) } From f096301e434e878d492c7c6a76486154817148d9 Mon Sep 17 00:00:00 2001 From: Joris Bayer Date: Sun, 26 Mar 2023 15:14:13 +0200 Subject: [PATCH 06/13] fmt --- argmin/src/core/executor.rs | 4 +++- argmin/src/solver/quasinewton/bfgs.rs | 4 +++- argmin/src/solver/quasinewton/lbfgs.rs | 4 +++- argmin/src/solver/quasinewton/sr1.rs | 4 +++- 4 files changed, 12 insertions(+), 4 deletions(-) diff --git a/argmin/src/core/executor.rs b/argmin/src/core/executor.rs index 8d95e945e..103ca8664 100644 --- a/argmin/src/core/executor.rs +++ b/argmin/src/core/executor.rs @@ -591,7 +591,9 @@ mod tests { P: Clone, F: ArgminFloat, { - fn name(&self) -> &str { "OptimizationAlgorithm" } + fn name(&self) -> &str { + "OptimizationAlgorithm" + } // Only resets internal_state to 1 fn init( diff --git a/argmin/src/solver/quasinewton/bfgs.rs b/argmin/src/solver/quasinewton/bfgs.rs index 578467eb0..05b2706eb 100644 --- a/argmin/src/solver/quasinewton/bfgs.rs +++ b/argmin/src/solver/quasinewton/bfgs.rs @@ -159,7 +159,9 @@ where L: Clone + LineSearch + Solver>, F: ArgminFloat, { - fn name(&self) -> &str { "BFGS" } + fn name(&self) -> &str { + "BFGS" + } fn init( &mut self, diff --git a/argmin/src/solver/quasinewton/lbfgs.rs b/argmin/src/solver/quasinewton/lbfgs.rs index 8aeb6b66f..1dddc35c7 100644 --- a/argmin/src/solver/quasinewton/lbfgs.rs +++ b/argmin/src/solver/quasinewton/lbfgs.rs @@ -338,7 +338,9 @@ where L: Clone + LineSearch + Solver, IterState>, F: ArgminFloat, { - fn name(&self) -> &str { "L-BFGS" } + fn name(&self) -> &str { + "L-BFGS" + } fn init( &mut self, diff --git a/argmin/src/solver/quasinewton/sr1.rs b/argmin/src/solver/quasinewton/sr1.rs index 35d3a5cf0..e185169a2 100644 --- a/argmin/src/solver/quasinewton/sr1.rs +++ b/argmin/src/solver/quasinewton/sr1.rs @@ -172,7 +172,9 @@ where L: Clone + LineSearch + Solver>, F: ArgminFloat, { - fn name(&self) -> &str { "SR1" } + fn name(&self) -> &str { + "SR1" + } fn init( &mut self, From 839ef5fa79853b15900caa6a2ca24862d014de7e Mon Sep 17 00:00:00 2001 From: Joris Bayer Date: Sun, 26 Mar 2023 15:16:49 +0200 Subject: [PATCH 07/13] license --- argmin-py/Cargo.toml | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/argmin-py/Cargo.toml b/argmin-py/Cargo.toml index ea4fb9d60..ccdf37070 100644 --- a/argmin-py/Cargo.toml +++ b/argmin-py/Cargo.toml @@ -1,7 +1,17 @@ [package] name = "argmin-py" version = "0.1.0" +authors = ["Joris Bayer Date: Sat, 27 Jan 2024 12:28:39 +0100 Subject: [PATCH 08/13] repair dependencies --- crates/argmin-py/Cargo.toml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/crates/argmin-py/Cargo.toml b/crates/argmin-py/Cargo.toml index ccdf37070..6975eda42 100644 --- a/crates/argmin-py/Cargo.toml +++ b/crates/argmin-py/Cargo.toml @@ -19,10 +19,10 @@ crate-type = ["cdylib"] [dependencies] anyhow = "1.0.70" -argmin_testfunctions = "0.1.1" +argmin_testfunctions = { version = "0.1.1", path = "../argmin-testfunctions" } argmin = {path="../argmin", default-features=false, features=[]} -argmin-math = {path="../argmin-math", features=["ndarray_latest-serde"]} -ndarray-linalg = { version = "0.16", features = ["netlib"] } +argmin-math = {path="../argmin-math", features=["ndarray_latest"]} ndarray = { version = "0.15", features = ["serde-1"] } +ndarray-linalg = { version = "0.16", features = ["intel-mkl-static"] } numpy = "0.18.0" pyo3 = {version="0.18.1", features=["extension-module", "anyhow"]} From d430569c5fedfd5bd662ea1e7fa78e10dbb6bfb3 Mon Sep 17 00:00:00 2001 From: Joris Bayer Date: Sat, 27 Jan 2024 12:35:12 +0100 Subject: [PATCH 09/13] Upgrade dependencies --- crates/argmin-py/Cargo.toml | 4 ++-- crates/argmin-py/src/executor.rs | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/crates/argmin-py/Cargo.toml b/crates/argmin-py/Cargo.toml index 6975eda42..6eb8cd5ff 100644 --- a/crates/argmin-py/Cargo.toml +++ b/crates/argmin-py/Cargo.toml @@ -24,5 +24,5 @@ argmin = {path="../argmin", default-features=false, features=[]} argmin-math = {path="../argmin-math", features=["ndarray_latest"]} ndarray = { version = "0.15", features = ["serde-1"] } ndarray-linalg = { version = "0.16", features = ["intel-mkl-static"] } -numpy = "0.18.0" -pyo3 = {version="0.18.1", features=["extension-module", "anyhow"]} +numpy = "0.20.0" +pyo3 = {version="0.20.2", features=["extension-module", "anyhow"]} diff --git a/crates/argmin-py/src/executor.rs b/crates/argmin-py/src/executor.rs index 490ae27f9..94b5c3f4e 100644 --- a/crates/argmin-py/src/executor.rs +++ b/crates/argmin-py/src/executor.rs @@ -42,11 +42,11 @@ impl Executor { fn configure(&mut self, kwargs: Option<&PyDict>) -> PyResult<()> { if let Some(kwargs) = kwargs { let param = kwargs - .get_item("param") + .get_item("param")? .map(|x| x.extract::<&PyArray1>()) .map_or(Ok(None), |r| r.map(Some))?; let max_iters = kwargs - .get_item("max_iters") + .get_item("max_iters")? .map(|x| x.extract()) .map_or(Ok(None), |r| r.map(Some))?; From 4cbf7a0ce713838b6778bd2248a87854044779de Mon Sep 17 00:00:00 2001 From: Joris Bayer Date: Sat, 27 Jan 2024 12:39:37 +0100 Subject: [PATCH 10/13] Format --- crates/argmin/src/core/executor.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/crates/argmin/src/core/executor.rs b/crates/argmin/src/core/executor.rs index b1b2bacfb..236ac053d 100644 --- a/crates/argmin/src/core/executor.rs +++ b/crates/argmin/src/core/executor.rs @@ -180,7 +180,8 @@ where let kv = kv.unwrap_or(kv![]); // Observe after init - self.observers.observe_init(self.solver.name(), &state, &kv)?; + self.observers + .observe_init(self.solver.name(), &state, &kv)?; } state.func_counts(&self.problem); From 61f157ef066b62eb4da1d3e06ea460774fc05f48 Mon Sep 17 00:00:00 2001 From: Joris Bayer Date: Sat, 27 Jan 2024 12:44:38 +0100 Subject: [PATCH 11/13] Checkin empty README --- crates/argmin-py/README.md | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 crates/argmin-py/README.md diff --git a/crates/argmin-py/README.md b/crates/argmin-py/README.md new file mode 100644 index 000000000..e69de29bb From 52030e00ee3be945b57924863ada52646eaf05c1 Mon Sep 17 00:00:00 2001 From: Joris Bayer Date: Sat, 27 Jan 2024 12:47:55 +0100 Subject: [PATCH 12/13] Exclude argmin-py from wasm build --- .github/workflows/ci.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index f82cd4fdd..a4b364abe 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -181,11 +181,11 @@ jobs: - name: Install wasm-pack run: curl https://rustwasm.github.io/wasm-pack/installer/init.sh -sSf | sh - name: Build target wasm32-unknown-unknown - run: cargo build --workspace --exclude argmin-observer-spectator --exclude spectator --exclude argmin-observer-paramwriter --exclude "example-*" --target wasm32-unknown-unknown --features wasm-bindgen + run: cargo build --workspace --exclude argmin-observer-spectator --exclude spectator --exclude argmin-observer-paramwriter --exclude "example-*" --exclude argmin-py --target wasm32-unknown-unknown --features wasm-bindgen - name: Build target wasm32-wasi with feature wasm-bindgen - run: cargo build --workspace --exclude argmin-observer-spectator --exclude spectator --exclude argmin-observer-paramwriter --exclude "example-*" --target wasm32-wasi --features wasm-bindgen + run: cargo build --workspace --exclude argmin-observer-spectator --exclude spectator --exclude argmin-observer-paramwriter --exclude "example-*" --exclude argmin-py --target wasm32-wasi --features wasm-bindgen - name: Build target wasm32-unknown-emscripten - run: cargo build --workspace --exclude argmin-observer-spectator --exclude spectator --exclude argmin-observer-paramwriter --exclude "example-*" --target wasm32-unknown-emscripten --no-default-features --features wasm-bindgen + run: cargo build --workspace --exclude argmin-observer-spectator --exclude spectator --exclude argmin-observer-paramwriter --exclude "example-*" --exclude argmin-py --target wasm32-unknown-emscripten --no-default-features --features wasm-bindgen cargo-deny: runs-on: ubuntu-latest From ee48f38c378939d844514aa84930a1e3b69fde61 Mon Sep 17 00:00:00 2001 From: Joris Bayer Date: Sat, 27 Jan 2024 12:52:35 +0100 Subject: [PATCH 13/13] More formatting --- crates/argmin-py/src/types.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/crates/argmin-py/src/types.rs b/crates/argmin-py/src/types.rs index 5bdaa1077..57cab8f97 100644 --- a/crates/argmin-py/src/types.rs +++ b/crates/argmin-py/src/types.rs @@ -12,4 +12,5 @@ pub type Array1 = ndarray::Array1; pub type Array2 = ndarray::Array2; pub type PyArray1 = numpy::PyArray1; -pub type IterState = argmin::core::IterState, (), Scalar>; +pub type IterState = + argmin::core::IterState, (), Scalar>;