-
Notifications
You must be signed in to change notification settings - Fork 95
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[math] Add taichi customized operators (event csrmv, csrmv, jitconn e…
…vent mv, jitconn mv) (#553) * Add _csr_matvec_taichi.py * Test event csr matvec using taichi custom op * Update _csr_matvec_taichi.py * Add sparse csr matvec using taichi customized op * Test event csr matvec using taichi customized op * Implement autograd of event csr matvec using taichi customized op * Update test of `test_event_csrmv_taichi.py` * Update _csr_mv_taichi.py * Test sparse csr matvec using taichi customized op * Update test_csrmv_taichi.py * Remove test event and sparse csrmv using taichi from pytest * Fix autograd bug and update test_csrmv_taichi.py * Fix autograd bug and update `test_event_csr_matvec_taichi.py` * Fix event csr matvec kernel bug * Fix test bugs * Add taichi.func random generators * Update `test_taichi_random.py` * Implement `mv_prob_homo_taichi` and `mv_prob_uniform_taichi` * Implement jitconn matvec using taichi customized op` and Need to test * Fix bugs in * Remove pytest in 'test_taichi_random.py' * Implement jitconn event matvec using taichi customized op and Need to test * Implement lfsr88 random generator algorithm * Refactor `jitconn/_matvec_taichi.py` with lfsr88 random generator * [csrmv taichi] format codes and redefine JVP rules using `.defjvp` interface * [csrmv taichi] format codes of `brainpy.math.sparse.csrmv` and redefine JVP rules using `.defjvp` interface * [math] depress taichi import logging by forcing using `import_taichi()` utility, move taichi random functions into another file * fix missing file * Optimize event csr matvec with taichi customized op and Add taichi event csr matvec benchmark * Update event_csrmv_taichi_VS_event_csrmv.py * Optimize csr matvec with taichi customized op and Add taichi csr matvec benchmark * Fix bugs * Add more benchmarks * Update benchmarks * Optimized taichi event csr matvec gpu * Update benchmarks * Update benchmarks * Update benchmarks * Update benchmarks * Optimized taichi customized cpu kernels about event csr matvec and csr matvec * Add taichi jitconn matvec benchmark and Optimize taichi jitconn matvec op * Refactor taichi event matvec op * Add taichi jitconn event matvec benchmark * Optimize taichi jitconn matvec op on CPU backend * Update taichi jitconn matvec op * Update test files for taichi jitconn op * Update taichi random generator * Fix bugs * Add new function for taichi random seeds initialization * Update taichi_random_time_test.py * Update taichi_random_time_test.py * Update taichi_random_time_test.py * Fix bugs * Remove taichi_random_time_test.py * Update test_taichi_random.py * [event csr taichi] small upgrade * [csr mv taichi] fix bugs * [math] new module `brainpy.math.tifunc` for taichi functionality * [math] move default environment setting into `defaults.py` * [math] fix and update taichi jitconn operators --------- Co-authored-by: chaoming <[email protected]>
- Loading branch information
1 parent
9662fbb
commit 6368289
Showing
34 changed files
with
8,732 additions
and
410 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
import jax.numpy as jnp | ||
from jax import config | ||
|
||
from brainpy._src.dependency_check import import_taichi | ||
from .modes import NonBatchingMode | ||
from .scales import IdScaling | ||
|
||
__all__ = ['mode', 'membrane_scaling', 'dt', 'bool_', 'int_', 'ti_int', 'float_', 'ti_float', 'complex_'] | ||
|
||
ti = import_taichi() | ||
|
||
# Default computation mode. | ||
mode = NonBatchingMode() | ||
|
||
# '''Default computation mode.''' | ||
membrane_scaling = IdScaling() | ||
|
||
# '''Default time step.''' | ||
dt = 0.1 | ||
|
||
# '''Default bool data type.''' | ||
bool_ = jnp.bool_ | ||
|
||
# '''Default integer data type.''' | ||
int_ = jnp.int64 if config.read('jax_enable_x64') else jnp.int32 | ||
|
||
# '''Default integer data type in Taichi.''' | ||
ti_int = ti.int64 if config.read('jax_enable_x64') else ti.int32 | ||
|
||
# '''Default float data type.''' | ||
float_ = jnp.float64 if config.read('jax_enable_x64') else jnp.float32 | ||
|
||
# '''Default float data type in Taichi.''' | ||
ti_float = ti.float64 if config.read('jax_enable_x64') else ti.float32 | ||
|
||
# '''Default complex data type.''' | ||
complex_ = jnp.complex128 if config.read('jax_enable_x64') else jnp.complex64 | ||
|
||
# redirects | ||
redirects = {'mode': mode, | ||
'membrane_scaling': membrane_scaling, | ||
'dt': dt, | ||
'bool_': bool_, | ||
'int_': int_, | ||
'ti_int': ti_int, | ||
'float_': float_, | ||
'ti_float': ti_float, | ||
'complex_': complex_} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
|
||
from ._info_collection import * | ||
from ._csr_matvec import * | ||
from ._csr_matvec_taichi import * | ||
|
Oops, something went wrong.