Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplex algorithm #183

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion argmin/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ wasm-bindgen = ["instant/wasm-bindgen", "getrandom/js"]
slog-logger = ["slog", "slog-term", "slog-async"]
serde1 = ["serde", "serde_json", "rand/serde1", "bincode", "slog-json"]
ndarrayl = ["argmin-math/ndarray_latest-serde"]
nalgebral = ["argmin-math/nalgebra_latest-serde"]

[badges]
maintenance = { status = "actively-developed" }
Expand Down Expand Up @@ -93,7 +94,7 @@ required-features = ["argmin-math/ndarray_latest-serde", "slog-logger"]

[[example]]
name = "gaussnewton_nalgebra"
required-features = ["argmin-math/nalgebra_latest-serde", "slog-logger"]
required-features = ["nalgebral", "argmin-math/nalgebra_latest-serde", "slog-logger"]

[[example]]
name = "goldensectionsearch"
Expand Down
89 changes: 89 additions & 0 deletions argmin/examples/simplex.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
// Copyright 2018-2022 argmin developers
//
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
// http://opensource.org/licenses/MIT>, at your option. This file may not be
// copied, modified, or distributed except according to those terms.

// use argmin::core::{ArgminOp, ArgminSlogLogger, Error, Executor, ObserverMode};
use argmin::core::*;
use argmin::solver::simplex::Simplex;

struct Problem {
c: Vec<f64>,
b: Vec<f64>,
a: Vec<Vec<f64>>,
}

impl LinearProgram for Problem {
type Param = Vec<Self::Float>;
type Float = f64;

fn c(&self) -> Result<&[Self::Float], Error> {
Ok(&self.c)
}

fn b(&self) -> Result<&[Self::Float], Error> {
Ok(&self.b)
}

fn A(&self) -> Result<&[Vec<Self::Float>], Error> {
Ok(&self.a)
}
}

fn run() -> Result<(), Error> {
// let problem = Problem {
// c: [-5.0f64, -6.0, -6.0].to_vec(),
// b: [10.0f64, 10.0, 10.0].to_vec(),
// a: [
// [1.0, 2.0, 2.0].to_vec(),
// [2.0, 1.0, 2.0].to_vec(),
// [2.0, 2.0, 1.0].to_vec(),
// ]
// .to_vec(),
// };
// let problem = Problem {
// c: [-3.0f64, -1.0].to_vec(),
// b: [1.0f64, 1.0, 2.0].to_vec(),
// a: [
// [-1.0, 1.0].to_vec(),
// [1.0, -1.0].to_vec(),
// [0.0, 1.0].to_vec(),
// ]
// .to_vec(),
// };
let problem = Problem {
c: [-12.0f64, -8.0].to_vec(),
b: [80.0f64, 100.0, 75.0].to_vec(),
a: [
[4.0, 2.0].to_vec(),
[2.0, 3.0].to_vec(),
[5.0, 1.0].to_vec(),
]
.to_vec(),
};

let solver: Simplex<f64> = Simplex::new();

// let init_param = [1.0, 2.0, 3.0].to_vec();
let init_param = vec![];
let res = Executor::new(problem, solver, init_param)
// .add_observer(ArgminSlogLogger::term(), ObserverMode::Always)
.max_iters(3)
.run()?;

// Wait a second (lets the logger flush everything before printing again)
std::thread::sleep(std::time::Duration::from_secs(1));

// Print result
println!("{}", res);
Ok(())
}

