diff --git a/build.gradle b/build.gradle index 655b18a..bc50d8a 100644 --- a/build.gradle +++ b/build.gradle @@ -3,12 +3,11 @@ plugins { id 'maven-publish' id 'signing' id 'io.github.gradle-nexus.publish-plugin' version '2.0.0' - id 'com.google.protobuf' version '0.9.4' id 'com.diffplug.spotless' version '7.0.2' } group = 'dev.faisca' -version = '0.1.0' +version = '0.3.0' java { sourceCompatibility = JavaVersion.VERSION_17 @@ -21,47 +20,12 @@ repositories { mavenCentral() } -def grpcVersion = '1.71.0' -def protobufVersion = '4.29.3' - dependencies { - api "io.grpc:grpc-stub:${grpcVersion}" - api "io.grpc:grpc-protobuf:${grpcVersion}" - - implementation "io.grpc:grpc-netty-shaded:${grpcVersion}" - implementation "com.google.protobuf:protobuf-java:${protobufVersion}" - - compileOnly 'org.apache.tomcat:annotations-api:6.0.53' - testImplementation platform('org.junit:junit-bom:5.11.4') testImplementation 'org.junit.jupiter:junit-jupiter' testRuntimeOnly 'org.junit.platform:junit-platform-launcher' } -protobuf { - protoc { - artifact = "com.google.protobuf:protoc:${protobufVersion}" - } - plugins { - grpc { - artifact = "io.grpc:protoc-gen-grpc-java:${grpcVersion}" - } - } - generateProtoTasks { - all()*.plugins { - grpc {} - } - } -} - -sourceSets { - main { - proto { - srcDir 'proto' - } - } -} - test { useJUnitPlatform() testLogging { diff --git a/proto/fila/v1/admin.proto b/proto/fila/v1/admin.proto deleted file mode 100644 index 886e58d..0000000 --- a/proto/fila/v1/admin.proto +++ /dev/null @@ -1,197 +0,0 @@ -syntax = "proto3"; -package fila.v1; - -// Admin RPCs for operators and the CLI. -service FilaAdmin { - rpc CreateQueue(CreateQueueRequest) returns (CreateQueueResponse); - rpc DeleteQueue(DeleteQueueRequest) returns (DeleteQueueResponse); - rpc SetConfig(SetConfigRequest) returns (SetConfigResponse); - rpc GetConfig(GetConfigRequest) returns (GetConfigResponse); - rpc ListConfig(ListConfigRequest) returns (ListConfigResponse); - rpc GetStats(GetStatsRequest) returns (GetStatsResponse); - rpc Redrive(RedriveRequest) returns (RedriveResponse); - rpc ListQueues(ListQueuesRequest) returns (ListQueuesResponse); - - // API key management. CreateApiKey bypasses auth (bootstrap); others require a valid key. - rpc CreateApiKey(CreateApiKeyRequest) returns (CreateApiKeyResponse); - rpc RevokeApiKey(RevokeApiKeyRequest) returns (RevokeApiKeyResponse); - rpc ListApiKeys(ListApiKeysRequest) returns (ListApiKeysResponse); - - // Per-key ACL management. - rpc SetAcl(SetAclRequest) returns (SetAclResponse); - rpc GetAcl(GetAclRequest) returns (GetAclResponse); -} - -message CreateQueueRequest { - string name = 1; - QueueConfig config = 2; -} - -message QueueConfig { - string on_enqueue_script = 1; - string on_failure_script = 2; - uint64 visibility_timeout_ms = 3; -} - -message CreateQueueResponse { - string queue_id = 1; -} - -message DeleteQueueRequest { - string queue = 1; -} - -message DeleteQueueResponse {} - -message SetConfigRequest { - string key = 1; - string value = 2; -} - -message SetConfigResponse {} - -message GetConfigRequest { - string key = 1; -} - -message GetConfigResponse { - string value = 1; -} - -message ConfigEntry { - string key = 1; - string value = 2; -} - -message ListConfigRequest { - string prefix = 1; -} - -message ListConfigResponse { - repeated ConfigEntry entries = 1; - uint32 total_count = 2; -} - -message GetStatsRequest { - string queue = 1; -} - -message PerFairnessKeyStats { - string key = 1; - uint64 pending_count = 2; - int64 current_deficit = 3; - uint32 weight = 4; -} - -message PerThrottleKeyStats { - string key = 1; - double tokens = 2; - double rate_per_second = 3; - double burst = 4; -} - -message GetStatsResponse { - uint64 depth = 1; - uint64 in_flight = 2; - uint64 active_fairness_keys = 3; - uint32 active_consumers = 4; - uint32 quantum = 5; - repeated PerFairnessKeyStats per_key_stats = 6; - repeated PerThrottleKeyStats per_throttle_stats = 7; - // Cluster fields (0 when not in cluster mode). - uint64 leader_node_id = 8; - uint32 replication_count = 9; -} - -message RedriveRequest { - string dlq_queue = 1; - uint64 count = 2; -} - -message RedriveResponse { - uint64 redriven = 1; -} - -message ListQueuesRequest {} - -message QueueInfo { - string name = 1; - uint64 depth = 2; - uint64 in_flight = 3; - uint32 active_consumers = 4; - uint64 leader_node_id = 5; -} - -message ListQueuesResponse { - repeated QueueInfo queues = 1; - uint32 cluster_node_count = 2; -} - -// --- API Key Management --- - -message CreateApiKeyRequest { - /// Human-readable label for the key. - string name = 1; - /// Optional Unix timestamp (milliseconds) after which the key expires. - /// 0 means no expiration. - uint64 expires_at_ms = 2; - /// When true, the key bypasses all ACL checks (superadmin). - bool is_superadmin = 3; -} - -message CreateApiKeyResponse { - /// Opaque key ID for management operations (revoke, list, set-acl). - string key_id = 1; - /// Plaintext API key. Returned once — store it securely. - string key = 2; - /// Whether this key has superadmin privileges. - bool is_superadmin = 3; -} - -message RevokeApiKeyRequest { - string key_id = 1; -} - -message RevokeApiKeyResponse {} - -message ListApiKeysRequest {} - -message ApiKeyInfo { - string key_id = 1; - string name = 2; - uint64 created_at_ms = 3; - /// 0 means no expiration. - uint64 expires_at_ms = 4; - bool is_superadmin = 5; -} - -message ListApiKeysResponse { - repeated ApiKeyInfo keys = 1; -} - -// --- ACL Management --- - -/// A single permission grant: kind (produce/consume/admin) + queue pattern. -message AclPermission { - /// One of: "produce", "consume", "admin". - string kind = 1; - /// Queue name or wildcard ("*" or "orders.*"). - string pattern = 2; -} - -message SetAclRequest { - string key_id = 1; - repeated AclPermission permissions = 2; -} - -message SetAclResponse {} - -message GetAclRequest { - string key_id = 1; -} - -message GetAclResponse { - string key_id = 1; - repeated AclPermission permissions = 2; - bool is_superadmin = 3; -} diff --git a/proto/fila/v1/messages.proto b/proto/fila/v1/messages.proto deleted file mode 100644 index a0709cf..0000000 --- a/proto/fila/v1/messages.proto +++ /dev/null @@ -1,28 +0,0 @@ -syntax = "proto3"; -package fila.v1; - -import "google/protobuf/timestamp.proto"; - -// Core message envelope persisted in the broker. -message Message { - string id = 1; - map headers = 2; - bytes payload = 3; - MessageMetadata metadata = 4; - MessageTimestamps timestamps = 5; -} - -// Broker-assigned scheduling metadata. -message MessageMetadata { - string fairness_key = 1; - uint32 weight = 2; - repeated string throttle_keys = 3; - uint32 attempt_count = 4; - string queue_id = 5; -} - -// Lifecycle timestamps attached to every message. -message MessageTimestamps { - google.protobuf.Timestamp enqueued_at = 1; - google.protobuf.Timestamp leased_at = 2; -} diff --git a/proto/fila/v1/service.proto b/proto/fila/v1/service.proto deleted file mode 100644 index f14fdd0..0000000 --- a/proto/fila/v1/service.proto +++ /dev/null @@ -1,45 +0,0 @@ -syntax = "proto3"; -package fila.v1; - -import "fila/v1/messages.proto"; - -// Hot-path RPCs for producers and consumers. -service FilaService { - rpc Enqueue(EnqueueRequest) returns (EnqueueResponse); - rpc Consume(ConsumeRequest) returns (stream ConsumeResponse); - rpc Ack(AckRequest) returns (AckResponse); - rpc Nack(NackRequest) returns (NackResponse); -} - -message EnqueueRequest { - string queue = 1; - map headers = 2; - bytes payload = 3; -} - -message EnqueueResponse { - string message_id = 1; -} - -message ConsumeRequest { - string queue = 1; -} - -message ConsumeResponse { - Message message = 1; -} - -message AckRequest { - string queue = 1; - string message_id = 2; -} - -message AckResponse {} - -message NackRequest { - string queue = 1; - string message_id = 2; - string error = 3; -} - -message NackResponse {} diff --git a/src/main/java/dev/faisca/fila/ApiKeyInterceptor.java b/src/main/java/dev/faisca/fila/ApiKeyInterceptor.java deleted file mode 100644 index e7ea461..0000000 --- a/src/main/java/dev/faisca/fila/ApiKeyInterceptor.java +++ /dev/null @@ -1,36 +0,0 @@ -package dev.faisca.fila; - -import io.grpc.CallOptions; -import io.grpc.Channel; -import io.grpc.ClientCall; -import io.grpc.ClientInterceptor; -import io.grpc.ForwardingClientCall.SimpleForwardingClientCall; -import io.grpc.Metadata; -import io.grpc.MethodDescriptor; - -/** - * gRPC client interceptor that attaches a {@code Bearer} API key to the {@code authorization} - * metadata header on every outgoing RPC. - */ -final class ApiKeyInterceptor implements ClientInterceptor { - private static final Metadata.Key AUTH_KEY = - Metadata.Key.of("authorization", Metadata.ASCII_STRING_MARSHALLER); - - private final String headerValue; - - ApiKeyInterceptor(String apiKey) { - this.headerValue = "Bearer " + apiKey; - } - - @Override - public ClientCall interceptCall( - MethodDescriptor method, CallOptions callOptions, Channel next) { - return new SimpleForwardingClientCall<>(next.newCall(method, callOptions)) { - @Override - public void start(Listener responseListener, Metadata headers) { - headers.put(AUTH_KEY, headerValue); - super.start(responseListener, headers); - } - }; - } -} diff --git a/src/main/java/dev/faisca/fila/BatchMode.java b/src/main/java/dev/faisca/fila/BatchMode.java new file mode 100644 index 0000000..0be14fc --- /dev/null +++ b/src/main/java/dev/faisca/fila/BatchMode.java @@ -0,0 +1,93 @@ +package dev.faisca.fila; + +/** + * Controls how the SDK batches {@link FilaClient#enqueue} calls. + * + *

The default is {@link #auto()} -- opportunistic batching that requires zero configuration. At + * low load each message is sent individually (zero added latency). At high load messages accumulate + * naturally and are flushed together. + */ +public final class BatchMode { + enum Kind { + AUTO, + LINGER, + DISABLED + } + + private final Kind kind; + private final int maxBatchSize; + private final long lingerMs; + + private BatchMode(Kind kind, int maxBatchSize, long lingerMs) { + this.kind = kind; + this.maxBatchSize = maxBatchSize; + this.lingerMs = lingerMs; + } + + /** + * Opportunistic batching (default). + * + *

A background thread blocks for the first message, then drains any additional messages that + * arrived concurrently. At low load each message is sent individually. At high load messages + * accumulate naturally into batches. Zero config, zero latency penalty. + * + * @return a new AUTO batch mode with default max batch size (100) + */ + public static BatchMode auto() { + return new BatchMode(Kind.AUTO, 100, 0); + } + + /** + * Opportunistic batching with a custom max batch size. + * + * @param maxBatchSize safety cap on batch size + * @return a new AUTO batch mode + */ + public static BatchMode auto(int maxBatchSize) { + if (maxBatchSize < 1) { + throw new IllegalArgumentException("maxBatchSize must be >= 1"); + } + return new BatchMode(Kind.AUTO, maxBatchSize, 0); + } + + /** + * Timer-based forced batching. + * + *

Buffers messages and flushes when either {@code batchSize} messages accumulate or {@code + * lingerMs} milliseconds elapse since the first message in the batch -- whichever comes first. + * + * @param lingerMs time threshold in milliseconds before a partial batch is flushed + * @param batchSize maximum messages per batch + * @return a new LINGER batch mode + */ + public static BatchMode linger(long lingerMs, int batchSize) { + if (lingerMs < 1) { + throw new IllegalArgumentException("lingerMs must be >= 1"); + } + if (batchSize < 1) { + throw new IllegalArgumentException("batchSize must be >= 1"); + } + return new BatchMode(Kind.LINGER, batchSize, lingerMs); + } + + /** + * No batching. Each {@link FilaClient#enqueue} call is a direct single-message RPC. + * + * @return a DISABLED batch mode + */ + public static BatchMode disabled() { + return new BatchMode(Kind.DISABLED, 0, 0); + } + + Kind getKind() { + return kind; + } + + int getMaxBatchSize() { + return maxBatchSize; + } + + long getLingerMs() { + return lingerMs; + } +} diff --git a/src/main/java/dev/faisca/fila/Batcher.java b/src/main/java/dev/faisca/fila/Batcher.java new file mode 100644 index 0000000..dde79b7 --- /dev/null +++ b/src/main/java/dev/faisca/fila/Batcher.java @@ -0,0 +1,256 @@ +package dev.faisca.fila; + +import dev.faisca.fila.fibp.Codec; +import dev.faisca.fila.fibp.Connection; +import dev.faisca.fila.fibp.Opcodes; +import dev.faisca.fila.fibp.Primitives; +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.ScheduledFuture; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; + +/** + * Background batcher that coalesces individual enqueue calls into multi-message RPCs. + * + *

Supports two modes: AUTO (opportunistic, Nagle-style) and LINGER (timer-based). The batcher + * runs on a dedicated daemon thread and flushes RPCs on an executor pool. + */ +final class Batcher { + private final LinkedBlockingQueue queue = new LinkedBlockingQueue<>(); + private final AtomicBoolean running = new AtomicBoolean(true); + private final Connection connection; + private final BatchMode mode; + private final Thread batcherThread; + private final ExecutorService flushExecutor; + private final ScheduledExecutorService scheduler; + + static final class BatchItem { + final EnqueueMessage message; + final CompletableFuture future; + + BatchItem(EnqueueMessage message, CompletableFuture future) { + this.message = message; + this.future = future; + } + } + + Batcher(Connection connection, BatchMode mode) { + this.connection = connection; + this.mode = mode; + this.flushExecutor = Executors.newCachedThreadPool(r -> newDaemon(r, "fila-batch-flush")); + this.scheduler = + mode.getKind() == BatchMode.Kind.LINGER + ? Executors.newSingleThreadScheduledExecutor(r -> newDaemon(r, "fila-batch-scheduler")) + : null; + + this.batcherThread = + new Thread( + mode.getKind() == BatchMode.Kind.AUTO ? this::runAuto : this::runLinger, + "fila-batcher"); + this.batcherThread.setDaemon(true); + this.batcherThread.start(); + } + + CompletableFuture submit(EnqueueMessage message) { + CompletableFuture future = new CompletableFuture<>(); + if (!running.get()) { + future.completeExceptionally(new FilaException("batcher is shut down")); + return future; + } + queue.add(new BatchItem(message, future)); + return future; + } + + void shutdown() { + running.set(false); + batcherThread.interrupt(); + try { + batcherThread.join(5000); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + + List remaining = new ArrayList<>(); + queue.drainTo(remaining); + if (!remaining.isEmpty()) { + flushBatch(remaining); + } + + flushExecutor.shutdown(); + try { + if (!flushExecutor.awaitTermination(5, TimeUnit.SECONDS)) { + flushExecutor.shutdownNow(); + } + } catch (InterruptedException e) { + flushExecutor.shutdownNow(); + Thread.currentThread().interrupt(); + } + + if (scheduler != null) { + scheduler.shutdown(); + try { + if (!scheduler.awaitTermination(5, TimeUnit.SECONDS)) { + scheduler.shutdownNow(); + } + } catch (InterruptedException e) { + scheduler.shutdownNow(); + Thread.currentThread().interrupt(); + } + } + } + + private void runAuto() { + int maxBatchSize = mode.getMaxBatchSize(); + while (running.get()) { + try { + BatchItem first = queue.take(); + List batch = new ArrayList<>(); + batch.add(first); + queue.drainTo(batch, maxBatchSize - 1); + + List toFlush = List.copyOf(batch); + flushExecutor.submit(() -> flushBatch(toFlush)); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + return; + } + } + } + + private void runLinger() { + int batchSize = mode.getMaxBatchSize(); + long lingerMs = mode.getLingerMs(); + List buffer = new ArrayList<>(); + ScheduledFuture lingerTimer = null; + + while (running.get()) { + try { + if (buffer.isEmpty()) { + BatchItem item = queue.take(); + buffer.add(item); + + if (buffer.size() >= batchSize) { + List toFlush = List.copyOf(buffer); + buffer.clear(); + flushExecutor.submit(() -> flushBatch(toFlush)); + } else { + lingerTimer = + scheduler.schedule( + () -> batcherThread.interrupt(), lingerMs, TimeUnit.MILLISECONDS); + } + } else { + BatchItem item = queue.poll(lingerMs, TimeUnit.MILLISECONDS); + if (item != null) { + buffer.add(item); + queue.drainTo(buffer, batchSize - buffer.size()); + } + + if (buffer.size() >= batchSize || item == null) { + if (lingerTimer != null) { + lingerTimer.cancel(false); + lingerTimer = null; + } + List toFlush = List.copyOf(buffer); + buffer.clear(); + flushExecutor.submit(() -> flushBatch(toFlush)); + } + } + } catch (InterruptedException e) { + if (!buffer.isEmpty()) { + if (lingerTimer != null) { + lingerTimer.cancel(false); + lingerTimer = null; + } + List toFlush = List.copyOf(buffer); + buffer.clear(); + flushExecutor.submit(() -> flushBatch(toFlush)); + } + if (!running.get()) { + Thread.currentThread().interrupt(); + return; + } + } + } + } + + @SuppressWarnings("unchecked") + private void flushBatch(List items) { + if (items.isEmpty()) { + return; + } + + int count = items.size(); + String[] queues = new String[count]; + Map[] headers = new Map[count]; + byte[][] payloads = new byte[count][]; + + for (int i = 0; i < count; i++) { + EnqueueMessage msg = items.get(i).message; + queues[i] = msg.getQueue(); + headers[i] = msg.getHeaders(); + payloads[i] = msg.getPayload(); + } + + int requestId = connection.nextRequestId(); + byte[] frame = Codec.encodeEnqueue(requestId, queues, headers, payloads); + + try { + Connection.Frame response = connection.sendAndReceive(frame, requestId, 30_000); + byte opcode = response.header().opcode(); + + if (opcode == Opcodes.ERROR) { + FilaException ex = FilaClient.mapErrorFrame(response.body()); + for (BatchItem item : items) { + item.future.completeExceptionally(ex); + } + return; + } + + if (opcode != Opcodes.ENQUEUE_RESULT) { + FilaException ex = new RpcException(Opcodes.ERR_INTERNAL, "unexpected response opcode"); + for (BatchItem item : items) { + item.future.completeExceptionally(ex); + } + return; + } + + Primitives.Reader r = new Primitives.Reader(response.body()); + long resultCount = r.readU32(); + + for (int i = 0; i < items.size(); i++) { + BatchItem item = items.get(i); + if (i < resultCount) { + int errorCode = r.readU8(); + String messageId = r.readString(); + if (errorCode == Opcodes.ERR_OK) { + item.future.complete(messageId); + } else { + item.future.completeExceptionally(FilaClient.mapErrorCode(errorCode, messageId)); + } + } else { + item.future.completeExceptionally( + new RpcException(Opcodes.ERR_INTERNAL, "server returned fewer results than sent")); + } + } + } catch (IOException | InterruptedException e) { + FilaException ex = new FilaException("batch enqueue failed", e); + for (BatchItem item : items) { + item.future.completeExceptionally(ex); + } + } + } + + private static Thread newDaemon(Runnable r, String name) { + Thread t = new Thread(r, name); + t.setDaemon(true); + return t; + } +} diff --git a/src/main/java/dev/faisca/fila/ConsumerHandle.java b/src/main/java/dev/faisca/fila/ConsumerHandle.java index b81ee83..16fda53 100644 --- a/src/main/java/dev/faisca/fila/ConsumerHandle.java +++ b/src/main/java/dev/faisca/fila/ConsumerHandle.java @@ -1,24 +1,31 @@ package dev.faisca.fila; -import io.grpc.Context; - /** Handle for a running consume stream. Call {@link #cancel()} to stop consuming. */ public final class ConsumerHandle { - private final Context.CancellableContext context; + private final Runnable cancelAction; private final Thread thread; - ConsumerHandle(Context.CancellableContext context, Thread thread) { - this.context = context; + ConsumerHandle(Runnable cancelAction, Thread thread) { + this.cancelAction = cancelAction; this.thread = thread; } /** Cancel the consume stream and wait for the consumer thread to finish. */ public void cancel() { - context.cancel(null); + cancelAction.run(); try { thread.join(5000); } catch (InterruptedException e) { Thread.currentThread().interrupt(); } } + + /** Block until the consumer thread finishes (without cancelling). */ + void awaitDone() { + try { + thread.join(); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + } } diff --git a/src/main/java/dev/faisca/fila/EnqueueMessage.java b/src/main/java/dev/faisca/fila/EnqueueMessage.java new file mode 100644 index 0000000..80d63ec --- /dev/null +++ b/src/main/java/dev/faisca/fila/EnqueueMessage.java @@ -0,0 +1,43 @@ +package dev.faisca.fila; + +import java.util.Map; + +/** + * A message to be enqueued via {@link FilaClient#enqueueMany(java.util.List)}. + * + *

Each message specifies its target queue, headers, and payload independently, allowing a single + * call to target multiple queues. + */ +public final class EnqueueMessage { + private final String queue; + private final Map headers; + private final byte[] payload; + + /** + * Create a new enqueue message. + * + * @param queue target queue name + * @param headers message headers (may be empty) + * @param payload message payload bytes + */ + public EnqueueMessage(String queue, Map headers, byte[] payload) { + this.queue = queue; + this.headers = Map.copyOf(headers); + this.payload = payload.clone(); + } + + /** Returns the target queue name. */ + public String getQueue() { + return queue; + } + + /** Returns the message headers. */ + public Map getHeaders() { + return headers; + } + + /** Returns the message payload bytes. */ + public byte[] getPayload() { + return payload.clone(); + } +} diff --git a/src/main/java/dev/faisca/fila/EnqueueResult.java b/src/main/java/dev/faisca/fila/EnqueueResult.java new file mode 100644 index 0000000..3bea453 --- /dev/null +++ b/src/main/java/dev/faisca/fila/EnqueueResult.java @@ -0,0 +1,57 @@ +package dev.faisca.fila; + +/** + * The result of a single message within an enqueue call. + * + *

Each message in a multi-message enqueue is independently validated and processed. A failed + * message does not affect the others. Use {@link #isSuccess()} to check the outcome, then either + * {@link #getMessageId()} or {@link #getError()}. + */ +public final class EnqueueResult { + private final String messageId; + private final String error; + + private EnqueueResult(String messageId, String error) { + this.messageId = messageId; + this.error = error; + } + + /** Create a successful result with the broker-assigned message ID. */ + static EnqueueResult success(String messageId) { + return new EnqueueResult(messageId, null); + } + + /** Create a failed result with an error description. */ + static EnqueueResult error(String error) { + return new EnqueueResult(null, error); + } + + /** Returns true if the message was successfully enqueued. */ + public boolean isSuccess() { + return messageId != null; + } + + /** + * Returns the broker-assigned message ID. + * + * @throws IllegalStateException if this result is an error + */ + public String getMessageId() { + if (messageId == null) { + throw new IllegalStateException("result is an error: " + error); + } + return messageId; + } + + /** + * Returns the error description. + * + * @throws IllegalStateException if this result is a success + */ + public String getError() { + if (error == null) { + throw new IllegalStateException("result is a success"); + } + return error; + } +} diff --git a/src/main/java/dev/faisca/fila/FilaClient.java b/src/main/java/dev/faisca/fila/FilaClient.java index 8551fc0..51cea1c 100644 --- a/src/main/java/dev/faisca/fila/FilaClient.java +++ b/src/main/java/dev/faisca/fila/FilaClient.java @@ -1,26 +1,37 @@ package dev.faisca.fila; -import fila.v1.FilaServiceGrpc; -import fila.v1.Messages; -import fila.v1.Service; -import io.grpc.ChannelCredentials; -import io.grpc.Context; -import io.grpc.Grpc; -import io.grpc.ManagedChannel; -import io.grpc.ManagedChannelBuilder; -import io.grpc.StatusRuntimeException; -import io.grpc.TlsChannelCredentials; +import dev.faisca.fila.fibp.Codec; +import dev.faisca.fila.fibp.Connection; +import dev.faisca.fila.fibp.Opcodes; +import dev.faisca.fila.fibp.Primitives; import java.io.ByteArrayInputStream; import java.io.IOException; -import java.util.Iterator; +import java.security.KeyFactory; +import java.security.KeyStore; +import java.security.cert.CertificateFactory; +import java.security.cert.X509Certificate; +import java.util.ArrayList; +import java.util.Base64; +import java.util.List; import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.function.Consumer; +import javax.net.ssl.KeyManagerFactory; +import javax.net.ssl.SSLContext; +import javax.net.ssl.TrustManagerFactory; /** - * Client for the Fila message broker. + * Client for the Fila message broker using the FIBP binary protocol. * - *

Wraps the hot-path gRPC operations: enqueue, consume, ack, nack. + *

Wraps the hot-path operations: enqueue, consume, ack, nack. Also exposes admin and auth + * operations. + * + *

By default, {@code enqueue()} routes through an opportunistic batcher that coalesces messages + * at high load without adding latency at low load. * *

{@code
  * try (FilaClient client = FilaClient.builder("localhost:5555").build()) {
@@ -35,12 +46,26 @@
  * }
*/ public final class FilaClient implements AutoCloseable { - private final ManagedChannel channel; - private final FilaServiceGrpc.FilaServiceBlockingStub blockingStub; - - private FilaClient(ManagedChannel channel) { - this.channel = channel; - this.blockingStub = FilaServiceGrpc.newBlockingStub(channel); + private final Connection connection; + private final byte[] caCertPem; + private final byte[] clientCertPem; + private final byte[] clientKeyPem; + private final String apiKey; + private final Batcher batcher; + + private FilaClient( + Connection connection, + byte[] caCertPem, + byte[] clientCertPem, + byte[] clientKeyPem, + String apiKey, + Batcher batcher) { + this.connection = connection; + this.caCertPem = caCertPem; + this.clientCertPem = clientCertPem; + this.clientKeyPem = clientKeyPem; + this.apiKey = apiKey; + this.batcher = batcher; } /** Returns a new builder for configuring a {@link FilaClient}. */ @@ -56,64 +81,197 @@ public static Builder builder(String address) { * @param payload message payload bytes * @return the broker-assigned message ID (UUIDv7) * @throws QueueNotFoundException if the queue does not exist - * @throws RpcException for unexpected gRPC failures + * @throws RpcException for unexpected protocol failures */ public String enqueue(String queue, Map headers, byte[] payload) { - Service.EnqueueRequest req = - Service.EnqueueRequest.newBuilder() - .setQueue(queue) - .putAllHeaders(headers) - .setPayload(com.google.protobuf.ByteString.copyFrom(payload)) - .build(); + if (batcher != null) { + CompletableFuture future = + batcher.submit(new EnqueueMessage(queue, headers, payload)); + try { + return future.get(); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new FilaException("enqueue interrupted", e); + } catch (ExecutionException e) { + Throwable cause = e.getCause(); + if (cause instanceof FilaException fe) { + throw fe; + } + throw new RpcException(Opcodes.ERR_INTERNAL, cause.getMessage()); + } + } + + return enqueueDirect(queue, headers, payload); + } + + /** + * Enqueue multiple messages in a single RPC call. + * + * @param messages the messages to enqueue + * @return a list of results, one per input message + * @throws RpcException for transport-level failures affecting the entire batch + */ + @SuppressWarnings("unchecked") + public List enqueueMany(List messages) { + int count = messages.size(); + String[] queues = new String[count]; + Map[] headers = new Map[count]; + byte[][] payloads = new byte[count][]; + for (int i = 0; i < count; i++) { + EnqueueMessage msg = messages.get(i); + queues[i] = msg.getQueue(); + headers[i] = msg.getHeaders(); + payloads[i] = msg.getPayload(); + } + + int requestId = connection.nextRequestId(); + byte[] frame = Codec.encodeEnqueue(requestId, queues, headers, payloads); + try { - Service.EnqueueResponse resp = blockingStub.enqueue(req); - return resp.getMessageId(); - } catch (StatusRuntimeException e) { - throw mapEnqueueError(e); + Connection.Frame response = connection.sendAndReceive(frame, requestId, 30_000); + checkForError(response); + + if (response.header().opcode() != Opcodes.ENQUEUE_RESULT) { + throw new RpcException(Opcodes.ERR_INTERNAL, "unexpected response opcode"); + } + + Primitives.Reader r = new Primitives.Reader(response.body()); + long resultCount = r.readU32(); + List results = new ArrayList<>((int) resultCount); + for (int i = 0; i < resultCount; i++) { + int errorCode = r.readU8(); + String messageId = r.readString(); + if (errorCode == Opcodes.ERR_OK) { + results.add(EnqueueResult.success(messageId)); + } else { + results.add(EnqueueResult.error(messageId.isEmpty() ? errorName(errorCode) : messageId)); + } + } + return results; + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new FilaException("enqueueMany failed", e); + } catch (IOException e) { + throw new FilaException("enqueueMany failed", e); } } /** * Open a streaming consumer on the specified queue. * - *

Messages are delivered to the handler on a background thread. Nacked messages are - * redelivered on the same stream. Call {@link ConsumerHandle#cancel()} to stop consuming. - * * @param queue queue to consume from * @param handler callback invoked for each message * @return a handle to cancel the consumer - * @throws QueueNotFoundException if the queue does not exist - * @throws RpcException for unexpected gRPC failures */ public ConsumerHandle consume(String queue, Consumer handler) { - Service.ConsumeRequest req = Service.ConsumeRequest.newBuilder().setQueue(queue).build(); + int requestId = connection.nextRequestId(); + byte[] frame = Codec.encodeConsume(requestId, queue); + + LinkedBlockingQueue deliveryChan = + connection.registerDeliveryChannel(requestId); + + AtomicBoolean cancelled = new AtomicBoolean(false); - Context.CancellableContext ctx = Context.current().withCancellation(); Thread thread = new Thread( () -> { - ctx.run( - () -> { - try { - Iterator stream = blockingStub.consume(req); - while (stream.hasNext()) { - Service.ConsumeResponse resp = stream.next(); - if (!resp.hasMessage() || resp.getMessage().getId().isEmpty()) { - continue; - } - handler.accept(buildConsumeMessage(resp.getMessage())); - } - } catch (StatusRuntimeException e) { - if (e.getStatus().getCode() != io.grpc.Status.Code.CANCELLED) { - throw mapConsumeError(e); - } + try { + // Send subscribe request + Connection.Frame response = connection.sendAndReceive(frame, requestId, 30_000); + + byte opcode = response.header().opcode(); + if (opcode == Opcodes.ERROR) { + Primitives.Reader er = new Primitives.Reader(response.body()); + int errorCode = er.readU8(); + String message = er.readString(); + Map metadata = er.readStringMap(); + + if (errorCode == Opcodes.ERR_NOT_LEADER) { + String leaderAddr = metadata.get("leader_addr"); + if (leaderAddr != null) { + retryConsumeOnLeader(leaderAddr, queue, handler); + return; + } + } + throw mapErrorCode(errorCode, message); + } + + if (opcode != Opcodes.CONSUME_OK) { + throw new RpcException( + Opcodes.ERR_INTERNAL, "unexpected consume response opcode"); + } + + // ConsumeOk received, now read deliveries from the channel + // The delivery frames use the same requestId + while (!cancelled.get() && !connection.isClosed()) { + Connection.Frame delivery = deliveryChan.poll(1, TimeUnit.SECONDS); + if (delivery == null) { + continue; + } + if (delivery.header().opcode() == Opcodes.ERROR) { + break; + } + if (delivery.header().opcode() != Opcodes.DELIVERY) { + continue; + } + + Primitives.Reader dr = new Primitives.Reader(delivery.body()); + long msgCount = dr.readU32(); + for (long m = 0; m < msgCount; m++) { + String msgId = dr.readString(); + String msgQueue = dr.readString(); + Map msgHeaders = dr.readStringMap(); + byte[] msgPayload = dr.readBytes(); + String fairnessKey = dr.readString(); + long weight = dr.readU32(); + String[] throttleKeys = dr.readStringList(); + long attemptCount = dr.readU32(); + long enqueuedAt = dr.readU64(); + long leasedAt = dr.readU64(); + + if (msgId.isEmpty()) { + continue; } - }); + + handler.accept( + new ConsumeMessage( + msgId, + msgHeaders, + msgPayload, + fairnessKey, + (int) attemptCount, + msgQueue)); + } + } + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + if (!cancelled.get()) { + throw new FilaException("consume stream failed", e); + } + } catch (IOException e) { + if (!cancelled.get()) { + throw new FilaException("consume stream failed", e); + } + } finally { + connection.unregisterDeliveryChannel(requestId); + } }, "fila-consumer-" + queue); thread.setDaemon(true); thread.start(); - return new ConsumerHandle(ctx, thread); + + Runnable cancelAction = + () -> { + cancelled.set(true); + try { + connection.send(Codec.encodeCancelConsume(requestId)); + } catch (IOException ignored) { + // best effort + } + connection.unregisterDeliveryChannel(requestId); + }; + + return new ConsumerHandle(cancelAction, thread); } /** @@ -121,16 +279,33 @@ public ConsumerHandle consume(String queue, Consumer handler) { * * @param queue queue the message belongs to * @param msgId ID of the message to acknowledge - * @throws MessageNotFoundException if the message does not exist - * @throws RpcException for unexpected gRPC failures */ public void ack(String queue, String msgId) { - Service.AckRequest req = - Service.AckRequest.newBuilder().setQueue(queue).setMessageId(msgId).build(); + int requestId = connection.nextRequestId(); + byte[] frame = Codec.encodeAck(requestId, new String[] {queue}, new String[] {msgId}); + try { - blockingStub.ack(req); - } catch (StatusRuntimeException e) { - throw mapAckError(e); + Connection.Frame response = connection.sendAndReceive(frame, requestId, 30_000); + checkForError(response); + + if (response.header().opcode() != Opcodes.ACK_RESULT) { + throw new RpcException(Opcodes.ERR_INTERNAL, "unexpected ack response opcode"); + } + + Primitives.Reader r = new Primitives.Reader(response.body()); + long resultCount = r.readU32(); + if (resultCount < 1) { + throw new RpcException(Opcodes.ERR_INTERNAL, "no result from server"); + } + int errorCode = r.readU8(); + if (errorCode != Opcodes.ERR_OK) { + throw mapErrorCode(errorCode, "ack failed"); + } + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new FilaException("ack failed", e); + } catch (IOException e) { + throw new FilaException("ack failed", e); } } @@ -140,76 +315,416 @@ public void ack(String queue, String msgId) { * @param queue queue the message belongs to * @param msgId ID of the message to nack * @param error description of the failure - * @throws MessageNotFoundException if the message does not exist - * @throws RpcException for unexpected gRPC failures */ public void nack(String queue, String msgId, String error) { - Service.NackRequest req = - Service.NackRequest.newBuilder() - .setQueue(queue) - .setMessageId(msgId) - .setError(error) - .build(); + int requestId = connection.nextRequestId(); + byte[] frame = + Codec.encodeNack( + requestId, new String[] {queue}, new String[] {msgId}, new String[] {error}); + + try { + Connection.Frame response = connection.sendAndReceive(frame, requestId, 30_000); + checkForError(response); + + if (response.header().opcode() != Opcodes.NACK_RESULT) { + throw new RpcException(Opcodes.ERR_INTERNAL, "unexpected nack response opcode"); + } + + Primitives.Reader r = new Primitives.Reader(response.body()); + long resultCount = r.readU32(); + if (resultCount < 1) { + throw new RpcException(Opcodes.ERR_INTERNAL, "no result from server"); + } + int errorCode = r.readU8(); + if (errorCode != Opcodes.ERR_OK) { + throw mapErrorCode(errorCode, "nack failed"); + } + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new FilaException("nack failed", e); + } catch (IOException e) { + throw new FilaException("nack failed", e); + } + } + + // --- Admin operations --- + + /** + * Create a queue on the server. + * + * @param name queue name + */ + public void createQueue(String name) { + int requestId = connection.nextRequestId(); + byte[] frame = Codec.encodeCreateQueue(requestId, name, null, null, 0); + + try { + Connection.Frame response = connection.sendAndReceive(frame, requestId, 30_000); + checkForError(response); + + if (response.header().opcode() != Opcodes.CREATE_QUEUE_RESULT) { + throw new RpcException(Opcodes.ERR_INTERNAL, "unexpected createQueue response opcode"); + } + + Primitives.Reader r = new Primitives.Reader(response.body()); + int errorCode = r.readU8(); + if (errorCode != Opcodes.ERR_OK) { + String queueId = r.readString(); + throw mapErrorCode(errorCode, "createQueue: " + queueId); + } + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new FilaException("createQueue failed", e); + } catch (IOException e) { + throw new FilaException("createQueue failed", e); + } + } + + /** + * Delete a queue on the server. + * + * @param queue queue name + */ + public void deleteQueue(String queue) { + int requestId = connection.nextRequestId(); + byte[] frame = Codec.encodeDeleteQueue(requestId, queue); + + try { + Connection.Frame response = connection.sendAndReceive(frame, requestId, 30_000); + checkForError(response); + + if (response.header().opcode() != Opcodes.DELETE_QUEUE_RESULT) { + throw new RpcException(Opcodes.ERR_INTERNAL, "unexpected deleteQueue response opcode"); + } + + Primitives.Reader r = new Primitives.Reader(response.body()); + int errorCode = r.readU8(); + if (errorCode != Opcodes.ERR_OK) { + throw mapErrorCode(errorCode, "deleteQueue failed"); + } + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new FilaException("deleteQueue failed", e); + } catch (IOException e) { + throw new FilaException("deleteQueue failed", e); + } + } + + /** Set a runtime configuration key. */ + public void setConfig(String key, String value) { + int requestId = connection.nextRequestId(); + byte[] frame = Codec.encodeSetConfig(requestId, key, value); + + try { + Connection.Frame response = connection.sendAndReceive(frame, requestId, 30_000); + checkForError(response); + + if (response.header().opcode() != Opcodes.SET_CONFIG_RESULT) { + throw new RpcException(Opcodes.ERR_INTERNAL, "unexpected setConfig response opcode"); + } + + Primitives.Reader r = new Primitives.Reader(response.body()); + int errorCode = r.readU8(); + if (errorCode != Opcodes.ERR_OK) { + throw mapErrorCode(errorCode, "setConfig failed"); + } + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new FilaException("setConfig failed", e); + } catch (IOException e) { + throw new FilaException("setConfig failed", e); + } + } + + /** Get a runtime configuration value. */ + public String getConfig(String key) { + int requestId = connection.nextRequestId(); + byte[] frame = Codec.encodeGetConfig(requestId, key); + + try { + Connection.Frame response = connection.sendAndReceive(frame, requestId, 30_000); + checkForError(response); + + if (response.header().opcode() != Opcodes.GET_CONFIG_RESULT) { + throw new RpcException(Opcodes.ERR_INTERNAL, "unexpected getConfig response opcode"); + } + + Primitives.Reader r = new Primitives.Reader(response.body()); + int errorCode = r.readU8(); + if (errorCode != Opcodes.ERR_OK) { + throw mapErrorCode(errorCode, "getConfig failed"); + } + return r.readString(); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new FilaException("getConfig failed", e); + } catch (IOException e) { + throw new FilaException("getConfig failed", e); + } + } + + /** Redrive messages from a dead-letter queue back to its parent. */ + public long redrive(String dlqQueue, long count) { + int requestId = connection.nextRequestId(); + byte[] frame = Codec.encodeRedrive(requestId, dlqQueue, count); + try { - blockingStub.nack(req); - } catch (StatusRuntimeException e) { - throw mapNackError(e); + Connection.Frame response = connection.sendAndReceive(frame, requestId, 30_000); + checkForError(response); + + if (response.header().opcode() != Opcodes.REDRIVE_RESULT) { + throw new RpcException(Opcodes.ERR_INTERNAL, "unexpected redrive response opcode"); + } + + Primitives.Reader r = new Primitives.Reader(response.body()); + int errorCode = r.readU8(); + if (errorCode != Opcodes.ERR_OK) { + throw mapErrorCode(errorCode, "redrive failed"); + } + return r.readU64(); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new FilaException("redrive failed", e); + } catch (IOException e) { + throw new FilaException("redrive failed", e); + } + } + + // --- Auth operations --- + + /** + * Create an API key. + * + * @return array of [keyId, key, isSuperadmin] + */ + public String[] createApiKey(String name, long expiresAtMs, boolean isSuperadmin) { + int requestId = connection.nextRequestId(); + byte[] frame = Codec.encodeCreateApiKey(requestId, name, expiresAtMs, isSuperadmin); + + try { + Connection.Frame response = connection.sendAndReceive(frame, requestId, 30_000); + checkForError(response); + + if (response.header().opcode() != Opcodes.CREATE_API_KEY_RESULT) { + throw new RpcException(Opcodes.ERR_INTERNAL, "unexpected createApiKey response opcode"); + } + + Primitives.Reader r = new Primitives.Reader(response.body()); + int errorCode = r.readU8(); + if (errorCode != Opcodes.ERR_OK) { + throw mapErrorCode(errorCode, "createApiKey failed"); + } + String keyId = r.readString(); + String key = r.readString(); + boolean superadmin = r.readBool(); + return new String[] {keyId, key, String.valueOf(superadmin)}; + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new FilaException("createApiKey failed", e); + } catch (IOException e) { + throw new FilaException("createApiKey failed", e); + } + } + + /** Revoke an API key. */ + public void revokeApiKey(String keyId) { + int requestId = connection.nextRequestId(); + byte[] frame = Codec.encodeRevokeApiKey(requestId, keyId); + + try { + Connection.Frame response = connection.sendAndReceive(frame, requestId, 30_000); + checkForError(response); + + if (response.header().opcode() != Opcodes.REVOKE_API_KEY_RESULT) { + throw new RpcException(Opcodes.ERR_INTERNAL, "unexpected revokeApiKey response opcode"); + } + + Primitives.Reader r = new Primitives.Reader(response.body()); + int errorCode = r.readU8(); + if (errorCode != Opcodes.ERR_OK) { + throw mapErrorCode(errorCode, "revokeApiKey failed"); + } + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new FilaException("revokeApiKey failed", e); + } catch (IOException e) { + throw new FilaException("revokeApiKey failed", e); + } + } + + /** Set ACL permissions for an API key. */ + public void setAcl(String keyId, String[] kinds, String[] patterns) { + int requestId = connection.nextRequestId(); + byte[] frame = Codec.encodeSetAcl(requestId, keyId, kinds, patterns); + + try { + Connection.Frame response = connection.sendAndReceive(frame, requestId, 30_000); + checkForError(response); + + if (response.header().opcode() != Opcodes.SET_ACL_RESULT) { + throw new RpcException(Opcodes.ERR_INTERNAL, "unexpected setAcl response opcode"); + } + + Primitives.Reader r = new Primitives.Reader(response.body()); + int errorCode = r.readU8(); + if (errorCode != Opcodes.ERR_OK) { + throw mapErrorCode(errorCode, "setAcl failed"); + } + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new FilaException("setAcl failed", e); + } catch (IOException e) { + throw new FilaException("setAcl failed", e); } } - /** Shut down the underlying gRPC channel. */ @Override public void close() { - channel.shutdown(); + if (batcher != null) { + batcher.shutdown(); + } + connection.close(); + } + + // --- Internal helpers --- + + @SuppressWarnings("unchecked") + private String enqueueDirect(String queue, Map headers, byte[] payload) { + int requestId = connection.nextRequestId(); + byte[] frame = + Codec.encodeEnqueue( + requestId, new String[] {queue}, new Map[] {headers}, new byte[][] {payload}); + try { - if (!channel.awaitTermination(5, TimeUnit.SECONDS)) { - channel.shutdownNow(); + Connection.Frame response = connection.sendAndReceive(frame, requestId, 30_000); + checkForError(response); + + if (response.header().opcode() != Opcodes.ENQUEUE_RESULT) { + throw new RpcException(Opcodes.ERR_INTERNAL, "unexpected response opcode"); + } + + Primitives.Reader r = new Primitives.Reader(response.body()); + long resultCount = r.readU32(); + if (resultCount < 1) { + throw new RpcException(Opcodes.ERR_INTERNAL, "no result from server"); } + int errorCode = r.readU8(); + String messageId = r.readString(); + if (errorCode == Opcodes.ERR_OK) { + return messageId; + } + throw mapErrorCode(errorCode, "enqueue: " + messageId); } catch (InterruptedException e) { - channel.shutdownNow(); Thread.currentThread().interrupt(); + throw new FilaException("enqueue failed", e); + } catch (IOException e) { + throw new FilaException("enqueue failed", e); } } - private static ConsumeMessage buildConsumeMessage(Messages.Message msg) { - Messages.MessageMetadata meta = msg.getMetadata(); - return new ConsumeMessage( - msg.getId(), - msg.getHeadersMap(), - msg.getPayload().toByteArray(), - meta.getFairnessKey(), - meta.getAttemptCount(), - meta.getQueueId()); + private void retryConsumeOnLeader( + String leaderAddr, String queue, Consumer handler) { + validateLeaderAddr(leaderAddr); + String host = Builder.parseHost(leaderAddr); + int port = Builder.parsePort(leaderAddr); + + try { + SSLContext sslContext = null; + if (caCertPem != null) { + sslContext = Builder.buildSSLContext(caCertPem, clientCertPem, clientKeyPem); + } + Connection leaderConn = Connection.connect(host, port, apiKey, sslContext); + try { + FilaClient leaderClient = + new FilaClient(leaderConn, caCertPem, clientCertPem, clientKeyPem, apiKey, null); + ConsumerHandle handle = leaderClient.consume(queue, handler); + // Block until the consumer thread finishes. This method is called from + // within the original consumer thread, so blocking here is expected. + handle.awaitDone(); + } finally { + leaderConn.close(); + } + } catch (Exception e) { + throw new FilaException("failed to connect to leader at " + leaderAddr, e); + } } - private static FilaException mapEnqueueError(StatusRuntimeException e) { - return switch (e.getStatus().getCode()) { - case NOT_FOUND -> new QueueNotFoundException("enqueue: " + e.getStatus().getDescription()); - default -> new RpcException(e.getStatus().getCode(), e.getStatus().getDescription()); - }; + private void checkForError(Connection.Frame response) { + if (response.header().opcode() == Opcodes.ERROR) { + throw mapErrorFrame(response.body()); + } } - private static FilaException mapConsumeError(StatusRuntimeException e) { - return switch (e.getStatus().getCode()) { - case NOT_FOUND -> new QueueNotFoundException("consume: " + e.getStatus().getDescription()); - default -> new RpcException(e.getStatus().getCode(), e.getStatus().getDescription()); - }; + static FilaException mapErrorFrame(byte[] body) { + if (body.length == 0) { + return new RpcException(Opcodes.ERR_INTERNAL, "empty error frame"); + } + Primitives.Reader r = new Primitives.Reader(body); + int errorCode = r.readU8(); + String message = r.readString(); + // Read metadata map but don't use it for now (preserved for forward compat) + if (r.remaining() > 0) { + r.readStringMap(); + } + return mapErrorCode(errorCode, message); } - private static FilaException mapAckError(StatusRuntimeException e) { - return switch (e.getStatus().getCode()) { - case NOT_FOUND -> new MessageNotFoundException("ack: " + e.getStatus().getDescription()); - default -> new RpcException(e.getStatus().getCode(), e.getStatus().getDescription()); + static FilaException mapErrorCode(int errorCode, String message) { + return switch (errorCode) { + case Opcodes.ERR_QUEUE_NOT_FOUND -> new QueueNotFoundException("queue not found: " + message); + case Opcodes.ERR_MESSAGE_NOT_FOUND -> + new MessageNotFoundException("message not found: " + message); + case Opcodes.ERR_UNAUTHORIZED -> new RpcException(errorCode, "unauthorized: " + message); + case Opcodes.ERR_FORBIDDEN -> new RpcException(errorCode, "forbidden: " + message); + case Opcodes.ERR_NOT_LEADER -> new RpcException(errorCode, "not leader: " + message); + case Opcodes.ERR_QUEUE_ALREADY_EXISTS -> + new RpcException(errorCode, "queue already exists: " + message); + case Opcodes.ERR_CHANNEL_FULL -> new RpcException(errorCode, "channel full: " + message); + case Opcodes.ERR_API_KEY_NOT_FOUND -> + new RpcException(errorCode, "api key not found: " + message); + default -> new RpcException(errorCode, message); }; } - private static FilaException mapNackError(StatusRuntimeException e) { - return switch (e.getStatus().getCode()) { - case NOT_FOUND -> new MessageNotFoundException("nack: " + e.getStatus().getDescription()); - default -> new RpcException(e.getStatus().getCode(), e.getStatus().getDescription()); + private static String errorName(int code) { + return switch (code) { + case Opcodes.ERR_QUEUE_NOT_FOUND -> "queue not found"; + case Opcodes.ERR_MESSAGE_NOT_FOUND -> "message not found"; + case Opcodes.ERR_QUEUE_ALREADY_EXISTS -> "queue already exists"; + case Opcodes.ERR_UNAUTHORIZED -> "unauthorized"; + case Opcodes.ERR_FORBIDDEN -> "forbidden"; + case Opcodes.ERR_NOT_LEADER -> "not leader"; + default -> "error code 0x" + Integer.toHexString(code); }; } + private static void validateLeaderAddr(String addr) { + if (addr == null || addr.isEmpty()) { + throw new FilaException("invalid leader address: empty"); + } + if (addr.contains("//") || addr.contains("/")) { + throw new FilaException("invalid leader address: must be host:port, got: " + addr); + } + int colonIdx = addr.lastIndexOf(':'); + if (colonIdx < 0) { + throw new FilaException("invalid leader address: missing port, got: " + addr); + } + String host = addr.substring(0, colonIdx); + String portStr = addr.substring(colonIdx + 1); + if (host.isEmpty()) { + throw new FilaException("invalid leader address: empty host, got: " + addr); + } + int port; + try { + port = Integer.parseInt(portStr); + } catch (NumberFormatException ex) { + throw new FilaException("invalid leader address: non-numeric port, got: " + addr); + } + if (port < 1 || port > 65535) { + throw new FilaException("invalid leader address: port out of range, got: " + addr); + } + } + /** Builder for {@link FilaClient}. */ public static final class Builder { private final String address; @@ -218,6 +733,7 @@ public static final class Builder { private byte[] clientCertPem; private byte[] clientKeyPem; private String apiKey; + private BatchMode batchMode = BatchMode.auto(); private Builder(String address) { this.address = address; @@ -226,10 +742,6 @@ private Builder(String address) { /** * Enable TLS using the JVM's default trust store (cacerts). * - *

Use this when the Fila server's certificate is issued by a public CA already trusted by - * the JVM. For servers using self-signed or private CA certificates, use {@link - * #withTlsCaCert(byte[])} instead. - * * @return this builder */ public Builder withTls() { @@ -240,9 +752,6 @@ public Builder withTls() { /** * Set the CA certificate for TLS server verification. * - *

When set, the client connects over TLS instead of plaintext. The CA certificate is used to - * verify the server's identity. Implies {@link #withTls()}. - * * @param caCertPem PEM-encoded CA certificate bytes * @return this builder */ @@ -255,9 +764,6 @@ public Builder withTlsCaCert(byte[] caCertPem) { /** * Set the client certificate and key for mutual TLS (mTLS). * - *

Requires either {@link #withTls()} or {@link #withTlsCaCert(byte[])} to be called first. - * When provided, the client presents its certificate to the server for mutual authentication. - * * @param certPem PEM-encoded client certificate bytes * @param keyPem PEM-encoded client private key bytes * @return this builder @@ -271,9 +777,6 @@ public Builder withTlsClientCert(byte[] certPem, byte[] keyPem) { /** * Set an API key for authentication. * - *

When set, the key is sent as a {@code Bearer} token in the {@code authorization} metadata - * header on every outgoing RPC. - * * @param apiKey the API key string * @return this builder */ @@ -282,6 +785,17 @@ public Builder withApiKey(String apiKey) { return this; } + /** + * Set the batching mode for {@link FilaClient#enqueue} calls. + * + * @param batchMode the batch mode + * @return this builder + */ + public Builder withBatchMode(BatchMode batchMode) { + this.batchMode = batchMode; + return this; + } + /** Build and connect the client. */ public FilaClient build() { if (clientCertPem != null && !tlsEnabled) { @@ -289,55 +803,32 @@ public FilaClient build() { "client certificate requires TLS — call withTls() or withTlsCaCert() first"); } - ManagedChannel channel; - - if (tlsEnabled) { - // Parse host/port before the TLS try block so that NumberFormatException - // (a subclass of IllegalArgumentException) from address parsing is not - // misreported as "invalid certificate". - String host = parseHost(address); - int port = parsePort(address); - - try { - TlsChannelCredentials.Builder tlsBuilder = TlsChannelCredentials.newBuilder(); + String host = parseHost(address); + int port = parsePort(address); - if (caCertPem != null) { - tlsBuilder.trustManager(new ByteArrayInputStream(caCertPem)); - } - - if (clientCertPem != null && clientKeyPem != null) { - tlsBuilder.keyManager( - new ByteArrayInputStream(clientCertPem), new ByteArrayInputStream(clientKeyPem)); - } - - ChannelCredentials creds = tlsBuilder.build(); - var channelBuilder = Grpc.newChannelBuilderForAddress(host, port, creds); - - if (apiKey != null) { - channelBuilder.intercept(new ApiKeyInterceptor(apiKey)); - } - - channel = channelBuilder.build(); - } catch (IllegalArgumentException e) { - throw new FilaException("failed to configure TLS: invalid certificate", e); - } catch (IOException e) { - throw new FilaException("failed to configure TLS", e); + try { + SSLContext sslContext = null; + if (tlsEnabled) { + sslContext = buildSSLContext(caCertPem, clientCertPem, clientKeyPem); } - } else { - var channelBuilder = ManagedChannelBuilder.forTarget(address).usePlaintext(); - if (apiKey != null) { - channelBuilder.intercept(new ApiKeyInterceptor(apiKey)); + Connection connection = Connection.connect(host, port, apiKey, sslContext); + + Batcher batcherInstance = null; + if (batchMode.getKind() != BatchMode.Kind.DISABLED) { + batcherInstance = new Batcher(connection, batchMode); } - channel = channelBuilder.build(); + return new FilaClient( + connection, caCertPem, clientCertPem, clientKeyPem, apiKey, batcherInstance); + } catch (FilaException e) { + throw e; + } catch (Exception e) { + throw new FilaException("failed to connect to " + address, e); } - - return new FilaClient(channel); } - private static String parseHost(String address) { - // Handle IPv6 bracket notation: [::1]:5555 + static String parseHost(String address) { if (address.startsWith("[")) { int closeBracket = address.indexOf(']'); if (closeBracket < 0) { @@ -352,8 +843,7 @@ private static String parseHost(String address) { return address.substring(0, colonIdx); } - private static int parsePort(String address) { - // Handle IPv6 bracket notation: [::1]:5555 + static int parsePort(String address) { if (address.startsWith("[")) { int closeBracket = address.indexOf(']'); if (closeBracket < 0 || closeBracket + 2 > address.length()) { @@ -367,5 +857,67 @@ private static int parsePort(String address) { } return Integer.parseInt(address.substring(colonIdx + 1)); } + + static SSLContext buildSSLContext(byte[] caCertPem, byte[] clientCertPem, byte[] clientKeyPem) { + try { + SSLContext sslContext = SSLContext.getInstance("TLS"); + + TrustManagerFactory tmf = null; + if (caCertPem != null) { + CertificateFactory cf = CertificateFactory.getInstance("X.509"); + X509Certificate caCert = + (X509Certificate) cf.generateCertificate(new ByteArrayInputStream(caCertPem)); + KeyStore ts = KeyStore.getInstance(KeyStore.getDefaultType()); + ts.load(null, null); + ts.setCertificateEntry("ca", caCert); + tmf = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()); + tmf.init(ts); + } + + KeyManagerFactory kmf = null; + if (clientCertPem != null && clientKeyPem != null) { + CertificateFactory cf = CertificateFactory.getInstance("X.509"); + X509Certificate clientCert = + (X509Certificate) cf.generateCertificate(new ByteArrayInputStream(clientCertPem)); + + // Parse PEM private key + String keyPemStr = new String(clientKeyPem); + keyPemStr = + keyPemStr + .replace("-----BEGIN PRIVATE KEY-----", "") + .replace("-----END PRIVATE KEY-----", "") + .replace("-----BEGIN EC PRIVATE KEY-----", "") + .replace("-----END EC PRIVATE KEY-----", "") + .replaceAll("\\s", ""); + byte[] keyDer = Base64.getDecoder().decode(keyPemStr); + + java.security.spec.PKCS8EncodedKeySpec keySpec = + new java.security.spec.PKCS8EncodedKeySpec(keyDer); + + // Try EC first, then RSA + java.security.PrivateKey privateKey; + try { + privateKey = KeyFactory.getInstance("EC").generatePrivate(keySpec); + } catch (Exception e) { + privateKey = KeyFactory.getInstance("RSA").generatePrivate(keySpec); + } + + KeyStore ks = KeyStore.getInstance(KeyStore.getDefaultType()); + ks.load(null, null); + ks.setKeyEntry( + "client", privateKey, new char[0], new java.security.cert.Certificate[] {clientCert}); + kmf = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm()); + kmf.init(ks, new char[0]); + } + + sslContext.init( + kmf != null ? kmf.getKeyManagers() : null, + tmf != null ? tmf.getTrustManagers() : null, + null); + return sslContext; + } catch (Exception e) { + throw new FilaException("failed to configure TLS: invalid certificate", e); + } + } } } diff --git a/src/main/java/dev/faisca/fila/RpcException.java b/src/main/java/dev/faisca/fila/RpcException.java index 55241c5..2162a75 100644 --- a/src/main/java/dev/faisca/fila/RpcException.java +++ b/src/main/java/dev/faisca/fila/RpcException.java @@ -1,18 +1,16 @@ package dev.faisca.fila; -import io.grpc.Status; - -/** Thrown for unexpected gRPC failures not mapped to a specific Fila exception. */ +/** Thrown for protocol-level failures not mapped to a specific Fila exception. */ public class RpcException extends FilaException { - private final Status.Code code; + private final int errorCode; - public RpcException(Status.Code code, String message) { + public RpcException(int errorCode, String message) { super(message); - this.code = code; + this.errorCode = errorCode; } - /** Returns the gRPC status code of the failed call. */ - public Status.Code getCode() { - return code; + /** Returns the FIBP error code of the failed call. */ + public int getErrorCode() { + return errorCode; } } diff --git a/src/main/java/dev/faisca/fila/fibp/Codec.java b/src/main/java/dev/faisca/fila/fibp/Codec.java new file mode 100644 index 0000000..6a9f786 --- /dev/null +++ b/src/main/java/dev/faisca/fila/fibp/Codec.java @@ -0,0 +1,216 @@ +package dev.faisca.fila.fibp; + +import java.util.Map; + +/** + * Encodes and decodes FIBP frames. + * + *

Encoding methods produce complete frame bodies (including the 6-byte header). Decoding methods + * consume the body bytes after the header has been parsed. + */ +public final class Codec { + private Codec() {} + + // --- Frame body encoding (header + payload) --- + + /** Encode a complete frame: [u32 length][body]. */ + public static byte[] encodeFrame(byte opcode, byte flags, int requestId, byte[] bodyPayload) { + int bodyLen = FrameHeader.SIZE + bodyPayload.length; + Primitives.Writer w = new Primitives.Writer(4 + bodyLen); + w.writeU32(bodyLen); + w.writeU8(opcode & 0xFF); + w.writeU8(flags & 0xFF); + w.writeU32(requestId); + if (bodyPayload.length > 0) { + // Write raw bytes directly + byte[] frameBytes = w.toByteArray(); + byte[] result = new byte[frameBytes.length + bodyPayload.length]; + System.arraycopy(frameBytes, 0, result, 0, frameBytes.length); + System.arraycopy(bodyPayload, 0, result, frameBytes.length, bodyPayload.length); + return result; + } + return w.toByteArray(); + } + + /** Encode a Handshake frame (0x01). */ + public static byte[] encodeHandshake(int requestId, int version, String apiKey) { + Primitives.Writer w = new Primitives.Writer(32); + w.writeU16(version); + w.writeOptionalString(apiKey); + return encodeFrame(Opcodes.HANDSHAKE, (byte) 0, requestId, w.toByteArray()); + } + + /** Encode a Ping frame (0x03). */ + public static byte[] encodePing(int requestId) { + return encodeFrame(Opcodes.PING, (byte) 0, requestId, new byte[0]); + } + + /** Encode a Pong frame (0x04). */ + public static byte[] encodePong(int requestId) { + return encodeFrame(Opcodes.PONG, (byte) 0, requestId, new byte[0]); + } + + /** Encode a Disconnect frame (0x05). */ + public static byte[] encodeDisconnect(int requestId) { + return encodeFrame(Opcodes.DISCONNECT, (byte) 0, requestId, new byte[0]); + } + + /** Encode an Enqueue frame (0x10) for a batch of messages. */ + public static byte[] encodeEnqueue( + int requestId, String[] queues, Map[] headers, byte[][] payloads) { + Primitives.Writer w = new Primitives.Writer(256); + w.writeU32(queues.length); + for (int i = 0; i < queues.length; i++) { + w.writeString(queues[i]); + w.writeStringMap(headers[i]); + w.writeBytes(payloads[i]); + } + return encodeFrame(Opcodes.ENQUEUE, (byte) 0, requestId, w.toByteArray()); + } + + /** Encode a Consume frame (0x12). */ + public static byte[] encodeConsume(int requestId, String queue) { + Primitives.Writer w = new Primitives.Writer(32); + w.writeString(queue); + return encodeFrame(Opcodes.CONSUME, (byte) 0, requestId, w.toByteArray()); + } + + /** Encode a CancelConsume frame (0x14). */ + public static byte[] encodeCancelConsume(int requestId) { + return encodeFrame(Opcodes.CANCEL_CONSUME, (byte) 0, requestId, new byte[0]); + } + + /** Encode an Ack frame (0x16). */ + public static byte[] encodeAck(int requestId, String[] queues, String[] messageIds) { + Primitives.Writer w = new Primitives.Writer(128); + w.writeU32(queues.length); + for (int i = 0; i < queues.length; i++) { + w.writeString(queues[i]); + w.writeString(messageIds[i]); + } + return encodeFrame(Opcodes.ACK, (byte) 0, requestId, w.toByteArray()); + } + + /** Encode a Nack frame (0x18). */ + public static byte[] encodeNack( + int requestId, String[] queues, String[] messageIds, String[] errors) { + Primitives.Writer w = new Primitives.Writer(128); + w.writeU32(queues.length); + for (int i = 0; i < queues.length; i++) { + w.writeString(queues[i]); + w.writeString(messageIds[i]); + w.writeString(errors[i]); + } + return encodeFrame(Opcodes.NACK, (byte) 0, requestId, w.toByteArray()); + } + + // --- Admin opcodes --- + + /** Encode a CreateQueue frame (0xFD). */ + public static byte[] encodeCreateQueue( + int requestId, + String name, + String onEnqueueScript, + String onFailureScript, + long visibilityTimeoutMs) { + Primitives.Writer w = new Primitives.Writer(64); + w.writeString(name); + w.writeOptionalString(onEnqueueScript); + w.writeOptionalString(onFailureScript); + w.writeU64(visibilityTimeoutMs); + return encodeFrame(Opcodes.CREATE_QUEUE, (byte) 0, requestId, w.toByteArray()); + } + + /** Encode a DeleteQueue frame (0xFB). */ + public static byte[] encodeDeleteQueue(int requestId, String queue) { + Primitives.Writer w = new Primitives.Writer(32); + w.writeString(queue); + return encodeFrame(Opcodes.DELETE_QUEUE, (byte) 0, requestId, w.toByteArray()); + } + + /** Encode a GetStats frame (0xF9). */ + public static byte[] encodeGetStats(int requestId, String queue) { + Primitives.Writer w = new Primitives.Writer(32); + w.writeString(queue); + return encodeFrame(Opcodes.GET_STATS, (byte) 0, requestId, w.toByteArray()); + } + + /** Encode a ListQueues frame (0xF7). */ + public static byte[] encodeListQueues(int requestId) { + return encodeFrame(Opcodes.LIST_QUEUES, (byte) 0, requestId, new byte[0]); + } + + /** Encode a SetConfig frame (0xF5). */ + public static byte[] encodeSetConfig(int requestId, String key, String value) { + Primitives.Writer w = new Primitives.Writer(64); + w.writeString(key); + w.writeString(value); + return encodeFrame(Opcodes.SET_CONFIG, (byte) 0, requestId, w.toByteArray()); + } + + /** Encode a GetConfig frame (0xF3). */ + public static byte[] encodeGetConfig(int requestId, String key) { + Primitives.Writer w = new Primitives.Writer(32); + w.writeString(key); + return encodeFrame(Opcodes.GET_CONFIG, (byte) 0, requestId, w.toByteArray()); + } + + /** Encode a ListConfig frame (0xF1). */ + public static byte[] encodeListConfig(int requestId, String prefix) { + Primitives.Writer w = new Primitives.Writer(32); + w.writeString(prefix); + return encodeFrame(Opcodes.LIST_CONFIG, (byte) 0, requestId, w.toByteArray()); + } + + /** Encode a Redrive frame (0xEF). */ + public static byte[] encodeRedrive(int requestId, String dlqQueue, long count) { + Primitives.Writer w = new Primitives.Writer(32); + w.writeString(dlqQueue); + w.writeU64(count); + return encodeFrame(Opcodes.REDRIVE, (byte) 0, requestId, w.toByteArray()); + } + + // --- Auth opcodes --- + + /** Encode a CreateApiKey frame (0xED). */ + public static byte[] encodeCreateApiKey( + int requestId, String name, long expiresAtMs, boolean isSuperadmin) { + Primitives.Writer w = new Primitives.Writer(64); + w.writeString(name); + w.writeU64(expiresAtMs); + w.writeBool(isSuperadmin); + return encodeFrame(Opcodes.CREATE_API_KEY, (byte) 0, requestId, w.toByteArray()); + } + + /** Encode a RevokeApiKey frame (0xEB). */ + public static byte[] encodeRevokeApiKey(int requestId, String keyId) { + Primitives.Writer w = new Primitives.Writer(32); + w.writeString(keyId); + return encodeFrame(Opcodes.REVOKE_API_KEY, (byte) 0, requestId, w.toByteArray()); + } + + /** Encode a ListApiKeys frame (0xE9). */ + public static byte[] encodeListApiKeys(int requestId) { + return encodeFrame(Opcodes.LIST_API_KEYS, (byte) 0, requestId, new byte[0]); + } + + /** Encode a SetAcl frame (0xE7). */ + public static byte[] encodeSetAcl( + int requestId, String keyId, String[] kinds, String[] patterns) { + Primitives.Writer w = new Primitives.Writer(128); + w.writeString(keyId); + w.writeU16(kinds.length); + for (int i = 0; i < kinds.length; i++) { + w.writeString(kinds[i]); + w.writeString(patterns[i]); + } + return encodeFrame(Opcodes.SET_ACL, (byte) 0, requestId, w.toByteArray()); + } + + /** Encode a GetAcl frame (0xE5). */ + public static byte[] encodeGetAcl(int requestId, String keyId) { + Primitives.Writer w = new Primitives.Writer(32); + w.writeString(keyId); + return encodeFrame(Opcodes.GET_ACL, (byte) 0, requestId, w.toByteArray()); + } +} diff --git a/src/main/java/dev/faisca/fila/fibp/Connection.java b/src/main/java/dev/faisca/fila/fibp/Connection.java new file mode 100644 index 0000000..6722ccb --- /dev/null +++ b/src/main/java/dev/faisca/fila/fibp/Connection.java @@ -0,0 +1,311 @@ +package dev.faisca.fila.fibp; + +import java.io.ByteArrayOutputStream; +import java.io.DataInputStream; +import java.io.IOException; +import java.io.OutputStream; +import java.net.Socket; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import javax.net.ssl.SSLContext; +import javax.net.ssl.SSLSocket; +import javax.net.ssl.SSLSocketFactory; + +/** + * Manages a TCP connection to a Fila server using the FIBP binary protocol. + * + *

Handles handshake, request/response multiplexing, server-push delivery routing, and keepalive. + */ +public final class Connection implements AutoCloseable { + + /** A received frame (header + body after header). */ + public record Frame(FrameHeader header, byte[] body) {} + + private final Socket socket; + private final DataInputStream input; + private final OutputStream output; + private final Object writeLock = new Object(); + + private final AtomicInteger nextRequestId = new AtomicInteger(1); + private final ConcurrentHashMap> waiters = + new ConcurrentHashMap<>(); + private final ConcurrentHashMap> deliveryChannels = + new ConcurrentHashMap<>(); + + // Continuation frame reassembly buffers, keyed by request ID. + private final ConcurrentHashMap continuationBuffers = + new ConcurrentHashMap<>(); + // Track the opcode for each continuation sequence. + private final ConcurrentHashMap continuationOpcodes = new ConcurrentHashMap<>(); + + private final AtomicBoolean closed = new AtomicBoolean(false); + private final Thread readLoop; + + private long nodeId; + private int maxFrameSize; + + private Connection(Socket socket) throws IOException { + this.socket = socket; + this.input = new DataInputStream(socket.getInputStream()); + this.output = socket.getOutputStream(); + this.readLoop = new Thread(this::readLoopRun, "fibp-read-loop"); + this.readLoop.setDaemon(true); + } + + /** + * Connect to a Fila server and perform the FIBP handshake. + * + * @param host server host + * @param port server port + * @param apiKey optional API key (null if auth disabled) + * @param sslContext optional SSLContext for TLS (null for plaintext) + * @return a connected Connection + */ + public static Connection connect(String host, int port, String apiKey, SSLContext sslContext) + throws IOException { + Socket socket; + if (sslContext != null) { + SSLSocketFactory factory = sslContext.getSocketFactory(); + SSLSocket sslSocket = (SSLSocket) factory.createSocket(host, port); + sslSocket.startHandshake(); + socket = sslSocket; + } else { + socket = new Socket(host, port); + } + socket.setTcpNoDelay(true); + + Connection conn = new Connection(socket); + try { + conn.performHandshake(apiKey); + } catch (IOException e) { + try { + socket.close(); + } catch (IOException suppressed) { + e.addSuppressed(suppressed); + } + throw e; + } + conn.readLoop.start(); + return conn; + } + + /** Returns the next request ID. */ + public int nextRequestId() { + return nextRequestId.getAndIncrement(); + } + + /** Returns the server's node ID from the handshake. */ + public long getNodeId() { + return nodeId; + } + + /** Returns the server's max frame size from the handshake. */ + public int getMaxFrameSize() { + return maxFrameSize; + } + + /** Send raw bytes (a complete encoded frame) to the server. */ + public void send(byte[] frameBytes) throws IOException { + synchronized (writeLock) { + output.write(frameBytes); + output.flush(); + } + } + + /** + * Send a request and wait for the response. + * + * @param frameBytes the complete encoded frame + * @param requestId the request ID used in the frame + * @param timeoutMs timeout in milliseconds + * @return the response frame + */ + public Frame sendAndReceive(byte[] frameBytes, int requestId, long timeoutMs) + throws IOException, InterruptedException { + LinkedBlockingQueue queue = new LinkedBlockingQueue<>(1); + waiters.put(requestId, queue); + try { + send(frameBytes); + Frame frame = queue.poll(timeoutMs, TimeUnit.MILLISECONDS); + if (frame == null) { + throw new IOException("request timed out (requestId=" + requestId + ")"); + } + return frame; + } finally { + waiters.remove(requestId); + } + } + + /** + * Register a delivery channel for a consume subscription. + * + * @param requestId the consume request ID + * @return a queue that will receive delivery frames + */ + public LinkedBlockingQueue registerDeliveryChannel(int requestId) { + LinkedBlockingQueue queue = new LinkedBlockingQueue<>(); + deliveryChannels.put(requestId, queue); + return queue; + } + + /** Unregister a delivery channel. */ + public void unregisterDeliveryChannel(int requestId) { + deliveryChannels.remove(requestId); + } + + /** Returns true if the connection is closed. */ + public boolean isClosed() { + return closed.get(); + } + + @Override + public void close() { + if (closed.compareAndSet(false, true)) { + // Try to send disconnect + try { + send(Codec.encodeDisconnect(0)); + } catch (IOException ignored) { + // best effort + } + + try { + socket.close(); + } catch (IOException ignored) { + // best effort + } + + readLoop.interrupt(); + + // Fail all pending waiters + for (Map.Entry> entry : waiters.entrySet()) { + // Signal by offering a sentinel (or the waiter will timeout) + entry + .getValue() + .offer( + new Frame(new FrameHeader(Opcodes.ERROR, (byte) 0, entry.getKey()), new byte[0])); + } + } + } + + private void performHandshake(String apiKey) throws IOException { + byte[] handshakeFrame = Codec.encodeHandshake(0, Opcodes.PROTOCOL_VERSION, apiKey); + send(handshakeFrame); + + // Read the response synchronously (before read loop starts). + // Loop to skip continuation frames (readFrame returns null for them). + Frame response; + do { + response = readFrame(); + } while (response == null); + if (response.header().opcode() == Opcodes.HANDSHAKE_OK) { + Primitives.Reader r = new Primitives.Reader(response.body()); + int negotiatedVersion = r.readU16(); + this.nodeId = r.readU64(); + int maxFrame = r.readU32AsInt(); + this.maxFrameSize = maxFrame == 0 ? Opcodes.DEFAULT_MAX_FRAME_SIZE : maxFrame; + } else if (response.header().opcode() == Opcodes.ERROR) { + Primitives.Reader r = new Primitives.Reader(response.body()); + int errorCode = r.readU8(); + String message = r.readString(); + throw new IOException( + "handshake rejected: code=0x" + Integer.toHexString(errorCode) + " message=" + message); + } else { + throw new IOException( + "unexpected handshake response opcode: 0x" + + Integer.toHexString(response.header().opcode() & 0xFF)); + } + } + + private Frame readFrame() throws IOException { + // Read 4-byte length prefix + byte[] lenBytes = new byte[4]; + input.readFully(lenBytes); + int bodyLen = ByteBuffer.wrap(lenBytes).order(ByteOrder.BIG_ENDIAN).getInt(); + + // Read body + byte[] body = new byte[bodyLen]; + input.readFully(body); + + // Parse header (first 6 bytes of body) + byte opcode = body[0]; + byte flags = body[1]; + int requestId = ByteBuffer.wrap(body, 2, 4).order(ByteOrder.BIG_ENDIAN).getInt(); + + // Payload is everything after the 6-byte header + byte[] payload = new byte[bodyLen - FrameHeader.SIZE]; + System.arraycopy(body, FrameHeader.SIZE, payload, 0, payload.length); + + FrameHeader header = new FrameHeader(opcode, flags, requestId); + + // Handle continuation frames + if (header.isContinuation()) { + ByteArrayOutputStream buf = + continuationBuffers.computeIfAbsent(requestId, k -> new ByteArrayOutputStream()); + continuationOpcodes.putIfAbsent(requestId, opcode); + buf.write(payload); + // Return null to signal the caller to read another frame + return null; + } else { + // Check if we have buffered continuation data + ByteArrayOutputStream contBuf = continuationBuffers.remove(requestId); + Byte contOpcode = continuationOpcodes.remove(requestId); + if (contBuf != null) { + contBuf.write(payload); + byte resolvedOpcode = contOpcode != null ? contOpcode : opcode; + return new Frame( + new FrameHeader(resolvedOpcode, (byte) 0, requestId), contBuf.toByteArray()); + } + return new Frame(header, payload); + } + } + + private void readLoopRun() { + try { + while (!closed.get() && !Thread.currentThread().isInterrupted()) { + Frame frame = readFrame(); + if (frame == null) { + // Continuation frame buffered, read next + continue; + } + + byte opcode = frame.header().opcode(); + int reqId = frame.header().requestId(); + + if (opcode == Opcodes.PING) { + // Respond with Pong + try { + send(Codec.encodePong(reqId)); + } catch (IOException ignored) { + // connection closing + } + continue; + } + + if (opcode == Opcodes.DELIVERY) { + // Route to delivery channel + LinkedBlockingQueue ch = deliveryChannels.get(reqId); + if (ch != null) { + ch.offer(frame); + } + continue; + } + + // Route to request waiter + LinkedBlockingQueue waiter = waiters.get(reqId); + if (waiter != null) { + waiter.offer(frame); + } + } + } catch (IOException e) { + if (!closed.get()) { + close(); + } + } + } +} diff --git a/src/main/java/dev/faisca/fila/fibp/FrameHeader.java b/src/main/java/dev/faisca/fila/fibp/FrameHeader.java new file mode 100644 index 0000000..6a9e952 --- /dev/null +++ b/src/main/java/dev/faisca/fila/fibp/FrameHeader.java @@ -0,0 +1,13 @@ +package dev.faisca.fila.fibp; + +/** Represents the 6-byte frame header in the FIBP protocol. */ +public record FrameHeader(byte opcode, byte flags, int requestId) { + + /** Size of the frame header in bytes. */ + public static final int SIZE = 6; + + /** Returns true if the CONTINUATION flag is set. */ + public boolean isContinuation() { + return (flags & Opcodes.FLAG_CONTINUATION) != 0; + } +} diff --git a/src/main/java/dev/faisca/fila/fibp/Opcodes.java b/src/main/java/dev/faisca/fila/fibp/Opcodes.java new file mode 100644 index 0000000..159caa4 --- /dev/null +++ b/src/main/java/dev/faisca/fila/fibp/Opcodes.java @@ -0,0 +1,85 @@ +package dev.faisca.fila.fibp; + +/** FIBP protocol opcodes and error codes. */ +public final class Opcodes { + private Opcodes() {} + + // Protocol version + public static final int PROTOCOL_VERSION = 1; + + // Default max frame size (16 MiB) + public static final int DEFAULT_MAX_FRAME_SIZE = 16 * 1024 * 1024; + + // Control opcodes (0x00-0x0F) + public static final byte HANDSHAKE = 0x01; + public static final byte HANDSHAKE_OK = 0x02; + public static final byte PING = 0x03; + public static final byte PONG = 0x04; + public static final byte DISCONNECT = 0x05; + + // Hot-path opcodes (0x10-0x1F) + public static final byte ENQUEUE = 0x10; + public static final byte ENQUEUE_RESULT = 0x11; + public static final byte CONSUME = 0x12; + public static final byte CONSUME_OK = 0x13; + public static final byte DELIVERY = 0x14; + public static final byte CANCEL_CONSUME = 0x15; + public static final byte ACK = 0x16; + public static final byte ACK_RESULT = 0x17; + public static final byte NACK = 0x18; + public static final byte NACK_RESULT = 0x19; + + // Error opcode + public static final byte ERROR = (byte) 0xFE; + + // Admin opcodes (0xFD downward) + public static final byte CREATE_QUEUE = (byte) 0xFD; + public static final byte CREATE_QUEUE_RESULT = (byte) 0xFC; + public static final byte DELETE_QUEUE = (byte) 0xFB; + public static final byte DELETE_QUEUE_RESULT = (byte) 0xFA; + public static final byte GET_STATS = (byte) 0xF9; + public static final byte GET_STATS_RESULT = (byte) 0xF8; + public static final byte LIST_QUEUES = (byte) 0xF7; + public static final byte LIST_QUEUES_RESULT = (byte) 0xF6; + public static final byte SET_CONFIG = (byte) 0xF5; + public static final byte SET_CONFIG_RESULT = (byte) 0xF4; + public static final byte GET_CONFIG = (byte) 0xF3; + public static final byte GET_CONFIG_RESULT = (byte) 0xF2; + public static final byte LIST_CONFIG = (byte) 0xF1; + public static final byte LIST_CONFIG_RESULT = (byte) 0xF0; + public static final byte REDRIVE = (byte) 0xEF; + public static final byte REDRIVE_RESULT = (byte) 0xEE; + public static final byte CREATE_API_KEY = (byte) 0xED; + public static final byte CREATE_API_KEY_RESULT = (byte) 0xEC; + public static final byte REVOKE_API_KEY = (byte) 0xEB; + public static final byte REVOKE_API_KEY_RESULT = (byte) 0xEA; + public static final byte LIST_API_KEYS = (byte) 0xE9; + public static final byte LIST_API_KEYS_RESULT = (byte) 0xE8; + public static final byte SET_ACL = (byte) 0xE7; + public static final byte SET_ACL_RESULT = (byte) 0xE6; + public static final byte GET_ACL = (byte) 0xE5; + public static final byte GET_ACL_RESULT = (byte) 0xE4; + + // Error codes + public static final byte ERR_OK = 0x00; + public static final byte ERR_QUEUE_NOT_FOUND = 0x01; + public static final byte ERR_MESSAGE_NOT_FOUND = 0x02; + public static final byte ERR_QUEUE_ALREADY_EXISTS = 0x03; + public static final byte ERR_LUA_COMPILATION = 0x04; + public static final byte ERR_STORAGE = 0x05; + public static final byte ERR_NOT_A_DLQ = 0x06; + public static final byte ERR_PARENT_QUEUE_NOT_FOUND = 0x07; + public static final byte ERR_INVALID_CONFIG_VALUE = 0x08; + public static final byte ERR_CHANNEL_FULL = 0x09; + public static final byte ERR_UNAUTHORIZED = 0x0A; + public static final byte ERR_FORBIDDEN = 0x0B; + public static final byte ERR_NOT_LEADER = 0x0C; + public static final byte ERR_UNSUPPORTED_VERSION = 0x0D; + public static final byte ERR_INVALID_FRAME = 0x0E; + public static final byte ERR_API_KEY_NOT_FOUND = 0x0F; + public static final byte ERR_NODE_NOT_READY = 0x10; + public static final byte ERR_INTERNAL = (byte) 0xFF; + + // Flags + public static final byte FLAG_CONTINUATION = 0x01; +} diff --git a/src/main/java/dev/faisca/fila/fibp/Primitives.java b/src/main/java/dev/faisca/fila/fibp/Primitives.java new file mode 100644 index 0000000..fc2e7c7 --- /dev/null +++ b/src/main/java/dev/faisca/fila/fibp/Primitives.java @@ -0,0 +1,234 @@ +package dev.faisca.fila.fibp; + +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.charset.StandardCharsets; +import java.util.LinkedHashMap; +import java.util.Map; + +/** + * Reader and writer for FIBP primitive types using big-endian encoding. + * + *

The writer accumulates bytes into an expanding buffer. The reader consumes bytes from a + * provided byte array. + */ +public final class Primitives { + private Primitives() {} + + /** Writer that accumulates encoded bytes. */ + public static final class Writer { + private ByteBuffer buf; + + public Writer() { + this(256); + } + + public Writer(int initialCapacity) { + buf = ByteBuffer.allocate(initialCapacity).order(ByteOrder.BIG_ENDIAN); + } + + private void ensureCapacity(int needed) { + if (buf.remaining() < needed) { + int newCap = Math.max(buf.capacity() * 2, buf.position() + needed); + ByteBuffer newBuf = ByteBuffer.allocate(newCap).order(ByteOrder.BIG_ENDIAN); + buf.flip(); + newBuf.put(buf); + buf = newBuf; + } + } + + public void writeU8(int value) { + ensureCapacity(1); + buf.put((byte) (value & 0xFF)); + } + + public void writeU16(int value) { + ensureCapacity(2); + buf.putShort((short) (value & 0xFFFF)); + } + + public void writeU32(int value) { + ensureCapacity(4); + buf.putInt(value); + } + + public void writeU32(long value) { + ensureCapacity(4); + buf.putInt((int) (value & 0xFFFFFFFFL)); + } + + public void writeU64(long value) { + ensureCapacity(8); + buf.putLong(value); + } + + public void writeI64(long value) { + ensureCapacity(8); + buf.putLong(value); + } + + public void writeF64(double value) { + ensureCapacity(8); + buf.putDouble(value); + } + + public void writeBool(boolean value) { + writeU8(value ? 1 : 0); + } + + public void writeString(String value) { + byte[] bytes = value.getBytes(StandardCharsets.UTF_8); + if (bytes.length > 65535) { + throw new IllegalArgumentException( + "string exceeds u16 max length: " + bytes.length + " bytes"); + } + writeU16(bytes.length); + ensureCapacity(bytes.length); + buf.put(bytes); + } + + public void writeBytes(byte[] value) { + writeU32(value.length); + ensureCapacity(value.length); + buf.put(value); + } + + public void writeStringMap(Map map) { + if (map.size() > 65535) { + throw new IllegalArgumentException("map exceeds u16 max entry count: " + map.size()); + } + writeU16(map.size()); + for (Map.Entry entry : map.entrySet()) { + writeString(entry.getKey()); + writeString(entry.getValue()); + } + } + + public void writeStringList(String[] list) { + writeU16(list.length); + for (String s : list) { + writeString(s); + } + } + + public void writeOptionalString(String value) { + if (value == null) { + writeU8(0); + } else { + writeU8(1); + writeString(value); + } + } + + /** Returns the encoded bytes. */ + public byte[] toByteArray() { + byte[] result = new byte[buf.position()]; + buf.flip(); + buf.get(result); + buf.flip(); // restore position for potential further use + buf.position(result.length); + return result; + } + + /** Returns current write position (number of bytes written). */ + public int position() { + return buf.position(); + } + } + + /** Reader that consumes bytes from a buffer. */ + public static final class Reader { + private final ByteBuffer buf; + + public Reader(byte[] data) { + buf = ByteBuffer.wrap(data).order(ByteOrder.BIG_ENDIAN); + } + + public Reader(byte[] data, int offset, int length) { + buf = ByteBuffer.wrap(data, offset, length).order(ByteOrder.BIG_ENDIAN); + } + + public int readU8() { + return buf.get() & 0xFF; + } + + public int readU16() { + return buf.getShort() & 0xFFFF; + } + + public int readU32AsInt() { + return buf.getInt(); + } + + public long readU32() { + return buf.getInt() & 0xFFFFFFFFL; + } + + public long readU64() { + return buf.getLong(); + } + + public long readI64() { + return buf.getLong(); + } + + public double readF64() { + return buf.getDouble(); + } + + public boolean readBool() { + return readU8() != 0; + } + + public String readString() { + int len = readU16(); + byte[] bytes = new byte[len]; + buf.get(bytes); + return new String(bytes, StandardCharsets.UTF_8); + } + + public byte[] readBytes() { + int len = readU32AsInt(); + if (len < 0 || len > buf.remaining()) { + throw new IllegalArgumentException( + "invalid byte array length: " + len + " (remaining: " + buf.remaining() + ")"); + } + byte[] bytes = new byte[len]; + buf.get(bytes); + return bytes; + } + + public Map readStringMap() { + int count = readU16(); + Map map = new LinkedHashMap<>(count); + for (int i = 0; i < count; i++) { + String key = readString(); + String value = readString(); + map.put(key, value); + } + return map; + } + + public String[] readStringList() { + int count = readU16(); + String[] list = new String[count]; + for (int i = 0; i < count; i++) { + list[i] = readString(); + } + return list; + } + + public String readOptionalString() { + int present = readU8(); + if (present == 0) { + return null; + } + return readString(); + } + + /** Returns the number of remaining bytes. */ + public int remaining() { + return buf.remaining(); + } + } +} diff --git a/src/test/java/dev/faisca/fila/BatchClientTest.java b/src/test/java/dev/faisca/fila/BatchClientTest.java new file mode 100644 index 0000000..90fc203 --- /dev/null +++ b/src/test/java/dev/faisca/fila/BatchClientTest.java @@ -0,0 +1,234 @@ +package dev.faisca.fila; + +import static org.junit.jupiter.api.Assertions.*; + +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIf; + +/** + * Integration tests for enqueueMany and smart batching. + * + *

Requires a fila-server binary. Skipped if not available. + */ +@EnabledIf("serverAvailable") +class BatchClientTest { + private static TestServer server; + + @BeforeAll + static void setUp() throws Exception { + server = TestServer.start(); + server.createQueue("test-batch-explicit"); + server.createQueue("test-batch-auto"); + server.createQueue("test-batch-linger"); + server.createQueue("test-batch-disabled"); + server.createQueue("test-batch-consume"); + server.createQueue("test-batch-mixed"); + } + + @AfterAll + static void tearDown() { + if (server != null) server.stop(); + } + + static boolean serverAvailable() { + return TestServer.isBinaryAvailable(); + } + + @Test + void explicitEnqueueMany() { + try (FilaClient client = + FilaClient.builder(server.address()).withBatchMode(BatchMode.disabled()).build()) { + List messages = new ArrayList<>(); + for (int i = 0; i < 5; i++) { + messages.add( + new EnqueueMessage( + "test-batch-explicit", + Map.of("idx", String.valueOf(i)), + ("batch-msg-" + i).getBytes())); + } + + List results = client.enqueueMany(messages); + assertEquals(5, results.size()); + + Set ids = new HashSet<>(); + for (EnqueueResult result : results) { + assertTrue(result.isSuccess(), "each message should succeed"); + assertFalse(result.getMessageId().isEmpty()); + ids.add(result.getMessageId()); + } + assertEquals(5, ids.size(), "all message IDs should be unique"); + } + } + + @Test + void explicitEnqueueManyWithNonexistentQueue() { + try (FilaClient client = + FilaClient.builder(server.address()).withBatchMode(BatchMode.disabled()).build()) { + List messages = new ArrayList<>(); + messages.add(new EnqueueMessage("test-batch-explicit", Map.of(), "good-msg".getBytes())); + messages.add(new EnqueueMessage("no-such-queue", Map.of(), "bad-msg".getBytes())); + messages.add(new EnqueueMessage("test-batch-explicit", Map.of(), "another-good".getBytes())); + + List results = client.enqueueMany(messages); + assertEquals(3, results.size()); + + // First and third should succeed, second should fail + assertTrue(results.get(0).isSuccess()); + assertFalse(results.get(1).isSuccess()); + assertTrue(results.get(2).isSuccess()); + } + } + + @Test + void autoBatchingEnqueue() throws Exception { + try (FilaClient client = + FilaClient.builder(server.address()).withBatchMode(BatchMode.auto()).build()) { + String msgId = + client.enqueue("test-batch-auto", Map.of("mode", "auto"), "auto-msg".getBytes()); + assertNotNull(msgId); + assertFalse(msgId.isEmpty()); + + CountDownLatch latch = new CountDownLatch(1); + AtomicReference received = new AtomicReference<>(); + + ConsumerHandle handle = + client.consume( + "test-batch-auto", + msg -> { + received.set(msg); + client.ack("test-batch-auto", msg.getId()); + latch.countDown(); + }); + + assertTrue(latch.await(10, TimeUnit.SECONDS), "should receive message within 10s"); + handle.cancel(); + + ConsumeMessage msg = received.get(); + assertNotNull(msg); + assertEquals(msgId, msg.getId()); + assertEquals("auto", msg.getHeaders().get("mode")); + assertArrayEquals("auto-msg".getBytes(), msg.getPayload()); + } + } + + @Test + void autoBatchingMultipleMessages() throws Exception { + try (FilaClient client = + FilaClient.builder(server.address()).withBatchMode(BatchMode.auto(50)).build()) { + int count = 10; + Set sentIds = new HashSet<>(); + for (int i = 0; i < count; i++) { + String msgId = + client.enqueue( + "test-batch-consume", Map.of("idx", String.valueOf(i)), ("msg-" + i).getBytes()); + assertNotNull(msgId); + sentIds.add(msgId); + } + assertEquals(count, sentIds.size(), "all message IDs should be unique"); + + CountDownLatch latch = new CountDownLatch(count); + Set receivedIds = java.util.Collections.synchronizedSet(new HashSet<>()); + + ConsumerHandle handle = + client.consume( + "test-batch-consume", + msg -> { + receivedIds.add(msg.getId()); + client.ack("test-batch-consume", msg.getId()); + latch.countDown(); + }); + + assertTrue(latch.await(15, TimeUnit.SECONDS), "should receive all messages within 15s"); + handle.cancel(); + + assertEquals(sentIds, receivedIds, "should receive all sent messages"); + } + } + + @Test + void lingerBatchingEnqueue() throws Exception { + try (FilaClient client = + FilaClient.builder(server.address()).withBatchMode(BatchMode.linger(50, 10)).build()) { + String msgId = + client.enqueue("test-batch-linger", Map.of("mode", "linger"), "linger-msg".getBytes()); + assertNotNull(msgId); + assertFalse(msgId.isEmpty()); + + CountDownLatch latch = new CountDownLatch(1); + AtomicReference received = new AtomicReference<>(); + + ConsumerHandle handle = + client.consume( + "test-batch-linger", + msg -> { + received.set(msg); + client.ack("test-batch-linger", msg.getId()); + latch.countDown(); + }); + + assertTrue(latch.await(10, TimeUnit.SECONDS), "should receive message within 10s"); + handle.cancel(); + + assertEquals(msgId, received.get().getId()); + } + } + + @Test + void disabledBatchingEnqueue() throws Exception { + try (FilaClient client = + FilaClient.builder(server.address()).withBatchMode(BatchMode.disabled()).build()) { + String msgId = + client.enqueue( + "test-batch-disabled", Map.of("mode", "disabled"), "direct-msg".getBytes()); + assertNotNull(msgId); + assertFalse(msgId.isEmpty()); + + CountDownLatch latch = new CountDownLatch(1); + AtomicReference received = new AtomicReference<>(); + + ConsumerHandle handle = + client.consume( + "test-batch-disabled", + msg -> { + received.set(msg); + client.ack("test-batch-disabled", msg.getId()); + latch.countDown(); + }); + + assertTrue(latch.await(10, TimeUnit.SECONDS), "should receive message within 10s"); + handle.cancel(); + + assertEquals(msgId, received.get().getId()); + } + } + + @Test + void enqueueNonexistentQueueThroughBatcher() { + try (FilaClient client = + FilaClient.builder(server.address()).withBatchMode(BatchMode.auto()).build()) { + assertThrows( + QueueNotFoundException.class, + () -> client.enqueue("no-such-queue-batch", Map.of(), "data".getBytes())); + } + } + + @Test + void defaultBatchModeIsAuto() throws Exception { + try (FilaClient client = FilaClient.builder(server.address()).build()) { + String msgId = + client.enqueue("test-batch-mixed", Map.of("default", "true"), "default-batch".getBytes()); + assertNotNull(msgId); + assertFalse(msgId.isEmpty()); + } + } +} diff --git a/src/test/java/dev/faisca/fila/BatchModeTest.java b/src/test/java/dev/faisca/fila/BatchModeTest.java new file mode 100644 index 0000000..2c5a930 --- /dev/null +++ b/src/test/java/dev/faisca/fila/BatchModeTest.java @@ -0,0 +1,57 @@ +package dev.faisca.fila; + +import static org.junit.jupiter.api.Assertions.*; + +import org.junit.jupiter.api.Test; + +/** Unit tests for BatchMode configuration. */ +class BatchModeTest { + + @Test + void autoDefaultMaxBatchSize() { + BatchMode mode = BatchMode.auto(); + assertEquals(BatchMode.Kind.AUTO, mode.getKind()); + assertEquals(100, mode.getMaxBatchSize()); + } + + @Test + void autoCustomMaxBatchSize() { + BatchMode mode = BatchMode.auto(50); + assertEquals(BatchMode.Kind.AUTO, mode.getKind()); + assertEquals(50, mode.getMaxBatchSize()); + } + + @Test + void autoRejectsZeroMaxBatchSize() { + assertThrows(IllegalArgumentException.class, () -> BatchMode.auto(0)); + } + + @Test + void autoRejectsNegativeMaxBatchSize() { + assertThrows(IllegalArgumentException.class, () -> BatchMode.auto(-1)); + } + + @Test + void lingerConfigValues() { + BatchMode mode = BatchMode.linger(10, 50); + assertEquals(BatchMode.Kind.LINGER, mode.getKind()); + assertEquals(10, mode.getLingerMs()); + assertEquals(50, mode.getMaxBatchSize()); + } + + @Test + void lingerRejectsZeroLingerMs() { + assertThrows(IllegalArgumentException.class, () -> BatchMode.linger(0, 50)); + } + + @Test + void lingerRejectsZeroBatchSize() { + assertThrows(IllegalArgumentException.class, () -> BatchMode.linger(10, 0)); + } + + @Test + void disabledMode() { + BatchMode mode = BatchMode.disabled(); + assertEquals(BatchMode.Kind.DISABLED, mode.getKind()); + } +} diff --git a/src/test/java/dev/faisca/fila/BuilderTest.java b/src/test/java/dev/faisca/fila/BuilderTest.java index b144878..25a2904 100644 --- a/src/test/java/dev/faisca/fila/BuilderTest.java +++ b/src/test/java/dev/faisca/fila/BuilderTest.java @@ -4,33 +4,17 @@ import org.junit.jupiter.api.Test; -/** Unit tests for FilaClient.Builder configuration. */ +/** Unit tests for FilaClient.Builder configuration validation. */ class BuilderTest { @Test - void builderPlaintextDoesNotThrow() { - // Plaintext builder should create a client without error - FilaClient client = FilaClient.builder("localhost:5555").build(); - assertNotNull(client); - client.close(); - } - - @Test - void builderWithApiKeyDoesNotThrow() { - // API key without TLS should work (for backward compat / dev mode) - FilaClient client = FilaClient.builder("localhost:5555").withApiKey("test-key").build(); - assertNotNull(client); - client.close(); - } - - @Test - void builderWithInvalidCaCertThrows() { - // Invalid PEM bytes should throw FilaException + void builderClientCertWithoutTlsThrows() { + // Client cert without TLS enabled should fail fast assertThrows( FilaException.class, () -> FilaClient.builder("localhost:5555") - .withTlsCaCert("not-a-valid-cert".getBytes()) + .withTlsClientCert("cert".getBytes(), "key".getBytes()) .build()); } @@ -40,39 +24,12 @@ void builderChainingReturnsBuilder() { FilaClient.Builder builder = FilaClient.builder("localhost:5555") .withApiKey("key") + .withBatchMode(BatchMode.auto()) .withTlsCaCert("cert".getBytes()) .withTlsClientCert("cert".getBytes(), "key".getBytes()); assertNotNull(builder); } - @Test - void builderClientCertWithoutTlsThrows() { - // Client cert without TLS enabled should fail fast - assertThrows( - FilaException.class, - () -> - FilaClient.builder("localhost:5555") - .withTlsClientCert("cert".getBytes(), "key".getBytes()) - .build()); - } - - @Test - void builderWithTlsSystemTrustDoesNotThrow() { - // withTls() using system trust store should create a client without error - FilaClient client = FilaClient.builder("localhost:5555").withTls().build(); - assertNotNull(client); - client.close(); - } - - @Test - void builderWithTlsAndApiKeyDoesNotThrow() { - // withTls() combined with API key should work - FilaClient client = - FilaClient.builder("localhost:5555").withTls().withApiKey("test-key").build(); - assertNotNull(client); - client.close(); - } - @Test void builderChainingWithTlsReturnsBuilder() { // Verify fluent API for withTls() returns the builder for chaining @@ -83,4 +40,15 @@ void builderChainingWithTlsReturnsBuilder() { .withTlsClientCert("cert".getBytes(), "key".getBytes()); assertNotNull(builder); } + + @Test + void builderWithInvalidCaCertThrows() { + // Invalid PEM bytes should throw FilaException during TLS setup + assertThrows( + FilaException.class, + () -> + FilaClient.builder("localhost:5555") + .withTlsCaCert("not-a-valid-cert".getBytes()) + .build()); + } } diff --git a/src/test/java/dev/faisca/fila/EnqueueResultTest.java b/src/test/java/dev/faisca/fila/EnqueueResultTest.java new file mode 100644 index 0000000..d22a317 --- /dev/null +++ b/src/test/java/dev/faisca/fila/EnqueueResultTest.java @@ -0,0 +1,35 @@ +package dev.faisca.fila; + +import static org.junit.jupiter.api.Assertions.*; + +import org.junit.jupiter.api.Test; + +/** Unit tests for EnqueueResult. */ +class EnqueueResultTest { + + @Test + void successResult() { + EnqueueResult result = EnqueueResult.success("msg-123"); + assertTrue(result.isSuccess()); + assertEquals("msg-123", result.getMessageId()); + } + + @Test + void successGetErrorThrows() { + EnqueueResult result = EnqueueResult.success("msg-123"); + assertThrows(IllegalStateException.class, result::getError); + } + + @Test + void errorResult() { + EnqueueResult result = EnqueueResult.error("queue not found"); + assertFalse(result.isSuccess()); + assertEquals("queue not found", result.getError()); + } + + @Test + void errorGetMessageIdThrows() { + EnqueueResult result = EnqueueResult.error("queue not found"); + assertThrows(IllegalStateException.class, result::getMessageId); + } +} diff --git a/src/test/java/dev/faisca/fila/TestServer.java b/src/test/java/dev/faisca/fila/TestServer.java index 9cfb14d..b4d8940 100644 --- a/src/test/java/dev/faisca/fila/TestServer.java +++ b/src/test/java/dev/faisca/fila/TestServer.java @@ -1,27 +1,23 @@ package dev.faisca.fila; -import fila.v1.Admin; -import fila.v1.FilaAdminGrpc; -import io.grpc.ChannelCredentials; -import io.grpc.Grpc; -import io.grpc.ManagedChannel; -import io.grpc.ManagedChannelBuilder; -import io.grpc.TlsChannelCredentials; -import java.io.ByteArrayInputStream; +import dev.faisca.fila.fibp.Codec; +import dev.faisca.fila.fibp.Connection; +import dev.faisca.fila.fibp.Opcodes; +import dev.faisca.fila.fibp.Primitives; import java.io.IOException; import java.net.ServerSocket; import java.nio.file.Files; import java.nio.file.Path; import java.util.Comparator; import java.util.concurrent.TimeUnit; +import javax.net.ssl.SSLContext; /** Manages a fila-server subprocess for integration tests. */ final class TestServer { private final Process process; private final Path dataDir; private final String address; - private final ManagedChannel adminChannel; - private final FilaAdminGrpc.FilaAdminBlockingStub adminStub; + private final Connection adminConn; private final boolean tlsEnabled; private final byte[] caCertPem; private final byte[] clientCertPem; @@ -32,7 +28,7 @@ private TestServer( Process process, Path dataDir, String address, - ManagedChannel adminChannel, + Connection adminConn, boolean tlsEnabled, byte[] caCertPem, byte[] clientCertPem, @@ -41,8 +37,7 @@ private TestServer( this.process = process; this.dataDir = dataDir; this.address = address; - this.adminChannel = adminChannel; - this.adminStub = FilaAdminGrpc.newBlockingStub(adminChannel); + this.adminConn = adminConn; this.tlsEnabled = tlsEnabled; this.caCertPem = caCertPem; this.clientCertPem = clientCertPem; @@ -50,55 +45,64 @@ private TestServer( this.apiKey = apiKey; } - /** Returns the address of the running server. */ String address() { return address; } - /** Returns true if TLS is enabled on this server. */ boolean isTlsEnabled() { return tlsEnabled; } - /** Returns the CA certificate PEM bytes. Only valid when TLS is enabled. */ byte[] caCertPem() { return caCertPem; } - /** Returns the client certificate PEM bytes. Only valid when TLS is enabled. */ byte[] clientCertPem() { return clientCertPem; } - /** Returns the client private key PEM bytes. Only valid when TLS is enabled. */ byte[] clientKeyPem() { return clientKeyPem; } - /** Returns the bootstrap API key. Only valid when auth is enabled. */ String apiKey() { return apiKey; } - /** Creates a queue on the test server (plaintext mode). */ + /** Creates a queue on the test server via FIBP. */ void createQueue(String name) { - adminStub.createQueue(Admin.CreateQueueRequest.newBuilder().setName(name).build()); + int requestId = adminConn.nextRequestId(); + byte[] frame = Codec.encodeCreateQueue(requestId, name, null, null, 0); + try { + Connection.Frame response = adminConn.sendAndReceive(frame, requestId, 10_000); + if (response.header().opcode() == Opcodes.ERROR) { + Primitives.Reader r = new Primitives.Reader(response.body()); + int code = r.readU8(); + String msg = r.readString(); + throw new RuntimeException("createQueue failed: code=" + code + " msg=" + msg); + } + if (response.header().opcode() == Opcodes.CREATE_QUEUE_RESULT) { + Primitives.Reader r = new Primitives.Reader(response.body()); + int code = r.readU8(); + if (code != Opcodes.ERR_OK && code != Opcodes.ERR_QUEUE_ALREADY_EXISTS) { + throw new RuntimeException("createQueue failed: code=" + code); + } + } + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new RuntimeException("createQueue failed", e); + } catch (IOException e) { + throw new RuntimeException("createQueue failed", e); + } } - /** Creates a queue using an authenticated admin stub (TLS + API key mode). */ + /** Creates a queue using an authenticated admin connection (TLS + API key mode). */ void createQueueWithApiKey(String name) { - // The admin channel was already created with TLS + API key interceptor - adminStub.createQueue(Admin.CreateQueueRequest.newBuilder().setName(name).build()); + createQueue(name); } - /** Stops the server and cleans up temporary files. */ void stop() { - adminChannel.shutdown(); - try { - adminChannel.awaitTermination(2, TimeUnit.SECONDS); - } catch (InterruptedException e) { - Thread.currentThread().interrupt(); - } + adminConn.close(); process.destroyForcibly(); try { process.waitFor(5, TimeUnit.SECONDS); @@ -108,14 +112,6 @@ void stop() { deleteDirectory(dataDir); } - /** - * Returns true if the fila-server binary is available at a known local path. - * - *

Note: This intentionally does NOT check PATH. The TLS integration tests require a local dev - * build to ensure cert generation and server config are compatible. In CI, the plaintext - * integration tests run via {@link FilaClientTest} using the downloaded binary; the TLS tests are - * skipped until the CI pipeline is configured to provision TLS test infrastructure. - */ static boolean isBinaryAvailable() { try { String path = findBinary(); @@ -140,14 +136,13 @@ static TestServer start() throws IOException, InterruptedException { pb.environment().put("FILA_DATA_DIR", dataDir.resolve("db").toString()); Process process = pb.start(); - if (!waitForPort(port, 10_000)) { + Connection adminConn = waitForHandshake("127.0.0.1", port, null, null, 10_000); + if (adminConn == null) { process.destroyForcibly(); deleteDirectory(dataDir); throw new IOException("fila-server failed to start within 10s on " + address); } - - ManagedChannel adminChannel = ManagedChannelBuilder.forTarget(address).usePlaintext().build(); - return new TestServer(process, dataDir, address, adminChannel, false, null, null, null, null); + return new TestServer(process, dataDir, address, adminConn, false, null, null, null, null); } /** Starts a fila-server with TLS and API key auth on a random port. */ @@ -157,14 +152,12 @@ static TestServer startWithTls() throws IOException, InterruptedException { Path dataDir = Files.createTempDirectory("fila-test-tls-"); - // Generate self-signed CA, server cert, and client cert using openssl generateCerts(dataDir); byte[] caCert = Files.readAllBytes(dataDir.resolve("ca.pem")); byte[] clientCert = Files.readAllBytes(dataDir.resolve("client.pem")); byte[] clientKey = Files.readAllBytes(dataDir.resolve("client-key.pem")); - // Bootstrap API key for auth String bootstrapKey = "test-bootstrap-key-" + System.currentTimeMillis(); Path configFile = dataDir.resolve("fila.toml"); @@ -175,15 +168,15 @@ static TestServer startWithTls() throws IOException, InterruptedException { + "\"\n" + "\n" + "[tls]\n" - + "ca_cert = \"" - + dataDir.resolve("ca.pem") - + "\"\n" - + "server_cert = \"" + + "cert_file = \"" + dataDir.resolve("server.pem") + "\"\n" - + "server_key = \"" + + "key_file = \"" + dataDir.resolve("server-key.pem") + "\"\n" + + "ca_file = \"" + + dataDir.resolve("ca.pem") + + "\"\n" + "\n" + "[auth]\n" + "bootstrap_apikey = \"" @@ -197,30 +190,19 @@ static TestServer startWithTls() throws IOException, InterruptedException { pb.environment().put("FILA_DATA_DIR", dataDir.resolve("db").toString()); Process process = pb.start(); - if (!waitForPort(port, 10_000)) { + SSLContext sslContext = FilaClient.Builder.buildSSLContext(caCert, clientCert, clientKey); + Connection adminConn = waitForHandshake("127.0.0.1", port, bootstrapKey, sslContext, 10_000); + if (adminConn == null) { process.destroyForcibly(); deleteDirectory(dataDir); throw new IOException("fila-server failed to start within 10s on " + address); } - // Create admin channel with TLS + API key - TlsChannelCredentials.Builder tlsBuilder = - TlsChannelCredentials.newBuilder().trustManager(new ByteArrayInputStream(caCert)); - tlsBuilder.keyManager( - new ByteArrayInputStream(clientCert), new ByteArrayInputStream(clientKey)); - ChannelCredentials creds = tlsBuilder.build(); - - ManagedChannel adminChannel = - Grpc.newChannelBuilderForAddress("127.0.0.1", port, creds) - .intercept(new ApiKeyInterceptor(bootstrapKey)) - .build(); - return new TestServer( - process, dataDir, address, adminChannel, true, caCert, clientCert, clientKey, bootstrapKey); + process, dataDir, address, adminConn, true, caCert, clientCert, clientKey, bootstrapKey); } private static void generateCerts(Path dir) throws IOException, InterruptedException { - // Generate CA key and cert exec( dir, "openssl", @@ -240,7 +222,6 @@ private static void generateCerts(Path dir) throws IOException, InterruptedExcep "-subj", "/CN=fila-test-ca"); - // Generate server key and CSR exec( dir, "openssl", @@ -257,11 +238,9 @@ private static void generateCerts(Path dir) throws IOException, InterruptedExcep "-subj", "/CN=127.0.0.1"); - // Write SAN extension file Files.writeString( dir.resolve("server-ext.cnf"), "subjectAltName=IP:127.0.0.1\nbasicConstraints=CA:FALSE\n"); - // Sign server cert with CA exec( dir, "openssl", @@ -281,7 +260,6 @@ private static void generateCerts(Path dir) throws IOException, InterruptedExcep "-extfile", "server-ext.cnf"); - // Generate client key and CSR exec( dir, "openssl", @@ -291,14 +269,25 @@ private static void generateCerts(Path dir) throws IOException, InterruptedExcep "-pkeyopt", "ec_paramgen_curve:prime256v1", "-keyout", - "client-key.pem", + "client-key-ec.pem", "-out", "client.csr", "-nodes", "-subj", "/CN=fila-test-client"); - // Sign client cert with CA + // Convert EC key to PKCS#8 format for Java compatibility + exec( + dir, + "openssl", + "pkcs8", + "-topk8", + "-nocrypt", + "-in", + "client-key-ec.pem", + "-out", + "client-key.pem"); + exec( dir, "openssl", @@ -354,16 +343,23 @@ private static int findFreePort() throws IOException { } } - private static boolean waitForPort(int port, long timeoutMs) throws InterruptedException { + /** + * Wait for the server to accept a FIBP handshake, retrying up to timeoutMs. + * + * @return a connected Connection, or null if timed out + */ + private static Connection waitForHandshake( + String host, int port, String apiKey, SSLContext sslContext, long timeoutMs) + throws InterruptedException { long deadline = System.currentTimeMillis() + timeoutMs; while (System.currentTimeMillis() < deadline) { - try (var sock = new java.net.Socket("127.0.0.1", port)) { - return true; + try { + return Connection.connect(host, port, apiKey, sslContext); } catch (IOException e) { - Thread.sleep(100); + Thread.sleep(200); } } - return false; + return null; } private static void deleteDirectory(Path dir) { diff --git a/src/test/java/dev/faisca/fila/TlsAuthClientTest.java b/src/test/java/dev/faisca/fila/TlsAuthClientTest.java index ac3560b..93cb4dc 100644 --- a/src/test/java/dev/faisca/fila/TlsAuthClientTest.java +++ b/src/test/java/dev/faisca/fila/TlsAuthClientTest.java @@ -71,42 +71,35 @@ void connectWithTlsAndApiKey() throws Exception { } } - @Test - void connectWithTlsOnly() throws Exception { - // TLS without API key — validates TLS transport works independently of auth - try (FilaClient client = - FilaClient.builder(server.address()) - .withTlsCaCert(server.caCertPem()) - .withTlsClientCert(server.clientCertPem(), server.clientKeyPem()) - .build()) { - // Without an API key on an auth-enabled server, the enqueue should be rejected. - // This validates TLS transport is working (connection succeeds) but auth is enforced. - RpcException ex = - assertThrows( - RpcException.class, - () -> client.enqueue("test-tls-auth", Map.of(), "tls-only".getBytes())); - assertEquals( - io.grpc.Status.Code.UNAUTHENTICATED, - ex.getCode(), - "should reject with UNAUTHENTICATED when no API key is provided"); + private static String fullExceptionMessage(Throwable t) { + StringBuilder sb = new StringBuilder(); + for (Throwable cur = t; cur != null; cur = cur.getCause()) { + if (cur.getMessage() != null) { + sb.append(cur.getMessage()).append(" | "); + } } + return sb.toString(); } @Test void rejectWithoutApiKey() { - try (FilaClient client = - FilaClient.builder(server.address()) - .withTlsCaCert(server.caCertPem()) - .withTlsClientCert(server.clientCertPem(), server.clientKeyPem()) - .build()) { - RpcException ex = - assertThrows( - RpcException.class, - () -> client.enqueue("test-tls-auth", Map.of(), "no-key".getBytes())); - assertEquals( - io.grpc.Status.Code.UNAUTHENTICATED, - ex.getCode(), - "should reject with UNAUTHENTICATED when no API key is provided"); - } + // Without an API key on an auth-enabled server, the FIBP handshake is rejected. + // The client.build() should throw because the connection is refused during handshake. + FilaException ex = + assertThrows( + FilaException.class, + () -> + FilaClient.builder(server.address()) + .withTlsCaCert(server.caCertPem()) + .withTlsClientCert(server.clientCertPem(), server.clientKeyPem()) + .build()); + // The exception may wrap the root cause; check full chain for auth-related content. + String fullMessage = fullExceptionMessage(ex); + assertTrue( + fullMessage.contains("handshake rejected") + || fullMessage.contains("unauthorized") + || fullMessage.contains("auth") + || fullMessage.contains("rejected"), + "expected auth-related error in exception chain, got: " + fullMessage); } } diff --git a/src/test/java/dev/faisca/fila/fibp/PrimitivesTest.java b/src/test/java/dev/faisca/fila/fibp/PrimitivesTest.java new file mode 100644 index 0000000..c6de01b --- /dev/null +++ b/src/test/java/dev/faisca/fila/fibp/PrimitivesTest.java @@ -0,0 +1,144 @@ +package dev.faisca.fila.fibp; + +import static org.junit.jupiter.api.Assertions.*; + +import java.util.LinkedHashMap; +import java.util.Map; +import org.junit.jupiter.api.Test; + +/** Unit tests for FIBP primitives encoding/decoding. */ +class PrimitivesTest { + + @Test + void u8RoundTrip() { + Primitives.Writer w = new Primitives.Writer(); + w.writeU8(0); + w.writeU8(127); + w.writeU8(255); + + Primitives.Reader r = new Primitives.Reader(w.toByteArray()); + assertEquals(0, r.readU8()); + assertEquals(127, r.readU8()); + assertEquals(255, r.readU8()); + } + + @Test + void u16RoundTrip() { + Primitives.Writer w = new Primitives.Writer(); + w.writeU16(0); + w.writeU16(1000); + w.writeU16(65535); + + Primitives.Reader r = new Primitives.Reader(w.toByteArray()); + assertEquals(0, r.readU16()); + assertEquals(1000, r.readU16()); + assertEquals(65535, r.readU16()); + } + + @Test + void u32RoundTrip() { + Primitives.Writer w = new Primitives.Writer(); + w.writeU32(0); + w.writeU32(100_000); + w.writeU32(0xFFFFFFFFL); + + Primitives.Reader r = new Primitives.Reader(w.toByteArray()); + assertEquals(0, r.readU32()); + assertEquals(100_000, r.readU32()); + assertEquals(0xFFFFFFFFL, r.readU32()); + } + + @Test + void u64RoundTrip() { + Primitives.Writer w = new Primitives.Writer(); + w.writeU64(0); + w.writeU64(Long.MAX_VALUE); + + Primitives.Reader r = new Primitives.Reader(w.toByteArray()); + assertEquals(0, r.readU64()); + assertEquals(Long.MAX_VALUE, r.readU64()); + } + + @Test + void boolRoundTrip() { + Primitives.Writer w = new Primitives.Writer(); + w.writeBool(true); + w.writeBool(false); + + Primitives.Reader r = new Primitives.Reader(w.toByteArray()); + assertTrue(r.readBool()); + assertFalse(r.readBool()); + } + + @Test + void stringRoundTrip() { + Primitives.Writer w = new Primitives.Writer(); + w.writeString("hello"); + w.writeString(""); + w.writeString("unicode: \u00e9\u00e8\u00ea"); + + Primitives.Reader r = new Primitives.Reader(w.toByteArray()); + assertEquals("hello", r.readString()); + assertEquals("", r.readString()); + assertEquals("unicode: \u00e9\u00e8\u00ea", r.readString()); + } + + @Test + void bytesRoundTrip() { + Primitives.Writer w = new Primitives.Writer(); + byte[] data = {1, 2, 3, 4, 5}; + w.writeBytes(data); + w.writeBytes(new byte[0]); + + Primitives.Reader r = new Primitives.Reader(w.toByteArray()); + assertArrayEquals(data, r.readBytes()); + assertArrayEquals(new byte[0], r.readBytes()); + } + + @Test + void stringMapRoundTrip() { + Primitives.Writer w = new Primitives.Writer(); + Map map = new LinkedHashMap<>(); + map.put("key1", "val1"); + map.put("key2", "val2"); + w.writeStringMap(map); + + Primitives.Reader r = new Primitives.Reader(w.toByteArray()); + Map result = r.readStringMap(); + assertEquals(2, result.size()); + assertEquals("val1", result.get("key1")); + assertEquals("val2", result.get("key2")); + } + + @Test + void optionalStringRoundTrip() { + Primitives.Writer w = new Primitives.Writer(); + w.writeOptionalString("present"); + w.writeOptionalString(null); + + Primitives.Reader r = new Primitives.Reader(w.toByteArray()); + assertEquals("present", r.readOptionalString()); + assertNull(r.readOptionalString()); + } + + @Test + void f64RoundTrip() { + Primitives.Writer w = new Primitives.Writer(); + w.writeF64(3.14); + w.writeF64(0.0); + + Primitives.Reader r = new Primitives.Reader(w.toByteArray()); + assertEquals(3.14, r.readF64(), 0.001); + assertEquals(0.0, r.readF64(), 0.001); + } + + @Test + void stringListRoundTrip() { + Primitives.Writer w = new Primitives.Writer(); + w.writeStringList(new String[] {"a", "b", "c"}); + + Primitives.Reader r = new Primitives.Reader(w.toByteArray()); + String[] list = r.readStringList(); + assertArrayEquals(new String[] {"a", "b", "c"}, list); + } +}