Skip to content

Commit bddc2e4

Browse files
authored
[Inference API] Add Google Vertex AI Rerank support (elastic#110273)
1 parent 258a8b5 commit bddc2e4

File tree

33 files changed

+1975
-18
lines changed

33 files changed

+1975
-18
lines changed

server/src/main/java/org/elasticsearch/TransportVersions.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,7 @@ static TransportVersion def(int id) {
206206
public static final TransportVersion SECURITY_MIGRATIONS_MIGRATION_NEEDED_ADDED = def(8_697_00_0);
207207
public static final TransportVersion K_FOR_KNN_QUERY_ADDED = def(8_698_00_0);
208208
public static final TransportVersion TEXT_SIMILARITY_RERANKER_RETRIEVER = def(8_699_00_0);
209+
public static final TransportVersion ML_INFERENCE_GOOGLE_VERTEX_AI_RERANKING_ADDED = def(8_700_00_0);
209210

210211
/*
211212
* STOP! READ THIS FIRST! No, really,

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,8 @@
5353
import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiSecretSettings;
5454
import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsServiceSettings;
5555
import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsTaskSettings;
56+
import org.elasticsearch.xpack.inference.services.googlevertexai.rerank.GoogleVertexAiRerankServiceSettings;
57+
import org.elasticsearch.xpack.inference.services.googlevertexai.rerank.GoogleVertexAiRerankTaskSettings;
5658
import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceServiceSettings;
5759
import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserServiceSettings;
5860
import org.elasticsearch.xpack.inference.services.mistral.embeddings.MistralEmbeddingsServiceSettings;
@@ -314,6 +316,22 @@ private static void addGoogleVertexAiNamedWriteables(List<NamedWriteableRegistry
314316
GoogleVertexAiEmbeddingsTaskSettings::new
315317
)
316318
);
319+
320+
namedWriteables.add(
321+
new NamedWriteableRegistry.Entry(
322+
ServiceSettings.class,
323+
GoogleVertexAiRerankServiceSettings.NAME,
324+
GoogleVertexAiRerankServiceSettings::new
325+
)
326+
);
327+
328+
namedWriteables.add(
329+
new NamedWriteableRegistry.Entry(
330+
TaskSettings.class,
331+
GoogleVertexAiRerankTaskSettings.NAME,
332+
GoogleVertexAiRerankTaskSettings::new
333+
)
334+
);
317335
}
318336

319337
private static void addInternalElserNamedWriteables(List<NamedWriteableRegistry.Entry> namedWriteables) {

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/googlevertexai/GoogleVertexAiActionCreator.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
1212
import org.elasticsearch.xpack.inference.services.ServiceComponents;
1313
import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsModel;
14+
import org.elasticsearch.xpack.inference.services.googlevertexai.rerank.GoogleVertexAiRerankModel;
1415

1516
import java.util.Map;
1617
import java.util.Objects;
@@ -30,4 +31,9 @@ public GoogleVertexAiActionCreator(Sender sender, ServiceComponents serviceCompo
3031
public ExecutableAction create(GoogleVertexAiEmbeddingsModel model, Map<String, Object> taskSettings) {
3132
return new GoogleVertexAiEmbeddingsAction(sender, model, serviceComponents);
3233
}
34+
35+
@Override
36+
public ExecutableAction create(GoogleVertexAiRerankModel model, Map<String, Object> taskSettings) {
37+
return new GoogleVertexAiRerankAction(sender, model, serviceComponents.threadPool());
38+
}
3339
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/googlevertexai/GoogleVertexAiActionVisitor.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,13 @@
99

1010
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
1111
import org.elasticsearch.xpack.inference.services.googlevertexai.embeddings.GoogleVertexAiEmbeddingsModel;
12+
import org.elasticsearch.xpack.inference.services.googlevertexai.rerank.GoogleVertexAiRerankModel;
1213

1314
import java.util.Map;
1415

1516
public interface GoogleVertexAiActionVisitor {
1617

1718
ExecutableAction create(GoogleVertexAiEmbeddingsModel model, Map<String, Object> taskSettings);
1819

20+
ExecutableAction create(GoogleVertexAiRerankModel model, Map<String, Object> taskSettings);
1921
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.external.action.googlevertexai;
9+
10+
import org.elasticsearch.ElasticsearchException;
11+
import org.elasticsearch.action.ActionListener;
12+
import org.elasticsearch.core.TimeValue;
13+
import org.elasticsearch.inference.InferenceServiceResults;
14+
import org.elasticsearch.threadpool.ThreadPool;
15+
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
16+
import org.elasticsearch.xpack.inference.external.http.sender.GoogleVertexAiRerankRequestManager;
17+
import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs;
18+
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
19+
import org.elasticsearch.xpack.inference.services.googlevertexai.rerank.GoogleVertexAiRerankModel;
20+
21+
import java.util.Objects;
22+
23+
import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage;
24+
import static org.elasticsearch.xpack.inference.external.action.ActionUtils.createInternalServerError;
25+
import static org.elasticsearch.xpack.inference.external.action.ActionUtils.wrapFailuresInElasticsearchException;
26+
27+
public class GoogleVertexAiRerankAction implements ExecutableAction {
28+
29+
private final String failedToSendRequestErrorMessage;
30+
31+
private final Sender sender;
32+
33+
private final GoogleVertexAiRerankRequestManager requestManager;
34+
35+
public GoogleVertexAiRerankAction(Sender sender, GoogleVertexAiRerankModel model, ThreadPool threadPool) {
36+
Objects.requireNonNull(model);
37+
this.sender = Objects.requireNonNull(sender);
38+
this.failedToSendRequestErrorMessage = constructFailedToSendRequestMessage(model.uri(), "Google Vertex AI rerank");
39+
this.requestManager = GoogleVertexAiRerankRequestManager.of(model, threadPool);
40+
}
41+
42+
@Override
43+
public void execute(InferenceInputs inferenceInputs, TimeValue timeout, ActionListener<InferenceServiceResults> listener) {
44+
try {
45+
ActionListener<InferenceServiceResults> wrappedListener = wrapFailuresInElasticsearchException(
46+
failedToSendRequestErrorMessage,
47+
listener
48+
);
49+
sender.send(requestManager, inferenceInputs, timeout, wrappedListener);
50+
} catch (ElasticsearchException e) {
51+
listener.onFailure(e);
52+
} catch (Exception e) {
53+
listener.onFailure(createInternalServerError(e, failedToSendRequestErrorMessage));
54+
}
55+
}
56+
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/GoogleVertexAiEmbeddingsRequestManager.java

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,11 +41,19 @@ private static ResponseHandler createEmbeddingsHandler() {
4141
private final Truncator truncator;
4242

4343
public GoogleVertexAiEmbeddingsRequestManager(GoogleVertexAiEmbeddingsModel model, Truncator truncator, ThreadPool threadPool) {
44-
super(threadPool, model);
44+
super(threadPool, model, RateLimitGrouping.of(model));
4545
this.model = Objects.requireNonNull(model);
4646
this.truncator = Objects.requireNonNull(truncator);
4747
}
4848

49+
record RateLimitGrouping(int modelIdHash) {
50+
public static RateLimitGrouping of(GoogleVertexAiEmbeddingsModel model) {
51+
Objects.requireNonNull(model);
52+
53+
return new RateLimitGrouping(model.rateLimitServiceSettings().modelId().hashCode());
54+
}
55+
}
56+
4957
@Override
5058
public void execute(
5159
String query,

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/GoogleVertexAiRequestManager.java

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -10,19 +10,9 @@
1010
import org.elasticsearch.threadpool.ThreadPool;
1111
import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiModel;
1212

13-
import java.util.Objects;
14-
1513
public abstract class GoogleVertexAiRequestManager extends BaseRequestManager {
1614

17-
GoogleVertexAiRequestManager(ThreadPool threadPool, GoogleVertexAiModel model) {
18-
super(threadPool, model.getInferenceEntityId(), RateLimitGrouping.of(model), model.rateLimitServiceSettings().rateLimitSettings());
19-
}
20-
21-
record RateLimitGrouping(int modelIdHash) {
22-
public static RateLimitGrouping of(GoogleVertexAiModel model) {
23-
Objects.requireNonNull(model);
24-
25-
return new RateLimitGrouping(model.rateLimitServiceSettings().modelId().hashCode());
26-
}
15+
GoogleVertexAiRequestManager(ThreadPool threadPool, GoogleVertexAiModel model, Object rateLimitGroup) {
16+
super(threadPool, model.getInferenceEntityId(), rateLimitGroup, model.rateLimitServiceSettings().rateLimitSettings());
2717
}
2818
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.external.http.sender;
9+
10+
import org.apache.logging.log4j.LogManager;
11+
import org.apache.logging.log4j.Logger;
12+
import org.elasticsearch.action.ActionListener;
13+
import org.elasticsearch.inference.InferenceServiceResults;
14+
import org.elasticsearch.threadpool.ThreadPool;
15+
import org.elasticsearch.xpack.inference.external.googlevertexai.GoogleVertexAiResponseHandler;
16+
import org.elasticsearch.xpack.inference.external.http.retry.RequestSender;
17+
import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
18+
import org.elasticsearch.xpack.inference.external.request.googlevertexai.GoogleVertexAiRerankRequest;
19+
import org.elasticsearch.xpack.inference.external.response.googlevertexai.GoogleVertexAiRerankResponseEntity;
20+
import org.elasticsearch.xpack.inference.services.googlevertexai.rerank.GoogleVertexAiRerankModel;
21+
22+
import java.util.List;
23+
import java.util.Objects;
24+
import java.util.function.Supplier;
25+
26+
public class GoogleVertexAiRerankRequestManager extends GoogleVertexAiRequestManager {
27+
28+
private static final Logger logger = LogManager.getLogger(GoogleVertexAiRerankRequestManager.class);
29+
30+
private static final ResponseHandler HANDLER = createGoogleVertexAiResponseHandler();
31+
32+
private static ResponseHandler createGoogleVertexAiResponseHandler() {
33+
return new GoogleVertexAiResponseHandler(
34+
"Google Vertex AI rerank",
35+
(request, response) -> GoogleVertexAiRerankResponseEntity.fromResponse(response)
36+
);
37+
}
38+
39+
public static GoogleVertexAiRerankRequestManager of(GoogleVertexAiRerankModel model, ThreadPool threadPool) {
40+
return new GoogleVertexAiRerankRequestManager(Objects.requireNonNull(model), Objects.requireNonNull(threadPool));
41+
}
42+
43+
private final GoogleVertexAiRerankModel model;
44+
45+
private GoogleVertexAiRerankRequestManager(GoogleVertexAiRerankModel model, ThreadPool threadPool) {
46+
super(threadPool, model, RateLimitGrouping.of(model));
47+
this.model = model;
48+
}
49+
50+
record RateLimitGrouping(int projectIdHash) {
51+
public static RateLimitGrouping of(GoogleVertexAiRerankModel model) {
52+
Objects.requireNonNull(model);
53+
54+
return new RateLimitGrouping(model.rateLimitServiceSettings().projectId().hashCode());
55+
}
56+
}
57+
58+
@Override
59+
public void execute(
60+
String query,
61+
List<String> input,
62+
RequestSender requestSender,
63+
Supplier<Boolean> hasRequestCompletedFunction,
64+
ActionListener<InferenceServiceResults> listener
65+
) {
66+
GoogleVertexAiRerankRequest request = new GoogleVertexAiRerankRequest(query, input, model);
67+
68+
execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener));
69+
}
70+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.external.request.googlevertexai;
9+
10+
import org.apache.http.HttpHeaders;
11+
import org.apache.http.client.methods.HttpPost;
12+
import org.apache.http.entity.ByteArrayEntity;
13+
import org.elasticsearch.common.Strings;
14+
import org.elasticsearch.xcontent.XContentType;
15+
import org.elasticsearch.xpack.inference.external.request.HttpRequest;
16+
import org.elasticsearch.xpack.inference.external.request.Request;
17+
import org.elasticsearch.xpack.inference.services.googlevertexai.rerank.GoogleVertexAiRerankModel;
18+
19+
import java.net.URI;
20+
import java.nio.charset.StandardCharsets;
21+
import java.util.List;
22+
import java.util.Objects;
23+
24+
public class GoogleVertexAiRerankRequest implements GoogleVertexAiRequest {
25+
26+
private final GoogleVertexAiRerankModel model;
27+
28+
private final String query;
29+
30+
private final List<String> input;
31+
32+
public GoogleVertexAiRerankRequest(String query, List<String> input, GoogleVertexAiRerankModel model) {
33+
this.model = Objects.requireNonNull(model);
34+
this.query = Objects.requireNonNull(query);
35+
this.input = Objects.requireNonNull(input);
36+
}
37+
38+
@Override
39+
public HttpRequest createHttpRequest() {
40+
HttpPost httpPost = new HttpPost(model.uri());
41+
42+
ByteArrayEntity byteEntity = new ByteArrayEntity(
43+
Strings.toString(
44+
new GoogleVertexAiRerankRequestEntity(query, input, model.getServiceSettings().modelId(), model.getTaskSettings().topN())
45+
).getBytes(StandardCharsets.UTF_8)
46+
);
47+
48+
httpPost.setEntity(byteEntity);
49+
httpPost.setHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType());
50+
51+
decorateWithAuth(httpPost);
52+
53+
return new HttpRequest(httpPost, getInferenceEntityId());
54+
}
55+
56+
public void decorateWithAuth(HttpPost httpPost) {
57+
GoogleVertexAiRequest.decorateWithBearerToken(httpPost, model.getSecretSettings());
58+
}
59+
60+
public GoogleVertexAiRerankModel model() {
61+
return model;
62+
}
63+
64+
@Override
65+
public String getInferenceEntityId() {
66+
return model.getInferenceEntityId();
67+
}
68+
69+
@Override
70+
public URI getURI() {
71+
return model.uri();
72+
}
73+
74+
@Override
75+
public Request truncate() {
76+
return this;
77+
}
78+
79+
@Override
80+
public boolean[] getTruncationInfo() {
81+
return null;
82+
}
83+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.external.request.googlevertexai;
9+
10+
import org.elasticsearch.core.Nullable;
11+
import org.elasticsearch.xcontent.ToXContentObject;
12+
import org.elasticsearch.xcontent.XContentBuilder;
13+
14+
import java.io.IOException;
15+
import java.util.List;
16+
import java.util.Objects;
17+
18+
public record GoogleVertexAiRerankRequestEntity(String query, List<String> inputs, @Nullable String model, @Nullable Integer topN)
19+
implements
20+
ToXContentObject {
21+
22+
private static final String MODEL_FIELD = "model";
23+
private static final String QUERY_FIELD = "query";
24+
private static final String RECORDS_FIELD = "records";
25+
private static final String ID_FIELD = "id";
26+
27+
private static final String CONTENT_FIELD = "content";
28+
private static final String TOP_N_FIELD = "topN";
29+
30+
public GoogleVertexAiRerankRequestEntity {
31+
Objects.requireNonNull(query);
32+
Objects.requireNonNull(inputs);
33+
}
34+
35+
@Override
36+
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
37+
builder.startObject();
38+
39+
if (model != null) {
40+
builder.field(MODEL_FIELD, model);
41+
}
42+
43+
builder.field(QUERY_FIELD, query);
44+
45+
builder.startArray(RECORDS_FIELD);
46+
47+
for (int recordId = 0; recordId < inputs.size(); recordId++) {
48+
builder.startObject();
49+
50+
{
51+
builder.field(ID_FIELD, String.valueOf(recordId));
52+
builder.field(CONTENT_FIELD, inputs.get(recordId));
53+
}
54+
55+
builder.endObject();
56+
}
57+
58+
builder.endArray();
59+
60+
if (topN != null) {
61+
builder.field(TOP_N_FIELD, topN);
62+
}
63+
64+
builder.endObject();
65+
66+
return builder;
67+
}
68+
}

0 commit comments

Comments
 (0)