Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cot-cli/src/migration_generator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1690,6 +1690,7 @@ mod tests {
foreign_key: Some(ForeignKeySpec {
to_model: parse_quote!(crate::Table4),
}),
many_to_many: None,
}],
},
];
Expand Down
106 changes: 106 additions & 0 deletions cot-codegen/src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,8 @@ pub struct FieldOpts {
pub ty: syn::Type,
pub primary_key: darling::util::Flag,
pub unique: darling::util::Flag,
#[darling(default)]
pub many_to_many: Option<ManyToManyOpts>,
}

impl FieldOpts {
Expand Down Expand Up @@ -213,6 +215,17 @@ impl FieldOpts {
.map(ForeignKeySpec::try_from)
.transpose()?,
);

let many_to_many_spec = match ManyToManySpec::try_from(self.ty.clone()) {
Ok(mut spec) => {
if let Some(attr) = &self.many_to_many {
spec.attr = attr.clone();
}
Some(spec)
}
Err(_) => None,
};

let is_primary_key = self.primary_key.is_present();
let mut resolved_ty = self.ty.clone();
symbol_resolver.resolve(&mut resolved_ty, self_reference);
Expand All @@ -224,6 +237,7 @@ impl FieldOpts {
primary_key: is_primary_key,
foreign_key,
unique: self.unique.is_present(),
many_to_many: many_to_many_spec,
})
}
}
Expand Down Expand Up @@ -261,6 +275,7 @@ pub struct Field {
/// determined not to be a foreign key.
pub foreign_key: Option<ForeignKeySpec>,
pub unique: bool,
pub many_to_many: Option<ManyToManySpec>,
}

#[derive(Debug, Clone, PartialEq, Eq, Hash)]
Expand Down Expand Up @@ -310,6 +325,97 @@ impl TryFrom<syn::Type> for ForeignKeySpec {
}
}

use syn::{Type, TypeGroup, TypeParen, TypePath, TypeReference};

#[derive(Debug, Clone, FromMeta, Default, PartialEq, Eq, Hash)]
pub struct ManyToManyOpts {
#[darling(default)]
pub table: Option<String>,
#[darling(default)]
pub owner_field: Option<String>,
#[darling(default)]
pub target_field: Option<String>,
#[darling(default)]
pub through: Option<String>,
}

#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct ManyToManySpec {
pub to_model: syn::Type,
pub target_table_name: String,
pub attr: ManyToManyOpts,
}

impl TryFrom<syn::Type> for ManyToManySpec {
type Error = syn::Error;

fn try_from(ty: syn::Type) -> Result<Self, Self::Error> {
let syn::Type::Path(type_path) = &ty else {
return Err(syn::Error::new(
ty.span(),
"expected a path type for ManyToMany",
));
};

let seg = type_path
.path
.segments
.last()
.expect("type path must have at least one segment");

let ident_str = seg.ident.to_string();
if ident_str != "ManyToMany" && !ident_str.ends_with("::ManyToMany") {
return Err(syn::Error::new(ty.span(), "expected ManyToMany<T>"));
}

let syn::PathArguments::AngleBracketed(args) = &seg.arguments else {
return Err(syn::Error::new(
ty.span(),
"expected ManyToMany to have angle-bracketed generic arguments",
));
};

if args.args.len() != 1 {
return Err(syn::Error::new(
ty.span(),
"expected ManyToMany to have only one generic parameter",
));
}

let inner = &args.args[0];
let inner_ty = if let syn::GenericArgument::Type(inner_ty) = inner {
inner_ty
} else {
return Err(syn::Error::new(
ty.span(),
"expected a type generic argument",
));
};

fn type_to_snake_name(ty: &Type) -> Option<String> {
match ty {
Type::Path(TypePath { path, .. }) => path
.segments
.last()
.map(|seg| seg.ident.to_string().to_snake_case()),
Type::Reference(TypeReference { elem, .. }) => type_to_snake_name(&*elem),
Type::Paren(TypeParen { elem, .. }) => type_to_snake_name(&*elem),
Type::Group(TypeGroup { elem, .. }) => type_to_snake_name(&*elem),
_ => None,
}
}

let target_table_name = type_to_snake_name(inner_ty)
.expect("Could not determine target table name from ManyToMany inner type");

Ok(ManyToManySpec {
to_model: inner_ty.clone(),
target_table_name,
attr: ManyToManyOpts::default(),
})
}
}

