Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
62e37c1
WIP
samaloney Sep 26, 2024
f44f95f
Scipy fitting works can't get astropy fitting to work
samaloney Oct 4, 2024
007abbf
WIP
samaloney Oct 16, 2024
130ff72
WIP
samaloney Oct 31, 2024
9649a21
Working with astropy modelis with units
samaloney Oct 31, 2024
27fe104
try again
samaloney Oct 31, 2024
8d066a1
git repo dep
samaloney Oct 31, 2024
3025cc9
Try tag
samaloney Oct 31, 2024
794ab84
Update tests
samaloney Oct 31, 2024
2426745
more test fixes
samaloney Oct 31, 2024
58c5496
Fixed fitting_simulated_data conflicts
jajmitchell May 29, 2025
3c4ee07
Fix attempt 2 fitting_simulated_data
jajmitchell May 29, 2025
8c88be1
Compound fitting now works
jajmitchell May 30, 2025
b2fba8c
Tidy and adds changelog
jajmitchell May 30, 2025
e2ff95a
Apply unit handling and binning normalisation
jajmitchell Jun 6, 2025
17fabd9
Fixes tests
jajmitchell Jun 10, 2025
862c47d
Fixes `pyproject.toml` dependencies
jajmitchell Jun 10, 2025
42d9c1c
Removes `astropy` from `pyproject.toml`
jajmitchell Jun 10, 2025
804084f
Fixes example and unit handling in models.py
jajmitchell Jun 11, 2025
f6dad58
Adds correct astropy into pyproject.toml
jajmitchell Jun 11, 2025
34168d2
Fixes dependencies
jajmitchell Jun 11, 2025
f599155
Fix Scipy optimising function
jajmitchell Jun 12, 2025
83dea89
Spectrum llows counts fluxes and photon bins
jajmitchell Jun 12, 2025
64d668e
Fix pre-commit
jajmitchell Jun 12, 2025
c49e2c5
Fixes pre-commit
jajmitchell Jun 12, 2025
f8b7431
Fixes nonthermal test
jajmitchell Jun 12, 2025
32d3270
Update astropy tag
jajmitchell Jun 12, 2025
f42eadf
Fixes pyproject.toml
jajmitchell Jun 12, 2025
f22d7a2
Fix spectrum indexing
jajmitchell Jun 12, 2025
75e301c
Second attempt at fixing spectrum
jajmitchell Jun 12, 2025
1752747
Attempt 3 at fixing spectrum
jajmitchell Jun 12, 2025
030072a
Skips AstropyDepcrecitaionWarning
jajmitchell Jun 12, 2025
f6d6307
Fixes optimising_functions
jajmitchell Jun 12, 2025
45207ac
Second attempt fix optimising_functions
jajmitchell Jun 12, 2025
80fa881
Fixes optimising_functions fingers crossed
jajmitchell Jun 12, 2025
c8dbe3f
Fixes spectrum
jajmitchell Jun 12, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog/163.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Adds unit handling capability to class::MatrixModel to support Astropy fitting.
194 changes: 125 additions & 69 deletions examples/fitting_simulated_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,18 @@
import numpy as np
from matplotlib.colors import LogNorm

import astropy.units as u
from astropy.modeling import fitting
from astropy.visualization import quantity_support

from sunkit_spex.data.simulated_data import simulate_square_response_matrix
from sunkit_spex.fitting.objective_functions.optimising_functions import minimize_func
from sunkit_spex.fitting.optimizer_tools.minimizer_tools import scipy_minimize
from sunkit_spex.fitting.statistics.gaussian import chi_squared
from sunkit_spex.models.instrument_response import MatrixModel
from sunkit_spex.models.models import GaussianModel, StraightLineModel
from sunkit_spex.spectrum import Spectrum
from sunkit_spex.spectrum.spectrum import SpectralAxis

#####################################################
#
Expand All @@ -37,87 +41,110 @@

start, inc = 1.6, 0.04
stop = 80 + inc / 2
ph_energies = np.arange(start, stop, inc)
ph_energies = np.arange(start, stop, inc) * u.keV
ph_energies_centers = ph_energies[:-1] + 0.5 * np.diff(ph_energies)

