diff --git a/nemoguardrails/rails/llm/llmrails.py b/nemoguardrails/rails/llm/llmrails.py index 28a51b641..caccb22a1 100644 --- a/nemoguardrails/rails/llm/llmrails.py +++ b/nemoguardrails/rails/llm/llmrails.py @@ -911,6 +911,7 @@ async def generate_async( await streaming_handler.push_chunk(END_OF_STREAM) # IF tracing is enabled we need to set GenerationLog attrs + original_log_options = None if self.config.tracing.enabled: if options is None: options = GenerationOptions() @@ -921,6 +922,7 @@ async def generate_async( else: # If options is a dict, convert it to GenerationOptions options = GenerationOptions(**options) + original_log_options = options.log.model_copy(deep=True) # enable log options # it is aggressive, but these are required for tracing @@ -1038,6 +1040,25 @@ async def generate_async( ) await tracer.export_async() + # respect original log specification, if tracing added information to the output + if original_log_options: + if not any( + ( + original_log_options.internal_events, + original_log_options.activated_rails, + original_log_options.llm_calls, + original_log_options.colang_history, + ) + ): + res.log = None + else: + if not original_log_options.internal_events: + res.log.internal_events = [] + if not original_log_options.activated_rails: + res.log.activated_rails = [] + if not original_log_options.llm_calls: + res.log.llm_calls = [] + return res else: # If a prompt is used, we only return the content of the message. diff --git a/tests/test_tracing.py b/tests/test_tracing.py index 56a7b1875..f0663803a 100644 --- a/tests/test_tracing.py +++ b/tests/test_tracing.py @@ -14,6 +14,7 @@ # limitations under the License. import asyncio +import itertools import unittest from unittest.mock import AsyncMock, MagicMock, patch @@ -312,10 +313,9 @@ async def test_tracing_does_not_mutate_user_options(): ), "User's original options were modified! This causes instability." # verify that tracing still works - assert response.log is not None, "Tracing should still work correctly" assert ( - response.log.activated_rails is not None - ), "Should have activated rails data" + response.log is None + ), "Tracing should still work correctly, without affecting returned log" @pytest.mark.asyncio @@ -358,9 +358,7 @@ async def test_tracing_with_none_options(): messages=[{"role": "user", "content": "hello"}], options=None ) - assert response.log is not None - assert response.log.activated_rails is not None - assert response.log.stats is not None + assert response.log is None @pytest.mark.asyncio @@ -368,7 +366,8 @@ async def test_tracing_aggressive_override_when_all_disabled(): """Test that tracing aggressively enables all logging when user disables all options. When user disables all three tracing related options, tracing still enables - ALL of them to ensure comprehensive logging data. + ALL of them to ensure comprehensive logging data. However, this should not contaminate the + returned response object """ config = RailsConfig.from_content( @@ -424,12 +423,9 @@ async def test_tracing_aggressive_override_when_all_disabled(): assert user_options.log.colang_history == original_colang_history assert response.log is not None - assert ( - response.log.activated_rails is not None - and len(response.log.activated_rails) > 0 - ) - assert response.log.llm_calls is not None - assert response.log.internal_events is not None + assert response.log.activated_rails == [] + assert response.log.llm_calls == [] + assert response.log.internal_events == [] assert user_options.log.activated_rails == original_activated_rails assert user_options.log.llm_calls == original_llm_calls @@ -439,6 +435,104 @@ async def test_tracing_aggressive_override_when_all_disabled(): assert user_options.log.internal_events == False +@pytest.mark.asyncio +@pytest.mark.parametrize( + "activated_rails,llm_calls,internal_events,colang_history", + list(itertools.product([False, True], repeat=4)), +) +async def test_tracing_preserves_specific_log_fields( + activated_rails, llm_calls, internal_events, colang_history +): + """Test that adding tracing respects the original user logging options in the response object""" + + config = RailsConfig.from_content( + colang_content=""" + define user express greeting + "hello" + + define flow + user express greeting + bot express greeting + + define bot express greeting + "Hello! How can I assist you today?" + """, + config={ + "models": [], + "tracing": {"enabled": True, "adapters": [{"name": "FileSystem"}]}, + }, + ) + + chat = TestChat( + config, + llm_completions=[ + "user express greeting", + "bot express greeting", + "Hello! How can I assist you today?", + ], + ) + + # user enables some subset of log options + user_options = GenerationOptions( + log=GenerationLogOptions( + activated_rails=activated_rails, + llm_calls=llm_calls, + internal_events=internal_events, + colang_history=colang_history, + ) + ) + + original_activated_rails = user_options.log.activated_rails + original_llm_calls = user_options.log.llm_calls + original_internal_events = user_options.log.internal_events + original_colang_history = user_options.log.colang_history + + with patch.object(Tracer, "export_async", return_value=None): + response = await chat.app.generate_async( + messages=[{"role": "user", "content": "hello"}], options=user_options + ) + + assert user_options.log.activated_rails == original_activated_rails + assert user_options.log.llm_calls == original_llm_calls + assert user_options.log.internal_events == original_internal_events + assert user_options.log.colang_history == original_colang_history + + # verify that only the requested log options are returned in the response + if not any( + ( + user_options.log.activated_rails, + user_options.log.llm_calls, + user_options.log.internal_events, + user_options.log.colang_history, + ) + ): + assert response.log is None + else: + assert response.log is not None + + if user_options.log.activated_rails: + assert len(response.log.activated_rails) > 0 + else: + assert len(response.log.activated_rails) == 0 + + if user_options.log.llm_calls: + assert len(response.log.llm_calls) > 0 + else: + assert len(response.log.llm_calls) == 0 + + if user_options.log.internal_events: + assert len(response.log.internal_events) > 0 + else: + assert len(response.log.internal_events) == 0 + + assert user_options.log.activated_rails == original_activated_rails + assert user_options.log.llm_calls == original_llm_calls + assert user_options.log.internal_events == original_internal_events + assert user_options.log.activated_rails == activated_rails + assert user_options.log.llm_calls == llm_calls + assert user_options.log.internal_events == internal_events + + @pytest.mark.asyncio async def test_tracing_aggressive_override_with_dict_options(): """Test that tracing works correctly when options are passed as a dict. @@ -502,11 +596,11 @@ async def test_tracing_aggressive_override_with_dict_options(): assert response.log is not None assert ( - response.log.activated_rails is not None - and len(response.log.activated_rails) > 0 + response.log.activated_rails == [] + and len(response.log.activated_rails) == 0 ) - assert response.log.llm_calls is not None - assert response.log.internal_events is not None + assert response.log.llm_calls == [] + assert response.log.internal_events == [] if __name__ == "__main__":