@@ -35,16 +35,21 @@ class SquashedSelect:
3535 columns : Tuple [Tuple [scalar_exprs .Expression , str ], ...]
3636 predicate : Optional [scalar_exprs .Expression ]
3737 ordering : Tuple [order .OrderingExpression , ...]
38+ reverse_root : bool = False
3839
3940 @classmethod
40- def from_node (cls , node : nodes .BigFrameNode ) -> SquashedSelect :
41+ def from_node (
42+ cls , node : nodes .BigFrameNode , projections_only : bool = False
43+ ) -> SquashedSelect :
4144 if isinstance (node , nodes .ProjectionNode ):
42- return cls .from_node (node .child ).project (node .assignments )
43- elif isinstance (node , nodes .FilterNode ):
45+ return cls .from_node (node .child , projections_only = projections_only ).project (
46+ node .assignments
47+ )
48+ elif not projections_only and isinstance (node , nodes .FilterNode ):
4449 return cls .from_node (node .child ).filter (node .predicate )
45- elif isinstance (node , nodes .ReversedNode ):
50+ elif not projections_only and isinstance (node , nodes .ReversedNode ):
4651 return cls .from_node (node .child ).reverse ()
47- elif isinstance (node , nodes .OrderByNode ):
52+ elif not projections_only and isinstance (node , nodes .OrderByNode ):
4853 return cls .from_node (node .child ).order_with (node .by )
4954 else :
5055 selection = tuple (
@@ -63,7 +68,9 @@ def project(
6368 new_columns = tuple (
6469 (expr .bind_all_variables (self .column_lookup ), id ) for expr , id in projection
6570 )
66- return SquashedSelect (self .root , new_columns , self .predicate , self .ordering )
71+ return SquashedSelect (
72+ self .root , new_columns , self .predicate , self .ordering , self .reverse_root
73+ )
6774
6875 def filter (self , predicate : scalar_exprs .Expression ) -> SquashedSelect :
6976 if self .predicate is None :
@@ -72,18 +79,24 @@ def filter(self, predicate: scalar_exprs.Expression) -> SquashedSelect:
7279 new_predicate = ops .and_op .as_expr (
7380 self .predicate , predicate .bind_all_variables (self .column_lookup )
7481 )
75- return SquashedSelect (self .root , self .columns , new_predicate , self .ordering )
82+ return SquashedSelect (
83+ self .root , self .columns , new_predicate , self .ordering , self .reverse_root
84+ )
7685
7786 def reverse (self ) -> SquashedSelect :
7887 new_ordering = tuple (expr .with_reverse () for expr in self .ordering )
79- return SquashedSelect (self .root , self .columns , self .predicate , new_ordering )
88+ return SquashedSelect (
89+ self .root , self .columns , self .predicate , new_ordering , not self .reverse_root
90+ )
8091
8192 def order_with (self , by : Tuple [order .OrderingExpression , ...]):
8293 adjusted_orderings = [
8394 order_part .bind_variables (self .column_lookup ) for order_part in by
8495 ]
8596 new_ordering = (* adjusted_orderings , * self .ordering )
86- return SquashedSelect (self .root , self .columns , self .predicate , new_ordering )
97+ return SquashedSelect (
98+ self .root , self .columns , self .predicate , new_ordering , self .reverse_root
99+ )
87100
88101 def maybe_join (
89102 self , right : SquashedSelect , join_def : join_defs .JoinDefinition
@@ -126,8 +139,10 @@ def maybe_join(
126139 new_columns = remap_names (join_def , lselection , rselection )
127140
128141 # Reconstruct ordering
142+ reverse_root = self .reverse_root
129143 if join_type == "right" :
130144 new_ordering = right .ordering
145+ reverse_root = right .reverse_root
131146 elif join_type == "outer" :
132147 if lmask is not None :
133148 prefix = order .OrderingExpression (lmask , order .OrderingDirection .DESC )
@@ -158,18 +173,31 @@ def maybe_join(
158173 new_ordering = self .ordering
159174 else :
160175 raise ValueError (f"Unexpected join type { join_type } " )
161- return SquashedSelect (self .root , new_columns , new_predicate , new_ordering )
176+ return SquashedSelect (
177+ self .root , new_columns , new_predicate , new_ordering , reverse_root
178+ )
162179
163180 def expand (self ) -> nodes .BigFrameNode :
164181 # Safest to apply predicates first, as it may filter out inputs that cannot be handled by other expressions
165182 root = self .root
183+ if self .reverse_root :
184+ root = nodes .ReversedNode (child = root )
166185 if self .predicate :
167186 root = nodes .FilterNode (child = root , predicate = self .predicate )
168187 if self .ordering :
169188 root = nodes .OrderByNode (child = root , by = self .ordering )
170189 return nodes .ProjectionNode (child = root , assignments = self .columns )
171190
172191
192+ def maybe_squash_projection (node : nodes .BigFrameNode ) -> nodes .BigFrameNode :
193+ if isinstance (node , nodes .ProjectionNode ) and isinstance (
194+ node .child , nodes .ProjectionNode
195+ ):
196+ # Conservative approach, only squash consecutive projections, even though could also squash filters, reorderings
197+ return SquashedSelect .from_node (node , projections_only = True ).expand ()
198+ return node
199+
200+
173201def maybe_rewrite_join (join_node : nodes .JoinNode ) -> nodes .BigFrameNode :
174202 left_side = SquashedSelect .from_node (join_node .left_child )
175203 right_side = SquashedSelect .from_node (join_node .right_child )
0 commit comments