Skip to content

Commit

Permalink
Update Jax and Haiku versions
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 474776295
Change-Id: I8f5add84f45069f2c48ac905801a611cb298b319
  • Loading branch information
Htomlinson14 authored and copybara-github committed Sep 16, 2022
1 parent 2eb6c0c commit 5e702ce
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 9 deletions.
2 changes: 1 addition & 1 deletion alphafold/model/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def final_init(config):

def batched_gather(params, indices, axis=0, batch_dims=0):
"""Implements a JAX equivalent of `tf.gather` with `axis` and `batch_dims`."""
take_fn = lambda p, i: jnp.take(p, i, axis=axis)
take_fn = lambda p, i: jnp.take(p, i, axis=axis, mode='clip')
for _ in range(batch_dims):
take_fn = jax.vmap(take_fn)
return take_fn(params, indices)
Expand Down
4 changes: 2 additions & 2 deletions docker/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@ RUN wget -q -P /app/alphafold/alphafold/common/ \
RUN pip3 install --upgrade pip \
&& pip3 install -r /app/alphafold/requirements.txt \
&& pip3 install --upgrade \
jax==0.2.14 \
jaxlib==0.1.69+cuda$(cut -f1,2 -d. <<< ${CUDA} | sed 's/\.//g') \
jax==0.3.17 \
jaxlib==0.3.15+cuda11.cudnn805 \
-f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

# Apply OpenMM patch.
Expand Down
2 changes: 1 addition & 1 deletion docker/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# Dependencies necessary to execute run_docker.py
absl-py==0.13.0
absl-py==1.0.0
docker==5.0.0
10 changes: 5 additions & 5 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
absl-py==0.13.0
absl-py==1.0.0
biopython==1.79
chex==0.0.7
dm-haiku==0.0.4
dm-haiku==0.0.7
dm-tree==0.1.6
docker==5.0.0
immutabledict==2.0.0
jax==0.2.14
jax==0.3.17
ml-collections==0.1.0
numpy==1.19.5
numpy==1.21.6
pandas==1.3.4
protobuf==3.20.1
scipy==1.7.0
tensorflow-cpu==2.5.0
tensorflow-cpu==2.9.0

0 comments on commit 5e702ce

Please sign in to comment.