Skip to content

Commit eedd123

Browse files
authored
Merge branch 'main' into fix_hash
2 parents 6d25225 + 3acc494 commit eedd123

File tree

6 files changed

+196
-60
lines changed

6 files changed

+196
-60
lines changed

README.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ Documentation
2525
* `BigQuery DataFrames source code (GitHub) <https://github.com/googleapis/python-bigquery-dataframes>`_
2626
* `BigQuery DataFrames sample notebooks <https://github.com/googleapis/python-bigquery-dataframes/tree/main/notebooks>`_
2727
* `BigQuery DataFrames API reference <https://cloud.google.com/python/docs/reference/bigframes/latest/summary_overview>`_
28+
* `BigQuery DataFrames supported pandas APIs <https://cloud.google.com/python/docs/reference/bigframes/latest/supported_pandas_apis>`_
2829

2930

3031
Getting started with BigQuery DataFrames

bigframes/dtypes.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -658,10 +658,14 @@ def is_compatible(scalar: typing.Any, dtype: Dtype) -> typing.Optional[Dtype]:
658658
return None
659659

660660

661-
def lcd_type(dtype1: Dtype, dtype2: Dtype) -> Dtype:
662-
"""Get the supertype of the two types."""
663-
if dtype1 == dtype2:
664-
return dtype1
661+
def lcd_type(*dtypes: Dtype) -> Dtype:
662+
if len(dtypes) < 1:
663+
raise ValueError("at least one dypes should be provided")
664+
if len(dtypes) == 1:
665+
return dtypes[0]
666+
unique_dtypes = set(dtypes)
667+
if len(unique_dtypes) == 1:
668+
return unique_dtypes.pop()
665669
# Implicit conversion currently only supported for numeric types
666670
hierarchy: list[Dtype] = [
667671
pd.BooleanDtype(),
@@ -670,9 +674,9 @@ def lcd_type(dtype1: Dtype, dtype2: Dtype) -> Dtype:
670674
pd.ArrowDtype(pa.decimal256(76, 38)),
671675
pd.Float64Dtype(),
672676
]
673-
if (dtype1 not in hierarchy) or (dtype2 not in hierarchy):
677+
if any([dtype not in hierarchy for dtype in dtypes]):
674678
return None
675-
lcd_index = max(hierarchy.index(dtype1), hierarchy.index(dtype2))
679+
lcd_index = max([hierarchy.index(dtype) for dtype in dtypes])
676680
return hierarchy[lcd_index]
677681

678682

bigframes/session/__init__.py

