Skip to content

Commit

Permalink
More clean up of code
Browse files Browse the repository at this point in the history
Added pytest for delta analysis. Also only do Sobol indices, not Delta moments.
  • Loading branch information
aaschwanden committed Dec 17, 2024
1 parent 470dd65 commit 61314d1
Show file tree
Hide file tree
Showing 4 changed files with 259 additions and 50 deletions.
162 changes: 124 additions & 38 deletions analysis/analyze_scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,12 @@
Analyze RAGIS ensemble.
"""

import copy
import json
import time
import warnings
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
from concurrent.futures import ThreadPoolExecutor, as_completed
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed
from functools import wraps
from importlib.resources import files
from pathlib import Path
Expand All @@ -45,7 +47,7 @@
from tqdm.auto import tqdm

import pism_ragis.processing as prp
from pism_ragis.analyze import delta_analyze
from pism_ragis.analyze import delta_analysis, sobol_analysis
from pism_ragis.decorators import profileit, timeit
from pism_ragis.filtering import filter_outliers, importance_sampling
from pism_ragis.likelihood import log_normal
Expand Down Expand Up @@ -141,6 +143,7 @@ def run_sampling(
fudge_factor: float = 3.0,
fig_dir: Union[str, Path] = "figures",
params: List[str] = [],
config: Dict = {},
) -> pd.DataFrame:
"""
Run sampling to process observed and simulated datasets.
Expand Down Expand Up @@ -168,7 +171,7 @@ def run_sampling(
The directory where figures will be saved, by default "figures".
params : List[str], optional
A list of parameter names to be used for filtering configurations, by default [].
ragis_config : Dict, optional
config : Dict, optional
A dictionary containing configuration settings for the RAGIS model, by default {}.
Returns
Expand Down Expand Up @@ -236,6 +239,7 @@ def run_sampling(
obs_mean_var,
filter_range=filter_range,
fig_dir=fig_dir,
config=config,
)
prior_posterior = pd.concat(prior_posterior_list).reset_index(drop=True)
prior_posterior = prior_posterior.apply(prp.convert_column_to_numeric)
Expand Down Expand Up @@ -355,6 +359,7 @@ def plot_basins(
>>> plot_basins(observed, prior, posterior, "grounding_line_flux")
"""
start_time = time.time()

with tqdm(
desc="Plotting basins",
total=len(observed.basin),
Expand All @@ -379,6 +384,83 @@ def plot_basins(
)
progress_bar.update()

# with ProcessPoolExecutor(max_workers=options.n_jobs) as executor:
# futures = []
# for basin in observed.basin:
# futures.append(
# executor.submit(
# plot_obs_sims,
# observed.sel(basin=basin).sel(
# {"time": slice(str(plot_range[0]), str(plot_range[1]))}
# ),
# prior.sel(basin=basin).sel(
# {"time": slice(str(plot_range[0]), str(plot_range[1]))}
# ),
# posterior.sel(basin=basin).sel(
# {"time": slice(str(plot_range[0]), str(plot_range[1]))}
# ),
# config=config,
# filtering_var=filtering_var,
# filter_range=filter_range,
# fig_dir=fig_dir,
# obs_alpha=obs_alpha,
# sim_alpha=sim_alpha,
# )
# )
# for future in tqdm(
# as_completed(futures), total=len(futures), desc="Processing basins"
# ):
# try:
# future.result()
# except Exception as e:
# print(f"An error occurred: {e}")

# obs_list = []
# prior_list = []
# posterior_list = []
# for basin in observed.basin:
# obs_list.append(
# observed.sel(basin=basin).sel(
# {"time": slice(str(plot_range[0]), str(plot_range[1]))}
# )
# )
# prior_list.append(
# prior.sel(basin=basin).sel(
# {"time": slice(str(plot_range[0]), str(plot_range[1]))}
# )
# )

# posterior_list.append(
# posterior.sel(basin=basin).sel(
# {"time": slice(str(plot_range[0]), str(plot_range[1]))}
# )
# )

# client = Client()
# print(f"Open client in browser: {client.dashboard_link}")

# # obs_scattered = client.scatter(obs_list)
# # prior_scattered = client.scatter(prior_list)
# # posterior_scattered = client.scatter(posterior_list)
# obs_scattered = obs_list
# prior_scattered = prior_list
# posterior_scattered = posterior_list
# futures = client.map(
# plot_obs_sims,
# obs_scattered,
# prior_scattered,
# posterior_scattered,
# config=config,
# filtering_var=filtering_var,
# filter_range=filter_range,
# fig_dir=fig_dir,
# obs_alpha=obs_alpha,
# sim_alpha=sim_alpha,
# )

# progress(futures, notebook=notebook)
# client.close()

end_time = time.time()
elapsed_time = end_time - start_time
print(f"...took {elapsed_time:.2f}s")
Expand Down Expand Up @@ -679,7 +761,7 @@ def plot_outliers(


@timeit
def run_delta_analysis(
def run_sensitivity_analysis(
input_df: pd.DataFrame,
response_ds: xr.Dataset,
filter_vars: List[str],
Expand Down Expand Up @@ -720,7 +802,7 @@ def run_delta_analysis(

client = Client()
print(f"Open client in browser: {client.dashboard_link}")
all_delta_indices_list = []
sensitivity_indices_list = []
for gdim, df in input_df.groupby(by=group_dim):
df = df.drop(columns=[group_dim])
problem = {
Expand All @@ -745,28 +827,28 @@ def run_delta_analysis(
)

futures = client.map(
delta_analyze,
delta_analysis,
responses_scattered,
X=df.to_numpy(),
problem=problem,
ensemble_df=df,
)
progress(futures, notebook=notebook)
result = client.gather(futures)

delta_indices = xr.concat(
sensitivity_indices = xr.concat(
[r.expand_dims(iter_dim) for r in result], dim=iter_dim
)
delta_indices[iter_dim] = responses[iter_dim]
delta_indices = delta_indices.expand_dims(group_dim, axis=1)
delta_indices[group_dim] = [gdim]
delta_indices = delta_indices.expand_dims("filtered_by", axis=2)
delta_indices["filtered_by"] = [filter_var]
all_delta_indices_list.append(delta_indices)

all_delta_indices: xr.Dataset = xr.merge(all_delta_indices_list)
sensitivity_indices[iter_dim] = responses[iter_dim]
sensitivity_indices = sensitivity_indices.expand_dims(group_dim, axis=1)
sensitivity_indices[group_dim] = [gdim]
sensitivity_indices = sensitivity_indices.expand_dims("filtered_by", axis=2)
sensitivity_indices["filtered_by"] = [filter_var]
sensitivity_indices_list.append(sensitivity_indices)

all_sensitivity_indices: xr.Dataset = xr.merge(sensitivity_indices_list)
client.close()

return all_delta_indices
return all_sensitivity_indices


@timeit
Expand Down Expand Up @@ -1124,7 +1206,8 @@ def plot_obs_sims(
)
rolling_window = 13
ragis_config = toml.load(ragis_config_file)
params_short_dict = ragis_config["Parameters"]
config = json.loads(json.dumps(ragis_config))
params_short_dict = config["Parameters"]
params = list(params_short_dict.keys())

result_dir = Path(options.result_dir)
Expand All @@ -1141,19 +1224,19 @@ def plot_obs_sims(

plt.rcParams["font.size"] = 6

flux_vars = ragis_config["Flux Variables"]
flux_vars = config["Flux Variables"]
flux_uncertainty_vars = {
k + "_uncertainty": v + "_uncertainty" for k, v in flux_vars.items()
}

simulated_ds = prepare_simulations(
basin_files, ragis_config, reference_date, parallel=parallel, engine=engine
basin_files, config, reference_date, parallel=parallel, engine=engine
)

observed_mankoff_ds, observed_grace_ds = prepare_observations(
options.mankoff_url,
options.grace_url,
ragis_config,
config,
reference_date,
engine=engine,
)
Expand Down Expand Up @@ -1235,7 +1318,7 @@ def plot_obs_sims(
filter_range=filter_range,
fudge_factor=fudge_factor,
params=params,
ragis_config=ragis_config,
config=config,
fig_dir=fig_dir,
)

Expand All @@ -1248,7 +1331,7 @@ def plot_obs_sims(
fudge_factor=10,
filter_range=filter_range,
params=params,
ragis_config=ragis_config,
config=config,
fig_dir=fig_dir,
)

Expand All @@ -1273,19 +1356,19 @@ def plot_obs_sims(
/ Path(f"""prior_posterior_{filter_range[0]}-{filter_range[1]}.parquet""")
)

bins_dict = ragis_config["Posterior Bins"]
bins_dict = config["Posterior Bins"]
plot_prior_posteriors(
prior_posterior.rename(columns=params_short_dict),
bins_dict,
fig_dir=fig_dir,
config=ragis_config,
config=config,
)

prior_config = filter_config(simulated.isel({"time": 0}), params)
prior_df = config_to_dataframe(prior_config, ensemble="Prior")
params_df = prepare_input(prior_df)

all_delta_indices_list = []
sensitivity_indices_list = []
for basin_group, intersection, filtering_vars in zip(
[simulated_grace_basins_ds, simulated_mankoff_basins_ds],
[intersection_grace, intersection_mankoff],
Expand All @@ -1294,43 +1377,46 @@ def plot_obs_sims(
sobol_response_ds = basin_group.sel(time=slice("1980-01-01", "2020-01-01"))
sobol_input_df = params_df[params_df["basin"].isin(intersection)]

all_delta_indices_list.append(
run_delta_analysis(
sensitivity_indices_list.append(
run_sensitivity_analysis(
sobol_input_df,
sobol_response_ds,
filtering_vars,
notebook=notebook,
)
)

all_delta_indices = xr.concat(all_delta_indices_list, dim="basin")
sensitivity_indices = xr.concat(sensitivity_indices_list, dim="basin")
si_dir = result_dir / Path("sensitivity_indices")
si_dir.mkdir(parents=True, exist_ok=True)
sensitivity_indices.to_netcdf(si_dir / Path("sensitivity_indices.nc"))

# Extract the prefix from each coordinate value
prefixes = [
name.split(".")[0] for name in all_delta_indices.pism_config_axis.values
name.split(".")[0] for name in sensitivity_indices.pism_config_axis.values
]

# Add the prefixes as a new coordinate
all_delta_indices = all_delta_indices.assign_coords(
sensitivity_indices = sensitivity_indices.assign_coords(
prefix=("pism_config_axis", prefixes)
)

parameter_groups = ragis_config["Parameter Groups"]
si_prefixes = [parameter_groups[name] for name in all_delta_indices.prefix.values]
parameter_groups = config["Parameter Groups"]
si_prefixes = [parameter_groups[name] for name in sensitivity_indices.prefix.values]

all_delta_indices = all_delta_indices.assign_coords(
sensitivity_indices = sensitivity_indices.assign_coords(
sensitivity_indices_group=("pism_config_axis", si_prefixes)
)
# Group by the new coordinate and compute the sum for each group
indices_vars = [v for v in all_delta_indices.data_vars if "_conf" not in v]
indices_vars = [v for v in sensitivity_indices.data_vars if "_conf" not in v]
aggregated_indices = (
all_delta_indices[indices_vars].groupby("sensitivity_indices_group").sum()
sensitivity_indices[indices_vars].groupby("sensitivity_indices_group").sum()
)
# Group by the new coordinate and compute the sum the squares for each group
# then take the root.
indices_conf = [v for v in all_delta_indices.data_vars if "_conf" in v]
indices_conf = [v for v in sensitivity_indices.data_vars if "_conf" in v]
aggregated_conf = (
all_delta_indices[indices_conf]
sensitivity_indices[indices_conf]
.apply(np.square)
.groupby("sensitivity_indices_group")
.sum()
Expand Down
Loading

0 comments on commit 61314d1

Please sign in to comment.