Skip to content

Commit 7ebd10e

Browse files
committed
format code with black
1 parent fac165c commit 7ebd10e

33 files changed

+497
-314
lines changed

kag/common/conf.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,8 @@ def update_conf(self, configs: dict):
207207
"""
208208
KAG_QA_TASK_CONFIG stores per-task configuration and should be cleaned up after use.
209209
"""
210-
KAG_QA_TASK_CONFIG = knext.common.cache.LinkCache(maxsize=100, ttl=300)
210+
KAG_QA_TASK_CONFIG = knext.common.cache.LinkCache(maxsize=100, ttl=300)
211+
211212

212213
class KAGConfigAccessor:
213214
@staticmethod
@@ -248,7 +249,11 @@ def init_env(config_file: str = None):
248249
project_id = os.getenv(KAGConstants.ENV_KAG_PROJECT_ID)
249250
host_addr = os.getenv(KAGConstants.ENV_KAG_PROJECT_HOST_ADDR)
250251
prod = False
251-
if project_id is not None and host_addr is not None and not validate_config_file(config_file):
252+
if (
253+
project_id is not None
254+
and host_addr is not None
255+
and not validate_config_file(config_file)
256+
):
252257
prod = True
253258
global KAG_CONFIG
254259
KAG_CONFIG.initialize(prod, config_file)

kag/common/tools/algorithm_tool/chunk_retriever/atomic_query_chunk_retriever.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -256,15 +256,11 @@ def invoke(self, task: Task, **kwargs) -> RetrieverOutput:
256256

257257
chunks = chunks + query_text_related_chunks
258258

259-
out = RetrieverOutput(
260-
retriever_method=self.name, chunks=chunks
261-
)
259+
out = RetrieverOutput(retriever_method=self.name, chunks=chunks)
262260
return out
263261
except Exception as e:
264262
logger.error(f"run calculate_sim_scores failed, info: {e}", exc_info=True)
265-
return RetrieverOutput(
266-
retriever_method=self.name, err_msg=str(e)
267-
)
263+
return RetrieverOutput(retriever_method=self.name, err_msg=str(e))
268264

269265
def schema(self):
270266
return {

kag/common/tools/algorithm_tool/chunk_retriever/outline_chunk_retriever.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -156,23 +156,17 @@ def invoke(self, task, **kwargs) -> RetrieverOutput:
156156
)
157157

158158
# to retrieve output
159-
out = RetrieverOutput(
160-
retriever_method=self.name, chunks=chunks
161-
)
159+
out = RetrieverOutput(retriever_method=self.name, chunks=chunks)
162160
chunk_cached_by_query_map.put(query, out)
163161
return out
164162

165163
except Exception as e:
166164
logger.error(f"run calculate_sim_scores failed, info: {e}", exc_info=True)
167-
return RetrieverOutput(
168-
retriever_method=self.name, err_msg=str(e)
169-
)
165+
return RetrieverOutput(retriever_method=self.name, err_msg=str(e))
170166

171167
@property
172168
def input_indices(self):
173169
return ["Outline"]
174170

175171
def schema(self):
176-
return {
177-
"name": "outline_chunk_retriever"
178-
}
172+
return {"name": "outline_chunk_retriever"}

kag/common/tools/algorithm_tool/chunk_retriever/ppr_chunk_retriever.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -309,9 +309,7 @@ def invoke(self, task, **kwargs) -> RetrieverOutput:
309309
properties=node,
310310
)
311311
)
312-
return RetrieverOutput(
313-
retriever_method=self.name, chunks=matched_docs
314-
)
312+
return RetrieverOutput(retriever_method=self.name, chunks=matched_docs)
315313

