Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 99 additions & 0 deletions examples/restate_attribute_demo.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
use restate_sdk::prelude::*;

/// This example demonstrates the new #[restate(...)] attribute syntax
/// for configuring handler behavior.
#[restate_sdk::object]
trait ProductInventory {
/// Shared handler with timeout configuration
/// This handler can run concurrently with others and has a 30-second inactivity timeout
#[restate(shared, inactivity_timeout = "30s")]
async fn get_stock() -> Result<u32, TerminalError>;

/// Regular handler with custom timeouts
/// This handler has exclusive access to state and custom timeout settings
#[restate(inactivity_timeout = "1m", abort_timeout = "10s")]
async fn update_stock(quantity: u32) -> Result<(), TerminalError>;

/// Shared handler with lazy state enabled
/// Useful for read-heavy operations where state is loaded on-demand
#[restate(shared = true, lazy_state = true, inactivity_timeout = "45s")]
async fn check_availability() -> Result<bool, TerminalError>;

/// Regular handler without special configuration
async fn reserve_stock(quantity: u32) -> Result<bool, TerminalError>;

/// Handler demonstrating different duration formats
#[restate(inactivity_timeout = "500ms", abort_timeout = "2h")]
async fn quick_check() -> Result<(), TerminalError>;
}

struct ProductInventoryImpl;

const STOCK_KEY: &str = "stock";

impl ProductInventory for ProductInventoryImpl {
async fn get_stock(&self, ctx: SharedObjectContext<'_>) -> Result<u32, TerminalError> {
println!("Getting stock");
Ok(ctx.get::<u32>(STOCK_KEY).await?.unwrap_or(0))
}

async fn update_stock(
&self,
ctx: ObjectContext<'_>,
quantity: u32,
) -> Result<(), TerminalError> {
println!("Updating stock to {}", quantity);
ctx.set(STOCK_KEY, quantity);
Ok(())
}

async fn check_availability(
&self,
ctx: SharedObjectContext<'_>,
) -> Result<bool, TerminalError> {
println!("Checking availability");
let stock = ctx.get::<u32>(STOCK_KEY).await?.unwrap_or(0);
Ok(stock > 0)
}

async fn reserve_stock(
&self,
ctx: ObjectContext<'_>,
quantity: u32,
) -> Result<bool, TerminalError> {
println!("Reserving {} units", quantity);
let current = ctx.get::<u32>(STOCK_KEY).await?.unwrap_or(0);

if current >= quantity {
ctx.set(STOCK_KEY, current - quantity);
Ok(true)
} else {
Ok(false)
}
}

async fn quick_check(&self, _ctx: ObjectContext<'_>) -> Result<(), TerminalError> {
println!("Quick check");
Ok(())
}
}

#[tokio::main]
async fn main() {
tracing_subscriber::fmt::init();

println!("Starting Product Inventory service with new #[restate(...)] attribute syntax");
println!("This example demonstrates:");
println!(" - Shared handlers with configurable timeouts");
println!(" - Lazy state loading");
println!(" - Various duration formats (ms, s, m, h)");
println!(" - Mixing multiple configuration options");

HttpServer::new(
Endpoint::builder()
.bind(ProductInventoryImpl.serve())
.build(),
)
.listen_and_serve("0.0.0.0:9080".parse().unwrap())
.await;
}
1 change: 1 addition & 0 deletions macros/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,5 @@ proc-macro = true
[dependencies]
proc-macro2 = "1.0"
quote = "1.0"
humantime = "2.3"
syn = { version = "2.0", features = ["full"] }
109 changes: 92 additions & 17 deletions macros/src/ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,13 @@
// Some parts copied from https://github.com/dtolnay/thiserror/blob/39aaeb00ff270a49e3c254d7b38b10e934d3c7a5/impl/src/ast.rs
// License Apache-2.0 or MIT

use quote::ToTokens;
use syn::ext::IdentExt;
use syn::parse::{Parse, ParseStream};
use syn::spanned::Spanned;
use syn::token::Comma;
use syn::{
Attribute, Error, Expr, ExprLit, FnArg, GenericArgument, Ident, Lit, Pat, PatType, Path,
Attribute, Error, Expr, ExprLit, FnArg, GenericArgument, Ident, Lit, Meta, Pat, PatType, Path,
PathArguments, Result, ReturnType, Token, Type, Visibility, braced, parenthesized, parse_quote,
};

