Skip to content

Commit

Permalink
add uc support for spec extract, support for matched zoom and uc
Browse files Browse the repository at this point in the history
  • Loading branch information
gibsongreen committed Jan 23, 2025
1 parent 289bf10 commit 2745eef
Show file tree
Hide file tree
Showing 5 changed files with 125 additions and 17 deletions.
34 changes: 34 additions & 0 deletions jdaviz/configs/mosviz/plugins/tools.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os

from glue.config import viewer_tool
from astropy import units as u

from jdaviz.configs.mosviz.plugins.viewers import MosvizProfileView, MosvizProfile2DView
from jdaviz.core.tools import _MatchedZoomMixin, HomeZoom, BoxZoom, XRangeZoom, PanZoom, PanZoomX
Expand All @@ -18,12 +19,45 @@ def _is_matched_viewer(self, viewer):
return isinstance(viewer, (MosvizProfile2DView, MosvizProfileView))

def _map_limits(self, from_viewer, to_viewer, limits={}):
components = self.viewer.state.data_collection[0]._components
# Determine cid for spectral axis
cid = None
for key in components.keys():
if 'Wavelength' in str(key):
cid = str(key)
break
elif 'Wave' in str(key):
cid = str(key)
break

if cid is None:
raise ValueError("Neither 'Wavelength' nor 'Wave' component'" +
"found in the data collection.")

native_unit = u.Unit(self.viewer.state.data_collection[0].get_component(cid).units)
current_display_unit = u.Unit(self.viewer.jdaviz_helper.app._get_display_unit('spectral'))

if isinstance(from_viewer, MosvizProfileView) and isinstance(to_viewer, MosvizProfile2DView): # noqa
if native_unit != current_display_unit:
limits['x_min'] = (limits['x_min'] * native_unit).to_value(
current_display_unit, equivalencies=u.spectral()
)

limits['x_max'] = (limits['x_max'] * native_unit).to_value(
current_display_unit, equivalencies=u.spectral()
)
limits['x_min'], limits['x_max'] = to_viewer.world_to_pixel_limits((limits['x_min'],
limits['x_max']))
elif isinstance(from_viewer, MosvizProfile2DView) and isinstance(to_viewer, MosvizProfileView): # noqa
limits['x_min'], limits['x_max'] = from_viewer.pixel_to_world_limits((limits['x_min'],
limits['x_max']))
if native_unit != current_display_unit:
limits['x_min'] = (limits['x_min'] * native_unit).to_value(
current_display_unit, equivalencies=u.spectral()
)
limits['x_max'] = (limits['x_max'] * native_unit).to_value(
current_display_unit, equivalencies=u.spectral()
)
return limits


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from traitlets import Bool, List, Unicode, observe


