Skip to content
22 changes: 14 additions & 8 deletions api/controllers/console/app/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,10 +343,13 @@ def get(self, app_model):
)

if args.keyword:
from libs.helper import escape_like_pattern

escaped_keyword = escape_like_pattern(args.keyword)
query = query.join(Message, Message.conversation_id == Conversation.id).where(
or_(
Message.query.ilike(f"%{args.keyword}%"),
Message.answer.ilike(f"%{args.keyword}%"),
Message.query.ilike(f"%{escaped_keyword}%", escape="\\"),
Message.answer.ilike(f"%{escaped_keyword}%", escape="\\"),
)
)

Expand Down Expand Up @@ -455,7 +458,10 @@ def get(self, app_model):
query = sa.select(Conversation).where(Conversation.app_id == app_model.id, Conversation.is_deleted.is_(False))

if args.keyword:
keyword_filter = f"%{args.keyword}%"
from libs.helper import escape_like_pattern

escaped_keyword = escape_like_pattern(args.keyword)
keyword_filter = f"%{escaped_keyword}%"
query = (
query.join(
Message,
Expand All @@ -464,11 +470,11 @@ def get(self, app_model):
.join(subquery, subquery.c.conversation_id == Conversation.id)
.where(
or_(
Message.query.ilike(keyword_filter),
Message.answer.ilike(keyword_filter),
Conversation.name.ilike(keyword_filter),
Conversation.introduction.ilike(keyword_filter),
subquery.c.from_end_user_session_id.ilike(keyword_filter),
Message.query.ilike(keyword_filter, escape="\\"),
Message.answer.ilike(keyword_filter, escape="\\"),
Conversation.name.ilike(keyword_filter, escape="\\"),
Conversation.introduction.ilike(keyword_filter, escape="\\"),
subquery.c.from_end_user_session_id.ilike(keyword_filter, escape="\\"),
),
)
.group_by(Conversation.id)
Expand Down
9 changes: 6 additions & 3 deletions api/controllers/console/datasets/datasets_segments.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from fields.segment_fields import child_chunk_fields, segment_fields
from libs.helper import escape_like_pattern
from libs.login import current_account_with_tenant, login_required
from models.dataset import ChildChunk, DocumentSegment
from models.model import UploadFile
Expand Down Expand Up @@ -145,6 +146,8 @@ def get(self, dataset_id, document_id):
query = query.where(DocumentSegment.hit_count >= hit_count_gte)

if keyword:
# Escape special characters in keyword to prevent SQL injection via LIKE wildcards
escaped_keyword = escape_like_pattern(keyword)
# Search in both content and keywords fields
# Use database-specific methods for JSON array search
if dify_config.SQLALCHEMY_DATABASE_URI_SCHEME == "postgresql":
Expand All @@ -156,15 +159,15 @@ def get(self, dataset_id, document_id):
.scalar_subquery()
),
",",
).ilike(f"%{keyword}%")
).ilike(f"%{escaped_keyword}%", escape="\\")
else:
# MySQL: Cast JSON to string for pattern matching
# MySQL stores Chinese text directly in JSON without Unicode escaping
keywords_condition = cast(DocumentSegment.keywords, String).ilike(f"%{keyword}%")
keywords_condition = cast(DocumentSegment.keywords, String).ilike(f"%{escaped_keyword}%", escape="\\")

query = query.where(
or_(
DocumentSegment.content.ilike(f"%{keyword}%"),
DocumentSegment.content.ilike(f"%{escaped_keyword}%", escape="\\"),
keywords_condition,
)
)
Expand Down
8 changes: 5 additions & 3 deletions api/core/rag/datasource/vdb/clickzetta/clickzetta_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -984,9 +984,11 @@ def _search_by_like(self, query: str, **kwargs: Any) -> list[Document]:

# No need for dataset_id filter since each dataset has its own table

# Use simple quote escaping for LIKE clause
escaped_query = query.replace("'", "''")
filter_clauses.append(f"{Field.CONTENT_KEY} LIKE '%{escaped_query}%'")
# Escape special characters for LIKE clause to prevent SQL injection
from libs.helper import escape_like_pattern

