Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 24 additions & 14 deletions kag/common/tools/search_api/impl/openspg_search_api.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import logging
from typing import List

from kag.common.conf import KAGConstants, KAGConfigAccessor
from kag.common.tools.search_api.search_api_abc import SearchApiABC
from knext.search.client import SearchClient

logger = logging.getLogger(__name__)


@SearchApiABC.register("openspg_search_api", as_default=True)
class OpenSPGSearchAPI(SearchApiABC):
Expand All @@ -25,27 +28,34 @@ def search_text(
self, query_string, label_constraints=None, topk=10, params=None
) -> List:
if self.sc:
return self.sc.search_text(
query_string=query_string,
label_constraints=label_constraints,
topk=topk,
params=params,
)
try:
return self.sc.search_text(
query_string=query_string,
label_constraints=label_constraints,
topk=topk,
params=params,
)
except Exception as e:
logger.debug(f"search_vector error {e}", exc_info=True)
return []

def search_vector(
self, label, property_key, query_vector, topk=10, ef_search=None, params=None
) -> List:
if self.sc is None:
return []
res = self.sc.search_vector(
label=label,
property_key=property_key,
query_vector=query_vector,
topk=topk,
ef_search=ef_search,
params=params,
)
try:
res = self.sc.search_vector(
label=label,
property_key=property_key,
query_vector=query_vector,
topk=topk,
ef_search=ef_search,
params=params,
)
except Exception as e:
logger.debug(f"search_vector error {e}", exc_info=True)
res = []
if res is None:
return []
return res
Expand Down
7 changes: 7 additions & 0 deletions kag/interface/common/llm_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,8 @@ def invoke(
tools = kwargs.get("tools", None)
if tools:
with_json_parse = False
context = kwargs.get("context", None)
segment_name = kwargs.get("segment_name", None)
try:
self.sync_limiter.acquire()
response = (
Expand All @@ -335,6 +337,11 @@ def invoke(
)
if tools:
return response
if context and segment_name and hasattr(context, "add_kwargs"):
try:
context.add_kwargs(segment_name, response)
except Exception as e:
logger.warning(f"Failed to add kwargs to context: {e}")
result = prompt_op.parse_response(response, model=self.model, **variables)
logger.debug(f"Result: {result}")
return result
Expand Down
17 changes: 17 additions & 0 deletions kag/interface/solver/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,23 @@ def __init__(self):
self.variables_graph = KgGraph()
self.kwargs = {}

def add_kwargs(self, key, value):
"""
Adds a key-value pair to the context's keyword arguments dictionary.

This method stores arbitrary key-value data in the context that can be
accessed by other components during pipeline execution.

Args:
key: The string key to store the value under
value: The value to be stored, can be of any type

Example:
context.add_kwargs("temperature", 0.7)
context.add_kwargs("max_tokens", 1000)
"""
self.kwargs[key] = value

def add_task(self, task: Task):
"""Adds a task to the context.

Expand Down
31 changes: 18 additions & 13 deletions kag/solver/pipeline/kag_static_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,21 +129,26 @@ async def ainvoke(self, query, **kwargs):
while True:
context: Context = Context()
tasks = await self.planning(query, context, **kwargs)

for task in tasks:
context.add_task(task)

for task_group in context.gen_task(group=True):
await asyncio.gather(
*[
asyncio.create_task(
self.execute_task(query, task, context, **kwargs)
)
for task in task_group
]
if not tasks:
think_response = context.kwargs.get("thinker", "")
answer = await self.generator.ainvoke(
query + "\n" + think_response, context, **kwargs
)
else:
for task in tasks:
context.add_task(task)

for task_group in context.gen_task(group=True):
await asyncio.gather(
*[
asyncio.create_task(
self.execute_task(query, task, context, **kwargs)
)
for task in task_group
]
)

answer = await self.generator.ainvoke(query, context, **kwargs)
answer = await self.generator.ainvoke(query, context, **kwargs)
from kag.common.utils import red, green, reset

task_info = []
Expand Down
11 changes: 6 additions & 5 deletions kag/solver/reporter/open_spg_reporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,18 +530,19 @@ def add_report_line(self, segment, tag_name, content, status, **kwargs):
}

def do_report(self):
if not self.client:
return
content, status_enum, metrics = self.generate_report_data()

request = TaskStreamRequest(
task_id=self.task_id, content=content, status_enum=status_enum
)
# logging.info(f"do_report:{request}")
try:
ret = self.client.reasoner_dialog_report_completions_post(
task_stream_request=request
)
if self.client:
ret = self.client.reasoner_dialog_report_completions_post(
task_stream_request=request
)
else:
ret = {}
if self.last_report is None:
logger.info(f"begin do_report: {request} ret={ret}")
self.last_report = request
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ json_repair
gitpython
docstring_parser
aiolimiter
pyarrow==19.0.1
pyodps==0.12.2
aliyun-log-python-sdk==0.8.8
pyvis
Expand Down