diff --git a/main.py b/main.py index e1edc51..dcb78af 100644 --- a/main.py +++ b/main.py @@ -28,6 +28,7 @@ from minichain import Transaction, Blockchain, Block, State, Mempool, P2PNetwork, mine_block from minichain.validators import is_valid_receiver +from minichain.chain import MAX_BLOCKS_PER_REQUEST logger = logging.getLogger(__name__) @@ -156,6 +157,51 @@ async def handler(data): else: logger.warning("📥 Received Block #%s — rejected", block.index) + elif msg_type == "status": + import json as _json + peer_height = payload["height"] + my_height = chain.height + + if peer_height > my_height: + writer = data.get("_writer") + if writer: + from_h = my_height + 1 + to_h = min(peer_height, from_h + MAX_BLOCKS_PER_REQUEST - 1) + request = _json.dumps({ + "type": "get_blocks", + "data": {"from_height": from_h, "to_height": to_h}, + }) + "\n" + writer.write(request.encode()) + await writer.drain() + logger.info( + "📡 Requesting blocks %d~%d from %s", + from_h, to_h, peer_addr, + ) + + elif msg_type == "get_blocks": + import json as _json + from_h = payload["from_height"] + to_h = payload["to_height"] + blocks = chain.get_blocks_range(from_h, to_h) + + writer = data.get("_writer") + if writer and blocks: + response = _json.dumps({ + "type": "blocks", + "data": {"blocks": blocks} + }) + "\n" + writer.write(response.encode()) + await writer.drain() + logger.info("📤 Sent %d blocks to %s", len(blocks), peer_addr) + + elif msg_type == "blocks": + received = payload["blocks"] + success, count = chain.add_blocks_bulk(received) + if success: + logger.info("✅ Chain synced: added %d blocks", count) + else: + logger.warning("❌ Chain sync failed — batch rejected") + return handler @@ -318,13 +364,24 @@ async def run_node(port: int, host: str, connect_to: str | None, fund: int, data # When a new peer connects, send our state so they can sync async def on_peer_connected(writer): import json as _json + accounts_snapshot, height_snapshot = chain.snapshot_state_and_height() + sync_msg = _json.dumps({ "type": "sync", - "data": {"accounts": chain.state.accounts} + "data": {"accounts": accounts_snapshot}, + }) + "\n" + status_msg = _json.dumps({ + "type": "status", + "data": {"height": height_snapshot}, }) + "\n" + writer.write(sync_msg.encode()) + writer.write(status_msg.encode()) await writer.drain() - logger.info("🔄 Sent state sync to new peer") + logger.info( + "🔄 Sent state sync (%d accounts) and status (height=%d) to new peer", + len(accounts_snapshot), height_snapshot, + ) network.register_on_peer_connected(on_peer_connected) diff --git a/minichain/chain.py b/minichain/chain.py index b65d575..1e6136d 100644 --- a/minichain/chain.py +++ b/minichain/chain.py @@ -4,6 +4,8 @@ import logging import threading +MAX_BLOCKS_PER_REQUEST = 500 + logger = logging.getLogger(__name__) @@ -54,6 +56,12 @@ def last_block(self): with self._lock: # Acquire lock for thread-safe access return self.chain[-1] + @property + def height(self) -> int: + """Returns the current chain height (genesis = 0)""" + with self._lock: + return len(self.chain) - 1 + def add_block(self, block): """ Validates and adds a block to the chain if all transactions succeed. @@ -82,3 +90,64 @@ def add_block(self, block): self.state = temp_state self.chain.append(block) return True + + def get_blocks_range(self, from_height: int, to_height: int) -> list: + """Return serialized blocks in [from_height, to_height], capped at MAX_BLOCKS_PER_REQUEST.""" + with self._lock: + to_height = min( + to_height, + len(self.chain) - 1, + from_height + MAX_BLOCKS_PER_REQUEST - 1, + ) + if from_height > to_height or from_height < 0: + return [] + return [b.to_dict() for b in self.chain[from_height:to_height + 1]] + + def add_blocks_bulk(self, block_dicts: list) -> tuple: + """ + Atomically add a batch of blocks: validate each block's transactions + against a temporary state, and commit chain + state only if every + block passes. Any failure leaves the local chain and state untouched. + + Returns (True, count) on full success, (False, 0) on any failure. + """ + with self._lock: + temp_state = self.state.copy() + prev_block = self.chain[-1] + new_blocks = [] + + for block_dict in block_dicts: + try: + block = Block.from_dict(block_dict) + except (KeyError, TypeError, ValueError) as exc: + logger.warning("Bulk add rejected: malformed block dict: %s", exc) + return False, 0 + + try: + validate_block_link_and_hash(prev_block, block) + except ValueError as exc: + logger.warning("Bulk add rejected at block %s: %s", block.index, exc) + return False, 0 + + for tx in block.transactions: + if not temp_state.validate_and_apply(tx): + logger.warning( + "Bulk add rejected at block %s: transaction failed validation", + block.index, + ) + return False, 0 + + new_blocks.append(block) + prev_block = block + + self.state = temp_state + self.chain.extend(new_blocks) + return True, len(new_blocks) + + def snapshot_state_and_height(self) -> tuple: + """Capture accounts and chain height under a single lock acquisition.""" + with self._lock: + accounts_copy = { + addr: dict(acc) for addr, acc in self.state.accounts.items() + } + return accounts_copy, len(self.chain) - 1 \ No newline at end of file diff --git a/minichain/p2p.py b/minichain/p2p.py index 3271598..787b12e 100644 --- a/minichain/p2p.py +++ b/minichain/p2p.py @@ -11,11 +11,12 @@ from .serialization import canonical_json_hash from .validators import is_valid_receiver +from .chain import MAX_BLOCKS_PER_REQUEST logger = logging.getLogger(__name__) TOPIC = "minichain-global" -SUPPORTED_MESSAGE_TYPES = {"sync", "tx", "block"} +SUPPORTED_MESSAGE_TYPES = {"sync", "tx", "block", "status", "get_blocks", "blocks"} class P2PNetwork: @@ -207,6 +208,39 @@ def _validate_block_payload(self, payload): for tx_payload in payload["transactions"] ) + def _validate_status_payload(self, payload): + if not isinstance(payload, dict): + return False + if set(payload) != {"height"}: + return False + if not isinstance(payload["height"], int) or payload["height"] < 0: + return False + return True + + def _validate_get_blocks_payload(self, payload): + if not isinstance(payload, dict): + return False + if set(payload) != {"from_height", "to_height"}: + return False + fh, th = payload.get("from_height"), payload.get("to_height") + if not isinstance(fh, int) or not isinstance(th, int): + return False + if fh < 0 or fh > th: + return False + return True + + def _validate_blocks_payload(self, payload): + if not isinstance(payload, dict): + return False + if set(payload) != {"blocks"}: + return False + blocks = payload.get("blocks") + if not isinstance(blocks, list): + return False + if len(blocks) > MAX_BLOCKS_PER_REQUEST: + return False + return all(self._validate_block_payload(b) for b in blocks) + def _validate_message(self, message): if not isinstance(message, dict): return False @@ -226,6 +260,9 @@ def _validate_message(self, message): "sync": self._validate_sync_payload, "tx": self._validate_transaction_payload, "block": self._validate_block_payload, + "status": self._validate_status_payload, + "get_blocks": self._validate_get_blocks_payload, + "blocks": self._validate_blocks_payload, } return validators[msg_type](payload) @@ -283,6 +320,7 @@ async def _listen_to_peer( continue self._mark_seen(msg_type, payload) data["_peer_addr"] = addr + data["_writer"] = writer if self._handler_callback: try: