Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .env.example
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
OPENAI_API_KEY=
OPENAI_ASSISTANT_ID=
LITERAL_API_KEY=
LITERAL_API_KEY=
96 changes: 93 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,99 @@ You can deploy your OpenAI assistant with Chainlit using this template.

### Supported Assistant Features

| Streaming | Files | Code Interpreter | File Search | Voice |
| --------- | ----- | ---------------- | ----------- | ----- |
| ✅ | ✅ | ✅ | ✅ | ✅ |

| Streaming | Files | Code Interpreter | File Search | Voice | Function Call |
| --------- | ----- | ---------------- | ----------- | ----- | -------------- |
| ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |


#### Function Call Feature


![openai-assistant-funciton-call](https://github.com/user-attachments/assets/920767aa-29af-493c-a88b-9dd0137c3108)

---

#### How to Implement Customized Functions

To implement custom functions with OpenAI's function call feature, follow these steps:

1. **Register Your Function**
- Register your function through the OpenAI API or in the Playground. Ensure that the function names and their parameters are properly defined.

```json
{
"name": "search_web",
"description": "search information online",
"strict": true,
"parameters": {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "search query for web search"
}
},
"additionalProperties": false,
"required": [
"query"
]
}
}
```

```json
{
"name": "add_two_numbers",
"description": "add two numbers together",
"strict": true,
"parameters": {
"type": "object",
"properties": {
"number1": {
"type": "integer",
"description": "the first number for the addinng operation"
},
"number2": {
"type": "integer",
"description": "the second number for the addinng operation"
}
},
"additionalProperties": false,
"required": [
"number1",
"number2"
]
}
}
```



2. **Edit the `function_map` Attribute in the `EventHandler` Class**
- In the `EventHandler` class, edit the `function_map` attribute like so:
```python
self.function_map = {
'search_web': self.search_web,
'add_two_numbers': self.add_two_numbers
}
```
- The **key names** (`'search_web'`, `'add_two_numbers'`) are the **function names** you registered with OpenAI, and they must match **exactly**.
- The values (e.g., `self.search_web`, `self.add_two_numbers`) are the actual methods you will implement inside the `EventHandler` class.

3. **Implement the Functions**
- Implement the functions that you registered and mapped in the `EventHandler` class. After the 'self' parameter, the parameter names and their count should exactly match the registered function in OpenAI’s system. For example:
```python
def search_web(self, query):
# Implementation of the search_web function
pass

def add_two_numbers(self, num1, num2):
# Implementation of the add_two_numbers function
return num1 + num2
```

---

### Get an OpenAI API key

Expand Down
156 changes: 143 additions & 13 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@
from io import BytesIO
from pathlib import Path
from typing import List
import json
from typing_extensions import override

from openai import AsyncAssistantEventHandler, AsyncOpenAI, OpenAI
from openai.types.beta.threads.runs import ToolCall, ToolCallDelta,RunStepDelta,RunStep

from literalai.helper import utc_now

import chainlit as cl
from chainlit.config import config
from chainlit.element import Element
Expand All @@ -21,6 +23,13 @@

config.ui.name = assistant.name



#how to add customized functions
# 1. edit class attribute function_map, key is the function name you regiesterd in openai's system, you can add function through playground or api
# value is the actual function you are going to implement
# 2. implement the function in the class,make sure the arguments name match with the parameters

class EventHandler(AsyncAssistantEventHandler):

def __init__(self, assistant_name: str) -> None:
Expand All @@ -29,6 +38,65 @@ def __init__(self, assistant_name: str) -> None:
self.current_step: cl.Step = None
self.current_tool_call = None
self.assistant_name = assistant_name
self.function_map={'search_web':self.search_web,
'add_two_numbers':self.add_two_numbers}



#implment search_web function.
def search_web(self,query):
print('search_web -----> query:',query)
return 'currently unimplemented, please implement this function '+query

#an example function that is already implemented
def add_two_numbers(self,number1,number2):
return str(number1+number2)

@override
async def on_event(self, event):
# Retrieve events that are denoted with 'requires_action',since these will have our tool_calls
if event.event == 'thread.run.requires_action':
run_id = event.data.id # Retrieve the run ID from the event data
self.current_run.id=run_id
await self.handle_requires_action(event.data, run_id)


async def handle_requires_action(self, data, run_id):
tool_outputs = []
for tool in data.required_action.submit_tool_outputs.tool_calls:
func_name = tool.function.name
func_args = tool.function.arguments

# Use the instance attribute function_map
func_to_call = self.function_map.get(func_name)

if func_to_call:
try:
# Parse the func_args JSON string to a dictionary
func_args_dict = json.loads(func_args)
tool_call_output = func_to_call(**func_args_dict)
tool_outputs.append({"tool_call_id": tool.id, "output": tool_call_output})
except TypeError as e:
print(f"Error calling function {func_name}: {e}")
else:
print(f"Function {func_name} not found")
# Submit all tool_outputs at the same time
await self.submit_tool_outputs(tool_outputs, run_id)


async def submit_tool_outputs(self, tool_outputs, run_id):
"""
Submits the tool outputs to the current run.
"""
async with async_openai_client.beta.threads.runs.submit_tool_outputs_stream(
thread_id=self.current_run.thread_id,
run_id=run_id,
tool_outputs=tool_outputs,
event_handler=EventHandler(assistant_name=self.assistant_name),
) as stream:
await stream.until_done()



async def on_text_created(self, text) -> None:
self.current_message = await cl.Message(author=self.assistant_name, content="").send()
Expand All @@ -39,20 +107,14 @@ async def on_text_delta(self, delta, snapshot):
async def on_text_done(self, text):
await self.current_message.update()

async def on_tool_call_created(self, tool_call):
self.current_tool_call = tool_call.id
self.current_step = cl.Step(name=tool_call.type, type="tool")
self.current_step.language = "python"
self.current_step.created_at = utc_now()
await self.current_step.send()

async def on_tool_call_delta(self, delta, snapshot):
if snapshot.id != self.current_tool_call:
self.current_tool_call = snapshot.id
self.current_step = cl.Step(name=delta.type, type="tool")
self.current_step.language = "python"
self.current_step.start = utc_now()
await self.current_step.send()
#await self.current_step.send()

if delta.type == "code_interpreter":
if delta.code_interpreter.outputs:
Expand All @@ -73,9 +135,72 @@ async def on_tool_call_delta(self, delta, snapshot):
await self.current_step.stream_token(delta.code_interpreter.input)


async def on_tool_call_done(self, tool_call):
self.current_step.end = utc_now()
await self.current_step.update()


async def on_run_step_done(self, run_step: RunStep):
if run_step.type == 'tool_calls':
tool_calls = run_step.step_details.tool_calls

# Handle tool call with type 'file_search and output quations to user'
if any(call.type == 'file_search' for call in tool_calls):
#retrieve quations from openai by adding parameter include
run_step = sync_openai_client.beta.threads.runs.steps.retrieve(
thread_id=cl.user_session.get("thread_id"),
run_id=run_step.run_id,
step_id=run_step.id,
include=["step_details.tool_calls[*].file_search.results[*].content"]
)
# Initialize an empty list to hold the citations
citations = []
# Extract tool_calls from run_step.step_details
tool_calls = run_step.step_details.tool_calls
# Iterate through each tool call
for call in tool_calls:
# Check if the type of the tool call is 'file_search'
if call.type == 'file_search':
# Extract the file search results from the file_search attribute
file_search_results = call.file_search.results
# Iterate through each result in file_search_results
for result in file_search_results:
# Extract the first content's text for the quote (if available)
quote = result.content[0].text if result.content else ""
# Create a citation dictionary
citation = {
"file_name": result.file_name,
"score": result.score,
"quote": quote
}
# Append the citation dictionary to the citations list
citations.append(citation)

# Create the final dictionary with the list of citations
citations_dict = {"citations": citations}
#Construct step
self.current_step = cl.Step(name=call.type, type="tool")
self.current_step.output=citations_dict
await self.current_step.send()

#Handle tool calls that are customized functions and render function input and output to user
else:
for tool_call in tool_calls:
if hasattr(tool_call, 'function'):
func_name = tool_call.function.name
func_args = tool_call.function.arguments
func_output=tool_call.function.output
self.current_step = cl.Step(name=func_name, type="tool")
self.current_step.input = json.loads(func_args)

# Handle func_output being either JSON string or plain text
try:
# Try to parse func_output as JSON
self.current_step.output = json.loads(func_output)
except json.JSONDecodeError:
# If it's not JSON, treat it as plain text
self.current_step.output = func_output
await self.current_step.send()




async def on_image_file_done(self, image_file):
image_id = image_file.file_id
Expand All @@ -92,6 +217,7 @@ async def on_image_file_done(self, image_file):
await self.current_message.update()



@cl.step(type="tool")
async def speech_to_text(audio_file):
response = await async_openai_client.audio.transcriptions.create(
Expand Down Expand Up @@ -132,8 +258,7 @@ async def start_chat():
thread = await async_openai_client.beta.threads.create()
# Store thread ID in user session for later use
cl.user_session.set("thread_id", thread.id)
await cl.Avatar(name=assistant.name, path="./public/logo.png").send()
await cl.Message(content=f"Hello, I'm {assistant.name}!", disable_feedback=True).send()



@cl.on_message
Expand Down Expand Up @@ -197,3 +322,8 @@ async def on_audio_end(elements: list[Element]):
msg = cl.Message(author="You", content=transcription, elements=elements)

await main(message=msg)


if __name__ == "__main__":
from chainlit.cli import run_chainlit
run_chainlit(__file__)