Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 54 additions & 56 deletions scripts/averaging/create_climo_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import numpy as np
import xarray as xr # module-level import so all functions can get to it.

import multiprocessing as mp

def get_time_slice_by_year(time, startyear, endyear):
"""
Expand Down Expand Up @@ -224,34 +223,35 @@ def create_climo_files(adf, clobber=False, search=None):
# end_diag_script(errmsg) # Previously we would kill the run here.
continue

list_of_arguments.append((adf, ts_files, syr, eyr, output_file))


#End of var_list loop
#--------------------
list_of_arguments.append((adf.user, ts_files, syr, eyr, output_file))

# Parallelize the computation using multiprocessing pool:
with mp.Pool(processes=number_of_cpu) as p:
result = p.starmap(process_variable, list_of_arguments)

#End of model case loop
#----------------------

#Notify user that script has ended:
print(f" --> Starting Pool with {number_of_cpu} workers for {len(list_of_arguments)} variables.")
import multiprocessing as mp
# Use 'spawn' to ensure a fresh memory space for each process
# Safer on HPC systems than the default 'fork'
context = mp.get_context('spawn')
with context.Pool(processes=number_of_cpu) as p:
results = p.starmap(process_variable, list_of_arguments)
# Print results to see if any specific variable failed
for res in results:
if "Failed" in res:
print(f"\t {res}")
print(" ... multiprocessing pool closed.")
print(" ...CAM climatologies have been calculated successfully.")


#
# Local functions
#
def process_variable(adf, ts_files, syr, eyr, output_file):
def process_variable(adf_user, ts_files, syr, eyr, output_file):
'''
Compute and save the monthly climatology file.

Parameters
----------
adf
The ADF object
adf_user
The user from the ADF object
ts_files : list
list of paths to time series files
syr : str
Expand All @@ -261,46 +261,44 @@ def process_variable(adf, ts_files, syr, eyr, output_file):
output_file : str or Path
file path for output climatology file
'''
#Read in files via xarray (xr):
if len(ts_files) == 1:
cam_ts_data = xr.open_dataset(ts_files[0], decode_times=True)
else:
cam_ts_data = xr.open_mfdataset(ts_files, decode_times=True, combine='by_coords')
#Average time dimension over time bounds, if bounds exist:
if 'time_bnds' in cam_ts_data:
time = cam_ts_data['time']
# NOTE: force `load` here b/c if dask & time is cftime, throws a NotImplementedError:
time = xr.DataArray(cam_ts_data['time_bnds'].load().mean(dim='nbnd').values, dims=time.dims, attrs=time.attrs)
cam_ts_data['time'] = time
cam_ts_data.assign_coords(time=time)
cam_ts_data = xr.decode_cf(cam_ts_data)
#Extract data subset using provided year bounds:
tslice = get_time_slice_by_year(cam_ts_data.time, int(syr), int(eyr))
cam_ts_data = cam_ts_data.isel(time=tslice)
#Group time series values by month, and average those months together:
cam_climo_data = cam_ts_data.groupby('time.month').mean(dim='time')
#Rename "months" to "time":
cam_climo_data = cam_climo_data.rename({'month':'time'})
#Set netCDF encoding method (deal with getting non-nan fill values):
enc_dv = {xname: {'_FillValue': None, 'zlib': True, 'complevel': 4} for xname in cam_climo_data.data_vars}
enc_c = {xname: {'_FillValue': None} for xname in cam_climo_data.coords}
enc = {**enc_c, **enc_dv}

# Create a dictionary of attributes
# Convert the list to a string (join with commas)
ts_files_str = [str(path) for path in ts_files]
ts_files_str = ', '.join(ts_files_str)
attrs_dict = {
"adf_user": adf.user,
"climo_yrs": f"{syr}-{eyr}",
"time_series_files": ts_files_str,
}
cam_climo_data = cam_climo_data.assign_attrs(attrs_dict)

#Output variable climatology to NetCDF-4 file:
cam_climo_data.to_netcdf(output_file, format='NETCDF4', encoding=enc)
return 1 # All funcs return something. Could do error checking with this if needed.

import xarray as xr
import numpy as np
import dask
import gc
dask.config.set(scheduler='synchronous') # Disable internal dask multi-threading
try:
# Using chunks={} forces xarray to use dask, which handles memory better
# than loading everything into RAM at once via open_dataset
with xr.open_mfdataset(ts_files, decode_times=True, combine='by_coords', chunks={'time': 12}) as ds:
if 'time_bnds' in ds:
new_time = ds['time_bnds'].load().mean(dim='nbnd')
ds = ds.assign_coords(time=new_time.values)
ds = xr.decode_cf(ds)

tslice = get_time_slice_by_year(ds.time, int(syr), int(eyr))
ds_subset = ds.isel(time=tslice)

climo = ds_subset.groupby('time.month').mean(dim='time')
climo = climo.rename({'month': 'time'})

enc_dv = {xname: {'_FillValue': None, 'zlib': True, 'complevel': 4} for xname in climo.data_vars}
enc_c = {xname: {'_FillValue': None} for xname in climo.coords}
enc = {**enc_c, **enc_dv}

climo.attrs.update({
"units": ds.attrs.get("units", "--"),
"adf_user": adf_user,
"climo_yrs": f"{syr}-{eyr}",
"time_series_files": ", ".join([str(f) for f in ts_files])
})

climo.to_netcdf(output_file, format='NETCDF4', encoding=enc)
return f"Success: {output_file.name}"
except Exception as e:
return f"Failed: {output_file.name} with error: {str(e)}"
finally:
# Force cleanup of memory
gc.collect()

def check_averaging_interval(syear_in, eyear_in):
"""
Expand Down
Loading