Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

QC unmeasured bundle muts #105

Merged
merged 19 commits into from
Jul 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions .github/workflows/build_test_package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,14 @@ jobs:
run: |
pip install --upgrade pip
pip install -e ".[dev]"

- name: Black Format Check
run: |
black --check .
uses: psf/black@stable
with:
options: "--check"
src: "."
jupyter: false
version: "~= 23.3" # this is the version that ships with the vs code extension, currently

- name: Test
run: |
Expand Down
10 changes: 10 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,16 @@ All notable changes to this project will be documented in this file.

The format is based on `Keep a Changelog <https://keepachangelog.com>`_.


Bleeding Edge (main)
--------------------
- Closed [https://github.com/matsengrp/multidms/issues/104](docs test issue), thanks, @WSDeWitt !
- Cleaned Actions, again thanks to @WSDeWitt
- Fixed [bug in wildtype predictions](https://github.com/matsengrp/multidms/issues/106)
- Implimented [QC on invalid bundle muts](https://github.com/matsengrp/multidms/issues/84) as pointed out by @Haddox
- a few other minor cleanup tasks


0.1.9
-----
- First Release on PYPI
14 changes: 6 additions & 8 deletions CONTRIBUTING.rst
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
=====================================
How to contribute to this package
=====================================
============
Contributing
============

We welcome contributions to `multidms <multidms>`!
This document describes how to edit the package, run the tests, build the docs, put tagged versions on PyPI_, etc.

Editing the project
Expand Down Expand Up @@ -63,11 +64,8 @@ These can include:
Running the tests locally
++++++++++++++++++++++++++
After you make changes, you should run two sets of tests.
To run the tests, go to the top-level packag directory.
Then make sure that you have installed the packages listed in `test_requirements.txt <test_requirements.txt>`_.
If these are not installed, install them with::

pip install -r test_requirements.txt
To run the tests, go to the top-level packag directory -
making sure to install and activate a `development environment <https://matsengrp.github.io/multidms/installation.html>`_.

Then use ruff_ to `lint the code <https://en.wikipedia.org/wiki/Lint_%28software%29>`_ by running::

Expand Down
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
# multidms

![License](https://img.shields.io/github/license/matsengrp/multidms)
[![PyPI version](https://badge.fury.io/py/multidms.svg)](https://badge.fury.io/py/multidms)
[![Build](https://github.com/matsengrp/multidms/actions/workflows/build_test_package.yml/badge.svg)](https://github.com/matsengrp/multidms/actions/workflows/build_test_package.yml)
![License](https://img.shields.io/github/license/matsengrp/multidms)
[![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black)
[![Ruff](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/charliermarsh/ruff/main/assets/badge/v2.json)](https://github.com/astral-sh/ruff)

`multidms` is a Python package written by the
[Matsen group](https://matsen.fhcrc.org/)
in collaboration with
[William DeWitt](https://wsdewitt.github.io/)
and the
[Bloom Lab](https://research.fhcrc.org/bloom/en.html).
It can be used to fit a single global-epistasis model to one or more deep mutational scanning experiments,
It can be used to jointly fit a global-epistasis model to one or more deep mutational scanning experiments,
with the goal of estimating the effects of individual mutations,
and how much the effects differ between experiments.

Expand Down
1 change: 1 addition & 0 deletions docs/changelog.rst
1 change: 1 addition & 0 deletions docs/contributing.rst
4 changes: 1 addition & 3 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,8 @@ See below for information and examples of how to use this package.
fit_delta_BA1_example
multidms
acknowledgments
..
jit model composition
using with GPU's
contributing
changelog


Indices and tables
Expand Down
20 changes: 9 additions & 11 deletions docs/installation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,19 @@ The source code for ``multidms`` is available on GitHub at https://github.com/ma

The easiest way to install ``multidms`` is using ``pip``:

... code-block::
.. code-block::

pip install multidms

The fitting process can be quite computationally intensive,
and if available, we recommend using a GPU to accelerate the fitting process.
While ``multidms`` ships with ``jax`` and ``jaxlib`` as dependencies,
these packages do not include CUDA support by default.
For this, please update the jax installation in your environment to include CUDA support
by following the instructions in the
`jax documentation <https://github.com/google/jax#pip-installation-gpu-cuda-installed-via-pip-easier>`_.

.. code-block::
.. note::

pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
The `multidms.Model` fitting process can be quite computationally intensive,
and if available, we recommend using a GPU to accelerate it.
While ``multidms`` ships with ``jax`` and ``jaxlib`` as dependencies,
these packages do not include CUDA support by default.
For this, please update the jax installation in your environment to include CUDA support
by following the instructions in the
`jax documentation <https://github.com/google/jax#pip-installation-gpu-cuda-installed-via-pip-easier>`_.

Developer install
-----------------
Expand Down
6 changes: 3 additions & 3 deletions multidms/biophysical.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def nn_global_epistasis(theta: dict, z_d: jnp.array):
- :math:`w^{l}_{i}` and :math:`w^{o}_{i}` are free parameters representing latent
and output tranformations, respectively, associated with unit `i` in the
hidden layer of the network.
- :math:`b^{l}_{i}` is a free parameter, as an added bias term to unit $i$.
- :math:`b^{l}_{i}` is a free parameter, as an added bias term to unit `i`.
- :math:`b^{o}` is a constant, singular free parameter.

.. Note::
Expand Down Expand Up @@ -234,8 +234,8 @@ def softplus_activation(d_params, act, lower_bound=-3.5, hinge_scale=0.1, **kwar
with a lower bound at :math:`l + \gamma_{h}`,
as well as a ramping coefficient, :math:`\lambda_{\text{sp}}`.

Concretely, if we let $z' = g(\phi_d(v))$, then the predicted functional score of
our model is given by:
Concretely, if we let :math:`z' = g(\phi_d(v))`, then the predicted functional score
of our model is given by:

.. math::
t(z') = \lambda_{sp}\log(1 + e^{\frac{z' - l}{\lambda_{sp}}}) + l
Expand Down
130 changes: 101 additions & 29 deletions multidms/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,15 +272,19 @@ def __init__(
for hom, hom_func_df in df.groupby("condition"):
if verbose:
print(f"inferring site map for {hom}")
for idx, row in hom_func_df.iterrows():
for idx, row in tqdm(
hom_func_df.iterrows(), total=len(hom_func_df), disable=not verbose
):
for wt, site in zip(row.wts, row.sites):
site_map.loc[site, hom] = wt

if assert_site_integrity:
if verbose:
print("Asserting site integrity")
for hom, hom_func_df in df.groupby("condition"):
for idx, row in hom_func_df.iterrows():
for idx, row in tqdm(
hom_func_df.iterrows(), total=len(hom_func_df), disable=not verbose
):
for wt, site in zip(row.wts, row.sites):
assert site_map.loc[site, hom] == wt

Expand All @@ -290,7 +294,7 @@ def __init__(
sites_to_throw = na_rows[na_rows].index
site_map.dropna(inplace=True)

nb_workers = os.cpu_count() if not nb_workers else nb_workers
nb_workers = min(os.cpu_count(), 4) if nb_workers is None else nb_workers
pandarallel.initialize(progress_bar=verbose, nb_workers=nb_workers)

def flags_invalid_sites(disallowed_sites, sites_list):
Expand All @@ -302,47 +306,115 @@ def flags_invalid_sites(disallowed_sites, sites_list):
return False
return True

df["allowed_variant"] = df.sites.apply(
df["allowed_variant"] = df.sites.parallel_apply(
lambda sl: flags_invalid_sites(sites_to_throw, sl)
)
len(df)
df = df[df["allowed_variant"]]
if verbose:
print(
f"unknown cond wildtype at sites: {list(sites_to_throw.values)},"
f"\ndropping: {len(df) - len(df[df['allowed_variant']])} variants"
"which have mutations at those sites."
)

df.query("allowed_variant", inplace=True)
df.drop("allowed_variant", axis=1, inplace=True)
site_map.sort_index(inplace=True)

def get_nis_from_site_map(site_map):
"""Get non-identical sites from a site map"""
non_identical_sites = {}
reference_sequence_conditions = [self._reference]
for condition in self._conditions:
if condition == self._reference:
non_identical_sites[condition] = []
continue

nis = site_map.where(
site_map[self.reference] != site_map[condition],
).dropna()

if len(nis) == 0:
non_identical_sites[condition] = []
reference_sequence_conditions.append(condition)
else:
non_identical_sites[condition] = nis[[self._reference, condition]]
return non_identical_sites, reference_sequence_conditions

self._site_map = site_map.sort_index()
(non_identical_sites, reference_sequence_conditions) = get_nis_from_site_map(
site_map
)

# identify and write site map differences for each condition
non_identical_mutations = {}
non_identical_sites = {}
self._reference_sequence_conditions = [self._reference]
for condition in self._conditions:
if condition == self._reference:
non_identical_mutations[condition] = ""
non_identical_sites[condition] = []
# invalid nis see https://github.com/matsengrp/multidms/issues/84
observed_ref_muts = (
df.query("condition == @self.reference")
.aa_substitutions.str.split()
.explode()
.unique()
)
invalid_nim = []
for condition in self.conditions:
if (
condition == self.reference
or condition in reference_sequence_conditions
):
continue
observed_cond_muts = (
df.query("condition == @condition")
.aa_substitutions.str.split()
.explode()
.unique()
)
for site, cond_wts in non_identical_sites[condition].iterrows():
ref_wt, cond_wt = cond_wts[self.reference], cond_wts[condition]
forward_mut = f"{ref_wt}{site}{cond_wt}"
reversion_mut = f"{cond_wt}{site}{ref_wt}"

condition_1 = forward_mut in observed_ref_muts
condition_2 = reversion_mut in observed_cond_muts
if not (condition_1 and condition_2):
invalid_nim.append(site)

# find variants that contain mutations at invalid sites
df["allowed_variant"] = df.sites.parallel_apply(
lambda sl: flags_invalid_sites(invalid_nim, sl)
)
if verbose:
print(
f"invalid non-identical-sites: {invalid_nim}, dropping "
f"{len(df) - len(df[df['allowed_variant']])} variants"
)

nis = self._site_map.where(
self._site_map[self.reference] != self.site_map[condition],
).dropna()
# drop variants that contain mutations at invalid sites
df.query("allowed_variant", inplace=True)
df.drop("allowed_variant", axis=1, inplace=True)

if len(nis) == 0:
non_identical_mutations[condition] = ""
non_identical_sites[condition] = []
self._reference_sequence_conditions.append(condition)
else:
muts = nis[self._reference] + nis.index.astype(str) + nis[condition]
muts_string = " ".join(muts.values)
non_identical_mutations[condition] = muts_string
non_identical_sites[condition] = nis[[self._reference, condition]]
# drop invalid sites from site map
self._site_map = site_map.drop(invalid_nim, inplace=False)

# recompute non-identical sites for static property
(
self._non_identical_sites,
self._reference_sequence_conditions,
) = get_nis_from_site_map(self._site_map)

# compute the static non_identical_mutations property
non_identical_mutations = {}
for condition in self.conditions:
if condition in self.reference_sequence_conditions:
non_identical_mutations[condition] = ""
continue
nis = self.non_identical_sites[condition]
muts = nis[self.reference] + nis.index.astype(str) + nis[condition]
muts_string = " ".join(muts.values)
non_identical_mutations[condition] = muts_string
self._non_identical_mutations = non_identical_mutations
self._non_identical_sites = non_identical_sites

# compute all substitution conversions for all conditions which
# do not share the reference sequence
df = df.assign(var_wrt_ref=df["aa_substitutions"])
for condition, condition_func_df in df.groupby("condition"):
if verbose:
print(f"Converting mutations for {condition}")

if condition in self.reference_sequence_conditions:
if verbose:
print("is reference, skipping")
Expand Down
6 changes: 3 additions & 3 deletions multidms/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,7 +554,7 @@ def add_phenotypes_to_df(
for condition, condition_df in df.groupby(condition_col):
variant_subs = condition_df[substitutions_col]
if condition not in self.data.reference_sequence_conditions:
variant_subs = condition_df.parallel_apply(
variant_subs = condition_df.apply(
lambda x: self.data.convert_subs_wrt_ref_seq(
condition, x[substitutions_col]
),
Expand Down Expand Up @@ -683,7 +683,7 @@ def phenotype_fromsubs(self, aa_subs, condition=None):
)
]
)
return self.phenotype_frombinary(X)
return self.phenotype_frombinary(X, condition)

def latent_fromsubs(self, aa_subs, condition=None):
"""
Expand All @@ -699,7 +699,7 @@ def latent_fromsubs(self, aa_subs, condition=None):
)
]
)
return self.latent_frombinary(X)
return self.latent_frombinary(X, condition)

def phenotype_frombinary(self, X, condition=None):
"""
Expand Down
7 changes: 6 additions & 1 deletion multidms/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -708,9 +708,14 @@ def mut_shift_plot(
mutations_dfs,
)

# for now, we're simply dropping the functional scores
mut_df.drop(
[c for c in mut_df.columns if "func_score" in c], axis=1, inplace=True
)

# now compute replicate averages
for c in fit.mutations_df.columns:
if c == "mutation" or "times_seen" in c:
if c == "mutation" or "times_seen" in c or "func_score" in c:
continue
cols_to_combine = [f"{replicate}_{c}" for replicate in fit_data.keys()]
if c in ["wts", "sites", "muts"]:
Expand Down
2 changes: 1 addition & 1 deletion multidms/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def fit_wrapper(

print(
f"training_step {training_step}/{num_training_steps},"
"Loss: {imodel.loss}, Time: {fit_time} Seconds",
f"Loss: {imodel.loss}, Time: {fit_time} Seconds",
flush=True,
)

Expand Down
Loading