Skip to content

Commit

Permalink
Constify read only buffers
Browse files Browse the repository at this point in the history
  • Loading branch information
jdavidberger committed Mar 15, 2023
1 parent 0395d88 commit e0e7b06
Showing 1 changed file with 24 additions and 16 deletions.
40 changes: 24 additions & 16 deletions cnkalman/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,7 +469,7 @@ def arg_str_py(arg):
t = get_type(a)
t = t.replace("FLT*", "floating[:]")
t = t.replace("FLT", "floating")
return "%s %s" % (t, get_name(a))
return "const %s %s" % (t, get_name(a))

def generate_args_string(args, as_call = False):
return ", ".join(map(lambda x: get_name(x[1]) if as_call else arg_str, enumerate(args)))
Expand All @@ -484,7 +484,7 @@ def emit_header(*args, **kwargs):
if file is not None:
print(*args, **kwargs, file=file[1])
file[1].flush()

print(f"\tGenerating {name}...", file=sys.stderr)
flatten, args = flatten_func(func, name, args, suffix, argument_specs)
if flatten is None:
return None
Expand Down Expand Up @@ -524,12 +524,16 @@ def update_free_symbols(v):
values = [flatten[k] for k in keys]
keys = [k[1] for k in keys]
values = [ a.symengine_type() if hasattr(a, 'symengine_type') else a for a in values ]
print(f"\tCSE for {len(values)} dict values", file=sys.stderr)
cse_output = cse(symengine.Matrix(values))
update_free_symbols(values)

else:
print(f"\tCSE for {len(flatten)} values", file=sys.stderr)
cse_output = cse(symengine.Matrix(flatten))
update_free_symbols(flatten)

