From e49cab5024f79b188a444d3e898e65263d790969 Mon Sep 17 00:00:00 2001 From: Theo van Kraay Date: Thu, 3 Jul 2025 20:12:40 +0100 Subject: [PATCH] changes to support local model --- python/src/app/banking_agents.py | 1 + python/src/app/banking_agents_api.py | 47 ++++++++++++++------------ python/src/app/services/local_model.py | 31 +++++++++++++++++ python/src/app/tools/sales.py | 2 +- 4 files changed, 59 insertions(+), 22 deletions(-) create mode 100644 python/src/app/services/local_model.py diff --git a/python/src/app/banking_agents.py b/python/src/app/banking_agents.py index d2b61823c..3ecd50dc3 100644 --- a/python/src/app/banking_agents.py +++ b/python/src/app/banking_agents.py @@ -9,6 +9,7 @@ from langgraph_checkpoint_cosmosdb import CosmosDBSaver from langsmith import traceable from src.app.services.azure_open_ai import model +# from src.app.services.local_model import model # Use local model from src.app.services.azure_cosmos_db import DATABASE_NAME, checkpoint_container, chat_container, \ update_chat_container, patch_active_agent from src.app.tools.sales import get_offer_information, calculate_monthly_payment, create_account diff --git a/python/src/app/banking_agents_api.py b/python/src/app/banking_agents_api.py index 3c445b416..14c1c6a75 100644 --- a/python/src/app/banking_agents_api.py +++ b/python/src/app/banking_agents_api.py @@ -6,7 +6,7 @@ from datetime import datetime from fastapi import BackgroundTasks -from azure.monitor.opentelemetry import configure_azure_monitor +#from azure.monitor.opentelemetry import configure_azure_monitor from azure.cosmos.exceptions import CosmosHttpResponseError @@ -16,6 +16,7 @@ from pydantic import BaseModel from typing import List, Dict from src.app.services.azure_open_ai import model +# from src.app.services.local_model import model # Use local model from langgraph_checkpoint_cosmosdb import CosmosDBSaver from langgraph.graph.state import CompiledStateGraph from starlette.middleware.cors import CORSMiddleware @@ -32,7 +33,7 @@ load_dotenv(override=False) -configure_azure_monitor() +#configure_azure_monitor() endpointTitle = "ChatEndpoints" @@ -134,25 +135,29 @@ def store_debug_log(sessionId, tenantId, userId, response_data): if "messages" in details: for msg in details["messages"]: if hasattr(msg, "response_metadata"): - metadata = msg.response_metadata - finish_reason = metadata.get("finish_reason", finish_reason) - model_name = metadata.get("model_name", model_name) - system_fingerprint = metadata.get("system_fingerprint", system_fingerprint) - input_tokens = metadata.get("token_usage", {}).get("prompt_tokens", input_tokens) - output_tokens = metadata.get("token_usage", {}).get("completion_tokens", output_tokens) - total_tokens = metadata.get("token_usage", {}).get("total_tokens", total_tokens) - cached_tokens = metadata.get("token_usage", {}).get("prompt_tokens_details", {}).get( - "cached_tokens", cached_tokens) - logprobs = metadata.get("logprobs", logprobs) - content_filter_results = metadata.get("content_filter_results", content_filter_results) - - if "tool_calls" in msg.additional_kwargs: - tool_calls.extend(msg.additional_kwargs["tool_calls"]) - transfer_success = any( - call.get("name", "").startswith("transfer_to_") for call in tool_calls) - previous_agent = agent_selected - agent_selected = tool_calls[-1].get("name", "").replace("transfer_to_", - "") if tool_calls else agent_selected + metadata = getattr(msg, "response_metadata", None) + if metadata: + finish_reason = metadata.get("finish_reason", finish_reason) + model_name = metadata.get("model_name", model_name) + system_fingerprint = metadata.get("system_fingerprint", system_fingerprint) + + token_usage = metadata.get("token_usage", {}) or {} + input_tokens = token_usage.get("prompt_tokens", input_tokens) + output_tokens = token_usage.get("completion_tokens", output_tokens) + total_tokens = token_usage.get("total_tokens", total_tokens) + + prompt_details = token_usage.get("prompt_tokens_details", {}) or {} + cached_tokens = prompt_details.get("cached_tokens", cached_tokens) + + logprobs = metadata.get("logprobs", logprobs) + content_filter_results = metadata.get("content_filter_results", content_filter_results) + + if "tool_calls" in msg.additional_kwargs: + tool_calls.extend(msg.additional_kwargs["tool_calls"]) + transfer_success = any( + call.get("name", "").startswith("transfer_to_") for call in tool_calls) + previous_agent = agent_selected + agent_selected = tool_calls[-1].get("name", "").replace("transfer_to_", "") if tool_calls else agent_selected property_bag = [ {"key": "agent_selected", "value": agent_selected, "timeStamp": timestamp}, diff --git a/python/src/app/services/local_model.py b/python/src/app/services/local_model.py new file mode 100644 index 000000000..d875a9471 --- /dev/null +++ b/python/src/app/services/local_model.py @@ -0,0 +1,31 @@ +from langchain_openai import ChatOpenAI + +from openai import OpenAI + +# tested with LM Studio 1.2.0 and Qwen2.5-7b-instruct model +# Ensure you have the LM Studio server running and the model loaded +# replace imports with local model imports: +# from src.app.services.local_model import model (in banking_agents.py and banking_agents_api.py) +# from src.app.services.local_model import generate_embedding (in sales.py) + + +model = ChatOpenAI( + model_name="qwen2.5-7b-instruct", + openai_api_base="http://172.26.208.1:1234/v1", + openai_api_key="lm-studio", # Arbitrary, just needs to be set + temperature=0, + max_tokens=1024, +) + +client = OpenAI( + base_url="http://localhost:1235/v1", # LM Studio embedding model + api_key="lm-studio" +) + +def generate_embedding(text): + response = client.embeddings.create( + input=text, + model="nomic-embed-text-v1.5" + ) + return response.data[0].embedding + diff --git a/python/src/app/tools/sales.py b/python/src/app/tools/sales.py index a80aa8ed2..675515502 100644 --- a/python/src/app/tools/sales.py +++ b/python/src/app/tools/sales.py @@ -7,7 +7,7 @@ from src.app.services.azure_cosmos_db import vector_search, create_account_record, \ fetch_latest_account_number from src.app.services.azure_open_ai import generate_embedding - +# from src.app.services.local_model import generate_embedding # Use local model @tool @traceable