Skip to content

Commit

Permalink
Support generating vectorized output
Browse files Browse the repository at this point in the history
  • Loading branch information
jdavidberger committed Oct 30, 2022
1 parent 6ff7c2f commit d75677a
Showing 1 changed file with 41 additions and 8 deletions.
49 changes: 41 additions & 8 deletions cnkalman/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,8 +467,8 @@ def arg_str(arg):
def arg_str_py(arg):
a = arg[1]
t = get_type(a)
t = t.replace("FLT*", "np.float32_t[:]")
t = t.replace("FLT", "np.float32_t")
t = t.replace("FLT*", "floating[:]")
t = t.replace("FLT", "floating")
return "%s %s" % (t, get_name(a))

def generate_args_string(args, as_call = False):
Expand Down Expand Up @@ -514,7 +514,7 @@ def update_free_symbols(v):
for v1 in v:
update_free_symbols(v1)

type_name = "np.float32_t[:,:]"
type_name = "floating[:,:]"
type = None
if isinstance(flatten, dict):
type_name = flatten["$original"].__class__.__name__
Expand Down Expand Up @@ -640,15 +640,45 @@ def get_row_str():
func_name,
", ".join(map(arg_str_py, enumerate(args)))
))
emit_code(
f"\tcdef np.ndarray[float, ndim=2] {outputs[0][0]} = np.zeros(({abs(current_shape[0])},{current_shape[1]}), dtype=np.float32)")
call_args = ",".join([s[0] for s in outputs]) + ", " + generate_args_string(args, as_call=True)
emit_code(f"\tcdef np.ndarray[double, ndim=2] {outputs[0][0]} = np.zeros(({abs(current_shape[0])},{current_shape[1]}), dtype=np.float64)")
emit_code(f"\tcdef floating[:,:] _{outputs[0][0]} = {outputs[0][0]}")
call_args = ",".join([f"_{s[0]}" for s in outputs]) + ", " + generate_args_string(args, as_call=True)
emit_code(f"\t{prefix}{func_name}_nogil({call_args})")
if current_shape[1] == 1:
emit_code(f"\treturn {', '.join([s[0] for s in outputs])}.reshape(-1)")
else:
emit_code(f"\treturn {', '.join([s[0] for s in outputs])}")

for vectorize_over in argument_specs.get('_vectorize', []):
def get_vec_arg(a):
t = get_type(a)
name = get_name(a)
if name == vectorize_over:
t += "*"
t = t.replace("FLT**", "floating[:,:]")
t = t.replace("FLT*", "floating[:]")
t = t.replace("FLT", "floating")
return "%s %s" % (t, name)

output_shape = f"{vectorize_over}.shape[0]", abs(current_shape[0]), current_shape[1]
if output_shape[-1] == 1:
output_shape = output_shape[:-1]

dtype = "np.float32 if floating is float else np.float64"
vec_args = ', '.join(map(get_vec_arg, args))
call_args = ",".join([f"_{s[0]}[idx,:,:]" for s in outputs] + [f"{get_name(a)}[idx]" if vectorize_over == get_name(a) else get_name(a) for a in args])
emit_header(f"cpdef np.ndarray {prefix}{func_name}_vectorize_{vectorize_over}({vec_args}) ")
emit_code(f"""
cpdef np.ndarray {prefix}{func_name}_vectorize_{vectorize_over}({vec_args}):
cdef np.ndarray[floating, ndim=3] {outputs[0][0]} = np.zeros(({vectorize_over}.shape[0],{abs(current_shape[0])},{current_shape[1]}), dtype={dtype})
cdef floating[:,:,:] _{outputs[0][0]} = {outputs[0][0]}
cdef int idx
with nogil:
for idx in range({vectorize_over}.shape[0]):
{prefix}{func_name}_nogil({call_args})
return {outputs[0][0]}.reshape({', '.join(map(str,output_shape))})
""".replace(" ", "\t"))

emit_header("cpdef void %s%s(%s, %s)" % (prefix,
func_name,
", ".join([f"{type_name} " + s[0] for s in outputs]),
Expand Down Expand Up @@ -869,7 +899,7 @@ def emit_code(*args, **kwargs):
continue

#emit_code("// Jacobian of", func.__qualname__.replace(".", "_"), "wrt", jac_value)
codegen(this_jac, fname, func_args, suffix=suffix, outputs=[('Hx', jac_shape, jac_value)], input_keys=keys,file=file, prefix=prefix)
codegen(this_jac, fname, func_args, argument_specs=argument_specs, suffix=suffix, outputs=[('Hx', jac_shape, jac_value)], input_keys=keys,file=file, prefix=prefix)

#jac_with_hx = this_jac.reshape(jac_size, 1).col_join(fxm.reshape(fx_size, 1))

Expand Down Expand Up @@ -954,6 +984,7 @@ def get_pyx_file(fn):
if os.path.exists(full_path):
if not force_generate and \
os.path.getmtime(full_path) > os.path.getmtime(fn) and \
os.path.getmtime(full_path) > os.path.getmtime(__file__) and \
os.path.getmtime(f'{path.parent.as_posix()}/{new_stem}.pxd') > os.path.getmtime(fn):
return None

Expand All @@ -970,7 +1001,7 @@ def get_pyx_file(fn):
from libc.math cimport *
from libc.stdint cimport *
from libc cimport *
from cython cimport floating
""")

g = open(f"{path.parent.as_posix()}/{new_stem}.pxd", 'w')
Expand All @@ -980,6 +1011,8 @@ def get_pyx_file(fn):
import numpy as np
cimport numpy as np
from cython cimport floating
""")
generate_code_files[fn] = (f, g)
return generate_code_files[fn]
Expand Down

0 comments on commit d75677a

Please sign in to comment.