Skip to content

Commit 63c2ab1

Browse files
authored
Standardize ElasticsearchEmbeddings (#95)
* standardize embeddings class * update tests
1 parent d53b8be commit 63c2ab1

File tree

4 files changed

+344
-314
lines changed

4 files changed

+344
-314
lines changed

libs/elasticsearch/langchain_elasticsearch/_async/embeddings.py

Lines changed: 170 additions & 155 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations
22

3-
from typing import Any, Dict, List, Optional
3+
from typing import List, Optional
44

55
from elasticsearch import AsyncElasticsearch
66
from elasticsearch.helpers.vectorstore import AsyncEmbeddingService
@@ -13,187 +13,202 @@
1313

1414

1515
class AsyncElasticsearchEmbeddings(Embeddings):
16-
"""Elasticsearch embedding models.
16+
"""`Elasticsearch` embedding models.
1717
1818
This class provides an interface to generate embeddings using a model deployed
19-
in an Elasticsearch cluster. It requires an Elasticsearch connection object
20-
and the model_id of the model deployed in the cluster.
19+
in an Elasticsearch cluster. It requires an Elasticsearch connection and the
20+
model_id of the model deployed in the cluster.
2121
2222
In Elasticsearch you need to have an embedding model loaded and deployed.
2323
- https://www.elastic.co/guide/en/elasticsearch/reference/current/infer-trained-model.html
2424
- https://www.elastic.co/guide/en/machine-learning/current/ml-nlp-deploy-models.html
2525
26+
Setup:
27+
Install `langchain_elasticsearch` and start Elasticsearch locally using
28+
the start-local script.
29+
30+
```bash
31+
pip install -qU langchain_elasticsearch
32+
curl -fsSL https://elastic.co/start-local | sh
33+
```
34+
35+
This will create an `elastic-start-local` folder. To start Elasticsearch
36+
and Kibana:
37+
```bash
38+
cd elastic-start-local
39+
./start.sh
40+
```
41+
42+
Elasticsearch will be available at `http://localhost:9200`. The password
43+
for the `elastic` user and API key are stored in the `.env` file in the
44+
`elastic-start-local` folder.
45+
46+
Key init args:
47+
model_id: str
48+
The model_id of the model deployed in the Elasticsearch cluster.
49+
input_field: str
50+
The name of the key for the input text field in the document.
51+
Defaults to 'text_field'.
52+
53+
Key init args — client params:
54+
client: Optional[AsyncElasticsearch or Elasticsearch]
55+
Pre-existing Elasticsearch connection. Either provide this OR credentials.
56+
es_url: Optional[str]
57+
URL of the Elasticsearch instance to connect to.
58+
es_cloud_id: Optional[str]
59+
Cloud ID of the Elasticsearch instance to connect to.
60+
es_user: Optional[str]
61+
Username to use when connecting to Elasticsearch.
62+
es_api_key: Optional[str]
63+
API key to use when connecting to Elasticsearch.
64+
es_password: Optional[str]
65+
Password to use when connecting to Elasticsearch.
66+
67+
Instantiate:
68+
```python
69+
from langchain_elasticsearch import ElasticsearchEmbeddings
70+
71+
embeddings = ElasticsearchEmbeddings(
72+
model_id="your_model_id",
73+
es_url="http://localhost:9200"
74+
)
75+
```
76+
77+
Instantiate with API key (URL):
78+
```python
79+
from langchain_elasticsearch import ElasticsearchEmbeddings
80+
81+
embeddings = ElasticsearchEmbeddings(
82+
model_id="your_model_id",
83+
es_url="http://localhost:9200",
84+
es_api_key="your-api-key"
85+
)
86+
```
87+
88+
Instantiate with username/password (URL):
89+
```python
90+
from langchain_elasticsearch import ElasticsearchEmbeddings
91+
92+
embeddings = ElasticsearchEmbeddings(
93+
model_id="your_model_id",
94+
es_url="http://localhost:9200",
95+
es_user="elastic",
96+
es_password="password"
97+
)
98+
```
99+
100+
If you want to use a cloud hosted Elasticsearch instance, you can pass in the
101+
es_cloud_id argument instead of the es_url argument.
102+
103+
Instantiate from cloud (with username/password):
104+
```python
105+
from langchain_elasticsearch import ElasticsearchEmbeddings
106+
107+
embeddings = ElasticsearchEmbeddings(
108+
model_id="your_model_id",
109+
es_cloud_id="<cloud_id>",
110+
es_user="elastic",
111+
es_password="<password>"
112+
)
113+
```
114+
115+
Instantiate from cloud (with API key):
116+
```python
117+
from langchain_elasticsearch import ElasticsearchEmbeddings
118+
119+
embeddings = ElasticsearchEmbeddings(
120+
model_id="your_model_id",
121+
es_cloud_id="<cloud_id>",
122+
es_api_key="your-api-key"
123+
)
124+
```
125+
126+
You can also connect to an existing Elasticsearch instance by passing in a
127+
pre-existing Elasticsearch connection via the client argument.
128+
129+
Instantiate from existing connection:
130+
```python
131+
from langchain_elasticsearch import ElasticsearchEmbeddings
132+
from elasticsearch import Elasticsearch
133+
134+
client = Elasticsearch("http://localhost:9200")
135+
embeddings = ElasticsearchEmbeddings(
136+
model_id="your_model_id",
137+
client=client
138+
)
139+
```
140+
141+
Generate embeddings:
142+
```python
143+
documents = [
144+
"This is an example document.",
145+
"Another example document to generate embeddings for.",
146+
]
147+
embeddings_list = embeddings.embed_documents(documents)
148+
```
149+
150+
Generate query embedding:
151+
```python
152+
query_embedding = embeddings.embed_query("What is this about?")
153+
```
154+
26155
For synchronous applications, use the `ElasticsearchEmbeddings` class.
27-
For asyhchronous applications, use the `AsyncElasticsearchEmbeddings` class.
156+
For asynchronous applications, use the `AsyncElasticsearchEmbeddings` class.
28157
""" # noqa: E501
29158

30159
def __init__(
31160
self,
32-
client: AsyncElasticsearch,
33161
model_id: str,
34162
*,
35163
input_field: str = "text_field",
36-
):
37-
"""
38-
Initialize the ElasticsearchEmbeddings instance.
39-
40-
Args:
41-
client (Elasticsearch or AsyncElasticsearch): An Elasticsearch client
42-
object or an AsyncElasticsearch client object.
43-
model_id (str): The model_id of the model deployed in the Elasticsearch
44-
cluster.
45-
input_field (str): The name of the key for the input text field in the
46-
document. Defaults to 'text_field'.
47-
"""
48-
# Apply User-Agent for telemetry
49-
# (applies to both passed and internally created clients)
50-
self.client = async_with_user_agent_header(client, "langchain-py-e")
51-
self.model_id = model_id
52-
self.input_field = input_field
53-
54-
@classmethod
55-
def from_credentials(
56-
cls,
57-
model_id: str,
58-
*,
164+
client: Optional[AsyncElasticsearch] = None,
59165
es_url: Optional[str] = None,
60166
es_cloud_id: Optional[str] = None,
61-
es_api_key: Optional[str] = None,
62167
es_user: Optional[str] = None,
168+
es_api_key: Optional[str] = None,
63169
es_password: Optional[str] = None,
64-
es_params: Optional[Dict[str, Any]] = None,
65-
input_field: str = "text_field",
66-
) -> AsyncElasticsearchEmbeddings:
67-
"""Instantiate embeddings from Elasticsearch credentials.
170+
):
171+
"""Initialize the ElasticsearchEmbeddings instance.
68172
69173
Args:
70174
model_id (str): The model_id of the model deployed in the Elasticsearch
71175
cluster.
72176
input_field (str): The name of the key for the input text field in the
73177
document. Defaults to 'text_field'.
74-
es_url: (str, optional): URL of the Elasticsearch instance to connect to.
75-
es_cloud_id: (str, optional): The Elasticsearch cloud ID to connect to.
76-
es_api_key: (str, optional): API key to use connecting to Elasticsearch.
77-
es_user: (str, optional): Elasticsearch username.
78-
es_password: (str, optional): Elasticsearch password.
79-
es_params: (dict, optional): Additional parameters for the
80-
Elasticsearch client.
81-
82-
Example:
83-
.. code-block:: python
84-
85-
from langchain_elasticserach.embeddings import ElasticsearchEmbeddings
86-
87-
88-
# Define the model ID and input field name (if different from default)
89-
model_id = "your_model_id"
90-
# Optional, only if different from 'text_field'
91-
input_field = "your_input_field"
92-
93-
# Provide either es_url (local) or es_cloud_id (cloud).
94-
# For authentication, provide either es_api_key or
95-
# (es_user + es_password).
96-
embeddings = ElasticsearchEmbeddings.from_credentials(
97-
model_id,
98-
input_field=input_field,
99-
es_cloud_id="foo",
100-
es_api_key="bar",
101-
)
102-
103-
# Or use local URL with API key:
104-
embeddings = ElasticsearchEmbeddings.from_credentials(
105-
model_id,
106-
es_url="http://localhost:9200",
107-
es_api_key="bar"
108-
)
109-
110-
# Or use username/password authentication:
111-
embeddings = ElasticsearchEmbeddings.from_credentials(
112-
model_id,
113-
es_url="http://localhost:9200",
114-
es_user="elastic",
115-
es_password="password"
116-
)
117-
118-
# Note: To use environment variables, read them yourself:
119-
# import os
120-
# embeddings = ElasticsearchEmbeddings.from_credentials(
121-
# model_id,
122-
# es_cloud_id=os.environ.get("ES_CLOUD_ID"),
123-
# es_api_key=os.environ.get("ES_API_KEY"),
124-
# )
125-
126-
documents = [
127-
"This is an example document.",
128-
"Another example document to generate embeddings for.",
129-
]
130-
embeddings_generator.embed_documents(documents)
178+
client (AsyncElasticsearch or Elasticsearch, optional):
179+
Pre-existing Elasticsearch connection. Either provide this OR
180+
credentials.
181+
es_url (str, optional): URL of the Elasticsearch instance to connect to.
182+
es_cloud_id (str, optional): Cloud ID of the Elasticsearch instance.
183+
es_user (str, optional): Username to use when connecting to
184+
Elasticsearch.
185+
es_api_key (str, optional): API key to use when connecting to
186+
Elasticsearch.
187+
es_password (str, optional): Password to use when connecting to
188+
Elasticsearch.
131189
"""
190+
# Accept either client OR credentials (one required)
191+
if client is not None:
192+
es_connection = client
193+
elif es_url is not None or es_cloud_id is not None:
194+
es_connection = create_async_elasticsearch_client(
195+
url=es_url,
196+
cloud_id=es_cloud_id,
197+
api_key=es_api_key,
198+
username=es_user,
199+
password=es_password,
200+
)
201+
else:
202+
raise ValueError(
203+
"Either 'client' or credentials (es_url, es_cloud_id, etc.) "
204+
"must be provided."
205+
)
132206

