-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathbot.py
More file actions
120 lines (94 loc) · 4.4 KB
/
bot.py
File metadata and controls
120 lines (94 loc) · 4.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
# -*- coding: utf-8 -*-
import aioredis
import asyncio
import asyncpg
import hashlib
import logging
import traceback
from pathlib import Path
from typing import TYPE_CHECKING, List
if TYPE_CHECKING:
from cogs.db.utils import DBUtils
import aiohttp
import discord
from discord.ext import commands
class FlagBot(commands.Bot):
def __init__(self, *args, config=None, **kwargs):
super().__init__(*args, **kwargs)
self.redis = None
self.config = config or {}
self.db = None
self.db_available = asyncio.Event()
#: OAuth2 application owner.
self.owner: discord.User = None
#: List of extension names to load. We store this because `self.extensions` is volatile during reload.
self.to_load: List[str] = None
self.logger = logging.getLogger('flagbot')
self.remove_command('help')
self.session = aiohttp.ClientSession(loop=self.loop)
self.loop.create_task(self.acquire_pool())
async def acquire_pool(self):
credentials = self.config.pop("database")
if not credentials:
self.logger.critical("Cannot connect to db, no credentials!")
await self.logout()
self.db = await asyncpg.create_pool(**credentials)
self.db_available.set()
async def on_ready(self):
self.redis = await aioredis.create_redis_pool(**self.config['redis'])
self.discover_exts('cogs')
self.logger.info('Ready! Logged in as %s (%d)', self.user, self.user.id)
async def load_cache(self):
conn = self.get_db()
self.config['reviewer_channels'] = await conn.load_reviewer_channels()
self.config['scan_channels'] = await conn.load_scan_channels()
for c in self.config['reviewer_channels']:
channel = self.get_channel(c['channel_id']) or await self.fetch_channel(c['channel_id'])
if len(await channel.webhooks()) == 0:
await channel.create_webhook(name='FlagBot')
channel = self.get_channel(self.config['stats_channel']) or await self.fetch_channel(self.config['stats_channel'])
if len(await channel.webhooks()) == 0:
await channel.create_webhook(name='FlagBot')
channel = self.get_channel(self.config['sanitize_channel']) or await self.fetch_channel(self.config['sanitize_channel'])
if len(await channel.webhooks()) == 0:
await channel.create_webhook(name='FlagBot')
def get_db(self) -> "DBUtils":
conn = self.get_cog('DBUtils')
if conn is None:
self.bot.logger.info("The cog \"DBUtils\" is not loaded")
return
return conn
async def on_command_error(self, ctx: commands.Context, exception):
msg = ctx.message
if isinstance(exception, (commands.CommandOnCooldown, commands.CommandNotFound,
commands.DisabledCommand, commands.MissingPermissions,
commands.CheckFailure)):
pass # we don't care about these
elif isinstance(exception, (commands.BadArgument, commands.MissingRequiredArgument)):
try:
await msg.add_reaction("\N{BLACK QUESTION MARK ORNAMENT}")
except discord.HTTPException:
pass
else:
error_digest = "".join(traceback.format_exception(type(exception), exception,
exception.__traceback__, 8))
error_hash = hashlib.sha256(error_digest.encode("utf8")).hexdigest()
short_hash = error_hash[0:8]
self.logger.error(f"Encountered command error [{error_hash}] ({msg.id}):\n{error_digest}")
await ctx.send(f"Uh-oh, that's an error [{short_hash}...]")
async def is_owner(self, user):
if user.id in self.config.get("admin_users", []):
return True
return await super().is_owner(user)
def discover_exts(self, directory: str):
"""Loads all extensions from a directory."""
ignore = {'__pycache__', '__init__'}
exts = [
'.'.join(list(p.parts)).replace('.py', '') for p in list(Path(directory).glob('**/*.py'))
if p.stem not in ignore
]
self.logger.info('Loading extensions: %s', exts)
for ext in exts:
self.load_extension(ext)
self.to_load = list(self.extensions.keys()).copy()
self.logger.info('To load: %s', self.to_load)