Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add model_embedding_dim argument to Dataset constructors #107

Merged
merged 15 commits into from
Jan 24, 2025
Merged

Conversation

willdumm
Copy link
Contributor

@willdumm willdumm commented Jan 16, 2025

This PR makes DXSM datasets specific to the embedding dimension of the model that the dataset is intended for.
Minimal changes in dnsm-experiments-1 were required, and these are in https://github.com/matsengrp/dnsm-experiments-1/pull/78.

Here are some discussion points:

  • Should the token insertion steps be moved from the load_pcp_df implementation to the dataset constructor? Right now, we do any token insertion (and heavy/light splicing, etc) any time we load a pcp dataframe, then the Dataset constructor strips out any unrecognized tokens.
  • Should we do anything to ensure that sequences containing X’s aren’t presented to old models that don’t support them? (models have always had an embedding dimension dedicated to X’s)
  • How do we prevent old models that don’t support them from being presented with paired data? If they were, the new dataset construction code would just strip out the separator token.
  • If we give a crepe an aa sequence directly, it may not have the scaffolding tokens that the model has been trained to expect, unless the sequence has been moved through the load_pcp_df pipeline
  • What does the src_key_padding_mask actually do? Where is the documentation for the torch Embedding forward function? It does not seem to exempt sites from being looked up in the embedding, anyway. I just want to make sure that we shouldn’t be unmasking ambiguous sites before presenting them to the model, like we do with other tokens.

@willdumm
Copy link
Contributor Author

willdumm commented Jan 16, 2025

To recap our conversation just now about this:

  • If we give a crepe an aa sequence directly, it may not have the scaffolding tokens that the model has been trained to expect, unless the sequence has been moved through the load_pcp_df pipeline
  • Should the token insertion steps be moved from the load_pcp_df implementation to the dataset constructor? Right now, we do any token insertion (and heavy/light splicing, etc) any time we load a pcp dataframe, then the Dataset constructor strips out any unrecognized tokens.

Yes, we've decided so. I'll backup the current working status of this branch (and the corresponding dnsm-experiments-1 branch) to wd-token-fix-bu, then implement the following here:

  • load_pcp_df and associated functions will load pcp_df without inserting any special token scaffolding into the sequences. We will always maintain separate heavy and light-chain columns wherever applicable.
  • There will be a free function that takes perhaps pairs of heavy- and light- chain sequences and scaffolds them with special tokens so they can be presented to the model. This function will take a known_token_count so that it knows how to process the sequences.
  • This free function will be called by the Dataset constructor, which will also still need to accept the known_token_count parameter.
  • Calling the model forward/represent functions directly (or through model.call) will require the user to do any sequence token scaffolding on their own, perhaps using the free function mentioned above
  • Calling the Crepe.__call__ function on sequences will do the required scaffolding automatically. Perhaps we'll have to strip out the model predictions for special token sites, so the outputs match the input sequence lengths? Erick suggests input format of something like crepe([(heavy, None), (heavy1, light1), ...]) to allow heavy and light chain sequences to be passed to the crepe for proper scaffolding.
  • Should we do anything to ensure that sequences containing X’s aren’t presented to old models that don’t support them? (models have always had an embedding dimension dedicated to X’s)
  • How do we prevent old models that don’t support them from being presented with paired data? If they were, the new dataset construction code would just strip out the separator token.

We'll just use pretrained crepe names to make this clear.

  • What does the src_key_padding_mask actually do? Where is the documentation for the torch Embedding forward function? It does not seem to exempt sites from being looked up in the embedding, anyway. I just want to make sure that we shouldn’t be unmasking ambiguous sites before presenting them to the model, like we do with other tokens.

We understand that although all sites in the input are included in the embedding, the model ignores sites masked sites while keeping track of the fact that they take up space. We may want to try unmasking X's that aren't just for padding here, as a separate issue.

Copy link
Contributor

@matsen matsen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was able to take a quick look (half-attention as attending conference)...

netam/dasm.py Outdated
# just have it take aa strings, but that's not what it did before, so I'm
# keeping the original behavior for now. (although, the old docstring
# claimed incorrectly that it took an aa sequence)
def build_selection_matrix_from_parent(self, parent: Tuple[str, str]):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it not used for branch length optimization?

If not, let's cut it!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's used in a couple of tests, so I prepended an underscore to its name instead of removing it.

if chosen_v_families is not None:
chosen_v_families = set(chosen_v_families)
# TODO is this the right way to handle this? Or should it be OR?
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that we decided on OR.

@willdumm willdumm marked this pull request as ready for review January 23, 2025 00:28
@willdumm
Copy link
Contributor Author

I am still waiting for all the notebooks to run, and I expect that changes will be needed to make some of them run. However, all tests pass, and the test snakemake configs run for both dnsm and dasm.

Copy link
Contributor

@matsen matsen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🎉

@willdumm willdumm merged commit 09d9274 into main Jan 24, 2025
2 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants