From ab1a85c42a2eaa75c8173f05ef12e987bfb2dbf9 Mon Sep 17 00:00:00 2001 From: francis2tm Date: Tue, 28 Jan 2025 22:38:18 +0000 Subject: [PATCH] feat: function calling --- packages/sui-agent/src/config/atoma.ts | 14 ++- .../sui-agent/src/tools/aftermath/index.ts | 96 +++++++++++++++---- 2 files changed, 90 insertions(+), 20 deletions(-) diff --git a/packages/sui-agent/src/config/atoma.ts b/packages/sui-agent/src/config/atoma.ts index ce1d1b5..4cf2980 100644 --- a/packages/sui-agent/src/config/atoma.ts +++ b/packages/sui-agent/src/config/atoma.ts @@ -16,18 +16,30 @@ export function initializeAtomaSDK(bearerAuth: string): AtomaSDK { * @param sdk - Initialized Atoma SDK instance * @param messages - Array of message objects with content and role * @param model - Optional model identifier (defaults to Llama-3.3-70B-Instruct) + * @param functions - Optional array of function definitions for function calling * @returns Chat completion response */ async function atomaChat( sdk: AtomaSDK, messages: { content: string; role: string }[], model?: string, + functions?: Array<{ + name: string; + description: string; + parameters: { + type: string; + properties: Record; + required?: string[]; + }; + }>, ) { try { return await sdk.chat.create({ messages, model: model || ATOMA_CHAT_COMPLETIONS_MODEL, maxTokens: 4096, + functions: functions, + functionCall: functions ? 'auto' : undefined, }); } catch (error) { // Log the error for monitoring @@ -59,7 +71,7 @@ async function atomaChat( /** * Health check function that returns service status * @param sdk - Initialized Atoma SDK instance - * @returns Boolean indicating if service is healthy + * @returns Promise indicating if service is healthy */ async function isAtomaHealthy(sdk: AtomaSDK): Promise { try { diff --git a/packages/sui-agent/src/tools/aftermath/index.ts b/packages/sui-agent/src/tools/aftermath/index.ts index ab22e28..2971a9a 100644 --- a/packages/sui-agent/src/tools/aftermath/index.ts +++ b/packages/sui-agent/src/tools/aftermath/index.ts @@ -37,33 +37,91 @@ class Tools { this.tools.push({ name, description, parameters, process }); } + /** + * Convert tool parameters to OpenAI function parameters format + * @param parameters - Tool parameters + * @returns OpenAI function parameters format + */ + private convertToFunctionParameters(parameters: Tool['parameters']) { + const properties: Record = {}; + const required: string[] = []; + + parameters.forEach((param) => { + properties[param.name] = { + type: param.type, + description: param.description, + }; + if (param.required) { + required.push(param.name); + } + }); + + return { + type: 'object', + properties, + required: required.length > 0 ? required : undefined, + }; + } + + /** + * Convert registered tools to OpenAI function format + * @returns Array of functions in OpenAI format + */ + private getToolsAsFunctions() { + return this.tools.map((tool) => ({ + name: tool.name, + description: tool.description, + parameters: this.convertToFunctionParameters(tool.parameters), + })); + } + /** * Select appropriate tool based on user query * @param query - User query * @returns Selected tool response or null if no tool found */ async selectAppropriateTool(query: string): Promise { - const finalPrompt = this.prompt.replace( - '${toolsList}', - JSON.stringify(this.getAllTools()), + const functions = this.getToolsAsFunctions(); + + const ai: any = await atomaChat( + this.sdk, + [ + { + content: this.prompt, + role: 'system', + }, + { + content: query || '', + role: 'user', + }, + ], + undefined, + functions, ); - const ai: any = await atomaChat(this.sdk, [ - { - content: finalPrompt, - role: 'system', - }, - { - content: query || '', - role: 'user', - }, - ]); - const res = ai.choices[0].message.content; - - const applicableTools: toolResponse[] = JSON.parse(res); - if (applicableTools.length > 0) return applicableTools[0]; - - return null; + const message = ai.choices[0].message; + if (!message.function_call) { + return null; + } + + const selectedTool = this.tools.find( + (tool) => tool.name === message.function_call.name, + ); + if (!selectedTool) { + return null; + } + + const args = JSON.parse(message.function_call.arguments); + const toolArgs = selectedTool.parameters.map((param) => args[param.name]); + + return { + success: true, + selected_tool: selectedTool.name, + response: null, + needs_additional_info: false, + additional_info_required: null, + tool_arguments: toolArgs, + }; } /**