#####################################################
#
# Let's start making a simulated photon spectrum

sim_cont = {"edges": False, "slope": -1, "intercept": 100}
sim_line = {"edges": False, "amplitude": 100, "mean": 30, "stddev": 2}
sim_cont = {"slope": -1 * u.ph / u.keV**2, "intercept": 100 * u.ph / u.keV}
sim_line = {"amplitude": 100 * u.ph / u.keV, "mean": 30 * u.keV, "stddev": 2 * u.keV}
# use a straight line model for a continuum, Gaussian for a line
ph_model = StraightLineModel(**sim_cont) + GaussianModel(**sim_line)

plt.figure()
plt.plot(ph_energies, ph_model(ph_energies))
plt.xlabel("Energy [keV]")
plt.ylabel("ph s$^{-1}$ cm$^{-2}$ keV$^{-1}$")
plt.title("Simulated Photon Spectrum")
plt.show()
with quantity_support():
plt.figure()
plt.plot(ph_energies_centers, ph_model(ph_energies))
plt.xlabel(f"Energy [{ph_energies.unit}]")
plt.title("Simulated Photon Spectrum")
plt.show()

#####################################################
#
# Now want a response matrix

srm = simulate_square_response_matrix(ph_energies.size)
srm_model = MatrixModel(matrix=srm)

