Skip to content

Commit

Permalink
Add setup_test_data
Browse files Browse the repository at this point in the history
  • Loading branch information
moeyensj committed Nov 2, 2023
1 parent bdda0a0 commit 4a8cc4a
Showing 1 changed file with 101 additions and 85 deletions.
186 changes: 101 additions & 85 deletions thor/tests/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,19 @@ def orbits():
return make_real_orbits()


@pytest.fixture
def integration_config():
config = Config(
vx_bins=10,
vy_bins=10,
vx_min=-0.01,
vx_max=0.01,
vy_min=-0.01,
vy_max=0.01,
)
return config


@pytest.fixture
def ray_cluster():
import ray
Expand All @@ -70,47 +83,84 @@ def ray_cluster():
ray.shutdown()


def test_Orbit_generate_ephemeris_from_observations_empty(orbits):
# Test that when passed empty observations, TestOrbit.generate_ephemeris_from_observations
# returns a Value Error
observations = Observations.empty()
test_orbit = THORbit.from_orbits(orbits[0])
with pytest.raises(ValueError, match="Observations must not be empty."):
test_orbit.generate_ephemeris_from_observations(observations)


@pytest.mark.parametrize("object_id", OBJECT_IDS)
def test_range_and_transform(object_id, orbits, observations):

def setup_test_data(object_id, orbits, observations, integration_config, max_arc_length = None):
"""
Selects the observations and orbit for a given object ID and returns the
test orbit, observations, expected observation IDs and the configuration
for the integration test.
"""
orbit = orbits.select("object_id", object_id)
exposures, detections, associations = observations

# Make THOR observations from the detections and exposures
observations = Observations.from_detections_and_exposures(detections, exposures)

# Select the associations that match this object ID
associations_i = associations.select("object_id", object_id)
detections_i = detections.apply_mask(
pc.is_in(detections.id, associations_i.detection_id)
)
exposures_i = exposures.apply_mask(pc.is_in(exposures.id, detections_i.exposure_id))
assert len(associations_i) == 90

if max_arc_length is not None:
# Limit detections max_arc_length days from the first detection
time_mask = pc.and_(
pc.greater_equal(detections_i.time.days, pc.min(detections_i.time.days)),
pc.less_equal(
detections_i.time.days, pc.min(detections_i.time.days).as_py() + max_arc_length
),
)
detections_i = detections_i.apply_mask(time_mask)
exposures_i = exposures_i.apply_mask(
pc.is_in(exposures_i.id, detections_i.exposure_id)
)
associations_i = associations_i.apply_mask(
pc.is_in(associations_i.detection_id, detections_i.id)
)

# Extract the observations that match this object ID
obs_ids_expected = associations_i.detection_id.unique().sort()

# Filter the observations to include only those that match this object
observations = observations.apply_mask(
pc.is_in(observations.detections.id, obs_ids_expected)
)
# Make THOR observations from the detections and exposures
observations = Observations.from_detections_and_exposures(detections_i, exposures_i)

if object_id in TOLERANCES:
tolerance = TOLERANCES[object_id]
integration_config.cell_radius = TOLERANCES[object_id]
else:
tolerance = TOLERANCES["default"]
integration_config.cell_radius = TOLERANCES["default"]

# Create a test orbit for this object
test_orbit = THORbit.from_orbits(orbit)

return test_orbit, observations, obs_ids_expected, integration_config


def test_Orbit_generate_ephemeris_from_observations_empty(orbits):
# Test that when passed empty observations, TestOrbit.generate_ephemeris_from_observations
# returns a Value Error
observations = Observations.empty()
test_orbit = THORbit.from_orbits(orbits[0])
with pytest.raises(ValueError, match="Observations must not be empty."):
test_orbit.generate_ephemeris_from_observations(observations)


@pytest.mark.parametrize("object_id", OBJECT_IDS)
def test_range_and_transform(object_id, orbits, observations, integration_config):

integration_config.max_processes = 1
(
test_orbit,
observations,
obs_ids_expected,
integration_config,
) = setup_test_data(object_id, orbits, observations, integration_config)

if object_id in TOLERANCES:
integration_config.cell_radius = TOLERANCES[object_id]
else:
integration_config.cell_radius = TOLERANCES["default"]

# Set a filter to include observations within 1 arcsecond of the predicted position
# of the test orbit
filters = [TestOrbitRadiusObservationFilter(radius=tolerance)]
filters = [TestOrbitRadiusObservationFilter(radius=integration_config.cell_radius)]
for filter in filters:
observations = filter.apply(observations, test_orbit)

