Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 31 additions & 31 deletions api/core/rag/retrieval/dataset_retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -1474,38 +1474,38 @@ def _multiple_retrieve_thread(
if cancel_event and cancel_event.is_set():
break

# Skip second reranking when there is only one dataset
if reranking_enable and dataset_count > 1:
# do rerank for searched documents
data_post_processor = DataPostProcessor(tenant_id, reranking_mode, reranking_model, weights, False)
if query:
all_documents_item = data_post_processor.invoke(
query=query,
documents=all_documents_item,
score_threshold=score_threshold,
top_n=top_k,
query_type=QueryType.TEXT_QUERY,
)
if attachment_id:
all_documents_item = data_post_processor.invoke(
documents=all_documents_item,
score_threshold=score_threshold,
top_n=top_k,
query_type=QueryType.IMAGE_QUERY,
query=attachment_id,
)
else:
if index_type == IndexTechniqueType.ECONOMY:
if not query:
all_documents_item = []
else:
all_documents_item = self.calculate_keyword_score(query, all_documents_item, top_k)
elif index_type == IndexTechniqueType.HIGH_QUALITY:
all_documents_item = self.calculate_vector_score(all_documents_item, top_k, score_threshold)
# Skip second reranking when there is only one dataset
if reranking_enable and dataset_count > 1:
# do rerank for searched documents
data_post_processor = DataPostProcessor(tenant_id, reranking_mode, reranking_model, weights, False)
if query:
all_documents_item = data_post_processor.invoke(
query=query,
documents=all_documents_item,
score_threshold=score_threshold,
top_n=top_k,
query_type=QueryType.TEXT_QUERY,
)
if attachment_id:
all_documents_item = data_post_processor.invoke(
documents=all_documents_item,
score_threshold=score_threshold,
top_n=top_k,
query_type=QueryType.IMAGE_QUERY,
query=attachment_id,
)
else:
all_documents_item = all_documents_item[:top_k] if top_k else all_documents_item
if all_documents_item:
all_documents.extend(all_documents_item)
if index_type == IndexTechniqueType.ECONOMY:
if not query:
all_documents_item = []
else:
all_documents_item = self.calculate_keyword_score(query, all_documents_item, top_k)
elif index_type == IndexTechniqueType.HIGH_QUALITY:
all_documents_item = self.calculate_vector_score(all_documents_item, top_k, score_threshold)
else:
all_documents_item = all_documents_item[:top_k] if top_k else all_documents_item
if all_documents_item:
all_documents.extend(all_documents_item)
except Exception as e:
if cancel_event:
cancel_event.set()
Expand Down
113 changes: 113 additions & 0 deletions api/tests/unit_tests/core/rag/retrieval/test_knowledge_retrieval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
import threading
from unittest.mock import Mock, patch
from uuid import uuid4

import pytest
from flask import Flask, current_app

from core.rag.models.document import Document
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
from models.dataset import Dataset


class TestRetrievalService:
@pytest.fixture
def mock_dataset(self) -> Dataset:
dataset = Mock(spec=Dataset)
dataset.id = str(uuid4())
dataset.tenant_id = str(uuid4())
dataset.name = "test_dataset"
dataset.indexing_technique = "high_quality"
dataset.provider = "dify"
return dataset

def test_multiple_retrieve_reranking_with_app_context(self, mock_dataset):
"""
Repro test for current bug:
reranking runs after `with flask_app.app_context():` exits.
`_multiple_retrieve_thread` catches exceptions and stores them into `thread_exceptions`,
so we must assert from that list (not from an outer try/except).
"""
dataset_retrieval = DatasetRetrieval()
flask_app = Flask(__name__)
tenant_id = str(uuid4())

# second dataset to ensure dataset_count > 1 reranking branch
secondary_dataset = Mock(spec=Dataset)
secondary_dataset.id = str(uuid4())
secondary_dataset.provider = "dify"
secondary_dataset.indexing_technique = "high_quality"

# retriever returns 1 doc into internal list (all_documents_item)
document = Document(
page_content="Context aware doc",
metadata={
"doc_id": "doc1",
"score": 0.95,
"document_id": str(uuid4()),
"dataset_id": mock_dataset.id,
},
provider="dify",
)

def fake_retriever(
flask_app, dataset_id, query, top_k, all_documents, document_ids_filter, metadata_condition, attachment_ids
):
all_documents.append(document)

called = {"init": 0, "invoke": 0}

class ContextRequiredPostProcessor:
def __init__(self, *args, **kwargs):
called["init"] += 1
# will raise RuntimeError if no Flask app context exists
_ = current_app.name

def invoke(self, *args, **kwargs):
called["invoke"] += 1
_ = current_app.name
return kwargs.get("documents") or args[1]

# output list from _multiple_retrieve_thread
all_documents: list[Document] = []

# IMPORTANT: _multiple_retrieve_thread swallows exceptions and appends them here
thread_exceptions: list[Exception] = []

def target():
with patch.object(dataset_retrieval, "_retriever", side_effect=fake_retriever):
with patch(
"core.rag.retrieval.dataset_retrieval.DataPostProcessor",
ContextRequiredPostProcessor,
):
dataset_retrieval._multiple_retrieve_thread(
flask_app=flask_app,
available_datasets=[mock_dataset, secondary_dataset],
metadata_condition=None,
metadata_filter_document_ids=None,
all_documents=all_documents,
tenant_id=tenant_id,
reranking_enable=True,
reranking_mode="reranking_model",
reranking_model={
"reranking_provider_name": "cohere",
"reranking_model_name": "rerank-v2",
},
weights=None,
top_k=3,
score_threshold=0.0,
query="test query",
attachment_id=None,
dataset_count=2, # force reranking branch
thread_exceptions=thread_exceptions, # ✅ key
)

t = threading.Thread(target=target)
t.start()
t.join()

# Ensure reranking branch was actually executed
assert called["init"] >= 1, "DataPostProcessor was never constructed; reranking branch may not have run."

# Current buggy code should record an exception (not raise it)
assert not thread_exceptions, thread_exceptions