diff --git a/src/aleph/sdk/client/abstract.py b/src/aleph/sdk/client/abstract.py index 3daa198e..59b72554 100644 --- a/src/aleph/sdk/client/abstract.py +++ b/src/aleph/sdk/client/abstract.py @@ -42,7 +42,13 @@ from aleph.sdk.utils import extended_json_encoder from ..query.filters import MessageFilter, PostFilter -from ..query.responses import MessagesResponse, PostsResponse, PriceResponse +from ..query.responses import ( + CursorMessagesResponse, + CursorPostsResponse, + MessagesResponse, + PostsResponse, + PriceResponse, +) from ..types import GenericMessage, StorageEnum from ..utils import Writable, compute_sha256 @@ -120,26 +126,47 @@ async def get_posts( """ raise NotImplementedError("Did you mean to import `AlephHttpClient`?") + @abstractmethod + async def get_posts_cursor( + self, + page_size: int = DEFAULT_PAGE_SIZE, + cursor: str = "", + post_filter: Optional[PostFilter] = None, + ignore_invalid_messages: Optional[bool] = True, + invalid_messages_log_level: Optional[int] = logging.NOTSET, + ) -> CursorPostsResponse: + """ + Fetch a list of posts from the network using cursor-based pagination. + + :param page_size: Number of items to fetch, max 200 (Default: 200) + :param cursor: Opaque cursor from a previous response's next_cursor. Empty string starts from the beginning. + :param post_filter: Filter to apply to the posts (Default: None) + :param ignore_invalid_messages: Ignore invalid messages (Default: True) + :param invalid_messages_log_level: Log level to use for invalid messages (Default: logging.NOTSET) + """ + raise NotImplementedError("Did you mean to import `AlephHttpClient`?") + async def get_posts_iterator( self, post_filter: Optional[PostFilter] = None, ) -> AsyncIterable[PostMessage]: """ - Fetch all filtered posts, returning an async iterator and fetching them page by page. Might return duplicates - but will always return all posts. + Fetch all filtered posts, returning an async iterator and fetching them + using cursor-based pagination. Does not return duplicates. :param post_filter: Filter to apply to the posts (Default: None) """ - page = 1 - resp = None - while resp is None or len(resp.posts) > 0: - resp = await self.get_posts( - page=page, + cursor: str = "" + while True: + resp = await self.get_posts_cursor( + cursor=cursor, post_filter=post_filter, ) - page += 1 for post in resp.posts: yield post # type: ignore + if resp.next_cursor is None: + break + cursor = resp.next_cursor @abstractmethod async def download_file(self, file_hash: str) -> bytes: @@ -224,26 +251,47 @@ async def get_messages( """ raise NotImplementedError("Did you mean to import `AlephHttpClient`?") + @abstractmethod + async def get_messages_cursor( + self, + page_size: int = DEFAULT_PAGE_SIZE, + cursor: str = "", + message_filter: Optional[MessageFilter] = None, + ignore_invalid_messages: Optional[bool] = True, + invalid_messages_log_level: Optional[int] = logging.NOTSET, + ) -> CursorMessagesResponse: + """ + Fetch a list of messages from the network using cursor-based pagination. + + :param page_size: Number of items to fetch, max 200 (Default: 200) + :param cursor: Opaque cursor from a previous response's next_cursor. Empty string starts from the beginning. + :param message_filter: Filter to apply to the messages + :param ignore_invalid_messages: Ignore invalid messages (Default: True) + :param invalid_messages_log_level: Log level to use for invalid messages (Default: logging.NOTSET) + """ + raise NotImplementedError("Did you mean to import `AlephHttpClient`?") + async def get_messages_iterator( self, message_filter: Optional[MessageFilter] = None, ) -> AsyncIterable[AlephMessage]: """ - Fetch all filtered messages, returning an async iterator and fetching them page by page. Might return duplicates - but will always return all messages. + Fetch all filtered messages, returning an async iterator and fetching + them using cursor-based pagination. Does not return duplicates. :param message_filter: Filter to apply to the messages """ - page = 1 - resp = None - while resp is None or len(resp.messages) > 0: - resp = await self.get_messages( - page=page, + cursor: str = "" + while True: + resp = await self.get_messages_cursor( + cursor=cursor, message_filter=message_filter, ) - page += 1 for message in resp.messages: yield message + if resp.next_cursor is None: + break + cursor = resp.next_cursor @abstractmethod async def get_message( diff --git a/src/aleph/sdk/client/http.py b/src/aleph/sdk/client/http.py index 6fa6d4ac..c58b5571 100644 --- a/src/aleph/sdk/client/http.py +++ b/src/aleph/sdk/client/http.py @@ -52,10 +52,12 @@ RemovedMessageError, ResourceNotFoundError, ) -from ..query.filters import BalanceFilter, MessageFilter, PostFilter +from ..query.filters import BalanceFilter, MessageFilter, PostFilter, SortBy from ..query.responses import ( BalanceResponse, CreditsHistoryResponse, + CursorMessagesResponse, + CursorPostsResponse, MessagesResponse, Post, PostsResponse, @@ -260,6 +262,56 @@ async def get_posts( pagination_item=response_json["pagination_item"], ) + async def get_posts_cursor( + self, + page_size: int = 200, + cursor: str = "", + post_filter: Optional[PostFilter] = None, + ignore_invalid_messages: Optional[bool] = True, + invalid_messages_log_level: Optional[int] = logging.NOTSET, + ) -> CursorPostsResponse: + ignore_invalid_messages = ( + True if ignore_invalid_messages is None else ignore_invalid_messages + ) + invalid_messages_log_level = ( + logging.NOTSET + if invalid_messages_log_level is None + else invalid_messages_log_level + ) + + if post_filter and post_filter.sort_by == SortBy.TX_TIME: + raise ValueError( + "sortBy=tx-time is not compatible with cursor-based pagination" + ) + + page_size = min(page_size, 200) + + params: Dict[str, str] = {} + if post_filter: + params = post_filter.as_http_params() + params["cursor"] = cursor + params["pagination"] = str(page_size) + + async with self.http_session.get("/api/v0/posts.json", params=params) as resp: + resp.raise_for_status() + response_json = await resp.json() + posts_raw = response_json["posts"] + + posts: List[Post] = [] + for post_raw in posts_raw: + try: + posts.append(Post.model_validate(post_raw)) + except ValidationError as e: + if not ignore_invalid_messages: + raise e + if invalid_messages_log_level: + logger.log(level=invalid_messages_log_level, msg=e) + return CursorPostsResponse( + posts=posts, + pagination_per_page=response_json["pagination_per_page"], + next_cursor=response_json.get("next_cursor"), + ) + async def download_file_to_buffer( self, file_hash: str, @@ -425,6 +477,67 @@ async def get_messages( pagination_item=response_json["pagination_item"], ) + async def get_messages_cursor( + self, + page_size: int = 200, + cursor: str = "", + message_filter: Optional[MessageFilter] = None, + ignore_invalid_messages: Optional[bool] = True, + invalid_messages_log_level: Optional[int] = logging.NOTSET, + ) -> CursorMessagesResponse: + ignore_invalid_messages = ( + True if ignore_invalid_messages is None else ignore_invalid_messages + ) + invalid_messages_log_level = ( + logging.NOTSET + if invalid_messages_log_level is None + else invalid_messages_log_level + ) + + if message_filter and message_filter.sort_by == SortBy.TX_TIME: + raise ValueError( + "sortBy=tx-time is not compatible with cursor-based pagination" + ) + + page_size = min(page_size, 200) + + params: Dict[str, str] = {} + if message_filter: + params = message_filter.as_http_params() + params["cursor"] = cursor + params["pagination"] = str(page_size) + + async with self.http_session.get( + "/api/v0/messages.json", params=params + ) as resp: + resp.raise_for_status() + response_json = await resp.json() + messages_raw = response_json["messages"] + + messages: List[AlephMessage] = [] + for message_raw in messages_raw: + try: + message = parse_message(message_raw) + messages.append(message) + except KeyError as e: + if not ignore_invalid_messages: + raise e + logger.log( + level=invalid_messages_log_level, + msg=f"KeyError: Field '{e.args[0]}' not found", + ) + except ValidationError as e: + if not ignore_invalid_messages: + raise e + if invalid_messages_log_level: + logger.log(level=invalid_messages_log_level, msg=e) + + return CursorMessagesResponse( + messages=messages, + pagination_per_page=response_json["pagination_per_page"], + next_cursor=response_json.get("next_cursor"), + ) + @overload async def get_message( # type: ignore self, diff --git a/src/aleph/sdk/query/responses.py b/src/aleph/sdk/query/responses.py index 6efade14..4b901050 100644 --- a/src/aleph/sdk/query/responses.py +++ b/src/aleph/sdk/query/responses.py @@ -76,6 +76,23 @@ class MessagesResponse(PaginationResponse): pagination_item: str = "messages" +class CursorPaginationResponse(BaseModel): + pagination_per_page: int + next_cursor: Optional[str] = None + + +class CursorPostsResponse(CursorPaginationResponse): + """Cursor-paginated response from /api/v0/posts.json""" + + posts: List[Post] + + +class CursorMessagesResponse(CursorPaginationResponse): + """Cursor-paginated response from /api/v0/messages.json""" + + messages: List[AlephMessage] + + class PriceResponse(BaseModel): """Response from an aleph.im node API on the path /api/v0/price/{item_hash}""" diff --git a/tests/unit/services/test_authorizations.py b/tests/unit/services/test_authorizations.py index 7ab2b7ee..04f52eb0 100644 --- a/tests/unit/services/test_authorizations.py +++ b/tests/unit/services/test_authorizations.py @@ -143,6 +143,12 @@ async def download_file_to_path(self, *args, **kwargs): async def get_messages(self, *args, **kwargs): raise NotImplementedError + async def get_posts_cursor(self, *args, **kwargs): + raise NotImplementedError + + async def get_messages_cursor(self, *args, **kwargs): + raise NotImplementedError + async def get_message(self, *args, **kwargs): raise NotImplementedError