From 00a0296c885d997a709989f67e983a2d5ebc134d Mon Sep 17 00:00:00 2001 From: ctrl-z-9000-times Date: Wed, 27 Apr 2022 10:12:15 -0400 Subject: [PATCH] Make "sparse" solver check if equations are linear. If the system is linear, then newtons method always converges in exactly one iteration. When using the sparse solver on linear systems omit the newtons iteration and solve directly. This should make the resulting code run marginally faster by skipping the check for convergence. Currently the check for convergence is implemented as "error = sqrt(|F|^2)". --- nmodl/ode.py | 16 ++++++- src/pybind/pyembed.hpp | 2 + src/pybind/wrapper.cpp | 4 +- src/visitors/sympy_solver_visitor.cpp | 10 ++++- test/unit/visitor/sympy_solver.cpp | 60 ++++++++++----------------- 5 files changed, 50 insertions(+), 42 deletions(-) diff --git a/nmodl/ode.py b/nmodl/ode.py index cdbeca0458..368241afb0 100644 --- a/nmodl/ode.py +++ b/nmodl/ode.py @@ -8,6 +8,7 @@ from importlib import import_module import sympy as sp +import itertools # import known_functions through low-level mechanism because the ccode # module is overwritten in sympy and contents of that submodule cannot be @@ -272,6 +273,8 @@ def solve_non_lin_system(eq_strings, vars, constants, function_calls): eqs, state_vars, sympy_vars = _sympify_eqs(eq_strings, vars, constants) + linear = _is_linear(eqs, state_vars, sympy_vars) + custom_fcts = _get_custom_functions(function_calls) jacobian = sp.Matrix(eqs).jacobian(state_vars) @@ -291,7 +294,18 @@ def solve_non_lin_system(eq_strings, vars, constants, function_calls): # interweave code = _interweave_eqs(vecFcode, vecJcode) - return code + return code, linear + + +def _is_linear(eqs, state_vars, sympy_vars): + for expr in eqs: + for (x, y) in itertools.combinations_with_replacement(state_vars, 2): + try: + if not sp.Eq(sp.diff(expr, x, y), 0): + return False + except TypeError: + return False + return True def integrate2c(diff_string, dt_var, vars, use_pade_approx=False): diff --git a/src/pybind/pyembed.hpp b/src/pybind/pyembed.hpp index cfd78e73f9..b73697fec4 100644 --- a/src/pybind/pyembed.hpp +++ b/src/pybind/pyembed.hpp @@ -56,6 +56,8 @@ struct SolveNonLinearSystemExecutor: public PythonExecutor { // output // returns a vector of solutions, i.e. new statements to add to block: std::vector solutions; + // returns if the system is linear or not. + bool linear; // may also return a python exception message: std::string exception_message; diff --git a/src/pybind/wrapper.cpp b/src/pybind/wrapper.cpp index 3a97410a70..1fa219060d 100644 --- a/src/pybind/wrapper.cpp +++ b/src/pybind/wrapper.cpp @@ -66,13 +66,14 @@ void SolveNonLinearSystemExecutor::operator()() { from nmodl.ode import solve_non_lin_system exception_message = "" try: - solutions = solve_non_lin_system(equation_strings, + solutions, linear = solve_non_lin_system(equation_strings, state_vars, vars, function_calls) except Exception as e: # if we fail, fail silently and return empty string solutions = [""] + linear = False new_local_vars = [""] exception_message = str(e) )", @@ -80,6 +81,7 @@ void SolveNonLinearSystemExecutor::operator()() { locals); // returns a vector of solutions, i.e. new statements to add to block: solutions = locals["solutions"].cast>(); + linear = locals["linear"].cast(); // may also return a python exception message: exception_message = locals["exception_message"].cast(); } diff --git a/src/visitors/sympy_solver_visitor.cpp b/src/visitors/sympy_solver_visitor.cpp index 58683d601a..66782b24fa 100644 --- a/src/visitors/sympy_solver_visitor.cpp +++ b/src/visitors/sympy_solver_visitor.cpp @@ -356,6 +356,7 @@ void SympySolverVisitor::solve_non_linear_system( (*solver)(); // returns a vector of solutions, i.e. new statements to add to block: auto solutions = solver->solutions; + bool linear = solver->linear; // may also return a python exception message: auto exception_message = solver->exception_message; pywrap::EmbeddedPythonLoader::get_instance().api()->destroy_nsls_executor(solver); @@ -364,8 +365,13 @@ void SympySolverVisitor::solve_non_linear_system( exception_message); return; } - logger->debug("SympySolverVisitor :: Constructing eigen newton solve block"); - construct_eigen_solver_block(pre_solve_statements, solutions, false); + if (!linear) { + logger->debug("SympySolverVisitor :: Constructing eigen newton solve block"); + } + else { + logger->debug("SympySolverVisitor :: Constructing eigen solve block"); + } + construct_eigen_solver_block(pre_solve_statements, solutions, linear); } void SympySolverVisitor::visit_var_name(ast::VarName& node) { diff --git a/test/unit/visitor/sympy_solver.cpp b/test/unit/visitor/sympy_solver.cpp index 5bdfd25478..9b308888de 100644 --- a/test/unit/visitor/sympy_solver.cpp +++ b/test/unit/visitor/sympy_solver.cpp @@ -619,7 +619,7 @@ SCENARIO("Solve ODEs with derivimplicit method using SympySolverVisitor", )"; std::string expected_result = R"( DERIVATIVE states { - EIGEN_NEWTON_SOLVE[1]{ + EIGEN_LINEAR_SOLVE[1]{ LOCAL old_m }{ IF (mInf == 1) { @@ -628,7 +628,6 @@ SCENARIO("Solve ODEs with derivimplicit method using SympySolverVisitor", old_m = m }{ nmodl_eigen_x[0] = m - }{ nmodl_eigen_f[0] = (-nmodl_eigen_x[0]*dt+dt*mInf+mTau*(-nmodl_eigen_x[0]+old_m))/mTau nmodl_eigen_j[0] = -(dt+mTau)/mTau }{ @@ -659,7 +658,7 @@ SCENARIO("Solve ODEs with derivimplicit method using SympySolverVisitor", })"; std::string expected_result = R"( DERIVATIVE states { - EIGEN_NEWTON_SOLVE[2]{ + EIGEN_LINEAR_SOLVE[2]{ LOCAL a, b, old_y, old_x }{ old_y = y @@ -667,7 +666,6 @@ SCENARIO("Solve ODEs with derivimplicit method using SympySolverVisitor", }{ nmodl_eigen_x[0] = x nmodl_eigen_x[1] = y - }{ nmodl_eigen_f[0] = -nmodl_eigen_x[1]+a*dt+old_y nmodl_eigen_j[0] = 0 nmodl_eigen_j[2] = -1.0 @@ -703,7 +701,7 @@ SCENARIO("Solve ODEs with derivimplicit method using SympySolverVisitor", })"; std::string expected_result = R"( DERIVATIVE states { - EIGEN_NEWTON_SOLVE[2]{ + EIGEN_LINEAR_SOLVE[2]{ LOCAL a, b, old_M_1, old_M_0 }{ old_M_1 = M[1] @@ -711,7 +709,6 @@ SCENARIO("Solve ODEs with derivimplicit method using SympySolverVisitor", }{ nmodl_eigen_x[0] = M[0] nmodl_eigen_x[1] = M[1] - }{ nmodl_eigen_f[0] = -nmodl_eigen_x[1]+a*dt+old_M_1 nmodl_eigen_j[0] = 0 nmodl_eigen_j[2] = -1.0 @@ -748,7 +745,7 @@ SCENARIO("Solve ODEs with derivimplicit method using SympySolverVisitor", })"; std::string expected_result = R"( DERIVATIVE states { - EIGEN_NEWTON_SOLVE[2]{ + EIGEN_LINEAR_SOLVE[2]{ LOCAL a, b, old_x, old_y }{ old_x = x @@ -756,7 +753,6 @@ SCENARIO("Solve ODEs with derivimplicit method using SympySolverVisitor", }{ nmodl_eigen_x[0] = x nmodl_eigen_x[1] = y - }{ nmodl_eigen_f[0] = -nmodl_eigen_x[0]+a*dt+old_x nmodl_eigen_j[0] = -1.0 nmodl_eigen_j[2] = 0 @@ -825,7 +821,7 @@ SCENARIO("Solve ODEs with derivimplicit method using SympySolverVisitor", DERIVATIVE states { LOCAL a, b IF (a == 1) { - EIGEN_NEWTON_SOLVE[2]{ + EIGEN_LINEAR_SOLVE[2]{ LOCAL old_x, old_y }{ old_x = x @@ -833,7 +829,6 @@ SCENARIO("Solve ODEs with derivimplicit method using SympySolverVisitor", }{ nmodl_eigen_x[0] = x nmodl_eigen_x[1] = y - }{ nmodl_eigen_f[0] = -nmodl_eigen_x[0]+a*dt+old_x nmodl_eigen_j[0] = -1.0 nmodl_eigen_j[2] = 0 @@ -875,7 +870,7 @@ SCENARIO("Solve ODEs with derivimplicit method using SympySolverVisitor", })"; std::string expected_result = R"( DERIVATIVE states { - EIGEN_NEWTON_SOLVE[2]{ + EIGEN_LINEAR_SOLVE[2]{ LOCAL a, b, old_x, old_y }{ old_x = x @@ -883,7 +878,6 @@ SCENARIO("Solve ODEs with derivimplicit method using SympySolverVisitor", }{ nmodl_eigen_x[0] = x nmodl_eigen_x[1] = y - }{ nmodl_eigen_f[0] = -nmodl_eigen_x[0]+nmodl_eigen_x[1]*a*dt+b*dt+old_x nmodl_eigen_j[0] = -1.0 nmodl_eigen_j[2] = a*dt @@ -901,7 +895,7 @@ SCENARIO("Solve ODEs with derivimplicit method using SympySolverVisitor", })"; std::string expected_result_cse = R"( DERIVATIVE states { - EIGEN_NEWTON_SOLVE[2]{ + EIGEN_LINEAR_SOLVE[2]{ LOCAL a, b, old_x, old_y }{ old_x = x @@ -909,7 +903,6 @@ SCENARIO("Solve ODEs with derivimplicit method using SympySolverVisitor", }{ nmodl_eigen_x[0] = x nmodl_eigen_x[1] = y - }{ nmodl_eigen_f[0] = -nmodl_eigen_x[0]+nmodl_eigen_x[1]*a*dt+b*dt+old_x nmodl_eigen_j[0] = -1.0 nmodl_eigen_j[2] = a*dt @@ -954,7 +947,7 @@ SCENARIO("Solve ODEs with derivimplicit method using SympySolverVisitor", )"; std::string expected_result = R"( DERIVATIVE states { - EIGEN_NEWTON_SOLVE[3]{ + EIGEN_LINEAR_SOLVE[3]{ LOCAL a, b, c, d, h, old_x, old_y, old_z }{ old_x = x @@ -964,7 +957,6 @@ SCENARIO("Solve ODEs with derivimplicit method using SympySolverVisitor", nmodl_eigen_x[0] = x nmodl_eigen_x[1] = y nmodl_eigen_x[2] = z - }{ nmodl_eigen_f[0] = -nmodl_eigen_x[0]+nmodl_eigen_x[2]*a*dt+b*dt*h+old_x nmodl_eigen_j[0] = -1.0 nmodl_eigen_j[3] = 0 @@ -986,7 +978,7 @@ SCENARIO("Solve ODEs with derivimplicit method using SympySolverVisitor", })"; std::string expected_cse_result = R"( DERIVATIVE states { - EIGEN_NEWTON_SOLVE[3]{ + EIGEN_LINEAR_SOLVE[3]{ LOCAL a, b, c, d, h, old_x, old_y, old_z }{ old_x = x @@ -996,7 +988,6 @@ SCENARIO("Solve ODEs with derivimplicit method using SympySolverVisitor", nmodl_eigen_x[0] = x nmodl_eigen_x[1] = y nmodl_eigen_x[2] = z - }{ nmodl_eigen_f[0] = -nmodl_eigen_x[0]+nmodl_eigen_x[2]*a*dt+b*dt*h+old_x nmodl_eigen_j[0] = -1.0 nmodl_eigen_j[3] = 0 @@ -1042,7 +1033,7 @@ SCENARIO("Solve ODEs with derivimplicit method using SympySolverVisitor", )"; std::string expected_result = R"( DERIVATIVE scheme1 { - EIGEN_NEWTON_SOLVE[2]{ + EIGEN_LINEAR_SOLVE[2]{ LOCAL old_mc, old_m }{ old_mc = mc @@ -1050,7 +1041,6 @@ SCENARIO("Solve ODEs with derivimplicit method using SympySolverVisitor", }{ nmodl_eigen_x[0] = mc nmodl_eigen_x[1] = m - }{ nmodl_eigen_f[0] = -nmodl_eigen_x[0]*a*dt-nmodl_eigen_x[0]+nmodl_eigen_x[1]*b*dt+old_mc nmodl_eigen_j[0] = -a*dt-1.0 nmodl_eigen_j[2] = b*dt @@ -1086,14 +1076,13 @@ SCENARIO("Solve ODEs with derivimplicit method using SympySolverVisitor", )"; std::string expected_result = R"( DERIVATIVE scheme1 { - EIGEN_NEWTON_SOLVE[2]{ + EIGEN_LINEAR_SOLVE[2]{ LOCAL old_mc }{ old_mc = mc }{ nmodl_eigen_x[0] = mc nmodl_eigen_x[1] = m - }{ nmodl_eigen_f[0] = -nmodl_eigen_x[0]*a*dt-nmodl_eigen_x[0]+nmodl_eigen_x[1]*b*dt+old_mc nmodl_eigen_j[0] = -a*dt-1.0 nmodl_eigen_j[2] = b*dt @@ -1131,7 +1120,7 @@ SCENARIO("Solve ODEs with derivimplicit method using SympySolverVisitor", )"; std::string expected_result = R"( DERIVATIVE scheme1 { - EIGEN_NEWTON_SOLVE[2]{ + EIGEN_LINEAR_SOLVE[2]{ LOCAL old_mc, old_m }{ old_mc = mc @@ -1139,7 +1128,6 @@ SCENARIO("Solve ODEs with derivimplicit method using SympySolverVisitor", }{ nmodl_eigen_x[0] = mc nmodl_eigen_x[1] = m - }{ nmodl_eigen_f[0] = -nmodl_eigen_x[0]*a*dt-nmodl_eigen_x[0]+nmodl_eigen_x[1]*b*dt+old_mc nmodl_eigen_j[0] = -a*dt-1.0 nmodl_eigen_j[2] = b*dt @@ -1180,7 +1168,7 @@ SCENARIO("Solve ODEs with derivimplicit method using SympySolverVisitor", })"; std::string expected_result = R"( DERIVATIVE ihkin { - EIGEN_NEWTON_SOLVE[5]{ + EIGEN_LINEAR_SOLVE[5]{ LOCAL alpha, beta, k3p, k4, k1ca, k2, old_c1, old_o1, old_p0 }{ evaluate_fct(v, cai) @@ -1193,7 +1181,6 @@ SCENARIO("Solve ODEs with derivimplicit method using SympySolverVisitor", nmodl_eigen_x[2] = o2 nmodl_eigen_x[3] = p0 nmodl_eigen_x[4] = p1 - }{ nmodl_eigen_f[0] = -nmodl_eigen_x[0]*alpha*dt-nmodl_eigen_x[0]+nmodl_eigen_x[1]*beta*dt+old_c1 nmodl_eigen_j[0] = -alpha*dt-1.0 nmodl_eigen_j[5] = beta*dt @@ -1260,13 +1247,12 @@ SCENARIO("Solve ODEs with derivimplicit method using SympySolverVisitor", )"; std::string expected_result = R"( DERIVATIVE scheme1 { - EIGEN_NEWTON_SOLVE[1]{ + EIGEN_LINEAR_SOLVE[1]{ LOCAL old_W_0 }{ old_W_0 = W[0] }{ nmodl_eigen_x[0] = W[0] - }{ nmodl_eigen_f[0] = -nmodl_eigen_x[0]*dt*A[0]+nmodl_eigen_x[0]*dt*B[0]-nmodl_eigen_x[0]+3.0*dt*A[1]+old_W_0 nmodl_eigen_j[0] = -dt*A[0]+dt*B[0]-1.0 }{ @@ -1300,7 +1286,7 @@ SCENARIO("Solve ODEs with derivimplicit method using SympySolverVisitor", )"; std::string expected_result = R"( DERIVATIVE scheme1 { - EIGEN_NEWTON_SOLVE[2]{ + EIGEN_LINEAR_SOLVE[2]{ LOCAL old_M_0, old_M_1 }{ old_M_0 = M[0] @@ -1308,7 +1294,6 @@ SCENARIO("Solve ODEs with derivimplicit method using SympySolverVisitor", }{ nmodl_eigen_x[0] = M[0] nmodl_eigen_x[1] = M[1] - }{ nmodl_eigen_f[0] = -nmodl_eigen_x[0]*dt*A[0]-nmodl_eigen_x[0]+nmodl_eigen_x[1]*dt*B[0]+old_M_0 nmodl_eigen_j[0] = -dt*A[0]-1.0 nmodl_eigen_j[2] = dt*B[0] @@ -1346,13 +1331,12 @@ SCENARIO("Solve ODEs with derivimplicit method using SympySolverVisitor", )"; std::string expected_result = R"( DERIVATIVE scheme1 { - EIGEN_NEWTON_SOLVE[1]{ + EIGEN_LINEAR_SOLVE[1]{ LOCAL old_W_0 }{ old_W_0 = W[0] }{ nmodl_eigen_x[0] = W[0] - }{ nmodl_eigen_f[0] = -nmodl_eigen_x[0]*dt*A[0]+nmodl_eigen_x[0]*dt*B[0]-nmodl_eigen_x[0]+3.0*dt*A[1]+old_W_0 nmodl_eigen_j[0] = -dt*A[0]+dt*B[0]-1.0 }{ @@ -2053,7 +2037,7 @@ SCENARIO("Solve NONLINEAR block using SympySolver Visitor", "[visitor][solver][s x } NONLINEAR nonlin { - ~ x = 5 + ~ x * x * x = 5 })"; std::string expected_text = R"( NONLINEAR nonlin { @@ -2062,8 +2046,8 @@ SCENARIO("Solve NONLINEAR block using SympySolver Visitor", "[visitor][solver][s }{ nmodl_eigen_x[0] = x }{ - nmodl_eigen_f[0] = 5.0-nmodl_eigen_x[0] - nmodl_eigen_j[0] = -1.0 + nmodl_eigen_f[0] = 5.0-pow(nmodl_eigen_x[0], 3) + nmodl_eigen_j[0] = -3.0 * pow(nmodl_eigen_x[0], 2) }{ x = nmodl_eigen_x[0] }{ @@ -2084,7 +2068,7 @@ SCENARIO("Solve NONLINEAR block using SympySolver Visitor", "[visitor][solver][s NONLINEAR nonlin { ~ s[0] = 1 ~ s[1] = 3 - ~ s[2] + s[1] = s[0] + ~ s[2] + s[1] = s[0] * s[0] })"; std::string expected_text = R"( NONLINEAR nonlin { @@ -2097,14 +2081,14 @@ SCENARIO("Solve NONLINEAR block using SympySolver Visitor", "[visitor][solver][s }{ nmodl_eigen_f[0] = 1.0-nmodl_eigen_x[0] nmodl_eigen_f[1] = 3.0-nmodl_eigen_x[1] - nmodl_eigen_f[2] = nmodl_eigen_x[0]-nmodl_eigen_x[1]-nmodl_eigen_x[2] + nmodl_eigen_f[2] = pow(nmodl_eigen_x[0], 2)-nmodl_eigen_x[1]-nmodl_eigen_x[2] nmodl_eigen_j[0] = -1.0 nmodl_eigen_j[3] = 0 nmodl_eigen_j[6] = 0 nmodl_eigen_j[1] = 0 nmodl_eigen_j[4] = -1.0 nmodl_eigen_j[7] = 0 - nmodl_eigen_j[2] = 1.0 + nmodl_eigen_j[2] = 2.0 * nmodl_eigen_x[0] nmodl_eigen_j[5] = -1.0 nmodl_eigen_j[8] = -1.0 }{