From 1a5a03b4ba3bda64154ae44fcb2ff9eb8c3fe755 Mon Sep 17 00:00:00 2001 From: ML Metrics Team Date: Mon, 22 Sep 2025 16:32:19 -0700 Subject: [PATCH] *example to dict* 1. treat feature values as lists instead of single value if the feature only contains single scalar 2. support the conversion of list to ndarray with proper type to retain the feature type *dict to example* 1. Check the type for numeric feature as they might be float instead of int which we could not rely on the first number of the list to determine the type. *move script to internal codebase* * Unit test requires using internal module (checked by our kokor check) PiperOrigin-RevId: 810197698 --- ml_metrics/_src/utils/proto_utils.py | 116 ++++--- ml_metrics/_src/utils/proto_utils_test.py | 394 ++++++++++++++++++---- 2 files changed, 407 insertions(+), 103 deletions(-) diff --git a/ml_metrics/_src/utils/proto_utils.py b/ml_metrics/_src/utils/proto_utils.py index b607c34f..f69537fc 100644 --- a/ml_metrics/_src/utils/proto_utils.py +++ b/ml_metrics/_src/utils/proto_utils.py @@ -12,14 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. """Proto utils.""" + +import collections from collections.abc import Iterable from typing import Any -from absl import logging from ml_metrics._src.tools.telemetry import telemetry -import more_itertools as mit import numpy as np from tensorflow.core.example import example_pb2 + _ExampleOrBytes = bytes | example_pb2.Example @@ -32,61 +33,90 @@ def _maybe_deserialize(ex: _ExampleOrBytes) -> example_pb2.Example: @telemetry.function_monitor(api='ml_metrics', category=telemetry.CATEGORY.UTIL) -def tf_examples_to_dict(examples: Iterable[_ExampleOrBytes] | _ExampleOrBytes): - """Parses a serialized tf.train.Example to a dict.""" +def tf_examples_to_dict( + examples: Iterable[_ExampleOrBytes] | _ExampleOrBytes, +) -> dict[ + str, + list[int | float | bytes] | list[list[int | float | bytes]], +]: + """Parses serialized or unserialized tf.train.Examples to a dict. + + The conversion assumes all examples have the same features. If not, a + ValueError will be raised. + + Args: + examples: A single tf.train.Example, serialized tf.train.Example, or an + iterable of tf.train.Examples and/or serialized tf.train.Examples. + + Returns: + A dict mapping feature names to lists of feature values. + + Raises: + ValueError: If the features are not all present in all examples. + """ + single_example = False if isinstance(examples, (bytes, example_pb2.Example)): single_example = True examples = [examples] - examples = (_maybe_deserialize(ex) for ex in examples) - examples = mit.peekable(examples) - if (head := examples.peek(None)) is None: - return {} - result = {k: [] for k in head.features.feature} + result = collections.defaultdict(list) + for ex in examples: - missing = set(result) - for key, feature in ex.features.feature.items(): - missing.remove(key) - value = getattr(feature, feature.WhichOneof('kind')).value - if value and isinstance(value[0], bytes): - try: - value = [v.decode() for v in value] - except UnicodeDecodeError: - logging.info( - 'chainable: %s', - f'Failed to decode for {key}, forward the raw bytes.', - ) - result[key].extend(value) - if missing: + ex = _maybe_deserialize(ex) + features = dict(ex.features.feature) + + if result and result.keys() != features.keys(): raise ValueError( - f'Missing keys: {missing}, expecting {set(result)}, got {ex=}' + 'All examples must have the same features, got %s and %s' + % (result.keys(), features.keys()) ) - result = {k: v for k, v in result.items()} - # Scalar value in a single example will be returned with the scalar directly. - if single_example and all(len(v) == 1 for v in result.values()): - result = {k: v[0] for k, v in result.items()} + + for name, values in features.items(): + result[name].append(getattr(values, values.WhichOneof('kind')).value) + + if single_example: + return {k: v[0] for k, v in result.items()} return result @telemetry.function_monitor(api='ml_metrics', category=telemetry.CATEGORY.UTIL) def dict_to_tf_example(data: dict[str, Any]) -> example_pb2.Example: """Creates a tf.Example from a dictionary.""" + example = example_pb2.Example() - for key, value in data.items(): - if isinstance(value, (str, bytes, np.floating, float, int, np.integer)): - value = [value] - feature = example.features.feature - if isinstance(value[0], str): - for v in value: - assert isinstance(v, str), f'bad str type: {value}' - feature[key].bytes_list.value.append(v.encode()) - elif isinstance(value[0], bytes): - feature[key].bytes_list.value.extend(value) - elif isinstance(value[0], (int, np.integer)): - feature[key].int64_list.value.extend(value) - elif isinstance(value[0], (float, np.floating)): - feature[key].float_list.value.extend(value) + for key, values in data.items(): + if isinstance(values, (str, bytes, np.floating, float, int, np.integer)): + values = [values] + + if not values: + # Skip empty features. + continue + + if isinstance(values[0], str): + for v in values: + assert isinstance(v, str), f'bad str type: {values}' + example.features.feature[key].bytes_list.value.append(v.encode()) + continue + + if isinstance(values[0], bytes): + feature_kind = 'bytes_list' + elif isinstance(values[0], (float, np.floating)): + feature_kind = 'float_list' + elif isinstance(values[0], (int, np.integer)): + feature_kind = 'int64_list' + for v in values: + if isinstance(v, (float, np.floating)): + # If a float is encountered in the list, we consider the whole feature + # to be a float_list. + feature_kind = 'float_list' + break + elif not isinstance(v, (int, np.integer)): + break else: - raise TypeError(f'Value for "{key}" is not a supported type.') + raise TypeError(f'Values for "{key}" is not a supported type.') + + feature_list = getattr(example.features.feature[key], feature_kind).value + feature_list.extend(values) + return example diff --git a/ml_metrics/_src/utils/proto_utils_test.py b/ml_metrics/_src/utils/proto_utils_test.py index 66b55536..35c31136 100644 --- a/ml_metrics/_src/utils/proto_utils_test.py +++ b/ml_metrics/_src/utils/proto_utils_test.py @@ -11,91 +11,365 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from google.protobuf import text_format from ml_metrics._src.utils import proto_utils from ml_metrics._src.utils import test_utils -import numpy as np +import tensorflow as tf from absl.testing import absltest from absl.testing import parameterized from tensorflow.core.example import example_pb2 -def _get_tf_example(**kwargs): - example = example_pb2.Example() - for k, v in kwargs.items(): - example.features.feature[k].bytes_list.value.append(v) - return example +EXAMPLE_1 = text_format.Parse( + """ +features { + feature { + key: "bytes" + value { + bytes_list { + value: "ab" + } + } + } + feature { + key: "bytes_arr" + value { + bytes_list { + value: "cd" + value: "ef" + } + } + } + feature { + key: "int64" + value { + int64_list { + value: 1 + } + } + } + feature { + key: "int64_arr" + value { + int64_list { + value: 2 + value: 3 + } + } + } + feature { + key: "float" + value { + float_list { + value: 1.5 + } + } + } + feature { + key: "float_arr" + value { + float_list { + value: 2.5 + value: 3.5 + } + } + } +} +""", + example_pb2.Example(), +) -class TFExampleTest(parameterized.TestCase): +EXAMPLE_2 = text_format.Parse( + """ +features { + feature { + key: "bytes" + value { + bytes_list { + value: "mn" + } + } + } + feature { + key: "bytes_arr" + value { + bytes_list { + value: "op" + value: "qr" + } + } + } + feature { + key: "int64" + value { + int64_list { + value: 4 + } + } + } + feature { + key: "int64_arr" + value { + int64_list { + value: 5 + value: 6 + } + } + } + feature { + key: "float" + value { + float_list { + value: 11.5 + } + } + } + feature { + key: "float_arr" + value { + float_list { + value: 12.5 + value: 13.5 + } + } + } +} +""", + example_pb2.Example(), +) - def test_single_example(self): - data = { - 'bytes_key': b'\x80abc', # not utf-8 decodable - 'str_key': 'str_test', - 'init_key': 123, - 'np_int': np.int32(123), - 'float_key': 4.56, - 'np_float': np.float32(123), - } - e = proto_utils.dict_to_tf_example(data).SerializeToString() - actual = proto_utils.tf_examples_to_dict(e) - self.assertDictAlmostEqual(data, actual, places=6) - - def test_batch_example(self): - data = { - 'bytes_key': [b'\x80abc', b'\x80def'], # not utf-8 decodable - 'str_key': ['str_test', 'str_test2'], - 'init_key': [123, 456], - 'np_int': [np.int32(123), np.int32(456)], - 'float_key': [4.56, 7.89], - 'np_float': [np.float32(123), np.float32(456)], + +class TFExampleTest(tf.test.TestCase, parameterized.TestCase): + + @parameterized.named_parameters( + dict( + testcase_name='tf_example', + example=EXAMPLE_1, + ), + dict( + testcase_name='serialized_tf_example', + example=EXAMPLE_1.SerializeToString(), + ), + ) + def test_single_example_to_dict(self, example): + actual = proto_utils.tf_examples_to_dict(example) + expected = { + 'bytes': [b'ab'], + 'bytes_arr': [b'cd', b'ef'], + 'int64': [1], + 'int64_arr': [2, 3], + 'float': [1.5], + 'float_arr': [2.5, 3.5], } - e = proto_utils.dict_to_tf_example(data) - actual = proto_utils.tf_examples_to_dict(e) - test_utils.assert_nested_container_equal(self, data, actual, places=6) + test_utils.assert_nested_container_equal(self, expected, actual, places=6) @parameterized.named_parameters( dict( - testcase_name='with_single_example', - num_elems=1, + testcase_name='batch_tf_examples', + examples=[EXAMPLE_1, EXAMPLE_2], ), dict( - testcase_name='multiple_examples', - num_elems=3, + testcase_name='batch_serialized_tf_examples', + examples=[ + EXAMPLE_1.SerializeToString(), + EXAMPLE_2.SerializeToString(), + ], + ), + dict( + testcase_name='batch_mixed_tf_examples', + examples=[EXAMPLE_1.SerializeToString(), EXAMPLE_2], ), ) - def test_multiple_examples_as_batch(self, num_elems): - data = { - 'bytes_key': b'\x80abc', # not utf-8 decodable - 'str_key': 'str_test', - 'init_key': 123, - 'np_int': np.int32(123), - 'float_key': 4.56, - 'np_float': np.float32(123), - } - e = [proto_utils.dict_to_tf_example(data) for _ in range(num_elems)] - actual = proto_utils.tf_examples_to_dict(e) - expected = {k: [v] * num_elems for k, v in data.items()} + def test_batched_examples_to_dict(self, examples): + actual = proto_utils.tf_examples_to_dict(examples) + expected = { + 'bytes': [[b'ab'], [b'mn']], + 'bytes_arr': [[b'cd', b'ef'], [b'op', b'qr']], + 'int64': [[1], [4]], + 'int64_arr': [[2, 3], [5, 6]], + 'float': [[1.5], [11.5]], + 'float_arr': [[2.5, 3.5], [12.5, 13.5]], + } test_utils.assert_nested_container_equal(self, expected, actual, places=6) - def test_empty_example(self): - self.assertEmpty(proto_utils.tf_examples_to_dict([])) + def test_batched_single_example_to_dict(self): + actual = proto_utils.tf_examples_to_dict([EXAMPLE_1]) + expected = { + 'bytes': [[b'ab']], + 'bytes_arr': [[b'cd', b'ef']], + 'int64': [[1]], + 'int64_arr': [[2, 3]], + 'float': [[1.5]], + 'float_arr': [[2.5, 3.5]], + } + test_utils.assert_nested_container_equal(self, expected, actual, places=6) - def test_unsupported_type(self): - with self.assertRaisesRegex(TypeError, 'Unsupported type'): - proto_utils.tf_examples_to_dict('unsupported_type') + def test_missing_features_example_to_dict(self): + example_missing_features = text_format.Parse( + """ + features { + feature { + key: "bytes" + value { + bytes_list { + value: "xy" + } + } + } + } + """, + example_pb2.Example(), + ) - def test_unsupported_value_type(self): with self.assertRaisesRegex( - TypeError, 'Value for "a" is not a supported type' + ValueError, 'All examples must have the same features' ): - proto_utils.dict_to_tf_example({'a': [example_pb2.Example()]}) + _ = proto_utils.tf_examples_to_dict([example_missing_features, EXAMPLE_1]) + + def test_empty_example_to_dict(self): + self.assertEmpty(proto_utils.tf_examples_to_dict([])) + + def test_dict_to_tf_example(self): + data = { + 'bytes_scalar': b'a', + 'str_scalar': 'b', + 'int64_scalar': 1, + 'flaot_scalar': 2.1, + 'bytes_list': [b'cd', b'ef'], + 'str_list': ['gh', 'ij'], + 'int64_list': [2, 3], + 'float_list': [1, 3.5], + } + expected = text_format.Parse( + """ + features { + feature { + key: "bytes_scalar" + value { + bytes_list { + value: "a" + } + } + } + feature { + key: "str_scalar" + value { + bytes_list { + value: "b" + } + } + } + feature { + key: "int64_scalar" + value { + int64_list { + value: 1 + } + } + } + feature { + key: "flaot_scalar" + value { + float_list { + value: 2.1 + } + } + } + feature { + key: "bytes_list" + value { + bytes_list { + value: "cd" + value: "ef" + } + } + } + feature { + key: "str_list" + value { + bytes_list { + value: "gh" + value: "ij" + } + } + } + feature { + key: "int64_list" + value { + int64_list { + value: 2 + value: 3 + } + } + } + feature { + key: "float_list" + value { + float_list { + value: 1.0 + value: 3.5 + } + } + } + } + """, + example_pb2.Example(), + ) + actual = proto_utils.dict_to_tf_example(data) + self.assertProtoEquals(expected, actual) + + def test_dict_to_tf_example_key_with_empty_list(self): + data = {'int64_list': 1, 'empty_list': []} + expected = text_format.Parse( + """ + features { + feature { + key: "int64_list" + value { + int64_list { + value: 1 + } + } + } + } + """, + example_pb2.Example(), + ) + actual = proto_utils.dict_to_tf_example(data) + self.assertProtoEquals(expected, actual) - def test_multiple_examples_missing_key(self): - data = [{'a': 'a', 'b': 1}, {'b': 2}] - examples = [proto_utils.dict_to_tf_example(d) for d in data] - with self.assertRaisesRegex(ValueError, 'Missing keys'): - _ = proto_utils.tf_examples_to_dict(examples) + def test_dict_to_tf_example_bad_str_type(self): + data = {'str_arr': ['abc', b'def']} + with self.assertRaisesRegex(AssertionError, 'bad str type'): + _ = proto_utils.dict_to_tf_example(data) + + @parameterized.named_parameters( + dict( + testcase_name='bytes', + data={'bad_type': [b'ab', 'cd']}, + ), + dict( + testcase_name='float', + data={'bad_type': [1.0, 'a']}, + ), + dict( + testcase_name='int64', + data={'bad_type': [1, 'b']}, + ), + ) + def test_dict_to_tf_example_inconsistent_types(self, data): + # This test is required as the logic to determine the type of the feature + # list is based on the first value of the list. + with self.assertRaises(Exception): + _ = proto_utils.dict_to_tf_example(data) + + def test_dict_to_tf_example_unsupported_type(self): + data = {'bad_type': [example_pb2.Example()]} + with self.assertRaisesRegex( + TypeError, 'Values for "bad_type" is not a supported type.' + ): + _ = proto_utils.dict_to_tf_example(data) if __name__ == '__main__':