316314
def schema(self):
317315
return {

kag/common/tools/algorithm_tool/chunk_retriever/rc_retriever.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,13 +52,12 @@ def invoke(self, task, **kwargs) -> RetrieverOutput:
5252
output.retriever_method = self.name
5353
except Exception as e:
5454
logger.error(e, exc_info=True)
55-
output = RetrieverOutput(
56-
retriever_method=self.name, err_msg=f"{task} {e}"
57-
)
55+
output = RetrieverOutput(retriever_method=self.name, err_msg=f"{task} {e}")
5856
logger.debug(
5957
f"{self.schema().get('name', '')} `{task.arguments['query']}` Retrieved chunks num: {len(output.chunks)} cost={time.time() - start_time}"
6058
)
6159
return output
60+
6261
def schema(self):
6362
return {
6463
"name": "kg_rc_retriever",

kag/common/tools/algorithm_tool/chunk_retriever/summary_chunk_retriever.py

Lines changed: 18 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@
2121
VectorizeModelABC,
2222
ChunkData,
2323
RetrieverOutput,
24-
EntityData, Task,
24+
EntityData,
25+
Task,
2526
)
2627
from kag.interface.solver.model.schema_utils import SchemaUtils
2728
from kag.common.config import LogicFormConfiguration
@@ -36,13 +37,13 @@
3637
@RetrieverABC.register("summary_chunk_retriever")
3738
class SummaryChunkRetriever(RetrieverABC):
3839
def __init__(
39-
self,
40-
vectorize_model: VectorizeModelABC = None,
41-
search_api: SearchApiABC = None,
42-
graph_api: GraphApiABC = None,
43-
top_k: int = 10,
44-
score_threshold=0.85,
45-
**kwargs,
40+
self,
41+
vectorize_model: VectorizeModelABC = None,
42+
search_api: SearchApiABC = None,
43+
graph_api: GraphApiABC = None,
44+
top_k: int = 10,
45+
score_threshold=0.85,
46+
**kwargs,
4647
):
4748
super().__init__(top_k, **kwargs)
4849
self.vectorize_model = vectorize_model or VectorizeModelABC.from_config(
@@ -69,12 +70,12 @@ def _get_summaries(self, query, top_k) -> List[str]:
6970

7071
# recall top_k summaries
7172
top_k_summaries = self.search_api.search_vector(
72-
label=self.schema_helper.get_label_within_prefix("Summary"),
73-
property_key="content",
74-
query_vector=query_vector,
75-
topk=top_k,
76-
ef_search=top_k * 3,
77-
)
73+
label=self.schema_helper.get_label_within_prefix("Summary"),
74+
property_key="content",
75+
query_vector=query_vector,
76+
topk=top_k,
77+
ef_search=top_k * 3,
78+
)
7879
for item in top_k_summaries:
7980
topk_summary_ids.append(item["node"]["id"])
8081

@@ -172,23 +173,17 @@ def invoke(self, task: Task, **kwargs) -> RetrieverOutput:
172173
)
173174

174175
# to retrieve output
175-
out = RetrieverOutput(
176-
chunks=chunks, retriever_method=self.name
177-
)
176+
out = RetrieverOutput(chunks=chunks, retriever_method=self.name)
178177
chunk_cached_by_query_map.put(query, out)
179178
return out
180179

181180
except Exception as e:
182181
logger.error(f"run calculate_sim_scores failed, info: {e}", exc_info=True)
183-
return RetrieverOutput(
184-
retriever_method=self.name, err_msg=str(e)
185-
)
182+
return RetrieverOutput(retriever_method=self.name, err_msg=str(e))
186183

187184
@property
188185
def input_indices(self):
189186
return ["Summary"]
190187

191188
def schema(self):
192-
return {
193-
"name": "summary_chunk_retriever"
194-
}
189+
return {"name": "summary_chunk_retriever"}

kag/common/tools/algorithm_tool/chunk_retriever/table_retriever.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -153,23 +153,17 @@ def invoke(self, task, **kwargs) -> RetrieverOutput:
153153
chunks = self.get_related_chunks(topk_table_ids)
154154

155155
# to retrieve output
156-
out = RetrieverOutput(
157-
retriever_method=self.name, chunks=chunks
158-
)
156+
out = RetrieverOutput(retriever_method=self.name, chunks=chunks)
159157
chunk_cached_by_query_map.put(query, out)
160158
return out
161159

