|
20 | 20 |
|
21 | 21 | import base64 |
22 | 22 | import math |
23 | | -import os |
24 | | -import sys |
25 | 23 | import requests |
26 | 24 | from typing import List |
27 | 25 | from unittest import TestCase, main |
|
32 | 30 | from enum import Enum |
33 | 31 | import json |
34 | 32 | from fastavro.schema import load_schema |
| 33 | +from google.protobuf import descriptor_pb2, descriptor_pool, message_factory |
| 34 | + |
| 35 | + |
| 36 | +def _add_protobuf_field(message, name, number, field_type, type_name=None): |
| 37 | + field = message.field.add() |
| 38 | + field.name = name |
| 39 | + field.number = number |
| 40 | + field.label = descriptor_pb2.FieldDescriptorProto.LABEL_OPTIONAL |
| 41 | + field.type = field_type |
| 42 | + if type_name: |
| 43 | + field.type_name = type_name |
| 44 | + |
| 45 | + |
| 46 | +def _get_message_classes(pool, message_names): |
| 47 | + if hasattr(message_factory, 'GetMessageClass'): |
| 48 | + return tuple( |
| 49 | + message_factory.GetMessageClass(pool.FindMessageTypeByName(message_name)) |
| 50 | + for message_name in message_names |
| 51 | + ) |
| 52 | + factory = message_factory.MessageFactory(pool) |
| 53 | + return tuple( |
| 54 | + factory.GetPrototype(pool.FindMessageTypeByName(message_name)) |
| 55 | + for message_name in message_names |
| 56 | + ) |
| 57 | + |
| 58 | + |
| 59 | +def _build_protobuf_test_messages(): |
| 60 | + file_proto = descriptor_pb2.FileDescriptorProto() |
| 61 | + file_proto.name = 'test_schema.proto' |
| 62 | + file_proto.package = 'test' |
| 63 | + file_proto.syntax = 'proto3' |
| 64 | + |
| 65 | + test_message = file_proto.message_type.add() |
| 66 | + test_message.name = 'TestMessage' |
| 67 | + _add_protobuf_field(test_message, 'name', 1, descriptor_pb2.FieldDescriptorProto.TYPE_STRING) |
| 68 | + _add_protobuf_field(test_message, 'value', 2, descriptor_pb2.FieldDescriptorProto.TYPE_INT32) |
| 69 | + |
| 70 | + nested_message = file_proto.message_type.add() |
| 71 | + nested_message.name = 'TestMessageWithNested' |
| 72 | + _add_protobuf_field(nested_message, 'str_field', 1, descriptor_pb2.FieldDescriptorProto.TYPE_STRING) |
| 73 | + _add_protobuf_field(nested_message, 'int_field', 2, descriptor_pb2.FieldDescriptorProto.TYPE_INT32) |
| 74 | + _add_protobuf_field(nested_message, 'double_field', 3, descriptor_pb2.FieldDescriptorProto.TYPE_DOUBLE) |
| 75 | + _add_protobuf_field( |
| 76 | + nested_message, 'nested', 4, descriptor_pb2.FieldDescriptorProto.TYPE_MESSAGE, '.test.TestInner' |
| 77 | + ) |
| 78 | + |
| 79 | + inner_message = file_proto.message_type.add() |
| 80 | + inner_message.name = 'TestInner' |
| 81 | + _add_protobuf_field(inner_message, 'inner_str', 1, descriptor_pb2.FieldDescriptorProto.TYPE_STRING) |
| 82 | + _add_protobuf_field(inner_message, 'inner_int', 2, descriptor_pb2.FieldDescriptorProto.TYPE_INT64) |
| 83 | + |
| 84 | + pool = descriptor_pool.DescriptorPool() |
| 85 | + pool.AddSerializedFile(file_proto.SerializeToString()) |
| 86 | + return _get_message_classes( |
| 87 | + pool, |
| 88 | + ('test.TestMessage', 'test.TestMessageWithNested', 'test.TestInner'), |
| 89 | + ) |
| 90 | + |
| 91 | + |
| 92 | +TestMessage, TestMessageWithNested, TestInner = _build_protobuf_test_messages() |
35 | 93 |
|
36 | | -# Make generated protobuf test classes importable |
37 | | -sys.path.insert(0, os.path.dirname(__file__)) |
38 | | -from test_schema_pb2 import TestMessage, TestMessageWithNested, TestInner |
39 | 94 |
|
40 | 95 | class ExampleRecord(Record): |
41 | 96 | str_field = String() |
|
0 commit comments