diff --git a/integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/ConcurrencyTest.kt b/integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/ConcurrencyTest.kt new file mode 100644 index 00000000..feb78c9c --- /dev/null +++ b/integration-test/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/integration/ConcurrencyTest.kt @@ -0,0 +1,158 @@ +package io.modelcontextprotocol.kotlin.sdk.integration + +import io.modelcontextprotocol.kotlin.sdk.client.Client +import io.modelcontextprotocol.kotlin.sdk.client.ClientOptions +import io.modelcontextprotocol.kotlin.sdk.server.Server +import io.modelcontextprotocol.kotlin.sdk.server.ServerOptions +import io.modelcontextprotocol.kotlin.sdk.server.ServerSession +import io.modelcontextprotocol.kotlin.sdk.shared.AbstractTransport +import io.modelcontextprotocol.kotlin.sdk.shared.TransportSendOptions +import io.modelcontextprotocol.kotlin.sdk.types.CallToolResult +import io.modelcontextprotocol.kotlin.sdk.types.Implementation +import io.modelcontextprotocol.kotlin.sdk.types.JSONRPCMessage +import io.modelcontextprotocol.kotlin.sdk.types.ServerCapabilities +import io.modelcontextprotocol.kotlin.sdk.types.TextContent +import kotlinx.coroutines.CompletableDeferred +import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.SupervisorJob +import kotlinx.coroutines.cancel +import kotlinx.coroutines.channels.Channel +import kotlinx.coroutines.joinAll +import kotlinx.coroutines.launch +import kotlinx.coroutines.runBlocking +import kotlinx.coroutines.withTimeout +import kotlin.concurrent.atomics.ExperimentalAtomicApi +import kotlin.test.Test +import kotlin.test.assertNotNull +import kotlin.test.assertTrue +import kotlin.time.Duration.Companion.seconds + +/** + * Tests that the Protocol layer handles incoming messages concurrently, + * preventing deadlock when a request handler needs to wait for other messages. + * + * See: https://github.com/modelcontextprotocol/kotlin-sdk/issues/176 + */ +class ConcurrencyTest { + + /** + * A channel-based transport that delivers messages asynchronously via Kotlin Channels, + * simulating real network transports. This is necessary to reproduce the concurrency + * bug — the synchronous InMemoryTransport masks the issue. + */ + @OptIn(ExperimentalAtomicApi::class) + private class ChannelTransport( + private val scope: CoroutineScope, + private val sendChannel: Channel, + private val receiveChannel: Channel, + ) : AbstractTransport() { + override suspend fun start() { + scope.launch { + for (message in receiveChannel) { + _onMessage.invoke(message) + } + } + } + + override suspend fun send(message: JSONRPCMessage, options: TransportSendOptions?) { + sendChannel.send(message) + } + + override suspend fun close() { + sendChannel.close() + receiveChannel.cancel() + invokeOnCloseCallback() + } + + companion object { + fun createLinkedPair(scope: CoroutineScope): Pair { + val clientToServer = Channel(Channel.UNLIMITED) + val serverToClient = Channel(Channel.UNLIMITED) + return Pair( + ChannelTransport(scope, serverToClient, clientToServer), + ChannelTransport(scope, clientToServer, serverToClient), + ) + } + } + } + + /** + * Verifies that concurrent tool calls are handled concurrently, not serially. + * + * Uses deterministic synchronization: the fast tool completes while the slow + * handler is still suspended, proving that handlers run concurrently rather + * than serially. No wall-clock timing thresholds are used. + */ + @OptIn(ExperimentalAtomicApi::class) + @Test + fun `server handles concurrent requests concurrently`() = runBlocking { + val serverOptions = ServerOptions( + capabilities = ServerCapabilities(tools = ServerCapabilities.Tools(null)), + ) + serverOptions.concurrentMessageHandling = true + + val server = Server( + serverInfo = Implementation("test-server", "1.0"), + options = serverOptions, + ) + + // Latch that blocks the slow handler until we signal it to finish. + // This lets us prove the fast handler completed while the slow one + // was still running — impossible under serial dispatch. + val slowHandlerCanFinish = CompletableDeferred() + + server.addTool("slow_tool", "A tool that blocks until signaled") { + slowHandlerCanFinish.await() + CallToolResult(content = listOf(TextContent("slow_tool_done"))) + } + + server.addTool("fast_tool", "A tool that completes immediately") { + CallToolResult(content = listOf(TextContent("fast_tool_done"))) + } + + val client = Client( + clientInfo = Implementation("test-client", "1.0"), + options = ClientOptions(), + ) + + val scope = CoroutineScope(SupervisorJob() + Dispatchers.Default) + val (clientTransport, serverTransport) = ChannelTransport.createLinkedPair(scope) + val serverSessionResult = CompletableDeferred() + + try { + listOf( + launch { client.connect(clientTransport) }, + launch { serverSessionResult.complete(server.createSession(serverTransport)) }, + ).joinAll() + + // Start the slow request (handler blocks on slowHandlerCanFinish) + val slowResult = CompletableDeferred() + launch { + slowResult.complete(client.callTool("slow_tool", mapOf())) + } + + // Start the fast request + val fastResult = CompletableDeferred() + launch { + fastResult.complete(client.callTool("fast_tool", mapOf())) + } + + // The fast request must complete while the slow handler is still suspended. + // Under serial dispatch, both requests would be blocked behind the slow handler, + // so the fast result could never arrive. + val fast = withTimeout(5.seconds) { fastResult.await() } + assertNotNull(fast) + + // Now release the slow handler and verify it completes + slowHandlerCanFinish.complete(Unit) + val slow = withTimeout(5.seconds) { slowResult.await() } + assertNotNull(slow) + Unit + } finally { + clientTransport.close() + serverTransport.close() + scope.cancel() + } + } +} \ No newline at end of file diff --git a/kotlin-sdk-core/Module.md b/kotlin-sdk-core/Module.md index b4ed71b5..2bfd69ef 100644 --- a/kotlin-sdk-core/Module.md +++ b/kotlin-sdk-core/Module.md @@ -16,7 +16,11 @@ designed for Kotlin Multiplatform with explicit API mode enabled. handling. `WebSocketMcpTransport` adds a shared WebSocket implementation for both client and server sides, and `ReadBuffer` handles streaming JSON-RPC framing. - **Protocol engine**: The `Protocol` base class manages request/response correlation, notifications, progress tokens, - and capability assertions. Higher-level modules extend it to become `Client` and `Server`. + and capability assertions. When `concurrentMessageHandling` is enabled on `ProtocolOptions`, incoming requests and + notifications are dispatched concurrently in separate coroutines backed by a `SupervisorJob`, preventing deadlock + when a request handler sends its own request (e.g., `roots/list`) before responding. Defaults to false for backward + compatibility; set to true for transports with independent receive loops (Stdio, WebSocket, Channel). + Higher-level modules extend `Protocol` to become `Client` and `Server`. - **Errors and safety**: Common exception types (`McpException`, parsing errors) plus capability enforcement hooks ensure callers cannot use endpoints the peer does not advertise. diff --git a/kotlin-sdk-core/api/kotlin-sdk-core.api b/kotlin-sdk-core/api/kotlin-sdk-core.api index c843aea0..e6b4f9b6 100644 --- a/kotlin-sdk-core/api/kotlin-sdk-core.api +++ b/kotlin-sdk-core/api/kotlin-sdk-core.api @@ -90,8 +90,10 @@ public final class io/modelcontextprotocol/kotlin/sdk/shared/ProtocolKt { public class io/modelcontextprotocol/kotlin/sdk/shared/ProtocolOptions { public synthetic fun (ZJILkotlin/jvm/internal/DefaultConstructorMarker;)V public synthetic fun (ZJLkotlin/jvm/internal/DefaultConstructorMarker;)V + public final fun getConcurrentMessageHandling ()Z public final fun getEnforceStrictCapabilities ()Z public final fun getTimeout-UwyO8pc ()J + public final fun setConcurrentMessageHandling (Z)V public final fun setEnforceStrictCapabilities (Z)V public final fun setTimeout-LRDsOJo (J)V } diff --git a/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/Protocol.kt b/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/Protocol.kt index 8f86c035..e5d9a2c6 100644 --- a/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/Protocol.kt +++ b/kotlin-sdk-core/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/shared/Protocol.kt @@ -30,8 +30,13 @@ import kotlinx.atomicfu.update import kotlinx.collections.immutable.PersistentMap import kotlinx.collections.immutable.persistentMapOf import kotlinx.coroutines.CompletableDeferred +import kotlinx.coroutines.CoroutineName +import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.Deferred +import kotlinx.coroutines.SupervisorJob import kotlinx.coroutines.TimeoutCancellationException +import kotlinx.coroutines.cancel +import kotlinx.coroutines.launch import kotlinx.coroutines.withTimeout import kotlinx.serialization.json.JsonObject import kotlinx.serialization.json.JsonPrimitive @@ -64,7 +69,17 @@ public typealias ProgressCallback = (Progress) -> Unit public open class ProtocolOptions( public var enforceStrictCapabilities: Boolean = false, public var timeout: Duration = DEFAULT_REQUEST_TIMEOUT, -) +) { + /** + * When true, incoming requests and notifications are handled concurrently + * in separate coroutines, allowing the message receive loop to continue + * processing other messages. This prevents deadlock when a handler sends + * its own request to the peer. Defaults to false for backward compatibility; + * set to true for transports with independent receive loops (Stdio, WebSocket, + * Channel) where a blocking handler would otherwise stall message processing. + */ + public var concurrentMessageHandling: Boolean = false +} /** * The default request timeout. @@ -148,6 +163,13 @@ public abstract class Protocol(@PublishedApi internal val options: ProtocolOptio public var transport: Transport? = null private set + /** + * Scope for launching concurrent request and notification handlers. + * Created on [connect] and cancelled on [doClose]. + * Using [SupervisorJob] so a failing handler doesn't cancel sibling handlers. + */ + private var handlerScope: CoroutineScope? = null + private val _requestHandlers: AtomicRef RequestResult?>> = atomic(persistentMapOf()) @@ -227,9 +249,20 @@ public abstract class Protocol(@PublishedApi internal val options: ProtocolOptio * Attaches to the given transport, starts it, and starts listening for messages. * * The Protocol object assumes ownership of the Transport, replacing any callbacks that have already been set, and expects that it is the only user of the Transport instance going forward. + * + * When [ProtocolOptions.concurrentMessageHandling] is true, incoming requests and notifications + * are handled concurrently in separate coroutines, allowing the message receive loop to continue processing + * other messages (including responses to outgoing requests). This prevents deadlock when a request + * handler sends its own request to the peer and awaits the response. Defaults to false for backward + * compatibility; set to true for transports with independent receive loops (Stdio, WebSocket, + * Channel) where a blocking handler would otherwise stall message processing. */ public open suspend fun connect(transport: Transport) { this.transport = transport + if (options?.concurrentMessageHandling == true) { + handlerScope = CoroutineScope(SupervisorJob() + kotlinx.coroutines.Dispatchers.Default) + } + transport.onClose { doClose() } @@ -241,9 +274,35 @@ public abstract class Protocol(@PublishedApi internal val options: ProtocolOptio transport.onMessage { message -> when (message) { is JSONRPCResponse -> onResponse(message, null) - is JSONRPCRequest -> onRequest(message) - is JSONRPCNotification -> onNotification(message) + is JSONRPCError -> onResponse(null, message) + + is JSONRPCRequest -> { + val scope = handlerScope + if (scope != null) { + // Concurrent handling: launch in a separate coroutine so the message + // receive loop is not blocked while the handler runs. + scope.launch(CoroutineName("MCP-Request-${message.id}")) { + onRequest(message) + } + } else { + // Synchronous handling: for transports that need responses sent within + // the same context (e.g., HTTP transports responding directly). + onRequest(message) + } + } + + is JSONRPCNotification -> { + val scope = handlerScope + if (scope != null) { + scope.launch(CoroutineName("MCP-Notification-${message.method}")) { + onNotification(message) + } + } else { + onNotification(message) + } + } + is JSONRPCEmptyMessage -> Unit } } @@ -253,6 +312,9 @@ public abstract class Protocol(@PublishedApi internal val options: ProtocolOptio } private fun doClose() { + handlerScope?.cancel() + handlerScope = null + val handlersToNotify = _responseHandlers.value.values.toList() _responseHandlers.getAndSet(persistentMapOf()) _progressHandlers.getAndSet(persistentMapOf())