162160
except Exception as e:
163161
logger.error(f"run calculate_sim_scores failed, info: {e}", exc_info=True)
164-
return RetrieverOutput(
165-
retriever_method=self.name, err_msg=str(e)
166-
)
162+
return RetrieverOutput(retriever_method=self.name, err_msg=str(e))
167163

168164
@property
169165
def input_indices(self):
170166
return ["Outline"]
171167

172168
def schema(self):
173-
return {
174-
"name": "table_retriever"
175-
}
169+
return {"name": "table_retriever"}

kag/common/tools/algorithm_tool/chunk_retriever/text_chunk_retriever.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -48,14 +48,10 @@ def invoke(self, task, **kwargs) -> RetrieverOutput:
4848
score=score,
4949
)
5050
)
51-
return RetrieverOutput(
52-
chunks=chunks, retriever_method=self.name
53-
)
51+
return RetrieverOutput(chunks=chunks, retriever_method=self.name)
5452
except Exception as e:
5553
logger.error(f"run calculate_sim_scores failed, info: {e}", exc_info=True)
56-
return RetrieverOutput(
57-
retriever_method=self.name, err_msg=str(e)
58-
)
54+
return RetrieverOutput(retriever_method=self.name, err_msg=str(e))
5955

6056
def schema(self):
6157
return {

kag/common/tools/algorithm_tool/chunk_retriever/vector_chunk_retriever.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -88,11 +88,11 @@ def invoke(self, task, **kwargs) -> RetrieverOutput:
8888
score = item.get("score", 0.0)
8989
if score >= self.score_threshold:
9090
chunk = ChunkData(
91-
content=item["node"].get("content", ""),
92-
title=item["node"]["name"],
93-
chunk_id=item["node"]["id"],
94-
score=score,
95-
)
91+
content=item["node"].get("content", ""),
92+
title=item["node"]["name"],
93+
chunk_id=item["node"]["id"],
94+
score=score,
95+
)
9696
if chunk.chunk_id not in merged:
9797
merged[chunk.chunk_id] = score
9898
if merged[chunk.chunk_id] < score:
@@ -111,17 +111,13 @@ def invoke(self, task, **kwargs) -> RetrieverOutput:
111111
score=score,
112112
)
113113
)
114-
out = RetrieverOutput(
115-
chunks=chunks, retriever_method=self.name
116-
)
114+
out = RetrieverOutput(chunks=chunks, retriever_method=self.name)
117115
chunk_cached_by_query_map.put(query, out)
118116
return out
119117

120118
except Exception as e:
121119
logger.error(f"run calculate_sim_scores failed, info: {e}", exc_info=True)
122-
return RetrieverOutput(
123-
retriever_method=self.name, err_msg=str(e)
124-
)
120+
return RetrieverOutput(retriever_method=self.name, err_msg=str(e))
125121

126122
def schema(self):
127123
return {

kag/common/tools/algorithm_tool/graph_retriever/entity_linking.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,9 @@ def invoke(self, query, name, type_name, topk_k=None, **kwargs) -> List[EntityDa
222222
# 6. Final filtering and sorting
223223
if not topk_k:
224224
topk_k = self.top_k
225-
recognition_threshold = kwargs.get("recognition_threshold", self.recognition_threshold)
225+
recognition_threshold = kwargs.get(
226+
"recognition_threshold", self.recognition_threshold
227+
)
226228
retdata = []
227229
if name is None:
228230
return retdata
@@ -260,7 +262,7 @@ def invoke(self, query, name, type_name, topk_k=None, **kwargs) -> List[EntityDa
260262
else:
261263
break
262264

263-
return retdata[: topk_k]
265+
return retdata[:topk_k]
264266
except Exception as e:
265267
logger.error(
266268
f"Error in entity_linking {query} name={name} type={type_name}: {e}",

0 commit comments

Comments
 (0)