Skip to content

Add support to custom search tool #119

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,8 @@ You can customize the research assistant workflow through several parameters:
- `writer_model`: Model for writing the report (default: "claude-3-5-sonnet-latest")
- `writer_model_kwargs`: Additional parameter for writer_model
- `search_api`: API to use for web searches (default: "tavily", options include "perplexity", "exa", "arxiv", "pubmed", "linkup")
- `search_api_custom_function`: *(optional)*
Async function for custom search logic; use with `"search_api": "customsearch"`.

## 2. Multi-Agent Implementation (`src/open_deep_research/multi_agent.py`)

Expand Down
4 changes: 3 additions & 1 deletion src/open_deep_research/configuration.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
from enum import Enum
from dataclasses import dataclass, fields, field
from typing import Any, Optional, Dict, Literal
from typing import Any, Optional, Dict, Literal, Awaitable, Callable

from langchain_core.runnables import RunnableConfig

Expand Down Expand Up @@ -50,6 +50,8 @@ class WorkflowConfiguration:
writer_provider: str = "anthropic"
writer_model: str = "claude-3-7-sonnet-latest"
writer_model_kwargs: Optional[Dict[str, Any]] = None
search_api_custom_function: Optional[Callable[..., Awaitable[Any]]] = None # Async function for custom search logic; use with `"search_api": "customsearch"`.


