diff --git a/examples/frequentist_notebook_jax.py b/examples/frequentist_notebook_jax.py index be24789..da60c05 100644 --- a/examples/frequentist_notebook_jax.py +++ b/examples/frequentist_notebook_jax.py @@ -13,68 +13,70 @@ # name: python3 # --- +# # Quasilikelihood data analysis notebook +# +# This notebook shows how to estimate growth advantages by fiting the model within the quasimultinomial framework. + # + import jax import jax.numpy as jnp - -import pandas as pd - -import numpy as np - -import matplotlib.ticker as ticker import matplotlib.pyplot as plt -import seaborn as sns - -from scipy.special import expit -from scipy.stats import norm - +import matplotlib.ticker as ticker +import pandas as pd import yaml -import covvfit._preprocess_abundances as prec -import covvfit.plotting._timeseries as plot_ts - +from covvfit import plot, preprocess from covvfit import quasimultinomial as qm -import numpyro - +plot_ts = plot.timeseries # - -# # Load and preprocess data +# ## Load and preprocess data +# +# We start by loading the data: # + -DATA_PATH = "../../LolliPop/lollipop_covvfit/deconvolved.csv" -VAR_DATES_PATH = "../../LolliPop/lollipop_covvfit/var_dates.yaml" - -DATA_PATH = "../new_data/deconvolved.csv" -VAR_DATES_PATH = "../new_data/var_dates.yaml" +_dir_switch = False # Change this to True or False, depending on the laptop you are on +if _dir_switch: + DATA_PATH = "../../LolliPop/lollipop_covvfit/deconvolved.csv" + VAR_DATES_PATH = "../../LolliPop/lollipop_covvfit/var_dates.yaml" +else: + DATA_PATH = "../new_data/deconvolved.csv" + VAR_DATES_PATH = "../new_data/var_dates.yaml" data = pd.read_csv(DATA_PATH, sep="\t") data.head() - +# + # Load the YAML file with open(VAR_DATES_PATH, "r") as file: var_dates_data = yaml.safe_load(file) # Access the var_dates data var_dates = var_dates_data["var_dates"] -# - + data_wide = data.pivot_table( index=["date", "location"], columns="variant", values="proportion", fill_value=0 ).reset_index() data_wide = data_wide.rename(columns={"date": "time", "location": "city"}) -data_wide.head() -# + +# Define the list with cities: +cities = list(data_wide["city"].unique()) + ## Set limit times for modeling max_date = pd.to_datetime(data_wide["time"]).max() delta_time = pd.Timedelta(days=240) start_date = max_date - delta_time +# Print the data frame +data_wide.head() +# - + +# Now we look at the variants in the data and define the variants of interest: # + # Convert the keys to datetime objects for comparison @@ -90,156 +92,185 @@ def match_date(start_date): return closest_date, var_dates_parsed[closest_date] -variants_full = match_date(start_date + delta_time)[1] - -variants = ["KP.2", "KP.3", "XEC"] +variants_full = match_date(start_date + delta_time)[1] # All the variants in this range -variants_other = [i for i in variants_full if i not in variants] +variants_investigated = [ + "KP.2", + "KP.3", + "XEC", +] # Variants found in the data, which we focus on in this analysis +variants_other = [ + i for i in variants_full if i not in variants_investigated +] # Variants not of interest # - -cities = list(data_wide["city"].unique()) +# Apart from the variants of interest, we define the "other" variant, which artificially merges all the other variants into one. This allows us to model the data as a compositional time series, i.e., the sum of abundances of all "variants" is normalized to one. -variants2 = ["other"] + variants -data2 = prec.preprocess_df( +# + +variants_effective = ["other"] + variants_investigated +data_full = preprocess.preprocess_df( data_wide, cities, variants_full, date_min=start_date, zero_date=start_date ) -# + -data2["other"] = data2[variants_other].sum(axis=1) -data2[variants2] = data2[variants2].div(data2[variants2].sum(axis=1), axis=0) - -ts_lst, ys_lst = prec.make_data_list(data2, cities, variants2) -ts_lst, ys_lst2 = prec.make_data_list(data2, cities, variants) +data_full["other"] = data_full[variants_other].sum(axis=1) +data_full[variants_effective] = data_full[variants_effective].div( + data_full[variants_effective].sum(axis=1), axis=0 +) -t_max = max([x.max() for x in ts_lst]) -t_min = min([x.min() for x in ts_lst]) +# + +ts_lst, ys_effective = preprocess.make_data_list( + data_full, cities=cities, variants=variants_effective +) -ts_lst_scaled = [(x - t_min) / (t_max - t_min) for x in ts_lst] +# Scale the time for numerical stability +time_scaler = preprocess.TimeScaler() +ts_lst_scaled = time_scaler.fit_transform(ts_lst) # - -# # fit in jax +# ## Fit the quasimultinomial model +# +# Now we fit the quasimultinomial model, which allows us to find the maximum quasilikelihood estimate of the parameters: # + # %%time -# Recall that the input should be (n_timepoints, n_variants) -# TODO(Pawel, David): Resolve Issue https://github.com/cbg-ethz/covvfit/issues/24 -observed_data = [y.T for y in ys_lst] - - # no priors loss = qm.construct_total_loss( - ys=observed_data, + ys=ys_effective, ts=ts_lst_scaled, average_loss=False, # Do not average the loss over the data points, so that the covariance matrix shrinks with more and more data added ) +n_variants_effective = len(variants_effective) + # initial parameters -theta0 = qm.construct_theta0(n_cities=len(cities), n_variants=len(variants2)) +theta0 = qm.construct_theta0(n_cities=len(cities), n_variants=n_variants_effective) # Run the optimization routine solution = qm.jax_multistart_minimize(loss, theta0, n_starts=10) + +theta_star = solution.x # The maximum quasilikelihood estimate + +print( + f"Relative growth advantages: \n", + qm.get_relative_growths(theta_star, n_variants=n_variants_effective), +) # - -# ## Make fitted values and confidence intervals +# ## Confidence intervals of the growth advantages +# +# To obtain confidence intervals, we will take into account overdispersion. To do this, we need to compare the predictions with the observed values. Then, we can use overdispersion to attempt to correct the covariance matrix and obtain the confidence intervals. # + ## compute fitted values -fitted_values = qm.fitted_values( - ts_lst_scaled, theta=solution.x, cities=cities, n_variants=len(variants2) +ys_fitted = qm.fitted_values( + ts_lst_scaled, theta=theta_star, cities=cities, n_variants=n_variants_effective ) -# ... and because of https://github.com/cbg-ethz/covvfit/issues/24 -# we need to transpose again -y_fit_lst = [y.T[1:] for y in fitted_values] - ## compute covariance matrix -covariance = qm.get_covariance(loss, solution.x) +covariance = qm.get_covariance(loss, theta_star) overdispersion_tuple = qm.compute_overdispersion( - observed=observed_data, - predicted=fitted_values, + observed=ys_effective, + predicted=ys_fitted, ) overdisp_fixed = overdispersion_tuple.overall +print(f"Overdispersion factor: {float(overdisp_fixed):.3f}.") +print("Note that values lower than 1 signify underdispersion.") -# + ## scale covariance by overdisp covariance_scaled = overdisp_fixed * covariance ## compute standard errors and confidence intervals of the estimates standard_errors_estimates = qm.get_standard_errors(covariance_scaled) -confints_estimates = qm.get_confidence_intervals(solution.x, standard_errors_estimates) +confints_estimates = qm.get_confidence_intervals( + theta_star, standard_errors_estimates, confidence_level=0.95 +) -## compute confidence intervals of the fitted values on the logit scale and back transform -y_fit_lst_confint = qm.get_confidence_bands_logit( - solution.x, len(variants2), ts_lst_scaled, covariance_scaled + +print("\n\nRelative growth advantages:") +for variant, m, l, u in zip( + variants_effective[1:], + qm.get_relative_growths(theta_star, n_variants=n_variants_effective), + qm.get_relative_growths(confints_estimates[0], n_variants=n_variants_effective), + qm.get_relative_growths(confints_estimates[1], n_variants=n_variants_effective), +): + print(f" {variant}: {float(m):.2f} ({float(l):.2f} – {float(u):.2f})") +# - + + +# We can propagate this uncertainty to the observed values. Let's generate confidence bands around the fitted lines and predict the future behaviour. + +# + +ys_fitted_confint = qm.get_confidence_bands_logit( + theta_star, + n_variants=n_variants_effective, + ts=ts_lst_scaled, + covariance=covariance_scaled, ) + ## compute predicted values and confidence bands horizon = 60 ts_pred_lst = [jnp.arange(horizon + 1) + tt.max() for tt in ts_lst] -ts_pred_lst_scaled = [(x - t_min) / (t_max - t_min) for x in ts_pred_lst] +ts_pred_lst_scaled = time_scaler.transform(ts_pred_lst) -y_pred_lst = qm.fitted_values( - ts_pred_lst_scaled, theta=solution.x, cities=cities, n_variants=len(variants2) +ys_pred = qm.fitted_values( + ts_pred_lst_scaled, theta=theta_star, cities=cities, n_variants=n_variants_effective ) -# ... and because of https://github.com/cbg-ethz/covvfit/issues/24 -# we need to transpose again -y_pred_lst = [y.T[1:] for y in y_pred_lst] - -y_pred_lst_confint = qm.get_confidence_bands_logit( - solution.x, len(variants2), ts_pred_lst_scaled, covariance_scaled +ys_pred_confint = qm.get_confidence_bands_logit( + theta_star, + n_variants=n_variants_effective, + ts=ts_pred_lst_scaled, + covariance=covariance_scaled, ) +# - +# ## Plot +# +# Finally, we plot the abundance data and the model predictions. Note that the 0th element in each array corresponds to the artificial "other" variant and we decided to plot only the explicitly defined variants. -# - +# + +colors = [plot_ts.COLORS_COVSPECTRUM[var] for var in variants_investigated] -# ## Plotting functions -plot_fit = plot_ts.plot_fit -plot_complement = plot_ts.plot_complement -plot_data = plot_ts.plot_data -plot_confidence_bands = plot_ts.plot_confidence_bands +figure_spec = plot.arrange_into_grid(len(cities), axsize=(4, 1.5), dpi=350, wspace=1) -# ## Plot -# + -colors_covsp = plot_ts.colors_covsp -colors = [colors_covsp[var] for var in variants] -fig, axes_tot = plt.subplots(4, 2, figsize=(15, 10)) -axes_flat = axes_tot.flatten() - -for i, city in enumerate(cities): - ax = axes_flat[i] - # plot fitted and predicted values - plot_fit(ax, ts_lst[i], y_fit_lst[i], variants, colors) - plot_fit(ax, ts_pred_lst[i], y_pred_lst[i], variants, colors, linetype="--") - - # # plot 1-fitted and predicted values - plot_complement(ax, ts_lst[i], y_fit_lst[i], variants) - # plot_complement(ax, ts_pred_lst[i], y_pred_lst[i], variants, linetype="--") - # plot raw deconvolved values - plot_data(ax, ts_lst[i], ys_lst2[i], variants, colors) - # make confidence bands and plot them - conf_bands = y_fit_lst_confint[i] - plot_confidence_bands( +def plot_city(ax, i: int) -> None: + def remove_0th(arr): + """We don't plot the artificial 0th variant 'other'.""" + return arr[:, 1:] + + # Plot fits in observed and unobserved time intervals. + plot_ts.plot_fit(ax, ts_lst[i], remove_0th(ys_fitted[i]), colors=colors) + plot_ts.plot_fit( + ax, ts_pred_lst[i], remove_0th(ys_pred[i]), colors=colors, linestyle="--" + ) + + plot_ts.plot_confidence_bands( ax, ts_lst[i], - {"lower": conf_bands[0], "upper": conf_bands[1]}, - variants, - colors, + jax.tree.map(remove_0th, ys_fitted_confint[i]), + colors=colors, ) - - pred_bands = y_pred_lst_confint[i] - plot_confidence_bands( + plot_ts.plot_confidence_bands( ax, ts_pred_lst[i], - {"lower": pred_bands[0], "upper": pred_bands[1]}, - variants, - colors, + jax.tree.map(remove_0th, ys_pred_confint[i]), + colors=colors, + ) + + # Plot the data points + plot_ts.plot_data(ax, ts_lst[i], remove_0th(ys_effective[i]), colors=colors) + + # Plot the complements + plot_ts.plot_complement(ax, ts_lst[i], remove_0th(ys_fitted[i]), alpha=0.3) + plot_ts.plot_complement( + ax, ts_pred_lst[i], remove_0th(ys_pred[i]), linestyle="--", alpha=0.3 ) # format axes and title @@ -252,9 +283,8 @@ def format_date(x, pos): tick_labels = ["0%", "50%", "100%"] ax.set_yticks(tick_positions) ax.set_yticklabels(tick_labels) - ax.set_ylabel("relative abundances") - ax.set_title(city) + ax.set_ylabel("Relative abundances") + ax.set_title(cities[i]) -fig.tight_layout() -fig.show() -# - + +figure_spec.map(plot_city, range(len(cities))) diff --git a/pyproject.toml b/pyproject.toml index da2bdbc..b93a257 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,6 +19,7 @@ numpy = "==1.24.3" #pymc = "==5.3.0" seaborn = "^0.13.2" numpyro = "^0.14.0" +subplots-from-axsize = "^0.1.9" [tool.poetry.group.dev] diff --git a/src/covvfit/__init__.py b/src/covvfit/__init__.py index 826ecec..cf7ed7b 100644 --- a/src/covvfit/__init__.py +++ b/src/covvfit/__init__.py @@ -10,8 +10,8 @@ ) simulation = None +import covvfit._preprocess_abundances as preprocess import covvfit.plotting as plot -from covvfit._preprocess_abundances import load_data, make_data_list, preprocess_df from covvfit._splines import create_spline_matrix VERSION = "0.1.0" @@ -19,11 +19,10 @@ __all__ = [ "create_spline_matrix", - "make_data_list", - "preprocess_df", "load_data", "VERSION", "quasimultinomial", "plot", + "preprocess", "simulation", ] diff --git a/src/covvfit/_preprocess_abundances.py b/src/covvfit/_preprocess_abundances.py index 4922676..8d6c6df 100644 --- a/src/covvfit/_preprocess_abundances.py +++ b/src/covvfit/_preprocess_abundances.py @@ -1,6 +1,7 @@ """utilities to preprocess relative abundances""" import pandas as pd +from jaxtyping import Array, Float def load_data(file) -> pd.DataFrame: @@ -13,34 +14,56 @@ def preprocess_df( df: pd.DataFrame, cities: list[str], variants: list[str], - undertermined_thresh: float = 0.01, + *, + undetermined_thresh: float = 0.01, zero_date: str = "2023-01-01", date_min: str | None = None, date_max: str | None = None, + time_col: str = "time", + city_col: str = "city", + undetermined_col: str | None = "undetermined", ) -> pd.DataFrame: - """Preprocessing.""" + """Preprocessing function. + + Args: + df: data frame with data + cities: cities for which the data will be processed + variants: variants which will be processed, they should + be represented by different columns in `df` + undetermined_thresh: threshold of the `undetermined` variant + used to remove days with too many missing values. + Use `None` to not remove any data + zero_date: reference time point + date_min: the lower bound of the data to be selected. + Set to `None` to not set a bound + date_max: see `date_min` + time_col: column with dates representing days + city_col: column with cities + undetermined_col: column with the undetermined variant + """ df = df.copy() # Convert the 'time' column to datetime - df["time"] = pd.to_datetime(df["time"]) + df[time_col] = pd.to_datetime(df["time"]) # Remove days with too high undetermined - df = df[df["undetermined"] < undertermined_thresh] # pyright: ignore + if undetermined_col is not None: + df = df[df[undetermined_col] < undetermined_thresh] # pyright: ignore - # Subset the 'BQ.1.1' column - df = df[["time", "city"] + variants] # pyright: ignore + # Subset the columns corresponding to variants + df = df[[time_col, city_col] + variants] # pyright: ignore # Subset only the specified cities - df = df[df["city"].isin(cities)] # pyright: ignore + df = df[df[city_col].isin(cities)] # pyright: ignore # Create a new column which is the difference in days between zero_date and the date - df["days_from"] = (df["time"] - pd.to_datetime(zero_date)).dt.days + df["days_from"] = (df[time_col] - pd.to_datetime(zero_date)).dt.days # Subset dates if date_min is not None: - df = df[df["time"] >= pd.to_datetime(date_min)] # pyright: ignore + df = df[df[time_col] >= pd.to_datetime(date_min)] # pyright: ignore if date_max is not None: - df = df[df["time"] < pd.to_datetime(date_max)] # pyright: ignore + df = df[df[time_col] < pd.to_datetime(date_max)] # pyright: ignore return df @@ -49,13 +72,81 @@ def make_data_list( df: pd.DataFrame, cities: list[str], variants: list[str], -) -> tuple[list, list]: +) -> tuple[ + list[Float[Array, " timepoints"]], list[Float[Array, "timepoints variants"]] +]: ts_lst = [df[(df.city == city)].days_from.values for city in cities] ys_lst = [ - df[(df.city == city)][variants].values.T for city in cities + df[(df.city == city)][variants].values for city in cities ] # pyright: ignore + # TODO(David, Pawel): How should we handle this case? + # It *implicitly* changes the output data type, basing on the input value. + # Do we even use this feature? if "count_sum" in df.columns: ns_lst = [df[(df.city == city)].count_sum.values for city in cities] return (ts_lst, ys_lst, ns_lst) else: return (ts_lst, ys_lst) + + +_ListTimeSeries = list[Float[Array, " timeseries"]] + + +class TimeScaler: + """Scales a list of time series, so that the values are normalized.""" + + def __init__(self): + self.t_min = None + self.t_max = None + self._fitted = False + + def fit(self, ts: _ListTimeSeries) -> None: + """Fit the scaler parameters to the provided time series. + + Args: + ts: list of timeseries, i.e., `ts[i]` is an array + of some length `n_timepoints[i]`. + """ + self.t_min = min([x.min() for x in ts]) + self.t_max = max([x.max() for x in ts]) + self._fitted = True + + def transform(self, ts: _ListTimeSeries) -> _ListTimeSeries: + """Returns scaled values. + + Args: + ts: list of timeseries, i.e., `ts[i]` is an array + of some length `n_timepoints[i]`. + + Returns: + list of exactly the same format as `ts` + + Note: + The model has to be fitted first. + """ + if not self._fitted: + raise RuntimeError("You need to fit the model first.") + + denominator = self.t_max - self.t_min + return [(x - self.t_min) / denominator for x in ts] + + def fit_transform(self, ts: _ListTimeSeries) -> _ListTimeSeries: + """Fits the model and returns scaled values. + + Args: + ts: list of timeseries, i.e., `ts[i]` is an array + of some length `n_timepoints[i]`. + + Returns: + list of exactly the same format as `ts` + + Note: + This function is equivalent to calling + first `fit` method and then `transform`. + """ + self.fit(ts) + return self.transform(ts) + + @property + def time_unit(self) -> float: + return self.t_max - self.t_min diff --git a/src/covvfit/_quasimultinomial.py b/src/covvfit/_quasimultinomial.py index ee3e9f1..61abe71 100644 --- a/src/covvfit/_quasimultinomial.py +++ b/src/covvfit/_quasimultinomial.py @@ -67,7 +67,7 @@ def loss( return -jnp.sum(n * y * logp, axis=-1) -_ThetaType = Float[Array, "(cities+1)*(variants-1)"] +ModelParameters = Float[Array, "(cities+1)*(variants-1)"] def _add_first_variant(vec: Float[Array, " variants-1"]) -> Float[Array, " variants"]: @@ -78,21 +78,21 @@ def _add_first_variant(vec: Float[Array, " variants-1"]) -> Float[Array, " varia def construct_theta( relative_growths: Float[Array, " variants-1"], relative_midpoints: Float[Array, "cities variants-1"], -) -> _ThetaType: +) -> ModelParameters: flattened_midpoints = relative_midpoints.flatten() theta = jnp.concatenate([relative_growths, flattened_midpoints]) return theta def get_relative_growths( - theta: _ThetaType, + theta: ModelParameters, n_variants: int, ) -> Float[Array, " variants-1"]: return theta[: n_variants - 1] def get_relative_midpoints( - theta: _ThetaType, + theta: ModelParameters, n_variants: int, ) -> Float[Array, "cities variants-1"]: n_cities = theta.shape[0] // (n_variants - 1) - 1 @@ -109,12 +109,12 @@ def convert(confidence: float) -> float: Example: StandardErrorsMultipliers.convert(0.95) # 1.9599 """ - return float(jax.scipy.stats.norm.ppf((1 + confidence) / 2)) + return float(jax.scipy.stats.norm.ppf((1 + confidence) / 2.0)) def get_covariance( - loss_fn: Callable[[_ThetaType], _Float], - theta: _ThetaType, + loss_fn: Callable[[ModelParameters], _Float], + theta: ModelParameters, ) -> Float[Array, "n_params n_params"]: """Calculates the covariance matrix of the parameters. @@ -185,7 +185,7 @@ def get_confidence_intervals( Assumes a normal distribution for the estimates. """ # Calculate the multiplier based on the confidence level - z_score = jax.scipy.stats.norm.ppf((1 + confidence_level) / 2) + z_score = StandardErrorsMultipliers.convert(confidence_level) # Compute the lower and upper bounds of the confidence intervals lower_bound = estimates - z_score * standard_errors @@ -196,7 +196,7 @@ def get_confidence_intervals( def fitted_values( times: list[Float[Array, " timepoints"]], - theta: _ThetaType, + theta: ModelParameters, cities: list, n_variants: int, ) -> list[Float[Array, "timepoints variants"]]: @@ -222,7 +222,7 @@ def fitted_values( return y_fit_lst -def create_logit_predictions_fn( +def _create_logit_predictions_fn( n_variants: int, city_index: int, ts: Float[Array, " timepoints"] ) -> Callable[ [Float[Array, " (cities+1)*(variants-1)"]], Float[Array, "timepoints variants"] @@ -239,63 +239,67 @@ def create_logit_predictions_fn( """ def logit_predictions_with_fixed_args( - theta: _ThetaType, + theta: ModelParameters, ): return get_logit_predictions( theta=theta, n_variants=n_variants, city_index=city_index, ts=ts - )[:, 1:] + ) return logit_predictions_with_fixed_args +class ConfidenceBand(NamedTuple): + lower: Float[Array, "timepoints variants"] + upper: Float[Array, "timepoints variants"] + + def get_confidence_bands_logit( - solution_x: Float[Array, " (cities+1)*(variants-1)"], - variants_count: int, - ts_lst_scaled: list[Float[Array, " timepoints"]], - covariance_scaled: Float[Array, "n_params n_params"], + theta: ModelParameters, + *, + n_variants: int, + ts: list[Float[Array, " timepoints"]], + covariance: Float[Array, "n_params n_params"], confidence_level: float = 0.95, -) -> list[tuple]: +) -> list[ConfidenceBand]: """Computes confidence intervals for logit predictions using the Delta method, back-transforms them to the linear scale Args: - solution_x: Optimized parameters for the model. + theta: Parameters for the model. variants_count: Number of variants. ts_lst_scaled: List of timepoint arrays for each city. - covariance_scaled: Covariance matrix for the parameters. + covariance: Covariance matrix for the parameters. Note that it should + include any overdispersion factors. confidence_level: Desired confidence level for intervals (default is 0.95). Returns: A list of dictionaries for each city, each with "lower" and "upper" bounds for the confidence intervals on the linear scale. """ - - y_fit_lst_logit = [ - get_logit_predictions(solution_x, variants_count, i, ts).T[1:, :] - for i, ts in enumerate(ts_lst_scaled) + logit_timeseries = [ + get_logit_predictions(theta, n_variants, i, ts) for i, ts in enumerate(ts) ] - y_fit_lst_logit_se = [] - for i, ts in enumerate(ts_lst_scaled): + logit_se = [] + for i, ts in enumerate(ts): # Compute the Jacobian of the transformation and project standard errors - jacobian = jax.jacobian(create_logit_predictions_fn(variants_count, i, ts))( - solution_x - ) - standard_errors = get_standard_errors( - jacobian=jacobian, covariance=covariance_scaled - ).T - y_fit_lst_logit_se.append(standard_errors) + jacobian = jax.jacobian(_create_logit_predictions_fn(n_variants, i, ts))(theta) + standard_errors = get_standard_errors(jacobian=jacobian, covariance=covariance) + logit_se.append(standard_errors) # Compute confidence intervals on the logit scale - y_fit_lst_logit_confint = [ + logit_confint = [ get_confidence_intervals(fitted, se, confidence_level=confidence_level) - for fitted, se in zip(y_fit_lst_logit, y_fit_lst_logit_se) + for fitted, se in zip(logit_timeseries, logit_se) ] # Project confidence intervals to the linear scale y_fit_lst_logit_confint_expit = [ - (jax.scipy.special.expit(confint[0]), jax.scipy.special.expit(confint[1])) - for confint in y_fit_lst_logit_confint + ConfidenceBand( + lower=jax.scipy.special.expit(confint[0]), + upper=jax.scipy.special.expit(confint[1]), + ) + for confint in logit_confint ] return y_fit_lst_logit_confint_expit @@ -309,7 +313,22 @@ def triangular_mask(n_variants, valid_value: float = 0, masked_value: float = jn return nan_mask -def get_relative_advantages(theta, n_variants: int): +def get_relative_advantages( + theta: ModelParameters, n_variants: int +) -> Float[Array, "variants variants"]: + """Returns a matrix of relative advantages, comparing every two variants. + + Returns: + matrix of shape (n_variants, n_variants) with `A[reference, variant]` + representing the relative advantage of `variant` over `reference`. + + Note: + From the model assumptions it follows that + `A[v1, v2] + A[v2, v3] = A[v1, v3]` + for every three variants. (I.e., the relative advantage + of `v3` over `v1` is the sum of advantages of `v3` over `v2` + and `v2` over `v1`) + """ # Shape (n_variants-1,) describing relative advantages # over the 0th variant rel_growths = get_relative_growths(theta, n_variants=n_variants) @@ -320,7 +339,10 @@ def get_relative_advantages(theta, n_variants: int): def get_softmax_predictions( - theta: _ThetaType, n_variants: int, city_index: int, ts: Float[Array, " timepoints"] + theta: ModelParameters, + n_variants: int, + city_index: int, + ts: Float[Array, " timepoints"], ) -> Float[Array, "timepoints variants"]: rel_growths = get_relative_growths(theta, n_variants=n_variants) growths = _add_first_variant(rel_growths) @@ -339,7 +361,7 @@ def get_softmax_predictions( def get_logit_predictions( - theta: _ThetaType, + theta: ModelParameters, n_variants: int, city_index: int, ts: Float[Array, " timepoints"], @@ -365,7 +387,7 @@ class OptimizeMultiResult: def construct_theta0( n_cities: int, n_variants: int, -) -> _ThetaType: +) -> ModelParameters: return np.zeros((n_cities * (n_variants - 1) + n_variants - 1,), dtype=float) @@ -665,7 +687,7 @@ def construct_total_loss( overdispersion: _OverDispersionType = 1.0, accept_theta: bool = True, average_loss: bool = False, -) -> Callable[[_ThetaType], _Float] | _RelativeGrowthsAndOffsetsFunction: +) -> Callable[[ModelParameters], _Float] | _RelativeGrowthsAndOffsetsFunction: """Constructs the loss function, suitable e.g., for optimization. Args: diff --git a/src/covvfit/plotting/__init__.py b/src/covvfit/plotting/__init__.py index f108b7a..5883c1e 100644 --- a/src/covvfit/plotting/__init__.py +++ b/src/covvfit/plotting/__init__.py @@ -1,14 +1,23 @@ """Plotting functionalities.""" -from covvfit.plotting._grid import plot_grid, set_axis_off +import covvfit.plotting._timeseries as timeseries +from covvfit.plotting._grid import ( + ArrangedGrid, + arrange_into_grid, + plot_on_rectangular_grid, + set_axis_off, +) from covvfit.plotting._simplex import plot_on_simplex -from covvfit.plotting._timeseries import colors_covsp, make_legend, num_to_date +from covvfit.plotting._timeseries import COLORS_COVSPECTRUM, make_legend, num_to_date __all__ = [ + "ArrangedGrid", + "arrange_into_grid", "plot_on_simplex", - "plot_grid", + "plot_on_rectangular_grid", "set_axis_off", "make_legend", "num_to_date", - "colors_covsp", + "timeseries", + "COLORS_COVSPECTRUM", ] diff --git a/src/covvfit/plotting/_grid.py b/src/covvfit/plotting/_grid.py index cb2dca9..cd95c60 100644 --- a/src/covvfit/plotting/_grid.py +++ b/src/covvfit/plotting/_grid.py @@ -1,9 +1,11 @@ +from dataclasses import dataclass from typing import Any, Callable import matplotlib.pyplot as plt import numpy as np from matplotlib.axes import Axes from matplotlib.figure import Figure +from subplots_from_axsize import subplots_from_axsize def set_axis_off(ax: Axes, i: int = 0, j: int = 0) -> None: @@ -11,7 +13,7 @@ def set_axis_off(ax: Axes, i: int = 0, j: int = 0) -> None: ax.set_axis_off() -def plot_grid( +def plot_on_rectangular_grid( nrows: int, diag_func: Callable[[Axes, int], Any], under_diag: Callable[[Axes, int, int], Any], @@ -20,7 +22,7 @@ def plot_grid( axsize: tuple[float, float] = (2.0, 2.0), **subplot_kw, ) -> tuple[Figure, np.ndarray]: - """Creates a grid of subplots. + """Creates a rectangular grid of subplots. Args: nrows: number of rows @@ -58,3 +60,119 @@ def plot_grid( over_diag(ax, i, j) return fig, axes + + +@dataclass(frozen=False) +class ArrangedGrid: + """A two-dimensional grid of axes. + + Attrs: + fig: Matplotlib figure. + axes: one-dimensional array of active axes, + with length equal to the number of active plots + axes_grid: two-dimensional array of all axes. + + + Note: + The number of plots in `axes_grid` is typically + greater than the one in `axes`, as `axes_grid` + contains also the axes which are not active + """ + + fig: Figure + axes: np.ndarray + axes_grid: np.ndarray + + @property + def n_active(self) -> int: + return len(self.axes) + + def map( + self, + func: Callable[[Axes], None] | Callable[[Axes, Any], None], + arguments: list | None = None, + ) -> None: + """Applies a function to each active plotting axis. + + Args: + func: function to be applied. It can have + signature func(ax: plt.Axes) + if `arguments` is None, which modifies + the axis in-place. + + If `arguments` is not None, then the function + should have the signature + func(ax: plt.Axes, argument) + where `argument` is taken from the `arguments` + list + """ + if arguments is None: + for ax in self.axes: + func(ax) + else: + if self.n_active != len(arguments): + raise ValueError( + f"Provide one argument for each active axis, in total {self.n_active}" + ) + for ax, arg in zip(self.axes, arguments): + func(ax, arg) + + def set_titles(self, titles: list[str]) -> None: + for title, ax in zip(titles, self.axes): + ax.set_title(title) + + def set_xlabels(self, labels: list[str]) -> None: + for label, ax in zip(labels, self.axes): + ax.set_xlabel(label) + + def set_ylabels(self, labels: list[str]) -> None: + for label, ax in zip(labels, self.axes): + ax.set_ylabel(label) + + +def _calculate_nrows(n: int, ncols: int): + if ncols < 1: + raise ValueError(f"ncols has to be at least 1, was {ncols}.") + return int(np.ceil(n / ncols)) + + +def arrange_into_grid( + nplots: int, + ncols: int = 2, + axsize: tuple[float, float] = (2.0, 1.0), + **kwargs, +) -> ArrangedGrid: + """Builds an array of plots to accommodate + the axes listed. + + Args: + nplots: number of plots + ncols: number of columns + axsize: axis size + kwargs: keyword arguments to be passed to + `subplots_from_axsize`. For example, + ``` + wspace=0.2, # Changes the horizontal spacing + hspace=0.3, # Changes the vertical spacing + left=0.5, # Changes the left margin + ``` + """ + nrows = _calculate_nrows(nplots, ncols=ncols) + + fig, axs = subplots_from_axsize( + axsize=axsize, + nrows=nrows, + ncols=ncols, + **kwargs, + ) + + # Set not used axes + for i, ax in enumerate(axs.ravel()): + if i >= nplots: + ax.set_axis_off() + + return ArrangedGrid( + fig=fig, + axes=axs.ravel()[:nplots], + axes_grid=axs, + ) diff --git a/src/covvfit/plotting/_timeseries.py b/src/covvfit/plotting/_timeseries.py index cf88cf2..2eb18b0 100644 --- a/src/covvfit/plotting/_timeseries.py +++ b/src/covvfit/plotting/_timeseries.py @@ -2,10 +2,15 @@ import matplotlib.lines as mlines import matplotlib.patches as mpatches +import matplotlib.pyplot as plt import numpy as np import pandas as pd +from jaxtyping import Array, Float -colors_covsp = { +Variant = str +Color = str + +COLORS_COVSPECTRUM: dict[Variant, Color] = { "B.1.1.7": "#D16666", "B.1.351": "#FF6666", "P.1": "#FFB3B3", @@ -31,7 +36,7 @@ } -def make_legend(colors, variants): +def make_legend(colors: list[Color], variants: list[Variant]) -> list[mpatches.Patch]: """make a shared legend for the plot""" # Create a patch (i.e., a colored box) for each variant variant_patches = [ @@ -59,13 +64,24 @@ def make_legend(colors, variants): return handles -def num_to_date(num, date_min, pos=None, fmt="%b. '%y"): +def num_to_date( + num: pd.Series | Float[Array, " timepoints"], date_min: str, fmt="%b. '%y" +) -> pd.Series: """convert days number into a date format""" date = pd.to_datetime(date_min) + pd.to_timedelta(num, "D") return date.strftime(fmt) -def plot_fit(ax, ts, y_fit, variants, colors, linetype="-", **kwargs): +def plot_fit( + ax: plt.Axes, + ts: Float[Array, " timepoints"], + y_fit: Float[Array, "timepoints variants"], + *, + colors: list[Color], + variants: list[Variant] | None = None, + linestyle="-", + **kwargs, +) -> None: """ Function to plot fitted values with customizable line type. @@ -75,62 +91,124 @@ def plot_fit(ax, ts, y_fit, variants, colors, linetype="-", **kwargs): y_fit (array-like): Fitted values for each variant. variants (list): List of variant names. colors (list): List of colors for each variant. - linetype (str): Line style for plotting (e.g., '-', '--', '-.', ':'). + linestyle (str): Line style for plotting (e.g., '-', '--', '-.', ':'). """ sorted_indices = np.argsort(ts) + n_variants = y_fit.shape[-1] + if variants is None: + variants = [""] * n_variants + for i, variant in enumerate(variants): ax.plot( ts[sorted_indices], - y_fit[i, :][sorted_indices], + y_fit[sorted_indices, i], color=colors[i], - linestyle=linetype, - label=f"fit {variant}", + linestyle=linestyle, + label=variant, **kwargs, ) -def plot_complement(ax, ts, y_fit, variants, color="grey", linetype="-"): +def plot_complement( + ax: plt.Axes, + ts: Float[Array, " timepoints"], + y_fit: Float[Array, "timepoints variants"], + color: str = "grey", + linestyle: str = "-", + **kwargs, +) -> None: ## function to plot 1-sum(fitted_values) i.e., the other variant(s) sorted_indices = np.argsort(ts) ax.plot( ts[sorted_indices], - (1 - y_fit.sum(axis=0))[sorted_indices], + (1 - y_fit.sum(axis=-1))[sorted_indices], color=color, - linestyle=linetype, + linestyle=linestyle, + **kwargs, ) -def plot_data(ax, ts, ys, variants, colors): +def plot_data( + ax: plt.Axes, + ts: Float[Array, " timepoints"], + ys: Float[Array, "timepoints variants"], + colors: list[Color], + size: float = 4.0, + alpha: float = 0.5, + **kwargs, +) -> None: ## function to plot raw values - for i, variant in enumerate(variants): - ax.scatter(ts, ys[i, :], label="observed", alpha=0.5, color=colors[i], s=4) + for i in range(ys.shape[-1]): + ax.scatter( + ts, + ys[:, i], + alpha=alpha, + color=colors[i], + s=size, + **kwargs, + ) def plot_confidence_bands( - ax, ts, conf_bands, variants, colors, label="Confidence band", alpha=0.2 -): + ax: plt.Axes, + ts: Float[Array, " timepoints"], + conf_bands, + *, + colors: list[Color], + label: str = "Confidence band", + alpha: float = 0.2, + **kwargs, +) -> None: """ Plot confidence intervals for fitted values on a given axis with customizable confidence level. Parameters: - ax (matplotlib.axes.Axes): The axis to plot on. - ts (array-like): Time series data. - y_fit_logit (array-like): Logit-transformed fitted values for each variant. - logit_se (array-like): Standard errors for the logit-transformed fitted values. - color (str): Color for the confidence interval. - confidence (float, optional): Confidence level (e.g., 0.95 for 95%). Default is 0.95. - label (str, optional): Label for the confidence band. Default is "Confidence band". + ax: The axis to plot on. + ts: Time series data. + conf_bands: confidence bands object. It can be: + 1. A class with attributes `lower` and `upper`, each of which is + an array of shape `(n_timepoints, n_variants)` and represents + the lower and upper confidence bands, respectively. + 2. A tuple of two arrays of the specified shape. + 3. A dictionary with keys "lower" and "upper" + color: Color for the confidence interval. + label: Label for the confidence band. Default is "Confidence band". + alpha: Alpha level controling the opacity. + **kwargs: Additional keyword arguments for `ax.fill_between`. """ # Sort indices for time series sorted_indices = np.argsort(ts) + lower, upper = None, None + if hasattr(conf_bands, "lower") and hasattr(conf_bands, "upper"): + lower = conf_bands.lower + upper = conf_bands.upper + elif isinstance(conf_bands, dict): + lower = conf_bands["lower"] + upper = conf_bands["upper"] + else: + lower = conf_bands[0] + upper = conf_bands[1] + + if lower is None or upper is None: + raise ValueError("Confidence bands are not in a recognized format.") + + lower = np.asarray(lower) + upper = np.asarray(upper) + + if lower.ndim != 2 or lower.shape != upper.shape: + raise ValueError("The shape is wrong.") + + n_variants = lower.shape[-1] + # Plot the confidence interval - for i, variant in enumerate(variants): + for i in range(n_variants): ax.fill_between( ts[sorted_indices], - conf_bands["lower"][i][sorted_indices], - conf_bands["upper"][i][sorted_indices], + lower[sorted_indices, i], + upper[sorted_indices, i], color=colors[i], alpha=alpha, label=label, + **kwargs, ) diff --git a/tests/test_quasimultinomial.py b/tests/test_quasimultinomial.py index edb4f9d..507d13e 100644 --- a/tests/test_quasimultinomial.py +++ b/tests/test_quasimultinomial.py @@ -1,5 +1,6 @@ import covvfit._quasimultinomial as qm import jax +import jax.numpy as jnp import numpy.testing as npt import pytest @@ -46,3 +47,48 @@ def test_parameter_conversions_2(seed: int, n_cities: int, n_variants: int) -> N ), theta, ) + + +def test_softmax_predictions( + n_cities: int = 2, n_variants: int = 3, n_timepoints: int = 50 +) -> None: + theta0 = qm.construct_theta0(n_cities=n_cities, n_variants=n_variants) + theta = jax.random.normal(jax.random.PRNGKey(42), shape=theta0.shape) + + ts = jnp.linspace(0, 1, n_timepoints) + + for city in range(n_cities): + predictions = qm.get_softmax_predictions( + theta, + n_variants=n_variants, + city_index=city, + ts=ts, + ) + + assert predictions.shape == (n_timepoints, n_variants) + + npt.assert_allclose( + predictions.sum(axis=-1), + jnp.ones(n_timepoints), + atol=1e-6, + ) + + +def test_get_relative_advantages(n_cities: int = 1, n_variants: int = 5) -> None: + theta0 = qm.construct_theta0(n_cities=n_cities, n_variants=n_variants) + # The variants are ordered by increasing fitness + relative = jnp.arange(1, n_variants) + theta = qm.construct_theta( + relative_growths=relative, + relative_midpoints=qm.get_relative_midpoints(theta0, n_variants=n_variants), + ) + + A = qm.get_relative_advantages(theta, n_variants=n_variants) + for v2 in range(n_variants): + for v1 in range(n_variants): + assert pytest.approx(A[v1, v2]) == v2 - v1 + + for v1 in range(n_variants): + for v2 in range(n_variants): + for v3 in range(n_variants): + assert pytest.approx(A[v1, v3]) == A[v1, v2] + A[v2, v3]