diff --git a/src/openai/_models.py b/src/openai/_models.py index ed4c1f82d6..709b855e19 100644 --- a/src/openai/_models.py +++ b/src/openai/_models.py @@ -657,7 +657,11 @@ def construct_type(*, value: object, type_: object, metadata: Optional[List[Any] if not is_mapping(value): return value - _, items_type = get_args(type_) # Dict[_, items_type] + args = get_args(type_) + if len(args) < 2: + return value + + _, items_type = args # Dict[_, items_type] return {key: construct_type(value=item, type_=items_type) for key, item in value.items()} if ( diff --git a/src/openai/_utils/_transform.py b/src/openai/_utils/_transform.py index 414f38c340..575c609cc2 100644 --- a/src/openai/_utils/_transform.py +++ b/src/openai/_utils/_transform.py @@ -180,7 +180,11 @@ def _transform_recursive( return _transform_typeddict(data, stripped_type) if origin == dict and is_mapping(data): - items_type = get_args(stripped_type)[1] + args = get_args(stripped_type) + if len(args) < 2: + return cast(object, data) + + items_type = args[1] return {key: _transform_recursive(value, annotation=items_type) for key, value in data.items()} if ( @@ -346,7 +350,11 @@ async def _async_transform_recursive( return await _async_transform_typeddict(data, stripped_type) if origin == dict and is_mapping(data): - items_type = get_args(stripped_type)[1] + args = get_args(stripped_type) + if len(args) < 2: + return cast(object, data) + + items_type = args[1] return {key: _transform_recursive(value, annotation=items_type) for key, value in data.items()} if ( diff --git a/tests/test_models.py b/tests/test_models.py index cc204bac1d..49b5fe6b6a 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -133,6 +133,12 @@ class NestedModel(BaseModel): assert cast(Any, m.nested) is False +def test_bare_dictionary_annotation() -> None: + value = {"hello": "world"} + + assert construct_type(value=value, type_=dict) is value + + def test_nested_dictionary_model() -> None: class NestedModel(BaseModel): nested: Dict[str, BasicModel] diff --git a/tests/test_transform.py b/tests/test_transform.py index bece75dfc7..302982f1f2 100644 --- a/tests/test_transform.py +++ b/tests/test_transform.py @@ -161,6 +161,18 @@ async def test_ignores_invalid_input(use_async: bool) -> None: assert await transform({"foo": ""}, Foo7, use_async) == {"foo": ""} +class BareDictTypedDict(TypedDict, total=False): + metadata: dict + + +@parametrize +@pytest.mark.asyncio +async def test_bare_dict_annotation(use_async: bool) -> None: + data = {"metadata": {"hello": "world"}} + + assert await transform(data, BareDictTypedDict, use_async) == data + + class DatetimeDict(TypedDict, total=False): foo: Annotated[datetime, PropertyInfo(format="iso8601")]