From bd8611f6f61c31fca3d24641eb300d6a3acc281e Mon Sep 17 00:00:00 2001 From: jordanschnell Date: Mon, 13 Apr 2026 19:55:13 +0000 Subject: [PATCH 1/3] add goes interpolation; \n clean up dimensions selections/remove dependency on dataset --- .../app/chem_regrid/chem_regrid_impl.py | 343 ++++++++++++------ 1 file changed, 224 insertions(+), 119 deletions(-) diff --git a/src/regrid_wrapper/app/chem_regrid/chem_regrid_impl.py b/src/regrid_wrapper/app/chem_regrid/chem_regrid_impl.py index 2a66628..ef4c7b2 100755 --- a/src/regrid_wrapper/app/chem_regrid/chem_regrid_impl.py +++ b/src/regrid_wrapper/app/chem_regrid/chem_regrid_impl.py @@ -12,6 +12,7 @@ import esmpy import numpy as np +import xarray as xr import pandas as pd from pydantic import BaseModel @@ -38,7 +39,8 @@ # to avoid setting zeroes when a particular hour file is missing. def find_latest_rave_file(input_dir, target_time_str, ebb_dcycle, max_lookback_hours=24): """Return list of files for the latest time <= target_time_str.""" - fmt = "%Y%m%d%H" + fmt = "%Y%m%d%H" #RAVE + fmt2= "%Y%j%H" # GOES target_time = datetime.strptime(target_time_str, fmt) for h in range(max_lookback_hours + 1): @@ -50,13 +52,16 @@ def find_latest_rave_file(input_dir, target_time_str, ebb_dcycle, max_lookback_h _LOGGER.warning("unrecognized ebb_dcycle, reverting to same-day, ebb_dcycle = 1") this_time = target_time + timedelta(hours=h) - this_str = this_time.strftime(fmt) - paths = glob.glob(str(input_dir) + "/RAVE-HrlyEmiss-3km_v2r0_blend_s"+this_str+"*") + if dataset_name == "RAVE": + this_str = this_time.strftime(fmt) + paths = glob.glob(input_dir + "/RAVE-HrlyEmiss-3km_v2r0_blend_s"+this_str+"*") + elif dataset_name == "GOES": + this_str = this_time.strftime(fmt2) + paths = glob.glob(input_dir + "/OR_ABI-L2-AODC-M6_G18_s"+this_str+"*") if paths: if h > 0: - print(f"Missing RAVE file for {target_time_str}, using {this_str} instead") + print(f"Missing {dataset_name} file for {target_time_str}, using {this_str} instead") return paths - # nothing found within lookback window return [] # @@ -170,7 +175,7 @@ def reshape_field_data(self, target: np.ndarray) -> np.ndarray: ... -class RaveField1d(AbstractRaveField): +class RaveField2d(AbstractRaveField): def create_dimension_collection( self, ncells_bounds: tuple[int, int] @@ -183,7 +188,7 @@ def reshape_field_data(self, target: np.ndarray) -> np.ndarray: return target.reshape(-1) -class RaveField2d(AbstractRaveField): +class RaveField2d_plusTime(AbstractRaveField): def create_dimension_collection( self, ncells_bounds: tuple[int, int] @@ -193,8 +198,7 @@ def create_dimension_collection( ) def reshape_field_data(self, target: np.ndarray) -> np.ndarray: - return target.reshape(1, -1) - + return target.reshape(self.time_size, -1) class RaveField3d(AbstractRaveField): @@ -203,33 +207,15 @@ def create_dimension_collection( ) -> DimensionCollection: return DimensionCollection( value=( - self.time_dimension, self.create_ncells_dimension(ncells_bounds), self.nklevel_dimension, ) ) def reshape_field_data(self, target: np.ndarray) -> np.ndarray: - return target.reshape(1, -1, 1) - - -class RaveField2d_plusTime(AbstractRaveField): - - def create_dimension_collection( - self, ncells_bounds: tuple[int, int] - ) -> DimensionCollection: - return DimensionCollection( - value=( - self.create_ncells_dimension(ncells_bounds), - self.time_dimension, - ) - ) - - def reshape_field_data(self, target: np.ndarray) -> np.ndarray: - return target.reshape(-1, 12) + return target.reshape(-1, self.level_out_size) - -class RaveField4d(AbstractRaveField): +class RaveField3d_plusTime(AbstractRaveField): def create_dimension_collection( self, ncells_bounds: tuple[int, int] @@ -241,10 +227,8 @@ def create_dimension_collection( self.time_dimension, ) ) - def reshape_field_data(self, target: np.ndarray) -> np.ndarray: - return target.reshape(-1, self.level_out_size,self.time_size) - + return target.reshape(-1, self.level_out_size, self.time_size) class RaveToMpasRegridContext(BaseModel): dataset_name: str @@ -300,24 +284,16 @@ def rave_fields(self) -> tuple[AbstractRaveField, ...]: "time_size": self.time_size, "num_cells": self.num_cells, } - if field_name in ("clayfrac", "sandfrac", "uthres", "ssm"): - app = RaveField1d.model_validate(init_data) - elif field_name in ("FRE", "FRP_MEAN", "RWC_denominator", "ecoregion_ID", "10h_dead_fuel_moisture_content"): - app = RaveField2d.model_validate(init_data) - elif field_name in ("DBL_POLL", "ENL_POLL", "GRA_POLL", "RAG_POLL"): - app = RaveField3d.model_validate(init_data) - elif self.dataset_name == 'NEMO_RWC' and field_name in ("PEC", "POC", "PMOTHR", "PMC"): - app = RaveField3d.model_validate(init_data) - elif self.dataset_name == 'NEMO_ANTHRO' and field_name in ("PEC", "POC", "PMOTHR", "PMC"): - app = RaveField4d.model_validate(init_data) - elif self.dataset_name == 'RAVE' and field_name in ("PM25", "NH3", "SO2", "TPM", "NOx", "CH4","CO"): - app = RaveField3d.model_validate(init_data) - elif field_name in ("rdrag",): - app = RaveField2d_plusTime.model_validate(init_data) - elif self.dataset_name == 'GRA2PES' and field_name in ("HC01", "PM25-PRI", "PM10-PRI", "h_agl","SO2","NH3","NOX","CO"): - app = RaveField4d.model_validate(init_data) + if self.level_out_size == 0: + if self.time_size == 0: + app = RaveField2d.model_validate(init_data) + else: + app = RaveField2d_plusTime.model_validate(init_data) else: - raise NotImplementedError(field_name) + if self.time_size == 0: + app = RaveField3d.model_validate(init_data) + else: + app = RaveField3d_plusTime.model_validate(init_data) rave_fields.append(app) _LOGGER.debug(f"{rave_fields=}") return tuple(rave_fields) @@ -362,10 +338,15 @@ def initialize(self) -> None: # ) # mpas_desc.to_scrip(str(self.context.scrip_path)) +# JLS - temporary fix for coords not in file + if self.context.dataset_name == "GOES": + pathsrc="/scratch4/BMC/acomp/cheMPAS-Fire/input/grids/domain_latlons/goes19_abi_conus_interpolated_lat_lon.nc" + else: + pathsrc=self.context.src_path _LOGGER.info("create source grid") if self.context.x_corner_dim is None: self._src_gwrap = NcToGrid( - path=self.context.src_path, + path=pathsrc, spec=GridSpec( x_center=self.context.x_center, y_center=self.context.y_center, @@ -379,7 +360,7 @@ def initialize(self) -> None: ).create_grid_wrapper() else: self._src_gwrap = NcToGrid( - path=self.context.src_path, + path=pathsrc, spec=GridSpec( x_center=self.context.x_center, y_center=self.context.y_center, @@ -406,24 +387,32 @@ def initialize(self) -> None: dst_mesh = self._dst_mesh # Check for extra dims beyond lat/lon - if self.context.level_out_size > 1 and self.context.time_size > 1: - self._dst_field = esmpy.Field( - dst_mesh, name="dst", meshloc=esmpy.MeshLoc.ELEMENT, ndbounds=(self.context.level_out_size, self.context.time_size) - ) - elif self.context.level_out_size > 1 and self.context.time_size == 1: - self._dst_field = esmpy.Field( - dst_mesh, name="dst", meshloc=esmpy.MeshLoc.ELEMENT, ndbounds=(self.context.level_out_size,) - ) - elif self.context.level_out_size == 1 and self.context.time_size > 1: - self._dst_field = esmpy.Field( - dst_mesh, name="dst", meshloc=esmpy.MeshLoc.ELEMENT, ndbounds=(self.context.time_size,) - ) + _LOGGER.info("create destination field") + if self.context.level_out_size == 0: + #2D + if self.context.time_size == 0: + # 2D, static in Time + self._dst_field = esmpy.Field( + dst_mesh, name="dst", meshloc=esmpy.MeshLoc.ELEMENT, + ) + else: + # 2D + Time + self._dst_field = esmpy.Field( + dst_mesh, name="dst", meshloc=esmpy.MeshLoc.ELEMENT, ndbounds=(self.context.time_size,) + ) else: - _LOGGER.info("create destination field") - self._dst_field = esmpy.Field( - dst_mesh, name="dst", meshloc=esmpy.MeshLoc.ELEMENT - ) - + #3D + if self.context.time_size == 0: + # 3D, static in Time + self._dst_field = esmpy.Field( + dst_mesh, name="dst", meshloc=esmpy.MeshLoc.ELEMENT, ndbounds=(self.context.level_out_size,) + ) + else: + # 3D + Time + self._dst_field = esmpy.Field( + dst_mesh, name="dst", meshloc=esmpy.MeshLoc.ELEMENT, ndbounds=(self.context.level_out_size, self.context.time_size) + ) +# Check for weights _LOGGER.info("create regridder") if self.context.weight_path.exists(): _LOGGER.info("create regridder from file") @@ -535,8 +524,13 @@ def run(self) -> None: [dim.name[0] for dim in dims.value], fill_value=rave_field.fill_value, ) - for k, v in rave_field.attrs.items(): - setattr(var, k, v) +# Don't carry over fill value and datatype + if self.context.dataset_name != 'GOES': + type_to_use = rave_field.dtype + for k, v in rave_field.attrs.items(): + setattr(var, k, v) + else: + type_to_use = np.float32 _LOGGER.info(f"setting variable data {rave_field.name=}") # Multiply FRE/FRP by output area so it is back to W or J*s @@ -691,47 +685,49 @@ def create_desc_stuff(self, targets: Iterable[FileDesc]) -> pd.DataFrame: def create_src_field_wrapper(self, field_name: str) -> FieldWrapper: _LOGGER.info("create source field") - if self.context.dataset_name == "GRA2PES" and field_name in ("PM25-PRI", "PM10-PRI", "HC01", "h_agl","SO2","CO","NH3","NOX"): - if field_name in ("h_agl",): - src_fwrap = NcToField( - path=self.context.src_path, - name=field_name, - gwrap=self.get_src_gwrap(), - dim_time=(self.context.time_name,), - dim_level=('bottom_top_stag',), - ).create_field_wrapper() - else: - src_fwrap = NcToField( - path=self.context.src_path, - name=field_name, - gwrap=self.get_src_gwrap(), - dim_time=(self.context.time_name,), - dim_level=(self.context.level_in_name,), - ).create_field_wrapper() - elif self.context.dataset_name == "NEMO_ANTHRO": - src_fwrap = NcToField( - path=self.context.src_path, - name=field_name, - gwrap=self.get_src_gwrap(), - dim_time=(self.context.time_name,), - dim_level=(self.context.level_in_name,), - ).create_field_wrapper() - - elif field_name in ("clayfrac", "sandfrac", "uthres", "ssm"): - src_fwrap = NcToField( - path=self.context.src_path, - name=field_name, - gwrap=self.get_src_gwrap(), - dim_time=None, - dim_level=(self.context.level_in_name,), - ).create_field_wrapper() + if self.context.dataset_name == "GRA2PES" and field_name in ("h_agl",): # Special case for staggered grid + src_fwrap = NcToField( + path=self.context.src_path, + name=field_name, + gwrap=self.get_src_gwrap(), + dim_time=(self.context.time_name,), + dim_level=('bottom_top_stag',), + ).create_field_wrapper() + elif self.context.level_in_name == "None": + if self.context.time_name == "None": + src_fwrap = NcToField( + path=self.context.src_path, + name=field_name, + gwrap=self.get_src_gwrap(), + dim_time=None, + dim_level=None, + ).create_field_wrapper() + else: + src_fwrap = NcToField( + path=self.context.src_path, + name=field_name, + gwrap=self.get_src_gwrap(), + dim_time=(self.context.time_name,), + dim_level=None, + ).create_field_wrapper() else: - src_fwrap = NcToField( - path=self.context.src_path, - name=field_name, - gwrap=self.get_src_gwrap(), - dim_time=(self.context.time_name,), - ).create_field_wrapper() + if self.context.time_name == "None": + src_fwrap = NcToField( + path=self.context.src_path, + name=field_name, + gwrap=self.get_src_gwrap(), + dim_time=None, + dim_level=(self.context.level_in_name,), + ).create_field_wrapper() + else: + src_fwrap = NcToField( + path=self.context.src_path, + name=field_name, + gwrap=self.get_src_gwrap(), + dim_time=(self.context.time_name,), + dim_level=(self.context.level_in_name,), + ).create_field_wrapper() + # Get the area from the RAVE file, need to convert from /grid to /m2 if (self.context.dataset_name == "RAVE" and field_name in ("PM25", "NH3", "SO2", "FRE", "FRP_MEAN", "TPM", "CH4", "CO", "NOx")): area_fwrap = NcToField( @@ -1060,7 +1056,6 @@ def main(ctx: ChemRegridContext) -> None: time_name = "time" time_size = 1 InterpMethod = "CONSERVE" - elif dataset_name == "GRA2PES": field_names = ("PM25-PRI", "PM10-PRI","SO2","CO","NOX","NH3","h_agl") # ,"HC01"=methane BAQMS, summer, 2025 x_center = "XLONG" # "XLONG_M" @@ -1106,8 +1101,8 @@ def main(ctx: ChemRegridContext) -> None: x_corner_dim = "COLC" y_corner_dim = "ROWC" level_in_name = "None" - level_out_name = "nkreswoodcomb" - level_out_size = 1 + level_out_name = "None" + level_out_size = 0 time_name = "Time" time_size = 1 InterpMethod = "CONSERVE" @@ -1155,8 +1150,8 @@ def main(ctx: ChemRegridContext) -> None: x_corner_dim = None y_corner_dim = None level_in_name = "None" - level_out_name = "nkreswoodcomb" - level_out_size = 1 + level_out_name = "None" + level_out_size = 0 time_name = "Time" time_size = 1 InterpMethod = "BILINEAR" @@ -1171,8 +1166,8 @@ def main(ctx: ChemRegridContext) -> None: x_corner_dim = None y_corner_dim = None level_in_name = "None" - level_out_name = "nkemit" - level_out_size = 1 + level_out_name = "None" + level_out_size = 0 time_name = "time" time_size = 0 InterpMethod = "BILINEAR" @@ -1187,8 +1182,8 @@ def main(ctx: ChemRegridContext) -> None: x_corner_dim = None y_corner_dim = None level_in_name = "None" - level_out_name = "nkemit" - level_out_size = 1 + level_out_name = "None" + level_out_size = 0 time_name = "time" time_size = 12 InterpMethod = "BILINEAR" @@ -1213,6 +1208,34 @@ def main(ctx: ChemRegridContext) -> None: time_name = "time" time_size = 1 InterpMethod = "BILINEAR" + InterpMethod = "CONSERVE" + elif dataset_name == "GOES": + field_names = ("AOD",) + x_center = "longitude" + y_center = "latitude" + x_dim = "x" + y_dim = "y" + x_corner = None + y_corner = None + x_corner_dim = None + y_corner_dim = None + level_in_name = "None" + level_out_name = "None" + level_out_size = 0 + time_name = "None" + time_size = 0 + InterpMethod = "BILINEAR" + dates_needed = [] + for i in range(25): + if ebb_dcycle == 1: # Same-day emissions + x = datetime(int(YYYY), int(MM), int(DD), int(HH), 0, 0) + timedelta(hours=i) + elif ebb_dcycle == -1 or ebb_dcycle == 2: # Persistence (-1) or forecasted (2) needs prev 24 hours + x = datetime(int(YYYY), int(MM), int(DD), int(HH), 0, 0) - timedelta(hours=i) + else: + _LOGGER.info("EBB_DCYLE selection not recognized, reverting to same day, ebb_dcycle = 1") + x = datetime(int(YYYY), int(MM), int(DD), int(HH), 0, 0) - timedelta(hours=i) + y = x.strftime("%Y%m%d%H") + dates_needed.append(y) weight_path = ctx.get_weight_path(InterpMethod) @@ -1329,6 +1352,88 @@ def main(ctx: ChemRegridContext) -> None: _LOGGER.info("NGFS success") + elif dataset_name == "GOES": + processor = None + #files_to_cat = [] + #for date_to_process in dates_needed: + # print("date to process = ") + # print(date_to_process) + date_to_process = dates_needed[0] + rave_paths = find_latest_rave_file(input_dir, date_to_process, -1, dataset_name, max_lookback_hours=2) + # files_to_cat.append(rave_paths) + # Unique + files_to_cat = rave_paths + print("will cat files: ") + print(files_to_cat) + if COMM.rank == 0: + ds = xr.open_mfdataset(files_to_cat, combine='nested', concat_dim='file') + # 2. Calculate the nanmean across the new 'file' dimension + # skipna=True (default) ensures it behaves like np.nanmean + ds_averaged = ds['AOD'].mean(dim='file', skipna=True) + print(ds_averaged) + ds_averaged.encoding.update({ + 'dtype': 'float32', + '_FillValue': -999 + }) + ds_averaged.to_netcdf(Path(output_dir + '/test_goes_aod_merged.nc')) + + if not rave_paths: + print(f"No matching GOES files found for {date_to_process} (even after lookback).") + exit() + + print('Reading merged GOES file:', 'test_goes_aod_merged.nc') + #rave_path = rave_paths[0] + rave_path = Path(output_dir + "/test_goes_aod_merged.nc") + new_dst_path = Path(output_dir + "/" + mesh_name + "-GOES-" + date_to_process + ".nc") + # --- OPTIMIZATION START --- + if processor is None: + # FIRST PASS: Full Initialization + # This pays the "expensive" cost of loading weights/grids, but only once. + + context = RaveToMpasRegridContext( + dataset_name=dataset_name, + src_path=rave_path, + dst_path=dst_path, + new_dst_path=new_dst_path, + desc_stats_out=desc_stats_out, + weight_path=weight_path, + InterpMethod=InterpMethod, + scrip_path=scrip_path, + num_cells=num_cells, + mesh_name=mesh_name, + field_names=field_names, + x_center=x_center, + y_center=y_center, + x_dim=x_dim, + y_dim=y_dim, + x_corner=x_corner, + y_corner=y_corner, + x_corner_dim=x_corner_dim, + y_corner_dim=y_corner_dim, + level_in_name=level_in_name, + level_out_name=level_out_name, + level_out_size=level_out_size, + time_name=time_name, + time_size=time_size + + ) + processor = RaveToMpasRegridProcessor(context=context) + processor.initialize() + else: + # SUBSEQUENT PASSES: Hot Swap + # Just update the paths in the existing context. + # The grids and regridder (weights) remain loaded in memory. + processor.context.src_path = rave_path + processor.context.new_dst_path = new_dst_path + # Run the regridding (Fast) + processor.run() + # --- OPTIMIZATION END --- + # Only finalize after ALL files are done + if processor: + processor.finalize() + + _LOGGER.info("success") + elif dataset_name == "FMC": for date_to_process in dates_needed: rave_paths = glob.glob(input_dir + "fmc_" + date_to_process + ".nc") From 142329a40fac4446a387f6392af30901f1e74640 Mon Sep 17 00:00:00 2001 From: jordanschnell Date: Mon, 13 Apr 2026 20:02:49 +0000 Subject: [PATCH 2/3] bug --- src/regrid_wrapper/app/chem_regrid/chem_regrid_impl.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/regrid_wrapper/app/chem_regrid/chem_regrid_impl.py b/src/regrid_wrapper/app/chem_regrid/chem_regrid_impl.py index ef4c7b2..adf27dc 100755 --- a/src/regrid_wrapper/app/chem_regrid/chem_regrid_impl.py +++ b/src/regrid_wrapper/app/chem_regrid/chem_regrid_impl.py @@ -1208,7 +1208,6 @@ def main(ctx: ChemRegridContext) -> None: time_name = "time" time_size = 1 InterpMethod = "BILINEAR" - InterpMethod = "CONSERVE" elif dataset_name == "GOES": field_names = ("AOD",) x_center = "longitude" From 292fdee5764291890c1c29b155a44094b9202369 Mon Sep 17 00:00:00 2001 From: jordanschnell Date: Wed, 15 Apr 2026 13:30:27 +0000 Subject: [PATCH 3/3] update prints and hardcoded path on Ursa --- .../app/chem_regrid/chem_regrid_impl.py | 15 ++++----------- src/regrid_wrapper/app/chem_regrid/context.py | 1 + 2 files changed, 5 insertions(+), 11 deletions(-) diff --git a/src/regrid_wrapper/app/chem_regrid/chem_regrid_impl.py b/src/regrid_wrapper/app/chem_regrid/chem_regrid_impl.py index adf27dc..df46a41 100755 --- a/src/regrid_wrapper/app/chem_regrid/chem_regrid_impl.py +++ b/src/regrid_wrapper/app/chem_regrid/chem_regrid_impl.py @@ -340,7 +340,7 @@ def initialize(self) -> None: # JLS - temporary fix for coords not in file if self.context.dataset_name == "GOES": - pathsrc="/scratch4/BMC/acomp/cheMPAS-Fire/input/grids/domain_latlons/goes19_abi_conus_interpolated_lat_lon.nc" + pathsrc=workdir+"/goes19_abi_conus_interpolated_lat_lon.nc" else: pathsrc=self.context.src_path _LOGGER.info("create source grid") @@ -1353,17 +1353,10 @@ def main(ctx: ChemRegridContext) -> None: elif dataset_name == "GOES": processor = None - #files_to_cat = [] - #for date_to_process in dates_needed: - # print("date to process = ") - # print(date_to_process) date_to_process = dates_needed[0] rave_paths = find_latest_rave_file(input_dir, date_to_process, -1, dataset_name, max_lookback_hours=2) - # files_to_cat.append(rave_paths) - # Unique files_to_cat = rave_paths - print("will cat files: ") - print(files_to_cat) + _LOGGER.info(f"will cat files: {files_to_cat=}") if COMM.rank == 0: ds = xr.open_mfdataset(files_to_cat, combine='nested', concat_dim='file') # 2. Calculate the nanmean across the new 'file' dimension @@ -1377,8 +1370,8 @@ def main(ctx: ChemRegridContext) -> None: ds_averaged.to_netcdf(Path(output_dir + '/test_goes_aod_merged.nc')) if not rave_paths: - print(f"No matching GOES files found for {date_to_process} (even after lookback).") - exit() + _LOGGER.info(f"No matching GOES files found for {date_to_process} (even after lookback).") + raise ValueError print('Reading merged GOES file:', 'test_goes_aod_merged.nc') #rave_path = rave_paths[0] diff --git a/src/regrid_wrapper/app/chem_regrid/context.py b/src/regrid_wrapper/app/chem_regrid/context.py index fbb1337..538068b 100644 --- a/src/regrid_wrapper/app/chem_regrid/context.py +++ b/src/regrid_wrapper/app/chem_regrid/context.py @@ -37,6 +37,7 @@ class DatasetName(StrEnum): FENGSHA_2D = "FENGSHA_2D" FENGSHA_2D_Time = "FENGSHA_2D_Time" NGFS = "NGFS" + GOES = "GOES" class ChemRegridContext(RwBaseModel):