Skip to content

Commit f391ff2

Browse files
YunKuiLuilayaperumalg
authored andcommitted
feat: ChatClient#mutate adds copies of advisors and advisorParams.
Auto-cherry-pick to 1.0.x Fixes #3459 - The `mutate()` method adds copies of `advisors` and `advisorParams`. - Add unit test to validate `mutate()` behavior. Signed-off-by: YunKui Lu <[email protected]>
1 parent 5036ed3 commit f391ff2

File tree

2 files changed

+53
-0
lines changed

2 files changed

+53
-0
lines changed

spring-ai-client-chat/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@
6161
import org.springframework.core.io.Resource;
6262
import org.springframework.lang.Nullable;
6363
import org.springframework.util.Assert;
64+
import org.springframework.util.CollectionUtils;
6465
import org.springframework.util.MimeType;
6566
import org.springframework.util.StringUtils;
6667

@@ -705,6 +706,7 @@ public TemplateRenderer getTemplateRenderer() {
705706
* Return a {@code ChatClient2Builder} to create a new {@code ChatClient2} whose
706707
* settings are replicated from this {@code ChatClientRequest}.
707708
*/
709+
@Override
708710
public Builder mutate() {
709711
DefaultChatClientBuilder builder = (DefaultChatClientBuilder) ChatClient
710712
.builder(this.chatModel, this.observationRegistry, this.observationConvention)
@@ -713,6 +715,10 @@ public Builder mutate() {
713715
.defaultToolContext(this.toolContext)
714716
.defaultToolNames(StringUtils.toStringArray(this.toolNames));
715717

718+
if (!CollectionUtils.isEmpty(this.advisors)) {
719+
builder.defaultAdvisors(a -> a.advisors(this.advisors).params(this.advisorParams));
720+
}
721+
716722
if (StringUtils.hasText(this.userText)) {
717723
builder.defaultUser(
718724
u -> u.text(this.userText).params(this.userParams).media(this.media.toArray(new Media[0])));

spring-ai-client-chat/src/test/java/org/springframework/ai/chat/client/DefaultChatClientTests.java

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
import org.springframework.ai.content.Media;
4949
import org.springframework.ai.converter.ListOutputConverter;
5050
import org.springframework.ai.converter.StructuredOutputConverter;
51+
import org.springframework.ai.template.TemplateRenderer;
5152
import org.springframework.ai.tool.ToolCallback;
5253
import org.springframework.ai.tool.function.FunctionToolCallback;
5354
import org.springframework.core.ParameterizedTypeReference;
@@ -60,6 +61,7 @@
6061
import static org.assertj.core.api.Assertions.assertThatThrownBy;
6162
import static org.mockito.BDDMockito.given;
6263
import static org.mockito.Mockito.mock;
64+
import static org.mockito.Mockito.when;
6365

6466
/**
6567
* Unit tests for {@link DefaultChatClient}.
@@ -124,6 +126,51 @@ void whenPromptWithOptionsThenReturn() {
124126
assertThat(spec.getChatOptions()).isEqualTo(chatOptions);
125127
}
126128

129+
@Test
130+
void testMutate() {
131+
var media = mock(Media.class);
132+
var toolCallback = mock(ToolCallback.class);
133+
var advisor = mock(Advisor.class);
134+
var templateRenderer = mock(TemplateRenderer.class);
135+
var chatOptions = mock(ChatOptions.class);
136+
var copyChatOptions = mock(ChatOptions.class);
137+
when(chatOptions.copy()).thenReturn(copyChatOptions);
138+
var toolContext = new HashMap<String, Object>();
139+
var userMessage1 = mock(UserMessage.class);
140+
var userMessage2 = mock(UserMessage.class);
141+
142+
DefaultChatClientBuilder defaultChatClientBuilder = new DefaultChatClientBuilder(mock(ChatModel.class));
143+
defaultChatClientBuilder.addMessages(List.of(userMessage1, userMessage2));
144+
ChatClient originalChatClient = defaultChatClientBuilder.defaultAdvisors(advisor)
145+
.defaultOptions(chatOptions)
146+
.defaultUser(u -> u.text("original user {userParams}").param("userParams", "user value2").media(media))
147+
.defaultSystem(s -> s.text("original system {sysParams}").param("sysParams", "system value1"))
148+
.defaultTemplateRenderer(templateRenderer)
149+
.defaultToolNames("toolName1", "toolName2")
150+
.defaultToolCallbacks(toolCallback)
151+
.defaultToolContext(toolContext)
152+
.build();
153+
var originalSpec = (DefaultChatClient.DefaultChatClientRequestSpec) originalChatClient.prompt();
154+
155+
ChatClient mutateChatClient = originalChatClient.mutate().build();
156+
var mutateSpec = (DefaultChatClient.DefaultChatClientRequestSpec) mutateChatClient.prompt();
157+
158+
assertThat(mutateSpec).isNotSameAs(originalSpec);
159+
160+
assertThat(mutateSpec.getMessages()).hasSize(2).containsOnly(userMessage1, userMessage2);
161+
assertThat(mutateSpec.getAdvisors()).hasSize(1).containsOnly(advisor);
162+
assertThat(mutateSpec.getChatOptions()).isEqualTo(copyChatOptions);
163+
assertThat(mutateSpec.getUserText()).isEqualTo("original user {userParams}");
164+
assertThat(mutateSpec.getUserParams()).containsEntry("userParams", "user value2");
165+
assertThat(mutateSpec.getMedia()).hasSize(1).containsOnly(media);
166+
assertThat(mutateSpec.getSystemText()).isEqualTo("original system {sysParams}");
167+
assertThat(mutateSpec.getSystemParams()).containsEntry("sysParams", "system value1");
168+
assertThat(mutateSpec.getTemplateRenderer()).isEqualTo(templateRenderer);
169+
assertThat(mutateSpec.getToolNames()).containsExactly("toolName1", "toolName2");
170+
assertThat(mutateSpec.getToolCallbacks()).containsExactly(toolCallback);
171+
assertThat(mutateSpec.getToolContext()).isEqualTo(toolContext);
172+
}
173+
127174
@Test
128175
void whenMutateChatClientRequest() {
129176
ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build();

0 commit comments

Comments
 (0)