Skip to content

Commit

Permalink
Added python bindings
Browse files Browse the repository at this point in the history
  • Loading branch information
jdavidberger committed Oct 31, 2022
1 parent d75677a commit 2d45e25
Show file tree
Hide file tree
Showing 16 changed files with 636 additions and 185 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/build_wheels.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ jobs:

- name: Build wheel
run: |
python -m build
python -m build --wheel
git status
- uses: actions/upload-artifact@v2
Expand Down
4 changes: 4 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ set(CMAKE_CXX_STANDARD 17)

option(USE_ASAN "Use address sanitizer" OFF)

if(NOT PYTHON_EXECUTABLE)
find_package (Python 3.8 COMPONENTS Interpreter Development REQUIRED)
endif()

if(UNIX)
add_compile_options(-fPIC -Wall -Wno-unused-variable -Wno-switch -Wno-parentheses -Wno-missing-braces -Werror=return-type -fvisibility=hidden -Werror=vla -fno-math-errno -Werror=pointer-arith)
SET(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${SHARED_FLAGS} -std=gnu99 -Werror=incompatible-pointer-types -Werror=implicit-function-declaration -Werror=missing-field-initializers ")
Expand Down
19 changes: 19 additions & 0 deletions include/cnkalman/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,24 @@ namespace cnkalman {
cnkalman_meas_model meas_mdl = {};

KalmanMeasurementModel(KalmanModel* kalmanModel, const std::string& name, size_t meas_cnt);
KalmanMeasurementModel(KalmanModel* kalmanModel, size_t meas_cnt) : KalmanMeasurementModel(kalmanModel, "meas", meas_cnt) {}
virtual ~KalmanMeasurementModel() = default;

/***
* @param x current state
* @param z measurement prediction
* @param h measurement jacobian wrt x
* @return Whether the residual / jacobian is valid
*/
virtual bool predict_measurement(const CnMat& x, CnMat* z, CnMat* h) = 0;

/***
* @param Z observed measurement
* @param x current state
* @param z measurement prediction
* @param h measurement jacobian wrt x
* @return Whether the residual / jacobian is valid
*/
virtual bool residual(const CnMat& Z, const CnMat& x, CnMat* y, CnMat* h);

cnkalman_update_extended_stats_t update(FLT t, const struct CnMat& Z, CnMat& R);
Expand Down Expand Up @@ -47,6 +63,7 @@ namespace cnkalman {
virtual std::ostream& write(std::ostream&) const;

KalmanModel(const std::string& name, size_t state_cnt);
KalmanModel(size_t state_cnt) : KalmanModel("mdl", state_cnt) {}
virtual void reset();
virtual ~KalmanModel();

Expand All @@ -68,13 +85,15 @@ namespace cnkalman {
struct CN_EXPORT_CLASS KalmanLinearPredictionModel : public KalmanModel {
virtual const CnMat& F() const = 0;
KalmanLinearPredictionModel(const std::string &name, size_t stateCnt);
KalmanLinearPredictionModel(size_t stateCnt) : KalmanLinearPredictionModel("mdl", stateCnt) {}
void predict(FLT dt, const CnMat& x0, CnMat* x1, CnMat* cF) override;
};

struct CN_EXPORT_CLASS KalmanLinearMeasurementModel : public KalmanMeasurementModel {
CnMat H;

KalmanLinearMeasurementModel(KalmanModel* kalmanModel, const std::string& name, const CnMat& H);
KalmanLinearMeasurementModel(KalmanModel* kalmanModel, const CnMat& H) : KalmanLinearMeasurementModel(kalmanModel, "meas", H) {}
~KalmanLinearMeasurementModel() override;
bool predict_measurement(const CnMat &x, CnMat *z, CnMat *h) override;
};
Expand Down
73 changes: 73 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import glob
import os
import pathlib
import shutil
import sys

from setuptools import setup, Extension
from setuptools.command.build_ext import build_ext as build_ext_orig

cmake_path = os.environ.get('CMAKE_PATH', "cmake")
class CMakeExtension(Extension):

def __init__(self, name):
# don't invoke the original build_ext for this special extension
super().__init__(name, sources=[])


class build_ext(build_ext_orig):

def run(self):
for ext in self.extensions:
self.build_cmake(ext)
super().run()

def build_cmake(self, ext):
cwd = pathlib.Path().absolute()

# these dirs will be created in build_py, so if you don't have
# any python sources to bundle, the dirs will be missing
build_temp = pathlib.Path(self.build_temp)
build_temp.mkdir(parents=True, exist_ok=True)
extdir = pathlib.Path(self.get_ext_fullpath(ext.name))
extdir.mkdir(parents=True, exist_ok=True)

config = 'Release'
cmake_args = [
'-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=' + str(extdir.parent.absolute()),
'-DCMAKE_BUILD_TYPE=' + config,
'-DCMAKE_INSTALL_PREFIX=' + str(build_temp.absolute()),
'-DPYTHON_EXECUTABLE=' + sys.executable
]

# example of build args
build_args = [
'--config', config,
'--', '-j4'
]

build_temp_dir = str(build_temp.absolute())
ext_dir_parent = str(extdir.parent.absolute())
os.chdir(str(build_temp.absolute()))
self.spawn([cmake_path, str(cwd)] + cmake_args)
if not self.dry_run:
self.spawn([cmake_path, '--build', '.', '--target', 'install'] + build_args)
print(build_temp_dir + "/lib/*.so")
for file in glob.glob(build_temp_dir + "/lib/*.so"):
print(file)
shutil.copy(file, ext_dir_parent + "/cnkalman")

# Troubleshooting: if fail on line above then delete all possible
# temporary CMake files including "CMakeCache.txt" in top level dir.
os.chdir(str(cwd))


setup(
name='cnkalman',
packages=['cnkalman'],
ext_modules=[CMakeExtension('.')],
cmdclass={
'build_ext': build_ext,
},
setup_requires=["setuptools-git-versioning"],
)
7 changes: 7 additions & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,10 @@ target_include_directories(cnkalman PUBLIC
target_link_libraries(cnkalman cnmatrix)

install(TARGETS cnkalman DESTINATION lib)

find_package(pybind11)
if(pybind11_FOUND)
pybind11_add_module(filter cnkalman_python_bindings.cpp)
target_link_libraries(filter PUBLIC cnkalman)
install(TARGETS filter DESTINATION lib)
endif()
Loading

0 comments on commit 2d45e25

Please sign in to comment.