Skip to content

Commit

Permalink
Add ability to only run relax for the best unrelaxed model.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 501851892
Change-Id: I7484c2aa7ac30af611d88cfa8d632096c262824a
  • Loading branch information
Htomlinson14 authored and copybara-github committed Jan 13, 2023
1 parent 0d9a24b commit 684ffa1
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 56 deletions.
17 changes: 10 additions & 7 deletions docker/run_docker.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,15 @@

flags.DEFINE_bool(
'use_gpu', True, 'Enable NVIDIA runtime to run with GPUs.')
flags.DEFINE_boolean(
'run_relax', True,
'Whether to run the final relaxation step on the predicted models. Turning '
'relax off might result in predictions with distracting stereochemical '
'violations but might help in case you are having issues with the '
'relaxation stage.')
flags.DEFINE_enum('models_to_relax', 'best', ['best', 'all', 'none'],
'The models to run the final relaxation step on. '
'If `all`, all models are relaxed, which may be time '
'consuming. If `best`, only the most confident model is '
'relaxed. If `none`, relaxation is not run. Turning off '
'relaxation might result in predictions with '
'distracting stereochemical violations but might help '
'in case you are having issues with the relaxation '
'stage.')
flags.DEFINE_bool(
'enable_gpu_relax', True, 'Run relax on GPU if GPU is enabled.')
flags.DEFINE_string(
Expand Down Expand Up @@ -221,7 +224,7 @@ def main(argv):
f'--benchmark={FLAGS.benchmark}',
f'--use_precomputed_msas={FLAGS.use_precomputed_msas}',
f'--num_multimer_predictions_per_model={FLAGS.num_multimer_predictions_per_model}',
f'--run_relax={FLAGS.run_relax}',
f'--models_to_relax={FLAGS.models_to_relax}',
f'--use_gpu_relax={use_gpu_relax}',
'--logtostderr',
])
Expand Down
108 changes: 65 additions & 43 deletions run_alphafold.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

"""Full AlphaFold protein structure prediction script."""
import enum
import json
import os
import pathlib
Expand Down Expand Up @@ -43,6 +44,13 @@

logging.set_verbosity(logging.INFO)


@enum.unique
class ModelsToRelax(enum.Enum):
ALL = 0
BEST = 1
NONE = 2

