diff --git a/tests/client_tests/unit/test_ledger.py b/tests/client_tests/unit/test_ledger.py index 0e077b4..1c4cecb 100644 --- a/tests/client_tests/unit/test_ledger.py +++ b/tests/client_tests/unit/test_ledger.py @@ -121,8 +121,10 @@ async def test_update_history(self): ) -class MocHeaderNetwork: +class MocHeaderNetwork(MockNetwork): def __init__(self, responses): + super().__init__(responses.get('history', ''), + responses.get('transaction', '')) self.responses = responses async def get_headers(self, height, blocks): @@ -146,6 +148,51 @@ async def test_1_block_reorganization(self): 'height': 21, 'hex': hexlify(self.make_header(block_height=21)) }]) + async def test_2_block_reorganization(self): + account = self.ledger.account_class.generate(self.ledger, Wallet(), "torba") + address = await account.receiving.get_or_create_usable_address() + address_details = await self.ledger.db.get_address(address=address) + self.assertEqual(address_details['history'], None) + + self.ledger.network = MocHeaderNetwork({ + 20: {'height': 20, 'count': 5, 'hex': hexlify( + self.get_bytes(after=block_bytes(20), upto=block_bytes(5)), + )}, + 25: {'height': 25, 'count': 0, 'hex': b''}, + 'history': [ + {'tx_hash': 'abcd01', 'height': 20}, + {'tx_hash': 'abcd02', 'height': 25}, + ], + 'transaction': { + 'abcd01': hexlify(get_transaction(get_output(1)).raw), + 'abcd02': hexlify(get_transaction(get_output(2)).raw), + } + }) + headers = self.ledger.headers + await headers.connect(0, self.get_bytes(upto=block_bytes(20))) + self.add_header(block_height=len(headers)) + self.assertEqual(headers.height, 20) + await self.ledger.update_history(address, '') + txs_details = await self.ledger.db.get_transactions(account) + self.assertEqual(len(txs_details), 2) + self.assertEqual(txs_details[0].id, get_transaction(get_output(2)).id) + self.assertEqual(txs_details[1].id, get_transaction(get_output(1)).id) + address_details = await self.ledger.db.get_address(address=address) + self.assertEqual( + address_details['history'], + '252bda9b22cc902ca2aa2de3548ee8baf06b8501ff7bfb3b0b7d980dbd1bf792:20:' + 'ab9c0654dd484ac20437030f2034e25dcb29fc507e84b91138f80adc3af738f9:25:' + ) + await self.ledger.db.rewind_blockchain(21) + txs_details = await self.ledger.db.get_transactions(account) + self.assertEqual(len(txs_details), 1) + self.assertEqual(txs_details[0].id, get_transaction(get_output(1)).id) + address_details = await self.ledger.db.get_address(address=address) + self.assertEqual( + address_details['history'], + '252bda9b22cc902ca2aa2de3548ee8baf06b8501ff7bfb3b0b7d980dbd1bf792:20:' + ) + async def test_3_block_reorganization(self): self.ledger.network = MocHeaderNetwork({ 20: {'height': 20, 'count': 5, 'hex': hexlify( diff --git a/torba/client/basedatabase.py b/torba/client/basedatabase.py index d161c81..d8facf7 100644 --- a/torba/client/basedatabase.py +++ b/torba/client/basedatabase.py @@ -3,6 +3,7 @@ from asyncio import wrap_future from concurrent.futures.thread import ThreadPoolExecutor +from io import StringIO from typing import Tuple, List, Union, Callable, Any, Awaitable, Iterable import sqlite3 @@ -99,6 +100,8 @@ def constraints_to_sql(constraints, joiner=' AND ', prepend_key=''): col, op = col[:-len('__lte')], '<=' elif key.endswith('__gt'): col, op = col[:-len('__gt')], '>' + elif key.endswith('__gte'): + col, op = col[:-len('__gte')], '>=' elif key.endswith('__like'): col, op = col[:-len('__like')], 'LIKE' elif key.endswith('__not_like'): @@ -218,6 +221,11 @@ def _update_sql(table: str, data: dict, where: str, ) return sql, values + async def _delete_sql(self, table: str, **constraints): + await self.db.execute( + *query("DELETE FROM {}".format(table), **constraints) + ) + class BaseDatabase(SQLiteMixin): @@ -350,11 +358,29 @@ async def reserve_outputs(self, txos, is_reserved=True): async def release_outputs(self, txos): await self.reserve_outputs(txos, is_reserved=False) - async def rewind_blockchain(self, above_height): # pylint: disable=no-self-use - # TODO: + async def rewind_blockchain(self, above_height: int): # 1. delete transactions above_height + txs = await self.select_transactions( + 'txid, raw', height__gte=above_height + ) + if not txs: + return + await self._delete_sql('tx', height__gte=above_height) # 2. update address histories removing deleted TXs - return True + address_history = await self.select_addresses('address, history') + for tx in txs: + txid = tx[0] + for row in address_history: + if not row[1] or txid not in row[1]: + continue + result = StringIO() + hist = row[1].split(':')[:-1] + for x in range(0, len(hist), 2): + if txid != hist[x] and above_height > int(hist[x+1]): + result.write(f'{hist[x]}:{hist[x+1]}:') + await self.set_address_history(row[0], result.getvalue()) + await self._delete_sql('txo', txid__like=txid) + await self._delete_sql('txi', txid__like=txid) async def select_transactions(self, cols, account=None, **constraints): if 'txid' not in constraints and account is not None: