Skip to content

Commit c4bd19e

Browse files
author
Anthony
committed
fix: move constants outside import block, remove whitespace-only changes
1 parent a660486 commit c4bd19e

2 files changed

Lines changed: 262 additions & 12 deletions

File tree

node/rewards_implementation_rip200.py

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,10 @@ def jsonify(obj):
8181
except ImportError:
8282
ANTI_DOUBLE_MINING_AVAILABLE = False
8383
print("[WARN] anti_double_mining.py not available - using standard rewards")
84+
# Constants for API responses
85+
RTC_DECIMAL_PRECISION = 8
86+
DATABASE_LOCKED_ERROR_MESSAGE = "Service unavailable due to database issues"
87+
UNEXPECTED_DATABASE_ERROR_MESSAGE = "An unexpected database error occurred"
8488

8589
# Constants
8690
UNIT = 1_000_000 # uRTC per 1 RTC
@@ -270,18 +274,25 @@ def get_balance():
270274
if not miner_id:
271275
return jsonify({"error": "miner_id required"}), 400
272276

273-
with sqlite3.connect(DB_PATH) as db:
274-
row = db.execute(
275-
"SELECT amount_i64 FROM balances WHERE miner_id = ?",
276-
(miner_id,)
277-
).fetchone()
278-
279-
amount_i64 = int(row[0]) if row else 0
280-
return jsonify({
281-
"miner_id": miner_id,
282-
"amount_i64": amount_i64,
283-
"amount_rtc": amount_i64 / UNIT
284-
})
277+
try:
278+
with sqlite3.connect(DB_PATH) as db:
279+
row = db.execute(
280+
"SELECT amount_i64 FROM balances WHERE miner_id = ?",
281+
(miner_id,)
282+
).fetchone()
283+
284+
amount_i64 = int(row[0]) if row else 0
285+
return jsonify({
286+
"miner_id": miner_id,
287+
"amount_i64": amount_i64,
288+
"amount_rtc": round(amount_i64 / UNIT, RTC_DECIMAL_PRECISION)
289+
})
290+
except sqlite3.OperationalError as e:
291+
print(f"Database operational error in get_balance for miner_id {miner_id}: {e}")
292+
return jsonify({"error": DATABASE_LOCKED_ERROR_MESSAGE}), 503
293+
except sqlite3.Error as e:
294+
print(f"Unexpected database error in get_balance for miner_id {miner_id}: {e}")
295+
return jsonify({"error": UNEXPECTED_DATABASE_ERROR_MESSAGE}), 500
285296

