|
1 | 1 | from __future__ import annotations |
2 | 2 |
|
3 | | -from typing import Any, Dict, List, Optional |
| 3 | +from typing import List, Optional |
4 | 4 |
|
5 | 5 | from elasticsearch import AsyncElasticsearch |
6 | 6 | from elasticsearch.helpers.vectorstore import AsyncEmbeddingService |
|
13 | 13 |
|
14 | 14 |
|
15 | 15 | class AsyncElasticsearchEmbeddings(Embeddings): |
16 | | - """Elasticsearch embedding models. |
| 16 | + """`Elasticsearch` embedding models. |
17 | 17 |
|
18 | 18 | This class provides an interface to generate embeddings using a model deployed |
19 | | - in an Elasticsearch cluster. It requires an Elasticsearch connection object |
20 | | - and the model_id of the model deployed in the cluster. |
| 19 | + in an Elasticsearch cluster. It requires an Elasticsearch connection and the |
| 20 | + model_id of the model deployed in the cluster. |
21 | 21 |
|
22 | 22 | In Elasticsearch you need to have an embedding model loaded and deployed. |
23 | 23 | - https://www.elastic.co/guide/en/elasticsearch/reference/current/infer-trained-model.html |
24 | 24 | - https://www.elastic.co/guide/en/machine-learning/current/ml-nlp-deploy-models.html |
25 | 25 |
|
| 26 | + Setup: |
| 27 | + Install `langchain_elasticsearch` and start Elasticsearch locally using |
| 28 | + the start-local script. |
| 29 | +
|
| 30 | + ```bash |
| 31 | + pip install -qU langchain_elasticsearch |
| 32 | + curl -fsSL https://elastic.co/start-local | sh |
| 33 | + ``` |
| 34 | +
|
| 35 | + This will create an `elastic-start-local` folder. To start Elasticsearch |
| 36 | + and Kibana: |
| 37 | + ```bash |
| 38 | + cd elastic-start-local |
| 39 | + ./start.sh |
| 40 | + ``` |
| 41 | +
|
| 42 | + Elasticsearch will be available at `http://localhost:9200`. The password |
| 43 | + for the `elastic` user and API key are stored in the `.env` file in the |
| 44 | + `elastic-start-local` folder. |
| 45 | +
|
| 46 | + Key init args: |
| 47 | + model_id: str |
| 48 | + The model_id of the model deployed in the Elasticsearch cluster. |
| 49 | + input_field: str |
| 50 | + The name of the key for the input text field in the document. |
| 51 | + Defaults to 'text_field'. |
| 52 | +
|
| 53 | + Key init args — client params: |
| 54 | + client: Optional[AsyncElasticsearch or Elasticsearch] |
| 55 | + Pre-existing Elasticsearch connection. Either provide this OR credentials. |
| 56 | + es_url: Optional[str] |
| 57 | + URL of the Elasticsearch instance to connect to. |
| 58 | + es_cloud_id: Optional[str] |
| 59 | + Cloud ID of the Elasticsearch instance to connect to. |
| 60 | + es_user: Optional[str] |
| 61 | + Username to use when connecting to Elasticsearch. |
| 62 | + es_api_key: Optional[str] |
| 63 | + API key to use when connecting to Elasticsearch. |
| 64 | + es_password: Optional[str] |
| 65 | + Password to use when connecting to Elasticsearch. |
| 66 | +
|
| 67 | + Instantiate: |
| 68 | + ```python |
| 69 | + from langchain_elasticsearch import ElasticsearchEmbeddings |
| 70 | +
|
| 71 | + embeddings = ElasticsearchEmbeddings( |
| 72 | + model_id="your_model_id", |
| 73 | + es_url="http://localhost:9200" |
| 74 | + ) |
| 75 | + ``` |
| 76 | +
|
| 77 | + Instantiate with API key (URL): |
| 78 | + ```python |
| 79 | + from langchain_elasticsearch import ElasticsearchEmbeddings |
| 80 | +
|
| 81 | + embeddings = ElasticsearchEmbeddings( |
| 82 | + model_id="your_model_id", |
| 83 | + es_url="http://localhost:9200", |
| 84 | + es_api_key="your-api-key" |
| 85 | + ) |
| 86 | + ``` |
| 87 | +
|
| 88 | + Instantiate with username/password (URL): |
| 89 | + ```python |
| 90 | + from langchain_elasticsearch import ElasticsearchEmbeddings |
| 91 | +
|
| 92 | + embeddings = ElasticsearchEmbeddings( |
| 93 | + model_id="your_model_id", |
| 94 | + es_url="http://localhost:9200", |
| 95 | + es_user="elastic", |
| 96 | + es_password="password" |
| 97 | + ) |
| 98 | + ``` |
| 99 | +
|
| 100 | + If you want to use a cloud hosted Elasticsearch instance, you can pass in the |
| 101 | + es_cloud_id argument instead of the es_url argument. |
| 102 | +
|
| 103 | + Instantiate from cloud (with username/password): |
| 104 | + ```python |
| 105 | + from langchain_elasticsearch import ElasticsearchEmbeddings |
| 106 | +
|
| 107 | + embeddings = ElasticsearchEmbeddings( |
| 108 | + model_id="your_model_id", |
| 109 | + es_cloud_id="<cloud_id>", |
| 110 | + es_user="elastic", |
| 111 | + es_password="<password>" |
| 112 | + ) |
| 113 | + ``` |
| 114 | +
|
| 115 | + Instantiate from cloud (with API key): |
| 116 | + ```python |
| 117 | + from langchain_elasticsearch import ElasticsearchEmbeddings |
| 118 | +
|
| 119 | + embeddings = ElasticsearchEmbeddings( |
| 120 | + model_id="your_model_id", |
| 121 | + es_cloud_id="<cloud_id>", |
| 122 | + es_api_key="your-api-key" |
| 123 | + ) |
| 124 | + ``` |
| 125 | +
|
| 126 | + You can also connect to an existing Elasticsearch instance by passing in a |
| 127 | + pre-existing Elasticsearch connection via the client argument. |
| 128 | +
|
| 129 | + Instantiate from existing connection: |
| 130 | + ```python |
| 131 | + from langchain_elasticsearch import ElasticsearchEmbeddings |
| 132 | + from elasticsearch import Elasticsearch |
| 133 | +
|
| 134 | + client = Elasticsearch("http://localhost:9200") |
| 135 | + embeddings = ElasticsearchEmbeddings( |
| 136 | + model_id="your_model_id", |
| 137 | + client=client |
| 138 | + ) |
| 139 | + ``` |
| 140 | +
|
| 141 | + Generate embeddings: |
| 142 | + ```python |
| 143 | + documents = [ |
| 144 | + "This is an example document.", |
| 145 | + "Another example document to generate embeddings for.", |
| 146 | + ] |
| 147 | + embeddings_list = embeddings.embed_documents(documents) |
| 148 | + ``` |
| 149 | +
|
| 150 | + Generate query embedding: |
| 151 | + ```python |
| 152 | + query_embedding = embeddings.embed_query("What is this about?") |
| 153 | + ``` |
| 154 | +
|
26 | 155 | For synchronous applications, use the `ElasticsearchEmbeddings` class. |
27 | | - For asyhchronous applications, use the `AsyncElasticsearchEmbeddings` class. |
| 156 | + For asynchronous applications, use the `AsyncElasticsearchEmbeddings` class. |
28 | 157 | """ # noqa: E501 |
29 | 158 |
|
30 | 159 | def __init__( |
31 | 160 | self, |
32 | | - client: AsyncElasticsearch, |
33 | 161 | model_id: str, |
34 | 162 | *, |
35 | 163 | input_field: str = "text_field", |
36 | | - ): |
37 | | - """ |
38 | | - Initialize the ElasticsearchEmbeddings instance. |
39 | | -
|
40 | | - Args: |
41 | | - client (Elasticsearch or AsyncElasticsearch): An Elasticsearch client |
42 | | - object or an AsyncElasticsearch client object. |
43 | | - model_id (str): The model_id of the model deployed in the Elasticsearch |
44 | | - cluster. |
45 | | - input_field (str): The name of the key for the input text field in the |
46 | | - document. Defaults to 'text_field'. |
47 | | - """ |
48 | | - # Apply User-Agent for telemetry |
49 | | - # (applies to both passed and internally created clients) |
50 | | - self.client = async_with_user_agent_header(client, "langchain-py-e") |
51 | | - self.model_id = model_id |
52 | | - self.input_field = input_field |
53 | | - |
54 | | - @classmethod |
55 | | - def from_credentials( |
56 | | - cls, |
57 | | - model_id: str, |
58 | | - *, |
| 164 | + client: Optional[AsyncElasticsearch] = None, |
59 | 165 | es_url: Optional[str] = None, |
60 | 166 | es_cloud_id: Optional[str] = None, |
61 | | - es_api_key: Optional[str] = None, |
62 | 167 | es_user: Optional[str] = None, |
| 168 | + es_api_key: Optional[str] = None, |
63 | 169 | es_password: Optional[str] = None, |
64 | | - es_params: Optional[Dict[str, Any]] = None, |
65 | | - input_field: str = "text_field", |
66 | | - ) -> AsyncElasticsearchEmbeddings: |
67 | | - """Instantiate embeddings from Elasticsearch credentials. |
| 170 | + ): |
| 171 | + """Initialize the ElasticsearchEmbeddings instance. |
68 | 172 |
|
69 | 173 | Args: |
70 | 174 | model_id (str): The model_id of the model deployed in the Elasticsearch |
71 | 175 | cluster. |
72 | 176 | input_field (str): The name of the key for the input text field in the |
73 | 177 | document. Defaults to 'text_field'. |
74 | | - es_url: (str, optional): URL of the Elasticsearch instance to connect to. |
75 | | - es_cloud_id: (str, optional): The Elasticsearch cloud ID to connect to. |
76 | | - es_api_key: (str, optional): API key to use connecting to Elasticsearch. |
77 | | - es_user: (str, optional): Elasticsearch username. |
78 | | - es_password: (str, optional): Elasticsearch password. |
79 | | - es_params: (dict, optional): Additional parameters for the |
80 | | - Elasticsearch client. |
81 | | -
|
82 | | - Example: |
83 | | - .. code-block:: python |
84 | | -
|
85 | | - from langchain_elasticserach.embeddings import ElasticsearchEmbeddings |
86 | | -
|
87 | | -
|
88 | | - # Define the model ID and input field name (if different from default) |
89 | | - model_id = "your_model_id" |
90 | | - # Optional, only if different from 'text_field' |
91 | | - input_field = "your_input_field" |
92 | | -
|
93 | | - # Provide either es_url (local) or es_cloud_id (cloud). |
94 | | - # For authentication, provide either es_api_key or |
95 | | - # (es_user + es_password). |
96 | | - embeddings = ElasticsearchEmbeddings.from_credentials( |
97 | | - model_id, |
98 | | - input_field=input_field, |
99 | | - es_cloud_id="foo", |
100 | | - es_api_key="bar", |
101 | | - ) |
102 | | -
|
103 | | - # Or use local URL with API key: |
104 | | - embeddings = ElasticsearchEmbeddings.from_credentials( |
105 | | - model_id, |
106 | | - es_url="http://localhost:9200", |
107 | | - es_api_key="bar" |
108 | | - ) |
109 | | -
|
110 | | - # Or use username/password authentication: |
111 | | - embeddings = ElasticsearchEmbeddings.from_credentials( |
112 | | - model_id, |
113 | | - es_url="http://localhost:9200", |
114 | | - es_user="elastic", |
115 | | - es_password="password" |
116 | | - ) |
117 | | -
|
118 | | - # Note: To use environment variables, read them yourself: |
119 | | - # import os |
120 | | - # embeddings = ElasticsearchEmbeddings.from_credentials( |
121 | | - # model_id, |
122 | | - # es_cloud_id=os.environ.get("ES_CLOUD_ID"), |
123 | | - # es_api_key=os.environ.get("ES_API_KEY"), |
124 | | - # ) |
125 | | -
|
126 | | - documents = [ |
127 | | - "This is an example document.", |
128 | | - "Another example document to generate embeddings for.", |
129 | | - ] |
130 | | - embeddings_generator.embed_documents(documents) |
| 178 | + client (AsyncElasticsearch or Elasticsearch, optional): |
| 179 | + Pre-existing Elasticsearch connection. Either provide this OR |
| 180 | + credentials. |
| 181 | + es_url (str, optional): URL of the Elasticsearch instance to connect to. |
| 182 | + es_cloud_id (str, optional): Cloud ID of the Elasticsearch instance. |
| 183 | + es_user (str, optional): Username to use when connecting to |
| 184 | + Elasticsearch. |
| 185 | + es_api_key (str, optional): API key to use when connecting to |
| 186 | + Elasticsearch. |
| 187 | + es_password (str, optional): Password to use when connecting to |
| 188 | + Elasticsearch. |
131 | 189 | """ |
| 190 | + # Accept either client OR credentials (one required) |
| 191 | + if client is not None: |
| 192 | + es_connection = client |
| 193 | + elif es_url is not None or es_cloud_id is not None: |
| 194 | + es_connection = create_async_elasticsearch_client( |
| 195 | + url=es_url, |
| 196 | + cloud_id=es_cloud_id, |
| 197 | + api_key=es_api_key, |
| 198 | + username=es_user, |
| 199 | + password=es_password, |
| 200 | + ) |
| 201 | + else: |
| 202 | + raise ValueError( |
| 203 | + "Either 'client' or credentials (es_url, es_cloud_id, etc.) " |
| 204 | + "must be provided." |
| 205 | + ) |
132 | 206 |
|
133 | | - # Connect to Elasticsearch using create_elasticsearch_client for consistency |
134 | | - es_connection = create_async_elasticsearch_client( |
135 | | - url=es_url, |
136 | | - cloud_id=es_cloud_id, |
137 | | - api_key=es_api_key, |
138 | | - username=es_user, |
139 | | - password=es_password, |
140 | | - params=es_params, |
141 | | - ) |
142 | | - return cls(es_connection, model_id, input_field=input_field) |
143 | | - |
144 | | - @classmethod |
145 | | - def from_es_connection( |
146 | | - cls, |
147 | | - model_id: str, |
148 | | - es_connection: AsyncElasticsearch, |
149 | | - input_field: str = "text_field", |
150 | | - ) -> AsyncElasticsearchEmbeddings: |
151 | | - """ |
152 | | - Instantiate embeddings from an existing Elasticsearch connection. |
153 | | -
|
154 | | - This method provides a way to create an instance of the ElasticsearchEmbeddings |
155 | | - class using an existing Elasticsearch connection. |
156 | | -
|
157 | | - Args: |
158 | | - model_id (str): The model_id of the model deployed in the Elasticsearch cluster. |
159 | | - es_connection (elasticsearch.Elasticsearch): An existing Elasticsearch |
160 | | - connection object. input_field (str, optional): The name of the key for the |
161 | | - input text field in the document. Defaults to 'text_field'. |
162 | | -
|
163 | | - Returns: |
164 | | - ElasticsearchEmbeddings: An instance of the ElasticsearchEmbeddings class. |
165 | | -
|
166 | | - Example: |
167 | | - .. code-block:: python |
168 | | -
|
169 | | - from elasticsearch import Elasticsearch |
170 | | -
|
171 | | - from langchain_elasticsearch.embeddings import ElasticsearchEmbeddings |
172 | | -
|
173 | | - # Define the model ID and input field name (if different from default) |
174 | | - model_id = "your_model_id" |
175 | | - # Optional, only if different from 'text_field' |
176 | | - input_field = "your_input_field" |
177 | | -
|
178 | | - # Create Elasticsearch connection |
179 | | - es_connection = Elasticsearch( |
180 | | - hosts=["localhost:9200"], http_auth=("user", "password") |
181 | | - ) |
182 | | -
|
183 | | - # Instantiate ElasticsearchEmbeddings using the existing connection |
184 | | - embeddings = ElasticsearchEmbeddings.from_es_connection( |
185 | | - model_id, |
186 | | - es_connection, |
187 | | - input_field=input_field, |
188 | | - ) |
189 | | -
|
190 | | - documents = [ |
191 | | - "This is an example document.", |
192 | | - "Another example document to generate embeddings for.", |
193 | | - ] |
194 | | - embeddings_generator.embed_documents(documents) |
195 | | - """ |
196 | | - return cls(es_connection, model_id, input_field=input_field) |
| 207 | + # Apply User-Agent for telemetry |
| 208 | + # (applies to both passed and internally created clients) |
| 209 | + self.client = async_with_user_agent_header(es_connection, "langchain-py-e") |
| 210 | + self.model_id = model_id |
| 211 | + self.input_field = input_field |
197 | 212 |
|
198 | 213 | async def _embedding_func(self, texts: List[str]) -> List[List[float]]: |
199 | 214 | """ |
|
0 commit comments