Skip to content

Commit 08fcaaf

Browse files
pl04351820Sichen Liu
andauthored
feat: support Vector Search (#896)
Co-authored-by: Sichen Liu <sichenliu@google.com>
1 parent a8ed3ea commit 08fcaaf

File tree

15 files changed

+1260
-33
lines changed

15 files changed

+1260
-33
lines changed

google/cloud/firestore_v1/_helpers.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626

2727
from google.cloud import exceptions # type: ignore
2828
from google.cloud._helpers import _datetime_to_pb_timestamp # type: ignore
29+
from google.cloud.firestore_v1.vector import Vector
2930
from google.cloud.firestore_v1.types.write import DocumentTransform
3031
from google.cloud.firestore_v1 import transforms
3132
from google.cloud.firestore_v1 import types
@@ -160,7 +161,8 @@ def encode_value(value) -> types.document.Value:
160161
161162
Args:
162163
value (Union[NoneType, bool, int, float, datetime.datetime, \
163-
str, bytes, dict, ~google.cloud.Firestore.GeoPoint]): A native
164+
str, bytes, dict, ~google.cloud.Firestore.GeoPoint, \
165+
~google.cloud.firestore_v1.vector.Vector]): A native
164166
Python value to convert to a protobuf field.
165167
166168
Returns:
@@ -209,6 +211,9 @@ def encode_value(value) -> types.document.Value:
209211
value_pb = document.ArrayValue(values=value_list)
210212
return document.Value(array_value=value_pb)
211213

214+
if isinstance(value, Vector):
215+
return encode_value(value.to_map_value())
216+
212217
if isinstance(value, dict):
213218
value_dict = encode_dict(value)
214219
value_pb = document.MapValue(fields=value_dict)
@@ -331,7 +336,9 @@ def reference_value_to_document(reference_value, client) -> Any:
331336

332337
def decode_value(
333338
value, client
334-
) -> Union[None, bool, int, float, list, datetime.datetime, str, bytes, dict, GeoPoint]:
339+
) -> Union[
340+
None, bool, int, float, list, datetime.datetime, str, bytes, dict, GeoPoint, Vector
341+
]:
335342
"""Converts a Firestore protobuf ``Value`` to a native Python value.
336343
337344
Args:
@@ -382,7 +389,7 @@ def decode_value(
382389
raise ValueError("Unknown ``value_type``", value_type)
383390

384391

385-
def decode_dict(value_fields, client) -> dict:
392+
def decode_dict(value_fields, client) -> Union[dict, Vector]:
386393
"""Converts a protobuf map of Firestore ``Value``-s.
387394
388395
Args:
@@ -397,8 +404,14 @@ def decode_dict(value_fields, client) -> dict:
397404
of native Python values converted from the ``value_fields``.
398405
"""
399406
value_fields_pb = getattr(value_fields, "_pb", value_fields)
407+
res = {key: decode_value(value, client) for key, value in value_fields_pb.items()}
408+
409+
if res.get("__type__", None) == "__vector__":
410+
# Vector data type is represented as mapping.
411+
# {"__type__":"__vector__", "value": [1.0, 2.0, 3.0]}.
412+
return Vector(res["value"])
400413

401-
return {key: decode_value(value, client) for key, value in value_fields_pb.items()}
414+
return res
402415

403416

404417
def get_doc_id(document_pb, expected_prefix) -> str:

google/cloud/firestore_v1/base_collection.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,12 @@
1919
from google.api_core import retry as retries
2020

2121
from google.cloud.firestore_v1 import _helpers
22+
from google.cloud.firestore_v1.base_vector_query import DistanceMeasure
2223
from google.cloud.firestore_v1.document import DocumentReference
2324
from google.cloud.firestore_v1.base_aggregation import BaseAggregationQuery
25+
from google.cloud.firestore_v1.base_vector_query import BaseVectorQuery
2426
from google.cloud.firestore_v1.base_query import QueryType
27+
from google.cloud.firestore_v1.vector import Vector
2528

2629

2730
from typing import (
@@ -46,6 +49,7 @@
4649
from google.cloud.firestore_v1.base_document import DocumentSnapshot
4750
from google.cloud.firestore_v1.transaction import Transaction
4851
from google.cloud.firestore_v1.field_path import FieldPath
52+
from firestore_v1.vector_query import VectorQuery
4953

5054
_AUTO_ID_CHARS = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"
5155

@@ -120,6 +124,9 @@ def _query(self) -> QueryType:
120124
def _aggregation_query(self) -> BaseAggregationQuery:
121125
raise NotImplementedError
122126

127+
def _vector_query(self) -> BaseVectorQuery:
128+
raise NotImplementedError
129+
123130
def document(self, document_id: Optional[str] = None) -> DocumentReference:
124131
"""Create a sub-document underneath the current collection.
125132
@@ -539,6 +546,31 @@ def avg(self, field_ref: str | FieldPath, alias=None):
539546
"""
540547
return self._aggregation_query().avg(field_ref, alias=alias)
541548

549+
def find_nearest(
550+
self,
551+
vector_field: str,
552+
query_vector: Vector,
553+
limit: int,
554+
distance_measure: DistanceMeasure,
555+
) -> VectorQuery:
556+
"""
557+
Finds the closest vector embeddings to the given query vector.
558+
559+
Args:
560+
vector_field(str): An indexed vector field to search upon. Only documents which contain
561+
vectors whose dimensionality match the query_vector can be returned.
562+
query_vector(Vector): The query vector that we are searching on. Must be a vector of no more
563+
than 2048 dimensions.
564+
limit (int): The number of nearest neighbors to return. Must be a positive integer of no more than 1000.
565+
distance_measure(:class:`DistanceMeasure`): The Distance Measure to use.
566+
567+
Returns:
568+
:class`~firestore_v1.vector_query.VectorQuery`: the vector query.
569+
"""
570+
return self._vector_query().find_nearest(
571+
vector_field, query_vector, limit, distance_measure
572+
)
573+
542574

543575
def _auto_id() -> str:
544576
"""Generate a "random" automatically generated ID.

