diff --git a/utoipa-gen/src/component/schema/enums.rs b/utoipa-gen/src/component/schema/enums.rs index ded9d36e..c8d722d8 100644 --- a/utoipa-gen/src/component/schema/enums.rs +++ b/utoipa-gen/src/component/schema/enums.rs @@ -450,7 +450,7 @@ impl MixedEnumContent { match &variant.fields { Fields::Named(named) => { - let (variant_tokens, references, _) = + let (variant_tokens, references, disc_variant) = MixedEnumContent::get_named_tokens_with_schema_references( root, MixedEnumVariant { @@ -464,6 +464,7 @@ impl MixedEnumContent { rename_all, )?; schema_references.extend(references); + discriminator_variant = disc_variant; variant_tokens.to_tokens(&mut tokens); } Fields::Unnamed(unnamed) => { @@ -486,14 +487,49 @@ impl MixedEnumContent { variant_tokens.to_tokens(&mut tokens); } Fields::Unit => { - let variant_tokens = MixedEnumContent::get_unit_tokens( - name, - variant_features, - serde_container, - variant_serde_rules, - rename_all, - ); - variant_tokens.to_tokens(&mut tokens); + #[cfg(feature = "tagged_discriminator")] + { + // OpenAPI discriminator mappings can only reference named ($ref) schemas, so + // an inline unit variant can never appear in the mapping — generated clients + // (e.g. Kiota) that dispatch deserialization on the mapping then fail on that + // variant at parse time. Synthesize a named `{Enum}{Variant}` component instead. + if let SerdeEnumRepr::InternallyTagged { tag } = &serde_container.enum_repr { + let enum_name = root.ident.to_string(); + let (variant_tokens, schema_reference, disc_variant) = + MixedEnumContent::get_unit_tokens_with_named_component( + &enum_name, + name, + variant_features, + tag, + serde_container, + variant_serde_rules, + rename_all, + ); + schema_references.push(schema_reference); + discriminator_variant = disc_variant; + variant_tokens.to_tokens(&mut tokens); + } else { + let variant_tokens = MixedEnumContent::get_unit_tokens( + name, + variant_features, + serde_container, + variant_serde_rules, + rename_all, + ); + variant_tokens.to_tokens(&mut tokens); + } + } + #[cfg(not(feature = "tagged_discriminator"))] + { + let variant_tokens = MixedEnumContent::get_unit_tokens( + name, + variant_features, + serde_container, + variant_serde_rules, + rename_all, + ); + variant_tokens.to_tokens(&mut tokens); + } } } @@ -527,6 +563,9 @@ impl MixedEnumContent { ); let name = renamed.unwrap_or(Cow::Owned(name)); + #[cfg(feature = "tagged_discriminator")] + let enum_name = root.ident.to_string(); + let root = &Root { ident: &variant.ident, attributes: &variant.attrs, @@ -554,28 +593,57 @@ impl MixedEnumContent { let schema = NamedStructSchema::new(root, fields, variant_features)?; let mut schema_tokens = schema.to_token_stream(); - ( - if schema.is_all_of { - let object_builder_tokens = - quote! { utoipa::openapi::schema::Object::builder() }; - let enum_schema_tokens = - EnumSchema::::tagged(object_builder_tokens) - .tag(tag, PlainSchema::for_name(name.as_ref())) - .features(enum_features) - .to_token_stream(); - schema_tokens.extend(quote! { - .item(#enum_schema_tokens) - }); - schema_tokens - } else { - EnumSchema::::tagged(schema_tokens) + let is_all_of = schema.is_all_of; + let variant_schema_tokens = if is_all_of { + let object_builder_tokens = + quote! { utoipa::openapi::schema::Object::builder() }; + let enum_schema_tokens = + EnumSchema::::tagged(object_builder_tokens) .tag(tag, PlainSchema::for_name(name.as_ref())) .features(enum_features) - .to_token_stream() - }, - schema.fields_references, - None, - ) + .to_token_stream(); + schema_tokens.extend(quote! { + .item(#enum_schema_tokens) + }); + schema_tokens + } else { + EnumSchema::::tagged(schema_tokens) + .tag(tag, PlainSchema::for_name(name.as_ref())) + .features(enum_features) + .to_token_stream() + }; + + // Same rationale as unit variants: lift the inline tagged object into a named + // `{Enum}{Variant}` component so it can participate in the discriminator + // mapping. Skipped for generic enums where instantiations would collide on the + // synthesized name. + #[cfg(feature = "tagged_discriminator")] + if root.generics.params.is_empty() { + let component_name = format!("{enum_name}{}", variant.ident); + let component_schema = if is_all_of { + quote! { utoipa::openapi::schema::Schema::AllOf(#variant_schema_tokens.build()) } + } else { + quote! { utoipa::openapi::schema::Schema::Object(#variant_schema_tokens.build()) } + }; + let mut references = schema.fields_references; + references.push(SchemaReference { + name: quote! { String::from(#component_name) }, + tokens: quote! { utoipa::openapi::RefOr::T(#component_schema) }, + references: TokenStream::new(), + is_inline: false, + no_recursion: false, + }); + return Ok(( + quote! { + utoipa::openapi::schema::RefBuilder::new() + .ref_location_from_schema_name(#component_name) + }, + references, + Some((name.as_ref().to_string(), quote! { #component_name })), + )); + } + + (variant_schema_tokens, schema.fields_references, None) } SerdeEnumRepr::Untagged => { let schema = NamedStructSchema::new(root, fields, variant_features)?; @@ -737,6 +805,60 @@ impl MixedEnumContent { Ok(tokens_with_schema_reference) } + /// Build a unit variant of an internally tagged enum as a `$ref` to a synthesized named + /// component (`{EnumName}{VariantName}`) instead of an inline object. OpenAPI discriminator + /// mappings can only reference named schemas, so inline unit variants can never appear in + /// the mapping — generated clients (e.g. Kiota) that dispatch deserialization on the mapping + /// then fail on those variants at parse time. + #[cfg(feature = "tagged_discriminator")] + fn get_unit_tokens_with_named_component( + enum_name: &str, + name: String, + mut variant_features: Vec, + tag: &str, + serde_container: &SerdeContainer, + variant_serde_rules: SerdeValue, + rename_all: Option<&RenameAll>, + ) -> (TokenStream, SchemaReference, Option<(String, TokenStream)>) { + let component_name = format!("{enum_name}{name}"); + let renamed = super::rename_enum_variant( + &name, + &mut variant_features, + &variant_serde_rules, + serde_container, + rename_all, + ); + let tag_value = renamed.unwrap_or(Cow::Owned(name)); + + let component_tokens = EnumSchema::::new(tag_value.as_ref()) + .tagged(tag) + .features(variant_features) + .to_token_stream(); + + let schema_reference = SchemaReference { + name: quote! { String::from(#component_name) }, + tokens: quote! { + utoipa::openapi::RefOr::T( + utoipa::openapi::schema::Schema::Object(#component_tokens.build()) + ) + }, + references: TokenStream::new(), + is_inline: false, + no_recursion: false, + }; + + let ref_tokens = quote! { + utoipa::openapi::schema::RefBuilder::new() + .ref_location_from_schema_name(#component_name) + }; + + ( + ref_tokens, + schema_reference, + Some((tag_value.into_owned(), quote! { #component_name })), + ) + } + fn get_unit_tokens( name: String, mut variant_features: Vec, diff --git a/utoipa-gen/tests/tagged_discriminator.rs b/utoipa-gen/tests/tagged_discriminator.rs index 80b4e845..264a4329 100644 --- a/utoipa-gen/tests/tagged_discriminator.rs +++ b/utoipa-gen/tests/tagged_discriminator.rs @@ -70,12 +70,12 @@ fn derive_enum_tagged_discriminator_complex() { let schema = ::schema(); let value = serde_json::to_value(schema).unwrap(); - // Inline variants are NOT added to discriminator mapping, but have the tag injected. - // Ref variants are added to mapping and are bare refs in oneOf. - + // Inline (struct) variants are lifted into synthesized `{Enum}{Variant}` components so they + // can be referenced from the discriminator mapping; ref variants are mapped directly. let expected = serde_json::json!({ "discriminator": { "mapping": { + "inlineVariant": "#/components/schemas/ComplexEnumInlineVariant", "renamed_variant": "#/components/schemas/Item" }, "propertyName": "kind" @@ -85,26 +85,105 @@ fn derive_enum_tagged_discriminator_complex() { "$ref": "#/components/schemas/Item" }, { - "type": "object", - "required": [ - "value", - "kind" - ], - "properties": { - "kind": { - "type": "string", - "enum": [ - "inlineVariant" - ] - }, - "value": { - "type": "integer", - "format": "int32" - } - } + "$ref": "#/components/schemas/ComplexEnumInlineVariant" } ] }); assert_eq!(value, expected); + + // The synthesized component must actually be registered, with the tag injected. + let mut schemas = Vec::new(); + ::schemas(&mut schemas); + let inline_variant = schemas + .iter() + .find(|(name, _)| name == "ComplexEnumInlineVariant") + .expect("synthesized ComplexEnumInlineVariant component must be registered"); + let inline_value = serde_json::to_value(&inline_variant.1).unwrap(); + let expected_component = serde_json::json!({ + "type": "object", + "required": [ + "value", + "kind" + ], + "properties": { + "kind": { + "type": "string", + "enum": [ + "inlineVariant" + ] + }, + "value": { + "type": "integer", + "format": "int32" + } + } + }); + assert_eq!(inline_value, expected_component); +} + +#[test] +#[cfg(feature = "tagged_discriminator")] +fn derive_enum_tagged_discriminator_unit_variant() { + #[derive(ToSchema, Serialize, Deserialize)] + #[serde(rename_all = "camelCase")] + struct ProviderOpenId { + provider_id: String, + } + + // Mirrors the shape that motivated this: a user identity-provider binding where the + // "no provider" case is a unit variant. + #[derive(ToSchema, Serialize, Deserialize)] + #[serde(tag = "type", rename_all = "camelCase")] + enum Provider { + None, + OpenId(ProviderOpenId), + } + + let schema = ::schema(); + let value = serde_json::to_value(schema).unwrap(); + + let expected = serde_json::json!({ + "discriminator": { + "mapping": { + "none": "#/components/schemas/ProviderNone", + "openId": "#/components/schemas/ProviderOpenId" + }, + "propertyName": "type" + }, + "oneOf": [ + { + "$ref": "#/components/schemas/ProviderNone" + }, + { + "$ref": "#/components/schemas/ProviderOpenId" + } + ] + }); + + assert_eq!(value, expected); + + // The synthesized unit-variant component must be registered as a tag-only object. + let mut schemas = Vec::new(); + ::schemas(&mut schemas); + let none_variant = schemas + .iter() + .find(|(name, _)| name == "ProviderNone") + .expect("synthesized ProviderNone component must be registered"); + let none_value = serde_json::to_value(&none_variant.1).unwrap(); + let expected_component = serde_json::json!({ + "type": "object", + "required": [ + "type" + ], + "properties": { + "type": { + "type": "string", + "enum": [ + "none" + ] + } + } + }); + assert_eq!(none_value, expected_component); }