Skip to content
20 changes: 12 additions & 8 deletions bigframes/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import bigframes.dtypes
import bigframes.operations as ops
import bigframes.operations.aggregations as agg_ops
import bigframes.session._io.bigquery

if typing.TYPE_CHECKING:
from bigframes.session import Session
Expand Down Expand Up @@ -153,25 +154,28 @@ def start_query(

def cached(self, cluster_cols: typing.Sequence[str]) -> ArrayValue:
"""Write the ArrayValue to a session table and create a new block object that references it."""
compiled = self.compile()
ibis_expr = compiled._to_ibis_expr("unordered", expose_hidden_cols=True)
destination = self.session._ibis_to_session_table(
ibis_expr, cluster_cols=cluster_cols, api_name="cache"
compiled_value = self.compile()
ibis_expr = compiled_value._to_ibis_expr(
ordering_mode="unordered", expose_hidden_cols=True
)
tmp_table = self.session._ibis_to_session_table(
ibis_expr, cluster_cols=cluster_cols, api_name="cached"
)

table_expression = self.session.ibis_client.table(
f"{destination.project}.{destination.dataset_id}.{destination.table_id}"
f"{tmp_table.project}.{tmp_table.dataset_id}.{tmp_table.table_id}"
)
new_columns = [table_expression[column] for column in compiled.column_ids]
new_columns = [table_expression[column] for column in compiled_value.column_ids]
new_hidden_columns = [
table_expression[column]
for column in compiled._hidden_ordering_column_names
for column in compiled_value._hidden_ordering_column_names
]
return ArrayValue.from_ibis(
self.session,
table_expression,
columns=new_columns,
hidden_ordering_columns=new_hidden_columns,
ordering=compiled._ordering,
ordering=compiled_value._ordering,
)

# Operations
Expand Down
72 changes: 25 additions & 47 deletions bigframes/session/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from __future__ import annotations

import datetime
import logging
import os
import re
Expand Down Expand Up @@ -325,9 +326,15 @@ def _query_to_destination(
# internal issue 303057336.
# Since we have a `statement_type == 'SELECT'`, schema should be populated.
schema = typing.cast(Iterable[bigquery.SchemaField], dry_run_job.schema)
temp_table = self._create_session_table_empty(api_name, schema, index_cols)
cluster_cols = [
item.name
for item in schema
if (item.name in index_cols) and _can_cluster_bq(item)
][:_MAX_CLUSTER_COLUMNS]
temp_table = self._create_empty_temp_table(schema, cluster_cols)

job_config = bigquery.QueryJobConfig()
job_config.labels["bigframes-api"] = api_name
job_config.destination = temp_table

try:
Expand Down Expand Up @@ -422,17 +429,15 @@ def _read_gbq_query(
index_col: Iterable[str] | str = (),
col_order: Iterable[str] = (),
max_results: Optional[int] = None,
api_name: str,
api_name: str = "read_gbq_query",
) -> dataframe.DataFrame:
if isinstance(index_col, str):
index_cols = [index_col]
else:
index_cols = list(index_col)

destination, query_job = self._query_to_destination(
query,
index_cols,
api_name=api_name,
query, index_cols, api_name=api_name
)

# If there was no destination table, that means the query must have
Expand Down Expand Up @@ -1273,53 +1278,26 @@ def _create_session_table(self) -> bigquery.TableReference:
)
return dataset.table(table_name)

def _create_session_table_empty(
def _create_empty_temp_table(
self,
api_name: str,
schema: Iterable[bigquery.SchemaField],
cluster_cols: List[str],
) -> bigquery.TableReference:
# Can't set a table in _SESSION as destination via query job API, so we
# run DDL, instead.
table = self._create_session_table()
schema_sql = bigframes_io.bq_schema_to_sql(schema)

clusterable_cols = [
col.name
for col in schema
if col.name in cluster_cols and _can_cluster_bq(col)
][:_MAX_CLUSTER_COLUMNS]

if clusterable_cols:
cluster_cols_sql = ", ".join(
f"`{cluster_col}`" for cluster_col in clusterable_cols
)
cluster_sql = f"CLUSTER BY {cluster_cols_sql}"
else:
cluster_sql = ""

ddl_text = f"""
CREATE TEMP TABLE
`_SESSION`.`{table.table_id}`
({schema_sql})
{cluster_sql}
"""

job_config = bigquery.QueryJobConfig()

# Include a label so that Dataplex Lineage can identify temporary
# tables that BigQuery DataFrames creates. Googlers: See internal issue
# 296779699. We're labeling the job instead of the table because
# otherwise we get `BadRequest: 400 OPTIONS on temporary tables are not
# supported`.
job_config.labels = {"source": "bigquery-dataframes-temp"}
job_config.labels["bigframes-api"] = api_name

_, query_job = self._start_query(ddl_text, job_config=job_config)
dataset = self._anonymous_dataset
expiration = (
datetime.datetime.now(datetime.timezone.utc) + constants.DEFAULT_EXPIRATION
)

# Use fully-qualified name instead of `_SESSION` name so that the
# created table can be used as the destination table.
return query_job.destination
table = bigframes_io.create_temp_table(
self.bqclient,
dataset,
expiration,
schema=schema,
cluster_columns=cluster_cols,
)
return bigquery.TableReference.from_string(table)

def _create_sequential_ordering(
self,
Expand Down Expand Up @@ -1356,13 +1334,13 @@ def _ibis_to_session_table(
cluster_cols: Iterable[str],
api_name: str,
) -> bigquery.TableReference:
desination, _ = self._query_to_destination(
destination, _ = self._query_to_destination(
self.ibis_client.compile(table),
index_cols=list(cluster_cols),
api_name=api_name,
)
# There should always be a destination table for this query type.
return typing.cast(bigquery.TableReference, desination)
return typing.cast(bigquery.TableReference, destination)

def remote_function(
self,
Expand Down
8 changes: 7 additions & 1 deletion bigframes/session/_io/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import datetime
import textwrap
import types
from typing import Dict, Iterable, Union
from typing import Dict, Iterable, Optional, Union
import uuid

import google.cloud.bigquery as bigquery
Expand Down Expand Up @@ -121,11 +121,17 @@ def create_temp_table(
bqclient: bigquery.Client,
dataset: bigquery.DatasetReference,
expiration: datetime.datetime,
*,
schema: Optional[Iterable[bigquery.SchemaField]] = None,
cluster_columns: Optional[list[str]] = None,
) -> str:
"""Create an empty table with an expiration in the desired dataset."""
table_ref = random_table(dataset)
destination = bigquery.Table(table_ref)
destination.expires = expiration
destination.schema = schema
if cluster_columns:
destination.clustering_fields = cluster_columns
bqclient.create_table(destination)
return f"{table_ref.project}.{table_ref.dataset_id}.{table_ref.table_id}"

Expand Down