google/cloud/firestore_v1/base_query.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from google.cloud.firestore_v1 import document
3434
from google.cloud.firestore_v1 import field_path as field_path_module
3535
from google.cloud.firestore_v1 import transforms
36+
from google.cloud.firestore_v1.base_vector_query import DistanceMeasure
3637
from google.cloud.firestore_v1.types import StructuredQuery
3738
from google.cloud.firestore_v1.types import query
3839
from google.cloud.firestore_v1.types import Cursor
@@ -51,11 +52,13 @@
5152
Union,
5253
TYPE_CHECKING,
5354
)
55+
from google.cloud.firestore_v1.vector import Vector
5456

5557
# Types needed only for Type Hints
5658
from google.cloud.firestore_v1.base_document import DocumentSnapshot
5759

5860
if TYPE_CHECKING: # pragma: NO COVER
61+
from google.cloud.firestore_v1.base_vector_query import BaseVectorQuery
5962
from google.cloud.firestore_v1.field_path import FieldPath
6063

6164
_BAD_DIR_STRING: str
@@ -972,6 +975,15 @@ def _to_protobuf(self) -> StructuredQuery:
972975
query_kwargs["limit"] = wrappers_pb2.Int32Value(value=self._limit)
973976
return query.StructuredQuery(**query_kwargs)
974977