plt.figure()
plt.imshow(
srm, origin="lower", extent=[ph_energies[0], ph_energies[-1], ph_energies[0], ph_energies[-1]], norm=LogNorm()
srm = simulate_square_response_matrix(ph_energies.size - 1)
srm_model = MatrixModel(
matrix=srm,
input_axis=SpectralAxis(ph_energies),
output_axis=SpectralAxis(ph_energies),
c=1 * u.ct / u.ph,
_input_units={"x": u.ph * u.keV**-1},
_output_units={"y": u.ct * u.keV**-1},
)
plt.ylabel("Photon Energies [keV]")
plt.xlabel("Count Energies [keV]")
plt.title("Simulated SRM")
plt.show()
# srm_model.input_units = {"x": u.ph}


with quantity_support():
plt.figure()
plt.imshow(
srm_model.matrix,
origin="lower",
extent=(
srm_model.input_axis[0].value,
srm_model.input_axis[-1].value,
srm_model.output_axis[0].value,
srm_model.output_axis[-1].value,
),
norm=LogNorm(),
)
plt.ylabel(f"Photon Energies [{srm_model.input_axis.unit}]")
plt.xlabel(f"Count Energies [{srm_model.output_axis.unit}]")
plt.title("Simulated SRM")
plt.show()

#####################################################
#
# Start work on a count model

sim_gauss = {"edges": False, "amplitude": 70, "mean": 40, "stddev": 2}
sim_gauss = {"amplitude": 70 * u.ct / u.keV, "mean": 40 * u.keV, "stddev": 2 * u.keV}
# the brackets are very necessary
ct_model = (ph_model | srm_model) + GaussianModel(**sim_gauss)
ct_model = ph_model | srm_model

#####################################################
#
# Generate simulated count data to (almost) fit

sim_count_model = ct_model(ph_energies)

sim_count_model = ct_model(SpectralAxis(ph_energies))
#####################################################
#
# Add some noise
np_rand = np.random.default_rng(seed=10)
sim_count_model_wn = sim_count_model + (2 * np_rand.random(sim_count_model.size) - 1) * np.sqrt(sim_count_model)
sim_count_model_wn = (
sim_count_model + (2 * np_rand.random(sim_count_model.size)) * np.sqrt(sim_count_model.value) * u.ct / u.keV
)

obs_spec = Spectrum(sim_count_model_wn.reshape(-1), spectral_axis=ph_energies)


#####################################################
#
# Can plot all the different components in the simulated count spectrum

plt.figure()
plt.plot(ph_energies, (ph_model | srm_model)(ph_energies), label="photon model features")
plt.plot(ph_energies, GaussianModel(**sim_gauss)(ph_energies), label="gaussian feature")
plt.plot(ph_energies, sim_count_model, label="total sim. spectrum")
plt.plot(ph_energies, sim_count_model_wn, label="total sim. spectrum + noise", lw=0.5)
plt.xlabel("Energy [keV]")
plt.ylabel("cts s$^{-1}$ keV$^{-1}$")
plt.title("Simulated Count Spectrum")
plt.legend()
with quantity_support():
plt.figure()
plt.plot(ph_energies_centers, (ph_model | srm_model)(ph_energies), label="photon model features")
plt.plot(ph_energies_centers, GaussianModel(**sim_gauss)(ph_energies), label="gaussian feature")
plt.plot(ph_energies_centers, sim_count_model, label="total sim. spectrum")
plt.plot(obs_spec._spectral_axis, obs_spec.data, label="total sim. spectrum + noise", lw=0.5)
plt.xlabel(f"Energy [{ph_energies.unit}]")
plt.title("Simulated Count Spectrum")
plt.legend()

plt.text(80, 170, "(ph_model(sl,in,am1,mn1,sd1) | srm)", ha="right", c="tab:blue", weight="bold")
plt.text(80, 150, "+ Gaussian(am2,mn2,sd2)", ha="right", c="tab:orange", weight="bold")
plt.show()
plt.text(80, 170, "(ph_model(sl,in,am1,mn1,sd1) | srm)", ha="right", c="tab:blue", weight="bold")
plt.text(80, 150, "+ Gaussian(am2,mn2,sd2)", ha="right", c="tab:orange", weight="bold")
plt.show()

#####################################################
#
# Now we have the simulated data, let's start setting up to fit it
#
# Get some initial guesses that are off from the simulated data above

guess_cont = {"edges": False, "slope": -0.5, "intercept": 80}
guess_line = {"edges": False, "amplitude": 150, "mean": 32, "stddev": 5}
guess_gauss = {"edges": False, "amplitude": 350, "mean": 39, "stddev": 0.5}
guess_cont = {"slope": -0.5 * u.ph / u.keV**2, "intercept": 80 * u.ph / u.keV}
guess_line = {"amplitude": 150 * u.ph / u.keV, "mean": 32 * u.keV, "stddev": 5 * u.keV}
guess_gauss = {"amplitude": 350 * u.ct / u.keV, "mean": 39 * u.keV, "stddev": 0.5 * u.keV}

#####################################################
#
Expand All @@ -126,22 +153,24 @@
ph_mod_4fit = StraightLineModel(**guess_cont) + GaussianModel(**guess_line)
count_model_4fit = (ph_mod_4fit | srm_model) + GaussianModel(**guess_gauss)

#####################################################
#
# Let's fit the simulated data and plot the result

opt_res = scipy_minimize(
minimize_func, count_model_4fit.parameters, (sim_count_model_wn, ph_energies, count_model_4fit, chi_squared)
)
# print(ph_mod_4fit(ph_energies).size)
# print(count_model_4fit(obs_spec.data).size)
# #####################################################
# #
# # Let's fit the simulated data and plot the result

plt.figure()
plt.plot(ph_energies, sim_count_model_wn, label="total sim. spectrum + noise")
plt.plot(ph_energies, count_model_4fit.evaluate(ph_energies, *opt_res.x), ls=":", label="model fit")
plt.xlabel("Energy [keV]")
plt.ylabel("cts s$^{-1}$ keV$^{-1}$")
plt.title("Simulated Count Spectrum Fit with Scipy")
plt.legend()
plt.show()

opt_res = scipy_minimize(minimize_func, count_model_4fit.parameters, (obs_spec, count_model_4fit, chi_squared))

with quantity_support():
plt.figure()
plt.plot(ph_energies_centers, sim_count_model_wn, label="total sim. spectrum + noise")
plt.plot(ph_energies_centers, count_model_4fit.evaluate(ph_energies.value, *opt_res.x), ls=":", label="model fit")
plt.xlabel(f"Energy [{ph_energies.unit}]")
plt.title("Simulated Count Spectrum Fit with Scipy")
plt.legend()
plt.show()


#####################################################
Expand All @@ -150,18 +179,31 @@
#
# Try and ensure we start fresh with new model definitions

guess_cont = {"slope": -0.5 * u.ph / u.keV**2, "intercept": 80 * u.ph / u.keV}
guess_line = {"amplitude": 150 * u.ph / u.keV, "mean": 32 * u.keV, "stddev": 5 * u.keV}

ph_mod_4astropyfit = StraightLineModel(**guess_cont) + GaussianModel(**guess_line)
count_model_4astropyfit = (ph_mod_4fit | srm_model) + GaussianModel(**guess_gauss)

astropy_fit = fitting.LevMarLSQFitter()
cgauss = GaussianModel(**guess_gauss)


count_model_4astropyfit = (ph_mod_4astropyfit | srm_model) + cgauss


astropy_fitted_result = astropy_fit(count_model_4astropyfit, ph_energies, sim_count_model_wn)
astropy_fit = fitting.LevMarLSQFitter()
astropy_fitted_result = astropy_fit(count_model_4astropyfit, ph_energies, obs_spec.data << obs_spec.unit)

plt.figure()
plt.plot(ph_energies, sim_count_model_wn, label="total sim. spectrum + noise")
plt.plot(ph_energies, astropy_fitted_result(ph_energies), ls=":", label="model fit")
plt.plot(ph_energies_centers, sim_count_model_wn, label="total sim. spectrum + noise")
plt.plot(
ph_energies_centers,
count_model_4astropyfit.evaluate(ph_energies.value, *astropy_fitted_result.parameters),
ls=":",
label="model fit",
)

plt.xlabel("Energy [keV]")
plt.ylabel("cts s$^{-1}$ keV$^{-1}$")
plt.ylabel("ct keV$^{-1}$")
plt.title("Simulated Count Spectrum Fit with Astropy")
plt.legend()
plt.show()
Expand All @@ -170,24 +212,38 @@
#
# Display a table of the fitted results

plt.figure(layout="constrained")
# plt.figure(layout="constrained")


# row_labels = (
# tuple(sim_cont)[-2:] + tuple(f"{p}1" for p in tuple(sim_line)[-3:]) + tuple(f"{p}2" for p in tuple(sim_gauss)[-3:])
# )
# column_labels = ("True Values", "Guess Values", "Scipy Fit", "Astropy Fit")
# true_vals = np.array(tuple(sim_cont.values())[-2:] + tuple(sim_line.values())[-3:] + tuple(sim_gauss.values())[-3:])
# guess_vals = np.array(
# tuple(guess_cont.values())[-2:] + tuple(guess_line.values())[-3:] + tuple(guess_gauss.values())[-3:]
# )
# scipy_vals = opt_res.x
# astropy_vals = astropy_fitted_result.parameters

# print(np.shape(scipy_vals))
# print(np.shape(astropy_vals))
# print(np.shape(true_vals))
# print(np.shape(guess_vals))

plt.figure(layout="constrained")

row_labels = (
tuple(sim_cont)[-2:] + tuple(f"{p}1" for p in tuple(sim_line)[-3:]) + tuple(f"{p}2" for p in tuple(sim_gauss)[-3:])
tuple(sim_cont) + tuple(f"{p}1" for p in tuple(sim_line)) + ("C",) + tuple(f"{p}2" for p in tuple(sim_gauss))
)
column_labels = ("True Values", "Guess Values", "Scipy Fit", "Astropy Fit")
true_vals = np.array(tuple(sim_cont.values())[-2:] + tuple(sim_line.values())[-3:] + tuple(sim_gauss.values())[-3:])
guess_vals = np.array(
tuple(guess_cont.values())[-2:] + tuple(guess_line.values())[-3:] + tuple(guess_gauss.values())[-3:]
)
true_vals = tuple(sim_cont.values()) + tuple(sim_line.values()) + (1 * u.m,) + tuple(sim_gauss.values())
true_vals = [t.value for t in true_vals]
guess_vals = tuple(guess_cont.values()) + tuple(guess_line.values()) + (1 * u.m,) + tuple(guess_gauss.values())
guess_vals = [g.value for g in guess_vals]
scipy_vals = opt_res.x
astropy_vals = astropy_fitted_result.parameters

print(np.shape(scipy_vals))
print(np.shape(astropy_vals))
print(np.shape(true_vals))
print(np.shape(guess_vals))

cell_vals = np.vstack((true_vals, guess_vals, scipy_vals, astropy_vals)).T
cell_text = np.round(np.vstack((true_vals, guess_vals, scipy_vals, astropy_vals)).T, 2).astype(str)
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ authors = [
{ name = "The SunPy Community", email = "sunpy@googlegroups.com" },
]
dependencies = [
"astropy @ git+https://github.com/jajmitchell/astropy.git@astropy_sunkit-spex",
"corner>=2.2",
"emcee>=3.1",
"matplotlib>=3.7",
Expand Down
1 change: 1 addition & 0 deletions pytest.ini
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,4 @@ filterwarnings =
# Oldestdeps issues
ignore:`finfo.machar` is deprecated:DeprecationWarning
ignore:Please use `convolve1d` from the `scipy.ndimage` namespace, the `scipy.ndimage.filters` namespace is deprecated.:DeprecationWarning
ignore::astropy.utils.exceptions.AstropyDeprecationWarning
12 changes: 9 additions & 3 deletions sunkit_spex/fitting/objective_functions/optimising_functions.py
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we want to keep this file? it seems like it's just wrapping a function call. might be more flexible to just not have it

Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
__all__ = ["minimize_func"]


def minimize_func(params, data_y, model_x, model_func, statistic_func):
def minimize_func(params, obs_spec, model_func, statistic_func):
"""
Minimization function.

Expand All @@ -31,6 +31,12 @@ def minimize_func(params, data_y, model_x, model_func, statistic_func):
-------
`float`
The value to be optimized that compares the model to the data.

"""
model_y = model_func.evaluate(model_x, *params)
return statistic_func(data_y, model_y)

if obs_spec._spectral_axis._bin_edges is not None:
model_y = model_func.evaluate(obs_spec._spectral_axis._bin_edges.value, *params)
else:
model_y = model_func.evaluate(obs_spec._spectral_axis.value, *params)

return statistic_func(obs_spec.data, model_y)
13 changes: 7 additions & 6 deletions sunkit_spex/fitting/tests/test_objective_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,33 +4,34 @@

import numpy as np

import astropy.units as u

from sunkit_spex.fitting.objective_functions.optimising_functions import minimize_func
from sunkit_spex.fitting.statistics.gaussian import chi_squared
from sunkit_spex.models.models import StraightLineModel
from sunkit_spex.spectrum import Spectrum


def test_minimize_func():
"""Test the `minimize_func` function against known outputs."""
sim_x0 = np.arange(3)
sim_x0 = np.arange(3) * u.keV
model_params0 = {"slope": 1, "intercept": 0}
sim_model0 = StraightLineModel(edges=False, **model_params0)
sim_data0 = sim_model0.evaluate(sim_x0, **model_params0)
res0 = minimize_func(
params=tuple(model_params0.values()),
data_y=sim_data0,
model_x=sim_x0,
obs_spec=Spectrum(sim_data0, spectral_axis=sim_x0),
model_func=sim_model0,
statistic_func=chi_squared,
)

sim_x1 = np.arange(3)
sim_x1 = np.arange(3) * u.keV
model_params1 = {"slope": 1, "intercept": 0}
sim_model1 = StraightLineModel(edges=False, **model_params1)
sim_data1 = sim_model1.evaluate(sim_x1, **model_params1)[::-1]
res1 = minimize_func(
params=tuple(model_params1.values()),
data_y=sim_data1,
model_x=sim_x1,
obs_spec=Spectrum(sim_data1, spectral_axis=sim_x1),
model_func=sim_model1,
statistic_func=chi_squared,
)
Expand Down
Loading
Loading