diff --git a/dabench/__init__.py b/dabench/__init__.py index 8d4cbd2..b1cc0f1 100644 --- a/dabench/__init__.py +++ b/dabench/__init__.py @@ -1 +1 @@ -from . import data, vector, model, observer, obsop, dacycler, _suppl_data +from . import data, model, observer, obsop, dacycler, dasupport, _suppl_data, utils diff --git a/dabench/dacycler/_dacycler.py b/dabench/dacycler/_dacycler.py index 9f24a91..6afc605 100644 --- a/dabench/dacycler/_dacycler.py +++ b/dabench/dacycler/_dacycler.py @@ -1,8 +1,12 @@ """Base class for Data Assimilation Cycler object (DACycler)""" -from dabench import vector import numpy as np +import jax.numpy as jnp +import jax +import xarray as xr +import xarray_jax as xj +import dabench.dacycler._utils as dac_utils class DACycler(): """Base class for DACycler object @@ -37,6 +41,7 @@ def __init__(self, R=None, H=None, h=None, + analysis_time_in_window=None ): self.h = h @@ -48,15 +53,122 @@ def __init__(self, self.system_dim = system_dim self.delta_t = delta_t self.model_obj = model_obj + self.analysis_time_in_window = analysis_time_in_window + + + def _calc_default_H(self, obs_values, obs_loc_indices): + H = jnp.zeros((obs_values.flatten().shape[0], self.system_dim)) + H = H.at[jnp.arange(H.shape[0]), + obs_loc_indices.flatten(), + ].set(1) + return H + + def _calc_default_R(self, obs_values, obs_error_sd): + return jnp.identity(obs_values.flatten().shape[0])*(obs_error_sd**2) + + def _calc_default_B(self): + """If B is not provided, identity matrix with shape (system_dim, system_dim.""" + return jnp.identity(self.system_dim) + + def _step_forecast(self, xa, n_steps=1): + """Perform forecast using model object""" + return self.model_obj.forecast(xa, n_steps=n_steps) + + def _step_cycle(self, xb, obs_vals, obs_locs, obs_time_mask, obs_loc_mask, + H=None, h=None, R=None, B=None, **kwargs): + if H is not None or h is None: + vals = self._cycle_obsop( + xb, obs_vals, obs_locs, obs_time_mask, + obs_loc_mask, H, R, B, **kwargs) + return vals + else: + raise ValueError( + 'Only linear obs operators (H) are supported right now.') + vals = self._cycle_general_obsop( + xb, obs_vals, obs_locs, obs_time_mask, + obs_loc_mask, h, R, B, **kwargs) + return vals + + def _cycle_and_forecast(self, cur_state, filtered_idx): + # 1. Get data + # 1-b. Calculate obs_time_mask and restore filtered_idx to original values + cur_state = cur_state.to_xarray() + cur_time = cur_state['_cur_time'].data + cur_state = cur_state.drop_vars(['_cur_time']) + obs_time_mask = filtered_idx > 0 + filtered_idx = filtered_idx - 1 + + # 2. Calculate analysis + cur_obs_vals = jnp.array(self._obs_vector[self._observed_vars].to_array().data).at[:, filtered_idx].get() + cur_obs_loc_indices = jnp.array(self._obs_vector.system_index.data).at[:, filtered_idx].get() + cur_obs_loc_mask = jnp.array(self._obs_loc_masks).at[:, filtered_idx].get().astype(bool) + cur_obs_time_mask = jnp.repeat(obs_time_mask, cur_obs_vals.shape[-1]) + analysis = self._step_cycle( + cur_state, + cur_obs_vals, + cur_obs_loc_indices, + obs_loc_mask=cur_obs_loc_mask, + obs_time_mask=cur_obs_time_mask + ) + # 3. Forecast next timestep + next_state, forecast_states = self._step_forecast(analysis, n_steps=self.steps_per_window) + next_state = next_state.assign( + _cur_time = cur_time + self.analysis_window + ).assign_coords( + cur_state.coords) + + return xj.from_xarray(next_state), forecast_states + + def _cycle_and_forecast_4d(self, cur_state, filtered_idx): + # 1. Get data + # 1-b. Calculate obs_time_mask and restore filtered_idx to original values + cur_state = cur_state.to_xarray() + cur_time = cur_state['_cur_time'].data + cur_state = cur_state.drop_vars(['_cur_time']) + obs_time_mask = filtered_idx > 0 + filtered_idx = filtered_idx - 1 + + cur_obs_vals = jnp.array(self._obs_vector[self._observed_vars].to_stacked_array('system',['time']).data).at[filtered_idx].get() + cur_obs_times = jnp.array(self._obs_vector.time.data).at[filtered_idx].get() + cur_obs_loc_indices = jnp.array(self._obs_vector.system_index.data).at[:, filtered_idx].get().reshape(filtered_idx.shape[0], -1) + cur_obs_loc_mask = jnp.array(self._obs_loc_masks).at[:, filtered_idx].get().astype(bool).reshape(filtered_idx.shape[0], -1) + + # Calculate obs window indices: closest model timesteps that match obs + obs_window_indices =jnp.array([ + jnp.argmin( + jnp.abs(obs_time - (cur_time + self._model_timesteps)) + ) for obs_time in cur_obs_times + ]) + + # 2. Calculate analysis + analysis = self._step_cycle( + cur_state, + cur_obs_vals, + cur_obs_loc_indices, + obs_loc_mask=cur_obs_loc_mask, + obs_time_mask=obs_time_mask, + obs_window_indices=obs_window_indices + ) + + # 3. Forecast forward + next_state, forecast_states = self._step_forecast(analysis, n_steps=self.steps_per_window) + next_state = next_state.assign( + _cur_time = cur_time + self.analysis_window + ).assign_coords( + cur_state.coords) + + return xj.from_xarray(next_state), forecast_states def cycle(self, input_state, start_time, obs_vector, n_cycles, - analysis_window, + obs_error_sd=None, + analysis_window=0.2, analysis_time_in_window=None, - return_forecast=False): + return_forecast=False + ): """Perform DA cycle repeatedly, including analysis and forecast Args: @@ -79,52 +191,78 @@ def cycle(self, vector.StateVector of analyses and times. """ + # These could be different if observer doesn't observe all variables + # For now, making them the same + self._observed_vars = obs_vector['variable'].values + self._data_vars = list(input_state.data_vars) + + if obs_error_sd is None: + obs_error_sd = obs_vector.error_sd + + self.analysis_window = analysis_window + # If don't specify analysis_time_in_window, is assumed to be middle - if analysis_time_in_window is None: - analysis_time_in_window = analysis_window/2 + if self.analysis_time_in_window is None and analysis_time_in_window is None: + analysis_time_in_window = self.analysis_window/2 + else: + analysis_time_in_window = self.analysis_time_in_window + + # Steps per window + 1 to include start + self.steps_per_window = round(analysis_window/self.delta_t) + 1 + self._model_timesteps = jnp.arange(self.steps_per_window)*self.delta_t # Time offset from middle of time window, for gathering observations _time_offset = (analysis_window/2) - analysis_time_in_window - # Number of model steps to run per window - steps_per_window = round(analysis_window/self.delta_t) + 1 - - # For storing outputs - all_output_states = [] - all_times = [] - cur_time = start_time - cur_state = input_state - - for i in range(n_cycles): - # 1. Filter observations to inside analysis window - window_middle = cur_time + _time_offset - window_start = window_middle - analysis_window/2 - window_end = window_middle + analysis_window/2 - obs_vec_timefilt = obs_vector.filter_times( - window_start, window_end - ) - - if obs_vec_timefilt.values.shape[0] > 0: - # 2. Calculate analysis - analysis, kh = self._step_cycle(cur_state, obs_vec_timefilt) - # 3. Forecast through analysis window - forecast_states = self._step_forecast(analysis, - n_steps=steps_per_window) - # 4. Save outputs - if return_forecast: - # Append forecast to current state, excluding last step - all_output_states.append(forecast_states.values[:-1]) - all_times.append( - np.arange(steps_per_window-1)*self.delta_t + cur_time - ) - else: - all_output_states.append(analysis.values[np.newaxis]) - all_times.append([cur_time]) - - # Starting point for next cycle is last step of forecast - cur_state = forecast_states[-1] - cur_time += analysis_window - - return vector.StateVector(values=np.concatenate(all_output_states), - times=np.concatenate(all_times)) + # Set up for jax.lax.scan, which is very fast + all_times = dac_utils._get_all_times( + start_time, + analysis_window, + n_cycles) + + + if self.steps_per_window is None: + self.steps_per_window = round(analysis_window/self.delta_t) + 1 + self._model_timesteps = jnp.arange(self.steps_per_window)*self.delta_t + # Get the obs vectors for each analysis window + all_filtered_idx = dac_utils._get_obs_indices( + obs_times=jnp.array(obs_vector.time.values), + analysis_times=all_times+_time_offset, + start_inclusive=True, + end_inclusive=self.in_4d, + analysis_window=analysis_window + ) + input_state = input_state.assign(_cur_time=start_time) + + all_filtered_padded = dac_utils._pad_time_indices(all_filtered_idx, add_one=True) + self._obs_vector=obs_vector + self.obs_error_sd = obs_error_sd + if obs_vector.stationary_observers: + self._obs_loc_masks = jnp.ones( + obs_vector[self._observed_vars].to_array().shape, dtype=bool) + else: + self._obs_loc_masks = ~np.isnan( + obs_vector[self._observed_vars].to_array().data) + self._obs_vector=self._obs_vector.fillna(0) + + if self.in_4d: + cur_state, all_values = jax.lax.scan( + self._cycle_and_forecast_4d, + xj.from_xarray(input_state), + all_filtered_padded) + else: + cur_state, all_values = jax.lax.scan( + self._cycle_and_forecast, + xj.from_xarray(input_state), + all_filtered_padded) + + all_vals_xr = xr.Dataset( + {var: (('cycle',) + tuple(all_values[var].dims), + all_values[var].data) + for var in all_values.data_vars} + ).rename_dims({'time': 'cycle_timestep'}) + if return_forecast: + return all_vals_xr.drop_isel(cycle_timestep=-1) + else: + return all_vals_xr.isel(cycle_timestep=0) diff --git a/dabench/dacycler/_etkf.py b/dabench/dacycler/_etkf.py index ed509e6..b5514f5 100644 --- a/dabench/dacycler/_etkf.py +++ b/dabench/dacycler/_etkf.py @@ -5,9 +5,10 @@ import jax import jax.numpy as jnp from jax.scipy import linalg +import xarray as xr +import xarray_jax as xj -from dabench import dacycler, vector -import dabench.dacycler._utils as dac_utils +from dabench import dacycler class ETKF(dacycler.DACycler): @@ -61,96 +62,34 @@ def __init__(self, ensemble=True, B=B, R=R, H=H, h=h) - def _step_cycle(self, xb, yo, obs_time_mask, obs_loc_mask, - H=None, h=None, R=None, B=None): - if H is not None or h is None: - vals, kh = self._cycle_obsop( - xb.values, yo.values, yo.location_indices, yo.error_sd, obs_time_mask, - obs_loc_mask, H, R, B) - return vector.StateVector(values=vals, store_as_jax=True), kh - else: - return self._cycle_general_obsop(xb, yo, h, R, B) - - def _calc_default_H(self, obs_values, obs_loc_indices): - H = jnp.zeros((obs_values.flatten().shape[0], self.system_dim)) - H = H.at[jnp.arange(H.shape[0]), - obs_loc_indices.flatten() - ].set(1) - return H - - def _calc_default_R(self, obs_values, obs_error_sd): - return jnp.identity(obs_values.flatten().shape[0])*(obs_error_sd**2) - - def _calc_default_B(self): - return jnp.identity(self.system_dim) - - def _cycle_obsop(self, Xbt, obs_values, obs_loc_indices, obs_error_sd, - obs_time_mask, obs_loc_mask, - H=None, h=None, R=None, B=None): - if H is None and h is None: - if self.H is None: - if self.h is None: - H = self._calc_default_H(obs_values, obs_loc_indices) - else: - h = self.h - else: - H = self.H - if R is None: - if self.R is None: - R = self._calc_default_R(obs_values, obs_error_sd) - else: - R = self.R - if B is None: - if self.B is None: - B = self._calc_default_B() - else: - B = self.B - - nr, nc = Xbt.shape - assert nr == self.ensemble_dim, ( - 'cycle:: model_forecast must have dimension {}x{}').format( - self.ensemble_dim, self.system_dim) - - # Apply obs masks to H - H = jnp.where(obs_time_mask, H.T, 0).T - H = jnp.where(obs_loc_mask.flatten(), H.T, 0).T - - # Analysis cycles over all obs in data_obs - Xa = self._compute_analysis(Xb=Xbt.T, - y=obs_values, - H=H, - h=h, - R=R, - rho=self.multiplicative_inflation) - - return Xa.T, 0 - def _step_forecast(self, xa, n_steps): - data_forecast = [] + """Ensemble method needs a slightly different _step_forecast method""" + ensemble_forecasts = [] + ensemble_inputs = [] for i in range(self.ensemble_dim): - new_vals = self.model_obj.forecast( - vector.StateVector(values=xa.values[i], store_as_jax=True), + cur_inputs, cur_forecast = self.model_obj.forecast( + xa.isel(ensemble=i), n_steps=n_steps - ).values - data_forecast.append(new_vals) + ) + ensemble_inputs.append(cur_inputs) + ensemble_forecasts.append(cur_forecast) - out_vals = jnp.moveaxis(jnp.stack(data_forecast), [0,1,2],[1,0,2]) - return vector.StateVector(values=out_vals, - store_as_jax=True) + return (xr.concat(ensemble_inputs, dim='ensemble'), + xr.concat(ensemble_forecasts, dim='ensemble')) - def _apply_obsop(self, Xb, H, h): + def _apply_obsop(self, xb, H, h): if H is not None: - Yb = H @ Xb + yb = H @ xb else: - Yb = h(Xb) + yb = h(xb) - return Yb + return yb - def _compute_analysis(self, Xb, y, H, h, R, rho=1.0, Yb=None): + def _compute_analysis(self, xb, y, H, h, R, rho=1.0, yb=None): """ETKF analysis algorithm Args: - Xb (ndarray): Forecast/background ensemble with shape + xb (ndarray): Forecast/background ensemble with shape (system_dim, ensemble_dim). y (ndarray): Observation array with shape (observation_dim,) H (ndarray): Observation operator with shape (observation_dim, @@ -161,10 +100,10 @@ def _compute_analysis(self, Xb, y, H, h, R, rho=1.0, Yb=None): (i.e. no inflation) Returns: - Xa (ndarray): Analysis ensemble [size: (system_dim, ensemble_dim)] + xa (ndarray): Analysis ensemble [size: (system_dim, ensemble_dim)] """ # Number of state variables, ensemble members and observations - system_dim, ensemble_dim = Xb.shape + system_dim, ensemble_dim = xb.shape observation_dim = y.shape[0] # Auxiliary matrices that will ease the computations @@ -172,30 +111,25 @@ def _compute_analysis(self, Xb, y, H, h, R, rho=1.0, Yb=None): I = jnp.identity(ensemble_dim) # The ensemble is inflated (rho=1.0 is no inflation) - Xb_pert = Xb @ (I-U) - Xb = Xb_pert + Xb @ U + xb_pert = xb @ (I-U) + xb = xb_pert + xb @ U - # Ensemble Transform Kalman Filter - # Initialize the ensemble in observation space - if Yb is None: - Yb = jnp.empty((observation_dim, ensemble_dim)) - - # Map every ensemble member into observation space - Yb = self._apply_obsop(Xb, H, h) + # Map every ensemble member into observation space + yb = self._apply_obsop(xb, H, h) # Get ensemble means and perturbations - xb_bar = jnp.mean(Xb, axis=1) - Xb_pert = Xb @ (I-U) + xb_bar = jnp.mean(xb, axis=1) + xb_pert = xb @ (I-U) - yb_bar = jnp.mean(Yb, axis=1) - Yb_pert = Yb @ (I-U) + yb_bar = jnp.mean(yb, axis=1) + yb_pert = yb @ (I-U) # Compute the analysis if len(R) > 0: Rinv = jnp.linalg.pinv(R, rtol=1e-15) Pa_ens = jnp.linalg.pinv((ensemble_dim-1)/rho*I - + Yb_pert.T @ Rinv @ Yb_pert, + + yb_pert.T @ Rinv @ yb_pert, rtol=1e-15) Wa = linalg.sqrtm((ensemble_dim-1) * Pa_ens) Wa = Wa.real @@ -204,139 +138,55 @@ def _compute_analysis(self, Xb, y, H, h, R, rho=1.0, Yb=None): Pa_ens = jnp.zeros((ensemble_dim, ensemble_dim), dtype=R.dtype) Wa = jnp.zeros((ensemble_dim, ensemble_dim), dtype=R.dtype) - wa = Pa_ens @ Yb_pert.T @ Rinv @ (y.flatten()-yb_bar) + wa = Pa_ens @ yb_pert.T @ Rinv @ (y.flatten()-yb_bar) - Xa_pert = Xb_pert @ Wa + xa_pert = xb_pert @ Wa - xa_bar = xb_bar + jnp.ravel(Xb_pert @ wa) + xa_bar = xb_bar + jnp.ravel(xb_pert @ wa) v = jnp.ones((1, ensemble_dim)) - Xa = Xa_pert + xa_bar[:, None] @ v - - return Xa - - def _cycle_and_forecast(self, state_obs_tuple, filtered_idx): - # 1. Get data - cur_state_vals = state_obs_tuple[0] - obs_vals = state_obs_tuple[1] - obs_times = state_obs_tuple[2] - obs_loc_indices = state_obs_tuple[3] - obs_loc_masks = state_obs_tuple[4] - obs_error_sd = state_obs_tuple[5] - - # 1-b. Calculate obs_time_mask and restore filtered_idx to original values - obs_time_mask = jnp.repeat(filtered_idx > 0, obs_loc_indices.shape[1]) - filtered_idx = filtered_idx - 1 - - # 2. Calculate analysis - new_obs_vals = obs_vals[filtered_idx] - new_obs_loc_indices = obs_loc_indices[filtered_idx] - new_obs_loc_mask = obs_loc_masks[filtered_idx] - analysis, kh = self._step_cycle( - vector.StateVector(values=cur_state_vals, store_as_jax=True), - vector.ObsVector(values=new_obs_vals, - location_indices=new_obs_loc_indices, - error_sd=obs_error_sd, store_as_jax=True), - obs_loc_mask=new_obs_loc_mask, - obs_time_mask=obs_time_mask - ) - # 3. Forecast next timestep - forecast_states = self._step_forecast(analysis, n_steps=self.steps_per_window) - next_state = forecast_states.values[-1] - - return (next_state, obs_vals, obs_times, obs_loc_indices, - obs_loc_masks, obs_error_sd), forecast_states.values[:-1] - - def cycle(self, - input_state, - start_time, - obs_vector, - n_cycles, - obs_error_sd=None, - analysis_window=0.2, - analysis_time_in_window=None, - return_forecast=False): - """Perform DA cycle repeatedly, including analysis and forecast + xa = xa_pert + xa_bar[:, None] @ v - Args: - input_state (vector.StateVector): Input state. - start_time (float or datetime-like): Starting time. - obs_vector (vector.ObsVector): Observations vector. - n_cycles (int): Number of analysis cycles to run, each of length - analysis_window. - analysis_window (float): Time window from which to gather - observations for DA Cycle. - analysis_time_in_window (float): Where within analysis_window - to perform analysis. For example, 0.0 is the start of the - window. Default is None, which selects the middle of the - window. - return_forecast (bool): If True, returns forecast at each model - timestep. If False, returns only analyses, one per analysis - cycle. Default is False. + return xa - Returns: - vector.StateVector of analyses and times. - """ + def _cycle_obsop(self, x0_xarray, obs_values, obs_loc_indices, + obs_time_mask, obs_loc_mask, + H=None, h=None, R=None, B=None): + if H is None and h is None: + if self.H is None: + if self.h is None: + H = self._calc_default_H(obs_values, obs_loc_indices) + else: + h = self.h + else: + H = self.H + if R is None: + if self.R is None: + R = self._calc_default_R(obs_values, self.obs_error_sd) + else: + R = self.R + if B is None: + if self.B is None: + B = self._calc_default_B() + else: + B = self.B - if obs_error_sd is None: - obs_error_sd = obs_vector.error_sd - self.analysis_window = analysis_window - - # If don't specify analysis_time_in_window, is assumed to be middle - if analysis_time_in_window is None: - analysis_time_in_window = analysis_window/2 - - # Steps per window + 1 to include start - self.steps_per_window = round(analysis_window/self.delta_t) + 1 - - # Time offset from middle of time window, for gathering observations - _time_offset = (analysis_window/2) - analysis_time_in_window - - # Set up for jax.lax.scan, which is very fast - all_times = dac_utils._get_all_times( - start_time, - analysis_window, - n_cycles) - - - # Get the obs vectors for each analysis window - all_filtered_idx = dac_utils._get_obs_indices( - obs_times=obs_vector.times, - analysis_times=all_times+_time_offset, - start_inclusive=True, - end_inclusive=False, - analysis_window=analysis_window - ) - - all_filtered_padded = dac_utils._pad_time_indices(all_filtered_idx, add_one=True) - - # Padding observations - if obs_vector.stationary_observers: - obs_loc_masks = jnp.ones(obs_vector.values.shape, dtype=bool) - cur_state, all_values = jax.lax.scan( - self._cycle_and_forecast, - (input_state.values, obs_vector.values, obs_vector.times, - obs_vector.location_indices, obs_loc_masks, obs_error_sd), - all_filtered_padded) - else: - obs_vals, obs_locs, obs_loc_masks = dac_utils._pad_obs_locs(obs_vector) - cur_state, all_values = jax.lax.scan( - self._cycle_and_forecast, - (input_state.values, obs_vals, obs_vector.times, - obs_locs, obs_loc_masks, obs_error_sd), - all_filtered_padded) - - - if return_forecast: - all_times_forecast = jnp.arange( - 0, - n_cycles*analysis_window, - self.delta_t - ) + start_time - return vector.StateVector(values=jnp.concatenate(all_values), - times=all_times_forecast) - else: - return vector.StateVector(values=jnp.vstack([ - forecast[0][jnp.newaxis] for forecast in all_values] - ), - times=all_times) + xb = x0_xarray.to_stacked_array('system',['ensemble']).data.T + n_sys, n_ens = xb.shape + assert n_ens == self.ensemble_dim, ( + 'cycle:: model_forecast must have dimension {}x{}').format( + self.ensemble_dim, self.system_dim) + + # Apply obs masks to H + H = jnp.where(obs_time_mask.flatten(), H.T, 0).T + H = jnp.where(obs_loc_mask.flatten(), H.T, 0).T + + # Analysis cycles over all obs in data_obs + xa = self._compute_analysis(xb=xb, + y=obs_values, + H=H, + h=h, + R=R, + rho=self.multiplicative_inflation) + + return x0_xarray.assign(x=(['ensemble','i'], xa.T)) diff --git a/dabench/dacycler/_var3d.py b/dabench/dacycler/_var3d.py index f84fd87..271b969 100644 --- a/dabench/dacycler/_var3d.py +++ b/dabench/dacycler/_var3d.py @@ -4,7 +4,7 @@ import jax.numpy as jnp import jax.scipy as jscipy -from dabench import dacycler, vector +from dabench import dacycler class Var3D(dacycler.DACycler): @@ -48,48 +48,21 @@ def __init__(self, ensemble=False, B=B, R=R, H=H, h=h) - def _step_cycle(self, xb, yo, H=None, h=None, R=None, B=None): - """Perform one step of DA Cycle - - Returns: - vector.StateVector containing analysis results - - """ - if H is not None or h is None: - return self._cycle_linear_obsop(xb, yo, H, R, B) - else: - return self._cycle_general_obsop(xb, yo, h, R, B) - - def _calc_default_H(self, obs_vec): - """If H is not provided, creates identity matrix to serve as H""" - H = jnp.zeros((obs_vec.values.flatten().shape[0], self.system_dim)) - H = H.at[jnp.arange(H.shape[0]), obs_vec.location_indices.flatten() - ].set(1) - return H - - def _calc_default_R(self, obs_vec): - """If R i s not provided, calculates default based on observation error""" - return jnp.identity(obs_vec.values.flatten().shape[0])*obs_vec.error_sd**2 - - def _calc_default_B(self): - """If B is not provided, identity matrix with shape (system_dim, system_dim.""" - - return jnp.identity(self.system_dim) - - def _cycle_general_obsop(self, forecast, obs_vec): - return - - def _cycle_linear_obsop(self, forecast, obs_vec, H=None, R=None, - B=None): + def _cycle_obsop(self, x0_xarray, obs_values, obs_loc_indices, + obs_time_mask, obs_loc_mask, + H=None, h=None, R=None, B=None): """When obsop (H) is linear""" - if H is None: + if H is None and h is None: if self.H is None: - H = self._calc_default_H(obs_vec) + if self.h is None: + H = self._calc_default_H(obs_values, obs_loc_indices) + else: + h = self.h else: H = self.H if R is None: if self.R is None: - R = self._calc_default_R(obs_vec) + R = self._calc_default_R(obs_values, self.obs_error_sd) else: R = self.R if B is None: @@ -98,9 +71,12 @@ def _cycle_linear_obsop(self, forecast, obs_vec, H=None, R=None, else: B = self.B - # make inputs column vectors - xb = jnp.array([forecast.values.flatten()]).T - yo = jnp.array([obs_vec.values.flatten()]).T + xb = x0_xarray.to_stacked_array('system',[]).data.flatten() + yo = obs_values.flatten() + + # Apply masks to H + H = jnp.where(obs_time_mask.flatten(), H.T, 0).T + H = jnp.where(obs_loc_mask.flatten(), H.T, 0).T # Set parameters xdim = xb.size # Size or get one of the shape params? @@ -117,12 +93,4 @@ def _cycle_linear_obsop(self, forecast, obs_vec, H=None, R=None, xa, ierr = jscipy.sparse.linalg.cg(A, b1, x0=xb, tol=1e-05, maxiter=1000) - # Compute KH: - HBHtPlusR_inv = jnp.linalg.inv(H @ BHt + R) - KH = BHt @ HBHtPlusR_inv @ H - - return vector.StateVector(values=xa.T[0], store_as_jax=True), KH - - def _step_forecast(self, xa, n_steps): - """n_steps forward of model forecast""" - return self.model_obj.forecast(xa, n_steps=n_steps) + return x0_xarray.assign(x=(x0_xarray.dims, xa.T)) diff --git a/dabench/dacycler/_var4d.py b/dabench/dacycler/_var4d.py index 7dd6cee..2a5242f 100644 --- a/dabench/dacycler/_var4d.py +++ b/dabench/dacycler/_var4d.py @@ -11,8 +11,10 @@ from jax.scipy.sparse.linalg import bicgstab from copy import deepcopy from functools import partial +import xarray as xr +import xarray_jax as xj -from dabench import dacycler, vector +from dabench import dacycler import dabench.dacycler._utils as dac_utils @@ -64,6 +66,7 @@ def __init__(self, n_outer_loops=1, steps_per_window=1, obs_window_indices=None, + analysis_time_in_window=0, **kwargs ): @@ -81,7 +84,8 @@ def __init__(self, model_obj=model_obj, in_4d=True, ensemble=False, - B=B, R=R, H=H, h=h) + B=B, R=R, H=H, h=h, + analysis_time_in_window=analysis_time_in_window) def _calc_default_H(self, obs_loc_indices): Hs = jnp.zeros((obs_loc_indices.shape[0], obs_loc_indices.shape[1], @@ -92,128 +96,6 @@ def _calc_default_H(self, obs_loc_indices): ].set(1) return Hs - - def _calc_default_R(self, obs_values, obs_error_sd): - return jnp.identity(obs_values[0].shape[0])*(obs_error_sd**2) - - def _calc_default_B(self): - return jnp.identity(self.system_dim) - - def _make_outerloop_4d(self, xb0, Hs, B, Rinv, - obs_values, obs_window_indices, obs_time_mask, - n_steps): - - def _outerloop_4d(x0, _): - # Get TLM and current forecast trajectory - # Based on current best guess for x0 - M, x = self.model_obj.compute_tlm( - n_steps=n_steps, - state_vec=vector.StateVector(values=x0, - store_as_jax=True) - ) - - # 4D-Var inner loop - x0 = self._innerloop_4d(self.system_dim, - x, xb0, obs_values, - Hs, B, Rinv, M, - obs_window_indices, - obs_time_mask) - - return x0, x0 - - return _outerloop_4d - - def _cycle_obsop(self, xb0, obs_values, obs_loc_indices, obs_error_sd, - obs_window_indices, obs_time_mask, obs_loc_mask, - H=None, h=None, R=None, B=None, - n_steps=1): - if H is None and h is None: - if self.H is None: - if self.h is None: - H = self._calc_default_H(obs_loc_indices) - # Apply obs loc mask - # NOTE: nonstationary observer case runs MUCH slower. Not sure why - # Ideally, this conditional would not be necessary, but this is a - # workaround to prevent slowing down stationary observer case. - Hs = jax.lax.cond( - self._obs_vector.stationary_observers, - lambda: H, - lambda: (obs_loc_mask[:, :, jnp.newaxis] * H)) - else: - h = self.h - else: - # Assumes self.H is for a single timestep - H = self.H[jnp.newaxis] - Hs = jax.lax.cond( - self._obs_vector.stationary_observers, - lambda: jnp.repeat(H, obs_values.shape[0], axis=0), - lambda: (obs_loc_mask[:, :, jnp.newaxis] * H)) - - if R is None: - if self.R is None: - R = self._calc_default_R(obs_values, obs_error_sd) - else: - R = self.R - if B is None: - if self.B is None: - B = self._calc_default_B() - else: - B = self.B - - # Static Variables - Rinv = jscipy.linalg.inv(R) - - # Best guess for x0 starts as background - x0 = deepcopy(xb0) - - outerloop_4d_func = self._make_outerloop_4d( - xb0, Hs, B, Rinv, obs_values, obs_window_indices, - obs_time_mask, n_steps) - - x0, all_x0s = jax.lax.scan(outerloop_4d_func, init=x0, - xs=None, length=self.n_outer_loops) - - # forecast - x = self.step_forecast( - n_steps=n_steps, - x0=vector.StateVector(values=x0, store_as_jax=True) - ).values - - return x - - def step_cycle(self, x0, yo, obs_time_mask, obs_loc_mask, - obs_window_indices, H=None, h=None, R=None, B=None, - n_steps=1): - """Perform one step of DA Cycle""" - if H is not None or h is None: - return self._cycle_obsop( - x0.values, yo.values, yo.location_indices, yo.error_sd, - obs_loc_mask=obs_loc_mask, obs_time_mask=obs_time_mask, - obs_window_indices=obs_window_indices, - H=H, R=R, B=B, - n_steps=n_steps) - else: - return self._cycle_obsop( - x0.values, yo.values, yo.location_indices, yo.error_sd, h=h, - R=R, B=B, obs_window_indices=obs_window_indices, - n_steps=n_steps) - - def step_forecast(self, x0, n_steps=1): - """Perform forecast using model object""" - if 'n_steps' in inspect.getfullargspec(self.model_obj.forecast).args: - return self.model_obj.forecast(x0, n_steps=n_steps) - else: - if n_steps == 1: - return self.model_obj.forecast(x0) - else: - out = [x0] - xi = x0 - for s in range(n_steps): - xi = self.model.forecast(xi) - out.append(xi) - return vector.StateVector(jnp.vstack(xi), store_as_jax=True) - - def _calc_J_term(self, H, M, Rinv, y, x): # The Jb Term (A) HM = H @ M @@ -223,12 +105,12 @@ def _calc_J_term(self, H, M, Rinv, y, x): D = (y - (H @ x)) return MtHtRinv @ HM, MtHtRinv @ D[:, None] - @partial(jax.jit, static_argnums=[0, 1]) def _innerloop_4d(self, system_dim, x, xb0, obs_vals, Hs, B, Rinv, M, obs_window_indices, obs_time_mask): """4DVar innerloop""" - x0_last = x[0] + x0_last = x.isel(time=0) + x = x.to_stacked_array('system',['time']) # Set up Variables SumMtHtRinvHM = jnp.zeros_like(B) # A input @@ -238,15 +120,15 @@ def _innerloop_4d(self, system_dim, x, xb0, obs_vals, Hs, B, Rinv, M, for i, j in enumerate(obs_window_indices): Jb, Jo = jax.lax.cond( obs_time_mask.at[i].get(mode='fill', fill_value=0), - lambda: self._calc_J_term(Hs.at[i].get(mode='clip'), M[j], - Rinv, obs_vals[i], x[j]), + lambda: self._calc_J_term(Hs.at[i].get(mode='clip'), M.data[j], + Rinv, obs_vals[i], x.data[j]), lambda: (jnp.zeros_like(SumMtHtRinvHM), jnp.zeros_like(SumMtHtRinvD)) ) SumMtHtRinvHM += Jb SumMtHtRinvD += Jo # Compute initial departure - db0 = xb0 - x0_last + db0 = (xb0 - x0_last).to_stacked_array('system',[]).data # Solve Ax=b for the initial perturbation dx0 = self._solve(db0, SumMtHtRinvHM, SumMtHtRinvD, B) @@ -256,6 +138,30 @@ def _innerloop_4d(self, system_dim, x, xb0, obs_vals, Hs, B, Rinv, M, return x0_new + def _make_outerloop_4d(self, xb0, Hs, B, Rinv, + obs_values, obs_window_indices, obs_time_mask, + n_steps): + + def _outerloop_4d(x0, _): + # Get TLM and current forecast trajectory + # Based on current best guess for x0 + x0 = x0.to_xarray() + x, M = self.model_obj.compute_tlm( + n_steps=n_steps, + state_vec=x0 + ) + + # 4D-Var inner loop + x0_new = self._innerloop_4d(self.system_dim, + x, xb0, obs_values, + Hs, B, Rinv, M, + obs_window_indices, + obs_time_mask) + + return xj.from_xarray(x0_new.assign_coords(x0.coords)), x0 + + return _outerloop_4d + @partial(jax.jit, static_argnums=0) def _solve(self, db0, SumMtHtRinvHM, SumMtHtRinvD, B): """Solve the 4D-Var linear optimization @@ -286,142 +192,54 @@ def _solve(self, db0, SumMtHtRinvHM, SumMtHtRinvD, B): return dx0 - def _cycle_and_forecast(self, cur_state_vals_time_tuple, filtered_idx): - cur_state_vals, cur_time = cur_state_vals_time_tuple - obs_error_sd = self._obs_error_sd - - # Calculate obs_time_mask and restore filtered_idx to original values - obs_time_mask = filtered_idx > 0 - filtered_idx = filtered_idx - 1 - - cur_obs_vals = jnp.array(self._obs_vector.values).at[filtered_idx].get() - cur_obs_loc_indices = jnp.array(self._obs_vector.location_indices).at[filtered_idx].get() - cur_obs_times = jnp.array(self._obs_vector.times).at[filtered_idx].get() - cur_obs_loc_mask = jnp.array(self._obs_loc_masks).at[filtered_idx].get().astype(bool) - - # Calculate obs window indices: closest model timesteps that match obs - obs_window_indices = jax.lax.cond( - self.obs_window_indices is None, - lambda: jnp.array([ - jnp.argmin( - jnp.abs(obs_time - (cur_time + self._model_timesteps)) - ) for obs_time in cur_obs_times - ]), - lambda: jnp.array(self.obs_window_indices) - ) - - analysis = self.step_cycle( - vector.StateVector(values=cur_state_vals, store_as_jax=True), - vector.ObsVector(values=cur_obs_vals, - location_indices=cur_obs_loc_indices, - error_sd=obs_error_sd, - store_as_jax=True), - obs_time_mask=obs_time_mask, - obs_loc_mask=cur_obs_loc_mask, - n_steps=self.steps_per_window, - obs_window_indices=obs_window_indices) - new_time = cur_time + self.analysis_window - - return (analysis[-1], new_time), analysis[:-1] - - def cycle(self, - input_state, - start_time, - obs_vector, - obs_error_sd, - n_cycles, - analysis_window, - analysis_time_in_window=0, - return_forecast=False): - """Perform DA cycle repeatedly, including analysis and forecast - - Args: - input_state (vector.StateVector): Input state. - start_time (float or datetime-like): Starting time. - obs_vector (vector.ObsVector): Observations vector. - obs_error_sd (float): Standard deviation of observation error. - Typically not known, so provide a best-guess. - n_cycles (int): Number of analysis cycles to run, each of length - analysis_window. - analysis_window (float): Length of time window from which to gather - observations for each DA Cycle, in model time units. - analysis_time_in_window (float): At what time within analysis_window - to perform analysis. For example, 0.0 is the start of the - window. Default is 0, the start of the window. - return_forecast (bool): If True, returns forecast at each model - timestep. If False, returns only analyses, one per analysis - cycle. Default is False. - - Returns: - vector.StateVector of analyses and times. - """ - if (not obs_vector.stationary_observers and - (self.H is not None or self.h is not None)): - warnings.warn( - "Provided obs vector has nonstationary observers. When" - " providing a custom obs operator (H/h), the Var4DBackprop" - "DA cycler may not function properly. If you encounter " - "errors, try again with an observer where" - "stationary_observers=True or without specifying H or h (a " - "default H matrix will be used to map observations to system " - "space)." - ) - self.analysis_window = analysis_window - - # If don't specify analysis_time_in_window, is assumed to be middle - if analysis_time_in_window is None: - analysis_time_in_window = self.analysis_window/2 + def _cycle_obsop(self, xb0, obs_values, obs_loc_indices, + obs_time_mask, obs_loc_mask, + H=None, h=None, R=None, B=None, obs_window_indices=None): + if H is None and h is None: + if self.H is None: + if self.h is None: + H = self._calc_default_H(obs_loc_indices) + # Apply obs loc mask + # NOTE: nonstationary observer case runs MUCH slower. Not sure why + # Ideally, this conditional would not be necessary, but this is a + # workaround to prevent slowing down stationary observer case. + Hs = jax.lax.cond( + self._obs_vector.stationary_observers, + lambda: H, + lambda: (obs_loc_mask[:, :, jnp.newaxis] * H)) + else: + h = self.h + else: + # Assumes self.H is for a single timestep + H = self.H[jnp.newaxis] + Hs = jax.lax.cond( + self._obs_vector.stationary_observers, + lambda: jnp.repeat(H, obs_values.shape[0], axis=0), + lambda: (obs_loc_mask[:, :, jnp.newaxis] * H)) - # Time offset from middle of time window, for gathering observations - _time_offset = (analysis_window/2) - analysis_time_in_window + if R is None: + if self.R is None: + R = self._calc_default_R(obs_values, self.obs_error_sd) + else: + R = self.R - # Set up for jax.lax.scan, which is very fast - all_times = dac_utils._get_all_times(start_time, analysis_window, - n_cycles) + if B is None: + if self.B is None: + B = self._calc_default_B() + else: + B = self.B - if self.steps_per_window is None: - self.steps_per_window = round(analysis_window/self.delta_t) + 1 - self._model_timesteps = jnp.arange(self.steps_per_window)*self.delta_t + # Static Variables + Rinv = jscipy.linalg.inv(R) - # Get the obs vectors for each analysis window - all_filtered_idx = dac_utils._get_obs_indices( - obs_times=obs_vector.times, - analysis_times=all_times+_time_offset, - start_inclusive=True, - end_inclusive=True, - analysis_window=analysis_window - ) + # Best guess for x0 starts as background + x0_new = deepcopy(xb0) - all_filtered_padded = dac_utils._pad_time_indices(all_filtered_idx) + outerloop_4d_func = self._make_outerloop_4d( + xb0, Hs, B, Rinv, obs_values, obs_window_indices, + obs_time_mask, self.steps_per_window) - self._obs_vector = obs_vector - self._obs_error_sd = obs_error_sd + x0_new, all_x0s = jax.lax.scan(outerloop_4d_func, init=xj.from_xarray(x0_new), + xs=None, length=self.n_outer_loops) - # Padding observations - if obs_vector.stationary_observers: - self._obs_loc_masks = jnp.ones(obs_vector.values.shape, dtype=bool) - else: - obs_vals, obs_locs, obs_loc_masks = dac_utils._pad_obs_locs( - obs_vector) - self._obs_vector.values = obs_vals - self._obs_vector.location_indices = obs_locs - self._obs_loc_masks = jnp.array(obs_loc_masks) - - cur_state, all_values = jax.lax.scan( - self._cycle_and_forecast, - init=(input_state.values, start_time), - xs=all_filtered_padded) - - if return_forecast: - all_times_forecast = jnp.arange( - 0, - n_cycles*analysis_window, - self.delta_t - ) + start_time - return vector.StateVector(values=jnp.concatenate(all_values), - times=all_times_forecast) - else: - return vector.StateVector(values=jnp.vstack([ - forecast[0] for forecast in all_values] - ), - times=all_times) + return x0_new.to_xarray() diff --git a/dabench/dacycler/_var4d_backprop.py b/dabench/dacycler/_var4d_backprop.py index 161f66a..80cfcfc 100644 --- a/dabench/dacycler/_var4d_backprop.py +++ b/dabench/dacycler/_var4d_backprop.py @@ -11,8 +11,10 @@ import jax import optax from functools import partial +import xarray as xr +import xarray_jax as xj -from dabench import dacycler, vector +from dabench import dacycler import dabench.dacycler._utils as dac_utils @@ -72,6 +74,7 @@ def __init__(self, steps_per_window=None, obs_window_indices=None, loss_growth_limit=10, + analysis_time_in_window=0, **kwargs ): @@ -91,8 +94,8 @@ def __init__(self, model_obj=model_obj, in_4d=True, ensemble=False, - B=B, R=R, H=H, h=h) - + B=B, R=R, H=H, h=h, + analysis_time_in_window=analysis_time_in_window) def _calc_default_H(self, obs_loc_indices): Hs = jnp.zeros((obs_loc_indices.shape[0], obs_loc_indices.shape[1], @@ -104,12 +107,6 @@ def _calc_default_H(self, obs_loc_indices): return Hs - def _calc_default_R(self, obs_values, obs_error_sd): - return jnp.identity(obs_values[0].shape[0])*(obs_error_sd**2) - - def _calc_default_B(self): - return jnp.identity(self.system_dim) - def _raise_nan_error(self): raise ValueError('Loss value is nan, exiting optimization') @@ -120,7 +117,7 @@ def _callback_raise_error(self, error_method, loss_val): jax.debug.callback(error_method) return loss_val - @partial(jax.jit, static_argnums=[0]) + # @partial(jax.jit, static_argnums=[0]) def _calc_obs_term(self, pred_x, obs_vals, Ht, Rinv): pred_obs = pred_x @ Ht resid = pred_obs.ravel() - obs_vals.ravel() @@ -132,15 +129,15 @@ def _make_loss(self, xb0, obs_vals, Hs, Binv, Rinv, obs_time_mask, n_steps): """Define loss function based on 4dvar cost""" - @jax.jit + # @jax.jit def loss_4dvarcost(x0): # Get initial departure - db0 = (x0.ravel() - xb0.ravel()) + db0 = (x0.to_array().data.ravel() - xb0.to_array().data.ravel()) # Make new prediction - pred_x = self.step_forecast( - vector.StateVector(values=x0, store_as_jax=True), - n_steps).values + # NOTE: [1] selects the full forecast instead of last timestep only + pred_x = self._step_forecast( + x0, n_steps)[1].to_stacked_array('system',['time']).data # Calculate observation term of J_0 obs_term = 0 @@ -170,11 +167,13 @@ def _make_backprop_epoch(self, loss_func, optimizer, hessian_inv): loss_value_grad = value_and_grad(loss_func, argnums=0) - @jax.jit + # @jax.jit def _backprop_epoch(epoch_state_tuple, i): x0, init_loss, opt_state = epoch_state_tuple + x0 = x0.to_xarray() loss_val, dx0 = loss_value_grad(x0) - dx0_hess = hessian_inv @ dx0 + x0_array = x0.to_stacked_array('system', []) + dx0_hess = hessian_inv @ dx0.to_stacked_array('system',[]).data init_loss = jax.lax.cond( i == 0, lambda: loss_val, @@ -186,17 +185,18 @@ def _backprop_epoch(epoch_state_tuple, i): lambda: loss_val) updates, opt_state = optimizer.update(dx0_hess, opt_state) - x0_new = optax.apply_updates(x0, updates) - - return (x0_new, init_loss, opt_state), loss_val + x0_array.data = optax.apply_updates( + x0_array.data, updates) + x0_new = x0_array.to_unstacked_dataset('system').assign_attrs( + x0.attrs + ) + return (xj.from_xarray(x0_new), init_loss, opt_state), loss_val return _backprop_epoch - - def _cycle_obsop(self, x0, obs_values, obs_loc_indices, obs_error_sd, + def _cycle_obsop(self, x0_xarray, obs_values, obs_loc_indices, obs_time_mask, obs_loc_mask, - H=None, h=None, R=None, B=None, obs_window_indices=None, - n_steps=1): + H=None, h=None, R=None, B=None, obs_window_indices=None): if H is None and h is None: if self.H is None: if self.h is None: @@ -221,7 +221,7 @@ def _cycle_obsop(self, x0, obs_values, obs_loc_indices, obs_error_sd, if R is None: if self.R is None: - R = self._calc_default_R(obs_values, obs_error_sd) + R = self._calc_default_R(obs_values, self.obs_error_sd) else: R = self.R @@ -231,7 +231,6 @@ def _cycle_obsop(self, x0, obs_values, obs_loc_indices, obs_error_sd, else: B = self.B - Rinv = jscipy.linalg.inv(R) Binv = jscipy.linalg.inv(B) @@ -240,205 +239,30 @@ def _cycle_obsop(self, x0, obs_values, obs_loc_indices, obs_error_sd, Binv + Hs.at[0].get().T @ Rinv @ Hs.at[0].get()) loss_func = self._make_loss( - x0, + x0_xarray, obs_values, Hs, Binv, Rinv, obs_window_indices, obs_time_mask, - n_steps=n_steps) + n_steps=self.steps_per_window) lr = optax.exponential_decay( self.learning_rate, 1, self.lr_decay) optimizer = optax.sgd(lr) - opt_state = optimizer.init(x0) + opt_state = optimizer.init(x0_xarray.to_stacked_array('system',[]).data) # Make initial forecast and calculate loss backprop_epoch_func = self._make_backprop_epoch(loss_func, optimizer, hessian_inv) + # epoch_state_tuple, loss_vals = backprop_epoch_func((xj.from_xarray(x0_xarray), 0., opt_state),0) epoch_state_tuple, loss_vals = jax.lax.scan( - backprop_epoch_func, init=(x0, 0., opt_state), + backprop_epoch_func, init=(xj.from_xarray(x0_xarray), 0., opt_state), xs=jnp.arange(self.num_iters)) - x0, init_loss, opt_state = epoch_state_tuple - - xa = self.step_forecast( - vector.StateVector(values=x0, store_as_jax=True), - n_steps=n_steps) - - return xa, loss_vals - - def step_cycle(self, xb, yo, obs_time_mask, obs_loc_mask, - obs_window_indices, H=None, h=None, R=None, B=None, - n_steps=1): - """Perform one step of DA Cycle""" - if H is not None or h is None: - return self._cycle_obsop( - xb.values, yo.values, yo.location_indices, yo.error_sd, - obs_time_mask=obs_time_mask, obs_loc_mask=obs_loc_mask, - H=H, R=R, B=B, - obs_window_indices=obs_window_indices, n_steps=n_steps) - else: - return self._cycle_obsop( - xb, yo, h, R, B, obs_window_indices=obs_window_indices, - n_steps=n_steps) - - def step_forecast(self, xa, n_steps=1): - """Perform forecast using model object""" - if 'n_steps' in inspect.getfullargspec(self.model_obj.forecast).args: - return self.model_obj.forecast(xa, n_steps=n_steps) - else: - if n_steps == 1: - return self.model_obj.forecast(xa) - else: - out = [xa] - xi = xa - for s in range(n_steps): - xi = self.model.forecast(xi) - out.append(xi) - return vector.StateVector(jnp.vstack(xi), store_as_jax=True) - - def _cycle_and_forecast(self, cur_state_vals_time_tuple, filtered_idx): - cur_state_vals, cur_time = cur_state_vals_time_tuple - obs_error_sd = self._obs_error_sd - - # Calculate obs_time_mask and restore filtered_idx to original values - obs_time_mask = filtered_idx > 0 - filtered_idx = filtered_idx - 1 - - cur_obs_vals = jnp.array(self._obs_vector.values).at[filtered_idx].get() - cur_obs_loc_indices = jnp.array(self._obs_vector.location_indices).at[filtered_idx].get() - cur_obs_times = jnp.array(self._obs_vector.times).at[filtered_idx].get() - cur_obs_loc_mask = jnp.array(self._obs_loc_masks).at[filtered_idx].get().astype(bool) - - # Calculate obs window indices: closest model timesteps that match obs - obs_window_indices = jax.lax.cond( - self.obs_window_indices is None, - lambda: jnp.array([ - jnp.argmin( - jnp.abs(obs_time - (cur_time + self._model_timesteps)) - ) for obs_time in cur_obs_times - ]), - lambda: jnp.array(self.obs_window_indices) - ) + x0_new = epoch_state_tuple[0].to_xarray() - analysis, loss_vals = self.step_cycle( - vector.StateVector(values=cur_state_vals, store_as_jax=True), - vector.ObsVector(values=cur_obs_vals, - location_indices=cur_obs_loc_indices, - error_sd=obs_error_sd, - store_as_jax=True), - obs_time_mask=obs_time_mask, - obs_loc_mask=cur_obs_loc_mask, - n_steps=self.steps_per_window, - obs_window_indices=obs_window_indices) - new_time = cur_time + self.analysis_window - - return (analysis.values[-1], new_time), (analysis.values[:-1], loss_vals) - - def cycle(self, - input_state, - start_time, - obs_vector, - obs_error_sd, - n_cycles, - analysis_window, - analysis_time_in_window=0, - return_forecast=False): - """Perform DA cycle repeatedly, including analysis and forecast - - Args: - input_state (vector.StateVector): Input state. - start_time (float or datetime-like): Starting time. - obs_vector (vector.ObsVector): Observations vector. - obs_error_sd (float): Standard deviation of observation error. - Typically not known, so provide a best-guess. - n_cycles (int): Number of analysis cycles to run, each of length - analysis_window. - analysis_window (float): Length of time window from which to gather - observations for each DA Cycle, in model time units. - analysis_time_in_window (float): At what time within analysis_window - to perform analysis. For example, 0.0 is the start of the - window. Default is 0, the start of the window. - return_forecast (bool): If True, returns forecast at each model - timestep. If False, returns only analyses, one per analysis - cycle. Default is False. - - Returns: - vector.StateVector of analyses and times. - """ - if (not obs_vector.stationary_observers and - (self.H is not None or self.h is not None)): - warnings.warn( - "Provided obs vector has nonstationary observers. When" - " providing a custom obs operator (H/h), the Var4DBackprop" - "DA cycler may not function properly. If you encounter " - "errors, try again with an observer where" - "stationary_observers=True or without specifying H or h (a " - "default H matrix will be used to map observations to system " - "space)." - ) - self.analysis_window = analysis_window - - # If don't specify analysis_time_in_window, is assumed to be middle - if analysis_time_in_window is None: - analysis_time_in_window = self.analysis_window/2 - - # Time offset from middle of time window, for gathering observations - _time_offset = (analysis_window/2) - analysis_time_in_window - - # Set up for jax.lax.scan, which is very fast - all_times = dac_utils._get_all_times(start_time, analysis_window, - n_cycles) - - if self.steps_per_window is None: - self.steps_per_window = round(analysis_window/self.delta_t) + 1 - self._model_timesteps = jnp.arange(self.steps_per_window)*self.delta_t - - # Get the obs vectors for each analysis window - all_filtered_idx = dac_utils._get_obs_indices( - obs_times=obs_vector.times, - analysis_times=all_times+_time_offset, - start_inclusive=True, - end_inclusive=True, - analysis_window=analysis_window - ) - - all_filtered_padded = dac_utils._pad_time_indices(all_filtered_idx) - - self._obs_vector = obs_vector - self._obs_error_sd = obs_error_sd - - # Padding observations - if obs_vector.stationary_observers: - self._obs_loc_masks = jnp.ones(obs_vector.values.shape, dtype=bool) - else: - obs_vals, obs_locs, obs_loc_masks = dac_utils._pad_obs_locs( - obs_vector) - self._obs_vector.values = obs_vals - self._obs_vector.location_indices = obs_locs - self._obs_loc_masks = jnp.array(obs_loc_masks) - - cur_state, all_results = jax.lax.scan( - self._cycle_and_forecast, - init=(input_state.values, start_time), - xs=all_filtered_padded) - self.loss_values = all_results[1] - all_values = all_results[0] - - if return_forecast: - all_times_forecast = jnp.arange( - 0, - n_cycles*analysis_window, - self.delta_t - ) + start_time - return vector.StateVector(values=jnp.concatenate(all_values), - times=all_times_forecast) - else: - return vector.StateVector(values=jnp.vstack([ - forecast[0] for forecast in all_values] - ), - times=all_times) + return x0_new diff --git a/dabench/dasupport/__init__.py b/dabench/dasupport/__init__.py new file mode 100644 index 0000000..c234c37 --- /dev/null +++ b/dabench/dasupport/__init__.py @@ -0,0 +1,5 @@ +from .generate_era5_ensemble import GenEra5Ens + +__all__ = [ + 'GenEra5Ens', + ] diff --git a/dabench/dasupport/__pycache__ b/dabench/dasupport/__pycache__ new file mode 100644 index 0000000..e69de29 diff --git a/dabench/dasupport/generate_era5_ensemble.py b/dabench/dasupport/generate_era5_ensemble.py new file mode 100644 index 0000000..b7d9876 --- /dev/null +++ b/dabench/dasupport/generate_era5_ensemble.py @@ -0,0 +1,253 @@ +# Sample a series of initial conditions from era5 in order to generate a test initial ensemble + +import argparse + +# For converting strings into datetime objects +from datetime import datetime, timedelta + +# Interface to Google Cloud Services +import gcsfs +import xarray as xr +from dateutil.relativedelta import relativedelta + +from ..utils.timing import report_timing + +# Selected vars for ERA5 ensemble +# This will reduce the number of model fields processed and stored in the ensemble +# A number of these fields are used, for example, by the Google Research NeuralGCM, +# while additional variables are added to support DA of surface satellite observations. +ERA5_CONTROL_VARIABLES = [ + 'geopotential', + 'temperature', + 'specific_humidity', + 'u_component_of_wind', + 'v_component_of_wind', + 'specific_cloud_ice_water_content', + 'specific_cloud_liquid_water_content', + 'surface_pressure', + 'sea_surface_temperature', + 'sea_ice_cover', + # additional variables for DA support: 10m wind speed, u/v neutral winds at 10m + # (wind speed is precomputed upon ensemble file generation) + '10m_u_component_of_wind', + '10m_v_component_of_wind', + '10m_u_component_of_neutral_wind', + '10m_v_component_of_neutral_wind', + 'significant_height_of_combined_wind_waves_and_swell', + 'mean_wave_direction', + 'mean_wave_period', + 'geopotential_at_surface' +] +# From ECMWF docs (for wave parameters): +# https://codes.ecmwf.int/grib/param-db/140229 +# https://codes.ecmwf.int/grib/param-db/140230 +# https://codes.ecmwf.int/grib/param-db/140232 + + +#%% Parse arguments +def parse_arguments(): + parser = argparse.ArgumentParser(description="Process command line inputs.") + + # Define the arguments + parser.add_argument( + "--atmosphere_ensemble_s3_key", + type=str, + required=True, + default=None, + help="The s3 path for the ensemble zarr store.", + ) + parser.add_argument( + "--date_format", + type=str, + required=False, + default="%Y%m%dZ%H", + help="Date format. Default: %Y%m%dZ%H", + ) + parser.add_argument( + "--target_date", + type=str, + required=True, + default=None, #datetime.strptime(f"{YEAR}{MONTH}{DAY}Z{HOUR}",'%Y%m%dZ%H'), + help="Initialization date. Default format: %Y%m%dZ%H", + ) + parser.add_argument( + "--ensemble_size", + type=int, + required=True, + default=None, + help="Number of ensemble members", + ) + parser.add_argument( + "--sample_strategy", + type=str, + required=False, + default="consecutive_day", + help="{'multi_year'|'multi_month'|'consecutive_day'}", + ) + parser.add_argument( + "--start_date", + type=str, + required=True, + default=None, #datetime.strptime(f"{YEAR-1}{MONTH}{DAY}Z{HOUR}",'%Y%m%dZ%H'), + help="Date to start backwards count for sample strategy", + ) + parser.add_argument( + "--era5_path", + type=str, + required=False, + default="gs://gcp-public-data-arco-era5/ar/full_37-1h-0p25deg-chunk-1.zarr-v3", + help="Cloud-based source of the ERA5 dataset to access as ensemble members.", + ) + # Parse the arguments + args = parser.parse_args() + return args + + +#%% Define the initial ensemble + + +def _define_init_ensemble( + ensemble_size, init_ensemble_start_date, init_ensemble_sample_strategy="multi_year" +): + + if init_ensemble_sample_strategy == "multi_year": + increment = relativedelta(years=1) + elif init_ensemble_sample_strategy == "multi_month": + increment = timedelta(months=1) + elif init_ensemble_sample_strategy == "consecutive_day": + increment = timedelta(days=1) + else: + raise Exception( + f"Not a valid init_ensemble_sampling_strategy = {init_ensemble_sample_strategy}" + ) + + init_ensemble_member_dates = [] + for i in range(ensemble_size): + init_ensemble_member_dates.append(init_ensemble_start_date - i * increment) + + print(f"ensemble member init date list = {init_ensemble_member_dates}") + + return init_ensemble_member_dates + + +def GenEra5Ens( + date_format:str="%Y%m%dZ%H", + atmosphere_ensemble_s3_key:str=None, + target_date:datetime=datetime.strptime("19990101Z00","%Y%m%dZ%H"), + sample_strategy:str="consecutive_day", + start_date:datetime=datetime.strptime("19981231Z00","%Y%m%dZ%H"), + ensemble_size:int=4, + era5_path:str="gs://gcp-public-data-arco-era5/ar/full_37-1h-0p25deg-chunk-1.zarr-v3", + verbose:bool=False, + ): + + #%% Set up the gcp access to era5 + if era5_path[0:2] == "gs": + gcs = gcsfs.GCSFileSystem(token="anon") + ds_era5 = xr.open_zarr(gcs.get_mapper(era5_path), chunks=None) + else: + raise Exception("Non-GCP source for ERA5 not yet supported. EXITING...") + report_timing(timing_label="GenEra5Ens:: access remote zarr store") + + #%% Reorder the latitudes + # Following: + # https://stackoverflow.com/questions/54677161/xarray-reverse-an-array-along-one-coordinate + # (ECMWF latitudes are often stored N to S instead of - to +) + ds_era5 = ds_era5.isel(latitude=slice(None, None, -1)) + print(ds_era5.latitude) + assert ds_era5.latitude[0] < ds_era5.latitude[-1] + + #%% Determine dates for initial ensemble sampling + init_ensemble_member_dates = _define_init_ensemble( + ensemble_size=ensemble_size, + init_ensemble_start_date=start_date, + init_ensemble_sample_strategy=sample_strategy, + ) + + #%% Now sample the selection from era5 and put into new zarr store on s3 + + #%% Sample from era5 + ds_init_ens = ds_era5[ERA5_CONTROL_VARIABLES].sel(time=init_ensemble_member_dates) + report_timing( + timing_label="GenEra5Ens:: select time steps as ensemble members" + ) + if verbose: + print(ds_init_ens) + + #%% Update time to target and add ensemble dimension + ds_init_ens = ds_init_ens.rename_dims(dims_dict={"time": "member"}) + ds_init_ens["member"] = range(ensemble_size) + ds_init_ens = ds_init_ens.drop_vars("time") + report_timing( + timing_label="GenEra5Ens:: add member dimension to replace time" + ) + if verbose: + print(ds_init_ens) + + #%% Select target date from era5 for recentering the ensemble + ds_target = ds_era5[ERA5_CONTROL_VARIABLES].sel(time=target_date) + + #%% Compute the 10m diagnostic wind speed and 10m neutral wind speed + if ('10m_u_component_of_neutral_wind' in ERA5_CONTROL_VARIABLES and + '10m_v_component_of_neutral_wind' in ERA5_CONTROL_VARIABLES): + ds_init_ens['ws10n'] = (ds_init_ens['10m_u_component_of_neutral_wind']**2 + ds_init_ens['10m_v_component_of_neutral_wind']**2)**(0.5) + ds_target['ws10n'] = (ds_target['10m_u_component_of_neutral_wind']**2 + ds_target['10m_v_component_of_neutral_wind']**2)**(0.5) + report_timing( + timing_label="GenEra5Ens:: computing neutral wind speeds at 10m (ws10n)" + ) + if ('10m_u_component_of_wind' in ERA5_CONTROL_VARIABLES and + '10m_v_component_of_wind' in ERA5_CONTROL_VARIABLES): + ds_init_ens['ws10m'] = (ds_init_ens['10m_u_component_of_wind']**2 + ds_init_ens['10m_v_component_of_wind']**2)**(0.5) + ds_target['ws10m'] = (ds_target['10m_u_component_of_wind']**2 + ds_target['10m_v_component_of_wind']**2)**(0.5) + report_timing( + timing_label="GenEra5Ens:: computing diagnostic wind speeds at 10m (ws10m)" + ) + + #%% Recenter ensemble to target date + print(f'GenEra5Ens:: re-centering ensemble with ensemble_size = {ensemble_size} to target_date = {target_date}...') + ds_mean = ds_init_ens.mean(dim="member") + ds_diff = ds_target - ds_mean + ds_init_ens = ds_init_ens + ds_diff + report_timing( + timing_label="GenEra5Ens:: recenter ensemble to target date" + ) + if verbose: + print(ds_init_ens) + + #%% Now add time back on as a singleton dimension + ds_init_ens = ds_init_ens.expand_dims(dim={"time": [target_date]}, axis=0) + report_timing( + timing_label="GenEra5Ens:: add time dimension back on to dataset structure" + ) + if verbose: + print(ds_init_ens) + + #%% Add some checks to make sure dimensions haven't changed + assert ds_era5.sizes['latitude'] == ds_init_ens.sizes['latitude'] + assert ds_era5.sizes['longitude'] == ds_init_ens.sizes['longitude'] + assert ds_era5.sizes['level'] == ds_init_ens.sizes['level'] + + #%% Store to zarr (locally or on e.g. AWS s3) + print('Storing as zarr...') + ds_init_ens.to_zarr(atmosphere_ensemble_s3_key, mode="w") + report_timing( + timing_label="GenEra5Ens:: upload to s3 as a new zarr store" + ) + + +#%% Main access +if __name__ == "__main__": + args = parse_arguments() + + # %% Process input arguments + report_timing(timing_label="GenEra5Ens:: initializing...") + + GenEra5Ens( + date_format=args.date_format, + atmosphere_ensemble_s3_key=args.atmosphere_ensemble_s3_key, + target_date=args.target_date, + sample_strategy=args.sample_strategy, + start_date=args.start_date, + ensemble_size=args.ensemble_size, + era5_path=args.era5_path, + ) diff --git a/dabench/data/__init__.py b/dabench/data/__init__.py index 5c9509e..11e367e 100644 --- a/dabench/data/__init__.py +++ b/dabench/data/__init__.py @@ -9,6 +9,7 @@ from .barotropic import Barotropic from .enso_indices import ENSOIndices from .qgs import QGS +from ._xarray_accessor import DABenchDatasetAccessor, DABenchDataArrayAccessor __all__ = [ 'Data', diff --git a/dabench/data/_data.py b/dabench/data/_data.py index 0c375e1..94e3cd6 100644 --- a/dabench/data/_data.py +++ b/dabench/data/_data.py @@ -22,9 +22,6 @@ class Data(): i.e. 1d. random_seed (int): random seed, defaults to 37 delta_t (float): the timestep of the data (assumed uniform) - values (ndarray): 2d array of data (time_dim, system_dim), - set by generate() method - times (ndarray): 1d array of times (time_dim), set by generate() method store_as_jax (bool): Store values as jax array instead of numpy array. Default is False (store as numpy). """ @@ -35,7 +32,6 @@ def __init__(self, original_dim=None, random_seed=37, delta_t=0.01, - values=None, store_as_jax=False, x0=None, **kwargs): @@ -46,10 +42,12 @@ def __init__(self, self.random_seed = random_seed self.delta_t = delta_t self.store_as_jax = store_as_jax - # values and x0 atts are properties to better convert between jax/numpy - self._values = values + + # Default var and coord names + self.var_names = ['x'] + self.coord_names = ['index'] + # x0 attribute is property to better convert between jax/numpy self._x0 = x0 - self._times = None if original_dim is None: self.original_dim = (system_dim,) @@ -59,69 +57,10 @@ def __init__(self, self._values_gridded = None self._x0_gridded = None - def __getitem__(self, subscript): - if self.values is None: - raise AttributeError('Object does not contain any data values.\n' - 'Run .generate() or .load() and try again') - - if isinstance(subscript, slice): - new_copy = copy.deepcopy(self) - new_copy.values = new_copy.values[ - subscript.start:subscript.stop:subscript.step] - new_copy.times = new_copy.times[ - subscript.start:subscript.stop:subscript.step] - new_copy.time_dim = new_copy.times.shape[0] - return new_copy - else: - new_copy = copy.deepcopy(self) - new_copy.values = new_copy.values[subscript] - new_copy.times = new_copy.times[subscript] - if isinstance(subscript, int): - new_copy.time_dim = 1 - else: - new_copy.time_dim = new_copy.times.shape[0] - return new_copy - - @property - def values(self): - return self._values - - @values.setter - def values(self, vals): - if vals is None: - self._values = None - else: - if self.store_as_jax: - self._values = jnp.asarray(vals) - else: - self._values = np.asarray(vals) - - @values.deleter - def values(self): - del self._values - @property def x0(self): return self._x0 - @property - def times(self): - return self._times - - @times.setter - def times(self, vals): - if vals is None: - self._times = None - else: - if self.store_as_jax: - self._times = jnp.asarray(vals) - else: - self._times = np.asarray(vals) - - @times.deleter - def times(self): - del self._times - @x0.setter def x0(self, x0_vals): if x0_vals is None: @@ -136,13 +75,6 @@ def x0(self, x0_vals): def x0(self): del self._x0 - @property - def values_gridded(self): - if self._values is None: - return None - else: - return self._to_original_dim() - @property def x0_gridded(self): if self._x0 is None: @@ -150,31 +82,9 @@ def x0_gridded(self): else: return self._x0.reshape(self.original_dim) - def _to_original_dim(self): - """Converts 1D representation of system back to original dimensions. - - Returns: - Multidimensional array with shape: - (time_dim, original_dim[0], ..., original_dim[n]) - """ - return jnp.reshape(self.values, (self.time_dim,) + self.original_dim) - - def sample_cells(self, targets): - """Samples values at a list of multidimensional array indices. - - Args: - targets (ndarray): Array of target indices in shape: - (num_of_target_indices, time_dim + original_dim). E.g. - [[0,0], [0,1]] samples the first and second cell values in the - first timestep (in this case original_dim = 1). - """ - tupled_targets = tuple(tuple(targets[:, i]) for - i in range(len(self.original_dim) + 1)) - return self._to_original_dim()[tupled_targets] - def generate(self, n_steps=None, t_final=None, x0=None, M0=None, return_tlm=False, stride=None, **kwargs): - """Generates a dataset and assigns values and times to the data object. + """Generates a dataset and returns xarray state vector. Notes: Either provide n_steps or t_final in order to indicate the length @@ -251,254 +161,48 @@ def generate(self, n_steps=None, t_final=None, x0=None, M0=None, jax_comps=self.store_as_jax, **kwargs) - # The generate method specifically stores data in the object, - # as opposed to the forecast method, which does not. - # Store values and times as part of data object - self.values = y[:, :self.system_dim] - self.times = t - self.time_dim = len(t) + # Convert to JAX if necessary + self.time_dim = t.shape[0] + out_dim = (self.time_dim,) + self.original_dim + if self.store_as_jax: + y_out = jnp.array(y[:,:self.system_dim].reshape(out_dim)) + else: + y_out = np.array(y[:,:self.system_dim].reshape(out_dim)) + # Build Xarray object for output + coord_dict = dict(zip( + ['time'] + self.coord_names, + [t] + [np.arange(dim) for dim in self.original_dim] + )) + out_vec = xr.Dataset( + {self.var_names[0]: (coord_dict.keys(),y_out)}, + coords=coord_dict, + attrs={'store_as_jax':self.store_as_jax, + 'system_dim': self.system_dim, + 'delta_t': self.delta_t + } + ) # Return the data series and associated TLMs if requested if return_tlm: # Reshape M matrix - M = jnp.reshape(y[:, self.system_dim:], - (self.time_dim, - self.system_dim, - self.system_dim) - ) - if self.store_as_jax: - return M - else: - return np.array(M) - - def _import_xarray_ds(self, ds, include_vars=None, exclude_vars=None, - years_select=None, dates_select=None, - lat_sorting=None): - # Convert to numpy background - ds = ds.as_numpy() - - if dates_select is not None: - dates_filter_indices = ds.time.dt.date.isin(dates_select) - # First check to make sure the dates exist in the object - if dates_filter_indices.sum() == 0: - raise ValueError('Dataset does not contain any of the dates' - ' specified in dates_select\n' - 'dates_select = {}\n' - 'NetCDF contains {}'.format( - dates_select, - np.unique(ds.time.dt.date) - ) - ) + M = jnp.reshape(y[:, self.system_dim:], + (self.time_dim, + self.system_dim, + self.system_dim) + ) else: - ds = ds.isel(time=dates_filter_indices) - else: - if years_select is not None: - year_filter_indices = ds.time.dt.year.isin(years_select) - # First check to make sure the years exist in the object - if year_filter_indices.sum() == 0: - raise ValueError('Dataset does not contain any of the ' - 'years specified in years_select\n' - 'years_select = {}\n' - 'NetCDF contains {}'.format( - years_select, - np.unique(ds.time.dt.year) - ) - ) - else: - ds = ds.isel(time=year_filter_indices) - - # Check size before loading - size_gb = ds.nbytes / (1024 ** 3) - if size_gb > 1: - warnings.warn('Trying to load large xarray dataset into memory. \n' - 'Size: {} GB. Operation may take a long time, ' - 'stall, or crash.'.format(size_gb)) - - # Get variable names and shapes - names_list = [] - shapes_list = [] - if exclude_vars is not None: - ds = ds.drop_vars(exclude_vars) - if include_vars is not None: - ds = ds[include_vars] - for data_var in ds.data_vars: - shapes_list.append(ds[data_var].shape) - names_list.append(data_var) - - # Load - ds.load() - - # Get dims - dims = ds.sizes - dims_names = list(ds.sizes) - - # Set times - time_key = None - dims_keys = dims.keys() - if 'time' in dims_keys: - time_key = 'time' - elif 'times' in dims_keys: - time_key = 'times' - elif 'time0' in dims_keys: - time_key = 'time0' - if time_key is not None: - self.times = ds[time_key].values - self.time_dim = self.times.shape[0] - else: - self.times = np.array([0]) - self.time_dim = 1 - - # Find names for key dimensions: lat, lon, level (if it exists) - lat_key = None - lon_key = None - lev_key = None - if 'level' in dims_keys: - lev_key = 'level' - elif 'lev' in dims_keys: - lev_key = 'lev' - if 'latitude' in dims_keys: - lat_key = 'latitude' - elif 'lat' in dims_keys: - lat_key = 'lat' - if 'longitude' in dims_keys: - lon_key = 'longitude' - elif 'lon' in dims_keys: - lon_key = 'lon' - - # Reorder dimensions: time, level, lat, lon, etc. - dim_order = np.array([time_key, lev_key, lat_key, lon_key]) - dim_order = dim_order[dim_order != np.array(None)] - remaining_dims = [d for d in dims_names if d not in dim_order] - full_dim_order = list(dim_order) + remaining_dims - - if len(full_dim_order) > 0: - ds = ds.transpose(*full_dim_order) - - # Orient data vertically - if lat_key is not None: - if lat_sorting is not None: - if lat_sorting == 'ascending': - ds = ds.sortby(lat_key, ascending=True) - elif lat_sorting == 'descending': - ds = ds.sortby(lat_key, ascending=False) - else: - warnings.warn('{} is not a valid value for lat_sorting.\n' - 'Choose one of None, "ascending", or ' - '"descending".\n' - 'Proceeding without sorting.'.format( - lat_sorting) - ) - - # Check if all elements' data shapes are equal - if len(names_list) == 0: - raise ValueError('No valid data_vars were found in dataset.\n' - 'Check your include_vars and exclude_vars args.') - if not shapes_list.count(shapes_list[0]) == len(shapes_list): - # Formatting for showing variable names and shapes - var_shape_warn_list = ['{:<12} {:<15}'.format( - 'Variable', 'Dimensions')] - var_shape_warn_list += ['{:<16} {:<16}'.format( - names_list[i], str(shapes_list[i])) - for i in range(len(shapes_list))] - warnings.warn('data_vars do not all share the same dimensions.\n' - 'Broadcasting variables to same dimensions.\n' - 'To avoid, use include_vars or exclude_vars args.\n' - 'Variable dimensions are:\n' - '{}'.format('\n'.join(var_shape_warn_list)) - ) - - # Gather values and set dimensions - temp_values = np.moveaxis(ds.to_dataarray().values, 0, -1) - self.original_dim = temp_values.shape[1:] - if self.original_dim[-1] == 1 and len(self.original_dim) > 2: - self.original_dim = self.original_dim[:-1] - - self.values = temp_values.reshape( - temp_values.shape[0], -1) - self.var_names = np.array(names_list) - if self.x0 is None: - self.x0 = self.values[0] - self.time_dim = self.values.shape[0] - self.system_dim = self.values.shape[1] - if len(full_dim_order) == 0: - warnings.warn('Unable to find any spatial or level dimensions ' - 'in dataset. Setting original_dim to system_dim: ' - '{}'.format(self.system_dim)) - - def load_netcdf(self, filepath=None, include_vars=None, exclude_vars=None, - years_select=None, dates_select=None, - lat_sorting='descending'): - """Loads values from netCDF file, saves them in values attribute - - Args: - filepath (str): Path to netCDF file to load. If not given, - defaults to loading ERA5 ECMWF SLP data over Japan - from 2018 to 2021. - include_vars (list-like): Data variables to load from NetCDF. If - None (default), loads all variables. Can be used to exclude bad - variables. - exclude_vars (list-like): Data variabes to exclude from NetCDF - loading. If None (default), loads all vars (or only those - specified in include_vars). It's recommended to only specify - include_vars OR exclude_vars (unless you want to do extra - typing). - years_select (list-like): Years to load (ints). If None, loads all - timesteps. - dates_select (list-like): Dates to load. Elements must be - datetime date or datetime objects, depending on type of time - indices in NetCDF. If both years_select and dates_select - are specified, time_stamps overwrites "years" argument. If - None, loads all timesteps. - lat_sorting (str): Orient data by latitude: - descending (default), ascending, or None (uses orientation - from data file). - """ - if filepath is None: - # Use importlib.resources to get the default netCDF from dabench - filepath = resources.files(_suppl_data).joinpath('era5_japan_slp.nc') - with xr.open_dataset(filepath, decode_coords='all') as ds: - self._import_xarray_ds( - ds, include_vars=include_vars, - exclude_vars=exclude_vars, - years_select=years_select, dates_select=dates_select, - lat_sorting=lat_sorting) - - def save_netcdf(self, filename): - """Saves values in values attribute to netCDF file - - Args: - filepath (str): Path to netCDF file to save - """ - - # Set variable names - if not hasattr(self, 'var_names') or self.var_names is None: - var_names = ['var{}'.format(i) for - i in range(self.values.shape[1])] - else: - var_names = self.var_names - - # Set times - if not hasattr(self, 'times') or self.times is None: - times = np.arange(self.values.shape[0]) - else: - times = self.times - - # Get values as list: - values_list = [('time', self.values[:, i]) for i in range( - self.values.shape[1])] - - data_dict = dict(zip(var_names, values_list)) - coords_dict = { - 'time': times, - 'system_dim': range(len(var_names)) - } - ds = xr.Dataset( - data_vars=data_dict, - coords=coords_dict + M = np.reshape(y[:, self.system_dim:], + (self.time_dim, + self.system_dim, + self.system_dim) + ) + M = xr.DataArray( + M, dims=('time','system_0','system_n') ) - - ds.to_netcdf(filename, mode='w') + return out_vec, M + else: + return out_vec def rhs_aux(self, x, t): """The auxiliary model used to compute the TLM. @@ -592,8 +296,8 @@ def calc_lyapunov_exponents_series(self, total_time=None, rescale_time=1, # Loop over rescale time periods for i, (t1, t2) in enumerate(zip(times[:-1], times[1:])): - M = self.generate(t_final=t2-t1, x0=x0, M0=M0, return_tlm=True) - x_t2 = self.values[-1] + x, M = self.generate(t_final=t2-t1, x0=x0, M0=M0, return_tlm=True) + x_t2 = x.isel(time=-1).to_array().data.flatten() M_t2 = M[-1] Q, R = jnp.linalg.qr(M_t2) @@ -647,30 +351,50 @@ def calc_lyapunov_exponents_final(self, total_time=None, rescale_time=1, x0=x0, convergence=convergence)[-1] - def split_train_valid_test(self, train_size, valid_size, test_size): - """Splits data into train, validation, and test sets by time + def load_netcdf(self, filepath=None, include_vars=None, exclude_vars=None, + years_select=None, dates_select=None, + lat_sorting='descending'): + """Loads values from netCDF file, saves them in values attribute Args: - train_size, valid_size, test_size (float or int): Size of sets. - If < 1, represents the fraction of the time series to use. - If > 1, represents the number of timesteps. - - Returns: - (train_obj, valid_obj, test_obj): Data objects + filepath (str): Path to netCDF file to load. If not given, + defaults to loading ERA5 ECMWF SLP data over Japan + from 2018 to 2021. + include_vars (list-like): Data variables to load from NetCDF. If + None (default), loads all variables. Can be used to exclude bad + variables. + exclude_vars (list-like): Data variabes to exclude from NetCDF + loading. If None (default), loads all vars (or only those + specified in include_vars). It's recommended to only specify + include_vars OR exclude_vars (unless you want to do extra + typing). + years_select (list-like): Years to load (ints). If None, loads all + timesteps. + dates_select (list-like): Dates to load. Elements must be + datetime date or datetime objects, depending on type of time + indices in NetCDF. If both years_select and dates_select + are specified, time_stamps overwrites "years" argument. If + None, loads all timesteps. + lat_sorting (str): Orient data by latitude: + descending (default), ascending, or None (uses orientation + from data file). """ + if filepath is None: + # Use importlib.resources to get the default netCDF from dabench + filepath = resources.files(_suppl_data).joinpath('era5_japan_slp.nc') + return xr.open_dataset(filepath, decode_coords='all', engine='scipy').as_numpy() + # self._import_xarray_ds( + # ds, include_vars=include_vars, + # exclude_vars=exclude_vars, + # years_select=years_select, dates_select=dates_select, + # lat_sorting=lat_sorting) + + def save_netcdf(self, ds, filename): + """Saves values in values attribute to netCDF file - if 0 < train_size < 1: - train_size = round(train_size*self.time_dim) - if 0 < valid_size < 1: - valid_size = round(valid_size*self.time_dim) - if 0 < test_size < 1: - test_size = round(test_size*self.time_dim) - - # Round up train_size - if train_size + valid_size + test_size < self.time_dim: - train_size = self.time_dim - valid_size - test_size - - train_end = train_size - valid_end = train_size + valid_size + Args: + ds (Xarray Dataset): Xarray dataset + filepath (str): Path to netCDF file to save + """ - return self[:train_end], self[train_end:valid_end], self[valid_end:] + ds.to_netcdf(filename, mode='w') \ No newline at end of file diff --git a/dabench/data/_utils.py b/dabench/data/_utils.py index f119674..6034237 100644 --- a/dabench/data/_utils.py +++ b/dabench/data/_utils.py @@ -31,10 +31,7 @@ def integrate(function, x0, t_final, delta_t, method='odeint', stride=None, """ if method == 'odeint': # Define timesteps - if jax_comps: - t = jnp.arange(0.0, t_final, delta_t) - else: - t = np.arange(0.0, t_final, delta_t) + t = np.arange(0.0, t_final - delta_t/2, delta_t) # If stride is defined, remove timesteps that are not on stride steps if stride is not None: assert stride > 1 and isinstance(stride, int), \ diff --git a/dabench/data/_xarray_accessor.py b/dabench/data/_xarray_accessor.py new file mode 100644 index 0000000..faa344c --- /dev/null +++ b/dabench/data/_xarray_accessor.py @@ -0,0 +1,62 @@ +import xarray as xr +import numpy as np +import warnings + + +def _check_split_lengths(xr_obj, split_lengths): + total_length = np.sum(split_lengths) + xr_timedim = xr_obj.sizes['time'] + if xr_timedim < total_length: + warnings.warn("Specified split lengths ({}) exceed \n" + "Xarray object's time dimension ({}).".format( + split_lengths, xr_timedim + )) + elif xr_timedim > total_length: + warnings.warn("Specified split lengths ({}) are shorter than " + "Xarray object's time dimension ({}).".format( + split_lengths, xr_timedim + )) + + +@xr.register_dataset_accessor('dab') +class DABenchDatasetAccessor: + """Helper methods for manipulating xarray Datasets""" + def __init__(self, xarray_obj): + self._obj = xarray_obj + + def flatten(self): + if 'time' in self._obj.coords: + remaining_dim = ['time'] + else: + remaining_dim = [] + return self._obj.to_stacked_array('system', remaining_dim) + + def split_train_val_test(self, split_lengths): + _check_split_lengths(self._obj, split_lengths) + out_ds = [] + start_i = 0 + for sl in split_lengths: + end_i = start_i + sl + out_ds.append(self._obj.isel(time=slice(start_i, end_i))) + return tuple(out_ds) + + +@xr.register_dataarray_accessor('dab') +class DABenchDataArrayAccessor: + """Helper methods for manipulating xarray DataArrays""" + def __init__(self, xarray_obj): + self._obj = xarray_obj + + def unflatten(self): + return self._obj.to_unstacked_dataset('system') + + def split_train_val_test(self, split_lengths): + _check_split_lengths(self._obj, split_lengths) + out_ds = [] + start_i = 0 + for sl in split_lengths: + end_i = start_i + sl + out_ds.append(self._obj.isel(time=slice(start_i, end_i))) + return tuple(out_ds) + + diff --git a/dabench/data/gcp.py b/dabench/data/gcp.py index 632d431..3e8727c 100644 --- a/dabench/data/gcp.py +++ b/dabench/data/gcp.py @@ -112,14 +112,17 @@ def _load_gcp_era5(self): # Subset by lon boundaries ds = ds.sel(longitude=slice(subset_min_lon, subset_max_lon)) - self._import_xarray_ds(ds) + # Assign system dimension + ds = ds.assign_attrs(system_dim=ds.to_stacked_array('system',['time']).sizes['system']) + + return ds def generate(self): """Alias for _load_gcp_era5""" warnings.warn('GCP.generate() is an alias for the load() method. ' 'Proceeding with downloading ERA5 data from GCP...') - self._load_gcp_era5() + return self._load_gcp_era5() def load(self): """Alias for _load_gcp_era5""" - self._load_gcp_era5() + return self._load_gcp_era5() diff --git a/dabench/data/sqgturb.py b/dabench/data/sqgturb.py index 5bd1db5..cccff62 100644 --- a/dabench/data/sqgturb.py +++ b/dabench/data/sqgturb.py @@ -113,6 +113,10 @@ def __init__(self, values=values, times=times, delta_t=delta_t, store_as_jax=store_as_jax, **kwargs) + + self.coord_names = ['level','x','y'] + self.var_names=['pv'] + # Fall back on default if no pv if pv is None: with resources.open_binary( @@ -470,8 +474,8 @@ def integrate(self, f, x0, t_final, delta_t=None, include_x0=True, pvspec, values = jax.lax.scan(self._rk4, pvspec, xs=None, length=n_steps) - # Reshape to (time_dim, system_dim) - values = values.reshape((self.time_dim, -1)) + # Apply reverse fft to + values = self.ifft2(values) # Update internal states self.pvspec = pvspec diff --git a/dabench/metrics/_ensemble.py b/dabench/metrics/_ensemble.py new file mode 100644 index 0000000..763b366 --- /dev/null +++ b/dabench/metrics/_ensemble.py @@ -0,0 +1,191 @@ +"""Ensemble forecast metrics""" + +import jax.numpy as jnp +from dabench.metrics import _utils + + +__all__ = [ + 'rank_histogram', + 'crps_ensemble', + ] + + +def rank_histogram(observations, forecasts, dim=None, member_dim="member"): + """JAX array implementation of Rank Histogram + + Description: + (from https://www.cawcr.gov.au/projects/verification/#Methods_for_EPS) + + Answers the question: How well does the ensemble spread of the forecast represent the true variability (uncertainty) of the observations? + + Also known as a "Talagrand diagram", this method checks where the verifying observation usually falls with respect to the ensemble forecast data, which is arranged in increasing order at each grid point. In an ensemble with perfect spread, each member represents an equally likely scenario, so the observation is equally likely to fall between any two members. + + To construct a rank histogram, do the following: + 1. At every observation (or analysis) point rank the N ensemble members from lowest to highest. This represents N+1 possible bins that the observation could fit into, including the two extremes + 2. Identify which bin the observation falls into at each point + 3. Tally over many observations to create a histogram of rank. + + Interpretation: + Flat - ensemble spread about right to represent forecast uncertainty + U-shaped - ensemble spread too small, many observations falling outside the extremes of the ensemble + Dome-shaped - ensemble spread too large, most observations falling near the center of the ensemble + Asymmetric - ensemble contains bias + + Note: A flat rank histogram does not necessarily indicate a good forecast, it only measures whether the observed probability distribution is well represented by the ensemble. + + Args: + predictions (ndarray): Array of predictions + targets (ndarray): Array of targets to compare against. Shape must + be broadcastable to shape of predictions. + + Returns: + [UPDATE] Float, Pearson's R correlation coefficient. + """ + + # RMSD = sqrt( 1/(N+1) * sum(Sk - M/(N+1)^2) ) + + # See: https://github.com/xarray-contrib/xskillscore/blob/64f17fdd1816b64b9e13c3f2febb9800a7e6ed0c/xskillscore/core/probabilistic.py#L830C20-L830C76 + + def _rank_first(x, y): + """Concatenates x and y and returns the rank of the + first element along the last axes""" + xy = jnp.concatenate((x[..., jnp.newaxis], y), axis=-1) + return bn.nanrankdata(xy, axis=-1)[..., 0] + + if dim is not None: + if len(dim) == 0: + raise ValueError( + "At least one dimension must be supplied to compute rank histogram over" + ) + if member_dim in dim: + raise ValueError(f'"{member_dim}" cannot be specified as an input to dim') + + ranks = xr.apply_ufunc( + _rank_first, + observations, + forecasts, + input_core_dims=[[], [member_dim]], + dask="parallelized", + output_dtypes=[int], + ) + + bin_edges = jnp.arange(0.5, len(forecasts[member_dim]) + 2) + return histogram(ranks, bins=[bin_edges], bin_names=["rank"], dim=dim, bin_dim_suffix="") + + +def crps_ensemble(observations, forecasts, axis=-1): + """JAX array implementation of Continuous Ranked Probability Score + + (From: https://confluence.ecmwf.int/display/FUG/Section+12.B+Statistical+Concepts+-+Probabilistic+Data#:~:text=The%20Continuous%20Ranked%20Probability%20Score,the%20forecast%20is%20wholly%20inaccurate.) + + A generalisation of Ranked Probability Score (RPS) is the Continuous Rank Probability Score (CRPSS) where the thresholds are continuous rather than discrete (see Nurmi, 2003; Jollife and Stephenson, 2003; Wilks, 2006). The Continuous Ranked Probability Score (CRPS) is a measure of how good forecasts are in matching observed outcomes. Where: + + CRPS = 0 the forecast is wholly accurate; + CRPS = 1 the forecast is wholly inaccurate. + CRPS is calculated by comparing the Cumulative Distribution Functions (CDF) for the forecast against a reference dataset (observations, or analyses, or climatology) over a given period. + + Args: + predictions (ndarray): Array of predictions + targets (ndarray): Array of targets to compare against. Shape must + be broadcastable to shape of predictions. + + Returns: + [UPDATE] Float, Mean Squared Error + """ + + # Integral from -inf to inf: (1/M) * sum[ S [P_j(x) - H(x - x_oj)]^2 dx ] + # where Pj, H, and x_oj are the predicted cumulative distribution for case j, the Heaviside step function, + # and the observed value, respectively. + # (see: https://www.ecmwf.int/sites/default/files/elibrary/2007/10729-ensemble-forecasting.pdf) + # with M independent cases (e.g. different dates) + + # See: https://github.com/properscoring/properscoring/blob/a465b5578d4b661e662933e84fa7673a70e75e94/properscoring/_crps.py#L244 + + # Manage input quality + observations = jnp.asarray(observations) + forecasts = jnp.asarray(forecasts) + + if axis != -1: + # Move the axis to the end + forecasts = jnp.rollaxis(forecasts, axis, start=forecasts.ndim) + + if observations.shape not in [forecasts.shape, forecasts.shape[:-1]]: + raise ValueError('observations and forecasts must have matching ' + 'shapes or matching shapes except along `axis=%s`' + % axis) + + if observations.shape == forecasts.shape: + if weights is not None: + raise ValueError('cannot supply weights unless you also supply ' + 'an ensemble forecast') + return abs(observations - forecasts) + + # Sort forecast members by target quantity + idx = jnp.argsort(forecasts, axis=-1) + forecasts = forecasts[idx] + weights = jnp.ones_like(forecasts) + + return _crps_ensemble_vectorized(observation, forecasts, weights, result) + +# @guvectorize(["void(float64[:], float64[:], float64[:], float64[:])"], +# "(),(n),(n)->()", nopython=True) + + @partial(jnp.vectorize, signature='(),(n),(n)->()') + def _crps_ensemble_vectorized(observation, forecasts, weights, result): + # beware: forecasts are assumed sorted in NumPy's sort order + + # add asserts here: + + # we index the 0th element to get the scalar value from this 0d array: + # http://numba.pydata.org/numba-doc/0.18.2/user/vectorize.html#the-guvectorize-decorator + obs = observation[0] + + if jnp.isnan(obs): + result[0] = jnp.nan + return + + total_weight = 0.0 + for n, weight in enumerate(weights): + if jnp.isnan(forecasts[n]): + # NumPy sorts NaN to the end + break + if not weight >= 0: + # this catches NaN weights + result[0] = jnp.nan + return + total_weight += weight + + obs_cdf = 0 + forecast_cdf = 0 + prev_forecast = 0 + integral = 0 + + for n, forecast in enumerate(forecasts): + if jnp.isnan(forecast): + # NumPy sorts NaN to the end + if n == 0: + integral = jnp.nan + # reset for the sake of the conditional below + forecast = prev_forecast + break + + if obs_cdf == 0 and obs < forecast: + integral += (obs - prev_forecast) * forecast_cdf ** 2 + integral += (forecast - obs) * (forecast_cdf - 1) ** 2 + obs_cdf = 1 + else: + integral += ((forecast - prev_forecast) + * (forecast_cdf - obs_cdf) ** 2) + + forecast_cdf += weights[n] / total_weight + prev_forecast = forecast + + if obs_cdf == 0: + # forecast can be undefined here if the loop body is never executed + # (because forecasts have size 0), but don't worry about that because + # we want to raise an error in that case, anyways + integral += obs - forecast + + result[0] = integral + + diff --git a/dabench/model/__init__.py b/dabench/model/__init__.py index f05d128..abae58e 100644 --- a/dabench/model/__init__.py +++ b/dabench/model/__init__.py @@ -1,7 +1,9 @@ from ._model import Model from ._rc import RCModel +from ._neuralgcm import NeuralGCM __all__ = [ 'Model', 'RCModel', + 'NeuralGCM', ] diff --git a/dabench/model/_neuralgcm.py b/dabench/model/_neuralgcm.py index 075294a..5b37b3d 100644 --- a/dabench/model/_neuralgcm.py +++ b/dabench/model/_neuralgcm.py @@ -330,6 +330,7 @@ def regrid_input(self, data, fill_nans=False): return eval_data + def flat_to_xarray(self, flat, xr_template): remap_dict = {} coords_order = ['time','level','longitude','latitude'] diff --git a/dabench/model/_rc.py b/dabench/model/_rc.py index e3636a6..e842dd7 100644 --- a/dabench/model/_rc.py +++ b/dabench/model/_rc.py @@ -7,6 +7,7 @@ from scipy import sparse, stats, linalg import numpy as np import jax.numpy as jnp +import xarray as xr from dabench import vector, model @@ -167,7 +168,7 @@ def weights_init(self): self.states = None self.Adense = A.asformat('array') if self.sparse_adj_matrix else A - def generate(self, u, A=None, Win=None, r0=None, save_states=False): + def generate(self, state_vec, A=None, Win=None, r0=None): """generate reservoir time series from input signal u Args: u (array_like): (time_dimension, system_dimension), input signal to @@ -183,6 +184,7 @@ def generate(self, u, A=None, Win=None, r0=None, save_states=False): Returns: r (array_like): (time_dim, reservoir_dim), reservoir state """ + u = state_vec.to_stacked_array('system',['time']).data r = np.zeros((u.shape[0], self.reservoir_dim)) if r0 is not None: @@ -194,11 +196,10 @@ def generate(self, u, A=None, Win=None, r0=None, save_states=False): for t in range(0, u.shape[0]): r[t, :] = self.update(r[t - 1], u[t - 1, :], A, Win) - if save_states: - self.states = r - self.s_last = r[-1] - else: - return r + return xr.Dataset( + {'r': (('time', 'reservoir'), r)}, + coords={'time':state_vec.time} + ) def update(self, r, u, A=None, Win=None): """Update reservoir state with input signal and previous state @@ -379,10 +380,9 @@ def train(self, data_obj, update_Wout=True): Wout (array_like): Trained output weight matrix """ - if self.states is None: - self.generate(data_obj.values, save_states=True) - r = self.states[:, :] - u = data_obj.values[:, :] + r = self.generate(data_obj)['r'].data + # u = data_obj.to_array().transpose(..., 'variable').data.reshape(data_obj.sizes['time'], -1) + u = data_obj.to_array().stack(system=['variable','i']).data self.Wout = self._compute_Wout(r, u, update_Wout=update_Wout, u=u.T) def _compute_Wout(self, rt, y, update_Wout=True, u=None): @@ -480,21 +480,23 @@ def _linsolve_pinv(self, X, Y, beta=None): def forecast(self, state_vec, n_steps=1): if n_steps == 1: - new_vals = self.update(state_vec.values, - self.readout(state_vec.values)) - new_vec = vector.StateVector(values=new_vals, store_as_jax=True) - + new_vals = self.update(state_vec['r'].data, + self.readout(state_vec['r'].data)) + new_vec = xr.Dataset( + {'r':(('time','reservoir'), new_vals)} + ) else: - r = state_vec.values + r = state_vec['r'].data r_full = jnp.zeros((n_steps, self.reservoir_dim)) for i in range(n_steps): r_full = r_full.at[i].set(r) if i < n_steps-1: r = self.update(r, self.readout(r)) - new_vec = vector.StateVector(values=r_full, store_as_jax=True) - - return new_vec + new_vec = xr.Dataset( + {'r':(('time','reservoir'), r_full)} + ) + return new_vec.isel(time=-1), new_vec.drop_isel(time=-1) def save_weights(self, pkl_path): """Save RC reservoir weights as pkl file. diff --git a/dabench/observer/_observer.py b/dabench/observer/_observer.py index 80998e6..5e79b61 100644 --- a/dabench/observer/_observer.py +++ b/dabench/observer/_observer.py @@ -7,8 +7,7 @@ import numpy as np import jax.numpy as jnp - -from dabench.vector import ObsVector +import xarray as xr class Observer(): @@ -72,13 +71,13 @@ class Observer(): """ def __init__(self, - data_obj, + state_vec, random_time_density=1., random_location_density=1., random_time_count=None, random_location_count=None, - time_indices=None, - location_indices=None, + times=None, + locations=None, stationary_observers=True, error_bias=0., error_sd=0., @@ -87,17 +86,37 @@ def __init__(self, store_as_jax=False, ): - self.data_obj = data_obj + self.state_vec = state_vec + self._coord_names = list(self.state_vec.coords.keys()) + self._nontime_coord_names = [coord for coord in self._coord_names + if coord != 'time'] + self.state_vec = self.state_vec.assign_coords( + {'variable': self.state_vec.data_vars} + # 'variable_index': np.arange(len(self.state_vec.data_vars))} + ) + # The system_index corresponds to the points location in a flattened + # array (i.e. state_vec[state_vec.data_vars].to_array().data.flatten()) + self.state_vec = self.state_vec.assign( + {'system_index': ( + ['variable'] + ['time'] + self._nontime_coord_names, + np.tile( + np.arange(self.state_vec.system_dim).reshape( + self.state_vec.sizes['variable'], -1 + ), + self.state_vec.sizes['time'] + ).reshape(self.state_vec.to_array().shape) + ) + } + ) + + if times is not None: + times = np.array(times) + self.times = times - if time_indices is not None: - time_indices = np.array(time_indices) - self.time_indices = time_indices self.random_time_density = random_time_density self.random_time_count = random_time_count - if location_indices is not None: - location_indices = np.array(location_indices) - self.location_indices = location_indices + self.locations = locations self.random_location_density = random_location_density self.random_location_count = random_location_count self.stationary_observers = stationary_observers @@ -117,16 +136,15 @@ def __init__(self, self.error_bias = error_bias self.error_sd = error_sd + if isinstance(self.error_bias, (list, np.ndarray, jnp.ndarray)): if len(self.error_bias) == 1: self._error_bias_is_list = False - elif not len(self.error_bias) == self.data_obj.system_dim: + elif not len(self.error_bias) == self.state_vec.system_dim: raise ValueError( "List of error biases has length {}." - "Must match either system_dim ({}) or " - "number of location indices ({})".format( - len(self.error_bias), self.data_obj.system_dim, - self.location_indices.shape[0])) + "Must match system_dim ({}) or ".format( + len(self.error_bias), self.state_vec.system_dim)) elif isinstance(self.error_bias, list): if self.store_as_jax: self.error_bias = jnp.array(self.error_bias) @@ -139,13 +157,11 @@ def __init__(self, if isinstance(self.error_sd, (list, np.ndarray, jnp.ndarray)): if len(self.error_sd) == 1: self._error_sd_is_list = False - elif not len(self.error_sd) == self.data_obj.system_dim: + elif not len(self.error_sd) == self.state_vec.system_dim: raise ValueError( "List of error sds has length {}." - "Must match either system_dim ({}) or " - "number of location indices ({})".format( - len(self.error_sd), self.data_obj.system_dim, - self.location_indices.shape[0])) + "Must match system_dim ({})".format( + len(self.error_sd), self.state_vec.system_dim)) elif isinstance(self.error_sd, list): if self.store_as_jax: self.error_sd = jnp.array(self.error_sd) @@ -157,117 +173,77 @@ def __init__(self, self.error_positive_only = error_positive_only - def _generate_time_indices(self, rng): + def _generate_times(self, rng): if self.random_time_count is not None: - self.time_indices = np.sort(rng.choice( - self.data_obj.time_dim, + self.times = np.sort(rng.choice( + self.state_vec['time'], size=self.random_time_count, replace=False, shuffle=False)) else: - self.time_indices = np.where( + self.times = self.state_vec.time[np.where( rng.binomial(1, p=self.random_time_density, - size=self.data_obj.time_dim + size=self.state_vec.sizes['time'] ).astype('bool') - )[0] + )[0]] - def _generate_stationary_indices(self, rng): + def _generate_stationary_locs(self, rng): if self.random_location_count is not None: - self.location_indices = rng.choice( - self.data_obj.system_dim, - size=self.random_location_count, - replace=False, - shuffle=False) + location_count = self.random_location_count else: - self.location_indices = np.where( - rng.binomial(1, p=self.random_location_density, - size=self.data_obj.system_dim - ).astype('bool') - )[0] - - def _generate_nonstationary_indices(self, rng): - if self.random_location_count is not None: - self.location_indices = np.array([ - rng.choice( - self.data_obj.system_dim, - size=self.random_location_count, - replace=False, - shuffle=False) - for i in range(self.time_indices.shape[0])]) + location_count = np.sum( + rng.binomial(1, + p=self.random_location_density, + size=self.state_vec.system_dim)) + if len(self._nontime_coord_names) > 1: + sample_w_replace=True else: - self.location_indices = np.array([ - np.where( - rng.binomial(1, p=self.random_location_density, - size=self.data_obj.system_dim - ).astype('bool'))[0] - for i in range(self.time_indices.shape[0]) - ], dtype=object) - - def _generate_stationary_indices_gridded(self, rng): - if self.random_location_count is not None: - arange_list = [np.arange(n) for n in self.data_obj.original_dim] - ind_possibilities = np.array( - np.meshgrid(*arange_list)).T.reshape( - -1, len(self.data_obj.original_dim)) - self.location_indices = rng.choice( - ind_possibilities, - size=self.random_location_count, - replace=False, - shuffle=False) - else: - self.location_indices = np.array(np.where( - rng.binomial(1, p=self.random_location_density, - size=self.data_obj.original_dim - ).astype('bool') - )).T - - def _generate_nonstationary_indices_gridded(self, rng): + sample_w_replace=False + self.locations = { + coord_name: xr.DataArray( + rng.choice( + self.state_vec[coord_name], + size=location_count, + replace=sample_w_replace, + shuffle=False), + dims=['observations']) + for coord_name in self._nontime_coord_names + } + self.location_dim = location_count + + def _generate_nonstationary_locs(self, rng): + """Generate different locations for each observation time""" if self.random_location_count is not None: - arange_list = [np.arange(n) for n in self.data_obj.original_dim] - ind_possibilities = np.array( - np.meshgrid(*arange_list)).T.reshape( - -1, len(self.data_obj.original_dim)) - self.location_indices = np.array([rng.choice( - ind_possibilities, - size=self.random_location_count, - replace=False, - shuffle=False) for i in range(self.time_indices.shape[0])]) + self._location_counts = np.repeat( + self.random_location_count, self.times.shape[0] + ) else: - self.location_indices = np.array([ - np.array(np.where( - rng.binomial(1, p=self.random_location_density, - size=self.data_obj.original_dim - ).astype('bool'))).T - for i in range(self.time_indices.shape[0]) - ], dtype=object) - - def _sample_stationary(self, errors_vector, sample_in_system_dim): - if sample_in_system_dim: - values_vector = ( - self.data_obj.values[self.time_indices][ - :, self.location_indices] - + errors_vector) + # An unequal amount of locations per time + self._location_counts = [np.sum( + rng.binomial(1, + p=self.random_location_density, + size=self.state_vec.system_dim) + ) + for i in range(self.times.shape[0])] + + if len(self._nontime_coord_names) > 1: + sample_w_replace=True else: - values_gridded = self.data_obj.values_gridded - values_vector = np.array([ - values_gridded[t][tuple(self.location_indices.T)] - for t in self.time_indices]) + errors_vector - return values_vector - - def _sample_nonstationary(self, errors_vector, sample_in_system_dim): - if sample_in_system_dim: - values_vector = np.array([ - (self.data_obj.values[self.time_indices[i]] - [self.location_indices[i]] + errors_vector[i]) - for i in range(self.time_dim)], dtype=object) - else: - values_gridded = self.data_obj.values_gridded - values_vector = np.array( - [values_gridded[self.time_indices[i]][ - tuple(self.location_indices[i].T)] - + errors_vector[i] for i in range(self.time_dim)], - dtype=object) - return values_vector + sample_w_replace=False + + self.locations = [{ + coord_name: xr.DataArray( + rng.choice( + self.state_vec[coord_name], + size=lc, + replace=sample_w_replace, + shuffle=False), + dims=['observations']) + for coord_name in self._nontime_coord_names + } + for lc in self._location_counts] + + self.location_dim = np.max(self._location_counts) def observe(self): """Generate observations. @@ -277,153 +253,87 @@ def observe(self): errors """ - if self.data_obj.values is None: - raise ValueError('Data have not been generated/loaded. Run:\n' - 'self.data_obj.generate() to create data for ' - 'observer') - # Define random num generator rng = np.random.default_rng(self.random_seed) # Set time indices - if self.time_indices is None: - self._generate_time_indices(rng) + if self.times is None: + self._generate_times(rng) - self.time_dim = self.time_indices.shape[0] + self.time_dim = self.times.shape[0] # For stationary observers (default) if self.stationary_observers: - # Generate location_indices if not specified - if self.location_indices is None: - # Check if data is in spectral or physical space - if (hasattr(self.data_obj, 'is_spectral') and - self.data_obj.is_spectral): - self._generate_stationary_indices_gridded(rng) - else: - self._generate_stationary_indices(rng) - - # Check that location_indices are in correct dimensions - if self.location_indices.shape[0] == 0: - raise ValueError('location_indices is an empty list') - elif len(self.location_indices.shape) == 1: - sample_in_system_dim = True - elif (self.location_indices.shape[1] == - len(self.data_obj.original_dim)): - sample_in_system_dim = False - else: - raise ValueError('location_indices must be 1D or match\n' - 'len(self.data_obj.original_dim)') - - # Generate errors - self.location_dim = np.repeat(self.location_indices.shape[0], - self.time_dim) - errors_vec_size = (self.time_dim,) + (self.location_dim[0],) - if self._error_bias_is_list: - error_bias = self.error_bias[self.location_indices] - else: - error_bias = self.error_bias - if self._error_sd_is_list: - error_sd = self.error_sd[self.location_indices] + # Generate locations if not specified + if self.locations is None: + self._generate_stationary_locs(rng) else: - error_sd = self.error_sd - errors_vector = rng.normal(loc=error_bias, - scale=error_sd, - size=errors_vec_size) + self.location_dim = next(iter(self.locations.items()))[1]['observations'].size - # Clip errors to positive only - if self.error_positive_only: - errors_vector[errors_vector < 0.] = 0. - # Get values - values_vector = self._sample_stationary( - errors_vector, - sample_in_system_dim) - - # Repeat location indices across time_dim for passing to ObsVector - full_loc_indices = np.array( - [self.location_indices] * self.time_dim) + # Sample + obs_vec = self.state_vec.sel(time=self.times).sel(self.locations) # If NON-stationary observer else: # Generate location_indices if not specified - if self.location_indices is None: - # Check if data is in spectral or physical space - if (hasattr(self.data_obj, 'is_spectral') and - self.data_obj.is_spectral): - self._generate_nonstationary_indices_gridded(rng) - else: - self._generate_nonstationary_indices(rng) - - # Check that location_indices are in correct dimensions - if self.location_indices.shape[0] == 0: - raise ValueError('location_indices is an empty list') - elif len(self.location_indices[0].shape) == 1: - sample_in_system_dim = True - elif (self.location_indices[0].shape[1] == - len(self.data_obj.original_dim)): - sample_in_system_dim = False - else: - raise ValueError('With stationary_observers=False,' - 'location_indices must be 1D array of arrays,' - ' with each element being 1D or matching\n' - 'self.data_obj.original_dim') - self.location_dim = np.array([a.shape[0] for a in - self.location_indices]) - - # Generate errors - if self._error_bias_is_list: - if self._error_sd_is_list: - errors_vector = np.array([ - rng.normal( - loc=self.error_bias[ld], - scale=self.error_sd[ld], - size=ld) - for ld in self.location_dim], dtype=object) - else: - errors_vector = np.array([ - rng.normal( - loc=self.error_bias[ld], - scale=self.error_sd, - size=ld) - for ld in self.location_dim], dtype=object) - else: - if self._error_sd_is_list: - errors_vector = np.array([ - rng.normal( - loc=self.error_bias, - scale=self.error_sd[ld], - size=ld) - for ld in self.location_dim], dtype=object) - else: - errors_vector = np.array([ - rng.normal( - loc=self.error_bias, - scale=self.error_sd, - size=ld) - for ld in self.location_dim], dtype=object) - - if self.error_positive_only: - errors_vector = np.array([ - np.maximum(e, 0.) for e in errors_vector]) - - # Get values from generator - values_vector = self._sample_nonstationary( - errors_vector, - sample_in_system_dim) - - # For passing to ObsVector - full_loc_indices = self.location_indices - - return ObsVector(values=values_vector, - times=self.data_obj.times[self.time_indices], - time_indices=self.time_indices, - location_indices=full_loc_indices, - obs_dims=self.location_dim, - num_obs=values_vector.shape[0], - errors=errors_vector, - error_dist='normal', - error_sd=self.error_sd, - error_bias=self.error_bias, - store_as_jax=self.store_as_jax, - stationary_observers=self.stationary_observers - ) + if self.locations is None: + self._generate_nonstationary_locs(rng) + + # If there's an unequal number of obs, will pad + pad_widths = self.location_dim - np.array(self._location_counts) + + # Sample + obs_vec = xr.concat([ + # Select by time + self.state_vec.sel( + time=t + # Select locations + ).sel( + self.locations[i] + # Pad observations to max number + ).pad( + observations=(0, pad_widths[i]) + ) + for i, t in enumerate(self.times)], + dim='time') + + # Transpose system_index to ensure consistency with flattened data + obs_vec['system_index'] = obs_vec['system_index'].transpose('variable','time','observations').fillna( + 0).astype(int) + + # Generate errors + errors_vec_size = ((self.time_dim,) + + (self.location_dim,) + + (obs_vec.sizes['variable'],)) + errors_vec_size = ((obs_vec.sizes['variable'],) + + (self.time_dim,) + + (self.location_dim,)) + + if self._error_bias_is_list: + error_bias = self.error_bias[obs_vec['system_index'].data] + else: + error_bias = self.error_bias + if self._error_sd_is_list: + error_sd = self.error_sd[obs_vec['system_index'].data] + else: + error_sd = self.error_sd + errors_vector = rng.normal(loc=error_bias, + scale=error_sd, + size=errors_vec_size) + + # Include flag for whether observations are stationary or not + obs_vec = obs_vec.assign_attrs( + stationary_observers=self.stationary_observers) + + # Clip errors to positive only + if self.error_positive_only: + errors_vector[errors_vector < 0.] = 0. + + # Save errors and apply them to observations + obs_vec = obs_vec.assign(errors=(obs_vec['system_index'].dims, errors_vector)) + for data_var in obs_vec['variable'].values: + obs_vec[data_var] = obs_vec[data_var] + obs_vec['errors'].sel(variable=data_var) + + + return obs_vec \ No newline at end of file diff --git a/dabench/utils/__init__.py b/dabench/utils/__init__.py new file mode 100644 index 0000000..b4600fe --- /dev/null +++ b/dabench/utils/__init__.py @@ -0,0 +1,5 @@ +from .timing import report_timing + +__all__ = [ + 'report_timing', + ] diff --git a/dabench/utils/timing.py b/dabench/utils/timing.py new file mode 100644 index 0000000..a7f0fc9 --- /dev/null +++ b/dabench/utils/timing.py @@ -0,0 +1,62 @@ +import datetime +import time + + +def report_timing(timing_label=""): + + if not hasattr(report_timing, "timing_start_time"): + report_timing.timing_start_process_time = time.process_time() + report_timing.timing_start_time = time.time() + report_timing.last_process_time = report_timing.timing_start_process_time + report_timing.last_time = report_timing.timing_start_time + return + + print(f"\n< === {timing_label} ===") + + # Print the current time, helpful for tracking long runs + now = datetime.datetime.now() + print(f"Current datetime is: {now}") + + print(" === ") + + # get process execution time + timing_end_process_time = time.process_time() + seconds = timing_end_process_time - report_timing.last_process_time + minutes = seconds / 60.0 + print(f"CPU Execution time of this step: {seconds} seconds or {minutes} minutes.") + seconds = timing_end_process_time - report_timing.timing_start_process_time + minutes = seconds / 60.0 + print(f"CPU Execution time so far: {seconds} seconds or {minutes} minutes.") + + print(" === ") + + # get wall clock time + timing_end_time = time.time() + seconds = timing_end_time - report_timing.last_time + minutes = seconds / 60.0 + print( + f"Wall Clock Execution time of this step: {seconds} seconds or {minutes} minutes." + ) + seconds = timing_end_time - report_timing.timing_start_time + minutes = seconds / 60.0 + print(f"Wall Clock Execution time so far: {seconds} seconds or {minutes} minutes.") + + print(f" === {timing_label} === >\n") + + # Set up to get estimate of time between calls + report_timing.last_process_time = timing_end_process_time + report_timing.last_time = timing_end_time + + +def _test(): + report_timing(timing_label="initializing...") + time.sleep(3) + report_timing(timing_label="3 second sleep.") + time.sleep(10) + report_timing(timing_label="10 second sleep.") + + +# %% Main access +if __name__ == "__main__": + # main(sys.argv) + _test() diff --git a/pyproject.toml b/pyproject.toml index 77dfabe..6714bc2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,7 +23,8 @@ dependencies = [ "scipy", "optax", "xarray", - "cftime" + "cftime", + "xarray_jax@git+https://github.com/kysolvik/xarray_jax_permissible.git" ] [project.optional-dependencies]