Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion wayflowcore/src/wayflowcore/_utils/async_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ async def run_async_function_in_parallel(
passed inputs, with a given max number of workers
"""
max_workers_semaphore: AsyncContextManager[Any] = (
anyio.Semaphore(initial_value=max_workers) # type: ignore
anyio.Semaphore(initial_value=max_workers)
if max_workers is not None
else contextlib.nullcontext()
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
from wayflowcore._metadata import MetadataType
from wayflowcore._utils.async_helpers import run_async_in_sync
from wayflowcore.embeddingmodels.embeddingmodel import EmbeddingModel
from wayflowcore.models._requesthelpers import _RetryStrategy, request_post_with_retries
from wayflowcore.models._requesthelpers import request_post_with_retries
from wayflowcore.retrypolicy import RetryPolicy
from wayflowcore.serialization.context import DeserializationContext, SerializationContext
from wayflowcore.serialization.serializer import SerializableObject

Expand All @@ -36,6 +37,7 @@ def __init__(
id: Optional[str] = None,
name: Optional[str] = None,
description: Optional[str] = None,
retry_policy: Optional[RetryPolicy] = None,
):
super().__init__(
__metadata_info__=__metadata_info__,
Expand All @@ -45,7 +47,7 @@ def __init__(
)
self._model_id = model_id
self._base_url = _add_leading_http_if_needed(base_url).rstrip("/")
self._retry_strategy = _RetryStrategy()
self.retry_policy = retry_policy

def _get_headers(self) -> Dict[str, str]:
"""
Expand Down Expand Up @@ -73,7 +75,7 @@ async def embed_async(self, data: List[str]) -> List[List[float]]:
headers = self._get_headers()
response_data = await request_post_with_retries(
request_params=dict(url=url, headers=headers, json=payload),
retry_strategy=self._retry_strategy,
retry_policy=self.retry_policy,
)

return [item["embedding"] for item in response_data["data"]]
Expand Down
6 changes: 6 additions & 0 deletions wayflowcore/src/wayflowcore/mcp/clienttransport.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from mcp.client.streamable_http import streamablehttp_client
from typing_extensions import TypeAlias

from wayflowcore.retrypolicy import RetryPolicy
from wayflowcore.serialization.serializer import (
SerializableDataclass,
SerializableDataclassMixin,
Expand Down Expand Up @@ -167,6 +168,9 @@ class RemoteBaseTransport(SerializableDataclass, ClientTransport, ABC):
timeout: float = 5
"""The timeout for the HTTP request. Defaults to 5 seconds."""

retry_policy: Optional[RetryPolicy] = None
"""Optional retry policy configuration applied to MCP HTTP calls."""

sse_read_timeout: float = 60 * 5
"""The timeout for the SSE connection, in seconds. Defaults to 5 minutes."""

Expand All @@ -182,6 +186,8 @@ class RemoteBaseTransport(SerializableDataclass, ClientTransport, ABC):
"""Arguments for the MCP session."""

def __post_init__(self) -> None:
if self.retry_policy is not None and not isinstance(self.retry_policy, RetryPolicy):
raise TypeError("retry_policy must be a wayflowcore.retrypolicy.RetryPolicy instance")
repeated_headers = set(self.headers or {}).intersection(set(self.sensitive_headers or {}))
if repeated_headers:
raise ValueError(
Expand Down
Loading