Skip to content

Commit b23cf83

Browse files
refactor: ExecuteResult is reusable, sampleable (#2159)
1 parent ecee2bc commit b23cf83

27 files changed

+721
-369
lines changed

bigframes/core/array_value.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
import pandas
2424
import pyarrow as pa
2525

26-
from bigframes.core import agg_expressions
26+
from bigframes.core import agg_expressions, bq_data
2727
import bigframes.core.expression as ex
2828
import bigframes.core.guid
2929
import bigframes.core.identifiers as ids
@@ -63,7 +63,7 @@ def from_pyarrow(cls, arrow_table: pa.Table, session: Session):
6363
def from_managed(cls, source: local_data.ManagedArrowTable, session: Session):
6464
scan_list = nodes.ScanList(
6565
tuple(
66-
nodes.ScanItem(ids.ColumnId(item.column), item.dtype, item.column)
66+
nodes.ScanItem(ids.ColumnId(item.column), item.column)
6767
for item in source.schema.items
6868
)
6969
)
@@ -88,9 +88,9 @@ def from_range(cls, start, end, step):
8888
def from_table(
8989
cls,
9090
table: google.cloud.bigquery.Table,
91-
schema: schemata.ArraySchema,
9291
session: Session,
9392
*,
93+
columns: Optional[Sequence[str]] = None,
9494
predicate: Optional[str] = None,
9595
at_time: Optional[datetime.datetime] = None,
9696
primary_key: Sequence[str] = (),
@@ -100,7 +100,7 @@ def from_table(
100100
if offsets_col and primary_key:
101101
raise ValueError("must set at most one of 'offests', 'primary_key'")
102102
# define data source only for needed columns, this makes row-hashing cheaper
103-
table_def = nodes.GbqTable.from_table(table, columns=schema.names)
103+
table_def = bq_data.GbqTable.from_table(table, columns=columns or ())
104104

105105
# create ordering from info
106106
ordering = None
@@ -111,15 +111,17 @@ def from_table(
111111
[ids.ColumnId(key_part) for key_part in primary_key]
112112
)
113113

114+
bf_schema = schemata.ArraySchema.from_bq_table(table, columns=columns)
114115
# Scan all columns by default, we define this list as it can be pruned while preserving source_def
115116
scan_list = nodes.ScanList(
116117
tuple(
117-
nodes.ScanItem(ids.ColumnId(item.column), item.dtype, item.column)
118-
for item in schema.items
118+
nodes.ScanItem(ids.ColumnId(item.column), item.column)
119+
for item in bf_schema.items
119120
)
120121
)
121-
source_def = nodes.BigqueryDataSource(
122+
source_def = bq_data.BigqueryDataSource(
122123
table=table_def,
124+
schema=bf_schema,
123125
at_time=at_time,
124126
sql_predicate=predicate,
125127
ordering=ordering,
@@ -130,7 +132,7 @@ def from_table(
130132
@classmethod
131133
def from_bq_data_source(
132134
cls,
133-
source: nodes.BigqueryDataSource,
135+
source: bq_data.BigqueryDataSource,
134136
scan_list: nodes.ScanList,
135137
session: Session,
136138
):

bigframes/core/blocks.py

Lines changed: 41 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@
3737
Optional,
3838
Sequence,
3939
Tuple,
40-
TYPE_CHECKING,
4140
Union,
4241
)
4342
import warnings
@@ -70,9 +69,6 @@
7069
from bigframes.session import dry_runs, execution_spec
7170
from bigframes.session import executor as executors
7271

73-
if TYPE_CHECKING:
74-
from bigframes.session.executor import ExecuteResult
75-
7672
# Type constraint for wherever column labels are used
7773
Label = typing.Hashable
7874

@@ -98,7 +94,6 @@
9894
LevelsType = typing.Union[LevelType, typing.Sequence[LevelType]]
9995

10096

101-
@dataclasses.dataclass
10297
class PandasBatches(Iterator[pd.DataFrame]):
10398
"""Interface for mutable objects with state represented by a block value object."""
10499

@@ -271,10 +266,14 @@ def shape(self) -> typing.Tuple[int, int]:
271266
except Exception:
272267
pass
273268

274-
row_count = self.session._executor.execute(
275-
self.expr.row_count(),
276-
execution_spec.ExecutionSpec(promise_under_10gb=True, ordered=False),
277-
).to_py_scalar()
269+
row_count = (
270+
self.session._executor.execute(
271+
self.expr.row_count(),
272+
execution_spec.ExecutionSpec(promise_under_10gb=True, ordered=False),
273+
)
274+
.batches()
275+
.to_py_scalar()
276+
)
278277
return (row_count, len(self.value_columns))
279278

280279
@property
@@ -584,7 +583,7 @@ def to_arrow(
584583
ordered=ordered,
585584
),
586585
)
587-
pa_table = execute_result.to_arrow_table()
586+
pa_table = execute_result.batches().to_arrow_table()
588587

589588
pa_index_labels = []
590589
for index_level, index_label in enumerate(self._index_labels):
@@ -636,17 +635,13 @@ def to_pandas(
636635
max_download_size, sampling_method, random_state
637636
)
638637

639-
ex_result = self._materialize_local(
638+
return self._materialize_local(
640639
materialize_options=MaterializationOptions(
641640
downsampling=sampling,
642641
allow_large_results=allow_large_results,
643642
ordered=ordered,
644643
)
645644
)
646-
df = ex_result.to_pandas()
647-
df = self._copy_index_to_pandas(df)
648-
df.set_axis(self.column_labels, axis=1, copy=False)
649-
return df, ex_result.query_job
650645

651646
def _get_sampling_option(
652647
self,
@@ -683,7 +678,7 @@ def try_peek(
683678
self.expr,
684679
execution_spec.ExecutionSpec(promise_under_10gb=under_10gb, peek=n),
685680
)
686-
df = result.to_pandas()
681+
df = result.batches().to_pandas()
687682
return self._copy_index_to_pandas(df)
688683
else:
689684
return None
@@ -704,13 +699,14 @@ def to_pandas_batches(
704699
if (allow_large_results is not None)
705700
else not bigframes.options._allow_large_results
706701
)
707-
execute_result = self.session._executor.execute(
702+
execution_result = self.session._executor.execute(
708703
self.expr,
709704
execution_spec.ExecutionSpec(
710705
promise_under_10gb=under_10gb,
711706
ordered=True,
712707
),
713708
)
709+
result_batches = execution_result.batches()
714710

715711
# To reduce the number of edge cases to consider when working with the
716712
# results of this, always return at least one DataFrame. See:
@@ -724,19 +720,21 @@ def to_pandas_batches(
724720
dfs = map(
725721
lambda a: a[0],
726722
itertools.zip_longest(
727-
execute_result.to_pandas_batches(page_size, max_results),
723+
result_batches.to_pandas_batches(page_size, max_results),
728724
[0],
729725
fillvalue=empty_val,
730726
),
731727
)
732728
dfs = iter(map(self._copy_index_to_pandas, dfs))
733729

734-
total_rows = execute_result.total_rows
730+
total_rows = result_batches.approx_total_rows
735731
if (total_rows is not None) and (max_results is not None):
736732
total_rows = min(total_rows, max_results)
737733

738734
return PandasBatches(
739-
dfs, total_rows, total_bytes_processed=execute_result.total_bytes_processed
735+
dfs,
736+
total_rows,
737+
total_bytes_processed=execution_result.total_bytes_processed,
740738
)
741739

742740
def _copy_index_to_pandas(self, df: pd.DataFrame) -> pd.DataFrame:
@@ -754,7 +752,7 @@ def _copy_index_to_pandas(self, df: pd.DataFrame) -> pd.DataFrame:
754752

755753
def _materialize_local(
756754
self, materialize_options: MaterializationOptions = MaterializationOptions()
757-
) -> ExecuteResult:
755+
) -> tuple[pd.DataFrame, Optional[bigquery.QueryJob]]:
758756
"""Run query and download results as a pandas DataFrame. Return the total number of results as well."""
759757
# TODO(swast): Allow for dry run and timeout.
760758
under_10gb = (
@@ -769,9 +767,11 @@ def _materialize_local(
769767
ordered=materialize_options.ordered,
770768
),
771769
)
770+
result_batches = execute_result.batches()
771+
772772
sample_config = materialize_options.downsampling
773-
if execute_result.total_bytes is not None:
774-
table_mb = execute_result.total_bytes / _BYTES_TO_MEGABYTES
773+
if result_batches.approx_total_bytes is not None:
774+
table_mb = result_batches.approx_total_bytes / _BYTES_TO_MEGABYTES
775775
max_download_size = sample_config.max_download_size
776776
fraction = (
777777
max_download_size / table_mb
@@ -792,7 +792,7 @@ def _materialize_local(
792792

793793
# TODO: Maybe materialize before downsampling
794794
# Some downsampling methods
795-
if fraction < 1 and (execute_result.total_rows is not None):
795+
if fraction < 1 and (result_batches.approx_total_rows is not None):
796796
if not sample_config.enable_downsampling:
797797
raise RuntimeError(
798798
f"The data size ({table_mb:.2f} MB) exceeds the maximum download limit of "
@@ -811,7 +811,7 @@ def _materialize_local(
811811
"the downloading limit."
812812
)
813813
warnings.warn(msg, category=UserWarning)
814-
total_rows = execute_result.total_rows
814+
total_rows = result_batches.approx_total_rows
815815
# Remove downsampling config from subsequent invocations, as otherwise could result in many
816816
# iterations if downsampling undershoots
817817
return self._downsample(
@@ -823,7 +823,10 @@ def _materialize_local(
823823
MaterializationOptions(ordered=materialize_options.ordered)
824824
)
825825
else:
826-
return execute_result
826+
df = result_batches.to_pandas()
827+
df = self._copy_index_to_pandas(df)
828+
df.set_axis(self.column_labels, axis=1, copy=False)
829+
return df, execute_result.query_job
827830

828831
def _downsample(
829832
self, total_rows: int, sampling_method: str, fraction: float, random_state
@@ -1662,15 +1665,19 @@ def retrieve_repr_request_results(
16621665
ordered=True,
16631666
),
16641667
)
1665-
row_count = self.session._executor.execute(
1666-
self.expr.row_count(),
1667-
execution_spec.ExecutionSpec(
1668-
promise_under_10gb=True,
1669-
ordered=False,
1670-
),
1671-
).to_py_scalar()
1668+
row_count = (
1669+
self.session._executor.execute(
1670+
self.expr.row_count(),
1671+
execution_spec.ExecutionSpec(
1672+
promise_under_10gb=True,
1673+
ordered=False,
1674+
),
1675+
)
1676+
.batches()
1677+
.to_py_scalar()
1678+
)
16721679

1673-
head_df = head_result.to_pandas()
1680+
head_df = head_result.batches().to_pandas()
16741681
return self._copy_index_to_pandas(head_df), row_count, head_result.query_job
16751682

16761683
def promote_offsets(self, label: Label = None) -> typing.Tuple[Block, str]:

0 commit comments

Comments
 (0)