Skip to content
10 changes: 8 additions & 2 deletions bigframes/functions/remote_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from __future__ import annotations

import collections.abc
import hashlib
import inspect
import logging
Expand Down Expand Up @@ -1043,6 +1044,8 @@ def wrapper(func):
"Types are required to use @remote_function."
)
input_types.append(param_type)
elif not isinstance(input_types, collections.abc.Sequence):
input_types = [input_types]

if output_type is None:
if (output_type := signature.return_annotation) is inspect.Signature.empty:
Expand All @@ -1055,9 +1058,12 @@ def wrapper(func):
# The function will actually be receiving a pandas Series, but allow both
# BigQuery DataFrames and pandas object types for compatibility.
is_row_processor = False
if input_types == bigframes.series.Series or input_types == pandas.Series:
if len(input_types) == 1 and (
(input_type := input_types[0]) == bigframes.series.Series
or input_type == pandas.Series
):
warnings.warn(
"input_types=Series scenario is in preview.",
"input_types=Series is in preview.",
stacklevel=1,
category=bigframes.exceptions.PreviewWarning,
)
Expand Down
2 changes: 1 addition & 1 deletion tests/system/small/test_remote_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -727,7 +727,7 @@ def add_ints(row):

with pytest.warns(
bigframes.exceptions.PreviewWarning,
match="input_types=Series scenario is in preview.",
match="input_types=Series is in preview.",
):
add_ints_remote = session.remote_function(bigframes.series.Series, int)(
add_ints
Expand Down
4 changes: 4 additions & 0 deletions tests/unit/resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import pytest

import bigframes
import bigframes.clients
import bigframes.core as core
import bigframes.core.ordering
import bigframes.dataframe
Expand Down Expand Up @@ -97,6 +98,9 @@ def query_mock(query, *args, **kwargs):

bqoptions = bigframes.BigQueryOptions(credentials=credentials, location=location)
session = bigframes.Session(context=bqoptions, clients_provider=clients_provider)
session._bq_connection_manager = mock.create_autospec(
bigframes.clients.BqConnectionManager, instance=True
)
return session


Expand Down
38 changes: 38 additions & 0 deletions tests/unit/test_remote_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,53 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import re

import bigframes_vendored.ibis.backends.bigquery.datatypes as third_party_ibis_bqtypes
from ibis.expr import datatypes as ibis_types
import pandas
import pytest

import bigframes.dtypes
import bigframes.functions.remote_function
import bigframes.series
from tests.unit import resources


@pytest.mark.parametrize(
"series_type",
(
pytest.param(
pandas.Series,
id="pandas.Series",
),
pytest.param(
bigframes.series.Series,
id="bigframes.series.Series",
),
),
)
def test_series_input_types_to_str(series_type):
"""Check that is_row_processor=True uses str as the input type to serialize a row."""
session = resources.create_bigquery_session()
remote_function_decorator = bigframes.functions.remote_function.remote_function(
session=session
)

with pytest.warns(
bigframes.exceptions.PreviewWarning,
match=re.escape("input_types=Series is in preview."),
):

@remote_function_decorator
def axis_1_function(myparam: series_type) -> str: # type: ignore
return "Hello, " + myparam["str_col"] + "!" # type: ignore

# Still works as a normal function.
assert axis_1_function(pandas.Series({"str_col": "World"})) == "Hello, World!"
assert axis_1_function.ibis_node is not None


def test_supported_types_correspond():
# The same types should be representable by the supported Python and BigQuery types.
ibis_types_from_python = {
Expand Down