print(f"\t{len(free_symbols)} free symbols...", file=sys.stderr)
func_name = name
emit_header("cdef void %s%s_nogil(%s, %s) nogil " % (prefix,
name,
Expand All @@ -552,18 +556,18 @@ def update_free_symbols(v):
name = get_name(a)
for k, v in flatten_args(a()):
if f"{name}{k.strip('[]')}" in free_symbols:
emit_code("\tcdef float %s = %s%s" % (str(v), "("+name+")" if isinstance_namedtuple(a()) else name, k))
emit_code("\tcdef floating %s = %s%s" % (str(v), "("+name+")" if isinstance_namedtuple(a()) else name, k))
elif isinstance(a, WrapTuple):
name = get_name(a)
digits = math.floor(math.log(len(a.t)))
for k, v in flatten_args(a.t):
idx = k.strip('[]')
if f"{name}{str(idx).zfill(digits)}" in free_symbols:
emit_code("\tcdef float %s = %s%s" % (str(v), name, k))
emit_code("\tcdef floating %s = %s%s" % (str(v), name, k))

for item in cse_output[0]:
stripped_line = pyxcode(item[1]).replace("\n", " ").replace("\t", " ")
emit_code(f"\tcdef float {symengine.ccode(item[0])} = {stripped_line}")
emit_code(f"\tcdef floating {symengine.ccode(item[0])} = {stripped_line}")

output_idx = 0
outputs_idx = 0
Expand Down Expand Up @@ -603,7 +607,7 @@ def get_row_str():
return str(current_row)
if hasattr(item, "tolist"):
for item1 in sum(item.tolist(), []):
emit_code("\t%s[%s,%s] = %s" % (outputs[outputs_idx][0], get_row_str(), get_col_str(), output_idx, pyxode(item1).replace("\n", " ").replace("\t", " ")))
emit_code("\t%s[%s,%s] = %s" % (outputs[outputs_idx][0], get_row_str(), get_col_str(), output_idx, pyxcode(item1).replace("\n", " ").replace("\t", " ")))
output_idx += 1
current_row = output_idx / current_shape[1]
current_col = output_idx % current_shape[1]
Expand All @@ -621,6 +625,7 @@ def get_row_str():

emit_code("")

release_gil = len(cse_output[1]) > 200
emit_code("cpdef void %s%s(%s, %s): " % (prefix,
func_name,
", ".join([f"{type_name} " + s[0] for s in outputs]),
Expand All @@ -643,7 +648,7 @@ def get_row_str():
emit_code(f"\tcdef np.ndarray[double, ndim=2] {outputs[0][0]} = np.zeros(({abs(current_shape[0])},{current_shape[1]}), dtype=np.float64)")
emit_code(f"\tcdef floating[:,:] _{outputs[0][0]} = {outputs[0][0]}")
call_args = ",".join([f"_{s[0]}" for s in outputs]) + ", " + generate_args_string(args, as_call=True)
emit_code(f"\t{prefix}{func_name}_nogil({call_args})")
emit_code(f"\t{'with nogil: ' if release_gil else ''}{prefix}{func_name}_nogil({call_args})")
if current_shape[1] == 1:
emit_code(f"\treturn {', '.join([s[0] for s in outputs])}.reshape(-1)")
else:
Expand All @@ -658,7 +663,7 @@ def get_vec_arg(a):
t = t.replace("FLT**", "floating[:,:]")
t = t.replace("FLT*", "floating[:]")
t = t.replace("FLT", "floating")
return "%s %s" % (t, name)
return "const %s %s" % (t, name)

output_shape = f"{vectorize_over}.shape[0]", abs(current_shape[0]), current_shape[1]
if output_shape[-1] == 1:
Expand Down Expand Up @@ -894,9 +899,9 @@ def emit_code(*args, **kwargs):

jac_shape = this_jac.shape
jac_size = this_jac.shape[0] * this_jac.shape[1]

if jac_size == 1:
continue
#
# if jac_size == 1:
# continue

#emit_code("// Jacobian of", func.__qualname__.replace(".", "_"), "wrt", jac_value)
codegen(this_jac, fname, func_args, argument_specs=argument_specs, suffix=suffix, outputs=[('Hx', jac_shape, jac_value)], input_keys=keys,file=file, prefix=prefix)
Expand Down Expand Up @@ -982,9 +987,9 @@ def get_pyx_file(fn):

full_path = f'{path.parent.as_posix()}/{new_stem}.pyx'
if os.path.exists(full_path):
#os.path.getmtime(full_path) > os.path.getmtime(__file__) and \
if not force_generate and \
os.path.getmtime(full_path) > os.path.getmtime(fn) and \
os.path.getmtime(full_path) > os.path.getmtime(__file__) and \
os.path.getmtime(f'{path.parent.as_posix()}/{new_stem}.pxd') > os.path.getmtime(fn):
return None

Expand Down Expand Up @@ -1038,7 +1043,7 @@ def get_file(fn):

import numpy as np
def functionify(args_info, jac):
def f(*args):
def f(*args, force_resolve=True):
subset = {}
for i, info in enumerate(args_info):
if isinstance(info, WrapTuple):
Expand All @@ -1047,8 +1052,10 @@ def f(*args):
subset[s.__str__()] = v
else:
subset[info.__str__()] = args[i]

return np.array(jac.subs(subset)).astype(np.float64)
rtn = np.array(jac.subs(subset))
if force_resolve:
return rtn.astype(np.float64)
return rtn
return f

def expand_hint(v, length):
Expand Down Expand Up @@ -1098,11 +1105,12 @@ def eval(self, reeval=False):
f = self.f

global use_symbolic_eval
old_use_symbolic_eval = use_symbolic_eval
use_symbolic_eval = True
is_pyx = isinstance(f, tuple)
jacs, args = generate_code_and_jacobians(func, argument_specs=self.kwargs, jac_all=self.kwargs.get('_all', False),
file=f, prefix=self.prefix, codegen = generate_pyxcode if is_pyx else generate_ccode)
use_symbolic_eval = False
use_symbolic_eval = old_use_symbolic_eval
self._jacs = jacs
self._args = args

Expand Down

0 comments on commit e0e7b06

Please sign in to comment.