Skip to content

Commit 131b436

Browse files
committed
added custom document creation from hit to ElasticsearchRetrievalChain
1 parent 0c9cf38 commit 131b436

File tree

1 file changed

+50
-14
lines changed

1 file changed

+50
-14
lines changed

src/main/java/com/github/hakenadu/javalangchains/chains/data/retrieval/ElasticsearchRetrievalChain.java

Lines changed: 50 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import java.util.Optional;
1010
import java.util.Spliterator;
1111
import java.util.Spliterators;
12+
import java.util.function.BiFunction;
1213
import java.util.function.Function;
1314
import java.util.stream.Stream;
1415
import java.util.stream.StreamSupport;
@@ -46,6 +47,12 @@ public class ElasticsearchRetrievalChain extends RetrievalChain implements Close
4647
*/
4748
private final Function<String, ObjectNode> queryCreator;
4849

50+
/**
51+
* Consumes an elasticsearch hit and the question and creates a document as a
52+
* result
53+
*/
54+
private final BiFunction<ObjectNode, String, Map<String, String>> documentCreator;
55+
4956
/**
5057
* {@link ObjectMapper} used for query creation and document deserialization
5158
*/
@@ -59,14 +66,31 @@ public class ElasticsearchRetrievalChain extends RetrievalChain implements Close
5966
* @param maxDocumentCount {@link #getMaxDocumentCount()}
6067
* @param objectMapper {@link #objectMapper}
6168
* @param queryCreator {@link #queryCreator}
69+
* @param documentCreator {@link #documentCreator}
6270
*/
6371
public ElasticsearchRetrievalChain(final String index, final RestClient restClient, final int maxDocumentCount,
64-
final ObjectMapper objectMapper, final Function<String, ObjectNode> queryCreator) {
72+
final ObjectMapper objectMapper, final Function<String, ObjectNode> queryCreator,
73+
final BiFunction<ObjectNode, String, Map<String, String>> documentCreator) {
6574
super(maxDocumentCount);
6675
this.index = index;
6776
this.restClient = restClient;
6877
this.objectMapper = objectMapper;
6978
this.queryCreator = queryCreator;
79+
this.documentCreator = documentCreator;
80+
}
81+
82+
/**
83+
* Creates an instance of {@link ElasticsearchRetrievalChain}
84+
*
85+
* @param index {@link #index}
86+
* @param restClient {@link #restClient}
87+
* @param maxDocumentCount {@link #getMaxDocumentCount()}
88+
* @param objectMapper {@link #objectMapper}
89+
* @param queryCreator {@link #queryCreator}
90+
*/
91+
public ElasticsearchRetrievalChain(final String index, final RestClient restClient, final int maxDocumentCount,
92+
final ObjectMapper objectMapper, final Function<String, ObjectNode> queryCreator) {
93+
this(index, restClient, maxDocumentCount, objectMapper, queryCreator, defaultDocumentCreator(objectMapper));
7094
}
7195

7296
/**
@@ -144,24 +168,36 @@ public Stream<Map<String, String>> run(final String input) {
144168
}
145169

146170
return StreamSupport.stream(Spliterators.spliteratorUnknownSize(hits.iterator(), Spliterator.ORDERED), false)
147-
.map(ObjectNode.class::cast).map(o -> o.get("_source")).map(ObjectNode.class::cast)
148-
.map(source -> createDocument(source, input));
171+
.map(ObjectNode.class::cast).map(hitNode -> documentCreator.apply(hitNode, input));
149172
}
150173

151-
private Map<String, String> createDocument(final ObjectNode source, final String question) {
152-
final Map<String, Object> sourceMap = objectMapper.convertValue(source,
153-
new TypeReference<Map<String, Object>>() {
154-
// noop
155-
});
174+
/**
175+
* creates the default {@link #queryCreator}
176+
*
177+
* @param objectMapper the {@link ObjectMapper} used for json operations
178+
* @return {@link BiFunction} which consumes a hit node and the question and
179+
* produces a document consisting of all (key, value)-pairs of the hit's
180+
* _source object
181+
*/
182+
public static BiFunction<ObjectNode, String, Map<String, String>> defaultDocumentCreator(
183+
final ObjectMapper objectMapper) {
184+
return (hitObjectNode, question) -> {
185+
final ObjectNode source = (ObjectNode) hitObjectNode.get("_source");
156186

157-
final Map<String, String> document = new HashMap<>();
158-
document.put(PromptConstants.QUESTION, question);
187+
final Map<String, Object> sourceMap = objectMapper.convertValue(source,
188+
new TypeReference<Map<String, Object>>() {
189+
// noop
190+
});
159191

160-
for (final Entry<String, Object> sourceEntry : sourceMap.entrySet()) {
161-
document.put(sourceEntry.getKey(), sourceEntry.getValue().toString());
162-
}
192+
final Map<String, String> document = new HashMap<>();
193+
document.put(PromptConstants.QUESTION, question);
194+
195+
for (final Entry<String, Object> sourceEntry : sourceMap.entrySet()) {
196+
document.put(sourceEntry.getKey(), sourceEntry.getValue().toString());
197+
}
163198

164-
return document;
199+
return document;
200+
};
165201
}
166202

167203
@Override

0 commit comments

Comments
 (0)