From ee492bf3e0f92e7dc8d72dccdc4ff326128ada0b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?V=C3=ADctor=20Zabalza?= Date: Sun, 16 Jul 2017 12:08:27 +0100 Subject: [PATCH 1/8] Add proton synchrotron --- naima/models.py | 10 +- naima/radiative.py | 385 ++++++++++++++++++++++++++++--------- naima/tests/test_models.py | 62 +++++- setup.py | 4 +- 4 files changed, 365 insertions(+), 96 deletions(-) diff --git a/naima/models.py b/naima/models.py index 3071d000..7df27f91 100644 --- a/naima/models.py +++ b/naima/models.py @@ -8,12 +8,16 @@ from astropy.utils.data import get_pkg_data_filename from .extern.validator import (validate_scalar, validate_array, validate_physical_type) -from .radiative import Synchrotron, InverseCompton, PionDecay, Bremsstrahlung +from .radiative import ( + Synchrotron, ElectronSynchrotron, ProtonSynchrotron, + InverseCompton, PionDecay, Bremsstrahlung +) from .model_utils import memoize __all__ = [ - 'Synchrotron', 'InverseCompton', 'PionDecay', 'Bremsstrahlung', - 'BrokenPowerLaw', 'ExponentialCutoffPowerLaw', 'PowerLaw', 'LogParabola', + 'Synchrotron', 'ElectronSynchrotron', 'ProtonSynchrotron', + 'InverseCompton', 'PionDecay', 'Bremsstrahlung', 'BrokenPowerLaw', + 'ExponentialCutoffPowerLaw', 'PowerLaw', 'LogParabola', 'ExponentialCutoffBrokenPowerLaw', 'TableModel', 'EblAbsorptionModel' ] diff --git a/naima/radiative.py b/naima/radiative.py index f7cdd50d..48280aed 100644 --- a/naima/radiative.py +++ b/naima/radiative.py @@ -3,6 +3,10 @@ from __future__ import (absolute_import, division, print_function, unicode_literals) import numpy as np + +from numtraits import NumericalTrait +from traitlets import HasTraits, Int, observe + from .extern.validator import (validate_scalar, validate_array, validate_physical_type) @@ -75,16 +79,23 @@ def __init__(self, particle_distribution): physical_type='differential energy') except (AttributeError, TypeError): # otherwise check the output - pd = self.particle_distribution([ - 0.1, - 1, - 10, - ] * u.TeV) + pd = self.particle_distribution([0.1, 1, 10] * u.TeV) validate_physical_type( 'Particle distribution', pd, physical_type='differential energy') + def _spectrum(self, photon_energy): + """ + Compute photon spectrum. Implemented in subclasses + + Parameters + ---------- + photon_energy : :class:`~astropy.units.Quantity` instance + Photon energy array. + """ + raise NotImplementedError + @memoize def flux(self, photon_energy, distance=1 * u.kpc): """Differential flux at a given distance from the source. @@ -136,13 +147,18 @@ def sed(self, photon_energy, distance=1 * u.kpc): return sed -class BaseElectron(BaseRadiative): - """Implements gam and nelec properties in addition to the BaseRadiative methods +class BaseLorentzFactor(BaseRadiative): + """ + Implements gam and npart properties in addition to the BaseRadiative + methods """ - def __init__(self, particle_distribution): - super(BaseElectron, self).__init__(particle_distribution) - self.param_names = ['Eemin', 'Eemax', 'nEed'] + def __init__(self, particle_distribution, mass): + super(BaseLorentzFactor, self).__init__(particle_distribution) + self.param_names = ['gmin', 'gmax', 'ngd'] + mass = validate_scalar('mass', mass, physical_type='mass') + self.mc2 = (mass * c**2).cgs + self.mc2_unit = u.Unit(self.mc2) self._memoize = True self._cache = {} self._queue = [] @@ -151,68 +167,70 @@ def __init__(self, particle_distribution): def _gam(self): """ Lorentz factor array """ - log10gmin = np.log10(self.Eemin / mec2).value - log10gmax = np.log10(self.Eemax / mec2).value + log10gmin = np.log10(self.gmin) + log10gmax = np.log10(self.gmax) return np.logspace(log10gmin, log10gmax, - self.nEed * (log10gmax - log10gmin)) + self.ngd * (log10gmax - log10gmin)) @property - def _nelec(self): + def _npart(self): """ Particles per unit lorentz factor """ - pd = self.particle_distribution(self._gam * mec2) - return pd.to(1 / mec2_unit).value + pd = self.particle_distribution(self._gam * self.mc2) + return pd.to(1 / self.mc2_unit).value @property - def We(self): - """ Total energy in electrons used for the radiative calculation + def _W(self): + """ Total energy in particles used for the radiative calculation """ - We = trapz_loglog(self._gam * self._nelec, self._gam * mec2) - return We + gam = self._gam + return trapz_loglog(gam * self._npart, gam * self.mc2) - def compute_We(self, Eemin=None, Eemax=None): - """ Total energy in electrons between energies Eemin and Eemax + def _compute_W(self, Emin=None, Emax=None): + """ Total energy in particles between energies Emin and Emax Parameters ---------- Eemin : :class:`~astropy.units.Quantity` float, optional - Minimum electron energy for energy content calculation. + Minimum particle energy for energy content calculation. Eemax : :class:`~astropy.units.Quantity` float, optional - Maximum electron energy for energy content calculation. + Maximum particle energy for energy content calculation. """ - if Eemin is None and Eemax is None: - We = self.We + if Emin is None and Emax is None: + W = self.W else: - if Eemax is None: - Eemax = self.Eemax - if Eemin is None: - Eemin = self.Eemin + if Emin is None: + Emin = self.gmin * self.mc2 + if Emax is None: + Emax = self.gmax * self.mc2 - log10gmin = np.log10(Eemin / mec2) - log10gmax = np.log10(Eemax / mec2) + log10gmin = np.log10(Emin / self.mc2).value + log10gmax = np.log10(Emax / self.mc2).value gam = np.logspace(log10gmin, log10gmax, - self.nEed * (log10gmax - log10gmin)) - nelec = self.particle_distribution(gam * mec2).to(1 / - mec2_unit).value - We = trapz_loglog(gam * nelec, gam * mec2) + self.ngd * (log10gmax - log10gmin)) - return We + pd = self.particle_distribution(self._gam * self.mc2) + npart = pd.to(1 / self.mc2_unit).value - def set_We(self, We, Eemin=None, Eemax=None, amplitude_name=None): - """ Normalize particle distribution so that the total energy in electrons - between Eemin and Eemax is We + W = trapz_loglog(gam * npart, gam * self.mc2) + + return W + + def _set_W(self, W, Emin=None, Emax=None, amplitude_name=None): + """ Normalize particle distribution so that the total energy in + particles between Emin and Emax is W Parameters ---------- - We : :class:`~astropy.units.Quantity` float - Desired energy in electrons. + W : :class:`~astropy.units.Quantity` float + Desired energy in particles. - Eemin : :class:`~astropy.units.Quantity` float, optional - Minimum electron energy for energy content calculation. + Emin : :class:`~astropy.units.Quantity` float, optional + Minimum particle energy for energy content calculation. - Eemax : :class:`~astropy.units.Quantity` float, optional - Maximum electron energy for energy content calculation. + Emax : :class:`~astropy.units.Quantity` float, optional + Maximum particle energy for energy content calculation. amplitude_name : str, optional Name of the amplitude parameter of the particle distribution. It @@ -220,67 +238,161 @@ def set_We(self, We, Eemin=None, Eemax=None, amplitude_name=None): Defaults to ``amplitude``. """ - We = validate_scalar('We', We, physical_type='energy') - oldWe = self.compute_We(Eemin=Eemin, Eemax=Eemax) + W = validate_scalar('W', W, physical_type='energy') + oldW = self._compute_W(Emin=Emin, Emax=Emax) + factor = (W / oldW).decompose() if amplitude_name is None: try: - self.particle_distribution.amplitude *= ( - We / oldWe).decompose() + self.particle_distribution.amplitude *= factor except AttributeError: log.error( 'The particle distribution does not have an attribute' ' called amplitude to modify its normalization: you can' - ' set the name with the amplitude_name parameter of set_We' + ' set the name with the amplitude_name parameter of set_W' ) else: oldampl = getattr(self.particle_distribution, amplitude_name) setattr(self.particle_distribution, amplitude_name, - oldampl * (We / oldWe).decompose()) + oldampl * factor) -class Synchrotron(BaseElectron): - """Synchrotron emission from an electron population. +class BaseElectron(BaseLorentzFactor, HasTraits): + """ + Sets particle mass of BaseLorentzFactor to the electron mass + """ - This class uses the approximation of the synchrotron emissivity in a - random magnetic field of Aharonian, Kelner, and Prosekin 2010, PhysRev D - 82, 3002 (`arXiv:1006.1045 `_). + def __init__(self, particle_distribution): + super(BaseElectron, self).__init__(particle_distribution, mass=m_e) - Parameters - ---------- - particle_distribution : function - Particle distribution function, taking electron energies as a - `~astropy.units.Quantity` array or float, and returning the particle - energy density in units of number of electrons per unit energy as a - `~astropy.units.Quantity` array or float. + Eemin = NumericalTrait(convertible_to=u.erg) + Eemax = NumericalTrait(convertible_to=u.erg) + nEed = Int() - B : :class:`~astropy.units.Quantity` float instance, optional - Isotropic magnetic field strength. Default: equipartition - with CMB (3.24e-6 G) + @observe('Eemin') + def _handle_Eemin(self, change): + self.gmin = float(change['new'] / self.mc2) - Other parameters - ---------------- - Eemin : :class:`~astropy.units.Quantity` float instance, optional - Minimum electron energy for the electron distribution. Default is 1 - GeV. + @observe('Eemax') + def _handle_Eemax(self, change): + self.gmax = float(change['new'] / self.mc2) - Eemax : :class:`~astropy.units.Quantity` float instance, optional - Maximum electron energy for the electron distribution. Default is 510 - TeV. + @observe('nEed') + def _handle_nEed(self, change): + self.ngd = change['new'] - nEed : scalar - Number of points per decade in energy for the electron energy and - distribution arrays. Default is 100. + @property + def We(self): + """ Total energy in particles used for the radiative calculation + """ + return self._W + + def set_We(self, We, Eemin=None, Eemax=None, amplitude_name=None): + """ Normalize particle distribution so that the total energy in + electrons between `Eemin` and `Eemax` is `We` + + Parameters + ---------- + We : :class:`~astropy.units.Quantity` float + Desired energy in electrons. + + Eemin : :class:`~astropy.units.Quantity` float, optional + Minimum electron energy for energy content calculation. + + Eemax : :class:`~astropy.units.Quantity` float, optional + Maximum electron energy for energy content calculation. + + amplitude_name : str, optional + Name of the amplitude parameter of the particle distribution. It + must be accesible as an attribute of the distribution function. + Defaults to ``amplitude``. + """ + return self._set_W(We, Emin=Eemin, Emax=Eemax, + amplitude_name=amplitude_name) + + def compute_We(self, Eemin=None, Eemax=None): + """ Total energy in electrons between energies Emin and Emax + + Parameters + ---------- + Eemin : :class:`~astropy.units.Quantity` float, optional + Minimum electron energy for energy content calculation. + + Eemax : :class:`~astropy.units.Quantity` float, optional + Maximum electron energy for energy content calculation. + """ + return self._compute_W(Emin=Eemin, Emax=Eemax) + + +class BaseLorentzProton(BaseLorentzFactor, HasTraits): + """ + Sets particle mass of BaseLorentzFactor to the proton mass """ - def __init__(self, particle_distribution, B=3.24e-6 * u.G, **kwargs): - super(Synchrotron, self).__init__(particle_distribution) - self.B = validate_scalar('B', B, physical_type='magnetic flux density') - self.Eemin = 1 * u.GeV - self.Eemax = 1e9 * mec2 - self.nEed = 100 - self.param_names += ['B'] - self.__dict__.update(**kwargs) + def __init__(self, particle_distribution): + super(BaseLorentzProton, self).__init__(particle_distribution, + mass=m_p) + + Epmin = NumericalTrait(convertible_to=u.erg) + Epmax = NumericalTrait(convertible_to=u.erg) + nEpd = Int() + + @observe('Epmin') + def _handle_Epmin(self, change): + self.gmin = float(change['new'] / self.mc2) + + @observe('Epmax') + def _handle_Epmax(self, change): + self.gmax = float(change['new'] / self.mc2) + + @observe('nEpd') + def _handle_nEpd(self, change): + self.ngd = change['new'] + + @property + def Wp(self): + """ Total energy in particles used for the radiative calculation + """ + return self._W + + def set_Wp(self, Wp, Epmin=None, Epmax=None, amplitude_name=None): + """ Normalize particle distribution so that the total energy in + protons between `Epmin` and `Epmax` is `Wp` + + Parameters + ---------- + Wp : :class:`~astropy.units.Quantity` float + Desired energy in protons. + + Epmin : :class:`~astropy.units.Quantity` float, optional + Minimum proton energy for energy content calculation. + + Epmax : :class:`~astropy.units.Quantity` float, optional + Maximum proton energy for energy content calculation. + + amplitude_name : str, optional + Name of the amplitude parameter of the particle distribution. It + must be accesible as an attribute of the distribution function. + Defaults to ``amplitude``. + """ + return self._set_W(Wp, Emin=Epmin, Emax=Epmax, + amplitude_name=amplitude_name) + + def compute_Wp(self, Epmin=None, Epmax=None): + """ Total energy in protons between energies Emin and Emax + + Parameters + ---------- + Epmin : :class:`~astropy.units.Quantity` float, optional + Minimum proton energy for energy content calculation. + + Epmax : :class:`~astropy.units.Quantity` float, optional + Maximum proton energy for energy content calculation. + """ + return self._compute_W(Emin=Epmin, Emax=Epmax) + + +class BaseSynchrotron(BaseLorentzFactor): def _spectrum(self, photon_energy): """Compute intrinsic synchrotron differential spectrum for energies in @@ -320,23 +432,115 @@ def Gtilde(x): # when using cgs (SI is fine, see # https://github.com/astropy/astropy/issues/1687) CS1_0 = np.sqrt(3) * e.value**3 * self.B.to('G').value - CS1_1 = (2 * np.pi * m_e.cgs.value * c.cgs.value - ** 2 * hbar.cgs.value * outspecene.to('erg').value) + CS1_1 = (2 * np.pi * self.mc2.cgs.value + * hbar.cgs.value * outspecene.to('erg').value) CS1 = CS1_0 / CS1_1 # Critical energy, erg Ec = 3 * e.value * hbar.cgs.value * self.B.to('G').value * self._gam**2 - Ec /= 2 * (m_e * c).cgs.value + Ec /= 2 * (self.mc2 / c).cgs.value EgEc = outspecene.to('erg').value / np.vstack(Ec) dNdE = CS1 * Gtilde(EgEc) # return units spec = trapz_loglog( - np.vstack(self._nelec) * dNdE, self._gam, axis=0) / u.s / u.erg + np.vstack(self._npart) * dNdE, self._gam, axis=0) / u.s / u.erg spec = spec.to('1/(s eV)') return spec +class ElectronSynchrotron(BaseElectron, BaseSynchrotron): + """Synchrotron emission from an electron population. + + This class uses the approximation of the synchrotron emissivity in a + random magnetic field of Aharonian, Kelner, and Prosekin 2010, PhysRev D + 82, 3002 (`arXiv:1006.1045 `_). + + Parameters + ---------- + particle_distribution : function + Particle distribution function, taking electron energies as a + `~astropy.units.Quantity` array or float, and returning the particle + energy density in units of number of electrons per unit energy as a + `~astropy.units.Quantity` array or float. + + B : :class:`~astropy.units.Quantity` float instance, optional + Isotropic magnetic field strength. Default: equipartition + with CMB (3.24e-6 G) + + Other parameters + ---------------- + Eemin : :class:`~astropy.units.Quantity` float instance, optional + Minimum electron energy for the electron distribution. Default is 1 + GeV. + + Eemax : :class:`~astropy.units.Quantity` float instance, optional + Maximum electron energy for the electron distribution. Default is 510 + TeV. + + nEed : scalar + Number of points per decade in energy for the electron energy and + distribution arrays. Default is 100. + """ + + def __init__(self, particle_distribution, B=3.24e-6 * u.G, **kwargs): + super(ElectronSynchrotron, self).__init__(particle_distribution) + self.B = validate_scalar('B', B, physical_type='magnetic flux density') + self.Eemin = 1 * u.GeV + self.Eemax = (1e9 * m_e * c ** 2).to(u.TeV) + self.nEed = 100 + self.param_names += ['B'] + for key, value in kwargs.items(): + setattr(self, key, value) + + +class Synchrotron(ElectronSynchrotron): + pass + + +class ProtonSynchrotron(BaseLorentzProton, BaseSynchrotron): + """Synchrotron emission from a proton population. + + This class uses the approximation of the synchrotron emissivity in a + random magnetic field of Aharonian, Kelner, and Prosekin 2010, PhysRev D + 82, 3002 (`arXiv:1006.1045 `_). + + Parameters + ---------- + particle_distribution : function + Particle distribution function, taking proton energies as a + `~astropy.units.Quantity` array or float, and returning the particle + energy density in units of number of protons per unit energy as a + `~astropy.units.Quantity` array or float. + + B : :class:`~astropy.units.Quantity` float instance, optional + Isotropic magnetic field strength. Default: equipartition + with CMB (3.24e-6 G) + + Other parameters + ---------------- + Epmin : :class:`~astropy.units.Quantity` float instance, optional + Minimum proton energy for the proton distribution. Default is 1 + GeV. + + Epmax : :class:`~astropy.units.Quantity` float instance, optional + Maximum proton energy for the proton distribution. Default is 1 PeV. + + nEpd : scalar + Number of points per decade in energy for the proton energy and + distribution arrays. Default is 100. + """ + + def __init__(self, particle_distribution, B=3.24e-6 * u.G, **kwargs): + super(ProtonSynchrotron, self).__init__(particle_distribution) + self.B = validate_scalar('B', B, physical_type='magnetic flux density') + self.Epmin = 1 * u.GeV + self.Epmax = 1 * u.PeV + self.nEpd = 100 + self.param_names += ['B'] + for key, value in kwargs.items(): + setattr(self, key, value) + def G12(x, a): """ @@ -1089,7 +1293,6 @@ class PionDecay(BaseProton): 123014 `_. - Parameters ---------- particle_distribution : function diff --git a/naima/tests/test_models.py b/naima/tests/test_models.py index d9f9a3fc..6c0b6e87 100644 --- a/naima/tests/test_models.py +++ b/naima/tests/test_models.py @@ -5,6 +5,8 @@ from astropy.tests.helper import pytest from astropy.extern import six +from traitlets import TraitError + from ..utils import trapz_loglog try: @@ -52,7 +54,7 @@ def particle_dists(): @pytest.mark.skipif('not HAS_SCIPY') -def test_synchrotron_lum(particle_dists): +def test_electron_synchrotron_lum(particle_dists): """ test sync calculation """ @@ -89,6 +91,64 @@ def test_synchrotron_lum(particle_dists): assert_allclose(lsy.value, 31374131.90312505) +@pytest.mark.skipif('not HAS_SCIPY') +def test_proton_synchrotron_lum(particle_dists): + """ + test sync calculation + """ + from ..models import ProtonSynchrotron + + ECPL, PL, BPL = particle_dists + + # lum_ref = [0.00025231296225663107, 0.03316715765695228, + # 0.00044597089198025806] + # We_ref = [5064124672.902273, 11551172166.866821, 926633861.2898524] + + Wps = [] + lsys = [] + for pdist in particle_dists: + sy = ProtonSynchrotron(pdist, **proton_properties) + + Wps.append(sy.Wp.to('erg').value) + + lsy = trapz_loglog(sy.flux(energy, 0) * energy, energy).to('erg/s') + assert lsy.unit == u.erg / u.s + lsys.append(lsy.value) + + print(lsys) + print(Wps) + # assert_allclose(lsys, lum_ref) + # assert_allclose(Wes, We_ref) + + sy = ProtonSynchrotron(ECPL, B=1 * u.G, **proton_properties) + sy.flux(data) + sy.flux(data2) + + lsy = trapz_loglog(sy.flux(energy, 0) * energy, energy).to('erg/s') + assert (lsy.unit == u.erg / u.s) + print(lsy) + # assert_allclose(lsy.value, 31374131.90312505) + + +@pytest.mark.skipif('not HAS_SCIPY') +def test_synchrotron_traits(particle_dists): + from ..models import Synchrotron + ECPL, _, _ = particle_dists + sy = Synchrotron(ECPL, Eemin=1 * u.GeV, Eemax=1 * u.PeV) + + sy.Eemin = 1 * u.TeV + assert sy.gmin == float(1 * u.TeV / (m_e * c **2)) + + sy.Eemax = 100 * u.TeV + assert sy.gmax == float(100 * u.TeV / (m_e * c **2)) + + sy.nEed = 10 + assert sy.ngd == 10 + + with pytest.raises(TraitError): + sy.Eemin = 10 * u.m + + @pytest.mark.skipif('not HAS_SCIPY') def test_bolometric_luminosity(particle_dists): """ diff --git a/setup.py b/setup.py index 7f379dea..9fdf9824 100644 --- a/setup.py +++ b/setup.py @@ -97,7 +97,9 @@ 'corner', 'matplotlib', 'scipy', - 'h5py'], + 'h5py', + 'numtraits', + 'traitlets'], setup(name=PACKAGENAME, From 87b2c78ca7c09212340eaae5c36a35e150e066d0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?V=C3=ADctor=20Zabalza?= Date: Sun, 16 Jul 2017 12:20:21 +0100 Subject: [PATCH 2/8] add numtraits and traitlets to test deps --- .travis.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.travis.yml b/.travis.yml index dd9a1a9b..2b2a81e7 100644 --- a/.travis.yml +++ b/.travis.yml @@ -26,8 +26,8 @@ env: - PYTHON_VERSION=3.6 - NUMPY_VERSION=1.13 - ASTROPY_VERSION=stable - - CONDA_DEPENDENCIES='pytest pip Cython jinja2 pyyaml scipy matplotlib h5py mock sphinx_rtd_theme' - - PIP_DEPENDENCIES='emcee corner' + - CONDA_DEPENDENCIES='pytest pip Cython jinja2 pyyaml scipy matplotlib h5py mock sphinx_rtd_theme traitlets' + - PIP_DEPENDENCIES='emcee corner numtraits' matrix: - PYTHON_VERSION=2.7 SETUP_CMD='egg_info' From 197c0e3c9216968a262dce77f0b95c4a9467140e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?V=C3=ADctor=20Zabalza?= Date: Sun, 16 Jul 2017 16:50:01 +0100 Subject: [PATCH 3/8] Fix errors in BaseLorentzFactor --- naima/radiative.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/naima/radiative.py b/naima/radiative.py index 48280aed..d5050df4 100644 --- a/naima/radiative.py +++ b/naima/radiative.py @@ -198,7 +198,7 @@ def _compute_W(self, Emin=None, Emax=None): Maximum particle energy for energy content calculation. """ if Emin is None and Emax is None: - W = self.W + W = self._W else: if Emin is None: Emin = self.gmin * self.mc2 @@ -210,7 +210,7 @@ def _compute_W(self, Emin=None, Emax=None): gam = np.logspace(log10gmin, log10gmax, self.ngd * (log10gmax - log10gmin)) - pd = self.particle_distribution(self._gam * self.mc2) + pd = self.particle_distribution(gam * self.mc2) npart = pd.to(1 / self.mc2_unit).value W = trapz_loglog(gam * npart, gam * self.mc2) From 6967a4979aa9f65aabe6280c9acc9a8c837bfab3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?V=C3=ADctor=20Zabalza?= Date: Sun, 16 Jul 2017 16:50:39 +0100 Subject: [PATCH 4/8] Fix Bremsstrahlung and IC --- naima/radiative.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/naima/radiative.py b/naima/radiative.py index d5050df4..67f565ed 100644 --- a/naima/radiative.py +++ b/naima/radiative.py @@ -630,7 +630,8 @@ def __init__(self, self.Eemax = 1e9 * mec2 self.nEed = 100 self.param_names += ['seed_photon_fields'] - self.__dict__.update(**kwargs) + for key, value in kwargs.items(): + setattr(self, key, value) @staticmethod def _process_input_seed(seed_photon_fields): @@ -874,7 +875,7 @@ def _calc_specic(self, seed, outspecene): self._gam, self.seed_photon_fields[seed]['energy'], self.seed_photon_fields[seed]['photon_density'], Eph) - lum = uf * Eph * trapz_loglog(self._nelec * gamint, self._gam) + lum = uf * Eph * trapz_loglog(self._npart * gamint, self._gam) lum = lum * u.Unit('1/s') return lum / outspecene # return differential spectrum in 1/s/eV @@ -1026,7 +1027,8 @@ def __init__(self, particle_distribution, n0=1 / u.cm**3, **kwargs): self.weight_ee = np.sum(Z * X) self.weight_ep = np.sum(Z**2 * X) self.param_names += ['n0', 'weight_ee', 'weight_ep'] - self.__dict__.update(**kwargs) + for key, value in kwargs.items(): + setattr(self, key, value) @staticmethod def _sigma_1(gam, eps): @@ -1140,7 +1142,7 @@ def _emiss_ee(self, Eph): gam = np.vstack(self._gam) # compute integral with electron distribution emiss = c.cgs * trapz_loglog( - np.vstack(self._nelec) * self._sigma_ee(gam, Eph), + np.vstack(self._npart) * self._sigma_ee(gam, Eph), self._gam, axis=0) return emiss @@ -1156,7 +1158,7 @@ def _emiss_ep(self, Eph): eps = (Eph / mec2).decompose().value # compute integral with electron distribution emiss = c.cgs * trapz_loglog( - np.vstack(self._nelec) * self._sigma_ep(gam, eps), + np.vstack(self._npart) * self._sigma_ep(gam, eps), self._gam, axis=0).to(u.cm**2 / Eph.unit) return emiss From a7a6c52de2cf50c04d741246140cf7a46032c419 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?V=C3=ADctor=20Zabalza?= Date: Sun, 16 Jul 2017 16:50:53 +0100 Subject: [PATCH 5/8] parametrize set_We tests --- naima/tests/test_models.py | 30 ++++++++++++++---------------- 1 file changed, 14 insertions(+), 16 deletions(-) diff --git a/naima/tests/test_models.py b/naima/tests/test_models.py index 6c0b6e87..e1744b4e 100644 --- a/naima/tests/test_models.py +++ b/naima/tests/test_models.py @@ -175,9 +175,7 @@ def test_compute_We(particle_dists): ECPL, PL, BPL = particle_dists sy = Synchrotron(ECPL, B=1 * u.G, **electron_properties) - Eemin, Eemax = 10 * u.GeV, 100 * u.TeV - sy.compute_We() sy.compute_We(Eemin=Eemin) sy.compute_We(Eemax=Eemax) @@ -193,7 +191,9 @@ def test_compute_We(particle_dists): @pytest.mark.skipif('not HAS_SCIPY') -def test_set_We(particle_dists): +@pytest.mark.parametrize("Eemin", [1 * u.GeV, 10 * u.GeV, None]) +@pytest.mark.parametrize("Eemax", [100 * u.TeV, None]) +def test_set_We(particle_dists, Eemin, Eemax): """ test sync calculation """ @@ -206,18 +206,15 @@ def test_set_We(particle_dists): W = 1e49 * u.erg - Eemax = 100 * u.TeV - for Eemin in [1 * u.GeV, 10 * u.GeV, None]: - for Eemax in [100 * u.TeV, None]: - sy.set_We(W, Eemin, Eemax) - assert_allclose(W, sy.compute_We(Eemin, Eemax)) - sy.set_We(W, Eemin, Eemax, amplitude_name='amplitude') - assert_allclose(W, sy.compute_We(Eemin, Eemax)) + sy.set_We(W, Eemin, Eemax) + assert_allclose(W, sy.compute_We(Eemin, Eemax)) + sy.set_We(W, Eemin, Eemax, amplitude_name='amplitude') + assert_allclose(W, sy.compute_We(Eemin, Eemax)) - pp.set_Wp(W, Eemin, Eemax) - assert_allclose(W, pp.compute_Wp(Eemin, Eemax)) - pp.set_Wp(W, Eemin, Eemax, amplitude_name='amplitude') - assert_allclose(W, pp.compute_Wp(Eemin, Eemax)) + pp.set_Wp(W, Eemin, Eemax) + assert_allclose(W, pp.compute_Wp(Eemin, Eemax)) + pp.set_Wp(W, Eemin, Eemax, amplitude_name='amplitude') + assert_allclose(W, pp.compute_Wp(Eemin, Eemax)) with pytest.raises(AttributeError): sy.set_We(W, amplitude_name='norm') @@ -238,8 +235,9 @@ def test_bremsstrahlung_lum(particle_dists): # avoid low-energy (E<2MeV) where there are problems with cross-section energy2 = np.logspace(8, 14, 100) * u.eV - brems = Bremsstrahlung(ECPL, n0=1 * u.cm** -3, Eemin=m_e * c**2) - lbrems = trapz_loglog(brems.flux(energy2, 0) * energy2, energy2).to('erg/s') + brems = Bremsstrahlung(ECPL, n0=1 * u.cm ** -3, Eemin=m_e * c ** 2) + lbrems = trapz_loglog(brems.flux(energy2, 0) * energy2, + energy2).to('erg/s') lum_ref = 2.3064095039069847e-05 assert_allclose(lbrems.value, lum_ref) From e9601c9c3c06733afa5b8e84568089431902e420 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?V=C3=ADctor=20Zabalza?= Date: Sun, 16 Jul 2017 17:11:51 +0100 Subject: [PATCH 6/8] update CHANGES --- CHANGES.rst | 11 ++++++++++- setup.py | 2 +- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/CHANGES.rst b/CHANGES.rst index 55bbdcd8..a1a352a3 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -1,3 +1,13 @@ +0.9 (unreleased) +---------------- + +- Added the `ProtonSynchrotron` class to compute synchrotron radiation from + proton populations. + +Bug fixes +^^^^^^^^^ +- Updated deprecated numpy usages. + 0.8 (2016-12-21) ---------------- @@ -140,7 +150,6 @@ API Changes - module sherpamod is now sherpa_modules. - 0.1 (2015-02-02) ---------------- diff --git a/setup.py b/setup.py index 9fdf9824..451a8da3 100644 --- a/setup.py +++ b/setup.py @@ -45,7 +45,7 @@ builtins._ASTROPY_PACKAGE_NAME_ = PACKAGENAME # VERSION should be PEP386 compatible (http://www.python.org/dev/peps/pep-0386) -VERSION = '0.8' +VERSION = '0.8.dev' # Indicates if this version is a release version RELEASE = 'dev' not in VERSION From 7a2597215b2f66463f057ea5b6206378da4d2c06 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?V=C3=ADctor=20Zabalza?= Date: Sun, 16 Jul 2017 18:16:19 +0100 Subject: [PATCH 7/8] add deps to appveyor [skip travis] --- appveyor.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/appveyor.yml b/appveyor.yml index 70a3485d..a41500d0 100644 --- a/appveyor.yml +++ b/appveyor.yml @@ -14,8 +14,8 @@ environment: PYTHON_ARCH: "64" # needs to be set for CMD_IN_ENV to succeed. If a mix # of 32 bit and 64 bit builds are needed, move this # to the matrix section. - CONDA_DEPENDENCIES: "pytest pip Cython jinja2 pyyaml scipy matplotlib h5py mock sphinx_rtd_theme" - PIP_DEPENDENCIES: "emcee" + CONDA_DEPENDENCIES: "pytest pip Cython jinja2 pyyaml scipy matplotlib h5py mock sphinx_rtd_theme traitlets" + PIP_DEPENDENCIES: "emcee numtraits" ASTROPY_VERSION: "stable" matrix: From 4cf74b96778a17591b1160205f6e1ffcc87fca64 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?V=C3=ADctor=20Zabalza?= Date: Wed, 28 Nov 2018 13:08:50 +0000 Subject: [PATCH 8/8] black formatting --- ah_bootstrap.py | 322 ++++++----- docs/_static/RXJ1713_IC.py | 109 ++-- docs/conf.py | 101 ++-- examples/CrabNebula_SynSSC.py | 72 ++- examples/RXJ1713_IC.py | 64 ++- examples/RXJ1713_IC_minimal.py | 43 +- examples/RXJ1713_SynIC.py | 77 +-- examples/absorbed_SynIC.py | 46 +- examples/model_examples.py | 58 +- ez_setup.py | 134 +++-- naima/__init__.py | 1 + naima/_astropy_init.py | 65 ++- naima/analysis.py | 297 +++++----- naima/core.py | 313 ++++++----- naima/extern/interruptible_pool.py | 19 +- naima/extern/minimize.py | 94 +++- naima/extern/validator.py | 59 +- naima/model_fitter.py | 156 +++--- naima/model_utils.py | 23 +- naima/models.py | 217 +++++--- naima/plot.py | 842 ++++++++++++++++------------ naima/radiative.py | 853 +++++++++++++++++------------ naima/sherpa_models.py | 211 ++++--- naima/tests/fixtures.py | 98 +++- naima/tests/setup_package.py | 3 +- naima/tests/test_functionfit.py | 284 +++++++--- naima/tests/test_interactive.py | 60 +- naima/tests/test_models.py | 336 +++++++----- naima/tests/test_plotting.py | 115 ++-- naima/tests/test_saveread.py | 62 ++- naima/tests/test_sherpamod.py | 17 +- naima/tests/test_utils.py | 108 ++-- naima/utils.py | 336 +++++++----- pyproject.toml | 15 + setup.py | 131 +++-- 35 files changed, 3446 insertions(+), 2295 deletions(-) create mode 100644 pyproject.toml diff --git a/ah_bootstrap.py b/ah_bootstrap.py index 786b8b14..a194b93a 100644 --- a/ah_bootstrap.py +++ b/ah_bootstrap.py @@ -69,14 +69,15 @@ # setuptools_bootstrap.py; but it was combined into ah_bootstrap.py try: import pkg_resources - _setuptools_req = pkg_resources.Requirement.parse('setuptools>=0.7') + + _setuptools_req = pkg_resources.Requirement.parse("setuptools>=0.7") # This may raise a DistributionNotFound in which case no version of # setuptools or distribute is properly installed - _setuptools = pkg_resources.get_distribution('setuptools') + _setuptools = pkg_resources.get_distribution("setuptools") if _setuptools not in _setuptools_req: # Older version of setuptools; check if we have distribute; again if # this results in DistributionNotFound we want to give up - _distribute = pkg_resources.get_distribution('distribute') + _distribute = pkg_resources.get_distribution("distribute") if _setuptools != _distribute: # It's possible on some pathological systems to have an old version # of setuptools and distribute on sys.path simultaneously; make @@ -88,6 +89,7 @@ # There are several types of exceptions that can occur here; if all else # fails bootstrap and use the bootstrapped version from ez_setup import use_setuptools + use_setuptools() @@ -97,7 +99,7 @@ # https://github.com/astropy/astropy-helpers/issues/302 try: - import typing # noqa + import typing # noqa except ImportError: pass @@ -108,7 +110,7 @@ # later cause the TemporaryDirectory class defined in it to stop working when # used later on by setuptools try: - import setuptools.py31compat # noqa + import setuptools.py31compat # noqa except ImportError: pass @@ -122,7 +124,8 @@ # issue) try: import matplotlib - matplotlib.use('Agg') + + matplotlib.use("Agg") import matplotlib.pyplot except: # Ignore if this fails for *any* reason* @@ -144,21 +147,25 @@ # TODO: Maybe enable checking for a specific version of astropy_helpers? -DIST_NAME = 'astropy-helpers' -PACKAGE_NAME = 'astropy_helpers' +DIST_NAME = "astropy-helpers" +PACKAGE_NAME = "astropy_helpers" # Defaults for other options DOWNLOAD_IF_NEEDED = True -INDEX_URL = 'https://pypi.python.org/simple' +INDEX_URL = "https://pypi.python.org/simple" USE_GIT = True OFFLINE = False AUTO_UPGRADE = True # A list of all the configuration options and their required types CFG_OPTIONS = [ - ('auto_use', bool), ('path', str), ('download_if_needed', bool), - ('index_url', str), ('use_git', bool), ('offline', bool), - ('auto_upgrade', bool) + ("auto_use", bool), + ("path", str), + ("download_if_needed", bool), + ("index_url", str), + ("use_git", bool), + ("offline", bool), + ("auto_upgrade", bool), ] @@ -168,14 +175,21 @@ class _Bootstrapper(object): documentation. """ - def __init__(self, path=None, index_url=None, use_git=None, offline=None, - download_if_needed=None, auto_upgrade=None): + def __init__( + self, + path=None, + index_url=None, + use_git=None, + offline=None, + download_if_needed=None, + auto_upgrade=None, + ): if path is None: path = PACKAGE_NAME if not (isinstance(path, _str_types) or path is False): - raise TypeError('path must be a string or False') + raise TypeError("path must be a string or False") if PY3 and not isinstance(path, _text_type): fs_encoding = sys.getfilesystemencoding() @@ -192,15 +206,20 @@ def __init__(self, path=None, index_url=None, use_git=None, offline=None, download_if_needed = False auto_upgrade = False - self.download = (download_if_needed - if download_if_needed is not None - else DOWNLOAD_IF_NEEDED) - self.auto_upgrade = (auto_upgrade - if auto_upgrade is not None else AUTO_UPGRADE) + self.download = ( + download_if_needed + if download_if_needed is not None + else DOWNLOAD_IF_NEEDED + ) + self.auto_upgrade = ( + auto_upgrade if auto_upgrade is not None else AUTO_UPGRADE + ) # If this is a release then the .git directory will not exist so we # should not use git. - git_dir_exists = os.path.exists(os.path.join(os.path.dirname(__file__), '.git')) + git_dir_exists = os.path.exists( + os.path.join(os.path.dirname(__file__), ".git") + ) if use_git is None and not git_dir_exists: use_git = False @@ -218,7 +237,7 @@ def main(cls, argv=None): config = cls.parse_config() config.update(cls.parse_command_line(argv)) - auto_use = config.pop('auto_use', False) + auto_use = config.pop("auto_use", False) bootstrapper = cls(**config) if auto_use: @@ -231,13 +250,13 @@ def main(cls, argv=None): @classmethod def parse_config(cls): - if not os.path.exists('setup.cfg'): + if not os.path.exists("setup.cfg"): return {} cfg = ConfigParser() try: - cfg.read('setup.cfg') + cfg.read("setup.cfg") except Exception as e: if DEBUG: raise @@ -245,22 +264,23 @@ def parse_config(cls): log.error( "Error reading setup.cfg: {0!r}\n{1} will not be " "automatically bootstrapped and package installation may fail." - "\n{2}".format(e, PACKAGE_NAME, _err_help_msg)) + "\n{2}".format(e, PACKAGE_NAME, _err_help_msg) + ) return {} - if not cfg.has_section('ah_bootstrap'): + if not cfg.has_section("ah_bootstrap"): return {} config = {} for option, type_ in CFG_OPTIONS: - if not cfg.has_option('ah_bootstrap', option): + if not cfg.has_option("ah_bootstrap", option): continue if type_ is bool: - value = cfg.getboolean('ah_bootstrap', option) + value = cfg.getboolean("ah_bootstrap", option) else: - value = cfg.get('ah_bootstrap', option) + value = cfg.get("ah_bootstrap", option) config[option] = value @@ -279,18 +299,18 @@ def parse_command_line(cls, argv=None): # of the same name then we will break that. However there's a catch22 # here that we can't just do full argument parsing right here, because # we don't yet know *how* to parse all possible command-line arguments. - if '--no-git' in argv: - config['use_git'] = False - argv.remove('--no-git') + if "--no-git" in argv: + config["use_git"] = False + argv.remove("--no-git") - if '--offline' in argv: - config['offline'] = True - argv.remove('--offline') + if "--offline" in argv: + config["offline"] = True + argv.remove("--offline") return config def run(self): - strategies = ['local_directory', 'local_file', 'index'] + strategies = ["local_directory", "local_file", "index"] dist = None # First, remove any previously imported versions of astropy_helpers; @@ -299,7 +319,7 @@ def run(self): # the case of setup_requires for key in list(sys.modules): try: - if key == PACKAGE_NAME or key.startswith(PACKAGE_NAME + '.'): + if key == PACKAGE_NAME or key.startswith(PACKAGE_NAME + "."): del sys.modules[key] except AttributeError: # Sometimes mysterious non-string things can turn up in @@ -310,7 +330,7 @@ def run(self): self.is_submodule = self._check_submodule() for strategy in strategies: - method = getattr(self, 'get_{0}_dist'.format(strategy)) + method = getattr(self, "get_{0}_dist".format(strategy)) dist = method() if dist is not None: break @@ -318,7 +338,8 @@ def run(self): raise _AHBootstrapSystemExit( "No source found for the {0!r} package; {0} must be " "available and importable as a prerequisite to building " - "or installing this package.".format(PACKAGE_NAME)) + "or installing this package.".format(PACKAGE_NAME) + ) # This is a bit hacky, but if astropy_helpers was loaded from a # directory/submodule its Distribution object gets a "precedence" of @@ -353,8 +374,11 @@ def config(self): with. """ - return dict((optname, getattr(self, optname)) - for optname, _ in CFG_OPTIONS if hasattr(self, optname)) + return dict( + (optname, getattr(self, optname)) + for optname, _ in CFG_OPTIONS + if hasattr(self, optname) + ) def get_local_directory_dist(self): """ @@ -365,17 +389,20 @@ def get_local_directory_dist(self): if not os.path.isdir(self.path): return - log.info('Attempting to import astropy_helpers from {0} {1!r}'.format( - 'submodule' if self.is_submodule else 'directory', - self.path)) + log.info( + "Attempting to import astropy_helpers from {0} {1!r}".format( + "submodule" if self.is_submodule else "directory", self.path + ) + ) dist = self._directory_import() if dist is None: log.warn( - 'The requested path {0!r} for importing {1} does not ' - 'exist, or does not contain a copy of the {1} ' - 'package.'.format(self.path, PACKAGE_NAME)) + "The requested path {0!r} for importing {1} does not " + "exist, or does not contain a copy of the {1} " + "package.".format(self.path, PACKAGE_NAME) + ) elif self.auto_upgrade and not self.is_submodule: # A version of astropy-helpers was found on the available path, but # check to see if a bugfix release is available on PyPI @@ -394,8 +421,10 @@ def get_local_file_dist(self): if not os.path.isfile(self.path): return - log.info('Attempting to unpack and import astropy_helpers from ' - '{0!r}'.format(self.path)) + log.info( + "Attempting to unpack and import astropy_helpers from " + "{0!r}".format(self.path) + ) try: dist = self._do_download(find_links=[self.path]) @@ -404,8 +433,9 @@ def get_local_file_dist(self): raise log.warn( - 'Failed to import {0} from the specified archive {1!r}: ' - '{2}'.format(PACKAGE_NAME, self.path, str(e))) + "Failed to import {0} from the specified archive {1!r}: " + "{2}".format(PACKAGE_NAME, self.path, str(e)) + ) dist = None if dist is not None and self.auto_upgrade: @@ -419,12 +449,13 @@ def get_local_file_dist(self): def get_index_dist(self): if not self.download: - log.warn('Downloading {0!r} disabled.'.format(DIST_NAME)) + log.warn("Downloading {0!r} disabled.".format(DIST_NAME)) return None log.warn( "Downloading {0!r}; run setup.py with the --offline option to " - "force offline installation.".format(DIST_NAME)) + "force offline installation.".format(DIST_NAME) + ) try: dist = self._do_download() @@ -432,8 +463,9 @@ def get_index_dist(self): if DEBUG: raise log.warn( - 'Failed to download and/or install {0!r} from {1!r}:\n' - '{2}'.format(DIST_NAME, self.index_url, str(e))) + "Failed to download and/or install {0!r} from {1!r}:\n" + "{2}".format(DIST_NAME, self.index_url, str(e)) + ) dist = None # No need to run auto-upgrade here since we've already presumably @@ -462,11 +494,10 @@ def _directory_import(self): if dist is None: # We didn't find an egg-info/dist-info in the given path, but if a # setup.py exists we can generate it - setup_py = os.path.join(path, 'setup.py') + setup_py = os.path.join(path, "setup.py") if os.path.isfile(setup_py): with _silence(): - run_setup(os.path.join(path, 'setup.py'), - ['egg_info']) + run_setup(os.path.join(path, "setup.py"), ["egg_info"]) for dist in pkg_resources.find_distributions(path, True): # There should be only one... @@ -474,9 +505,9 @@ def _directory_import(self): return dist - def _do_download(self, version='', find_links=None): + def _do_download(self, version="", find_links=None): if find_links: - allow_hosts = '' + allow_hosts = "" index_url = None else: allow_hosts = None @@ -489,21 +520,21 @@ def _do_download(self, version='', find_links=None): class _Distribution(Distribution): def get_option_dict(self, command_name): opts = Distribution.get_option_dict(self, command_name) - if command_name == 'easy_install': + if command_name == "easy_install": if find_links is not None: - opts['find_links'] = ('setup script', find_links) + opts["find_links"] = ("setup script", find_links) if index_url is not None: - opts['index_url'] = ('setup script', index_url) + opts["index_url"] = ("setup script", index_url) if allow_hosts is not None: - opts['allow_hosts'] = ('setup script', allow_hosts) + opts["allow_hosts"] = ("setup script", allow_hosts) return opts if version: - req = '{0}=={1}'.format(DIST_NAME, version) + req = "{0}=={1}".format(DIST_NAME, version) else: req = DIST_NAME - attrs = {'setup_requires': [req]} + attrs = {"setup_requires": [req]} try: if DEBUG: @@ -519,13 +550,13 @@ def get_option_dict(self, command_name): if DEBUG: raise - msg = 'Error retrieving {0} from {1}:\n{2}' + msg = "Error retrieving {0} from {1}:\n{2}" if find_links: source = find_links[0] elif index_url != INDEX_URL: source = index_url else: - source = 'PyPI' + source = "PyPI" raise Exception(msg.format(DIST_NAME, source, repr(e))) @@ -535,7 +566,8 @@ def _do_upgrade(self, dist): next_version = _next_version(dist.parsed_version) req = pkg_resources.Requirement.parse( - '{0}>{1},<{2}'.format(DIST_NAME, dist.version, next_version)) + "{0}>{1},<{2}".format(DIST_NAME, dist.version, next_version) + ) package_index = PackageIndex(index_url=self.index_url) @@ -552,8 +584,9 @@ def _check_submodule(self): ``_check_submodule_no_git`` for further details. """ - if (self.path is None or - (os.path.exists(self.path) and not os.path.isdir(self.path))): + if self.path is None or ( + os.path.exists(self.path) and not os.path.isdir(self.path) + ): return False if self.use_git: @@ -571,11 +604,13 @@ def _check_submodule_using_git(self): path looks like a git submodule, but it cannot perform updates. """ - cmd = ['git', 'submodule', 'status', '--', self.path] + cmd = ["git", "submodule", "status", "--", self.path] try: - log.info('Running `{0}`; use the --no-git option to disable git ' - 'commands'.format(' '.join(cmd))) + log.info( + "Running `{0}`; use the --no-git option to disable git " + "commands".format(" ".join(cmd)) + ) returncode, stdout, stderr = run_cmd(cmd) except _CommandNotFound: # The git command simply wasn't found; this is most likely the @@ -596,12 +631,15 @@ def _check_submodule_using_git(self): # which only occurs with a malformatted locale setting which can # happen sometimes on OSX. See again # https://github.com/astropy/astropy/issues/2749 - perl_warning = ('perl: warning: Falling back to the standard locale ' - '("C").') + perl_warning = ( + "perl: warning: Falling back to the standard locale " '("C").' + ) if not stderr.strip().endswith(perl_warning): # Some other unknown error condition occurred - log.warn('git submodule command failed ' - 'unexpectedly:\n{0}'.format(stderr)) + log.warn( + "git submodule command failed " + "unexpectedly:\n{0}".format(stderr) + ) return False # Output of `git submodule status` is as follows: @@ -620,21 +658,24 @@ def _check_submodule_using_git(self): # only if the submodule is initialized. We ignore this information for # now _git_submodule_status_re = re.compile( - '^(?P[+-U ])(?P[0-9a-f]{40}) ' - '(?P\S+)( .*)?$') + "^(?P[+-U ])(?P[0-9a-f]{40}) " + "(?P\S+)( .*)?$" + ) # The stdout should only contain one line--the status of the # requested submodule m = _git_submodule_status_re.match(stdout) if m: # Yes, the path *is* a git submodule - self._update_submodule(m.group('submodule'), m.group('status')) + self._update_submodule(m.group("submodule"), m.group("status")) return True else: log.warn( - 'Unexpected output from `git submodule status`:\n{0}\n' - 'Will attempt import from {1!r} regardless.'.format( - stdout, self.path)) + "Unexpected output from `git submodule status`:\n{0}\n" + "Will attempt import from {1!r} regardless.".format( + stdout, self.path + ) + ) return False def _check_submodule_no_git(self): @@ -648,7 +689,7 @@ def _check_submodule_no_git(self): .gitmodules file is changed between git versions. """ - gitmodules_path = os.path.abspath('.gitmodules') + gitmodules_path = os.path.abspath(".gitmodules") if not os.path.isfile(gitmodules_path): return False @@ -667,7 +708,7 @@ def _check_submodule_no_git(self): line = line.lstrip() # comments can start with either # or ; - if line and line[0] in (':', ';'): + if line and line[0] in (":", ";"): continue gitmodules_fileobj.write(line) @@ -679,16 +720,19 @@ def _check_submodule_no_git(self): try: cfg.readfp(gitmodules_fileobj) except Exception as exc: - log.warn('Malformatted .gitmodules file: {0}\n' - '{1} cannot be assumed to be a git submodule.'.format( - exc, self.path)) + log.warn( + "Malformatted .gitmodules file: {0}\n" + "{1} cannot be assumed to be a git submodule.".format( + exc, self.path + ) + ) return False for section in cfg.sections(): - if not cfg.has_option(section, 'path'): + if not cfg.has_option(section, "path"): continue - submodule_path = cfg.get(section, 'path').rstrip(os.sep) + submodule_path = cfg.get(section, "path").rstrip(os.sep) if submodule_path == self.path.rstrip(os.sep): return True @@ -696,43 +740,53 @@ def _check_submodule_no_git(self): return False def _update_submodule(self, submodule, status): - if status == ' ': + if status == " ": # The submodule is up to date; no action necessary return - elif status == '-': + elif status == "-": if self.offline: raise _AHBootstrapSystemExit( "Cannot initialize the {0} submodule in --offline mode; " "this requires being able to clone the submodule from an " - "online repository.".format(submodule)) - cmd = ['update', '--init'] - action = 'Initializing' - elif status == '+': - cmd = ['update'] - action = 'Updating' + "online repository.".format(submodule) + ) + cmd = ["update", "--init"] + action = "Initializing" + elif status == "+": + cmd = ["update"] + action = "Updating" if self.offline: - cmd.append('--no-fetch') - elif status == 'U': + cmd.append("--no-fetch") + elif status == "U": raise _AHBootstrapSystemExit( - 'Error: Submodule {0} contains unresolved merge conflicts. ' - 'Please complete or abandon any changes in the submodule so that ' - 'it is in a usable state, then try again.'.format(submodule)) + "Error: Submodule {0} contains unresolved merge conflicts. " + "Please complete or abandon any changes in the submodule so that " + "it is in a usable state, then try again.".format(submodule) + ) else: - log.warn('Unknown status {0!r} for git submodule {1!r}. Will ' - 'attempt to use the submodule as-is, but try to ensure ' - 'that the submodule is in a clean state and contains no ' - 'conflicts or errors.\n{2}'.format(status, submodule, - _err_help_msg)) + log.warn( + "Unknown status {0!r} for git submodule {1!r}. Will " + "attempt to use the submodule as-is, but try to ensure " + "that the submodule is in a clean state and contains no " + "conflicts or errors.\n{2}".format( + status, submodule, _err_help_msg + ) + ) return err_msg = None - cmd = ['git', 'submodule'] + cmd + ['--', submodule] - log.warn('{0} {1} submodule with: `{2}`'.format( - action, submodule, ' '.join(cmd))) + cmd = ["git", "submodule"] + cmd + ["--", submodule] + log.warn( + "{0} {1} submodule with: `{2}`".format( + action, submodule, " ".join(cmd) + ) + ) try: - log.info('Running `{0}`; use the --no-git option to disable git ' - 'commands'.format(' '.join(cmd))) + log.info( + "Running `{0}`; use the --no-git option to disable git " + "commands".format(" ".join(cmd)) + ) returncode, stdout, stderr = run_cmd(cmd) except OSError as e: err_msg = str(e) @@ -741,9 +795,11 @@ def _update_submodule(self, submodule, status): err_msg = stderr if err_msg is not None: - log.warn('An unexpected error occurred updating the git submodule ' - '{0!r}:\n{1}\n{2}'.format(submodule, err_msg, - _err_help_msg)) + log.warn( + "An unexpected error occurred updating the git submodule " + "{0!r}:\n{1}\n{2}".format(submodule, err_msg, _err_help_msg) + ) + class _CommandNotFound(OSError): """ @@ -771,30 +827,30 @@ def run_cmd(cmd): raise if e.errno == errno.ENOENT: - msg = 'Command not found: `{0}`'.format(' '.join(cmd)) + msg = "Command not found: `{0}`".format(" ".join(cmd)) raise _CommandNotFound(msg, cmd) else: raise _AHBootstrapSystemExit( - 'An unexpected error occurred when running the ' - '`{0}` command:\n{1}'.format(' '.join(cmd), str(e))) - + "An unexpected error occurred when running the " + "`{0}` command:\n{1}".format(" ".join(cmd), str(e)) + ) # Can fail of the default locale is not configured properly. See # https://github.com/astropy/astropy/issues/2749. For the purposes under # consideration 'latin1' is an acceptable fallback. try: - stdio_encoding = locale.getdefaultlocale()[1] or 'latin1' + stdio_encoding = locale.getdefaultlocale()[1] or "latin1" except ValueError: # Due to an OSX oddity locale.getdefaultlocale() can also crash # depending on the user's locale/language settings. See: # http://bugs.python.org/issue18378 - stdio_encoding = 'latin1' + stdio_encoding = "latin1" # Unlikely to fail at this point but even then let's be flexible if not isinstance(stdout, _text_type): - stdout = stdout.decode(stdio_encoding, 'replace') + stdout = stdout.decode(stdio_encoding, "replace") if not isinstance(stderr, _text_type): - stderr = stderr.decode(stdio_encoding, 'replace') + stderr = stderr.decode(stdio_encoding, "replace") return (p.returncode, stdout, stderr) @@ -810,16 +866,16 @@ def _next_version(version): '1.3.0' """ - if hasattr(version, 'base_version'): + if hasattr(version, "base_version"): # New version parsing from setuptools >= 8.0 if version.base_version: - parts = version.base_version.split('.') + parts = version.base_version.split(".") else: parts = [] else: parts = [] for part in version: - if part.startswith('*'): + if part.startswith("*"): break parts.append(part) @@ -830,14 +886,14 @@ def _next_version(version): major, minor, micro = parts[:3] - return '{0}.{1}.{2}'.format(major, minor + 1, 0) + return "{0}.{1}.{2}".format(major, minor + 1, 0) class _DummyFile(object): """A noop writeable object.""" - errors = '' # Required for Python 3.x - encoding = 'utf-8' + errors = "" # Required for Python 3.x + encoding = "utf-8" def write(self, s): pass @@ -880,11 +936,11 @@ def _silence(): class _AHBootstrapSystemExit(SystemExit): def __init__(self, *args): if not args: - msg = 'An unknown problem occurred bootstrapping astropy_helpers.' + msg = "An unknown problem occurred bootstrapping astropy_helpers." else: msg = args[0] - msg += '\n' + _err_help_msg + msg += "\n" + _err_help_msg super(_AHBootstrapSystemExit, self).__init__(msg, *args[1:]) diff --git a/docs/_static/RXJ1713_IC.py b/docs/_static/RXJ1713_IC.py index c199977d..72ab938b 100644 --- a/docs/_static/RXJ1713_IC.py +++ b/docs/_static/RXJ1713_IC.py @@ -16,14 +16,17 @@ def ElectronIC(pars, data): # Match parameters to ECPL properties, and give them the appropriate units amplitude = pars[0] / u.eV alpha = pars[1] - e_cutoff = (10**pars[2]) * u.TeV + e_cutoff = (10 ** pars[2]) * u.TeV # Initialize instances of the particle distribution and radiative model - ECPL = ExponentialCutoffPowerLaw(amplitude, 10. * u.TeV, alpha, e_cutoff) + ECPL = ExponentialCutoffPowerLaw(amplitude, 10.0 * u.TeV, alpha, e_cutoff) IC = InverseCompton( ECPL, - seed_photon_fields=['CMB', - ['FIR', 26.5 * u.K, 0.415 * u.eV / u.cm**3]]) + seed_photon_fields=[ + "CMB", + ["FIR", 26.5 * u.K, 0.415 * u.eV / u.cm ** 3], + ], + ) # compute flux at the energies given in data['energy'], and convert to # units of flux data @@ -51,28 +54,26 @@ def lnprior(pars): Parameter limits should be done here through uniform prior ditributions """ - logprob = (naima.uniform_prior(pars[0], 0., np.inf) + - naima.uniform_prior(pars[1], -1, 5)) + logprob = naima.uniform_prior(pars[0], 0.0, np.inf) + naima.uniform_prior( + pars[1], -1, 5 + ) return logprob -if __name__ == '__main__': +if __name__ == "__main__": # Set initial parameters and labels - p0 = np.array(( - 1e30, - 3.0, - np.log10(30))) - labels = ['norm', 'index', 'log10(cutoff)'] + p0 = np.array((1e30, 3.0, np.log10(30))) + labels = ["norm", "index", "log10(cutoff)"] - samplerf = 'RXJ1713_IC_sampler.hdf5' - if os.path.exists(samplerf) and 'onlyplot' in sys.argv: + samplerf = "RXJ1713_IC_sampler.hdf5" + if os.path.exists(samplerf) and "onlyplot" in sys.argv: sampler = naima.read_run(samplerf, modelfn=ElectronIC) else: # Read data - data = ascii.read('../../examples/RXJ1713_HESS_2007.dat') + data = ascii.read("../../examples/RXJ1713_HESS_2007.dat") # Run sampler sampler, pos = naima.run_sampler( data_table=data, @@ -85,83 +86,89 @@ def lnprior(pars): nrun=100, threads=4, prefit=True, - interactive=True) + interactive=True, + ) # Save sampler - naima.save_run('RXJ1713_IC_sampler.hdf5', sampler) + naima.save_run("RXJ1713_IC_sampler.hdf5", sampler) # Diagnostic plots - naima.save_results_table('RXJ1713_IC', sampler) + naima.save_results_table("RXJ1713_IC", sampler) from astropy.io import ascii - results = ascii.read('RXJ1713_IC_results.ecsv') + + results = ascii.read("RXJ1713_IC_results.ecsv") results.remove_row(-1) # remove blob2 - for col in ['median', 'unc_lo', 'unc_hi']: - results[col].format = '.3g' + for col in ["median", "unc_lo", "unc_hi"]: + results[col].format = ".3g" - with open('RXJ1713_IC_results_table.txt', 'w') as f: + with open("RXJ1713_IC_results_table.txt", "w") as f: info = [] - for key in ['n_walkers', 'n_run', 'p0', 'ML_pars', 'MaxLogLikelihood']: - info.append('{0:<18}: {1}\n'.format(key, str(results.meta[key]))) + for key in ["n_walkers", "n_run", "p0", "ML_pars", "MaxLogLikelihood"]: + info.append("{0:<18}: {1}\n".format(key, str(results.meta[key]))) f.writelines(info) - f.write('\n') - f.write('------------- ------- ------- --------\n') - results.write(f, format='ascii.fixed_width_two_line') + f.write("\n") + f.write("------------- ------- ------- --------\n") + results.write(f, format="ascii.fixed_width_two_line") alabaster_width = 660 alabaster_dpi = 100 * alabaster_width / 800 - print('Plotting chains...') + print("Plotting chains...") f = naima.plot_chain(sampler, 1) - f.savefig('RXJ1713_IC_chain_index.png', dpi=alabaster_dpi) + f.savefig("RXJ1713_IC_chain_index.png", dpi=alabaster_dpi) f = naima.plot_chain(sampler, 2) - f.savefig('RXJ1713_IC_chain_cutoff.png', dpi=alabaster_dpi) + f.savefig("RXJ1713_IC_chain_cutoff.png", dpi=alabaster_dpi) e_range = [100 * u.GeV, 500 * u.TeV] # with samples - print('Plotting samples...') + print("Plotting samples...") f = naima.plot_fit(sampler, 0, ML_info=False) f.axes[0].set_ylim(1e-13, 2e-10) f.tight_layout() f.subplots_adjust(hspace=0) - f.savefig('RXJ1713_IC_model_samples.png', dpi=alabaster_dpi) - print('Plotting samples with e_range...') + f.savefig("RXJ1713_IC_model_samples.png", dpi=alabaster_dpi) + print("Plotting samples with e_range...") f = naima.plot_fit( - sampler, 0, e_range=e_range, ML_info=False, n_samples=500) + sampler, 0, e_range=e_range, ML_info=False, n_samples=500 + ) f.axes[0].set_ylim(1e-13, 2e-10) f.tight_layout() f.subplots_adjust(hspace=0) - f.savefig('RXJ1713_IC_model_samples_erange.png', dpi=alabaster_dpi) + f.savefig("RXJ1713_IC_model_samples_erange.png", dpi=alabaster_dpi) # with confs - print('Plotting confs...') - f = naima.plot_fit(sampler, 0, ML_info=False, confs=[3, 1], - last_step=False) + print("Plotting confs...") + f = naima.plot_fit( + sampler, 0, ML_info=False, confs=[3, 1], last_step=False + ) f.axes[0].set_ylim(1e-13, 2e-10) f.tight_layout() f.subplots_adjust(hspace=0) - f.savefig('RXJ1713_IC_model_confs.png', dpi=alabaster_dpi) - print('Plotting confs with e_range...') - f = naima.plot_fit(sampler, 0, e_range=e_range, ML_info=False, - confs=[3, 1]) + f.savefig("RXJ1713_IC_model_confs.png", dpi=alabaster_dpi) + print("Plotting confs with e_range...") + f = naima.plot_fit( + sampler, 0, e_range=e_range, ML_info=False, confs=[3, 1] + ) f.axes[0].set_ylim(1e-13, 2e-10) f.tight_layout() f.subplots_adjust(hspace=0) - f.savefig('RXJ1713_IC_model_confs_erange.png', dpi=alabaster_dpi) + f.savefig("RXJ1713_IC_model_confs_erange.png", dpi=alabaster_dpi) - print('Plotting corner...') + print("Plotting corner...") f = naima.plot_corner(sampler, bins=40) w = f.get_size_inches()[0] - f.savefig('RXJ1713_IC_corner.png', dpi=alabaster_width / w) + f.savefig("RXJ1713_IC_corner.png", dpi=alabaster_width / w) - print('Plotting blobs...') + print("Plotting blobs...") f = naima.plot_blob( sampler, 1, ML_info=False, - label='Electron energy distribution', - xlabel=r'Electron energy [$\mathrm{TeV}$]') + label="Electron energy distribution", + xlabel=r"Electron energy [$\mathrm{TeV}$]", + ) f.tight_layout() - f.savefig('RXJ1713_IC_pdist.png', dpi=alabaster_dpi) - f = naima.plot_blob(sampler, 2, label=r'$W_e(E_e>1\,\mathrm{TeV})$') - f.savefig('RXJ1713_IC_We.png', dpi=alabaster_dpi) + f.savefig("RXJ1713_IC_pdist.png", dpi=alabaster_dpi) + f = naima.plot_blob(sampler, 2, label=r"$W_e(E_e>1\,\mathrm{TeV})$") + f.savefig("RXJ1713_IC_We.png", dpi=alabaster_dpi) diff --git a/docs/conf.py b/docs/conf.py index ec673fde..b2cfcdec 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -33,15 +33,15 @@ import astropy_helpers except ImportError: # Building from inside the docs/ directory? - if os.path.basename(os.getcwd()) == 'docs': - a_h_path = os.path.abspath(os.path.join('..', 'astropy_helpers')) + if os.path.basename(os.getcwd()) == "docs": + a_h_path = os.path.abspath(os.path.join("..", "astropy_helpers")) if os.path.isdir(a_h_path): sys.path.insert(1, a_h_path) # If that doesn't work trying to import from astropy_helpers below will # still blow up -on_rtd = os.environ.get('READTHEDOCS', None) == 'True' +on_rtd = os.environ.get("READTHEDOCS", None) == "True" if on_rtd: if sys.version_info.major > 2: @@ -52,9 +52,9 @@ class Mock(MagicMock): @classmethod def __getattr__(cls, name): - return Mock() + return Mock() - MOCK_MODULES = ['h5py',] + MOCK_MODULES = ["h5py"] sys.modules.update((mod_name, Mock()) for mod_name in MOCK_MODULES) # Load all of the global Astropy configuration @@ -66,22 +66,22 @@ def __getattr__(cls, name): except ImportError: from configparser import ConfigParser conf = ConfigParser() -conf.read([os.path.join(os.path.dirname(__file__), '..', 'setup.cfg')]) -setup_cfg = dict(conf.items('metadata')) +conf.read([os.path.join(os.path.dirname(__file__), "..", "setup.cfg")]) +setup_cfg = dict(conf.items("metadata")) # -- General configuration ---------------------------------------------------- # If your documentation needs a minimal Sphinx version, state it here. -needs_sphinx = '1.3' +needs_sphinx = "1.3" # del intersphinx_mapping['h5py'] -intersphinx_mapping['emcee'] = ('http://dan.iel.fm/emcee/current/', None) +intersphinx_mapping["emcee"] = ("http://dan.iel.fm/emcee/current/", None) # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. -exclude_patterns.append('_templates') +exclude_patterns.append("_templates") # This is added to the end of RST files - a good place to put substitutions to # be used globally. @@ -91,20 +91,21 @@ def __getattr__(cls, name): # -- Project information ------------------------------------------------------ # This does not *have* to match the package name, but typically does -project = setup_cfg['package_name'] -author = setup_cfg['author'] -copyright = '{0}, {1}'.format( - datetime.datetime.now().year, setup_cfg['author']) +project = setup_cfg["package_name"] +author = setup_cfg["author"] +copyright = "{0}, {1}".format( + datetime.datetime.now().year, setup_cfg["author"] +) # The version info for the project you're documenting, acts as replacement for # |version| and |release|, also used in various other places throughout the # built documents. -__import__(setup_cfg['package_name']) -package = sys.modules[setup_cfg['package_name']] +__import__(setup_cfg["package_name"]) +package = sys.modules[setup_cfg["package_name"]] # The short X.Y version. -version = package.__version__.split('-', 1)[0] +version = package.__version__.split("-", 1)[0] # The full version, including alpha/beta/rc tags. release = package.__version__ @@ -120,80 +121,77 @@ def __getattr__(cls, name): # Add any paths that contain custom themes here, relative to this directory. # To use a different custom theme, add the directory containing the theme. -#html_theme_path = [] +# html_theme_path = [] # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. To override the custom theme, set this to the # name of a builtin theme or the name of a custom theme in html_theme_path. -html_theme = 'alabaster' +html_theme = "alabaster" html_theme_options = { - 'description': '''Python package for computation of non-thermal + "description": """Python package for computation of non-thermal radiation from relativistic particle populations and - MCMC fitting to observed spectra''', - 'github_user': 'zblz', - 'github_repo': 'naima', - 'github_banner': True, - 'github_button': False, - 'travis_button': False, + MCMC fitting to observed spectra""", + "github_user": "zblz", + "github_repo": "naima", + "github_banner": True, + "github_button": False, + "travis_button": False, # use sans-serif fonts - 'font_family': "'Myriad Pro', Calibri, Helvetica, Arial, sans-serif", - 'head_font_family': "'Lucida Grande', 'Calibri', Helvetica, Arial, sans-serif", - 'show_powered_by': False, - 'show_related': True, - 'code_font_size': '0.7em', - } + "font_family": "'Myriad Pro', Calibri, Helvetica, Arial, sans-serif", + "head_font_family": "'Lucida Grande', 'Calibri', Helvetica, Arial, sans-serif", + "show_powered_by": False, + "show_related": True, + "code_font_size": "0.7em", +} # Custom sidebar templates, maps document names to template names. -#html_sidebars = {} +# html_sidebars = {} html_sidebars = { - '**': [ - 'about.html', - 'navigation.html', - 'relations.html', - 'searchbox.html', - ] + "**": ["about.html", "navigation.html", "relations.html", "searchbox.html"] } # The name of an image file (within the static path) to use as favicon of the # docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32 # pixels large. -#html_favicon = '' +# html_favicon = '' # If not '', a 'Last updated on:' timestamp is inserted at every page bottom, # using the given strftime format. -#html_last_updated_fmt = '' +# html_last_updated_fmt = '' # The name for this set of Sphinx documents. If None, it defaults to # " v documentation". -html_title = '{0} v{1}'.format(project, release) +html_title = "{0} v{1}".format(project, release) # Output file base name for HTML help builder. -htmlhelp_basename = project + 'doc' +htmlhelp_basename = project + "doc" # -- Options for LaTeX output -------------------------------------------------- # Grouping the document tree into LaTeX files. List of tuples # (source start file, target name, title, author, documentclass [howto/manual]). -latex_documents = [('index', project + '.tex', project + u' Documentation', - author, 'manual')] +latex_documents = [ + ("index", project + ".tex", project + u" Documentation", author, "manual") +] # -- Options for manual page output -------------------------------------------- # One entry per manual page. List of tuples # (source start file, name, description, authors, manual section). -man_pages = [('index', project.lower(), project + u' Documentation', - [author], 1)] +man_pages = [ + ("index", project.lower(), project + u" Documentation", [author], 1) +] ## -- Options for the edit_on_github extension ---------------------------------------- -if eval(setup_cfg.get('edit_on_github')): - extensions += ['astropy_helpers.sphinx.ext.edit_on_github'] +if eval(setup_cfg.get("edit_on_github")): + extensions += ["astropy_helpers.sphinx.ext.edit_on_github"] - versionmod = __import__(setup_cfg['package_name'] + '.version') - edit_on_github_project = setup_cfg['github_project'] + versionmod = __import__(setup_cfg["package_name"] + ".version") + edit_on_github_project = setup_cfg["github_project"] if versionmod.version.release: edit_on_github_branch = "v" + versionmod.version.version else: @@ -201,4 +199,3 @@ def __getattr__(cls, name): edit_on_github_source_root = "" edit_on_github_doc_root = "docs" - diff --git a/examples/CrabNebula_SynSSC.py b/examples/CrabNebula_SynSSC.py index 41ca9955..f372f302 100644 --- a/examples/CrabNebula_SynSSC.py +++ b/examples/CrabNebula_SynSSC.py @@ -3,18 +3,23 @@ from astropy.constants import c import astropy.units as u import naima -from naima.models import (ExponentialCutoffBrokenPowerLaw, Synchrotron, - InverseCompton) +from naima.models import ( + ExponentialCutoffBrokenPowerLaw, + Synchrotron, + InverseCompton, +) -ECBPL = ExponentialCutoffBrokenPowerLaw(amplitude=3.699e36 / u.eV, - e_0=1 * u.TeV, - e_break=0.265 * u.TeV, - alpha_1=1.5, - alpha_2=3.233, - e_cutoff=1863 * u.TeV, - beta=2.) +ECBPL = ExponentialCutoffBrokenPowerLaw( + amplitude=3.699e36 / u.eV, + e_0=1 * u.TeV, + e_break=0.265 * u.TeV, + alpha_1=1.5, + alpha_2=3.233, + e_cutoff=1863 * u.TeV, + beta=2.0, +) -eopts = {'Eemax': 50 * u.PeV, 'Eemin': 0.1 * u.GeV} +eopts = {"Eemax": 50 * u.PeV, "Eemin": 0.1 * u.GeV} SYN = Synchrotron(ECBPL, B=125 * u.uG, Eemax=50 * u.PeV, Eemin=0.1 * u.GeV) @@ -22,31 +27,48 @@ Rpwn = 2.1 * u.pc Esy = np.logspace(-7, 9, 100) * u.eV Lsy = SYN.flux(Esy, distance=0 * u.cm) # use distance 0 to get luminosity -phn_sy = Lsy / (4 * np.pi * Rpwn**2 * c) * 2.24 +phn_sy = Lsy / (4 * np.pi * Rpwn ** 2 * c) * 2.24 -IC = InverseCompton(ECBPL, - seed_photon_fields=['CMB', - ['FIR', 70 * u.K, 0.5 * u.eV / u.cm**3], - ['NIR', 5000 * u.K, 1 * u.eV / u.cm**3], - ['SSC', Esy, phn_sy]], - Eemax=50 * u.PeV, Eemin=0.1 * u.GeV) +IC = InverseCompton( + ECBPL, + seed_photon_fields=[ + "CMB", + ["FIR", 70 * u.K, 0.5 * u.eV / u.cm ** 3], + ["NIR", 5000 * u.K, 1 * u.eV / u.cm ** 3], + ["SSC", Esy, phn_sy], + ], + Eemax=50 * u.PeV, + Eemin=0.1 * u.GeV, +) # Use plot_data from naima to plot the observed spectra -data = ascii.read('CrabNebula_spectrum.ecsv') +data = ascii.read("CrabNebula_spectrum.ecsv") figure = naima.plot_data(data, e_unit=u.eV) ax = figure.axes[0] # Plot the computed model emission energy = np.logspace(-7, 15, 100) * u.eV -ax.loglog(energy, IC.sed(energy, 2 * u.kpc) + SYN.sed(energy, 2 * u.kpc), - lw=3, c='k', label='Total') +ax.loglog( + energy, + IC.sed(energy, 2 * u.kpc) + SYN.sed(energy, 2 * u.kpc), + lw=3, + c="k", + label="Total", +) for i, seed, ls in zip( - range(4), ['CMB', 'FIR', 'NIR', 'SSC'], ['--', '-.', ':', '-']): - ax.loglog(energy, IC.sed(energy, 2 * u.kpc, seed=seed), - lw=2, c=naima.plot.color_cycle[i + 1], label=seed, ls=ls) + range(4), ["CMB", "FIR", "NIR", "SSC"], ["--", "-.", ":", "-"] +): + ax.loglog( + energy, + IC.sed(energy, 2 * u.kpc, seed=seed), + lw=2, + c=naima.plot.color_cycle[i + 1], + label=seed, + ls=ls, + ) ax.set_ylim(1e-12, 1e-7) -ax.legend(loc='upper right', frameon=False) +ax.legend(loc="upper right", frameon=False) figure.tight_layout() -figure.savefig('CrabNebula_SynSSC.png') +figure.savefig("CrabNebula_SynSSC.png") diff --git a/examples/RXJ1713_IC.py b/examples/RXJ1713_IC.py index 8e354d35..4d6ca63c 100644 --- a/examples/RXJ1713_IC.py +++ b/examples/RXJ1713_IC.py @@ -6,7 +6,7 @@ ## Read data -data = ascii.read('RXJ1713_HESS_2007.dat') +data = ascii.read("RXJ1713_HESS_2007.dat") ## Model definition @@ -18,20 +18,24 @@ def ElectronIC(pars, data): # Match parameters to ECPL properties, and give them the appropriate units amplitude = pars[0] / u.eV alpha = pars[1] - e_cutoff = (10**pars[2]) * u.TeV + e_cutoff = (10 ** pars[2]) * u.TeV # Initialize instances of the particle distribution and radiative model - ECPL = ExponentialCutoffPowerLaw(amplitude, 10. * u.TeV, alpha, e_cutoff) + ECPL = ExponentialCutoffPowerLaw(amplitude, 10.0 * u.TeV, alpha, e_cutoff) # Compute IC on CMB and on a FIR component with values from GALPROP for the # position of RXJ1713 IC = InverseCompton( ECPL, - seed_photon_fields=['CMB', ['FIR', 26.5 * u.K, 0.415 * u.eV / u.cm**3]], - Eemin=100 * u.GeV) + seed_photon_fields=[ + "CMB", + ["FIR", 26.5 * u.K, 0.415 * u.eV / u.cm ** 3], + ], + Eemin=100 * u.GeV, + ) # compute flux at the energies given in data['energy'], and convert to units # of flux data - model = IC.flux(data, distance=1.0 * u.kpc).to(data['flux'].unit) + model = IC.flux(data, distance=1.0 * u.kpc).to(data["flux"].unit) # Save this realization of the particle distribution function elec_energy = np.logspace(11, 15, 100) * u.eV @@ -45,6 +49,7 @@ def ElectronIC(pars, data): # blobs. return model, (elec_energy, nelec), We + ## Prior definition @@ -54,39 +59,46 @@ def lnprior(pars): Parameter limits should be done here through uniform prior ditributions """ - logprob = naima.uniform_prior(pars[0], 0., np.inf) \ - + naima.uniform_prior(pars[1], -1, 5) + logprob = naima.uniform_prior(pars[0], 0.0, np.inf) + naima.uniform_prior( + pars[1], -1, 5 + ) return logprob -if __name__ == '__main__': +if __name__ == "__main__": ## Set initial parameters and labels - p0 = np.array((1e30, 3.0, np.log10(30),)) - labels = ['norm', 'index', 'log10(cutoff)'] + p0 = np.array((1e30, 3.0, np.log10(30))) + labels = ["norm", "index", "log10(cutoff)"] ## Run sampler - sampler, pos = naima.run_sampler(data_table=data, - p0=p0, - labels=labels, - model=ElectronIC, - prior=lnprior, - nwalkers=32, - nburn=100, - nrun=20, - threads=4, - prefit=True) + sampler, pos = naima.run_sampler( + data_table=data, + p0=p0, + labels=labels, + model=ElectronIC, + prior=lnprior, + nwalkers=32, + nburn=100, + nrun=20, + threads=4, + prefit=True, + ) ## Save run results to HDF5 file (can be read later with naima.read_run) - naima.save_run('RXJ1713_IC_run.hdf5', sampler) + naima.save_run("RXJ1713_IC_run.hdf5", sampler) ## Diagnostic plots with labels for the metadata blobs naima.save_diagnostic_plots( - 'RXJ1713_IC', + "RXJ1713_IC", sampler, sed=True, last_step=False, - blob_labels=['Spectrum', 'Electron energy distribution', - '$W_e (E_e>1\, \mathrm{TeV})$']) - naima.save_results_table('RXJ1713_IC', sampler) + blob_labels=[ + "Spectrum", + "Electron energy distribution", + "$W_e (E_e>1\, \mathrm{TeV})$", + ], + ) + naima.save_results_table("RXJ1713_IC", sampler) diff --git a/examples/RXJ1713_IC_minimal.py b/examples/RXJ1713_IC_minimal.py index 2a3b60a9..c0abb319 100644 --- a/examples/RXJ1713_IC_minimal.py +++ b/examples/RXJ1713_IC_minimal.py @@ -7,7 +7,7 @@ ## Read data -data = ascii.read('RXJ1713_HESS_2007.dat') +data = ascii.read("RXJ1713_HESS_2007.dat") def ElectronIC(pars, data): @@ -16,39 +16,42 @@ def ElectronIC(pars, data): at data energy values """ - ECPL = ExponentialCutoffPowerLaw(pars[0] / u.eV, 10. * u.TeV, pars[1], - 10**pars[2] * u.TeV) - IC = InverseCompton(ECPL, seed_photon_fields=['CMB']) + ECPL = ExponentialCutoffPowerLaw( + pars[0] / u.eV, 10.0 * u.TeV, pars[1], 10 ** pars[2] * u.TeV + ) + IC = InverseCompton(ECPL, seed_photon_fields=["CMB"]) return IC.flux(data, distance=1.0 * u.kpc) def lnprior(pars): # Limit amplitude to positive domain - logprob = naima.uniform_prior(pars[0], 0., np.inf) + logprob = naima.uniform_prior(pars[0], 0.0, np.inf) return logprob -if __name__ == '__main__': +if __name__ == "__main__": ## Set initial parameters and labels - p0 = np.array((1e30, 3.0, np.log10(30),)) - labels = ['norm', 'index', 'log10(cutoff)'] + p0 = np.array((1e30, 3.0, np.log10(30))) + labels = ["norm", "index", "log10(cutoff)"] ## Run sampler - sampler, pos = naima.run_sampler(data_table=data, - p0=p0, - labels=labels, - model=ElectronIC, - prior=lnprior, - nwalkers=32, - nburn=100, - nrun=20, - threads=4, - prefit=True, - interactive=False) + sampler, pos = naima.run_sampler( + data_table=data, + p0=p0, + labels=labels, + model=ElectronIC, + prior=lnprior, + nwalkers=32, + nburn=100, + nrun=20, + threads=4, + prefit=True, + interactive=False, + ) ## Save run results - out_root = 'RXJ1713_IC_minimal' + out_root = "RXJ1713_IC_minimal" naima.save_run(out_root, sampler) ## Save diagnostic plots and results table diff --git a/examples/RXJ1713_SynIC.py b/examples/RXJ1713_SynIC.py index 4e7e5108..43e591d2 100644 --- a/examples/RXJ1713_SynIC.py +++ b/examples/RXJ1713_SynIC.py @@ -9,8 +9,8 @@ # We only consider every fifth X-ray spectral point to speed-up calculations for this example # DO NOT do this for a final analysis! -soft_xray = ascii.read('RXJ1713_Suzaku-XIS.dat')[::5] -vhe = ascii.read('RXJ1713_HESS_2007.dat') +soft_xray = ascii.read("RXJ1713_Suzaku-XIS.dat")[::5] +vhe = ascii.read("RXJ1713_HESS_2007.dat") ## Model definition @@ -20,31 +20,36 @@ def ElectronSynIC(pars, data): # Match parameters to ECPL properties, and give them the appropriate units - amplitude = 10**pars[0] / u.eV + amplitude = 10 ** pars[0] / u.eV alpha = pars[1] - e_cutoff = (10**pars[2]) * u.TeV + e_cutoff = (10 ** pars[2]) * u.TeV B = pars[3] * u.uG # Initialize instances of the particle distribution and radiative models - ECPL = ExponentialCutoffPowerLaw(amplitude, 10. * u.TeV, alpha, e_cutoff) + ECPL = ExponentialCutoffPowerLaw(amplitude, 10.0 * u.TeV, alpha, e_cutoff) # Compute IC on CMB and on a FIR component with values from GALPROP for the # position of RXJ1713 IC = InverseCompton( ECPL, - seed_photon_fields=['CMB', ['FIR', 26.5 * u.K, 0.415 * u.eV / u.cm**3]], - Eemin=100 * u.GeV) + seed_photon_fields=[ + "CMB", + ["FIR", 26.5 * u.K, 0.415 * u.eV / u.cm ** 3], + ], + Eemin=100 * u.GeV, + ) SYN = Synchrotron(ECPL, B=B) # compute flux at the energies given in data['energy'] - model = (IC.flux(data, - distance=1.0 * u.kpc) + SYN.flux(data, - distance=1.0 * u.kpc)) + model = IC.flux(data, distance=1.0 * u.kpc) + SYN.flux( + data, distance=1.0 * u.kpc + ) # The first array returned will be compared to the observed spectrum for # fitting. All subsequent objects will be stored in the sampler metadata # blobs. return model, IC.compute_We(Eemin=1 * u.TeV) + ## Prior definition @@ -54,41 +59,47 @@ def lnprior(pars): Parameter limits should be done here through uniform prior ditributions """ # Limit norm and B to be positive - logprob = naima.uniform_prior(pars[0], 0., np.inf) \ - + naima.uniform_prior(pars[1], -1, 5) \ - + naima.uniform_prior(pars[3], 0, np.inf) + logprob = ( + naima.uniform_prior(pars[0], 0.0, np.inf) + + naima.uniform_prior(pars[1], -1, 5) + + naima.uniform_prior(pars[3], 0, np.inf) + ) return logprob -if __name__ == '__main__': +if __name__ == "__main__": ## Set initial parameters and labels # Estimate initial magnetic field and get value in uG - B0 = 2 * naima.estimate_B(soft_xray, vhe).to('uG').value + B0 = 2 * naima.estimate_B(soft_xray, vhe).to("uG").value p0 = np.array((33, 2.5, np.log10(48.0), B0)) - labels = ['log10(norm)', 'index', 'log10(cutoff)', 'B'] + labels = ["log10(norm)", "index", "log10(cutoff)", "B"] ## Run sampler - sampler, pos = naima.run_sampler(data_table=[soft_xray, vhe], - p0=p0, - labels=labels, - model=ElectronSynIC, - prior=lnprior, - nwalkers=32, - nburn=100, - nrun=20, - threads=4, - prefit=True, - interactive=False) + sampler, pos = naima.run_sampler( + data_table=[soft_xray, vhe], + p0=p0, + labels=labels, + model=ElectronSynIC, + prior=lnprior, + nwalkers=32, + nburn=100, + nrun=20, + threads=4, + prefit=True, + interactive=False, + ) ## Save run results to HDF5 file (can be read later with naima.read_run) - naima.save_run('RXJ1713_SynIC', sampler) + naima.save_run("RXJ1713_SynIC", sampler) ## Diagnostic plots - naima.save_diagnostic_plots('RXJ1713_SynIC', - sampler, - sed=True, - blob_labels=['Spectrum', '$W_e$($E_e>1$ TeV)']) - naima.save_results_table('RXJ1713_SynIC', sampler) + naima.save_diagnostic_plots( + "RXJ1713_SynIC", + sampler, + sed=True, + blob_labels=["Spectrum", "$W_e$($E_e>1$ TeV)"], + ) + naima.save_results_table("RXJ1713_SynIC", sampler) diff --git a/examples/absorbed_SynIC.py b/examples/absorbed_SynIC.py index 9d473811..ea368e2a 100644 --- a/examples/absorbed_SynIC.py +++ b/examples/absorbed_SynIC.py @@ -7,35 +7,39 @@ # Model definition -from naima.models import InverseCompton, Synchrotron, ExponentialCutoffPowerLaw, BrokenPowerLaw, EblAbsorptionModel +from naima.models import ( + InverseCompton, + Synchrotron, + ExponentialCutoffPowerLaw, + BrokenPowerLaw, + EblAbsorptionModel, +) def ElectronEblAbsorbedSynIC(pars, data): # Match parameters to ECPL properties, and give them the appropriate units - amplitude = 10**pars[0] / u.eV - e_break = (10**pars[1]) * u.TeV + amplitude = 10 ** pars[0] / u.eV + e_break = (10 ** pars[1]) * u.TeV alpha1 = pars[2] alpha2 = pars[3] B = pars[4] * u.uG # Define the redshift of the source, and absorption model redshift = pars[5] * u.dimensionless_unscaled - EBL_transmitance = EblAbsorptionModel(redshift, 'Dominguez') + EBL_transmitance = EblAbsorptionModel(redshift, "Dominguez") # Initialize instances of the particle distribution and radiative models - BPL = BrokenPowerLaw(amplitude, 1. * u.TeV, e_break, alpha1, alpha2) + BPL = BrokenPowerLaw(amplitude, 1.0 * u.TeV, e_break, alpha1, alpha2) # Compute IC on a CMB component - IC = InverseCompton( - BPL, - seed_photon_fields=['CMB'], - Eemin=10 * u.GeV) + IC = InverseCompton(BPL, seed_photon_fields=["CMB"], Eemin=10 * u.GeV) SYN = Synchrotron(BPL, B=B) # compute flux at the energies given in data['energy'] - model = (EBL_transmitance.transmission(data) * IC.flux(data, distance=1.0 * u.kpc) + - SYN.flux(data, distance=1.0 * u.kpc)) + model = EBL_transmitance.transmission(data) * IC.flux( + data, distance=1.0 * u.kpc + ) + SYN.flux(data, distance=1.0 * u.kpc) # The first array returned will be compared to the observed spectrum for # fitting. All subsequent objects will be stored in the sampler metadata @@ -43,12 +47,24 @@ def ElectronEblAbsorbedSynIC(pars, data): return model, IC.compute_We(Eemin=1 * u.TeV) -if __name__ == '__main__': +if __name__ == "__main__": # Some random values for a "beautiful double peak structure - p0 = np.array((31., 1., 0.35, 1.5, 2.3, 0.06)) + p0 = np.array((31.0, 1.0, 0.35, 1.5, 2.3, 0.06)) - labels = ['log10(norm)', 'log10(Energy_Break)', 'index1', 'index2', 'B', 'redshift'] + labels = [ + "log10(norm)", + "log10(Energy_Break)", + "index1", + "index2", + "B", + "redshift", + ] # Run interactive fitter, to show the very high energy absorption - imf = naima.InteractiveModelFitter(ElectronEblAbsorbedSynIC, p0, labels=labels, e_range=[1e-09*u.GeV, 1e05*u.GeV]) + imf = naima.InteractiveModelFitter( + ElectronEblAbsorbedSynIC, + p0, + labels=labels, + e_range=[1e-09 * u.GeV, 1e05 * u.GeV], + ) diff --git a/examples/model_examples.py b/examples/model_examples.py index 2599eb5b..7e002d3c 100644 --- a/examples/model_examples.py +++ b/examples/model_examples.py @@ -15,21 +15,22 @@ # Pion decay # ========== -PionDecay_ECPL_p0 = np.array((46, 2.34, np.log10(80.))) -PionDecay_ECPL_labels = ['log10(norm)', 'index', 'log10(cutoff)'] +PionDecay_ECPL_p0 = np.array((46, 2.34, np.log10(80.0))) +PionDecay_ECPL_labels = ["log10(norm)", "index", "log10(cutoff)"] # Prepare an energy array for saving the particle distribution proton_energy = np.logspace(-3, 2, 50) * u.TeV def PionDecay_ECPL(pars, data): - amplitude = 10**pars[0] / u.TeV + amplitude = 10 ** pars[0] / u.TeV alpha = pars[1] - e_cutoff = 10**pars[2] * u.TeV + e_cutoff = 10 ** pars[2] * u.TeV - ECPL = naima.models.ExponentialCutoffPowerLaw(amplitude, 30 * u.TeV, alpha, - e_cutoff) - PP = naima.models.PionDecay(ECPL, nh=1.0 * u.cm** -3) + ECPL = naima.models.ExponentialCutoffPowerLaw( + amplitude, 30 * u.TeV, alpha, e_cutoff + ) + PP = naima.models.PionDecay(ECPL, nh=1.0 * u.cm ** -3) model = PP.flux(data, distance=1.0 * u.kpc) # Save a realization of the particle distribution to the metadata blob @@ -46,26 +47,28 @@ def PionDecay_ECPL_lnprior(pars): logprob = naima.uniform_prior(pars[1], -1, 5) return logprob + # Inverse Compton with the energy in electrons as the normalization parameter # =========================================================================== IC_We_p0 = np.array((40, 3.0, np.log10(30))) -IC_We_labels = ['log10(We)', 'index', 'log10(cutoff)'] +IC_We_labels = ["log10(We)", "index", "log10(cutoff)"] def IC_We(pars, data): # Example of a model that is normalized though the total energy in electrons # Match parameters to ECPL properties, and give them the appropriate units - We = 10**pars[0] * u.erg + We = 10 ** pars[0] * u.erg alpha = pars[1] - e_cutoff = 10**pars[2] * u.TeV + e_cutoff = 10 ** pars[2] * u.TeV # Initialize instances of the particle distribution and radiative model # set a bogus normalization that will be changed in third line - ECPL = naima.models.ExponentialCutoffPowerLaw(1 / u.eV, 10. * u.TeV, - alpha, e_cutoff) - IC = naima.models.InverseCompton(ECPL, seed_photon_fields=['CMB']) + ECPL = naima.models.ExponentialCutoffPowerLaw( + 1 / u.eV, 10.0 * u.TeV, alpha, e_cutoff + ) + IC = naima.models.InverseCompton(ECPL, seed_photon_fields=["CMB"]) IC.set_We(We, Eemin=1 * u.TeV) # compute flux at the energies given in data['energy'] @@ -82,6 +85,7 @@ def IC_We_lnprior(pars): logprob = naima.uniform_prior(pars[1], -1, 5) return logprob + # # FUNCTIONAL MODELS # @@ -89,34 +93,37 @@ def IC_We_lnprior(pars): # =========================== ECPL_p0 = np.array((1e-12, 2.4, np.log10(15.0))) -ECPL_labels = ['norm', 'index', 'log10(cutoff)'] +ECPL_labels = ["norm", "index", "log10(cutoff)"] def ECPL(pars, data): # Get the units of the flux data and match them in the model amplitude - amplitude = pars[0] * data['flux'].unit + amplitude = pars[0] * data["flux"].unit alpha = pars[1] - e_cutoff = (10**pars[2]) * u.TeV - ECPL = naima.models.ExponentialCutoffPowerLaw(amplitude, 1 * u.TeV, alpha, - e_cutoff) + e_cutoff = (10 ** pars[2]) * u.TeV + ECPL = naima.models.ExponentialCutoffPowerLaw( + amplitude, 1 * u.TeV, alpha, e_cutoff + ) return ECPL(data) def ECPL_lnprior(pars): - logprob = naima.uniform_prior(pars[0], 0., np.inf) \ - + naima.uniform_prior(pars[1], -1, 5) + logprob = naima.uniform_prior(pars[0], 0.0, np.inf) + naima.uniform_prior( + pars[1], -1, 5 + ) return logprob + # Log-Parabola or Curved Powerlaw # =============================== -LP_p0 = np.array((1.5e-12, 2.7, 0.12,)) -LP_labels = ['norm', 'alpha', 'beta'] +LP_p0 = np.array((1.5e-12, 2.7, 0.12)) +LP_labels = ["norm", "alpha", "beta"] def LP(pars, data): - amplitude = pars[0] * data['flux'].unit + amplitude = pars[0] * data["flux"].unit alpha = pars[1] beta = pars[2] LP = naima.models.LogParabola(amplitude, 1 * u.TeV, alpha, beta) @@ -124,6 +131,7 @@ def LP(pars, data): def LP_lnprior(pars): - logprob = naima.uniform_prior(pars[0], 0., np.inf) \ - + naima.uniform_prior(pars[1], -1, 5) + logprob = naima.uniform_prior(pars[0], 0.0, np.inf) + naima.uniform_prior( + pars[1], -1, 5 + ) return logprob diff --git a/ez_setup.py b/ez_setup.py index 800c31ef..5bfacfd1 100644 --- a/ez_setup.py +++ b/ez_setup.py @@ -39,7 +39,9 @@ DEFAULT_SAVE_DIR = os.curdir DEFAULT_DEPRECATION_MESSAGE = "ez_setup.py is deprecated and when using it setuptools will be pinned to {0} since it's the last version that supports setuptools self upgrade/installation, check https://github.com/pypa/setuptools/issues/581 for more info; use pip to install setuptools" -MEANINGFUL_INVALID_ZIP_ERR_MSG = 'Maybe {0} is corrupted, delete it and try again.' +MEANINGFUL_INVALID_ZIP_ERR_MSG = ( + "Maybe {0} is corrupted, delete it and try again." +) log.warn(DEFAULT_DEPRECATION_MESSAGE.format(DEFAULT_VERSION)) @@ -58,10 +60,10 @@ def _install(archive_filename, install_args=()): """Install Setuptools.""" with archive_context(archive_filename): # installing - log.warn('Installing Setuptools') - if not _python_cmd('setup.py', 'install', *install_args): - log.warn('Something went wrong during the installation.') - log.warn('See the error message above.') + log.warn("Installing Setuptools") + if not _python_cmd("setup.py", "install", *install_args): + log.warn("Something went wrong during the installation.") + log.warn("See the error message above.") # exitcode will be 2 return 2 @@ -70,12 +72,12 @@ def _build_egg(egg, archive_filename, to_dir): """Build Setuptools egg.""" with archive_context(archive_filename): # building an egg - log.warn('Building a Setuptools egg in %s', to_dir) - _python_cmd('setup.py', '-q', 'bdist_egg', '--dist-dir', to_dir) + log.warn("Building a Setuptools egg in %s", to_dir) + _python_cmd("setup.py", "-q", "bdist_egg", "--dist-dir", to_dir) # returning the result log.warn(egg) if not os.path.exists(egg): - raise IOError('Could not build the egg.') + raise IOError("Could not build the egg.") class ContextualZipFile(zipfile.ZipFile): @@ -90,7 +92,7 @@ def __exit__(self, type, value, traceback): def __new__(cls, *args, **kwargs): """Construct a ZipFile or ContextualZipFile as appropriate.""" - if hasattr(zipfile.ZipFile, '__exit__'): + if hasattr(zipfile.ZipFile, "__exit__"): return zipfile.ZipFile(*args, **kwargs) return super(ContextualZipFile, cls).__new__(cls) @@ -103,7 +105,7 @@ def archive_context(filename): The unzipped target is cleaned up after. """ tmpdir = tempfile.mkdtemp() - log.warn('Extracting in %s', tmpdir) + log.warn("Extracting in %s", tmpdir) old_wd = os.getcwd() try: os.chdir(tmpdir) @@ -112,7 +114,7 @@ def archive_context(filename): archive.extractall() except zipfile.BadZipfile as err: if not err.args: - err.args = ('', ) + err.args = ("",) err.args = err.args + ( MEANINGFUL_INVALID_ZIP_ERR_MSG.format(filename), ) @@ -121,7 +123,7 @@ def archive_context(filename): # going in the directory subdir = os.path.join(tmpdir, os.listdir(tmpdir)[0]) os.chdir(subdir) - log.warn('Now working in %s', subdir) + log.warn("Now working in %s", subdir) yield finally: @@ -131,27 +133,32 @@ def archive_context(filename): def _do_download(version, download_base, to_dir, download_delay): """Download Setuptools.""" - py_desig = 'py{sys.version_info[0]}.{sys.version_info[1]}'.format(sys=sys) - tp = 'setuptools-{version}-{py_desig}.egg' + py_desig = "py{sys.version_info[0]}.{sys.version_info[1]}".format(sys=sys) + tp = "setuptools-{version}-{py_desig}.egg" egg = os.path.join(to_dir, tp.format(**locals())) if not os.path.exists(egg): - archive = download_setuptools(version, download_base, - to_dir, download_delay) + archive = download_setuptools( + version, download_base, to_dir, download_delay + ) _build_egg(egg, archive, to_dir) sys.path.insert(0, egg) # Remove previously-imported pkg_resources if present (see # https://bitbucket.org/pypa/setuptools/pull-request/7/ for details). - if 'pkg_resources' in sys.modules: + if "pkg_resources" in sys.modules: _unload_pkg_resources() import setuptools + setuptools.bootstrap_install_from = egg def use_setuptools( - version=DEFAULT_VERSION, download_base=DEFAULT_URL, - to_dir=DEFAULT_SAVE_DIR, download_delay=15): + version=DEFAULT_VERSION, + download_base=DEFAULT_URL, + to_dir=DEFAULT_SAVE_DIR, + download_delay=15, +): """ Ensure that a setuptools version is installed. @@ -162,11 +169,12 @@ def use_setuptools( # prior to importing, capture the module state for # representative modules. - rep_modules = 'pkg_resources', 'setuptools' + rep_modules = "pkg_resources", "setuptools" imported = set(sys.modules).intersection(rep_modules) try: import pkg_resources + pkg_resources.require("setuptools>=" + version) # a suitable version is already installed return @@ -193,14 +201,16 @@ def _conflict_bail(VC_err, version): Setuptools was imported prior to invocation, so it is unsafe to unload it. Bail out. """ - conflict_tmpl = textwrap.dedent(""" + conflict_tmpl = textwrap.dedent( + """ The required version of setuptools (>={version}) is not available, and can't be installed while this script is running. Please install a more recent version first, using 'easy_install -U setuptools'. (Currently using {VC_err.args[0]!r}) - """) + """ + ) msg = conflict_tmpl.format(**locals()) sys.stderr.write(msg) sys.exit(2) @@ -210,11 +220,10 @@ def _unload_pkg_resources(): sys.meta_path = [ importer for importer in sys.meta_path - if importer.__class__.__module__ != 'pkg_resources.extern' + if importer.__class__.__module__ != "pkg_resources.extern" ] del_modules = [ - name for name in sys.modules - if name.startswith('pkg_resources') + name for name in sys.modules if name.startswith("pkg_resources") ] for mod_name in del_modules: del sys.modules[mod_name] @@ -248,57 +257,59 @@ def download_file_powershell(url, target): '(new-object System.Net.WebClient).DownloadFile("%(url)s", "%(target)s")' % locals() ) - cmd = [ - 'powershell', - '-Command', - ps_cmd, - ] + cmd = ["powershell", "-Command", ps_cmd] _clean_check(cmd, target) def has_powershell(): """Determine if Powershell is available.""" - if platform.system() != 'Windows': + if platform.system() != "Windows": return False - cmd = ['powershell', '-Command', 'echo test'] - with open(os.path.devnull, 'wb') as devnull: + cmd = ["powershell", "-Command", "echo test"] + with open(os.path.devnull, "wb") as devnull: try: subprocess.check_call(cmd, stdout=devnull, stderr=devnull) except Exception: return False return True + + download_file_powershell.viable = has_powershell def download_file_curl(url, target): - cmd = ['curl', url, '--location', '--silent', '--output', target] + cmd = ["curl", url, "--location", "--silent", "--output", target] _clean_check(cmd, target) def has_curl(): - cmd = ['curl', '--version'] - with open(os.path.devnull, 'wb') as devnull: + cmd = ["curl", "--version"] + with open(os.path.devnull, "wb") as devnull: try: subprocess.check_call(cmd, stdout=devnull, stderr=devnull) except Exception: return False return True + + download_file_curl.viable = has_curl def download_file_wget(url, target): - cmd = ['wget', url, '--quiet', '--output-document', target] + cmd = ["wget", url, "--quiet", "--output-document", target] _clean_check(cmd, target) def has_wget(): - cmd = ['wget', '--version'] - with open(os.path.devnull, 'wb') as devnull: + cmd = ["wget", "--version"] + with open(os.path.devnull, "wb") as devnull: try: subprocess.check_call(cmd, stdout=devnull, stderr=devnull) except Exception: return False return True + + download_file_wget.viable = has_wget @@ -314,6 +325,8 @@ def download_file_insecure(url, target): # Write all the data in one block to avoid creating a partial file. with open(target, "wb") as dst: dst.write(data) + + download_file_insecure.viable = lambda: True @@ -329,9 +342,12 @@ def get_best_downloader(): def download_setuptools( - version=DEFAULT_VERSION, download_base=DEFAULT_URL, - to_dir=DEFAULT_SAVE_DIR, delay=15, - downloader_factory=get_best_downloader): + version=DEFAULT_VERSION, + download_base=DEFAULT_URL, + to_dir=DEFAULT_SAVE_DIR, + delay=15, + downloader_factory=get_best_downloader, +): """ Download setuptools from a specified location and return its filename. @@ -362,30 +378,41 @@ def _build_install_args(options): Returns list of command line arguments. """ - return ['--user'] if options.user_install else [] + return ["--user"] if options.user_install else [] def _parse_args(): """Parse the command line for options.""" parser = optparse.OptionParser() parser.add_option( - '--user', dest='user_install', action='store_true', default=False, - help='install in user site package') + "--user", + dest="user_install", + action="store_true", + default=False, + help="install in user site package", + ) parser.add_option( - '--download-base', dest='download_base', metavar="URL", + "--download-base", + dest="download_base", + metavar="URL", default=DEFAULT_URL, - help='alternative URL from where to download the setuptools package') + help="alternative URL from where to download the setuptools package", + ) parser.add_option( - '--insecure', dest='downloader_factory', action='store_const', - const=lambda: download_file_insecure, default=get_best_downloader, - help='Use internal, non-validating downloader' + "--insecure", + dest="downloader_factory", + action="store_const", + const=lambda: download_file_insecure, + default=get_best_downloader, + help="Use internal, non-validating downloader", ) parser.add_option( - '--version', help="Specify which version to download", + "--version", + help="Specify which version to download", default=DEFAULT_VERSION, ) parser.add_option( - '--to-dir', + "--to-dir", help="Directory to save (and re-use) package", default=DEFAULT_SAVE_DIR, ) @@ -410,5 +437,6 @@ def main(): archive = download_setuptools(**_download_args(options)) return _install(archive, _build_install_args(options)) -if __name__ == '__main__': + +if __name__ == "__main__": sys.exit(main()) diff --git a/naima/__init__.py b/naima/__init__.py index 0a2e5f7b..9066a4ba 100644 --- a/naima/__init__.py +++ b/naima/__init__.py @@ -13,6 +13,7 @@ # should keep this content at the top. # ---------------------------------------------------------------------------- from ._astropy_init import * + # ---------------------------------------------------------------------------- from .core import * diff --git a/naima/_astropy_init.py b/naima/_astropy_init.py index f6533863..9e682fe1 100644 --- a/naima/_astropy_init.py +++ b/naima/_astropy_init.py @@ -1,12 +1,13 @@ # Licensed under a 3-clause BSD style license - see LICENSE.rst -__all__ = ['__version__', '__githash__', 'test'] +__all__ = ["__version__", "__githash__", "test"] # this indicates whether or not we are in the package's setup.py try: _ASTROPY_SETUP_ except NameError: from sys import version_info + if version_info[0] >= 3: import builtins else: @@ -16,21 +17,34 @@ try: from .version import version as __version__ except ImportError: - __version__ = '' + __version__ = "" try: from .version import githash as __githash__ except ImportError: - __githash__ = '' + __githash__ = "" # set up the test command def _get_test_runner(): import os from astropy.tests.helper import TestRunner + return TestRunner(os.path.dirname(__file__)) -def test(package=None, test_path=None, args=None, plugins=None, - verbose=False, pastebin=None, remote_data=False, pep8=False, - pdb=False, coverage=False, open_files=False, **kwargs): + +def test( + package=None, + test_path=None, + args=None, + plugins=None, + verbose=False, + pastebin=None, + remote_data=False, + pep8=False, + pdb=False, + coverage=False, + open_files=False, + **kwargs +): """ Run the tests using `py.test `__. A proper set of arguments is constructed and passed to `pytest.main`_. @@ -105,10 +119,20 @@ def test(package=None, test_path=None, args=None, plugins=None, """ test_runner = _get_test_runner() return test_runner.run_tests( - package=package, test_path=test_path, args=args, - plugins=plugins, verbose=verbose, pastebin=pastebin, - remote_data=remote_data, pep8=pep8, pdb=pdb, - coverage=coverage, open_files=open_files, **kwargs) + package=package, + test_path=test_path, + args=args, + plugins=plugins, + verbose=verbose, + pastebin=pastebin, + remote_data=remote_data, + pep8=pep8, + pdb=pdb, + coverage=coverage, + open_files=open_files, + **kwargs + ) + if not _ASTROPY_SETUP_: import os @@ -118,21 +142,30 @@ def test(package=None, test_path=None, args=None, plugins=None, # add these here so we only need to cleanup the namespace at the end config_dir = None - if not os.environ.get('ASTROPY_SKIP_CONFIG_UPDATE', False): + if not os.environ.get("ASTROPY_SKIP_CONFIG_UPDATE", False): config_dir = os.path.dirname(__file__) config_template = os.path.join(config_dir, __package__ + ".cfg") if os.path.isfile(config_template): try: config.configuration.update_default_config( - __package__, config_dir, version=__version__) + __package__, config_dir, version=__version__ + ) except TypeError as orig_error: try: config.configuration.update_default_config( - __package__, config_dir) + __package__, config_dir + ) except config.configuration.ConfigurationDefaultMissingError as e: - wmsg = (e.args[0] + " Cannot install default profile. If you are " - "importing from source, this is expected.") - warn(config.configuration.ConfigurationDefaultMissingWarning(wmsg)) + wmsg = ( + e.args[0] + + " Cannot install default profile. If you are " + "importing from source, this is expected." + ) + warn( + config.configuration.ConfigurationDefaultMissingWarning( + wmsg + ) + ) del e except: raise orig_error diff --git a/naima/analysis.py b/naima/analysis.py index d96f4bcc..571294d0 100644 --- a/naima/analysis.py +++ b/naima/analysis.py @@ -1,6 +1,10 @@ # Licensed under a 3-clause BSD style license - see LICENSE.rst -from __future__ import (absolute_import, division, print_function, - unicode_literals) +from __future__ import ( + absolute_import, + division, + print_function, + unicode_literals, +) import numpy as np import astropy.units as u @@ -16,23 +20,29 @@ try: import yaml + HAS_PYYAML = True except ImportError: HAS_PYYAML = False __all__ = [ - "save_diagnostic_plots", "save_results_table", "save_run", "read_run" + "save_diagnostic_plots", + "save_results_table", + "save_run", + "read_run", ] -def save_diagnostic_plots(outname, - sampler, - modelidxs=None, - pdf=False, - sed=True, - blob_labels=None, - last_step=False, - dpi=100): +def save_diagnostic_plots( + outname, + sampler, + modelidxs=None, + pdf=False, + sed=True, + blob_labels=None, + last_step=False, + dpi=100, +): """ Generate diagnostic plots. @@ -69,46 +79,52 @@ def save_diagnostic_plots(outname, from .plot import plot_chain, plot_blob, plot_corner from matplotlib import pyplot as plt + # This function should never be interactive - old_interactive = plt.rcParams['interactive'] - plt.rcParams['interactive'] = False + old_interactive = plt.rcParams["interactive"] + plt.rcParams["interactive"] = False if pdf: - plt.rc('pdf', fonttype=42) - log.info('Saving diagnostic plots in file ' - '{0}_plots.pdf'.format(outname)) + plt.rc("pdf", fonttype=42) + log.info( + "Saving diagnostic plots in file " "{0}_plots.pdf".format(outname) + ) from matplotlib.backends.backend_pdf import PdfPages - outpdf = PdfPages('{0}_plots.pdf'.format(outname)) + + outpdf = PdfPages("{0}_plots.pdf".format(outname)) # Chains for par, label in six.moves.zip( - six.moves.range(sampler.chain.shape[-1]), sampler.labels): + six.moves.range(sampler.chain.shape[-1]), sampler.labels + ): try: - log.info('Plotting chain of parameter {0}...'.format(label)) + log.info("Plotting chain of parameter {0}...".format(label)) f = plot_chain(sampler, par, last_step=last_step) if pdf: - f.savefig(outpdf, format='pdf', dpi=dpi) + f.savefig(outpdf, format="pdf", dpi=dpi) else: - if 'log(' in label or 'log10(' in label: - label = label.split('(')[-1].split(')')[0] - f.savefig('{0}_chain_{1}.png'.format(outname, label), dpi=dpi) + if "log(" in label or "log10(" in label: + label = label.split("(")[-1].split(")")[0] + f.savefig("{0}_chain_{1}.png".format(outname, label), dpi=dpi) f.clf() plt.close(f) except Exception as e: - log.warning('plot_chain failed for paramter' - ' {0} ({1}): {2}'.format(label, par, e)) + log.warning( + "plot_chain failed for paramter" + " {0} ({1}): {2}".format(label, par, e) + ) # Corner plot - log.info('Plotting corner plot...') + log.info("Plotting corner plot...") f = plot_corner(sampler) if f is not None: if pdf: - f.savefig(outpdf, format='pdf', dpi=dpi) + f.savefig(outpdf, format="pdf", dpi=dpi) else: - f.savefig('{0}_corner.png'.format(outname), dpi=dpi) + f.savefig("{0}_corner.png".format(outname), dpi=dpi) f.clf() plt.close(f) @@ -122,50 +138,56 @@ def save_diagnostic_plots(outname, sed = [sed for idx in modelidxs] if blob_labels is None: - blob_labels = ['Model output {0}'.format(idx) for idx in modelidxs] + blob_labels = ["Model output {0}".format(idx) for idx in modelidxs] elif len(modelidxs) == 1 and isinstance(blob_labels, str): blob_labels = [blob_labels] elif len(blob_labels) < len(modelidxs): # Add labels n = len(blob_labels) - blob_labels += ['Model output {0}'.format(idx) - for idx in modelidxs[n:]] + blob_labels += [ + "Model output {0}".format(idx) for idx in modelidxs[n:] + ] for modelidx, plot_sed, label in six.moves.zip( - modelidxs, sed, blob_labels): + modelidxs, sed, blob_labels + ): try: - log.info('Plotting {0}...'.format(label)) + log.info("Plotting {0}...".format(label)) f = plot_blob( sampler, blobidx=modelidx, label=label, sed=plot_sed, n_samples=100, - last_step=last_step) + last_step=last_step, + ) if pdf: - f.savefig(outpdf, format='pdf', dpi=dpi) + f.savefig(outpdf, format="pdf", dpi=dpi) else: - f.savefig('{0}_model{1}.png'.format(outname, modelidx), - dpi=dpi) + f.savefig( + "{0}_model{1}.png".format(outname, modelidx), dpi=dpi + ) f.clf() plt.close(f) except Exception as e: - log.warning('plot_blob failed for {0}: {1}'.format(label, e)) + log.warning("plot_blob failed for {0}: {1}".format(label, e)) if pdf: outpdf.close() # set interactive back to original - plt.rcParams['interactive'] = old_interactive + plt.rcParams["interactive"] = old_interactive -def save_results_table(outname, - sampler, - format='ascii.ecsv', - convert_log=True, - last_step=False, - include_blobs=True): +def save_results_table( + outname, + sampler, + format="ascii.ecsv", + convert_log=True, + last_step=False, + include_blobs=True, +): """ Save an ASCII table with the results stored in the `~emcee.EnsembleSampler`. @@ -211,21 +233,28 @@ def save_results_table(outname, Table with the results. """ - if not HAS_PYYAML and format == 'ascii.ecsv': - format = 'ascii.ipac' - log.warning('PyYAML package is required for ECSV format,' - ' falling back to {0}...'.format(format)) - elif format not in ['ascii.ecsv', 'ascii.ipac']: - log.warning('The chosen table format does not have an astropy' - ' writer that suppports metadata writing, no run info' - ' will be saved to the file!') - - file_extension = 'dat' - if format == 'ascii.ecsv': - file_extension = 'ecsv' - - log.info('Saving results table in {0}_results.{1}'.format(outname, - file_extension)) + if not HAS_PYYAML and format == "ascii.ecsv": + format = "ascii.ipac" + log.warning( + "PyYAML package is required for ECSV format," + " falling back to {0}...".format(format) + ) + elif format not in ["ascii.ecsv", "ascii.ipac"]: + log.warning( + "The chosen table format does not have an astropy" + " writer that suppports metadata writing, no run info" + " will be saved to the file!" + ) + + file_extension = "dat" + if format == "ascii.ecsv": + file_extension = "ecsv" + + log.info( + "Saving results table in {0}_results.{1}".format( + outname, file_extension + ) + ) labels = sampler.labels @@ -237,34 +266,38 @@ def save_results_table(outname, quant = [16, 50, 84] # Do we need more info on the distributions? t = Table( - names=['label', 'median', 'unc_lo', 'unc_hi'], - dtype=['S72', 'f8', 'f8', 'f8']) - t['label'].description = 'Name of the parameter' - t['median'].description = 'Median of the posterior distribution function' - t['unc_lo'].description = ( - 'Difference between the median and the' - ' {0}th percentile of the pdf, ~1sigma lower uncertainty'.format(quant[ - 0])) - t['unc_hi'].description = ( - 'Difference between the {0}th percentile' - ' and the median of the pdf, ~1sigma upper uncertainty' - .format(quant[2]) + names=["label", "median", "unc_lo", "unc_hi"], + dtype=["S72", "f8", "f8", "f8"], + ) + t["label"].description = "Name of the parameter" + t["median"].description = "Median of the posterior distribution function" + t["unc_lo"].description = ( + "Difference between the median and the" + " {0}th percentile of the pdf, ~1sigma lower uncertainty".format( + quant[0] + ) + ) + t["unc_hi"].description = ( + "Difference between the {0}th percentile" + " and the median of the pdf, ~1sigma upper uncertainty".format( + quant[2] + ) ) metadata = {} # Start with info from the distributions used for storing the results - metadata['n_samples'] = dists.shape[0] + metadata["n_samples"] = dists.shape[0] # save ML parameter vector and best/median loglikelihood ML, MLp, MLerr, _ = find_ML(sampler, None) - metadata['ML_pars'] = [float(p) for p in MLp] - metadata['MaxLogLikelihood'] = float(ML) + metadata["ML_pars"] = [float(p) for p in MLp] + metadata["MaxLogLikelihood"] = float(ML) # compute and save BIC BIC = len(MLp) * np.log(len(sampler.data)) - 2 * ML - metadata['BIC'] = BIC + metadata["BIC"] = BIC # And add all info stored in the sampler.run_info dict - if hasattr(sampler, 'run_info'): + if hasattr(sampler, "run_info"): metadata.update(sampler.run_info) for p, label in enumerate(labels): @@ -276,16 +309,17 @@ def save_results_table(outname, t.add_row((label, med, lo, hi)) - if convert_log and ('log10(' in label or 'log(' in label): - nlabel = label.split('(')[-1].split(')')[0] - ltype = label.split('(')[0] - if ltype == 'log10': - new_dist = 10**dist - elif ltype == 'log': + if convert_log and ("log10(" in label or "log(" in label): + nlabel = label.split("(")[-1].split(")")[0] + ltype = label.split("(")[0] + if ltype == "log10": + new_dist = 10 ** dist + elif ltype == "log": new_dist = np.exp(dist) quantiles = dict( - six.moves.zip(quant, np.percentile(new_dist, quant))) + six.moves.zip(quant, np.percentile(new_dist, quant)) + ) med = quantiles[50] lo, hi = med - quantiles[16], quantiles[84] - med @@ -315,24 +349,25 @@ def save_results_table(outname, blobl.append(walkerblob[idx]) if unit: dist = np.array([b.value for b in blobl]) - metadata['blob{0}_unit'.format(idx)] = unit.to_string() + metadata["blob{0}_unit".format(idx)] = unit.to_string() else: dist = np.array(blobl) quantiles = dict( - six.moves.zip(quant, np.percentile(dist, quant))) + six.moves.zip(quant, np.percentile(dist, quant)) + ) med = quantiles[50] lo, hi = med - quantiles[16], quantiles[84] - med - t.add_row(('blob{0}'.format(idx), med, lo, hi)) + t.add_row(("blob{0}".format(idx), med, lo, hi)) - if format == 'ascii.ipac': + if format == "ascii.ipac": # Only keywords are written to IPAC tables - t.meta['keywords'] = {} + t.meta["keywords"] = {} for di in metadata.items(): - t.meta['keywords'][di[0]] = {'value': di[1]} + t.meta["keywords"][di[0]] = {"value": di[1]} else: - if format == 'ascii.ecsv': + if format == "ascii.ecsv": # there can be no numpy arrays in the metadata (YAML doesn't like # them) for di in list(metadata.items()): @@ -346,7 +381,7 @@ def save_results_table(outname, # Save it directly in meta for readability in ECSV t.meta.update(metadata) - t.write('{0}_results.{1}'.format(outname, file_extension), format=format) + t.write("{0}_results.{1}".format(outname, file_extension), format=format) return t @@ -376,20 +411,21 @@ def save_run(filename, sampler, compression=True, clobber=False): Whether to overwrite the output filename if it exists. """ - if filename.split('.')[-1] not in ['h5', 'hdf5']: - filename += '_chain.h5' + if filename.split(".")[-1] not in ["h5", "hdf5"]: + filename += "_chain.h5" if os.path.exists(filename) and not clobber: log.warning( - 'Not writing file because file exists and clobber is False') + "Not writing file because file exists and clobber is False" + ) return - f = h5py.File(filename, 'w') - group = f.create_group('sampler') - group.create_dataset( - 'chain', data=sampler.chain, compression=compression) + f = h5py.File(filename, "w") + group = f.create_group("sampler") + group.create_dataset("chain", data=sampler.chain, compression=compression) group.create_dataset( - 'lnprobability', data=sampler.lnprobability, compression=compression) + "lnprobability", data=sampler.lnprobability, compression=compression + ) # blobs blob = sampler.blobs[-1][0] @@ -398,7 +434,7 @@ def save_run(filename, sampler, compression=True, clobber=False): # scalar or array quantity units = [item.unit.to_string()] elif isinstance(item, float): - units = [''] + units = [""] elif isinstance(item, tuple) or isinstance(item, list): arearrs = np.all([isinstance(x, np.ndarray) for x in item]) if arearrs: @@ -407,11 +443,11 @@ def save_run(filename, sampler, compression=True, clobber=False): if isinstance(x, u.Quantity): units.append(x.unit.to_string()) else: - units.append('') + units.append("") else: log.warning( - 'blob number {0} has unknown format and cannot be saved ' - 'in HDF5 file' + "blob number {0} has unknown format and cannot be saved " + "in HDF5 file" ) continue @@ -424,21 +460,23 @@ def save_run(filename, sampler, compression=True, clobber=False): blob = u.Quantity(blob).value blobdataset = group.create_dataset( - 'blob{0}'.format(idx), data=blob, compression=compression) + "blob{0}".format(idx), data=blob, compression=compression + ) if len(units) > 1: for j, unit in enumerate(units): - blobdataset.attrs['unit{0}'.format(j)] = unit + blobdataset.attrs["unit{0}".format(j)] = unit else: - blobdataset.attrs['unit'] = units[0] + blobdataset.attrs["unit"] = units[0] - if hasattr(sampler, 'data'): + if hasattr(sampler, "data"): data = group.create_dataset( - 'data', + "data", data=Table(sampler.data).as_array(), - compression=compression) + compression=compression, + ) for col in sampler.data.colnames: - f['sampler/data'].attrs[col + 'unit'] = str(sampler.data[col].unit) + f["sampler/data"].attrs[col + "unit"] = str(sampler.data[col].unit) for key in sampler.data.meta: val = sampler.data.meta[key] @@ -451,11 +489,12 @@ def save_run(filename, sampler, compression=True, clobber=False): warnings.warn( "Attribute `{0}` of type {1} of the data table" " of the sampler cannot be written to HDF5 files" - "- skipping".format(key, type(val)), AstropyUserWarning + "- skipping".format(key, type(val)), + AstropyUserWarning, ) # add all run info to group attributes - if hasattr(sampler, 'run_info'): + if hasattr(sampler, "run_info"): for key in sampler.run_info.keys(): val = sampler.run_info[key] try: @@ -464,11 +503,11 @@ def save_run(filename, sampler, compression=True, clobber=False): group.attrs[key] = str(val) # add other sampler info to the attrs - group.attrs['acceptance_fraction'] = np.mean(sampler.acceptance_fraction) + group.attrs["acceptance_fraction"] = np.mean(sampler.acceptance_fraction) # add labels as individual attrs (there might be a better way) for i, label in enumerate(sampler.labels): - group.attrs['label{0}'.format(i)] = label + group.attrs["label{0}".format(i)] = label f.close() @@ -521,10 +560,10 @@ def read_run(filename, modelfn=None): result.modelfn = modelfn result.run_info = {} - f = h5py.File(filename, 'r') + f = h5py.File(filename, "r") # chain and lnprobability - result.chain = np.array(f['sampler/chain']) - result.lnprobability = np.array(f['sampler/lnprobability']) + result.chain = np.array(f["sampler/chain"]) + result.lnprobability = np.array(f["sampler/lnprobability"]) # blobs result.blobs = [] @@ -534,18 +573,20 @@ def read_run(filename, modelfn=None): for i in range(100): # first read each of the blobs and convert to Quantities try: - ds = f['sampler/blob{0}'.format(i)] + ds = f["sampler/blob{0}".format(i)] rank = np.ndim(ds[0]) blobrank.append(rank) if rank <= 1: - blobs.append(u.Quantity(ds.value, unit=ds.attrs['unit'])) + blobs.append(u.Quantity(ds.value, unit=ds.attrs["unit"])) else: blob = [] for j in range(np.ndim(ds[0])): blob.append( u.Quantity( ds.value[:, j, :], - unit=ds.attrs['unit{0}'.format(j)])) + unit=ds.attrs["unit{0}".format(j)], + ) + ) blobs.append(blob) except KeyError: break @@ -568,19 +609,19 @@ def read_run(filename, modelfn=None): result.blobs.append(steplist) # run info - result.run_info = dict(f['sampler'].attrs) - result.acceptance_fraction = f['sampler'].attrs['acceptance_fraction'] + result.run_info = dict(f["sampler"].attrs) + result.acceptance_fraction = f["sampler"].attrs["acceptance_fraction"] # labels result.labels = [] for i in range(result.chain.shape[2]): - result.labels.append(f['sampler'].attrs['label{0}'.format(i)]) + result.labels.append(f["sampler"].attrs["label{0}".format(i)]) # data - data = Table(np.array(f['sampler/data'])) - data.meta.update(f['sampler/data'].attrs) + data = Table(np.array(f["sampler/data"])) + data.meta.update(f["sampler/data"].attrs) for col in data.colnames: - if f['sampler/data'].attrs[col + 'unit'] != 'None': - data[col].unit = f['sampler/data'].attrs[col + 'unit'] + if f["sampler/data"].attrs[col + "unit"] != "None": + data[col].unit = f["sampler/data"].attrs[col + "unit"] result.data = QTable(data) return result diff --git a/naima/core.py b/naima/core.py index 68d441b8..14fb03ca 100644 --- a/naima/core.py +++ b/naima/core.py @@ -1,7 +1,11 @@ # -*- coding: utf-8 -*- # Licensed under a 3-clause BSD style license - see LICENSE.rst -from __future__ import (absolute_import, division, print_function, - unicode_literals) +from __future__ import ( + absolute_import, + division, + print_function, + unicode_literals, +) import numpy as np from astropy import log import astropy.units as u @@ -10,17 +14,20 @@ from .utils import validate_data_table, sed_conversion __all__ = [ - "normal_prior", "uniform_prior", "log_uniform_prior", "get_sampler", - "run_sampler" + "normal_prior", + "uniform_prior", + "log_uniform_prior", + "get_sampler", + "run_sampler", ] # Define phsyical types used in plot and utils.validate_data_table -u.def_physical_type(u.erg / u.cm**2 / u.s, 'flux') -u.def_physical_type(u.Unit('1/(s cm2 erg)'), 'differential flux') -u.def_physical_type(u.Unit('1/(s erg)'), 'differential power') -u.def_physical_type(u.Unit('1/TeV'), 'differential energy') -u.def_physical_type(u.Unit('1/cm3'), 'number density') -u.def_physical_type(u.Unit('1/(eV cm3)'), 'differential number density') +u.def_physical_type(u.erg / u.cm ** 2 / u.s, "flux") +u.def_physical_type(u.Unit("1/(s cm2 erg)"), "differential flux") +u.def_physical_type(u.Unit("1/(s erg)"), "differential power") +u.def_physical_type(u.Unit("1/TeV"), "differential energy") +u.def_physical_type(u.Unit("1/cm3"), "number density") +u.def_physical_type(u.Unit("1/(eV cm3)"), "differential number density") # Prior functions @@ -37,7 +44,7 @@ def uniform_prior(value, umin, umax): def normal_prior(value, mean, sigma): """Normal prior distribution. """ - return -0.5 * (2 * np.pi * sigma) - (value - mean)**2 / (2. * sigma) + return -0.5 * (2 * np.pi * sigma) - (value - mean) ** 2 / (2.0 * sigma) def log_uniform_prior(value, umin=0, umax=None): @@ -61,31 +68,38 @@ def log_uniform_prior(value, umin=0, umax=None): def lnprobmodel(model, data): # Check if conversion is required - model_is_sed = model.unit.physical_type in ['power', 'flux'] - data_is_sed = data['flux'].unit.physical_type in ['power', 'flux'] + model_is_sed = model.unit.physical_type in ["power", "flux"] + data_is_sed = data["flux"].unit.physical_type in ["power", "flux"] if model_is_sed != data_is_sed: - unit, sed_factor = sed_conversion(data['energy'], model.unit, - data_is_sed) - model = (model * sed_factor).to(data['flux'].unit) + unit, sed_factor = sed_conversion( + data["energy"], model.unit, data_is_sed + ) + model = (model * sed_factor).to(data["flux"].unit) - ul = data['ul'] + ul = data["ul"] notul = ~ul - difference = model[notul] - data['flux'][notul] + difference = model[notul] - data["flux"][notul] # use different errors for model above or below data sign = difference > 0 loerr, hierr = 1 * ~sign, 1 * sign - logprob = -difference**2 / (2. * (loerr * data['flux_error_lo'][notul] + - hierr * data['flux_error_hi'][notul])**2) + logprob = -difference ** 2 / ( + 2.0 + * ( + loerr * data["flux_error_lo"][notul] + + hierr * data["flux_error_hi"][notul] + ) + ** 2 + ) totallogprob = np.sum(logprob) if np.sum(ul) > 0: # deal with upper limits at CL set by data['cl'] - violated_uls = np.sum(model[ul] > data['flux'][ul]) - totallogprob += violated_uls * np.log(1. - data['cl'][violated_uls]) + violated_uls = np.sum(model[ul] > data["flux"][ul]) + totallogprob += violated_uls * np.log(1.0 - data["cl"][violated_uls]) return totallogprob @@ -97,15 +111,16 @@ def lnprob(pars, data, modelfunc, priorfunc): else: lnprob_priors = priorfunc(pars) -# If prior is -np.inf, avoid calling the function as invalid calls may be made, -# and the result will be discarded anyway + # If prior is -np.inf, avoid calling the function as invalid calls may be made, + # and the result will be discarded anyway if not np.isinf(lnprob_priors): modelout = modelfunc(pars, data) # Save blobs or save model if no blobs given # If model is not in blobs, save model+blobs - if ((type(modelout) == tuple or type(modelout) == list) and - (type(modelout) != np.ndarray)): + if (type(modelout) == tuple or type(modelout) == list) and ( + type(modelout) != np.ndarray + ): model = modelout[0] MODEL_IN_BLOB = False @@ -136,24 +151,40 @@ def lnprob(pars, data, modelfunc, priorfunc): def _run_mcmc(sampler, pos, nrun): for i, out in enumerate(sampler.sample(pos, iterations=nrun)): - progress = (100. * float(i) / float(nrun)) - if progress % 5 < (5. / float(nrun)): - print("\nProgress of the run: {0:.0f} percent" - " ({1} of {2} steps)".format(int(progress), i, nrun)) + progress = 100.0 * float(i) / float(nrun) + if progress % 5 < (5.0 / float(nrun)): + print( + "\nProgress of the run: {0:.0f} percent" + " ({1} of {2} steps)".format(int(progress), i, nrun) + ) npars = out[0].shape[-1] paravg, parstd = [], [] for npar in range(npars): paravg.append(np.median(out[0][:, npar])) parstd.append(np.std(out[0][:, npar])) - print(" " + (" ".join([ - "{%i:-^15}" % i for i in range(npars) - ])).format(*sampler.labels)) - print(" Last ensemble median : " + (" ".join( - ["{%i:^15.3g}" % i for i in range(npars)])).format(*paravg)) - print(" Last ensemble std : " + (" ".join( - ["{%i:^15.3g}" % i for i in range(npars)])).format(*parstd)) - print(" Last ensemble lnprob : avg: {0:.3f}, max: {1:.3f}" - .format(np.average(out[1]), np.max(out[1]))) + print( + " " + + (" ".join(["{%i:-^15}" % i for i in range(npars)])).format( + *sampler.labels + ) + ) + print( + " Last ensemble median : " + + (" ".join(["{%i:^15.3g}" % i for i in range(npars)])).format( + *paravg + ) + ) + print( + " Last ensemble std : " + + (" ".join(["{%i:^15.3g}" % i for i in range(npars)])).format( + *parstd + ) + ) + print( + " Last ensemble lnprob : avg: {0:.3f}, max: {1:.3f}".format( + np.average(out[1]), np.max(out[1]) + ) + ) return sampler, out[0] @@ -171,60 +202,67 @@ def nll(*args): return -lnprob(*args)[0] log.info( - 'Finding Maximum Likelihood parameters through Nelder-Mead fitting...') - log.info(' Initial parameters: {0}'.format(p0)) - log.info(' Initial lnprob(p0): {0:.3f}'.format(-nll(p0, data, model, - prior))) + "Finding Maximum Likelihood parameters through Nelder-Mead fitting..." + ) + log.info(" Initial parameters: {0}".format(p0)) + log.info( + " Initial lnprob(p0): {0:.3f}".format(-nll(p0, data, model, prior)) + ) with warnings.catch_warnings(): warnings.simplefilter("ignore") result = minimize( nll, p0, args=(data, model, flat_prior), - method='Nelder-Mead', - options={'maxfev': 500, - 'xtol': 1e-1, - 'ftol': 1e-3}) - ll_prior = lnprob(result['x'], data, model, prior)[0] + method="Nelder-Mead", + options={"maxfev": 500, "xtol": 1e-1, "ftol": 1e-3}, + ) + ll_prior = lnprob(result["x"], data, model, prior)[0] - if (result['success'] or result['status'] == 1) and not np.isinf(ll_prior): + if (result["success"] or result["status"] == 1) and not np.isinf(ll_prior): # also keep result if we have reached maxiter, it is likely # better than p0 - if result['status'] == 1: - log.info(' Maximum number of function evaluations reached!') - if result['status'] == 1: - log.info(' New parameters : {0}'.format(result['x'])) + if result["status"] == 1: + log.info(" Maximum number of function evaluations reached!") + if result["status"] == 1: + log.info(" New parameters : {0}".format(result["x"])) else: - log.info(' New ML parameters : {0}'.format(result['x'])) + log.info(" New ML parameters : {0}".format(result["x"])) P0_IS_ML = True - if -result['fun'] == ll_prior: - log.info(' Maximum lnprob(p0): {0:.3f}'.format(-result['fun'])) + if -result["fun"] == ll_prior: + log.info(" Maximum lnprob(p0): {0:.3f}".format(-result["fun"])) else: - log.info('flat prior lnprob(p0): {0:.3f}'.format(-result['fun'])) - log.info('full prior lnprob(p0): {0:.3f}'.format(ll_prior)) - p0 = result['x'] + log.info("flat prior lnprob(p0): {0:.3f}".format(-result["fun"])) + log.info("full prior lnprob(p0): {0:.3f}".format(ll_prior)) + p0 = result["x"] elif np.isinf(ll_prior): - log.warning('Maximum Likelihood procedure converged on a parameter' - ' vector forbidden by prior,' - ' using original parameters for MCMC') + log.warning( + "Maximum Likelihood procedure converged on a parameter" + " vector forbidden by prior," + " using original parameters for MCMC" + ) else: - log.warning('Maximum Likelihood procedure failed to converge,' - ' using original parameters for MCMC') + log.warning( + "Maximum Likelihood procedure failed to converge," + " using original parameters for MCMC" + ) return p0, P0_IS_ML -def get_sampler(data_table=None, - p0=None, - model=None, - prior=None, - nwalkers=500, - nburn=100, - guess=True, - interactive=False, - prefit=False, - labels=None, - threads=4, - data_sed=None): +def get_sampler( + data_table=None, + p0=None, + model=None, + prior=None, + nwalkers=500, + nburn=100, + guess=True, + interactive=False, + prefit=False, + labels=None, + threads=4, + data_sed=None, +): """Generate a new MCMC sampler. Parameters @@ -329,24 +367,25 @@ def get_sampler(data_table=None, import emcee if data_table is None: - raise TypeError('Data table is missing!') + raise TypeError("Data table is missing!") else: data = validate_data_table(data_table, sed=data_sed) if model is None: - raise TypeError('Model function is missing!') + raise TypeError("Model function is missing!") # Add parameter labels if not provided or too short if labels is None: # First is normalization - labels = ['norm'] + ['par{0}'.format(i) for i in range(1, len(p0))] + labels = ["norm"] + ["par{0}".format(i) for i in range(1, len(p0))] elif len(labels) < len(p0): - labels += ['par{0}'.format(i) for i in range(len(labels), len(p0))] + labels += ["par{0}".format(i) for i in range(len(labels), len(p0))] # Check that the model returns fluxes in same physical type as data modelout = model(p0, data) - if ((type(modelout) == tuple or type(modelout) == list) and - (type(modelout) != np.ndarray)): + if (type(modelout) == tuple or type(modelout) == list) and ( + type(modelout) != np.ndarray + ): spec = modelout[0] else: spec = modelout @@ -356,20 +395,24 @@ def get_sampler(data_table=None, try: # If both can be converted to differential flux, they can be compared # Otherwise, sed_conversion will raise a u.UnitsError - sed_conversion(data['energy'], spec.unit, False) - sed_conversion(data['energy'], data['flux'].unit, False) + sed_conversion(data["energy"], spec.unit, False) + sed_conversion(data["energy"], data["flux"].unit, False) except u.UnitsError: raise u.UnitsError( - 'The physical type of the model and data units are not compatible,' - ' please modify your model or data so they match:\n' - ' Model units: {0} [{1}]\n Data units: {2} [{3}]\n'.format( - spec.unit, spec.unit.physical_type, data['flux'].unit, data[ - 'flux'].unit.physical_type)) + "The physical type of the model and data units are not compatible," + " please modify your model or data so they match:\n" + " Model units: {0} [{1}]\n Data units: {2} [{3}]\n".format( + spec.unit, + spec.unit.physical_type, + data["flux"].unit, + data["flux"].unit.physical_type, + ) + ) if guess: - normNames = ['norm', 'Norm', 'ampl', 'Ampl', 'We', 'Wp'] - normNameslog = ['log({0}'.format(name) for name in normNames] - normNameslog10 = ['log10({0}'.format(name) for name in normNames] + normNames = ["norm", "Norm", "ampl", "Ampl", "We", "Wp"] + normNameslog = ["log({0}".format(name) for name in normNames] + normNameslog10 = ["log10({0}".format(name) for name in normNames] normNames += normNameslog + normNameslog10 idxs = [] for l in normNames: @@ -381,47 +424,58 @@ def get_sampler(data_table=None, if len(idxs) == 1: - nunit, sedf = sed_conversion(data['energy'], spec.unit, False) - currFlux = np.trapz(data['energy'] * (spec * sedf).to(nunit), - data['energy']) - nunit, sedf = sed_conversion(data['energy'], data['flux'].unit, - False) - dataFlux = np.trapz(data['energy'] * - (data['flux'] * sedf).to(nunit), - data['energy']) - ratio = (dataFlux / currFlux) - if labels[idxs[0]].startswith('log('): + nunit, sedf = sed_conversion(data["energy"], spec.unit, False) + currFlux = np.trapz( + data["energy"] * (spec * sedf).to(nunit), data["energy"] + ) + nunit, sedf = sed_conversion( + data["energy"], data["flux"].unit, False + ) + dataFlux = np.trapz( + data["energy"] * (data["flux"] * sedf).to(nunit), + data["energy"], + ) + ratio = dataFlux / currFlux + if labels[idxs[0]].startswith("log("): p0[idxs[0]] += np.log(ratio) - elif labels[idxs[0]].startswith('log10('): + elif labels[idxs[0]].startswith("log10("): p0[idxs[0]] += np.log10(ratio) else: p0[idxs[0]] *= ratio elif len(idxs) == 0: - log.warning('No label starting with [{0}] found: not applying' - ' normalization guess.'.format(','.join(normNames))) + log.warning( + "No label starting with [{0}] found: not applying" + " normalization guess.".format(",".join(normNames)) + ) elif len(idxs) > 1: - log.warning('More than one label starting with [{0}] found:' - ' not applying normalization guess.'.format(','.join( - normNames))) + log.warning( + "More than one label starting with [{0}] found:" + " not applying normalization guess.".format( + ",".join(normNames) + ) + ) P0_IS_ML = False if interactive: try: - log.info( - 'Launching interactive model fitter, close when finished') + log.info("Launching interactive model fitter, close when finished") from .model_fitter import InteractiveModelFitter import matplotlib.pyplot as plt - iprev = plt.rcParams['interactive'] - plt.rcParams['interactive'] = False + + iprev = plt.rcParams["interactive"] + plt.rcParams["interactive"] = False imf = InteractiveModelFitter( - model, p0, data, labels=labels, sed=True) + model, p0, data, labels=labels, sed=True + ) p0 = imf.pars P0_IS_ML = imf.P0_IS_ML - plt.rcParams['interactive'] = iprev + plt.rcParams["interactive"] = iprev except ImportError as e: - log.warning('Interactive fitting is not available because' - ' matplotlib is not installed: {0}'.format(e)) + log.warning( + "Interactive fitting is not available because" + " matplotlib is not installed: {0}".format(e) + ) # If we already did the prefit call in ModelWidget (and didn't modify the # parameters afterwards), avoid doing it here @@ -429,7 +483,8 @@ def get_sampler(data_table=None, p0, P0_IS_ML = _prefit(p0, data, model, prior) sampler = emcee.EnsembleSampler( - nwalkers, len(p0), lnprob, args=[data, model, prior], threads=threads) + nwalkers, len(p0), lnprob, args=[data, model, prior], threads=threads + ) # Add data and parameters properties to sampler sampler.data_table = data_table @@ -439,11 +494,11 @@ def get_sampler(data_table=None, sampler.modelfn = model # Add run_info dict sampler.run_info = { - 'n_walkers': nwalkers, - 'n_burn': nburn, + "n_walkers": nwalkers, + "n_burn": nburn, # convert from np.float to regular float - 'p0': [float(p) for p in p0], - 'guess': guess, + "p0": [float(p) for p in p0], + "guess": guess, } # Initialize walkers in a ball of relative size 0.5% in all dimensions if @@ -453,15 +508,17 @@ def get_sampler(data_table=None, p0 = emcee.utils.sample_ball(p0, p0var, nwalkers) if nburn > 0: - print('Burning in the {0} walkers with {1} steps...'.format(nwalkers, - nburn)) + print( + "Burning in the {0} walkers with {1} steps...".format( + nwalkers, nburn + ) + ) sampler, pos = _run_mcmc(sampler, p0, nburn) else: pos = p0 - sampler.run_info['p0_burn_median'] = [ - float(p) for p in np.median( - pos, axis=0) + sampler.run_info["p0_burn_median"] = [ + float(p) for p in np.median(pos, axis=0) ] return sampler, pos @@ -497,9 +554,9 @@ def run_sampler(nrun=100, sampler=None, pos=None, **kwargs): if sampler is None or pos is None: sampler, pos = get_sampler(**kwargs) - sampler.run_info['n_run'] = nrun + sampler.run_info["n_run"] = nrun - print('\nWalker burn in finished, running {0} steps...'.format(nrun)) + print("\nWalker burn in finished, running {0} steps...".format(nrun)) sampler.reset() sampler, pos = _run_mcmc(sampler, pos, nrun) diff --git a/naima/extern/interruptible_pool.py b/naima/extern/interruptible_pool.py index 22250ede..58b521db 100644 --- a/naima/extern/interruptible_pool.py +++ b/naima/extern/interruptible_pool.py @@ -22,8 +22,12 @@ """ -from __future__ import (division, print_function, absolute_import, - unicode_literals) +from __future__ import ( + division, + print_function, + absolute_import, + unicode_literals, +) __all__ = ["InterruptiblePool"] @@ -65,13 +69,16 @@ class InterruptiblePool(Pool): Extra arguments. Python 2.7 supports a ``maxtasksperchild`` parameter. """ + wait_timeout = 3600 - def __init__(self, processes=None, initializer=None, initargs=(), - **kwargs): + def __init__( + self, processes=None, initializer=None, initargs=(), **kwargs + ): new_initializer = functools.partial(_initializer_wrapper, initializer) - super(InterruptiblePool, self).__init__(processes, new_initializer, - initargs, **kwargs) + super(InterruptiblePool, self).__init__( + processes, new_initializer, initargs, **kwargs + ) def map(self, func, iterable, chunksize=None): """ diff --git a/naima/extern/minimize.py b/naima/extern/minimize.py index 25768948..f30c1b3f 100644 --- a/naima/extern/minimize.py +++ b/naima/extern/minimize.py @@ -7,20 +7,37 @@ # minimize is a thin wrapper that behaves like scipy.optimize.minimize -from __future__ import (absolute_import, division, print_function, - unicode_literals) +from __future__ import ( + absolute_import, + division, + print_function, + unicode_literals, +) import numpy -from numpy import (atleast_1d, eye, mgrid, argmin, zeros, shape, squeeze, - vectorize, asarray, sqrt, Inf, asfarray, isinf) +from numpy import ( + atleast_1d, + eye, + mgrid, + argmin, + zeros, + shape, + squeeze, + vectorize, + asarray, + sqrt, + Inf, + asfarray, + isinf, +) # standard status messages of optimizers -_status_message = {'success': 'Optimization terminated successfully.', - 'maxfev': 'Maximum number of function evaluations has ' - 'been exceeded.', - 'maxiter': 'Maximum number of iterations has been ' - 'exceeded.', - 'pr_loss': 'Desired error not necessarily achieved due ' - 'to precision loss.'} +_status_message = { + "success": "Optimization terminated successfully.", + "maxfev": "Maximum number of function evaluations has " "been exceeded.", + "maxiter": "Maximum number of iterations has been " "exceeded.", + "pr_loss": "Desired error not necessarily achieved due " + "to precision loss.", +} def wrap_function(function, args): @@ -34,20 +51,34 @@ def function_wrapper(*wrapper_args): return ncalls, function_wrapper + class OptimizeResult(dict): """ Represents the optimization result. """ + pass + class OptimizeWarning(UserWarning): pass -def minimize(func,x0,args=(),options={},method=None): +def minimize(func, x0, args=(), options={}, method=None): return _minimize_neldermead(func, x0, args=args, **options) -def _minimize_neldermead(func, x0, args=(), callback=None, xtol=1e-4, ftol=1e-4, - maxiter=None, maxfev=None, disp=False, return_all=False): # pragma: no cover + +def _minimize_neldermead( + func, + x0, + args=(), + callback=None, + xtol=1e-4, + ftol=1e-4, + maxiter=None, + maxfev=None, + disp=False, + return_all=False, +): # pragma: no cover """ Minimization of scalar function of one or more variables using the Nelder-Mead algorithm. @@ -98,7 +129,7 @@ def _minimize_neldermead(func, x0, args=(), callback=None, xtol=1e-4, ftol=1e-4, for k in range(0, N): y = numpy.array(x0, copy=True) if y[k] != 0: - y[k] = (1 + nonzdelt)*y[k] + y[k] = (1 + nonzdelt) * y[k] else: y[k] = zdelt @@ -113,9 +144,12 @@ def _minimize_neldermead(func, x0, args=(), callback=None, xtol=1e-4, ftol=1e-4, iterations = 1 - while (fcalls[0] < maxfun and iterations < maxiter): - if (numpy.max(numpy.ravel(numpy.abs((sim[1:] - sim[0]) / sim[0]))) <= xtol and - numpy.max(numpy.abs((fsim[0] - fsim[1:]) / fsim[0])) <= ftol): + while fcalls[0] < maxfun and iterations < maxiter: + if ( + numpy.max(numpy.ravel(numpy.abs((sim[1:] - sim[0]) / sim[0]))) + <= xtol + and numpy.max(numpy.abs((fsim[0] - fsim[1:]) / fsim[0])) <= ftol + ): break xbar = numpy.add.reduce(sim[:-1], 0) / N @@ -179,25 +213,31 @@ def _minimize_neldermead(func, x0, args=(), callback=None, xtol=1e-4, ftol=1e-4, if fcalls[0] >= maxfun: warnflag = 1 - msg = _status_message['maxfev'] + msg = _status_message["maxfev"] if disp: - print('Warning: ' + msg) + print("Warning: " + msg) elif iterations >= maxiter: warnflag = 2 - msg = _status_message['maxiter'] + msg = _status_message["maxiter"] if disp: - print('Warning: ' + msg) + print("Warning: " + msg) else: - msg = _status_message['success'] + msg = _status_message["success"] if disp: print(msg) print(" Current function value: %f" % fval) print(" Iterations: %d" % iterations) print(" Function evaluations: %d" % fcalls[0]) - result = OptimizeResult(fun=fval, nit=iterations, nfev=fcalls[0], - status=warnflag, success=(warnflag == 0), - message=msg, x=x) + result = OptimizeResult( + fun=fval, + nit=iterations, + nfev=fcalls[0], + status=warnflag, + success=(warnflag == 0), + message=msg, + x=x, + ) if retall: - result['allvecs'] = allvecs + result["allvecs"] = allvecs return result diff --git a/naima/extern/validator.py b/naima/extern/validator.py index 9ed8f44a..fb307561 100644 --- a/naima/extern/validator.py +++ b/naima/extern/validator.py @@ -5,16 +5,27 @@ from astropy import units as u from astropy.extern import six + def validate_physical_type(name, value, physical_type): if physical_type is not None: if not isinstance(value, u.Quantity): - raise TypeError("{0} should be given as a Quantity object".format(name)) + raise TypeError( + "{0} should be given as a Quantity object".format(name) + ) if isinstance(physical_type, six.string_types): if value.unit.physical_type != physical_type: - raise TypeError("{0} should be given in units of {1}".format(name, physical_type)) + raise TypeError( + "{0} should be given in units of {1}".format( + name, physical_type + ) + ) else: if not value.unit.physical_type in physical_type: - raise TypeError("{0} should be given in units of {1}".format(name, ', '.join(physical_type))) + raise TypeError( + "{0} should be given in units of {1}".format( + name, ", ".join(physical_type) + ) + ) def validate_scalar(name, value, domain=None, physical_type=None): @@ -23,28 +34,36 @@ def validate_scalar(name, value, domain=None, physical_type=None): if not physical_type: if not np.isscalar(value) or not np.isreal(value): - raise TypeError("{0} should be a scalar floating point value".format(name)) + raise TypeError( + "{0} should be a scalar floating point value".format(name) + ) - if domain == 'positive': - if value < 0.: + if domain == "positive": + if value < 0.0: raise ValueError("{0} should be positive".format(name)) - elif domain == 'strictly-positive': - if value <= 0.: + elif domain == "strictly-positive": + if value <= 0.0: raise ValueError("{0} should be strictly positive".format(name)) - elif domain == 'negative': - if value > 0.: + elif domain == "negative": + if value > 0.0: raise ValueError("{0} should be negative".format(name)) - elif domain == 'strictly-negative': - if value >= 0.: + elif domain == "strictly-negative": + if value >= 0.0: raise ValueError("{0} should be strictly negative".format(name)) elif type(domain) in [tuple, list] and len(domain) == 2: if value < domain[0] or value > domain[-1]: - raise ValueError("{0} should be in the range [{1}:{2}]".format(name, domain[0], domain[-1])) + raise ValueError( + "{0} should be in the range [{1}:{2}]".format( + name, domain[0], domain[-1] + ) + ) return value -def validate_array(name, value, domain=None, ndim=1, shape=None, physical_type=None): +def validate_array( + name, value, domain=None, ndim=1, shape=None, physical_type=None +): validate_physical_type(name, value, physical_type) @@ -62,8 +81,16 @@ def validate_array(name, value, domain=None, ndim=1, shape=None, physical_type=N # Check that the shape matches that expected if shape is not None and value.shape != shape: if ndim == 1: - raise ValueError("{0} has incorrect length (expected {1} but found {2})".format(name, shape[0], value.shape[0])) + raise ValueError( + "{0} has incorrect length (expected {1} but found {2})".format( + name, shape[0], value.shape[0] + ) + ) else: - raise ValueError("{0} has incorrect shape (expected {1} but found {2})".format(name, shape, value.shape)) + raise ValueError( + "{0} has incorrect shape (expected {1} but found {2})".format( + name, shape, value.shape + ) + ) return value diff --git a/naima/model_fitter.py b/naima/model_fitter.py index f564e7c9..86bd37b3 100644 --- a/naima/model_fitter.py +++ b/naima/model_fitter.py @@ -1,6 +1,10 @@ # Licensed under a 3-clause BSD style license - see LICENSE.rst -from __future__ import (absolute_import, division, print_function, - unicode_literals) +from __future__ import ( + absolute_import, + division, + print_function, + unicode_literals, +) import numpy as np import astropy.units as u @@ -11,12 +15,13 @@ from .utils import sed_conversion, validate_data_table from .extern.validator import validate_array -__all__ = ['InteractiveModelFitter'] +__all__ = ["InteractiveModelFitter"] def _process_model(model): - if ((isinstance(model, tuple) or isinstance(model, list)) and - not isinstance(model, np.ndarray)): + if ( + isinstance(model, tuple) or isinstance(model, list) + ) and not isinstance(model, np.ndarray): return model[0] else: return model @@ -57,15 +62,17 @@ class InteractiveModelFitter(object): Default is True and can also be changed through the GUI. """ - def __init__(self, - modelfn, - p0, - data=None, - e_range=None, - e_npoints=100, - labels=None, - sed=True, - auto_update=True): + def __init__( + self, + modelfn, + p0, + data=None, + e_range=None, + e_npoints=100, + labels=None, + sed=True, + auto_update=True, + ): import matplotlib.pyplot as plt from matplotlib.widgets import Button, Slider, CheckButtons @@ -74,9 +81,9 @@ def __init__(self, self.P0_IS_ML = False npars = len(p0) if labels is None: - labels = ['par{0}'.format(i) for i in range(npars)] + labels = ["par{0}".format(i) for i in range(npars)] elif len(labels) < npars: - labels += ['par{0}'.format(i) for i in range(len(labels), npars)] + labels += ["par{0}".format(i) for i in range(len(labels), npars)] self.hasdata = data is not None self.data = None @@ -89,12 +96,18 @@ def __init__(self, if e_range: e_range = validate_array( - 'e_range', u.Quantity(e_range), physical_type='energy') - energy = np.logspace( - np.log10(e_range[0].value), - np.log10(e_range[1].value), e_npoints) * e_range.unit + "e_range", u.Quantity(e_range), physical_type="energy" + ) + energy = ( + np.logspace( + np.log10(e_range[0].value), + np.log10(e_range[1].value), + e_npoints, + ) + * e_range.unit + ) if self.hasdata: - energy = energy.to(self.data['energy'].unit) + energy = energy.to(self.data["energy"].unit) else: e_unit = e_range.unit else: @@ -103,28 +116,30 @@ def __init__(self, # Bogus flux array to send to model if not using data if sed: - flux = np.zeros(e_npoints) * u.Unit('erg/(cm2 s)') + flux = np.zeros(e_npoints) * u.Unit("erg/(cm2 s)") else: - flux = np.zeros(e_npoints) * u.Unit('1/(TeV cm2 s)') + flux = np.zeros(e_npoints) * u.Unit("1/(TeV cm2 s)") if self.hasdata: - e_unit = self.data['energy'].unit + e_unit = self.data["energy"].unit _plot_data_to_ax(self.data, modelax, sed=sed, e_unit=e_unit) if not e_range: # use data for model - energy = self.data['energy'] - flux = self.data['flux'] + energy = self.data["energy"] + flux = self.data["flux"] - self.data_for_model = {'energy': energy, 'flux': flux} + self.data_for_model = {"energy": energy, "flux": flux} model = _process_model(self.modelfn(p0, self.data_for_model)) if self.hasdata: if not np.all( - self.data_for_model['energy'] == self.data['energy']): + self.data_for_model["energy"] == self.data["energy"] + ): # this will be slow, maybe interpolate already computed model? model_for_lnprob = _process_model( - self.modelfn(self.pars, self.data)) + self.modelfn(self.pars, self.data) + ) else: model_for_lnprob = model lnprob = lnprobmodel(model_for_lnprob, self.data) @@ -133,65 +148,71 @@ def __init__(self, self.lnprobtxt = modelax.text( 0.05, 0.05, - r'', - ha='left', - va='bottom', + r"", + ha="left", + va="bottom", transform=modelax.transAxes, - size=20) - self.lnprobtxt.set_text(r'$\ln\mathcal{{L}} = {0:.1f}$'.format( - lnprob)) + size=20, + ) + self.lnprobtxt.set_text( + r"$\ln\mathcal{{L}} = {0:.1f}$".format(lnprob) + ) self.f_unit, self.sedf = sed_conversion(energy, model.unit, sed) if self.hasdata: - datamin = (self.data['energy'][0] - self.data['energy_error_lo'][0] - ).to(e_unit).value / 3 + datamin = ( + self.data["energy"][0] - self.data["energy_error_lo"][0] + ).to(e_unit).value / 3 xmin = min(datamin, energy[0].to(e_unit).value) datamax = ( - self.data['energy'][-1] + self.data['energy_error_hi'][-1] + self.data["energy"][-1] + self.data["energy_error_hi"][-1] ).to(e_unit).value * 3 xmax = max(datamax, energy[-1].to(e_unit).value) modelax.set_xlim(xmin, xmax) else: # plot_data_to_ax has not set ylabel - unit = self.f_unit.to_string('latex_inline') + unit = self.f_unit.to_string("latex_inline") if sed: - modelax.set_ylabel(r'$E^2 dN/dE$ [{0}]'.format(unit)) + modelax.set_ylabel(r"$E^2 dN/dE$ [{0}]".format(unit)) else: - modelax.set_ylabel(r'$dN/dE$ [{0}]'.format(unit)) + modelax.set_ylabel(r"$dN/dE$ [{0}]".format(unit)) modelax.set_xlim(energy[0].value, energy[-1].value) self.line, = modelax.loglog( - energy.to(e_unit), (model * self.sedf).to(self.f_unit), + energy.to(e_unit), + (model * self.sedf).to(self.f_unit), lw=2, - c='k', - zorder=10) + c="k", + zorder=10, + ) - modelax.set_xlabel('Energy [{0}]'.format( - energy.unit.to_string('latex_inline'))) + modelax.set_xlabel( + "Energy [{0}]".format(energy.unit.to_string("latex_inline")) + ) paraxes = [] for n in range(npars): paraxes.append( - plt.subplot2grid( - (2 * npars, 10), (npars + n, 0), colspan=7)) + plt.subplot2grid((2 * npars, 10), (npars + n, 0), colspan=7) + ) self.parsliders = [] - slider_props = {'facecolor': color_cycle[-1], 'alpha': 0.5} + slider_props = {"facecolor": color_cycle[-1], "alpha": 0.5} for label, parax, valinit in six.moves.zip(labels, paraxes, p0): # Attempt to estimate reasonable parameter ranges from label pmin, pmax = valinit / 10, valinit * 3 - if 'log' in label: + if "log" in label: span = 2 - if 'norm' in label or 'amplitude' in label: + if "norm" in label or "amplitude" in label: # give more range for normalization parameters span *= 2 pmin, pmax = valinit - span, valinit + span - elif ('index' in label) or ('alpha' in label): - if valinit > 0.: + elif ("index" in label) or ("alpha" in label): + if valinit > 0.0: pmin, pmax = 0, 5 else: pmin, pmax = -5, 0 - elif 'norm' in label or 'amplitude' in label: + elif "norm" in label or "amplitude" in label: # norm without log, it will not be pretty because sliders are # only linear pmin, pmax = valinit / 100, valinit * 100 @@ -202,28 +223,30 @@ def __init__(self, pmin, pmax, valinit=valinit, - valfmt='%.4g', - **slider_props) + valfmt="%.4g", + **slider_props + ) slider.on_changed(self.update_if_auto) self.parsliders.append(slider) autoupdateax = plt.subplot2grid((8, 4), (4, 3), colspan=1, rowspan=1) - auto_update_check = CheckButtons(autoupdateax, ('Auto update',), - (auto_update,)) + auto_update_check = CheckButtons( + autoupdateax, ("Auto update",), (auto_update,) + ) auto_update_check.on_clicked(self.update_autoupdate) self.autoupdate = auto_update updateax = plt.subplot2grid((8, 4), (5, 3), colspan=1, rowspan=1) - update_button = Button(updateax, 'Update model') + update_button = Button(updateax, "Update model") update_button.on_clicked(self.update) if self.hasdata: fitax = plt.subplot2grid((8, 4), (6, 3), colspan=1, rowspan=1) - fit_button = Button(fitax, 'Do Nelder-Mead fit') + fit_button = Button(fitax, "Do Nelder-Mead fit") fit_button.on_clicked(self.do_fit) closeax = plt.subplot2grid((8, 4), (7, 3), colspan=1, rowspan=1) - close_button = Button(closeax, 'Close window') + close_button = Button(closeax, "Close window") close_button.on_clicked(self.close_fig) self.fig.subplots_adjust(top=0.98, right=0.98, bottom=0.02, hspace=0.2) @@ -245,14 +268,16 @@ def update(self, event): self.line.set_ydata((model * self.sedf).to(self.f_unit)) if self.hasdata: if not np.all( - self.data_for_model['energy'] == self.data['energy']): + self.data_for_model["energy"] == self.data["energy"] + ): # this will be slow, maybe interpolate already computed model? model = _process_model(self.modelfn(self.pars, self.data)) lnprob = lnprobmodel(model, self.data) if isinstance(lnprob, u.Quantity): lnprob = lnprob.decompose().value - self.lnprobtxt.set_text(r'$\ln\mathcal{{L}} = {0:.1f}$'.format( - lnprob)) + self.lnprobtxt.set_text( + r"$\ln\mathcal{{L}} = {0:.1f}$".format(lnprob) + ) self.fig.canvas.draw_idle() def do_fit(self, event): @@ -263,10 +288,11 @@ def do_fit(self, event): if P0_IS_ML: for slider, val in zip(self.parsliders, self.pars): slider.set_val(val) - self.update('after_fit') + self.update("after_fit") self.autoupdate = autoupdate self.P0_IS_ML = P0_IS_ML def close_fig(self, event): import matplotlib.pyplot as plt + plt.close(self.fig) diff --git a/naima/model_utils.py b/naima/model_utils.py index 68fb6fee..410c16cb 100644 --- a/naima/model_utils.py +++ b/naima/model_utils.py @@ -1,7 +1,11 @@ # -*- coding: utf-8 -*- # Licensed under a 3-clause BSD style license - see LICENSE.rst -from __future__ import (absolute_import, division, print_function, - unicode_literals) +from __future__ import ( + absolute_import, + division, + print_function, + unicode_literals, +) import numpy as np import hashlib from astropy import units as u @@ -25,9 +29,10 @@ def model(cls, energy, *args, **kwargs): try: with warnings.catch_warnings(): warnings.simplefilter( - 'ignore', - getattr(np, 'VisibleDeprecationWarning', None)) - energy = u.Quantity(energy['energy']) + "ignore", + getattr(np, "VisibleDeprecationWarning", None), + ) + energy = u.Quantity(energy["energy"]) except (TypeError, ValueError, IndexError): pass @@ -41,20 +46,20 @@ def model(cls, energy, *args, **kwargs): data = [hashlib.sha256(bstr).hexdigest()] data.append(energy.unit.to_string()) - data.append(str(kwargs.get('distance', 0))) + data.append(str(kwargs.get("distance", 0))) if args: data.append(str(args)) - if hasattr(cls, 'particle_distribution'): + if hasattr(cls, "particle_distribution"): models = [cls, cls.particle_distribution] else: models = [cls] for model in models: - if hasattr(model, 'param_names'): + if hasattr(model, "param_names"): for par in model.param_names: data.append(str(getattr(model, par))) - token = u''.join(data) + token = "".join(data) digest = hashlib.sha256(token.encode()).hexdigest() if digest in cache: diff --git a/naima/models.py b/naima/models.py index 7df27f91..3b10a8cb 100644 --- a/naima/models.py +++ b/naima/models.py @@ -1,24 +1,44 @@ # -*- coding: utf-8 -*- # Licensed under a 3-clause BSD style license - see LICENSE.rst -from __future__ import (absolute_import, division, print_function, - unicode_literals) +from __future__ import ( + absolute_import, + division, + print_function, + unicode_literals, +) import os import numpy as np import astropy.units as u from astropy.utils.data import get_pkg_data_filename -from .extern.validator import (validate_scalar, validate_array, - validate_physical_type) +from .extern.validator import ( + validate_scalar, + validate_array, + validate_physical_type, +) from .radiative import ( - Synchrotron, ElectronSynchrotron, ProtonSynchrotron, - InverseCompton, PionDecay, Bremsstrahlung + Synchrotron, + ElectronSynchrotron, + ProtonSynchrotron, + InverseCompton, + PionDecay, + Bremsstrahlung, ) from .model_utils import memoize __all__ = [ - 'Synchrotron', 'ElectronSynchrotron', 'ProtonSynchrotron', - 'InverseCompton', 'PionDecay', 'Bremsstrahlung', 'BrokenPowerLaw', - 'ExponentialCutoffPowerLaw', 'PowerLaw', 'LogParabola', - 'ExponentialCutoffBrokenPowerLaw', 'TableModel', 'EblAbsorptionModel' + "Synchrotron", + "ElectronSynchrotron", + "ProtonSynchrotron", + "InverseCompton", + "PionDecay", + "Bremsstrahlung", + "BrokenPowerLaw", + "ExponentialCutoffPowerLaw", + "PowerLaw", + "LogParabola", + "ExponentialCutoffBrokenPowerLaw", + "TableModel", + "EblAbsorptionModel", ] @@ -28,13 +48,14 @@ def _validate_ene(ene): if isinstance(ene, dict) or isinstance(ene, Table): try: ene = validate_array( - 'energy', u.Quantity(ene['energy']), physical_type='energy') + "energy", u.Quantity(ene["energy"]), physical_type="energy" + ) except KeyError: - raise TypeError('Table or dict does not have \'ene\' column') + raise TypeError("Table or dict does not have 'ene' column") else: if not isinstance(ene, u.Quantity): ene = u.Quantity(ene) - validate_physical_type('energy', ene, physical_type='energy') + validate_physical_type("energy", ene, physical_type="energy") return ene @@ -65,7 +86,7 @@ class PowerLaw(object): """ - param_names = ['amplitude', 'e_0', 'alpha'] + param_names = ["amplitude", "e_0", "alpha"] _memoize = False _cache = {} _queue = [] @@ -73,7 +94,8 @@ class PowerLaw(object): def __init__(self, amplitude, e_0, alpha): self.amplitude = amplitude self.e_0 = validate_scalar( - 'e_0', e_0, domain='positive', physical_type='energy') + "e_0", e_0, domain="positive", physical_type="energy" + ) self.alpha = alpha @staticmethod @@ -81,13 +103,16 @@ def eval(e, amplitude, e_0, alpha): """One dimensional power law model function""" xx = e / e_0 - return amplitude * xx**(-alpha) + return amplitude * xx ** (-alpha) @memoize def _calc(self, e): return self.eval( - e.to('eV').value, self.amplitude, - self.e_0.to('eV').value, self.alpha) + e.to("eV").value, + self.amplitude, + self.e_0.to("eV").value, + self.alpha, + ) def __call__(self, e): """One dimensional power law model function""" @@ -126,7 +151,7 @@ class ExponentialCutoffPowerLaw(object): """ - param_names = ['amplitude', 'e_0', 'alpha', 'e_cutoff', 'beta'] + param_names = ["amplitude", "e_0", "alpha", "e_cutoff", "beta"] _memoize = False _cache = {} _queue = [] @@ -134,10 +159,12 @@ class ExponentialCutoffPowerLaw(object): def __init__(self, amplitude, e_0, alpha, e_cutoff, beta=1.0): self.amplitude = amplitude self.e_0 = validate_scalar( - 'e_0', e_0, domain='positive', physical_type='energy') + "e_0", e_0, domain="positive", physical_type="energy" + ) self.alpha = alpha self.e_cutoff = validate_scalar( - 'e_cutoff', e_cutoff, domain='positive', physical_type='energy') + "e_cutoff", e_cutoff, domain="positive", physical_type="energy" + ) self.beta = beta @staticmethod @@ -146,14 +173,18 @@ def eval(e, amplitude, e_0, alpha, e_cutoff, beta): """ xx = e / e_0 - return amplitude * xx**(-alpha) * np.exp(-(e / e_cutoff)**beta) + return amplitude * xx ** (-alpha) * np.exp(-(e / e_cutoff) ** beta) @memoize def _calc(self, e): return self.eval( - e.to('eV').value, self.amplitude, - self.e_0.to('eV').value, self.alpha, - self.e_cutoff.to('eV').value, self.beta) + e.to("eV").value, + self.amplitude, + self.e_0.to("eV").value, + self.alpha, + self.e_cutoff.to("eV").value, + self.beta, + ) def __call__(self, e): """One dimensional power law with an exponential cutoff model function @@ -199,7 +230,7 @@ class BrokenPowerLaw(object): \\right. """ - param_names = ['amplitude', 'e_0', 'e_break', 'alpha_1', 'alpha_2'] + param_names = ["amplitude", "e_0", "e_break", "alpha_1", "alpha_2"] _memoize = False _cache = {} _queue = [] @@ -207,25 +238,31 @@ class BrokenPowerLaw(object): def __init__(self, amplitude, e_0, e_break, alpha_1, alpha_2): self.amplitude = amplitude self.e_0 = validate_scalar( - 'e_0', e_0, domain='positive', physical_type='energy') + "e_0", e_0, domain="positive", physical_type="energy" + ) self.e_break = validate_scalar( - 'e_break', e_break, domain='positive', physical_type='energy') + "e_break", e_break, domain="positive", physical_type="energy" + ) self.alpha_1 = alpha_1 self.alpha_2 = alpha_2 @staticmethod def eval(e, amplitude, e_0, e_break, alpha_1, alpha_2): """One dimensional broken power law model function""" - K = np.where(e < e_break, 1, (e_break / e_0)**(alpha_2 - alpha_1)) + K = np.where(e < e_break, 1, (e_break / e_0) ** (alpha_2 - alpha_1)) alpha = np.where(e < e_break, alpha_1, alpha_2) - return amplitude * K * (e / e_0)**-alpha + return amplitude * K * (e / e_0) ** -alpha @memoize def _calc(self, e): return self.eval( - e.to('eV').value, self.amplitude, - self.e_0.to('eV').value, - self.e_break.to('eV').value, self.alpha_1, self.alpha_2) + e.to("eV").value, + self.amplitude, + self.e_0.to("eV").value, + self.e_break.to("eV").value, + self.alpha_1, + self.alpha_2, + ) def __call__(self, e): """One dimensional broken power law model function""" @@ -278,46 +315,55 @@ class ExponentialCutoffBrokenPowerLaw(object): """ param_names = [ - 'amplitude', 'e_0', 'e_break', 'alpha_1', 'alpha_2', 'e_cutoff', 'beta' + "amplitude", + "e_0", + "e_break", + "alpha_1", + "alpha_2", + "e_cutoff", + "beta", ] _memoize = False _cache = {} _queue = [] - def __init__(self, - amplitude, - e_0, - e_break, - alpha_1, - alpha_2, - e_cutoff, - beta=1.0): + def __init__( + self, amplitude, e_0, e_break, alpha_1, alpha_2, e_cutoff, beta=1.0 + ): self.amplitude = amplitude self.e_0 = validate_scalar( - 'e_0', e_0, domain='positive', physical_type='energy') + "e_0", e_0, domain="positive", physical_type="energy" + ) self.e_break = validate_scalar( - 'e_break', e_break, domain='positive', physical_type='energy') + "e_break", e_break, domain="positive", physical_type="energy" + ) self.alpha_1 = alpha_1 self.alpha_2 = alpha_2 self.e_cutoff = validate_scalar( - 'e_cutoff', e_cutoff, domain='positive', physical_type='energy') + "e_cutoff", e_cutoff, domain="positive", physical_type="energy" + ) self.beta = beta @staticmethod def eval(e, amplitude, e_0, e_break, alpha_1, alpha_2, e_cutoff, beta): """One dimensional broken power law model function""" - K = np.where(e < e_break, 1, (e_break / e_0)**(alpha_2 - alpha_1)) + K = np.where(e < e_break, 1, (e_break / e_0) ** (alpha_2 - alpha_1)) alpha = np.where(e < e_break, alpha_1, alpha_2) ee2 = e / e_cutoff - return amplitude * K * (e / e_0)**-alpha * np.exp(-(ee2**beta)) + return amplitude * K * (e / e_0) ** -alpha * np.exp(-(ee2 ** beta)) @memoize def _calc(self, e): return self.eval( - e.to('eV').value, self.amplitude, - self.e_0.to('eV').value, - self.e_break.to('eV').value, self.alpha_1, self.alpha_2, - self.e_cutoff.to('eV').value, self.beta) + e.to("eV").value, + self.amplitude, + self.e_0.to("eV").value, + self.e_break.to("eV").value, + self.alpha_1, + self.alpha_2, + self.e_cutoff.to("eV").value, + self.beta, + ) def __call__(self, e): """One dimensional broken power law model with exponential cutoff @@ -357,7 +403,7 @@ class LogParabola(object): """ - param_names = ['amplitude', 'e_0', 'alpha', 'beta'] + param_names = ["amplitude", "e_0", "alpha", "beta"] _memoize = False _cache = {} _queue = [] @@ -365,7 +411,8 @@ class LogParabola(object): def __init__(self, amplitude, e_0, alpha, beta): self.amplitude = amplitude self.e_0 = validate_scalar( - 'e_0', e_0, domain='positive', physical_type='energy') + "e_0", e_0, domain="positive", physical_type="energy" + ) self.alpha = alpha self.beta = beta @@ -375,13 +422,17 @@ def eval(e, amplitude, e_0, alpha, beta): ee = e / e_0 eeponent = -alpha - beta * np.log(ee) - return amplitude * ee**eeponent + return amplitude * ee ** eeponent @memoize def _calc(self, e): return self.eval( - e.to('eV').value, self.amplitude, - self.e_0.to('eV').value, self.alpha, self.beta) + e.to("eV").value, + self.amplitude, + self.e_0.to("eV").value, + self.alpha, + self.beta, + ) def __call__(self, e): """One dimensional curved power law function""" @@ -411,25 +462,28 @@ class TableModel(object): def __init__(self, energy, values, amplitude=1): from scipy.interpolate import interp1d + self._energy = validate_array( - 'energy', energy, domain='positive', physical_type='energy') + "energy", energy, domain="positive", physical_type="energy" + ) self._values = values self.amplitude = amplitude - loge = np.log10(self._energy.to('eV').value) + loge = np.log10(self._energy.to("eV").value) try: self.unit = self._values.unit logy = np.log10(self._values.value) except AttributeError: - self.unit = u.Unit('') + self.unit = u.Unit("") logy = np.log10(self._values) self._interplogy = interp1d( - loge, logy, fill_value=-np.Inf, bounds_error=False, kind='cubic') + loge, logy, fill_value=-np.Inf, bounds_error=False, kind="cubic" + ) def __call__(self, e): e = _validate_ene(e) - interpy = np.power(10, self._interplogy(np.log10(e.to('eV').value))) + interpy = np.power(10, self._interplogy(np.log10(e.to("eV").value))) return self.amplitude * interpy * self.unit @@ -460,41 +514,46 @@ class EblAbsorptionModel(TableModel): TableModel """ - def __init__(self, redshift, ebl_absorption_model='Dominguez'): + def __init__(self, redshift, ebl_absorption_model="Dominguez"): # check that the redshift is a positive scalar if not isinstance(redshift, u.Quantity): redshift *= u.dimensionless_unscaled self.redshift = validate_scalar( - 'redshift', + "redshift", redshift, - domain='positive', - physical_type='dimensionless') + domain="positive", + physical_type="dimensionless", + ) self.model = ebl_absorption_model - if self.model == 'Dominguez': + if self.model == "Dominguez": """Table generated by Alberto Dominguez containing tau vs energy [TeV] vs redshift. Energy is defined between 1 GeV and 100 TeV, in 500 bins uniform in log(E). Redshift is defined between 0.01 and 4, in steps of 0.01. """ filename = get_pkg_data_filename( - os.path.join('data', 'tau_dominguez11.npz')) - taus_table = np.load(filename)['arr_0'] + os.path.join("data", "tau_dominguez11.npz") + ) + taus_table = np.load(filename)["arr_0"] redshift_list = np.arange(0.01, 4, 0.01) - energy = taus_table['energy'] * u.TeV + energy = taus_table["energy"] * u.TeV if self.redshift >= 0.01: - colname = 'col%s' % ( - 2 + (np.abs(redshift_list - self.redshift)).argmin()) + colname = "col%s" % ( + 2 + (np.abs(redshift_list - self.redshift)).argmin() + ) table_values = taus_table[colname] # Set maximum value of the log(Tau) to 150, as it is high # enough. This solves later overflow problems. - table_values[table_values > 150.] = 150. - taus = 10**table_values * u.dimensionless_unscaled + table_values[table_values > 150.0] = 150.0 + taus = 10 ** table_values * u.dimensionless_unscaled elif self.redshift < 0.01: - taus = 10**np.zeros(len(taus_table[ - 'energy'])) * u.dimensionless_unscaled + taus = ( + 10 ** np.zeros(len(taus_table["energy"])) + * u.dimensionless_unscaled + ) else: raise ValueError('Model should be one of: ["Dominguez"]') @@ -504,10 +563,10 @@ def transmission(self, e): e = _validate_ene(e) taus = np.zeros(len(e)) for i in range(0, len(e)): - if e[i].to('GeV').value < 1.: - taus[i] = 0. - elif e[i].to('TeV').value > 100.: - taus[i] = np.log10(6000.) + if e[i].to("GeV").value < 1.0: + taus[i] = 0.0 + elif e[i].to("TeV").value > 100.0: + taus[i] = np.log10(6000.0) else: taus[i] = np.log10(self(e[i])) return np.exp(-taus) diff --git a/naima/plot.py b/naima/plot.py index f39a133d..56bcfbaa 100644 --- a/naima/plot.py +++ b/naima/plot.py @@ -1,6 +1,10 @@ # Licensed under a 3-clause BSD style license - see LICENSE.rst -from __future__ import (absolute_import, division, print_function, - unicode_literals) +from __future__ import ( + absolute_import, + division, + print_function, + unicode_literals, +) import numpy as np import astropy.units as u @@ -15,14 +19,16 @@ __all__ = ["plot_chain", "plot_fit", "plot_data", "plot_blob", "plot_corner"] -marker_cycle = ['o', 's', 'd', 'p', '*'] +marker_cycle = ["o", "s", "d", "p", "*"] # from seaborn: sns.color_palette('deep',6) -color_cycle = [(0.2980392156862745, 0.4470588235294118, 0.6901960784313725), - (0.3333333333333333, 0.6588235294117647, 0.40784313725490196), - (0.7686274509803922, 0.3058823529411765, 0.3215686274509804), - (0.5058823529411764, 0.4470588235294118, 0.6980392156862745), - (0.8, 0.7254901960784313, 0.4549019607843137), - (0.39215686274509803, 0.7098039215686275, 0.803921568627451)] +color_cycle = [ + (0.2980392156862745, 0.4470588235294118, 0.6901960784313725), + (0.3333333333333333, 0.6588235294117647, 0.40784313725490196), + (0.7686274509803922, 0.3058823529411765, 0.3215686274509804), + (0.5058823529411764, 0.4470588235294118, 0.6980392156862745), + (0.8, 0.7254901960784313, 0.4549019607843137), + (0.39215686274509803, 0.7098039215686275, 0.803921568627451), +] def plot_chain(sampler, p=None, **kwargs): @@ -71,16 +77,16 @@ def round2(x, n): y = str(int(y)) else: # preserve trailing zeroes - y = ('{{0:.{0}f}}'.format(n)).format(x) + y = ("{{0:.{0}f}}".format(n)).format(x) return y def _latex_value_error(val, elo, ehi=0, tol=0.25): order = int(np.log10(abs(val))) if order > 2 or order < -2: - val /= 10**order - elo /= 10**order - ehi /= 10**order + val /= 10 ** order + elo /= 10 ** order + ehi /= 10 ** order else: order = 0 nlo = -int(np.floor(np.log10(elo))) @@ -92,23 +98,26 @@ def _latex_value_error(val, elo, ehi=0, tol=0.25): if ehi * 10 ** nhi < 2: nhi += 1 # ehi = round(ehi,nhi) - if np.abs(elo - ehi) / ((elo + ehi) / 2.) > tol: + if np.abs(elo - ehi) / ((elo + ehi) / 2.0) > tol: n = max(nlo, nhi) - string = '{0}^{{+{1}}}_{{-{2}}}'.format(* [ - round2(x, nn) for x, nn in zip([val, ehi, elo], [n, nhi, nlo]) - ]) + string = "{0}^{{+{1}}}_{{-{2}}}".format( + *[ + round2(x, nn) + for x, nn in zip([val, ehi, elo], [n, nhi, nlo]) + ] + ) else: - e = (elo + ehi) / 2. + e = (elo + ehi) / 2.0 n = -int(np.floor(np.log10(e))) - if e * 10**n < 2: + if e * 10 ** n < 2: n += 1 - string = '{0} \pm {1}'.format(* [round2(x, n) for x in [val, e]]) + string = "{0} \pm {1}".format(*[round2(x, n) for x in [val, e]]) else: - string = '{0} \pm {1}'.format(* [round2(x, nlo) for x in [val, elo]]) + string = "{0} \pm {1}".format(*[round2(x, nlo) for x in [val, elo]]) if order != 0: - string = '(' + string + r')\times10^{{{0}}}'.format(order) + string = "(" + string + r")\times10^{{{0}}}".format(order) - return '$' + string + '$' + return "$" + string + "$" def _plot_chain_func(sampler, p, last_step=False): @@ -118,6 +127,7 @@ def _plot_chain_func(sampler, p, last_step=False): import matplotlib.pyplot as plt from scipy import stats + if len(chain.shape) > 2: traces = chain[:, :, p] if last_step: @@ -128,7 +138,8 @@ def _plot_chain_func(sampler, p, last_step=False): dist = traces.flatten() else: log.warning( - 'we need the full chain to plot the traces, not a flatchain!') + "we need the full chain to plot the traces, not a flatchain!" + ) return None nwalkers = traces.shape[0] @@ -144,7 +155,7 @@ def _plot_chain_func(sampler, p, last_step=False): # plot five percent of the traces darker if nwalkers < 60: - thresh = 1 - 3. / nwalkers + thresh = 1 - 3.0 / nwalkers else: thresh = 0.95 red = np.arange(nwalkers) / float(nwalkers) >= thresh @@ -154,47 +165,50 @@ def _plot_chain_func(sampler, p, last_step=False): ax1.plot(t, color=(0.1,) * 3, lw=1.0, alpha=0.25, zorder=0) for t in traces[red]: ax1.plot(t, color=color_cycle[0], lw=1.5, alpha=0.75, zorder=0) - ax1.set_xlabel('step number') + ax1.set_xlabel("step number") # [l.set_rotation(45) for l in ax1.get_yticklabels()] ax1.set_ylabel(label) ax1.yaxis.set_label_coords(-0.15, 0.5) - ax1.set_title('Walker traces') + ax1.set_title("Walker traces") - nbins = min(max(25, int(len(dist) / 100.)), 100) + nbins = min(max(25, int(len(dist) / 100.0)), 100) xlabel = label n, x, _ = ax2.hist( dist, nbins, - histtype='stepfilled', + histtype="stepfilled", color=color_cycle[0], lw=0, - normed=1) + normed=1, + ) kde = stats.kde.gaussian_kde(dist) - ax2.plot(x, kde(x), color='k', label='KDE') + ax2.plot(x, kde(x), color="k", label="KDE") quant = [16, 50, 84] xquant = np.percentile(dist, quant) quantiles = dict(six.moves.zip(quant, xquant)) ax2.axvline( quantiles[50], - ls='--', - color='k', + ls="--", + color="k", alpha=0.5, lw=2, - label='50% quantile') + label="50% quantile", + ) ax2.axvspan( quantiles[16], quantiles[84], color=(0.5,) * 3, alpha=0.25, - label='68% CI', - lw=0) + label="68% CI", + lw=0, + ) # ax2.legend() for l in ax2.get_xticklabels(): l.set_rotation(45) ax2.set_xlabel(xlabel) ax2.xaxis.set_label_coords(0.5, -0.1) - ax2.set_title('posterior distribution') + ax2.set_title("posterior distribution") ax2.set_ylim(top=n.max() * 1.05) # Print distribution parameters on lower-left @@ -204,61 +218,76 @@ def _plot_chain_func(sampler, p, last_step=False): ac = sampler.get_autocorr_time()[p] except AttributeError: ac = autocorr.integrated_time( - np.mean( - chain, axis=0), axis=0, fast=False)[p] - autocorr_message = '{0:.1f}'.format(ac) + np.mean(chain, axis=0), axis=0, fast=False + )[p] + autocorr_message = "{0:.1f}".format(ac) except autocorr.AutocorrError: # Raised when chain is too short for meaningful auto-correlation # estimation autocorr_message = None if last_step: - clen = 'last ensemble' + clen = "last ensemble" else: - clen = 'whole chain' + clen = "whole chain" - chain_props = 'Walkers: {0} \nSteps in chain: {1} \n'.format(nwalkers, - nsteps) + chain_props = "Walkers: {0} \nSteps in chain: {1} \n".format( + nwalkers, nsteps + ) if autocorr_message is not None: - chain_props += 'Autocorrelation time: {0}\n'.format(autocorr_message) - chain_props += 'Mean acceptance fraction: {0:.3f}\n'.format( - np.mean(sampler.acceptance_fraction)) +\ - 'Distribution properties for the {clen}:\n \ + chain_props += "Autocorrelation time: {0}\n".format(autocorr_message) + chain_props += ( + "Mean acceptance fraction: {0:.3f}\n".format( + np.mean(sampler.acceptance_fraction) + ) + + "Distribution properties for the {clen}:\n \ $-$ median: ${median}$, std: ${std}$ \n \ $-$ median with uncertainties based on \n \ - the 16th and 84th percentiles ($\sim$1$\sigma$):\n'.format( - median=_latex_float(quantiles[50]), - std=_latex_float(np.std(dist)), clen=clen) - - info_line = ' ' * 10 + label + ' = ' + _latex_value_error( - quantiles[50], quantiles[50] - quantiles[16], quantiles[84] - - quantiles[50]) + the 16th and 84th percentiles ($\sim$1$\sigma$):\n".format( + median=_latex_float(quantiles[50]), + std=_latex_float(np.std(dist)), + clen=clen, + ) + ) + + info_line = ( + " " * 10 + + label + + " = " + + _latex_value_error( + quantiles[50], + quantiles[50] - quantiles[16], + quantiles[84] - quantiles[50], + ) + ) chain_props += info_line - if 'log10(' in label or 'log(' in label: - nlabel = label.split('(')[-1].split(')')[0] - ltype = label.split('(')[0] - if ltype == 'log10': - new_dist = 10**dist - elif ltype == 'log': + if "log10(" in label or "log(" in label: + nlabel = label.split("(")[-1].split(")")[0] + ltype = label.split("(")[0] + if ltype == "log10": + new_dist = 10 ** dist + elif ltype == "log": new_dist = np.exp(dist) quant = [16, 50, 84] quantiles = dict(six.moves.zip(quant, np.percentile(new_dist, quant))) - label_template = '\n' + ' ' * 10 + '{{label:>{0}}}'.format(len(label)) + label_template = "\n" + " " * 10 + "{{label:>{0}}}".format(len(label)) new_line = label_template.format(label=nlabel) - new_line += ' = ' + _latex_value_error(quantiles[50], quantiles[50] - - quantiles[16], quantiles[84] - - quantiles[50]) + new_line += " = " + _latex_value_error( + quantiles[50], + quantiles[50] - quantiles[16], + quantiles[84] - quantiles[50], + ) chain_props += new_line info_line += new_line - log.info('{0:-^50}\n'.format(label) + info_line) - f.text(0.05, 0.45, chain_props, ha='left', va='top') + log.info("{0:-^50}\n".format(label) + info_line) + f.text(0.05, 0.45, chain_props, ha="left", va="top") return f @@ -275,10 +304,10 @@ def _process_blob(sampler, modelidx, last_step=False, energy=None): # Allow process blob to be used by _calc_samples and _calc_ML by sending # only blobs, not full sampler - if hasattr(sampler, 'blobs'): + if hasattr(sampler, "blobs"): blob0 = sampler.blobs[-1][0][modelidx] blobs = sampler.blobs - energy = sampler.data['energy'] + energy = sampler.data["energy"] else: blobs = [sampler] blob0 = sampler[0][modelidx] @@ -310,9 +339,12 @@ def _process_blob(sampler, modelidx, last_step=False, energy=None): for walkerblob in step: model.append(walkerblob[modelidx]) model = u.Quantity(model) - elif (isinstance(blob0, list) or isinstance(blob0, tuple)): - if (len(blob0) == 2 and isinstance(blob0[0], u.Quantity) and - isinstance(blob0[1], u.Quantity)): + elif isinstance(blob0, list) or isinstance(blob0, tuple): + if ( + len(blob0) == 2 + and isinstance(blob0[0], u.Quantity) + and isinstance(blob0[1], u.Quantity) + ): # Energy array for model is item 0 in blob, model flux is item 1 modelx = blob0[0] @@ -325,21 +357,23 @@ def _process_blob(sampler, modelidx, last_step=False, energy=None): model.append(walkerblob[modelidx][1]) model = u.Quantity(model) else: - raise TypeError('Model {0} has wrong blob format'.format(modelidx)) + raise TypeError("Model {0} has wrong blob format".format(modelidx)) else: - raise TypeError('Model {0} has wrong blob format'.format(modelidx)) + raise TypeError("Model {0} has wrong blob format".format(modelidx)) return modelx, model -def _read_or_calc_samples(sampler, - modelidx=0, - n_samples=100, - last_step=False, - e_range=None, - e_npoints=100, - threads=None): +def _read_or_calc_samples( + sampler, + modelidx=0, + n_samples=100, + last_step=False, + e_range=None, + e_npoints=100, + threads=None, +): """Get samples from blob or compute them from chain and sampler.modelfn """ @@ -349,14 +383,20 @@ def _read_or_calc_samples(sampler, else: # prepare bogus data for calculation e_range = validate_array( - 'e_range', u.Quantity(e_range), physical_type='energy') + "e_range", u.Quantity(e_range), physical_type="energy" + ) e_unit = e_range.unit - energy = np.logspace( - np.log10(e_range[0].value), np.log10(e_range[1].value), - e_npoints) * e_unit + energy = ( + np.logspace( + np.log10(e_range[0].value), + np.log10(e_range[1].value), + e_npoints, + ) + * e_unit + ) data = { - 'energy': energy, - 'flux': np.zeros(energy.shape) * sampler.data['flux'].unit + "energy": energy, + "flux": np.zeros(energy.shape) * sampler.data["flux"].unit, } # init pool and select parameters chain = sampler.chain[-1] if last_step else sampler.flatchain @@ -375,7 +415,8 @@ def _read_or_calc_samples(sampler, blobs.append(modelout) modelx, model = _process_blob( - blobs, modelidx=modelidx, energy=data['energy']) + blobs, modelidx=modelidx, energy=data["energy"] + ) return modelx, model @@ -389,14 +430,20 @@ def _calc_ML(sampler, modelidx=0, e_range=None, e_npoints=100): if e_range is not None: # prepare bogus data for calculation e_range = validate_array( - 'e_range', u.Quantity(e_range), physical_type='energy') + "e_range", u.Quantity(e_range), physical_type="energy" + ) e_unit = e_range.unit - energy = np.logspace( - np.log10(e_range[0].value), np.log10(e_range[1].value), - e_npoints) * e_unit + energy = ( + np.logspace( + np.log10(e_range[0].value), + np.log10(e_range[1].value), + e_npoints, + ) + * e_unit + ) data = { - 'energy': energy, - 'flux': np.zeros(energy.shape) * sampler.data['flux'].unit + "energy": energy, + "flux": np.zeros(energy.shape) * sampler.data["flux"].unit, } modelout = sampler.modelfn(MLp, data) @@ -406,26 +453,28 @@ def _calc_ML(sampler, modelidx=0, e_range=None, e_npoints=100): blob = modelout[modelidx] if isinstance(blob, u.Quantity): - modelx = data['energy'].copy() + modelx = data["energy"].copy() model_ML = blob.copy() elif len(blob) == 2: modelx = blob[0].copy() model_ML = blob[1].copy() else: - raise TypeError('Model {0} has wrong blob format'.format(modelidx)) + raise TypeError("Model {0} has wrong blob format".format(modelidx)) ML_model = (modelx, model_ML) return ML, MLp, MLerr, ML_model -def _calc_CI(sampler, - modelidx=0, - confs=[3, 1], - last_step=False, - e_range=None, - e_npoints=100, - threads=None): +def _calc_CI( + sampler, + modelidx=0, + confs=[3, 1], + last_step=False, + e_range=None, + e_npoints=100, + threads=None, +): """Calculate confidence interval. """ from scipy import stats @@ -443,14 +492,15 @@ def _calc_CI(sampler, minsamples = min(100, int(1 / stats.norm.cdf(-maxconf) + 1)) if minsamples > 1000: log.warning( - 'In order to sample the confidence band for {0} sigma,' - ' {1} new samples need to be computed, but we are limiting' - ' it to 1000 samples, so the confidence band might not be' - ' well constrained.' - ' Consider reducing the maximum' - ' confidence significance or using the samples stored in' - ' the sampler by setting e_range' - ' to None'.format(maxconf, minsamples)) + "In order to sample the confidence band for {0} sigma," + " {1} new samples need to be computed, but we are limiting" + " it to 1000 samples, so the confidence band might not be" + " well constrained." + " Consider reducing the maximum" + " confidence significance or using the samples stored in" + " the sampler by setting e_range" + " to None".format(maxconf, minsamples) + ) minsamples = 1000 else: minsamples = None @@ -462,7 +512,8 @@ def _calc_CI(sampler, e_range=e_range, e_npoints=e_npoints, n_samples=minsamples, - threads=threads) + threads=threads, + ) nwalkers = len(model) - 1 CI = [] @@ -486,27 +537,32 @@ def _plot_MLmodel(ax, sampler, modelidx, e_range, e_npoints, e_unit, sed): """compute and plot ML model""" ML, MLp, MLerr, ML_model = _calc_ML( - sampler, modelidx, e_range=e_range, e_npoints=e_npoints) + sampler, modelidx, e_range=e_range, e_npoints=e_npoints + ) f_unit, sedf = sed_conversion(ML_model[0], ML_model[1].unit, sed) ax.loglog( - ML_model[0].to(e_unit).value, (ML_model[1] * sedf).to(f_unit).value, - color='k', + ML_model[0].to(e_unit).value, + (ML_model[1] * sedf).to(f_unit).value, + color="k", lw=2, - alpha=0.8) - - -def plot_CI(ax, - sampler, - modelidx=0, - sed=True, - confs=[3, 1, 0.5], - e_unit=u.eV, - label=None, - e_range=None, - e_npoints=100, - threads=None, - last_step=False): + alpha=0.8, + ) + + +def plot_CI( + ax, + sampler, + modelidx=0, + sed=True, + confs=[3, 1, 0.5], + e_unit=u.eV, + label=None, + e_range=None, + e_npoints=100, + threads=None, + last_step=False, +): """Plot confidence interval. Parameters @@ -544,38 +600,44 @@ def plot_CI(ax, e_range=e_range, e_npoints=e_npoints, last_step=last_step, - threads=threads) + threads=threads, + ) # pick first confidence interval curve for units f_unit, sedf = sed_conversion(modelx, CI[0][0].unit, sed) for (ymin, ymax), conf in zip(CI, confs): color = np.log(conf) / np.log(20) + 0.4 ax.fill_between( - modelx.to(e_unit).value, (ymax * sedf).to(f_unit).value, + modelx.to(e_unit).value, + (ymax * sedf).to(f_unit).value, (ymin * sedf).to(f_unit).value, lw=0.001, color=(color,) * 3, alpha=0.6, - zorder=-10) + zorder=-10, + ) _plot_MLmodel(ax, sampler, modelidx, e_range, e_npoints, e_unit, sed) if label is not None: - ax.set_ylabel('{0} [{1}]'.format(label, f_unit.to_string( - 'latex_inline'))) - - -def plot_samples(ax, - sampler, - modelidx=0, - sed=True, - n_samples=100, - e_unit=u.eV, - e_range=None, - e_npoints=100, - threads=None, - label=None, - last_step=False): + ax.set_ylabel( + "{0} [{1}]".format(label, f_unit.to_string("latex_inline")) + ) + + +def plot_samples( + ax, + sampler, + modelidx=0, + sed=True, + n_samples=100, + e_unit=u.eV, + e_range=None, + e_npoints=100, + threads=None, + label=None, + last_step=False, +): """Plot a number of samples from the sampler chain. Parameters @@ -615,23 +677,27 @@ def plot_samples(ax, last_step=last_step, e_range=e_range, e_npoints=e_npoints, - threads=threads) + threads=threads, + ) # pick first model sample for units f_unit, sedf = sed_conversion(modelx, model[0].unit, sed) - sample_alpha = min(5. / n_samples, 0.5) + sample_alpha = min(5.0 / n_samples, 0.5) for my in model[np.random.randint(len(model), size=n_samples)]: ax.loglog( - modelx.to(e_unit).value, (my * sedf).to(f_unit).value, + modelx.to(e_unit).value, + (my * sedf).to(f_unit).value, color=(0.1,) * 3, alpha=sample_alpha, - lw=1.0) + lw=1.0, + ) _plot_MLmodel(ax, sampler, modelidx, e_range, e_npoints, e_unit, sed) if label is not None: - ax.set_ylabel('{0} [{1}]'.format(label, f_unit.to_string( - 'latex_inline'))) + ax.set_ylabel( + "{0} [{1}]".format(label, f_unit.to_string("latex_inline")) + ) def find_ML(sampler, modelidx): @@ -640,42 +706,41 @@ def find_ML(sampler, modelidx): probability. """ index = np.unravel_index( - np.argmax(sampler.lnprobability), sampler.lnprobability.shape) + np.argmax(sampler.lnprobability), sampler.lnprobability.shape + ) MLp = sampler.chain[index] - if modelidx is not None and hasattr(sampler, 'blobs'): + if modelidx is not None and hasattr(sampler, "blobs"): blob = sampler.blobs[index[1]][index[0]][modelidx] if isinstance(blob, u.Quantity): - modelx = sampler.data['energy'].copy() + modelx = sampler.data["energy"].copy() model_ML = blob.copy() elif len(blob) == 2: modelx = blob[0].copy() model_ML = blob[1].copy() else: - raise TypeError('Model {0} has wrong blob format'.format(modelidx)) - elif modelidx is not None and hasattr(sampler, 'modelfn'): + raise TypeError("Model {0} has wrong blob format".format(modelidx)) + elif modelidx is not None and hasattr(sampler, "modelfn"): blob = _process_blob( [sampler.modelfn(MLp, sampler.data)], modelidx, - energy=sampler.data['energy']) + energy=sampler.data["energy"], + ) modelx, model_ML = blob[0], blob[1][0] else: modelx, model_ML = None, None MLerr = [] for dist in sampler.flatchain.T: - hilo = np.percentile(dist, [16., 84.]) - MLerr.append((hilo[1] - hilo[0]) / 2.) + hilo = np.percentile(dist, [16.0, 84.0]) + MLerr.append((hilo[1] - hilo[0]) / 2.0) ML = sampler.lnprobability[index] return ML, MLp, MLerr, (modelx, model_ML) -def plot_blob(sampler, - blobidx=0, - label=None, - last_step=False, - figure=None, - **kwargs): +def plot_blob( + sampler, blobidx=0, label=None, last_step=False, figure=None, **kwargs +): """ Plot a metadata blob as a fit to spectral data or value distribution @@ -699,7 +764,7 @@ def plot_blob(sampler, modelx, model = _process_blob(sampler, blobidx, last_step) if label is None: - label = 'Model output {0}'.format(blobidx) + label = "Model output {0}".format(blobidx) if modelx is None: # Blob is scalar, plot distribution @@ -711,30 +776,33 @@ def plot_blob(sampler, last_step=last_step, label=label, figure=figure, - **kwargs) + **kwargs + ) return f -def plot_fit(sampler, - modelidx=0, - label=None, - sed=True, - last_step=False, - n_samples=100, - confs=None, - ML_info=False, - figure=None, - plotdata=None, - plotresiduals=None, - e_unit=None, - e_range=None, - e_npoints=100, - threads=None, - xlabel=None, - ylabel=None, - ulim_opts={}, - errorbar_opts={}): +def plot_fit( + sampler, + modelidx=0, + label=None, + sed=True, + last_step=False, + n_samples=100, + confs=None, + ML_info=False, + figure=None, + plotdata=None, + plotresiduals=None, + e_unit=None, + e_range=None, + e_npoints=100, + threads=None, + xlabel=None, + ylabel=None, + ulim_opts={}, + errorbar_opts={}, +): """ Plot data with fit confidence regions. @@ -799,10 +867,10 @@ def plot_fit(sampler, import matplotlib.pyplot as plt ML, MLp, MLerr, model_ML = find_ML(sampler, modelidx) - infostr = 'Maximum log probability: {0:.3g}\n'.format(ML) - infostr += 'Maximum Likelihood values:\n' + infostr = "Maximum log probability: {0:.3g}\n".format(ML) + infostr += "Maximum Likelihood values:\n" maxlen = np.max([len(ilabel) for ilabel in sampler.labels]) - vartemplate = '{{2:>{0}}}: {{0:>8.3g}} +/- {{1:<8.3g}}\n'.format(maxlen) + vartemplate = "{{2:>{0}}}: {{0:>8.3g}} +/- {{1:<8.3g}}\n".format(maxlen) for p, v, ilabel in zip(MLp, MLerr, sampler.labels): infostr += vartemplate.format(p, v, ilabel) @@ -810,10 +878,10 @@ def plot_fit(sampler, data = sampler.data - if e_range is None and not hasattr(sampler, 'blobs'): - e_range = data['energy'][[0, -1]] * np.array((1. / 3., 3.)) + if e_range is None and not hasattr(sampler, "blobs"): + e_range = data["energy"][[0, -1]] * np.array((1.0 / 3.0, 3.0)) - if len(model_ML[0]) == len(data['energy']) and plotdata is None: + if len(model_ML[0]) == len(data["energy"]) and plotdata is None: plotdata = True elif plotdata is None: plotdata = False @@ -831,7 +899,8 @@ def plot_fit(sampler, figure=figure, e_unit=e_unit, ulim_opts=ulim_opts, - errorbar_opts=errorbar_opts) + errorbar_opts=errorbar_opts, + ) if figure is None: f = plt.figure() @@ -847,7 +916,7 @@ def plot_fit(sampler, ax1 = f.add_subplot(111) if e_unit is None: - e_unit = data['energy'].unit + e_unit = data["energy"].unit if confs is not None: plot_CI( @@ -861,7 +930,8 @@ def plot_fit(sampler, e_range=e_range, e_npoints=e_npoints, last_step=last_step, - threads=threads) + threads=threads, + ) elif n_samples: plot_samples( ax1, @@ -874,7 +944,8 @@ def plot_fit(sampler, e_range=e_range, e_npoints=e_npoints, last_step=last_step, - threads=threads) + threads=threads, + ) else: # plot only ML model _plot_MLmodel(ax1, sampler, modelidx, e_range, e_npoints, e_unit, sed) @@ -888,7 +959,8 @@ def plot_fit(sampler, sed=sed, ylabel=ylabel, ulim_opts=ulim_opts, - errorbar_opts=errorbar_opts) + errorbar_opts=errorbar_opts, + ) if plotresiduals: _plot_residuals_to_ax( data, @@ -896,36 +968,43 @@ def plot_fit(sampler, ax2, e_unit=e_unit, sed=sed, - errorbar_opts=errorbar_opts) + errorbar_opts=errorbar_opts, + ) xlaxis = ax2 for tl in ax1.get_xticklabels(): tl.set_visible(False) - xmin = 10**np.floor( + xmin = 10 ** np.floor( np.log10( - np.min(data['energy'] - data['energy_error_lo']).to(e_unit) - .value)) - xmax = 10**np.ceil( + np.min(data["energy"] - data["energy_error_lo"]) + .to(e_unit) + .value + ) + ) + xmax = 10 ** np.ceil( np.log10( - np.max(data['energy'] + data['energy_error_hi']).to(e_unit) - .value)) + np.max(data["energy"] + data["energy_error_hi"]) + .to(e_unit) + .value + ) + ) ax1.set_xlim(xmin, xmax) else: - ax1.set_xscale('log') - ax1.set_yscale('log') + ax1.set_xscale("log") + ax1.set_yscale("log") if sed: ndecades = 10 else: ndecades = 20 # restrict y axis to ndecades to avoid autoscaling deep exponentials xmin, xmax, ymin, ymax = ax1.axis() - ymin = max(ymin, ymax / 10**ndecades) + ymin = max(ymin, ymax / 10 ** ndecades) ax1.set_ylim(bottom=ymin) # scale x axis to largest model_ML x point within ndecades decades of # maximum f_unit, sedf = sed_conversion(model_ML[0], model_ML[1].unit, sed) hi = np.where((model_ML[1] * sedf).to(f_unit).value > ymin) xmax = np.max(model_ML[0][hi]) - ax1.set_xlim(right=10**np.ceil(np.log10(xmax.to(e_unit).value))) + ax1.set_xlim(right=10 ** np.ceil(np.log10(xmax.to(e_unit).value))) if e_range: # ensure that xmin/xmax contains e_range @@ -939,17 +1018,19 @@ def plot_fit(sampler, 0.05, 0.05, infostr, - ha='left', - va='bottom', + ha="left", + va="bottom", transform=ax1.transAxes, - family='monospace') + family="monospace", + ) if label is not None: ax1.set_title(label) if xlabel is None: - xlaxis.set_xlabel('Energy [{0}]'.format( - e_unit.to_string('latex_inline'))) + xlaxis.set_xlabel( + "Energy [{0}]".format(e_unit.to_string("latex_inline")) + ) else: xlaxis.set_xlabel(xlabel) @@ -958,186 +1039,202 @@ def plot_fit(sampler, return f -def _plot_ulims(ax, - x, - y, - xerr, - color, - capsize=5, - height_fraction=0.25, - elinewidth=2): +def _plot_ulims( + ax, x, y, xerr, color, capsize=5, height_fraction=0.25, elinewidth=2 +): """ Plot upper limits as arrows with cap at value of upper limit. uplim behaviour has been fixed in matplotlib 1.4 """ ax.errorbar( - x, y, xerr=xerr, ls='', color=color, elinewidth=elinewidth, capsize=0) + x, y, xerr=xerr, ls="", color=color, elinewidth=elinewidth, capsize=0 + ) from distutils.version import LooseVersion import matplotlib + mpl_version = LooseVersion(matplotlib.__version__) - if mpl_version >= LooseVersion('1.4.0'): + if mpl_version >= LooseVersion("1.4.0"): ax.errorbar( x, y, yerr=height_fraction * y, - ls='', + ls="", uplims=True, color=color, elinewidth=elinewidth, capsize=capsize, - zorder=10) + zorder=10, + ) else: ax.errorbar( - x, (1 - height_fraction) * y, + x, + (1 - height_fraction) * y, yerr=height_fraction * y, - ls='', + ls="", lolims=True, color=color, elinewidth=elinewidth, capsize=capsize, - zorder=10) - - -def _plot_data_to_ax(data_all, - ax1, - e_unit=None, - sed=True, - ylabel=None, - ulim_opts={}, - errorbar_opts={}): + zorder=10, + ) + + +def _plot_data_to_ax( + data_all, + ax1, + e_unit=None, + sed=True, + ylabel=None, + ulim_opts={}, + errorbar_opts={}, +): """ Plots data errorbars and upper limits onto ax. X label is left to plot_data and plot_fit because they depend on whether residuals are plotted. """ if e_unit is None: - e_unit = data_all['energy'].unit + e_unit = data_all["energy"].unit - f_unit, sedf = sed_conversion(data_all['energy'], data_all['flux'].unit, - sed) + f_unit, sedf = sed_conversion( + data_all["energy"], data_all["flux"].unit, sed + ) - if 'group' not in data_all.keys(): - data_all['group'] = np.zeros(len(data_all)) + if "group" not in data_all.keys(): + data_all["group"] = np.zeros(len(data_all)) - groups = np.unique(data_all['group']) + groups = np.unique(data_all["group"]) for g in groups: - data = data_all[np.where(data_all['group'] == g)] - _, sedfg = sed_conversion(data['energy'], data['flux'].unit, sed) + data = data_all[np.where(data_all["group"] == g)] + _, sedfg = sed_conversion(data["energy"], data["flux"].unit, sed) # wrap around color and marker cycles color = color_cycle[int(g) % len(color_cycle)] marker = marker_cycle[int(g) % len(marker_cycle)] - ul = data['ul'] + ul = data["ul"] notul = ~ul # Hack to show y errors compatible with 0 in loglog plot - yerr_lo = data['flux_error_lo'][notul] - y = data['flux'][notul].to(yerr_lo.unit) - bad_err = np.where((y - yerr_lo) <= 0.) - yerr_lo[bad_err] = y[bad_err] * (1. - 1e-7) - yerr = u.Quantity((yerr_lo, data['flux_error_hi'][notul])) - xerr = u.Quantity((data['energy_error_lo'], data['energy_error_hi'])) + yerr_lo = data["flux_error_lo"][notul] + y = data["flux"][notul].to(yerr_lo.unit) + bad_err = np.where((y - yerr_lo) <= 0.0) + yerr_lo[bad_err] = y[bad_err] * (1.0 - 1e-7) + yerr = u.Quantity((yerr_lo, data["flux_error_hi"][notul])) + xerr = u.Quantity((data["energy_error_lo"], data["energy_error_hi"])) opts = dict( zorder=100, marker=marker, - ls='', + ls="", elinewidth=2, capsize=0, mec=color, mew=0.1, ms=5, - color=color) + color=color, + ) opts.update(**errorbar_opts) ax1.errorbar( - data['energy'][notul].to(e_unit).value, - (data['flux'][notul] * sedfg[notul]).to(f_unit).value, + data["energy"][notul].to(e_unit).value, + (data["flux"][notul] * sedfg[notul]).to(f_unit).value, yerr=(yerr * sedfg[notul]).to(f_unit).value, xerr=xerr[:, notul].to(e_unit).value, - **opts) + **opts + ) if np.any(ul): - if 'elinewidth' in errorbar_opts: - ulim_opts['elinewidth'] = errorbar_opts['elinewidth'] - - _plot_ulims(ax1, data['energy'][ul].to(e_unit).value, - (data['flux'][ul] * sedfg[ul]).to(f_unit).value, - (xerr[:, ul]).to(e_unit).value, color, **ulim_opts) + if "elinewidth" in errorbar_opts: + ulim_opts["elinewidth"] = errorbar_opts["elinewidth"] + + _plot_ulims( + ax1, + data["energy"][ul].to(e_unit).value, + (data["flux"][ul] * sedfg[ul]).to(f_unit).value, + (xerr[:, ul]).to(e_unit).value, + color, + **ulim_opts + ) - ax1.set_xscale('log') - ax1.set_yscale('log') - xmin = 10**np.floor( + ax1.set_xscale("log") + ax1.set_yscale("log") + xmin = 10 ** np.floor( np.log10( - np.min(data['energy'] - data['energy_error_lo']).to(e_unit).value)) - xmax = 10**np.ceil( + np.min(data["energy"] - data["energy_error_lo"]).to(e_unit).value + ) + ) + xmax = 10 ** np.ceil( np.log10( - np.max(data['energy'] + data['energy_error_hi']).to(e_unit).value)) + np.max(data["energy"] + data["energy_error_hi"]).to(e_unit).value + ) + ) ax1.set_xlim(xmin, xmax) # avoid autoscaling to errorbars to 0 - notul = ~data_all['ul'] - if np.any(data_all['flux_error_lo'][notul] >= data_all['flux'][notul]): - elo = ( - (data_all['flux'][notul] * sedf[notul]).to(f_unit).value - - (data_all['flux_error_lo'][notul] * sedf[notul]).to(f_unit).value) + notul = ~data_all["ul"] + if np.any(data_all["flux_error_lo"][notul] >= data_all["flux"][notul]): + elo = (data_all["flux"][notul] * sedf[notul]).to(f_unit).value - ( + data_all["flux_error_lo"][notul] * sedf[notul] + ).to(f_unit).value gooderr = np.where( - data_all['flux_error_lo'][notul] < data_all['flux'][notul]) - ymin = 10**np.floor(np.log10(np.min(elo[gooderr]))) + data_all["flux_error_lo"][notul] < data_all["flux"][notul] + ) + ymin = 10 ** np.floor(np.log10(np.min(elo[gooderr]))) ax1.set_ylim(bottom=ymin) if ylabel is None: if sed: - ax1.set_ylabel(r'$E^2\mathrm{{d}}N/\mathrm{{d}}E$' - ' [{0}]'.format( - u.Unit(f_unit).to_string('latex_inline'))) + ax1.set_ylabel( + r"$E^2\mathrm{{d}}N/\mathrm{{d}}E$" + " [{0}]".format(u.Unit(f_unit).to_string("latex_inline")) + ) else: - ax1.set_ylabel(r'$\mathrm{{d}}N/\mathrm{{d}}E$' - ' [{0}]'.format( - u.Unit(f_unit).to_string('latex_inline'))) + ax1.set_ylabel( + r"$\mathrm{{d}}N/\mathrm{{d}}E$" + " [{0}]".format(u.Unit(f_unit).to_string("latex_inline")) + ) else: ax1.set_ylabel(ylabel) -def _plot_residuals_to_ax(data_all, - model_ML, - ax, - e_unit=u.eV, - sed=True, - errorbar_opts={}): +def _plot_residuals_to_ax( + data_all, model_ML, ax, e_unit=u.eV, sed=True, errorbar_opts={} +): """Function to compute and plot residuals in units of the uncertainty""" - if 'group' not in data_all.keys(): - data_all['group'] = np.zeros(len(data_all)) + if "group" not in data_all.keys(): + data_all["group"] = np.zeros(len(data_all)) - groups = np.unique(data_all['group']) + groups = np.unique(data_all["group"]) MLf_unit, MLsedf = sed_conversion(model_ML[0], model_ML[1].unit, sed) MLene = model_ML[0].to(e_unit) MLflux = (model_ML[1] * MLsedf).to(MLf_unit) - ax.axhline(0, color='k', lw=1, ls='--') + ax.axhline(0, color="k", lw=1, ls="--") interp = False - if (data_all['energy'].size != MLene.size or - not np.allclose(data_all['energy'].value, MLene.value)): + if data_all["energy"].size != MLene.size or not np.allclose( + data_all["energy"].value, MLene.value + ): interp = True from scipy.interpolate import interp1d + modelfunc = interp1d(MLene.value, MLflux.value, bounds_error=False) for g in groups: - groupidx = np.where(data_all['group'] == g) + groupidx = np.where(data_all["group"] == g) data = data_all[groupidx] - notul = ~data['ul'] - df_unit, dsedf = sed_conversion(data['energy'], data['flux'].unit, sed) - ene = data['energy'].to(e_unit) - xerr = u.Quantity((data['energy_error_lo'], data['energy_error_hi'])) - flux = (data['flux'] * dsedf).to(df_unit) - dflux = (data['flux_error_lo'] + data['flux_error_hi']) / 2. + notul = ~data["ul"] + df_unit, dsedf = sed_conversion(data["energy"], data["flux"].unit, sed) + ene = data["energy"].to(e_unit) + xerr = u.Quantity((data["energy_error_lo"], data["energy_error_hi"])) + flux = (data["flux"] * dsedf).to(df_unit) + dflux = (data["flux_error_lo"] + data["flux_error_hi"]) / 2.0 dflux = (dflux * dsedf).to(df_unit)[notul] if interp: @@ -1152,38 +1249,44 @@ def _plot_residuals_to_ax(data_all, opts = dict( zorder=100, marker=marker, - ls='', + ls="", elinewidth=2, capsize=0, mec=color, mew=0.1, ms=6, - color=color) + color=color, + ) opts.update(errorbar_opts) ax.errorbar( - ene[notul].value, (difference / dflux).decompose().value, + ene[notul].value, + (difference / dflux).decompose().value, yerr=(dflux / dflux).decompose().value, xerr=xerr[:, notul].to(e_unit).value, - **opts) + **opts + ) from matplotlib.ticker import MaxNLocator - ax.yaxis.set_major_locator( - MaxNLocator( - 5, integer='True', prune='upper', symmetric=True)) - - ax.set_ylabel(r'$\Delta\sigma$') - ax.set_xscale('log') - -def plot_data(input_data, - xlabel=None, - ylabel=None, - sed=True, - figure=None, - e_unit=None, - ulim_opts={}, - errorbar_opts={}): + ax.yaxis.set_major_locator( + MaxNLocator(5, integer="True", prune="upper", symmetric=True) + ) + + ax.set_ylabel(r"$\Delta\sigma$") + ax.set_xscale("log") + + +def plot_data( + input_data, + xlabel=None, + ylabel=None, + sed=True, + figure=None, + e_unit=None, + ulim_opts={}, + errorbar_opts={}, +): """ Plot spectral data. @@ -1216,12 +1319,12 @@ def plot_data(input_data, try: data = validate_data_table(input_data) except TypeError: - if hasattr(input_data, 'data'): + if hasattr(input_data, "data"): data = input_data.data - elif isinstance(input_data, dict) and 'energy' in input_data.keys(): + elif isinstance(input_data, dict) and "energy" in input_data.keys(): data = input_data else: - log.warning('input_data format not know, no plotting data!') + log.warning("input_data format not know, no plotting data!") return None if figure is None: @@ -1236,14 +1339,14 @@ def plot_data(input_data, # try to get units from previous plot in figure try: - old_e_unit = u.Unit(ax1.get_xlabel().split('[')[-1].split(']')[0]) + old_e_unit = u.Unit(ax1.get_xlabel().split("[")[-1].split("]")[0]) except ValueError: - old_e_unit = u.Unit('') + old_e_unit = u.Unit("") - if e_unit is None and old_e_unit.physical_type == 'energy': + if e_unit is None and old_e_unit.physical_type == "energy": e_unit = old_e_unit elif e_unit is None: - e_unit = data['energy'].unit + e_unit = data["energy"].unit _plot_data_to_ax( data, @@ -1252,13 +1355,16 @@ def plot_data(input_data, sed=sed, ylabel=ylabel, ulim_opts=ulim_opts, - errorbar_opts=errorbar_opts) + errorbar_opts=errorbar_opts, + ) if xlabel is not None: ax1.set_xlabel(xlabel) - elif xlabel is None and ax1.get_xlabel() == '': - ax1.set_xlabel(r'$\mathrm{Energy}$' + ' [{0}]'.format( - e_unit.to_string('latex_inline'))) + elif xlabel is None and ax1.get_xlabel() == "": + ax1.set_xlabel( + r"$\mathrm{Energy}$" + + " [{0}]".format(e_unit.to_string("latex_inline")) + ) ax1.autoscale() @@ -1278,23 +1384,26 @@ def plot_distribution(samples, label, figure=None): if isinstance(samples[0], u.Quantity): unit = samples[0].unit else: - unit = '' + unit = "" if isinstance(std, u.Quantity): std = std.value - dist_props = '{label} distribution properties:\n \ + dist_props = "{label} distribution properties:\n \ $-$ median: ${median}$ {unit}, std: ${std}$ {unit}\n \ $-$ Median with uncertainties based on \n \ the 16th and 84th percentiles ($\sim$1$\sigma$):\n\ - {label} = {value_error} {unit}'.format( + {label} = {value_error} {unit}".format( label=label, median=_latex_float(quantiles[50]), std=_latex_float(std), - value_error=_latex_value_error(quantiles[50], - quantiles[50] - quantiles[16], - quantiles[84] - quantiles[50]), - unit=unit) + value_error=_latex_value_error( + quantiles[50], + quantiles[50] - quantiles[16], + quantiles[84] - quantiles[50], + ), + unit=unit, + ) if figure is None: f = plt.figure() @@ -1304,47 +1413,50 @@ def plot_distribution(samples, label, figure=None): ax = f.add_subplot(111) f.subplots_adjust(bottom=0.40, top=0.93, left=0.06, right=0.95) - f.text(0.2, 0.27, dist_props, ha='left', va='top') + f.text(0.2, 0.27, dist_props, ha="left", va="top") - histnbins = min(max(25, int(len(samples) / 100.)), 100) - xlabel = '' if label is None else label + histnbins = min(max(25, int(len(samples) / 100.0)), 100) + xlabel = "" if label is None else label n, x, _ = ax.hist( samples, histnbins, - histtype='stepfilled', + histtype="stepfilled", color=color_cycle[0], lw=0, - normed=1) + normed=1, + ) if isinstance(samples, u.Quantity): samples_nounit = samples.value else: samples_nounit = samples kde = stats.kde.gaussian_kde(samples_nounit) - ax.plot(x, kde(x), color='k', label='KDE') + ax.plot(x, kde(x), color="k", label="KDE") ax.axvline( quantiles[50], - ls='--', - color='k', + ls="--", + color="k", alpha=0.5, lw=2, - label='50% quantile') + label="50% quantile", + ) ax.axvspan( quantiles[16], quantiles[84], color=(0.5,) * 3, alpha=0.25, - label='68% CI', - lw=0) + label="68% CI", + lw=0, + ) # ax.legend() for l in ax.get_xticklabels(): l.set_rotation(45) # [l.set_rotation(45) for l in ax.get_yticklabels()] - if unit != '': - xlabel += ' [{0}]'.format(unit) + if unit != "": + xlabel += " [{0}]".format(unit) ax.set_xlabel(xlabel) - ax.set_title('posterior distribution of {0}'.format(label)) + ax.set_title("posterior distribution of {0}".format(label)) ax.set_ylim(top=n.max() * 1.05) return f @@ -1368,8 +1480,9 @@ def plot_corner(sampler, show_ML=True, **kwargs): the 2D histograms. """ import matplotlib.pyplot as plt - oldlw = plt.rcParams['lines.linewidth'] - plt.rcParams['lines.linewidth'] = 0.7 + + oldlw = plt.rcParams["lines.linewidth"] + plt.rcParams["lines.linewidth"] = 0.7 try: from corner import corner @@ -1379,21 +1492,22 @@ def plot_corner(sampler, show_ML=True, **kwargs): MLp = None corner_opts = { - 'labels': sampler.labels, - 'truths': MLp, - 'quantiles': [0.16, 0.5, 0.84], - 'verbose': False, - 'truth_color': color_cycle[0], + "labels": sampler.labels, + "truths": MLp, + "quantiles": [0.16, 0.5, 0.84], + "verbose": False, + "truth_color": color_cycle[0], } corner_opts.update(kwargs) f = corner(sampler.flatchain, **corner_opts) except ImportError: - log.warning('The corner package is not installed;' - ' corner plot not available') + log.warning( + "The corner package is not installed;" " corner plot not available" + ) f = None - plt.rcParams['lines.linewidth'] = oldlw + plt.rcParams["lines.linewidth"] = oldlw return f diff --git a/naima/radiative.py b/naima/radiative.py index 67f565ed..575351b7 100644 --- a/naima/radiative.py +++ b/naima/radiative.py @@ -1,14 +1,21 @@ # -*- coding: utf-8 -*- # Licensed under a 3-clause BSD style license - see LICENSE.rst -from __future__ import (absolute_import, division, print_function, - unicode_literals) +from __future__ import ( + absolute_import, + division, + print_function, + unicode_literals, +) import numpy as np from numtraits import NumericalTrait from traitlets import HasTraits, Int, observe -from .extern.validator import (validate_scalar, validate_array, - validate_physical_type) +from .extern.validator import ( + validate_scalar, + validate_array, + validate_physical_type, +) from .utils import trapz_loglog from .model_utils import memoize @@ -22,25 +29,29 @@ # Constants and units from astropy import units as u + # import constant values from astropy.constants from astropy.constants import c, m_e, hbar, sigma_sb, e, m_p, alpha __all__ = [ - 'Synchrotron', 'InverseCompton', 'PionDecay', 'Bremsstrahlung', - 'PionDecayKelner06' + "Synchrotron", + "InverseCompton", + "PionDecay", + "Bremsstrahlung", + "PionDecayKelner06", ] # Get a new logger to avoid changing the level of the astropy logger -log = logging.getLogger('naima.radiative') +log = logging.getLogger("naima.radiative") log.setLevel(logging.INFO) e = e.gauss -mec2 = (m_e * c**2).cgs +mec2 = (m_e * c ** 2).cgs mec2_unit = u.Unit(mec2) -ar = (4 * sigma_sb / c).to('erg/(cm3 K4)') -r0 = (e**2 / mec2).to('cm') +ar = (4 * sigma_sb / c).to("erg/(cm3 K4)") +r0 = (e ** 2 / mec2).to("cm") def _validate_ene(ene): @@ -49,13 +60,14 @@ def _validate_ene(ene): if isinstance(ene, dict) or isinstance(ene, Table): try: ene = validate_array( - 'energy', u.Quantity(ene['energy']), physical_type='energy') + "energy", u.Quantity(ene["energy"]), physical_type="energy" + ) except KeyError: - raise TypeError('Table or dict does not have \'energy\' column') + raise TypeError("Table or dict does not have 'energy' column") else: if not isinstance(ene, u.Quantity): ene = u.Quantity(ene) - validate_physical_type('energy', ene, physical_type='energy') + validate_physical_type("energy", ene, physical_type="energy") return ene @@ -74,16 +86,18 @@ def __init__(self, particle_distribution): # the particle distribution is a function from naima.models pd = self.particle_distribution.amplitude validate_physical_type( - 'Particle distribution', + "Particle distribution", pd, - physical_type='differential energy') + physical_type="differential energy", + ) except (AttributeError, TypeError): # otherwise check the output pd = self.particle_distribution([0.1, 1, 10] * u.TeV) validate_physical_type( - 'Particle distribution', + "Particle distribution", pd, - physical_type='differential energy') + physical_type="differential energy", + ) def _spectrum(self, photon_energy): """ @@ -114,11 +128,12 @@ def flux(self, photon_energy, distance=1 * u.kpc): if distance != 0: distance = validate_scalar( - 'distance', distance, physical_type='length') - spec /= 4 * np.pi * distance.to('cm')**2 - out_unit = '1/(s cm2 eV)' + "distance", distance, physical_type="length" + ) + spec /= 4 * np.pi * distance.to("cm") ** 2 + out_unit = "1/(s cm2 eV)" else: - out_unit = '1/(s eV)' + out_unit = "1/(s eV)" return spec.to(out_unit) @@ -135,14 +150,15 @@ def sed(self, photon_energy, distance=1 * u.kpc): be returned. Default is 1 kpc. """ if distance != 0: - out_unit = 'erg/(cm2 s)' + out_unit = "erg/(cm2 s)" else: - out_unit = 'erg/s' + out_unit = "erg/s" photon_energy = _validate_ene(photon_energy) - sed = (self.flux(photon_energy, distance) * photon_energy ** 2.).to( - out_unit) + sed = (self.flux(photon_energy, distance) * photon_energy ** 2.0).to( + out_unit + ) return sed @@ -155,9 +171,9 @@ class BaseLorentzFactor(BaseRadiative): def __init__(self, particle_distribution, mass): super(BaseLorentzFactor, self).__init__(particle_distribution) - self.param_names = ['gmin', 'gmax', 'ngd'] - mass = validate_scalar('mass', mass, physical_type='mass') - self.mc2 = (mass * c**2).cgs + self.param_names = ["gmin", "gmax", "ngd"] + mass = validate_scalar("mass", mass, physical_type="mass") + self.mc2 = (mass * c ** 2).cgs self.mc2_unit = u.Unit(self.mc2) self._memoize = True self._cache = {} @@ -169,8 +185,9 @@ def _gam(self): """ log10gmin = np.log10(self.gmin) log10gmax = np.log10(self.gmax) - return np.logspace(log10gmin, log10gmax, - self.ngd * (log10gmax - log10gmin)) + return np.logspace( + log10gmin, log10gmax, self.ngd * (log10gmax - log10gmin) + ) @property def _npart(self): @@ -207,8 +224,9 @@ def _compute_W(self, Emin=None, Emax=None): log10gmin = np.log10(Emin / self.mc2).value log10gmax = np.log10(Emax / self.mc2).value - gam = np.logspace(log10gmin, log10gmax, - self.ngd * (log10gmax - log10gmin)) + gam = np.logspace( + log10gmin, log10gmax, self.ngd * (log10gmax - log10gmin) + ) pd = self.particle_distribution(gam * self.mc2) npart = pd.to(1 / self.mc2_unit).value @@ -238,7 +256,7 @@ def _set_W(self, W, Emin=None, Emax=None, amplitude_name=None): Defaults to ``amplitude``. """ - W = validate_scalar('W', W, physical_type='energy') + W = validate_scalar("W", W, physical_type="energy") oldW = self._compute_W(Emin=Emin, Emax=Emax) factor = (W / oldW).decompose() @@ -247,14 +265,15 @@ def _set_W(self, W, Emin=None, Emax=None, amplitude_name=None): self.particle_distribution.amplitude *= factor except AttributeError: log.error( - 'The particle distribution does not have an attribute' - ' called amplitude to modify its normalization: you can' - ' set the name with the amplitude_name parameter of set_W' + "The particle distribution does not have an attribute" + " called amplitude to modify its normalization: you can" + " set the name with the amplitude_name parameter of set_W" ) else: oldampl = getattr(self.particle_distribution, amplitude_name) - setattr(self.particle_distribution, amplitude_name, - oldampl * factor) + setattr( + self.particle_distribution, amplitude_name, oldampl * factor + ) class BaseElectron(BaseLorentzFactor, HasTraits): @@ -269,17 +288,17 @@ def __init__(self, particle_distribution): Eemax = NumericalTrait(convertible_to=u.erg) nEed = Int() - @observe('Eemin') + @observe("Eemin") def _handle_Eemin(self, change): - self.gmin = float(change['new'] / self.mc2) + self.gmin = float(change["new"] / self.mc2) - @observe('Eemax') + @observe("Eemax") def _handle_Eemax(self, change): - self.gmax = float(change['new'] / self.mc2) + self.gmax = float(change["new"] / self.mc2) - @observe('nEed') + @observe("nEed") def _handle_nEed(self, change): - self.ngd = change['new'] + self.ngd = change["new"] @property def We(self): @@ -307,8 +326,9 @@ def set_We(self, We, Eemin=None, Eemax=None, amplitude_name=None): must be accesible as an attribute of the distribution function. Defaults to ``amplitude``. """ - return self._set_W(We, Emin=Eemin, Emax=Eemax, - amplitude_name=amplitude_name) + return self._set_W( + We, Emin=Eemin, Emax=Eemax, amplitude_name=amplitude_name + ) def compute_We(self, Eemin=None, Eemax=None): """ Total energy in electrons between energies Emin and Emax @@ -330,24 +350,25 @@ class BaseLorentzProton(BaseLorentzFactor, HasTraits): """ def __init__(self, particle_distribution): - super(BaseLorentzProton, self).__init__(particle_distribution, - mass=m_p) + super(BaseLorentzProton, self).__init__( + particle_distribution, mass=m_p + ) Epmin = NumericalTrait(convertible_to=u.erg) Epmax = NumericalTrait(convertible_to=u.erg) nEpd = Int() - @observe('Epmin') + @observe("Epmin") def _handle_Epmin(self, change): - self.gmin = float(change['new'] / self.mc2) + self.gmin = float(change["new"] / self.mc2) - @observe('Epmax') + @observe("Epmax") def _handle_Epmax(self, change): - self.gmax = float(change['new'] / self.mc2) + self.gmax = float(change["new"] / self.mc2) - @observe('nEpd') + @observe("nEpd") def _handle_nEpd(self, change): - self.ngd = change['new'] + self.ngd = change["new"] @property def Wp(self): @@ -375,8 +396,9 @@ def set_Wp(self, Wp, Epmin=None, Epmax=None, amplitude_name=None): must be accesible as an attribute of the distribution function. Defaults to ``amplitude``. """ - return self._set_W(Wp, Emin=Epmin, Emax=Epmax, - amplitude_name=amplitude_name) + return self._set_W( + Wp, Emin=Epmin, Emax=Epmax, amplitude_name=amplitude_name + ) def compute_Wp(self, Epmin=None, Epmax=None): """ Total energy in protons between energies Emin and Emax @@ -393,7 +415,6 @@ def compute_Wp(self, Epmin=None, Epmax=None): class BaseSynchrotron(BaseLorentzFactor): - def _spectrum(self, photon_energy): """Compute intrinsic synchrotron differential spectrum for energies in ``photon_energy`` @@ -420,35 +441,50 @@ def Gtilde(x): Invoking crbt only once reduced time by ~40% """ cb = cbrt(x) - gt1 = 1.808 * cb / np.sqrt(1 + 3.4 * cb**2.) - gt2 = 1 + 2.210 * cb**2. + 0.347 * cb**4. - gt3 = 1 + 1.353 * cb**2. + 0.217 * cb**4. + gt1 = 1.808 * cb / np.sqrt(1 + 3.4 * cb ** 2.0) + gt2 = 1 + 2.210 * cb ** 2.0 + 0.347 * cb ** 4.0 + gt3 = 1 + 1.353 * cb ** 2.0 + 0.217 * cb ** 4.0 return gt1 * (gt2 / gt3) * np.exp(-x) - log.debug('calc_sy: Starting synchrotron computation with AKB2010...') + log.debug("calc_sy: Starting synchrotron computation with AKB2010...") # strip units, ensuring correct conversion # astropy units do not convert correctly for gyroradius calculation # when using cgs (SI is fine, see # https://github.com/astropy/astropy/issues/1687) - CS1_0 = np.sqrt(3) * e.value**3 * self.B.to('G').value - CS1_1 = (2 * np.pi * self.mc2.cgs.value - * hbar.cgs.value * outspecene.to('erg').value) + CS1_0 = np.sqrt(3) * e.value ** 3 * self.B.to("G").value + CS1_1 = ( + 2 + * np.pi + * self.mc2.cgs.value + * hbar.cgs.value + * outspecene.to("erg").value + ) CS1 = CS1_0 / CS1_1 # Critical energy, erg - Ec = 3 * e.value * hbar.cgs.value * self.B.to('G').value * self._gam**2 + Ec = ( + 3 + * e.value + * hbar.cgs.value + * self.B.to("G").value + * self._gam ** 2 + ) Ec /= 2 * (self.mc2 / c).cgs.value - EgEc = outspecene.to('erg').value / np.vstack(Ec) + EgEc = outspecene.to("erg").value / np.vstack(Ec) dNdE = CS1 * Gtilde(EgEc) # return units - spec = trapz_loglog( - np.vstack(self._npart) * dNdE, self._gam, axis=0) / u.s / u.erg - spec = spec.to('1/(s eV)') + spec = ( + trapz_loglog(np.vstack(self._npart) * dNdE, self._gam, axis=0) + / u.s + / u.erg + ) + spec = spec.to("1/(s eV)") return spec + class ElectronSynchrotron(BaseElectron, BaseSynchrotron): """Synchrotron emission from an electron population. @@ -485,11 +521,11 @@ class ElectronSynchrotron(BaseElectron, BaseSynchrotron): def __init__(self, particle_distribution, B=3.24e-6 * u.G, **kwargs): super(ElectronSynchrotron, self).__init__(particle_distribution) - self.B = validate_scalar('B', B, physical_type='magnetic flux density') + self.B = validate_scalar("B", B, physical_type="magnetic flux density") self.Eemin = 1 * u.GeV self.Eemax = (1e9 * m_e * c ** 2).to(u.TeV) self.nEed = 100 - self.param_names += ['B'] + self.param_names += ["B"] for key, value in kwargs.items(): setattr(self, key, value) @@ -533,11 +569,11 @@ class ProtonSynchrotron(BaseLorentzProton, BaseSynchrotron): def __init__(self, particle_distribution, B=3.24e-6 * u.G, **kwargs): super(ProtonSynchrotron, self).__init__(particle_distribution) - self.B = validate_scalar('B', B, physical_type='magnetic flux density') + self.B = validate_scalar("B", B, physical_type="magnetic flux density") self.Epmin = 1 * u.GeV self.Epmax = 1 * u.PeV self.nEpd = 100 - self.param_names += ['B'] + self.param_names += ["B"] for key, value in kwargs.items(): setattr(self, key, value) @@ -547,10 +583,10 @@ def G12(x, a): Eqs 20, 24, 25 of Khangulyan et al (2014) """ alpha, a, beta, b = a - pi26 = np.pi**2 / 6.0 + pi26 = np.pi ** 2 / 6.0 G = (pi26 + x) * np.exp(-x) - tmp = 1 + b * x**beta - g = 1. / (a * x**alpha / tmp + 1.) + tmp = 1 + b * x ** beta + g = 1.0 / (a * x ** alpha / tmp + 1.0) return G * g @@ -559,11 +595,11 @@ def G34(x, a): Eqs 20, 24, 25 of Khangulyan et al (2014) """ alpha, a, beta, b, c = a - pi26 = np.pi**2 / 6.0 + pi26 = np.pi ** 2 / 6.0 tmp = (1 + c * x) / (1 + pi26 * c * x) G = pi26 * tmp * np.exp(-x) - tmp = 1 + b * x**beta - g = 1. / (a * x**alpha / tmp + 1.) + tmp = 1 + b * x ** beta + g = 1.0 / (a * x ** alpha / tmp + 1.0) return G * g @@ -620,16 +656,15 @@ class InverseCompton(BaseElectron): distribution arrays. Default is 300. """ - def __init__(self, - particle_distribution, - seed_photon_fields=['CMB'], - **kwargs): + def __init__( + self, particle_distribution, seed_photon_fields=["CMB"], **kwargs + ): super(InverseCompton, self).__init__(particle_distribution) self.seed_photon_fields = self._process_input_seed(seed_photon_fields) self.Eemin = 1 * u.GeV self.Eemax = 1e9 * mec2 self.nEed = 100 - self.param_names += ['seed_photon_fields'] + self.param_names += ["seed_photon_fields"] for key, value in kwargs.items(): setattr(self, key, value) @@ -641,14 +676,14 @@ def _process_input_seed(seed_photon_fields): Tcmb = 2.72548 * u.K # 0.00057 K Tfir = 30 * u.K - ufir = 0.5 * u.eV / u.cm**3 + ufir = 0.5 * u.eV / u.cm ** 3 Tnir = 3000 * u.K - unir = 1.0 * u.eV / u.cm**3 + unir = 1.0 * u.eV / u.cm ** 3 # Allow for seed_photon_fields definitions of the type 'CMB-NIR-FIR' or # 'CMB' if type(seed_photon_fields) != list: - seed_photon_fields = seed_photon_fields.split('-') + seed_photon_fields = seed_photon_fields.split("-") result = OrderedDict() @@ -656,93 +691,105 @@ def _process_input_seed(seed_photon_fields): seed = {} if isinstance(inseed, six.string_types): name = inseed - seed['type'] = 'thermal' - if inseed == 'CMB': - seed['T'] = Tcmb - seed['u'] = ar * Tcmb**4 - seed['isotropic'] = True - elif inseed == 'FIR': - seed['T'] = Tfir - seed['u'] = ufir - seed['isotropic'] = True - elif inseed == 'NIR': - seed['T'] = Tnir - seed['u'] = unir - seed['isotropic'] = True + seed["type"] = "thermal" + if inseed == "CMB": + seed["T"] = Tcmb + seed["u"] = ar * Tcmb ** 4 + seed["isotropic"] = True + elif inseed == "FIR": + seed["T"] = Tfir + seed["u"] = ufir + seed["isotropic"] = True + elif inseed == "NIR": + seed["T"] = Tnir + seed["u"] = unir + seed["isotropic"] = True else: - log.warning('Will not use seed {0} because it is not ' - 'CMB, FIR or NIR'.format(inseed)) + log.warning( + "Will not use seed {0} because it is not " + "CMB, FIR or NIR".format(inseed) + ) raise TypeError - elif type(inseed) == list and (len(inseed) == 3 or - len(inseed) == 4): + elif type(inseed) == list and ( + len(inseed) == 3 or len(inseed) == 4 + ): isotropic = len(inseed) == 3 if isotropic: name, T, uu = inseed - seed['isotropic'] = True + seed["isotropic"] = True else: name, T, uu, theta = inseed - seed['isotropic'] = False - seed['theta'] = validate_scalar( - '{0}-theta'.format(name), theta, physical_type='angle') + seed["isotropic"] = False + seed["theta"] = validate_scalar( + "{0}-theta".format(name), theta, physical_type="angle" + ) - thermal = T.unit.physical_type == 'temperature' + thermal = T.unit.physical_type == "temperature" if thermal: - seed['type'] = 'thermal' + seed["type"] = "thermal" validate_scalar( - '{0}-T'.format(name), + "{0}-T".format(name), T, - domain='positive', - physical_type='temperature') - seed['T'] = T + domain="positive", + physical_type="temperature", + ) + seed["T"] = T if uu == 0: - seed['u'] = ar * T**4 + seed["u"] = ar * T ** 4 else: # pressure has same physical type as energy density validate_scalar( - '{0}-u'.format(name), + "{0}-u".format(name), uu, - domain='positive', - physical_type='pressure') - seed['u'] = uu + domain="positive", + physical_type="pressure", + ) + seed["u"] = uu else: - seed['type'] = 'array' + seed["type"] = "array" # Ensure everything is in arrays T = u.Quantity((T,)).flatten() uu = u.Quantity((uu,)).flatten() - seed['energy'] = validate_array( - '{0}-energy'.format(name), + seed["energy"] = validate_array( + "{0}-energy".format(name), T, - domain='positive', - physical_type='energy') + domain="positive", + physical_type="energy", + ) - if np.isscalar(seed['energy']) or seed['energy'].size == 1: - seed['photon_density'] = validate_scalar( - '{0}-density'.format(name), + if np.isscalar(seed["energy"]) or seed["energy"].size == 1: + seed["photon_density"] = validate_scalar( + "{0}-density".format(name), uu, - domain='positive', - physical_type='pressure') + domain="positive", + physical_type="pressure", + ) else: - if uu.unit.physical_type == 'pressure': - uu /= seed['energy']**2 - seed['photon_density'] = validate_array( - '{0}-density'.format(name), + if uu.unit.physical_type == "pressure": + uu /= seed["energy"] ** 2 + seed["photon_density"] = validate_array( + "{0}-density".format(name), uu, - domain='positive', - physical_type='differential number density') + domain="positive", + physical_type="differential number density", + ) else: - raise TypeError('Unable to process seed photon' - ' field: {0}'.format(inseed)) + raise TypeError( + "Unable to process seed photon" + " field: {0}".format(inseed) + ) result[name] = seed return result @staticmethod - def _iso_ic_on_planck(electron_energy, soft_photon_temperature, - gamma_energy): + def _iso_ic_on_planck( + electron_energy, soft_photon_temperature, gamma_energy + ): """ IC cross-section for isotropic interaction with a blackbody photon spectrum following Eq. 14 of Khangulyan, Aharonian, and Kelner 2014, @@ -759,20 +806,21 @@ def _iso_ic_on_planck(electron_energy, soft_photon_temperature, a3 = [0.606, 0.443, 1.481, 0.540, 0.319] a4 = [0.461, 0.726, 1.457, 0.382, 6.620] z = gamma_energy / electron_energy - x = z / (1 - z) / (4. * electron_energy * soft_photon_temperature) + x = z / (1 - z) / (4.0 * electron_energy * soft_photon_temperature) # Eq. 14 - cross_section = z**2 / (2 * (1 - z)) * G34(x, a3) + G34(x, a4) - tmp = (soft_photon_temperature / electron_energy)**2 + cross_section = z ** 2 / (2 * (1 - z)) * G34(x, a3) + G34(x, a4) + tmp = (soft_photon_temperature / electron_energy) ** 2 # r0 = (e**2 / m_e / c**2).to('cm') # (2 * r0 ** 2 * m_e ** 3 * c ** 4 / (pi * hbar ** 3)).cgs - tmp *= 2.6318735743809104e+16 + tmp *= 2.6318735743809104e16 cross_section = tmp * cross_section - cc = ((gamma_energy < electron_energy) * (electron_energy > 1)) + cc = (gamma_energy < electron_energy) * (electron_energy > 1) return np.where(cc, cross_section, np.zeros_like(cross_section)) @staticmethod - def _ani_ic_on_planck(electron_energy, soft_photon_temperature, - gamma_energy, theta): + def _ani_ic_on_planck( + electron_energy, soft_photon_temperature, gamma_energy, theta + ): """ IC cross-section for anisotropic interaction with a blackbody photon spectrum following Eq. 11 of Khangulyan, Aharonian, and Kelner 2014, @@ -790,22 +838,27 @@ def _ani_ic_on_planck(electron_energy, soft_photon_temperature, a1 = [0.857, 0.153, 1.840, 0.254] a2 = [0.691, 1.330, 1.668, 0.534] z = gamma_energy / electron_energy - ttheta = 2. * electron_energy * soft_photon_temperature * ( - 1. - np.cos(theta)) + ttheta = ( + 2.0 + * electron_energy + * soft_photon_temperature + * (1.0 - np.cos(theta)) + ) x = z / (1 - z) / ttheta # Eq. 11 - cross_section = z**2 / (2 * (1 - z)) * G12(x, a1) + G12(x, a2) - tmp = (soft_photon_temperature / electron_energy)**2 + cross_section = z ** 2 / (2 * (1 - z)) * G12(x, a1) + G12(x, a2) + tmp = (soft_photon_temperature / electron_energy) ** 2 # r0 = (e**2 / m_e / c**2).to('cm') # (2 * r0 ** 2 * m_e ** 3 * c ** 4 / (pi * hbar ** 3)).cgs - tmp *= 2.6318735743809104e+16 + tmp *= 2.6318735743809104e16 cross_section = tmp * cross_section - cc = ((gamma_energy < electron_energy) * (electron_energy > 1)) + cc = (gamma_energy < electron_energy) * (electron_energy > 1) return np.where(cc, cross_section, np.zeros_like(cross_section)) @staticmethod - def _iso_ic_on_monochromatic(electron_energy, seed_energy, seed_edensity, - gamma_energy): + def _iso_ic_on_monochromatic( + electron_energy, seed_energy, seed_edensity, gamma_energy + ): """ IC cross-section for an isotropic interaction with a monochromatic photon spectrum following Eq. 22 of Aharonian & Atoyan 1981, Ap&SS 79, @@ -822,19 +875,25 @@ def _iso_ic_on_monochromatic(electron_energy, seed_energy, seed_edensity, b = 4 * photE0 * electron_energy w = gamma_energy / electron_energy q = w / (b * (1 - w)) - fic = (2 * q * np.log(q) + (1 + 2 * q) * (1 - q) + (1. / 2.) * - (b * q)**2 * (1 - q) / (1 + b * q)) - - gamint = (fic * heaviside(1 - q) * - heaviside(q - 1. / (4 * electron_energy**2))) - gamint[np.isnan(gamint)] = 0. + fic = ( + 2 * q * np.log(q) + + (1 + 2 * q) * (1 - q) + + (1.0 / 2.0) * (b * q) ** 2 * (1 - q) / (1 + b * q) + ) + + gamint = ( + fic + * heaviside(1 - q) + * heaviside(q - 1.0 / (4 * electron_energy ** 2)) + ) + gamint[np.isnan(gamint)] = 0.0 if phn.size > 1: - phn = phn.to(1 / (mec2_unit * u.cm**3)).value + phn = phn.to(1 / (mec2_unit * u.cm ** 3)).value gamint = trapz_loglog(gamint * phn / photE0, photE0, axis=0) # 1/s else: - phn = phn.to(mec2_unit / u.cm**3).value - gamint *= phn / photE0**2 + phn = phn.to(mec2_unit / u.cm ** 3).value + gamint *= phn / photE0 ** 2 gamint = gamint.squeeze() # gamint /= mec2.to('erg').value @@ -844,39 +903,47 @@ def _iso_ic_on_monochromatic(electron_energy, seed_energy, seed_edensity, sigt = 6.652458734983284e-25 c = 29979245800.0 - gamint *= (3. / 4.) * sigt * c / electron_energy**2 + gamint *= (3.0 / 4.0) * sigt * c / electron_energy ** 2 return gamint def _calc_specic(self, seed, outspecene): - log.debug('_calc_specic: Computing IC on {0} seed photons...'.format( - seed)) + log.debug( + "_calc_specic: Computing IC on {0} seed photons...".format(seed) + ) Eph = (outspecene / mec2).decompose().value # Catch numpy RuntimeWarnings of overflowing exp (which are then # discarded anyway) with warnings.catch_warnings(): warnings.simplefilter("ignore") - if self.seed_photon_fields[seed]['type'] == 'thermal': - T = self.seed_photon_fields[seed]['T'] - uf = (self.seed_photon_fields[seed]['u'] / - (ar * T**4)).decompose() - if self.seed_photon_fields[seed]['isotropic']: - gamint = self._iso_ic_on_planck(self._gam, - T.to('K').value, Eph) + if self.seed_photon_fields[seed]["type"] == "thermal": + T = self.seed_photon_fields[seed]["T"] + uf = ( + self.seed_photon_fields[seed]["u"] / (ar * T ** 4) + ).decompose() + if self.seed_photon_fields[seed]["isotropic"]: + gamint = self._iso_ic_on_planck( + self._gam, T.to("K").value, Eph + ) else: - theta = self.seed_photon_fields[seed]['theta'].to( - 'rad').value + theta = ( + self.seed_photon_fields[seed]["theta"].to("rad").value + ) gamint = self._ani_ic_on_planck( - self._gam, T.to('K').value, Eph, theta) + self._gam, T.to("K").value, Eph, theta + ) else: uf = 1 gamint = self._iso_ic_on_monochromatic( - self._gam, self.seed_photon_fields[seed]['energy'], - self.seed_photon_fields[seed]['photon_density'], Eph) + self._gam, + self.seed_photon_fields[seed]["energy"], + self.seed_photon_fields[seed]["photon_density"], + Eph, + ) lum = uf * Eph * trapz_loglog(self._npart * gamint, self._gam) - lum = lum * u.Unit('1/s') + lum = lum * u.Unit("1/s") return lum / outspecene # return differential spectrum in 1/s/eV @@ -900,7 +967,8 @@ def _spectrum(self, photon_energy): for seed in self.seed_photon_fields: # Call actual computation, detached to allow changes in subclasses self.specic.append( - self._calc_specic(seed, outspecene).to('1/(s eV)')) + self._calc_specic(seed, outspecene).to("1/(s eV)") + ) return np.sum(u.Quantity(self.specic), axis=0) @@ -923,31 +991,35 @@ def flux(self, photon_energy, distance=1 * u.kpc, seed=None): contributions (default). """ model = super(InverseCompton, self).flux( - photon_energy, distance=distance) + photon_energy, distance=distance + ) if seed is not None: # Test seed argument if not isinstance(seed, int): if seed not in self.seed_photon_fields: raise ValueError( - 'Provided seed photon field name is not in' - ' the definition of the InverseCompton instance') + "Provided seed photon field name is not in" + " the definition of the InverseCompton instance" + ) else: seed = list(self.seed_photon_fields.keys()).index(seed) elif seed > len(self.seed_photon_fields): raise ValueError( - 'Provided seed photon field number is larger' - ' than the number of seed photon fields defined in the' - ' InverseCompton instance') + "Provided seed photon field number is larger" + " than the number of seed photon fields defined in the" + " InverseCompton instance" + ) if distance != 0: distance = validate_scalar( - 'distance', distance, physical_type='length') - dfac = 4 * np.pi * distance.to('cm')**2 - out_unit = '1/(s cm2 eV)' + "distance", distance, physical_type="length" + ) + dfac = 4 * np.pi * distance.to("cm") ** 2 + out_unit = "1/(s cm2 eV)" else: dfac = 1 - out_unit = '1/(s eV)' + out_unit = "1/(s eV)" model = (self.specic[seed] / dfac).to(out_unit) @@ -974,13 +1046,14 @@ def sed(self, photon_energy, distance=1 * u.kpc, seed=None): if seed is not None: if distance != 0: - out_unit = 'erg/(cm2 s)' + out_unit = "erg/(cm2 s)" else: - out_unit = 'erg/s' + out_unit = "erg/s" - sed = (self.flux( - photon_energy, distance=distance, seed=seed) * photon_energy - ** 2.).to(out_unit) + sed = ( + self.flux(photon_energy, distance=distance, seed=seed) + * photon_energy ** 2.0 + ).to(out_unit) return sed @@ -1012,7 +1085,7 @@ class Bremsstrahlung(BaseElectron): Z_i^2 X_i`, default is 1.263. """ - def __init__(self, particle_distribution, n0=1 / u.cm**3, **kwargs): + def __init__(self, particle_distribution, n0=1 / u.cm ** 3, **kwargs): super(Bremsstrahlung, self).__init__(particle_distribution) self.n0 = n0 self.Eemin = 100 * u.MeV @@ -1020,13 +1093,13 @@ def __init__(self, particle_distribution, n0=1 / u.cm**3, **kwargs): self.nEed = 300 # compute ee and ep weights from H and He abundances in ISM assumin # ionized medium - Y = np.array([1., 9.59e-2]) + Y = np.array([1.0, 9.59e-2]) Z = np.array([1, 2]) N = np.sum(Y) X = Y / N self.weight_ee = np.sum(Z * X) - self.weight_ep = np.sum(Z**2 * X) - self.param_names += ['n0', 'weight_ee', 'weight_ep'] + self.weight_ep = np.sum(Z ** 2 * X) + self.param_names += ["n0", "weight_ee", "weight_ep"] for key, value in kwargs.items(): setattr(self, key, value) @@ -1037,9 +1110,9 @@ def _sigma_1(gam, eps): Eq. A2 of Baring et al. (1999) Return in units of cm2 / mec2 """ - s1 = 4 * r0**2 * alpha / eps / mec2_unit - s2 = 1 + (1. / 3. - eps / gam) * (1 - eps / gam) - s3 = np.log(2 * gam * (gam - eps) / eps) - 1. / 2. + s1 = 4 * r0 ** 2 * alpha / eps / mec2_unit + s2 = 1 + (1.0 / 3.0 - eps / gam) * (1 - eps / gam) + s3 = np.log(2 * gam * (gam - eps) / eps) - 1.0 / 2.0 s3[np.where(gam < eps)] = 0.0 return s1 * s2 * s3 @@ -1050,17 +1123,17 @@ def _sigma_2(gam, eps): Eq. A3 of Baring et al. (1999) Return in units of cm2 / mec2 """ - s0 = r0**2 * alpha / (3 * eps) / mec2_unit + s0 = r0 ** 2 * alpha / (3 * eps) / mec2_unit - s1_1 = 16 * (1 - eps + eps**2) * np.log(gam / eps) - s1_2 = -1 / eps**2 + 3 / eps - 4 - 4 * eps - 8 * eps**2 + s1_1 = 16 * (1 - eps + eps ** 2) * np.log(gam / eps) + s1_2 = -1 / eps ** 2 + 3 / eps - 4 - 4 * eps - 8 * eps ** 2 s1_3 = -2 * (1 - 2 * eps) * np.log(1 - 2 * eps) - s1_4 = 1 / (4 * eps**3) - 1 / (2 * eps**2) + 3 / eps - 2 + 4 * eps + s1_4 = 1 / (4 * eps ** 3) - 1 / (2 * eps ** 2) + 3 / eps - 2 + 4 * eps s1 = s1_1 + s1_2 + s1_3 * s1_4 s2_1 = 2 / eps - s2_2 = (4 - 1 / eps + 1 / (4 * eps**2)) * np.log(2 * gam) - s2_3 = -2 + 2 / eps - 5 / (8 * eps**2) + s2_2 = (4 - 1 / eps + 1 / (4 * eps ** 2)) * np.log(2 * gam) + s2_3 = -2 + 2 / eps - 5 / (8 * eps ** 2) s2 = s2_1 * (s2_2 + s2_3) return s0 * np.where(eps <= 0.5, s1, s2) * heaviside(gam - eps) @@ -1070,7 +1143,9 @@ def _sigma_ee_rel(self, gam, eps): Eq. A1, A4 of Baring et al. (1999) Use for Ee > 2 MeV """ - A = 1 - 8 / 3 * (gam - 1)**0.2 / (gam + 1) * (eps / gam)**(1. / 3.) + A = 1 - 8 / 3 * (gam - 1) ** 0.2 / (gam + 1) * (eps / gam) ** ( + 1.0 / 3.0 + ) return (self._sigma_1(gam, eps) + self._sigma_2(gam, eps)) * A @@ -1079,13 +1154,13 @@ def _F(x, gam): """ Eqs. A6, A7 of Baring et al. (1999) """ - beta = np.sqrt(1 - gam**-2) - B = 1 + 0.5 * (gam**2 - 1) + beta = np.sqrt(1 - gam ** -2) + B = 1 + 0.5 * (gam ** 2 - 1) C = 10 * x * gam * beta * (2 + gam * beta) - C /= 1 + x**2 * (gam**2 - 1) + C /= 1 + x ** 2 * (gam ** 2 - 1) - F_1 = (17 - 3 * x**2 / (2 - x)**2 - C) * np.sqrt(1 - x) - F_2 = 12 * (2 - x) - 7 * x**2 / (2 - x) - 3 * x**4 / (2 - x)**3 + F_1 = (17 - 3 * x ** 2 / (2 - x) ** 2 - C) * np.sqrt(1 - x) + F_2 = 12 * (2 - x) - 7 * x ** 2 / (2 - x) - 3 * x ** 4 / (2 - x) ** 3 F_3 = np.log((1 + np.sqrt(1 - x)) / np.sqrt(x)) return B * F_1 + F_2 * F_3 @@ -1095,17 +1170,17 @@ def _sigma_ee_nonrel(self, gam, eps): Eq. A5 of Baring et al. (1999) Use for Ee < 2 MeV """ - s0 = 4 * r0**2 * alpha / (15 * eps) - x = 4 * eps / (gam**2 - 1) + s0 = 4 * r0 ** 2 * alpha / (15 * eps) + x = 4 * eps / (gam ** 2 - 1) sigma_nonrel = s0 * self._F(x, gam) - sigma_nonrel[np.where(eps >= 0.25 * (gam**2 - 1.))] = 0.0 + sigma_nonrel[np.where(eps >= 0.25 * (gam ** 2 - 1.0))] = 0.0 sigma_nonrel[np.where(gam * np.ones_like(eps) < 1.0)] = 0.0 return sigma_nonrel / mec2_unit def _sigma_ee(self, gam, Eph): eps = (Eph / mec2).decompose().value # initialize shape and units of cross section - sigma = np.zeros_like(gam * eps) * u.Unit(u.cm**2 / Eph.unit) + sigma = np.zeros_like(gam * eps) * u.Unit(u.cm ** 2 / Eph.unit) gam_trans = (2 * u.MeV / mec2).decompose().value # Non relativistic below 2 MeV if np.any(gam <= gam_trans): @@ -1120,7 +1195,7 @@ def _sigma_ee(self, gam, Eph): warnings.simplefilter("ignore") sigma[rel_matrix] = self._sigma_ee_rel(gam, eps)[rel_matrix] - return sigma.to(u.cm**2 / Eph.unit) + return sigma.to(u.cm ** 2 / Eph.unit) def _sigma_ep(self, gam, eps): """ @@ -1144,7 +1219,8 @@ def _emiss_ee(self, Eph): emiss = c.cgs * trapz_loglog( np.vstack(self._npart) * self._sigma_ee(gam, Eph), self._gam, - axis=0) + axis=0, + ) return emiss def _emiss_ep(self, Eph): @@ -1160,7 +1236,8 @@ def _emiss_ep(self, Eph): emiss = c.cgs * trapz_loglog( np.vstack(self._npart) * self._sigma_ep(gam, eps), self._gam, - axis=0).to(u.cm**2 / Eph.unit) + axis=0, + ).to(u.cm ** 2 / Eph.unit) return emiss def _spectrum(self, photon_energy): @@ -1175,8 +1252,10 @@ def _spectrum(self, photon_energy): Eph = _validate_ene(photon_energy) - spec = self.n0 * (self.weight_ee * self._emiss_ee(Eph) + self.weight_ep - * self._emiss_ep(Eph)) + spec = self.n0 * ( + self.weight_ee * self._emiss_ee(Eph) + + self.weight_ep * self._emiss_ep(Eph) + ) return spec @@ -1187,7 +1266,7 @@ class BaseProton(BaseRadiative): def __init__(self, particle_distribution): super(BaseProton, self).__init__(particle_distribution) - self.param_names = ['Epmin', 'Epmax', 'nEpd'] + self.param_names = ["Epmin", "Epmax", "nEpd"] self._memoize = True self._cache = {} self._queue = [] @@ -1197,23 +1276,24 @@ def _Ep(self): """ Proton energy array in GeV """ return np.logspace( - np.log10(self.Epmin.to('GeV').value), - np.log10(self.Epmax.to('GeV').value), - self.nEpd * (np.log10(self.Epmax / self.Epmin))) + np.log10(self.Epmin.to("GeV").value), + np.log10(self.Epmax.to("GeV").value), + self.nEpd * (np.log10(self.Epmax / self.Epmin)), + ) @property def _J(self): """ Particles per unit proton energy in particles per GeV """ pd = self.particle_distribution(self._Ep * u.GeV) - return pd.to('1/GeV').value + return pd.to("1/GeV").value @property def Wp(self): """Total energy in protons """ Wp = trapz_loglog(self._Ep * self._J, self._Ep) * u.GeV - return Wp.to('erg') + return Wp.to("erg") def compute_Wp(self, Epmin=None, Epmax=None): """ Total energy in protons between energies Epmin and Epmax @@ -1234,12 +1314,18 @@ def compute_Wp(self, Epmin=None, Epmax=None): if Epmin is None: Epmin = self.Epmin - log10Epmin = np.log10(Epmin.to('GeV').value) - log10Epmax = np.log10(Epmax.to('GeV').value) - Ep = np.logspace(log10Epmin, log10Epmax, - self.nEpd * (log10Epmax - log10Epmin)) * u.GeV + log10Epmin = np.log10(Epmin.to("GeV").value) + log10Epmax = np.log10(Epmax.to("GeV").value) + Ep = ( + np.logspace( + log10Epmin, + log10Epmax, + self.nEpd * (log10Epmax - log10Epmin), + ) + * u.GeV + ) pdist = self.particle_distribution(Ep) - Wp = trapz_loglog(Ep * pdist, Ep).to('erg') + Wp = trapz_loglog(Ep * pdist, Ep).to("erg") return Wp @@ -1264,23 +1350,27 @@ def set_Wp(self, Wp, Epmin=None, Epmax=None, amplitude_name=None): Defaults to ``amplitude``. """ - Wp = validate_scalar('Wp', Wp, physical_type='energy') + Wp = validate_scalar("Wp", Wp, physical_type="energy") oldWp = self.compute_Wp(Epmin=Epmin, Epmax=Epmax) if amplitude_name is None: try: self.particle_distribution.amplitude *= ( - Wp / oldWp).decompose() + Wp / oldWp + ).decompose() except AttributeError: log.error( - 'The particle distribution does not have an attribute' - ' called amplitude to modify its normalization: you can' - ' set the name with the amplitude_name parameter of set_Wp' + "The particle distribution does not have an attribute" + " called amplitude to modify its normalization: you can" + " set the name with the amplitude_name parameter of set_Wp" ) else: oldampl = getattr(self.particle_distribution, amplitude_name) - setattr(self.particle_distribution, amplitude_name, - oldampl * (Wp / oldWp).decompose()) + setattr( + self.particle_distribution, + amplitude_name, + oldampl * (Wp / oldWp).decompose(), + ) class PionDecay(BaseProton): @@ -1339,21 +1429,24 @@ class PionDecay(BaseProton): ISM nuclear enhancement factor. """ - def __init__(self, - particle_distribution, - nh=1.0 / u.cm**3, - nuclear_enhancement=True, - **kwargs): + def __init__( + self, + particle_distribution, + nh=1.0 / u.cm ** 3, + nuclear_enhancement=True, + **kwargs + ): super(PionDecay, self).__init__(particle_distribution) - self.nh = validate_scalar('nh', nh, physical_type='number density') + self.nh = validate_scalar("nh", nh, physical_type="number density") self.nuclear_enhancement = nuclear_enhancement self.useLUT = True - self.hiEmodel = 'Pythia8' + self.hiEmodel = "Pythia8" self.Epmin = ( - self._m_p + self._Tth + 1e-4) * u.GeV # Threshold energy ~1.22 GeV + self._m_p + self._Tth + 1e-4 + ) * u.GeV # Threshold energy ~1.22 GeV self.Epmax = 10 * u.PeV # 10 PeV self.nEpd = 100 - self.param_names += ['nh', 'nuclear_enhancement', 'useLUT', 'hiEmodel'] + self.param_names += ["nh", "nuclear_enhancement", "useLUT", "hiEmodel"] self.__dict__.update(**kwargs) # define model parameters from tables @@ -1390,9 +1483,9 @@ def __init__(self, # yapf: enable # energy at which each of the hiE models start being valid - _Etrans = {'Pythia8': 50, 'SIBYLL': 100, 'QGSJET': 100, 'Geant4': 100} + _Etrans = {"Pythia8": 50, "SIBYLL": 100, "QGSJET": 100, "Geant4": 100} # - _m_p = (m_p * c**2).to('GeV').value + _m_p = (m_p * c ** 2).to("GeV").value _m_pi = 0.1349766 # GeV/c2 _Tth = 0.27966184 @@ -1412,8 +1505,8 @@ def _sigma_inel(self, Tp): """ L = np.log(Tp / self._Tth) - sigma = 30.7 - 0.96 * L + 0.18 * L**2 - sigma *= (1 - (self._Tth / Tp)**1.9)**3 + sigma = 30.7 - 0.96 * L + 0.18 * L ** 2 + sigma *= (1 - (self._Tth / Tp) ** 1.9) ** 3 return sigma * 1e-27 # convert from mbarn to cm-2 def _sigma_pi_loE(self, Tp): @@ -1426,26 +1519,30 @@ def _sigma_pi_loE(self, Tp): Mres = 1.1883 # GeV Gres = 0.2264 # GeV s = 2 * m_p * (Tp + 2 * m_p) # center of mass energy - gamma = np.sqrt(Mres**2 * (Mres**2 + Gres**2)) + gamma = np.sqrt(Mres ** 2 * (Mres ** 2 + Gres ** 2)) K = np.sqrt(8) * Mres * Gres * gamma - K /= np.pi * np.sqrt(Mres**2 + gamma) + K /= np.pi * np.sqrt(Mres ** 2 + gamma) fBW = m_p * K - fBW /= ((np.sqrt(s) - m_p)**2 - Mres**2)**2 + Mres**2 * Gres**2 + fBW /= ( + (np.sqrt(s) - m_p) ** 2 - Mres ** 2 + ) ** 2 + Mres ** 2 * Gres ** 2 - mu = np.sqrt((s - m_pi**2 - 4 * m_p**2)**2 - 16 * m_pi**2 * m_p**2) + mu = np.sqrt( + (s - m_pi ** 2 - 4 * m_p ** 2) ** 2 - 16 * m_pi ** 2 * m_p ** 2 + ) mu /= 2 * m_pi * np.sqrt(s) sigma0 = 7.66e-3 # mb - sigma1pi = sigma0 * mu**1.95 * (1 + mu + mu**5) * fBW**1.86 + sigma1pi = sigma0 * mu ** 1.95 * (1 + mu + mu ** 5) * fBW ** 1.86 # two pion production sigma2pi = 5.7 # mb sigma2pi /= 1 + np.exp(-9.3 * (Tp - 1.4)) E2pith = 0.56 # GeV - sigma2pi[np.where(Tp < E2pith)] = 0. + sigma2pi[np.where(Tp < E2pith)] = 0.0 return (sigma1pi + sigma2pi) * 1e-27 # return in cm-2 @@ -1455,7 +1552,7 @@ def _sigma_pi_midE(self, Tp): """ m_p = self._m_p Qp = (Tp - self._Tth) / m_p - multip = -6e-3 + 0.237 * Qp - 0.023 * Qp**2 + multip = -6e-3 + 0.237 * Qp - 0.023 * Qp ** 2 return self._sigma_inel(Tp) * multip def _sigma_pi_hiE(self, Tp, a): @@ -1464,8 +1561,8 @@ def _sigma_pi_hiE(self, Tp, a): """ m_p = self._m_p csip = (Tp - 3.0) / m_p - m1 = a[0] * csip**a[3] * (1 + np.exp(-a[1] * csip**a[4])) - m2 = 1 - np.exp(-a[2] * csip**0.25) + m1 = a[0] * csip ** a[3] * (1 + np.exp(-a[1] * csip ** a[4])) + m2 = 1 - np.exp(-a[2] * csip ** 0.25) multip = m1 * m2 return self._sigma_inel(Tp) * multip @@ -1480,7 +1577,7 @@ def _sigma_pi(self, Tp): sigma[idx2] = self._sigma_pi_midE(Tp[idx2]) # for 5GeV<=E= 5.0) * (Tp < self._Etrans[self.hiEmodel])) - sigma[idx3] = self._sigma_pi_hiE(Tp[idx3], self._a['Geant4']) + sigma[idx3] = self._sigma_pi_hiE(Tp[idx3], self._a["Geant4"]) # for E>=Etrans idx4 = np.where((Tp >= self._Etrans[self.hiEmodel])) sigma[idx4] = self._sigma_pi_hiE(Tp[idx4], self._a[self.hiEmodel]) @@ -1496,10 +1593,10 @@ def _b_params(self, Tp): b3 = np.zeros(TphiE.size) idx = np.where(TphiE < 5.0) - b1[idx], b2[idx], b3[idx] = self._b['Geant4_0'] + b1[idx], b2[idx], b3[idx] = self._b["Geant4_0"] idx = np.where(TphiE >= 5.0) - b1[idx], b2[idx], b3[idx] = self._b['Geant4'] + b1[idx], b2[idx], b3[idx] = self._b["Geant4"] idx = np.where(TphiE >= self._Etrans[self.hiEmodel]) b1[idx], b2[idx], b3[idx] = self._b[self.hiEmodel] @@ -1511,10 +1608,10 @@ def _calc_EpimaxLAB(self, Tp): m_pi = self._m_pi # Eq 10 s = 2 * m_p * (Tp + 2 * m_p) # center of mass energy - EpiCM = (s - 4 * m_p**2 + m_pi**2) / (2 * np.sqrt(s)) - PpiCM = np.sqrt(EpiCM**2 - m_pi**2) + EpiCM = (s - 4 * m_p ** 2 + m_pi ** 2) / (2 * np.sqrt(s)) + PpiCM = np.sqrt(EpiCM ** 2 - m_pi ** 2) gCM = (Tp + 2 * m_p) / np.sqrt(s) - betaCM = np.sqrt(1 - gCM**-2) + betaCM = np.sqrt(1 - gCM ** -2) EpimaxLAB = gCM * (EpiCM + PpiCM * betaCM) return EpimaxLAB @@ -1523,7 +1620,7 @@ def _calc_Egmax(self, Tp): m_pi = self._m_pi EpimaxLAB = self._calc_EpimaxLAB(Tp) gpiLAB = EpimaxLAB / m_pi - betapiLAB = np.sqrt(1 - gpiLAB**-2) + betapiLAB = np.sqrt(1 - gpiLAB ** -2) Egmax = (m_pi / 2) * gpiLAB * (1 + betapiLAB) return Egmax @@ -1540,9 +1637,13 @@ def _Amax(self, Tp): EpimaxLAB = self._calc_EpimaxLAB(Tp) Amax[loE] = b[0] * self._sigma_pi(Tp[loE]) / EpimaxLAB[loE] thetap = Tp / m_p - Amax[hiE] = (b[1] * thetap[hiE] - ** -b[2] * np.exp(b[3] * np.log(thetap[hiE])**2) * - self._sigma_pi(Tp[hiE]) / m_p) + Amax[hiE] = ( + b[1] + * thetap[hiE] ** -b[2] + * np.exp(b[3] * np.log(thetap[hiE]) ** 2) + * self._sigma_pi(Tp[hiE]) + / m_p + ) return Amax @@ -1551,26 +1652,26 @@ def _F_func(self, Tp, Egamma, modelparams): m_pi = self._m_pi # Eq 9 Egmax = self._calc_Egmax(Tp) - Yg = Egamma + m_pi**2 / (4 * Egamma) - Ygmax = Egmax + m_pi**2 / (4 * Egmax) + Yg = Egamma + m_pi ** 2 / (4 * Egamma) + Ygmax = Egmax + m_pi ** 2 / (4 * Egmax) Xg = (Yg - m_pi) / (Ygmax - m_pi) # zero out invalid fields (Egamma > Egmax -> Xg > 1) Xg[np.where(Xg > 1)] = 1.0 # Eq 11 C = lamb * m_pi / Ygmax - F = (1 - Xg**alpha)**beta - F /= (1 + Xg / C)**gamma + F = (1 - Xg ** alpha) ** beta + F /= (1 + Xg / C) ** gamma # return F def _kappa(self, Tp): thetap = Tp / self._m_p - return 3.29 - thetap**-1.5 / 5. + return 3.29 - thetap ** -1.5 / 5.0 def _mu(self, Tp): q = (Tp - 1.0) / self._m_p - x = 5. / 4. - return x * q**x * np.exp(-x * q) + x = 5.0 / 4.0 + return x * q ** x * np.exp(-x * q) def _F(self, Tp, Egamma): F = np.zeros_like(Tp) @@ -1581,14 +1682,14 @@ def _F(self, Tp, Egamma): idx = np.where((Tp >= self._Tth) * (Tp <= 1.0)) if idx[0].size > 0: kappa = self._kappa(Tp[idx]) - mp = self._F_mp['ExpData'] + mp = self._F_mp["ExpData"] mp[2] = kappa F[idx] = self._F_func(Tp[idx], Egamma, mp) # 1GeV < Tp < 4 GeV: Geant4 model 0 idx = np.where((Tp > 1.0) * (Tp <= 4.0)) if idx[0].size > 0: - mp = self._F_mp['Geant4_0'] + mp = self._F_mp["Geant4_0"] mu = self._mu(Tp[idx]) mp[2] = mu + 2.45 mp[3] = mu + 1.45 @@ -1597,7 +1698,7 @@ def _F(self, Tp, Egamma): # 4 GeV < Tp < 20 GeV idx = np.where((Tp > 4.0) * (Tp <= 20.0)) if idx[0].size > 0: - mp = self._F_mp['Geant4_1'] + mp = self._F_mp["Geant4_1"] mu = self._mu(Tp[idx]) mp[2] = 1.5 * mu + 4.95 mp[3] = mu + 1.50 @@ -1606,7 +1707,7 @@ def _F(self, Tp, Egamma): # 20 GeV < Tp < 100 GeV idx = np.where((Tp > 20.0) * (Tp <= 100.0)) if idx[0].size > 0: - mp = self._F_mp['Geant4_2'] + mp = self._F_mp["Geant4_2"] F[idx] = self._F_func(Tp[idx], Egamma, mp) # Tp > Etrans @@ -1647,8 +1748,11 @@ def _nuclear_factor(self, Tp): eps1 = 0.29 eps2 = 0.1 - epstotal = np.where(Tp > self._Tth, epsC + - (eps1 + eps2) * sigmaRpp * G / sigmainel, 0.0) + epstotal = np.where( + Tp > self._Tth, + epsC + (eps1 + eps2) * sigmaRpp * G / sigmainel, + 0.0, + ) if np.any(Tp < 1.0): # nuclear enhancement factor diverges towards Tp = Tth, fix Tp<1 to @@ -1660,11 +1764,14 @@ def _nuclear_factor(self, Tp): def _loadLUT(self, LUT_fname): try: - filename = get_pkg_data_filename(os.path.join('data', LUT_fname)) + filename = get_pkg_data_filename(os.path.join("data", LUT_fname)) self.diffsigma = LookupTable(filename) except IOError: - warnings.warn('LUT {0} not found, reverting to useLUT = False'. - format(LUT_fname)) + warnings.warn( + "LUT {0} not found, reverting to useLUT = False".format( + LUT_fname + ) + ) self.diffsigma = self._diffsigma self.useLUT = False @@ -1683,10 +1790,10 @@ def _spectrum(self, photon_energy): # Load LUT if available, otherwise use self._diffsigma if self.useLUT: - LUT_base = 'PionDecayKafexhiu14_LUT_' + LUT_base = "PionDecayKafexhiu14_LUT_" if self.nuclear_enhancement: - LUT_base += 'NucEnh_' - LUT_fname = LUT_base + '{0}.npz'.format(self.hiEmodel) + LUT_base += "NucEnh_" + LUT_fname = LUT_base + "{0}.npz".format(self.hiEmodel) # only reload LUT if it has changed or hasn't been loaded yet try: if os.path.basename(self.diffsigma.fname) != LUT_fname: @@ -1696,24 +1803,24 @@ def _spectrum(self, photon_energy): else: self.diffsigma = self._diffsigma - Egamma = _validate_ene(photon_energy).to('GeV') + Egamma = _validate_ene(photon_energy).to("GeV") Ep = self._Ep * u.GeV - J = self._J * u.Unit('1/GeV') + J = self._J * u.Unit("1/GeV") specpp = [] for Eg in Egamma: - diffsigma = self.diffsigma(Ep.value, Eg.value) * u.Unit('cm2/GeV') + diffsigma = self.diffsigma(Ep.value, Eg.value) * u.Unit("cm2/GeV") specpp.append(trapz_loglog(diffsigma * J, Ep)) self.specpp = u.Quantity(specpp) self.specpp *= self.nh * c.cgs - return self.specpp.to('1/(s eV)') + return self.specpp.to("1/(s eV)") def heaviside(x): - return (np.sign(x) + 1) / 2. + return (np.sign(x) + 1) / 2.0 class PionDecayKelner06(BaseRadiative): @@ -1748,25 +1855,28 @@ class PionDecayKelner06(BaseRadiative): """ # This class doesn't inherit from BaseProton - param_names = ['nh', 'Etrans'] + param_names = ["nh", "Etrans"] _memoize = True _cache = {} _queue = [] - def __init__(self, - particle_distribution, - nh=1.0 / u.cm**3, - Etrans=0.1 * u.TeV, - **kwargs): + def __init__( + self, + particle_distribution, + nh=1.0 / u.cm ** 3, + Etrans=0.1 * u.TeV, + **kwargs + ): self.particle_distribution = particle_distribution - self.nh = validate_scalar('nh', nh, physical_type='number density') + self.nh = validate_scalar("nh", nh, physical_type="number density") self.Etrans = validate_scalar( - 'Etrans', Etrans, domain='positive', physical_type='energy') + "Etrans", Etrans, domain="positive", physical_type="energy" + ) self.__dict__.update(**kwargs) def _particle_distribution(self, E): - return self.particle_distribution(E * u.TeV).to('1/TeV').value + return self.particle_distribution(E * u.TeV).to("1/TeV").value def _Fgamma(self, x, Ep): """ @@ -1782,14 +1892,17 @@ def _Fgamma(self, x, Ep): Eprot [TeV] """ L = np.log(Ep) - B = 1.30 + 0.14 * L + 0.011 * L**2 # Eq59 - beta = (1.79 + 0.11 * L + 0.008 * L**2)**-1 # Eq60 - k = (0.801 + 0.049 * L + 0.014 * L**2)**-1 # Eq61 - xb = x**beta - - F1 = B * (np.log(x) / x) * ((1 - xb) / (1 + k * xb * (1 - xb)))**4 - F2 = 1. / np.log(x) - (4 * beta * xb) / (1 - xb) - ( - 4 * k * beta * xb * (1 - 2 * xb)) / (1 + k * xb * (1 - xb)) + B = 1.30 + 0.14 * L + 0.011 * L ** 2 # Eq59 + beta = (1.79 + 0.11 * L + 0.008 * L ** 2) ** -1 # Eq60 + k = (0.801 + 0.049 * L + 0.014 * L ** 2) ** -1 # Eq61 + xb = x ** beta + + F1 = B * (np.log(x) / x) * ((1 - xb) / (1 + k * xb * (1 - xb))) ** 4 + F2 = ( + 1.0 / np.log(x) + - (4 * beta * xb) / (1 - xb) + - (4 * k * beta * xb * (1 - 2 * xb)) / (1 + k * xb * (1 - xb)) + ) return F1 * F2 @@ -1811,10 +1924,10 @@ def _sigma_inel(self, Ep): """ L = np.log(Ep) - sigma = 34.3 + 1.88 * L + 0.25 * L**2 + sigma = 34.3 + 1.88 * L + 0.25 * L ** 2 if Ep <= 0.1: Eth = 1.22e-3 - sigma *= (1 - (Eth / Ep)**4)**2 * heaviside(Ep - Eth) + sigma *= (1 - (Eth / Ep) ** 4) ** 2 * heaviside(Ep - Eth) return sigma * 1e-27 # convert from mbarn to cm2 def _photon_integrand(self, x, Egamma): @@ -1822,9 +1935,12 @@ def _photon_integrand(self, x, Egamma): Integrand of Eq. 72 """ try: - return (self._sigma_inel(Egamma / x) * - self._particle_distribution((Egamma / x)) * - self._Fgamma(x, Egamma / x) / x) + return ( + self._sigma_inel(Egamma / x) + * self._particle_distribution((Egamma / x)) + * self._Fgamma(x, Egamma / x) + / x + ) except ZeroDivisionError: return np.nan @@ -1840,52 +1956,71 @@ def _calc_specpp_hiE(self, Egamma): # result=c*fixed_quad(self._photon_integrand, 0., 1., args = [Egamma, # ], n = 40)[0] from scipy.integrate import quad - Egamma = Egamma.to('TeV').value - specpp = c.cgs.value * quad( - self._photon_integrand, 0., 1., args=Egamma, epsrel=1e-3, - epsabs=0)[0] - return specpp * u.Unit('1/(s TeV)') + Egamma = Egamma.to("TeV").value + specpp = ( + c.cgs.value + * quad( + self._photon_integrand, + 0.0, + 1.0, + args=Egamma, + epsrel=1e-3, + epsabs=0, + )[0] + ) + + return specpp * u.Unit("1/(s TeV)") # variables for delta integrand _c = c.cgs.value _Kpi = 0.17 - _mp = (m_p * c**2).to('TeV').value + _mp = (m_p * c ** 2).to("TeV").value _m_pi = 1.349766e-4 # TeV/c2 def _delta_integrand(self, Epi): Ep0 = self._mp + Epi / self._Kpi - qpi = (self._c * - (self.nhat / self._Kpi) * self._sigma_inel(Ep0) * - self._particle_distribution(Ep0)) - return qpi / np.sqrt(Epi**2 + self._m_pi**2) + qpi = ( + self._c + * (self.nhat / self._Kpi) + * self._sigma_inel(Ep0) + * self._particle_distribution(Ep0) + ) + return qpi / np.sqrt(Epi ** 2 + self._m_pi ** 2) def _calc_specpp_loE(self, Egamma): """ Delta-functional approximation for low energies Egamma < 0.1 TeV """ from scipy.integrate import quad - Egamma = Egamma.to('TeV').value - Epimin = Egamma + self._m_pi**2 / (4 * Egamma) - result = 2 * quad( - self._delta_integrand, Epimin, np.inf, epsrel=1e-3, epsabs=0)[0] + Egamma = Egamma.to("TeV").value + Epimin = Egamma + self._m_pi ** 2 / (4 * Egamma) - return result * u.Unit('1/(s TeV)') + result = ( + 2 + * quad( + self._delta_integrand, Epimin, np.inf, epsrel=1e-3, epsabs=0 + )[0] + ) + + return result * u.Unit("1/(s TeV)") @property def Wp(self): """Total energy in protons above 1.22 GeV threshold (erg). """ from scipy.integrate import quad + Eth = 1.22e-3 with warnings.catch_warnings(): warnings.simplefilter("ignore") - Wp = quad(lambda x: x * self._particle_distribution(x), Eth, - np.Inf)[0] + Wp = quad( + lambda x: x * self._particle_distribution(x), Eth, np.Inf + )[0] - return (Wp * u.TeV).to('erg') + return (Wp * u.TeV).to("erg") def _spectrum(self, photon_energy): """ @@ -1904,16 +2039,17 @@ def _spectrum(self, photon_energy): with warnings.catch_warnings(): warnings.simplefilter("ignore") - self.nhat = 1. # initial value, works for index~2.1 + self.nhat = 1.0 # initial value, works for index~2.1 if np.any(outspecene < self.Etrans) and np.any( - outspecene >= self.Etrans): + outspecene >= self.Etrans + ): # compute value of nhat so that delta functional matches # accurate calculation at 0.1TeV full = self._calc_specpp_hiE(self.Etrans) delta = self._calc_specpp_loE(self.Etrans) self.nhat *= (full / delta).decompose().value - self.specpp = np.zeros(len(outspecene)) * u.Unit('1/(s TeV)') + self.specpp = np.zeros(len(outspecene)) * u.Unit("1/(s TeV)") for i, Egamma in enumerate(outspecene): if Egamma >= self.Etrans: @@ -1921,9 +2057,9 @@ def _spectrum(self, photon_energy): else: self.specpp[i] = self._calc_specpp_loE(Egamma) - density_factor = (self.nh / (1 * u.Unit('1/cm3'))).decompose().value + density_factor = (self.nh / (1 * u.Unit("1/cm3"))).decompose().value - return density_factor * self.specpp.to('1/(s eV)') + return density_factor * self.specpp.to("1/(s eV)") class LookupTable(object): @@ -1944,11 +2080,12 @@ class LookupTable(object): def __init__(self, filename): from scipy.interpolate import RectBivariateSpline + f_lut = np.load(filename) X = f_lut.f.X Y = f_lut.f.Y lut = f_lut.f.lut - self.int_lut = RectBivariateSpline(X, Y, 10**lut, kx=3, ky=3, s=0) + self.int_lut = RectBivariateSpline(X, Y, 10 ** lut, kx=3, ky=3, s=0) self.fname = filename def __call__(self, X, Y): @@ -1958,33 +2095,36 @@ def __call__(self, X, Y): def _calc_lut_pp(args): # pragma: no cover epr, eph, hiEmodel, nuc = args from .models import PowerLaw + pl = PowerLaw(1 / u.eV, 1 * u.TeV, 0.0) pp = PionDecay(pl, hiEmodel=hiEmodel, nuclear_enhancement=nuc) - diffsigma = pp._diffsigma(epr.to('GeV').value, eph.to('GeV').value) + diffsigma = pp._diffsigma(epr.to("GeV").value, eph.to("GeV").value) return diffsigma -def generate_lut_pp(Ep=np.logspace(0.085623713910610105, 7, 800) * u.GeV, - Eg=np.logspace(-5, 3, 1024) * u.TeV, - out_base='PionDecayKafexhiu14_LUT_', - hiEmodel=None, - nuclear_enhancement=True): # pragma: no cover +def generate_lut_pp( + Ep=np.logspace(0.085623713910610105, 7, 800) * u.GeV, + Eg=np.logspace(-5, 3, 1024) * u.TeV, + out_base="PionDecayKafexhiu14_LUT_", + hiEmodel=None, + nuclear_enhancement=True, +): # pragma: no cover from emcee.interruptible_pool import InterruptiblePool as Pool pool = Pool() if hiEmodel is None: - hiEmodel = ['Geant4', 'Pythia8', 'SIBYLL', 'QGSJET'] + hiEmodel = ["Geant4", "Pythia8", "SIBYLL", "QGSJET"] elif type(hiEmodel) is str: hiEmodel = [hiEmodel] if nuclear_enhancement: - out_base += 'NucEnh_' + out_base += "NucEnh_" for model in hiEmodel: - out_file = out_base + model + '.npz' - print('Saving LUT for model {0} in {1}...'.format(model, out_file)) + out_file = out_base + model + ".npz" + print("Saving LUT for model {0} in {1}...".format(model, out_file)) args = [(Ep, eg, model, nuclear_enhancement) for eg in Eg] diffsigma_list = pool.map(_calc_lut_pp, args) @@ -1992,6 +2132,7 @@ def generate_lut_pp(Ep=np.logspace(0.085623713910610105, 7, 800) * u.GeV, np.savez_compressed( out_file, - X=np.log10(Ep.to('GeV').value), - Y=np.log10(Eg.to('GeV').value), - lut=np.log10(diffsigma)) + X=np.log10(Ep.to("GeV").value), + Y=np.log10(Eg.to("GeV").value), + lut=np.log10(diffsigma), + ) diff --git a/naima/sherpa_models.py b/naima/sherpa_models.py index 051ee172..f02fbb2f 100644 --- a/naima/sherpa_models.py +++ b/naima/sherpa_models.py @@ -1,7 +1,11 @@ # -*- coding: utf-8 -*- # Licensed under a 3-clause BSD style license - see LICENSE.rst -from __future__ import (absolute_import, division, print_function, - unicode_literals) +from __future__ import ( + absolute_import, + division, + print_function, + unicode_literals, +) import numpy as np import astropy.units as u @@ -12,7 +16,7 @@ from . import models from .utils import trapz_loglog -__all__ = ['InverseCompton', 'Synchrotron', 'PionDecay', 'Bremsstrahlung'] +__all__ = ["InverseCompton", "Synchrotron", "PionDecay", "Bremsstrahlung"] def _mergex(xlo, xhi): @@ -59,14 +63,17 @@ def calc(self, p, x, xhi=None): # Do a trapz integration to obtain the photons per bin if xhi is None: - photons = (model * Eph).to('1/(s cm2)').value + photons = (model * Eph).to("1/(s cm2)").value else: - photons = trapz_loglog( - model, Eph, intervals=True).to('1/(s cm2)').value + photons = ( + trapz_loglog(model, Eph, intervals=True).to("1/(s cm2)").value + ) if p[-1]: - print(self.thawedpars, - trapz_loglog(Eph * model, Eph).to('erg/(s cm2)')) + print( + self.thawedpars, + trapz_loglog(Eph * model, Eph).to("erg/(s cm2)"), + ) return photons @@ -75,35 +82,39 @@ class SherpaModelECPL(SherpaModel): """ Base class for Sherpa models with a PL or ECPL particle distribution """ - def __init__(self, name='Model'): + def __init__(self, name="Model"): self.name = name # Initialize ECPL parameters - self.index = Parameter(name, 'index', 2.1, min=-10, max=10) - self.ref = Parameter(name, 'ref', 60, min=0, frozen=True, units='TeV') + self.index = Parameter(name, "index", 2.1, min=-10, max=10) + self.ref = Parameter(name, "ref", 60, min=0, frozen=True, units="TeV") self.ampl = Parameter( - name, 'ampl', 100, min=0, max=1e60, hard_max=1e100, units='1e30/eV' + name, "ampl", 100, min=0, max=1e60, hard_max=1e100, units="1e30/eV" ) self.cutoff = Parameter( - name, 'cutoff', 0, min=0, frozen=True, units='TeV') - self.beta = Parameter(name, 'beta', 1, min=0, max=10, frozen=True) + name, "cutoff", 0, min=0, frozen=True, units="TeV" + ) + self.beta = Parameter(name, "beta", 1, min=0, max=10, frozen=True) self.distance = Parameter( - name, 'distance', 1, min=0, max=1e6, frozen=True, units='kpc') - self.verbose = Parameter(name, 'verbose', 0, min=0, max=1, frozen=True) + name, "distance", 1, min=0, max=1e6, frozen=True, units="kpc" + ) + self.verbose = Parameter(name, "verbose", 0, min=0, max=1, frozen=True) @staticmethod def _pdist(p): """ Return PL or ECPL instance based on parameters p """ index, ref, ampl, cutoff, beta = p[:5] if cutoff == 0.0: - pdist = models.PowerLaw(ampl * 1e30 * u.Unit('1/eV'), ref * u.TeV, - index) + pdist = models.PowerLaw( + ampl * 1e30 * u.Unit("1/eV"), ref * u.TeV, index + ) else: pdist = models.ExponentialCutoffPowerLaw( - ampl * 1e30 * u.Unit('1/eV'), + ampl * 1e30 * u.Unit("1/eV"), ref * u.TeV, index, cutoff * u.TeV, - beta=beta) + beta=beta, + ) return pdist @@ -115,45 +126,72 @@ class InverseCompton(SherpaModelECPL): `naima.models.ExponentialCutoffPowerLaw` documentation. """ - def __init__(self, name='IC'): + def __init__(self, name="IC"): self.name = name - self.TFIR = Parameter(name, 'TFIR', 30, min=0, frozen=True, units='K') + self.TFIR = Parameter(name, "TFIR", 30, min=0, frozen=True, units="K") self.uFIR = Parameter( - name, 'uFIR', 0.0, min=0, frozen=True, - units='eV/cm3') # , 0.2eV/cm3 typical in outer disk + name, "uFIR", 0.0, min=0, frozen=True, units="eV/cm3" + ) # , 0.2eV/cm3 typical in outer disk self.TNIR = Parameter( - name, 'TNIR', 3000, min=0, frozen=True, units='K') + name, "TNIR", 3000, min=0, frozen=True, units="K" + ) self.uNIR = Parameter( - name, 'uNIR', 0.0, min=0, frozen=True, - units='eV/cm3') # , 0.2eV/cm3 typical in outer disk + name, "uNIR", 0.0, min=0, frozen=True, units="eV/cm3" + ) # , 0.2eV/cm3 typical in outer disk # add ECPL params super(InverseCompton, self).__init__(name=name) # Initialize model - ArithmeticModel.__init__(self, name, ( - self.index, self.ref, self.ampl, self.cutoff, self.beta, self.TFIR, - self.uFIR, self.TNIR, self.uNIR, self.distance, self.verbose)) + ArithmeticModel.__init__( + self, + name, + ( + self.index, + self.ref, + self.ampl, + self.cutoff, + self.beta, + self.TFIR, + self.uFIR, + self.TNIR, + self.uNIR, + self.distance, + self.verbose, + ), + ) self._use_caching = True self.cache = 10 def flux(self, p, Eph): - (index, ref, ampl, cutoff, beta, TFIR, uFIR, TNIR, uNIR, distance, - verbose) = p + ( + index, + ref, + ampl, + cutoff, + beta, + TFIR, + uFIR, + TNIR, + uNIR, + distance, + verbose, + ) = p # Build seedspec definition - seedspec = ['CMB'] + seedspec = ["CMB"] if uFIR > 0.0: - seedspec.append(['FIR', TFIR * u.K, uFIR * u.eV / u.cm**3]) + seedspec.append(["FIR", TFIR * u.K, uFIR * u.eV / u.cm ** 3]) if uNIR > 0.0: - seedspec.append(['NIR', TNIR * u.K, uNIR * u.eV / u.cm**3]) + seedspec.append(["NIR", TNIR * u.K, uNIR * u.eV / u.cm ** 3]) ic = models.InverseCompton( self._pdist(p), seed_photon_fields=seedspec, Eemin=1 * u.GeV, Eemax=100 * u.TeV, - Eed=100) + Eed=100, + ) - return ic.flux(Eph, distance=distance * u.kpc).to('1/(s cm2 keV)') + return ic.flux(Eph, distance=distance * u.kpc).to("1/(s cm2 keV)") class Synchrotron(SherpaModelECPL): @@ -164,15 +202,26 @@ class Synchrotron(SherpaModelECPL): `naima.models.ExponentialCutoffPowerLaw` documentation. """ - def __init__(self, name='Sync'): + def __init__(self, name="Sync"): self.name = name - self.B = Parameter(name, 'B', 1, min=0, max=10, frozen=True, units='G') + self.B = Parameter(name, "B", 1, min=0, max=10, frozen=True, units="G") # add ECPL params super(Synchrotron, self).__init__(name=name) # Initialize model - ArithmeticModel.__init__(self, name, (self.index, self.ref, self.ampl, - self.cutoff, self.beta, self.B, - self.distance, self.verbose)) + ArithmeticModel.__init__( + self, + name, + ( + self.index, + self.ref, + self.ampl, + self.cutoff, + self.beta, + self.B, + self.distance, + self.verbose, + ), + ) self._use_caching = True self.cache = 10 @@ -180,7 +229,7 @@ def flux(self, p, Eph): index, ref, ampl, cutoff, beta, B, distance, verbose = p sy = models.Synchrotron(self._pdist(p), B=B * u.G) - return sy.flux(Eph, distance=distance * u.kpc).to('1/(s cm2 keV)') + return sy.flux(Eph, distance=distance * u.kpc).to("1/(s cm2 keV)") class Bremsstrahlung(SherpaModelECPL): @@ -191,33 +240,60 @@ class Bremsstrahlung(SherpaModelECPL): `naima.models.ExponentialCutoffPowerLaw` documentation. """ - def __init__(self, name='Bremsstrahlung'): + def __init__(self, name="Bremsstrahlung"): self.name = name self.n0 = Parameter( - name, 'n0', 1, min=0, max=1e20, frozen=True, units='1/cm3') + name, "n0", 1, min=0, max=1e20, frozen=True, units="1/cm3" + ) self.weight_ee = Parameter( - name, 'weight_ee', 1.088, min=0, max=10, frozen=True) + name, "weight_ee", 1.088, min=0, max=10, frozen=True + ) self.weight_ep = Parameter( - name, 'weight_ep', 1.263, min=0, max=10, frozen=True) + name, "weight_ep", 1.263, min=0, max=10, frozen=True + ) # add ECPL params super(Bremsstrahlung, self).__init__(name=name) # Initialize model - ArithmeticModel.__init__(self, name, ( - self.index, self.ref, self.ampl, self.cutoff, self.beta, self.n0, - self.weight_ee, self.weight_ep, self.distance, self.verbose)) + ArithmeticModel.__init__( + self, + name, + ( + self.index, + self.ref, + self.ampl, + self.cutoff, + self.beta, + self.n0, + self.weight_ee, + self.weight_ep, + self.distance, + self.verbose, + ), + ) self._use_caching = True self.cache = 10 def flux(self, p, Eph): - (index, ref, ampl, cutoff, beta, n0, weight_ee, weight_ep, distance, - verbose) = p + ( + index, + ref, + ampl, + cutoff, + beta, + n0, + weight_ee, + weight_ep, + distance, + verbose, + ) = p brems = models.Bremsstrahlung( self._pdist(p), - n0=n0 / u.cm**3, + n0=n0 / u.cm ** 3, weight_ee=weight_ee, - weight_ep=weight_ep) + weight_ep=weight_ep, + ) - return brems.flux(Eph, distance=distance * u.kpc).to('1/(s cm2 keV)') + return brems.flux(Eph, distance=distance * u.kpc).to("1/(s cm2 keV)") class PionDecay(SherpaModelECPL): @@ -228,20 +304,31 @@ class PionDecay(SherpaModelECPL): `naima.models.ExponentialCutoffPowerLaw` documentation. """ - def __init__(self, name='pp'): + def __init__(self, name="pp"): self.name = name - self.nh = Parameter(name, 'nH', 1, min=0, frozen=True, units='1/cm3') + self.nh = Parameter(name, "nH", 1, min=0, frozen=True, units="1/cm3") # add ECPL params super(PionDecay, self).__init__(name=name) # Initialize model - ArithmeticModel.__init__(self, name, (self.index, self.ref, self.ampl, - self.cutoff, self.beta, self.nh, - self.distance, self.verbose)) + ArithmeticModel.__init__( + self, + name, + ( + self.index, + self.ref, + self.ampl, + self.cutoff, + self.beta, + self.nh, + self.distance, + self.verbose, + ), + ) self._use_caching = True self.cache = 10 def flux(self, p, Eph): index, ref, ampl, cutoff, beta, nh, distance, verbose = p - pp = models.PionDecay(self._pdist(p), nh=nh * u.Unit('1/cm3')) + pp = models.PionDecay(self._pdist(p), nh=nh * u.Unit("1/cm3")) - return pp.flux(Eph, distance=distance * u.kpc).to('1/(s cm2 keV)') + return pp.flux(Eph, distance=distance * u.kpc).to("1/(s cm2 keV)") diff --git a/naima/tests/fixtures.py b/naima/tests/fixtures.py index 92f5e9f1..4a8d6463 100644 --- a/naima/tests/fixtures.py +++ b/naima/tests/fixtures.py @@ -7,8 +7,9 @@ from astropy.tests.helper import pytest from ..core import run_sampler, uniform_prior + # Read data -fname = get_pkg_data_filename('data/CrabNebula_HESS_ipac.dat') +fname = get_pkg_data_filename("data/CrabNebula_HESS_ipac.dat") data_table = ascii.read(fname) # Model definition @@ -23,51 +24,72 @@ def cutoffexp(pars, data): - 3: cutoff exponent (beta) """ - x = data['energy'].copy() + x = data["energy"].copy() # take logarithmic mean of first and last data points as normalization # energy x0 = np.sqrt(x[0] * x[-1]) N = np.exp(pars[0]) gamma = pars[1] - ecut = (10**pars[2]) * u.TeV + ecut = (10 ** pars[2]) * u.TeV # beta = pars[3] - beta = 1. + beta = 1.0 - flux = N * (x / x0) ** -gamma * np.exp( - -(x / ecut) ** beta) * u.Unit('1/(cm2 s TeV)') + flux = ( + N + * (x / x0) ** -gamma + * np.exp(-(x / ecut) ** beta) + * u.Unit("1/(cm2 s TeV)") + ) # save a model with different energies than the data - ene = np.logspace(np.log10(x[0].value) - 1, - np.log10(x[-1].value) + 1, 100) * x.unit - model = (N * (ene / x0) ** -gamma * - np.exp(-(ene / ecut) ** beta)) * u.Unit('1/(cm2 s TeV)') + ene = ( + np.logspace(np.log10(x[0].value) - 1, np.log10(x[-1].value) + 1, 100) + * x.unit + ) + model = ( + N * (ene / x0) ** -gamma * np.exp(-(ene / ecut) ** beta) + ) * u.Unit("1/(cm2 s TeV)") # save a particle energy distribution - model_part = (N * (ene / x0) ** -gamma * - np.exp(-(ene / ecut) ** beta)) * u.Unit('1/(TeV)') + model_part = ( + N * (ene / x0) ** -gamma * np.exp(-(ene / ecut) ** beta) + ) * u.Unit("1/(TeV)") # save a broken powerlaw in luminosity units - _model1 = N * np.where(x <= x0, - (x / x0) ** -(gamma - 0.5), - (x / x0) ** -(gamma + 0.5) - ) * u.Unit('1/(cm2 s TeV)') + _model1 = ( + N + * np.where( + x <= x0, (x / x0) ** -(gamma - 0.5), (x / x0) ** -(gamma + 0.5) + ) + * u.Unit("1/(cm2 s TeV)") + ) - model1 = (_model1 * (x ** 2) * 4 * np.pi * (2 * u.kpc) ** 2).to('erg/s') + model1 = (_model1 * (x ** 2) * 4 * np.pi * (2 * u.kpc) ** 2).to("erg/s") # save a model with no units to check that it is dealt with gracefully model2 = 1e-10 * np.ones(len(x)) # save a model with wrong length to check that it is dealt with gracefully - model3 = 1e-10 * np.ones(len(x) * 2) * u.Unit('erg/s') + model3 = 1e-10 * np.ones(len(x) * 2) * u.Unit("erg/s") # add a scalar value to test plot_distribution - model4 = np.trapz(model, ene).to('1/(cm2 s)') + model4 = np.trapz(model, ene).to("1/(cm2 s)") # and without units model5 = model4.value # save flux model as tuple with energies and without - return (flux, (x, flux), (ene, model), (ene, model_part), model1, model2, - model3, (x, model3), model4, model5) + return ( + flux, + (x, flux), + (ene, model), + (ene, model_part), + model1, + model2, + model3, + (x, model3), + model4, + model5, + ) def simple_cutoffexp(pars, data): @@ -87,21 +109,39 @@ def lnprior(pars): return logprob + # Run sampler @pytest.fixture def sampler(): - p0=np.array((np.log(1.8e-12),2.4,np.log10(15.0),)) - labels=['log(norm)','index','log10(cutoff)'] + p0 = np.array((np.log(1.8e-12), 2.4, np.log10(15.0))) + labels = ["log(norm)", "index", "log10(cutoff)"] sampler, pos = run_sampler( - data_table=data_table, p0=p0, labels=labels, model=cutoffexp, - prior=lnprior, nwalkers=10, nburn=2, nrun=2, threads=1) + data_table=data_table, + p0=p0, + labels=labels, + model=cutoffexp, + prior=lnprior, + nwalkers=10, + nburn=2, + nrun=2, + threads=1, + ) return sampler + @pytest.fixture def simple_sampler(): - p0=np.array((np.log(1.8e-12),2.4,np.log10(15.0),)) - labels=['log(norm)','index','log10(cutoff)'] + p0 = np.array((np.log(1.8e-12), 2.4, np.log10(15.0))) + labels = ["log(norm)", "index", "log10(cutoff)"] sampler, pos = run_sampler( - data_table=data_table, p0=p0, labels=labels, model=simple_cutoffexp, - prior=lnprior, nwalkers=10, nburn=2, nrun=2, threads=1) + data_table=data_table, + p0=p0, + labels=labels, + model=simple_cutoffexp, + prior=lnprior, + nwalkers=10, + nburn=2, + nrun=2, + threads=1, + ) return sampler diff --git a/naima/tests/setup_package.py b/naima/tests/setup_package.py index f34fe9ed..08f73eaf 100644 --- a/naima/tests/setup_package.py +++ b/naima/tests/setup_package.py @@ -1,3 +1,2 @@ def get_package_data(): - return { - _ASTROPY_PACKAGE_NAME_ + '.tests': ['coveragerc', 'data/*.dat']} + return {_ASTROPY_PACKAGE_NAME_ + ".tests": ["coveragerc", "data/*.dat"]} diff --git a/naima/tests/test_functionfit.py b/naima/tests/test_functionfit.py index f035a076..65de3ae8 100644 --- a/naima/tests/test_functionfit.py +++ b/naima/tests/test_functionfit.py @@ -5,42 +5,51 @@ import astropy.units as u from astropy.io import ascii -from ..core import (run_sampler, get_sampler, uniform_prior, normal_prior, - lnprob) +from ..core import ( + run_sampler, + get_sampler, + uniform_prior, + normal_prior, + lnprob, +) try: import emcee + HAS_EMCEE = True except ImportError: HAS_EMCEE = False try: import scipy + HAS_SCIPY = True except ImportError: HAS_SCIPY = False try: import matplotlib - matplotlib.use('Agg') + + matplotlib.use("Agg") HAS_MATPLOTLIB = True except ImportError: HAS_MATPLOTLIB = False # Read data -fname = get_pkg_data_filename('data/CrabNebula_HESS_ipac.dat') +fname = get_pkg_data_filename("data/CrabNebula_HESS_ipac.dat") data_table = ascii.read(fname) # Read fake SED -fname0 = get_pkg_data_filename('data/Fake_ipac_sed.dat') +fname0 = get_pkg_data_filename("data/Fake_ipac_sed.dat") data_table_sed = ascii.read(fname0) # Read spectrum with symmetric flux errors -fname2 = get_pkg_data_filename('data/CrabNebula_HESS_ipac_symmetric.dat') +fname2 = get_pkg_data_filename("data/CrabNebula_HESS_ipac_symmetric.dat") data_table2 = ascii.read(fname2) # Model definition + def cutoffexp(pars, data): """ Powerlaw with exponential cutoff @@ -52,7 +61,7 @@ def cutoffexp(pars, data): - 3: cutoff exponent (beta) """ - x = data['energy'] + x = data["energy"] # take logarithmic mean of first and last data points as normalization # energy x0 = np.sqrt(x[0] * x[-1]) @@ -61,147 +70,262 @@ def cutoffexp(pars, data): gamma = pars[1] ecut = pars[2] * u.TeV # beta = pars[3] - beta = 1. + beta = 1.0 + + return ( + N + * (x / x0) ** -gamma + * np.exp(-(x / ecut) ** beta) + * u.Unit("1/(cm2 s TeV)") + ) - return N * (x / x0) ** -gamma * np.exp(-(x / ecut) ** beta) * u.Unit('1/(cm2 s TeV)') def cutoffexp_sed(pars, data): - x = data['energy'] + x = data["energy"] x0 = np.sqrt(x[0] * x[-1]) N = pars[0] gamma = pars[1] ecut = pars[2] * u.TeV - return N * (x / x0) ** -gamma * np.exp(-(x / ecut)) * u.Unit('erg/(cm2 s)') + return N * (x / x0) ** -gamma * np.exp(-(x / ecut)) * u.Unit("erg/(cm2 s)") + def cutoffexp_wrong(pars, data): - return data['energy'] * u.m + return data["energy"] * u.m + # Prior definition + def lnprior(pars): """ Return probability of parameter values according to prior knowledge. Parameter limits should be done here through uniform prior ditributions """ - logprob = uniform_prior(pars[0], 0., np.inf) \ - + normal_prior(pars[1], 1.4, 0.5) \ - + uniform_prior(pars[2], 0., np.inf) + logprob = ( + uniform_prior(pars[0], 0.0, np.inf) + + normal_prior(pars[1], 1.4, 0.5) + + uniform_prior(pars[2], 0.0, np.inf) + ) return logprob + # Set initial parameters -p0 = np.array((1e-9, 1.4, 14.0,)) -labels = ['norm', 'index', 'cutoff'] +p0 = np.array((1e-9, 1.4, 14.0)) +labels = ["norm", "index", "cutoff"] # Initialize in different ways to test argument validation -@pytest.mark.skipif('not HAS_EMCEE') +@pytest.mark.skipif("not HAS_EMCEE") def test_init(): sampler, pos = get_sampler( - data_table=data_table, p0=p0, labels=labels, model=cutoffexp, - prior=lnprior, nwalkers=10, nburn=2, threads=1) + data_table=data_table, + p0=p0, + labels=labels, + model=cutoffexp, + prior=lnprior, + nwalkers=10, + nburn=2, + threads=1, + ) sampler, pos = run_sampler( - data_table=data_table, p0=p0, labels=labels, model=cutoffexp, - prior=lnprior, nwalkers=10, nburn=2, nrun=2, threads=1) + data_table=data_table, + p0=p0, + labels=labels, + model=cutoffexp, + prior=lnprior, + nwalkers=10, + nburn=2, + nrun=2, + threads=1, + ) # test that the CL keyword has been correctly read - assert np.all(sampler.data['cl'] == 0.99) + assert np.all(sampler.data["cl"] == 0.99) -@pytest.mark.skipif('not HAS_EMCEE') +@pytest.mark.skipif("not HAS_EMCEE") def test_inf_prior(): pars = p0 pars[0] = -1e-9 _ = lnprob(pars, data_table, cutoffexp, lnprior) -@pytest.mark.skipif('not HAS_EMCEE') +@pytest.mark.skipif("not HAS_EMCEE") def test_sed_conversion_in_lnprobmodel(): sampler, pos = get_sampler( - data_table=data_table, p0=p0, labels=labels, model=cutoffexp_sed, - prior=lnprior, nwalkers=10, nburn=2, threads=1) - - -@pytest.mark.skipif('not HAS_EMCEE') + data_table=data_table, + p0=p0, + labels=labels, + model=cutoffexp_sed, + prior=lnprior, + nwalkers=10, + nburn=2, + threads=1, + ) + + +@pytest.mark.skipif("not HAS_EMCEE") def test_wrong_model_units(): # test exception raised when model and data spectra cannot be compared with pytest.raises(u.UnitsError): sampler, pos = get_sampler( - data_table=data_table, p0=p0, labels=labels, model=cutoffexp_wrong, - prior=lnprior, nwalkers=10, nburn=2, threads=1) - - -@pytest.mark.skipif('not HAS_EMCEE or not HAS_SCIPY') + data_table=data_table, + p0=p0, + labels=labels, + model=cutoffexp_wrong, + prior=lnprior, + nwalkers=10, + nburn=2, + threads=1, + ) + + +@pytest.mark.skipif("not HAS_EMCEE or not HAS_SCIPY") def test_prefit(): sampler, pos = get_sampler( - data_table=data_table, p0=p0, labels=labels, model=cutoffexp, - prior=lnprior, nwalkers=10, nburn=5, threads=1, prefit=True) - - -@pytest.mark.skipif('not HAS_EMCEE or not HAS_SCIPY or not HAS_MATPLOTLIB') + data_table=data_table, + p0=p0, + labels=labels, + model=cutoffexp, + prior=lnprior, + nwalkers=10, + nburn=5, + threads=1, + prefit=True, + ) + + +@pytest.mark.skipif("not HAS_EMCEE or not HAS_SCIPY or not HAS_MATPLOTLIB") def test_interactive(): sampler, pos = get_sampler( - data_table=data_table, p0=p0, labels=labels, model=cutoffexp, - prior=lnprior, nwalkers=10, nburn=5, threads=1, interactive=True) - - -@pytest.mark.skipif('not HAS_EMCEE') + data_table=data_table, + p0=p0, + labels=labels, + model=cutoffexp, + prior=lnprior, + nwalkers=10, + nburn=5, + threads=1, + interactive=True, + ) + + +@pytest.mark.skipif("not HAS_EMCEE") def test_init_symmetric_dflux(): # symmetric data_table errors sampler, pos = run_sampler( - data_table=data_table2, p0=p0, labels=labels, model=cutoffexp, - prior=lnprior, nwalkers=10, nburn=2, nrun=2, threads=1) - - -@pytest.mark.skipif('not HAS_EMCEE') + data_table=data_table2, + p0=p0, + labels=labels, + model=cutoffexp, + prior=lnprior, + nwalkers=10, + nburn=2, + nrun=2, + threads=1, + ) + + +@pytest.mark.skipif("not HAS_EMCEE") def test_init_labels(): # labels - sampler, pos = run_sampler(data_table=data_table, p0=p0, labels=None, - model=cutoffexp, prior=lnprior, nwalkers=10, - nrun=2, nburn=2, threads=1) sampler, pos = run_sampler( - data_table=data_table, p0=p0, labels=labels[:2], model=cutoffexp, - prior=lnprior, nwalkers=10, nrun=2, nburn=2, threads=1) - - -@pytest.mark.skipif('not HAS_EMCEE') + data_table=data_table, + p0=p0, + labels=None, + model=cutoffexp, + prior=lnprior, + nwalkers=10, + nrun=2, + nburn=2, + threads=1, + ) + sampler, pos = run_sampler( + data_table=data_table, + p0=p0, + labels=labels[:2], + model=cutoffexp, + prior=lnprior, + nwalkers=10, + nrun=2, + nburn=2, + threads=1, + ) + + +@pytest.mark.skipif("not HAS_EMCEE") def test_init_prior(): # no prior sampler, pos = run_sampler( - data_table=data_table, p0=p0, labels=labels, model=cutoffexp, - prior=None, nwalkers=10, nrun=2, nburn=2, threads=1) - - -@pytest.mark.skipif('not HAS_EMCEE') + data_table=data_table, + p0=p0, + labels=labels, + model=cutoffexp, + prior=None, + nwalkers=10, + nrun=2, + nburn=2, + threads=1, + ) + + +@pytest.mark.skipif("not HAS_EMCEE") def test_init_exception_model(): # test exception raised when no model or data_table are provided with pytest.raises(TypeError): - sampler, pos = get_sampler(data_table=data_table, p0=p0, labels=labels, - prior=lnprior, nwalkers=10, nburn=2, - threads=1) + sampler, pos = get_sampler( + data_table=data_table, + p0=p0, + labels=labels, + prior=lnprior, + nwalkers=10, + nburn=2, + threads=1, + ) -@pytest.mark.skipif('not HAS_EMCEE') +@pytest.mark.skipif("not HAS_EMCEE") def test_init_exception_data(): with pytest.raises(TypeError): - sampler, pos = get_sampler(p0=p0, labels=labels, model=cutoffexp, - prior=lnprior, nwalkers=10, nburn=2, - threads=1) + sampler, pos = get_sampler( + p0=p0, + labels=labels, + model=cutoffexp, + prior=lnprior, + nwalkers=10, + nburn=2, + threads=1, + ) -@pytest.mark.skipif('not HAS_EMCEE') +@pytest.mark.skipif("not HAS_EMCEE") def test_multiple_data_tables(): - sampler, pos = get_sampler(data_table=[data_table_sed, data_table], p0=p0, - labels=labels, model=cutoffexp, prior=lnprior, - nwalkers=10, nburn=2, threads=1) - - -@pytest.mark.skipif('not HAS_EMCEE') + sampler, pos = get_sampler( + data_table=[data_table_sed, data_table], + p0=p0, + labels=labels, + model=cutoffexp, + prior=lnprior, + nwalkers=10, + nburn=2, + threads=1, + ) + + +@pytest.mark.skipif("not HAS_EMCEE") def test_data_table_in_list(): - sampler, pos = get_sampler(data_table=[data_table], p0=p0, labels=labels, - model=cutoffexp, prior=lnprior, nwalkers=10, - nburn=2, threads=1) + sampler, pos = get_sampler( + data_table=[data_table], + p0=p0, + labels=labels, + model=cutoffexp, + prior=lnprior, + nwalkers=10, + nburn=2, + threads=1, + ) diff --git a/naima/tests/test_interactive.py b/naima/tests/test_interactive.py index a10fb512..ed541acf 100644 --- a/naima/tests/test_interactive.py +++ b/naima/tests/test_interactive.py @@ -9,8 +9,10 @@ try: import matplotlib - matplotlib.use('Agg') + + matplotlib.use("Agg") import matplotlib.pyplot as plt + HAS_MATPLOTLIB = True except: HAS_MATPLOTLIB = False @@ -19,52 +21,66 @@ from ..model_fitter import InteractiveModelFitter # Read data -fname = get_pkg_data_filename('data/CrabNebula_HESS_ipac.dat') +fname = get_pkg_data_filename("data/CrabNebula_HESS_ipac.dat") data = ascii.read(fname) -def modelfn(pars,data): - ECPL = ExponentialCutoffPowerLaw(10**pars[0] * u.Unit('1/(cm2 s TeV)'), 1*u.TeV, - pars[1], 10**pars[2] * u.TeV) + +def modelfn(pars, data): + ECPL = ExponentialCutoffPowerLaw( + 10 ** pars[0] * u.Unit("1/(cm2 s TeV)"), + 1 * u.TeV, + pars[1], + 10 ** pars[2] * u.TeV, + ) return ECPL(data) + def modelfn2(pars, data): - return modelfn(pars, data), (1,2,3)*u.m + return modelfn(pars, data), (1, 2, 3) * u.m -labels = ['log10(norm)', 'index', 'log10(cutoff)'] + +labels = ["log10(norm)", "index", "log10(cutoff)"] p0 = np.array((-12, 2.7, np.log10(14))) -e_range = [100*u.GeV, 100*u.TeV] +e_range = [100 * u.GeV, 100 * u.TeV] + -@pytest.mark.skipif('not HAS_MATPLOTLIB') +@pytest.mark.skipif("not HAS_MATPLOTLIB") def test_modelwidget_inputs(): for dt in [data, None]: for er in [e_range, None]: for model in [modelfn, modelfn2]: - imf = InteractiveModelFitter(model, p0, labels=labels, - data=dt, e_range=er) - imf.update('test') + imf = InteractiveModelFitter( + model, p0, labels=labels, data=dt, e_range=er + ) + imf.update("test") for labs in [labels, labels[:2], None]: imf = InteractiveModelFitter(model, p0, labels=labs) for sed in [True, False]: for dt in [data, None]: - imf = InteractiveModelFitter(model, p0, data=dt, labels=labels, sed=sed) + imf = InteractiveModelFitter( + model, p0, data=dt, labels=labels, sed=sed + ) p0[1] = -2.7 imf = InteractiveModelFitter(model, p0, labels=labels) - labels[0] = 'norm' + labels[0] = "norm" imf = InteractiveModelFitter(model, p0, labels=labels) - plt.close('all') + plt.close("all") + -@pytest.mark.skipif('not HAS_MATPLOTLIB') +@pytest.mark.skipif("not HAS_MATPLOTLIB") def test_modelwidget_funcs(): - imf = InteractiveModelFitter(modelfn, p0, data=data, labels=labels, auto_update=False) + imf = InteractiveModelFitter( + modelfn, p0, data=data, labels=labels, auto_update=False + ) assert imf.autoupdate is False - imf.update_autoupdate('test') + imf.update_autoupdate("test") assert imf.autoupdate is True imf.parsliders[0].val *= 2 - imf.update_if_auto('test') - imf.close_fig('test') + imf.update_if_auto("test") + imf.close_fig("test") imf = InteractiveModelFitter(modelfn, p0, labels=labels, auto_update=False) - imf.update('test') - plt.close('all') + imf.update("test") + plt.close("all") diff --git a/naima/tests/test_models.py b/naima/tests/test_models.py index e1744b4e..6c0f13d6 100644 --- a/naima/tests/test_models.py +++ b/naima/tests/test_models.py @@ -11,6 +11,7 @@ try: import scipy + HAS_SCIPY = True except ImportError: HAS_SCIPY = False @@ -22,38 +23,42 @@ alpha_1 = 1.5 alpha_2 = 2.5 -electron_properties = {'Eemin': 100 * u.GeV, 'Eemax': 1 * u.PeV} -proton_properties = {'Epmax': 1 * u.PeV} +electron_properties = {"Eemin": 100 * u.GeV, "Eemax": 1 * u.PeV} +proton_properties = {"Epmax": 1 * u.PeV} energy = np.logspace(0, 15, 1000) * u.eV from astropy.table import QTable, Table + data = QTable() -data['energy'] = energy +data["energy"] = energy data2 = Table() -data2['energy'] = energy +data2["energy"] = energy from astropy.constants import m_e, c, sigma_sb, hbar -pdist_unit = 1 / u.Unit(m_e * c**2) + +pdist_unit = 1 / u.Unit(m_e * c ** 2) @pytest.fixture def particle_dists(): from ..models import ExponentialCutoffPowerLaw, PowerLaw, BrokenPowerLaw - ECPL = ExponentialCutoffPowerLaw(amplitude=1 * pdist_unit, - e_0=e_0, - alpha=alpha, - e_cutoff=e_cutoff) + + ECPL = ExponentialCutoffPowerLaw( + amplitude=1 * pdist_unit, e_0=e_0, alpha=alpha, e_cutoff=e_cutoff + ) PL = PowerLaw(amplitude=1 * pdist_unit, e_0=e_0, alpha=alpha) - BPL = BrokenPowerLaw(amplitude=1 * pdist_unit, - e_0=e_0, - e_break=e_break, - alpha_1=alpha_1, - alpha_2=alpha_2) + BPL = BrokenPowerLaw( + amplitude=1 * pdist_unit, + e_0=e_0, + e_break=e_break, + alpha_1=alpha_1, + alpha_2=alpha_2, + ) return ECPL, PL, BPL -@pytest.mark.skipif('not HAS_SCIPY') +@pytest.mark.skipif("not HAS_SCIPY") def test_electron_synchrotron_lum(particle_dists): """ test sync calculation @@ -62,8 +67,11 @@ def test_electron_synchrotron_lum(particle_dists): ECPL, PL, BPL = particle_dists - lum_ref = [0.00025231296225663107, 0.03316715765695228, - 0.00044597089198025806] + lum_ref = [ + 0.00025231296225663107, + 0.03316715765695228, + 0.00044597089198025806, + ] We_ref = [5064124672.902273, 11551172166.866821, 926633861.2898524] Wes = [] @@ -71,9 +79,9 @@ def test_electron_synchrotron_lum(particle_dists): for pdist in particle_dists: sy = Synchrotron(pdist, **electron_properties) - Wes.append(sy.We.to('erg').value) + Wes.append(sy.We.to("erg").value) - lsy = trapz_loglog(sy.flux(energy, 0) * energy, energy).to('erg/s') + lsy = trapz_loglog(sy.flux(energy, 0) * energy, energy).to("erg/s") assert lsy.unit == u.erg / u.s lsys.append(lsy.value) @@ -86,12 +94,12 @@ def test_electron_synchrotron_lum(particle_dists): sy.flux(data) sy.flux(data2) - lsy = trapz_loglog(sy.flux(energy, 0) * energy, energy).to('erg/s') - assert (lsy.unit == u.erg / u.s) + lsy = trapz_loglog(sy.flux(energy, 0) * energy, energy).to("erg/s") + assert lsy.unit == u.erg / u.s assert_allclose(lsy.value, 31374131.90312505) -@pytest.mark.skipif('not HAS_SCIPY') +@pytest.mark.skipif("not HAS_SCIPY") def test_proton_synchrotron_lum(particle_dists): """ test sync calculation @@ -109,9 +117,9 @@ def test_proton_synchrotron_lum(particle_dists): for pdist in particle_dists: sy = ProtonSynchrotron(pdist, **proton_properties) - Wps.append(sy.Wp.to('erg').value) + Wps.append(sy.Wp.to("erg").value) - lsy = trapz_loglog(sy.flux(energy, 0) * energy, energy).to('erg/s') + lsy = trapz_loglog(sy.flux(energy, 0) * energy, energy).to("erg/s") assert lsy.unit == u.erg / u.s lsys.append(lsy.value) @@ -124,23 +132,24 @@ def test_proton_synchrotron_lum(particle_dists): sy.flux(data) sy.flux(data2) - lsy = trapz_loglog(sy.flux(energy, 0) * energy, energy).to('erg/s') - assert (lsy.unit == u.erg / u.s) + lsy = trapz_loglog(sy.flux(energy, 0) * energy, energy).to("erg/s") + assert lsy.unit == u.erg / u.s print(lsy) # assert_allclose(lsy.value, 31374131.90312505) -@pytest.mark.skipif('not HAS_SCIPY') +@pytest.mark.skipif("not HAS_SCIPY") def test_synchrotron_traits(particle_dists): from ..models import Synchrotron + ECPL, _, _ = particle_dists sy = Synchrotron(ECPL, Eemin=1 * u.GeV, Eemax=1 * u.PeV) sy.Eemin = 1 * u.TeV - assert sy.gmin == float(1 * u.TeV / (m_e * c **2)) + assert sy.gmin == float(1 * u.TeV / (m_e * c ** 2)) sy.Eemax = 100 * u.TeV - assert sy.gmax == float(100 * u.TeV / (m_e * c **2)) + assert sy.gmax == float(100 * u.TeV / (m_e * c ** 2)) sy.nEed = 10 assert sy.ngd == 10 @@ -149,7 +158,7 @@ def test_synchrotron_traits(particle_dists): sy.Eemin = 10 * u.m -@pytest.mark.skipif('not HAS_SCIPY') +@pytest.mark.skipif("not HAS_SCIPY") def test_bolometric_luminosity(particle_dists): """ test sync calculation @@ -165,7 +174,7 @@ def test_bolometric_luminosity(particle_dists): sy.sed(energy, distance=0) -@pytest.mark.skipif('not HAS_SCIPY') +@pytest.mark.skipif("not HAS_SCIPY") def test_compute_We(particle_dists): """ test sync calculation @@ -190,7 +199,7 @@ def test_compute_We(particle_dists): pp.compute_Wp(Epmin=Epmin, Epmax=Epmax) -@pytest.mark.skipif('not HAS_SCIPY') +@pytest.mark.skipif("not HAS_SCIPY") @pytest.mark.parametrize("Eemin", [1 * u.GeV, 10 * u.GeV, None]) @pytest.mark.parametrize("Eemax", [100 * u.TeV, None]) def test_set_We(particle_dists, Eemin, Eemax): @@ -208,22 +217,22 @@ def test_set_We(particle_dists, Eemin, Eemax): sy.set_We(W, Eemin, Eemax) assert_allclose(W, sy.compute_We(Eemin, Eemax)) - sy.set_We(W, Eemin, Eemax, amplitude_name='amplitude') + sy.set_We(W, Eemin, Eemax, amplitude_name="amplitude") assert_allclose(W, sy.compute_We(Eemin, Eemax)) pp.set_Wp(W, Eemin, Eemax) assert_allclose(W, pp.compute_Wp(Eemin, Eemax)) - pp.set_Wp(W, Eemin, Eemax, amplitude_name='amplitude') + pp.set_Wp(W, Eemin, Eemax, amplitude_name="amplitude") assert_allclose(W, pp.compute_Wp(Eemin, Eemax)) with pytest.raises(AttributeError): - sy.set_We(W, amplitude_name='norm') + sy.set_We(W, amplitude_name="norm") with pytest.raises(AttributeError): - pp.set_Wp(W, amplitude_name='norm') + pp.set_Wp(W, amplitude_name="norm") -@pytest.mark.skipif('not HAS_SCIPY') +@pytest.mark.skipif("not HAS_SCIPY") def test_bremsstrahlung_lum(particle_dists): """ test sync calculation @@ -236,14 +245,15 @@ def test_bremsstrahlung_lum(particle_dists): energy2 = np.logspace(8, 14, 100) * u.eV brems = Bremsstrahlung(ECPL, n0=1 * u.cm ** -3, Eemin=m_e * c ** 2) - lbrems = trapz_loglog(brems.flux(energy2, 0) * energy2, - energy2).to('erg/s') + lbrems = trapz_loglog(brems.flux(energy2, 0) * energy2, energy2).to( + "erg/s" + ) lum_ref = 2.3064095039069847e-05 assert_allclose(lbrems.value, lum_ref) -@pytest.mark.skipif('not HAS_SCIPY') +@pytest.mark.skipif("not HAS_SCIPY") def test_inverse_compton_lum(particle_dists): """ test IC calculation @@ -252,27 +262,30 @@ def test_inverse_compton_lum(particle_dists): ECPL, PL, BPL = particle_dists - lum_ref = [0.00027822017772343816, 0.004821189282097695, - 0.00012916583207749083] + lum_ref = [ + 0.00027822017772343816, + 0.004821189282097695, + 0.00012916583207749083, + ] lums = [] for pdist in particle_dists: ic = InverseCompton(pdist, **electron_properties) - lic = trapz_loglog(ic.flux(energy, 0) * energy, energy).to('erg/s') - assert (lic.unit == u.erg / u.s) + lic = trapz_loglog(ic.flux(energy, 0) * energy, energy).to("erg/s") + assert lic.unit == u.erg / u.s lums.append(lic.value) assert_allclose(lums, lum_ref) - ic = InverseCompton(ECPL, seed_photon_fields=['CMB', 'FIR', 'NIR']) + ic = InverseCompton(ECPL, seed_photon_fields=["CMB", "FIR", "NIR"]) ic.flux(data) ic.flux(data2) - lic = trapz_loglog(ic.flux(energy, 0) * energy, energy).to('erg/s') + lic = trapz_loglog(ic.flux(energy, 0) * energy, energy).to("erg/s") assert_allclose(lic.value, 0.0005833034007064158) -@pytest.mark.skipif('not HAS_SCIPY') +@pytest.mark.skipif("not HAS_SCIPY") def test_anisotropic_inverse_compton_lum(particle_dists): """ test IC calculation @@ -287,18 +300,21 @@ def test_anisotropic_inverse_compton_lum(particle_dists): lums = [] for angle in angles: - ic = InverseCompton(PL, - seed_photon_fields=[['Star', 20000 * u.K, 0.1 * - u.erg / u.cm**3, angle],], - **electron_properties) - lic = trapz_loglog(ic.flux(energy, 0) * energy, energy).to('erg/s') - assert (lic.unit == u.erg / u.s) + ic = InverseCompton( + PL, + seed_photon_fields=[ + ["Star", 20000 * u.K, 0.1 * u.erg / u.cm ** 3, angle] + ], + **electron_properties + ) + lic = trapz_loglog(ic.flux(energy, 0) * energy, energy).to("erg/s") + assert lic.unit == u.erg / u.s lums.append(lic.value) assert_allclose(lums, lum_ref) -@pytest.mark.skipif('not HAS_SCIPY') +@pytest.mark.skipif("not HAS_SCIPY") def test_monochromatic_inverse_compton(particle_dists): """ test IC monochromatic against khangulyan et al. @@ -309,35 +325,41 @@ def test_monochromatic_inverse_compton(particle_dists): # compute a blackbody spectrum with 1 eV/cm3 at 30K from astropy.analytic_functions import blackbody_nu + Ephbb = np.logspace(-3.5, -1.5, 100) * u.eV - lambdabb = Ephbb.to('AA', equivalencies=u.equivalencies.spectral()) + lambdabb = Ephbb.to("AA", equivalencies=u.equivalencies.spectral()) T = 30 * u.K - w = 1 * u.eV / u.cm**3 - bb = (blackbody_nu(lambdabb, T) * 2 * u.sr / c.cgs - / Ephbb / hbar).to('1/(cm3 eV)') - Ebbmax = Ephbb[np.argmax(Ephbb**2 * bb)] - - ar = (4 * sigma_sb / c).to('erg/(cm3 K4)') - bb *= (w / (ar * T**4)).decompose() - - eopts = {'Eemax': 10000 * u.GeV, 'Eemin': 10 * u.GeV, 'nEed': 1000} - IC_khang = InverseCompton(PL, seed_photon_fields=[['bb', T, w]], **eopts) - IC_mono = InverseCompton(PL, - seed_photon_fields=[['mono', Ebbmax, w]], - **eopts) - IC_bb = InverseCompton(PL, seed_photon_fields=[['bb2', Ephbb, bb]], **eopts) - IC_bb_ene = InverseCompton(PL, - seed_photon_fields=[['bb2', Ephbb, Ephbb**2 * bb]], **eopts) + w = 1 * u.eV / u.cm ** 3 + bb = (blackbody_nu(lambdabb, T) * 2 * u.sr / c.cgs / Ephbb / hbar).to( + "1/(cm3 eV)" + ) + Ebbmax = Ephbb[np.argmax(Ephbb ** 2 * bb)] + + ar = (4 * sigma_sb / c).to("erg/(cm3 K4)") + bb *= (w / (ar * T ** 4)).decompose() + + eopts = {"Eemax": 10000 * u.GeV, "Eemin": 10 * u.GeV, "nEed": 1000} + IC_khang = InverseCompton(PL, seed_photon_fields=[["bb", T, w]], **eopts) + IC_mono = InverseCompton( + PL, seed_photon_fields=[["mono", Ebbmax, w]], **eopts + ) + IC_bb = InverseCompton( + PL, seed_photon_fields=[["bb2", Ephbb, bb]], **eopts + ) + IC_bb_ene = InverseCompton( + PL, seed_photon_fields=[["bb2", Ephbb, Ephbb ** 2 * bb]], **eopts + ) Eph = np.logspace(-1, 1, 3) * u.GeV assert_allclose(IC_khang.sed(Eph).value, IC_mono.sed(Eph).value, rtol=1e-2) assert_allclose(IC_khang.sed(Eph).value, IC_bb.sed(Eph).value, rtol=1e-2) - assert_allclose(IC_khang.sed(Eph).value, IC_bb_ene.sed(Eph).value, - rtol=1e-2) + assert_allclose( + IC_khang.sed(Eph).value, IC_bb_ene.sed(Eph).value, rtol=1e-2 + ) -@pytest.mark.skipif('not HAS_SCIPY') +@pytest.mark.skipif("not HAS_SCIPY") def test_flux_sed(particle_dists): """ test IC calculation @@ -347,35 +369,43 @@ def test_flux_sed(particle_dists): ECPL, PL, BPL = particle_dists d1 = 2.5 * u.kpc - d2 = 10. * u.kpc - - ic = InverseCompton(ECPL, - seed_photon_fields=['CMB', 'FIR', 'NIR'], - **electron_properties) - - luminosity = trapz_loglog( - ic.flux(energy, 0) * energy, energy).to('erg/s').value + d2 = 10.0 * u.kpc - int_flux1 = trapz_loglog( - ic.flux(energy, d1) * energy, energy).to('erg/(s cm2)').value - int_flux2 = trapz_loglog( - ic.flux(energy, d2) * energy, energy).to('erg/(s cm2)').value + ic = InverseCompton( + ECPL, seed_photon_fields=["CMB", "FIR", "NIR"], **electron_properties + ) + + luminosity = ( + trapz_loglog(ic.flux(energy, 0) * energy, energy).to("erg/s").value + ) + + int_flux1 = ( + trapz_loglog(ic.flux(energy, d1) * energy, energy) + .to("erg/(s cm2)") + .value + ) + int_flux2 = ( + trapz_loglog(ic.flux(energy, d2) * energy, energy) + .to("erg/(s cm2)") + .value + ) # check distance scaling - assert_allclose(int_flux1 / int_flux2, (d2 / d1).value**2.) + assert_allclose(int_flux1 / int_flux2, (d2 / d1).value ** 2.0) # check values - assert_allclose(int_flux1, luminosity / (4 * np.pi * - (d1.to('cm').value)**2)) + assert_allclose( + int_flux1, luminosity / (4 * np.pi * (d1.to("cm").value) ** 2) + ) # check SED - sed1 = ic.sed(energy, d1).to('erg/(s cm2)').value - sed0 = (ic.flux(energy, 0) * energy**2).to('erg/s').value + sed1 = ic.sed(energy, d1).to("erg/(s cm2)").value + sed0 = (ic.flux(energy, 0) * energy ** 2).to("erg/s").value - assert_allclose(sed1, sed0 / (4 * np.pi * (d1.to('cm').value)**2)) + assert_allclose(sed1, sed0 / (4 * np.pi * (d1.to("cm").value) ** 2)) -@pytest.mark.skipif('not HAS_SCIPY') +@pytest.mark.skipif("not HAS_SCIPY") def test_ic_seed_input(particle_dists): """ test initialization of different input formats for seed photon fields @@ -384,28 +414,30 @@ def test_ic_seed_input(particle_dists): ECPL, PL, BPL = particle_dists - ic = InverseCompton(PL, seed_photon_fields='CMB') + ic = InverseCompton(PL, seed_photon_fields="CMB") - ic = InverseCompton(PL, seed_photon_fields=['CMB', 'FIR', 'NIR'],) + ic = InverseCompton(PL, seed_photon_fields=["CMB", "FIR", "NIR"]) Eph = (1, 10) * u.eV - phn = (3, 1) * u.Unit('1/(eV cm3)') - test_seeds = [['test', 5000 * u.K, 0], - ['array', Eph, phn], - ['array-energy', Eph, Eph**2 * phn], - ['mono', Eph[0], phn[0] * Eph[0]**2], - ['mono-array', Eph[:1], phn[:1] * Eph[:1]**2], - # from docs: - ['NIR', 50 * u.K, 1.5 * u.eV / u.cm**3], - ['star', 25000 * u.K, 3 * u.erg / u.cm**3, 120 * u.deg], - ['X-ray', [1, 10] * u.keV, [1, 1e-2] * 1 / (u.eV * u.cm**3)], - ['UV', 50 * u.eV, 15 * u.eV / u.cm**3],] + phn = (3, 1) * u.Unit("1/(eV cm3)") + test_seeds = [ + ["test", 5000 * u.K, 0], + ["array", Eph, phn], + ["array-energy", Eph, Eph ** 2 * phn], + ["mono", Eph[0], phn[0] * Eph[0] ** 2], + ["mono-array", Eph[:1], phn[:1] * Eph[:1] ** 2], + # from docs: + ["NIR", 50 * u.K, 1.5 * u.eV / u.cm ** 3], + ["star", 25000 * u.K, 3 * u.erg / u.cm ** 3, 120 * u.deg], + ["X-ray", [1, 10] * u.keV, [1, 1e-2] * 1 / (u.eV * u.cm ** 3)], + ["UV", 50 * u.eV, 15 * u.eV / u.cm ** 3], + ] for seed in test_seeds: - ic = InverseCompton(PL, seed_photon_fields=['CMB', seed]) + ic = InverseCompton(PL, seed_photon_fields=["CMB", seed]) -@pytest.mark.skipif('not HAS_SCIPY') +@pytest.mark.skipif("not HAS_SCIPY") def test_ic_seed_fluxes(particle_dists): """ test per seed flux computation @@ -416,27 +448,29 @@ def test_ic_seed_fluxes(particle_dists): ic = InverseCompton( PL, - seed_photon_fields=['CMB', - ['test', 5000 * u.K, 0], - ['test2', 5000 * u.K, 10 * u.eV / u.cm**3], - ['test3', 5000 * u.K, 10 * u.eV / u.cm**3, 90 * - u.deg],],) + seed_photon_fields=[ + "CMB", + ["test", 5000 * u.K, 0], + ["test2", 5000 * u.K, 10 * u.eV / u.cm ** 3], + ["test3", 5000 * u.K, 10 * u.eV / u.cm ** 3, 90 * u.deg], + ], + ) ene = np.logspace(-3, 0, 5) * u.TeV - for idx, name in enumerate(['CMB', 'test', 'test2', 'test3',]): + for idx, name in enumerate(["CMB", "test", "test2", "test3"]): icname = ic.sed(ene, seed=name) icnumber = ic.sed(ene, seed=idx) assert_allclose(icname, icnumber) with pytest.raises(ValueError): - _ = ic.sed(ene, seed='FIR') + _ = ic.sed(ene, seed="FIR") with pytest.raises(ValueError): _ = ic.sed(ene, seed=10) -@pytest.mark.skipif('not HAS_SCIPY') +@pytest.mark.skipif("not HAS_SCIPY") def test_pion_decay(particle_dists): """ test ProtonOZM @@ -460,12 +494,12 @@ def test_pion_decay(particle_dists): lpps_noLUT = [] for pdist in particle_dists: pp = PionDecay(pdist, useLUT=True, **proton_properties) - Wps.append(pp.Wp.to('erg').value) - lpp = trapz_loglog(pp.flux(energy, 0) * energy, energy).to('erg/s') - assert (lpp.unit == u.erg / u.s) + Wps.append(pp.Wp.to("erg").value) + lpp = trapz_loglog(pp.flux(energy, 0) * energy, energy).to("erg/s") + assert lpp.unit == u.erg / u.s lpps_LUT.append(lpp.value) pp.useLUT = False - lpp = trapz_loglog(pp.flux(energy, 0) * energy, energy).to('erg/s') + lpp = trapz_loglog(pp.flux(energy, 0) * energy, energy).to("erg/s") lpps_noLUT.append(lpp.value) assert_allclose(lpps_LUT, lum_ref_LUT) @@ -473,11 +507,11 @@ def test_pion_decay(particle_dists): assert_allclose(Wps, Wp_ref) # test LUT not found - pp = PionDecay(PL, useLUT=True, hiEmodel='Geant4', **proton_properties) + pp = PionDecay(PL, useLUT=True, hiEmodel="Geant4", **proton_properties) pp.flux(energy, 0) -@pytest.mark.skipif('not HAS_SCIPY') +@pytest.mark.skipif("not HAS_SCIPY") def test_pion_decay_no_nuc_enh(particle_dists): """ test PionDecayKelner06 @@ -489,21 +523,20 @@ def test_pion_decay_no_nuc_enh(particle_dists): for pdist in [ECPL, PL, BPL]: pdist.amplitude = 1 * (1 / u.TeV) - lum_ref = [5.693100769654807e-13,] + lum_ref = [5.693100769654807e-13] energy = np.logspace(9, 13, 20) * u.eV - pp = PionDecay(ECPL, - nuclear_enhancement=False, - useLUT=False, - **proton_properties) - Wp = pp.Wp.to('erg').value - lpp = trapz_loglog(pp.flux(energy, 0) * energy, energy).to('erg/s') - assert (lpp.unit == u.erg / u.s) + pp = PionDecay( + ECPL, nuclear_enhancement=False, useLUT=False, **proton_properties + ) + Wp = pp.Wp.to("erg").value + lpp = trapz_loglog(pp.flux(energy, 0) * energy, energy).to("erg/s") + assert lpp.unit == u.erg / u.s assert_allclose(lpp.value, lum_ref[0]) -@pytest.mark.skipif('not HAS_SCIPY') +@pytest.mark.skipif("not HAS_SCIPY") def test_pion_decay_kelner(particle_dists): """ test PionDecayKelner06 @@ -519,9 +552,9 @@ def test_pion_decay_kelner(particle_dists): energy = np.logspace(9, 13, 20) * u.eV pp = PionDecay(ECPL, **proton_properties) - Wp = pp.Wp.to('erg').value - lpp = trapz_loglog(pp.flux(energy, 0) * energy, energy).to('erg/s') - assert (lpp.unit == u.erg / u.s) + Wp = pp.Wp.to("erg").value + lpp = trapz_loglog(pp.flux(energy, 0) * energy, energy).to("erg/s") + assert lpp.unit == u.erg / u.s assert_allclose(lpp.value, lum_ref[0]) @@ -532,7 +565,7 @@ def test_inputs(): from ..models import LogParabola, ExponentialCutoffBrokenPowerLaw - LP = LogParabola(1., e_0, 1.7, 0.2) + LP = LogParabola(1.0, e_0, 1.7, 0.2) LP._memoize = True # do twice for memoize @@ -541,13 +574,14 @@ def test_inputs(): LP(10 * u.TeV) LP(10 * u.TeV) - ECBPL = ExponentialCutoffBrokenPowerLaw(1., e_0, e_break, 1.5, 2.5, - e_cutoff, 2.0) + ECBPL = ExponentialCutoffBrokenPowerLaw( + 1.0, e_0, e_break, 1.5, 2.5, e_cutoff, 2.0 + ) ECBPL._memoize = True ECBPL(np.logspace(1, 10, 10) * u.TeV) with pytest.raises(TypeError): - data = {'flux': [1, 2, 4]} + data = {"flux": [1, 2, 4]} LP(data) @@ -557,14 +591,14 @@ def test_tablemodel(): lemin, lemax = -4, 2 # test an exponential cutoff PL with index 2, cutoff at 10 TeV e = np.logspace(lemin, lemax, 50) * u.TeV - n = (e.value)** -2 * np.exp(-e.value / 10) / u.eV + n = (e.value) ** -2 * np.exp(-e.value / 10) / u.eV tm = TableModel(e, n, amplitude=1) - assert_allclose(n.to('1/eV').value, tm(e).to('1/eV').value) + assert_allclose(n.to("1/eV").value, tm(e).to("1/eV").value) # test interpolation at low tolerance e2 = np.logspace(lemin, lemax, 1000) * u.TeV - n2 = (e2.value)** -2 * np.exp(-e2.value / 10) / u.eV - assert_allclose(n2.to('1/eV').value, tm(e2).to('1/eV').value, rtol=1e-1) + n2 = (e2.value) ** -2 * np.exp(-e2.value / 10) / u.eV + assert_allclose(n2.to("1/eV").value, tm(e2).to("1/eV").value, rtol=1e-1) # test TableModel without units in y tm2 = TableModel(e, n.value) @@ -576,6 +610,7 @@ def test_tablemodel(): # use tablemodel as pdist from ..radiative import Synchrotron, InverseCompton, PionDecay + SY = Synchrotron(tm) _ = SY.flux(e / 10) IC = InverseCompton(tm) @@ -592,14 +627,17 @@ def test_eblabsorptionmodel(): lemin, lemax = -4, 2 - EBL_zero = EblAbsorptionModel(0., 'Dominguez') - EBL_moderate = EblAbsorptionModel(0.5, 'Dominguez') + EBL_zero = EblAbsorptionModel(0.0, "Dominguez") + EBL_moderate = EblAbsorptionModel(0.5, "Dominguez") e = np.logspace(lemin, lemax, 50) * u.TeV -# Test if the EBL absorption at z = 0 changes the test array filled with ones - assert_allclose(np.ones_like(e).value, np.ones_like(e).value * - EBL_zero.transmission(e), rtol=1e-1) -# Make sure the transmission at z = 0. is always larger than the one at z = 0.5 + # Test if the EBL absorption at z = 0 changes the test array filled with ones + assert_allclose( + np.ones_like(e).value, + np.ones_like(e).value * EBL_zero.transmission(e), + rtol=1e-1, + ) + # Make sure the transmission at z = 0. is always larger than the one at z = 0.5 difference = EBL_zero.transmission(e) - EBL_moderate.transmission(e) - assert(np.all(difference > -1E-10)) + assert np.all(difference > -1e-10) diff --git a/naima/tests/test_plotting.py b/naima/tests/test_plotting.py index bba51811..f33aa7c6 100644 --- a/naima/tests/test_plotting.py +++ b/naima/tests/test_plotting.py @@ -10,14 +10,17 @@ try: import matplotlib - matplotlib.use('Agg') + + matplotlib.use("Agg") import matplotlib.pyplot as plt + HAS_MATPLOTLIB = True except ImportError: HAS_MATPLOTLIB = False try: import emcee + HAS_EMCEE = True except ImportError: HAS_EMCEE = False @@ -29,39 +32,43 @@ from .fixtures import sampler # Read data -fname = get_pkg_data_filename('data/CrabNebula_HESS_ipac.dat') +fname = get_pkg_data_filename("data/CrabNebula_HESS_ipac.dat") data_table = ascii.read(fname) -fname2 = get_pkg_data_filename('data/CrabNebula_Fake_Xray.dat') +fname2 = get_pkg_data_filename("data/CrabNebula_Fake_Xray.dat") data_table2 = ascii.read(fname2) data_list = [data_table2, data_table] -@pytest.mark.skipif('not HAS_MATPLOTLIB or not HAS_EMCEE') +@pytest.mark.skipif("not HAS_MATPLOTLIB or not HAS_EMCEE") @pytest.mark.parametrize("last_step", [True, False]) @pytest.mark.parametrize("convert_log", [True, False]) @pytest.mark.parametrize("include_blobs", [True, False]) -@pytest.mark.parametrize("format", ['ascii.ipac', 'ascii.ecsv', 'ascii']) +@pytest.mark.parametrize("format", ["ascii.ipac", "ascii.ecsv", "ascii"]) def test_results_table(sampler, last_step, convert_log, include_blobs, format): # set one keyword to a numpy array to try an break ecsv - sampler.run_info['test'] = np.random.randn(3) + sampler.run_info["test"] = np.random.randn(3) save_results_table( - 'test_table', sampler, - convert_log=convert_log, last_step=last_step, - format=format, include_blobs=include_blobs) + "test_table", + sampler, + convert_log=convert_log, + last_step=last_step, + format=format, + include_blobs=include_blobs, + ) - os.unlink(glob('test_table_results*')[0]) + os.unlink(glob("test_table_results*")[0]) -@pytest.mark.skipif('not HAS_MATPLOTLIB or not HAS_EMCEE') +@pytest.mark.skipif("not HAS_MATPLOTLIB or not HAS_EMCEE") @pytest.mark.parametrize("last_step", [True, False]) @pytest.mark.parametrize("p", [None, 1]) def test_chain_plots(sampler, last_step, p): plot_chain(sampler, last_step=last_step, p=p) - plt.close('all') + plt.close("all") -@pytest.mark.skipif('not HAS_MATPLOTLIB or not HAS_EMCEE') +@pytest.mark.skipif("not HAS_MATPLOTLIB or not HAS_EMCEE") @pytest.mark.parametrize("idx", range(4)) @pytest.mark.parametrize("sed", [True, False]) @pytest.mark.parametrize("last_step", [True, False]) @@ -70,64 +77,82 @@ def test_chain_plots(sampler, last_step, p): @pytest.mark.parametrize("e_range", [[1 * u.GeV, 100 * u.TeV], None]) def test_fit_plots(sampler, idx, sed, last_step, confs, n_samples, e_range): # plot models with correct format - plot_fit(sampler, modelidx=idx, sed=sed, - last_step=last_step, plotdata=True, - confs=confs, n_samples=n_samples, - e_range=e_range) - plt.close('all') - - -@pytest.mark.skipif('not HAS_MATPLOTLIB or not HAS_EMCEE') + plot_fit( + sampler, + modelidx=idx, + sed=sed, + last_step=last_step, + plotdata=True, + confs=confs, + n_samples=n_samples, + e_range=e_range, + ) + plt.close("all") + + +@pytest.mark.skipif("not HAS_MATPLOTLIB or not HAS_EMCEE") @pytest.mark.parametrize("threads", [None, 1, 4]) def test_threads_in_samples(sampler, threads): - plot_fit(sampler, - n_samples=100, - threads=threads, - e_range=[1 * u.GeV, 100 * u.TeV], - e_npoints=20) - plt.close('all') + plot_fit( + sampler, + n_samples=100, + threads=threads, + e_range=[1 * u.GeV, 100 * u.TeV], + e_npoints=20, + ) + plt.close("all") -@pytest.mark.skipif('not HAS_MATPLOTLIB or not HAS_EMCEE') +@pytest.mark.skipif("not HAS_MATPLOTLIB or not HAS_EMCEE") @pytest.mark.parametrize("sed", [True, False]) def test_plot_data(sampler, sed): plot_data(sampler, sed=sed) - plt.close('all') + plt.close("all") -@pytest.mark.skipif('not HAS_MATPLOTLIB or not HAS_EMCEE') +@pytest.mark.skipif("not HAS_MATPLOTLIB or not HAS_EMCEE") def test_plot_data_reuse_fig(sampler): # change the energy units between calls data = sampler.data fig = plot_data(data, sed=True) - data['energy'] = (data['energy']/1000).to('keV') + data["energy"] = (data["energy"] / 1000).to("keV") plot_data(data, sed=True, figure=fig) - plt.close('all') + plt.close("all") -@pytest.mark.skipif('not HAS_MATPLOTLIB or not HAS_EMCEE') +@pytest.mark.skipif("not HAS_MATPLOTLIB or not HAS_EMCEE") @pytest.mark.parametrize("data_tables", [data_table, data_table2, data_list]) def test_plot_data_tables(sampler, data_tables): plot_data(data_tables) - plt.close('all') + plt.close("all") -@pytest.mark.skipif('not HAS_MATPLOTLIB or not HAS_EMCEE') +@pytest.mark.skipif("not HAS_MATPLOTLIB or not HAS_EMCEE") def test_fit_data_units(sampler): plot_fit(sampler, modelidx=0, sed=None) - plt.close('all') + plt.close("all") -@pytest.mark.skipif('not HAS_MATPLOTLIB or not HAS_EMCEE') +@pytest.mark.skipif("not HAS_MATPLOTLIB or not HAS_EMCEE") def test_diagnostic_plots(sampler): # Diagnostic plots # try to plot all models, including those with wrong format/units - blob_labels = ['Model', 'Flux', 'Model', 'Particle Distribution', - 'Broken PL', 'Wrong', 'Wrong', 'Wrong', 'Scalar', - 'Scalar without units'] - - save_diagnostic_plots('test_function_1', sampler, blob_labels=blob_labels) - save_diagnostic_plots('test_function_2', sampler, sed=True, - blob_labels=blob_labels[:4]) - save_diagnostic_plots('test_function_3', sampler, pdf=True) + blob_labels = [ + "Model", + "Flux", + "Model", + "Particle Distribution", + "Broken PL", + "Wrong", + "Wrong", + "Wrong", + "Scalar", + "Scalar without units", + ] + + save_diagnostic_plots("test_function_1", sampler, blob_labels=blob_labels) + save_diagnostic_plots( + "test_function_2", sampler, sed=True, blob_labels=blob_labels[:4] + ) + save_diagnostic_plots("test_function_3", sampler, pdf=True) diff --git a/naima/tests/test_saveread.py b/naima/tests/test_saveread.py index b5f24c15..3fd437eb 100644 --- a/naima/tests/test_saveread.py +++ b/naima/tests/test_saveread.py @@ -10,14 +10,17 @@ try: import matplotlib - matplotlib.use('Agg') + + matplotlib.use("Agg") import matplotlib.pyplot as plt + HAS_MATPLOTLIB = True except: HAS_MATPLOTLIB = False try: import emcee + HAS_EMCEE = True except: HAS_EMCEE = False @@ -31,15 +34,15 @@ from .fixtures import simple_sampler as sampler -fname = get_pkg_data_filename('data/CrabNebula_HESS_ipac.dat') +fname = get_pkg_data_filename("data/CrabNebula_HESS_ipac.dat") data_table = ascii.read(fname) -@pytest.mark.skipif('not HAS_EMCEE') +@pytest.mark.skipif("not HAS_EMCEE") def test_roundtrip(sampler): - save_run('test_chain.h5', sampler, clobber=True) - assert os.path.exists('test_chain.h5') - nresult = read_run('test_chain.h5') + save_run("test_chain.h5", sampler, clobber=True) + assert os.path.exists("test_chain.h5") + nresult = read_run("test_chain.h5") assert np.allclose(sampler.chain, nresult.chain) assert np.allclose(sampler.flatchain, nresult.flatchain) @@ -47,7 +50,7 @@ def test_roundtrip(sampler): assert np.allclose(sampler.flatlnprobability, nresult.flatlnprobability) nwalkers, nsteps = sampler.chain.shape[:2] - j, k = int(nsteps/2), int(nwalkers/2) + j, k = int(nsteps / 2), int(nwalkers / 2) for l in range(len(sampler.blobs[j][k])): b0 = sampler.blobs[j][k][l] b1 = nresult.blobs[j][k][l] @@ -69,45 +72,50 @@ def test_roundtrip(sampler): assert sampler.labels[i] == nresult.labels[i] for col in sampler.data.colnames: - assert np.allclose(u.Quantity(sampler.data[col]).value, - u.Quantity(nresult.data[col]).value) + assert np.allclose( + u.Quantity(sampler.data[col]).value, + u.Quantity(nresult.data[col]).value, + ) assert str(sampler.data[col].unit) == str(nresult.data[col].unit) validate_data_table(nresult.data) - assert np.allclose(np.mean(sampler.acceptance_fraction), - nresult.acceptance_fraction) + assert np.allclose( + np.mean(sampler.acceptance_fraction), nresult.acceptance_fraction + ) -@pytest.mark.skipif('not HAS_MATPLOTLIB or not HAS_EMCEE') +@pytest.mark.skipif("not HAS_MATPLOTLIB or not HAS_EMCEE") def test_plot_fit(sampler): - save_run('test_chain.h5', sampler, clobber=True) - nresult = read_run('test_chain.h5', modelfn=sampler.modelfn) + save_run("test_chain.h5", sampler, clobber=True) + nresult = read_run("test_chain.h5", modelfn=sampler.modelfn) plot_data(nresult) plot_fit(nresult, 0) plot_fit(nresult, 0, e_range=[0.1, 10] * u.TeV) plot_fit(nresult, 0, sed=False) - plt.close('all') + plt.close("all") -@pytest.mark.skipif('not HAS_MATPLOTLIB or not HAS_EMCEE') +@pytest.mark.skipif("not HAS_MATPLOTLIB or not HAS_EMCEE") def test_plot_chain(sampler): - save_run('test_chain.h5', sampler, clobber=True) - nresult = read_run('test_chain.h5', modelfn=sampler.modelfn) + save_run("test_chain.h5", sampler, clobber=True) + nresult = read_run("test_chain.h5", modelfn=sampler.modelfn) for i in range(nresult.chain.shape[2]): plot_chain(nresult, i) - plt.close('all') + plt.close("all") -@pytest.mark.skipif('not HAS_MATPLOTLIB or not HAS_EMCEE') +@pytest.mark.skipif("not HAS_MATPLOTLIB or not HAS_EMCEE") def test_imf(sampler): - save_run('test_chain.h5', sampler, clobber=True) - nresult = read_run('test_chain.h5', modelfn=sampler.modelfn) + save_run("test_chain.h5", sampler, clobber=True) + nresult = read_run("test_chain.h5", modelfn=sampler.modelfn) - imf = InteractiveModelFitter(nresult.modelfn, nresult.chain[-1][-1], - nresult.data) - imf.do_fit('test') + imf = InteractiveModelFitter( + nresult.modelfn, nresult.chain[-1][-1], nresult.data + ) + imf.do_fit("test") from naima.core import lnprobmodel - lnprobmodel(nresult.modelfn(imf.pars,nresult.data)[0], nresult.data) - plt.close('all') + + lnprobmodel(nresult.modelfn(imf.pars, nresult.data)[0], nresult.data) + plt.close("all") diff --git a/naima/tests/test_sherpamod.py b/naima/tests/test_sherpamod.py index fbfec2f0..dfb8e014 100644 --- a/naima/tests/test_sherpamod.py +++ b/naima/tests/test_sherpamod.py @@ -5,13 +5,17 @@ try: from sherpa import ui + HAS_SHERPA = True except ImportError: HAS_SHERPA = False energies = np.logspace(8, 10, 10) # 0.1 to 10 TeV in keV -test_spec_points = (1e-20 * (energies / 1e9) ** -0.7 * - (1 + 0.2 * np.random.randn(energies.size))) +test_spec_points = ( + 1e-20 + * (energies / 1e9) ** -0.7 + * (1 + 0.2 * np.random.randn(energies.size)) +) test_err_points = 0.2 * test_spec_points elo = energies[:-1] @@ -20,7 +24,7 @@ test_err_int = 0.2 * test_spec_int -@pytest.mark.skipif('not HAS_SHERPA') +@pytest.mark.skipif("not HAS_SHERPA") def test_electron_models(): """ test import @@ -58,14 +62,15 @@ def test_electron_models(): model.verbose.set(1) # test with integrated data - ui.load_arrays(1, elo, ehi, test_spec_int, test_err_int, - ui.Data1DInt) + ui.load_arrays( + 1, elo, ehi, test_spec_int, test_err_int, ui.Data1DInt + ) ui.set_model(model) ui.guess() ui.fit() -@pytest.mark.skipif('not HAS_SHERPA') +@pytest.mark.skipif("not HAS_SHERPA") def test_proton_model(): """ test import diff --git a/naima/tests/test_utils.py b/naima/tests/test_utils.py index 4321ab56..50568e81 100644 --- a/naima/tests/test_utils.py +++ b/naima/tests/test_utils.py @@ -6,92 +6,106 @@ import astropy.units as u from astropy.io import ascii -from ..utils import (validate_data_table, generate_energy_edges, - build_data_table, estimate_B) +from ..utils import ( + validate_data_table, + generate_energy_edges, + build_data_table, + estimate_B, +) # Read data -fname = get_pkg_data_filename('data/CrabNebula_HESS_ipac.dat') +fname = get_pkg_data_filename("data/CrabNebula_HESS_ipac.dat") data_table = ascii.read(fname) # Read spectrum with symmetric flux errors -fname_sym = get_pkg_data_filename('data/CrabNebula_HESS_ipac_symmetric.dat') +fname_sym = get_pkg_data_filename("data/CrabNebula_HESS_ipac_symmetric.dat") data_table_sym = ascii.read(fname_sym) + def test_validate_energy_error_types(): - for etype in ['edges','error','width','errors']: + for etype in ["edges", "error", "width", "errors"]: fname = get_pkg_data_filename( - 'data/CrabNebula_HESS_ipac_energy_{0}.dat'.format(etype)) + "data/CrabNebula_HESS_ipac_energy_{0}.dat".format(etype) + ) dt = ascii.read(fname) validate_data_table(dt) + def test_sed(): - fname = get_pkg_data_filename('data/Fake_ipac_sed.dat') + fname = get_pkg_data_filename("data/Fake_ipac_sed.dat") validate_data_table(ascii.read(fname)) validate_data_table([ascii.read(fname)]) + def test_concatenation(): - fname0 = get_pkg_data_filename('data/Fake_ipac_sed.dat') + fname0 = get_pkg_data_filename("data/Fake_ipac_sed.dat") dt0 = ascii.read(fname0) for sed in [True, False]: - validate_data_table([dt0,data_table],sed=sed) - validate_data_table([data_table,dt0],sed=sed) - validate_data_table([dt0,dt0],sed=sed) + validate_data_table([dt0, data_table], sed=sed) + validate_data_table([data_table, dt0], sed=sed) + validate_data_table([dt0, dt0], sed=sed) + def test_validate_data_types(): data_table2 = data_table.copy() - data_table2['energy'].unit = '' + data_table2["energy"].unit = "" with pytest.raises(TypeError): validate_data_table(data_table2) + def test_validate_missing_column(): data_table2 = data_table.copy() - data_table2.remove_column('energy') + data_table2.remove_column("energy") with pytest.raises(TypeError): validate_data_table(data_table2) data_table2 = data_table_sym.copy() - data_table2.remove_column('flux_error') + data_table2.remove_column("flux_error") with pytest.raises(TypeError): validate_data_table(data_table2) + def test_validate_string_uls(): from astropy.table import Column + data_table2 = data_table.copy() # replace uls column with valid strings - data_table2.remove_column('ul') + data_table2.remove_column("ul") data_table2.add_column( - Column(name='ul', dtype=str, data=['False']*len(data_table2)) + Column(name="ul", dtype=str, data=["False"] * len(data_table2)) ) - data_table2['ul'][1] = 'True' + data_table2["ul"][1] = "True" data = validate_data_table(data_table2) - assert np.sum(data['ul']) == 1 - assert np.sum(~data['ul']) == len(data_table2)-1 + assert np.sum(data["ul"]) == 1 + assert np.sum(~data["ul"]) == len(data_table2) - 1 # put an invalid value - data_table2['ul'][2] = 'foo' + data_table2["ul"][2] = "foo" with pytest.raises(TypeError): validate_data_table(data_table2) + def test_validate_cl(): data_table2 = data_table.copy() # use invalid value - data_table2.meta['keywords']['cl']['value'] = 'test' + data_table2.meta["keywords"]["cl"]["value"] = "test" with pytest.raises(TypeError): data = validate_data_table(data_table2) # remove cl - data_table2.meta['keywords'].pop('cl') + data_table2.meta["keywords"].pop("cl") data = validate_data_table(data_table2) - assert np.all(data['cl'] == 0.9) + assert np.all(data["cl"] == 0.9) + def test_build_data_table(): - ene = np.logspace(-2,2,20) * u.TeV - flux = (ene / (1 * u.TeV)) ** -2 * u.Unit('1/(cm2 s TeV)') + ene = np.logspace(-2, 2, 20) * u.TeV + flux = (ene / (1 * u.TeV)) ** -2 * u.Unit("1/(cm2 s TeV)") flux_error_hi = 0.2 * flux flux_error_lo = 0.1 * flux ul = np.zeros(len(ene)) @@ -99,15 +113,31 @@ def test_build_data_table(): dene = generate_energy_edges(ene) - table = build_data_table(ene, flux, flux_error_hi=flux_error_hi, - flux_error_lo=flux_error_lo, ul=ul) - table = build_data_table(ene, flux, flux_error_hi=flux_error_hi, - flux_error_lo=flux_error_lo, ul=ul, cl=0.99) - table = build_data_table(ene, flux, flux_error=flux_error_hi, - energy_width=dene[0]) - table = build_data_table(ene, flux, flux_error=flux_error_hi, - energy_lo=(ene - dene[0]), - energy_hi=(ene + dene[1])) + table = build_data_table( + ene, + flux, + flux_error_hi=flux_error_hi, + flux_error_lo=flux_error_lo, + ul=ul, + ) + table = build_data_table( + ene, + flux, + flux_error_hi=flux_error_hi, + flux_error_lo=flux_error_lo, + ul=ul, + cl=0.99, + ) + table = build_data_table( + ene, flux, flux_error=flux_error_hi, energy_width=dene[0] + ) + table = build_data_table( + ene, + flux, + flux_error=flux_error_hi, + energy_lo=(ene - dene[0]), + energy_hi=(ene + dene[1]), + ) # no flux_error with pytest.raises(TypeError): @@ -118,14 +148,16 @@ def test_build_data_table(): build_data_table(ene.value, flux, flux_error=flux_error_hi) with pytest.raises(TypeError): - build_data_table(ene.value*u.Unit('erg/(cm2 s)'), flux, - flux_error=flux_error_hi) + build_data_table( + ene.value * u.Unit("erg/(cm2 s)"), flux, flux_error=flux_error_hi + ) + def test_estimate_B(): - fname = get_pkg_data_filename('data/CrabNebula_Fake_Xray.dat') + fname = get_pkg_data_filename("data/CrabNebula_Fake_Xray.dat") xray = ascii.read(fname) B = estimate_B(xray, data_table) - assert_allclose(B.to('uG'), 0.4848756912803697 * u.uG) + assert_allclose(B.to("uG"), 0.4848756912803697 * u.uG) diff --git a/naima/utils.py b/naima/utils.py index 69c060bb..1b2964ed 100644 --- a/naima/utils.py +++ b/naima/utils.py @@ -1,6 +1,10 @@ # Licensed under a 3-clause BSD style license - see LICENSE.rst -from __future__ import (absolute_import, division, print_function, - unicode_literals) +from __future__ import ( + absolute_import, + division, + print_function, + unicode_literals, +) import numpy as np import astropy.units as u @@ -11,24 +15,28 @@ from .extern.validator import validate_array, validate_scalar __all__ = [ - "generate_energy_edges", "sed_conversion", "build_data_table", "estimate_B" + "generate_energy_edges", + "sed_conversion", + "build_data_table", + "estimate_B", ] # Input validation tools -def validate_column(data_table, key, pt, domain='positive'): +def validate_column(data_table, key, pt, domain="positive"): try: column = data_table[key] array = validate_array( key, - u.Quantity( - column, unit=column.unit), + u.Quantity(column, unit=column.unit), physical_type=pt, - domain=domain) + domain=domain, + ) except KeyError: raise TypeError( - 'Data table does not contain required column "{0}"'.format(key)) + 'Data table does not contain required column "{0}"'.format(key) + ) return array @@ -54,24 +62,25 @@ def validate_data_table(data_table, sed=None): for dt in data_table: if not isinstance(dt, Table) and not isinstance(dt, QTable): raise TypeError( - 'An object passed as data_table is not an astropy Table!') + "An object passed as data_table is not an astropy Table!" + ) except TypeError: raise TypeError( - 'Argument passed to validate_data_table is not a table and ' - 'not a list' + "Argument passed to validate_data_table is not a table and " + "not a list" ) def dt_sed_conversion(dt, sed): - f_unit, sedf = sed_conversion(dt['energy'], dt['flux'].unit, sed) + f_unit, sedf = sed_conversion(dt["energy"], dt["flux"].unit, sed) # roundtrip to Table to change the units t = Table(dt) - for col in ['flux', 'flux_error_lo', 'flux_error_hi']: + for col in ["flux", "flux_error_lo", "flux_error_hi"]: t[col].unit = f_unit ndt = QTable(t) - ndt['flux'] = (dt['flux'] * sedf).to(f_unit) - ndt['flux_error_lo'] = (dt['flux_error_lo'] * sedf).to(f_unit) - ndt['flux_error_hi'] = (dt['flux_error_hi'] * sedf).to(f_unit) + ndt["flux"] = (dt["flux"] * sedf).to(f_unit) + ndt["flux_error_lo"] = (dt["flux_error_lo"] * sedf).to(f_unit) + ndt["flux_error_hi"] = (dt["flux_error_hi"] * sedf).to(f_unit) return ndt @@ -82,19 +91,21 @@ def dt_sed_conversion(dt, sed): # concatenate input data tables data_new = data_list[0].copy() - f_pt = data_new['flux'].unit.physical_type + f_pt = data_new["flux"].unit.physical_type if sed is None: - sed = f_pt in ['flux', 'power'] + sed = f_pt in ["flux", "power"] data_new = dt_sed_conversion(data_new, sed) for dt in data_list[1:]: - nf_pt = dt['flux'].unit.physical_type - if (('flux' in nf_pt and 'power' in f_pt) or - ('power' in nf_pt and 'flux' in f_pt)): + nf_pt = dt["flux"].unit.physical_type + if ("flux" in nf_pt and "power" in f_pt) or ( + "power" in nf_pt and "flux" in f_pt + ): raise TypeError( - 'The physical types of the data tables could not be ' - 'matched: Some are in flux and others in luminosity units') + "The physical types of the data tables could not be " + "matched: Some are in flux and others in luminosity units" + ) dt = dt_sed_conversion(dt, sed) @@ -108,100 +119,114 @@ def _validate_single_data_table(data_table, group=0): data = QTable() - flux_types = ['flux', 'differential flux', 'power', 'differential power'] + flux_types = ["flux", "differential flux", "power", "differential power"] # Energy and flux arrays - data['energy'] = validate_column(data_table, 'energy', 'energy') - data['flux'] = validate_column(data_table, 'flux', flux_types) + data["energy"] = validate_column(data_table, "energy", "energy") + data["flux"] = validate_column(data_table, "flux", flux_types) # Flux uncertainties - if 'flux_error' in data_table.keys(): - dflux = validate_column(data_table, 'flux_error', flux_types) - data['flux_error_lo'] = dflux - data['flux_error_hi'] = dflux - elif 'flux_error_lo' in data_table.keys( - ) and 'flux_error_hi' in data_table.keys(): - data['flux_error_lo'] = validate_column(data_table, 'flux_error_lo', - flux_types) - data['flux_error_hi'] = validate_column(data_table, 'flux_error_hi', - flux_types) + if "flux_error" in data_table.keys(): + dflux = validate_column(data_table, "flux_error", flux_types) + data["flux_error_lo"] = dflux + data["flux_error_hi"] = dflux + elif ( + "flux_error_lo" in data_table.keys() + and "flux_error_hi" in data_table.keys() + ): + data["flux_error_lo"] = validate_column( + data_table, "flux_error_lo", flux_types + ) + data["flux_error_hi"] = validate_column( + data_table, "flux_error_hi", flux_types + ) else: - raise TypeError('Data table does not contain required column' - ' "flux_error" or columns "flux_error_lo"' - ' and "flux_error_hi"') + raise TypeError( + "Data table does not contain required column" + ' "flux_error" or columns "flux_error_lo"' + ' and "flux_error_hi"' + ) - if 'group' in data_table.colnames: + if "group" in data_table.colnames: # avoid overwriting groups - data['group'] = data_table['group'] + data["group"] = data_table["group"] else: - data['group'] = [group] * len(data['energy']) + data["group"] = [group] * len(data["energy"]) # Energy bin edges - if 'energy_width' in data_table.keys(): - energy_width = validate_column(data_table, 'energy_width', 'energy') - data['energy_error_lo'] = energy_width / 2. - data['energy_error_hi'] = energy_width / 2. - elif 'energy_error' in data_table.keys(): - energy_error = validate_column(data_table, 'energy_error', 'energy') - data['energy_error_lo'] = energy_error - data['energy_error_hi'] = energy_error - elif ('energy_error_lo' in data_table.keys() and - 'energy_error_hi' in data_table.keys()): - energy_error_lo = validate_column(data_table, 'energy_error_lo', - 'energy') - data['energy_error_lo'] = energy_error_lo - energy_error_hi = validate_column(data_table, 'energy_error_hi', - 'energy') - data['energy_error_hi'] = energy_error_hi - elif 'energy_lo' in data_table.keys() and 'energy_hi' in data_table.keys(): - energy_lo = validate_column(data_table, 'energy_lo', 'energy') - data['energy_error_lo'] = (data['energy'] - energy_lo) - energy_hi = validate_column(data_table, 'energy_hi', 'energy') - data['energy_error_hi'] = (energy_hi - data['energy']) + if "energy_width" in data_table.keys(): + energy_width = validate_column(data_table, "energy_width", "energy") + data["energy_error_lo"] = energy_width / 2.0 + data["energy_error_hi"] = energy_width / 2.0 + elif "energy_error" in data_table.keys(): + energy_error = validate_column(data_table, "energy_error", "energy") + data["energy_error_lo"] = energy_error + data["energy_error_hi"] = energy_error + elif ( + "energy_error_lo" in data_table.keys() + and "energy_error_hi" in data_table.keys() + ): + energy_error_lo = validate_column( + data_table, "energy_error_lo", "energy" + ) + data["energy_error_lo"] = energy_error_lo + energy_error_hi = validate_column( + data_table, "energy_error_hi", "energy" + ) + data["energy_error_hi"] = energy_error_hi + elif "energy_lo" in data_table.keys() and "energy_hi" in data_table.keys(): + energy_lo = validate_column(data_table, "energy_lo", "energy") + data["energy_error_lo"] = data["energy"] - energy_lo + energy_hi = validate_column(data_table, "energy_hi", "energy") + data["energy_error_hi"] = energy_hi - data["energy"] else: - data['energy_error_lo'], data[ - 'energy_error_hi'] = generate_energy_edges( - data['energy'], groups=data['group']) + data["energy_error_lo"], data[ + "energy_error_hi" + ] = generate_energy_edges(data["energy"], groups=data["group"]) # Upper limit flags - if 'ul' in data_table.keys(): + if "ul" in data_table.keys(): # Check if it is a integer or boolean flag - ul_col = data_table['ul'] + ul_col = data_table["ul"] if ul_col.dtype.type is np.int_ or ul_col.dtype.type is np.bool_: - data['ul'] = np.array(ul_col, dtype=np.bool) + data["ul"] = np.array(ul_col, dtype=np.bool) elif ul_col.dtype.type is np.str_: strbool = True for ul in ul_col: - if ul != 'True' and ul != 'False': + if ul != "True" and ul != "False": strbool = False if strbool: - data['ul'] = np.array( - [ast.literal_eval(ul) for ul in ul_col], dtype=np.bool) + data["ul"] = np.array( + [ast.literal_eval(ul) for ul in ul_col], dtype=np.bool + ) else: - raise TypeError('UL column is in wrong format') + raise TypeError("UL column is in wrong format") else: - raise TypeError('UL column is in wrong format') + raise TypeError("UL column is in wrong format") else: - data['ul'] = np.array([False] * len(data['energy'])) + data["ul"] = np.array([False] * len(data["energy"])) - if 'flux_ul' in data_table.keys(): - data['flux'][data['ul']] = u.Quantity( - data_table['flux_ul'])[data['ul']] + if "flux_ul" in data_table.keys(): + data["flux"][data["ul"]] = u.Quantity(data_table["flux_ul"])[ + data["ul"] + ] HAS_CL = False - if 'keywords' in data_table.meta.keys(): - if 'cl' in data_table.meta['keywords'].keys(): + if "keywords" in data_table.meta.keys(): + if "cl" in data_table.meta["keywords"].keys(): HAS_CL = True - CL = validate_scalar('cl', - data_table.meta['keywords']['cl']['value']) - data['cl'] = [CL] * len(data['energy']) + CL = validate_scalar( + "cl", data_table.meta["keywords"]["cl"]["value"] + ) + data["cl"] = [CL] * len(data["energy"]) if not HAS_CL: - data['cl'] = [0.9] * len(data['energy']) - if np.sum(data['ul']) > 0: + data["cl"] = [0.9] * len(data["energy"]) + if np.sum(data["ul"]) > 0: log.warning( '"cl" keyword not provided in input data table, upper limits' - ' will be assumed to be at 90% confidence level') + " will be assumed to be at 90% confidence level" + ) return data @@ -220,44 +245,49 @@ def sed_conversion(energy, model_unit, sed): if sed: # SED - f_unit = u.Unit('erg/s') - if model_pt == 'power' or model_pt == 'flux' or model_pt == 'energy': + f_unit = u.Unit("erg/s") + if model_pt == "power" or model_pt == "flux" or model_pt == "energy": sedf = ones - elif 'differential' in model_pt: - sedf = (energy**2) + elif "differential" in model_pt: + sedf = energy ** 2 else: raise u.UnitsError( - 'Model physical type ({0}) is not supported'.format(model_pt), - 'Supported physical types are: power, flux, differential' - ' power, differential flux') - - if 'flux' in model_pt: - f_unit /= u.cm**2 - elif 'energy' in model_pt: + "Model physical type ({0}) is not supported".format(model_pt), + "Supported physical types are: power, flux, differential" + " power, differential flux", + ) + + if "flux" in model_pt: + f_unit /= u.cm ** 2 + elif "energy" in model_pt: # particle energy distributions f_unit = u.erg else: # Differential spectrum - f_unit = u.Unit('1/(s TeV)') - if 'differential' in model_pt: + f_unit = u.Unit("1/(s TeV)") + if "differential" in model_pt: sedf = ones - elif model_pt == 'power' or model_pt == 'flux' or model_pt == 'energy': + elif model_pt == "power" or model_pt == "flux" or model_pt == "energy": # From SED to differential - sedf = 1 / (energy**2) + sedf = 1 / (energy ** 2) else: raise u.UnitsError( - 'Model physical type ({0}) is not supported'.format(model_pt), - 'Supported physical types are: power, flux, differential' - ' power, differential flux') - - if 'flux' in model_pt: - f_unit /= u.cm**2 - elif 'energy' in model_pt: + "Model physical type ({0}) is not supported".format(model_pt), + "Supported physical types are: power, flux, differential" + " power, differential flux", + ) + + if "flux" in model_pt: + f_unit /= u.cm ** 2 + elif "energy" in model_pt: # particle energy distributions - f_unit = u.Unit('1/TeV') + f_unit = u.Unit("1/TeV") - log.debug('Converted from {0} ({1}) into {2} ({3}) for sed={4}'.format( - model_unit, model_pt, f_unit, f_unit.physical_type, sed)) + log.debug( + "Converted from {0} ({1}) into {2} ({3}) for sed={4}".format( + model_unit, model_pt, f_unit, f_unit.physical_type, sed + ) + ) return f_unit, sedf @@ -287,12 +317,12 @@ def trapz_loglog(y, x, axis=-1, intervals=False): y_unit = y.unit y = y.value except AttributeError: - y_unit = 1. + y_unit = 1.0 try: x_unit = x.unit x = x.value except AttributeError: - x_unit = 1. + x_unit = 1.0 y = np.asanyarray(y) x = np.asanyarray(x) @@ -315,12 +345,17 @@ def trapz_loglog(y, x, axis=-1, intervals=False): # if local powerlaw index is -1, use \int 1/x = log(x); otherwise use # normal powerlaw integration trapzs = np.where( - np.abs(b + 1.) > 1e-10, (y[slice1] * ( - x[slice2] * (x[slice2] / x[slice1])**b - x[slice1])) / (b + 1), - x[slice1] * y[slice1] * np.log(x[slice2] / x[slice1])) + np.abs(b + 1.0) > 1e-10, + ( + y[slice1] + * (x[slice2] * (x[slice2] / x[slice1]) ** b - x[slice1]) + ) + / (b + 1), + x[slice1] * y[slice1] * np.log(x[slice2] / x[slice1]), + ) - tozero = (y[slice1] == 0.) + (y[slice2] == 0.) + (x[slice1] == x[slice2]) - trapzs[tozero] = 0. + tozero = (y[slice1] == 0.0) + (y[slice2] == 0.0) + (x[slice1] == x[slice2]) + trapzs[tozero] = 0.0 if intervals: return trapzs * x_unit * y_unit @@ -370,16 +405,18 @@ def generate_energy_edges(ene, groups=None): return eloehi -def build_data_table(energy, - flux, - flux_error=None, - flux_error_lo=None, - flux_error_hi=None, - energy_width=None, - energy_lo=None, - energy_hi=None, - ul=None, - cl=None): +def build_data_table( + energy, + flux, + flux_error=None, + flux_error_lo=None, + flux_error_hi=None, + energy_width=None, + energy_lo=None, + energy_hi=None, + ul=None, + cl=None, +): """ Read data into data dict. @@ -419,32 +456,32 @@ def build_data_table(energy, table = QTable() if cl is not None: - cl = validate_scalar('cl', cl) - table.meta['keywords'] = {'cl': {'value': cl}} + cl = validate_scalar("cl", cl) + table.meta["keywords"] = {"cl": {"value": cl}} - table['energy'] = energy + table["energy"] = energy if energy_width is not None: - table['energy_width'] = energy_width + table["energy_width"] = energy_width elif energy_lo is not None and energy_hi is not None: - table['energy_lo'] = energy_lo - table['energy_hi'] = energy_hi + table["energy_lo"] = energy_lo + table["energy_hi"] = energy_hi - table['flux'] = flux + table["flux"] = flux if flux_error is not None: - table['flux_error'] = flux_error + table["flux_error"] = flux_error elif flux_error_lo is not None and flux_error_hi is not None: - table['flux_error_lo'] = flux_error_lo - table['flux_error_hi'] = flux_error_hi + table["flux_error_lo"] = flux_error_lo + table["flux_error_hi"] = flux_error_hi else: - raise TypeError('Flux error not provided!') + raise TypeError("Flux error not provided!") if ul is not None: ul = np.array(ul, dtype=np.int) - table['ul'] = ul + table["ul"] = ul - table.meta['comments'] = ['Table generated with naima.build_data_table'] + table.meta["comments"] = ["Table generated with naima.build_data_table"] # test table units, format, etc validate_data_table(table) @@ -452,9 +489,9 @@ def build_data_table(energy, return table -def estimate_B(xray_table, - vhe_table, - photon_energy_density=0.261 * u.eV / u.cm**3): +def estimate_B( + xray_table, vhe_table, photon_energy_density=0.261 * u.eV / u.cm ** 3 +): """ Estimate magnetic field from synchrotron to Inverse Compton luminosity ratio @@ -503,12 +540,13 @@ def estimate_B(xray_table, xray = validate_data_table(xray_table, sed=False) vhe = validate_data_table(vhe_table, sed=False) - xray_lum = trapz_loglog(xray['flux'] * xray['energy'], xray['energy']) - vhe_lum = trapz_loglog(vhe['flux'] * vhe['energy'], vhe['energy']) + xray_lum = trapz_loglog(xray["flux"] * xray["energy"], xray["energy"]) + vhe_lum = trapz_loglog(vhe["flux"] * vhe["energy"], vhe["energy"]) - uph = (photon_energy_density.to('erg/cm3')).value + uph = (photon_energy_density.to("erg/cm3")).value - B0 = (np.sqrt((xray_lum / vhe_lum).decompose().value * 8 * np.pi * uph) * - u.G).to('uG') + B0 = ( + np.sqrt((xray_lum / vhe_lum).decompose().value * 8 * np.pi * uph) * u.G + ).to("uG") return B0 diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 00000000..32347df5 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,15 @@ +[tool.black] +line-length = 79 +exclude = ''' +/( + \.git + | \.mypy_cache + | \.tox + | \.venv + | _build + | build + | dist + | \.eggs + | astropy_helpers +)/ +''' diff --git a/setup.py b/setup.py index 451a8da3..51db513c 100644 --- a/setup.py +++ b/setup.py @@ -8,7 +8,7 @@ import ah_bootstrap from setuptools import setup -#A dirty hack to get around some early import/configurations ambiguities +# A dirty hack to get around some early import/configurations ambiguities if sys.version_info[0] >= 3: import builtins else: @@ -16,7 +16,10 @@ builtins._ASTROPY_SETUP_ = True from astropy_helpers.setup_helpers import ( - register_commands, get_debug_option, get_package_info) + register_commands, + get_debug_option, + get_package_info, +) from astropy_helpers.git_helpers import get_git_devstr from astropy_helpers.version_helpers import generate_version_py @@ -26,18 +29,18 @@ except ImportError: from configparser import ConfigParser conf = ConfigParser() -conf.read(['setup.cfg']) -metadata = dict(conf.items('metadata')) +conf.read(["setup.cfg"]) +metadata = dict(conf.items("metadata")) -PACKAGENAME = str(metadata.get('package_name', 'packagename')) -DESCRIPTION = metadata.get('description', 'Astropy affiliated package') -AUTHOR = metadata.get('author', '') -AUTHOR_EMAIL = metadata.get('author_email', '') -LICENSE = metadata.get('license', 'unknown') -URL = metadata.get('url', 'http://astropy.org') +PACKAGENAME = str(metadata.get("package_name", "packagename")) +DESCRIPTION = metadata.get("description", "Astropy affiliated package") +AUTHOR = metadata.get("author", "") +AUTHOR_EMAIL = metadata.get("author_email", "") +LICENSE = metadata.get("license", "unknown") +URL = metadata.get("url", "http://astropy.org") # Get the long description from the README.rst file -with open('README.rst', 'rt') as f: +with open("README.rst", "rt") as f: LONG_DESCRIPTION = f.read() # Store the package name in a built-in variable so it's easy @@ -45,10 +48,10 @@ builtins._ASTROPY_PACKAGE_NAME_ = PACKAGENAME # VERSION should be PEP386 compatible (http://www.python.org/dev/peps/pep-0386) -VERSION = '0.8.dev' +VERSION = "0.8.dev" # Indicates if this version is a release version -RELEASE = 'dev' not in VERSION +RELEASE = "dev" not in VERSION if not RELEASE: VERSION += get_git_devstr(False) @@ -59,12 +62,16 @@ cmdclassd = register_commands(PACKAGENAME, VERSION, RELEASE) # Freeze build information in version.py -generate_version_py(PACKAGENAME, VERSION, RELEASE, - get_debug_option(PACKAGENAME)) +generate_version_py( + PACKAGENAME, VERSION, RELEASE, get_debug_option(PACKAGENAME) +) # Treat everything in scripts except README.rst as a script to be installed -scripts = [fname for fname in glob.glob(os.path.join('scripts', '*')) - if os.path.basename(fname) != 'README.rst'] +scripts = [ + fname + for fname in glob.glob(os.path.join("scripts", "*")) + if os.path.basename(fname) != "README.rst" +] # Get configuration information from all of the various subpackages. # See the docstring for setup_helpers.update_package_files for more @@ -72,8 +79,8 @@ package_info = get_package_info() # Add the project-global data -package_info['package_data'].setdefault(PACKAGENAME, []) -package_info['package_data'][PACKAGENAME].append('data/*') +package_info["package_data"].setdefault(PACKAGENAME, []) +package_info["package_data"][PACKAGENAME].append("data/*") # Include all .c files, recursively, including those generated by # Cython, since we can not do this in MANIFEST.in with a "dynamic" @@ -81,50 +88,56 @@ c_files = [] for root, dirs, files in os.walk(PACKAGENAME): for filename in files: - if filename.endswith('.c'): + if filename.endswith(".c"): c_files.append( - os.path.join( - os.path.relpath(root, PACKAGENAME), filename)) -package_info['package_data'][PACKAGENAME].extend(c_files) + os.path.join(os.path.relpath(root, PACKAGENAME), filename) + ) +package_info["package_data"][PACKAGENAME].extend(c_files) # Some dependencies with C extensions cannot be built on readthedocs -on_rtd = os.environ.get('READTHEDOCS', None) == 'True' +on_rtd = os.environ.get("READTHEDOCS", None) == "True" if on_rtd: install_requires = [] else: - install_requires=['astropy>=1.0.2', - 'emcee>=2.2.0', - 'corner', - 'matplotlib', - 'scipy', - 'h5py', - 'numtraits', - 'traitlets'], - - -setup(name=PACKAGENAME, - version=VERSION, - description=DESCRIPTION, - scripts=scripts, - install_requires=install_requires, - author=AUTHOR, - author_email=AUTHOR_EMAIL, - license=LICENSE, - url=URL, - long_description=LONG_DESCRIPTION, - classifiers = [ 'Programming Language :: Python :: 3', - 'Programming Language :: Python :: 2.7', - 'Programming Language :: Python :: 3.4', - 'Programming Language :: Python :: 3.5', - 'Development Status :: 5 - Production/Stable', - 'Intended Audience :: Science/Research', - 'License :: OSI Approved :: BSD License', - 'Operating System :: OS Independent', - 'Topic :: Scientific/Engineering :: Astronomy', - 'Topic :: Scientific/Engineering :: Physics', - ], - cmdclass=cmdclassd, - zip_safe=False, - use_2to3=False, - **package_info + install_requires = ( + [ + "astropy>=1.0.2", + "emcee>=2.2.0", + "corner", + "matplotlib", + "scipy", + "h5py", + "numtraits", + "traitlets", + ], + ) + + +setup( + name=PACKAGENAME, + version=VERSION, + description=DESCRIPTION, + scripts=scripts, + install_requires=install_requires, + author=AUTHOR, + author_email=AUTHOR_EMAIL, + license=LICENSE, + url=URL, + long_description=LONG_DESCRIPTION, + classifiers=[ + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 2.7", + "Programming Language :: Python :: 3.4", + "Programming Language :: Python :: 3.5", + "Development Status :: 5 - Production/Stable", + "Intended Audience :: Science/Research", + "License :: OSI Approved :: BSD License", + "Operating System :: OS Independent", + "Topic :: Scientific/Engineering :: Astronomy", + "Topic :: Scientific/Engineering :: Physics", + ], + cmdclass=cmdclassd, + zip_safe=False, + use_2to3=False, + **package_info )