diff --git a/fila/async_client.py b/fila/async_client.py index 5cb8350..6c4f429 100644 --- a/fila/async_client.py +++ b/fila/async_client.py @@ -15,7 +15,7 @@ AsyncFibpConnection, FibpError, decode_ack_nack_response, - decode_consume_message, + decode_consume_push, decode_enqueue_response, encode_ack, encode_consume, @@ -259,11 +259,12 @@ async def consume(self, queue: str) -> AsyncIterator[ConsumeMessage]: except FibpError as e: raise _map_fibp_error(e.code, e.message) from e - return self._consume_iter(q) + return self._consume_iter(q, queue) async def _consume_iter( self, q: object, + queue: str, ) -> AsyncIterator[ConsumeMessage]: import asyncio # q is an asyncio.Queue[bytes | None] @@ -273,23 +274,22 @@ async def _consume_iter( if body is None: return try: - msg_id, queue, headers, payload, fairness_key, attempt_count = ( - decode_consume_message(body) - ) + messages = decode_consume_push(body) except Exception: _log.warning( - "failed to decode consume message; skipping frame", + "failed to decode consume push frame; skipping", exc_info=True, ) continue - yield ConsumeMessage( - id=msg_id, - headers=headers, - payload=payload, - fairness_key=fairness_key, - attempt_count=attempt_count, - queue=queue, - ) + for msg_id, headers, payload, fairness_key, attempt_count in messages: + yield ConsumeMessage( + id=msg_id, + headers=headers, + payload=payload, + fairness_key=fairness_key, + attempt_count=attempt_count, + queue=queue, + ) async def ack(self, queue: str, msg_id: str) -> None: """Acknowledge a successfully processed message. diff --git a/fila/batcher.py b/fila/batcher.py index 7d393cf..39a6d82 100644 --- a/fila/batcher.py +++ b/fila/batcher.py @@ -7,7 +7,7 @@ from concurrent.futures import Future, ThreadPoolExecutor from typing import TYPE_CHECKING -from fila.errors import _map_enqueue_error_code +from fila.errors import _map_enqueue_error_code, _map_fibp_error from fila.fibp import ( FibpError, decode_enqueue_response, @@ -59,7 +59,7 @@ def _flush_single( else: item.future.set_exception(_map_enqueue_error_code(err_code, err_msg)) except FibpError as e: - item.future.set_exception(_map_enqueue_error_code(e.code, e.message)) + item.future.set_exception(_map_fibp_error(e.code, e.message)) except Exception as e: item.future.set_exception(e) @@ -99,7 +99,7 @@ def _flush_queue_batch( try: body = conn.send_request(frame, corr_id).result() except FibpError as e: - err = _map_enqueue_error_code(e.code, e.message) + err = _map_fibp_error(e.code, e.message) for item in items: item.future.set_exception(err) return diff --git a/fila/client.py b/fila/client.py index 8997d10..d1bd065 100644 --- a/fila/client.py +++ b/fila/client.py @@ -16,7 +16,7 @@ FibpConnection, FibpError, decode_ack_nack_response, - decode_consume_message, + decode_consume_push, decode_enqueue_response, encode_ack, encode_consume, @@ -296,9 +296,9 @@ def consume(self, queue: str) -> Iterator[ConsumeMessage]: except FibpError as e: raise _map_fibp_error(e.code, e.message) from e - return self._consume_iter(cq) + return self._consume_iter(cq, queue) - def _consume_iter(self, cq: object) -> Iterator[ConsumeMessage]: + def _consume_iter(self, cq: object, queue: str) -> Iterator[ConsumeMessage]: from fila.fibp import _ConsumeQueue assert isinstance(cq, _ConsumeQueue) while True: @@ -306,23 +306,22 @@ def _consume_iter(self, cq: object) -> Iterator[ConsumeMessage]: if body is None: return try: - msg_id, queue, headers, payload, fairness_key, attempt_count = ( - decode_consume_message(body) - ) + messages = decode_consume_push(body) except Exception: _log.warning( - "failed to decode consume message; skipping frame", + "failed to decode consume push frame; skipping", exc_info=True, ) continue - yield ConsumeMessage( - id=msg_id, - headers=headers, - payload=payload, - fairness_key=fairness_key, - attempt_count=attempt_count, - queue=queue, - ) + for msg_id, headers, payload, fairness_key, attempt_count in messages: + yield ConsumeMessage( + id=msg_id, + headers=headers, + payload=payload, + fairness_key=fairness_key, + attempt_count=attempt_count, + queue=queue, + ) def ack(self, queue: str, msg_id: str) -> None: """Acknowledge a successfully processed message. diff --git a/fila/errors.py b/fila/errors.py index 4f743d4..0ffa3f2 100644 --- a/fila/errors.py +++ b/fila/errors.py @@ -3,6 +3,7 @@ from __future__ import annotations from fila.fibp import ( + ERR_AUTH_REQUIRED, ERR_INTERNAL, ERR_MESSAGE_NOT_FOUND, ERR_PERMISSION_DENIED, @@ -79,4 +80,6 @@ def _map_fibp_error(code: int, message: str) -> FilaError: return QueueNotFoundError(message) if code == ERR_MESSAGE_NOT_FOUND: return MessageNotFoundError(message) + if code in (ERR_AUTH_REQUIRED, ERR_PERMISSION_DENIED): + return TransportError(code, message) return TransportError(code, message) diff --git a/fila/fibp.py b/fila/fibp.py index f2e4fb9..1804056 100644 --- a/fila/fibp.py +++ b/fila/fibp.py @@ -177,8 +177,12 @@ def encode_nack(corr_id: int, items: list[tuple[str, str, str]]) -> bytes: def encode_auth(corr_id: int, api_key: str) -> bytes: - """Encode an AUTH frame carrying the API key.""" - return _encode_frame(0, OP_AUTH, corr_id, _encode_str(api_key)) + """Encode an AUTH frame carrying the API key. + + The server expects the raw UTF-8 bytes of the key as the payload — + no u16 length prefix (unlike most string fields in this protocol). + """ + return _encode_frame(0, OP_AUTH, corr_id, api_key.encode()) def encode_admin(op: int, corr_id: int, proto_body: bytes) -> bytes: @@ -221,37 +225,47 @@ def decode_enqueue_response(body: bytes) -> list[tuple[bool, str, int, str]]: return results -def decode_consume_message(body: bytes) -> tuple[str, str, dict[str, str], bytes, str, int]: - """Decode a single server-pushed consume frame body. +def decode_consume_push( + body: bytes, +) -> list[tuple[str, dict[str, str], bytes, str, int]]: + """Decode a server-pushed consume frame body (batch format). - Returns ``(msg_id, queue, headers, payload, fairness_key, attempt_count)``. + Returns a list of ``(msg_id, headers, payload, fairness_key, attempt_count)`` + tuples. The queue name is *not* included in the push frame — callers must + supply it from the subscribe context. - The consume push wire format is:: + The server wire format is:: - msg_id_len:u16 | msg_id - queue_len:u16 | queue - fairness_key_len:u16 | fairness_key - attempt_count:u32 - header_count:u8 | (key_len:u16 key val_len:u16 val)... - payload_len:u32 | payload + msg_count:u16 + for each message: + msg_id_len:u16 | msg_id + fairness_key:u16 | fairness_key + attempt_count:u32 + header_count:u8 | (key_len:u16 key val_len:u16 val)... + payload_len:u32 | payload """ offset = 0 - msg_id, offset = _decode_str(body, offset) - queue, offset = _decode_str(body, offset) - fairness_key, offset = _decode_str(body, offset) - (attempt_count,) = struct.unpack_from(">I", body, offset) - offset += 4 - (header_count,) = struct.unpack_from(">B", body, offset) - offset += 1 - headers: dict[str, str] = {} - for _ in range(header_count): - k, offset = _decode_str(body, offset) - v, offset = _decode_str(body, offset) - headers[k] = v - (payload_len,) = struct.unpack_from(">I", body, offset) - offset += 4 - payload = body[offset: offset + payload_len] - return msg_id, queue, headers, payload, fairness_key, attempt_count + (count,) = struct.unpack_from(">H", body, offset) + offset += 2 + results: list[tuple[str, dict[str, str], bytes, str, int]] = [] + for _ in range(count): + msg_id, offset = _decode_str(body, offset) + fairness_key, offset = _decode_str(body, offset) + (attempt_count,) = struct.unpack_from(">I", body, offset) + offset += 4 + (header_count,) = struct.unpack_from(">B", body, offset) + offset += 1 + headers: dict[str, str] = {} + for _ in range(header_count): + k, offset = _decode_str(body, offset) + v, offset = _decode_str(body, offset) + headers[k] = v + (payload_len,) = struct.unpack_from(">I", body, offset) + offset += 4 + payload = body[offset: offset + payload_len] + offset += payload_len + results.append((msg_id, headers, payload, fairness_key, attempt_count)) + return results def decode_ack_nack_response(body: bytes) -> list[tuple[bool, int, str]]: @@ -276,10 +290,28 @@ def decode_ack_nack_response(body: bytes) -> list[tuple[bool, int, str]]: def decode_error_frame(body: bytes) -> tuple[int, str]: - """Decode a 0xFE ERROR frame body. Returns ``(error_code, message)``.""" - (code,) = struct.unpack_from(">H", body, 0) - msg, _ = _decode_str(body, 2) - return code, msg + """Decode a 0xFE ERROR frame body. Returns ``(error_code, message)``. + + The server encodes error frames as raw UTF-8 message bytes with no code + prefix. This function infers the error code from the message content so + that callers can perform type-safe error handling. + """ + msg = body.decode(errors="replace") + # Infer the error code from well-known message prefixes. + lower = msg.lower() + if "queue" in lower and "not found" in lower: + return ERR_QUEUE_NOT_FOUND, msg + if "message" in lower and "not found" in lower: + return ERR_MESSAGE_NOT_FOUND, msg + if "permission denied" in lower or "does not have" in lower: + return ERR_PERMISSION_DENIED, msg + if ( + "authentication required" in lower + or "invalid or missing api key" in lower + or "auth" in lower + ): + return ERR_AUTH_REQUIRED, msg + return ERR_INTERNAL, msg # ------------------------------------------------------------------ @@ -416,10 +448,18 @@ def send_request(self, frame: bytes, corr_id: int) -> Future[bytes]: return fut def open_consume_stream(self, frame: bytes, corr_id: int) -> _ConsumeQueue: - """Register a consume queue, send *frame*, and return the queue.""" + """Register a consume queue, send *frame*, and return the queue. + + The server sends push frames with correlation_id=0 (FLAG_STREAM set), + so the queue is registered under both the original corr_id (to absorb + the initial stream-accepted ack) and 0 (to receive pushed messages). + Only one consume stream per connection is supported. + """ cq = _ConsumeQueue() with self._lock: self._consume_queues[corr_id] = cq + # Push frames always arrive with corr_id=0. + self._consume_queues[0] = cq with self._send_lock: self._sock.sendall(frame) return cq @@ -487,15 +527,12 @@ def _dispatch(self, flags: int, op: int, corr_id: int, body: bytes) -> None: # Resolve a pending future. with self._lock: fut: Future[bytes] | None = self._pending.pop(corr_id, None) - # Also check if this is the "end of consume stream" signal - # (op == OP_CONSUME response with no push flag). - cq = self._consume_queues.get(corr_id) - if cq is not None and op == OP_CONSUME: - # Server closed the consume stream. - cq.close() - with self._lock: - self._consume_queues.pop(corr_id, None) + # A non-push OP_CONSUME frame with an empty body is the server's + # "stream accepted" acknowledgment. The consume queue was already + # registered under corr_id=0 in open_consume_stream, so there is + # nothing to do here — just discard the ack frame. + if op == OP_CONSUME and not body: return if fut is not None and not fut.done(): @@ -612,11 +649,19 @@ async def send_request(self, frame: bytes, corr_id: int) -> bytes: async def open_consume_stream( self, frame: bytes, corr_id: int ) -> asyncio.Queue[bytes | None]: - """Send *frame* and return a queue that receives pushed bodies.""" + """Send *frame* and return a queue that receives pushed bodies. + + The server sends push frames with correlation_id=0 (FLAG_STREAM set), + so the queue is registered under both the original corr_id (to absorb + the initial stream-accepted ack) and 0 (to receive pushed messages). + Only one consume stream per connection is supported. + """ assert self._write_lock is not None assert self._writer is not None q: asyncio.Queue[bytes | None] = asyncio.Queue() self._consume_queues[corr_id] = q + # Push frames always arrive with corr_id=0. + self._consume_queues[0] = q async with self._write_lock: self._writer.write(frame) await self._writer.drain() @@ -655,10 +700,11 @@ def _dispatch(self, flags: int, op: int, corr_id: int, body: bytes) -> None: self._wake_all(FibpError(0, "server sent GOAWAY")) return - # End of consume stream (server sends a non-push CONSUME frame to close). - if op == OP_CONSUME and corr_id in self._consume_queues: - q = self._consume_queues.pop(corr_id) - q.put_nowait(None) + # A non-push OP_CONSUME frame with an empty body is the server's + # "stream accepted" acknowledgment. The consume queue was already + # registered under corr_id=0 in open_consume_stream, so there is + # nothing to do here — just discard the ack frame. + if op == OP_CONSUME and not body: return fut = self._pending.pop(corr_id, None) diff --git a/tests/conftest.py b/tests/conftest.py index bf136f1..b2969ab 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -12,13 +12,12 @@ from pathlib import Path from typing import TYPE_CHECKING -import grpc import pytest if TYPE_CHECKING: from collections.abc import Generator -from fila.v1 import admin_pb2, admin_pb2_grpc +from fila.v1 import admin_pb2 FILA_SERVER_BIN = os.environ.get( "FILA_SERVER_BIN", @@ -169,93 +168,43 @@ def stop(self) -> None: self._process.wait() shutil.rmtree(self._data_dir, ignore_errors=True) - def _make_grpc_channel(self) -> grpc.Channel: - """Create a gRPC channel to this server (TLS-aware) for admin ops.""" + def create_queue(self, name: str) -> None: + """Create a queue on the test server via FIBP admin op.""" + from fila.fibp import ( + OP_CREATE_QUEUE, + FibpConnection, + encode_admin, + make_ssl_context, + parse_addr, + ) + + host, port = parse_addr(self.addr) + ssl_ctx = None if self.tls_paths is not None: with open(self.tls_paths["ca_cert"], "rb") as f: - ca = f.read() + ca_cert = f.read() with open(self.tls_paths["client_cert"], "rb") as f: - cert = f.read() + client_cert = f.read() with open(self.tls_paths["client_key"], "rb") as f: - key = f.read() - creds = grpc.ssl_channel_credentials( - root_certificates=ca, - private_key=key, - certificate_chain=cert, + client_key = f.read() + ssl_ctx = make_ssl_context( + ca_cert=ca_cert, + client_cert=client_cert, + client_key=client_key, ) - channel = grpc.secure_channel(self.addr, creds) - else: - channel = grpc.insecure_channel(self.addr) - - if self.api_key is not None: - # Inject API key via metadata interceptor for admin calls. - channel = grpc.intercept_channel(channel, _GrpcApiKeyInterceptor(self.api_key)) - - return channel - def create_queue(self, name: str) -> None: - """Create a queue on the test server via admin gRPC.""" - channel = self._make_grpc_channel() - stub = admin_pb2_grpc.FilaAdminStub(channel) - stub.CreateQueue( - admin_pb2.CreateQueueRequest( + conn = FibpConnection(host, port, ssl_ctx=ssl_ctx, api_key=self.api_key) + try: + corr_id = conn.alloc_corr_id() + proto_body = admin_pb2.CreateQueueRequest( name=name, config=admin_pb2.QueueConfig(), - ) - ) - channel.close() - - -class _GrpcClientCallDetails(grpc.ClientCallDetails): # type: ignore[misc] - """Minimal concrete ClientCallDetails for the API key interceptor.""" - - def __init__( - self, - method: str, - timeout: float | None, - metadata: list[tuple[str, str | bytes]] | None, - credentials: grpc.CallCredentials | None, - wait_for_ready: bool | None, - compression: grpc.Compression | None, - ) -> None: - self.method = method - self.timeout = timeout - self.metadata = metadata - self.credentials = credentials - self.wait_for_ready = wait_for_ready - self.compression = compression - - -class _GrpcApiKeyInterceptor( - grpc.UnaryUnaryClientInterceptor, # type: ignore[misc] - grpc.UnaryStreamClientInterceptor, # type: ignore[misc] -): - """Injects authorization metadata into gRPC admin calls (test fixture only).""" - - def __init__(self, api_key: str) -> None: - self._metadata = (("authorization", f"Bearer {api_key}"),) - - def _inject(self, details: grpc.ClientCallDetails) -> _GrpcClientCallDetails: - metadata = list(details.metadata or []) - metadata.extend(self._metadata) - return _GrpcClientCallDetails( - details.method, - details.timeout, - metadata, - details.credentials, - details.wait_for_ready, - details.compression, - ) - - def intercept_unary_unary( # type: ignore[override] - self, continuation: object, details: grpc.ClientCallDetails, request: object - ) -> object: - return continuation(self._inject(details), request) # type: ignore[call-arg] - - def intercept_unary_stream( # type: ignore[override] - self, continuation: object, details: grpc.ClientCallDetails, request: object - ) -> object: - return continuation(self._inject(details), request) # type: ignore[call-arg] + ).SerializeToString() + frame = encode_admin(OP_CREATE_QUEUE, corr_id, proto_body) + fut = conn.send_request(frame, corr_id) + fut.result(timeout=10.0) + finally: + conn.close() @pytest.fixture() @@ -318,9 +267,9 @@ def tls_server() -> Generator[TestServer, None, None]: f'listen_addr = "{addr}"\n' f'\n' f'[tls]\n' - f'ca_cert = "{tls_paths["ca_cert"]}"\n' - f'server_cert = "{tls_paths["server_cert"]}"\n' - f'server_key = "{tls_paths["server_key"]}"\n' + f'ca_file = "{tls_paths["ca_cert"]}"\n' + f'cert_file = "{tls_paths["server_cert"]}"\n' + f'key_file = "{tls_paths["server_key"]}"\n' ) env = {**os.environ, "FILA_DATA_DIR": os.path.join(data_dir, "db")} @@ -372,6 +321,8 @@ def auth_server() -> Generator[TestServer, None, None]: f.write( f'[fibp]\n' f'listen_addr = "{addr}"\n' + f'\n' + f'[auth]\n' f'bootstrap_apikey = "{bootstrap_key}"\n' ) diff --git a/tests/test_batcher.py b/tests/test_batcher.py index d8adfa2..fb5c39b 100644 --- a/tests/test_batcher.py +++ b/tests/test_batcher.py @@ -19,7 +19,7 @@ _flush_many, _flush_single, ) -from fila.errors import EnqueueError, QueueNotFoundError +from fila.errors import QueueNotFoundError, TransportError from fila.fibp import ( ERR_QUEUE_NOT_FOUND, FibpError, @@ -149,7 +149,7 @@ def test_transport_failure_sets_all_futures(self) -> None: _flush_many(conn, items) for item in items: - with pytest.raises(EnqueueError): + with pytest.raises(TransportError): item.future.result(timeout=1.0) def test_multi_queue_batch_sends_per_queue_frames(self) -> None: