Skip to content

Commit

Permalink
Merge pull request #65 from anh-tong/fix-tree-map
Browse files Browse the repository at this point in the history
Fix deprecated tree_map in JAX
  • Loading branch information
anh-tong authored Aug 28, 2024
2 parents db14f58 + e15a808 commit 56bad4e
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 6 deletions.
6 changes: 1 addition & 5 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -51,15 +51,11 @@ repos:
args: ["--fix", "--show-fixes"]

- repo: https://github.com/pre-commit/mirrors-mypy
rev: "v1.2.0"
rev: "v1.11.0"
hooks:
- id: mypy
files: src
args: []
additional_dependencies:
- pytest
- jax
- jaxlib

- repo: https://github.com/shellcheck-py/shellcheck-py
rev: "v0.9.0.2"
Expand Down
2 changes: 1 addition & 1 deletion src/signax/signatures.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,6 @@ def multi_signature_combine(signatures: list[Array]) -> list[Array]:
elems=signatures,
)
# return the last index after associative scan
result = jax.tree_map(lambda x: x[-1], result)
result = jax.tree_util.tree_map(lambda x: x[-1], result)

return result

0 comments on commit 56bad4e

Please sign in to comment.