Skip to content

Commit

Permalink
Add standalone serialization function
Browse files Browse the repository at this point in the history
This commit adds a standalone serialization function to the symengine
wrapper module. Previously there was a load_basic() function that could
be used for directly deserializing the payload without going through
pickle, but no corresponding function for serialization. A new function
save_basic() is added here to perform this function which will enable
users to directly roundtrip symengine objects through binary
serialization.

Fixes symengine#449
  • Loading branch information
mtreinish committed Sep 5, 2023
1 parent 8d5d63c commit 02e0462
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 5 deletions.
13 changes: 8 additions & 5 deletions symengine/lib/symengine_wrapper.in.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -829,11 +829,6 @@ cdef list vec_pair_to_list(symengine.vec_pair& vec):
result.append((c2py(a), c2py(b)))
return result


def load_basic(bytes s):
return c2py(symengine.wrapper_loads(s))


repr_latex=[False]

cdef class Basic(object):
Expand Down Expand Up @@ -1225,6 +1220,14 @@ cdef class Basic(object):
return d


def load_basic(bytes s):
return c2py(symengine.wrapper_loads(s))


cpdef save_basic(Basic basic):
return symengine.wrapper_dumps(basic)


def series(ex, x=None, x0=0, n=6, as_deg_coef_pair=False):
# TODO: check for x0 an infinity, see sympy/core/expr.py
# TODO: nonzero x0
Expand Down
7 changes: 7 additions & 0 deletions symengine/tests/test_pickling.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from symengine import symbols, sin, sinh, have_numpy, have_llvm, cos, Symbol
from symengine.lib.symengine_wrapper import load_basic, save_basic
from symengine.test_utilities import raises
import pickle
import unittest
Expand All @@ -11,6 +12,12 @@ def test_basic():
expr2 = pickle.loads(s)
assert expr == expr2

def test_basic_direct():
x, y, z = symbols('x y z')
expr = sin(cos(x + y)/z)**2
s = save_basic(expr)
expr2 = load_basic(s)
assert expr == expr2

class MySymbolBase(Symbol):
def __init__(self, name, attr):
Expand Down

0 comments on commit 02e0462

Please sign in to comment.