diff --git a/diffstar/data_loaders/load_smah_data.py b/diffstar/data_loaders/load_smah_data.py index 59412f8..48aa091 100644 --- a/diffstar/data_loaders/load_smah_data.py +++ b/diffstar/data_loaders/load_smah_data.py @@ -327,6 +327,7 @@ def load_tng_data(data_drn=BEBOP): log_smahs = np.log10(mstarh) log_mahs = halos["mpeakh"] + log_mahs = log_mahs - np.log10(H_TNG) log_mahs = np.maximum.accumulate(log_mahs, axis=1) logmp0 = log_mahs[:, -1] diff --git a/diffstar/diffstarpop/kernels/params/params_diffstarfits_mgash_tng.py b/diffstar/diffstarpop/kernels/params/params_diffstarfits_mgash_tng.py index 21553f7..bcaee1c 100644 --- a/diffstar/diffstarpop/kernels/params/params_diffstarfits_mgash_tng.py +++ b/diffstar/diffstarpop/kernels/params/params_diffstarfits_mgash_tng.py @@ -9,74 +9,74 @@ SFH_PDF_QUENCH_MU_PDICT = OrderedDict( [ - ("mean_ulgm_mseq_xtp", 11.300), - ("mean_ulgm_mseq_ytp", 11.300), - ("mean_ulgm_mseq_lo", 1.296), - ("mean_ulgm_mseq_hi", 0.336), - ("mean_ulgy_mseq_xtp", 13.509), - ("mean_ulgy_mseq_ytp", -9.297), - ("mean_ulgy_mseq_lo", 0.543), - ("mean_ulgy_mseq_hi", 0.247), - ("mean_ul_mseq_int", -0.358), - ("mean_ul_mseq_slp", 1.187), - ("mean_uh_mseq_int", -0.807), - ("mean_uh_mseq_slp", 0.066), - ("mean_ulgm_qseq_xtp", 12.622), - ("mean_ulgm_qseq_ytp", 11.767), - ("mean_ulgm_qseq_lo", 0.289), - ("mean_ulgm_qseq_hi", 0.297), - ("mean_ulgy_qseq_xtp", 12.334), - ("mean_ulgy_qseq_ytp", -9.655), - ("mean_ulgy_qseq_lo", 0.572), - ("mean_ulgy_qseq_hi", 0.388), - ("mean_ul_qseq_int", -0.354), - ("mean_ul_qseq_slp", 1.694), - ("mean_uh_qseq_int", -0.924), - ("mean_uh_qseq_slp", -0.869), - ("mean_uqt_xtp", 13.226), - ("mean_uqt_ytp", 0.929), + ("mean_ulgm_mseq_xtp", 11.738), + ("mean_ulgm_mseq_ytp", 11.530), + ("mean_ulgm_mseq_lo", 1.144), + ("mean_ulgm_mseq_hi", 0.321), + ("mean_ulgy_mseq_xtp", 13.700), + ("mean_ulgy_mseq_ytp", -9.450), + ("mean_ulgy_mseq_lo", 0.514), + ("mean_ulgy_mseq_hi", -0.689), + ("mean_ul_mseq_int", -0.413), + ("mean_ul_mseq_slp", 0.872), + ("mean_uh_mseq_int", -1.221), + ("mean_uh_mseq_slp", -0.443), + ("mean_ulgm_qseq_xtp", 11.300), + ("mean_ulgm_qseq_ytp", 11.362), + ("mean_ulgm_qseq_lo", 1.505), + ("mean_ulgm_qseq_hi", 0.340), + ("mean_ulgy_qseq_xtp", 12.335), + ("mean_ulgy_qseq_ytp", -9.903), + ("mean_ulgy_qseq_lo", 0.533), + ("mean_ulgy_qseq_hi", 0.400), + ("mean_ul_qseq_int", -0.629), + ("mean_ul_qseq_slp", 1.597), + ("mean_uh_qseq_int", -0.877), + ("mean_uh_qseq_slp", -1.015), + ("mean_uqt_xtp", 12.141), + ("mean_uqt_ytp", 1.035), ("mean_uqt_lo", -0.100), ("mean_uqt_hi", -0.100), - ("mean_uqs_int", 0.144), - ("mean_uqs_slp", -0.081), - ("mean_udrop_int", -2.026), - ("mean_udrop_slp", 0.543), - ("mean_urej_int", -0.932), - ("mean_urej_slp", 0.409), + ("mean_uqs_int", 0.136), + ("mean_uqs_slp", -0.008), + ("mean_udrop_int", -2.077), + ("mean_udrop_slp", 0.489), + ("mean_urej_int", -1.012), + ("mean_urej_slp", 0.420), ] ) SFH_PDF_QUENCH_COV_MS_BLOCK_PDICT = OrderedDict( [ - ("std_ulgm_mseq_int", 0.379), - ("std_ulgm_mseq_slp", 0.025), - ("std_ulgy_mseq_int", 0.283), - ("std_ulgy_mseq_slp", 0.007), - ("std_ul_mseq_int", 1.665), - ("std_ul_mseq_slp", 0.337), - ("std_uh_mseq_int", 1.588), - ("std_uh_mseq_slp", -0.254), - ("std_ulgm_qseq_int", 0.362), - ("std_ulgm_qseq_slp", 0.001), - ("std_ulgy_qseq_int", 0.274), - ("std_ulgy_qseq_slp", -0.008), - ("std_ul_qseq_int", 1.766), - ("std_ul_qseq_slp", -0.476), - ("std_uh_qseq_int", 1.306), - ("std_uh_qseq_slp", -0.208), + ("std_ulgm_mseq_int", 0.383), + ("std_ulgm_mseq_slp", 0.053), + ("std_ulgy_mseq_int", 0.276), + ("std_ulgy_mseq_slp", -0.005), + ("std_ul_mseq_int", 1.677), + ("std_ul_mseq_slp", 0.017), + ("std_uh_mseq_int", 1.673), + ("std_uh_mseq_slp", -0.196), + ("std_ulgm_qseq_int", 0.378), + ("std_ulgm_qseq_slp", -0.091), + ("std_ulgy_qseq_int", 0.281), + ("std_ulgy_qseq_slp", -0.024), + ("std_ul_qseq_int", 1.832), + ("std_ul_qseq_slp", -0.537), + ("std_uh_qseq_int", 1.448), + ("std_uh_qseq_slp", -0.102), ] ) SFH_PDF_QUENCH_COV_Q_BLOCK_PDICT = OrderedDict( [ - ("std_uqt_int", 0.131), - ("std_uqt_slp", 0.071), - ("std_uqs_int", 0.590), - ("std_uqs_slp", 0.184), - ("std_udrop_int", 0.784), - ("std_udrop_slp", -0.065), - ("std_urej_int", 1.221), - ("std_urej_slp", -0.111), + ("std_uqt_int", 0.129), + ("std_uqt_slp", 0.069), + ("std_uqs_int", 0.560), + ("std_uqs_slp", 0.175), + ("std_udrop_int", 0.765), + ("std_udrop_slp", 0.018), + ("std_urej_int", 1.222), + ("std_urej_slp", -0.118), ] ) @@ -85,28 +85,28 @@ ("frac_quench_cen_x0_tpeak", 7.000), ("frac_quench_cen_k_tpeak", 2.000), ("frac_quench_cen_x0_ylotpeak", 11.100), - ("frac_quench_cen_x0_yhitpeak", 12.813), + ("frac_quench_cen_x0_yhitpeak", 12.914), ("frac_quench_cen_ylo_ylotpeak", 0.990), - ("frac_quench_cen_ylo_yhitpeak", 0.196), + ("frac_quench_cen_ylo_yhitpeak", 0.185), ("frac_quench_cen_k", 3.848), ("frac_quench_cen_yhi", 0.971), ("frac_quench_sat_x0_tpeak", 7.000), ("frac_quench_sat_k_tpeak", 2.000), ("frac_quench_sat_x0_ylotpeak", 11.100), - ("frac_quench_sat_x0_yhitpeak", 12.813), + ("frac_quench_sat_x0_yhitpeak", 12.914), ("frac_quench_sat_ylo_ylotpeak", 0.990), - ("frac_quench_sat_ylo_yhitpeak", 0.196), + ("frac_quench_sat_ylo_yhitpeak", 0.185), ("frac_quench_sat_k", 3.848), ("frac_quench_sat_yhi", 0.971), ] ) DELTA_UQT_PDICT = OrderedDict( [ - ("delta_uqt_x0", 3.309), - ("delta_uqt_k", 4.719), - ("delta_uqt_ylo", -0.308), - ("delta_uqt_yhi", 0.029), - ("delta_uqt_slope", 0.009), + ("delta_uqt_x0", 3.364), + ("delta_uqt_k", 4.982), + ("delta_uqt_ylo", -0.295), + ("delta_uqt_yhi", 0.026), + ("delta_uqt_slope", -0.003), ] ) SFH_PDF_QUENCH_PDICT = SFH_PDF_FRAC_QUENCH_PDICT.copy() diff --git a/diffstar/diffstarpop/kernels/params/params_diffstarpopfits_mgash_tng.py b/diffstar/diffstarpop/kernels/params/params_diffstarpopfits_mgash_tng.py index 468dc41..2c8347d 100644 --- a/diffstar/diffstarpop/kernels/params/params_diffstarpopfits_mgash_tng.py +++ b/diffstar/diffstarpop/kernels/params/params_diffstarpopfits_mgash_tng.py @@ -1,113 +1,110 @@ from collections import OrderedDict, namedtuple -from ..defaults_mgash import ( - DiffstarPopParams, - get_unbounded_diffstarpop_params, -) +from ..defaults_mgash import DiffstarPopParams, get_unbounded_diffstarpop_params from ..satquenchpop_model import DEFAULT_SATQUENCHPOP_PARAMS SFH_PDF_QUENCH_MU_PDICT = OrderedDict( [ - ("mean_ulgm_mseq_xtp", 11.934), - ("mean_ulgm_mseq_ytp", 11.392), - ("mean_ulgm_mseq_lo", 0.654), - ("mean_ulgm_mseq_hi", 0.409), - ("mean_ulgy_mseq_xtp", 12.716), - ("mean_ulgy_mseq_ytp", -9.372), - ("mean_ulgy_mseq_lo", 0.691), - ("mean_ulgy_mseq_hi", 0.999), - ("mean_ul_mseq_int", 0.350), - ("mean_ul_mseq_slp", 2.462), - ("mean_uh_mseq_int", -2.144), - ("mean_uh_mseq_slp", -0.700), - ("mean_ulgm_qseq_xtp", 13.590), - ("mean_ulgm_qseq_ytp", 11.911), - ("mean_ulgm_qseq_lo", 0.218), - ("mean_ulgm_qseq_hi", 0.294), - ("mean_ulgy_qseq_xtp", 12.038), - ("mean_ulgy_qseq_ytp", -9.768), - ("mean_ulgy_qseq_lo", 1.350), - ("mean_ulgy_qseq_hi", 0.597), - ("mean_ul_qseq_int", -1.536), - ("mean_ul_qseq_slp", -0.132), - ("mean_uh_qseq_int", -1.129), - ("mean_uh_qseq_slp", -0.216), - ("mean_uqt_xtp", 13.551), - ("mean_uqt_ytp", 0.698), - ("mean_uqt_lo", -0.332), - ("mean_uqt_hi", -0.018), - ("mean_uqs_int", 1.324), - ("mean_uqs_slp", 0.301), - ("mean_udrop_int", -2.997), - ("mean_udrop_slp", 1.093), - ("mean_urej_int", -9.199), - ("mean_urej_slp", -0.854), + ("mean_ulgm_mseq_xtp", 12.449), + ("mean_ulgm_mseq_ytp", 11.897), + ("mean_ulgm_mseq_lo", 0.992), + ("mean_ulgm_mseq_hi", 0.039), + ("mean_ulgy_mseq_xtp", 13.459), + ("mean_ulgy_mseq_ytp", -8.471), + ("mean_ulgy_mseq_lo", 1.249), + ("mean_ulgy_mseq_hi", 0.465), + ("mean_ul_mseq_int", -0.714), + ("mean_ul_mseq_slp", 3.270), + ("mean_uh_mseq_int", -4.088), + ("mean_uh_mseq_slp", -4.347), + ("mean_ulgm_qseq_xtp", 11.073), + ("mean_ulgm_qseq_ytp", 11.002), + ("mean_ulgm_qseq_lo", 3.635), + ("mean_ulgm_qseq_hi", 0.551), + ("mean_ulgy_qseq_xtp", 12.971), + ("mean_ulgy_qseq_ytp", -9.536), + ("mean_ulgy_qseq_lo", 0.638), + ("mean_ulgy_qseq_hi", 0.266), + ("mean_ul_qseq_int", -2.147), + ("mean_ul_qseq_slp", -1.249), + ("mean_uh_qseq_int", -1.392), + ("mean_uh_qseq_slp", -0.790), + ("mean_uqt_xtp", 12.107), + ("mean_uqt_ytp", 1.084), + ("mean_uqt_lo", -0.057), + ("mean_uqt_hi", -0.519), + ("mean_uqs_int", -0.019), + ("mean_uqs_slp", -0.231), + ("mean_udrop_int", -2.999), + ("mean_udrop_slp", 2.113), + ("mean_urej_int", -8.730), + ("mean_urej_slp", -0.280), ] ) SFH_PDF_QUENCH_COV_MS_BLOCK_PDICT = OrderedDict( [ ("std_ulgm_mseq_int", 0.011), - ("std_ulgm_mseq_slp", 0.001), + ("std_ulgm_mseq_slp", 0.049), ("std_ulgy_mseq_int", 0.011), - ("std_ulgy_mseq_slp", -0.184), - ("std_ul_mseq_int", 0.085), - ("std_ul_mseq_slp", 0.994), - ("std_uh_mseq_int", 0.103), + ("std_ulgy_mseq_slp", -0.123), + ("std_ul_mseq_int", 0.768), + ("std_ul_mseq_slp", 0.997), + ("std_uh_mseq_int", 0.457), ("std_uh_mseq_slp", -0.999), - ("std_ulgm_qseq_int", 0.011), - ("std_ulgm_qseq_slp", -0.175), - ("std_ulgy_qseq_int", 0.011), - ("std_ulgy_qseq_slp", -0.122), - ("std_ul_qseq_int", 0.014), - ("std_ul_qseq_slp", -0.002), - ("std_uh_qseq_int", 0.575), - ("std_uh_qseq_slp", -0.469), + ("std_ulgm_qseq_int", 0.108), + ("std_ulgm_qseq_slp", -0.149), + ("std_ulgy_qseq_int", 0.071), + ("std_ulgy_qseq_slp", -0.077), + ("std_ul_qseq_int", 0.392), + ("std_ul_qseq_slp", -0.681), + ("std_uh_qseq_int", 0.013), + ("std_uh_qseq_slp", 0.432), ] ) SFH_PDF_QUENCH_COV_Q_BLOCK_PDICT = OrderedDict( [ - ("std_uqt_int", 0.073), - ("std_uqt_slp", -0.009), - ("std_uqs_int", 0.042), - ("std_uqs_slp", -0.035), - ("std_udrop_int", 0.011), - ("std_udrop_slp", 0.753), - ("std_urej_int", 0.127), - ("std_urej_slp", -0.063), + ("std_uqt_int", 0.028), + ("std_uqt_slp", -0.013), + ("std_uqs_int", 0.075), + ("std_uqs_slp", -0.038), + ("std_udrop_int", 0.463), + ("std_udrop_slp", -0.824), + ("std_urej_int", 0.131), + ("std_urej_slp", 0.013), ] ) SFH_PDF_FRAC_QUENCH_PDICT = OrderedDict( [ - ("frac_quench_cen_x0_tpeak", 13.615), - ("frac_quench_cen_k_tpeak", 9.240), - ("frac_quench_cen_x0_ylotpeak", 13.966), - ("frac_quench_cen_x0_yhitpeak", 11.541), - ("frac_quench_cen_ylo_ylotpeak", 0.041), - ("frac_quench_cen_ylo_yhitpeak", 0.737), - ("frac_quench_cen_k", 4.995), + ("frac_quench_cen_x0_tpeak", 13.080), + ("frac_quench_cen_k_tpeak", 9.936), + ("frac_quench_cen_x0_ylotpeak", 12.066), + ("frac_quench_cen_x0_yhitpeak", 12.401), + ("frac_quench_cen_ylo_ylotpeak", 0.999), + ("frac_quench_cen_ylo_yhitpeak", 0.001), + ("frac_quench_cen_k", 4.999), ("frac_quench_cen_yhi", 0.999), - ("frac_quench_sat_x0_tpeak", 11.905), - ("frac_quench_sat_k_tpeak", 4.158), - ("frac_quench_sat_x0_ylotpeak", 12.469), - ("frac_quench_sat_x0_yhitpeak", 12.456), + ("frac_quench_sat_x0_tpeak", 9.388), + ("frac_quench_sat_k_tpeak", 9.848), + ("frac_quench_sat_x0_ylotpeak", 13.025), + ("frac_quench_sat_x0_yhitpeak", 12.397), ("frac_quench_sat_ylo_ylotpeak", 0.999), ("frac_quench_sat_ylo_yhitpeak", 0.001), - ("frac_quench_sat_k", 4.995), - ("frac_quench_sat_yhi", 0.999), + ("frac_quench_sat_k", 4.999), + ("frac_quench_sat_yhi", 0.848), ] ) DELTA_UQT_PDICT = OrderedDict( [ - ("delta_uqt_x0", 2.532), - ("delta_uqt_k", 0.454), - ("delta_uqt_ylo", -0.977), - ("delta_uqt_yhi", -0.002), - ("delta_uqt_slope", -0.030), + ("delta_uqt_x0", 4.191), + ("delta_uqt_k", 0.577), + ("delta_uqt_ylo", -0.496), + ("delta_uqt_yhi", 0.094), + ("delta_uqt_slope", -0.048), ] ) SFH_PDF_QUENCH_PDICT = SFH_PDF_FRAC_QUENCH_PDICT.copy() diff --git a/scripts/diffstarpop_scripts/calculate_quality_diffstar_fits_mgash.py b/scripts/diffstarpop_scripts/calculate_quality_diffstar_fits_mgash.py new file mode 100644 index 0000000..c2df9cb --- /dev/null +++ b/scripts/diffstarpop_scripts/calculate_quality_diffstar_fits_mgash.py @@ -0,0 +1,708 @@ +import re +import os +import numpy as np +from matplotlib import pyplot as plt +import matplotlib.gridspec as gridspec +from matplotlib.patches import Patch +from matplotlib.lines import Line2D +import warnings +import h5py + +import argparse + +from astropy.cosmology import Planck15, z_at_value + +mred = "#d62728" +morange = "#ff7f0e" +mgreen = "#2ca02c" +mblue = "#1f77b4" +mpurple = "#9467bd" +plt.rc("font", family="serif") +plt.rc("font", size=22) +plt.rc("text", usetex=False) +plt.rc("text.latex", preamble=r"\usepackage{amsmath}") # necessary to use \dfrac + +import smhm_utils_tng_mgash +import smhm_utils_galacticus_mgash +import smhm_utils_smdpl_mgash +from smhm_utils_smdpl_mgash import load_diffstar_sfh_tables + +from diffstar.data_loaders.load_smah_data import ( + FB_SMDPL, + T0_SMDPL, + load_smdpl_diffmah_fits, + load_SMDPL_DR1_data, + load_SMDPL_nomerging_data, + load_tng_data, +) +from diffstar.data_loaders.load_galacticus_sfh import load_galacticus_diffstar_data + +from jax import vmap, jit as jjit, numpy as jnp +from diffstar.defaults import TODAY, LGT0 +from diffstar.utils import cumulative_mstar_formed_galpop +from diffmah.diffmah_kernels import DiffmahParams, mah_halopop, DEFAULT_MAH_PARAMS + + +def _jnp_interp_vmap(x_new, x, y): + return jnp.interp(x_new, x, y) + + +jnp_interp_vmap = jjit(vmap(_jnp_interp_vmap, in_axes=(None, None, 0))) + + +def calculate_plot_smdpl_nomerging(mpeak_bins): + diffmah_drn = smhm_utils_smdpl_mgash.LCRC_NOMERGING_DIFFMAH_DRN + diffstar_drn = smhm_utils_smdpl_mgash.LCRC_NOMERGING_DIFFSTAR_DRN + binaries_drn = smhm_utils_smdpl_mgash.LCRC_NOMERGING_BINARIES_DRN + diffstar_bnpat = smhm_utils_smdpl_mgash.LCRC_NOMERGING_diffstar_bnpat + sim_name = "DR1_nomerging" + + regex_str = re.escape(diffstar_bnpat).replace(r"\{\}", r"(\d{1,3})") + pattern = re.compile(f"^{regex_str}$") + matching_files = [f for f in os.listdir(diffstar_drn) if pattern.match(f)] + subvols = [x.split("_")[-1].split(".")[0] for x in matching_files] + subvols = np.sort(np.array(subvols).astype(int)) + n_subvol_smdpl = len(subvols) + + mpeak_binsc = 0.5 * (mpeak_bins[1:] + mpeak_bins[:-1]) + nt = 117 + + mstar_data_mean = np.zeros((len(mpeak_binsc), nt)) + mstar_fit_mean = np.zeros((len(mpeak_binsc), nt)) + sfr_data_mean = np.zeros((len(mpeak_binsc), nt)) + sfr_fit_mean = np.zeros((len(mpeak_binsc), nt)) + + ngals = np.zeros(len(mpeak_binsc)) + + for subvol in range(576): + + print(subvol) + + out = smhm_utils_smdpl_mgash.load_diffstar_sfh_tables( + subvol, + sim_name, + n_subvol_smdpl, + diffmah_drn, + diffstar_drn, + diffstar_bnpat, + ) + ( + t_table, + log_mah_table, + log_smh_table, + log_ssfrh_table, + mah_params, + sfh_params, + has_fit, + ) = out + + log_sfh_table = log_ssfrh_table + log_smh_table + + out = load_SMDPL_nomerging_data([subvol], binaries_drn) + (halo_ids, log_smahs, sfrh, SMDPL_t, log_mahs, logmp0) = out + log_sfrh = np.where(sfrh > 0.0, np.log10(sfrh), 0.0) + + _log_smahs_data = log_smahs[has_fit] + _log_sfrh_data = log_sfrh[has_fit] + + _log_smahs_fits = jnp_interp_vmap(SMDPL_t, t_table, log_smh_table) + _log_sfrh_fits = jnp_interp_vmap(SMDPL_t, t_table, log_sfh_table) + + smahs_fits = np.where(_log_smahs_fits == 0.0, np.nan, 10**_log_smahs_fits) + sfrh_fits = np.where(_log_sfrh_fits == 0.0, np.nan, 10**_log_sfrh_fits) + smahs_data = np.where(_log_smahs_data == 0.0, np.nan, 10**_log_smahs_data) + sfrh_data = np.where(_log_sfrh_data == 0.0, np.nan, 10**_log_sfrh_data) + + logmp0_data = logmp0[has_fit] + + ssfrh = sfrh_data / smahs_data + ssfrh_fit = sfrh_fits / smahs_fits + ssfrh = np.clip(ssfrh, 1e-12, np.inf) + ssfrh_fit = np.clip(ssfrh_fit, 1e-12, np.inf) + sfrh = np.where(smahs_data > 0.0, ssfrh * smahs_data, sfrh_data) + sfrh_fits = ssfrh_fit * smahs_fits + + for i in range(len(mpeak_bins) - 1): + masksel = (logmp0_data > mpeak_bins[i]) & (logmp0_data < mpeak_bins[i + 1]) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + + mstar_data_mean[i] += np.nansum(smahs_data[masksel], axis=0) + mstar_fit_mean[i] += np.nansum(smahs_fits[masksel], axis=0) + sfr_data_mean[i] += np.nansum(sfrh[masksel], axis=0) + sfr_fit_mean[i] += np.nansum(sfrh_fits[masksel], axis=0) + + ngals[i] += masksel.sum() + + mstar_data_mean /= ngals[:, None] + mstar_fit_mean /= ngals[:, None] + sfr_data_mean /= ngals[:, None] + sfr_fit_mean /= ngals[:, None] + + out = ( + mpeak_bins, + mpeak_binsc, + SMDPL_t, + mstar_data_mean, + mstar_fit_mean, + sfr_data_mean, + sfr_fit_mean, + ) + + return out + + +def calculate_plot_smdpl_dr1(mpeak_bins): + diffmah_drn = smhm_utils_smdpl_mgash.LCRC_DR1_DIFFMAH_DRN + diffstar_drn = smhm_utils_smdpl_mgash.LCRC_DR1_DIFFSTAR_DRN + binaries_drn = smhm_utils_smdpl_mgash.LCRC_DR1_BINARIES_DRN + diffstar_bnpat = smhm_utils_smdpl_mgash.LCRC_DR1_diffstar_bnpat + sim_name = "DR1" + + regex_str = re.escape(diffstar_bnpat).replace(r"\{\}", r"(\d{1,3})") + pattern = re.compile(f"^{regex_str}$") + matching_files = [f for f in os.listdir(diffstar_drn) if pattern.match(f)] + subvols = [x.split("_")[-1].split(".")[0] for x in matching_files] + subvols = np.sort(np.array(subvols).astype(int)) + n_subvol_smdpl = len(subvols) + + mpeak_binsc = 0.5 * (mpeak_bins[1:] + mpeak_bins[:-1]) + nt = 117 + + mstar_data_mean = np.zeros((len(mpeak_binsc), nt)) + mstar_fit_mean = np.zeros((len(mpeak_binsc), nt)) + sfr_data_mean = np.zeros((len(mpeak_binsc), nt)) + sfr_fit_mean = np.zeros((len(mpeak_binsc), nt)) + + ngals = np.zeros(len(mpeak_binsc)) + + print(n_subvol_smdpl) + + for subvol in subvols: + + print(subvol) + + out = smhm_utils_smdpl_mgash.load_diffstar_sfh_tables( + subvol, + sim_name, + n_subvol_smdpl, + diffmah_drn, + diffstar_drn, + diffstar_bnpat, + ) + ( + t_table, + log_mah_table, + log_smh_table, + log_ssfrh_table, + mah_params, + sfh_params, + has_fit, + ) = out + + log_sfh_table = log_ssfrh_table + log_smh_table + + out = load_SMDPL_DR1_data([subvol], binaries_drn) + (halo_ids, log_smahs, sfrh, SMDPL_t, log_mahs, logmp0) = out + log_sfrh = np.where(sfrh > 0.0, np.log10(sfrh), 0.0) + + _log_smahs_data = log_smahs[has_fit] + _log_sfrh_data = log_sfrh[has_fit] + + _log_smahs_fits = jnp_interp_vmap(SMDPL_t, t_table, log_smh_table) + _log_sfrh_fits = jnp_interp_vmap(SMDPL_t, t_table, log_sfh_table) + + smahs_fits = np.where(_log_smahs_fits == 0.0, np.nan, 10**_log_smahs_fits) + sfrh_fits = np.where(_log_sfrh_fits == 0.0, np.nan, 10**_log_sfrh_fits) + smahs_data = np.where(_log_smahs_data == 0.0, np.nan, 10**_log_smahs_data) + sfrh_data = np.where(_log_sfrh_data == 0.0, np.nan, 10**_log_sfrh_data) + + logmp0_data = logmp0[has_fit] + + ssfrh = sfrh_data / smahs_data + ssfrh_fit = sfrh_fits / smahs_fits + ssfrh = np.clip(ssfrh, 1e-12, np.inf) + ssfrh_fit = np.clip(ssfrh_fit, 1e-12, np.inf) + sfrh = np.where(smahs_data > 0.0, ssfrh * smahs_data, sfrh_data) + sfrh_fits = ssfrh_fit * smahs_fits + + for i in range(len(mpeak_bins) - 1): + masksel = (logmp0_data > mpeak_bins[i]) & (logmp0_data < mpeak_bins[i + 1]) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + + mstar_data_mean[i] += np.nansum(smahs_data[masksel], axis=0) + mstar_fit_mean[i] += np.nansum(smahs_fits[masksel], axis=0) + sfr_data_mean[i] += np.nansum(sfrh[masksel], axis=0) + sfr_fit_mean[i] += np.nansum(sfrh_fits[masksel], axis=0) + + ngals[i] += masksel.sum() + + mstar_data_mean /= ngals[:, None] + mstar_fit_mean /= ngals[:, None] + sfr_data_mean /= ngals[:, None] + sfr_fit_mean /= ngals[:, None] + + out = ( + mpeak_bins, + mpeak_binsc, + SMDPL_t, + mstar_data_mean, + mstar_fit_mean, + sfr_data_mean, + sfr_fit_mean, + ) + + return out + + +def calculate_plot_tng(mpeak_bins): + diffmah_drn = smhm_utils_tng_mgash.BEBOP_TNG_MAH + diffstar_drn = smhm_utils_tng_mgash.BEBOP_TNG_SFH + binaries_drn = smhm_utils_tng_mgash.BEBOP_TNG + + mpeak_binsc = 0.5 * (mpeak_bins[1:] + mpeak_bins[:-1]) + out = load_tng_data(binaries_drn) + (halo_ids, log_smahs, sfrh, tng_t, log_mahs, logmp0) = out + log_sfrh = np.where(sfrh > 0.0, np.log10(sfrh), 0.0) + nt = len(tng_t) + n_subvol_smdpl = 20 + + nhalos_tot = len(halo_ids) + + _a = np.arange(0, nhalos_tot).astype("i8") + + mstar_data_mean = np.zeros((len(mpeak_binsc), nt)) + mstar_fit_mean = np.zeros((len(mpeak_binsc), nt)) + sfr_data_mean = np.zeros((len(mpeak_binsc), nt)) + sfr_fit_mean = np.zeros((len(mpeak_binsc), nt)) + + ngals = np.zeros(len(mpeak_binsc)) + + print(n_subvol_smdpl) + + for subvol in range(20): + + print(subvol) + + indx = np.array_split(_a, n_subvol_smdpl)[subvol] + + out = smhm_utils_tng_mgash.load_diffstar_sfh_tables( + subvol, + diffmah_drn, + diffstar_drn, + ) + ( + t_table, + log_mah_table, + log_smh_table, + log_ssfrh_table, + mah_params, + sfh_params, + has_fit, + ) = out + + log_sfh_table = log_ssfrh_table + log_smh_table + + _log_smahs_data = log_smahs[indx][has_fit] + _log_sfrh_data = log_sfrh[indx][has_fit] + + _log_smahs_fits = jnp_interp_vmap(tng_t, t_table, log_smh_table) + _log_sfrh_fits = jnp_interp_vmap(tng_t, t_table, log_sfh_table) + + smahs_fits = np.where(_log_smahs_fits == 0.0, np.nan, 10**_log_smahs_fits) + sfrh_fits = np.where(_log_sfrh_fits == 0.0, np.nan, 10**_log_sfrh_fits) + smahs_data = np.where(_log_smahs_data == 0.0, np.nan, 10**_log_smahs_data) + sfrh_data = np.where(_log_sfrh_data == 0.0, np.nan, 10**_log_sfrh_data) + + logmp0_data = logmp0[indx][has_fit] + + ssfrh = sfrh_data / smahs_data + ssfrh_fit = sfrh_fits / smahs_fits + ssfrh = np.clip(ssfrh, 1e-12, np.inf) + ssfrh_fit = np.clip(ssfrh_fit, 1e-12, np.inf) + sfrh = np.where(smahs_data > 0.0, ssfrh * smahs_data, sfrh_data) + sfrh_fits = ssfrh_fit * smahs_fits + + for i in range(len(mpeak_bins) - 1): + masksel = (logmp0_data > mpeak_bins[i]) & (logmp0_data < mpeak_bins[i + 1]) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + + mstar_data_mean[i] += np.nansum(smahs_data[masksel], axis=0) + mstar_fit_mean[i] += np.nansum(smahs_fits[masksel], axis=0) + sfr_data_mean[i] += np.nansum(sfrh[masksel], axis=0) + sfr_fit_mean[i] += np.nansum(sfrh_fits[masksel], axis=0) + + ngals[i] += masksel.sum() + + mstar_data_mean /= ngals[:, None] + mstar_fit_mean /= ngals[:, None] + sfr_data_mean /= ngals[:, None] + sfr_fit_mean /= ngals[:, None] + + out = ( + mpeak_bins, + mpeak_binsc, + tng_t, + mstar_data_mean, + mstar_fit_mean, + sfr_data_mean, + sfr_fit_mean, + ) + + return out + + +def calculate_plot_galcus_insitu(mpeak_bins): + BEBOP_GALAC = smhm_utils_galacticus_mgash.BEBOP_GALAC + BEBOP_GALAC_SFH = smhm_utils_galacticus_mgash.BEBOP_GALAC_SFH + + mpeak_binsc = 0.5 * (mpeak_bins[1:] + mpeak_bins[:-1]) + + out = load_galacticus_diffstar_data(BEBOP_GALAC) + galcus_t = out.galcus_sfh_data["tarr"] + sfrh = out.galcus_sfh_data["sfh_in_situ"] + diffmah_data = out.diffmah_fit_data + + log_smahs = np.log10(cumulative_mstar_formed_galpop(galcus_t, sfrh)) + + mah_params = DEFAULT_MAH_PARAMS._make( + [diffmah_data[key] for key in DEFAULT_MAH_PARAMS._fields] + ) + + mah_pars_ntuple = DiffmahParams(*mah_params) + dmhdt_fit, log_mah_fit = mah_halopop(mah_pars_ntuple, galcus_t, LGT0) + logmp0 = log_mah_fit[:, -1] + + # (halo_ids, log_smahs, sfrh, tng_t, log_mahs, logmp0) = out + log_sfrh = np.where(sfrh > 0.0, np.log10(sfrh), 0.0) + nt = len(galcus_t) + + mstar_data_mean = np.zeros((len(mpeak_binsc), nt)) + mstar_fit_mean = np.zeros((len(mpeak_binsc), nt)) + sfr_data_mean = np.zeros((len(mpeak_binsc), nt)) + sfr_fit_mean = np.zeros((len(mpeak_binsc), nt)) + + ngals = np.zeros(len(mpeak_binsc)) + + sfh_type = "in_situ" + + out = smhm_utils_galacticus_mgash.load_diffstar_sfh_tables( + sfh_type, + BEBOP_GALAC, + BEBOP_GALAC_SFH, + ) + ( + t_table, + log_mah_table, + log_smh_table, + log_ssfrh_table, + mah_params, + sfh_params, + is_cen, + has_fit, + ) = out + + log_sfh_table = log_ssfrh_table + log_smh_table + + _log_smahs_data = log_smahs[has_fit] + _log_sfrh_data = log_sfrh[has_fit] + + _log_smahs_fits = jnp_interp_vmap(galcus_t, t_table, log_smh_table) + _log_sfrh_fits = jnp_interp_vmap(galcus_t, t_table, log_sfh_table) + + smahs_fits = np.where(_log_smahs_fits == 0.0, np.nan, 10**_log_smahs_fits) + sfrh_fits = np.where(_log_sfrh_fits == 0.0, np.nan, 10**_log_sfrh_fits) + smahs_data = np.where(_log_smahs_data == 0.0, np.nan, 10**_log_smahs_data) + sfrh_data = np.where(_log_sfrh_data == 0.0, np.nan, 10**_log_sfrh_data) + + logmp0_data = logmp0[has_fit] + + ssfrh = sfrh_data / smahs_data + ssfrh_fit = sfrh_fits / smahs_fits + ssfrh = np.clip(ssfrh, 1e-12, np.inf) + ssfrh_fit = np.clip(ssfrh_fit, 1e-12, np.inf) + sfrh = np.where(smahs_data > 0.0, ssfrh * smahs_data, sfrh_data) + sfrh_fits = ssfrh_fit * smahs_fits + + for i in range(len(mpeak_bins) - 1): + masksel = (logmp0_data > mpeak_bins[i]) & (logmp0_data < mpeak_bins[i + 1]) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + + mstar_data_mean[i] += np.nansum(smahs_data[masksel], axis=0) + mstar_fit_mean[i] += np.nansum(smahs_fits[masksel], axis=0) + sfr_data_mean[i] += np.nansum(sfrh[masksel], axis=0) + sfr_fit_mean[i] += np.nansum(sfrh_fits[masksel], axis=0) + + ngals[i] += masksel.sum() + + mstar_data_mean /= ngals[:, None] + mstar_fit_mean /= ngals[:, None] + sfr_data_mean /= ngals[:, None] + sfr_fit_mean /= ngals[:, None] + + out = ( + mpeak_bins, + mpeak_binsc, + galcus_t, + mstar_data_mean, + mstar_fit_mean, + sfr_data_mean, + sfr_fit_mean, + ) + + return out + + +def calculate_plot_galcus_inplusexsitu(mpeak_bins): + BEBOP_GALAC = smhm_utils_galacticus_mgash.BEBOP_GALAC + BEBOP_GALAC_SFH = smhm_utils_galacticus_mgash.BEBOP_GALAC_SFH + + mpeak_binsc = 0.5 * (mpeak_bins[1:] + mpeak_bins[:-1]) + + out = load_galacticus_diffstar_data(BEBOP_GALAC) + galcus_t = out.galcus_sfh_data["tarr"] + sfrh = out.galcus_sfh_data["sfh_tot"] + diffmah_data = out.diffmah_fit_data + + log_smahs = np.log10(cumulative_mstar_formed_galpop(galcus_t, sfrh)) + + mah_params = DEFAULT_MAH_PARAMS._make( + [diffmah_data[key] for key in DEFAULT_MAH_PARAMS._fields] + ) + + mah_pars_ntuple = DiffmahParams(*mah_params) + dmhdt_fit, log_mah_fit = mah_halopop(mah_pars_ntuple, galcus_t, LGT0) + logmp0 = log_mah_fit[:, -1] + + # (halo_ids, log_smahs, sfrh, tng_t, log_mahs, logmp0) = out + log_sfrh = np.where(sfrh > 0.0, np.log10(sfrh), 0.0) + nt = len(galcus_t) + + mstar_data_mean = np.zeros((len(mpeak_binsc), nt)) + mstar_fit_mean = np.zeros((len(mpeak_binsc), nt)) + sfr_data_mean = np.zeros((len(mpeak_binsc), nt)) + sfr_fit_mean = np.zeros((len(mpeak_binsc), nt)) + + ngals = np.zeros(len(mpeak_binsc)) + + sfh_type = "in_plus_ex_situ" + + out = smhm_utils_galacticus_mgash.load_diffstar_sfh_tables( + sfh_type, + BEBOP_GALAC, + BEBOP_GALAC_SFH, + ) + ( + t_table, + log_mah_table, + log_smh_table, + log_ssfrh_table, + mah_params, + sfh_params, + is_cen, + has_fit, + ) = out + + log_sfh_table = log_ssfrh_table + log_smh_table + + _log_smahs_data = log_smahs[has_fit] + _log_sfrh_data = log_sfrh[has_fit] + + _log_smahs_fits = jnp_interp_vmap(galcus_t, t_table, log_smh_table) + _log_sfrh_fits = jnp_interp_vmap(galcus_t, t_table, log_sfh_table) + + smahs_fits = np.where(_log_smahs_fits == 0.0, np.nan, 10**_log_smahs_fits) + sfrh_fits = np.where(_log_sfrh_fits == 0.0, np.nan, 10**_log_sfrh_fits) + smahs_data = np.where(_log_smahs_data == 0.0, np.nan, 10**_log_smahs_data) + sfrh_data = np.where(_log_sfrh_data == 0.0, np.nan, 10**_log_sfrh_data) + + logmp0_data = logmp0[has_fit] + + ssfrh = sfrh_data / smahs_data + ssfrh_fit = sfrh_fits / smahs_fits + ssfrh = np.clip(ssfrh, 1e-12, np.inf) + ssfrh_fit = np.clip(ssfrh_fit, 1e-12, np.inf) + sfrh = np.where(smahs_data > 0.0, ssfrh * smahs_data, sfrh_data) + sfrh_fits = ssfrh_fit * smahs_fits + + for i in range(len(mpeak_bins) - 1): + masksel = (logmp0_data > mpeak_bins[i]) & (logmp0_data < mpeak_bins[i + 1]) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + + mstar_data_mean[i] += np.nansum(smahs_data[masksel], axis=0) + mstar_fit_mean[i] += np.nansum(smahs_fits[masksel], axis=0) + sfr_data_mean[i] += np.nansum(sfrh[masksel], axis=0) + sfr_fit_mean[i] += np.nansum(sfrh_fits[masksel], axis=0) + + ngals[i] += masksel.sum() + + mstar_data_mean /= ngals[:, None] + mstar_fit_mean /= ngals[:, None] + sfr_data_mean /= ngals[:, None] + sfr_fit_mean /= ngals[:, None] + + out = ( + mpeak_bins, + mpeak_binsc, + galcus_t, + mstar_data_mean, + mstar_fit_mean, + sfr_data_mean, + sfr_fit_mean, + ) + + return out + + +def save_data(outdrn, outname, data): + fnout = os.path.join(outdrn, outname) + + ( + tarr, + log_smahs_fits, + log_sfrh_fits, + log_smahs_data, + log_sfrh_data, + logmp0_data, + ) = data + + with h5py.File(fnout, "w") as hdfout: + hdfout["tarr"] = tarr + hdfout["log_smahs_fits"] = log_smahs_fits + hdfout["log_sfrh_fits"] = log_sfrh_fits + hdfout["log_smahs_data"] = log_smahs_data + hdfout["log_sfrh_data"] = log_sfrh_data + hdfout["logmp0_data"] = logmp0_data + + +def save_data_plot(outdrn, outname, data): + fnout = os.path.join(outdrn, outname) + + ( + mpeak_bins, + mpeak_binsc, + tarr, + mstar_data_mean, + mstar_fit_mean, + sfr_data_mean, + sfr_fit_mean, + ) = data + + with h5py.File(fnout, "w") as hdfout: + hdfout["tarr"] = tarr + hdfout["mpeak_bins"] = mpeak_bins + hdfout["mpeak_binsc"] = mpeak_binsc + hdfout["mstar_data_mean"] = mstar_data_mean + hdfout["mstar_fit_mean"] = mstar_fit_mean + hdfout["sfr_data_mean"] = sfr_data_mean + hdfout["sfr_fit_mean"] = sfr_fit_mean + + +# out_smdpl_nomerging = calculate_smdpl_nomerging() +# outdir = "/lcrc/project/halotools/alarcon/results/diffstar_quality_fits/" +# outname = "diffstar_quality_smdpl.h5" +# save_data(outdir, outname, out_smdpl_nomerging) + +# outdir = "/lcrc/project/halotools/alarcon/results/smdpl_pdf_target_data/" +# sim_name = "SMDPL_UM_Nomerging" +# make_diffstar_fits_plot(outdir, sim_name, *out_smdpl_nomerging) + +# mpeak_bins = np.arange(11.25, 14.5, 0.50) +# out_smdpl_nomerging = calculate_plot_smdpl_nomerging(mpeak_bins) +# outdir = "/lcrc/project/halotools/alarcon/results/diffstar_quality_fits/" +# outname = "diffstar_quality_smdpl.h5" +# save_data_plot(outdir, outname, out_smdpl_nomerging) + +# mpeak_bins = np.arange(11.25, 14.5, 0.50) +# out_smdpl_dr1 = calculate_plot_smdpl_dr1(mpeak_bins) +# outdir = "/lcrc/project/halotools/alarcon/results/diffstar_quality_fits/" +# outname = "diffstar_quality_smdpl_dr1.h5" +# save_data_plot(outdir, outname, out_smdpl_dr1) + +# mpeak_bins = np.arange(11.25, 14.5, 0.50) +# out_tng = calculate_plot_tng(mpeak_bins) +# outdir = "/lcrc/project/halotools/alarcon/results/diffstar_quality_fits/" +# outname = "diffstar_quality_tng.h5" +# save_data_plot(outdir, outname, out_tng) + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser() + parser.add_argument( + "-sim_name", + help="simulation name", + type=str, + default="smdpl", + choices=[ + "smdpl", + "smdpl_dr1", + "tng", + "galcus_insitu", + "galcus_inplusexsitu", + "all", + ], + ) + parser.add_argument( + "-outdir", + help="simulation name", + type=str, + default="/lcrc/project/halotools/alarcon/results/mgash/diffstar_quality_fits_mgash/", + ) + + args = parser.parse_args() + sim_name = args.sim_name + outdir = args.outdir + + mpeak_bins = np.arange(11.25, 14.5, 0.50) + + if sim_name == "all": + out_smdpl_nomerging = calculate_plot_smdpl_nomerging(mpeak_bins) + outname = "diffstar_quality_smdpl.h5" + save_data_plot(outdir, outname, out_smdpl_nomerging) + + out_smdpl_dr1 = calculate_plot_smdpl_dr1(mpeak_bins) + outname = "diffstar_quality_smdpl_dr1.h5" + save_data_plot(outdir, outname, out_smdpl_dr1) + + out_tng = calculate_plot_tng(mpeak_bins) + outname = "diffstar_quality_tng.h5" + save_data_plot(outdir, outname, out_tng) + + out_galcus_insitu = calculate_plot_galcus_insitu(mpeak_bins) + outname = "diffstar_quality_galcus_insitu.h5" + save_data_plot(outdir, outname, out_galcus_insitu) + + out_galcus_inplusexsitu = calculate_plot_galcus_inplusexsitu(mpeak_bins) + outname = "diffstar_quality_galcus_inplusexsitu.h5" + save_data_plot(outdir, outname, out_galcus_inplusexsitu) + + elif sim_name == "smdpl": + out_smdpl_nomerging = calculate_plot_smdpl_nomerging(mpeak_bins) + outname = "diffstar_quality_smdpl.h5" + save_data_plot(outdir, outname, out_smdpl_nomerging) + elif sim_name == "smdpl_dr1": + out_smdpl_dr1 = calculate_plot_smdpl_dr1(mpeak_bins) + outname = "diffstar_quality_smdpl_dr1.h5" + save_data_plot(outdir, outname, out_smdpl_dr1) + elif sim_name == "tng": + out_tng = calculate_plot_tng(mpeak_bins) + outname = "diffstar_quality_tng.h5" + save_data_plot(outdir, outname, out_tng) + elif sim_name == "galcus_insitu": + out_galcus_insitu = calculate_plot_galcus_insitu(mpeak_bins) + outname = "diffstar_quality_galcus_insitu.h5" + save_data_plot(outdir, outname, out_galcus_insitu) + elif sim_name == "galcus_inplusexsitu": + out_galcus_inplusexsitu = calculate_plot_galcus_inplusexsitu(mpeak_bins) + outname = "diffstar_quality_galcus_inplusexsitu.h5" + save_data_plot(outdir, outname, out_galcus_inplusexsitu) diff --git a/scripts/diffstarpop_scripts/fit_mstar_ssfr_pdfs_mgash.py b/scripts/diffstarpop_scripts/fit_mstar_ssfr_pdfs_mgash.py index 1cb4d43..ef0732f 100644 --- a/scripts/diffstarpop_scripts/fit_mstar_ssfr_pdfs_mgash.py +++ b/scripts/diffstarpop_scripts/fit_mstar_ssfr_pdfs_mgash.py @@ -18,7 +18,7 @@ from diffstar.defaults import TODAY, LGT0 from diffmah.diffmah_kernels import mah_halopop -from diffstar.diffstarpop.loss_kernels.mstar_ssfr_loss_mgash import ( +from diffstar.diffstarpop.loss_kernels.mstar_ssfr_loss_mgash_anyz import ( loss_mstar_kern_tobs_grad_wrapper, loss_mstar_ssfr_kern_tobs_grad_wrapper, loss_combined_wrapper, diff --git a/scripts/diffstarpop_scripts/measure_smhm_galacticus_script_mpi_mgash.py b/scripts/diffstarpop_scripts/measure_smhm_galacticus_script_mpi_mgash.py index e359d0c..a7daacf 100644 --- a/scripts/diffstarpop_scripts/measure_smhm_galacticus_script_mpi_mgash.py +++ b/scripts/diffstarpop_scripts/measure_smhm_galacticus_script_mpi_mgash.py @@ -207,8 +207,8 @@ hdfout["logmh_id"] = logmh_id hdfout["logmh_val"] = logmh_val hdfout["mah_params_samp"] = mah_params_samp - hdfout["ms_params_samp"] = ms_params_samp - hdfout["q_params_samp"] = q_params_samp + hdfout["ms_params_samp"] = ms_params_samp.T + hdfout["q_params_samp"] = q_params_samp.T hdfout["upid_samp"] = upid_samp hdfout["tobs_id"] = tobs_id hdfout["tobs_val"] = tobs_val diff --git a/scripts/diffstarpop_scripts/measure_smhm_smdpl_script_mpi_mgash.py b/scripts/diffstarpop_scripts/measure_smhm_smdpl_script_mpi_mgash.py index 34a6247..35365d4 100644 --- a/scripts/diffstarpop_scripts/measure_smhm_smdpl_script_mpi_mgash.py +++ b/scripts/diffstarpop_scripts/measure_smhm_smdpl_script_mpi_mgash.py @@ -264,8 +264,8 @@ hdfout["logmh_id"] = logmh_id hdfout["logmh_val"] = logmh_val hdfout["mah_params_samp"] = mah_params_samp - hdfout["ms_params_samp"] = ms_params_samp - hdfout["q_params_samp"] = q_params_samp + hdfout["ms_params_samp"] = ms_params_samp.T + hdfout["q_params_samp"] = q_params_samp.T hdfout["upid_samp"] = upid_samp hdfout["tobs_id"] = tobs_id hdfout["tobs_val"] = tobs_val diff --git a/scripts/diffstarpop_scripts/measure_smhm_tng_script_mpi_mgash.py b/scripts/diffstarpop_scripts/measure_smhm_tng_script_mpi_mgash.py index 97ab9da..fa0fde1 100644 --- a/scripts/diffstarpop_scripts/measure_smhm_tng_script_mpi_mgash.py +++ b/scripts/diffstarpop_scripts/measure_smhm_tng_script_mpi_mgash.py @@ -204,8 +204,8 @@ hdfout["logmh_id"] = logmh_id hdfout["logmh_val"] = logmh_val hdfout["mah_params_samp"] = mah_params_samp - hdfout["ms_params_samp"] = ms_params_samp - hdfout["q_params_samp"] = q_params_samp + hdfout["ms_params_samp"] = ms_params_samp.T + hdfout["q_params_samp"] = q_params_samp.T hdfout["upid_samp"] = upid_samp hdfout["tobs_id"] = tobs_id hdfout["tobs_val"] = tobs_val diff --git a/scripts/diffstarpop_scripts/smhm_utils_galacticus_mgash.py b/scripts/diffstarpop_scripts/smhm_utils_galacticus_mgash.py index 395602f..28fa747 100644 --- a/scripts/diffstarpop_scripts/smhm_utils_galacticus_mgash.py +++ b/scripts/diffstarpop_scripts/smhm_utils_galacticus_mgash.py @@ -6,8 +6,8 @@ import numpy as np from diffmah.diffmah_kernels import DEFAULT_MAH_PARAMS, mah_halopop from diffsky.diffndhist import tw_ndhist_weighted -from diffstar.defaults_mgash_model import DEFAULT_DIFFSTAR_PARAMS, LGT0, T_TABLE_MIN -from diffstar.sfh_model_mgash import calc_sfh_galpop +from diffstar.defaults import DEFAULT_DIFFSTAR_PARAMS, LGT0, T_TABLE_MIN +from diffstar.sfh_model import calc_sfh_galpop from diffstar.data_loaders.load_galacticus_sfh import load_galacticus_diffstar_data from scipy.stats import binned_statistic from astropy.cosmology import Planck13 @@ -61,20 +61,9 @@ def load_diffstar_sfh_tables( [diffmah_data[key][has_fit] for key in DEFAULT_MAH_PARAMS._fields] ) - ms_params = DEFAULT_DIFFSTAR_PARAMS.ms_params._make( - [ - diffstar_data[key][has_fit] - for key in DEFAULT_DIFFSTAR_PARAMS.ms_params._fields - ] + sfh_params = DEFAULT_DIFFSTAR_PARAMS._make( + [diffstar_data[key][has_fit] for key in DEFAULT_DIFFSTAR_PARAMS._fields] ) - q_params = DEFAULT_DIFFSTAR_PARAMS.q_params._make( - [ - diffstar_data[key][has_fit] - for key in DEFAULT_DIFFSTAR_PARAMS.q_params._fields - ] - ) - - sfh_params = DEFAULT_DIFFSTAR_PARAMS._make((ms_params, q_params)) t_0 = 10**lgt0 t_table = np.linspace(T_TABLE_MIN, t_0, n_times) @@ -94,8 +83,7 @@ def load_diffstar_sfh_tables( log_smh_table, log_ssfrh_table, mah_params, - ms_params, - q_params, + sfh_params, is_cen, has_fit, ) @@ -151,8 +139,7 @@ def sample_halos( log_mah, log_smh, mah_params, - ms_params, - q_params, + sfh_params, upid, ): ndbins_lo = logmh_bins[:-1] @@ -166,8 +153,9 @@ def sample_halos( upid_samp = [] mah_params = np.array(mah_params).T - ms_params = np.array(ms_params).T - q_params = np.array(q_params).T + sfh_params = np.array(sfh_params).T + ms_params = sfh_params[:, :4] + q_params = sfh_params[:, 4:] for i in range(len(ndbins_lo)): sel = (log_mah >= ndbins_lo[i]) & (log_mah < ndbins_hi[i]) @@ -190,8 +178,6 @@ def sample_halos( upid_samp = np.concatenate(upid_samp) mah_params_samp = DEFAULT_MAH_PARAMS._make(mah_params_samp.T) - ms_params_samp = DEFAULT_DIFFSTAR_PARAMS.ms_params._make(ms_params_samp.T) - q_params_samp = DEFAULT_DIFFSTAR_PARAMS.q_params._make(q_params_samp.T) out = ( logmh_id, logmh_val, @@ -223,8 +209,7 @@ def create_target_data( log_smh_table, log_ssfrh_table, mah_params, - ms_params, - q_params, + sfh_params, is_cen, has_fit, ) = _res @@ -283,8 +268,7 @@ def create_target_data( log_mah_table[:, tid], log_smh_table[:, tid], mah_params, - ms_params, - q_params, + sfh_params, final_upid, ) data.append( @@ -328,8 +312,8 @@ def concatenate_samples_haloes(data): logmh_id.append(subdata[0]) logmh_val.append(subdata[1]) mah_params_samp.append(np.array(subdata[2]).T) - ms_params_samp.append(np.array(subdata[3]).T) - q_params_samp.append(np.array(subdata[4]).T) + ms_params_samp.append(subdata[3]) + q_params_samp.append(subdata[4]) upid_samp.append(subdata[5]) tobs_id.append(subdata[6]) tobs_val.append(subdata[7]) @@ -346,8 +330,6 @@ def concatenate_samples_haloes(data): redshift_val = np.concatenate(redshift_val) mah_params_samp = DEFAULT_MAH_PARAMS._make(mah_params_samp.T) - ms_params_samp = DEFAULT_DIFFSTAR_PARAMS.ms_params._make(ms_params_samp.T) - q_params_samp = DEFAULT_DIFFSTAR_PARAMS.q_params._make(q_params_samp.T) haloes = ( logmh_id, @@ -436,8 +418,7 @@ def create_pdf_target_data( log_smh_table, log_ssfrh_table, mah_params, - ms_params, - q_params, + sfh_params, is_cen, has_fit, ) = _res diff --git a/scripts/diffstarpop_scripts/smhm_utils_smdpl_mgash.py b/scripts/diffstarpop_scripts/smhm_utils_smdpl_mgash.py index aa4c36f..2ec5211 100644 --- a/scripts/diffstarpop_scripts/smhm_utils_smdpl_mgash.py +++ b/scripts/diffstarpop_scripts/smhm_utils_smdpl_mgash.py @@ -6,8 +6,8 @@ import numpy as np from diffmah.diffmah_kernels import DEFAULT_MAH_PARAMS, mah_halopop from diffsky.diffndhist import tw_ndhist_weighted -from diffstar.defaults_mgash_model import DEFAULT_DIFFSTAR_PARAMS, LGT0, T_TABLE_MIN -from diffstar.sfh_model_mgash import calc_sfh_galpop +from diffstar.defaults import DEFAULT_DIFFSTAR_PARAMS, LGT0, T_TABLE_MIN +from diffstar.sfh_model import calc_sfh_galpop from scipy.stats import binned_statistic from astropy.cosmology import Planck13 from umachine_pyio.load_mock import load_mock_from_binaries @@ -145,19 +145,9 @@ def load_diffstar_sfh_tables( [diffmah_data[key][has_fit] for key in DEFAULT_MAH_PARAMS._fields] ) - ms_params = DEFAULT_DIFFSTAR_PARAMS.ms_params._make( - [ - diffstar_data[key][has_fit] - for key in DEFAULT_DIFFSTAR_PARAMS.ms_params._fields - ] + sfh_params = DEFAULT_DIFFSTAR_PARAMS._make( + [diffstar_data[key][has_fit] for key in DEFAULT_DIFFSTAR_PARAMS._fields] ) - q_params = DEFAULT_DIFFSTAR_PARAMS.q_params._make( - [ - diffstar_data[key][has_fit] - for key in DEFAULT_DIFFSTAR_PARAMS.q_params._fields - ] - ) - sfh_params = DEFAULT_DIFFSTAR_PARAMS._make((ms_params, q_params)) t_0 = 10**lgt0 t_table = np.linspace(T_TABLE_MIN, t_0, n_times) @@ -177,8 +167,7 @@ def load_diffstar_sfh_tables( log_smh_table, log_ssfrh_table, mah_params, - ms_params, - q_params, + sfh_params, has_fit, ) @@ -274,8 +263,7 @@ def sample_halos( log_mah, log_smh, mah_params, - ms_params, - q_params, + sfh_params, upid, ): ndbins_lo = logmh_bins[:-1] @@ -289,8 +277,9 @@ def sample_halos( upid_samp = [] mah_params = np.array(mah_params).T - ms_params = np.array(ms_params).T - q_params = np.array(q_params).T + sfh_params = np.array(sfh_params).T + ms_params = sfh_params[:, :4] + q_params = sfh_params[:, 4:] n_halos_per_subvol = N_HALOS_MAX // n_subvol_smdpl @@ -313,8 +302,6 @@ def sample_halos( upid_samp = np.concatenate(upid_samp) mah_params_samp = DEFAULT_MAH_PARAMS._make(mah_params_samp.T) - ms_params_samp = DEFAULT_DIFFSTAR_PARAMS.ms_params._make(ms_params_samp.T) - q_params_samp = DEFAULT_DIFFSTAR_PARAMS.q_params._make(q_params_samp.T) out = ( logmh_id, logmh_val, @@ -354,8 +341,7 @@ def create_target_data( log_smh_table, log_ssfrh_table, mah_params, - ms_params, - q_params, + sfh_params, has_fit, ) = _res @@ -411,8 +397,7 @@ def create_target_data( log_mah_table[:, tid], log_smh_table[:, tid], mah_params, - ms_params, - q_params, + sfh_params, upid, ) data.append( @@ -456,8 +441,8 @@ def concatenate_samples_haloes(data): logmh_id.append(subdata[0]) logmh_val.append(subdata[1]) mah_params_samp.append(np.array(subdata[2]).T) - ms_params_samp.append(np.array(subdata[3]).T) - q_params_samp.append(np.array(subdata[4]).T) + ms_params_samp.append(subdata[3]) + q_params_samp.append(subdata[4]) upid_samp.append(subdata[5]) tobs_id.append(subdata[6]) tobs_val.append(subdata[7]) @@ -474,8 +459,6 @@ def concatenate_samples_haloes(data): redshift_val = np.concatenate(redshift_val) mah_params_samp = DEFAULT_MAH_PARAMS._make(mah_params_samp.T) - ms_params_samp = DEFAULT_DIFFSTAR_PARAMS.ms_params._make(ms_params_samp.T) - q_params_samp = DEFAULT_DIFFSTAR_PARAMS.q_params._make(q_params_samp.T) haloes = ( logmh_id, @@ -571,8 +554,7 @@ def create_pdf_target_data( log_smh_table, log_ssfrh_table, mah_params, - ms_params, - q_params, + sfh_params, has_fit, ) = _res diff --git a/scripts/diffstarpop_scripts/smhm_utils_tng_mgash.py b/scripts/diffstarpop_scripts/smhm_utils_tng_mgash.py index eba419f..d7712b5 100644 --- a/scripts/diffstarpop_scripts/smhm_utils_tng_mgash.py +++ b/scripts/diffstarpop_scripts/smhm_utils_tng_mgash.py @@ -6,8 +6,8 @@ import numpy as np from diffmah.diffmah_kernels import DEFAULT_MAH_PARAMS, mah_halopop from diffsky.diffndhist import tw_ndhist_weighted -from diffstar.defaults_mgash_model import DEFAULT_DIFFSTAR_PARAMS, LGT0, T_TABLE_MIN -from diffstar.sfh_model_mgash import calc_sfh_galpop +from diffstar.defaults import DEFAULT_DIFFSTAR_PARAMS, LGT0, T_TABLE_MIN, FB +from diffstar.sfh_model import calc_sfh_galpop from scipy.stats import binned_statistic from astropy.cosmology import Planck13 from umachine_pyio.load_mock import load_mock_from_binaries @@ -97,19 +97,9 @@ def load_diffstar_sfh_tables( [diffmah_data[key][indx][has_fit] for key in DEFAULT_MAH_PARAMS._fields] ) - ms_params = DEFAULT_DIFFSTAR_PARAMS.ms_params._make( - [ - diffstar_data[key][indx][has_fit] - for key in DEFAULT_DIFFSTAR_PARAMS.ms_params._fields - ] + sfh_params = DEFAULT_DIFFSTAR_PARAMS._make( + [diffstar_data[key][indx][has_fit] for key in DEFAULT_DIFFSTAR_PARAMS._fields] ) - q_params = DEFAULT_DIFFSTAR_PARAMS.q_params._make( - [ - diffstar_data[key][indx][has_fit] - for key in DEFAULT_DIFFSTAR_PARAMS.q_params._fields - ] - ) - sfh_params = DEFAULT_DIFFSTAR_PARAMS._make((ms_params, q_params)) t_0 = 10**lgt0 t_table = np.linspace(T_TABLE_MIN, t_0, n_times) @@ -117,7 +107,7 @@ def load_diffstar_sfh_tables( __, log_mah_table = mah_halopop(mah_params, t_table, LGT0) sfh_table, smh_table = calc_sfh_galpop( - sfh_params, mah_params, t_table, lgt0=LGT0, return_smh=True + sfh_params, mah_params, t_table, lgt0=LGT0, fb=FB, return_smh=True ) log_sfh_table = np.log10(sfh_table) log_smh_table = np.log10(smh_table) @@ -129,8 +119,7 @@ def load_diffstar_sfh_tables( log_smh_table, log_ssfrh_table, mah_params, - ms_params, - q_params, + sfh_params, has_fit, ) @@ -219,8 +208,7 @@ def sample_halos( log_mah, log_smh, mah_params, - ms_params, - q_params, + sfh_params, upid, ): ndbins_lo = logmh_bins[:-1] @@ -234,8 +222,9 @@ def sample_halos( upid_samp = [] mah_params = np.array(mah_params).T - ms_params = np.array(ms_params).T - q_params = np.array(q_params).T + sfh_params = np.array(sfh_params).T + ms_params = sfh_params[:, :4] + q_params = sfh_params[:, 4:] for i in range(len(ndbins_lo)): sel = (log_mah >= ndbins_lo[i]) & (log_mah < ndbins_hi[i]) @@ -256,8 +245,6 @@ def sample_halos( upid_samp = np.concatenate(upid_samp) mah_params_samp = DEFAULT_MAH_PARAMS._make(mah_params_samp.T) - ms_params_samp = DEFAULT_DIFFSTAR_PARAMS.ms_params._make(ms_params_samp.T) - q_params_samp = DEFAULT_DIFFSTAR_PARAMS.q_params._make(q_params_samp.T) out = ( logmh_id, logmh_val, @@ -289,8 +276,7 @@ def create_target_data( log_smh_table, log_ssfrh_table, mah_params, - ms_params, - q_params, + sfh_params, has_fit, ) = _res @@ -350,8 +336,7 @@ def create_target_data( log_mah_table[:, tid], log_smh_table[:, tid], mah_params, - ms_params, - q_params, + sfh_params, final_upid, ) data.append( @@ -395,8 +380,8 @@ def concatenate_samples_haloes(data): logmh_id.append(subdata[0]) logmh_val.append(subdata[1]) mah_params_samp.append(np.array(subdata[2]).T) - ms_params_samp.append(np.array(subdata[3]).T) - q_params_samp.append(np.array(subdata[4]).T) + ms_params_samp.append(subdata[3]) + q_params_samp.append(subdata[4]) upid_samp.append(subdata[5]) tobs_id.append(subdata[6]) tobs_val.append(subdata[7]) @@ -413,8 +398,6 @@ def concatenate_samples_haloes(data): redshift_val = np.concatenate(redshift_val) mah_params_samp = DEFAULT_MAH_PARAMS._make(mah_params_samp.T) - ms_params_samp = DEFAULT_DIFFSTAR_PARAMS.ms_params._make(ms_params_samp.T) - q_params_samp = DEFAULT_DIFFSTAR_PARAMS.q_params._make(q_params_samp.T) haloes = ( logmh_id, @@ -503,8 +486,7 @@ def create_pdf_target_data( log_smh_table, log_ssfrh_table, mah_params, - ms_params, - q_params, + sfh_params, has_fit, ) = _res