Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions libs/core/langchain_core/tools/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -877,6 +877,7 @@ def run(
name=run_name,
run_id=run_id,
inputs=filtered_tool_input,
tool_call_id=tool_call_id,
**kwargs,
)

Expand Down Expand Up @@ -1004,6 +1005,7 @@ async def arun(
name=run_name,
run_id=run_id,
inputs=filtered_tool_input,
tool_call_id=tool_call_id,
**kwargs,
)
content = None
Expand Down
194 changes: 194 additions & 0 deletions libs/core/tests/unit_tests/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -3127,3 +3127,197 @@ class MockRuntime:
assert captured is not None
assert captured == {"query": "test", "limit": 5}
assert "runtime" not in captured


class CallbackHandlerWithToolCallIdCapture(FakeCallbackHandler):
"""Callback handler that captures `tool_call_id` passed to `on_tool_start`.

Used to verify that `tool_call_id` is correctly forwarded to the `on_tool_start`
callback method.
"""

captured_tool_call_ids: list[str | None] = []

def on_tool_start(
self,
serialized: dict[str, Any],
input_str: str,
*,
run_id: Any,
parent_run_id: Any | None = None,
tags: list[str] | None = None,
metadata: dict[str, Any] | None = None,
inputs: dict[str, Any] | None = None,
tool_call_id: str | None = None,
**kwargs: Any,
) -> Any:
"""Capture the `tool_call_id` passed to `on_tool_start`.

Args:
serialized: Serialized tool information.
input_str: String representation of tool input.
run_id: Unique identifier for this run.
parent_run_id: Identifier of the parent run.
tags: Optional tags for this run.
metadata: Optional metadata for this run.
inputs: Dictionary of tool inputs.
tool_call_id: The tool call identifier from the LLM.
**kwargs: Additional keyword arguments.

Returns:
Result from parent `on_tool_start` call.
"""
self.captured_tool_call_ids.append(tool_call_id)
return super().on_tool_start(
serialized,
input_str,
run_id=run_id,
parent_run_id=parent_run_id,
tags=tags,
metadata=metadata,
inputs=inputs,
**kwargs,
)


@pytest.mark.parametrize("method", ["invoke", "ainvoke"])
async def test_tool_call_id_passed_to_on_tool_start_callback(method: str) -> None:
"""Test that `tool_call_id` is passed to the `on_tool_start` callback."""

@tool
def simple_tool(query: str) -> str:
"""Simple tool for testing.

Args:
query: The query string.
"""
return f"Result: {query}"

handler = CallbackHandlerWithToolCallIdCapture(captured_tool_call_ids=[])

tool_call: ToolCall = {
"name": "simple_tool",
"args": {"query": "test"},
"id": "test_tool_call_id_123",
"type": "tool_call",
}

if method == "ainvoke":
result = await simple_tool.ainvoke(tool_call, config={"callbacks": [handler]})
else:
result = simple_tool.invoke(tool_call, config={"callbacks": [handler]})

assert result == ToolMessage(
content="Result: test", name="simple_tool", tool_call_id="test_tool_call_id_123"
)
assert handler.tool_starts == 1
assert len(handler.captured_tool_call_ids) == 1
assert handler.captured_tool_call_ids[0] == "test_tool_call_id_123"


def test_tool_call_id_none_when_invoked_without_tool_call() -> None:
"""Test that `tool_call_id` is `None` when tool is invoked without a `ToolCall`.

When a tool is invoked directly with arguments (not via a `ToolCall`),
the `tool_call_id` should be `None` in the callback.
"""

@tool
def simple_tool(query: str) -> str:
"""Simple tool for testing.

Args:
query: The query string.
"""
return f"Result: {query}"

handler = CallbackHandlerWithToolCallIdCapture(captured_tool_call_ids=[])

# Invoke tool directly with arguments, not a ToolCall
result = simple_tool.invoke({"query": "test"}, config={"callbacks": [handler]})

assert result == "Result: test"
assert handler.tool_starts == 1
assert len(handler.captured_tool_call_ids) == 1
# tool_call_id should be None when not invoked with a ToolCall
assert handler.captured_tool_call_ids[0] is None


def test_tool_call_id_empty_string_passed_to_callback() -> None:
"""Test that empty string `tool_call_id` is correctly passed to callback.

Some systems may use empty strings as `tool_call_id`, and this should
be passed through correctly (not converted to `None`).
"""

@tool
def simple_tool(query: str) -> str:
"""Simple tool for testing.

Args:
query: The query string.
"""
return f"Result: {query}"

handler = CallbackHandlerWithToolCallIdCapture(captured_tool_call_ids=[])

# Invoke tool with empty string tool_call_id
tool_call: ToolCall = {
"name": "simple_tool",
"args": {"query": "test"},
"id": "",
"type": "tool_call",
}

result = simple_tool.invoke(tool_call, config={"callbacks": [handler]})

assert result == ToolMessage(
content="Result: test", name="simple_tool", tool_call_id=""
)
assert handler.tool_starts == 1
assert len(handler.captured_tool_call_ids) == 1
# Empty string should be passed as-is, not converted to None
assert handler.captured_tool_call_ids[0] == ""


@pytest.mark.parametrize("method", ["run", "arun"])
async def test_tool_call_id_passed_via_run_method(method: str) -> None:
"""Test that `tool_call_id` is passed to callback when using run/arun method.

The `run()` and `arun()` methods are the lower-level APIs that `invoke()`
and `ainvoke()` call internally. This test ensures `tool_call_id` works
at this level as well.
"""

@tool
def simple_tool(query: str) -> str:
"""Simple tool for testing.

Args:
query: The query string.
"""
return f"Result: {query}"

handler = CallbackHandlerWithToolCallIdCapture(captured_tool_call_ids=[])

if method == "arun":
result = await simple_tool.arun(
{"query": "test"},
callbacks=[handler],
tool_call_id="run_method_tool_call_id",
)
else:
result = simple_tool.run(
{"query": "test"},
callbacks=[handler],
tool_call_id="run_method_tool_call_id",
)

assert result == ToolMessage(
content="Result: test",
name="simple_tool",
tool_call_id="run_method_tool_call_id",
)
assert handler.tool_starts == 1
assert len(handler.captured_tool_call_ids) == 1
assert handler.captured_tool_call_ids[0] == "run_method_tool_call_id"