diff --git a/CLAUDE.md b/CLAUDE.md index 057d54e..6ada361 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -93,13 +93,24 @@ Processes the controller impl block. Distinguishes between: * Creates `` stream type and `Args` struct. * Signal methods are NOT exposed in the client API (controller emits them directly). +**Poll methods** (marked with `#[controller(poll_*)]`): +* Methods are called periodically at the specified interval. +* Three time unit attributes are supported: + * `#[controller(poll_seconds = N)]` - Poll every N seconds. + * `#[controller(poll_millis = N)]` - Poll every N milliseconds. + * `#[controller(poll_micros = N)]` - Poll every N microseconds. +* Methods with the same timeout value (same unit and value) are grouped into a single ticker arm. +* All methods in a group are called sequentially when the ticker fires (in declaration order). +* Poll methods are NOT exposed in the client API (internal to the controller). +* Uses `embassy_time::Ticker::every()` for timing. + **Getter/setter methods** (from struct field attributes): * Receives getter/setter field info from struct processing. * Generates client-side getter methods that request current field value. * Generates client-side setter methods that update field value (and broadcast if published). The generated `run()` method contains a `select_biased!` loop that receives method calls from -clients and dispatches them to the user's implementations. +clients, dispatches them to the user's implementations, and handles periodic poll method calls. ### Utilities (`src/util.rs`) Case conversion functions (`pascal_to_snake_case`, `snake_to_pascal_case`) used for generating type and method names. @@ -109,6 +120,7 @@ Case conversion functions (`pascal_to_snake_case`, `snake_to_pascal_case`) used User code must have these dependencies (per README): * `futures` with `async-await` feature. * `embassy-sync` for channels and synchronization. +* `embassy-time` for poll method timing (only required if using poll methods). Dev dependencies include `embassy-executor` and `embassy-time` for testing. diff --git a/Cargo.lock b/Cargo.lock index 6ff7252..f91c989 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -17,12 +17,6 @@ version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8" -[[package]] -name = "bitflags" -version = "1.3.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" - [[package]] name = "byteorder" version = "1.5.0" @@ -76,38 +70,6 @@ dependencies = [ "syn", ] -[[package]] -name = "defmt" -version = "1.0.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "548d977b6da32fa1d1fda2876453da1e7df63ad0304c8b3dae4dbe7b96f39b78" -dependencies = [ - "bitflags", - "defmt-macros", -] - -[[package]] -name = "defmt-macros" -version = "1.0.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3d4fc12a85bcf441cfe44344c4b72d58493178ce635338a3f3b78943aceb258e" -dependencies = [ - "defmt-parser", - "proc-macro-error2", - "proc-macro2", - "quote", - "syn", -] - -[[package]] -name = "defmt-parser" -version = "1.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "10d60334b3b2e7c9d91ef8150abfb6fa4c1c39ebbcf4a81c2e346aad939fee3e" -dependencies = [ - "thiserror", -] - [[package]] name = "document-features" version = "0.2.11" @@ -124,7 +86,6 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "06070468370195e0e86f241c8e5004356d696590a678d47d6676795b2e439c6b" dependencies = [ "critical-section", - "defmt", "document-features", "embassy-executor-macros", "embassy-executor-timer-queue", @@ -170,9 +131,9 @@ checksum = "f4fa65b9284d974dad7a23bb72835c4ec85c0b540d86af7fc4098c88cff51d65" dependencies = [ "cfg-if", "critical-section", - "defmt", "document-features", "embassy-time-driver", + "embassy-time-queue-utils", "embedded-hal 0.2.7", "embedded-hal 1.0.0", "embedded-hal-async", @@ -188,6 +149,16 @@ dependencies = [ "document-features", ] +[[package]] +name = "embassy-time-queue-utils" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "80e2ee86063bd028a420a5fb5898c18c87a8898026da1d4c852af2c443d0a454" +dependencies = [ + "embassy-executor-timer-queue", + "heapless 0.8.0", +] + [[package]] name = "embedded-hal" version = "0.2.7" @@ -434,28 +405,6 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" -[[package]] -name = "proc-macro-error-attr2" -version = "2.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "96de42df36bb9bba5542fe9f1a054b8cc87e172759a1868aa05c1f3acc89dfc5" -dependencies = [ - "proc-macro2", - "quote", -] - -[[package]] -name = "proc-macro-error2" -version = "2.0.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "11ec05c52be0a07b08061f7dd003e7d7092e0472bc731b4af7bb1ef876109802" -dependencies = [ - "proc-macro-error-attr2", - "proc-macro2", - "quote", - "syn", -] - [[package]] name = "proc-macro2" version = "1.0.101" @@ -533,26 +482,6 @@ dependencies = [ "unicode-ident", ] -[[package]] -name = "thiserror" -version = "2.0.16" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3467d614147380f2e4e374161426ff399c91084acd2363eaf549172b3d5e60c0" -dependencies = [ - "thiserror-impl", -] - -[[package]] -name = "thiserror-impl" -version = "2.0.16" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6c5e1be1c48b9172ee610da68fd9cd2770e7a4056cb3fc98710ee6906f0c7960" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - [[package]] name = "unicode-ident" version = "1.0.19" diff --git a/Cargo.toml b/Cargo.toml index 3dee775..8c5745f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -30,10 +30,5 @@ embassy-sync = "0.7.2" embassy-executor = { version = "0.9.1", features = [ "arch-std", "executor-thread", - "defmt", -] } -embassy-time = { version = "0.5.0", features = [ - "defmt", - "defmt-timestamp-uptime", - "tick-hz-32_768", ] } +embassy-time = { version = "0.5.0", features = ["mock-driver"] } diff --git a/README.md b/README.md index c457ba9..2b546da 100644 --- a/README.md +++ b/README.md @@ -185,12 +185,72 @@ methods: signal events. The stream yields `Args` structs (e.g., `ControllerPowerErrorArgs`) containing all signal arguments as public fields. +## Poll Methods + +Methods can be marked for periodic execution using poll attributes. These methods are called +automatically by the controller's `run()` loop at the specified interval. + +Three time unit attributes are supported: +* `#[controller(poll_seconds = N)]` - Poll every N seconds. +* `#[controller(poll_millis = N)]` - Poll every N milliseconds. +* `#[controller(poll_micros = N)]` - Poll every N microseconds. + +Example: +```rust,no_run +use firmware_controller::controller; + +#[controller] +mod sensor_controller { + pub struct Controller { + #[controller(publish)] + temperature: f32, + } + + impl Controller { + // Called every 5 seconds. + #[controller(poll_seconds = 5)] + pub async fn read_temperature(&mut self) { + // Read from sensor and update state. + self.set_temperature(42.0).await; + } + + // Called every 100ms. + #[controller(poll_millis = 100)] + pub async fn check_watchdog(&mut self) { + // Pet the watchdog. + } + + // Both called every second (grouped together). + #[controller(poll_seconds = 1)] + pub async fn log_status(&self) { + // Log current status. + } + + #[controller(poll_seconds = 1)] + pub async fn blink_led(&mut self) { + // Toggle LED. + } + } +} + +fn main() {} +``` + +Key characteristics: +* Poll methods are **not** exposed in the client API. They are internal periodic tasks managed + entirely by the controller's `run()` loop. +* Methods with the same timeout value (same unit and value) are grouped into a single timer arm + and called sequentially when the timer fires (in declaration order). +* Uses `embassy_time::Ticker` for timing, which ensures consistent intervals regardless of method + execution time. + ## Dependencies assumed The `controller` macro assumes that you have the following dependencies in your `Cargo.toml`: * `futures` with `async-await` feature enabled. * `embassy-sync` +* `embassy-time` (only required if using poll methods) ## Known limitations & Caveats diff --git a/src/controller/item_impl.rs b/src/controller/item_impl.rs index dbb95fa..279e188 100644 --- a/src/controller/item_impl.rs +++ b/src/controller/item_impl.rs @@ -1,10 +1,12 @@ +use std::collections::BTreeMap; + use proc_macro2::TokenStream; use quote::quote; use syn::{ parse::{Parse, ParseStream}, parse_quote, spanned::Spanned, - Attribute, Ident, ImplItem, ImplItemFn, ItemImpl, Result, Signature, Token, Visibility, + Attribute, Ident, ImplItem, ImplItemFn, ItemImpl, LitInt, Result, Signature, Token, Visibility, }; use crate::controller::item_struct::{GetterFieldInfo, PublishedFieldInfo, SetterFieldInfo}; @@ -26,6 +28,16 @@ pub(crate) fn expand( }); let signal_declarations = signals.clone().map(|s| &s.declarations); + // Extract poll methods and group them by duration. + let poll_methods: Vec<_> = methods + .iter() + .filter_map(|m| match m { + Method::Poll(poll) => Some(poll), + _ => None, + }) + .collect(); + let (poll_ticker_declarations, poll_select_arms) = generate_poll_code(&poll_methods); + let methods = methods.iter().filter_map(|m| match m { Method::Proxied(method) => Some(method), _ => None, @@ -67,12 +79,14 @@ pub(crate) fn expand( #(#args_channels_rx_tx)* #(#pub_setter_rx_tx)* #(#pub_getter_rx_tx)* + #(#poll_ticker_declarations)* loop { futures::select_biased! { #(#select_arms,)* #(#pub_setter_select_arms,)* #(#pub_getter_select_arms,)* + #(#poll_select_arms,)* } } } @@ -154,9 +168,20 @@ fn get_methods(input: &mut ItemImpl, struct_name: &Ident) -> Result> .items .iter_mut() .filter_map(|item| match item { - syn::ImplItem::Fn(m) => Some(ProxiedMethod::parse(m, struct_name).map(Method::Proxied)), + syn::ImplItem::Fn(m) => { + // Check if this is a poll method first. + match PollMethod::parse(m) { + Ok(Some(poll)) => Some(Ok(Method::Poll(poll))), + Ok(None) => { + // Not a poll method, treat as proxied. + Some(ProxiedMethod::parse(m, struct_name).map(Method::Proxied)) + } + Err(e) => Some(Err(e)), + } + } syn::ImplItem::Verbatim(tokens) => { - // … thus parse them ourselves and construct an ImplItemFn from that + // Signal methods have a semicolon at the end instead of a body block, + // thus parse them ourselves and construct an ImplItemFn from that. let ImplItemSignal { attrs, vis, sig } = match syn::parse2::(tokens.clone()) { Ok(decl) => decl, @@ -207,6 +232,8 @@ enum Method { Proxied(ProxiedMethod), /// A signal method. Signal(Signal), + /// A method that will be called periodically. + Poll(PollMethod), } /// Method that will be called by the client. @@ -522,6 +549,164 @@ impl Signal { } } +/// Duration for poll methods. Used as a key for grouping methods with the same timeout. +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)] +enum PollDuration { + Seconds(u64), + Millis(u64), + Micros(u64), +} + +impl PollDuration { + fn to_duration_expr(&self) -> TokenStream { + match self { + PollDuration::Seconds(n) => quote! { embassy_time::Duration::from_secs(#n) }, + PollDuration::Millis(n) => quote! { embassy_time::Duration::from_millis(#n) }, + PollDuration::Micros(n) => quote! { embassy_time::Duration::from_micros(#n) }, + } + } +} + +/// A method that will be called periodically with a timeout. +#[derive(Debug)] +struct PollMethod { + method_name: Ident, + duration: PollDuration, +} + +impl PollMethod { + fn parse(method: &mut ImplItemFn) -> Result> { + let mut duration: Option<(PollDuration, proc_macro2::Span)> = None; + + // Check if method has a poll attribute. + for attr in &method.attrs { + if !attr.path().is_ident("controller") { + continue; + } + + attr.parse_nested_meta(|meta| { + let new_duration = if meta.path.is_ident("poll_seconds") { + meta.input.parse::()?; + let lit: LitInt = meta.input.parse()?; + let value: u64 = lit.base10_parse()?; + if value == 0 { + return Err(syn::Error::new_spanned( + lit, + "poll duration must be greater than zero", + )); + } + Some((PollDuration::Seconds(value), meta.path.span())) + } else if meta.path.is_ident("poll_millis") { + meta.input.parse::()?; + let lit: LitInt = meta.input.parse()?; + let value: u64 = lit.base10_parse()?; + if value == 0 { + return Err(syn::Error::new_spanned( + lit, + "poll duration must be greater than zero", + )); + } + Some((PollDuration::Millis(value), meta.path.span())) + } else if meta.path.is_ident("poll_micros") { + meta.input.parse::()?; + let lit: LitInt = meta.input.parse()?; + let value: u64 = lit.base10_parse()?; + if value == 0 { + return Err(syn::Error::new_spanned( + lit, + "poll duration must be greater than zero", + )); + } + Some((PollDuration::Micros(value), meta.path.span())) + } else { + None + }; + + if let Some((new_dur, span)) = new_duration { + if duration.is_some() { + return Err(syn::Error::new( + span, + "only one poll attribute is allowed per method", + )); + } + duration = Some((new_dur, span)); + } + + Ok(()) + })?; + } + + let Some((duration, _)) = duration else { + return Ok(None); + }; + + // Validate that poll methods have no parameters besides receiver. + let has_non_receiver_params = method + .sig + .inputs + .iter() + .any(|arg| matches!(arg, syn::FnArg::Typed(_))); + if has_non_receiver_params { + return Err(syn::Error::new_spanned( + &method.sig.inputs, + "poll methods cannot have parameters (only `&self` or `&mut self` is allowed)", + )); + } + + // Remove the poll attribute from the method. + remove_poll_attr(method)?; + + let method_name = method.sig.ident.clone(); + Ok(Some(Self { + method_name, + duration, + })) + } +} + +fn remove_poll_attr(method: &mut ImplItemFn) -> syn::Result<()> { + method.attrs = method + .attrs + .iter() + .cloned() + .filter_map(|attr| { + if !attr.path().is_ident("controller") { + return Some(Ok(attr)); + } + + let res = attr.parse_nested_meta(|meta| { + if meta.path.is_ident("poll_seconds") + || meta.path.is_ident("poll_millis") + || meta.path.is_ident("poll_micros") + { + // Consume the `= N` part. + meta.input.parse::()?; + let _: LitInt = meta.input.parse()?; + Ok(()) + } else { + let path = &meta.path; + let found = path + .get_ident() + .map(|ident| ident.to_string()) + .unwrap_or_else(|| quote!(#path).to_string()); + let e = format!( + "poll methods cannot have other `controller` attributes (found `{}`); \ + remove attributes like `getter`, `setter`, `publish`, or `signal`", + found + ); + Err(syn::Error::new_spanned(meta.path, e)) + } + }); + match res { + Err(e) => Some(Err(e)), + Ok(()) => None, + } + }) + .collect::>>()?; + + Ok(()) +} + // Like ImplItemFn, but with a semicolon at the end instead of a body block struct ImplItemSignal { attrs: Vec, @@ -846,3 +1031,42 @@ fn generate_pub_getter(field: &GetterFieldInfo, struct_name: &Ident) -> PubGette client_tx_rx_initializations, } } + +/// Generate ticker declarations and select arms for poll methods, grouped by duration. +/// +/// Returns (ticker_declarations, select_arms) where: +/// - ticker_declarations: Code to create Tickers before the loop. +/// - select_arms: Select arms that wait on ticker.next(). +fn generate_poll_code(poll_methods: &[&PollMethod]) -> (Vec, Vec) { + // Group poll methods by duration. + let mut groups: BTreeMap> = BTreeMap::new(); + for poll in poll_methods { + groups + .entry(poll.duration.clone()) + .or_default() + .push(&poll.method_name); + } + + let mut ticker_declarations = Vec::new(); + let mut select_arms = Vec::new(); + + for (index, (duration, method_names)) in groups.into_iter().enumerate() { + let duration_expr = duration.to_duration_expr(); + let ticker_name = Ident::new( + &format!("__poll_ticker_{index}"), + proc_macro2::Span::call_site(), + ); + + ticker_declarations.push(quote! { + let mut #ticker_name = embassy_time::Ticker::every(#duration_expr); + }); + + select_arms.push(quote! { + _ = futures::FutureExt::fuse(#ticker_name.next()) => { + #(self.#method_names().await;)* + } + }); + } + + (ticker_declarations, select_arms) +} diff --git a/tests/integration.rs b/tests/integration.rs index e0fdbe4..c5cd062 100644 --- a/tests/integration.rs +++ b/tests/integration.rs @@ -273,3 +273,94 @@ fn test_visibility_on_fields() { async fn visibility_controller_task(controller: visibility_test_controller::Controller) { controller.run().await; } + +use std::sync::atomic::{AtomicU32, Ordering}; + +static POLL_A_COUNT: AtomicU32 = AtomicU32::new(0); +static POLL_B_COUNT: AtomicU32 = AtomicU32::new(0); +static POLL_C_COUNT: AtomicU32 = AtomicU32::new(0); + +/// Test poll methods with timeouts. +#[controller] +mod poll_test_controller { + use super::*; + + pub struct Controller { + #[controller(getter)] + pub value: u32, + } + + impl Controller { + // Two methods with the same poll interval (50ms) - should be grouped. + #[controller(poll_millis = 50)] + pub async fn poll_a(&mut self) { + POLL_A_COUNT.fetch_add(1, Ordering::SeqCst); + } + + #[controller(poll_millis = 50)] + pub async fn poll_b(&mut self) { + POLL_B_COUNT.fetch_add(1, Ordering::SeqCst); + } + + // Different poll interval (100ms). + #[controller(poll_millis = 100)] + pub async fn poll_c(&mut self) { + POLL_C_COUNT.fetch_add(1, Ordering::SeqCst); + } + } +} + +/// Test that poll methods are called at the expected intervals. +#[test] +fn poll_methods() { + use embassy_time::{Duration, MockDriver}; + + // Reset mock driver and counters. + let driver = MockDriver::get(); + driver.reset(); + POLL_A_COUNT.store(0, Ordering::SeqCst); + POLL_B_COUNT.store(0, Ordering::SeqCst); + POLL_C_COUNT.store(0, Ordering::SeqCst); + + let controller = poll_test_controller::Controller::new(42); + + // Verify struct fields are accessible. + assert_eq!(controller.value, 42); + + // Run the controller in a background thread. + std::thread::spawn(move || { + let executor = Box::leak(Box::new(embassy_executor::Executor::new())); + executor.run(move |spawner| { + spawner.spawn(poll_controller_task(controller)).unwrap(); + }); + }); + + // Give the executor a moment to start. + std::thread::sleep(std::time::Duration::from_millis(10)); + + // Advance mock time by 50ms - poll_a and poll_b should fire once. + driver.advance(Duration::from_millis(50)); + std::thread::sleep(std::time::Duration::from_millis(10)); + assert_eq!(POLL_A_COUNT.load(Ordering::SeqCst), 1); + assert_eq!(POLL_B_COUNT.load(Ordering::SeqCst), 1); + assert_eq!(POLL_C_COUNT.load(Ordering::SeqCst), 0); + + // Advance another 50ms (total 100ms) - poll_a/poll_b fire again, poll_c fires once. + driver.advance(Duration::from_millis(50)); + std::thread::sleep(std::time::Duration::from_millis(10)); + assert_eq!(POLL_A_COUNT.load(Ordering::SeqCst), 2); + assert_eq!(POLL_B_COUNT.load(Ordering::SeqCst), 2); + assert_eq!(POLL_C_COUNT.load(Ordering::SeqCst), 1); + + // Advance another 100ms (total 200ms) - poll_a/poll_b fire 2 more times, poll_c fires once. + driver.advance(Duration::from_millis(100)); + std::thread::sleep(std::time::Duration::from_millis(10)); + assert_eq!(POLL_A_COUNT.load(Ordering::SeqCst), 4); + assert_eq!(POLL_B_COUNT.load(Ordering::SeqCst), 4); + assert_eq!(POLL_C_COUNT.load(Ordering::SeqCst), 2); +} + +#[embassy_executor::task] +async fn poll_controller_task(controller: poll_test_controller::Controller) { + controller.run().await; +}