Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Allow vectors to be used with .data_set() #574

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

ntolley
Copy link
Contributor

@ntolley ntolley commented Jan 29, 2025

Closes #573. Only change is swapping the dimensions when squeezing the params used for the .set()
@michaeldeistler @jnsbck is there anything I'm missing? Hopefully this isn't breaking something somewhere else!

@ntolley
Copy link
Contributor Author

ntolley commented Jan 29, 2025

Also while working on this I realized the the regular net.set() function has some peculiar behavior as well. For instance the following code passes without any error when running the tests:

net = SimpleNet(2, 2, 4)
net.set("radius", np.repeat(1.0, 16))  # no error
net.cell(range(2)).set("radius", np.repeat(1.0, 16))  # no error, which is very weird!
net.cell(range(2)).set("radius", np.repeat(1.0, 2))  # error, which should pass

I've included tests for this which is why the PR is failing, they should pass if I limit the tests to just the data_set() functionality.

I may need some guidance on where to look to proceed. I imagine this is possible since net.cell(range(2)).make_trainable(...) has no issues with views

@michaeldeistler
Copy link
Contributor

michaeldeistler commented Jan 29, 2025

Hi Nick! Thanks for the PR!

I think set behaves as I would have expected. Currently, set must either get a scalar value or an array with as many entries as there are compartments being set.

net = SimpleNet(2, 2, 4)

creates a net of two neurons with 2 branches each, and each branch having 4 compartments. This makes a total of 16 compartments. Then, net.cell(range(2)) selects both cells => 16 compartments. This is why the first line passes.

I agree that it would be nice to support the second line. It's definitely possible, but it would require extra logic for the set method to understand that it should assign the two values to all compartments in the two cells.

That being said: since we currently do not support this yet, please remove the test.

I agree that there is something off with data_set (your example in the issue should be working).

@ntolley
Copy link
Contributor Author

ntolley commented Jan 29, 2025

Will do! Also I spoke to soon, this does indeed break a few things, I'll keep at it...

@michaeldeistler
Copy link
Contributor

michaeldeistler commented Jan 29, 2025

Yes, I also just found a small example that breaks after the change:

import jaxley as jx
import numpy as np
from jaxley.channels import Na

comp = jx.Compartment()
branch = jx.Branch(comp, ncomp=4)
cell1 = jx.Cell(branch, parents=[-1])
cell2 = jx.Cell(branch, parents=[-1, 0, 0])
net = jx.Network([cell1, cell2])

net.cell("all").make_trainable("radius")
params = net.get_parameters()
net.record('v')
s = jx.integrate(net, t_max=10, params=params)

I think the best place to fix it would be here, but it requires quite a bit of logic: If the number of compartments matches the number of values, then inds should be made into shape (num_inds, 1) here, not into shape (1, num_inds) (which happens with the np.atleast2d)

See this comment which explains how inds should be shaped.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

BUG: .data_set() function does not work with vectors
2 participants