978+
def find_nearest(
979+
self,
980+
vector_field: str,
981+
queryVector: Vector,
982+
limit: int,
983+
distance_measure: DistanceMeasure,
984+
) -> BaseVectorQuery:
985+
raise NotImplementedError
986+
975987
def count(
976988
self, alias: str | None = None
977989
) -> Type["firestore_v1.base_aggregation.BaseAggregationQuery"]:
Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
# Copyright 2024 Google LLC All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Classes for representing vector queries for the Google Cloud Firestore API.
16+
"""
17+
18+
import abc
19+
20+
from abc import ABC
21+
from enum import Enum
22+
from typing import Iterable, Optional, Tuple, Union
23+
from google.api_core import gapic_v1
24+
from google.api_core import retry as retries
25+
from google.cloud.firestore_v1.base_document import DocumentSnapshot
26+
from google.cloud.firestore_v1.types import query
27+
from google.cloud.firestore_v1.vector import Vector
28+
from google.cloud.firestore_v1 import _helpers
29+
30+
31+
class DistanceMeasure(Enum):
32+
EUCLIDEAN = 1
33+
COSINE = 2
34+
DOT_PRODUCT = 3
35+
36+
37+
class BaseVectorQuery(ABC):
38+
"""Represents a vector query to the Firestore API."""
39+
40+
def __init__(self, nested_query) -> None:
41+
self._nested_query = nested_query
42+
self._collection_ref = nested_query._parent
43+
self._vector_field: Optional[str] = None
44+
self._query_vector: Optional[Vector] = None
45+
self._limit: Optional[int] = None
46+
self._distance_measure: Optional[DistanceMeasure] = None
47+
48+
@property
49+
def _client(self):
50+
return self._collection_ref._client
51+
52+
def _to_protobuf(self) -> query.StructuredQuery:
53+
pb = query.StructuredQuery()
54+
55+
distance_measure_proto = None
56+
if self._distance_measure == DistanceMeasure.EUCLIDEAN:
57+
distance_measure_proto = (
58+
query.StructuredQuery.FindNearest.DistanceMeasure.EUCLIDEAN
59+
)
60+
elif self._distance_measure == DistanceMeasure.COSINE:
61+
distance_measure_proto = (
62+
query.StructuredQuery.FindNearest.DistanceMeasure.COSINE
63+
)
64+
elif self._distance_measure == DistanceMeasure.DOT_PRODUCT:
65+
distance_measure_proto = (
66+
query.StructuredQuery.FindNearest.DistanceMeasure.DOT_PRODUCT
67+
)
68+
else:
69+
raise ValueError("Invalid distance_measure")
70+
71+
pb = self._nested_query._to_protobuf()
72+
pb.find_nearest = query.StructuredQuery.FindNearest(
73+
vector_field=query.StructuredQuery.FieldReference(
74+
field_path=self._vector_field
75+
),
76+
query_vector=_helpers.encode_value(self._query_vector),
77+
distance_measure=distance_measure_proto,
78+
limit=self._limit,
79+
)
80+
return pb
81+
82+
def _prep_stream(
83+
self,
84+
transaction=None,
85+
retry: Union[retries.Retry, None, gapic_v1.method._MethodDefault] = None,
86+
timeout: Optional[float] = None,
87+
) -> Tuple[dict, str, dict]:
88+
parent_path, expected_prefix = self._collection_ref._parent_info()
89+
request = {
90+
"parent": parent_path,
91+
"structured_query": self._to_protobuf(),
92+
"transaction": _helpers.get_transaction_id(transaction),
93+
}
94+
kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout)
95+
96+
return request, expected_prefix, kwargs
97+
98+
@abc.abstractmethod
99+
def get(
100+
self,
101+
transaction=None,
102+
retry: retries.Retry = gapic_v1.method.DEFAULT,
103+
timeout: Optional[float] = None,
104+
) -> Iterable[DocumentSnapshot]:
105+
"""Runs the vector query."""
106+
107+
def find_nearest(
108+
self,
109+
vector_field: str,
110+
query_vector: Vector,
111+
limit: int,
112+
distance_measure: DistanceMeasure,
113+
):
114+
"""Finds the closest vector embeddings to the given query vector."""
115+
self._vector_field = vector_field
116+
self._query_vector = query_vector
117+
self._limit = limit
118+
self._distance_measure = distance_measure
119+
return self

google/cloud/firestore_v1/collection.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
)
2424
from google.cloud.firestore_v1 import query as query_mod
2525
from google.cloud.firestore_v1 import aggregation
26+
from google.cloud.firestore_v1 import vector_query
2627
from google.cloud.firestore_v1.watch import Watch
2728
from google.cloud.firestore_v1 import document
2829
from typing import Any, Callable, Generator, Tuple, Union
@@ -76,6 +77,14 @@ def _aggregation_query(self) -> aggregation.AggregationQuery:
7677
"""
7778
return aggregation.AggregationQuery(self._query())
7879

80+
def _vector_query(self) -> vector_query.VectorQuery:
81+
"""VectorQuery factory.
82+
83+
Returns:
84+
:class:`~google.cloud.firestore_v1.vector_query.VectorQuery`
85+
"""
86+
return vector_query.VectorQuery(self._query())
87+
7988
def add(
8089
self,
8190
document_data: dict,

0 commit comments

Comments
 (0)