escaped_query = escape_like_pattern(query).replace("'", "''")
filter_clauses.append(f"{Field.CONTENT_KEY} LIKE '%{escaped_query}%' ESCAPE '\\\\'")
where_clause = " AND ".join(filter_clauses)

search_sql = f"""
Expand Down
8 changes: 6 additions & 2 deletions api/core/rag/datasource/vdb/iris/iris_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,11 +287,15 @@ def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
cursor.execute(sql, (query,))
else:
# Fallback to LIKE search (inefficient for large datasets)
query_pattern = f"%{query}%"
# Escape special characters for LIKE clause to prevent SQL injection
from libs.helper import escape_like_pattern

escaped_query = escape_like_pattern(query)
query_pattern = f"%{escaped_query}%"
sql = f"""
SELECT TOP {top_k} id, text, meta
FROM {self.schema}.{self.table_name}
WHERE text LIKE ?
WHERE text LIKE ? ESCAPE '\\'
"""
cursor.execute(sql, (query_pattern,))

Expand Down
14 changes: 10 additions & 4 deletions api/core/rag/retrieval/dataset_retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -1195,18 +1195,24 @@ def process_metadata_filter_func(

json_field = DatasetDocument.doc_metadata[metadata_name].as_string()

from libs.helper import escape_like_pattern

match condition:
case "contains":
filters.append(json_field.like(f"%{value}%"))
escaped_value = escape_like_pattern(str(value))
filters.append(json_field.like(f"%{escaped_value}%", escape="\\"))

case "not contains":
filters.append(json_field.notlike(f"%{value}%"))
escaped_value = escape_like_pattern(str(value))
filters.append(json_field.notlike(f"%{escaped_value}%", escape="\\"))

case "start with":
filters.append(json_field.like(f"{value}%"))
escaped_value = escape_like_pattern(str(value))
filters.append(json_field.like(f"{escaped_value}%", escape="\\"))

case "end with":
filters.append(json_field.like(f"%{value}"))
escaped_value = escape_like_pattern(str(value))
filters.append(json_field.like(f"%{escaped_value}", escape="\\"))

case "is" | "=":
if isinstance(value, str):
Expand Down
32 changes: 32 additions & 0 deletions api/libs/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,38 @@
logger = logging.getLogger(__name__)


def escape_like_pattern(pattern: str) -> str:
"""
Escape special characters in a string for safe use in SQL LIKE patterns.

This function escapes the special characters used in SQL LIKE patterns:
- Backslash (\\) -> \\
- Percent (%) -> \\%
- Underscore (_) -> \\_

The escaped pattern can then be safely used in SQL LIKE queries with the
ESCAPE '\\' clause to prevent SQL injection via LIKE wildcards.

Args:
pattern: The string pattern to escape

Returns:
Escaped string safe for use in SQL LIKE queries

