diff --git a/core/wren-core-base/manifest-macro/src/lib.rs b/core/wren-core-base/manifest-macro/src/lib.rs index bde439ab9c..95b007031e 100644 --- a/core/wren-core-base/manifest-macro/src/lib.rs +++ b/core/wren-core-base/manifest-macro/src/lib.rs @@ -153,7 +153,7 @@ pub fn model(python_binding: proc_macro::TokenStream) -> proc_macro::TokenStream pub table_reference: Option, pub columns: Vec>, #[serde(default)] - pub primary_key: Option, + pub primary_key: Option, #[serde(default, with = "bool_from_int")] pub cached: bool, #[serde(default)] diff --git a/core/wren-core-base/src/mdl/builder.rs b/core/wren-core-base/src/mdl/builder.rs index 59f8977e75..a315abb6df 100644 --- a/core/wren-core-base/src/mdl/builder.rs +++ b/core/wren-core-base/src/mdl/builder.rs @@ -20,8 +20,8 @@ #![allow(dead_code)] use crate::mdl::manifest::{ - Column, Cube, CubeDimension, DataSource, JoinType, Manifest, Measure, Model, Relationship, - TimeDimension, View, + Column, Cube, CubeDimension, DataSource, JoinType, Manifest, Measure, Model, PrimaryKey, + Relationship, TimeDimension, View, }; use crate::mdl::{ColumnLevelOperator, NormalizedExpr, RowLevelAccessControl, SessionProperty}; use std::collections::BTreeMap; @@ -144,7 +144,31 @@ impl ModelBuilder { } pub fn primary_key(mut self, primary_key: &str) -> Self { - self.model.primary_key = Some(primary_key.to_string()); + assert!( + !primary_key.trim().is_empty(), + "primary_key must be a non-empty column name" + ); + self.model.primary_key = Some(PrimaryKey::Single(primary_key.to_string())); + self + } + + /// Set a composite primary key spanning multiple columns. + /// A single column collapses to [`PrimaryKey::Single`] so the serialized + /// form stays a plain string. + pub fn primary_keys(mut self, primary_keys: &[&str]) -> Self { + assert!( + !primary_keys.is_empty(), + "primary_keys must contain at least one column" + ); + assert!( + primary_keys.iter().all(|k| !k.trim().is_empty()), + "primary_keys cannot contain empty column names" + ); + self.model.primary_key = Some(if let [single] = primary_keys { + PrimaryKey::Single(single.to_string()) + } else { + PrimaryKey::Composite(primary_keys.iter().map(|s| s.to_string()).collect()) + }); self } diff --git a/core/wren-core-base/src/mdl/manifest.rs b/core/wren-core-base/src/mdl/manifest.rs index 15d67eba5d..631241450c 100644 --- a/core/wren-core-base/src/mdl/manifest.rs +++ b/core/wren-core-base/src/mdl/manifest.rs @@ -25,6 +25,7 @@ use std::sync::Arc; mod manifest_impl { use crate::mdl::manifest::bool_from_int; use crate::mdl::manifest::table_reference; + use crate::mdl::manifest::PrimaryKey; use manifest_macro::{ column, column_level_access_control, column_level_operator, cube, cube_dimension, data_source, join_type, manifest, measure, model, normalized_expr, normalized_expr_type, @@ -59,6 +60,7 @@ mod manifest_impl { mod manifest_impl { use crate::mdl::manifest::bool_from_int; use crate::mdl::manifest::table_reference; + use crate::mdl::manifest::PrimaryKey; use manifest_macro::{ column, column_level_access_control, column_level_operator, cube, cube_dimension, data_source, join_type, manifest, measure, model, normalized_expr, normalized_expr_type, @@ -93,7 +95,28 @@ mod manifest_impl { pub use crate::mdl::manifest::manifest_impl::*; -pub const MAX_SUPPORTED_LAYOUT_VERSION: u32 = 2; +/// The primary key of a [Model]. A model may declare either a single column +/// (`"primaryKey": "id"`) or a composite key (`"primaryKey": ["a", "b"]`). +/// The `#[serde(untagged)]` representation keeps the legacy single-string form +/// fully backward compatible. +#[derive(serde::Serialize, serde::Deserialize, Debug, PartialEq, Eq, Hash, Clone)] +#[serde(untagged)] +pub enum PrimaryKey { + Single(String), + Composite(Vec), +} + +impl PrimaryKey { + /// All primary key columns in declaration order. + pub fn columns(&self) -> Vec<&str> { + match self { + PrimaryKey::Single(s) => vec![s.as_str()], + PrimaryKey::Composite(v) => v.iter().map(String::as_str).collect(), + } + } +} + +pub const MAX_SUPPORTED_LAYOUT_VERSION: u32 = 3; impl Manifest { pub fn validate_layout_version(&self) -> Result<(), LayoutVersionError> { @@ -377,9 +400,21 @@ impl Model { .map(Arc::clone) } - /// Return the primary key of the model + /// Return the first primary key column of the model. + /// For a composite key this is the first declared column; use + /// [`Model::primary_keys`] to get every column. pub fn primary_key(&self) -> Option<&str> { - self.primary_key.as_deref() + self.primary_key + .as_ref() + .and_then(|pk| pk.columns().into_iter().next()) + } + + /// Return all primary key columns of the model (empty if none declared). + pub fn primary_keys(&self) -> Vec<&str> { + self.primary_key + .as_ref() + .map(PrimaryKey::columns) + .unwrap_or_default() } /// Return the table reference of the model @@ -557,4 +592,75 @@ mod tests { model = ModelBuilder::new("empty_model").build(); assert!(matches!(model.source(), ModelSource::Invalid(_))); } + + #[test] + fn test_primary_key_serde() { + use crate::mdl::manifest::{Model, PrimaryKey}; + + // Legacy single-column form deserializes to Single and serializes back to a string. + let single: Model = + serde_json::from_str(r#"{"name":"customer","columns":[],"primaryKey":"c_custkey"}"#) + .unwrap(); + assert_eq!( + single.primary_key, + Some(PrimaryKey::Single("c_custkey".into())) + ); + assert_eq!(single.primary_key(), Some("c_custkey")); + assert_eq!(single.primary_keys(), vec!["c_custkey"]); + assert_eq!( + serde_json::to_value(&single.primary_key).unwrap(), + serde_json::json!("c_custkey") + ); + + // Composite form deserializes to Composite and serializes back to an array. + let composite: Model = serde_json::from_str( + r#"{"name":"partsupp","columns":[],"primaryKey":["ps_partkey","ps_suppkey"]}"#, + ) + .unwrap(); + assert_eq!( + composite.primary_key, + Some(PrimaryKey::Composite(vec![ + "ps_partkey".into(), + "ps_suppkey".into() + ])) + ); + assert_eq!(composite.primary_key(), Some("ps_partkey")); + assert_eq!(composite.primary_keys(), vec!["ps_partkey", "ps_suppkey"]); + assert_eq!( + serde_json::to_value(&composite.primary_key).unwrap(), + serde_json::json!(["ps_partkey", "ps_suppkey"]) + ); + + // Absent primary key. + let none: Model = serde_json::from_str(r#"{"name":"m","columns":[]}"#).unwrap(); + assert_eq!(none.primary_key(), None); + assert!(none.primary_keys().is_empty()); + + // Builder produces the composite form. + let model = ModelBuilder::new("partsupp") + .primary_keys(&["ps_partkey", "ps_suppkey"]) + .build(); + assert_eq!(model.primary_keys(), vec!["ps_partkey", "ps_suppkey"]); + + // A single-column `primary_keys` collapses to `Single` (serializes to a string). + let model = ModelBuilder::new("customer") + .primary_keys(&["c_custkey"]) + .build(); + assert_eq!( + model.primary_key, + Some(PrimaryKey::Single("c_custkey".into())) + ); + } + + #[test] + #[should_panic(expected = "non-empty")] + fn test_builder_rejects_empty_primary_key() { + ModelBuilder::new("m").primary_key(" ").build(); + } + + #[test] + #[should_panic(expected = "at least one")] + fn test_builder_rejects_empty_primary_keys() { + ModelBuilder::new("m").primary_keys(&[]).build(); + } } diff --git a/core/wren-core-base/src/mdl/migration.rs b/core/wren-core-base/src/mdl/migration.rs index 45a324122a..a2341861c8 100644 --- a/core/wren-core-base/src/mdl/migration.rs +++ b/core/wren-core-base/src/mdl/migration.rs @@ -82,6 +82,7 @@ pub fn migrate_manifest( for version in current..target_version { match version { 1 => migrate_v1_to_v2(&mut value), + 2 => migrate_v2_to_v3(&mut value), _ => { return Err(MigrationError::UnsupportedTargetVersion { target: target_version, @@ -102,6 +103,14 @@ fn migrate_v1_to_v2(_value: &mut Value) { // so existing manifests deserialize correctly without changes. } +/// v2→v3: No data transformation needed. +/// `primaryKey` accepts a composite array in addition to a single string; +/// existing single-string primary keys remain valid. +fn migrate_v2_to_v3(_value: &mut Value) { + // No-op: `primaryKey` is an untagged `string | array` enum, so existing + // single-column manifests deserialize correctly without changes. +} + #[cfg(test)] mod tests { use super::*; @@ -114,6 +123,26 @@ mod tests { assert_eq!(value["layoutVersion"], 2); } + #[test] + fn test_migrate_v2_to_v3() { + let v2_json = r#"{"layoutVersion":2,"catalog":"wren","schema":"public","models":[]}"#; + let result = migrate_manifest(v2_json, 3).unwrap(); + let value: Value = serde_json::from_str(&result).unwrap(); + assert_eq!(value["layoutVersion"], 3); + } + + #[test] + fn test_migrate_v1_to_v3_preserves_composite_pk() { + let v1_json = r#"{"catalog":"wren","schema":"public","models":[{"name":"partsupp","columns":[],"primaryKey":["ps_partkey","ps_suppkey"]}]}"#; + let result = migrate_manifest(v1_json, 3).unwrap(); + let value: Value = serde_json::from_str(&result).unwrap(); + assert_eq!(value["layoutVersion"], 3); + assert_eq!( + value["models"][0]["primaryKey"], + serde_json::json!(["ps_partkey", "ps_suppkey"]) + ); + } + #[test] fn test_migrate_already_at_target() { let v2_json = r#"{"layoutVersion":2,"catalog":"wren","schema":"public","models":[]}"#; diff --git a/core/wren-core/core/src/logical_plan/analyze/model_generation.rs b/core/wren-core/core/src/logical_plan/analyze/model_generation.rs index 4a33b424aa..f039209cd9 100644 --- a/core/wren-core/core/src/logical_plan/analyze/model_generation.rs +++ b/core/wren-core/core/src/logical_plan/analyze/model_generation.rs @@ -236,15 +236,18 @@ impl ModelGenerationRule { let name = alias.name.clone(); let ident = ident(rebased_measure.to_string()).alias(name.clone()); - let rebased_dimension = - rebase_column(&calculation_plan.dimensions[0], &plan_alias)?; - let project = vec![rebased_dimension.clone(), ident]; + // Group by every primary key dimension so composite-key + // calculations expose all key columns for the join back. + let rebased_dimensions = calculation_plan + .dimensions + .iter() + .map(|dimension| rebase_column(dimension, &plan_alias)) + .collect::>>()?; + let mut project = rebased_dimensions.clone(); + project.push(ident); let result = match source_plan { Some(plan) => LogicalPlanBuilder::from(plan) - .aggregate( - vec![rebased_dimension], - vec![rebased_measure], - )? + .aggregate(rebased_dimensions, vec![rebased_measure])? .project(project)? .build()?, _ => { diff --git a/core/wren-core/core/src/logical_plan/analyze/plan.rs b/core/wren-core/core/src/logical_plan/analyze/plan.rs index f6aa53fbd9..97345c0df7 100644 --- a/core/wren-core/core/src/logical_plan/analyze/plan.rs +++ b/core/wren-core/core/src/logical_plan/analyze/plan.rs @@ -236,31 +236,38 @@ impl ModelPlanNodeBuilder { expr, )?; self.required_calculation.push(calculation); - // insert the primary key to the required fields for join with the calculation - - let Some(pk_column) = model - .primary_key() - .and_then(|pk| model.get_visible_column(pk)) - else { + // insert the primary key(s) to the required fields for join with the calculation + let pk_names = model.primary_keys(); + if pk_names.is_empty() { return plan_err!( "Primary key not found for model {}. To use `TO_MANY` relationship, the primary key is required for the base model.", model.name() ); - }; - self.model_required_fields + } + let required_fields = self + .model_required_fields .entry(TableReference::full( self.analyzed_wren_mdl.wren_mdl().catalog(), self.analyzed_wren_mdl.wren_mdl().schema(), model.name(), )) - .or_default() - .insert(OrdExpr::new(Expr::Column( + .or_default(); + for pk in pk_names { + let Some(pk_column) = model.get_visible_column(pk) else { + return plan_err!( + "Primary key column {} not found for model {}", + pk, + model.name() + ); + }; + required_fields.insert(OrdExpr::new(Expr::Column( DFColumn::from_qualified_name(format!( "{}.{}", quoted(model.name()), quoted(pk_column.name()), )), ))); + } } else { merge_graph(&mut self.directed_graph, column_graph)?; if self.is_contain_calculation_source(&qualified_column) { @@ -389,24 +396,35 @@ impl ModelPlanNodeBuilder { for calculation_plan in calculate_iter { let target_ref = TableReference::bare(calculation_plan.name()); - let Some(join_key) = model.primary_key() else { + let join_keys = model.primary_keys(); + if join_keys.is_empty() { return plan_err!( "Model {} should have primary key for calculation", model.name() ); - }; + } + // Join the calculation on every primary key column. The conjunction is + // re-parsed downstream by `collect_join_keys`, so a composite key emits + // `pk1 = pk1 AND pk2 = pk2`. + let join_condition = join_keys + .iter() + .map(|join_key| { + format!( + "{}.{} = {}.{}", + quoted(model_ref.table()), + quoted(join_key), + quoted(target_ref.table()), + quoted(join_key), + ) + }) + .collect::>() + .join(" AND "); relation_chain = RelationChain::Chain( LogicalPlan::Extension(Extension { node: calculation_plan.as_ref(), }), JoinType::OneToOne, - format!( - "{}.{} = {}.{}", - quoted(model_ref.table()), - quoted(join_key), - quoted(target_ref.table()), - quoted(join_key), - ), + join_condition, Box::new(relation_chain), ); } @@ -1053,35 +1071,46 @@ impl CalculationPlanNode { let Some(model) = calculation.dataset.try_as_model() else { return plan_err!("Only support model as source dataset"); }; - let Some(pk_column) = model - .primary_key() - .and_then(|pk| model.get_visible_column(pk)) - else { + let pk_names = model.primary_keys(); + if pk_names.is_empty() { return plan_err!("Primary key not found"); - }; + } + let mut pk_columns = Vec::with_capacity(pk_names.len()); + for pk in pk_names { + let Some(pk_column) = model.get_visible_column(pk) else { + return plan_err!("Primary key not found"); + }; + pk_columns.push(pk_column); + } - // include calculation column and join key (pk) - let output_field = vec![ - Arc::new(Field::new( - calculation.column.name(), - try_map_data_type(&calculation.column.r#type)?, - calculation.column.not_null, - )), - Arc::new(Field::new( + // include calculation column and join key(s) (pk) + let mut output_field = vec![Arc::new(Field::new( + calculation.column.name(), + try_map_data_type(&calculation.column.r#type)?, + calculation.column.not_null, + ))]; + for pk_column in &pk_columns { + output_field.push(Arc::new(Field::new( pk_column.name(), try_map_data_type(&pk_column.r#type)?, pk_column.not_null, - )), - ] - .into_iter() - .map(|f| (Some(TableReference::bare(quoted(model.name()))), f)) - .collect(); - let dimensions = vec![create_wren_expr_for_model( - &pk_column.name, - Arc::clone(&model), - Arc::clone(&session_state_ref), - )? - .alias(pk_column.name())]; + ))); + } + let output_field = output_field + .into_iter() + .map(|f| (Some(TableReference::bare(quoted(model.name()))), f)) + .collect(); + let dimensions = pk_columns + .iter() + .map(|pk_column| { + Ok(create_wren_expr_for_model( + &pk_column.name, + Arc::clone(&model), + Arc::clone(&session_state_ref), + )? + .alias(pk_column.name())) + }) + .collect::>>()?; let schema_ref = DFSchemaRef::new( DFSchema::new_with_metadata(output_field, HashMap::new()) .expect("create schema failed"), diff --git a/core/wren-core/core/src/mdl/mod.rs b/core/wren-core/core/src/mdl/mod.rs index 157f45189e..3b82f3a320 100644 --- a/core/wren-core/core/src/mdl/mod.rs +++ b/core/wren-core/core/src/mdl/mod.rs @@ -4295,4 +4295,71 @@ mod test { ); Ok(()) } + + /// Verify that a TO_MANY calculation on a model with a *composite* primary key + /// joins the calculation back on EVERY primary key column (`pk1 = pk1 AND + /// pk2 = pk2`), exercising the primary-key consumption sites in `plan.rs`. + #[tokio::test] + async fn test_composite_key_calculation() -> Result<()> { + let ctx = create_wren_ctx(None, None); + let manifest = ManifestBuilder::new() + .catalog("wren") + .schema("test") + .model( + ModelBuilder::new("part") + .table_reference("part") + .column(ColumnBuilder::new("p_partkey", "int").build()) + .column(ColumnBuilder::new("p_suppkey", "int").build()) + .column( + ColumnBuilder::new_relationship( + "partsupp", + "partsupp", + "part_partsupp", + ) + .build(), + ) + .column( + ColumnBuilder::new_calculated("total_availqty", "int") + .expression("sum(partsupp.ps_availqty)") + .build(), + ) + .primary_keys(&["p_partkey", "p_suppkey"]) + .build(), + ) + .model( + ModelBuilder::new("partsupp") + .table_reference("partsupp") + .column(ColumnBuilder::new("ps_partkey", "int").build()) + .column(ColumnBuilder::new("ps_suppkey", "int").build()) + .column(ColumnBuilder::new("ps_availqty", "int").build()) + .primary_keys(&["ps_partkey", "ps_suppkey"]) + .build(), + ) + .relationship( + RelationshipBuilder::new("part_partsupp") + .model("part") + .model("partsupp") + .join_type(JoinType::OneToMany) + .condition( + "part.p_partkey = partsupp.ps_partkey AND \ + part.p_suppkey = partsupp.ps_suppkey", + ) + .build(), + ) + .build(); + let analyzed_mdl = Arc::new(AnalyzedWrenMDL::analyze( + manifest, + Arc::new(HashMap::default()), + Mode::Unparse, + )?); + let sql = "SELECT total_availqty FROM part"; + // The calculation must be joined back to `part` on BOTH primary key + // columns: `... ON totalavailqty.p_partkey = part.p_partkey AND + // totalavailqty.p_suppkey = part.p_suppkey`. + assert_snapshot!( + transform_sql_with_ctx(&ctx, Arc::clone(&analyzed_mdl), &[], Arc::new(HashMap::default()), sql).await?, + @r#"SELECT "part".total_availqty FROM (SELECT __relation__1.total_availqty FROM (SELECT total_availqty.p_partkey, total_availqty.p_suppkey, total_availqty.total_availqty FROM (SELECT __relation__1.p_partkey AS p_partkey, __relation__1.p_suppkey AS p_suppkey, sum(CAST(__relation__1.ps_availqty AS BIGINT)) AS total_availqty FROM (SELECT "part".p_partkey, "part".p_suppkey, partsupp.ps_availqty, partsupp.ps_partkey, partsupp.ps_suppkey FROM (SELECT partsupp.ps_availqty, partsupp.ps_partkey, partsupp.ps_suppkey FROM (SELECT partsupp.ps_availqty, partsupp.ps_partkey, partsupp.ps_suppkey FROM (SELECT __source.ps_availqty AS ps_availqty, __source.ps_partkey AS ps_partkey, __source.ps_suppkey AS ps_suppkey FROM partsupp AS __source) AS partsupp) AS partsupp) AS partsupp RIGHT OUTER JOIN (SELECT __source.p_partkey AS p_partkey, __source.p_suppkey AS p_suppkey FROM "part" AS __source) AS "part" ON partsupp.ps_partkey = "part".p_partkey AND partsupp.ps_suppkey = "part".p_suppkey) AS __relation__1 GROUP BY __relation__1.p_partkey, __relation__1.p_suppkey) AS total_availqty RIGHT OUTER JOIN (SELECT __source.p_partkey AS p_partkey, __source.p_suppkey AS p_suppkey FROM "part" AS __source) AS "part" ON total_availqty.p_partkey = "part".p_partkey AND total_availqty.p_suppkey = "part".p_suppkey) AS __relation__1) AS "part""# + ); + Ok(()) + } } diff --git a/core/wren-mdl/mdl.schema.json b/core/wren-mdl/mdl.schema.json index 33b2749749..078b637e27 100644 --- a/core/wren-mdl/mdl.schema.json +++ b/core/wren-mdl/mdl.schema.json @@ -311,8 +311,21 @@ } }, "primaryKey": { - "description": "the primary key of the model. It's required if the model is the one side of any OEN_TO_MANY or MANY_TO_ONE relationship", - "type": "string" + "description": "the primary key of the model. A single column (string) or a composite key (array of columns). It's required if the model is the one side of any ONE_TO_MANY or MANY_TO_ONE relationship", + "oneOf": [ + { + "type": "string", + "minLength": 1 + }, + { + "type": "array", + "items": { + "type": "string", + "minLength": 1 + }, + "minItems": 1 + } + ] }, "cached": { "description": "(WIP) whether the model is cached or not", diff --git a/core/wren/src/wren/context.py b/core/wren/src/wren/context.py index c5c95e106c..68387745d4 100644 --- a/core/wren/src/wren/context.py +++ b/core/wren/src/wren/context.py @@ -151,9 +151,9 @@ def convert_mdl_to_project(mdl_json: dict) -> list[ProjectFile]: # ── wren_project.yml ────────────────────────────────────── # Map layoutVersion back to schema_version layout_version = mdl_json.get("layoutVersion", 1) - _LAYOUT_TO_SCHEMA = {1: 2, 2: 3} + _LAYOUT_TO_SCHEMA = {1: 2, 2: 3, 3: 4} schema_version = _LAYOUT_TO_SCHEMA.get( - layout_version, 3 if layout_version >= 2 else 2 + layout_version, 4 if layout_version >= 3 else (3 if layout_version >= 2 else 2) ) project_config: dict[str, Any] = {"schema_version": schema_version} if "name" in mdl_json: @@ -419,10 +419,10 @@ def save_project_config(project_path: Path, config: dict) -> None: ) -_SUPPORTED_SCHEMA_VERSIONS = {1, 2, 3} +_SUPPORTED_SCHEMA_VERSIONS = {1, 2, 3, 4} # schema_version → layoutVersion mapping for the engine -_LAYOUT_VERSION_MAP = {1: 1, 2: 1, 3: 2} +_LAYOUT_VERSION_MAP = {1: 1, 2: 1, 3: 2, 4: 3} # Valid dialect values (matches Rust DataSource enum) _VALID_DIALECTS = { @@ -864,12 +864,39 @@ def validate_project(project_path: Path) -> list[ValidationError]: ) pk = model.get("primary_key") - if pk and pk not in col_names: + if pk is None: + pk_cols = [] + elif isinstance(pk, str): + pk_cols = [pk] + elif isinstance(pk, list) and all(isinstance(c, str) and c for c in pk) and pk: + pk_cols = pk + else: errors.append( ValidationError( "error", f"{src_path} > {name}", - f"primary_key '{pk}' not found in columns", + "primary_key must be a non-empty string or list of non-empty strings", + ) + ) + pk_cols = [] + for pk_col in pk_cols: + if pk_col not in col_names: + errors.append( + ValidationError( + "error", + f"{src_path} > {name}", + f"primary_key '{pk_col}' not found in columns", + ) + ) + + # Composite (list-form) primary_key is a layoutVersion 3 / schema_version 4 + # wire format that older engines cannot deserialize. + if isinstance(pk, list) and sv < 4: + errors.append( + ValidationError( + "warning", + f"{src_path} > {name}", + f"composite primary_key requires schema_version >= 4 (current: {sv})", ) ) @@ -1119,7 +1146,8 @@ def plan_upgrade( created, deleted = _plan_v1_to_v2(project_path) files_created.extend(created) files_deleted.extend(deleted) - # v2→v3: no file layout changes needed + # v2→v3 (dialect) and v3→v4 (composite primary_key): no file layout + # changes needed — only wren_project.yml is restamped. return UpgradeResult( from_version=current, diff --git a/core/wren/src/wren/osi.py b/core/wren/src/wren/osi.py index 5c29f7f1aa..5f4e470b1f 100644 --- a/core/wren/src/wren/osi.py +++ b/core/wren/src/wren/osi.py @@ -354,7 +354,7 @@ def _convert_field( *, dialect: str, type_override: str | None, - primary_key_name: str | None, + primary_key_names: set[str], ) -> dict: name = field_obj["name"] expr = _pick_expression(field_obj.get("expression"), dialect) @@ -370,7 +370,7 @@ def _convert_field( } if is_calc: column["expression"] = expr - if primary_key_name == name: + if name in primary_key_names: column["is_primary_key"] = True column["not_null"] = True if desc := _osi_description(field_obj): @@ -394,20 +394,6 @@ def _format_column_types_snippet(dataset_name: str, field_names: list[str]) -> s ) -def _format_composite_pk_snippet( - dataset_name: str, candidates: list[str], picked: str -) -> str: - """Snippet hint when wren falls back to first column of composite PK.""" - return ( - f" Wren picked {picked!r}. To override, add to dataset " - f"'{dataset_name}' in the OSI file:\n\n" - " custom_extensions:\n" - " - vendor_name: WREN\n" - f' data: \'{{"primary_key": "\"}}'" - ) - - def _convert_dataset( ds: dict, *, wren_cfg: WrenConfig ) -> tuple[dict, list[ValidationError]]: @@ -441,34 +427,24 @@ def _convert_dataset( } ds_pk_override = ds_wren.get("primary_key") - # Primary key — OSI allows composite arrays; wren takes a single string. + # Primary key — OSI allows a single string or a composite array; wren MDL + # supports both (a string for single, an array for composite). pk_raw = ds.get("primary_key") - pk_name: str | None = None + pk_names: list[str] = [] if isinstance(pk_raw, list): - if len(pk_raw) == 1: - pk_name = str(pk_raw[0]) - elif len(pk_raw) > 1: - candidates = [str(c) for c in pk_raw] - pick: str | None = None - if isinstance(ds_pk_override, str) and ds_pk_override in candidates: - pick = ds_pk_override - elif name in wren_cfg.primary_key_pick: - want = wren_cfg.primary_key_pick[name] - if want in candidates: - pick = want - pk_name = pick or candidates[0] - if pick is None: - errors.append( - ValidationError( - "warning", - f"dataset '{name}'", - f"composite primary_key {candidates} — Wren MDL " - f"takes one column.\n" - + _format_composite_pk_snippet(name, candidates, pk_name), - ) - ) + candidates = [str(c) for c in pk_raw if c] + # Explicit-narrowing escape hatch: a WREN override / primary_key_pick may + # select a single column out of a composite key. + pick: str | None = None + if isinstance(ds_pk_override, str) and ds_pk_override in candidates: + pick = ds_pk_override + elif name in wren_cfg.primary_key_pick: + want = wren_cfg.primary_key_pick[name] + if want in candidates: + pick = want + pk_names = [pick] if pick else candidates elif isinstance(pk_raw, str) and pk_raw: - pk_name = pk_raw + pk_names = [pk_raw] # Convert fields columns: list[dict] = [] @@ -491,7 +467,7 @@ def _convert_dataset( f, dialect=wren_cfg.dialect, type_override=type_override, - primary_key_name=pk_name, + primary_key_names=set(pk_names), ) columns.append(col) if type_override is None: @@ -519,8 +495,8 @@ def _convert_dataset( model["table_reference"] = table_ref else: model["ref_sql"] = ref_sql - if pk_name: - model["primary_key"] = pk_name + if pk_names: + model["primary_key"] = pk_names[0] if len(pk_names) == 1 else pk_names if desc := _osi_description(ds): model["properties"]["description"] = desc diff --git a/core/wren/tests/unit/test_context.py b/core/wren/tests/unit/test_context.py index 2350395877..eb0cabcb10 100644 --- a/core/wren/tests/unit/test_context.py +++ b/core/wren/tests/unit/test_context.py @@ -483,6 +483,77 @@ def test_validate_pk_not_in_columns(tmp_path): assert any("not found in columns" in e.message for e in errors) +def test_validate_pk_invalid_type(tmp_path): + _make_v2_project(tmp_path) + d = tmp_path / "models" / "orders" + d.mkdir(parents=True) + (d / "metadata.yml").write_text( + "name: orders\n" + "table_reference:\n table: orders\n" + "columns:\n - name: id\n type: INTEGER\n" + "primary_key: 123\n" + ) + # Must not raise (no TypeError) and must flag the malformed shape. + errors = validate_project(tmp_path) + assert any("must be a non-empty string or list" in e.message for e in errors) + + +def test_validate_composite_pk_missing_col(tmp_path): + _make_v2_project(tmp_path) + d = tmp_path / "models" / "orders" + d.mkdir(parents=True) + (d / "metadata.yml").write_text( + "name: orders\n" + "table_reference:\n table: orders\n" + "columns:\n - name: a\n type: INTEGER\n" + "primary_key:\n - a\n - missing_col\n" + ) + errors = validate_project(tmp_path) + assert any("primary_key 'missing_col' not found" in e.message for e in errors) + + +def test_validate_composite_pk_all_present(tmp_path): + _make_v2_project(tmp_path) + d = tmp_path / "models" / "orders" + d.mkdir(parents=True) + (d / "metadata.yml").write_text( + "name: orders\n" + "table_reference:\n table: orders\n" + "columns:\n - name: a\n type: INTEGER\n - name: b\n type: INTEGER\n" + "primary_key:\n - a\n - b\n" + ) + errors = validate_project(tmp_path) + assert not any("not found in columns" in e.message for e in errors) + + +def test_validate_composite_pk_requires_schema_version_4(tmp_path): + _make_v2_project(tmp_path, schema_version=3) + d = tmp_path / "models" / "orders" + d.mkdir(parents=True) + (d / "metadata.yml").write_text( + "name: orders\n" + "table_reference:\n table: orders\n" + "columns:\n - name: a\n type: INTEGER\n - name: b\n type: INTEGER\n" + "primary_key:\n - a\n - b\n" + ) + errors = validate_project(tmp_path) + assert any("requires schema_version >= 4" in e.message for e in errors) + + +def test_validate_composite_pk_ok_at_schema_version_4(tmp_path): + _make_v2_project(tmp_path, schema_version=4) + d = tmp_path / "models" / "orders" + d.mkdir(parents=True) + (d / "metadata.yml").write_text( + "name: orders\n" + "table_reference:\n table: orders\n" + "columns:\n - name: a\n type: INTEGER\n - name: b\n type: INTEGER\n" + "primary_key:\n - a\n - b\n" + ) + errors = validate_project(tmp_path) + assert not any("requires schema_version >= 4" in e.message for e in errors) + + def test_validate_relationship_unknown_model(tmp_path): _make_valid_project(tmp_path) (tmp_path / "relationships.yml").write_text( @@ -1021,7 +1092,7 @@ def test_plan_upgrade_above_target(tmp_path): def test_plan_upgrade_default_to_latest(tmp_path): _make_v1_project(tmp_path) result = plan_upgrade(tmp_path) - assert result.to_version == 3 + assert result.to_version == 4 def test_apply_upgrade_v1_to_v2(tmp_path): diff --git a/core/wren/tests/unit/test_context_cli.py b/core/wren/tests/unit/test_context_cli.py index 07a996b456..603ceff0f1 100644 --- a/core/wren/tests/unit/test_context_cli.py +++ b/core/wren/tests/unit/test_context_cli.py @@ -466,13 +466,13 @@ def _make_v1_project(tmp_path: Path) -> Path: return tmp_path -def test_upgrade_cli_v2_to_v3(tmp_path): +def test_upgrade_cli_default_to_latest(tmp_path): _make_valid_project(tmp_path) result = runner.invoke(app, ["context", "upgrade", "--path", str(tmp_path)]) assert result.exit_code == 0, result.output assert "Upgrade complete" in result.output config = yaml.safe_load((tmp_path / "wren_project.yml").read_text()) - assert config["schema_version"] == 3 + assert config["schema_version"] == 4 def test_upgrade_cli_dry_run(tmp_path): @@ -491,7 +491,7 @@ def test_upgrade_cli_dry_run(tmp_path): def test_upgrade_cli_already_current(tmp_path): (tmp_path / "wren_project.yml").write_text( - "schema_version: 3\nname: test\ndata_source: postgres\n" + "schema_version: 4\nname: test\ndata_source: postgres\n" ) result = runner.invoke(app, ["context", "upgrade", "--path", str(tmp_path)]) assert result.exit_code == 0 diff --git a/core/wren/tests/unit/test_osi.py b/core/wren/tests/unit/test_osi.py index ac579011aa..fbdb912923 100644 --- a/core/wren/tests/unit/test_osi.py +++ b/core/wren/tests/unit/test_osi.py @@ -382,16 +382,24 @@ def test_build_tpcds_full_runs(): assert len(manifest["relationships"]) == 4 -def test_build_tpcds_composite_pk_warning(): - """store_sales has composite PK [ss_item_sk, ss_ticket_number].""" - _, errors = build_manifest_from_osi( +def test_build_tpcds_composite_pk_preserved(): + """store_sales has composite PK [ss_item_sk, ss_ticket_number] — wren MDL now + keeps it as a list instead of downgrading to one column with a warning.""" + manifest, errors = build_manifest_from_osi( _fixture("tpcds_full.yaml"), data_source="postgres" ) + # No downgrade warning any more. composite_warns = [e for e in errors if "composite primary_key" in e.message] - assert len(composite_warns) == 1 - assert "store_sales" in composite_warns[0].path - # Snippet must guide the user to override - assert 'data: \'{"primary_key"' in composite_warns[0].message + assert composite_warns == [] + + store_sales = next(m for m in manifest["models"] if m["name"] == "store_sales") + assert store_sales["primary_key"] == ["ss_item_sk", "ss_ticket_number"] + # Every composite member that is a declared field is flagged on its column. + # (ss_ticket_number is referenced by the PK but not declared as a field in + # this fixture, so it has no column to flag.) + pk_cols = {c["name"] for c in store_sales["columns"] if c.get("is_primary_key")} + assert "ss_item_sk" in pk_cols + assert pk_cols <= {"ss_item_sk", "ss_ticket_number"} def test_build_tpcds_untyped_field_warnings_include_snippets():