Skip to content

Commit

Permalink
Add more benchmarks (#387)
Browse files Browse the repository at this point in the history
* Try slower plot again to see how it does now with performance
improvements from #370

* Don't need both fixtures in this test any more

* Don't need this old stuff either

* Use more points for pixel_to_world benchmark

* Add some more benchmarks

* Need units on parameters for updating vct some of the time

Not sure how this snuck past the tests in the first place

* Benchmarks are taking too long again

* Nope too slow

* Benchmark dataset slicing

* Add changelog

* Try slower plot again to see how it does now with performance
improvements from #370

* Don't need both fixtures in this test any more

* Don't need this old stuff either

* Use more points for pixel_to_world benchmark

* Add some more benchmarks

* Need units on parameters for updating vct some of the time

Not sure how this snuck past the tests in the first place

* Benchmarks are taking too long again

* Nope too slow

* Benchmark dataset slicing

* Add changelog

* Update dkist/wcs/models.py

Co-authored-by: Stuart Mumford <[email protected]>

* Add tests to hit update_celestial_transform and make sure the units are
working sensibly

* Correct test comparison value

---------

Co-authored-by: Stuart Mumford <[email protected]>
  • Loading branch information
SolarDrew and Cadair authored Jul 18, 2024
1 parent f2d8e79 commit dab245f
Show file tree
Hide file tree
Showing 4 changed files with 113 additions and 6 deletions.
1 change: 1 addition & 0 deletions changelog/387.trivial.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add some more benchmarks to track performance of more parts of the user tools.
76 changes: 72 additions & 4 deletions dkist/tests/test_benchmarks.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
import matplotlib.pyplot as plt
import numpy as np
import pytest
from numpy.random import default_rng

import astropy.units as u
from astropy.modeling.models import Tabular1D

from dkist import load_dataset
from dkist.wcs.models import (Ravel, generate_celestial_transform,
update_celestial_transform_parameters)


@pytest.mark.benchmark
Expand All @@ -11,11 +17,8 @@ def test_load_asdf(benchmark, large_visp_dataset_file):


@pytest.mark.benchmark
def test_pixel_to_world(benchmark, visp_dataset_no_headers, large_visp_dataset):
def test_pixel_to_world(benchmark, visp_dataset_no_headers):
ds = visp_dataset_no_headers
# pxcoords2 = []
# for size in ds2.wcs.pixel_shape:
# pxcoords2.append(np.arange(size))

pxcoords = np.mgrid[:ds.wcs.pixel_shape[0]:50,
:ds.wcs.pixel_shape[1]:50,
Expand All @@ -35,3 +38,68 @@ def plot_and_save_fig(ds=visp_dataset_no_headers, axes=axes):
ds.plot(plot_axes=axes)
plt.savefig("tmpplot")
plt.close()


@pytest.mark.benchmark
def test_generate_celestial(benchmark):
benchmark(generate_celestial_transform,
crpix=[0, 0] * u.pix,
crval=[0, 0] * u.arcsec,
cdelt=[1, 1] * u.arcsec/u.pix,
pc=np.identity(2) * u.pix,
)


@pytest.mark.benchmark
def test_update_celestial(benchmark):
trsfm = generate_celestial_transform(
crpix=[0, 0] * u.pix,
crval=[0, 0] * u.arcsec,
cdelt=[1, 1] * u.arcsec/u.pix,
pc=np.identity(2) * u.pix)

benchmark(update_celestial_transform_parameters,
trsfm,
[1, 1] * u.pix,
[0.5, 0.5] * u.arcsec/u.pix,
np.identity(2) * u.pix,
[1, 1] * u.arcsec,
180 * u.deg,
)


@pytest.mark.benchmark
def test_raveled_tab1d_model(benchmark):
ndim = 3
rng = default_rng()
array_shape = rng.integers(1, 21, ndim)
array_bounds = array_shape - 1
ravel = Ravel(array_shape)
nelem = np.prod(array_shape)
units = u.pix
values = np.arange(nelem) * units
lut_values = values
tabular = Tabular1D(
values,
lut_values,
bounds_error=False,
fill_value=np.nan,
method="linear",
)
raveled_tab = ravel | tabular
# adding the new axis onto array_bounds makes broadcasting work below
array_bounds = array_bounds[:, np.newaxis]
# use 5 as an arbitrary number of inputs
random_number_shape = len(array_shape), 5
random_numbers = rng.random(random_number_shape)
raw_inputs = random_numbers * array_bounds
inputs = tuple(raw_inputs * units)

benchmark(raveled_tab, *inputs)


@pytest.mark.benchmark
def test_slice_dataset(benchmark, large_visp_dataset):
@benchmark
def slice_dataset(dataset=large_visp_dataset, idx = np.s_[:2, 10:15, 0]):
sliced = dataset[idx]
2 changes: 1 addition & 1 deletion dkist/wcs/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def update_celestial_transform_parameters(
-crpix[0],
-crpix[1],
pc,
transform[2].translation.value,
transform[2].translation.quantity if hasattr(pc, "unit") else transform[2].translation.value,
cdelt[0],
cdelt[1],
crval[0],
Expand Down
40 changes: 39 additions & 1 deletion dkist/wcs/tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from dkist.wcs.models import (AsymmetricMapping, Ravel, Unravel, VaryingCelestialTransform,
VaryingCelestialTransform2D, VaryingCelestialTransform3D,
generate_celestial_transform,
generate_celestial_transform, update_celestial_transform_parameters,
varying_celestial_transform_from_tables)


Expand Down Expand Up @@ -52,6 +52,44 @@ def test_generate_celestial_unitless():
assert u.allclose(shift1.offset, 0)


def test_update_celestial():
trsfm = generate_celestial_transform(
crpix=[0, 0] * u.pix,
crval=[0, 0] * u.arcsec,
cdelt=[1, 1] * u.arcsec/u.pix,
pc=np.identity(2) * u.pix)

update_celestial_transform_parameters(
trsfm,
[1, 1] * u.pix,
[0.5, 0.5] * u.arcsec/u.pix,
np.identity(2) * u.pix,
[1, 1] * u.arcsec,
180 * u.deg)

# Copout and only test that one parameter is correct
shift1 = trsfm.left.left.left.left.right
assert u.allclose(shift1.offset.quantity, -1 * u.pix)

def test_update_celestial_unitless():
trsfm = generate_celestial_transform(
crpix=[0, 0],
crval=[0, 0],
cdelt=[1, 1],
pc=np.identity(2))

update_celestial_transform_parameters(
trsfm,
[1, 1],
[0.5, 0.5],
np.identity(2),
[1, 1],
180)

shift1 = trsfm.left.left.left.left.right
assert u.allclose(shift1.offset.value, -1)


def test_varying_transform_no_lon_pole_unit():
varying_matrix_lt = [rotation_matrix(a)[:2, :2] for a in np.linspace(0, 90, 10)] * u.pix
# Without a lon_pole passed, the transform was originally setting
Expand Down

0 comments on commit dab245f

Please sign in to comment.