Skip to content

Commit f99c115

Browse files
authored
feat(graph): Handle cases when a single table is multiplexed by multi… (#169)
* feat(graph): Handle cases when a single table is multiplexed by multiple element definitions 1) Allow the same table used to define multiple node/edge definitions; For example, ``` ShareHolding AS PersonOwnsCompany SOURCE Person DESTINATION Company, ShareHolding AS ComanyOwnsCompany SOURCE Company DESTINATION Company, ``` 2) Allow graph schema to reference TOKENLIST-typed properties (by ignoring them) This is a blocker if a customer manually add search index to a property and want to use it using graph query language. * Adjust the schema representation to be more explicit * Address comments
1 parent 9940b86 commit f99c115

File tree

3 files changed

+197
-70
lines changed

3 files changed

+197
-70
lines changed

src/langchain_google_spanner/graph_store.py

Lines changed: 99 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -188,11 +188,14 @@ class ElementSchema(object):
188188

189189
NODE_KEY_COLUMN_NAME: str = "id"
190190
TARGET_NODE_KEY_COLUMN_NAME: str = "target_id"
191+
192+
# Reserved column names when `use_flexible_schema` is true.
193+
# Properties are stored in a JSON column named `properties`;
194+
# Edge types are stored in a string column named `label`.
191195
DYNAMIC_PROPERTY_COLUMN_NAME: str = "properties"
192196
DYNAMIC_LABEL_COLUMN_NAME: str = "label"
193197

