Skip to content
16 changes: 12 additions & 4 deletions tests/util/test_fof_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
This module contains unit tests for the `util/fof_utils.py` module.
"""

import os
from unittest.mock import mock_open, patch

import numpy as np
Expand All @@ -11,9 +12,9 @@
clean_value,
compare_arrays,
compare_var_and_attr_ds,
fill_nans_for_float32,
get_observation_variables,
get_report_variables,
replace_nan_with_sentinel,
split_feedback_dataset,
)
from util.log_handler import initialize_detailed_logger
Expand Down Expand Up @@ -196,8 +197,8 @@ def test_fill_nans_for_float32_nan(arr_nan):
Test that if an array containing nan is given, these values are replaced
by -9.99999e05.
"""
array = fill_nans_for_float32(arr_nan)
expected = np.array([1.0, -9.99999e05, 3.0, 4.0, -9.99999e05], dtype=np.float32)
array = replace_nan_with_sentinel(arr_nan)
expected = np.array([1.0, -9.99999e05, 3.0, 4.0, -9.99999e05], dtype=np.float64)
assert np.array_equal(array, expected)


Expand All @@ -206,7 +207,7 @@ def test_fill_nans_for_float32(arr1):
Test that if an array without nan is given, the output of the function
is the same as the input.
"""
array = fill_nans_for_float32(arr1)
array = replace_nan_with_sentinel(arr1)
assert np.array_equal(array, arr1)


Expand Down Expand Up @@ -254,6 +255,13 @@ def test_compare_var_and_attr_ds(ds1, ds2):
assert (total1, equal1) == (103, 102)
assert (total2, equal2) == (103, 102)

script_dir = os.path.dirname(os.path.abspath(__file__))
grandparent_dir = os.path.dirname(os.path.dirname(script_dir))

path_name = os.path.join(grandparent_dir, "differences.csv")
if os.path.exists(path_name):
os.remove(path_name)


@pytest.fixture(name="ds3")
def fixture_sample_dataset_3(sample_dataset_fof):
Expand Down
19 changes: 12 additions & 7 deletions util/fof_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,13 +104,18 @@ def compare_arrays(arr1, arr2, var_name):
return total, equal, diff


def fill_nans_for_float32(arr):
def replace_nan_with_sentinel(arr):
"""
To make sure nan values are recognised.
If the input array has a floating dtype, it is cast to float64
and all NaN values are replaced with the sentinel value -999999.
If the array does not have a floating dtype, it is returned unchanged.
"""
if arr.dtype == np.float32 and np.isnan(arr).any():
return np.where(np.isnan(arr), -999999, arr)
return arr
if not np.issubdtype(arr.dtype, np.floating):
return arr

arr = arr.astype(np.float64, copy=False)

return np.where(np.isnan(arr), -999999.0, arr)


def clean_value(x):
Expand Down Expand Up @@ -203,8 +208,8 @@ def process_var(ds1, ds2, var, detailed_logger):
number of matching elements.
"""

arr1 = fill_nans_for_float32(ds1[var].values)
arr2 = fill_nans_for_float32(ds2[var].values)
arr1 = replace_nan_with_sentinel(ds1[var].values)
arr2 = replace_nan_with_sentinel(ds2[var].values)
if arr1.size == arr2.size:
t, e, diff = compare_arrays(arr1, arr2, var)
if diff.size != 0:
Expand Down