diff --git a/tests/system/small/ml/test_llm.py b/tests/system/small/ml/test_llm.py index 245fead028..999950007f 100644 --- a/tests/system/small/ml/test_llm.py +++ b/tests/system/small/ml/test_llm.py @@ -12,7 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable +from contextlib import AbstractContextManager, nullcontext +from typing import Any, Callable from unittest import mock import pandas as pd @@ -25,6 +26,21 @@ from bigframes.testing import utils +@pytest.fixture(scope="function") +def text_generator_model(request, bq_connection, session): + """Creates a text generator model, mocking creation for Claude models.""" + model_class = request.param + if model_class == llm.Claude3TextGenerator: + # For Claude, mock the BQML model creation to avoid the network call + # that fails due to the region issue. + with mock.patch.object(llm.Claude3TextGenerator, "_create_bqml_model"): + model = model_class(connection_name=bq_connection, session=session) + else: + # For other models like Gemini, create as usual. + model = model_class(connection_name=bq_connection, session=session) + yield model + + @pytest.mark.parametrize( "model_name", ("text-embedding-005", "text-embedding-004", "text-multilingual-embedding-002"), @@ -251,14 +267,10 @@ def __eq__(self, other): return self.equals(other) -@pytest.mark.skip("b/436340035 test failed") @pytest.mark.parametrize( - ( - "model_class", - "options", - ), + ("text_generator_model", "options"), [ - ( + pytest.param( llm.GeminiTextGenerator, { "temperature": 0.9, @@ -266,22 +278,24 @@ def __eq__(self, other): "top_p": 1.0, "ground_with_google_search": False, }, + id="gemini", ), - ( + pytest.param( llm.Claude3TextGenerator, { "max_output_tokens": 128, "top_k": 40, "top_p": 0.95, }, + id="claude", ), ], + indirect=["text_generator_model"], ) def test_text_generator_retry_success( session, - model_class, + text_generator_model, options, - bq_connection, ): # Requests. df0 = EqCmpAllDataFrame( @@ -298,21 +312,13 @@ def test_text_generator_retry_success( df1 = EqCmpAllDataFrame( { "ml_generate_text_status": ["error", "error"], - "prompt": [ - "What is BQML?", - "What is BigQuery DataFrame?", - ], + "prompt": ["What is BQML?", "What is BigQuery DataFrame?"], }, index=[1, 2], session=session, ) df2 = EqCmpAllDataFrame( - { - "ml_generate_text_status": ["error"], - "prompt": [ - "What is BQML?", - ], - }, + {"ml_generate_text_status": ["error"], "prompt": ["What is BQML?"]}, index=[1], session=session, ) @@ -342,31 +348,21 @@ def test_text_generator_retry_success( EqCmpAllDataFrame( { "ml_generate_text_status": ["error", ""], - "prompt": [ - "What is BQML?", - "What is BigQuery DataFrame?", - ], + "prompt": ["What is BQML?", "What is BigQuery DataFrame?"], }, index=[1, 2], session=session, ), EqCmpAllDataFrame( - { - "ml_generate_text_status": [""], - "prompt": [ - "What is BQML?", - ], - }, + {"ml_generate_text_status": [""], "prompt": ["What is BQML?"]}, index=[1], session=session, ), ] - text_generator_model = model_class(connection_name=bq_connection, session=session) text_generator_model._bqml_model = mock_bqml_model with mock.patch.object(core.BqmlModel, "generate_text_tvf", generate_text_tvf): - # 3rd retry isn't triggered result = text_generator_model.predict(df0, max_retries=3) mock_generate_text.assert_has_calls( @@ -391,17 +387,14 @@ def test_text_generator_retry_success( ), check_dtype=False, check_index_type=False, + check_like=True, ) -@pytest.mark.skip("b/436340035 test failed") @pytest.mark.parametrize( - ( - "model_class", - "options", - ), + ("text_generator_model", "options"), [ - ( + pytest.param( llm.GeminiTextGenerator, { "temperature": 0.9, @@ -409,18 +402,21 @@ def test_text_generator_retry_success( "top_p": 1.0, "ground_with_google_search": False, }, + id="gemini", ), - ( + pytest.param( llm.Claude3TextGenerator, { "max_output_tokens": 128, "top_k": 40, "top_p": 0.95, }, + id="claude", ), ], + indirect=["text_generator_model"], ) -def test_text_generator_retry_no_progress(session, model_class, options, bq_connection): +def test_text_generator_retry_no_progress(session, text_generator_model, options): # Requests. df0 = EqCmpAllDataFrame( { @@ -480,7 +476,6 @@ def test_text_generator_retry_no_progress(session, model_class, options, bq_conn ), ] - text_generator_model = model_class(connection_name=bq_connection, session=session) text_generator_model._bqml_model = mock_bqml_model with mock.patch.object(core.BqmlModel, "generate_text_tvf", generate_text_tvf): @@ -508,10 +503,10 @@ def test_text_generator_retry_no_progress(session, model_class, options, bq_conn ), check_dtype=False, check_index_type=False, + check_like=True, ) -@pytest.mark.skip("b/436340035 test failed") def test_text_embedding_generator_retry_success(session, bq_connection): # Requests. df0 = EqCmpAllDataFrame( @@ -793,17 +788,28 @@ def test_gemini_preview_model_warnings(model_name): llm.GeminiTextGenerator(model_name=model_name) -# b/436340035 temp disable the test to unblock presumbit @pytest.mark.parametrize( "model_class", [ llm.TextEmbeddingGenerator, llm.MultimodalEmbeddingGenerator, llm.GeminiTextGenerator, - # llm.Claude3TextGenerator, + llm.Claude3TextGenerator, ], ) def test_text_embedding_generator_no_default_model_warning(model_class): message = "Since upgrading the default model can cause unintended breakages, the\ndefault model will be removed in BigFrames 3.0. Please supply an\nexplicit model to avoid this message." - with pytest.warns(FutureWarning, match=message): - model_class(model_name=None) + + # For Claude models, we must mock the model creation to avoid network errors. + # For all other models, we do nothing. contextlib.nullcontext() is a + # placeholder that allows the "with" statement to work for all cases. + patcher: AbstractContextManager[Any] + if model_class == llm.Claude3TextGenerator: + patcher = mock.patch.object(llm.Claude3TextGenerator, "_create_bqml_model") + else: + # We can now call nullcontext() directly + patcher = nullcontext() + + with patcher: + with pytest.warns(FutureWarning, match=message): + model_class(model_name=None)