diff --git a/src/ModelContextProtocol.Core/Protocol/BlobResourceContents.cs b/src/ModelContextProtocol.Core/Protocol/BlobResourceContents.cs index 859d45a62..809c1aea1 100644 --- a/src/ModelContextProtocol.Core/Protocol/BlobResourceContents.cs +++ b/src/ModelContextProtocol.Core/Protocol/BlobResourceContents.cs @@ -1,6 +1,7 @@ using System.Buffers; using System.Buffers.Text; using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; using System.Runtime.InteropServices; using System.Text.Json.Serialization; @@ -28,7 +29,7 @@ namespace ModelContextProtocol.Protocol; public sealed class BlobResourceContents : ResourceContents { private ReadOnlyMemory? _decodedData; - private ReadOnlyMemory _blob; + private ReadOnlyMemory? _blob; /// /// Creates an from raw data. @@ -40,15 +41,20 @@ public sealed class BlobResourceContents : ResourceContents /// public static BlobResourceContents FromBytes(ReadOnlyMemory bytes, string uri, string? mimeType = null) { - ReadOnlyMemory blob = EncodingUtilities.EncodeToBase64Utf8(bytes); - - return new() - { - _decodedData = bytes, - Blob = blob, - MimeType = mimeType, - Uri = uri - }; + return new(bytes, uri, mimeType); + } + + /// Initializes a new instance of the class. + public BlobResourceContents() + { + } + + [SetsRequiredMembers] + private BlobResourceContents(ReadOnlyMemory decodedData, string uri, string? mimeType) + { + _decodedData = decodedData; + Uri = uri; + MimeType = mimeType; } /// @@ -60,7 +66,16 @@ public static BlobResourceContents FromBytes(ReadOnlyMemory bytes, string [JsonPropertyName("blob")] public required ReadOnlyMemory Blob { - get => _blob; + get + { + if (_blob is null) + { + Debug.Assert(_decodedData is not null); + _blob = EncodingUtilities.EncodeToBase64Utf8(_decodedData!.Value); + } + + return _blob.Value; + } set { _blob = value; diff --git a/src/ModelContextProtocol.Core/Protocol/ContentBlock.cs b/src/ModelContextProtocol.Core/Protocol/ContentBlock.cs index a2206bbe3..5946dc940 100644 --- a/src/ModelContextProtocol.Core/Protocol/ContentBlock.cs +++ b/src/ModelContextProtocol.Core/Protocol/ContentBlock.cs @@ -92,6 +92,7 @@ public sealed class Converter : JsonConverter string? name = null; string? title = null; ReadOnlyMemory? data = null; + ReadOnlyMemory? decodedData = null; string? mimeType = null; string? uri = null; string? description = null; @@ -137,7 +138,14 @@ public sealed class Converter : JsonConverter break; case "data": - data = reader.HasValueSequence ? reader.ValueSequence.ToArray() : reader.ValueSpan.ToArray(); + if (!reader.ValueIsEscaped) + { + data = reader.HasValueSequence ? reader.ValueSequence.ToArray() : reader.ValueSpan.ToArray(); + } + else + { + decodedData = reader.GetBytesFromBase64(); + } break; case "mimeType": @@ -230,17 +238,23 @@ public sealed class Converter : JsonConverter Text = text ?? throw new JsonException("Text contents must be provided for 'text' type."), }, - "image" => new ImageContentBlock - { - Data = data ?? throw new JsonException("Image data must be provided for 'image' type."), - MimeType = mimeType ?? throw new JsonException("MIME type must be provided for 'image' type."), - }, - - "audio" => new AudioContentBlock - { - Data = data ?? throw new JsonException("Audio data must be provided for 'audio' type."), - MimeType = mimeType ?? throw new JsonException("MIME type must be provided for 'audio' type."), - }, + "image" => decodedData is not null ? + ImageContentBlock.FromBytes(decodedData.Value, + mimeType ?? throw new JsonException("MIME type must be provided for 'image' type.")) : + new ImageContentBlock + { + Data = data ?? throw new JsonException("Image data must be provided for 'image' type."), + MimeType = mimeType ?? throw new JsonException("MIME type must be provided for 'image' type."), + }, + + "audio" => decodedData is not null ? + AudioContentBlock.FromBytes(decodedData.Value, + mimeType ?? throw new JsonException("MIME type must be provided for 'audio' type.")) : + new AudioContentBlock + { + Data = data ?? throw new JsonException("Audio data must be provided for 'audio' type."), + MimeType = mimeType ?? throw new JsonException("MIME type must be provided for 'audio' type."), + }, "resource" => new EmbeddedResourceBlock { @@ -414,7 +428,7 @@ public sealed class TextContentBlock : ContentBlock public sealed class ImageContentBlock : ContentBlock { private ReadOnlyMemory? _decodedData; - private ReadOnlyMemory _data; + private ReadOnlyMemory? _data; /// /// Creates an from decoded image bytes. @@ -423,7 +437,7 @@ public sealed class ImageContentBlock : ContentBlock /// The MIME type of the image. /// A new instance. /// - /// This method stores the provided bytes as and encodes them to base64 UTF-8 bytes for . + /// This method stores the provided bytes as and lazily encodes them to base64 UTF-8 bytes for . /// /// is . /// is empty or composed entirely of whitespace. @@ -431,14 +445,19 @@ public static ImageContentBlock FromBytes(ReadOnlyMemory bytes, string mim { Throw.IfNullOrWhiteSpace(mimeType); - ReadOnlyMemory data = EncodingUtilities.EncodeToBase64Utf8(bytes); - - return new() - { - _decodedData = bytes, - Data = data, - MimeType = mimeType - }; + return new(bytes, mimeType); + } + + /// Initializes a new instance of the class. + public ImageContentBlock() + { + } + + [SetsRequiredMembers] + private ImageContentBlock(ReadOnlyMemory decodedData, string mimeType) + { + _decodedData = decodedData; + MimeType = mimeType; } /// @@ -453,7 +472,16 @@ public static ImageContentBlock FromBytes(ReadOnlyMemory bytes, string mim [JsonPropertyName("data")] public required ReadOnlyMemory Data { - get => _data; + get + { + if (_data is null) + { + Debug.Assert(_decodedData is not null); + _data = EncodingUtilities.EncodeToBase64Utf8(_decodedData!.Value); + } + + return _data.Value; + } set { _data = value; @@ -494,7 +522,14 @@ public ReadOnlyMemory DecodedData public required string MimeType { get; set; } [DebuggerBrowsable(DebuggerBrowsableState.Never)] - private string DebuggerDisplay => $"MimeType = {MimeType}, Length = {DebuggerDisplayHelper.GetBase64LengthDisplay(Data)}"; + private string DebuggerDisplay + { + get + { + string lengthDisplay = _decodedData is not null ? $"{_decodedData.Value.Length} bytes" : DebuggerDisplayHelper.GetBase64LengthDisplay(Data); + return $"MimeType = {MimeType}, Length = {lengthDisplay}"; + } + } } /// Represents audio provided to or from an LLM. @@ -502,7 +537,7 @@ public ReadOnlyMemory DecodedData public sealed class AudioContentBlock : ContentBlock { private ReadOnlyMemory? _decodedData; - private ReadOnlyMemory _data; + private ReadOnlyMemory? _data; /// /// Creates an from decoded audio bytes. @@ -511,7 +546,7 @@ public sealed class AudioContentBlock : ContentBlock /// The MIME type of the audio. /// A new instance. /// - /// This method stores the provided bytes as and encodes them to base64 UTF-8 bytes for . + /// This method stores the provided bytes as and lazily encodes them to base64 UTF-8 bytes for . /// /// is . /// is empty or composed entirely of whitespace. @@ -519,14 +554,19 @@ public static AudioContentBlock FromBytes(ReadOnlyMemory bytes, string mim { Throw.IfNullOrWhiteSpace(mimeType); - ReadOnlyMemory data = EncodingUtilities.EncodeToBase64Utf8(bytes); - - return new() - { - _decodedData = bytes, - Data = data, - MimeType = mimeType - }; + return new(bytes, mimeType); + } + + /// Initializes a new instance of the class. + public AudioContentBlock() + { + } + + [SetsRequiredMembers] + private AudioContentBlock(ReadOnlyMemory decodedData, string mimeType) + { + _decodedData = decodedData; + MimeType = mimeType; } /// @@ -541,7 +581,16 @@ public static AudioContentBlock FromBytes(ReadOnlyMemory bytes, string mim [JsonPropertyName("data")] public required ReadOnlyMemory Data { - get => _data; + get + { + if (_data is null) + { + Debug.Assert(_decodedData is not null); + _data = EncodingUtilities.EncodeToBase64Utf8(_decodedData!.Value); + } + + return _data.Value; + } set { _data = value; @@ -582,7 +631,14 @@ public ReadOnlyMemory DecodedData public required string MimeType { get; set; } [DebuggerBrowsable(DebuggerBrowsableState.Never)] - private string DebuggerDisplay => $"MimeType = {MimeType}, Length = {DebuggerDisplayHelper.GetBase64LengthDisplay(Data)}"; + private string DebuggerDisplay + { + get + { + string lengthDisplay = _decodedData is not null ? $"{_decodedData.Value.Length} bytes" : DebuggerDisplayHelper.GetBase64LengthDisplay(Data); + return $"MimeType = {MimeType}, Length = {lengthDisplay}"; + } + } } /// Represents the contents of a resource, embedded into a prompt or tool call result. diff --git a/src/ModelContextProtocol.Core/Protocol/ResourceContents.cs b/src/ModelContextProtocol.Core/Protocol/ResourceContents.cs index f1b190277..283c47d6b 100644 --- a/src/ModelContextProtocol.Core/Protocol/ResourceContents.cs +++ b/src/ModelContextProtocol.Core/Protocol/ResourceContents.cs @@ -80,6 +80,7 @@ public sealed class Converter : JsonConverter string? uri = null; string? mimeType = null; ReadOnlyMemory? blob = null; + ReadOnlyMemory? decodedBlob = null; string? text = null; JsonObject? meta = null; @@ -105,7 +106,14 @@ public sealed class Converter : JsonConverter break; case "blob": - blob = reader.HasValueSequence ? reader.ValueSequence.ToArray() : reader.ValueSpan.ToArray(); + if (!reader.ValueIsEscaped) + { + blob = reader.HasValueSequence ? reader.ValueSequence.ToArray() : reader.ValueSpan.ToArray(); + } + else + { + decodedBlob = reader.GetBytesFromBase64(); + } break; case "text": @@ -122,6 +130,13 @@ public sealed class Converter : JsonConverter } } + if (decodedBlob is not null) + { + var blobResource = BlobResourceContents.FromBytes(decodedBlob.Value, uri ?? string.Empty, mimeType); + blobResource.Meta = meta; + return blobResource; + } + if (blob is not null) { return new BlobResourceContents diff --git a/tests/ModelContextProtocol.Tests/AIContentExtensionsTests.cs b/tests/ModelContextProtocol.Tests/AIContentExtensionsTests.cs index b46a34268..77be2f922 100644 --- a/tests/ModelContextProtocol.Tests/AIContentExtensionsTests.cs +++ b/tests/ModelContextProtocol.Tests/AIContentExtensionsTests.cs @@ -1,5 +1,6 @@ using Microsoft.Extensions.AI; using ModelContextProtocol.Protocol; +using ModelContextProtocol.Tests.Protocol; using System.Text.Json; using System.Text.Json.Serialization; @@ -402,6 +403,183 @@ public void ToChatMessage_CallToolResult_WithAnonymousTypeInContent_Works() Assert.Null(exception); } + + [Theory] + [MemberData(nameof(ContentBlockTests.Base64TestData), MemberType = typeof(ContentBlockTests))] + public void ImageContentBlock_ToAIContent_RoundTrips(byte[] originalBytes) + { + var image = ImageContentBlock.FromBytes(originalBytes, "image/png"); + + var aiContent = Assert.IsType(image.ToAIContent()); + Assert.Equal("image/png", aiContent.MediaType); + Assert.Equal(originalBytes, aiContent.Data.ToArray()); + + var roundTripped = Assert.IsType(aiContent.ToContentBlock()); + Assert.Equal("image/png", roundTripped.MimeType); + Assert.Equal(originalBytes, roundTripped.DecodedData.ToArray()); + } + + [Theory] + [MemberData(nameof(ContentBlockTests.Base64TestData), MemberType = typeof(ContentBlockTests))] + public void ImageContentBlock_DataSetter_ToAIContent_RoundTrips(byte[] originalBytes) + { + string base64 = Convert.ToBase64String(originalBytes); + var image = new ImageContentBlock + { + Data = System.Text.Encoding.UTF8.GetBytes(base64), + MimeType = "image/jpeg" + }; + + var aiContent = Assert.IsType(image.ToAIContent()); + Assert.Equal("image/jpeg", aiContent.MediaType); + Assert.Equal(originalBytes, aiContent.Data.ToArray()); + + var roundTripped = Assert.IsType(aiContent.ToContentBlock()); + Assert.Equal("image/jpeg", roundTripped.MimeType); + Assert.Equal(originalBytes, roundTripped.DecodedData.ToArray()); + } + + [Theory] + [MemberData(nameof(ContentBlockTests.Base64TestData), MemberType = typeof(ContentBlockTests))] + public void AudioContentBlock_ToAIContent_RoundTrips(byte[] originalBytes) + { + var audio = AudioContentBlock.FromBytes(originalBytes, "audio/wav"); + + var aiContent = Assert.IsType(audio.ToAIContent()); + Assert.Equal("audio/wav", aiContent.MediaType); + Assert.Equal(originalBytes, aiContent.Data.ToArray()); + + var roundTripped = Assert.IsType(aiContent.ToContentBlock()); + Assert.Equal("audio/wav", roundTripped.MimeType); + Assert.Equal(originalBytes, roundTripped.DecodedData.ToArray()); + } + + [Theory] + [MemberData(nameof(ContentBlockTests.Base64TestData), MemberType = typeof(ContentBlockTests))] + public void AudioContentBlock_DataSetter_ToAIContent_RoundTrips(byte[] originalBytes) + { + string base64 = Convert.ToBase64String(originalBytes); + var audio = new AudioContentBlock + { + Data = System.Text.Encoding.UTF8.GetBytes(base64), + MimeType = "audio/mp3" + }; + + var aiContent = Assert.IsType(audio.ToAIContent()); + Assert.Equal("audio/mp3", aiContent.MediaType); + Assert.Equal(originalBytes, aiContent.Data.ToArray()); + + var roundTripped = Assert.IsType(aiContent.ToContentBlock()); + Assert.Equal("audio/mp3", roundTripped.MimeType); + Assert.Equal(originalBytes, roundTripped.DecodedData.ToArray()); + } + + [Theory] + [MemberData(nameof(ContentBlockTests.Base64TestData), MemberType = typeof(ContentBlockTests))] + public void BlobResourceContents_ToAIContent_RoundTrips(byte[] originalBytes) + { + var blob = BlobResourceContents.FromBytes(originalBytes, "file:///test.bin", "application/octet-stream"); + var embedded = new EmbeddedResourceBlock { Resource = blob }; + + var aiContent = Assert.IsType(embedded.ToAIContent()); + Assert.Equal("application/octet-stream", aiContent.MediaType); + Assert.Equal(originalBytes, aiContent.Data.ToArray()); + + var roundTripped = Assert.IsType(aiContent.ToContentBlock()); + var roundTrippedBlob = Assert.IsType(roundTripped.Resource); + Assert.Equal("application/octet-stream", roundTrippedBlob.MimeType); + Assert.Equal(originalBytes, roundTrippedBlob.DecodedData.ToArray()); + } + + [Theory] + [MemberData(nameof(ContentBlockTests.Base64TestData), MemberType = typeof(ContentBlockTests))] + public void BlobResourceContents_BlobSetter_ToAIContent_RoundTrips(byte[] originalBytes) + { + string base64 = Convert.ToBase64String(originalBytes); + var blob = new BlobResourceContents + { + Blob = System.Text.Encoding.UTF8.GetBytes(base64), + Uri = "file:///test.bin", + MimeType = "application/octet-stream" + }; + var embedded = new EmbeddedResourceBlock { Resource = blob }; + + var aiContent = Assert.IsType(embedded.ToAIContent()); + Assert.Equal("application/octet-stream", aiContent.MediaType); + Assert.Equal(originalBytes, aiContent.Data.ToArray()); + + var roundTripped = Assert.IsType(aiContent.ToContentBlock()); + var roundTrippedBlob = Assert.IsType(roundTripped.Resource); + Assert.Equal("application/octet-stream", roundTrippedBlob.MimeType); + Assert.Equal(originalBytes, roundTrippedBlob.DecodedData.ToArray()); + } + + [Theory] + [MemberData(nameof(ContentBlockTests.Base64TestData), MemberType = typeof(ContentBlockTests))] + public void ImageContentBlock_JsonDeserialized_ToAIContent_RoundTrips(byte[] originalBytes) + { + string base64 = Convert.ToBase64String(originalBytes); + string json = $$"""{"type":"image","data":"{{base64}}","mimeType":"image/png"}"""; + + var image = Assert.IsType( + JsonSerializer.Deserialize(json, McpJsonUtilities.DefaultOptions)); + var aiContent = Assert.IsType(image.ToAIContent()); + Assert.Equal(originalBytes, aiContent.Data.ToArray()); + + var roundTripped = Assert.IsType(aiContent.ToContentBlock()); + Assert.Equal(originalBytes, roundTripped.DecodedData.ToArray()); + } + + [Theory] + [MemberData(nameof(ContentBlockTests.Base64TestData), MemberType = typeof(ContentBlockTests))] + public void ImageContentBlock_EscapedJsonDeserialized_ToAIContent_RoundTrips(byte[] originalBytes) + { + string base64 = Convert.ToBase64String(originalBytes); + string json = $$"""{"type":"image","data":"{{base64.Replace("/", "\\/")}}","mimeType":"image/png"}"""; + + var image = Assert.IsType( + JsonSerializer.Deserialize(json, McpJsonUtilities.DefaultOptions)); + var aiContent = Assert.IsType(image.ToAIContent()); + Assert.Equal(originalBytes, aiContent.Data.ToArray()); + + var roundTripped = Assert.IsType(aiContent.ToContentBlock()); + Assert.Equal(originalBytes, roundTripped.DecodedData.ToArray()); + } + + [Theory] + [MemberData(nameof(ContentBlockTests.Base64TestData), MemberType = typeof(ContentBlockTests))] + public void AudioContentBlock_EscapedJsonDeserialized_ToAIContent_RoundTrips(byte[] originalBytes) + { + string base64 = Convert.ToBase64String(originalBytes); + string json = $$"""{"type":"audio","data":"{{base64.Replace("/", "\\/")}}","mimeType":"audio/wav"}"""; + + var audio = Assert.IsType( + JsonSerializer.Deserialize(json, McpJsonUtilities.DefaultOptions)); + var aiContent = Assert.IsType(audio.ToAIContent()); + Assert.Equal(originalBytes, aiContent.Data.ToArray()); + + var roundTripped = Assert.IsType(aiContent.ToContentBlock()); + Assert.Equal(originalBytes, roundTripped.DecodedData.ToArray()); + } + + [Theory] + [MemberData(nameof(ContentBlockTests.Base64TestData), MemberType = typeof(ContentBlockTests))] + public void BlobResourceContents_EscapedJsonDeserialized_ToAIContent_RoundTrips(byte[] originalBytes) + { + string base64 = Convert.ToBase64String(originalBytes); + string json = $$"""{"uri":"file:///test.bin","blob":"{{base64.Replace("/", "\\/")}}","mimeType":"application/octet-stream"}"""; + + var blob = Assert.IsType( + JsonSerializer.Deserialize(json, McpJsonUtilities.DefaultOptions)); + var embedded = new EmbeddedResourceBlock { Resource = blob }; + + var aiContent = Assert.IsType(embedded.ToAIContent()); + Assert.Equal(originalBytes, aiContent.Data.ToArray()); + + var roundTripped = Assert.IsType(aiContent.ToContentBlock()); + var roundTrippedBlob = Assert.IsType(roundTripped.Resource); + Assert.Equal(originalBytes, roundTrippedBlob.DecodedData.ToArray()); + } } // Test type for named user-defined type test diff --git a/tests/ModelContextProtocol.Tests/Protocol/ContentBlockTests.cs b/tests/ModelContextProtocol.Tests/Protocol/ContentBlockTests.cs index 957a45260..6e943f025 100644 --- a/tests/ModelContextProtocol.Tests/Protocol/ContentBlockTests.cs +++ b/tests/ModelContextProtocol.Tests/Protocol/ContentBlockTests.cs @@ -1,4 +1,5 @@ using ModelContextProtocol.Protocol; +using System.Text; using System.Text.Json; namespace ModelContextProtocol.Tests.Protocol; @@ -290,4 +291,275 @@ public void AudioContentBlock_FromBytes_ThrowsForNullOrWhiteSpaceMimeType(string { Assert.ThrowsAny(() => AudioContentBlock.FromBytes((byte[])[1, 2, 3], mimeType!)); } + + [Fact] + public void ImageContentBlock_Deserialization_HandlesEscapedForwardSlashInBase64() + { + // Base64 uses '/' which some JSON encoders escape as '\/' (valid JSON). + // The converter must unescape before storing the base64 UTF-8 bytes. + byte[] originalBytes = [0xFF, 0xD8, 0xFF, 0xE0]; // sample bytes that produce '/' in base64 + string base64 = Convert.ToBase64String(originalBytes); // "/9j/4A==" + Assert.Contains("/", base64); + + // Simulate a JSON encoder that escapes '/' as '\/' + string json = $$"""{"type":"image","data":"{{base64.Replace("/", "\\/")}}","mimeType":"image/jpeg"}"""; + + var deserialized = JsonSerializer.Deserialize(json, McpJsonUtilities.DefaultOptions); + var image = Assert.IsType(deserialized); + Assert.Equal(base64, System.Text.Encoding.UTF8.GetString(image.Data.ToArray())); + Assert.Equal(originalBytes, image.DecodedData.ToArray()); + } + + [Fact] + public void AudioContentBlock_Deserialization_HandlesEscapedForwardSlashInBase64() + { + byte[] originalBytes = [0xFF, 0xD8, 0xFF, 0xE0]; + string base64 = Convert.ToBase64String(originalBytes); + Assert.Contains("/", base64); + + string json = $$"""{"type":"audio","data":"{{base64.Replace("/", "\\/")}}","mimeType":"audio/wav"}"""; + + var deserialized = JsonSerializer.Deserialize(json, McpJsonUtilities.DefaultOptions); + var audio = Assert.IsType(deserialized); + Assert.Equal(base64, System.Text.Encoding.UTF8.GetString(audio.Data.ToArray())); + Assert.Equal(originalBytes, audio.DecodedData.ToArray()); + } + + /// + /// Provides test data for base64 roundtrip tests. Each entry is a byte array that exercises + /// different base64 encoding characteristics: + /// - Various lengths producing 0, 1, or 2 padding characters + /// - Bytes that produce all 64 base64 alphabet characters including '+' and '/' + /// + public static TheoryData Base64TestData() + { + var data = new TheoryData + { + Array.Empty(), // empty: "" + new byte[] { 0x00 }, // 1 byte, 2 padding chars: "AA==" + new byte[] { 0x00, 0x01 }, // 2 bytes, 1 padding char: "AAE=" + new byte[] { 0x00, 0x01, 0x02 }, // 3 bytes, no padding: "AAEC" + new byte[] { 0xFF, 0xD8, 0xFF, 0xE0 }, // produces '/' in base64: "/9j/4A==" + new byte[] { 0xFB, 0xEF, 0xBE }, // produces '+' in base64: "++++" + }; + + // All 256 byte values to exercise the full base64 alphabet + byte[] allBytes = new byte[256]; + for (int i = 0; i < 256; i++) + { + allBytes[i] = (byte)i; + } + data.Add(allBytes); + + // Larger payload (1024 bytes) + byte[] largePayload = new byte[1024]; + new Random(42).NextBytes(largePayload); + data.Add(largePayload); + + return data; + } + + [Theory] + [MemberData(nameof(Base64TestData))] + public void ImageContentBlock_FromBytes_RoundtripsCorrectly(byte[] originalBytes) + { + string expectedBase64 = Convert.ToBase64String(originalBytes); + + var image = ImageContentBlock.FromBytes(originalBytes, "image/png"); + + Assert.Equal("image/png", image.MimeType); + Assert.Equal(originalBytes, image.DecodedData.ToArray()); + Assert.Equal(expectedBase64, Encoding.UTF8.GetString(image.Data.ToArray())); + } + + [Theory] + [MemberData(nameof(Base64TestData))] + public void ImageContentBlock_DataSetter_RoundtripsCorrectly(byte[] originalBytes) + { + string base64 = Convert.ToBase64String(originalBytes); + byte[] base64Utf8 = Encoding.UTF8.GetBytes(base64); + + var image = new ImageContentBlock { Data = base64Utf8, MimeType = "image/png" }; + + Assert.Equal(base64Utf8, image.Data.ToArray()); + Assert.Equal(originalBytes, image.DecodedData.ToArray()); + } + + [Theory] + [MemberData(nameof(Base64TestData))] + public void ImageContentBlock_JsonRoundtrip_PreservesData(byte[] originalBytes) + { + string base64 = Convert.ToBase64String(originalBytes); + byte[] base64Utf8 = Encoding.UTF8.GetBytes(base64); + + var original = new ImageContentBlock { Data = base64Utf8, MimeType = "image/png" }; + string json = JsonSerializer.Serialize(original, McpJsonUtilities.DefaultOptions); + var deserialized = Assert.IsType( + JsonSerializer.Deserialize(json, McpJsonUtilities.DefaultOptions)); + + Assert.Equal(base64Utf8, deserialized.Data.ToArray()); + Assert.Equal(originalBytes, deserialized.DecodedData.ToArray()); + } + + [Theory] + [MemberData(nameof(Base64TestData))] + public void ImageContentBlock_FromBytes_JsonRoundtrip_PreservesData(byte[] originalBytes) + { + string expectedBase64 = Convert.ToBase64String(originalBytes); + + var original = ImageContentBlock.FromBytes(originalBytes, "image/jpeg"); + string json = JsonSerializer.Serialize(original, McpJsonUtilities.DefaultOptions); + var deserialized = Assert.IsType( + JsonSerializer.Deserialize(json, McpJsonUtilities.DefaultOptions)); + + Assert.Equal(expectedBase64, Encoding.UTF8.GetString(deserialized.Data.ToArray())); + Assert.Equal(originalBytes, deserialized.DecodedData.ToArray()); + } + + [Theory] + [MemberData(nameof(Base64TestData))] + public void ImageContentBlock_EscapedJsonRoundtrip_PreservesData(byte[] originalBytes) + { + string base64 = Convert.ToBase64String(originalBytes); + + // Simulate JSON encoder that escapes '/' as '\/' + string json = $$"""{"type":"image","data":"{{base64.Replace("/", "\\/")}}","mimeType":"image/png"}"""; + + var deserialized = Assert.IsType( + JsonSerializer.Deserialize(json, McpJsonUtilities.DefaultOptions)); + + Assert.Equal(base64, Encoding.UTF8.GetString(deserialized.Data.ToArray())); + Assert.Equal(originalBytes, deserialized.DecodedData.ToArray()); + } + + [Fact] + public void ImageContentBlock_DataSetterInvalidatesCachedDecodedData() + { + byte[] bytes1 = [1, 2, 3]; + var image = ImageContentBlock.FromBytes(bytes1, "image/png"); + + // Access DecodedData to populate cache + Assert.Equal(bytes1, image.DecodedData.ToArray()); + + // Set new Data to invalidate cache + byte[] newBytes = [4, 5, 6]; + string newBase64 = Convert.ToBase64String(newBytes); + image.Data = Encoding.UTF8.GetBytes(newBase64); + + Assert.Equal(newBytes, image.DecodedData.ToArray()); + } + + [Theory] + [MemberData(nameof(Base64TestData))] + public void AudioContentBlock_FromBytes_RoundtripsCorrectly(byte[] originalBytes) + { + string expectedBase64 = Convert.ToBase64String(originalBytes); + + var audio = AudioContentBlock.FromBytes(originalBytes, "audio/wav"); + + Assert.Equal("audio/wav", audio.MimeType); + Assert.Equal(originalBytes, audio.DecodedData.ToArray()); + Assert.Equal(expectedBase64, Encoding.UTF8.GetString(audio.Data.ToArray())); + } + + [Theory] + [MemberData(nameof(Base64TestData))] + public void AudioContentBlock_DataSetter_RoundtripsCorrectly(byte[] originalBytes) + { + string base64 = Convert.ToBase64String(originalBytes); + byte[] base64Utf8 = Encoding.UTF8.GetBytes(base64); + + var audio = new AudioContentBlock { Data = base64Utf8, MimeType = "audio/wav" }; + + Assert.Equal(base64Utf8, audio.Data.ToArray()); + Assert.Equal(originalBytes, audio.DecodedData.ToArray()); + } + + [Theory] + [MemberData(nameof(Base64TestData))] + public void AudioContentBlock_JsonRoundtrip_PreservesData(byte[] originalBytes) + { + string base64 = Convert.ToBase64String(originalBytes); + byte[] base64Utf8 = Encoding.UTF8.GetBytes(base64); + + var original = new AudioContentBlock { Data = base64Utf8, MimeType = "audio/wav" }; + string json = JsonSerializer.Serialize(original, McpJsonUtilities.DefaultOptions); + var deserialized = Assert.IsType( + JsonSerializer.Deserialize(json, McpJsonUtilities.DefaultOptions)); + + Assert.Equal(base64Utf8, deserialized.Data.ToArray()); + Assert.Equal(originalBytes, deserialized.DecodedData.ToArray()); + } + + [Theory] + [MemberData(nameof(Base64TestData))] + public void AudioContentBlock_FromBytes_JsonRoundtrip_PreservesData(byte[] originalBytes) + { + string expectedBase64 = Convert.ToBase64String(originalBytes); + + var original = AudioContentBlock.FromBytes(originalBytes, "audio/mp3"); + string json = JsonSerializer.Serialize(original, McpJsonUtilities.DefaultOptions); + var deserialized = Assert.IsType( + JsonSerializer.Deserialize(json, McpJsonUtilities.DefaultOptions)); + + Assert.Equal(expectedBase64, Encoding.UTF8.GetString(deserialized.Data.ToArray())); + Assert.Equal(originalBytes, deserialized.DecodedData.ToArray()); + } + + [Theory] + [MemberData(nameof(Base64TestData))] + public void AudioContentBlock_EscapedJsonRoundtrip_PreservesData(byte[] originalBytes) + { + string base64 = Convert.ToBase64String(originalBytes); + + string json = $$"""{"type":"audio","data":"{{base64.Replace("/", "\\/")}}","mimeType":"audio/wav"}"""; + + var deserialized = Assert.IsType( + JsonSerializer.Deserialize(json, McpJsonUtilities.DefaultOptions)); + + Assert.Equal(base64, Encoding.UTF8.GetString(deserialized.Data.ToArray())); + Assert.Equal(originalBytes, deserialized.DecodedData.ToArray()); + } + + [Fact] + public void AudioContentBlock_DataSetterInvalidatesCachedDecodedData() + { + byte[] bytes1 = [1, 2, 3]; + var audio = AudioContentBlock.FromBytes(bytes1, "audio/wav"); + + Assert.Equal(bytes1, audio.DecodedData.ToArray()); + + byte[] newBytes = [4, 5, 6]; + string newBase64 = Convert.ToBase64String(newBytes); + audio.Data = Encoding.UTF8.GetBytes(newBase64); + + Assert.Equal(newBytes, audio.DecodedData.ToArray()); + } + + [Theory] + [MemberData(nameof(Base64TestData))] + public void ImageContentBlock_FromBytes_LazilyEncodesData(byte[] originalBytes) + { + // FromBytes should only decode when Data is accessed + var image = ImageContentBlock.FromBytes(originalBytes, "image/png"); + + // First, access DecodedData without touching Data + Assert.Equal(originalBytes, image.DecodedData.ToArray()); + + // Now access Data and verify it lazily encoded correctly + string expectedBase64 = Convert.ToBase64String(originalBytes); + Assert.Equal(expectedBase64, Encoding.UTF8.GetString(image.Data.ToArray())); + } + + [Theory] + [MemberData(nameof(Base64TestData))] + public void AudioContentBlock_FromBytes_LazilyEncodesData(byte[] originalBytes) + { + var audio = AudioContentBlock.FromBytes(originalBytes, "audio/wav"); + + Assert.Equal(originalBytes, audio.DecodedData.ToArray()); + + string expectedBase64 = Convert.ToBase64String(originalBytes); + Assert.Equal(expectedBase64, Encoding.UTF8.GetString(audio.Data.ToArray())); + } } \ No newline at end of file diff --git a/tests/ModelContextProtocol.Tests/Protocol/ResourceContentsTests.cs b/tests/ModelContextProtocol.Tests/Protocol/ResourceContentsTests.cs index 0940caf08..073f349de 100644 --- a/tests/ModelContextProtocol.Tests/Protocol/ResourceContentsTests.cs +++ b/tests/ModelContextProtocol.Tests/Protocol/ResourceContentsTests.cs @@ -1,4 +1,5 @@ using ModelContextProtocol.Protocol; +using System.Text; using System.Text.Json; namespace ModelContextProtocol.Tests.Protocol; @@ -489,4 +490,129 @@ public static void BlobResourceContents_NullMimeType_OmittedFromJson() Assert.DoesNotContain("mimeType", json); } + + [Fact] + public static void BlobResourceContents_Deserialization_HandlesEscapedForwardSlashInBase64() + { + // Base64 uses '/' which some JSON encoders escape as '\/' (valid JSON). + // The converter must unescape before storing the base64 UTF-8 bytes. + byte[] originalBytes = [0xFF, 0xD8, 0xFF, 0xE0]; // sample bytes that produce '/' in base64 + string base64 = Convert.ToBase64String(originalBytes); // "/9j/4A==" + Assert.Contains("/", base64); + + // Simulate a JSON encoder that escapes '/' as '\/' + string json = $$"""{"uri":"file:///test.bin","blob":"{{base64.Replace("/", "\\/")}}","mimeType":"application/octet-stream"}"""; + + var deserialized = JsonSerializer.Deserialize(json, McpJsonUtilities.DefaultOptions); + var blob = Assert.IsType(deserialized); + Assert.Equal(base64, System.Text.Encoding.UTF8.GetString(blob.Blob.ToArray())); + Assert.Equal(originalBytes, blob.DecodedData.ToArray()); + } + + [Theory] + [MemberData(nameof(ContentBlockTests.Base64TestData), MemberType = typeof(ContentBlockTests))] + public static void BlobResourceContents_FromBytes_RoundtripsCorrectly(byte[] originalBytes) + { + string expectedBase64 = Convert.ToBase64String(originalBytes); + + var blob = BlobResourceContents.FromBytes(originalBytes, "file:///test.bin", "application/octet-stream"); + + Assert.Equal("file:///test.bin", blob.Uri); + Assert.Equal("application/octet-stream", blob.MimeType); + Assert.Equal(originalBytes, blob.DecodedData.ToArray()); + Assert.Equal(expectedBase64, Encoding.UTF8.GetString(blob.Blob.ToArray())); + } + + [Theory] + [MemberData(nameof(ContentBlockTests.Base64TestData), MemberType = typeof(ContentBlockTests))] + public static void BlobResourceContents_BlobSetter_RoundtripsCorrectly(byte[] originalBytes) + { + string base64 = Convert.ToBase64String(originalBytes); + byte[] base64Utf8 = Encoding.UTF8.GetBytes(base64); + + var blob = new BlobResourceContents { Blob = base64Utf8, Uri = "file:///test.bin" }; + + Assert.Equal(base64Utf8, blob.Blob.ToArray()); + Assert.Equal(originalBytes, blob.DecodedData.ToArray()); + } + + [Theory] + [MemberData(nameof(ContentBlockTests.Base64TestData), MemberType = typeof(ContentBlockTests))] + public static void BlobResourceContents_JsonRoundtrip_PreservesData(byte[] originalBytes) + { + string base64 = Convert.ToBase64String(originalBytes); + byte[] base64Utf8 = Encoding.UTF8.GetBytes(base64); + + var original = new BlobResourceContents + { + Blob = base64Utf8, + Uri = "file:///test.bin", + MimeType = "application/octet-stream" + }; + string json = JsonSerializer.Serialize(original, McpJsonUtilities.DefaultOptions); + var deserialized = Assert.IsType( + JsonSerializer.Deserialize(json, McpJsonUtilities.DefaultOptions)); + + Assert.Equal(base64Utf8, deserialized.Blob.ToArray()); + Assert.Equal(originalBytes, deserialized.DecodedData.ToArray()); + } + + [Theory] + [MemberData(nameof(ContentBlockTests.Base64TestData), MemberType = typeof(ContentBlockTests))] + public static void BlobResourceContents_FromBytes_JsonRoundtrip_PreservesData(byte[] originalBytes) + { + string expectedBase64 = Convert.ToBase64String(originalBytes); + + var original = BlobResourceContents.FromBytes(originalBytes, "file:///test.bin", "application/octet-stream"); + string json = JsonSerializer.Serialize(original, McpJsonUtilities.DefaultOptions); + var deserialized = Assert.IsType( + JsonSerializer.Deserialize(json, McpJsonUtilities.DefaultOptions)); + + Assert.Equal(expectedBase64, Encoding.UTF8.GetString(deserialized.Blob.ToArray())); + Assert.Equal(originalBytes, deserialized.DecodedData.ToArray()); + } + + [Theory] + [MemberData(nameof(ContentBlockTests.Base64TestData), MemberType = typeof(ContentBlockTests))] + public static void BlobResourceContents_EscapedJsonRoundtrip_PreservesData(byte[] originalBytes) + { + string base64 = Convert.ToBase64String(originalBytes); + + string json = $$"""{"uri":"file:///test.bin","blob":"{{base64.Replace("/", "\\/")}}","mimeType":"application/octet-stream"}"""; + + var deserialized = Assert.IsType( + JsonSerializer.Deserialize(json, McpJsonUtilities.DefaultOptions)); + + Assert.Equal(base64, Encoding.UTF8.GetString(deserialized.Blob.ToArray())); + Assert.Equal(originalBytes, deserialized.DecodedData.ToArray()); + } + + [Fact] + public static void BlobResourceContents_BlobSetterInvalidatesCachedDecodedData() + { + byte[] bytes1 = [1, 2, 3]; + var blob = BlobResourceContents.FromBytes(bytes1, "file:///test.bin"); + + Assert.Equal(bytes1, blob.DecodedData.ToArray()); + + byte[] newBytes = [4, 5, 6]; + string newBase64 = Convert.ToBase64String(newBytes); + blob.Blob = Encoding.UTF8.GetBytes(newBase64); + + Assert.Equal(newBytes, blob.DecodedData.ToArray()); + } + + [Theory] + [MemberData(nameof(ContentBlockTests.Base64TestData), MemberType = typeof(ContentBlockTests))] + public static void BlobResourceContents_FromBytes_LazilyEncodesBlob(byte[] originalBytes) + { + var blob = BlobResourceContents.FromBytes(originalBytes, "file:///test.bin"); + + // Access DecodedData first without touching Blob + Assert.Equal(originalBytes, blob.DecodedData.ToArray()); + + // Now access Blob and verify lazy encoding + string expectedBase64 = Convert.ToBase64String(originalBytes); + Assert.Equal(expectedBase64, Encoding.UTF8.GetString(blob.Blob.ToArray())); + } }