From bc4aba1d08faecedf42e312b91b1c5390d3b7df8 Mon Sep 17 00:00:00 2001 From: Johannes Dertmann Date: Sat, 3 May 2025 21:01:38 +0200 Subject: [PATCH 1/3] Add `LruCache` type alias --- src/allris/oparl.rs | 6 ++---- src/lru_cache.rs | 5 +++++ 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/src/allris/oparl.rs b/src/allris/oparl.rs index 979d5f4..004b411 100644 --- a/src/allris/oparl.rs +++ b/src/allris/oparl.rs @@ -11,14 +11,12 @@ use url::Url; use super::{AllrisUrl, Error}; use crate::allris::http_request; -use crate::lru_cache::{Cache, Lru}; - -type LruCache = Cache>; +use crate::lru_cache::{Lru, LruCache}; /// Caches calls to the api's `organization` endpoints, as these information will /// rarely change. static ORGANIZATIONS: LazyLock, Organization)>> = - LazyLock::new(|| Cache::new(Lru::new(50))); + LazyLock::new(|| LruCache::new(Lru::new(50))); #[derive(Debug, Clone, Deserialize, Serialize)] #[serde(rename_all = "camelCase")] diff --git a/src/lru_cache.rs b/src/lru_cache.rs index 5acd109..39c6076 100644 --- a/src/lru_cache.rs +++ b/src/lru_cache.rs @@ -115,6 +115,7 @@ impl> CacheInner { } } +#[derive(Debug)] pub struct CacheItem(Arc>); impl Deref for CacheItem { @@ -188,6 +189,10 @@ impl> Cache { } } + + +pub type LruCache = Cache>; + #[cfg(test)] mod tests { use std::time::Duration; From d1f1a7a6fecb533e955febc2a456fbe4d2de6ca4 Mon Sep 17 00:00:00 2001 From: Johannes Dertmann Date: Sat, 3 May 2025 21:02:34 +0200 Subject: [PATCH 2/3] Add `get_dedicated` function to database --- src/database.rs | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/database.rs b/src/database.rs index 56635bb..fe4cd51 100644 --- a/src/database.rs +++ b/src/database.rs @@ -147,6 +147,7 @@ impl DatabaseConnection { pub fn shared(self) -> SharedDatabaseConnection { SharedDatabaseConnection { + client: self.client.clone(), timeout: self.timeout, connection: Mutex::new(self), } @@ -215,6 +216,13 @@ impl DatabaseConnection { pub struct SharedDatabaseConnection { connection: Mutex, timeout: Option, + client: redis::Client, +} + +impl SharedDatabaseConnection { + pub fn get_dedicated(&self) -> DatabaseConnection { + DatabaseConnection::new(self.client.clone(), self.timeout) + } } #[derive(Debug, Clone, Copy)] From b7cbb932c5283ea001d322a152343479bf870952 Mon Sep 17 00:00:00 2001 From: Johannes Dertmann Date: Sat, 3 May 2025 21:05:32 +0200 Subject: [PATCH 3/3] Move broadcasting to bot-utils crate --- Cargo.lock | 11 + Cargo.toml | 4 +- README.md | 2 +- bot-utils/Cargo.toml | 10 + bot-utils/src/broadcasting.rs | 515 ++++++++++++++++++++++++++ bot-utils/src/lib.rs | 3 + src/broadcasting.rs | 398 +++++++------------- src/broadcasting/scheduled_message.rs | 232 ------------ src/lru_cache.rs | 2 - src/main.rs | 14 +- src/types.rs | 2 - 11 files changed, 676 insertions(+), 517 deletions(-) create mode 100644 bot-utils/Cargo.toml create mode 100644 bot-utils/src/broadcasting.rs create mode 100644 bot-utils/src/lib.rs delete mode 100644 src/broadcasting/scheduled_message.rs diff --git a/Cargo.lock b/Cargo.lock index d9e2048..f235adb 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -30,6 +30,7 @@ dependencies = [ name = "allrisbot" version = "0.1.0" dependencies = [ + "bot-utils", "chrono", "clap", "env_logger", @@ -197,6 +198,16 @@ dependencies = [ "syn", ] +[[package]] +name = "bot-utils" +version = "0.1.0" +dependencies = [ + "frankenstein", + "futures-util", + "log", + "tokio", +] + [[package]] name = "bumpalo" version = "3.17.0" diff --git a/Cargo.toml b/Cargo.toml index b056340..733ab98 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,5 +1,6 @@ [workspace] -members = [ +members = [ + "bot-utils", "telegram-message-builder", "telegram-message-builder-macro", ] @@ -13,6 +14,7 @@ repository = "https://github.com/jdertmann/AllrisBot" license = "AGPL-3.0-or-later" [dependencies] +bot-utils = { path = "bot-utils" } chrono = { version = "0.4.40", features = ["serde"] } clap = { version = "4.5", features = ["derive", "env"] } env_logger = "0.11" diff --git a/README.md b/README.md index b7dad6a..8da2bb6 100644 --- a/README.md +++ b/README.md @@ -31,7 +31,7 @@ See `./target/release/allrisbot --help` for usage details. If you’d like to make contributions, feel free to open an issue or pull request. -## License (`allrisbot` crate) +## License (`allrisbot` and `bot-utils` crate) Copyright (C) 2025 Johannes Dertmann diff --git a/bot-utils/Cargo.toml b/bot-utils/Cargo.toml new file mode 100644 index 0000000..ee44d96 --- /dev/null +++ b/bot-utils/Cargo.toml @@ -0,0 +1,10 @@ +[package] +name = "bot-utils" +version = "0.1.0" +edition = "2024" + +[dependencies] +frankenstein = { version = "0.40", features = ["trait-async"] } +futures-util = { version = "0.3", default-features = false, features = ["alloc"] } +log = "0.4" +tokio = { version = "1.44.2", features = ["sync", "time", "rt-multi-thread", "macros"] } diff --git a/bot-utils/src/broadcasting.rs b/bot-utils/src/broadcasting.rs new file mode 100644 index 0000000..eaaafe0 --- /dev/null +++ b/bot-utils/src/broadcasting.rs @@ -0,0 +1,515 @@ +//! Schedules and broadcasts messages to chats in a rate-limited and fault-tolerant way. + +// The `broadcast_task` function forms the main loop of this module and is responsible +// for triggering per-chat processing until each chat has caught up with the most recent +// entry of the `receive_updates` stream. +// +// Processing for each chat consists of: +// 1. Retrieving and preprocessing of the next message from the backend. +// 2. Sending it to the sender task. +// 3. Waiting for the sender task's confirmation that the message was sent. +// 4. Sleeping for a short duration to comply with per-chat rate limits. +// +// The sender task receives filtered messages and handles the actual delivery while enforcing +// a global broadcast rate limit. + +// TODO: if filter was checked a long time ago, check it again before sending +// TODO: allow sending multiple messages per update + +use std::collections::HashMap; +use std::collections::hash_map::Entry; +use std::error::Error; +use std::fmt::Debug; +use std::hash::Hash; +use std::ops::ControlFlow; +use std::pin::pin; +use std::sync::Arc; +use std::time::Duration; + +use frankenstein::Error as RequestError; +use frankenstein::response::{ErrorResponse, ResponseParameters}; +use futures_util::future::{Fuse, FusedFuture}; +use futures_util::stream::{FusedStream, FuturesUnordered, StreamExt as _}; +use futures_util::{FutureExt, Stream}; +use tokio::sync::{mpsc, oneshot, watch}; +use tokio::task::JoinHandle; +use tokio::time::{Instant, MissedTickBehavior, interval, sleep, sleep_until}; + +use super::ChatId; + +const BROADCASTS_PER_SECOND: f32 = 30.; +const MESSAGE_INTERVAL_CHAT: Duration = Duration::from_secs(1); +const MESSAGE_INTERVAL_GROUP: Duration = Duration::from_secs(3); + +/// error messages that imply we're not allowed to send messages +/// to this chat in the future. +const TELEGRAM_ERRORS: [&str; 14] = [ + "Bad Request: CHAT_WRITE_FORBIDDEN", + "Bad Request: TOPIC_CLOSED", + "Bad Request: chat not found", + "Bad Request: have no rights to send a message", + "Bad Request: not enough rights to send text messages to the chat", + "Bad Request: need administrator rights in the channel chat", + "Forbidden: bot is not a member of the channel chat", + "Forbidden: bot is not a member of the supergroup chat", + "Forbidden: bot was blocked by the user", + "Forbidden: bot was kicked from the channel chat", + "Forbidden: bot was kicked from the group chat", + "Forbidden: bot was kicked from the supergroup chat", + "Forbidden: the group chat was deleted", + "Forbidden: user is deactivated", +]; + +enum ChatStatus { + Processed(B::UpdateId), + OutOfSync, + Stopped, + ShuttingDown, + MigratedTo(ChatId), +} + +fn delay(chat_id: ChatId) -> Duration { + if chat_id < 0 { + MESSAGE_INTERVAL_GROUP + } else { + MESSAGE_INTERVAL_CHAT + } +} + +type OneshotResponse = (Result, ::Error>, bool); +type SendMessage = (ScheduledMessage, oneshot::Sender>); + +pub enum NextUpdate { + Ready { id: B::UpdateId, msg: B::Message }, + Skipped { id: B::UpdateId }, + Pending { previous: B::UpdateId }, + Migrated { to: ChatId }, + Stopped, +} + +macro_rules! ret_ty { + ($x:ty, $e:ty ) => { + impl Future> + Send + }; + ($x:ty) => { + ret_ty!($x, Self::Error) + }; +} + +pub trait Backend: Send + Sync + Sized + 'static { + type UpdateId: Debug + Hash + Ord + Send + Sync + Copy + 'static; + + type Message: Debug + Send + Sync + 'static; + + type Error: Error + Send + 'static; + + /// Returns a stream that first yields the id of the latest update as soon as possible, and then + /// yields whenever there are new updates with a later UpdateId. When it returns None, a soft shutdown + /// is initiated. + fn receive_updates(&self) + -> impl Stream)> + Send + 'static; + + fn next_update(&self, chat: ChatId) -> ret_ty![NextUpdate]; + + fn send(&self, chat: ChatId, message: &Self::Message) -> ret_ty![(), frankenstein::Error]; + + fn acknowledge(&self, chat: ChatId, update: Self::UpdateId) -> ret_ty![bool]; + + fn unacknowledge(&self, chat: ChatId, update: Self::UpdateId) -> ret_ty![bool]; + + fn migrate_chat(&self, old: ChatId, new: ChatId) -> ret_ty![bool]; + + fn remove_chat(&self, id: ChatId) -> ret_ty![bool]; +} + +struct SharedDependencies { + backend: B, + hard_shutdown: watch::Sender, + sender_tx: mpsc::Sender>, +} + +fn backoff_strategy() -> impl Iterator { + (1..=6).map(|i| { + let millis = 10 * 6_u64.pow(i); + let millis = millis.min(120_000); + Duration::from_millis(millis) + }) +} + +/// A message that is scheduled to be sent to a certain chat +struct ScheduledMessage { + pub chat_id: ChatId, + pub update: B::UpdateId, + pub message: B::Message, +} + +impl ScheduledMessage { + async fn unacknowledge(&self, shared: &SharedDependencies) -> Result { + let r = shared + .backend + .unacknowledge(self.chat_id, self.update) + .await?; + if !r { + log::warn!( + "Failed to unacknowledge message {:?} for chat {}", + self.update, + self.chat_id + ); + } + Ok(r) + } + + async fn handle_response( + &self, + shared: &SharedDependencies, + response: Result<(), RequestError>, + backoff: Option, + ) -> Result, Duration>, B::Error> { + if let Err(e) = response.as_ref() { + log::warn!("Failed to send message: {e}"); + } + + macro_rules! retry_with_backoff { + ($dur:expr) => { + if self.unacknowledge(shared).await? { + return Ok(ControlFlow::Continue($dur)); + } else { + ChatStatus::OutOfSync + } + }; + } + + macro_rules! retry { + () => { + if let Some(backoff) = backoff { + retry_with_backoff!(backoff) + } else { + log::error!("Sending failed definitely, not retrying!"); + ChatStatus::Processed(self.update) + } + }; + } + + let result = match response { + Ok(()) => ChatStatus::Processed(self.update), + Err(RequestError::Api(e)) => match e { + ErrorResponse { + error_code: 401 | 404, + .. + } => { + log::error!("Invalid token! Was it revoked?"); + shared.hard_shutdown.send_replace(true); + _ = self.unacknowledge(shared).await?; + ChatStatus::ShuttingDown + } + ErrorResponse { + parameters: + Some(ResponseParameters { + migrate_to_chat_id: Some(new_chat_id), + .. + }), + .. + } => { + _ = self.unacknowledge(shared).await?; + shared + .backend + .migrate_chat(self.chat_id, new_chat_id) + .await?; + ChatStatus::MigratedTo(new_chat_id) + } + ErrorResponse { + parameters: + Some(ResponseParameters { + retry_after: Some(secs), + .. + }), + .. + } => { + retry_with_backoff!(Duration::from_secs(secs as u64)) + } + ErrorResponse { description, .. } + if TELEGRAM_ERRORS.contains(&description.as_str()) => + { + shared.backend.remove_chat(self.chat_id).await?; + ChatStatus::Stopped + } + _ => retry!(), + }, + _ => retry!(), + }; + + Ok(ControlFlow::Break(result)) + } + + /// Sends a message. Will retry a number of times if it fails + async fn send_message( + &self, + shared: &SharedDependencies, + message_sent: &mut bool, + ) -> Result, B::Error> { + let mut backoff = backoff_strategy(); + + loop { + *message_sent = false; + let ack = shared + .backend + .acknowledge(self.chat_id, self.update) + .await?; + if !ack { + return Ok(ChatStatus::OutOfSync); + } + let response = shared.backend.send(self.chat_id, &self.message).await; + *message_sent = true; + + match self + .handle_response(shared, response, backoff.next()) + .await? + { + ControlFlow::Break(result) => return Ok(result), + ControlFlow::Continue(retry_after) => sleep(retry_after).await, + } + } + } +} + +/// Process the next entry of the message stream for a certain chat +async fn try_process_next( + shared: &SharedDependencies, + chat_id: ChatId, +) -> Result, B::Error> { + let started = Instant::now(); + + let (id, message) = match shared.backend.next_update(chat_id).await? { + NextUpdate::Ready { id, msg: next } => (id, next), + NextUpdate::Skipped { id } => { + if shared.backend.acknowledge(chat_id, id).await? { + return Ok(ChatStatus::Processed(id)); + } else { + return Ok(ChatStatus::OutOfSync); + } + } + NextUpdate::Pending { previous: last } => return Ok(ChatStatus::Processed(last)), + NextUpdate::Migrated { to } => return Ok(ChatStatus::MigratedTo(to)), + NextUpdate::Stopped => return Ok(ChatStatus::Stopped), + }; + + // pass the message to the sender task + let scheduled = ScheduledMessage { + chat_id, + message, + update: id, + }; + let (oneshot_tx, oneshot_rx) = oneshot::channel(); + + if shared + .sender_tx + .send((scheduled, oneshot_tx)) + .await + .is_err() + { + return Ok(ChatStatus::ShuttingDown); + } + + match oneshot_rx.await { + Ok((r, true)) => { + // message has been sent, apply a delay for rate limiting + sleep_until(started + delay(chat_id)).await; + r + } + Ok((r, false)) => { + // message has not been sent + r + } + Err(_) => { + // sender task apparently not running anymore + Ok(ChatStatus::ShuttingDown) + } + } +} + +async fn sender_task( + shared: Arc>, + mut sender_rx: mpsc::Receiver>, +) { + let mut shutdown = shared.hard_shutdown.subscribe(); + let mut interval = interval(Duration::from_secs_f32(1. / BROADCASTS_PER_SECOND)); + interval.set_missed_tick_behavior(MissedTickBehavior::Delay); + + loop { + let recv = async { + interval.tick().await; + sender_rx.recv().await + }; + + let (sender, result_tx) = tokio::select! { + biased; + _ = shutdown.wait_for(|x| *x) => break, + next = recv => match next { + Some(next) => next, + None => break + } + }; + + let mut message_sent = false; + let result = sender.send_message(&shared, &mut message_sent).await; + let _ = result_tx.send((result, message_sent)); + } +} + +#[derive(Default)] +struct ProcessingState { + triggered_while_running: bool, +} + +enum ShutdownSignal { + Soft, + Hard, +} + +struct BroadcastManager<'a, B: Backend, Fut, F: Fn(&'a SharedDependencies, ChatId) -> Fut> { + shared: &'a SharedDependencies, + latest_entry_id: Option, + states: HashMap, + process_next_message: F, + processing: FuturesUnordered, +} + +impl<'a, B: Backend, Fut, F: Fn(&'a SharedDependencies, ChatId) -> Fut> + BroadcastManager<'a, B, Fut, F> +{ + /// should be called if there's possibly a new message for this chat + fn trigger_chat(&mut self, chat_id: ChatId) { + match self.states.entry(chat_id) { + Entry::Occupied(mut entry) => { + entry.get_mut().triggered_while_running = true; + } + Entry::Vacant(entry) => { + self.processing + .push((self.process_next_message)(self.shared, chat_id)); + entry.insert(ProcessingState::default()); + } + } + } + + /// triggers all active chats after a new message has arrived + fn on_message_scheduled(&mut self, id: B::UpdateId, active_chats: Vec) { + self.latest_entry_id = Some(id); + + for chat_id in active_chats { + self.trigger_chat(chat_id); + } + } + + fn on_processing_finished(&mut self, chat_id: ChatId, result: Result, B::Error>) { + let restart = self + .states + .remove(&chat_id) + .map(|s| s.triggered_while_running) + .unwrap_or_else(|| { + log::warn!("Unexpectedly missing state"); + true // restart task to be on the safe site + }); + match result { + Ok(ChatStatus::Processed(stream_id)) => { + if Some(stream_id) < self.latest_entry_id { + self.trigger_chat(chat_id); + } + } + Ok(ChatStatus::OutOfSync) => self.trigger_chat(chat_id), + Ok(ChatStatus::Stopped) => { + if restart { + // It's possible that a user unsubscribes and then quickly re-subscribes. + // In such cases, the previous task might report `ChatStopped`, even though the + // chat is active again. We avoid prematurely stopping processing by checking if the chat + // was re-triggered during that time. Ignoring these `ChatStopped` results is + // harmless and helps avoid missing messages. + + self.trigger_chat(chat_id); + } + } + Ok(ChatStatus::MigratedTo(chat_id)) => self.trigger_chat(chat_id), + Ok(ChatStatus::ShuttingDown) => (), + Err(e) => log::error!("{e}"), + } + } +} + +async fn broadcast_task(backend: impl Backend, mut shutdown_rx: mpsc::Receiver) { + let (sender_tx, sender_rx) = mpsc::channel(3); + let shared = Arc::new(SharedDependencies { + sender_tx, + backend, + hard_shutdown: watch::Sender::new(false), + }); + + let mut sender_handle = tokio::spawn(sender_task(shared.clone(), sender_rx)); + let mut soft_shutdown = false; + let mut updates = pin!(shared.backend.receive_updates().fuse()); + let mut manager = BroadcastManager { + shared: &shared, + latest_entry_id: None, + states: HashMap::new(), + process_next_message: |shared, chat_id| async move { + let result = try_process_next(shared, chat_id).await; + (chat_id, result) + }, + processing: FuturesUnordered::new(), + }; + + while !(soft_shutdown && manager.processing.is_empty()) { + tokio::select! { + biased; + _ = &mut sender_handle => return, + signal = shutdown_rx.recv() => match signal { + Some(ShutdownSignal::Soft) => soft_shutdown = true, + Some(ShutdownSignal::Hard) | None => break + }, + item = updates.next(), if !updates.is_terminated() => { + if let Some((id,active_chats)) = item { + manager.on_message_scheduled(id, active_chats) + } else { + soft_shutdown = true; + } + }, + Some((chat_id, result)) = manager.processing.next(), if !manager.processing.is_empty() => { + manager.on_processing_finished(chat_id, result); + } + } + } + + // notify the sender task to stop after the next message + shared.hard_shutdown.send_replace(true); + let _ = sender_handle.await; +} + +pub struct Broadcaster { + shutdown_tx: mpsc::Sender, + handle: Fuse>, +} + +impl Broadcaster { + pub fn new(backend: impl Backend) -> Self { + let (shutdown_tx, shutdown_rx) = mpsc::channel(2); + let handle = tokio::spawn(broadcast_task(backend, shutdown_rx)).fuse(); + Self { + shutdown_tx, + handle, + } + } + + pub async fn soft_shutdown(&mut self) { + _ = self.shutdown_tx.send(ShutdownSignal::Soft).await; + + if !self.handle.is_terminated() { + _ = (&mut self.handle).await; + } + } + + pub async fn hard_shutdown(self) { + _ = self.shutdown_tx.send(ShutdownSignal::Hard).await; + + if !self.handle.is_terminated() { + _ = self.handle.await; + } + } +} diff --git a/bot-utils/src/lib.rs b/bot-utils/src/lib.rs new file mode 100644 index 0000000..8b0e03a --- /dev/null +++ b/bot-utils/src/lib.rs @@ -0,0 +1,3 @@ +pub mod broadcasting; + +pub type ChatId = i64; diff --git a/src/broadcasting.rs b/src/broadcasting.rs index 46690a2..1e0d152 100644 --- a/src/broadcasting.rs +++ b/src/broadcasting.rs @@ -1,318 +1,174 @@ -//! Schedules and broadcasts messages to chats in a rate-limited and fault-tolerant way. -//! -//! A broadcast is scheduled as a Redis stream entry. The `broadcast_task` function forms the -//! main loop of this module and is responsible for triggering per-chat processing until each -//! chat has caught up with the most recent stream entry. -//! -//! Processing for each chat consists of: -//! 1. Retrieving the next message from the database or cache. -//! 2. Checking if the message matches the user's subscription filters. -//! 3. If it does, sending it to the sender task. -//! 4. Sleeping for a short duration to comply with per-chat rate limits. -//! -//! The sender task receives filtered messages and handles the actual delivery while enforcing -//! a global broadcast rate limit. - -mod scheduled_message; - -use std::collections::HashMap; -use std::collections::hash_map::Entry; -use std::pin::pin; -use std::sync::Arc; +use std::future; use std::time::Duration; -use futures_util::stream::{FuturesUnordered, StreamExt as _}; -use tokio::sync::{mpsc, oneshot, watch}; -use tokio::time::{Instant, MissedTickBehavior, interval, sleep, sleep_until}; +use bot_utils::ChatId; +use bot_utils::broadcasting::{Backend, NextUpdate}; +use frankenstein::AsyncTelegramApi as _; +use frankenstein::types::LinkPreviewOptions; +use futures_util::{Stream, StreamExt, stream}; +use regex::Regex; +use tokio::time::sleep; -use self::scheduled_message::ScheduledMessage; use crate::database::{self, ChatState, DatabaseConnection, SharedDatabaseConnection, StreamId}; -use crate::lru_cache::{Cache, Lru}; -use crate::types::{ChatId, Message}; - -const BROADCASTS_PER_SECOND: f32 = 30.; -const MESSAGE_INTERVAL_CHAT: Duration = Duration::from_secs(1); -const MESSAGE_INTERVAL_GROUP: Duration = Duration::from_secs(3); +use crate::lru_cache::{CacheItem, Lru, LruCache}; +use crate::types::{Condition, Filter, Message}; + +impl Condition { + fn matches(&self, message: &Message) -> bool { + let Ok(regex) = Regex::new(&self.pattern) else { + log::warn!("Invalid regex pattern!"); + return false; + }; -enum ProcessNextResult { - Processed(StreamId), - OutOfSync, - ChatStopped, - ShuttingDown, - MigratedTo(ChatId), -} + let result = message + .tags + .iter() + .filter(|x| x.0 == self.tag) + .any(|x| regex.is_match(&x.1)); -fn delay(chat_id: ChatId) -> Duration { - if chat_id < 0 { - MESSAGE_INTERVAL_GROUP - } else { - MESSAGE_INTERVAL_CHAT + result ^ self.negate } } -type OneshotResponse = (database::Result, bool); -type SendMessage = (ScheduledMessage, oneshot::Sender); +impl Filter { + fn matches(&self, message: &Message) -> bool { + for condition in &self.conditions { + if !condition.matches(message) { + return false; + } + } -struct SharedDependencies { - bot: crate::Bot, - db: SharedDatabaseConnection, - hard_shutdown: watch::Sender, - next_message_cache: Cache>, - sender_tx: mpsc::Sender, + true + } } -/// Process the next entry of the message stream for a certain chat -async fn try_process_next( - shared: &SharedDependencies, - chat_id: ChatId, -) -> database::Result { - let started = Instant::now(); - - let last_sent = match shared.db.get_chat_state(chat_id).await? { - ChatState::Active { last_sent } => last_sent, - ChatState::Migrated { to } => return Ok(ProcessNextResult::MigratedTo(to)), - ChatState::Stopped => return Ok(ProcessNextResult::ChatStopped), - }; - - let next_message = shared - .next_message_cache - .get_some(last_sent, || shared.db.get_next_message(last_sent)) - .await?; +pub struct RedisBackend { + pub bot: crate::Bot, + pub db: SharedDatabaseConnection, + pub cache: LruCache, +} - let scheduled = match next_message { - Some(next) => ScheduledMessage::new(chat_id, next), - None => return Ok(ProcessNextResult::Processed(last_sent)), - }; +impl RedisBackend { + pub fn new(bot: crate::Bot, db: redis::Client) -> Self { + let db = DatabaseConnection::new(db, None).shared(); + let cache = LruCache::new(Lru::new(30)); - if !scheduled.check_filters(shared).await? { - // message should not be sent, early return - if scheduled.acknowledge_message(shared).await? { - return Ok(ProcessNextResult::Processed(scheduled.message_id())); - } else { - return Ok(ProcessNextResult::OutOfSync); - } + Self { bot, db, cache } } - // pass the message to the sender task - let (oneshot_tx, oneshot_rx) = oneshot::channel(); - - if shared - .sender_tx - .send((scheduled, oneshot_tx)) - .await - .is_err() - { - return Ok(ProcessNextResult::ShuttingDown); + async fn get_next_entry( + &self, + last_sent: StreamId, + ) -> database::Result>> { + self.cache + .get_some(last_sent, || self.db.get_next_message(last_sent)) + .await } - match oneshot_rx.await { - Ok((r, true)) => { - // message has been sent, apply a delay for rate limiting - sleep_until(started + delay(chat_id)).await; - r - } - Ok((r, false)) => { - // message has not been sent - r - } - Err(_) => { - // sender task apparently not running anymore - Ok(ProcessNextResult::ShuttingDown) - } + async fn matches_filter(&self, chat: i64, msg: &Message) -> database::Result { + let filters = self.db.get_filters(chat).await?; + let matches = filters.iter().any(|filter| filter.matches(msg)); + Ok(matches) } } -async fn sender_task(shared: Arc, mut sender_rx: mpsc::Receiver) { - let mut shutdown = shared.hard_shutdown.subscribe(); - let mut interval = interval(Duration::from_secs_f32(1. / BROADCASTS_PER_SECOND)); - interval.set_missed_tick_behavior(MissedTickBehavior::Delay); +impl Backend for RedisBackend { + type UpdateId = StreamId; - loop { - let recv = async { - interval.tick().await; - sender_rx.recv().await - }; + type Message = CacheItem<(StreamId, Message)>; - let (sender, result_tx) = tokio::select! { - biased; - _ = shutdown.wait_for(|x| *x) => break, - next = recv => match next { - Some(next) => next, - None => break - } - }; + type Error = database::Error; - let mut message_sent = false; - let result = sender.send_message(&shared, &mut message_sent).await; - let _ = result_tx.send((result, message_sent)); + async fn acknowledge( + &self, + chat_id: i64, + message_id: Self::UpdateId, + ) -> Result { + self.db.acknowledge_message(chat_id, message_id).await } -} -#[derive(Default)] -struct ProcessingState { - should_restart: bool, -} - -pub enum ShutdownSignal { - Soft, - Hard, -} - -struct BroadcastManager<'a, Fut, F: Fn(&'a SharedDependencies, ChatId) -> Fut> { - shared: &'a SharedDependencies, - latest_entry_id: Option, - states: HashMap, - process_next_message: F, - processing: FuturesUnordered, -} - -impl<'a, Fut, F: Fn(&'a SharedDependencies, ChatId) -> Fut> BroadcastManager<'a, Fut, F> { - /// should be called if there's possibly a new message for this chat - fn trigger_chat(&mut self, chat_id: ChatId) { - match self.states.entry(chat_id) { - Entry::Occupied(mut entry) => { - entry.get_mut().should_restart = true; - } - Entry::Vacant(entry) => { - self.processing - .push((self.process_next_message)(self.shared, chat_id)); - entry.insert(ProcessingState::default()); - } - } + async fn unacknowledge( + &self, + chat_id: i64, + message_id: Self::UpdateId, + ) -> Result { + self.db.unacknowledge_message(chat_id, message_id).await } - /// triggers all active chats after a new message has arrived - fn on_message_scheduled(&mut self, id: StreamId, active_chats: Vec) { - self.latest_entry_id = Some(id); - - for chat_id in active_chats { - self.trigger_chat(chat_id); - } + async fn migrate_chat( + &self, + old_chat_id: ChatId, + new_chat_id: ChatId, + ) -> Result { + self.db.migrate_chat(old_chat_id, new_chat_id).await } - fn on_processing_finished( - &mut self, - chat_id: ChatId, - result: database::Result, - ) { - let restart = self.states.remove(&chat_id).unwrap().should_restart; - match result { - Ok(ProcessNextResult::Processed(stream_id)) => { - if Some(stream_id) < self.latest_entry_id { - self.trigger_chat(chat_id); - } - } - Ok(ProcessNextResult::OutOfSync) => self.trigger_chat(chat_id), - Ok(ProcessNextResult::ChatStopped) => { - if restart { - // It's possible that a user unsubscribes and then quickly re-subscribes. - // In such cases, the previous task might report `ChatStopped`, even though the - // chat is active again. We avoid prematurely stopping processing by checking if the chat - // was re-triggered during that time. Ignoring these outdated `ChatStopped` results is - // harmless and helps avoid missing messages. - - self.trigger_chat(chat_id); - } - } - Ok(ProcessNextResult::MigratedTo(chat_id)) => self.trigger_chat(chat_id), - Ok(ProcessNextResult::ShuttingDown) => (), - Err(e) => log::error!("Database error: {e}"), - } + async fn remove_chat(&self, chat_id: ChatId) -> Result { + self.db.remove_subscription(chat_id).await } - /// Returns the next stream id and a list of active chats - /// as soon as a new entry is added to the message stream - fn next_message_ready( - &self, - mut conn: DatabaseConnection, - was_error: bool, - ) -> impl Future< - Output = ( - database::Result<(StreamId, Vec)>, - DatabaseConnection, - ), - > + 'static { - let id = self.latest_entry_id; + async fn next_update(&self, chat: ChatId) -> Result, Self::Error> { + let last_sent = match self.db.get_chat_state(chat).await? { + ChatState::Active { last_sent } => last_sent, + ChatState::Migrated { to } => return Ok(NextUpdate::Migrated { to }), + ChatState::Stopped => return Ok(NextUpdate::Stopped), + }; - async move { - if was_error { - sleep(Duration::from_secs(20)).await; + let update = match self.get_next_entry(last_sent).await? { + Some(msg) if self.matches_filter(chat, &msg.1).await? => { + NextUpdate::Ready { id: msg.0, msg } } - - let result = async { - let next_id = if let Some(id) = id { - conn.next_message_id_blocking(id).await? - } else { - conn.current_message_id().await? - }; - let active_chats = conn.get_active_chats().await?; - Ok((next_id, active_chats)) + Some(msg) => { + self.acknowledge(chat, msg.0).await?; + NextUpdate::Skipped { id: msg.0 } } - .await; + None => NextUpdate::Pending { + previous: last_sent, + }, + }; - (result, conn) - } + Ok(update) } -} - -pub async fn broadcast_task( - bot: crate::Bot, - db: redis::Client, - mut shutdown_rx: mpsc::UnboundedReceiver, -) { - let (sender_tx, sender_rx) = mpsc::channel(3); - let shared = Arc::new(SharedDependencies { - bot, - sender_tx, - db: DatabaseConnection::new(db.clone(), None).shared(), - hard_shutdown: watch::Sender::new(false), - next_message_cache: Cache::new(Lru::new(15)), - }); + fn receive_updates(&self) -> impl Stream)> + 'static { + let db = self.db.get_dedicated(); - let mut sender_handle = tokio::spawn(sender_task(shared.clone(), sender_rx)); - - let mut soft_shutdown = false; - - let mut manager = BroadcastManager { - shared: &shared, - latest_entry_id: None, - states: HashMap::new(), - process_next_message: |shared, chat_id| async move { - let result = try_process_next(shared, chat_id).await; - (chat_id, result) - }, - processing: FuturesUnordered::new(), - }; + stream::unfold( + (None, false, db), + |(last_stream_id, was_error, mut db)| async move { + if was_error { + sleep(Duration::from_secs(20)).await; + } - let conn = DatabaseConnection::new(db, None); - let mut next_message_ready = pin!(manager.next_message_ready(conn, false)); + let result: Result<_, Self::Error> = async { + let next_id = if let Some(id) = last_stream_id { + db.next_message_id_blocking(id).await? + } else { + db.current_message_id().await? + }; + let active_chats = db.get_active_chats().await?; + Ok((next_id, active_chats)) + } + .await; - while !(soft_shutdown && manager.processing.is_empty()) { - tokio::select! { - biased; - _ = &mut sender_handle => return, - signal = shutdown_rx.recv() => match signal { - Some(ShutdownSignal::Soft) => soft_shutdown = true, - Some(ShutdownSignal::Hard) | None => break - }, - (result, conn) = &mut next_message_ready => { let was_error = result.is_err(); + let item = result.ok(); + let stream_id = item.as_ref().map(|item| item.0).or(last_stream_id); - match result { - Ok((id, active_chats)) => manager.on_message_scheduled(id, active_chats), - Err(e) => log::error!("Database error: {e}") - }; - - next_message_ready.set(manager.next_message_ready(conn, was_error)); + Some((item, (stream_id, was_error, db))) }, - Some(result) = manager.processing.next(), if !manager.processing.is_empty() => { - manager.on_processing_finished(result.0, result.1); - } - } + ) + .filter_map(future::ready) } - // notifiy the sender task to stop after the next message - shared.hard_shutdown.send_replace(true); - let _ = sender_handle.await; + async fn send(&self, chat_id: i64, message: &Self::Message) -> Result<(), frankenstein::Error> { + let message = &message.1; + let mut params = message.request.clone(); + params.chat_id = chat_id.into(); + params.link_preview_options = Some(LinkPreviewOptions::builder().is_disabled(true).build()); + + self.bot.send_message(¶ms).await?; + + Ok(()) + } } diff --git a/src/broadcasting/scheduled_message.rs b/src/broadcasting/scheduled_message.rs deleted file mode 100644 index 6b5c2bd..0000000 --- a/src/broadcasting/scheduled_message.rs +++ /dev/null @@ -1,232 +0,0 @@ -use std::ops::ControlFlow; -use std::time::Duration; - -use frankenstein::response::{ErrorResponse, ResponseParameters}; -use frankenstein::types::LinkPreviewOptions; -use frankenstein::{AsyncTelegramApi, Error as RequestError}; -use regex::Regex; -use tokio::time::sleep; -use tokio_retry::strategy::{ExponentialBackoff, jitter}; - -use super::{ProcessNextResult, SharedDependencies}; -use crate::database::{self, StreamId}; -use crate::lru_cache::CacheItem; -use crate::types::{ChatId, Condition, Filter, Message}; - -/// error messages that imply we're not allowed to send messages -/// to this chat in the future. -const TELEGRAM_ERRORS: [&str; 14] = [ - "Bad Request: CHAT_WRITE_FORBIDDEN", - "Bad Request: TOPIC_CLOSED", - "Bad Request: chat not found", - "Bad Request: have no rights to send a message", - "Bad Request: not enough rights to send text messages to the chat", - "Bad Request: need administrator rights in the channel chat", - "Forbidden: bot is not a member of the channel chat", - "Forbidden: bot is not a member of the supergroup chat", - "Forbidden: bot was blocked by the user", - "Forbidden: bot was kicked from the channel chat", - "Forbidden: bot was kicked from the group chat", - "Forbidden: bot was kicked from the supergroup chat", - "Forbidden: the group chat was deleted", - "Forbidden: user is deactivated", -]; - -fn backoff_strategy() -> impl Iterator { - ExponentialBackoff::from_millis(10) - .factor(10) - .max_delay(Duration::from_secs(30)) - .map(jitter) - .take(5) -} - -impl Condition { - fn matches(&self, message: &Message) -> Result { - let regex = Regex::new(&self.pattern)?; - let result = message - .tags - .iter() - .filter(|x| x.0 == self.tag) - .any(|x| regex.is_match(&x.1)); - - Ok(result ^ self.negate) - } -} - -impl Filter { - fn matches(&self, message: &Message) -> Result { - for condition in &self.conditions { - if !condition.matches(message)? { - return Ok(false); - } - } - - Ok(true) - } -} - -/// A message that is scheduled to be sent to a certain chat -pub struct ScheduledMessage { - chat_id: ChatId, - entry: CacheItem<(StreamId, Message)>, -} - -impl ScheduledMessage { - pub fn new(chat_id: ChatId, entry: CacheItem<(StreamId, Message)>) -> Self { - Self { chat_id, entry } - } - - pub fn message_id(&self) -> StreamId { - self.entry.0 - } - - pub fn message(&self) -> &Message { - &self.entry.1 - } - - /// checks whether this message should be sent - pub async fn check_filters(&self, shared: &SharedDependencies) -> database::Result { - let filters = shared.db.get_filters(self.chat_id).await?; - for filter in filters { - if filter.matches(self.message())? { - return Ok(true); - } - } - - Ok(false) - } - - /// mark this message as sent in the database - pub async fn acknowledge_message(&self, shared: &SharedDependencies) -> database::Result { - shared - .db - .acknowledge_message(self.chat_id, self.message_id()) - .await - } - - async fn unacknowledge_message(&self, shared: &SharedDependencies) -> database::Result { - shared - .db - .unacknowledge_message(self.chat_id, self.message_id()) - .await - } - - async fn handle_response( - &self, - shared: &SharedDependencies, - response: Result<(), RequestError>, - backoff: Option, - ) -> database::Result> { - let response = response.inspect_err(|e| log::warn!("Failed to send message: {e}")); - - macro_rules! retry { - ($dur:expr) => { - if self.unacknowledge_message(shared).await? { - return Ok(ControlFlow::Continue($dur)); - } else { - ProcessNextResult::OutOfSync - } - }; - } - - let result = match response { - Ok(()) => ProcessNextResult::Processed(self.message_id()), - Err(RequestError::Api(e)) => match e { - ErrorResponse { - error_code: 401 | 404, - .. - } => { - log::error!("Invalid token! Was it revoked?"); - shared.hard_shutdown.send_replace(true); - ProcessNextResult::ShuttingDown - } - ErrorResponse { - parameters: - Some(ResponseParameters { - migrate_to_chat_id: Some(new_chat_id), - .. - }), - .. - } => { - self.unacknowledge_message(shared).await?; - shared.db.migrate_chat(self.chat_id, new_chat_id).await?; - ProcessNextResult::MigratedTo(new_chat_id) - } - ErrorResponse { - parameters: - Some(ResponseParameters { - retry_after: Some(secs), - .. - }), - .. - } => { - retry!(Duration::from_secs(secs as u64)) - } - ErrorResponse { description, .. } - if TELEGRAM_ERRORS.contains(&description.as_str()) => - { - shared.db.remove_subscription(self.chat_id).await?; - ProcessNextResult::ChatStopped - } - _ => { - if let Some(backoff) = backoff { - retry!(backoff) - } else { - log::error!("Sending failed definitely, not retrying!"); - ProcessNextResult::Processed(self.message_id()) - } - } - }, - _ => { - if let Some(backoff) = backoff { - retry!(backoff) - } else { - log::error!("Sending failed definitely, not retrying!"); - ProcessNextResult::Processed(self.message_id()) - } - } - }; - - Ok(ControlFlow::Break(result)) - } - - /// Sends a message. Will retry a number of times if it fails - pub async fn send_message( - &self, - shared: &SharedDependencies, - message_sent: &mut bool, - ) -> database::Result { - let mut backoff = backoff_strategy(); - - loop { - *message_sent = false; - - if !self.acknowledge_message(shared).await? { - return Ok(ProcessNextResult::OutOfSync); - } - - let response = self.try_send_message(shared).await; - *message_sent = true; - - match self - .handle_response(shared, response, backoff.next()) - .await? - { - ControlFlow::Break(result) => return Ok(result), - ControlFlow::Continue(retry_after) => sleep(retry_after).await, - } - } - } - - async fn try_send_message(&self, shared: &SharedDependencies) -> Result<(), RequestError> { - let message = self.message(); - - let mut params = message.request.clone(); - params.chat_id = self.chat_id.into(); - params.link_preview_options = Some(LinkPreviewOptions::builder().is_disabled(true).build()); - - shared.bot.send_message(¶ms).await?; - - Ok(()) - } -} diff --git a/src/lru_cache.rs b/src/lru_cache.rs index 39c6076..5458cad 100644 --- a/src/lru_cache.rs +++ b/src/lru_cache.rs @@ -189,8 +189,6 @@ impl> Cache { } } - - pub type LruCache = Cache>; #[cfg(test)] diff --git a/src/main.rs b/src/main.rs index 6c8a984..a422256 100644 --- a/src/main.rs +++ b/src/main.rs @@ -19,11 +19,12 @@ use std::error::Error; use std::process::ExitCode; use std::time::Duration; -use broadcasting::broadcast_task; +use bot_utils::broadcasting::Broadcaster; +use broadcasting::RedisBackend; use clap::Parser; use database::DatabaseConnection; use redis::{ConnectionInfo, IntoConnectionInfo}; -use tokio::sync::{mpsc, oneshot}; +use tokio::sync::oneshot; use url::Url; use crate::allris::AllrisUrl; @@ -162,8 +163,7 @@ async fn main() -> ExitCode { let scraper_handle = tokio::spawn(scraper_task); // start the broadcasting task - let (ctrl_tx, ctrl_rx) = mpsc::unbounded_channel(); - let mut broadcaster_handle = tokio::spawn(broadcast_task(bot, db_client, ctrl_rx)); + let mut broadcaster = Broadcaster::new(RedisBackend::new(bot, db_client)); // listen for CTRL+C tokio::signal::ctrl_c() @@ -177,17 +177,15 @@ async fn main() -> ExitCode { // wait until message queue is empty, unless CTRL+C is pressed a second time // or 20 seconds have passed - let _ = ctrl_tx.send(broadcasting::ShutdownSignal::Soft); let success = tokio::select! { - _ = &mut broadcaster_handle => true, + _ = broadcaster.soft_shutdown() => true, _ = tokio::signal::ctrl_c() => false, _ = tokio::time::sleep(Duration::from_secs(20)) => false }; if !success { log::warn!("Not all pending messages have been sent, shutting down anyway ..."); - let _ = ctrl_tx.send(broadcasting::ShutdownSignal::Hard); - let _ = broadcaster_handle.await; + broadcaster.hard_shutdown().await; } // We want users to be able to stop broadcasts even if we're in the process of shutting down, diff --git a/src/types.rs b/src/types.rs index 0def40e..4b2298e 100644 --- a/src/types.rs +++ b/src/types.rs @@ -3,8 +3,6 @@ use std::fmt::Display; use frankenstein::methods::SendMessageParams; use serde::{Deserialize, Serialize}; -pub type ChatId = i64; - #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Message { pub request: SendMessageParams,