diff --git a/bigframes/exceptions.py b/bigframes/exceptions.py index 39a847de84..174f3e852a 100644 --- a/bigframes/exceptions.py +++ b/bigframes/exceptions.py @@ -103,6 +103,10 @@ class FunctionAxisOnePreviewWarning(PreviewWarning): """Remote Function and Managed UDF with axis=1 preview.""" +class FunctionConflictTypeHintWarning(UserWarning): + """Conflicting type hints in a BigFrames function.""" + + class FunctionPackageVersionWarning(PreviewWarning): """ Managed UDF package versions for Numpy, Pandas, and Pyarrow may not diff --git a/bigframes/functions/_function_session.py b/bigframes/functions/_function_session.py index 371784332c..64d2f0edfd 100644 --- a/bigframes/functions/_function_session.py +++ b/bigframes/functions/_function_session.py @@ -536,6 +536,11 @@ def wrapper(func): if input_types is not None: if not isinstance(input_types, collections.abc.Sequence): input_types = [input_types] + if _utils.has_conflict_input_type(py_sig, input_types): + msg = bfe.format_message( + "Conflicting input types detected, using the one from the decorator." + ) + warnings.warn(msg, category=bfe.FunctionConflictTypeHintWarning) py_sig = py_sig.replace( parameters=[ par.replace(annotation=itype) @@ -543,6 +548,11 @@ def wrapper(func): ] ) if output_type: + if _utils.has_conflict_output_type(py_sig, output_type): + msg = bfe.format_message( + "Conflicting return type detected, using the one from the decorator." + ) + warnings.warn(msg, category=bfe.FunctionConflictTypeHintWarning) py_sig = py_sig.replace(return_annotation=output_type) # Try to get input types via type annotations. @@ -838,6 +848,11 @@ def wrapper(func): if input_types is not None: if not isinstance(input_types, collections.abc.Sequence): input_types = [input_types] + if _utils.has_conflict_input_type(py_sig, input_types): + msg = bfe.format_message( + "Conflicting input types detected, using the one from the decorator." + ) + warnings.warn(msg, category=bfe.FunctionConflictTypeHintWarning) py_sig = py_sig.replace( parameters=[ par.replace(annotation=itype) @@ -845,6 +860,11 @@ def wrapper(func): ] ) if output_type: + if _utils.has_conflict_output_type(py_sig, output_type): + msg = bfe.format_message( + "Conflicting return type detected, using the one from the decorator." + ) + warnings.warn(msg, category=bfe.FunctionConflictTypeHintWarning) py_sig = py_sig.replace(return_annotation=output_type) # The function will actually be receiving a pandas Series, but allow diff --git a/bigframes/functions/_utils.py b/bigframes/functions/_utils.py index 0b7222db86..8d0ea57c92 100644 --- a/bigframes/functions/_utils.py +++ b/bigframes/functions/_utils.py @@ -14,10 +14,11 @@ import hashlib +import inspect import json import sys import typing -from typing import cast, Optional, Set +from typing import Any, cast, Optional, Sequence, Set import warnings import cloudpickle @@ -290,3 +291,36 @@ def post_process(input): return bbq.json_extract_string_array(input, value_dtype=result_dtype) return post_process + + +def has_conflict_input_type( + signature: inspect.Signature, + input_types: Sequence[Any], +) -> bool: + """Checks if the parameters have any conflict with the input_types.""" + params = list(signature.parameters.values()) + + if len(params) != len(input_types): + return True + + # Check for conflicts type hints. + for i, param in enumerate(params): + if param.annotation is not inspect.Parameter.empty: + if param.annotation != input_types[i]: + return True + + # No conflicts were found after checking all parameters. + return False + + +def has_conflict_output_type( + signature: inspect.Signature, + output_type: Any, +) -> bool: + """Checks if the return type annotation conflicts with the output_type.""" + return_annotation = signature.return_annotation + + if return_annotation is inspect.Parameter.empty: + return False + + return return_annotation != output_type diff --git a/tests/system/large/functions/test_managed_function.py b/tests/system/large/functions/test_managed_function.py index 5aa27e1775..5349529f1d 100644 --- a/tests/system/large/functions/test_managed_function.py +++ b/tests/system/large/functions/test_managed_function.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import warnings + import google.api_core.exceptions import pandas import pyarrow @@ -31,12 +33,22 @@ def test_managed_function_array_output(session, scalars_dfs, dataset_id): try: - @session.udf( - dataset=dataset_id, - name=prefixer.create_prefix(), + with warnings.catch_warnings(record=True) as record: + + @session.udf( + dataset=dataset_id, + name=prefixer.create_prefix(), + ) + def featurize(x: int) -> list[float]: + return [float(i) for i in [x, x + 1, x + 2]] + + # No following conflict warning when there is no redundant type hints. + input_type_warning = "Conflicting input types detected" + return_type_warning = "Conflicting return type detected" + assert not any(input_type_warning in str(warning.message) for warning in record) + assert not any( + return_type_warning in str(warning.message) for warning in record ) - def featurize(x: int) -> list[float]: - return [float(i) for i in [x, x + 1, x + 2]] scalars_df, scalars_pandas_df = scalars_dfs @@ -222,7 +234,10 @@ def add(x: int, y: int) -> int: def test_managed_function_series_combine_array_output(session, dataset_id, scalars_dfs): try: - def add_list(x: int, y: int) -> list[int]: + # The type hints in this function's signature has conflicts. The + # `input_types` and `output_type` arguments from udf decorator take + # precedence and will be used instead. + def add_list(x, y: bool) -> list[bool]: return [x, y] scalars_df, scalars_pandas_df = scalars_dfs @@ -234,9 +249,18 @@ def add_list(x: int, y: int) -> list[int]: # Make sure there are NA values in the test column. assert any([pandas.isna(val) for val in bf_df[int_col_name_with_nulls]]) - add_list_managed_func = session.udf( - dataset=dataset_id, name=prefixer.create_prefix() - )(add_list) + with warnings.catch_warnings(record=True) as record: + add_list_managed_func = session.udf( + input_types=[int, int], + output_type=list[int], + dataset=dataset_id, + name=prefixer.create_prefix(), + )(add_list) + + input_type_warning = "Conflicting input types detected" + assert any(input_type_warning in str(warning.message) for warning in record) + return_type_warning = "Conflicting return type detected" + assert any(return_type_warning in str(warning.message) for warning in record) # After filtering out nulls the managed function application should work # similar to pandas. diff --git a/tests/system/large/functions/test_remote_function.py b/tests/system/large/functions/test_remote_function.py index f3e97aeb85..0d6029bd2c 100644 --- a/tests/system/large/functions/test_remote_function.py +++ b/tests/system/large/functions/test_remote_function.py @@ -843,22 +843,31 @@ def test_remote_function_with_external_package_dependencies( ): try: - def pd_np_foo(x): + # The return type hint in this function's signature has conflict. The + # `output_type` argument from remote_function decorator takes precedence + # and will be used instead. + def pd_np_foo(x) -> None: import numpy as mynp import pandas as mypd return mypd.Series([x, mynp.sqrt(mynp.abs(x))]).sum() - # Create the remote function with the name provided explicitly - pd_np_foo_remote = session.remote_function( - input_types=[int], - output_type=float, - dataset=dataset_id, - bigquery_connection=bq_cf_connection, - reuse=False, - packages=["numpy", "pandas >= 2.0.0"], - cloud_function_service_account="default", - )(pd_np_foo) + with warnings.catch_warnings(record=True) as record: + # Create the remote function with the name provided explicitly + pd_np_foo_remote = session.remote_function( + input_types=[int], + output_type=float, + dataset=dataset_id, + bigquery_connection=bq_cf_connection, + reuse=False, + packages=["numpy", "pandas >= 2.0.0"], + cloud_function_service_account="default", + )(pd_np_foo) + + input_type_warning = "Conflicting input types detected" + assert not any(input_type_warning in str(warning.message) for warning in record) + return_type_warning = "Conflicting return type detected" + assert any(return_type_warning in str(warning.message) for warning in record) # The behavior of the created remote function should be as expected scalars_df, scalars_pandas_df = scalars_dfs @@ -1999,10 +2008,25 @@ def test_remote_function_unnamed_removed_w_session_cleanup(): # create a clean session session = bigframes.connect() - # create an unnamed remote function in the session - @session.remote_function(reuse=False, cloud_function_service_account="default") - def foo(x: int) -> int: - return x + 1 + with warnings.catch_warnings(record=True) as record: + # create an unnamed remote function in the session. + # The type hints in this function's signature are redundant. The + # `input_types` and `output_type` arguments from remote_function + # decorator take precedence and will be used instead. + @session.remote_function( + input_types=[int], + output_type=int, + reuse=False, + cloud_function_service_account="default", + ) + def foo(x: int) -> int: + return x + 1 + + # No following warning with only redundant type hints (no conflict). + input_type_warning = "Conflicting input types detected" + assert not any(input_type_warning in str(warning.message) for warning in record) + return_type_warning = "Conflicting return type detected" + assert not any(return_type_warning in str(warning.message) for warning in record) # ensure that remote function artifacts are created assert foo.bigframes_remote_function is not None