Skip to content

Commit bbb8eff

Browse files
committed
feat: add msgpack support for WebSocket communication
- Implement format detection (`json`/`msgpack`) in WebSocket transport. - Add `use_binary_protocol` option to enable `msgpack` encoding/decoding. - Ensure compatibility with protocol parameters and set appropriate formats during connection. - Add tests to verify `msgpack` and `json` behavior based on `use_binary_protocol` setting.
1 parent c9be051 commit bbb8eff

File tree

5 files changed

+118
-5
lines changed

5 files changed

+118
-5
lines changed

ably/realtime/connectionmanager.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,9 @@ async def __get_transport_params(self) -> dict:
154154
params["v"] = protocol_version
155155
if self.connection_details:
156156
params["resume"] = self.connection_details.connection_key
157+
# RTN2a: Set format to msgpack if use_binary_protocol is enabled
158+
if self.options.use_binary_protocol:
159+
params["format"] = "msgpack"
157160
return params
158161

159162
async def close_impl(self) -> None:

ably/realtime/realtime_channel.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -558,7 +558,8 @@ def _on_message(self, proto_msg: dict) -> None:
558558
elif action == ProtocolMessageAction.MESSAGE:
559559
messages = []
560560
try:
561-
messages = Message.from_encoded_array(proto_msg.get('messages'), context=self.__decoding_context)
561+
messages = Message.from_encoded_array(proto_msg.get('messages'),
562+
cipher=self.cipher, context=self.__decoding_context)
562563
self.__decoding_context.last_message_id = messages[-1].id
563564
self.__channel_serial = channel_serial
564565
except AblyException as e:

ably/transport/websockettransport.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
from enum import IntEnum
99
from typing import TYPE_CHECKING
1010

