Skip to content

Commit

Permalink
Reinstates _sanitize_inputs
Browse files Browse the repository at this point in the history
Reinstates _sanitize_inputs method to handle units correctly for astropy fitting.
  • Loading branch information
jajmitchell committed Feb 19, 2025
1 parent 3b0b363 commit 9fd4ce9
Showing 1 changed file with 27 additions and 36 deletions.
63 changes: 27 additions & 36 deletions sunkit_spex/models/physical/thermal.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,8 +152,6 @@ class ThermalEmission(FittableModel):
fixed=True
)



input_units_equivalencies = {'keV': u.spectral(),'K':u.temperature_energy()}
_input_units_allow_dimensionless = True

Expand All @@ -178,7 +176,6 @@ def __init__(self,
ar=ar,ca=ca,fe=fe,
**kwargs)

# def evaluate(self, spectrum, temperature, emission_measure, observer_distance):
def evaluate(self, energy_edges, temperature, emission_measure, observer_distance,
mg,si,s,ar,ca,fe):

Expand Down Expand Up @@ -467,34 +464,10 @@ def thermal_emission(
To change the file used, see the setup_continuum_parameters() function.
{doc_string_params}"""
# Convert inputs to known units and confirm they are within range.

# energy_edges_keV, temperature_K = _sanitize_inputs(energy_edges, temperature)

if energy_edges.isscalar or len(energy_edges) < 2 or energy_edges.ndim > 1:
raise ValueError("energy_edges must be a 1-D astropy Quantity with length greater than 1.")

if isinstance(energy_edges, Quantity):
energy_edges_keV = energy_edges.to(u.keV)
else:
energy_edges_keV = energy_edges

if isinstance(temperature, Quantity):
temperature_K = temperature.to(u.K)
else:
temperature_K = temperature

if temperature.isscalar:
temperature_K = np.array([temperature_K.value]) * u.K

if isinstance(emission_measure, Quantity):
emission_measure = emission_measure.to(u.cm**-3)

if isinstance(observer_distance, Quantity):
observer_distance = observer_distance.to(u.cm)
else:
observer_distance = (observer_distance*u.AU).to(u.cm).value

# Sanitize inputs
energy_edges_keV, temperature_K, emission_measure, observer_distance = _sanitize_inputs(energy_edges, temperature,
emission_measure,observer_distance)

energy_range = (
min(CONTINUUM_GRID["energy range keV"][0], LINE_GRID["energy range keV"][0]),
Expand Down Expand Up @@ -1045,15 +1018,33 @@ def _weight_emission_bins(
return new_line_intensities, neighbor_intensities, neighbor_iline


def _sanitize_inputs(energy_edges, temperature):
# Convert inputs to known units and confirm they are within range.
def _sanitize_inputs(energy_edges, temperature, emission_measure, observer_distance):

if energy_edges.isscalar or len(energy_edges) < 2 or energy_edges.ndim > 1:
raise ValueError("energy_edges must be a 1-D astropy Quantity with length greater than 1.")
energy_edges_keV = energy_edges.to_value(u.keV)
temperature_K = temperature.to_value(u.K)

if isinstance(energy_edges, Quantity):
energy_edges_keV = energy_edges.to(u.keV)
else:
energy_edges_keV = energy_edges

if isinstance(temperature, Quantity):
temperature_K = temperature.to(u.K)
else:
temperature_K = temperature

if temperature.isscalar:
temperature_K = np.array([temperature_K])
return energy_edges_keV, temperature_K
temperature_K = np.array([temperature_K.value]) * u.K

if isinstance(emission_measure, Quantity):
emission_measure = emission_measure.to(u.cm**-3)

if isinstance(observer_distance, Quantity):
observer_distance = observer_distance.to(u.cm)
else:
observer_distance = (observer_distance*u.AU).to(u.cm).value

return energy_edges_keV, temperature_K, emission_measure, observer_distance


def _error_if_input_outside_valid_range(input_values, grid_range, param_name, param_unit):
Expand Down

0 comments on commit 9fd4ce9

Please sign in to comment.