From a2fabd069590f118410fd8985cc83b6b8426f8c8 Mon Sep 17 00:00:00 2001 From: Andrew Hearin Date: Thu, 30 Apr 2026 09:24:04 -0500 Subject: [PATCH] Add halopop.central field. Set nhalos_host=1 for host halos --- .../lightcone_generators/mc_lightcone.py | 19 +++++++++++++---- .../tests/test_mc_lightcone.py | 21 +++++++++++++++++++ 2 files changed, 36 insertions(+), 4 deletions(-) diff --git a/diffhalos/lightcone_generators/mc_lightcone.py b/diffhalos/lightcone_generators/mc_lightcone.py index 4193f83..2c176fc 100644 --- a/diffhalos/lightcone_generators/mc_lightcone.py +++ b/diffhalos/lightcone_generators/mc_lightcone.py @@ -413,9 +413,12 @@ def weighted_lc( the host halos should be weighted by nhalos, but subhalos should be weighted by nhalos*nhalos_host. + central : ndarray of shape (n_halos_tot, ) + Integer equals 1 for central halos and 0 for subhalos + nhalos_host: ndarray of shape (n_halos_tot, ) Multiplicity factor of the host halo - Equals nhalos for central halos + Equals 1 for central halos For subhalos, halopop.nhalos_host = halopop.nhalos[halopop.halo_indx] nsub_per_host: int @@ -513,6 +516,7 @@ def _weighted_lc_from_grid( host_indx = jnp.arange(n_host).astype(int) subhalo_indx = jnp.repeat(host_indx, subpop.nsub_per_host) halo_indx = jnp.concatenate((host_indx, subhalo_indx)).astype(int) + central = jnp.concatenate((jnp.ones(n_host), jnp.zeros(n_sub))).astype(int) z_obs_subs = jnp.repeat(cenpop.z_obs, subpop.nsub_per_host) z_obs_all = jnp.concatenate((cenpop.z_obs, z_obs_subs)) @@ -523,7 +527,7 @@ def _weighted_lc_from_grid( cenpop = cenpop._replace(t_obs=t_obs_all) nhalos_host_subs = jnp.repeat(cenpop.nhalos, subpop.nsub_per_host) - nhalos_host_all = jnp.concatenate((cenpop.nhalos, nhalos_host_subs)) + nhalos_host_all = jnp.concatenate((jnp.ones(n_host), nhalos_host_subs)) logmp_obs_all = jnp.concatenate((cenpop.logmp_obs, subpop.logmp_obs)) cenpop = cenpop._replace(logmp_obs=logmp_obs_all) @@ -559,7 +563,14 @@ def _weighted_lc_from_grid( # the subhalo information and some fields are updated to new shapes halopop = namedtuple( "weighted_lc", - [*cenpop._fields, "nhalos_host", "nsub_per_host", "logmu_obs", "halo_indx"], - )(*cenpop, nhalos_host_all, subpop.nsub_per_host, logmu_obs_all, halo_indx) + [ + *cenpop._fields, + "central", + "nhalos_host", + "nsub_per_host", + "logmu_obs", + "halo_indx", + ], + )(*cenpop, central, nhalos_host_all, subpop.nsub_per_host, logmu_obs_all, halo_indx) return halopop diff --git a/diffhalos/lightcone_generators/tests/test_mc_lightcone.py b/diffhalos/lightcone_generators/tests/test_mc_lightcone.py index 8ebb8d1..e84beb2 100644 --- a/diffhalos/lightcone_generators/tests/test_mc_lightcone.py +++ b/diffhalos/lightcone_generators/tests/test_mc_lightcone.py @@ -253,3 +253,24 @@ def test_weighted_lc_tpeak_clip(): # satellites: logmp_obs == logmp0 assert np.allclose(logmp0_subs, logmsub_obs) + + +def test_weighted_lc_nhalos_host(): + ran_key = jran.key(0) + + n_host_halos = 100 + z_min, z_max = 0.1, 3.1 + sky_area_degsq = 10.0 + lgmp_min, lgmp_max = 10.0, 15.0 + args = (ran_key, n_host_halos, z_min, z_max, lgmp_min, lgmp_max, sky_area_degsq) + halopop = mclc.weighted_lc(*args) + + assert np.allclose(halopop.central[:n_host_halos], 1) + assert np.allclose(halopop.central[n_host_halos:], 0) + + assert np.allclose(halopop.nhalos_host[:n_host_halos], 1) + + assert np.allclose( + halopop.nhalos_host[n_host_halos:], + halopop.nhalos[halopop.halo_indx][n_host_halos:], + )