Skip to content

Commit 782ccf7

Browse files
fix: discriminator-stripped structs in standalone arrays + Default trait
Two fixes for code generation correctness: 1. Array type aliases for structs whose discriminator field was stripped (because they appear in tagged enums) now generate a single-variant wrapper enum that re-adds the tag via serde. This fixes serialization of types like RequestTextBlock in Vec<RequestTextBlock> arrays where the `type` field was missing. 2. Required fields with default values whose types are discriminated unions (which don't derive Default) now generate as Option<T> instead of bare T with #[serde(default)], preventing compile errors. Both fixes have insta-verified snapshot tests. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent db73553 commit 782ccf7

File tree

4 files changed

+371
-6
lines changed

4 files changed

+371
-6
lines changed

src/generator.rs

Lines changed: 85 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ use std::path::PathBuf;
99
struct DiscriminatedVariantInfo {
1010
/// The discriminator field name (e.g., "type")
1111
discriminator_field: String,
12+
/// The const value of the discriminator (e.g., "text")
13+
discriminator_value: String,
1214
/// Whether the parent union is untagged
1315
is_parent_untagged: bool,
1416
}
@@ -206,6 +208,7 @@ impl CodeGenerator {
206208
variant.type_name.clone(),
207209
DiscriminatedVariantInfo {
208210
discriminator_field: discriminator_field.clone(),
211+
discriminator_value: variant.discriminator_value.clone(),
209212
is_parent_untagged,
210213
},
211214
);
@@ -603,8 +606,52 @@ impl CodeGenerator {
603606
}
604607
}
605608
SchemaType::Array { item_type } => {
606-
// Generate type alias for named array schemas
609+
// Generate type alias for named array schemas.
610+
//
611+
// Special case: if the array item is a struct whose discriminator
612+
// field was stripped (because it's used in a tagged enum), the bare
613+
// struct won't serialize the discriminator in standalone contexts.
614+
// Generate a single-variant tagged wrapper enum so the discriminator
615+
// field is re-added by serde's tag attribute.
607616
let array_name = format_ident!("{}", self.to_rust_type_name(&schema.name));
617+
618+
// Check if the item type is a Reference to a discriminator-stripped struct
619+
if let SchemaType::Reference { target } = item_type.as_ref() {
620+
if let Some(info) = discriminated_variant_info.get(target) {
621+
if !info.is_parent_untagged {
622+
// Generate a wrapper enum that re-adds the discriminator tag
623+
let wrapper_name = format_ident!(
624+
"{}Item",
625+
self.to_rust_type_name(&schema.name)
626+
);
627+
let variant_type =
628+
format_ident!("{}", self.to_rust_type_name(target));
629+
let disc_field = &info.discriminator_field;
630+
let disc_value = &info.discriminator_value;
631+
632+
let doc_comment = if let Some(desc) = &schema.description {
633+
quote! { #[doc = #desc] }
634+
} else {
635+
TokenStream::new()
636+
};
637+
638+
return Ok(quote! {
639+
/// Wrapper enum that re-adds the discriminator tag
640+
/// for array contexts where the inner struct had its
641+
/// discriminator field stripped for tagged enum use.
642+
#[derive(Debug, Clone, Deserialize, Serialize)]
643+
#[serde(tag = #disc_field)]
644+
pub enum #wrapper_name {
645+
#[serde(rename = #disc_value)]
646+
#variant_type(#variant_type),
647+
}
648+
#doc_comment
649+
pub type #array_name = Vec<#wrapper_name>;
650+
});
651+
}
652+
}
653+
}
654+
608655
let inner_type = self.generate_array_item_type(item_type, analysis);
609656

610657
let doc_comment = if let Some(desc) = &schema.description {
@@ -856,7 +903,7 @@ impl CodeGenerator {
856903
let field_type =
857904
self.generate_field_type(&schema.name, field_name, prop, is_required, analysis);
858905

859-
let serde_attrs = self.generate_serde_field_attrs(field_name, prop, is_required);
906+
let serde_attrs = self.generate_serde_field_attrs(field_name, prop, is_required, analysis);
860907
let specta_attrs = self.generate_specta_field_attrs(field_name);
861908

862909
let doc_comment = if let Some(desc) = &prop.description {
@@ -1235,7 +1282,13 @@ impl CodeGenerator {
12351282
.unwrap_or(false);
12361283

12371284
if is_required && !prop.nullable && !is_nullable_override {
1238-
base_type
1285+
// If the field has a default value but its type doesn't implement Default,
1286+
// wrap in Option<T> so serde can default to None instead of requiring Default.
1287+
if prop.default.is_some() && self.type_lacks_default(&prop.schema_type, analysis) {
1288+
quote! { Option<#base_type> }
1289+
} else {
1290+
base_type
1291+
}
12391292
} else {
12401293
quote! { Option<#base_type> }
12411294
}
@@ -1246,6 +1299,7 @@ impl CodeGenerator {
12461299
field_name: &str,
12471300
prop: &crate::analysis::PropertyInfo,
12481301
is_required: bool,
1302+
analysis: &crate::analysis::SchemaAnalysis,
12491303
) -> TokenStream {
12501304
let mut attrs = Vec::new();
12511305

@@ -1264,10 +1318,13 @@ impl CodeGenerator {
12641318
attrs.push(quote! { skip_serializing_if = "Option::is_none" });
12651319
}
12661320

1267-
// Only add default attribute for required fields that have default values
1268-
// Optional fields (Option<T>) already default to None, so don't need #[serde(default)]
1321+
// Only add default attribute for required fields that have default values.
1322+
// Skip #[serde(default)] for types that don't implement Default (discriminated
1323+
// unions, union enums) — those fields should be Option<T> instead.
12691324
if prop.default.is_some() && (is_required && !prop.nullable) {
1270-
attrs.push(quote! { default });
1325+
if !self.type_lacks_default(&prop.schema_type, analysis) {
1326+
attrs.push(quote! { default });
1327+
}
12711328
}
12721329

12731330
if attrs.is_empty() {
@@ -1277,6 +1334,28 @@ impl CodeGenerator {
12771334
}
12781335
}
12791336

1337+
/// Check if a schema type resolves to a type that doesn't implement `Default`.
1338+
/// Discriminated unions and union enums don't derive Default, so fields with
1339+
/// these types can't use `#[serde(default)]`.
1340+
fn type_lacks_default(
1341+
&self,
1342+
schema_type: &crate::analysis::SchemaType,
1343+
analysis: &crate::analysis::SchemaAnalysis,
1344+
) -> bool {
1345+
use crate::analysis::SchemaType;
1346+
match schema_type {
1347+
SchemaType::DiscriminatedUnion { .. } | SchemaType::Union { .. } => true,
1348+
SchemaType::Reference { target } => {
1349+
if let Some(schema) = analysis.schemas.get(target) {
1350+
self.type_lacks_default(&schema.schema_type, analysis)
1351+
} else {
1352+
false
1353+
}
1354+
}
1355+
_ => false,
1356+
}
1357+
}
1358+
12801359
fn generate_specta_field_attrs(&self, field_name: &str) -> TokenStream {
12811360
if !self.config.enable_specta {
12821361
return TokenStream::new();
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
---
2+
source: src/test_helpers.rs
3+
expression: "&generated_code"
4+
---
5+
//! Generated types from OpenAPI specification
6+
//!
7+
//! This file contains all the generated types for the API.
8+
//! Do not edit manually - regenerate using the appropriate script.
9+
#![allow(clippy::large_enum_variant)]
10+
#![allow(clippy::format_in_format_args)]
11+
#![allow(clippy::let_unit_value)]
12+
#![allow(unreachable_patterns)]
13+
use serde::{Deserialize, Serialize};
14+
#[derive(Debug, Clone, Deserialize, Serialize)]
15+
pub struct ToolResult {
16+
pub caller: Option<ToolResultCaller>,
17+
pub content: String,
18+
}
19+
#[derive(Debug, Clone, Deserialize, Serialize)]
20+
#[serde(tag = "type")]
21+
pub enum ToolResultCaller {
22+
#[serde(rename = "direct")]
23+
DirectCaller(DirectCaller),
24+
#[serde(rename = "server")]
25+
ServerCaller(ServerCaller),
26+
}
27+
#[derive(Debug, Clone, Deserialize, Serialize)]
28+
#[serde(tag = "type")]
29+
pub enum CallerType {
30+
#[serde(rename = "direct")]
31+
DirectCaller(DirectCaller),
32+
#[serde(rename = "server")]
33+
ServerCaller(ServerCaller),
34+
}
35+
#[derive(Debug, Clone, Deserialize, Serialize)]
36+
pub struct ServerCaller {
37+
#[serde(skip_serializing_if = "Option::is_none")]
38+
pub version: Option<String>,
39+
}
40+
#[derive(Debug, Clone, Deserialize, Serialize)]
41+
pub struct DirectCaller {}
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
---
2+
source: src/test_helpers.rs
3+
expression: "&generated_code"
4+
---
5+
//! Generated types from OpenAPI specification
6+
//!
7+
//! This file contains all the generated types for the API.
8+
//! Do not edit manually - regenerate using the appropriate script.
9+
#![allow(clippy::large_enum_variant)]
10+
#![allow(clippy::format_in_format_args)]
11+
#![allow(clippy::let_unit_value)]
12+
#![allow(unreachable_patterns)]
13+
use serde::{Deserialize, Serialize};
14+
#[derive(Debug, Clone, Deserialize, Serialize)]
15+
pub struct CreateMessageParams {
16+
pub messages: Vec<InputContentBlock>,
17+
#[serde(skip_serializing_if = "Option::is_none")]
18+
pub system: Option<CreateMessageParamsSystem>,
19+
}
20+
/// Wrapper enum that re-adds the discriminator tag
21+
/// for array contexts where the inner struct had its
22+
/// discriminator field stripped for tagged enum use.
23+
#[derive(Debug, Clone, Deserialize, Serialize)]
24+
#[serde(tag = "type")]
25+
pub enum RequestTextBlockArrayItem {
26+
#[serde(rename = "text")]
27+
RequestTextBlock(RequestTextBlock),
28+
}
29+
///Array variant in union
30+
pub type RequestTextBlockArray = Vec<RequestTextBlockArrayItem>;
31+
#[derive(Debug, Clone, Deserialize, Serialize)]
32+
#[serde(tag = "type")]
33+
pub enum InputContentBlock {
34+
#[serde(rename = "text")]
35+
RequestTextBlock(RequestTextBlock),
36+
#[serde(rename = "image")]
37+
RequestImageBlock(RequestImageBlock),
38+
}
39+
#[derive(Debug, Clone, Deserialize, Serialize)]
40+
pub struct RequestTextBlock {
41+
#[serde(skip_serializing_if = "Option::is_none")]
42+
pub cache_control: Option<RequestTextBlockCacheControl>,
43+
pub text: String,
44+
}
45+
#[derive(Debug, Clone, Deserialize, Serialize)]
46+
pub struct RequestTextBlockCacheControl {
47+
#[serde(skip_serializing_if = "Option::is_none")]
48+
pub ttl: Option<String>,
49+
#[serde(skip_serializing_if = "Option::is_none")]
50+
pub r#type: Option<serde_json::Value>,
51+
}
52+
#[derive(Debug, Clone, Deserialize, Serialize)]
53+
pub struct RequestImageBlock {
54+
pub source: String,
55+
}
56+
pub type CreateMessageParamsSystemString = String;
57+
#[derive(Debug, Clone, Deserialize, Serialize)]
58+
#[serde(untagged)]
59+
pub enum CreateMessageParamsSystem {
60+
CreateMessageParamsSystemString(CreateMessageParamsSystemString),
61+
RequestTextBlockArray(RequestTextBlockArray),
62+
}

0 commit comments

Comments
 (0)