Skip to content

Commit

Permalink
Refactor some more complicated observation code to be in a 'common' f…
Browse files Browse the repository at this point in the history
…unction library & add some interfaces to make observations more extensible
  • Loading branch information
emilyhunt committed Jan 30, 2025
1 parent a511705 commit b59df8f
Show file tree
Hide file tree
Showing 5 changed files with 217 additions and 73 deletions.
7 changes: 6 additions & 1 deletion src/ocelot/model/observation/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
from ._base import BaseObservation, BaseSelectionFunction # noqa: F401
from ._base import (
BaseObservation, # noqa: F401
BaseSelectionFunction, # noqa: F401
CustomPhotometricMethodObservation, # noqa: F401
CustomAstrometricMethodObservation, # noqa: F401
)
from .subsample_selection import GenericSubsampleSelectionFunction # noqa: F401
from .gaia.gaia_dr3 import GaiaDR3ObservationModel, GaiaDR3SelectionFunction # noqa: F401
63 changes: 29 additions & 34 deletions src/ocelot/model/observation/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import numpy as np
import pandas as pd
from numpy.typing import ArrayLike
from astropy.coordinates import SkyCoord
from astropy.units import Quantity


Expand Down Expand Up @@ -65,40 +64,8 @@ def get_selection_functions(
"""Fetch all selection functions associated with this observation."""
pass

def calculate_resolving_power(
self,
primary: pd.DataFrame,
secondary: pd.DataFrame,
separation: Quantity | None = None,
) -> np.ndarray:
"""Calculates the probability that a given pair of stars would be separately
resolved."""
# Calculate separation manually if not specified
if separation is None:
if "ra" not in primary.columns or "dec" not in primary.columns:
raise ValueError(
"separation not specified, and will instead be calculated manually;"
" however, required columns 'ra' and 'dec' are not in the columns "
"of 'primary'."
)
if "ra" not in secondary.columns or "dec" not in secondary.columns:
raise ValueError(
"separation not specified, and will instead be calculated manually;"
" however, required columns 'ra' and 'dec' are not in the columns "
"of 'secondary'."
)
coord_primary = SkyCoord(
primary["ra"].to_numpy(), primary["dec"].to_numpy(), unit="deg"
)
coord_secondary = SkyCoord(
secondary["ra"].to_numpy(), secondary["dec"].to_numpy(), unit="deg"
)
separation = coord_primary.separation(coord_secondary)

return self._calculate_resolving_power(primary, secondary, separation)

@abstractmethod
def _calculate_resolving_power(
def calculate_resolving_power(
self,
primary: pd.DataFrame,
secondary: pd.DataFrame,
Expand Down Expand Up @@ -158,3 +125,31 @@ def _query(self, observation: pd.DataFrame) -> np.ndarray:
of detecting a given star.
"""
pass


class CustomPhotometricMethodObservation(ABC):
"""Stub abstract base class defining an observation model that implements its own
photometric calculation method. This allows for observations to do more complicated
things than simply defining Gaussian uncertainties on fluxes, for instance.
"""

@abstractmethod
def apply_photometric_errors(
self, cluster: ocelot.simulate.cluster.SimulatedCluster
):
"""Apply photometric errors and save them to the observation."""
pass


class CustomAstrometricMethodObservation(ABC):
"""Stub abstract base class defining an observation model that implements its own
astrometric calculation method. This allows for observations to do more complicated
things than simply defining Gaussian uncertainties on astrometry, for instance.
"""

@abstractmethod
def apply_astrometric_errors(
self, cluster: ocelot.simulate.cluster.SimulatedCluster
):
"""Apply photometric errors and save them to the observation."""
pass
126 changes: 126 additions & 0 deletions src/ocelot/model/observation/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
"""Functions common to many different observation classes."""

from __future__ import annotations
import pandas as pd
from astropy.coordinates import SkyCoord
from astropy.units import Quantity
from ocelot.model.observation import BaseObservation
import ocelot.simulate.cluster


def calculate_separation(primary: pd.DataFrame, secondary: pd.DataFrame) -> Quantity:
"""Calculate the separation between a primary and a secondary list of stars.
Both dataframes must contain the keys 'ra' and 'dec'.
Parameters
----------
primary : pd.DataFrame
The primary dataframe of stars. Must contain 'ra' and 'dec'.
secondary : pd.DataFrame
The secondary dataframe of stars. Must contain 'ra' and 'dec'. Must have the
same length as 'primary'
Returns
-------
separations: Quantity
astropy Quantity array containing separations between stars in the two
specified dataframes.
Raises
------
ValueError
If 'ra' or 'dec' not in the columns of primary or secondary, or if there is a
length mismatch.
"""
# Checks
if "ra" not in primary.columns or "dec" not in primary.columns:
raise ValueError(
"separation not specified, and will instead be calculated manually;"
" however, required columns 'ra' and 'dec' are not in the columns "
"of 'primary'."
)
if "ra" not in secondary.columns or "dec" not in secondary.columns:
raise ValueError(
"separation not specified, and will instead be calculated manually;"
" however, required columns 'ra' and 'dec' are not in the columns "
"of 'secondary'."
)
if len(primary) != len(secondary):
raise ValueError(
"primary and secondary star dataframes must have equal length."
)

# Create skycoords & calculate the sep
coord_primary = SkyCoord(
primary["ra"].to_numpy(), primary["dec"].to_numpy(), unit="deg"
)
coord_secondary = SkyCoord(
secondary["ra"].to_numpy(), secondary["dec"].to_numpy(), unit="deg"
)
return coord_primary.separation(coord_secondary)


def apply_astrometric_errors_simple_gaussian(
cluster: ocelot.simulate.cluster.SimulatedCluster,
model: BaseObservation,
columns: None | list[str] | tuple[str] = None,
):
"""Calculates astrometry sampled from a Gaussian error distribution and adds it
as a column in the relevant observation.
Parameters
----------
cluster : ocelot.simulate.cluster.SimulatedCluster
Simulated cluster to apply to.
model : BaseObservation
Current model being used.
columns : None | list[str] | tuple[str], optional
List or tuple of columns to apply the errors to. Default: None, in which case
proper motion and parallax columns (if present) will have errors applied.
"""
observation = cluster.observations[model.name]

if columns is None:
columns = []
if model.has_parallaxes:
columns.append("parallax")
if model.has_proper_motions:
columns.extend(["pmra", "pmdec"])

for column in columns:
observation[column] = cluster.random_generator.normal(
loc=observation[column].to_numpy(),
scale=observation[f"{column}_error"].to_numpy(),
)


def apply_photometric_errors_simple_gaussian(
cluster: ocelot.simulate.cluster.SimulatedCluster,
model: BaseObservation,
bands: None | list[str] | tuple[str] = None,
):
"""Calculates photometry sampled from a Gaussian error distribution and adds it
as a column in the relevant observation.
Parameters
----------
cluster : ocelot.simulate.cluster.SimulatedCluster
Simulated cluster to apply to.
model : BaseObservation
Current model being used.
bands : None | list[str] | tuple[str], optional
List or tuple of bands to apply the errors to. Default: None, in which case all
bands in model.photometric_band_names have error applied.
"""
if bands is None:
bands = model.photometric_band_names

observation = cluster.observations[model.name]

for band in bands:
new_fluxes = cluster.random_generator.normal(
loc=model.mag_to_flux(observation[band].to_numpy(), band),
scale=observation[f"{band}_flux_error"].to_numpy(),
)
observation[band] = model.flux_to_mag(new_fluxes, band)
36 changes: 24 additions & 12 deletions src/ocelot/model/observation/gaia/gaia_dr3.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
"""Main class defining an observation made with Gaia DR3."""

from __future__ import annotations
from ocelot.model.observation._base import BaseObservation, BaseSelectionFunction
from ocelot.model.observation._base import (
BaseObservation,
BaseSelectionFunction,
# CustomPhotometricMethodObservation,
)
import ocelot.simulate.cluster
from scipy.interpolate import interp1d
from gaiaunlimited.selectionfunctions import DR3SelectionFunctionTCG
Expand Down Expand Up @@ -75,34 +79,41 @@ def has_parallaxes(self) -> bool:
def calculate_photometric_errors(
self, cluster: ocelot.simulate.cluster.SimulatedCluster
):
"""Apply photometric errors to a simulated cluster."""
"""Calculate photometric errors for a simulated cluster."""
self._assert_simulated_cluster_not_reused(cluster)
if self.matching_stars is None:
self.matching_stars, self.stars_to_assign = _closest_gaia_star(
cluster.observations["gaia_dr3"], self.representative_stars
)
observation = cluster.observations["gaia_dr3"]

for band in ("g", "bp", "rp"):
observation.loc[self.stars_to_assign, f"gaia_dr3_{band}_flux_error"] = (
self.matching_stars[f"phot_{band}_mean_flux_error"].to_numpy()
)
cluster.observations["gaia_dr3"].loc[
self.stars_to_assign, f"gaia_dr3_{band}_flux_error"
] = self.matching_stars[f"phot_{band}_mean_flux_error"].to_numpy()

def apply_photometric_errors(
self, cluster: ocelot.simulate.cluster.SimulatedCluster
):
"""Custom method to apply photometric errors to a simulated cluster.
Method incorporates the underestimated BP and RP flux measurement issue in DR3.
"""
raise NotImplementedError()

def calculate_astrometric_errors(
self, cluster: ocelot.simulate.cluster.SimulatedCluster
):
"""Apply astrometric errors to a simulated cluster."""
"""Calculate astrometric errors for a simulated cluster."""
self._assert_simulated_cluster_not_reused(cluster)
if self.matching_stars is None:
self.matching_stars, self.stars_to_assign = _closest_gaia_star(
cluster.observations["gaia_dr3"], self.representative_stars
)
observation = cluster.observations["gaia_dr3"]

for column in ("pmra_error", "pmdec_error", "parallax_error"):
observation.loc[self.stars_to_assign, column] = self.matching_stars[
column
].to_numpy()
cluster.observations["gaia_dr3"].loc[self.stars_to_assign, column] = (
self.matching_stars[column].to_numpy()
)

def get_selection_functions(
self, cluster: ocelot.simulate.cluster.SimulatedCluster
Expand All @@ -127,7 +138,7 @@ def calculate_extinction(self, cluster: ocelot.simulate.cluster.SimulatedCluster
observation["extinction"], observation["temperature"]
)

def _calculate_resolving_power(
def calculate_resolving_power(
self,
primary: pd.DataFrame,
secondary: pd.DataFrame,
Expand Down Expand Up @@ -171,6 +182,7 @@ def _check_band_name(self, band: str):
def _assert_simulated_cluster_not_reused(
self, cluster: ocelot.simulate.cluster.SimulatedCluster
):
# Todo this is bad and should be improved lol - model should be cluster-agnostic
if self.simulated_cluster is None:
self.simulated_cluster = cluster
return
Expand Down
Loading

0 comments on commit b59df8f

Please sign in to comment.