Expand All @@ -121,17 +171,28 @@ def test_range_and_transform(object_id, orbits, observations):
)
assert len(transformed_detections) == 90
assert pc.all(
pc.less_equal(pc.abs(transformed_detections.coordinates.theta_x), tolerance)
pc.less_equal(pc.abs(transformed_detections.coordinates.theta_x), integration_config.cell_radius)
).as_py()
assert pc.all(
pc.less_equal(pc.abs(transformed_detections.coordinates.theta_y), tolerance)
pc.less_equal(pc.abs(transformed_detections.coordinates.theta_y), integration_config.cell_radius)
).as_py()

# Ensure we get all the object IDs back that we expect
obs_ids_actual = transformed_detections.id.unique().sort()
assert pc.all(pc.equal(obs_ids_actual, obs_ids_expected))


def run_link_test_orbit(test_orbit, observations, config):
for i, results in enumerate(
link_test_orbit(test_orbit, observations, config=config)
):
if i == 4:
recovered_orbits, recovered_orbit_members = results
else:
continue
return recovered_orbits, recovered_orbit_members


@pytest.mark.parametrize(
"object_id",
[
Expand All @@ -152,73 +213,28 @@ def test_range_and_transform(object_id, orbits, observations):
)
@pytest.mark.parametrize("parallelized", [True, False])
@pytest.mark.integration
def test_link_test_orbit(object_id, orbits, observations, parallelized, ray_cluster):
def test_link_test_orbit(
object_id, orbits, observations, parallelized, integration_config, ray_cluster
):

config = Config()
if parallelized:
config.max_processes = 4
else:
config.max_processes = 1

# Reduce the clustering grid size to speed up the test
config.vx_bins = 10
config.vy_bins = 10
config.vx_min = -0.01
config.vx_max = 0.01
config.vy_min = -0.01
config.vy_max = 0.01

orbit = orbits.select("object_id", object_id)
exposures, detections, associations = observations

# Select the associations that match this object ID
associations_i = associations.select("object_id", object_id)
detections_i = detections.apply_mask(
pc.is_in(detections.id, associations_i.detection_id)
)
exposures_i = exposures.apply_mask(pc.is_in(exposures.id, detections_i.exposure_id))
assert len(associations_i) == 90

# Limit detections to first two weeks
time_mask = pc.and_(
pc.greater_equal(detections_i.time.days, pc.min(detections_i.time.days)),
pc.less_equal(
detections_i.time.days, pc.min(detections_i.time.days).as_py() + 14
),
)
detections_i = detections_i.apply_mask(time_mask)
exposures_i = exposures_i.apply_mask(
pc.is_in(exposures_i.id, detections_i.exposure_id)
)
associations_i = associations_i.apply_mask(
pc.is_in(associations_i.detection_id, detections_i.id)
)

# Extract the observations that match this object ID
obs_ids_expected = associations_i.detection_id.unique().sort()

# Make THOR observations from the detections and exposures
observations = Observations.from_detections_and_exposures(detections_i, exposures_i)

if object_id in TOLERANCES:
config.cell_radius = TOLERANCES[object_id]
integration_config.max_processes = 4
else:
config.cell_radius = TOLERANCES["default"]
integration_config.max_processes = 1

# Create a test orbit for this object
test_orbit = THORbit.from_orbits(orbit)
(
test_orbit,
observations,
obs_ids_expected,
integration_config,
) = setup_test_data(object_id, orbits, observations, integration_config, max_arc_length=14)

# Run link_test_orbit and make sure we get the correct observations back
for i, results in enumerate(
link_test_orbit(test_orbit, observations, config=config)
):
if i == 4:
od_orbits, od_orbit_members = results
else:
continue

assert len(od_orbit_members) == len(obs_ids_expected)
recovered_orbits, recovered_orbit_members = run_link_test_orbit(test_orbit, observations, integration_config)
assert len(recovered_orbits) == 1
assert len(recovered_orbit_members) == len(obs_ids_expected)

# Ensure we get all the object IDs back that we expect
obs_ids_actual = od_orbit_members["obs_id"].values
obs_ids_actual = recovered_orbit_members["obs_id"].values
assert pc.all(pc.equal(obs_ids_actual, obs_ids_expected))

0 comments on commit 4a8cc4a

Please sign in to comment.