from jdaviz.core.events import SnackbarMessage
from jdaviz.core.registries import tray_registry
from jdaviz.core.template_mixin import (PluginTemplateMixin,
Expand Down Expand Up @@ -389,9 +388,10 @@ def _trace_dataset_selected(self, msg=None):
# happens when first initializing plugin outside of tray
return

width = self.trace_dataset.selected_obj.shape[0]
width = self.trace_dataset.get_selected_spectrum(use_display_units=True).shape[0]
# estimate the pixel number by taking the median of the brightest pixel index in each column
brightest_pixel = int(np.median(np.argmax(self.trace_dataset.selected_obj.flux, axis=0)))
trace_flux = self.trace_dataset.get_selected_spectrum(use_display_units=True).flux
brightest_pixel = int(np.median(np.argmax(trace_flux, axis=0)))
# do not allow to be an edge pixel
if brightest_pixel < 1:
brightest_pixel = 1
Expand Down Expand Up @@ -708,6 +708,7 @@ def import_trace(self, trace):
else: # pragma: no cover
raise NotImplementedError(f"trace of type {trace.__class__.__name__} not supported")

# UPDATE HERE
@with_spinner('trace_spinner')
def export_trace(self, add_data=False, **kwargs):
"""
Expand All @@ -728,21 +729,29 @@ def export_trace(self, add_data=False, **kwargs):
# then we're offsetting an existing trace
# for FlatTrace, we can keep and expose a new FlatTrace (which has the advantage of
# being able to load back into the plugin)
orig_trace = self.trace_trace.selected_obj
orig_trace = self.trace_trace.get_selected_spectrum(
self.trace_trace.selected_obj, use_display_units=True
)
if isinstance(orig_trace, tracing.FlatTrace):
trace = tracing.FlatTrace(self.trace_dataset.selected_obj,
trace = tracing.FlatTrace(self.trace_dataset.get_selected_spectrum(
self.trace_dataset, use_display_units=True),
orig_trace.trace_pos+self.trace_offset)
else:
trace = tracing.ArrayTrace(self.trace_dataset.selected_obj,
self.trace_trace.selected_obj.trace+self.trace_offset)
trace = tracing.ArrayTrace(self.trace_dataset.get_selected_spectrum(
self.trace_dataset, use_display_units=True),
self.trace_trace.get_selected_spectrum(
self.trace_trace.selected_obj,
use_display_units=True).trace+self.trace_offset)

elif self.trace_type_selected == 'Flat':
trace = tracing.FlatTrace(self.trace_dataset.selected_obj,
trace = tracing.FlatTrace(self.trace_dataset.get_selected_spectrum(
use_display_units=True),
self.trace_pixel)

elif self.trace_type_selected in _model_cls:
trace_model = _model_cls[self.trace_type_selected](degree=self.trace_order)
trace = tracing.FitTrace(self.trace_dataset.selected_obj,
trace = tracing.FitTrace(self.trace_dataset.get_selected_spectrum(
use_display_units=True),
guess=self.trace_pixel,
bins=int(self.trace_bins) if self.trace_do_binning else None,
window=self.trace_window,
Expand All @@ -762,12 +771,13 @@ def vue_create_trace(self, *args):

def _get_bg_trace(self):
if self.bg_type_selected == 'Manual':
trace = tracing.FlatTrace(self.trace_dataset.selected_obj,
trace = tracing.FlatTrace(self.trace_dataset.get_selected_spectrum(
use_display_units=True),
self.bg_trace_pixel)
elif self.bg_trace_selected == 'From Plugin':
trace = self.export_trace(add_data=False)
else:
trace = self.bg_trace.selected_obj
trace = self.bg_trace.get_selected_spectrum(use_disaply_units=True)

return trace

Expand Down Expand Up @@ -825,17 +835,20 @@ def export_bg(self, **kwargs):
trace = self._get_bg_trace()

if self.bg_type_selected == 'Manual':
bg = background.Background(self.bg_dataset.selected_obj,
bg = background.Background(self.bg_dataset.get_selected_spectrum(
use_display_units=True),
[trace], width=self.bg_width,
statistic=self.bg_statistic.selected.lower())
elif self.bg_type_selected == 'OneSided':
bg = background.Background.one_sided(self.bg_dataset.selected_obj,
bg = background.Background.one_sided(self.bg_dataset.get_selected_spectrum(
use_display_units=True),
trace,
self.bg_separation,
width=self.bg_width,
statistic=self.bg_statistic.selected.lower())
elif self.bg_type_selected == 'TwoSided':
bg = background.Background.two_sided(self.bg_dataset.selected_obj,
bg = background.Background.two_sided(self.bg_dataset.get_selected_spectrum(
use_display_units=True),
trace,
self.bg_separation,
width=self.bg_width,
Expand Down Expand Up @@ -918,13 +931,13 @@ def _get_ext_trace(self):
if self.ext_trace_selected == 'From Plugin':
return self.export_trace(add_data=False)
else:
return self.ext_trace.selected_obj
return self.ext_trace.get_selected_spectrum(use_display_units=True)

def _get_ext_input_spectrum(self):
if self.ext_dataset_selected == 'From Plugin':
return self.export_bg_sub(add_data=False)
else:
return self.ext_dataset.selected_obj
return self.ext_dataset.get_selected_spectrum(use_display_units=True)

def import_extract(self, ext):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from specreduce import tracing, background, extract
from specutils import Spectrum1D

from jdaviz.core.custom_units_and_equivs import SPEC_PHOTON_FLUX_DENSITY_UNITS

GWCS_LT_0_18_1 = Version(gwcs.__version__) < Version('0.18.1')


Expand Down Expand Up @@ -265,3 +267,33 @@ def test_horne_extract_self_profile(specviz2d_helper):
pext.self_prof_interp_degree_y = 0
with pytest.raises(ValueError, match='`self_prof_interp_degree_y` must be greater than 0.'):
sp_ext = pext.export_extract_spectrum()


def test_spectral_extraction_flux_unit_conversions(specviz2d_helper, mos_spectrum2d):
specviz2d_helper.load_data(mos_spectrum2d)

uc = specviz2d_helper.plugins["Unit Conversion"]
pext = specviz2d_helper.plugins['Spectral Extraction']

for new_flux_unit in SPEC_PHOTON_FLUX_DENSITY_UNITS:
# iterate through flux units verifying that selected object/spectrum is obtained using
# display units
uc.flux_unit.selected = new_flux_unit

exported_trace = pext.export_trace()
assert exported_trace.image._unit == specviz2d_helper.app._get_display_unit('flux')

exported_bg = pext.export_bg()
assert exported_bg.image._unit == specviz2d_helper.app._get_display_unit('flux')

exported_bg_img = pext.export_bg_img()
assert exported_bg_img._unit == specviz2d_helper.app._get_display_unit('flux')

exported_bg_sub = pext.export_bg_sub()
assert exported_bg_sub._unit == specviz2d_helper.app._get_display_unit('flux')

exported_extract_spectrum = pext.export_extract_spectrum()
assert exported_extract_spectrum._unit == specviz2d_helper.app._get_display_unit('flux')

exported_extract = pext.export_extract()
assert exported_extract.image._unit == specviz2d_helper.app._get_display_unit('flux')
2 changes: 1 addition & 1 deletion jdaviz/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,7 @@ def _generate_mos_spectrum2d():
'CRVAL1': 0.0, 'CRVAL2': 5.0,
'RADESYS': 'ICRS', 'SPECSYS': 'BARYCENT'}
np.random.seed(42)
data = np.random.sample((1024, 15)) * u.one
data = np.random.sample((1024, 15)) * u.Jy
return data, header


Expand Down
29 changes: 29 additions & 0 deletions jdaviz/core/tests/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,3 +141,32 @@ def test_stretch_bounds_click_outside_threshold(imviz_helper):
stretch_tool.on_mouse_event(outside_threshold_msg)
assert po.stretch_vmin.value == initial_vmin
assert po.stretch_vmax.value == initial_vmax


def test_unit_conversion_limits_update(specviz2d_helper, mos_spectrum2d):
specviz2d_helper.load_data(mos_spectrum2d)
uc = specviz2d_helper.plugins['Unit Conversion']

spec_viewer = specviz2d_helper.app.get_viewer(
specviz2d_helper.app._jdaviz_helper._default_spectrum_viewer_reference_name)
spec2d_viewer = specviz2d_helper.app.get_viewer(
specviz2d_helper.app._jdaviz_helper._default_spectrum_2d_viewer_reference_name)

# ensure spectrum and spectrum2d viewer limits matching updates when spectral_unit
# conversion occurs
uc.spectral_unit = 'Hz'

spec_viewer_lims_before = spec_viewer.get_limits()
spec2d_viewer_lims_before = spec2d_viewer.get_limits()

spec_viewer.reset_limits()

# ensure spectral unit conversion occurs when limits are manually changed
assert_allclose(spec_viewer_lims_before, spec_viewer.get_limits())
assert_allclose(spec2d_viewer_lims_before, spec2d_viewer.get_limits())

spec2d_viewer.reset_limits()

# test again when matching viewer's limits are reset
assert_allclose(spec_viewer_lims_before, spec_viewer.get_limits())
assert_allclose(spec2d_viewer_lims_before, spec2d_viewer.get_limits())

0 comments on commit 2745eef

Please sign in to comment.