@@ -100,6 +100,7 @@ def cache_results_table(
100100 original_root : nodes .BigFrameNode ,
101101 table : bigquery .Table ,
102102 ordering : order .RowOrdering ,
103+ num_rows : Optional [int ] = None ,
103104 ):
104105 # Assumption: GBQ cached table uses field name as bq column name
105106 scan_list = nodes .ScanList (
@@ -112,7 +113,7 @@ def cache_results_table(
112113 source = nodes .BigqueryDataSource (
113114 nodes .GbqTable .from_table (table ),
114115 ordering = ordering ,
115- n_rows = table . num_rows ,
116+ n_rows = num_rows ,
116117 ),
117118 scan_list = scan_list ,
118119 table_session = original_root .session ,
@@ -468,14 +469,16 @@ def _cache_with_cluster_cols(
468469 plan , sort_rows = False , materialize_all_order_keys = True
469470 )
470471 )
471- tmp_table_ref = self ._sql_as_cached_temp_table (
472+ tmp_table_ref , num_rows = self ._sql_as_cached_temp_table (
472473 compiled .sql ,
473474 compiled .sql_schema ,
474475 cluster_cols = bq_io .select_cluster_cols (compiled .sql_schema , cluster_cols ),
475476 )
476477 tmp_table = self .bqclient .get_table (tmp_table_ref )
477478 assert compiled .row_order is not None
478- self .cache .cache_results_table (array_value .node , tmp_table , compiled .row_order )
479+ self .cache .cache_results_table (
480+ array_value .node , tmp_table , compiled .row_order , num_rows = num_rows
481+ )
479482
480483 def _cache_with_offsets (self , array_value : bigframes .core .ArrayValue ):
481484 """Executes the query and uses the resulting table to rewrite future executions."""
@@ -487,14 +490,16 @@ def _cache_with_offsets(self, array_value: bigframes.core.ArrayValue):
487490 sort_rows = False ,
488491 )
489492 )
490- tmp_table_ref = self ._sql_as_cached_temp_table (
493+ tmp_table_ref , num_rows = self ._sql_as_cached_temp_table (
491494 compiled .sql ,
492495 compiled .sql_schema ,
493496 cluster_cols = [offset_column ],
494497 )
495498 tmp_table = self .bqclient .get_table (tmp_table_ref )
496499 assert compiled .row_order is not None
497- self .cache .cache_results_table (array_value .node , tmp_table , compiled .row_order )
500+ self .cache .cache_results_table (
501+ array_value .node , tmp_table , compiled .row_order , num_rows = num_rows
502+ )
498503
499504 def _cache_with_session_awareness (
500505 self ,
@@ -552,7 +557,7 @@ def _sql_as_cached_temp_table(
552557 sql : str ,
553558 schema : Sequence [bigquery .SchemaField ],
554559 cluster_cols : Sequence [str ],
555- ) -> bigquery .TableReference :
560+ ) -> tuple [ bigquery .TableReference , Optional [ int ]] :
556561 assert len (cluster_cols ) <= _MAX_CLUSTER_COLUMNS
557562 temp_table = self .storage_manager .create_temp_table (schema , cluster_cols )
558563
@@ -567,8 +572,8 @@ def _sql_as_cached_temp_table(
567572 job_config = job_config ,
568573 )
569574 assert query_job is not None
570- query_job .result ()
571- return query_job .destination
575+ iter = query_job .result ()
576+ return query_job .destination , iter . total_rows
572577
573578 def _validate_result_schema (
574579 self ,
0 commit comments