From 00f5b9b3c36733a98c9f6cba613e597ffede2c0c Mon Sep 17 00:00:00 2001 From: Kumail Zaidi Date: Tue, 2 Jun 2026 17:02:53 -0500 Subject: [PATCH 1/5] upgrade to diffsky's gd_specphot_... and use poisson loss in emline loss --- diffhtwo/experimental/conftest.py | 11 +- .../data/hizels/halpha_LF_z0p4.dat | 41 +++--- .../data/hizels/halpha_LF_z0p84.dat | 20 +-- .../data/hizels/halpha_LF_z1p47.dat | 29 +++-- .../data/hizels/halpha_LF_z2p23.dat | 33 +++-- .../experimental/data_loaders/load_hizels.py | 117 +++++++++++------- diffhtwo/experimental/kernels/N_spec.py | 40 ++++++ diffhtwo/experimental/kernels/lc_spec_kern.py | 41 ++++++ diffhtwo/experimental/kernels/spec_kern.py | 6 +- .../experimental/kernels/tests/test_N_spec.py | 22 ++++ .../experimental/loss_kernels/emline_loss.py | 63 +++++----- .../loss_kernels/emline_loss_mse.py | 86 +++++++++++++ .../loss_kernels/tests/test_emline_loss.py | 26 ++-- .../tests/test_emline_loss_mse.py | 41 ++++++ .../optimizers/Np_specphot_opt.py | 55 ++++---- .../optimizers/tests/test_Np_specphot_opt.py | 14 +-- scripts/config_diagnostics.yaml | 10 +- 17 files changed, 450 insertions(+), 205 deletions(-) create mode 100644 diffhtwo/experimental/kernels/N_spec.py create mode 100644 diffhtwo/experimental/kernels/lc_spec_kern.py create mode 100644 diffhtwo/experimental/kernels/tests/test_N_spec.py create mode 100644 diffhtwo/experimental/loss_kernels/emline_loss_mse.py create mode 100644 diffhtwo/experimental/loss_kernels/tests/test_emline_loss_mse.py diff --git a/diffhtwo/experimental/conftest.py b/diffhtwo/experimental/conftest.py index 28f90611..328d94ef 100644 --- a/diffhtwo/experimental/conftest.py +++ b/diffhtwo/experimental/conftest.py @@ -52,16 +52,13 @@ def feniks(ran_key, fake_subset_ssp_data): @pytest.fixture(scope="session") -def hizels(ran_key, fake_subset_ssp_data, feniks_tcurves): +def hizels_fitting_data(ran_key, fake_subset_ssp_data, feniks_tcurves): ssp_data, emline_wave_aa = fake_subset_ssp_data - hizels = load_hizels.get_hizels_data( - HIZELS_DRN, - ran_key, - ssp_data, - feniks_tcurves, + hizels_fitting_data = load_hizels.get_hizels_data( + HIZELS_DRN, ran_key, ssp_data, feniks_tcurves, emline_wave_aa ) - return hizels + return hizels_fitting_data @pytest.fixture(scope="session") diff --git a/diffhtwo/experimental/data/hizels/halpha_LF_z0p4.dat b/diffhtwo/experimental/data/hizels/halpha_LF_z0p4.dat index 3e244a3c..f3b28d85 100644 --- a/diffhtwo/experimental/data/hizels/halpha_LF_z0p4.dat +++ b/diffhtwo/experimental/data/hizels/halpha_LF_z0p4.dat @@ -1,23 +1,20 @@ # z = 0.40 -# Columns: -# logLHa logLHa_binw_full logphi_obs logphi_obs_err logphi_corr logphi_corr_err -logLHa logLHa_binw_full logphi_obs logphi_obs_err logphi_corr logphi_corr_err -40.50 0.10 -1.84 0.04 -1.66 0.04 -40.60 0.10 -1.78 0.04 -1.70 0.04 -40.70 0.10 -1.87 0.04 -1.81 0.04 -40.80 0.10 -2.01 0.05 -1.93 0.05 -40.90 0.10 -2.20 0.06 -1.96 0.07 -41.00 0.10 -2.21 0.06 -2.03 0.07 -41.10 0.10 -2.41 0.08 -2.12 0.09 -41.20 0.10 -2.39 0.08 -2.27 0.08 -41.30 0.10 -2.43 0.08 -2.29 0.09 -41.40 0.10 -2.55 0.10 -2.42 0.10 -41.50 0.10 -2.55 0.10 -2.46 0.11 -41.60 0.10 -2.71 0.12 -2.57 0.13 -41.70 0.10 -2.94 0.17 -2.69 0.19 -41.80 0.10 -2.90 0.16 -2.73 0.17 -41.90 0.10 -3.04 0.19 -2.88 0.20 -42.00 0.10 -3.34 0.30 -3.03 0.35 -42.20 0.20 -3.45 0.36 -3.56 0.51 -42.50 0.30 -3.64 0.53 -3.71 0.71 - +logLHa logLHa_binw_full nsources logphi_obs logphi_obs_err logphi_corr logphi_corr_err vol_1e4Mpc3 +40.50 0.10 128 -1.84 0.04 -1.66 0.04 8.8 +40.60 0.10 147 -1.78 0.04 -1.70 0.04 8.8 +40.70 0.10 118 -1.87 0.04 -1.81 0.04 8.8 +40.80 0.10 86 -2.01 0.05 -1.93 0.05 8.8 +40.90 0.10 56 -2.20 0.06 -1.96 0.07 8.8 +41.00 0.10 54 -2.21 0.06 -2.03 0.07 8.8 +41.10 0.10 34 -2.41 0.08 -2.12 0.09 8.8 +41.20 0.10 36 -2.39 0.08 -2.27 0.08 8.8 +41.30 0.10 33 -2.43 0.08 -2.29 0.09 8.8 +41.40 0.10 25 -2.55 0.10 -2.42 0.10 8.8 +41.50 0.10 25 -2.55 0.10 -2.46 0.11 8.8 +41.60 0.10 17 -2.71 0.12 -2.57 0.13 8.8 +41.70 0.10 10 -2.94 0.17 -2.69 0.19 8.8 +41.80 0.10 11 -2.90 0.16 -2.73 0.17 8.8 +41.90 0.10 8 -3.04 0.19 -2.88 0.20 8.8 +42.00 0.10 4 -3.34 0.30 -3.03 0.35 8.8 +42.20 0.20 3 -3.45 0.36 -3.56 0.51 8.8 +42.50 0.30 2 -3.64 0.53 -3.71 0.71 8.8 diff --git a/diffhtwo/experimental/data/hizels/halpha_LF_z0p84.dat b/diffhtwo/experimental/data/hizels/halpha_LF_z0p84.dat index b2cf5ef5..74ae341f 100644 --- a/diffhtwo/experimental/data/hizels/halpha_LF_z0p84.dat +++ b/diffhtwo/experimental/data/hizels/halpha_LF_z0p84.dat @@ -1,11 +1,11 @@ # z = 0.84 -logLHa logLHa_binw_full logphi_obs logphi_obs_err logphi_corr logphi_corr_err -41.70 0.15 -2.12 0.03 -1.93 0.03 -41.85 0.15 -2.11 0.03 -2.02 0.03 -42.00 0.15 -2.43 0.04 -2.18 0.04 -42.15 0.15 -2.72 0.06 -2.43 0.06 -42.30 0.15 -3.38 0.15 -2.73 0.17 -42.45 0.15 -3.46 0.17 -3.01 0.17 -42.60 0.15 -3.61 0.21 -3.27 0.21 -42.75 0.15 -4.16 0.53 -3.79 0.55 -42.90 0.15 -4.46 0.90 -4.13 1.51 +logLHa logLHa_binw_full nsources logphi_obs logphi_obs_err logphi_corr logphi_corr_err vol_1e4Mpc3 +41.70 0.15 218 -2.12 0.03 -1.93 0.03 19.1 +41.85 0.15 222 -2.11 0.03 -2.02 0.03 19.1 +42.00 0.15 107 -2.43 0.04 -2.18 0.04 19.1 +42.15 0.15 54 -2.72 0.06 -2.43 0.06 19.1 +42.30 0.15 12 -3.38 0.15 -2.73 0.17 19.1 +42.45 0.15 10 -3.46 0.17 -3.01 0.17 19.1 +42.60 0.15 7 -3.61 0.21 -3.27 0.21 19.1 +42.75 0.15 2 -4.16 0.53 -3.79 0.55 19.1 +42.90 0.15 1 -4.46 0.90 -4.13 1.51 19.1 diff --git a/diffhtwo/experimental/data/hizels/halpha_LF_z1p47.dat b/diffhtwo/experimental/data/hizels/halpha_LF_z1p47.dat index 92def431..e115a6a2 100644 --- a/diffhtwo/experimental/data/hizels/halpha_LF_z1p47.dat +++ b/diffhtwo/experimental/data/hizels/halpha_LF_z1p47.dat @@ -1,16 +1,15 @@ # z = 1.47 -logLHa logLHa_binw_full logphi_obs logphi_obs_err logphi_corr logphi_corr_err -42.10 0.10 -2.20 0.10 -2.13 0.10 -42.20 0.10 -2.37 0.08 -2.25 0.09 -42.30 0.10 -2.55 0.06 -2.34 0.06 -42.40 0.10 -2.67 0.05 -2.47 0.05 -42.50 0.10 -2.78 0.05 -2.62 0.05 -42.60 0.10 -2.83 0.04 -2.73 0.04 -42.70 0.10 -3.23 0.07 -2.91 0.08 -42.80 0.10 -3.50 0.10 -3.18 0.11 -42.90 0.10 -3.91 0.18 -3.55 0.18 -43.00 0.10 -4.17 0.26 -3.81 0.26 -43.10 0.10 -4.39 0.37 -4.22 0.38 -43.20 0.10 -4.57 0.53 -4.55 0.55 -43.40 0.30 -4.57 0.53 -4.86 0.55 - +logLHa logLHa_binw_full nsources logphi_obs logphi_obs_err logphi_corr logphi_corr_err vol_1e4Mpc3 +42.10 0.10 25 -2.20 0.10 -2.13 0.10 4.0 +42.20 0.10 32 -2.37 0.08 -2.25 0.09 7.5 +42.30 0.10 62 -2.55 0.06 -2.34 0.06 22.1 +42.40 0.10 86 -2.67 0.05 -2.47 0.05 40.2 +42.50 0.10 101 -2.78 0.05 -2.62 0.05 60.4 +42.60 0.10 106 -2.83 0.04 -2.73 0.04 71.4 +42.70 0.10 43 -3.23 0.07 -2.91 0.08 73.6 +42.80 0.10 23 -3.50 0.10 -3.18 0.11 73.6 +42.90 0.10 9 -3.91 0.18 -3.55 0.18 73.6 +43.00 0.10 5 -4.17 0.26 -3.81 0.26 73.6 +43.10 0.10 3 -4.39 0.37 -4.22 0.38 73.6 +43.20 0.10 2 -4.57 0.53 -4.55 0.55 73.6 +43.40 0.30 2 -4.57 0.53 -4.86 0.55 73.6 diff --git a/diffhtwo/experimental/data/hizels/halpha_LF_z2p23.dat b/diffhtwo/experimental/data/hizels/halpha_LF_z2p23.dat index 5a0a66bc..cc988fae 100644 --- a/diffhtwo/experimental/data/hizels/halpha_LF_z2p23.dat +++ b/diffhtwo/experimental/data/hizels/halpha_LF_z2p23.dat @@ -1,18 +1,17 @@ # z = 2.23 -logLHa logLHa_binw_full logphi_obs logphi_obs_err logphi_corr logphi_corr_err -42.00 0.15 -2.18 0.19 -1.93 0.19 -42.15 0.15 -2.34 0.16 -2.07 0.16 -42.30 0.10 -2.24 0.07 -2.19 0.07 -42.40 0.10 -2.36 0.05 -2.31 0.05 -42.50 0.10 -2.48 0.04 -2.41 0.05 -42.60 0.10 -2.60 0.04 -2.50 0.04 -42.70 0.10 -2.68 0.04 -2.59 0.05 -42.80 0.10 -2.89 0.05 -2.73 0.06 -42.90 0.10 -3.18 0.07 -2.88 0.14 -43.00 0.10 -3.41 0.09 -3.09 0.17 -43.10 0.10 -3.68 0.12 -3.33 0.22 -43.20 0.10 -4.04 0.21 -3.67 0.31 -43.30 0.10 -4.41 0.37 -4.01 0.51 -43.40 0.10 -4.59 0.53 -4.22 0.68 -43.60 0.30 -4.41 0.37 -4.63 0.41 - +logLHa logLHa_binw_full nsources logphi_obs logphi_obs_err logphi_corr logphi_corr_err vol_1e4Mpc3 +42.00 0.15 8 -2.18 0.19 -1.93 0.19 0.8 +42.15 0.15 11 -2.34 0.16 -2.07 0.16 1.6 +42.30 0.10 47 -2.24 0.07 -2.19 0.07 6.7 +42.40 0.10 91 -2.36 0.05 -2.31 0.05 20.9 +42.50 0.10 107 -2.48 0.04 -2.41 0.05 32.7 +42.60 0.10 158 -2.60 0.04 -2.50 0.04 63.3 +42.70 0.10 163 -2.68 0.04 -2.59 0.05 77.2 +42.80 0.10 100 -2.89 0.05 -2.73 0.06 77.2 +42.90 0.10 51 -3.18 0.07 -2.88 0.14 77.2 +43.00 0.10 30 -3.41 0.09 -3.09 0.17 77.2 +43.10 0.10 16 -3.68 0.12 -3.33 0.22 77.2 +43.20 0.10 7 -4.04 0.21 -3.67 0.31 77.2 +43.30 0.10 3 -4.41 0.37 -4.01 0.51 77.2 +43.40 0.10 2 -4.59 0.53 -4.22 0.68 77.2 +43.60 0.30 3 -4.41 0.37 -4.63 0.41 77.2 diff --git a/diffhtwo/experimental/data_loaders/load_hizels.py b/diffhtwo/experimental/data_loaders/load_hizels.py index a9a6bde1..39c89283 100644 --- a/diffhtwo/experimental/data_loaders/load_hizels.py +++ b/diffhtwo/experimental/data_loaders/load_hizels.py @@ -7,9 +7,9 @@ from ..lightcone_generators import generate_lc_data -HiZELS = namedtuple( - "HiZELS", - ["lg_Lbin_edges", "lg_LF", "z", "dz", "lc_data"], +Hizels = namedtuple( + "Hizels", + ["line_wave_aa", "lg_Lbin_edges", "N_data", "vol_Mpc3_data", "z", "dz", "lc_data"], ) DELTA_L_HALPHA = -0.4 # uncorrect HiZELS h-alpha L for dust (A_halpha = 1 mag) @@ -19,7 +19,8 @@ def get_hizels_data( ran_key, ssp_data, tcurves, - num_halos=500, + halpha_wave_aa, + num_halos=250, lgmp_min=10.0, lgmp_max=15.0, lc_sky_area_degsq=100, @@ -27,15 +28,18 @@ def get_hizels_data( ): ( hizels_lg_halpha_Lbin_edges_data, - hizels_lg_halpha_LF_data, - hizels_halpha_LF_z_data, - hizels_halpha_LF_delta_z_data, + hizels_halpha_N_data, + hizels_halpha_vol_Mpc3, + hizels_halpha_z_data, + hizels_halpha_delta_z_data, ) = get_hizels_halpha(drn) + line_wave_aa = [halpha_wave_aa] lg_Lbin_edges = [hizels_lg_halpha_Lbin_edges_data] - lg_LF = [hizels_lg_halpha_LF_data] - z = [hizels_halpha_LF_z_data] - dz = [hizels_halpha_LF_delta_z_data] + N_data = [hizels_halpha_N_data] + vol_Mpc3_data = [hizels_halpha_vol_Mpc3] + z = [hizels_halpha_z_data] + dz = [hizels_halpha_delta_z_data] lc_data = [] for line in range(0, len(z)): @@ -66,10 +70,10 @@ def get_hizels_data( line_lc_data.append(generate_lc_data(*lc_args)) lc_data.append(line_lc_data) - return HiZELS(lg_Lbin_edges, lg_LF, z, dz, lc_data) + return Hizels(line_wave_aa, lg_Lbin_edges, N_data, vol_Mpc3_data, z, dz, lc_data) -def get_lgL_bin_edges( +def _get_lgL_bin_edges( table, L_colname, bin_width_full_colname, delta_L_halpha=DELTA_L_HALPHA ): edges = [] @@ -109,56 +113,67 @@ def pad_dummy_lg_LF_data(lg_halpha_LF_data, lg_halpha_LF_dummy_err, max_length=1 return jnp.vstack((lg_halpha_LF_data_padded, lg_halpha_LF_err_padded)) -def lg_phi_h0p7_to_hdefault(lg_phi_h0p7): +def _lg_phi_h0p7_to_hdefault(lg_phi_h0p7): phi_h1p0 = (10**lg_phi_h0p7) / (0.7**3) return np.log10(phi_h1p0 * (DEFAULT_COSMOLOGY.h**3)) +def _vol_h0p7_to_hdefault(vol_1e4Mpc3): + vol_Mpc3_h1p0 = 1e4 * vol_1e4Mpc3 * (0.7**3) + vol_Mpc3 = vol_Mpc3_h1p0 / (DEFAULT_COSMOLOGY.h**3) + return vol_Mpc3 + + +def _lg_phi_corr_to_N_corr(lg_phi_corr, vol_1e4Mpc3): + phi_corr = 10**lg_phi_corr + vol_Mpc3 = 1e4 * vol_1e4Mpc3 + N_corr = phi_corr * vol_Mpc3 + return N_corr + + def get_hizels_halpha(drn): HiZELS_halpha_z0p4 = ascii.read(drn / "halpha_LF_z0p4.dat") - lg_halpha_Lbin_edges_z0p4 = get_lgL_bin_edges( + lg_halpha_Lbin_edges_z0p4 = _get_lgL_bin_edges( HiZELS_halpha_z0p4, "logLHa", "logLHa_binw_full" ) - lg_halpha_LF_data_z0p4 = jnp.vstack( - ( - jnp.array(lg_phi_h0p7_to_hdefault(HiZELS_halpha_z0p4["logphi_corr"])), - jnp.array(HiZELS_halpha_z0p4["logphi_corr_err"]), - ) + # lg_halpha_LF_data_z0p4 = jnp.vstack( + # ( + # jnp.array(lg_phi_h0p7_to_hdefault(HiZELS_halpha_z0p4["logphi_corr"])), + # jnp.array(HiZELS_halpha_z0p4["logphi_corr_err"]), + # ) + # ) + halpha_N_data_z0p4 = _lg_phi_corr_to_N_corr( + HiZELS_halpha_z0p4["logphi_corr"], HiZELS_halpha_z0p4["vol_1e4Mpc3"] ) + halpha_vol_Mpc3_z0p4 = _vol_h0p7_to_hdefault(HiZELS_halpha_z0p4["vol_1e4Mpc3"]) HiZELS_halpha_z0p84 = ascii.read(drn / "halpha_LF_z0p84.dat") - lg_halpha_Lbin_edges_z0p84 = get_lgL_bin_edges( + lg_halpha_Lbin_edges_z0p84 = _get_lgL_bin_edges( HiZELS_halpha_z0p84, "logLHa", "logLHa_binw_full" ) - lg_halpha_LF_data_z0p84 = jnp.vstack( - ( - jnp.array(lg_phi_h0p7_to_hdefault(HiZELS_halpha_z0p84["logphi_corr"])), - jnp.array(HiZELS_halpha_z0p84["logphi_corr_err"]), - ) + halpha_N_data_z0p84 = _lg_phi_corr_to_N_corr( + HiZELS_halpha_z0p84["logphi_corr"], HiZELS_halpha_z0p84["vol_1e4Mpc3"] ) + halpha_vol_Mpc3_z0p84 = _vol_h0p7_to_hdefault(HiZELS_halpha_z0p84["vol_1e4Mpc3"]) HiZELS_halpha_z1p47 = ascii.read(drn / "halpha_LF_z1p47.dat") - lg_halpha_Lbin_edges_z1p47 = get_lgL_bin_edges( + lg_halpha_Lbin_edges_z1p47 = _get_lgL_bin_edges( HiZELS_halpha_z1p47, "logLHa", "logLHa_binw_full" ) - lg_halpha_LF_data_z1p47 = jnp.vstack( - ( - jnp.array(lg_phi_h0p7_to_hdefault(HiZELS_halpha_z1p47["logphi_corr"])), - jnp.array(HiZELS_halpha_z1p47["logphi_corr_err"]), - ) + halpha_N_data_z1p47 = _lg_phi_corr_to_N_corr( + HiZELS_halpha_z1p47["logphi_corr"], HiZELS_halpha_z1p47["vol_1e4Mpc3"] ) + halpha_vol_Mpc3_z1p47 = _vol_h0p7_to_hdefault(HiZELS_halpha_z1p47["vol_1e4Mpc3"]) HiZELS_halpha_z2p23 = ascii.read(drn / "halpha_LF_z2p23.dat") - lg_halpha_Lbin_edges_z2p23 = get_lgL_bin_edges( + lg_halpha_Lbin_edges_z2p23 = _get_lgL_bin_edges( HiZELS_halpha_z2p23, "logLHa", "logLHa_binw_full" ) - lg_halpha_LF_data_z2p23 = jnp.vstack( - ( - jnp.array(lg_phi_h0p7_to_hdefault(HiZELS_halpha_z2p23["logphi_corr"])), - jnp.array(HiZELS_halpha_z2p23["logphi_corr_err"]), - ) + halpha_N_data_z2p23 = _lg_phi_corr_to_N_corr( + HiZELS_halpha_z2p23["logphi_corr"], HiZELS_halpha_z2p23["vol_1e4Mpc3"] ) + halpha_vol_Mpc3_z2p23 = _vol_h0p7_to_hdefault(HiZELS_halpha_z2p23["vol_1e4Mpc3"]) hizels_lg_halpha_Lbin_edges_data = [ lg_halpha_Lbin_edges_z0p4, @@ -167,21 +182,28 @@ def get_hizels_halpha(drn): lg_halpha_Lbin_edges_z2p23, ] - hizels_lg_halpha_LF_data = [ - lg_halpha_LF_data_z0p4, - lg_halpha_LF_data_z0p84, - lg_halpha_LF_data_z1p47, - lg_halpha_LF_data_z2p23, + hizels_halpha_N_data = [ + halpha_N_data_z0p4, + halpha_N_data_z0p84, + halpha_N_data_z1p47, + halpha_N_data_z2p23, + ] + + hizels_halpha_vol_Mpc3 = [ + halpha_vol_Mpc3_z0p4, + halpha_vol_Mpc3_z0p84, + halpha_vol_Mpc3_z1p47, + halpha_vol_Mpc3_z2p23, ] - hizels_halpha_LF_z_data = [ + hizels_halpha_z_data = [ jnp.float64(0.40), jnp.float64(0.84), jnp.float64(1.47), jnp.float64(2.23), ] - hizels_halpha_LF_delta_z_data = [ + hizels_halpha_delta_z_data = [ 0.02, 0.03, 0.032, @@ -190,7 +212,8 @@ def get_hizels_halpha(drn): return ( hizels_lg_halpha_Lbin_edges_data, - hizels_lg_halpha_LF_data, - hizels_halpha_LF_z_data, - hizels_halpha_LF_delta_z_data, + hizels_halpha_N_data, + hizels_halpha_vol_Mpc3, + hizels_halpha_z_data, + hizels_halpha_delta_z_data, ) diff --git a/diffhtwo/experimental/kernels/N_spec.py b/diffhtwo/experimental/kernels/N_spec.py new file mode 100644 index 00000000..047c0d69 --- /dev/null +++ b/diffhtwo/experimental/kernels/N_spec.py @@ -0,0 +1,40 @@ +import jax.numpy as jnp +from diffsky import diffndhist_lomem +from jax import jit as jjit + +from .lc_spec_kern import mc_specphot_kern_merging_wrapper + + +@jjit +def N_linelum( + ran_key, + line_wave_table, + lg_Lbin_edges, + lc_data, + param_collection, +): + spec_kern_results = mc_specphot_kern_merging_wrapper( + ran_key, + param_collection, + lc_data, + line_wave_table, + ) + lg_linelum_gal = jnp.log10(spec_kern_results.linelum_gal) + gal_weight = lc_data.cen_weight * lc_data.sat_weight + + sig = jnp.diff(lg_Lbin_edges) / 2 + sig = sig.reshape(sig.size, 1) + lg_Lbin_edges = lg_Lbin_edges.reshape(lg_Lbin_edges.size, 1) + + Lbin_lo = lg_Lbin_edges[:-1] + Lbin_hi = lg_Lbin_edges[1:] + + N_linelum = diffndhist_lomem.tw_ndhist_weighted( + lg_linelum_gal, + sig, + gal_weight, + Lbin_lo, + Lbin_hi, + ) + + return N_linelum diff --git a/diffhtwo/experimental/kernels/lc_spec_kern.py b/diffhtwo/experimental/kernels/lc_spec_kern.py new file mode 100644 index 00000000..03ad17b1 --- /dev/null +++ b/diffhtwo/experimental/kernels/lc_spec_kern.py @@ -0,0 +1,41 @@ +from diffsky.experimental.kernels import gd_specphot_kernels_merging as gspkm +from diffstar.defaults import FB +from dsps.cosmology import DEFAULT_COSMOLOGY +from jax import jit as jjit + + +@jjit +def mc_specphot_kern_merging_wrapper( + ran_key, + param_collection, + lc_data, + line_wave_table, + cosmo_params=DEFAULT_COSMOLOGY, + fb=FB, + mc_merge=0, +): + _res = gspkm._mc_specphot_kern_merging( + ran_key, + lc_data.z_obs, + lc_data.t_obs, + lc_data.mah_params, + lc_data.ssp_data, + lc_data.precomputed_ssp_mag_table, + lc_data.z_phot_table, + lc_data.wave_eff_table, + line_wave_table, + *param_collection, + cosmo_params, + fb, + lc_data.logmp_infall, + lc_data.logmhost_infall, + lc_data.t_infall, + lc_data.is_central, + lc_data.sat_weight, + lc_data.halo_indx, + mc_merge, + ) + + (phot_kern_results, phot_randoms, spec_kern_results) = _res + + return spec_kern_results diff --git a/diffhtwo/experimental/kernels/spec_kern.py b/diffhtwo/experimental/kernels/spec_kern.py index 30a1c4c7..4cb6e0a9 100644 --- a/diffhtwo/experimental/kernels/spec_kern.py +++ b/diffhtwo/experimental/kernels/spec_kern.py @@ -1,8 +1,8 @@ import jax.numpy as jnp from diffsky.burstpop import freqburst_mono from diffsky.experimental import mc_diffstarpop_wrappers as mcdw +from diffsky.experimental.kernels import gd_specphot_kernels_merging as gspkm from diffsky.experimental.kernels import mc_randoms -from diffsky.experimental.kernels import specphot_kernels_merging as spkm from diffstar.defaults import FB from dsps.cosmology import DEFAULT_COSMOLOGY from jax import jit as jjit @@ -22,7 +22,7 @@ def n_spec_kern( fb=FB, mc_merge=0, ): - _res = spkm._mc_specphot_kern_merging( + _res = gspkm._mc_specphot_kern_merging( ran_key, lc_data.z_obs, lc_data.t_obs, @@ -100,7 +100,7 @@ def n_spec_q_ms_burst( mc_is_burst = (mc_is_ms) & (mc_is_burst) mc_is_ms = (mc_is_ms) & (~mc_is_burst) - _res = spkm._mc_specphot_kern_merging( + _res = gspkm._mc_specphot_kern_merging( ran_key, lc_data.z_obs, lc_data.t_obs, diff --git a/diffhtwo/experimental/kernels/tests/test_N_spec.py b/diffhtwo/experimental/kernels/tests/test_N_spec.py new file mode 100644 index 00000000..ac5fad44 --- /dev/null +++ b/diffhtwo/experimental/kernels/tests/test_N_spec.py @@ -0,0 +1,22 @@ +import jax.numpy as jnp +import numpy as np +from diffsky.param_utils.diffsky_param_wrapper_merging import DEFAULT_PARAM_COLLECTION +from jax import random as jran + +from ..N_spec import N_linelum + + +def test_N_linelum(hizels): + ran_key = jran.key(0) + + line_wave_table = jnp.array([hizels.line_wave_aa[0]]) + N = N_linelum( + ran_key, + line_wave_table, + hizels.lg_Lbin_edges[0][0], + hizels.lc_data[0][1], + DEFAULT_PARAM_COLLECTION, + ) + + assert np.isfinite(N).all() + assert (N >= 0.0).all() diff --git a/diffhtwo/experimental/loss_kernels/emline_loss.py b/diffhtwo/experimental/loss_kernels/emline_loss.py index a97979ee..0a002b4e 100644 --- a/diffhtwo/experimental/loss_kernels/emline_loss.py +++ b/diffhtwo/experimental/loss_kernels/emline_loss.py @@ -3,34 +3,33 @@ from dsps.metallicity.umzr import DEFAULT_MZR_U_PARAMS from jax import jit as jjit -from ..kernels.spec_kern import n_spec_kern +from ..kernels.N_spec import N_linelum from ..param_utils import get_param_collection_from_u_theta -from .loss_functions import mse_w +from .loss_functions import poisson_loss @jjit def get_emline_loss( ran_key, - lg_emline_LF_target, - lg_emline_Lbin_edges, - param_collection, - lc_data, line_wave_aa, + lg_Lbin_edges, + N_data, + vol_Mpc3_data, + lc_data, + param_collection, ): line_wave_table = jnp.array([line_wave_aa]) - lg_emline_LF_model = n_spec_kern( + + N_model = N_linelum( ran_key, - param_collection, - lc_data, line_wave_table, - lg_emline_Lbin_edges, + lg_Lbin_edges, + lc_data, + param_collection, ) + N_model = N_model * (vol_Mpc3_data / lc_data.lc_tot_vol_mpc3) - emline_loss = mse_w( - lg_emline_LF_model, - lg_emline_LF_target[0], - lg_emline_LF_target[1], - ) + emline_loss = poisson_loss(N_model, N_data) return emline_loss @@ -38,21 +37,23 @@ def get_emline_loss( def _loss_emline_kern( u_theta, ran_key, - lg_emline_LF_target, - lg_emline_Lbin_edges, - lc_data, line_wave_aa, + lg_Lbin_edges, + N_data, + vol_Mpc3_data, + lc_data, u_mzr_params=DEFAULT_MZR_U_PARAMS, u_scatter_params=DEFAULT_SCATTER_U_PARAMS, ): param_collection = get_param_collection_from_u_theta(u_theta) emline_loss_args = ( ran_key, - lg_emline_LF_target, - lg_emline_Lbin_edges, - param_collection, - lc_data, line_wave_aa, + lg_Lbin_edges, + N_data, + vol_Mpc3_data, + lc_data, + param_collection, ) emline_loss = get_emline_loss(*emline_loss_args) return emline_loss @@ -62,24 +63,22 @@ def _loss_emline_kern( def _loss_emline_kern_multi_line_multi_z( u_theta, ran_key, - lg_emline_LF_target, - lg_emline_Lbin_edges, - emline_lc_data, - emline_wave_table, + fitting_data_multi_line_multi_z, ): emline_loss_multi_line_multi_z = 0.0 - n_line = len(emline_wave_table) + n_line = len(fitting_data_multi_line_multi_z.lg_Lbin_edges) for line in range(0, n_line): - n_z = len(lg_emline_LF_target[line]) + n_z = len(fitting_data_multi_line_multi_z.lg_Lbin_edges[line]) for z in range(0, n_z): emline_loss_args_z = ( u_theta, ran_key, - lg_emline_LF_target[line][z], - lg_emline_Lbin_edges[line][z], - emline_lc_data[line][z], - emline_wave_table[line], + fitting_data_multi_line_multi_z.line_wave_aa[line], + fitting_data_multi_line_multi_z.lg_Lbin_edges[line][z], + fitting_data_multi_line_multi_z.N_data[line][z], + fitting_data_multi_line_multi_z.vol_Mpc3_data[line][z], + fitting_data_multi_line_multi_z.lc_data[line][z], ) emline_loss_multi_line_multi_z += _loss_emline_kern(*emline_loss_args_z) diff --git a/diffhtwo/experimental/loss_kernels/emline_loss_mse.py b/diffhtwo/experimental/loss_kernels/emline_loss_mse.py new file mode 100644 index 00000000..a97979ee --- /dev/null +++ b/diffhtwo/experimental/loss_kernels/emline_loss_mse.py @@ -0,0 +1,86 @@ +import jax.numpy as jnp +from diffsky.experimental.scatter import DEFAULT_SCATTER_U_PARAMS +from dsps.metallicity.umzr import DEFAULT_MZR_U_PARAMS +from jax import jit as jjit + +from ..kernels.spec_kern import n_spec_kern +from ..param_utils import get_param_collection_from_u_theta +from .loss_functions import mse_w + + +@jjit +def get_emline_loss( + ran_key, + lg_emline_LF_target, + lg_emline_Lbin_edges, + param_collection, + lc_data, + line_wave_aa, +): + line_wave_table = jnp.array([line_wave_aa]) + lg_emline_LF_model = n_spec_kern( + ran_key, + param_collection, + lc_data, + line_wave_table, + lg_emline_Lbin_edges, + ) + + emline_loss = mse_w( + lg_emline_LF_model, + lg_emline_LF_target[0], + lg_emline_LF_target[1], + ) + + return emline_loss + + +def _loss_emline_kern( + u_theta, + ran_key, + lg_emline_LF_target, + lg_emline_Lbin_edges, + lc_data, + line_wave_aa, + u_mzr_params=DEFAULT_MZR_U_PARAMS, + u_scatter_params=DEFAULT_SCATTER_U_PARAMS, +): + param_collection = get_param_collection_from_u_theta(u_theta) + emline_loss_args = ( + ran_key, + lg_emline_LF_target, + lg_emline_Lbin_edges, + param_collection, + lc_data, + line_wave_aa, + ) + emline_loss = get_emline_loss(*emline_loss_args) + return emline_loss + + +@jjit +def _loss_emline_kern_multi_line_multi_z( + u_theta, + ran_key, + lg_emline_LF_target, + lg_emline_Lbin_edges, + emline_lc_data, + emline_wave_table, +): + emline_loss_multi_line_multi_z = 0.0 + + n_line = len(emline_wave_table) + for line in range(0, n_line): + n_z = len(lg_emline_LF_target[line]) + for z in range(0, n_z): + emline_loss_args_z = ( + u_theta, + ran_key, + lg_emline_LF_target[line][z], + lg_emline_Lbin_edges[line][z], + emline_lc_data[line][z], + emline_wave_table[line], + ) + emline_loss_multi_line_multi_z += _loss_emline_kern(*emline_loss_args_z) + + return emline_loss_multi_line_multi_z diff --git a/diffhtwo/experimental/loss_kernels/tests/test_emline_loss.py b/diffhtwo/experimental/loss_kernels/tests/test_emline_loss.py index c4f7cb98..e8577175 100644 --- a/diffhtwo/experimental/loss_kernels/tests/test_emline_loss.py +++ b/diffhtwo/experimental/loss_kernels/tests/test_emline_loss.py @@ -6,23 +6,26 @@ from ..emline_loss import _loss_emline_kern, get_emline_loss -def test_emline_loss(fake_subset_ssp_data, hizels): +def test_emline_loss(fake_subset_ssp_data, hizels_fitting_data): ssp_data, emline_wave_aa = fake_subset_ssp_data # pick first line, first zbin - lg_emline_LF_target = hizels.lg_LF[0][0] - lg_emline_Lbin_edges = hizels.lg_Lbin_edges[0][0] - lc_data = hizels.lc_data[0][0] + line_wave_aa = hizels_fitting_data.line_wave_aa[0] + lg_Lbin_edges = hizels_fitting_data.lg_Lbin_edges[0][0] + N_data = hizels_fitting_data.N_data[0][0] + vol_Mpc3_data = hizels_fitting_data.vol_Mpc3_data[0][0] + lc_data = hizels_fitting_data.lc_data[0][0] ran_key = jran.key(0) emline_loss = get_emline_loss( ran_key, - lg_emline_LF_target, - lg_emline_Lbin_edges, - DEFAULT_PARAM_COLLECTION, + line_wave_aa, + lg_Lbin_edges, + N_data, + vol_Mpc3_data, lc_data, - emline_wave_aa, + DEFAULT_PARAM_COLLECTION, ) assert np.isfinite(emline_loss) @@ -31,9 +34,10 @@ def test_emline_loss(fake_subset_ssp_data, hizels): emline_loss_kern = _loss_emline_kern( u_theta, ran_key, - lg_emline_LF_target, - lg_emline_Lbin_edges, + line_wave_aa, + lg_Lbin_edges, + N_data, + vol_Mpc3_data, lc_data, - emline_wave_aa, ) assert np.isfinite(emline_loss_kern) diff --git a/diffhtwo/experimental/loss_kernels/tests/test_emline_loss_mse.py b/diffhtwo/experimental/loss_kernels/tests/test_emline_loss_mse.py new file mode 100644 index 00000000..e631d30a --- /dev/null +++ b/diffhtwo/experimental/loss_kernels/tests/test_emline_loss_mse.py @@ -0,0 +1,41 @@ +import numpy as np +import pytest +from diffsky.param_utils.diffsky_param_wrapper_merging import DEFAULT_PARAM_COLLECTION +from jax import random as jran + +from ... import param_utils as pu +from ..emline_loss import _loss_emline_kern, get_emline_loss + + +@pytest.mark.skip(reason="Currently mse based emline loss code is outdated") +def test_emline_loss(fake_subset_ssp_data, hizels): + ssp_data, emline_wave_aa = fake_subset_ssp_data + + # pick first line, first zbin + lg_emline_LF_target = hizels.lg_LF[0][0] + lg_emline_Lbin_edges = hizels.lg_Lbin_edges[0][0] + lc_data = hizels.lc_data[0][0] + + ran_key = jran.key(0) + + emline_loss = get_emline_loss( + ran_key, + lg_emline_LF_target, + lg_emline_Lbin_edges, + DEFAULT_PARAM_COLLECTION, + lc_data, + emline_wave_aa, + ) + + assert np.isfinite(emline_loss) + + u_theta = pu.get_u_theta_from_param_collection(DEFAULT_PARAM_COLLECTION) + emline_loss_kern = _loss_emline_kern( + u_theta, + ran_key, + lg_emline_LF_target, + lg_emline_Lbin_edges, + lc_data, + emline_wave_aa, + ) + assert np.isfinite(emline_loss_kern) diff --git a/diffhtwo/experimental/optimizers/Np_specphot_opt.py b/diffhtwo/experimental/optimizers/Np_specphot_opt.py index 2179de48..2bfec05f 100644 --- a/diffhtwo/experimental/optimizers/Np_specphot_opt.py +++ b/diffhtwo/experimental/optimizers/Np_specphot_opt.py @@ -13,7 +13,7 @@ from jax import lax, value_and_grad, vmap from jax.example_libraries import optimizers as jax_opt -# from ..loss_kernels.emline_loss import _loss_emline_kern_multi_line_multi_z +from ..loss_kernels.emline_loss import _loss_emline_kern_multi_line_multi_z from ..loss_kernels.phot_loss import _loss_phot_kern _L_pk = ( @@ -58,10 +58,10 @@ def _opt_update(opt_state, i): ) # clip gradients - # global_norm = jnp.sqrt(sum(jnp.sum(g**2) for g in grads)) - # tau = 1.0 - # scale = jnp.minimum(1.0, tau / (global_norm + 1e-6)) - # grads = tuple(g * scale for g in grads) + global_norm = jnp.sqrt(sum(jnp.sum(g**2) for g in grads)) + tau = 1.0 + scale = jnp.minimum(1.0, tau / (global_norm + 1e-6)) + grads = tuple(g * scale for g in grads) opt_state = opt_update(i, grads, opt_state) return opt_state, loss @@ -80,11 +80,10 @@ def _loss_sdss_feniks_hizels( sdss_fitting_data, feniks_meta_data, feniks_fitting_data, - # hizels, - # line_wave_table, - fit_sdss=True, - fit_feniks=True, - # fit_hizels=False, + hizels_fitting_data, + fit_sdss=False, + fit_feniks=False, + fit_hizels=True, ): loss = 0.0 @@ -109,19 +108,16 @@ def _loss_sdss_feniks_hizels( loss += feniks_phot_loss # hizels - # if fit_hizels: - # hizels_emline_multi_line_multi_z_loss_args = ( - # u_theta, - # ran_key, - # hizels.lg_LF, - # hizels.lg_Lbin_edges, - # hizels.lc_data, - # line_wave_table, - # ) - # hizels_emline_loss = _loss_emline_kern_multi_line_multi_z( - # *hizels_emline_multi_line_multi_z_loss_args - # ) - # loss += hizels_emline_loss + if fit_hizels: + hizels_emline_multi_line_multi_z_loss_args = ( + u_theta, + ran_key, + hizels_fitting_data, + ) + hizels_emline_loss = _loss_emline_kern_multi_line_multi_z( + *hizels_emline_multi_line_multi_z_loss_args + ) + loss += hizels_emline_loss return loss @@ -138,8 +134,7 @@ def fit_sdss_feniks_hizels( sdss_fitting_data, feniks_meta_data, feniks_fitting_data, - # hizels, - # line_wave_table, + hizels_fitting_data, n_steps=2, step_size=1e-2, ): @@ -152,8 +147,7 @@ def fit_sdss_feniks_hizels( sdss_fitting_data, feniks_meta_data, feniks_fitting_data, - # hizels, - # line_wave_table, + hizels_fitting_data, ) def _opt_update(opt_state, i): @@ -166,6 +160,13 @@ def _opt_update(opt_state, i): grads = tuple( jnp.where(train, grad, 0.0) for grad, train in zip(grads, trainable) ) + + # clip gradients + global_norm = jnp.sqrt(sum(jnp.sum(g**2) for g in grads)) + tau = 1.0 + scale = jnp.minimum(1.0, tau / (global_norm + 1e-6)) + grads = tuple(g * scale for g in grads) + opt_state = opt_update(i, grads, opt_state) return opt_state, loss diff --git a/diffhtwo/experimental/optimizers/tests/test_Np_specphot_opt.py b/diffhtwo/experimental/optimizers/tests/test_Np_specphot_opt.py index 86c753b8..c1491baf 100644 --- a/diffhtwo/experimental/optimizers/tests/test_Np_specphot_opt.py +++ b/diffhtwo/experimental/optimizers/tests/test_Np_specphot_opt.py @@ -126,12 +126,10 @@ def test_phot_opt(ran_key, feniks_multi_z_data): assert check_param_collection_is_ok(param_collection_fit) -@pytest.mark.skip( - reason="This will be enabled when gd_specphot_kern_merging is implemented" -) -def test_specphot_opt(ran_key, fake_subset_ssp_data, feniks_multi_z_data, hizels): +def test_specphot_opt( + ran_key, fake_subset_ssp_data, feniks_multi_z_data, hizels_fitting_data +): ssp_data, emline_wave_aa = fake_subset_ssp_data - emline_wave_table = jnp.array([emline_wave_aa]) feniks_meta_data, feniks_fitting_data = feniks_multi_z_data @@ -146,8 +144,7 @@ def test_specphot_opt(ran_key, fake_subset_ssp_data, feniks_multi_z_data, hizels sdss_fitting_data, feniks_meta_data, feniks_fitting_data, - hizels, - emline_wave_table, + hizels_fitting_data, ) assert np.isfinite(loss) @@ -163,8 +160,7 @@ def test_specphot_opt(ran_key, fake_subset_ssp_data, feniks_multi_z_data, hizels sdss_fitting_data, feniks_meta_data, feniks_fitting_data, - hizels, - emline_wave_table, + hizels_fitting_data, n_steps=2, step_size=0.1, ) diff --git a/scripts/config_diagnostics.yaml b/scripts/config_diagnostics.yaml index 8f26e482..fdea774a 100644 --- a/scripts/config_diagnostics.yaml +++ b/scripts/config_diagnostics.yaml @@ -1,6 +1,6 @@ -model_drn: /Users/kumail/diffdir/fits/run122 -model_nickname: run122_all -fit_diagnostics_save_drn: /Users/kumail/diffdir/fits/run122/diagnostic_plots/all +model_drn: /Users/kumail/diffdir/fits/run126 +model_nickname: run126_diffstarpop+spspop+merging +fit_diagnostics_save_drn: /Users/kumail/diffdir/fits/run126/diagnostic_plots/diffstarpop+spspop+merging feniks_drn: /Users/kumail/diffdir/feniks sdss_drn: /Users/kumail/diffdir/sdss @@ -8,8 +8,8 @@ cosmos20_drn: /Users/kumail/diffdir/COSMOS20 dsps_drn: /Users/kumail/diffdir/COSMOS20/portal.nersc.gov/project/hacc/aphearin/DSPS_data ssp_file: /Users/kumail/diffdir/ssp_data/ssp_w_emlines/fsps_v0.4.7_mist_c3k_a_kroupa_wNE_logGasU-2.0_logGasZ0.0.h5 -plot_feniks: True -plot_sdss: False +plot_feniks: False +plot_sdss: True plots: num_halos : 3000 From a8da6095f5cfa7e940b628219d1dfdc38dc93470 Mon Sep 17 00:00:00 2001 From: Kumail Zaidi Date: Tue, 2 Jun 2026 18:09:22 -0500 Subject: [PATCH 2/5] plot_halpha --- .../experimental/data_loaders/load_hizels.py | 64 +++++- .../experimental/diagnostics/plot_halpha.py | 118 ++++++++++ diffhtwo/experimental/kernels/spec_kern.py | 64 ++++++ diffhtwo/experimental/utils.py | 68 ------ scripts/config_diffsky.yaml | 30 +++ scripts/fit_diffsky.py | 207 ++++++++++++++++++ 6 files changed, 474 insertions(+), 77 deletions(-) create mode 100644 diffhtwo/experimental/diagnostics/plot_halpha.py create mode 100644 scripts/config_diffsky.yaml create mode 100644 scripts/fit_diffsky.py diff --git a/diffhtwo/experimental/data_loaders/load_hizels.py b/diffhtwo/experimental/data_loaders/load_hizels.py index 39c89283..de3c7bfc 100644 --- a/diffhtwo/experimental/data_loaders/load_hizels.py +++ b/diffhtwo/experimental/data_loaders/load_hizels.py @@ -9,7 +9,16 @@ Hizels = namedtuple( "Hizels", - ["line_wave_aa", "lg_Lbin_edges", "N_data", "vol_Mpc3_data", "z", "dz", "lc_data"], + [ + "line_wave_aa", + "lg_Lbin_edges", + "N_data", + "vol_Mpc3_data", + "lg_phi_data", + "z", + "dz", + "lc_data", + ], ) DELTA_L_HALPHA = -0.4 # uncorrect HiZELS h-alpha L for dust (A_halpha = 1 mag) @@ -30,6 +39,7 @@ def get_hizels_data( hizels_lg_halpha_Lbin_edges_data, hizels_halpha_N_data, hizels_halpha_vol_Mpc3, + hizels_halpha_lg_phi_data, hizels_halpha_z_data, hizels_halpha_delta_z_data, ) = get_hizels_halpha(drn) @@ -38,6 +48,7 @@ def get_hizels_data( lg_Lbin_edges = [hizels_lg_halpha_Lbin_edges_data] N_data = [hizels_halpha_N_data] vol_Mpc3_data = [hizels_halpha_vol_Mpc3] + lg_phi_data = [hizels_halpha_lg_phi_data] z = [hizels_halpha_z_data] dz = [hizels_halpha_delta_z_data] @@ -70,7 +81,9 @@ def get_hizels_data( line_lc_data.append(generate_lc_data(*lc_args)) lc_data.append(line_lc_data) - return Hizels(line_wave_aa, lg_Lbin_edges, N_data, vol_Mpc3_data, z, dz, lc_data) + return Hizels( + line_wave_aa, lg_Lbin_edges, N_data, vol_Mpc3_data, lg_phi_data, z, dz, lc_data + ) def _get_lgL_bin_edges( @@ -132,22 +145,24 @@ def _lg_phi_corr_to_N_corr(lg_phi_corr, vol_1e4Mpc3): def get_hizels_halpha(drn): + """z0p4""" HiZELS_halpha_z0p4 = ascii.read(drn / "halpha_LF_z0p4.dat") - lg_halpha_Lbin_edges_z0p4 = _get_lgL_bin_edges( HiZELS_halpha_z0p4, "logLHa", "logLHa_binw_full" ) - # lg_halpha_LF_data_z0p4 = jnp.vstack( - # ( - # jnp.array(lg_phi_h0p7_to_hdefault(HiZELS_halpha_z0p4["logphi_corr"])), - # jnp.array(HiZELS_halpha_z0p4["logphi_corr_err"]), - # ) - # ) halpha_N_data_z0p4 = _lg_phi_corr_to_N_corr( HiZELS_halpha_z0p4["logphi_corr"], HiZELS_halpha_z0p4["vol_1e4Mpc3"] ) halpha_vol_Mpc3_z0p4 = _vol_h0p7_to_hdefault(HiZELS_halpha_z0p4["vol_1e4Mpc3"]) + halpha_lg_phi_data_z0p4 = jnp.vstack( + ( + jnp.array(_lg_phi_h0p7_to_hdefault(HiZELS_halpha_z0p4["logphi_corr"])), + jnp.array(HiZELS_halpha_z0p4["logphi_corr_err"]), + ) + ) + + """z0p84""" HiZELS_halpha_z0p84 = ascii.read(drn / "halpha_LF_z0p84.dat") lg_halpha_Lbin_edges_z0p84 = _get_lgL_bin_edges( HiZELS_halpha_z0p84, "logLHa", "logLHa_binw_full" @@ -157,6 +172,14 @@ def get_hizels_halpha(drn): ) halpha_vol_Mpc3_z0p84 = _vol_h0p7_to_hdefault(HiZELS_halpha_z0p84["vol_1e4Mpc3"]) + halpha_lg_phi_data_z0p84 = jnp.vstack( + ( + jnp.array(_lg_phi_h0p7_to_hdefault(HiZELS_halpha_z0p84["logphi_corr"])), + jnp.array(HiZELS_halpha_z0p84["logphi_corr_err"]), + ) + ) + + """z1p47""" HiZELS_halpha_z1p47 = ascii.read(drn / "halpha_LF_z1p47.dat") lg_halpha_Lbin_edges_z1p47 = _get_lgL_bin_edges( HiZELS_halpha_z1p47, "logLHa", "logLHa_binw_full" @@ -166,6 +189,14 @@ def get_hizels_halpha(drn): ) halpha_vol_Mpc3_z1p47 = _vol_h0p7_to_hdefault(HiZELS_halpha_z1p47["vol_1e4Mpc3"]) + halpha_lg_phi_data_z1p47 = jnp.vstack( + ( + jnp.array(_lg_phi_h0p7_to_hdefault(HiZELS_halpha_z1p47["logphi_corr"])), + jnp.array(HiZELS_halpha_z1p47["logphi_corr_err"]), + ) + ) + + """z2p23""" HiZELS_halpha_z2p23 = ascii.read(drn / "halpha_LF_z2p23.dat") lg_halpha_Lbin_edges_z2p23 = _get_lgL_bin_edges( HiZELS_halpha_z2p23, "logLHa", "logLHa_binw_full" @@ -175,6 +206,13 @@ def get_hizels_halpha(drn): ) halpha_vol_Mpc3_z2p23 = _vol_h0p7_to_hdefault(HiZELS_halpha_z2p23["vol_1e4Mpc3"]) + halpha_lg_phi_data_z2p23 = jnp.vstack( + ( + jnp.array(_lg_phi_h0p7_to_hdefault(HiZELS_halpha_z2p23["logphi_corr"])), + jnp.array(HiZELS_halpha_z2p23["logphi_corr_err"]), + ) + ) + hizels_lg_halpha_Lbin_edges_data = [ lg_halpha_Lbin_edges_z0p4, lg_halpha_Lbin_edges_z0p84, @@ -196,6 +234,13 @@ def get_hizels_halpha(drn): halpha_vol_Mpc3_z2p23, ] + hizels_halpha_lg_phi_data = [ + halpha_lg_phi_data_z0p4, + halpha_lg_phi_data_z0p84, + halpha_lg_phi_data_z1p47, + halpha_lg_phi_data_z2p23, + ] + hizels_halpha_z_data = [ jnp.float64(0.40), jnp.float64(0.84), @@ -214,6 +259,7 @@ def get_hizels_halpha(drn): hizels_lg_halpha_Lbin_edges_data, hizels_halpha_N_data, hizels_halpha_vol_Mpc3, + hizels_halpha_lg_phi_data, hizels_halpha_z_data, hizels_halpha_delta_z_data, ) diff --git a/diffhtwo/experimental/diagnostics/plot_halpha.py b/diffhtwo/experimental/diagnostics/plot_halpha.py new file mode 100644 index 00000000..3cef2b11 --- /dev/null +++ b/diffhtwo/experimental/diagnostics/plot_halpha.py @@ -0,0 +1,118 @@ +import matplotlib.pyplot as plt + +from diffhtwo.experimental.kernels.spec_kern import get_halpha_LF_q_ms_burst + + +def plot_halpha_ms_q_burst( + ran_key, + hizels, + param_collection, + ssp_data, + tcurves, + halpha_wave_aa, + model_nickname, + savedir, + plt_show=True, +): + alpha = 1 + lw = 2 + + xlims = [] + for i in range(0, 4): + xlims.append( + (hizels.lg_Lbin_edges[0][i].min(), hizels.lg_Lbin_edges[0][i].max()) + ) + ylim = (-5.5, -1.4) + + fig, ax = plt.subplots(1, 4, figsize=(12, 3.2)) + # fig.subplots_adjust(hspace=0.2, left=0.065, right=0.98, bottom=0.17, top=0.88) + + for i in range(0, 4): + ( + lgL_bin_centers, + lg_halpha_LF, + lg_halpha_LF_q, + lg_halpha_LF_ms, + lg_halpha_LF_burst, + ) = get_halpha_LF_q_ms_burst( + ran_key, + param_collection, + hizels.lg_Lbin_edges[0][i], + hizels.z[0][i], + hizels.dz[0][i], + ssp_data, + tcurves, # dummy arg, + halpha_wave_aa, + ) + ax[i].errorbar( + lgL_bin_centers, + hizels.lg_phi_data[0][i][0], + hizels.lg_phi_data[0][i][1], + color="k", + fmt="s", + markersize=5, + alpha=0.5, + label="HiZELS", + ) + + ax[i].plot( + lgL_bin_centers, + lg_halpha_LF, + color="k", + alpha=alpha, + label="diffsky", + lw=lw, + ) + ax[i].plot( + lgL_bin_centers, + lg_halpha_LF_burst, + color="orange", + alpha=alpha, + label="mc_is_burst", + lw=lw, + ) + ax[i].plot( + lgL_bin_centers, + lg_halpha_LF_ms, + color="deepskyblue", + alpha=alpha, + label="mc_is_ms", + lw=lw, + ) + ax[i].plot( + lgL_bin_centers, + lg_halpha_LF_q, + color="darkred", + alpha=alpha, + label="mc_is_q", + lw=lw, + ) + + ax[i].set_xlim(xlims[i]) + ax[i].set_ylim(ylim) + ax[i].set_title(" z = " + str(hizels.z[0][i]), y=0.85) + + ax[i].tick_params( + axis="both", + which="both", # major + minor + direction="in", + top=True, + right=True, + length=3, + width=0.6, + labelsize=10, + ) + + fig.supxlabel("log$_{10}$ (L$_{H\u03b1}$ [erg/s])", fontsize=14) + fig.supylabel("log$_{10}($\u03d5 [Mpc$^{-3}$])", fontsize=14) + plt.rcParams["legend.fontsize"] = 8 + ax[-1].legend(loc="lower left", framealpha=0.5) + + fig.savefig( + savedir + "/" + model_nickname + "_halpha_LF" + ".png", + bbox_inches="tight", + dpi=200, + ) + if plt_show: + plt.show() + plt.close() diff --git a/diffhtwo/experimental/kernels/spec_kern.py b/diffhtwo/experimental/kernels/spec_kern.py index 4cb6e0a9..6236a38c 100644 --- a/diffhtwo/experimental/kernels/spec_kern.py +++ b/diffhtwo/experimental/kernels/spec_kern.py @@ -1,4 +1,5 @@ import jax.numpy as jnp +import numpy as np from diffsky.burstpop import freqburst_mono from diffsky.experimental import mc_diffstarpop_wrappers as mcdw from diffsky.experimental.kernels import gd_specphot_kernels_merging as gspkm @@ -8,6 +9,7 @@ from jax import jit as jjit from .. import emline_luminosity +from ..lightcone_generators import generate_lc_data from .gehrels_err import N_0, N_FLOOR @@ -175,3 +177,65 @@ def n_spec_q_ms_burst( lg_emline_LF_burst = jnp.log10(emline_N_burst / lc_data.lc_tot_vol_mpc3) return lg_emline_LF, lg_emline_LF_q, lg_emline_LF_ms, lg_emline_LF_burst + + +def get_halpha_LF_q_ms_burst( + ran_key, + param_collection, + lgL_bin_edges, + halpha_LF_z, + halpha_LF_delta_z, + ssp_data, + tcurves, + halpha_wave_aa, + lgmp_min=10.0, + lgmp_max=15.0, + num_halos=100, + sky_area_degsq=10000, + n_z_phot_table=15, + cosmo_params=DEFAULT_COSMOLOGY, + fb=FB, +): + halpha_lc_z_min = halpha_LF_z - (halpha_LF_delta_z / 2) + halpha_lc_z_max = halpha_LF_z + (halpha_LF_delta_z / 2) + z_phot_table = 10 ** np.linspace( + np.log10(halpha_lc_z_min), np.log10(halpha_lc_z_max), n_z_phot_table + ) + + lc_args = ( + ran_key, + num_halos, + halpha_lc_z_min, + halpha_lc_z_max, + lgmp_min, + lgmp_max, + sky_area_degsq, + ssp_data, + tcurves, + z_phot_table, + ) + lc_data = generate_lc_data(*lc_args) + + line_wave_table = jnp.array([halpha_wave_aa]) + ( + lg_halpha_LF, + lg_halpha_LF_q, + lg_halpha_LF_ms, + lg_halpha_LF_burst, + ) = n_spec_q_ms_burst( + ran_key, + param_collection, + lc_data, + line_wave_table, + lgL_bin_edges, + ) + + lgL_bin_centers = 0.5 * (lgL_bin_edges[1:] + lgL_bin_edges[:-1]) + + return ( + lgL_bin_centers, + lg_halpha_LF, + lg_halpha_LF_q, + lg_halpha_LF_ms, + lg_halpha_LF_burst, + ) diff --git a/diffhtwo/experimental/utils.py b/diffhtwo/experimental/utils.py index 55b00f07..890f67fa 100644 --- a/diffhtwo/experimental/utils.py +++ b/diffhtwo/experimental/utils.py @@ -2,15 +2,9 @@ import jax.numpy as jnp import numpy as np -from diffsky.mass_functions import mc_hosts -from diffstar.defaults import FB -from dsps.cosmology.defaults import DEFAULT_COSMOLOGY from jax import jit as jjit from jax.tree_util import tree_flatten_with_path -from .kernels.spec_kern import n_spec_q_ms_burst -from .lightcone_generators import generate_lc_data - @jjit def lupton_log10(t, log10_clip, t0=0.0, M0=0.0, alpha=1 / jnp.log(10.0)): @@ -39,68 +33,6 @@ def safe_log10(x, EPS=1e-12): return jnp.log(jnp.clip(x, EPS, jnp.inf)) / jnp.log(10.0) -def get_halpha_LF_q_ms_burst( - ran_key, - param_collection, - lgL_bin_edges, - halpha_LF_z, - halpha_LF_delta_z, - ssp_data, - tcurves, - halpha_wave_aa, - lgmp_min=10.0, - lgmp_max=mc_hosts.LGMH_MAX, - num_halos=10000, - sky_area_degsq=10000, - n_z_phot_table=15, - cosmo_params=DEFAULT_COSMOLOGY, - fb=FB, -): - halpha_lc_z_min = halpha_LF_z - (halpha_LF_delta_z / 2) - halpha_lc_z_max = halpha_LF_z + (halpha_LF_delta_z / 2) - z_phot_table = 10 ** np.linspace( - np.log10(halpha_lc_z_min), np.log10(halpha_lc_z_max), n_z_phot_table - ) - - lc_args = ( - ran_key, - num_halos, - halpha_lc_z_min, - halpha_lc_z_max, - lgmp_min, - lgmp_max, - sky_area_degsq, - ssp_data, - tcurves, - z_phot_table, - ) - lc_data = generate_lc_data(*lc_args) - - line_wave_table = jnp.array([halpha_wave_aa]) - ( - lg_halpha_LF, - lg_halpha_LF_q, - lg_halpha_LF_ms, - lg_halpha_LF_burst, - ) = n_spec_q_ms_burst( - ran_key, - param_collection, - lc_data, - line_wave_table, - lgL_bin_edges, - ) - - lgL_bin_centers = 0.5 * (lgL_bin_edges[1:] + lgL_bin_edges[:-1]) - - return ( - lgL_bin_centers, - lg_halpha_LF, - lg_halpha_LF_q, - lg_halpha_LF_ms, - lg_halpha_LF_burst, - ) - - def load_feniks_tcurve(tcurve_filename): tcurve = np.loadtxt(tcurve_filename) wave = tcurve[:, 0] diff --git a/scripts/config_diffsky.yaml b/scripts/config_diffsky.yaml new file mode 100644 index 00000000..d367eca2 --- /dev/null +++ b/scripts/config_diffsky.yaml @@ -0,0 +1,30 @@ +base_path: "/Users/kumail/diffdir" + +start_runid: "run90" +start_fit_type: "all" + +fit_runid: "runtest" +fit_type: "all" + +sdss: + N_centroids: 100 + num_halos: 100 + +feniks: + lh_d_mag: 0.4 + N_centroids: 100 + num_halos: 100 + +hizels: + num_halos: 100 + +epoch: + n_it: 1 + n_steps: 2 + step_size: 0.1 + +defaults: + diffstarpop: True + spspop: True + ssperr: True + merging: True diff --git a/scripts/fit_diffsky.py b/scripts/fit_diffsky.py new file mode 100644 index 00000000..3a82cfb7 --- /dev/null +++ b/scripts/fit_diffsky.py @@ -0,0 +1,207 @@ +import argparse +import os +import time +from datetime import datetime +from pathlib import Path + +import jax +import jax.numpy as jnp +import matplotlib.pyplot as plt +import numpy as np +import yaml +from diffsky.data_loaders.hacc_utils import lc_mock +from diffsky.merging.merging_model import DEFAULT_MERGE_PARAMS +from diffsky.param_utils.spspop_param_utils import DEFAULT_SPSPOP_PARAMS +from diffsky.ssp_err_model.defaults import ZERO_SSPERR_PARAMS +from diffstar.diffstarpop.kernels.params.params_diffstarpopfits_mgash import ( + DiffstarPop_Params_Diffstarpopfits_mgash, +) +from dsps import load_ssp_templates +from dsps.data_loaders import load_emline_info as lemi +from jax import random as jran + +from diffhtwo.experimental import param_utils as pu +from diffhtwo.experimental.data_loaders import load_feniks, load_hizels, load_sdss +from diffhtwo.experimental.defaults import FENIKS_Z_MIN, SDSS_Z_MAX, SDSS_Z_MIN +from diffhtwo.experimental.latin_hypercube import lh_utils as lhu +from diffhtwo.experimental.optimizers import Np_specphot_opt + +DIFFSTARPOP_GALACTICUS_exsitu = DiffstarPop_Params_Diffstarpopfits_mgash[ + "galacticus_in_plus_ex_situ" +] + +if __name__ == "__main__": + p = argparse.ArgumentParser() + p.add_argument("--config", default="config_diffsky.yaml") + args = p.parse_args() + + with open(args.config) as f: + cfg = yaml.safe_load(f) + + sdss_drn = cfg["base_path"] + "/sdss" + feniks_drn = cfg["base_path"] + "/feniks" + hizels_drn = Path(cfg["base_path"] + "/hizels") + ssp_filename = ( + cfg["base_path"] + + "/ssp_data/ssp_w_emlines/fsps_v0.4.7_mist_c3k_a_kroupa_wNE_logGasU-2.0_logGasZ0.0.h5" + ) + + # get ssp data + ssp_data = load_ssp_templates(fn=ssp_filename) + ssp_data = lemi.get_subset_emline_data(ssp_data, ["Ba_alpha_6563"]) + halpha_wave_aa = jnp.array(ssp_data.ssp_emline_wave[0]) + + # load sdss data + ran_key = jran.key(0) + SDSS = load_sdss.get_sdss_data(sdss_drn, ran_key, ssp_data) + + # load feniks data + ran_key = jran.key(0) + FENIKS = load_feniks.get_feniks_data( + feniks_drn, ran_key, ssp_data, lh_d_mag=cfg["feniks"]["lh_d_mag"] + ) + + # load hizels data + hizels_fitting_data = load_hizels.get_hizels_data( + hizels_drn, + ran_key, + ssp_data, + SDSS.filter_info.tcurves, + halpha_wave_aa, + num_halos=cfg["hizels"]["num_halos"], + ) + + # start fit dirs + fit_start_drn = cfg["base_path"] + "/fits/" + cfg["start_runid"] + "/" + param_collection_fit = lc_mock.load_diffsky_param_collection_merging( + fit_start_drn, + cfg["start_runid"] + "_" + cfg["start_fit_type"], + ) + if cfg["defaults"]["diffstarpop"]: + param_collection_fit = param_collection_fit._replace( + diffstarpop_params=DIFFSTARPOP_GALACTICUS_exsitu + ) + if cfg["defaults"]["spspop"]: + param_collection_fit = param_collection_fit._replace( + spspop_params=DEFAULT_SPSPOP_PARAMS + ) + if cfg["defaults"]["ssperr"]: + param_collection_fit = param_collection_fit._replace( + ssperr_params=ZERO_SSPERR_PARAMS + ) + if cfg["defaults"]["merging"]: + param_collection_fit = param_collection_fit._replace( + merging_params=DEFAULT_MERGE_PARAMS + ) + + u_theta_fit = pu.get_u_theta_from_param_collection(param_collection_fit) + + # fit dirs + trainable_params = pu.get_trainable_params(fit_type=cfg["fit_type"]) + fit_save_drn = cfg["base_path"] + "/fits/" + cfg["fit_runid"] + "/" + fit_diagnostics_save_drn = ( + cfg["base_path"] + + "/fits/" + + cfg["fit_runid"] + + "/diagnostic_plots/" + + cfg["fit_type"] + ) + os.makedirs(fit_diagnostics_save_drn + "/loss", exist_ok=True) + os.makedirs(fit_diagnostics_save_drn + "/lh_N_z", exist_ok=True) + + os.system(f"cp {args.config} {fit_diagnostics_save_drn}") + + # SDSS + sdss_z_min = [SDSS_Z_MIN, 0.08, 0.14] + sdss_z_max = [0.08, 0.14, SDSS_Z_MAX] + + # FENIKS + feniks_z_min = [FENIKS_Z_MIN, 1] + feniks_z_max = [1, 2] + + initial_pts = [] + start = time.time() + for epoch in range(0, cfg["epoch"]["n_it"]): + print(f'Running Epoch {epoch+1}/{cfg["epoch"]["n_it"]}...') + + # SDSS + sdss = load_sdss.refresh_lh_centroids(SDSS) + sdss_meta_data, sdss_fitting_data = lhu.get_zbins_lh_lc( + ran_key, + SDSS, + sdss_z_min, + sdss_z_max, + ssp_data, + cfg["sdss"]["N_centroids"], + lh_N_z_savedir=fit_diagnostics_save_drn + "/lh_N_z", + num_halos=cfg["sdss"]["num_halos"], + ) + + # FENIKS + FENIKS = load_feniks.refresh_lh_centroids(FENIKS, cfg["feniks"]["lh_d_mag"]) + feniks_meta_data, feniks_fitting_data = lhu.get_zbins_lh_lc( + ran_key, + FENIKS, + feniks_z_min, + feniks_z_max, + ssp_data, + cfg["feniks"]["N_centroids"], + lh_N_z_savedir=fit_diagnostics_save_drn + "/lh_N_z", + num_halos=cfg["feniks"]["num_halos"], + ) + + loss_hist, u_theta_fit = Np_specphot_opt.fit_sdss_feniks_hizels( + u_theta_fit, + trainable_params, + ran_key, + sdss_meta_data, + sdss_fitting_data, + feniks_meta_data, + feniks_fitting_data, + hizels_fitting_data, + n_steps=cfg["epoch"]["n_steps"], + step_size=cfg["epoch"]["step_size"], + ) + + jax.clear_caches() + + param_collection_fit = pu.get_param_collection_from_u_theta(u_theta_fit) + lc_mock.write_diffsky_param_collection_merging( + fit_save_drn, + cfg["fit_runid"] + "_" + cfg["fit_type"], + param_collection_fit, + ) + + if epoch == 0: + STEPS = np.arange(1, cfg["epoch"]["n_steps"] + 1, 1) + + LOSS_HIST = loss_hist + + initial_pts.append((STEPS[0], LOSS_HIST[0])) + else: + steps = np.arange(STEPS[-1] + 1, STEPS[-1] + cfg["epoch"]["n_steps"] + 1, 1) + initial_pts.append((steps[0], loss_hist[0])) + + STEPS = np.concatenate((STEPS, steps)) + LOSS_HIST = np.concatenate((LOSS_HIST, loss_hist)) + + end = time.time() + elapsed = end - start + print( + f'Gradient descent took {elapsed/60:.3f} minutes for {cfg["epoch"]["n_steps"]*cfg["epoch"]["n_it"]} steps.' + ) + print(f'speed: {elapsed/(cfg["epoch"]["n_steps"]*cfg["epoch"]["n_it"]):.3f} s/it') + + # gradient descent figure + fig_loss, ax_loss = plt.subplots(1) + + start_step = [s[0] for s in initial_pts] + start_loss = [s[1] for s in initial_pts] + ax_loss.scatter(start_step, start_loss, s=50, c="deepskyblue") + + ax_loss.plot(STEPS, LOSS_HIST, c="deepskyblue") + ax_loss.set_ylabel("Poisson Loss") + ax_loss.set_xlabel("steps") + ts = datetime.now().strftime("%Y-%m-%d_%H:%M:%S") + plt.savefig(fit_diagnostics_save_drn + "/loss/loss_" + ts + ".png") + plt.close() From 29290740301b4a982baf75f5748184bfa4527b57 Mon Sep 17 00:00:00 2001 From: Kumail Zaidi Date: Tue, 2 Jun 2026 19:52:46 -0500 Subject: [PATCH 3/5] fit_feniks=True --- diffhtwo/experimental/diagnostics/plot_halpha.py | 4 +++- diffhtwo/experimental/optimizers/Np_specphot_opt.py | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/diffhtwo/experimental/diagnostics/plot_halpha.py b/diffhtwo/experimental/diagnostics/plot_halpha.py index 3cef2b11..8fe69299 100644 --- a/diffhtwo/experimental/diagnostics/plot_halpha.py +++ b/diffhtwo/experimental/diagnostics/plot_halpha.py @@ -12,6 +12,7 @@ def plot_halpha_ms_q_burst( halpha_wave_aa, model_nickname, savedir, + num_halos=100, plt_show=True, ): alpha = 1 @@ -25,7 +26,7 @@ def plot_halpha_ms_q_burst( ylim = (-5.5, -1.4) fig, ax = plt.subplots(1, 4, figsize=(12, 3.2)) - # fig.subplots_adjust(hspace=0.2, left=0.065, right=0.98, bottom=0.17, top=0.88) + fig.subplots_adjust(hspace=0.2, left=0.065, right=0.98, bottom=0.17, top=0.88) for i in range(0, 4): ( @@ -43,6 +44,7 @@ def plot_halpha_ms_q_burst( ssp_data, tcurves, # dummy arg, halpha_wave_aa, + num_halos=num_halos, ) ax[i].errorbar( lgL_bin_centers, diff --git a/diffhtwo/experimental/optimizers/Np_specphot_opt.py b/diffhtwo/experimental/optimizers/Np_specphot_opt.py index 2bfec05f..9d5995a1 100644 --- a/diffhtwo/experimental/optimizers/Np_specphot_opt.py +++ b/diffhtwo/experimental/optimizers/Np_specphot_opt.py @@ -82,7 +82,7 @@ def _loss_sdss_feniks_hizels( feniks_fitting_data, hizels_fitting_data, fit_sdss=False, - fit_feniks=False, + fit_feniks=True, fit_hizels=True, ): loss = 0.0 From 26a270cadf9af773b2a57671e79165c04d5cacab Mon Sep 17 00:00:00 2001 From: Kumail Zaidi Date: Tue, 2 Jun 2026 22:38:13 -0500 Subject: [PATCH 4/5] linelum_gal-->linelum_weighted --- diffhtwo/experimental/kernels/N_spec.py | 4 +-- diffhtwo/experimental/kernels/spec_kern.py | 4 +-- scripts/config_diagnostics.yaml | 16 +++++---- scripts/generate_diagnostic_plots.py | 39 +++++++++++++++++++--- 4 files changed, 49 insertions(+), 14 deletions(-) diff --git a/diffhtwo/experimental/kernels/N_spec.py b/diffhtwo/experimental/kernels/N_spec.py index 047c0d69..414e4b21 100644 --- a/diffhtwo/experimental/kernels/N_spec.py +++ b/diffhtwo/experimental/kernels/N_spec.py @@ -19,7 +19,7 @@ def N_linelum( lc_data, line_wave_table, ) - lg_linelum_gal = jnp.log10(spec_kern_results.linelum_gal) + lg_linelum_weighted = jnp.log10(spec_kern_results.linelum_weighted) gal_weight = lc_data.cen_weight * lc_data.sat_weight sig = jnp.diff(lg_Lbin_edges) / 2 @@ -30,7 +30,7 @@ def N_linelum( Lbin_hi = lg_Lbin_edges[1:] N_linelum = diffndhist_lomem.tw_ndhist_weighted( - lg_linelum_gal, + lg_linelum_weighted, sig, gal_weight, Lbin_lo, diff --git a/diffhtwo/experimental/kernels/spec_kern.py b/diffhtwo/experimental/kernels/spec_kern.py index 6236a38c..7b261de4 100644 --- a/diffhtwo/experimental/kernels/spec_kern.py +++ b/diffhtwo/experimental/kernels/spec_kern.py @@ -47,13 +47,13 @@ def n_spec_kern( ) (phot_kern_results, phot_randoms, spec_kern_results) = _res - linelum_gal = spec_kern_results.linelum_gal + linelum_weighted = spec_kern_results.linelum_weighted gal_weight = lc_data.cen_weight * lc_data.sat_weight sig = jnp.diff(lg_emline_Lbin_edges) / 2 sig = sig.reshape(sig.size, 1) _, emline_N = emline_luminosity.get_emline_luminosity_func( - linelum_gal, + linelum_weighted, gal_weight, sig=sig, lgL_bin_edges=lg_emline_Lbin_edges, diff --git a/scripts/config_diagnostics.yaml b/scripts/config_diagnostics.yaml index fdea774a..a335cf12 100644 --- a/scripts/config_diagnostics.yaml +++ b/scripts/config_diagnostics.yaml @@ -1,15 +1,19 @@ -model_drn: /Users/kumail/diffdir/fits/run126 -model_nickname: run126_diffstarpop+spspop+merging -fit_diagnostics_save_drn: /Users/kumail/diffdir/fits/run126/diagnostic_plots/diffstarpop+spspop+merging +model_drn: /Users/kumail/diffdir/fits/run127 +model_nickname: run127_diffstarpop+spspop+merging +fit_diagnostics_save_drn: /Users/kumail/diffdir/fits/run127/diagnostic_plots/diffstarpop+spspop+merging -feniks_drn: /Users/kumail/diffdir/feniks sdss_drn: /Users/kumail/diffdir/sdss +feniks_drn: /Users/kumail/diffdir/feniks +hizels_drn: /Users/kumail/diffdir/hizels + cosmos20_drn: /Users/kumail/diffdir/COSMOS20 dsps_drn: /Users/kumail/diffdir/COSMOS20/portal.nersc.gov/project/hacc/aphearin/DSPS_data ssp_file: /Users/kumail/diffdir/ssp_data/ssp_w_emlines/fsps_v0.4.7_mist_c3k_a_kroupa_wNE_logGasU-2.0_logGasZ0.0.h5 -plot_feniks: False -plot_sdss: True + +plot_sdss: False +plot_feniks: True +plot_hizels: True plots: num_halos : 3000 diff --git a/scripts/generate_diagnostic_plots.py b/scripts/generate_diagnostic_plots.py index 49f43669..ad8530ff 100644 --- a/scripts/generate_diagnostic_plots.py +++ b/scripts/generate_diagnostic_plots.py @@ -1,5 +1,6 @@ import argparse import os +from pathlib import Path import jax.numpy as jnp import numpy as np @@ -13,7 +14,7 @@ from dsps.data_loaders import load_emline_info as lemi from jax import random as jran -from diffhtwo.experimental.data_loaders import load_feniks, load_sdss +from diffhtwo.experimental.data_loaders import load_feniks, load_hizels, load_sdss from diffhtwo.experimental.defaults import ( FENIKS_Z_MAX, FENIKS_Z_MIN, @@ -28,6 +29,7 @@ plot_lgfburst_mh_z, ) from diffhtwo.experimental.diagnostics.plot_cen import plot_massive_cen_colors +from diffhtwo.experimental.diagnostics.plot_halpha import plot_halpha_ms_q_burst from diffhtwo.experimental.diagnostics.plot_insitu_sm import plot_insitu_sm from diffhtwo.experimental.diagnostics.plot_phot import ( plot_app_mag_funcs, @@ -53,8 +55,11 @@ # get directories/files os.environ["COSMOS20_DRN"] = cfg["cosmos20_drn"] os.environ["DSPS_DRN"] = cfg["dsps_drn"] - feniks_drn = cfg["feniks_drn"] + sdss_drn = cfg["sdss_drn"] + feniks_drn = cfg["feniks_drn"] + hizels_drn = Path(cfg["hizels_drn"]) + ssp_filename = cfg["ssp_file"] fit_diagnostics_save_drn = cfg["fit_diagnostics_save_drn"] param_collection_fit = lc_mock.load_diffsky_param_collection_merging( @@ -67,8 +72,7 @@ # get ssp data ssp_data = load_ssp_templates(fn=ssp_filename) ssp_data = lemi.get_subset_emline_data(ssp_data, ["Ba_alpha_6563"]) - emline_wave_aa = jnp.array(ssp_data.ssp_emline_wave[0]) - emline_wave_table = jnp.array([emline_wave_aa]) + halpha_wave_aa = jnp.array(ssp_data.ssp_emline_wave[0]) ran_key = jran.key(0) if cfg["plots"]["plot_satquench_model"]: @@ -103,6 +107,33 @@ label1="fit", ) + """ + Plot HiZELS + """ + if cfg["plot_hizels"]: + feniks = load_feniks.get_feniks_data(feniks_drn, ran_key, ssp_data) + hizels_drn = Path(hizels_drn) + hizels_label = "hizels" + hizels = load_hizels.get_hizels_data( + hizels_drn, + ran_key, + ssp_data, + feniks.filter_info.tcurves, + halpha_wave_aa, + ) + plot_halpha_ms_q_burst( + ran_key, + hizels, + param_collection_fit, + ssp_data, + feniks.filter_info.tcurves, + halpha_wave_aa, + hizels_label, + fit_diagnostics_save_drn, + num_halos=num_halos, + plt_show=False, + ) + """ Plot FENIKS """ From 9c5ced48db4721506c2825a4671a799c39cb9734 Mon Sep 17 00:00:00 2001 From: Kumail Zaidi Date: Tue, 2 Jun 2026 22:42:39 -0500 Subject: [PATCH 5/5] fit_feniks=False --- diffhtwo/experimental/optimizers/Np_specphot_opt.py | 2 +- scripts/config_diagnostics.yaml | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/diffhtwo/experimental/optimizers/Np_specphot_opt.py b/diffhtwo/experimental/optimizers/Np_specphot_opt.py index 9d5995a1..2bfec05f 100644 --- a/diffhtwo/experimental/optimizers/Np_specphot_opt.py +++ b/diffhtwo/experimental/optimizers/Np_specphot_opt.py @@ -82,7 +82,7 @@ def _loss_sdss_feniks_hizels( feniks_fitting_data, hizels_fitting_data, fit_sdss=False, - fit_feniks=True, + fit_feniks=False, fit_hizels=True, ): loss = 0.0 diff --git a/scripts/config_diagnostics.yaml b/scripts/config_diagnostics.yaml index a335cf12..2b0b76b6 100644 --- a/scripts/config_diagnostics.yaml +++ b/scripts/config_diagnostics.yaml @@ -1,6 +1,6 @@ -model_drn: /Users/kumail/diffdir/fits/run127 -model_nickname: run127_diffstarpop+spspop+merging -fit_diagnostics_save_drn: /Users/kumail/diffdir/fits/run127/diagnostic_plots/diffstarpop+spspop+merging +model_drn: /Users/kumail/diffdir/fits/run128 +model_nickname: run128_diffstarpop+spspop+merging +fit_diagnostics_save_drn: /Users/kumail/diffdir/fits/run128/diagnostic_plots/diffstarpop+spspop+merging sdss_drn: /Users/kumail/diffdir/sdss feniks_drn: /Users/kumail/diffdir/feniks