diff --git a/tests/util/test_fof_utils.py b/tests/util/test_fof_utils.py index 9564311..f0e61b1 100644 --- a/tests/util/test_fof_utils.py +++ b/tests/util/test_fof_utils.py @@ -2,6 +2,7 @@ This module contains unit tests for the `util/fof_utils.py` module. """ +import os from unittest.mock import mock_open, patch import numpy as np @@ -11,9 +12,9 @@ clean_value, compare_arrays, compare_var_and_attr_ds, - fill_nans_for_float32, get_observation_variables, get_report_variables, + replace_nan_with_sentinel, split_feedback_dataset, ) from util.log_handler import initialize_detailed_logger @@ -196,8 +197,8 @@ def test_fill_nans_for_float32_nan(arr_nan): Test that if an array containing nan is given, these values are replaced by -9.99999e05. """ - array = fill_nans_for_float32(arr_nan) - expected = np.array([1.0, -9.99999e05, 3.0, 4.0, -9.99999e05], dtype=np.float32) + array = replace_nan_with_sentinel(arr_nan) + expected = np.array([1.0, -9.99999e05, 3.0, 4.0, -9.99999e05], dtype=np.float64) assert np.array_equal(array, expected) @@ -206,7 +207,7 @@ def test_fill_nans_for_float32(arr1): Test that if an array without nan is given, the output of the function is the same as the input. """ - array = fill_nans_for_float32(arr1) + array = replace_nan_with_sentinel(arr1) assert np.array_equal(array, arr1) @@ -254,6 +255,13 @@ def test_compare_var_and_attr_ds(ds1, ds2): assert (total1, equal1) == (103, 102) assert (total2, equal2) == (103, 102) + script_dir = os.path.dirname(os.path.abspath(__file__)) + grandparent_dir = os.path.dirname(os.path.dirname(script_dir)) + + path_name = os.path.join(grandparent_dir, "differences.csv") + if os.path.exists(path_name): + os.remove(path_name) + @pytest.fixture(name="ds3") def fixture_sample_dataset_3(sample_dataset_fof): diff --git a/util/fof_utils.py b/util/fof_utils.py index f9e544d..466c3b7 100644 --- a/util/fof_utils.py +++ b/util/fof_utils.py @@ -104,13 +104,18 @@ def compare_arrays(arr1, arr2, var_name): return total, equal, diff -def fill_nans_for_float32(arr): +def replace_nan_with_sentinel(arr): """ - To make sure nan values are recognised. + If the input array has a floating dtype, it is cast to float64 + and all NaN values are replaced with the sentinel value -999999. + If the array does not have a floating dtype, it is returned unchanged. """ - if arr.dtype == np.float32 and np.isnan(arr).any(): - return np.where(np.isnan(arr), -999999, arr) - return arr + if not np.issubdtype(arr.dtype, np.floating): + return arr + + arr = arr.astype(np.float64, copy=False) + + return np.where(np.isnan(arr), -999999.0, arr) def clean_value(x): @@ -203,8 +208,8 @@ def process_var(ds1, ds2, var, detailed_logger): number of matching elements. """ - arr1 = fill_nans_for_float32(ds1[var].values) - arr2 = fill_nans_for_float32(ds2[var].values) + arr1 = replace_nan_with_sentinel(ds1[var].values) + arr2 = replace_nan_with_sentinel(ds2[var].values) if arr1.size == arr2.size: t, e, diff = compare_arrays(arr1, arr2, var) if diff.size != 0: