diff --git a/jdaviz/configs/mosviz/plugins/tools.py b/jdaviz/configs/mosviz/plugins/tools.py index bde5a25f82..c95e3a8479 100644 --- a/jdaviz/configs/mosviz/plugins/tools.py +++ b/jdaviz/configs/mosviz/plugins/tools.py @@ -1,6 +1,7 @@ import os from glue.config import viewer_tool +from astropy import units as u from jdaviz.configs.mosviz.plugins.viewers import MosvizProfileView, MosvizProfile2DView from jdaviz.core.tools import _MatchedZoomMixin, HomeZoom, BoxZoom, XRangeZoom, PanZoom, PanZoomX @@ -18,12 +19,45 @@ def _is_matched_viewer(self, viewer): return isinstance(viewer, (MosvizProfile2DView, MosvizProfileView)) def _map_limits(self, from_viewer, to_viewer, limits={}): + components = self.viewer.state.data_collection[0]._components + # Determine cid for spectral axis + cid = None + for key in components.keys(): + if 'Wavelength' in str(key): + cid = str(key) + break + elif 'Wave' in str(key): + cid = str(key) + break + + if cid is None: + raise ValueError("Neither 'Wavelength' nor 'Wave' component'" + + "found in the data collection.") + + native_unit = u.Unit(self.viewer.state.data_collection[0].get_component(cid).units) + current_display_unit = u.Unit(self.viewer.jdaviz_helper.app._get_display_unit('spectral')) + if isinstance(from_viewer, MosvizProfileView) and isinstance(to_viewer, MosvizProfile2DView): # noqa + if native_unit != current_display_unit: + limits['x_min'] = (limits['x_min'] * native_unit).to_value( + current_display_unit, equivalencies=u.spectral() + ) + + limits['x_max'] = (limits['x_max'] * native_unit).to_value( + current_display_unit, equivalencies=u.spectral() + ) limits['x_min'], limits['x_max'] = to_viewer.world_to_pixel_limits((limits['x_min'], limits['x_max'])) elif isinstance(from_viewer, MosvizProfile2DView) and isinstance(to_viewer, MosvizProfileView): # noqa limits['x_min'], limits['x_max'] = from_viewer.pixel_to_world_limits((limits['x_min'], limits['x_max'])) + if native_unit != current_display_unit: + limits['x_min'] = (limits['x_min'] * native_unit).to_value( + current_display_unit, equivalencies=u.spectral() + ) + limits['x_max'] = (limits['x_max'] * native_unit).to_value( + current_display_unit, equivalencies=u.spectral() + ) return limits diff --git a/jdaviz/configs/specviz2d/plugins/spectral_extraction/spectral_extraction.py b/jdaviz/configs/specviz2d/plugins/spectral_extraction/spectral_extraction.py index 05cad4e168..6e0cfb7ab2 100644 --- a/jdaviz/configs/specviz2d/plugins/spectral_extraction/spectral_extraction.py +++ b/jdaviz/configs/specviz2d/plugins/spectral_extraction/spectral_extraction.py @@ -2,7 +2,6 @@ from traitlets import Bool, List, Unicode, observe - from jdaviz.core.events import SnackbarMessage from jdaviz.core.registries import tray_registry from jdaviz.core.template_mixin import (PluginTemplateMixin, @@ -389,9 +388,10 @@ def _trace_dataset_selected(self, msg=None): # happens when first initializing plugin outside of tray return - width = self.trace_dataset.selected_obj.shape[0] + width = self.trace_dataset.get_selected_spectrum(use_display_units=True).shape[0] # estimate the pixel number by taking the median of the brightest pixel index in each column - brightest_pixel = int(np.median(np.argmax(self.trace_dataset.selected_obj.flux, axis=0))) + trace_flux = self.trace_dataset.get_selected_spectrum(use_display_units=True).flux + brightest_pixel = int(np.median(np.argmax(trace_flux, axis=0))) # do not allow to be an edge pixel if brightest_pixel < 1: brightest_pixel = 1 @@ -708,6 +708,7 @@ def import_trace(self, trace): else: # pragma: no cover raise NotImplementedError(f"trace of type {trace.__class__.__name__} not supported") + # UPDATE HERE @with_spinner('trace_spinner') def export_trace(self, add_data=False, **kwargs): """ @@ -728,21 +729,29 @@ def export_trace(self, add_data=False, **kwargs): # then we're offsetting an existing trace # for FlatTrace, we can keep and expose a new FlatTrace (which has the advantage of # being able to load back into the plugin) - orig_trace = self.trace_trace.selected_obj + orig_trace = self.trace_trace.get_selected_spectrum( + self.trace_trace.selected_obj, use_display_units=True + ) if isinstance(orig_trace, tracing.FlatTrace): - trace = tracing.FlatTrace(self.trace_dataset.selected_obj, + trace = tracing.FlatTrace(self.trace_dataset.get_selected_spectrum( + self.trace_dataset, use_display_units=True), orig_trace.trace_pos+self.trace_offset) else: - trace = tracing.ArrayTrace(self.trace_dataset.selected_obj, - self.trace_trace.selected_obj.trace+self.trace_offset) + trace = tracing.ArrayTrace(self.trace_dataset.get_selected_spectrum( + self.trace_dataset, use_display_units=True), + self.trace_trace.get_selected_spectrum( + self.trace_trace.selected_obj, + use_display_units=True).trace+self.trace_offset) elif self.trace_type_selected == 'Flat': - trace = tracing.FlatTrace(self.trace_dataset.selected_obj, + trace = tracing.FlatTrace(self.trace_dataset.get_selected_spectrum( + use_display_units=True), self.trace_pixel) elif self.trace_type_selected in _model_cls: trace_model = _model_cls[self.trace_type_selected](degree=self.trace_order) - trace = tracing.FitTrace(self.trace_dataset.selected_obj, + trace = tracing.FitTrace(self.trace_dataset.get_selected_spectrum( + use_display_units=True), guess=self.trace_pixel, bins=int(self.trace_bins) if self.trace_do_binning else None, window=self.trace_window, @@ -762,12 +771,13 @@ def vue_create_trace(self, *args): def _get_bg_trace(self): if self.bg_type_selected == 'Manual': - trace = tracing.FlatTrace(self.trace_dataset.selected_obj, + trace = tracing.FlatTrace(self.trace_dataset.get_selected_spectrum( + use_display_units=True), self.bg_trace_pixel) elif self.bg_trace_selected == 'From Plugin': trace = self.export_trace(add_data=False) else: - trace = self.bg_trace.selected_obj + trace = self.bg_trace.get_selected_spectrum(use_disaply_units=True) return trace @@ -825,17 +835,20 @@ def export_bg(self, **kwargs): trace = self._get_bg_trace() if self.bg_type_selected == 'Manual': - bg = background.Background(self.bg_dataset.selected_obj, + bg = background.Background(self.bg_dataset.get_selected_spectrum( + use_display_units=True), [trace], width=self.bg_width, statistic=self.bg_statistic.selected.lower()) elif self.bg_type_selected == 'OneSided': - bg = background.Background.one_sided(self.bg_dataset.selected_obj, + bg = background.Background.one_sided(self.bg_dataset.get_selected_spectrum( + use_display_units=True), trace, self.bg_separation, width=self.bg_width, statistic=self.bg_statistic.selected.lower()) elif self.bg_type_selected == 'TwoSided': - bg = background.Background.two_sided(self.bg_dataset.selected_obj, + bg = background.Background.two_sided(self.bg_dataset.get_selected_spectrum( + use_display_units=True), trace, self.bg_separation, width=self.bg_width, @@ -918,13 +931,13 @@ def _get_ext_trace(self): if self.ext_trace_selected == 'From Plugin': return self.export_trace(add_data=False) else: - return self.ext_trace.selected_obj + return self.ext_trace.get_selected_spectrum(use_display_units=True) def _get_ext_input_spectrum(self): if self.ext_dataset_selected == 'From Plugin': return self.export_bg_sub(add_data=False) else: - return self.ext_dataset.selected_obj + return self.ext_dataset.get_selected_spectrum(use_display_units=True) def import_extract(self, ext): """ diff --git a/jdaviz/configs/specviz2d/plugins/spectral_extraction/tests/test_spectral_extraction.py b/jdaviz/configs/specviz2d/plugins/spectral_extraction/tests/test_spectral_extraction.py index 5f4cecd774..c04515fdfe 100644 --- a/jdaviz/configs/specviz2d/plugins/spectral_extraction/tests/test_spectral_extraction.py +++ b/jdaviz/configs/specviz2d/plugins/spectral_extraction/tests/test_spectral_extraction.py @@ -10,6 +10,8 @@ from specreduce import tracing, background, extract from specutils import Spectrum1D +from jdaviz.core.custom_units_and_equivs import SPEC_PHOTON_FLUX_DENSITY_UNITS + GWCS_LT_0_18_1 = Version(gwcs.__version__) < Version('0.18.1') @@ -265,3 +267,33 @@ def test_horne_extract_self_profile(specviz2d_helper): pext.self_prof_interp_degree_y = 0 with pytest.raises(ValueError, match='`self_prof_interp_degree_y` must be greater than 0.'): sp_ext = pext.export_extract_spectrum() + + +def test_spectral_extraction_flux_unit_conversions(specviz2d_helper, mos_spectrum2d): + specviz2d_helper.load_data(mos_spectrum2d) + + uc = specviz2d_helper.plugins["Unit Conversion"] + pext = specviz2d_helper.plugins['Spectral Extraction'] + + for new_flux_unit in SPEC_PHOTON_FLUX_DENSITY_UNITS: + # iterate through flux units verifying that selected object/spectrum is obtained using + # display units + uc.flux_unit.selected = new_flux_unit + + exported_trace = pext.export_trace() + assert exported_trace.image._unit == specviz2d_helper.app._get_display_unit('flux') + + exported_bg = pext.export_bg() + assert exported_bg.image._unit == specviz2d_helper.app._get_display_unit('flux') + + exported_bg_img = pext.export_bg_img() + assert exported_bg_img._unit == specviz2d_helper.app._get_display_unit('flux') + + exported_bg_sub = pext.export_bg_sub() + assert exported_bg_sub._unit == specviz2d_helper.app._get_display_unit('flux') + + exported_extract_spectrum = pext.export_extract_spectrum() + assert exported_extract_spectrum._unit == specviz2d_helper.app._get_display_unit('flux') + + exported_extract = pext.export_extract() + assert exported_extract.image._unit == specviz2d_helper.app._get_display_unit('flux') diff --git a/jdaviz/conftest.py b/jdaviz/conftest.py index ded57b6337..26d4f37465 100644 --- a/jdaviz/conftest.py +++ b/jdaviz/conftest.py @@ -360,7 +360,7 @@ def _generate_mos_spectrum2d(): 'CRVAL1': 0.0, 'CRVAL2': 5.0, 'RADESYS': 'ICRS', 'SPECSYS': 'BARYCENT'} np.random.seed(42) - data = np.random.sample((1024, 15)) * u.one + data = np.random.sample((1024, 15)) * u.Jy return data, header diff --git a/jdaviz/core/tests/test_tools.py b/jdaviz/core/tests/test_tools.py index 35622814b4..beea25e853 100644 --- a/jdaviz/core/tests/test_tools.py +++ b/jdaviz/core/tests/test_tools.py @@ -141,3 +141,32 @@ def test_stretch_bounds_click_outside_threshold(imviz_helper): stretch_tool.on_mouse_event(outside_threshold_msg) assert po.stretch_vmin.value == initial_vmin assert po.stretch_vmax.value == initial_vmax + + +def test_unit_conversion_limits_update(specviz2d_helper, mos_spectrum2d): + specviz2d_helper.load_data(mos_spectrum2d) + uc = specviz2d_helper.plugins['Unit Conversion'] + + spec_viewer = specviz2d_helper.app.get_viewer( + specviz2d_helper.app._jdaviz_helper._default_spectrum_viewer_reference_name) + spec2d_viewer = specviz2d_helper.app.get_viewer( + specviz2d_helper.app._jdaviz_helper._default_spectrum_2d_viewer_reference_name) + + # ensure spectrum and spectrum2d viewer limits matching updates when spectral_unit + # conversion occurs + uc.spectral_unit = 'Hz' + + spec_viewer_lims_before = spec_viewer.get_limits() + spec2d_viewer_lims_before = spec2d_viewer.get_limits() + + spec_viewer.reset_limits() + + # ensure spectral unit conversion occurs when limits are manually changed + assert_allclose(spec_viewer_lims_before, spec_viewer.get_limits()) + assert_allclose(spec2d_viewer_lims_before, spec2d_viewer.get_limits()) + + spec2d_viewer.reset_limits() + + # test again when matching viewer's limits are reset + assert_allclose(spec_viewer_lims_before, spec_viewer.get_limits()) + assert_allclose(spec2d_viewer_lims_before, spec2d_viewer.get_limits())