diff --git a/grain/_src/python/dataset/transformations/testing_util.py b/grain/_src/python/dataset/transformations/testing_util.py index b0def247b..031807435 100644 --- a/grain/_src/python/dataset/transformations/testing_util.py +++ b/grain/_src/python/dataset/transformations/testing_util.py @@ -1045,6 +1045,94 @@ def test_meta_features(self, convert_input_to_np: bool): convert_input_to_np=convert_input_to_np, ) + @parameterized.product( + convert_input_to_np=[True, False], + meta_feature_values=[ + [3.5, 7.5, 5.5], + [np.asarray(3.5), np.asarray(7.5), np.asarray(5.5)], + [ + np.asarray(3.5, dtype=np.int32), + np.asarray(7, dtype=np.int32), + np.asarray(5, dtype=np.int32), + ], + ], + ) + def test_nonint64_meta_feature( + self, + convert_input_to_np: bool, + meta_feature_values: list[np.ndarray | float], + ): + input_elements = [ + { + "inputs": np.asarray([1, 2, 3]), + "targets": np.asarray([10]), + "meta_feature": meta_feature_values[0], + }, + { + "inputs": np.asarray([4, 5]), + "targets": np.asarray([20, 30, 40]), + "meta_feature": meta_feature_values[1], + }, + { + "inputs": np.asarray([6]), + "targets": np.asarray([50, 60]), + "meta_feature": meta_feature_values[2], + }, + ] + length_struct = {"inputs": 3, "targets": 4, "meta_feature": 3} + + expected_elements = [ + { + "inputs": [1, 2, 3], + "targets": [10, 0, 0, 0], + "inputs_segment_ids": [1, 1, 1], + "targets_segment_ids": [1, 0, 0, 0], + "inputs_positions": [0, 1, 2], + "targets_positions": [0, 0, 0, 0], + "meta_feature": [ + np.asarray(meta_feature_values[0]).item(), + 0, + 0, + ], + }, + { + "inputs": [4, 5, 0], + "targets": [20, 30, 40, 0], + "inputs_segment_ids": [1, 1, 0], + "targets_segment_ids": [1, 1, 1, 0], + "inputs_positions": [0, 1, 0], + "targets_positions": [0, 1, 2, 0], + "meta_feature": [ + np.asarray(meta_feature_values[1]).item(), + 0, + 0, + ], + }, + { + "inputs": [6, 0, 0], + "targets": [50, 60, 0, 0], + "inputs_segment_ids": [1, 0, 0], + "targets_segment_ids": [1, 1, 0, 0], + "inputs_positions": [0, 0, 0], + "targets_positions": [0, 1, 0, 0], + "meta_feature": [ + np.asarray(meta_feature_values[2]).item(), + 0, + 0, + ], + }, + ] + _common_test_body( + self.packer_cls, + input_elements, + expected_elements, + length_struct, + kwargs=self.kwargs, + num_packing_bins=3, + meta_features=["meta_feature"], + convert_input_to_np=convert_input_to_np, + ) + @parameterized.parameters( {"restore_at_step": 0}, {"restore_at_step": 1},