Skip to content

Commit

Permalink
Merge pull request #51 from pstradio/comparison_plots_update
Browse files Browse the repository at this point in the history
Comparison plots update
  • Loading branch information
pstradio authored May 20, 2022
2 parents dcb08a9 + 91e9b77 commit 7d25793
Show file tree
Hide file tree
Showing 15 changed files with 1,095 additions and 838 deletions.
10 changes: 10 additions & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,16 @@ To remove the environment again, run:
conda deactivate
conda env remove -n qa4sm_reader
Code Formatting
---------------
To apply pep8 conform styling to any changed files [we use `yapf`](https://github.com/google/yapf). The correct
settings are already set in `setup.cfg`. Therefore the following command
should be enough:

.. code::
yapf file.py --in-place
Testing
-------

Expand Down
1 change: 0 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,5 @@
print("Error: __version of setuptools is too old (<58.0)!")
sys.exit(1)


if __name__ == "__main__":
setup(use_pyscaffold=True)
1 change: 0 additions & 1 deletion src/qa4sm_reader/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# -*- coding: utf-8 -*-


# __author__ = "Lukas Racbhauer"
# __copyright__ = "2019, TU Wien, Department of Geodesy and Geoinformation"
# __license__ = "mit"
Expand Down
175 changes: 85 additions & 90 deletions src/qa4sm_reader/comparing.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@ class QA4SMComparison:
Class that provides comparison plots and table for a list of netCDF files. As initialising a QA4SMImage can
take some time, the class can be updated keeping memory of what has already been initialized
"""

def __init__(self, paths: list or str, extent: tuple = None, get_intersection: bool = True):
def __init__(self,
paths: list or str,
extent: tuple = None,
get_intersection: bool = True):
"""
Initialise the QA4SMImages from the paths to netCDF files specified
Expand All @@ -45,11 +47,14 @@ def __init__(self, paths: list or str, extent: tuple = None, get_intersection: b
self.paths = paths
self.extent = extent

self.compared = self._init_imgs(extent=extent, get_intersection=get_intersection)
self.compared = self._init_imgs(extent=extent,
get_intersection=get_intersection)
self.ref = self._check_ref()
self.union = not get_intersection

def _init_imgs(self, extent: tuple = None, get_intersection: bool = True) -> list:
def _init_imgs(self,
extent: tuple = None,
get_intersection: bool = True) -> list:
"""
Initialize the QA4SMImages for the selected validation results files. If 'extent' is specified, this is used. If
not, by default the intersection of results is taken and the images are initialized with it, unless 'get_union'
Expand Down Expand Up @@ -124,8 +129,7 @@ def _check_ref(self) -> str:
if not ref == previous:
raise ComparisonError(
"The initialized validation results have different reference "
"datasets. This is currently not supported"
)
"datasets. This is currently not supported")
previous = ref

return ref
Expand All @@ -137,7 +141,9 @@ def common_metrics(self) -> dict:
img_metrics = {}
for metric in img.metrics:
# hardcoded because n_obs cannot be compared. todo: exclude empty metrics (problem: the values are not loaded here)
if metric in glob.metric_groups[0] or metric in ["tau", "p_tau"]:
if metric in glob.metric_groups[0] or metric in [
"tau", "p_tau"
]:
continue
img_metrics[metric] = glob._metric_name[metric]
if n == 0:
Expand All @@ -157,8 +163,8 @@ def overlapping(self) -> bool:
polys = []
for img in self.compared: # get names and extents for all images
minlon, maxlon, minlat, maxlat = img.extent
bounds = [(minlon, minlat), (maxlon, minlat),
(maxlon, maxlat), (minlon, maxlat)]
bounds = [(minlon, minlat), (maxlon, minlat), (maxlon, maxlat),
(minlon, maxlat)]
Pol = Polygon(bounds)
polys.append(Pol)

Expand Down Expand Up @@ -186,18 +192,14 @@ def validation_names(self) -> list:
datasets = img.datasets
if len(datasets.others) == 2:
for n, ds_meta in enumerate(datasets.others):
name = template.format(
n, ds_meta["pretty_title"],
datasets.ref["pretty_title"]
)
name = template.format(n, ds_meta["pretty_title"],
datasets.ref["pretty_title"])
names.append(name)
break
else:
other = img.datasets.others[0]
name = template.format(
n, other["pretty_title"],
img.datasets.ref["pretty_title"]
)
name = template.format(n, other["pretty_title"],
img.datasets.ref["pretty_title"])
names.append(name)

return names
Expand All @@ -219,8 +221,7 @@ def _check_pairwise(self) -> Union[bool, ComparisonError]:
if not pairwise:
raise ComparisonError(
"For pairwise comparison methods, only two "
"validation results with two datasets each can be compared"
)
"validation results with two datasets each can be compared")

def get_reference_points(self) -> tuple:
"""
Expand All @@ -246,12 +247,10 @@ def get_reference_points(self) -> tuple:

return ref_points

def _combine_geometry(
self,
imgs: list,
get_intersection: bool = True,
return_polys=False
) -> tuple:
def _combine_geometry(self,
imgs: list,
get_intersection: bool = True,
return_polys=False) -> tuple:
"""
Return the union or the intersection of the spatial extents of the provided validations; in case of intersection,
check that the validations are overlapping
Expand All @@ -274,8 +273,8 @@ def _combine_geometry(

for n, img in enumerate(imgs):
minlon, maxlon, minlat, maxlat = img.extent
bounds = [(minlon, minlat), (maxlon, minlat),
(maxlon, maxlat), (minlon, maxlat)]
bounds = [(minlon, minlat), (maxlon, minlat), (maxlon, maxlat),
(minlon, maxlat)]
Pol = Polygon(bounds)
name = f"Val{n}: " + img.name
polys[name] = Pol
Expand Down Expand Up @@ -306,9 +305,9 @@ def _combine_geometry(
return minlon, maxlon, minlat, maxlat

def visualize_extent(
self,
intersection: bool = True,
plot_points: bool = False,
self,
intersection: bool = True,
plot_points: bool = False,
):
"""
Method to get and visualize the comparison extent including the reference points.
Expand All @@ -334,16 +333,16 @@ def visualize_extent(
ref_grid_stepsize = self.compared[0].ref_dataset_grid_stepsize

ref = self._check_ref()["short_name"]
plm.plot_spatial_extent(
polys=polys,
ref_points=ref_points,
overlapping=self.overlapping,
intersection_extent=extent,
reg_grid=(ref != "ISMN"),
grid_stepsize=ref_grid_stepsize
)

def _get_data(self, metric: str) -> dict: # todo: use new handlers to get metadata for Variable
plm.plot_spatial_extent(polys=polys,
ref_points=ref_points,
overlapping=self.overlapping,
intersection_extent=extent,
reg_grid=(ref != "ISMN"),
grid_stepsize=ref_grid_stepsize)

def _get_data(
self, metric: str
) -> dict: # todo: use new handlers to get metadata for Variable
"""
Get the list of image Variable names from a metric
Expand All @@ -360,10 +359,8 @@ def _get_data(self, metric: str) -> dict: # todo: use new handlers to get metad
varnames = {"varlist": [], "ci_list": []}
n = 0
for i, img in enumerate(self.compared):
for Var in img._iter_vars(
type="metric",
filter_parms={"metric": metric}
):
for Var in img._iter_vars(type="metric",
filter_parms={"metric": metric}):
var_cis = []
id = i
varname = Var.varname
Expand All @@ -372,16 +369,24 @@ def _get_data(self, metric: str) -> dict: # todo: use new handlers to get metad
if self.single_image:
id = n
col_name = "Val{}: {} ".format(
id, QA4SMPlotter._box_caption(Var, tc=Var.g == 3, short_caption=True)
)
id,
QA4SMPlotter._box_caption(Var,
tc=Var.g == 3,
short_caption=True))

# Remove substrings in TC column names
col_name = col_name.replace("Other Data:",
"").replace("\n", "")

data = data.rename(col_name)
varnames["varlist"].append(data)
n += 1
# get CIs too, if present
for CI_Var in img._iter_vars(
type="ci",
filter_parms={"metric": metric, "metric_ds": Var.metric_ds}
):
for CI_Var in img._iter_vars(type="ci",
filter_parms={
"metric": metric,
"metric_ds": Var.metric_ds
}):
# a bit of necessary code repetition
varname = CI_Var.varname
data = img._ds2df(varnames=[varname])[varname]
Expand Down Expand Up @@ -461,12 +466,10 @@ def _handle_multiindex(self, dfs: list) -> pd.DataFrame:

return pair_df

def _get_pairwise(
self,
metric: str,
add_stats: bool = True,
return_cis=False
) -> pd.DataFrame:
def _get_pairwise(self,
metric: str,
add_stats: bool = True,
return_cis=False) -> pd.DataFrame:
"""
Get the data and names for pairwise comparisons, meaning: two validations with one satellite dataset each. Includes
a method to subset the metric values to the selected spatial extent.
Expand Down Expand Up @@ -496,9 +499,7 @@ def _get_pairwise(

if self.overlapping:
diff = pair_df.iloc[:, 0] - pair_df.iloc[:, 1]
diff = diff.rename(
"Val0 - Val1 (common points)"
)
diff = diff.rename("Val0 - Val1 (common points)")
pair_df = pd.concat([pair_df, diff], axis=1)
if add_stats:
pair_df = self.rename_with_stats(pair_df)
Expand Down Expand Up @@ -530,8 +531,7 @@ def perform_checks(self, overlapping=False, union=False, pairwise=False):
if not self.overlapping:
raise SpatialExtentError(
"This method works only in case the initialized "
"validations have overlapping spatial extents."
)
"validations have overlapping spatial extents.")
# todo: check behavior here if union is initialized through init_union
if union and not self.extent:
if self.union:
Expand All @@ -556,15 +556,12 @@ def diff_table(self, metrics: list) -> pd.DataFrame:
for metric in metrics:
ref = self._check_ref()["short_name"]
units = glob._metric_description[metric].format(
glob.get_metric_units(ref)
)
glob.get_metric_units(ref))
description = glob._metric_name[metric] + units
medians = self._get_pairwise(metric).median()
# a bit of a hack here
table[description] = [
medians[0],
medians[1],
medians[0] - medians[1]
medians[0], medians[1], medians[0] - medians[1]
]
columns = self.validation_names
columns.append("Difference of the medians (0 - 1)")
Expand Down Expand Up @@ -593,7 +590,8 @@ def diff_boxplot(self, metric: str, **kwargs):
# prepare axis name
Metric = QA4SMMetric(metric)
ref_ds = self.ref['short_name']
um = glob._metric_description[metric].format(glob.get_metric_units(ref_ds))
um = glob._metric_description[metric].format(
glob.get_metric_units(ref_ds))
figwidth = glob.boxplot_width * (len(df.columns) + 1)
figsize = [figwidth, glob.boxplot_height]
fig, axes = plm.boxplot(
Expand All @@ -604,8 +602,8 @@ def diff_boxplot(self, metric: str, **kwargs):
)
# titles for the plot
fonts = {"fontsize": 12}
title_plot = "Comparison of {} {}\nagainst the reference {}".format(Metric.pretty_name, um,
self.ref["pretty_title"])
title_plot = "Comparison of {} {}\nagainst the reference {}".format(
Metric.pretty_name, um, self.ref["pretty_title"])
axes.set_title(title_plot, pad=glob.title_pad, **fonts)

plm.make_watermark(fig, glob.watermark_pos, offset=0.04)
Expand All @@ -625,17 +623,17 @@ def diff_mapplot(self, metric: str, **kwargs):
self.perform_checks(overlapping=True, union=True, pairwise=True)
df = self._get_pairwise(metric=metric, add_stats=False).dropna()
Metric = QA4SMMetric(metric)
um = glob._metric_description[metric].format(glob.get_metric_units(self.ref['short_name']))
um = glob._metric_description[metric].format(
glob.get_metric_units(self.ref['short_name']))
# make mapplot
cbar_label = "Difference between {} and {}".format(*df.columns) + f"{um}"

fig, axes = plm.mapplot(
df.iloc[:, 2],
metric=metric,
ref_short=self.ref['short_name'],
diff_map=True,
label=cbar_label
)
cbar_label = "Difference between {} and {}".format(
*df.columns) + f"{um}"

fig, axes = plm.mapplot(df.iloc[:, 2],
metric=metric,
ref_short=self.ref['short_name'],
diff_map=True,
label=cbar_label)
fonts = {"fontsize": 12}
title_plot = f"Overview of the difference in {Metric.pretty_name} " \
f"against the reference {self.ref['pretty_title']}"
Expand All @@ -656,23 +654,20 @@ def wrapper(self, method: str, metric=None, **kwargs):
**kwargs : kwargs
plotting keyword arguments
"""
diff_methods_lut = {'boxplot': self.diff_boxplot,
'mapplot': self.diff_mapplot}
diff_methods_lut = {
'boxplot': self.diff_boxplot,
'mapplot': self.diff_mapplot
}
try:
diff_method = diff_methods_lut[method]
except KeyError as e:
warn(
'Difference method not valid. Choose one of %s' % ', '.join(diff_methods_lut.keys())
)
warn('Difference method not valid. Choose one of %s' %
', '.join(diff_methods_lut.keys()))
raise e

if not metric:
raise ComparisonError(
"If you chose '{}' as a method, you should specify"
" a metric (e.g. 'R').".format(method)
)
" a metric (e.g. 'R').".format(method))

return diff_method(
metric=metric,
**kwargs
)
return diff_method(metric=metric, **kwargs)
Loading

0 comments on commit 7d25793

Please sign in to comment.