99import java .util .Optional ;
1010import java .util .Spliterator ;
1111import java .util .Spliterators ;
12+ import java .util .function .BiFunction ;
1213import java .util .function .Function ;
1314import java .util .stream .Stream ;
1415import java .util .stream .StreamSupport ;
@@ -46,6 +47,12 @@ public class ElasticsearchRetrievalChain extends RetrievalChain implements Close
4647 */
4748private final Function <String , ObjectNode > queryCreator ;
4849
50+ /**
51+ * Consumes an elasticsearch hit and the question and creates a document as a
52+ * result
53+ */
54+ private final BiFunction <ObjectNode , String , Map <String , String >> documentCreator ;
55+
4956/**
5057 * {@link ObjectMapper} used for query creation and document deserialization
5158 */
@@ -59,14 +66,31 @@ public class ElasticsearchRetrievalChain extends RetrievalChain implements Close
5966 * @param maxDocumentCount {@link #getMaxDocumentCount()}
6067 * @param objectMapper {@link #objectMapper}
6168 * @param queryCreator {@link #queryCreator}
69+ * @param documentCreator {@link #documentCreator}
6270 */
6371public ElasticsearchRetrievalChain (final String index , final RestClient restClient , final int maxDocumentCount ,
64- final ObjectMapper objectMapper , final Function <String , ObjectNode > queryCreator ) {
72+ final ObjectMapper objectMapper , final Function <String , ObjectNode > queryCreator ,
73+ final BiFunction <ObjectNode , String , Map <String , String >> documentCreator ) {
6574super (maxDocumentCount );
6675this .index = index ;
6776this .restClient = restClient ;
6877this .objectMapper = objectMapper ;
6978this .queryCreator = queryCreator ;
79+ this .documentCreator = documentCreator ;
80+ }
81+
82+ /**
83+ * Creates an instance of {@link ElasticsearchRetrievalChain}
84+ *
85+ * @param index {@link #index}
86+ * @param restClient {@link #restClient}
87+ * @param maxDocumentCount {@link #getMaxDocumentCount()}
88+ * @param objectMapper {@link #objectMapper}
89+ * @param queryCreator {@link #queryCreator}
90+ */
91+ public ElasticsearchRetrievalChain (final String index , final RestClient restClient , final int maxDocumentCount ,
92+ final ObjectMapper objectMapper , final Function <String , ObjectNode > queryCreator ) {
93+ this (index , restClient , maxDocumentCount , objectMapper , queryCreator , defaultDocumentCreator (objectMapper ));
7094}
7195
7296/**
@@ -144,24 +168,36 @@ public Stream<Map<String, String>> run(final String input) {
144168}
145169
146170return StreamSupport .stream (Spliterators .spliteratorUnknownSize (hits .iterator (), Spliterator .ORDERED ), false )
147- .map (ObjectNode .class ::cast ).map (o -> o .get ("_source" )).map (ObjectNode .class ::cast )
148- .map (source -> createDocument (source , input ));
171+ .map (ObjectNode .class ::cast ).map (hitNode -> documentCreator .apply (hitNode , input ));
149172}
150173
151- private Map <String , String > createDocument (final ObjectNode source , final String question ) {
152- final Map <String , Object > sourceMap = objectMapper .convertValue (source ,
153- new TypeReference <Map <String , Object >>() {
154- // noop
155- });
174+ /**
175+ * creates the default {@link #queryCreator}
176+ *
177+ * @param objectMapper the {@link ObjectMapper} used for json operations
178+ * @return {@link BiFunction} which consumes a hit node and the question and
179+ * produces a document consisting of all (key, value)-pairs of the hit's
180+ * _source object
181+ */
182+ public static BiFunction <ObjectNode , String , Map <String , String >> defaultDocumentCreator (
183+ final ObjectMapper objectMapper ) {
184+ return (hitObjectNode , question ) -> {
185+ final ObjectNode source = (ObjectNode ) hitObjectNode .get ("_source" );
156186
157- final Map <String , String > document = new HashMap <>();
158- document .put (PromptConstants .QUESTION , question );
187+ final Map <String , Object > sourceMap = objectMapper .convertValue (source ,
188+ new TypeReference <Map <String , Object >>() {
189+ // noop
190+ });
159191
160- for (final Entry <String , Object > sourceEntry : sourceMap .entrySet ()) {
161- document .put (sourceEntry .getKey (), sourceEntry .getValue ().toString ());
162- }
192+ final Map <String , String > document = new HashMap <>();
193+ document .put (PromptConstants .QUESTION , question );
194+
195+ for (final Entry <String , Object > sourceEntry : sourceMap .entrySet ()) {
196+ document .put (sourceEntry .getKey (), sourceEntry .getValue ().toString ());
197+ }
163198
164- return document ;
199+ return document ;
200+ };
165201}
166202
167203@ Override
0 commit comments