diff --git a/lib/fila/batcher.rb b/lib/fila/batcher.rb index d75fd89..3f09509 100644 --- a/lib/fila/batcher.rb +++ b/lib/fila/batcher.rb @@ -127,7 +127,7 @@ def drain_nonblocking(batch) # Flush a batch of items via the FIBP transport. # Groups items by queue to produce one frame per queue. - def flush_batch(items) # rubocop:disable Metrics/AbcSize + def flush_batch(items) # Group by queue, preserving per-item result queues groups = items.each_with_index.group_by { |item, _| item.message[:queue] } @@ -145,7 +145,9 @@ def flush_batch(items) # rubocop:disable Metrics/AbcSize rescue Transport::ConnectionClosed => e broadcast_error(items, RPCError.new(0, "connection closed: #{e.message}")) rescue StandardError => e - broadcast_error(items, Fila::Error.new(e.message)) + # Re-broadcast the original exception so callers see the specific type + # (e.g. QueueNotFoundError, RPCError) rather than a generic Fila::Error. + broadcast_error(items, e) end # Convert an EnqueueResult into a String (message_id) or Exception. diff --git a/lib/fila/client.rb b/lib/fila/client.rb index 6a731c1..f305150 100644 --- a/lib/fila/client.rb +++ b/lib/fila/client.rb @@ -251,7 +251,7 @@ def consume_stream(queue, &block) when Transport::ConnectionClosed then break when Exception then raise frame when String - msg = Codec.decode_consume_push(frame) + msg = Codec.decode_consume_push(frame, queue_name: queue) block.call(msg) if msg end end diff --git a/lib/fila/codec.rb b/lib/fila/codec.rb index fb24b38..26320b6 100644 --- a/lib/fila/codec.rb +++ b/lib/fila/codec.rb @@ -9,14 +9,93 @@ module Fila module Codec # rubocop:disable Metrics/ModuleLength module_function + # ----------------------------------------------------------------------- + # Low-level read/write primitives + # + # These are listed first so that module_function makes them available as + # both public module methods (Codec.read_u16 etc.) and private instance + # methods. All higher-level encode/decode methods rely on them via + # implicit self, which only works when they are module functions too. + # ----------------------------------------------------------------------- + + # @param str [String] + # @return [String] binary fragment (len:u16BE | utf8-bytes) + def encode_str16(str) + bytes = str.encode('UTF-8').b + [bytes.bytesize].pack('n') + bytes + end + + # @param buf [String] binary buffer + # @param pos [Integer] byte offset + # @return [Array(String, Integer)] decoded string and new offset + def read_str16(buf, pos) + len, pos = read_u16(buf, pos) + [buf.byteslice(pos, len).force_encoding('UTF-8'), pos + len] + end + + # @param buf [String] binary buffer + # @param pos [Integer] byte offset + # @return [Array(Hash, Integer)] decoded headers hash and new offset + def read_headers(buf, pos) + count, pos = read_u8(buf, pos) + headers = {} + count.times do + key, pos = read_str16(buf, pos) + val, pos = read_str16(buf, pos) + headers[key] = val + end + [headers, pos] + end + + # @param buf [String] binary buffer + # @param pos [Integer] byte offset + # @return [Array(Integer, Integer)] decoded u8 value and new offset + def read_u8(buf, pos) + [buf.getbyte(pos), pos + 1] + end + + # @param buf [String] binary buffer + # @param pos [Integer] byte offset + # @return [Array(Integer, Integer)] decoded big-endian u16 value and new offset + def read_u16(buf, pos) + [buf.byteslice(pos, 2).unpack1('n'), pos + 2] + end + + # @param buf [String] binary buffer + # @param pos [Integer] byte offset + # @return [Array(Integer, Integer)] decoded big-endian u32 value and new offset + def read_u32(buf, pos) + [buf.byteslice(pos, 4).unpack1('N'), pos + 4] + end + + # ----------------------------------------------------------------------- + # Single-message encoder + # + # Wire format: header_count:u8 | + # headers: (key_len:u16BE+key, val_len:u16BE+val)* | + # payload_len:u32BE | payload + # ----------------------------------------------------------------------- + + # @param msg [Hash] with :payload and optional :headers + # @return [String] binary fragment + def encode_message(msg) + headers = msg[:headers] || {} + buf = [headers.size].pack('C') + headers.each do |key, val| + buf += encode_str16(key.to_s) + buf += encode_str16(val.to_s) + end + payload_b = (msg[:payload] || '').b + buf += [payload_b.bytesize].pack('N') + payload_b + buf + end + # ----------------------------------------------------------------------- # Enqueue request # # queue_len:u16BE | queue:utf8 # msg_count:u16BE - # messages... (each: header_count:u8 | - # headers: (key_len:u16BE+key, val_len:u16BE+val)* | - # payload_len:u32BE | payload) + # messages... (each encoded by encode_message) # ----------------------------------------------------------------------- # @param queue [String] @@ -72,21 +151,25 @@ def encode_consume(queue, initial_credits: 256) # # Frame payload: msg_count:u16BE | messages... # Each message: msg_id_len:u16BE+msg_id | fk_len:u16BE+fk | - # attempt_count:u32BE | queue_id_len:u16BE+queue_id | - # header_count:u8 | headers | payload_len:u32BE | payload + # attempt_count:u32BE | + # header_count:u8 | headers: (key_len:u16BE+key, val_len:u16BE+val)* | + # payload_len:u32BE | payload + # + # Note: the queue name is NOT included in the wire frame; it is passed in + # by the caller (from the original consume request) via +queue_name+. # # @param payload [String] raw binary frame payload + # @param queue_name [String] name of the queue being consumed # @return [ConsumeMessage, nil] the first message in the frame - def decode_consume_push(payload) + def decode_consume_push(payload, queue_name: '') pos = 0 - _msg_count, pos = read_u16(payload, pos) - msg_id, pos = read_str16(payload, pos) - fairness_key, pos = read_str16(payload, pos) - attempt_count, pos = read_u32(payload, pos) - queue_id, pos = read_str16(payload, pos) - headers, pos = read_headers(payload, pos) - pay_len, pos = read_u32(payload, pos) - body = payload.byteslice(pos, pay_len) + _msg_count, pos = read_u16(payload, pos) + msg_id, pos = read_str16(payload, pos) + fairness_key, pos = read_str16(payload, pos) + attempt_count, pos = read_u32(payload, pos) + headers, pos = read_headers(payload, pos) + pay_len, pos = read_u32(payload, pos) + body = payload.byteslice(pos, pay_len) ConsumeMessage.new( id: msg_id, @@ -94,7 +177,7 @@ def decode_consume_push(payload) payload: body, fairness_key: fairness_key, attempt_count: attempt_count, - queue: queue_id + queue: queue_name ) end @@ -159,53 +242,11 @@ def encode_nack(items) end # Decode a nack response (same shape as ack response). - alias decode_nack_response decode_ack_response - - private - - def encode_message(msg) - headers = msg[:headers] || {} - buf = [headers.size].pack('C') - headers.each do |key, val| - buf += encode_str16(key.to_s) - buf += encode_str16(val.to_s) - end - payload_b = (msg[:payload] || '').b - buf += [payload_b.bytesize].pack('N') + payload_b - buf - end - - def encode_str16(str) - bytes = str.encode('UTF-8').b - [bytes.bytesize].pack('n') + bytes - end - - def read_str16(buf, pos) - len, pos = read_u16(buf, pos) - [buf.byteslice(pos, len).force_encoding('UTF-8'), pos + len] - end - - def read_headers(buf, pos) - count, pos = read_u8(buf, pos) - headers = {} - count.times do - key, pos = read_str16(buf, pos) - val, pos = read_str16(buf, pos) - headers[key] = val - end - [headers, pos] - end - - def read_u8(buf, pos) - [buf.getbyte(pos), pos + 1] - end - - def read_u16(buf, pos) - [buf.byteslice(pos, 2).unpack1('n'), pos + 2] - end - - def read_u32(buf, pos) - [buf.byteslice(pos, 4).unpack1('N'), pos + 4] + # + # module_function does not propagate to aliases, so we define this + # explicitly to ensure it is callable as Codec.decode_nack_response. + def decode_nack_response(payload) + decode_ack_response(payload) end end end diff --git a/lib/fila/transport.rb b/lib/fila/transport.rb index 1f29944..fe63eb5 100644 --- a/lib/fila/transport.rb +++ b/lib/fila/transport.rb @@ -100,25 +100,43 @@ def request(opcode, payload) # Register a push queue for consume-stream server-push frames. # Returns corr_id used to issue the consume request. # + # The server sends two kinds of frames after a consume request: + # 1. An ACK frame (flags=0, corr_id=, empty payload) + # confirming the consume subscription was registered. + # 2. Push frames (FLAG_SERVER_PUSH set, corr_id=0) carrying messages. + # + # We use a one-shot queue for the ACK (to block until the subscription is + # confirmed) and register +push_q+ at corr_id=0 for ongoing push frames. + # # @param payload [String] consume request payload # @param push_q [Queue] messages pushed here as they arrive # @return [Integer] corr_id def start_consume(payload, push_q) corr_id = next_corr_id + ack_q = Queue.new @mutex.synchronize do raise ConnectionClosed, 'connection is closed' if @closed - @pending[corr_id] = push_q + @pending[corr_id] = ack_q # consume ACK routed here (one-shot) + @pending[0] = push_q # server-push frames carry corr_id=0 end write_frame(OP_CONSUME, corr_id, payload) + + # Wait for the ACK so we know the subscription is active before returning. + outcome = ack_q.pop + raise outcome if outcome.is_a?(Exception) + corr_id end # Remove the consume push queue and stop dispatching to it. def stop_consume(corr_id) - @mutex.synchronize { @pending.delete(corr_id) } + @mutex.synchronize do + @pending.delete(corr_id) + @pending.delete(0) + end end # Close the connection. @@ -169,8 +187,8 @@ def perform_handshake end def send_auth - key_bytes = @api_key.encode('UTF-8').b - payload = [key_bytes.bytesize].pack('n') + key_bytes + # The FIBP AUTH frame payload is the raw API key bytes — no length prefix. + payload = @api_key.encode('UTF-8').b request(OP_AUTH, payload) end @@ -216,16 +234,36 @@ def dispatch_frame(frame) dest.push(result) end + # Parse an OP_ERROR frame payload. + # + # FIBP OP_ERROR frames carry a plain UTF-8 message (no numeric code + # prefix). We map well-known error messages to typed Ruby exceptions so + # callers can rescue specific error classes. def parse_error_frame(payload) - err_code = payload.byteslice(0, 2).unpack1('n') - msg_len = payload.byteslice(2, 2).unpack1('n') - msg = payload.byteslice(4, msg_len).force_encoding('UTF-8') - case err_code - when ERR_QUEUE_NOT_FOUND then QueueNotFoundError.new(msg) - when ERR_MESSAGE_NOT_FOUND then MessageNotFoundError.new(msg) - when ERR_UNAUTHENTICATED then RPCError.new(ERR_UNAUTHENTICATED, msg) - else RPCError.new(err_code, msg) - end + msg = payload.force_encoding('UTF-8') + error_from_message(msg) + end + + # Map a plain-text FIBP error message to the appropriate Ruby exception. + def error_from_message(msg) + return RPCError.new(ERR_UNAUTHENTICATED, msg) if auth_error?(msg) + return QueueNotFoundError.new(msg) if queue_not_found_error?(msg) + return MessageNotFoundError.new(msg) if message_not_found_error?(msg) + + RPCError.new(0, msg) + end + + def auth_error?(msg) + msg.include?('authentication') || msg.include?('unauthenticated') || + msg.include?('api key') || msg.include?('OP_AUTH') + end + + def queue_not_found_error?(msg) + msg.include?('queue not found') || msg.include?('queue does not exist') + end + + def message_not_found_error?(msg) + msg.include?('message not found') || msg.include?('lease not found') end def write_frame(opcode, corr_id, payload) @@ -254,9 +292,12 @@ def read_raw(num_bytes) buf end + # corr_id=0 is permanently reserved for server-push frames. Regular + # request IDs cycle through 1..0xFFFFFFFF so they never collide. def next_corr_id @mutex.synchronize do - @corr_seq = (@corr_seq + 1) & 0xFFFFFFFF + @corr_seq += 1 + @corr_seq = 1 if @corr_seq > 0xFFFFFFFF @corr_seq end end diff --git a/test/test_helper.rb b/test/test_helper.rb index d7afa98..f53ecf7 100644 --- a/test/test_helper.rb +++ b/test/test_helper.rb @@ -29,6 +29,20 @@ def self.find_free_port # @return [Hash] server info with :addr, :host, :port, :pid, :data_dir # and optional :tls_config, :bootstrap_apikey def self.start(tls_config: nil, bootstrap_apikey: nil) + # Retry up to 3 times to handle the TOCTOU race between find_free_port + # and the server process binding the port. + 3.times do |attempt| + result = try_start(tls_config: tls_config, bootstrap_apikey: bootstrap_apikey) + return result if result + + raise "fila-server failed to bind after #{attempt + 1} attempt(s)" if attempt == 2 + end + end + + # Attempt to start a fila-server instance once. Returns the server info + # hash on success, or nil if the port was already in use (retry-able). + # Raises on any other failure. + def self.try_start(tls_config: nil, bootstrap_apikey: nil) port = find_free_port addr = "127.0.0.1:#{port}" @@ -73,23 +87,39 @@ def self.start(tls_config: nil, bootstrap_apikey: nil) wait_for_ready(server_info, stderr_path, toml) server_info + rescue RuntimeError => e + # If the port was already in use the server exits immediately and stderr + # contains "Address already in use". Clean up and signal the caller to + # retry with a different port. + if e.message.include?('Address already in use') + FileUtils.rm_rf(data_dir) + return nil + end + + raise end + FIBP_HANDSHAKE = "FIBP\x01\x00".b.freeze + def self.wait_for_ready(server_info, stderr_path, toml) deadline = Time.now + 10 ready = false while Time.now < deadline begin - # Use a plain TCP connect to check the port is accepting connections. - # A full FIBP handshake would fail against a non-FIBP server (e.g. old - # gRPC binary), masking real startup failures. - sock = TCPSocket.new(server_info[:host], server_info[:port]) - sock.close - ready = true - break - rescue SystemCallError - sleep 0.05 + # Perform a full FIBP handshake probe so we only return when the server + # is actually ready to accept FIBP connections, not just when the TCP + # port is open. A plain TCP connect can succeed before the server has + # finished its initialization, causing the first real request to fail. + fibp_ready = probe_fibp(server_info[:host], server_info[:port], + server_info[:tls_config]) + if fibp_ready + ready = true + break + end + rescue SystemCallError, IOError, OpenSSL::SSL::SSLError + # not up yet — fall through to sleep end + sleep 0.05 end return if ready @@ -105,6 +135,51 @@ def self.wait_for_ready(server_info, stderr_path, toml) raise "fila-server failed to start within 10s on #{server_info[:addr]}\nConfig:\n#{toml}\nStderr:\n#{stderr_output}" end + # Attempt a single FIBP handshake probe. Returns true if the server echoes + # the 6-byte handshake back within the timeout. Raises SystemCallError if + # the port is not yet accepting connections. + # + # Uses a blocking read with IO.select so a single probe attempt waits up to + # PROBE_TIMEOUT_S seconds rather than returning false immediately and forcing + # the caller to reconnect (which can leave half-open connections). + PROBE_TIMEOUT_S = 0.5 + + def self.probe_fibp(host, port, tls_config) + tcp = TCPSocket.new(host, port) + tcp.setsockopt(Socket::IPPROTO_TCP, Socket::TCP_NODELAY, 1) + sock = tcp + + if tls_config + ctx = OpenSSL::SSL::SSLContext.new + # Don't verify the server cert in the readiness probe — we only need to + # confirm the server is accepting FIBP connections. + ctx.set_params(verify_mode: OpenSSL::SSL::VERIFY_NONE) + + # For mTLS servers the server requires a client cert; supply one when + # available so the TLS handshake succeeds. + if tls_config[:client_cert_path] && tls_config[:client_key_path] + ctx.cert = OpenSSL::X509::Certificate.new(File.read(tls_config[:client_cert_path])) + ctx.key = OpenSSL::PKey::RSA.new(File.read(tls_config[:client_key_path])) + end + + ssl = OpenSSL::SSL::SSLSocket.new(tcp, ctx) + ssl.hostname = host + ssl.connect + sock = ssl + end + + sock.write(FIBP_HANDSHAKE) + + # Wait up to PROBE_TIMEOUT_S for the server to echo back the handshake. + return false unless sock.wait_readable(PROBE_TIMEOUT_S) + + echo = sock.read(6) + echo == FIBP_HANDSHAKE + ensure + sock&.close rescue nil # rubocop:disable Style/RescueModifier + tcp.close rescue nil if sock != tcp # rubocop:disable Style/RescueModifier + end + def self.stop(server) Process.kill('TERM', server[:pid]) Process.wait(server[:pid]) @@ -137,17 +212,71 @@ def self.admin_transport(server) end # Send a CreateQueue admin frame via FIBP. + # The payload is a protobuf-encoded CreateQueueRequest { name: }. OP_CREATE_QUEUE = 0x10 def self.create_queue(server, name) transport = admin_transport(server) - name_b = name.encode('UTF-8').b - payload = [name_b.bytesize].pack('n') + name_b + - [0].pack('n') # config_count: 0 key-value pairs + payload = proto_encode_create_queue(name) transport.request(OP_CREATE_QUEUE, payload) rescue StandardError => e raise "create_queue #{name.inspect} failed: #{e.message}" ensure transport&.close end + + # Hand-encode a CreateQueueRequest protobuf message. + # + # CreateQueueRequest { string name = 1; QueueConfig config = 2; } + # + # We only set field 1 (name). For strings ≤ 127 bytes the varint length + # fits in one byte, which covers all queue names used in tests. + # + # Proto3 wire format for a string field: + # tag: (field_number << 3) | wire_type → field 1, wire type 2 → 0x0a + # len: varint-encoded byte length of the string + # data: UTF-8 bytes + def self.proto_encode_create_queue(name) + name_b = name.encode('UTF-8').b + raise ArgumentError, "queue name too long (#{name_b.bytesize} bytes)" if name_b.bytesize > 127 + + "\x0a".b + [name_b.bytesize].pack('C') + name_b + end + + # Probe whether the running server binary supports TLS. Some bleeding-edge + # builds panic on TLS startup when the Rustls crypto provider is not + # installed. Returns true if a TLS-configured server starts successfully. + def self.tls_supported? + return false unless FILA_SERVER_AVAILABLE + + cert_dir = Dir.mktmpdir('fila-tls-probe-') + # Generate a minimal self-signed server cert for the probe. + key = OpenSSL::PKey::RSA.new(2048) + cert = OpenSSL::X509::Certificate.new + cert.version = 2 + cert.serial = 1 + cert.subject = OpenSSL::X509::Name.parse('/CN=fila-tls-probe') + cert.issuer = cert.subject + cert.public_key = key.public_key + cert.not_before = Time.now - 60 + cert.not_after = Time.now + 3600 + cert.sign(key, OpenSSL::Digest.new('SHA256')) + + cert_path = File.join(cert_dir, 'server.crt') + key_path = File.join(cert_dir, 'server.key') + File.write(cert_path, cert.to_pem) + File.write(key_path, key.to_pem) + + server = start(tls_config: { server_cert_path: cert_path, server_key_path: key_path }) + stop(server) + true + rescue StandardError + false + ensure + FileUtils.rm_rf(cert_dir) if cert_dir + end end + +# Probe TLS support once at load time so each TLS test can guard itself with: +# skip 'TLS not supported by this server binary' unless FILA_TLS_AVAILABLE +FILA_TLS_AVAILABLE = TestServerHelper.tls_supported? diff --git a/test/test_tls_auth.rb b/test/test_tls_auth.rb index a008387..458f951 100644 --- a/test/test_tls_auth.rb +++ b/test/test_tls_auth.rb @@ -3,7 +3,7 @@ require 'test_helper' require 'openssl' -return unless FILA_SERVER_AVAILABLE +return unless FILA_SERVER_AVAILABLE && FILA_TLS_AVAILABLE # Helper to generate self-signed CA and server/client certificates for tests. module CertHelper