From eedc7dc34ed911fe7fe386cf47a84777a7e157a5 Mon Sep 17 00:00:00 2001 From: Kumail Zaidi Date: Sun, 31 May 2026 18:54:58 -0500 Subject: [PATCH] don't clip grads --- diffhtwo/experimental/data_loaders/load_feniks.py | 2 +- diffhtwo/experimental/defaults.py | 1 - diffhtwo/experimental/optimizers/Np_specphot_opt.py | 10 +++++----- scripts/config_diagnostics.yaml | 6 +++--- 4 files changed, 9 insertions(+), 10 deletions(-) diff --git a/diffhtwo/experimental/data_loaders/load_feniks.py b/diffhtwo/experimental/data_loaders/load_feniks.py index 6dd0ad94..d632a6c4 100644 --- a/diffhtwo/experimental/data_loaders/load_feniks.py +++ b/diffhtwo/experimental/data_loaders/load_feniks.py @@ -121,7 +121,7 @@ def get_lh_centroids(dataset, lh_d_mag): lh_centroids[:, -1] < (FENIKS_Z_MAX - (LH_D_Z / 2)) ) k_mask = lh_centroids[:, -2] < FENIKS_MAGK_THRESH - u_mask = lh_centroids[:, -3] < 25 + u_mask = lh_centroids[:, -3] < 24.9 lh_centroids = lh_centroids[redshift_mask & k_mask & u_mask] redshift_centers = [0.45, 0.95, 1.45, 1.95, 2.45, 2.95, 3.45, 3.95] diff --git a/diffhtwo/experimental/defaults.py b/diffhtwo/experimental/defaults.py index 5f0344a2..1c83c063 100644 --- a/diffhtwo/experimental/defaults.py +++ b/diffhtwo/experimental/defaults.py @@ -21,7 +21,6 @@ FENIKS_Z_MIN = 0.2 FENIKS_Z_MAX = 3.0 FENIKS_MAGK_THRESH = 24.3 # col mag -FENIKS_MAGOTHER_THRESH = 27.0 SDSS_AREA_DEG2 = 7199 SDSS_Z_MIN = 0.02 diff --git a/diffhtwo/experimental/optimizers/Np_specphot_opt.py b/diffhtwo/experimental/optimizers/Np_specphot_opt.py index ff6a3efd..2179de48 100644 --- a/diffhtwo/experimental/optimizers/Np_specphot_opt.py +++ b/diffhtwo/experimental/optimizers/Np_specphot_opt.py @@ -13,7 +13,7 @@ from jax import lax, value_and_grad, vmap from jax.example_libraries import optimizers as jax_opt -from ..loss_kernels.emline_loss import _loss_emline_kern_multi_line_multi_z +# from ..loss_kernels.emline_loss import _loss_emline_kern_multi_line_multi_z from ..loss_kernels.phot_loss import _loss_phot_kern _L_pk = ( @@ -58,10 +58,10 @@ def _opt_update(opt_state, i): ) # clip gradients - global_norm = jnp.sqrt(sum(jnp.sum(g**2) for g in grads)) - tau = 1.0 - scale = jnp.minimum(1.0, tau / (global_norm + 1e-6)) - grads = tuple(g * scale for g in grads) + # global_norm = jnp.sqrt(sum(jnp.sum(g**2) for g in grads)) + # tau = 1.0 + # scale = jnp.minimum(1.0, tau / (global_norm + 1e-6)) + # grads = tuple(g * scale for g in grads) opt_state = opt_update(i, grads, opt_state) return opt_state, loss diff --git a/scripts/config_diagnostics.yaml b/scripts/config_diagnostics.yaml index 7476fbd0..8f26e482 100644 --- a/scripts/config_diagnostics.yaml +++ b/scripts/config_diagnostics.yaml @@ -1,6 +1,6 @@ -model_drn: /Users/kumail/diffdir/fits/run121 -model_nickname: run121_diffstarpop+spspop+merging -fit_diagnostics_save_drn: /Users/kumail/diffdir/fits/run121/diagnostic_plots/diffstarpop+spspop+merging +model_drn: /Users/kumail/diffdir/fits/run122 +model_nickname: run122_all +fit_diagnostics_save_drn: /Users/kumail/diffdir/fits/run122/diagnostic_plots/all feniks_drn: /Users/kumail/diffdir/feniks sdss_drn: /Users/kumail/diffdir/sdss