@classmethod
def from_runnable_config(
Expand Down
7 changes: 5 additions & 2 deletions src/open_deep_research/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ async def generate_report_plan(state: ReportState, config: RunnableConfig):
number_of_queries = configurable.number_of_queries
search_api = get_config_value(configurable.search_api)
search_api_config = configurable.search_api_config or {} # Get the config dict, default to empty
search_api_custom_function = configurable.search_api_custom_function
params_to_pass = get_search_params(search_api, search_api_config) # Filter parameters

# Convert JSON object to string if necessary
Expand Down Expand Up @@ -101,7 +102,8 @@ async def generate_report_plan(state: ReportState, config: RunnableConfig):
query_list = [query.search_query for query in results.queries]

# Search the web with parameters
source_str = await select_and_execute_search(search_api, query_list, params_to_pass)
source_str = await select_and_execute_search(search_api, query_list, params_to_pass, search_api_custom_function)


# Format system instructions
system_instructions_sections = report_planner_instructions.format(topic=topic, report_organization=report_structure, context=source_str, feedback=feedback)
Expand Down Expand Up @@ -255,13 +257,14 @@ async def search_web(state: SectionState, config: RunnableConfig):
configurable = WorkflowConfiguration.from_runnable_config(config)
search_api = get_config_value(configurable.search_api)
search_api_config = configurable.search_api_config or {} # Get the config dict, default to empty
search_api_custom_function = configurable.search_api_custom_function
params_to_pass = get_search_params(search_api, search_api_config) # Filter parameters

# Web search
query_list = [query.search_query for query in search_queries]

# Search the web with parameters
source_str = await select_and_execute_search(search_api, query_list, params_to_pass)
source_str = await select_and_execute_search(search_api, query_list, params_to_pass, search_api_custom_function)

return {"source_str": source_str, "search_iterations": state["search_iterations"] + 1}

Expand Down
8 changes: 6 additions & 2 deletions src/open_deep_research/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import aiohttp
import httpx
import time
from typing import List, Optional, Dict, Any, Union, Literal, Annotated, cast
from typing import List, Optional, Dict, Any, Union, Literal, Annotated, cast, Callable, Awaitable
from urllib.parse import unquote
from collections import defaultdict
import itertools
Expand Down Expand Up @@ -1498,7 +1498,7 @@ async def azureaisearch_search(queries: List[str], max_results: int = 5, topic:
return "No valid search results found. Please try different search queries or use a different search API."


async def select_and_execute_search(search_api: str, query_list: list[str], params_to_pass: dict) -> str:
async def select_and_execute_search(search_api: str, query_list: list[str], params_to_pass: dict, search_api_custom_function: Optional[Callable[..., Awaitable[Any]]]) -> str:
"""Select and execute the appropriate search API.

Args:
Expand Down Expand Up @@ -1533,6 +1533,10 @@ async def select_and_execute_search(search_api: str, query_list: list[str], para
search_results = await google_search_async(query_list, **params_to_pass)
elif search_api == "azureaisearch":
search_results = await azureaisearch_search_async(query_list, **params_to_pass)
elif search_api == "customsearch":
if search_api_custom_function is None or not callable(search_api_custom_function):
raise ValueError("For 'customsearch', you must provide a valid async function as 'search_api_custom_function'.")
search_results = await search_api_custom_function(query_list, **params_to_pass)
else:
raise ValueError(f"Unsupported search API: {search_api}")

Expand Down
4 changes: 3 additions & 1 deletion src/open_deep_research/workflow/configuration.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from open_deep_research.configuration import DEFAULT_REPORT_STRUCTURE, SearchAPI
from dataclasses import dataclass, fields
from typing import Optional, Dict, Any, Literal
from typing import Optional, Dict, Any, Literal, Awaitable, Callable
from langchain_core.runnables import RunnableConfig
import os

Expand Down Expand Up @@ -29,6 +29,8 @@ class WorkflowConfiguration:
writer_provider: str = "anthropic"
writer_model: str = "claude-3-7-sonnet-latest"
writer_model_kwargs: Optional[Dict[str, Any]] = None
search_api_custom_function: Optional[Callable[..., Awaitable[Any]]] = None # Async function for custom search logic; use with `"search_api": "customsearch"`.


@classmethod
def from_runnable_config(
Expand Down
6 changes: 4 additions & 2 deletions src/open_deep_research/workflow/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ async def generate_report_plan(state: ReportState, config: RunnableConfig) -> Co
number_of_queries = configurable.number_of_queries
search_api = get_config_value(configurable.search_api)
search_api_config = configurable.search_api_config or {} # Get the config dict, default to empty
search_api_custom_function = configurable.search_api_custom_function
params_to_pass = get_search_params(search_api, search_api_config) # Filter parameters
sections_user_approval = configurable.sections_user_approval

Expand All @@ -94,7 +95,7 @@ async def generate_report_plan(state: ReportState, config: RunnableConfig) -> Co
HumanMessage(content="Generate search queries that will help with planning the sections of the report.")])

query_list = [query.search_query for query in results.queries]
source_str = await select_and_execute_search(search_api, query_list, params_to_pass)
source_str = await select_and_execute_search(search_api, query_list, params_to_pass, search_api_custom_function)
system_instructions_sections = report_planner_instructions.format(messages=get_buffer_string(messages), report_organization=report_structure, context=source_str, feedback=feedback)

planner_provider = get_config_value(configurable.planner_provider)
Expand Down Expand Up @@ -182,10 +183,11 @@ async def search_web(state: SectionState, config: RunnableConfig):
configurable = WorkflowConfiguration.from_runnable_config(config)
search_api = get_config_value(configurable.search_api)
search_api_config = configurable.search_api_config or {}
search_api_custom_function = configurable.search_api_custom_function
params_to_pass = get_search_params(search_api, search_api_config)

query_list = [query.search_query for query in search_queries]
source_str = await select_and_execute_search(search_api, query_list, params_to_pass)
source_str = await select_and_execute_search(search_api, query_list, params_to_pass, search_api_custom_function)

return {"source_str": source_str, "search_iterations": state["search_iterations"] + 1}

Expand Down
1 change: 1 addition & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ def pytest_addoption(parser):
"""Add command-line options to pytest."""
parser.addoption("--research-agent", action="store", help="Agent type: multi_agent or graph")
parser.addoption("--search-api", action="store", help="Search API to use")
parser.addoption("--search-api-custom-function", action="store", help="# Async function tool for custom search logic")
parser.addoption("--eval-model", action="store", help="Model for evaluation")
parser.addoption("--supervisor-model", action="store", help="Model for supervisor agent")
parser.addoption("--researcher-model", action="store", help="Model for researcher agent")
Expand Down
4 changes: 4 additions & 0 deletions tests/run_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ def main():
# Search API configuration
parser.add_argument("--search-api", choices=["tavily", "duckduckgo"],
help="Search API to use for content retrieval")
parser.add_argument("--search-api-custom-function", help="Dotted path to a custom async search function")


args = parser.parse_args()

Expand Down Expand Up @@ -159,6 +161,8 @@ def add_model_configs(cmd, args):
cmd.append(f"--eval-model={args.eval_model}")
if args.search_api:
cmd.append(f"--search-api={args.search_api}")
if args.search_api_custom_function:
cmd.append(f"--search-api-custom-function={args.search_api_custom_function}")
if args.max_search_depth:
cmd.append(f"--max-search-depth={args.max_search_depth}")

Expand Down