diff --git a/nmodl/ode.py b/nmodl/ode.py index cdbeca0458..368241afb0 100644 --- a/nmodl/ode.py +++ b/nmodl/ode.py @@ -8,6 +8,7 @@ from importlib import import_module import sympy as sp +import itertools # import known_functions through low-level mechanism because the ccode # module is overwritten in sympy and contents of that submodule cannot be @@ -272,6 +273,8 @@ def solve_non_lin_system(eq_strings, vars, constants, function_calls): eqs, state_vars, sympy_vars = _sympify_eqs(eq_strings, vars, constants) + linear = _is_linear(eqs, state_vars, sympy_vars) + custom_fcts = _get_custom_functions(function_calls) jacobian = sp.Matrix(eqs).jacobian(state_vars) @@ -291,7 +294,18 @@ def solve_non_lin_system(eq_strings, vars, constants, function_calls): # interweave code = _interweave_eqs(vecFcode, vecJcode) - return code + return code, linear + + +def _is_linear(eqs, state_vars, sympy_vars): + for expr in eqs: + for (x, y) in itertools.combinations_with_replacement(state_vars, 2): + try: + if not sp.Eq(sp.diff(expr, x, y), 0): + return False + except TypeError: + return False + return True def integrate2c(diff_string, dt_var, vars, use_pade_approx=False): diff --git a/src/pybind/pyembed.hpp b/src/pybind/pyembed.hpp index cfd78e73f9..b73697fec4 100644 --- a/src/pybind/pyembed.hpp +++ b/src/pybind/pyembed.hpp @@ -56,6 +56,8 @@ struct SolveNonLinearSystemExecutor: public PythonExecutor { // output // returns a vector of solutions, i.e. new statements to add to block: std::vector solutions; + // returns if the system is linear or not. + bool linear; // may also return a python exception message: std::string exception_message; diff --git a/src/pybind/wrapper.cpp b/src/pybind/wrapper.cpp index 3a97410a70..1fa219060d 100644 --- a/src/pybind/wrapper.cpp +++ b/src/pybind/wrapper.cpp @@ -66,13 +66,14 @@ void SolveNonLinearSystemExecutor::operator()() { from nmodl.ode import solve_non_lin_system exception_message = "" try: - solutions = solve_non_lin_system(equation_strings, + solutions, linear = solve_non_lin_system(equation_strings, state_vars, vars, function_calls) except Exception as e: # if we fail, fail silently and return empty string solutions = [""] + linear = False new_local_vars = [""] exception_message = str(e) )", @@ -80,6 +81,7 @@ void SolveNonLinearSystemExecutor::operator()() { locals); // returns a vector of solutions, i.e. new statements to add to block: solutions = locals["solutions"].cast>(); + linear = locals["linear"].cast(); // may also return a python exception message: exception_message = locals["exception_message"].cast(); } diff --git a/src/visitors/sympy_solver_visitor.cpp b/src/visitors/sympy_solver_visitor.cpp index 58683d601a..66782b24fa 100644 --- a/src/visitors/sympy_solver_visitor.cpp +++ b/src/visitors/sympy_solver_visitor.cpp @@ -356,6 +356,7 @@ void SympySolverVisitor::solve_non_linear_system( (*solver)(); // returns a vector of solutions, i.e. new statements to add to block: auto solutions = solver->solutions; + bool linear = solver->linear; // may also return a python exception message: auto exception_message = solver->exception_message; pywrap::EmbeddedPythonLoader::get_instance().api()->destroy_nsls_executor(solver); @@ -364,8 +365,13 @@ void SympySolverVisitor::solve_non_linear_system( exception_message); return; } - logger->debug("SympySolverVisitor :: Constructing eigen newton solve block"); - construct_eigen_solver_block(pre_solve_statements, solutions, false); + if (!linear) { + logger->debug("SympySolverVisitor :: Constructing eigen newton solve block"); + } + else { + logger->debug("SympySolverVisitor :: Constructing eigen solve block"); + } + construct_eigen_solver_block(pre_solve_statements, solutions, linear); } void SympySolverVisitor::visit_var_name(ast::VarName& node) { diff --git a/test/unit/visitor/sympy_solver.cpp b/test/unit/visitor/sympy_solver.cpp index 5bdfd25478..9b308888de 100644 --- a/test/unit/visitor/sympy_solver.cpp +++ b/test/unit/visitor/sympy_solver.cpp @@ -619,7 +619,7 @@ SCENARIO("Solve ODEs with derivimplicit method using SympySolverVisitor", )"; std::string expected_result = R"( DERIVATIVE states { - EIGEN_NEWTON_SOLVE[1]{ + EIGEN_LINEAR_SOLVE[1]{ LOCAL old_m }{ IF (mInf == 1) { @@ -628,7 +628,6 @@ SCENARIO("Solve ODEs with derivimplicit method using SympySolverVisitor", old_m = m }{ nmodl_eigen_x[0] = m - }{ nmodl_eigen_f[0] = (-nmodl_eigen_x[0]*dt+dt*mInf+mTau*(-nmodl_eigen_x[0]+old_m))/mTau nmodl_eigen_j[0] = -(dt+mTau)/mTau }{ @@ -659,7 +658,7 @@ SCENARIO("Solve ODEs with derivimplicit method using SympySolverVisitor", })"; std::string expected_result = R"( DERIVATIVE states { - EIGEN_NEWTON_SOLVE[2]{ + EIGEN_LINEAR_SOLVE[2]{ LOCAL a, b, old_y, old_x }{ old_y = y @@ -667,7 +666,6 @@ SCENARIO("Solve ODEs with derivimplicit method using SympySolverVisitor", }{ nmodl_eigen_x[0] = x nmodl_eigen_x[1] = y - }{ nmodl_eigen_f[0] = -nmodl_eigen_x[1]+a*dt+old_y nmodl_eigen_j[0] = 0 nmodl_eigen_j[2] = -1.0 @@ -703,7 +701,7 @@ SCENARIO("Solve ODEs with derivimplicit method using SympySolverVisitor", })"; std::string expected_result = R"( DERIVATIVE states { - EIGEN_NEWTON_SOLVE[2]{ + EIGEN_LINEAR_SOLVE[2]{ LOCAL a, b, old_M_1, old_M_0 }{ old_M_1 = M[1] @@ -711,7 +709,6 @@ SCENARIO("Solve ODEs with derivimplicit method using SympySolverVisitor", }{ nmodl_eigen_x[0] = M[0] nmodl_eigen_x[1] = M[1] - }{ nmodl_eigen_f[0] = -nmodl_eigen_x[1]+a*dt+old_M_1 nmodl_eigen_j[0] = 0 nmodl_eigen_j[2] = -1.0 @@ -748,7 +745,7 @@ SCENARIO("Solve ODEs with derivimplicit method using SympySolverVisitor", })"; std::string expected_result = R"( DERIVATIVE states { - EIGEN_NEWTON_SOLVE[2]{ + EIGEN_LINEAR_SOLVE[2]{ LOCAL a, b, old_x, old_y }{ old_x = x @@ -756,7 +753,6 @@ SCENARIO("Solve ODEs with derivimplicit method using SympySolverVisitor", }{ nmodl_eigen_x[0] = x nmodl_eigen_x[1] = y - }{ nmodl_eigen_f[0] = -nmodl_eigen_x[0]+a*dt+old_x nmodl_eigen_j[0] = -1.0 nmodl_eigen_j[2] = 0 @@ -825,7 +821,7 @@ SCENARIO("Solve ODEs with derivimplicit method using SympySolverVisitor", DERIVATIVE states { LOCAL a, b IF (a == 1) { - EIGEN_NEWTON_SOLVE[2]{ + EIGEN_LINEAR_SOLVE[2]{ LOCAL old_x, old_y }{ old_x = x @@ -833,7 +829,6 @@ SCENARIO("Solve ODEs with derivimplicit method using SympySolverVisitor", }{ nmodl_eigen_x[0] = x nmodl_eigen_x[1] = y - }{ nmodl_eigen_f[0] = -nmodl_eigen_x[0]+a*dt+old_x nmodl_eigen_j[0] = -1.0 nmodl_eigen_j[2] = 0 @@ -875,7 +870,7 @@ SCENARIO("Solve ODEs with derivimplicit method using SympySolverVisitor", })"; std::string expected_result = R"( DERIVATIVE states { - EIGEN_NEWTON_SOLVE[2]{ + EIGEN_LINEAR_SOLVE[2]{ LOCAL a, b, old_x, old_y }{ old_x = x @@ -883,7 +878,6 @@ SCENARIO("Solve ODEs with derivimplicit method using SympySolverVisitor", }{ nmodl_eigen_x[0] = x nmodl_eigen_x[1] = y - }{ nmodl_eigen_f[0] = -nmodl_eigen_x[0]+nmodl_eigen_x[1]*a*dt+b*dt+old_x nmodl_eigen_j[0] = -1.0 nmodl_eigen_j[2] = a*dt @@ -901,7 +895,7 @@ SCENARIO("Solve ODEs with derivimplicit method using SympySolverVisitor", })"; std::string expected_result_cse = R"( DERIVATIVE states { - EIGEN_NEWTON_SOLVE[2]{ + EIGEN_LINEAR_SOLVE[2]{ LOCAL a, b, old_x, old_y }{ old_x = x @@ -909,7 +903,6 @@ SCENARIO("Solve ODEs with derivimplicit method using SympySolverVisitor", }{ nmodl_eigen_x[0] = x nmodl_eigen_x[1] = y - }{ nmodl_eigen_f[0] = -nmodl_eigen_x[0]+nmodl_eigen_x[1]*a*dt+b*dt+old_x nmodl_eigen_j[0] = -1.0 nmodl_eigen_j[2] = a*dt @@ -954,7 +947,7 @@ SCENARIO("Solve ODEs with derivimplicit method using SympySolverVisitor", )"; std::string expected_result = R"( DERIVATIVE states { - EIGEN_NEWTON_SOLVE[3]{ + EIGEN_LINEAR_SOLVE[3]{ LOCAL a, b, c, d, h, old_x, old_y, old_z }{ old_x = x @@ -964,7 +957,6 @@ SCENARIO("Solve ODEs with derivimplicit method using SympySolverVisitor", nmodl_eigen_x[0] = x nmodl_eigen_x[1] = y nmodl_eigen_x[2] = z - }{ nmodl_eigen_f[0] = -nmodl_eigen_x[0]+nmodl_eigen_x[2]*a*dt+b*dt*h+old_x nmodl_eigen_j[0] = -1.0 nmodl_eigen_j[3] = 0 @@ -986,7 +978,7 @@ SCENARIO("Solve ODEs with derivimplicit method using SympySolverVisitor", })"; std::string expected_cse_result = R"( DERIVATIVE states { - EIGEN_NEWTON_SOLVE[3]{ + EIGEN_LINEAR_SOLVE[3]{ LOCAL a, b, c, d, h, old_x, old_y, old_z }{ old_x = x @@ -996,7 +988,6 @@ SCENARIO("Solve ODEs with derivimplicit method using SympySolverVisitor", nmodl_eigen_x[0] = x nmodl_eigen_x[1] = y nmodl_eigen_x[2] = z - }{ nmodl_eigen_f[0] = -nmodl_eigen_x[0]+nmodl_eigen_x[2]*a*dt+b*dt*h+old_x nmodl_eigen_j[0] = -1.0 nmodl_eigen_j[3] = 0 @@ -1042,7 +1033,7 @@ SCENARIO("Solve ODEs with derivimplicit method using SympySolverVisitor", )"; std::string expected_result = R"( DERIVATIVE scheme1 { - EIGEN_NEWTON_SOLVE[2]{ + EIGEN_LINEAR_SOLVE[2]{ LOCAL old_mc, old_m }{ old_mc = mc @@ -1050,7 +1041,6 @@ SCENARIO("Solve ODEs with derivimplicit method using SympySolverVisitor", }{ nmodl_eigen_x[0] = mc nmodl_eigen_x[1] = m - }{ nmodl_eigen_f[0] = -nmodl_eigen_x[0]*a*dt-nmodl_eigen_x[0]+nmodl_eigen_x[1]*b*dt+old_mc nmodl_eigen_j[0] = -a*dt-1.0 nmodl_eigen_j[2] = b*dt @@ -1086,14 +1076,13 @@ SCENARIO("Solve ODEs with derivimplicit method using SympySolverVisitor", )"; std::string expected_result = R"( DERIVATIVE scheme1 { - EIGEN_NEWTON_SOLVE[2]{ + EIGEN_LINEAR_SOLVE[2]{ LOCAL old_mc }{ old_mc = mc }{ nmodl_eigen_x[0] = mc nmodl_eigen_x[1] = m - }{ nmodl_eigen_f[0] = -nmodl_eigen_x[0]*a*dt-nmodl_eigen_x[0]+nmodl_eigen_x[1]*b*dt+old_mc nmodl_eigen_j[0] = -a*dt-1.0 nmodl_eigen_j[2] = b*dt @@ -1131,7 +1120,7 @@ SCENARIO("Solve ODEs with derivimplicit method using SympySolverVisitor", )"; std::string expected_result = R"( DERIVATIVE scheme1 { - EIGEN_NEWTON_SOLVE[2]{ + EIGEN_LINEAR_SOLVE[2]{ LOCAL old_mc, old_m }{ old_mc = mc @@ -1139,7 +1128,6 @@ SCENARIO("Solve ODEs with derivimplicit method using SympySolverVisitor", }{ nmodl_eigen_x[0] = mc nmodl_eigen_x[1] = m - }{ nmodl_eigen_f[0] = -nmodl_eigen_x[0]*a*dt-nmodl_eigen_x[0]+nmodl_eigen_x[1]*b*dt+old_mc nmodl_eigen_j[0] = -a*dt-1.0 nmodl_eigen_j[2] = b*dt @@ -1180,7 +1168,7 @@ SCENARIO("Solve ODEs with derivimplicit method using SympySolverVisitor", })"; std::string expected_result = R"( DERIVATIVE ihkin { - EIGEN_NEWTON_SOLVE[5]{ + EIGEN_LINEAR_SOLVE[5]{ LOCAL alpha, beta, k3p, k4, k1ca, k2, old_c1, old_o1, old_p0 }{ evaluate_fct(v, cai) @@ -1193,7 +1181,6 @@ SCENARIO("Solve ODEs with derivimplicit method using SympySolverVisitor", nmodl_eigen_x[2] = o2 nmodl_eigen_x[3] = p0 nmodl_eigen_x[4] = p1 - }{ nmodl_eigen_f[0] = -nmodl_eigen_x[0]*alpha*dt-nmodl_eigen_x[0]+nmodl_eigen_x[1]*beta*dt+old_c1 nmodl_eigen_j[0] = -alpha*dt-1.0 nmodl_eigen_j[5] = beta*dt @@ -1260,13 +1247,12 @@ SCENARIO("Solve ODEs with derivimplicit method using SympySolverVisitor", )"; std::string expected_result = R"( DERIVATIVE scheme1 { - EIGEN_NEWTON_SOLVE[1]{ + EIGEN_LINEAR_SOLVE[1]{ LOCAL old_W_0 }{ old_W_0 = W[0] }{ nmodl_eigen_x[0] = W[0] - }{ nmodl_eigen_f[0] = -nmodl_eigen_x[0]*dt*A[0]+nmodl_eigen_x[0]*dt*B[0]-nmodl_eigen_x[0]+3.0*dt*A[1]+old_W_0 nmodl_eigen_j[0] = -dt*A[0]+dt*B[0]-1.0 }{ @@ -1300,7 +1286,7 @@ SCENARIO("Solve ODEs with derivimplicit method using SympySolverVisitor", )"; std::string expected_result = R"( DERIVATIVE scheme1 { - EIGEN_NEWTON_SOLVE[2]{ + EIGEN_LINEAR_SOLVE[2]{ LOCAL old_M_0, old_M_1 }{ old_M_0 = M[0] @@ -1308,7 +1294,6 @@ SCENARIO("Solve ODEs with derivimplicit method using SympySolverVisitor", }{ nmodl_eigen_x[0] = M[0] nmodl_eigen_x[1] = M[1] - }{ nmodl_eigen_f[0] = -nmodl_eigen_x[0]*dt*A[0]-nmodl_eigen_x[0]+nmodl_eigen_x[1]*dt*B[0]+old_M_0 nmodl_eigen_j[0] = -dt*A[0]-1.0 nmodl_eigen_j[2] = dt*B[0] @@ -1346,13 +1331,12 @@ SCENARIO("Solve ODEs with derivimplicit method using SympySolverVisitor", )"; std::string expected_result = R"( DERIVATIVE scheme1 { - EIGEN_NEWTON_SOLVE[1]{ + EIGEN_LINEAR_SOLVE[1]{ LOCAL old_W_0 }{ old_W_0 = W[0] }{ nmodl_eigen_x[0] = W[0] - }{ nmodl_eigen_f[0] = -nmodl_eigen_x[0]*dt*A[0]+nmodl_eigen_x[0]*dt*B[0]-nmodl_eigen_x[0]+3.0*dt*A[1]+old_W_0 nmodl_eigen_j[0] = -dt*A[0]+dt*B[0]-1.0 }{ @@ -2053,7 +2037,7 @@ SCENARIO("Solve NONLINEAR block using SympySolver Visitor", "[visitor][solver][s x } NONLINEAR nonlin { - ~ x = 5 + ~ x * x * x = 5 })"; std::string expected_text = R"( NONLINEAR nonlin { @@ -2062,8 +2046,8 @@ SCENARIO("Solve NONLINEAR block using SympySolver Visitor", "[visitor][solver][s }{ nmodl_eigen_x[0] = x }{ - nmodl_eigen_f[0] = 5.0-nmodl_eigen_x[0] - nmodl_eigen_j[0] = -1.0 + nmodl_eigen_f[0] = 5.0-pow(nmodl_eigen_x[0], 3) + nmodl_eigen_j[0] = -3.0 * pow(nmodl_eigen_x[0], 2) }{ x = nmodl_eigen_x[0] }{ @@ -2084,7 +2068,7 @@ SCENARIO("Solve NONLINEAR block using SympySolver Visitor", "[visitor][solver][s NONLINEAR nonlin { ~ s[0] = 1 ~ s[1] = 3 - ~ s[2] + s[1] = s[0] + ~ s[2] + s[1] = s[0] * s[0] })"; std::string expected_text = R"( NONLINEAR nonlin { @@ -2097,14 +2081,14 @@ SCENARIO("Solve NONLINEAR block using SympySolver Visitor", "[visitor][solver][s }{ nmodl_eigen_f[0] = 1.0-nmodl_eigen_x[0] nmodl_eigen_f[1] = 3.0-nmodl_eigen_x[1] - nmodl_eigen_f[2] = nmodl_eigen_x[0]-nmodl_eigen_x[1]-nmodl_eigen_x[2] + nmodl_eigen_f[2] = pow(nmodl_eigen_x[0], 2)-nmodl_eigen_x[1]-nmodl_eigen_x[2] nmodl_eigen_j[0] = -1.0 nmodl_eigen_j[3] = 0 nmodl_eigen_j[6] = 0 nmodl_eigen_j[1] = 0 nmodl_eigen_j[4] = -1.0 nmodl_eigen_j[7] = 0 - nmodl_eigen_j[2] = 1.0 + nmodl_eigen_j[2] = 2.0 * nmodl_eigen_x[0] nmodl_eigen_j[5] = -1.0 nmodl_eigen_j[8] = -1.0 }{