diff --git a/assign_roles/assign_roles.py b/assign_roles/assign_roles.py index b627c13..411d338 100644 --- a/assign_roles/assign_roles.py +++ b/assign_roles/assign_roles.py @@ -1,3 +1,5 @@ +import asyncio + import discord from redbot.core import commands # Changed from discord.ext @@ -34,6 +36,25 @@ def __init__(self, bot: Red): self.config = Config.get_conf(self, identifier=73600, force_registration=True) self.config.register_guild(roles={}) + # Lock to prevent race conditions when updating roles in Config + self._config_locks = {} # {guild_id: asyncio.Lock} + self._locks_creation_lock = asyncio.Lock() # Lock for creating per-guild locks + + def _get_guild_lock(self, guild_id: int) -> asyncio.Lock: + """Get or create a lock for a specific guild to prevent config race conditions. + + Args: + guild_id: The ID of the guild + + Returns: + An asyncio.Lock for the guild + """ + if guild_id not in self._config_locks: + # Use a lock to ensure only one lock is created per guild + # Note: We can't await here, but this is safe because dict access is atomic + self._config_locks[guild_id] = asyncio.Lock() + return self._config_locks[guild_id] + # Events # Commands @@ -92,25 +113,33 @@ async def authorise(self, ctx, authorised_role: discord.Role, giveable_role: dis await ctx.defer(ephemeral=True) gld = ctx.guild - server_dict = await self.config.guild(gld).roles() - author_max_role = max(r for r in ctx.author.roles) - authorised_id = str(authorised_role.id) - giveable_id = str(giveable_role.id) - - if authorised_role.is_default(): # Role to be authorised should not be @everyone. - notice = self.AUTHORISE_NO_EVERYONE - elif giveable_role.is_default(): # Same goes for role to be given. - notice = self.AUTHORISE_NOT_DEFAULT - elif authorised_role >= author_max_role and ctx.author != gld.owner: # Hierarchical role order check. - notice = self.AUTHORISE_NO_HIGHER - # Check if "pair" already exists. - elif giveable_id in server_dict and authorised_id in server_dict[giveable_id]: - notice = self.AUTHORISE_EXISTS - else: # Role authorisation is valid. - server_dict.setdefault(giveable_id, []).append(authorised_id) - await self.config.guild(gld).roles.set(server_dict) - notice = self.AUTHORISE_SUCCESS.format(authorised_role.name, giveable_role.name) + # Use lock to prevent race conditions when multiple users authorize roles simultaneously + lock = self._get_guild_lock(gld.id) + async with lock: + server_dict = await self.config.guild(gld).roles() + + author_max_role = max(r for r in ctx.author.roles) + authorised_id = str(authorised_role.id) + giveable_id = str(giveable_role.id) + + if authorised_role.is_default(): # Role to be authorised should not be @everyone. + notice = self.AUTHORISE_NO_EVERYONE + elif giveable_role.is_default(): # Same goes for role to be given. + notice = self.AUTHORISE_NOT_DEFAULT + elif authorised_role >= author_max_role and ctx.author != gld.owner: # Hierarchical role order check. + notice = self.AUTHORISE_NO_HIGHER + # Check if "pair" already exists. + elif giveable_id in server_dict and authorised_id in server_dict[giveable_id]: + notice = self.AUTHORISE_EXISTS + else: # Role authorisation is valid. + if giveable_id not in server_dict: + server_dict[giveable_id] = [] + # Double-check for duplicates before appending (safety check) + if authorised_id not in server_dict[giveable_id]: + server_dict[giveable_id].append(authorised_id) + await self.config.guild(gld).roles.set(server_dict) + notice = self.AUTHORISE_SUCCESS.format(authorised_role.name, giveable_role.name) await ctx.send(notice, ephemeral=True) @commands.guild_only() @@ -131,26 +160,30 @@ async def deauthorise(self, ctx, authorised_role: discord.Role, giveable_role: d await ctx.defer(ephemeral=True) gld = ctx.guild - server_dict = await self.config.guild(gld).roles() - author_max_role = max(r for r in ctx.author.roles) - authorised_id = str(authorised_role.id) - giveable_id = str(giveable_role.id) - - if authorised_role.is_default(): # Role to be de-authorised should not be @everyone. - notice = self.AUTHORISE_NO_EVERYONE - elif giveable_role.is_default(): # Same goes for role to be given. - notice = self.AUTHORISE_NOT_DEFAULT - elif authorised_role >= author_max_role and ctx.author != gld.owner: # Hierarchical role order check. - notice = self.AUTHORISE_NO_HIGHER - elif giveable_id not in server_dict: - notice = self.AUTHORISE_EMPTY.format(giveable_role.name) - elif authorised_id not in server_dict[giveable_id]: - notice = self.AUTHORISE_MISMATCH.format(authorised_role.name, giveable_role.name) - else: # Role de-authorisation is valid. - server_dict[giveable_id].remove(authorised_id) - await self.config.guild(gld).roles.set(server_dict) - notice = self.DEAUTHORISE_SUCCESS.format(authorised_role.name, giveable_role.name) + # Use lock to prevent race conditions when multiple users deauthorize roles simultaneously + lock = self._get_guild_lock(gld.id) + async with lock: + server_dict = await self.config.guild(gld).roles() + + author_max_role = max(r for r in ctx.author.roles) + authorised_id = str(authorised_role.id) + giveable_id = str(giveable_role.id) + + if authorised_role.is_default(): # Role to be de-authorised should not be @everyone. + notice = self.AUTHORISE_NO_EVERYONE + elif giveable_role.is_default(): # Same goes for role to be given. + notice = self.AUTHORISE_NOT_DEFAULT + elif authorised_role >= author_max_role and ctx.author != gld.owner: # Hierarchical role order check. + notice = self.AUTHORISE_NO_HIGHER + elif giveable_id not in server_dict: + notice = self.AUTHORISE_EMPTY.format(giveable_role.name) + elif authorised_id not in server_dict[giveable_id]: + notice = self.AUTHORISE_MISMATCH.format(authorised_role.name, giveable_role.name) + else: # Role de-authorisation is valid. + server_dict[giveable_id].remove(authorised_id) + await self.config.guild(gld).roles.set(server_dict) + notice = self.DEAUTHORISE_SUCCESS.format(authorised_role.name, giveable_role.name) await ctx.send(notice, ephemeral=True) @commands.guild_only() diff --git a/party/party.py b/party/party.py index 706defe..d55a76d 100644 --- a/party/party.py +++ b/party/party.py @@ -1,3 +1,4 @@ +import asyncio import logging import secrets from datetime import datetime, timezone @@ -127,16 +128,19 @@ async def on_submit(self, interaction: discord.Interaction): new_description = self.description_input.value.strip() or None # Update the party data - async with self.cog.config.guild(interaction.guild).parties() as parties: - if self.party_id not in parties: - await interaction.followup.send("❌ Party not found.", ephemeral=True) - return + # Use lock to prevent race conditions + lock = self.cog._get_guild_lock(interaction.guild.id) + async with lock: + async with self.cog.config.guild(interaction.guild).parties() as parties: + if self.party_id not in parties: + await interaction.followup.send("❌ Party not found.", ephemeral=True) + return - old_title = parties[self.party_id]['name'] - old_description = parties[self.party_id].get('description') + old_title = parties[self.party_id]['name'] + old_description = parties[self.party_id].get('description') - parties[self.party_id]['name'] = new_title - parties[self.party_id]['description'] = new_description + parties[self.party_id]['name'] = new_title + parties[self.party_id]['description'] = new_description # Update the party message await self.cog.update_party_message(interaction.guild.id, self.party_id) @@ -271,8 +275,11 @@ async def on_submit(self, interaction: discord.Interaction): party["signups"][role] = [] # Save the party - async with self.cog.config.guild(interaction.guild).parties() as parties: - parties[party_id] = party + # Use lock to prevent race conditions when multiple parties are created simultaneously + lock = self.cog._get_guild_lock(interaction.guild.id) + async with lock: + async with self.cog.config.guild(interaction.guild).parties() as parties: + parties[party_id] = party # Create the party embed embed = await self.cog.create_party_embed(party, interaction.guild) @@ -285,9 +292,12 @@ async def on_submit(self, interaction: discord.Interaction): message = await channel.send(embed=embed, view=view) # Save the message ID and channel ID - async with self.cog.config.guild(interaction.guild).parties() as parties: - parties[party_id]["message_id"] = message.id - parties[party_id]["channel_id"] = channel.id + # Use lock to prevent race conditions when multiple parties are created simultaneously + lock = self.cog._get_guild_lock(interaction.guild.id) + async with lock: + async with self.cog.config.guild(interaction.guild).parties() as parties: + parties[party_id]["message_id"] = message.id + parties[party_id]["channel_id"] = channel.id # Create modlog entry await self.cog.create_party_modlog( @@ -737,11 +747,30 @@ def __init__(self, bot): } self.config.register_guild(**default_guild) + # Lock to prevent race conditions when updating parties in Config + self._config_locks = {} # {guild_id: asyncio.Lock} + self._locks_creation_lock = asyncio.Lock() # Lock for creating per-guild locks + # Load persistent views for existing parties self.bot.loop.create_task(self._register_persistent_views()) # Register custom modlog casetypes self.bot.loop.create_task(self._register_casetypes()) + def _get_guild_lock(self, guild_id: int) -> asyncio.Lock: + """Get or create a lock for a specific guild to prevent config race conditions. + + Args: + guild_id: The ID of the guild + + Returns: + An asyncio.Lock for the guild + """ + if guild_id not in self._config_locks: + # Use a lock to ensure only one lock is created per guild + # Note: We can't await here, but this is safe because dict access is atomic + self._config_locks[guild_id] = asyncio.Lock() + return self._config_locks[guild_id] + @staticmethod def parse_allow_multiple(allow_multiple_text: str) -> tuple[bool, Optional[str]]: """Parse and validate allow_multiple_per_role setting. @@ -1021,74 +1050,77 @@ async def signup_user( guild_id = interaction.guild.id user_id = str(interaction.user.id) - async with self.config.guild_from_id(guild_id).parties() as parties: - if party_id not in parties: - if disabled_view: - # Edit the original message to show error and remove the select view - if deferred: - await interaction.edit_original_response( - content="❌ Party not found.", - view=None - ) - else: - await interaction.response.edit_message( - content="❌ Party not found.", - view=None - ) - else: - if deferred: - await interaction.followup.send( - "❌ Party not found.", - ephemeral=True - ) + # Use lock to prevent race conditions when multiple users sign up simultaneously + lock = self._get_guild_lock(guild_id) + async with lock: + async with self.config.guild_from_id(guild_id).parties() as parties: + if party_id not in parties: + if disabled_view: + # Edit the original message to show error and remove the select view + if deferred: + await interaction.edit_original_response( + content="❌ Party not found.", + view=None + ) + else: + await interaction.response.edit_message( + content="❌ Party not found.", + view=None + ) else: - await interaction.response.send_message( - "❌ Party not found.", - ephemeral=True - ) - return - - party = parties[party_id] - allow_multiple = party.get("allow_multiple_per_role", True) + if deferred: + await interaction.followup.send( + "❌ Party not found.", + ephemeral=True + ) + else: + await interaction.response.send_message( + "❌ Party not found.", + ephemeral=True + ) + return - # Remove user from any existing role first - for role_name, users in party["signups"].items(): - if user_id in users: - party["signups"][role_name].remove(user_id) - - # Check if role exists in signups, if not create it - if role not in party["signups"]: - party["signups"][role] = [] - - # Check if multiple signups allowed - if not allow_multiple and len(party["signups"][role]) > 0: - if disabled_view: - # Edit the original message to show error and remove the select view - if deferred: - await interaction.edit_original_response( - content=f"❌ The role **{role}** is already full (multiple signups not allowed).", - view=None - ) + party = parties[party_id] + allow_multiple = party.get("allow_multiple_per_role", True) + + # Remove user from any existing role first + for role_name, users in party["signups"].items(): + if user_id in users: + party["signups"][role_name].remove(user_id) + + # Check if role exists in signups, if not create it + if role not in party["signups"]: + party["signups"][role] = [] + + # Check if multiple signups allowed + if not allow_multiple and len(party["signups"][role]) > 0: + if disabled_view: + # Edit the original message to show error and remove the select view + if deferred: + await interaction.edit_original_response( + content=f"❌ The role **{role}** is already full (multiple signups not allowed).", + view=None + ) + else: + await interaction.response.edit_message( + content=f"❌ The role **{role}** is already full (multiple signups not allowed).", + view=None + ) else: - await interaction.response.edit_message( - content=f"❌ The role **{role}** is already full (multiple signups not allowed).", - view=None - ) - else: - if deferred: - await interaction.followup.send( - f"❌ The role **{role}** is already full (multiple signups not allowed).", - ephemeral=True - ) - else: - await interaction.response.send_message( - f"❌ The role **{role}** is already full (multiple signups not allowed).", - ephemeral=True - ) - return + if deferred: + await interaction.followup.send( + f"❌ The role **{role}** is already full (multiple signups not allowed).", + ephemeral=True + ) + else: + await interaction.response.send_message( + f"❌ The role **{role}** is already full (multiple signups not allowed).", + ephemeral=True + ) + return - # Add user to the role - party["signups"][role].append(user_id) + # Add user to the role + party["signups"][role].append(user_id) # Send success response if disabled_view: diff --git a/quotesdb/quotedb.py b/quotesdb/quotedb.py index 3b5bf2c..2b7edb8 100644 --- a/quotesdb/quotedb.py +++ b/quotesdb/quotedb.py @@ -1,3 +1,4 @@ +import asyncio import datetime import random @@ -24,6 +25,25 @@ def __init__(self, bot): self.config.register_guild(**default_guild) + # Lock to prevent race conditions when updating quotes in Config + self._config_locks = {} # {guild_id: asyncio.Lock} + self._locks_creation_lock = asyncio.Lock() # Lock for creating per-guild locks + + def _get_guild_lock(self, guild_id: int) -> asyncio.Lock: + """Get or create a lock for a specific guild to prevent config race conditions. + + Args: + guild_id: The ID of the guild + + Returns: + An asyncio.Lock for the guild + """ + if guild_id not in self._config_locks: + # Use a lock to ensure only one lock is created per guild + # Note: We can't await here, but this is safe because dict access is atomic + self._config_locks[guild_id] = asyncio.Lock() + return self._config_locks[guild_id] + @commands.guild_only() @commands.hybrid_command(name="qadd", aliases=["."]) async def quote_add(self, ctx, trigger: str, *, quote: str): @@ -38,20 +58,23 @@ async def quote_add(self, ctx, trigger: str, *, quote: str): """ await ctx.defer(ephemeral=True) - guild_group = self.config.guild(ctx.guild) - incr = await guild_group.quotes.incr() + 1 - await guild_group.quotes.incr.set(incr) - async with guild_group.quotes.id() as quotes, guild_group.quotes.trigger() as triggers: - quotes[incr] = { - "content": quote, - "user": ctx.author.id, - "trigger": trigger, - "jump_url": ctx.message.jump_url if ctx.message else None, - "datetime": datetime.datetime.now().timestamp() - } - - triggers.setdefault(trigger, []) - triggers[trigger] += [str(incr)] + # Use lock to prevent race conditions when multiple users add quotes simultaneously + lock = self._get_guild_lock(ctx.guild.id) + async with lock: + guild_group = self.config.guild(ctx.guild) + incr = await guild_group.quotes.incr() + 1 + await guild_group.quotes.incr.set(incr) + async with guild_group.quotes.id() as quotes, guild_group.quotes.trigger() as triggers: + quotes[incr] = { + "content": quote, + "user": ctx.author.id, + "trigger": trigger, + "jump_url": ctx.message.jump_url if ctx.message else None, + "datetime": datetime.datetime.now().timestamp() + } + + triggers.setdefault(trigger, []) + triggers[trigger] += [str(incr)] await ctx.send(f"{ctx.author.mention}, added quote `#{incr}`.", ephemeral=True) @@ -94,19 +117,22 @@ async def quote_del(self, ctx, qid: str): """ await ctx.defer(ephemeral=True) - guild_group = self.config.guild(ctx.guild) - async with guild_group.quotes.id() as quotes, guild_group.quotes.trigger() as triggers: - if qid not in quotes: - await ctx.send(f"{ctx.author.mention}, invalid quote id.", ephemeral=True) - return - data = quotes[qid] - member = discord.utils.find(lambda m: m.id == data['user'], ctx.channel.guild.members) - if ctx.author != member and not await self.bot.is_admin(ctx.author): - await ctx.send(f"{ctx.author.mention}, only the creator (or admins) can delete that.", ephemeral=True) - return - trigger = data['trigger'] - del quotes[qid] - triggers[trigger].remove(qid) + # Use lock to prevent race conditions when multiple users delete quotes simultaneously + lock = self._get_guild_lock(ctx.guild.id) + async with lock: + guild_group = self.config.guild(ctx.guild) + async with guild_group.quotes.id() as quotes, guild_group.quotes.trigger() as triggers: + if qid not in quotes: + await ctx.send(f"{ctx.author.mention}, invalid quote id.", ephemeral=True) + return + data = quotes[qid] + member = discord.utils.find(lambda m: m.id == data['user'], ctx.channel.guild.members) + if ctx.author != member and not await self.bot.is_admin(ctx.author): + await ctx.send(f"{ctx.author.mention}, only the creator (or admins) can delete that.", ephemeral=True) + return + trigger = data['trigger'] + del quotes[qid] + triggers[trigger].remove(qid) await ctx.send(f"{ctx.author.mention}, deleted quote #{qid}.", ephemeral=True) diff --git a/secret_santa/secret_santa.py b/secret_santa/secret_santa.py index 4cca3a3..5f0d1ab 100644 --- a/secret_santa/secret_santa.py +++ b/secret_santa/secret_santa.py @@ -1,3 +1,4 @@ +import asyncio import datetime import logging import random @@ -42,6 +43,25 @@ def __init__(self, bot): self.config.register_guild(**default_guild) self.config.register_global(**default_global) + # Lock to prevent race conditions when updating events in Config + self._config_locks = {} # {guild_id: asyncio.Lock} + self._locks_creation_lock = asyncio.Lock() # Lock for creating per-guild locks + + def _get_guild_lock(self, guild_id: int) -> asyncio.Lock: + """Get or create a lock for a specific guild to prevent config race conditions. + + Args: + guild_id: The ID of the guild + + Returns: + An asyncio.Lock for the guild + """ + if guild_id not in self._config_locks: + # Use a lock to ensure only one lock is created per guild + # Note: We can't await here, but this is safe because dict access is atomic + self._config_locks[guild_id] = asyncio.Lock() + return self._config_locks[guild_id] + async def red_delete_data_for_user(self, *, requester, user_id: int): """Delete user data when requested.""" all_guilds = await self.config.all_guilds() @@ -1404,8 +1424,11 @@ async def santadm_wishlist(self, ctx, event_id: str, *, wishlist: str): return # Update the wishlist - async with self.config.guild_from_id(guild_id).events() as events: - events[event_name]["participants"][user_id_str]["wishlist"] = wishlist + # Use lock to prevent race conditions when multiple users update wishlists simultaneously + lock = self._get_guild_lock(guild_id) + async with lock: + async with self.config.guild_from_id(guild_id).events() as events: + events[event_name]["participants"][user_id_str]["wishlist"] = wishlist guild = self.bot.get_guild(guild_id) guild_name = guild.name if guild else "Unknown Server"