77from dataclasses import dataclass , field
88from queue import Queue , Empty
99from typing import List , Mapping , Callable , Union
10- import types
1110import trio
11+ import logging
1212from trio_websocket import (
1313 open_websocket_url ,
1414 ConnectionClosed ,
2424 SubscriptionType ,
2525 to_camel_case ,
2626)
27+ from blocknative import __version__ as API_VERSION
2728
2829PING_INTERVAL = 15
2930PING_TIMEOUT = 10
3031MESSAGE_SEND_INTERVAL = 0.021 # 21ms
3132
32- Callback = Callable [[dict , Callable ], None ]
33+ BN_BASE_URL = 'wss://api.blocknative.com/v0'
34+ BN_ETHEREUM = 'ethereum'
3335
36+ Callback = Callable [[dict , Callable ], None ]
3437
3538@dataclass
3639class Subscription :
@@ -85,18 +88,12 @@ def as_dict(self) -> dict:
8588class Stream :
8689 """Stream class used to connect to Blocknative's WebSocket API."""
8790
88- # - Public instance variables -
89-
9091 api_key : str
91- base_url : str = 'wss://api.blocknative.com/v0'
92- blockchain : str = 'ethereum'
92+ blockchain : str = BN_ETHEREUM
9393 network_id : int = 1
9494 version : str = '1'
9595 global_filters : List [dict ] = None
9696 valid_session : bool = True
97-
98- # - Private instance variables -
99-
10097 _ws : WebSocketConnection = field (default = None , init = False )
10198 _message_queue : Queue = field (default = Queue (), init = False )
10299
@@ -128,7 +125,7 @@ async def txn_handler(txn)
128125 stream.subscribe('0x7a250d5630b4cf539739df2c5dacb4c659f2488d', txn_handler)
129126 """
130127
131- if self .blockchain == 'ethereum' :
128+ if self .blockchain == BN_ETHEREUM :
132129 address = address .lower ()
133130
134131 # Add this subscription to the registry
@@ -159,10 +156,10 @@ def subscribe_txn(self, tx_hash: str, callback: Callback, status: str = 'sent'):
159156 if self ._is_connected ():
160157 self ._send_txn_watch_message (tx_hash , status )
161158
162- def connect (self ):
159+ def connect (self , base_url : str = BN_BASE_URL ):
163160 """Initializes the connection to the WebSocket server."""
164161 try :
165- return trio .run (self ._connect )
162+ return trio .run (self ._connect , base_url )
166163 except KeyboardInterrupt :
167164 print ('keyboard interrupt' )
168165 return None
@@ -173,6 +170,7 @@ def send_message(self, message: str):
173170 Args:
174171 message: The message to send.
175172 """
173+ logging .debug ('Sending: {}' % message )
176174 self ._message_queue .put (message )
177175
178176 async def _message_dispatcher (self ):
@@ -220,34 +218,29 @@ async def _message_handler(self, message: dict):
220218 # Raises an exception if the status of the message is an error
221219 raise_error_on_status (message )
222220
223- if 'event' in message and 'transaction' in message ['event' ]:
221+ if 'event' in message :
222+ event = message ['event' ]
224223 # Ignore server echo and unsubscribe messages
225- if is_server_echo (message [ ' event' ] ['eventCode' ]):
224+ if is_server_echo (event ['eventCode' ]):
226225 return
227226
228- # Checks if the messsage is for a transaction subscription
229- if subscription_type (message ) == SubscriptionType .TRANSACTION :
230-
231- # Find the matching subscription and run it's callback
232- if (
233- message ['event' ]['transaction' ]['hash' ]
234- in self ._subscription_registry
235- ):
236- await self ._subscription_registry [
237- message ['event' ]['transaction' ]['hash' ]
238- ].callback (message ['event' ]['transaction' ])
239-
240- # Checks if the messsage is for an address subscription
241- elif subscription_type (message ) == SubscriptionType .ADDRESS :
242- watched_address = message ['event' ]['transaction' ]['watchedAddress' ]
243- if watched_address in self ._subscription_registry and watched_address is not None :
227+ if 'transaction' in event :
228+ event_transaction = event ['transaction' ]
229+ # Checks if the messsage is for a transaction subscription
230+ if subscription_type (message ) == SubscriptionType .TRANSACTION :
244231 # Find the matching subscription and run it's callback
245- if 'transaction' in message ['event' ]:
246- transaction = message ['event' ]['transaction' ]
247- await self ._subscription_registry [watched_address ].callback (
248- transaction ,
249- (lambda : self .unsubscribe (watched_address )),
250- )
232+ transaction_hash = event_transaction ['hash' ]
233+ if transaction_hash in self ._subscription_registry :
234+ transaction = self ._flatten_event_to_transaction (event )
235+ await self ._subscription_registry [transaction_hash ].callback (transaction )
236+
237+ # Checks if the messsage is for an address subscription
238+ elif subscription_type (message ) == SubscriptionType .ADDRESS :
239+ watched_address = event_transaction ['watchedAddress' ]
240+ if watched_address in self ._subscription_registry and watched_address is not None :
241+ # Find the matching subscription and run it's callback
242+ transaction = self ._flatten_event_to_transaction (event )
243+ await self ._subscription_registry [watched_address ].callback (transaction ,(lambda : self .unsubscribe (watched_address )))
251244
252245 def unsubscribe (self , watched_address ):
253246 # remove this subscription from the registry so that we don't execute the callback
@@ -284,7 +277,7 @@ async def _heartbeat(self):
284277 await self ._ws .ping ()
285278 await trio .sleep (PING_INTERVAL )
286279
287- async def _handle_connection (self ):
280+ async def _handle_connection (self , base_url : str ):
288281 """Handles the setup once the websocket connection is established, as well as,
289282 handles reconnect if the websocket closes for any reason.
290283
@@ -315,14 +308,16 @@ async def _handle_connection(self):
315308 nursery .start_soon (self ._message_dispatcher )
316309 except ConnectionClosed as cc :
317310 # If server times the connection out or drops, reconnect
318- await self ._connect ()
311+ await trio .sleep (0.5 )
312+ await self ._connect (base_url )
319313
320- async def _connect (self ):
314+ async def _connect (self , base_url ):
321315 try :
322- async with open_websocket_url (self . base_url ) as ws :
316+ async with open_websocket_url (base_url ) as ws :
323317 self ._ws = ws
324- await self ._handle_connection ()
318+ await self ._handle_connection (base_url )
325319 except HandshakeError as e :
320+ logging .exception ('Handshake failed' )
326321 return False
327322
328323 def _is_connected (self ) -> bool :
@@ -398,7 +393,7 @@ def _build_payload(
398393 return {
399394 'timeStamp' : datetime .now ().isoformat (),
400395 'dappId' : self .api_key ,
401- 'version' : self . version ,
396+ 'version' : API_VERSION ,
402397 'blockchain' : {
403398 'system' : self .blockchain ,
404399 'network' : network_id_to_name (self .network_id ),
@@ -413,3 +408,25 @@ def _queue_init_message(self):
413408 self .send_message (
414409 self ._build_payload (category_code = 'initialize' , event_code = 'checkDappId' )
415410 )
411+
412+ def _flatten_event_to_transaction (self , event :dict ):
413+ transaction = {}
414+ eventcopy = dict (event )
415+ del eventcopy ['dappId' ]
416+ if 'transaction' in eventcopy :
417+ txn = eventcopy ['transaction' ]
418+ for k in txn .keys ():
419+ transaction [k ] = txn [k ]
420+ del eventcopy ['transaction' ]
421+ if 'blockchain' in eventcopy :
422+ bc = eventcopy ['blockchain' ]
423+ for k in bc .keys ():
424+ transaction [k ] = bc [k ]
425+ del eventcopy ['blockchain' ]
426+ if 'contractCall' in eventcopy :
427+ transaction ['contractCall' ] = eventcopy ['contractCall' ]
428+ del eventcopy ['contractCall' ]
429+ for k in eventcopy :
430+ if not isinstance (k , dict ) and not isinstance (k , list ):
431+ transaction [k ] = eventcopy [k ]
432+ return transaction
0 commit comments