Examples:
>>> escape_like_pattern("50% discount")
'50\\% discount'
>>> escape_like_pattern("test_data")
'test\\_data'
>>> escape_like_pattern("path\\to\\file")
'path\\\\to\\\\file'
"""
if not pattern:
return pattern
# Escape backslash first, then percent and underscore
return pattern.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_")


def extract_tenant_id(user: Union["Account", "EndUser"]) -> str | None:
"""
Extract tenant_id from Account or EndUser object.
Expand Down
7 changes: 5 additions & 2 deletions api/services/annotation_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,13 +137,16 @@ def get_annotation_list_by_app_id(cls, app_id: str, page: int, limit: int, keywo
if not app:
raise NotFound("App not found")
if keyword:
from libs.helper import escape_like_pattern

escaped_keyword = escape_like_pattern(keyword)
stmt = (
select(MessageAnnotation)
.where(MessageAnnotation.app_id == app_id)
.where(
or_(
MessageAnnotation.question.ilike(f"%{keyword}%"),
MessageAnnotation.content.ilike(f"%{keyword}%"),
MessageAnnotation.question.ilike(f"%{escaped_keyword}%", escape="\\"),
MessageAnnotation.content.ilike(f"%{escaped_keyword}%", escape="\\"),
)
)
.order_by(MessageAnnotation.created_at.desc(), MessageAnnotation.id.desc())
Expand Down
5 changes: 4 additions & 1 deletion api/services/app_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,11 @@ def get_paginate_apps(self, user_id: str, tenant_id: str, args: dict) -> Paginat
if args.get("is_created_by_me", False):
filters.append(App.created_by == user_id)
if args.get("name"):
from libs.helper import escape_like_pattern

name = args["name"][:30]
filters.append(App.name.ilike(f"%{name}%"))
escaped_name = escape_like_pattern(name)
filters.append(App.name.ilike(f"%{escaped_name}%", escape="\\"))
# Check if tag_ids is not empty to avoid WHERE false condition
if args.get("tag_ids") and len(args["tag_ids"]) > 0:
target_ids = TagService.get_target_ids_by_tag_ids("app", tenant_id, args["tag_ids"])
Expand Down
4 changes: 3 additions & 1 deletion api/services/conversation_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,9 @@ def get_conversational_variable(
# Apply variable_name filter if provided
if variable_name:
# Filter using JSON extraction to match variable names case-insensitively
escaped_variable_name = variable_name.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_")
from libs.helper import escape_like_pattern

escaped_variable_name = escape_like_pattern(variable_name)
# Filter using JSON extraction to match variable names case-insensitively
if dify_config.DB_TYPE in ["mysql", "oceanbase", "seekdb"]:
stmt = stmt.where(
Expand Down
9 changes: 6 additions & 3 deletions api/services/dataset_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,8 @@ def get_datasets(page, per_page, tenant_id=None, user=None, search=None, tag_ids
query = query.where(Dataset.permission == DatasetPermissionEnum.ALL_TEAM)

if search:
query = query.where(Dataset.name.ilike(f"%{search}%"))
escaped_search = helper.escape_like_pattern(search)
query = query.where(Dataset.name.ilike(f"%{escaped_search}%", escape="\\"))

# Check if tag_ids is not empty to avoid WHERE false condition
if tag_ids and len(tag_ids) > 0:
Expand Down Expand Up @@ -3423,7 +3424,8 @@ def get_child_chunks(
.order_by(ChildChunk.position.asc())
)
if keyword:
query = query.where(ChildChunk.content.ilike(f"%{keyword}%"))
escaped_keyword = helper.escape_like_pattern(keyword)
query = query.where(ChildChunk.content.ilike(f"%{escaped_keyword}%", escape="\\"))
return db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False)

@classmethod
Expand Down Expand Up @@ -3456,7 +3458,8 @@ def get_segments(
query = query.where(DocumentSegment.status.in_(status_list))

if keyword:
query = query.where(DocumentSegment.content.ilike(f"%{keyword}%"))
escaped_keyword = helper.escape_like_pattern(keyword)
query = query.where(DocumentSegment.content.ilike(f"%{escaped_keyword}%", escape="\\"))

query = query.order_by(DocumentSegment.position.asc(), DocumentSegment.id.asc())
paginated_segments = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False)
Expand Down
5 changes: 4 additions & 1 deletion api/services/external_knowledge_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,10 @@ def get_external_knowledge_apis(
.order_by(ExternalKnowledgeApis.created_at.desc())
)
if search:
query = query.where(ExternalKnowledgeApis.name.ilike(f"%{search}%"))
from libs.helper import escape_like_pattern

escaped_search = escape_like_pattern(search)
query = query.where(ExternalKnowledgeApis.name.ilike(f"%{escaped_search}%", escape="\\"))

external_knowledge_apis = db.paginate(
select=query, page=page, per_page=per_page, max_per_page=100, error_out=False
Expand Down
5 changes: 4 additions & 1 deletion api/services/tag_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,10 @@ def get_tags(tag_type: str, current_tenant_id: str, keyword: str | None = None):
.where(Tag.type == tag_type, Tag.tenant_id == current_tenant_id)
)
if keyword:
query = query.where(sa.and_(Tag.name.ilike(f"%{keyword}%")))
from libs.helper import escape_like_pattern

escaped_keyword = escape_like_pattern(keyword)
query = query.where(sa.and_(Tag.name.ilike(f"%{escaped_keyword}%", escape="\\")))
query = query.group_by(Tag.id, Tag.type, Tag.name, Tag.created_at)
results: list = query.order_by(Tag.created_at.desc()).all()
return results
Expand Down
15 changes: 11 additions & 4 deletions api/services/workflow_app_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,12 +86,19 @@ def get_paginate_workflow_app_logs(
# Join to workflow run for filtering when needed.

if keyword:
keyword_like_val = f"%{keyword[:30].encode('unicode_escape').decode('utf-8')}%".replace(r"\u", r"\\u")
from libs.helper import escape_like_pattern

# Escape special characters in keyword to prevent SQL injection via LIKE wildcards
escaped_keyword = escape_like_pattern(keyword[:30])
keyword_like_val = f"%{escaped_keyword}%"
keyword_conditions = [
WorkflowRun.inputs.ilike(keyword_like_val),
WorkflowRun.outputs.ilike(keyword_like_val),
WorkflowRun.inputs.ilike(keyword_like_val, escape="\\"),
WorkflowRun.outputs.ilike(keyword_like_val, escape="\\"),
# filter keyword by end user session id if created by end user role
and_(WorkflowRun.created_by_role == "end_user", EndUser.session_id.ilike(keyword_like_val)),
and_(
WorkflowRun.created_by_role == "end_user",
EndUser.session_id.ilike(keyword_like_val, escape="\\"),
),
]

# filter keyword by workflow run id
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -444,6 +444,78 @@ def test_get_annotation_list_by_app_id_with_keyword(
assert total == 1
assert unique_keyword in annotation_list[0].question or unique_keyword in annotation_list[0].content

def test_get_annotation_list_by_app_id_with_special_characters_in_keyword(
self, db_session_with_containers, mock_external_service_dependencies
):
r"""
Test retrieval of annotation list with special characters in keyword to verify SQL injection prevention.

This test verifies:
- Special characters (%, _, \) in keyword are properly escaped
- Search treats special characters as literal characters, not wildcards
- SQL injection via LIKE wildcards is prevented
"""
fake = Faker()
app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)

# Create annotations with special characters in content
annotation_with_percent = {
"question": "Question with 50% discount",
"answer": "Answer about 50% discount offer",
}
AppAnnotationService.insert_app_annotation_directly(annotation_with_percent, app.id)

annotation_with_underscore = {
"question": "Question with test_data",
"answer": "Answer about test_data value",
}
AppAnnotationService.insert_app_annotation_directly(annotation_with_underscore, app.id)

annotation_with_backslash = {
"question": "Question with path\\to\\file",
"answer": "Answer about path\\to\\file location",
}
AppAnnotationService.insert_app_annotation_directly(annotation_with_backslash, app.id)

# Create annotation that should NOT match (contains % but as part of different text)
annotation_no_match = {
"question": "Question with 100% different",
"answer": "Answer about 100% different content",
}
AppAnnotationService.insert_app_annotation_directly(annotation_no_match, app.id)

# Test 1: Search with % character - should find exact match only
annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id(
app.id, page=1, limit=10, keyword="50%"
)
assert total == 1
assert len(annotation_list) == 1
assert "50%" in annotation_list[0].question or "50%" in annotation_list[0].content

# Test 2: Search with _ character - should find exact match only
annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id(
app.id, page=1, limit=10, keyword="test_data"
)
assert total == 1
assert len(annotation_list) == 1
assert "test_data" in annotation_list[0].question or "test_data" in annotation_list[0].content

# Test 3: Search with \ character - should find exact match only
annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id(
app.id, page=1, limit=10, keyword="path\\to\\file"
)
assert total == 1
assert len(annotation_list) == 1
assert "path\\to\\file" in annotation_list[0].question or "path\\to\\file" in annotation_list[0].content

# Test 4: Search with % should NOT match 100% (verifies escaping works)
annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id(
app.id, page=1, limit=10, keyword="50%"
)
# Should only find the 50% annotation, not the 100% one
assert total == 1
assert all("50%" in (item.question or "") or "50%" in (item.content or "") for item in annotation_list)

def test_get_annotation_list_by_app_id_app_not_found(
self, db_session_with_containers, mock_external_service_dependencies
):
Expand Down
Loading