From 7f7517b4dea4c74f607ef7795d9e1b53b174de1b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Czy=C5=BC?= Date: Mon, 18 Nov 2024 22:41:49 +0100 Subject: [PATCH 01/10] WIP: Annotate suspicious functions --- src/covvfit/_quasimultinomial.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/covvfit/_quasimultinomial.py b/src/covvfit/_quasimultinomial.py index ee3e9f1..a19eb54 100644 --- a/src/covvfit/_quasimultinomial.py +++ b/src/covvfit/_quasimultinomial.py @@ -109,7 +109,7 @@ def convert(confidence: float) -> float: Example: StandardErrorsMultipliers.convert(0.95) # 1.9599 """ - return float(jax.scipy.stats.norm.ppf((1 + confidence) / 2)) + return float(jax.scipy.stats.norm.ppf((1 + confidence) / 2.0)) def get_covariance( @@ -185,7 +185,7 @@ def get_confidence_intervals( Assumes a normal distribution for the estimates. """ # Calculate the multiplier based on the confidence level - z_score = jax.scipy.stats.norm.ppf((1 + confidence_level) / 2) + z_score = StandardErrorsMultipliers.convert(confidence_level) # Compute the lower and upper bounds of the confidence intervals lower_bound = estimates - z_score * standard_errors @@ -269,6 +269,7 @@ def get_confidence_bands_logit( A list of dictionaries for each city, each with "lower" and "upper" bounds for the confidence intervals on the linear scale. """ + # TODO(Pawel): Potentially fix the signature of this function. y_fit_lst_logit = [ get_logit_predictions(solution_x, variants_count, i, ts).T[1:, :] @@ -322,6 +323,8 @@ def get_relative_advantages(theta, n_variants: int): def get_softmax_predictions( theta: _ThetaType, n_variants: int, city_index: int, ts: Float[Array, " timepoints"] ) -> Float[Array, "timepoints variants"]: + # TODO(Pawel): Potentially fix the signature of this function. + rel_growths = get_relative_growths(theta, n_variants=n_variants) growths = _add_first_variant(rel_growths) From b20b75404386892a253d03e93ba083f2430a6b8a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Czy=C5=BC?= Date: Tue, 19 Nov 2024 10:17:58 +0100 Subject: [PATCH 02/10] Adjust public API --- src/covvfit/__init__.py | 5 ++--- src/covvfit/plotting/__init__.py | 6 ++++-- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/src/covvfit/__init__.py b/src/covvfit/__init__.py index 826ecec..cf7ed7b 100644 --- a/src/covvfit/__init__.py +++ b/src/covvfit/__init__.py @@ -10,8 +10,8 @@ ) simulation = None +import covvfit._preprocess_abundances as preprocess import covvfit.plotting as plot -from covvfit._preprocess_abundances import load_data, make_data_list, preprocess_df from covvfit._splines import create_spline_matrix VERSION = "0.1.0" @@ -19,11 +19,10 @@ __all__ = [ "create_spline_matrix", - "make_data_list", - "preprocess_df", "load_data", "VERSION", "quasimultinomial", "plot", + "preprocess", "simulation", ] diff --git a/src/covvfit/plotting/__init__.py b/src/covvfit/plotting/__init__.py index f108b7a..738d047 100644 --- a/src/covvfit/plotting/__init__.py +++ b/src/covvfit/plotting/__init__.py @@ -1,8 +1,9 @@ """Plotting functionalities.""" +import covvfit.plotting._timeseries as timeseries from covvfit.plotting._grid import plot_grid, set_axis_off from covvfit.plotting._simplex import plot_on_simplex -from covvfit.plotting._timeseries import colors_covsp, make_legend, num_to_date +from covvfit.plotting._timeseries import COLORS_COVSPECTRUM, make_legend, num_to_date __all__ = [ "plot_on_simplex", @@ -10,5 +11,6 @@ "set_axis_off", "make_legend", "num_to_date", - "colors_covsp", + "timeseries", + "COLORS_COVSPECTRUM", ] From d592e1cf2a69a6c950801f8e9f5a589b3c6d8358 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Czy=C5=BC?= Date: Tue, 19 Nov 2024 10:22:53 +0100 Subject: [PATCH 03/10] Refactor preprocessing --- src/covvfit/_preprocess_abundances.py | 111 +++++++++++++++++++++++--- 1 file changed, 99 insertions(+), 12 deletions(-) diff --git a/src/covvfit/_preprocess_abundances.py b/src/covvfit/_preprocess_abundances.py index 4922676..cabf1ef 100644 --- a/src/covvfit/_preprocess_abundances.py +++ b/src/covvfit/_preprocess_abundances.py @@ -1,6 +1,7 @@ """utilities to preprocess relative abundances""" import pandas as pd +from jaxtyping import Array, Float def load_data(file) -> pd.DataFrame: @@ -13,34 +14,56 @@ def preprocess_df( df: pd.DataFrame, cities: list[str], variants: list[str], - undertermined_thresh: float = 0.01, + *, + undetermined_thresh: float = 0.01, zero_date: str = "2023-01-01", date_min: str | None = None, date_max: str | None = None, + time_col: str = "time", + city_col: str = "city", + undetermined_col: str | None = "undetermined", ) -> pd.DataFrame: - """Preprocessing.""" + """Preprocessing function. + + Args: + df: data frame with data + cities: cities for which the data will be processed + variants: variants which will be processed, they should + be represented by different columns in `df` + undetermined_thresh: threshold of the `undetermined` variant + used to remove days with too many missing values. + Use `None` to not remove any data + zero_date: reference time point + date_min: the lower bound of the data to be selected. + Set to `None` to not set a bound + date_max: see `date_min` + time_col: column with dates representing days + city_col: column with cities + undetermined_col: column with the undetermined variant + """ df = df.copy() # Convert the 'time' column to datetime - df["time"] = pd.to_datetime(df["time"]) + df[time_col] = pd.to_datetime(df["time"]) # Remove days with too high undetermined - df = df[df["undetermined"] < undertermined_thresh] # pyright: ignore + if undetermined_col is not None: + df = df[df[undetermined_col] < undetermined_thresh] # pyright: ignore - # Subset the 'BQ.1.1' column - df = df[["time", "city"] + variants] # pyright: ignore + # Subset the columns corresponding to variants + df = df[[time_col, city_col] + variants] # pyright: ignore # Subset only the specified cities - df = df[df["city"].isin(cities)] # pyright: ignore + df = df[df[city_col].isin(cities)] # pyright: ignore # Create a new column which is the difference in days between zero_date and the date - df["days_from"] = (df["time"] - pd.to_datetime(zero_date)).dt.days + df["days_from"] = (df[time_col] - pd.to_datetime(zero_date)).dt.days # Subset dates if date_min is not None: - df = df[df["time"] >= pd.to_datetime(date_min)] # pyright: ignore + df = df[df[time_col] >= pd.to_datetime(date_min)] # pyright: ignore if date_max is not None: - df = df[df["time"] < pd.to_datetime(date_max)] # pyright: ignore + df = df[df[time_col] < pd.to_datetime(date_max)] # pyright: ignore return df @@ -49,13 +72,77 @@ def make_data_list( df: pd.DataFrame, cities: list[str], variants: list[str], -) -> tuple[list, list]: +) -> tuple[ + list[Float[Array, " timepoints"]], list[Float[Array, "timepoints variants"]] +]: ts_lst = [df[(df.city == city)].days_from.values for city in cities] ys_lst = [ - df[(df.city == city)][variants].values.T for city in cities + df[(df.city == city)][variants].values for city in cities ] # pyright: ignore + # TODO(David, Pawel): How should we handle this case? + # It *implicitly* changes the output data type, basing on the input value. + # Do we even use this feature? if "count_sum" in df.columns: ns_lst = [df[(df.city == city)].count_sum.values for city in cities] return (ts_lst, ys_lst, ns_lst) else: return (ts_lst, ys_lst) + + +_ListTimeSeries = list[Float[Array, " timeseries"]] + + +class TimeScaler: + """Scales a list of time series, so that the values are normalized.""" + + def __init__(self): + self.t_min = None + self.t_max = None + self._fitted = False + + def fit(self, ts: _ListTimeSeries) -> None: + """Fit the scaler parameters to the provided time series. + + Args: + ts: list of timeseries, i.e., `ts[i]` is an array + of some length `n_timepoints[i]`. + """ + self.t_min = min([x.min() for x in ts]) + self.t_max = max([x.max() for x in ts]) + self._fitted = True + + def transform(self, ts: _ListTimeSeries) -> _ListTimeSeries: + """Returns scaled values. + + Args: + ts: list of timeseries, i.e., `ts[i]` is an array + of some length `n_timepoints[i]`. + + Returns: + list of exactly the same format as `ts` + + Note: + The model has to be fitted first. + """ + if not self._fitted: + raise RuntimeError("You need to fit the model first.") + + denominator = self.t_max - self.t_min + return [(x - self.t_min) / denominator for x in ts] + + def fit_transform(self, ts: _ListTimeSeries) -> _ListTimeSeries: + """Fits the model and returns scaled values. + + Args: + ts: list of timeseries, i.e., `ts[i]` is an array + of some length `n_timepoints[i]`. + + Returns: + list of exactly the same format as `ts` + + Note: + This function is equivalent to calling + first `fit` method and then `transform`. + """ + self.fit(ts) + return self.transform(ts) From 109eff426e0e0b4ac101607c4c2b7ebad25689ad Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Czy=C5=BC?= Date: Tue, 19 Nov 2024 10:23:37 +0100 Subject: [PATCH 04/10] Add type annotations and change accepted shapes --- src/covvfit/plotting/_timeseries.py | 92 ++++++++++++++++++++++------- 1 file changed, 70 insertions(+), 22 deletions(-) diff --git a/src/covvfit/plotting/_timeseries.py b/src/covvfit/plotting/_timeseries.py index cf88cf2..2dace61 100644 --- a/src/covvfit/plotting/_timeseries.py +++ b/src/covvfit/plotting/_timeseries.py @@ -1,11 +1,17 @@ """utilities to plot""" +from typing import Literal import matplotlib.lines as mlines import matplotlib.patches as mpatches +import matplotlib.pyplot as plt import numpy as np import pandas as pd +from jaxtyping import Array, Float -colors_covsp = { +Variant = str +Color = str + +COLORS_COVSPECTRUM: dict[Variant, Color] = { "B.1.1.7": "#D16666", "B.1.351": "#FF6666", "P.1": "#FFB3B3", @@ -31,7 +37,7 @@ } -def make_legend(colors, variants): +def make_legend(colors: list[Color], variants: list[Variant]) -> list[mpatches.Patch]: """make a shared legend for the plot""" # Create a patch (i.e., a colored box) for each variant variant_patches = [ @@ -59,13 +65,23 @@ def make_legend(colors, variants): return handles -def num_to_date(num, date_min, pos=None, fmt="%b. '%y"): +def num_to_date( + num: pd.Series | Float[Array, " timepoints"], date_min: str, fmt="%b. '%y" +) -> pd.Series: """convert days number into a date format""" date = pd.to_datetime(date_min) + pd.to_timedelta(num, "D") return date.strftime(fmt) -def plot_fit(ax, ts, y_fit, variants, colors, linetype="-", **kwargs): +def plot_fit( + ax: plt.Axes, + ts: Float[Array, " timepoints"], + y_fit: Float[Array, "timepoints variants"], + variants: list[Variant], + colors: list[Color], + linetype="-", + **kwargs, +) -> None: """ Function to plot fitted values with customizable line type. @@ -81,7 +97,7 @@ def plot_fit(ax, ts, y_fit, variants, colors, linetype="-", **kwargs): for i, variant in enumerate(variants): ax.plot( ts[sorted_indices], - y_fit[i, :][sorted_indices], + y_fit[sorted_indices, i], color=colors[i], linestyle=linetype, label=f"fit {variant}", @@ -89,48 +105,80 @@ def plot_fit(ax, ts, y_fit, variants, colors, linetype="-", **kwargs): ) -def plot_complement(ax, ts, y_fit, variants, color="grey", linetype="-"): +def plot_complement( + ax: plt.Axes, + ts: Float[Array, " timepoints"], + y_fit: Float[Array, "timepoints variants"], + color: str = "grey", + linetype: str = "-", + **kwargs, +) -> None: ## function to plot 1-sum(fitted_values) i.e., the other variant(s) sorted_indices = np.argsort(ts) ax.plot( ts[sorted_indices], - (1 - y_fit.sum(axis=0))[sorted_indices], + (1 - y_fit.sum(axis=-1))[sorted_indices], color=color, linestyle=linetype, + **kwargs, ) -def plot_data(ax, ts, ys, variants, colors): +def plot_data( + ax: plt.Axes, + ts: Float[Array, " timepoints"], + ys: Float[Array, "timepoints variants"], + colors: list[Color], + size: float | int = 4, + alpha: float = 0.5, + **kwargs, +) -> None: ## function to plot raw values - for i, variant in enumerate(variants): - ax.scatter(ts, ys[i, :], label="observed", alpha=0.5, color=colors[i], s=4) + for i in range(ys.shape[-1]): + ax.scatter( + ts, + ys[:, i], + label="observed", + alpha=alpha, + color=colors[i], + s=size, + **kwargs, + ) def plot_confidence_bands( - ax, ts, conf_bands, variants, colors, label="Confidence band", alpha=0.2 -): + ax: plt.Axes, + ts: Float[Array, " timepoints"], + conf_bands: dict[Literal["lower", "upper"], Float[Array, "timepoints variants"]], + colors: list[Color], + label: str = "Confidence band", + alpha: float = 0.2, + **kwargs, +) -> None: """ Plot confidence intervals for fitted values on a given axis with customizable confidence level. Parameters: - ax (matplotlib.axes.Axes): The axis to plot on. - ts (array-like): Time series data. - y_fit_logit (array-like): Logit-transformed fitted values for each variant. - logit_se (array-like): Standard errors for the logit-transformed fitted values. - color (str): Color for the confidence interval. - confidence (float, optional): Confidence level (e.g., 0.95 for 95%). Default is 0.95. - label (str, optional): Label for the confidence band. Default is "Confidence band". + ax: The axis to plot on. + ts: Time series data. + color: Color for the confidence interval. + label: Label for the confidence band. Default is "Confidence band". + alpha: Alpha level controling the opacity. + **kwargs: Additional keyword arguments for `ax.fill_between`. """ # Sort indices for time series sorted_indices = np.argsort(ts) + n_variants = conf_bands["lower"].shape[-1] + # Plot the confidence interval - for i, variant in enumerate(variants): + for i in range(n_variants): ax.fill_between( ts[sorted_indices], - conf_bands["lower"][i][sorted_indices], - conf_bands["upper"][i][sorted_indices], + conf_bands["lower"][sorted_indices][i], + conf_bands["upper"][sorted_indices][i], color=colors[i], alpha=alpha, label=label, + **kwargs, ) From ecf8bc624da65524bae7a0d57cd858d9eef010ea Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Czy=C5=BC?= Date: Tue, 19 Nov 2024 10:24:22 +0100 Subject: [PATCH 05/10] WIP: Refactor notebook up to fitting the model --- examples/frequentist_notebook_jax.py | 126 +++++++++++++++------------ 1 file changed, 70 insertions(+), 56 deletions(-) diff --git a/examples/frequentist_notebook_jax.py b/examples/frequentist_notebook_jax.py index be24789..8c2d66f 100644 --- a/examples/frequentist_notebook_jax.py +++ b/examples/frequentist_notebook_jax.py @@ -13,41 +13,36 @@ # name: python3 # --- +# # Quasilikelihood data analysis notebook +# +# This notebook shows how to estimate growth advantages by fiting the model within the quasimultinomial framework. + # + -import jax import jax.numpy as jnp - -import pandas as pd - -import numpy as np - -import matplotlib.ticker as ticker import matplotlib.pyplot as plt -import seaborn as sns - -from scipy.special import expit -from scipy.stats import norm - +import matplotlib.ticker as ticker +import pandas as pd import yaml -import covvfit._preprocess_abundances as prec -import covvfit.plotting._timeseries as plot_ts - +from covvfit import plot, preprocess from covvfit import quasimultinomial as qm -import numpyro - +plot_ts = plot.timeseries # - -# # Load and preprocess data +# ## Load and preprocess data +# +# We start by loading the data: # + -DATA_PATH = "../../LolliPop/lollipop_covvfit/deconvolved.csv" -VAR_DATES_PATH = "../../LolliPop/lollipop_covvfit/var_dates.yaml" - -DATA_PATH = "../new_data/deconvolved.csv" -VAR_DATES_PATH = "../new_data/var_dates.yaml" +_dir_switch = False # Change this to 0 or 1, depending on the laptop you are on +if _dir_switch: + DATA_PATH = "../../LolliPop/lollipop_covvfit/deconvolved.csv" + VAR_DATES_PATH = "../../LolliPop/lollipop_covvfit/var_dates.yaml" +else: + DATA_PATH = "../new_data/deconvolved.csv" + VAR_DATES_PATH = "../new_data/var_dates.yaml" data = pd.read_csv(DATA_PATH, sep="\t") @@ -60,21 +55,27 @@ # Access the var_dates data var_dates = var_dates_data["var_dates"] -# - + data_wide = data.pivot_table( index=["date", "location"], columns="variant", values="proportion", fill_value=0 ).reset_index() data_wide = data_wide.rename(columns={"date": "time", "location": "city"}) -data_wide.head() -# + +# Define the list with cities: +cities = list(data_wide["city"].unique()) + ## Set limit times for modeling max_date = pd.to_datetime(data_wide["time"]).max() delta_time = pd.Timedelta(days=240) start_date = max_date - delta_time +# Print the data frame +data_wide.head() +# - + +# Now we look at the variants in the data and define the variants of interest: # + # Convert the keys to datetime objects for comparison @@ -90,72 +91,85 @@ def match_date(start_date): return closest_date, var_dates_parsed[closest_date] -variants_full = match_date(start_date + delta_time)[1] - -variants = ["KP.2", "KP.3", "XEC"] +variants_full = match_date(start_date + delta_time)[1] # All the variants in this range -variants_other = [i for i in variants_full if i not in variants] +variants_of_interest = ["KP.2", "KP.3", "XEC"] # Variants of interest +variants_other = [ + i for i in variants_full if i not in variants_of_interest +] # Variants not of interest # - -cities = list(data_wide["city"].unique()) +# Apart from the variants of interest, we define the "other" variant, which artificially merges all the other variants into one. This allows us to model the data as a compositional time series, i.e., the sum of abundances of all "variants" is normalized to one. -variants2 = ["other"] + variants -data2 = prec.preprocess_df( +# + +variants_effective = ["other"] + variants_of_interest +data_full = preprocess.preprocess_df( data_wide, cities, variants_full, date_min=start_date, zero_date=start_date ) -# + -data2["other"] = data2[variants_other].sum(axis=1) -data2[variants2] = data2[variants2].div(data2[variants2].sum(axis=1), axis=0) - -ts_lst, ys_lst = prec.make_data_list(data2, cities, variants2) -ts_lst, ys_lst2 = prec.make_data_list(data2, cities, variants) +data_full["other"] = data_full[variants_other].sum(axis=1) +data_full[variants_effective] = data_full[variants_effective].div( + data_full[variants_effective].sum(axis=1), axis=0 +) -t_max = max([x.max() for x in ts_lst]) -t_min = min([x.min() for x in ts_lst]) +# + +_, ys_effective = preprocess.make_data_list(data_full, cities, variants_effective) +ts_lst, ys_of_interest = preprocess.make_data_list( + data_full, cities, variants_of_interest +) -ts_lst_scaled = [(x - t_min) / (t_max - t_min) for x in ts_lst] +# Scale the time for numerical stability +t_scaler = preprocess.TimeScaler() +ts_lst_scaled = t_scaler.fit_transform(ts_lst) # - -# # fit in jax +# ## Fit the quasimultinomial model +# +# Now we fit the quasimultinomial model, which allows us to find the maximum quasilikelihood estimate of the parameters: # + # %%time -# Recall that the input should be (n_timepoints, n_variants) -# TODO(Pawel, David): Resolve Issue https://github.com/cbg-ethz/covvfit/issues/24 -observed_data = [y.T for y in ys_lst] - - # no priors loss = qm.construct_total_loss( - ys=observed_data, + ys=ys_effective, ts=ts_lst_scaled, average_loss=False, # Do not average the loss over the data points, so that the covariance matrix shrinks with more and more data added ) +n_variants_effective = len(variants_effective) + # initial parameters -theta0 = qm.construct_theta0(n_cities=len(cities), n_variants=len(variants2)) +theta0 = qm.construct_theta0(n_cities=len(cities), n_variants=n_variants_effective) # Run the optimization routine solution = qm.jax_multistart_minimize(loss, theta0, n_starts=10) + +theta_star = solution.x # The maximum quasilikelihood estimate + +print( + f"Relative growth rates: \n", + qm.get_relative_growths(theta_star, n_variants=n_variants_effective), +) # - # ## Make fitted values and confidence intervals -# + ## compute fitted values fitted_values = qm.fitted_values( - ts_lst_scaled, theta=solution.x, cities=cities, n_variants=len(variants2) + ts_lst_scaled, theta=theta_star, cities=cities, n_variants=n_variants_effective ) + +# + +# TODO(Pawel): Refactor this out!!!!! # ... and because of https://github.com/cbg-ethz/covvfit/issues/24 # we need to transpose again y_fit_lst = [y.T[1:] for y in fitted_values] ## compute covariance matrix -covariance = qm.get_covariance(loss, solution.x) +covariance = qm.get_covariance(loss, theta_star) overdispersion_tuple = qm.compute_overdispersion( observed=observed_data, @@ -171,7 +185,8 @@ def match_date(start_date): ## compute standard errors and confidence intervals of the estimates standard_errors_estimates = qm.get_standard_errors(covariance_scaled) -confints_estimates = qm.get_confidence_intervals(solution.x, standard_errors_estimates) +confints_estimates = qm.get_confidence_intervals(theta_star, standard_errors_estimates) + ## compute confidence intervals of the fitted values on the logit scale and back transform y_fit_lst_confint = qm.get_confidence_bands_logit( @@ -207,7 +222,7 @@ def match_date(start_date): # ## Plot # + -colors_covsp = plot_ts.colors_covsp +colors_covsp = plot_ts.COLORS_COVSPECTRUM colors = [colors_covsp[var] for var in variants] fig, axes_tot = plt.subplots(4, 2, figsize=(15, 10)) axes_flat = axes_tot.flatten() @@ -257,4 +272,3 @@ def format_date(x, pos): fig.tight_layout() fig.show() -# - From c92935dfdfe02b9a083331615341934f893b29ee Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Czy=C5=BC?= Date: Tue, 19 Nov 2024 11:01:14 +0100 Subject: [PATCH 06/10] Add more tests --- src/covvfit/_quasimultinomial.py | 53 +++++++++++++++++++++----------- tests/test_quasimultinomial.py | 46 +++++++++++++++++++++++++++ 2 files changed, 81 insertions(+), 18 deletions(-) diff --git a/src/covvfit/_quasimultinomial.py b/src/covvfit/_quasimultinomial.py index a19eb54..a253386 100644 --- a/src/covvfit/_quasimultinomial.py +++ b/src/covvfit/_quasimultinomial.py @@ -67,7 +67,7 @@ def loss( return -jnp.sum(n * y * logp, axis=-1) -_ThetaType = Float[Array, "(cities+1)*(variants-1)"] +ModelParameters = Float[Array, "(cities+1)*(variants-1)"] def _add_first_variant(vec: Float[Array, " variants-1"]) -> Float[Array, " variants"]: @@ -78,21 +78,21 @@ def _add_first_variant(vec: Float[Array, " variants-1"]) -> Float[Array, " varia def construct_theta( relative_growths: Float[Array, " variants-1"], relative_midpoints: Float[Array, "cities variants-1"], -) -> _ThetaType: +) -> ModelParameters: flattened_midpoints = relative_midpoints.flatten() theta = jnp.concatenate([relative_growths, flattened_midpoints]) return theta def get_relative_growths( - theta: _ThetaType, + theta: ModelParameters, n_variants: int, ) -> Float[Array, " variants-1"]: return theta[: n_variants - 1] def get_relative_midpoints( - theta: _ThetaType, + theta: ModelParameters, n_variants: int, ) -> Float[Array, "cities variants-1"]: n_cities = theta.shape[0] // (n_variants - 1) - 1 @@ -113,8 +113,8 @@ def convert(confidence: float) -> float: def get_covariance( - loss_fn: Callable[[_ThetaType], _Float], - theta: _ThetaType, + loss_fn: Callable[[ModelParameters], _Float], + theta: ModelParameters, ) -> Float[Array, "n_params n_params"]: """Calculates the covariance matrix of the parameters. @@ -196,7 +196,7 @@ def get_confidence_intervals( def fitted_values( times: list[Float[Array, " timepoints"]], - theta: _ThetaType, + theta: ModelParameters, cities: list, n_variants: int, ) -> list[Float[Array, "timepoints variants"]]: @@ -239,7 +239,7 @@ def create_logit_predictions_fn( """ def logit_predictions_with_fixed_args( - theta: _ThetaType, + theta: ModelParameters, ): return get_logit_predictions( theta=theta, n_variants=n_variants, city_index=city_index, ts=ts @@ -249,7 +249,7 @@ def logit_predictions_with_fixed_args( def get_confidence_bands_logit( - solution_x: Float[Array, " (cities+1)*(variants-1)"], + theta: ModelParameters, variants_count: int, ts_lst_scaled: list[Float[Array, " timepoints"]], covariance_scaled: Float[Array, "n_params n_params"], @@ -270,9 +270,10 @@ def get_confidence_bands_logit( for the confidence intervals on the linear scale. """ # TODO(Pawel): Potentially fix the signature of this function. + # Issue 24 y_fit_lst_logit = [ - get_logit_predictions(solution_x, variants_count, i, ts).T[1:, :] + get_logit_predictions(theta, variants_count, i, ts).T[1:, :] for i, ts in enumerate(ts_lst_scaled) ] @@ -280,7 +281,7 @@ def get_confidence_bands_logit( for i, ts in enumerate(ts_lst_scaled): # Compute the Jacobian of the transformation and project standard errors jacobian = jax.jacobian(create_logit_predictions_fn(variants_count, i, ts))( - solution_x + theta ) standard_errors = get_standard_errors( jacobian=jacobian, covariance=covariance_scaled @@ -310,7 +311,22 @@ def triangular_mask(n_variants, valid_value: float = 0, masked_value: float = jn return nan_mask -def get_relative_advantages(theta, n_variants: int): +def get_relative_advantages( + theta: ModelParameters, n_variants: int +) -> Float[Array, "variants variants"]: + """Returns a matrix of relative advantages, comparing every two variants. + + Returns: + matrix of shape (n_variants, n_variants) with `A[reference, variant]` + representing the relative advantage of `variant` over `reference`. + + Note: + From the model assumptions it follows that + `A[v1, v2] + A[v2, v3] = A[v1, v3]` + for every three variants. (I.e., the relative advantage + of `v3` over `v1` is the sum of advantages of `v3` over `v2` + and `v2` over `v1`) + """ # Shape (n_variants-1,) describing relative advantages # over the 0th variant rel_growths = get_relative_growths(theta, n_variants=n_variants) @@ -321,10 +337,11 @@ def get_relative_advantages(theta, n_variants: int): def get_softmax_predictions( - theta: _ThetaType, n_variants: int, city_index: int, ts: Float[Array, " timepoints"] + theta: ModelParameters, + n_variants: int, + city_index: int, + ts: Float[Array, " timepoints"], ) -> Float[Array, "timepoints variants"]: - # TODO(Pawel): Potentially fix the signature of this function. - rel_growths = get_relative_growths(theta, n_variants=n_variants) growths = _add_first_variant(rel_growths) @@ -342,7 +359,7 @@ def get_softmax_predictions( def get_logit_predictions( - theta: _ThetaType, + theta: ModelParameters, n_variants: int, city_index: int, ts: Float[Array, " timepoints"], @@ -368,7 +385,7 @@ class OptimizeMultiResult: def construct_theta0( n_cities: int, n_variants: int, -) -> _ThetaType: +) -> ModelParameters: return np.zeros((n_cities * (n_variants - 1) + n_variants - 1,), dtype=float) @@ -668,7 +685,7 @@ def construct_total_loss( overdispersion: _OverDispersionType = 1.0, accept_theta: bool = True, average_loss: bool = False, -) -> Callable[[_ThetaType], _Float] | _RelativeGrowthsAndOffsetsFunction: +) -> Callable[[ModelParameters], _Float] | _RelativeGrowthsAndOffsetsFunction: """Constructs the loss function, suitable e.g., for optimization. Args: diff --git a/tests/test_quasimultinomial.py b/tests/test_quasimultinomial.py index edb4f9d..507d13e 100644 --- a/tests/test_quasimultinomial.py +++ b/tests/test_quasimultinomial.py @@ -1,5 +1,6 @@ import covvfit._quasimultinomial as qm import jax +import jax.numpy as jnp import numpy.testing as npt import pytest @@ -46,3 +47,48 @@ def test_parameter_conversions_2(seed: int, n_cities: int, n_variants: int) -> N ), theta, ) + + +def test_softmax_predictions( + n_cities: int = 2, n_variants: int = 3, n_timepoints: int = 50 +) -> None: + theta0 = qm.construct_theta0(n_cities=n_cities, n_variants=n_variants) + theta = jax.random.normal(jax.random.PRNGKey(42), shape=theta0.shape) + + ts = jnp.linspace(0, 1, n_timepoints) + + for city in range(n_cities): + predictions = qm.get_softmax_predictions( + theta, + n_variants=n_variants, + city_index=city, + ts=ts, + ) + + assert predictions.shape == (n_timepoints, n_variants) + + npt.assert_allclose( + predictions.sum(axis=-1), + jnp.ones(n_timepoints), + atol=1e-6, + ) + + +def test_get_relative_advantages(n_cities: int = 1, n_variants: int = 5) -> None: + theta0 = qm.construct_theta0(n_cities=n_cities, n_variants=n_variants) + # The variants are ordered by increasing fitness + relative = jnp.arange(1, n_variants) + theta = qm.construct_theta( + relative_growths=relative, + relative_midpoints=qm.get_relative_midpoints(theta0, n_variants=n_variants), + ) + + A = qm.get_relative_advantages(theta, n_variants=n_variants) + for v2 in range(n_variants): + for v1 in range(n_variants): + assert pytest.approx(A[v1, v2]) == v2 - v1 + + for v1 in range(n_variants): + for v2 in range(n_variants): + for v3 in range(n_variants): + assert pytest.approx(A[v1, v3]) == A[v1, v2] + A[v2, v3] From 19be8faf8bf3368ae1670cad7f2b9e576a6ece13 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Czy=C5=BC?= Date: Tue, 19 Nov 2024 12:04:58 +0100 Subject: [PATCH 07/10] WIP: Notebook refactor out confidence interval construction --- examples/frequentist_notebook_jax.py | 75 ++++++++++++++++++---------- src/covvfit/_quasimultinomial.py | 26 +++++----- 2 files changed, 62 insertions(+), 39 deletions(-) diff --git a/examples/frequentist_notebook_jax.py b/examples/frequentist_notebook_jax.py index 8c2d66f..bc65457 100644 --- a/examples/frequentist_notebook_jax.py +++ b/examples/frequentist_notebook_jax.py @@ -36,7 +36,7 @@ # We start by loading the data: # + -_dir_switch = False # Change this to 0 or 1, depending on the laptop you are on +_dir_switch = False # Change this to True or False, depending on the laptop you are on if _dir_switch: DATA_PATH = "../../LolliPop/lollipop_covvfit/deconvolved.csv" VAR_DATES_PATH = "../../LolliPop/lollipop_covvfit/var_dates.yaml" @@ -48,7 +48,7 @@ data = pd.read_csv(DATA_PATH, sep="\t") data.head() - +# + # Load the YAML file with open(VAR_DATES_PATH, "r") as file: var_dates_data = yaml.safe_load(file) @@ -113,9 +113,11 @@ def match_date(start_date): ) # + -_, ys_effective = preprocess.make_data_list(data_full, cities, variants_effective) -ts_lst, ys_of_interest = preprocess.make_data_list( - data_full, cities, variants_of_interest +ts_lst, ys_effective = preprocess.make_data_list( + data_full, cities=cities, variants=variants_effective +) +_, ys_of_interest = preprocess.make_data_list( + data_full, cities=cities, variants=variants_of_interest ) # Scale the time for numerical stability @@ -149,81 +151,102 @@ def match_date(start_date): theta_star = solution.x # The maximum quasilikelihood estimate print( - f"Relative growth rates: \n", + f"Relative growth advantages: \n", qm.get_relative_growths(theta_star, n_variants=n_variants_effective), ) # - -# ## Make fitted values and confidence intervals +# ## Confidence intervals of the growth advantages +# +# To obtain confidence intervals, we will take into account overdispersion. To do this, we need to compare the predictions with the observed values. Then, we can use overdispersion to attempt to correct the covariance matrix and obtain the confidence intervals. +# + ## compute fitted values fitted_values = qm.fitted_values( ts_lst_scaled, theta=theta_star, cities=cities, n_variants=n_variants_effective ) - -# + -# TODO(Pawel): Refactor this out!!!!! -# ... and because of https://github.com/cbg-ethz/covvfit/issues/24 -# we need to transpose again -y_fit_lst = [y.T[1:] for y in fitted_values] - ## compute covariance matrix covariance = qm.get_covariance(loss, theta_star) overdispersion_tuple = qm.compute_overdispersion( - observed=observed_data, + observed=ys_effective, predicted=fitted_values, ) overdisp_fixed = overdispersion_tuple.overall +print( + f"Overdispersion factor: {float(overdisp_fixed):.3f}.\nNote that values lower than 1 signify underdispersion." +) + -# + ## scale covariance by overdisp covariance_scaled = overdisp_fixed * covariance ## compute standard errors and confidence intervals of the estimates standard_errors_estimates = qm.get_standard_errors(covariance_scaled) -confints_estimates = qm.get_confidence_intervals(theta_star, standard_errors_estimates) +confints_estimates = qm.get_confidence_intervals( + theta_star, standard_errors_estimates, confidence_level=0.95 +) + + +print("\n\nRelative growth advantages:") +for variant, m, l, u in zip( + variants_effective[1:], + qm.get_relative_growths(theta_star, n_variants=n_variants_effective), + qm.get_relative_growths(confints_estimates[0], n_variants=n_variants_effective), + qm.get_relative_growths(confints_estimates[1], n_variants=n_variants_effective), +): + print(f" {variant}: {float(m):.2f} ({float(l):.2f} – {float(u):.2f})") + + +# + +# TODO(Pawel): Refactor this out!!!!! +# ... and because of https://github.com/cbg-ethz/covvfit/issues/24 +# we need to transpose again +y_fit_lst = [y.T[1:] for y in fitted_values] ## compute confidence intervals of the fitted values on the logit scale and back transform y_fit_lst_confint = qm.get_confidence_bands_logit( - solution.x, len(variants2), ts_lst_scaled, covariance_scaled + theta_star, len(variants_effective), ts_lst_scaled, covariance_scaled ) ## compute predicted values and confidence bands horizon = 60 ts_pred_lst = [jnp.arange(horizon + 1) + tt.max() for tt in ts_lst] -ts_pred_lst_scaled = [(x - t_min) / (t_max - t_min) for x in ts_pred_lst] +ts_pred_lst_scaled = t_scaler.transform(ts_pred_lst) y_pred_lst = qm.fitted_values( - ts_pred_lst_scaled, theta=solution.x, cities=cities, n_variants=len(variants2) + ts_pred_lst_scaled, theta=theta_star, cities=cities, n_variants=n_variants_effective ) # ... and because of https://github.com/cbg-ethz/covvfit/issues/24 # we need to transpose again y_pred_lst = [y.T[1:] for y in y_pred_lst] y_pred_lst_confint = qm.get_confidence_bands_logit( - solution.x, len(variants2), ts_pred_lst_scaled, covariance_scaled + solution.x, n_variants_effective, ts_pred_lst_scaled, covariance_scaled ) # - -# ## Plotting functions +confints_estimates + +y_pred_lst[0].shape + +# ## Plot + +# + +colors = [plot_ts.COLORS_COVSPECTRUM[var] for var in variants] plot_fit = plot_ts.plot_fit plot_complement = plot_ts.plot_complement plot_data = plot_ts.plot_data plot_confidence_bands = plot_ts.plot_confidence_bands -# ## Plot -# + -colors_covsp = plot_ts.COLORS_COVSPECTRUM -colors = [colors_covsp[var] for var in variants] fig, axes_tot = plt.subplots(4, 2, figsize=(15, 10)) axes_flat = axes_tot.flatten() diff --git a/src/covvfit/_quasimultinomial.py b/src/covvfit/_quasimultinomial.py index a253386..692ba57 100644 --- a/src/covvfit/_quasimultinomial.py +++ b/src/covvfit/_quasimultinomial.py @@ -222,7 +222,7 @@ def fitted_values( return y_fit_lst -def create_logit_predictions_fn( +def _create_logit_predictions_fn( n_variants: int, city_index: int, ts: Float[Array, " timepoints"] ) -> Callable[ [Float[Array, " (cities+1)*(variants-1)"]], Float[Array, "timepoints variants"] @@ -250,8 +250,9 @@ def logit_predictions_with_fixed_args( def get_confidence_bands_logit( theta: ModelParameters, - variants_count: int, - ts_lst_scaled: list[Float[Array, " timepoints"]], + *, + n_variants: int, + ts: list[Float[Array, " timepoints"]], covariance_scaled: Float[Array, "n_params n_params"], confidence_level: float = 0.95, ) -> list[tuple]: @@ -259,10 +260,11 @@ def get_confidence_bands_logit( back-transforms them to the linear scale Args: - solution_x: Optimized parameters for the model. + theta: Parameters for the model. variants_count: Number of variants. ts_lst_scaled: List of timepoint arrays for each city. - covariance_scaled: Covariance matrix for the parameters. + covariance: Covariance matrix for the parameters. Note that it should + include any overdispersion factors. confidence_level: Desired confidence level for intervals (default is 0.95). Returns: @@ -272,17 +274,15 @@ def get_confidence_bands_logit( # TODO(Pawel): Potentially fix the signature of this function. # Issue 24 - y_fit_lst_logit = [ - get_logit_predictions(theta, variants_count, i, ts).T[1:, :] - for i, ts in enumerate(ts_lst_scaled) + logit_timeseries = [ + get_logit_predictions(theta, n_variants, i, ts).T[1:, :] + for i, ts in enumerate(ts) ] y_fit_lst_logit_se = [] - for i, ts in enumerate(ts_lst_scaled): + for i, ts in enumerate(ts): # Compute the Jacobian of the transformation and project standard errors - jacobian = jax.jacobian(create_logit_predictions_fn(variants_count, i, ts))( - theta - ) + jacobian = jax.jacobian(_create_logit_predictions_fn(n_variants, i, ts))(theta) standard_errors = get_standard_errors( jacobian=jacobian, covariance=covariance_scaled ).T @@ -291,7 +291,7 @@ def get_confidence_bands_logit( # Compute confidence intervals on the logit scale y_fit_lst_logit_confint = [ get_confidence_intervals(fitted, se, confidence_level=confidence_level) - for fitted, se in zip(y_fit_lst_logit, y_fit_lst_logit_se) + for fitted, se in zip(logit_timeseries, y_fit_lst_logit_se) ] # Project confidence intervals to the linear scale From e04f79b7402875d7c5010419928a9b043832a968 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Czy=C5=BC?= Date: Tue, 19 Nov 2024 12:13:56 +0100 Subject: [PATCH 08/10] WIP: Fixed the notebook --- examples/frequentist_notebook_jax.py | 19 ++++++++++++------- src/covvfit/_quasimultinomial.py | 4 ++-- 2 files changed, 14 insertions(+), 9 deletions(-) diff --git a/examples/frequentist_notebook_jax.py b/examples/frequentist_notebook_jax.py index bc65457..33a3b24 100644 --- a/examples/frequentist_notebook_jax.py +++ b/examples/frequentist_notebook_jax.py @@ -121,8 +121,8 @@ def match_date(start_date): ) # Scale the time for numerical stability -t_scaler = preprocess.TimeScaler() -ts_lst_scaled = t_scaler.fit_transform(ts_lst) +time_scaler = preprocess.TimeScaler() +ts_lst_scaled = time_scaler.fit_transform(ts_lst) # - @@ -210,13 +210,17 @@ def match_date(start_date): ## compute confidence intervals of the fitted values on the logit scale and back transform y_fit_lst_confint = qm.get_confidence_bands_logit( - theta_star, len(variants_effective), ts_lst_scaled, covariance_scaled + theta_star, + n_variants=n_variants_effective, + ts=ts_lst_scaled, + covariance=covariance_scaled, ) + ## compute predicted values and confidence bands horizon = 60 ts_pred_lst = [jnp.arange(horizon + 1) + tt.max() for tt in ts_lst] -ts_pred_lst_scaled = t_scaler.transform(ts_pred_lst) +ts_pred_lst_scaled = time_scaler.transform(ts_pred_lst) y_pred_lst = qm.fitted_values( ts_pred_lst_scaled, theta=theta_star, cities=cities, n_variants=n_variants_effective @@ -226,14 +230,15 @@ def match_date(start_date): y_pred_lst = [y.T[1:] for y in y_pred_lst] y_pred_lst_confint = qm.get_confidence_bands_logit( - solution.x, n_variants_effective, ts_pred_lst_scaled, covariance_scaled + theta_star, + n_variants=n_variants_effective, + ts=ts_pred_lst_scaled, + covariance=covariance_scaled, ) # - -confints_estimates - y_pred_lst[0].shape # ## Plot diff --git a/src/covvfit/_quasimultinomial.py b/src/covvfit/_quasimultinomial.py index 692ba57..70408d1 100644 --- a/src/covvfit/_quasimultinomial.py +++ b/src/covvfit/_quasimultinomial.py @@ -253,7 +253,7 @@ def get_confidence_bands_logit( *, n_variants: int, ts: list[Float[Array, " timepoints"]], - covariance_scaled: Float[Array, "n_params n_params"], + covariance: Float[Array, "n_params n_params"], confidence_level: float = 0.95, ) -> list[tuple]: """Computes confidence intervals for logit predictions using the Delta method, @@ -284,7 +284,7 @@ def get_confidence_bands_logit( # Compute the Jacobian of the transformation and project standard errors jacobian = jax.jacobian(_create_logit_predictions_fn(n_variants, i, ts))(theta) standard_errors = get_standard_errors( - jacobian=jacobian, covariance=covariance_scaled + jacobian=jacobian, covariance=covariance ).T y_fit_lst_logit_se.append(standard_errors) From 36ce8914c865400ee8cfb3cd935f264599ac0135 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Czy=C5=BC?= Date: Wed, 20 Nov 2024 13:08:06 +0100 Subject: [PATCH 09/10] Finish refactoring. --- examples/frequentist_notebook_jax.py | 124 ++++++++++++-------------- pyproject.toml | 1 + src/covvfit/_preprocess_abundances.py | 4 + src/covvfit/_quasimultinomial.py | 32 ++++--- src/covvfit/plotting/__init__.py | 11 ++- src/covvfit/plotting/_grid.py | 122 ++++++++++++++++++++++++- src/covvfit/plotting/_timeseries.py | 58 +++++++++--- 7 files changed, 252 insertions(+), 100 deletions(-) diff --git a/examples/frequentist_notebook_jax.py b/examples/frequentist_notebook_jax.py index 33a3b24..da60c05 100644 --- a/examples/frequentist_notebook_jax.py +++ b/examples/frequentist_notebook_jax.py @@ -18,6 +18,7 @@ # This notebook shows how to estimate growth advantages by fiting the model within the quasimultinomial framework. # + +import jax import jax.numpy as jnp import matplotlib.pyplot as plt import matplotlib.ticker as ticker @@ -93,16 +94,20 @@ def match_date(start_date): variants_full = match_date(start_date + delta_time)[1] # All the variants in this range -variants_of_interest = ["KP.2", "KP.3", "XEC"] # Variants of interest +variants_investigated = [ + "KP.2", + "KP.3", + "XEC", +] # Variants found in the data, which we focus on in this analysis variants_other = [ - i for i in variants_full if i not in variants_of_interest + i for i in variants_full if i not in variants_investigated ] # Variants not of interest # - # Apart from the variants of interest, we define the "other" variant, which artificially merges all the other variants into one. This allows us to model the data as a compositional time series, i.e., the sum of abundances of all "variants" is normalized to one. # + -variants_effective = ["other"] + variants_of_interest +variants_effective = ["other"] + variants_investigated data_full = preprocess.preprocess_df( data_wide, cities, variants_full, date_min=start_date, zero_date=start_date ) @@ -116,9 +121,6 @@ def match_date(start_date): ts_lst, ys_effective = preprocess.make_data_list( data_full, cities=cities, variants=variants_effective ) -_, ys_of_interest = preprocess.make_data_list( - data_full, cities=cities, variants=variants_of_interest -) # Scale the time for numerical stability time_scaler = preprocess.TimeScaler() @@ -162,7 +164,7 @@ def match_date(start_date): # + ## compute fitted values -fitted_values = qm.fitted_values( +ys_fitted = qm.fitted_values( ts_lst_scaled, theta=theta_star, cities=cities, n_variants=n_variants_effective ) @@ -171,15 +173,13 @@ def match_date(start_date): overdispersion_tuple = qm.compute_overdispersion( observed=ys_effective, - predicted=fitted_values, + predicted=ys_fitted, ) overdisp_fixed = overdispersion_tuple.overall -print( - f"Overdispersion factor: {float(overdisp_fixed):.3f}.\nNote that values lower than 1 signify underdispersion." -) - +print(f"Overdispersion factor: {float(overdisp_fixed):.3f}.") +print("Note that values lower than 1 signify underdispersion.") ## scale covariance by overdisp covariance_scaled = overdisp_fixed * covariance @@ -199,17 +199,13 @@ def match_date(start_date): qm.get_relative_growths(confints_estimates[1], n_variants=n_variants_effective), ): print(f" {variant}: {float(m):.2f} ({float(l):.2f} – {float(u):.2f})") +# - -# + -# TODO(Pawel): Refactor this out!!!!! -# ... and because of https://github.com/cbg-ethz/covvfit/issues/24 -# we need to transpose again -y_fit_lst = [y.T[1:] for y in fitted_values] - +# We can propagate this uncertainty to the observed values. Let's generate confidence bands around the fitted lines and predict the future behaviour. -## compute confidence intervals of the fitted values on the logit scale and back transform -y_fit_lst_confint = qm.get_confidence_bands_logit( +# + +ys_fitted_confint = qm.get_confidence_bands_logit( theta_star, n_variants=n_variants_effective, ts=ts_lst_scaled, @@ -222,67 +218,59 @@ def match_date(start_date): ts_pred_lst = [jnp.arange(horizon + 1) + tt.max() for tt in ts_lst] ts_pred_lst_scaled = time_scaler.transform(ts_pred_lst) -y_pred_lst = qm.fitted_values( +ys_pred = qm.fitted_values( ts_pred_lst_scaled, theta=theta_star, cities=cities, n_variants=n_variants_effective ) -# ... and because of https://github.com/cbg-ethz/covvfit/issues/24 -# we need to transpose again -y_pred_lst = [y.T[1:] for y in y_pred_lst] - -y_pred_lst_confint = qm.get_confidence_bands_logit( +ys_pred_confint = qm.get_confidence_bands_logit( theta_star, n_variants=n_variants_effective, ts=ts_pred_lst_scaled, covariance=covariance_scaled, ) - - # - -y_pred_lst[0].shape - # ## Plot +# +# Finally, we plot the abundance data and the model predictions. Note that the 0th element in each array corresponds to the artificial "other" variant and we decided to plot only the explicitly defined variants. # + -colors = [plot_ts.COLORS_COVSPECTRUM[var] for var in variants] - -plot_fit = plot_ts.plot_fit -plot_complement = plot_ts.plot_complement -plot_data = plot_ts.plot_data -plot_confidence_bands = plot_ts.plot_confidence_bands - - -fig, axes_tot = plt.subplots(4, 2, figsize=(15, 10)) -axes_flat = axes_tot.flatten() - -for i, city in enumerate(cities): - ax = axes_flat[i] - # plot fitted and predicted values - plot_fit(ax, ts_lst[i], y_fit_lst[i], variants, colors) - plot_fit(ax, ts_pred_lst[i], y_pred_lst[i], variants, colors, linetype="--") - - # # plot 1-fitted and predicted values - plot_complement(ax, ts_lst[i], y_fit_lst[i], variants) - # plot_complement(ax, ts_pred_lst[i], y_pred_lst[i], variants, linetype="--") - # plot raw deconvolved values - plot_data(ax, ts_lst[i], ys_lst2[i], variants, colors) - # make confidence bands and plot them - conf_bands = y_fit_lst_confint[i] - plot_confidence_bands( +colors = [plot_ts.COLORS_COVSPECTRUM[var] for var in variants_investigated] + + +figure_spec = plot.arrange_into_grid(len(cities), axsize=(4, 1.5), dpi=350, wspace=1) + + +def plot_city(ax, i: int) -> None: + def remove_0th(arr): + """We don't plot the artificial 0th variant 'other'.""" + return arr[:, 1:] + + # Plot fits in observed and unobserved time intervals. + plot_ts.plot_fit(ax, ts_lst[i], remove_0th(ys_fitted[i]), colors=colors) + plot_ts.plot_fit( + ax, ts_pred_lst[i], remove_0th(ys_pred[i]), colors=colors, linestyle="--" + ) + + plot_ts.plot_confidence_bands( ax, ts_lst[i], - {"lower": conf_bands[0], "upper": conf_bands[1]}, - variants, - colors, + jax.tree.map(remove_0th, ys_fitted_confint[i]), + colors=colors, ) - - pred_bands = y_pred_lst_confint[i] - plot_confidence_bands( + plot_ts.plot_confidence_bands( ax, ts_pred_lst[i], - {"lower": pred_bands[0], "upper": pred_bands[1]}, - variants, - colors, + jax.tree.map(remove_0th, ys_pred_confint[i]), + colors=colors, + ) + + # Plot the data points + plot_ts.plot_data(ax, ts_lst[i], remove_0th(ys_effective[i]), colors=colors) + + # Plot the complements + plot_ts.plot_complement(ax, ts_lst[i], remove_0th(ys_fitted[i]), alpha=0.3) + plot_ts.plot_complement( + ax, ts_pred_lst[i], remove_0th(ys_pred[i]), linestyle="--", alpha=0.3 ) # format axes and title @@ -295,8 +283,8 @@ def format_date(x, pos): tick_labels = ["0%", "50%", "100%"] ax.set_yticks(tick_positions) ax.set_yticklabels(tick_labels) - ax.set_ylabel("relative abundances") - ax.set_title(city) + ax.set_ylabel("Relative abundances") + ax.set_title(cities[i]) + -fig.tight_layout() -fig.show() +figure_spec.map(plot_city, range(len(cities))) diff --git a/pyproject.toml b/pyproject.toml index da2bdbc..b93a257 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,6 +19,7 @@ numpy = "==1.24.3" #pymc = "==5.3.0" seaborn = "^0.13.2" numpyro = "^0.14.0" +subplots-from-axsize = "^0.1.9" [tool.poetry.group.dev] diff --git a/src/covvfit/_preprocess_abundances.py b/src/covvfit/_preprocess_abundances.py index cabf1ef..8d6c6df 100644 --- a/src/covvfit/_preprocess_abundances.py +++ b/src/covvfit/_preprocess_abundances.py @@ -146,3 +146,7 @@ def fit_transform(self, ts: _ListTimeSeries) -> _ListTimeSeries: """ self.fit(ts) return self.transform(ts) + + @property + def time_unit(self) -> float: + return self.t_max - self.t_min diff --git a/src/covvfit/_quasimultinomial.py b/src/covvfit/_quasimultinomial.py index 70408d1..339444b 100644 --- a/src/covvfit/_quasimultinomial.py +++ b/src/covvfit/_quasimultinomial.py @@ -243,11 +243,16 @@ def logit_predictions_with_fixed_args( ): return get_logit_predictions( theta=theta, n_variants=n_variants, city_index=city_index, ts=ts - )[:, 1:] + ) return logit_predictions_with_fixed_args +class ConfidenceBand(NamedTuple): + lower: Float[Array, "timepoints variants"] + upper: Float[Array, "timepoints variants"] + + def get_confidence_bands_logit( theta: ModelParameters, *, @@ -255,7 +260,7 @@ def get_confidence_bands_logit( ts: list[Float[Array, " timepoints"]], covariance: Float[Array, "n_params n_params"], confidence_level: float = 0.95, -) -> list[tuple]: +) -> list[ConfidenceBand]: """Computes confidence intervals for logit predictions using the Delta method, back-transforms them to the linear scale @@ -271,33 +276,32 @@ def get_confidence_bands_logit( A list of dictionaries for each city, each with "lower" and "upper" bounds for the confidence intervals on the linear scale. """ - # TODO(Pawel): Potentially fix the signature of this function. - # Issue 24 - logit_timeseries = [ - get_logit_predictions(theta, n_variants, i, ts).T[1:, :] - for i, ts in enumerate(ts) + get_logit_predictions(theta, n_variants, i, ts) for i, ts in enumerate(ts) ] - y_fit_lst_logit_se = [] + logit_se = [] for i, ts in enumerate(ts): # Compute the Jacobian of the transformation and project standard errors jacobian = jax.jacobian(_create_logit_predictions_fn(n_variants, i, ts))(theta) standard_errors = get_standard_errors( jacobian=jacobian, covariance=covariance - ).T - y_fit_lst_logit_se.append(standard_errors) + ) # TODO(Pawel): Issue 24, take a look at it. + logit_se.append(standard_errors) # Compute confidence intervals on the logit scale - y_fit_lst_logit_confint = [ + logit_confint = [ get_confidence_intervals(fitted, se, confidence_level=confidence_level) - for fitted, se in zip(logit_timeseries, y_fit_lst_logit_se) + for fitted, se in zip(logit_timeseries, logit_se) ] # Project confidence intervals to the linear scale y_fit_lst_logit_confint_expit = [ - (jax.scipy.special.expit(confint[0]), jax.scipy.special.expit(confint[1])) - for confint in y_fit_lst_logit_confint + ConfidenceBand( + lower=jax.scipy.special.expit(confint[0]), + upper=jax.scipy.special.expit(confint[1]), + ) + for confint in logit_confint ] return y_fit_lst_logit_confint_expit diff --git a/src/covvfit/plotting/__init__.py b/src/covvfit/plotting/__init__.py index 738d047..5883c1e 100644 --- a/src/covvfit/plotting/__init__.py +++ b/src/covvfit/plotting/__init__.py @@ -1,13 +1,20 @@ """Plotting functionalities.""" import covvfit.plotting._timeseries as timeseries -from covvfit.plotting._grid import plot_grid, set_axis_off +from covvfit.plotting._grid import ( + ArrangedGrid, + arrange_into_grid, + plot_on_rectangular_grid, + set_axis_off, +) from covvfit.plotting._simplex import plot_on_simplex from covvfit.plotting._timeseries import COLORS_COVSPECTRUM, make_legend, num_to_date __all__ = [ + "ArrangedGrid", + "arrange_into_grid", "plot_on_simplex", - "plot_grid", + "plot_on_rectangular_grid", "set_axis_off", "make_legend", "num_to_date", diff --git a/src/covvfit/plotting/_grid.py b/src/covvfit/plotting/_grid.py index cb2dca9..cd95c60 100644 --- a/src/covvfit/plotting/_grid.py +++ b/src/covvfit/plotting/_grid.py @@ -1,9 +1,11 @@ +from dataclasses import dataclass from typing import Any, Callable import matplotlib.pyplot as plt import numpy as np from matplotlib.axes import Axes from matplotlib.figure import Figure +from subplots_from_axsize import subplots_from_axsize def set_axis_off(ax: Axes, i: int = 0, j: int = 0) -> None: @@ -11,7 +13,7 @@ def set_axis_off(ax: Axes, i: int = 0, j: int = 0) -> None: ax.set_axis_off() -def plot_grid( +def plot_on_rectangular_grid( nrows: int, diag_func: Callable[[Axes, int], Any], under_diag: Callable[[Axes, int, int], Any], @@ -20,7 +22,7 @@ def plot_grid( axsize: tuple[float, float] = (2.0, 2.0), **subplot_kw, ) -> tuple[Figure, np.ndarray]: - """Creates a grid of subplots. + """Creates a rectangular grid of subplots. Args: nrows: number of rows @@ -58,3 +60,119 @@ def plot_grid( over_diag(ax, i, j) return fig, axes + + +@dataclass(frozen=False) +class ArrangedGrid: + """A two-dimensional grid of axes. + + Attrs: + fig: Matplotlib figure. + axes: one-dimensional array of active axes, + with length equal to the number of active plots + axes_grid: two-dimensional array of all axes. + + + Note: + The number of plots in `axes_grid` is typically + greater than the one in `axes`, as `axes_grid` + contains also the axes which are not active + """ + + fig: Figure + axes: np.ndarray + axes_grid: np.ndarray + + @property + def n_active(self) -> int: + return len(self.axes) + + def map( + self, + func: Callable[[Axes], None] | Callable[[Axes, Any], None], + arguments: list | None = None, + ) -> None: + """Applies a function to each active plotting axis. + + Args: + func: function to be applied. It can have + signature func(ax: plt.Axes) + if `arguments` is None, which modifies + the axis in-place. + + If `arguments` is not None, then the function + should have the signature + func(ax: plt.Axes, argument) + where `argument` is taken from the `arguments` + list + """ + if arguments is None: + for ax in self.axes: + func(ax) + else: + if self.n_active != len(arguments): + raise ValueError( + f"Provide one argument for each active axis, in total {self.n_active}" + ) + for ax, arg in zip(self.axes, arguments): + func(ax, arg) + + def set_titles(self, titles: list[str]) -> None: + for title, ax in zip(titles, self.axes): + ax.set_title(title) + + def set_xlabels(self, labels: list[str]) -> None: + for label, ax in zip(labels, self.axes): + ax.set_xlabel(label) + + def set_ylabels(self, labels: list[str]) -> None: + for label, ax in zip(labels, self.axes): + ax.set_ylabel(label) + + +def _calculate_nrows(n: int, ncols: int): + if ncols < 1: + raise ValueError(f"ncols has to be at least 1, was {ncols}.") + return int(np.ceil(n / ncols)) + + +def arrange_into_grid( + nplots: int, + ncols: int = 2, + axsize: tuple[float, float] = (2.0, 1.0), + **kwargs, +) -> ArrangedGrid: + """Builds an array of plots to accommodate + the axes listed. + + Args: + nplots: number of plots + ncols: number of columns + axsize: axis size + kwargs: keyword arguments to be passed to + `subplots_from_axsize`. For example, + ``` + wspace=0.2, # Changes the horizontal spacing + hspace=0.3, # Changes the vertical spacing + left=0.5, # Changes the left margin + ``` + """ + nrows = _calculate_nrows(nplots, ncols=ncols) + + fig, axs = subplots_from_axsize( + axsize=axsize, + nrows=nrows, + ncols=ncols, + **kwargs, + ) + + # Set not used axes + for i, ax in enumerate(axs.ravel()): + if i >= nplots: + ax.set_axis_off() + + return ArrangedGrid( + fig=fig, + axes=axs.ravel()[:nplots], + axes_grid=axs, + ) diff --git a/src/covvfit/plotting/_timeseries.py b/src/covvfit/plotting/_timeseries.py index 2dace61..2eb18b0 100644 --- a/src/covvfit/plotting/_timeseries.py +++ b/src/covvfit/plotting/_timeseries.py @@ -1,5 +1,4 @@ """utilities to plot""" -from typing import Literal import matplotlib.lines as mlines import matplotlib.patches as mpatches @@ -77,9 +76,10 @@ def plot_fit( ax: plt.Axes, ts: Float[Array, " timepoints"], y_fit: Float[Array, "timepoints variants"], - variants: list[Variant], + *, colors: list[Color], - linetype="-", + variants: list[Variant] | None = None, + linestyle="-", **kwargs, ) -> None: """ @@ -91,16 +91,20 @@ def plot_fit( y_fit (array-like): Fitted values for each variant. variants (list): List of variant names. colors (list): List of colors for each variant. - linetype (str): Line style for plotting (e.g., '-', '--', '-.', ':'). + linestyle (str): Line style for plotting (e.g., '-', '--', '-.', ':'). """ sorted_indices = np.argsort(ts) + n_variants = y_fit.shape[-1] + if variants is None: + variants = [""] * n_variants + for i, variant in enumerate(variants): ax.plot( ts[sorted_indices], y_fit[sorted_indices, i], color=colors[i], - linestyle=linetype, - label=f"fit {variant}", + linestyle=linestyle, + label=variant, **kwargs, ) @@ -110,7 +114,7 @@ def plot_complement( ts: Float[Array, " timepoints"], y_fit: Float[Array, "timepoints variants"], color: str = "grey", - linetype: str = "-", + linestyle: str = "-", **kwargs, ) -> None: ## function to plot 1-sum(fitted_values) i.e., the other variant(s) @@ -119,7 +123,7 @@ def plot_complement( ts[sorted_indices], (1 - y_fit.sum(axis=-1))[sorted_indices], color=color, - linestyle=linetype, + linestyle=linestyle, **kwargs, ) @@ -129,7 +133,7 @@ def plot_data( ts: Float[Array, " timepoints"], ys: Float[Array, "timepoints variants"], colors: list[Color], - size: float | int = 4, + size: float = 4.0, alpha: float = 0.5, **kwargs, ) -> None: @@ -138,7 +142,6 @@ def plot_data( ax.scatter( ts, ys[:, i], - label="observed", alpha=alpha, color=colors[i], s=size, @@ -149,7 +152,8 @@ def plot_data( def plot_confidence_bands( ax: plt.Axes, ts: Float[Array, " timepoints"], - conf_bands: dict[Literal["lower", "upper"], Float[Array, "timepoints variants"]], + conf_bands, + *, colors: list[Color], label: str = "Confidence band", alpha: float = 0.2, @@ -161,6 +165,12 @@ def plot_confidence_bands( Parameters: ax: The axis to plot on. ts: Time series data. + conf_bands: confidence bands object. It can be: + 1. A class with attributes `lower` and `upper`, each of which is + an array of shape `(n_timepoints, n_variants)` and represents + the lower and upper confidence bands, respectively. + 2. A tuple of two arrays of the specified shape. + 3. A dictionary with keys "lower" and "upper" color: Color for the confidence interval. label: Label for the confidence band. Default is "Confidence band". alpha: Alpha level controling the opacity. @@ -169,14 +179,34 @@ def plot_confidence_bands( # Sort indices for time series sorted_indices = np.argsort(ts) - n_variants = conf_bands["lower"].shape[-1] + lower, upper = None, None + if hasattr(conf_bands, "lower") and hasattr(conf_bands, "upper"): + lower = conf_bands.lower + upper = conf_bands.upper + elif isinstance(conf_bands, dict): + lower = conf_bands["lower"] + upper = conf_bands["upper"] + else: + lower = conf_bands[0] + upper = conf_bands[1] + + if lower is None or upper is None: + raise ValueError("Confidence bands are not in a recognized format.") + + lower = np.asarray(lower) + upper = np.asarray(upper) + + if lower.ndim != 2 or lower.shape != upper.shape: + raise ValueError("The shape is wrong.") + + n_variants = lower.shape[-1] # Plot the confidence interval for i in range(n_variants): ax.fill_between( ts[sorted_indices], - conf_bands["lower"][sorted_indices][i], - conf_bands["upper"][sorted_indices][i], + lower[sorted_indices, i], + upper[sorted_indices, i], color=colors[i], alpha=alpha, label=label, From 4003cbf68bb826dc24453c55bdc958af8683a3b7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Czy=C5=BC?= Date: Wed, 20 Nov 2024 13:14:27 +0100 Subject: [PATCH 10/10] Remove a TODO comment --- src/covvfit/_quasimultinomial.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/covvfit/_quasimultinomial.py b/src/covvfit/_quasimultinomial.py index 339444b..61abe71 100644 --- a/src/covvfit/_quasimultinomial.py +++ b/src/covvfit/_quasimultinomial.py @@ -284,9 +284,7 @@ def get_confidence_bands_logit( for i, ts in enumerate(ts): # Compute the Jacobian of the transformation and project standard errors jacobian = jax.jacobian(_create_logit_predictions_fn(n_variants, i, ts))(theta) - standard_errors = get_standard_errors( - jacobian=jacobian, covariance=covariance - ) # TODO(Pawel): Issue 24, take a look at it. + standard_errors = get_standard_errors(jacobian=jacobian, covariance=covariance) logit_se.append(standard_errors) # Compute confidence intervals on the logit scale