Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Read in prior distributions from an external .py file #123

Merged
merged 16 commits into from
Nov 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ ENGINE := docker
CONTAINER_NAME := pyrenew-hew
CONTAINER_REMOTE_NAME := $(ACR_TAG_PREFIX)$(CONTAINER_NAME)":latest"

container_build:
container_build: acr_login
$(ENGINE) build . -t $(CONTAINER_NAME)

container_tag:
Expand Down
29 changes: 17 additions & 12 deletions pipelines/batch/setup_eval_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,41 +97,46 @@ def main(
mount_pairs=[
{
"source": "nssp-etl",
"target": "/pyrenew-hew/nssp_demo/nssp-etl",
"target": "/pyrenew-hew/nssp-etl",
},
{
"source": "nssp-archival-vintages",
"target": "/pyrenew-hew/nssp_demo/nssp-archival-vintages",
"target": "/pyrenew-hew/nssp-archival-vintages",
},
{
"source": "prod-param-estimates",
"target": "/pyrenew-hew/nssp_demo/params",
"target": "/pyrenew-hew/params",
},
{
"source": "pyrenew-test-output",
"target": "/pyrenew-hew/nssp_demo/private_data",
"source": "pyrenew-hew-prod-output",
"target": "/pyrenew-hew/output",
},
{
"source": "pyrenew-hew-config",
"target": "/pyrenew-hew/config",
},
],
)

base_call = (
"/bin/bash -c '"
"python nssp_demo/forecast_state.py "
"python pipelines/forecast_state.py "
"--disease {disease} "
"--state {state} "
"--n-training-days 365 "
"--n-warmup 1000 "
"--n-samples 500 "
"--facility-level-nssp-data-dir nssp_demo/nssp-etl/gold "
"--facility-level-nssp-data-dir nssp-etl/gold "
"--state-level-nssp-data-dir "
"nssp_demo/nssp-archival-vintages/gold "
"--param-data-dir nssp_demo/params "
"--output-data-dir nssp_demo/private_data "
"nssp-archival-vintages/gold "
"--param-data-dir params "
"--output-data-dir output "
"--priors-path config/eval_priors.py "
damonbayer marked this conversation as resolved.
Show resolved Hide resolved
"--report-date {report_date:%Y-%m-%d} "
"--exclude-last-n-days 2 "
"--score "
"--eval-data-path "
"nssp_demo/nssp-archival-vintages/latest_comprehensive.parquet"
"nssp-archival-vintages/latest_comprehensive.parquet"
"'"
)

Expand All @@ -143,7 +148,7 @@ def main(
locations.filter(~pl.col("STUSAB").is_in(excluded_locations))
.get_column("STUSAB")
.to_list()
)
) + ["US"]

report_dates = [
datetime.date(2023, 10, 11) + datetime.timedelta(weeks=x)
Expand Down
3 changes: 2 additions & 1 deletion pipelines/batch/setup_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ def main(pool_name: str) -> None:
"nssp-etl",
"nssp-archival-vintages",
"prod-param-estimates",
"pyrenew-test-output",
"pyrenew-hew-prod-output",
"pyrenew-hew-config",
],
account_names=creds.azure_blob_storage_account,
identity_references=node_id_ref,
Expand Down
32 changes: 19 additions & 13 deletions pipelines/batch/setup_prod_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,44 +97,50 @@ def main(
mount_pairs=[
{
"source": "nssp-etl",
"target": "/pyrenew-hew/nssp_demo/nssp-etl",
"target": "/pyrenew-hew/nssp-etl",
},
{
"source": "nssp-archival-vintages",
"target": "/pyrenew-hew/nssp_demo/nssp-archival-vintages",
"target": "/pyrenew-hew/nssp-archival-vintages",
},
{
"source": "prod-param-estimates",
"target": "/pyrenew-hew/nssp_demo/params",
"target": "/pyrenew-hew/params",
},
{
"source": "pyrenew-test-output",
"target": "/pyrenew-hew/nssp_demo/private_data",
"source": "pyrenew-hew-prod-output",
"target": "/pyrenew-hew/output",
},
{
"source": "pyrenew-hew-config",
"target": "/pyrenew-hew/config",
},
],
)

