Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions src/ModelContextProtocol.Core/McpSessionHandler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -190,11 +190,16 @@ async Task ProcessMessageAsync()
{
// Register before we yield, so that the tracking is guaranteed to be there
// when subsequent messages arrive, even if the asynchronous processing happens
// out of order.
// out of order. Per spec, "The initialize request MUST NOT be cancelled by clients",
// so we don't track it in _handlingRequests to prevent cancellation notifications from
// canceling it.
if (messageWithId is not null)
{
combinedCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
_handlingRequests[messageWithId.Id] = combinedCts;
if (message is not JsonRpcRequest { Method: RequestMethods.Initialize })
{
_handlingRequests[messageWithId.Id] = combinedCts;
}
}

// If we await the handler without yielding first, the transport may not be able to read more messages,
Expand Down Expand Up @@ -528,9 +533,10 @@ public async Task<JsonRpcResponse> SendRequestAsync(JsonRpcRequest request, Canc
// Now that the request has been sent, register for cancellation. If we registered before,
// a cancellation request could arrive before the server knew about that request ID, in which
// case the server could ignore it.
// Per spec, "The initialize request MUST NOT be cancelled by clients", so skip registration for initialize.
LogRequestSentAwaitingResponse(EndpointName, request.Method, request.Id);
JsonRpcMessage? response;
using (var registration = RegisterCancellation(cancellationToken, request))
using (var registration = method != RequestMethods.Initialize ? RegisterCancellation(cancellationToken, request) : default)
{
response = await tcs.Task.WaitAsync(cancellationToken).ConfigureAwait(false);
}
Expand Down
37 changes: 37 additions & 0 deletions tests/ModelContextProtocol.Tests/Protocol/CancellationTests.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
using Microsoft.Extensions.DependencyInjection;
using ModelContextProtocol.Client;
using ModelContextProtocol.Protocol;
using ModelContextProtocol.Server;
using ModelContextProtocol.Tests.Utils;
using System.IO.Pipelines;
using System.Text;

namespace ModelContextProtocol.Tests;

Expand Down Expand Up @@ -65,4 +69,37 @@ public async Task CancellationPropagation_RequestingCancellationCancelsPendingRe
cts.Cancel();
await Assert.ThrowsAnyAsync<OperationCanceledException>(async () => await waitTask);
}

[Fact]
public async Task InitializeTimeout_DoesNotSendCancellationNotification()
{
// Arrange: Create a transport where the server never responds, so the client will time out.
var serverInput = new MemoryStream();
var serverOutputPipe = new Pipe();

var clientTransport = new StreamClientTransport(
serverInput: serverInput,
serverOutputPipe.Reader.AsStream(),
LoggerFactory);

var clientOptions = new McpClientOptions
{
InitializationTimeout = TimeSpan.FromMilliseconds(500),
};

// Act: Client will send initialize, then time out since no response comes.
// Per spec, "The initialize request MUST NOT be cancelled by clients",
// so no cancellation notification should be sent.
await Assert.ThrowsAsync<TimeoutException>(async () =>
{
await McpClient.CreateAsync(clientTransport, clientOptions: clientOptions, loggerFactory: LoggerFactory,
cancellationToken: TestContext.Current.CancellationToken);
});

// Assert: Read what was written to serverInput.
// The only message should be the initialize request, NOT a cancellation notification.
var content = Encoding.UTF8.GetString(serverInput.ToArray());
Assert.Contains("\"method\":\"initialize\"", content);
Assert.DoesNotContain("notifications/cancelled", content);
}
}
53 changes: 53 additions & 0 deletions tests/ModelContextProtocol.Tests/Server/McpServerTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -861,6 +861,59 @@ public async Task Can_SendMessage_Before_RunAsync()
Assert.Same(logNotification, transport.SentMessages[0]);
}

[Fact]
public async Task Server_IgnoresCancellationNotificationForInitializeRequest()
{
// Arrange
await using var transport = new TestServerTransport();
await using McpServer server = McpServer.Create(transport, _options, LoggerFactory);
var runTask = server.RunAsync(TestContext.Current.CancellationToken);

// Set up to capture the initialize response
var initializeRequest = new JsonRpcRequest
{
Id = new RequestId("init-cancel-test"),
Method = RequestMethods.Initialize,
Params = JsonSerializer.SerializeToNode(new InitializeRequestParams
{
ProtocolVersion = "2024-11-05",
Capabilities = new ClientCapabilities(),
ClientInfo = new Implementation { Name = "test-client", Version = "1.0.0" }
}, McpJsonUtilities.DefaultOptions)
};

var initResponseTcs = new TaskCompletionSource<JsonRpcResponse>();
transport.OnMessageSent = (message) =>
{
if (message is JsonRpcResponse response && response.Id == initializeRequest.Id)
{
initResponseTcs.TrySetResult(response);
}
};

// Act: Send initialize request and immediately send a cancellation notification for it.
// Per spec, "The initialize request MUST NOT be cancelled by clients", so the server
// should ignore the cancellation and still complete the initialize request.
await transport.SendClientMessageAsync(initializeRequest, TestContext.Current.CancellationToken);
await transport.SendClientMessageAsync(new JsonRpcNotification
{
Method = NotificationMethods.CancelledNotification,
Params = JsonSerializer.SerializeToNode(
new CancelledNotificationParams { RequestId = initializeRequest.Id },
McpJsonUtilities.DefaultOptions),
}, TestContext.Current.CancellationToken);

// Assert: The initialize response should still arrive (not cancelled)
var response = await initResponseTcs.Task.WaitAsync(TestConstants.DefaultTimeout, TestContext.Current.CancellationToken);
Assert.NotNull(response.Result);
var initResult = JsonSerializer.Deserialize<InitializeResult>(response.Result, McpJsonUtilities.DefaultOptions);
Assert.NotNull(initResult);
Assert.NotNull(initResult.ServerInfo);

await transport.DisposeAsync();
await runTask;
}

private static async Task InitializeServerAsync(TestServerTransport transport, ClientCapabilities capabilities, CancellationToken cancellationToken = default)
{
var initializeRequest = new JsonRpcRequest
Expand Down