Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions kotlin-sdk-server/api/kotlin-sdk-server.api
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@ public final class io/modelcontextprotocol/kotlin/sdk/server/KtorServerKt {
public static synthetic fun mcp$default (Lio/ktor/server/routing/Route;ZLjava/util/List;Ljava/util/List;Lkotlin/jvm/functions/Function1;ILjava/lang/Object;)V
public static final fun mcpStatelessStreamableHttp (Lio/ktor/server/application/Application;Ljava/lang/String;ZLjava/util/List;Ljava/util/List;Lio/modelcontextprotocol/kotlin/sdk/server/EventStore;Lkotlin/jvm/functions/Function1;)V
public static synthetic fun mcpStatelessStreamableHttp$default (Lio/ktor/server/application/Application;Ljava/lang/String;ZLjava/util/List;Ljava/util/List;Lio/modelcontextprotocol/kotlin/sdk/server/EventStore;Lkotlin/jvm/functions/Function1;ILjava/lang/Object;)V
public static final fun mcpStreamableHttp (Lio/ktor/server/application/Application;Ljava/lang/String;ZLjava/util/List;Ljava/util/List;Lio/modelcontextprotocol/kotlin/sdk/server/EventStore;Lkotlin/jvm/functions/Function1;)V
public static synthetic fun mcpStreamableHttp$default (Lio/ktor/server/application/Application;Ljava/lang/String;ZLjava/util/List;Ljava/util/List;Lio/modelcontextprotocol/kotlin/sdk/server/EventStore;Lkotlin/jvm/functions/Function1;ILjava/lang/Object;)V
public static final fun mcpStreamableHttp (Lio/ktor/server/application/Application;Ljava/lang/String;ZLjava/util/List;Ljava/util/List;Lio/modelcontextprotocol/kotlin/sdk/server/EventStore;Lkotlin/jvm/functions/Function1;Lkotlin/jvm/functions/Function1;)V
public static synthetic fun mcpStreamableHttp$default (Lio/ktor/server/application/Application;Ljava/lang/String;ZLjava/util/List;Ljava/util/List;Lio/modelcontextprotocol/kotlin/sdk/server/EventStore;Lkotlin/jvm/functions/Function1;Lkotlin/jvm/functions/Function1;ILjava/lang/Object;)V
}

public final class io/modelcontextprotocol/kotlin/sdk/server/RegisteredPrompt : io/modelcontextprotocol/kotlin/sdk/server/Feature {
Expand Down Expand Up @@ -245,15 +245,16 @@ public final class io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServe
}

public final class io/modelcontextprotocol/kotlin/sdk/server/StreamableHttpServerTransport$Configuration {
public synthetic fun <init> (ZZLjava/util/List;Ljava/util/List;Lio/modelcontextprotocol/kotlin/sdk/server/EventStore;Lkotlin/time/Duration;JILkotlin/jvm/internal/DefaultConstructorMarker;)V
public synthetic fun <init> (ZZLjava/util/List;Ljava/util/List;Lio/modelcontextprotocol/kotlin/sdk/server/EventStore;Lkotlin/time/Duration;JLkotlin/jvm/internal/DefaultConstructorMarker;)V
public synthetic fun <init> (ZZLjava/util/List;Ljava/util/List;Lio/modelcontextprotocol/kotlin/sdk/server/EventStore;Lkotlin/time/Duration;JLkotlin/jvm/functions/Function1;ILkotlin/jvm/internal/DefaultConstructorMarker;)V
public synthetic fun <init> (ZZLjava/util/List;Ljava/util/List;Lio/modelcontextprotocol/kotlin/sdk/server/EventStore;Lkotlin/time/Duration;JLkotlin/jvm/functions/Function1;Lkotlin/jvm/internal/DefaultConstructorMarker;)V
public final fun getAllowedHosts ()Ljava/util/List;
public final fun getAllowedOrigins ()Ljava/util/List;
public final fun getEnableDnsRebindingProtection ()Z
public final fun getEnableJsonResponse ()Z
public final fun getEventStore ()Lio/modelcontextprotocol/kotlin/sdk/server/EventStore;
public final fun getMaxRequestBodySize ()J
public final fun getRetryInterval-FghU774 ()Lkotlin/time/Duration;
public final fun getSseHeartbeatConfig ()Lkotlin/jvm/functions/Function1;
}

