diff --git a/.github/workflows/dev-publish.yml b/.github/workflows/dev-publish.yml index 78078d0..e23d281 100644 --- a/.github/workflows/dev-publish.yml +++ b/.github/workflows/dev-publish.yml @@ -20,9 +20,10 @@ jobs: run: | COMMIT_COUNT=$(git rev-list --count HEAD) SHORT_SHA=$(git rev-parse --short HEAD) - DEV_VERSION="0.1.dev${COMMIT_COUNT}+g${SHORT_SHA}" + DEV_VERSION="0.1.dev${COMMIT_COUNT}" sed -i "s/^version = .*/version = \"${DEV_VERSION}\"/" pyproject.toml - echo "Publishing version: ${DEV_VERSION}" + sed -i "s/^description = .*/description = \"Python client SDK for the Fila message broker (dev build from ${SHORT_SHA})\"/" pyproject.toml + echo "Publishing version: ${DEV_VERSION} (${SHORT_SHA})" - run: pip install build - run: python -m build - uses: pypa/gh-action-pypi-publish@release/v1 diff --git a/README.md b/README.md index ad70043..3875887 100644 --- a/README.md +++ b/README.md @@ -2,6 +2,8 @@ Python client SDK for the [Fila](https://github.com/faiscadev/fila) message broker. +Communicates with the Fila server over the FIBP (Fila Binary Protocol) on port 5555. + ## Installation ```bash @@ -102,7 +104,7 @@ client = Client( ) ``` -The API key is sent as `authorization: Bearer ` metadata on every RPC. +The API key is sent during the FIBP handshake. ## API @@ -114,6 +116,10 @@ Connect to a Fila broker. Both support context manager protocol. Enqueue a message. Returns the broker-assigned message ID. +### `client.enqueue_many(messages) -> list[EnqueueResult]` + +Enqueue multiple messages in a single request. Returns per-message results. + ### `client.consume(queue) -> Iterator[ConsumeMessage]` Open a streaming consumer. Returns an iterator (sync) or async iterator (async) that yields messages as they become available. @@ -126,12 +132,31 @@ Acknowledge a successfully processed message. The message is permanently removed Negatively acknowledge a failed message. The message is requeued or routed to the dead-letter queue based on the queue's configuration. +### Admin Methods + +- `client.create_queue(name, config=None)` -- Create a queue +- `client.delete_queue(name)` -- Delete a queue +- `client.get_stats(queue) -> StatsResult` -- Get queue statistics +- `client.list_queues() -> list[str]` -- List all queues +- `client.set_config(queue, config)` -- Set queue configuration +- `client.get_config(queue) -> dict` -- Get queue configuration +- `client.list_config(queue) -> dict` -- List queue configuration +- `client.redrive(source, dest, count)` -- Redrive messages between queues + +### Auth Methods + +- `client.create_api_key(name) -> CreateApiKeyResult` -- Create an API key +- `client.revoke_api_key(key_id)` -- Revoke an API key +- `client.list_api_keys() -> list[ApiKeyInfo]` -- List API keys +- `client.set_acl(key_id, patterns, superadmin=False)` -- Set ACL +- `client.get_acl(key_id) -> AclEntry` -- Get ACL + ## Error Handling Per-operation exception classes: ```python -from fila import QueueNotFoundError, MessageNotFoundError +from fila import QueueNotFoundError, MessageNotFoundError, UnauthorizedError try: client.enqueue("missing-queue", None, b"test") diff --git a/fila/__init__.py b/fila/__init__.py index a273836..0b82fde 100644 --- a/fila/__init__.py +++ b/fila/__init__.py @@ -1,21 +1,73 @@ -"""Fila — Python client SDK for the Fila message broker.""" +"""Fila -- Python client SDK for the Fila message broker.""" from fila.async_client import AsyncClient from fila.client import Client from fila.errors import ( + AclNotFoundError, + ApiKeyNotFoundError, + ChannelFullError, + EnqueueError, FilaError, + ForbiddenError, + InvalidArgumentError, + LuaError, MessageNotFoundError, + NotLeaderError, + PermissionDeniedError, + ProtocolError, + QueueAlreadyExistsError, QueueNotFoundError, + ResourceExhaustedError, RPCError, + UnauthorizedError, + UnavailableError, +) +from fila.types import ( + AccumulatorMode, + AclEntry, + AclPermission, + ApiKeyInfo, + ConsumeMessage, + CreateApiKeyResult, + EnqueueResult, + FairnessKeyStat, + Linger, + QueueInfo, + StatsResult, + ThrottleKeyStat, ) -from fila.types import ConsumeMessage __all__ = [ + "AccumulatorMode", + "AclEntry", + "AclNotFoundError", + "AclPermission", + "ApiKeyInfo", + "ApiKeyNotFoundError", "AsyncClient", + "ChannelFullError", "Client", "ConsumeMessage", + "CreateApiKeyResult", + "EnqueueError", + "EnqueueResult", + "FairnessKeyStat", "FilaError", + "ForbiddenError", + "InvalidArgumentError", + "Linger", + "LuaError", "MessageNotFoundError", + "NotLeaderError", + "PermissionDeniedError", + "ProtocolError", + "QueueAlreadyExistsError", + "QueueInfo", "QueueNotFoundError", "RPCError", + "ResourceExhaustedError", + "StatsResult", + "ThrottleKeyStat", + "UnauthorizedError", + "UnavailableError", ] diff --git a/fila/async_client.py b/fila/async_client.py index 9c99b50..6cb1e3d 100644 --- a/fila/async_client.py +++ b/fila/async_client.py @@ -2,141 +2,79 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any - -import grpc -import grpc.aio +import ssl +from typing import TYPE_CHECKING + +from fila.conn import AsyncConnection +from fila.errors import ( + NotLeaderError, + _map_error_code, + _map_per_item_error, + _raise_from_error_frame, +) +from fila.fibp.codec import ( + decode_ack_result, + decode_create_api_key_result, + decode_create_queue_result, + decode_delete_queue_result, + decode_delivery, + decode_enqueue_result, + decode_error, + decode_get_acl_result, + decode_get_config_result, + decode_get_stats_result, + decode_list_api_keys_result, + decode_list_config_result, + decode_list_queues_result, + decode_nack_result, + decode_redrive_result, + decode_revoke_api_key_result, + decode_set_acl_result, + decode_set_config_result, + encode_ack, + encode_create_api_key, + encode_create_queue, + encode_delete_queue, + encode_enqueue, + encode_get_acl, + encode_get_config, + encode_get_stats, + encode_list_api_keys, + encode_list_config, + encode_list_queues, + encode_nack, + encode_redrive, + encode_revoke_api_key, + encode_set_acl, + encode_set_config, +) +from fila.fibp.opcodes import ErrorCode, FrameHeader, Opcode +from fila.types import ( + AclEntry, + AclPermission, + ApiKeyInfo, + ConsumeMessage, + CreateApiKeyResult, + EnqueueResult, + FairnessKeyStat, + QueueInfo, + StatsResult, + ThrottleKeyStat, +) if TYPE_CHECKING: from collections.abc import AsyncIterator -from fila.errors import _map_ack_error, _map_consume_error, _map_enqueue_error, _map_nack_error -from fila.types import ConsumeMessage -from fila.v1 import service_pb2, service_pb2_grpc - - -class _AsyncClientCallDetails( - grpc.aio.ClientCallDetails, # type: ignore[misc] -): - """Concrete ``ClientCallDetails`` for the async interceptor chain. - - ``grpc.aio.ClientCallDetails`` is a namedtuple with 5 fields (method, - timeout, metadata, credentials, wait_for_ready). We override ``__new__`` - so the namedtuple layer receives exactly those five, then set any extra - attribute (``compression``) in ``__init__``. - """ - - def __new__( - cls, - method: str, - timeout: float | None, - metadata: grpc.aio.Metadata | None, - credentials: grpc.CallCredentials | None, - wait_for_ready: bool | None, - ) -> _AsyncClientCallDetails: - return super().__new__(cls, method, timeout, metadata, credentials, wait_for_ready) # type: ignore[no-any-return] - - def __init__( - self, - method: str, - timeout: float | None, - metadata: grpc.aio.Metadata | None, - credentials: grpc.CallCredentials | None, - wait_for_ready: bool | None, - ) -> None: - # Fields are already set by __new__ (namedtuple). Nothing extra to do. - pass - - -class _AsyncApiKeyInterceptor( - grpc.aio.UnaryUnaryClientInterceptor, # type: ignore[misc] - grpc.aio.UnaryStreamClientInterceptor, # type: ignore[misc] -): - """Injects ``authorization: Bearer `` metadata into every async RPC.""" - - def __init__(self, api_key: str) -> None: - self._metadata = grpc.aio.Metadata(("authorization", f"Bearer {api_key}")) - - def _inject( - self, metadata: grpc.aio.Metadata | None - ) -> grpc.aio.Metadata: - merged = grpc.aio.Metadata() - if metadata is not None: - for key, value in metadata: - merged.add(key, value) - for key, value in self._metadata: - merged.add(key, value) - return merged - - async def intercept_unary_unary( - self, - continuation: Any, - client_call_details: grpc.aio.ClientCallDetails, - request: Any, - ) -> Any: - new_details = _AsyncClientCallDetails( - client_call_details.method, - client_call_details.timeout, - self._inject(client_call_details.metadata), - client_call_details.credentials, - client_call_details.wait_for_ready, - ) - return await continuation(new_details, request) - async def intercept_unary_stream( - self, - continuation: Any, - client_call_details: grpc.aio.ClientCallDetails, - request: Any, - ) -> Any: - new_details = _AsyncClientCallDetails( - client_call_details.method, - client_call_details.timeout, - self._inject(client_call_details.metadata), - client_call_details.credentials, - client_call_details.wait_for_ready, - ) - return await continuation(new_details, request) +def _parse_addr(addr: str) -> tuple[str, int]: + if ":" not in addr: + raise ValueError(f"invalid address (expected host:port): {addr}") + host, port_str = addr.rsplit(":", 1) + return host, int(port_str) class AsyncClient: - """Asynchronous client for the Fila message broker. - - Wraps the hot-path gRPC operations: enqueue, consume, ack, nack. - - Usage:: - - client = AsyncClient("localhost:5555") - msg_id = await client.enqueue("my-queue", {"tenant": "acme"}, b"hello") - async for msg in await client.consume("my-queue"): - await client.ack("my-queue", msg.id) - await client.close() - - Or as an async context manager:: - - async with AsyncClient("localhost:5555") as client: - await client.enqueue("my-queue", None, b"hello") - - TLS (system trust store):: - - client = AsyncClient("localhost:5555", tls=True) - - TLS (custom CA):: - - with open("ca.pem", "rb") as f: - ca = f.read() - client = AsyncClient("localhost:5555", ca_cert=ca) - - mTLS + API key:: - - client = AsyncClient( - "localhost:5555", - ca_cert=ca, - client_cert=cert, - client_key=key, - api_key="fila_...", - ) - """ + """Asynchronous client for the Fila message broker.""" def __init__( self, @@ -148,178 +86,424 @@ def __init__( client_key: bytes | None = None, api_key: str | None = None, ) -> None: - """Connect to a Fila broker at the given address. - - Args: - addr: Broker address in "host:port" format (e.g., "localhost:5555"). - tls: Enable TLS using the OS system trust store for server - verification. Ignored when ``ca_cert`` is provided (which - implies TLS). Defaults to ``False``. - ca_cert: PEM-encoded CA certificate for verifying the server. - When provided, a TLS channel is used instead of an insecure one. - client_cert: PEM-encoded client certificate for mutual TLS (optional). - client_key: PEM-encoded client private key for mutual TLS (optional). - api_key: API key for authentication. When set, every RPC includes an - ``authorization: Bearer `` metadata header. - """ - use_tls = tls or ca_cert is not None + self._addr = addr + self._tls = tls + self._ca_cert = ca_cert + self._client_cert = client_cert + self._client_key = client_key + self._api_key = api_key + self._conn: AsyncConnection | None = None + use_tls = tls or ca_cert is not None if (client_cert is not None or client_key is not None) and not use_tls: raise ValueError( - "client_cert and client_key require ca_cert or tls=True to establish a TLS channel" + "client_cert and client_key require ca_cert or tls=True" ) - interceptors: list[grpc.aio.ClientInterceptor] = [] - if api_key is not None: - interceptors.append(_AsyncApiKeyInterceptor(api_key)) + self._ssl_ctx = self._make_ssl_context() if use_tls else None - if use_tls: - creds = grpc.ssl_channel_credentials( - root_certificates=ca_cert, - private_key=client_key, - certificate_chain=client_cert, - ) - self._channel = grpc.aio.secure_channel( - addr, creds, interceptors=interceptors or None - ) + def _make_ssl_context(self) -> ssl.SSLContext: + ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + if self._ca_cert is not None: + ctx.load_verify_locations(cadata=self._ca_cert.decode("ascii")) else: - self._channel = grpc.aio.insecure_channel( - addr, interceptors=interceptors or None + ctx.load_default_certs() + if self._client_cert is not None and self._client_key is not None: + import os + import tempfile + + cert_file = tempfile.NamedTemporaryFile(delete=False, suffix=".pem") + key_file = tempfile.NamedTemporaryFile(delete=False, suffix=".pem") + try: + cert_file.write(self._client_cert) + cert_file.close() + key_file.write(self._client_key) + key_file.close() + ctx.load_cert_chain(cert_file.name, key_file.name) + finally: + os.unlink(cert_file.name) + os.unlink(key_file.name) + return ctx + + async def _ensure_connected(self) -> AsyncConnection: + if self._conn is None: + host, port = _parse_addr(self._addr) + self._conn = await AsyncConnection.connect( + host, port, ssl_context=self._ssl_ctx, api_key=self._api_key ) + return self._conn + + async def _reconnect(self, addr: str) -> None: + import contextlib - self._stub = service_pb2_grpc.FilaServiceStub(self._channel) # type: ignore[no-untyped-call] + if self._conn is not None: + with contextlib.suppress(OSError): + await self._conn.close() + self._addr = addr + self._conn = None + await self._ensure_connected() async def close(self) -> None: - """Close the underlying gRPC channel.""" - await self._channel.close() + if self._conn is not None: + await self._conn.close() + self._conn = None async def __aenter__(self) -> AsyncClient: + await self._ensure_connected() return self async def __aexit__(self, *args: object) -> None: await self.close() + # -- hot-path operations ------------------------------------------------- + async def enqueue( self, queue: str, headers: dict[str, str] | None, payload: bytes, ) -> str: - """Enqueue a message to the specified queue. - - Args: - queue: Target queue name. - headers: Optional message headers. - payload: Message payload bytes. + """Enqueue a message. Returns the broker-assigned message ID.""" + await self._ensure_connected() + msgs: list[dict[str, object]] = [ + {"queue": queue, "headers": headers or {}, "payload": payload}, + ] + body = encode_enqueue(msgs) + + header, resp_body = await self._request_with_leader_retry( + Opcode.ENQUEUE, body + ) + if header.opcode == Opcode.ERROR: + err = decode_error(resp_body) + _raise_from_error_frame(err) - Returns: - Broker-assigned message ID (UUIDv7). + items = decode_enqueue_result(resp_body) + item = items[0] + if item.error_code == ErrorCode.OK: + return item.message_id + raise _map_per_item_error(item.error_code, "enqueue") - Raises: - QueueNotFoundError: If the queue does not exist. - RPCError: For unexpected gRPC failures. - """ - try: - resp = await self._stub.Enqueue( - service_pb2.EnqueueRequest( - queue=queue, - headers=headers or {}, - payload=payload, - ) - ) - except grpc.RpcError as e: - raise _map_enqueue_error(e) from e - return str(resp.message_id) + async def enqueue_many( + self, + messages: list[tuple[str, dict[str, str] | None, bytes]], + ) -> list[EnqueueResult]: + """Enqueue multiple messages in a single request.""" + await self._ensure_connected() + msgs: list[dict[str, object]] = [ + {"queue": q, "headers": h or {}, "payload": p} + for q, h, p in messages + ] + body = encode_enqueue(msgs) + + header, resp_body = await self._request_with_leader_retry( + Opcode.ENQUEUE, body + ) + if header.opcode == Opcode.ERROR: + err = decode_error(resp_body) + _raise_from_error_frame(err) + + items = decode_enqueue_result(resp_body) + results: list[EnqueueResult] = [] + for item in items: + if item.error_code == ErrorCode.OK: + results.append(EnqueueResult(message_id=item.message_id, error=None)) + else: + err_msg = f"error 0x{item.error_code:02x}" + results.append(EnqueueResult(message_id=None, error=err_msg)) + return results async def consume(self, queue: str) -> AsyncIterator[ConsumeMessage]: - """Open a streaming consumer on the specified queue. - - Yields messages as they become available. The iterator ends when the - server stream closes or an error occurs. Nil message frames (keepalive - signals) are skipped automatically. - - Args: - queue: Queue to consume from. - - Yields: - ConsumeMessage objects as they arrive. - - Raises: - QueueNotFoundError: If the queue does not exist. - RPCError: For unexpected gRPC failures. - """ + """Open a streaming consumer on the specified queue.""" + conn = await self._ensure_connected() try: - stream = self._stub.Consume( - service_pb2.ConsumeRequest(queue=queue) - ) - except grpc.RpcError as e: - raise _map_consume_error(e) from e + _req_id, consumer_id = await conn.subscribe(queue) + except NotLeaderError as e: + if e.leader_addr is not None: + await self._reconnect(e.leader_addr) + conn = await self._ensure_connected() + _req_id, consumer_id = await conn.subscribe(queue) + else: + raise - return self._consume_iter(stream) + return self._consume_iter(conn, consumer_id) async def _consume_iter( - self, - stream: Any, + self, conn: AsyncConnection, consumer_id: str ) -> AsyncIterator[ConsumeMessage]: - """Internal async generator reading from the gRPC stream.""" try: - async for resp in stream: - msg = resp.message - if msg is None or not msg.ByteSize(): - continue # keepalive - metadata = msg.metadata - cm = ConsumeMessage( - id=msg.id, - headers=dict(msg.headers), - payload=bytes(msg.payload), - fairness_key=metadata.fairness_key if metadata else "", - attempt_count=metadata.attempt_count if metadata else 0, - queue=metadata.queue_id if metadata else "", - ) - yield cm - except grpc.RpcError: - return + while True: + try: + header, body = await conn.read_frame() + except (ConnectionError, OSError): + return + + if header.opcode == Opcode.DELIVERY: + for msg in decode_delivery(body): + yield ConsumeMessage( + id=msg.message_id, + queue=msg.queue, + headers=msg.headers, + payload=msg.payload, + fairness_key=msg.fairness_key, + attempt_count=msg.attempt_count, + weight=msg.weight, + throttle_keys=msg.throttle_keys, + enqueued_at=msg.enqueued_at, + leased_at=msg.leased_at, + ) + elif header.opcode == Opcode.ERROR: + err = decode_error(body) + _raise_from_error_frame(err) + finally: + import contextlib + with contextlib.suppress(OSError): + await conn.cancel_consume(consumer_id) async def ack(self, queue: str, msg_id: str) -> None: - """Acknowledge a successfully processed message. + body = encode_ack([{"queue": queue, "message_id": msg_id}]) + header, resp_body = await self._request_with_leader_retry(Opcode.ACK, body) - The message is permanently removed from the queue. + if header.opcode == Opcode.ERROR: + err = decode_error(resp_body) + _raise_from_error_frame(err) - Args: - queue: Queue the message belongs to. - msg_id: ID of the message to acknowledge. - - Raises: - MessageNotFoundError: If the message does not exist. - RPCError: For unexpected gRPC failures. - """ - try: - await self._stub.Ack( - service_pb2.AckRequest(queue=queue, message_id=msg_id) - ) - except grpc.RpcError as e: - raise _map_ack_error(e) from e + codes = decode_ack_result(resp_body) + if codes and codes[0] != ErrorCode.OK: + raise _map_per_item_error(codes[0], "ack") async def nack(self, queue: str, msg_id: str, error: str) -> None: - """Negatively acknowledge a message that failed processing. + body = encode_nack([{"queue": queue, "message_id": msg_id, "error": error}]) + header, resp_body = await self._request_with_leader_retry(Opcode.NACK, body) - The message is requeued for retry or routed to the dead-letter queue - based on the queue's on_failure Lua hook configuration. + if header.opcode == Opcode.ERROR: + err = decode_error(resp_body) + _raise_from_error_frame(err) - Args: - queue: Queue the message belongs to. - msg_id: ID of the message to nack. - error: Description of the failure. + codes = decode_nack_result(resp_body) + if codes and codes[0] != ErrorCode.OK: + raise _map_per_item_error(codes[0], "nack") - Raises: - MessageNotFoundError: If the message does not exist. - RPCError: For unexpected gRPC failures. - """ - try: - await self._stub.Nack( - service_pb2.NackRequest( - queue=queue, message_id=msg_id, error=error - ) - ) - except grpc.RpcError as e: - raise _map_nack_error(e) from e + # -- admin operations ---------------------------------------------------- + + async def create_queue(self, name: str) -> None: + body = encode_create_queue(name) + header, resp_body = await self._request_with_leader_retry( + Opcode.CREATE_QUEUE, body + ) + if header.opcode == Opcode.ERROR: + err = decode_error(resp_body) + _raise_from_error_frame(err) + result = decode_create_queue_result(resp_body) + if result.error_code != ErrorCode.OK: + raise _map_error_code(result.error_code, "create_queue failed") + + async def delete_queue(self, name: str) -> None: + body = encode_delete_queue(name) + header, resp_body = await self._request_with_leader_retry( + Opcode.DELETE_QUEUE, body + ) + if header.opcode == Opcode.ERROR: + err = decode_error(resp_body) + _raise_from_error_frame(err) + error_code = decode_delete_queue_result(resp_body) + if error_code != ErrorCode.OK: + raise _map_error_code(error_code, "delete_queue failed") + + async def get_stats(self, queue: str) -> StatsResult: + body = encode_get_stats(queue) + header, resp_body = await self._request_with_leader_retry( + Opcode.GET_STATS, body + ) + if header.opcode == Opcode.ERROR: + err = decode_error(resp_body) + _raise_from_error_frame(err) + r = decode_get_stats_result(resp_body) + if r.error_code != ErrorCode.OK: + raise _map_error_code(r.error_code, "get_stats failed") + return StatsResult( + depth=r.depth, in_flight=r.in_flight, + active_fairness_keys=r.active_fairness_keys, + active_consumers=r.active_consumers, quantum=r.quantum, + leader_node_id=r.leader_node_id, + replication_count=r.replication_count, + per_key_stats=[ + FairnessKeyStat( + key=s.key, pending_count=s.pending_count, + current_deficit=s.current_deficit, weight=s.weight, + ) for s in r.per_key_stats + ], + per_throttle_stats=[ + ThrottleKeyStat( + key=s.key, tokens=s.tokens, + rate_per_second=s.rate_per_second, burst=s.burst, + ) for s in r.per_throttle_stats + ], + ) + + async def list_queues(self) -> list[QueueInfo]: + body = encode_list_queues() + header, resp_body = await self._request_with_leader_retry( + Opcode.LIST_QUEUES, body + ) + if header.opcode == Opcode.ERROR: + err = decode_error(resp_body) + _raise_from_error_frame(err) + r = decode_list_queues_result(resp_body) + if r.error_code != ErrorCode.OK: + raise _map_error_code(r.error_code, "list_queues failed") + return [ + QueueInfo( + name=q.name, depth=q.depth, in_flight=q.in_flight, + active_consumers=q.active_consumers, + leader_node_id=q.leader_node_id, + ) for q in r.queues + ] + + async def set_config(self, key: str, value: str) -> None: + body = encode_set_config(key, value) + header, resp_body = await self._request_with_leader_retry( + Opcode.SET_CONFIG, body + ) + if header.opcode == Opcode.ERROR: + err = decode_error(resp_body) + _raise_from_error_frame(err) + error_code = decode_set_config_result(resp_body) + if error_code != ErrorCode.OK: + raise _map_error_code(error_code, "set_config failed") + + async def get_config(self, key: str) -> str: + body = encode_get_config(key) + header, resp_body = await self._request_with_leader_retry( + Opcode.GET_CONFIG, body + ) + if header.opcode == Opcode.ERROR: + err = decode_error(resp_body) + _raise_from_error_frame(err) + error_code, value = decode_get_config_result(resp_body) + if error_code != ErrorCode.OK: + raise _map_error_code(error_code, "get_config failed") + return value + + async def list_config(self, prefix: str) -> dict[str, str]: + body = encode_list_config(prefix) + header, resp_body = await self._request_with_leader_retry( + Opcode.LIST_CONFIG, body + ) + if header.opcode == Opcode.ERROR: + err = decode_error(resp_body) + _raise_from_error_frame(err) + error_code, entries = decode_list_config_result(resp_body) + if error_code != ErrorCode.OK: + raise _map_error_code(error_code, "list_config failed") + return entries + + async def redrive(self, dlq_queue: str, count: int) -> int: + body = encode_redrive(dlq_queue, count) + header, resp_body = await self._request_with_leader_retry( + Opcode.REDRIVE, body + ) + if header.opcode == Opcode.ERROR: + err = decode_error(resp_body) + _raise_from_error_frame(err) + error_code, redriven = decode_redrive_result(resp_body) + if error_code != ErrorCode.OK: + raise _map_error_code(error_code, "redrive failed") + return redriven + + # -- auth operations ----------------------------------------------------- + + async def create_api_key(self, name: str) -> CreateApiKeyResult: + body = encode_create_api_key(name) + header, resp_body = await self._request_with_leader_retry( + Opcode.CREATE_API_KEY, body + ) + if header.opcode == Opcode.ERROR: + err = decode_error(resp_body) + _raise_from_error_frame(err) + ec, key_id, raw_key, is_superadmin = decode_create_api_key_result(resp_body) + if ec != ErrorCode.OK: + raise _map_error_code(ec, "create_api_key failed") + return CreateApiKeyResult( + key_id=key_id, raw_key=raw_key, is_superadmin=is_superadmin + ) + + async def revoke_api_key(self, key_id: str) -> None: + body = encode_revoke_api_key(key_id) + header, resp_body = await self._request_with_leader_retry( + Opcode.REVOKE_API_KEY, body + ) + if header.opcode == Opcode.ERROR: + err = decode_error(resp_body) + _raise_from_error_frame(err) + error_code = decode_revoke_api_key_result(resp_body) + if error_code != ErrorCode.OK: + raise _map_error_code(error_code, "revoke_api_key failed") + + async def list_api_keys(self) -> list[ApiKeyInfo]: + body = encode_list_api_keys() + header, resp_body = await self._request_with_leader_retry( + Opcode.LIST_API_KEYS, body + ) + if header.opcode == Opcode.ERROR: + err = decode_error(resp_body) + _raise_from_error_frame(err) + error_code, items = decode_list_api_keys_result(resp_body) + if error_code != ErrorCode.OK: + raise _map_error_code(error_code, "list_api_keys failed") + return [ + ApiKeyInfo( + key_id=k.key_id, name=k.name, created_at=k.created_at, + expires_at=k.expires_at, is_superadmin=k.is_superadmin, + ) for k in items + ] + + async def set_acl(self, key_id: str, permissions: list[tuple[str, str]]) -> None: + body = encode_set_acl(key_id, permissions) + header, resp_body = await self._request_with_leader_retry( + Opcode.SET_ACL, body + ) + if header.opcode == Opcode.ERROR: + err = decode_error(resp_body) + _raise_from_error_frame(err) + error_code = decode_set_acl_result(resp_body) + if error_code != ErrorCode.OK: + raise _map_error_code(error_code, "set_acl failed") + + async def get_acl(self, key_id: str) -> AclEntry: + body = encode_get_acl(key_id) + header, resp_body = await self._request_with_leader_retry( + Opcode.GET_ACL, body + ) + if header.opcode == Opcode.ERROR: + err = decode_error(resp_body) + _raise_from_error_frame(err) + result = decode_get_acl_result(resp_body) + if result.error_code != ErrorCode.OK: + raise _map_error_code(result.error_code, "get_acl failed") + return AclEntry( + key_id=result.key_id, + is_superadmin=result.is_superadmin, + permissions=[ + AclPermission(kind=p.kind, pattern=p.pattern) + for p in result.permissions + ], + ) + + # -- internal helpers ---------------------------------------------------- + + async def _request_with_leader_retry( + self, opcode: int, body: bytes + ) -> tuple[FrameHeader, bytes]: + conn = await self._ensure_connected() + header, resp_body = await conn.request(opcode, body) + + if header.opcode == Opcode.ERROR: + err = decode_error(resp_body) + if err.code == ErrorCode.NOT_LEADER: + leader_addr = err.metadata.get("leader_addr") + if leader_addr: + await self._reconnect(leader_addr) + conn = await self._ensure_connected() + return await conn.request(opcode, body) + + return header, resp_body diff --git a/fila/batcher.py b/fila/batcher.py new file mode 100644 index 0000000..68417d3 --- /dev/null +++ b/fila/batcher.py @@ -0,0 +1,255 @@ +"""Background accumulator for opportunistic and linger-based enqueue accumulation.""" + +from __future__ import annotations + +import queue +import threading +from concurrent.futures import Future, ThreadPoolExecutor +from typing import TYPE_CHECKING + +from fila.errors import EnqueueError, _map_per_item_error +from fila.fibp.codec import decode_enqueue_result, decode_error, encode_enqueue +from fila.fibp.opcodes import ErrorCode, Opcode + +if TYPE_CHECKING: + from fila.conn import Connection + + +# Sentinel that signals the accumulator thread to stop. +_STOP = object() + +# Maximum number of messages per flush when none is configured. +_DEFAULT_MAX_MESSAGES = 1000 + + +class _EnqueueItem: + """Internal envelope pairing a message dict with its result future.""" + + __slots__ = ("msg", "future") + + def __init__( + self, + msg: dict[str, object], + future: Future[str], + ) -> None: + self.msg = msg + self.future = future + + +def _flush_single( + conn: Connection, + req: _EnqueueItem, +) -> None: + """Send a single message via the FIBP Enqueue request.""" + try: + body = encode_enqueue([req.msg]) + header, resp_body = conn.request(Opcode.ENQUEUE, body) + + if header.opcode == Opcode.ERROR: + err = decode_error(resp_body) + from fila.errors import _raise_from_error_frame + try: + _raise_from_error_frame(err) + except Exception as e: + req.future.set_exception(e) + return + + items = decode_enqueue_result(resp_body) + item = items[0] + if item.error_code == ErrorCode.OK: + req.future.set_result(item.message_id) + else: + req.future.set_exception( + _map_per_item_error(item.error_code, "enqueue") + ) + except Exception as e: + req.future.set_exception(e) + + +def _flush_many( + conn: Connection, + items: list[_EnqueueItem], +) -> None: + """Send multiple messages via the FIBP Enqueue request.""" + try: + body = encode_enqueue([item.msg for item in items]) + header, resp_body = conn.request(Opcode.ENQUEUE, body) + except Exception as e: + err = EnqueueError(f"enqueue request failed: {e}") + for item in items: + item.future.set_exception(err) + return + + if header.opcode == Opcode.ERROR: + try: + err_frame = decode_error(resp_body) + from fila.errors import _map_error_code + exc = _map_error_code(err_frame.code, err_frame.message) + except Exception as e: + exc = EnqueueError(f"enqueue failed: {e}") + for item in items: + item.future.set_exception(exc) + return + + results = decode_enqueue_result(resp_body) + for i, result in enumerate(results): + if i >= len(items): + break + item = items[i] + if result.error_code == ErrorCode.OK: + item.future.set_result(result.message_id) + else: + item.future.set_exception( + _map_per_item_error(result.error_code, "enqueue") + ) + + # If the server returned fewer results than items sent, fail the rest. + for i in range(len(results), len(items)): + items[i].future.set_exception( + EnqueueError("server returned fewer results than messages sent") + ) + + +class AutoAccumulator: + """Opportunistic accumulator: drains a queue and flushes in batches. + + A background daemon thread blocks on the first message, then non-blocking + drains any additional messages that arrived during processing and flushes + them as a single Enqueue request via a thread pool executor. + """ + + def __init__( + self, + conn: Connection, + max_messages: int = _DEFAULT_MAX_MESSAGES, + max_workers: int = 4, + ) -> None: + self._conn = conn + self._max_messages = max_messages + self._queue: queue.Queue[_EnqueueItem | object] = queue.Queue() + self._executor = ThreadPoolExecutor(max_workers=max_workers) + self._thread = threading.Thread(target=self._run, daemon=True) + self._thread.start() + + def submit(self, msg: dict[str, object]) -> Future[str]: + """Submit a message for accumulated enqueue. Returns a Future for the message ID.""" + fut: Future[str] = Future() + self._queue.put(_EnqueueItem(msg, fut)) + return fut + + def close(self, timeout: float | None = 30.0) -> None: + """Drain pending messages and shut down the accumulator.""" + self._queue.put(_STOP) + self._thread.join(timeout=timeout) + self._executor.shutdown(wait=True) + + def update_conn(self, conn: Connection) -> None: + """Update the connection (e.g. after leader-hint reconnect).""" + self._conn = conn + + def _run(self) -> None: + """Background loop: block for first item, drain rest, flush.""" + while True: + first = self._queue.get() + if first is _STOP: + return + + assert isinstance(first, _EnqueueItem) + batch: list[_EnqueueItem] = [first] + + while len(batch) < self._max_messages: + try: + item = self._queue.get_nowait() + except queue.Empty: + break + if item is _STOP: + self._flush(batch) + return + assert isinstance(item, _EnqueueItem) + batch.append(item) + + self._flush(batch) + + def _flush(self, batch: list[_EnqueueItem]) -> None: + """Dispatch a batch to the executor.""" + if len(batch) == 1: + self._executor.submit(_flush_single, self._conn, batch[0]) + else: + self._executor.submit(_flush_many, self._conn, batch) + + +class LingerAccumulator: + """Timer-based accumulator: holds messages for up to linger_ms or max_messages. + + A background daemon thread accumulates messages and flushes when either + the count reaches ``max_messages`` or ``linger_ms`` milliseconds have + elapsed since the first message in the current batch arrived. + """ + + def __init__( + self, + conn: Connection, + linger_ms: float, + max_messages: int, + max_workers: int = 4, + ) -> None: + self._conn = conn + self._linger_s = linger_ms / 1000.0 + self._max_messages = max_messages + self._queue: queue.Queue[_EnqueueItem | object] = queue.Queue() + self._executor = ThreadPoolExecutor(max_workers=max_workers) + self._thread = threading.Thread(target=self._run, daemon=True) + self._thread.start() + + def submit(self, msg: dict[str, object]) -> Future[str]: + """Submit a message for accumulated enqueue. Returns a Future for the message ID.""" + fut: Future[str] = Future() + self._queue.put(_EnqueueItem(msg, fut)) + return fut + + def close(self, timeout: float | None = 30.0) -> None: + """Drain pending messages and shut down the accumulator.""" + self._queue.put(_STOP) + self._thread.join(timeout=timeout) + self._executor.shutdown(wait=True) + + def update_conn(self, conn: Connection) -> None: + """Update the connection (e.g. after leader-hint reconnect).""" + self._conn = conn + + def _run(self) -> None: + """Background loop: accumulate up to max_messages or linger timeout.""" + import time + + while True: + first = self._queue.get() + if first is _STOP: + return + + assert isinstance(first, _EnqueueItem) + batch: list[_EnqueueItem] = [first] + + deadline = time.monotonic() + self._linger_s + + while len(batch) < self._max_messages: + remaining = deadline - time.monotonic() + if remaining <= 0: + break + try: + item = self._queue.get(timeout=remaining) + except queue.Empty: + break + if item is _STOP: + self._flush(batch) + return + assert isinstance(item, _EnqueueItem) + batch.append(item) + + self._flush(batch) + + def _flush(self, batch: list[_EnqueueItem]) -> None: + """Dispatch a batch to the executor.""" + if len(batch) == 1: + self._executor.submit(_flush_single, self._conn, batch[0]) + else: + self._executor.submit(_flush_many, self._conn, batch) diff --git a/fila/client.py b/fila/client.py index 531c051..0d6aa5d 100644 --- a/fila/client.py +++ b/fila/client.py @@ -2,88 +2,85 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any - -import grpc - -from fila.errors import _map_ack_error, _map_consume_error, _map_enqueue_error, _map_nack_error -from fila.types import ConsumeMessage -from fila.v1 import service_pb2, service_pb2_grpc +import ssl +from typing import TYPE_CHECKING + +from fila.batcher import AutoAccumulator, LingerAccumulator +from fila.conn import Connection +from fila.errors import ( + NotLeaderError, + _map_error_code, + _map_per_item_error, + _raise_from_error_frame, +) +from fila.fibp.codec import ( + decode_ack_result, + decode_create_api_key_result, + decode_create_queue_result, + decode_delete_queue_result, + decode_delivery, + decode_enqueue_result, + decode_error, + decode_get_acl_result, + decode_get_config_result, + decode_get_stats_result, + decode_list_api_keys_result, + decode_list_config_result, + decode_list_queues_result, + decode_nack_result, + decode_redrive_result, + decode_revoke_api_key_result, + decode_set_acl_result, + decode_set_config_result, + encode_ack, + encode_create_api_key, + encode_create_queue, + encode_delete_queue, + encode_enqueue, + encode_get_acl, + encode_get_config, + encode_get_stats, + encode_list_api_keys, + encode_list_config, + encode_list_queues, + encode_nack, + encode_redrive, + encode_revoke_api_key, + encode_set_acl, + encode_set_config, +) +from fila.fibp.opcodes import ErrorCode, FrameHeader, Opcode +from fila.types import ( + AccumulatorMode, + AclEntry, + AclPermission, + ApiKeyInfo, + ConsumeMessage, + CreateApiKeyResult, + EnqueueResult, + FairnessKeyStat, + Linger, + QueueInfo, + StatsResult, + ThrottleKeyStat, +) if TYPE_CHECKING: from collections.abc import Iterator -class _ClientCallDetails( - grpc.ClientCallDetails, # type: ignore[misc] -): - """Concrete ``ClientCallDetails`` that can be instantiated. - - ``grpc.ClientCallDetails`` is an abstract class with no ``__init__``, so we - need our own subclass to carry the fields through the interceptor chain. - """ - - def __init__( - self, - method: str, - timeout: float | None, - metadata: list[tuple[str, str | bytes]] | None, - credentials: grpc.CallCredentials | None, - wait_for_ready: bool | None, - compression: grpc.Compression | None, - ) -> None: - self.method = method - self.timeout = timeout - self.metadata = metadata - self.credentials = credentials - self.wait_for_ready = wait_for_ready - self.compression = compression - - -class _ApiKeyInterceptor( - grpc.UnaryUnaryClientInterceptor, # type: ignore[misc] - grpc.UnaryStreamClientInterceptor, # type: ignore[misc] -): - """Injects ``authorization: Bearer `` metadata into every RPC.""" - - def __init__(self, api_key: str) -> None: - self._metadata = (("authorization", f"Bearer {api_key}"),) - - def _inject( - self, client_call_details: grpc.ClientCallDetails - ) -> _ClientCallDetails: - metadata = list(client_call_details.metadata or []) - metadata.extend(self._metadata) - return _ClientCallDetails( - client_call_details.method, - client_call_details.timeout, - metadata, - client_call_details.credentials, - client_call_details.wait_for_ready, - client_call_details.compression, - ) - - def intercept_unary_unary( - self, - continuation: Any, - client_call_details: grpc.ClientCallDetails, - request: Any, - ) -> Any: - return continuation(self._inject(client_call_details), request) - - def intercept_unary_stream( - self, - continuation: Any, - client_call_details: grpc.ClientCallDetails, - request: Any, - ) -> Any: - return continuation(self._inject(client_call_details), request) +def _parse_addr(addr: str) -> tuple[str, int]: + """Parse 'host:port' into (host, port).""" + if ":" not in addr: + raise ValueError(f"invalid address (expected host:port): {addr}") + host, port_str = addr.rsplit(":", 1) + return host, int(port_str) class Client: """Synchronous client for the Fila message broker. - Wraps the hot-path gRPC operations: enqueue, consume, ack, nack. + Wraps the hot-path FIBP operations: enqueue, enqueue_many, consume, ack, nack. Usage:: @@ -98,6 +95,20 @@ class Client: with Client("localhost:5555") as client: client.enqueue("my-queue", None, b"hello") + Accumulator modes:: + + # AUTO (default): opportunistic accumulation via background thread + client = Client("localhost:5555") + + # DISABLED: each enqueue() is a direct request + client = Client("localhost:5555", accumulator_mode=AccumulatorMode.DISABLED) + + # LINGER: timer-based forced accumulation + client = Client( + "localhost:5555", + accumulator_mode=Linger(linger_ms=10, max_messages=100), + ) + TLS (system trust store):: client = Client("localhost:5555", tls=True) @@ -128,47 +139,83 @@ def __init__( client_cert: bytes | None = None, client_key: bytes | None = None, api_key: str | None = None, + accumulator_mode: AccumulatorMode | Linger = AccumulatorMode.AUTO, + max_accumulator_messages: int = 1000, ) -> None: - """Connect to a Fila broker at the given address. - - Args: - addr: Broker address in "host:port" format (e.g., "localhost:5555"). - tls: Enable TLS using the OS system trust store for server - verification. Ignored when ``ca_cert`` is provided (which - implies TLS). Defaults to ``False``. - ca_cert: PEM-encoded CA certificate for verifying the server. - When provided, a TLS channel is used instead of an insecure one. - client_cert: PEM-encoded client certificate for mutual TLS (optional). - client_key: PEM-encoded client private key for mutual TLS (optional). - api_key: API key for authentication. When set, every RPC includes an - ``authorization: Bearer `` metadata header. - """ - use_tls = tls or ca_cert is not None + self._addr = addr + self._tls = tls + self._ca_cert = ca_cert + self._client_cert = client_cert + self._client_key = client_key + self._api_key = api_key + use_tls = tls or ca_cert is not None if (client_cert is not None or client_key is not None) and not use_tls: raise ValueError( - "client_cert and client_key require ca_cert or tls=True to establish a TLS channel" + "client_cert and client_key require ca_cert or tls=True" ) - if use_tls: - creds = grpc.ssl_channel_credentials( - root_certificates=ca_cert, - private_key=client_key, - certificate_chain=client_cert, + self._ssl_ctx = self._make_ssl_context() if use_tls else None + self._conn = self._connect(addr) + + self._accumulator: AutoAccumulator | LingerAccumulator | None = None + if isinstance(accumulator_mode, Linger): + self._accumulator = LingerAccumulator( + self._conn, + linger_ms=accumulator_mode.linger_ms, + max_messages=accumulator_mode.max_messages, + ) + elif accumulator_mode is AccumulatorMode.AUTO: + self._accumulator = AutoAccumulator( + self._conn, + max_messages=max_accumulator_messages, ) - self._channel = grpc.secure_channel(addr, creds) + + def _make_ssl_context(self) -> ssl.SSLContext: + """Create an SSL context from stored credentials.""" + ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + if self._ca_cert is not None: + ctx.load_verify_locations(cadata=self._ca_cert.decode("ascii")) else: - self._channel = grpc.insecure_channel(addr) + ctx.load_default_certs() + if self._client_cert is not None and self._client_key is not None: + import os + import tempfile + + cert_file = tempfile.NamedTemporaryFile(delete=False, suffix=".pem") + key_file = tempfile.NamedTemporaryFile(delete=False, suffix=".pem") + try: + cert_file.write(self._client_cert) + cert_file.close() + key_file.write(self._client_key) + key_file.close() + ctx.load_cert_chain(cert_file.name, key_file.name) + finally: + os.unlink(cert_file.name) + os.unlink(key_file.name) + return ctx + + def _connect(self, addr: str) -> Connection: + host, port = _parse_addr(addr) + return Connection.connect( + host, port, ssl_context=self._ssl_ctx, api_key=self._api_key + ) - if api_key is not None: - interceptor = _ApiKeyInterceptor(api_key) - self._channel = grpc.intercept_channel(self._channel, interceptor) + def _reconnect(self, addr: str) -> None: + import contextlib - self._stub = service_pb2_grpc.FilaServiceStub(self._channel) # type: ignore[no-untyped-call] + with contextlib.suppress(OSError): + self._conn.close() + self._addr = addr + self._conn = self._connect(addr) + if self._accumulator is not None: + self._accumulator.update_conn(self._conn) def close(self) -> None: - """Close the underlying gRPC channel.""" - self._channel.close() + """Drain pending accumulated messages and close the connection.""" + if self._accumulator is not None: + self._accumulator.close() + self._conn.close() def __enter__(self) -> Client: return self @@ -176,127 +223,359 @@ def __enter__(self) -> Client: def __exit__(self, *args: object) -> None: self.close() + # -- hot-path operations ------------------------------------------------- + def enqueue( self, queue: str, headers: dict[str, str] | None, payload: bytes, ) -> str: - """Enqueue a message to the specified queue. + """Enqueue a message. Returns the broker-assigned message ID.""" + msg: dict[str, object] = { + "queue": queue, "headers": headers or {}, "payload": payload, + } - Args: - queue: Target queue name. - headers: Optional message headers. - payload: Message payload bytes. + if self._accumulator is not None: + future = self._accumulator.submit(msg) + return future.result() - Returns: - Broker-assigned message ID (UUIDv7). + return self._enqueue_direct([msg])[0] - Raises: - QueueNotFoundError: If the queue does not exist. - RPCError: For unexpected gRPC failures. - """ - try: - resp = self._stub.Enqueue( - service_pb2.EnqueueRequest( - queue=queue, - headers=headers or {}, - payload=payload, - ) - ) - except grpc.RpcError as e: - raise _map_enqueue_error(e) from e - return str(resp.message_id) + def enqueue_many( + self, + messages: list[tuple[str, dict[str, str] | None, bytes]], + ) -> list[EnqueueResult]: + """Enqueue multiple messages in a single request.""" + msgs: list[dict[str, object]] = [ + {"queue": q, "headers": h or {}, "payload": p} + for q, h, p in messages + ] + body = encode_enqueue(msgs) + header, resp_body = self._request_with_leader_retry(Opcode.ENQUEUE, body) + + if header.opcode == Opcode.ERROR: + err = decode_error(resp_body) + _raise_from_error_frame(err) + + items = decode_enqueue_result(resp_body) + results: list[EnqueueResult] = [] + for item in items: + if item.error_code == ErrorCode.OK: + results.append(EnqueueResult(message_id=item.message_id, error=None)) + else: + err_msg = f"error 0x{item.error_code:02x}" + results.append(EnqueueResult(message_id=None, error=err_msg)) + return results def consume(self, queue: str) -> Iterator[ConsumeMessage]: - """Open a streaming consumer on the specified queue. - - Yields messages as they become available. The iterator ends when the - server stream closes or an error occurs. Skip nil message frames - (keepalive signals) automatically. - - Args: - queue: Queue to consume from. - - Yields: - ConsumeMessage objects as they arrive. - - Raises: - QueueNotFoundError: If the queue does not exist. - RPCError: For unexpected gRPC failures. - """ + """Open a streaming consumer on the specified queue.""" try: - stream = self._stub.Consume( - service_pb2.ConsumeRequest(queue=queue) - ) - except grpc.RpcError as e: - raise _map_consume_error(e) from e + _req_id, consumer_id = self._conn.subscribe(queue) + except NotLeaderError as e: + if e.leader_addr is not None: + self._reconnect(e.leader_addr) + _req_id, consumer_id = self._conn.subscribe(queue) + else: + raise - return self._consume_iter(stream) + return self._consume_iter(consumer_id) - def _consume_iter( - self, - stream: Any, - ) -> Iterator[ConsumeMessage]: - """Internal generator reading from the gRPC stream.""" + def _consume_iter(self, consumer_id: str) -> Iterator[ConsumeMessage]: try: - for resp in stream: - msg = resp.message - if msg is None or not msg.ByteSize(): - continue # keepalive - metadata = msg.metadata - cm = ConsumeMessage( - id=msg.id, - headers=dict(msg.headers), - payload=bytes(msg.payload), - fairness_key=metadata.fairness_key if metadata else "", - attempt_count=metadata.attempt_count if metadata else 0, - queue=metadata.queue_id if metadata else "", - ) - yield cm - except grpc.RpcError: - return + while True: + try: + header, body = self._conn.read_frame() + except (ConnectionError, OSError): + return + + if header.opcode == Opcode.DELIVERY: + for msg in decode_delivery(body): + yield ConsumeMessage( + id=msg.message_id, + queue=msg.queue, + headers=msg.headers, + payload=msg.payload, + fairness_key=msg.fairness_key, + attempt_count=msg.attempt_count, + weight=msg.weight, + throttle_keys=msg.throttle_keys, + enqueued_at=msg.enqueued_at, + leased_at=msg.leased_at, + ) + elif header.opcode == Opcode.ERROR: + err = decode_error(body) + _raise_from_error_frame(err) + finally: + import contextlib + with contextlib.suppress(OSError): + self._conn.cancel_consume(consumer_id) def ack(self, queue: str, msg_id: str) -> None: - """Acknowledge a successfully processed message. + """Acknowledge a successfully processed message.""" + body = encode_ack([{"queue": queue, "message_id": msg_id}]) + header, resp_body = self._request_with_leader_retry(Opcode.ACK, body) - The message is permanently removed from the queue. + if header.opcode == Opcode.ERROR: + err = decode_error(resp_body) + _raise_from_error_frame(err) - Args: - queue: Queue the message belongs to. - msg_id: ID of the message to acknowledge. - - Raises: - MessageNotFoundError: If the message does not exist. - RPCError: For unexpected gRPC failures. - """ - try: - self._stub.Ack( - service_pb2.AckRequest(queue=queue, message_id=msg_id) - ) - except grpc.RpcError as e: - raise _map_ack_error(e) from e + codes = decode_ack_result(resp_body) + if codes and codes[0] != ErrorCode.OK: + raise _map_per_item_error(codes[0], "ack") def nack(self, queue: str, msg_id: str, error: str) -> None: - """Negatively acknowledge a message that failed processing. - - The message is requeued for retry or routed to the dead-letter queue - based on the queue's on_failure Lua hook configuration. + """Negatively acknowledge a message that failed processing.""" + body = encode_nack([{"queue": queue, "message_id": msg_id, "error": error}]) + header, resp_body = self._request_with_leader_retry(Opcode.NACK, body) + + if header.opcode == Opcode.ERROR: + err = decode_error(resp_body) + _raise_from_error_frame(err) + + codes = decode_nack_result(resp_body) + if codes and codes[0] != ErrorCode.OK: + raise _map_per_item_error(codes[0], "nack") + + # -- admin operations ---------------------------------------------------- + + def create_queue(self, name: str) -> None: + """Create a queue on the broker.""" + body = encode_create_queue(name) + header, resp_body = self._request_with_leader_retry(Opcode.CREATE_QUEUE, body) + if header.opcode == Opcode.ERROR: + err = decode_error(resp_body) + _raise_from_error_frame(err) + result = decode_create_queue_result(resp_body) + if result.error_code != ErrorCode.OK: + raise _map_error_code(result.error_code, "create_queue failed") + + def delete_queue(self, name: str) -> None: + """Delete a queue from the broker.""" + body = encode_delete_queue(name) + header, resp_body = self._request_with_leader_retry(Opcode.DELETE_QUEUE, body) + if header.opcode == Opcode.ERROR: + err = decode_error(resp_body) + _raise_from_error_frame(err) + error_code = decode_delete_queue_result(resp_body) + if error_code != ErrorCode.OK: + raise _map_error_code(error_code, "delete_queue failed") + + def get_stats(self, queue: str) -> StatsResult: + """Get statistics for a queue.""" + body = encode_get_stats(queue) + header, resp_body = self._request_with_leader_retry(Opcode.GET_STATS, body) + if header.opcode == Opcode.ERROR: + err = decode_error(resp_body) + _raise_from_error_frame(err) + r = decode_get_stats_result(resp_body) + if r.error_code != ErrorCode.OK: + raise _map_error_code(r.error_code, "get_stats failed") + return StatsResult( + depth=r.depth, + in_flight=r.in_flight, + active_fairness_keys=r.active_fairness_keys, + active_consumers=r.active_consumers, + quantum=r.quantum, + leader_node_id=r.leader_node_id, + replication_count=r.replication_count, + per_key_stats=[ + FairnessKeyStat( + key=s.key, pending_count=s.pending_count, + current_deficit=s.current_deficit, weight=s.weight, + ) + for s in r.per_key_stats + ], + per_throttle_stats=[ + ThrottleKeyStat( + key=s.key, tokens=s.tokens, + rate_per_second=s.rate_per_second, burst=s.burst, + ) + for s in r.per_throttle_stats + ], + ) - Args: - queue: Queue the message belongs to. - msg_id: ID of the message to nack. - error: Description of the failure. + def list_queues(self) -> list[QueueInfo]: + """List all queues on the broker.""" + body = encode_list_queues() + header, resp_body = self._request_with_leader_retry(Opcode.LIST_QUEUES, body) + if header.opcode == Opcode.ERROR: + err = decode_error(resp_body) + _raise_from_error_frame(err) + r = decode_list_queues_result(resp_body) + if r.error_code != ErrorCode.OK: + raise _map_error_code(r.error_code, "list_queues failed") + return [ + QueueInfo( + name=q.name, depth=q.depth, in_flight=q.in_flight, + active_consumers=q.active_consumers, + leader_node_id=q.leader_node_id, + ) + for q in r.queues + ] + + def set_config(self, key: str, value: str) -> None: + """Set a runtime configuration key.""" + body = encode_set_config(key, value) + header, resp_body = self._request_with_leader_retry(Opcode.SET_CONFIG, body) + if header.opcode == Opcode.ERROR: + err = decode_error(resp_body) + _raise_from_error_frame(err) + error_code = decode_set_config_result(resp_body) + if error_code != ErrorCode.OK: + raise _map_error_code(error_code, "set_config failed") + + def get_config(self, key: str) -> str: + """Get a runtime configuration value.""" + body = encode_get_config(key) + header, resp_body = self._request_with_leader_retry(Opcode.GET_CONFIG, body) + if header.opcode == Opcode.ERROR: + err = decode_error(resp_body) + _raise_from_error_frame(err) + error_code, value = decode_get_config_result(resp_body) + if error_code != ErrorCode.OK: + raise _map_error_code(error_code, "get_config failed") + return value + + def list_config(self, prefix: str) -> dict[str, str]: + """List configuration entries matching a prefix.""" + body = encode_list_config(prefix) + header, resp_body = self._request_with_leader_retry(Opcode.LIST_CONFIG, body) + if header.opcode == Opcode.ERROR: + err = decode_error(resp_body) + _raise_from_error_frame(err) + error_code, entries = decode_list_config_result(resp_body) + if error_code != ErrorCode.OK: + raise _map_error_code(error_code, "list_config failed") + return entries + + def redrive(self, dlq_queue: str, count: int) -> int: + """Redrive DLQ messages. Returns number of messages redriven.""" + body = encode_redrive(dlq_queue, count) + header, resp_body = self._request_with_leader_retry(Opcode.REDRIVE, body) + if header.opcode == Opcode.ERROR: + err = decode_error(resp_body) + _raise_from_error_frame(err) + error_code, redriven = decode_redrive_result(resp_body) + if error_code != ErrorCode.OK: + raise _map_error_code(error_code, "redrive failed") + return redriven + + # -- auth operations ----------------------------------------------------- + + def create_api_key(self, name: str) -> CreateApiKeyResult: + """Create a new API key.""" + body = encode_create_api_key(name) + header, resp_body = self._request_with_leader_retry( + Opcode.CREATE_API_KEY, body + ) + if header.opcode == Opcode.ERROR: + err = decode_error(resp_body) + _raise_from_error_frame(err) + error_code, key_id, raw_key, is_superadmin = decode_create_api_key_result( + resp_body + ) + if error_code != ErrorCode.OK: + raise _map_error_code(error_code, "create_api_key failed") + return CreateApiKeyResult( + key_id=key_id, raw_key=raw_key, is_superadmin=is_superadmin + ) - Raises: - MessageNotFoundError: If the message does not exist. - RPCError: For unexpected gRPC failures. - """ - try: - self._stub.Nack( - service_pb2.NackRequest( - queue=queue, message_id=msg_id, error=error - ) + def revoke_api_key(self, key_id: str) -> None: + """Revoke an API key.""" + body = encode_revoke_api_key(key_id) + header, resp_body = self._request_with_leader_retry( + Opcode.REVOKE_API_KEY, body + ) + if header.opcode == Opcode.ERROR: + err = decode_error(resp_body) + _raise_from_error_frame(err) + error_code = decode_revoke_api_key_result(resp_body) + if error_code != ErrorCode.OK: + raise _map_error_code(error_code, "revoke_api_key failed") + + def list_api_keys(self) -> list[ApiKeyInfo]: + """List all API keys.""" + body = encode_list_api_keys() + header, resp_body = self._request_with_leader_retry( + Opcode.LIST_API_KEYS, body + ) + if header.opcode == Opcode.ERROR: + err = decode_error(resp_body) + _raise_from_error_frame(err) + error_code, items = decode_list_api_keys_result(resp_body) + if error_code != ErrorCode.OK: + raise _map_error_code(error_code, "list_api_keys failed") + return [ + ApiKeyInfo( + key_id=k.key_id, name=k.name, created_at=k.created_at, + expires_at=k.expires_at, is_superadmin=k.is_superadmin, ) - except grpc.RpcError as e: - raise _map_nack_error(e) from e + for k in items + ] + + def set_acl(self, key_id: str, permissions: list[tuple[str, str]]) -> None: + """Set ACL for an API key. permissions = [(kind, pattern), ...].""" + body = encode_set_acl(key_id, permissions) + header, resp_body = self._request_with_leader_retry(Opcode.SET_ACL, body) + if header.opcode == Opcode.ERROR: + err = decode_error(resp_body) + _raise_from_error_frame(err) + error_code = decode_set_acl_result(resp_body) + if error_code != ErrorCode.OK: + raise _map_error_code(error_code, "set_acl failed") + + def get_acl(self, key_id: str) -> AclEntry: + """Get ACL for an API key.""" + body = encode_get_acl(key_id) + header, resp_body = self._request_with_leader_retry(Opcode.GET_ACL, body) + if header.opcode == Opcode.ERROR: + err = decode_error(resp_body) + _raise_from_error_frame(err) + result = decode_get_acl_result(resp_body) + if result.error_code != ErrorCode.OK: + raise _map_error_code(result.error_code, "get_acl failed") + return AclEntry( + key_id=result.key_id, + is_superadmin=result.is_superadmin, + permissions=[ + AclPermission(kind=p.kind, pattern=p.pattern) + for p in result.permissions + ], + ) + + # -- internal helpers ---------------------------------------------------- + + def _enqueue_direct(self, messages: list[dict[str, object]]) -> list[str]: + """Send an enqueue request directly and return message IDs.""" + body = encode_enqueue(messages) + header, resp_body = self._request_with_leader_retry(Opcode.ENQUEUE, body) + + if header.opcode == Opcode.ERROR: + err = decode_error(resp_body) + _raise_from_error_frame(err) + + items = decode_enqueue_result(resp_body) + results: list[str] = [] + for item in items: + if item.error_code == ErrorCode.OK: + results.append(item.message_id) + else: + raise _map_per_item_error(item.error_code, "enqueue") + return results + + def _request_with_leader_retry( + self, opcode: int, body: bytes + ) -> tuple[FrameHeader, bytes]: + """Send a request, retrying once on NotLeader with leader hint.""" + header, resp_body = self._conn.request(opcode, body) + + if header.opcode == Opcode.ERROR: + err = decode_error(resp_body) + if err.code == ErrorCode.NOT_LEADER: + leader_addr = err.metadata.get("leader_addr") + if leader_addr: + self._reconnect(leader_addr) + return self._conn.request(opcode, body) + + return header, resp_body diff --git a/fila/conn.py b/fila/conn.py new file mode 100644 index 0000000..e505b83 --- /dev/null +++ b/fila/conn.py @@ -0,0 +1,414 @@ +"""FIBP connection manager — synchronous and asynchronous.""" + +from __future__ import annotations + +import asyncio +import socket +import struct +import threading +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + import ssl + +from fila.fibp.codec import ( + decode_error, + decode_handshake_ok, + encode_handshake, + encode_pong, +) +from fila.fibp.opcodes import ( + DEFAULT_MAX_FRAME_SIZE, + FRAME_HEADER_SIZE, + PROTOCOL_VERSION, + FrameHeader, + Opcode, +) + + +def _parse_header(data: bytes) -> FrameHeader: + """Parse 6 bytes into a FrameHeader.""" + opcode = data[0] + flags = data[1] + request_id = struct.unpack_from("!I", data, 2)[0] + return FrameHeader(opcode=opcode, flags=flags, request_id=request_id) + + +def _build_frame(opcode: int, request_id: int, body: bytes, flags: int = 0) -> bytes: + """Build a length-prefixed FIBP frame.""" + frame_body = struct.pack("!BBI", opcode, flags, request_id) + body + return struct.pack("!I", len(frame_body)) + frame_body + + +# --------------------------------------------------------------------------- +# Synchronous connection +# --------------------------------------------------------------------------- + +class Connection: + """Synchronous FIBP connection over a TCP socket.""" + + def __init__(self, sock: socket.socket, max_frame_size: int = DEFAULT_MAX_FRAME_SIZE) -> None: + self._sock = sock + self._max_frame_size = max_frame_size + self._req_counter = 0 + self._lock = threading.Lock() + self._pushback: tuple[FrameHeader, bytes] | None = None + + @classmethod + def connect( + cls, + host: str, + port: int, + *, + ssl_context: ssl.SSLContext | None = None, + api_key: str | None = None, + timeout: float = 10.0, + ) -> Connection: + """Open a TCP connection and perform the FIBP handshake.""" + sock = socket.create_connection((host, port), timeout=timeout) + sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) + + if ssl_context is not None: + sock = ssl_context.wrap_socket(sock, server_hostname=host) + + conn = cls(sock) + conn._handshake(api_key) + return conn + + def _next_request_id(self) -> int: + self._req_counter += 1 + return self._req_counter & 0xFFFFFFFF + + def _handshake(self, api_key: str | None) -> None: + """Perform the FIBP handshake.""" + body = encode_handshake(PROTOCOL_VERSION, api_key) + req_id = self._next_request_id() + self.write_frame(Opcode.HANDSHAKE, req_id, body) + + header, resp_body = self.read_frame() + if header.opcode == Opcode.ERROR: + err = decode_error(resp_body) + from fila.errors import _raise_from_error_frame + _raise_from_error_frame(err) + + if header.opcode != Opcode.HANDSHAKE_OK: + raise ConnectionError( + f"expected HandshakeOk (0x02), got 0x{header.opcode:02x}" + ) + + version, _node_id, max_frame_size = decode_handshake_ok(resp_body) + if max_frame_size > 0: + self._max_frame_size = max_frame_size + + def write_frame(self, opcode: int, request_id: int, body: bytes, flags: int = 0) -> None: + """Write a single length-prefixed FIBP frame.""" + frame = _build_frame(opcode, request_id, body, flags) + self._sock.sendall(frame) + + def read_frame(self) -> tuple[FrameHeader, bytes]: + """Read a single length-prefixed FIBP frame. + + Handles Ping by responding with Pong automatically. + Handles continuation frames by concatenating bodies. + """ + if self._pushback is not None: + frame = self._pushback + self._pushback = None + return frame + + while True: + header, body = self._read_single_frame() + + # Auto-reply to Ping. + if header.opcode == Opcode.PING: + self.write_frame(Opcode.PONG, header.request_id, encode_pong()) + continue + + # Handle continuation frames. + if header.is_continuation: + parts = [body] + while True: + cont_header, cont_body = self._read_single_frame() + if cont_header.opcode == Opcode.PING: + self.write_frame(Opcode.PONG, cont_header.request_id, encode_pong()) + continue + parts.append(cont_body) + if not cont_header.is_continuation: + break + return header, b"".join(parts) + + return header, body + + def _read_single_frame(self) -> tuple[FrameHeader, bytes]: + """Read one frame from the wire (no continuation handling).""" + length_bytes = self._recv_exact(4) + frame_len = struct.unpack("!I", length_bytes)[0] + if frame_len > self._max_frame_size: + raise ConnectionError( + f"frame size {frame_len} exceeds max {self._max_frame_size}" + ) + frame_data = self._recv_exact(frame_len) + header = _parse_header(frame_data[:FRAME_HEADER_SIZE]) + body = frame_data[FRAME_HEADER_SIZE:] + return header, body + + def _recv_exact(self, n: int) -> bytes: + """Read exactly n bytes from the socket.""" + buf = bytearray() + while len(buf) < n: + chunk = self._sock.recv(n - len(buf)) + if not chunk: + raise ConnectionError("connection closed by remote") + buf.extend(chunk) + return bytes(buf) + + def request(self, opcode: int, body: bytes) -> tuple[FrameHeader, bytes]: + """Send a request frame and read the response (synchronous request-response).""" + req_id = self._next_request_id() + with self._lock: + self.write_frame(opcode, req_id, body) + return self.read_frame() + + def subscribe(self, queue: str) -> tuple[int, str]: + """Send a Consume request and wait for ConsumeOk. + + Returns (request_id, consumer_id). + """ + from fila.fibp.codec import decode_consume_ok, encode_consume + + req_id = self._next_request_id() + self.write_frame(Opcode.CONSUME, req_id, encode_consume(queue)) + + header, body = self.read_frame() + if header.opcode == Opcode.ERROR: + err = decode_error(body) + from fila.errors import _raise_from_error_frame + _raise_from_error_frame(err) + + if header.opcode == Opcode.CONSUME_OK: + consumer_id = decode_consume_ok(body) + return req_id, consumer_id + + if header.opcode == Opcode.DELIVERY: + # Server may send Delivery directly (older binaries without ConsumeOk). + # Push the frame back so the consume iterator can read it. + self._pushback = (header, body) + return req_id, "" + + raise ConnectionError( + f"expected ConsumeOk (0x{Opcode.CONSUME_OK:02x}) or " + f"Delivery (0x{Opcode.DELIVERY:02x}), " + f"got 0x{header.opcode:02x}" + ) + + def cancel_consume(self, consumer_id: str) -> None: + """Send a CancelConsume frame.""" + from fila.fibp.codec import encode_cancel_consume + + if not consumer_id: + return + req_id = self._next_request_id() + self.write_frame(Opcode.CANCEL_CONSUME, req_id, encode_cancel_consume(consumer_id)) + + def close(self) -> None: + """Send Disconnect and close the socket.""" + import contextlib + + with contextlib.suppress(OSError): + from fila.fibp.codec import encode_disconnect + req_id = self._next_request_id() + self.write_frame(Opcode.DISCONNECT, req_id, encode_disconnect()) + + with contextlib.suppress(OSError): + self._sock.close() + + +# --------------------------------------------------------------------------- +# Async connection +# --------------------------------------------------------------------------- + +class AsyncConnection: + """Asynchronous FIBP connection over asyncio streams.""" + + def __init__( + self, + reader: asyncio.StreamReader, + writer: asyncio.StreamWriter, + max_frame_size: int = DEFAULT_MAX_FRAME_SIZE, + ) -> None: + self._reader = reader + self._writer = writer + self._max_frame_size = max_frame_size + self._req_counter = 0 + self._lock = asyncio.Lock() + self._pushback: tuple[FrameHeader, bytes] | None = None + + @classmethod + async def connect( + cls, + host: str, + port: int, + *, + ssl_context: ssl.SSLContext | None = None, + api_key: str | None = None, + timeout: float = 10.0, + ) -> AsyncConnection: + """Open a TCP connection and perform the FIBP handshake.""" + reader, writer = await asyncio.wait_for( + asyncio.open_connection(host, port, ssl=ssl_context), + timeout=timeout, + ) + + # Set TCP_NODELAY on the underlying socket. + sock = writer.get_extra_info("socket") + if sock is not None: + sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) + + conn = cls(reader, writer) + await conn._handshake(api_key) + return conn + + def _next_request_id(self) -> int: + self._req_counter += 1 + return self._req_counter & 0xFFFFFFFF + + async def _handshake(self, api_key: str | None) -> None: + """Perform the FIBP handshake.""" + body = encode_handshake(PROTOCOL_VERSION, api_key) + req_id = self._next_request_id() + await self.write_frame(Opcode.HANDSHAKE, req_id, body) + + header, resp_body = await self.read_frame() + if header.opcode == Opcode.ERROR: + err = decode_error(resp_body) + from fila.errors import _raise_from_error_frame + _raise_from_error_frame(err) + + if header.opcode != Opcode.HANDSHAKE_OK: + raise ConnectionError( + f"expected HandshakeOk (0x02), got 0x{header.opcode:02x}" + ) + + version, _node_id, max_frame_size = decode_handshake_ok(resp_body) + if max_frame_size > 0: + self._max_frame_size = max_frame_size + + async def write_frame( + self, opcode: int, request_id: int, body: bytes, flags: int = 0 + ) -> None: + """Write a single length-prefixed FIBP frame.""" + frame = _build_frame(opcode, request_id, body, flags) + self._writer.write(frame) + await self._writer.drain() + + async def read_frame(self) -> tuple[FrameHeader, bytes]: + """Read a single length-prefixed FIBP frame. + + Handles Ping by responding with Pong automatically. + Handles continuation frames by concatenating bodies. + """ + if self._pushback is not None: + frame = self._pushback + self._pushback = None + return frame + + while True: + header, body = await self._read_single_frame() + + if header.opcode == Opcode.PING: + await self.write_frame(Opcode.PONG, header.request_id, encode_pong()) + continue + + if header.is_continuation: + parts = [body] + while True: + cont_header, cont_body = await self._read_single_frame() + if cont_header.opcode == Opcode.PING: + await self.write_frame( + Opcode.PONG, cont_header.request_id, encode_pong() + ) + continue + parts.append(cont_body) + if not cont_header.is_continuation: + break + return header, b"".join(parts) + + return header, body + + async def _read_single_frame(self) -> tuple[FrameHeader, bytes]: + """Read one frame from the wire (no continuation handling).""" + length_bytes = await self._reader.readexactly(4) + frame_len = struct.unpack("!I", length_bytes)[0] + if frame_len > self._max_frame_size: + raise ConnectionError( + f"frame size {frame_len} exceeds max {self._max_frame_size}" + ) + frame_data = await self._reader.readexactly(frame_len) + header = _parse_header(frame_data[:FRAME_HEADER_SIZE]) + body = frame_data[FRAME_HEADER_SIZE:] + return header, body + + async def request(self, opcode: int, body: bytes) -> tuple[FrameHeader, bytes]: + """Send a request frame and read the response.""" + req_id = self._next_request_id() + async with self._lock: + await self.write_frame(opcode, req_id, body) + return await self.read_frame() + + async def subscribe(self, queue: str) -> tuple[int, str]: + """Send a Consume request and wait for ConsumeOk. + + Returns (request_id, consumer_id). + """ + from fila.fibp.codec import decode_consume_ok, encode_consume + + req_id = self._next_request_id() + await self.write_frame(Opcode.CONSUME, req_id, encode_consume(queue)) + + header, body = await self.read_frame() + if header.opcode == Opcode.ERROR: + err = decode_error(body) + from fila.errors import _raise_from_error_frame + _raise_from_error_frame(err) + + if header.opcode == Opcode.CONSUME_OK: + consumer_id = decode_consume_ok(body) + return req_id, consumer_id + + if header.opcode == Opcode.DELIVERY: + # Server may send Delivery directly (older binaries without ConsumeOk). + self._pushback = (header, body) + return req_id, "" + + raise ConnectionError( + f"expected ConsumeOk (0x{Opcode.CONSUME_OK:02x}) or " + f"Delivery (0x{Opcode.DELIVERY:02x}), " + f"got 0x{header.opcode:02x}" + ) + + async def cancel_consume(self, consumer_id: str) -> None: + """Send a CancelConsume frame.""" + from fila.fibp.codec import encode_cancel_consume + + if not consumer_id: + return + + req_id = self._next_request_id() + await self.write_frame( + Opcode.CANCEL_CONSUME, req_id, encode_cancel_consume(consumer_id) + ) + + async def close(self) -> None: + """Send Disconnect and close the stream.""" + try: + from fila.fibp.codec import encode_disconnect + req_id = self._next_request_id() + await self.write_frame(Opcode.DISCONNECT, req_id, encode_disconnect()) + except OSError: + pass + finally: + try: + self._writer.close() + await self._writer.wait_closed() + except OSError: + pass diff --git a/fila/errors.py b/fila/errors.py index 346c1c6..242968d 100644 --- a/fila/errors.py +++ b/fila/errors.py @@ -2,7 +2,10 @@ from __future__ import annotations -import grpc +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from fila.fibp.codec import ErrorFrame class FilaError(Exception): @@ -17,42 +20,159 @@ class MessageNotFoundError(FilaError): """Raised when the specified message does not exist.""" -class RPCError(FilaError): - """Raised for unexpected gRPC failures, preserving status code and message.""" +class QueueAlreadyExistsError(FilaError): + """Raised when creating a queue that already exists.""" + + +class InvalidArgumentError(FilaError): + """Raised when an argument is invalid.""" + + +class PermissionDeniedError(FilaError): + """Raised when permission is denied for the operation.""" + + +class UnauthorizedError(FilaError): + """Raised when the client is not authenticated.""" + + +class ForbiddenError(FilaError): + """Raised when the client lacks permission for the operation.""" - def __init__(self, code: grpc.StatusCode, message: str) -> None: - self.code = code - self.message = message - super().__init__(f"rpc error (code = {code.name}): {message}") +class NotLeaderError(FilaError): + """Raised when the request was sent to a non-leader node. -def _map_enqueue_error(err: grpc.RpcError) -> FilaError: - """Map a gRPC error from an enqueue call to a Fila exception.""" - code = err.code() - if code == grpc.StatusCode.NOT_FOUND: - return QueueNotFoundError(f"enqueue: {err.details()}") - return RPCError(code, err.details() or "") + The ``leader_addr`` attribute contains the address of the current leader, + if available. + """ + def __init__(self, message: str, leader_addr: str | None = None) -> None: + self.leader_addr = leader_addr + super().__init__(message) -def _map_consume_error(err: grpc.RpcError) -> FilaError: - """Map a gRPC error from a consume call to a Fila exception.""" - code = err.code() - if code == grpc.StatusCode.NOT_FOUND: - return QueueNotFoundError(f"consume: {err.details()}") - return RPCError(code, err.details() or "") +class ChannelFullError(FilaError): + """Raised when a channel or buffer is full (backpressure).""" -def _map_ack_error(err: grpc.RpcError) -> FilaError: - """Map a gRPC error from an ack call to a Fila exception.""" - code = err.code() - if code == grpc.StatusCode.NOT_FOUND: - return MessageNotFoundError(f"ack: {err.details()}") - return RPCError(code, err.details() or "") +class ResourceExhaustedError(FilaError): + """Raised when a resource limit has been reached.""" -def _map_nack_error(err: grpc.RpcError) -> FilaError: - """Map a gRPC error from a nack call to a Fila exception.""" - code = err.code() - if code == grpc.StatusCode.NOT_FOUND: - return MessageNotFoundError(f"nack: {err.details()}") - return RPCError(code, err.details() or "") + +class UnavailableError(FilaError): + """Raised when the server is unavailable.""" + + +class LuaError(FilaError): + """Raised when a Lua script error occurs.""" + + +class EnqueueError(FilaError): + """Raised when an enqueue operation fails. + + In ``enqueue_many()``, individual per-message failures are reported via + ``EnqueueResult.error`` and do not raise this exception. It is also used + as a fallback for per-message enqueue failures that do not map to a more + specific type (e.g., storage or Lua errors). + """ + + +class ProtocolError(FilaError): + """Raised for unexpected protocol-level failures.""" + + def __init__(self, code: int, message: str) -> None: + self.code = code + self.message = message + super().__init__(f"protocol error (code=0x{code:02x}): {message}") + + +class ApiKeyNotFoundError(FilaError): + """Raised when the specified API key does not exist.""" + + +class AclNotFoundError(FilaError): + """Raised when the specified ACL entry does not exist.""" + + +# Keep RPCError as a thin alias for backwards compatibility with existing +# callers that catch fila.RPCError. +RPCError = ProtocolError + + +# --------------------------------------------------------------------------- +# Error-code mapping +# --------------------------------------------------------------------------- + +def _map_error_code(code: int, message: str) -> FilaError: + """Map a FIBP error code to the appropriate exception.""" + from fila.fibp.opcodes import ErrorCode + + match code: + case ErrorCode.QUEUE_NOT_FOUND: + return QueueNotFoundError(message) + case ErrorCode.MESSAGE_NOT_FOUND: + return MessageNotFoundError(message) + case ErrorCode.QUEUE_ALREADY_EXISTS: + return QueueAlreadyExistsError(message) + case ErrorCode.INVALID_ARGUMENT: + return InvalidArgumentError(message) + case ErrorCode.PERMISSION_DENIED: + return PermissionDeniedError(message) + case ErrorCode.UNAUTHENTICATED: + return UnauthorizedError(message) + case ErrorCode.FORBIDDEN: + return ForbiddenError(message) + case ErrorCode.NOT_LEADER: + return NotLeaderError(message) + case ErrorCode.CHANNEL_FULL: + return ChannelFullError(message) + case ErrorCode.RESOURCE_EXHAUSTED: + return ResourceExhaustedError(message) + case ErrorCode.UNAVAILABLE: + return UnavailableError(message) + case ErrorCode.LUA_ERROR: + return LuaError(message) + case ErrorCode.API_KEY_NOT_FOUND: + return ApiKeyNotFoundError(message) + case ErrorCode.ACL_NOT_FOUND: + return AclNotFoundError(message) + case ErrorCode.DEADLINE_EXCEEDED: + return ProtocolError(code, message) + case ErrorCode.PRECONDITION_FAILED: + return ProtocolError(code, message) + case ErrorCode.ABORTED: + return ProtocolError(code, message) + case ErrorCode.INTERNAL_ERROR: + return ProtocolError(code, message) + case _: + return ProtocolError(code, message) + + +def _raise_from_error_frame(err: ErrorFrame) -> None: + """Raise the appropriate exception from a decoded Error frame. + + For NotLeader errors, extracts the leader_addr from metadata. + """ + from fila.fibp.opcodes import ErrorCode + + if err.code == ErrorCode.NOT_LEADER: + leader_addr = err.metadata.get("leader_addr") + raise NotLeaderError(err.message, leader_addr=leader_addr) + + raise _map_error_code(err.code, err.message) + + +def _map_per_item_error(code: int, context: str) -> FilaError: + """Map a per-item error code (from EnqueueResult, AckResult, etc.).""" + from fila.fibp.opcodes import ErrorCode + + match code: + case ErrorCode.QUEUE_NOT_FOUND: + return QueueNotFoundError(f"{context}: queue not found") + case ErrorCode.MESSAGE_NOT_FOUND: + return MessageNotFoundError(f"{context}: message not found") + case ErrorCode.PERMISSION_DENIED: + return PermissionDeniedError(f"{context}: permission denied") + case _: + return EnqueueError(f"{context}: error code 0x{code:02x}") diff --git a/fila/fibp/__init__.py b/fila/fibp/__init__.py new file mode 100644 index 0000000..134f1b9 --- /dev/null +++ b/fila/fibp/__init__.py @@ -0,0 +1,96 @@ +"""FIBP (Fila Binary Protocol) codec and primitives.""" + +from fila.fibp.codec import ( + decode_ack_result, + decode_consume_ok, + decode_create_api_key_result, + decode_create_queue_result, + decode_delete_queue_result, + decode_delivery, + decode_enqueue_result, + decode_error, + decode_get_acl_result, + decode_get_config_result, + decode_get_stats_result, + decode_handshake_ok, + decode_list_api_keys_result, + decode_list_config_result, + decode_list_queues_result, + decode_nack_result, + decode_redrive_result, + decode_revoke_api_key_result, + decode_set_acl_result, + decode_set_config_result, + encode_ack, + encode_cancel_consume, + encode_consume, + encode_create_api_key, + encode_create_queue, + encode_delete_queue, + encode_disconnect, + encode_enqueue, + encode_get_acl, + encode_get_config, + encode_get_stats, + encode_handshake, + encode_list_api_keys, + encode_list_config, + encode_list_queues, + encode_nack, + encode_pong, + encode_redrive, + encode_revoke_api_key, + encode_set_acl, + encode_set_config, +) +from fila.fibp.opcodes import ErrorCode, FrameHeader, Opcode +from fila.fibp.primitives import Reader, Writer + +__all__ = [ + "ErrorCode", + "FrameHeader", + "Opcode", + "Reader", + "Writer", + "decode_ack_result", + "decode_consume_ok", + "decode_create_api_key_result", + "decode_create_queue_result", + "decode_delete_queue_result", + "decode_delivery", + "decode_enqueue_result", + "decode_error", + "decode_get_acl_result", + "decode_get_config_result", + "decode_get_stats_result", + "decode_handshake_ok", + "decode_list_api_keys_result", + "decode_list_config_result", + "decode_list_queues_result", + "decode_nack_result", + "decode_redrive_result", + "decode_revoke_api_key_result", + "decode_set_acl_result", + "decode_set_config_result", + "encode_ack", + "encode_cancel_consume", + "encode_consume", + "encode_create_api_key", + "encode_create_queue", + "encode_delete_queue", + "encode_disconnect", + "encode_enqueue", + "encode_get_acl", + "encode_get_config", + "encode_get_stats", + "encode_handshake", + "encode_list_api_keys", + "encode_list_config", + "encode_list_queues", + "encode_nack", + "encode_pong", + "encode_redrive", + "encode_revoke_api_key", + "encode_set_acl", + "encode_set_config", +] diff --git a/fila/fibp/codec.py b/fila/fibp/codec.py new file mode 100644 index 0000000..407d578 --- /dev/null +++ b/fila/fibp/codec.py @@ -0,0 +1,634 @@ +"""Encode/decode functions for every FIBP opcode.""" + +from __future__ import annotations + +from dataclasses import dataclass, field + +from fila.fibp.primitives import Reader, Writer + +# --------------------------------------------------------------------------- +# Data types used by decode functions +# --------------------------------------------------------------------------- + + +@dataclass(frozen=True, slots=True) +class DeliveryMessage: + """A single message within a Delivery frame.""" + + message_id: str + queue: str + headers: dict[str, str] + payload: bytes + fairness_key: str + weight: int + throttle_keys: list[str] + attempt_count: int + enqueued_at: int + leased_at: int + + +@dataclass(frozen=True, slots=True) +class EnqueueResultItem: + """Per-message result within an EnqueueResult frame.""" + + error_code: int + message_id: str + + +@dataclass(frozen=True, slots=True) +class ErrorFrame: + """Decoded Error frame.""" + + code: int + message: str + metadata: dict[str, str] + + +@dataclass(frozen=True, slots=True) +class CreateQueueResultFrame: + """Decoded CreateQueueResult.""" + + error_code: int + queue_id: str + + +@dataclass(frozen=True, slots=True) +class FairnessKeyStat: + """Per-fairness-key stats in GetStatsResult.""" + + key: str + pending_count: int + current_deficit: int + weight: int + + +@dataclass(frozen=True, slots=True) +class ThrottleKeyStat: + """Per-throttle-key stats in GetStatsResult.""" + + key: str + tokens: float + rate_per_second: float + burst: float + + +@dataclass(frozen=True, slots=True) +class StatsResultFrame: + """Decoded GetStatsResult frame.""" + + error_code: int + depth: int + in_flight: int + active_fairness_keys: int + active_consumers: int + quantum: int + leader_node_id: int + replication_count: int + per_key_stats: list[FairnessKeyStat] = field(default_factory=list) + per_throttle_stats: list[ThrottleKeyStat] = field(default_factory=list) + + +@dataclass(frozen=True, slots=True) +class ListQueuesQueueInfo: + """A single queue in ListQueuesResult.""" + + name: str + depth: int + in_flight: int + active_consumers: int + leader_node_id: int + + +@dataclass(frozen=True, slots=True) +class ListQueuesResultFrame: + """Decoded ListQueuesResult.""" + + error_code: int + cluster_node_count: int + queues: list[ListQueuesQueueInfo] + + +@dataclass(frozen=True, slots=True) +class ApiKeyInfoFrame: + """A single API key in ListApiKeysResult.""" + + key_id: str + name: str + created_at: int + expires_at: int + is_superadmin: bool + + +@dataclass(frozen=True, slots=True) +class AclPermission: + """A single ACL permission.""" + + kind: str + pattern: str + + +@dataclass(frozen=True, slots=True) +class GetAclResultFrame: + """Decoded GetAclResult.""" + + error_code: int + key_id: str + is_superadmin: bool + permissions: list[AclPermission] + + +# --------------------------------------------------------------------------- +# Encode: Control +# --------------------------------------------------------------------------- + + +def encode_handshake(version: int, api_key: str | None = None) -> bytes: + """Encode a Handshake frame body.""" + w = Writer() + w.write_u16(version) + w.write_optional_string(api_key) + return w.finish() + + +def encode_pong() -> bytes: + """Encode a Pong frame body (empty).""" + return b"" + + +def encode_disconnect() -> bytes: + """Encode a Disconnect frame body (empty).""" + return b"" + + +# --------------------------------------------------------------------------- +# Decode: Control +# --------------------------------------------------------------------------- + + +def decode_handshake_ok(data: bytes) -> tuple[int, int, int]: + """Decode a HandshakeOk frame body -> (version, node_id, max_frame_size).""" + r = Reader(data) + version = r.read_u16() + node_id = r.read_u64() + max_frame_size = r.read_u32() + return version, node_id, max_frame_size + + +# --------------------------------------------------------------------------- +# Encode: Hot-path +# --------------------------------------------------------------------------- + + +def encode_enqueue(messages: list[dict[str, object]]) -> bytes: + """Encode an Enqueue frame body. + + Each message dict has keys: queue (str), headers (dict[str,str]), payload (bytes). + """ + w = Writer() + w.write_u32(len(messages)) + for msg in messages: + w.write_string(str(msg["queue"])) + w.write_string_map(msg.get("headers") or {}) # type: ignore[arg-type] + w.write_bytes(msg.get("payload", b"") or b"") # type: ignore[arg-type] + return w.finish() + + +def encode_consume(queue: str) -> bytes: + """Encode a Consume frame body.""" + w = Writer() + w.write_string(queue) + return w.finish() + + +def encode_cancel_consume(consumer_id: str) -> bytes: + """Encode a CancelConsume frame body.""" + w = Writer() + w.write_string(consumer_id) + return w.finish() + + +def encode_ack(items: list[dict[str, str]]) -> bytes: + """Encode an Ack frame body. Each item: {queue, message_id}.""" + w = Writer() + w.write_u32(len(items)) + for item in items: + w.write_string(item["queue"]) + w.write_string(item["message_id"]) + return w.finish() + + +def encode_nack(items: list[dict[str, str]]) -> bytes: + """Encode a Nack frame body. Each item: {queue, message_id, error}.""" + w = Writer() + w.write_u32(len(items)) + for item in items: + w.write_string(item["queue"]) + w.write_string(item["message_id"]) + w.write_string(item.get("error", "")) + return w.finish() + + +# --------------------------------------------------------------------------- +# Decode: Hot-path +# --------------------------------------------------------------------------- + + +def decode_enqueue_result(data: bytes) -> list[EnqueueResultItem]: + """Decode an EnqueueResult frame body.""" + r = Reader(data) + count = r.read_u32() + results: list[EnqueueResultItem] = [] + for _ in range(count): + error_code = r.read_u8() + message_id = r.read_string() + results.append(EnqueueResultItem(error_code=error_code, message_id=message_id)) + return results + + +def decode_consume_ok(data: bytes) -> str: + """Decode a ConsumeOk frame body -> consumer_id.""" + r = Reader(data) + return r.read_string() + + +def decode_delivery(data: bytes) -> list[DeliveryMessage]: + """Decode a Delivery frame body.""" + r = Reader(data) + count = r.read_u32() + messages: list[DeliveryMessage] = [] + for _ in range(count): + msg_id = r.read_string() + queue = r.read_string() + headers = r.read_string_map() + payload = r.read_bytes() + fairness_key = r.read_string() + weight = r.read_u32() + throttle_keys = r.read_string_list() + attempt_count = r.read_u32() + enqueued_at = r.read_u64() + leased_at = r.read_u64() + messages.append(DeliveryMessage( + message_id=msg_id, + queue=queue, + headers=headers, + payload=payload, + fairness_key=fairness_key, + weight=weight, + throttle_keys=throttle_keys, + attempt_count=attempt_count, + enqueued_at=enqueued_at, + leased_at=leased_at, + )) + return messages + + +def decode_ack_result(data: bytes) -> list[int]: + """Decode an AckResult frame body -> list of error codes.""" + r = Reader(data) + count = r.read_u32() + return [r.read_u8() for _ in range(count)] + + +def decode_nack_result(data: bytes) -> list[int]: + """Decode a NackResult frame body -> list of error codes.""" + r = Reader(data) + count = r.read_u32() + return [r.read_u8() for _ in range(count)] + + +# --------------------------------------------------------------------------- +# Decode: Error +# --------------------------------------------------------------------------- + + +def decode_error(data: bytes) -> ErrorFrame: + """Decode an Error frame body.""" + r = Reader(data) + code = r.read_u8() + message = r.read_string() + metadata = r.read_string_map() + return ErrorFrame(code=code, message=message, metadata=metadata) + + +# --------------------------------------------------------------------------- +# Encode: Admin +# --------------------------------------------------------------------------- + + +def encode_create_queue( + name: str, + *, + on_enqueue_script: str | None = None, + on_failure_script: str | None = None, + visibility_timeout_ms: int = 0, +) -> bytes: + """Encode a CreateQueue frame body. + + Wire format: [string name][optional on_enqueue_script] + [optional on_failure_script][u64 visibility_timeout_ms] + """ + w = Writer() + w.write_string(name) + w.write_optional_string(on_enqueue_script) + w.write_optional_string(on_failure_script) + w.write_u64(visibility_timeout_ms) + return w.finish() + + +def encode_delete_queue(name: str) -> bytes: + """Encode a DeleteQueue frame body.""" + w = Writer() + w.write_string(name) + return w.finish() + + +def encode_get_stats(queue: str) -> bytes: + """Encode a GetStats frame body.""" + w = Writer() + w.write_string(queue) + return w.finish() + + +def encode_list_queues() -> bytes: + """Encode a ListQueues frame body (empty).""" + return b"" + + +def encode_set_config(key: str, value: str) -> bytes: + """Encode a SetConfig frame body. Wire: [string key][string value].""" + w = Writer() + w.write_string(key) + w.write_string(value) + return w.finish() + + +def encode_get_config(key: str) -> bytes: + """Encode a GetConfig frame body. Wire: [string key].""" + w = Writer() + w.write_string(key) + return w.finish() + + +def encode_list_config(prefix: str) -> bytes: + """Encode a ListConfig frame body. Wire: [string prefix].""" + w = Writer() + w.write_string(prefix) + return w.finish() + + +def encode_redrive(dlq_queue: str, count: int) -> bytes: + """Encode a Redrive frame body. Wire: [string dlq_queue][u64 count].""" + w = Writer() + w.write_string(dlq_queue) + w.write_u64(count) + return w.finish() + + +# --------------------------------------------------------------------------- +# Decode: Admin results +# --------------------------------------------------------------------------- + + +def decode_create_queue_result(data: bytes) -> CreateQueueResultFrame: + """Decode a CreateQueueResult. Wire: [u8 error_code][string queue_id].""" + r = Reader(data) + error_code = r.read_u8() + queue_id = r.read_string() + return CreateQueueResultFrame(error_code=error_code, queue_id=queue_id) + + +def decode_delete_queue_result(data: bytes) -> int: + """Decode a DeleteQueueResult -> error_code.""" + r = Reader(data) + return r.read_u8() + + +def decode_get_stats_result(data: bytes) -> StatsResultFrame: + """Decode a GetStatsResult frame body.""" + r = Reader(data) + error_code = r.read_u8() + depth = r.read_u64() + in_flight = r.read_u64() + active_fairness_keys = r.read_u64() + active_consumers = r.read_u32() + quantum = r.read_u32() + leader_node_id = r.read_u64() + replication_count = r.read_u32() + + per_key_count = r.read_u16() + per_key_stats: list[FairnessKeyStat] = [] + for _ in range(per_key_count): + key = r.read_string() + pending = r.read_u64() + deficit = r.read_i64() + weight = r.read_u32() + per_key_stats.append(FairnessKeyStat( + key=key, pending_count=pending, current_deficit=deficit, weight=weight + )) + + per_throttle_count = r.read_u16() + per_throttle_stats: list[ThrottleKeyStat] = [] + for _ in range(per_throttle_count): + key = r.read_string() + tokens = r.read_f64() + rate = r.read_f64() + burst = r.read_f64() + per_throttle_stats.append(ThrottleKeyStat( + key=key, tokens=tokens, rate_per_second=rate, burst=burst + )) + + return StatsResultFrame( + error_code=error_code, + depth=depth, + in_flight=in_flight, + active_fairness_keys=active_fairness_keys, + active_consumers=active_consumers, + quantum=quantum, + leader_node_id=leader_node_id, + replication_count=replication_count, + per_key_stats=per_key_stats, + per_throttle_stats=per_throttle_stats, + ) + + +def decode_list_queues_result(data: bytes) -> ListQueuesResultFrame: + """Decode a ListQueuesResult frame body.""" + r = Reader(data) + error_code = r.read_u8() + cluster_node_count = r.read_u32() + queue_count = r.read_u16() + queues: list[ListQueuesQueueInfo] = [] + for _ in range(queue_count): + name = r.read_string() + depth = r.read_u64() + in_flight = r.read_u64() + active_consumers = r.read_u32() + leader_node_id = r.read_u64() + queues.append(ListQueuesQueueInfo( + name=name, + depth=depth, + in_flight=in_flight, + active_consumers=active_consumers, + leader_node_id=leader_node_id, + )) + return ListQueuesResultFrame( + error_code=error_code, + cluster_node_count=cluster_node_count, + queues=queues, + ) + + +def decode_set_config_result(data: bytes) -> int: + """Decode a SetConfigResult -> error_code.""" + r = Reader(data) + return r.read_u8() + + +def decode_get_config_result(data: bytes) -> tuple[int, str]: + """Decode a GetConfigResult -> (error_code, value).""" + r = Reader(data) + error_code = r.read_u8() + value = r.read_string() + return error_code, value + + +def decode_list_config_result(data: bytes) -> tuple[int, dict[str, str]]: + """Decode a ListConfigResult -> (error_code, entries).""" + r = Reader(data) + error_code = r.read_u8() + count = r.read_u16() + entries: dict[str, str] = {} + for _ in range(count): + key = r.read_string() + value = r.read_string() + entries[key] = value + return error_code, entries + + +def decode_redrive_result(data: bytes) -> tuple[int, int]: + """Decode a RedriveResult -> (error_code, redriven_count).""" + r = Reader(data) + error_code = r.read_u8() + redriven = r.read_u64() + return error_code, redriven + + +# --------------------------------------------------------------------------- +# Encode: Auth +# --------------------------------------------------------------------------- + + +def encode_create_api_key( + name: str, *, expires_at_ms: int = 0, is_superadmin: bool = False +) -> bytes: + """Encode a CreateApiKey frame body. + + Wire: [string name][u64 expires_at_ms][bool is_superadmin] + """ + w = Writer() + w.write_string(name) + w.write_u64(expires_at_ms) + w.write_bool(is_superadmin) + return w.finish() + + +def encode_revoke_api_key(key_id: str) -> bytes: + """Encode a RevokeApiKey frame body.""" + w = Writer() + w.write_string(key_id) + return w.finish() + + +def encode_list_api_keys() -> bytes: + """Encode a ListApiKeys frame body (empty).""" + return b"" + + +def encode_set_acl( + key_id: str, permissions: list[tuple[str, str]] +) -> bytes: + """Encode a SetAcl frame body. + + Wire: [string key_id][u16 count][per: string kind, string pattern] + permissions is a list of (kind, pattern) tuples. + """ + w = Writer() + w.write_string(key_id) + w.write_u16(len(permissions)) + for kind, pattern in permissions: + w.write_string(kind) + w.write_string(pattern) + return w.finish() + + +def encode_get_acl(key_id: str) -> bytes: + """Encode a GetAcl frame body.""" + w = Writer() + w.write_string(key_id) + return w.finish() + + +# --------------------------------------------------------------------------- +# Decode: Auth results +# --------------------------------------------------------------------------- + + +def decode_create_api_key_result(data: bytes) -> tuple[int, str, str, bool]: + """Decode a CreateApiKeyResult -> (error_code, key_id, raw_key, is_superadmin).""" + r = Reader(data) + error_code = r.read_u8() + key_id = r.read_string() + raw_key = r.read_string() + is_superadmin = r.read_bool() + return error_code, key_id, raw_key, is_superadmin + + +def decode_revoke_api_key_result(data: bytes) -> int: + """Decode a RevokeApiKeyResult -> error_code.""" + r = Reader(data) + return r.read_u8() + + +def decode_list_api_keys_result(data: bytes) -> tuple[int, list[ApiKeyInfoFrame]]: + """Decode a ListApiKeysResult -> (error_code, list of ApiKeyInfoFrame).""" + r = Reader(data) + error_code = r.read_u8() + count = r.read_u16() + keys: list[ApiKeyInfoFrame] = [] + for _ in range(count): + key_id = r.read_string() + name = r.read_string() + created_at = r.read_u64() + expires_at = r.read_u64() + is_superadmin = r.read_bool() + keys.append(ApiKeyInfoFrame( + key_id=key_id, name=name, created_at=created_at, + expires_at=expires_at, is_superadmin=is_superadmin, + )) + return error_code, keys + + +def decode_set_acl_result(data: bytes) -> int: + """Decode a SetAclResult -> error_code.""" + r = Reader(data) + return r.read_u8() + + +def decode_get_acl_result(data: bytes) -> GetAclResultFrame: + """Decode a GetAclResult.""" + r = Reader(data) + error_code = r.read_u8() + key_id = r.read_string() + is_superadmin = r.read_bool() + perm_count = r.read_u16() + permissions: list[AclPermission] = [] + for _ in range(perm_count): + kind = r.read_string() + pattern = r.read_string() + permissions.append(AclPermission(kind=kind, pattern=pattern)) + return GetAclResultFrame( + error_code=error_code, + key_id=key_id, + is_superadmin=is_superadmin, + permissions=permissions, + ) diff --git a/fila/fibp/opcodes.py b/fila/fibp/opcodes.py new file mode 100644 index 0000000..1d9ff5d --- /dev/null +++ b/fila/fibp/opcodes.py @@ -0,0 +1,120 @@ +"""FIBP opcode constants, error codes, and frame header definition.""" + +from __future__ import annotations + +from dataclasses import dataclass +from enum import IntEnum + +# --------------------------------------------------------------------------- +# Protocol constants +# --------------------------------------------------------------------------- + +PROTOCOL_VERSION: int = 1 +DEFAULT_MAX_FRAME_SIZE: int = 16 * 1024 * 1024 # 16 MiB +FRAME_HEADER_SIZE: int = 6 # opcode(1) + flags(1) + request_id(4) +CONTINUATION_FLAG: int = 0x01 + + +# --------------------------------------------------------------------------- +# Opcodes +# --------------------------------------------------------------------------- + +class Opcode(IntEnum): + """FIBP opcodes.""" + + # Control + HANDSHAKE = 0x01 + HANDSHAKE_OK = 0x02 + PING = 0x03 + PONG = 0x04 + DISCONNECT = 0x05 + + # Hot-path + ENQUEUE = 0x10 + ENQUEUE_RESULT = 0x11 + CONSUME = 0x12 + CONSUME_OK = 0x13 + DELIVERY = 0x14 + CANCEL_CONSUME = 0x15 + ACK = 0x16 + ACK_RESULT = 0x17 + NACK = 0x18 + NACK_RESULT = 0x19 + + # Error + ERROR = 0xFE + + # Admin + CREATE_QUEUE = 0xFD + CREATE_QUEUE_RESULT = 0xFC + DELETE_QUEUE = 0xFB + DELETE_QUEUE_RESULT = 0xFA + GET_STATS = 0xF9 + GET_STATS_RESULT = 0xF8 + LIST_QUEUES = 0xF7 + LIST_QUEUES_RESULT = 0xF6 + SET_CONFIG = 0xF5 + SET_CONFIG_RESULT = 0xF4 + GET_CONFIG = 0xF3 + GET_CONFIG_RESULT = 0xF2 + LIST_CONFIG = 0xF1 + LIST_CONFIG_RESULT = 0xF0 + REDRIVE = 0xEF + REDRIVE_RESULT = 0xEE + + # Auth + CREATE_API_KEY = 0xED + CREATE_API_KEY_RESULT = 0xEC + REVOKE_API_KEY = 0xEB + REVOKE_API_KEY_RESULT = 0xEA + LIST_API_KEYS = 0xE9 + LIST_API_KEYS_RESULT = 0xE8 + SET_ACL = 0xE7 + SET_ACL_RESULT = 0xE6 + GET_ACL = 0xE5 + GET_ACL_RESULT = 0xE4 + + +# --------------------------------------------------------------------------- +# Error codes +# --------------------------------------------------------------------------- + +class ErrorCode(IntEnum): + """FIBP error codes returned in Error frames and per-item results.""" + + OK = 0x00 + QUEUE_NOT_FOUND = 0x01 + MESSAGE_NOT_FOUND = 0x02 + QUEUE_ALREADY_EXISTS = 0x03 + INVALID_ARGUMENT = 0x04 + DEADLINE_EXCEEDED = 0x05 + PERMISSION_DENIED = 0x06 + RESOURCE_EXHAUSTED = 0x07 + PRECONDITION_FAILED = 0x08 + ABORTED = 0x09 + UNAVAILABLE = 0x0A + UNAUTHENTICATED = 0x0B + NOT_LEADER = 0x0C + LUA_ERROR = 0x0D + CHANNEL_FULL = 0x0E + FORBIDDEN = 0x0F + API_KEY_NOT_FOUND = 0x10 + ACL_NOT_FOUND = 0x11 + INTERNAL_ERROR = 0xFF + + +# --------------------------------------------------------------------------- +# Frame header +# --------------------------------------------------------------------------- + +@dataclass(frozen=True, slots=True) +class FrameHeader: + """Parsed 6-byte FIBP frame header.""" + + opcode: int + flags: int + request_id: int + + @property + def is_continuation(self) -> bool: + return bool(self.flags & CONTINUATION_FLAG) diff --git a/fila/fibp/primitives.py b/fila/fibp/primitives.py new file mode 100644 index 0000000..99ec068 --- /dev/null +++ b/fila/fibp/primitives.py @@ -0,0 +1,177 @@ +"""Low-level encoding/decoding primitives for the FIBP wire format. + +All multi-byte integers are big-endian. Strings are length-prefixed with +a u16 byte count followed by UTF-8 bytes. Byte slices use a u32 length +prefix. +""" + +from __future__ import annotations + +import struct + +# --------------------------------------------------------------------------- +# Writer +# --------------------------------------------------------------------------- + +class Writer: + """Accumulates bytes for a FIBP frame body.""" + + __slots__ = ("_buf",) + + def __init__(self) -> None: + self._buf = bytearray() + + # -- scalars ------------------------------------------------------------- + + def write_u8(self, v: int) -> None: + self._buf.append(v & 0xFF) + + def write_u16(self, v: int) -> None: + self._buf.extend(struct.pack("!H", v)) + + def write_u32(self, v: int) -> None: + self._buf.extend(struct.pack("!I", v)) + + def write_u64(self, v: int) -> None: + self._buf.extend(struct.pack("!Q", v)) + + def write_i64(self, v: int) -> None: + self._buf.extend(struct.pack("!q", v)) + + def write_f64(self, v: float) -> None: + self._buf.extend(struct.pack("!d", v)) + + def write_bool(self, v: bool) -> None: + self._buf.append(1 if v else 0) + + # -- composites ---------------------------------------------------------- + + def write_string(self, s: str) -> None: + encoded = s.encode("utf-8") + self.write_u16(len(encoded)) + self._buf.extend(encoded) + + def write_bytes(self, b: bytes) -> None: + self.write_u32(len(b)) + self._buf.extend(b) + + def write_string_map(self, m: dict[str, str]) -> None: + self.write_u16(len(m)) + for k, v in m.items(): + self.write_string(k) + self.write_string(v) + + def write_string_list(self, items: list[str]) -> None: + self.write_u16(len(items)) + for s in items: + self.write_string(s) + + def write_optional_string(self, s: str | None) -> None: + if s is None: + self.write_u8(0) + else: + self.write_u8(1) + self.write_string(s) + + # -- access -------------------------------------------------------------- + + def finish(self) -> bytes: + return bytes(self._buf) + + +# --------------------------------------------------------------------------- +# Reader +# --------------------------------------------------------------------------- + +class Reader: + """Reads primitive values from a FIBP frame body with position tracking.""" + + __slots__ = ("_data", "_pos") + + def __init__(self, data: bytes | bytearray | memoryview) -> None: + self._data = bytes(data) + self._pos = 0 + + @property + def remaining(self) -> int: + return len(self._data) - self._pos + + # -- scalars ------------------------------------------------------------- + + def read_u8(self) -> int: + v = self._data[self._pos] + self._pos += 1 + return v + + def read_u16(self) -> int: + v: int = struct.unpack_from("!H", self._data, self._pos)[0] + self._pos += 2 + return v + + def read_u32(self) -> int: + v: int = struct.unpack_from("!I", self._data, self._pos)[0] + self._pos += 4 + return v + + def read_u64(self) -> int: + v: int = struct.unpack_from("!Q", self._data, self._pos)[0] + self._pos += 8 + return v + + def read_i64(self) -> int: + v: int = struct.unpack_from("!q", self._data, self._pos)[0] + self._pos += 8 + return v + + def read_f64(self) -> float: + v: float = struct.unpack_from("!d", self._data, self._pos)[0] + self._pos += 8 + return v + + def read_bool(self) -> bool: + return self.read_u8() != 0 + + # -- composites ---------------------------------------------------------- + + def read_string(self) -> str: + length = self.read_u16() + end = self._pos + length + if end > len(self._data): + raise ValueError( + f"string length {length} exceeds remaining buffer " + f"({len(self._data) - self._pos} bytes at offset {self._pos})" + ) + s = self._data[self._pos:end].decode("utf-8") + self._pos = end + return s + + def read_bytes(self) -> bytes: + length = self.read_u32() + end = self._pos + length + if end > len(self._data): + raise ValueError( + f"bytes length {length} exceeds remaining buffer " + f"({len(self._data) - self._pos} bytes at offset {self._pos})" + ) + b = self._data[self._pos:end] + self._pos = end + return b + + def read_string_map(self) -> dict[str, str]: + count = self.read_u16() + m: dict[str, str] = {} + for _ in range(count): + k = self.read_string() + v = self.read_string() + m[k] = v + return m + + def read_string_list(self) -> list[str]: + count = self.read_u16() + return [self.read_string() for _ in range(count)] + + def read_optional_string(self) -> str | None: + present = self.read_u8() + if present: + return self.read_string() + return None diff --git a/fila/types.py b/fila/types.py index 2474228..0ecc66f 100644 --- a/fila/types.py +++ b/fila/types.py @@ -2,7 +2,8 @@ from __future__ import annotations -from dataclasses import dataclass +from dataclasses import dataclass, field +from enum import Enum, auto @dataclass(frozen=True) @@ -15,3 +16,135 @@ class ConsumeMessage: fairness_key: str attempt_count: int queue: str + weight: int = 0 + throttle_keys: list[str] | None = None + enqueued_at: int = 0 + leased_at: int = 0 + + +@dataclass(frozen=True) +class EnqueueResult: + """Result for a single message within an enqueue operation. + + Exactly one of ``message_id`` or ``error`` is set. + """ + + message_id: str | None + error: str | None + + @property + def is_success(self) -> bool: + """Return True if this message was enqueued successfully.""" + return self.message_id is not None + + +class AccumulatorMode(Enum): + """Controls how ``enqueue()`` routes messages to the broker. + + - ``AUTO``: Opportunistic accumulation via a background thread. At low load + messages are sent individually; at high load they cluster into batches. + This is the default. + - ``DISABLED``: No accumulation. Each ``enqueue()`` call is a direct RPC. + """ + + AUTO = auto() + DISABLED = auto() + + +@dataclass(frozen=True) +class Linger: + """Timer-based forced accumulation mode. + + Messages are held for up to ``linger_ms`` milliseconds or until + ``max_messages`` messages accumulate, whichever comes first. + + Args: + linger_ms: Maximum time to hold a message before flushing (milliseconds). + max_messages: Maximum number of messages per flush. + """ + + linger_ms: float + max_messages: int + + +@dataclass(frozen=True) +class CreateApiKeyResult: + """Result of creating an API key.""" + + key_id: str + raw_key: str + is_superadmin: bool = False + + +@dataclass(frozen=True) +class ApiKeyInfo: + """Summary information about an API key.""" + + key_id: str + name: str + created_at: int + expires_at: int = 0 + is_superadmin: bool = False + + +@dataclass(frozen=True) +class AclPermission: + """A single ACL permission.""" + + kind: str + pattern: str + + +@dataclass(frozen=True) +class AclEntry: + """ACL entry for an API key.""" + + key_id: str + is_superadmin: bool + permissions: list[AclPermission] = field(default_factory=list) + + +@dataclass(frozen=True) +class FairnessKeyStat: + """Per-fairness-key statistics.""" + + key: str + pending_count: int + current_deficit: int + weight: int + + +@dataclass(frozen=True) +class ThrottleKeyStat: + """Per-throttle-key statistics.""" + + key: str + tokens: float + rate_per_second: float + burst: float + + +@dataclass(frozen=True) +class StatsResult: + """Queue statistics.""" + + depth: int + in_flight: int + active_fairness_keys: int + active_consumers: int + quantum: int + leader_node_id: int = 0 + replication_count: int = 0 + per_key_stats: list[FairnessKeyStat] = field(default_factory=list) + per_throttle_stats: list[ThrottleKeyStat] = field(default_factory=list) + + +@dataclass(frozen=True) +class QueueInfo: + """Summary information about a queue.""" + + name: str + depth: int + in_flight: int + active_consumers: int + leader_node_id: int = 0 diff --git a/fila/v1/__init__.py b/fila/v1/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/fila/v1/admin_pb2.py b/fila/v1/admin_pb2.py deleted file mode 100644 index 4bb4e27..0000000 --- a/fila/v1/admin_pb2.py +++ /dev/null @@ -1,102 +0,0 @@ -# -*- coding: utf-8 -*- -# Generated by the protocol buffer compiler. DO NOT EDIT! -# NO CHECKED-IN PROTOBUF GENCODE -# source: fila/v1/admin.proto -# Protobuf Python Version: 6.31.1 -"""Generated protocol buffer code.""" -from google.protobuf import descriptor as _descriptor -from google.protobuf import descriptor_pool as _descriptor_pool -from google.protobuf import runtime_version as _runtime_version -from google.protobuf import symbol_database as _symbol_database -from google.protobuf.internal import builder as _builder -_runtime_version.ValidateProtobufRuntimeVersion( - _runtime_version.Domain.PUBLIC, - 6, - 31, - 1, - '', - 'fila/v1/admin.proto' -) -# @@protoc_insertion_point(imports) - -_sym_db = _symbol_database.Default() - - - - -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x13\x66ila/v1/admin.proto\x12\x07\x66ila.v1\"H\n\x12\x43reateQueueRequest\x12\x0c\n\x04name\x18\x01 \x01(\t\x12$\n\x06\x63onfig\x18\x02 \x01(\x0b\x32\x14.fila.v1.QueueConfig\"b\n\x0bQueueConfig\x12\x19\n\x11on_enqueue_script\x18\x01 \x01(\t\x12\x19\n\x11on_failure_script\x18\x02 \x01(\t\x12\x1d\n\x15visibility_timeout_ms\x18\x03 \x01(\x04\"\'\n\x13\x43reateQueueResponse\x12\x10\n\x08queue_id\x18\x01 \x01(\t\"#\n\x12\x44\x65leteQueueRequest\x12\r\n\x05queue\x18\x01 \x01(\t\"\x15\n\x13\x44\x65leteQueueResponse\".\n\x10SetConfigRequest\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t\"\x13\n\x11SetConfigResponse\"\x1f\n\x10GetConfigRequest\x12\x0b\n\x03key\x18\x01 \x01(\t\"\"\n\x11GetConfigResponse\x12\r\n\x05value\x18\x01 \x01(\t\")\n\x0b\x43onfigEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t\"#\n\x11ListConfigRequest\x12\x0e\n\x06prefix\x18\x01 \x01(\t\"P\n\x12ListConfigResponse\x12%\n\x07\x65ntries\x18\x01 \x03(\x0b\x32\x14.fila.v1.ConfigEntry\x12\x13\n\x0btotal_count\x18\x02 \x01(\r\" \n\x0fGetStatsRequest\x12\r\n\x05queue\x18\x01 \x01(\t\"b\n\x13PerFairnessKeyStats\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\x15\n\rpending_count\x18\x02 \x01(\x04\x12\x17\n\x0f\x63urrent_deficit\x18\x03 \x01(\x03\x12\x0e\n\x06weight\x18\x04 \x01(\r\"Z\n\x13PerThrottleKeyStats\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\x0e\n\x06tokens\x18\x02 \x01(\x01\x12\x17\n\x0frate_per_second\x18\x03 \x01(\x01\x12\r\n\x05\x62urst\x18\x04 \x01(\x01\"\x9f\x02\n\x10GetStatsResponse\x12\r\n\x05\x64\x65pth\x18\x01 \x01(\x04\x12\x11\n\tin_flight\x18\x02 \x01(\x04\x12\x1c\n\x14\x61\x63tive_fairness_keys\x18\x03 \x01(\x04\x12\x18\n\x10\x61\x63tive_consumers\x18\x04 \x01(\r\x12\x0f\n\x07quantum\x18\x05 \x01(\r\x12\x33\n\rper_key_stats\x18\x06 \x03(\x0b\x32\x1c.fila.v1.PerFairnessKeyStats\x12\x38\n\x12per_throttle_stats\x18\x07 \x03(\x0b\x32\x1c.fila.v1.PerThrottleKeyStats\x12\x16\n\x0eleader_node_id\x18\x08 \x01(\x04\x12\x19\n\x11replication_count\x18\t \x01(\r\"2\n\x0eRedriveRequest\x12\x11\n\tdlq_queue\x18\x01 \x01(\t\x12\r\n\x05\x63ount\x18\x02 \x01(\x04\"#\n\x0fRedriveResponse\x12\x10\n\x08redriven\x18\x01 \x01(\x04\"\x13\n\x11ListQueuesRequest\"m\n\tQueueInfo\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\r\n\x05\x64\x65pth\x18\x02 \x01(\x04\x12\x11\n\tin_flight\x18\x03 \x01(\x04\x12\x18\n\x10\x61\x63tive_consumers\x18\x04 \x01(\r\x12\x16\n\x0eleader_node_id\x18\x05 \x01(\x04\"T\n\x12ListQueuesResponse\x12\"\n\x06queues\x18\x01 \x03(\x0b\x32\x12.fila.v1.QueueInfo\x12\x1a\n\x12\x63luster_node_count\x18\x02 \x01(\r\"Q\n\x13\x43reateApiKeyRequest\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x15\n\rexpires_at_ms\x18\x02 \x01(\x04\x12\x15\n\ris_superadmin\x18\x03 \x01(\x08\"J\n\x14\x43reateApiKeyResponse\x12\x0e\n\x06key_id\x18\x01 \x01(\t\x12\x0b\n\x03key\x18\x02 \x01(\t\x12\x15\n\ris_superadmin\x18\x03 \x01(\x08\"%\n\x13RevokeApiKeyRequest\x12\x0e\n\x06key_id\x18\x01 \x01(\t\"\x16\n\x14RevokeApiKeyResponse\"\x14\n\x12ListApiKeysRequest\"o\n\nApiKeyInfo\x12\x0e\n\x06key_id\x18\x01 \x01(\t\x12\x0c\n\x04name\x18\x02 \x01(\t\x12\x15\n\rcreated_at_ms\x18\x03 \x01(\x04\x12\x15\n\rexpires_at_ms\x18\x04 \x01(\x04\x12\x15\n\ris_superadmin\x18\x05 \x01(\x08\"8\n\x13ListApiKeysResponse\x12!\n\x04keys\x18\x01 \x03(\x0b\x32\x13.fila.v1.ApiKeyInfo\".\n\rAclPermission\x12\x0c\n\x04kind\x18\x01 \x01(\t\x12\x0f\n\x07pattern\x18\x02 \x01(\t\"L\n\rSetAclRequest\x12\x0e\n\x06key_id\x18\x01 \x01(\t\x12+\n\x0bpermissions\x18\x02 \x03(\x0b\x32\x16.fila.v1.AclPermission\"\x10\n\x0eSetAclResponse\"\x1f\n\rGetAclRequest\x12\x0e\n\x06key_id\x18\x01 \x01(\t\"d\n\x0eGetAclResponse\x12\x0e\n\x06key_id\x18\x01 \x01(\t\x12+\n\x0bpermissions\x18\x02 \x03(\x0b\x32\x16.fila.v1.AclPermission\x12\x15\n\ris_superadmin\x18\x03 \x01(\x08\x32\x8e\x07\n\tFilaAdmin\x12H\n\x0b\x43reateQueue\x12\x1b.fila.v1.CreateQueueRequest\x1a\x1c.fila.v1.CreateQueueResponse\x12H\n\x0b\x44\x65leteQueue\x12\x1b.fila.v1.DeleteQueueRequest\x1a\x1c.fila.v1.DeleteQueueResponse\x12\x42\n\tSetConfig\x12\x19.fila.v1.SetConfigRequest\x1a\x1a.fila.v1.SetConfigResponse\x12\x42\n\tGetConfig\x12\x19.fila.v1.GetConfigRequest\x1a\x1a.fila.v1.GetConfigResponse\x12\x45\n\nListConfig\x12\x1a.fila.v1.ListConfigRequest\x1a\x1b.fila.v1.ListConfigResponse\x12?\n\x08GetStats\x12\x18.fila.v1.GetStatsRequest\x1a\x19.fila.v1.GetStatsResponse\x12<\n\x07Redrive\x12\x17.fila.v1.RedriveRequest\x1a\x18.fila.v1.RedriveResponse\x12\x45\n\nListQueues\x12\x1a.fila.v1.ListQueuesRequest\x1a\x1b.fila.v1.ListQueuesResponse\x12K\n\x0c\x43reateApiKey\x12\x1c.fila.v1.CreateApiKeyRequest\x1a\x1d.fila.v1.CreateApiKeyResponse\x12K\n\x0cRevokeApiKey\x12\x1c.fila.v1.RevokeApiKeyRequest\x1a\x1d.fila.v1.RevokeApiKeyResponse\x12H\n\x0bListApiKeys\x12\x1b.fila.v1.ListApiKeysRequest\x1a\x1c.fila.v1.ListApiKeysResponse\x12\x39\n\x06SetAcl\x12\x16.fila.v1.SetAclRequest\x1a\x17.fila.v1.SetAclResponse\x12\x39\n\x06GetAcl\x12\x16.fila.v1.GetAclRequest\x1a\x17.fila.v1.GetAclResponseb\x06proto3') - -_globals = globals() -_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) -_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'fila.v1.admin_pb2', _globals) -if not _descriptor._USE_C_DESCRIPTORS: - DESCRIPTOR._loaded_options = None - _globals['_CREATEQUEUEREQUEST']._serialized_start=32 - _globals['_CREATEQUEUEREQUEST']._serialized_end=104 - _globals['_QUEUECONFIG']._serialized_start=106 - _globals['_QUEUECONFIG']._serialized_end=204 - _globals['_CREATEQUEUERESPONSE']._serialized_start=206 - _globals['_CREATEQUEUERESPONSE']._serialized_end=245 - _globals['_DELETEQUEUEREQUEST']._serialized_start=247 - _globals['_DELETEQUEUEREQUEST']._serialized_end=282 - _globals['_DELETEQUEUERESPONSE']._serialized_start=284 - _globals['_DELETEQUEUERESPONSE']._serialized_end=305 - _globals['_SETCONFIGREQUEST']._serialized_start=307 - _globals['_SETCONFIGREQUEST']._serialized_end=353 - _globals['_SETCONFIGRESPONSE']._serialized_start=355 - _globals['_SETCONFIGRESPONSE']._serialized_end=374 - _globals['_GETCONFIGREQUEST']._serialized_start=376 - _globals['_GETCONFIGREQUEST']._serialized_end=407 - _globals['_GETCONFIGRESPONSE']._serialized_start=409 - _globals['_GETCONFIGRESPONSE']._serialized_end=443 - _globals['_CONFIGENTRY']._serialized_start=445 - _globals['_CONFIGENTRY']._serialized_end=486 - _globals['_LISTCONFIGREQUEST']._serialized_start=488 - _globals['_LISTCONFIGREQUEST']._serialized_end=523 - _globals['_LISTCONFIGRESPONSE']._serialized_start=525 - _globals['_LISTCONFIGRESPONSE']._serialized_end=605 - _globals['_GETSTATSREQUEST']._serialized_start=607 - _globals['_GETSTATSREQUEST']._serialized_end=639 - _globals['_PERFAIRNESSKEYSTATS']._serialized_start=641 - _globals['_PERFAIRNESSKEYSTATS']._serialized_end=739 - _globals['_PERTHROTTLEKEYSTATS']._serialized_start=741 - _globals['_PERTHROTTLEKEYSTATS']._serialized_end=831 - _globals['_GETSTATSRESPONSE']._serialized_start=834 - _globals['_GETSTATSRESPONSE']._serialized_end=1121 - _globals['_REDRIVEREQUEST']._serialized_start=1123 - _globals['_REDRIVEREQUEST']._serialized_end=1173 - _globals['_REDRIVERESPONSE']._serialized_start=1175 - _globals['_REDRIVERESPONSE']._serialized_end=1210 - _globals['_LISTQUEUESREQUEST']._serialized_start=1212 - _globals['_LISTQUEUESREQUEST']._serialized_end=1231 - _globals['_QUEUEINFO']._serialized_start=1233 - _globals['_QUEUEINFO']._serialized_end=1342 - _globals['_LISTQUEUESRESPONSE']._serialized_start=1344 - _globals['_LISTQUEUESRESPONSE']._serialized_end=1428 - _globals['_CREATEAPIKEYREQUEST']._serialized_start=1430 - _globals['_CREATEAPIKEYREQUEST']._serialized_end=1511 - _globals['_CREATEAPIKEYRESPONSE']._serialized_start=1513 - _globals['_CREATEAPIKEYRESPONSE']._serialized_end=1587 - _globals['_REVOKEAPIKEYREQUEST']._serialized_start=1589 - _globals['_REVOKEAPIKEYREQUEST']._serialized_end=1626 - _globals['_REVOKEAPIKEYRESPONSE']._serialized_start=1628 - _globals['_REVOKEAPIKEYRESPONSE']._serialized_end=1650 - _globals['_LISTAPIKEYSREQUEST']._serialized_start=1652 - _globals['_LISTAPIKEYSREQUEST']._serialized_end=1672 - _globals['_APIKEYINFO']._serialized_start=1674 - _globals['_APIKEYINFO']._serialized_end=1785 - _globals['_LISTAPIKEYSRESPONSE']._serialized_start=1787 - _globals['_LISTAPIKEYSRESPONSE']._serialized_end=1843 - _globals['_ACLPERMISSION']._serialized_start=1845 - _globals['_ACLPERMISSION']._serialized_end=1891 - _globals['_SETACLREQUEST']._serialized_start=1893 - _globals['_SETACLREQUEST']._serialized_end=1969 - _globals['_SETACLRESPONSE']._serialized_start=1971 - _globals['_SETACLRESPONSE']._serialized_end=1987 - _globals['_GETACLREQUEST']._serialized_start=1989 - _globals['_GETACLREQUEST']._serialized_end=2020 - _globals['_GETACLRESPONSE']._serialized_start=2022 - _globals['_GETACLRESPONSE']._serialized_end=2122 - _globals['_FILAADMIN']._serialized_start=2125 - _globals['_FILAADMIN']._serialized_end=3035 -# @@protoc_insertion_point(module_scope) diff --git a/fila/v1/admin_pb2.pyi b/fila/v1/admin_pb2.pyi deleted file mode 100644 index d603b29..0000000 --- a/fila/v1/admin_pb2.pyi +++ /dev/null @@ -1,269 +0,0 @@ -from google.protobuf.internal import containers as _containers -from google.protobuf import descriptor as _descriptor -from google.protobuf import message as _message -from collections.abc import Iterable as _Iterable, Mapping as _Mapping -from typing import ClassVar as _ClassVar, Optional as _Optional, Union as _Union - -DESCRIPTOR: _descriptor.FileDescriptor - -class CreateQueueRequest(_message.Message): - __slots__ = ("name", "config") - NAME_FIELD_NUMBER: _ClassVar[int] - CONFIG_FIELD_NUMBER: _ClassVar[int] - name: str - config: QueueConfig - def __init__(self, name: _Optional[str] = ..., config: _Optional[_Union[QueueConfig, _Mapping]] = ...) -> None: ... - -class QueueConfig(_message.Message): - __slots__ = ("on_enqueue_script", "on_failure_script", "visibility_timeout_ms") - ON_ENQUEUE_SCRIPT_FIELD_NUMBER: _ClassVar[int] - ON_FAILURE_SCRIPT_FIELD_NUMBER: _ClassVar[int] - VISIBILITY_TIMEOUT_MS_FIELD_NUMBER: _ClassVar[int] - on_enqueue_script: str - on_failure_script: str - visibility_timeout_ms: int - def __init__(self, on_enqueue_script: _Optional[str] = ..., on_failure_script: _Optional[str] = ..., visibility_timeout_ms: _Optional[int] = ...) -> None: ... - -class CreateQueueResponse(_message.Message): - __slots__ = ("queue_id",) - QUEUE_ID_FIELD_NUMBER: _ClassVar[int] - queue_id: str - def __init__(self, queue_id: _Optional[str] = ...) -> None: ... - -class DeleteQueueRequest(_message.Message): - __slots__ = ("queue",) - QUEUE_FIELD_NUMBER: _ClassVar[int] - queue: str - def __init__(self, queue: _Optional[str] = ...) -> None: ... - -class DeleteQueueResponse(_message.Message): - __slots__ = () - def __init__(self) -> None: ... - -class SetConfigRequest(_message.Message): - __slots__ = ("key", "value") - KEY_FIELD_NUMBER: _ClassVar[int] - VALUE_FIELD_NUMBER: _ClassVar[int] - key: str - value: str - def __init__(self, key: _Optional[str] = ..., value: _Optional[str] = ...) -> None: ... - -class SetConfigResponse(_message.Message): - __slots__ = () - def __init__(self) -> None: ... - -class GetConfigRequest(_message.Message): - __slots__ = ("key",) - KEY_FIELD_NUMBER: _ClassVar[int] - key: str - def __init__(self, key: _Optional[str] = ...) -> None: ... - -class GetConfigResponse(_message.Message): - __slots__ = ("value",) - VALUE_FIELD_NUMBER: _ClassVar[int] - value: str - def __init__(self, value: _Optional[str] = ...) -> None: ... - -class ConfigEntry(_message.Message): - __slots__ = ("key", "value") - KEY_FIELD_NUMBER: _ClassVar[int] - VALUE_FIELD_NUMBER: _ClassVar[int] - key: str - value: str - def __init__(self, key: _Optional[str] = ..., value: _Optional[str] = ...) -> None: ... - -class ListConfigRequest(_message.Message): - __slots__ = ("prefix",) - PREFIX_FIELD_NUMBER: _ClassVar[int] - prefix: str - def __init__(self, prefix: _Optional[str] = ...) -> None: ... - -class ListConfigResponse(_message.Message): - __slots__ = ("entries", "total_count") - ENTRIES_FIELD_NUMBER: _ClassVar[int] - TOTAL_COUNT_FIELD_NUMBER: _ClassVar[int] - entries: _containers.RepeatedCompositeFieldContainer[ConfigEntry] - total_count: int - def __init__(self, entries: _Optional[_Iterable[_Union[ConfigEntry, _Mapping]]] = ..., total_count: _Optional[int] = ...) -> None: ... - -class GetStatsRequest(_message.Message): - __slots__ = ("queue",) - QUEUE_FIELD_NUMBER: _ClassVar[int] - queue: str - def __init__(self, queue: _Optional[str] = ...) -> None: ... - -class PerFairnessKeyStats(_message.Message): - __slots__ = ("key", "pending_count", "current_deficit", "weight") - KEY_FIELD_NUMBER: _ClassVar[int] - PENDING_COUNT_FIELD_NUMBER: _ClassVar[int] - CURRENT_DEFICIT_FIELD_NUMBER: _ClassVar[int] - WEIGHT_FIELD_NUMBER: _ClassVar[int] - key: str - pending_count: int - current_deficit: int - weight: int - def __init__(self, key: _Optional[str] = ..., pending_count: _Optional[int] = ..., current_deficit: _Optional[int] = ..., weight: _Optional[int] = ...) -> None: ... - -class PerThrottleKeyStats(_message.Message): - __slots__ = ("key", "tokens", "rate_per_second", "burst") - KEY_FIELD_NUMBER: _ClassVar[int] - TOKENS_FIELD_NUMBER: _ClassVar[int] - RATE_PER_SECOND_FIELD_NUMBER: _ClassVar[int] - BURST_FIELD_NUMBER: _ClassVar[int] - key: str - tokens: float - rate_per_second: float - burst: float - def __init__(self, key: _Optional[str] = ..., tokens: _Optional[float] = ..., rate_per_second: _Optional[float] = ..., burst: _Optional[float] = ...) -> None: ... - -class GetStatsResponse(_message.Message): - __slots__ = ("depth", "in_flight", "active_fairness_keys", "active_consumers", "quantum", "per_key_stats", "per_throttle_stats", "leader_node_id", "replication_count") - DEPTH_FIELD_NUMBER: _ClassVar[int] - IN_FLIGHT_FIELD_NUMBER: _ClassVar[int] - ACTIVE_FAIRNESS_KEYS_FIELD_NUMBER: _ClassVar[int] - ACTIVE_CONSUMERS_FIELD_NUMBER: _ClassVar[int] - QUANTUM_FIELD_NUMBER: _ClassVar[int] - PER_KEY_STATS_FIELD_NUMBER: _ClassVar[int] - PER_THROTTLE_STATS_FIELD_NUMBER: _ClassVar[int] - LEADER_NODE_ID_FIELD_NUMBER: _ClassVar[int] - REPLICATION_COUNT_FIELD_NUMBER: _ClassVar[int] - depth: int - in_flight: int - active_fairness_keys: int - active_consumers: int - quantum: int - per_key_stats: _containers.RepeatedCompositeFieldContainer[PerFairnessKeyStats] - per_throttle_stats: _containers.RepeatedCompositeFieldContainer[PerThrottleKeyStats] - leader_node_id: int - replication_count: int - def __init__(self, depth: _Optional[int] = ..., in_flight: _Optional[int] = ..., active_fairness_keys: _Optional[int] = ..., active_consumers: _Optional[int] = ..., quantum: _Optional[int] = ..., per_key_stats: _Optional[_Iterable[_Union[PerFairnessKeyStats, _Mapping]]] = ..., per_throttle_stats: _Optional[_Iterable[_Union[PerThrottleKeyStats, _Mapping]]] = ..., leader_node_id: _Optional[int] = ..., replication_count: _Optional[int] = ...) -> None: ... - -class RedriveRequest(_message.Message): - __slots__ = ("dlq_queue", "count") - DLQ_QUEUE_FIELD_NUMBER: _ClassVar[int] - COUNT_FIELD_NUMBER: _ClassVar[int] - dlq_queue: str - count: int - def __init__(self, dlq_queue: _Optional[str] = ..., count: _Optional[int] = ...) -> None: ... - -class RedriveResponse(_message.Message): - __slots__ = ("redriven",) - REDRIVEN_FIELD_NUMBER: _ClassVar[int] - redriven: int - def __init__(self, redriven: _Optional[int] = ...) -> None: ... - -class ListQueuesRequest(_message.Message): - __slots__ = () - def __init__(self) -> None: ... - -class QueueInfo(_message.Message): - __slots__ = ("name", "depth", "in_flight", "active_consumers", "leader_node_id") - NAME_FIELD_NUMBER: _ClassVar[int] - DEPTH_FIELD_NUMBER: _ClassVar[int] - IN_FLIGHT_FIELD_NUMBER: _ClassVar[int] - ACTIVE_CONSUMERS_FIELD_NUMBER: _ClassVar[int] - LEADER_NODE_ID_FIELD_NUMBER: _ClassVar[int] - name: str - depth: int - in_flight: int - active_consumers: int - leader_node_id: int - def __init__(self, name: _Optional[str] = ..., depth: _Optional[int] = ..., in_flight: _Optional[int] = ..., active_consumers: _Optional[int] = ..., leader_node_id: _Optional[int] = ...) -> None: ... - -class ListQueuesResponse(_message.Message): - __slots__ = ("queues", "cluster_node_count") - QUEUES_FIELD_NUMBER: _ClassVar[int] - CLUSTER_NODE_COUNT_FIELD_NUMBER: _ClassVar[int] - queues: _containers.RepeatedCompositeFieldContainer[QueueInfo] - cluster_node_count: int - def __init__(self, queues: _Optional[_Iterable[_Union[QueueInfo, _Mapping]]] = ..., cluster_node_count: _Optional[int] = ...) -> None: ... - -class CreateApiKeyRequest(_message.Message): - __slots__ = ("name", "expires_at_ms", "is_superadmin") - NAME_FIELD_NUMBER: _ClassVar[int] - EXPIRES_AT_MS_FIELD_NUMBER: _ClassVar[int] - IS_SUPERADMIN_FIELD_NUMBER: _ClassVar[int] - name: str - expires_at_ms: int - is_superadmin: bool - def __init__(self, name: _Optional[str] = ..., expires_at_ms: _Optional[int] = ..., is_superadmin: bool = ...) -> None: ... - -class CreateApiKeyResponse(_message.Message): - __slots__ = ("key_id", "key", "is_superadmin") - KEY_ID_FIELD_NUMBER: _ClassVar[int] - KEY_FIELD_NUMBER: _ClassVar[int] - IS_SUPERADMIN_FIELD_NUMBER: _ClassVar[int] - key_id: str - key: str - is_superadmin: bool - def __init__(self, key_id: _Optional[str] = ..., key: _Optional[str] = ..., is_superadmin: bool = ...) -> None: ... - -class RevokeApiKeyRequest(_message.Message): - __slots__ = ("key_id",) - KEY_ID_FIELD_NUMBER: _ClassVar[int] - key_id: str - def __init__(self, key_id: _Optional[str] = ...) -> None: ... - -class RevokeApiKeyResponse(_message.Message): - __slots__ = () - def __init__(self) -> None: ... - -class ListApiKeysRequest(_message.Message): - __slots__ = () - def __init__(self) -> None: ... - -class ApiKeyInfo(_message.Message): - __slots__ = ("key_id", "name", "created_at_ms", "expires_at_ms", "is_superadmin") - KEY_ID_FIELD_NUMBER: _ClassVar[int] - NAME_FIELD_NUMBER: _ClassVar[int] - CREATED_AT_MS_FIELD_NUMBER: _ClassVar[int] - EXPIRES_AT_MS_FIELD_NUMBER: _ClassVar[int] - IS_SUPERADMIN_FIELD_NUMBER: _ClassVar[int] - key_id: str - name: str - created_at_ms: int - expires_at_ms: int - is_superadmin: bool - def __init__(self, key_id: _Optional[str] = ..., name: _Optional[str] = ..., created_at_ms: _Optional[int] = ..., expires_at_ms: _Optional[int] = ..., is_superadmin: bool = ...) -> None: ... - -class ListApiKeysResponse(_message.Message): - __slots__ = ("keys",) - KEYS_FIELD_NUMBER: _ClassVar[int] - keys: _containers.RepeatedCompositeFieldContainer[ApiKeyInfo] - def __init__(self, keys: _Optional[_Iterable[_Union[ApiKeyInfo, _Mapping]]] = ...) -> None: ... - -class AclPermission(_message.Message): - __slots__ = ("kind", "pattern") - KIND_FIELD_NUMBER: _ClassVar[int] - PATTERN_FIELD_NUMBER: _ClassVar[int] - kind: str - pattern: str - def __init__(self, kind: _Optional[str] = ..., pattern: _Optional[str] = ...) -> None: ... - -class SetAclRequest(_message.Message): - __slots__ = ("key_id", "permissions") - KEY_ID_FIELD_NUMBER: _ClassVar[int] - PERMISSIONS_FIELD_NUMBER: _ClassVar[int] - key_id: str - permissions: _containers.RepeatedCompositeFieldContainer[AclPermission] - def __init__(self, key_id: _Optional[str] = ..., permissions: _Optional[_Iterable[_Union[AclPermission, _Mapping]]] = ...) -> None: ... - -class SetAclResponse(_message.Message): - __slots__ = () - def __init__(self) -> None: ... - -class GetAclRequest(_message.Message): - __slots__ = ("key_id",) - KEY_ID_FIELD_NUMBER: _ClassVar[int] - key_id: str - def __init__(self, key_id: _Optional[str] = ...) -> None: ... - -class GetAclResponse(_message.Message): - __slots__ = ("key_id", "permissions", "is_superadmin") - KEY_ID_FIELD_NUMBER: _ClassVar[int] - PERMISSIONS_FIELD_NUMBER: _ClassVar[int] - IS_SUPERADMIN_FIELD_NUMBER: _ClassVar[int] - key_id: str - permissions: _containers.RepeatedCompositeFieldContainer[AclPermission] - is_superadmin: bool - def __init__(self, key_id: _Optional[str] = ..., permissions: _Optional[_Iterable[_Union[AclPermission, _Mapping]]] = ..., is_superadmin: bool = ...) -> None: ... diff --git a/fila/v1/admin_pb2_grpc.py b/fila/v1/admin_pb2_grpc.py deleted file mode 100644 index 93d6c4e..0000000 --- a/fila/v1/admin_pb2_grpc.py +++ /dev/null @@ -1,618 +0,0 @@ -# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! -"""Client and server classes corresponding to protobuf-defined services.""" -import grpc -import warnings - -from fila.v1 import admin_pb2 as fila_dot_v1_dot_admin__pb2 - -GRPC_GENERATED_VERSION = '1.78.0' -GRPC_VERSION = grpc.__version__ -_version_not_supported = False - -try: - from grpc._utilities import first_version_is_lower - _version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION) -except ImportError: - _version_not_supported = True - -if _version_not_supported: - raise RuntimeError( - f'The grpc package installed is at version {GRPC_VERSION},' - + ' but the generated code in fila/v1/admin_pb2_grpc.py depends on' - + f' grpcio>={GRPC_GENERATED_VERSION}.' - + f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}' - + f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.' - ) - - -class FilaAdminStub(object): - """Admin RPCs for operators and the CLI. - """ - - def __init__(self, channel): - """Constructor. - - Args: - channel: A grpc.Channel. - """ - self.CreateQueue = channel.unary_unary( - '/fila.v1.FilaAdmin/CreateQueue', - request_serializer=fila_dot_v1_dot_admin__pb2.CreateQueueRequest.SerializeToString, - response_deserializer=fila_dot_v1_dot_admin__pb2.CreateQueueResponse.FromString, - _registered_method=True) - self.DeleteQueue = channel.unary_unary( - '/fila.v1.FilaAdmin/DeleteQueue', - request_serializer=fila_dot_v1_dot_admin__pb2.DeleteQueueRequest.SerializeToString, - response_deserializer=fila_dot_v1_dot_admin__pb2.DeleteQueueResponse.FromString, - _registered_method=True) - self.SetConfig = channel.unary_unary( - '/fila.v1.FilaAdmin/SetConfig', - request_serializer=fila_dot_v1_dot_admin__pb2.SetConfigRequest.SerializeToString, - response_deserializer=fila_dot_v1_dot_admin__pb2.SetConfigResponse.FromString, - _registered_method=True) - self.GetConfig = channel.unary_unary( - '/fila.v1.FilaAdmin/GetConfig', - request_serializer=fila_dot_v1_dot_admin__pb2.GetConfigRequest.SerializeToString, - response_deserializer=fila_dot_v1_dot_admin__pb2.GetConfigResponse.FromString, - _registered_method=True) - self.ListConfig = channel.unary_unary( - '/fila.v1.FilaAdmin/ListConfig', - request_serializer=fila_dot_v1_dot_admin__pb2.ListConfigRequest.SerializeToString, - response_deserializer=fila_dot_v1_dot_admin__pb2.ListConfigResponse.FromString, - _registered_method=True) - self.GetStats = channel.unary_unary( - '/fila.v1.FilaAdmin/GetStats', - request_serializer=fila_dot_v1_dot_admin__pb2.GetStatsRequest.SerializeToString, - response_deserializer=fila_dot_v1_dot_admin__pb2.GetStatsResponse.FromString, - _registered_method=True) - self.Redrive = channel.unary_unary( - '/fila.v1.FilaAdmin/Redrive', - request_serializer=fila_dot_v1_dot_admin__pb2.RedriveRequest.SerializeToString, - response_deserializer=fila_dot_v1_dot_admin__pb2.RedriveResponse.FromString, - _registered_method=True) - self.ListQueues = channel.unary_unary( - '/fila.v1.FilaAdmin/ListQueues', - request_serializer=fila_dot_v1_dot_admin__pb2.ListQueuesRequest.SerializeToString, - response_deserializer=fila_dot_v1_dot_admin__pb2.ListQueuesResponse.FromString, - _registered_method=True) - self.CreateApiKey = channel.unary_unary( - '/fila.v1.FilaAdmin/CreateApiKey', - request_serializer=fila_dot_v1_dot_admin__pb2.CreateApiKeyRequest.SerializeToString, - response_deserializer=fila_dot_v1_dot_admin__pb2.CreateApiKeyResponse.FromString, - _registered_method=True) - self.RevokeApiKey = channel.unary_unary( - '/fila.v1.FilaAdmin/RevokeApiKey', - request_serializer=fila_dot_v1_dot_admin__pb2.RevokeApiKeyRequest.SerializeToString, - response_deserializer=fila_dot_v1_dot_admin__pb2.RevokeApiKeyResponse.FromString, - _registered_method=True) - self.ListApiKeys = channel.unary_unary( - '/fila.v1.FilaAdmin/ListApiKeys', - request_serializer=fila_dot_v1_dot_admin__pb2.ListApiKeysRequest.SerializeToString, - response_deserializer=fila_dot_v1_dot_admin__pb2.ListApiKeysResponse.FromString, - _registered_method=True) - self.SetAcl = channel.unary_unary( - '/fila.v1.FilaAdmin/SetAcl', - request_serializer=fila_dot_v1_dot_admin__pb2.SetAclRequest.SerializeToString, - response_deserializer=fila_dot_v1_dot_admin__pb2.SetAclResponse.FromString, - _registered_method=True) - self.GetAcl = channel.unary_unary( - '/fila.v1.FilaAdmin/GetAcl', - request_serializer=fila_dot_v1_dot_admin__pb2.GetAclRequest.SerializeToString, - response_deserializer=fila_dot_v1_dot_admin__pb2.GetAclResponse.FromString, - _registered_method=True) - - -class FilaAdminServicer(object): - """Admin RPCs for operators and the CLI. - """ - - def CreateQueue(self, request, context): - """Missing associated documentation comment in .proto file.""" - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') - - def DeleteQueue(self, request, context): - """Missing associated documentation comment in .proto file.""" - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') - - def SetConfig(self, request, context): - """Missing associated documentation comment in .proto file.""" - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') - - def GetConfig(self, request, context): - """Missing associated documentation comment in .proto file.""" - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') - - def ListConfig(self, request, context): - """Missing associated documentation comment in .proto file.""" - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') - - def GetStats(self, request, context): - """Missing associated documentation comment in .proto file.""" - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') - - def Redrive(self, request, context): - """Missing associated documentation comment in .proto file.""" - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') - - def ListQueues(self, request, context): - """Missing associated documentation comment in .proto file.""" - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') - - def CreateApiKey(self, request, context): - """API key management. CreateApiKey bypasses auth (bootstrap); others require a valid key. - """ - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') - - def RevokeApiKey(self, request, context): - """Missing associated documentation comment in .proto file.""" - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') - - def ListApiKeys(self, request, context): - """Missing associated documentation comment in .proto file.""" - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') - - def SetAcl(self, request, context): - """Per-key ACL management. - """ - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') - - def GetAcl(self, request, context): - """Missing associated documentation comment in .proto file.""" - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') - - -def add_FilaAdminServicer_to_server(servicer, server): - rpc_method_handlers = { - 'CreateQueue': grpc.unary_unary_rpc_method_handler( - servicer.CreateQueue, - request_deserializer=fila_dot_v1_dot_admin__pb2.CreateQueueRequest.FromString, - response_serializer=fila_dot_v1_dot_admin__pb2.CreateQueueResponse.SerializeToString, - ), - 'DeleteQueue': grpc.unary_unary_rpc_method_handler( - servicer.DeleteQueue, - request_deserializer=fila_dot_v1_dot_admin__pb2.DeleteQueueRequest.FromString, - response_serializer=fila_dot_v1_dot_admin__pb2.DeleteQueueResponse.SerializeToString, - ), - 'SetConfig': grpc.unary_unary_rpc_method_handler( - servicer.SetConfig, - request_deserializer=fila_dot_v1_dot_admin__pb2.SetConfigRequest.FromString, - response_serializer=fila_dot_v1_dot_admin__pb2.SetConfigResponse.SerializeToString, - ), - 'GetConfig': grpc.unary_unary_rpc_method_handler( - servicer.GetConfig, - request_deserializer=fila_dot_v1_dot_admin__pb2.GetConfigRequest.FromString, - response_serializer=fila_dot_v1_dot_admin__pb2.GetConfigResponse.SerializeToString, - ), - 'ListConfig': grpc.unary_unary_rpc_method_handler( - servicer.ListConfig, - request_deserializer=fila_dot_v1_dot_admin__pb2.ListConfigRequest.FromString, - response_serializer=fila_dot_v1_dot_admin__pb2.ListConfigResponse.SerializeToString, - ), - 'GetStats': grpc.unary_unary_rpc_method_handler( - servicer.GetStats, - request_deserializer=fila_dot_v1_dot_admin__pb2.GetStatsRequest.FromString, - response_serializer=fila_dot_v1_dot_admin__pb2.GetStatsResponse.SerializeToString, - ), - 'Redrive': grpc.unary_unary_rpc_method_handler( - servicer.Redrive, - request_deserializer=fila_dot_v1_dot_admin__pb2.RedriveRequest.FromString, - response_serializer=fila_dot_v1_dot_admin__pb2.RedriveResponse.SerializeToString, - ), - 'ListQueues': grpc.unary_unary_rpc_method_handler( - servicer.ListQueues, - request_deserializer=fila_dot_v1_dot_admin__pb2.ListQueuesRequest.FromString, - response_serializer=fila_dot_v1_dot_admin__pb2.ListQueuesResponse.SerializeToString, - ), - 'CreateApiKey': grpc.unary_unary_rpc_method_handler( - servicer.CreateApiKey, - request_deserializer=fila_dot_v1_dot_admin__pb2.CreateApiKeyRequest.FromString, - response_serializer=fila_dot_v1_dot_admin__pb2.CreateApiKeyResponse.SerializeToString, - ), - 'RevokeApiKey': grpc.unary_unary_rpc_method_handler( - servicer.RevokeApiKey, - request_deserializer=fila_dot_v1_dot_admin__pb2.RevokeApiKeyRequest.FromString, - response_serializer=fila_dot_v1_dot_admin__pb2.RevokeApiKeyResponse.SerializeToString, - ), - 'ListApiKeys': grpc.unary_unary_rpc_method_handler( - servicer.ListApiKeys, - request_deserializer=fila_dot_v1_dot_admin__pb2.ListApiKeysRequest.FromString, - response_serializer=fila_dot_v1_dot_admin__pb2.ListApiKeysResponse.SerializeToString, - ), - 'SetAcl': grpc.unary_unary_rpc_method_handler( - servicer.SetAcl, - request_deserializer=fila_dot_v1_dot_admin__pb2.SetAclRequest.FromString, - response_serializer=fila_dot_v1_dot_admin__pb2.SetAclResponse.SerializeToString, - ), - 'GetAcl': grpc.unary_unary_rpc_method_handler( - servicer.GetAcl, - request_deserializer=fila_dot_v1_dot_admin__pb2.GetAclRequest.FromString, - response_serializer=fila_dot_v1_dot_admin__pb2.GetAclResponse.SerializeToString, - ), - } - generic_handler = grpc.method_handlers_generic_handler( - 'fila.v1.FilaAdmin', rpc_method_handlers) - server.add_generic_rpc_handlers((generic_handler,)) - server.add_registered_method_handlers('fila.v1.FilaAdmin', rpc_method_handlers) - - - # This class is part of an EXPERIMENTAL API. -class FilaAdmin(object): - """Admin RPCs for operators and the CLI. - """ - - @staticmethod - def CreateQueue(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary( - request, - target, - '/fila.v1.FilaAdmin/CreateQueue', - fila_dot_v1_dot_admin__pb2.CreateQueueRequest.SerializeToString, - fila_dot_v1_dot_admin__pb2.CreateQueueResponse.FromString, - options, - channel_credentials, - insecure, - call_credentials, - compression, - wait_for_ready, - timeout, - metadata, - _registered_method=True) - - @staticmethod - def DeleteQueue(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary( - request, - target, - '/fila.v1.FilaAdmin/DeleteQueue', - fila_dot_v1_dot_admin__pb2.DeleteQueueRequest.SerializeToString, - fila_dot_v1_dot_admin__pb2.DeleteQueueResponse.FromString, - options, - channel_credentials, - insecure, - call_credentials, - compression, - wait_for_ready, - timeout, - metadata, - _registered_method=True) - - @staticmethod - def SetConfig(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary( - request, - target, - '/fila.v1.FilaAdmin/SetConfig', - fila_dot_v1_dot_admin__pb2.SetConfigRequest.SerializeToString, - fila_dot_v1_dot_admin__pb2.SetConfigResponse.FromString, - options, - channel_credentials, - insecure, - call_credentials, - compression, - wait_for_ready, - timeout, - metadata, - _registered_method=True) - - @staticmethod - def GetConfig(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary( - request, - target, - '/fila.v1.FilaAdmin/GetConfig', - fila_dot_v1_dot_admin__pb2.GetConfigRequest.SerializeToString, - fila_dot_v1_dot_admin__pb2.GetConfigResponse.FromString, - options, - channel_credentials, - insecure, - call_credentials, - compression, - wait_for_ready, - timeout, - metadata, - _registered_method=True) - - @staticmethod - def ListConfig(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary( - request, - target, - '/fila.v1.FilaAdmin/ListConfig', - fila_dot_v1_dot_admin__pb2.ListConfigRequest.SerializeToString, - fila_dot_v1_dot_admin__pb2.ListConfigResponse.FromString, - options, - channel_credentials, - insecure, - call_credentials, - compression, - wait_for_ready, - timeout, - metadata, - _registered_method=True) - - @staticmethod - def GetStats(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary( - request, - target, - '/fila.v1.FilaAdmin/GetStats', - fila_dot_v1_dot_admin__pb2.GetStatsRequest.SerializeToString, - fila_dot_v1_dot_admin__pb2.GetStatsResponse.FromString, - options, - channel_credentials, - insecure, - call_credentials, - compression, - wait_for_ready, - timeout, - metadata, - _registered_method=True) - - @staticmethod - def Redrive(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary( - request, - target, - '/fila.v1.FilaAdmin/Redrive', - fila_dot_v1_dot_admin__pb2.RedriveRequest.SerializeToString, - fila_dot_v1_dot_admin__pb2.RedriveResponse.FromString, - options, - channel_credentials, - insecure, - call_credentials, - compression, - wait_for_ready, - timeout, - metadata, - _registered_method=True) - - @staticmethod - def ListQueues(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary( - request, - target, - '/fila.v1.FilaAdmin/ListQueues', - fila_dot_v1_dot_admin__pb2.ListQueuesRequest.SerializeToString, - fila_dot_v1_dot_admin__pb2.ListQueuesResponse.FromString, - options, - channel_credentials, - insecure, - call_credentials, - compression, - wait_for_ready, - timeout, - metadata, - _registered_method=True) - - @staticmethod - def CreateApiKey(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary( - request, - target, - '/fila.v1.FilaAdmin/CreateApiKey', - fila_dot_v1_dot_admin__pb2.CreateApiKeyRequest.SerializeToString, - fila_dot_v1_dot_admin__pb2.CreateApiKeyResponse.FromString, - options, - channel_credentials, - insecure, - call_credentials, - compression, - wait_for_ready, - timeout, - metadata, - _registered_method=True) - - @staticmethod - def RevokeApiKey(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary( - request, - target, - '/fila.v1.FilaAdmin/RevokeApiKey', - fila_dot_v1_dot_admin__pb2.RevokeApiKeyRequest.SerializeToString, - fila_dot_v1_dot_admin__pb2.RevokeApiKeyResponse.FromString, - options, - channel_credentials, - insecure, - call_credentials, - compression, - wait_for_ready, - timeout, - metadata, - _registered_method=True) - - @staticmethod - def ListApiKeys(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary( - request, - target, - '/fila.v1.FilaAdmin/ListApiKeys', - fila_dot_v1_dot_admin__pb2.ListApiKeysRequest.SerializeToString, - fila_dot_v1_dot_admin__pb2.ListApiKeysResponse.FromString, - options, - channel_credentials, - insecure, - call_credentials, - compression, - wait_for_ready, - timeout, - metadata, - _registered_method=True) - - @staticmethod - def SetAcl(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary( - request, - target, - '/fila.v1.FilaAdmin/SetAcl', - fila_dot_v1_dot_admin__pb2.SetAclRequest.SerializeToString, - fila_dot_v1_dot_admin__pb2.SetAclResponse.FromString, - options, - channel_credentials, - insecure, - call_credentials, - compression, - wait_for_ready, - timeout, - metadata, - _registered_method=True) - - @staticmethod - def GetAcl(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary( - request, - target, - '/fila.v1.FilaAdmin/GetAcl', - fila_dot_v1_dot_admin__pb2.GetAclRequest.SerializeToString, - fila_dot_v1_dot_admin__pb2.GetAclResponse.FromString, - options, - channel_credentials, - insecure, - call_credentials, - compression, - wait_for_ready, - timeout, - metadata, - _registered_method=True) diff --git a/fila/v1/messages_pb2.py b/fila/v1/messages_pb2.py deleted file mode 100644 index 3cf7edb..0000000 --- a/fila/v1/messages_pb2.py +++ /dev/null @@ -1,45 +0,0 @@ -# -*- coding: utf-8 -*- -# Generated by the protocol buffer compiler. DO NOT EDIT! -# NO CHECKED-IN PROTOBUF GENCODE -# source: fila/v1/messages.proto -# Protobuf Python Version: 6.31.1 -"""Generated protocol buffer code.""" -from google.protobuf import descriptor as _descriptor -from google.protobuf import descriptor_pool as _descriptor_pool -from google.protobuf import runtime_version as _runtime_version -from google.protobuf import symbol_database as _symbol_database -from google.protobuf.internal import builder as _builder -_runtime_version.ValidateProtobufRuntimeVersion( - _runtime_version.Domain.PUBLIC, - 6, - 31, - 1, - '', - 'fila/v1/messages.proto' -) -# @@protoc_insertion_point(imports) - -_sym_db = _symbol_database.Default() - - -from google.protobuf import timestamp_pb2 as google_dot_protobuf_dot_timestamp__pb2 - - -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x16\x66ila/v1/messages.proto\x12\x07\x66ila.v1\x1a\x1fgoogle/protobuf/timestamp.proto\"\xe2\x01\n\x07Message\x12\n\n\x02id\x18\x01 \x01(\t\x12.\n\x07headers\x18\x02 \x03(\x0b\x32\x1d.fila.v1.Message.HeadersEntry\x12\x0f\n\x07payload\x18\x03 \x01(\x0c\x12*\n\x08metadata\x18\x04 \x01(\x0b\x32\x18.fila.v1.MessageMetadata\x12.\n\ntimestamps\x18\x05 \x01(\x0b\x32\x1a.fila.v1.MessageTimestamps\x1a.\n\x0cHeadersEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\"w\n\x0fMessageMetadata\x12\x14\n\x0c\x66\x61irness_key\x18\x01 \x01(\t\x12\x0e\n\x06weight\x18\x02 \x01(\r\x12\x15\n\rthrottle_keys\x18\x03 \x03(\t\x12\x15\n\rattempt_count\x18\x04 \x01(\r\x12\x10\n\x08queue_id\x18\x05 \x01(\t\"s\n\x11MessageTimestamps\x12/\n\x0b\x65nqueued_at\x18\x01 \x01(\x0b\x32\x1a.google.protobuf.Timestamp\x12-\n\tleased_at\x18\x02 \x01(\x0b\x32\x1a.google.protobuf.Timestampb\x06proto3') - -_globals = globals() -_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) -_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'fila.v1.messages_pb2', _globals) -if not _descriptor._USE_C_DESCRIPTORS: - DESCRIPTOR._loaded_options = None - _globals['_MESSAGE_HEADERSENTRY']._loaded_options = None - _globals['_MESSAGE_HEADERSENTRY']._serialized_options = b'8\001' - _globals['_MESSAGE']._serialized_start=69 - _globals['_MESSAGE']._serialized_end=295 - _globals['_MESSAGE_HEADERSENTRY']._serialized_start=249 - _globals['_MESSAGE_HEADERSENTRY']._serialized_end=295 - _globals['_MESSAGEMETADATA']._serialized_start=297 - _globals['_MESSAGEMETADATA']._serialized_end=416 - _globals['_MESSAGETIMESTAMPS']._serialized_start=418 - _globals['_MESSAGETIMESTAMPS']._serialized_end=533 -# @@protoc_insertion_point(module_scope) diff --git a/fila/v1/messages_pb2.pyi b/fila/v1/messages_pb2.pyi deleted file mode 100644 index a91bb74..0000000 --- a/fila/v1/messages_pb2.pyi +++ /dev/null @@ -1,53 +0,0 @@ -import datetime - -from google.protobuf import timestamp_pb2 as _timestamp_pb2 -from google.protobuf.internal import containers as _containers -from google.protobuf import descriptor as _descriptor -from google.protobuf import message as _message -from collections.abc import Iterable as _Iterable, Mapping as _Mapping -from typing import ClassVar as _ClassVar, Optional as _Optional, Union as _Union - -DESCRIPTOR: _descriptor.FileDescriptor - -class Message(_message.Message): - __slots__ = ("id", "headers", "payload", "metadata", "timestamps") - class HeadersEntry(_message.Message): - __slots__ = ("key", "value") - KEY_FIELD_NUMBER: _ClassVar[int] - VALUE_FIELD_NUMBER: _ClassVar[int] - key: str - value: str - def __init__(self, key: _Optional[str] = ..., value: _Optional[str] = ...) -> None: ... - ID_FIELD_NUMBER: _ClassVar[int] - HEADERS_FIELD_NUMBER: _ClassVar[int] - PAYLOAD_FIELD_NUMBER: _ClassVar[int] - METADATA_FIELD_NUMBER: _ClassVar[int] - TIMESTAMPS_FIELD_NUMBER: _ClassVar[int] - id: str - headers: _containers.ScalarMap[str, str] - payload: bytes - metadata: MessageMetadata - timestamps: MessageTimestamps - def __init__(self, id: _Optional[str] = ..., headers: _Optional[_Mapping[str, str]] = ..., payload: _Optional[bytes] = ..., metadata: _Optional[_Union[MessageMetadata, _Mapping]] = ..., timestamps: _Optional[_Union[MessageTimestamps, _Mapping]] = ...) -> None: ... - -class MessageMetadata(_message.Message): - __slots__ = ("fairness_key", "weight", "throttle_keys", "attempt_count", "queue_id") - FAIRNESS_KEY_FIELD_NUMBER: _ClassVar[int] - WEIGHT_FIELD_NUMBER: _ClassVar[int] - THROTTLE_KEYS_FIELD_NUMBER: _ClassVar[int] - ATTEMPT_COUNT_FIELD_NUMBER: _ClassVar[int] - QUEUE_ID_FIELD_NUMBER: _ClassVar[int] - fairness_key: str - weight: int - throttle_keys: _containers.RepeatedScalarFieldContainer[str] - attempt_count: int - queue_id: str - def __init__(self, fairness_key: _Optional[str] = ..., weight: _Optional[int] = ..., throttle_keys: _Optional[_Iterable[str]] = ..., attempt_count: _Optional[int] = ..., queue_id: _Optional[str] = ...) -> None: ... - -class MessageTimestamps(_message.Message): - __slots__ = ("enqueued_at", "leased_at") - ENQUEUED_AT_FIELD_NUMBER: _ClassVar[int] - LEASED_AT_FIELD_NUMBER: _ClassVar[int] - enqueued_at: _timestamp_pb2.Timestamp - leased_at: _timestamp_pb2.Timestamp - def __init__(self, enqueued_at: _Optional[_Union[datetime.datetime, _timestamp_pb2.Timestamp, _Mapping]] = ..., leased_at: _Optional[_Union[datetime.datetime, _timestamp_pb2.Timestamp, _Mapping]] = ...) -> None: ... diff --git a/fila/v1/messages_pb2_grpc.py b/fila/v1/messages_pb2_grpc.py deleted file mode 100644 index fa0dc71..0000000 --- a/fila/v1/messages_pb2_grpc.py +++ /dev/null @@ -1,24 +0,0 @@ -# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! -"""Client and server classes corresponding to protobuf-defined services.""" -import grpc -import warnings - - -GRPC_GENERATED_VERSION = '1.78.0' -GRPC_VERSION = grpc.__version__ -_version_not_supported = False - -try: - from grpc._utilities import first_version_is_lower - _version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION) -except ImportError: - _version_not_supported = True - -if _version_not_supported: - raise RuntimeError( - f'The grpc package installed is at version {GRPC_VERSION},' - + ' but the generated code in fila/v1/messages_pb2_grpc.py depends on' - + f' grpcio>={GRPC_GENERATED_VERSION}.' - + f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}' - + f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.' - ) diff --git a/fila/v1/service_pb2.py b/fila/v1/service_pb2.py deleted file mode 100644 index 11ad1f0..0000000 --- a/fila/v1/service_pb2.py +++ /dev/null @@ -1,57 +0,0 @@ -# -*- coding: utf-8 -*- -# Generated by the protocol buffer compiler. DO NOT EDIT! -# NO CHECKED-IN PROTOBUF GENCODE -# source: fila/v1/service.proto -# Protobuf Python Version: 6.31.1 -"""Generated protocol buffer code.""" -from google.protobuf import descriptor as _descriptor -from google.protobuf import descriptor_pool as _descriptor_pool -from google.protobuf import runtime_version as _runtime_version -from google.protobuf import symbol_database as _symbol_database -from google.protobuf.internal import builder as _builder -_runtime_version.ValidateProtobufRuntimeVersion( - _runtime_version.Domain.PUBLIC, - 6, - 31, - 1, - '', - 'fila/v1/service.proto' -) -# @@protoc_insertion_point(imports) - -_sym_db = _symbol_database.Default() - - -from fila.v1 import messages_pb2 as fila_dot_v1_dot_messages__pb2 - - -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x15\x66ila/v1/service.proto\x12\x07\x66ila.v1\x1a\x16\x66ila/v1/messages.proto\"\x97\x01\n\x0e\x45nqueueRequest\x12\r\n\x05queue\x18\x01 \x01(\t\x12\x35\n\x07headers\x18\x02 \x03(\x0b\x32$.fila.v1.EnqueueRequest.HeadersEntry\x12\x0f\n\x07payload\x18\x03 \x01(\x0c\x1a.\n\x0cHeadersEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\"%\n\x0f\x45nqueueResponse\x12\x12\n\nmessage_id\x18\x01 \x01(\t\"\x1f\n\x0e\x43onsumeRequest\x12\r\n\x05queue\x18\x01 \x01(\t\"4\n\x0f\x43onsumeResponse\x12!\n\x07message\x18\x01 \x01(\x0b\x32\x10.fila.v1.Message\"/\n\nAckRequest\x12\r\n\x05queue\x18\x01 \x01(\t\x12\x12\n\nmessage_id\x18\x02 \x01(\t\"\r\n\x0b\x41\x63kResponse\"?\n\x0bNackRequest\x12\r\n\x05queue\x18\x01 \x01(\t\x12\x12\n\nmessage_id\x18\x02 \x01(\t\x12\r\n\x05\x65rror\x18\x03 \x01(\t\"\x0e\n\x0cNackResponse2\xf2\x01\n\x0b\x46ilaService\x12<\n\x07\x45nqueue\x12\x17.fila.v1.EnqueueRequest\x1a\x18.fila.v1.EnqueueResponse\x12>\n\x07\x43onsume\x12\x17.fila.v1.ConsumeRequest\x1a\x18.fila.v1.ConsumeResponse0\x01\x12\x30\n\x03\x41\x63k\x12\x13.fila.v1.AckRequest\x1a\x14.fila.v1.AckResponse\x12\x33\n\x04Nack\x12\x14.fila.v1.NackRequest\x1a\x15.fila.v1.NackResponseb\x06proto3') - -_globals = globals() -_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) -_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'fila.v1.service_pb2', _globals) -if not _descriptor._USE_C_DESCRIPTORS: - DESCRIPTOR._loaded_options = None - _globals['_ENQUEUEREQUEST_HEADERSENTRY']._loaded_options = None - _globals['_ENQUEUEREQUEST_HEADERSENTRY']._serialized_options = b'8\001' - _globals['_ENQUEUEREQUEST']._serialized_start=59 - _globals['_ENQUEUEREQUEST']._serialized_end=210 - _globals['_ENQUEUEREQUEST_HEADERSENTRY']._serialized_start=164 - _globals['_ENQUEUEREQUEST_HEADERSENTRY']._serialized_end=210 - _globals['_ENQUEUERESPONSE']._serialized_start=212 - _globals['_ENQUEUERESPONSE']._serialized_end=249 - _globals['_CONSUMEREQUEST']._serialized_start=251 - _globals['_CONSUMEREQUEST']._serialized_end=282 - _globals['_CONSUMERESPONSE']._serialized_start=284 - _globals['_CONSUMERESPONSE']._serialized_end=336 - _globals['_ACKREQUEST']._serialized_start=338 - _globals['_ACKREQUEST']._serialized_end=385 - _globals['_ACKRESPONSE']._serialized_start=387 - _globals['_ACKRESPONSE']._serialized_end=400 - _globals['_NACKREQUEST']._serialized_start=402 - _globals['_NACKREQUEST']._serialized_end=465 - _globals['_NACKRESPONSE']._serialized_start=467 - _globals['_NACKRESPONSE']._serialized_end=481 - _globals['_FILASERVICE']._serialized_start=484 - _globals['_FILASERVICE']._serialized_end=726 -# @@protoc_insertion_point(module_scope) diff --git a/fila/v1/service_pb2.pyi b/fila/v1/service_pb2.pyi deleted file mode 100644 index c6478c4..0000000 --- a/fila/v1/service_pb2.pyi +++ /dev/null @@ -1,69 +0,0 @@ -from fila.v1 import messages_pb2 as _messages_pb2 -from google.protobuf.internal import containers as _containers -from google.protobuf import descriptor as _descriptor -from google.protobuf import message as _message -from collections.abc import Mapping as _Mapping -from typing import ClassVar as _ClassVar, Optional as _Optional, Union as _Union - -DESCRIPTOR: _descriptor.FileDescriptor - -class EnqueueRequest(_message.Message): - __slots__ = ("queue", "headers", "payload") - class HeadersEntry(_message.Message): - __slots__ = ("key", "value") - KEY_FIELD_NUMBER: _ClassVar[int] - VALUE_FIELD_NUMBER: _ClassVar[int] - key: str - value: str - def __init__(self, key: _Optional[str] = ..., value: _Optional[str] = ...) -> None: ... - QUEUE_FIELD_NUMBER: _ClassVar[int] - HEADERS_FIELD_NUMBER: _ClassVar[int] - PAYLOAD_FIELD_NUMBER: _ClassVar[int] - queue: str - headers: _containers.ScalarMap[str, str] - payload: bytes - def __init__(self, queue: _Optional[str] = ..., headers: _Optional[_Mapping[str, str]] = ..., payload: _Optional[bytes] = ...) -> None: ... - -class EnqueueResponse(_message.Message): - __slots__ = ("message_id",) - MESSAGE_ID_FIELD_NUMBER: _ClassVar[int] - message_id: str - def __init__(self, message_id: _Optional[str] = ...) -> None: ... - -class ConsumeRequest(_message.Message): - __slots__ = ("queue",) - QUEUE_FIELD_NUMBER: _ClassVar[int] - queue: str - def __init__(self, queue: _Optional[str] = ...) -> None: ... - -class ConsumeResponse(_message.Message): - __slots__ = ("message",) - MESSAGE_FIELD_NUMBER: _ClassVar[int] - message: _messages_pb2.Message - def __init__(self, message: _Optional[_Union[_messages_pb2.Message, _Mapping]] = ...) -> None: ... - -class AckRequest(_message.Message): - __slots__ = ("queue", "message_id") - QUEUE_FIELD_NUMBER: _ClassVar[int] - MESSAGE_ID_FIELD_NUMBER: _ClassVar[int] - queue: str - message_id: str - def __init__(self, queue: _Optional[str] = ..., message_id: _Optional[str] = ...) -> None: ... - -class AckResponse(_message.Message): - __slots__ = () - def __init__(self) -> None: ... - -class NackRequest(_message.Message): - __slots__ = ("queue", "message_id", "error") - QUEUE_FIELD_NUMBER: _ClassVar[int] - MESSAGE_ID_FIELD_NUMBER: _ClassVar[int] - ERROR_FIELD_NUMBER: _ClassVar[int] - queue: str - message_id: str - error: str - def __init__(self, queue: _Optional[str] = ..., message_id: _Optional[str] = ..., error: _Optional[str] = ...) -> None: ... - -class NackResponse(_message.Message): - __slots__ = () - def __init__(self) -> None: ... diff --git a/fila/v1/service_pb2_grpc.py b/fila/v1/service_pb2_grpc.py deleted file mode 100644 index 663ae2a..0000000 --- a/fila/v1/service_pb2_grpc.py +++ /dev/null @@ -1,229 +0,0 @@ -# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! -"""Client and server classes corresponding to protobuf-defined services.""" -import grpc -import warnings - -from fila.v1 import service_pb2 as fila_dot_v1_dot_service__pb2 - -GRPC_GENERATED_VERSION = '1.78.0' -GRPC_VERSION = grpc.__version__ -_version_not_supported = False - -try: - from grpc._utilities import first_version_is_lower - _version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION) -except ImportError: - _version_not_supported = True - -if _version_not_supported: - raise RuntimeError( - f'The grpc package installed is at version {GRPC_VERSION},' - + ' but the generated code in fila/v1/service_pb2_grpc.py depends on' - + f' grpcio>={GRPC_GENERATED_VERSION}.' - + f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}' - + f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.' - ) - - -class FilaServiceStub(object): - """Hot-path RPCs for producers and consumers. - """ - - def __init__(self, channel): - """Constructor. - - Args: - channel: A grpc.Channel. - """ - self.Enqueue = channel.unary_unary( - '/fila.v1.FilaService/Enqueue', - request_serializer=fila_dot_v1_dot_service__pb2.EnqueueRequest.SerializeToString, - response_deserializer=fila_dot_v1_dot_service__pb2.EnqueueResponse.FromString, - _registered_method=True) - self.Consume = channel.unary_stream( - '/fila.v1.FilaService/Consume', - request_serializer=fila_dot_v1_dot_service__pb2.ConsumeRequest.SerializeToString, - response_deserializer=fila_dot_v1_dot_service__pb2.ConsumeResponse.FromString, - _registered_method=True) - self.Ack = channel.unary_unary( - '/fila.v1.FilaService/Ack', - request_serializer=fila_dot_v1_dot_service__pb2.AckRequest.SerializeToString, - response_deserializer=fila_dot_v1_dot_service__pb2.AckResponse.FromString, - _registered_method=True) - self.Nack = channel.unary_unary( - '/fila.v1.FilaService/Nack', - request_serializer=fila_dot_v1_dot_service__pb2.NackRequest.SerializeToString, - response_deserializer=fila_dot_v1_dot_service__pb2.NackResponse.FromString, - _registered_method=True) - - -class FilaServiceServicer(object): - """Hot-path RPCs for producers and consumers. - """ - - def Enqueue(self, request, context): - """Missing associated documentation comment in .proto file.""" - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') - - def Consume(self, request, context): - """Missing associated documentation comment in .proto file.""" - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') - - def Ack(self, request, context): - """Missing associated documentation comment in .proto file.""" - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') - - def Nack(self, request, context): - """Missing associated documentation comment in .proto file.""" - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') - - -def add_FilaServiceServicer_to_server(servicer, server): - rpc_method_handlers = { - 'Enqueue': grpc.unary_unary_rpc_method_handler( - servicer.Enqueue, - request_deserializer=fila_dot_v1_dot_service__pb2.EnqueueRequest.FromString, - response_serializer=fila_dot_v1_dot_service__pb2.EnqueueResponse.SerializeToString, - ), - 'Consume': grpc.unary_stream_rpc_method_handler( - servicer.Consume, - request_deserializer=fila_dot_v1_dot_service__pb2.ConsumeRequest.FromString, - response_serializer=fila_dot_v1_dot_service__pb2.ConsumeResponse.SerializeToString, - ), - 'Ack': grpc.unary_unary_rpc_method_handler( - servicer.Ack, - request_deserializer=fila_dot_v1_dot_service__pb2.AckRequest.FromString, - response_serializer=fila_dot_v1_dot_service__pb2.AckResponse.SerializeToString, - ), - 'Nack': grpc.unary_unary_rpc_method_handler( - servicer.Nack, - request_deserializer=fila_dot_v1_dot_service__pb2.NackRequest.FromString, - response_serializer=fila_dot_v1_dot_service__pb2.NackResponse.SerializeToString, - ), - } - generic_handler = grpc.method_handlers_generic_handler( - 'fila.v1.FilaService', rpc_method_handlers) - server.add_generic_rpc_handlers((generic_handler,)) - server.add_registered_method_handlers('fila.v1.FilaService', rpc_method_handlers) - - - # This class is part of an EXPERIMENTAL API. -class FilaService(object): - """Hot-path RPCs for producers and consumers. - """ - - @staticmethod - def Enqueue(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary( - request, - target, - '/fila.v1.FilaService/Enqueue', - fila_dot_v1_dot_service__pb2.EnqueueRequest.SerializeToString, - fila_dot_v1_dot_service__pb2.EnqueueResponse.FromString, - options, - channel_credentials, - insecure, - call_credentials, - compression, - wait_for_ready, - timeout, - metadata, - _registered_method=True) - - @staticmethod - def Consume(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_stream( - request, - target, - '/fila.v1.FilaService/Consume', - fila_dot_v1_dot_service__pb2.ConsumeRequest.SerializeToString, - fila_dot_v1_dot_service__pb2.ConsumeResponse.FromString, - options, - channel_credentials, - insecure, - call_credentials, - compression, - wait_for_ready, - timeout, - metadata, - _registered_method=True) - - @staticmethod - def Ack(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary( - request, - target, - '/fila.v1.FilaService/Ack', - fila_dot_v1_dot_service__pb2.AckRequest.SerializeToString, - fila_dot_v1_dot_service__pb2.AckResponse.FromString, - options, - channel_credentials, - insecure, - call_credentials, - compression, - wait_for_ready, - timeout, - metadata, - _registered_method=True) - - @staticmethod - def Nack(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary( - request, - target, - '/fila.v1.FilaService/Nack', - fila_dot_v1_dot_service__pb2.NackRequest.SerializeToString, - fila_dot_v1_dot_service__pb2.NackResponse.FromString, - options, - channel_credentials, - insecure, - call_credentials, - compression, - wait_for_ready, - timeout, - metadata, - _registered_method=True) diff --git a/proto/fila/v1/admin.proto b/proto/fila/v1/admin.proto deleted file mode 100644 index 886e58d..0000000 --- a/proto/fila/v1/admin.proto +++ /dev/null @@ -1,197 +0,0 @@ -syntax = "proto3"; -package fila.v1; - -// Admin RPCs for operators and the CLI. -service FilaAdmin { - rpc CreateQueue(CreateQueueRequest) returns (CreateQueueResponse); - rpc DeleteQueue(DeleteQueueRequest) returns (DeleteQueueResponse); - rpc SetConfig(SetConfigRequest) returns (SetConfigResponse); - rpc GetConfig(GetConfigRequest) returns (GetConfigResponse); - rpc ListConfig(ListConfigRequest) returns (ListConfigResponse); - rpc GetStats(GetStatsRequest) returns (GetStatsResponse); - rpc Redrive(RedriveRequest) returns (RedriveResponse); - rpc ListQueues(ListQueuesRequest) returns (ListQueuesResponse); - - // API key management. CreateApiKey bypasses auth (bootstrap); others require a valid key. - rpc CreateApiKey(CreateApiKeyRequest) returns (CreateApiKeyResponse); - rpc RevokeApiKey(RevokeApiKeyRequest) returns (RevokeApiKeyResponse); - rpc ListApiKeys(ListApiKeysRequest) returns (ListApiKeysResponse); - - // Per-key ACL management. - rpc SetAcl(SetAclRequest) returns (SetAclResponse); - rpc GetAcl(GetAclRequest) returns (GetAclResponse); -} - -message CreateQueueRequest { - string name = 1; - QueueConfig config = 2; -} - -message QueueConfig { - string on_enqueue_script = 1; - string on_failure_script = 2; - uint64 visibility_timeout_ms = 3; -} - -message CreateQueueResponse { - string queue_id = 1; -} - -message DeleteQueueRequest { - string queue = 1; -} - -message DeleteQueueResponse {} - -message SetConfigRequest { - string key = 1; - string value = 2; -} - -message SetConfigResponse {} - -message GetConfigRequest { - string key = 1; -} - -message GetConfigResponse { - string value = 1; -} - -message ConfigEntry { - string key = 1; - string value = 2; -} - -message ListConfigRequest { - string prefix = 1; -} - -message ListConfigResponse { - repeated ConfigEntry entries = 1; - uint32 total_count = 2; -} - -message GetStatsRequest { - string queue = 1; -} - -message PerFairnessKeyStats { - string key = 1; - uint64 pending_count = 2; - int64 current_deficit = 3; - uint32 weight = 4; -} - -message PerThrottleKeyStats { - string key = 1; - double tokens = 2; - double rate_per_second = 3; - double burst = 4; -} - -message GetStatsResponse { - uint64 depth = 1; - uint64 in_flight = 2; - uint64 active_fairness_keys = 3; - uint32 active_consumers = 4; - uint32 quantum = 5; - repeated PerFairnessKeyStats per_key_stats = 6; - repeated PerThrottleKeyStats per_throttle_stats = 7; - // Cluster fields (0 when not in cluster mode). - uint64 leader_node_id = 8; - uint32 replication_count = 9; -} - -message RedriveRequest { - string dlq_queue = 1; - uint64 count = 2; -} - -message RedriveResponse { - uint64 redriven = 1; -} - -message ListQueuesRequest {} - -message QueueInfo { - string name = 1; - uint64 depth = 2; - uint64 in_flight = 3; - uint32 active_consumers = 4; - uint64 leader_node_id = 5; -} - -message ListQueuesResponse { - repeated QueueInfo queues = 1; - uint32 cluster_node_count = 2; -} - -// --- API Key Management --- - -message CreateApiKeyRequest { - /// Human-readable label for the key. - string name = 1; - /// Optional Unix timestamp (milliseconds) after which the key expires. - /// 0 means no expiration. - uint64 expires_at_ms = 2; - /// When true, the key bypasses all ACL checks (superadmin). - bool is_superadmin = 3; -} - -message CreateApiKeyResponse { - /// Opaque key ID for management operations (revoke, list, set-acl). - string key_id = 1; - /// Plaintext API key. Returned once — store it securely. - string key = 2; - /// Whether this key has superadmin privileges. - bool is_superadmin = 3; -} - -message RevokeApiKeyRequest { - string key_id = 1; -} - -message RevokeApiKeyResponse {} - -message ListApiKeysRequest {} - -message ApiKeyInfo { - string key_id = 1; - string name = 2; - uint64 created_at_ms = 3; - /// 0 means no expiration. - uint64 expires_at_ms = 4; - bool is_superadmin = 5; -} - -message ListApiKeysResponse { - repeated ApiKeyInfo keys = 1; -} - -// --- ACL Management --- - -/// A single permission grant: kind (produce/consume/admin) + queue pattern. -message AclPermission { - /// One of: "produce", "consume", "admin". - string kind = 1; - /// Queue name or wildcard ("*" or "orders.*"). - string pattern = 2; -} - -message SetAclRequest { - string key_id = 1; - repeated AclPermission permissions = 2; -} - -message SetAclResponse {} - -message GetAclRequest { - string key_id = 1; -} - -message GetAclResponse { - string key_id = 1; - repeated AclPermission permissions = 2; - bool is_superadmin = 3; -} diff --git a/proto/fila/v1/messages.proto b/proto/fila/v1/messages.proto deleted file mode 100644 index a0709cf..0000000 --- a/proto/fila/v1/messages.proto +++ /dev/null @@ -1,28 +0,0 @@ -syntax = "proto3"; -package fila.v1; - -import "google/protobuf/timestamp.proto"; - -// Core message envelope persisted in the broker. -message Message { - string id = 1; - map headers = 2; - bytes payload = 3; - MessageMetadata metadata = 4; - MessageTimestamps timestamps = 5; -} - -// Broker-assigned scheduling metadata. -message MessageMetadata { - string fairness_key = 1; - uint32 weight = 2; - repeated string throttle_keys = 3; - uint32 attempt_count = 4; - string queue_id = 5; -} - -// Lifecycle timestamps attached to every message. -message MessageTimestamps { - google.protobuf.Timestamp enqueued_at = 1; - google.protobuf.Timestamp leased_at = 2; -} diff --git a/proto/fila/v1/service.proto b/proto/fila/v1/service.proto deleted file mode 100644 index f14fdd0..0000000 --- a/proto/fila/v1/service.proto +++ /dev/null @@ -1,45 +0,0 @@ -syntax = "proto3"; -package fila.v1; - -import "fila/v1/messages.proto"; - -// Hot-path RPCs for producers and consumers. -service FilaService { - rpc Enqueue(EnqueueRequest) returns (EnqueueResponse); - rpc Consume(ConsumeRequest) returns (stream ConsumeResponse); - rpc Ack(AckRequest) returns (AckResponse); - rpc Nack(NackRequest) returns (NackResponse); -} - -message EnqueueRequest { - string queue = 1; - map headers = 2; - bytes payload = 3; -} - -message EnqueueResponse { - string message_id = 1; -} - -message ConsumeRequest { - string queue = 1; -} - -message ConsumeResponse { - Message message = 1; -} - -message AckRequest { - string queue = 1; - string message_id = 2; -} - -message AckResponse {} - -message NackRequest { - string queue = 1; - string message_id = 2; - string error = 3; -} - -message NackResponse {} diff --git a/pyproject.toml b/pyproject.toml index 26a4ab3..d25e4b1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,24 +4,19 @@ build-backend = "setuptools.build_meta" [project] name = "fila-python" -version = "0.1.0" +version = "0.3.0" description = "Python client SDK for the Fila message broker" readme = "README.md" -license = "AGPL-3.0-or-later" +license = {text = "AGPL-3.0-or-later"} requires-python = ">=3.10" -dependencies = [ - "grpcio>=1.60.0", - "protobuf>=4.25.0", -] +dependencies = [] [project.optional-dependencies] dev = [ - "grpcio-tools>=1.60.0", "pytest>=8.0", "pytest-asyncio>=0.23", "ruff>=0.3", "mypy>=1.8", - "mypy-protobuf>=3.5", ] [tool.setuptools.packages.find] @@ -30,10 +25,10 @@ include = ["fila*"] [tool.ruff] target-version = "py310" line-length = 100 -exclude = ["fila/v1/"] [tool.ruff.lint] select = ["E", "F", "I", "N", "UP", "B", "SIM", "TCH"] +ignore = ["SIM115"] [tool.mypy] python_version = "3.10" @@ -41,13 +36,5 @@ strict = true warn_return_any = true warn_unused_configs = true -[[tool.mypy.overrides]] -module = "fila.v1.*" -ignore_errors = true - -[[tool.mypy.overrides]] -module = "grpc.*" -ignore_missing_imports = true - [tool.pytest.ini_options] asyncio_mode = "auto" diff --git a/tests/conftest.py b/tests/conftest.py index 3b91d60..820aa02 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -12,13 +12,11 @@ from pathlib import Path from typing import TYPE_CHECKING -import grpc import pytest if TYPE_CHECKING: from collections.abc import Generator -from fila.v1 import admin_pb2, admin_pb2_grpc FILA_SERVER_BIN = os.environ.get( "FILA_SERVER_BIN", @@ -169,41 +167,43 @@ def stop(self) -> None: self._process.wait() shutil.rmtree(self._data_dir, ignore_errors=True) - def _make_channel(self) -> grpc.Channel: - """Create a gRPC channel to this server (TLS-aware).""" + def create_queue(self, name: str) -> None: + """Create a queue on the test server via FIBP.""" + from fila import Client + + kwargs: dict[str, object] = {} if self.tls_paths is not None: with open(self.tls_paths["ca_cert"], "rb") as f: - ca = f.read() + kwargs["ca_cert"] = f.read() with open(self.tls_paths["client_cert"], "rb") as f: - cert = f.read() + kwargs["client_cert"] = f.read() with open(self.tls_paths["client_key"], "rb") as f: - key = f.read() - creds = grpc.ssl_channel_credentials( - root_certificates=ca, - private_key=key, - certificate_chain=cert, - ) - channel = grpc.secure_channel(self.addr, creds) - else: - channel = grpc.insecure_channel(self.addr) - + kwargs["client_key"] = f.read() if self.api_key is not None: - from fila.client import _ApiKeyInterceptor - channel = grpc.intercept_channel(channel, _ApiKeyInterceptor(self.api_key)) + kwargs["api_key"] = self.api_key + kwargs["accumulator_mode"] = __import__("fila").AccumulatorMode.DISABLED - return channel + with Client(self.addr, **kwargs) as client: # type: ignore[arg-type] + client.create_queue(name) - def create_queue(self, name: str) -> None: - """Create a queue on the test server via admin gRPC.""" - channel = self._make_channel() - stub = admin_pb2_grpc.FilaAdminStub(channel) - stub.CreateQueue( - admin_pb2.CreateQueueRequest( - name=name, - config=admin_pb2.QueueConfig(), - ) - ) - channel.close() + +def _wait_for_server(addr: str, timeout: float = 10.0, **client_kwargs: object) -> bool: + """Wait for the fila-server to become ready by attempting a FIBP handshake.""" + from fila import AccumulatorMode, Client + + deadline = time.monotonic() + timeout + while time.monotonic() < deadline: + try: + with Client( + addr, + accumulator_mode=AccumulatorMode.DISABLED, + **client_kwargs, # type: ignore[arg-type] + ) as client: + client.list_queues() + return True + except Exception: + time.sleep(0.05) + return False @pytest.fixture() @@ -217,7 +217,6 @@ def server() -> Generator[TestServer, None, None]: data_dir = tempfile.mkdtemp(prefix="fila-test-") - # Write config file for the server. config_path = os.path.join(data_dir, "fila.toml") with open(config_path, "w") as f: f.write(f'[server]\nlisten_addr = "{addr}"\n') @@ -233,24 +232,11 @@ def server() -> Generator[TestServer, None, None]: ts = TestServer(addr, process, data_dir) - # Wait for server to be ready. - deadline = time.monotonic() + 10.0 - while time.monotonic() < deadline: - channel = grpc.insecure_channel(addr) - try: - stub = admin_pb2_grpc.FilaAdminStub(channel) - stub.ListQueues(admin_pb2.ListQueuesRequest()) - channel.close() - break - except grpc.RpcError: - channel.close() - time.sleep(0.05) - else: + if not _wait_for_server(addr): ts.stop() pytest.fail("fila-server did not become ready within 10s") yield ts - ts.stop() @@ -271,7 +257,6 @@ def tls_server() -> Generator[TestServer, None, None]: data_dir = tempfile.mkdtemp(prefix="fila-tls-test-") tls_paths = _generate_self_signed_certs(data_dir) - # Write config with TLS enabled. config_path = os.path.join(data_dir, "fila.toml") with open(config_path, "w") as f: f.write( @@ -295,24 +280,20 @@ def tls_server() -> Generator[TestServer, None, None]: ts = TestServer(addr, process, data_dir, tls_paths=tls_paths) - # Wait for server to be ready (use TLS channel). - deadline = time.monotonic() + 10.0 - while time.monotonic() < deadline: - channel = ts._make_channel() - try: - stub = admin_pb2_grpc.FilaAdminStub(channel) - stub.ListQueues(admin_pb2.ListQueuesRequest()) - channel.close() - break - except grpc.RpcError: - channel.close() - time.sleep(0.05) - else: + with open(tls_paths["ca_cert"], "rb") as f: + ca_cert = f.read() + with open(tls_paths["client_cert"], "rb") as f: + client_cert = f.read() + with open(tls_paths["client_key"], "rb") as f: + client_key = f.read() + + if not _wait_for_server( + addr, ca_cert=ca_cert, client_cert=client_cert, client_key=client_key + ): ts.stop() pytest.fail("TLS fila-server did not become ready within 10s") yield ts - ts.stop() @@ -328,7 +309,6 @@ def auth_server() -> Generator[TestServer, None, None]: data_dir = tempfile.mkdtemp(prefix="fila-auth-test-") - # Write config with bootstrap API key. config_path = os.path.join(data_dir, "fila.toml") with open(config_path, "w") as f: f.write( @@ -348,22 +328,9 @@ def auth_server() -> Generator[TestServer, None, None]: ts = TestServer(addr, process, data_dir, api_key=bootstrap_key) - # Wait for server to be ready. - deadline = time.monotonic() + 10.0 - while time.monotonic() < deadline: - channel = ts._make_channel() - try: - stub = admin_pb2_grpc.FilaAdminStub(channel) - stub.ListQueues(admin_pb2.ListQueuesRequest()) - channel.close() - break - except grpc.RpcError: - channel.close() - time.sleep(0.05) - else: + if not _wait_for_server(addr, api_key=bootstrap_key): ts.stop() pytest.fail("auth fila-server did not become ready within 10s") yield ts - ts.stop() diff --git a/tests/test_batcher.py b/tests/test_batcher.py new file mode 100644 index 0000000..231b3cc --- /dev/null +++ b/tests/test_batcher.py @@ -0,0 +1,259 @@ +"""Unit tests for the batcher module. + +These tests use mock connections and do not require a running fila-server. +""" + +from __future__ import annotations + +import struct +from concurrent.futures import Future +from unittest.mock import MagicMock + +import pytest + +from fila.batcher import ( + AutoAccumulator, + LingerAccumulator, + _EnqueueItem, + _flush_many, + _flush_single, +) +from fila.errors import EnqueueError, QueueNotFoundError +from fila.fibp.opcodes import ErrorCode, FrameHeader, Opcode + + +def _make_enqueue_result_body(*results: tuple[int, str]) -> bytes: + """Build an EnqueueResult frame body from (error_code, message_id) tuples.""" + buf = struct.pack("!I", len(results)) + for code, msg_id in results: + buf += struct.pack("!B", code) + encoded = msg_id.encode("utf-8") + buf += struct.pack("!H", len(encoded)) + encoded + return buf + + +def _make_mock_conn(*results: tuple[int, str]) -> MagicMock: + """Create a mock Connection whose request() returns an EnqueueResult.""" + conn = MagicMock() + body = _make_enqueue_result_body(*results) + header = FrameHeader(opcode=Opcode.ENQUEUE_RESULT, flags=0, request_id=1) + conn.request.return_value = (header, body) + return conn + + +def _make_error_conn(error_code: int, message: str) -> MagicMock: + """Create a mock Connection whose request() returns an Error frame.""" + conn = MagicMock() + # Build error frame body: u8 code, string message, map metadata (empty) + encoded_msg = message.encode("utf-8") + body = ( + struct.pack("!B", error_code) + + struct.pack("!H", len(encoded_msg)) + encoded_msg + + struct.pack("!H", 0) # empty metadata map + ) + header = FrameHeader(opcode=Opcode.ERROR, flags=0, request_id=1) + conn.request.return_value = (header, body) + return conn + + +class TestFlushSingle: + """Test the _flush_single function.""" + + def test_success(self) -> None: + conn = _make_mock_conn((ErrorCode.OK, "msg-001")) + + msg = {"queue": "q", "headers": {}, "payload": b"data"} + fut: Future[str] = Future() + req = _EnqueueItem(msg, fut) + + _flush_single(conn, req) + + assert fut.result(timeout=1.0) == "msg-001" + conn.request.assert_called_once() + + def test_error_frame(self) -> None: + conn = _make_error_conn(ErrorCode.QUEUE_NOT_FOUND, "queue not found") + + msg = {"queue": "missing", "headers": {}, "payload": b"data"} + fut: Future[str] = Future() + req = _EnqueueItem(msg, fut) + + _flush_single(conn, req) + + with pytest.raises(QueueNotFoundError): + fut.result(timeout=1.0) + + +class TestFlushMany: + """Test the _flush_many function.""" + + def test_all_success(self) -> None: + conn = _make_mock_conn( + (ErrorCode.OK, "id-1"), + (ErrorCode.OK, "id-2"), + ) + + items = [ + _EnqueueItem( + {"queue": "q", "headers": {}, "payload": b"a"}, + Future(), + ), + _EnqueueItem( + {"queue": "q", "headers": {}, "payload": b"b"}, + Future(), + ), + ] + + _flush_many(conn, items) + + assert items[0].future.result(timeout=1.0) == "id-1" + assert items[1].future.result(timeout=1.0) == "id-2" + + def test_mixed_results(self) -> None: + conn = _make_mock_conn( + (ErrorCode.OK, "id-1"), + (ErrorCode.QUEUE_NOT_FOUND, ""), + ) + + items = [ + _EnqueueItem( + {"queue": "q", "headers": {}, "payload": b"a"}, + Future(), + ), + _EnqueueItem( + {"queue": "missing", "headers": {}, "payload": b"b"}, + Future(), + ), + ] + + _flush_many(conn, items) + + assert items[0].future.result(timeout=1.0) == "id-1" + with pytest.raises(QueueNotFoundError): + items[1].future.result(timeout=1.0) + + def test_connection_failure_sets_all_futures(self) -> None: + conn = MagicMock() + conn.request.side_effect = ConnectionError("server unavailable") + + items = [ + _EnqueueItem( + {"queue": "q", "headers": {}, "payload": b"a"}, + Future(), + ), + _EnqueueItem( + {"queue": "q", "headers": {}, "payload": b"b"}, + Future(), + ), + ] + + _flush_many(conn, items) + + for item in items: + with pytest.raises(EnqueueError): + item.future.result(timeout=1.0) + + +class TestAutoAccumulator: + """Test the AutoAccumulator end-to-end.""" + + def test_single_message(self) -> None: + """When only one message is queued, AutoAccumulator sends it.""" + conn = _make_mock_conn((ErrorCode.OK, "msg-solo")) + accumulator = AutoAccumulator(conn, max_messages=100) + + msg = {"queue": "q", "headers": {}, "payload": b"solo"} + fut = accumulator.submit(msg) + result = fut.result(timeout=5.0) + + assert result == "msg-solo" + conn.request.assert_called_once() + accumulator.close() + + def test_concurrent_messages_accumulated(self) -> None: + """When multiple messages arrive concurrently, they accumulate together.""" + conn = _make_mock_conn(*[(ErrorCode.OK, f"id-{i}") for i in range(5)]) + accumulator = AutoAccumulator(conn, max_messages=100) + + futures = [] + for i in range(5): + msg = {"queue": "q", "headers": {}, "payload": f"msg-{i}".encode()} + futures.append(accumulator.submit(msg)) + + for f in futures: + result = f.result(timeout=5.0) + assert result is not None + + accumulator.close() + + def test_close_drains_pending(self) -> None: + """close() waits for pending messages to be flushed.""" + conn = _make_mock_conn((ErrorCode.OK, "drained")) + accumulator = AutoAccumulator(conn, max_messages=100) + + msg = {"queue": "q", "headers": {}, "payload": b"drain-me"} + fut = accumulator.submit(msg) + + accumulator.close() + + assert fut.result(timeout=1.0) == "drained" + + def test_update_conn(self) -> None: + """update_conn replaces the connection used for flushing.""" + old_conn = MagicMock() + new_conn = _make_mock_conn((ErrorCode.OK, "new-conn")) + + accumulator = AutoAccumulator(old_conn, max_messages=100) + accumulator.update_conn(new_conn) + + msg = {"queue": "q", "headers": {}, "payload": b"data"} + fut = accumulator.submit(msg) + result = fut.result(timeout=5.0) + + assert result == "new-conn" + accumulator.close() + + +class TestLingerAccumulator: + """Test the LingerAccumulator.""" + + def test_flushes_at_max_messages(self) -> None: + """Flush triggers when max_messages messages accumulate.""" + conn = _make_mock_conn(*[(ErrorCode.OK, f"id-{i}") for i in range(3)]) + accumulator = LingerAccumulator(conn, linger_ms=5000, max_messages=3) + + futures = [] + for i in range(3): + msg = {"queue": "q", "headers": {}, "payload": f"m{i}".encode()} + futures.append(accumulator.submit(msg)) + + for i, f in enumerate(futures): + result = f.result(timeout=5.0) + assert result == f"id-{i}" + + accumulator.close() + + def test_flushes_at_linger_timeout(self) -> None: + """Flush triggers after linger_ms even if max_messages is not reached.""" + conn = _make_mock_conn((ErrorCode.OK, "lingered")) + accumulator = LingerAccumulator(conn, linger_ms=50, max_messages=100) + + msg = {"queue": "q", "headers": {}, "payload": b"linger"} + fut = accumulator.submit(msg) + + result = fut.result(timeout=5.0) + assert result == "lingered" + + accumulator.close() + + def test_close_drains_pending(self) -> None: + """close() drains any pending messages.""" + conn = _make_mock_conn((ErrorCode.OK, "drained")) + accumulator = LingerAccumulator(conn, linger_ms=10000, max_messages=100) + + msg = {"queue": "q", "headers": {}, "payload": b"drain"} + fut = accumulator.submit(msg) + + accumulator.close() + + assert fut.result(timeout=1.0) == "drained" diff --git a/tests/test_client.py b/tests/test_client.py index b8e353e..1fbfae8 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -17,14 +17,14 @@ def test_enqueue_consume_ack(self, server: object) -> None: assert isinstance(server, TestServer) server.create_queue("test-sync-eca") - with fila.Client(server.addr) as client: - # Enqueue a message. + with fila.Client( + server.addr, accumulator_mode=fila.AccumulatorMode.DISABLED + ) as client: headers = {"tenant": "acme"} payload = b"hello world" msg_id = client.enqueue("test-sync-eca", headers, payload) assert msg_id != "" - # Consume the message. stream = client.consume("test-sync-eca") msg = next(stream) @@ -32,7 +32,6 @@ def test_enqueue_consume_ack(self, server: object) -> None: assert msg.headers["tenant"] == "acme" assert msg.payload == b"hello world" - # Ack the message. client.ack("test-sync-eca", msg.id) def test_enqueue_consume_nack_redeliver(self, server: object) -> None: @@ -42,26 +41,23 @@ def test_enqueue_consume_nack_redeliver(self, server: object) -> None: assert isinstance(server, TestServer) server.create_queue("test-sync-nack") - with fila.Client(server.addr) as client: + with fila.Client( + server.addr, accumulator_mode=fila.AccumulatorMode.DISABLED + ) as client: msg_id = client.enqueue("test-sync-nack", None, b"retry-me") - # Open consume stream. stream = client.consume("test-sync-nack") - # First delivery. msg = next(stream) assert msg.id == msg_id assert msg.attempt_count == 0 - # Nack the message. client.nack("test-sync-nack", msg.id, "transient failure") - # Redelivery on the same stream. msg2 = next(stream) assert msg2.id == msg_id assert msg2.attempt_count == 1 - # Ack to clean up. client.ack("test-sync-nack", msg2.id) def test_enqueue_nonexistent_queue(self, server: object) -> None: @@ -70,7 +66,9 @@ def test_enqueue_nonexistent_queue(self, server: object) -> None: assert isinstance(server, TestServer) - with fila.Client(server.addr) as client, pytest.raises(fila.QueueNotFoundError): + with fila.Client( + server.addr, accumulator_mode=fila.AccumulatorMode.DISABLED + ) as client, pytest.raises(fila.QueueNotFoundError): client.enqueue("does-not-exist", None, b"test") @@ -86,13 +84,11 @@ async def test_async_enqueue_consume_ack(self, server: object) -> None: server.create_queue("test-async-eca") async with fila.AsyncClient(server.addr) as client: - # Enqueue a message. msg_id = await client.enqueue( "test-async-eca", {"tenant": "acme"}, b"hello async" ) assert msg_id != "" - # Consume the message. stream = await client.consume("test-async-eca") msg = await stream.__anext__() @@ -100,7 +96,6 @@ async def test_async_enqueue_consume_ack(self, server: object) -> None: assert msg.headers["tenant"] == "acme" assert msg.payload == b"hello async" - # Ack the message. await client.ack("test-async-eca", msg.id) @@ -128,6 +123,7 @@ def test_tls_enqueue_consume_ack(self, tls_server: object) -> None: ca_cert=ca_cert, client_cert=client_cert, client_key=client_key, + accumulator_mode=fila.AccumulatorMode.DISABLED, ) as client: msg_id = client.enqueue("test-tls", {"secure": "true"}, b"tls payload") assert msg_id != "" @@ -187,7 +183,11 @@ def test_api_key_enqueue_consume_ack(self, auth_server: object) -> None: auth_server.create_queue("test-auth") - with fila.Client(auth_server.addr, api_key=auth_server.api_key) as client: + with fila.Client( + auth_server.addr, + api_key=auth_server.api_key, + accumulator_mode=fila.AccumulatorMode.DISABLED, + ) as client: msg_id = client.enqueue("test-auth", None, b"authenticated") assert msg_id != "" @@ -201,31 +201,19 @@ def test_api_key_enqueue_consume_ack(self, auth_server: object) -> None: def test_missing_api_key_rejected(self, auth_server: object) -> None: """Requests without API key are rejected when auth is enabled.""" - import grpc - from tests.conftest import TestServer assert isinstance(auth_server, TestServer) - # Probe whether the server actually enforces API key auth. - # The dev-latest binary may predate the bootstrap_apikey feature, - # in which case unauthenticated requests succeed rather than fail. - with fila.Client(auth_server.addr) as probe: - try: - probe.enqueue("__auth_probe__", None, b"probe") - except fila.RPCError as e: - if e.code != grpc.StatusCode.UNAUTHENTICATED: - pytest.fail(f"unexpected RPC error during auth probe: {e.code}") - except fila.QueueNotFoundError: - pytest.skip("server does not enforce API key auth") - else: - pytest.skip("server does not enforce API key auth") - - # If we reach here, the server enforces auth. - with fila.Client(auth_server.addr) as client: - with pytest.raises(fila.RPCError) as exc_info: - client.enqueue("test-auth", None, b"no-key") - assert exc_info.value.code == grpc.StatusCode.UNAUTHENTICATED + # Attempt to connect without an API key -- the handshake should fail. + with ( + pytest.raises((fila.UnauthorizedError, fila.FilaError)), + fila.Client( + auth_server.addr, + accumulator_mode=fila.AccumulatorMode.DISABLED, + ) as client, + ): + client.enqueue("test-auth", None, b"no-key") @pytest.mark.asyncio async def test_async_api_key_enqueue(self, auth_server: object) -> None: diff --git a/tests/test_enqueue_integration.py b/tests/test_enqueue_integration.py new file mode 100644 index 0000000..10522c9 --- /dev/null +++ b/tests/test_enqueue_integration.py @@ -0,0 +1,228 @@ +"""Integration tests for enqueue_many and accumulator modes. + +These tests require a running fila-server binary. They are skipped +automatically when the server is not found (local dev). +""" + +from __future__ import annotations + +import pytest + +import fila + + +class TestEnqueueMany: + """Integration tests for the explicit enqueue_many method.""" + + def test_enqueue_many_multiple_messages(self, server: object) -> None: + """enqueue_many sends multiple messages in one request and returns per-message results.""" + from tests.conftest import TestServer + + assert isinstance(server, TestServer) + server.create_queue("test-enqueue-many") + + with fila.Client( + server.addr, accumulator_mode=fila.AccumulatorMode.DISABLED + ) as client: + results = client.enqueue_many([ + ("test-enqueue-many", {"idx": "0"}, b"payload-0"), + ("test-enqueue-many", {"idx": "1"}, b"payload-1"), + ("test-enqueue-many", {"idx": "2"}, b"payload-2"), + ]) + + assert len(results) == 3 + for r in results: + assert r.is_success + assert r.message_id is not None + assert r.error is None + + ids = [r.message_id for r in results] + assert len(set(ids)) == 3 + + def test_enqueue_many_single_message(self, server: object) -> None: + """enqueue_many works with a single message.""" + from tests.conftest import TestServer + + assert isinstance(server, TestServer) + server.create_queue("test-enqueue-many-single") + + with fila.Client( + server.addr, accumulator_mode=fila.AccumulatorMode.DISABLED + ) as client: + results = client.enqueue_many([ + ("test-enqueue-many-single", None, b"solo"), + ]) + + assert len(results) == 1 + assert results[0].is_success + assert results[0].message_id is not None + + def test_enqueue_many_consume_verify(self, server: object) -> None: + """Messages enqueued via enqueue_many can be consumed and acked.""" + from tests.conftest import TestServer + + assert isinstance(server, TestServer) + server.create_queue("test-enqueue-many-consume") + + with fila.Client( + server.addr, accumulator_mode=fila.AccumulatorMode.DISABLED + ) as client: + results = client.enqueue_many([ + ("test-enqueue-many-consume", {"k": "v"}, b"multi-msg"), + ]) + assert results[0].is_success + + stream = client.consume("test-enqueue-many-consume") + msg = next(stream) + + assert msg.id == results[0].message_id + assert msg.headers["k"] == "v" + assert msg.payload == b"multi-msg" + + client.ack("test-enqueue-many-consume", msg.id) + + +class TestAsyncEnqueueMany: + """Integration tests for the async enqueue_many method.""" + + @pytest.mark.asyncio + async def test_async_enqueue_many(self, server: object) -> None: + """Async enqueue_many sends multiple messages.""" + from tests.conftest import TestServer + + assert isinstance(server, TestServer) + server.create_queue("test-async-enqueue-many") + + async with fila.AsyncClient(server.addr) as client: + results = await client.enqueue_many([ + ("test-async-enqueue-many", None, b"async-0"), + ("test-async-enqueue-many", None, b"async-1"), + ]) + + assert len(results) == 2 + for r in results: + assert r.is_success + assert r.message_id is not None + + +class TestAccumulatorModes: + """Integration tests for accumulator modes (AccumulatorMode.AUTO, Linger).""" + + def test_auto_mode_enqueue(self, server: object) -> None: + """AUTO mode enqueues messages through the accumulator.""" + from tests.conftest import TestServer + + assert isinstance(server, TestServer) + server.create_queue("test-auto-accum") + + with fila.Client( + server.addr, accumulator_mode=fila.AccumulatorMode.AUTO + ) as client: + msg_id = client.enqueue("test-auto-accum", None, b"auto-msg") + assert msg_id != "" + + stream = client.consume("test-auto-accum") + msg = next(stream) + assert msg.id == msg_id + assert msg.payload == b"auto-msg" + client.ack("test-auto-accum", msg.id) + + def test_auto_mode_multiple_messages(self, server: object) -> None: + """AUTO mode handles multiple sequential enqueues.""" + from tests.conftest import TestServer + + assert isinstance(server, TestServer) + server.create_queue("test-auto-multi") + + with fila.Client( + server.addr, accumulator_mode=fila.AccumulatorMode.AUTO + ) as client: + ids = [] + for i in range(5): + msg_id = client.enqueue( + "test-auto-multi", None, f"msg-{i}".encode() + ) + assert msg_id != "" + ids.append(msg_id) + + assert len(set(ids)) == 5 + + def test_disabled_mode_enqueue(self, server: object) -> None: + """DISABLED mode sends each enqueue as a direct request.""" + from tests.conftest import TestServer + + assert isinstance(server, TestServer) + server.create_queue("test-disabled") + + with fila.Client( + server.addr, accumulator_mode=fila.AccumulatorMode.DISABLED + ) as client: + msg_id = client.enqueue("test-disabled", None, b"direct") + assert msg_id != "" + + stream = client.consume("test-disabled") + msg = next(stream) + assert msg.id == msg_id + client.ack("test-disabled", msg.id) + + def test_linger_mode_enqueue(self, server: object) -> None: + """LINGER mode enqueues messages through a timer-based accumulator.""" + from tests.conftest import TestServer + + assert isinstance(server, TestServer) + server.create_queue("test-linger") + + with fila.Client( + server.addr, + accumulator_mode=fila.Linger(linger_ms=50, max_messages=10), + ) as client: + msg_id = client.enqueue("test-linger", None, b"lingered") + assert msg_id != "" + + stream = client.consume("test-linger") + msg = next(stream) + assert msg.id == msg_id + assert msg.payload == b"lingered" + client.ack("test-linger", msg.id) + + def test_default_mode_is_auto(self, server: object) -> None: + """Client defaults to AUTO accumulator mode.""" + from tests.conftest import TestServer + + assert isinstance(server, TestServer) + server.create_queue("test-default-mode") + + with fila.Client(server.addr) as client: + msg_id = client.enqueue("test-default-mode", None, b"default") + assert msg_id != "" + + +class TestAccumulatorModeTypes: + """Unit tests for AccumulatorMode and Linger types (no server needed).""" + + def test_accumulator_mode_enum(self) -> None: + """AccumulatorMode has AUTO and DISABLED variants.""" + assert fila.AccumulatorMode.AUTO is not None + assert fila.AccumulatorMode.DISABLED is not None + modes = {fila.AccumulatorMode.AUTO, fila.AccumulatorMode.DISABLED} + assert len(modes) == 2 + + def test_linger_fields(self) -> None: + """Linger stores linger_ms and max_messages.""" + linger = fila.Linger(linger_ms=100, max_messages=50) + assert linger.linger_ms == 100 + assert linger.max_messages == 50 + + def test_enqueue_result_success(self) -> None: + """EnqueueResult.is_success returns True when message_id is set.""" + r = fila.EnqueueResult(message_id="abc", error=None) + assert r.is_success + assert r.message_id == "abc" + assert r.error is None + + def test_enqueue_result_error(self) -> None: + """EnqueueResult.is_success returns False when error is set.""" + r = fila.EnqueueResult(message_id=None, error="queue not found") + assert not r.is_success + assert r.message_id is None + assert r.error == "queue not found" diff --git a/tests/test_fibp.py b/tests/test_fibp.py new file mode 100644 index 0000000..343bd11 --- /dev/null +++ b/tests/test_fibp.py @@ -0,0 +1,289 @@ +"""Unit tests for the FIBP codec and primitives.""" + +from __future__ import annotations + +from fila.fibp.codec import ( + decode_ack_result, + decode_delivery, + decode_enqueue_result, + decode_error, + decode_handshake_ok, + encode_ack, + encode_consume, + encode_enqueue, + encode_handshake, + encode_nack, +) +from fila.fibp.opcodes import ErrorCode, FrameHeader, Opcode +from fila.fibp.primitives import Reader, Writer + + +class TestPrimitives: + """Test Writer/Reader round-trip for all primitive types.""" + + def test_u8(self) -> None: + w = Writer() + w.write_u8(42) + r = Reader(w.finish()) + assert r.read_u8() == 42 + + def test_u16(self) -> None: + w = Writer() + w.write_u16(1234) + r = Reader(w.finish()) + assert r.read_u16() == 1234 + + def test_u32(self) -> None: + w = Writer() + w.write_u32(0xDEADBEEF) + r = Reader(w.finish()) + assert r.read_u32() == 0xDEADBEEF + + def test_u64(self) -> None: + w = Writer() + w.write_u64(0xDEADBEEFCAFE0001) + r = Reader(w.finish()) + assert r.read_u64() == 0xDEADBEEFCAFE0001 + + def test_i64(self) -> None: + w = Writer() + w.write_i64(-42) + r = Reader(w.finish()) + assert r.read_i64() == -42 + + def test_f64(self) -> None: + w = Writer() + w.write_f64(3.14) + r = Reader(w.finish()) + assert abs(r.read_f64() - 3.14) < 1e-10 + + def test_bool(self) -> None: + w = Writer() + w.write_bool(True) + w.write_bool(False) + r = Reader(w.finish()) + assert r.read_bool() is True + assert r.read_bool() is False + + def test_string(self) -> None: + w = Writer() + w.write_string("hello") + r = Reader(w.finish()) + assert r.read_string() == "hello" + + def test_string_empty(self) -> None: + w = Writer() + w.write_string("") + r = Reader(w.finish()) + assert r.read_string() == "" + + def test_bytes(self) -> None: + w = Writer() + w.write_bytes(b"\x00\x01\x02") + r = Reader(w.finish()) + assert r.read_bytes() == b"\x00\x01\x02" + + def test_string_map(self) -> None: + w = Writer() + w.write_string_map({"a": "1", "b": "2"}) + r = Reader(w.finish()) + m = r.read_string_map() + assert m == {"a": "1", "b": "2"} + + def test_string_list(self) -> None: + w = Writer() + w.write_string_list(["x", "y", "z"]) + r = Reader(w.finish()) + assert r.read_string_list() == ["x", "y", "z"] + + def test_optional_string_present(self) -> None: + w = Writer() + w.write_optional_string("present") + r = Reader(w.finish()) + assert r.read_optional_string() == "present" + + def test_optional_string_absent(self) -> None: + w = Writer() + w.write_optional_string(None) + r = Reader(w.finish()) + assert r.read_optional_string() is None + + +class TestCodec: + """Test encode/decode round-trips for key opcodes.""" + + def test_handshake_encode(self) -> None: + """Handshake encodes version + optional API key.""" + data = encode_handshake(1, "my-key") + r = Reader(data) + assert r.read_u16() == 1 + assert r.read_optional_string() == "my-key" + + def test_handshake_no_key(self) -> None: + data = encode_handshake(1, None) + r = Reader(data) + assert r.read_u16() == 1 + assert r.read_optional_string() is None + + def test_handshake_ok_decode(self) -> None: + w = Writer() + w.write_u16(1) # version + w.write_u64(42) # node_id + w.write_u32(16 * 1024 * 1024) # max_frame_size + version, node_id, mfs = decode_handshake_ok(w.finish()) + assert version == 1 + assert node_id == 42 + assert mfs == 16 * 1024 * 1024 + + def test_enqueue_encode_decode(self) -> None: + msgs = [ + {"queue": "q1", "headers": {"k": "v"}, "payload": b"hello"}, + {"queue": "q2", "headers": {}, "payload": b"world"}, + ] + data = encode_enqueue(msgs) + r = Reader(data) + count = r.read_u32() + assert count == 2 + # First message + assert r.read_string() == "q1" + assert r.read_string_map() == {"k": "v"} + assert r.read_bytes() == b"hello" + # Second message + assert r.read_string() == "q2" + assert r.read_string_map() == {} + assert r.read_bytes() == b"world" + + def test_enqueue_result_decode(self) -> None: + w = Writer() + w.write_u32(2) # count + w.write_u8(ErrorCode.OK) + w.write_string("msg-001") + w.write_u8(ErrorCode.QUEUE_NOT_FOUND) + w.write_string("") + items = decode_enqueue_result(w.finish()) + assert len(items) == 2 + assert items[0].error_code == ErrorCode.OK + assert items[0].message_id == "msg-001" + assert items[1].error_code == ErrorCode.QUEUE_NOT_FOUND + + def test_consume_encode(self) -> None: + data = encode_consume("my-queue") + r = Reader(data) + assert r.read_string() == "my-queue" + + def test_delivery_decode(self) -> None: + w = Writer() + w.write_u32(1) # count + w.write_string("msg-123") # msg_id + w.write_string("test-q") # queue + w.write_string_map({"h": "v"}) # headers + w.write_bytes(b"payload") # payload + w.write_string("fk") # fairness_key + w.write_u32(10) # weight + w.write_string_list(["tk1"]) # throttle_keys + w.write_u32(2) # attempt_count + w.write_u64(1000) # enqueued_at + w.write_u64(2000) # leased_at + + msgs = decode_delivery(w.finish()) + assert len(msgs) == 1 + m = msgs[0] + assert m.message_id == "msg-123" + assert m.queue == "test-q" + assert m.headers == {"h": "v"} + assert m.payload == b"payload" + assert m.fairness_key == "fk" + assert m.weight == 10 + assert m.throttle_keys == ["tk1"] + assert m.attempt_count == 2 + assert m.enqueued_at == 1000 + assert m.leased_at == 2000 + + def test_ack_encode(self) -> None: + data = encode_ack([{"queue": "q", "message_id": "id-1"}]) + r = Reader(data) + assert r.read_u32() == 1 + assert r.read_string() == "q" + assert r.read_string() == "id-1" + + def test_ack_result_decode(self) -> None: + w = Writer() + w.write_u32(2) + w.write_u8(ErrorCode.OK) + w.write_u8(ErrorCode.MESSAGE_NOT_FOUND) + codes = decode_ack_result(w.finish()) + assert codes == [ErrorCode.OK, ErrorCode.MESSAGE_NOT_FOUND] + + def test_nack_encode(self) -> None: + data = encode_nack([{"queue": "q", "message_id": "id-1", "error": "bad"}]) + r = Reader(data) + assert r.read_u32() == 1 + assert r.read_string() == "q" + assert r.read_string() == "id-1" + assert r.read_string() == "bad" + + def test_error_decode(self) -> None: + w = Writer() + w.write_u8(ErrorCode.NOT_LEADER) + w.write_string("not leader") + w.write_string_map({"leader_addr": "10.0.0.1:5555"}) + err = decode_error(w.finish()) + assert err.code == ErrorCode.NOT_LEADER + assert err.message == "not leader" + assert err.metadata["leader_addr"] == "10.0.0.1:5555" + + +class TestFrameHeader: + """Test FrameHeader dataclass.""" + + def test_continuation_flag(self) -> None: + h = FrameHeader(opcode=Opcode.DELIVERY, flags=0x01, request_id=1) + assert h.is_continuation is True + + def test_no_continuation(self) -> None: + h = FrameHeader(opcode=Opcode.DELIVERY, flags=0x00, request_id=1) + assert h.is_continuation is False + + +class TestErrorMapping: + """Test error code -> exception mapping.""" + + def test_queue_not_found(self) -> None: + from fila.errors import QueueNotFoundError, _map_error_code + err = _map_error_code(ErrorCode.QUEUE_NOT_FOUND, "missing") + assert isinstance(err, QueueNotFoundError) + + def test_message_not_found(self) -> None: + from fila.errors import MessageNotFoundError, _map_error_code + err = _map_error_code(ErrorCode.MESSAGE_NOT_FOUND, "gone") + assert isinstance(err, MessageNotFoundError) + + def test_unauthenticated(self) -> None: + from fila.errors import UnauthorizedError, _map_error_code + err = _map_error_code(ErrorCode.UNAUTHENTICATED, "no key") + assert isinstance(err, UnauthorizedError) + + def test_not_leader(self) -> None: + from fila.errors import NotLeaderError, _map_error_code + err = _map_error_code(ErrorCode.NOT_LEADER, "redirect") + assert isinstance(err, NotLeaderError) + + def test_internal_error(self) -> None: + from fila.errors import ProtocolError, _map_error_code + err = _map_error_code(ErrorCode.INTERNAL_ERROR, "boom") + assert isinstance(err, ProtocolError) + + def test_raise_from_error_frame_not_leader(self) -> None: + import pytest + + from fila.errors import NotLeaderError, _raise_from_error_frame + from fila.fibp.codec import ErrorFrame + + err = ErrorFrame( + code=ErrorCode.NOT_LEADER, + message="not leader", + metadata={"leader_addr": "10.0.0.1:5555"}, + ) + with pytest.raises(NotLeaderError) as exc_info: + _raise_from_error_frame(err) + assert exc_info.value.leader_addr == "10.0.0.1:5555"