diff --git a/kotlin-sdk-server/api/kotlin-sdk-server.api b/kotlin-sdk-server/api/kotlin-sdk-server.api index f400e80d..cdc336e2 100644 --- a/kotlin-sdk-server/api/kotlin-sdk-server.api +++ b/kotlin-sdk-server/api/kotlin-sdk-server.api @@ -59,8 +59,8 @@ public final class io/modelcontextprotocol/kotlin/sdk/server/KtorServerKt { public static synthetic fun mcp$default (Lio/ktor/server/routing/Route;ZLjava/util/List;Ljava/util/List;Lkotlin/jvm/functions/Function1;ILjava/lang/Object;)V public static final fun mcpStatelessStreamableHttp (Lio/ktor/server/application/Application;Ljava/lang/String;ZLjava/util/List;Ljava/util/List;Lio/modelcontextprotocol/kotlin/sdk/server/EventStore;Lkotlin/jvm/functions/Function1;)V public static synthetic fun mcpStatelessStreamableHttp$default (Lio/ktor/server/application/Application;Ljava/lang/String;ZLjava/util/List;Ljava/util/List;Lio/modelcontextprotocol/kotlin/sdk/server/EventStore;Lkotlin/jvm/functions/Function1;ILjava/lang/Object;)V - public static final fun mcpStreamableHttp (Lio/ktor/server/application/Application;Ljava/lang/String;ZLjava/util/List;Ljava/util/List;Lio/modelcontextprotocol/kotlin/sdk/server/EventStore;Lkotlin/jvm/functions/Function1;)V - public static synthetic fun mcpStreamableHttp$default (Lio/ktor/server/application/Application;Ljava/lang/String;ZLjava/util/List;Ljava/util/List;Lio/modelcontextprotocol/kotlin/sdk/server/EventStore;Lkotlin/jvm/functions/Function1;ILjava/lang/Object;)V + public static final fun mcpStreamableHttp (Lio/ktor/server/application/Application;Ljava/lang/String;ZLjava/util/List;Ljava/util/List;Lio/modelcontextprotocol/kotlin/sdk/server/EventStore;Lkotlin/jvm/functions/Function1;Lkotlin/jvm/functions/Function1;)V + public static synthetic fun mcpStreamableHttp$default (Lio/ktor/server/application/Application;Ljava/lang/String;ZLjava/util/List;Ljava/util/List;Lio/modelcontextprotocol/kotlin/sdk/server/EventStore;Lkotlin/jvm/functions/Function1;Lkotlin/jvm/functions/Function1;ILjava/lang/Object;)V } public final class io/modelcontextprotocol/kotlin/sdk/server/RegisteredPrompt : io/modelcontextprotocol/kotlin/sdk/server/Feature { @@ -245,8 +245,8 @@ public final class io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServe } public final class io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransport$Configuration { - public synthetic fun (ZZLjava/util/List;Ljava/util/List;Lio/modelcontextprotocol/kotlin/sdk/server/EventStore;Lkotlin/time/Duration;JILkotlin/jvm/internal/DefaultConstructorMarker;)V - public synthetic fun (ZZLjava/util/List;Ljava/util/List;Lio/modelcontextprotocol/kotlin/sdk/server/EventStore;Lkotlin/time/Duration;JLkotlin/jvm/internal/DefaultConstructorMarker;)V + public synthetic fun (ZZLjava/util/List;Ljava/util/List;Lio/modelcontextprotocol/kotlin/sdk/server/EventStore;Lkotlin/time/Duration;JLkotlin/jvm/functions/Function1;ILkotlin/jvm/internal/DefaultConstructorMarker;)V + public synthetic fun (ZZLjava/util/List;Ljava/util/List;Lio/modelcontextprotocol/kotlin/sdk/server/EventStore;Lkotlin/time/Duration;JLkotlin/jvm/functions/Function1;Lkotlin/jvm/internal/DefaultConstructorMarker;)V public final fun getAllowedHosts ()Ljava/util/List; public final fun getAllowedOrigins ()Ljava/util/List; public final fun getEnableDnsRebindingProtection ()Z @@ -254,6 +254,7 @@ public final class io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServe public final fun getEventStore ()Lio/modelcontextprotocol/kotlin/sdk/server/EventStore; public final fun getMaxRequestBodySize ()J public final fun getRetryInterval-FghU774 ()Lkotlin/time/Duration; + public final fun getSseHeartbeatConfig ()Lkotlin/jvm/functions/Function1; } public final class io/modelcontextprotocol/kotlin/sdk/server/WebSocketMcpKtorServerExtensionsKt { diff --git a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/KtorServer.kt b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/KtorServer.kt index d6191dff..560f95c9 100644 --- a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/KtorServer.kt +++ b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/KtorServer.kt @@ -20,8 +20,10 @@ import io.ktor.server.routing.get import io.ktor.server.routing.post import io.ktor.server.routing.route import io.ktor.server.routing.routing +import io.ktor.server.sse.Heartbeat import io.ktor.server.sse.SSE import io.ktor.server.sse.ServerSSESession +import io.ktor.server.sse.heartbeat import io.ktor.server.sse.sse import io.ktor.utils.io.KtorDsl import io.modelcontextprotocol.kotlin.sdk.types.RPCError @@ -167,6 +169,7 @@ private fun Application.mcpStreamableHttp( } sse { + configuration.sseHeartbeatConfig?.let { config -> heartbeat(config) } val transport = existingStreamableTransport(call, transportManager) ?: return@sse transport.handleRequest(this, call) } @@ -208,6 +211,8 @@ private fun Application.mcpStreamableHttp( * If `null`, origin validation is disabled. * @param eventStore An optional [EventStore] instance to enable resumable event stream functionality. * Allows storing and replaying events. + * @param sseHeartbeatConfig The heartbeat configuration option for SSE connections. Null by default, + * meaning no heartbeat is sent. * @param block factory block with access to the [RoutingContext] (for reading request headers) * that creates and returns the [Server] to handle the connection. */ @@ -218,6 +223,7 @@ public fun Application.mcpStreamableHttp( allowedHosts: List? = null, allowedOrigins: List? = null, eventStore: EventStore? = null, + sseHeartbeatConfig: (Heartbeat.() -> Unit)? = null, block: RoutingContext.() -> Server, ) { mcpStreamableHttp( @@ -228,6 +234,7 @@ public fun Application.mcpStreamableHttp( configuration = StreamableHttpServerTransport.Configuration( eventStore = eventStore, enableJsonResponse = true, + sseHeartbeatConfig = sseHeartbeatConfig, ), block = block, ) diff --git a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransport.kt b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransport.kt index ae8a39d1..cbb94c5a 100644 --- a/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransport.kt +++ b/kotlin-sdk-server/src/commonMain/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransport.kt @@ -12,6 +12,7 @@ import io.ktor.server.request.receiveText import io.ktor.server.response.header import io.ktor.server.response.respond import io.ktor.server.response.respondNullable +import io.ktor.server.sse.Heartbeat import io.ktor.server.sse.ServerSSESession import io.ktor.util.collections.ConcurrentMap import io.modelcontextprotocol.kotlin.sdk.shared.AbstractTransport @@ -133,6 +134,7 @@ public class StreamableHttpServerTransport(private val configuration: Configurat * @property retryInterval retry interval for SSE reconnection attempts * @property maxRequestBodySize Maximum allowed size (in bytes) for incoming request bodies. * Defaults to 4 MB (4,194,304 bytes). + * @property sseHeartbeatConfig Configuration options for SSE heartbeat */ public class Configuration( public val enableJsonResponse: Boolean = false, @@ -154,6 +156,7 @@ public class StreamableHttpServerTransport(private val configuration: Configurat public val eventStore: EventStore? = null, public val retryInterval: Duration? = null, public val maxRequestBodySize: Long = DEFAULT_MAX_REQUEST_BODY_SIZE, + public val sseHeartbeatConfig: (Heartbeat.() -> Unit)? = null, ) { init { require(maxRequestBodySize > 0) { diff --git a/kotlin-sdk-server/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpHeartbeatTest.kt b/kotlin-sdk-server/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpHeartbeatTest.kt new file mode 100644 index 00000000..0cafd87a --- /dev/null +++ b/kotlin-sdk-server/src/jvmTest/kotlin/io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpHeartbeatTest.kt @@ -0,0 +1,149 @@ +package io.modelcontextprotocol.kotlin.sdk.server + +import io.kotest.matchers.shouldBe +import io.ktor.client.HttpClient +import io.ktor.client.request.HttpRequestBuilder +import io.ktor.client.request.header +import io.ktor.client.request.post +import io.ktor.client.request.prepareGet +import io.ktor.client.request.setBody +import io.ktor.client.statement.bodyAsChannel +import io.ktor.http.ContentType +import io.ktor.http.HttpHeaders +import io.ktor.http.HttpStatusCode +import io.ktor.http.contentType +import io.ktor.serialization.kotlinx.json.json +import io.ktor.server.testing.ApplicationTestBuilder +import io.ktor.server.testing.testApplication +import io.ktor.sse.ServerSentEvent +import io.ktor.utils.io.ByteReadChannel +import io.ktor.utils.io.readUTF8Line +import io.modelcontextprotocol.kotlin.sdk.types.ClientCapabilities +import io.modelcontextprotocol.kotlin.sdk.types.Implementation +import io.modelcontextprotocol.kotlin.sdk.types.InitializeRequest +import io.modelcontextprotocol.kotlin.sdk.types.InitializeRequestParams +import io.modelcontextprotocol.kotlin.sdk.types.JSONRPCRequest +import io.modelcontextprotocol.kotlin.sdk.types.LATEST_PROTOCOL_VERSION +import io.modelcontextprotocol.kotlin.sdk.types.McpJson +import io.modelcontextprotocol.kotlin.sdk.types.ServerCapabilities +import io.modelcontextprotocol.kotlin.sdk.types.toJSON +import kotlinx.coroutines.withTimeoutOrNull +import java.util.concurrent.atomic.AtomicBoolean +import kotlin.test.Test +import kotlin.test.assertNotNull +import kotlin.time.Duration.Companion.milliseconds +import io.ktor.client.plugins.contentnegotiation.ContentNegotiation as ClientContentNegotiation + +class StreamableHttpHeartbeatTest { + private val path = "/mcp" + + @Test + fun `GET SSE stream applies configured heartbeat`() = testApplication { + val heartbeatConfigured = AtomicBoolean(false) + + application { + mcpStreamableHttp( + path = path, + sseHeartbeatConfig = { + heartbeatConfigured.set(true) + period = 50.milliseconds + event = ServerSentEvent(comments = "mcp-heartbeat") + }, + ) { + testServer() + } + } + + val client = createTestClient() + val sessionId = initializeSession(client) + + client.prepareGet(path) { + addSseHeaders(sessionId) + }.execute { response -> + response.status shouldBe HttpStatusCode.OK + response.headers[MCP_SESSION_ID_HEADER] shouldBe sessionId + assertNotNull(response.bodyAsChannel().readUTF8Line()) + heartbeatConfigured.get() shouldBe true + } + } + + @Test + fun `GET SSE stream does not send heartbeat by default`() = testApplication { + application { + mcpStreamableHttp(path = path) { + testServer() + } + } + + val client = createTestClient() + val sessionId = initializeSession(client) + + client.prepareGet(path) { + addSseHeaders(sessionId) + }.execute { response -> + response.status shouldBe HttpStatusCode.OK + response.headers[MCP_SESSION_ID_HEADER] shouldBe sessionId + + val heartbeatLine = response.bodyAsChannel().readLineMatching(": heartbeat", timeoutMillis = 150) + + heartbeatLine shouldBe null + } + } + + private fun testServer(): Server = Server( + Implementation("test-server", "1.0.0"), + ServerOptions(capabilities = ServerCapabilities()), + ) + + private suspend fun initializeSession(client: HttpClient): String { + val response = client.post(path) { + header(HttpHeaders.Host, "localhost") + addStreamableHeaders() + setBody(buildInitializeRequestPayload()) + } + + response.status shouldBe HttpStatusCode.OK + return assertNotNull(response.headers[MCP_SESSION_ID_HEADER]) + } + + private suspend fun ByteReadChannel.readLineMatching(expectedLine: String, timeoutMillis: Long = 2_000): String? = + withTimeoutOrNull(timeoutMillis.milliseconds) { + var line = readUTF8Line() + while (line != null) { + if (line == expectedLine) return@withTimeoutOrNull line + line = readUTF8Line() + } + null + } + + private fun HttpRequestBuilder.addStreamableHeaders() { + header( + HttpHeaders.Accept, + listOf(ContentType.Application.Json, ContentType.Text.EventStream).joinToString(", ") { + it.toString() + }, + ) + contentType(ContentType.Application.Json) + } + + private fun HttpRequestBuilder.addSseHeaders(sessionId: String) { + header(HttpHeaders.Host, "localhost") + header(HttpHeaders.Accept, ContentType.Text.EventStream.toString()) + header(MCP_SESSION_ID_HEADER, sessionId) + header("mcp-protocol-version", LATEST_PROTOCOL_VERSION) + } + + private fun buildInitializeRequestPayload(): JSONRPCRequest = InitializeRequest( + InitializeRequestParams( + protocolVersion = LATEST_PROTOCOL_VERSION, + capabilities = ClientCapabilities(), + clientInfo = Implementation(name = "test-client", version = "1.0.0"), + ), + ).toJSON() + + private fun ApplicationTestBuilder.createTestClient(): HttpClient = createClient { + install(ClientContentNegotiation) { + json(McpJson) + } + } +}