public final class io/modelcontextprotocol/kotlin/sdk/server/WebSocketMcpKtorServerExtensionsKt {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@ import io.ktor.server.routing.get
import io.ktor.server.routing.post
import io.ktor.server.routing.route
import io.ktor.server.routing.routing
import io.ktor.server.sse.Heartbeat
import io.ktor.server.sse.SSE
import io.ktor.server.sse.ServerSSESession
import io.ktor.server.sse.heartbeat
import io.ktor.server.sse.sse
import io.ktor.utils.io.KtorDsl
import io.modelcontextprotocol.kotlin.sdk.types.RPCError
Expand Down Expand Up @@ -167,6 +169,7 @@ private fun Application.mcpStreamableHttp(
}

sse {
configuration.sseHeartbeatConfig?.let { config -> heartbeat(config) }
val transport = existingStreamableTransport(call, transportManager) ?: return@sse
transport.handleRequest(this, call)
}
Expand Down Expand Up @@ -208,6 +211,8 @@ private fun Application.mcpStreamableHttp(
* If `null`, origin validation is disabled.
* @param eventStore An optional [EventStore] instance to enable resumable event stream functionality.
* Allows storing and replaying events.
* @param sseHeartbeatConfig The heartbeat configuration option for SSE connections. Null by default,
* meaning no heartbeat is sent.
* @param block factory block with access to the [RoutingContext] (for reading request headers)
* that creates and returns the [Server] to handle the connection.
*/
Expand All @@ -218,6 +223,7 @@ public fun Application.mcpStreamableHttp(
allowedHosts: List<String>? = null,
allowedOrigins: List<String>? = null,
eventStore: EventStore? = null,
sseHeartbeatConfig: (Heartbeat.() -> Unit)? = null,
block: RoutingContext.() -> Server,
) {
mcpStreamableHttp(
Expand All @@ -228,6 +234,7 @@ public fun Application.mcpStreamableHttp(
configuration = StreamableHttpServerTransport.Configuration(
eventStore = eventStore,
enableJsonResponse = true,
sseHeartbeatConfig = sseHeartbeatConfig,
),
block = block,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import io.ktor.server.request.receiveText
import io.ktor.server.response.header
import io.ktor.server.response.respond
import io.ktor.server.response.respondNullable
import io.ktor.server.sse.Heartbeat
import io.ktor.server.sse.ServerSSESession
import io.ktor.util.collections.ConcurrentMap
import io.modelcontextprotocol.kotlin.sdk.shared.AbstractTransport
Expand Down Expand Up @@ -133,6 +134,7 @@ public class StreamableHttpServerTransport(private val configuration: Configurat
* @property retryInterval retry interval for SSE reconnection attempts
* @property maxRequestBodySize Maximum allowed size (in bytes) for incoming request bodies.
* Defaults to 4 MB (4,194,304 bytes).
* @property sseHeartbeatConfig Configuration options for SSE heartbeat
*/
public class Configuration(
public val enableJsonResponse: Boolean = false,
Expand All @@ -154,6 +156,7 @@ public class StreamableHttpServerTransport(private val configuration: Configurat
public val eventStore: EventStore? = null,
public val retryInterval: Duration? = null,
public val maxRequestBodySize: Long = DEFAULT_MAX_REQUEST_BODY_SIZE,
public val sseHeartbeatConfig: (Heartbeat.() -> Unit)? = null,
) {
init {
require(maxRequestBodySize > 0) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
package io.modelcontextprotocol.kotlin.sdk.server

import io.kotest.matchers.shouldBe
import io.ktor.client.HttpClient
import io.ktor.client.request.HttpRequestBuilder
import io.ktor.client.request.header
import io.ktor.client.request.post
import io.ktor.client.request.prepareGet
import io.ktor.client.request.setBody
import io.ktor.client.statement.bodyAsChannel
import io.ktor.http.ContentType
import io.ktor.http.HttpHeaders
import io.ktor.http.HttpStatusCode
import io.ktor.http.contentType
import io.ktor.serialization.kotlinx.json.json
import io.ktor.server.testing.ApplicationTestBuilder
import io.ktor.server.testing.testApplication
import io.ktor.sse.ServerSentEvent
import io.ktor.utils.io.ByteReadChannel
import io.ktor.utils.io.readUTF8Line
import io.modelcontextprotocol.kotlin.sdk.types.ClientCapabilities
import io.modelcontextprotocol.kotlin.sdk.types.Implementation
import io.modelcontextprotocol.kotlin.sdk.types.InitializeRequest
import io.modelcontextprotocol.kotlin.sdk.types.InitializeRequestParams
import io.modelcontextprotocol.kotlin.sdk.types.JSONRPCRequest
import io.modelcontextprotocol.kotlin.sdk.types.LATEST_PROTOCOL_VERSION
import io.modelcontextprotocol.kotlin.sdk.types.McpJson
import io.modelcontextprotocol.kotlin.sdk.types.ServerCapabilities
import io.modelcontextprotocol.kotlin.sdk.types.toJSON
import kotlinx.coroutines.withTimeoutOrNull
import java.util.concurrent.atomic.AtomicBoolean
import kotlin.test.Test
import kotlin.test.assertNotNull
import kotlin.time.Duration.Companion.milliseconds
import io.ktor.client.plugins.contentnegotiation.ContentNegotiation as ClientContentNegotiation

class StreamableHttpHeartbeatTest {
private val path = "/mcp"

@Test
fun `GET SSE stream applies configured heartbeat`() = testApplication {
val heartbeatConfigured = AtomicBoolean(false)

application {
mcpStreamableHttp(
path = path,
sseHeartbeatConfig = {
heartbeatConfigured.set(true)
period = 50.milliseconds
event = ServerSentEvent(comments = "mcp-heartbeat")
},
) {
testServer()
}
}

val client = createTestClient()
val sessionId = initializeSession(client)

client.prepareGet(path) {
addSseHeaders(sessionId)
}.execute { response ->
response.status shouldBe HttpStatusCode.OK
response.headers[MCP_SESSION_ID_HEADER] shouldBe sessionId
assertNotNull(response.bodyAsChannel().readUTF8Line())
heartbeatConfigured.get() shouldBe true
}
}

@Test
fun `GET SSE stream does not send heartbeat by default`() = testApplication {
application {
mcpStreamableHttp(path = path) {
testServer()
}
}

val client = createTestClient()
val sessionId = initializeSession(client)

client.prepareGet(path) {
addSseHeaders(sessionId)
}.execute { response ->
response.status shouldBe HttpStatusCode.OK
response.headers[MCP_SESSION_ID_HEADER] shouldBe sessionId

val heartbeatLine = response.bodyAsChannel().readLineMatching(": heartbeat", timeoutMillis = 150)

heartbeatLine shouldBe null
}
}

private fun testServer(): Server = Server(
Implementation("test-server", "1.0.0"),
ServerOptions(capabilities = ServerCapabilities()),
)

private suspend fun initializeSession(client: HttpClient): String {
val response = client.post(path) {
header(HttpHeaders.Host, "localhost")
addStreamableHeaders()
setBody(buildInitializeRequestPayload())
}

response.status shouldBe HttpStatusCode.OK
return assertNotNull(response.headers[MCP_SESSION_ID_HEADER])
}

private suspend fun ByteReadChannel.readLineMatching(expectedLine: String, timeoutMillis: Long = 2_000): String? =
withTimeoutOrNull(timeoutMillis.milliseconds) {
var line = readUTF8Line()
while (line != null) {
if (line == expectedLine) return@withTimeoutOrNull line
line = readUTF8Line()
}
null
}

private fun HttpRequestBuilder.addStreamableHeaders() {
header(
HttpHeaders.Accept,
listOf(ContentType.Application.Json, ContentType.Text.EventStream).joinToString(", ") {
it.toString()
},
)
contentType(ContentType.Application.Json)
}

private fun HttpRequestBuilder.addSseHeaders(sessionId: String) {
header(HttpHeaders.Host, "localhost")
header(HttpHeaders.Accept, ContentType.Text.EventStream.toString())
header(MCP_SESSION_ID_HEADER, sessionId)
header("mcp-protocol-version", LATEST_PROTOCOL_VERSION)
}

private fun buildInitializeRequestPayload(): JSONRPCRequest = InitializeRequest(
InitializeRequestParams(
protocolVersion = LATEST_PROTOCOL_VERSION,
capabilities = ClientCapabilities(),
clientInfo = Implementation(name = "test-client", version = "1.0.0"),
),
).toJSON()

private fun ApplicationTestBuilder.createTestClient(): HttpClient = createClient {
install(ClientContentNegotiation) {
json(McpJson)
}
}
}