Skip to content

Commit

Permalink
Performance improvements and remove namespace propagators
Browse files Browse the repository at this point in the history
  • Loading branch information
akoumjian authored Dec 4, 2024
1 parent 51972c9 commit 8fda15b
Show file tree
Hide file tree
Showing 10 changed files with 201 additions and 94 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/pip-build-lint-test-coverage.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ jobs:
pip install pdm
- name: Install Testing Dependencies
run: |
pdm install -G dev -G test
pdm install -G test
- name: Lint
run: pdm run lint
- name: Test with coverage
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ import numpy as np
from astropy import units as u

from adam_core.orbits.query import query_horizons
from adam_core.propagator.adam_assist import ASSISTPropagator
from adam_assist import ASSISTPropagator
from adam_core.time import Timestamp

# Get orbits to propagate
Expand Down
17 changes: 5 additions & 12 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ authors = [
]
description = "Core libraries for the ADAM platform"
readme = "README.md"
requires-python = ">=3.10"
requires-python = ">=3.10,<3.13"
classifiers = [
"Operating System :: OS Independent",
"Development Status :: 4 - Beta",
Expand Down Expand Up @@ -60,7 +60,6 @@ build-backend = "pdm.backend"


[dependency-groups]
dev = ["ipython>=8.28.0"]
test = [
"pytest-benchmark",
"pytest-cov",
Expand All @@ -70,7 +69,8 @@ test = [
"isort",
"mypy",
"ruff",
"black"
"black",
"adam-assist>=0.2.0",
]

[tool.black]
Expand All @@ -85,11 +85,6 @@ target-version = "py311"
lint.ignore = []
exclude = ["build"]

[tool.pytest.ini_options]
# In order for namespace packages to work during tests,
# we need to import from the installed modules instead of local source
addopts = ["--pyargs", "adam_core"]

[tool.pdm.build]
includes = ["src/adam_core/"]

Expand All @@ -110,11 +105,9 @@ lint = { composite = [
fix = "ruff ./src/adam_core --fix"
typecheck = "mypy --strict ./src/adam_core"

test = "pytest --benchmark-disable {args}"
test = "pytest --benchmark-disable -m 'not profile' {args}"
doctest = "pytest --doctest-plus --doctest-only"
benchmark = "pytest --benchmark-only"
coverage = "pytest --cov=adam_core --cov-report=xml"
coverage = "pytest --cov=adam_core -m 'not profile' --cov-report=xml"


[tool.pdm.dev-dependencies]
dev = ["-e git+https://github.com/B612-Asteroid-Institute/adam-assist.git@main#egg=adam-assist"]
59 changes: 40 additions & 19 deletions src/adam_core/dynamics/ephemeris.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from ..orbits.ephemeris import Ephemeris
from ..orbits.orbits import Orbits
from .aberrations import _add_light_time, add_stellar_aberration
from .propagation import process_in_chunks


@jit
Expand Down Expand Up @@ -176,6 +177,11 @@ def generate_ephemeris_2body(
ephemeris : `~adam_core.orbits.ephemeris.Ephemeris` (N)
Topocentric ephemerides for each propagated orbit as observed by the given observers.
"""
num_entries = len(observers)
assert (
len(propagated_orbits) == num_entries
), "Orbits and observers must be paired and orbits must be propagated to observer times."

# Transform both the orbits and observers to the barycenter if they are not already.
propagated_orbits_barycentric = propagated_orbits.set_column(
"coordinates",
Expand All @@ -196,26 +202,41 @@ def generate_ephemeris_2body(
),
)

# Stack the observer coordinates and codes for each orbit in the propagated orbits
num_orbits = len(propagated_orbits_barycentric.orbit_id.unique())
observer_coordinates = np.tile(
observers_barycentric.coordinates.values, (num_orbits, 1)
)
observer_codes = np.tile(observers.code.to_numpy(zero_copy_only=False), num_orbits)
observer_coordinates = observers_barycentric.coordinates.values
observer_codes = observers_barycentric.code.to_numpy(zero_copy_only=False)
mu = observers_barycentric.coordinates.origin.mu()
mu = np.tile(mu, num_orbits)

times = propagated_orbits.coordinates.time.to_astropy()
ephemeris_spherical, light_time = _generate_ephemeris_2body_vmap(
propagated_orbits_barycentric.coordinates.values,
times.mjd,
observer_coordinates,
mu,
lt_tol,
max_iter,
tol,
stellar_aberration,
)
times = propagated_orbits.coordinates.time.mjd().to_numpy(zero_copy_only=False)

# Define chunk size
chunk_size = 50

# Process in chunks
ephemeris_chunks = []
light_time_chunks = []

for orbits_chunk, times_chunk, observer_coords_chunk, mu_chunk in zip(
process_in_chunks(propagated_orbits_barycentric.coordinates.values, chunk_size),
process_in_chunks(times, chunk_size),
process_in_chunks(observer_coordinates, chunk_size),
process_in_chunks(mu, chunk_size),
):
ephemeris_chunk, light_time_chunk = _generate_ephemeris_2body_vmap(
orbits_chunk,
times_chunk,
observer_coords_chunk,
mu_chunk,
lt_tol,
max_iter,
tol,
stellar_aberration,
)
ephemeris_chunks.append(ephemeris_chunk)
light_time_chunks.append(light_time_chunk)

# Concatenate chunks and remove padding
ephemeris_spherical = jnp.concatenate(ephemeris_chunks, axis=0)[:num_entries]
light_time = jnp.concatenate(light_time_chunks, axis=0)[:num_entries]

ephemeris_spherical = np.array(ephemeris_spherical)
light_time = np.array(light_time)

Expand Down
56 changes: 48 additions & 8 deletions src/adam_core/dynamics/propagation.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,26 @@ def _propagate_2body(
)


def pad_to_fixed_size(array, target_shape, pad_value=0):
"""
Pad an array to a fixed shape with a specified pad value.
"""
pad_width = [(0, max(0, t - s)) for s, t in zip(array.shape, target_shape)]
return jnp.pad(array, pad_width, constant_values=pad_value)


def process_in_chunks(array, chunk_size):
"""
Yield chunks of the array with a fixed size, padding the last chunk if necessary.
"""
n = array.shape[0]
for i in range(0, n, chunk_size):
chunk = array[i : i + chunk_size]
if chunk.shape[0] < chunk_size:
chunk = pad_to_fixed_size(chunk, (chunk_size,) + chunk.shape[1:])
yield chunk


def propagate_2body(
orbits: Orbits,
times: Timestamp,
Expand All @@ -80,7 +100,7 @@ def propagate_2body(
Parameters
----------
orbits : `~jax.numpy.ndarray` (N, 6)
orbits : `~adam_core.orbits.orbits.Orbits` (N)
Cartesian orbits with position in units of au and velocity in units of au per day.
times : Timestamp (M)
Epochs to which to propagate each orbit. If a single epoch is given, all orbits are propagated to this
Expand All @@ -97,16 +117,19 @@ def propagate_2body(
orbits : `~adam_core.orbits.orbits.Orbits` (N*M)
Orbits propagated to each MJD.
"""
# Lets extract the cartesian orbits and times from the orbits object
# Extract and prepare data
cartesian_orbits = orbits.coordinates.values
t0 = orbits.coordinates.time.rescale("tdb").mjd()
t1 = times.rescale("tdb").mjd()
mu = orbits.coordinates.origin.mu()
orbit_ids = orbits.orbit_id.to_numpy(zero_copy_only=False)
object_ids = orbits.object_id.to_numpy(zero_copy_only=False)

# Lets stack the orbits into a single array shaped by the number of orbits and number of times
# and then pass this to the vectorized map version of _propagate_2body
# Define chunk size
chunk_size = 50 # Example chunk size

# Prepare arrays for chunk processing
# This creates a n x m matrix where n is the number of orbits and m is the number of times
n_orbits = cartesian_orbits.shape[0]
n_times = len(times)
orbit_ids_ = np.repeat(orbit_ids, n_times)
Expand All @@ -116,10 +139,24 @@ def propagate_2body(
t0_ = np.repeat(t0, n_times)
t1_ = np.tile(t1, n_orbits)

orbits_propagated = _propagate_2body_vmap(
orbits_array_, t0_, t1_, mu, max_iter, tol
)
orbits_propagated = np.array(orbits_propagated)
# Process in chunks
orbits_propagated_chunks = []
for orbits_chunk, t0_chunk, t1_chunk, mu_chunk in zip(
process_in_chunks(orbits_array_, chunk_size),
process_in_chunks(t0_, chunk_size),
process_in_chunks(t1_, chunk_size),
process_in_chunks(mu, chunk_size),
):
orbits_propagated_chunk = _propagate_2body_vmap(
orbits_chunk, t0_chunk, t1_chunk, mu_chunk, max_iter, tol
)
orbits_propagated_chunks.append(orbits_propagated_chunk)

# Concatenate all chunks
orbits_propagated = jnp.concatenate(orbits_propagated_chunks, axis=0)

# Remove padding
orbits_propagated = orbits_propagated[: n_orbits * n_times]

if not orbits.coordinates.covariance.is_all_nan():
cartesian_covariances = orbits.coordinates.covariance.to_matrix()
Expand All @@ -145,6 +182,9 @@ def propagate_2body(
origin_code = np.empty(n_orbits * n_times, dtype="object")
origin_code.fill("SUN")

# Convert from the jax array to a numpy array
orbits_propagated = np.asarray(orbits_propagated)

orbits_propagated = Orbits.from_kwargs(
orbit_id=orbit_ids_,
object_id=object_ids_,
Expand Down
49 changes: 49 additions & 0 deletions src/adam_core/dynamics/tests/test_ephemeris.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import cProfile

import jax
import numpy as np
import pyarrow.compute as pc
import pytest
Expand All @@ -7,6 +10,7 @@
from ...observers import Observers
from ...time import Timestamp
from ..ephemeris import generate_ephemeris_2body
from ..propagation import propagate_2body

OBJECT_IDS = [
"594913 'Aylo'chaxnim (2020 AV2)",
Expand Down Expand Up @@ -134,3 +138,48 @@ def test_generate_ephemeris_2body(object_id, propagated_orbits, ephemeris):
assert pc.all(pc.is_null(ephemeris_orbit_2body.aberrated_coordinates.vx)).as_py()
assert pc.all(pc.is_null(ephemeris_orbit_2body.aberrated_coordinates.vy)).as_py()
assert pc.all(pc.is_null(ephemeris_orbit_2body.aberrated_coordinates.vz)).as_py()


@pytest.mark.profile
def test_profile_generate_ephemeris_2body_matrix(propagated_orbits, tmp_path):
"""Profile the generate_ephemeris_2body function with different combinations of orbits,
observers and times. Results are saved to a stats file that can be visualized with snakeviz.
"""
# Clear the jax cache
jax.clear_caches()
# Create profiler
profiler = cProfile.Profile(subcalls=True, builtins=True)
profiler.bias = 0

n_entries = [1, 10, 100, 1000]

# create 1000 times, observers, and propagate orbits to those times
times = Timestamp.from_mjd(
np.arange(60000, 60000 + 1000, 1),
scale="tdb",
)
observers = Observers.from_code(
"X05",
times=times,
)
propagated_orbits = propagate_2body(
propagated_orbits[0],
times,
)

def to_profile():
for n_entries_i in n_entries:
generate_ephemeris_2body(
propagated_orbits[:n_entries_i],
observers[:n_entries_i],
)

# Run profiling
profiler.enable()
to_profile()
profiler.disable()

# Save and print results
stats_file = tmp_path / "ephemeris_profile.prof"
profiler.dump_stats(stats_file)
print(f"Run 'snakeviz {stats_file}' to view the profile results.")
52 changes: 52 additions & 0 deletions src/adam_core/dynamics/tests/test_propagation.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
import cProfile
import itertools

import jax
import numpy as np
import pytest
import spiceypy as sp
from astropy import units as u

Expand Down Expand Up @@ -439,3 +444,50 @@ def test_benchmark_propagate_2body(benchmark, orbital_elements):
scale="tdb",
)
benchmark(propagate_2body, orbits[0], times=times)


@pytest.mark.benchmark(group="propagate_2body")
def test_benchmark_propagate_2body_matrix(benchmark, propagated_orbits):
# Clear the jax cache
jax.clear_caches()

def benchmark_function():
n_orbits = [1, 5, 20]
n_times = [1, 10, 100]

for n_orbits_i, n_times_i in itertools.product(n_orbits, n_times):
times = Timestamp.from_mjd(
np.arange(0, n_times_i, 1),
scale="tdb",
)
propagate_2body(propagated_orbits[:n_orbits_i], times=times)

benchmark(benchmark_function)


@pytest.mark.profile
def test_profile_propagate_2body_matrix(propagated_orbits, tmp_path):
"""Profile the propagate_2body function with different combinations of orbits and times.
Results are saved to a stats file that can be visualized with snakeviz."""
# Clear the jax cache
jax.clear_caches()

# Create profiler
profiler = cProfile.Profile(subcalls=True, builtins=True)
profiler.bias = 0
# Run profiling
profiler.enable()
n_orbits = [1, 5, 20]
n_times = [1, 10, 100]
for n_orbits_i, n_times_i in itertools.product(n_orbits, n_times):
times = Timestamp.from_mjd(
np.arange(0, n_times_i, 1),
scale="tdb",
)
propagate_2body(propagated_orbits[:n_orbits_i], times=times)
profiler.disable()

# Save and print results
stats_file = tmp_path / "precovery_profile.prof"
profiler.dump_stats(stats_file)
print(f"Run 'snakeviz {stats_file}' to view the profile results.")
Loading

0 comments on commit 8fda15b

Please sign in to comment.