Skip to content

fix: Aggregate state_delta from parallel AgentTool calls #1155

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions contributing/samples/parallel_agent_tool_state_merge/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from . import agent
117 changes: 117 additions & 0 deletions contributing/samples/parallel_agent_tool_state_merge/agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import random

from google.adk import Agent
from google.adk.tools.agent_tool import AgentTool
from google.genai import types


# --- Roll Die Sub-Agent ---
def roll_die(sides: int) -> int:
"""Roll a die and return the rolled result."""
return random.randint(1, sides)


roll_agent = Agent(
model="gemini-2.5-flash-preview-05-20",
name="roll_agent",
description="Handles rolling dice of different sizes.",
instruction="""
You are responsible for rolling dice based on the user's request.
When asked to roll a die, you must call the roll_die tool with the number of sides as an integer.
""",
tools=[roll_die],
generate_content_config=types.GenerateContentConfig(
safety_settings=[
types.SafetySetting( # avoid false alarm about rolling dice.
category=types.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
threshold=types.HarmBlockThreshold.OFF,
),
]
),
output_key="result_roll_agent",
)


# --- Prime Check Sub-Agent ---
def check_prime(nums: list[int]) -> str:
"""Check if a given list of numbers are prime."""
primes = set()
for number in nums:
number = int(number)
if number <= 1:
continue
is_prime = True
for i in range(2, int(number**0.5) + 1):
if number % i == 0:
is_prime = False
break
if is_prime:
primes.add(number)
return (
"No prime numbers found."
if not primes
else f"{', '.join(str(num) for num in primes)} are prime numbers."
)


prime_agent = Agent(
model="gemini-2.5-flash-preview-05-20",
name="prime_agent",
description="Handles checking if numbers are prime.",
instruction="""
You are responsible for checking whether numbers are prime.
When asked to check primes, you must call the check_prime tool with a list of integers.
Never attempt to determine prime numbers manually.
Return the prime number results to the root agent.
""",
tools=[check_prime],
generate_content_config=types.GenerateContentConfig(
safety_settings=[
types.SafetySetting( # avoid false alarm about rolling dice.
category=types.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
threshold=types.HarmBlockThreshold.OFF,
),
]
),
output_key="result_prime_agent",
)

root_agent = Agent(
model="gemini-2.5-flash-preview-05-20",
name="root_agent",
instruction="""
You are a helpful assistant that can roll dice and check if numbers are prime.
You delegate rolling dice tasks to the roll_agent and prime checking tasks to the prime_agent.
Follow these steps:
1. If the user asks to roll a die, delegate to the roll_agent.
2. If the user asks to check primes, delegate to the prime_agent.
3. If the user asks to roll a die and then check if the result is prime, call roll_agent first, then pass the result to prime_agent.
Always clarify the results before proceeding.
""",
global_instruction=(
"You are DicePrimeBot, ready to roll dice and check prime numbers."
),
tools=[AgentTool(agent=roll_agent), AgentTool(agent=prime_agent)],
generate_content_config=types.GenerateContentConfig(
safety_settings=[
types.SafetySetting( # avoid false alarm about rolling dice.
category=types.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
threshold=types.HarmBlockThreshold.OFF,
),
]
),
)
81 changes: 81 additions & 0 deletions contributing/samples/parallel_agent_tool_state_merge/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import asyncio
import time

import agent
from dotenv import load_dotenv
from google.adk.cli.utils import logs
from google.adk.runners import InMemoryRunner
from google.adk.sessions import Session
from google.genai import types

load_dotenv(override=True)
logs.log_to_tmp_folder()


async def main():
app_name = 'parallel_tool_test_app'
user_id_1 = 'user1'
runner = InMemoryRunner(
agent=agent.root_agent,
app_name=app_name,
)
session_11 = await runner.session_service.create_session(
app_name=app_name, user_id=user_id_1
)

async def run_prompt(session: Session, new_message: str):
content = types.Content(
role='user', parts=[types.Part.from_text(text=new_message)]
)
print('** User says:', content.model_dump(exclude_none=True))
async for event in runner.run_async(
user_id=user_id_1,
session_id=session.id,
new_message=content,
):
if event.content.parts and event.content.parts[0].text:
print(f'** {event.author}: {event.content.parts[0].text}')

start_time = time.time()
print('Start time:', start_time)
print('------------------------------------')
await run_prompt(session_11, 'Hi')
await run_prompt(
session_11,
'Roll a 6-sided die and check if 2 is a prime number at the same time.',
)
print('\n--- Session State After Parallel Tool Call ---')
updated_session_1 = await runner.session_service.get_session(
app_name=app_name, user_id=user_id_1, session_id=session_11.id
)
print('\n--- Session State After Parallel Tool Call ---')
if updated_session_1 and updated_session_1.state:
print(f'Session ID: {updated_session_1.id}')
print('Session State:')
for key, value in updated_session_1.state.items():
print(f' {key}: {value}')
else:
print('No session state found or session could not be retrieved.')

end_time = time.time()
print('------------------------------------')
print('End time:', end_time)
print('Total time:', end_time - start_time)


if __name__ == '__main__':
asyncio.run(main())
17 changes: 13 additions & 4 deletions src/google/adk/flows/llm_flows/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,12 +502,21 @@ def merge_parallel_function_response_events(

merged_actions = EventActions()
merged_requested_auth_configs = {}
merged_state_delta = {}
for event in function_response_events:
merged_requested_auth_configs.update(event.actions.requested_auth_configs)
merged_actions = merged_actions.model_copy(
update=event.actions.model_dump()
)
if event.actions:
if event.actions.requested_auth_configs:
merged_requested_auth_configs.update(
event.actions.requested_auth_configs
)
if event.actions.state_delta:
merged_state_delta.update(event.actions.state_delta)

merged_actions = merged_actions.model_copy(
update=event.actions.model_dump()
)
merged_actions.requested_auth_configs = merged_requested_auth_configs
merged_actions.state_delta = merged_state_delta
# Create the new merged event
merged_event = Event(
invocation_id=Event.new_id(),
Expand Down