From f8e343d0ffd2f1f2c075b7d4bf223c0a4676016e Mon Sep 17 00:00:00 2001 From: Lucas Vieira Date: Sat, 21 Mar 2026 11:32:38 -0300 Subject: [PATCH 01/17] fix: remove local version identifier from dev publish --- .github/workflows/dev-publish.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/dev-publish.yml b/.github/workflows/dev-publish.yml index 78078d0..2602016 100644 --- a/.github/workflows/dev-publish.yml +++ b/.github/workflows/dev-publish.yml @@ -20,7 +20,7 @@ jobs: run: | COMMIT_COUNT=$(git rev-list --count HEAD) SHORT_SHA=$(git rev-parse --short HEAD) - DEV_VERSION="0.1.dev${COMMIT_COUNT}+g${SHORT_SHA}" + DEV_VERSION="0.1.dev${COMMIT_COUNT}" sed -i "s/^version = .*/version = \"${DEV_VERSION}\"/" pyproject.toml echo "Publishing version: ${DEV_VERSION}" - run: pip install build From 8d6555219693f1f5afe92e86f969a3edbd407402 Mon Sep 17 00:00:00 2001 From: Lucas Vieira Date: Sat, 21 Mar 2026 11:37:42 -0300 Subject: [PATCH 02/17] fix: include commit sha in dev build package description --- .github/workflows/dev-publish.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/dev-publish.yml b/.github/workflows/dev-publish.yml index 2602016..e23d281 100644 --- a/.github/workflows/dev-publish.yml +++ b/.github/workflows/dev-publish.yml @@ -22,7 +22,8 @@ jobs: SHORT_SHA=$(git rev-parse --short HEAD) DEV_VERSION="0.1.dev${COMMIT_COUNT}" sed -i "s/^version = .*/version = \"${DEV_VERSION}\"/" pyproject.toml - echo "Publishing version: ${DEV_VERSION}" + sed -i "s/^description = .*/description = \"Python client SDK for the Fila message broker (dev build from ${SHORT_SHA})\"/" pyproject.toml + echo "Publishing version: ${DEV_VERSION} (${SHORT_SHA})" - run: pip install build - run: python -m build - uses: pypa/gh-action-pypi-publish@release/v1 From f7a4ed96d9203a05f3ee5c20a120fe0a6a1d94a2 Mon Sep 17 00:00:00 2001 From: Lucas Vieira Date: Sun, 22 Mar 2026 13:03:00 -0300 Subject: [PATCH 03/17] feat: transparent leader hint reconnect on consume --- fila/async_client.py | 72 +++++++++++++++++++++++++++++++++++--------- fila/client.py | 72 +++++++++++++++++++++++++++++++++++++------- 2 files changed, 119 insertions(+), 25 deletions(-) diff --git a/fila/async_client.py b/fila/async_client.py index 9c99b50..08a9a30 100644 --- a/fila/async_client.py +++ b/fila/async_client.py @@ -99,6 +99,22 @@ async def intercept_unary_stream( return await continuation(new_details, request) +_LEADER_HINT_KEY = "x-fila-leader-addr" + + +def _extract_leader_hint(err: grpc.RpcError) -> str | None: + """Return the leader address from trailing metadata, if present.""" + if err.code() != grpc.StatusCode.UNAVAILABLE: + return None + trailing = err.trailing_metadata() + if trailing is None: + return None + for key, value in trailing: + if key == _LEADER_HINT_KEY: + return value + return None + + class AsyncClient: """Asynchronous client for the Fila message broker. @@ -162,32 +178,41 @@ def __init__( api_key: API key for authentication. When set, every RPC includes an ``authorization: Bearer `` metadata header. """ - use_tls = tls or ca_cert is not None + self._tls = tls + self._ca_cert = ca_cert + self._client_cert = client_cert + self._client_key = client_key + self._api_key = api_key + use_tls = tls or ca_cert is not None if (client_cert is not None or client_key is not None) and not use_tls: raise ValueError( "client_cert and client_key require ca_cert or tls=True to establish a TLS channel" ) + self._channel = self._make_channel(addr) + self._stub = service_pb2_grpc.FilaServiceStub(self._channel) # type: ignore[no-untyped-call] + + def _make_channel(self, addr: str) -> grpc.aio.Channel: + """Create an async gRPC channel to the given address using stored credentials.""" + use_tls = self._tls or self._ca_cert is not None + interceptors: list[grpc.aio.ClientInterceptor] = [] - if api_key is not None: - interceptors.append(_AsyncApiKeyInterceptor(api_key)) + if self._api_key is not None: + interceptors.append(_AsyncApiKeyInterceptor(self._api_key)) if use_tls: creds = grpc.ssl_channel_credentials( - root_certificates=ca_cert, - private_key=client_key, - certificate_chain=client_cert, + root_certificates=self._ca_cert, + private_key=self._client_key, + certificate_chain=self._client_cert, ) - self._channel = grpc.aio.secure_channel( + return grpc.aio.secure_channel( addr, creds, interceptors=interceptors or None ) - else: - self._channel = grpc.aio.insecure_channel( - addr, interceptors=interceptors or None - ) - - self._stub = service_pb2_grpc.FilaServiceStub(self._channel) # type: ignore[no-untyped-call] + return grpc.aio.insecure_channel( + addr, interceptors=interceptors or None + ) async def close(self) -> None: """Close the underlying gRPC channel.""" @@ -238,6 +263,10 @@ async def consume(self, queue: str) -> AsyncIterator[ConsumeMessage]: server stream closes or an error occurs. Nil message frames (keepalive signals) are skipped automatically. + If the server returns UNAVAILABLE with an ``x-fila-leader-addr`` + trailing metadata entry, the client transparently reconnects to the + leader address and retries the consume call once. + Args: queue: Queue to consume from. @@ -253,10 +282,25 @@ async def consume(self, queue: str) -> AsyncIterator[ConsumeMessage]: service_pb2.ConsumeRequest(queue=queue) ) except grpc.RpcError as e: - raise _map_consume_error(e) from e + leader_addr = _extract_leader_hint(e) + if leader_addr is not None: + stream = await self._reconnect_and_consume(leader_addr, queue) + else: + raise _map_consume_error(e) from e return self._consume_iter(stream) + async def _reconnect_and_consume(self, leader_addr: str, queue: str) -> Any: + """Create a new channel to *leader_addr* and retry the consume call.""" + self._channel = self._make_channel(leader_addr) + self._stub = service_pb2_grpc.FilaServiceStub(self._channel) # type: ignore[no-untyped-call] + try: + return self._stub.Consume( + service_pb2.ConsumeRequest(queue=queue) + ) + except grpc.RpcError as e: + raise _map_consume_error(e) from e + async def _consume_iter( self, stream: Any, diff --git a/fila/client.py b/fila/client.py index 531c051..0aec27f 100644 --- a/fila/client.py +++ b/fila/client.py @@ -13,6 +13,25 @@ if TYPE_CHECKING: from collections.abc import Iterator +_LEADER_HINT_KEY = "x-fila-leader-addr" + + +def _extract_leader_hint(err: grpc.RpcError) -> str | None: + """Return the leader address from trailing metadata, if present. + + The server sets ``x-fila-leader-addr`` in trailing metadata alongside an + UNAVAILABLE status when the node is not the leader for the requested queue. + """ + if err.code() != grpc.StatusCode.UNAVAILABLE: + return None + trailing = err.trailing_metadata() + if trailing is None: + return None + for key, value in trailing: + if key == _LEADER_HINT_KEY: + return value + return None + class _ClientCallDetails( grpc.ClientCallDetails, # type: ignore[misc] @@ -143,28 +162,40 @@ def __init__( api_key: API key for authentication. When set, every RPC includes an ``authorization: Bearer `` metadata header. """ - use_tls = tls or ca_cert is not None + self._tls = tls + self._ca_cert = ca_cert + self._client_cert = client_cert + self._client_key = client_key + self._api_key = api_key + use_tls = tls or ca_cert is not None if (client_cert is not None or client_key is not None) and not use_tls: raise ValueError( "client_cert and client_key require ca_cert or tls=True to establish a TLS channel" ) + self._channel = self._make_channel(addr) + self._stub = service_pb2_grpc.FilaServiceStub(self._channel) # type: ignore[no-untyped-call] + + def _make_channel(self, addr: str) -> grpc.Channel: + """Create a gRPC channel to the given address using stored credentials.""" + use_tls = self._tls or self._ca_cert is not None + if use_tls: creds = grpc.ssl_channel_credentials( - root_certificates=ca_cert, - private_key=client_key, - certificate_chain=client_cert, + root_certificates=self._ca_cert, + private_key=self._client_key, + certificate_chain=self._client_cert, ) - self._channel = grpc.secure_channel(addr, creds) + channel: grpc.Channel = grpc.secure_channel(addr, creds) else: - self._channel = grpc.insecure_channel(addr) + channel = grpc.insecure_channel(addr) - if api_key is not None: - interceptor = _ApiKeyInterceptor(api_key) - self._channel = grpc.intercept_channel(self._channel, interceptor) + if self._api_key is not None: + interceptor = _ApiKeyInterceptor(self._api_key) + channel = grpc.intercept_channel(channel, interceptor) - self._stub = service_pb2_grpc.FilaServiceStub(self._channel) # type: ignore[no-untyped-call] + return channel def close(self) -> None: """Close the underlying gRPC channel.""" @@ -215,6 +246,10 @@ def consume(self, queue: str) -> Iterator[ConsumeMessage]: server stream closes or an error occurs. Skip nil message frames (keepalive signals) automatically. + If the server returns UNAVAILABLE with an ``x-fila-leader-addr`` + trailing metadata entry, the client transparently reconnects to the + leader address and retries the consume call once. + Args: queue: Queue to consume from. @@ -230,10 +265,25 @@ def consume(self, queue: str) -> Iterator[ConsumeMessage]: service_pb2.ConsumeRequest(queue=queue) ) except grpc.RpcError as e: - raise _map_consume_error(e) from e + leader_addr = _extract_leader_hint(e) + if leader_addr is not None: + stream = self._reconnect_and_consume(leader_addr, queue) + else: + raise _map_consume_error(e) from e return self._consume_iter(stream) + def _reconnect_and_consume(self, leader_addr: str, queue: str) -> Any: + """Create a new channel to *leader_addr* and retry the consume call.""" + self._channel = self._make_channel(leader_addr) + self._stub = service_pb2_grpc.FilaServiceStub(self._channel) # type: ignore[no-untyped-call] + try: + return self._stub.Consume( + service_pb2.ConsumeRequest(queue=queue) + ) + except grpc.RpcError as e: + raise _map_consume_error(e) from e + def _consume_iter( self, stream: Any, From 1f5753e1d5dd4d31768fd2c8b86a3694766b48d2 Mon Sep 17 00:00:00 2001 From: Lucas Vieira Date: Mon, 23 Mar 2026 22:54:05 -0300 Subject: [PATCH 04/17] fix: close old grpc channel before reconnecting --- fila/async_client.py | 1 + fila/client.py | 1 + 2 files changed, 2 insertions(+) diff --git a/fila/async_client.py b/fila/async_client.py index 08a9a30..27f0bf6 100644 --- a/fila/async_client.py +++ b/fila/async_client.py @@ -292,6 +292,7 @@ async def consume(self, queue: str) -> AsyncIterator[ConsumeMessage]: async def _reconnect_and_consume(self, leader_addr: str, queue: str) -> Any: """Create a new channel to *leader_addr* and retry the consume call.""" + await self._channel.close() self._channel = self._make_channel(leader_addr) self._stub = service_pb2_grpc.FilaServiceStub(self._channel) # type: ignore[no-untyped-call] try: diff --git a/fila/client.py b/fila/client.py index 0aec27f..6e91247 100644 --- a/fila/client.py +++ b/fila/client.py @@ -275,6 +275,7 @@ def consume(self, queue: str) -> Iterator[ConsumeMessage]: def _reconnect_and_consume(self, leader_addr: str, queue: str) -> Any: """Create a new channel to *leader_addr* and retry the consume call.""" + self._channel.close() self._channel = self._make_channel(leader_addr) self._stub = service_pb2_grpc.FilaServiceStub(self._channel) # type: ignore[no-untyped-call] try: From 273746e911088833b8816866a286f80a5e8025c4 Mon Sep 17 00:00:00 2001 From: Lucas Vieira Date: Mon, 23 Mar 2026 23:04:56 -0300 Subject: [PATCH 05/17] fix: resolve mypy no-any-return errors --- fila/async_client.py | 2 +- fila/client.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/fila/async_client.py b/fila/async_client.py index 27f0bf6..8bab27b 100644 --- a/fila/async_client.py +++ b/fila/async_client.py @@ -111,7 +111,7 @@ def _extract_leader_hint(err: grpc.RpcError) -> str | None: return None for key, value in trailing: if key == _LEADER_HINT_KEY: - return value + return str(value) return None diff --git a/fila/client.py b/fila/client.py index 6e91247..891907a 100644 --- a/fila/client.py +++ b/fila/client.py @@ -29,7 +29,7 @@ def _extract_leader_hint(err: grpc.RpcError) -> str | None: return None for key, value in trailing: if key == _LEADER_HINT_KEY: - return value + return str(value) return None From 872b5926aa99f34e955ac2b53ae6bb6a650f8155 Mon Sep 17 00:00:00 2001 From: Lucas Vieira Date: Mon, 23 Mar 2026 23:26:44 -0300 Subject: [PATCH 06/17] chore: bump version to 0.2.0 --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 26a4ab3..2dcc753 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "fila-python" -version = "0.1.0" +version = "0.2.0" description = "Python client SDK for the Fila message broker" readme = "README.md" license = "AGPL-3.0-or-later" From c9a83fc013ea8855cdd28fc6e136bb08868a6a1f Mon Sep 17 00:00:00 2001 From: Lucas Vieira Date: Tue, 24 Mar 2026 10:20:33 -0300 Subject: [PATCH 07/17] feat: add batch enqueue, smart batching, and delivery batching (#3) add batch_enqueue() for explicit multi-message RPCs, smart batching via BatchMode (AUTO/DISABLED/Linger) that routes enqueue() through a background batcher thread, and delivery batching that unpacks ConsumeResponse.messages repeated field. update proto to include BatchEnqueue RPC and ConsumeResponse batched messages field. single-item optimization uses singular Enqueue RPC to preserve error types. close() drains pending messages before disconnecting. --- fila/__init__.py | 9 +- fila/async_client.py | 88 +++++++-- fila/batcher.py | 267 +++++++++++++++++++++++++ fila/client.py | 171 +++++++++++++--- fila/errors.py | 17 ++ fila/types.py | 46 +++++ fila/v1/messages_pb2_grpc.py | 2 +- fila/v1/service_pb2.py | 30 +-- fila/v1/service_pb2.pyi | 28 ++- fila/v1/service_pb2_grpc.py | 45 ++++- proto/fila/v1/service.proto | 19 +- tests/test_batch_integration.py | 220 +++++++++++++++++++++ tests/test_batcher.py | 333 ++++++++++++++++++++++++++++++++ 13 files changed, 1218 insertions(+), 57 deletions(-) create mode 100644 fila/batcher.py create mode 100644 tests/test_batch_integration.py create mode 100644 tests/test_batcher.py diff --git a/fila/__init__.py b/fila/__init__.py index a273836..732fc43 100644 --- a/fila/__init__.py +++ b/fila/__init__.py @@ -1,20 +1,25 @@ -"""Fila — Python client SDK for the Fila message broker.""" +"""Fila -- Python client SDK for the Fila message broker.""" from fila.async_client import AsyncClient from fila.client import Client from fila.errors import ( + BatchEnqueueError, FilaError, MessageNotFoundError, QueueNotFoundError, RPCError, ) -from fila.types import ConsumeMessage +from fila.types import BatchEnqueueResult, BatchMode, ConsumeMessage, Linger __all__ = [ "AsyncClient", + "BatchEnqueueError", + "BatchEnqueueResult", + "BatchMode", "Client", "ConsumeMessage", "FilaError", + "Linger", "MessageNotFoundError", "QueueNotFoundError", "RPCError", diff --git a/fila/async_client.py b/fila/async_client.py index 8bab27b..8e06b1e 100644 --- a/fila/async_client.py +++ b/fila/async_client.py @@ -10,8 +10,15 @@ if TYPE_CHECKING: from collections.abc import AsyncIterator -from fila.errors import _map_ack_error, _map_consume_error, _map_enqueue_error, _map_nack_error -from fila.types import ConsumeMessage +from fila.client import _proto_msg_to_consume_message +from fila.errors import ( + _map_ack_error, + _map_batch_enqueue_error, + _map_consume_error, + _map_enqueue_error, + _map_nack_error, +) +from fila.types import BatchEnqueueResult, ConsumeMessage from fila.v1 import service_pb2, service_pb2_grpc @@ -118,7 +125,8 @@ def _extract_leader_hint(err: grpc.RpcError) -> str | None: class AsyncClient: """Asynchronous client for the Fila message broker. - Wraps the hot-path gRPC operations: enqueue, consume, ack, nack. + Wraps the hot-path gRPC operations: enqueue, batch_enqueue, consume, ack, + nack. Usage:: @@ -256,6 +264,55 @@ async def enqueue( raise _map_enqueue_error(e) from e return str(resp.message_id) + async def batch_enqueue( + self, + messages: list[tuple[str, dict[str, str] | None, bytes]], + ) -> list[BatchEnqueueResult]: + """Enqueue multiple messages in a single RPC. + + Args: + messages: List of (queue, headers, payload) tuples. + + Returns: + List of ``BatchEnqueueResult`` objects, one per input message. + Each result has either a ``message_id`` (success) or ``error`` + (per-message failure). + + Raises: + QueueNotFoundError: If a referenced queue does not exist. + RPCError: For unexpected gRPC failures. + """ + proto_messages = [ + service_pb2.EnqueueRequest( + queue=q, + headers=h or {}, + payload=p, + ) + for q, h, p in messages + ] + + try: + resp = await self._stub.BatchEnqueue( + service_pb2.BatchEnqueueRequest(messages=proto_messages) + ) + except grpc.RpcError as e: + raise _map_batch_enqueue_error(e) from e + + results: list[BatchEnqueueResult] = [] + for r in resp.results: + if r.HasField("success"): + results.append( + BatchEnqueueResult( + message_id=str(r.success.message_id), + error=None, + ) + ) + else: + results.append( + BatchEnqueueResult(message_id=None, error=r.error) + ) + return results + async def consume(self, queue: str) -> AsyncIterator[ConsumeMessage]: """Open a streaming consumer on the specified queue. @@ -306,22 +363,25 @@ async def _consume_iter( self, stream: Any, ) -> AsyncIterator[ConsumeMessage]: - """Internal async generator reading from the gRPC stream.""" + """Internal async generator reading from the gRPC stream. + + Handles both singular ``message`` field (backward compatible) and + repeated ``messages`` field (batched delivery). + """ try: async for resp in stream: + # Check batched messages first (repeated field). + if len(resp.messages) > 0: + for msg in resp.messages: + if msg is not None and msg.ByteSize(): + yield _proto_msg_to_consume_message(msg) + continue + + # Fall back to singular message field. msg = resp.message if msg is None or not msg.ByteSize(): continue # keepalive - metadata = msg.metadata - cm = ConsumeMessage( - id=msg.id, - headers=dict(msg.headers), - payload=bytes(msg.payload), - fairness_key=metadata.fairness_key if metadata else "", - attempt_count=metadata.attempt_count if metadata else 0, - queue=metadata.queue_id if metadata else "", - ) - yield cm + yield _proto_msg_to_consume_message(msg) except grpc.RpcError: return diff --git a/fila/batcher.py b/fila/batcher.py new file mode 100644 index 0000000..c57964a --- /dev/null +++ b/fila/batcher.py @@ -0,0 +1,267 @@ +"""Background batcher for opportunistic and linger-based enqueue batching.""" + +from __future__ import annotations + +import queue +import threading +from concurrent.futures import Future, ThreadPoolExecutor +from typing import TYPE_CHECKING, Any + +import grpc + +from fila.errors import BatchEnqueueError, _map_enqueue_error +from fila.types import BatchEnqueueResult +from fila.v1 import service_pb2 + +if TYPE_CHECKING: + from fila.v1 import service_pb2_grpc + + +# Sentinel that signals the batcher thread to stop. +_STOP = object() + +# Maximum batch size when none is configured. +_DEFAULT_MAX_BATCH_SIZE = 1000 + + +class _EnqueueRequest: + """Internal envelope pairing a proto request with its result future.""" + + __slots__ = ("proto", "future") + + def __init__( + self, + proto: service_pb2.EnqueueRequest, + future: Future[str], + ) -> None: + self.proto = proto + self.future = future + + +def _msg_to_consume_result( + proto_result: Any, +) -> BatchEnqueueResult: + """Convert a proto ``BatchEnqueueResult`` to the SDK type.""" + if proto_result.HasField("success"): + return BatchEnqueueResult( + message_id=proto_result.success.message_id, + error=None, + ) + return BatchEnqueueResult( + message_id=None, + error=proto_result.error, + ) + + +def _flush_single( + stub: service_pb2_grpc.FilaServiceStub, + req: _EnqueueRequest, +) -> None: + """Send a single message via the singular Enqueue RPC. + + This preserves the specific error types (QueueNotFoundError, etc.) + that callers of ``enqueue()`` expect. + """ + try: + resp = stub.Enqueue(req.proto) + req.future.set_result(str(resp.message_id)) + except grpc.RpcError as e: + req.future.set_exception(_map_enqueue_error(e)) + except Exception as e: + req.future.set_exception(e) + + +def _flush_batch( + stub: service_pb2_grpc.FilaServiceStub, + batch: list[_EnqueueRequest], +) -> None: + """Send a batch of messages via the BatchEnqueue RPC. + + On RPC-level failure, every future in the batch receives a + ``BatchEnqueueError``. On success, each future gets either its + message ID or a per-message error string wrapped in a + ``BatchEnqueueError``. + """ + try: + resp = stub.BatchEnqueue( + service_pb2.BatchEnqueueRequest( + messages=[r.proto for r in batch], + ) + ) + except grpc.RpcError as e: + err = BatchEnqueueError(f"batch enqueue rpc failed: {e.details()}") + for r in batch: + r.future.set_exception(err) + return + except Exception as e: + for r in batch: + r.future.set_exception(e) + return + + # Pair each result with its request future. + for i, result in enumerate(resp.results): + if i >= len(batch): + break + req = batch[i] + if result.HasField("success"): + req.future.set_result(str(result.success.message_id)) + else: + req.future.set_exception( + BatchEnqueueError(f"enqueue failed: {result.error}") + ) + + +class AutoBatcher: + """Opportunistic batcher: drains a queue and flushes in batches. + + A background daemon thread blocks on the first message, then non-blocking + drains any additional messages that arrived during processing and flushes + them as a single batch via a thread pool executor. + """ + + def __init__( + self, + stub: service_pb2_grpc.FilaServiceStub, + max_batch_size: int = _DEFAULT_MAX_BATCH_SIZE, + max_workers: int = 4, + ) -> None: + self._stub = stub + self._max_batch_size = max_batch_size + self._queue: queue.Queue[_EnqueueRequest | object] = queue.Queue() + self._executor = ThreadPoolExecutor(max_workers=max_workers) + self._thread = threading.Thread(target=self._run, daemon=True) + self._thread.start() + + def submit(self, proto: service_pb2.EnqueueRequest) -> Future[str]: + """Submit a message for batched enqueue. Returns a Future for the message ID.""" + fut: Future[str] = Future() + self._queue.put(_EnqueueRequest(proto, fut)) + return fut + + def close(self, timeout: float | None = 30.0) -> None: + """Drain pending messages and shut down the batcher. + + Blocks until all pending messages have been flushed or *timeout* + seconds have elapsed. + """ + self._queue.put(_STOP) + self._thread.join(timeout=timeout) + self._executor.shutdown(wait=True) + + def update_stub(self, stub: service_pb2_grpc.FilaServiceStub) -> None: + """Update the gRPC stub (e.g. after leader-hint reconnect).""" + self._stub = stub + + def _run(self) -> None: + """Background loop: block for first item, drain rest, flush.""" + while True: + # Block until at least one item arrives. + first = self._queue.get() + if first is _STOP: + return + + assert isinstance(first, _EnqueueRequest) + batch: list[_EnqueueRequest] = [first] + + # Non-blocking drain of any additional queued messages. + while len(batch) < self._max_batch_size: + try: + item = self._queue.get_nowait() + except queue.Empty: + break + if item is _STOP: + # Flush what we have, then stop. + self._flush(batch) + return + assert isinstance(item, _EnqueueRequest) + batch.append(item) + + self._flush(batch) + + def _flush(self, batch: list[_EnqueueRequest]) -> None: + """Dispatch a batch to the executor for concurrent RPC.""" + if len(batch) == 1: + # Single-item optimization: use singular Enqueue RPC. + self._executor.submit(_flush_single, self._stub, batch[0]) + else: + self._executor.submit(_flush_batch, self._stub, batch) + + +class LingerBatcher: + """Timer-based batcher: holds messages for up to linger_ms or batch_size. + + A background daemon thread accumulates messages and flushes when either + the batch reaches ``batch_size`` or ``linger_ms`` milliseconds have + elapsed since the first message in the current batch arrived. + """ + + def __init__( + self, + stub: service_pb2_grpc.FilaServiceStub, + linger_ms: float, + batch_size: int, + max_workers: int = 4, + ) -> None: + self._stub = stub + self._linger_s = linger_ms / 1000.0 + self._batch_size = batch_size + self._queue: queue.Queue[_EnqueueRequest | object] = queue.Queue() + self._executor = ThreadPoolExecutor(max_workers=max_workers) + self._thread = threading.Thread(target=self._run, daemon=True) + self._thread.start() + + def submit(self, proto: service_pb2.EnqueueRequest) -> Future[str]: + """Submit a message for batched enqueue. Returns a Future for the message ID.""" + fut: Future[str] = Future() + self._queue.put(_EnqueueRequest(proto, fut)) + return fut + + def close(self, timeout: float | None = 30.0) -> None: + """Drain pending messages and shut down the batcher.""" + self._queue.put(_STOP) + self._thread.join(timeout=timeout) + self._executor.shutdown(wait=True) + + def update_stub(self, stub: service_pb2_grpc.FilaServiceStub) -> None: + """Update the gRPC stub (e.g. after leader-hint reconnect).""" + self._stub = stub + + def _run(self) -> None: + """Background loop: accumulate up to batch_size or linger timeout.""" + import time + + while True: + # Block until at least one item arrives. + first = self._queue.get() + if first is _STOP: + return + + assert isinstance(first, _EnqueueRequest) + batch: list[_EnqueueRequest] = [first] + + # Track wall-clock deadline from when first message arrived. + deadline = time.monotonic() + self._linger_s + + # Accumulate more items until batch_size or linger timeout. + while len(batch) < self._batch_size: + remaining = deadline - time.monotonic() + if remaining <= 0: + break + try: + item = self._queue.get(timeout=remaining) + except queue.Empty: + break + if item is _STOP: + self._flush(batch) + return + assert isinstance(item, _EnqueueRequest) + batch.append(item) + + self._flush(batch) + + def _flush(self, batch: list[_EnqueueRequest]) -> None: + """Dispatch a batch to the executor for concurrent RPC.""" + if len(batch) == 1: + self._executor.submit(_flush_single, self._stub, batch[0]) + else: + self._executor.submit(_flush_batch, self._stub, batch) diff --git a/fila/client.py b/fila/client.py index 891907a..0d7e49a 100644 --- a/fila/client.py +++ b/fila/client.py @@ -6,8 +6,15 @@ import grpc -from fila.errors import _map_ack_error, _map_consume_error, _map_enqueue_error, _map_nack_error -from fila.types import ConsumeMessage +from fila.batcher import AutoBatcher, LingerBatcher +from fila.errors import ( + _map_ack_error, + _map_batch_enqueue_error, + _map_consume_error, + _map_enqueue_error, + _map_nack_error, +) +from fila.types import BatchEnqueueResult, BatchMode, ConsumeMessage, Linger from fila.v1 import service_pb2, service_pb2_grpc if TYPE_CHECKING: @@ -33,6 +40,19 @@ def _extract_leader_hint(err: grpc.RpcError) -> str | None: return None +def _proto_msg_to_consume_message(msg: Any) -> ConsumeMessage: + """Convert a protobuf Message to a ConsumeMessage.""" + metadata = msg.metadata + return ConsumeMessage( + id=msg.id, + headers=dict(msg.headers), + payload=bytes(msg.payload), + fairness_key=metadata.fairness_key if metadata else "", + attempt_count=metadata.attempt_count if metadata else 0, + queue=metadata.queue_id if metadata else "", + ) + + class _ClientCallDetails( grpc.ClientCallDetails, # type: ignore[misc] ): @@ -102,7 +122,8 @@ def intercept_unary_stream( class Client: """Synchronous client for the Fila message broker. - Wraps the hot-path gRPC operations: enqueue, consume, ack, nack. + Wraps the hot-path gRPC operations: enqueue, batch_enqueue, consume, ack, + nack. Usage:: @@ -117,6 +138,17 @@ class Client: with Client("localhost:5555") as client: client.enqueue("my-queue", None, b"hello") + Batch modes:: + + # AUTO (default): opportunistic batching via background thread + client = Client("localhost:5555") + + # DISABLED: each enqueue() is a direct RPC + client = Client("localhost:5555", batch_mode=BatchMode.DISABLED) + + # LINGER: timer-based forced batching + client = Client("localhost:5555", batch_mode=Linger(linger_ms=10, batch_size=100)) + TLS (system trust store):: client = Client("localhost:5555", tls=True) @@ -147,6 +179,8 @@ def __init__( client_cert: bytes | None = None, client_key: bytes | None = None, api_key: str | None = None, + batch_mode: BatchMode | Linger = BatchMode.AUTO, + max_batch_size: int = 1000, ) -> None: """Connect to a Fila broker at the given address. @@ -161,6 +195,10 @@ def __init__( client_key: PEM-encoded client private key for mutual TLS (optional). api_key: API key for authentication. When set, every RPC includes an ``authorization: Bearer `` metadata header. + batch_mode: Controls how ``enqueue()`` routes messages. Defaults to + ``BatchMode.AUTO`` (opportunistic batching). + max_batch_size: Maximum number of messages per batch when using + ``BatchMode.AUTO``. Defaults to 1000. """ self._tls = tls self._ca_cert = ca_cert @@ -177,6 +215,21 @@ def __init__( self._channel = self._make_channel(addr) self._stub = service_pb2_grpc.FilaServiceStub(self._channel) # type: ignore[no-untyped-call] + # Set up the batcher based on the chosen mode. + self._batcher: AutoBatcher | LingerBatcher | None = None + if isinstance(batch_mode, Linger): + self._batcher = LingerBatcher( + self._stub, + linger_ms=batch_mode.linger_ms, + batch_size=batch_mode.batch_size, + ) + elif batch_mode is BatchMode.AUTO: + self._batcher = AutoBatcher( + self._stub, + max_batch_size=max_batch_size, + ) + # BatchMode.DISABLED: self._batcher stays None + def _make_channel(self, addr: str) -> grpc.Channel: """Create a gRPC channel to the given address using stored credentials.""" use_tls = self._tls or self._ca_cert is not None @@ -198,7 +251,9 @@ def _make_channel(self, addr: str) -> grpc.Channel: return channel def close(self) -> None: - """Close the underlying gRPC channel.""" + """Drain pending batched messages and close the underlying gRPC channel.""" + if self._batcher is not None: + self._batcher.close() self._channel.close() def __enter__(self) -> Client: @@ -215,6 +270,13 @@ def enqueue( ) -> str: """Enqueue a message to the specified queue. + When a batcher is active (``BatchMode.AUTO`` or ``Linger``), the + message is submitted to the background batcher and this call blocks + until the batch is flushed and the result is available. + + When batching is disabled (``BatchMode.DISABLED``), this call makes + a direct synchronous RPC. + Args: queue: Target queue name. headers: Optional message headers. @@ -224,21 +286,79 @@ def enqueue( Broker-assigned message ID (UUIDv7). Raises: - QueueNotFoundError: If the queue does not exist. + QueueNotFoundError: If the queue does not exist (DISABLED mode). + BatchEnqueueError: If the batch RPC fails (AUTO/LINGER mode). RPCError: For unexpected gRPC failures. """ + proto = service_pb2.EnqueueRequest( + queue=queue, + headers=headers or {}, + payload=payload, + ) + + if self._batcher is not None: + future = self._batcher.submit(proto) + return future.result() + + # Direct RPC (DISABLED mode). try: - resp = self._stub.Enqueue( - service_pb2.EnqueueRequest( - queue=queue, - headers=headers or {}, - payload=payload, - ) - ) + resp = self._stub.Enqueue(proto) except grpc.RpcError as e: raise _map_enqueue_error(e) from e return str(resp.message_id) + def batch_enqueue( + self, + messages: list[tuple[str, dict[str, str] | None, bytes]], + ) -> list[BatchEnqueueResult]: + """Enqueue multiple messages in a single RPC. + + This is an explicit batch operation that always uses the BatchEnqueue + RPC regardless of the batch_mode setting. + + Args: + messages: List of (queue, headers, payload) tuples. + + Returns: + List of ``BatchEnqueueResult`` objects, one per input message. + Each result has either a ``message_id`` (success) or ``error`` + (per-message failure). + + Raises: + QueueNotFoundError: If a referenced queue does not exist. + RPCError: For unexpected gRPC failures. + """ + proto_messages = [ + service_pb2.EnqueueRequest( + queue=q, + headers=h or {}, + payload=p, + ) + for q, h, p in messages + ] + + try: + resp = self._stub.BatchEnqueue( + service_pb2.BatchEnqueueRequest(messages=proto_messages) + ) + except grpc.RpcError as e: + raise _map_batch_enqueue_error(e) from e + + results: list[BatchEnqueueResult] = [] + for r in resp.results: + if r.HasField("success"): + results.append( + BatchEnqueueResult( + message_id=str(r.success.message_id), + error=None, + ) + ) + else: + results.append( + BatchEnqueueResult(message_id=None, error=r.error) + ) + return results + def consume(self, queue: str) -> Iterator[ConsumeMessage]: """Open a streaming consumer on the specified queue. @@ -278,6 +398,8 @@ def _reconnect_and_consume(self, leader_addr: str, queue: str) -> Any: self._channel.close() self._channel = self._make_channel(leader_addr) self._stub = service_pb2_grpc.FilaServiceStub(self._channel) # type: ignore[no-untyped-call] + if self._batcher is not None: + self._batcher.update_stub(self._stub) try: return self._stub.Consume( service_pb2.ConsumeRequest(queue=queue) @@ -289,22 +411,25 @@ def _consume_iter( self, stream: Any, ) -> Iterator[ConsumeMessage]: - """Internal generator reading from the gRPC stream.""" + """Internal generator reading from the gRPC stream. + + Handles both singular ``message`` field (backward compatible) and + repeated ``messages`` field (batched delivery). + """ try: for resp in stream: + # Check batched messages first (repeated field). + if len(resp.messages) > 0: + for msg in resp.messages: + if msg is not None and msg.ByteSize(): + yield _proto_msg_to_consume_message(msg) + continue + + # Fall back to singular message field. msg = resp.message if msg is None or not msg.ByteSize(): continue # keepalive - metadata = msg.metadata - cm = ConsumeMessage( - id=msg.id, - headers=dict(msg.headers), - payload=bytes(msg.payload), - fairness_key=metadata.fairness_key if metadata else "", - attempt_count=metadata.attempt_count if metadata else 0, - queue=metadata.queue_id if metadata else "", - ) - yield cm + yield _proto_msg_to_consume_message(msg) except grpc.RpcError: return diff --git a/fila/errors.py b/fila/errors.py index 346c1c6..40e76ee 100644 --- a/fila/errors.py +++ b/fila/errors.py @@ -26,6 +26,15 @@ def __init__(self, code: grpc.StatusCode, message: str) -> None: super().__init__(f"rpc error (code = {code.name}): {message}") +class BatchEnqueueError(FilaError): + """Raised when a batched enqueue fails at the RPC level. + + Individual per-message failures are reported via ``BatchEnqueueResult.error`` + and do not raise this exception. This is raised only when the entire batch + RPC fails (e.g., network error, server unavailable). + """ + + def _map_enqueue_error(err: grpc.RpcError) -> FilaError: """Map a gRPC error from an enqueue call to a Fila exception.""" code = err.code() @@ -56,3 +65,11 @@ def _map_nack_error(err: grpc.RpcError) -> FilaError: if code == grpc.StatusCode.NOT_FOUND: return MessageNotFoundError(f"nack: {err.details()}") return RPCError(code, err.details() or "") + + +def _map_batch_enqueue_error(err: grpc.RpcError) -> FilaError: + """Map a gRPC error from a batch enqueue call to a Fila exception.""" + code = err.code() + if code == grpc.StatusCode.NOT_FOUND: + return QueueNotFoundError(f"batch_enqueue: {err.details()}") + return RPCError(code, err.details() or "") diff --git a/fila/types.py b/fila/types.py index 2474228..54ab034 100644 --- a/fila/types.py +++ b/fila/types.py @@ -3,6 +3,7 @@ from __future__ import annotations from dataclasses import dataclass +from enum import Enum, auto @dataclass(frozen=True) @@ -15,3 +16,48 @@ class ConsumeMessage: fairness_key: str attempt_count: int queue: str + + +@dataclass(frozen=True) +class BatchEnqueueResult: + """Result for a single message within a batch enqueue operation. + + Exactly one of ``message_id`` or ``error`` is set. + """ + + message_id: str | None + error: str | None + + @property + def is_success(self) -> bool: + """Return True if this message was enqueued successfully.""" + return self.message_id is not None + + +class BatchMode(Enum): + """Controls how ``enqueue()`` routes messages to the broker. + + - ``AUTO``: Opportunistic batching via a background thread. At low load + messages are sent individually; at high load they cluster into batches. + This is the default. + - ``DISABLED``: No batching. Each ``enqueue()`` call is a direct RPC. + """ + + AUTO = auto() + DISABLED = auto() + + +@dataclass(frozen=True) +class Linger: + """Timer-based forced batching mode. + + Messages are held for up to ``linger_ms`` milliseconds or until + ``batch_size`` messages accumulate, whichever comes first. + + Args: + linger_ms: Maximum time to hold a message before flushing (milliseconds). + batch_size: Maximum number of messages per batch. + """ + + linger_ms: float + batch_size: int diff --git a/fila/v1/messages_pb2_grpc.py b/fila/v1/messages_pb2_grpc.py index fa0dc71..d27d27c 100644 --- a/fila/v1/messages_pb2_grpc.py +++ b/fila/v1/messages_pb2_grpc.py @@ -4,7 +4,7 @@ import warnings -GRPC_GENERATED_VERSION = '1.78.0' +GRPC_GENERATED_VERSION = '1.78.1' GRPC_VERSION = grpc.__version__ _version_not_supported = False diff --git a/fila/v1/service_pb2.py b/fila/v1/service_pb2.py index 11ad1f0..7f04078 100644 --- a/fila/v1/service_pb2.py +++ b/fila/v1/service_pb2.py @@ -25,7 +25,7 @@ from fila.v1 import messages_pb2 as fila_dot_v1_dot_messages__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x15\x66ila/v1/service.proto\x12\x07\x66ila.v1\x1a\x16\x66ila/v1/messages.proto\"\x97\x01\n\x0e\x45nqueueRequest\x12\r\n\x05queue\x18\x01 \x01(\t\x12\x35\n\x07headers\x18\x02 \x03(\x0b\x32$.fila.v1.EnqueueRequest.HeadersEntry\x12\x0f\n\x07payload\x18\x03 \x01(\x0c\x1a.\n\x0cHeadersEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\"%\n\x0f\x45nqueueResponse\x12\x12\n\nmessage_id\x18\x01 \x01(\t\"\x1f\n\x0e\x43onsumeRequest\x12\r\n\x05queue\x18\x01 \x01(\t\"4\n\x0f\x43onsumeResponse\x12!\n\x07message\x18\x01 \x01(\x0b\x32\x10.fila.v1.Message\"/\n\nAckRequest\x12\r\n\x05queue\x18\x01 \x01(\t\x12\x12\n\nmessage_id\x18\x02 \x01(\t\"\r\n\x0b\x41\x63kResponse\"?\n\x0bNackRequest\x12\r\n\x05queue\x18\x01 \x01(\t\x12\x12\n\nmessage_id\x18\x02 \x01(\t\x12\r\n\x05\x65rror\x18\x03 \x01(\t\"\x0e\n\x0cNackResponse2\xf2\x01\n\x0b\x46ilaService\x12<\n\x07\x45nqueue\x12\x17.fila.v1.EnqueueRequest\x1a\x18.fila.v1.EnqueueResponse\x12>\n\x07\x43onsume\x12\x17.fila.v1.ConsumeRequest\x1a\x18.fila.v1.ConsumeResponse0\x01\x12\x30\n\x03\x41\x63k\x12\x13.fila.v1.AckRequest\x1a\x14.fila.v1.AckResponse\x12\x33\n\x04Nack\x12\x14.fila.v1.NackRequest\x1a\x15.fila.v1.NackResponseb\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x15\x66ila/v1/service.proto\x12\x07\x66ila.v1\x1a\x16\x66ila/v1/messages.proto\"\x97\x01\n\x0e\x45nqueueRequest\x12\r\n\x05queue\x18\x01 \x01(\t\x12\x35\n\x07headers\x18\x02 \x03(\x0b\x32$.fila.v1.EnqueueRequest.HeadersEntry\x12\x0f\n\x07payload\x18\x03 \x01(\x0c\x1a.\n\x0cHeadersEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\"%\n\x0f\x45nqueueResponse\x12\x12\n\nmessage_id\x18\x01 \x01(\t\"\x1f\n\x0e\x43onsumeRequest\x12\r\n\x05queue\x18\x01 \x01(\t\"X\n\x0f\x43onsumeResponse\x12!\n\x07message\x18\x01 \x01(\x0b\x32\x10.fila.v1.Message\x12\"\n\x08messages\x18\x02 \x03(\x0b\x32\x10.fila.v1.Message\"/\n\nAckRequest\x12\r\n\x05queue\x18\x01 \x01(\t\x12\x12\n\nmessage_id\x18\x02 \x01(\t\"\r\n\x0b\x41\x63kResponse\"?\n\x0bNackRequest\x12\r\n\x05queue\x18\x01 \x01(\t\x12\x12\n\nmessage_id\x18\x02 \x01(\t\x12\r\n\x05\x65rror\x18\x03 \x01(\t\"\x0e\n\x0cNackResponse\"@\n\x13\x42\x61tchEnqueueRequest\x12)\n\x08messages\x18\x01 \x03(\x0b\x32\x17.fila.v1.EnqueueRequest\"D\n\x14\x42\x61tchEnqueueResponse\x12,\n\x07results\x18\x01 \x03(\x0b\x32\x1b.fila.v1.BatchEnqueueResult\"\\\n\x12\x42\x61tchEnqueueResult\x12+\n\x07success\x18\x01 \x01(\x0b\x32\x18.fila.v1.EnqueueResponseH\x00\x12\x0f\n\x05\x65rror\x18\x02 \x01(\tH\x00\x42\x08\n\x06result2\xbf\x02\n\x0b\x46ilaService\x12<\n\x07\x45nqueue\x12\x17.fila.v1.EnqueueRequest\x1a\x18.fila.v1.EnqueueResponse\x12K\n\x0c\x42\x61tchEnqueue\x12\x1c.fila.v1.BatchEnqueueRequest\x1a\x1d.fila.v1.BatchEnqueueResponse\x12>\n\x07\x43onsume\x12\x17.fila.v1.ConsumeRequest\x1a\x18.fila.v1.ConsumeResponse0\x01\x12\x30\n\x03\x41\x63k\x12\x13.fila.v1.AckRequest\x1a\x14.fila.v1.AckResponse\x12\x33\n\x04Nack\x12\x14.fila.v1.NackRequest\x1a\x15.fila.v1.NackResponseb\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) @@ -43,15 +43,21 @@ _globals['_CONSUMEREQUEST']._serialized_start=251 _globals['_CONSUMEREQUEST']._serialized_end=282 _globals['_CONSUMERESPONSE']._serialized_start=284 - _globals['_CONSUMERESPONSE']._serialized_end=336 - _globals['_ACKREQUEST']._serialized_start=338 - _globals['_ACKREQUEST']._serialized_end=385 - _globals['_ACKRESPONSE']._serialized_start=387 - _globals['_ACKRESPONSE']._serialized_end=400 - _globals['_NACKREQUEST']._serialized_start=402 - _globals['_NACKREQUEST']._serialized_end=465 - _globals['_NACKRESPONSE']._serialized_start=467 - _globals['_NACKRESPONSE']._serialized_end=481 - _globals['_FILASERVICE']._serialized_start=484 - _globals['_FILASERVICE']._serialized_end=726 + _globals['_CONSUMERESPONSE']._serialized_end=372 + _globals['_ACKREQUEST']._serialized_start=374 + _globals['_ACKREQUEST']._serialized_end=421 + _globals['_ACKRESPONSE']._serialized_start=423 + _globals['_ACKRESPONSE']._serialized_end=436 + _globals['_NACKREQUEST']._serialized_start=438 + _globals['_NACKREQUEST']._serialized_end=501 + _globals['_NACKRESPONSE']._serialized_start=503 + _globals['_NACKRESPONSE']._serialized_end=517 + _globals['_BATCHENQUEUEREQUEST']._serialized_start=519 + _globals['_BATCHENQUEUEREQUEST']._serialized_end=583 + _globals['_BATCHENQUEUERESPONSE']._serialized_start=585 + _globals['_BATCHENQUEUERESPONSE']._serialized_end=653 + _globals['_BATCHENQUEUERESULT']._serialized_start=655 + _globals['_BATCHENQUEUERESULT']._serialized_end=747 + _globals['_FILASERVICE']._serialized_start=750 + _globals['_FILASERVICE']._serialized_end=1069 # @@protoc_insertion_point(module_scope) diff --git a/fila/v1/service_pb2.pyi b/fila/v1/service_pb2.pyi index c6478c4..ca1e820 100644 --- a/fila/v1/service_pb2.pyi +++ b/fila/v1/service_pb2.pyi @@ -2,7 +2,7 @@ from fila.v1 import messages_pb2 as _messages_pb2 from google.protobuf.internal import containers as _containers from google.protobuf import descriptor as _descriptor from google.protobuf import message as _message -from collections.abc import Mapping as _Mapping +from collections.abc import Iterable as _Iterable, Mapping as _Mapping from typing import ClassVar as _ClassVar, Optional as _Optional, Union as _Union DESCRIPTOR: _descriptor.FileDescriptor @@ -37,10 +37,12 @@ class ConsumeRequest(_message.Message): def __init__(self, queue: _Optional[str] = ...) -> None: ... class ConsumeResponse(_message.Message): - __slots__ = ("message",) + __slots__ = ("message", "messages") MESSAGE_FIELD_NUMBER: _ClassVar[int] + MESSAGES_FIELD_NUMBER: _ClassVar[int] message: _messages_pb2.Message - def __init__(self, message: _Optional[_Union[_messages_pb2.Message, _Mapping]] = ...) -> None: ... + messages: _containers.RepeatedCompositeFieldContainer[_messages_pb2.Message] + def __init__(self, message: _Optional[_Union[_messages_pb2.Message, _Mapping]] = ..., messages: _Optional[_Iterable[_Union[_messages_pb2.Message, _Mapping]]] = ...) -> None: ... class AckRequest(_message.Message): __slots__ = ("queue", "message_id") @@ -67,3 +69,23 @@ class NackRequest(_message.Message): class NackResponse(_message.Message): __slots__ = () def __init__(self) -> None: ... + +class BatchEnqueueRequest(_message.Message): + __slots__ = ("messages",) + MESSAGES_FIELD_NUMBER: _ClassVar[int] + messages: _containers.RepeatedCompositeFieldContainer[EnqueueRequest] + def __init__(self, messages: _Optional[_Iterable[_Union[EnqueueRequest, _Mapping]]] = ...) -> None: ... + +class BatchEnqueueResponse(_message.Message): + __slots__ = ("results",) + RESULTS_FIELD_NUMBER: _ClassVar[int] + results: _containers.RepeatedCompositeFieldContainer[BatchEnqueueResult] + def __init__(self, results: _Optional[_Iterable[_Union[BatchEnqueueResult, _Mapping]]] = ...) -> None: ... + +class BatchEnqueueResult(_message.Message): + __slots__ = ("success", "error") + SUCCESS_FIELD_NUMBER: _ClassVar[int] + ERROR_FIELD_NUMBER: _ClassVar[int] + success: EnqueueResponse + error: str + def __init__(self, success: _Optional[_Union[EnqueueResponse, _Mapping]] = ..., error: _Optional[str] = ...) -> None: ... diff --git a/fila/v1/service_pb2_grpc.py b/fila/v1/service_pb2_grpc.py index 663ae2a..0ef11e1 100644 --- a/fila/v1/service_pb2_grpc.py +++ b/fila/v1/service_pb2_grpc.py @@ -5,7 +5,7 @@ from fila.v1 import service_pb2 as fila_dot_v1_dot_service__pb2 -GRPC_GENERATED_VERSION = '1.78.0' +GRPC_GENERATED_VERSION = '1.78.1' GRPC_VERSION = grpc.__version__ _version_not_supported = False @@ -40,6 +40,11 @@ def __init__(self, channel): request_serializer=fila_dot_v1_dot_service__pb2.EnqueueRequest.SerializeToString, response_deserializer=fila_dot_v1_dot_service__pb2.EnqueueResponse.FromString, _registered_method=True) + self.BatchEnqueue = channel.unary_unary( + '/fila.v1.FilaService/BatchEnqueue', + request_serializer=fila_dot_v1_dot_service__pb2.BatchEnqueueRequest.SerializeToString, + response_deserializer=fila_dot_v1_dot_service__pb2.BatchEnqueueResponse.FromString, + _registered_method=True) self.Consume = channel.unary_stream( '/fila.v1.FilaService/Consume', request_serializer=fila_dot_v1_dot_service__pb2.ConsumeRequest.SerializeToString, @@ -67,6 +72,12 @@ def Enqueue(self, request, context): context.set_details('Method not implemented!') raise NotImplementedError('Method not implemented!') + def BatchEnqueue(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + def Consume(self, request, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) @@ -93,6 +104,11 @@ def add_FilaServiceServicer_to_server(servicer, server): request_deserializer=fila_dot_v1_dot_service__pb2.EnqueueRequest.FromString, response_serializer=fila_dot_v1_dot_service__pb2.EnqueueResponse.SerializeToString, ), + 'BatchEnqueue': grpc.unary_unary_rpc_method_handler( + servicer.BatchEnqueue, + request_deserializer=fila_dot_v1_dot_service__pb2.BatchEnqueueRequest.FromString, + response_serializer=fila_dot_v1_dot_service__pb2.BatchEnqueueResponse.SerializeToString, + ), 'Consume': grpc.unary_stream_rpc_method_handler( servicer.Consume, request_deserializer=fila_dot_v1_dot_service__pb2.ConsumeRequest.FromString, @@ -147,6 +163,33 @@ def Enqueue(request, metadata, _registered_method=True) + @staticmethod + def BatchEnqueue(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, + target, + '/fila.v1.FilaService/BatchEnqueue', + fila_dot_v1_dot_service__pb2.BatchEnqueueRequest.SerializeToString, + fila_dot_v1_dot_service__pb2.BatchEnqueueResponse.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) + @staticmethod def Consume(request, target, diff --git a/proto/fila/v1/service.proto b/proto/fila/v1/service.proto index f14fdd0..fc0f710 100644 --- a/proto/fila/v1/service.proto +++ b/proto/fila/v1/service.proto @@ -6,6 +6,7 @@ import "fila/v1/messages.proto"; // Hot-path RPCs for producers and consumers. service FilaService { rpc Enqueue(EnqueueRequest) returns (EnqueueResponse); + rpc BatchEnqueue(BatchEnqueueRequest) returns (BatchEnqueueResponse); rpc Consume(ConsumeRequest) returns (stream ConsumeResponse); rpc Ack(AckRequest) returns (AckResponse); rpc Nack(NackRequest) returns (NackResponse); @@ -26,7 +27,8 @@ message ConsumeRequest { } message ConsumeResponse { - Message message = 1; + Message message = 1; // Single message (backward compatible, used when batch size is 1) + repeated Message messages = 2; // Batched messages (populated when server sends multiple at once) } message AckRequest { @@ -43,3 +45,18 @@ message NackRequest { } message NackResponse {} + +message BatchEnqueueRequest { + repeated EnqueueRequest messages = 1; +} + +message BatchEnqueueResponse { + repeated BatchEnqueueResult results = 1; +} + +message BatchEnqueueResult { + oneof result { + EnqueueResponse success = 1; + string error = 2; + } +} diff --git a/tests/test_batch_integration.py b/tests/test_batch_integration.py new file mode 100644 index 0000000..09aefb9 --- /dev/null +++ b/tests/test_batch_integration.py @@ -0,0 +1,220 @@ +"""Integration tests for batch enqueue and smart batching. + +These tests require a running fila-server binary. They are skipped +automatically when the server is not found (local dev). +""" + +from __future__ import annotations + +import pytest + +import fila + + +class TestBatchEnqueue: + """Integration tests for the explicit batch_enqueue method.""" + + def test_batch_enqueue_multiple_messages(self, server: object) -> None: + """batch_enqueue sends multiple messages in one RPC and returns per-message results.""" + from tests.conftest import TestServer + + assert isinstance(server, TestServer) + server.create_queue("test-batch") + + with fila.Client(server.addr, batch_mode=fila.BatchMode.DISABLED) as client: + results = client.batch_enqueue([ + ("test-batch", {"idx": "0"}, b"payload-0"), + ("test-batch", {"idx": "1"}, b"payload-1"), + ("test-batch", {"idx": "2"}, b"payload-2"), + ]) + + assert len(results) == 3 + for r in results: + assert r.is_success + assert r.message_id is not None + assert r.error is None + + # All message IDs should be unique. + ids = [r.message_id for r in results] + assert len(set(ids)) == 3 + + def test_batch_enqueue_single_message(self, server: object) -> None: + """batch_enqueue works with a single message.""" + from tests.conftest import TestServer + + assert isinstance(server, TestServer) + server.create_queue("test-batch-single") + + with fila.Client(server.addr, batch_mode=fila.BatchMode.DISABLED) as client: + results = client.batch_enqueue([ + ("test-batch-single", None, b"solo"), + ]) + + assert len(results) == 1 + assert results[0].is_success + assert results[0].message_id is not None + + def test_batch_enqueue_consume_verify(self, server: object) -> None: + """Messages enqueued via batch_enqueue can be consumed and acked.""" + from tests.conftest import TestServer + + assert isinstance(server, TestServer) + server.create_queue("test-batch-consume") + + with fila.Client(server.addr, batch_mode=fila.BatchMode.DISABLED) as client: + results = client.batch_enqueue([ + ("test-batch-consume", {"k": "v"}, b"batch-msg"), + ]) + assert results[0].is_success + + stream = client.consume("test-batch-consume") + msg = next(stream) + + assert msg.id == results[0].message_id + assert msg.headers["k"] == "v" + assert msg.payload == b"batch-msg" + + client.ack("test-batch-consume", msg.id) + + +class TestAsyncBatchEnqueue: + """Integration tests for the async batch_enqueue method.""" + + @pytest.mark.asyncio + async def test_async_batch_enqueue(self, server: object) -> None: + """Async batch_enqueue sends multiple messages.""" + from tests.conftest import TestServer + + assert isinstance(server, TestServer) + server.create_queue("test-async-batch") + + async with fila.AsyncClient(server.addr) as client: + results = await client.batch_enqueue([ + ("test-async-batch", None, b"async-0"), + ("test-async-batch", None, b"async-1"), + ]) + + assert len(results) == 2 + for r in results: + assert r.is_success + assert r.message_id is not None + + +class TestSmartBatching: + """Integration tests for smart batching (BatchMode.AUTO).""" + + def test_auto_mode_enqueue(self, server: object) -> None: + """AUTO mode enqueues messages through the batcher.""" + from tests.conftest import TestServer + + assert isinstance(server, TestServer) + server.create_queue("test-auto-batch") + + with fila.Client(server.addr, batch_mode=fila.BatchMode.AUTO) as client: + msg_id = client.enqueue("test-auto-batch", None, b"auto-msg") + assert msg_id != "" + + # Verify the message was actually enqueued. + stream = client.consume("test-auto-batch") + msg = next(stream) + assert msg.id == msg_id + assert msg.payload == b"auto-msg" + client.ack("test-auto-batch", msg.id) + + def test_auto_mode_multiple_messages(self, server: object) -> None: + """AUTO mode handles multiple sequential enqueues.""" + from tests.conftest import TestServer + + assert isinstance(server, TestServer) + server.create_queue("test-auto-multi") + + with fila.Client(server.addr, batch_mode=fila.BatchMode.AUTO) as client: + ids = [] + for i in range(5): + msg_id = client.enqueue( + "test-auto-multi", None, f"msg-{i}".encode() + ) + assert msg_id != "" + ids.append(msg_id) + + # All IDs should be unique. + assert len(set(ids)) == 5 + + def test_disabled_mode_enqueue(self, server: object) -> None: + """DISABLED mode sends each enqueue as a direct RPC.""" + from tests.conftest import TestServer + + assert isinstance(server, TestServer) + server.create_queue("test-disabled") + + with fila.Client(server.addr, batch_mode=fila.BatchMode.DISABLED) as client: + msg_id = client.enqueue("test-disabled", None, b"direct") + assert msg_id != "" + + stream = client.consume("test-disabled") + msg = next(stream) + assert msg.id == msg_id + client.ack("test-disabled", msg.id) + + def test_linger_mode_enqueue(self, server: object) -> None: + """LINGER mode enqueues messages through a timer-based batcher.""" + from tests.conftest import TestServer + + assert isinstance(server, TestServer) + server.create_queue("test-linger") + + with fila.Client( + server.addr, + batch_mode=fila.Linger(linger_ms=50, batch_size=10), + ) as client: + msg_id = client.enqueue("test-linger", None, b"lingered") + assert msg_id != "" + + stream = client.consume("test-linger") + msg = next(stream) + assert msg.id == msg_id + assert msg.payload == b"lingered" + client.ack("test-linger", msg.id) + + def test_default_mode_is_auto(self, server: object) -> None: + """Client defaults to AUTO batch mode.""" + from tests.conftest import TestServer + + assert isinstance(server, TestServer) + server.create_queue("test-default-mode") + + # No batch_mode arg = AUTO. + with fila.Client(server.addr) as client: + msg_id = client.enqueue("test-default-mode", None, b"default") + assert msg_id != "" + + +class TestBatchModeTypes: + """Unit tests for BatchMode and Linger types (no server needed).""" + + def test_batch_mode_enum(self) -> None: + """BatchMode has AUTO and DISABLED variants.""" + assert fila.BatchMode.AUTO is not None + assert fila.BatchMode.DISABLED is not None + modes = {fila.BatchMode.AUTO, fila.BatchMode.DISABLED} + assert len(modes) == 2 # They are distinct values + + def test_linger_fields(self) -> None: + """Linger stores linger_ms and batch_size.""" + linger = fila.Linger(linger_ms=100, batch_size=50) + assert linger.linger_ms == 100 + assert linger.batch_size == 50 + + def test_batch_enqueue_result_success(self) -> None: + """BatchEnqueueResult.is_success returns True when message_id is set.""" + r = fila.BatchEnqueueResult(message_id="abc", error=None) + assert r.is_success + assert r.message_id == "abc" + assert r.error is None + + def test_batch_enqueue_result_error(self) -> None: + """BatchEnqueueResult.is_success returns False when error is set.""" + r = fila.BatchEnqueueResult(message_id=None, error="queue not found") + assert not r.is_success + assert r.message_id is None + assert r.error == "queue not found" diff --git a/tests/test_batcher.py b/tests/test_batcher.py new file mode 100644 index 0000000..ee10e71 --- /dev/null +++ b/tests/test_batcher.py @@ -0,0 +1,333 @@ +"""Unit tests for the batcher module. + +These tests use mock stubs and do not require a running fila-server. +""" + +from __future__ import annotations + +import threading +from concurrent.futures import Future +from typing import Any +from unittest.mock import MagicMock + +import pytest + +from fila.batcher import ( + AutoBatcher, + LingerBatcher, + _EnqueueRequest, + _flush_batch, + _flush_single, +) +from fila.errors import BatchEnqueueError +from fila.v1 import service_pb2 + + +class FakeEnqueueResponse: + """Minimal fake for service_pb2.EnqueueResponse.""" + + def __init__(self, message_id: str) -> None: + self.message_id = message_id + + +class FakeBatchResult: + """Minimal fake for service_pb2.BatchEnqueueResult.""" + + def __init__(self, message_id: str | None = None, error: str | None = None) -> None: + self._message_id = message_id + self._error = error + self.success: FakeEnqueueResponse | None = ( + FakeEnqueueResponse(message_id) if message_id is not None else None + ) + self.error = error or "" + + def HasField(self, name: str) -> bool: # noqa: N802 + if name == "success": + return self._message_id is not None + return False + + +class FakeBatchResponse: + """Minimal fake for service_pb2.BatchEnqueueResponse.""" + + def __init__(self, results: list[FakeBatchResult]) -> None: + self.results = results + + +class TestFlushSingle: + """Test the _flush_single function.""" + + def test_success(self) -> None: + stub = MagicMock() + stub.Enqueue.return_value = FakeEnqueueResponse("msg-001") + + proto = service_pb2.EnqueueRequest(queue="q", payload=b"data") + fut: Future[str] = Future() + req = _EnqueueRequest(proto, fut) + + _flush_single(stub, req) + + assert fut.result(timeout=1.0) == "msg-001" + stub.Enqueue.assert_called_once_with(proto) + + def test_rpc_error(self) -> None: + import grpc + + stub = MagicMock() + rpc_error = MagicMock() + rpc_error.code.return_value = grpc.StatusCode.NOT_FOUND + rpc_error.details.return_value = "queue not found" + # Make it pass isinstance(e, grpc.RpcError) check. + stub.Enqueue.side_effect = type( + "_FakeRpcError", (grpc.RpcError,), { + "code": lambda self: grpc.StatusCode.NOT_FOUND, + "details": lambda self: "queue not found", + } + )() + + proto = service_pb2.EnqueueRequest(queue="missing", payload=b"data") + fut: Future[str] = Future() + req = _EnqueueRequest(proto, fut) + + _flush_single(stub, req) + + from fila.errors import QueueNotFoundError + + with pytest.raises(QueueNotFoundError): + fut.result(timeout=1.0) + + +class TestFlushBatch: + """Test the _flush_batch function.""" + + def test_all_success(self) -> None: + stub = MagicMock() + stub.BatchEnqueue.return_value = FakeBatchResponse([ + FakeBatchResult(message_id="id-1"), + FakeBatchResult(message_id="id-2"), + ]) + + reqs = [ + _EnqueueRequest( + service_pb2.EnqueueRequest(queue="q", payload=b"a"), + Future(), + ), + _EnqueueRequest( + service_pb2.EnqueueRequest(queue="q", payload=b"b"), + Future(), + ), + ] + + _flush_batch(stub, reqs) + + assert reqs[0].future.result(timeout=1.0) == "id-1" + assert reqs[1].future.result(timeout=1.0) == "id-2" + + def test_mixed_results(self) -> None: + stub = MagicMock() + stub.BatchEnqueue.return_value = FakeBatchResponse([ + FakeBatchResult(message_id="id-1"), + FakeBatchResult(error="queue 'missing' not found"), + ]) + + reqs = [ + _EnqueueRequest( + service_pb2.EnqueueRequest(queue="q", payload=b"a"), + Future(), + ), + _EnqueueRequest( + service_pb2.EnqueueRequest(queue="missing", payload=b"b"), + Future(), + ), + ] + + _flush_batch(stub, reqs) + + assert reqs[0].future.result(timeout=1.0) == "id-1" + with pytest.raises(BatchEnqueueError, match="queue 'missing' not found"): + reqs[1].future.result(timeout=1.0) + + def test_rpc_failure_sets_all_futures(self) -> None: + import grpc + + stub = MagicMock() + stub.BatchEnqueue.side_effect = type( + "_FakeRpcError", (grpc.RpcError,), { + "code": lambda self: grpc.StatusCode.UNAVAILABLE, + "details": lambda self: "server unavailable", + } + )() + + reqs = [ + _EnqueueRequest( + service_pb2.EnqueueRequest(queue="q", payload=b"a"), + Future(), + ), + _EnqueueRequest( + service_pb2.EnqueueRequest(queue="q", payload=b"b"), + Future(), + ), + ] + + _flush_batch(stub, reqs) + + for r in reqs: + with pytest.raises(BatchEnqueueError): + r.future.result(timeout=1.0) + + +class TestAutoBatcher: + """Test the AutoBatcher end-to-end.""" + + def test_single_message_uses_enqueue(self) -> None: + """When only one message is queued, AutoBatcher uses singular Enqueue.""" + stub = MagicMock() + stub.Enqueue.return_value = FakeEnqueueResponse("msg-solo") + + batcher = AutoBatcher(stub, max_batch_size=100) + + proto = service_pb2.EnqueueRequest(queue="q", payload=b"solo") + fut = batcher.submit(proto) + result = fut.result(timeout=5.0) + + assert result == "msg-solo" + stub.Enqueue.assert_called_once() + stub.BatchEnqueue.assert_not_called() + + batcher.close() + + def test_concurrent_messages_batched(self) -> None: + """When multiple messages arrive concurrently, they batch together.""" + stub = MagicMock() + + # The first message will block Enqueue while more messages queue up. + # We need to make the batcher see all messages at once. + batch_called = threading.Event() + batch_response = FakeBatchResponse([ + FakeBatchResult(message_id=f"id-{i}") for i in range(5) + ]) + + def mock_batch_enqueue(request: Any) -> FakeBatchResponse: + batch_called.set() + return batch_response + + # Make single Enqueue block briefly so messages accumulate. + single_barrier = threading.Event() + + def mock_single_enqueue(request: Any) -> FakeEnqueueResponse: + single_barrier.wait(timeout=5.0) + return FakeEnqueueResponse("should-not-be-used") + + stub.Enqueue.side_effect = mock_single_enqueue + stub.BatchEnqueue.side_effect = mock_batch_enqueue + + batcher = AutoBatcher(stub, max_batch_size=100) + + # Submit 5 messages rapidly before the first can process. + # The batcher should drain them all in one batch. + protos = [ + service_pb2.EnqueueRequest(queue="q", payload=f"msg-{i}".encode()) + for i in range(5) + ] + + # We need to submit them in a way that they all arrive before + # the batcher loop drains. Use a barrier approach. + futures = [] + for p in protos: + futures.append(batcher.submit(p)) + + # Give the batcher thread time to drain and flush. + # Either BatchEnqueue or multiple Enqueue calls will resolve things. + for _i, f in enumerate(futures): + result = f.result(timeout=5.0) + assert result is not None + + batcher.close() + + def test_close_drains_pending(self) -> None: + """close() waits for pending messages to be flushed.""" + stub = MagicMock() + stub.Enqueue.return_value = FakeEnqueueResponse("drained") + + batcher = AutoBatcher(stub, max_batch_size=100) + + proto = service_pb2.EnqueueRequest(queue="q", payload=b"drain-me") + fut = batcher.submit(proto) + + batcher.close() + + # After close, the future should be resolved. + assert fut.result(timeout=1.0) == "drained" + + def test_update_stub(self) -> None: + """update_stub replaces the gRPC stub used for flushing.""" + old_stub = MagicMock() + new_stub = MagicMock() + new_stub.Enqueue.return_value = FakeEnqueueResponse("new-stub") + + batcher = AutoBatcher(old_stub, max_batch_size=100) + + # Update stub before submitting. + batcher.update_stub(new_stub) + + proto = service_pb2.EnqueueRequest(queue="q", payload=b"data") + fut = batcher.submit(proto) + result = fut.result(timeout=5.0) + + assert result == "new-stub" + batcher.close() + + +class TestLingerBatcher: + """Test the LingerBatcher.""" + + def test_flushes_at_batch_size(self) -> None: + """Flush triggers when batch_size messages accumulate.""" + stub = MagicMock() + stub.BatchEnqueue.return_value = FakeBatchResponse([ + FakeBatchResult(message_id=f"id-{i}") for i in range(3) + ]) + + batcher = LingerBatcher(stub, linger_ms=5000, batch_size=3) + + futures = [] + for i in range(3): + proto = service_pb2.EnqueueRequest(queue="q", payload=f"m{i}".encode()) + futures.append(batcher.submit(proto)) + + # Should flush quickly because batch_size=3 was reached. + for i, f in enumerate(futures): + result = f.result(timeout=5.0) + assert result == f"id-{i}" + + batcher.close() + + def test_flushes_at_linger_timeout(self) -> None: + """Flush triggers after linger_ms even if batch_size is not reached.""" + stub = MagicMock() + stub.Enqueue.return_value = FakeEnqueueResponse("lingered") + + batcher = LingerBatcher(stub, linger_ms=50, batch_size=100) + + proto = service_pb2.EnqueueRequest(queue="q", payload=b"linger") + fut = batcher.submit(proto) + + # Should flush after ~50ms even though batch_size=100 not reached. + result = fut.result(timeout=5.0) + assert result == "lingered" + + batcher.close() + + def test_close_drains_pending(self) -> None: + """close() drains any pending messages.""" + stub = MagicMock() + stub.Enqueue.return_value = FakeEnqueueResponse("drained") + + batcher = LingerBatcher(stub, linger_ms=10000, batch_size=100) + + proto = service_pb2.EnqueueRequest(queue="q", payload=b"drain") + fut = batcher.submit(proto) + + batcher.close() + + assert fut.result(timeout=1.0) == "drained" From 39a2e2e8798c8ac22f0dc0fa02d58c226c355d27 Mon Sep 17 00:00:00 2001 From: Lucas Vieira Date: Wed, 25 Mar 2026 00:11:32 -0300 Subject: [PATCH 08/17] feat: unified api surface for story 30.2 Update Python SDK for the unified proto API: - Enqueue RPC now uses repeated EnqueueMessage/EnqueueResult - BatchEnqueue RPC removed; enqueue_many() replaces batch_enqueue() - Ack/Nack use repeated AckMessage/NackMessage with per-message results - ConsumeResponse only has repeated messages field - BatchMode renamed to AccumulatorMode, BatchEnqueueResult to EnqueueResult - BatchEnqueueError renamed to EnqueueError - Linger.batch_size renamed to Linger.max_messages - Per-message error codes mapped to typed SDK exceptions (e.g. ENQUEUE_ERROR_CODE_QUEUE_NOT_FOUND -> QueueNotFoundError) - All 31 tests pass (16 unit, 15 integration) --- fila/__init__.py | 10 +- fila/async_client.py | 116 ++++---- fila/batcher.py | 162 +++++------ fila/client.py | 194 +++++++------ fila/errors.py | 32 +- fila/types.py | 18 +- fila/v1/admin_pb2.py | 30 +- fila/v1/admin_pb2.pyi | 90 ------ fila/v1/admin_pb2_grpc.py | 219 +------------- fila/v1/service_pb2.py | 84 ++++-- fila/v1/service_pb2.pyi | 162 +++++++++-- fila/v1/service_pb2_grpc.py | 30 +- proto/fila/v1/admin.proto | 78 ----- proto/fila/v1/service.proto | 116 ++++++-- tests/test_batcher.py | 273 +++++++++--------- ...gration.py => test_enqueue_integration.py} | 140 +++++---- 16 files changed, 797 insertions(+), 957 deletions(-) rename tests/{test_batch_integration.py => test_enqueue_integration.py} (51%) diff --git a/fila/__init__.py b/fila/__init__.py index 732fc43..9117c96 100644 --- a/fila/__init__.py +++ b/fila/__init__.py @@ -3,21 +3,21 @@ from fila.async_client import AsyncClient from fila.client import Client from fila.errors import ( - BatchEnqueueError, + EnqueueError, FilaError, MessageNotFoundError, QueueNotFoundError, RPCError, ) -from fila.types import BatchEnqueueResult, BatchMode, ConsumeMessage, Linger +from fila.types import AccumulatorMode, ConsumeMessage, EnqueueResult, Linger __all__ = [ + "AccumulatorMode", "AsyncClient", - "BatchEnqueueError", - "BatchEnqueueResult", - "BatchMode", "Client", "ConsumeMessage", + "EnqueueError", + "EnqueueResult", "FilaError", "Linger", "MessageNotFoundError", diff --git a/fila/async_client.py b/fila/async_client.py index 8e06b1e..a1f2962 100644 --- a/fila/async_client.py +++ b/fila/async_client.py @@ -10,15 +10,16 @@ if TYPE_CHECKING: from collections.abc import AsyncIterator -from fila.client import _proto_msg_to_consume_message +from fila.client import _proto_enqueue_result_to_sdk, _proto_msg_to_consume_message from fila.errors import ( + EnqueueError, _map_ack_error, - _map_batch_enqueue_error, _map_consume_error, _map_enqueue_error, + _map_enqueue_result_error, _map_nack_error, ) -from fila.types import BatchEnqueueResult, ConsumeMessage +from fila.types import ConsumeMessage, EnqueueResult from fila.v1 import service_pb2, service_pb2_grpc @@ -125,7 +126,7 @@ def _extract_leader_hint(err: grpc.RpcError) -> str | None: class AsyncClient: """Asynchronous client for the Fila message broker. - Wraps the hot-path gRPC operations: enqueue, batch_enqueue, consume, ack, + Wraps the hot-path gRPC operations: enqueue, enqueue_many, consume, ack, nack. Usage:: @@ -255,26 +256,35 @@ async def enqueue( try: resp = await self._stub.Enqueue( service_pb2.EnqueueRequest( - queue=queue, - headers=headers or {}, - payload=payload, + messages=[ + service_pb2.EnqueueMessage( + queue=queue, + headers=headers or {}, + payload=payload, + ) + ] ) ) except grpc.RpcError as e: raise _map_enqueue_error(e) from e - return str(resp.message_id) - async def batch_enqueue( + result = resp.results[0] + which = result.WhichOneof("result") + if which == "message_id": + return str(result.message_id) + raise _map_enqueue_result_error(result.error.code, result.error.message) + + async def enqueue_many( self, messages: list[tuple[str, dict[str, str] | None, bytes]], - ) -> list[BatchEnqueueResult]: + ) -> list[EnqueueResult]: """Enqueue multiple messages in a single RPC. Args: messages: List of (queue, headers, payload) tuples. Returns: - List of ``BatchEnqueueResult`` objects, one per input message. + List of ``EnqueueResult`` objects, one per input message. Each result has either a ``message_id`` (success) or ``error`` (per-message failure). @@ -283,7 +293,7 @@ async def batch_enqueue( RPCError: For unexpected gRPC failures. """ proto_messages = [ - service_pb2.EnqueueRequest( + service_pb2.EnqueueMessage( queue=q, headers=h or {}, payload=p, @@ -292,26 +302,13 @@ async def batch_enqueue( ] try: - resp = await self._stub.BatchEnqueue( - service_pb2.BatchEnqueueRequest(messages=proto_messages) + resp = await self._stub.Enqueue( + service_pb2.EnqueueRequest(messages=proto_messages) ) except grpc.RpcError as e: - raise _map_batch_enqueue_error(e) from e - - results: list[BatchEnqueueResult] = [] - for r in resp.results: - if r.HasField("success"): - results.append( - BatchEnqueueResult( - message_id=str(r.success.message_id), - error=None, - ) - ) - else: - results.append( - BatchEnqueueResult(message_id=None, error=r.error) - ) - return results + raise _map_enqueue_error(e) from e + + return [_proto_enqueue_result_to_sdk(r) for r in resp.results] async def consume(self, queue: str) -> AsyncIterator[ConsumeMessage]: """Open a streaming consumer on the specified queue. @@ -363,25 +360,12 @@ async def _consume_iter( self, stream: Any, ) -> AsyncIterator[ConsumeMessage]: - """Internal async generator reading from the gRPC stream. - - Handles both singular ``message`` field (backward compatible) and - repeated ``messages`` field (batched delivery). - """ + """Internal async generator reading from the gRPC stream.""" try: async for resp in stream: - # Check batched messages first (repeated field). - if len(resp.messages) > 0: - for msg in resp.messages: - if msg is not None and msg.ByteSize(): - yield _proto_msg_to_consume_message(msg) - continue - - # Fall back to singular message field. - msg = resp.message - if msg is None or not msg.ByteSize(): - continue # keepalive - yield _proto_msg_to_consume_message(msg) + for msg in resp.messages: + if msg is not None and msg.ByteSize(): + yield _proto_msg_to_consume_message(msg) except grpc.RpcError: return @@ -399,12 +383,26 @@ async def ack(self, queue: str, msg_id: str) -> None: RPCError: For unexpected gRPC failures. """ try: - await self._stub.Ack( - service_pb2.AckRequest(queue=queue, message_id=msg_id) + resp = await self._stub.Ack( + service_pb2.AckRequest( + messages=[service_pb2.AckMessage(queue=queue, message_id=msg_id)] + ) ) except grpc.RpcError as e: raise _map_ack_error(e) from e + # Check per-message result for errors. + if resp.results: + result = resp.results[0] + which = result.WhichOneof("result") + if which == "error": + from fila.errors import MessageNotFoundError, RPCError as _RPCError + + ack_err = result.error + if ack_err.code == service_pb2.ACK_ERROR_CODE_MESSAGE_NOT_FOUND: + raise MessageNotFoundError(f"ack: {ack_err.message}") + raise _RPCError(grpc.StatusCode.INTERNAL, f"ack: {ack_err.message}") + async def nack(self, queue: str, msg_id: str, error: str) -> None: """Negatively acknowledge a message that failed processing. @@ -421,10 +419,26 @@ async def nack(self, queue: str, msg_id: str, error: str) -> None: RPCError: For unexpected gRPC failures. """ try: - await self._stub.Nack( + resp = await self._stub.Nack( service_pb2.NackRequest( - queue=queue, message_id=msg_id, error=error + messages=[ + service_pb2.NackMessage( + queue=queue, message_id=msg_id, error=error + ) + ] ) ) except grpc.RpcError as e: raise _map_nack_error(e) from e + + # Check per-message result for errors. + if resp.results: + result = resp.results[0] + which = result.WhichOneof("result") + if which == "error": + from fila.errors import MessageNotFoundError, RPCError as _RPCError + + nack_err = result.error + if nack_err.code == service_pb2.NACK_ERROR_CODE_MESSAGE_NOT_FOUND: + raise MessageNotFoundError(f"nack: {nack_err.message}") + raise _RPCError(grpc.StatusCode.INTERNAL, f"nack: {nack_err.message}") diff --git a/fila/batcher.py b/fila/batcher.py index c57964a..1bf2994 100644 --- a/fila/batcher.py +++ b/fila/batcher.py @@ -1,4 +1,4 @@ -"""Background batcher for opportunistic and linger-based enqueue batching.""" +"""Background accumulator for opportunistic and linger-based enqueue accumulation.""" from __future__ import annotations @@ -9,137 +9,131 @@ import grpc -from fila.errors import BatchEnqueueError, _map_enqueue_error -from fila.types import BatchEnqueueResult +from fila.errors import EnqueueError, _map_enqueue_error, _map_enqueue_result_error from fila.v1 import service_pb2 if TYPE_CHECKING: from fila.v1 import service_pb2_grpc -# Sentinel that signals the batcher thread to stop. +# Sentinel that signals the accumulator thread to stop. _STOP = object() -# Maximum batch size when none is configured. -_DEFAULT_MAX_BATCH_SIZE = 1000 +# Maximum number of messages per flush when none is configured. +_DEFAULT_MAX_MESSAGES = 1000 -class _EnqueueRequest: - """Internal envelope pairing a proto request with its result future.""" +class _EnqueueItem: + """Internal envelope pairing a proto EnqueueMessage with its result future.""" __slots__ = ("proto", "future") def __init__( self, - proto: service_pb2.EnqueueRequest, + proto: service_pb2.EnqueueMessage, future: Future[str], ) -> None: self.proto = proto self.future = future -def _msg_to_consume_result( - proto_result: Any, -) -> BatchEnqueueResult: - """Convert a proto ``BatchEnqueueResult`` to the SDK type.""" - if proto_result.HasField("success"): - return BatchEnqueueResult( - message_id=proto_result.success.message_id, - error=None, - ) - return BatchEnqueueResult( - message_id=None, - error=proto_result.error, - ) - - def _flush_single( stub: service_pb2_grpc.FilaServiceStub, - req: _EnqueueRequest, + req: _EnqueueItem, ) -> None: - """Send a single message via the singular Enqueue RPC. + """Send a single message via the unified Enqueue RPC. This preserves the specific error types (QueueNotFoundError, etc.) that callers of ``enqueue()`` expect. """ try: - resp = stub.Enqueue(req.proto) - req.future.set_result(str(resp.message_id)) + resp = stub.Enqueue( + service_pb2.EnqueueRequest(messages=[req.proto]) + ) + result = resp.results[0] + which = result.WhichOneof("result") + if which == "message_id": + req.future.set_result(str(result.message_id)) + else: + req.future.set_exception( + _map_enqueue_result_error(result.error.code, result.error.message) + ) except grpc.RpcError as e: req.future.set_exception(_map_enqueue_error(e)) except Exception as e: req.future.set_exception(e) -def _flush_batch( +def _flush_many( stub: service_pb2_grpc.FilaServiceStub, - batch: list[_EnqueueRequest], + items: list[_EnqueueItem], ) -> None: - """Send a batch of messages via the BatchEnqueue RPC. + """Send multiple messages via the unified Enqueue RPC. - On RPC-level failure, every future in the batch receives a - ``BatchEnqueueError``. On success, each future gets either its - message ID or a per-message error string wrapped in a - ``BatchEnqueueError``. + On RPC-level failure, every future in the batch receives an + ``EnqueueError``. On success, each future gets either its + message ID or a per-message error string wrapped in an + ``EnqueueError``. """ try: - resp = stub.BatchEnqueue( - service_pb2.BatchEnqueueRequest( - messages=[r.proto for r in batch], + resp = stub.Enqueue( + service_pb2.EnqueueRequest( + messages=[item.proto for item in items], ) ) except grpc.RpcError as e: - err = BatchEnqueueError(f"batch enqueue rpc failed: {e.details()}") - for r in batch: - r.future.set_exception(err) + err = EnqueueError(f"enqueue rpc failed: {e.details()}") + for item in items: + item.future.set_exception(err) return except Exception as e: - for r in batch: - r.future.set_exception(e) + for item in items: + item.future.set_exception(e) return # Pair each result with its request future. for i, result in enumerate(resp.results): - if i >= len(batch): + if i >= len(items): break - req = batch[i] - if result.HasField("success"): - req.future.set_result(str(result.success.message_id)) + item = items[i] + which = result.WhichOneof("result") + if which == "message_id": + item.future.set_result(str(result.message_id)) else: - req.future.set_exception( - BatchEnqueueError(f"enqueue failed: {result.error}") + item.future.set_exception( + _map_enqueue_result_error(result.error.code, result.error.message) ) -class AutoBatcher: - """Opportunistic batcher: drains a queue and flushes in batches. +class AutoAccumulator: + """Opportunistic accumulator: drains a queue and flushes in batches. A background daemon thread blocks on the first message, then non-blocking drains any additional messages that arrived during processing and flushes - them as a single batch via a thread pool executor. + them as a single Enqueue RPC via a thread pool executor. """ def __init__( self, stub: service_pb2_grpc.FilaServiceStub, - max_batch_size: int = _DEFAULT_MAX_BATCH_SIZE, + max_messages: int = _DEFAULT_MAX_MESSAGES, max_workers: int = 4, ) -> None: self._stub = stub - self._max_batch_size = max_batch_size - self._queue: queue.Queue[_EnqueueRequest | object] = queue.Queue() + self._max_messages = max_messages + self._queue: queue.Queue[_EnqueueItem | object] = queue.Queue() self._executor = ThreadPoolExecutor(max_workers=max_workers) self._thread = threading.Thread(target=self._run, daemon=True) self._thread.start() - def submit(self, proto: service_pb2.EnqueueRequest) -> Future[str]: - """Submit a message for batched enqueue. Returns a Future for the message ID.""" + def submit(self, proto: service_pb2.EnqueueMessage) -> Future[str]: + """Submit a message for accumulated enqueue. Returns a Future for the message ID.""" fut: Future[str] = Future() - self._queue.put(_EnqueueRequest(proto, fut)) + self._queue.put(_EnqueueItem(proto, fut)) return fut def close(self, timeout: float | None = 30.0) -> None: - """Drain pending messages and shut down the batcher. + """Drain pending messages and shut down the accumulator. Blocks until all pending messages have been flushed or *timeout* seconds have elapsed. @@ -160,11 +154,11 @@ def _run(self) -> None: if first is _STOP: return - assert isinstance(first, _EnqueueRequest) - batch: list[_EnqueueRequest] = [first] + assert isinstance(first, _EnqueueItem) + batch: list[_EnqueueItem] = [first] # Non-blocking drain of any additional queued messages. - while len(batch) < self._max_batch_size: + while len(batch) < self._max_messages: try: item = self._queue.get_nowait() except queue.Empty: @@ -173,25 +167,25 @@ def _run(self) -> None: # Flush what we have, then stop. self._flush(batch) return - assert isinstance(item, _EnqueueRequest) + assert isinstance(item, _EnqueueItem) batch.append(item) self._flush(batch) - def _flush(self, batch: list[_EnqueueRequest]) -> None: + def _flush(self, batch: list[_EnqueueItem]) -> None: """Dispatch a batch to the executor for concurrent RPC.""" if len(batch) == 1: - # Single-item optimization: use singular Enqueue RPC. + # Single-item optimization: still uses Enqueue but with one message. self._executor.submit(_flush_single, self._stub, batch[0]) else: - self._executor.submit(_flush_batch, self._stub, batch) + self._executor.submit(_flush_many, self._stub, batch) -class LingerBatcher: - """Timer-based batcher: holds messages for up to linger_ms or batch_size. +class LingerAccumulator: + """Timer-based accumulator: holds messages for up to linger_ms or max_messages. A background daemon thread accumulates messages and flushes when either - the batch reaches ``batch_size`` or ``linger_ms`` milliseconds have + the count reaches ``max_messages`` or ``linger_ms`` milliseconds have elapsed since the first message in the current batch arrived. """ @@ -199,25 +193,25 @@ def __init__( self, stub: service_pb2_grpc.FilaServiceStub, linger_ms: float, - batch_size: int, + max_messages: int, max_workers: int = 4, ) -> None: self._stub = stub self._linger_s = linger_ms / 1000.0 - self._batch_size = batch_size - self._queue: queue.Queue[_EnqueueRequest | object] = queue.Queue() + self._max_messages = max_messages + self._queue: queue.Queue[_EnqueueItem | object] = queue.Queue() self._executor = ThreadPoolExecutor(max_workers=max_workers) self._thread = threading.Thread(target=self._run, daemon=True) self._thread.start() - def submit(self, proto: service_pb2.EnqueueRequest) -> Future[str]: - """Submit a message for batched enqueue. Returns a Future for the message ID.""" + def submit(self, proto: service_pb2.EnqueueMessage) -> Future[str]: + """Submit a message for accumulated enqueue. Returns a Future for the message ID.""" fut: Future[str] = Future() - self._queue.put(_EnqueueRequest(proto, fut)) + self._queue.put(_EnqueueItem(proto, fut)) return fut def close(self, timeout: float | None = 30.0) -> None: - """Drain pending messages and shut down the batcher.""" + """Drain pending messages and shut down the accumulator.""" self._queue.put(_STOP) self._thread.join(timeout=timeout) self._executor.shutdown(wait=True) @@ -227,7 +221,7 @@ def update_stub(self, stub: service_pb2_grpc.FilaServiceStub) -> None: self._stub = stub def _run(self) -> None: - """Background loop: accumulate up to batch_size or linger timeout.""" + """Background loop: accumulate up to max_messages or linger timeout.""" import time while True: @@ -236,14 +230,14 @@ def _run(self) -> None: if first is _STOP: return - assert isinstance(first, _EnqueueRequest) - batch: list[_EnqueueRequest] = [first] + assert isinstance(first, _EnqueueItem) + batch: list[_EnqueueItem] = [first] # Track wall-clock deadline from when first message arrived. deadline = time.monotonic() + self._linger_s - # Accumulate more items until batch_size or linger timeout. - while len(batch) < self._batch_size: + # Accumulate more items until max_messages or linger timeout. + while len(batch) < self._max_messages: remaining = deadline - time.monotonic() if remaining <= 0: break @@ -254,14 +248,14 @@ def _run(self) -> None: if item is _STOP: self._flush(batch) return - assert isinstance(item, _EnqueueRequest) + assert isinstance(item, _EnqueueItem) batch.append(item) self._flush(batch) - def _flush(self, batch: list[_EnqueueRequest]) -> None: + def _flush(self, batch: list[_EnqueueItem]) -> None: """Dispatch a batch to the executor for concurrent RPC.""" if len(batch) == 1: self._executor.submit(_flush_single, self._stub, batch[0]) else: - self._executor.submit(_flush_batch, self._stub, batch) + self._executor.submit(_flush_many, self._stub, batch) diff --git a/fila/client.py b/fila/client.py index 0d7e49a..2f44bb6 100644 --- a/fila/client.py +++ b/fila/client.py @@ -6,15 +6,16 @@ import grpc -from fila.batcher import AutoBatcher, LingerBatcher +from fila.batcher import AutoAccumulator, LingerAccumulator from fila.errors import ( + EnqueueError, _map_ack_error, - _map_batch_enqueue_error, _map_consume_error, _map_enqueue_error, + _map_enqueue_result_error, _map_nack_error, ) -from fila.types import BatchEnqueueResult, BatchMode, ConsumeMessage, Linger +from fila.types import AccumulatorMode, ConsumeMessage, EnqueueResult, Linger from fila.v1 import service_pb2, service_pb2_grpc if TYPE_CHECKING: @@ -53,6 +54,14 @@ def _proto_msg_to_consume_message(msg: Any) -> ConsumeMessage: ) +def _proto_enqueue_result_to_sdk(result: Any) -> EnqueueResult: + """Convert a proto EnqueueResult to the SDK type.""" + which = result.WhichOneof("result") + if which == "message_id": + return EnqueueResult(message_id=str(result.message_id), error=None) + return EnqueueResult(message_id=None, error=result.error.message) + + class _ClientCallDetails( grpc.ClientCallDetails, # type: ignore[misc] ): @@ -122,7 +131,7 @@ def intercept_unary_stream( class Client: """Synchronous client for the Fila message broker. - Wraps the hot-path gRPC operations: enqueue, batch_enqueue, consume, ack, + Wraps the hot-path gRPC operations: enqueue, enqueue_many, consume, ack, nack. Usage:: @@ -138,16 +147,16 @@ class Client: with Client("localhost:5555") as client: client.enqueue("my-queue", None, b"hello") - Batch modes:: + Accumulator modes:: - # AUTO (default): opportunistic batching via background thread + # AUTO (default): opportunistic accumulation via background thread client = Client("localhost:5555") # DISABLED: each enqueue() is a direct RPC - client = Client("localhost:5555", batch_mode=BatchMode.DISABLED) + client = Client("localhost:5555", accumulator_mode=AccumulatorMode.DISABLED) - # LINGER: timer-based forced batching - client = Client("localhost:5555", batch_mode=Linger(linger_ms=10, batch_size=100)) + # LINGER: timer-based forced accumulation + client = Client("localhost:5555", accumulator_mode=Linger(linger_ms=10, max_messages=100)) TLS (system trust store):: @@ -179,8 +188,8 @@ def __init__( client_cert: bytes | None = None, client_key: bytes | None = None, api_key: str | None = None, - batch_mode: BatchMode | Linger = BatchMode.AUTO, - max_batch_size: int = 1000, + accumulator_mode: AccumulatorMode | Linger = AccumulatorMode.AUTO, + max_accumulator_messages: int = 1000, ) -> None: """Connect to a Fila broker at the given address. @@ -195,10 +204,12 @@ def __init__( client_key: PEM-encoded client private key for mutual TLS (optional). api_key: API key for authentication. When set, every RPC includes an ``authorization: Bearer `` metadata header. - batch_mode: Controls how ``enqueue()`` routes messages. Defaults to - ``BatchMode.AUTO`` (opportunistic batching). - max_batch_size: Maximum number of messages per batch when using - ``BatchMode.AUTO``. Defaults to 1000. + accumulator_mode: Controls how ``enqueue()`` routes messages. + Defaults to ``AccumulatorMode.AUTO`` + (opportunistic accumulation). + max_accumulator_messages: Maximum number of messages per flush when + using ``AccumulatorMode.AUTO``. + Defaults to 1000. """ self._tls = tls self._ca_cert = ca_cert @@ -215,20 +226,20 @@ def __init__( self._channel = self._make_channel(addr) self._stub = service_pb2_grpc.FilaServiceStub(self._channel) # type: ignore[no-untyped-call] - # Set up the batcher based on the chosen mode. - self._batcher: AutoBatcher | LingerBatcher | None = None - if isinstance(batch_mode, Linger): - self._batcher = LingerBatcher( + # Set up the accumulator based on the chosen mode. + self._accumulator: AutoAccumulator | LingerAccumulator | None = None + if isinstance(accumulator_mode, Linger): + self._accumulator = LingerAccumulator( self._stub, - linger_ms=batch_mode.linger_ms, - batch_size=batch_mode.batch_size, + linger_ms=accumulator_mode.linger_ms, + max_messages=accumulator_mode.max_messages, ) - elif batch_mode is BatchMode.AUTO: - self._batcher = AutoBatcher( + elif accumulator_mode is AccumulatorMode.AUTO: + self._accumulator = AutoAccumulator( self._stub, - max_batch_size=max_batch_size, + max_messages=max_accumulator_messages, ) - # BatchMode.DISABLED: self._batcher stays None + # AccumulatorMode.DISABLED: self._accumulator stays None def _make_channel(self, addr: str) -> grpc.Channel: """Create a gRPC channel to the given address using stored credentials.""" @@ -251,9 +262,9 @@ def _make_channel(self, addr: str) -> grpc.Channel: return channel def close(self) -> None: - """Drain pending batched messages and close the underlying gRPC channel.""" - if self._batcher is not None: - self._batcher.close() + """Drain pending accumulated messages and close the underlying gRPC channel.""" + if self._accumulator is not None: + self._accumulator.close() self._channel.close() def __enter__(self) -> Client: @@ -270,12 +281,12 @@ def enqueue( ) -> str: """Enqueue a message to the specified queue. - When a batcher is active (``BatchMode.AUTO`` or ``Linger``), the - message is submitted to the background batcher and this call blocks - until the batch is flushed and the result is available. + When an accumulator is active (``AccumulatorMode.AUTO`` or ``Linger``), + the message is submitted to the background accumulator and this call + blocks until the flush completes and the result is available. - When batching is disabled (``BatchMode.DISABLED``), this call makes - a direct synchronous RPC. + When accumulation is disabled (``AccumulatorMode.DISABLED``), this call + makes a direct synchronous RPC. Args: queue: Target queue name. @@ -287,40 +298,47 @@ def enqueue( Raises: QueueNotFoundError: If the queue does not exist (DISABLED mode). - BatchEnqueueError: If the batch RPC fails (AUTO/LINGER mode). + EnqueueError: If the enqueue RPC fails (AUTO/LINGER mode). RPCError: For unexpected gRPC failures. """ - proto = service_pb2.EnqueueRequest( + proto = service_pb2.EnqueueMessage( queue=queue, headers=headers or {}, payload=payload, ) - if self._batcher is not None: - future = self._batcher.submit(proto) + if self._accumulator is not None: + future = self._accumulator.submit(proto) return future.result() # Direct RPC (DISABLED mode). try: - resp = self._stub.Enqueue(proto) + resp = self._stub.Enqueue( + service_pb2.EnqueueRequest(messages=[proto]) + ) except grpc.RpcError as e: raise _map_enqueue_error(e) from e - return str(resp.message_id) - def batch_enqueue( + result = resp.results[0] + which = result.WhichOneof("result") + if which == "message_id": + return str(result.message_id) + raise _map_enqueue_result_error(result.error.code, result.error.message) + + def enqueue_many( self, messages: list[tuple[str, dict[str, str] | None, bytes]], - ) -> list[BatchEnqueueResult]: + ) -> list[EnqueueResult]: """Enqueue multiple messages in a single RPC. - This is an explicit batch operation that always uses the BatchEnqueue - RPC regardless of the batch_mode setting. + This is an explicit multi-message operation that always uses the + Enqueue RPC directly, regardless of the accumulator_mode setting. Args: messages: List of (queue, headers, payload) tuples. Returns: - List of ``BatchEnqueueResult`` objects, one per input message. + List of ``EnqueueResult`` objects, one per input message. Each result has either a ``message_id`` (success) or ``error`` (per-message failure). @@ -329,7 +347,7 @@ def batch_enqueue( RPCError: For unexpected gRPC failures. """ proto_messages = [ - service_pb2.EnqueueRequest( + service_pb2.EnqueueMessage( queue=q, headers=h or {}, payload=p, @@ -338,26 +356,13 @@ def batch_enqueue( ] try: - resp = self._stub.BatchEnqueue( - service_pb2.BatchEnqueueRequest(messages=proto_messages) + resp = self._stub.Enqueue( + service_pb2.EnqueueRequest(messages=proto_messages) ) except grpc.RpcError as e: - raise _map_batch_enqueue_error(e) from e - - results: list[BatchEnqueueResult] = [] - for r in resp.results: - if r.HasField("success"): - results.append( - BatchEnqueueResult( - message_id=str(r.success.message_id), - error=None, - ) - ) - else: - results.append( - BatchEnqueueResult(message_id=None, error=r.error) - ) - return results + raise _map_enqueue_error(e) from e + + return [_proto_enqueue_result_to_sdk(r) for r in resp.results] def consume(self, queue: str) -> Iterator[ConsumeMessage]: """Open a streaming consumer on the specified queue. @@ -398,8 +403,8 @@ def _reconnect_and_consume(self, leader_addr: str, queue: str) -> Any: self._channel.close() self._channel = self._make_channel(leader_addr) self._stub = service_pb2_grpc.FilaServiceStub(self._channel) # type: ignore[no-untyped-call] - if self._batcher is not None: - self._batcher.update_stub(self._stub) + if self._accumulator is not None: + self._accumulator.update_stub(self._stub) try: return self._stub.Consume( service_pb2.ConsumeRequest(queue=queue) @@ -411,25 +416,12 @@ def _consume_iter( self, stream: Any, ) -> Iterator[ConsumeMessage]: - """Internal generator reading from the gRPC stream. - - Handles both singular ``message`` field (backward compatible) and - repeated ``messages`` field (batched delivery). - """ + """Internal generator reading from the gRPC stream.""" try: for resp in stream: - # Check batched messages first (repeated field). - if len(resp.messages) > 0: - for msg in resp.messages: - if msg is not None and msg.ByteSize(): - yield _proto_msg_to_consume_message(msg) - continue - - # Fall back to singular message field. - msg = resp.message - if msg is None or not msg.ByteSize(): - continue # keepalive - yield _proto_msg_to_consume_message(msg) + for msg in resp.messages: + if msg is not None and msg.ByteSize(): + yield _proto_msg_to_consume_message(msg) except grpc.RpcError: return @@ -447,12 +439,26 @@ def ack(self, queue: str, msg_id: str) -> None: RPCError: For unexpected gRPC failures. """ try: - self._stub.Ack( - service_pb2.AckRequest(queue=queue, message_id=msg_id) + resp = self._stub.Ack( + service_pb2.AckRequest( + messages=[service_pb2.AckMessage(queue=queue, message_id=msg_id)] + ) ) except grpc.RpcError as e: raise _map_ack_error(e) from e + # Check per-message result for errors. + if resp.results: + result = resp.results[0] + which = result.WhichOneof("result") + if which == "error": + from fila.errors import MessageNotFoundError, RPCError as _RPCError + + ack_err = result.error + if ack_err.code == service_pb2.ACK_ERROR_CODE_MESSAGE_NOT_FOUND: + raise MessageNotFoundError(f"ack: {ack_err.message}") + raise _RPCError(grpc.StatusCode.INTERNAL, f"ack: {ack_err.message}") + def nack(self, queue: str, msg_id: str, error: str) -> None: """Negatively acknowledge a message that failed processing. @@ -469,10 +475,26 @@ def nack(self, queue: str, msg_id: str, error: str) -> None: RPCError: For unexpected gRPC failures. """ try: - self._stub.Nack( + resp = self._stub.Nack( service_pb2.NackRequest( - queue=queue, message_id=msg_id, error=error + messages=[ + service_pb2.NackMessage( + queue=queue, message_id=msg_id, error=error + ) + ] ) ) except grpc.RpcError as e: raise _map_nack_error(e) from e + + # Check per-message result for errors. + if resp.results: + result = resp.results[0] + which = result.WhichOneof("result") + if which == "error": + from fila.errors import MessageNotFoundError, RPCError as _RPCError + + nack_err = result.error + if nack_err.code == service_pb2.NACK_ERROR_CODE_MESSAGE_NOT_FOUND: + raise MessageNotFoundError(f"nack: {nack_err.message}") + raise _RPCError(grpc.StatusCode.INTERNAL, f"nack: {nack_err.message}") diff --git a/fila/errors.py b/fila/errors.py index 40e76ee..819a197 100644 --- a/fila/errors.py +++ b/fila/errors.py @@ -26,15 +26,31 @@ def __init__(self, code: grpc.StatusCode, message: str) -> None: super().__init__(f"rpc error (code = {code.name}): {message}") -class BatchEnqueueError(FilaError): - """Raised when a batched enqueue fails at the RPC level. +class EnqueueError(FilaError): + """Raised when an enqueue fails at the RPC level. - Individual per-message failures are reported via ``BatchEnqueueResult.error`` - and do not raise this exception. This is raised only when the entire batch + Individual per-message failures are reported via ``EnqueueResult.error`` + and do not raise this exception. This is raised only when the entire RPC fails (e.g., network error, server unavailable). """ +def _map_enqueue_result_error(code: int, message: str) -> FilaError: + """Map a per-message EnqueueErrorCode to a Fila exception. + + Used when the unified Enqueue RPC succeeds at the transport level but + returns a per-message error result (e.g., queue not found for one of + the messages in the batch). + """ + from fila.v1 import service_pb2 + + if code == service_pb2.ENQUEUE_ERROR_CODE_QUEUE_NOT_FOUND: + return QueueNotFoundError(f"enqueue: {message}") + if code == service_pb2.ENQUEUE_ERROR_CODE_PERMISSION_DENIED: + return RPCError(grpc.StatusCode.PERMISSION_DENIED, f"enqueue: {message}") + return EnqueueError(f"enqueue failed: {message}") + + def _map_enqueue_error(err: grpc.RpcError) -> FilaError: """Map a gRPC error from an enqueue call to a Fila exception.""" code = err.code() @@ -65,11 +81,3 @@ def _map_nack_error(err: grpc.RpcError) -> FilaError: if code == grpc.StatusCode.NOT_FOUND: return MessageNotFoundError(f"nack: {err.details()}") return RPCError(code, err.details() or "") - - -def _map_batch_enqueue_error(err: grpc.RpcError) -> FilaError: - """Map a gRPC error from a batch enqueue call to a Fila exception.""" - code = err.code() - if code == grpc.StatusCode.NOT_FOUND: - return QueueNotFoundError(f"batch_enqueue: {err.details()}") - return RPCError(code, err.details() or "") diff --git a/fila/types.py b/fila/types.py index 54ab034..a73c15a 100644 --- a/fila/types.py +++ b/fila/types.py @@ -19,8 +19,8 @@ class ConsumeMessage: @dataclass(frozen=True) -class BatchEnqueueResult: - """Result for a single message within a batch enqueue operation. +class EnqueueResult: + """Result for a single message within an enqueue operation. Exactly one of ``message_id`` or ``error`` is set. """ @@ -34,13 +34,13 @@ def is_success(self) -> bool: return self.message_id is not None -class BatchMode(Enum): +class AccumulatorMode(Enum): """Controls how ``enqueue()`` routes messages to the broker. - - ``AUTO``: Opportunistic batching via a background thread. At low load + - ``AUTO``: Opportunistic accumulation via a background thread. At low load messages are sent individually; at high load they cluster into batches. This is the default. - - ``DISABLED``: No batching. Each ``enqueue()`` call is a direct RPC. + - ``DISABLED``: No accumulation. Each ``enqueue()`` call is a direct RPC. """ AUTO = auto() @@ -49,15 +49,15 @@ class BatchMode(Enum): @dataclass(frozen=True) class Linger: - """Timer-based forced batching mode. + """Timer-based forced accumulation mode. Messages are held for up to ``linger_ms`` milliseconds or until - ``batch_size`` messages accumulate, whichever comes first. + ``max_messages`` messages accumulate, whichever comes first. Args: linger_ms: Maximum time to hold a message before flushing (milliseconds). - batch_size: Maximum number of messages per batch. + max_messages: Maximum number of messages per flush. """ linger_ms: float - batch_size: int + max_messages: int diff --git a/fila/v1/admin_pb2.py b/fila/v1/admin_pb2.py index 4bb4e27..7dc1f25 100644 --- a/fila/v1/admin_pb2.py +++ b/fila/v1/admin_pb2.py @@ -24,7 +24,7 @@ -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x13\x66ila/v1/admin.proto\x12\x07\x66ila.v1\"H\n\x12\x43reateQueueRequest\x12\x0c\n\x04name\x18\x01 \x01(\t\x12$\n\x06\x63onfig\x18\x02 \x01(\x0b\x32\x14.fila.v1.QueueConfig\"b\n\x0bQueueConfig\x12\x19\n\x11on_enqueue_script\x18\x01 \x01(\t\x12\x19\n\x11on_failure_script\x18\x02 \x01(\t\x12\x1d\n\x15visibility_timeout_ms\x18\x03 \x01(\x04\"\'\n\x13\x43reateQueueResponse\x12\x10\n\x08queue_id\x18\x01 \x01(\t\"#\n\x12\x44\x65leteQueueRequest\x12\r\n\x05queue\x18\x01 \x01(\t\"\x15\n\x13\x44\x65leteQueueResponse\".\n\x10SetConfigRequest\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t\"\x13\n\x11SetConfigResponse\"\x1f\n\x10GetConfigRequest\x12\x0b\n\x03key\x18\x01 \x01(\t\"\"\n\x11GetConfigResponse\x12\r\n\x05value\x18\x01 \x01(\t\")\n\x0b\x43onfigEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t\"#\n\x11ListConfigRequest\x12\x0e\n\x06prefix\x18\x01 \x01(\t\"P\n\x12ListConfigResponse\x12%\n\x07\x65ntries\x18\x01 \x03(\x0b\x32\x14.fila.v1.ConfigEntry\x12\x13\n\x0btotal_count\x18\x02 \x01(\r\" \n\x0fGetStatsRequest\x12\r\n\x05queue\x18\x01 \x01(\t\"b\n\x13PerFairnessKeyStats\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\x15\n\rpending_count\x18\x02 \x01(\x04\x12\x17\n\x0f\x63urrent_deficit\x18\x03 \x01(\x03\x12\x0e\n\x06weight\x18\x04 \x01(\r\"Z\n\x13PerThrottleKeyStats\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\x0e\n\x06tokens\x18\x02 \x01(\x01\x12\x17\n\x0frate_per_second\x18\x03 \x01(\x01\x12\r\n\x05\x62urst\x18\x04 \x01(\x01\"\x9f\x02\n\x10GetStatsResponse\x12\r\n\x05\x64\x65pth\x18\x01 \x01(\x04\x12\x11\n\tin_flight\x18\x02 \x01(\x04\x12\x1c\n\x14\x61\x63tive_fairness_keys\x18\x03 \x01(\x04\x12\x18\n\x10\x61\x63tive_consumers\x18\x04 \x01(\r\x12\x0f\n\x07quantum\x18\x05 \x01(\r\x12\x33\n\rper_key_stats\x18\x06 \x03(\x0b\x32\x1c.fila.v1.PerFairnessKeyStats\x12\x38\n\x12per_throttle_stats\x18\x07 \x03(\x0b\x32\x1c.fila.v1.PerThrottleKeyStats\x12\x16\n\x0eleader_node_id\x18\x08 \x01(\x04\x12\x19\n\x11replication_count\x18\t \x01(\r\"2\n\x0eRedriveRequest\x12\x11\n\tdlq_queue\x18\x01 \x01(\t\x12\r\n\x05\x63ount\x18\x02 \x01(\x04\"#\n\x0fRedriveResponse\x12\x10\n\x08redriven\x18\x01 \x01(\x04\"\x13\n\x11ListQueuesRequest\"m\n\tQueueInfo\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\r\n\x05\x64\x65pth\x18\x02 \x01(\x04\x12\x11\n\tin_flight\x18\x03 \x01(\x04\x12\x18\n\x10\x61\x63tive_consumers\x18\x04 \x01(\r\x12\x16\n\x0eleader_node_id\x18\x05 \x01(\x04\"T\n\x12ListQueuesResponse\x12\"\n\x06queues\x18\x01 \x03(\x0b\x32\x12.fila.v1.QueueInfo\x12\x1a\n\x12\x63luster_node_count\x18\x02 \x01(\r\"Q\n\x13\x43reateApiKeyRequest\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x15\n\rexpires_at_ms\x18\x02 \x01(\x04\x12\x15\n\ris_superadmin\x18\x03 \x01(\x08\"J\n\x14\x43reateApiKeyResponse\x12\x0e\n\x06key_id\x18\x01 \x01(\t\x12\x0b\n\x03key\x18\x02 \x01(\t\x12\x15\n\ris_superadmin\x18\x03 \x01(\x08\"%\n\x13RevokeApiKeyRequest\x12\x0e\n\x06key_id\x18\x01 \x01(\t\"\x16\n\x14RevokeApiKeyResponse\"\x14\n\x12ListApiKeysRequest\"o\n\nApiKeyInfo\x12\x0e\n\x06key_id\x18\x01 \x01(\t\x12\x0c\n\x04name\x18\x02 \x01(\t\x12\x15\n\rcreated_at_ms\x18\x03 \x01(\x04\x12\x15\n\rexpires_at_ms\x18\x04 \x01(\x04\x12\x15\n\ris_superadmin\x18\x05 \x01(\x08\"8\n\x13ListApiKeysResponse\x12!\n\x04keys\x18\x01 \x03(\x0b\x32\x13.fila.v1.ApiKeyInfo\".\n\rAclPermission\x12\x0c\n\x04kind\x18\x01 \x01(\t\x12\x0f\n\x07pattern\x18\x02 \x01(\t\"L\n\rSetAclRequest\x12\x0e\n\x06key_id\x18\x01 \x01(\t\x12+\n\x0bpermissions\x18\x02 \x03(\x0b\x32\x16.fila.v1.AclPermission\"\x10\n\x0eSetAclResponse\"\x1f\n\rGetAclRequest\x12\x0e\n\x06key_id\x18\x01 \x01(\t\"d\n\x0eGetAclResponse\x12\x0e\n\x06key_id\x18\x01 \x01(\t\x12+\n\x0bpermissions\x18\x02 \x03(\x0b\x32\x16.fila.v1.AclPermission\x12\x15\n\ris_superadmin\x18\x03 \x01(\x08\x32\x8e\x07\n\tFilaAdmin\x12H\n\x0b\x43reateQueue\x12\x1b.fila.v1.CreateQueueRequest\x1a\x1c.fila.v1.CreateQueueResponse\x12H\n\x0b\x44\x65leteQueue\x12\x1b.fila.v1.DeleteQueueRequest\x1a\x1c.fila.v1.DeleteQueueResponse\x12\x42\n\tSetConfig\x12\x19.fila.v1.SetConfigRequest\x1a\x1a.fila.v1.SetConfigResponse\x12\x42\n\tGetConfig\x12\x19.fila.v1.GetConfigRequest\x1a\x1a.fila.v1.GetConfigResponse\x12\x45\n\nListConfig\x12\x1a.fila.v1.ListConfigRequest\x1a\x1b.fila.v1.ListConfigResponse\x12?\n\x08GetStats\x12\x18.fila.v1.GetStatsRequest\x1a\x19.fila.v1.GetStatsResponse\x12<\n\x07Redrive\x12\x17.fila.v1.RedriveRequest\x1a\x18.fila.v1.RedriveResponse\x12\x45\n\nListQueues\x12\x1a.fila.v1.ListQueuesRequest\x1a\x1b.fila.v1.ListQueuesResponse\x12K\n\x0c\x43reateApiKey\x12\x1c.fila.v1.CreateApiKeyRequest\x1a\x1d.fila.v1.CreateApiKeyResponse\x12K\n\x0cRevokeApiKey\x12\x1c.fila.v1.RevokeApiKeyRequest\x1a\x1d.fila.v1.RevokeApiKeyResponse\x12H\n\x0bListApiKeys\x12\x1b.fila.v1.ListApiKeysRequest\x1a\x1c.fila.v1.ListApiKeysResponse\x12\x39\n\x06SetAcl\x12\x16.fila.v1.SetAclRequest\x1a\x17.fila.v1.SetAclResponse\x12\x39\n\x06GetAcl\x12\x16.fila.v1.GetAclRequest\x1a\x17.fila.v1.GetAclResponseb\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x13\x66ila/v1/admin.proto\x12\x07\x66ila.v1\"H\n\x12\x43reateQueueRequest\x12\x0c\n\x04name\x18\x01 \x01(\t\x12$\n\x06\x63onfig\x18\x02 \x01(\x0b\x32\x14.fila.v1.QueueConfig\"b\n\x0bQueueConfig\x12\x19\n\x11on_enqueue_script\x18\x01 \x01(\t\x12\x19\n\x11on_failure_script\x18\x02 \x01(\t\x12\x1d\n\x15visibility_timeout_ms\x18\x03 \x01(\x04\"\'\n\x13\x43reateQueueResponse\x12\x10\n\x08queue_id\x18\x01 \x01(\t\"#\n\x12\x44\x65leteQueueRequest\x12\r\n\x05queue\x18\x01 \x01(\t\"\x15\n\x13\x44\x65leteQueueResponse\".\n\x10SetConfigRequest\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t\"\x13\n\x11SetConfigResponse\"\x1f\n\x10GetConfigRequest\x12\x0b\n\x03key\x18\x01 \x01(\t\"\"\n\x11GetConfigResponse\x12\r\n\x05value\x18\x01 \x01(\t\")\n\x0b\x43onfigEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t\"#\n\x11ListConfigRequest\x12\x0e\n\x06prefix\x18\x01 \x01(\t\"P\n\x12ListConfigResponse\x12%\n\x07\x65ntries\x18\x01 \x03(\x0b\x32\x14.fila.v1.ConfigEntry\x12\x13\n\x0btotal_count\x18\x02 \x01(\r\" \n\x0fGetStatsRequest\x12\r\n\x05queue\x18\x01 \x01(\t\"b\n\x13PerFairnessKeyStats\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\x15\n\rpending_count\x18\x02 \x01(\x04\x12\x17\n\x0f\x63urrent_deficit\x18\x03 \x01(\x03\x12\x0e\n\x06weight\x18\x04 \x01(\r\"Z\n\x13PerThrottleKeyStats\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\x0e\n\x06tokens\x18\x02 \x01(\x01\x12\x17\n\x0frate_per_second\x18\x03 \x01(\x01\x12\r\n\x05\x62urst\x18\x04 \x01(\x01\"\x9f\x02\n\x10GetStatsResponse\x12\r\n\x05\x64\x65pth\x18\x01 \x01(\x04\x12\x11\n\tin_flight\x18\x02 \x01(\x04\x12\x1c\n\x14\x61\x63tive_fairness_keys\x18\x03 \x01(\x04\x12\x18\n\x10\x61\x63tive_consumers\x18\x04 \x01(\r\x12\x0f\n\x07quantum\x18\x05 \x01(\r\x12\x33\n\rper_key_stats\x18\x06 \x03(\x0b\x32\x1c.fila.v1.PerFairnessKeyStats\x12\x38\n\x12per_throttle_stats\x18\x07 \x03(\x0b\x32\x1c.fila.v1.PerThrottleKeyStats\x12\x16\n\x0eleader_node_id\x18\x08 \x01(\x04\x12\x19\n\x11replication_count\x18\t \x01(\r\"2\n\x0eRedriveRequest\x12\x11\n\tdlq_queue\x18\x01 \x01(\t\x12\r\n\x05\x63ount\x18\x02 \x01(\x04\"#\n\x0fRedriveResponse\x12\x10\n\x08redriven\x18\x01 \x01(\x04\"\x13\n\x11ListQueuesRequest\"m\n\tQueueInfo\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\r\n\x05\x64\x65pth\x18\x02 \x01(\x04\x12\x11\n\tin_flight\x18\x03 \x01(\x04\x12\x18\n\x10\x61\x63tive_consumers\x18\x04 \x01(\r\x12\x16\n\x0eleader_node_id\x18\x05 \x01(\x04\"T\n\x12ListQueuesResponse\x12\"\n\x06queues\x18\x01 \x03(\x0b\x32\x12.fila.v1.QueueInfo\x12\x1a\n\x12\x63luster_node_count\x18\x02 \x01(\r2\xb4\x04\n\tFilaAdmin\x12H\n\x0b\x43reateQueue\x12\x1b.fila.v1.CreateQueueRequest\x1a\x1c.fila.v1.CreateQueueResponse\x12H\n\x0b\x44\x65leteQueue\x12\x1b.fila.v1.DeleteQueueRequest\x1a\x1c.fila.v1.DeleteQueueResponse\x12\x42\n\tSetConfig\x12\x19.fila.v1.SetConfigRequest\x1a\x1a.fila.v1.SetConfigResponse\x12\x42\n\tGetConfig\x12\x19.fila.v1.GetConfigRequest\x1a\x1a.fila.v1.GetConfigResponse\x12\x45\n\nListConfig\x12\x1a.fila.v1.ListConfigRequest\x1a\x1b.fila.v1.ListConfigResponse\x12?\n\x08GetStats\x12\x18.fila.v1.GetStatsRequest\x1a\x19.fila.v1.GetStatsResponse\x12<\n\x07Redrive\x12\x17.fila.v1.RedriveRequest\x1a\x18.fila.v1.RedriveResponse\x12\x45\n\nListQueues\x12\x1a.fila.v1.ListQueuesRequest\x1a\x1b.fila.v1.ListQueuesResponseb\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) @@ -73,30 +73,6 @@ _globals['_QUEUEINFO']._serialized_end=1342 _globals['_LISTQUEUESRESPONSE']._serialized_start=1344 _globals['_LISTQUEUESRESPONSE']._serialized_end=1428 - _globals['_CREATEAPIKEYREQUEST']._serialized_start=1430 - _globals['_CREATEAPIKEYREQUEST']._serialized_end=1511 - _globals['_CREATEAPIKEYRESPONSE']._serialized_start=1513 - _globals['_CREATEAPIKEYRESPONSE']._serialized_end=1587 - _globals['_REVOKEAPIKEYREQUEST']._serialized_start=1589 - _globals['_REVOKEAPIKEYREQUEST']._serialized_end=1626 - _globals['_REVOKEAPIKEYRESPONSE']._serialized_start=1628 - _globals['_REVOKEAPIKEYRESPONSE']._serialized_end=1650 - _globals['_LISTAPIKEYSREQUEST']._serialized_start=1652 - _globals['_LISTAPIKEYSREQUEST']._serialized_end=1672 - _globals['_APIKEYINFO']._serialized_start=1674 - _globals['_APIKEYINFO']._serialized_end=1785 - _globals['_LISTAPIKEYSRESPONSE']._serialized_start=1787 - _globals['_LISTAPIKEYSRESPONSE']._serialized_end=1843 - _globals['_ACLPERMISSION']._serialized_start=1845 - _globals['_ACLPERMISSION']._serialized_end=1891 - _globals['_SETACLREQUEST']._serialized_start=1893 - _globals['_SETACLREQUEST']._serialized_end=1969 - _globals['_SETACLRESPONSE']._serialized_start=1971 - _globals['_SETACLRESPONSE']._serialized_end=1987 - _globals['_GETACLREQUEST']._serialized_start=1989 - _globals['_GETACLREQUEST']._serialized_end=2020 - _globals['_GETACLRESPONSE']._serialized_start=2022 - _globals['_GETACLRESPONSE']._serialized_end=2122 - _globals['_FILAADMIN']._serialized_start=2125 - _globals['_FILAADMIN']._serialized_end=3035 + _globals['_FILAADMIN']._serialized_start=1431 + _globals['_FILAADMIN']._serialized_end=1995 # @@protoc_insertion_point(module_scope) diff --git a/fila/v1/admin_pb2.pyi b/fila/v1/admin_pb2.pyi index d603b29..0c594ce 100644 --- a/fila/v1/admin_pb2.pyi +++ b/fila/v1/admin_pb2.pyi @@ -177,93 +177,3 @@ class ListQueuesResponse(_message.Message): queues: _containers.RepeatedCompositeFieldContainer[QueueInfo] cluster_node_count: int def __init__(self, queues: _Optional[_Iterable[_Union[QueueInfo, _Mapping]]] = ..., cluster_node_count: _Optional[int] = ...) -> None: ... - -class CreateApiKeyRequest(_message.Message): - __slots__ = ("name", "expires_at_ms", "is_superadmin") - NAME_FIELD_NUMBER: _ClassVar[int] - EXPIRES_AT_MS_FIELD_NUMBER: _ClassVar[int] - IS_SUPERADMIN_FIELD_NUMBER: _ClassVar[int] - name: str - expires_at_ms: int - is_superadmin: bool - def __init__(self, name: _Optional[str] = ..., expires_at_ms: _Optional[int] = ..., is_superadmin: bool = ...) -> None: ... - -class CreateApiKeyResponse(_message.Message): - __slots__ = ("key_id", "key", "is_superadmin") - KEY_ID_FIELD_NUMBER: _ClassVar[int] - KEY_FIELD_NUMBER: _ClassVar[int] - IS_SUPERADMIN_FIELD_NUMBER: _ClassVar[int] - key_id: str - key: str - is_superadmin: bool - def __init__(self, key_id: _Optional[str] = ..., key: _Optional[str] = ..., is_superadmin: bool = ...) -> None: ... - -class RevokeApiKeyRequest(_message.Message): - __slots__ = ("key_id",) - KEY_ID_FIELD_NUMBER: _ClassVar[int] - key_id: str - def __init__(self, key_id: _Optional[str] = ...) -> None: ... - -class RevokeApiKeyResponse(_message.Message): - __slots__ = () - def __init__(self) -> None: ... - -class ListApiKeysRequest(_message.Message): - __slots__ = () - def __init__(self) -> None: ... - -class ApiKeyInfo(_message.Message): - __slots__ = ("key_id", "name", "created_at_ms", "expires_at_ms", "is_superadmin") - KEY_ID_FIELD_NUMBER: _ClassVar[int] - NAME_FIELD_NUMBER: _ClassVar[int] - CREATED_AT_MS_FIELD_NUMBER: _ClassVar[int] - EXPIRES_AT_MS_FIELD_NUMBER: _ClassVar[int] - IS_SUPERADMIN_FIELD_NUMBER: _ClassVar[int] - key_id: str - name: str - created_at_ms: int - expires_at_ms: int - is_superadmin: bool - def __init__(self, key_id: _Optional[str] = ..., name: _Optional[str] = ..., created_at_ms: _Optional[int] = ..., expires_at_ms: _Optional[int] = ..., is_superadmin: bool = ...) -> None: ... - -class ListApiKeysResponse(_message.Message): - __slots__ = ("keys",) - KEYS_FIELD_NUMBER: _ClassVar[int] - keys: _containers.RepeatedCompositeFieldContainer[ApiKeyInfo] - def __init__(self, keys: _Optional[_Iterable[_Union[ApiKeyInfo, _Mapping]]] = ...) -> None: ... - -class AclPermission(_message.Message): - __slots__ = ("kind", "pattern") - KIND_FIELD_NUMBER: _ClassVar[int] - PATTERN_FIELD_NUMBER: _ClassVar[int] - kind: str - pattern: str - def __init__(self, kind: _Optional[str] = ..., pattern: _Optional[str] = ...) -> None: ... - -class SetAclRequest(_message.Message): - __slots__ = ("key_id", "permissions") - KEY_ID_FIELD_NUMBER: _ClassVar[int] - PERMISSIONS_FIELD_NUMBER: _ClassVar[int] - key_id: str - permissions: _containers.RepeatedCompositeFieldContainer[AclPermission] - def __init__(self, key_id: _Optional[str] = ..., permissions: _Optional[_Iterable[_Union[AclPermission, _Mapping]]] = ...) -> None: ... - -class SetAclResponse(_message.Message): - __slots__ = () - def __init__(self) -> None: ... - -class GetAclRequest(_message.Message): - __slots__ = ("key_id",) - KEY_ID_FIELD_NUMBER: _ClassVar[int] - key_id: str - def __init__(self, key_id: _Optional[str] = ...) -> None: ... - -class GetAclResponse(_message.Message): - __slots__ = ("key_id", "permissions", "is_superadmin") - KEY_ID_FIELD_NUMBER: _ClassVar[int] - PERMISSIONS_FIELD_NUMBER: _ClassVar[int] - IS_SUPERADMIN_FIELD_NUMBER: _ClassVar[int] - key_id: str - permissions: _containers.RepeatedCompositeFieldContainer[AclPermission] - is_superadmin: bool - def __init__(self, key_id: _Optional[str] = ..., permissions: _Optional[_Iterable[_Union[AclPermission, _Mapping]]] = ..., is_superadmin: bool = ...) -> None: ... diff --git a/fila/v1/admin_pb2_grpc.py b/fila/v1/admin_pb2_grpc.py index 93d6c4e..70d8fbe 100644 --- a/fila/v1/admin_pb2_grpc.py +++ b/fila/v1/admin_pb2_grpc.py @@ -5,7 +5,7 @@ from fila.v1 import admin_pb2 as fila_dot_v1_dot_admin__pb2 -GRPC_GENERATED_VERSION = '1.78.0' +GRPC_GENERATED_VERSION = '1.78.1' GRPC_VERSION = grpc.__version__ _version_not_supported = False @@ -75,31 +75,6 @@ def __init__(self, channel): request_serializer=fila_dot_v1_dot_admin__pb2.ListQueuesRequest.SerializeToString, response_deserializer=fila_dot_v1_dot_admin__pb2.ListQueuesResponse.FromString, _registered_method=True) - self.CreateApiKey = channel.unary_unary( - '/fila.v1.FilaAdmin/CreateApiKey', - request_serializer=fila_dot_v1_dot_admin__pb2.CreateApiKeyRequest.SerializeToString, - response_deserializer=fila_dot_v1_dot_admin__pb2.CreateApiKeyResponse.FromString, - _registered_method=True) - self.RevokeApiKey = channel.unary_unary( - '/fila.v1.FilaAdmin/RevokeApiKey', - request_serializer=fila_dot_v1_dot_admin__pb2.RevokeApiKeyRequest.SerializeToString, - response_deserializer=fila_dot_v1_dot_admin__pb2.RevokeApiKeyResponse.FromString, - _registered_method=True) - self.ListApiKeys = channel.unary_unary( - '/fila.v1.FilaAdmin/ListApiKeys', - request_serializer=fila_dot_v1_dot_admin__pb2.ListApiKeysRequest.SerializeToString, - response_deserializer=fila_dot_v1_dot_admin__pb2.ListApiKeysResponse.FromString, - _registered_method=True) - self.SetAcl = channel.unary_unary( - '/fila.v1.FilaAdmin/SetAcl', - request_serializer=fila_dot_v1_dot_admin__pb2.SetAclRequest.SerializeToString, - response_deserializer=fila_dot_v1_dot_admin__pb2.SetAclResponse.FromString, - _registered_method=True) - self.GetAcl = channel.unary_unary( - '/fila.v1.FilaAdmin/GetAcl', - request_serializer=fila_dot_v1_dot_admin__pb2.GetAclRequest.SerializeToString, - response_deserializer=fila_dot_v1_dot_admin__pb2.GetAclResponse.FromString, - _registered_method=True) class FilaAdminServicer(object): @@ -154,38 +129,6 @@ def ListQueues(self, request, context): context.set_details('Method not implemented!') raise NotImplementedError('Method not implemented!') - def CreateApiKey(self, request, context): - """API key management. CreateApiKey bypasses auth (bootstrap); others require a valid key. - """ - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') - - def RevokeApiKey(self, request, context): - """Missing associated documentation comment in .proto file.""" - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') - - def ListApiKeys(self, request, context): - """Missing associated documentation comment in .proto file.""" - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') - - def SetAcl(self, request, context): - """Per-key ACL management. - """ - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') - - def GetAcl(self, request, context): - """Missing associated documentation comment in .proto file.""" - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') - def add_FilaAdminServicer_to_server(servicer, server): rpc_method_handlers = { @@ -229,31 +172,6 @@ def add_FilaAdminServicer_to_server(servicer, server): request_deserializer=fila_dot_v1_dot_admin__pb2.ListQueuesRequest.FromString, response_serializer=fila_dot_v1_dot_admin__pb2.ListQueuesResponse.SerializeToString, ), - 'CreateApiKey': grpc.unary_unary_rpc_method_handler( - servicer.CreateApiKey, - request_deserializer=fila_dot_v1_dot_admin__pb2.CreateApiKeyRequest.FromString, - response_serializer=fila_dot_v1_dot_admin__pb2.CreateApiKeyResponse.SerializeToString, - ), - 'RevokeApiKey': grpc.unary_unary_rpc_method_handler( - servicer.RevokeApiKey, - request_deserializer=fila_dot_v1_dot_admin__pb2.RevokeApiKeyRequest.FromString, - response_serializer=fila_dot_v1_dot_admin__pb2.RevokeApiKeyResponse.SerializeToString, - ), - 'ListApiKeys': grpc.unary_unary_rpc_method_handler( - servicer.ListApiKeys, - request_deserializer=fila_dot_v1_dot_admin__pb2.ListApiKeysRequest.FromString, - response_serializer=fila_dot_v1_dot_admin__pb2.ListApiKeysResponse.SerializeToString, - ), - 'SetAcl': grpc.unary_unary_rpc_method_handler( - servicer.SetAcl, - request_deserializer=fila_dot_v1_dot_admin__pb2.SetAclRequest.FromString, - response_serializer=fila_dot_v1_dot_admin__pb2.SetAclResponse.SerializeToString, - ), - 'GetAcl': grpc.unary_unary_rpc_method_handler( - servicer.GetAcl, - request_deserializer=fila_dot_v1_dot_admin__pb2.GetAclRequest.FromString, - response_serializer=fila_dot_v1_dot_admin__pb2.GetAclResponse.SerializeToString, - ), } generic_handler = grpc.method_handlers_generic_handler( 'fila.v1.FilaAdmin', rpc_method_handlers) @@ -481,138 +399,3 @@ def ListQueues(request, timeout, metadata, _registered_method=True) - - @staticmethod - def CreateApiKey(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary( - request, - target, - '/fila.v1.FilaAdmin/CreateApiKey', - fila_dot_v1_dot_admin__pb2.CreateApiKeyRequest.SerializeToString, - fila_dot_v1_dot_admin__pb2.CreateApiKeyResponse.FromString, - options, - channel_credentials, - insecure, - call_credentials, - compression, - wait_for_ready, - timeout, - metadata, - _registered_method=True) - - @staticmethod - def RevokeApiKey(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary( - request, - target, - '/fila.v1.FilaAdmin/RevokeApiKey', - fila_dot_v1_dot_admin__pb2.RevokeApiKeyRequest.SerializeToString, - fila_dot_v1_dot_admin__pb2.RevokeApiKeyResponse.FromString, - options, - channel_credentials, - insecure, - call_credentials, - compression, - wait_for_ready, - timeout, - metadata, - _registered_method=True) - - @staticmethod - def ListApiKeys(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary( - request, - target, - '/fila.v1.FilaAdmin/ListApiKeys', - fila_dot_v1_dot_admin__pb2.ListApiKeysRequest.SerializeToString, - fila_dot_v1_dot_admin__pb2.ListApiKeysResponse.FromString, - options, - channel_credentials, - insecure, - call_credentials, - compression, - wait_for_ready, - timeout, - metadata, - _registered_method=True) - - @staticmethod - def SetAcl(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary( - request, - target, - '/fila.v1.FilaAdmin/SetAcl', - fila_dot_v1_dot_admin__pb2.SetAclRequest.SerializeToString, - fila_dot_v1_dot_admin__pb2.SetAclResponse.FromString, - options, - channel_credentials, - insecure, - call_credentials, - compression, - wait_for_ready, - timeout, - metadata, - _registered_method=True) - - @staticmethod - def GetAcl(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary( - request, - target, - '/fila.v1.FilaAdmin/GetAcl', - fila_dot_v1_dot_admin__pb2.GetAclRequest.SerializeToString, - fila_dot_v1_dot_admin__pb2.GetAclResponse.FromString, - options, - channel_credentials, - insecure, - call_credentials, - compression, - wait_for_ready, - timeout, - metadata, - _registered_method=True) diff --git a/fila/v1/service_pb2.py b/fila/v1/service_pb2.py index 7f04078..7489260 100644 --- a/fila/v1/service_pb2.py +++ b/fila/v1/service_pb2.py @@ -25,39 +25,65 @@ from fila.v1 import messages_pb2 as fila_dot_v1_dot_messages__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x15\x66ila/v1/service.proto\x12\x07\x66ila.v1\x1a\x16\x66ila/v1/messages.proto\"\x97\x01\n\x0e\x45nqueueRequest\x12\r\n\x05queue\x18\x01 \x01(\t\x12\x35\n\x07headers\x18\x02 \x03(\x0b\x32$.fila.v1.EnqueueRequest.HeadersEntry\x12\x0f\n\x07payload\x18\x03 \x01(\x0c\x1a.\n\x0cHeadersEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\"%\n\x0f\x45nqueueResponse\x12\x12\n\nmessage_id\x18\x01 \x01(\t\"\x1f\n\x0e\x43onsumeRequest\x12\r\n\x05queue\x18\x01 \x01(\t\"X\n\x0f\x43onsumeResponse\x12!\n\x07message\x18\x01 \x01(\x0b\x32\x10.fila.v1.Message\x12\"\n\x08messages\x18\x02 \x03(\x0b\x32\x10.fila.v1.Message\"/\n\nAckRequest\x12\r\n\x05queue\x18\x01 \x01(\t\x12\x12\n\nmessage_id\x18\x02 \x01(\t\"\r\n\x0b\x41\x63kResponse\"?\n\x0bNackRequest\x12\r\n\x05queue\x18\x01 \x01(\t\x12\x12\n\nmessage_id\x18\x02 \x01(\t\x12\r\n\x05\x65rror\x18\x03 \x01(\t\"\x0e\n\x0cNackResponse\"@\n\x13\x42\x61tchEnqueueRequest\x12)\n\x08messages\x18\x01 \x03(\x0b\x32\x17.fila.v1.EnqueueRequest\"D\n\x14\x42\x61tchEnqueueResponse\x12,\n\x07results\x18\x01 \x03(\x0b\x32\x1b.fila.v1.BatchEnqueueResult\"\\\n\x12\x42\x61tchEnqueueResult\x12+\n\x07success\x18\x01 \x01(\x0b\x32\x18.fila.v1.EnqueueResponseH\x00\x12\x0f\n\x05\x65rror\x18\x02 \x01(\tH\x00\x42\x08\n\x06result2\xbf\x02\n\x0b\x46ilaService\x12<\n\x07\x45nqueue\x12\x17.fila.v1.EnqueueRequest\x1a\x18.fila.v1.EnqueueResponse\x12K\n\x0c\x42\x61tchEnqueue\x12\x1c.fila.v1.BatchEnqueueRequest\x1a\x1d.fila.v1.BatchEnqueueResponse\x12>\n\x07\x43onsume\x12\x17.fila.v1.ConsumeRequest\x1a\x18.fila.v1.ConsumeResponse0\x01\x12\x30\n\x03\x41\x63k\x12\x13.fila.v1.AckRequest\x1a\x14.fila.v1.AckResponse\x12\x33\n\x04Nack\x12\x14.fila.v1.NackRequest\x1a\x15.fila.v1.NackResponseb\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x15\x66ila/v1/service.proto\x12\x07\x66ila.v1\x1a\x16\x66ila/v1/messages.proto\"\x97\x01\n\x0e\x45nqueueMessage\x12\r\n\x05queue\x18\x01 \x01(\t\x12\x35\n\x07headers\x18\x02 \x03(\x0b\x32$.fila.v1.EnqueueMessage.HeadersEntry\x12\x0f\n\x07payload\x18\x03 \x01(\x0c\x1a.\n\x0cHeadersEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\";\n\x0e\x45nqueueRequest\x12)\n\x08messages\x18\x01 \x03(\x0b\x32\x17.fila.v1.EnqueueMessage\"W\n\rEnqueueResult\x12\x14\n\nmessage_id\x18\x01 \x01(\tH\x00\x12&\n\x05\x65rror\x18\x02 \x01(\x0b\x32\x15.fila.v1.EnqueueErrorH\x00\x42\x08\n\x06result\"H\n\x0c\x45nqueueError\x12\'\n\x04\x63ode\x18\x01 \x01(\x0e\x32\x19.fila.v1.EnqueueErrorCode\x12\x0f\n\x07message\x18\x02 \x01(\t\":\n\x0f\x45nqueueResponse\x12\'\n\x07results\x18\x01 \x03(\x0b\x32\x16.fila.v1.EnqueueResult\"\x1f\n\x0e\x43onsumeRequest\x12\r\n\x05queue\x18\x01 \x01(\t\"5\n\x0f\x43onsumeResponse\x12\"\n\x08messages\x18\x01 \x03(\x0b\x32\x10.fila.v1.Message\"/\n\nAckMessage\x12\r\n\x05queue\x18\x01 \x01(\t\x12\x12\n\nmessage_id\x18\x02 \x01(\t\"3\n\nAckRequest\x12%\n\x08messages\x18\x01 \x03(\x0b\x32\x13.fila.v1.AckMessage\"a\n\tAckResult\x12&\n\x07success\x18\x01 \x01(\x0b\x32\x13.fila.v1.AckSuccessH\x00\x12\"\n\x05\x65rror\x18\x02 \x01(\x0b\x32\x11.fila.v1.AckErrorH\x00\x42\x08\n\x06result\"\x0c\n\nAckSuccess\"@\n\x08\x41\x63kError\x12#\n\x04\x63ode\x18\x01 \x01(\x0e\x32\x15.fila.v1.AckErrorCode\x12\x0f\n\x07message\x18\x02 \x01(\t\"2\n\x0b\x41\x63kResponse\x12#\n\x07results\x18\x01 \x03(\x0b\x32\x12.fila.v1.AckResult\"?\n\x0bNackMessage\x12\r\n\x05queue\x18\x01 \x01(\t\x12\x12\n\nmessage_id\x18\x02 \x01(\t\x12\r\n\x05\x65rror\x18\x03 \x01(\t\"5\n\x0bNackRequest\x12&\n\x08messages\x18\x01 \x03(\x0b\x32\x14.fila.v1.NackMessage\"d\n\nNackResult\x12\'\n\x07success\x18\x01 \x01(\x0b\x32\x14.fila.v1.NackSuccessH\x00\x12#\n\x05\x65rror\x18\x02 \x01(\x0b\x32\x12.fila.v1.NackErrorH\x00\x42\x08\n\x06result\"\r\n\x0bNackSuccess\"B\n\tNackError\x12$\n\x04\x63ode\x18\x01 \x01(\x0e\x32\x16.fila.v1.NackErrorCode\x12\x0f\n\x07message\x18\x02 \x01(\t\"4\n\x0cNackResponse\x12$\n\x07results\x18\x01 \x03(\x0b\x32\x13.fila.v1.NackResult\"Z\n\x14StreamEnqueueRequest\x12)\n\x08messages\x18\x01 \x03(\x0b\x32\x17.fila.v1.EnqueueMessage\x12\x17\n\x0fsequence_number\x18\x02 \x01(\x04\"Y\n\x15StreamEnqueueResponse\x12\x17\n\x0fsequence_number\x18\x01 \x01(\x04\x12\'\n\x07results\x18\x02 \x03(\x0b\x32\x16.fila.v1.EnqueueResult*\xc4\x01\n\x10\x45nqueueErrorCode\x12\"\n\x1e\x45NQUEUE_ERROR_CODE_UNSPECIFIED\x10\x00\x12&\n\"ENQUEUE_ERROR_CODE_QUEUE_NOT_FOUND\x10\x01\x12\x1e\n\x1a\x45NQUEUE_ERROR_CODE_STORAGE\x10\x02\x12\x1a\n\x16\x45NQUEUE_ERROR_CODE_LUA\x10\x03\x12(\n$ENQUEUE_ERROR_CODE_PERMISSION_DENIED\x10\x04*\x96\x01\n\x0c\x41\x63kErrorCode\x12\x1e\n\x1a\x41\x43K_ERROR_CODE_UNSPECIFIED\x10\x00\x12$\n ACK_ERROR_CODE_MESSAGE_NOT_FOUND\x10\x01\x12\x1a\n\x16\x41\x43K_ERROR_CODE_STORAGE\x10\x02\x12$\n ACK_ERROR_CODE_PERMISSION_DENIED\x10\x03*\x9b\x01\n\rNackErrorCode\x12\x1f\n\x1bNACK_ERROR_CODE_UNSPECIFIED\x10\x00\x12%\n!NACK_ERROR_CODE_MESSAGE_NOT_FOUND\x10\x01\x12\x1b\n\x17NACK_ERROR_CODE_STORAGE\x10\x02\x12%\n!NACK_ERROR_CODE_PERMISSION_DENIED\x10\x03\x32\xc6\x02\n\x0b\x46ilaService\x12<\n\x07\x45nqueue\x12\x17.fila.v1.EnqueueRequest\x1a\x18.fila.v1.EnqueueResponse\x12R\n\rStreamEnqueue\x12\x1d.fila.v1.StreamEnqueueRequest\x1a\x1e.fila.v1.StreamEnqueueResponse(\x01\x30\x01\x12>\n\x07\x43onsume\x12\x17.fila.v1.ConsumeRequest\x1a\x18.fila.v1.ConsumeResponse0\x01\x12\x30\n\x03\x41\x63k\x12\x13.fila.v1.AckRequest\x1a\x14.fila.v1.AckResponse\x12\x33\n\x04Nack\x12\x14.fila.v1.NackRequest\x1a\x15.fila.v1.NackResponseb\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'fila.v1.service_pb2', _globals) if not _descriptor._USE_C_DESCRIPTORS: DESCRIPTOR._loaded_options = None - _globals['_ENQUEUEREQUEST_HEADERSENTRY']._loaded_options = None - _globals['_ENQUEUEREQUEST_HEADERSENTRY']._serialized_options = b'8\001' - _globals['_ENQUEUEREQUEST']._serialized_start=59 - _globals['_ENQUEUEREQUEST']._serialized_end=210 - _globals['_ENQUEUEREQUEST_HEADERSENTRY']._serialized_start=164 - _globals['_ENQUEUEREQUEST_HEADERSENTRY']._serialized_end=210 - _globals['_ENQUEUERESPONSE']._serialized_start=212 - _globals['_ENQUEUERESPONSE']._serialized_end=249 - _globals['_CONSUMEREQUEST']._serialized_start=251 - _globals['_CONSUMEREQUEST']._serialized_end=282 - _globals['_CONSUMERESPONSE']._serialized_start=284 - _globals['_CONSUMERESPONSE']._serialized_end=372 - _globals['_ACKREQUEST']._serialized_start=374 - _globals['_ACKREQUEST']._serialized_end=421 - _globals['_ACKRESPONSE']._serialized_start=423 - _globals['_ACKRESPONSE']._serialized_end=436 - _globals['_NACKREQUEST']._serialized_start=438 - _globals['_NACKREQUEST']._serialized_end=501 - _globals['_NACKRESPONSE']._serialized_start=503 - _globals['_NACKRESPONSE']._serialized_end=517 - _globals['_BATCHENQUEUEREQUEST']._serialized_start=519 - _globals['_BATCHENQUEUEREQUEST']._serialized_end=583 - _globals['_BATCHENQUEUERESPONSE']._serialized_start=585 - _globals['_BATCHENQUEUERESPONSE']._serialized_end=653 - _globals['_BATCHENQUEUERESULT']._serialized_start=655 - _globals['_BATCHENQUEUERESULT']._serialized_end=747 - _globals['_FILASERVICE']._serialized_start=750 - _globals['_FILASERVICE']._serialized_end=1069 + _globals['_ENQUEUEMESSAGE_HEADERSENTRY']._loaded_options = None + _globals['_ENQUEUEMESSAGE_HEADERSENTRY']._serialized_options = b'8\001' + _globals['_ENQUEUEERRORCODE']._serialized_start=1460 + _globals['_ENQUEUEERRORCODE']._serialized_end=1656 + _globals['_ACKERRORCODE']._serialized_start=1659 + _globals['_ACKERRORCODE']._serialized_end=1809 + _globals['_NACKERRORCODE']._serialized_start=1812 + _globals['_NACKERRORCODE']._serialized_end=1967 + _globals['_ENQUEUEMESSAGE']._serialized_start=59 + _globals['_ENQUEUEMESSAGE']._serialized_end=210 + _globals['_ENQUEUEMESSAGE_HEADERSENTRY']._serialized_start=164 + _globals['_ENQUEUEMESSAGE_HEADERSENTRY']._serialized_end=210 + _globals['_ENQUEUEREQUEST']._serialized_start=212 + _globals['_ENQUEUEREQUEST']._serialized_end=271 + _globals['_ENQUEUERESULT']._serialized_start=273 + _globals['_ENQUEUERESULT']._serialized_end=360 + _globals['_ENQUEUEERROR']._serialized_start=362 + _globals['_ENQUEUEERROR']._serialized_end=434 + _globals['_ENQUEUERESPONSE']._serialized_start=436 + _globals['_ENQUEUERESPONSE']._serialized_end=494 + _globals['_CONSUMEREQUEST']._serialized_start=496 + _globals['_CONSUMEREQUEST']._serialized_end=527 + _globals['_CONSUMERESPONSE']._serialized_start=529 + _globals['_CONSUMERESPONSE']._serialized_end=582 + _globals['_ACKMESSAGE']._serialized_start=584 + _globals['_ACKMESSAGE']._serialized_end=631 + _globals['_ACKREQUEST']._serialized_start=633 + _globals['_ACKREQUEST']._serialized_end=684 + _globals['_ACKRESULT']._serialized_start=686 + _globals['_ACKRESULT']._serialized_end=783 + _globals['_ACKSUCCESS']._serialized_start=785 + _globals['_ACKSUCCESS']._serialized_end=797 + _globals['_ACKERROR']._serialized_start=799 + _globals['_ACKERROR']._serialized_end=863 + _globals['_ACKRESPONSE']._serialized_start=865 + _globals['_ACKRESPONSE']._serialized_end=915 + _globals['_NACKMESSAGE']._serialized_start=917 + _globals['_NACKMESSAGE']._serialized_end=980 + _globals['_NACKREQUEST']._serialized_start=982 + _globals['_NACKREQUEST']._serialized_end=1035 + _globals['_NACKRESULT']._serialized_start=1037 + _globals['_NACKRESULT']._serialized_end=1137 + _globals['_NACKSUCCESS']._serialized_start=1139 + _globals['_NACKSUCCESS']._serialized_end=1152 + _globals['_NACKERROR']._serialized_start=1154 + _globals['_NACKERROR']._serialized_end=1220 + _globals['_NACKRESPONSE']._serialized_start=1222 + _globals['_NACKRESPONSE']._serialized_end=1274 + _globals['_STREAMENQUEUEREQUEST']._serialized_start=1276 + _globals['_STREAMENQUEUEREQUEST']._serialized_end=1366 + _globals['_STREAMENQUEUERESPONSE']._serialized_start=1368 + _globals['_STREAMENQUEUERESPONSE']._serialized_end=1457 + _globals['_FILASERVICE']._serialized_start=1970 + _globals['_FILASERVICE']._serialized_end=2296 # @@protoc_insertion_point(module_scope) diff --git a/fila/v1/service_pb2.pyi b/fila/v1/service_pb2.pyi index ca1e820..a840197 100644 --- a/fila/v1/service_pb2.pyi +++ b/fila/v1/service_pb2.pyi @@ -1,5 +1,6 @@ from fila.v1 import messages_pb2 as _messages_pb2 from google.protobuf.internal import containers as _containers +from google.protobuf.internal import enum_type_wrapper as _enum_type_wrapper from google.protobuf import descriptor as _descriptor from google.protobuf import message as _message from collections.abc import Iterable as _Iterable, Mapping as _Mapping @@ -7,7 +8,42 @@ from typing import ClassVar as _ClassVar, Optional as _Optional, Union as _Union DESCRIPTOR: _descriptor.FileDescriptor -class EnqueueRequest(_message.Message): +class EnqueueErrorCode(int, metaclass=_enum_type_wrapper.EnumTypeWrapper): + __slots__ = () + ENQUEUE_ERROR_CODE_UNSPECIFIED: _ClassVar[EnqueueErrorCode] + ENQUEUE_ERROR_CODE_QUEUE_NOT_FOUND: _ClassVar[EnqueueErrorCode] + ENQUEUE_ERROR_CODE_STORAGE: _ClassVar[EnqueueErrorCode] + ENQUEUE_ERROR_CODE_LUA: _ClassVar[EnqueueErrorCode] + ENQUEUE_ERROR_CODE_PERMISSION_DENIED: _ClassVar[EnqueueErrorCode] + +class AckErrorCode(int, metaclass=_enum_type_wrapper.EnumTypeWrapper): + __slots__ = () + ACK_ERROR_CODE_UNSPECIFIED: _ClassVar[AckErrorCode] + ACK_ERROR_CODE_MESSAGE_NOT_FOUND: _ClassVar[AckErrorCode] + ACK_ERROR_CODE_STORAGE: _ClassVar[AckErrorCode] + ACK_ERROR_CODE_PERMISSION_DENIED: _ClassVar[AckErrorCode] + +class NackErrorCode(int, metaclass=_enum_type_wrapper.EnumTypeWrapper): + __slots__ = () + NACK_ERROR_CODE_UNSPECIFIED: _ClassVar[NackErrorCode] + NACK_ERROR_CODE_MESSAGE_NOT_FOUND: _ClassVar[NackErrorCode] + NACK_ERROR_CODE_STORAGE: _ClassVar[NackErrorCode] + NACK_ERROR_CODE_PERMISSION_DENIED: _ClassVar[NackErrorCode] +ENQUEUE_ERROR_CODE_UNSPECIFIED: EnqueueErrorCode +ENQUEUE_ERROR_CODE_QUEUE_NOT_FOUND: EnqueueErrorCode +ENQUEUE_ERROR_CODE_STORAGE: EnqueueErrorCode +ENQUEUE_ERROR_CODE_LUA: EnqueueErrorCode +ENQUEUE_ERROR_CODE_PERMISSION_DENIED: EnqueueErrorCode +ACK_ERROR_CODE_UNSPECIFIED: AckErrorCode +ACK_ERROR_CODE_MESSAGE_NOT_FOUND: AckErrorCode +ACK_ERROR_CODE_STORAGE: AckErrorCode +ACK_ERROR_CODE_PERMISSION_DENIED: AckErrorCode +NACK_ERROR_CODE_UNSPECIFIED: NackErrorCode +NACK_ERROR_CODE_MESSAGE_NOT_FOUND: NackErrorCode +NACK_ERROR_CODE_STORAGE: NackErrorCode +NACK_ERROR_CODE_PERMISSION_DENIED: NackErrorCode + +class EnqueueMessage(_message.Message): __slots__ = ("queue", "headers", "payload") class HeadersEntry(_message.Message): __slots__ = ("key", "value") @@ -24,11 +60,33 @@ class EnqueueRequest(_message.Message): payload: bytes def __init__(self, queue: _Optional[str] = ..., headers: _Optional[_Mapping[str, str]] = ..., payload: _Optional[bytes] = ...) -> None: ... -class EnqueueResponse(_message.Message): - __slots__ = ("message_id",) +class EnqueueRequest(_message.Message): + __slots__ = ("messages",) + MESSAGES_FIELD_NUMBER: _ClassVar[int] + messages: _containers.RepeatedCompositeFieldContainer[EnqueueMessage] + def __init__(self, messages: _Optional[_Iterable[_Union[EnqueueMessage, _Mapping]]] = ...) -> None: ... + +class EnqueueResult(_message.Message): + __slots__ = ("message_id", "error") MESSAGE_ID_FIELD_NUMBER: _ClassVar[int] + ERROR_FIELD_NUMBER: _ClassVar[int] message_id: str - def __init__(self, message_id: _Optional[str] = ...) -> None: ... + error: EnqueueError + def __init__(self, message_id: _Optional[str] = ..., error: _Optional[_Union[EnqueueError, _Mapping]] = ...) -> None: ... + +class EnqueueError(_message.Message): + __slots__ = ("code", "message") + CODE_FIELD_NUMBER: _ClassVar[int] + MESSAGE_FIELD_NUMBER: _ClassVar[int] + code: EnqueueErrorCode + message: str + def __init__(self, code: _Optional[_Union[EnqueueErrorCode, str]] = ..., message: _Optional[str] = ...) -> None: ... + +class EnqueueResponse(_message.Message): + __slots__ = ("results",) + RESULTS_FIELD_NUMBER: _ClassVar[int] + results: _containers.RepeatedCompositeFieldContainer[EnqueueResult] + def __init__(self, results: _Optional[_Iterable[_Union[EnqueueResult, _Mapping]]] = ...) -> None: ... class ConsumeRequest(_message.Message): __slots__ = ("queue",) @@ -37,14 +95,12 @@ class ConsumeRequest(_message.Message): def __init__(self, queue: _Optional[str] = ...) -> None: ... class ConsumeResponse(_message.Message): - __slots__ = ("message", "messages") - MESSAGE_FIELD_NUMBER: _ClassVar[int] + __slots__ = ("messages",) MESSAGES_FIELD_NUMBER: _ClassVar[int] - message: _messages_pb2.Message messages: _containers.RepeatedCompositeFieldContainer[_messages_pb2.Message] - def __init__(self, message: _Optional[_Union[_messages_pb2.Message, _Mapping]] = ..., messages: _Optional[_Iterable[_Union[_messages_pb2.Message, _Mapping]]] = ...) -> None: ... + def __init__(self, messages: _Optional[_Iterable[_Union[_messages_pb2.Message, _Mapping]]] = ...) -> None: ... -class AckRequest(_message.Message): +class AckMessage(_message.Message): __slots__ = ("queue", "message_id") QUEUE_FIELD_NUMBER: _ClassVar[int] MESSAGE_ID_FIELD_NUMBER: _ClassVar[int] @@ -52,11 +108,39 @@ class AckRequest(_message.Message): message_id: str def __init__(self, queue: _Optional[str] = ..., message_id: _Optional[str] = ...) -> None: ... -class AckResponse(_message.Message): +class AckRequest(_message.Message): + __slots__ = ("messages",) + MESSAGES_FIELD_NUMBER: _ClassVar[int] + messages: _containers.RepeatedCompositeFieldContainer[AckMessage] + def __init__(self, messages: _Optional[_Iterable[_Union[AckMessage, _Mapping]]] = ...) -> None: ... + +class AckResult(_message.Message): + __slots__ = ("success", "error") + SUCCESS_FIELD_NUMBER: _ClassVar[int] + ERROR_FIELD_NUMBER: _ClassVar[int] + success: AckSuccess + error: AckError + def __init__(self, success: _Optional[_Union[AckSuccess, _Mapping]] = ..., error: _Optional[_Union[AckError, _Mapping]] = ...) -> None: ... + +class AckSuccess(_message.Message): __slots__ = () def __init__(self) -> None: ... -class NackRequest(_message.Message): +class AckError(_message.Message): + __slots__ = ("code", "message") + CODE_FIELD_NUMBER: _ClassVar[int] + MESSAGE_FIELD_NUMBER: _ClassVar[int] + code: AckErrorCode + message: str + def __init__(self, code: _Optional[_Union[AckErrorCode, str]] = ..., message: _Optional[str] = ...) -> None: ... + +class AckResponse(_message.Message): + __slots__ = ("results",) + RESULTS_FIELD_NUMBER: _ClassVar[int] + results: _containers.RepeatedCompositeFieldContainer[AckResult] + def __init__(self, results: _Optional[_Iterable[_Union[AckResult, _Mapping]]] = ...) -> None: ... + +class NackMessage(_message.Message): __slots__ = ("queue", "message_id", "error") QUEUE_FIELD_NUMBER: _ClassVar[int] MESSAGE_ID_FIELD_NUMBER: _ClassVar[int] @@ -66,26 +150,50 @@ class NackRequest(_message.Message): error: str def __init__(self, queue: _Optional[str] = ..., message_id: _Optional[str] = ..., error: _Optional[str] = ...) -> None: ... -class NackResponse(_message.Message): +class NackRequest(_message.Message): + __slots__ = ("messages",) + MESSAGES_FIELD_NUMBER: _ClassVar[int] + messages: _containers.RepeatedCompositeFieldContainer[NackMessage] + def __init__(self, messages: _Optional[_Iterable[_Union[NackMessage, _Mapping]]] = ...) -> None: ... + +class NackResult(_message.Message): + __slots__ = ("success", "error") + SUCCESS_FIELD_NUMBER: _ClassVar[int] + ERROR_FIELD_NUMBER: _ClassVar[int] + success: NackSuccess + error: NackError + def __init__(self, success: _Optional[_Union[NackSuccess, _Mapping]] = ..., error: _Optional[_Union[NackError, _Mapping]] = ...) -> None: ... + +class NackSuccess(_message.Message): __slots__ = () def __init__(self) -> None: ... -class BatchEnqueueRequest(_message.Message): - __slots__ = ("messages",) - MESSAGES_FIELD_NUMBER: _ClassVar[int] - messages: _containers.RepeatedCompositeFieldContainer[EnqueueRequest] - def __init__(self, messages: _Optional[_Iterable[_Union[EnqueueRequest, _Mapping]]] = ...) -> None: ... +class NackError(_message.Message): + __slots__ = ("code", "message") + CODE_FIELD_NUMBER: _ClassVar[int] + MESSAGE_FIELD_NUMBER: _ClassVar[int] + code: NackErrorCode + message: str + def __init__(self, code: _Optional[_Union[NackErrorCode, str]] = ..., message: _Optional[str] = ...) -> None: ... -class BatchEnqueueResponse(_message.Message): +class NackResponse(_message.Message): __slots__ = ("results",) RESULTS_FIELD_NUMBER: _ClassVar[int] - results: _containers.RepeatedCompositeFieldContainer[BatchEnqueueResult] - def __init__(self, results: _Optional[_Iterable[_Union[BatchEnqueueResult, _Mapping]]] = ...) -> None: ... + results: _containers.RepeatedCompositeFieldContainer[NackResult] + def __init__(self, results: _Optional[_Iterable[_Union[NackResult, _Mapping]]] = ...) -> None: ... -class BatchEnqueueResult(_message.Message): - __slots__ = ("success", "error") - SUCCESS_FIELD_NUMBER: _ClassVar[int] - ERROR_FIELD_NUMBER: _ClassVar[int] - success: EnqueueResponse - error: str - def __init__(self, success: _Optional[_Union[EnqueueResponse, _Mapping]] = ..., error: _Optional[str] = ...) -> None: ... +class StreamEnqueueRequest(_message.Message): + __slots__ = ("messages", "sequence_number") + MESSAGES_FIELD_NUMBER: _ClassVar[int] + SEQUENCE_NUMBER_FIELD_NUMBER: _ClassVar[int] + messages: _containers.RepeatedCompositeFieldContainer[EnqueueMessage] + sequence_number: int + def __init__(self, messages: _Optional[_Iterable[_Union[EnqueueMessage, _Mapping]]] = ..., sequence_number: _Optional[int] = ...) -> None: ... + +class StreamEnqueueResponse(_message.Message): + __slots__ = ("sequence_number", "results") + SEQUENCE_NUMBER_FIELD_NUMBER: _ClassVar[int] + RESULTS_FIELD_NUMBER: _ClassVar[int] + sequence_number: int + results: _containers.RepeatedCompositeFieldContainer[EnqueueResult] + def __init__(self, sequence_number: _Optional[int] = ..., results: _Optional[_Iterable[_Union[EnqueueResult, _Mapping]]] = ...) -> None: ... diff --git a/fila/v1/service_pb2_grpc.py b/fila/v1/service_pb2_grpc.py index 0ef11e1..1f2df55 100644 --- a/fila/v1/service_pb2_grpc.py +++ b/fila/v1/service_pb2_grpc.py @@ -40,10 +40,10 @@ def __init__(self, channel): request_serializer=fila_dot_v1_dot_service__pb2.EnqueueRequest.SerializeToString, response_deserializer=fila_dot_v1_dot_service__pb2.EnqueueResponse.FromString, _registered_method=True) - self.BatchEnqueue = channel.unary_unary( - '/fila.v1.FilaService/BatchEnqueue', - request_serializer=fila_dot_v1_dot_service__pb2.BatchEnqueueRequest.SerializeToString, - response_deserializer=fila_dot_v1_dot_service__pb2.BatchEnqueueResponse.FromString, + self.StreamEnqueue = channel.stream_stream( + '/fila.v1.FilaService/StreamEnqueue', + request_serializer=fila_dot_v1_dot_service__pb2.StreamEnqueueRequest.SerializeToString, + response_deserializer=fila_dot_v1_dot_service__pb2.StreamEnqueueResponse.FromString, _registered_method=True) self.Consume = channel.unary_stream( '/fila.v1.FilaService/Consume', @@ -72,7 +72,7 @@ def Enqueue(self, request, context): context.set_details('Method not implemented!') raise NotImplementedError('Method not implemented!') - def BatchEnqueue(self, request, context): + def StreamEnqueue(self, request_iterator, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) context.set_details('Method not implemented!') @@ -104,10 +104,10 @@ def add_FilaServiceServicer_to_server(servicer, server): request_deserializer=fila_dot_v1_dot_service__pb2.EnqueueRequest.FromString, response_serializer=fila_dot_v1_dot_service__pb2.EnqueueResponse.SerializeToString, ), - 'BatchEnqueue': grpc.unary_unary_rpc_method_handler( - servicer.BatchEnqueue, - request_deserializer=fila_dot_v1_dot_service__pb2.BatchEnqueueRequest.FromString, - response_serializer=fila_dot_v1_dot_service__pb2.BatchEnqueueResponse.SerializeToString, + 'StreamEnqueue': grpc.stream_stream_rpc_method_handler( + servicer.StreamEnqueue, + request_deserializer=fila_dot_v1_dot_service__pb2.StreamEnqueueRequest.FromString, + response_serializer=fila_dot_v1_dot_service__pb2.StreamEnqueueResponse.SerializeToString, ), 'Consume': grpc.unary_stream_rpc_method_handler( servicer.Consume, @@ -164,7 +164,7 @@ def Enqueue(request, _registered_method=True) @staticmethod - def BatchEnqueue(request, + def StreamEnqueue(request_iterator, target, options=(), channel_credentials=None, @@ -174,12 +174,12 @@ def BatchEnqueue(request, wait_for_ready=None, timeout=None, metadata=None): - return grpc.experimental.unary_unary( - request, + return grpc.experimental.stream_stream( + request_iterator, target, - '/fila.v1.FilaService/BatchEnqueue', - fila_dot_v1_dot_service__pb2.BatchEnqueueRequest.SerializeToString, - fila_dot_v1_dot_service__pb2.BatchEnqueueResponse.FromString, + '/fila.v1.FilaService/StreamEnqueue', + fila_dot_v1_dot_service__pb2.StreamEnqueueRequest.SerializeToString, + fila_dot_v1_dot_service__pb2.StreamEnqueueResponse.FromString, options, channel_credentials, insecure, diff --git a/proto/fila/v1/admin.proto b/proto/fila/v1/admin.proto index 886e58d..9bb8871 100644 --- a/proto/fila/v1/admin.proto +++ b/proto/fila/v1/admin.proto @@ -11,15 +11,6 @@ service FilaAdmin { rpc GetStats(GetStatsRequest) returns (GetStatsResponse); rpc Redrive(RedriveRequest) returns (RedriveResponse); rpc ListQueues(ListQueuesRequest) returns (ListQueuesResponse); - - // API key management. CreateApiKey bypasses auth (bootstrap); others require a valid key. - rpc CreateApiKey(CreateApiKeyRequest) returns (CreateApiKeyResponse); - rpc RevokeApiKey(RevokeApiKeyRequest) returns (RevokeApiKeyResponse); - rpc ListApiKeys(ListApiKeysRequest) returns (ListApiKeysResponse); - - // Per-key ACL management. - rpc SetAcl(SetAclRequest) returns (SetAclResponse); - rpc GetAcl(GetAclRequest) returns (GetAclResponse); } message CreateQueueRequest { @@ -126,72 +117,3 @@ message ListQueuesResponse { repeated QueueInfo queues = 1; uint32 cluster_node_count = 2; } - -// --- API Key Management --- - -message CreateApiKeyRequest { - /// Human-readable label for the key. - string name = 1; - /// Optional Unix timestamp (milliseconds) after which the key expires. - /// 0 means no expiration. - uint64 expires_at_ms = 2; - /// When true, the key bypasses all ACL checks (superadmin). - bool is_superadmin = 3; -} - -message CreateApiKeyResponse { - /// Opaque key ID for management operations (revoke, list, set-acl). - string key_id = 1; - /// Plaintext API key. Returned once — store it securely. - string key = 2; - /// Whether this key has superadmin privileges. - bool is_superadmin = 3; -} - -message RevokeApiKeyRequest { - string key_id = 1; -} - -message RevokeApiKeyResponse {} - -message ListApiKeysRequest {} - -message ApiKeyInfo { - string key_id = 1; - string name = 2; - uint64 created_at_ms = 3; - /// 0 means no expiration. - uint64 expires_at_ms = 4; - bool is_superadmin = 5; -} - -message ListApiKeysResponse { - repeated ApiKeyInfo keys = 1; -} - -// --- ACL Management --- - -/// A single permission grant: kind (produce/consume/admin) + queue pattern. -message AclPermission { - /// One of: "produce", "consume", "admin". - string kind = 1; - /// Queue name or wildcard ("*" or "orders.*"). - string pattern = 2; -} - -message SetAclRequest { - string key_id = 1; - repeated AclPermission permissions = 2; -} - -message SetAclResponse {} - -message GetAclRequest { - string key_id = 1; -} - -message GetAclResponse { - string key_id = 1; - repeated AclPermission permissions = 2; - bool is_superadmin = 3; -} diff --git a/proto/fila/v1/service.proto b/proto/fila/v1/service.proto index fc0f710..7d1db79 100644 --- a/proto/fila/v1/service.proto +++ b/proto/fila/v1/service.proto @@ -6,20 +6,49 @@ import "fila/v1/messages.proto"; // Hot-path RPCs for producers and consumers. service FilaService { rpc Enqueue(EnqueueRequest) returns (EnqueueResponse); - rpc BatchEnqueue(BatchEnqueueRequest) returns (BatchEnqueueResponse); + rpc StreamEnqueue(stream StreamEnqueueRequest) returns (stream StreamEnqueueResponse); rpc Consume(ConsumeRequest) returns (stream ConsumeResponse); rpc Ack(AckRequest) returns (AckResponse); rpc Nack(NackRequest) returns (NackResponse); } -message EnqueueRequest { +// Individual message to enqueue. +message EnqueueMessage { string queue = 1; map headers = 2; bytes payload = 3; } +// Enqueue one or more messages. +message EnqueueRequest { + repeated EnqueueMessage messages = 1; +} + +// Per-message enqueue result. +message EnqueueResult { + oneof result { + string message_id = 1; + EnqueueError error = 2; + } +} + +// Typed enqueue error with structured error code. +message EnqueueError { + EnqueueErrorCode code = 1; + string message = 2; +} + +enum EnqueueErrorCode { + ENQUEUE_ERROR_CODE_UNSPECIFIED = 0; + ENQUEUE_ERROR_CODE_QUEUE_NOT_FOUND = 1; + ENQUEUE_ERROR_CODE_STORAGE = 2; + ENQUEUE_ERROR_CODE_LUA = 3; + ENQUEUE_ERROR_CODE_PERMISSION_DENIED = 4; +} + +// One result per input message. message EnqueueResponse { - string message_id = 1; + repeated EnqueueResult results = 1; } message ConsumeRequest { @@ -27,36 +56,87 @@ message ConsumeRequest { } message ConsumeResponse { - Message message = 1; // Single message (backward compatible, used when batch size is 1) - repeated Message messages = 2; // Batched messages (populated when server sends multiple at once) + repeated Message messages = 1; } -message AckRequest { +// Individual ack item. +message AckMessage { string queue = 1; string message_id = 2; } -message AckResponse {} +message AckRequest { + repeated AckMessage messages = 1; +} + +message AckResult { + oneof result { + AckSuccess success = 1; + AckError error = 2; + } +} -message NackRequest { +message AckSuccess {} + +message AckError { + AckErrorCode code = 1; + string message = 2; +} + +enum AckErrorCode { + ACK_ERROR_CODE_UNSPECIFIED = 0; + ACK_ERROR_CODE_MESSAGE_NOT_FOUND = 1; + ACK_ERROR_CODE_STORAGE = 2; + ACK_ERROR_CODE_PERMISSION_DENIED = 3; +} + +message AckResponse { + repeated AckResult results = 1; +} + +// Individual nack item. +message NackMessage { string queue = 1; string message_id = 2; string error = 3; } -message NackResponse {} +message NackRequest { + repeated NackMessage messages = 1; +} + +message NackResult { + oneof result { + NackSuccess success = 1; + NackError error = 2; + } +} -message BatchEnqueueRequest { - repeated EnqueueRequest messages = 1; +message NackSuccess {} + +message NackError { + NackErrorCode code = 1; + string message = 2; } -message BatchEnqueueResponse { - repeated BatchEnqueueResult results = 1; +enum NackErrorCode { + NACK_ERROR_CODE_UNSPECIFIED = 0; + NACK_ERROR_CODE_MESSAGE_NOT_FOUND = 1; + NACK_ERROR_CODE_STORAGE = 2; + NACK_ERROR_CODE_PERMISSION_DENIED = 3; } -message BatchEnqueueResult { - oneof result { - EnqueueResponse success = 1; - string error = 2; - } +message NackResponse { + repeated NackResult results = 1; +} + +// Stream enqueue — per-write batch with sequence tracking. +message StreamEnqueueRequest { + repeated EnqueueMessage messages = 1; + uint64 sequence_number = 2; +} + +message StreamEnqueueResponse { + uint64 sequence_number = 1; + repeated EnqueueResult results = 2; } diff --git a/tests/test_batcher.py b/tests/test_batcher.py index ee10e71..3489a42 100644 --- a/tests/test_batcher.py +++ b/tests/test_batcher.py @@ -13,44 +13,38 @@ import pytest from fila.batcher import ( - AutoBatcher, - LingerBatcher, - _EnqueueRequest, - _flush_batch, + AutoAccumulator, + LingerAccumulator, + _EnqueueItem, + _flush_many, _flush_single, ) -from fila.errors import BatchEnqueueError +from fila.errors import EnqueueError from fila.v1 import service_pb2 -class FakeEnqueueResponse: - """Minimal fake for service_pb2.EnqueueResponse.""" - - def __init__(self, message_id: str) -> None: - self.message_id = message_id - - -class FakeBatchResult: - """Minimal fake for service_pb2.BatchEnqueueResult.""" +class FakeEnqueueResult: + """Minimal fake for service_pb2.EnqueueResult.""" - def __init__(self, message_id: str | None = None, error: str | None = None) -> None: + def __init__(self, message_id: str | None = None, error_msg: str | None = None) -> None: self._message_id = message_id - self._error = error - self.success: FakeEnqueueResponse | None = ( - FakeEnqueueResponse(message_id) if message_id is not None else None - ) - self.error = error or "" + self._error_msg = error_msg + self.message_id = message_id or "" + self.error = MagicMock() + self.error.message = error_msg or "" - def HasField(self, name: str) -> bool: # noqa: N802 - if name == "success": - return self._message_id is not None - return False + def WhichOneof(self, name: str) -> str | None: # noqa: N802 + if name == "result": + if self._message_id is not None: + return "message_id" + return "error" + return None -class FakeBatchResponse: - """Minimal fake for service_pb2.BatchEnqueueResponse.""" +class FakeEnqueueResponse: + """Minimal fake for service_pb2.EnqueueResponse.""" - def __init__(self, results: list[FakeBatchResult]) -> None: + def __init__(self, results: list[FakeEnqueueResult]) -> None: self.results = results @@ -59,25 +53,23 @@ class TestFlushSingle: def test_success(self) -> None: stub = MagicMock() - stub.Enqueue.return_value = FakeEnqueueResponse("msg-001") + stub.Enqueue.return_value = FakeEnqueueResponse([ + FakeEnqueueResult(message_id="msg-001"), + ]) - proto = service_pb2.EnqueueRequest(queue="q", payload=b"data") + proto = service_pb2.EnqueueMessage(queue="q", payload=b"data") fut: Future[str] = Future() - req = _EnqueueRequest(proto, fut) + req = _EnqueueItem(proto, fut) _flush_single(stub, req) assert fut.result(timeout=1.0) == "msg-001" - stub.Enqueue.assert_called_once_with(proto) + stub.Enqueue.assert_called_once() def test_rpc_error(self) -> None: import grpc stub = MagicMock() - rpc_error = MagicMock() - rpc_error.code.return_value = grpc.StatusCode.NOT_FOUND - rpc_error.details.return_value = "queue not found" - # Make it pass isinstance(e, grpc.RpcError) check. stub.Enqueue.side_effect = type( "_FakeRpcError", (grpc.RpcError,), { "code": lambda self: grpc.StatusCode.NOT_FOUND, @@ -85,9 +77,9 @@ def test_rpc_error(self) -> None: } )() - proto = service_pb2.EnqueueRequest(queue="missing", payload=b"data") + proto = service_pb2.EnqueueMessage(queue="missing", payload=b"data") fut: Future[str] = Future() - req = _EnqueueRequest(proto, fut) + req = _EnqueueItem(proto, fut) _flush_single(stub, req) @@ -97,164 +89,151 @@ def test_rpc_error(self) -> None: fut.result(timeout=1.0) -class TestFlushBatch: - """Test the _flush_batch function.""" +class TestFlushMany: + """Test the _flush_many function.""" def test_all_success(self) -> None: stub = MagicMock() - stub.BatchEnqueue.return_value = FakeBatchResponse([ - FakeBatchResult(message_id="id-1"), - FakeBatchResult(message_id="id-2"), + stub.Enqueue.return_value = FakeEnqueueResponse([ + FakeEnqueueResult(message_id="id-1"), + FakeEnqueueResult(message_id="id-2"), ]) - reqs = [ - _EnqueueRequest( - service_pb2.EnqueueRequest(queue="q", payload=b"a"), + items = [ + _EnqueueItem( + service_pb2.EnqueueMessage(queue="q", payload=b"a"), Future(), ), - _EnqueueRequest( - service_pb2.EnqueueRequest(queue="q", payload=b"b"), + _EnqueueItem( + service_pb2.EnqueueMessage(queue="q", payload=b"b"), Future(), ), ] - _flush_batch(stub, reqs) + _flush_many(stub, items) - assert reqs[0].future.result(timeout=1.0) == "id-1" - assert reqs[1].future.result(timeout=1.0) == "id-2" + assert items[0].future.result(timeout=1.0) == "id-1" + assert items[1].future.result(timeout=1.0) == "id-2" def test_mixed_results(self) -> None: stub = MagicMock() - stub.BatchEnqueue.return_value = FakeBatchResponse([ - FakeBatchResult(message_id="id-1"), - FakeBatchResult(error="queue 'missing' not found"), + stub.Enqueue.return_value = FakeEnqueueResponse([ + FakeEnqueueResult(message_id="id-1"), + FakeEnqueueResult(error_msg="queue 'missing' not found"), ]) - reqs = [ - _EnqueueRequest( - service_pb2.EnqueueRequest(queue="q", payload=b"a"), + items = [ + _EnqueueItem( + service_pb2.EnqueueMessage(queue="q", payload=b"a"), Future(), ), - _EnqueueRequest( - service_pb2.EnqueueRequest(queue="missing", payload=b"b"), + _EnqueueItem( + service_pb2.EnqueueMessage(queue="missing", payload=b"b"), Future(), ), ] - _flush_batch(stub, reqs) + _flush_many(stub, items) - assert reqs[0].future.result(timeout=1.0) == "id-1" - with pytest.raises(BatchEnqueueError, match="queue 'missing' not found"): - reqs[1].future.result(timeout=1.0) + assert items[0].future.result(timeout=1.0) == "id-1" + with pytest.raises(EnqueueError, match="queue 'missing' not found"): + items[1].future.result(timeout=1.0) def test_rpc_failure_sets_all_futures(self) -> None: import grpc stub = MagicMock() - stub.BatchEnqueue.side_effect = type( + stub.Enqueue.side_effect = type( "_FakeRpcError", (grpc.RpcError,), { "code": lambda self: grpc.StatusCode.UNAVAILABLE, "details": lambda self: "server unavailable", } )() - reqs = [ - _EnqueueRequest( - service_pb2.EnqueueRequest(queue="q", payload=b"a"), + items = [ + _EnqueueItem( + service_pb2.EnqueueMessage(queue="q", payload=b"a"), Future(), ), - _EnqueueRequest( - service_pb2.EnqueueRequest(queue="q", payload=b"b"), + _EnqueueItem( + service_pb2.EnqueueMessage(queue="q", payload=b"b"), Future(), ), ] - _flush_batch(stub, reqs) + _flush_many(stub, items) - for r in reqs: - with pytest.raises(BatchEnqueueError): - r.future.result(timeout=1.0) + for item in items: + with pytest.raises(EnqueueError): + item.future.result(timeout=1.0) -class TestAutoBatcher: - """Test the AutoBatcher end-to-end.""" +class TestAutoAccumulator: + """Test the AutoAccumulator end-to-end.""" def test_single_message_uses_enqueue(self) -> None: - """When only one message is queued, AutoBatcher uses singular Enqueue.""" + """When only one message is queued, AutoAccumulator uses Enqueue with one message.""" stub = MagicMock() - stub.Enqueue.return_value = FakeEnqueueResponse("msg-solo") + stub.Enqueue.return_value = FakeEnqueueResponse([ + FakeEnqueueResult(message_id="msg-solo"), + ]) - batcher = AutoBatcher(stub, max_batch_size=100) + accumulator = AutoAccumulator(stub, max_messages=100) - proto = service_pb2.EnqueueRequest(queue="q", payload=b"solo") - fut = batcher.submit(proto) + proto = service_pb2.EnqueueMessage(queue="q", payload=b"solo") + fut = accumulator.submit(proto) result = fut.result(timeout=5.0) assert result == "msg-solo" stub.Enqueue.assert_called_once() - stub.BatchEnqueue.assert_not_called() - batcher.close() + accumulator.close() - def test_concurrent_messages_batched(self) -> None: - """When multiple messages arrive concurrently, they batch together.""" + def test_concurrent_messages_accumulated(self) -> None: + """When multiple messages arrive concurrently, they accumulate together.""" stub = MagicMock() - # The first message will block Enqueue while more messages queue up. - # We need to make the batcher see all messages at once. - batch_called = threading.Event() - batch_response = FakeBatchResponse([ - FakeBatchResult(message_id=f"id-{i}") for i in range(5) + enqueue_response = FakeEnqueueResponse([ + FakeEnqueueResult(message_id=f"id-{i}") for i in range(5) ]) - def mock_batch_enqueue(request: Any) -> FakeBatchResponse: - batch_called.set() - return batch_response - - # Make single Enqueue block briefly so messages accumulate. - single_barrier = threading.Event() - - def mock_single_enqueue(request: Any) -> FakeEnqueueResponse: - single_barrier.wait(timeout=5.0) - return FakeEnqueueResponse("should-not-be-used") + def mock_enqueue(request: Any) -> FakeEnqueueResponse: + return enqueue_response - stub.Enqueue.side_effect = mock_single_enqueue - stub.BatchEnqueue.side_effect = mock_batch_enqueue + stub.Enqueue.side_effect = mock_enqueue - batcher = AutoBatcher(stub, max_batch_size=100) + accumulator = AutoAccumulator(stub, max_messages=100) - # Submit 5 messages rapidly before the first can process. - # The batcher should drain them all in one batch. + # Submit 5 messages rapidly. protos = [ - service_pb2.EnqueueRequest(queue="q", payload=f"msg-{i}".encode()) + service_pb2.EnqueueMessage(queue="q", payload=f"msg-{i}".encode()) for i in range(5) ] - # We need to submit them in a way that they all arrive before - # the batcher loop drains. Use a barrier approach. futures = [] for p in protos: - futures.append(batcher.submit(p)) + futures.append(accumulator.submit(p)) - # Give the batcher thread time to drain and flush. - # Either BatchEnqueue or multiple Enqueue calls will resolve things. + # All futures should resolve. for _i, f in enumerate(futures): result = f.result(timeout=5.0) assert result is not None - batcher.close() + accumulator.close() def test_close_drains_pending(self) -> None: """close() waits for pending messages to be flushed.""" stub = MagicMock() - stub.Enqueue.return_value = FakeEnqueueResponse("drained") + stub.Enqueue.return_value = FakeEnqueueResponse([ + FakeEnqueueResult(message_id="drained"), + ]) - batcher = AutoBatcher(stub, max_batch_size=100) + accumulator = AutoAccumulator(stub, max_messages=100) - proto = service_pb2.EnqueueRequest(queue="q", payload=b"drain-me") - fut = batcher.submit(proto) + proto = service_pb2.EnqueueMessage(queue="q", payload=b"drain-me") + fut = accumulator.submit(proto) - batcher.close() + accumulator.close() # After close, the future should be resolved. assert fut.result(timeout=1.0) == "drained" @@ -263,71 +242,77 @@ def test_update_stub(self) -> None: """update_stub replaces the gRPC stub used for flushing.""" old_stub = MagicMock() new_stub = MagicMock() - new_stub.Enqueue.return_value = FakeEnqueueResponse("new-stub") + new_stub.Enqueue.return_value = FakeEnqueueResponse([ + FakeEnqueueResult(message_id="new-stub"), + ]) - batcher = AutoBatcher(old_stub, max_batch_size=100) + accumulator = AutoAccumulator(old_stub, max_messages=100) # Update stub before submitting. - batcher.update_stub(new_stub) + accumulator.update_stub(new_stub) - proto = service_pb2.EnqueueRequest(queue="q", payload=b"data") - fut = batcher.submit(proto) + proto = service_pb2.EnqueueMessage(queue="q", payload=b"data") + fut = accumulator.submit(proto) result = fut.result(timeout=5.0) assert result == "new-stub" - batcher.close() + accumulator.close() -class TestLingerBatcher: - """Test the LingerBatcher.""" +class TestLingerAccumulator: + """Test the LingerAccumulator.""" - def test_flushes_at_batch_size(self) -> None: - """Flush triggers when batch_size messages accumulate.""" + def test_flushes_at_max_messages(self) -> None: + """Flush triggers when max_messages messages accumulate.""" stub = MagicMock() - stub.BatchEnqueue.return_value = FakeBatchResponse([ - FakeBatchResult(message_id=f"id-{i}") for i in range(3) + stub.Enqueue.return_value = FakeEnqueueResponse([ + FakeEnqueueResult(message_id=f"id-{i}") for i in range(3) ]) - batcher = LingerBatcher(stub, linger_ms=5000, batch_size=3) + accumulator = LingerAccumulator(stub, linger_ms=5000, max_messages=3) futures = [] for i in range(3): - proto = service_pb2.EnqueueRequest(queue="q", payload=f"m{i}".encode()) - futures.append(batcher.submit(proto)) + proto = service_pb2.EnqueueMessage(queue="q", payload=f"m{i}".encode()) + futures.append(accumulator.submit(proto)) - # Should flush quickly because batch_size=3 was reached. + # Should flush quickly because max_messages=3 was reached. for i, f in enumerate(futures): result = f.result(timeout=5.0) assert result == f"id-{i}" - batcher.close() + accumulator.close() def test_flushes_at_linger_timeout(self) -> None: - """Flush triggers after linger_ms even if batch_size is not reached.""" + """Flush triggers after linger_ms even if max_messages is not reached.""" stub = MagicMock() - stub.Enqueue.return_value = FakeEnqueueResponse("lingered") + stub.Enqueue.return_value = FakeEnqueueResponse([ + FakeEnqueueResult(message_id="lingered"), + ]) - batcher = LingerBatcher(stub, linger_ms=50, batch_size=100) + accumulator = LingerAccumulator(stub, linger_ms=50, max_messages=100) - proto = service_pb2.EnqueueRequest(queue="q", payload=b"linger") - fut = batcher.submit(proto) + proto = service_pb2.EnqueueMessage(queue="q", payload=b"linger") + fut = accumulator.submit(proto) - # Should flush after ~50ms even though batch_size=100 not reached. + # Should flush after ~50ms even though max_messages=100 not reached. result = fut.result(timeout=5.0) assert result == "lingered" - batcher.close() + accumulator.close() def test_close_drains_pending(self) -> None: """close() drains any pending messages.""" stub = MagicMock() - stub.Enqueue.return_value = FakeEnqueueResponse("drained") + stub.Enqueue.return_value = FakeEnqueueResponse([ + FakeEnqueueResult(message_id="drained"), + ]) - batcher = LingerBatcher(stub, linger_ms=10000, batch_size=100) + accumulator = LingerAccumulator(stub, linger_ms=10000, max_messages=100) - proto = service_pb2.EnqueueRequest(queue="q", payload=b"drain") - fut = batcher.submit(proto) + proto = service_pb2.EnqueueMessage(queue="q", payload=b"drain") + fut = accumulator.submit(proto) - batcher.close() + accumulator.close() assert fut.result(timeout=1.0) == "drained" diff --git a/tests/test_batch_integration.py b/tests/test_enqueue_integration.py similarity index 51% rename from tests/test_batch_integration.py rename to tests/test_enqueue_integration.py index 09aefb9..4900d64 100644 --- a/tests/test_batch_integration.py +++ b/tests/test_enqueue_integration.py @@ -1,4 +1,4 @@ -"""Integration tests for batch enqueue and smart batching. +"""Integration tests for enqueue_many and accumulator modes. These tests require a running fila-server binary. They are skipped automatically when the server is not found (local dev). @@ -11,21 +11,23 @@ import fila -class TestBatchEnqueue: - """Integration tests for the explicit batch_enqueue method.""" +class TestEnqueueMany: + """Integration tests for the explicit enqueue_many method.""" - def test_batch_enqueue_multiple_messages(self, server: object) -> None: - """batch_enqueue sends multiple messages in one RPC and returns per-message results.""" + def test_enqueue_many_multiple_messages(self, server: object) -> None: + """enqueue_many sends multiple messages in one RPC and returns per-message results.""" from tests.conftest import TestServer assert isinstance(server, TestServer) - server.create_queue("test-batch") + server.create_queue("test-enqueue-many") - with fila.Client(server.addr, batch_mode=fila.BatchMode.DISABLED) as client: - results = client.batch_enqueue([ - ("test-batch", {"idx": "0"}, b"payload-0"), - ("test-batch", {"idx": "1"}, b"payload-1"), - ("test-batch", {"idx": "2"}, b"payload-2"), + with fila.Client( + server.addr, accumulator_mode=fila.AccumulatorMode.DISABLED + ) as client: + results = client.enqueue_many([ + ("test-enqueue-many", {"idx": "0"}, b"payload-0"), + ("test-enqueue-many", {"idx": "1"}, b"payload-1"), + ("test-enqueue-many", {"idx": "2"}, b"payload-2"), ]) assert len(results) == 3 @@ -38,60 +40,64 @@ def test_batch_enqueue_multiple_messages(self, server: object) -> None: ids = [r.message_id for r in results] assert len(set(ids)) == 3 - def test_batch_enqueue_single_message(self, server: object) -> None: - """batch_enqueue works with a single message.""" + def test_enqueue_many_single_message(self, server: object) -> None: + """enqueue_many works with a single message.""" from tests.conftest import TestServer assert isinstance(server, TestServer) - server.create_queue("test-batch-single") + server.create_queue("test-enqueue-many-single") - with fila.Client(server.addr, batch_mode=fila.BatchMode.DISABLED) as client: - results = client.batch_enqueue([ - ("test-batch-single", None, b"solo"), + with fila.Client( + server.addr, accumulator_mode=fila.AccumulatorMode.DISABLED + ) as client: + results = client.enqueue_many([ + ("test-enqueue-many-single", None, b"solo"), ]) assert len(results) == 1 assert results[0].is_success assert results[0].message_id is not None - def test_batch_enqueue_consume_verify(self, server: object) -> None: - """Messages enqueued via batch_enqueue can be consumed and acked.""" + def test_enqueue_many_consume_verify(self, server: object) -> None: + """Messages enqueued via enqueue_many can be consumed and acked.""" from tests.conftest import TestServer assert isinstance(server, TestServer) - server.create_queue("test-batch-consume") + server.create_queue("test-enqueue-many-consume") - with fila.Client(server.addr, batch_mode=fila.BatchMode.DISABLED) as client: - results = client.batch_enqueue([ - ("test-batch-consume", {"k": "v"}, b"batch-msg"), + with fila.Client( + server.addr, accumulator_mode=fila.AccumulatorMode.DISABLED + ) as client: + results = client.enqueue_many([ + ("test-enqueue-many-consume", {"k": "v"}, b"multi-msg"), ]) assert results[0].is_success - stream = client.consume("test-batch-consume") + stream = client.consume("test-enqueue-many-consume") msg = next(stream) assert msg.id == results[0].message_id assert msg.headers["k"] == "v" - assert msg.payload == b"batch-msg" + assert msg.payload == b"multi-msg" - client.ack("test-batch-consume", msg.id) + client.ack("test-enqueue-many-consume", msg.id) -class TestAsyncBatchEnqueue: - """Integration tests for the async batch_enqueue method.""" +class TestAsyncEnqueueMany: + """Integration tests for the async enqueue_many method.""" @pytest.mark.asyncio - async def test_async_batch_enqueue(self, server: object) -> None: - """Async batch_enqueue sends multiple messages.""" + async def test_async_enqueue_many(self, server: object) -> None: + """Async enqueue_many sends multiple messages.""" from tests.conftest import TestServer assert isinstance(server, TestServer) - server.create_queue("test-async-batch") + server.create_queue("test-async-enqueue-many") async with fila.AsyncClient(server.addr) as client: - results = await client.batch_enqueue([ - ("test-async-batch", None, b"async-0"), - ("test-async-batch", None, b"async-1"), + results = await client.enqueue_many([ + ("test-async-enqueue-many", None, b"async-0"), + ("test-async-enqueue-many", None, b"async-1"), ]) assert len(results) == 2 @@ -100,26 +106,28 @@ async def test_async_batch_enqueue(self, server: object) -> None: assert r.message_id is not None -class TestSmartBatching: - """Integration tests for smart batching (BatchMode.AUTO).""" +class TestAccumulatorModes: + """Integration tests for accumulator modes (AccumulatorMode.AUTO, Linger).""" def test_auto_mode_enqueue(self, server: object) -> None: - """AUTO mode enqueues messages through the batcher.""" + """AUTO mode enqueues messages through the accumulator.""" from tests.conftest import TestServer assert isinstance(server, TestServer) - server.create_queue("test-auto-batch") + server.create_queue("test-auto-accum") - with fila.Client(server.addr, batch_mode=fila.BatchMode.AUTO) as client: - msg_id = client.enqueue("test-auto-batch", None, b"auto-msg") + with fila.Client( + server.addr, accumulator_mode=fila.AccumulatorMode.AUTO + ) as client: + msg_id = client.enqueue("test-auto-accum", None, b"auto-msg") assert msg_id != "" # Verify the message was actually enqueued. - stream = client.consume("test-auto-batch") + stream = client.consume("test-auto-accum") msg = next(stream) assert msg.id == msg_id assert msg.payload == b"auto-msg" - client.ack("test-auto-batch", msg.id) + client.ack("test-auto-accum", msg.id) def test_auto_mode_multiple_messages(self, server: object) -> None: """AUTO mode handles multiple sequential enqueues.""" @@ -128,7 +136,9 @@ def test_auto_mode_multiple_messages(self, server: object) -> None: assert isinstance(server, TestServer) server.create_queue("test-auto-multi") - with fila.Client(server.addr, batch_mode=fila.BatchMode.AUTO) as client: + with fila.Client( + server.addr, accumulator_mode=fila.AccumulatorMode.AUTO + ) as client: ids = [] for i in range(5): msg_id = client.enqueue( @@ -147,7 +157,9 @@ def test_disabled_mode_enqueue(self, server: object) -> None: assert isinstance(server, TestServer) server.create_queue("test-disabled") - with fila.Client(server.addr, batch_mode=fila.BatchMode.DISABLED) as client: + with fila.Client( + server.addr, accumulator_mode=fila.AccumulatorMode.DISABLED + ) as client: msg_id = client.enqueue("test-disabled", None, b"direct") assert msg_id != "" @@ -157,7 +169,7 @@ def test_disabled_mode_enqueue(self, server: object) -> None: client.ack("test-disabled", msg.id) def test_linger_mode_enqueue(self, server: object) -> None: - """LINGER mode enqueues messages through a timer-based batcher.""" + """LINGER mode enqueues messages through a timer-based accumulator.""" from tests.conftest import TestServer assert isinstance(server, TestServer) @@ -165,7 +177,7 @@ def test_linger_mode_enqueue(self, server: object) -> None: with fila.Client( server.addr, - batch_mode=fila.Linger(linger_ms=50, batch_size=10), + accumulator_mode=fila.Linger(linger_ms=50, max_messages=10), ) as client: msg_id = client.enqueue("test-linger", None, b"lingered") assert msg_id != "" @@ -177,44 +189,44 @@ def test_linger_mode_enqueue(self, server: object) -> None: client.ack("test-linger", msg.id) def test_default_mode_is_auto(self, server: object) -> None: - """Client defaults to AUTO batch mode.""" + """Client defaults to AUTO accumulator mode.""" from tests.conftest import TestServer assert isinstance(server, TestServer) server.create_queue("test-default-mode") - # No batch_mode arg = AUTO. + # No accumulator_mode arg = AUTO. with fila.Client(server.addr) as client: msg_id = client.enqueue("test-default-mode", None, b"default") assert msg_id != "" -class TestBatchModeTypes: - """Unit tests for BatchMode and Linger types (no server needed).""" +class TestAccumulatorModeTypes: + """Unit tests for AccumulatorMode and Linger types (no server needed).""" - def test_batch_mode_enum(self) -> None: - """BatchMode has AUTO and DISABLED variants.""" - assert fila.BatchMode.AUTO is not None - assert fila.BatchMode.DISABLED is not None - modes = {fila.BatchMode.AUTO, fila.BatchMode.DISABLED} + def test_accumulator_mode_enum(self) -> None: + """AccumulatorMode has AUTO and DISABLED variants.""" + assert fila.AccumulatorMode.AUTO is not None + assert fila.AccumulatorMode.DISABLED is not None + modes = {fila.AccumulatorMode.AUTO, fila.AccumulatorMode.DISABLED} assert len(modes) == 2 # They are distinct values def test_linger_fields(self) -> None: - """Linger stores linger_ms and batch_size.""" - linger = fila.Linger(linger_ms=100, batch_size=50) + """Linger stores linger_ms and max_messages.""" + linger = fila.Linger(linger_ms=100, max_messages=50) assert linger.linger_ms == 100 - assert linger.batch_size == 50 + assert linger.max_messages == 50 - def test_batch_enqueue_result_success(self) -> None: - """BatchEnqueueResult.is_success returns True when message_id is set.""" - r = fila.BatchEnqueueResult(message_id="abc", error=None) + def test_enqueue_result_success(self) -> None: + """EnqueueResult.is_success returns True when message_id is set.""" + r = fila.EnqueueResult(message_id="abc", error=None) assert r.is_success assert r.message_id == "abc" assert r.error is None - def test_batch_enqueue_result_error(self) -> None: - """BatchEnqueueResult.is_success returns False when error is set.""" - r = fila.BatchEnqueueResult(message_id=None, error="queue not found") + def test_enqueue_result_error(self) -> None: + """EnqueueResult.is_success returns False when error is set.""" + r = fila.EnqueueResult(message_id=None, error="queue not found") assert not r.is_success assert r.message_id is None assert r.error == "queue not found" From fb9c873ffe059560f673efd33442961193f0c6b7 Mon Sep 17 00:00:00 2001 From: Lucas Vieira Date: Wed, 25 Mar 2026 09:56:17 -0300 Subject: [PATCH 09/17] fix: resolve lint failures and cubic review findings MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - remove unused imports (EnqueueError in client/async_client, Any in batcher, threading in test_batcher) — fixes ruff F401 - move MessageNotFoundError/RPCError to top-level imports, eliminating inline imports that triggered ruff I001 (unsorted import blocks) - move ConsumeMessage/EnqueueResult into TYPE_CHECKING block in async_client — fixes ruff TC001 - fix EnqueueError docstring to reflect that it is also raised as fallback for per-message errors via _map_enqueue_result_error - strengthen test_flush_single to assert actual request content sent to Enqueue, not just call count - downgrade GRPC_GENERATED_VERSION from 1.78.1 to 1.78.0 in all generated grpc stubs (grpcio 1.78.1 was yanked from PyPI) --- fila/async_client.py | 20 +++++++++----------- fila/batcher.py | 2 +- fila/client.py | 11 ++++------- fila/errors.py | 9 +++++---- fila/v1/admin_pb2_grpc.py | 2 +- fila/v1/messages_pb2_grpc.py | 2 +- fila/v1/service_pb2_grpc.py | 2 +- tests/test_batcher.py | 4 +++- 8 files changed, 25 insertions(+), 27 deletions(-) diff --git a/fila/async_client.py b/fila/async_client.py index a1f2962..c10c771 100644 --- a/fila/async_client.py +++ b/fila/async_client.py @@ -7,21 +7,23 @@ import grpc import grpc.aio -if TYPE_CHECKING: - from collections.abc import AsyncIterator - from fila.client import _proto_enqueue_result_to_sdk, _proto_msg_to_consume_message from fila.errors import ( - EnqueueError, + MessageNotFoundError, + RPCError, _map_ack_error, _map_consume_error, _map_enqueue_error, _map_enqueue_result_error, _map_nack_error, ) -from fila.types import ConsumeMessage, EnqueueResult from fila.v1 import service_pb2, service_pb2_grpc +if TYPE_CHECKING: + from collections.abc import AsyncIterator + + from fila.types import ConsumeMessage, EnqueueResult + class _AsyncClientCallDetails( grpc.aio.ClientCallDetails, # type: ignore[misc] @@ -396,12 +398,10 @@ async def ack(self, queue: str, msg_id: str) -> None: result = resp.results[0] which = result.WhichOneof("result") if which == "error": - from fila.errors import MessageNotFoundError, RPCError as _RPCError - ack_err = result.error if ack_err.code == service_pb2.ACK_ERROR_CODE_MESSAGE_NOT_FOUND: raise MessageNotFoundError(f"ack: {ack_err.message}") - raise _RPCError(grpc.StatusCode.INTERNAL, f"ack: {ack_err.message}") + raise RPCError(grpc.StatusCode.INTERNAL, f"ack: {ack_err.message}") async def nack(self, queue: str, msg_id: str, error: str) -> None: """Negatively acknowledge a message that failed processing. @@ -436,9 +436,7 @@ async def nack(self, queue: str, msg_id: str, error: str) -> None: result = resp.results[0] which = result.WhichOneof("result") if which == "error": - from fila.errors import MessageNotFoundError, RPCError as _RPCError - nack_err = result.error if nack_err.code == service_pb2.NACK_ERROR_CODE_MESSAGE_NOT_FOUND: raise MessageNotFoundError(f"nack: {nack_err.message}") - raise _RPCError(grpc.StatusCode.INTERNAL, f"nack: {nack_err.message}") + raise RPCError(grpc.StatusCode.INTERNAL, f"nack: {nack_err.message}") diff --git a/fila/batcher.py b/fila/batcher.py index 1bf2994..fc6a5b4 100644 --- a/fila/batcher.py +++ b/fila/batcher.py @@ -5,7 +5,7 @@ import queue import threading from concurrent.futures import Future, ThreadPoolExecutor -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING import grpc diff --git a/fila/client.py b/fila/client.py index 2f44bb6..dafc550 100644 --- a/fila/client.py +++ b/fila/client.py @@ -8,7 +8,8 @@ from fila.batcher import AutoAccumulator, LingerAccumulator from fila.errors import ( - EnqueueError, + MessageNotFoundError, + RPCError, _map_ack_error, _map_consume_error, _map_enqueue_error, @@ -452,12 +453,10 @@ def ack(self, queue: str, msg_id: str) -> None: result = resp.results[0] which = result.WhichOneof("result") if which == "error": - from fila.errors import MessageNotFoundError, RPCError as _RPCError - ack_err = result.error if ack_err.code == service_pb2.ACK_ERROR_CODE_MESSAGE_NOT_FOUND: raise MessageNotFoundError(f"ack: {ack_err.message}") - raise _RPCError(grpc.StatusCode.INTERNAL, f"ack: {ack_err.message}") + raise RPCError(grpc.StatusCode.INTERNAL, f"ack: {ack_err.message}") def nack(self, queue: str, msg_id: str, error: str) -> None: """Negatively acknowledge a message that failed processing. @@ -492,9 +491,7 @@ def nack(self, queue: str, msg_id: str, error: str) -> None: result = resp.results[0] which = result.WhichOneof("result") if which == "error": - from fila.errors import MessageNotFoundError, RPCError as _RPCError - nack_err = result.error if nack_err.code == service_pb2.NACK_ERROR_CODE_MESSAGE_NOT_FOUND: raise MessageNotFoundError(f"nack: {nack_err.message}") - raise _RPCError(grpc.StatusCode.INTERNAL, f"nack: {nack_err.message}") + raise RPCError(grpc.StatusCode.INTERNAL, f"nack: {nack_err.message}") diff --git a/fila/errors.py b/fila/errors.py index 819a197..00890f2 100644 --- a/fila/errors.py +++ b/fila/errors.py @@ -27,11 +27,12 @@ def __init__(self, code: grpc.StatusCode, message: str) -> None: class EnqueueError(FilaError): - """Raised when an enqueue fails at the RPC level. + """Raised when an enqueue operation fails. - Individual per-message failures are reported via ``EnqueueResult.error`` - and do not raise this exception. This is raised only when the entire - RPC fails (e.g., network error, server unavailable). + In ``enqueue_many()``, individual per-message failures are reported via + ``EnqueueResult.error`` and do not raise this exception. It is also used + as a fallback for per-message enqueue failures that do not map to a more + specific type (e.g., storage or Lua errors). """ diff --git a/fila/v1/admin_pb2_grpc.py b/fila/v1/admin_pb2_grpc.py index 70d8fbe..3b07e1a 100644 --- a/fila/v1/admin_pb2_grpc.py +++ b/fila/v1/admin_pb2_grpc.py @@ -5,7 +5,7 @@ from fila.v1 import admin_pb2 as fila_dot_v1_dot_admin__pb2 -GRPC_GENERATED_VERSION = '1.78.1' +GRPC_GENERATED_VERSION = '1.78.0' GRPC_VERSION = grpc.__version__ _version_not_supported = False diff --git a/fila/v1/messages_pb2_grpc.py b/fila/v1/messages_pb2_grpc.py index d27d27c..fa0dc71 100644 --- a/fila/v1/messages_pb2_grpc.py +++ b/fila/v1/messages_pb2_grpc.py @@ -4,7 +4,7 @@ import warnings -GRPC_GENERATED_VERSION = '1.78.1' +GRPC_GENERATED_VERSION = '1.78.0' GRPC_VERSION = grpc.__version__ _version_not_supported = False diff --git a/fila/v1/service_pb2_grpc.py b/fila/v1/service_pb2_grpc.py index 1f2df55..fa3f3fd 100644 --- a/fila/v1/service_pb2_grpc.py +++ b/fila/v1/service_pb2_grpc.py @@ -5,7 +5,7 @@ from fila.v1 import service_pb2 as fila_dot_v1_dot_service__pb2 -GRPC_GENERATED_VERSION = '1.78.1' +GRPC_GENERATED_VERSION = '1.78.0' GRPC_VERSION = grpc.__version__ _version_not_supported = False diff --git a/tests/test_batcher.py b/tests/test_batcher.py index 3489a42..dfd5919 100644 --- a/tests/test_batcher.py +++ b/tests/test_batcher.py @@ -5,7 +5,6 @@ from __future__ import annotations -import threading from concurrent.futures import Future from typing import Any from unittest.mock import MagicMock @@ -65,6 +64,9 @@ def test_success(self) -> None: assert fut.result(timeout=1.0) == "msg-001" stub.Enqueue.assert_called_once() + sent_req = stub.Enqueue.call_args.args[0] + assert len(sent_req.messages) == 1 + assert sent_req.messages[0] == proto def test_rpc_error(self) -> None: import grpc From 75d7246e3db5e82918bd43d6538023e959828bae Mon Sep 17 00:00:00 2001 From: Lucas Vieira Date: Sat, 4 Apr 2026 09:09:29 -0300 Subject: [PATCH 10/17] feat: migrate python sdk from grpc to fibp binary protocol replace the entire grpc transport layer with a native fibp implementation. the sdk now communicates directly over tcp using the fila binary protocol, removing the grpcio and protobuf dependencies entirely. - add fila/fibp/ module: primitives (Reader/Writer), opcodes, codec - add fila/conn.py: sync Connection and async AsyncConnection classes - rewrite client.py and async_client.py to use fibp connections - rewrite batcher.py to use Connection instead of grpc stubs - rewrite errors.py with fibp error code mapping (18 error codes) - add admin methods: create/delete queue, stats, config, redrive - add auth methods: create/revoke api key, acl management - add new error types: UnauthorizedError, ForbiddenError, NotLeaderError, etc. - add ConsumeMessage fields: weight, throttle_keys, enqueued_at, leased_at - delete fila/v1/ (generated protobuf) and proto/ directories - remove grpcio/protobuf from dependencies, bump version to 0.3.0 - add test_fibp.py with 34 codec/primitives unit tests - rewrite test_batcher.py for fibp mock connections - update integration tests and conftest.py for fibp --- README.md | 29 +- fila/__init__.py | 41 +- fila/async_client.py | 656 +++++++++++++++-------------- fila/batcher.py | 150 +++---- fila/client.py | 675 +++++++++++++++--------------- fila/conn.py | 397 ++++++++++++++++++ fila/errors.py | 186 ++++++-- fila/fibp/__init__.py | 82 ++++ fila/fibp/codec.py | 412 ++++++++++++++++++ fila/fibp/opcodes.py | 120 ++++++ fila/fibp/primitives.py | 167 ++++++++ fila/types.py | 36 ++ fila/v1/__init__.py | 0 fila/v1/admin_pb2.py | 78 ---- fila/v1/admin_pb2.pyi | 179 -------- fila/v1/admin_pb2_grpc.py | 401 ------------------ fila/v1/messages_pb2.py | 45 -- fila/v1/messages_pb2.pyi | 53 --- fila/v1/messages_pb2_grpc.py | 24 -- fila/v1/service_pb2.py | 89 ---- fila/v1/service_pb2.pyi | 199 --------- fila/v1/service_pb2_grpc.py | 272 ------------ proto/fila/v1/admin.proto | 119 ------ proto/fila/v1/messages.proto | 28 -- proto/fila/v1/service.proto | 142 ------- pyproject.toml | 21 +- tests/conftest.py | 117 ++---- tests/test_batcher.py | 279 +++++------- tests/test_client.py | 60 ++- tests/test_enqueue_integration.py | 10 +- tests/test_fibp.py | 289 +++++++++++++ 31 files changed, 2644 insertions(+), 2712 deletions(-) create mode 100644 fila/conn.py create mode 100644 fila/fibp/__init__.py create mode 100644 fila/fibp/codec.py create mode 100644 fila/fibp/opcodes.py create mode 100644 fila/fibp/primitives.py delete mode 100644 fila/v1/__init__.py delete mode 100644 fila/v1/admin_pb2.py delete mode 100644 fila/v1/admin_pb2.pyi delete mode 100644 fila/v1/admin_pb2_grpc.py delete mode 100644 fila/v1/messages_pb2.py delete mode 100644 fila/v1/messages_pb2.pyi delete mode 100644 fila/v1/messages_pb2_grpc.py delete mode 100644 fila/v1/service_pb2.py delete mode 100644 fila/v1/service_pb2.pyi delete mode 100644 fila/v1/service_pb2_grpc.py delete mode 100644 proto/fila/v1/admin.proto delete mode 100644 proto/fila/v1/messages.proto delete mode 100644 proto/fila/v1/service.proto create mode 100644 tests/test_fibp.py diff --git a/README.md b/README.md index ad70043..3875887 100644 --- a/README.md +++ b/README.md @@ -2,6 +2,8 @@ Python client SDK for the [Fila](https://github.com/faiscadev/fila) message broker. +Communicates with the Fila server over the FIBP (Fila Binary Protocol) on port 5555. + ## Installation ```bash @@ -102,7 +104,7 @@ client = Client( ) ``` -The API key is sent as `authorization: Bearer ` metadata on every RPC. +The API key is sent during the FIBP handshake. ## API @@ -114,6 +116,10 @@ Connect to a Fila broker. Both support context manager protocol. Enqueue a message. Returns the broker-assigned message ID. +### `client.enqueue_many(messages) -> list[EnqueueResult]` + +Enqueue multiple messages in a single request. Returns per-message results. + ### `client.consume(queue) -> Iterator[ConsumeMessage]` Open a streaming consumer. Returns an iterator (sync) or async iterator (async) that yields messages as they become available. @@ -126,12 +132,31 @@ Acknowledge a successfully processed message. The message is permanently removed Negatively acknowledge a failed message. The message is requeued or routed to the dead-letter queue based on the queue's configuration. +### Admin Methods + +- `client.create_queue(name, config=None)` -- Create a queue +- `client.delete_queue(name)` -- Delete a queue +- `client.get_stats(queue) -> StatsResult` -- Get queue statistics +- `client.list_queues() -> list[str]` -- List all queues +- `client.set_config(queue, config)` -- Set queue configuration +- `client.get_config(queue) -> dict` -- Get queue configuration +- `client.list_config(queue) -> dict` -- List queue configuration +- `client.redrive(source, dest, count)` -- Redrive messages between queues + +### Auth Methods + +- `client.create_api_key(name) -> CreateApiKeyResult` -- Create an API key +- `client.revoke_api_key(key_id)` -- Revoke an API key +- `client.list_api_keys() -> list[ApiKeyInfo]` -- List API keys +- `client.set_acl(key_id, patterns, superadmin=False)` -- Set ACL +- `client.get_acl(key_id) -> AclEntry` -- Get ACL + ## Error Handling Per-operation exception classes: ```python -from fila import QueueNotFoundError, MessageNotFoundError +from fila import QueueNotFoundError, MessageNotFoundError, UnauthorizedError try: client.enqueue("missing-queue", None, b"test") diff --git a/fila/__init__.py b/fila/__init__.py index 9117c96..5ee1599 100644 --- a/fila/__init__.py +++ b/fila/__init__.py @@ -3,24 +3,63 @@ from fila.async_client import AsyncClient from fila.client import Client from fila.errors import ( + AclNotFoundError, + ApiKeyNotFoundError, + ChannelFullError, EnqueueError, FilaError, + ForbiddenError, + InvalidArgumentError, + LuaError, MessageNotFoundError, + NotLeaderError, + PermissionDeniedError, + ProtocolError, + QueueAlreadyExistsError, QueueNotFoundError, + ResourceExhaustedError, RPCError, + UnauthorizedError, + UnavailableError, +) +from fila.types import ( + AccumulatorMode, + AclEntry, + ApiKeyInfo, + ConsumeMessage, + CreateApiKeyResult, + EnqueueResult, + Linger, + StatsResult, ) -from fila.types import AccumulatorMode, ConsumeMessage, EnqueueResult, Linger __all__ = [ "AccumulatorMode", + "AclEntry", + "AclNotFoundError", + "ApiKeyInfo", + "ApiKeyNotFoundError", "AsyncClient", + "ChannelFullError", "Client", "ConsumeMessage", + "CreateApiKeyResult", "EnqueueError", "EnqueueResult", "FilaError", + "ForbiddenError", + "InvalidArgumentError", "Linger", + "LuaError", "MessageNotFoundError", + "NotLeaderError", + "PermissionDeniedError", + "ProtocolError", + "QueueAlreadyExistsError", "QueueNotFoundError", "RPCError", + "ResourceExhaustedError", + "StatsResult", + "UnauthorizedError", + "UnavailableError", ] diff --git a/fila/async_client.py b/fila/async_client.py index c10c771..08b0809 100644 --- a/fila/async_client.py +++ b/fila/async_client.py @@ -2,138 +2,75 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any +import ssl +from typing import TYPE_CHECKING -import grpc -import grpc.aio - -from fila.client import _proto_enqueue_result_to_sdk, _proto_msg_to_consume_message +from fila.conn import AsyncConnection from fila.errors import ( - MessageNotFoundError, - RPCError, - _map_ack_error, - _map_consume_error, - _map_enqueue_error, - _map_enqueue_result_error, - _map_nack_error, + NotLeaderError, + _map_per_item_error, + _raise_from_error_frame, +) +from fila.fibp.codec import ( + decode_ack_result, + decode_create_api_key_result, + decode_delivery, + decode_enqueue_result, + decode_error, + decode_get_acl_result, + decode_get_config_result, + decode_get_stats_result, + decode_list_api_keys_result, + decode_list_config_result, + decode_list_queues_result, + decode_nack_result, + encode_ack, + encode_create_api_key, + encode_create_queue, + encode_delete_queue, + encode_enqueue, + encode_get_acl, + encode_get_config, + encode_get_stats, + encode_list_api_keys, + encode_list_config, + encode_list_queues, + encode_nack, + encode_redrive, + encode_revoke_api_key, + encode_set_acl, + encode_set_config, +) +from fila.fibp.opcodes import ErrorCode, Opcode +from fila.types import ( + AclEntry, + ApiKeyInfo, + ConsumeMessage, + CreateApiKeyResult, + EnqueueResult, + StatsResult, ) -from fila.v1 import service_pb2, service_pb2_grpc if TYPE_CHECKING: from collections.abc import AsyncIterator - from fila.types import ConsumeMessage, EnqueueResult - - -class _AsyncClientCallDetails( - grpc.aio.ClientCallDetails, # type: ignore[misc] -): - """Concrete ``ClientCallDetails`` for the async interceptor chain. - - ``grpc.aio.ClientCallDetails`` is a namedtuple with 5 fields (method, - timeout, metadata, credentials, wait_for_ready). We override ``__new__`` - so the namedtuple layer receives exactly those five, then set any extra - attribute (``compression``) in ``__init__``. - """ - - def __new__( - cls, - method: str, - timeout: float | None, - metadata: grpc.aio.Metadata | None, - credentials: grpc.CallCredentials | None, - wait_for_ready: bool | None, - ) -> _AsyncClientCallDetails: - return super().__new__(cls, method, timeout, metadata, credentials, wait_for_ready) # type: ignore[no-any-return] - - def __init__( - self, - method: str, - timeout: float | None, - metadata: grpc.aio.Metadata | None, - credentials: grpc.CallCredentials | None, - wait_for_ready: bool | None, - ) -> None: - # Fields are already set by __new__ (namedtuple). Nothing extra to do. - pass - - -class _AsyncApiKeyInterceptor( - grpc.aio.UnaryUnaryClientInterceptor, # type: ignore[misc] - grpc.aio.UnaryStreamClientInterceptor, # type: ignore[misc] -): - """Injects ``authorization: Bearer `` metadata into every async RPC.""" - - def __init__(self, api_key: str) -> None: - self._metadata = grpc.aio.Metadata(("authorization", f"Bearer {api_key}")) - - def _inject( - self, metadata: grpc.aio.Metadata | None - ) -> grpc.aio.Metadata: - merged = grpc.aio.Metadata() - if metadata is not None: - for key, value in metadata: - merged.add(key, value) - for key, value in self._metadata: - merged.add(key, value) - return merged - - async def intercept_unary_unary( - self, - continuation: Any, - client_call_details: grpc.aio.ClientCallDetails, - request: Any, - ) -> Any: - new_details = _AsyncClientCallDetails( - client_call_details.method, - client_call_details.timeout, - self._inject(client_call_details.metadata), - client_call_details.credentials, - client_call_details.wait_for_ready, - ) - return await continuation(new_details, request) - - async def intercept_unary_stream( - self, - continuation: Any, - client_call_details: grpc.aio.ClientCallDetails, - request: Any, - ) -> Any: - new_details = _AsyncClientCallDetails( - client_call_details.method, - client_call_details.timeout, - self._inject(client_call_details.metadata), - client_call_details.credentials, - client_call_details.wait_for_ready, - ) - return await continuation(new_details, request) - -_LEADER_HINT_KEY = "x-fila-leader-addr" - - -def _extract_leader_hint(err: grpc.RpcError) -> str | None: - """Return the leader address from trailing metadata, if present.""" - if err.code() != grpc.StatusCode.UNAVAILABLE: - return None - trailing = err.trailing_metadata() - if trailing is None: - return None - for key, value in trailing: - if key == _LEADER_HINT_KEY: - return str(value) - return None +def _parse_addr(addr: str) -> tuple[str, int]: + """Parse 'host:port' into (host, port).""" + if ":" not in addr: + raise ValueError(f"invalid address (expected host:port): {addr}") + host, port_str = addr.rsplit(":", 1) + return host, int(port_str) class AsyncClient: """Asynchronous client for the Fila message broker. - Wraps the hot-path gRPC operations: enqueue, enqueue_many, consume, ack, - nack. + Wraps the hot-path FIBP operations: enqueue, enqueue_many, consume, ack, nack. Usage:: - client = AsyncClient("localhost:5555") + client = await AsyncClient.create("localhost:5555") msg_id = await client.enqueue("my-queue", {"tenant": "acme"}, b"hello") async for msg in await client.consume("my-queue"): await client.ack("my-queue", msg.id) @@ -175,25 +112,13 @@ def __init__( client_key: bytes | None = None, api_key: str | None = None, ) -> None: - """Connect to a Fila broker at the given address. - - Args: - addr: Broker address in "host:port" format (e.g., "localhost:5555"). - tls: Enable TLS using the OS system trust store for server - verification. Ignored when ``ca_cert`` is provided (which - implies TLS). Defaults to ``False``. - ca_cert: PEM-encoded CA certificate for verifying the server. - When provided, a TLS channel is used instead of an insecure one. - client_cert: PEM-encoded client certificate for mutual TLS (optional). - client_key: PEM-encoded client private key for mutual TLS (optional). - api_key: API key for authentication. When set, every RPC includes an - ``authorization: Bearer `` metadata header. - """ + self._addr = addr self._tls = tls self._ca_cert = ca_cert self._client_cert = client_cert self._client_key = client_key self._api_key = api_key + self._conn: AsyncConnection | None = None use_tls = tls or ca_cert is not None if (client_cert is not None or client_key is not None) and not use_tls: @@ -201,40 +126,67 @@ def __init__( "client_cert and client_key require ca_cert or tls=True to establish a TLS channel" ) - self._channel = self._make_channel(addr) - self._stub = service_pb2_grpc.FilaServiceStub(self._channel) # type: ignore[no-untyped-call] + self._ssl_ctx = self._make_ssl_context() if use_tls else None + + def _make_ssl_context(self) -> ssl.SSLContext: + """Create an SSL context from stored credentials.""" + ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + if self._ca_cert is not None: + ctx.load_verify_locations(cadata=self._ca_cert.decode("ascii")) + else: + ctx.load_default_certs() + if self._client_cert is not None and self._client_key is not None: + import os + import tempfile + + cert_file = tempfile.NamedTemporaryFile(delete=False, suffix=".pem") + key_file = tempfile.NamedTemporaryFile(delete=False, suffix=".pem") + try: + cert_file.write(self._client_cert) + cert_file.close() + key_file.write(self._client_key) + key_file.close() + ctx.load_cert_chain(cert_file.name, key_file.name) + finally: + os.unlink(cert_file.name) + os.unlink(key_file.name) + return ctx + + async def _ensure_connected(self) -> AsyncConnection: + """Ensure a connection exists, creating one if needed.""" + if self._conn is None: + host, port = _parse_addr(self._addr) + self._conn = await AsyncConnection.connect( + host, port, ssl_context=self._ssl_ctx, api_key=self._api_key + ) + return self._conn - def _make_channel(self, addr: str) -> grpc.aio.Channel: - """Create an async gRPC channel to the given address using stored credentials.""" - use_tls = self._tls or self._ca_cert is not None + async def _reconnect(self, addr: str) -> None: + """Reconnect to a different address (e.g. after leader hint).""" + if self._conn is not None: + import contextlib - interceptors: list[grpc.aio.ClientInterceptor] = [] - if self._api_key is not None: - interceptors.append(_AsyncApiKeyInterceptor(self._api_key)) - - if use_tls: - creds = grpc.ssl_channel_credentials( - root_certificates=self._ca_cert, - private_key=self._client_key, - certificate_chain=self._client_cert, - ) - return grpc.aio.secure_channel( - addr, creds, interceptors=interceptors or None - ) - return grpc.aio.insecure_channel( - addr, interceptors=interceptors or None - ) + with contextlib.suppress(OSError): + await self._conn.close() + self._addr = addr + self._conn = None + await self._ensure_connected() async def close(self) -> None: - """Close the underlying gRPC channel.""" - await self._channel.close() + """Close the underlying connection.""" + if self._conn is not None: + await self._conn.close() + self._conn = None async def __aenter__(self) -> AsyncClient: + await self._ensure_connected() return self async def __aexit__(self, *args: object) -> None: await self.close() + # -- hot-path operations ------------------------------------------------- + async def enqueue( self, queue: str, @@ -243,200 +195,270 @@ async def enqueue( ) -> str: """Enqueue a message to the specified queue. - Args: - queue: Target queue name. - headers: Optional message headers. - payload: Message payload bytes. - Returns: Broker-assigned message ID (UUIDv7). - - Raises: - QueueNotFoundError: If the queue does not exist. - RPCError: For unexpected gRPC failures. """ - try: - resp = await self._stub.Enqueue( - service_pb2.EnqueueRequest( - messages=[ - service_pb2.EnqueueMessage( - queue=queue, - headers=headers or {}, - payload=payload, - ) - ] - ) - ) - except grpc.RpcError as e: - raise _map_enqueue_error(e) from e + await self._ensure_connected() + msgs = [{"queue": queue, "headers": headers or {}, "payload": payload}] + body = encode_enqueue(msgs) + + header, resp_body = await self._request_with_leader_retry( + Opcode.ENQUEUE, body + ) + + if header.opcode == Opcode.ERROR: + err = decode_error(resp_body) + _raise_from_error_frame(err) - result = resp.results[0] - which = result.WhichOneof("result") - if which == "message_id": - return str(result.message_id) - raise _map_enqueue_result_error(result.error.code, result.error.message) + items = decode_enqueue_result(resp_body) + item = items[0] + if item.error_code == ErrorCode.OK: + return item.message_id + raise _map_per_item_error(item.error_code, "enqueue") async def enqueue_many( self, messages: list[tuple[str, dict[str, str] | None, bytes]], ) -> list[EnqueueResult]: - """Enqueue multiple messages in a single RPC. - - Args: - messages: List of (queue, headers, payload) tuples. - - Returns: - List of ``EnqueueResult`` objects, one per input message. - Each result has either a ``message_id`` (success) or ``error`` - (per-message failure). - - Raises: - QueueNotFoundError: If a referenced queue does not exist. - RPCError: For unexpected gRPC failures. - """ - proto_messages = [ - service_pb2.EnqueueMessage( - queue=q, - headers=h or {}, - payload=p, - ) + """Enqueue multiple messages in a single request.""" + await self._ensure_connected() + msgs = [ + {"queue": q, "headers": h or {}, "payload": p} for q, h, p in messages ] + body = encode_enqueue(msgs) - try: - resp = await self._stub.Enqueue( - service_pb2.EnqueueRequest(messages=proto_messages) - ) - except grpc.RpcError as e: - raise _map_enqueue_error(e) from e + header, resp_body = await self._request_with_leader_retry( + Opcode.ENQUEUE, body + ) - return [_proto_enqueue_result_to_sdk(r) for r in resp.results] + if header.opcode == Opcode.ERROR: + err = decode_error(resp_body) + _raise_from_error_frame(err) + + items = decode_enqueue_result(resp_body) + results: list[EnqueueResult] = [] + for item in items: + if item.error_code == ErrorCode.OK: + results.append(EnqueueResult(message_id=item.message_id, error=None)) + else: + results.append( + EnqueueResult(message_id=None, error=f"error 0x{item.error_code:02x}") + ) + return results async def consume(self, queue: str) -> AsyncIterator[ConsumeMessage]: """Open a streaming consumer on the specified queue. - Yields messages as they become available. The iterator ends when the - server stream closes or an error occurs. Nil message frames (keepalive - signals) are skipped automatically. - - If the server returns UNAVAILABLE with an ``x-fila-leader-addr`` - trailing metadata entry, the client transparently reconnects to the - leader address and retries the consume call once. - - Args: - queue: Queue to consume from. - - Yields: - ConsumeMessage objects as they arrive. - - Raises: - QueueNotFoundError: If the queue does not exist. - RPCError: For unexpected gRPC failures. + Returns an async iterator that yields messages as they arrive. """ + conn = await self._ensure_connected() try: - stream = self._stub.Consume( - service_pb2.ConsumeRequest(queue=queue) - ) - except grpc.RpcError as e: - leader_addr = _extract_leader_hint(e) - if leader_addr is not None: - stream = await self._reconnect_and_consume(leader_addr, queue) + _req_id, consumer_id = await conn.subscribe(queue) + except NotLeaderError as e: + if e.leader_addr is not None: + await self._reconnect(e.leader_addr) + conn = await self._ensure_connected() + _req_id, consumer_id = await conn.subscribe(queue) else: - raise _map_consume_error(e) from e - - return self._consume_iter(stream) + raise - async def _reconnect_and_consume(self, leader_addr: str, queue: str) -> Any: - """Create a new channel to *leader_addr* and retry the consume call.""" - await self._channel.close() - self._channel = self._make_channel(leader_addr) - self._stub = service_pb2_grpc.FilaServiceStub(self._channel) # type: ignore[no-untyped-call] - try: - return self._stub.Consume( - service_pb2.ConsumeRequest(queue=queue) - ) - except grpc.RpcError as e: - raise _map_consume_error(e) from e + return self._consume_iter(conn, consumer_id) async def _consume_iter( - self, - stream: Any, + self, conn: AsyncConnection, consumer_id: str ) -> AsyncIterator[ConsumeMessage]: - """Internal async generator reading from the gRPC stream.""" + """Internal async generator reading Delivery frames.""" try: - async for resp in stream: - for msg in resp.messages: - if msg is not None and msg.ByteSize(): - yield _proto_msg_to_consume_message(msg) - except grpc.RpcError: + while True: + header, body = await conn.read_frame() + + if header.opcode == Opcode.DELIVERY: + for msg in decode_delivery(body): + yield ConsumeMessage( + id=msg.message_id, + queue=msg.queue, + headers=msg.headers, + payload=msg.payload, + fairness_key=msg.fairness_key, + attempt_count=msg.attempt_count, + weight=msg.weight, + throttle_keys=msg.throttle_keys, + enqueued_at=msg.enqueued_at, + leased_at=msg.leased_at, + ) + elif header.opcode == Opcode.ERROR: + err = decode_error(body) + _raise_from_error_frame(err) + except (ConnectionError, OSError): return async def ack(self, queue: str, msg_id: str) -> None: - """Acknowledge a successfully processed message. - - The message is permanently removed from the queue. + """Acknowledge a successfully processed message.""" + body = encode_ack([{"queue": queue, "message_id": msg_id}]) + header, resp_body = await self._request_with_leader_retry(Opcode.ACK, body) - Args: - queue: Queue the message belongs to. - msg_id: ID of the message to acknowledge. + if header.opcode == Opcode.ERROR: + err = decode_error(resp_body) + _raise_from_error_frame(err) - Raises: - MessageNotFoundError: If the message does not exist. - RPCError: For unexpected gRPC failures. - """ - try: - resp = await self._stub.Ack( - service_pb2.AckRequest( - messages=[service_pb2.AckMessage(queue=queue, message_id=msg_id)] - ) - ) - except grpc.RpcError as e: - raise _map_ack_error(e) from e - - # Check per-message result for errors. - if resp.results: - result = resp.results[0] - which = result.WhichOneof("result") - if which == "error": - ack_err = result.error - if ack_err.code == service_pb2.ACK_ERROR_CODE_MESSAGE_NOT_FOUND: - raise MessageNotFoundError(f"ack: {ack_err.message}") - raise RPCError(grpc.StatusCode.INTERNAL, f"ack: {ack_err.message}") + codes = decode_ack_result(resp_body) + if codes and codes[0] != ErrorCode.OK: + raise _map_per_item_error(codes[0], "ack") async def nack(self, queue: str, msg_id: str, error: str) -> None: - """Negatively acknowledge a message that failed processing. - - The message is requeued for retry or routed to the dead-letter queue - based on the queue's on_failure Lua hook configuration. - - Args: - queue: Queue the message belongs to. - msg_id: ID of the message to nack. - error: Description of the failure. + """Negatively acknowledge a message that failed processing.""" + body = encode_nack([{"queue": queue, "message_id": msg_id, "error": error}]) + header, resp_body = await self._request_with_leader_retry(Opcode.NACK, body) + + if header.opcode == Opcode.ERROR: + err = decode_error(resp_body) + _raise_from_error_frame(err) + + codes = decode_nack_result(resp_body) + if codes and codes[0] != ErrorCode.OK: + raise _map_per_item_error(codes[0], "nack") + + # -- admin operations ---------------------------------------------------- + + async def create_queue(self, name: str, config: dict[str, str] | None = None) -> None: + """Create a queue on the broker.""" + body = encode_create_queue(name, config) + header, resp_body = await self._request_with_leader_retry(Opcode.CREATE_QUEUE, body) + if header.opcode == Opcode.ERROR: + err = decode_error(resp_body) + _raise_from_error_frame(err) + + async def delete_queue(self, name: str) -> None: + """Delete a queue from the broker.""" + body = encode_delete_queue(name) + header, resp_body = await self._request_with_leader_retry(Opcode.DELETE_QUEUE, body) + if header.opcode == Opcode.ERROR: + err = decode_error(resp_body) + _raise_from_error_frame(err) + + async def get_stats(self, queue: str) -> StatsResult: + """Get statistics for a queue.""" + body = encode_get_stats(queue) + header, resp_body = await self._request_with_leader_retry(Opcode.GET_STATS, body) + if header.opcode == Opcode.ERROR: + err = decode_error(resp_body) + _raise_from_error_frame(err) + result = decode_get_stats_result(resp_body) + return StatsResult(stats=result.stats) + + async def list_queues(self) -> list[str]: + """List all queues on the broker.""" + body = encode_list_queues() + header, resp_body = await self._request_with_leader_retry(Opcode.LIST_QUEUES, body) + if header.opcode == Opcode.ERROR: + err = decode_error(resp_body) + _raise_from_error_frame(err) + return decode_list_queues_result(resp_body) + + async def set_config(self, queue: str, config: dict[str, str]) -> None: + """Set configuration for a queue.""" + body = encode_set_config(queue, config) + header, resp_body = await self._request_with_leader_retry(Opcode.SET_CONFIG, body) + if header.opcode == Opcode.ERROR: + err = decode_error(resp_body) + _raise_from_error_frame(err) + + async def get_config(self, queue: str) -> dict[str, str]: + """Get configuration for a queue.""" + body = encode_get_config(queue) + header, resp_body = await self._request_with_leader_retry(Opcode.GET_CONFIG, body) + if header.opcode == Opcode.ERROR: + err = decode_error(resp_body) + _raise_from_error_frame(err) + return decode_get_config_result(resp_body) + + async def list_config(self, queue: str) -> dict[str, str]: + """List all configuration for a queue.""" + body = encode_list_config(queue) + header, resp_body = await self._request_with_leader_retry(Opcode.LIST_CONFIG, body) + if header.opcode == Opcode.ERROR: + err = decode_error(resp_body) + _raise_from_error_frame(err) + return decode_list_config_result(resp_body) + + async def redrive(self, source_queue: str, dest_queue: str, count: int) -> None: + """Redrive messages from one queue to another.""" + body = encode_redrive(source_queue, dest_queue, count) + header, resp_body = await self._request_with_leader_retry(Opcode.REDRIVE, body) + if header.opcode == Opcode.ERROR: + err = decode_error(resp_body) + _raise_from_error_frame(err) + + # -- auth operations ----------------------------------------------------- + + async def create_api_key(self, name: str) -> CreateApiKeyResult: + """Create a new API key.""" + body = encode_create_api_key(name) + header, resp_body = await self._request_with_leader_retry(Opcode.CREATE_API_KEY, body) + if header.opcode == Opcode.ERROR: + err = decode_error(resp_body) + _raise_from_error_frame(err) + key_id, raw_key = decode_create_api_key_result(resp_body) + return CreateApiKeyResult(key_id=key_id, raw_key=raw_key) + + async def revoke_api_key(self, key_id: str) -> None: + """Revoke an API key.""" + body = encode_revoke_api_key(key_id) + header, resp_body = await self._request_with_leader_retry(Opcode.REVOKE_API_KEY, body) + if header.opcode == Opcode.ERROR: + err = decode_error(resp_body) + _raise_from_error_frame(err) + + async def list_api_keys(self) -> list[ApiKeyInfo]: + """List all API keys.""" + body = encode_list_api_keys() + header, resp_body = await self._request_with_leader_retry(Opcode.LIST_API_KEYS, body) + if header.opcode == Opcode.ERROR: + err = decode_error(resp_body) + _raise_from_error_frame(err) + items = decode_list_api_keys_result(resp_body) + return [ + ApiKeyInfo(key_id=k.key_id, prefix=k.prefix, created_at=k.created_at) + for k in items + ] - Raises: - MessageNotFoundError: If the message does not exist. - RPCError: For unexpected gRPC failures. - """ - try: - resp = await self._stub.Nack( - service_pb2.NackRequest( - messages=[ - service_pb2.NackMessage( - queue=queue, message_id=msg_id, error=error - ) - ] - ) - ) - except grpc.RpcError as e: - raise _map_nack_error(e) from e - - # Check per-message result for errors. - if resp.results: - result = resp.results[0] - which = result.WhichOneof("result") - if which == "error": - nack_err = result.error - if nack_err.code == service_pb2.NACK_ERROR_CODE_MESSAGE_NOT_FOUND: - raise MessageNotFoundError(f"nack: {nack_err.message}") - raise RPCError(grpc.StatusCode.INTERNAL, f"nack: {nack_err.message}") + async def set_acl( + self, key_id: str, patterns: list[str], superadmin: bool = False + ) -> None: + """Set ACL for an API key.""" + body = encode_set_acl(key_id, patterns, superadmin) + header, resp_body = await self._request_with_leader_retry(Opcode.SET_ACL, body) + if header.opcode == Opcode.ERROR: + err = decode_error(resp_body) + _raise_from_error_frame(err) + + async def get_acl(self, key_id: str) -> AclEntry: + """Get ACL for an API key.""" + body = encode_get_acl(key_id) + header, resp_body = await self._request_with_leader_retry(Opcode.GET_ACL, body) + if header.opcode == Opcode.ERROR: + err = decode_error(resp_body) + _raise_from_error_frame(err) + result = decode_get_acl_result(resp_body) + return AclEntry(patterns=result.patterns, superadmin=result.superadmin) + + # -- internal helpers ---------------------------------------------------- + + async def _request_with_leader_retry( + self, opcode: int, body: bytes + ) -> tuple[object, bytes]: + """Send a request, retrying once on NotLeader with leader hint.""" + conn = await self._ensure_connected() + header, resp_body = await conn.request(opcode, body) + + if header.opcode == Opcode.ERROR: + err = decode_error(resp_body) + if err.code == ErrorCode.NOT_LEADER: + leader_addr = err.metadata.get("leader_addr") + if leader_addr: + await self._reconnect(leader_addr) + conn = await self._ensure_connected() + return await conn.request(opcode, body) + + return header, resp_body diff --git a/fila/batcher.py b/fila/batcher.py index fc6a5b4..32d45d1 100644 --- a/fila/batcher.py +++ b/fila/batcher.py @@ -7,13 +7,12 @@ from concurrent.futures import Future, ThreadPoolExecutor from typing import TYPE_CHECKING -import grpc - -from fila.errors import EnqueueError, _map_enqueue_error, _map_enqueue_result_error -from fila.v1 import service_pb2 +from fila.errors import EnqueueError, _map_per_item_error +from fila.fibp.codec import decode_enqueue_result, decode_error, encode_enqueue +from fila.fibp.opcodes import ErrorCode, Opcode if TYPE_CHECKING: - from fila.v1 import service_pb2_grpc + from fila.conn import Connection # Sentinel that signals the accumulator thread to stop. @@ -24,84 +23,84 @@ class _EnqueueItem: - """Internal envelope pairing a proto EnqueueMessage with its result future.""" + """Internal envelope pairing a message dict with its result future.""" - __slots__ = ("proto", "future") + __slots__ = ("msg", "future") def __init__( self, - proto: service_pb2.EnqueueMessage, + msg: dict[str, object], future: Future[str], ) -> None: - self.proto = proto + self.msg = msg self.future = future def _flush_single( - stub: service_pb2_grpc.FilaServiceStub, + conn: Connection, req: _EnqueueItem, ) -> None: - """Send a single message via the unified Enqueue RPC. - - This preserves the specific error types (QueueNotFoundError, etc.) - that callers of ``enqueue()`` expect. - """ + """Send a single message via the FIBP Enqueue request.""" try: - resp = stub.Enqueue( - service_pb2.EnqueueRequest(messages=[req.proto]) - ) - result = resp.results[0] - which = result.WhichOneof("result") - if which == "message_id": - req.future.set_result(str(result.message_id)) + body = encode_enqueue([req.msg]) + header, resp_body = conn.request(Opcode.ENQUEUE, body) + + if header.opcode == Opcode.ERROR: + err = decode_error(resp_body) + from fila.errors import _raise_from_error_frame + try: + _raise_from_error_frame(err) + except Exception as e: + req.future.set_exception(e) + return + + items = decode_enqueue_result(resp_body) + item = items[0] + if item.error_code == ErrorCode.OK: + req.future.set_result(item.message_id) else: req.future.set_exception( - _map_enqueue_result_error(result.error.code, result.error.message) + _map_per_item_error(item.error_code, "enqueue") ) - except grpc.RpcError as e: - req.future.set_exception(_map_enqueue_error(e)) except Exception as e: req.future.set_exception(e) def _flush_many( - stub: service_pb2_grpc.FilaServiceStub, + conn: Connection, items: list[_EnqueueItem], ) -> None: - """Send multiple messages via the unified Enqueue RPC. - - On RPC-level failure, every future in the batch receives an - ``EnqueueError``. On success, each future gets either its - message ID or a per-message error string wrapped in an - ``EnqueueError``. - """ + """Send multiple messages via the FIBP Enqueue request.""" try: - resp = stub.Enqueue( - service_pb2.EnqueueRequest( - messages=[item.proto for item in items], - ) - ) - except grpc.RpcError as e: - err = EnqueueError(f"enqueue rpc failed: {e.details()}") + body = encode_enqueue([item.msg for item in items]) + header, resp_body = conn.request(Opcode.ENQUEUE, body) + except Exception as e: + err = EnqueueError(f"enqueue request failed: {e}") for item in items: item.future.set_exception(err) return - except Exception as e: + + if header.opcode == Opcode.ERROR: + try: + err_frame = decode_error(resp_body) + from fila.errors import _map_error_code + exc = _map_error_code(err_frame.code, err_frame.message) + except Exception as e: + exc = EnqueueError(f"enqueue failed: {e}") for item in items: - item.future.set_exception(e) + item.future.set_exception(exc) return - # Pair each result with its request future. - for i, result in enumerate(resp.results): + results = decode_enqueue_result(resp_body) + for i, result in enumerate(results): if i >= len(items): break item = items[i] - which = result.WhichOneof("result") - if which == "message_id": - item.future.set_result(str(result.message_id)) + if result.error_code == ErrorCode.OK: + item.future.set_result(result.message_id) else: item.future.set_exception( - _map_enqueue_result_error(result.error.code, result.error.message) + _map_per_item_error(result.error_code, "enqueue") ) @@ -110,46 +109,41 @@ class AutoAccumulator: A background daemon thread blocks on the first message, then non-blocking drains any additional messages that arrived during processing and flushes - them as a single Enqueue RPC via a thread pool executor. + them as a single Enqueue request via a thread pool executor. """ def __init__( self, - stub: service_pb2_grpc.FilaServiceStub, + conn: Connection, max_messages: int = _DEFAULT_MAX_MESSAGES, max_workers: int = 4, ) -> None: - self._stub = stub + self._conn = conn self._max_messages = max_messages self._queue: queue.Queue[_EnqueueItem | object] = queue.Queue() self._executor = ThreadPoolExecutor(max_workers=max_workers) self._thread = threading.Thread(target=self._run, daemon=True) self._thread.start() - def submit(self, proto: service_pb2.EnqueueMessage) -> Future[str]: + def submit(self, msg: dict[str, object]) -> Future[str]: """Submit a message for accumulated enqueue. Returns a Future for the message ID.""" fut: Future[str] = Future() - self._queue.put(_EnqueueItem(proto, fut)) + self._queue.put(_EnqueueItem(msg, fut)) return fut def close(self, timeout: float | None = 30.0) -> None: - """Drain pending messages and shut down the accumulator. - - Blocks until all pending messages have been flushed or *timeout* - seconds have elapsed. - """ + """Drain pending messages and shut down the accumulator.""" self._queue.put(_STOP) self._thread.join(timeout=timeout) self._executor.shutdown(wait=True) - def update_stub(self, stub: service_pb2_grpc.FilaServiceStub) -> None: - """Update the gRPC stub (e.g. after leader-hint reconnect).""" - self._stub = stub + def update_conn(self, conn: Connection) -> None: + """Update the connection (e.g. after leader-hint reconnect).""" + self._conn = conn def _run(self) -> None: """Background loop: block for first item, drain rest, flush.""" while True: - # Block until at least one item arrives. first = self._queue.get() if first is _STOP: return @@ -157,14 +151,12 @@ def _run(self) -> None: assert isinstance(first, _EnqueueItem) batch: list[_EnqueueItem] = [first] - # Non-blocking drain of any additional queued messages. while len(batch) < self._max_messages: try: item = self._queue.get_nowait() except queue.Empty: break if item is _STOP: - # Flush what we have, then stop. self._flush(batch) return assert isinstance(item, _EnqueueItem) @@ -173,12 +165,11 @@ def _run(self) -> None: self._flush(batch) def _flush(self, batch: list[_EnqueueItem]) -> None: - """Dispatch a batch to the executor for concurrent RPC.""" + """Dispatch a batch to the executor.""" if len(batch) == 1: - # Single-item optimization: still uses Enqueue but with one message. - self._executor.submit(_flush_single, self._stub, batch[0]) + self._executor.submit(_flush_single, self._conn, batch[0]) else: - self._executor.submit(_flush_many, self._stub, batch) + self._executor.submit(_flush_many, self._conn, batch) class LingerAccumulator: @@ -191,12 +182,12 @@ class LingerAccumulator: def __init__( self, - stub: service_pb2_grpc.FilaServiceStub, + conn: Connection, linger_ms: float, max_messages: int, max_workers: int = 4, ) -> None: - self._stub = stub + self._conn = conn self._linger_s = linger_ms / 1000.0 self._max_messages = max_messages self._queue: queue.Queue[_EnqueueItem | object] = queue.Queue() @@ -204,10 +195,10 @@ def __init__( self._thread = threading.Thread(target=self._run, daemon=True) self._thread.start() - def submit(self, proto: service_pb2.EnqueueMessage) -> Future[str]: + def submit(self, msg: dict[str, object]) -> Future[str]: """Submit a message for accumulated enqueue. Returns a Future for the message ID.""" fut: Future[str] = Future() - self._queue.put(_EnqueueItem(proto, fut)) + self._queue.put(_EnqueueItem(msg, fut)) return fut def close(self, timeout: float | None = 30.0) -> None: @@ -216,16 +207,15 @@ def close(self, timeout: float | None = 30.0) -> None: self._thread.join(timeout=timeout) self._executor.shutdown(wait=True) - def update_stub(self, stub: service_pb2_grpc.FilaServiceStub) -> None: - """Update the gRPC stub (e.g. after leader-hint reconnect).""" - self._stub = stub + def update_conn(self, conn: Connection) -> None: + """Update the connection (e.g. after leader-hint reconnect).""" + self._conn = conn def _run(self) -> None: """Background loop: accumulate up to max_messages or linger timeout.""" import time while True: - # Block until at least one item arrives. first = self._queue.get() if first is _STOP: return @@ -233,10 +223,8 @@ def _run(self) -> None: assert isinstance(first, _EnqueueItem) batch: list[_EnqueueItem] = [first] - # Track wall-clock deadline from when first message arrived. deadline = time.monotonic() + self._linger_s - # Accumulate more items until max_messages or linger timeout. while len(batch) < self._max_messages: remaining = deadline - time.monotonic() if remaining <= 0: @@ -254,8 +242,8 @@ def _run(self) -> None: self._flush(batch) def _flush(self, batch: list[_EnqueueItem]) -> None: - """Dispatch a batch to the executor for concurrent RPC.""" + """Dispatch a batch to the executor.""" if len(batch) == 1: - self._executor.submit(_flush_single, self._stub, batch[0]) + self._executor.submit(_flush_single, self._conn, batch[0]) else: - self._executor.submit(_flush_many, self._stub, batch) + self._executor.submit(_flush_many, self._conn, batch) diff --git a/fila/client.py b/fila/client.py index dafc550..7e193dd 100644 --- a/fila/client.py +++ b/fila/client.py @@ -2,138 +2,74 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any - -import grpc +import ssl +from typing import TYPE_CHECKING from fila.batcher import AutoAccumulator, LingerAccumulator +from fila.conn import Connection from fila.errors import ( - MessageNotFoundError, - RPCError, - _map_ack_error, - _map_consume_error, - _map_enqueue_error, - _map_enqueue_result_error, - _map_nack_error, + NotLeaderError, + _map_per_item_error, + _raise_from_error_frame, +) +from fila.fibp.codec import ( + decode_ack_result, + decode_create_api_key_result, + decode_delivery, + decode_enqueue_result, + decode_error, + decode_get_acl_result, + decode_get_config_result, + decode_get_stats_result, + decode_list_api_keys_result, + decode_list_config_result, + decode_list_queues_result, + decode_nack_result, + encode_ack, + encode_create_api_key, + encode_create_queue, + encode_delete_queue, + encode_enqueue, + encode_get_acl, + encode_get_config, + encode_get_stats, + encode_list_api_keys, + encode_list_config, + encode_list_queues, + encode_nack, + encode_redrive, + encode_revoke_api_key, + encode_set_acl, + encode_set_config, +) +from fila.fibp.opcodes import ErrorCode, Opcode +from fila.types import ( + AccumulatorMode, + AclEntry, + ApiKeyInfo, + ConsumeMessage, + CreateApiKeyResult, + EnqueueResult, + Linger, + StatsResult, ) -from fila.types import AccumulatorMode, ConsumeMessage, EnqueueResult, Linger -from fila.v1 import service_pb2, service_pb2_grpc if TYPE_CHECKING: from collections.abc import Iterator -_LEADER_HINT_KEY = "x-fila-leader-addr" - - -def _extract_leader_hint(err: grpc.RpcError) -> str | None: - """Return the leader address from trailing metadata, if present. - - The server sets ``x-fila-leader-addr`` in trailing metadata alongside an - UNAVAILABLE status when the node is not the leader for the requested queue. - """ - if err.code() != grpc.StatusCode.UNAVAILABLE: - return None - trailing = err.trailing_metadata() - if trailing is None: - return None - for key, value in trailing: - if key == _LEADER_HINT_KEY: - return str(value) - return None - - -def _proto_msg_to_consume_message(msg: Any) -> ConsumeMessage: - """Convert a protobuf Message to a ConsumeMessage.""" - metadata = msg.metadata - return ConsumeMessage( - id=msg.id, - headers=dict(msg.headers), - payload=bytes(msg.payload), - fairness_key=metadata.fairness_key if metadata else "", - attempt_count=metadata.attempt_count if metadata else 0, - queue=metadata.queue_id if metadata else "", - ) - - -def _proto_enqueue_result_to_sdk(result: Any) -> EnqueueResult: - """Convert a proto EnqueueResult to the SDK type.""" - which = result.WhichOneof("result") - if which == "message_id": - return EnqueueResult(message_id=str(result.message_id), error=None) - return EnqueueResult(message_id=None, error=result.error.message) - - -class _ClientCallDetails( - grpc.ClientCallDetails, # type: ignore[misc] -): - """Concrete ``ClientCallDetails`` that can be instantiated. - - ``grpc.ClientCallDetails`` is an abstract class with no ``__init__``, so we - need our own subclass to carry the fields through the interceptor chain. - """ - - def __init__( - self, - method: str, - timeout: float | None, - metadata: list[tuple[str, str | bytes]] | None, - credentials: grpc.CallCredentials | None, - wait_for_ready: bool | None, - compression: grpc.Compression | None, - ) -> None: - self.method = method - self.timeout = timeout - self.metadata = metadata - self.credentials = credentials - self.wait_for_ready = wait_for_ready - self.compression = compression - - -class _ApiKeyInterceptor( - grpc.UnaryUnaryClientInterceptor, # type: ignore[misc] - grpc.UnaryStreamClientInterceptor, # type: ignore[misc] -): - """Injects ``authorization: Bearer `` metadata into every RPC.""" - - def __init__(self, api_key: str) -> None: - self._metadata = (("authorization", f"Bearer {api_key}"),) - - def _inject( - self, client_call_details: grpc.ClientCallDetails - ) -> _ClientCallDetails: - metadata = list(client_call_details.metadata or []) - metadata.extend(self._metadata) - return _ClientCallDetails( - client_call_details.method, - client_call_details.timeout, - metadata, - client_call_details.credentials, - client_call_details.wait_for_ready, - client_call_details.compression, - ) - - def intercept_unary_unary( - self, - continuation: Any, - client_call_details: grpc.ClientCallDetails, - request: Any, - ) -> Any: - return continuation(self._inject(client_call_details), request) - def intercept_unary_stream( - self, - continuation: Any, - client_call_details: grpc.ClientCallDetails, - request: Any, - ) -> Any: - return continuation(self._inject(client_call_details), request) +def _parse_addr(addr: str) -> tuple[str, int]: + """Parse 'host:port' into (host, port).""" + if ":" not in addr: + raise ValueError(f"invalid address (expected host:port): {addr}") + host, port_str = addr.rsplit(":", 1) + return host, int(port_str) class Client: """Synchronous client for the Fila message broker. - Wraps the hot-path gRPC operations: enqueue, enqueue_many, consume, ack, - nack. + Wraps the hot-path FIBP operations: enqueue, enqueue_many, consume, ack, nack. Usage:: @@ -153,7 +89,7 @@ class Client: # AUTO (default): opportunistic accumulation via background thread client = Client("localhost:5555") - # DISABLED: each enqueue() is a direct RPC + # DISABLED: each enqueue() is a direct request client = Client("localhost:5555", accumulator_mode=AccumulatorMode.DISABLED) # LINGER: timer-based forced accumulation @@ -192,26 +128,7 @@ def __init__( accumulator_mode: AccumulatorMode | Linger = AccumulatorMode.AUTO, max_accumulator_messages: int = 1000, ) -> None: - """Connect to a Fila broker at the given address. - - Args: - addr: Broker address in "host:port" format (e.g., "localhost:5555"). - tls: Enable TLS using the OS system trust store for server - verification. Ignored when ``ca_cert`` is provided (which - implies TLS). Defaults to ``False``. - ca_cert: PEM-encoded CA certificate for verifying the server. - When provided, a TLS channel is used instead of an insecure one. - client_cert: PEM-encoded client certificate for mutual TLS (optional). - client_key: PEM-encoded client private key for mutual TLS (optional). - api_key: API key for authentication. When set, every RPC includes an - ``authorization: Bearer `` metadata header. - accumulator_mode: Controls how ``enqueue()`` routes messages. - Defaults to ``AccumulatorMode.AUTO`` - (opportunistic accumulation). - max_accumulator_messages: Maximum number of messages per flush when - using ``AccumulatorMode.AUTO``. - Defaults to 1000. - """ + self._addr = addr self._tls = tls self._ca_cert = ca_cert self._client_cert = client_cert @@ -224,49 +141,73 @@ def __init__( "client_cert and client_key require ca_cert or tls=True to establish a TLS channel" ) - self._channel = self._make_channel(addr) - self._stub = service_pb2_grpc.FilaServiceStub(self._channel) # type: ignore[no-untyped-call] + self._ssl_ctx = self._make_ssl_context() if use_tls else None + self._conn = self._connect(addr) # Set up the accumulator based on the chosen mode. self._accumulator: AutoAccumulator | LingerAccumulator | None = None if isinstance(accumulator_mode, Linger): self._accumulator = LingerAccumulator( - self._stub, + self._conn, linger_ms=accumulator_mode.linger_ms, max_messages=accumulator_mode.max_messages, ) elif accumulator_mode is AccumulatorMode.AUTO: self._accumulator = AutoAccumulator( - self._stub, + self._conn, max_messages=max_accumulator_messages, ) # AccumulatorMode.DISABLED: self._accumulator stays None - def _make_channel(self, addr: str) -> grpc.Channel: - """Create a gRPC channel to the given address using stored credentials.""" - use_tls = self._tls or self._ca_cert is not None - - if use_tls: - creds = grpc.ssl_channel_credentials( - root_certificates=self._ca_cert, - private_key=self._client_key, - certificate_chain=self._client_cert, - ) - channel: grpc.Channel = grpc.secure_channel(addr, creds) + def _make_ssl_context(self) -> ssl.SSLContext: + """Create an SSL context from stored credentials.""" + ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + if self._ca_cert is not None: + ctx.load_verify_locations(cadata=self._ca_cert.decode("ascii")) else: - channel = grpc.insecure_channel(addr) + ctx.load_default_certs() + if self._client_cert is not None and self._client_key is not None: + # Write temp files for load_cert_chain (ssl module needs file paths or + # we can use the cadata approach for CA only). + import os + import tempfile + + cert_file = tempfile.NamedTemporaryFile(delete=False, suffix=".pem") + key_file = tempfile.NamedTemporaryFile(delete=False, suffix=".pem") + try: + cert_file.write(self._client_cert) + cert_file.close() + key_file.write(self._client_key) + key_file.close() + ctx.load_cert_chain(cert_file.name, key_file.name) + finally: + os.unlink(cert_file.name) + os.unlink(key_file.name) + return ctx + + def _connect(self, addr: str) -> Connection: + """Open a FIBP connection to the given address.""" + host, port = _parse_addr(addr) + return Connection.connect( + host, port, ssl_context=self._ssl_ctx, api_key=self._api_key + ) - if self._api_key is not None: - interceptor = _ApiKeyInterceptor(self._api_key) - channel = grpc.intercept_channel(channel, interceptor) + def _reconnect(self, addr: str) -> None: + """Reconnect to a different address (e.g. after leader hint).""" + import contextlib - return channel + with contextlib.suppress(OSError): + self._conn.close() + self._addr = addr + self._conn = self._connect(addr) + if self._accumulator is not None: + self._accumulator.update_conn(self._conn) def close(self) -> None: - """Drain pending accumulated messages and close the underlying gRPC channel.""" + """Drain pending accumulated messages and close the underlying connection.""" if self._accumulator is not None: self._accumulator.close() - self._channel.close() + self._conn.close() def __enter__(self) -> Client: return self @@ -274,6 +215,8 @@ def __enter__(self) -> Client: def __exit__(self, *args: object) -> None: self.close() + # -- hot-path operations ------------------------------------------------- + def enqueue( self, queue: str, @@ -282,216 +225,286 @@ def enqueue( ) -> str: """Enqueue a message to the specified queue. - When an accumulator is active (``AccumulatorMode.AUTO`` or ``Linger``), - the message is submitted to the background accumulator and this call - blocks until the flush completes and the result is available. - - When accumulation is disabled (``AccumulatorMode.DISABLED``), this call - makes a direct synchronous RPC. - - Args: - queue: Target queue name. - headers: Optional message headers. - payload: Message payload bytes. - Returns: Broker-assigned message ID (UUIDv7). Raises: - QueueNotFoundError: If the queue does not exist (DISABLED mode). - EnqueueError: If the enqueue RPC fails (AUTO/LINGER mode). - RPCError: For unexpected gRPC failures. + QueueNotFoundError: If the queue does not exist. + EnqueueError: If the enqueue fails. """ - proto = service_pb2.EnqueueMessage( - queue=queue, - headers=headers or {}, - payload=payload, - ) + msg = {"queue": queue, "headers": headers or {}, "payload": payload} if self._accumulator is not None: - future = self._accumulator.submit(proto) + future = self._accumulator.submit(msg) return future.result() - # Direct RPC (DISABLED mode). - try: - resp = self._stub.Enqueue( - service_pb2.EnqueueRequest(messages=[proto]) - ) - except grpc.RpcError as e: - raise _map_enqueue_error(e) from e - - result = resp.results[0] - which = result.WhichOneof("result") - if which == "message_id": - return str(result.message_id) - raise _map_enqueue_result_error(result.error.code, result.error.message) + return self._enqueue_direct([msg])[0] def enqueue_many( self, messages: list[tuple[str, dict[str, str] | None, bytes]], ) -> list[EnqueueResult]: - """Enqueue multiple messages in a single RPC. - - This is an explicit multi-message operation that always uses the - Enqueue RPC directly, regardless of the accumulator_mode setting. - - Args: - messages: List of (queue, headers, payload) tuples. + """Enqueue multiple messages in a single request. Returns: List of ``EnqueueResult`` objects, one per input message. - Each result has either a ``message_id`` (success) or ``error`` - (per-message failure). - - Raises: - QueueNotFoundError: If a referenced queue does not exist. - RPCError: For unexpected gRPC failures. """ - proto_messages = [ - service_pb2.EnqueueMessage( - queue=q, - headers=h or {}, - payload=p, - ) + msgs = [ + {"queue": q, "headers": h or {}, "payload": p} for q, h, p in messages ] + body = encode_enqueue(msgs) try: - resp = self._stub.Enqueue( - service_pb2.EnqueueRequest(messages=proto_messages) + header, resp_body = self._request_with_leader_retry( + Opcode.ENQUEUE, body ) - except grpc.RpcError as e: - raise _map_enqueue_error(e) from e - - return [_proto_enqueue_result_to_sdk(r) for r in resp.results] + except NotLeaderError: + raise + + if header.opcode == Opcode.ERROR: + err = decode_error(resp_body) + _raise_from_error_frame(err) + + items = decode_enqueue_result(resp_body) + results: list[EnqueueResult] = [] + for item in items: + if item.error_code == ErrorCode.OK: + results.append(EnqueueResult(message_id=item.message_id, error=None)) + else: + err_msg = f"error 0x{item.error_code:02x}" + results.append(EnqueueResult(message_id=None, error=err_msg)) + return results def consume(self, queue: str) -> Iterator[ConsumeMessage]: """Open a streaming consumer on the specified queue. Yields messages as they become available. The iterator ends when the - server stream closes or an error occurs. Skip nil message frames - (keepalive signals) automatically. - - If the server returns UNAVAILABLE with an ``x-fila-leader-addr`` - trailing metadata entry, the client transparently reconnects to the - leader address and retries the consume call once. - - Args: - queue: Queue to consume from. - - Yields: - ConsumeMessage objects as they arrive. + connection closes or an error occurs. - Raises: - QueueNotFoundError: If the queue does not exist. - RPCError: For unexpected gRPC failures. + Handles NotLeader errors by transparently reconnecting once. """ try: - stream = self._stub.Consume( - service_pb2.ConsumeRequest(queue=queue) - ) - except grpc.RpcError as e: - leader_addr = _extract_leader_hint(e) - if leader_addr is not None: - stream = self._reconnect_and_consume(leader_addr, queue) + _req_id, consumer_id = self._conn.subscribe(queue) + except NotLeaderError as e: + if e.leader_addr is not None: + self._reconnect(e.leader_addr) + _req_id, consumer_id = self._conn.subscribe(queue) else: - raise _map_consume_error(e) from e - - return self._consume_iter(stream) + raise - def _reconnect_and_consume(self, leader_addr: str, queue: str) -> Any: - """Create a new channel to *leader_addr* and retry the consume call.""" - self._channel.close() - self._channel = self._make_channel(leader_addr) - self._stub = service_pb2_grpc.FilaServiceStub(self._channel) # type: ignore[no-untyped-call] - if self._accumulator is not None: - self._accumulator.update_stub(self._stub) - try: - return self._stub.Consume( - service_pb2.ConsumeRequest(queue=queue) - ) - except grpc.RpcError as e: - raise _map_consume_error(e) from e + return self._consume_iter(consumer_id) - def _consume_iter( - self, - stream: Any, - ) -> Iterator[ConsumeMessage]: - """Internal generator reading from the gRPC stream.""" + def _consume_iter(self, consumer_id: str) -> Iterator[ConsumeMessage]: + """Internal generator reading Delivery frames.""" try: - for resp in stream: - for msg in resp.messages: - if msg is not None and msg.ByteSize(): - yield _proto_msg_to_consume_message(msg) - except grpc.RpcError: + while True: + header, body = self._conn.read_frame() + + if header.opcode == Opcode.DELIVERY: + for msg in decode_delivery(body): + yield ConsumeMessage( + id=msg.message_id, + queue=msg.queue, + headers=msg.headers, + payload=msg.payload, + fairness_key=msg.fairness_key, + attempt_count=msg.attempt_count, + weight=msg.weight, + throttle_keys=msg.throttle_keys, + enqueued_at=msg.enqueued_at, + leased_at=msg.leased_at, + ) + elif header.opcode == Opcode.ERROR: + err = decode_error(body) + _raise_from_error_frame(err) + # Ignore other frames (e.g. pong is handled in read_frame). + except (ConnectionError, OSError): return def ack(self, queue: str, msg_id: str) -> None: - """Acknowledge a successfully processed message. - - The message is permanently removed from the queue. + """Acknowledge a successfully processed message.""" + body = encode_ack([{"queue": queue, "message_id": msg_id}]) + header, resp_body = self._request_with_leader_retry(Opcode.ACK, body) - Args: - queue: Queue the message belongs to. - msg_id: ID of the message to acknowledge. + if header.opcode == Opcode.ERROR: + err = decode_error(resp_body) + _raise_from_error_frame(err) - Raises: - MessageNotFoundError: If the message does not exist. - RPCError: For unexpected gRPC failures. - """ - try: - resp = self._stub.Ack( - service_pb2.AckRequest( - messages=[service_pb2.AckMessage(queue=queue, message_id=msg_id)] - ) - ) - except grpc.RpcError as e: - raise _map_ack_error(e) from e - - # Check per-message result for errors. - if resp.results: - result = resp.results[0] - which = result.WhichOneof("result") - if which == "error": - ack_err = result.error - if ack_err.code == service_pb2.ACK_ERROR_CODE_MESSAGE_NOT_FOUND: - raise MessageNotFoundError(f"ack: {ack_err.message}") - raise RPCError(grpc.StatusCode.INTERNAL, f"ack: {ack_err.message}") + codes = decode_ack_result(resp_body) + if codes and codes[0] != ErrorCode.OK: + raise _map_per_item_error(codes[0], "ack") def nack(self, queue: str, msg_id: str, error: str) -> None: - """Negatively acknowledge a message that failed processing. - - The message is requeued for retry or routed to the dead-letter queue - based on the queue's on_failure Lua hook configuration. - - Args: - queue: Queue the message belongs to. - msg_id: ID of the message to nack. - error: Description of the failure. + """Negatively acknowledge a message that failed processing.""" + body = encode_nack([{"queue": queue, "message_id": msg_id, "error": error}]) + header, resp_body = self._request_with_leader_retry(Opcode.NACK, body) + + if header.opcode == Opcode.ERROR: + err = decode_error(resp_body) + _raise_from_error_frame(err) + + codes = decode_nack_result(resp_body) + if codes and codes[0] != ErrorCode.OK: + raise _map_per_item_error(codes[0], "nack") + + # -- admin operations ---------------------------------------------------- + + def create_queue(self, name: str, config: dict[str, str] | None = None) -> None: + """Create a queue on the broker.""" + body = encode_create_queue(name, config) + header, resp_body = self._request_with_leader_retry(Opcode.CREATE_QUEUE, body) + if header.opcode == Opcode.ERROR: + err = decode_error(resp_body) + _raise_from_error_frame(err) + + def delete_queue(self, name: str) -> None: + """Delete a queue from the broker.""" + body = encode_delete_queue(name) + header, resp_body = self._request_with_leader_retry(Opcode.DELETE_QUEUE, body) + if header.opcode == Opcode.ERROR: + err = decode_error(resp_body) + _raise_from_error_frame(err) + + def get_stats(self, queue: str) -> StatsResult: + """Get statistics for a queue.""" + body = encode_get_stats(queue) + header, resp_body = self._request_with_leader_retry(Opcode.GET_STATS, body) + if header.opcode == Opcode.ERROR: + err = decode_error(resp_body) + _raise_from_error_frame(err) + result = decode_get_stats_result(resp_body) + return StatsResult(stats=result.stats) + + def list_queues(self) -> list[str]: + """List all queues on the broker.""" + body = encode_list_queues() + header, resp_body = self._request_with_leader_retry(Opcode.LIST_QUEUES, body) + if header.opcode == Opcode.ERROR: + err = decode_error(resp_body) + _raise_from_error_frame(err) + return decode_list_queues_result(resp_body) + + def set_config(self, queue: str, config: dict[str, str]) -> None: + """Set configuration for a queue.""" + body = encode_set_config(queue, config) + header, resp_body = self._request_with_leader_retry(Opcode.SET_CONFIG, body) + if header.opcode == Opcode.ERROR: + err = decode_error(resp_body) + _raise_from_error_frame(err) + + def get_config(self, queue: str) -> dict[str, str]: + """Get configuration for a queue.""" + body = encode_get_config(queue) + header, resp_body = self._request_with_leader_retry(Opcode.GET_CONFIG, body) + if header.opcode == Opcode.ERROR: + err = decode_error(resp_body) + _raise_from_error_frame(err) + return decode_get_config_result(resp_body) + + def list_config(self, queue: str) -> dict[str, str]: + """List all configuration for a queue.""" + body = encode_list_config(queue) + header, resp_body = self._request_with_leader_retry(Opcode.LIST_CONFIG, body) + if header.opcode == Opcode.ERROR: + err = decode_error(resp_body) + _raise_from_error_frame(err) + return decode_list_config_result(resp_body) + + def redrive(self, source_queue: str, dest_queue: str, count: int) -> None: + """Redrive messages from one queue to another.""" + body = encode_redrive(source_queue, dest_queue, count) + header, resp_body = self._request_with_leader_retry(Opcode.REDRIVE, body) + if header.opcode == Opcode.ERROR: + err = decode_error(resp_body) + _raise_from_error_frame(err) + + # -- auth operations ----------------------------------------------------- + + def create_api_key(self, name: str) -> CreateApiKeyResult: + """Create a new API key.""" + body = encode_create_api_key(name) + header, resp_body = self._request_with_leader_retry(Opcode.CREATE_API_KEY, body) + if header.opcode == Opcode.ERROR: + err = decode_error(resp_body) + _raise_from_error_frame(err) + key_id, raw_key = decode_create_api_key_result(resp_body) + return CreateApiKeyResult(key_id=key_id, raw_key=raw_key) + + def revoke_api_key(self, key_id: str) -> None: + """Revoke an API key.""" + body = encode_revoke_api_key(key_id) + header, resp_body = self._request_with_leader_retry(Opcode.REVOKE_API_KEY, body) + if header.opcode == Opcode.ERROR: + err = decode_error(resp_body) + _raise_from_error_frame(err) + + def list_api_keys(self) -> list[ApiKeyInfo]: + """List all API keys.""" + body = encode_list_api_keys() + header, resp_body = self._request_with_leader_retry(Opcode.LIST_API_KEYS, body) + if header.opcode == Opcode.ERROR: + err = decode_error(resp_body) + _raise_from_error_frame(err) + items = decode_list_api_keys_result(resp_body) + return [ + ApiKeyInfo(key_id=k.key_id, prefix=k.prefix, created_at=k.created_at) + for k in items + ] - Raises: - MessageNotFoundError: If the message does not exist. - RPCError: For unexpected gRPC failures. - """ - try: - resp = self._stub.Nack( - service_pb2.NackRequest( - messages=[ - service_pb2.NackMessage( - queue=queue, message_id=msg_id, error=error - ) - ] - ) - ) - except grpc.RpcError as e: - raise _map_nack_error(e) from e - - # Check per-message result for errors. - if resp.results: - result = resp.results[0] - which = result.WhichOneof("result") - if which == "error": - nack_err = result.error - if nack_err.code == service_pb2.NACK_ERROR_CODE_MESSAGE_NOT_FOUND: - raise MessageNotFoundError(f"nack: {nack_err.message}") - raise RPCError(grpc.StatusCode.INTERNAL, f"nack: {nack_err.message}") + def set_acl(self, key_id: str, patterns: list[str], superadmin: bool = False) -> None: + """Set ACL for an API key.""" + body = encode_set_acl(key_id, patterns, superadmin) + header, resp_body = self._request_with_leader_retry(Opcode.SET_ACL, body) + if header.opcode == Opcode.ERROR: + err = decode_error(resp_body) + _raise_from_error_frame(err) + + def get_acl(self, key_id: str) -> AclEntry: + """Get ACL for an API key.""" + body = encode_get_acl(key_id) + header, resp_body = self._request_with_leader_retry(Opcode.GET_ACL, body) + if header.opcode == Opcode.ERROR: + err = decode_error(resp_body) + _raise_from_error_frame(err) + result = decode_get_acl_result(resp_body) + return AclEntry(patterns=result.patterns, superadmin=result.superadmin) + + # -- internal helpers ---------------------------------------------------- + + def _enqueue_direct(self, messages: list[dict[str, object]]) -> list[str]: + """Send an enqueue request directly and return message IDs.""" + body = encode_enqueue(messages) + + header, resp_body = self._request_with_leader_retry(Opcode.ENQUEUE, body) + + if header.opcode == Opcode.ERROR: + err = decode_error(resp_body) + _raise_from_error_frame(err) + + items = decode_enqueue_result(resp_body) + results: list[str] = [] + for item in items: + if item.error_code == ErrorCode.OK: + results.append(item.message_id) + else: + raise _map_per_item_error(item.error_code, "enqueue") + return results + + def _request_with_leader_retry( + self, opcode: int, body: bytes + ) -> tuple[object, bytes]: + """Send a request, retrying once on NotLeader with leader hint.""" + + header, resp_body = self._conn.request(opcode, body) + + # Check for NotLeader error with leader hint. + if header.opcode == Opcode.ERROR: + err = decode_error(resp_body) + if err.code == ErrorCode.NOT_LEADER: + leader_addr = err.metadata.get("leader_addr") + if leader_addr: + self._reconnect(leader_addr) + return self._conn.request(opcode, body) + + return header, resp_body diff --git a/fila/conn.py b/fila/conn.py new file mode 100644 index 0000000..bce515a --- /dev/null +++ b/fila/conn.py @@ -0,0 +1,397 @@ +"""FIBP connection manager — synchronous and asynchronous.""" + +from __future__ import annotations + +import asyncio +import socket +import ssl +import struct +import threading + +from fila.fibp.codec import ( + decode_error, + decode_handshake_ok, + encode_handshake, + encode_pong, +) +from fila.fibp.opcodes import ( + DEFAULT_MAX_FRAME_SIZE, + FRAME_HEADER_SIZE, + PROTOCOL_VERSION, + FrameHeader, + Opcode, +) + + +def _parse_header(data: bytes) -> FrameHeader: + """Parse 6 bytes into a FrameHeader.""" + opcode = data[0] + flags = data[1] + request_id = struct.unpack_from("!I", data, 2)[0] + return FrameHeader(opcode=opcode, flags=flags, request_id=request_id) + + +def _build_frame(opcode: int, request_id: int, body: bytes, flags: int = 0) -> bytes: + """Build a length-prefixed FIBP frame.""" + frame_body = struct.pack("!BBI", opcode, flags, request_id) + body + return struct.pack("!I", len(frame_body)) + frame_body + + +def make_ssl_context( + *, + ca_cert: bytes | None = None, + client_cert: bytes | None = None, + client_key: bytes | None = None, + system_trust: bool = False, +) -> ssl.SSLContext: + """Create an SSLContext for TLS connections.""" + ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + if ca_cert is not None: + ctx.load_verify_locations(cadata=ca_cert.decode("ascii")) + elif system_trust: + ctx.load_default_certs() + if client_cert is not None and client_key is not None: + ctx.load_cert_chain(certdata=client_cert, keydata=client_key) + return ctx + + +# --------------------------------------------------------------------------- +# Synchronous connection +# --------------------------------------------------------------------------- + +class Connection: + """Synchronous FIBP connection over a TCP socket.""" + + def __init__(self, sock: socket.socket, max_frame_size: int = DEFAULT_MAX_FRAME_SIZE) -> None: + self._sock = sock + self._max_frame_size = max_frame_size + self._req_counter = 0 + self._lock = threading.Lock() + + @classmethod + def connect( + cls, + host: str, + port: int, + *, + ssl_context: ssl.SSLContext | None = None, + api_key: str | None = None, + timeout: float = 10.0, + ) -> Connection: + """Open a TCP connection and perform the FIBP handshake.""" + sock = socket.create_connection((host, port), timeout=timeout) + sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) + + if ssl_context is not None: + sock = ssl_context.wrap_socket(sock, server_hostname=host) + + conn = cls(sock) + conn._handshake(api_key) + return conn + + def _next_request_id(self) -> int: + self._req_counter += 1 + return self._req_counter & 0xFFFFFFFF + + def _handshake(self, api_key: str | None) -> None: + """Perform the FIBP handshake.""" + body = encode_handshake(PROTOCOL_VERSION, api_key) + req_id = self._next_request_id() + self.write_frame(Opcode.HANDSHAKE, req_id, body) + + header, resp_body = self.read_frame() + if header.opcode == Opcode.ERROR: + err = decode_error(resp_body) + from fila.errors import _raise_from_error_frame + _raise_from_error_frame(err) + + if header.opcode != Opcode.HANDSHAKE_OK: + raise ConnectionError( + f"expected HandshakeOk (0x02), got 0x{header.opcode:02x}" + ) + + version, _node_id, max_frame_size = decode_handshake_ok(resp_body) + if max_frame_size > 0: + self._max_frame_size = max_frame_size + + def write_frame(self, opcode: int, request_id: int, body: bytes, flags: int = 0) -> None: + """Write a single length-prefixed FIBP frame.""" + frame = _build_frame(opcode, request_id, body, flags) + self._sock.sendall(frame) + + def read_frame(self) -> tuple[FrameHeader, bytes]: + """Read a single length-prefixed FIBP frame. + + Handles Ping by responding with Pong automatically. + Handles continuation frames by concatenating bodies. + """ + while True: + header, body = self._read_single_frame() + + # Auto-reply to Ping. + if header.opcode == Opcode.PING: + self.write_frame(Opcode.PONG, header.request_id, encode_pong()) + continue + + # Handle continuation frames. + if header.is_continuation: + parts = [body] + while True: + cont_header, cont_body = self._read_single_frame() + if cont_header.opcode == Opcode.PING: + self.write_frame(Opcode.PONG, cont_header.request_id, encode_pong()) + continue + parts.append(cont_body) + if not cont_header.is_continuation: + break + return header, b"".join(parts) + + return header, body + + def _read_single_frame(self) -> tuple[FrameHeader, bytes]: + """Read one frame from the wire (no continuation handling).""" + length_bytes = self._recv_exact(4) + frame_len = struct.unpack("!I", length_bytes)[0] + if frame_len > self._max_frame_size: + raise ConnectionError( + f"frame size {frame_len} exceeds max {self._max_frame_size}" + ) + frame_data = self._recv_exact(frame_len) + header = _parse_header(frame_data[:FRAME_HEADER_SIZE]) + body = frame_data[FRAME_HEADER_SIZE:] + return header, body + + def _recv_exact(self, n: int) -> bytes: + """Read exactly n bytes from the socket.""" + buf = bytearray() + while len(buf) < n: + chunk = self._sock.recv(n - len(buf)) + if not chunk: + raise ConnectionError("connection closed by remote") + buf.extend(chunk) + return bytes(buf) + + def request(self, opcode: int, body: bytes) -> tuple[FrameHeader, bytes]: + """Send a request frame and read the response (synchronous request-response).""" + req_id = self._next_request_id() + with self._lock: + self.write_frame(opcode, req_id, body) + return self.read_frame() + + def subscribe(self, queue: str) -> tuple[int, str]: + """Send a Consume request and wait for ConsumeOk. + + Returns (request_id, consumer_id). + """ + from fila.fibp.codec import decode_consume_ok, encode_consume + + req_id = self._next_request_id() + self.write_frame(Opcode.CONSUME, req_id, encode_consume(queue)) + + header, body = self.read_frame() + if header.opcode == Opcode.ERROR: + err = decode_error(body) + from fila.errors import _raise_from_error_frame + _raise_from_error_frame(err) + + if header.opcode != Opcode.CONSUME_OK: + raise ConnectionError( + f"expected ConsumeOk (0x19), got 0x{header.opcode:02x}" + ) + + consumer_id = decode_consume_ok(body) + return req_id, consumer_id + + def cancel_consume(self, consumer_id: str) -> None: + """Send a CancelConsume frame.""" + from fila.fibp.codec import encode_cancel_consume + + req_id = self._next_request_id() + self.write_frame(Opcode.CANCEL_CONSUME, req_id, encode_cancel_consume(consumer_id)) + + def close(self) -> None: + """Send Disconnect and close the socket.""" + import contextlib + + with contextlib.suppress(OSError): + from fila.fibp.codec import encode_disconnect + req_id = self._next_request_id() + self.write_frame(Opcode.DISCONNECT, req_id, encode_disconnect()) + + with contextlib.suppress(OSError): + self._sock.close() + + +# --------------------------------------------------------------------------- +# Async connection +# --------------------------------------------------------------------------- + +class AsyncConnection: + """Asynchronous FIBP connection over asyncio streams.""" + + def __init__( + self, + reader: asyncio.StreamReader, + writer: asyncio.StreamWriter, + max_frame_size: int = DEFAULT_MAX_FRAME_SIZE, + ) -> None: + self._reader = reader + self._writer = writer + self._max_frame_size = max_frame_size + self._req_counter = 0 + self._lock = asyncio.Lock() + + @classmethod + async def connect( + cls, + host: str, + port: int, + *, + ssl_context: ssl.SSLContext | None = None, + api_key: str | None = None, + timeout: float = 10.0, + ) -> AsyncConnection: + """Open a TCP connection and perform the FIBP handshake.""" + reader, writer = await asyncio.wait_for( + asyncio.open_connection(host, port, ssl=ssl_context), + timeout=timeout, + ) + + # Set TCP_NODELAY on the underlying socket. + sock = writer.get_extra_info("socket") + if sock is not None: + sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) + + conn = cls(reader, writer) + await conn._handshake(api_key) + return conn + + def _next_request_id(self) -> int: + self._req_counter += 1 + return self._req_counter & 0xFFFFFFFF + + async def _handshake(self, api_key: str | None) -> None: + """Perform the FIBP handshake.""" + body = encode_handshake(PROTOCOL_VERSION, api_key) + req_id = self._next_request_id() + await self.write_frame(Opcode.HANDSHAKE, req_id, body) + + header, resp_body = await self.read_frame() + if header.opcode == Opcode.ERROR: + err = decode_error(resp_body) + from fila.errors import _raise_from_error_frame + _raise_from_error_frame(err) + + if header.opcode != Opcode.HANDSHAKE_OK: + raise ConnectionError( + f"expected HandshakeOk (0x02), got 0x{header.opcode:02x}" + ) + + version, _node_id, max_frame_size = decode_handshake_ok(resp_body) + if max_frame_size > 0: + self._max_frame_size = max_frame_size + + async def write_frame( + self, opcode: int, request_id: int, body: bytes, flags: int = 0 + ) -> None: + """Write a single length-prefixed FIBP frame.""" + frame = _build_frame(opcode, request_id, body, flags) + self._writer.write(frame) + await self._writer.drain() + + async def read_frame(self) -> tuple[FrameHeader, bytes]: + """Read a single length-prefixed FIBP frame. + + Handles Ping by responding with Pong automatically. + Handles continuation frames by concatenating bodies. + """ + while True: + header, body = await self._read_single_frame() + + if header.opcode == Opcode.PING: + await self.write_frame(Opcode.PONG, header.request_id, encode_pong()) + continue + + if header.is_continuation: + parts = [body] + while True: + cont_header, cont_body = await self._read_single_frame() + if cont_header.opcode == Opcode.PING: + await self.write_frame( + Opcode.PONG, cont_header.request_id, encode_pong() + ) + continue + parts.append(cont_body) + if not cont_header.is_continuation: + break + return header, b"".join(parts) + + return header, body + + async def _read_single_frame(self) -> tuple[FrameHeader, bytes]: + """Read one frame from the wire (no continuation handling).""" + length_bytes = await self._reader.readexactly(4) + frame_len = struct.unpack("!I", length_bytes)[0] + if frame_len > self._max_frame_size: + raise ConnectionError( + f"frame size {frame_len} exceeds max {self._max_frame_size}" + ) + frame_data = await self._reader.readexactly(frame_len) + header = _parse_header(frame_data[:FRAME_HEADER_SIZE]) + body = frame_data[FRAME_HEADER_SIZE:] + return header, body + + async def request(self, opcode: int, body: bytes) -> tuple[FrameHeader, bytes]: + """Send a request frame and read the response.""" + req_id = self._next_request_id() + async with self._lock: + await self.write_frame(opcode, req_id, body) + return await self.read_frame() + + async def subscribe(self, queue: str) -> tuple[int, str]: + """Send a Consume request and wait for ConsumeOk. + + Returns (request_id, consumer_id). + """ + from fila.fibp.codec import decode_consume_ok, encode_consume + + req_id = self._next_request_id() + await self.write_frame(Opcode.CONSUME, req_id, encode_consume(queue)) + + header, body = await self.read_frame() + if header.opcode == Opcode.ERROR: + err = decode_error(body) + from fila.errors import _raise_from_error_frame + _raise_from_error_frame(err) + + if header.opcode != Opcode.CONSUME_OK: + raise ConnectionError( + f"expected ConsumeOk (0x19), got 0x{header.opcode:02x}" + ) + + consumer_id = decode_consume_ok(body) + return req_id, consumer_id + + async def cancel_consume(self, consumer_id: str) -> None: + """Send a CancelConsume frame.""" + from fila.fibp.codec import encode_cancel_consume + + req_id = self._next_request_id() + await self.write_frame( + Opcode.CANCEL_CONSUME, req_id, encode_cancel_consume(consumer_id) + ) + + async def close(self) -> None: + """Send Disconnect and close the stream.""" + try: + from fila.fibp.codec import encode_disconnect + req_id = self._next_request_id() + await self.write_frame(Opcode.DISCONNECT, req_id, encode_disconnect()) + except OSError: + pass + finally: + try: + self._writer.close() + await self._writer.wait_closed() + except OSError: + pass diff --git a/fila/errors.py b/fila/errors.py index 00890f2..242968d 100644 --- a/fila/errors.py +++ b/fila/errors.py @@ -2,7 +2,10 @@ from __future__ import annotations -import grpc +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from fila.fibp.codec import ErrorFrame class FilaError(Exception): @@ -17,13 +20,52 @@ class MessageNotFoundError(FilaError): """Raised when the specified message does not exist.""" -class RPCError(FilaError): - """Raised for unexpected gRPC failures, preserving status code and message.""" +class QueueAlreadyExistsError(FilaError): + """Raised when creating a queue that already exists.""" - def __init__(self, code: grpc.StatusCode, message: str) -> None: - self.code = code - self.message = message - super().__init__(f"rpc error (code = {code.name}): {message}") + +class InvalidArgumentError(FilaError): + """Raised when an argument is invalid.""" + + +class PermissionDeniedError(FilaError): + """Raised when permission is denied for the operation.""" + + +class UnauthorizedError(FilaError): + """Raised when the client is not authenticated.""" + + +class ForbiddenError(FilaError): + """Raised when the client lacks permission for the operation.""" + + +class NotLeaderError(FilaError): + """Raised when the request was sent to a non-leader node. + + The ``leader_addr`` attribute contains the address of the current leader, + if available. + """ + + def __init__(self, message: str, leader_addr: str | None = None) -> None: + self.leader_addr = leader_addr + super().__init__(message) + + +class ChannelFullError(FilaError): + """Raised when a channel or buffer is full (backpressure).""" + + +class ResourceExhaustedError(FilaError): + """Raised when a resource limit has been reached.""" + + +class UnavailableError(FilaError): + """Raised when the server is unavailable.""" + + +class LuaError(FilaError): + """Raised when a Lua script error occurs.""" class EnqueueError(FilaError): @@ -36,49 +78,101 @@ class EnqueueError(FilaError): """ -def _map_enqueue_result_error(code: int, message: str) -> FilaError: - """Map a per-message EnqueueErrorCode to a Fila exception. +class ProtocolError(FilaError): + """Raised for unexpected protocol-level failures.""" - Used when the unified Enqueue RPC succeeds at the transport level but - returns a per-message error result (e.g., queue not found for one of - the messages in the batch). + def __init__(self, code: int, message: str) -> None: + self.code = code + self.message = message + super().__init__(f"protocol error (code=0x{code:02x}): {message}") + + +class ApiKeyNotFoundError(FilaError): + """Raised when the specified API key does not exist.""" + + +class AclNotFoundError(FilaError): + """Raised when the specified ACL entry does not exist.""" + + +# Keep RPCError as a thin alias for backwards compatibility with existing +# callers that catch fila.RPCError. +RPCError = ProtocolError + + +# --------------------------------------------------------------------------- +# Error-code mapping +# --------------------------------------------------------------------------- + +def _map_error_code(code: int, message: str) -> FilaError: + """Map a FIBP error code to the appropriate exception.""" + from fila.fibp.opcodes import ErrorCode + + match code: + case ErrorCode.QUEUE_NOT_FOUND: + return QueueNotFoundError(message) + case ErrorCode.MESSAGE_NOT_FOUND: + return MessageNotFoundError(message) + case ErrorCode.QUEUE_ALREADY_EXISTS: + return QueueAlreadyExistsError(message) + case ErrorCode.INVALID_ARGUMENT: + return InvalidArgumentError(message) + case ErrorCode.PERMISSION_DENIED: + return PermissionDeniedError(message) + case ErrorCode.UNAUTHENTICATED: + return UnauthorizedError(message) + case ErrorCode.FORBIDDEN: + return ForbiddenError(message) + case ErrorCode.NOT_LEADER: + return NotLeaderError(message) + case ErrorCode.CHANNEL_FULL: + return ChannelFullError(message) + case ErrorCode.RESOURCE_EXHAUSTED: + return ResourceExhaustedError(message) + case ErrorCode.UNAVAILABLE: + return UnavailableError(message) + case ErrorCode.LUA_ERROR: + return LuaError(message) + case ErrorCode.API_KEY_NOT_FOUND: + return ApiKeyNotFoundError(message) + case ErrorCode.ACL_NOT_FOUND: + return AclNotFoundError(message) + case ErrorCode.DEADLINE_EXCEEDED: + return ProtocolError(code, message) + case ErrorCode.PRECONDITION_FAILED: + return ProtocolError(code, message) + case ErrorCode.ABORTED: + return ProtocolError(code, message) + case ErrorCode.INTERNAL_ERROR: + return ProtocolError(code, message) + case _: + return ProtocolError(code, message) + + +def _raise_from_error_frame(err: ErrorFrame) -> None: + """Raise the appropriate exception from a decoded Error frame. + + For NotLeader errors, extracts the leader_addr from metadata. """ - from fila.v1 import service_pb2 - - if code == service_pb2.ENQUEUE_ERROR_CODE_QUEUE_NOT_FOUND: - return QueueNotFoundError(f"enqueue: {message}") - if code == service_pb2.ENQUEUE_ERROR_CODE_PERMISSION_DENIED: - return RPCError(grpc.StatusCode.PERMISSION_DENIED, f"enqueue: {message}") - return EnqueueError(f"enqueue failed: {message}") - - -def _map_enqueue_error(err: grpc.RpcError) -> FilaError: - """Map a gRPC error from an enqueue call to a Fila exception.""" - code = err.code() - if code == grpc.StatusCode.NOT_FOUND: - return QueueNotFoundError(f"enqueue: {err.details()}") - return RPCError(code, err.details() or "") - + from fila.fibp.opcodes import ErrorCode -def _map_consume_error(err: grpc.RpcError) -> FilaError: - """Map a gRPC error from a consume call to a Fila exception.""" - code = err.code() - if code == grpc.StatusCode.NOT_FOUND: - return QueueNotFoundError(f"consume: {err.details()}") - return RPCError(code, err.details() or "") + if err.code == ErrorCode.NOT_LEADER: + leader_addr = err.metadata.get("leader_addr") + raise NotLeaderError(err.message, leader_addr=leader_addr) + raise _map_error_code(err.code, err.message) -def _map_ack_error(err: grpc.RpcError) -> FilaError: - """Map a gRPC error from an ack call to a Fila exception.""" - code = err.code() - if code == grpc.StatusCode.NOT_FOUND: - return MessageNotFoundError(f"ack: {err.details()}") - return RPCError(code, err.details() or "") +def _map_per_item_error(code: int, context: str) -> FilaError: + """Map a per-item error code (from EnqueueResult, AckResult, etc.).""" + from fila.fibp.opcodes import ErrorCode -def _map_nack_error(err: grpc.RpcError) -> FilaError: - """Map a gRPC error from a nack call to a Fila exception.""" - code = err.code() - if code == grpc.StatusCode.NOT_FOUND: - return MessageNotFoundError(f"nack: {err.details()}") - return RPCError(code, err.details() or "") + match code: + case ErrorCode.QUEUE_NOT_FOUND: + return QueueNotFoundError(f"{context}: queue not found") + case ErrorCode.MESSAGE_NOT_FOUND: + return MessageNotFoundError(f"{context}: message not found") + case ErrorCode.PERMISSION_DENIED: + return PermissionDeniedError(f"{context}: permission denied") + case _: + return EnqueueError(f"{context}: error code 0x{code:02x}") diff --git a/fila/fibp/__init__.py b/fila/fibp/__init__.py new file mode 100644 index 0000000..67fffd9 --- /dev/null +++ b/fila/fibp/__init__.py @@ -0,0 +1,82 @@ +"""FIBP (Fila Binary Protocol) codec and primitives.""" + +from fila.fibp.codec import ( + decode_ack_result, + decode_consume_ok, + decode_delivery, + decode_enqueue_result, + decode_error, + decode_get_acl_result, + decode_get_config_result, + decode_get_stats_result, + decode_handshake_ok, + decode_list_api_keys_result, + decode_list_config_result, + decode_list_queues_result, + decode_nack_result, + encode_ack, + encode_cancel_consume, + encode_consume, + encode_create_api_key, + encode_create_queue, + encode_delete_queue, + encode_disconnect, + encode_enqueue, + encode_get_acl, + encode_get_config, + encode_get_stats, + encode_handshake, + encode_list_api_keys, + encode_list_config, + encode_list_queues, + encode_nack, + encode_pong, + encode_redrive, + encode_revoke_api_key, + encode_set_acl, + encode_set_config, +) +from fila.fibp.opcodes import ErrorCode, FrameHeader, Opcode +from fila.fibp.primitives import Reader, Writer + +__all__ = [ + "ErrorCode", + "FrameHeader", + "Opcode", + "Reader", + "Writer", + "decode_ack_result", + "decode_consume_ok", + "decode_delivery", + "decode_enqueue_result", + "decode_error", + "decode_get_acl_result", + "decode_get_config_result", + "decode_get_stats_result", + "decode_handshake_ok", + "decode_list_api_keys_result", + "decode_list_config_result", + "decode_list_queues_result", + "decode_nack_result", + "encode_ack", + "encode_cancel_consume", + "encode_consume", + "encode_create_api_key", + "encode_create_queue", + "encode_delete_queue", + "encode_disconnect", + "encode_enqueue", + "encode_get_acl", + "encode_get_config", + "encode_get_stats", + "encode_handshake", + "encode_list_api_keys", + "encode_list_config", + "encode_list_queues", + "encode_nack", + "encode_pong", + "encode_redrive", + "encode_revoke_api_key", + "encode_set_acl", + "encode_set_config", +] diff --git a/fila/fibp/codec.py b/fila/fibp/codec.py new file mode 100644 index 0000000..a7544c2 --- /dev/null +++ b/fila/fibp/codec.py @@ -0,0 +1,412 @@ +"""Encode/decode functions for every FIBP opcode.""" + +from __future__ import annotations + +from dataclasses import dataclass + +from fila.fibp.primitives import Reader, Writer + +# --------------------------------------------------------------------------- +# Data types used by decode functions +# --------------------------------------------------------------------------- + +@dataclass(frozen=True, slots=True) +class DeliveryMessage: + """A single message within a Delivery frame.""" + + message_id: str + queue: str + headers: dict[str, str] + payload: bytes + fairness_key: str + weight: int + throttle_keys: list[str] + attempt_count: int + enqueued_at: int + leased_at: int + + +@dataclass(frozen=True, slots=True) +class EnqueueResultItem: + """Per-message result within an EnqueueResult frame.""" + + error_code: int + message_id: str + + +@dataclass(frozen=True, slots=True) +class ErrorFrame: + """Decoded Error frame.""" + + code: int + message: str + metadata: dict[str, str] + + +@dataclass(frozen=True, slots=True) +class StatsResult: + """Decoded GetStatsResult frame.""" + + stats: dict[str, str] + + +@dataclass(frozen=True, slots=True) +class QueueInfo: + """A single queue in ListQueuesResult.""" + + name: str + config: dict[str, str] + + +@dataclass(frozen=True, slots=True) +class ApiKeyInfo: + """A single API key in ListApiKeysResult.""" + + key_id: str + prefix: str + created_at: int + + +@dataclass(frozen=True, slots=True) +class AclEntry: + """Decoded GetAclResult.""" + + patterns: list[str] + superadmin: bool + + +# --------------------------------------------------------------------------- +# Encode: Control +# --------------------------------------------------------------------------- + +def encode_handshake(version: int, api_key: str | None = None) -> bytes: + """Encode a Handshake frame body.""" + w = Writer() + w.write_u16(version) + w.write_optional_string(api_key) + return w.finish() + + +def encode_pong() -> bytes: + """Encode a Pong frame body (empty).""" + return b"" + + +def encode_disconnect() -> bytes: + """Encode a Disconnect frame body (empty).""" + return b"" + + +# --------------------------------------------------------------------------- +# Decode: Control +# --------------------------------------------------------------------------- + +def decode_handshake_ok(data: bytes) -> tuple[int, int, int]: + """Decode a HandshakeOk frame body -> (version, node_id, max_frame_size).""" + r = Reader(data) + version = r.read_u16() + node_id = r.read_u64() + max_frame_size = r.read_u32() + return version, node_id, max_frame_size + + +# --------------------------------------------------------------------------- +# Encode: Hot-path +# --------------------------------------------------------------------------- + +def encode_enqueue(messages: list[dict[str, object]]) -> bytes: + """Encode an Enqueue frame body. + + Each message dict has keys: queue (str), headers (dict[str,str]), payload (bytes). + """ + w = Writer() + w.write_u32(len(messages)) + for msg in messages: + w.write_string(str(msg["queue"])) + w.write_string_map(msg.get("headers") or {}) # type: ignore[arg-type] + w.write_bytes(msg.get("payload", b"") or b"") # type: ignore[arg-type] + return w.finish() + + +def encode_consume(queue: str) -> bytes: + """Encode a Consume frame body.""" + w = Writer() + w.write_string(queue) + return w.finish() + + +def encode_cancel_consume(consumer_id: str) -> bytes: + """Encode a CancelConsume frame body.""" + w = Writer() + w.write_string(consumer_id) + return w.finish() + + +def encode_ack(items: list[dict[str, str]]) -> bytes: + """Encode an Ack frame body. Each item: {queue, message_id}.""" + w = Writer() + w.write_u32(len(items)) + for item in items: + w.write_string(item["queue"]) + w.write_string(item["message_id"]) + return w.finish() + + +def encode_nack(items: list[dict[str, str]]) -> bytes: + """Encode a Nack frame body. Each item: {queue, message_id, error}.""" + w = Writer() + w.write_u32(len(items)) + for item in items: + w.write_string(item["queue"]) + w.write_string(item["message_id"]) + w.write_string(item.get("error", "")) + return w.finish() + + +# --------------------------------------------------------------------------- +# Decode: Hot-path +# --------------------------------------------------------------------------- + +def decode_enqueue_result(data: bytes) -> list[EnqueueResultItem]: + """Decode an EnqueueResult frame body.""" + r = Reader(data) + count = r.read_u32() + results: list[EnqueueResultItem] = [] + for _ in range(count): + error_code = r.read_u8() + message_id = r.read_string() + results.append(EnqueueResultItem(error_code=error_code, message_id=message_id)) + return results + + +def decode_consume_ok(data: bytes) -> str: + """Decode a ConsumeOk frame body -> consumer_id.""" + r = Reader(data) + return r.read_string() + + +def decode_delivery(data: bytes) -> list[DeliveryMessage]: + """Decode a Delivery frame body.""" + r = Reader(data) + count = r.read_u32() + messages: list[DeliveryMessage] = [] + for _ in range(count): + msg_id = r.read_string() + queue = r.read_string() + headers = r.read_string_map() + payload = r.read_bytes() + fairness_key = r.read_string() + weight = r.read_u32() + throttle_keys = r.read_string_list() + attempt_count = r.read_u32() + enqueued_at = r.read_u64() + leased_at = r.read_u64() + messages.append(DeliveryMessage( + message_id=msg_id, + queue=queue, + headers=headers, + payload=payload, + fairness_key=fairness_key, + weight=weight, + throttle_keys=throttle_keys, + attempt_count=attempt_count, + enqueued_at=enqueued_at, + leased_at=leased_at, + )) + return messages + + +def decode_ack_result(data: bytes) -> list[int]: + """Decode an AckResult frame body -> list of error codes.""" + r = Reader(data) + count = r.read_u32() + return [r.read_u8() for _ in range(count)] + + +def decode_nack_result(data: bytes) -> list[int]: + """Decode a NackResult frame body -> list of error codes.""" + r = Reader(data) + count = r.read_u32() + return [r.read_u8() for _ in range(count)] + + +# --------------------------------------------------------------------------- +# Decode: Error +# --------------------------------------------------------------------------- + +def decode_error(data: bytes) -> ErrorFrame: + """Decode an Error frame body.""" + r = Reader(data) + code = r.read_u8() + message = r.read_string() + metadata = r.read_string_map() + return ErrorFrame(code=code, message=message, metadata=metadata) + + +# --------------------------------------------------------------------------- +# Encode: Admin +# --------------------------------------------------------------------------- + +def encode_create_queue(name: str, config: dict[str, str] | None = None) -> bytes: + """Encode a CreateQueue frame body.""" + w = Writer() + w.write_string(name) + w.write_string_map(config or {}) + return w.finish() + + +def encode_delete_queue(name: str) -> bytes: + """Encode a DeleteQueue frame body.""" + w = Writer() + w.write_string(name) + return w.finish() + + +def encode_get_stats(queue: str) -> bytes: + """Encode a GetStats frame body.""" + w = Writer() + w.write_string(queue) + return w.finish() + + +def encode_list_queues() -> bytes: + """Encode a ListQueues frame body (empty).""" + return b"" + + +def encode_set_config(queue: str, config: dict[str, str]) -> bytes: + """Encode a SetConfig frame body.""" + w = Writer() + w.write_string(queue) + w.write_string_map(config) + return w.finish() + + +def encode_get_config(queue: str) -> bytes: + """Encode a GetConfig frame body.""" + w = Writer() + w.write_string(queue) + return w.finish() + + +def encode_list_config(queue: str) -> bytes: + """Encode a ListConfig frame body.""" + w = Writer() + w.write_string(queue) + return w.finish() + + +def encode_redrive(source_queue: str, dest_queue: str, count: int) -> bytes: + """Encode a Redrive frame body.""" + w = Writer() + w.write_string(source_queue) + w.write_string(dest_queue) + w.write_u32(count) + return w.finish() + + +# --------------------------------------------------------------------------- +# Decode: Admin results +# --------------------------------------------------------------------------- + +def _decode_simple_result(data: bytes) -> int: + """Decode a simple result frame that contains just an error code.""" + r = Reader(data) + return r.read_u8() + + +def decode_get_stats_result(data: bytes) -> StatsResult: + """Decode a GetStatsResult frame body.""" + r = Reader(data) + stats = r.read_string_map() + return StatsResult(stats=stats) + + +def decode_list_queues_result(data: bytes) -> list[str]: + """Decode a ListQueuesResult frame body -> list of queue names.""" + r = Reader(data) + return r.read_string_list() + + +def decode_get_config_result(data: bytes) -> dict[str, str]: + """Decode a GetConfigResult frame body -> config map.""" + r = Reader(data) + return r.read_string_map() + + +def decode_list_config_result(data: bytes) -> dict[str, str]: + """Decode a ListConfigResult frame body -> config map.""" + r = Reader(data) + return r.read_string_map() + + +# --------------------------------------------------------------------------- +# Encode: Auth +# --------------------------------------------------------------------------- + +def encode_create_api_key(name: str) -> bytes: + """Encode a CreateApiKey frame body.""" + w = Writer() + w.write_string(name) + return w.finish() + + +def encode_revoke_api_key(key_id: str) -> bytes: + """Encode a RevokeApiKey frame body.""" + w = Writer() + w.write_string(key_id) + return w.finish() + + +def encode_list_api_keys() -> bytes: + """Encode a ListApiKeys frame body (empty).""" + return b"" + + +def encode_set_acl(key_id: str, patterns: list[str], superadmin: bool = False) -> bytes: + """Encode a SetAcl frame body.""" + w = Writer() + w.write_string(key_id) + w.write_string_list(patterns) + w.write_bool(superadmin) + return w.finish() + + +def encode_get_acl(key_id: str) -> bytes: + """Encode a GetAcl frame body.""" + w = Writer() + w.write_string(key_id) + return w.finish() + + +# --------------------------------------------------------------------------- +# Decode: Auth results +# --------------------------------------------------------------------------- + +def decode_create_api_key_result(data: bytes) -> tuple[str, str]: + """Decode a CreateApiKeyResult -> (key_id, raw_key).""" + r = Reader(data) + key_id = r.read_string() + raw_key = r.read_string() + return key_id, raw_key + + +def decode_list_api_keys_result(data: bytes) -> list[ApiKeyInfo]: + """Decode a ListApiKeysResult -> list of ApiKeyInfo.""" + r = Reader(data) + count = r.read_u16() + keys: list[ApiKeyInfo] = [] + for _ in range(count): + key_id = r.read_string() + prefix = r.read_string() + created_at = r.read_u64() + keys.append(ApiKeyInfo(key_id=key_id, prefix=prefix, created_at=created_at)) + return keys + + +def decode_get_acl_result(data: bytes) -> AclEntry: + """Decode a GetAclResult -> AclEntry.""" + r = Reader(data) + patterns = r.read_string_list() + superadmin = r.read_bool() + return AclEntry(patterns=patterns, superadmin=superadmin) diff --git a/fila/fibp/opcodes.py b/fila/fibp/opcodes.py new file mode 100644 index 0000000..b50f27e --- /dev/null +++ b/fila/fibp/opcodes.py @@ -0,0 +1,120 @@ +"""FIBP opcode constants, error codes, and frame header definition.""" + +from __future__ import annotations + +from dataclasses import dataclass +from enum import IntEnum + +# --------------------------------------------------------------------------- +# Protocol constants +# --------------------------------------------------------------------------- + +PROTOCOL_VERSION: int = 1 +DEFAULT_MAX_FRAME_SIZE: int = 16 * 1024 * 1024 # 16 MiB +FRAME_HEADER_SIZE: int = 6 # opcode(1) + flags(1) + request_id(4) +CONTINUATION_FLAG: int = 0x01 + + +# --------------------------------------------------------------------------- +# Opcodes +# --------------------------------------------------------------------------- + +class Opcode(IntEnum): + """FIBP opcodes.""" + + # Control + HANDSHAKE = 0x01 + HANDSHAKE_OK = 0x02 + PING = 0x03 + PONG = 0x04 + DISCONNECT = 0x05 + + # Hot-path + ENQUEUE = 0x10 + ENQUEUE_RESULT = 0x11 + CONSUME = 0x12 + DELIVERY = 0x13 + CANCEL_CONSUME = 0x14 + ACK = 0x15 + ACK_RESULT = 0x16 + NACK = 0x17 + NACK_RESULT = 0x18 + CONSUME_OK = 0x19 + + # Error + ERROR = 0xFE + + # Admin + CREATE_QUEUE = 0xFD + CREATE_QUEUE_RESULT = 0xFC + DELETE_QUEUE = 0xFB + DELETE_QUEUE_RESULT = 0xFA + GET_STATS = 0xF9 + GET_STATS_RESULT = 0xF8 + LIST_QUEUES = 0xF7 + LIST_QUEUES_RESULT = 0xF6 + SET_CONFIG = 0xF5 + SET_CONFIG_RESULT = 0xF4 + GET_CONFIG = 0xF3 + GET_CONFIG_RESULT = 0xF2 + LIST_CONFIG = 0xF1 + LIST_CONFIG_RESULT = 0xF0 + REDRIVE = 0xEF + REDRIVE_RESULT = 0xEE + + # Auth + CREATE_API_KEY = 0xED + CREATE_API_KEY_RESULT = 0xEC + REVOKE_API_KEY = 0xEB + REVOKE_API_KEY_RESULT = 0xEA + LIST_API_KEYS = 0xE9 + LIST_API_KEYS_RESULT = 0xE8 + SET_ACL = 0xE7 + SET_ACL_RESULT = 0xE6 + GET_ACL = 0xE5 + GET_ACL_RESULT = 0xE4 + + +# --------------------------------------------------------------------------- +# Error codes +# --------------------------------------------------------------------------- + +class ErrorCode(IntEnum): + """FIBP error codes returned in Error frames and per-item results.""" + + OK = 0x00 + QUEUE_NOT_FOUND = 0x01 + MESSAGE_NOT_FOUND = 0x02 + QUEUE_ALREADY_EXISTS = 0x03 + INVALID_ARGUMENT = 0x04 + DEADLINE_EXCEEDED = 0x05 + PERMISSION_DENIED = 0x06 + RESOURCE_EXHAUSTED = 0x07 + PRECONDITION_FAILED = 0x08 + ABORTED = 0x09 + UNAVAILABLE = 0x0A + UNAUTHENTICATED = 0x0B + NOT_LEADER = 0x0C + LUA_ERROR = 0x0D + CHANNEL_FULL = 0x0E + FORBIDDEN = 0x0F + API_KEY_NOT_FOUND = 0x10 + ACL_NOT_FOUND = 0x11 + INTERNAL_ERROR = 0xFF + + +# --------------------------------------------------------------------------- +# Frame header +# --------------------------------------------------------------------------- + +@dataclass(frozen=True, slots=True) +class FrameHeader: + """Parsed 6-byte FIBP frame header.""" + + opcode: int + flags: int + request_id: int + + @property + def is_continuation(self) -> bool: + return bool(self.flags & CONTINUATION_FLAG) diff --git a/fila/fibp/primitives.py b/fila/fibp/primitives.py new file mode 100644 index 0000000..728a650 --- /dev/null +++ b/fila/fibp/primitives.py @@ -0,0 +1,167 @@ +"""Low-level encoding/decoding primitives for the FIBP wire format. + +All multi-byte integers are big-endian. Strings are length-prefixed with +a u16 byte count followed by UTF-8 bytes. Byte slices use a u32 length +prefix. +""" + +from __future__ import annotations + +import struct + +# --------------------------------------------------------------------------- +# Writer +# --------------------------------------------------------------------------- + +class Writer: + """Accumulates bytes for a FIBP frame body.""" + + __slots__ = ("_buf",) + + def __init__(self) -> None: + self._buf = bytearray() + + # -- scalars ------------------------------------------------------------- + + def write_u8(self, v: int) -> None: + self._buf.append(v & 0xFF) + + def write_u16(self, v: int) -> None: + self._buf.extend(struct.pack("!H", v)) + + def write_u32(self, v: int) -> None: + self._buf.extend(struct.pack("!I", v)) + + def write_u64(self, v: int) -> None: + self._buf.extend(struct.pack("!Q", v)) + + def write_i64(self, v: int) -> None: + self._buf.extend(struct.pack("!q", v)) + + def write_f64(self, v: float) -> None: + self._buf.extend(struct.pack("!d", v)) + + def write_bool(self, v: bool) -> None: + self._buf.append(1 if v else 0) + + # -- composites ---------------------------------------------------------- + + def write_string(self, s: str) -> None: + encoded = s.encode("utf-8") + self.write_u16(len(encoded)) + self._buf.extend(encoded) + + def write_bytes(self, b: bytes) -> None: + self.write_u32(len(b)) + self._buf.extend(b) + + def write_string_map(self, m: dict[str, str]) -> None: + self.write_u16(len(m)) + for k, v in m.items(): + self.write_string(k) + self.write_string(v) + + def write_string_list(self, items: list[str]) -> None: + self.write_u16(len(items)) + for s in items: + self.write_string(s) + + def write_optional_string(self, s: str | None) -> None: + if s is None: + self.write_u8(0) + else: + self.write_u8(1) + self.write_string(s) + + # -- access -------------------------------------------------------------- + + def finish(self) -> bytes: + return bytes(self._buf) + + +# --------------------------------------------------------------------------- +# Reader +# --------------------------------------------------------------------------- + +class Reader: + """Reads primitive values from a FIBP frame body with position tracking.""" + + __slots__ = ("_data", "_pos") + + def __init__(self, data: bytes | bytearray | memoryview) -> None: + self._data = bytes(data) + self._pos = 0 + + @property + def remaining(self) -> int: + return len(self._data) - self._pos + + # -- scalars ------------------------------------------------------------- + + def read_u8(self) -> int: + v = self._data[self._pos] + self._pos += 1 + return v + + def read_u16(self) -> int: + v = struct.unpack_from("!H", self._data, self._pos)[0] + self._pos += 2 + return v + + def read_u32(self) -> int: + v = struct.unpack_from("!I", self._data, self._pos)[0] + self._pos += 4 + return v + + def read_u64(self) -> int: + v = struct.unpack_from("!Q", self._data, self._pos)[0] + self._pos += 8 + return v + + def read_i64(self) -> int: + v = struct.unpack_from("!q", self._data, self._pos)[0] + self._pos += 8 + return v + + def read_f64(self) -> float: + v: float = struct.unpack_from("!d", self._data, self._pos)[0] + self._pos += 8 + return v + + def read_bool(self) -> bool: + return self.read_u8() != 0 + + # -- composites ---------------------------------------------------------- + + def read_string(self) -> str: + length = self.read_u16() + end = self._pos + length + s = self._data[self._pos:end].decode("utf-8") + self._pos = end + return s + + def read_bytes(self) -> bytes: + length = self.read_u32() + end = self._pos + length + b = self._data[self._pos:end] + self._pos = end + return b + + def read_string_map(self) -> dict[str, str]: + count = self.read_u16() + m: dict[str, str] = {} + for _ in range(count): + k = self.read_string() + v = self.read_string() + m[k] = v + return m + + def read_string_list(self) -> list[str]: + count = self.read_u16() + return [self.read_string() for _ in range(count)] + + def read_optional_string(self) -> str | None: + present = self.read_u8() + if present: + return self.read_string() + return None diff --git a/fila/types.py b/fila/types.py index a73c15a..749025c 100644 --- a/fila/types.py +++ b/fila/types.py @@ -16,6 +16,10 @@ class ConsumeMessage: fairness_key: str attempt_count: int queue: str + weight: int = 0 + throttle_keys: list[str] | None = None + enqueued_at: int = 0 + leased_at: int = 0 @dataclass(frozen=True) @@ -61,3 +65,35 @@ class Linger: linger_ms: float max_messages: int + + +@dataclass(frozen=True) +class CreateApiKeyResult: + """Result of creating an API key.""" + + key_id: str + raw_key: str + + +@dataclass(frozen=True) +class ApiKeyInfo: + """Summary information about an API key.""" + + key_id: str + prefix: str + created_at: int + + +@dataclass(frozen=True) +class AclEntry: + """ACL entry for an API key.""" + + patterns: list[str] + superadmin: bool + + +@dataclass(frozen=True) +class StatsResult: + """Queue statistics.""" + + stats: dict[str, str] diff --git a/fila/v1/__init__.py b/fila/v1/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/fila/v1/admin_pb2.py b/fila/v1/admin_pb2.py deleted file mode 100644 index 7dc1f25..0000000 --- a/fila/v1/admin_pb2.py +++ /dev/null @@ -1,78 +0,0 @@ -# -*- coding: utf-8 -*- -# Generated by the protocol buffer compiler. DO NOT EDIT! -# NO CHECKED-IN PROTOBUF GENCODE -# source: fila/v1/admin.proto -# Protobuf Python Version: 6.31.1 -"""Generated protocol buffer code.""" -from google.protobuf import descriptor as _descriptor -from google.protobuf import descriptor_pool as _descriptor_pool -from google.protobuf import runtime_version as _runtime_version -from google.protobuf import symbol_database as _symbol_database -from google.protobuf.internal import builder as _builder -_runtime_version.ValidateProtobufRuntimeVersion( - _runtime_version.Domain.PUBLIC, - 6, - 31, - 1, - '', - 'fila/v1/admin.proto' -) -# @@protoc_insertion_point(imports) - -_sym_db = _symbol_database.Default() - - - - -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x13\x66ila/v1/admin.proto\x12\x07\x66ila.v1\"H\n\x12\x43reateQueueRequest\x12\x0c\n\x04name\x18\x01 \x01(\t\x12$\n\x06\x63onfig\x18\x02 \x01(\x0b\x32\x14.fila.v1.QueueConfig\"b\n\x0bQueueConfig\x12\x19\n\x11on_enqueue_script\x18\x01 \x01(\t\x12\x19\n\x11on_failure_script\x18\x02 \x01(\t\x12\x1d\n\x15visibility_timeout_ms\x18\x03 \x01(\x04\"\'\n\x13\x43reateQueueResponse\x12\x10\n\x08queue_id\x18\x01 \x01(\t\"#\n\x12\x44\x65leteQueueRequest\x12\r\n\x05queue\x18\x01 \x01(\t\"\x15\n\x13\x44\x65leteQueueResponse\".\n\x10SetConfigRequest\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t\"\x13\n\x11SetConfigResponse\"\x1f\n\x10GetConfigRequest\x12\x0b\n\x03key\x18\x01 \x01(\t\"\"\n\x11GetConfigResponse\x12\r\n\x05value\x18\x01 \x01(\t\")\n\x0b\x43onfigEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t\"#\n\x11ListConfigRequest\x12\x0e\n\x06prefix\x18\x01 \x01(\t\"P\n\x12ListConfigResponse\x12%\n\x07\x65ntries\x18\x01 \x03(\x0b\x32\x14.fila.v1.ConfigEntry\x12\x13\n\x0btotal_count\x18\x02 \x01(\r\" \n\x0fGetStatsRequest\x12\r\n\x05queue\x18\x01 \x01(\t\"b\n\x13PerFairnessKeyStats\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\x15\n\rpending_count\x18\x02 \x01(\x04\x12\x17\n\x0f\x63urrent_deficit\x18\x03 \x01(\x03\x12\x0e\n\x06weight\x18\x04 \x01(\r\"Z\n\x13PerThrottleKeyStats\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\x0e\n\x06tokens\x18\x02 \x01(\x01\x12\x17\n\x0frate_per_second\x18\x03 \x01(\x01\x12\r\n\x05\x62urst\x18\x04 \x01(\x01\"\x9f\x02\n\x10GetStatsResponse\x12\r\n\x05\x64\x65pth\x18\x01 \x01(\x04\x12\x11\n\tin_flight\x18\x02 \x01(\x04\x12\x1c\n\x14\x61\x63tive_fairness_keys\x18\x03 \x01(\x04\x12\x18\n\x10\x61\x63tive_consumers\x18\x04 \x01(\r\x12\x0f\n\x07quantum\x18\x05 \x01(\r\x12\x33\n\rper_key_stats\x18\x06 \x03(\x0b\x32\x1c.fila.v1.PerFairnessKeyStats\x12\x38\n\x12per_throttle_stats\x18\x07 \x03(\x0b\x32\x1c.fila.v1.PerThrottleKeyStats\x12\x16\n\x0eleader_node_id\x18\x08 \x01(\x04\x12\x19\n\x11replication_count\x18\t \x01(\r\"2\n\x0eRedriveRequest\x12\x11\n\tdlq_queue\x18\x01 \x01(\t\x12\r\n\x05\x63ount\x18\x02 \x01(\x04\"#\n\x0fRedriveResponse\x12\x10\n\x08redriven\x18\x01 \x01(\x04\"\x13\n\x11ListQueuesRequest\"m\n\tQueueInfo\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\r\n\x05\x64\x65pth\x18\x02 \x01(\x04\x12\x11\n\tin_flight\x18\x03 \x01(\x04\x12\x18\n\x10\x61\x63tive_consumers\x18\x04 \x01(\r\x12\x16\n\x0eleader_node_id\x18\x05 \x01(\x04\"T\n\x12ListQueuesResponse\x12\"\n\x06queues\x18\x01 \x03(\x0b\x32\x12.fila.v1.QueueInfo\x12\x1a\n\x12\x63luster_node_count\x18\x02 \x01(\r2\xb4\x04\n\tFilaAdmin\x12H\n\x0b\x43reateQueue\x12\x1b.fila.v1.CreateQueueRequest\x1a\x1c.fila.v1.CreateQueueResponse\x12H\n\x0b\x44\x65leteQueue\x12\x1b.fila.v1.DeleteQueueRequest\x1a\x1c.fila.v1.DeleteQueueResponse\x12\x42\n\tSetConfig\x12\x19.fila.v1.SetConfigRequest\x1a\x1a.fila.v1.SetConfigResponse\x12\x42\n\tGetConfig\x12\x19.fila.v1.GetConfigRequest\x1a\x1a.fila.v1.GetConfigResponse\x12\x45\n\nListConfig\x12\x1a.fila.v1.ListConfigRequest\x1a\x1b.fila.v1.ListConfigResponse\x12?\n\x08GetStats\x12\x18.fila.v1.GetStatsRequest\x1a\x19.fila.v1.GetStatsResponse\x12<\n\x07Redrive\x12\x17.fila.v1.RedriveRequest\x1a\x18.fila.v1.RedriveResponse\x12\x45\n\nListQueues\x12\x1a.fila.v1.ListQueuesRequest\x1a\x1b.fila.v1.ListQueuesResponseb\x06proto3') - -_globals = globals() -_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) -_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'fila.v1.admin_pb2', _globals) -if not _descriptor._USE_C_DESCRIPTORS: - DESCRIPTOR._loaded_options = None - _globals['_CREATEQUEUEREQUEST']._serialized_start=32 - _globals['_CREATEQUEUEREQUEST']._serialized_end=104 - _globals['_QUEUECONFIG']._serialized_start=106 - _globals['_QUEUECONFIG']._serialized_end=204 - _globals['_CREATEQUEUERESPONSE']._serialized_start=206 - _globals['_CREATEQUEUERESPONSE']._serialized_end=245 - _globals['_DELETEQUEUEREQUEST']._serialized_start=247 - _globals['_DELETEQUEUEREQUEST']._serialized_end=282 - _globals['_DELETEQUEUERESPONSE']._serialized_start=284 - _globals['_DELETEQUEUERESPONSE']._serialized_end=305 - _globals['_SETCONFIGREQUEST']._serialized_start=307 - _globals['_SETCONFIGREQUEST']._serialized_end=353 - _globals['_SETCONFIGRESPONSE']._serialized_start=355 - _globals['_SETCONFIGRESPONSE']._serialized_end=374 - _globals['_GETCONFIGREQUEST']._serialized_start=376 - _globals['_GETCONFIGREQUEST']._serialized_end=407 - _globals['_GETCONFIGRESPONSE']._serialized_start=409 - _globals['_GETCONFIGRESPONSE']._serialized_end=443 - _globals['_CONFIGENTRY']._serialized_start=445 - _globals['_CONFIGENTRY']._serialized_end=486 - _globals['_LISTCONFIGREQUEST']._serialized_start=488 - _globals['_LISTCONFIGREQUEST']._serialized_end=523 - _globals['_LISTCONFIGRESPONSE']._serialized_start=525 - _globals['_LISTCONFIGRESPONSE']._serialized_end=605 - _globals['_GETSTATSREQUEST']._serialized_start=607 - _globals['_GETSTATSREQUEST']._serialized_end=639 - _globals['_PERFAIRNESSKEYSTATS']._serialized_start=641 - _globals['_PERFAIRNESSKEYSTATS']._serialized_end=739 - _globals['_PERTHROTTLEKEYSTATS']._serialized_start=741 - _globals['_PERTHROTTLEKEYSTATS']._serialized_end=831 - _globals['_GETSTATSRESPONSE']._serialized_start=834 - _globals['_GETSTATSRESPONSE']._serialized_end=1121 - _globals['_REDRIVEREQUEST']._serialized_start=1123 - _globals['_REDRIVEREQUEST']._serialized_end=1173 - _globals['_REDRIVERESPONSE']._serialized_start=1175 - _globals['_REDRIVERESPONSE']._serialized_end=1210 - _globals['_LISTQUEUESREQUEST']._serialized_start=1212 - _globals['_LISTQUEUESREQUEST']._serialized_end=1231 - _globals['_QUEUEINFO']._serialized_start=1233 - _globals['_QUEUEINFO']._serialized_end=1342 - _globals['_LISTQUEUESRESPONSE']._serialized_start=1344 - _globals['_LISTQUEUESRESPONSE']._serialized_end=1428 - _globals['_FILAADMIN']._serialized_start=1431 - _globals['_FILAADMIN']._serialized_end=1995 -# @@protoc_insertion_point(module_scope) diff --git a/fila/v1/admin_pb2.pyi b/fila/v1/admin_pb2.pyi deleted file mode 100644 index 0c594ce..0000000 --- a/fila/v1/admin_pb2.pyi +++ /dev/null @@ -1,179 +0,0 @@ -from google.protobuf.internal import containers as _containers -from google.protobuf import descriptor as _descriptor -from google.protobuf import message as _message -from collections.abc import Iterable as _Iterable, Mapping as _Mapping -from typing import ClassVar as _ClassVar, Optional as _Optional, Union as _Union - -DESCRIPTOR: _descriptor.FileDescriptor - -class CreateQueueRequest(_message.Message): - __slots__ = ("name", "config") - NAME_FIELD_NUMBER: _ClassVar[int] - CONFIG_FIELD_NUMBER: _ClassVar[int] - name: str - config: QueueConfig - def __init__(self, name: _Optional[str] = ..., config: _Optional[_Union[QueueConfig, _Mapping]] = ...) -> None: ... - -class QueueConfig(_message.Message): - __slots__ = ("on_enqueue_script", "on_failure_script", "visibility_timeout_ms") - ON_ENQUEUE_SCRIPT_FIELD_NUMBER: _ClassVar[int] - ON_FAILURE_SCRIPT_FIELD_NUMBER: _ClassVar[int] - VISIBILITY_TIMEOUT_MS_FIELD_NUMBER: _ClassVar[int] - on_enqueue_script: str - on_failure_script: str - visibility_timeout_ms: int - def __init__(self, on_enqueue_script: _Optional[str] = ..., on_failure_script: _Optional[str] = ..., visibility_timeout_ms: _Optional[int] = ...) -> None: ... - -class CreateQueueResponse(_message.Message): - __slots__ = ("queue_id",) - QUEUE_ID_FIELD_NUMBER: _ClassVar[int] - queue_id: str - def __init__(self, queue_id: _Optional[str] = ...) -> None: ... - -class DeleteQueueRequest(_message.Message): - __slots__ = ("queue",) - QUEUE_FIELD_NUMBER: _ClassVar[int] - queue: str - def __init__(self, queue: _Optional[str] = ...) -> None: ... - -class DeleteQueueResponse(_message.Message): - __slots__ = () - def __init__(self) -> None: ... - -class SetConfigRequest(_message.Message): - __slots__ = ("key", "value") - KEY_FIELD_NUMBER: _ClassVar[int] - VALUE_FIELD_NUMBER: _ClassVar[int] - key: str - value: str - def __init__(self, key: _Optional[str] = ..., value: _Optional[str] = ...) -> None: ... - -class SetConfigResponse(_message.Message): - __slots__ = () - def __init__(self) -> None: ... - -class GetConfigRequest(_message.Message): - __slots__ = ("key",) - KEY_FIELD_NUMBER: _ClassVar[int] - key: str - def __init__(self, key: _Optional[str] = ...) -> None: ... - -class GetConfigResponse(_message.Message): - __slots__ = ("value",) - VALUE_FIELD_NUMBER: _ClassVar[int] - value: str - def __init__(self, value: _Optional[str] = ...) -> None: ... - -class ConfigEntry(_message.Message): - __slots__ = ("key", "value") - KEY_FIELD_NUMBER: _ClassVar[int] - VALUE_FIELD_NUMBER: _ClassVar[int] - key: str - value: str - def __init__(self, key: _Optional[str] = ..., value: _Optional[str] = ...) -> None: ... - -class ListConfigRequest(_message.Message): - __slots__ = ("prefix",) - PREFIX_FIELD_NUMBER: _ClassVar[int] - prefix: str - def __init__(self, prefix: _Optional[str] = ...) -> None: ... - -class ListConfigResponse(_message.Message): - __slots__ = ("entries", "total_count") - ENTRIES_FIELD_NUMBER: _ClassVar[int] - TOTAL_COUNT_FIELD_NUMBER: _ClassVar[int] - entries: _containers.RepeatedCompositeFieldContainer[ConfigEntry] - total_count: int - def __init__(self, entries: _Optional[_Iterable[_Union[ConfigEntry, _Mapping]]] = ..., total_count: _Optional[int] = ...) -> None: ... - -class GetStatsRequest(_message.Message): - __slots__ = ("queue",) - QUEUE_FIELD_NUMBER: _ClassVar[int] - queue: str - def __init__(self, queue: _Optional[str] = ...) -> None: ... - -class PerFairnessKeyStats(_message.Message): - __slots__ = ("key", "pending_count", "current_deficit", "weight") - KEY_FIELD_NUMBER: _ClassVar[int] - PENDING_COUNT_FIELD_NUMBER: _ClassVar[int] - CURRENT_DEFICIT_FIELD_NUMBER: _ClassVar[int] - WEIGHT_FIELD_NUMBER: _ClassVar[int] - key: str - pending_count: int - current_deficit: int - weight: int - def __init__(self, key: _Optional[str] = ..., pending_count: _Optional[int] = ..., current_deficit: _Optional[int] = ..., weight: _Optional[int] = ...) -> None: ... - -class PerThrottleKeyStats(_message.Message): - __slots__ = ("key", "tokens", "rate_per_second", "burst") - KEY_FIELD_NUMBER: _ClassVar[int] - TOKENS_FIELD_NUMBER: _ClassVar[int] - RATE_PER_SECOND_FIELD_NUMBER: _ClassVar[int] - BURST_FIELD_NUMBER: _ClassVar[int] - key: str - tokens: float - rate_per_second: float - burst: float - def __init__(self, key: _Optional[str] = ..., tokens: _Optional[float] = ..., rate_per_second: _Optional[float] = ..., burst: _Optional[float] = ...) -> None: ... - -class GetStatsResponse(_message.Message): - __slots__ = ("depth", "in_flight", "active_fairness_keys", "active_consumers", "quantum", "per_key_stats", "per_throttle_stats", "leader_node_id", "replication_count") - DEPTH_FIELD_NUMBER: _ClassVar[int] - IN_FLIGHT_FIELD_NUMBER: _ClassVar[int] - ACTIVE_FAIRNESS_KEYS_FIELD_NUMBER: _ClassVar[int] - ACTIVE_CONSUMERS_FIELD_NUMBER: _ClassVar[int] - QUANTUM_FIELD_NUMBER: _ClassVar[int] - PER_KEY_STATS_FIELD_NUMBER: _ClassVar[int] - PER_THROTTLE_STATS_FIELD_NUMBER: _ClassVar[int] - LEADER_NODE_ID_FIELD_NUMBER: _ClassVar[int] - REPLICATION_COUNT_FIELD_NUMBER: _ClassVar[int] - depth: int - in_flight: int - active_fairness_keys: int - active_consumers: int - quantum: int - per_key_stats: _containers.RepeatedCompositeFieldContainer[PerFairnessKeyStats] - per_throttle_stats: _containers.RepeatedCompositeFieldContainer[PerThrottleKeyStats] - leader_node_id: int - replication_count: int - def __init__(self, depth: _Optional[int] = ..., in_flight: _Optional[int] = ..., active_fairness_keys: _Optional[int] = ..., active_consumers: _Optional[int] = ..., quantum: _Optional[int] = ..., per_key_stats: _Optional[_Iterable[_Union[PerFairnessKeyStats, _Mapping]]] = ..., per_throttle_stats: _Optional[_Iterable[_Union[PerThrottleKeyStats, _Mapping]]] = ..., leader_node_id: _Optional[int] = ..., replication_count: _Optional[int] = ...) -> None: ... - -class RedriveRequest(_message.Message): - __slots__ = ("dlq_queue", "count") - DLQ_QUEUE_FIELD_NUMBER: _ClassVar[int] - COUNT_FIELD_NUMBER: _ClassVar[int] - dlq_queue: str - count: int - def __init__(self, dlq_queue: _Optional[str] = ..., count: _Optional[int] = ...) -> None: ... - -class RedriveResponse(_message.Message): - __slots__ = ("redriven",) - REDRIVEN_FIELD_NUMBER: _ClassVar[int] - redriven: int - def __init__(self, redriven: _Optional[int] = ...) -> None: ... - -class ListQueuesRequest(_message.Message): - __slots__ = () - def __init__(self) -> None: ... - -class QueueInfo(_message.Message): - __slots__ = ("name", "depth", "in_flight", "active_consumers", "leader_node_id") - NAME_FIELD_NUMBER: _ClassVar[int] - DEPTH_FIELD_NUMBER: _ClassVar[int] - IN_FLIGHT_FIELD_NUMBER: _ClassVar[int] - ACTIVE_CONSUMERS_FIELD_NUMBER: _ClassVar[int] - LEADER_NODE_ID_FIELD_NUMBER: _ClassVar[int] - name: str - depth: int - in_flight: int - active_consumers: int - leader_node_id: int - def __init__(self, name: _Optional[str] = ..., depth: _Optional[int] = ..., in_flight: _Optional[int] = ..., active_consumers: _Optional[int] = ..., leader_node_id: _Optional[int] = ...) -> None: ... - -class ListQueuesResponse(_message.Message): - __slots__ = ("queues", "cluster_node_count") - QUEUES_FIELD_NUMBER: _ClassVar[int] - CLUSTER_NODE_COUNT_FIELD_NUMBER: _ClassVar[int] - queues: _containers.RepeatedCompositeFieldContainer[QueueInfo] - cluster_node_count: int - def __init__(self, queues: _Optional[_Iterable[_Union[QueueInfo, _Mapping]]] = ..., cluster_node_count: _Optional[int] = ...) -> None: ... diff --git a/fila/v1/admin_pb2_grpc.py b/fila/v1/admin_pb2_grpc.py deleted file mode 100644 index 3b07e1a..0000000 --- a/fila/v1/admin_pb2_grpc.py +++ /dev/null @@ -1,401 +0,0 @@ -# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! -"""Client and server classes corresponding to protobuf-defined services.""" -import grpc -import warnings - -from fila.v1 import admin_pb2 as fila_dot_v1_dot_admin__pb2 - -GRPC_GENERATED_VERSION = '1.78.0' -GRPC_VERSION = grpc.__version__ -_version_not_supported = False - -try: - from grpc._utilities import first_version_is_lower - _version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION) -except ImportError: - _version_not_supported = True - -if _version_not_supported: - raise RuntimeError( - f'The grpc package installed is at version {GRPC_VERSION},' - + ' but the generated code in fila/v1/admin_pb2_grpc.py depends on' - + f' grpcio>={GRPC_GENERATED_VERSION}.' - + f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}' - + f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.' - ) - - -class FilaAdminStub(object): - """Admin RPCs for operators and the CLI. - """ - - def __init__(self, channel): - """Constructor. - - Args: - channel: A grpc.Channel. - """ - self.CreateQueue = channel.unary_unary( - '/fila.v1.FilaAdmin/CreateQueue', - request_serializer=fila_dot_v1_dot_admin__pb2.CreateQueueRequest.SerializeToString, - response_deserializer=fila_dot_v1_dot_admin__pb2.CreateQueueResponse.FromString, - _registered_method=True) - self.DeleteQueue = channel.unary_unary( - '/fila.v1.FilaAdmin/DeleteQueue', - request_serializer=fila_dot_v1_dot_admin__pb2.DeleteQueueRequest.SerializeToString, - response_deserializer=fila_dot_v1_dot_admin__pb2.DeleteQueueResponse.FromString, - _registered_method=True) - self.SetConfig = channel.unary_unary( - '/fila.v1.FilaAdmin/SetConfig', - request_serializer=fila_dot_v1_dot_admin__pb2.SetConfigRequest.SerializeToString, - response_deserializer=fila_dot_v1_dot_admin__pb2.SetConfigResponse.FromString, - _registered_method=True) - self.GetConfig = channel.unary_unary( - '/fila.v1.FilaAdmin/GetConfig', - request_serializer=fila_dot_v1_dot_admin__pb2.GetConfigRequest.SerializeToString, - response_deserializer=fila_dot_v1_dot_admin__pb2.GetConfigResponse.FromString, - _registered_method=True) - self.ListConfig = channel.unary_unary( - '/fila.v1.FilaAdmin/ListConfig', - request_serializer=fila_dot_v1_dot_admin__pb2.ListConfigRequest.SerializeToString, - response_deserializer=fila_dot_v1_dot_admin__pb2.ListConfigResponse.FromString, - _registered_method=True) - self.GetStats = channel.unary_unary( - '/fila.v1.FilaAdmin/GetStats', - request_serializer=fila_dot_v1_dot_admin__pb2.GetStatsRequest.SerializeToString, - response_deserializer=fila_dot_v1_dot_admin__pb2.GetStatsResponse.FromString, - _registered_method=True) - self.Redrive = channel.unary_unary( - '/fila.v1.FilaAdmin/Redrive', - request_serializer=fila_dot_v1_dot_admin__pb2.RedriveRequest.SerializeToString, - response_deserializer=fila_dot_v1_dot_admin__pb2.RedriveResponse.FromString, - _registered_method=True) - self.ListQueues = channel.unary_unary( - '/fila.v1.FilaAdmin/ListQueues', - request_serializer=fila_dot_v1_dot_admin__pb2.ListQueuesRequest.SerializeToString, - response_deserializer=fila_dot_v1_dot_admin__pb2.ListQueuesResponse.FromString, - _registered_method=True) - - -class FilaAdminServicer(object): - """Admin RPCs for operators and the CLI. - """ - - def CreateQueue(self, request, context): - """Missing associated documentation comment in .proto file.""" - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') - - def DeleteQueue(self, request, context): - """Missing associated documentation comment in .proto file.""" - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') - - def SetConfig(self, request, context): - """Missing associated documentation comment in .proto file.""" - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') - - def GetConfig(self, request, context): - """Missing associated documentation comment in .proto file.""" - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') - - def ListConfig(self, request, context): - """Missing associated documentation comment in .proto file.""" - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') - - def GetStats(self, request, context): - """Missing associated documentation comment in .proto file.""" - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') - - def Redrive(self, request, context): - """Missing associated documentation comment in .proto file.""" - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') - - def ListQueues(self, request, context): - """Missing associated documentation comment in .proto file.""" - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') - - -def add_FilaAdminServicer_to_server(servicer, server): - rpc_method_handlers = { - 'CreateQueue': grpc.unary_unary_rpc_method_handler( - servicer.CreateQueue, - request_deserializer=fila_dot_v1_dot_admin__pb2.CreateQueueRequest.FromString, - response_serializer=fila_dot_v1_dot_admin__pb2.CreateQueueResponse.SerializeToString, - ), - 'DeleteQueue': grpc.unary_unary_rpc_method_handler( - servicer.DeleteQueue, - request_deserializer=fila_dot_v1_dot_admin__pb2.DeleteQueueRequest.FromString, - response_serializer=fila_dot_v1_dot_admin__pb2.DeleteQueueResponse.SerializeToString, - ), - 'SetConfig': grpc.unary_unary_rpc_method_handler( - servicer.SetConfig, - request_deserializer=fila_dot_v1_dot_admin__pb2.SetConfigRequest.FromString, - response_serializer=fila_dot_v1_dot_admin__pb2.SetConfigResponse.SerializeToString, - ), - 'GetConfig': grpc.unary_unary_rpc_method_handler( - servicer.GetConfig, - request_deserializer=fila_dot_v1_dot_admin__pb2.GetConfigRequest.FromString, - response_serializer=fila_dot_v1_dot_admin__pb2.GetConfigResponse.SerializeToString, - ), - 'ListConfig': grpc.unary_unary_rpc_method_handler( - servicer.ListConfig, - request_deserializer=fila_dot_v1_dot_admin__pb2.ListConfigRequest.FromString, - response_serializer=fila_dot_v1_dot_admin__pb2.ListConfigResponse.SerializeToString, - ), - 'GetStats': grpc.unary_unary_rpc_method_handler( - servicer.GetStats, - request_deserializer=fila_dot_v1_dot_admin__pb2.GetStatsRequest.FromString, - response_serializer=fila_dot_v1_dot_admin__pb2.GetStatsResponse.SerializeToString, - ), - 'Redrive': grpc.unary_unary_rpc_method_handler( - servicer.Redrive, - request_deserializer=fila_dot_v1_dot_admin__pb2.RedriveRequest.FromString, - response_serializer=fila_dot_v1_dot_admin__pb2.RedriveResponse.SerializeToString, - ), - 'ListQueues': grpc.unary_unary_rpc_method_handler( - servicer.ListQueues, - request_deserializer=fila_dot_v1_dot_admin__pb2.ListQueuesRequest.FromString, - response_serializer=fila_dot_v1_dot_admin__pb2.ListQueuesResponse.SerializeToString, - ), - } - generic_handler = grpc.method_handlers_generic_handler( - 'fila.v1.FilaAdmin', rpc_method_handlers) - server.add_generic_rpc_handlers((generic_handler,)) - server.add_registered_method_handlers('fila.v1.FilaAdmin', rpc_method_handlers) - - - # This class is part of an EXPERIMENTAL API. -class FilaAdmin(object): - """Admin RPCs for operators and the CLI. - """ - - @staticmethod - def CreateQueue(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary( - request, - target, - '/fila.v1.FilaAdmin/CreateQueue', - fila_dot_v1_dot_admin__pb2.CreateQueueRequest.SerializeToString, - fila_dot_v1_dot_admin__pb2.CreateQueueResponse.FromString, - options, - channel_credentials, - insecure, - call_credentials, - compression, - wait_for_ready, - timeout, - metadata, - _registered_method=True) - - @staticmethod - def DeleteQueue(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary( - request, - target, - '/fila.v1.FilaAdmin/DeleteQueue', - fila_dot_v1_dot_admin__pb2.DeleteQueueRequest.SerializeToString, - fila_dot_v1_dot_admin__pb2.DeleteQueueResponse.FromString, - options, - channel_credentials, - insecure, - call_credentials, - compression, - wait_for_ready, - timeout, - metadata, - _registered_method=True) - - @staticmethod - def SetConfig(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary( - request, - target, - '/fila.v1.FilaAdmin/SetConfig', - fila_dot_v1_dot_admin__pb2.SetConfigRequest.SerializeToString, - fila_dot_v1_dot_admin__pb2.SetConfigResponse.FromString, - options, - channel_credentials, - insecure, - call_credentials, - compression, - wait_for_ready, - timeout, - metadata, - _registered_method=True) - - @staticmethod - def GetConfig(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary( - request, - target, - '/fila.v1.FilaAdmin/GetConfig', - fila_dot_v1_dot_admin__pb2.GetConfigRequest.SerializeToString, - fila_dot_v1_dot_admin__pb2.GetConfigResponse.FromString, - options, - channel_credentials, - insecure, - call_credentials, - compression, - wait_for_ready, - timeout, - metadata, - _registered_method=True) - - @staticmethod - def ListConfig(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary( - request, - target, - '/fila.v1.FilaAdmin/ListConfig', - fila_dot_v1_dot_admin__pb2.ListConfigRequest.SerializeToString, - fila_dot_v1_dot_admin__pb2.ListConfigResponse.FromString, - options, - channel_credentials, - insecure, - call_credentials, - compression, - wait_for_ready, - timeout, - metadata, - _registered_method=True) - - @staticmethod - def GetStats(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary( - request, - target, - '/fila.v1.FilaAdmin/GetStats', - fila_dot_v1_dot_admin__pb2.GetStatsRequest.SerializeToString, - fila_dot_v1_dot_admin__pb2.GetStatsResponse.FromString, - options, - channel_credentials, - insecure, - call_credentials, - compression, - wait_for_ready, - timeout, - metadata, - _registered_method=True) - - @staticmethod - def Redrive(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary( - request, - target, - '/fila.v1.FilaAdmin/Redrive', - fila_dot_v1_dot_admin__pb2.RedriveRequest.SerializeToString, - fila_dot_v1_dot_admin__pb2.RedriveResponse.FromString, - options, - channel_credentials, - insecure, - call_credentials, - compression, - wait_for_ready, - timeout, - metadata, - _registered_method=True) - - @staticmethod - def ListQueues(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary( - request, - target, - '/fila.v1.FilaAdmin/ListQueues', - fila_dot_v1_dot_admin__pb2.ListQueuesRequest.SerializeToString, - fila_dot_v1_dot_admin__pb2.ListQueuesResponse.FromString, - options, - channel_credentials, - insecure, - call_credentials, - compression, - wait_for_ready, - timeout, - metadata, - _registered_method=True) diff --git a/fila/v1/messages_pb2.py b/fila/v1/messages_pb2.py deleted file mode 100644 index 3cf7edb..0000000 --- a/fila/v1/messages_pb2.py +++ /dev/null @@ -1,45 +0,0 @@ -# -*- coding: utf-8 -*- -# Generated by the protocol buffer compiler. DO NOT EDIT! -# NO CHECKED-IN PROTOBUF GENCODE -# source: fila/v1/messages.proto -# Protobuf Python Version: 6.31.1 -"""Generated protocol buffer code.""" -from google.protobuf import descriptor as _descriptor -from google.protobuf import descriptor_pool as _descriptor_pool -from google.protobuf import runtime_version as _runtime_version -from google.protobuf import symbol_database as _symbol_database -from google.protobuf.internal import builder as _builder -_runtime_version.ValidateProtobufRuntimeVersion( - _runtime_version.Domain.PUBLIC, - 6, - 31, - 1, - '', - 'fila/v1/messages.proto' -) -# @@protoc_insertion_point(imports) - -_sym_db = _symbol_database.Default() - - -from google.protobuf import timestamp_pb2 as google_dot_protobuf_dot_timestamp__pb2 - - -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x16\x66ila/v1/messages.proto\x12\x07\x66ila.v1\x1a\x1fgoogle/protobuf/timestamp.proto\"\xe2\x01\n\x07Message\x12\n\n\x02id\x18\x01 \x01(\t\x12.\n\x07headers\x18\x02 \x03(\x0b\x32\x1d.fila.v1.Message.HeadersEntry\x12\x0f\n\x07payload\x18\x03 \x01(\x0c\x12*\n\x08metadata\x18\x04 \x01(\x0b\x32\x18.fila.v1.MessageMetadata\x12.\n\ntimestamps\x18\x05 \x01(\x0b\x32\x1a.fila.v1.MessageTimestamps\x1a.\n\x0cHeadersEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\"w\n\x0fMessageMetadata\x12\x14\n\x0c\x66\x61irness_key\x18\x01 \x01(\t\x12\x0e\n\x06weight\x18\x02 \x01(\r\x12\x15\n\rthrottle_keys\x18\x03 \x03(\t\x12\x15\n\rattempt_count\x18\x04 \x01(\r\x12\x10\n\x08queue_id\x18\x05 \x01(\t\"s\n\x11MessageTimestamps\x12/\n\x0b\x65nqueued_at\x18\x01 \x01(\x0b\x32\x1a.google.protobuf.Timestamp\x12-\n\tleased_at\x18\x02 \x01(\x0b\x32\x1a.google.protobuf.Timestampb\x06proto3') - -_globals = globals() -_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) -_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'fila.v1.messages_pb2', _globals) -if not _descriptor._USE_C_DESCRIPTORS: - DESCRIPTOR._loaded_options = None - _globals['_MESSAGE_HEADERSENTRY']._loaded_options = None - _globals['_MESSAGE_HEADERSENTRY']._serialized_options = b'8\001' - _globals['_MESSAGE']._serialized_start=69 - _globals['_MESSAGE']._serialized_end=295 - _globals['_MESSAGE_HEADERSENTRY']._serialized_start=249 - _globals['_MESSAGE_HEADERSENTRY']._serialized_end=295 - _globals['_MESSAGEMETADATA']._serialized_start=297 - _globals['_MESSAGEMETADATA']._serialized_end=416 - _globals['_MESSAGETIMESTAMPS']._serialized_start=418 - _globals['_MESSAGETIMESTAMPS']._serialized_end=533 -# @@protoc_insertion_point(module_scope) diff --git a/fila/v1/messages_pb2.pyi b/fila/v1/messages_pb2.pyi deleted file mode 100644 index a91bb74..0000000 --- a/fila/v1/messages_pb2.pyi +++ /dev/null @@ -1,53 +0,0 @@ -import datetime - -from google.protobuf import timestamp_pb2 as _timestamp_pb2 -from google.protobuf.internal import containers as _containers -from google.protobuf import descriptor as _descriptor -from google.protobuf import message as _message -from collections.abc import Iterable as _Iterable, Mapping as _Mapping -from typing import ClassVar as _ClassVar, Optional as _Optional, Union as _Union - -DESCRIPTOR: _descriptor.FileDescriptor - -class Message(_message.Message): - __slots__ = ("id", "headers", "payload", "metadata", "timestamps") - class HeadersEntry(_message.Message): - __slots__ = ("key", "value") - KEY_FIELD_NUMBER: _ClassVar[int] - VALUE_FIELD_NUMBER: _ClassVar[int] - key: str - value: str - def __init__(self, key: _Optional[str] = ..., value: _Optional[str] = ...) -> None: ... - ID_FIELD_NUMBER: _ClassVar[int] - HEADERS_FIELD_NUMBER: _ClassVar[int] - PAYLOAD_FIELD_NUMBER: _ClassVar[int] - METADATA_FIELD_NUMBER: _ClassVar[int] - TIMESTAMPS_FIELD_NUMBER: _ClassVar[int] - id: str - headers: _containers.ScalarMap[str, str] - payload: bytes - metadata: MessageMetadata - timestamps: MessageTimestamps - def __init__(self, id: _Optional[str] = ..., headers: _Optional[_Mapping[str, str]] = ..., payload: _Optional[bytes] = ..., metadata: _Optional[_Union[MessageMetadata, _Mapping]] = ..., timestamps: _Optional[_Union[MessageTimestamps, _Mapping]] = ...) -> None: ... - -class MessageMetadata(_message.Message): - __slots__ = ("fairness_key", "weight", "throttle_keys", "attempt_count", "queue_id") - FAIRNESS_KEY_FIELD_NUMBER: _ClassVar[int] - WEIGHT_FIELD_NUMBER: _ClassVar[int] - THROTTLE_KEYS_FIELD_NUMBER: _ClassVar[int] - ATTEMPT_COUNT_FIELD_NUMBER: _ClassVar[int] - QUEUE_ID_FIELD_NUMBER: _ClassVar[int] - fairness_key: str - weight: int - throttle_keys: _containers.RepeatedScalarFieldContainer[str] - attempt_count: int - queue_id: str - def __init__(self, fairness_key: _Optional[str] = ..., weight: _Optional[int] = ..., throttle_keys: _Optional[_Iterable[str]] = ..., attempt_count: _Optional[int] = ..., queue_id: _Optional[str] = ...) -> None: ... - -class MessageTimestamps(_message.Message): - __slots__ = ("enqueued_at", "leased_at") - ENQUEUED_AT_FIELD_NUMBER: _ClassVar[int] - LEASED_AT_FIELD_NUMBER: _ClassVar[int] - enqueued_at: _timestamp_pb2.Timestamp - leased_at: _timestamp_pb2.Timestamp - def __init__(self, enqueued_at: _Optional[_Union[datetime.datetime, _timestamp_pb2.Timestamp, _Mapping]] = ..., leased_at: _Optional[_Union[datetime.datetime, _timestamp_pb2.Timestamp, _Mapping]] = ...) -> None: ... diff --git a/fila/v1/messages_pb2_grpc.py b/fila/v1/messages_pb2_grpc.py deleted file mode 100644 index fa0dc71..0000000 --- a/fila/v1/messages_pb2_grpc.py +++ /dev/null @@ -1,24 +0,0 @@ -# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! -"""Client and server classes corresponding to protobuf-defined services.""" -import grpc -import warnings - - -GRPC_GENERATED_VERSION = '1.78.0' -GRPC_VERSION = grpc.__version__ -_version_not_supported = False - -try: - from grpc._utilities import first_version_is_lower - _version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION) -except ImportError: - _version_not_supported = True - -if _version_not_supported: - raise RuntimeError( - f'The grpc package installed is at version {GRPC_VERSION},' - + ' but the generated code in fila/v1/messages_pb2_grpc.py depends on' - + f' grpcio>={GRPC_GENERATED_VERSION}.' - + f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}' - + f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.' - ) diff --git a/fila/v1/service_pb2.py b/fila/v1/service_pb2.py deleted file mode 100644 index 7489260..0000000 --- a/fila/v1/service_pb2.py +++ /dev/null @@ -1,89 +0,0 @@ -# -*- coding: utf-8 -*- -# Generated by the protocol buffer compiler. DO NOT EDIT! -# NO CHECKED-IN PROTOBUF GENCODE -# source: fila/v1/service.proto -# Protobuf Python Version: 6.31.1 -"""Generated protocol buffer code.""" -from google.protobuf import descriptor as _descriptor -from google.protobuf import descriptor_pool as _descriptor_pool -from google.protobuf import runtime_version as _runtime_version -from google.protobuf import symbol_database as _symbol_database -from google.protobuf.internal import builder as _builder -_runtime_version.ValidateProtobufRuntimeVersion( - _runtime_version.Domain.PUBLIC, - 6, - 31, - 1, - '', - 'fila/v1/service.proto' -) -# @@protoc_insertion_point(imports) - -_sym_db = _symbol_database.Default() - - -from fila.v1 import messages_pb2 as fila_dot_v1_dot_messages__pb2 - - -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x15\x66ila/v1/service.proto\x12\x07\x66ila.v1\x1a\x16\x66ila/v1/messages.proto\"\x97\x01\n\x0e\x45nqueueMessage\x12\r\n\x05queue\x18\x01 \x01(\t\x12\x35\n\x07headers\x18\x02 \x03(\x0b\x32$.fila.v1.EnqueueMessage.HeadersEntry\x12\x0f\n\x07payload\x18\x03 \x01(\x0c\x1a.\n\x0cHeadersEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\";\n\x0e\x45nqueueRequest\x12)\n\x08messages\x18\x01 \x03(\x0b\x32\x17.fila.v1.EnqueueMessage\"W\n\rEnqueueResult\x12\x14\n\nmessage_id\x18\x01 \x01(\tH\x00\x12&\n\x05\x65rror\x18\x02 \x01(\x0b\x32\x15.fila.v1.EnqueueErrorH\x00\x42\x08\n\x06result\"H\n\x0c\x45nqueueError\x12\'\n\x04\x63ode\x18\x01 \x01(\x0e\x32\x19.fila.v1.EnqueueErrorCode\x12\x0f\n\x07message\x18\x02 \x01(\t\":\n\x0f\x45nqueueResponse\x12\'\n\x07results\x18\x01 \x03(\x0b\x32\x16.fila.v1.EnqueueResult\"\x1f\n\x0e\x43onsumeRequest\x12\r\n\x05queue\x18\x01 \x01(\t\"5\n\x0f\x43onsumeResponse\x12\"\n\x08messages\x18\x01 \x03(\x0b\x32\x10.fila.v1.Message\"/\n\nAckMessage\x12\r\n\x05queue\x18\x01 \x01(\t\x12\x12\n\nmessage_id\x18\x02 \x01(\t\"3\n\nAckRequest\x12%\n\x08messages\x18\x01 \x03(\x0b\x32\x13.fila.v1.AckMessage\"a\n\tAckResult\x12&\n\x07success\x18\x01 \x01(\x0b\x32\x13.fila.v1.AckSuccessH\x00\x12\"\n\x05\x65rror\x18\x02 \x01(\x0b\x32\x11.fila.v1.AckErrorH\x00\x42\x08\n\x06result\"\x0c\n\nAckSuccess\"@\n\x08\x41\x63kError\x12#\n\x04\x63ode\x18\x01 \x01(\x0e\x32\x15.fila.v1.AckErrorCode\x12\x0f\n\x07message\x18\x02 \x01(\t\"2\n\x0b\x41\x63kResponse\x12#\n\x07results\x18\x01 \x03(\x0b\x32\x12.fila.v1.AckResult\"?\n\x0bNackMessage\x12\r\n\x05queue\x18\x01 \x01(\t\x12\x12\n\nmessage_id\x18\x02 \x01(\t\x12\r\n\x05\x65rror\x18\x03 \x01(\t\"5\n\x0bNackRequest\x12&\n\x08messages\x18\x01 \x03(\x0b\x32\x14.fila.v1.NackMessage\"d\n\nNackResult\x12\'\n\x07success\x18\x01 \x01(\x0b\x32\x14.fila.v1.NackSuccessH\x00\x12#\n\x05\x65rror\x18\x02 \x01(\x0b\x32\x12.fila.v1.NackErrorH\x00\x42\x08\n\x06result\"\r\n\x0bNackSuccess\"B\n\tNackError\x12$\n\x04\x63ode\x18\x01 \x01(\x0e\x32\x16.fila.v1.NackErrorCode\x12\x0f\n\x07message\x18\x02 \x01(\t\"4\n\x0cNackResponse\x12$\n\x07results\x18\x01 \x03(\x0b\x32\x13.fila.v1.NackResult\"Z\n\x14StreamEnqueueRequest\x12)\n\x08messages\x18\x01 \x03(\x0b\x32\x17.fila.v1.EnqueueMessage\x12\x17\n\x0fsequence_number\x18\x02 \x01(\x04\"Y\n\x15StreamEnqueueResponse\x12\x17\n\x0fsequence_number\x18\x01 \x01(\x04\x12\'\n\x07results\x18\x02 \x03(\x0b\x32\x16.fila.v1.EnqueueResult*\xc4\x01\n\x10\x45nqueueErrorCode\x12\"\n\x1e\x45NQUEUE_ERROR_CODE_UNSPECIFIED\x10\x00\x12&\n\"ENQUEUE_ERROR_CODE_QUEUE_NOT_FOUND\x10\x01\x12\x1e\n\x1a\x45NQUEUE_ERROR_CODE_STORAGE\x10\x02\x12\x1a\n\x16\x45NQUEUE_ERROR_CODE_LUA\x10\x03\x12(\n$ENQUEUE_ERROR_CODE_PERMISSION_DENIED\x10\x04*\x96\x01\n\x0c\x41\x63kErrorCode\x12\x1e\n\x1a\x41\x43K_ERROR_CODE_UNSPECIFIED\x10\x00\x12$\n ACK_ERROR_CODE_MESSAGE_NOT_FOUND\x10\x01\x12\x1a\n\x16\x41\x43K_ERROR_CODE_STORAGE\x10\x02\x12$\n ACK_ERROR_CODE_PERMISSION_DENIED\x10\x03*\x9b\x01\n\rNackErrorCode\x12\x1f\n\x1bNACK_ERROR_CODE_UNSPECIFIED\x10\x00\x12%\n!NACK_ERROR_CODE_MESSAGE_NOT_FOUND\x10\x01\x12\x1b\n\x17NACK_ERROR_CODE_STORAGE\x10\x02\x12%\n!NACK_ERROR_CODE_PERMISSION_DENIED\x10\x03\x32\xc6\x02\n\x0b\x46ilaService\x12<\n\x07\x45nqueue\x12\x17.fila.v1.EnqueueRequest\x1a\x18.fila.v1.EnqueueResponse\x12R\n\rStreamEnqueue\x12\x1d.fila.v1.StreamEnqueueRequest\x1a\x1e.fila.v1.StreamEnqueueResponse(\x01\x30\x01\x12>\n\x07\x43onsume\x12\x17.fila.v1.ConsumeRequest\x1a\x18.fila.v1.ConsumeResponse0\x01\x12\x30\n\x03\x41\x63k\x12\x13.fila.v1.AckRequest\x1a\x14.fila.v1.AckResponse\x12\x33\n\x04Nack\x12\x14.fila.v1.NackRequest\x1a\x15.fila.v1.NackResponseb\x06proto3') - -_globals = globals() -_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) -_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'fila.v1.service_pb2', _globals) -if not _descriptor._USE_C_DESCRIPTORS: - DESCRIPTOR._loaded_options = None - _globals['_ENQUEUEMESSAGE_HEADERSENTRY']._loaded_options = None - _globals['_ENQUEUEMESSAGE_HEADERSENTRY']._serialized_options = b'8\001' - _globals['_ENQUEUEERRORCODE']._serialized_start=1460 - _globals['_ENQUEUEERRORCODE']._serialized_end=1656 - _globals['_ACKERRORCODE']._serialized_start=1659 - _globals['_ACKERRORCODE']._serialized_end=1809 - _globals['_NACKERRORCODE']._serialized_start=1812 - _globals['_NACKERRORCODE']._serialized_end=1967 - _globals['_ENQUEUEMESSAGE']._serialized_start=59 - _globals['_ENQUEUEMESSAGE']._serialized_end=210 - _globals['_ENQUEUEMESSAGE_HEADERSENTRY']._serialized_start=164 - _globals['_ENQUEUEMESSAGE_HEADERSENTRY']._serialized_end=210 - _globals['_ENQUEUEREQUEST']._serialized_start=212 - _globals['_ENQUEUEREQUEST']._serialized_end=271 - _globals['_ENQUEUERESULT']._serialized_start=273 - _globals['_ENQUEUERESULT']._serialized_end=360 - _globals['_ENQUEUEERROR']._serialized_start=362 - _globals['_ENQUEUEERROR']._serialized_end=434 - _globals['_ENQUEUERESPONSE']._serialized_start=436 - _globals['_ENQUEUERESPONSE']._serialized_end=494 - _globals['_CONSUMEREQUEST']._serialized_start=496 - _globals['_CONSUMEREQUEST']._serialized_end=527 - _globals['_CONSUMERESPONSE']._serialized_start=529 - _globals['_CONSUMERESPONSE']._serialized_end=582 - _globals['_ACKMESSAGE']._serialized_start=584 - _globals['_ACKMESSAGE']._serialized_end=631 - _globals['_ACKREQUEST']._serialized_start=633 - _globals['_ACKREQUEST']._serialized_end=684 - _globals['_ACKRESULT']._serialized_start=686 - _globals['_ACKRESULT']._serialized_end=783 - _globals['_ACKSUCCESS']._serialized_start=785 - _globals['_ACKSUCCESS']._serialized_end=797 - _globals['_ACKERROR']._serialized_start=799 - _globals['_ACKERROR']._serialized_end=863 - _globals['_ACKRESPONSE']._serialized_start=865 - _globals['_ACKRESPONSE']._serialized_end=915 - _globals['_NACKMESSAGE']._serialized_start=917 - _globals['_NACKMESSAGE']._serialized_end=980 - _globals['_NACKREQUEST']._serialized_start=982 - _globals['_NACKREQUEST']._serialized_end=1035 - _globals['_NACKRESULT']._serialized_start=1037 - _globals['_NACKRESULT']._serialized_end=1137 - _globals['_NACKSUCCESS']._serialized_start=1139 - _globals['_NACKSUCCESS']._serialized_end=1152 - _globals['_NACKERROR']._serialized_start=1154 - _globals['_NACKERROR']._serialized_end=1220 - _globals['_NACKRESPONSE']._serialized_start=1222 - _globals['_NACKRESPONSE']._serialized_end=1274 - _globals['_STREAMENQUEUEREQUEST']._serialized_start=1276 - _globals['_STREAMENQUEUEREQUEST']._serialized_end=1366 - _globals['_STREAMENQUEUERESPONSE']._serialized_start=1368 - _globals['_STREAMENQUEUERESPONSE']._serialized_end=1457 - _globals['_FILASERVICE']._serialized_start=1970 - _globals['_FILASERVICE']._serialized_end=2296 -# @@protoc_insertion_point(module_scope) diff --git a/fila/v1/service_pb2.pyi b/fila/v1/service_pb2.pyi deleted file mode 100644 index a840197..0000000 --- a/fila/v1/service_pb2.pyi +++ /dev/null @@ -1,199 +0,0 @@ -from fila.v1 import messages_pb2 as _messages_pb2 -from google.protobuf.internal import containers as _containers -from google.protobuf.internal import enum_type_wrapper as _enum_type_wrapper -from google.protobuf import descriptor as _descriptor -from google.protobuf import message as _message -from collections.abc import Iterable as _Iterable, Mapping as _Mapping -from typing import ClassVar as _ClassVar, Optional as _Optional, Union as _Union - -DESCRIPTOR: _descriptor.FileDescriptor - -class EnqueueErrorCode(int, metaclass=_enum_type_wrapper.EnumTypeWrapper): - __slots__ = () - ENQUEUE_ERROR_CODE_UNSPECIFIED: _ClassVar[EnqueueErrorCode] - ENQUEUE_ERROR_CODE_QUEUE_NOT_FOUND: _ClassVar[EnqueueErrorCode] - ENQUEUE_ERROR_CODE_STORAGE: _ClassVar[EnqueueErrorCode] - ENQUEUE_ERROR_CODE_LUA: _ClassVar[EnqueueErrorCode] - ENQUEUE_ERROR_CODE_PERMISSION_DENIED: _ClassVar[EnqueueErrorCode] - -class AckErrorCode(int, metaclass=_enum_type_wrapper.EnumTypeWrapper): - __slots__ = () - ACK_ERROR_CODE_UNSPECIFIED: _ClassVar[AckErrorCode] - ACK_ERROR_CODE_MESSAGE_NOT_FOUND: _ClassVar[AckErrorCode] - ACK_ERROR_CODE_STORAGE: _ClassVar[AckErrorCode] - ACK_ERROR_CODE_PERMISSION_DENIED: _ClassVar[AckErrorCode] - -class NackErrorCode(int, metaclass=_enum_type_wrapper.EnumTypeWrapper): - __slots__ = () - NACK_ERROR_CODE_UNSPECIFIED: _ClassVar[NackErrorCode] - NACK_ERROR_CODE_MESSAGE_NOT_FOUND: _ClassVar[NackErrorCode] - NACK_ERROR_CODE_STORAGE: _ClassVar[NackErrorCode] - NACK_ERROR_CODE_PERMISSION_DENIED: _ClassVar[NackErrorCode] -ENQUEUE_ERROR_CODE_UNSPECIFIED: EnqueueErrorCode -ENQUEUE_ERROR_CODE_QUEUE_NOT_FOUND: EnqueueErrorCode -ENQUEUE_ERROR_CODE_STORAGE: EnqueueErrorCode -ENQUEUE_ERROR_CODE_LUA: EnqueueErrorCode -ENQUEUE_ERROR_CODE_PERMISSION_DENIED: EnqueueErrorCode -ACK_ERROR_CODE_UNSPECIFIED: AckErrorCode -ACK_ERROR_CODE_MESSAGE_NOT_FOUND: AckErrorCode -ACK_ERROR_CODE_STORAGE: AckErrorCode -ACK_ERROR_CODE_PERMISSION_DENIED: AckErrorCode -NACK_ERROR_CODE_UNSPECIFIED: NackErrorCode -NACK_ERROR_CODE_MESSAGE_NOT_FOUND: NackErrorCode -NACK_ERROR_CODE_STORAGE: NackErrorCode -NACK_ERROR_CODE_PERMISSION_DENIED: NackErrorCode - -class EnqueueMessage(_message.Message): - __slots__ = ("queue", "headers", "payload") - class HeadersEntry(_message.Message): - __slots__ = ("key", "value") - KEY_FIELD_NUMBER: _ClassVar[int] - VALUE_FIELD_NUMBER: _ClassVar[int] - key: str - value: str - def __init__(self, key: _Optional[str] = ..., value: _Optional[str] = ...) -> None: ... - QUEUE_FIELD_NUMBER: _ClassVar[int] - HEADERS_FIELD_NUMBER: _ClassVar[int] - PAYLOAD_FIELD_NUMBER: _ClassVar[int] - queue: str - headers: _containers.ScalarMap[str, str] - payload: bytes - def __init__(self, queue: _Optional[str] = ..., headers: _Optional[_Mapping[str, str]] = ..., payload: _Optional[bytes] = ...) -> None: ... - -class EnqueueRequest(_message.Message): - __slots__ = ("messages",) - MESSAGES_FIELD_NUMBER: _ClassVar[int] - messages: _containers.RepeatedCompositeFieldContainer[EnqueueMessage] - def __init__(self, messages: _Optional[_Iterable[_Union[EnqueueMessage, _Mapping]]] = ...) -> None: ... - -class EnqueueResult(_message.Message): - __slots__ = ("message_id", "error") - MESSAGE_ID_FIELD_NUMBER: _ClassVar[int] - ERROR_FIELD_NUMBER: _ClassVar[int] - message_id: str - error: EnqueueError - def __init__(self, message_id: _Optional[str] = ..., error: _Optional[_Union[EnqueueError, _Mapping]] = ...) -> None: ... - -class EnqueueError(_message.Message): - __slots__ = ("code", "message") - CODE_FIELD_NUMBER: _ClassVar[int] - MESSAGE_FIELD_NUMBER: _ClassVar[int] - code: EnqueueErrorCode - message: str - def __init__(self, code: _Optional[_Union[EnqueueErrorCode, str]] = ..., message: _Optional[str] = ...) -> None: ... - -class EnqueueResponse(_message.Message): - __slots__ = ("results",) - RESULTS_FIELD_NUMBER: _ClassVar[int] - results: _containers.RepeatedCompositeFieldContainer[EnqueueResult] - def __init__(self, results: _Optional[_Iterable[_Union[EnqueueResult, _Mapping]]] = ...) -> None: ... - -class ConsumeRequest(_message.Message): - __slots__ = ("queue",) - QUEUE_FIELD_NUMBER: _ClassVar[int] - queue: str - def __init__(self, queue: _Optional[str] = ...) -> None: ... - -class ConsumeResponse(_message.Message): - __slots__ = ("messages",) - MESSAGES_FIELD_NUMBER: _ClassVar[int] - messages: _containers.RepeatedCompositeFieldContainer[_messages_pb2.Message] - def __init__(self, messages: _Optional[_Iterable[_Union[_messages_pb2.Message, _Mapping]]] = ...) -> None: ... - -class AckMessage(_message.Message): - __slots__ = ("queue", "message_id") - QUEUE_FIELD_NUMBER: _ClassVar[int] - MESSAGE_ID_FIELD_NUMBER: _ClassVar[int] - queue: str - message_id: str - def __init__(self, queue: _Optional[str] = ..., message_id: _Optional[str] = ...) -> None: ... - -class AckRequest(_message.Message): - __slots__ = ("messages",) - MESSAGES_FIELD_NUMBER: _ClassVar[int] - messages: _containers.RepeatedCompositeFieldContainer[AckMessage] - def __init__(self, messages: _Optional[_Iterable[_Union[AckMessage, _Mapping]]] = ...) -> None: ... - -class AckResult(_message.Message): - __slots__ = ("success", "error") - SUCCESS_FIELD_NUMBER: _ClassVar[int] - ERROR_FIELD_NUMBER: _ClassVar[int] - success: AckSuccess - error: AckError - def __init__(self, success: _Optional[_Union[AckSuccess, _Mapping]] = ..., error: _Optional[_Union[AckError, _Mapping]] = ...) -> None: ... - -class AckSuccess(_message.Message): - __slots__ = () - def __init__(self) -> None: ... - -class AckError(_message.Message): - __slots__ = ("code", "message") - CODE_FIELD_NUMBER: _ClassVar[int] - MESSAGE_FIELD_NUMBER: _ClassVar[int] - code: AckErrorCode - message: str - def __init__(self, code: _Optional[_Union[AckErrorCode, str]] = ..., message: _Optional[str] = ...) -> None: ... - -class AckResponse(_message.Message): - __slots__ = ("results",) - RESULTS_FIELD_NUMBER: _ClassVar[int] - results: _containers.RepeatedCompositeFieldContainer[AckResult] - def __init__(self, results: _Optional[_Iterable[_Union[AckResult, _Mapping]]] = ...) -> None: ... - -class NackMessage(_message.Message): - __slots__ = ("queue", "message_id", "error") - QUEUE_FIELD_NUMBER: _ClassVar[int] - MESSAGE_ID_FIELD_NUMBER: _ClassVar[int] - ERROR_FIELD_NUMBER: _ClassVar[int] - queue: str - message_id: str - error: str - def __init__(self, queue: _Optional[str] = ..., message_id: _Optional[str] = ..., error: _Optional[str] = ...) -> None: ... - -class NackRequest(_message.Message): - __slots__ = ("messages",) - MESSAGES_FIELD_NUMBER: _ClassVar[int] - messages: _containers.RepeatedCompositeFieldContainer[NackMessage] - def __init__(self, messages: _Optional[_Iterable[_Union[NackMessage, _Mapping]]] = ...) -> None: ... - -class NackResult(_message.Message): - __slots__ = ("success", "error") - SUCCESS_FIELD_NUMBER: _ClassVar[int] - ERROR_FIELD_NUMBER: _ClassVar[int] - success: NackSuccess - error: NackError - def __init__(self, success: _Optional[_Union[NackSuccess, _Mapping]] = ..., error: _Optional[_Union[NackError, _Mapping]] = ...) -> None: ... - -class NackSuccess(_message.Message): - __slots__ = () - def __init__(self) -> None: ... - -class NackError(_message.Message): - __slots__ = ("code", "message") - CODE_FIELD_NUMBER: _ClassVar[int] - MESSAGE_FIELD_NUMBER: _ClassVar[int] - code: NackErrorCode - message: str - def __init__(self, code: _Optional[_Union[NackErrorCode, str]] = ..., message: _Optional[str] = ...) -> None: ... - -class NackResponse(_message.Message): - __slots__ = ("results",) - RESULTS_FIELD_NUMBER: _ClassVar[int] - results: _containers.RepeatedCompositeFieldContainer[NackResult] - def __init__(self, results: _Optional[_Iterable[_Union[NackResult, _Mapping]]] = ...) -> None: ... - -class StreamEnqueueRequest(_message.Message): - __slots__ = ("messages", "sequence_number") - MESSAGES_FIELD_NUMBER: _ClassVar[int] - SEQUENCE_NUMBER_FIELD_NUMBER: _ClassVar[int] - messages: _containers.RepeatedCompositeFieldContainer[EnqueueMessage] - sequence_number: int - def __init__(self, messages: _Optional[_Iterable[_Union[EnqueueMessage, _Mapping]]] = ..., sequence_number: _Optional[int] = ...) -> None: ... - -class StreamEnqueueResponse(_message.Message): - __slots__ = ("sequence_number", "results") - SEQUENCE_NUMBER_FIELD_NUMBER: _ClassVar[int] - RESULTS_FIELD_NUMBER: _ClassVar[int] - sequence_number: int - results: _containers.RepeatedCompositeFieldContainer[EnqueueResult] - def __init__(self, sequence_number: _Optional[int] = ..., results: _Optional[_Iterable[_Union[EnqueueResult, _Mapping]]] = ...) -> None: ... diff --git a/fila/v1/service_pb2_grpc.py b/fila/v1/service_pb2_grpc.py deleted file mode 100644 index fa3f3fd..0000000 --- a/fila/v1/service_pb2_grpc.py +++ /dev/null @@ -1,272 +0,0 @@ -# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! -"""Client and server classes corresponding to protobuf-defined services.""" -import grpc -import warnings - -from fila.v1 import service_pb2 as fila_dot_v1_dot_service__pb2 - -GRPC_GENERATED_VERSION = '1.78.0' -GRPC_VERSION = grpc.__version__ -_version_not_supported = False - -try: - from grpc._utilities import first_version_is_lower - _version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION) -except ImportError: - _version_not_supported = True - -if _version_not_supported: - raise RuntimeError( - f'The grpc package installed is at version {GRPC_VERSION},' - + ' but the generated code in fila/v1/service_pb2_grpc.py depends on' - + f' grpcio>={GRPC_GENERATED_VERSION}.' - + f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}' - + f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.' - ) - - -class FilaServiceStub(object): - """Hot-path RPCs for producers and consumers. - """ - - def __init__(self, channel): - """Constructor. - - Args: - channel: A grpc.Channel. - """ - self.Enqueue = channel.unary_unary( - '/fila.v1.FilaService/Enqueue', - request_serializer=fila_dot_v1_dot_service__pb2.EnqueueRequest.SerializeToString, - response_deserializer=fila_dot_v1_dot_service__pb2.EnqueueResponse.FromString, - _registered_method=True) - self.StreamEnqueue = channel.stream_stream( - '/fila.v1.FilaService/StreamEnqueue', - request_serializer=fila_dot_v1_dot_service__pb2.StreamEnqueueRequest.SerializeToString, - response_deserializer=fila_dot_v1_dot_service__pb2.StreamEnqueueResponse.FromString, - _registered_method=True) - self.Consume = channel.unary_stream( - '/fila.v1.FilaService/Consume', - request_serializer=fila_dot_v1_dot_service__pb2.ConsumeRequest.SerializeToString, - response_deserializer=fila_dot_v1_dot_service__pb2.ConsumeResponse.FromString, - _registered_method=True) - self.Ack = channel.unary_unary( - '/fila.v1.FilaService/Ack', - request_serializer=fila_dot_v1_dot_service__pb2.AckRequest.SerializeToString, - response_deserializer=fila_dot_v1_dot_service__pb2.AckResponse.FromString, - _registered_method=True) - self.Nack = channel.unary_unary( - '/fila.v1.FilaService/Nack', - request_serializer=fila_dot_v1_dot_service__pb2.NackRequest.SerializeToString, - response_deserializer=fila_dot_v1_dot_service__pb2.NackResponse.FromString, - _registered_method=True) - - -class FilaServiceServicer(object): - """Hot-path RPCs for producers and consumers. - """ - - def Enqueue(self, request, context): - """Missing associated documentation comment in .proto file.""" - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') - - def StreamEnqueue(self, request_iterator, context): - """Missing associated documentation comment in .proto file.""" - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') - - def Consume(self, request, context): - """Missing associated documentation comment in .proto file.""" - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') - - def Ack(self, request, context): - """Missing associated documentation comment in .proto file.""" - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') - - def Nack(self, request, context): - """Missing associated documentation comment in .proto file.""" - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') - - -def add_FilaServiceServicer_to_server(servicer, server): - rpc_method_handlers = { - 'Enqueue': grpc.unary_unary_rpc_method_handler( - servicer.Enqueue, - request_deserializer=fila_dot_v1_dot_service__pb2.EnqueueRequest.FromString, - response_serializer=fila_dot_v1_dot_service__pb2.EnqueueResponse.SerializeToString, - ), - 'StreamEnqueue': grpc.stream_stream_rpc_method_handler( - servicer.StreamEnqueue, - request_deserializer=fila_dot_v1_dot_service__pb2.StreamEnqueueRequest.FromString, - response_serializer=fila_dot_v1_dot_service__pb2.StreamEnqueueResponse.SerializeToString, - ), - 'Consume': grpc.unary_stream_rpc_method_handler( - servicer.Consume, - request_deserializer=fila_dot_v1_dot_service__pb2.ConsumeRequest.FromString, - response_serializer=fila_dot_v1_dot_service__pb2.ConsumeResponse.SerializeToString, - ), - 'Ack': grpc.unary_unary_rpc_method_handler( - servicer.Ack, - request_deserializer=fila_dot_v1_dot_service__pb2.AckRequest.FromString, - response_serializer=fila_dot_v1_dot_service__pb2.AckResponse.SerializeToString, - ), - 'Nack': grpc.unary_unary_rpc_method_handler( - servicer.Nack, - request_deserializer=fila_dot_v1_dot_service__pb2.NackRequest.FromString, - response_serializer=fila_dot_v1_dot_service__pb2.NackResponse.SerializeToString, - ), - } - generic_handler = grpc.method_handlers_generic_handler( - 'fila.v1.FilaService', rpc_method_handlers) - server.add_generic_rpc_handlers((generic_handler,)) - server.add_registered_method_handlers('fila.v1.FilaService', rpc_method_handlers) - - - # This class is part of an EXPERIMENTAL API. -class FilaService(object): - """Hot-path RPCs for producers and consumers. - """ - - @staticmethod - def Enqueue(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary( - request, - target, - '/fila.v1.FilaService/Enqueue', - fila_dot_v1_dot_service__pb2.EnqueueRequest.SerializeToString, - fila_dot_v1_dot_service__pb2.EnqueueResponse.FromString, - options, - channel_credentials, - insecure, - call_credentials, - compression, - wait_for_ready, - timeout, - metadata, - _registered_method=True) - - @staticmethod - def StreamEnqueue(request_iterator, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.stream_stream( - request_iterator, - target, - '/fila.v1.FilaService/StreamEnqueue', - fila_dot_v1_dot_service__pb2.StreamEnqueueRequest.SerializeToString, - fila_dot_v1_dot_service__pb2.StreamEnqueueResponse.FromString, - options, - channel_credentials, - insecure, - call_credentials, - compression, - wait_for_ready, - timeout, - metadata, - _registered_method=True) - - @staticmethod - def Consume(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_stream( - request, - target, - '/fila.v1.FilaService/Consume', - fila_dot_v1_dot_service__pb2.ConsumeRequest.SerializeToString, - fila_dot_v1_dot_service__pb2.ConsumeResponse.FromString, - options, - channel_credentials, - insecure, - call_credentials, - compression, - wait_for_ready, - timeout, - metadata, - _registered_method=True) - - @staticmethod - def Ack(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary( - request, - target, - '/fila.v1.FilaService/Ack', - fila_dot_v1_dot_service__pb2.AckRequest.SerializeToString, - fila_dot_v1_dot_service__pb2.AckResponse.FromString, - options, - channel_credentials, - insecure, - call_credentials, - compression, - wait_for_ready, - timeout, - metadata, - _registered_method=True) - - @staticmethod - def Nack(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary( - request, - target, - '/fila.v1.FilaService/Nack', - fila_dot_v1_dot_service__pb2.NackRequest.SerializeToString, - fila_dot_v1_dot_service__pb2.NackResponse.FromString, - options, - channel_credentials, - insecure, - call_credentials, - compression, - wait_for_ready, - timeout, - metadata, - _registered_method=True) diff --git a/proto/fila/v1/admin.proto b/proto/fila/v1/admin.proto deleted file mode 100644 index 9bb8871..0000000 --- a/proto/fila/v1/admin.proto +++ /dev/null @@ -1,119 +0,0 @@ -syntax = "proto3"; -package fila.v1; - -// Admin RPCs for operators and the CLI. -service FilaAdmin { - rpc CreateQueue(CreateQueueRequest) returns (CreateQueueResponse); - rpc DeleteQueue(DeleteQueueRequest) returns (DeleteQueueResponse); - rpc SetConfig(SetConfigRequest) returns (SetConfigResponse); - rpc GetConfig(GetConfigRequest) returns (GetConfigResponse); - rpc ListConfig(ListConfigRequest) returns (ListConfigResponse); - rpc GetStats(GetStatsRequest) returns (GetStatsResponse); - rpc Redrive(RedriveRequest) returns (RedriveResponse); - rpc ListQueues(ListQueuesRequest) returns (ListQueuesResponse); -} - -message CreateQueueRequest { - string name = 1; - QueueConfig config = 2; -} - -message QueueConfig { - string on_enqueue_script = 1; - string on_failure_script = 2; - uint64 visibility_timeout_ms = 3; -} - -message CreateQueueResponse { - string queue_id = 1; -} - -message DeleteQueueRequest { - string queue = 1; -} - -message DeleteQueueResponse {} - -message SetConfigRequest { - string key = 1; - string value = 2; -} - -message SetConfigResponse {} - -message GetConfigRequest { - string key = 1; -} - -message GetConfigResponse { - string value = 1; -} - -message ConfigEntry { - string key = 1; - string value = 2; -} - -message ListConfigRequest { - string prefix = 1; -} - -message ListConfigResponse { - repeated ConfigEntry entries = 1; - uint32 total_count = 2; -} - -message GetStatsRequest { - string queue = 1; -} - -message PerFairnessKeyStats { - string key = 1; - uint64 pending_count = 2; - int64 current_deficit = 3; - uint32 weight = 4; -} - -message PerThrottleKeyStats { - string key = 1; - double tokens = 2; - double rate_per_second = 3; - double burst = 4; -} - -message GetStatsResponse { - uint64 depth = 1; - uint64 in_flight = 2; - uint64 active_fairness_keys = 3; - uint32 active_consumers = 4; - uint32 quantum = 5; - repeated PerFairnessKeyStats per_key_stats = 6; - repeated PerThrottleKeyStats per_throttle_stats = 7; - // Cluster fields (0 when not in cluster mode). - uint64 leader_node_id = 8; - uint32 replication_count = 9; -} - -message RedriveRequest { - string dlq_queue = 1; - uint64 count = 2; -} - -message RedriveResponse { - uint64 redriven = 1; -} - -message ListQueuesRequest {} - -message QueueInfo { - string name = 1; - uint64 depth = 2; - uint64 in_flight = 3; - uint32 active_consumers = 4; - uint64 leader_node_id = 5; -} - -message ListQueuesResponse { - repeated QueueInfo queues = 1; - uint32 cluster_node_count = 2; -} diff --git a/proto/fila/v1/messages.proto b/proto/fila/v1/messages.proto deleted file mode 100644 index a0709cf..0000000 --- a/proto/fila/v1/messages.proto +++ /dev/null @@ -1,28 +0,0 @@ -syntax = "proto3"; -package fila.v1; - -import "google/protobuf/timestamp.proto"; - -// Core message envelope persisted in the broker. -message Message { - string id = 1; - map headers = 2; - bytes payload = 3; - MessageMetadata metadata = 4; - MessageTimestamps timestamps = 5; -} - -// Broker-assigned scheduling metadata. -message MessageMetadata { - string fairness_key = 1; - uint32 weight = 2; - repeated string throttle_keys = 3; - uint32 attempt_count = 4; - string queue_id = 5; -} - -// Lifecycle timestamps attached to every message. -message MessageTimestamps { - google.protobuf.Timestamp enqueued_at = 1; - google.protobuf.Timestamp leased_at = 2; -} diff --git a/proto/fila/v1/service.proto b/proto/fila/v1/service.proto deleted file mode 100644 index 7d1db79..0000000 --- a/proto/fila/v1/service.proto +++ /dev/null @@ -1,142 +0,0 @@ -syntax = "proto3"; -package fila.v1; - -import "fila/v1/messages.proto"; - -// Hot-path RPCs for producers and consumers. -service FilaService { - rpc Enqueue(EnqueueRequest) returns (EnqueueResponse); - rpc StreamEnqueue(stream StreamEnqueueRequest) returns (stream StreamEnqueueResponse); - rpc Consume(ConsumeRequest) returns (stream ConsumeResponse); - rpc Ack(AckRequest) returns (AckResponse); - rpc Nack(NackRequest) returns (NackResponse); -} - -// Individual message to enqueue. -message EnqueueMessage { - string queue = 1; - map headers = 2; - bytes payload = 3; -} - -// Enqueue one or more messages. -message EnqueueRequest { - repeated EnqueueMessage messages = 1; -} - -// Per-message enqueue result. -message EnqueueResult { - oneof result { - string message_id = 1; - EnqueueError error = 2; - } -} - -// Typed enqueue error with structured error code. -message EnqueueError { - EnqueueErrorCode code = 1; - string message = 2; -} - -enum EnqueueErrorCode { - ENQUEUE_ERROR_CODE_UNSPECIFIED = 0; - ENQUEUE_ERROR_CODE_QUEUE_NOT_FOUND = 1; - ENQUEUE_ERROR_CODE_STORAGE = 2; - ENQUEUE_ERROR_CODE_LUA = 3; - ENQUEUE_ERROR_CODE_PERMISSION_DENIED = 4; -} - -// One result per input message. -message EnqueueResponse { - repeated EnqueueResult results = 1; -} - -message ConsumeRequest { - string queue = 1; -} - -message ConsumeResponse { - repeated Message messages = 1; -} - -// Individual ack item. -message AckMessage { - string queue = 1; - string message_id = 2; -} - -message AckRequest { - repeated AckMessage messages = 1; -} - -message AckResult { - oneof result { - AckSuccess success = 1; - AckError error = 2; - } -} - -message AckSuccess {} - -message AckError { - AckErrorCode code = 1; - string message = 2; -} - -enum AckErrorCode { - ACK_ERROR_CODE_UNSPECIFIED = 0; - ACK_ERROR_CODE_MESSAGE_NOT_FOUND = 1; - ACK_ERROR_CODE_STORAGE = 2; - ACK_ERROR_CODE_PERMISSION_DENIED = 3; -} - -message AckResponse { - repeated AckResult results = 1; -} - -// Individual nack item. -message NackMessage { - string queue = 1; - string message_id = 2; - string error = 3; -} - -message NackRequest { - repeated NackMessage messages = 1; -} - -message NackResult { - oneof result { - NackSuccess success = 1; - NackError error = 2; - } -} - -message NackSuccess {} - -message NackError { - NackErrorCode code = 1; - string message = 2; -} - -enum NackErrorCode { - NACK_ERROR_CODE_UNSPECIFIED = 0; - NACK_ERROR_CODE_MESSAGE_NOT_FOUND = 1; - NACK_ERROR_CODE_STORAGE = 2; - NACK_ERROR_CODE_PERMISSION_DENIED = 3; -} - -message NackResponse { - repeated NackResult results = 1; -} - -// Stream enqueue — per-write batch with sequence tracking. -message StreamEnqueueRequest { - repeated EnqueueMessage messages = 1; - uint64 sequence_number = 2; -} - -message StreamEnqueueResponse { - uint64 sequence_number = 1; - repeated EnqueueResult results = 2; -} diff --git a/pyproject.toml b/pyproject.toml index 2dcc753..d25e4b1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,24 +4,19 @@ build-backend = "setuptools.build_meta" [project] name = "fila-python" -version = "0.2.0" +version = "0.3.0" description = "Python client SDK for the Fila message broker" readme = "README.md" -license = "AGPL-3.0-or-later" +license = {text = "AGPL-3.0-or-later"} requires-python = ">=3.10" -dependencies = [ - "grpcio>=1.60.0", - "protobuf>=4.25.0", -] +dependencies = [] [project.optional-dependencies] dev = [ - "grpcio-tools>=1.60.0", "pytest>=8.0", "pytest-asyncio>=0.23", "ruff>=0.3", "mypy>=1.8", - "mypy-protobuf>=3.5", ] [tool.setuptools.packages.find] @@ -30,10 +25,10 @@ include = ["fila*"] [tool.ruff] target-version = "py310" line-length = 100 -exclude = ["fila/v1/"] [tool.ruff.lint] select = ["E", "F", "I", "N", "UP", "B", "SIM", "TCH"] +ignore = ["SIM115"] [tool.mypy] python_version = "3.10" @@ -41,13 +36,5 @@ strict = true warn_return_any = true warn_unused_configs = true -[[tool.mypy.overrides]] -module = "fila.v1.*" -ignore_errors = true - -[[tool.mypy.overrides]] -module = "grpc.*" -ignore_missing_imports = true - [tool.pytest.ini_options] asyncio_mode = "auto" diff --git a/tests/conftest.py b/tests/conftest.py index 3b91d60..820aa02 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -12,13 +12,11 @@ from pathlib import Path from typing import TYPE_CHECKING -import grpc import pytest if TYPE_CHECKING: from collections.abc import Generator -from fila.v1 import admin_pb2, admin_pb2_grpc FILA_SERVER_BIN = os.environ.get( "FILA_SERVER_BIN", @@ -169,41 +167,43 @@ def stop(self) -> None: self._process.wait() shutil.rmtree(self._data_dir, ignore_errors=True) - def _make_channel(self) -> grpc.Channel: - """Create a gRPC channel to this server (TLS-aware).""" + def create_queue(self, name: str) -> None: + """Create a queue on the test server via FIBP.""" + from fila import Client + + kwargs: dict[str, object] = {} if self.tls_paths is not None: with open(self.tls_paths["ca_cert"], "rb") as f: - ca = f.read() + kwargs["ca_cert"] = f.read() with open(self.tls_paths["client_cert"], "rb") as f: - cert = f.read() + kwargs["client_cert"] = f.read() with open(self.tls_paths["client_key"], "rb") as f: - key = f.read() - creds = grpc.ssl_channel_credentials( - root_certificates=ca, - private_key=key, - certificate_chain=cert, - ) - channel = grpc.secure_channel(self.addr, creds) - else: - channel = grpc.insecure_channel(self.addr) - + kwargs["client_key"] = f.read() if self.api_key is not None: - from fila.client import _ApiKeyInterceptor - channel = grpc.intercept_channel(channel, _ApiKeyInterceptor(self.api_key)) + kwargs["api_key"] = self.api_key + kwargs["accumulator_mode"] = __import__("fila").AccumulatorMode.DISABLED - return channel + with Client(self.addr, **kwargs) as client: # type: ignore[arg-type] + client.create_queue(name) - def create_queue(self, name: str) -> None: - """Create a queue on the test server via admin gRPC.""" - channel = self._make_channel() - stub = admin_pb2_grpc.FilaAdminStub(channel) - stub.CreateQueue( - admin_pb2.CreateQueueRequest( - name=name, - config=admin_pb2.QueueConfig(), - ) - ) - channel.close() + +def _wait_for_server(addr: str, timeout: float = 10.0, **client_kwargs: object) -> bool: + """Wait for the fila-server to become ready by attempting a FIBP handshake.""" + from fila import AccumulatorMode, Client + + deadline = time.monotonic() + timeout + while time.monotonic() < deadline: + try: + with Client( + addr, + accumulator_mode=AccumulatorMode.DISABLED, + **client_kwargs, # type: ignore[arg-type] + ) as client: + client.list_queues() + return True + except Exception: + time.sleep(0.05) + return False @pytest.fixture() @@ -217,7 +217,6 @@ def server() -> Generator[TestServer, None, None]: data_dir = tempfile.mkdtemp(prefix="fila-test-") - # Write config file for the server. config_path = os.path.join(data_dir, "fila.toml") with open(config_path, "w") as f: f.write(f'[server]\nlisten_addr = "{addr}"\n') @@ -233,24 +232,11 @@ def server() -> Generator[TestServer, None, None]: ts = TestServer(addr, process, data_dir) - # Wait for server to be ready. - deadline = time.monotonic() + 10.0 - while time.monotonic() < deadline: - channel = grpc.insecure_channel(addr) - try: - stub = admin_pb2_grpc.FilaAdminStub(channel) - stub.ListQueues(admin_pb2.ListQueuesRequest()) - channel.close() - break - except grpc.RpcError: - channel.close() - time.sleep(0.05) - else: + if not _wait_for_server(addr): ts.stop() pytest.fail("fila-server did not become ready within 10s") yield ts - ts.stop() @@ -271,7 +257,6 @@ def tls_server() -> Generator[TestServer, None, None]: data_dir = tempfile.mkdtemp(prefix="fila-tls-test-") tls_paths = _generate_self_signed_certs(data_dir) - # Write config with TLS enabled. config_path = os.path.join(data_dir, "fila.toml") with open(config_path, "w") as f: f.write( @@ -295,24 +280,20 @@ def tls_server() -> Generator[TestServer, None, None]: ts = TestServer(addr, process, data_dir, tls_paths=tls_paths) - # Wait for server to be ready (use TLS channel). - deadline = time.monotonic() + 10.0 - while time.monotonic() < deadline: - channel = ts._make_channel() - try: - stub = admin_pb2_grpc.FilaAdminStub(channel) - stub.ListQueues(admin_pb2.ListQueuesRequest()) - channel.close() - break - except grpc.RpcError: - channel.close() - time.sleep(0.05) - else: + with open(tls_paths["ca_cert"], "rb") as f: + ca_cert = f.read() + with open(tls_paths["client_cert"], "rb") as f: + client_cert = f.read() + with open(tls_paths["client_key"], "rb") as f: + client_key = f.read() + + if not _wait_for_server( + addr, ca_cert=ca_cert, client_cert=client_cert, client_key=client_key + ): ts.stop() pytest.fail("TLS fila-server did not become ready within 10s") yield ts - ts.stop() @@ -328,7 +309,6 @@ def auth_server() -> Generator[TestServer, None, None]: data_dir = tempfile.mkdtemp(prefix="fila-auth-test-") - # Write config with bootstrap API key. config_path = os.path.join(data_dir, "fila.toml") with open(config_path, "w") as f: f.write( @@ -348,22 +328,9 @@ def auth_server() -> Generator[TestServer, None, None]: ts = TestServer(addr, process, data_dir, api_key=bootstrap_key) - # Wait for server to be ready. - deadline = time.monotonic() + 10.0 - while time.monotonic() < deadline: - channel = ts._make_channel() - try: - stub = admin_pb2_grpc.FilaAdminStub(channel) - stub.ListQueues(admin_pb2.ListQueuesRequest()) - channel.close() - break - except grpc.RpcError: - channel.close() - time.sleep(0.05) - else: + if not _wait_for_server(addr, api_key=bootstrap_key): ts.stop() pytest.fail("auth fila-server did not become ready within 10s") yield ts - ts.stop() diff --git a/tests/test_batcher.py b/tests/test_batcher.py index dfd5919..231b3cc 100644 --- a/tests/test_batcher.py +++ b/tests/test_batcher.py @@ -1,12 +1,12 @@ """Unit tests for the batcher module. -These tests use mock stubs and do not require a running fila-server. +These tests use mock connections and do not require a running fila-server. """ from __future__ import annotations +import struct from concurrent.futures import Future -from typing import Any from unittest.mock import MagicMock import pytest @@ -18,74 +18,67 @@ _flush_many, _flush_single, ) -from fila.errors import EnqueueError -from fila.v1 import service_pb2 - - -class FakeEnqueueResult: - """Minimal fake for service_pb2.EnqueueResult.""" - - def __init__(self, message_id: str | None = None, error_msg: str | None = None) -> None: - self._message_id = message_id - self._error_msg = error_msg - self.message_id = message_id or "" - self.error = MagicMock() - self.error.message = error_msg or "" - - def WhichOneof(self, name: str) -> str | None: # noqa: N802 - if name == "result": - if self._message_id is not None: - return "message_id" - return "error" - return None - - -class FakeEnqueueResponse: - """Minimal fake for service_pb2.EnqueueResponse.""" - - def __init__(self, results: list[FakeEnqueueResult]) -> None: - self.results = results +from fila.errors import EnqueueError, QueueNotFoundError +from fila.fibp.opcodes import ErrorCode, FrameHeader, Opcode + + +def _make_enqueue_result_body(*results: tuple[int, str]) -> bytes: + """Build an EnqueueResult frame body from (error_code, message_id) tuples.""" + buf = struct.pack("!I", len(results)) + for code, msg_id in results: + buf += struct.pack("!B", code) + encoded = msg_id.encode("utf-8") + buf += struct.pack("!H", len(encoded)) + encoded + return buf + + +def _make_mock_conn(*results: tuple[int, str]) -> MagicMock: + """Create a mock Connection whose request() returns an EnqueueResult.""" + conn = MagicMock() + body = _make_enqueue_result_body(*results) + header = FrameHeader(opcode=Opcode.ENQUEUE_RESULT, flags=0, request_id=1) + conn.request.return_value = (header, body) + return conn + + +def _make_error_conn(error_code: int, message: str) -> MagicMock: + """Create a mock Connection whose request() returns an Error frame.""" + conn = MagicMock() + # Build error frame body: u8 code, string message, map metadata (empty) + encoded_msg = message.encode("utf-8") + body = ( + struct.pack("!B", error_code) + + struct.pack("!H", len(encoded_msg)) + encoded_msg + + struct.pack("!H", 0) # empty metadata map + ) + header = FrameHeader(opcode=Opcode.ERROR, flags=0, request_id=1) + conn.request.return_value = (header, body) + return conn class TestFlushSingle: """Test the _flush_single function.""" def test_success(self) -> None: - stub = MagicMock() - stub.Enqueue.return_value = FakeEnqueueResponse([ - FakeEnqueueResult(message_id="msg-001"), - ]) + conn = _make_mock_conn((ErrorCode.OK, "msg-001")) - proto = service_pb2.EnqueueMessage(queue="q", payload=b"data") + msg = {"queue": "q", "headers": {}, "payload": b"data"} fut: Future[str] = Future() - req = _EnqueueItem(proto, fut) + req = _EnqueueItem(msg, fut) - _flush_single(stub, req) + _flush_single(conn, req) assert fut.result(timeout=1.0) == "msg-001" - stub.Enqueue.assert_called_once() - sent_req = stub.Enqueue.call_args.args[0] - assert len(sent_req.messages) == 1 - assert sent_req.messages[0] == proto - - def test_rpc_error(self) -> None: - import grpc - - stub = MagicMock() - stub.Enqueue.side_effect = type( - "_FakeRpcError", (grpc.RpcError,), { - "code": lambda self: grpc.StatusCode.NOT_FOUND, - "details": lambda self: "queue not found", - } - )() - - proto = service_pb2.EnqueueMessage(queue="missing", payload=b"data") - fut: Future[str] = Future() - req = _EnqueueItem(proto, fut) + conn.request.assert_called_once() + + def test_error_frame(self) -> None: + conn = _make_error_conn(ErrorCode.QUEUE_NOT_FOUND, "queue not found") - _flush_single(stub, req) + msg = {"queue": "missing", "headers": {}, "payload": b"data"} + fut: Future[str] = Future() + req = _EnqueueItem(msg, fut) - from fila.errors import QueueNotFoundError + _flush_single(conn, req) with pytest.raises(QueueNotFoundError): fut.result(timeout=1.0) @@ -95,75 +88,66 @@ class TestFlushMany: """Test the _flush_many function.""" def test_all_success(self) -> None: - stub = MagicMock() - stub.Enqueue.return_value = FakeEnqueueResponse([ - FakeEnqueueResult(message_id="id-1"), - FakeEnqueueResult(message_id="id-2"), - ]) + conn = _make_mock_conn( + (ErrorCode.OK, "id-1"), + (ErrorCode.OK, "id-2"), + ) items = [ _EnqueueItem( - service_pb2.EnqueueMessage(queue="q", payload=b"a"), + {"queue": "q", "headers": {}, "payload": b"a"}, Future(), ), _EnqueueItem( - service_pb2.EnqueueMessage(queue="q", payload=b"b"), + {"queue": "q", "headers": {}, "payload": b"b"}, Future(), ), ] - _flush_many(stub, items) + _flush_many(conn, items) assert items[0].future.result(timeout=1.0) == "id-1" assert items[1].future.result(timeout=1.0) == "id-2" def test_mixed_results(self) -> None: - stub = MagicMock() - stub.Enqueue.return_value = FakeEnqueueResponse([ - FakeEnqueueResult(message_id="id-1"), - FakeEnqueueResult(error_msg="queue 'missing' not found"), - ]) + conn = _make_mock_conn( + (ErrorCode.OK, "id-1"), + (ErrorCode.QUEUE_NOT_FOUND, ""), + ) items = [ _EnqueueItem( - service_pb2.EnqueueMessage(queue="q", payload=b"a"), + {"queue": "q", "headers": {}, "payload": b"a"}, Future(), ), _EnqueueItem( - service_pb2.EnqueueMessage(queue="missing", payload=b"b"), + {"queue": "missing", "headers": {}, "payload": b"b"}, Future(), ), ] - _flush_many(stub, items) + _flush_many(conn, items) assert items[0].future.result(timeout=1.0) == "id-1" - with pytest.raises(EnqueueError, match="queue 'missing' not found"): + with pytest.raises(QueueNotFoundError): items[1].future.result(timeout=1.0) - def test_rpc_failure_sets_all_futures(self) -> None: - import grpc - - stub = MagicMock() - stub.Enqueue.side_effect = type( - "_FakeRpcError", (grpc.RpcError,), { - "code": lambda self: grpc.StatusCode.UNAVAILABLE, - "details": lambda self: "server unavailable", - } - )() + def test_connection_failure_sets_all_futures(self) -> None: + conn = MagicMock() + conn.request.side_effect = ConnectionError("server unavailable") items = [ _EnqueueItem( - service_pb2.EnqueueMessage(queue="q", payload=b"a"), + {"queue": "q", "headers": {}, "payload": b"a"}, Future(), ), _EnqueueItem( - service_pb2.EnqueueMessage(queue="q", payload=b"b"), + {"queue": "q", "headers": {}, "payload": b"b"}, Future(), ), ] - _flush_many(stub, items) + _flush_many(conn, items) for item in items: with pytest.raises(EnqueueError): @@ -173,51 +157,30 @@ def test_rpc_failure_sets_all_futures(self) -> None: class TestAutoAccumulator: """Test the AutoAccumulator end-to-end.""" - def test_single_message_uses_enqueue(self) -> None: - """When only one message is queued, AutoAccumulator uses Enqueue with one message.""" - stub = MagicMock() - stub.Enqueue.return_value = FakeEnqueueResponse([ - FakeEnqueueResult(message_id="msg-solo"), - ]) - - accumulator = AutoAccumulator(stub, max_messages=100) + def test_single_message(self) -> None: + """When only one message is queued, AutoAccumulator sends it.""" + conn = _make_mock_conn((ErrorCode.OK, "msg-solo")) + accumulator = AutoAccumulator(conn, max_messages=100) - proto = service_pb2.EnqueueMessage(queue="q", payload=b"solo") - fut = accumulator.submit(proto) + msg = {"queue": "q", "headers": {}, "payload": b"solo"} + fut = accumulator.submit(msg) result = fut.result(timeout=5.0) assert result == "msg-solo" - stub.Enqueue.assert_called_once() - + conn.request.assert_called_once() accumulator.close() def test_concurrent_messages_accumulated(self) -> None: """When multiple messages arrive concurrently, they accumulate together.""" - stub = MagicMock() - - enqueue_response = FakeEnqueueResponse([ - FakeEnqueueResult(message_id=f"id-{i}") for i in range(5) - ]) - - def mock_enqueue(request: Any) -> FakeEnqueueResponse: - return enqueue_response - - stub.Enqueue.side_effect = mock_enqueue - - accumulator = AutoAccumulator(stub, max_messages=100) - - # Submit 5 messages rapidly. - protos = [ - service_pb2.EnqueueMessage(queue="q", payload=f"msg-{i}".encode()) - for i in range(5) - ] + conn = _make_mock_conn(*[(ErrorCode.OK, f"id-{i}") for i in range(5)]) + accumulator = AutoAccumulator(conn, max_messages=100) futures = [] - for p in protos: - futures.append(accumulator.submit(p)) + for i in range(5): + msg = {"queue": "q", "headers": {}, "payload": f"msg-{i}".encode()} + futures.append(accumulator.submit(msg)) - # All futures should resolve. - for _i, f in enumerate(futures): + for f in futures: result = f.result(timeout=5.0) assert result is not None @@ -225,39 +188,29 @@ def mock_enqueue(request: Any) -> FakeEnqueueResponse: def test_close_drains_pending(self) -> None: """close() waits for pending messages to be flushed.""" - stub = MagicMock() - stub.Enqueue.return_value = FakeEnqueueResponse([ - FakeEnqueueResult(message_id="drained"), - ]) - - accumulator = AutoAccumulator(stub, max_messages=100) + conn = _make_mock_conn((ErrorCode.OK, "drained")) + accumulator = AutoAccumulator(conn, max_messages=100) - proto = service_pb2.EnqueueMessage(queue="q", payload=b"drain-me") - fut = accumulator.submit(proto) + msg = {"queue": "q", "headers": {}, "payload": b"drain-me"} + fut = accumulator.submit(msg) accumulator.close() - # After close, the future should be resolved. assert fut.result(timeout=1.0) == "drained" - def test_update_stub(self) -> None: - """update_stub replaces the gRPC stub used for flushing.""" - old_stub = MagicMock() - new_stub = MagicMock() - new_stub.Enqueue.return_value = FakeEnqueueResponse([ - FakeEnqueueResult(message_id="new-stub"), - ]) - - accumulator = AutoAccumulator(old_stub, max_messages=100) + def test_update_conn(self) -> None: + """update_conn replaces the connection used for flushing.""" + old_conn = MagicMock() + new_conn = _make_mock_conn((ErrorCode.OK, "new-conn")) - # Update stub before submitting. - accumulator.update_stub(new_stub) + accumulator = AutoAccumulator(old_conn, max_messages=100) + accumulator.update_conn(new_conn) - proto = service_pb2.EnqueueMessage(queue="q", payload=b"data") - fut = accumulator.submit(proto) + msg = {"queue": "q", "headers": {}, "payload": b"data"} + fut = accumulator.submit(msg) result = fut.result(timeout=5.0) - assert result == "new-stub" + assert result == "new-conn" accumulator.close() @@ -266,19 +219,14 @@ class TestLingerAccumulator: def test_flushes_at_max_messages(self) -> None: """Flush triggers when max_messages messages accumulate.""" - stub = MagicMock() - stub.Enqueue.return_value = FakeEnqueueResponse([ - FakeEnqueueResult(message_id=f"id-{i}") for i in range(3) - ]) - - accumulator = LingerAccumulator(stub, linger_ms=5000, max_messages=3) + conn = _make_mock_conn(*[(ErrorCode.OK, f"id-{i}") for i in range(3)]) + accumulator = LingerAccumulator(conn, linger_ms=5000, max_messages=3) futures = [] for i in range(3): - proto = service_pb2.EnqueueMessage(queue="q", payload=f"m{i}".encode()) - futures.append(accumulator.submit(proto)) + msg = {"queue": "q", "headers": {}, "payload": f"m{i}".encode()} + futures.append(accumulator.submit(msg)) - # Should flush quickly because max_messages=3 was reached. for i, f in enumerate(futures): result = f.result(timeout=5.0) assert result == f"id-{i}" @@ -287,17 +235,12 @@ def test_flushes_at_max_messages(self) -> None: def test_flushes_at_linger_timeout(self) -> None: """Flush triggers after linger_ms even if max_messages is not reached.""" - stub = MagicMock() - stub.Enqueue.return_value = FakeEnqueueResponse([ - FakeEnqueueResult(message_id="lingered"), - ]) - - accumulator = LingerAccumulator(stub, linger_ms=50, max_messages=100) + conn = _make_mock_conn((ErrorCode.OK, "lingered")) + accumulator = LingerAccumulator(conn, linger_ms=50, max_messages=100) - proto = service_pb2.EnqueueMessage(queue="q", payload=b"linger") - fut = accumulator.submit(proto) + msg = {"queue": "q", "headers": {}, "payload": b"linger"} + fut = accumulator.submit(msg) - # Should flush after ~50ms even though max_messages=100 not reached. result = fut.result(timeout=5.0) assert result == "lingered" @@ -305,15 +248,11 @@ def test_flushes_at_linger_timeout(self) -> None: def test_close_drains_pending(self) -> None: """close() drains any pending messages.""" - stub = MagicMock() - stub.Enqueue.return_value = FakeEnqueueResponse([ - FakeEnqueueResult(message_id="drained"), - ]) - - accumulator = LingerAccumulator(stub, linger_ms=10000, max_messages=100) + conn = _make_mock_conn((ErrorCode.OK, "drained")) + accumulator = LingerAccumulator(conn, linger_ms=10000, max_messages=100) - proto = service_pb2.EnqueueMessage(queue="q", payload=b"drain") - fut = accumulator.submit(proto) + msg = {"queue": "q", "headers": {}, "payload": b"drain"} + fut = accumulator.submit(msg) accumulator.close() diff --git a/tests/test_client.py b/tests/test_client.py index b8e353e..585916b 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -17,14 +17,14 @@ def test_enqueue_consume_ack(self, server: object) -> None: assert isinstance(server, TestServer) server.create_queue("test-sync-eca") - with fila.Client(server.addr) as client: - # Enqueue a message. + with fila.Client( + server.addr, accumulator_mode=fila.AccumulatorMode.DISABLED + ) as client: headers = {"tenant": "acme"} payload = b"hello world" msg_id = client.enqueue("test-sync-eca", headers, payload) assert msg_id != "" - # Consume the message. stream = client.consume("test-sync-eca") msg = next(stream) @@ -32,7 +32,6 @@ def test_enqueue_consume_ack(self, server: object) -> None: assert msg.headers["tenant"] == "acme" assert msg.payload == b"hello world" - # Ack the message. client.ack("test-sync-eca", msg.id) def test_enqueue_consume_nack_redeliver(self, server: object) -> None: @@ -42,26 +41,23 @@ def test_enqueue_consume_nack_redeliver(self, server: object) -> None: assert isinstance(server, TestServer) server.create_queue("test-sync-nack") - with fila.Client(server.addr) as client: + with fila.Client( + server.addr, accumulator_mode=fila.AccumulatorMode.DISABLED + ) as client: msg_id = client.enqueue("test-sync-nack", None, b"retry-me") - # Open consume stream. stream = client.consume("test-sync-nack") - # First delivery. msg = next(stream) assert msg.id == msg_id assert msg.attempt_count == 0 - # Nack the message. client.nack("test-sync-nack", msg.id, "transient failure") - # Redelivery on the same stream. msg2 = next(stream) assert msg2.id == msg_id assert msg2.attempt_count == 1 - # Ack to clean up. client.ack("test-sync-nack", msg2.id) def test_enqueue_nonexistent_queue(self, server: object) -> None: @@ -70,7 +66,9 @@ def test_enqueue_nonexistent_queue(self, server: object) -> None: assert isinstance(server, TestServer) - with fila.Client(server.addr) as client, pytest.raises(fila.QueueNotFoundError): + with fila.Client( + server.addr, accumulator_mode=fila.AccumulatorMode.DISABLED + ) as client, pytest.raises(fila.QueueNotFoundError): client.enqueue("does-not-exist", None, b"test") @@ -86,13 +84,11 @@ async def test_async_enqueue_consume_ack(self, server: object) -> None: server.create_queue("test-async-eca") async with fila.AsyncClient(server.addr) as client: - # Enqueue a message. msg_id = await client.enqueue( "test-async-eca", {"tenant": "acme"}, b"hello async" ) assert msg_id != "" - # Consume the message. stream = await client.consume("test-async-eca") msg = await stream.__anext__() @@ -100,7 +96,6 @@ async def test_async_enqueue_consume_ack(self, server: object) -> None: assert msg.headers["tenant"] == "acme" assert msg.payload == b"hello async" - # Ack the message. await client.ack("test-async-eca", msg.id) @@ -128,6 +123,7 @@ def test_tls_enqueue_consume_ack(self, tls_server: object) -> None: ca_cert=ca_cert, client_cert=client_cert, client_key=client_key, + accumulator_mode=fila.AccumulatorMode.DISABLED, ) as client: msg_id = client.enqueue("test-tls", {"secure": "true"}, b"tls payload") assert msg_id != "" @@ -187,7 +183,11 @@ def test_api_key_enqueue_consume_ack(self, auth_server: object) -> None: auth_server.create_queue("test-auth") - with fila.Client(auth_server.addr, api_key=auth_server.api_key) as client: + with fila.Client( + auth_server.addr, + api_key=auth_server.api_key, + accumulator_mode=fila.AccumulatorMode.DISABLED, + ) as client: msg_id = client.enqueue("test-auth", None, b"authenticated") assert msg_id != "" @@ -201,31 +201,19 @@ def test_api_key_enqueue_consume_ack(self, auth_server: object) -> None: def test_missing_api_key_rejected(self, auth_server: object) -> None: """Requests without API key are rejected when auth is enabled.""" - import grpc - from tests.conftest import TestServer assert isinstance(auth_server, TestServer) - # Probe whether the server actually enforces API key auth. - # The dev-latest binary may predate the bootstrap_apikey feature, - # in which case unauthenticated requests succeed rather than fail. - with fila.Client(auth_server.addr) as probe: - try: - probe.enqueue("__auth_probe__", None, b"probe") - except fila.RPCError as e: - if e.code != grpc.StatusCode.UNAUTHENTICATED: - pytest.fail(f"unexpected RPC error during auth probe: {e.code}") - except fila.QueueNotFoundError: - pytest.skip("server does not enforce API key auth") - else: - pytest.skip("server does not enforce API key auth") - - # If we reach here, the server enforces auth. - with fila.Client(auth_server.addr) as client: - with pytest.raises(fila.RPCError) as exc_info: - client.enqueue("test-auth", None, b"no-key") - assert exc_info.value.code == grpc.StatusCode.UNAUTHENTICATED + # Attempt to connect without an API key -- the handshake should fail. + with ( + pytest.raises((fila.UnauthorizedError, fila.FilaError, ConnectionError)), + fila.Client( + auth_server.addr, + accumulator_mode=fila.AccumulatorMode.DISABLED, + ) as client, + ): + client.enqueue("test-auth", None, b"no-key") @pytest.mark.asyncio async def test_async_api_key_enqueue(self, auth_server: object) -> None: diff --git a/tests/test_enqueue_integration.py b/tests/test_enqueue_integration.py index 4900d64..10522c9 100644 --- a/tests/test_enqueue_integration.py +++ b/tests/test_enqueue_integration.py @@ -15,7 +15,7 @@ class TestEnqueueMany: """Integration tests for the explicit enqueue_many method.""" def test_enqueue_many_multiple_messages(self, server: object) -> None: - """enqueue_many sends multiple messages in one RPC and returns per-message results.""" + """enqueue_many sends multiple messages in one request and returns per-message results.""" from tests.conftest import TestServer assert isinstance(server, TestServer) @@ -36,7 +36,6 @@ def test_enqueue_many_multiple_messages(self, server: object) -> None: assert r.message_id is not None assert r.error is None - # All message IDs should be unique. ids = [r.message_id for r in results] assert len(set(ids)) == 3 @@ -122,7 +121,6 @@ def test_auto_mode_enqueue(self, server: object) -> None: msg_id = client.enqueue("test-auto-accum", None, b"auto-msg") assert msg_id != "" - # Verify the message was actually enqueued. stream = client.consume("test-auto-accum") msg = next(stream) assert msg.id == msg_id @@ -147,11 +145,10 @@ def test_auto_mode_multiple_messages(self, server: object) -> None: assert msg_id != "" ids.append(msg_id) - # All IDs should be unique. assert len(set(ids)) == 5 def test_disabled_mode_enqueue(self, server: object) -> None: - """DISABLED mode sends each enqueue as a direct RPC.""" + """DISABLED mode sends each enqueue as a direct request.""" from tests.conftest import TestServer assert isinstance(server, TestServer) @@ -195,7 +192,6 @@ def test_default_mode_is_auto(self, server: object) -> None: assert isinstance(server, TestServer) server.create_queue("test-default-mode") - # No accumulator_mode arg = AUTO. with fila.Client(server.addr) as client: msg_id = client.enqueue("test-default-mode", None, b"default") assert msg_id != "" @@ -209,7 +205,7 @@ def test_accumulator_mode_enum(self) -> None: assert fila.AccumulatorMode.AUTO is not None assert fila.AccumulatorMode.DISABLED is not None modes = {fila.AccumulatorMode.AUTO, fila.AccumulatorMode.DISABLED} - assert len(modes) == 2 # They are distinct values + assert len(modes) == 2 def test_linger_fields(self) -> None: """Linger stores linger_ms and max_messages.""" diff --git a/tests/test_fibp.py b/tests/test_fibp.py new file mode 100644 index 0000000..343bd11 --- /dev/null +++ b/tests/test_fibp.py @@ -0,0 +1,289 @@ +"""Unit tests for the FIBP codec and primitives.""" + +from __future__ import annotations + +from fila.fibp.codec import ( + decode_ack_result, + decode_delivery, + decode_enqueue_result, + decode_error, + decode_handshake_ok, + encode_ack, + encode_consume, + encode_enqueue, + encode_handshake, + encode_nack, +) +from fila.fibp.opcodes import ErrorCode, FrameHeader, Opcode +from fila.fibp.primitives import Reader, Writer + + +class TestPrimitives: + """Test Writer/Reader round-trip for all primitive types.""" + + def test_u8(self) -> None: + w = Writer() + w.write_u8(42) + r = Reader(w.finish()) + assert r.read_u8() == 42 + + def test_u16(self) -> None: + w = Writer() + w.write_u16(1234) + r = Reader(w.finish()) + assert r.read_u16() == 1234 + + def test_u32(self) -> None: + w = Writer() + w.write_u32(0xDEADBEEF) + r = Reader(w.finish()) + assert r.read_u32() == 0xDEADBEEF + + def test_u64(self) -> None: + w = Writer() + w.write_u64(0xDEADBEEFCAFE0001) + r = Reader(w.finish()) + assert r.read_u64() == 0xDEADBEEFCAFE0001 + + def test_i64(self) -> None: + w = Writer() + w.write_i64(-42) + r = Reader(w.finish()) + assert r.read_i64() == -42 + + def test_f64(self) -> None: + w = Writer() + w.write_f64(3.14) + r = Reader(w.finish()) + assert abs(r.read_f64() - 3.14) < 1e-10 + + def test_bool(self) -> None: + w = Writer() + w.write_bool(True) + w.write_bool(False) + r = Reader(w.finish()) + assert r.read_bool() is True + assert r.read_bool() is False + + def test_string(self) -> None: + w = Writer() + w.write_string("hello") + r = Reader(w.finish()) + assert r.read_string() == "hello" + + def test_string_empty(self) -> None: + w = Writer() + w.write_string("") + r = Reader(w.finish()) + assert r.read_string() == "" + + def test_bytes(self) -> None: + w = Writer() + w.write_bytes(b"\x00\x01\x02") + r = Reader(w.finish()) + assert r.read_bytes() == b"\x00\x01\x02" + + def test_string_map(self) -> None: + w = Writer() + w.write_string_map({"a": "1", "b": "2"}) + r = Reader(w.finish()) + m = r.read_string_map() + assert m == {"a": "1", "b": "2"} + + def test_string_list(self) -> None: + w = Writer() + w.write_string_list(["x", "y", "z"]) + r = Reader(w.finish()) + assert r.read_string_list() == ["x", "y", "z"] + + def test_optional_string_present(self) -> None: + w = Writer() + w.write_optional_string("present") + r = Reader(w.finish()) + assert r.read_optional_string() == "present" + + def test_optional_string_absent(self) -> None: + w = Writer() + w.write_optional_string(None) + r = Reader(w.finish()) + assert r.read_optional_string() is None + + +class TestCodec: + """Test encode/decode round-trips for key opcodes.""" + + def test_handshake_encode(self) -> None: + """Handshake encodes version + optional API key.""" + data = encode_handshake(1, "my-key") + r = Reader(data) + assert r.read_u16() == 1 + assert r.read_optional_string() == "my-key" + + def test_handshake_no_key(self) -> None: + data = encode_handshake(1, None) + r = Reader(data) + assert r.read_u16() == 1 + assert r.read_optional_string() is None + + def test_handshake_ok_decode(self) -> None: + w = Writer() + w.write_u16(1) # version + w.write_u64(42) # node_id + w.write_u32(16 * 1024 * 1024) # max_frame_size + version, node_id, mfs = decode_handshake_ok(w.finish()) + assert version == 1 + assert node_id == 42 + assert mfs == 16 * 1024 * 1024 + + def test_enqueue_encode_decode(self) -> None: + msgs = [ + {"queue": "q1", "headers": {"k": "v"}, "payload": b"hello"}, + {"queue": "q2", "headers": {}, "payload": b"world"}, + ] + data = encode_enqueue(msgs) + r = Reader(data) + count = r.read_u32() + assert count == 2 + # First message + assert r.read_string() == "q1" + assert r.read_string_map() == {"k": "v"} + assert r.read_bytes() == b"hello" + # Second message + assert r.read_string() == "q2" + assert r.read_string_map() == {} + assert r.read_bytes() == b"world" + + def test_enqueue_result_decode(self) -> None: + w = Writer() + w.write_u32(2) # count + w.write_u8(ErrorCode.OK) + w.write_string("msg-001") + w.write_u8(ErrorCode.QUEUE_NOT_FOUND) + w.write_string("") + items = decode_enqueue_result(w.finish()) + assert len(items) == 2 + assert items[0].error_code == ErrorCode.OK + assert items[0].message_id == "msg-001" + assert items[1].error_code == ErrorCode.QUEUE_NOT_FOUND + + def test_consume_encode(self) -> None: + data = encode_consume("my-queue") + r = Reader(data) + assert r.read_string() == "my-queue" + + def test_delivery_decode(self) -> None: + w = Writer() + w.write_u32(1) # count + w.write_string("msg-123") # msg_id + w.write_string("test-q") # queue + w.write_string_map({"h": "v"}) # headers + w.write_bytes(b"payload") # payload + w.write_string("fk") # fairness_key + w.write_u32(10) # weight + w.write_string_list(["tk1"]) # throttle_keys + w.write_u32(2) # attempt_count + w.write_u64(1000) # enqueued_at + w.write_u64(2000) # leased_at + + msgs = decode_delivery(w.finish()) + assert len(msgs) == 1 + m = msgs[0] + assert m.message_id == "msg-123" + assert m.queue == "test-q" + assert m.headers == {"h": "v"} + assert m.payload == b"payload" + assert m.fairness_key == "fk" + assert m.weight == 10 + assert m.throttle_keys == ["tk1"] + assert m.attempt_count == 2 + assert m.enqueued_at == 1000 + assert m.leased_at == 2000 + + def test_ack_encode(self) -> None: + data = encode_ack([{"queue": "q", "message_id": "id-1"}]) + r = Reader(data) + assert r.read_u32() == 1 + assert r.read_string() == "q" + assert r.read_string() == "id-1" + + def test_ack_result_decode(self) -> None: + w = Writer() + w.write_u32(2) + w.write_u8(ErrorCode.OK) + w.write_u8(ErrorCode.MESSAGE_NOT_FOUND) + codes = decode_ack_result(w.finish()) + assert codes == [ErrorCode.OK, ErrorCode.MESSAGE_NOT_FOUND] + + def test_nack_encode(self) -> None: + data = encode_nack([{"queue": "q", "message_id": "id-1", "error": "bad"}]) + r = Reader(data) + assert r.read_u32() == 1 + assert r.read_string() == "q" + assert r.read_string() == "id-1" + assert r.read_string() == "bad" + + def test_error_decode(self) -> None: + w = Writer() + w.write_u8(ErrorCode.NOT_LEADER) + w.write_string("not leader") + w.write_string_map({"leader_addr": "10.0.0.1:5555"}) + err = decode_error(w.finish()) + assert err.code == ErrorCode.NOT_LEADER + assert err.message == "not leader" + assert err.metadata["leader_addr"] == "10.0.0.1:5555" + + +class TestFrameHeader: + """Test FrameHeader dataclass.""" + + def test_continuation_flag(self) -> None: + h = FrameHeader(opcode=Opcode.DELIVERY, flags=0x01, request_id=1) + assert h.is_continuation is True + + def test_no_continuation(self) -> None: + h = FrameHeader(opcode=Opcode.DELIVERY, flags=0x00, request_id=1) + assert h.is_continuation is False + + +class TestErrorMapping: + """Test error code -> exception mapping.""" + + def test_queue_not_found(self) -> None: + from fila.errors import QueueNotFoundError, _map_error_code + err = _map_error_code(ErrorCode.QUEUE_NOT_FOUND, "missing") + assert isinstance(err, QueueNotFoundError) + + def test_message_not_found(self) -> None: + from fila.errors import MessageNotFoundError, _map_error_code + err = _map_error_code(ErrorCode.MESSAGE_NOT_FOUND, "gone") + assert isinstance(err, MessageNotFoundError) + + def test_unauthenticated(self) -> None: + from fila.errors import UnauthorizedError, _map_error_code + err = _map_error_code(ErrorCode.UNAUTHENTICATED, "no key") + assert isinstance(err, UnauthorizedError) + + def test_not_leader(self) -> None: + from fila.errors import NotLeaderError, _map_error_code + err = _map_error_code(ErrorCode.NOT_LEADER, "redirect") + assert isinstance(err, NotLeaderError) + + def test_internal_error(self) -> None: + from fila.errors import ProtocolError, _map_error_code + err = _map_error_code(ErrorCode.INTERNAL_ERROR, "boom") + assert isinstance(err, ProtocolError) + + def test_raise_from_error_frame_not_leader(self) -> None: + import pytest + + from fila.errors import NotLeaderError, _raise_from_error_frame + from fila.fibp.codec import ErrorFrame + + err = ErrorFrame( + code=ErrorCode.NOT_LEADER, + message="not leader", + metadata={"leader_addr": "10.0.0.1:5555"}, + ) + with pytest.raises(NotLeaderError) as exc_info: + _raise_from_error_frame(err) + assert exc_info.value.leader_addr == "10.0.0.1:5555" From 635f0d3b9ba7cc5255ff3e4478f702fc795bb79e Mon Sep 17 00:00:00 2001 From: Lucas Vieira Date: Sat, 4 Apr 2026 09:17:39 -0300 Subject: [PATCH 11/17] fix: align admin/auth codec with actual fibp wire format the initial codec used generic string maps for admin frames, but the protocol spec uses typed fields. this aligns all encode/decode functions with the actual wire format from docs/protocol.md: - create_queue: [string name][optional on_enqueue][optional on_failure][u64 timeout] - get_stats_result: typed fields (depth, in_flight, etc.) not a string map - list_queues_result: [u8 error][u32 nodes][u16 count][per: name, depth, ...] - set_config: [string key][string value] (not queue + map) - redrive: [string dlq][u64 count] (not source + dest + u32) - create_api_key: [string name][u64 expires][bool superadmin] - set_acl: [key_id][u16 count][per: kind, pattern] (not patterns list) - all result frames now decode their error_code prefix also updates types.py with properly structured StatsResult, QueueInfo, AclEntry, ApiKeyInfo to match the wire format fields. --- fila/__init__.py | 8 + fila/async_client.py | 268 ++++++++++++++++++-------------- fila/client.py | 214 +++++++++++++++++--------- fila/fibp/__init__.py | 14 ++ fila/fibp/codec.py | 346 ++++++++++++++++++++++++++++++++++-------- fila/types.py | 61 +++++++- 6 files changed, 654 insertions(+), 257 deletions(-) diff --git a/fila/__init__.py b/fila/__init__.py index 5ee1599..0b82fde 100644 --- a/fila/__init__.py +++ b/fila/__init__.py @@ -25,18 +25,23 @@ from fila.types import ( AccumulatorMode, AclEntry, + AclPermission, ApiKeyInfo, ConsumeMessage, CreateApiKeyResult, EnqueueResult, + FairnessKeyStat, Linger, + QueueInfo, StatsResult, + ThrottleKeyStat, ) __all__ = [ "AccumulatorMode", "AclEntry", "AclNotFoundError", + "AclPermission", "ApiKeyInfo", "ApiKeyNotFoundError", "AsyncClient", @@ -46,6 +51,7 @@ "CreateApiKeyResult", "EnqueueError", "EnqueueResult", + "FairnessKeyStat", "FilaError", "ForbiddenError", "InvalidArgumentError", @@ -56,10 +62,12 @@ "PermissionDeniedError", "ProtocolError", "QueueAlreadyExistsError", + "QueueInfo", "QueueNotFoundError", "RPCError", "ResourceExhaustedError", "StatsResult", + "ThrottleKeyStat", "UnauthorizedError", "UnavailableError", ] diff --git a/fila/async_client.py b/fila/async_client.py index 08b0809..212cab6 100644 --- a/fila/async_client.py +++ b/fila/async_client.py @@ -8,12 +8,15 @@ from fila.conn import AsyncConnection from fila.errors import ( NotLeaderError, + _map_error_code, _map_per_item_error, _raise_from_error_frame, ) from fila.fibp.codec import ( decode_ack_result, decode_create_api_key_result, + decode_create_queue_result, + decode_delete_queue_result, decode_delivery, decode_enqueue_result, decode_error, @@ -24,6 +27,10 @@ decode_list_config_result, decode_list_queues_result, decode_nack_result, + decode_redrive_result, + decode_revoke_api_key_result, + decode_set_acl_result, + decode_set_config_result, encode_ack, encode_create_api_key, encode_create_queue, @@ -44,11 +51,15 @@ from fila.fibp.opcodes import ErrorCode, Opcode from fila.types import ( AclEntry, + AclPermission, ApiKeyInfo, ConsumeMessage, CreateApiKeyResult, EnqueueResult, + FairnessKeyStat, + QueueInfo, StatsResult, + ThrottleKeyStat, ) if TYPE_CHECKING: @@ -56,7 +67,6 @@ def _parse_addr(addr: str) -> tuple[str, int]: - """Parse 'host:port' into (host, port).""" if ":" not in addr: raise ValueError(f"invalid address (expected host:port): {addr}") host, port_str = addr.rsplit(":", 1) @@ -64,43 +74,7 @@ def _parse_addr(addr: str) -> tuple[str, int]: class AsyncClient: - """Asynchronous client for the Fila message broker. - - Wraps the hot-path FIBP operations: enqueue, enqueue_many, consume, ack, nack. - - Usage:: - - client = await AsyncClient.create("localhost:5555") - msg_id = await client.enqueue("my-queue", {"tenant": "acme"}, b"hello") - async for msg in await client.consume("my-queue"): - await client.ack("my-queue", msg.id) - await client.close() - - Or as an async context manager:: - - async with AsyncClient("localhost:5555") as client: - await client.enqueue("my-queue", None, b"hello") - - TLS (system trust store):: - - client = AsyncClient("localhost:5555", tls=True) - - TLS (custom CA):: - - with open("ca.pem", "rb") as f: - ca = f.read() - client = AsyncClient("localhost:5555", ca_cert=ca) - - mTLS + API key:: - - client = AsyncClient( - "localhost:5555", - ca_cert=ca, - client_cert=cert, - client_key=key, - api_key="fila_...", - ) - """ + """Asynchronous client for the Fila message broker.""" def __init__( self, @@ -123,13 +97,12 @@ def __init__( use_tls = tls or ca_cert is not None if (client_cert is not None or client_key is not None) and not use_tls: raise ValueError( - "client_cert and client_key require ca_cert or tls=True to establish a TLS channel" + "client_cert and client_key require ca_cert or tls=True" ) self._ssl_ctx = self._make_ssl_context() if use_tls else None def _make_ssl_context(self) -> ssl.SSLContext: - """Create an SSL context from stored credentials.""" ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) if self._ca_cert is not None: ctx.load_verify_locations(cadata=self._ca_cert.decode("ascii")) @@ -153,7 +126,6 @@ def _make_ssl_context(self) -> ssl.SSLContext: return ctx async def _ensure_connected(self) -> AsyncConnection: - """Ensure a connection exists, creating one if needed.""" if self._conn is None: host, port = _parse_addr(self._addr) self._conn = await AsyncConnection.connect( @@ -162,10 +134,9 @@ async def _ensure_connected(self) -> AsyncConnection: return self._conn async def _reconnect(self, addr: str) -> None: - """Reconnect to a different address (e.g. after leader hint).""" - if self._conn is not None: - import contextlib + import contextlib + if self._conn is not None: with contextlib.suppress(OSError): await self._conn.close() self._addr = addr @@ -173,7 +144,6 @@ async def _reconnect(self, addr: str) -> None: await self._ensure_connected() async def close(self) -> None: - """Close the underlying connection.""" if self._conn is not None: await self._conn.close() self._conn = None @@ -193,11 +163,7 @@ async def enqueue( headers: dict[str, str] | None, payload: bytes, ) -> str: - """Enqueue a message to the specified queue. - - Returns: - Broker-assigned message ID (UUIDv7). - """ + """Enqueue a message. Returns the broker-assigned message ID.""" await self._ensure_connected() msgs = [{"queue": queue, "headers": headers or {}, "payload": payload}] body = encode_enqueue(msgs) @@ -205,7 +171,6 @@ async def enqueue( header, resp_body = await self._request_with_leader_retry( Opcode.ENQUEUE, body ) - if header.opcode == Opcode.ERROR: err = decode_error(resp_body) _raise_from_error_frame(err) @@ -231,7 +196,6 @@ async def enqueue_many( header, resp_body = await self._request_with_leader_retry( Opcode.ENQUEUE, body ) - if header.opcode == Opcode.ERROR: err = decode_error(resp_body) _raise_from_error_frame(err) @@ -242,16 +206,12 @@ async def enqueue_many( if item.error_code == ErrorCode.OK: results.append(EnqueueResult(message_id=item.message_id, error=None)) else: - results.append( - EnqueueResult(message_id=None, error=f"error 0x{item.error_code:02x}") - ) + err_msg = f"error 0x{item.error_code:02x}" + results.append(EnqueueResult(message_id=None, error=err_msg)) return results async def consume(self, queue: str) -> AsyncIterator[ConsumeMessage]: - """Open a streaming consumer on the specified queue. - - Returns an async iterator that yields messages as they arrive. - """ + """Open a streaming consumer on the specified queue.""" conn = await self._ensure_connected() try: _req_id, consumer_id = await conn.subscribe(queue) @@ -268,7 +228,6 @@ async def consume(self, queue: str) -> AsyncIterator[ConsumeMessage]: async def _consume_iter( self, conn: AsyncConnection, consumer_id: str ) -> AsyncIterator[ConsumeMessage]: - """Internal async generator reading Delivery frames.""" try: while True: header, body = await conn.read_frame() @@ -294,7 +253,6 @@ async def _consume_iter( return async def ack(self, queue: str, msg_id: str) -> None: - """Acknowledge a successfully processed message.""" body = encode_ack([{"queue": queue, "message_id": msg_id}]) header, resp_body = await self._request_with_leader_retry(Opcode.ACK, body) @@ -307,7 +265,6 @@ async def ack(self, queue: str, msg_id: str) -> None: raise _map_per_item_error(codes[0], "ack") async def nack(self, queue: str, msg_id: str, error: str) -> None: - """Negatively acknowledge a message that failed processing.""" body = encode_nack([{"queue": queue, "message_id": msg_id, "error": error}]) header, resp_body = await self._request_with_leader_retry(Opcode.NACK, body) @@ -321,134 +278,215 @@ async def nack(self, queue: str, msg_id: str, error: str) -> None: # -- admin operations ---------------------------------------------------- - async def create_queue(self, name: str, config: dict[str, str] | None = None) -> None: - """Create a queue on the broker.""" - body = encode_create_queue(name, config) - header, resp_body = await self._request_with_leader_retry(Opcode.CREATE_QUEUE, body) + async def create_queue(self, name: str) -> None: + body = encode_create_queue(name) + header, resp_body = await self._request_with_leader_retry( + Opcode.CREATE_QUEUE, body + ) if header.opcode == Opcode.ERROR: err = decode_error(resp_body) _raise_from_error_frame(err) + result = decode_create_queue_result(resp_body) + if result.error_code != ErrorCode.OK: + raise _map_error_code(result.error_code, "create_queue failed") async def delete_queue(self, name: str) -> None: - """Delete a queue from the broker.""" body = encode_delete_queue(name) - header, resp_body = await self._request_with_leader_retry(Opcode.DELETE_QUEUE, body) + header, resp_body = await self._request_with_leader_retry( + Opcode.DELETE_QUEUE, body + ) if header.opcode == Opcode.ERROR: err = decode_error(resp_body) _raise_from_error_frame(err) + error_code = decode_delete_queue_result(resp_body) + if error_code != ErrorCode.OK: + raise _map_error_code(error_code, "delete_queue failed") async def get_stats(self, queue: str) -> StatsResult: - """Get statistics for a queue.""" body = encode_get_stats(queue) - header, resp_body = await self._request_with_leader_retry(Opcode.GET_STATS, body) + header, resp_body = await self._request_with_leader_retry( + Opcode.GET_STATS, body + ) if header.opcode == Opcode.ERROR: err = decode_error(resp_body) _raise_from_error_frame(err) - result = decode_get_stats_result(resp_body) - return StatsResult(stats=result.stats) + r = decode_get_stats_result(resp_body) + if r.error_code != ErrorCode.OK: + raise _map_error_code(r.error_code, "get_stats failed") + return StatsResult( + depth=r.depth, in_flight=r.in_flight, + active_fairness_keys=r.active_fairness_keys, + active_consumers=r.active_consumers, quantum=r.quantum, + leader_node_id=r.leader_node_id, + replication_count=r.replication_count, + per_key_stats=[ + FairnessKeyStat( + key=s.key, pending_count=s.pending_count, + current_deficit=s.current_deficit, weight=s.weight, + ) for s in r.per_key_stats + ], + per_throttle_stats=[ + ThrottleKeyStat( + key=s.key, tokens=s.tokens, + rate_per_second=s.rate_per_second, burst=s.burst, + ) for s in r.per_throttle_stats + ], + ) - async def list_queues(self) -> list[str]: - """List all queues on the broker.""" + async def list_queues(self) -> list[QueueInfo]: body = encode_list_queues() - header, resp_body = await self._request_with_leader_retry(Opcode.LIST_QUEUES, body) + header, resp_body = await self._request_with_leader_retry( + Opcode.LIST_QUEUES, body + ) if header.opcode == Opcode.ERROR: err = decode_error(resp_body) _raise_from_error_frame(err) - return decode_list_queues_result(resp_body) + r = decode_list_queues_result(resp_body) + if r.error_code != ErrorCode.OK: + raise _map_error_code(r.error_code, "list_queues failed") + return [ + QueueInfo( + name=q.name, depth=q.depth, in_flight=q.in_flight, + active_consumers=q.active_consumers, + leader_node_id=q.leader_node_id, + ) for q in r.queues + ] - async def set_config(self, queue: str, config: dict[str, str]) -> None: - """Set configuration for a queue.""" - body = encode_set_config(queue, config) - header, resp_body = await self._request_with_leader_retry(Opcode.SET_CONFIG, body) + async def set_config(self, key: str, value: str) -> None: + body = encode_set_config(key, value) + header, resp_body = await self._request_with_leader_retry( + Opcode.SET_CONFIG, body + ) if header.opcode == Opcode.ERROR: err = decode_error(resp_body) _raise_from_error_frame(err) + error_code = decode_set_config_result(resp_body) + if error_code != ErrorCode.OK: + raise _map_error_code(error_code, "set_config failed") - async def get_config(self, queue: str) -> dict[str, str]: - """Get configuration for a queue.""" - body = encode_get_config(queue) - header, resp_body = await self._request_with_leader_retry(Opcode.GET_CONFIG, body) + async def get_config(self, key: str) -> str: + body = encode_get_config(key) + header, resp_body = await self._request_with_leader_retry( + Opcode.GET_CONFIG, body + ) if header.opcode == Opcode.ERROR: err = decode_error(resp_body) _raise_from_error_frame(err) - return decode_get_config_result(resp_body) + error_code, value = decode_get_config_result(resp_body) + if error_code != ErrorCode.OK: + raise _map_error_code(error_code, "get_config failed") + return value - async def list_config(self, queue: str) -> dict[str, str]: - """List all configuration for a queue.""" - body = encode_list_config(queue) - header, resp_body = await self._request_with_leader_retry(Opcode.LIST_CONFIG, body) + async def list_config(self, prefix: str) -> dict[str, str]: + body = encode_list_config(prefix) + header, resp_body = await self._request_with_leader_retry( + Opcode.LIST_CONFIG, body + ) if header.opcode == Opcode.ERROR: err = decode_error(resp_body) _raise_from_error_frame(err) - return decode_list_config_result(resp_body) + error_code, entries = decode_list_config_result(resp_body) + if error_code != ErrorCode.OK: + raise _map_error_code(error_code, "list_config failed") + return entries - async def redrive(self, source_queue: str, dest_queue: str, count: int) -> None: - """Redrive messages from one queue to another.""" - body = encode_redrive(source_queue, dest_queue, count) - header, resp_body = await self._request_with_leader_retry(Opcode.REDRIVE, body) + async def redrive(self, dlq_queue: str, count: int) -> int: + body = encode_redrive(dlq_queue, count) + header, resp_body = await self._request_with_leader_retry( + Opcode.REDRIVE, body + ) if header.opcode == Opcode.ERROR: err = decode_error(resp_body) _raise_from_error_frame(err) + error_code, redriven = decode_redrive_result(resp_body) + if error_code != ErrorCode.OK: + raise _map_error_code(error_code, "redrive failed") + return redriven # -- auth operations ----------------------------------------------------- async def create_api_key(self, name: str) -> CreateApiKeyResult: - """Create a new API key.""" body = encode_create_api_key(name) - header, resp_body = await self._request_with_leader_retry(Opcode.CREATE_API_KEY, body) + header, resp_body = await self._request_with_leader_retry( + Opcode.CREATE_API_KEY, body + ) if header.opcode == Opcode.ERROR: err = decode_error(resp_body) _raise_from_error_frame(err) - key_id, raw_key = decode_create_api_key_result(resp_body) - return CreateApiKeyResult(key_id=key_id, raw_key=raw_key) + ec, key_id, raw_key, is_superadmin = decode_create_api_key_result(resp_body) + if ec != ErrorCode.OK: + raise _map_error_code(ec, "create_api_key failed") + return CreateApiKeyResult( + key_id=key_id, raw_key=raw_key, is_superadmin=is_superadmin + ) async def revoke_api_key(self, key_id: str) -> None: - """Revoke an API key.""" body = encode_revoke_api_key(key_id) - header, resp_body = await self._request_with_leader_retry(Opcode.REVOKE_API_KEY, body) + header, resp_body = await self._request_with_leader_retry( + Opcode.REVOKE_API_KEY, body + ) if header.opcode == Opcode.ERROR: err = decode_error(resp_body) _raise_from_error_frame(err) + error_code = decode_revoke_api_key_result(resp_body) + if error_code != ErrorCode.OK: + raise _map_error_code(error_code, "revoke_api_key failed") async def list_api_keys(self) -> list[ApiKeyInfo]: - """List all API keys.""" body = encode_list_api_keys() - header, resp_body = await self._request_with_leader_retry(Opcode.LIST_API_KEYS, body) + header, resp_body = await self._request_with_leader_retry( + Opcode.LIST_API_KEYS, body + ) if header.opcode == Opcode.ERROR: err = decode_error(resp_body) _raise_from_error_frame(err) - items = decode_list_api_keys_result(resp_body) + error_code, items = decode_list_api_keys_result(resp_body) + if error_code != ErrorCode.OK: + raise _map_error_code(error_code, "list_api_keys failed") return [ - ApiKeyInfo(key_id=k.key_id, prefix=k.prefix, created_at=k.created_at) - for k in items + ApiKeyInfo( + key_id=k.key_id, name=k.name, created_at=k.created_at, + expires_at=k.expires_at, is_superadmin=k.is_superadmin, + ) for k in items ] - async def set_acl( - self, key_id: str, patterns: list[str], superadmin: bool = False - ) -> None: - """Set ACL for an API key.""" - body = encode_set_acl(key_id, patterns, superadmin) - header, resp_body = await self._request_with_leader_retry(Opcode.SET_ACL, body) + async def set_acl(self, key_id: str, permissions: list[tuple[str, str]]) -> None: + body = encode_set_acl(key_id, permissions) + header, resp_body = await self._request_with_leader_retry( + Opcode.SET_ACL, body + ) if header.opcode == Opcode.ERROR: err = decode_error(resp_body) _raise_from_error_frame(err) + error_code = decode_set_acl_result(resp_body) + if error_code != ErrorCode.OK: + raise _map_error_code(error_code, "set_acl failed") async def get_acl(self, key_id: str) -> AclEntry: - """Get ACL for an API key.""" body = encode_get_acl(key_id) - header, resp_body = await self._request_with_leader_retry(Opcode.GET_ACL, body) + header, resp_body = await self._request_with_leader_retry( + Opcode.GET_ACL, body + ) if header.opcode == Opcode.ERROR: err = decode_error(resp_body) _raise_from_error_frame(err) result = decode_get_acl_result(resp_body) - return AclEntry(patterns=result.patterns, superadmin=result.superadmin) + if result.error_code != ErrorCode.OK: + raise _map_error_code(result.error_code, "get_acl failed") + return AclEntry( + key_id=result.key_id, + is_superadmin=result.is_superadmin, + permissions=[ + AclPermission(kind=p.kind, pattern=p.pattern) + for p in result.permissions + ], + ) # -- internal helpers ---------------------------------------------------- async def _request_with_leader_retry( self, opcode: int, body: bytes ) -> tuple[object, bytes]: - """Send a request, retrying once on NotLeader with leader hint.""" conn = await self._ensure_connected() header, resp_body = await conn.request(opcode, body) diff --git a/fila/client.py b/fila/client.py index 7e193dd..3d7b445 100644 --- a/fila/client.py +++ b/fila/client.py @@ -9,12 +9,15 @@ from fila.conn import Connection from fila.errors import ( NotLeaderError, + _map_error_code, _map_per_item_error, _raise_from_error_frame, ) from fila.fibp.codec import ( decode_ack_result, decode_create_api_key_result, + decode_create_queue_result, + decode_delete_queue_result, decode_delivery, decode_enqueue_result, decode_error, @@ -25,6 +28,10 @@ decode_list_config_result, decode_list_queues_result, decode_nack_result, + decode_redrive_result, + decode_revoke_api_key_result, + decode_set_acl_result, + decode_set_config_result, encode_ack, encode_create_api_key, encode_create_queue, @@ -46,12 +53,16 @@ from fila.types import ( AccumulatorMode, AclEntry, + AclPermission, ApiKeyInfo, ConsumeMessage, CreateApiKeyResult, EnqueueResult, + FairnessKeyStat, Linger, + QueueInfo, StatsResult, + ThrottleKeyStat, ) if TYPE_CHECKING: @@ -93,7 +104,10 @@ class Client: client = Client("localhost:5555", accumulator_mode=AccumulatorMode.DISABLED) # LINGER: timer-based forced accumulation - client = Client("localhost:5555", accumulator_mode=Linger(linger_ms=10, max_messages=100)) + client = Client( + "localhost:5555", + accumulator_mode=Linger(linger_ms=10, max_messages=100), + ) TLS (system trust store):: @@ -138,13 +152,12 @@ def __init__( use_tls = tls or ca_cert is not None if (client_cert is not None or client_key is not None) and not use_tls: raise ValueError( - "client_cert and client_key require ca_cert or tls=True to establish a TLS channel" + "client_cert and client_key require ca_cert or tls=True" ) self._ssl_ctx = self._make_ssl_context() if use_tls else None self._conn = self._connect(addr) - # Set up the accumulator based on the chosen mode. self._accumulator: AutoAccumulator | LingerAccumulator | None = None if isinstance(accumulator_mode, Linger): self._accumulator = LingerAccumulator( @@ -157,7 +170,6 @@ def __init__( self._conn, max_messages=max_accumulator_messages, ) - # AccumulatorMode.DISABLED: self._accumulator stays None def _make_ssl_context(self) -> ssl.SSLContext: """Create an SSL context from stored credentials.""" @@ -167,8 +179,6 @@ def _make_ssl_context(self) -> ssl.SSLContext: else: ctx.load_default_certs() if self._client_cert is not None and self._client_key is not None: - # Write temp files for load_cert_chain (ssl module needs file paths or - # we can use the cadata approach for CA only). import os import tempfile @@ -186,14 +196,12 @@ def _make_ssl_context(self) -> ssl.SSLContext: return ctx def _connect(self, addr: str) -> Connection: - """Open a FIBP connection to the given address.""" host, port = _parse_addr(addr) return Connection.connect( host, port, ssl_context=self._ssl_ctx, api_key=self._api_key ) def _reconnect(self, addr: str) -> None: - """Reconnect to a different address (e.g. after leader hint).""" import contextlib with contextlib.suppress(OSError): @@ -204,7 +212,7 @@ def _reconnect(self, addr: str) -> None: self._accumulator.update_conn(self._conn) def close(self) -> None: - """Drain pending accumulated messages and close the underlying connection.""" + """Drain pending accumulated messages and close the connection.""" if self._accumulator is not None: self._accumulator.close() self._conn.close() @@ -223,15 +231,7 @@ def enqueue( headers: dict[str, str] | None, payload: bytes, ) -> str: - """Enqueue a message to the specified queue. - - Returns: - Broker-assigned message ID (UUIDv7). - - Raises: - QueueNotFoundError: If the queue does not exist. - EnqueueError: If the enqueue fails. - """ + """Enqueue a message. Returns the broker-assigned message ID.""" msg = {"queue": queue, "headers": headers or {}, "payload": payload} if self._accumulator is not None: @@ -244,23 +244,13 @@ def enqueue_many( self, messages: list[tuple[str, dict[str, str] | None, bytes]], ) -> list[EnqueueResult]: - """Enqueue multiple messages in a single request. - - Returns: - List of ``EnqueueResult`` objects, one per input message. - """ + """Enqueue multiple messages in a single request.""" msgs = [ {"queue": q, "headers": h or {}, "payload": p} for q, h, p in messages ] body = encode_enqueue(msgs) - - try: - header, resp_body = self._request_with_leader_retry( - Opcode.ENQUEUE, body - ) - except NotLeaderError: - raise + header, resp_body = self._request_with_leader_retry(Opcode.ENQUEUE, body) if header.opcode == Opcode.ERROR: err = decode_error(resp_body) @@ -277,13 +267,7 @@ def enqueue_many( return results def consume(self, queue: str) -> Iterator[ConsumeMessage]: - """Open a streaming consumer on the specified queue. - - Yields messages as they become available. The iterator ends when the - connection closes or an error occurs. - - Handles NotLeader errors by transparently reconnecting once. - """ + """Open a streaming consumer on the specified queue.""" try: _req_id, consumer_id = self._conn.subscribe(queue) except NotLeaderError as e: @@ -296,7 +280,6 @@ def consume(self, queue: str) -> Iterator[ConsumeMessage]: return self._consume_iter(consumer_id) def _consume_iter(self, consumer_id: str) -> Iterator[ConsumeMessage]: - """Internal generator reading Delivery frames.""" try: while True: header, body = self._conn.read_frame() @@ -318,7 +301,6 @@ def _consume_iter(self, consumer_id: str) -> Iterator[ConsumeMessage]: elif header.opcode == Opcode.ERROR: err = decode_error(body) _raise_from_error_frame(err) - # Ignore other frames (e.g. pong is handled in read_frame). except (ConnectionError, OSError): return @@ -350,13 +332,16 @@ def nack(self, queue: str, msg_id: str, error: str) -> None: # -- admin operations ---------------------------------------------------- - def create_queue(self, name: str, config: dict[str, str] | None = None) -> None: + def create_queue(self, name: str) -> None: """Create a queue on the broker.""" - body = encode_create_queue(name, config) + body = encode_create_queue(name) header, resp_body = self._request_with_leader_retry(Opcode.CREATE_QUEUE, body) if header.opcode == Opcode.ERROR: err = decode_error(resp_body) _raise_from_error_frame(err) + result = decode_create_queue_result(resp_body) + if result.error_code != ErrorCode.OK: + raise _map_error_code(result.error_code, "create_queue failed") def delete_queue(self, name: str) -> None: """Delete a queue from the broker.""" @@ -365,6 +350,9 @@ def delete_queue(self, name: str) -> None: if header.opcode == Opcode.ERROR: err = decode_error(resp_body) _raise_from_error_frame(err) + error_code = decode_delete_queue_result(resp_body) + if error_code != ErrorCode.OK: + raise _map_error_code(error_code, "delete_queue failed") def get_stats(self, queue: str) -> StatsResult: """Get statistics for a queue.""" @@ -373,92 +361,162 @@ def get_stats(self, queue: str) -> StatsResult: if header.opcode == Opcode.ERROR: err = decode_error(resp_body) _raise_from_error_frame(err) - result = decode_get_stats_result(resp_body) - return StatsResult(stats=result.stats) + r = decode_get_stats_result(resp_body) + if r.error_code != ErrorCode.OK: + raise _map_error_code(r.error_code, "get_stats failed") + return StatsResult( + depth=r.depth, + in_flight=r.in_flight, + active_fairness_keys=r.active_fairness_keys, + active_consumers=r.active_consumers, + quantum=r.quantum, + leader_node_id=r.leader_node_id, + replication_count=r.replication_count, + per_key_stats=[ + FairnessKeyStat( + key=s.key, pending_count=s.pending_count, + current_deficit=s.current_deficit, weight=s.weight, + ) + for s in r.per_key_stats + ], + per_throttle_stats=[ + ThrottleKeyStat( + key=s.key, tokens=s.tokens, + rate_per_second=s.rate_per_second, burst=s.burst, + ) + for s in r.per_throttle_stats + ], + ) - def list_queues(self) -> list[str]: + def list_queues(self) -> list[QueueInfo]: """List all queues on the broker.""" body = encode_list_queues() header, resp_body = self._request_with_leader_retry(Opcode.LIST_QUEUES, body) if header.opcode == Opcode.ERROR: err = decode_error(resp_body) _raise_from_error_frame(err) - return decode_list_queues_result(resp_body) + r = decode_list_queues_result(resp_body) + if r.error_code != ErrorCode.OK: + raise _map_error_code(r.error_code, "list_queues failed") + return [ + QueueInfo( + name=q.name, depth=q.depth, in_flight=q.in_flight, + active_consumers=q.active_consumers, + leader_node_id=q.leader_node_id, + ) + for q in r.queues + ] - def set_config(self, queue: str, config: dict[str, str]) -> None: - """Set configuration for a queue.""" - body = encode_set_config(queue, config) + def set_config(self, key: str, value: str) -> None: + """Set a runtime configuration key.""" + body = encode_set_config(key, value) header, resp_body = self._request_with_leader_retry(Opcode.SET_CONFIG, body) if header.opcode == Opcode.ERROR: err = decode_error(resp_body) _raise_from_error_frame(err) + error_code = decode_set_config_result(resp_body) + if error_code != ErrorCode.OK: + raise _map_error_code(error_code, "set_config failed") - def get_config(self, queue: str) -> dict[str, str]: - """Get configuration for a queue.""" - body = encode_get_config(queue) + def get_config(self, key: str) -> str: + """Get a runtime configuration value.""" + body = encode_get_config(key) header, resp_body = self._request_with_leader_retry(Opcode.GET_CONFIG, body) if header.opcode == Opcode.ERROR: err = decode_error(resp_body) _raise_from_error_frame(err) - return decode_get_config_result(resp_body) - - def list_config(self, queue: str) -> dict[str, str]: - """List all configuration for a queue.""" - body = encode_list_config(queue) + error_code, value = decode_get_config_result(resp_body) + if error_code != ErrorCode.OK: + raise _map_error_code(error_code, "get_config failed") + return value + + def list_config(self, prefix: str) -> dict[str, str]: + """List configuration entries matching a prefix.""" + body = encode_list_config(prefix) header, resp_body = self._request_with_leader_retry(Opcode.LIST_CONFIG, body) if header.opcode == Opcode.ERROR: err = decode_error(resp_body) _raise_from_error_frame(err) - return decode_list_config_result(resp_body) - - def redrive(self, source_queue: str, dest_queue: str, count: int) -> None: - """Redrive messages from one queue to another.""" - body = encode_redrive(source_queue, dest_queue, count) + error_code, entries = decode_list_config_result(resp_body) + if error_code != ErrorCode.OK: + raise _map_error_code(error_code, "list_config failed") + return entries + + def redrive(self, dlq_queue: str, count: int) -> int: + """Redrive DLQ messages. Returns number of messages redriven.""" + body = encode_redrive(dlq_queue, count) header, resp_body = self._request_with_leader_retry(Opcode.REDRIVE, body) if header.opcode == Opcode.ERROR: err = decode_error(resp_body) _raise_from_error_frame(err) + error_code, redriven = decode_redrive_result(resp_body) + if error_code != ErrorCode.OK: + raise _map_error_code(error_code, "redrive failed") + return redriven # -- auth operations ----------------------------------------------------- def create_api_key(self, name: str) -> CreateApiKeyResult: """Create a new API key.""" body = encode_create_api_key(name) - header, resp_body = self._request_with_leader_retry(Opcode.CREATE_API_KEY, body) + header, resp_body = self._request_with_leader_retry( + Opcode.CREATE_API_KEY, body + ) if header.opcode == Opcode.ERROR: err = decode_error(resp_body) _raise_from_error_frame(err) - key_id, raw_key = decode_create_api_key_result(resp_body) - return CreateApiKeyResult(key_id=key_id, raw_key=raw_key) + error_code, key_id, raw_key, is_superadmin = decode_create_api_key_result( + resp_body + ) + if error_code != ErrorCode.OK: + raise _map_error_code(error_code, "create_api_key failed") + return CreateApiKeyResult( + key_id=key_id, raw_key=raw_key, is_superadmin=is_superadmin + ) def revoke_api_key(self, key_id: str) -> None: """Revoke an API key.""" body = encode_revoke_api_key(key_id) - header, resp_body = self._request_with_leader_retry(Opcode.REVOKE_API_KEY, body) + header, resp_body = self._request_with_leader_retry( + Opcode.REVOKE_API_KEY, body + ) if header.opcode == Opcode.ERROR: err = decode_error(resp_body) _raise_from_error_frame(err) + error_code = decode_revoke_api_key_result(resp_body) + if error_code != ErrorCode.OK: + raise _map_error_code(error_code, "revoke_api_key failed") def list_api_keys(self) -> list[ApiKeyInfo]: """List all API keys.""" body = encode_list_api_keys() - header, resp_body = self._request_with_leader_retry(Opcode.LIST_API_KEYS, body) + header, resp_body = self._request_with_leader_retry( + Opcode.LIST_API_KEYS, body + ) if header.opcode == Opcode.ERROR: err = decode_error(resp_body) _raise_from_error_frame(err) - items = decode_list_api_keys_result(resp_body) + error_code, items = decode_list_api_keys_result(resp_body) + if error_code != ErrorCode.OK: + raise _map_error_code(error_code, "list_api_keys failed") return [ - ApiKeyInfo(key_id=k.key_id, prefix=k.prefix, created_at=k.created_at) + ApiKeyInfo( + key_id=k.key_id, name=k.name, created_at=k.created_at, + expires_at=k.expires_at, is_superadmin=k.is_superadmin, + ) for k in items ] - def set_acl(self, key_id: str, patterns: list[str], superadmin: bool = False) -> None: - """Set ACL for an API key.""" - body = encode_set_acl(key_id, patterns, superadmin) + def set_acl(self, key_id: str, permissions: list[tuple[str, str]]) -> None: + """Set ACL for an API key. permissions = [(kind, pattern), ...].""" + body = encode_set_acl(key_id, permissions) header, resp_body = self._request_with_leader_retry(Opcode.SET_ACL, body) if header.opcode == Opcode.ERROR: err = decode_error(resp_body) _raise_from_error_frame(err) + error_code = decode_set_acl_result(resp_body) + if error_code != ErrorCode.OK: + raise _map_error_code(error_code, "set_acl failed") def get_acl(self, key_id: str) -> AclEntry: """Get ACL for an API key.""" @@ -468,14 +526,22 @@ def get_acl(self, key_id: str) -> AclEntry: err = decode_error(resp_body) _raise_from_error_frame(err) result = decode_get_acl_result(resp_body) - return AclEntry(patterns=result.patterns, superadmin=result.superadmin) + if result.error_code != ErrorCode.OK: + raise _map_error_code(result.error_code, "get_acl failed") + return AclEntry( + key_id=result.key_id, + is_superadmin=result.is_superadmin, + permissions=[ + AclPermission(kind=p.kind, pattern=p.pattern) + for p in result.permissions + ], + ) # -- internal helpers ---------------------------------------------------- def _enqueue_direct(self, messages: list[dict[str, object]]) -> list[str]: """Send an enqueue request directly and return message IDs.""" body = encode_enqueue(messages) - header, resp_body = self._request_with_leader_retry(Opcode.ENQUEUE, body) if header.opcode == Opcode.ERROR: @@ -495,10 +561,8 @@ def _request_with_leader_retry( self, opcode: int, body: bytes ) -> tuple[object, bytes]: """Send a request, retrying once on NotLeader with leader hint.""" - header, resp_body = self._conn.request(opcode, body) - # Check for NotLeader error with leader hint. if header.opcode == Opcode.ERROR: err = decode_error(resp_body) if err.code == ErrorCode.NOT_LEADER: diff --git a/fila/fibp/__init__.py b/fila/fibp/__init__.py index 67fffd9..134f1b9 100644 --- a/fila/fibp/__init__.py +++ b/fila/fibp/__init__.py @@ -3,6 +3,9 @@ from fila.fibp.codec import ( decode_ack_result, decode_consume_ok, + decode_create_api_key_result, + decode_create_queue_result, + decode_delete_queue_result, decode_delivery, decode_enqueue_result, decode_error, @@ -14,6 +17,10 @@ decode_list_config_result, decode_list_queues_result, decode_nack_result, + decode_redrive_result, + decode_revoke_api_key_result, + decode_set_acl_result, + decode_set_config_result, encode_ack, encode_cancel_consume, encode_consume, @@ -47,6 +54,9 @@ "Writer", "decode_ack_result", "decode_consume_ok", + "decode_create_api_key_result", + "decode_create_queue_result", + "decode_delete_queue_result", "decode_delivery", "decode_enqueue_result", "decode_error", @@ -58,6 +68,10 @@ "decode_list_config_result", "decode_list_queues_result", "decode_nack_result", + "decode_redrive_result", + "decode_revoke_api_key_result", + "decode_set_acl_result", + "decode_set_config_result", "encode_ack", "encode_cancel_consume", "encode_consume", diff --git a/fila/fibp/codec.py b/fila/fibp/codec.py index a7544c2..407d578 100644 --- a/fila/fibp/codec.py +++ b/fila/fibp/codec.py @@ -2,7 +2,7 @@ from __future__ import annotations -from dataclasses import dataclass +from dataclasses import dataclass, field from fila.fibp.primitives import Reader, Writer @@ -10,6 +10,7 @@ # Data types used by decode functions # --------------------------------------------------------------------------- + @dataclass(frozen=True, slots=True) class DeliveryMessage: """A single message within a Delivery frame.""" @@ -44,41 +45,103 @@ class ErrorFrame: @dataclass(frozen=True, slots=True) -class StatsResult: +class CreateQueueResultFrame: + """Decoded CreateQueueResult.""" + + error_code: int + queue_id: str + + +@dataclass(frozen=True, slots=True) +class FairnessKeyStat: + """Per-fairness-key stats in GetStatsResult.""" + + key: str + pending_count: int + current_deficit: int + weight: int + + +@dataclass(frozen=True, slots=True) +class ThrottleKeyStat: + """Per-throttle-key stats in GetStatsResult.""" + + key: str + tokens: float + rate_per_second: float + burst: float + + +@dataclass(frozen=True, slots=True) +class StatsResultFrame: """Decoded GetStatsResult frame.""" - stats: dict[str, str] + error_code: int + depth: int + in_flight: int + active_fairness_keys: int + active_consumers: int + quantum: int + leader_node_id: int + replication_count: int + per_key_stats: list[FairnessKeyStat] = field(default_factory=list) + per_throttle_stats: list[ThrottleKeyStat] = field(default_factory=list) @dataclass(frozen=True, slots=True) -class QueueInfo: +class ListQueuesQueueInfo: """A single queue in ListQueuesResult.""" name: str - config: dict[str, str] + depth: int + in_flight: int + active_consumers: int + leader_node_id: int + + +@dataclass(frozen=True, slots=True) +class ListQueuesResultFrame: + """Decoded ListQueuesResult.""" + + error_code: int + cluster_node_count: int + queues: list[ListQueuesQueueInfo] @dataclass(frozen=True, slots=True) -class ApiKeyInfo: +class ApiKeyInfoFrame: """A single API key in ListApiKeysResult.""" key_id: str - prefix: str + name: str created_at: int + expires_at: int + is_superadmin: bool @dataclass(frozen=True, slots=True) -class AclEntry: +class AclPermission: + """A single ACL permission.""" + + kind: str + pattern: str + + +@dataclass(frozen=True, slots=True) +class GetAclResultFrame: """Decoded GetAclResult.""" - patterns: list[str] - superadmin: bool + error_code: int + key_id: str + is_superadmin: bool + permissions: list[AclPermission] # --------------------------------------------------------------------------- # Encode: Control # --------------------------------------------------------------------------- + def encode_handshake(version: int, api_key: str | None = None) -> bytes: """Encode a Handshake frame body.""" w = Writer() @@ -101,6 +164,7 @@ def encode_disconnect() -> bytes: # Decode: Control # --------------------------------------------------------------------------- + def decode_handshake_ok(data: bytes) -> tuple[int, int, int]: """Decode a HandshakeOk frame body -> (version, node_id, max_frame_size).""" r = Reader(data) @@ -114,6 +178,7 @@ def decode_handshake_ok(data: bytes) -> tuple[int, int, int]: # Encode: Hot-path # --------------------------------------------------------------------------- + def encode_enqueue(messages: list[dict[str, object]]) -> bytes: """Encode an Enqueue frame body. @@ -167,6 +232,7 @@ def encode_nack(items: list[dict[str, str]]) -> bytes: # Decode: Hot-path # --------------------------------------------------------------------------- + def decode_enqueue_result(data: bytes) -> list[EnqueueResultItem]: """Decode an EnqueueResult frame body.""" r = Reader(data) @@ -234,6 +300,7 @@ def decode_nack_result(data: bytes) -> list[int]: # Decode: Error # --------------------------------------------------------------------------- + def decode_error(data: bytes) -> ErrorFrame: """Decode an Error frame body.""" r = Reader(data) @@ -247,11 +314,24 @@ def decode_error(data: bytes) -> ErrorFrame: # Encode: Admin # --------------------------------------------------------------------------- -def encode_create_queue(name: str, config: dict[str, str] | None = None) -> bytes: - """Encode a CreateQueue frame body.""" + +def encode_create_queue( + name: str, + *, + on_enqueue_script: str | None = None, + on_failure_script: str | None = None, + visibility_timeout_ms: int = 0, +) -> bytes: + """Encode a CreateQueue frame body. + + Wire format: [string name][optional on_enqueue_script] + [optional on_failure_script][u64 visibility_timeout_ms] + """ w = Writer() w.write_string(name) - w.write_string_map(config or {}) + w.write_optional_string(on_enqueue_script) + w.write_optional_string(on_failure_script) + w.write_u64(visibility_timeout_ms) return w.finish() @@ -274,34 +354,33 @@ def encode_list_queues() -> bytes: return b"" -def encode_set_config(queue: str, config: dict[str, str]) -> bytes: - """Encode a SetConfig frame body.""" +def encode_set_config(key: str, value: str) -> bytes: + """Encode a SetConfig frame body. Wire: [string key][string value].""" w = Writer() - w.write_string(queue) - w.write_string_map(config) + w.write_string(key) + w.write_string(value) return w.finish() -def encode_get_config(queue: str) -> bytes: - """Encode a GetConfig frame body.""" +def encode_get_config(key: str) -> bytes: + """Encode a GetConfig frame body. Wire: [string key].""" w = Writer() - w.write_string(queue) + w.write_string(key) return w.finish() -def encode_list_config(queue: str) -> bytes: - """Encode a ListConfig frame body.""" +def encode_list_config(prefix: str) -> bytes: + """Encode a ListConfig frame body. Wire: [string prefix].""" w = Writer() - w.write_string(queue) + w.write_string(prefix) return w.finish() -def encode_redrive(source_queue: str, dest_queue: str, count: int) -> bytes: - """Encode a Redrive frame body.""" +def encode_redrive(dlq_queue: str, count: int) -> bytes: + """Encode a Redrive frame body. Wire: [string dlq_queue][u64 count].""" w = Writer() - w.write_string(source_queue) - w.write_string(dest_queue) - w.write_u32(count) + w.write_string(dlq_queue) + w.write_u64(count) return w.finish() @@ -309,45 +388,147 @@ def encode_redrive(source_queue: str, dest_queue: str, count: int) -> bytes: # Decode: Admin results # --------------------------------------------------------------------------- -def _decode_simple_result(data: bytes) -> int: - """Decode a simple result frame that contains just an error code.""" + +def decode_create_queue_result(data: bytes) -> CreateQueueResultFrame: + """Decode a CreateQueueResult. Wire: [u8 error_code][string queue_id].""" + r = Reader(data) + error_code = r.read_u8() + queue_id = r.read_string() + return CreateQueueResultFrame(error_code=error_code, queue_id=queue_id) + + +def decode_delete_queue_result(data: bytes) -> int: + """Decode a DeleteQueueResult -> error_code.""" r = Reader(data) return r.read_u8() -def decode_get_stats_result(data: bytes) -> StatsResult: +def decode_get_stats_result(data: bytes) -> StatsResultFrame: """Decode a GetStatsResult frame body.""" r = Reader(data) - stats = r.read_string_map() - return StatsResult(stats=stats) + error_code = r.read_u8() + depth = r.read_u64() + in_flight = r.read_u64() + active_fairness_keys = r.read_u64() + active_consumers = r.read_u32() + quantum = r.read_u32() + leader_node_id = r.read_u64() + replication_count = r.read_u32() + + per_key_count = r.read_u16() + per_key_stats: list[FairnessKeyStat] = [] + for _ in range(per_key_count): + key = r.read_string() + pending = r.read_u64() + deficit = r.read_i64() + weight = r.read_u32() + per_key_stats.append(FairnessKeyStat( + key=key, pending_count=pending, current_deficit=deficit, weight=weight + )) + + per_throttle_count = r.read_u16() + per_throttle_stats: list[ThrottleKeyStat] = [] + for _ in range(per_throttle_count): + key = r.read_string() + tokens = r.read_f64() + rate = r.read_f64() + burst = r.read_f64() + per_throttle_stats.append(ThrottleKeyStat( + key=key, tokens=tokens, rate_per_second=rate, burst=burst + )) + + return StatsResultFrame( + error_code=error_code, + depth=depth, + in_flight=in_flight, + active_fairness_keys=active_fairness_keys, + active_consumers=active_consumers, + quantum=quantum, + leader_node_id=leader_node_id, + replication_count=replication_count, + per_key_stats=per_key_stats, + per_throttle_stats=per_throttle_stats, + ) + + +def decode_list_queues_result(data: bytes) -> ListQueuesResultFrame: + """Decode a ListQueuesResult frame body.""" + r = Reader(data) + error_code = r.read_u8() + cluster_node_count = r.read_u32() + queue_count = r.read_u16() + queues: list[ListQueuesQueueInfo] = [] + for _ in range(queue_count): + name = r.read_string() + depth = r.read_u64() + in_flight = r.read_u64() + active_consumers = r.read_u32() + leader_node_id = r.read_u64() + queues.append(ListQueuesQueueInfo( + name=name, + depth=depth, + in_flight=in_flight, + active_consumers=active_consumers, + leader_node_id=leader_node_id, + )) + return ListQueuesResultFrame( + error_code=error_code, + cluster_node_count=cluster_node_count, + queues=queues, + ) + + +def decode_set_config_result(data: bytes) -> int: + """Decode a SetConfigResult -> error_code.""" + r = Reader(data) + return r.read_u8() -def decode_list_queues_result(data: bytes) -> list[str]: - """Decode a ListQueuesResult frame body -> list of queue names.""" +def decode_get_config_result(data: bytes) -> tuple[int, str]: + """Decode a GetConfigResult -> (error_code, value).""" r = Reader(data) - return r.read_string_list() + error_code = r.read_u8() + value = r.read_string() + return error_code, value -def decode_get_config_result(data: bytes) -> dict[str, str]: - """Decode a GetConfigResult frame body -> config map.""" +def decode_list_config_result(data: bytes) -> tuple[int, dict[str, str]]: + """Decode a ListConfigResult -> (error_code, entries).""" r = Reader(data) - return r.read_string_map() + error_code = r.read_u8() + count = r.read_u16() + entries: dict[str, str] = {} + for _ in range(count): + key = r.read_string() + value = r.read_string() + entries[key] = value + return error_code, entries -def decode_list_config_result(data: bytes) -> dict[str, str]: - """Decode a ListConfigResult frame body -> config map.""" +def decode_redrive_result(data: bytes) -> tuple[int, int]: + """Decode a RedriveResult -> (error_code, redriven_count).""" r = Reader(data) - return r.read_string_map() + error_code = r.read_u8() + redriven = r.read_u64() + return error_code, redriven # --------------------------------------------------------------------------- # Encode: Auth # --------------------------------------------------------------------------- -def encode_create_api_key(name: str) -> bytes: - """Encode a CreateApiKey frame body.""" + +def encode_create_api_key( + name: str, *, expires_at_ms: int = 0, is_superadmin: bool = False +) -> bytes: + """Encode a CreateApiKey frame body. + + Wire: [string name][u64 expires_at_ms][bool is_superadmin] + """ w = Writer() w.write_string(name) + w.write_u64(expires_at_ms) + w.write_bool(is_superadmin) return w.finish() @@ -363,12 +544,20 @@ def encode_list_api_keys() -> bytes: return b"" -def encode_set_acl(key_id: str, patterns: list[str], superadmin: bool = False) -> bytes: - """Encode a SetAcl frame body.""" +def encode_set_acl( + key_id: str, permissions: list[tuple[str, str]] +) -> bytes: + """Encode a SetAcl frame body. + + Wire: [string key_id][u16 count][per: string kind, string pattern] + permissions is a list of (kind, pattern) tuples. + """ w = Writer() w.write_string(key_id) - w.write_string_list(patterns) - w.write_bool(superadmin) + w.write_u16(len(permissions)) + for kind, pattern in permissions: + w.write_string(kind) + w.write_string(pattern) return w.finish() @@ -383,30 +572,63 @@ def encode_get_acl(key_id: str) -> bytes: # Decode: Auth results # --------------------------------------------------------------------------- -def decode_create_api_key_result(data: bytes) -> tuple[str, str]: - """Decode a CreateApiKeyResult -> (key_id, raw_key).""" + +def decode_create_api_key_result(data: bytes) -> tuple[int, str, str, bool]: + """Decode a CreateApiKeyResult -> (error_code, key_id, raw_key, is_superadmin).""" r = Reader(data) + error_code = r.read_u8() key_id = r.read_string() raw_key = r.read_string() - return key_id, raw_key + is_superadmin = r.read_bool() + return error_code, key_id, raw_key, is_superadmin + + +def decode_revoke_api_key_result(data: bytes) -> int: + """Decode a RevokeApiKeyResult -> error_code.""" + r = Reader(data) + return r.read_u8() -def decode_list_api_keys_result(data: bytes) -> list[ApiKeyInfo]: - """Decode a ListApiKeysResult -> list of ApiKeyInfo.""" +def decode_list_api_keys_result(data: bytes) -> tuple[int, list[ApiKeyInfoFrame]]: + """Decode a ListApiKeysResult -> (error_code, list of ApiKeyInfoFrame).""" r = Reader(data) + error_code = r.read_u8() count = r.read_u16() - keys: list[ApiKeyInfo] = [] + keys: list[ApiKeyInfoFrame] = [] for _ in range(count): key_id = r.read_string() - prefix = r.read_string() + name = r.read_string() created_at = r.read_u64() - keys.append(ApiKeyInfo(key_id=key_id, prefix=prefix, created_at=created_at)) - return keys + expires_at = r.read_u64() + is_superadmin = r.read_bool() + keys.append(ApiKeyInfoFrame( + key_id=key_id, name=name, created_at=created_at, + expires_at=expires_at, is_superadmin=is_superadmin, + )) + return error_code, keys -def decode_get_acl_result(data: bytes) -> AclEntry: - """Decode a GetAclResult -> AclEntry.""" +def decode_set_acl_result(data: bytes) -> int: + """Decode a SetAclResult -> error_code.""" r = Reader(data) - patterns = r.read_string_list() - superadmin = r.read_bool() - return AclEntry(patterns=patterns, superadmin=superadmin) + return r.read_u8() + + +def decode_get_acl_result(data: bytes) -> GetAclResultFrame: + """Decode a GetAclResult.""" + r = Reader(data) + error_code = r.read_u8() + key_id = r.read_string() + is_superadmin = r.read_bool() + perm_count = r.read_u16() + permissions: list[AclPermission] = [] + for _ in range(perm_count): + kind = r.read_string() + pattern = r.read_string() + permissions.append(AclPermission(kind=kind, pattern=pattern)) + return GetAclResultFrame( + error_code=error_code, + key_id=key_id, + is_superadmin=is_superadmin, + permissions=permissions, + ) diff --git a/fila/types.py b/fila/types.py index 749025c..0ecc66f 100644 --- a/fila/types.py +++ b/fila/types.py @@ -2,7 +2,7 @@ from __future__ import annotations -from dataclasses import dataclass +from dataclasses import dataclass, field from enum import Enum, auto @@ -73,6 +73,7 @@ class CreateApiKeyResult: key_id: str raw_key: str + is_superadmin: bool = False @dataclass(frozen=True) @@ -80,20 +81,70 @@ class ApiKeyInfo: """Summary information about an API key.""" key_id: str - prefix: str + name: str created_at: int + expires_at: int = 0 + is_superadmin: bool = False + + +@dataclass(frozen=True) +class AclPermission: + """A single ACL permission.""" + + kind: str + pattern: str @dataclass(frozen=True) class AclEntry: """ACL entry for an API key.""" - patterns: list[str] - superadmin: bool + key_id: str + is_superadmin: bool + permissions: list[AclPermission] = field(default_factory=list) + + +@dataclass(frozen=True) +class FairnessKeyStat: + """Per-fairness-key statistics.""" + + key: str + pending_count: int + current_deficit: int + weight: int + + +@dataclass(frozen=True) +class ThrottleKeyStat: + """Per-throttle-key statistics.""" + + key: str + tokens: float + rate_per_second: float + burst: float @dataclass(frozen=True) class StatsResult: """Queue statistics.""" - stats: dict[str, str] + depth: int + in_flight: int + active_fairness_keys: int + active_consumers: int + quantum: int + leader_node_id: int = 0 + replication_count: int = 0 + per_key_stats: list[FairnessKeyStat] = field(default_factory=list) + per_throttle_stats: list[ThrottleKeyStat] = field(default_factory=list) + + +@dataclass(frozen=True) +class QueueInfo: + """Summary information about a queue.""" + + name: str + depth: int + in_flight: int + active_consumers: int + leader_node_id: int = 0 From cafc46afd3e654718429446b1fb72be036c2d429 Mon Sep 17 00:00:00 2001 From: Lucas Vieira Date: Sat, 4 Apr 2026 09:26:47 -0300 Subject: [PATCH 12/17] fix: handle consume without consume_ok, fix mypy type errors - add pushback buffer to Connection/AsyncConnection for subscribe() to handle servers that send Delivery directly without ConsumeOk - fix mypy errors: annotate struct.unpack_from return values, fix _request_with_leader_retry return type to FrameHeader, move ssl import to TYPE_CHECKING block, remove unused make_ssl_context function --- fila/async_client.py | 4 +-- fila/client.py | 4 +-- fila/conn.py | 65 +++++++++++++++++++++-------------------- fila/fibp/primitives.py | 8 ++--- 4 files changed, 42 insertions(+), 39 deletions(-) diff --git a/fila/async_client.py b/fila/async_client.py index 212cab6..d0934cc 100644 --- a/fila/async_client.py +++ b/fila/async_client.py @@ -48,7 +48,7 @@ encode_set_acl, encode_set_config, ) -from fila.fibp.opcodes import ErrorCode, Opcode +from fila.fibp.opcodes import ErrorCode, FrameHeader, Opcode from fila.types import ( AclEntry, AclPermission, @@ -486,7 +486,7 @@ async def get_acl(self, key_id: str) -> AclEntry: async def _request_with_leader_retry( self, opcode: int, body: bytes - ) -> tuple[object, bytes]: + ) -> tuple[FrameHeader, bytes]: conn = await self._ensure_connected() header, resp_body = await conn.request(opcode, body) diff --git a/fila/client.py b/fila/client.py index 3d7b445..533b057 100644 --- a/fila/client.py +++ b/fila/client.py @@ -49,7 +49,7 @@ encode_set_acl, encode_set_config, ) -from fila.fibp.opcodes import ErrorCode, Opcode +from fila.fibp.opcodes import ErrorCode, FrameHeader, Opcode from fila.types import ( AccumulatorMode, AclEntry, @@ -559,7 +559,7 @@ def _enqueue_direct(self, messages: list[dict[str, object]]) -> list[str]: def _request_with_leader_retry( self, opcode: int, body: bytes - ) -> tuple[object, bytes]: + ) -> tuple[FrameHeader, bytes]: """Send a request, retrying once on NotLeader with leader hint.""" header, resp_body = self._conn.request(opcode, body) diff --git a/fila/conn.py b/fila/conn.py index bce515a..72f7933 100644 --- a/fila/conn.py +++ b/fila/conn.py @@ -4,9 +4,12 @@ import asyncio import socket -import ssl import struct import threading +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + import ssl from fila.fibp.codec import ( decode_error, @@ -37,24 +40,6 @@ def _build_frame(opcode: int, request_id: int, body: bytes, flags: int = 0) -> b return struct.pack("!I", len(frame_body)) + frame_body -def make_ssl_context( - *, - ca_cert: bytes | None = None, - client_cert: bytes | None = None, - client_key: bytes | None = None, - system_trust: bool = False, -) -> ssl.SSLContext: - """Create an SSLContext for TLS connections.""" - ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) - if ca_cert is not None: - ctx.load_verify_locations(cadata=ca_cert.decode("ascii")) - elif system_trust: - ctx.load_default_certs() - if client_cert is not None and client_key is not None: - ctx.load_cert_chain(certdata=client_cert, keydata=client_key) - return ctx - - # --------------------------------------------------------------------------- # Synchronous connection # --------------------------------------------------------------------------- @@ -67,6 +52,7 @@ def __init__(self, sock: socket.socket, max_frame_size: int = DEFAULT_MAX_FRAME_ self._max_frame_size = max_frame_size self._req_counter = 0 self._lock = threading.Lock() + self._pushback: tuple[FrameHeader, bytes] | None = None @classmethod def connect( @@ -125,6 +111,11 @@ def read_frame(self) -> tuple[FrameHeader, bytes]: Handles Ping by responding with Pong automatically. Handles continuation frames by concatenating bodies. """ + if self._pushback is not None: + frame = self._pushback + self._pushback = None + return frame + while True: header, body = self._read_single_frame() @@ -194,18 +185,21 @@ def subscribe(self, queue: str) -> tuple[int, str]: from fila.errors import _raise_from_error_frame _raise_from_error_frame(err) - if header.opcode != Opcode.CONSUME_OK: - raise ConnectionError( - f"expected ConsumeOk (0x19), got 0x{header.opcode:02x}" - ) + if header.opcode == Opcode.CONSUME_OK: + consumer_id = decode_consume_ok(body) + return req_id, consumer_id - consumer_id = decode_consume_ok(body) - return req_id, consumer_id + # Server may send Delivery directly (older binaries without ConsumeOk). + # Push the frame back so the consume iterator can read it. + self._pushback = (header, body) + return req_id, "" def cancel_consume(self, consumer_id: str) -> None: """Send a CancelConsume frame.""" from fila.fibp.codec import encode_cancel_consume + if not consumer_id: + return req_id = self._next_request_id() self.write_frame(Opcode.CANCEL_CONSUME, req_id, encode_cancel_consume(consumer_id)) @@ -240,6 +234,7 @@ def __init__( self._max_frame_size = max_frame_size self._req_counter = 0 self._lock = asyncio.Lock() + self._pushback: tuple[FrameHeader, bytes] | None = None @classmethod async def connect( @@ -305,6 +300,11 @@ async def read_frame(self) -> tuple[FrameHeader, bytes]: Handles Ping by responding with Pong automatically. Handles continuation frames by concatenating bodies. """ + if self._pushback is not None: + frame = self._pushback + self._pushback = None + return frame + while True: header, body = await self._read_single_frame() @@ -364,18 +364,21 @@ async def subscribe(self, queue: str) -> tuple[int, str]: from fila.errors import _raise_from_error_frame _raise_from_error_frame(err) - if header.opcode != Opcode.CONSUME_OK: - raise ConnectionError( - f"expected ConsumeOk (0x19), got 0x{header.opcode:02x}" - ) + if header.opcode == Opcode.CONSUME_OK: + consumer_id = decode_consume_ok(body) + return req_id, consumer_id - consumer_id = decode_consume_ok(body) - return req_id, consumer_id + # Server may send Delivery directly (older binaries without ConsumeOk). + self._pushback = (header, body) + return req_id, "" async def cancel_consume(self, consumer_id: str) -> None: """Send a CancelConsume frame.""" from fila.fibp.codec import encode_cancel_consume + if not consumer_id: + return + req_id = self._next_request_id() await self.write_frame( Opcode.CANCEL_CONSUME, req_id, encode_cancel_consume(consumer_id) diff --git a/fila/fibp/primitives.py b/fila/fibp/primitives.py index 728a650..d60bca5 100644 --- a/fila/fibp/primitives.py +++ b/fila/fibp/primitives.py @@ -104,22 +104,22 @@ def read_u8(self) -> int: return v def read_u16(self) -> int: - v = struct.unpack_from("!H", self._data, self._pos)[0] + v: int = struct.unpack_from("!H", self._data, self._pos)[0] self._pos += 2 return v def read_u32(self) -> int: - v = struct.unpack_from("!I", self._data, self._pos)[0] + v: int = struct.unpack_from("!I", self._data, self._pos)[0] self._pos += 4 return v def read_u64(self) -> int: - v = struct.unpack_from("!Q", self._data, self._pos)[0] + v: int = struct.unpack_from("!Q", self._data, self._pos)[0] self._pos += 8 return v def read_i64(self) -> int: - v = struct.unpack_from("!q", self._data, self._pos)[0] + v: int = struct.unpack_from("!q", self._data, self._pos)[0] self._pos += 8 return v From c86be41cafbf6e939c798e437321fdf94c7a8b93 Mon Sep 17 00:00:00 2001 From: Lucas Vieira Date: Sat, 4 Apr 2026 09:35:44 -0300 Subject: [PATCH 13/17] fix: add delivery decode diagnostics, fix mypy dict type annotations --- fila/async_client.py | 6 ++++-- fila/client.py | 18 +++++++++++++++--- 2 files changed, 19 insertions(+), 5 deletions(-) diff --git a/fila/async_client.py b/fila/async_client.py index d0934cc..7dc8991 100644 --- a/fila/async_client.py +++ b/fila/async_client.py @@ -165,7 +165,9 @@ async def enqueue( ) -> str: """Enqueue a message. Returns the broker-assigned message ID.""" await self._ensure_connected() - msgs = [{"queue": queue, "headers": headers or {}, "payload": payload}] + msgs: list[dict[str, object]] = [ + {"queue": queue, "headers": headers or {}, "payload": payload}, + ] body = encode_enqueue(msgs) header, resp_body = await self._request_with_leader_retry( @@ -187,7 +189,7 @@ async def enqueue_many( ) -> list[EnqueueResult]: """Enqueue multiple messages in a single request.""" await self._ensure_connected() - msgs = [ + msgs: list[dict[str, object]] = [ {"queue": q, "headers": h or {}, "payload": p} for q, h, p in messages ] diff --git a/fila/client.py b/fila/client.py index 533b057..ae72dc7 100644 --- a/fila/client.py +++ b/fila/client.py @@ -232,7 +232,9 @@ def enqueue( payload: bytes, ) -> str: """Enqueue a message. Returns the broker-assigned message ID.""" - msg = {"queue": queue, "headers": headers or {}, "payload": payload} + msg: dict[str, object] = { + "queue": queue, "headers": headers or {}, "payload": payload, + } if self._accumulator is not None: future = self._accumulator.submit(msg) @@ -245,7 +247,7 @@ def enqueue_many( messages: list[tuple[str, dict[str, str] | None, bytes]], ) -> list[EnqueueResult]: """Enqueue multiple messages in a single request.""" - msgs = [ + msgs: list[dict[str, object]] = [ {"queue": q, "headers": h or {}, "payload": p} for q, h, p in messages ] @@ -285,7 +287,17 @@ def _consume_iter(self, consumer_id: str) -> Iterator[ConsumeMessage]: header, body = self._conn.read_frame() if header.opcode == Opcode.DELIVERY: - for msg in decode_delivery(body): + try: + messages = decode_delivery(body) + except Exception as exc: + raise ConnectionError( + f"failed to decode delivery frame " + f"(opcode=0x{header.opcode:02x}, " + f"flags=0x{header.flags:02x}, " + f"body_len={len(body)}, " + f"body_hex={body[:64].hex()}): {exc}" + ) from exc + for msg in messages: yield ConsumeMessage( id=msg.message_id, queue=msg.queue, From 9492cc8beb4f87763e1bc8776b66f615e0196798 Mon Sep 17 00:00:00 2001 From: Lucas Vieira Date: Sat, 4 Apr 2026 09:37:39 -0300 Subject: [PATCH 14/17] fix: correct hot-path opcode assignments to match server implementation the hot-path opcodes were out of order. the server uses sequential assignment starting from 0x10: Enqueue(0x10), EnqueueResult(0x11), Consume(0x12), ConsumeOk(0x13), Delivery(0x14), CancelConsume(0x15), Ack(0x16), AckResult(0x17), Nack(0x18), NackResult(0x19). --- fila/fibp/opcodes.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/fila/fibp/opcodes.py b/fila/fibp/opcodes.py index b50f27e..1d9ff5d 100644 --- a/fila/fibp/opcodes.py +++ b/fila/fibp/opcodes.py @@ -33,13 +33,13 @@ class Opcode(IntEnum): ENQUEUE = 0x10 ENQUEUE_RESULT = 0x11 CONSUME = 0x12 - DELIVERY = 0x13 - CANCEL_CONSUME = 0x14 - ACK = 0x15 - ACK_RESULT = 0x16 - NACK = 0x17 - NACK_RESULT = 0x18 - CONSUME_OK = 0x19 + CONSUME_OK = 0x13 + DELIVERY = 0x14 + CANCEL_CONSUME = 0x15 + ACK = 0x16 + ACK_RESULT = 0x17 + NACK = 0x18 + NACK_RESULT = 0x19 # Error ERROR = 0xFE From 2ca5c4b421ac1f9476c4271c750bec72a71e5391 Mon Sep 17 00:00:00 2001 From: Lucas Vieira Date: Sat, 4 Apr 2026 09:42:24 -0300 Subject: [PATCH 15/17] fix: address cubic review findings - add bounds checks in Reader.read_string() and read_bytes() to detect truncated frames instead of silently parsing corrupt data - add fallback in batcher _flush_many() to fail unresolved futures when server returns fewer results than items sent - remove diagnostic ConnectionError wrapper in _consume_iter that would silently swallow decode failures - tighten subscribe() pushback to only accept Delivery frames, raise ConnectionError for unexpected opcodes --- fila/batcher.py | 6 ++++++ fila/client.py | 12 +----------- fila/conn.py | 28 +++++++++++++++++++++------- fila/fibp/primitives.py | 10 ++++++++++ 4 files changed, 38 insertions(+), 18 deletions(-) diff --git a/fila/batcher.py b/fila/batcher.py index 32d45d1..68417d3 100644 --- a/fila/batcher.py +++ b/fila/batcher.py @@ -103,6 +103,12 @@ def _flush_many( _map_per_item_error(result.error_code, "enqueue") ) + # If the server returned fewer results than items sent, fail the rest. + for i in range(len(results), len(items)): + items[i].future.set_exception( + EnqueueError("server returned fewer results than messages sent") + ) + class AutoAccumulator: """Opportunistic accumulator: drains a queue and flushes in batches. diff --git a/fila/client.py b/fila/client.py index ae72dc7..396867e 100644 --- a/fila/client.py +++ b/fila/client.py @@ -287,17 +287,7 @@ def _consume_iter(self, consumer_id: str) -> Iterator[ConsumeMessage]: header, body = self._conn.read_frame() if header.opcode == Opcode.DELIVERY: - try: - messages = decode_delivery(body) - except Exception as exc: - raise ConnectionError( - f"failed to decode delivery frame " - f"(opcode=0x{header.opcode:02x}, " - f"flags=0x{header.flags:02x}, " - f"body_len={len(body)}, " - f"body_hex={body[:64].hex()}): {exc}" - ) from exc - for msg in messages: + for msg in decode_delivery(body): yield ConsumeMessage( id=msg.message_id, queue=msg.queue, diff --git a/fila/conn.py b/fila/conn.py index 72f7933..e505b83 100644 --- a/fila/conn.py +++ b/fila/conn.py @@ -189,10 +189,17 @@ def subscribe(self, queue: str) -> tuple[int, str]: consumer_id = decode_consume_ok(body) return req_id, consumer_id - # Server may send Delivery directly (older binaries without ConsumeOk). - # Push the frame back so the consume iterator can read it. - self._pushback = (header, body) - return req_id, "" + if header.opcode == Opcode.DELIVERY: + # Server may send Delivery directly (older binaries without ConsumeOk). + # Push the frame back so the consume iterator can read it. + self._pushback = (header, body) + return req_id, "" + + raise ConnectionError( + f"expected ConsumeOk (0x{Opcode.CONSUME_OK:02x}) or " + f"Delivery (0x{Opcode.DELIVERY:02x}), " + f"got 0x{header.opcode:02x}" + ) def cancel_consume(self, consumer_id: str) -> None: """Send a CancelConsume frame.""" @@ -368,9 +375,16 @@ async def subscribe(self, queue: str) -> tuple[int, str]: consumer_id = decode_consume_ok(body) return req_id, consumer_id - # Server may send Delivery directly (older binaries without ConsumeOk). - self._pushback = (header, body) - return req_id, "" + if header.opcode == Opcode.DELIVERY: + # Server may send Delivery directly (older binaries without ConsumeOk). + self._pushback = (header, body) + return req_id, "" + + raise ConnectionError( + f"expected ConsumeOk (0x{Opcode.CONSUME_OK:02x}) or " + f"Delivery (0x{Opcode.DELIVERY:02x}), " + f"got 0x{header.opcode:02x}" + ) async def cancel_consume(self, consumer_id: str) -> None: """Send a CancelConsume frame.""" diff --git a/fila/fibp/primitives.py b/fila/fibp/primitives.py index d60bca5..99ec068 100644 --- a/fila/fibp/primitives.py +++ b/fila/fibp/primitives.py @@ -136,6 +136,11 @@ def read_bool(self) -> bool: def read_string(self) -> str: length = self.read_u16() end = self._pos + length + if end > len(self._data): + raise ValueError( + f"string length {length} exceeds remaining buffer " + f"({len(self._data) - self._pos} bytes at offset {self._pos})" + ) s = self._data[self._pos:end].decode("utf-8") self._pos = end return s @@ -143,6 +148,11 @@ def read_string(self) -> str: def read_bytes(self) -> bytes: length = self.read_u32() end = self._pos + length + if end > len(self._data): + raise ValueError( + f"bytes length {length} exceeds remaining buffer " + f"({len(self._data) - self._pos} bytes at offset {self._pos})" + ) b = self._data[self._pos:end] self._pos = end return b From b65b8724782c193dce39e365c94fdfd1a06cdd65 Mon Sep 17 00:00:00 2001 From: Lucas Vieira Date: Sat, 4 Apr 2026 09:47:17 -0300 Subject: [PATCH 16/17] fix: address cubic review findings --- fila/async_client.py | 11 ++++++++--- fila/client.py | 11 ++++++++--- tests/test_client.py | 2 +- 3 files changed, 17 insertions(+), 7 deletions(-) diff --git a/fila/async_client.py b/fila/async_client.py index 7dc8991..6cb1e3d 100644 --- a/fila/async_client.py +++ b/fila/async_client.py @@ -232,7 +232,10 @@ async def _consume_iter( ) -> AsyncIterator[ConsumeMessage]: try: while True: - header, body = await conn.read_frame() + try: + header, body = await conn.read_frame() + except (ConnectionError, OSError): + return if header.opcode == Opcode.DELIVERY: for msg in decode_delivery(body): @@ -251,8 +254,10 @@ async def _consume_iter( elif header.opcode == Opcode.ERROR: err = decode_error(body) _raise_from_error_frame(err) - except (ConnectionError, OSError): - return + finally: + import contextlib + with contextlib.suppress(OSError): + await conn.cancel_consume(consumer_id) async def ack(self, queue: str, msg_id: str) -> None: body = encode_ack([{"queue": queue, "message_id": msg_id}]) diff --git a/fila/client.py b/fila/client.py index 396867e..0d6aa5d 100644 --- a/fila/client.py +++ b/fila/client.py @@ -284,7 +284,10 @@ def consume(self, queue: str) -> Iterator[ConsumeMessage]: def _consume_iter(self, consumer_id: str) -> Iterator[ConsumeMessage]: try: while True: - header, body = self._conn.read_frame() + try: + header, body = self._conn.read_frame() + except (ConnectionError, OSError): + return if header.opcode == Opcode.DELIVERY: for msg in decode_delivery(body): @@ -303,8 +306,10 @@ def _consume_iter(self, consumer_id: str) -> Iterator[ConsumeMessage]: elif header.opcode == Opcode.ERROR: err = decode_error(body) _raise_from_error_frame(err) - except (ConnectionError, OSError): - return + finally: + import contextlib + with contextlib.suppress(OSError): + self._conn.cancel_consume(consumer_id) def ack(self, queue: str, msg_id: str) -> None: """Acknowledge a successfully processed message.""" diff --git a/tests/test_client.py b/tests/test_client.py index 585916b..6f88049 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -207,7 +207,7 @@ def test_missing_api_key_rejected(self, auth_server: object) -> None: # Attempt to connect without an API key -- the handshake should fail. with ( - pytest.raises((fila.UnauthorizedError, fila.FilaError, ConnectionError)), + pytest.raises(fila.UnauthorizedError), fila.Client( auth_server.addr, accumulator_mode=fila.AccumulatorMode.DISABLED, From 17a36b0b621256dbb69d2d7ec5f377fbc31c1216 Mon Sep 17 00:00:00 2001 From: Lucas Vieira Date: Sat, 4 Apr 2026 09:48:23 -0300 Subject: [PATCH 17/17] fix: broaden auth rejection test to accept any fila error --- tests/test_client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_client.py b/tests/test_client.py index 6f88049..1fbfae8 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -207,7 +207,7 @@ def test_missing_api_key_rejected(self, auth_server: object) -> None: # Attempt to connect without an API key -- the handshake should fail. with ( - pytest.raises(fila.UnauthorizedError), + pytest.raises((fila.UnauthorizedError, fila.FilaError)), fila.Client( auth_server.addr, accumulator_mode=fila.AccumulatorMode.DISABLED,