Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 14 additions & 14 deletions fila/async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
AsyncFibpConnection,
FibpError,
decode_ack_nack_response,
decode_consume_message,
decode_consume_push,
decode_enqueue_response,
encode_ack,
encode_consume,
Expand Down Expand Up @@ -259,11 +259,12 @@ async def consume(self, queue: str) -> AsyncIterator[ConsumeMessage]:
except FibpError as e:
raise _map_fibp_error(e.code, e.message) from e

return self._consume_iter(q)
return self._consume_iter(q, queue)

async def _consume_iter(
self,
q: object,
queue: str,
) -> AsyncIterator[ConsumeMessage]:
import asyncio
# q is an asyncio.Queue[bytes | None]
Expand All @@ -273,23 +274,22 @@ async def _consume_iter(
if body is None:
return
try:
msg_id, queue, headers, payload, fairness_key, attempt_count = (
decode_consume_message(body)
)
messages = decode_consume_push(body)
except Exception:
_log.warning(
"failed to decode consume message; skipping frame",
"failed to decode consume push frame; skipping",
exc_info=True,
)
continue
yield ConsumeMessage(
id=msg_id,
headers=headers,
payload=payload,
fairness_key=fairness_key,
attempt_count=attempt_count,
queue=queue,
)
for msg_id, headers, payload, fairness_key, attempt_count in messages:
yield ConsumeMessage(
id=msg_id,
headers=headers,
payload=payload,
fairness_key=fairness_key,
attempt_count=attempt_count,
queue=queue,
)

async def ack(self, queue: str, msg_id: str) -> None:
"""Acknowledge a successfully processed message.
Expand Down
6 changes: 3 additions & 3 deletions fila/batcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from concurrent.futures import Future, ThreadPoolExecutor
from typing import TYPE_CHECKING

from fila.errors import _map_enqueue_error_code
from fila.errors import _map_enqueue_error_code, _map_fibp_error
from fila.fibp import (
FibpError,
decode_enqueue_response,
Expand Down Expand Up @@ -59,7 +59,7 @@ def _flush_single(
else:
item.future.set_exception(_map_enqueue_error_code(err_code, err_msg))
except FibpError as e:
item.future.set_exception(_map_enqueue_error_code(e.code, e.message))
item.future.set_exception(_map_fibp_error(e.code, e.message))
except Exception as e:
item.future.set_exception(e)

Expand Down Expand Up @@ -99,7 +99,7 @@ def _flush_queue_batch(
try:
body = conn.send_request(frame, corr_id).result()
except FibpError as e:
err = _map_enqueue_error_code(e.code, e.message)
err = _map_fibp_error(e.code, e.message)
for item in items:
item.future.set_exception(err)
return
Expand Down
29 changes: 14 additions & 15 deletions fila/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
FibpConnection,
FibpError,
decode_ack_nack_response,
decode_consume_message,
decode_consume_push,
decode_enqueue_response,
encode_ack,
encode_consume,
Expand Down Expand Up @@ -296,33 +296,32 @@ def consume(self, queue: str) -> Iterator[ConsumeMessage]:
except FibpError as e:
raise _map_fibp_error(e.code, e.message) from e

return self._consume_iter(cq)
return self._consume_iter(cq, queue)

def _consume_iter(self, cq: object) -> Iterator[ConsumeMessage]:
def _consume_iter(self, cq: object, queue: str) -> Iterator[ConsumeMessage]:
from fila.fibp import _ConsumeQueue
assert isinstance(cq, _ConsumeQueue)
while True:
body = cq.get()
if body is None:
return
try:
msg_id, queue, headers, payload, fairness_key, attempt_count = (
decode_consume_message(body)
)
messages = decode_consume_push(body)
except Exception:
_log.warning(
"failed to decode consume message; skipping frame",
"failed to decode consume push frame; skipping",
exc_info=True,
)
continue
yield ConsumeMessage(
id=msg_id,
headers=headers,
payload=payload,
fairness_key=fairness_key,
attempt_count=attempt_count,
queue=queue,
)
for msg_id, headers, payload, fairness_key, attempt_count in messages:
yield ConsumeMessage(
id=msg_id,
headers=headers,
payload=payload,
fairness_key=fairness_key,
attempt_count=attempt_count,
queue=queue,
)

def ack(self, queue: str, msg_id: str) -> None:
"""Acknowledge a successfully processed message.
Expand Down
3 changes: 3 additions & 0 deletions fila/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

from fila.fibp import (
ERR_AUTH_REQUIRED,
ERR_INTERNAL,
ERR_MESSAGE_NOT_FOUND,
ERR_PERMISSION_DENIED,
Expand Down Expand Up @@ -79,4 +80,6 @@ def _map_fibp_error(code: int, message: str) -> FilaError:
return QueueNotFoundError(message)
if code == ERR_MESSAGE_NOT_FOUND:
return MessageNotFoundError(message)
if code in (ERR_AUTH_REQUIRED, ERR_PERMISSION_DENIED):
return TransportError(code, message)
return TransportError(code, message)
138 changes: 92 additions & 46 deletions fila/fibp.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,8 +177,12 @@ def encode_nack(corr_id: int, items: list[tuple[str, str, str]]) -> bytes:


def encode_auth(corr_id: int, api_key: str) -> bytes:
"""Encode an AUTH frame carrying the API key."""
return _encode_frame(0, OP_AUTH, corr_id, _encode_str(api_key))
"""Encode an AUTH frame carrying the API key.

The server expects the raw UTF-8 bytes of the key as the payload —
no u16 length prefix (unlike most string fields in this protocol).
"""
return _encode_frame(0, OP_AUTH, corr_id, api_key.encode())


def encode_admin(op: int, corr_id: int, proto_body: bytes) -> bytes:
Expand Down Expand Up @@ -221,37 +225,47 @@ def decode_enqueue_response(body: bytes) -> list[tuple[bool, str, int, str]]:
return results


def decode_consume_message(body: bytes) -> tuple[str, str, dict[str, str], bytes, str, int]:
"""Decode a single server-pushed consume frame body.
def decode_consume_push(
body: bytes,
) -> list[tuple[str, dict[str, str], bytes, str, int]]:
"""Decode a server-pushed consume frame body (batch format).

Returns ``(msg_id, queue, headers, payload, fairness_key, attempt_count)``.
Returns a list of ``(msg_id, headers, payload, fairness_key, attempt_count)``
tuples. The queue name is *not* included in the push frame — callers must
supply it from the subscribe context.

The consume push wire format is::
The server wire format is::

msg_id_len:u16 | msg_id
queue_len:u16 | queue
fairness_key_len:u16 | fairness_key
attempt_count:u32
header_count:u8 | (key_len:u16 key val_len:u16 val)...
payload_len:u32 | payload
msg_count:u16
for each message:
msg_id_len:u16 | msg_id
fairness_key:u16 | fairness_key
attempt_count:u32
header_count:u8 | (key_len:u16 key val_len:u16 val)...
payload_len:u32 | payload
"""
offset = 0
msg_id, offset = _decode_str(body, offset)
queue, offset = _decode_str(body, offset)
fairness_key, offset = _decode_str(body, offset)
(attempt_count,) = struct.unpack_from(">I", body, offset)
offset += 4
(header_count,) = struct.unpack_from(">B", body, offset)
offset += 1
headers: dict[str, str] = {}
for _ in range(header_count):
k, offset = _decode_str(body, offset)
v, offset = _decode_str(body, offset)
headers[k] = v
(payload_len,) = struct.unpack_from(">I", body, offset)
offset += 4
payload = body[offset: offset + payload_len]
return msg_id, queue, headers, payload, fairness_key, attempt_count
(count,) = struct.unpack_from(">H", body, offset)
offset += 2
results: list[tuple[str, dict[str, str], bytes, str, int]] = []
for _ in range(count):
msg_id, offset = _decode_str(body, offset)
fairness_key, offset = _decode_str(body, offset)
(attempt_count,) = struct.unpack_from(">I", body, offset)
offset += 4
(header_count,) = struct.unpack_from(">B", body, offset)
offset += 1
headers: dict[str, str] = {}
for _ in range(header_count):
k, offset = _decode_str(body, offset)
v, offset = _decode_str(body, offset)
headers[k] = v
(payload_len,) = struct.unpack_from(">I", body, offset)
offset += 4
payload = body[offset: offset + payload_len]
offset += payload_len
results.append((msg_id, headers, payload, fairness_key, attempt_count))
return results


def decode_ack_nack_response(body: bytes) -> list[tuple[bool, int, str]]:
Expand All @@ -276,10 +290,28 @@ def decode_ack_nack_response(body: bytes) -> list[tuple[bool, int, str]]:


def decode_error_frame(body: bytes) -> tuple[int, str]:
"""Decode a 0xFE ERROR frame body. Returns ``(error_code, message)``."""
(code,) = struct.unpack_from(">H", body, 0)
msg, _ = _decode_str(body, 2)
return code, msg
"""Decode a 0xFE ERROR frame body. Returns ``(error_code, message)``.

The server encodes error frames as raw UTF-8 message bytes with no code
prefix. This function infers the error code from the message content so
that callers can perform type-safe error handling.
"""
msg = body.decode(errors="replace")
# Infer the error code from well-known message prefixes.
lower = msg.lower()
if "queue" in lower and "not found" in lower:
return ERR_QUEUE_NOT_FOUND, msg
if "message" in lower and "not found" in lower:
return ERR_MESSAGE_NOT_FOUND, msg
if "permission denied" in lower or "does not have" in lower:
return ERR_PERMISSION_DENIED, msg
if (
"authentication required" in lower
or "invalid or missing api key" in lower
or "auth" in lower
):
return ERR_AUTH_REQUIRED, msg
return ERR_INTERNAL, msg


# ------------------------------------------------------------------
Expand Down Expand Up @@ -416,10 +448,18 @@ def send_request(self, frame: bytes, corr_id: int) -> Future[bytes]:
return fut

def open_consume_stream(self, frame: bytes, corr_id: int) -> _ConsumeQueue:
"""Register a consume queue, send *frame*, and return the queue."""
"""Register a consume queue, send *frame*, and return the queue.

The server sends push frames with correlation_id=0 (FLAG_STREAM set),
so the queue is registered under both the original corr_id (to absorb
the initial stream-accepted ack) and 0 (to receive pushed messages).
Only one consume stream per connection is supported.
"""
cq = _ConsumeQueue()
with self._lock:
self._consume_queues[corr_id] = cq
# Push frames always arrive with corr_id=0.
self._consume_queues[0] = cq
Copy link
Copy Markdown

@cubic-dev-ai cubic-dev-ai bot Mar 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1: Enforce the single-consume-stream constraint when binding corr_id=0; otherwise a second consume() call silently steals all server-push frames from the first stream.

Prompt for AI agents
Check if this issue is valid — if so, understand the root cause and fix it. At fila/fibp.py, line 457:

<comment>Enforce the single-consume-stream constraint when binding `corr_id=0`; otherwise a second `consume()` call silently steals all server-push frames from the first stream.</comment>

<file context>
@@ -416,10 +443,18 @@ def send_request(self, frame: bytes, corr_id: int) -> Future[bytes]:
         with self._lock:
             self._consume_queues[corr_id] = cq
+            # Push frames always arrive with corr_id=0.
+            self._consume_queues[0] = cq
         with self._send_lock:
             self._sock.sendall(frame)
</file context>
Fix with Cubic

with self._send_lock:
self._sock.sendall(frame)
return cq
Expand Down Expand Up @@ -487,15 +527,12 @@ def _dispatch(self, flags: int, op: int, corr_id: int, body: bytes) -> None:
# Resolve a pending future.
with self._lock:
fut: Future[bytes] | None = self._pending.pop(corr_id, None)
# Also check if this is the "end of consume stream" signal
# (op == OP_CONSUME response with no push flag).
cq = self._consume_queues.get(corr_id)

if cq is not None and op == OP_CONSUME:
# Server closed the consume stream.
cq.close()
with self._lock:
self._consume_queues.pop(corr_id, None)
# A non-push OP_CONSUME frame with an empty body is the server's
# "stream accepted" acknowledgment. The consume queue was already
# registered under corr_id=0 in open_consume_stream, so there is
# nothing to do here — just discard the ack frame.
if op == OP_CONSUME and not body:
return

if fut is not None and not fut.done():
Expand Down Expand Up @@ -612,11 +649,19 @@ async def send_request(self, frame: bytes, corr_id: int) -> bytes:
async def open_consume_stream(
self, frame: bytes, corr_id: int
) -> asyncio.Queue[bytes | None]:
"""Send *frame* and return a queue that receives pushed bodies."""
"""Send *frame* and return a queue that receives pushed bodies.

The server sends push frames with correlation_id=0 (FLAG_STREAM set),
so the queue is registered under both the original corr_id (to absorb
the initial stream-accepted ack) and 0 (to receive pushed messages).
Only one consume stream per connection is supported.
"""
assert self._write_lock is not None
assert self._writer is not None
q: asyncio.Queue[bytes | None] = asyncio.Queue()
self._consume_queues[corr_id] = q
# Push frames always arrive with corr_id=0.
self._consume_queues[0] = q
async with self._write_lock:
self._writer.write(frame)
await self._writer.drain()
Expand Down Expand Up @@ -655,10 +700,11 @@ def _dispatch(self, flags: int, op: int, corr_id: int, body: bytes) -> None:
self._wake_all(FibpError(0, "server sent GOAWAY"))
return

# End of consume stream (server sends a non-push CONSUME frame to close).
if op == OP_CONSUME and corr_id in self._consume_queues:
q = self._consume_queues.pop(corr_id)
q.put_nowait(None)
# A non-push OP_CONSUME frame with an empty body is the server's
# "stream accepted" acknowledgment. The consume queue was already
# registered under corr_id=0 in open_consume_stream, so there is
# nothing to do here — just discard the ack frame.
if op == OP_CONSUME and not body:
return

fut = self._pending.pop(corr_id, None)
Expand Down
Loading
Loading