fn main() {
if let Err(ref e) = run() {
println!("{}", e);
std::process::exit(1);
}
}
54 changes: 27 additions & 27 deletions argmin/src/core/executor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
#[cfg(feature = "serde1")]
use crate::core::{serialization::*, ArgminCheckpoint, DeserializeOwnedAlias};
use crate::core::{
ArgminIterData, ArgminKV, ArgminOp, ArgminResult, Error, IterState, Observe, Observer,
ObserverMode, OpWrapper, Solver, TerminationReason,
ArgminIterData, ArgminKV, ArgminResult, Error, Observe, Observer, ObserverMode, OpWrapper,
Solver, State, TerminationReason,
};
use instant;
use num_traits::Float;
Expand All @@ -25,18 +25,18 @@ use std::sync::Arc;
/// Executes a solver
#[derive(Clone)]
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
pub struct Executor<O: ArgminOp, S> {
pub struct Executor<O, S, I> {
/// solver
solver: S,
/// operator
#[cfg_attr(feature = "serde1", serde(skip))]
pub op: OpWrapper<O>,
/// State
#[cfg_attr(feature = "serde1", serde(bound = "IterState<O>: Serialize"))]
state: IterState<O>,
// #[cfg_attr(feature = "serde1", serde(bound = "IterState<O>: Serialize"))]
state: I,
/// Storage for observers
#[cfg_attr(feature = "serde1", serde(skip))]
observers: Observer<O>,
observers: Observer<I>,
/// Checkpoint
#[cfg(feature = "serde1")]
checkpoint: ArgminCheckpoint,
Expand All @@ -46,14 +46,14 @@ pub struct Executor<O: ArgminOp, S> {
timer: bool,
}

impl<O, S> Executor<O, S>
impl<O, S, I> Executor<O, S, I>
where
O: ArgminOp,
S: Solver<O>,
S: Solver<I>,
I: State<Operator = O>,
{
/// Create a new executor with a `solver` and an initial parameter `init_param`
pub fn new(op: O, solver: S, init_param: O::Param) -> Self {
let state = IterState::new(init_param);
pub fn new(op: O, solver: S, init_param: I::Param) -> Self {
let state = I::new(init_param);
Executor {
solver,
op: OpWrapper::new(op),
Expand All @@ -77,7 +77,7 @@ where
Ok(executor)
}

fn update(&mut self, data: &ArgminIterData<O>) -> Result<(), Error> {
fn update(&mut self, data: &ArgminIterData<I>) -> Result<(), Error> {
if let Some(cur_param) = data.get_param() {
self.state.param(cur_param);
}
Expand Down Expand Up @@ -122,7 +122,7 @@ where
}

/// Run the executor
pub fn run(mut self) -> Result<ArgminResult<O>, Error> {
pub fn run(mut self) -> Result<ArgminResult<I>, Error> {
let total_time = if self.timer {
Some(instant::Instant::now())
} else {
Expand Down Expand Up @@ -247,7 +247,7 @@ where

/// Attaches a observer which implements `ArgminLog` to the solver.
#[must_use]
pub fn add_observer<OBS: Observe<O> + 'static>(
pub fn add_observer<OBS: Observe<I> + 'static>(
mut self,
observer: OBS,
mode: ObserverMode,
Expand All @@ -265,35 +265,35 @@ where

/// Set target cost value
#[must_use]
pub fn target_cost(mut self, cost: O::Float) -> Self {
pub fn target_cost(mut self, cost: I::Float) -> Self {
self.state.target_cost(cost);
self
}

/// Set cost value
#[must_use]
pub fn cost(mut self, cost: O::Float) -> Self {
pub fn cost(mut self, cost: I::Float) -> Self {
self.state.cost(cost);
self
}

/// Set Gradient
#[must_use]
pub fn grad(mut self, grad: O::Param) -> Self {
pub fn grad(mut self, grad: I::Param) -> Self {
self.state.grad(grad);
self
}

/// Set Hessian
#[must_use]
pub fn hessian(mut self, hessian: O::Hessian) -> Self {
pub fn hessian(mut self, hessian: I::Hessian) -> Self {
self.state.hessian(hessian);
self
}

/// Set Jacobian
#[must_use]
pub fn jacobian(mut self, jacobian: O::Jacobian) -> Self {
pub fn jacobian(mut self, jacobian: I::Jacobian) -> Self {
self.state.jacobian(jacobian);
self
}
Expand Down Expand Up @@ -340,7 +340,7 @@ where
#[cfg(test)]
mod tests {
use super::*;
use crate::core::MinimalNoOperator;
use crate::core::{ArgminOp, IterState, MinimalNoOperator, State};
use approx::assert_relative_eq;

#[test]
Expand All @@ -349,15 +349,15 @@ mod tests {
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
struct TestSolver {}

impl<O> Solver<O> for TestSolver
impl<O> Solver<IterState<O>> for TestSolver
where
O: ArgminOp,
{
fn next_iter(
&mut self,
_op: &mut OpWrapper<O>,
_state: &mut IterState<O>,
) -> Result<ArgminIterData<O>, Error> {
) -> Result<ArgminIterData<IterState<O>>, Error> {
Ok(ArgminIterData::new())
}
}
Expand All @@ -369,7 +369,7 @@ mod tests {

// 1) Parameter vector changes, but not cost (continues to be `Inf`)
let new_param = vec![1.0, 1.0];
let new_iterdata: ArgminIterData<MinimalNoOperator> =
let new_iterdata: ArgminIterData<IterState<MinimalNoOperator>> =
ArgminIterData::new().param(new_param.clone());
executor.update(&new_iterdata).unwrap();
assert_eq!(executor.state.get_best_param().unwrap(), new_param);
Expand All @@ -379,7 +379,7 @@ mod tests {
// 2) Parameter vector and cost changes to something better
let new_param = vec![2.0, 2.0];
let new_cost = 10.0;
let new_iterdata: ArgminIterData<MinimalNoOperator> = ArgminIterData::new()
let new_iterdata: ArgminIterData<IterState<MinimalNoOperator>> = ArgminIterData::new()
.param(new_param.clone())
.cost(new_cost);
executor.update(&new_iterdata).unwrap();
Expand All @@ -395,7 +395,7 @@ mod tests {
let new_param = vec![3.0, 3.0];
let old_cost = executor.state.get_best_cost();
let new_cost = old_cost + 1.0;
let new_iterdata: ArgminIterData<MinimalNoOperator> =
let new_iterdata: ArgminIterData<IterState<MinimalNoOperator>> =
ArgminIterData::new().param(new_param).cost(new_cost);
executor.update(&new_iterdata).unwrap();
assert_eq!(executor.state.get_best_param(), old_param);
Expand All @@ -411,7 +411,7 @@ mod tests {

let new_param = vec![1.0, 1.0];
let new_cost = std::f64::NEG_INFINITY;
let new_iterdata: ArgminIterData<MinimalNoOperator> = ArgminIterData::new()
let new_iterdata: ArgminIterData<IterState<MinimalNoOperator>> = ArgminIterData::new()
.param(new_param.clone())
.cost(new_cost);
executor.update(&new_iterdata).unwrap();
Expand All @@ -423,7 +423,7 @@ mod tests {
let old_param = executor.state.get_best_param().unwrap();
let new_param = vec![6.0, 6.0];
let new_cost = std::f64::INFINITY;
let new_iterdata: ArgminIterData<MinimalNoOperator> =
let new_iterdata: ArgminIterData<IterState<MinimalNoOperator>> =
ArgminIterData::new().param(new_param).cost(new_cost);
executor.update(&new_iterdata).unwrap();
assert_eq!(executor.state.get_best_param().unwrap(), old_param);
Expand Down
Loading