Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 127 additions & 13 deletions src/main/java/dev/faisca/fila/FilaClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import io.grpc.Grpc;
import io.grpc.ManagedChannel;
import io.grpc.ManagedChannelBuilder;
import io.grpc.Metadata;
import io.grpc.StatusRuntimeException;
import io.grpc.TlsChannelCredentials;
import java.io.ByteArrayInputStream;
Expand Down Expand Up @@ -35,12 +36,28 @@
* }</pre>
*/
public final class FilaClient implements AutoCloseable {
private static final Metadata.Key<String> LEADER_ADDR_KEY =
Metadata.Key.of("x-fila-leader-addr", Metadata.ASCII_STRING_MARSHALLER);

private final ManagedChannel channel;
private final FilaServiceGrpc.FilaServiceBlockingStub blockingStub;

private FilaClient(ManagedChannel channel) {
private final byte[] caCertPem;
private final byte[] clientCertPem;
private final byte[] clientKeyPem;
private final String apiKey;

private FilaClient(
ManagedChannel channel,
byte[] caCertPem,
byte[] clientCertPem,
byte[] clientKeyPem,
String apiKey) {
this.channel = channel;
this.blockingStub = FilaServiceGrpc.newBlockingStub(channel);
this.caCertPem = caCertPem;
this.clientCertPem = clientCertPem;
this.clientKeyPem = clientKeyPem;
this.apiKey = apiKey;
}

/** Returns a new builder for configuring a {@link FilaClient}. */
Expand Down Expand Up @@ -96,15 +113,15 @@ public ConsumerHandle consume(String queue, Consumer<ConsumeMessage> handler) {
() -> {
try {
Iterator<Service.ConsumeResponse> stream = blockingStub.consume(req);
while (stream.hasNext()) {
Service.ConsumeResponse resp = stream.next();
if (!resp.hasMessage() || resp.getMessage().getId().isEmpty()) {
continue;
}
handler.accept(buildConsumeMessage(resp.getMessage()));
}
consumeStream(stream, handler);
} catch (StatusRuntimeException e) {
if (e.getStatus().getCode() != io.grpc.Status.Code.CANCELLED) {
if (e.getStatus().getCode() == io.grpc.Status.Code.CANCELLED) {
return;
}
String leaderAddr = extractLeaderAddr(e);
if (leaderAddr != null) {
Comment thread
cubic-dev-ai[bot] marked this conversation as resolved.
retryOnLeader(leaderAddr, req, handler);
} else {
throw mapConsumeError(e);
}
}
Expand Down Expand Up @@ -171,6 +188,103 @@ public void close() {
}
}

private static void consumeStream(
Iterator<Service.ConsumeResponse> stream, Consumer<ConsumeMessage> handler) {
while (stream.hasNext()) {
Service.ConsumeResponse resp = stream.next();
if (!resp.hasMessage() || resp.getMessage().getId().isEmpty()) {
continue;
}
handler.accept(buildConsumeMessage(resp.getMessage()));
}
}

private static String extractLeaderAddr(StatusRuntimeException e) {
if (e.getStatus().getCode() != io.grpc.Status.Code.UNAVAILABLE) {
return null;
}
Metadata trailers = e.getTrailers();
if (trailers == null) {
return null;
}
return trailers.get(LEADER_ADDR_KEY);
}

private static void validateLeaderAddr(String addr) {
if (addr == null || addr.isEmpty()) {
throw new FilaException("invalid leader address: empty");
}
// Must not contain scheme (e.g. "http://") or path (e.g. "/foo")
if (addr.contains("//") || addr.contains("/")) {
throw new FilaException("invalid leader address: must be host:port, got: " + addr);
}
int colonIdx = addr.lastIndexOf(':');
if (colonIdx < 0) {
throw new FilaException("invalid leader address: missing port, got: " + addr);
}
String host = addr.substring(0, colonIdx);
String portStr = addr.substring(colonIdx + 1);
if (host.isEmpty()) {
throw new FilaException("invalid leader address: empty host, got: " + addr);
}
int port;
try {
port = Integer.parseInt(portStr);
} catch (NumberFormatException e) {
throw new FilaException("invalid leader address: non-numeric port, got: " + addr);
}
if (port < 1 || port > 65535) {
throw new FilaException("invalid leader address: port out of range, got: " + addr);
}
}

private void retryOnLeader(
String leaderAddr, Service.ConsumeRequest req, Consumer<ConsumeMessage> handler) {
validateLeaderAddr(leaderAddr);
ManagedChannel leaderChannel = buildChannel(leaderAddr);
try {
FilaServiceGrpc.FilaServiceBlockingStub leaderStub =
FilaServiceGrpc.newBlockingStub(leaderChannel);
Iterator<Service.ConsumeResponse> stream = leaderStub.consume(req);
consumeStream(stream, handler);
} catch (StatusRuntimeException e) {
if (e.getStatus().getCode() != io.grpc.Status.Code.CANCELLED) {
throw mapConsumeError(e);
}
} finally {
leaderChannel.shutdown();
}
}

private ManagedChannel buildChannel(String address) {
if (caCertPem != null) {
try {
TlsChannelCredentials.Builder tlsBuilder =
TlsChannelCredentials.newBuilder().trustManager(new ByteArrayInputStream(caCertPem));
if (clientCertPem != null && clientKeyPem != null) {
tlsBuilder.keyManager(
new ByteArrayInputStream(clientCertPem), new ByteArrayInputStream(clientKeyPem));
}
ChannelCredentials creds = tlsBuilder.build();
var channelBuilder =
Grpc.newChannelBuilderForAddress(
Builder.parseHost(address), Builder.parsePort(address), creds);
if (apiKey != null) {
channelBuilder.intercept(new ApiKeyInterceptor(apiKey));
}
return channelBuilder.build();
} catch (IOException e) {
throw new FilaException("failed to configure TLS for leader redirect", e);
}
} else {
var channelBuilder = ManagedChannelBuilder.forTarget(address).usePlaintext();
if (apiKey != null) {
channelBuilder.intercept(new ApiKeyInterceptor(apiKey));
}
return channelBuilder.build();
}
}

private static ConsumeMessage buildConsumeMessage(Messages.Message msg) {
Messages.MessageMetadata meta = msg.getMetadata();
return new ConsumeMessage(
Expand Down Expand Up @@ -333,10 +447,10 @@ public FilaClient build() {
channel = channelBuilder.build();
}

return new FilaClient(channel);
return new FilaClient(channel, caCertPem, clientCertPem, clientKeyPem, apiKey);
}

private static String parseHost(String address) {
static String parseHost(String address) {
// Handle IPv6 bracket notation: [::1]:5555
if (address.startsWith("[")) {
int closeBracket = address.indexOf(']');
Expand All @@ -352,7 +466,7 @@ private static String parseHost(String address) {
return address.substring(0, colonIdx);
}

private static int parsePort(String address) {
static int parsePort(String address) {
// Handle IPv6 bracket notation: [::1]:5555
if (address.startsWith("[")) {
int closeBracket = address.indexOf(']');
Expand Down
Loading