base_call = (
"/bin/bash -c '"
"python nssp_demo/forecast_state.py "
"python pipelines/forecast_state.py "
"--disease {disease} "
"--state {state} "
"--n-training-days 75 "
"--n-warmup 1000 "
"--n-samples 500 "
"--facility-level-nssp-data-dir nssp_demo/nssp-etl/gold "
"--facility-level-nssp-data-dir nssp-etl/gold "
"--state-level-nssp-data-dir "
"nssp_demo/nssp-archival-vintages/gold "
"--param-data-dir nssp_demo/params "
"--output-data-dir nssp_demo/private_data "
"nssp-archival-vintages/gold "
"--param-data-dir params "
"--output-data-dir output "
"--priors-path config/prod_priors.py "
"--report-date {report_date} "
"--exclude-last-n-days 5 "
"--score "
"--no-score "
damonbayer marked this conversation as resolved.
Show resolved Hide resolved
"--eval-data-path "
"nssp_demo/nssp-archival-vintages/latest_comprehensive.parquet"
"nssp-archival-vintages/latest_comprehensive.parquet"
"'"
)

# to be replaced by forecasttools-py table
locations = pl.read_csv(
"https://www2.census.gov/geo/docs/reference/state.txt", separator="|"
)
Expand All @@ -143,7 +149,7 @@ def main(
locations.filter(~pl.col("STUSAB").is_in(excluded_locations))
.get_column("STUSAB")
.to_list()
)
) + ["US"]

