Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lagrange [WIP] #31

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open

Lagrange [WIP] #31

wants to merge 3 commits into from

Conversation

NicholasCowie
Copy link
Contributor

Summary

Steady-state solving is difficult and errors are problematic. This approach is not necessarily better than the method posed in diffrax but it has the potential to fail in a detectable way; this may be desirable for mcmc. Preliminary tests suggest that is may be faster than standard ODE solving, up to 3 times faster for the methionine cycle “BAD_GUESS” test case, and it achieves steady state.

Checklist:

  • tests pass
  • README.md up to date
  • docs up to date
  • link to any relevant issues

Comment on lines +25 to +26
S = model.structure.S
dG = (S.T @ model.parameters.dgf + 2.4788191*[email protected](conc))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can these lines be deleted?

Comment on lines 77 to 79
lagrangian,
solver,
jnp.concat([jnp.log(guess), lambda_guess]),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is more readable and can avoid annoying copy/paste bugs later:

Suggested change
lagrangian,
solver,
jnp.concat([jnp.log(guess), lambda_guess]),
fn=lagrangian,
solver=solver,
y0=jnp.concat([jnp.log(guess), lambda_guess]),

def lagrangian(
z: Float[Array, " n_balanced*2"],
model: RateEquationModel,
) -> Float[Array, " n_balanced*2"]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a docstring for this function?

) -> Float:
S = model.structure.S
dG = (S.T @ model.parameters.dgf + 2.4788191*[email protected](conc))
return sum(jnp.square(model.dcdt(0, x)))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe either jnp.square(model.dcdt(0, x)).sum() or jnp.sum(jnp.square(model.dcdt(0, x))) would be better here - I'm not sure exactly what happens when you use the standard sum

Copy link
Contributor

@teddygroves teddygroves left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, I think there are just a few lint things to fix before merging.

I think we should also add unit tests for steady state functions but that can wait

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants