From 9d377811cbfc67f9f72f5748f30823e83520f663 Mon Sep 17 00:00:00 2001 From: Erick Matsen Date: Sat, 11 Jan 2025 04:37:40 -0800 Subject: [PATCH 1/6] implementing DCSM --- netam/codon_table.py | 45 ++++++ netam/dcsm.py | 340 ++++++++++++++++++++++++++++++++++++++++ netam/dnsm.py | 1 + netam/framework.py | 3 + netam/molevol.py | 40 ++++- netam/sequences.py | 32 ++++ tests/test_dcsm.py | 53 +++++++ tests/test_sequences.py | 9 ++ 8 files changed, 517 insertions(+), 6 deletions(-) create mode 100644 netam/codon_table.py create mode 100644 netam/dcsm.py create mode 100644 tests/test_dcsm.py diff --git a/netam/codon_table.py b/netam/codon_table.py new file mode 100644 index 00000000..1f2d353f --- /dev/null +++ b/netam/codon_table.py @@ -0,0 +1,45 @@ +import numpy as np + +from Bio.Data import CodonTable +from netam.sequences import AA_STR_SORTED + + +def single_mutant_aa_indices(codon): + """Given a codon, return the amino acid indices for all single-mutant neighbors. + + Args: + codon (str): A three-letter codon (e.g., "ATG"). + AA_STR_SORTED (str): A string of amino acids in a sorted order. + + Returns: + list of int: Indices of the resulting amino acids for single mutants. + """ + standard_table = CodonTable.unambiguous_dna_by_id[1] # Standard codon table + bases = ["A", "C", "G", "T"] + + mutant_aa_indices = set() # Use a set to avoid duplicates + + # Generate all single-mutant neighbors + for pos in range(3): # Codons have 3 positions + for base in bases: + if base != codon[pos]: # Mutate only if it's a different base + mutant_codon = codon[:pos] + base + codon[pos + 1 :] + + # Check if the mutant codon translates to a valid amino acid + if mutant_codon in standard_table.forward_table: + mutant_aa = standard_table.forward_table[mutant_codon] + mutant_aa_indices.add(AA_STR_SORTED.index(mutant_aa)) + + return sorted(mutant_aa_indices) + + +def make_codon_neighbor_indicator(nt_seq): + """ + Create a binary array indicating the single-mutant amino acid neighbors of + each codon in a given DNA sequence. + """ + neighbor = np.zeros((len(AA_STR_SORTED), len(nt_seq) // 3), dtype=bool) + for i in range(0, len(nt_seq), 3): + codon = nt_seq[i : i + 3] + neighbor[single_mutant_aa_indices(codon), i // 3] = True + return neighbor diff --git a/netam/dcsm.py b/netam/dcsm.py new file mode 100644 index 00000000..1e146583 --- /dev/null +++ b/netam/dcsm.py @@ -0,0 +1,340 @@ +"""Defining the deep natural selection model (DNSM).""" + +import copy + +import pandas as pd +import torch +import torch.nn.functional as F + +from netam.common import ( + assert_pcp_valid, + clamp_probability, + codon_mask_tensor_of, + BIG, +) +from netam.dxsm import DXSMDataset, DXSMBurrito +import netam.molevol as molevol + +from netam.common import aa_idx_tensor_of_str_ambig +from netam.sequences import ( + aa_idx_array_of_str, + aa_subs_indicator_tensor_of, + build_stop_codon_indicator_tensor, + nt_idx_tensor_of_str, + token_mask_of_aa_idxs, + translate_sequence, + translate_sequences, + codon_idx_tensor_of_str_ambig, + AA_AMBIG_IDX, + AMBIGUOUS_CODON_IDX, + CODON_AA_INDICATOR_MATRIX, + RESERVED_TOKEN_REGEX, + MAX_AA_TOKEN_IDX, +) + + +class DCSMDataset(DXSMDataset): + + def __init__( + self, + nt_parents: pd.Series, + nt_children: pd.Series, + nt_ratess: torch.Tensor, + nt_cspss: torch.Tensor, + branch_lengths: torch.Tensor, + multihit_model=None, + ): + self.nt_parents = nt_parents.str.replace(RESERVED_TOKEN_REGEX, "N", regex=True) + # We will replace reserved tokens with Ns but use the unmodified + # originals for codons and mask creation. + self.nt_children = nt_children.str.replace( + RESERVED_TOKEN_REGEX, "N", regex=True + ) + self.nt_ratess = nt_ratess + self.nt_cspss = nt_cspss + self.multihit_model = copy.deepcopy(multihit_model) + if multihit_model is not None: + # We want these parameters to act like fixed data. This is essential + # for multithreaded branch length optimization to work. + self.multihit_model.values.requires_grad_(False) + + assert len(self.nt_parents) == len(self.nt_children) + pcp_count = len(self.nt_parents) + + # Important to use the unmodified versions of nt_parents and + # nt_children so they still contain special tokens. + aa_parents = translate_sequences(nt_parents) + aa_children = translate_sequences(nt_children) + + self.max_codon_seq_len = max(len(seq) for seq in aa_parents) + # We have sequences of varying length, so we start with all tensors set + # to the ambiguous amino acid, and then will fill in the actual values + # below. + self.codon_parents_idxss = torch.full( + (pcp_count, self.max_codon_seq_len), AMBIGUOUS_CODON_IDX + ) + self.codon_children_idxss = self.codon_parents_idxss.clone() + self.aa_parents_idxss = torch.full( + (pcp_count, self.max_codon_seq_len), AA_AMBIG_IDX + ) + self.aa_children_idxss = torch.full( + (pcp_count, self.max_codon_seq_len), AA_AMBIG_IDX + ) + # TODO here we are computing the subs indicators. This is handy for OE plots. + self.aa_subs_indicators = torch.zeros((pcp_count, self.max_codon_seq_len)) + + self.masks = torch.ones((pcp_count, self.max_codon_seq_len), dtype=torch.bool) + + # We are using the modified nt_parents and nt_children here because we + # don't want any funky symbols in our codon indices. + for i, (nt_parent, nt_child, aa_parent, aa_child) in enumerate( + zip(self.nt_parents, self.nt_children, aa_parents, aa_children) + ): + self.masks[i, :] = codon_mask_tensor_of( + nt_parent, nt_child, aa_length=self.max_codon_seq_len + ) + assert len(nt_parent) % 3 == 0 + codon_seq_len = len(nt_parent) // 3 + + assert_pcp_valid(nt_parent, nt_child, aa_mask=self.masks[i][:codon_seq_len]) + + self.codon_parents_idxss[i, :codon_seq_len] = codon_idx_tensor_of_str_ambig( + nt_parent + ) + self.codon_children_idxss[i, :codon_seq_len] = ( + codon_idx_tensor_of_str_ambig(nt_child) + ) + self.aa_parents_idxss[i, :codon_seq_len] = aa_idx_tensor_of_str_ambig( + aa_parent + ) + self.aa_children_idxss[i, :codon_seq_len] = aa_idx_tensor_of_str_ambig( + aa_child + ) + self.aa_subs_indicators[i, :codon_seq_len] = aa_subs_indicator_tensor_of( + aa_parent, aa_child + ) + + assert torch.all(self.masks.sum(dim=1) > 0) + assert torch.max(self.aa_parents_idxss) <= MAX_AA_TOKEN_IDX + assert torch.max(self.aa_children_idxss) <= MAX_AA_TOKEN_IDX + assert torch.max(self.codon_parents_idxss) <= AMBIGUOUS_CODON_IDX + + self._branch_lengths = branch_lengths + self.update_neutral_probs() + + def update_neutral_probs(self): + """Update the neutral mutation probabilities for the dataset. + + This is a somewhat vague name, but that's because it includes all of the various + types of neutral mutation probabilities that we might want to compute. + + In this case it's the neutral codon probabilities. + """ + neutral_codon_probs_l = [] + + for nt_parent, mask, nt_rates, nt_csps, branch_length in zip( + self.nt_parents, + self.masks, + self.nt_ratess, + self.nt_cspss, + self._branch_lengths, + ): + mask = mask.to("cpu") + nt_rates = nt_rates.to("cpu") + nt_csps = nt_csps.to("cpu") + if self.multihit_model is not None: + multihit_model = copy.deepcopy(self.multihit_model).to("cpu") + else: + multihit_model = None + # Note we are replacing all Ns with As, which means that we need to be careful + # with masking out these positions later. We do this below. + parent_idxs = nt_idx_tensor_of_str(nt_parent.replace("N", "A")) + parent_len = len(nt_parent) + + mut_probs = 1.0 - torch.exp(-branch_length * nt_rates[:parent_len]) + nt_csps = nt_csps[:parent_len, :] + nt_mask = mask.repeat_interleave(3)[: len(nt_parent)] + molevol.check_csps(parent_idxs[nt_mask], nt_csps[: len(nt_parent)][nt_mask]) + + neutral_codon_probs = molevol.neutral_codon_probs( + parent_idxs.reshape(-1, 3), + mut_probs.reshape(-1, 3), + nt_csps.reshape(-1, 3, 4), + multihit_model=multihit_model, + ) + + if not torch.isfinite(neutral_codon_probs).all(): + print(f"Found a non-finite neutral_codon_prob") + print(f"nt_parent: {nt_parent}") + print(f"mask: {mask}") + print(f"nt_rates: {nt_rates}") + print(f"nt_csps: {nt_csps}") + print(f"branch_length: {branch_length}") + raise ValueError( + f"neutral_codon_probs is not finite: {neutral_codon_probs}" + ) + + # Ensure that all values are positive before taking the log later + neutral_codon_probs = clamp_probability(neutral_codon_probs) + + pad_len = self.max_codon_seq_len - neutral_codon_probs.shape[0] + if pad_len > 0: + neutral_codon_probs = F.pad( + neutral_codon_probs, (0, 0, 0, pad_len), value=1e-8 + ) + # Here we zero out masked positions. + neutral_codon_probs *= mask[:, None] + + neutral_codon_probs_l.append(neutral_codon_probs) + + # Note that our masked out positions will have a nan log probability, + # which will require us to handle them correctly downstream. + self.log_neutral_codon_probss = torch.log(torch.stack(neutral_codon_probs_l)) + + def __getitem__(self, idx): + return { + "codon_parents_idxs": self.codon_parents_idxss[idx], + "codon_children_idxs": self.codon_children_idxss[idx], + "aa_parents_idxs": self.aa_parents_idxss[idx], + "aa_children_idxs": self.aa_children_idxss[idx], + "subs_indicator": self.aa_subs_indicators[idx], + "mask": self.masks[idx], + "log_neutral_codon_probs": self.log_neutral_codon_probss[idx], + "nt_rates": self.nt_ratess[idx], + "nt_csps": self.nt_cspss[idx], + } + + def to(self, device): + self.codon_parents_idxss = self.codon_parents_idxss.to(device) + self.codon_children_idxss = self.codon_children_idxss.to(device) + self.aa_parents_idxss = self.aa_parents_idxss.to(device) + self.aa_children_idxss = self.aa_children_idxss.to(device) + self.aa_subs_indicators = self.aa_subs_indicators.to(device) + self.masks = self.masks.to(device) + self.log_neutral_codon_probss = self.log_neutral_codon_probss.to(device) + self.nt_ratess = self.nt_ratess.to(device) + self.nt_cspss = self.nt_cspss.to(device) + if self.multihit_model is not None: + self.multihit_model = self.multihit_model.to(device) + + +class DCSMBurrito(DXSMBurrito): + + model_type = "dcsm" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.xent_loss = torch.nn.CrossEntropyLoss() + self.stop_codon_zapper = build_stop_codon_indicator_tensor() * -BIG + + def prediction_pair_of_batch(self, batch): + """Get log neutral codon substitution probabilities and log selection factors + for a batch of data. + + We don't mask on the output, which will thus contain junk in all of the masked + sites. + """ + aa_parents_idxs = batch["aa_parents_idxs"].to(self.device) + mask = batch["mask"].to(self.device) + log_neutral_codon_probs = batch["log_neutral_codon_probs"].to(self.device) + if not torch.isfinite(log_neutral_codon_probs[mask]).all(): + raise ValueError( + f"log_neutral_codon_probs has non-finite values at relevant positions: {log_neutral_codon_probs[mask]}" + ) + # We need the model to see special tokens here. For every other purpose + # they are masked out. + keep_token_mask = mask | token_mask_of_aa_idxs(aa_parents_idxs) + log_selection_factors = self.model(aa_parents_idxs, keep_token_mask) + return log_neutral_codon_probs, log_selection_factors + + def predictions_of_batch(self, batch): + """Make log probability predictions for a batch of data. + + In this case they are log probabilities of codons, which are made to be + probabilities by setting the parent codon to 1 - sum(children). + + After all this, we clip the probabilities below to avoid log(0) issues. + So, in cases when the sum of the children is > 1, we don't give a + normalized probability distribution, but that won't crash the loss + calculation because that step uses softmax. + + Note that make all ambiguous codons nan in the output, ensuring that + they must get properly masked downstream. + """ + log_neutral_codon_probs, log_selection_factors = self.prediction_pair_of_batch( + batch + ) + + # This code block, in other burritos, is done in a separate function, + # but we can't do that here because we need to normalize the + # probabilities in a way that is not possible without having the index + # of the parent codon. Namely, we need to set the parent codon to 1 - + # sum(children). + + # This indicator lifts things up from aa land to codon land. + # TODO I guess we could store indicator in self and have everything move with a self.to(device) call. + indicator = CODON_AA_INDICATOR_MATRIX.to(self.device).T + log_preds = ( + log_neutral_codon_probs + + log_selection_factors @ indicator + + self.stop_codon_zapper.to(self.device) + ) + assert torch.isnan(log_preds).sum() == 0 + + parent_indices = batch["codon_parents_idxs"].to(self.device) # Shape: [B, L] + valid_mask = parent_indices != AMBIGUOUS_CODON_IDX # Shape: [B, L] + + # Convert to linear space so we can add probabilities. + preds = torch.exp(log_preds) + + # Zero out the parent indices in preds, while keeping the computation + # graph intact. + preds_zeroer = torch.ones_like(preds) + preds_zeroer[valid_mask, parent_indices[valid_mask]] = 0.0 + preds = preds * preds_zeroer + + # Calculate the non-parent sum after zeroing out the parent indices. + non_parent_sum = preds[valid_mask, :].sum(dim=-1) + + # Add these parent values back in, again keeping the computation graph intact. + preds_parent = torch.zeros_like(preds) + preds_parent[valid_mask, parent_indices[valid_mask]] = 1.0 - non_parent_sum + preds = preds + preds_parent + + # We have to clamp the predictions to avoid log(0) issues. + preds = torch.clamp(preds, min=torch.finfo(preds.dtype).eps) + + log_preds = torch.log(preds) + + # Set ambiguous codons to nan to make sure that we handle them correctly downstream. + log_preds[~valid_mask, :] = float("nan") + + return log_preds + + def loss_of_batch(self, batch): + codon_children_idxs = batch["codon_children_idxs"].to(self.device) + mask = batch["mask"].to(self.device) + + predictions = self.predictions_of_batch(batch)[mask] + assert torch.isnan(predictions).sum() == 0 + codon_children_idxs = codon_children_idxs[mask] + + return self.xent_loss(predictions, codon_children_idxs) + + # TODO copied from dasm.py + def build_selection_matrix_from_parent(self, parent: str): + """Build a selection matrix from a parent amino acid sequence. + + Values at ambiguous sites are meaningless. + """ + # This is simpler than the equivalent in dnsm.py because we get the selection + # matrix directly. Note that selection_factors_of_aa_str does the exponentiation + # so this indeed gives us the selection factors, not the log selection factors. + parent = translate_sequence(parent) + per_aa_selection_factors = self.model.selection_factors_of_aa_str(parent) + + parent = parent.replace("X", "A") + parent_idxs = aa_idx_array_of_str(parent) + per_aa_selection_factors[torch.arange(len(parent_idxs)), parent_idxs] = 1.0 + + return per_aa_selection_factors diff --git a/netam/dnsm.py b/netam/dnsm.py index bc05d479..120eb8c6 100644 --- a/netam/dnsm.py +++ b/netam/dnsm.py @@ -56,6 +56,7 @@ def update_neutral_probs(self): mut_probs = 1.0 - torch.exp(-branch_length * nt_rates[:parent_len]) nt_csps = nt_csps[:parent_len, :] + # TODO singular/plural mismatch neutral_aa_mut_prob = molevol.neutral_aa_mut_probs( parent_idxs.reshape(-1, 3), mut_probs.reshape(-1, 3), diff --git a/netam/framework.py b/netam/framework.py index 1581e035..e0930f16 100644 --- a/netam/framework.py +++ b/netam/framework.py @@ -599,6 +599,9 @@ def process_data_loader(self, data_loader, train_mode=False, loss_reduction=None self.optimizer.zero_grad() scalar_loss.backward() + if torch.isnan(scalar_loss): + raise ValueError(f"NaN in loss: {scalar_loss.item()}") + nan_in_gradients = False for name, param in self.model.named_parameters(): if torch.isnan(param).any(): diff --git a/netam/molevol.py b/netam/molevol.py index 2aef1c10..8c03909b 100644 --- a/netam/molevol.py +++ b/netam/molevol.py @@ -339,14 +339,14 @@ def build_codon_mutsel( return codon_mutsel, sums_too_big -def neutral_aa_probs( +def neutral_codon_probs( parent_codon_idxs: Tensor, codon_mut_probs: Tensor, codon_csps: Tensor, multihit_model=None, ) -> Tensor: - """For every site, what is the probability that the amino acid will mutate to every - amino acid? + """For every site, what is the probability that the site will mutate to every + alternate codon? Args: parent_codon_idxs (torch.Tensor): The parent codons for each sequence. Shape: (codon_count, 3) @@ -354,8 +354,8 @@ def neutral_aa_probs( codon_csps (torch.Tensor): The substitution probabilities for each site in each codon. Shape: (codon_count, 3, 4) Returns: - torch.Tensor: The probability that each site will change to each amino acid. - Shape: (codon_count, 20) + torch.Tensor: The probability that each site will change to each codon. + Shape: (codon_count, 64) """ mut_matrices = build_mutation_matrices( @@ -366,8 +366,36 @@ def neutral_aa_probs( if multihit_model is not None: codon_probs = multihit_model(parent_codon_idxs, codon_probs) + return codon_probs.view(-1, 64) + + +def neutral_aa_probs( + parent_codon_idxs: Tensor, + codon_mut_probs: Tensor, + codon_csps: Tensor, + multihit_model=None, +) -> Tensor: + """For every site, what is the probability that the site will mutate to every + alternate amino acid? + + Args: + parent_codon_idxs (torch.Tensor): The parent codons for each sequence. Shape: (codon_count, 3) + codon_mut_probs (torch.Tensor): The mutation probabilities for each site in each codon. Shape: (codon_count, 3) + codon_csps (torch.Tensor): The substitution probabilities for each site in each codon. Shape: (codon_count, 3, 4) + + Returns: + torch.Tensor: The probability that each site will change to each codon. + Shape: (codon_count, 20) + """ + codon_probs = neutral_codon_probs( + parent_codon_idxs, + codon_mut_probs, + codon_csps, + multihit_model=multihit_model, + ) + # Get the probability of mutating to each amino acid. - aa_probs = codon_probs.view(-1, 64) @ CODON_AA_INDICATOR_MATRIX + aa_probs = codon_probs @ CODON_AA_INDICATOR_MATRIX return aa_probs diff --git a/netam/sequences.py b/netam/sequences.py index 1d574279..10db7560 100644 --- a/netam/sequences.py +++ b/netam/sequences.py @@ -117,6 +117,14 @@ def dataset_inputs_of_pcp_df(pcp_df, known_token_count): ) +def build_stop_codon_indicator_tensor(): + """Return a tensor indicating the stop codons.""" + stop_codon_indicator = torch.zeros(len(CODONS)) + for stop_codon in STOP_CODONS: + stop_codon_indicator[CODONS.index(stop_codon)] = 1.0 + return stop_codon_indicator + + def nt_idx_array_of_str(nt_str): """Return the indices of the nucleotides in a string.""" try: @@ -153,6 +161,30 @@ def aa_idx_tensor_of_str(aa_str): raise +# TODO isolating all this stuff here + +AMBIGUOUS_CODON_IDX = len(CODONS) + + +def idx_of_codon_allowing_ambiguous(codon): + # if codon contains an N + if "N" in codon: + return AMBIGUOUS_CODON_IDX + else: + return CODONS.index(codon) + + +def codon_idx_tensor_of_str_ambig(nt_str): + """Return the indices of the codons in a string.""" + assert len(nt_str) % 3 == 0 + return torch.tensor( + [idx_of_codon_allowing_ambiguous(codon) for codon in iter_codons(nt_str)] + ) + + +# TODO end isolating new stuff + + def aa_onehot_tensor_of_str(aa_str): aa_onehot = torch.zeros((len(aa_str), 20)) aa_indices_parent = aa_idx_array_of_str(aa_str) diff --git a/tests/test_dcsm.py b/tests/test_dcsm.py new file mode 100644 index 00000000..d78fa5b2 --- /dev/null +++ b/tests/test_dcsm.py @@ -0,0 +1,53 @@ +import os + +import torch +import pytest + +from netam.common import BIG, force_spawn +from netam.framework import ( + crepe_exists, + load_crepe, +) +from netam.sequences import MAX_AA_TOKEN_IDX +from netam.models import TransformerBinarySelectionModelWiggleAct +from netam.dcsm import ( + DCSMBurrito, + DCSMDataset, +) + + +@pytest.fixture(scope="module") +def dcsm_burrito(pcp_df): + force_spawn() + """Fixture that returns the DNSM Burrito object.""" + pcp_df["in_train"] = True + pcp_df.loc[pcp_df.index[-15:], "in_train"] = False + train_dataset, val_dataset = DCSMDataset.train_val_datasets_of_pcp_df(pcp_df) + + model = TransformerBinarySelectionModelWiggleAct( + nhead=2, + d_model_per_head=4, + dim_feedforward=256, + layer_count=2, + output_dim=MAX_AA_TOKEN_IDX + 1, + ) + + burrito = DCSMBurrito( + train_dataset, + val_dataset, + model, + batch_size=32, + learning_rate=0.001, + min_learning_rate=0.0001, + ) + burrito.joint_train( + epochs=1, cycle_count=2, training_method="full", optimize_bl_first_cycle=False + ) + return burrito + + +def test_parallel_branch_length_optimization(dcsm_burrito): + dataset = dcsm_burrito.val_dataset + parallel_branch_lengths = dcsm_burrito.find_optimal_branch_lengths(dataset) + branch_lengths = dcsm_burrito.serial_find_optimal_branch_lengths(dataset) + assert torch.allclose(branch_lengths, parallel_branch_lengths) diff --git a/tests/test_sequences.py b/tests/test_sequences.py index 17ce83d1..e1ee8e5e 100644 --- a/tests/test_sequences.py +++ b/tests/test_sequences.py @@ -13,7 +13,9 @@ CODON_AA_INDICATOR_MATRIX, MAX_KNOWN_TOKEN_COUNT, AA_AMBIG_IDX, + AMBIGUOUS_CODON_IDX, aa_onehot_tensor_of_str, + codon_idx_tensor_of_str, nt_idx_array_of_str, nt_subs_indicator_tensor_of, translate_sequences, @@ -94,6 +96,13 @@ def test_nucleotide_indices_of_codon(): assert nt_idx_array_of_str("GCG").tolist() == [2, 1, 2] +def test_codon_idx_tensor_of_str(): + nt_str = "AAAAACTTGTTTNTT" + expected_output = torch.tensor([0, 1, 62, 63, AMBIGUOUS_CODON_IDX]) + output = codon_idx_tensor_of_str(nt_str) + assert torch.equal(output, expected_output) + + def test_aa_onehot_tensor_of_str(): aa_str = "QY" From 6b1c8d4ec9715019193f7f61a5915bded6e2d873 Mon Sep 17 00:00:00 2001 From: Erick Matsen Date: Fri, 24 Jan 2025 13:01:22 -0800 Subject: [PATCH 2/6] cleanup --- netam/sequences.py | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/netam/sequences.py b/netam/sequences.py index 10db7560..4f638b06 100644 --- a/netam/sequences.py +++ b/netam/sequences.py @@ -30,6 +30,8 @@ MAX_AA_TOKEN_IDX = MAX_KNOWN_TOKEN_COUNT - 1 CODONS = ["".join(codon_list) for codon_list in itertools.product(BASES, repeat=3)] STOP_CODONS = ["TAA", "TAG", "TGA"] +AMBIGUOUS_CODON_IDX = len(CODONS) + # Each token in RESERVED_TOKENS will appear once in aa strings, and three times # in nt strings. RESERVED_TOKEN_TRANSLATIONS = {token * 3: token for token in RESERVED_TOKENS} @@ -161,13 +163,7 @@ def aa_idx_tensor_of_str(aa_str): raise -# TODO isolating all this stuff here - -AMBIGUOUS_CODON_IDX = len(CODONS) - - def idx_of_codon_allowing_ambiguous(codon): - # if codon contains an N if "N" in codon: return AMBIGUOUS_CODON_IDX else: @@ -182,9 +178,6 @@ def codon_idx_tensor_of_str_ambig(nt_str): ) -# TODO end isolating new stuff - - def aa_onehot_tensor_of_str(aa_str): aa_onehot = torch.zeros((len(aa_str), 20)) aa_indices_parent = aa_idx_array_of_str(aa_str) From c542b44d783aa0cda3f0ac30badc96e3a39f580f Mon Sep 17 00:00:00 2001 From: Will Dumm Date: Fri, 24 Jan 2025 16:13:13 -0800 Subject: [PATCH 3/6] finish rebase/refactor for Will's PR --- netam/dasm.py | 29 +----------- netam/dcsm.py | 115 ++++++++++----------------------------------- netam/dxsm.py | 27 +++++++++++ netam/sequences.py | 2 +- tests/test_dcsm.py | 6 +-- 5 files changed, 59 insertions(+), 120 deletions(-) diff --git a/netam/dasm.py b/netam/dasm.py index 1095e4a5..938c335f 100644 --- a/netam/dasm.py +++ b/netam/dasm.py @@ -3,11 +3,8 @@ import torch import torch.nn.functional as F -from netam.common import ( - clamp_probability, - BIG, -) -from netam.dxsm import DXSMDataset, DXSMBurrito +from netam.common import clamp_probability +from netam.dxsm import DXSMDataset, DXSMBurrito, zap_predictions_along_diagonal import netam.framework as framework import netam.molevol as molevol import netam.sequences as sequences @@ -100,28 +97,6 @@ def to(self, device): self.multihit_model = self.multihit_model.to(device) -def zap_predictions_along_diagonal(predictions, aa_parents_idxs, fill=-BIG): - """Set the diagonal (i.e. no amino acid change) of the predictions tensor to -BIG, - except where aa_parents_idxs >= 20, which indicates no update should be done.""" - - device = predictions.device - batch_size, L, _ = predictions.shape - batch_indices = torch.arange(batch_size, device=device)[:, None].expand(-1, L) - sequence_indices = torch.arange(L, device=device)[None, :].expand(batch_size, -1) - - # Create a mask for valid positions (where aa_parents_idxs is less than 20) - valid_mask = aa_parents_idxs < 20 - - # Only update the predictions for valid positions - predictions[ - batch_indices[valid_mask], - sequence_indices[valid_mask], - aa_parents_idxs[valid_mask], - ] = fill - - return predictions - - class DASMBurrito(framework.TwoLossMixin, DXSMBurrito): model_type = "dasm" diff --git a/netam/dcsm.py b/netam/dcsm.py index 1e146583..abc57f96 100644 --- a/netam/dcsm.py +++ b/netam/dcsm.py @@ -7,12 +7,10 @@ import torch.nn.functional as F from netam.common import ( - assert_pcp_valid, clamp_probability, - codon_mask_tensor_of, BIG, ) -from netam.dxsm import DXSMDataset, DXSMBurrito +from netam.dxsm import DXSMDataset, DXSMBurrito, zap_predictions_along_diagonal import netam.molevol as molevol from netam.common import aa_idx_tensor_of_str_ambig @@ -37,91 +35,34 @@ class DCSMDataset(DXSMDataset): def __init__( self, - nt_parents: pd.Series, - nt_children: pd.Series, - nt_ratess: torch.Tensor, - nt_cspss: torch.Tensor, - branch_lengths: torch.Tensor, - multihit_model=None, + *args, + **kwargs, ): - self.nt_parents = nt_parents.str.replace(RESERVED_TOKEN_REGEX, "N", regex=True) - # We will replace reserved tokens with Ns but use the unmodified - # originals for codons and mask creation. - self.nt_children = nt_children.str.replace( - RESERVED_TOKEN_REGEX, "N", regex=True - ) - self.nt_ratess = nt_ratess - self.nt_cspss = nt_cspss - self.multihit_model = copy.deepcopy(multihit_model) - if multihit_model is not None: - # We want these parameters to act like fixed data. This is essential - # for multithreaded branch length optimization to work. - self.multihit_model.values.requires_grad_(False) - + super().__init__(*args, **kwargs) assert len(self.nt_parents) == len(self.nt_children) - pcp_count = len(self.nt_parents) - - # Important to use the unmodified versions of nt_parents and - # nt_children so they still contain special tokens. - aa_parents = translate_sequences(nt_parents) - aa_children = translate_sequences(nt_children) - - self.max_codon_seq_len = max(len(seq) for seq in aa_parents) - # We have sequences of varying length, so we start with all tensors set - # to the ambiguous amino acid, and then will fill in the actual values - # below. - self.codon_parents_idxss = torch.full( - (pcp_count, self.max_codon_seq_len), AMBIGUOUS_CODON_IDX + # We need to add codon index tensors to the dataset. + + self.max_codon_seq_len = self.max_aa_seq_len + self.codon_parents_idxss = torch.full_like( + self.aa_parents_idxss, AMBIGUOUS_CODON_IDX ) self.codon_children_idxss = self.codon_parents_idxss.clone() - self.aa_parents_idxss = torch.full( - (pcp_count, self.max_codon_seq_len), AA_AMBIG_IDX - ) - self.aa_children_idxss = torch.full( - (pcp_count, self.max_codon_seq_len), AA_AMBIG_IDX - ) - # TODO here we are computing the subs indicators. This is handy for OE plots. - self.aa_subs_indicators = torch.zeros((pcp_count, self.max_codon_seq_len)) - - self.masks = torch.ones((pcp_count, self.max_codon_seq_len), dtype=torch.bool) # We are using the modified nt_parents and nt_children here because we # don't want any funky symbols in our codon indices. - for i, (nt_parent, nt_child, aa_parent, aa_child) in enumerate( - zip(self.nt_parents, self.nt_children, aa_parents, aa_children) + for i, (nt_parent, nt_child) in enumerate( + zip(self.nt_parents, self.nt_children) ): - self.masks[i, :] = codon_mask_tensor_of( - nt_parent, nt_child, aa_length=self.max_codon_seq_len - ) assert len(nt_parent) % 3 == 0 codon_seq_len = len(nt_parent) // 3 - - assert_pcp_valid(nt_parent, nt_child, aa_mask=self.masks[i][:codon_seq_len]) - self.codon_parents_idxss[i, :codon_seq_len] = codon_idx_tensor_of_str_ambig( nt_parent ) self.codon_children_idxss[i, :codon_seq_len] = ( codon_idx_tensor_of_str_ambig(nt_child) ) - self.aa_parents_idxss[i, :codon_seq_len] = aa_idx_tensor_of_str_ambig( - aa_parent - ) - self.aa_children_idxss[i, :codon_seq_len] = aa_idx_tensor_of_str_ambig( - aa_child - ) - self.aa_subs_indicators[i, :codon_seq_len] = aa_subs_indicator_tensor_of( - aa_parent, aa_child - ) - - assert torch.all(self.masks.sum(dim=1) > 0) - assert torch.max(self.aa_parents_idxss) <= MAX_AA_TOKEN_IDX - assert torch.max(self.aa_children_idxss) <= MAX_AA_TOKEN_IDX assert torch.max(self.codon_parents_idxss) <= AMBIGUOUS_CODON_IDX - self._branch_lengths = branch_lengths - self.update_neutral_probs() - def update_neutral_probs(self): """Update the neutral mutation probabilities for the dataset. @@ -177,7 +118,7 @@ def update_neutral_probs(self): # Ensure that all values are positive before taking the log later neutral_codon_probs = clamp_probability(neutral_codon_probs) - pad_len = self.max_codon_seq_len - neutral_codon_probs.shape[0] + pad_len = self.max_aa_seq_len - neutral_codon_probs.shape[0] if pad_len > 0: neutral_codon_probs = F.pad( neutral_codon_probs, (0, 0, 0, pad_len), value=1e-8 @@ -241,10 +182,7 @@ def prediction_pair_of_batch(self, batch): raise ValueError( f"log_neutral_codon_probs has non-finite values at relevant positions: {log_neutral_codon_probs[mask]}" ) - # We need the model to see special tokens here. For every other purpose - # they are masked out. - keep_token_mask = mask | token_mask_of_aa_idxs(aa_parents_idxs) - log_selection_factors = self.model(aa_parents_idxs, keep_token_mask) + log_selection_factors = self.selection_factors_of_aa_idxs(aa_parents_idxs, mask) return log_neutral_codon_probs, log_selection_factors def predictions_of_batch(self, batch): @@ -321,20 +259,19 @@ def loss_of_batch(self, batch): return self.xent_loss(predictions, codon_children_idxs) - # TODO copied from dasm.py - def build_selection_matrix_from_parent(self, parent: str): - """Build a selection matrix from a parent amino acid sequence. + # TODO copied from dasm.py (updated for new organization from Will's PR) + def build_selection_matrix_from_parent_aa( + self, aa_parent_idxs: torch.Tensor, mask: torch.Tensor + ): + """Build a selection matrix from a single parent amino acid sequence. Inputs are + expected to be as prepared in the Dataset constructor. Values at ambiguous sites are meaningless. """ - # This is simpler than the equivalent in dnsm.py because we get the selection - # matrix directly. Note that selection_factors_of_aa_str does the exponentiation - # so this indeed gives us the selection factors, not the log selection factors. - parent = translate_sequence(parent) - per_aa_selection_factors = self.model.selection_factors_of_aa_str(parent) - - parent = parent.replace("X", "A") - parent_idxs = aa_idx_array_of_str(parent) - per_aa_selection_factors[torch.arange(len(parent_idxs)), parent_idxs] = 1.0 - - return per_aa_selection_factors + with torch.no_grad(): + per_aa_selection_factors = self.selection_factors_of_aa_idxs( + aa_parent_idxs.unsqueeze(0), mask.unsqueeze(0) + ).exp() + return zap_predictions_along_diagonal( + per_aa_selection_factors, aa_parent_idxs.unsqueeze(0), fill=1.0 + ).squeeze(0) diff --git a/netam/dxsm.py b/netam/dxsm.py index 58f13c18..267a9b5d 100644 --- a/netam/dxsm.py +++ b/netam/dxsm.py @@ -19,6 +19,7 @@ stack_heterogeneous, codon_mask_tensor_of, assert_pcp_valid, + BIG, ) import netam.framework as framework import netam.molevol as molevol @@ -78,6 +79,7 @@ def __init__( assert self.masks.shape[1] * 3 == self.nt_cspss.shape[1] assert torch.all(self.masks.sum(dim=1) > 0) assert torch.max(self.aa_parents_idxss) <= MAX_AA_TOKEN_IDX + assert torch.max(self.aa_children_idxss) <= MAX_AA_TOKEN_IDX self._branch_lengths = branch_lengths self.update_neutral_probs() @@ -430,3 +432,28 @@ def worker_optimize_branch_length(burrito_class, model, dataset, optimization_kw """The worker used for parallel branch length optimization.""" burrito = burrito_class(None, dataset, copy.deepcopy(model)) return burrito.serial_find_optimal_branch_lengths(dataset, **optimization_kwargs) + + + +def zap_predictions_along_diagonal(predictions, aa_parents_idxs, fill=-BIG): + """Set the diagonal (i.e. no amino acid change) of the predictions tensor to -BIG, + except where aa_parents_idxs >= 20, which indicates no update should be done.""" + + device = predictions.device + batch_size, L, _ = predictions.shape + batch_indices = torch.arange(batch_size, device=device)[:, None].expand(-1, L) + sequence_indices = torch.arange(L, device=device)[None, :].expand(batch_size, -1) + + # Create a mask for valid positions (where aa_parents_idxs is less than 20) + valid_mask = aa_parents_idxs < 20 + + # Only update the predictions for valid positions + predictions[ + batch_indices[valid_mask], + sequence_indices[valid_mask], + aa_parents_idxs[valid_mask], + ] = fill + + return predictions + + diff --git a/netam/sequences.py b/netam/sequences.py index 4f638b06..12b5480e 100644 --- a/netam/sequences.py +++ b/netam/sequences.py @@ -32,6 +32,7 @@ STOP_CODONS = ["TAA", "TAG", "TGA"] AMBIGUOUS_CODON_IDX = len(CODONS) + # Each token in RESERVED_TOKENS will appear once in aa strings, and three times # in nt strings. RESERVED_TOKEN_TRANSLATIONS = {token * 3: token for token in RESERVED_TOKENS} @@ -39,7 +40,6 @@ # Create a regex pattern RESERVED_TOKEN_REGEX = f"[{''.join(map(re.escape, list(RESERVED_TOKENS)))}]" - def prepare_heavy_light_pair(heavy_seq, light_seq, known_token_count, is_nt=True): """Prepare a pair of heavy and light chain sequences for model input. diff --git a/tests/test_dcsm.py b/tests/test_dcsm.py index d78fa5b2..e024c836 100644 --- a/tests/test_dcsm.py +++ b/tests/test_dcsm.py @@ -8,7 +8,7 @@ crepe_exists, load_crepe, ) -from netam.sequences import MAX_AA_TOKEN_IDX +from netam.sequences import MAX_AA_TOKEN_IDX, MAX_KNOWN_TOKEN_COUNT from netam.models import TransformerBinarySelectionModelWiggleAct from netam.dcsm import ( DCSMBurrito, @@ -22,14 +22,14 @@ def dcsm_burrito(pcp_df): """Fixture that returns the DNSM Burrito object.""" pcp_df["in_train"] = True pcp_df.loc[pcp_df.index[-15:], "in_train"] = False - train_dataset, val_dataset = DCSMDataset.train_val_datasets_of_pcp_df(pcp_df) + train_dataset, val_dataset = DCSMDataset.train_val_datasets_of_pcp_df(pcp_df, MAX_KNOWN_TOKEN_COUNT) model = TransformerBinarySelectionModelWiggleAct( nhead=2, d_model_per_head=4, dim_feedforward=256, layer_count=2, - output_dim=MAX_AA_TOKEN_IDX + 1, + output_dim=20, ) burrito = DCSMBurrito( From 64996bafb4ee288d4a79496a8090404994773cc6 Mon Sep 17 00:00:00 2001 From: Will Dumm Date: Fri, 24 Jan 2025 16:14:43 -0800 Subject: [PATCH 4/6] format --- netam/codon_table.py | 6 ++---- netam/dxsm.py | 3 --- netam/sequences.py | 1 + tests/test_dcsm.py | 4 +++- 4 files changed, 6 insertions(+), 8 deletions(-) diff --git a/netam/codon_table.py b/netam/codon_table.py index 1f2d353f..11a17707 100644 --- a/netam/codon_table.py +++ b/netam/codon_table.py @@ -34,10 +34,8 @@ def single_mutant_aa_indices(codon): def make_codon_neighbor_indicator(nt_seq): - """ - Create a binary array indicating the single-mutant amino acid neighbors of - each codon in a given DNA sequence. - """ + """Create a binary array indicating the single-mutant amino acid neighbors of each + codon in a given DNA sequence.""" neighbor = np.zeros((len(AA_STR_SORTED), len(nt_seq) // 3), dtype=bool) for i in range(0, len(nt_seq), 3): codon = nt_seq[i : i + 3] diff --git a/netam/dxsm.py b/netam/dxsm.py index 267a9b5d..298ffe45 100644 --- a/netam/dxsm.py +++ b/netam/dxsm.py @@ -434,7 +434,6 @@ def worker_optimize_branch_length(burrito_class, model, dataset, optimization_kw return burrito.serial_find_optimal_branch_lengths(dataset, **optimization_kwargs) - def zap_predictions_along_diagonal(predictions, aa_parents_idxs, fill=-BIG): """Set the diagonal (i.e. no amino acid change) of the predictions tensor to -BIG, except where aa_parents_idxs >= 20, which indicates no update should be done.""" @@ -455,5 +454,3 @@ def zap_predictions_along_diagonal(predictions, aa_parents_idxs, fill=-BIG): ] = fill return predictions - - diff --git a/netam/sequences.py b/netam/sequences.py index 12b5480e..76d48219 100644 --- a/netam/sequences.py +++ b/netam/sequences.py @@ -40,6 +40,7 @@ # Create a regex pattern RESERVED_TOKEN_REGEX = f"[{''.join(map(re.escape, list(RESERVED_TOKENS)))}]" + def prepare_heavy_light_pair(heavy_seq, light_seq, known_token_count, is_nt=True): """Prepare a pair of heavy and light chain sequences for model input. diff --git a/tests/test_dcsm.py b/tests/test_dcsm.py index e024c836..e3eb6a17 100644 --- a/tests/test_dcsm.py +++ b/tests/test_dcsm.py @@ -22,7 +22,9 @@ def dcsm_burrito(pcp_df): """Fixture that returns the DNSM Burrito object.""" pcp_df["in_train"] = True pcp_df.loc[pcp_df.index[-15:], "in_train"] = False - train_dataset, val_dataset = DCSMDataset.train_val_datasets_of_pcp_df(pcp_df, MAX_KNOWN_TOKEN_COUNT) + train_dataset, val_dataset = DCSMDataset.train_val_datasets_of_pcp_df( + pcp_df, MAX_KNOWN_TOKEN_COUNT + ) model = TransformerBinarySelectionModelWiggleAct( nhead=2, From c3b4abcc1b96e510fc941314326ba6fd42ba6b8b Mon Sep 17 00:00:00 2001 From: Will Dumm Date: Fri, 24 Jan 2025 16:20:10 -0800 Subject: [PATCH 5/6] fix test --- tests/test_sequences.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_sequences.py b/tests/test_sequences.py index e1ee8e5e..a7df4dfd 100644 --- a/tests/test_sequences.py +++ b/tests/test_sequences.py @@ -15,7 +15,7 @@ AA_AMBIG_IDX, AMBIGUOUS_CODON_IDX, aa_onehot_tensor_of_str, - codon_idx_tensor_of_str, + codon_idx_tensor_of_str_ambig, nt_idx_array_of_str, nt_subs_indicator_tensor_of, translate_sequences, @@ -99,7 +99,7 @@ def test_nucleotide_indices_of_codon(): def test_codon_idx_tensor_of_str(): nt_str = "AAAAACTTGTTTNTT" expected_output = torch.tensor([0, 1, 62, 63, AMBIGUOUS_CODON_IDX]) - output = codon_idx_tensor_of_str(nt_str) + output = codon_idx_tensor_of_str_ambig(nt_str) assert torch.equal(output, expected_output) From 284c2eec907a5ff1a3b8a0d0cdf43e70a752c322 Mon Sep 17 00:00:00 2001 From: Will Dumm Date: Tue, 28 Jan 2025 15:40:56 -0800 Subject: [PATCH 6/6] address TODOs, format and lint --- netam/dasm.py | 16 ---------------- netam/dcsm.py | 44 ++++++++++---------------------------------- netam/dnsm.py | 19 +++++++++---------- netam/dxsm.py | 17 +++++++++++++++++ tests/test_dcsm.py | 10 ++-------- 5 files changed, 38 insertions(+), 68 deletions(-) diff --git a/netam/dasm.py b/netam/dasm.py index 938c335f..811e30d6 100644 --- a/netam/dasm.py +++ b/netam/dasm.py @@ -177,22 +177,6 @@ def loss_of_batch(self, batch): csp_loss = self.xent_loss(csp_pred, csp_targets) return torch.stack([subs_pos_loss, csp_loss]) - def build_selection_matrix_from_parent_aa( - self, aa_parent_idxs: torch.Tensor, mask: torch.Tensor - ): - """Build a selection matrix from a single parent amino acid sequence. Inputs are - expected to be as prepared in the Dataset constructor. - - Values at ambiguous sites are meaningless. - """ - with torch.no_grad(): - per_aa_selection_factors = self.selection_factors_of_aa_idxs( - aa_parent_idxs.unsqueeze(0), mask.unsqueeze(0) - ).exp() - return zap_predictions_along_diagonal( - per_aa_selection_factors, aa_parent_idxs.unsqueeze(0), fill=1.0 - ).squeeze(0) - # This is not used anywhere, except for in a few tests. Keeping it around # for that reason. def _build_selection_matrix_from_parent(self, parent: Tuple[str, str]): diff --git a/netam/dcsm.py b/netam/dcsm.py index abc57f96..5cf44b37 100644 --- a/netam/dcsm.py +++ b/netam/dcsm.py @@ -2,7 +2,6 @@ import copy -import pandas as pd import torch import torch.nn.functional as F @@ -10,24 +9,15 @@ clamp_probability, BIG, ) -from netam.dxsm import DXSMDataset, DXSMBurrito, zap_predictions_along_diagonal +from netam.dxsm import DXSMDataset, DXSMBurrito import netam.molevol as molevol -from netam.common import aa_idx_tensor_of_str_ambig from netam.sequences import ( - aa_idx_array_of_str, - aa_subs_indicator_tensor_of, build_stop_codon_indicator_tensor, nt_idx_tensor_of_str, - token_mask_of_aa_idxs, - translate_sequence, - translate_sequences, codon_idx_tensor_of_str_ambig, - AA_AMBIG_IDX, AMBIGUOUS_CODON_IDX, CODON_AA_INDICATOR_MATRIX, - RESERVED_TOKEN_REGEX, - MAX_AA_TOKEN_IDX, ) @@ -146,6 +136,8 @@ def __getitem__(self, idx): } def to(self, device): + self.aa_codon_indicator_matrix = self.aa_codon_indicator_matrix.to(device) + self.stop_codon_zapper = self.stop_codon_zapper.to(device) self.codon_parents_idxss = self.codon_parents_idxss.to(device) self.codon_children_idxss = self.codon_children_idxss.to(device) self.aa_parents_idxss = self.aa_parents_idxss.to(device) @@ -166,7 +158,10 @@ class DCSMBurrito(DXSMBurrito): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.xent_loss = torch.nn.CrossEntropyLoss() - self.stop_codon_zapper = build_stop_codon_indicator_tensor() * -BIG + self.stop_codon_zapper = (build_stop_codon_indicator_tensor() * -BIG).to( + self.device + ) + self.aa_codon_indicator_matrix = CODON_AA_INDICATOR_MATRIX.to(self.device).T def prediction_pair_of_batch(self, batch): """Get log neutral codon substitution probabilities and log selection factors @@ -209,13 +204,11 @@ def predictions_of_batch(self, batch): # of the parent codon. Namely, we need to set the parent codon to 1 - # sum(children). - # This indicator lifts things up from aa land to codon land. - # TODO I guess we could store indicator in self and have everything move with a self.to(device) call. - indicator = CODON_AA_INDICATOR_MATRIX.to(self.device).T + # The aa_codon_indicator_matrix lifts things up from aa land to codon land. log_preds = ( log_neutral_codon_probs - + log_selection_factors @ indicator - + self.stop_codon_zapper.to(self.device) + + log_selection_factors @ self.aa_codon_indicator_matrix + + self.stop_codon_zapper ) assert torch.isnan(log_preds).sum() == 0 @@ -258,20 +251,3 @@ def loss_of_batch(self, batch): codon_children_idxs = codon_children_idxs[mask] return self.xent_loss(predictions, codon_children_idxs) - - # TODO copied from dasm.py (updated for new organization from Will's PR) - def build_selection_matrix_from_parent_aa( - self, aa_parent_idxs: torch.Tensor, mask: torch.Tensor - ): - """Build a selection matrix from a single parent amino acid sequence. Inputs are - expected to be as prepared in the Dataset constructor. - - Values at ambiguous sites are meaningless. - """ - with torch.no_grad(): - per_aa_selection_factors = self.selection_factors_of_aa_idxs( - aa_parent_idxs.unsqueeze(0), mask.unsqueeze(0) - ).exp() - return zap_predictions_along_diagonal( - per_aa_selection_factors, aa_parent_idxs.unsqueeze(0), fill=1.0 - ).squeeze(0) diff --git a/netam/dnsm.py b/netam/dnsm.py index 120eb8c6..abce4255 100644 --- a/netam/dnsm.py +++ b/netam/dnsm.py @@ -56,15 +56,14 @@ def update_neutral_probs(self): mut_probs = 1.0 - torch.exp(-branch_length * nt_rates[:parent_len]) nt_csps = nt_csps[:parent_len, :] - # TODO singular/plural mismatch - neutral_aa_mut_prob = molevol.neutral_aa_mut_probs( + neutral_aa_mut_probs = molevol.neutral_aa_mut_probs( parent_idxs.reshape(-1, 3), mut_probs.reshape(-1, 3), nt_csps.reshape(-1, 3, 4), multihit_model=multihit_model, ) - if not torch.isfinite(neutral_aa_mut_prob).all(): + if not torch.isfinite(neutral_aa_mut_probs).all(): print(f"Found a non-finite neutral_aa_mut_prob") print(f"nt_parent: {nt_parent}") print(f"mask: {mask}") @@ -72,21 +71,21 @@ def update_neutral_probs(self): print(f"nt_csps: {nt_csps}") print(f"branch_length: {branch_length}") raise ValueError( - f"neutral_aa_mut_prob is not finite: {neutral_aa_mut_prob}" + f"neutral_aa_mut_prob is not finite: {neutral_aa_mut_probs}" ) # Ensure that all values are positive before taking the log later - neutral_aa_mut_prob = clamp_probability(neutral_aa_mut_prob) + neutral_aa_mut_probs = clamp_probability(neutral_aa_mut_probs) - pad_len = self.max_aa_seq_len - neutral_aa_mut_prob.shape[0] + pad_len = self.max_aa_seq_len - neutral_aa_mut_probs.shape[0] if pad_len > 0: - neutral_aa_mut_prob = F.pad( - neutral_aa_mut_prob, (0, pad_len), value=1e-8 + neutral_aa_mut_probs = F.pad( + neutral_aa_mut_probs, (0, pad_len), value=1e-8 ) # Here we zero out masked positions. - neutral_aa_mut_prob *= mask + neutral_aa_mut_probs *= mask - neutral_aa_mut_prob_l.append(neutral_aa_mut_prob) + neutral_aa_mut_prob_l.append(neutral_aa_mut_probs) # Note that our masked out positions will have a nan log probability, # which will require us to handle them correctly downstream. diff --git a/netam/dxsm.py b/netam/dxsm.py index 298ffe45..512d3922 100644 --- a/netam/dxsm.py +++ b/netam/dxsm.py @@ -423,6 +423,23 @@ def to_crepe(self): encoder = framework.PlaceholderEncoder() return framework.Crepe(encoder, self.model, training_hyperparameters) + # This is overridden in DNSMBurrito + def build_selection_matrix_from_parent_aa( + self, aa_parent_idxs: torch.Tensor, mask: torch.Tensor + ): + """Build a selection matrix from a single parent amino acid sequence. Inputs are + expected to be as prepared in the Dataset constructor. + + Values at ambiguous sites are meaningless. + """ + with torch.no_grad(): + per_aa_selection_factors = self.selection_factors_of_aa_idxs( + aa_parent_idxs.unsqueeze(0), mask.unsqueeze(0) + ).exp() + return zap_predictions_along_diagonal( + per_aa_selection_factors, aa_parent_idxs.unsqueeze(0), fill=1.0 + ).squeeze(0) + @abstractmethod def loss_of_batch(self, batch): pass diff --git a/tests/test_dcsm.py b/tests/test_dcsm.py index e3eb6a17..4a3abbec 100644 --- a/tests/test_dcsm.py +++ b/tests/test_dcsm.py @@ -1,14 +1,8 @@ -import os - import torch import pytest -from netam.common import BIG, force_spawn -from netam.framework import ( - crepe_exists, - load_crepe, -) -from netam.sequences import MAX_AA_TOKEN_IDX, MAX_KNOWN_TOKEN_COUNT +from netam.common import force_spawn +from netam.sequences import MAX_KNOWN_TOKEN_COUNT from netam.models import TransformerBinarySelectionModelWiggleAct from netam.dcsm import ( DCSMBurrito,