for disease, state in itertools.product(disease_list, all_locations):
task = get_task_config(
Expand Down
42 changes: 15 additions & 27 deletions pipelines/build_model.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,15 @@
import json
import runpy

import jax.numpy as jnp

# load priors
# have to run this from the right directory
from priors import ( # noqa: E402
autoreg_p_hosp_rv,
autoreg_rt_rv,
eta_sd_rv,
hosp_wday_effect_rv,
i0_first_obs_n_rv,
inf_feedback_strength_rv,
initialization_rate_rv,
log_r_mu_intercept_rv,
p_hosp_mean_rv,
p_hosp_w_sd_rv,
phi_rv,
)
from pyrenew.deterministic import DeterministicVariable

from pyrenew_hew.hosp_only_ww_model import hosp_only_ww_model


def build_model_from_dir(model_dir):
data_path = model_dir / "data_for_model_fit.json"
prior_path = model_dir / "priors.py"

with open(
data_path,
Expand Down Expand Up @@ -62,24 +48,26 @@ def build_model_from_dir(model_dir):
- 1
)

priors = runpy.run_path(prior_path)

right_truncation_offset = model_data["right_truncation_offset"]

my_model = hosp_only_ww_model(
state_pop=state_pop,
i0_first_obs_n_rv=i0_first_obs_n_rv,
initialization_rate_rv=initialization_rate_rv,
log_r_mu_intercept_rv=log_r_mu_intercept_rv,
autoreg_rt_rv=autoreg_rt_rv,
eta_sd_rv=eta_sd_rv, # sd of random walk for ar process,
i0_first_obs_n_rv=priors["i0_first_obs_n_rv"],
initialization_rate_rv=priors["initialization_rate_rv"],
log_r_mu_intercept_rv=priors["log_r_mu_intercept_rv"],
autoreg_rt_rv=priors["autoreg_rt_rv"],
eta_sd_rv=priors["eta_sd_rv"], # sd of random walk for ar process,
generation_interval_pmf_rv=generation_interval_pmf_rv,
infection_feedback_strength_rv=inf_feedback_strength_rv,
infection_feedback_strength_rv=priors["inf_feedback_strength_rv"],
infection_feedback_pmf_rv=infection_feedback_pmf_rv,
p_hosp_mean_rv=p_hosp_mean_rv,
p_hosp_w_sd_rv=p_hosp_w_sd_rv,
autoreg_p_hosp_rv=autoreg_p_hosp_rv,
hosp_wday_effect_rv=hosp_wday_effect_rv,
p_hosp_mean_rv=priors["p_ed_visit_mean_rv"],
p_hosp_w_sd_rv=priors["p_ed_visit_w_sd_rv"],
damonbayer marked this conversation as resolved.
Show resolved Hide resolved
autoreg_p_hosp_rv=priors["autoreg_p_ed_visit_rv"],
hosp_wday_effect_rv=priors["ed_visit_wday_effect_rv"],
inf_to_hosp_rv=inf_to_hosp_rv,
phi_rv=phi_rv,
phi_rv=priors["phi_rv"],
right_truncation_pmf_rv=right_truncation_pmf_rv,
n_initialization_points=uot,
)
Expand Down
23 changes: 18 additions & 5 deletions pipelines/create_hubverse_table.R
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,18 @@ draws_to_quantiles <- function(forecast_dir,
return(epiweekly_quantiles)
}

create_hubverse_table <- function(model_run_dir) {
create_hubverse_table <- function(model_run_dir,
exclude = NULL) {
locations_to_process <- fs::dir_ls(model_run_dir,
type = "directory"
)

if (!is.null(exclude)) {
locations_to_process <- locations_to_process[
!(fs::path_file(locations_to_process) %in% exclude)
]
}

report_date <- stringr::str_match(
model_run_dir,
"r_(([0-9]|-)+)_f"
Expand Down Expand Up @@ -131,8 +138,9 @@ create_hubverse_table <- function(model_run_dir) {


main <- function(model_run_dir,
output_path) {
create_hubverse_table(model_run_dir) |>
output_path,
exclude = NULL) {
create_hubverse_table(model_run_dir, exclude = exclude) |>
readr::write_tsv(output_path)
}

Expand All @@ -151,11 +159,16 @@ p <- argparser::arg_parser(
argparser::add_argument(
"output_path",
help = "path to which to save the table"
) |>
argparser::add_argument(
"--exclude",
help = "locations to exclude, as a whitespace-separated string",
default = ""
)

argv <- argparser::parse_args(p)

main(
argv$model_run_dir,
argv$output_path
argv$output_path,
stringr::str_split_1(argv$exclude, " ")
)
24 changes: 10 additions & 14 deletions pipelines/priors.py → pipelines/default_priors.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,6 @@
from numpyro.infer.reparam import LocScaleReparam
from pyrenew.randomvariable import DistributionalVariable, TransformedVariable

# many of these should probably be different depending
# on if we are modeling flu
# or covid

i0_first_obs_n_rv = DistributionalVariable(
"i0_first_obs_n_rv",
dist.Beta(1, 10),
Expand Down Expand Up @@ -41,28 +37,28 @@
)
# Could be reparameterized?

# Note: multiplied by 1/2 from hosp model
# this actually represents ed visits
p_hosp_mean_rv = DistributionalVariable(
"p_hosp_mean",
p_ed_visit_mean_rv = DistributionalVariable(
"p_ed_visit_mean",
dist.Normal(
transformation.SigmoidTransform().inv(0.005),
0.3,
),
) # logit scale


p_hosp_w_sd_rv = DistributionalVariable(
"p_hosp_w_sd_sd", dist.TruncatedNormal(0, 0.01, low=0)
p_ed_visit_w_sd_rv = DistributionalVariable(
"p_ed_visit_w_sd_sd", dist.TruncatedNormal(0, 0.01, low=0)
)


autoreg_p_hosp_rv = DistributionalVariable("autoreg_p_hosp", dist.Beta(1, 100))
autoreg_p_ed_visit_rv = DistributionalVariable(
"autoreg_p_ed_visit_rv", dist.Beta(1, 100)
)

hosp_wday_effect_rv = TransformedVariable(
"hosp_wday_effect",
ed_visit_wday_effect_rv = TransformedVariable(
"ed_visit_wday_effect",
DistributionalVariable(
"hosp_wday_effect_raw",
"ed_visit_wday_effect_raw",
dist.Dirichlet(jnp.array([5, 5, 5, 5, 5, 5, 5])),
),
transformation.AffineTransform(loc=0, scale=7),
Expand Down
Loading
Loading