133-
# Connect to Elasticsearch using create_elasticsearch_client for consistency
134-
es_connection = create_async_elasticsearch_client(
135-
url=es_url,
136-
cloud_id=es_cloud_id,
137-
api_key=es_api_key,
138-
username=es_user,
139-
password=es_password,
140-
params=es_params,
141-
)
142-
return cls(es_connection, model_id, input_field=input_field)
143-
144-
@classmethod
145-
def from_es_connection(
146-
cls,
147-
model_id: str,
148-
es_connection: AsyncElasticsearch,
149-
input_field: str = "text_field",
150-
) -> AsyncElasticsearchEmbeddings:
151-
"""
152-
Instantiate embeddings from an existing Elasticsearch connection.
153-
154-
This method provides a way to create an instance of the ElasticsearchEmbeddings
155-
class using an existing Elasticsearch connection.
156-
157-
Args:
158-
model_id (str): The model_id of the model deployed in the Elasticsearch cluster.
159-
es_connection (elasticsearch.Elasticsearch): An existing Elasticsearch
160-
connection object. input_field (str, optional): The name of the key for the
161-
input text field in the document. Defaults to 'text_field'.
162-
163-
Returns:
164-
ElasticsearchEmbeddings: An instance of the ElasticsearchEmbeddings class.
165-
166-
Example:
167-
.. code-block:: python
168-
169-
from elasticsearch import Elasticsearch
170-
171-
from langchain_elasticsearch.embeddings import ElasticsearchEmbeddings
172-
173-
# Define the model ID and input field name (if different from default)
174-
model_id = "your_model_id"
175-
# Optional, only if different from 'text_field'
176-
input_field = "your_input_field"
177-
178-
# Create Elasticsearch connection
179-
es_connection = Elasticsearch(
180-
hosts=["localhost:9200"], http_auth=("user", "password")
181-
)
182-
183-
# Instantiate ElasticsearchEmbeddings using the existing connection
184-
embeddings = ElasticsearchEmbeddings.from_es_connection(
185-
model_id,
186-
es_connection,
187-
input_field=input_field,
188-
)
189-
190-
documents = [
191-
"This is an example document.",
192-
"Another example document to generate embeddings for.",
193-
]
194-
embeddings_generator.embed_documents(documents)
195-
"""
196-
return cls(es_connection, model_id, input_field=input_field)
207+
# Apply User-Agent for telemetry
208+
# (applies to both passed and internally created clients)
209+
self.client = async_with_user_agent_header(es_connection, "langchain-py-e")
210+
self.model_id = model_id
211+
self.input_field = input_field
197212

198213
async def _embedding_func(self, texts: List[str]) -> List[List[float]]:
199214
"""

0 commit comments

Comments
 (0)