Skip to content

Feat/memory impl #157

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions arkitect/core/client/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from .base import Client, ClientPool, get_client_pool
from .http import default_ark_client, load_request
from .redis import RedisClient
from .sse import AsyncSSEDecoder

__all__ = [
Expand All @@ -23,4 +24,5 @@
"default_ark_client",
"load_request",
"get_client_pool",
"RedisClient",
]
121 changes: 121 additions & 0 deletions arkitect/core/client/redis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
# Copyright 2025 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import redis.asyncio as redis
from redis.asyncio.retry import Retry
from redis.backoff import ExponentialBackoff
from redis.exceptions import BusyLoadingError, ConnectionError, TimeoutError

from arkitect.core.client.base import Client


class RedisClient(Client):
"""
Initialize a new Redis client object.

Parameters:
host (str): The hostname of the Redis server.
username (str): The username for the Redis server.
password (str): The password for the Redis server.

Returns:
None.

"""

def __init__(self, host: str, username: str, password: str):
self.client = redis.Redis(
host=host,
username=username,
password=password,
retry=Retry(ExponentialBackoff(), 3),
retry_on_error=[BusyLoadingError, ConnectionError, TimeoutError],
)

async def get(self, key: str) -> str:
"""
Get the value of a key from the Redis database.

Args:
key (str): The key to retrieve from the Redis database.

Returns:
str: The value of the key, or None if the key does not exist.

"""
return await self.client.get(key)

async def set(self, key: str, value: str) -> None:
"""
Set the value of a key in the Redis database.
Args:
key (str): The key to set in the Redis database.
value (str): The value to set for the key.
Returns:
None.
"""
await self.client.set(key, value)

async def get_with_prefix(self, prefix: str) -> tuple[list[str], list[str]]:
"""
Asynchronous method to obtain all keys and values from the
Redis database that match the specified prefix

:param prefix: The specified prefix

:return: A list of tuples containing matching keys
and their corresponding values
"""

cursor = 0
keys = []

while True:
# 使用 SCAN 命令进行迭代查询
cursor, key_data = await self.client.scan(cursor, match=prefix, count=1000)

# 将匹配到的 key 添加到列表中
keys.extend(key_data)

# 如果游标值为 0,则表示遍历完成
if cursor == 0 or len(key_data) == 0:
break

# 使用 MGET 命令获取所有匹配到的 key 的对应 value
values = await self.client.mget(keys)

return keys, values

async def mget(self, keys: list[str]) -> list[str]:
"""
Get the values of multiple keys from the Redis database.

Args:
keys (list): A list of keys to retrieve from the Redis database.

Returns:
list: A list of values corresponding to the given keys.

"""
return await self.client.mget(keys)

async def delete(self, key: str) -> None:
"""
Delete a key from the Redis database.
Args:
key (str): The key to delete from the Redis database.
Returns:
None.
"""
await self.client.delete(key)
2 changes: 1 addition & 1 deletion arkitect/core/component/asr/asr_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ async def init(self) -> None:
"X-Api-Request-Id": self.log_id,
}

self.conn = await websockets.connect(self.base_url, extra_headers=headers)
self.conn = await websockets.connect(self.base_url, additional_headers=headers)
INFO(f"Connected to {self.base_url}, log_id: {self.log_id}")

# send init response
Expand Down
14 changes: 10 additions & 4 deletions arkitect/core/component/context/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,15 +43,19 @@
)
from arkitect.core.component.tool.mcp_client import MCPClient
from arkitect.core.component.tool.tool_pool import ToolPool, build_tool_pool
from arkitect.telemetry.trace import task
from arkitect.core.component.tool.utils import (
convert_to_chat_completion_content_part_param,
)
from arkitect.telemetry.trace.wrapper import task
from arkitect.types.llm.model import (
ArkChatParameters,
ArkContextParameters,
)
from arkitect.types.responses.event import ToolChunk

from .chat_completion import _AsyncChat
from .context_completion import _AsyncContext
from .model import ContextInterruption, State, ToolChunk
from .model import ContextInterruption, State


class _AsyncCompletions:
Expand All @@ -63,7 +67,7 @@ def __init__(self, ctx: "Context"):
async def handle_tool_call(self) -> bool:
last_message = self._ctx.get_latest_message()
if last_message is None or not last_message.get("tool_calls"):
return True
return False
if self._ctx.tool_pool is None:
return False
for tool_call in last_message.get("tool_calls"):
Expand All @@ -86,6 +90,7 @@ async def handle_tool_call(self) -> bool:
tool_resp = await self._ctx.tool_pool.execute_tool(
tool_name=tool_name, parameters=json.loads(parameters)
)
tool_resp = convert_to_chat_completion_content_part_param(tool_resp)
except Exception as e:
tool_exception = e

Expand Down Expand Up @@ -170,7 +175,7 @@ async def create(
)

