Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make "sparse" solver check if equations are linear. #860

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion nmodl/ode.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from importlib import import_module

import sympy as sp
import itertools

# import known_functions through low-level mechanism because the ccode
# module is overwritten in sympy and contents of that submodule cannot be
Expand Down Expand Up @@ -272,6 +273,8 @@ def solve_non_lin_system(eq_strings, vars, constants, function_calls):

eqs, state_vars, sympy_vars = _sympify_eqs(eq_strings, vars, constants)

linear = _is_linear(eqs, state_vars, sympy_vars)

custom_fcts = _get_custom_functions(function_calls)

jacobian = sp.Matrix(eqs).jacobian(state_vars)
Expand All @@ -291,7 +294,18 @@ def solve_non_lin_system(eq_strings, vars, constants, function_calls):
# interweave
code = _interweave_eqs(vecFcode, vecJcode)

return code
return code, linear


def _is_linear(eqs, state_vars, sympy_vars):
for expr in eqs:
for (x, y) in itertools.combinations_with_replacement(state_vars, 2):
alkino marked this conversation as resolved.
Show resolved Hide resolved
try:
if not sp.Eq(sp.diff(expr, x, y), 0):
return False
except TypeError:
return False
return True


def integrate2c(diff_string, dt_var, vars, use_pade_approx=False):
Expand Down
2 changes: 2 additions & 0 deletions src/pybind/pyembed.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ struct SolveNonLinearSystemExecutor: public PythonExecutor {
// output
// returns a vector of solutions, i.e. new statements to add to block:
std::vector<std::string> solutions;
// returns if the system is linear or not.
bool linear;
// may also return a python exception message:
std::string exception_message;

Expand Down
4 changes: 3 additions & 1 deletion src/pybind/wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,20 +66,22 @@ void SolveNonLinearSystemExecutor::operator()() {
from nmodl.ode import solve_non_lin_system
exception_message = ""
try:
solutions = solve_non_lin_system(equation_strings,
solutions, linear = solve_non_lin_system(equation_strings,
state_vars,
vars,
function_calls)
except Exception as e:
# if we fail, fail silently and return empty string
solutions = [""]
linear = False
new_local_vars = [""]
exception_message = str(e)
)",
py::globals(),
locals);
// returns a vector of solutions, i.e. new statements to add to block:
solutions = locals["solutions"].cast<std::vector<std::string>>();
linear = locals["linear"].cast<bool>();
// may also return a python exception message:
exception_message = locals["exception_message"].cast<std::string>();
}
Expand Down
10 changes: 8 additions & 2 deletions src/visitors/sympy_solver_visitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,7 @@ void SympySolverVisitor::solve_non_linear_system(
(*solver)();
// returns a vector of solutions, i.e. new statements to add to block:
auto solutions = solver->solutions;
bool linear = solver->linear;
// may also return a python exception message:
auto exception_message = solver->exception_message;
pywrap::EmbeddedPythonLoader::get_instance().api()->destroy_nsls_executor(solver);
Expand All @@ -364,8 +365,13 @@ void SympySolverVisitor::solve_non_linear_system(
exception_message);
return;
}
logger->debug("SympySolverVisitor :: Constructing eigen newton solve block");
construct_eigen_solver_block(pre_solve_statements, solutions, false);
if (!linear) {
logger->debug("SympySolverVisitor :: Constructing eigen newton solve block");
}
else {
logger->debug("SympySolverVisitor :: Constructing eigen solve block");
}
construct_eigen_solver_block(pre_solve_statements, solutions, linear);
}

void SympySolverVisitor::visit_var_name(ast::VarName& node) {
Expand Down
Loading