Skip to content
This repository was archived by the owner on Jul 31, 2019. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 48 additions & 1 deletion tests/client_tests/unit/test_ledger.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,10 @@ async def test_update_history(self):
)


class MocHeaderNetwork:
class MocHeaderNetwork(MockNetwork):
def __init__(self, responses):
super().__init__(responses.get('history', ''),
responses.get('transaction', ''))
self.responses = responses

async def get_headers(self, height, blocks):
Expand All @@ -146,6 +148,51 @@ async def test_1_block_reorganization(self):
'height': 21, 'hex': hexlify(self.make_header(block_height=21))
}])

async def test_2_block_reorganization(self):
account = self.ledger.account_class.generate(self.ledger, Wallet(), "torba")
address = await account.receiving.get_or_create_usable_address()
address_details = await self.ledger.db.get_address(address=address)
self.assertEqual(address_details['history'], None)

self.ledger.network = MocHeaderNetwork({
20: {'height': 20, 'count': 5, 'hex': hexlify(
self.get_bytes(after=block_bytes(20), upto=block_bytes(5)),
)},
25: {'height': 25, 'count': 0, 'hex': b''},
'history': [
{'tx_hash': 'abcd01', 'height': 20},
{'tx_hash': 'abcd02', 'height': 25},
],
'transaction': {
'abcd01': hexlify(get_transaction(get_output(1)).raw),
'abcd02': hexlify(get_transaction(get_output(2)).raw),
}
})
headers = self.ledger.headers
await headers.connect(0, self.get_bytes(upto=block_bytes(20)))
self.add_header(block_height=len(headers))
self.assertEqual(headers.height, 20)
await self.ledger.update_history(address, '')
txs_details = await self.ledger.db.get_transactions(account)
self.assertEqual(len(txs_details), 2)
self.assertEqual(txs_details[0].id, get_transaction(get_output(2)).id)
self.assertEqual(txs_details[1].id, get_transaction(get_output(1)).id)
address_details = await self.ledger.db.get_address(address=address)
self.assertEqual(
address_details['history'],
'252bda9b22cc902ca2aa2de3548ee8baf06b8501ff7bfb3b0b7d980dbd1bf792:20:'
'ab9c0654dd484ac20437030f2034e25dcb29fc507e84b91138f80adc3af738f9:25:'
)
await self.ledger.db.rewind_blockchain(21)
txs_details = await self.ledger.db.get_transactions(account)
self.assertEqual(len(txs_details), 1)
self.assertEqual(txs_details[0].id, get_transaction(get_output(1)).id)
address_details = await self.ledger.db.get_address(address=address)
self.assertEqual(
address_details['history'],
'252bda9b22cc902ca2aa2de3548ee8baf06b8501ff7bfb3b0b7d980dbd1bf792:20:'
)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

test should use self.ledger.receive_header() instead of self.ledger.db.rewind_blockchain()

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By calling rewind_blockchain i ensure that it done work that it should, also calling receive_header does not trigger reorg thus rewind_blockchain is not called.


async def test_3_block_reorganization(self):
self.ledger.network = MocHeaderNetwork({
20: {'height': 20, 'count': 5, 'hex': hexlify(
Expand Down
32 changes: 29 additions & 3 deletions torba/client/basedatabase.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from asyncio import wrap_future
from concurrent.futures.thread import ThreadPoolExecutor

from io import StringIO
from typing import Tuple, List, Union, Callable, Any, Awaitable, Iterable

import sqlite3
Expand Down Expand Up @@ -99,6 +100,8 @@ def constraints_to_sql(constraints, joiner=' AND ', prepend_key=''):
col, op = col[:-len('__lte')], '<='
elif key.endswith('__gt'):
col, op = col[:-len('__gt')], '>'
elif key.endswith('__gte'):
col, op = col[:-len('__gte')], '>='
elif key.endswith('__like'):
col, op = col[:-len('__like')], 'LIKE'
elif key.endswith('__not_like'):
Expand Down Expand Up @@ -218,6 +221,11 @@ def _update_sql(table: str, data: dict, where: str,
)
return sql, values

async def _delete_sql(self, table: str, **constraints):
await self.db.execute(
*query("DELETE FROM {}".format(table), **constraints)
)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is wrong, see the other methods. _*_sql is supposed to return SQL not, execute it.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not wrong, other functions need one level abstraction in plus, here it's pointless.


class BaseDatabase(SQLiteMixin):

Expand Down Expand Up @@ -350,11 +358,29 @@ async def reserve_outputs(self, txos, is_reserved=True):
async def release_outputs(self, txos):
await self.reserve_outputs(txos, is_reserved=False)

async def rewind_blockchain(self, above_height): # pylint: disable=no-self-use
# TODO:
async def rewind_blockchain(self, above_height: int):
# 1. delete transactions above_height
txs = await self.select_transactions(
'txid, raw', height__gte=above_height
)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

        txs = await self.select_transactions(
            'txid, raw', height__gte=above_height
        )

if not txs:
return
await self._delete_sql('tx', height__gte=above_height)
# 2. update address histories removing deleted TXs
return True
address_history = await self.select_addresses('address, history')
for tx in txs:
txid = tx[0]
for row in address_history:
if not row[1] or txid not in row[1]:
continue
result = StringIO()
hist = row[1].split(':')[:-1]
for x in range(0, len(hist), 2):
if txid != hist[x] and above_height > int(hist[x+1]):
result.write(f'{hist[x]}:{hist[x+1]}:')
await self.set_address_history(row[0], result.getvalue())
await self._delete_sql('txo', txid__like=txid)
await self._delete_sql('txi', txid__like=txid)

async def select_transactions(self, cols, account=None, **constraints):
if 'txid' not in constraints and account is not None:
Expand Down