Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feat] Support brainunit.Quantity type value param in creating Variable #712

Draft
wants to merge 7 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,8 @@ jobs:
if [ -f requirements-dev.txt ]; then pip install -r requirements-dev.txt; fi
pip uninstall brainpy -y
python setup.py install
pip install jax==0.4.30
pip install jaxlib==0.4.30
# pip install jax==0.4.30
# pip install jaxlib==0.4.30
- name: Test with pytest
run: |
cd brainpy
Expand Down
3 changes: 3 additions & 0 deletions brainpy/_src/dyn/others/tests/test_noise_groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
import brainpy as bp
import brainpy.math as bm
from absl.testing import parameterized
import pytest

pytest.skip("Skip the test due to the jax 0.5.0 version", allow_module_level=True)


class Test_Noise_Group(parameterized.TestCase):
Expand Down
49 changes: 48 additions & 1 deletion brainpy/_src/math/object_transform/tests/test_variable.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import brainpy.math as bm
import brainunit as u
import jax.numpy as jnp
from functools import partial
import unittest


class TestVar(unittest.TestCase):
def test1(self):
def test_ndarray(self):
class A(bm.BrainPyObject):
def __init__(self):
super().__init__()
Expand Down Expand Up @@ -46,6 +49,50 @@ def fff(self):

bm.clear_buffer_memory()

def test_state(self):
class B(bm.BrainPyObject):
def __init__(self):
super().__init__()
self.a = bm.Variable([0.,] * u.mV)
self.f1 = bm.jit(self.f)
self.f2 = bm.jit(self.ff)
self.f3 = bm.jit(self.fff)

def f(self):
ones_fun = partial(u.math.ones,unit=u.mV)
b = self.tracing_variable('b', ones_fun, (1,))
self.a += (b * 2)
return self.a.value

def ff(self):
self.b += 1. * u.mV

def fff(self):
self.f()
self.ff()
self.b *= self.a.value.mantissa
return self.b.value

print()
f_jit = bm.jit(B().f)
f_jit()
self.assertTrue(len(f_jit._dyn_vars) == 2)

print()
b = B()
self.assertTrue(u.math.all(b.f1() == [2.,] * u.mV))
self.assertTrue(len(b.f1._dyn_vars) == 2)
print(b.f2())
self.assertTrue(len(b.f2._dyn_vars) == 1)

print()
b = B()
print()
self.assertTrue(u.math.allclose(b.f3(), 4. * u.mV))
self.assertTrue(len(b.f3._dyn_vars) == 2)

bm.clear_buffer_memory()




16 changes: 13 additions & 3 deletions brainpy/_src/math/object_transform/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from jax.tree_util import register_pytree_node_class

from brainpy._src.math.ndarray import Array
from brainstate import State
from brainunit import Quantity
from brainpy._src.math.sharding import BATCH_AXIS
from brainpy.errors import MathError

Expand Down Expand Up @@ -220,7 +222,7 @@ def __add__(self, other: dict):


@register_pytree_node_class
class Variable(Array):
class Variable(Array, State):
"""The pointer to specify the dynamical variable.

Initializing an instance of ``Variable`` by two ways:
Expand Down Expand Up @@ -250,7 +252,8 @@ def __init__(
batch_axis: int = None,
*,
axis_names: Optional[Sequence[str]] = None,
ready_to_trace: bool = None
ready_to_trace: bool = None,
state_mode: bool = False,
):
if isinstance(value_or_size, int):
value = jnp.zeros(value_or_size, dtype=dtype)
Expand All @@ -259,7 +262,14 @@ def __init__(
else:
value = value_or_size

super().__init__(value, dtype=dtype)
if isinstance(value, Quantity):
state_mode = True

if state_mode:
State.__init__(self, value, dtype=dtype)
self._value = value
else:
Array.__init__(self, value, dtype=dtype)

# check batch axis
if isinstance(value, Variable):
Expand Down
53 changes: 39 additions & 14 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -1,17 +1,42 @@
numpy
jax
jaxlib
matplotlib
msgpack
tqdm
pathos
braintaichi
numba
brainstate
braintools
setuptools


# test requirements
pytest
absl-py
absl-py<=2.1.0
brainstate<=0.1.0.post20241210
braintaichi<=0.0.4
braintools<=0.0.4.post20241215
brainunit<=0.0.4
colorama<=0.4.6
contourpy<=1.3.1
cycler<=0.12.1
dill<=0.3.9
fonttools<=4.55.3
iniconfig<=2.0.0
kiwisolver<=1.4.7
llvmlite<=0.43.0
markdown-it-py<=3.0.0
matplotlib<=3.10.0
mdurl<=0.1.2
ml_dtypes<=0.5.0
msgpack<=1.1.0
multiprocess<=0.70.17
numba<=0.60.0
numpy<=2.0.2
opt_einsum<=3.4.0
packaging<=24.2
pathos<=0.3.3
pillow<=11.0.0
pluggy<=1.5.0
pox<=0.3.5
ppft<=1.7.6.9
pygments<=2.18.0
pyparsing<=3.2.0
pytest<=8.3.4
python-dateutil<=2.9.0.post0
rich<=13.9.4
scipy<=1.14.1
setuptools<=75.6.0
six<=1.17.0
taichi<=1.7.2
tqdm<=4.67.1
typing-extensions<=4.12.2
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
author_email='[email protected]',
packages=packages,
python_requires='>=3.9',
install_requires=['numpy>=1.15', 'jax>=0.4.13', 'tqdm'],
install_requires=['numpy>=1.15', 'jax>=0.4.13', 'tqdm', 'brainstate', 'brainunit'],
url='https://github.com/brainpy/BrainPy',
project_urls={
"Bug Tracker": "https://github.com/brainpy/BrainPy/issues",
Expand Down
Loading