@@ -51,8 +51,8 @@ class Field:
5151 dtype : bigframes .dtypes .Dtype
5252
5353
54- @dataclass (frozen = True )
55- class BigFrameNode :
54+ @dataclass (eq = False , frozen = True )
55+ class BigFrameNode ( abc . ABC ) :
5656 """
5757 Immutable node for representing 2D typed array as a tree of operators.
5858
@@ -95,12 +95,30 @@ def session(self):
9595 return sessions [0 ]
9696 return None
9797
98+ def _as_tuple (self ) -> Tuple :
99+ """Get all fields as tuple."""
100+ return tuple (getattr (self , field .name ) for field in fields (self ))
101+
102+ def __hash__ (self ) -> int :
103+ # Custom hash that uses cache to avoid costly recomputation
104+ return self ._cached_hash
105+
106+ def __eq__ (self , other ) -> bool :
107+ # Custom eq that tries to short-circuit full structural comparison
108+ if not isinstance (other , self .__class__ ):
109+ return False
110+ if self is other :
111+ return True
112+ if hash (self ) != hash (other ):
113+ return False
114+ return self ._as_tuple () == other ._as_tuple ()
115+
98116 # BigFrameNode trees can be very deep so its important avoid recalculating the hash from scratch
99117 # Each subclass of BigFrameNode should use this property to implement __hash__
100118 # The default dataclass-generated __hash__ method is not cached
101119 @functools .cached_property
102- def _node_hash (self ):
103- return hash (tuple ( hash ( getattr ( self , field . name )) for field in fields ( self ) ))
120+ def _cached_hash (self ):
121+ return hash (self . _as_tuple ( ))
104122
105123 @property
106124 def roots (self ) -> typing .Set [BigFrameNode ]:
@@ -226,7 +244,7 @@ def prune(self, used_cols: COLUMN_SET) -> BigFrameNode:
226244 return self .transform_children (lambda x : x .prune (used_cols ))
227245
228246
229- @dataclass (frozen = True )
247+ @dataclass (frozen = True , eq = False )
230248class UnaryNode (BigFrameNode ):
231249 child : BigFrameNode
232250
@@ -252,7 +270,7 @@ def order_ambiguous(self) -> bool:
252270 return self .child .order_ambiguous
253271
254272
255- @dataclass (frozen = True )
273+ @dataclass (frozen = True , eq = False )
256274class JoinNode (BigFrameNode ):
257275 left_child : BigFrameNode
258276 right_child : BigFrameNode
@@ -285,9 +303,6 @@ def explicitly_ordered(self) -> bool:
285303 # Do not consider user pre-join ordering intent - they need to re-order post-join in unordered mode.
286304 return False
287305
288- def __hash__ (self ):
289- return self ._node_hash
290-
291306 @functools .cached_property
292307 def fields (self ) -> Tuple [Field , ...]:
293308 return tuple (itertools .chain (self .left_child .fields , self .right_child .fields ))
@@ -320,7 +335,7 @@ def prune(self, used_cols: COLUMN_SET) -> BigFrameNode:
320335 return self .transform_children (lambda x : x .prune (new_used ))
321336
322337
323- @dataclass (frozen = True )
338+ @dataclass (frozen = True , eq = False )
324339class ConcatNode (BigFrameNode ):
325340 # TODO: Explcitly map column ids from each child
326341 children : Tuple [BigFrameNode , ...]
@@ -345,9 +360,6 @@ def explicitly_ordered(self) -> bool:
345360 # Consider concat as an ordered operations (even though input frames may not be ordered)
346361 return True
347362
348- def __hash__ (self ):
349- return self ._node_hash
350-
351363 @functools .cached_property
352364 def fields (self ) -> Tuple [Field , ...]:
353365 # TODO: Output names should probably be aligned beforehand or be part of concat definition
@@ -371,16 +383,13 @@ def prune(self, used_cols: COLUMN_SET) -> BigFrameNode:
371383 return self
372384
373385
374- @dataclass (frozen = True )
386+ @dataclass (frozen = True , eq = False )
375387class FromRangeNode (BigFrameNode ):
376388 # TODO: Enforce single-row, single column constraint
377389 start : BigFrameNode
378390 end : BigFrameNode
379391 step : int
380392
381- def __hash__ (self ):
382- return self ._node_hash
383-
384393 @property
385394 def roots (self ) -> typing .Set [BigFrameNode ]:
386395 return {self }
@@ -419,7 +428,7 @@ def prune(self, used_cols: COLUMN_SET) -> BigFrameNode:
419428# Input Nodex
420429# TODO: Most leaf nodes produce fixed column names based on the datasource
421430# They should support renaming
422- @dataclass (frozen = True )
431+ @dataclass (frozen = True , eq = False )
423432class LeafNode (BigFrameNode ):
424433 @property
425434 def roots (self ) -> typing .Set [BigFrameNode ]:
@@ -451,7 +460,7 @@ class ScanList:
451460 items : typing .Tuple [ScanItem , ...]
452461
453462
454- @dataclass (frozen = True )
463+ @dataclass (frozen = True , eq = False )
455464class ReadLocalNode (LeafNode ):
456465 feather_bytes : bytes
457466 data_schema : schemata .ArraySchema
@@ -460,9 +469,6 @@ class ReadLocalNode(LeafNode):
460469 scan_list : ScanList
461470 session : typing .Optional [bigframes .session .Session ] = None
462471
463- def __hash__ (self ):
464- return self ._node_hash
465-
466472 @functools .cached_property
467473 def fields (self ) -> Tuple [Field , ...]:
468474 return tuple (Field (col_id , dtype ) for col_id , dtype , _ in self .scan_list .items )
@@ -547,7 +553,7 @@ class BigqueryDataSource:
547553
548554
549555## Put ordering in here or just add order_by node above?
550- @dataclass (frozen = True )
556+ @dataclass (frozen = True , eq = False )
551557class ReadTableNode (LeafNode ):
552558 source : BigqueryDataSource
553559 # Subset of physical schema column
@@ -570,9 +576,6 @@ def __post_init__(self):
570576 def session (self ):
571577 return self .table_session
572578
573- def __hash__ (self ):
574- return self ._node_hash
575-
576579 @functools .cached_property
577580 def fields (self ) -> Tuple [Field , ...]:
578581 return tuple (Field (col_id , dtype ) for col_id , dtype , _ in self .scan_list .items )
@@ -616,15 +619,12 @@ def prune(self, used_cols: COLUMN_SET) -> BigFrameNode:
616619 return ReadTableNode (self .source , new_scan_list , self .table_session )
617620
618621
619- @dataclass (frozen = True )
622+ @dataclass (frozen = True , eq = False )
620623class CachedTableNode (ReadTableNode ):
621624 # The original BFET subtree that was cached
622625 # note: this isn't a "child" node.
623626 original_node : BigFrameNode = field ()
624627
625- def __hash__ (self ):
626- return self ._node_hash
627-
628628 def prune (self , used_cols : COLUMN_SET ) -> BigFrameNode :
629629 new_scan_list = ScanList (
630630 tuple (item for item in self .scan_list .items if item .id in used_cols )
@@ -635,13 +635,10 @@ def prune(self, used_cols: COLUMN_SET) -> BigFrameNode:
635635
636636
637637# Unary nodes
638- @dataclass (frozen = True )
638+ @dataclass (frozen = True , eq = False )
639639class PromoteOffsetsNode (UnaryNode ):
640640 col_id : bigframes .core .identifiers .ColumnId
641641
642- def __hash__ (self ):
643- return self ._node_hash
644-
645642 @property
646643 def non_local (self ) -> bool :
647644 return True
@@ -666,17 +663,14 @@ def prune(self, used_cols: COLUMN_SET) -> BigFrameNode:
666663 return self .transform_children (lambda x : x .prune (new_used ))
667664
668665
669- @dataclass (frozen = True )
666+ @dataclass (frozen = True , eq = False )
670667class FilterNode (UnaryNode ):
671668 predicate : ex .Expression
672669
673670 @property
674671 def row_preserving (self ) -> bool :
675672 return False
676673
677- def __hash__ (self ):
678- return self ._node_hash
679-
680674 @property
681675 def variables_introduced (self ) -> int :
682676 return 1
@@ -687,13 +681,10 @@ def prune(self, used_cols: COLUMN_SET) -> BigFrameNode:
687681 return FilterNode (pruned_child , self .predicate )
688682
689683
690- @dataclass (frozen = True )
684+ @dataclass (frozen = True , eq = False )
691685class OrderByNode (UnaryNode ):
692686 by : Tuple [OrderingExpression , ...]
693687
694- def __hash__ (self ):
695- return self ._node_hash
696-
697688 @property
698689 def variables_introduced (self ) -> int :
699690 return 0
@@ -716,14 +707,11 @@ def prune(self, used_cols: COLUMN_SET) -> BigFrameNode:
716707 return OrderByNode (pruned_child , self .by )
717708
718709
719- @dataclass (frozen = True )
710+ @dataclass (frozen = True , eq = False )
720711class ReversedNode (UnaryNode ):
721712 # useless field to make sure has distinct hash
722713 reversed : bool = True
723714
724- def __hash__ (self ):
725- return self ._node_hash
726-
727715 @property
728716 def variables_introduced (self ) -> int :
729717 return 0
@@ -734,15 +722,12 @@ def relation_ops_created(self) -> int:
734722 return 0
735723
736724
737- @dataclass (frozen = True )
725+ @dataclass (frozen = True , eq = False )
738726class SelectionNode (UnaryNode ):
739727 input_output_pairs : typing .Tuple [
740728 typing .Tuple [ex .DerefOp , bigframes .core .identifiers .ColumnId ], ...
741729 ]
742730
743- def __hash__ (self ):
744- return self ._node_hash
745-
746731 @functools .cached_property
747732 def fields (self ) -> Tuple [Field , ...]:
748733 return tuple (
@@ -772,7 +757,7 @@ def prune(self, used_cols: COLUMN_SET) -> BigFrameNode:
772757 return SelectionNode (pruned_child , pruned_selections )
773758
774759
775- @dataclass (frozen = True )
760+ @dataclass (frozen = True , eq = False )
776761class ProjectionNode (UnaryNode ):
777762 """Assigns new variables (without modifying existing ones)"""
778763
@@ -788,9 +773,6 @@ def __post_init__(self):
788773 # Cannot assign to existing variables - append only!
789774 assert all (name not in self .child .schema .names for _ , name in self .assignments )
790775
791- def __hash__ (self ):
792- return self ._node_hash
793-
794776 @functools .cached_property
795777 def fields (self ) -> Tuple [Field , ...]:
796778 input_types = self .child ._dtype_lookup
@@ -819,7 +801,7 @@ def prune(self, used_cols: COLUMN_SET) -> BigFrameNode:
819801
820802# TODO: Merge RowCount into Aggregate Node?
821803# Row count can be compute from table metadata sometimes, so it is a bit special.
822- @dataclass (frozen = True )
804+ @dataclass (frozen = True , eq = False )
823805class RowCountNode (UnaryNode ):
824806 @property
825807 def row_preserving (self ) -> bool :
@@ -842,7 +824,7 @@ def defines_namespace(self) -> bool:
842824 return True
843825
844826
845- @dataclass (frozen = True )
827+ @dataclass (frozen = True , eq = False )
846828class AggregateNode (UnaryNode ):
847829 aggregations : typing .Tuple [
848830 typing .Tuple [ex .Aggregation , bigframes .core .identifiers .ColumnId ], ...
@@ -854,9 +836,6 @@ class AggregateNode(UnaryNode):
854836 def row_preserving (self ) -> bool :
855837 return False
856838
857- def __hash__ (self ):
858- return self ._node_hash
859-
860839 @property
861840 def non_local (self ) -> bool :
862841 return True
@@ -904,7 +883,7 @@ def prune(self, used_cols: COLUMN_SET) -> BigFrameNode:
904883 return AggregateNode (pruned_child , pruned_aggs , self .by_column_ids , self .dropna )
905884
906885
907- @dataclass (frozen = True )
886+ @dataclass (frozen = True , eq = False )
908887class WindowOpNode (UnaryNode ):
909888 column_name : ex .DerefOp
910889 op : agg_ops .UnaryWindowOp
@@ -913,9 +892,6 @@ class WindowOpNode(UnaryNode):
913892 never_skip_nulls : bool = False
914893 skip_reproject_unsafe : bool = False
915894
916- def __hash__ (self ):
917- return self ._node_hash
918-
919895 @property
920896 def non_local (self ) -> bool :
921897 return True
@@ -945,11 +921,8 @@ def prune(self, used_cols: COLUMN_SET) -> BigFrameNode:
945921
946922
947923# TODO: Remove this op
948- @dataclass (frozen = True )
924+ @dataclass (frozen = True , eq = False )
949925class ReprojectOpNode (UnaryNode ):
950- def __hash__ (self ):
951- return self ._node_hash
952-
953926 @property
954927 def variables_introduced (self ) -> int :
955928 return 0
@@ -960,7 +933,7 @@ def relation_ops_created(self) -> int:
960933 return 0
961934
962935
963- @dataclass (frozen = True )
936+ @dataclass (frozen = True , eq = False )
964937class RandomSampleNode (UnaryNode ):
965938 fraction : float
966939
@@ -972,26 +945,20 @@ def deterministic(self) -> bool:
972945 def row_preserving (self ) -> bool :
973946 return False
974947
975- def __hash__ (self ):
976- return self ._node_hash
977-
978948 @property
979949 def variables_introduced (self ) -> int :
980950 return 1
981951
982952
983953# TODO: Explode should create a new column instead of overriding the existing one
984- @dataclass (frozen = True )
954+ @dataclass (frozen = True , eq = False )
985955class ExplodeNode (UnaryNode ):
986956 column_ids : typing .Tuple [ex .DerefOp , ...]
987957
988958 @property
989959 def row_preserving (self ) -> bool :
990960 return False
991961
992- def __hash__ (self ):
993- return self ._node_hash
994-
995962 @functools .cached_property
996963 def fields (self ) -> Tuple [Field , ...]:
997964 return tuple (
0 commit comments