Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
36223af
Adds SRM units and passes Spectralaxis to thermal
jajmitchell Jul 8, 2025
c8b9b95
Checks input is SpectralAxis object
jajmitchell Jul 8, 2025
c96bae5
Adds check input function
jajmitchell Jul 9, 2025
84b07e8
ThickTarget and ThinTarget use centers
jajmitchell Jul 9, 2025
d9f16b5
Albedo now takes SpectralAxis and centers
jajmitchell Jul 9, 2025
062c0ad
Pre-commit
jajmitchell Jul 9, 2025
403b2ac
scaling.py and models.py now function with centers
jajmitchell Jul 9, 2025
46e3b93
Fixes scaling and albedo tests
jajmitchell Jul 9, 2025
47eef10
Adds __array_finalize__ method to SpectralAxis
jajmitchell Jul 17, 2025
9516b9d
Fixes tests
jajmitchell Jul 17, 2025
0e9103f
Fixes part of fitting_simulated_data.py
jajmitchell Jul 17, 2025
b3be7e9
Reverts SRM model to main version
jajmitchell Jul 17, 2025
ff4c977
Fixes fitting tests
jajmitchell Jul 17, 2025
12c39a0
Fixes test_models.py
jajmitchell Jul 18, 2025
0e49d45
Fixes Albedo example
jajmitchell Jul 18, 2025
56c71c3
Fix Albedo example 2
jajmitchell Jul 18, 2025
dde9e35
Fixes examples
jajmitchell Jul 18, 2025
a0f2db5
Adds changelog
jajmitchell Jul 18, 2025
8d11859
Pre_commit fix
jajmitchell Jul 18, 2025
1b53ad6
Adds custom `__call__` to `ThermalEmission`
jajmitchell Jul 25, 2025
8498325
Fixes test_thermal.py
jajmitchell Jul 28, 2025
8d4c1b1
Thermal model uses photo_axis from meta
jajmitchell Sep 10, 2025
54407a3
Can pass array without initialising
jajmitchell Sep 10, 2025
56b71aa
Fitting works
jajmitchell Sep 15, 2025
1349730
Tidy up
jajmitchell Sep 15, 2025
2e17bd5
More tidy
jajmitchell Sep 15, 2025
b4c9782
Remove unused functions
jajmitchell Sep 15, 2025
3945f91
Merge branch 'main' into pass-energy-centers
jajmitchell Sep 15, 2025
e32ab3d
Fixes continuum and line models
jajmitchell Sep 15, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog/223.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Modifies :class:`sunkit_spex.models.scaling.InverseSquareFluxScaling`, :class:`sunkit_spex.models.scaling.Constant`, :class:`sunkit_spex.models.models.GaussianModel`, :class:`sunkit_spex.models.models.StraightLineModel`, :class:`sunkit_spex.models.physical.thermal.ThermalEmission`, :class:`sunkit_spex.models.physical.thermal.LineEmission`, :class:`sunkit_spex.models.physical.thermal.ContinuumEmission`, :class:`sunkit_spex.models.physical.nonthermal.ThickTarget`, :class:`sunkit_spex.models.physical.nonthermal.ThinTarget` and :class:`sunkit_spex.models.physical.albedo.Albedo` to take :class:`sunkit_spex.spectrum.spectrum.SpectralAxis` as input and fixes tests and examples accordingly.
13 changes: 7 additions & 6 deletions examples/fitting_simulated_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,13 @@
stop = 80 + inc / 2
ph_energies = np.arange(start, stop, inc)


#####################################################
#
# Let's start making a simulated photon spectrum

sim_cont = {"edges": False, "slope": -1, "intercept": 100}
sim_line = {"edges": False, "amplitude": 100, "mean": 30, "stddev": 2}
sim_cont = {"slope": -1, "intercept": 100}
sim_line = {"amplitude": 100, "mean": 30, "stddev": 2}
# use a straight line model for a continuum, Gaussian for a line
ph_model = StraightLineModel(**sim_cont) + GaussianModel(**sim_line)

Expand Down Expand Up @@ -75,7 +76,7 @@
#
# Start work on a count model

sim_gauss = {"edges": False, "amplitude": 70, "mean": 40, "stddev": 2}
sim_gauss = {"amplitude": 70, "mean": 40, "stddev": 2}
# the brackets are very necessary
ct_model = (ph_model | srm_model) + GaussianModel(**sim_gauss)

Expand Down Expand Up @@ -115,9 +116,9 @@
#
# Get some initial guesses that are off from the simulated data above