Lines changed: 17 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,9 @@ def __init__(
232232
# Now that we're starting the session, don't allow the options to be
233233
# changed.
234234
context._session_started = True
235-
self._df_snapshot: Dict[bigquery.TableReference, datetime.datetime] = {}
235+
self._df_snapshot: Dict[
236+
bigquery.TableReference, Tuple[datetime.datetime, bigquery.Table]
237+
] = {}
236238

237239
@property
238240
def bqclient(self):
@@ -699,14 +701,25 @@ def _get_snapshot_sql_and_primary_key(
699701
column(s), then return those too so that ordering generation can be
700702
avoided.
701703
"""
702-
# If there are primary keys defined, the query engine assumes these
703-
# columns are unique, even if the constraint is not enforced. We make
704-
# the same assumption and use these columns as the total ordering keys.
704+
(
705+
snapshot_timestamp,
706+
table,
707+
) = bigframes_io.get_snapshot_datetime_and_table_metadata(
708+
self.bqclient,
709+
table_ref=table_ref,
710+
api_name=api_name,
711+
cache=self._df_snapshot,
712+
use_cache=use_cache,
713+
)
714+
705715
if table.location.casefold() != self._location.casefold():
706716
raise ValueError(
707717
f"Current session is in {self._location} but dataset '{table.project}.{table.dataset_id}' is located in {table.location}"
708718
)
709719

720+
# If there are primary keys defined, the query engine assumes these
721+
# columns are unique, even if the constraint is not enforced. We make
722+
# the same assumption and use these columns as the total ordering keys.
710723
primary_keys = None
711724
if (
712725
(table_constraints := getattr(table, "table_constraints", None)) is not None
@@ -717,37 +730,6 @@ def _get_snapshot_sql_and_primary_key(
717730
):
718731
primary_keys = columns
719732

720-
job_config = bigquery.QueryJobConfig()
721-
job_config.labels["bigframes-api"] = api_name
722-
if use_cache and table.reference in self._df_snapshot.keys():
723-
snapshot_timestamp = self._df_snapshot[table.reference]
724-
725-
# Cache hit could be unexpected. See internal issue 329545805.
726-
# Raise a warning with more information about how to avoid the
727-
# problems with the cache.
728-
warnings.warn(
729-
f"Reading cached table from {snapshot_timestamp} to avoid "
730-
"incompatibilies with previous reads of this table. To read "
731-
"the latest version, set `use_cache=False` or close the "
732-
"current session with Session.close() or "
733-
"bigframes.pandas.close_session().",
734-
# There are many layers before we get to (possibly) the user's code:
735-
# pandas.read_gbq_table
736-
# -> with_default_session
737-
# -> Session.read_gbq_table
738-
# -> _read_gbq_table
739-
# -> _get_snapshot_sql_and_primary_key
740-
stacklevel=6,
741-
)
742-
else:
743-
snapshot_timestamp = list(
744-
self.bqclient.query(
745-
"SELECT CURRENT_TIMESTAMP() AS `current_timestamp`",
746-
job_config=job_config,
747-
).result()
748-
)[0][0]
749-
self._df_snapshot[table.reference] = snapshot_timestamp
750-
751733
try:
752734
table_expression = self.ibis_client.sql(
753735
bigframes_io.create_snapshot_sql(table.reference, snapshot_timestamp)

bigframes/session/_io/bigquery.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import types
2424
from typing import Dict, Iterable, Optional, Sequence, Tuple, Union
2525
import uuid
26+
import warnings
2627

2728
import google.api_core.exceptions
2829
import google.cloud.bigquery as bigquery
@@ -121,6 +122,59 @@ def table_ref_to_sql(table: bigquery.TableReference) -> str:
121122
return f"`{table.project}`.`{table.dataset_id}`.`{table.table_id}`"
122123

123124

125+
def get_snapshot_datetime_and_table_metadata(
126+
bqclient: bigquery.Client,
127+
table_ref: bigquery.TableReference,
128+
*,
129+
api_name: str,
130+
cache: Dict[bigquery.TableReference, Tuple[datetime.datetime, bigquery.Table]],
131+
use_cache: bool = True,
132+
) -> Tuple[datetime.datetime, bigquery.Table]:
133+
cached_table = cache.get(table_ref)
134+
if use_cache and cached_table is not None:
135+
snapshot_timestamp, _ = cached_table
136+
137+
# Cache hit could be unexpected. See internal issue 329545805.
138+
# Raise a warning with more information about how to avoid the
139+
# problems with the cache.
140+
warnings.warn(
141+
f"Reading cached table from {snapshot_timestamp} to avoid "
142+
"incompatibilies with previous reads of this table. To read "
143+
"the latest version, set `use_cache=False` or close the "
144+
"current session with Session.close() or "
145+
"bigframes.pandas.close_session().",
146+
# There are many layers before we get to (possibly) the user's code:
147+
# pandas.read_gbq_table
148+
# -> with_default_session
149+
# -> Session.read_gbq_table
150+
# -> _read_gbq_table
151+
# -> _get_snapshot_sql_and_primary_key
152+
# -> get_snapshot_datetime_and_table_metadata
153+
stacklevel=7,
154+
)
155+
return cached_table
156+
157+
# TODO(swast): It's possible that the table metadata is changed between now
158+
# and when we run the CURRENT_TIMESTAMP() query to see when we can time
159+
# travel to. Find a way to fetch the table metadata and BQ's current time
160+
# atomically.
161+
table = bqclient.get_table(table_ref)
162+
163+
# TODO(b/336521938): Refactor to make sure we set the "bigframes-api"
164+
# whereever we execute a query.
165+
job_config = bigquery.QueryJobConfig()
166+
job_config.labels["bigframes-api"] = api_name
167+
snapshot_timestamp = list(
168+
bqclient.query(
169+
"SELECT CURRENT_TIMESTAMP() AS `current_timestamp`",
170+
job_config=job_config,
171+
).result()
172+
)[0][0]
173+
cached_table = (snapshot_timestamp, table)
174+
cache[table_ref] = cached_table
175+
return cached_table
176+
177+
124178
def create_snapshot_sql(
125179
table_ref: bigquery.TableReference, current_timestamp: datetime.datetime
126180
) -> str:

tests/system/small/test_dataframe.py

Lines changed: 109 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2390,12 +2390,27 @@ def test_dataframe_pct_change(scalars_df_index, scalars_pandas_df_index, periods
23902390
def test_dataframe_agg_single_string(scalars_dfs):
23912391
numeric_cols = ["int64_col", "int64_too", "float64_col"]
23922392
scalars_df, scalars_pandas_df = scalars_dfs
2393+
23932394
bf_result = scalars_df[numeric_cols].agg("sum").to_pandas()
23942395
pd_result = scalars_pandas_df[numeric_cols].agg("sum")
23952396

2396-
# Pandas may produce narrower numeric types, but bigframes always produces Float64
2397-
pd_result = pd_result.astype("Float64")
2398-
pd.testing.assert_series_equal(pd_result, bf_result, check_index_type=False)
2397+
assert bf_result.dtype == "Float64"
2398+
pd.testing.assert_series_equal(
2399+
pd_result, bf_result, check_dtype=False, check_index_type=False
2400+
)
2401+
2402+
2403+
def test_dataframe_agg_int_single_string(scalars_dfs):
2404+
numeric_cols = ["int64_col", "int64_too", "bool_col"]
2405+
scalars_df, scalars_pandas_df = scalars_dfs
2406+
2407+
bf_result = scalars_df[numeric_cols].agg("sum").to_pandas()
2408+
pd_result = scalars_pandas_df[numeric_cols].agg("sum")
2409+
2410+
assert bf_result.dtype == "Int64"
2411+
pd.testing.assert_series_equal(
2412+
pd_result, bf_result, check_dtype=False, check_index_type=False
2413+
)
23992414

24002415

24012416
def test_dataframe_agg_multi_string(scalars_dfs):
@@ -2431,6 +2446,27 @@ def test_dataframe_agg_multi_string(scalars_dfs):
24312446
).all()
24322447

24332448

2449+
def test_dataframe_agg_int_multi_string(scalars_dfs):
2450+
numeric_cols = ["int64_col", "int64_too", "bool_col"]
2451+
aggregations = [
2452+
"sum",
2453+
"nunique",
2454+
"count",
2455+
]
2456+
scalars_df, scalars_pandas_df = scalars_dfs
2457+
bf_result = scalars_df[numeric_cols].agg(aggregations).to_pandas()
2458+
pd_result = scalars_pandas_df[numeric_cols].agg(aggregations)
2459+
2460+
for dtype in bf_result.dtypes:
2461+
assert dtype == "Int64"
2462+
2463+
# Pandas may produce narrower numeric types
2464+
# Pandas has object index type
2465+
pd.testing.assert_frame_equal(
2466+
pd_result, bf_result, check_dtype=False, check_index_type=False
2467+
)
2468+
2469+
24342470
@skip_legacy_pandas
24352471
def test_df_describe(scalars_dfs):
24362472
scalars_df, scalars_pandas_df = scalars_dfs
@@ -2982,6 +3018,58 @@ def test_loc_setitem_bool_series_scalar_error(scalars_dfs):
29823018
pd_df.loc[pd_df["int64_too"] == 1, "string_col"] = 99
29833019

29843020

3021+
@pytest.mark.parametrize(
3022+
("col", "op"),
3023+
[
3024+
# Int aggregates
3025+
pytest.param("int64_col", lambda x: x.sum(), id="int-sum"),
3026+
pytest.param("int64_col", lambda x: x.min(), id="int-min"),
3027+
pytest.param("int64_col", lambda x: x.max(), id="int-max"),
3028+
pytest.param("int64_col", lambda x: x.count(), id="int-count"),
3029+
pytest.param("int64_col", lambda x: x.nunique(), id="int-nunique"),
3030+
# Float aggregates
3031+
pytest.param("float64_col", lambda x: x.count(), id="float-count"),
3032+
pytest.param("float64_col", lambda x: x.nunique(), id="float-nunique"),
3033+
# Bool aggregates
3034+
pytest.param("bool_col", lambda x: x.sum(), id="bool-sum"),
3035+
pytest.param("bool_col", lambda x: x.count(), id="bool-count"),
3036+
pytest.param("bool_col", lambda x: x.nunique(), id="bool-nunique"),
3037+
# String aggregates
3038+
pytest.param("string_col", lambda x: x.count(), id="string-count"),
3039+
pytest.param("string_col", lambda x: x.nunique(), id="string-nunique"),
3040+
],
3041+
)
3042+
def test_dataframe_aggregate_int(scalars_df_index, scalars_pandas_df_index, col, op):
3043+
bf_result = op(scalars_df_index[[col]]).to_pandas()
3044+
pd_result = op(scalars_pandas_df_index[[col]])
3045+
3046+
# Check dtype separately
3047+
assert bf_result.dtype == "Int64"
3048+
3049+
# Pandas may produce narrower numeric types
3050+
# Pandas has object index type
3051+
assert_series_equal(pd_result, bf_result, check_dtype=False, check_index_type=False)
3052+
3053+
3054+
@pytest.mark.parametrize(
3055+
("col", "op"),
3056+
[
3057+
pytest.param("bool_col", lambda x: x.min(), id="bool-min"),
3058+
pytest.param("bool_col", lambda x: x.max(), id="bool-max"),
3059+
],
3060+
)
3061+
def test_dataframe_aggregate_bool(scalars_df_index, scalars_pandas_df_index, col, op):
3062+
bf_result = op(scalars_df_index[[col]]).to_pandas()
3063+
pd_result = op(scalars_pandas_df_index[[col]])
3064+
3065+
# Check dtype separately
3066+
assert bf_result.dtype == "boolean"
3067+
3068+
# Pandas may produce narrower numeric types
3069+
# Pandas has object index type
3070+
assert_series_equal(pd_result, bf_result, check_dtype=False, check_index_type=False)
3071+
3072+
29853073
@pytest.mark.parametrize(
29863074
("ordered"),
29873075
[
@@ -2990,34 +3078,38 @@ def test_loc_setitem_bool_series_scalar_error(scalars_dfs):
29903078
],
29913079
)
29923080
@pytest.mark.parametrize(
2993-
("op"),
3081+
("op", "bf_dtype"),
29943082
[
2995-
(lambda x: x.sum(numeric_only=True)),
2996-
(lambda x: x.mean(numeric_only=True)),
2997-
(lambda x: x.min(numeric_only=True)),
2998-
(lambda x: x.max(numeric_only=True)),
2999-
(lambda x: x.std(numeric_only=True)),
3000-
(lambda x: x.var(numeric_only=True)),
3001-
(lambda x: x.count(numeric_only=False)),
3002-
(lambda x: x.nunique()),
3083+
(lambda x: x.sum(numeric_only=True), "Float64"),
3084+
(lambda x: x.mean(numeric_only=True), "Float64"),
3085+
(lambda x: x.min(numeric_only=True), "Float64"),
3086+
(lambda x: x.max(numeric_only=True), "Float64"),
3087+
(lambda x: x.std(numeric_only=True), "Float64"),
3088+
(lambda x: x.var(numeric_only=True), "Float64"),
3089+
(lambda x: x.count(numeric_only=False), "Int64"),
3090+
(lambda x: x.nunique(), "Int64"),
30033091
],
30043092
ids=["sum", "mean", "min", "max", "std", "var", "count", "nunique"],
30053093
)
3006-
def test_dataframe_aggregates(scalars_df_index, scalars_pandas_df_index, op, ordered):
3094+
def test_dataframe_aggregates(
3095+
scalars_df_index, scalars_pandas_df_index, op, bf_dtype, ordered
3096+
):
30073097
col_names = ["int64_too", "float64_col", "string_col", "int64_col", "bool_col"]
30083098
bf_series = op(scalars_df_index[col_names])
3009-
pd_series = op(scalars_pandas_df_index[col_names])
30103099
bf_result = bf_series.to_pandas(ordered=ordered)
3100+
pd_result = op(scalars_pandas_df_index[col_names])
3101+
3102+
# Check dtype separately
3103+
assert bf_result.dtype == bf_dtype
30113104

30123105
# Pandas may produce narrower numeric types, but bigframes always produces Float64
30133106
# Pandas has object index type
3014-
pd_series.index = pd_series.index.astype(pd.StringDtype(storage="pyarrow"))
30153107
assert_series_equal(
3016-
pd_series,
3108+
pd_result,
30173109
bf_result,
3110+
check_dtype=False,
30183111
check_index_type=False,
30193112
ignore_order=not ordered,
3020-
check_dtype=False,
30213113
)
30223114

30233115

tests/unit/session/test_session.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,11 @@ def test_read_gbq_cached_table():
4242
google.cloud.bigquery.DatasetReference("my-project", "my_dataset"),
4343
"my_table",
4444
)
45-
session._df_snapshot[table_ref] = datetime.datetime(
46-
1999, 1, 2, 3, 4, 5, 678901, tzinfo=datetime.timezone.utc
45+
table = google.cloud.bigquery.Table(table_ref)
46+
table._properties["location"] = session._location
47+
session._df_snapshot[table_ref] = (
48+
datetime.datetime(1999, 1, 2, 3, 4, 5, 678901, tzinfo=datetime.timezone.utc),
49+
table,
4750
)
4851

4952
def get_table_mock(table_ref):

0 commit comments

Comments
 (0)