Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Fig 8 metrics to sea ice package #1168

Draft
wants to merge 14 commits into
base: main
Choose a base branch
from
2 changes: 1 addition & 1 deletion pcmdi_metrics/sea_ice/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ A [demo parameter file](https://github.com/PCMDI/pcmdi_metrics/blob/405_sic_ao/p

## Postprocessing

A script is provided to create a multi-model bar chart using results from multiple runs of the sea ice driver. This script can be found in `./scripts/sea_ice_figures.py`.
Two postprocessing scripts are provided in `./scripts/sea_ice_figures.py`. The script `sea_ice_figures.py` creates a multi-model bar chart for all sectors using results from a model ensemble. The script `sea_ice_total_errors.py` plots the total errors for the Arctic and Antarctic for a model ensemble.

Example command:
```
Expand Down
29 changes: 29 additions & 0 deletions pcmdi_metrics/sea_ice/lib/sea_ice_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,35 @@ def process_by_region(ds, ds_var, ds_area, pole):
return clims, means


def get_area(data, ds_area):
xvar = find_lon(data)
coord_i, coord_j = get_xy_coords(data, xvar)
total_area = (data * ds_area).sum((coord_i, coord_j), skipna=True)
if isinstance(total_area.data, dask.array.core.Array):
ta_mean = total_area.data.compute().item()
else:
ta_mean = total_area.data.item()
return ta_mean


def get_ocean_area_for_regions(ds, ds_var, area_val, pole):
# ds should have land/sea mask applied
regions_list = ["arctic", "antarctic", "ca", "na", "np", "sa", "sp", "io"]
areas = {}
# Only want spatial slice
if "time" in ds:
ds = ds.isel({"time": 0})
xvar = find_lon(ds)
yvar = find_lat(ds)
for region in regions_list:
data = choose_region(region, ds, ds_var, xvar, yvar, pole)
tmp = get_area(data, area_val)
areas[region] = tmp
print(tmp)
del data
return areas


def find_lon(ds):
for key in ds.coords:
if key in ["lon", "longitude"]:
Expand Down
186 changes: 186 additions & 0 deletions pcmdi_metrics/sea_ice/scripts/sea_ice_total_errors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
#!/usr/bin/env python
import argparse
import glob
import json
import os

import matplotlib.pyplot as plt
import numpy as np

# ----------------
# Load Metrics
# ----------------
parser = argparse.ArgumentParser(
prog="sea_ice_figures.py", description="Create figure for sea ice metrics"
)
parser.add_argument(
"--filelist",
dest="filelist",
default="sea_ice_metrics.json",
type=str,
help="Filename of sea ice metrics to glob. Permitted to use '*'",
)
parser.add_argument(
"--output_path",
dest="output_path",
default=".",
type=str,
help="The directory at which to write figure file",
)
args = parser.parse_args()

filelist = args.filelist
metrics_output_path = args.output_path

model_list = []
print(filelist)
metrics = {"RESULTS": {}}
for metrics_file in glob.glob(filelist):
with open(metrics_file) as mf:
results = json.load(mf)
for item in results["DIMENSIONS"]["model"]:
model_list.append(item)
metrics["RESULTS"].update(results["RESULTS"])

model_list.sort()
tmp = model_list[0]
reference_data_set = list(metrics["RESULTS"][tmp]["arctic"]["model_mean"].keys())[0]

# ----------------
# Make figure
# ----------------
sector_list = ["Arctic", "Antarctic"]
sector_short = ["arctic", "antarctic"]
fig7, ax7 = plt.subplots(2, 1, figsize=(5, 4))
mlabels = model_list
ind = np.arange(len(mlabels)) # the x locations for the groups
width = 0.7
n = len(ind)
for inds, sector in enumerate(sector_list):
# Assemble data
mse_clim = []
mse_ext = []
reg_clim = []
reg_ext = []
rgn = sector_short[inds]
for nmod, model in enumerate(model_list):
mse_clim.append(
float(
metrics["RESULTS"][model][rgn]["model_mean"][reference_data_set][
"monthly_clim"
]["mse"]
)
)
mse_ext.append(
float(
metrics["RESULTS"][model][rgn]["model_mean"][reference_data_set][
"total_extent"
]["mse"]
)
)
reg_clim.append(
float(
metrics["RESULTS"][model][rgn]["model_mean"][reference_data_set][
"monthly_clim"
]["sector_mse"]
)
)
reg_ext.append(
float(
metrics["RESULTS"][model][rgn]["model_mean"][reference_data_set][
"total_extent"
]["sector_mse"]
)
)

# plot bars
ax7[inds].bar(
ind,
mse_ext,
width,
color="r",
edgecolor="k",
linewidth=0.1,
label="Ann. Mean",
bottom=np.zeros(np.shape(mse_ext)),
)
ax7[inds].bar(
ind,
mse_clim,
width,
color="b",
edgecolor="k",
linewidth=0.1,
label="Ann. Cycle",
bottom=mse_ext,
)
bottom = [mse_ext[x] + mse_clim[x] for x in range(0, len(mse_ext))]
ax7[inds].bar(
ind,
reg_ext,
width,
color="y",
edgecolor="k",
linewidth=0.1,
label="Ann. Mean Reg.",
bottom=bottom,
)
bottom = [mse_ext[x] + mse_clim[x] + reg_ext[x] for x in range(0, len(mse_ext))]
ax7[inds].bar(
ind,
reg_clim,
width,
color="g",
edgecolor="k",
linewidth=0.1,
label="Ann. Cycle Reg.",
bottom=bottom,
)

# X axis label
if inds == len(sector_list) - 1:
ax7[inds].set_xticks(ind, mlabels, rotation=90, size=4, weight="bold")
else:
ax7[inds].set_xticks(ind, labels="")
ax7[inds].set_xlim(-1, len(mse_ext))

# Y axis
tmp = [
mse_ext[x] + mse_clim[x] + reg_ext[x] + reg_clim[x]
for x in range(0, len(mse_ext))
]
datamax = np.nanmax(np.array(tmp))
ymax = (datamax) * 1.05
ax7[inds].set_ylim(0.0, ymax)
ticks = range(0, round(ymax), 10)
labels = [str(round(x, 0)) for x in ticks]
ax7[inds].set_yticks(ticks, labels, fontsize=5)

# subplot frame styling
ax7[inds].tick_params(color=[0.3, 0.3, 0.3])
for spine in ax7[inds].spines.values():
spine.set_edgecolor([0.3, 0.3, 0.3])
spine.set_linewidth(0.5)
# labels etc
ax7[inds].set_ylabel("10${^1}{^2}$km${^4}$", size=6, weight="bold")
ax7[inds].grid(True, linestyle=":", linewidth=0.5)
ax7[inds].annotate(
sector,
(0.35, 0.85),
xycoords="axes fraction",
size=6,
weight="bold",
bbox=dict(facecolor="white", edgecolor="white", pad=1),
)

# Add legend, save figure
leg = ax7[0].legend(loc="upper right", fontsize=5, edgecolor=[0.3, 0.3, 0.3])
leg.get_frame().set_linewidth(0.5) # legend styling
t = plt.suptitle(
"Mean Square Error relative to " + reference_data_set, fontsize=8, y=0.93
)
plt.tight_layout()
figfile = os.path.join(metrics_output_path, "total_MSE_bar_chart.png")
plt.savefig(figfile, dpi=600)
print("Figure written to ", figfile)
print("Done")
Loading