Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/data/base_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,8 @@ def delete(self, model: BaseModel):
sql = text(f"""DELETE FROM {model.table}
WHERE {model.pk_field} = :pk""")

return self.execute(sql, pk=getattr(model, model.pk_field))
with session_scope() as session:
return session.execute(sql, {"pk": getattr(model, model.pk_field)}).rowcount

def execute(self, sql: Union[str, text], **kwargs):
if isinstance(sql, str):
Expand Down
34 changes: 34 additions & 0 deletions src/data/rabbitmq_pending_message_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from typing import Optional

from sqlalchemy.sql import text

from data.base_data import BaseModel, BaseData


class RabbitmqPendingMessageModel(BaseModel):
_table = "rabbitmq_pending_messages"
_pk_field = "id"
_columns = ["id", "type", "exchange_name", "queue_name", "json_body", "created_time"]


class RabbitmqPendingMessageData(BaseData):
def get_rabbitmq_pending_message_by_id(self, id: int) -> Optional[RabbitmqPendingMessageModel]:
sql = text("""
SELECT * FROM rabbitmq_pending_messages
WHERE id = :id;
""")

result_rows = self.execute(sql, id=id)
if not result_rows:
return None

return RabbitmqPendingMessageModel(result_rows[0])

def get_rabbitmq_pending_messages(self) -> list[RabbitmqPendingMessageModel]:
sql = text("""
SELECT * FROM rabbitmq_pending_messages
ORDER BY created_time ASC;
""")

result_rows = self.execute(sql)
return [RabbitmqPendingMessageModel(row) for row in result_rows]
62 changes: 47 additions & 15 deletions src/services/rabbit_service.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
import json
import pika
import praw
from collections import deque
from json import JSONEncoder
from praw.models.mod_action import ModAction
from praw.models.reddit.comment import Comment
from praw.models.reddit.submission import Submission
from types import FunctionType
from typing import Deque
from uuid import UUID

from services import rabbitmq_pending_message_service
from data.comment_data import CommentModel
from data.mod_action_data import ModActionModel
from data.post_data import PostModel
Expand All @@ -34,8 +33,6 @@ def default(self, obj):


class RabbitService:
messages_to_retry: Deque[dict] = deque()

def __init__(self, config_dict: dict):
self.config = config_dict
self.connection = None
Expand Down Expand Up @@ -79,11 +76,7 @@ def __init__(self, config_dict: dict):

self.queues[key] = {"exchange": exchange_name, "queue": queue_name}

if self.messages_to_retry:
logger.info(f"Retrying {len(self.messages_to_retry)} RabbitMQ messages")
while self.messages_to_retry:
exchange_name, queue_name, json_body = self.messages_to_retry.popleft()
self._publish_message(exchange_name, queue_name, json_body)
self._republish_messages()