guess_cont = {"edges": False, "slope": -0.5, "intercept": 80}
guess_line = {"edges": False, "amplitude": 150, "mean": 32, "stddev": 5}
guess_gauss = {"edges": False, "amplitude": 350, "mean": 39, "stddev": 0.5}
guess_cont = {"slope": -0.5, "intercept": 80}
guess_line = {"amplitude": 150, "mean": 32, "stddev": 5}
guess_gauss = {"amplitude": 350, "mean": 39, "stddev": 0.5}

#####################################################
#
Expand Down
4 changes: 2 additions & 2 deletions sunkit_spex/fitting/tests/test_objective_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def test_minimize_func():
"""Test the `minimize_func` function against known outputs."""
sim_x0 = np.arange(3)
model_params0 = {"slope": 1, "intercept": 0}
sim_model0 = StraightLineModel(edges=False, **model_params0)
sim_model0 = StraightLineModel(**model_params0)
sim_data0 = sim_model0.evaluate(sim_x0, **model_params0)
res0 = minimize_func(
params=tuple(model_params0.values()),
Expand All @@ -25,7 +25,7 @@ def test_minimize_func():

sim_x1 = np.arange(3)
model_params1 = {"slope": 1, "intercept": 0}
sim_model1 = StraightLineModel(edges=False, **model_params1)
sim_model1 = StraightLineModel(**model_params1)
sim_data1 = sim_model1.evaluate(sim_x1, **model_params1)[::-1]
res1 = minimize_func(
params=tuple(model_params1.values()),
Expand Down
4 changes: 2 additions & 2 deletions sunkit_spex/fitting/tests/test_optimizer_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,14 @@ def test_scipy_minimize():
sim_x0 = np.arange(3)
model_params0 = {"slope": 1, "intercept": 0}
model_param_values0 = tuple(model_params0.values())
sim_model0 = StraightLineModel(edges=False, **model_params0)
sim_model0 = StraightLineModel(**model_params0)
sim_data0 = sim_model0.evaluate(sim_x0, **model_params0)
opt_res0 = scipy_minimize(minimize_func, model_param_values0, (sim_data0, sim_x0, sim_model0, chi_squared))

sim_x1 = np.arange(3)
model_params1 = {"slope": 8, "intercept": 5}
model_param_values1 = tuple(model_params1.values())
sim_model1 = StraightLineModel(edges=False, **model_params1)
sim_model1 = StraightLineModel(**model_params1)
sim_data1 = sim_model1.evaluate(sim_x1, **model_params1)
opt_res1 = scipy_minimize(minimize_func, model_param_values1, (sim_data1, sim_x1, sim_model1, chi_squared))

Expand Down
14 changes: 2 additions & 12 deletions sunkit_spex/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,10 @@ class StraightLineModel(FittableModel):
slope = Parameter(default=1, description="Gradient of a straight line model.")
intercept = Parameter(default=0, description="Y-intercept of a straight line model.")

def __init__(self, slope=slope, intercept=intercept, edges=True, **kwargs):
self.edges = edges

def __init__(self, slope=slope, intercept=intercept, **kwargs):
super().__init__(slope, intercept, **kwargs)

def evaluate(self, x, slope, intercept):
if self.edges:
x = x[:-1] + 0.5 * np.diff(x)

"""Evaluate the straight line model at `x` with parameters `slope` and `intercept`."""
return slope * x + intercept

Expand Down Expand Up @@ -58,17 +53,12 @@ class GaussianModel(FittableModel):
mean = Parameter(default=0, min=0, description="X-offset for Gaussian.")
stddev = Parameter(default=1, description="Sigma for Gaussian.")

def __init__(self, amplitude=amplitude, mean=mean, stddev=stddev, edges=True, **kwargs):
self.edges = edges

def __init__(self, amplitude=amplitude, mean=mean, stddev=stddev, **kwargs):
super().__init__(amplitude, mean, stddev, **kwargs)

def evaluate(self, x, amplitude, mean, stddev):
"""Evaluate the Gaussian model at `x` with parameters `amplitude`, `mean`, and `stddev`."""

if self.edges:
x = x[:-1] + 0.5 * np.diff(x)

return amplitude * np.e ** (-((x - mean) ** 2) / (2 * stddev**2))

@property
Expand Down
29 changes: 24 additions & 5 deletions sunkit_spex/models/physical/albedo.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import warnings
from functools import lru_cache

import numpy as np
Expand All @@ -11,6 +12,8 @@

from sunpy.data import cache

from sunkit_spex.spectrum.spectrum import SpectralAxis

__all__ = ["Albedo", "get_albedo_matrix"]


Expand Down Expand Up @@ -45,11 +48,11 @@ class Albedo(FittableModel):
from astropy.visualization import quantity_support

from sunkit_spex.models.physical.albedo import Albedo
from sunkit_spex.spectrum.spectrum import SpectralAxis

e_edges = np.linspace(5, 550, 600) * u.keV
e_centers = e_edges[0:-1] + (0.5 * np.diff(e_edges))
e_centers = SpectralAxis(np.linspace(5, 550, 600) * u.keV, bin_specification='edges')
source = PowerLaw1D(amplitude=1*u.ph/(u.cm*u.s), x_0=5*u.keV, alpha=3)
albedo = Albedo(energy_edges=e_edges)
albedo = Albedo(spectral_axis=e_centers)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

change variable name

observed = source | albedo

with quantity_support():
Expand Down Expand Up @@ -89,15 +92,17 @@ class Albedo(FittableModel):
_input_units_allow_dimensionless = True

def __init__(self, *args, **kwargs):
self.energy_edges = kwargs.pop("energy_edges")
self.spectral_axis = kwargs.pop("spectral_axis")

super().__init__(*args, **kwargs)

def evaluate(self, spectrum, theta, anisotropy):
if not isinstance(theta, Quantity):
theta = theta * u.deg

albedo_matrix = get_albedo_matrix(self.energy_edges, theta, anisotropy)
energy_edges = _check_input_type(self.spectral_axis)

albedo_matrix = get_albedo_matrix(energy_edges, theta, anisotropy)

return spectrum + spectrum @ albedo_matrix

Expand Down Expand Up @@ -230,3 +235,17 @@ def get_albedo_matrix(energy_edges: Quantity[u.keV], theta: Quantity[u.deg], ani
anisotropy = np.array(anisotropy).squeeze()

return _calculate_albedo_matrix(tuple(energy_edges.to_value(u.keV)), theta.to_value(u.deg), anisotropy.item())


def _check_input_type(spectral_axis):
if isinstance(spectral_axis, SpectralAxis):
energy_edges = spectral_axis.bin_edges
else:
warnings.warn(
"As a SpectralAxis object was not passed, bin edges will be calculated as averages from the centers given.",
UserWarning,
)
spectral_axis = SpectralAxis(spectral_axis, bin_specification="centers")
energy_edges = spectral_axis.bin_edges

return energy_edges
Comment on lines +240 to +251
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should throw an error if spectral axis not passed, not a warning. that's the API change right?

22 changes: 18 additions & 4 deletions sunkit_spex/models/physical/nonthermal.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,8 @@ def __init__(
**kwargs,
)

def evaluate(self, energy_edges, p, break_energy, q, low_e_cutoff, high_e_cutoff, total_eflux):
energy_centers = energy_edges[:-1] + 0.5 * np.diff(energy_edges)
def evaluate(self, spectral_axis, p, break_energy, q, low_e_cutoff, high_e_cutoff, total_eflux):
energy_centers = spectral_axis
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

variable name...i'm gonna stop these comments ;P


if (
hasattr(break_energy, "unit")
Expand Down Expand Up @@ -253,8 +253,8 @@ def __init__(
**kwargs,
)

def evaluate(self, energy_edges, p, break_energy, q, low_e_cutoff, high_e_cutoff, total_eflux):
energy_centers = energy_edges[:-1] + 0.5 * np.diff(energy_edges)
def evaluate(self, spectral_axis, p, break_energy, q, low_e_cutoff, high_e_cutoff, total_eflux):
energy_centers = spectral_axis

if (
hasattr(break_energy, "unit")
Expand Down Expand Up @@ -1238,3 +1238,17 @@ def bremsstrahlung_thick_target(photon_energies, p, break_energy, q, low_e_cutof
return (fcoeff / decoeff) * flux

raise Warning("The photon energies are higher than the highest electron energy or not greater than zero")


# def _check_input_type(spectral_axis):
# if isinstance(spectral_axis, SpectralAxis):
# energy_edges = spectral_axis.bin_edges
# else:
# warnings.warn(
# "As a SpectralAxis object was not passed, bin edges will be calculated as averages from the centers given.",
# UserWarning,
# )
# spectral_axis = SpectralAxis(spectral_axis, bin_specification="centers")
# energy_edges = spectral_axis.bin_edges

# return energy_edges
Comment on lines +1243 to +1254
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

repeated code from before, so delete

14 changes: 8 additions & 6 deletions sunkit_spex/models/physical/tests/test_albedo.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from astropy.units import UnitsError

from sunkit_spex.models.physical.albedo import Albedo, get_albedo_matrix
from sunkit_spex.spectrum.spectrum import SpectralAxis


def test_get_albedo_matrix():
Expand Down Expand Up @@ -34,9 +35,9 @@ def test_get_albedo_matrix_bad_angle():

def test_albedo_model():
e_edges = np.linspace(10, 300, 10) * u.keV
e_centers = e_edges[0:-1] + (0.5 * np.diff(e_edges))
e_centers = SpectralAxis(e_edges, bin_specification="edges")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Variable name should change (here and elsewhere) to reflect the updated type

source = PowerLaw1D(amplitude=100 * u.ph, x_0=10 * u.keV, alpha=4)
observed = source | Albedo(energy_edges=e_edges)
observed = source | Albedo(spectral_axis=e_centers)
observed(e_centers)


Expand Down Expand Up @@ -76,10 +77,11 @@ def test_albedo_idl():
0.0013136881009379996,
]

e_ph = np.arange(11) * 2 + 10
albedo = Albedo(energy_edges=e_ph * u.keV, theta=45 * u.deg)
e_c = e_ph[:-1] + 0.5 * np.diff(e_ph)
spec_in = e_c**-2
e_ph = (np.arange(11) * 2 + 10) * u.keV
e_c = SpectralAxis(e_ph, bin_specification="edges")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Change variable name

albedo = Albedo(spectral_axis=e_c, theta=45 * u.deg)

spec_in = e_c.value**-2
spec_out = albedo(spec_in[:])
assert_allclose(idl_spec_in, spec_in)
assert_allclose(idl_spec_out, spec_out)
17 changes: 9 additions & 8 deletions sunkit_spex/models/physical/tests/test_nonthermal.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import astropy.units as u

from sunkit_spex.models.physical import nonthermal
from sunkit_spex.spectrum.spectrum import SpectralAxis

SSW_INTENSITY_UNIT = u.ph / u.cm**2 / u.s / u.keV

Expand Down Expand Up @@ -41,7 +42,7 @@ def thick_target():

Ensure you are using the same .sav file as used here.
"""
energy_edges = np.arange(25, 100.5, 0.5) * u.keV
spectral_axis = SpectralAxis(np.arange(25, 100.5, 0.5) * u.keV, bin_specification="edges")
observer_distance = (1 * u.AU).to(u.cm)
# fmt: off
ssw_output = (
Expand All @@ -65,7 +66,7 @@ def thick_target():
* SSW_INTENSITY_UNIT * (4 * np.pi * observer_distance**2)
)
# fmt: on
return energy_edges, ssw_output
return spectral_axis, ssw_output


def thin_target():
Expand Down Expand Up @@ -101,7 +102,7 @@ def thin_target():

Ensure you are using the same .sav file as used here.
"""
energy_edges = np.arange(25, 100.5, 0.5) * u.keV
spectral_axis = SpectralAxis(np.arange(25, 100.5, 0.5) * u.keV, bin_specification="edges")
observer_distance = (1 * u.AU).to(u.cm)
# fmt: off
ssw_output = (
Expand All @@ -125,22 +126,22 @@ def thin_target():
* SSW_INTENSITY_UNIT * (4 * np.pi * observer_distance**2)
)
# fmt: on
return energy_edges, ssw_output
return spectral_axis, ssw_output


@pytest.mark.parametrize("ssw", [thick_target])
def test_thick_target_against_ssw(ssw):
energy_edges, expected = ssw()
spectral_axis, expected = ssw()
model = nonthermal.ThickTarget()
output = model(energy_edges)
output = model(spectral_axis)
expected_value = expected.to_value(output.unit)
np.testing.assert_allclose(output.value, expected_value, rtol=0.035)


@pytest.mark.parametrize("ssw", [thin_target])
def test_thin_target_against_ssw(ssw):
energy_edges, expected = ssw()
spectral_axis, expected = ssw()
model = nonthermal.ThinTarget()
output = model(energy_edges)
output = model(spectral_axis)
expected_value = expected.to_value(output.unit)
np.testing.assert_allclose(output.value, expected_value, rtol=0.035)
Loading
Loading