Skip to content

Commit

Permalink
added point_sigma_iqr.eval_from_iterator call and a few pragma statem…
Browse files Browse the repository at this point in the history
…ents to get to full coverage
  • Loading branch information
eacharles committed Feb 21, 2024
1 parent eb1557f commit 6887c8d
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 24 deletions.
30 changes: 15 additions & 15 deletions src/qp/metrics/base_metric_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,9 +147,24 @@ class PointToPointMetric(BaseMetric):

metric_input_type = MetricInputType.point_to_point

def eval_from_iterator(self, estimate, reference):
self.initialize()
for estimate, reference in zip(estimate, reference):
centroids = self.accumulate(estimate, reference)
return self.finalize([centroids])

def evaluate(self, estimate, reference):
raise NotImplementedError()

def initialize(self): #pragma: no cover
pass

def accumulate(self, estimate, reference): #pragma: no cover
raise NotImplementedError()

def finalize(self): #pragma: no cover
raise NotImplementedError()


class PointToDistMetric(BaseMetric):
"""A base class for metrics that require a point estimate as the estimated
Expand All @@ -160,18 +175,3 @@ class PointToDistMetric(BaseMetric):

def evaluate(self, estimate, reference):
raise NotImplementedError()

def eval_from_iterator(self, estimate, reference):
self.initialize()
for estimate, reference in zip(estimate, reference):
self.accumulate(estimate, reference)
return self.finalize()

def initialize(self):
pass

def accumulate(self, estimate, reference):
raise NotImplementedError()

def finalize(self):
raise NotImplementedError()
2 changes: 1 addition & 1 deletion src/qp/metrics/pit.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def __init__(self, qp_ens, true_vals, eval_grid=DEFAULT_QUANTS):
# efficiently on line 61 with `data_quants = np.nanquantile(...)`.`
samp_mask = np.isfinite(self._pit_samps)
self._pit_samps[~samp_mask] = 0
if not np.all(samp_mask):
if not np.all(samp_mask): #pragma: no cover
logging.warning(
"Some PIT samples were `NaN`. They have been replacd with 0."
)
Expand Down
9 changes: 1 addition & 8 deletions src/qp/metrics/point_estimate_metric_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,6 @@ def __init__(self, tdigest_compression: int = 1000, **kwargs) -> None:
super().__init__()
self._tdigest_compression = tdigest_compression

# ! Not entirely sure this function will be used, will keep it here for now.
def eval_from_iterator(self, estimate, reference):
self.initialize()
for estimate, reference in zip(estimate, reference):
self.accumulate(estimate, reference)
return self.finalize()

def initialize(self):
pass

Expand Down Expand Up @@ -70,7 +63,7 @@ def finalize(self, centroids: np.ndarray = []):

return self.compute_from_digest(digest)

def compute_from_digest(self, digest):
def compute_from_digest(self, digest): #pragma: no cover
raise NotImplementedError


Expand Down
11 changes: 11 additions & 0 deletions tests/qp/test_point_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,11 @@ def construct_test_ensemble():
return zgrid, true_zs, grid_ens, true_ez


#generator that yields chunks from estimate and reference
def chunker(seq, size):
return (seq[pos:pos + size] for pos in range(0, len(seq), size))


class test_point_metrics(unittest.TestCase):

def test_point_metrics(self):
Expand All @@ -42,6 +47,7 @@ def test_point_metrics(self):

ez = PointStatsEz().evaluate(zb, zspec)
assert np.allclose(ez, true_ez, atol=1.0e-2)

# grid limits ez vals to ~10^-2 tol

sig_iqr = PointSigmaIQR().evaluate(zb, zspec)
Expand Down Expand Up @@ -69,6 +75,11 @@ def test_point_metrics_digest(self):
sig_iqr = point_sigma_iqr.finalize([centroids])
assert np.isclose(sig_iqr, SIGIQR, atol=1.0e-4)

zb_iter = chunker(zb, 100)
zspec_iter = chunker(zspec, 100)

sig_iqr_v2 = point_sigma_iqr.eval_from_iterator(zb_iter, zspec_iter)

point_bias = PointBias(**configuration)
centroids = point_bias.accumulate(zb, zspec)
bias = point_bias.finalize([centroids])
Expand Down

0 comments on commit 6887c8d

Please sign in to comment.