diff --git a/src/main/java/dev/faisca/fila/FilaClient.java b/src/main/java/dev/faisca/fila/FilaClient.java index 8551fc0..188f92c 100644 --- a/src/main/java/dev/faisca/fila/FilaClient.java +++ b/src/main/java/dev/faisca/fila/FilaClient.java @@ -8,6 +8,7 @@ import io.grpc.Grpc; import io.grpc.ManagedChannel; import io.grpc.ManagedChannelBuilder; +import io.grpc.Metadata; import io.grpc.StatusRuntimeException; import io.grpc.TlsChannelCredentials; import java.io.ByteArrayInputStream; @@ -35,12 +36,28 @@ * } */ public final class FilaClient implements AutoCloseable { + private static final Metadata.Key LEADER_ADDR_KEY = + Metadata.Key.of("x-fila-leader-addr", Metadata.ASCII_STRING_MARSHALLER); + private final ManagedChannel channel; private final FilaServiceGrpc.FilaServiceBlockingStub blockingStub; - - private FilaClient(ManagedChannel channel) { + private final byte[] caCertPem; + private final byte[] clientCertPem; + private final byte[] clientKeyPem; + private final String apiKey; + + private FilaClient( + ManagedChannel channel, + byte[] caCertPem, + byte[] clientCertPem, + byte[] clientKeyPem, + String apiKey) { this.channel = channel; this.blockingStub = FilaServiceGrpc.newBlockingStub(channel); + this.caCertPem = caCertPem; + this.clientCertPem = clientCertPem; + this.clientKeyPem = clientKeyPem; + this.apiKey = apiKey; } /** Returns a new builder for configuring a {@link FilaClient}. */ @@ -96,15 +113,15 @@ public ConsumerHandle consume(String queue, Consumer handler) { () -> { try { Iterator stream = blockingStub.consume(req); - while (stream.hasNext()) { - Service.ConsumeResponse resp = stream.next(); - if (!resp.hasMessage() || resp.getMessage().getId().isEmpty()) { - continue; - } - handler.accept(buildConsumeMessage(resp.getMessage())); - } + consumeStream(stream, handler); } catch (StatusRuntimeException e) { - if (e.getStatus().getCode() != io.grpc.Status.Code.CANCELLED) { + if (e.getStatus().getCode() == io.grpc.Status.Code.CANCELLED) { + return; + } + String leaderAddr = extractLeaderAddr(e); + if (leaderAddr != null) { + retryOnLeader(leaderAddr, req, handler); + } else { throw mapConsumeError(e); } } @@ -171,6 +188,103 @@ public void close() { } } + private static void consumeStream( + Iterator stream, Consumer handler) { + while (stream.hasNext()) { + Service.ConsumeResponse resp = stream.next(); + if (!resp.hasMessage() || resp.getMessage().getId().isEmpty()) { + continue; + } + handler.accept(buildConsumeMessage(resp.getMessage())); + } + } + + private static String extractLeaderAddr(StatusRuntimeException e) { + if (e.getStatus().getCode() != io.grpc.Status.Code.UNAVAILABLE) { + return null; + } + Metadata trailers = e.getTrailers(); + if (trailers == null) { + return null; + } + return trailers.get(LEADER_ADDR_KEY); + } + + private static void validateLeaderAddr(String addr) { + if (addr == null || addr.isEmpty()) { + throw new FilaException("invalid leader address: empty"); + } + // Must not contain scheme (e.g. "http://") or path (e.g. "/foo") + if (addr.contains("//") || addr.contains("/")) { + throw new FilaException("invalid leader address: must be host:port, got: " + addr); + } + int colonIdx = addr.lastIndexOf(':'); + if (colonIdx < 0) { + throw new FilaException("invalid leader address: missing port, got: " + addr); + } + String host = addr.substring(0, colonIdx); + String portStr = addr.substring(colonIdx + 1); + if (host.isEmpty()) { + throw new FilaException("invalid leader address: empty host, got: " + addr); + } + int port; + try { + port = Integer.parseInt(portStr); + } catch (NumberFormatException e) { + throw new FilaException("invalid leader address: non-numeric port, got: " + addr); + } + if (port < 1 || port > 65535) { + throw new FilaException("invalid leader address: port out of range, got: " + addr); + } + } + + private void retryOnLeader( + String leaderAddr, Service.ConsumeRequest req, Consumer handler) { + validateLeaderAddr(leaderAddr); + ManagedChannel leaderChannel = buildChannel(leaderAddr); + try { + FilaServiceGrpc.FilaServiceBlockingStub leaderStub = + FilaServiceGrpc.newBlockingStub(leaderChannel); + Iterator stream = leaderStub.consume(req); + consumeStream(stream, handler); + } catch (StatusRuntimeException e) { + if (e.getStatus().getCode() != io.grpc.Status.Code.CANCELLED) { + throw mapConsumeError(e); + } + } finally { + leaderChannel.shutdown(); + } + } + + private ManagedChannel buildChannel(String address) { + if (caCertPem != null) { + try { + TlsChannelCredentials.Builder tlsBuilder = + TlsChannelCredentials.newBuilder().trustManager(new ByteArrayInputStream(caCertPem)); + if (clientCertPem != null && clientKeyPem != null) { + tlsBuilder.keyManager( + new ByteArrayInputStream(clientCertPem), new ByteArrayInputStream(clientKeyPem)); + } + ChannelCredentials creds = tlsBuilder.build(); + var channelBuilder = + Grpc.newChannelBuilderForAddress( + Builder.parseHost(address), Builder.parsePort(address), creds); + if (apiKey != null) { + channelBuilder.intercept(new ApiKeyInterceptor(apiKey)); + } + return channelBuilder.build(); + } catch (IOException e) { + throw new FilaException("failed to configure TLS for leader redirect", e); + } + } else { + var channelBuilder = ManagedChannelBuilder.forTarget(address).usePlaintext(); + if (apiKey != null) { + channelBuilder.intercept(new ApiKeyInterceptor(apiKey)); + } + return channelBuilder.build(); + } + } + private static ConsumeMessage buildConsumeMessage(Messages.Message msg) { Messages.MessageMetadata meta = msg.getMetadata(); return new ConsumeMessage( @@ -333,10 +447,10 @@ public FilaClient build() { channel = channelBuilder.build(); } - return new FilaClient(channel); + return new FilaClient(channel, caCertPem, clientCertPem, clientKeyPem, apiKey); } - private static String parseHost(String address) { + static String parseHost(String address) { // Handle IPv6 bracket notation: [::1]:5555 if (address.startsWith("[")) { int closeBracket = address.indexOf(']'); @@ -352,7 +466,7 @@ private static String parseHost(String address) { return address.substring(0, colonIdx); } - private static int parsePort(String address) { + static int parsePort(String address) { // Handle IPv6 bracket notation: [::1]:5555 if (address.startsWith("[")) { int closeBracket = address.indexOf(']');