Skip to content

Commit d0c665a

Browse files
committed
check if transfer target is a sibling agent
1 parent 69997cd commit d0c665a

File tree

2 files changed

+71
-0
lines changed

2 files changed

+71
-0
lines changed

src/google/adk/flows/llm_flows/base_llm_flow.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -723,6 +723,15 @@ def _get_agent_to_run(
723723
agent_to_run = root_agent.find_agent(agent_name)
724724
if not agent_to_run:
725725
raise ValueError(f'Agent {agent_name} not found in the agent tree.')
726+
727+
from ...agents.llm_agent import LlmAgent
728+
729+
if all([
730+
isinstance(invocation_context.agent, LlmAgent),
731+
invocation_context.agent.disallow_transfer_to_peers,
732+
agent_to_run.parent_agent == invocation_context.agent.parent_agent,
733+
]):
734+
raise ValueError(f'Transfer to sibling agent {agent_name} is disallowed.')
726735
return agent_to_run
727736

728737
async def _call_llm_async(
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
"""Tests for BaseLlmFlow._get_agent_to_run transfer-to-peers behavior."""
2+
3+
from __future__ import annotations
4+
5+
from google.adk.agents.llm_agent import LlmAgent
6+
from google.adk.flows.llm_flows.base_llm_flow import BaseLlmFlow
7+
import pytest
8+
9+
from ... import testing_utils
10+
11+
12+
def make_agent_tree():
13+
root = LlmAgent(name='root')
14+
child1 = LlmAgent(name='child1')
15+
child2 = LlmAgent(name='child2')
16+
17+
child1.parent_agent = root
18+
child2.parent_agent = root
19+
root.sub_agents = [child1, child2]
20+
return root, child1, child2
21+
22+
23+
@pytest.mark.asyncio
24+
async def test_transfer_to_sibling_disallowed_raises():
25+
root, child1, child2 = make_agent_tree()
26+
27+
caller = child1
28+
caller.disallow_transfer_to_peers = True
29+
30+
ctx = await testing_utils.create_invocation_context(caller)
31+
32+
flow = BaseLlmFlow()
33+
34+
with pytest.raises(ValueError) as exc:
35+
flow._get_agent_to_run(ctx, 'child2')
36+
37+
38+
@pytest.mark.asyncio
39+
async def test_transfer_to_sibling_allowed_returns_agent():
40+
root, child1, child2 = make_agent_tree()
41+
42+
caller = child1
43+
caller.disallow_transfer_to_peers = False
44+
45+
ctx = await testing_utils.create_invocation_context(caller)
46+
47+
flow = BaseLlmFlow()
48+
agent = flow._get_agent_to_run(ctx, 'child2')
49+
assert agent is not None
50+
assert agent.name == 'child2'
51+
52+
53+
@pytest.mark.asyncio
54+
async def test_transfer_to_unknown_agent_raises():
55+
root, child1, child2 = make_agent_tree()
56+
57+
caller = child1
58+
ctx = await testing_utils.create_invocation_context(caller)
59+
flow = BaseLlmFlow()
60+
61+
with pytest.raises(ValueError) as exc:
62+
flow._get_agent_to_run(ctx, 'not_in_tree')

0 commit comments

Comments
 (0)