diff --git a/main.py b/main.py index fe1c1ce..1f7d1ba 100644 --- a/main.py +++ b/main.py @@ -27,7 +27,7 @@ from minichain import Transaction, Blockchain, Block, State, Mempool, P2PNetwork, mine_block from minichain.rpc import JSONRPCServer -from minichain.validators import is_valid_receiver +from minichain.validators import is_valid_receiver, ValidationStatus from minichain.block import calculate_receipt_root @@ -97,7 +97,7 @@ def mine_and_process_block(chain, mempool, miner_pk): mined_block = mine_block(block) - if chain.add_block(mined_block): + if chain.add_block(mined_block) == ValidationStatus.VALID: logger.info("✅ Block #%d mined and added (%d txs)", mined_block.index, len(mineable_txs)) mempool.remove_transactions(mineable_txs) return mined_block @@ -117,6 +117,7 @@ def mine_and_process_block(chain, mempool, miner_pk): def make_network_handler(chain, mempool, network): """Return an async callback that processes incoming P2P messages.""" + from minichain.validators import ValidationStatus async def handler(data): msg_type = data.get("type") @@ -148,24 +149,30 @@ async def handler(data): elif msg_type == "tx": try: tx = Transaction.from_dict(payload) - if getattr(tx, "chain_id", None) != chain.chain_id: - logger.warning("Invalid chain_id in tx from %s", peer_addr) - return - if mempool.add_transaction(tx): - logger.info("📥 Received tx from %s... (amount=%s)", tx.sender[:8], tx.amount) except Exception as e: logger.warning("Invalid tx payload from %s: %s", peer_addr, e) + return ValidationStatus.MALFORMED + + if getattr(tx, "chain_id", None) != chain.chain_id: + logger.warning("Invalid chain_id in tx from %s", peer_addr) + return ValidationStatus.INVALID + + if mempool.add_transaction(tx): + logger.info("📥 Received tx from %s... (amount=%s)", tx.sender[:8], tx.amount) + return ValidationStatus.VALID + else: + return ValidationStatus.FAILED elif msg_type == "block": try: block = Block.from_dict(payload) except Exception as e: logger.warning("Invalid block payload from %s: %s", peer_addr, e) - return + return ValidationStatus.MALFORMED - if chain.add_block(block): + status = chain.add_block(block) + if status == ValidationStatus.VALID: logger.info("📥 Received Block #%d — added to chain", block.index) - # Drop only confirmed transactions so higher nonces can remain queued. mempool.remove_transactions(block.transactions) else: @@ -178,6 +185,7 @@ async def handler(data): # For a fork, request the full chain to use resolve_conflicts req = {"type": "chain_request", "data": {"start_index": 0, "limit": 1000000}} # Request full chain for reorg asyncio.create_task(network._broadcast_raw(req)) + return status elif msg_type == "chain_request": start_index = payload.get("start_index", 0) @@ -221,7 +229,7 @@ async def handler(data): for block in new_chain: if block.index <= chain.last_block.index: continue # Ignore already known blocks - if chain.add_block(block): + if chain.add_block(block) == ValidationStatus.VALID: logger.info("📥 Synced Block #%d", block.index) mempool.remove_transactions(block.transactions) else: @@ -265,7 +273,7 @@ async def handler(data): """ -async def cli_loop(sk, pk, chain, mempool, network): +async def cli_loop(sk, pk, chain, mempool, network, datadir: str | None = None): """Read commands from stdin asynchronously.""" loop = asyncio.get_event_loop() print(HELP_TEXT) @@ -426,7 +434,7 @@ async def cli_loop(sk, pk, chain, mempool, network): # ── list-banned ── elif cmd == "list-banned": from minichain.persistence import get_banned_peers - banned = get_banned_peers() + banned = get_banned_peers(path=datadir or ".") if not banned: print(" No peers are currently banned.") else: @@ -441,7 +449,8 @@ async def cli_loop(sk, pk, chain, mempool, network): continue peer_id = parts[1] from minichain.persistence import ban_peer - ban_peer(peer_id, reason="Manual ban via CLI") + ban_peer(peer_id, reason="Manual ban via CLI", path=datadir or ".") + asyncio.create_task(network.disconnect_peer(f"peer:{peer_id}")) print(f" ✅ Peer {peer_id} banned.") # ── unban ── @@ -451,7 +460,7 @@ async def cli_loop(sk, pk, chain, mempool, network): continue peer_id = parts[1] from minichain.persistence import unban_peer - unban_peer(peer_id) + unban_peer(peer_id, path=datadir or ".") print(f" ✅ Peer {peer_id} unbanned.") # ── help ── @@ -493,7 +502,7 @@ async def run_node(port: int, host: str, connect_to: str | None, fund: int, data chain = Blockchain() mempool = Mempool() - network = P2PNetwork() + network = P2PNetwork(data_path=datadir or ".") handler = make_network_handler(chain, mempool, network) network.register_handler(handler) @@ -534,7 +543,7 @@ async def on_peer_connected(writer): await network.connect_to_peer(connect_to) try: - await cli_loop(sk, pk, chain, mempool, network) + await cli_loop(sk, pk, chain, mempool, network, datadir) finally: # Save chain to disk on shutdown if datadir: diff --git a/minichain/chain.py b/minichain/chain.py index 1ed9b84..8aacd17 100644 --- a/minichain/chain.py +++ b/minichain/chain.py @@ -253,6 +253,10 @@ def resolve_conflicts(self, new_chain_list) -> tuple[bool, list]: logger.warning("Reorg failed: Invalid receipt root at block %s. Expected %s, got %s", block.index, computed_receipt_root, block.receipt_root) return False, [] + if [r.to_dict() for r in block.receipts] != [r.to_dict() for r in receipts]: + logger.warning("Reorg failed: Receipt payload mismatch at block %s", block.index) + return False, [] + if block.state_root != temp_state.state_root(): logger.warning("Reorg failed: Invalid state root at block %s", block.index) return False, [] diff --git a/minichain/p2p.py b/minichain/p2p.py index 28efe38..4962e5a 100644 --- a/minichain/p2p.py +++ b/minichain/p2p.py @@ -7,6 +7,7 @@ import json import logging import threading +import time import trio import queue @@ -15,16 +16,33 @@ from libp2p.peer.peerinfo import info_from_p2p_addr from multiaddr import Multiaddr from .serialization import canonical_json_hash, canonical_json_dumps +from .validators import ValidationStatus +from .persistence import ban_peer, is_peer_banned logger = logging.getLogger(__name__) SUPPORTED_MESSAGE_TYPES = {"hello", "tx", "block", "chain_request", "chain_response"} PROTOCOL_ID = TProtocol("/minichain/1.0.0") +# Misbehavior thresholds — all four are overridable per P2PNetwork instance. +MALFORMED_THRESHOLD = 15 # N: accumulated malformed messages before ban +FAILED_THRESHOLD = 15 # M: accumulated failed messages before ban +INVALID_THRESHOLD = 1 # L: accumulated invalid messages before ban (1 = immediate) +DECAY_INTERVAL_MINUTES = 10 # T: counter half-life period in minutes + + class P2PNetwork: """Lightweight peer-to-peer networking using libp2p.""" - def __init__(self, handler_callback=None): + def __init__( + self, + handler_callback=None, + data_path: str = ".", + malformed_threshold: int = MALFORMED_THRESHOLD, + failed_threshold: int = FAILED_THRESHOLD, + invalid_threshold: int = INVALID_THRESHOLD, + decay_interval_minutes: float = DECAY_INTERVAL_MINUTES, + ): self._handler_callback = handler_callback self._on_peer_connected = None self._seen_tx_ids = set() @@ -34,6 +52,24 @@ def __init__(self, handler_callback=None): self._peer_count = 0 self._peer_count_lock = threading.Lock() + # Misbehavior tracking + self.data_path = data_path + self.malformed_threshold = malformed_threshold + self.failed_threshold = failed_threshold + self.invalid_threshold = invalid_threshold + self.decay_interval_minutes = decay_interval_minutes + # { peer_id_str -> {"malformed": int, "failed": int, "invalid": int} } + self._peer_counters: dict = {} + + if self.decay_interval_minutes <= 0: + raise ValueError(f"decay_interval_minutes must be positive, got {self.decay_interval_minutes}") + if self.malformed_threshold <= 0: + raise ValueError(f"malformed_threshold must be positive, got {self.malformed_threshold}") + if self.failed_threshold <= 0: + raise ValueError(f"failed_threshold must be positive, got {self.failed_threshold}") + if self.invalid_threshold <= 0: + raise ValueError(f"invalid_threshold must be positive, got {self.invalid_threshold}") + def register_handler(self, handler_callback): self._handler_callback = handler_callback @@ -44,9 +80,10 @@ async def start(self, port: int = 9000, host: str = "127.0.0.1"): self.port = port self.host_addr = host self.loop = asyncio.get_running_loop() - + threading.Thread(target=trio.run, args=(self._trio_main,), daemon=True).start() asyncio.create_task(self._asyncio_reader()) + asyncio.create_task(self._decay_counters()) logger.info(f"Network: Starting libp2p on port {port}") async def stop(self): @@ -101,17 +138,117 @@ def peer_count(self) -> int: with self._peer_count_lock: return self._peer_count + # ── misbehavior helpers ────────────────────────────────────────────────── + + def _increment_counter(self, peer_id: str, category: str) -> bool: + """ + Increment the named counter (malformed/failed/invalid) for peer_id. + Returns True if any counter now meets or exceeds its threshold. + Called only from the asyncio thread — no lock needed. + """ + if peer_id not in self._peer_counters: + self._peer_counters[peer_id] = {"malformed": 0, "failed": 0, "invalid": 0} + self._peer_counters[peer_id][category] += 1 + counts = self._peer_counters[peer_id] + return ( + counts["malformed"] >= self.malformed_threshold + or counts["failed"] >= self.failed_threshold + or counts["invalid"] >= self.invalid_threshold + ) + + async def _handle_validation_status( + self, peer_id: str, peer_addr: str, status: ValidationStatus + ): + """ + Apply misbehavior policy for a single ValidationStatus event: + MALFORMED → always disconnect; ban if counter >= N + FAILED → drop silently; ban + disconnect if counter >= M + INVALID → always ban + disconnect (L=1 means first occurrence triggers) + """ + if status == ValidationStatus.MALFORMED: + await self.disconnect_peer(peer_addr) + if self._increment_counter(peer_id, "malformed"): + ban_peer(peer_id, reason="malformed_threshold_exceeded", path=self.data_path) + logger.warning( + "Banned peer %s: malformed message threshold (%d) exceeded", + peer_id, self.malformed_threshold, + ) + + elif status == ValidationStatus.FAILED: + if self._increment_counter(peer_id, "failed"): + ban_peer(peer_id, reason="failed_threshold_exceeded", path=self.data_path) + await self.disconnect_peer(peer_addr) + logger.warning( + "Banned and disconnected peer %s: failed message threshold (%d) exceeded", + peer_id, self.failed_threshold, + ) + + elif status == ValidationStatus.INVALID: + if self._increment_counter(peer_id, "invalid"): + ban_peer(peer_id, reason="invalid_threshold_exceeded", path=self.data_path) + await self.disconnect_peer(peer_addr) + logger.warning( + "Banned and disconnected peer %s: invalid message threshold (%d) exceeded", + peer_id, self.invalid_threshold, + ) + + async def _decay_counters(self): + """ + Half-life decay: every decay_interval_minutes minutes divide all per-peer + counters by 2 (integer floor division). Runs for the lifetime of the node. + """ + interval_seconds = self.decay_interval_minutes * 60 + while True: + await asyncio.sleep(interval_seconds) + for counts in self._peer_counters.values(): + counts["malformed"] //= 2 + counts["failed"] //= 2 + counts["invalid"] //= 2 + + # ── asyncio reader ─────────────────────────────────────────────────────── + async def _asyncio_reader(self): while True: - try: msg = await self.loop.run_in_executor(None, self._to_asyncio.get) - except Exception: continue - + try: + msg = await self.loop.run_in_executor(None, self._to_asyncio.get) + except Exception: + continue + if msg[0] == "MSG": data = msg[1] - msg_type, payload = data.get("type"), data.get("data") - if msg_type not in SUPPORTED_MESSAGE_TYPES or self._is_duplicate(msg_type, payload): continue - self._mark_seen(msg_type, payload) - if self._handler_callback: await self._handler_callback(data) + msg_type = data.get("type") + payload = data.get("data") + peer_addr = data.get("_peer_addr", "") + peer_id = ( + peer_addr[len("peer:"):] if peer_addr.startswith("peer:") else peer_addr + ) + + if msg_type not in SUPPORTED_MESSAGE_TYPES: + continue + try: + if self._is_duplicate(msg_type, payload): + continue + self._mark_seen(msg_type, payload) + except Exception: + await self._handle_validation_status(peer_id, peer_addr, ValidationStatus.MALFORMED) + continue + + status = None + if self._handler_callback: + status = await self._handler_callback(data) + + # Only apply interception for content-bearing message types. + if msg_type in ("tx", "block") and status is not None: + await self._handle_validation_status(peer_id, peer_addr, status) + + elif msg[0] == "MALFORMED": + # JSON parse failure signalled from the Trio thread. + peer_addr = msg[1] + peer_id = ( + peer_addr[len("peer:"):] if peer_addr.startswith("peer:") else peer_addr + ) + await self._handle_validation_status(peer_id, peer_addr, ValidationStatus.MALFORMED) + elif msg[0] == "PEER_CONNECTED": class MockWriter: def write(self, data): self.data = data @@ -119,44 +256,67 @@ async def drain(self): pass if self._on_peer_connected: writer = MockWriter() await self._on_peer_connected(writer) - if hasattr(writer, 'data'): + if hasattr(writer, "data"): try: req = json.loads(writer.data.decode().strip()) await self._broadcast_raw(req) - except Exception: pass + except Exception: + pass + + # ── trio main ──────────────────────────────────────────────────────────── async def _trio_main(self): host = new_host() listen_addr = Multiaddr(f"/ip4/{self.host_addr}/tcp/{self.port}") await host.get_network().listen(listen_addr) print(f" Network Multiaddr: {listen_addr}/p2p/{host.get_id().to_string()}") - + streams = [] async def stream_handler(stream): + peer_id = str(stream.muxed_conn.peer_id) + addr = f"peer:{peer_id}" + + # Reject banned peers before doing anything else. + if is_peer_banned(peer_id, path=self.data_path): + logger.warning("Rejected connection from banned peer %s", peer_id) + try: + await stream.reset() + except Exception: + pass + return + streams.append(stream) with self._peer_count_lock: self._peer_count += 1 - peer_id = stream.muxed_conn.peer_id - addr = f"peer:{peer_id}" self._to_asyncio.put(("PEER_CONNECTED", None)) + try: + buffer = b"" while True: data = await stream.read(4096) - if not data: break - for line in data.split(b'\n'): - if not line: continue + if not data: + break + buffer += data + *lines, buffer = buffer.split(b"\n") + for line in lines: + if not line.strip(): + continue try: - msg = json.loads(line.decode().strip()) - msg["_peer_addr"] = addr - self._to_asyncio.put(("MSG", msg)) - except Exception: pass - except Exception: pass + parsed = json.loads(line.decode().strip()) + parsed["_peer_addr"] = addr + self._to_asyncio.put(("MSG", parsed)) + except Exception: + # Signal the asyncio side to apply MALFORMED policy. + self._to_asyncio.put(("MALFORMED", addr)) + except Exception: + pass + if stream in streams: streams.remove(stream) with self._peer_count_lock: self._peer_count -= 1 - + host.set_stream_handler(PROTOCOL_ID, stream_handler) async def check_queue(): @@ -164,7 +324,8 @@ async def check_queue(): try: while not self._to_trio.empty(): cmd, arg = self._to_trio.get_nowait() - if cmd == "STOP": return True + if cmd == "STOP": + return True elif cmd == "CONNECT": try: maddr = Multiaddr(arg) @@ -177,27 +338,34 @@ async def check_queue(): elif cmd == "BROADCAST": msg = (canonical_json_dumps(arg) + "\n").encode() for s in list(streams): - try: await s.write(msg) - except Exception: pass + try: + await s.write(msg) + except Exception: + pass elif cmd == "UNICAST": target_addr, payload = arg msg = (canonical_json_dumps(payload) + "\n").encode() for s in list(streams): - addr = f"peer:{s.muxed_conn.peer_id}" - if addr == target_addr: - try: await s.write(msg) - except Exception: pass + s_addr = f"peer:{s.muxed_conn.peer_id}" + if s_addr == target_addr: + try: + await s.write(msg) + except Exception: + pass elif cmd == "DISCONNECT": for s in list(streams): - addr = f"peer:{s.muxed_conn.peer_id}" - if addr == arg: - try: await s.reset() - except Exception: pass + s_addr = f"peer:{s.muxed_conn.peer_id}" + if s_addr == arg: + try: + await s.reset() + except Exception: + pass if s in streams: streams.remove(s) with self._peer_count_lock: self._peer_count -= 1 - except Exception: pass + except Exception: + pass await trio.sleep(0.1) async with trio.open_nursery() as nursery: diff --git a/minichain/state.py b/minichain/state.py index 413fec5..bbe3ad6 100644 --- a/minichain/state.py +++ b/minichain/state.py @@ -91,9 +91,12 @@ def validate_and_apply(self, tx): Validate and apply a transaction. Returns: Receipt|None """ - # Semantic validation: amount must be an integer and non-negative + # Semantic validation: amount and fee must be non-negative integers if not isinstance(tx.amount, int) or tx.amount < 0: return None + fee = getattr(tx, "fee", 0) + if not isinstance(fee, int) or fee < 0: + return None return self.apply_transaction(tx) def validate_and_apply_with_status(self, tx): @@ -104,24 +107,38 @@ def validate_and_apply_with_status(self, tx): from .validators import ValidationStatus if not isinstance(tx.amount, int) or tx.amount < 0: return ValidationStatus.MALFORMED, None - + fee = getattr(tx, "fee", 0) + if not isinstance(fee, int) or fee < 0: + return ValidationStatus.MALFORMED, None + status = self.verify_transaction_logic(tx) if status != ValidationStatus.VALID: return status, None - - # We know it's valid, so apply_transaction will succeed and return a Receipt - return ValidationStatus.VALID, self.apply_transaction(tx) + + # verify_transaction_logic already passed — skip the second call inside apply_transaction. + return ValidationStatus.VALID, self._apply_validated_tx(tx) def apply_transaction(self, tx): """ - Applies transaction and mutates state. - Returns: Receipt object if mathematically valid, None if invalid. + Validates and applies a transaction. + Returns: Receipt object if valid, None if invalid. """ + if not isinstance(tx.amount, int) or tx.amount < 0: + return None + fee = getattr(tx, "fee", 0) + if not isinstance(fee, int) or fee < 0: + return None from .validators import ValidationStatus - status = self.verify_transaction_logic(tx) - if status != ValidationStatus.VALID: + if self.verify_transaction_logic(tx) != ValidationStatus.VALID: return None + return self._apply_validated_tx(tx) + def _apply_validated_tx(self, tx): + """ + Apply a transaction that has already passed verify_transaction_logic. + Mutates state and returns a Receipt. Never call this directly — use + apply_transaction() or validate_and_apply_with_status() instead. + """ sender = self.accounts[tx.sender] total_cost = tx.amount + getattr(tx, 'fee', 0) diff --git a/tests/test_difficulty.py b/tests/test_difficulty.py index 0176f9b..d15c853 100644 --- a/tests/test_difficulty.py +++ b/tests/test_difficulty.py @@ -1,6 +1,7 @@ import unittest from minichain import Blockchain, Block from minichain.pow import mine_block +from minichain.validators import ValidationStatus class TestEMADifficulty(unittest.TestCase): def test_difficulty_adjustment(self): @@ -16,7 +17,7 @@ def test_difficulty_adjustment(self): ts = chain.last_block.timestamp + 1 block1 = Block(index=1, previous_hash=chain.last_block.hash, transactions=[], timestamp=ts, difficulty=chain.current_difficulty, state_root=chain.state.state_root()) mined_block1 = mine_block(block1) - self.assertTrue(chain.add_block(mined_block1)) + self.assertEqual(chain.add_block(mined_block1), ValidationStatus.VALID) self.assertEqual(chain.current_difficulty, 4) # Slow mining: timestamp 5000ms apart @@ -24,7 +25,7 @@ def test_difficulty_adjustment(self): ts = chain.last_block.timestamp + 5000 block2 = Block(index=2, previous_hash=chain.last_block.hash, transactions=[], timestamp=ts, difficulty=chain.current_difficulty, state_root=chain.state.state_root()) mined_block2 = mine_block(block2) - self.assertTrue(chain.add_block(mined_block2)) + self.assertEqual(chain.add_block(mined_block2), ValidationStatus.VALID) self.assertEqual(chain.current_difficulty, 3) def test_reorg_difficulty_validation(self): diff --git a/tests/test_persistence_runtime.py b/tests/test_persistence_runtime.py index 894ccca..c8b2121 100644 --- a/tests/test_persistence_runtime.py +++ b/tests/test_persistence_runtime.py @@ -12,7 +12,7 @@ class FakeNetwork: - def __init__(self): + def __init__(self, **kwargs): self.handler = None self.peer_count = 0 self._on_peer_connected = None @@ -84,7 +84,7 @@ async def test_run_node_loads_existing_sqlite_snapshot(self): chain = self._chain_with_tx() save(chain, self.tmpdir) - async def fake_cli_loop(sk, pk, loaded_chain, mempool, network): + async def fake_cli_loop(sk, pk, loaded_chain, mempool, network, datadir=None): self.assertEqual(len(loaded_chain.chain), len(chain.chain)) self.assertEqual(loaded_chain.last_block.hash, chain.last_block.hash) self.assertEqual(loaded_chain.state.accounts, chain.state.accounts) @@ -103,7 +103,7 @@ async def fake_cli_loop(sk, pk, loaded_chain, mempool, network): async def test_run_node_saves_sqlite_snapshot_on_shutdown(self): fixed_sk, fixed_pk = _make_keypair() - async def fake_cli_loop(sk, pk, chain, mempool, network): + async def fake_cli_loop(sk, pk, chain, mempool, network, datadir=None): self.assertEqual(pk, fixed_pk) self.assertEqual(chain.state.get_account(pk)["balance"], 25)