|
48 | 48 | import org.springframework.ai.content.Media;
|
49 | 49 | import org.springframework.ai.converter.ListOutputConverter;
|
50 | 50 | import org.springframework.ai.converter.StructuredOutputConverter;
|
| 51 | +import org.springframework.ai.template.TemplateRenderer; |
51 | 52 | import org.springframework.ai.tool.ToolCallback;
|
52 | 53 | import org.springframework.ai.tool.function.FunctionToolCallback;
|
53 | 54 | import org.springframework.core.ParameterizedTypeReference;
|
|
60 | 61 | import static org.assertj.core.api.Assertions.assertThatThrownBy;
|
61 | 62 | import static org.mockito.BDDMockito.given;
|
62 | 63 | import static org.mockito.Mockito.mock;
|
| 64 | +import static org.mockito.Mockito.when; |
63 | 65 |
|
64 | 66 | /**
|
65 | 67 | * Unit tests for {@link DefaultChatClient}.
|
@@ -124,6 +126,51 @@ void whenPromptWithOptionsThenReturn() {
|
124 | 126 | assertThat(spec.getChatOptions()).isEqualTo(chatOptions);
|
125 | 127 | }
|
126 | 128 |
|
| 129 | + @Test |
| 130 | + void testMutate() { |
| 131 | + var media = mock(Media.class); |
| 132 | + var toolCallback = mock(ToolCallback.class); |
| 133 | + var advisor = mock(Advisor.class); |
| 134 | + var templateRenderer = mock(TemplateRenderer.class); |
| 135 | + var chatOptions = mock(ChatOptions.class); |
| 136 | + var copyChatOptions = mock(ChatOptions.class); |
| 137 | + when(chatOptions.copy()).thenReturn(copyChatOptions); |
| 138 | + var toolContext = new HashMap<String, Object>(); |
| 139 | + var userMessage1 = mock(UserMessage.class); |
| 140 | + var userMessage2 = mock(UserMessage.class); |
| 141 | + |
| 142 | + DefaultChatClientBuilder defaultChatClientBuilder = new DefaultChatClientBuilder(mock(ChatModel.class)); |
| 143 | + defaultChatClientBuilder.addMessages(List.of(userMessage1, userMessage2)); |
| 144 | + ChatClient originalChatClient = defaultChatClientBuilder.defaultAdvisors(advisor) |
| 145 | + .defaultOptions(chatOptions) |
| 146 | + .defaultUser(u -> u.text("original user {userParams}").param("userParams", "user value2").media(media)) |
| 147 | + .defaultSystem(s -> s.text("original system {sysParams}").param("sysParams", "system value1")) |
| 148 | + .defaultTemplateRenderer(templateRenderer) |
| 149 | + .defaultToolNames("toolName1", "toolName2") |
| 150 | + .defaultToolCallbacks(toolCallback) |
| 151 | + .defaultToolContext(toolContext) |
| 152 | + .build(); |
| 153 | + var originalSpec = (DefaultChatClient.DefaultChatClientRequestSpec) originalChatClient.prompt(); |
| 154 | + |
| 155 | + ChatClient mutateChatClient = originalChatClient.mutate().build(); |
| 156 | + var mutateSpec = (DefaultChatClient.DefaultChatClientRequestSpec) mutateChatClient.prompt(); |
| 157 | + |
| 158 | + assertThat(mutateSpec).isNotSameAs(originalSpec); |
| 159 | + |
| 160 | + assertThat(mutateSpec.getMessages()).hasSize(2).containsOnly(userMessage1, userMessage2); |
| 161 | + assertThat(mutateSpec.getAdvisors()).hasSize(1).containsOnly(advisor); |
| 162 | + assertThat(mutateSpec.getChatOptions()).isEqualTo(copyChatOptions); |
| 163 | + assertThat(mutateSpec.getUserText()).isEqualTo("original user {userParams}"); |
| 164 | + assertThat(mutateSpec.getUserParams()).containsEntry("userParams", "user value2"); |
| 165 | + assertThat(mutateSpec.getMedia()).hasSize(1).containsOnly(media); |
| 166 | + assertThat(mutateSpec.getSystemText()).isEqualTo("original system {sysParams}"); |
| 167 | + assertThat(mutateSpec.getSystemParams()).containsEntry("sysParams", "system value1"); |
| 168 | + assertThat(mutateSpec.getTemplateRenderer()).isEqualTo(templateRenderer); |
| 169 | + assertThat(mutateSpec.getToolNames()).containsExactly("toolName1", "toolName2"); |
| 170 | + assertThat(mutateSpec.getToolCallbacks()).containsExactly(toolCallback); |
| 171 | + assertThat(mutateSpec.getToolContext()).isEqualTo(toolContext); |
| 172 | + } |
| 173 | + |
127 | 174 | @Test
|
128 | 175 | void whenMutateChatClientRequest() {
|
129 | 176 | ChatClient chatClient = new DefaultChatClientBuilder(mock(ChatModel.class)).build();
|
|
0 commit comments