286297
@app.route('/wallet/balances/all', methods=['GET'])
287298
def get_all_balances():
Lines changed: 239 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,239 @@
1+
"""
2+
Comprehensive tests for GET /wallet/balance endpoint (Issue #305).
3+
4+
Tests cover:
5+
- Success cases for existing and zero balances.
6+
- Error handling for missing/invalid miner_id.
7+
- Database operational errors (e.g., locked database).
8+
- General unexpected database errors.
9+
- Correct response format and RTC conversion.
10+
"""
11+
12+
import importlib.util
13+
import os
14+
import sys
15+
import tempfile
16+
import unittest
17+
from unittest.mock import patch, MagicMock
18+
import sqlite3
19+
20+
# Define the path to the node directory and the integrated module.
21+
NODE_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
22+
MODULE_PATH = os.path.join(NODE_DIR, "rustchain_v2_integrated_v2.2.1_rip200.py")
23+
24+
# Constants for test scenarios
25+
TEST_DB_PATH = os.path.join(tempfile.gettempdir(), "test_rustchain_balance.db")
26+
MINER_ID_ALICE = "alice"
27+
MINER_ID_BOB = "bob"
28+
MINER_ID_CHARLIE = "charlie"
29+
ALICE_BALANCE_I64 = 150_000_000
30+
BOB_BALANCE_I64 = 0
31+
UNIT = 1_000_000 # uRTC per 1 RTC, from rewards_implementation_rip200.py
32+
RTC_DECIMAL_PRECISION = 8
33+
DATABASE_LOCKED_ERROR_MESSAGE = "Service unavailable due to database issues"
34+
UNEXPECTED_DATABASE_ERROR_MESSAGE = "An unexpected database error occurred"
35+
36+
37+
class TestWalletBalanceEndpoint(unittest.TestCase):
38+
"""Comprehensive tests for the /wallet/balance endpoint."""
39+
40+
@classmethod
41+
def setUpClass(cls):
42+
"""Set up for all tests in this class."""
43+
# Ensure NODE_DIR is in sys.path for module import
44+
if NODE_DIR not in sys.path:
45+
sys.path.insert(0, NODE_DIR)
46+
47+
# Import the module containing the Flask app
48+
spec = importlib.util.spec_from_file_location(
49+
"rustchain_integrated_rewards_test", MODULE_PATH
50+
)
51+
cls.mod = importlib.util.module_from_spec(spec)
52+
spec.loader.exec_module(cls.mod)
53+
54+
# Override DB_PATH within the module for testing purposes
55+
cls.original_db_path = cls.mod.DB_PATH
56+
cls.mod.DB_PATH = TEST_DB_PATH
57+
58+
# Initialize Flask test client
59+
cls.client = cls.mod.app.test_client()
60+
61+
# Create a temporary database for setup and ensure it's clean
62+
cls._init_db()
63+
64+
@classmethod
65+
def tearDownClass(cls):
66+
"""Clean up after all tests in this class."""
67+
# Restore original DB_PATH
68+
cls.mod.DB_PATH = cls.original_db_path
69+
# Clean up temporary database file
70+
if os.path.exists(TEST_DB_PATH):
71+
os.remove(TEST_DB_PATH)
72+
73+
@classmethod
74+
def _init_db(cls):
75+
"""Initialize and populate the test database."""
76+
if os.path.exists(TEST_DB_PATH):
77+
os.remove(TEST_DB_PATH)
78+
79+
conn = sqlite3.connect(TEST_DB_PATH)
80+
cursor = conn.cursor()
81+
cursor.execute(
82+
"""
83+
CREATE TABLE IF NOT EXISTS balances (
84+
miner_id TEXT PRIMARY KEY,
85+
amount_i64 INTEGER NOT NULL
86+
);
87+
"""
88+
)
89+
cursor.execute(
90+
"INSERT INTO balances (miner_id, amount_i64) VALUES (?, ?) ON CONFLICT(miner_id) DO UPDATE SET amount_i64 = excluded.amount_i64;",
91+
(MINER_ID_ALICE, ALICE_BALANCE_I64)
92+
)
93+
conn.commit()
94+
conn.close()
95+
96+
def setUp(self):
97+
"""Reset the database for each test to ensure isolation."""
98+
self._init_db() # Re-initialize the DB before each test
99+
100+
# --- Success Cases ---
101+
102+
def test_get_balance_success_existing_miner(self):
103+
"""Test fetching balance for an existing miner with funds."""
104+
resp = self.client.get(f"/wallet/balance?miner_id={MINER_ID_ALICE}")
105+
self.assertEqual(resp.status_code, 200)
106+
data = resp.get_json()
107+
108+
self.assertIsNotNone(data)
109+
self.assertEqual(data["miner_id"], MINER_ID_ALICE)
110+
self.assertEqual(data["amount_i64"], ALICE_BALANCE_I64)
111+
self.assertAlmostEqual(data["amount_rtc"], round(ALICE_BALANCE_I64 / UNIT, RTC_DECIMAL_PRECISION))
112+
self.assertIsInstance(data["amount_rtc"], float)
113+
114+
def test_get_balance_success_non_existent_miner(self):
115+
"""Test fetching balance for a miner not in the database."""
116+
resp = self.client.get(f"/wallet/balance?miner_id={MINER_ID_BOB}")
117+
self.assertEqual(resp.status_code, 200)
118+
data = resp.get_json()
119+
120+
self.assertIsNotNone(data)
121+
self.assertEqual(data["miner_id"], MINER_ID_BOB)
122+
self.assertEqual(data["amount_i64"], BOB_BALANCE_I64)
123+
self.assertEqual(data["amount_rtc"], 0.0)
124+
125+
# --- Error Cases: miner_id parameter ---
126+
127+
def test_get_balance_missing_miner_id(self):
128+
"""Test request without 'miner_id' parameter."""
129+
resp = self.client.get("/wallet/balance")
130+
self.assertEqual(resp.status_code, 400)
131+
data = resp.get_json()
132+
self.assertEqual(data["error"], "miner_id required")
133+
134+
def test_get_balance_empty_miner_id(self):
135+
"""Test request with an empty 'miner_id' parameter."""
136+
resp = self.client.get("/wallet/balance?miner_id=")
137+
self.assertEqual(resp.status_code, 400)
138+
data = resp.get_json()
139+
self.assertEqual(data["error"], "miner_id required")
140+
141+
# --- Error Cases: Database Issues ---
142+
143+
def test_get_balance_operational_error(self):
144+
"""Test database operational error (e.g., locked DB)."""
145+
with patch.object(self.mod.sqlite3, "connect", side_effect=sqlite3.OperationalError("database is locked")):
146+
resp = self.client.get(f"/wallet/balance?miner_id={MINER_ID_ALICE}")
147+
self.assertEqual(resp.status_code, 503)
148+
data = resp.get_json()
149+
self.assertEqual(data["error"], DATABASE_LOCKED_ERROR_MESSAGE)
150+
151+
def test_get_balance_general_sqlite_error(self):
152+
"""Test a general unexpected sqlite3.Error."""
153+
with patch.object(self.mod.sqlite3, "connect", side_effect=sqlite3.Error("disk I/O error")):
154+
resp = self.client.get(f"/wallet/balance?miner_id={MINER_ID_ALICE}")
155+
self.assertEqual(resp.status_code, 500)
156+
data = resp.get_json()
157+
self.assertEqual(data["error"], UNEXPECTED_DATABASE_ERROR_MESSAGE)
158+
159+
def test_get_balance_operational_error_during_execute(self):
160+
"""Test database operational error during query execution."""
161+
mock_cursor = MagicMock()
162+
mock_cursor.fetchone.side_effect = sqlite3.OperationalError("database table locked")
163+
mock_db = MagicMock()
164+
mock_db.execute.return_value = mock_cursor
165+
mock_db.__enter__.return_value = mock_db
166+
mock_db.__exit__.return_value = None
167+
168+
with patch.object(self.mod.sqlite3, "connect", return_value=mock_db):
169+
resp = self.client.get(f"/wallet/balance?miner_id={MINER_ID_ALICE}")
170+
self.assertEqual(resp.status_code, 503)
171+
data = resp.get_json()
172+
self.assertEqual(data["error"], DATABASE_LOCKED_ERROR_MESSAGE)
173+
mock_db.execute.assert_called_once_with(
174+
"SELECT amount_i64 FROM balances WHERE miner_id = ?",
175+
(MINER_ID_ALICE,)
176+
)
177+
178+
def test_get_balance_general_sqlite_error_during_execute(self):
179+
"""Test a general unexpected sqlite3.Error during query execution."""
180+
mock_cursor = MagicMock()
181+
mock_cursor.fetchone.side_effect = sqlite3.Error("malformed database schema")
182+
mock_db = MagicMock()
183+
mock_db.execute.return_value = mock_cursor
184+
mock_db.__enter__.return_value = mock_db
185+
mock_db.__exit__.return_value = None
186+
187+
with patch.object(self.mod.sqlite3, "connect", return_value=mock_db):
188+
resp = self.client.get(f"/wallet/balance?miner_id={MINER_ID_ALICE}")
189+
self.assertEqual(resp.status_code, 500)
190+
data = resp.get_json()
191+
self.assertEqual(data["error"], UNEXPECTED_DATABASE_ERROR_MESSAGE)
192+
mock_db.execute.assert_called_once_with(
193+
"SELECT amount_i64 FROM balances WHERE miner_id = ?",
194+
(MINER_ID_ALICE,)
195+
)
196+
197+
# --- Response Format Validation ---
198+
199+
def test_get_balance_response_schema(self):
200+
"""Verify the response matches the expected schema."""
201+
resp = self.client.get(f"/wallet/balance?miner_id={MINER_ID_ALICE}")
202+
self.assertEqual(resp.status_code, 200)
203+
data = resp.get_json()
204+
205+
self.assertIn("miner_id", data)
206+
self.assertIn("amount_i64", data)
207+
self.assertIn("amount_rtc", data)
208+
self.assertIsInstance(data["miner_id"], str)
209+
self.assertIsInstance(data["amount_i64"], int)
210+
self.assertIsInstance(data["amount_rtc"], float)
211+
212+
def test_get_balance_rtc_precision(self):
213+
"""Test that amount_rtc is rounded to the specified precision."""
214+
# Assume UNIT and RTC_DECIMAL_PRECISION are accessible from the module or hardcoded for test
215+
balance_i64_complex = 123_456_789
216+
expected_rtc = round(balance_i64_complex / UNIT, RTC_DECIMAL_PRECISION)
217+
218+
conn = sqlite3.connect(TEST_DB_PATH)
219+
cursor = conn.cursor()
220+
cursor.execute(
221+
"INSERT INTO balances (miner_id, amount_i64) VALUES (?, ?) ON CONFLICT(miner_id) DO UPDATE SET amount_i64 = excluded.amount_i64;",
222+
(MINER_ID_CHARLIE, balance_i64_complex)
223+
)
224+
conn.commit()
225+
conn.close()
226+
227+
resp = self.client.get(f"/wallet/balance?miner_id={MINER_ID_CHARLIE}")
228+
self.assertEqual(resp.status_code, 200)
229+
data = resp.get_json()
230+
self.assertAlmostEqual(data["amount_rtc"], expected_rtc)
231+
# Verify the number of decimal places for amount_rtc
232+
rtc_str = str(data["amount_rtc"])
233+
if '.' in rtc_str:
234+
actual_precision = len(rtc_str.split('.')[-1])
235+
self.assertLessEqual(actual_precision, RTC_DECIMAL_PRECISION)
236+
237+
238+
if __name__ == "__main__":
239+
unittest.main(verbosity=2)

0 commit comments

Comments
 (0)