diff --git a/cot-cli/src/migration_generator.rs b/cot-cli/src/migration_generator.rs index da3b538d..30292fb8 100644 --- a/cot-cli/src/migration_generator.rs +++ b/cot-cli/src/migration_generator.rs @@ -1690,6 +1690,7 @@ mod tests { foreign_key: Some(ForeignKeySpec { to_model: parse_quote!(crate::Table4), }), + many_to_many: None, }], }, ]; diff --git a/cot-codegen/src/model.rs b/cot-codegen/src/model.rs index 6610b070..5a91867b 100644 --- a/cot-codegen/src/model.rs +++ b/cot-codegen/src/model.rs @@ -145,6 +145,8 @@ pub struct FieldOpts { pub ty: syn::Type, pub primary_key: darling::util::Flag, pub unique: darling::util::Flag, + #[darling(default)] + pub many_to_many: Option, } impl FieldOpts { @@ -213,6 +215,17 @@ impl FieldOpts { .map(ForeignKeySpec::try_from) .transpose()?, ); + + let many_to_many_spec = match ManyToManySpec::try_from(self.ty.clone()) { + Ok(mut spec) => { + if let Some(attr) = &self.many_to_many { + spec.attr = attr.clone(); + } + Some(spec) + } + Err(_) => None, + }; + let is_primary_key = self.primary_key.is_present(); let mut resolved_ty = self.ty.clone(); symbol_resolver.resolve(&mut resolved_ty, self_reference); @@ -224,6 +237,7 @@ impl FieldOpts { primary_key: is_primary_key, foreign_key, unique: self.unique.is_present(), + many_to_many: many_to_many_spec, }) } } @@ -261,6 +275,7 @@ pub struct Field { /// determined not to be a foreign key. pub foreign_key: Option, pub unique: bool, + pub many_to_many: Option, } #[derive(Debug, Clone, PartialEq, Eq, Hash)] @@ -310,6 +325,97 @@ impl TryFrom for ForeignKeySpec { } } +use syn::{Type, TypeGroup, TypeParen, TypePath, TypeReference}; + +#[derive(Debug, Clone, FromMeta, Default, PartialEq, Eq, Hash)] +pub struct ManyToManyOpts { + #[darling(default)] + pub table: Option, + #[darling(default)] + pub owner_field: Option, + #[darling(default)] + pub target_field: Option, + #[darling(default)] + pub through: Option, +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct ManyToManySpec { + pub to_model: syn::Type, + pub target_table_name: String, + pub attr: ManyToManyOpts, +} + +impl TryFrom for ManyToManySpec { + type Error = syn::Error; + + fn try_from(ty: syn::Type) -> Result { + let syn::Type::Path(type_path) = &ty else { + return Err(syn::Error::new( + ty.span(), + "expected a path type for ManyToMany", + )); + }; + + let seg = type_path + .path + .segments + .last() + .expect("type path must have at least one segment"); + + let ident_str = seg.ident.to_string(); + if ident_str != "ManyToMany" && !ident_str.ends_with("::ManyToMany") { + return Err(syn::Error::new(ty.span(), "expected ManyToMany")); + } + + let syn::PathArguments::AngleBracketed(args) = &seg.arguments else { + return Err(syn::Error::new( + ty.span(), + "expected ManyToMany to have angle-bracketed generic arguments", + )); + }; + + if args.args.len() != 1 { + return Err(syn::Error::new( + ty.span(), + "expected ManyToMany to have only one generic parameter", + )); + } + + let inner = &args.args[0]; + let inner_ty = if let syn::GenericArgument::Type(inner_ty) = inner { + inner_ty + } else { + return Err(syn::Error::new( + ty.span(), + "expected a type generic argument", + )); + }; + + fn type_to_snake_name(ty: &Type) -> Option { + match ty { + Type::Path(TypePath { path, .. }) => path + .segments + .last() + .map(|seg| seg.ident.to_string().to_snake_case()), + Type::Reference(TypeReference { elem, .. }) => type_to_snake_name(&*elem), + Type::Paren(TypeParen { elem, .. }) => type_to_snake_name(&*elem), + Type::Group(TypeGroup { elem, .. }) => type_to_snake_name(&*elem), + _ => None, + } + } + + let target_table_name = type_to_snake_name(inner_ty) + .expect("Could not determine target table name from ManyToMany inner type"); + + Ok(ManyToManySpec { + to_model: inner_ty.clone(), + target_table_name, + attr: ManyToManyOpts::default(), + }) + } +} + #[cfg(test)] mod tests { use syn::parse_quote; diff --git a/cot-macros/src/model.rs b/cot-macros/src/model.rs index 91006c40..bd7711e8 100644 --- a/cot-macros/src/model.rs +++ b/cot-macros/src/model.rs @@ -1,8 +1,9 @@ -use cot_codegen::model::{Field, Model, ModelArgs, ModelOpts, ModelType}; +use cot_codegen::model::{Field, ManyToManySpec, Model, ModelArgs, ModelOpts, ModelType}; use cot_codegen::symbol_resolver::{SymbolResolver, VisibleSymbol, VisibleSymbolKind}; use darling::FromMeta; use darling::ast::NestedMeta; use heck::ToSnakeCase; +use proc_macro2::Literal; use proc_macro2::{Ident, TokenStream}; use quote::{ToTokens, TokenStreamExt, format_ident, quote}; use syn::Token; @@ -82,6 +83,7 @@ struct ModelBuilder { fields_as_update_from_db: Vec, fields_as_get_values: Vec, fields_as_field_refs: Vec, + fields_as_m2m_consts: Vec, } impl ToTokens for ModelBuilder { @@ -116,6 +118,7 @@ impl ModelBuilder { fields_as_update_from_db: Vec::with_capacity(field_count), fields_as_get_values: Vec::with_capacity(field_count), fields_as_field_refs: Vec::with_capacity(field_count), + fields_as_m2m_consts: Vec::with_capacity(field_count), }; for field in &model.fields { model_builder.push_field(field); @@ -132,6 +135,39 @@ impl ModelBuilder { let index = self.fields_as_columns.len(); let column_name = &field.column_name; + if let Some(m2m_spec) = &field.many_to_many { + let target_ty = &m2m_spec.to_model; + self.fields_as_from_db.push(quote!( + #name: #orm_ident::ManyToMany::<#target_ty>::default() + )); + + self.fields_as_update_from_db.push(quote!( + _ => { /* many-to-many relation is not present in this row (stored in a join table) */ } + )); + + let (join_table, left_col, right_col) = self.infer_m2m_names(field, m2m_spec); + let join_table_lit = Literal::string(&join_table); + let left_col_lit = Literal::string(&left_col); + let right_col_lit = Literal::string(&right_col); + + let const_ident = format_ident!("{}_M2M", name.to_string().to_uppercase()); + + let owner_ty = &self.name; + + let m2m_const = quote!( + #[doc = concat!("Many-to-many metadata for the `", stringify!(#name), "` field.")] + pub const #const_ident: #orm_ident::ManyToManyField<#target_ty, #owner_ty> = + #orm_ident::ManyToManyField::new( + #join_table_lit, + #left_col_lit, + #right_col_lit, + ); + ); + + self.fields_as_m2m_consts.push(m2m_const); + return; + } + { let field_as_column = quote!(#orm_ident::Column::new( #orm_ident::Identifier::new(#column_name) @@ -158,6 +194,37 @@ impl ModelBuilder { )); } + fn infer_m2m_names(&self, field: &Field, m2m: &ManyToManySpec) -> (String, String, String) { + // owner table as available in ModelBuilder + let owner_table = self.table_name.clone(); // already snake-cased + app namespace if needed + let owner_pk_col = self.pk_field.column_name.clone(); + + let target_table = m2m.target_table_name.clone(); + let target_pk_col = "id".to_string(); + + let join_table = if let Some(t) = &m2m.attr.table { + t.clone() + } else { + format!("{}_{}", owner_table, field.name.to_string().to_snake_case()) + }; + + let left_col = if let Some(l) = &m2m.attr.owner_field { + l.clone() + } else { + format!("{}_{}", owner_table, owner_pk_col) + }; + + let right_col = if let Some(r) = &m2m.attr.target_field { + r.clone() + } else if owner_table == target_table { + format!("{}_id", field.name.to_string().to_snake_case()) + } else { + format!("{}_{}", target_table, target_pk_col) + }; + + (join_table, left_col, right_col) + } + #[must_use] fn build_model_impl(&self) -> TokenStream { let crate_ident = cot_ident(); @@ -242,6 +309,7 @@ impl ModelBuilder { let vis = &self.vis; let fields_struct_name = &self.fields_struct_name; let fields_as_field_refs = &self.fields_as_field_refs; + let m2m_consts = &self.fields_as_m2m_consts; quote! { #[doc = concat!("Fields of the model [`", stringify!(#name), "`].")] @@ -251,6 +319,8 @@ impl ModelBuilder { #[expect(non_upper_case_globals)] impl #fields_struct_name { #(#fields_as_field_refs)* + + #(#m2m_consts)* } } } diff --git a/cot/src/db.rs b/cot/src/db.rs index 445dd972..78e383d0 100644 --- a/cot/src/db.rs +++ b/cot/src/db.rs @@ -27,7 +27,9 @@ use derive_more::{Debug, Deref, Display}; #[cfg(test)] use mockall::automock; use query::Query; -pub use relations::{ForeignKey, ForeignKeyOnDeletePolicy, ForeignKeyOnUpdatePolicy}; +pub use relations::{ + ForeignKey, ForeignKeyOnDeletePolicy, ForeignKeyOnUpdatePolicy, ManyToMany, ManyToManyField, +}; use sea_query::{ ColumnRef, Iden, IntoColumnRef, OnConflict, ReturningClause, SchemaStatementBuilder, SimpleExpr, }; diff --git a/cot/src/db/fields.rs b/cot/src/db/fields.rs index e273b8d8..4925e927 100644 --- a/cot/src/db/fields.rs +++ b/cot/src/db/fields.rs @@ -8,7 +8,7 @@ use crate::db::impl_postgres::PostgresValueRef; use crate::db::impl_sqlite::SqliteValueRef; use crate::db::{ Auto, ColumnType, DatabaseError, DatabaseField, DbFieldValue, DbValue, ForeignKey, FromDbValue, - LimitedString, Model, PrimaryKey, Result, SqlxValueRef, ToDbFieldValue, ToDbValue, + LimitedString, ManyToMany, Model, PrimaryKey, Result, SqlxValueRef, ToDbFieldValue, ToDbValue, }; mod chrono_wrapper; diff --git a/cot/src/db/relations.rs b/cot/src/db/relations.rs index 5ab6685f..6d14d335 100644 --- a/cot/src/db/relations.rs +++ b/cot/src/db/relations.rs @@ -197,6 +197,80 @@ impl From for sea_query::ForeignKeyAction { } } +#[derive(Debug, Clone)] +pub enum ManyToMany { + PrimaryKeys(Vec), + Models(Vec), +} + +impl Default for ManyToMany { + fn default() -> Self { + ManyToMany::PrimaryKeys(Vec::new()) + } +} +impl ManyToMany { + pub fn primary_keys(&self) -> Vec<&T::PrimaryKey> { + match self { + Self::PrimaryKeys(pks) => pks.iter().collect(), + Self::Models(models) => models.iter().map(|m| m.primary_key()).collect(), + } + } + + pub fn models(&self) -> Option<&Vec> { + match self { + Self::Models(models) => Some(models), + Self::PrimaryKeys(_) => None, + } + } +} + +#[derive(Debug, Clone)] +pub struct ManyToManyField { + pub join_table: &'static str, + pub owner_field: &'static str, + pub target_field: &'static str, + pub phantom: std::marker::PhantomData<(T, Owner)>, +} + +impl ManyToManyField { + pub const fn new( + join_table: &'static str, + owner_field: &'static str, + target_field: &'static str, + ) -> Self { + Self { + join_table, + owner_field, + target_field, + phantom: std::marker::PhantomData, + } + } + + pub async fn get<'a, DB: DatabaseBackend>( + &self, + relation: &'a mut ManyToMany, + db: &DB, + ) -> Result<&'a Vec> { + match relation { + ManyToMany::Models(m) => Ok(m), + ManyToMany::PrimaryKeys(pks) => { + let mut models = Vec::with_capacity(pks.len()); + for pk in pks { + let model = T::get_by_primary_key(db, pk.clone()) + .await? + .ok_or(DatabaseError::ForeignKeyNotFound)?; + models.push(model); + } + *relation = ManyToMany::Models(models); + match relation { + ManyToMany::Models(m) => Ok(m), + ManyToMany::PrimaryKeys(_) => unreachable!("models were just set"), + } + } + } + } +} + #[cfg(test)] mod tests { use super::*;