#[cfg(test)]
mod tests {
use syn::parse_quote;
Expand Down
72 changes: 71 additions & 1 deletion cot-macros/src/model.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
use cot_codegen::model::{Field, Model, ModelArgs, ModelOpts, ModelType};
use cot_codegen::model::{Field, ManyToManySpec, Model, ModelArgs, ModelOpts, ModelType};
use cot_codegen::symbol_resolver::{SymbolResolver, VisibleSymbol, VisibleSymbolKind};
use darling::FromMeta;
use darling::ast::NestedMeta;
use heck::ToSnakeCase;
use proc_macro2::Literal;
use proc_macro2::{Ident, TokenStream};
use quote::{ToTokens, TokenStreamExt, format_ident, quote};
use syn::Token;
Expand Down Expand Up @@ -82,6 +83,7 @@ struct ModelBuilder {
fields_as_update_from_db: Vec<TokenStream>,
fields_as_get_values: Vec<TokenStream>,
fields_as_field_refs: Vec<TokenStream>,
fields_as_m2m_consts: Vec<TokenStream>,
}

impl ToTokens for ModelBuilder {
Expand Down Expand Up @@ -116,6 +118,7 @@ impl ModelBuilder {
fields_as_update_from_db: Vec::with_capacity(field_count),
fields_as_get_values: Vec::with_capacity(field_count),
fields_as_field_refs: Vec::with_capacity(field_count),
fields_as_m2m_consts: Vec::with_capacity(field_count),
};
for field in &model.fields {
model_builder.push_field(field);
Expand All @@ -132,6 +135,39 @@ impl ModelBuilder {
let index = self.fields_as_columns.len();
let column_name = &field.column_name;

if let Some(m2m_spec) = &field.many_to_many {
let target_ty = &m2m_spec.to_model;
self.fields_as_from_db.push(quote!(
#name: #orm_ident::ManyToMany::<#target_ty>::default()
));

self.fields_as_update_from_db.push(quote!(
_ => { /* many-to-many relation is not present in this row (stored in a join table) */ }
));

let (join_table, left_col, right_col) = self.infer_m2m_names(field, m2m_spec);
let join_table_lit = Literal::string(&join_table);
let left_col_lit = Literal::string(&left_col);
let right_col_lit = Literal::string(&right_col);

let const_ident = format_ident!("{}_M2M", name.to_string().to_uppercase());

let owner_ty = &self.name;

let m2m_const = quote!(
#[doc = concat!("Many-to-many metadata for the `", stringify!(#name), "` field.")]
pub const #const_ident: #orm_ident::ManyToManyField<#target_ty, #owner_ty> =
#orm_ident::ManyToManyField::new(
#join_table_lit,
#left_col_lit,
#right_col_lit,
);
);

self.fields_as_m2m_consts.push(m2m_const);
return;
}

{
let field_as_column = quote!(#orm_ident::Column::new(
#orm_ident::Identifier::new(#column_name)
Expand All @@ -158,6 +194,37 @@ impl ModelBuilder {
));
}

fn infer_m2m_names(&self, field: &Field, m2m: &ManyToManySpec) -> (String, String, String) {
// owner table as available in ModelBuilder
let owner_table = self.table_name.clone(); // already snake-cased + app namespace if needed
let owner_pk_col = self.pk_field.column_name.clone();

let target_table = m2m.target_table_name.clone();
let target_pk_col = "id".to_string();

let join_table = if let Some(t) = &m2m.attr.table {
t.clone()
} else {
format!("{}_{}", owner_table, field.name.to_string().to_snake_case())
};

let left_col = if let Some(l) = &m2m.attr.owner_field {
l.clone()
} else {
format!("{}_{}", owner_table, owner_pk_col)
};

let right_col = if let Some(r) = &m2m.attr.target_field {
r.clone()
} else if owner_table == target_table {
format!("{}_id", field.name.to_string().to_snake_case())
} else {
format!("{}_{}", target_table, target_pk_col)
};

(join_table, left_col, right_col)
}

#[must_use]
fn build_model_impl(&self) -> TokenStream {
let crate_ident = cot_ident();
Expand Down Expand Up @@ -242,6 +309,7 @@ impl ModelBuilder {
let vis = &self.vis;
let fields_struct_name = &self.fields_struct_name;
let fields_as_field_refs = &self.fields_as_field_refs;
let m2m_consts = &self.fields_as_m2m_consts;

quote! {
#[doc = concat!("Fields of the model [`", stringify!(#name), "`].")]
Expand All @@ -251,6 +319,8 @@ impl ModelBuilder {
#[expect(non_upper_case_globals)]
impl #fields_struct_name {
#(#fields_as_field_refs)*

#(#m2m_consts)*
}
}
}
Expand Down
4 changes: 3 additions & 1 deletion cot/src/db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@ use derive_more::{Debug, Deref, Display};
#[cfg(test)]
use mockall::automock;
use query::Query;
pub use relations::{ForeignKey, ForeignKeyOnDeletePolicy, ForeignKeyOnUpdatePolicy};
pub use relations::{
ForeignKey, ForeignKeyOnDeletePolicy, ForeignKeyOnUpdatePolicy, ManyToMany, ManyToManyField,
};
use sea_query::{
ColumnRef, Iden, IntoColumnRef, OnConflict, ReturningClause, SchemaStatementBuilder, SimpleExpr,
};
Expand Down
2 changes: 1 addition & 1 deletion cot/src/db/fields.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use crate::db::impl_postgres::PostgresValueRef;
use crate::db::impl_sqlite::SqliteValueRef;
use crate::db::{
Auto, ColumnType, DatabaseError, DatabaseField, DbFieldValue, DbValue, ForeignKey, FromDbValue,
LimitedString, Model, PrimaryKey, Result, SqlxValueRef, ToDbFieldValue, ToDbValue,
LimitedString, ManyToMany, Model, PrimaryKey, Result, SqlxValueRef, ToDbFieldValue, ToDbValue,
};

mod chrono_wrapper;
Expand Down
74 changes: 74 additions & 0 deletions cot/src/db/relations.rs
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,80 @@ impl From<ForeignKeyOnUpdatePolicy> for sea_query::ForeignKeyAction {
}
}

#[derive(Debug, Clone)]
pub enum ManyToMany<T: Model> {
PrimaryKeys(Vec<T::PrimaryKey>),
Models(Vec<T>),
}

impl<T: Model> Default for ManyToMany<T> {
fn default() -> Self {
ManyToMany::PrimaryKeys(Vec::new())
}
}
impl<T: Model> ManyToMany<T> {
pub fn primary_keys(&self) -> Vec<&T::PrimaryKey> {
match self {
Self::PrimaryKeys(pks) => pks.iter().collect(),
Self::Models(models) => models.iter().map(|m| m.primary_key()).collect(),
}
}

pub fn models(&self) -> Option<&Vec<T>> {
match self {
Self::Models(models) => Some(models),
Self::PrimaryKeys(_) => None,
}
}
}

#[derive(Debug, Clone)]
pub struct ManyToManyField<T: Model, Owner: Model> {
pub join_table: &'static str,
pub owner_field: &'static str,
pub target_field: &'static str,
pub phantom: std::marker::PhantomData<(T, Owner)>,
}

impl<T: Model, Owner: Model> ManyToManyField<T, Owner> {
pub const fn new(
join_table: &'static str,
owner_field: &'static str,
target_field: &'static str,
) -> Self {
Self {
join_table,
owner_field,
target_field,
phantom: std::marker::PhantomData,
}
}

pub async fn get<'a, DB: DatabaseBackend>(
&self,
relation: &'a mut ManyToMany<T>,
db: &DB,
) -> Result<&'a Vec<T>> {
match relation {
ManyToMany::Models(m) => Ok(m),
ManyToMany::PrimaryKeys(pks) => {
let mut models = Vec::with_capacity(pks.len());
for pk in pks {
let model = T::get_by_primary_key(db, pk.clone())
.await?
.ok_or(DatabaseError::ForeignKeyNotFound)?;
models.push(model);
}
*relation = ManyToMany::Models(models);
match relation {
ManyToMany::Models(m) => Ok(m),
ManyToMany::PrimaryKeys(_) => unreachable!("models were just set"),
}
}
}
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
Loading