Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 14 additions & 14 deletions examples/ai/chat/pydantic-ai-chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import marimo

__generated_with = "0.23.6"
__generated_with = "0.23.10"
app = marimo.App(width="medium")

with app.setup(hide_code=True):
Expand Down Expand Up @@ -272,19 +272,6 @@ def _pending_approval(messages) -> dict | None:
return None


async def custom_model(messages, config):
del config

pending = _pending_approval(messages)
if pending is not None:
async for chunk in _resume_after_approval(pending):
yield chunk
return

async for chunk in _showcase_turn():
yield chunk


async def _showcase_turn():
reasoning_id = _new_id("reasoning")
search_id = _new_id("tc")
Expand Down Expand Up @@ -497,6 +484,19 @@ async def _resume_after_approval(pending: dict):
yield vercel.FinishChunk(finish_reason="stop")


async def custom_model(messages, config):
del config

pending = _pending_approval(messages)
if pending is not None:
async for chunk in _resume_after_approval(pending):
yield chunk
return

async for chunk in _showcase_turn():
yield chunk


custom_chat = mo.ui.chat(
custom_model,
prompts=[
Expand Down
31 changes: 5 additions & 26 deletions marimo/_ai/llm/_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -796,21 +796,6 @@ def _serialize_vercel_ai_chunk(
result,
)
return result # type: ignore[no-any-return]
except TypeError:
# Fallback for pydantic-ai < 1.52.0 which doesn't have sdk_version param
try:
# by_alias=True: Use camelCase keys expected by Vercel AI SDK.
# exclude_none=True: Remove null values which cause validation errors.
serialized = chunk.model_dump(
mode="json", by_alias=True, exclude_none=True
)
except Exception as e:
LOGGER.error("Error serializing vercel ai chunk: %s", e)
return None
else:
if serialized.get("type") == "done":
return None
return serialized
except Exception as e:
LOGGER.error("Error serializing vercel ai chunk: %s", e)
return None
Expand Down Expand Up @@ -838,17 +823,11 @@ async def _stream_response(
messages=ui_messages,
)

try:
adapter = VercelAIAdapter(
agent=self.agent,
run_input=run_input,
sdk_version=AI_SDK_VERSION,
)
except TypeError:
adapter = VercelAIAdapter(
agent=self.agent,
run_input=run_input,
)
adapter = VercelAIAdapter(
agent=self.agent,
run_input=run_input,
sdk_version=AI_SDK_VERSION,
)
event_stream = adapter.run_stream(model_settings=model_settings)
async for event in event_stream:
if serialized := self._serialize_vercel_ai_chunk(event):
Expand Down
10 changes: 0 additions & 10 deletions marimo/_plugins/ui/_impl/chat/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,6 @@
DONE_CHUNK: Final[str] = "[DONE]"


def require_vercel_ai_sdk_support() -> None:
"""Only Pydantic AI >=1.52.0 supports AI SDK v6. So, we require it."""
DependencyManager.pydantic_ai.require_at_version(
why="for Vercel AI SDK support", min_version="1.52.0"
)


@dataclass
class SendMessageRequest:
messages: list[ChatMessage]
Expand Down Expand Up @@ -416,7 +409,6 @@ def _emit_cancellation_chunks(
abort_payload: dict[str, Any] | None = None
if DependencyManager.pydantic_ai.imported():
try:
require_vercel_ai_sdk_support()
from pydantic_ai.ui.vercel_ai.response_types import (
AbortChunk,
)
Expand Down Expand Up @@ -580,7 +572,6 @@ def _convert_value(self, value: dict[str, Any]) -> list[ChatMessage]:

part_validator_class = None
if DependencyManager.pydantic_ai.imported():
require_vercel_ai_sdk_support()
from pydantic_ai.ui.vercel_ai.request_types import UIMessagePart

# The frontend sends messages as ChatMessage parts so we use pydantic-ai to cast them
Expand Down Expand Up @@ -645,7 +636,6 @@ def handle_chunk(self, chunk: Any) -> None:

# Handle Pydantic AI's Vercel AI SDK chunks
if DependencyManager.pydantic_ai.imported():
require_vercel_ai_sdk_support()
from pydantic_ai.ui.vercel_ai.response_types import (
BaseChunk,
)
Expand Down
39 changes: 5 additions & 34 deletions marimo/_server/ai/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
from marimo._dependencies.dependencies import Dependency, DependencyManager
from marimo._plugins.ui._impl.chat.chat import (
AI_SDK_VERSION,
require_vercel_ai_sdk_support,
)
from marimo._server.ai.config import AnyProviderConfig
from marimo._server.ai.constants import ANTHROPIC_DEFAULT_MAX_TOKENS
Expand Down Expand Up @@ -107,7 +106,6 @@ def __init__(
*(deps or []),
source="server",
)
require_vercel_ai_sdk_support()

self.model: str = model
self.config: AnyProviderConfig = config
Expand Down Expand Up @@ -150,11 +148,13 @@ def _build_agent_settings(self, model: Model) -> ModelSettings | None:
thinking = self._default_thinking(model)
if thinking is None:
return None

if not (
model.profile.supports_thinking
or model.profile.thinking_always_enabled
model.profile.get("supports_thinking", False)
or model.profile.get("thinking_always_enabled", False)
):
return None

return ModelSettings(thinking=thinking)

def _default_thinking(self, model: Model) -> ThinkingLevel | None:
Expand Down Expand Up @@ -333,7 +333,7 @@ def create_provider(self, config: AnyProviderConfig) -> PydanticGoogle:
)
else:
# Try default initialization which may work with environment variables
provider = PydanticGoogle()
provider = PydanticGoogle() # type: ignore[call-overload]
return provider

@override
Expand Down Expand Up @@ -852,15 +852,6 @@ def _default_thinking(self, model: Model) -> ThinkingLevel | None:


class AnthropicProvider(PydanticProvider["PydanticAnthropic"]):
# Temperature of 0.2 was recommended for coding and data science in these links:
# https://community.openai.com/t/cheat-sheet-mastering-temperature-and-top-p-in-chatgpt-api/172683
# https://docs.anthropic.com/en/docs/test-and-evaluate/strengthen-guardrails/reduce-latency?utm_source=chatgpt.com
DEFAULT_TEMPERATURE: float = 0.2

# Extended thinking requires temperature of 1.
# https://docs.anthropic.com/en/docs/build-with-claude/extended-thinking
DEFAULT_EXTENDED_THINKING_TEMPERATURE: float = 1

@override
def create_provider(self, config: AnyProviderConfig) -> PydanticAnthropic:
from pydantic_ai.providers.anthropic import (
Expand All @@ -875,33 +866,13 @@ def create_model(self, max_tokens: int | None) -> Model:
AnthropicModel,
AnthropicModelSettings,
)
from pydantic_ai.profiles.anthropic import (
AnthropicModelProfile,
anthropic_model_profile,
)

settings: AnthropicModelSettings = {
"max_tokens": max_tokens
if max_tokens is not None
else ANTHROPIC_DEFAULT_MAX_TOKENS
}

# Anthropic extended thinking requires temperature=1; non-thinking
# models keep our default coding temperature. Some adaptive-only
# models (Opus 4.7+) reject sampling settings entirely β€” skip
# `temperature` for them so pydantic-ai doesn't drop it with a warning.
profile = AnthropicModelProfile.from_profile(
anthropic_model_profile(self.model)
)
if not getattr(
profile, "anthropic_disallows_sampling_settings", False
):
settings["temperature"] = (
self.DEFAULT_EXTENDED_THINKING_TEMPERATURE
if profile.supports_thinking
else self.DEFAULT_TEMPERATURE
)

return AnthropicModel(
model_name=self.model,
provider=self.provider,
Expand Down
4 changes: 2 additions & 2 deletions marimo/_server/ai/tools/code_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
def build_execute_code_toolset(
session: Session,
request: Request,
) -> FunctionToolset[None]:
) -> FunctionToolset:
"""Build a `FunctionToolset` exposing one tool: `execute_code`.

The tool is bound to the caller's *session* and *request*; the model
Expand All @@ -31,7 +31,7 @@ def build_execute_code_toolset(

from pydantic_ai import FunctionToolset

toolset: FunctionToolset[None] = FunctionToolset()
toolset: FunctionToolset = FunctionToolset()

async def execute_code(code: str) -> CodeExecutionResult:
"""Run Python inside the running notebook's kernel scratchpad.
Expand Down
4 changes: 1 addition & 3 deletions marimo/_tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,9 +250,7 @@ def _instrument_ai(provider: trace.TracerProvider) -> None:
from pydantic_ai import Agent
from pydantic_ai.models.instrumented import InstrumentationSettings

Agent.instrument_all(
InstrumentationSettings(tracer_provider=provider, version=5)
)
Agent.instrument_all(InstrumentationSettings(tracer_provider=provider))
LOGGER.debug("Enabled AI instrumentation")
except Exception as e:
LOGGER.debug("AI instrumentation failed: %s", e)
Expand Down
8 changes: 4 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ recommended = [
"marimo[sql]",
"marimo[sandbox]", # For `marimo edit --sandbox DIRECTORY`
"altair>=5.4.0", # Plotting in datasource viewer
"pydantic-ai-slim[openai]>=1.107.0,<2.0.0", # AI features
"pydantic-ai-slim[openai]>=1.107.0,<3.0.0", # AI features
"ruff", # Formatting
"nbformat>=5.7.0", # Export as IPYNB
]
Expand Down Expand Up @@ -142,7 +142,7 @@ dev = [
# For linting
"ruff>=0.15.16",
# For AI
"pydantic-ai-slim[openai]>=1.107.0,<2.0.0",
"pydantic-ai-slim[openai]>=1.107.0,<3.0.0",
]

test = [
Expand Down Expand Up @@ -205,7 +205,7 @@ test-optional = [
"anywidget~=0.9.21",
"ipython~=8.12.3",
# testing gen ai
"pydantic-ai-slim[google,anthropic,bedrock,openai]>=1.107.0,<2.0.0",
"pydantic-ai-slim[google,anthropic,bedrock,openai]>=1.107.0,<3.0.0",
# - google-auth uses cachetools, and cachetools<5.0.0 uses collections.MutableMapping (removed in Python 3.10)
"cachetools>=5.0.0",
"boto3>=1.38.46",
Expand Down Expand Up @@ -242,7 +242,7 @@ typecheck = [
"sqlalchemy>=2.0.40",
"obstore>=0.8.2",
"fsspec>=2026.2.0",
"pydantic-ai-slim[google,anthropic,bedrock,openai]>=1.107.0,<2.0.0",
"pydantic-ai-slim[google,anthropic,bedrock,openai]>=1.107.0,<3.0.0",
"loro>=1.5.0",
"boto3-stubs>=1.38.46",
"pandas-stubs>=1.5.3.230321",
Expand Down
13 changes: 5 additions & 8 deletions tests/_ai/llm/test_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -1673,8 +1673,8 @@ def test_build_ui_messages_preserves_tool_approval_field(self):
async def test_stream_response_emits_tool_approval_request(self):
"""Tools with `requires_approval=True` should surface an
approval-request chunk so the frontend can render an Approve/Deny
card. This is the v6-only behavior unlocked by passing
`sdk_version=AI_SDK_VERSION` to the adapter.
approval. This behavior requires passing `sdk_version=AI_SDK_VERSION`
to the adapter.
"""
from pydantic_ai import Agent, DeferredToolRequests
from pydantic_ai.models.function import (
Expand Down Expand Up @@ -1725,8 +1725,7 @@ def delete_file(path: str) -> str:
for chunk in chunks
if chunk.get("type") == "tool-approval-request"
]
# Older pydantic-ai generates a UUID approvalId; newer versions reuse
# toolCallId. Either is fine β€” assert the shape, not the exact value.
# pydantic-ai may reuse toolCallId as approvalId; assert the shape.
assert len(approval_chunks) == 1
chunk = approval_chunks[0]
assert chunk["type"] == "tool-approval-request"
Expand All @@ -1738,10 +1737,8 @@ def delete_file(path: str) -> str:
class MockBaseChunkWithError:
"""Mock BaseChunk that raises on serialization."""

def model_dump(
self, mode: str, by_alias: bool, exclude_none: bool
) -> dict[str, Any]:
del mode, by_alias, exclude_none
def encode(self, *, sdk_version: int) -> str:
del sdk_version
raise ValueError("Serialization error")


Expand Down
14 changes: 0 additions & 14 deletions tests/_ai/test_pydantic_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,20 +35,6 @@ def test_generate_id_with_empty_prefix(self):
assert result.startswith("_")


def _has_pydantic_function_like() -> bool:
"""Check if pydantic has the _function_like attribute required by pydantic-ai."""
try:
from pydantic._internal import _decorators

return hasattr(_decorators, "_function_like")
except ImportError:
return False


@pytest.mark.skipif(
not _has_pydantic_function_like(),
reason="pydantic version missing _function_like (required by pydantic-ai)",
)
class TestFormToolsets:
def test_form_toolsets_empty_list(self):
tool_invoker = AsyncMock()
Expand Down
2 changes: 1 addition & 1 deletion tests/_server/ai/test_ai_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,7 +443,7 @@ def test_for_model_openrouter(self) -> None:
config: AiConfig = {"openrouter": {"api_key": "test-openrouter-key"}}

provider_config = AnyProviderConfig.for_model(
"openrouter/gpt-4", config
"openrouter/openai/gpt-4", config
)

assert provider_config.api_key == "test-openrouter-key"
Expand Down
Loading
Loading