Skip to content

Commit

Permalink
Adding support for TwoPointMeasurement filters. (#479)
Browse files Browse the repository at this point in the history
* Adding support for TwoPointMeasurement filters.
* Adding tests for filters.
* Removed unused import.
* Using two-point pair as filter specification.
* Adding support for serializing and deserializing filters.
* Adding tests for factories.
* Added check for never reached branch.
* TwoPointFilter documentation first draft
* Simplify logic
* If _path is set do not search current directory
* Release cython version restriction
* Do not load our duplicate_code plugin
* Update finding of SACC files in some tests
* Update version tag
* Refactor for improved test coverage and fix missing error case
* Complete branch coverage
* Improving tutorial.
* Correct serialization for TwoPointFactory.
* Added test for TwoPointFactory serialization.

---------

Co-authored-by: paulrogozenski <[email protected]>
Co-authored-by: Marc Paterno <[email protected]>
  • Loading branch information
3 people authored Feb 13, 2025
1 parent da7bb3a commit 9613906
Show file tree
Hide file tree
Showing 21 changed files with 1,578 additions and 115 deletions.
2 changes: 1 addition & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
author = "LSST DESC Firecrown Contributors"

# The full version, including alpha/beta/rc tags
release = "1.8.0"
release = "1.9.0a0"


# -- General configuration ---------------------------------------------------
Expand Down
2 changes: 1 addition & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ dependencies:
- cosmosis >= 3.0
- cosmosis-build-standard-library
- coverage
- cython < 3.0.0
- cython
- dill
- fitsio
- flake8
Expand Down
2 changes: 1 addition & 1 deletion fctools/tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

# some global context to be used in the tracing. We are relying on
# 'trace_call' to act as a closure that captures these names.
tracefile = None # the file used for logging
tracefile: TextIO | None = None # the file used for logging
level = 0 # the call nesting level
entry = 0 # sequential entry number for each record

Expand Down
291 changes: 289 additions & 2 deletions firecrown/data_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,35 @@
"""

import hashlib
from typing import Callable, Sequence
from typing import Callable, Sequence, Annotated
from typing_extensions import assert_never

import sacc
from pydantic import (
BaseModel,
BeforeValidator,
ConfigDict,
Field,
model_validator,
PrivateAttr,
field_serializer,
)
import numpy as np
import numpy.typing as npt
import sacc

from firecrown.metadata_types import (
TwoPointHarmonic,
TwoPointReal,
Measurement,
)
from firecrown.metadata_functions import (
extract_all_tracers_inferred_galaxy_zdists,
extract_window_function,
extract_all_harmonic_metadata_indices,
extract_all_real_metadata_indices,
make_two_point_xy,
make_measurement,
make_measurement_dict,
)
from firecrown.data_types import TwoPointMeasurement

Expand Down Expand Up @@ -222,3 +237,275 @@ def check_two_point_consistence_real(
) -> None:
"""Check the indices of the real-space two-point functions."""
check_consistence(two_point_reals, lambda m: m.is_real(), "TwoPointReal")


class TwoPointTracerSpec(BaseModel):
"""Class defining a tracer bin specification."""

model_config = ConfigDict(extra="forbid", frozen=True)

name: Annotated[str, Field(description="The name of the tracer bin.")]
measurement: Annotated[
Measurement,
Field(description="The measurement of the tracer bin."),
BeforeValidator(make_measurement),
]

@field_serializer("measurement")
@classmethod
def serialize_measurement(cls, value: Measurement) -> dict[str, str]:
"""Serialize the Measurement."""
return make_measurement_dict(value)


def make_interval_from_list(
values: list[float] | tuple[float, float],
) -> tuple[float, float]:
"""Create an interval from a list of values."""
if isinstance(values, list):
if len(values) != 2:
raise ValueError("The list should have two values.")
if not all(isinstance(v, float) for v in values):
raise ValueError("The list should have two float values.")

return (values[0], values[1])
if isinstance(values, tuple):
return values

raise ValueError("The values should be a list or a tuple.")


class TwoPointBinFilter(BaseModel):
"""Class defining a filter for a bin."""

model_config = ConfigDict(extra="forbid", frozen=True)

spec: Annotated[
list[TwoPointTracerSpec],
Field(
description="The two-point bin specification.",
),
]
interval: Annotated[
tuple[float, float],
BeforeValidator(make_interval_from_list),
Field(description="The range of the bin to filter."),
]

@model_validator(mode="after")
def check_bin_filter(self) -> "TwoPointBinFilter":
"""Check the bin filter."""
if self.interval[0] >= self.interval[1]:
raise ValueError("The bin filter should be a valid range.")
if not 1 <= len(self.spec) <= 2:
raise ValueError("The bin_spec must contain one or two elements.")
return self

@field_serializer("interval")
@classmethod
def serialize_interval(cls, value: tuple[float, float]) -> list[float]:
"""Serialize the Measurement."""
return list(value)

@classmethod
def from_args(
cls,
name1: str,
measurement1: Measurement,
name2: str,
measurement2: Measurement,
lower: float,
upper: float,
) -> "TwoPointBinFilter":
"""Create a TwoPointBinFilter from the arguments."""
return cls(
spec=[
TwoPointTracerSpec(name=name1, measurement=measurement1),
TwoPointTracerSpec(name=name2, measurement=measurement2),
],
interval=(lower, upper),
)

@classmethod
def from_args_auto(
cls, name: str, measurement: Measurement, lower: float, upper: float
) -> "TwoPointBinFilter":
"""Create a TwoPointBinFilter from the arguments."""
return cls(
spec=[
TwoPointTracerSpec(name=name, measurement=measurement),
],
interval=(lower, upper),
)


BinSpec = frozenset[TwoPointTracerSpec]


def bin_spec_from_metadata(metadata: TwoPointReal | TwoPointHarmonic) -> BinSpec:
"""Return the bin spec from the metadata."""
return frozenset(
(
TwoPointTracerSpec(
name=metadata.XY.x.bin_name,
measurement=metadata.XY.x_measurement,
),
TwoPointTracerSpec(
name=metadata.XY.y.bin_name,
measurement=metadata.XY.y_measurement,
),
)
)


class TwoPointBinFilterCollection(BaseModel):
"""Class defining a collection of bin filters."""

model_config = ConfigDict(extra="forbid", frozen=True)

require_filter_for_all: bool = Field(
default=False,
description="If True, all bins should match a filter.",
)
allow_empty: bool = Field(
default=False,
description=(
"When true, objects with no elements remaining after applying "
"the filter will be ignored rather than treated as an error."
),
)
filters: list[TwoPointBinFilter] = Field(
description="The list of bin filters.",
)

_bin_filter_dict: dict[BinSpec, tuple[float, float]] = PrivateAttr()

@model_validator(mode="after")
def check_bin_filters(self) -> "TwoPointBinFilterCollection":
"""Check the bin filters."""
bin_specs = set()
for bin_filter in self.filters:
bin_spec = frozenset(bin_filter.spec)
if bin_spec in bin_specs:
raise ValueError(
f"The bin name {bin_filter.spec} is repeated "
f"in the bin filters."
)
bin_specs.add(bin_spec)

self._bin_filter_dict = {
frozenset(bin_filter.spec): bin_filter.interval
for bin_filter in self.filters
}
return self

@property
def bin_filter_dict(self) -> dict[BinSpec, tuple[float, float]]:
"""Return the bin filter dictionary."""
return self._bin_filter_dict

def filter_match(self, tpm: TwoPointMeasurement) -> bool:
"""Check if the TwoPointMeasurement matches the filter."""
bin_spec_key = bin_spec_from_metadata(tpm.metadata)
return bin_spec_key in self._bin_filter_dict

def run_bin_filter(
self,
bin_filter: tuple[float, float],
vals: npt.NDArray[np.float64] | npt.NDArray[np.int64],
) -> npt.NDArray[np.bool_]:
"""Run the filter merge."""
return (vals >= bin_filter[0]) & (vals <= bin_filter[1])

def apply_filter_single(
self, tpm: TwoPointMeasurement
) -> tuple[npt.NDArray[np.bool_], npt.NDArray[np.bool_]]:
"""Apply the filter to a single TwoPointMeasurement."""
assert self.filter_match(tpm)
bin_spec_key = bin_spec_from_metadata(tpm.metadata)
bin_filter = self._bin_filter_dict[bin_spec_key]
if tpm.is_real():
assert isinstance(tpm.metadata, TwoPointReal)
match_elements = self.run_bin_filter(bin_filter, tpm.metadata.thetas)
return match_elements, match_elements

assert isinstance(tpm.metadata, TwoPointHarmonic)
match_elements = self.run_bin_filter(bin_filter, tpm.metadata.ells)
match_obs = match_elements
if tpm.metadata.window is not None:
# The window function is represented by a matrix where each column
# corresponds to the weights for the ell values of each observation. We
# need to ensure that the window function is filtered correctly. To do this,
# we will check each column of the matrix and verify that all non-zero
# elements are within the filtered set. If any non-zero element falls
# outside the filtered set, the match_elements will be set to False for that
# observation.
non_zero_window = tpm.metadata.window > 0
match_obs = (
np.all(
(non_zero_window & match_elements[:, None]) == non_zero_window,
axis=0,
)
.ravel()
.astype(np.bool_)
)

return match_elements, match_obs

def __call__(
self, tpms: Sequence[TwoPointMeasurement]
) -> list[TwoPointMeasurement]:
"""Filter the two-point measurements."""
result = []

for tpm in tpms:
if not self.filter_match(tpm):
if not self.require_filter_for_all:
result.append(tpm)
continue
raise ValueError(f"The bin name {tpm.metadata} does not have a filter.")

match_elements, match_obs = self.apply_filter_single(tpm)
if not match_obs.any():
if not self.allow_empty:
# If empty results are not allowed, we raise an error
raise ValueError(
f"The TwoPointMeasurement {tpm.metadata} does not "
f"have any elements matching the filter."
)
# If the filter is empty, we skip this measurement
continue

assert isinstance(tpm.metadata, (TwoPointReal, TwoPointHarmonic))
new_metadata: TwoPointReal | TwoPointHarmonic
match tpm.metadata:
case TwoPointReal():
new_metadata = TwoPointReal(
XY=tpm.metadata.XY,
thetas=tpm.metadata.thetas[match_elements],
)
case TwoPointHarmonic():
# If the window function is not None, we need to filter it as well
# and update the metadata accordingly.
new_metadata = TwoPointHarmonic(
XY=tpm.metadata.XY,
window=(
tpm.metadata.window[:, match_obs][match_elements, :]
if tpm.metadata.window is not None
else None
),
ells=tpm.metadata.ells[match_elements],
)
case _ as unreachable:
assert_never(unreachable)

result.append(
TwoPointMeasurement(
data=tpm.data[match_obs],
indices=tpm.indices[match_obs],
covariance_name=tpm.covariance_name,
metadata=new_metadata,
)
)

return result
Loading

0 comments on commit 9613906

Please sign in to comment.