def init_connection(self, reconnect: bool = True):
logger.info(f"{"Rec" if reconnect else "C"}onnecting to RabbitMQ...")
Expand All @@ -95,21 +88,21 @@ def publish_post(self, reddit_post: Submission, post: PostModel, status: str = "
logger.info(f"Publishing post to RabbitMQ: {reddit_post.id} ({status})")
queue = self.queues["post"]
body = {"status": status, "reddit": reddit_post, "db": post.to_dict()}
self._publish_message(queue["exchange"], queue["queue"], json.dumps(body, cls=PRAWJSONEncoder))
self._publish_message(queue["exchange"], queue["queue"], json.dumps(body, cls=PRAWJSONEncoder), "post")

def publish_comment(self, reddit_comment: Comment, comment: CommentModel, status: str = "new"):
logger.info(f"Publishing comment to RabbitMQ: {reddit_comment.id} ({status})")
queue = self.queues["comment"]
body = {"status": status, "reddit": reddit_comment, "db": comment.to_dict()}
self._publish_message(queue["exchange"], queue["queue"], json.dumps(body, cls=PRAWJSONEncoder))
self._publish_message(queue["exchange"], queue["queue"], json.dumps(body, cls=PRAWJSONEncoder), "comment")

def publish_mod_action(self, reddit_mod_action: ModAction, mod_action: ModActionModel, status: str = "new"):
logger.info(f"Publishing mod action to RabbitMQ: {reddit_mod_action.id} ({status})")
queue = self.queues["mod_action"]
body = {"status": status, "reddit": reddit_mod_action, "db": mod_action.to_dict()}
self._publish_message(queue["exchange"], queue["queue"], json.dumps(body, cls=PRAWJSONEncoder))
self._publish_message(queue["exchange"], queue["queue"], json.dumps(body, cls=PRAWJSONEncoder), "mod action")

def _publish_message(self, exchange_name: str, queue_name: str, json_body: str):
def _publish_message(self, exchange_name: str, queue_name: str, json_body: str, type: str):
try:
self.channel.basic_publish(
exchange=exchange_name,
Expand All @@ -136,6 +129,45 @@ def _publish_message(self, exchange_name: str, queue_name: str, json_body: str):
)
logger.info("Successfully send message after reconnect")
except Exception:
logger.error("Still couldn't connect to RabbitMQ. Saving message to retry memory list")
self.messages_to_retry.append((exchange_name, queue_name, json_body))
logger.exception("Still couldn't connect to RabbitMQ. Saving message to retry table")
rabbitmq_pending_message_service.insert_pending_message(exchange_name, queue_name, json_body, type)
raise

def _republish_messages(self):
messages_to_retry = rabbitmq_pending_message_service.get_pending_messages()
if messages_to_retry:
logger.info(f"Retrying {len(messages_to_retry)} RabbitMQ messages")
number_of_successful_retries = 0
for pending_message in messages_to_retry:
logger.info(
f"Republishing {pending_message.type} to RabbitMQ"
+ f": {pending_message.json_body["reddit"]["id"]} ({pending_message.json_body["status"]})"
)
try:
json_body = json.dumps(pending_message.json_body, cls=PRAWJSONEncoder)
self._republish_message(pending_message.exchange_name, pending_message.queue_name, json_body)
number_of_successful_retries += rabbitmq_pending_message_service.delete_pending_message(pending_message)
except Exception:
if number_of_successful_retries > 0:
logger.warn(
f"Successfully retried {number_of_successful_retries}"
+ f" / {len(messages_to_retry)} RabbitMQ messages"
)
logger.exception("RabbitMQ is still down. Messages will stay pending.")
raise
logger.info(f"Successfully retried all {number_of_successful_retries} RabbitMQ messages")

def _republish_message(self, exchange_name: str, queue_name: str, json_body: str):
try:
self.channel.basic_publish(
exchange=exchange_name,
routing_key=queue_name,
body=json_body,
properties=pika.BasicProperties(
delivery_mode=pika.DeliveryMode.Persistent,
content_type="application/json",
headers={self.config["retry_attempt_header"]: 1},
),
)
except Exception:
logger.exception("Still couldn't connect to RabbitMQ. Message already saved to retry")
40 changes: 40 additions & 0 deletions src/services/rabbitmq_pending_message_service.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from typing import Optional

from data.rabbitmq_pending_message_data import RabbitmqPendingMessageData, RabbitmqPendingMessageModel

_rabbitmq_pending_message_data = RabbitmqPendingMessageData()


def get_pending_message_by_id(id: int) -> Optional[RabbitmqPendingMessageModel]:
"""
Gets a single rabbitmq_pending_message from the database.
"""

return _rabbitmq_pending_message_data.get_rabbitmq_pending_message_by_id(id)


def get_pending_messages() -> list[RabbitmqPendingMessageModel]:
"""
Get all rabbitmq_pending_messages in the DB. Ordered by created_time ascending
"""

return _rabbitmq_pending_message_data.get_rabbitmq_pending_messages()


def insert_pending_message(
exchange_name: str, queue_name: str, json_body: str, type: str
) -> RabbitmqPendingMessageModel:
"""Adds a new pending rabbitmq message to the database."""

db_model = RabbitmqPendingMessageModel()
db_model.exchange_name = exchange_name
db_model.queue_name = queue_name
db_model.json_body = json_body
db_model.type = type

saved_db_model = _rabbitmq_pending_message_data.insert(db_model)
return saved_db_model


def delete_pending_message(pending_message: RabbitmqPendingMessageModel):
return _rabbitmq_pending_message_data.delete(pending_message)
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
"""Add rabbitmq pending table

Revision ID: a5f2ff2c36f8
Revises: 03f4360f81fc
Create Date: 2025-12-20 02:24:58.431032+00:00

"""

from alembic import op

# revision identifiers, used by Alembic.
revision = "a5f2ff2c36f8"
down_revision = "03f4360f81fc"
branch_labels = None
depends_on = None


def upgrade():
op.execute("""
CREATE TABLE rabbitmq_pending_messages (
id BIGSERIAL PRIMARY KEY,
type TEXT NOT NULL,
exchange_name TEXT NOT NULL,
queue_name TEXT NOT NULL,
json_body JSONB NOT NULL,
created_time TIMESTAMPTZ NOT NULL DEFAULT now()
);
CREATE UNIQUE INDEX IF NOT EXISTS
idx_rabbitmq_pending_messages_created_time ON rabbitmq_pending_messages(created_time);
""")


def downgrade():
op.execute("""
DROP TABLE IF EXISTS rabbitmq_pending_messages;
""")