Skip to content

Commit

Permalink
Merge branch 'main' into mean_clim_upgrade_lee1043_20240812
Browse files Browse the repository at this point in the history
  • Loading branch information
lee1043 authored Feb 26, 2025
2 parents 32e515b + 6768705 commit f662cdb
Show file tree
Hide file tree
Showing 5 changed files with 332 additions and 25 deletions.
2 changes: 1 addition & 1 deletion pcmdi_metrics/sea_ice/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ A [demo parameter file](https://github.com/PCMDI/pcmdi_metrics/blob/405_sic_ao/p

## Postprocessing

A script is provided to create a multi-model bar chart using results from multiple runs of the sea ice driver. This script can be found in `./scripts/sea_ice_figures.py`.
Two postprocessing scripts are provided in `./scripts/sea_ice_figures.py`. The script `sea_ice_figures.py` creates a multi-model bar chart for all sectors using results from a model ensemble. The script `sea_ice_total_errors.py` plots the total errors for the Arctic and Antarctic for a model ensemble.

Example command:
```
Expand Down
29 changes: 29 additions & 0 deletions pcmdi_metrics/sea_ice/lib/sea_ice_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,35 @@ def process_by_region(ds, ds_var, ds_area, pole):
return clims, means


def get_area(data, ds_area):
xvar = find_lon(data)
coord_i, coord_j = get_xy_coords(data, xvar)
total_area = (data * ds_area).sum((coord_i, coord_j), skipna=True)
if isinstance(total_area.data, dask.array.core.Array):
ta_mean = total_area.data.compute().item()
else:
ta_mean = total_area.data.item()
return ta_mean


def get_ocean_area_for_regions(ds, ds_var, area_val, pole):
# ds should have land/sea mask applied
regions_list = ["arctic", "antarctic", "ca", "na", "np", "sa", "sp", "io"]
areas = {}
# Only want spatial slice
if "time" in ds:
ds = ds.isel({"time": 0})
xvar = find_lon(ds)
yvar = find_lat(ds)
for region in regions_list:
data = choose_region(region, ds, ds_var, xvar, yvar, pole)
tmp = get_area(data, area_val)
areas[region] = tmp
print(tmp)
del data
return areas


def find_lon(ds):
for key in ds.coords:
if key in ["lon", "longitude"]:
Expand Down
186 changes: 186 additions & 0 deletions pcmdi_metrics/sea_ice/scripts/sea_ice_total_errors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
#!/usr/bin/env python
import argparse
import glob
import json
import os

import matplotlib.pyplot as plt
import numpy as np

# ----------------
# Load Metrics
# ----------------
parser = argparse.ArgumentParser(
prog="sea_ice_figures.py", description="Create figure for sea ice metrics"
)
parser.add_argument(
"--filelist",
dest="filelist",
default="sea_ice_metrics.json",
type=str,
help="Filename of sea ice metrics to glob. Permitted to use '*'",
)
parser.add_argument(
"--output_path",
dest="output_path",
default=".",
type=str,
help="The directory at which to write figure file",
)
args = parser.parse_args()

filelist = args.filelist
metrics_output_path = args.output_path

model_list = []
print(filelist)
metrics = {"RESULTS": {}}
for metrics_file in glob.glob(filelist):
with open(metrics_file) as mf:
results = json.load(mf)
for item in results["DIMENSIONS"]["model"]:
model_list.append(item)
metrics["RESULTS"].update(results["RESULTS"])

model_list.sort()
tmp = model_list[0]
reference_data_set = list(metrics["RESULTS"][tmp]["arctic"]["model_mean"].keys())[0]

# ----------------
# Make figure
# ----------------
sector_list = ["Arctic", "Antarctic"]
sector_short = ["arctic", "antarctic"]
fig7, ax7 = plt.subplots(2, 1, figsize=(5, 4))
mlabels = model_list
ind = np.arange(len(mlabels)) # the x locations for the groups
width = 0.7
n = len(ind)
for inds, sector in enumerate(sector_list):
# Assemble data
mse_clim = []
mse_ext = []
reg_clim = []
reg_ext = []
rgn = sector_short[inds]
for nmod, model in enumerate(model_list):
mse_clim.append(
float(
metrics["RESULTS"][model][rgn]["model_mean"][reference_data_set][
"monthly_clim"
]["mse"]
)
)
mse_ext.append(
float(
metrics["RESULTS"][model][rgn]["model_mean"][reference_data_set][
"total_extent"
]["mse"]
)
)
reg_clim.append(
float(
metrics["RESULTS"][model][rgn]["model_mean"][reference_data_set][
"monthly_clim"
]["sector_mse"]
)
)
reg_ext.append(
float(
metrics["RESULTS"][model][rgn]["model_mean"][reference_data_set][
"total_extent"
]["sector_mse"]
)
)

# plot bars
ax7[inds].bar(
ind,
mse_ext,
width,
color="r",
edgecolor="k",
linewidth=0.1,
label="Ann. Mean",
bottom=np.zeros(np.shape(mse_ext)),
)
ax7[inds].bar(
ind,
mse_clim,
width,
color="b",
edgecolor="k",
linewidth=0.1,
label="Ann. Cycle",
bottom=mse_ext,
)
bottom = [mse_ext[x] + mse_clim[x] for x in range(0, len(mse_ext))]
ax7[inds].bar(
ind,
reg_ext,
width,
color="y",
edgecolor="k",
linewidth=0.1,
label="Ann. Mean Reg.",
bottom=bottom,
)
bottom = [mse_ext[x] + mse_clim[x] + reg_ext[x] for x in range(0, len(mse_ext))]
ax7[inds].bar(
ind,
reg_clim,
width,
color="g",
edgecolor="k",
linewidth=0.1,
label="Ann. Cycle Reg.",
bottom=bottom,
)

# X axis label
if inds == len(sector_list) - 1:
ax7[inds].set_xticks(ind, mlabels, rotation=90, size=4, weight="bold")
else:
ax7[inds].set_xticks(ind, labels="")
ax7[inds].set_xlim(-1, len(mse_ext))

# Y axis
tmp = [
mse_ext[x] + mse_clim[x] + reg_ext[x] + reg_clim[x]
for x in range(0, len(mse_ext))
]
datamax = np.nanmax(np.array(tmp))
ymax = (datamax) * 1.05
ax7[inds].set_ylim(0.0, ymax)
ticks = range(0, round(ymax), 10)
labels = [str(round(x, 0)) for x in ticks]
ax7[inds].set_yticks(ticks, labels, fontsize=5)

# subplot frame styling
ax7[inds].tick_params(color=[0.3, 0.3, 0.3])
for spine in ax7[inds].spines.values():
spine.set_edgecolor([0.3, 0.3, 0.3])
spine.set_linewidth(0.5)
# labels etc
ax7[inds].set_ylabel("10${^1}{^2}$km${^4}$", size=6, weight="bold")
ax7[inds].grid(True, linestyle=":", linewidth=0.5)
ax7[inds].annotate(
sector,
(0.35, 0.85),
xycoords="axes fraction",
size=6,
weight="bold",
bbox=dict(facecolor="white", edgecolor="white", pad=1),
)

# Add legend, save figure
leg = ax7[0].legend(loc="upper right", fontsize=5, edgecolor=[0.3, 0.3, 0.3])
leg.get_frame().set_linewidth(0.5) # legend styling
t = plt.suptitle(
"Mean Square Error relative to " + reference_data_set, fontsize=8, y=0.93
)
plt.tight_layout()
figfile = os.path.join(metrics_output_path, "total_MSE_bar_chart.png")
plt.savefig(figfile, dpi=600)
print("Figure written to ", figfile)
print("Done")
Loading

0 comments on commit f662cdb

Please sign in to comment.