Skip to content

Commit

Permalink
Numerically stable logit pred (#35)
Browse files Browse the repository at this point in the history
* evaluation workflow

* fixed data expands

* solve lineage overcounting

* Save current changes for testing older version

* compatibility with refactor #27

* added fitted plots (in progress)

* numerically stable logit predictions

* Generalize the function

---------

Co-authored-by: Paweł Czyż <[email protected]>
  • Loading branch information
dr-david and pawel-czyz authored Nov 27, 2024
1 parent 67b6b30 commit 1cdfe83
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 7 deletions.
54 changes: 47 additions & 7 deletions src/covvfit/_quasimultinomial.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,21 +360,61 @@ def get_softmax_predictions(
return y_softmax


def _logsumexp_excluding_column(
y: Float[Array, "*batch variants"],
axis: int = -1,
) -> Float[Array, "*batch variants"]:
"""
Compute logsumexp across the "variants" dimension for each column,
excluding the current column.
Args:
y_linear: a NumPy array.
axis: the axis representing variants, over which excluded logsumexp
will be performed
Returns:
an array of the same shape, where the `i`th element of the axis
corresponds to the logsum exp over all the other entries except
this one
"""
# Numerical stability by shifting with max_val
max_val = jnp.max(y, axis=axis, keepdims=True)
shifted = y - max_val
# Compute sum exp shifted,
# Substract sum exp shifted for each column
# Take the log and add back the max_val
sum_exp_shifted = jnp.sum(jnp.exp(shifted), axis=axis, keepdims=True)
logsumexp_excl = jnp.log(sum_exp_shifted - jnp.exp(shifted)) + max_val

return logsumexp_excl


def get_logit_predictions(
theta: ModelParameters,
n_variants: int,
city_index: int,
ts: Float[Array, " timepoints"],
) -> Float[Array, "timepoints variants"]:
return jax.scipy.special.logit(
get_softmax_predictions(
theta=theta,
n_variants=n_variants,
city_index=city_index,
ts=ts,
)
"""
Compute predictions on the logit scale.
Compute logit(softmax()) in a numerically stable manner
"""

rel_growths = get_relative_growths(theta, n_variants=n_variants)
growths = _add_first_variant(rel_growths)

rel_midpoints = get_relative_midpoints(theta, n_variants=n_variants)
midpoints = _add_first_variant(rel_midpoints[city_index])

y_linear = calculate_linear(
ts=ts,
midpoints=midpoints,
growths=growths,
)

return y_linear - _logsumexp_excluding_column(y_linear)


@dataclasses.dataclass
class OptimizeMultiResult:
Expand Down
1 change: 1 addition & 0 deletions workflows/compare_clinical/config_ba1.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
api_url: "https://lapis.cov-spectrum.org/open/v2/sample/aggregated"
country: "Switzerland"

wastewater_data_path: "../../data/main/deconvolved.csv"

run_name: "config_ba1"
Expand Down
1 change: 1 addition & 0 deletions workflows/compare_clinical/config_ba1ba2.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
api_url: "https://lapis.cov-spectrum.org/open/v2/sample/aggregated"
country: "Switzerland"

wastewater_data_path: "../../data/main/deconvolved.csv"

run_name: "config_ba1ba2"
Expand Down
2 changes: 2 additions & 0 deletions workflows/compare_clinical/config_ba4ba5.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
api_url: "https://lapis.cov-spectrum.org/open/v2/sample/aggregated"
country: "Switzerland"

wastewater_data_path: "../../data/main/deconvolved.csv"


run_name: "config_ba4ba5"

wastewater_cities:
Expand Down
1 change: 1 addition & 0 deletions workflows/compare_clinical/config_ba4ba5_2.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
api_url: "https://lapis.cov-spectrum.org/open/v2/sample/aggregated"
country: "Switzerland"

wastewater_data_path: "../../data/main/deconvolved.csv"

run_name: "config_ba4ba5_2"
Expand Down

0 comments on commit 1cdfe83

Please sign in to comment.