11+
import msgpack
12+
1113
from ably.http.httputils import HttpUtils
1214
from ably.types.connectiondetails import ConnectionDetails
1315
from ably.util.eventemitter import EventEmitter
@@ -71,6 +73,7 @@ def __init__(self, connection_manager: ConnectionManager, host: str, params: dic
7173
self.is_disposed = False
7274
self.host = host
7375
self.params = params
76+
self.format = params.get('format', 'json')
7477
super().__init__()
7578

7679
def connect(self):
@@ -189,12 +192,18 @@ async def ws_read_loop(self):
189192
raise AblyException('ws_read_loop started with no websocket', 500, 50000)
190193
try:
191194
async for raw in self.websocket:
192-
msg = json.loads(raw)
195+
# Decode based on format
196+
msg = self.decode_raw_websocket_frame(raw)
193197
task = asyncio.create_task(self.on_protocol_message(msg))
194198
task.add_done_callback(self.on_protcol_message_handled)
195199
except ConnectionClosedOK:
196200
return
197201

202+
def decode_raw_websocket_frame(self, raw: str | bytes) -> dict:
203+
if self.format == 'msgpack':
204+
return msgpack.unpackb(raw)
205+
return json.loads(raw)
206+
198207
def on_protcol_message_handled(self, task):
199208
try:
200209
exception = task.exception()
@@ -231,8 +240,13 @@ async def close(self):
231240
async def send(self, message: dict):
232241
if self.websocket is None:
233242
raise Exception()
234-
raw_msg = json.dumps(message)
235-
log.info(f'WebSocketTransport.send(): sending {raw_msg}')
243+
# Encode based on format
244+
if self.format == 'msgpack':
245+
raw_msg = msgpack.packb(message)
246+
log.info(f'WebSocketTransport.send(): sending msgpack message (length: {len(raw_msg)} bytes)')
247+
else:
248+
raw_msg = json.dumps(message)
249+
log.info(f'WebSocketTransport.send(): sending {raw_msg}')
236250
await self.websocket.send(raw_msg)
237251

238252
def set_idle_timer(self, timeout: float):

test/ably/realtime/realtimechannel_publish_test.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,10 @@
33
import pytest
44

55
from ably.realtime.connection import ConnectionState
6-
from ably.realtime.realtime_channel import ChannelState
6+
from ably.realtime.realtime_channel import ChannelOptions, ChannelState
77
from ably.transport.websockettransport import ProtocolMessageAction
88
from ably.types.message import Message
9+
from ably.util.crypto import CipherParams
910
from ably.util.exceptions import AblyException, IncompatibleClientIdException
1011
from test.ably.testapp import TestApp
1112
from test.ably.utils import BaseAsyncTestCase, WaitableEvent, assert_waiter
@@ -975,3 +976,37 @@ def on_message(message):
975976
assert data_received[1] == 'third message'
976977

977978
await ably.close()
979+
980+
async def test_publish_with_encryption(self):
981+
"""Verify that encrypted messages can be published and received correctly"""
982+
# Create connection with binary protocol enabled
983+
ably = await TestApp.get_ably_realtime(use_binary_protocol=True)
984+
await asyncio.wait_for(ably.connection.once_async(ConnectionState.CONNECTED), timeout=5)
985+
986+
# Get channel with encryption enabled
987+
cipher_params = CipherParams(secret_key=b'0123456789abcdef0123456789abcdef')
988+
channel_options = ChannelOptions(cipher=cipher_params)
989+
channel = ably.channels.get('encrypted_channel', channel_options)
990+
await channel.attach()
991+
992+
received_data = None
993+
data_received = WaitableEvent()
994+
def on_message(message):
995+
nonlocal received_data
996+
try:
997+
# message.decode()
998+
received_data = message.data
999+
data_received.finish()
1000+
except Exception as e:
1001+
data_received.finish()
1002+
raise e
1003+
1004+
await channel.subscribe(on_message)
1005+
1006+
await channel.publish('encrypted_event', 'sensitive data')
1007+
1008+
await data_received.wait()
1009+
1010+
assert received_data == 'sensitive data'
1011+
1012+
await ably.close()

test/ably/realtime/realtimeconnection_test.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -400,3 +400,63 @@ async def on_protocol_message(msg):
400400
await asyncio.wait_for(ably.connection.once_async(ConnectionState.CONNECTED), timeout=5)
401401

402402
await ably.close()
403+
404+
# RTN2f - Test msgpack format parameter when use_binary_protocol is enabled
405+
async def test_connection_format_msgpack_with_binary_protocol(self):
406+
"""Test that format=msgpack is sent when use_binary_protocol=True"""
407+
ably = await TestApp.get_ably_realtime(use_binary_protocol=True)
408+
await asyncio.wait_for(ably.connection.once_async(ConnectionState.CONNECTED), timeout=5)
409+
410+
received_raw_websocket_frames = []
411+
transport = ably.connection.connection_manager.transport
412+
original_decode_raw_websocket_frame = transport.decode_raw_websocket_frame
413+
414+
def intercepted_websocket_frame(data):
415+
received_raw_websocket_frames.append(data)
416+
return original_decode_raw_websocket_frame(data)
417+
418+
transport.decode_raw_websocket_frame = intercepted_websocket_frame
419+
420+
# Verify transport has format set to msgpack
421+
assert ably.connection.connection_manager.transport is not None
422+
assert ably.connection.connection_manager.transport.format == 'msgpack'
423+
424+
# Verify params include format=msgpack
425+
assert ably.connection.connection_manager.transport.params.get('format') == 'msgpack'
426+
427+
await ably.channels.get('connection_test').publish('test', b'test')
428+
429+
assert len(received_raw_websocket_frames) > 0
430+
assert all(isinstance(frame, bytes) for frame in received_raw_websocket_frames)
431+
432+
await ably.close()
433+
434+
async def test_connection_format_json_without_binary_protocol(self):
435+
"""Test that format defaults to json when use_binary_protocol=False"""
436+
ably = await TestApp.get_ably_realtime(use_binary_protocol=False)
437+
await asyncio.wait_for(ably.connection.once_async(ConnectionState.CONNECTED), timeout=5)
438+
439+
received_raw_websocket_frames = []
440+
transport = ably.connection.connection_manager.transport
441+
original_decode_raw_websocket_frame = transport.decode_raw_websocket_frame
442+
443+
def intercepted_websocket_frame(data):
444+
received_raw_websocket_frames.append(data)
445+
return original_decode_raw_websocket_frame(data)
446+
447+
transport.decode_raw_websocket_frame = intercepted_websocket_frame
448+
449+
# Verify transport has format set to json (default)
450+
assert ably.connection.connection_manager.transport is not None
451+
assert ably.connection.connection_manager.transport.format == 'json'
452+
453+
await ably.channels.get('connection_test').publish('test', b'test')
454+
455+
# Verify params don't include format parameter (or it's json)
456+
transport_format = ably.connection.connection_manager.transport.params.get('format')
457+
assert transport_format is None or transport_format == 'json'
458+
459+
assert len(received_raw_websocket_frames) > 0
460+
assert all(isinstance(frame, str) for frame in received_raw_websocket_frames)
461+
462+
await ably.close()

0 commit comments

Comments
 (0)