Expand Down Expand Up @@ -83,7 +84,7 @@ impl ServiceInner {
while !content.is_empty() {
let h: Handler = content.parse()?;

if h.is_shared && service_type == ServiceType::Service {
if h.config.shared && service_type == ServiceType::Service {
return Err(Error::new(
h.ident.span(),
"Service handlers cannot be annotated with #[shared]",
Expand Down Expand Up @@ -139,10 +140,78 @@ impl ServiceInner {
}
}

/// Parsed configuration from #[restate(...)] attribute
#[derive(Default)]
pub(crate) struct HandlerConfig {
pub(crate) shared: bool,
pub(crate) lazy_state: bool,
pub(crate) inactivity_timeout: Option<u64>,
pub(crate) abort_timeout: Option<u64>,
}

/// Parse a duration string like "30s", "5m", "1h" into milliseconds
fn parse_duration_ms(s: &str, span: proc_macro2::Span) -> Result<u64> {
let duration: humantime::Duration = s
.trim()
.parse()
.map_err(|err| Error::new(span, format!("Failed to parse duration: {err}")))?;

u64::try_from(duration.as_millis()).map_err(|_| Error::new(span, "Duration overflows u64"))
}

/// Parse #[restate(...)] attribute
fn parse_restate_attr(attr: &Attribute) -> Result<HandlerConfig> {
let mut result = HandlerConfig::default();

let meta_list = match &attr.meta {
Meta::List(list) => list,
_ => return Err(Error::new_spanned(attr, "Expected #[restate(...)]")),
};

// Parse the nested meta items
meta_list.parse_nested_meta(|meta| {
let path = &meta.path;

// Check if this is a boolean flag (e.g., "shared") or named value (e.g., "shared=true")
if path.is_ident("shared") {
if meta.input.peek(Token![=]) {
// Parse as "shared = true" or "shared = false"
let value: syn::LitBool = meta.value()?.parse()?;
result.shared = value.value;
} else {
// Parse as just "shared" (flag syntax)
result.shared = true;
}
} else if path.is_ident("lazy_state") {
if meta.input.peek(Token![=]) {
let value: syn::LitBool = meta.value()?.parse()?;
result.lazy_state = value.value;
} else {
result.lazy_state = true;
}
} else if path.is_ident("inactivity_timeout") {
let value: syn::LitStr = meta.value()?.parse()?;
result.inactivity_timeout = Some(parse_duration_ms(&value.value(), value.span())?);
} else if path.is_ident("abort_timeout") {
let value: syn::LitStr = meta.value()?.parse()?;
result.abort_timeout = Some(parse_duration_ms(&value.value(), value.span())?);
} else {
return Err(meta.error(format!(
"Unknown restate attribute: {}. Supported attributes are: \
shared, lazy_state, inactivity_timeout, abort_timeout",
path.to_token_stream()
)));
}

Ok(())
})?;

Ok(result)
}

pub(crate) struct Handler {
pub(crate) attrs: Vec<Attribute>,
pub(crate) is_shared: bool,
pub(crate) is_lazy_state: bool,
pub(crate) config: HandlerConfig,
pub(crate) restate_name: String,
pub(crate) ident: Ident,
pub(crate) arg: Option<PatType>,
Expand Down Expand Up @@ -198,7 +267,8 @@ impl Parse for Handler {
ReturnType::Default => {
return Err(Error::new(
return_type.span(),
"The return type cannot be empty, only Result or restate_sdk::prelude::HandlerResult is supported as return type",
"The return type cannot be empty, only Result or \
restate_sdk::prelude::HandlerResult is supported as return type",
));
}
ReturnType::Type(_, ty) => {
Expand All @@ -214,15 +284,16 @@ impl Parse for Handler {
};

// Process attributes
let mut is_shared = false;
let mut is_lazy_state = false;
let mut restate_name = ident.to_string();
let mut attrs = vec![];
let mut config = HandlerConfig::default();
let mut has_legacy_shared = false;
for attr in parsed_attrs {
if is_shared_attr(&attr) {
is_shared = true;
} else if is_lazy_state_attr(&attr) {
is_lazy_state = true;
// support deprecated shared
has_legacy_shared = true;
} else if is_restate_attr(&attr) {
config = parse_restate_attr(&attr)?;
} else if let Some(name) = read_literal_attribute_name(&attr)? {
restate_name = name;
} else {
Expand All @@ -231,10 +302,13 @@ impl Parse for Handler {
}
}

if has_legacy_shared {
config.shared = true;
}

Ok(Self {
attrs,
is_shared,
is_lazy_state,
config,
restate_name,
ident,
arg: args.pop(),
Expand All @@ -251,11 +325,12 @@ fn is_shared_attr(attr: &Attribute) -> bool {
.is_ok_and(|i| i == "shared")
}

fn is_lazy_state_attr(attr: &Attribute) -> bool {
attr.meta
.require_path_only()
.and_then(Path::require_ident)
.is_ok_and(|i| i == "lazy_state")
fn is_restate_attr(attr: &Attribute) -> bool {
if let Meta::List(list) = &attr.meta {
list.path.is_ident("restate")
} else {
false
}
}

fn read_literal_attribute_name(attr: &Attribute) -> Result<Option<String>> {
Expand Down
22 changes: 15 additions & 7 deletions macros/src/generator.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::ast::{Handler, Object, Service, ServiceInner, ServiceType, Workflow};
use crate::ast::{Handler, HandlerConfig, Object, Service, ServiceInner, ServiceType, Workflow};
use proc_macro2::TokenStream as TokenStream2;
use proc_macro2::{Ident, Literal};
use quote::{ToTokens, format_ident, quote};
Expand Down Expand Up @@ -55,10 +55,10 @@ impl<'a> ServiceGenerator<'a> {
let handler_fns = handlers
.iter()
.map(
|Handler { attrs, ident, arg, is_shared, output_ok, output_err, .. }| {
|Handler { attrs, ident, arg, config: HandlerConfig{shared, ..}, output_ok, output_err, .. }| {
let args = arg.iter();

let ctx = match (&service_ty, is_shared) {
let ctx = match (&service_ty, shared) {
(ServiceType::Service, _) => quote! { ::restate_sdk::prelude::Context },
(ServiceType::Object, true) => quote! { ::restate_sdk::prelude::SharedObjectContext },
(ServiceType::Object, false) => quote! { ::restate_sdk::prelude::ObjectContext },
Expand Down Expand Up @@ -185,7 +185,7 @@ impl<'a> ServiceGenerator<'a> {
let handlers = handlers.iter().map(|handler| {
let handler_literal = Literal::string(&handler.restate_name);

let handler_ty = if handler.is_shared {
let handler_ty = if handler.config.shared {
quote! { Some(::restate_sdk::discovery::HandlerType::Shared) }
} else if *service_ty == ServiceType::Workflow {
quote! { Some(::restate_sdk::discovery::HandlerType::Workflow) }
Expand All @@ -194,12 +194,20 @@ impl<'a> ServiceGenerator<'a> {
quote! { None }
};

let lazy_state = if handler.is_lazy_state {
let lazy_state = if handler.config.lazy_state {
quote! { Some(true) }
} else {
quote! { None}
};

let inactivity_timeout = handler.config.inactivity_timeout.map(|timeout| {
quote! { Some(#timeout) }
}).unwrap_or_else(|| quote! { None });

let abort_timeout = handler.config.abort_timeout.map(|timeout| {
quote! { Some(#timeout) }
}).unwrap_or_else(|| quote! { None });

let input_schema = match &handler.arg {
Some(PatType { ty, .. }) => {
quote! {
Expand Down Expand Up @@ -229,8 +237,8 @@ impl<'a> ServiceGenerator<'a> {
ty: #handler_ty,
documentation: None,
metadata: Default::default(),
abort_timeout: None,
inactivity_timeout: None,
abort_timeout: #abort_timeout,
inactivity_timeout: #inactivity_timeout,
journal_retention: None,
idempotency_retention: None,
workflow_completion_retention: None,
Expand Down
40 changes: 40 additions & 0 deletions tests/backward_compat_test.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
use restate_sdk::prelude::*;

// Test backward compatibility with the old #[shared] attribute
#[restate_sdk::object]
trait BackwardCompatObject {
// Old syntax should still work
#[shared]
async fn legacy_shared() -> HandlerResult<String>;

// Mix old and new syntax (as long as they don't conflict)
#[shared]
#[restate(inactivity_timeout = "30s")]
async fn mixed_syntax() -> HandlerResult<String>;

// New syntax
#[restate(shared)]
async fn new_shared() -> HandlerResult<String>;
}

struct BackwardCompatObjectImpl;

impl BackwardCompatObject for BackwardCompatObjectImpl {
async fn legacy_shared(&self, _ctx: SharedObjectContext<'_>) -> HandlerResult<String> {
Ok("legacy".to_string())
}

async fn mixed_syntax(&self, _ctx: SharedObjectContext<'_>) -> HandlerResult<String> {
Ok("mixed".to_string())
}

async fn new_shared(&self, _ctx: SharedObjectContext<'_>) -> HandlerResult<String> {
Ok("new".to_string())
}
}

#[test]
fn test_backward_compatibility() {
// This test verifies backward compatibility with old attributes
let _service = BackwardCompatObjectImpl.serve();
}
Loading