Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PoC: Python bindings #340

Draft
wants to merge 15 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,5 @@ target/
target/*
*.log
justfile
.vscode
.venv
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
members = [
"argmin",
"argmin-math",
"argmin-py",
]

exclude = [
Expand Down
18 changes: 18 additions & 0 deletions argmin-py/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
[package]
name = "argmin-py"
version = "0.1.0"
edition = "2021"

[lib]
name = "argmin"
crate-type = ["cdylib"]

[dependencies]
anyhow = "1.0.70"
argmin_testfunctions = "0.1.1"
argmin = {path="../argmin", default-features=false, features=[]}
argmin-math = {path="../argmin-math", features=["ndarray_latest-serde"]}
ndarray-linalg = { version = "0.16", features = ["netlib"] }
ndarray = { version = "0.15", features = ["serde-1"] }
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This section requires cleanup, I am not sure what's the best configuration of features.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks mostly fine. I think in the long run at least the serde1 feature of argmin can be enabled because that would allow checkpointing. But I guess checkpointing will need more work anyways.
At some point we will have to decide which BLAS backend to use. This is probably mostly a platform issue (only Intel-MKL works on Linux, Windows and Mac) and a licensing issue since the compiled code will be packaged into a python module.

numpy = "0.18.0"
pyo3 = {version="0.18.1", features=["extension-module", "anyhow"]}
26 changes: 26 additions & 0 deletions argmin-py/examples/newton.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Copyright 2018-2023 argmin developers
#
# Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
# http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
# http://opensource.org/licenses/MIT>, at your option. This file may not be
# copied, modified, or distributed except according to those terms.
from argmin import Problem, Solver, Executor
import numpy as np
from scipy.optimize import rosen_der, rosen_hess


def main():
problem = Problem(
gradient=rosen_der,
hessian=rosen_hess,
)
solver = Solver.Newton
executor = Executor(problem, solver)
executor.configure(param=np.array([-1.2, 1.0]), max_iters=8)

result = executor.run()
print(result)


if __name__ == "__main__":
main()
71 changes: 71 additions & 0 deletions argmin-py/src/executor.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
// Copyright 2018-2023 argmin developers
//
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
// http://opensource.org/licenses/MIT>, at your option. This file may not be
// copied, modified, or distributed except according to those terms.

// TODO: docs

use pyo3::{prelude::*, types::PyDict};

use argmin::core;

use crate::problem::Problem;
use crate::solver::{DynamicSolver, Solver};
use crate::types::{IterState, PyArray1};

#[pyclass]
pub struct Executor(Option<core::Executor<Problem, DynamicSolver, IterState>>);

impl Executor {
/// Consumes the inner executor.
///
/// PyObjects do not allow methods that consume the object itself, so this is a workaround
/// for using methods like `configure` and `run`.
fn take(&mut self) -> anyhow::Result<core::Executor<Problem, DynamicSolver, IterState>> {
stefan-k marked this conversation as resolved.
Show resolved Hide resolved
let Some(inner) = self.0.take() else {
return Err(anyhow::anyhow!("Executor was already run."));
};
Ok(inner)
}
}

#[pymethods]
impl Executor {
#[new]
fn new(problem: Problem, solver: Solver) -> Self {
Self(Some(core::Executor::new(problem, solver.into())))
}

#[pyo3(signature = (**kwargs))]
fn configure(&mut self, kwargs: Option<&PyDict>) -> PyResult<()> {
if let Some(kwargs) = kwargs {
let param = kwargs
.get_item("param")
.map(|x| x.extract::<&PyArray1>())
.map_or(Ok(None), |r| r.map(Some))?;
let max_iters = kwargs
.get_item("max_iters")
.map(|x| x.extract())
.map_or(Ok(None), |r| r.map(Some))?;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I really like how you transformed the somewhat weird way one needs to set the initial state in argmin (using a closure) into a very pythonic **kwargs thing. This may however turn into quite a chore given that there are multiple kinds of state (IterState and PopulationState) with lots of methods. However, I would label this problem low priority for now.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense. I might rewrite this when I add more solvers to the python extension.


self.0 = Some(self.take()?.configure(|mut state| {
if let Some(param) = param {
state = state.param(param.to_owned_array());
}
if let Some(max_iters) = max_iters {
state = state.max_iters(max_iters);
}
state
}));
}
Ok(())
}

fn run(&mut self) -> PyResult<String> {
// TODO: return usable OptimizationResult
let res = self.take()?.run();
Ok(res?.to_string())
}
}
24 changes: 24 additions & 0 deletions argmin-py/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
// Copyright 2018-2023 argmin developers
//
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
// http://opensource.org/licenses/MIT>, at your option. This file may not be
// copied, modified, or distributed except according to those terms.

// TODO: docs
mod executor;
mod problem;
mod solver;
mod types;

use pyo3::prelude::*;

#[pymodule]
#[pyo3(name = "argmin")]
fn argmin_py(_py: Python, m: &PyModule) -> PyResult<()> {
m.add_class::<executor::Executor>()?;
m.add_class::<problem::Problem>()?;
m.add_class::<solver::Solver>()?;

Ok(())
}
68 changes: 68 additions & 0 deletions argmin-py/src/problem.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
// Copyright 2018-2023 argmin developers
//
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
// http://opensource.org/licenses/MIT>, at your option. This file may not be
// copied, modified, or distributed except according to those terms.

// TODO: docs

use numpy::ToPyArray;
use pyo3::{prelude::*, types::PyTuple};

use argmin::core;

use crate::types::{Array1, Array2, Scalar};

#[pyclass]
#[derive(Clone)]
pub struct Problem {
gradient: PyObject,
hessian: PyObject,
// TODO: jacobian
}

#[pymethods]
impl Problem {
#[new]
fn new(gradient: PyObject, hessian: PyObject) -> Self {
Self { gradient, hessian }
}
}

impl core::Gradient for Problem {
type Param = Array1;
type Gradient = Array1;

fn gradient(&self, param: &Self::Param) -> Result<Self::Gradient, argmin::core::Error> {
call(&self.gradient, param)
}
}

impl argmin::core::Hessian for Problem {
type Param = Array1;

type Hessian = Array2;

fn hessian(&self, param: &Self::Param) -> Result<Self::Hessian, core::Error> {
call(&self.hessian, param)
}
}

fn call<InputDimension, OutputDimension>(
callable: &PyObject,
param: &ndarray::Array<Scalar, InputDimension>,
) -> Result<ndarray::Array<Scalar, OutputDimension>, argmin::core::Error>
where
InputDimension: ndarray::Dimension,
OutputDimension: ndarray::Dimension,
{
// TODO: prevent dynamic dispatch for every call
Python::with_gil(|py| {
let args = PyTuple::new(py, [param.to_pyarray(py)]);
let pyresult = callable.call(py, args, Default::default())?;
let pyarray = pyresult.extract::<&numpy::PyArray<Scalar, OutputDimension>>(py)?;
// TODO: try to get ownership instead of cloning
Ok(pyarray.to_owned_array())
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. I am unsure what's the overhead of calling to_pyarray and extract for every evaluation of the gradient, hessian etc. Probably needs benchmarks.
  2. to_owned_array makes a copy of the data, this should not be necessary.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just had a new idea: You're currently using the ndarray backend which requires transitioning between numpy arrays and ndarray arrays. Instead we could add a new math backend based on PyArray, which would mean that numpy would do all the heavy lifting. I'm not sure whether numpy or ndarray is faster and I haven't really thought this through either.

I somehow thought that it would be possible to use the underlying memory in both Rust and Python without copying but that doesn't seem to be the case.

Regarding point 2: I agree. I assumed that there is also into_owned_array but that does not seem to be the case.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I somehow thought that it would be possible to use the underlying memory in both Rust and Python without copying but that doesn't seem to be the case.

I'll give it a try!

})
}
46 changes: 46 additions & 0 deletions argmin-py/src/solver.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
// Copyright 2018-2023 argmin developers
//
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
// http://opensource.org/licenses/MIT>, at your option. This file may not be
// copied, modified, or distributed except according to those terms.

// TODO: docs

use pyo3::prelude::*;

use argmin::{core, solver};

use crate::{problem::Problem, types::IterState};

#[pyclass]
#[derive(Clone)]
pub enum Solver {
Newton,
}

pub struct DynamicSolver(Box<dyn core::Solver<Problem, IterState> + Send>);

impl From<Solver> for DynamicSolver {
fn from(value: Solver) -> Self {
let inner = match value {
Solver::Newton => solver::newton::Newton::new(),
};
Self(Box::new(inner))
}
}

impl core::Solver<Problem, IterState> for DynamicSolver {
// TODO: make this a trait method so we can return a dynamic
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good to me! We could have both for backwards compatibility, right? The default impl of the name method would then just return self.NAME.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems I was able to solve two problems at once: When I remove the associated constant, Solver becomes object-safe, so we can create trait objects for it. Let me know if you objects.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! I thought the generics would also be a problem for object safety, but it's great if this isn't the case. Sounds good to me!

fn name(&self) -> &str {
self.0.name()
}

fn next_iter(
&mut self,
problem: &mut core::Problem<Problem>,
state: IterState,
) -> Result<(IterState, Option<core::KV>), core::Error> {
self.0.next_iter(problem, state)
}
}
15 changes: 15 additions & 0 deletions argmin-py/src/types.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
// Copyright 2018-2023 argmin developers
//
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
// http://opensource.org/licenses/MIT>, at your option. This file may not be
// copied, modified, or distributed except according to those terms.

//! Base types for the Python extension.

pub type Scalar = f64; // TODO: allow complex numbers
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Complex numbers would be great but I wouldn't give this a high priority.

pub type Array1 = ndarray::Array1<Scalar>;
pub type Array2 = ndarray::Array2<Scalar>;
pub type PyArray1 = numpy::PyArray1<Scalar>;

pub type IterState = argmin::core::IterState<Array1, Array1, (), ndarray::Array2<Scalar>, Scalar>;
2 changes: 1 addition & 1 deletion argmin/src/core/checkpointing/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ use std::fmt::Display;
/// }
/// # fn main() {}
/// ```
pub trait Checkpoint<S, I> {
pub trait Checkpoint<S, I>: Send {
/// Save a checkpoint
///
/// Gets a reference to the current `solver` of type `S` and to the current `state` of type
Expand Down
4 changes: 2 additions & 2 deletions argmin/src/core/executor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ where
}

// Observe after init
self.observers.observe_init(S::NAME, &logs)?;
self.observers.observe_init(self.solver.name(), &logs)?;
}

state.func_counts(&self.problem);
Expand Down Expand Up @@ -591,7 +591,7 @@ mod tests {
P: Clone,
F: ArgminFloat,
{
const NAME: &'static str = "OptimizationAlgorithm";
fn name(&self) -> &str { "OptimizationAlgorithm" }

// Only resets internal_state to 1
fn init(
Expand Down
2 changes: 1 addition & 1 deletion argmin/src/core/observers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ use std::sync::{Arc, Mutex};
/// }
/// }
/// ```
pub trait Observe<I> {
pub trait Observe<I>: Send {
/// Called once after initialization of the solver.
///
/// Has access to the name of the solver via `name` and to a key-value store `kv` with entries
Expand Down
2 changes: 1 addition & 1 deletion argmin/src/core/result.rs
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ where
{
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
writeln!(f, "OptimizationResult:")?;
writeln!(f, " Solver: {}", S::NAME)?;
writeln!(f, " Solver: {}", self.solver().name())?;
writeln!(
f,
" param (best): {}",
Expand Down
7 changes: 5 additions & 2 deletions argmin/src/core/solver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ use crate::core::{Error, Problem, State, TerminationReason, TerminationStatus, K
/// P: Clone,
/// F: ArgminFloat
/// {
/// const NAME: &'static str = "OptimizationAlgorithm";
/// fn name(&self) -> &str { "OptimizationAlgorithm" }
///
/// fn init(
/// &mut self,
Expand Down Expand Up @@ -67,7 +67,7 @@ use crate::core::{Error, Problem, State, TerminationReason, TerminationStatus, K
/// ```
pub trait Solver<O, I: State> {
/// Name of the solver. Mainly used in [Observers](`crate::core::observers::Observe`).
const NAME: &'static str;
// const NAME: &'static str;

/// Initializes the algorithm.
///
Expand Down Expand Up @@ -117,4 +117,7 @@ pub trait Solver<O, I: State> {
fn terminate(&mut self, _state: &I) -> TerminationStatus {
TerminationStatus::NotTerminated
}

/// Returns the name of the solver.
fn name(&self) -> &str;
}
4 changes: 3 additions & 1 deletion argmin/src/core/test_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,9 @@ impl TestSolver {
}

impl<O> Solver<O, IterState<Vec<f64>, (), (), (), f64>> for TestSolver {
const NAME: &'static str = "TestSolver";
fn name(&self) -> &str {
"TestSolver"
}

fn next_iter(
&mut self,
Expand Down
4 changes: 3 additions & 1 deletion argmin/src/solver/brent/brentopt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,9 @@ where
O: CostFunction<Param = F, Output = F>,
F: ArgminFloat,
{
const NAME: &'static str = "BrentOpt";
fn name(&self) -> &str {
"BrentOpt"
}

fn init(
&mut self,
Expand Down
4 changes: 3 additions & 1 deletion argmin/src/solver/brent/brentroot.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,9 @@ where
O: CostFunction<Param = F, Output = F>,
F: ArgminFloat,
{
const NAME: &'static str = "BrentRoot";
fn name(&self) -> &str {
"BrentRoot"
}

fn init(
&mut self,
Expand Down
Loading