diff --git a/src/regrid_wrapper/app/chem_regrid/chem_regrid_impl.py b/src/regrid_wrapper/app/chem_regrid/chem_regrid_impl.py index 2a66628..df46a41 100755 --- a/src/regrid_wrapper/app/chem_regrid/chem_regrid_impl.py +++ b/src/regrid_wrapper/app/chem_regrid/chem_regrid_impl.py @@ -12,6 +12,7 @@ import esmpy import numpy as np +import xarray as xr import pandas as pd from pydantic import BaseModel @@ -38,7 +39,8 @@ # to avoid setting zeroes when a particular hour file is missing. def find_latest_rave_file(input_dir, target_time_str, ebb_dcycle, max_lookback_hours=24): """Return list of files for the latest time <= target_time_str.""" - fmt = "%Y%m%d%H" + fmt = "%Y%m%d%H" #RAVE + fmt2= "%Y%j%H" # GOES target_time = datetime.strptime(target_time_str, fmt) for h in range(max_lookback_hours + 1): @@ -50,13 +52,16 @@ def find_latest_rave_file(input_dir, target_time_str, ebb_dcycle, max_lookback_h _LOGGER.warning("unrecognized ebb_dcycle, reverting to same-day, ebb_dcycle = 1") this_time = target_time + timedelta(hours=h) - this_str = this_time.strftime(fmt) - paths = glob.glob(str(input_dir) + "/RAVE-HrlyEmiss-3km_v2r0_blend_s"+this_str+"*") + if dataset_name == "RAVE": + this_str = this_time.strftime(fmt) + paths = glob.glob(input_dir + "/RAVE-HrlyEmiss-3km_v2r0_blend_s"+this_str+"*") + elif dataset_name == "GOES": + this_str = this_time.strftime(fmt2) + paths = glob.glob(input_dir + "/OR_ABI-L2-AODC-M6_G18_s"+this_str+"*") if paths: if h > 0: - print(f"Missing RAVE file for {target_time_str}, using {this_str} instead") + print(f"Missing {dataset_name} file for {target_time_str}, using {this_str} instead") return paths - # nothing found within lookback window return [] # @@ -170,7 +175,7 @@ def reshape_field_data(self, target: np.ndarray) -> np.ndarray: ... -class RaveField1d(AbstractRaveField): +class RaveField2d(AbstractRaveField): def create_dimension_collection( self, ncells_bounds: tuple[int, int] @@ -183,7 +188,7 @@ def reshape_field_data(self, target: np.ndarray) -> np.ndarray: return target.reshape(-1) -class RaveField2d(AbstractRaveField): +class RaveField2d_plusTime(AbstractRaveField): def create_dimension_collection( self, ncells_bounds: tuple[int, int] @@ -193,8 +198,7 @@ def create_dimension_collection( ) def reshape_field_data(self, target: np.ndarray) -> np.ndarray: - return target.reshape(1, -1) - + return target.reshape(self.time_size, -1) class RaveField3d(AbstractRaveField): @@ -203,33 +207,15 @@ def create_dimension_collection( ) -> DimensionCollection: return DimensionCollection( value=( - self.time_dimension, self.create_ncells_dimension(ncells_bounds), self.nklevel_dimension, ) ) def reshape_field_data(self, target: np.ndarray) -> np.ndarray: - return target.reshape(1, -1, 1) - - -class RaveField2d_plusTime(AbstractRaveField): - - def create_dimension_collection( - self, ncells_bounds: tuple[int, int] - ) -> DimensionCollection: - return DimensionCollection( - value=( - self.create_ncells_dimension(ncells_bounds), - self.time_dimension, - ) - ) - - def reshape_field_data(self, target: np.ndarray) -> np.ndarray: - return target.reshape(-1, 12) + return target.reshape(-1, self.level_out_size) - -class RaveField4d(AbstractRaveField): +class RaveField3d_plusTime(AbstractRaveField): def create_dimension_collection( self, ncells_bounds: tuple[int, int] @@ -241,10 +227,8 @@ def create_dimension_collection( self.time_dimension, ) ) - def reshape_field_data(self, target: np.ndarray) -> np.ndarray: - return target.reshape(-1, self.level_out_size,self.time_size) - + return target.reshape(-1, self.level_out_size, self.time_size) class RaveToMpasRegridContext(BaseModel): dataset_name: str @@ -300,24 +284,16 @@ def rave_fields(self) -> tuple[AbstractRaveField, ...]: "time_size": self.time_size, "num_cells": self.num_cells, } - if field_name in ("clayfrac", "sandfrac", "uthres", "ssm"): - app = RaveField1d.model_validate(init_data) - elif field_name in ("FRE", "FRP_MEAN", "RWC_denominator", "ecoregion_ID", "10h_dead_fuel_moisture_content"): - app = RaveField2d.model_validate(init_data) - elif field_name in ("DBL_POLL", "ENL_POLL", "GRA_POLL", "RAG_POLL"): - app = RaveField3d.model_validate(init_data) - elif self.dataset_name == 'NEMO_RWC' and field_name in ("PEC", "POC", "PMOTHR", "PMC"): - app = RaveField3d.model_validate(init_data) - elif self.dataset_name == 'NEMO_ANTHRO' and field_name in ("PEC", "POC", "PMOTHR", "PMC"): - app = RaveField4d.model_validate(init_data) - elif self.dataset_name == 'RAVE' and field_name in ("PM25", "NH3", "SO2", "TPM", "NOx", "CH4","CO"): - app = RaveField3d.model_validate(init_data) - elif field_name in ("rdrag",): - app = RaveField2d_plusTime.model_validate(init_data) - elif self.dataset_name == 'GRA2PES' and field_name in ("HC01", "PM25-PRI", "PM10-PRI", "h_agl","SO2","NH3","NOX","CO"): - app = RaveField4d.model_validate(init_data) + if self.level_out_size == 0: + if self.time_size == 0: + app = RaveField2d.model_validate(init_data) + else: + app = RaveField2d_plusTime.model_validate(init_data) else: - raise NotImplementedError(field_name) + if self.time_size == 0: + app = RaveField3d.model_validate(init_data) + else: + app = RaveField3d_plusTime.model_validate(init_data) rave_fields.append(app) _LOGGER.debug(f"{rave_fields=}") return tuple(rave_fields) @@ -362,10 +338,15 @@ def initialize(self) -> None: # ) # mpas_desc.to_scrip(str(self.context.scrip_path)) +# JLS - temporary fix for coords not in file + if self.context.dataset_name == "GOES": + pathsrc=workdir+"/goes19_abi_conus_interpolated_lat_lon.nc" + else: + pathsrc=self.context.src_path _LOGGER.info("create source grid") if self.context.x_corner_dim is None: self._src_gwrap = NcToGrid( - path=self.context.src_path, + path=pathsrc, spec=GridSpec( x_center=self.context.x_center, y_center=self.context.y_center, @@ -379,7 +360,7 @@ def initialize(self) -> None: ).create_grid_wrapper() else: self._src_gwrap = NcToGrid( - path=self.context.src_path, + path=pathsrc, spec=GridSpec( x_center=self.context.x_center, y_center=self.context.y_center, @@ -406,24 +387,32 @@ def initialize(self) -> None: dst_mesh = self._dst_mesh # Check for extra dims beyond lat/lon - if self.context.level_out_size > 1 and self.context.time_size > 1: - self._dst_field = esmpy.Field( - dst_mesh, name="dst", meshloc=esmpy.MeshLoc.ELEMENT, ndbounds=(self.context.level_out_size, self.context.time_size) - ) - elif self.context.level_out_size > 1 and self.context.time_size == 1: - self._dst_field = esmpy.Field( - dst_mesh, name="dst", meshloc=esmpy.MeshLoc.ELEMENT, ndbounds=(self.context.level_out_size,) - ) - elif self.context.level_out_size == 1 and self.context.time_size > 1: - self._dst_field = esmpy.Field( - dst_mesh, name="dst", meshloc=esmpy.MeshLoc.ELEMENT, ndbounds=(self.context.time_size,) - ) + _LOGGER.info("create destination field") + if self.context.level_out_size == 0: + #2D + if self.context.time_size == 0: + # 2D, static in Time + self._dst_field = esmpy.Field( + dst_mesh, name="dst", meshloc=esmpy.MeshLoc.ELEMENT, + ) + else: + # 2D + Time + self._dst_field = esmpy.Field( + dst_mesh, name="dst", meshloc=esmpy.MeshLoc.ELEMENT, ndbounds=(self.context.time_size,) + ) else: - _LOGGER.info("create destination field") - self._dst_field = esmpy.Field( - dst_mesh, name="dst", meshloc=esmpy.MeshLoc.ELEMENT - ) - + #3D + if self.context.time_size == 0: + # 3D, static in Time + self._dst_field = esmpy.Field( + dst_mesh, name="dst", meshloc=esmpy.MeshLoc.ELEMENT, ndbounds=(self.context.level_out_size,) + ) + else: + # 3D + Time + self._dst_field = esmpy.Field( + dst_mesh, name="dst", meshloc=esmpy.MeshLoc.ELEMENT, ndbounds=(self.context.level_out_size, self.context.time_size) + ) +# Check for weights _LOGGER.info("create regridder") if self.context.weight_path.exists(): _LOGGER.info("create regridder from file") @@ -535,8 +524,13 @@ def run(self) -> None: [dim.name[0] for dim in dims.value], fill_value=rave_field.fill_value, ) - for k, v in rave_field.attrs.items(): - setattr(var, k, v) +# Don't carry over fill value and datatype + if self.context.dataset_name != 'GOES': + type_to_use = rave_field.dtype + for k, v in rave_field.attrs.items(): + setattr(var, k, v) + else: + type_to_use = np.float32 _LOGGER.info(f"setting variable data {rave_field.name=}") # Multiply FRE/FRP by output area so it is back to W or J*s @@ -691,47 +685,49 @@ def create_desc_stuff(self, targets: Iterable[FileDesc]) -> pd.DataFrame: def create_src_field_wrapper(self, field_name: str) -> FieldWrapper: _LOGGER.info("create source field") - if self.context.dataset_name == "GRA2PES" and field_name in ("PM25-PRI", "PM10-PRI", "HC01", "h_agl","SO2","CO","NH3","NOX"): - if field_name in ("h_agl",): - src_fwrap = NcToField( - path=self.context.src_path, - name=field_name, - gwrap=self.get_src_gwrap(), - dim_time=(self.context.time_name,), - dim_level=('bottom_top_stag',), - ).create_field_wrapper() - else: - src_fwrap = NcToField( - path=self.context.src_path, - name=field_name, - gwrap=self.get_src_gwrap(), - dim_time=(self.context.time_name,), - dim_level=(self.context.level_in_name,), - ).create_field_wrapper() - elif self.context.dataset_name == "NEMO_ANTHRO": - src_fwrap = NcToField( - path=self.context.src_path, - name=field_name, - gwrap=self.get_src_gwrap(), - dim_time=(self.context.time_name,), - dim_level=(self.context.level_in_name,), - ).create_field_wrapper() - - elif field_name in ("clayfrac", "sandfrac", "uthres", "ssm"): - src_fwrap = NcToField( - path=self.context.src_path, - name=field_name, - gwrap=self.get_src_gwrap(), - dim_time=None, - dim_level=(self.context.level_in_name,), - ).create_field_wrapper() + if self.context.dataset_name == "GRA2PES" and field_name in ("h_agl",): # Special case for staggered grid + src_fwrap = NcToField( + path=self.context.src_path, + name=field_name, + gwrap=self.get_src_gwrap(), + dim_time=(self.context.time_name,), + dim_level=('bottom_top_stag',), + ).create_field_wrapper() + elif self.context.level_in_name == "None": + if self.context.time_name == "None": + src_fwrap = NcToField( + path=self.context.src_path, + name=field_name, + gwrap=self.get_src_gwrap(), + dim_time=None, + dim_level=None, + ).create_field_wrapper() + else: + src_fwrap = NcToField( + path=self.context.src_path, + name=field_name, + gwrap=self.get_src_gwrap(), + dim_time=(self.context.time_name,), + dim_level=None, + ).create_field_wrapper() else: - src_fwrap = NcToField( - path=self.context.src_path, - name=field_name, - gwrap=self.get_src_gwrap(), - dim_time=(self.context.time_name,), - ).create_field_wrapper() + if self.context.time_name == "None": + src_fwrap = NcToField( + path=self.context.src_path, + name=field_name, + gwrap=self.get_src_gwrap(), + dim_time=None, + dim_level=(self.context.level_in_name,), + ).create_field_wrapper() + else: + src_fwrap = NcToField( + path=self.context.src_path, + name=field_name, + gwrap=self.get_src_gwrap(), + dim_time=(self.context.time_name,), + dim_level=(self.context.level_in_name,), + ).create_field_wrapper() + # Get the area from the RAVE file, need to convert from /grid to /m2 if (self.context.dataset_name == "RAVE" and field_name in ("PM25", "NH3", "SO2", "FRE", "FRP_MEAN", "TPM", "CH4", "CO", "NOx")): area_fwrap = NcToField( @@ -1060,7 +1056,6 @@ def main(ctx: ChemRegridContext) -> None: time_name = "time" time_size = 1 InterpMethod = "CONSERVE" - elif dataset_name == "GRA2PES": field_names = ("PM25-PRI", "PM10-PRI","SO2","CO","NOX","NH3","h_agl") # ,"HC01"=methane BAQMS, summer, 2025 x_center = "XLONG" # "XLONG_M" @@ -1106,8 +1101,8 @@ def main(ctx: ChemRegridContext) -> None: x_corner_dim = "COLC" y_corner_dim = "ROWC" level_in_name = "None" - level_out_name = "nkreswoodcomb" - level_out_size = 1 + level_out_name = "None" + level_out_size = 0 time_name = "Time" time_size = 1 InterpMethod = "CONSERVE" @@ -1155,8 +1150,8 @@ def main(ctx: ChemRegridContext) -> None: x_corner_dim = None y_corner_dim = None level_in_name = "None" - level_out_name = "nkreswoodcomb" - level_out_size = 1 + level_out_name = "None" + level_out_size = 0 time_name = "Time" time_size = 1 InterpMethod = "BILINEAR" @@ -1171,8 +1166,8 @@ def main(ctx: ChemRegridContext) -> None: x_corner_dim = None y_corner_dim = None level_in_name = "None" - level_out_name = "nkemit" - level_out_size = 1 + level_out_name = "None" + level_out_size = 0 time_name = "time" time_size = 0 InterpMethod = "BILINEAR" @@ -1187,8 +1182,8 @@ def main(ctx: ChemRegridContext) -> None: x_corner_dim = None y_corner_dim = None level_in_name = "None" - level_out_name = "nkemit" - level_out_size = 1 + level_out_name = "None" + level_out_size = 0 time_name = "time" time_size = 12 InterpMethod = "BILINEAR" @@ -1213,6 +1208,33 @@ def main(ctx: ChemRegridContext) -> None: time_name = "time" time_size = 1 InterpMethod = "BILINEAR" + elif dataset_name == "GOES": + field_names = ("AOD",) + x_center = "longitude" + y_center = "latitude" + x_dim = "x" + y_dim = "y" + x_corner = None + y_corner = None + x_corner_dim = None + y_corner_dim = None + level_in_name = "None" + level_out_name = "None" + level_out_size = 0 + time_name = "None" + time_size = 0 + InterpMethod = "BILINEAR" + dates_needed = [] + for i in range(25): + if ebb_dcycle == 1: # Same-day emissions + x = datetime(int(YYYY), int(MM), int(DD), int(HH), 0, 0) + timedelta(hours=i) + elif ebb_dcycle == -1 or ebb_dcycle == 2: # Persistence (-1) or forecasted (2) needs prev 24 hours + x = datetime(int(YYYY), int(MM), int(DD), int(HH), 0, 0) - timedelta(hours=i) + else: + _LOGGER.info("EBB_DCYLE selection not recognized, reverting to same day, ebb_dcycle = 1") + x = datetime(int(YYYY), int(MM), int(DD), int(HH), 0, 0) - timedelta(hours=i) + y = x.strftime("%Y%m%d%H") + dates_needed.append(y) weight_path = ctx.get_weight_path(InterpMethod) @@ -1329,6 +1351,81 @@ def main(ctx: ChemRegridContext) -> None: _LOGGER.info("NGFS success") + elif dataset_name == "GOES": + processor = None + date_to_process = dates_needed[0] + rave_paths = find_latest_rave_file(input_dir, date_to_process, -1, dataset_name, max_lookback_hours=2) + files_to_cat = rave_paths + _LOGGER.info(f"will cat files: {files_to_cat=}") + if COMM.rank == 0: + ds = xr.open_mfdataset(files_to_cat, combine='nested', concat_dim='file') + # 2. Calculate the nanmean across the new 'file' dimension + # skipna=True (default) ensures it behaves like np.nanmean + ds_averaged = ds['AOD'].mean(dim='file', skipna=True) + print(ds_averaged) + ds_averaged.encoding.update({ + 'dtype': 'float32', + '_FillValue': -999 + }) + ds_averaged.to_netcdf(Path(output_dir + '/test_goes_aod_merged.nc')) + + if not rave_paths: + _LOGGER.info(f"No matching GOES files found for {date_to_process} (even after lookback).") + raise ValueError + + print('Reading merged GOES file:', 'test_goes_aod_merged.nc') + #rave_path = rave_paths[0] + rave_path = Path(output_dir + "/test_goes_aod_merged.nc") + new_dst_path = Path(output_dir + "/" + mesh_name + "-GOES-" + date_to_process + ".nc") + # --- OPTIMIZATION START --- + if processor is None: + # FIRST PASS: Full Initialization + # This pays the "expensive" cost of loading weights/grids, but only once. + + context = RaveToMpasRegridContext( + dataset_name=dataset_name, + src_path=rave_path, + dst_path=dst_path, + new_dst_path=new_dst_path, + desc_stats_out=desc_stats_out, + weight_path=weight_path, + InterpMethod=InterpMethod, + scrip_path=scrip_path, + num_cells=num_cells, + mesh_name=mesh_name, + field_names=field_names, + x_center=x_center, + y_center=y_center, + x_dim=x_dim, + y_dim=y_dim, + x_corner=x_corner, + y_corner=y_corner, + x_corner_dim=x_corner_dim, + y_corner_dim=y_corner_dim, + level_in_name=level_in_name, + level_out_name=level_out_name, + level_out_size=level_out_size, + time_name=time_name, + time_size=time_size + + ) + processor = RaveToMpasRegridProcessor(context=context) + processor.initialize() + else: + # SUBSEQUENT PASSES: Hot Swap + # Just update the paths in the existing context. + # The grids and regridder (weights) remain loaded in memory. + processor.context.src_path = rave_path + processor.context.new_dst_path = new_dst_path + # Run the regridding (Fast) + processor.run() + # --- OPTIMIZATION END --- + # Only finalize after ALL files are done + if processor: + processor.finalize() + + _LOGGER.info("success") + elif dataset_name == "FMC": for date_to_process in dates_needed: rave_paths = glob.glob(input_dir + "fmc_" + date_to_process + ".nc") diff --git a/src/regrid_wrapper/app/chem_regrid/context.py b/src/regrid_wrapper/app/chem_regrid/context.py index fbb1337..538068b 100644 --- a/src/regrid_wrapper/app/chem_regrid/context.py +++ b/src/regrid_wrapper/app/chem_regrid/context.py @@ -37,6 +37,7 @@ class DatasetName(StrEnum): FENGSHA_2D = "FENGSHA_2D" FENGSHA_2D_Time = "FENGSHA_2D_Time" NGFS = "NGFS" + GOES = "GOES" class ChemRegridContext(RwBaseModel):