Skip to content

Commit

Permalink
Add test_benchmark_link_test_orbit
Browse files Browse the repository at this point in the history
  • Loading branch information
moeyensj committed Nov 2, 2023
1 parent 4a8cc4a commit 916a31b
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 11 deletions.
1 change: 1 addition & 0 deletions recipe/meta.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ requirements:
- scipy
- spiceypy
- pytest
- pytest-benchmark
- pytest-cov
- pre-commit

Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ scikit-learn
scipy
spiceypy
pytest
pytest-benchmark
pytest-cov
pre-commit
setuptools >= 45
Expand Down
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ install_requires =
[options.extras_require]
tests =
pytest
pytest-benchmark
pytest-cov
pre-commit

Expand Down
56 changes: 45 additions & 11 deletions thor/tests/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,9 @@ def ray_cluster():
ray.shutdown()


def setup_test_data(object_id, orbits, observations, integration_config, max_arc_length = None):
def setup_test_data(
object_id, orbits, observations, integration_config, max_arc_length=None
):
"""
Selects the observations and orbit for a given object ID and returns the
test orbit, observations, expected observation IDs and the configuration
Expand All @@ -105,7 +107,8 @@ def setup_test_data(object_id, orbits, observations, integration_config, max_arc
time_mask = pc.and_(
pc.greater_equal(detections_i.time.days, pc.min(detections_i.time.days)),
pc.less_equal(
detections_i.time.days, pc.min(detections_i.time.days).as_py() + max_arc_length
detections_i.time.days,
pc.min(detections_i.time.days).as_py() + max_arc_length,
),
)
detections_i = detections_i.apply_mask(time_mask)
Expand Down Expand Up @@ -171,10 +174,16 @@ def test_range_and_transform(object_id, orbits, observations, integration_config
)
assert len(transformed_detections) == 90
assert pc.all(
pc.less_equal(pc.abs(transformed_detections.coordinates.theta_x), integration_config.cell_radius)
pc.less_equal(
pc.abs(transformed_detections.coordinates.theta_x),
integration_config.cell_radius,
)
).as_py()
assert pc.all(
pc.less_equal(pc.abs(transformed_detections.coordinates.theta_y), integration_config.cell_radius)
pc.less_equal(
pc.abs(transformed_detections.coordinates.theta_y),
integration_config.cell_radius,
)
).as_py()

# Ensure we get all the object IDs back that we expect
Expand Down Expand Up @@ -222,19 +231,44 @@ def test_link_test_orbit(
else:
integration_config.max_processes = 1

(
test_orbit,
observations,
obs_ids_expected,
integration_config,
) = setup_test_data(object_id, orbits, observations, integration_config, max_arc_length=14)
(test_orbit, observations, obs_ids_expected, integration_config,) = setup_test_data(
object_id, orbits, observations, integration_config, max_arc_length=14
)

# Run link_test_orbit and make sure we get the correct observations back
recovered_orbits, recovered_orbit_members = run_link_test_orbit(test_orbit, observations, integration_config)
recovered_orbits, recovered_orbit_members = run_link_test_orbit(
test_orbit, observations, integration_config
)
assert len(recovered_orbits) == 1
assert len(recovered_orbit_members) == len(obs_ids_expected)

# Ensure we get all the object IDs back that we expect
obs_ids_actual = recovered_orbit_members["obs_id"].values
assert pc.all(pc.equal(obs_ids_actual, obs_ids_expected))


@pytest.mark.parametrize("parallelized", [True, False])
@pytest.mark.benchmark(group="link_test_orbit", min_rounds=5, warmup=True)
def test_benchmark_link_test_orbit(
orbits, observations, integration_config, parallelized, ray_cluster, benchmark
):

object_id = "202930 Ivezic (1998 SG172)"
if parallelized:
integration_config.max_processes = 4
else:
integration_config.max_processes = 1

(test_orbit, observations, obs_ids_expected, integration_config,) = setup_test_data(
object_id, orbits, observations, integration_config, max_arc_length=14
)

recovered_orbits, recovered_orbit_members = benchmark(
run_link_test_orbit, test_orbit, observations, integration_config
)
assert len(recovered_orbits) == 1
assert len(recovered_orbit_members) == len(obs_ids_expected)

# Ensure we get all the object IDs back that we expect
obs_ids_actual = recovered_orbit_members["obs_id"].values
assert pc.all(pc.equal(obs_ids_actual, obs_ids_expected))

0 comments on commit 916a31b

Please sign in to comment.