Skip to content

Commit 2fac1ab

Browse files
jonasjuckergithub-actions
andauthored
use job queue (#35)
Co-authored-by: github-actions <github-actions@github.com>
1 parent f19a8b0 commit 2fac1ab

3 files changed

Lines changed: 77 additions & 119 deletions

File tree

bot.py

Lines changed: 63 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,22 @@
55
ConversationHandler, CallbackContext, ContextTypes)
66

77
from logger_config import logger
8-
from constants import TIMEOUT_IN_SEC, STATION_SELECT_ONE_TIME, STATION_SELECT_SUBSCRIBE, ONE_TIME, SUBSCRIBE, UNSUBSCRIBE, VALID_SUMMARY_INTERVALS
8+
from constants import TIMEOUT_IN_SEC, STATION_SELECT_ONE_TIME, STATION_SELECT_SUBSCRIBE, ONE_TIME, SUBSCRIBE, UNSUBSCRIBE, VALID_SUMMARY_INTERVALS, JOBQUEUE_DELAY, DEFAULT_USER_ID
99

1010

1111
class PlotBot:
1212

13-
def __init__(self, token, station_config, db=None, admin_id=None):
13+
def __init__(self,
14+
token,
15+
station_config,
16+
db=None,
17+
admin_id=None,
18+
ecmwf=None):
1419