194198
name: str
195-
original_name: str
196199
kind: str
197200
key_columns: List[str]
198201
base_table_name: str
@@ -220,8 +223,7 @@ def make_node_schema(
220223
node.properties = CaseInsensitiveDict({prop: prop for prop in node.types})
221224
node.labels = [node_label]
222225
node.base_table_name = "%s_%s" % (graph_name, node_label)
223-
node.original_name = node_type
224-
node.name = node.base_table_name
226+
node.name = node_type
225227
node.kind = NODE_KIND
226228
node.key_columns = [ElementSchema.NODE_KEY_COLUMN_NAME]
227229
return node
@@ -243,15 +245,18 @@ def make_edge_schema(
243245
edge.labels = [edge_label]
244246
edge.base_table_name = "%s_%s" % (graph_schema.graph_name, edge_label)
245247
edge.key_columns = key_columns
246-
edge.original_name = edge_type
247-
edge.name = edge.base_table_name
248+
edge.name = edge_type
248249
edge.kind = EDGE_KIND
249250

250-
source_node_schema = graph_schema.get_node_schema(source_node_type)
251+
source_node_schema = graph_schema.get_node_schema(
252+
graph_schema.node_type_name(source_node_type)
253+
)
251254
if source_node_schema is None:
252255
raise ValueError("No source node schema `%s` found" % source_node_type)
253256

254-
target_node_schema = graph_schema.get_node_schema(target_node_type)
257+
target_node_schema = graph_schema.get_node_schema(
258+
graph_schema.node_type_name(target_node_type)
259+
)
255260
if target_node_schema is None:
256261
raise ValueError("No target node schema `%s` found" % target_node_type)
257262

@@ -346,7 +351,7 @@ def from_dynamic_nodes(
346351
)
347352
)
348353
return ElementSchema.make_node_schema(
349-
name, NODE_KIND, graph_schema.graph_name, types
354+
NODE_KIND, NODE_KIND, graph_schema.graph_name, types
350355
)
351356

352357
@staticmethod
@@ -452,7 +457,7 @@ def from_dynamic_edges(
452457
)
453458
)
454459
return ElementSchema.make_edge_schema(
455-
name,
460+
EDGE_KIND,
456461
EDGE_KIND,
457462
graph_schema,
458463
[
@@ -567,14 +572,13 @@ def add_edges(
567572
@staticmethod
568573
def from_info_schema(
569574
element_schema: Dict[str, Any],
570-
property_decls: List[Any],
575+
decl_by_types: CaseInsensitiveDict,
571576
) -> ElementSchema:
572577
"""Builds ElementSchema from information schema represenation of an element.
573578
574579
Args:
575580
element_schema: the information schema represenation of an element;
576-
property_decls: the information schema represenation of property
577-
declarations.
581+
decl_by_types: type information of property declarations.
578582
579583
Returns:
580584
ElementSchema
@@ -584,26 +588,23 @@ def from_info_schema(
584588
"""
585589
element = ElementSchema()
586590
element.name = element_schema["name"]
587-
element.original_name = element.name
588591
element.kind = element_schema["kind"]
589592
if element.kind not in [NODE_KIND, EDGE_KIND]:
590593
raise ValueError("Invalid element kind `{}`".format(element.kind))
591594

592595
element.key_columns = element_schema["keyColumns"]
593596
element.base_table_name = element_schema["baseTableName"]
594597
element.labels = element_schema["labelNames"]
598+
595599
element.properties = CaseInsensitiveDict(
596600
{
597601
prop_def["propertyDeclarationName"]: prop_def["valueExpressionSql"]
598602
for prop_def in element_schema.get("propertyDefinitions", [])
603+
if prop_def["propertyDeclarationName"] in decl_by_types
599604
}
600605
)
601606
element.types = CaseInsensitiveDict(
602-
{
603-
decl["name"]: TypeUtility.schema_str_to_spanner_type(decl["type"])
604-
for decl in property_decls
605-
if decl["name"] in element.properties
606-
}
607+
{decl: decl_by_types[decl] for decl in element.properties.keys()}
607608
)
608609

609610
if element.kind == EDGE_KIND:
@@ -636,7 +637,7 @@ def to_ddl(self, graph_schema: SpannerGraphSchema) -> str:
636637
to_identifiers = GraphDocumentUtility.to_identifiers
637638

638639
def get_reference_node_table(name: str) -> str:
639-
node_schema = graph_schema.node_tables.get(name, None)
640+
node_schema = graph_schema.nodes.get(name, None)
640641
if node_schema is None:
641642
raise ValueError("No node schema `%s` found" % name)
642643
return node_schema.base_table_name
@@ -708,13 +709,17 @@ def evolve(self, new_schema: ElementSchema) -> List[str]:
708709
)
709710
)
710711

711-
for k, v in new_schema.properties.items():
712-
if k in self.properties:
713-
if self.properties[k].casefold() != v.casefold():
714-
raise ValueError(
715-
"Property with name `{}` should have the same definition, got {},"
716-
" expected {}".format(k, v, self.properties[k])
717-
)
712+
# Only validate property definition when they're the same definition,
713+
# don't validate when two different definitions are based on the same
714+
# underlying table.
715+
if self.name == new_schema.name:
716+
for k, v in new_schema.properties.items():
717+
if k in self.properties:
718+
if self.properties[k].casefold() != v.casefold():
719+
raise ValueError(
720+
"Property with name `{}` should have the same definition, got {},"
721+
" expected {}".format(k, v, self.properties[k])
722+
)
718723

719724
for k, v in new_schema.types.items():
720725
if k in self.types:
@@ -845,16 +850,29 @@ def from_information_schema(self, info_schema: Dict[str, Any]) -> None:
845850
info_schema: the information schema represenation of a graph;
846851
"""
847852
property_decls = info_schema.get("propertyDeclarations", [])
853+
decl_by_types = CaseInsensitiveDict(
854+
{
855+
decl["name"]: TypeUtility.schema_str_to_spanner_type(decl["type"])
856+
for decl in property_decls
857+
if TypeUtility.schema_str_to_spanner_type(decl["type"]) is not None
858+
}
859+
)
848860
for node in info_schema["nodeTables"]:
849-
node_schema = ElementSchema.from_info_schema(node, property_decls)
861+
node_schema = ElementSchema.from_info_schema(node, decl_by_types)
850862
self._update_node_schema(node_schema)
851863
self._update_labels_and_properties(node_schema)
852864

853865
for edge in info_schema.get("edgeTables", []):
854-
edge_schema = ElementSchema.from_info_schema(edge, property_decls)
866+
edge_schema = ElementSchema.from_info_schema(edge, decl_by_types)
855867
self._update_edge_schema(edge_schema)
856868
self._update_labels_and_properties(edge_schema)
857869

870+
def node_type_name(self, name: str) -> str:
871+
return NODE_KIND if self.use_flexible_schema else name
872+
873+
def edge_type_name(self, name: str) -> str:
874+
return EDGE_KIND if self.use_flexible_schema else name
875+
858876
def get_node_schema(self, name: str) -> Optional[ElementSchema]:
859877
"""Gets the node schema by name.
860878
@@ -919,40 +937,54 @@ def __repr__(self) -> str:
919937
for k, v in self.properties.items()
920938
}
921939
)
940+
node_labels = {label for node in self.nodes.values() for label in node.labels}
941+
edge_labels = {label for edge in self.edges.values() for label in edge.labels}
942+
Triplet = Tuple[ElementSchema, ElementSchema, ElementSchema]
943+
triplets_per_label: CaseInsensitiveDict[List[Triplet]] = CaseInsensitiveDict({})
944+
for edge in self.edges.values():
945+
for label in edge.labels:
946+
source_node = self.get_node_schema(edge.source.node_name)
947+
target_node = self.get_node_schema(edge.target.node_name)
948+
if source_node is None:
949+
raise ValueError(f"Source node {edge.source.node_name} not found")
950+
if target_node is None:
951+
raise ValueError(f"Tource node {edge.target.node_name} not found")
952+
triplets_per_label.setdefault(label, []).append(
953+
(source_node, edge, target_node)
954+
)
922955
return json.dumps(
923956
{
924957
"Name of graph": self.graph_name,
925-
"Node properties per node type": {
926-
node.name: [
958+
"Node properties per node label": {
959+
label: [
927960
{
928961
"property name": name,
929962
"property type": properties[name],
930963
}
931-
for name in node.properties.keys()
964+
for name in self.labels[label].prop_names
932965
]
933-
for node in self.nodes.values()
966+
for label in node_labels
934967
},
935-
"Edge properties per edge type": {
936-
edge.name: [
968+
"Edge properties per edge label": {
969+
label: [
937970
{
938971
"property name": name,
939972
"property type": properties[name],
940973
}
941-
for name in edge.properties.keys()
974+
for name in self.labels[label].prop_names
942975
]
943-
for edge in self.edges.values()
944-
},
945-
"Node labels per node type": {
946-
node.name: node.labels for node in self.nodes.values()
976+
for label in edge_labels
947977
},
948-
"Edge labels per edge type": {
949-
edge.name: edge.labels for edge in self.edges.values()
950-
},
951-
"Edges": {
952-
edge.name: "From {} nodes to {} nodes".format(
953-
edge.source.node_name, edge.target.node_name
954-
)
955-
for edge in self.edges.values()
978+
"Possible edges per label": {
979+
label: [
980+
"(:{}) -[:{}]-> (:{})".format(
981+
source_node_label, label, target_node_label
982+
)
983+
for (source, edge, target) in triplets
984+
for source_node_label in source.labels
985+
for target_node_label in target.labels
986+
]
987+
for label, triplets in triplets_per_label.items()
956988
},
957989
},
958990
indent=2,
@@ -1035,18 +1067,15 @@ def construct_element_table(
10351067
)
10361068
ddl += "\nNODE TABLES(\n "
10371069
ddl += ",\n ".join(
1038-
(
1039-
construct_element_table(node, self.labels)
1040-
for node in self.node_tables.values()
1041-
)
1070+
(construct_element_table(node, self.labels) for node in self.nodes.values())
10421071
)
10431072
ddl += "\n)"
10441073
if len(self.edges) > 0:
10451074
ddl += "\nEDGE TABLES(\n "
10461075
ddl += ",\n ".join(
10471076
(
10481077
construct_element_table(edge, self.labels)
1049-
for edge in self.edge_tables.values()
1078+
for edge in self.edges.values()
10501079
)
10511080
)
10521081
ddl += "\n)"
@@ -1062,14 +1091,16 @@ def _update_node_schema(self, node_schema: ElementSchema) -> List[str]:
10621091
List[str]: a list of DDL statements that requires to evolve the schema.
10631092
"""
10641093

1065-
old_schema = self.node_tables.get(node_schema.name, None)
1066-
if old_schema is None:
1067-
ddls = [node_schema.to_ddl(self)]
1068-
self.node_tables[node_schema.name] = node_schema
1069-
else:
1094+
old_schema = self.nodes.get(node_schema.name, None)
1095+
if old_schema is not None:
10701096
ddls = old_schema.evolve(node_schema)
1097+
elif node_schema.base_table_name in self.node_tables:
1098+
ddls = self.node_tables[node_schema.base_table_name].evolve(node_schema)
1099+
else:
1100+
ddls = [node_schema.to_ddl(self)]
1101+
self.node_tables[node_schema.base_table_name] = node_schema
10711102

1072-
self.nodes[node_schema.original_name] = old_schema or node_schema
1103+
self.nodes[node_schema.name] = old_schema or node_schema
10731104
return ddls
10741105

10751106
def _update_edge_schema(self, edge_schema: ElementSchema) -> List[str]:
@@ -1081,15 +1112,16 @@ def _update_edge_schema(self, edge_schema: ElementSchema) -> List[str]:
10811112
Returns:
10821113
List[str]: a list of DDL statements that requires to evolve the schema.
10831114
"""
1084-
if edge_schema.base_table_name not in self.edge_tables:
1115+
old_schema = self.edges.get(edge_schema.name, None)
1116+
if old_schema is not None:
1117+
ddls = old_schema.evolve(edge_schema)
1118+
elif edge_schema.base_table_name in self.edge_tables:
1119+
ddls = self.edge_tables[edge_schema.base_table_name].evolve(edge_schema)
1120+
else:
10851121
ddls = [edge_schema.to_ddl(self)]
10861122
self.edge_tables[edge_schema.base_table_name] = edge_schema
1087-
else:
1088-
ddls = self.edge_tables[edge_schema.base_table_name].evolve(edge_schema)
10891123

1090-
self.edges[edge_schema.original_name] = self.edge_tables[
1091-
edge_schema.base_table_name
1092-
]
1124+
self.edges[edge_schema.name] = old_schema or edge_schema
10931125
return ddls
10941126

10951127
def _update_labels_and_properties(self, element_schema: ElementSchema) -> None:
@@ -1121,7 +1153,7 @@ def add_nodes(
11211153
List[str]: a list of column names;
11221154
List[List[Any]]: a list of rows.
11231155
"""
1124-
node_schema = self.get_node_schema(name)
1156+
node_schema = self.get_node_schema(self.node_type_name(name))
11251157
if node_schema is None:
11261158
raise ValueError("Unknown node schema: `%s`" % name)
11271159
for v in node_schema.add_nodes(name, nodes):
@@ -1142,8 +1174,9 @@ def add_edges(
11421174
List[str]: a list of column names;
11431175
List[List[Any]]: a list of rows.
11441176
"""
1145-
edge_schema = self.get_edge_schema(name)
1177+
edge_schema = self.get_edge_schema(self.edge_type_name(name))
11461178
if edge_schema is None:
1179+
print(list(self.edges.keys()))
11471180
raise ValueError("Unknown edge schema `%s`" % name)
11481181
for v in edge_schema.add_edges(name, edges):
11491182
yield v

src/langchain_google_spanner/type_utils.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
import base64
1818
import datetime
19-
from typing import Any
19+
from typing import Any, Optional
2020

2121
from google.cloud.spanner_v1 import JsonObject, param_types
2222

@@ -67,14 +67,14 @@ def spanner_type_to_schema_str(
6767
raise ValueError("Unsupported type: %s" % t)
6868

6969
@staticmethod
70-
def schema_str_to_spanner_type(s: str) -> param_types.Type:
70+
def schema_str_to_spanner_type(s: str) -> Optional[param_types.Type]:
7171
"""Returns a Spanner type corresponding to the string representation from Spanner schema type.
7272
7373
Parameters:
7474
- s: string representation of a Spanner schema type.
7575
7676
Returns:
77-
- Type[Any]: the corresponding Spanner type.
77+
- Optional[param_types.Type]: the corresponding Spanner type.
7878
"""
7979
if s == "BOOL":
8080
return param_types.BOOL
@@ -98,6 +98,10 @@ def schema_str_to_spanner_type(s: str) -> param_types.Type:
9898
return param_types.Array(
9999
TypeUtility.schema_str_to_spanner_type(s[len("ARRAY<") : -len(">")])
100100
)
101+
if s == "TOKENLIST":
102+
# There is no corresponding type for TOKENLIST in value type yet.
103+
# Returns none to allow TOKENLIST in the schema.
104+
return None
101105
raise ValueError("Unsupported type: %s" % s)
102106

103107
@staticmethod

0 commit comments

Comments
 (0)