try:
if await self.handle_tool_call():
if not await self.handle_tool_call():
break
except HookInterruptException as he:
return ContextInterruption(
Expand Down Expand Up @@ -279,6 +284,7 @@ async def execute_tool(
tool_resp = await self._ctx.tool_pool.execute_tool( # type: ignore
tool_name=tool_name, parameters=json.loads(parameters)
)
tool_resp = convert_to_chat_completion_content_part_param(tool_resp)
except Exception as e:
tool_exception = e
return tool_resp, tool_exception
Expand Down
8 changes: 4 additions & 4 deletions arkitect/core/component/context/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,13 +88,13 @@ async def pre_tool_call(
if len(state.messages) == 0:
return state
last_message = state.messages[-1]
if not last_message.get("tool_calls"):
if not last_message.tool_calls:
return state

formated_output = []
for tool_call in last_message.get("tool_calls"):
tool_name = tool_call.get("function", {}).get("name")
tool_call_param = tool_call.get("function", {}).get("arguments", "{}")
for tool_call in last_message.tool_calls:
tool_name = tool_call.function.name
tool_call_param = tool_call.function.arguments
formated_output.append(
f"tool_name: {tool_name}\ntool_call_param: {tool_call_param}\n"
)
Expand Down
24 changes: 7 additions & 17 deletions arkitect/core/component/context/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,30 +15,20 @@
from typing import Any, List, Literal, Optional

from pydantic import BaseModel, Field
from volcenginesdkarkruntime.types.chat import ChatCompletionMessageParam

from arkitect.types.llm.model import ArkChatParameters, ArkContextParameters


class ToolChunk(BaseModel):
tool_call_id: str
tool_name: str
tool_arguments: str
tool_exception: Optional[Exception] = None
tool_response: Any | None = None

class Config:
"""Configuration for this pydantic object."""

arbitrary_types_allowed = True
from arkitect.types.llm.model import ArkChatParameters, ArkContextParameters, Message
from arkitect.types.responses.event import StateUpdateEvent


class State(BaseModel):
checkpoint_id: str = ""

context_id: Optional[str] = Field(default=None)
messages: List[ChatCompletionMessageParam] = Field(default_factory=list)
messages: List[Message] = Field(default_factory=list)
parameters: Optional[ArkChatParameters] = Field(default=None)
context_parameters: Optional[ArkContextParameters] = Field(default=None)
details: Optional[Any] = None
details: dict = {}
events: List[StateUpdateEvent] = Field(default_factory=list)


class ContextInterruption(BaseModel):
Expand Down
2 changes: 1 addition & 1 deletion arkitect/core/component/context/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

from volcenginesdkarkruntime.types.chat import ChatCompletion, ChatCompletionChunk

from arkitect.core.component.context.model import ToolChunk
from arkitect.telemetry import logger
from arkitect.types.llm.model import (
ActionDetail,
Expand All @@ -25,6 +24,7 @@
BotUsage,
ToolDetail,
)
from arkitect.types.responses.event import ToolChunk


def convert_chunk(
Expand Down
4 changes: 4 additions & 0 deletions arkitect/core/component/llm/function_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@
)

from arkitect.core.component.tool.tool_pool import ToolPool
from arkitect.core.component.tool.utils import (
convert_to_chat_completion_content_part_param,
)
from arkitect.telemetry.logger import INFO, WARN
from arkitect.telemetry.trace import task
from arkitect.utils import dump_json_str
Expand Down Expand Up @@ -88,6 +91,7 @@ async def handle_function_call(
tool_name=tool_name,
parameters=parameters,
)
resp = convert_to_chat_completion_content_part_param(resp)
INFO(
f"Function {tool_name} called with parameters:"
+ dump_json_str(parameters)
Expand Down
26 changes: 26 additions & 0 deletions arkitect/core/component/memory/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Copyright 2025 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .base_memory_service import BaseMemoryService
from .in_memory_memory_service import (
InMemoryMemoryService,
InMemoryMemoryServiceSingleton,
)


__all__ = [
"BaseMemoryService",
"InMemoryMemoryService",
"InMemoryMemoryServiceSingleton",
]
58 changes: 58 additions & 0 deletions arkitect/core/component/memory/base_memory_service.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Copyright 2025 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from abc import ABC, abstractmethod
from typing import Any

from openai.types.responses import Response
from pydantic import BaseModel
from volcenginesdkarkruntime.types.chat.chat_completion_message import (
ChatCompletionMessage,
)

from arkitect.types.llm.model import Message


class Memory(BaseModel):
memory_content: str
reference: Any | None = None
metadata: Any | None = None


class SearchMemoryResponse(BaseModel):
memories: list[Memory]

@property
def content(self) -> str:
return "\n".join([m.memory_content for m in self.memories])


class BaseMemoryService(ABC):
@abstractmethod
async def update_memory(
self,
user_id: str,
new_messages: list[Message | dict | Response | ChatCompletionMessage],
**kwargs: Any,
) -> None:
pass

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add_or_update有点冗长,叫upsert_memory 呢?
然后接口感觉可以扩展成
@abstractmethod async def upsert_memory( self, user_id: str, data: list[Any], source: DataSource = DataSource.CHAT_MESSAGE, # 将来可以扩展更多的入库的数据,不单单局限在对话的message **kwargs: Any, ) -> None: ...
从而将来可以接入更多数据源,而不单单是对话消息

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

嗯嗯 我觉得要不先把new_messages 改为data: list[Any] 吧,其他source 之后需要加了再加上?暂时不放在base 的interface里
不一定强求所有实现都可以支持各种datasource,但来自message的应该是都会支持的

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好,我其实想留一个 datasource 是想让接口看起来不是专门针对对话消息的,将来可能可以继续扩展出来各种生产资料,比如doc啥的


@abstractmethod
async def search_memory(
self,
user_id: str,
query: str,
**kwargs: Any,
) -> SearchMemoryResponse:
pass
Loading