diff --git a/libs/elasticsearch/langchain_elasticsearch/_async/retrievers.py b/libs/elasticsearch/langchain_elasticsearch/_async/retrievers.py index 487ee10..decde5c 100644 --- a/libs/elasticsearch/langchain_elasticsearch/_async/retrievers.py +++ b/libs/elasticsearch/langchain_elasticsearch/_async/retrievers.py @@ -13,34 +13,373 @@ class AsyncElasticsearchRetriever(BaseRetriever): - """ - Elasticsearch retriever - - Args: - es_client: Elasticsearch client connection. Alternatively you can use the - `from_es_params` method with parameters to initialize the client. - index_name: The name of the index to query. Can also be a list of names. - body_func: Function to create an Elasticsearch DSL query body from a search - string. The returned query body must fit what you would normally send in a - POST request the the _search endpoint. If applicable, it also includes - parameters the `size` parameter etc. - content_field: The document field name that contains the page content. If - multiple indices are queried, specify a dict {index_name: field_name} here. - document_mapper: Function to map Elasticsearch hits to LangChain Documents. - - For synchronous applications, use the ``ElasticsearchRetriever`` class. - For asyhchronous applications, use the ``AsyncElasticsearchRetriever`` class. + """`Elasticsearch` retriever. + + Setup: + Install `langchain_elasticsearch` and start Elasticsearch locally using + the start-local script. + + ```bash + pip install -qU langchain_elasticsearch + curl -fsSL https://elastic.co/start-local | sh + ``` + + This will create an `elastic-start-local` folder. To start Elasticsearch + and Kibana: + ```bash + cd elastic-start-local + ./start.sh + ``` + + Elasticsearch will be available at `http://localhost:9200`. The password + for the `elastic` user and API key are stored in the `.env` file in the + `elastic-start-local` folder. + + Key init args — query params: + index_name: Union[str, Sequence[str]] + The name of the index to query. Can also be a list of names. + body_func: Callable[[str], Dict] + Function to create an Elasticsearch DSL query body from a search string. + The returned query body must fit what you would normally send in a POST + request to the _search endpoint. If applicable, it also includes parameters + like the `size` parameter etc. + content_field: Optional[Union[str, Mapping[str, str]]] + The document field name that contains the page content. If multiple indices + are queried, specify a dict {index_name: field_name} here. + document_mapper: Optional[Callable[[Mapping], Document]] + Function to map Elasticsearch hits to LangChain Documents. If not provided, + will be automatically created based on content_field. + + Key init args — client params: + client: Optional[AsyncElasticsearch or Elasticsearch] + Pre-existing Elasticsearch connection. Either provide this OR credentials. + es_url: Optional[str] + URL of the Elasticsearch instance to connect to. + es_cloud_id: Optional[str] + Cloud ID of the Elasticsearch instance to connect to. + es_user: Optional[str] + Username to use when connecting to Elasticsearch. + es_api_key: Optional[str] + API key to use when connecting to Elasticsearch. + es_password: Optional[str] + Password to use when connecting to Elasticsearch. + + Instantiate: + ```python + from langchain_elasticsearch import ElasticsearchRetriever + + def body_func(query: str) -> dict: + return {"query": {"match": {"text": {"query": query}}}} + + retriever = ElasticsearchRetriever( + index_name="langchain-demo", + body_func=body_func, + content_field="text", + es_url="http://localhost:9200", + ) + ``` + + Instantiate with API key (URL): + ```python + from langchain_elasticsearch import ElasticsearchRetriever + + def body_func(query: str) -> dict: + return {"query": {"match": {"text": {"query": query}}}} + + retriever = ElasticsearchRetriever( + index_name="langchain-demo", + body_func=body_func, + content_field="text", + es_url="http://localhost:9200", + es_api_key="your-api-key" + ) + ``` + + Instantiate with username/password (URL): + ```python + from langchain_elasticsearch import ElasticsearchRetriever + + def body_func(query: str) -> dict: + return {"query": {"match": {"text": {"query": query}}}} + + retriever = ElasticsearchRetriever( + index_name="langchain-demo", + body_func=body_func, + content_field="text", + es_url="http://localhost:9200", + es_user="elastic", + es_password="password" + ) + ``` + + If you want to use a cloud hosted Elasticsearch instance, you can pass in the + es_cloud_id argument instead of the es_url argument. + + Instantiate from cloud (with username/password): + ```python + from langchain_elasticsearch import ElasticsearchRetriever + + def body_func(query: str) -> dict: + return {"query": {"match": {"text": {"query": query}}}} + + retriever = ElasticsearchRetriever( + index_name="langchain-demo", + body_func=body_func, + content_field="text", + es_cloud_id="", + es_user="elastic", + es_password="" + ) + ``` + + Instantiate from cloud (with API key): + ```python + from langchain_elasticsearch import ElasticsearchRetriever + + def body_func(query: str) -> dict: + return {"query": {"match": {"text": {"query": query}}}} + + retriever = ElasticsearchRetriever( + index_name="langchain-demo", + body_func=body_func, + content_field="text", + es_cloud_id="", + es_api_key="your-api-key" + ) + ``` + + You can also connect to an existing Elasticsearch instance by passing in a + pre-existing Elasticsearch connection via the client argument. + + Instantiate from existing connection: + ```python + from langchain_elasticsearch import ElasticsearchRetriever + from elasticsearch import Elasticsearch + + def body_func(query: str) -> dict: + return {"query": {"match": {"text": {"query": query}}}} + + client = Elasticsearch("http://localhost:9200") + retriever = ElasticsearchRetriever( + index_name="langchain-demo", + body_func=body_func, + content_field="text", + client=client + ) + ``` + + Retrieve documents: + Note: Use `invoke()` or `ainvoke()` instead of the deprecated + `get_relevant_documents()` or `aget_relevant_documents()` methods. + + First, index some documents: + ```python + from elasticsearch import Elasticsearch + + client = Elasticsearch("http://localhost:9200") + + # Index sample documents + client.index( + index="some-index", + document={"text": "The quick brown fox jumps over the lazy dog"}, + id="1", + refresh=True + ) + client.index( + index="some-index", + document={"text": "Python is a popular programming language"}, + id="2", + refresh=True + ) + client.index( + index="some-index", + document={"text": "Elasticsearch is a search engine"}, + id="3", + refresh=True + ) + ``` + + Then retrieve documents: + ```python + from langchain_elasticsearch import ElasticsearchRetriever + + def body_func(query: str) -> dict: + return {"query": {"match": {"text": {"query": query}}}} + + retriever = ElasticsearchRetriever( + index_name="some-index", + body_func=body_func, + content_field="text", + es_url="http://localhost:9200" + ) + + # Retrieve documents + documents = retriever.invoke("Python") + for doc in documents: + print(f"* {doc.page_content}") + ``` + ```python + * Python is a popular programming language + ``` + + + + Use custom document mapper: + ```python + from langchain_elasticsearch import ElasticsearchRetriever + from langchain_core.documents import Document + from elasticsearch import Elasticsearch + from typing import Mapping, Any + + def body_func(query: str) -> dict: + return {"query": {"match": {"custom_field": {"query": query}}}} + + def custom_mapper(hit: Mapping[str, Any]) -> Document: + # Custom logic to extract content and metadata + return Document( + page_content=hit["_source"]["custom_field"], + metadata={"score": hit["_score"]} + ) + + client = Elasticsearch("http://localhost:9200") + retriever = ElasticsearchRetriever( + index_name="langchain-demo", + body_func=body_func, + document_mapper=custom_mapper, + client=client + ) + ``` + + Use with multiple indices: + ```python + from langchain_elasticsearch import ElasticsearchRetriever + from elasticsearch import Elasticsearch + + def body_func(query: str) -> dict: + return { + "query": { + "multi_match": { + "query": query, + "fields": ["text_field_1", "text_field_2"] + } + } + } + + client = Elasticsearch("http://localhost:9200") + retriever = ElasticsearchRetriever( + index_name=["index1", "index2"], + body_func=body_func, + content_field={ + "index1": "text_field_1", + "index2": "text_field_2" + }, + client=client + ) + ``` + + Use as LangChain retriever in chains: + Note: Before running this example, ensure you have indexed documents + in your Elasticsearch index. The retriever will search this index + for relevant documents to use as context. + + ```python + from langchain_elasticsearch import ElasticsearchRetriever + from langchain_core.runnables import RunnablePassthrough + from langchain_core.prompts import ChatPromptTemplate + from langchain_ollama import ChatOllama + + # ElasticsearchRetriever is already a BaseRetriever + retriever = ElasticsearchRetriever( + index_name="some-index", + body_func=lambda q: {"query": {"match": {"text": {"query": q}}}}, + content_field="text", + es_url="http://localhost:9200" + ) + + llm = ChatOllama(model="llama3", temperature=0) + + # Create a chain that retrieves documents and then generates a response + def format_docs(docs): + # Format documents for the prompt + return "\n\n".join(doc.page_content for doc in docs) + + system_prompt = ( + "You are an assistant for question-answering tasks. " + "Use the following pieces of retrieved context to answer " + "the question. If you don't know the answer, say that you " + "don't know. Use three sentences maximum and keep the " + "answer concise." + "\n\n" + "Context: {context}" + ) + + prompt = ChatPromptTemplate.from_messages([ + ("system", system_prompt), + ("human", "{question}"), + ]) + + chain = ( + {"context": retriever | format_docs, "question": RunnablePassthrough()} + | prompt + | llm + ) + + result = chain.invoke("what is the answer to this question?") + ``` + + For synchronous applications, use the `ElasticsearchRetriever` class. + For asynchronous applications, use the `AsyncElasticsearchRetriever` class. """ - es_client: AsyncElasticsearch + client: AsyncElasticsearch index_name: Union[str, Sequence[str]] body_func: Callable[[str], Dict] content_field: Optional[Union[str, Mapping[str, str]]] = None document_mapper: Optional[Callable[[Mapping], Document]] = None - def __init__(self, **kwargs: Any) -> None: - super().__init__(**kwargs) + def __init__( + self, + index_name: Union[str, Sequence[str]], + body_func: Callable[[str], Dict], + *, + content_field: Optional[Union[str, Mapping[str, str]]] = None, + document_mapper: Optional[Callable[[Mapping], Document]] = None, + client: Optional[AsyncElasticsearch] = None, + es_url: Optional[str] = None, + es_cloud_id: Optional[str] = None, + es_user: Optional[str] = None, + es_api_key: Optional[str] = None, + es_password: Optional[str] = None, + ) -> None: + # Create client from credentials if needed (BEFORE super().__init__) + if client is not None: + es_connection = client + elif es_url is not None or es_cloud_id is not None: + es_connection = create_async_elasticsearch_client( + url=es_url, + cloud_id=es_cloud_id, + api_key=es_api_key, + username=es_user, + password=es_password, + ) + else: + raise ValueError( + "Either 'client' or credentials (es_url, es_cloud_id, etc.) " + "must be provided." + ) + + # Apply user agent + es_connection = async_with_user_agent_header(es_connection, "langchain-py-r") + + super().__init__( + client=es_connection, + index_name=index_name, + body_func=body_func, + content_field=content_field, + document_mapper=document_mapper, + ) + # Now Pydantic has set everything, do validation if self.content_field is None and self.document_mapper is None: raise ValueError("One of content_field or document_mapper must be defined.") if self.content_field is not None and self.document_mapper is not None: @@ -59,52 +398,14 @@ def __init__(self, **kwargs: Any) -> None: "unknown type for content_field, expected string or dict." ) - self.es_client = async_with_user_agent_header(self.es_client, "langchain-py-r") - - @classmethod - def from_es_params( - cls, - index_name: Union[str, Sequence[str]], - body_func: Callable[[str], Dict], - content_field: Optional[Union[str, Mapping[str, str]]] = None, - document_mapper: Optional[Callable[[Mapping], Document]] = None, - url: Optional[str] = None, - cloud_id: Optional[str] = None, - api_key: Optional[str] = None, - username: Optional[str] = None, - password: Optional[str] = None, - params: Optional[Dict[str, Any]] = None, - ) -> "AsyncElasticsearchRetriever": - client = None - try: - client = create_async_elasticsearch_client( - url=url, - cloud_id=cloud_id, - api_key=api_key, - username=username, - password=password, - params=params, - ) - except Exception as err: - logger.error(f"Error connecting to Elasticsearch: {err}") - raise err - - return cls( - es_client=client, - index_name=index_name, - body_func=body_func, - content_field=content_field, - document_mapper=document_mapper, - ) - async def _aget_relevant_documents( self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun ) -> List[Document]: - if not self.es_client or not self.document_mapper: + if not self.client or not self.document_mapper: raise ValueError("faulty configuration") # should not happen body = self.body_func(query) - results = await self.es_client.search(index=self.index_name, body=body) + results = await self.client.search(index=self.index_name, body=body) return [self.document_mapper(hit) for hit in results["hits"]["hits"]] def _single_field_mapper(self, hit: Mapping[str, Any]) -> Document: diff --git a/libs/elasticsearch/langchain_elasticsearch/_sync/retrievers.py b/libs/elasticsearch/langchain_elasticsearch/_sync/retrievers.py index c35974d..efc9d24 100644 --- a/libs/elasticsearch/langchain_elasticsearch/_sync/retrievers.py +++ b/libs/elasticsearch/langchain_elasticsearch/_sync/retrievers.py @@ -13,34 +13,373 @@ class ElasticsearchRetriever(BaseRetriever): - """ - Elasticsearch retriever - - Args: - es_client: Elasticsearch client connection. Alternatively you can use the - `from_es_params` method with parameters to initialize the client. - index_name: The name of the index to query. Can also be a list of names. - body_func: Function to create an Elasticsearch DSL query body from a search - string. The returned query body must fit what you would normally send in a - POST request the the _search endpoint. If applicable, it also includes - parameters the `size` parameter etc. - content_field: The document field name that contains the page content. If - multiple indices are queried, specify a dict {index_name: field_name} here. - document_mapper: Function to map Elasticsearch hits to LangChain Documents. - - For synchronous applications, use the ``ElasticsearchRetriever`` class. - For asyhchronous applications, use the ``AsyncElasticsearchRetriever`` class. + """`Elasticsearch` retriever. + + Setup: + Install `langchain_elasticsearch` and start Elasticsearch locally using + the start-local script. + + ```bash + pip install -qU langchain_elasticsearch + curl -fsSL https://elastic.co/start-local | sh + ``` + + This will create an `elastic-start-local` folder. To start Elasticsearch + and Kibana: + ```bash + cd elastic-start-local + ./start.sh + ``` + + Elasticsearch will be available at `http://localhost:9200`. The password + for the `elastic` user and API key are stored in the `.env` file in the + `elastic-start-local` folder. + + Key init args — query params: + index_name: Union[str, Sequence[str]] + The name of the index to query. Can also be a list of names. + body_func: Callable[[str], Dict] + Function to create an Elasticsearch DSL query body from a search string. + The returned query body must fit what you would normally send in a POST + request to the _search endpoint. If applicable, it also includes parameters + like the `size` parameter etc. + content_field: Optional[Union[str, Mapping[str, str]]] + The document field name that contains the page content. If multiple indices + are queried, specify a dict {index_name: field_name} here. + document_mapper: Optional[Callable[[Mapping], Document]] + Function to map Elasticsearch hits to LangChain Documents. If not provided, + will be automatically created based on content_field. + + Key init args — client params: + client: Optional[AsyncElasticsearch or Elasticsearch] + Pre-existing Elasticsearch connection. Either provide this OR credentials. + es_url: Optional[str] + URL of the Elasticsearch instance to connect to. + es_cloud_id: Optional[str] + Cloud ID of the Elasticsearch instance to connect to. + es_user: Optional[str] + Username to use when connecting to Elasticsearch. + es_api_key: Optional[str] + API key to use when connecting to Elasticsearch. + es_password: Optional[str] + Password to use when connecting to Elasticsearch. + + Instantiate: + ```python + from langchain_elasticsearch import ElasticsearchRetriever + + def body_func(query: str) -> dict: + return {"query": {"match": {"text": {"query": query}}}} + + retriever = ElasticsearchRetriever( + index_name="langchain-demo", + body_func=body_func, + content_field="text", + es_url="http://localhost:9200", + ) + ``` + + Instantiate with API key (URL): + ```python + from langchain_elasticsearch import ElasticsearchRetriever + + def body_func(query: str) -> dict: + return {"query": {"match": {"text": {"query": query}}}} + + retriever = ElasticsearchRetriever( + index_name="langchain-demo", + body_func=body_func, + content_field="text", + es_url="http://localhost:9200", + es_api_key="your-api-key" + ) + ``` + + Instantiate with username/password (URL): + ```python + from langchain_elasticsearch import ElasticsearchRetriever + + def body_func(query: str) -> dict: + return {"query": {"match": {"text": {"query": query}}}} + + retriever = ElasticsearchRetriever( + index_name="langchain-demo", + body_func=body_func, + content_field="text", + es_url="http://localhost:9200", + es_user="elastic", + es_password="password" + ) + ``` + + If you want to use a cloud hosted Elasticsearch instance, you can pass in the + es_cloud_id argument instead of the es_url argument. + + Instantiate from cloud (with username/password): + ```python + from langchain_elasticsearch import ElasticsearchRetriever + + def body_func(query: str) -> dict: + return {"query": {"match": {"text": {"query": query}}}} + + retriever = ElasticsearchRetriever( + index_name="langchain-demo", + body_func=body_func, + content_field="text", + es_cloud_id="", + es_user="elastic", + es_password="" + ) + ``` + + Instantiate from cloud (with API key): + ```python + from langchain_elasticsearch import ElasticsearchRetriever + + def body_func(query: str) -> dict: + return {"query": {"match": {"text": {"query": query}}}} + + retriever = ElasticsearchRetriever( + index_name="langchain-demo", + body_func=body_func, + content_field="text", + es_cloud_id="", + es_api_key="your-api-key" + ) + ``` + + You can also connect to an existing Elasticsearch instance by passing in a + pre-existing Elasticsearch connection via the client argument. + + Instantiate from existing connection: + ```python + from langchain_elasticsearch import ElasticsearchRetriever + from elasticsearch import Elasticsearch + + def body_func(query: str) -> dict: + return {"query": {"match": {"text": {"query": query}}}} + + client = Elasticsearch("http://localhost:9200") + retriever = ElasticsearchRetriever( + index_name="langchain-demo", + body_func=body_func, + content_field="text", + client=client + ) + ``` + + Retrieve documents: + Note: Use `invoke()` or `ainvoke()` instead of the deprecated + `get_relevant_documents()` or `aget_relevant_documents()` methods. + + First, index some documents: + ```python + from elasticsearch import Elasticsearch + + client = Elasticsearch("http://localhost:9200") + + # Index sample documents + client.index( + index="some-index", + document={"text": "The quick brown fox jumps over the lazy dog"}, + id="1", + refresh=True + ) + client.index( + index="some-index", + document={"text": "Python is a popular programming language"}, + id="2", + refresh=True + ) + client.index( + index="some-index", + document={"text": "Elasticsearch is a search engine"}, + id="3", + refresh=True + ) + ``` + + Then retrieve documents: + ```python + from langchain_elasticsearch import ElasticsearchRetriever + + def body_func(query: str) -> dict: + return {"query": {"match": {"text": {"query": query}}}} + + retriever = ElasticsearchRetriever( + index_name="some-index", + body_func=body_func, + content_field="text", + es_url="http://localhost:9200" + ) + + # Retrieve documents + documents = retriever.invoke("Python") + for doc in documents: + print(f"* {doc.page_content}") + ``` + ```python + * Python is a popular programming language + ``` + + + + Use custom document mapper: + ```python + from langchain_elasticsearch import ElasticsearchRetriever + from langchain_core.documents import Document + from elasticsearch import Elasticsearch + from typing import Mapping, Any + + def body_func(query: str) -> dict: + return {"query": {"match": {"custom_field": {"query": query}}}} + + def custom_mapper(hit: Mapping[str, Any]) -> Document: + # Custom logic to extract content and metadata + return Document( + page_content=hit["_source"]["custom_field"], + metadata={"score": hit["_score"]} + ) + + client = Elasticsearch("http://localhost:9200") + retriever = ElasticsearchRetriever( + index_name="langchain-demo", + body_func=body_func, + document_mapper=custom_mapper, + client=client + ) + ``` + + Use with multiple indices: + ```python + from langchain_elasticsearch import ElasticsearchRetriever + from elasticsearch import Elasticsearch + + def body_func(query: str) -> dict: + return { + "query": { + "multi_match": { + "query": query, + "fields": ["text_field_1", "text_field_2"] + } + } + } + + client = Elasticsearch("http://localhost:9200") + retriever = ElasticsearchRetriever( + index_name=["index1", "index2"], + body_func=body_func, + content_field={ + "index1": "text_field_1", + "index2": "text_field_2" + }, + client=client + ) + ``` + + Use as LangChain retriever in chains: + Note: Before running this example, ensure you have indexed documents + in your Elasticsearch index. The retriever will search this index + for relevant documents to use as context. + + ```python + from langchain_elasticsearch import ElasticsearchRetriever + from langchain_core.runnables import RunnablePassthrough + from langchain_core.prompts import ChatPromptTemplate + from langchain_ollama import ChatOllama + + # ElasticsearchRetriever is already a BaseRetriever + retriever = ElasticsearchRetriever( + index_name="some-index", + body_func=lambda q: {"query": {"match": {"text": {"query": q}}}}, + content_field="text", + es_url="http://localhost:9200" + ) + + llm = ChatOllama(model="llama3", temperature=0) + + # Create a chain that retrieves documents and then generates a response + def format_docs(docs): + # Format documents for the prompt + return "\n\n".join(doc.page_content for doc in docs) + + system_prompt = ( + "You are an assistant for question-answering tasks. " + "Use the following pieces of retrieved context to answer " + "the question. If you don't know the answer, say that you " + "don't know. Use three sentences maximum and keep the " + "answer concise." + "\n\n" + "Context: {context}" + ) + + prompt = ChatPromptTemplate.from_messages([ + ("system", system_prompt), + ("human", "{question}"), + ]) + + chain = ( + {"context": retriever | format_docs, "question": RunnablePassthrough()} + | prompt + | llm + ) + + result = chain.invoke("what is the answer to this question?") + ``` + + For synchronous applications, use the `ElasticsearchRetriever` class. + For asynchronous applications, use the `AsyncElasticsearchRetriever` class. """ - es_client: Elasticsearch + client: Elasticsearch index_name: Union[str, Sequence[str]] body_func: Callable[[str], Dict] content_field: Optional[Union[str, Mapping[str, str]]] = None document_mapper: Optional[Callable[[Mapping], Document]] = None - def __init__(self, **kwargs: Any) -> None: - super().__init__(**kwargs) + def __init__( + self, + index_name: Union[str, Sequence[str]], + body_func: Callable[[str], Dict], + *, + content_field: Optional[Union[str, Mapping[str, str]]] = None, + document_mapper: Optional[Callable[[Mapping], Document]] = None, + client: Optional[Elasticsearch] = None, + es_url: Optional[str] = None, + es_cloud_id: Optional[str] = None, + es_user: Optional[str] = None, + es_api_key: Optional[str] = None, + es_password: Optional[str] = None, + ) -> None: + # Create client from credentials if needed (BEFORE super().__init__) + if client is not None: + es_connection = client + elif es_url is not None or es_cloud_id is not None: + es_connection = create_elasticsearch_client( + url=es_url, + cloud_id=es_cloud_id, + api_key=es_api_key, + username=es_user, + password=es_password, + ) + else: + raise ValueError( + "Either 'client' or credentials (es_url, es_cloud_id, etc.) " + "must be provided." + ) + + # Apply user agent + es_connection = with_user_agent_header(es_connection, "langchain-py-r") + + super().__init__( + client=es_connection, + index_name=index_name, + body_func=body_func, + content_field=content_field, + document_mapper=document_mapper, + ) + # Now Pydantic has set everything, do validation if self.content_field is None and self.document_mapper is None: raise ValueError("One of content_field or document_mapper must be defined.") if self.content_field is not None and self.document_mapper is not None: @@ -59,52 +398,14 @@ def __init__(self, **kwargs: Any) -> None: "unknown type for content_field, expected string or dict." ) - self.es_client = with_user_agent_header(self.es_client, "langchain-py-r") - - @classmethod - def from_es_params( - cls, - index_name: Union[str, Sequence[str]], - body_func: Callable[[str], Dict], - content_field: Optional[Union[str, Mapping[str, str]]] = None, - document_mapper: Optional[Callable[[Mapping], Document]] = None, - url: Optional[str] = None, - cloud_id: Optional[str] = None, - api_key: Optional[str] = None, - username: Optional[str] = None, - password: Optional[str] = None, - params: Optional[Dict[str, Any]] = None, - ) -> "ElasticsearchRetriever": - client = None - try: - client = create_elasticsearch_client( - url=url, - cloud_id=cloud_id, - api_key=api_key, - username=username, - password=password, - params=params, - ) - except Exception as err: - logger.error(f"Error connecting to Elasticsearch: {err}") - raise err - - return cls( - es_client=client, - index_name=index_name, - body_func=body_func, - content_field=content_field, - document_mapper=document_mapper, - ) - def _get_relevant_documents( self, query: str, *, run_manager: CallbackManagerForRetrieverRun ) -> List[Document]: - if not self.es_client or not self.document_mapper: + if not self.client or not self.document_mapper: raise ValueError("faulty configuration") # should not happen body = self.body_func(query) - results = self.es_client.search(index=self.index_name, body=body) + results = self.client.search(index=self.index_name, body=body) return [self.document_mapper(hit) for hit in results["hits"]["hits"]] def _single_field_mapper(self, hit: Mapping[str, Any]) -> Document: diff --git a/libs/elasticsearch/tests/integration_tests/_async/test_retrievers.py b/libs/elasticsearch/tests/integration_tests/_async/test_retrievers.py index 545e4ba..f54c8e0 100644 --- a/libs/elasticsearch/tests/integration_tests/_async/test_retrievers.py +++ b/libs/elasticsearch/tests/integration_tests/_async/test_retrievers.py @@ -2,7 +2,7 @@ import re import uuid -from typing import Any, Dict +from typing import Any, Dict, Mapping import pytest from elasticsearch import AsyncElasticsearch @@ -58,11 +58,11 @@ async def test_user_agent_header( index_name=index_name, body_func=lambda _: {"query": {"match_all": {}}}, content_field="text", - es_client=es_client, + client=es_client, ) - assert retriever.es_client - user_agent = retriever.es_client._headers["User-Agent"] + assert retriever.client + user_agent = retriever.client._headers["User-Agent"] assert ( re.match(r"^langchain-py-r/\d+\.\d+\.\d+(?:rc\d+)?$", user_agent) is not None @@ -91,20 +91,20 @@ def body_func(query: str) -> Dict: # Map test utility format to retriever format config = {} if "es_url" in env_config: - config["url"] = env_config["es_url"] + config["es_url"] = env_config["es_url"] if "es_api_key" in env_config: - config["api_key"] = env_config["es_api_key"] + config["es_api_key"] = env_config["es_api_key"] if "es_cloud_id" in env_config: - config["cloud_id"] = env_config["es_cloud_id"] + config["es_cloud_id"] = env_config["es_cloud_id"] - retriever = AsyncElasticsearchRetriever.from_es_params( + retriever = AsyncElasticsearchRetriever( index_name=index_name, body_func=body_func, content_field=text_field, **config, # type: ignore[arg-type] ) - await index_test_data(retriever.es_client, index_name, text_field) + await index_test_data(retriever.client, index_name, text_field) result = await retriever.aget_relevant_documents("foo") assert {r.page_content for r in result} == {"foo", "foo bar", "foo baz"} @@ -129,7 +129,7 @@ def body_func(query: str) -> Dict: index_name=index_name, body_func=body_func, content_field=text_field, - es_client=es_client, + client=es_client, ) await index_test_data(es_client, index_name, text_field) @@ -166,7 +166,7 @@ def body_func(query: str) -> Dict: index_name=[index_name_1, index_name_2], content_field={index_name_1: text_field_1, index_name_2: text_field_2}, body_func=body_func, - es_client=es_client, + client=es_client, ) await index_test_data(es_client, index_name_1, text_field_1) @@ -195,14 +195,14 @@ async def test_custom_mapper( def body_func(query: str) -> Dict: return {"query": {"match": {text_field: {"query": query}}}} - def id_as_content(hit: Dict) -> Document: + def id_as_content(hit: Mapping[str, Any]) -> Document: return Document(page_content=hit["_id"], metadata=meta) retriever = AsyncElasticsearchRetriever( index_name=index_name, body_func=body_func, document_mapper=id_as_content, - es_client=es_client, + client=es_client, ) await index_test_data(es_client, index_name, text_field) @@ -220,10 +220,10 @@ async def test_fail_content_field_and_mapper( with pytest.raises(ValueError): AsyncElasticsearchRetriever( content_field="text", - document_mapper=lambda x: x, + document_mapper=lambda x: x, # type: ignore[arg-type,return-value] index_name="foo", - body_func=lambda x: x, - es_client=es_client, + body_func=lambda x: x, # type: ignore[arg-type,return-value] + client=es_client, ) @pytest.mark.asyncio @@ -235,6 +235,6 @@ async def test_fail_neither_content_field_nor_mapper( with pytest.raises(ValueError): AsyncElasticsearchRetriever( index_name="foo", - body_func=lambda x: x, - es_client=es_client, + body_func=lambda x: x, # type: ignore[arg-type,return-value] + client=es_client, ) diff --git a/libs/elasticsearch/tests/integration_tests/_sync/test_retrievers.py b/libs/elasticsearch/tests/integration_tests/_sync/test_retrievers.py index eda486f..457b1da 100644 --- a/libs/elasticsearch/tests/integration_tests/_sync/test_retrievers.py +++ b/libs/elasticsearch/tests/integration_tests/_sync/test_retrievers.py @@ -2,7 +2,7 @@ import re import uuid -from typing import Any, Dict +from typing import Any, Dict, Mapping import pytest from elasticsearch import Elasticsearch @@ -54,11 +54,11 @@ def test_user_agent_header(self, es_client: Elasticsearch, index_name: str) -> N index_name=index_name, body_func=lambda _: {"query": {"match_all": {}}}, content_field="text", - es_client=es_client, + client=es_client, ) - assert retriever.es_client - user_agent = retriever.es_client._headers["User-Agent"] + assert retriever.client + user_agent = retriever.client._headers["User-Agent"] assert ( re.match(r"^langchain-py-r/\d+\.\d+\.\d+(?:rc\d+)?$", user_agent) is not None @@ -87,20 +87,20 @@ def body_func(query: str) -> Dict: # Map test utility format to retriever format config = {} if "es_url" in env_config: - config["url"] = env_config["es_url"] + config["es_url"] = env_config["es_url"] if "es_api_key" in env_config: - config["api_key"] = env_config["es_api_key"] + config["es_api_key"] = env_config["es_api_key"] if "es_cloud_id" in env_config: - config["cloud_id"] = env_config["es_cloud_id"] + config["es_cloud_id"] = env_config["es_cloud_id"] - retriever = ElasticsearchRetriever.from_es_params( + retriever = ElasticsearchRetriever( index_name=index_name, body_func=body_func, content_field=text_field, **config, # type: ignore[arg-type] ) - index_test_data(retriever.es_client, index_name, text_field) + index_test_data(retriever.client, index_name, text_field) result = retriever.get_relevant_documents("foo") assert {r.page_content for r in result} == {"foo", "foo bar", "foo baz"} @@ -123,7 +123,7 @@ def body_func(query: str) -> Dict: index_name=index_name, body_func=body_func, content_field=text_field, - es_client=es_client, + client=es_client, ) index_test_data(es_client, index_name, text_field) @@ -160,7 +160,7 @@ def body_func(query: str) -> Dict: index_name=[index_name_1, index_name_2], content_field={index_name_1: text_field_1, index_name_2: text_field_2}, body_func=body_func, - es_client=es_client, + client=es_client, ) index_test_data(es_client, index_name_1, text_field_1) @@ -187,14 +187,14 @@ def test_custom_mapper(self, es_client: Elasticsearch, index_name: str) -> None: def body_func(query: str) -> Dict: return {"query": {"match": {text_field: {"query": query}}}} - def id_as_content(hit: Dict) -> Document: + def id_as_content(hit: Mapping[str, Any]) -> Document: return Document(page_content=hit["_id"], metadata=meta) retriever = ElasticsearchRetriever( index_name=index_name, body_func=body_func, document_mapper=id_as_content, - es_client=es_client, + client=es_client, ) index_test_data(es_client, index_name, text_field) @@ -210,10 +210,10 @@ def test_fail_content_field_and_mapper(self, es_client: Elasticsearch) -> None: with pytest.raises(ValueError): ElasticsearchRetriever( content_field="text", - document_mapper=lambda x: x, + document_mapper=lambda x: x, # type: ignore[arg-type,return-value] index_name="foo", - body_func=lambda x: x, - es_client=es_client, + body_func=lambda x: x, # type: ignore[arg-type,return-value] + client=es_client, ) @pytest.mark.sync @@ -225,6 +225,6 @@ def test_fail_neither_content_field_nor_mapper( with pytest.raises(ValueError): ElasticsearchRetriever( index_name="foo", - body_func=lambda x: x, - es_client=es_client, + body_func=lambda x: x, # type: ignore[arg-type,return-value] + client=es_client, )