55 ConversationHandler , CallbackContext , ContextTypes )
66
77from logger_config import logger
8- from constants import TIMEOUT_IN_SEC , STATION_SELECT_ONE_TIME , STATION_SELECT_SUBSCRIBE , ONE_TIME , SUBSCRIBE , UNSUBSCRIBE , VALID_SUMMARY_INTERVALS
8+ from constants import TIMEOUT_IN_SEC , STATION_SELECT_ONE_TIME , STATION_SELECT_SUBSCRIBE , ONE_TIME , SUBSCRIBE , UNSUBSCRIBE , VALID_SUMMARY_INTERVALS , JOBQUEUE_DELAY , DEFAULT_USER_ID
99
1010
1111class PlotBot :
1212
13- def __init__ (self , token , station_config , db = None , admin_id = None ):
13+ def __init__ (self ,
14+ token ,
15+ station_config ,
16+ db = None ,
17+ admin_id = None ,
18+ ecmwf = None ):
1419
1520 self ._admin_id = admin_id
1621 self .app = Application .builder ().token (token ).build ()
1722 self ._db = db
23+ self ._ecmwf = ecmwf
1824 self ._station_names = sorted (
1925 [station ["name" ] for station in station_config ])
2026 self ._region_of_stations = {
@@ -24,14 +30,6 @@ def __init__(self, token, station_config, db=None, admin_id=None):
2430 self ._station_regions = sorted (
2531 {station ["region" ]
2632 for station in station_config })
27- self ._subscriptions = {
28- station : set ()
29- for station in self ._station_names
30- }
31- self ._one_time_forecast_requests = {
32- station : set ()
33- for station in self ._station_names
34- }
3533 # filter for stations
3634 self ._filter_stations = filters .Regex ("^(" +
3735 "|" .join (self ._station_names ) +
@@ -109,17 +107,53 @@ def __init__(self, token, station_config, db=None, admin_id=None):
109107 self .app .add_handler (one_time_forecast_handler )
110108 self .app .add_error_handler (self ._error )
111109
112- async def connect (self ):
113- await self .app .initialize ()
114- await self .app .updater .start_polling (allowed_updates = Update .ALL_TYPES )
115- await self .app .start ()
116- logger .info ('Bot connected' )
110+ self .app .job_queue .run_once (
111+ self ._override_basetime ,
112+ when = 0 ,
113+ name = 'Override basetime' ,
114+ )
115+ self .app .job_queue .run_repeating (
116+ self ._update_basetime ,
117+ interval = 60 ,
118+ first = 60 ,
119+ name = 'Update basetime' ,
120+ )
121+ self .app .job_queue .run_repeating (
122+ self ._cache_plots ,
123+ interval = 30 ,
124+ first = 30 ,
125+ name = 'Cache plots' ,
126+ )
127+ self .app .job_queue .run_repeating (
128+ self ._broadcast_from_queue ,
129+ interval = 90 ,
130+ first = 60 ,
131+ name = 'Broadcast' ,
132+ )
117133
118- while True :
119- await asyncio .sleep (1 )
134+ async def _override_basetime (self , context : CallbackContext ):
135+ self ._ecmwf .override_base_time_from_init ()
136+
137+ async def _update_basetime (self , context : CallbackContext ):
138+ self ._ecmwf .upgrade_basetime_global ()
139+ self ._ecmwf .upgrade_basetime_stations ()
140+
141+ async def _send_plot_from_queue (self , context : CallbackContext ):
142+ job = context .job
143+ user_id , station_name = job .data
144+ plots = self ._ecmwf .download_plots ([station_name ])
145+ await self ._send_plot_to_user (plots , station_name , user_id )
146+
147+ def start (self ):
148+ logger .info ('Starting bot' )
149+ self .app .run_polling (allowed_updates = Update .ALL_TYPES )
120150
121151 async def _error (self , update : Update , context : CallbackContext ):
122- user_id = update .message .chat_id
152+
153+ if update :
154+ user_id = update .message .chat_id
155+ else :
156+ user_id = DEFAULT_USER_ID
123157 logger .error (f"Exception while handling an update: { context .error } " )
124158 self ._db .log_activity (
125159 activity_type = "bot-error" ,
@@ -303,9 +337,11 @@ async def _subscribe_for_station(self, update: Update,
303337 reply_markup = ReplyKeyboardRemove (),
304338 )
305339 self ._db .add_subscription (msg_text , user .id )
306- self ._subscriptions [msg_text ].add (user .id )
307340
308341 logger .info (f' { user .first_name } subscribed for Station { msg_text } ' )
342+ context .job_queue .run_once (self ._send_plot_from_queue ,
343+ JOBQUEUE_DELAY ,
344+ data = (user .id , msg_text ))
309345
310346 self ._db .log_activity (
311347 activity_type = "subscription" ,
@@ -324,8 +360,10 @@ async def _request_one_time_forecast_for_station(
324360 reply_text ,
325361 reply_markup = ReplyKeyboardRemove (),
326362 )
327- self ._one_time_forecast_requests [msg_text ].add (user .id )
328363
364+ context .job_queue .run_once (self ._send_plot_from_queue ,
365+ JOBQUEUE_DELAY ,
366+ data = (user .id , msg_text ))
329367 logger .info (
330368 f' { user .first_name } requested forecast for Station { msg_text } ' )
331369
@@ -346,24 +384,8 @@ async def _cancel(self, update: Update, context: CallbackContext) -> int:
346384
347385 return ConversationHandler .END
348386
349- def has_new_subscribers_waiting (self ):
350- return any (users for users in self ._subscriptions .values ())
351-
352- def has_one_time_forecast_waiting (self ):
353- return any (users
354- for users in self ._one_time_forecast_requests .values ())
355-
356- def stations_of_one_time_request (self ):
357- return [
358- station
359- for station , users in self ._one_time_forecast_requests .items ()
360- if users
361- ]
362-
363- def stations_of_new_subscribers (self ):
364- return [
365- station for station , users in self ._subscriptions .items () if users
366- ]
387+ async def _cache_plots (self , context : CallbackContext ):
388+ self ._ecmwf .cache_plots ()
367389
368390 async def _send_plot_to_user (self , plots , station_name , user_id ):
369391 logger .debug (f'Send plot to user: { user_id } ' )
@@ -375,30 +397,9 @@ async def _send_plot_to_user(self, plots, station_name, user_id):
375397 except Exception as e :
376398 logger .error (f'Error sending plot to user { user_id } : { e } ' )
377399
378- async def _send_plots (self , plots , requests ):
379- for station_name , users in requests .items ():
380- for user_id in users :
381- await self ._send_plot_to_user (plots , station_name , user_id )
382-
383- async def send_plots_to_new_subscribers (self , plots ):
384- await self ._send_plots (plots , self ._subscriptions )
385- logger .info ('plots sent to new subscribers' )
386-
387- self ._subscriptions = {
388- station : set ()
389- for station in self ._station_names
390- }
391-
392- async def send_one_time_forecast (self , plots ):
393- await self ._send_plots (plots , self ._one_time_forecast_requests )
394- logger .info ('plots sent to one time forecast requests' )
395-
396- self ._one_time_forecast_requests = {
397- station : set ()
398- for station in self ._station_names
399- }
400-
401- async def broadcast (self , plots ):
400+ async def _broadcast_from_queue (self , context : CallbackContext ):
401+ plots = self ._ecmwf .download_latest_plots (
402+ self ._db .stations_with_subscribers ())
402403 if plots :
403404 for station_name in plots :
404405 for user_id in self ._db .get_subscriptions_by_station (
0 commit comments