diff --git a/examples/restate_attribute_demo.rs b/examples/restate_attribute_demo.rs new file mode 100644 index 0000000..f1b9b09 --- /dev/null +++ b/examples/restate_attribute_demo.rs @@ -0,0 +1,99 @@ +use restate_sdk::prelude::*; + +/// This example demonstrates the new #[restate(...)] attribute syntax +/// for configuring handler behavior. +#[restate_sdk::object] +trait ProductInventory { + /// Shared handler with timeout configuration + /// This handler can run concurrently with others and has a 30-second inactivity timeout + #[restate(shared, inactivity_timeout = "30s")] + async fn get_stock() -> Result; + + /// Regular handler with custom timeouts + /// This handler has exclusive access to state and custom timeout settings + #[restate(inactivity_timeout = "1m", abort_timeout = "10s")] + async fn update_stock(quantity: u32) -> Result<(), TerminalError>; + + /// Shared handler with lazy state enabled + /// Useful for read-heavy operations where state is loaded on-demand + #[restate(shared = true, lazy_state = true, inactivity_timeout = "45s")] + async fn check_availability() -> Result; + + /// Regular handler without special configuration + async fn reserve_stock(quantity: u32) -> Result; + + /// Handler demonstrating different duration formats + #[restate(inactivity_timeout = "500ms", abort_timeout = "2h")] + async fn quick_check() -> Result<(), TerminalError>; +} + +struct ProductInventoryImpl; + +const STOCK_KEY: &str = "stock"; + +impl ProductInventory for ProductInventoryImpl { + async fn get_stock(&self, ctx: SharedObjectContext<'_>) -> Result { + println!("Getting stock"); + Ok(ctx.get::(STOCK_KEY).await?.unwrap_or(0)) + } + + async fn update_stock( + &self, + ctx: ObjectContext<'_>, + quantity: u32, + ) -> Result<(), TerminalError> { + println!("Updating stock to {}", quantity); + ctx.set(STOCK_KEY, quantity); + Ok(()) + } + + async fn check_availability( + &self, + ctx: SharedObjectContext<'_>, + ) -> Result { + println!("Checking availability"); + let stock = ctx.get::(STOCK_KEY).await?.unwrap_or(0); + Ok(stock > 0) + } + + async fn reserve_stock( + &self, + ctx: ObjectContext<'_>, + quantity: u32, + ) -> Result { + println!("Reserving {} units", quantity); + let current = ctx.get::(STOCK_KEY).await?.unwrap_or(0); + + if current >= quantity { + ctx.set(STOCK_KEY, current - quantity); + Ok(true) + } else { + Ok(false) + } + } + + async fn quick_check(&self, _ctx: ObjectContext<'_>) -> Result<(), TerminalError> { + println!("Quick check"); + Ok(()) + } +} + +#[tokio::main] +async fn main() { + tracing_subscriber::fmt::init(); + + println!("Starting Product Inventory service with new #[restate(...)] attribute syntax"); + println!("This example demonstrates:"); + println!(" - Shared handlers with configurable timeouts"); + println!(" - Lazy state loading"); + println!(" - Various duration formats (ms, s, m, h)"); + println!(" - Mixing multiple configuration options"); + + HttpServer::new( + Endpoint::builder() + .bind(ProductInventoryImpl.serve()) + .build(), + ) + .listen_and_serve("0.0.0.0:9080".parse().unwrap()) + .await; +} diff --git a/macros/Cargo.toml b/macros/Cargo.toml index 6a26bc8..ab7ae07 100644 --- a/macros/Cargo.toml +++ b/macros/Cargo.toml @@ -13,4 +13,5 @@ proc-macro = true [dependencies] proc-macro2 = "1.0" quote = "1.0" +humantime = "2.3" syn = { version = "2.0", features = ["full"] } diff --git a/macros/src/ast.rs b/macros/src/ast.rs index 010de80..7cbe4b6 100644 --- a/macros/src/ast.rs +++ b/macros/src/ast.rs @@ -11,12 +11,13 @@ // Some parts copied from https://github.com/dtolnay/thiserror/blob/39aaeb00ff270a49e3c254d7b38b10e934d3c7a5/impl/src/ast.rs // License Apache-2.0 or MIT +use quote::ToTokens; use syn::ext::IdentExt; use syn::parse::{Parse, ParseStream}; use syn::spanned::Spanned; use syn::token::Comma; use syn::{ - Attribute, Error, Expr, ExprLit, FnArg, GenericArgument, Ident, Lit, Pat, PatType, Path, + Attribute, Error, Expr, ExprLit, FnArg, GenericArgument, Ident, Lit, Meta, Pat, PatType, Path, PathArguments, Result, ReturnType, Token, Type, Visibility, braced, parenthesized, parse_quote, }; @@ -83,7 +84,7 @@ impl ServiceInner { while !content.is_empty() { let h: Handler = content.parse()?; - if h.is_shared && service_type == ServiceType::Service { + if h.config.shared && service_type == ServiceType::Service { return Err(Error::new( h.ident.span(), "Service handlers cannot be annotated with #[shared]", @@ -139,10 +140,78 @@ impl ServiceInner { } } +/// Parsed configuration from #[restate(...)] attribute +#[derive(Default)] +pub(crate) struct HandlerConfig { + pub(crate) shared: bool, + pub(crate) lazy_state: bool, + pub(crate) inactivity_timeout: Option, + pub(crate) abort_timeout: Option, +} + +/// Parse a duration string like "30s", "5m", "1h" into milliseconds +fn parse_duration_ms(s: &str, span: proc_macro2::Span) -> Result { + let duration: humantime::Duration = s + .trim() + .parse() + .map_err(|err| Error::new(span, format!("Failed to parse duration: {err}")))?; + + u64::try_from(duration.as_millis()).map_err(|_| Error::new(span, "Duration overflows u64")) +} + +/// Parse #[restate(...)] attribute +fn parse_restate_attr(attr: &Attribute) -> Result { + let mut result = HandlerConfig::default(); + + let meta_list = match &attr.meta { + Meta::List(list) => list, + _ => return Err(Error::new_spanned(attr, "Expected #[restate(...)]")), + }; + + // Parse the nested meta items + meta_list.parse_nested_meta(|meta| { + let path = &meta.path; + + // Check if this is a boolean flag (e.g., "shared") or named value (e.g., "shared=true") + if path.is_ident("shared") { + if meta.input.peek(Token![=]) { + // Parse as "shared = true" or "shared = false" + let value: syn::LitBool = meta.value()?.parse()?; + result.shared = value.value; + } else { + // Parse as just "shared" (flag syntax) + result.shared = true; + } + } else if path.is_ident("lazy_state") { + if meta.input.peek(Token![=]) { + let value: syn::LitBool = meta.value()?.parse()?; + result.lazy_state = value.value; + } else { + result.lazy_state = true; + } + } else if path.is_ident("inactivity_timeout") { + let value: syn::LitStr = meta.value()?.parse()?; + result.inactivity_timeout = Some(parse_duration_ms(&value.value(), value.span())?); + } else if path.is_ident("abort_timeout") { + let value: syn::LitStr = meta.value()?.parse()?; + result.abort_timeout = Some(parse_duration_ms(&value.value(), value.span())?); + } else { + return Err(meta.error(format!( + "Unknown restate attribute: {}. Supported attributes are: \ + shared, lazy_state, inactivity_timeout, abort_timeout", + path.to_token_stream() + ))); + } + + Ok(()) + })?; + + Ok(result) +} + pub(crate) struct Handler { pub(crate) attrs: Vec, - pub(crate) is_shared: bool, - pub(crate) is_lazy_state: bool, + pub(crate) config: HandlerConfig, pub(crate) restate_name: String, pub(crate) ident: Ident, pub(crate) arg: Option, @@ -198,7 +267,8 @@ impl Parse for Handler { ReturnType::Default => { return Err(Error::new( return_type.span(), - "The return type cannot be empty, only Result or restate_sdk::prelude::HandlerResult is supported as return type", + "The return type cannot be empty, only Result or \ + restate_sdk::prelude::HandlerResult is supported as return type", )); } ReturnType::Type(_, ty) => { @@ -214,15 +284,16 @@ impl Parse for Handler { }; // Process attributes - let mut is_shared = false; - let mut is_lazy_state = false; let mut restate_name = ident.to_string(); let mut attrs = vec![]; + let mut config = HandlerConfig::default(); + let mut has_legacy_shared = false; for attr in parsed_attrs { if is_shared_attr(&attr) { - is_shared = true; - } else if is_lazy_state_attr(&attr) { - is_lazy_state = true; + // support deprecated shared + has_legacy_shared = true; + } else if is_restate_attr(&attr) { + config = parse_restate_attr(&attr)?; } else if let Some(name) = read_literal_attribute_name(&attr)? { restate_name = name; } else { @@ -231,10 +302,13 @@ impl Parse for Handler { } } + if has_legacy_shared { + config.shared = true; + } + Ok(Self { attrs, - is_shared, - is_lazy_state, + config, restate_name, ident, arg: args.pop(), @@ -251,11 +325,12 @@ fn is_shared_attr(attr: &Attribute) -> bool { .is_ok_and(|i| i == "shared") } -fn is_lazy_state_attr(attr: &Attribute) -> bool { - attr.meta - .require_path_only() - .and_then(Path::require_ident) - .is_ok_and(|i| i == "lazy_state") +fn is_restate_attr(attr: &Attribute) -> bool { + if let Meta::List(list) = &attr.meta { + list.path.is_ident("restate") + } else { + false + } } fn read_literal_attribute_name(attr: &Attribute) -> Result> { diff --git a/macros/src/generator.rs b/macros/src/generator.rs index 929f021..f396681 100644 --- a/macros/src/generator.rs +++ b/macros/src/generator.rs @@ -1,4 +1,4 @@ -use crate::ast::{Handler, Object, Service, ServiceInner, ServiceType, Workflow}; +use crate::ast::{Handler, HandlerConfig, Object, Service, ServiceInner, ServiceType, Workflow}; use proc_macro2::TokenStream as TokenStream2; use proc_macro2::{Ident, Literal}; use quote::{ToTokens, format_ident, quote}; @@ -55,10 +55,10 @@ impl<'a> ServiceGenerator<'a> { let handler_fns = handlers .iter() .map( - |Handler { attrs, ident, arg, is_shared, output_ok, output_err, .. }| { + |Handler { attrs, ident, arg, config: HandlerConfig{shared, ..}, output_ok, output_err, .. }| { let args = arg.iter(); - let ctx = match (&service_ty, is_shared) { + let ctx = match (&service_ty, shared) { (ServiceType::Service, _) => quote! { ::restate_sdk::prelude::Context }, (ServiceType::Object, true) => quote! { ::restate_sdk::prelude::SharedObjectContext }, (ServiceType::Object, false) => quote! { ::restate_sdk::prelude::ObjectContext }, @@ -185,7 +185,7 @@ impl<'a> ServiceGenerator<'a> { let handlers = handlers.iter().map(|handler| { let handler_literal = Literal::string(&handler.restate_name); - let handler_ty = if handler.is_shared { + let handler_ty = if handler.config.shared { quote! { Some(::restate_sdk::discovery::HandlerType::Shared) } } else if *service_ty == ServiceType::Workflow { quote! { Some(::restate_sdk::discovery::HandlerType::Workflow) } @@ -194,12 +194,20 @@ impl<'a> ServiceGenerator<'a> { quote! { None } }; - let lazy_state = if handler.is_lazy_state { + let lazy_state = if handler.config.lazy_state { quote! { Some(true) } } else { quote! { None} }; + let inactivity_timeout = handler.config.inactivity_timeout.map(|timeout| { + quote! { Some(#timeout) } + }).unwrap_or_else(|| quote! { None }); + + let abort_timeout = handler.config.abort_timeout.map(|timeout| { + quote! { Some(#timeout) } + }).unwrap_or_else(|| quote! { None }); + let input_schema = match &handler.arg { Some(PatType { ty, .. }) => { quote! { @@ -229,8 +237,8 @@ impl<'a> ServiceGenerator<'a> { ty: #handler_ty, documentation: None, metadata: Default::default(), - abort_timeout: None, - inactivity_timeout: None, + abort_timeout: #abort_timeout, + inactivity_timeout: #inactivity_timeout, journal_retention: None, idempotency_retention: None, workflow_completion_retention: None, diff --git a/tests/backward_compat_test.rs b/tests/backward_compat_test.rs new file mode 100644 index 0000000..cab3f8e --- /dev/null +++ b/tests/backward_compat_test.rs @@ -0,0 +1,40 @@ +use restate_sdk::prelude::*; + +// Test backward compatibility with the old #[shared] attribute +#[restate_sdk::object] +trait BackwardCompatObject { + // Old syntax should still work + #[shared] + async fn legacy_shared() -> HandlerResult; + + // Mix old and new syntax (as long as they don't conflict) + #[shared] + #[restate(inactivity_timeout = "30s")] + async fn mixed_syntax() -> HandlerResult; + + // New syntax + #[restate(shared)] + async fn new_shared() -> HandlerResult; +} + +struct BackwardCompatObjectImpl; + +impl BackwardCompatObject for BackwardCompatObjectImpl { + async fn legacy_shared(&self, _ctx: SharedObjectContext<'_>) -> HandlerResult { + Ok("legacy".to_string()) + } + + async fn mixed_syntax(&self, _ctx: SharedObjectContext<'_>) -> HandlerResult { + Ok("mixed".to_string()) + } + + async fn new_shared(&self, _ctx: SharedObjectContext<'_>) -> HandlerResult { + Ok("new".to_string()) + } +} + +#[test] +fn test_backward_compatibility() { + // This test verifies backward compatibility with old attributes + let _service = BackwardCompatObjectImpl.serve(); +} diff --git a/tests/restate_attribute_test.rs b/tests/restate_attribute_test.rs new file mode 100644 index 0000000..c7c76a0 --- /dev/null +++ b/tests/restate_attribute_test.rs @@ -0,0 +1,86 @@ +use restate_sdk::prelude::*; + +// Test the new #[restate(...)] attribute with various configurations +#[restate_sdk::object] +trait TestObject { + // Test using just the shared flag + #[restate(shared)] + async fn test_shared_flag() -> HandlerResult; + + // Test using named boolean syntax + #[restate(shared = true)] + async fn test_shared_named() -> HandlerResult; + + // Test using lazy_state flag + #[restate(lazy_state)] + async fn test_lazy_state_flag() -> HandlerResult; + + // Test combining multiple attributes + #[restate(shared, lazy_state = false)] + async fn test_combined() -> HandlerResult; + + // Test with timeout configuration + #[restate(inactivity_timeout = "30s")] + async fn test_inactivity_timeout() -> HandlerResult<()>; + + // Test with multiple timeout configurations + #[restate(shared, inactivity_timeout = "30s", abort_timeout = "5m")] + async fn test_multiple_timeouts() -> HandlerResult; + + // Test with various duration formats + #[restate(inactivity_timeout = "100ms")] + async fn test_milliseconds() -> HandlerResult<()>; + + #[restate(abort_timeout = "2h")] + async fn test_hours() -> HandlerResult<()>; + + // Regular handler without any special config + async fn regular_handler() -> HandlerResult<()>; +} + +struct TestObjectImpl; + +impl TestObject for TestObjectImpl { + async fn test_shared_flag(&self, _ctx: SharedObjectContext<'_>) -> HandlerResult { + Ok("shared flag".to_string()) + } + + async fn test_shared_named(&self, _ctx: SharedObjectContext<'_>) -> HandlerResult { + Ok("shared named".to_string()) + } + + async fn test_lazy_state_flag(&self, _ctx: ObjectContext<'_>) -> HandlerResult { + Ok("lazy state flag".to_string()) + } + + async fn test_combined(&self, _ctx: SharedObjectContext<'_>) -> HandlerResult { + Ok("combined".to_string()) + } + + async fn test_inactivity_timeout(&self, _ctx: ObjectContext<'_>) -> HandlerResult<()> { + Ok(()) + } + + async fn test_multiple_timeouts(&self, _ctx: SharedObjectContext<'_>) -> HandlerResult { + Ok("multiple timeouts".to_string()) + } + + async fn test_milliseconds(&self, _ctx: ObjectContext<'_>) -> HandlerResult<()> { + Ok(()) + } + + async fn test_hours(&self, _ctx: ObjectContext<'_>) -> HandlerResult<()> { + Ok(()) + } + + async fn regular_handler(&self, _ctx: ObjectContext<'_>) -> HandlerResult<()> { + Ok(()) + } +} + +#[test] +fn test_restate_attribute_compiles() { + // This test just verifies that the code compiles correctly + // The actual functionality is tested through the discovery mechanism + let _service = TestObjectImpl.serve(); +}