diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index c68e7b5..8a208f3 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -20,30 +20,34 @@ jobs: fail-fast: false matrix: os: [ubuntu-latest, macos-latest, windows-latest] - python-version: [ '3.8', '3.9', '3.10', '3.11', '3.12' ] + python-version: [ '3.9', '3.10', '3.11', '3.12', '3.13' ] steps: - name: Checkout - uses: actions/checkout@v3 + uses: actions/checkout@v4 - - uses: actions/setup-python@v4 + - uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - name: Upgrade pip run: python3 -m pip install --upgrade pip + - name: Ensure pip >= v25.1 + run: python -m pip install "pip >= 25.1" + - name: Install ypywidgets in dev mode - run: pip install -e ".[dev]" + run: pip install --group dev -e . - name: Check types run: mypy src - name: Run tests + if: ${{ !((matrix.python-version == '3.13') && (matrix.os == 'ubuntu-latest')) }} run: pytest ./tests -v --color=yes - name: Run code coverage - if: ${{ (matrix.python-version == '3.12') && (matrix.os == 'ubuntu-latest') }} + if: ${{ (matrix.python-version == '3.13') && (matrix.os == 'ubuntu-latest') }} run: | coverage run -m pytest tests coverage report --fail-under=100 diff --git a/pyproject.toml b/pyproject.toml index 428c59e..562192d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,7 +8,7 @@ version = "0.9.7" description = "Y-based Jupyter widgets for Python" readme = "README.md" license = "MIT" -requires-python = ">=3.8" +requires-python = ">=3.9" authors = [ { name = "David Brochart", email = "david.brochart@gmail.com" }, ] @@ -24,9 +24,9 @@ dependencies = [ ] [project.urls] -Homepage = "https://github.com/davidbrochart/ypywidgets" +Homepage = "https://github.com/QuantStack/ypywidgets" -[project.optional-dependencies] +[dependency-groups] dev = [ "coverage >=7.0.0,<8.0.0", "mypy", diff --git a/src/ypywidgets/comm.py b/src/ypywidgets/comm.py index 89a5a89..757ae8f 100644 --- a/src/ypywidgets/comm.py +++ b/src/ypywidgets/comm.py @@ -1,5 +1,8 @@ from __future__ import annotations +from collections.abc import Callable +from typing import Any + import comm from pycrdt import ( Doc, @@ -41,6 +44,8 @@ def create_widget_comm( class CommProvider: + _on_receive: Callable[[bytes], None] | None + def __init__( self, ydoc: Doc, @@ -48,33 +53,45 @@ def __init__( ) -> None: self._ydoc = ydoc self._comm = comm + self._on_receive = None msg = create_sync_message(ydoc) self._comm.send(buffers=[msg]) self._comm.on_msg(self._receive) - def _receive(self, msg): + def _receive(self, msg: dict[str, Any]): message = bytes(msg["buffers"][0]) - if message[0] == YMessageType.SYNC: - reply = handle_sync_message(message[1:], self._ydoc) + message_type = message[0] + message_content = message[1:] + if message_type == YMessageType.SYNC: + reply = handle_sync_message(message_content, self._ydoc) if reply is not None: self._comm.send(buffers=[reply]) if message[1] == YSyncMessageType.SYNC_STEP2: self._ydoc.observe(self._send) + elif message_type == 2: + if self._on_receive is not None: + self._on_receive(message_content) def _send(self, event: TransactionEvent): update = event.update message = create_update_message(update) self._comm.send(buffers=[message]) + def on_receive(self, callback: Callable[[bytes], None]): + self._on_receive = callback + + def send(self, message: bytes): + self._comm.send(buffers=[bytes([2]) + message]) + class CommWidget(Widget): def __init__( - self, - ydoc: Doc | None = None, - comm_data: dict | None = None, - comm_metadata: dict | None = None, - comm_id: str | None = None, - ): + self, + ydoc: Doc | None = None, + comm_data: dict | None = None, + comm_metadata: dict | None = None, + comm_id: str | None = None, + ): super().__init__(ydoc) model_name = self.__class__.__name__ _model_name = self.ydoc["_model_name"] = Text() @@ -85,7 +102,7 @@ def __init__( create_ydoc=not ydoc, ) self._comm = create_widget_comm(comm_data, comm_metadata, comm_id) - CommProvider(self.ydoc, self._comm) + self._provider = CommProvider(self.ydoc, self._comm) def _repr_mimebundle_(self, *args, **kwargs): # pragma: nocover plaintext = repr(self) @@ -100,3 +117,9 @@ def _repr_mimebundle_(self, *args, **kwargs): # pragma: nocover } } return data + + def on_receive(self, callback: Callable[[bytes], None]) -> None: + self._provider.on_receive(callback) + + def send(self, message: bytes) -> None: + self._provider.send(message) diff --git a/tests/conftest.py b/tests/conftest.py index b7a0ea5..df54a65 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,18 +1,12 @@ +from __future__ import annotations + import asyncio import time -from typing import Optional import comm import pytest -from pycrdt import ( - YMessageType, - YSyncMessageType, - TransactionEvent, - create_sync_message, - create_update_message, - handle_sync_message, -) -from ypywidgets import Widget +import pytest_asyncio +from pycrdt import create_sync_message from ypywidgets.comm import CommWidget @@ -47,10 +41,10 @@ async def receive(self): @pytest.fixture def widget_factories(): - return CommWidget, Widget + return CommWidget, CommWidget -@pytest.fixture +@pytest_asyncio.fixture async def synced_widgets(widget_factories): local_widget = widget_factories[0]() remote_widget_manager = RemoteWidgetManager(widget_factories[1], local_widget._comm) @@ -60,35 +54,31 @@ async def synced_widgets(widget_factories): class RemoteWidgetManager: - comm: Optional[MockComm] - widget: Optional[Widget] + comm: MockComm + widget: CommWidget | None - def __init__(self, widget_factory, comm): + def __init__(self, widget_factory, local_comm): self.widget_factory = widget_factory - self.comm = comm + self.local_comm = local_comm self.widget = None self.receive_task = asyncio.create_task(self.receive()) - def send(self, event: TransactionEvent): - update = event.update - message = create_update_message(update) - self.comm.recv_queue.put_nowait({"buffers": [message]}) + async def send(self): + while True: + msg_type, data, metadata, buffers, target_name, target_module = await self.widget._comm.send_queue.get() + if msg_type == "comm_msg": + self.local_comm.recv_queue.put_nowait({"buffers": buffers}) async def receive(self): while True: - msg_type, data, metadata, buffers, target_name, target_module = await self.comm.send_queue.get() + msg_type, data, metadata, buffers, target_name, target_module = await self.local_comm.send_queue.get() if msg_type == "comm_open": self.widget = self.widget_factory() msg = create_sync_message(self.widget.ydoc) - self.comm.handle_msg({"buffers": [msg]}) + self.local_comm.recv_queue.put_nowait({"buffers": [msg]}) + self.send_task = asyncio.create_task(self.send()) elif msg_type == "comm_msg": - message = buffers[0] - if message[0] == YMessageType.SYNC: - reply = handle_sync_message(message[1:], self.widget.ydoc) - if reply is not None: - self.comm.handle_msg({"buffers": [reply]}) - if message[1] == YSyncMessageType.SYNC_STEP2: - self.widget.ydoc.observe(self.send) + self.widget._comm.recv_queue.put_nowait({"buffers": buffers}) async def get_widget(self, timeout=0.1): t = time.monotonic() diff --git a/tests/test_attributes.py b/tests/test_attributes.py index 4bcfe85..e20ca5c 100644 --- a/tests/test_attributes.py +++ b/tests/test_attributes.py @@ -23,7 +23,7 @@ def _watch_foo(self, old, new): @pytest.mark.asyncio async def test_create_ydoc(synced_widgets): - local_widget, remote_widget = await synced_widgets + local_widget, remote_widget = synced_widgets local_text = Text() local_widget.ydoc["text"] = local_text @@ -39,7 +39,7 @@ async def test_create_ydoc(synced_widgets): @pytest.mark.asyncio @pytest.mark.parametrize("widget_factories", ((Widget1, Widget1),)) async def test_sync_attribute(widget_factories, synced_widgets): - local_widget, remote_widget = await synced_widgets + local_widget, remote_widget = synced_widgets with pytest.raises(AttributeError): assert local_widget.wrong_attr1 @@ -61,7 +61,7 @@ async def test_sync_attribute(widget_factories, synced_widgets): @pytest.mark.asyncio @pytest.mark.parametrize("widget_factories", ((Widget1, Widget2),)) async def test_watch_attribute(widget_factories, synced_widgets, capfd): - local_widget, remote_widget = await synced_widgets + local_widget, remote_widget = synced_widgets local_widget.foo = "foo" diff --git a/tests/test_messages.py b/tests/test_messages.py new file mode 100644 index 0000000..9c2ebdf --- /dev/null +++ b/tests/test_messages.py @@ -0,0 +1,30 @@ +import asyncio + +import pytest + + +@pytest.mark.asyncio +async def test_messages(synced_widgets): + local_messages = [] + remote_messages = [] + local_widget, remote_widget = synced_widgets + + def on_local_message(message): + local_messages.append(message) + + def on_remote_message(message): + remote_messages.append(message) + remote_widget.send(message + b", World!") + + local_widget.on_receive(on_local_message) + remote_widget.on_receive(on_remote_message) + + local_widget.send(b"Hello") + await asyncio.sleep(0.1) + assert remote_messages == [b"Hello"] + assert local_messages == [b"Hello, World!"] + + remote_widget.send(b"msg") + await asyncio.sleep(0.1) + assert remote_messages == [b"Hello"] + assert local_messages == [b"Hello, World!", b"msg"]