Skip to content

fix: add warnings for duplicated or conflicting type hints in bigfram… #1956

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Aug 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions bigframes/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,10 @@ class FunctionAxisOnePreviewWarning(PreviewWarning):
"""Remote Function and Managed UDF with axis=1 preview."""


class FunctionConflictTypeHintWarning(UserWarning):
"""Conflicting type hints in a BigFrames function."""


class FunctionPackageVersionWarning(PreviewWarning):
"""
Managed UDF package versions for Numpy, Pandas, and Pyarrow may not
Expand Down
20 changes: 20 additions & 0 deletions bigframes/functions/_function_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,13 +536,23 @@ def wrapper(func):
if input_types is not None:
if not isinstance(input_types, collections.abc.Sequence):
input_types = [input_types]
if _utils.has_conflict_input_type(py_sig, input_types):
msg = bfe.format_message(
"Conflicting input types detected, using the one from the decorator."
)
warnings.warn(msg, category=bfe.FunctionConflictTypeHintWarning)
py_sig = py_sig.replace(
parameters=[
par.replace(annotation=itype)
for par, itype in zip(py_sig.parameters.values(), input_types)
]
)
if output_type:
if _utils.has_conflict_output_type(py_sig, output_type):
msg = bfe.format_message(
"Conflicting return type detected, using the one from the decorator."
)
warnings.warn(msg, category=bfe.FunctionConflictTypeHintWarning)
py_sig = py_sig.replace(return_annotation=output_type)

# Try to get input types via type annotations.
Expand Down Expand Up @@ -838,13 +848,23 @@ def wrapper(func):
if input_types is not None:
if not isinstance(input_types, collections.abc.Sequence):
input_types = [input_types]
if _utils.has_conflict_input_type(py_sig, input_types):
msg = bfe.format_message(
"Conflicting input types detected, using the one from the decorator."
)
warnings.warn(msg, category=bfe.FunctionConflictTypeHintWarning)
py_sig = py_sig.replace(
parameters=[
par.replace(annotation=itype)
for par, itype in zip(py_sig.parameters.values(), input_types)
]
)
if output_type:
if _utils.has_conflict_output_type(py_sig, output_type):
msg = bfe.format_message(
"Conflicting return type detected, using the one from the decorator."
)
warnings.warn(msg, category=bfe.FunctionConflictTypeHintWarning)
py_sig = py_sig.replace(return_annotation=output_type)

# The function will actually be receiving a pandas Series, but allow
Expand Down
36 changes: 35 additions & 1 deletion bigframes/functions/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,11 @@


import hashlib
import inspect
import json
import sys
import typing
from typing import cast, Optional, Set
from typing import Any, cast, Optional, Sequence, Set
import warnings

import cloudpickle
Expand Down Expand Up @@ -290,3 +291,36 @@ def post_process(input):
return bbq.json_extract_string_array(input, value_dtype=result_dtype)

return post_process


def has_conflict_input_type(
signature: inspect.Signature,
input_types: Sequence[Any],
) -> bool:
"""Checks if the parameters have any conflict with the input_types."""
params = list(signature.parameters.values())

if len(params) != len(input_types):
return True

# Check for conflicts type hints.
for i, param in enumerate(params):
if param.annotation is not inspect.Parameter.empty:
if param.annotation != input_types[i]:
return True

# No conflicts were found after checking all parameters.
return False


def has_conflict_output_type(
signature: inspect.Signature,
output_type: Any,
) -> bool:
"""Checks if the return type annotation conflicts with the output_type."""
return_annotation = signature.return_annotation

if return_annotation is inspect.Parameter.empty:
return False

return return_annotation != output_type
42 changes: 33 additions & 9 deletions tests/system/large/functions/test_managed_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import warnings

import google.api_core.exceptions
import pandas
import pyarrow
Expand All @@ -31,12 +33,22 @@
def test_managed_function_array_output(session, scalars_dfs, dataset_id):
try:

@session.udf(
dataset=dataset_id,
name=prefixer.create_prefix(),
with warnings.catch_warnings(record=True) as record:

@session.udf(
dataset=dataset_id,
name=prefixer.create_prefix(),
)
def featurize(x: int) -> list[float]:
return [float(i) for i in [x, x + 1, x + 2]]

# No following conflict warning when there is no redundant type hints.
input_type_warning = "Conflicting input types detected"
return_type_warning = "Conflicting return type detected"
assert not any(input_type_warning in str(warning.message) for warning in record)
assert not any(
return_type_warning in str(warning.message) for warning in record
)
def featurize(x: int) -> list[float]:
return [float(i) for i in [x, x + 1, x + 2]]

scalars_df, scalars_pandas_df = scalars_dfs

Expand Down Expand Up @@ -222,7 +234,10 @@ def add(x: int, y: int) -> int:
def test_managed_function_series_combine_array_output(session, dataset_id, scalars_dfs):
try:

def add_list(x: int, y: int) -> list[int]:
# The type hints in this function's signature has conflicts. The
# `input_types` and `output_type` arguments from udf decorator take
# precedence and will be used instead.
def add_list(x, y: bool) -> list[bool]:
return [x, y]

scalars_df, scalars_pandas_df = scalars_dfs
Expand All @@ -234,9 +249,18 @@ def add_list(x: int, y: int) -> list[int]:
# Make sure there are NA values in the test column.
assert any([pandas.isna(val) for val in bf_df[int_col_name_with_nulls]])

add_list_managed_func = session.udf(
dataset=dataset_id, name=prefixer.create_prefix()
)(add_list)
with warnings.catch_warnings(record=True) as record:
add_list_managed_func = session.udf(
input_types=[int, int],
output_type=list[int],
dataset=dataset_id,
name=prefixer.create_prefix(),
)(add_list)

input_type_warning = "Conflicting input types detected"
assert any(input_type_warning in str(warning.message) for warning in record)
return_type_warning = "Conflicting return type detected"
assert any(return_type_warning in str(warning.message) for warning in record)

# After filtering out nulls the managed function application should work
# similar to pandas.
Expand Down
54 changes: 39 additions & 15 deletions tests/system/large/functions/test_remote_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -843,22 +843,31 @@ def test_remote_function_with_external_package_dependencies(
):
try:

def pd_np_foo(x):
# The return type hint in this function's signature has conflict. The
# `output_type` argument from remote_function decorator takes precedence
# and will be used instead.
def pd_np_foo(x) -> None:
import numpy as mynp
import pandas as mypd

return mypd.Series([x, mynp.sqrt(mynp.abs(x))]).sum()

# Create the remote function with the name provided explicitly
pd_np_foo_remote = session.remote_function(
input_types=[int],
output_type=float,
dataset=dataset_id,
bigquery_connection=bq_cf_connection,
reuse=False,
packages=["numpy", "pandas >= 2.0.0"],
cloud_function_service_account="default",
)(pd_np_foo)
with warnings.catch_warnings(record=True) as record:
# Create the remote function with the name provided explicitly
pd_np_foo_remote = session.remote_function(
input_types=[int],
output_type=float,
dataset=dataset_id,
bigquery_connection=bq_cf_connection,
reuse=False,
packages=["numpy", "pandas >= 2.0.0"],
cloud_function_service_account="default",
)(pd_np_foo)

input_type_warning = "Conflicting input types detected"
assert not any(input_type_warning in str(warning.message) for warning in record)
return_type_warning = "Conflicting return type detected"
assert any(return_type_warning in str(warning.message) for warning in record)

# The behavior of the created remote function should be as expected
scalars_df, scalars_pandas_df = scalars_dfs
Expand Down Expand Up @@ -1999,10 +2008,25 @@ def test_remote_function_unnamed_removed_w_session_cleanup():
# create a clean session
session = bigframes.connect()

# create an unnamed remote function in the session
@session.remote_function(reuse=False, cloud_function_service_account="default")
def foo(x: int) -> int:
return x + 1
with warnings.catch_warnings(record=True) as record:
# create an unnamed remote function in the session.
# The type hints in this function's signature are redundant. The
# `input_types` and `output_type` arguments from remote_function
# decorator take precedence and will be used instead.
@session.remote_function(
input_types=[int],
output_type=int,
reuse=False,
cloud_function_service_account="default",
)
def foo(x: int) -> int:
return x + 1

# No following warning with only redundant type hints (no conflict).
input_type_warning = "Conflicting input types detected"
assert not any(input_type_warning in str(warning.message) for warning in record)
return_type_warning = "Conflicting return type detected"
assert not any(return_type_warning in str(warning.message) for warning in record)

# ensure that remote function artifacts are created
assert foo.bigframes_remote_function is not None
Expand Down