diff --git a/astrbot/core/agent/tool.py b/astrbot/core/agent/tool.py index 4cee6ba6d1..188284802e 100644 --- a/astrbot/core/agent/tool.py +++ b/astrbot/core/agent/tool.py @@ -200,7 +200,85 @@ def func_list(self) -> list[FunctionTool]: """Get the list of function tools.""" return self.tools - def openai_schema(self, omit_empty_parameter_field: bool = False) -> list[dict]: + @staticmethod + def _google_compatible_schema(schema: dict[str, Any]) -> dict[str, Any]: + """Convert schema to the subset accepted by Gemini function declarations.""" + supported_types = { + "string", + "number", + "integer", + "boolean", + "array", + "object", + "null", + } + supported_formats = { + "string": {"enum", "date-time"}, + "integer": {"int32", "int64"}, + "number": {"float", "double"}, + } + + if "anyOf" in schema: + return { + "anyOf": [ToolSet._google_compatible_schema(s) for s in schema["anyOf"]] + } + + result = {} + + # Avoid side effects by not modifying the original schema. + origin_type = schema.get("type") + target_type = origin_type + + # Gemini API expects 'type' to be a string, while JSON Schema allows lists. + if isinstance(origin_type, list): + target_type = next((t for t in origin_type if t != "null"), "string") + + if target_type in supported_types: + result["type"] = target_type + if "format" in schema and schema["format"] in supported_formats.get( + result["type"], + set(), + ): + result["format"] = schema["format"] + else: + result["type"] = "null" + + support_fields = { + "title", + "description", + "enum", + "minimum", + "maximum", + "maxItems", + "minItems", + "nullable", + "required", + } + result.update({k: schema[k] for k in support_fields if k in schema}) + + if "properties" in schema: + properties = {} + for key, value in schema["properties"].items(): + properties[key] = ToolSet._google_compatible_schema(value) + + if properties: + result["properties"] = properties + + if target_type == "array": + items_schema = schema.get("items") + if isinstance(items_schema, dict): + result["items"] = ToolSet._google_compatible_schema(items_schema) + else: + # Gemini requires array schemas to include an `items` schema. + result["items"] = {"type": "string"} + + return result + + def openai_schema( + self, + omit_empty_parameter_field: bool = False, + gemini_compatible_schema: bool = False, + ) -> list[dict]: """Convert tools to OpenAI API function calling schema format.""" result = [] for tool in self.tools: @@ -212,7 +290,11 @@ def openai_schema(self, omit_empty_parameter_field: bool = False) -> list[dict]: if ( tool.parameters and tool.parameters.get("properties") ) or not omit_empty_parameter_field: - func_def["function"]["parameters"] = tool.parameters + func_def["function"]["parameters"] = ( + self._google_compatible_schema(tool.parameters) + if gemini_compatible_schema + else tool.parameters + ) result.append(func_def) return result @@ -233,95 +315,13 @@ def anthropic_schema(self) -> list[dict]: def google_schema(self) -> dict: """Convert tools to Google GenAI API format.""" - - def convert_schema(schema: dict) -> dict: - """Convert schema to Gemini API format.""" - supported_types = { - "string", - "number", - "integer", - "boolean", - "array", - "object", - "null", - } - supported_formats = { - "string": {"enum", "date-time"}, - "integer": {"int32", "int64"}, - "number": {"float", "double"}, - } - - if "anyOf" in schema: - return {"anyOf": [convert_schema(s) for s in schema["anyOf"]]} - - result = {} - - # Avoid side effects by not modifying the original schema - origin_type = schema.get("type") - target_type = origin_type - - # Compatibility fix: Gemini API expects 'type' to be a string (enum), - # but standard JSON Schema (MCP) allows lists (e.g. ["string", "null"]). - # We fallback to the first non-null type. - if isinstance(origin_type, list): - target_type = next((t for t in origin_type if t != "null"), "string") - - if target_type in supported_types: - result["type"] = target_type - if "format" in schema and schema["format"] in supported_formats.get( - result["type"], - set(), - ): - result["format"] = schema["format"] - else: - result["type"] = "null" - - support_fields = { - "title", - "description", - "enum", - "minimum", - "maximum", - "maxItems", - "minItems", - "nullable", - "required", - } - result.update({k: schema[k] for k in support_fields if k in schema}) - - if "properties" in schema: - properties = {} - for key, value in schema["properties"].items(): - prop_value = convert_schema(value) - if "default" in prop_value: - del prop_value["default"] - # see #5217 - if "additionalProperties" in prop_value: - del prop_value["additionalProperties"] - properties[key] = prop_value - - if properties: - result["properties"] = properties - - if target_type == "array": - items_schema = schema.get("items") - if isinstance(items_schema, dict): - result["items"] = convert_schema(items_schema) - else: - # Gemini requires array schemas to include an `items` schema. - # JSON Schema allows omitting it, so fall back to a permissive - # string item schema instead of emitting an invalid declaration. - result["items"] = {"type": "string"} - - return result - tools = [] for tool in self.tools: d: dict[str, Any] = {"name": tool.name} if tool.description: d["description"] = tool.description if tool.parameters: - d["parameters"] = convert_schema(tool.parameters) + d["parameters"] = self._google_compatible_schema(tool.parameters) tools.append(d) declarations = {} @@ -330,8 +330,15 @@ def convert_schema(schema: dict) -> dict: return declarations @deprecated(reason="Use openai_schema() instead", version="4.0.0") - def get_func_desc_openai_style(self, omit_empty_parameter_field: bool = False): - return self.openai_schema(omit_empty_parameter_field) + def get_func_desc_openai_style( + self, + omit_empty_parameter_field: bool = False, + gemini_compatible_schema: bool = False, + ): + return self.openai_schema( + omit_empty_parameter_field, + gemini_compatible_schema=gemini_compatible_schema, + ) @deprecated(reason="Use anthropic_schema() instead", version="4.0.0") def get_func_desc_anthropic_style(self): diff --git a/astrbot/core/provider/sources/openai_source.py b/astrbot/core/provider/sources/openai_source.py index 67971a2a93..ea03c9b392 100644 --- a/astrbot/core/provider/sources/openai_source.py +++ b/astrbot/core/provider/sources/openai_source.py @@ -522,9 +522,10 @@ async def get_models(self): async def _query(self, payloads: dict, tools: ToolSet | None) -> LLMResponse: if tools: model = payloads.get("model", "").lower() - omit_empty_param_field = "gemini" in model + is_gemini_model = "gemini" in model tool_list = tools.get_func_desc_openai_style( - omit_empty_parameter_field=omit_empty_param_field, + omit_empty_parameter_field=is_gemini_model, + gemini_compatible_schema=is_gemini_model, ) if tool_list: payloads["tools"] = tool_list @@ -594,9 +595,10 @@ async def _query_stream( """流式查询API,逐步返回结果""" if tools: model = payloads.get("model", "").lower() - omit_empty_param_field = "gemini" in model + is_gemini_model = "gemini" in model tool_list = tools.get_func_desc_openai_style( - omit_empty_parameter_field=omit_empty_param_field, + omit_empty_parameter_field=is_gemini_model, + gemini_compatible_schema=is_gemini_model, ) if tool_list: payloads["tools"] = tool_list diff --git a/tests/unit/test_tool_google_schema.py b/tests/unit/test_tool_google_schema.py index f1046e6af3..eabc2df386 100644 --- a/tests/unit/test_tool_google_schema.py +++ b/tests/unit/test_tool_google_schema.py @@ -11,41 +11,60 @@ def load_tool_module(): + module_names = [ + "astrbot", + "astrbot.core", + "astrbot.core.agent", + "astrbot.core.message", + "astrbot.core.message.message_event_result", + "astrbot.core.agent.run_context", + "astrbot.core.agent.tool", + ] + missing = object() + previous_modules = {name: sys.modules.get(name, missing) for name in module_names} + package_names = [ "astrbot", "astrbot.core", "astrbot.core.agent", "astrbot.core.message", ] - for name in package_names: - if name not in sys.modules: - module = types.ModuleType(name) - module.__path__ = [] - sys.modules[name] = module - - message_result_module = types.ModuleType( - "astrbot.core.message.message_event_result" - ) - message_result_module.MessageEventResult = type("MessageEventResult", (), {}) - sys.modules[message_result_module.__name__] = message_result_module + try: + for name in package_names: + if name not in sys.modules: + module = types.ModuleType(name) + module.__path__ = [] + sys.modules[name] = module - run_context_module = types.ModuleType("astrbot.core.agent.run_context") - run_context_module.TContext = TypeVar("TContext") + message_result_module = types.ModuleType( + "astrbot.core.message.message_event_result" + ) + message_result_module.MessageEventResult = type("MessageEventResult", (), {}) + sys.modules[message_result_module.__name__] = message_result_module - class ContextWrapper(Generic[run_context_module.TContext]): - pass + run_context_module = types.ModuleType("astrbot.core.agent.run_context") + run_context_module.TContext = TypeVar("TContext") - run_context_module.ContextWrapper = ContextWrapper - sys.modules[run_context_module.__name__] = run_context_module + class ContextWrapper(Generic[run_context_module.TContext]): + pass - spec = importlib.util.spec_from_file_location( - "astrbot.core.agent.tool", TOOL_MODULE_PATH - ) - assert spec and spec.loader - module = importlib.util.module_from_spec(spec) - sys.modules[spec.name] = module - spec.loader.exec_module(module) - return module + run_context_module.ContextWrapper = ContextWrapper + sys.modules[run_context_module.__name__] = run_context_module + + spec = importlib.util.spec_from_file_location( + "astrbot.core.agent.tool", TOOL_MODULE_PATH + ) + assert spec and spec.loader + module = importlib.util.module_from_spec(spec) + sys.modules[spec.name] = module + spec.loader.exec_module(module) + return module + finally: + for name, previous_module in previous_modules.items(): + if previous_module is missing: + sys.modules.pop(name, None) + else: + sys.modules[name] = previous_module def test_google_schema_fills_missing_array_items_with_string_schema(): @@ -75,3 +94,92 @@ def test_google_schema_fills_missing_array_items_with_string_schema(): assert source_uuids["type"] == "array" assert source_uuids["items"] == {"type": "string"} + + +def test_openai_schema_keeps_raw_parameter_fields_by_default(): + tool_module = load_tool_module() + FunctionTool = tool_module.FunctionTool + ToolSet = tool_module.ToolSet + + tool = FunctionTool( + name="search_sources", + description="Search sources by query.", + parameters={ + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "Search query.", + "examples": ["astrbot"], + } + }, + "examples": [{"query": "astrbot"}], + }, + ) + + schema = ToolSet([tool]).openai_schema() + parameters = schema[0]["function"]["parameters"] + + assert parameters["examples"] == [{"query": "astrbot"}] + assert parameters["properties"]["query"]["examples"] == ["astrbot"] + + +def test_openai_schema_can_sanitize_gemini_parameter_fields(): + tool_module = load_tool_module() + FunctionTool = tool_module.FunctionTool + ToolSet = tool_module.ToolSet + + tool = FunctionTool( + name="search_sources", + description="Search sources by query.", + parameters={ + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "Search query.", + "examples": ["astrbot"], + "default": "astrbot", + }, + "filters": { + "type": "object", + "description": "Nested filters.", + "properties": { + "tag": { + "type": "string", + "examples": ["docs"], + } + }, + "additionalProperties": False, + }, + "items": { + "type": "array", + "description": "Search result items.", + "items": { + "type": "object", + "properties": { + "id": { + "type": "string", + "examples": ["source-1"], + } + }, + }, + }, + }, + "required": ["query"], + "examples": [{"query": "astrbot"}], + "additionalProperties": False, + }, + ) + + schema = ToolSet([tool]).openai_schema(gemini_compatible_schema=True) + parameters = schema[0]["function"]["parameters"] + + assert "examples" not in parameters + assert "additionalProperties" not in parameters + assert "examples" not in parameters["properties"]["query"] + assert "default" not in parameters["properties"]["query"] + assert "additionalProperties" not in parameters["properties"]["filters"] + assert "examples" not in parameters["properties"]["filters"]["properties"]["tag"] + item_id = parameters["properties"]["items"]["items"]["properties"]["id"] + assert "examples" not in item_id