1520
self._admin_id = admin_id
1621
self.app = Application.builder().token(token).build()
1722
self._db = db
23+
self._ecmwf = ecmwf
1824
self._station_names = sorted(
1925
[station["name"] for station in station_config])
2026
self._region_of_stations = {
@@ -24,14 +30,6 @@ def __init__(self, token, station_config, db=None, admin_id=None):
2430
self._station_regions = sorted(
2531
{station["region"]
2632
for station in station_config})
27-
self._subscriptions = {
28-
station: set()
29-
for station in self._station_names
30-
}
31-
self._one_time_forecast_requests = {
32-
station: set()
33-
for station in self._station_names
34-
}
3533
# filter for stations
3634
self._filter_stations = filters.Regex("^(" +
3735
"|".join(self._station_names) +
@@ -109,17 +107,53 @@ def __init__(self, token, station_config, db=None, admin_id=None):
109107
self.app.add_handler(one_time_forecast_handler)
110108
self.app.add_error_handler(self._error)
111109

112-
async def connect(self):
113-
await self.app.initialize()
114-
await self.app.updater.start_polling(allowed_updates=Update.ALL_TYPES)
115-
await self.app.start()
116-
logger.info('Bot connected')
110+
self.app.job_queue.run_once(
111+
self._override_basetime,
112+
when=0,
113+
name='Override basetime',
114+
)
115+
self.app.job_queue.run_repeating(
116+
self._update_basetime,
117+
interval=60,
118+
first=60,
119+
name='Update basetime',
120+
)
121+
self.app.job_queue.run_repeating(
122+
self._cache_plots,
123+
interval=30,
124+
first=30,
125+
name='Cache plots',
126+
)
127+
self.app.job_queue.run_repeating(
128+
self._broadcast_from_queue,
129+
interval=90,
130+
first=60,
131+
name='Broadcast',
132+
)
117133

118-
while True:
119-
await asyncio.sleep(1)
134+
async def _override_basetime(self, context: CallbackContext):
135+
self._ecmwf.override_base_time_from_init()
136+
137+
async def _update_basetime(self, context: CallbackContext):
138+
self._ecmwf.upgrade_basetime_global()
139+
self._ecmwf.upgrade_basetime_stations()
140+
141+
async def _send_plot_from_queue(self, context: CallbackContext):
142+
job = context.job
143+
user_id, station_name = job.data
144+
plots = self._ecmwf.download_plots([station_name])
145+
await self._send_plot_to_user(plots, station_name, user_id)
146+
147+
def start(self):
148+
logger.info('Starting bot')
149+
self.app.run_polling(allowed_updates=Update.ALL_TYPES)
120150

121151
async def _error(self, update: Update, context: CallbackContext):
122-
user_id = update.message.chat_id
152+
153+
if update:
154+
user_id = update.message.chat_id
155+
else:
156+
user_id = DEFAULT_USER_ID
123157
logger.error(f"Exception while handling an update: {context.error}")
124158
self._db.log_activity(
125159
activity_type="bot-error",
@@ -303,9 +337,11 @@ async def _subscribe_for_station(self, update: Update,
303337
reply_markup=ReplyKeyboardRemove(),
304338
)
305339
self._db.add_subscription(msg_text, user.id)
306-
self._subscriptions[msg_text].add(user.id)
307340

308341
logger.info(f' {user.first_name} subscribed for Station {msg_text}')
342+
context.job_queue.run_once(self._send_plot_from_queue,
343+
JOBQUEUE_DELAY,
344+
data=(user.id, msg_text))
309345

310346
self._db.log_activity(
311347
activity_type="subscription",
@@ -324,8 +360,10 @@ async def _request_one_time_forecast_for_station(
324360
reply_text,
325361
reply_markup=ReplyKeyboardRemove(),
326362
)
327-
self._one_time_forecast_requests[msg_text].add(user.id)
328363

364+
context.job_queue.run_once(self._send_plot_from_queue,
365+
JOBQUEUE_DELAY,
366+
data=(user.id, msg_text))
329367
logger.info(
330368
f' {user.first_name} requested forecast for Station {msg_text}')
331369

@@ -346,24 +384,8 @@ async def _cancel(self, update: Update, context: CallbackContext) -> int:
346384

347385
return ConversationHandler.END
348386

349-
def has_new_subscribers_waiting(self):
350-
return any(users for users in self._subscriptions.values())
351-
352-
def has_one_time_forecast_waiting(self):
353-
return any(users
354-
for users in self._one_time_forecast_requests.values())
355-
356-
def stations_of_one_time_request(self):
357-
return [
358-
station
359-
for station, users in self._one_time_forecast_requests.items()
360-
if users
361-
]
362-
363-
def stations_of_new_subscribers(self):
364-
return [
365-
station for station, users in self._subscriptions.items() if users
366-
]
387+
async def _cache_plots(self, context: CallbackContext):
388+
self._ecmwf.cache_plots()
367389

368390
async def _send_plot_to_user(self, plots, station_name, user_id):
369391
logger.debug(f'Send plot to user: {user_id}')
@@ -375,30 +397,9 @@ async def _send_plot_to_user(self, plots, station_name, user_id):
375397
except Exception as e:
376398
logger.error(f'Error sending plot to user {user_id}: {e}')
377399

378-
async def _send_plots(self, plots, requests):
379-
for station_name, users in requests.items():
380-
for user_id in users:
381-
await self._send_plot_to_user(plots, station_name, user_id)
382-
383-
async def send_plots_to_new_subscribers(self, plots):
384-
await self._send_plots(plots, self._subscriptions)
385-
logger.info('plots sent to new subscribers')
386-
387-
self._subscriptions = {
388-
station: set()
389-
for station in self._station_names
390-
}
391-
392-
async def send_one_time_forecast(self, plots):
393-
await self._send_plots(plots, self._one_time_forecast_requests)
394-
logger.info('plots sent to one time forecast requests')
395-
396-
self._one_time_forecast_requests = {
397-
station: set()
398-
for station in self._station_names
399-
}
400-
401-
async def broadcast(self, plots):
400+
async def _broadcast_from_queue(self, context: CallbackContext):
401+
plots = self._ecmwf.download_latest_plots(
402+
self._db.stations_with_subscribers())
402403
if plots:
403404
for station_name in plots:
404405
for user_id in self._db.get_subscriptions_by_station(

constants.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,7 @@
77
STATION_SELECT_ONE_TIME, STATION_SELECT_SUBSCRIBE, ONE_TIME, SUBSCRIBE, UNSUBSCRIBE = range(
88
5)
99
VALID_SUMMARY_INTERVALS = ['24 HOURS', '7 DAYS', '30 DAYS', '1 YEAR']
10+
11+
JOBQUEUE_DELAY = 10
12+
13+
DEFAULT_USER_ID = 999

main.py

Lines changed: 10 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -1,41 +1,14 @@
11
import logging
22
import argparse
3-
import time
43
import yaml
54
import sys
6-
import threading
7-
import asyncio
85

96
from ecmwf import EcmwfApi
107
from bot import PlotBot
118
from logger_config import logger
129
from db import Database
1310

1411

15-
async def await_func(func, *args):
16-
async_func = asyncio.create_task(func(*args))
17-
await async_func
18-
19-
20-
def run_asyncio(func, *args):
21-
asyncio.run(await_func(func, *args))
22-
23-
24-
def run_asyncio_in_thread(func, name, *args):
25-
thread = threading.Thread(target=run_asyncio,
26-
name=name,
27-
daemon=True,
28-
args=[func, *args])
29-
thread.start()
30-
logging.debug(f'Started thread: {name}')
31-
32-
33-
def start_bot(token, station_config, admin_id, db):
34-
bot = PlotBot(token, station_config, admin_id=admin_id, db=db)
35-
run_asyncio_in_thread(bot.connect, 'bot-connect')
36-
return bot
37-
38-
3912
def main():
4013

4114
parser = argparse.ArgumentParser()
@@ -66,39 +39,19 @@ def main():
6639
with open('stations.yaml', 'r') as file:
6740
station_config = yaml.safe_load(file)
6841

42+
ecmwf = EcmwfApi(station_config)
43+
6944
db = Database('config.yml')
7045

71-
bot = start_bot(args.bot_token, station_config, args.admin_id, db)
46+
bot = PlotBot(args.bot_token,
47+
station_config,
48+
admin_id=args.admin_id,
49+
db=db,
50+
ecmwf=ecmwf)
51+
bot.start()
7252

73-
ecmwf = EcmwfApi(station_config)
74-
ecmwf.override_base_time_from_init()
75-
76-
logger.info('Enter infinite loop')
77-
78-
while True:
79-
80-
try:
81-
ecmwf.upgrade_basetime_global()
82-
ecmwf.upgrade_basetime_stations()
83-
if bot.has_new_subscribers_waiting():
84-
run_asyncio_in_thread(
85-
bot.send_plots_to_new_subscribers, 'new-subscribers',
86-
ecmwf.download_plots(bot.stations_of_new_subscribers()))
87-
if bot.has_one_time_forecast_waiting():
88-
run_asyncio_in_thread(
89-
bot.send_one_time_forecast, 'one-time-forecast',
90-
ecmwf.download_plots(bot.stations_of_one_time_request()))
91-
run_asyncio_in_thread(
92-
bot.broadcast, 'broadcast',
93-
ecmwf.download_latest_plots(db.stations_with_subscribers()))
94-
ecmwf.cache_plots()
95-
except Exception as e:
96-
logger.error(f'An error occured: {e}')
97-
sys.exit(1)
98-
99-
snooze = 5
100-
logger.debug(f'snooze {snooze}s ...')
101-
time.sleep(snooze)
53+
# we should not be here
54+
sys.exit(1)
10255

10356

10457
if __name__ == '__main__':

0 commit comments

Comments
 (0)