diff --git a/src/qp/metrics/base_metric_classes.py b/src/qp/metrics/base_metric_classes.py index 49ba655..5e29e80 100644 --- a/src/qp/metrics/base_metric_classes.py +++ b/src/qp/metrics/base_metric_classes.py @@ -147,9 +147,24 @@ class PointToPointMetric(BaseMetric): metric_input_type = MetricInputType.point_to_point + def eval_from_iterator(self, estimate, reference): + self.initialize() + for estimate, reference in zip(estimate, reference): + centroids = self.accumulate(estimate, reference) + return self.finalize([centroids]) + def evaluate(self, estimate, reference): raise NotImplementedError() + def initialize(self): #pragma: no cover + pass + + def accumulate(self, estimate, reference): #pragma: no cover + raise NotImplementedError() + + def finalize(self): #pragma: no cover + raise NotImplementedError() + class PointToDistMetric(BaseMetric): """A base class for metrics that require a point estimate as the estimated @@ -160,18 +175,3 @@ class PointToDistMetric(BaseMetric): def evaluate(self, estimate, reference): raise NotImplementedError() - - def eval_from_iterator(self, estimate, reference): - self.initialize() - for estimate, reference in zip(estimate, reference): - self.accumulate(estimate, reference) - return self.finalize() - - def initialize(self): - pass - - def accumulate(self, estimate, reference): - raise NotImplementedError() - - def finalize(self): - raise NotImplementedError() diff --git a/src/qp/metrics/pit.py b/src/qp/metrics/pit.py index e19cc6e..f434240 100644 --- a/src/qp/metrics/pit.py +++ b/src/qp/metrics/pit.py @@ -57,7 +57,7 @@ def __init__(self, qp_ens, true_vals, eval_grid=DEFAULT_QUANTS): # efficiently on line 61 with `data_quants = np.nanquantile(...)`.` samp_mask = np.isfinite(self._pit_samps) self._pit_samps[~samp_mask] = 0 - if not np.all(samp_mask): + if not np.all(samp_mask): #pragma: no cover logging.warning( "Some PIT samples were `NaN`. They have been replacd with 0." ) diff --git a/src/qp/metrics/point_estimate_metric_classes.py b/src/qp/metrics/point_estimate_metric_classes.py index 7c1df08..7856813 100644 --- a/src/qp/metrics/point_estimate_metric_classes.py +++ b/src/qp/metrics/point_estimate_metric_classes.py @@ -14,13 +14,6 @@ def __init__(self, tdigest_compression: int = 1000, **kwargs) -> None: super().__init__() self._tdigest_compression = tdigest_compression - # ! Not entirely sure this function will be used, will keep it here for now. - def eval_from_iterator(self, estimate, reference): - self.initialize() - for estimate, reference in zip(estimate, reference): - self.accumulate(estimate, reference) - return self.finalize() - def initialize(self): pass @@ -70,7 +63,7 @@ def finalize(self, centroids: np.ndarray = []): return self.compute_from_digest(digest) - def compute_from_digest(self, digest): + def compute_from_digest(self, digest): #pragma: no cover raise NotImplementedError diff --git a/tests/qp/test_point_metrics.py b/tests/qp/test_point_metrics.py index 7672392..abec9db 100644 --- a/tests/qp/test_point_metrics.py +++ b/tests/qp/test_point_metrics.py @@ -33,6 +33,11 @@ def construct_test_ensemble(): return zgrid, true_zs, grid_ens, true_ez +#generator that yields chunks from estimate and reference +def chunker(seq, size): + return (seq[pos:pos + size] for pos in range(0, len(seq), size)) + + class test_point_metrics(unittest.TestCase): def test_point_metrics(self): @@ -42,6 +47,7 @@ def test_point_metrics(self): ez = PointStatsEz().evaluate(zb, zspec) assert np.allclose(ez, true_ez, atol=1.0e-2) + # grid limits ez vals to ~10^-2 tol sig_iqr = PointSigmaIQR().evaluate(zb, zspec) @@ -69,6 +75,11 @@ def test_point_metrics_digest(self): sig_iqr = point_sigma_iqr.finalize([centroids]) assert np.isclose(sig_iqr, SIGIQR, atol=1.0e-4) + zb_iter = chunker(zb, 100) + zspec_iter = chunker(zspec, 100) + + sig_iqr_v2 = point_sigma_iqr.eval_from_iterator(zb_iter, zspec_iter) + point_bias = PointBias(**configuration) centroids = point_bias.accumulate(zb, zspec) bias = point_bias.finalize([centroids])