flags.DEFINE_list(
'fasta_paths', None, 'Paths to FASTA files, each containing a prediction '
'target that will be folded one after another. If a FASTA file contains '
Expand Down Expand Up @@ -119,11 +127,15 @@
'runs that are to reuse the MSAs. WARNING: This will not '
'check if the sequence, database or configuration have '
'changed.')
flags.DEFINE_boolean('run_relax', True, 'Whether to run the final relaxation '
'step on the predicted models. Turning relax off might '
'result in predictions with distracting stereochemical '
'violations but might help in case you are having issues '
'with the relaxation stage.')
flags.DEFINE_enum_class('models_to_relax', ModelsToRelax.BEST, ModelsToRelax,
'The models to run the final relaxation step on. '
'If `all`, all models are relaxed, which may be time '
'consuming. If `best`, only the most confident model '
'is relaxed. If `none`, relaxation is not run. Turning '
'off relaxation might result in predictions with '
'distracting stereochemical violations but might help '
'in case you are having issues with the relaxation '
'stage.')
flags.DEFINE_boolean('use_gpu_relax', None, 'Whether to relax on GPU. '
'Relax on GPU can be much faster than CPU, so it is '
'recommended to enable if possible. GPUs must be available'
Expand Down Expand Up @@ -156,7 +168,8 @@ def predict_structure(
model_runners: Dict[str, model.RunModel],
amber_relaxer: relax.AmberRelaxation,
benchmark: bool,
random_seed: int):
random_seed: int,
models_to_relax: ModelsToRelax):
"""Predicts structure using AlphaFold for the given sequence."""
logging.info('Predicting %s', fasta_name)
timings = {}
Expand All @@ -180,6 +193,7 @@ def predict_structure(
pickle.dump(feature_dict, f, protocol=4)

unrelaxed_pdbs = {}
unrelaxed_proteins = {}
relaxed_pdbs = {}
relax_metrics = {}
ranking_confidences = {}
Expand Down Expand Up @@ -232,38 +246,48 @@ def predict_structure(
b_factors=plddt_b_factors,
remove_leading_feature_dimension=not model_runner.multimer_mode)

unrelaxed_proteins[model_name] = unrelaxed_protein
unrelaxed_pdbs[model_name] = protein.to_pdb(unrelaxed_protein)
unrelaxed_pdb_path = os.path.join(output_dir, f'unrelaxed_{model_name}.pdb')
with open(unrelaxed_pdb_path, 'w') as f:
f.write(unrelaxed_pdbs[model_name])

if amber_relaxer:
# Relax the prediction.
t_0 = time.time()
relaxed_pdb_str, _, violations = amber_relaxer.process(
prot=unrelaxed_protein)
relax_metrics[model_name] = {
'remaining_violations': violations,
'remaining_violations_count': sum(violations)
}
timings[f'relax_{model_name}'] = time.time() - t_0

relaxed_pdbs[model_name] = relaxed_pdb_str

# Save the relaxed PDB.
relaxed_output_path = os.path.join(
output_dir, f'relaxed_{model_name}.pdb')
with open(relaxed_output_path, 'w') as f:
f.write(relaxed_pdb_str)

# Rank by model confidence and write out relaxed PDBs in rank order.
ranked_order = []
for idx, (model_name, _) in enumerate(
sorted(ranking_confidences.items(), key=lambda x: x[1], reverse=True)):
ranked_order.append(model_name)
# Rank by model confidence.
ranked_order = [
model_name for model_name, confidence in
sorted(ranking_confidences.items(), key=lambda x: x[1], reverse=True)]

# Relax predictions.
if models_to_relax == ModelsToRelax.BEST:
to_relax = [ranked_order[0]]
elif models_to_relax == ModelsToRelax.ALL:
to_relax = ranked_order
elif models_to_relax == ModelsToRelax.NONE:
to_relax = []

for model_name in to_relax:
t_0 = time.time()
relaxed_pdb_str, _, violations = amber_relaxer.process(
prot=unrelaxed_proteins[model_name])
relax_metrics[model_name] = {
'remaining_violations': violations,
'remaining_violations_count': sum(violations)
}
timings[f'relax_{model_name}'] = time.time() - t_0

relaxed_pdbs[model_name] = relaxed_pdb_str

# Save the relaxed PDB.
relaxed_output_path = os.path.join(
output_dir, f'relaxed_{model_name}.pdb')
with open(relaxed_output_path, 'w') as f:
f.write(relaxed_pdb_str)

# Write out relaxed PDBs in rank order.
for idx, model_name in enumerate(ranked_order):
ranked_output_path = os.path.join(output_dir, f'ranked_{idx}.pdb')
with open(ranked_output_path, 'w') as f:
if amber_relaxer:
if model_name in relaxed_pdbs:
f.write(relaxed_pdbs[model_name])
else:
f.write(unrelaxed_pdbs[model_name])
Expand All @@ -279,7 +303,7 @@ def predict_structure(
timings_output_path = os.path.join(output_dir, 'timings.json')
with open(timings_output_path, 'w') as f:
f.write(json.dumps(timings, indent=4))
if amber_relaxer:
if models_to_relax != ModelsToRelax.NONE:
relax_metrics_path = os.path.join(output_dir, 'relax_metrics.json')
with open(relax_metrics_path, 'w') as f:
f.write(json.dumps(relax_metrics, indent=4))
Expand Down Expand Up @@ -386,16 +410,13 @@ def main(argv):
logging.info('Have %d models: %s', len(model_runners),
list(model_runners.keys()))

if FLAGS.run_relax:
amber_relaxer = relax.AmberRelaxation(
max_iterations=RELAX_MAX_ITERATIONS,
tolerance=RELAX_ENERGY_TOLERANCE,
stiffness=RELAX_STIFFNESS,
exclude_residues=RELAX_EXCLUDE_RESIDUES,
max_outer_iterations=RELAX_MAX_OUTER_ITERATIONS,
use_gpu=FLAGS.use_gpu_relax)
else:
amber_relaxer = None
amber_relaxer = relax.AmberRelaxation(
max_iterations=RELAX_MAX_ITERATIONS,
tolerance=RELAX_ENERGY_TOLERANCE,
stiffness=RELAX_STIFFNESS,
exclude_residues=RELAX_EXCLUDE_RESIDUES,
max_outer_iterations=RELAX_MAX_OUTER_ITERATIONS,
use_gpu=FLAGS.use_gpu_relax)

random_seed = FLAGS.random_seed
if random_seed is None:
Expand All @@ -413,7 +434,8 @@ def main(argv):
model_runners=model_runners,
amber_relaxer=amber_relaxer,
benchmark=FLAGS.benchmark,
random_seed=random_seed)
random_seed=random_seed,
models_to_relax=FLAGS.models_to_relax)


if __name__ == '__main__':
Expand Down
14 changes: 8 additions & 6 deletions run_alphafold_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,10 @@
class RunAlphafoldTest(parameterized.TestCase):

@parameterized.named_parameters(
('relax', True),
('no_relax', False),
('relax', run_alphafold.ModelsToRelax.ALL),
('no_relax', run_alphafold.ModelsToRelax.NONE),
)
def test_end_to_end(self, do_relax):
def test_end_to_end(self, models_to_relax):

data_pipeline_mock = mock.Mock()
model_runner_mock = mock.Mock()
Expand Down Expand Up @@ -72,9 +72,11 @@ def test_end_to_end(self, do_relax):
output_dir_base=out_dir,
data_pipeline=data_pipeline_mock,
model_runners={'model1': model_runner_mock},
amber_relaxer=amber_relaxer_mock if do_relax else None,
amber_relaxer=amber_relaxer_mock,
benchmark=False,
random_seed=0)
random_seed=0,
models_to_relax=models_to_relax,
)

base_output_files = os.listdir(out_dir)
self.assertIn('target.fasta', base_output_files)
Expand All @@ -85,7 +87,7 @@ def test_end_to_end(self, do_relax):
'features.pkl', 'msas', 'ranked_0.pdb', 'ranking_debug.json',
'result_model1.pkl', 'timings.json', 'unrelaxed_model1.pdb',
]
if do_relax:
if models_to_relax == run_alphafold.ModelsToRelax.ALL:
expected_files.extend(['relaxed_model1.pdb', 'relax_metrics.json'])
with open(os.path.join(out_dir, 'test', 'relax_metrics.json')) as f:
relax_metrics = json.loads(f.read())
Expand Down

0 comments on commit 684ffa1

Please sign in to comment.