Skip to content

Commit cb82be7

Browse files
committed
chat-memory-jdbc : Fix message order when retrieving + also save order for batch inserts
1 parent 53a7af5 commit cb82be7

File tree

3 files changed

+49
-22
lines changed

3 files changed

+49
-22
lines changed

auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-jdbc/src/test/java/org/springframework/ai/model/chat/memory/jdbc/autoconfigure/JdbcChatMemoryAutoConfigurationIT.java

+8
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,14 @@ void addGetAndClear_shouldAllExecute() {
8282
assertThat(chatMemory.get(conversationId, Integer.MAX_VALUE)).hasSize(1);
8383
assertThat(chatMemory.get(conversationId, Integer.MAX_VALUE)).isEqualTo(List.of(userMessage));
8484

85+
var assistantMessage = new AssistantMessage("Message from the assistant");
86+
87+
chatMemory.add(conversationId, assistantMessage);
88+
89+
assertThat(chatMemory.get(conversationId, Integer.MAX_VALUE)).hasSize(2);
90+
assertThat(chatMemory.get(conversationId, Integer.MAX_VALUE))
91+
.isEqualTo(List.of(userMessage, assistantMessage));
92+
8593
chatMemory.clear(conversationId);
8694

8795
assertThat(chatMemory.get(conversationId, Integer.MAX_VALUE)).isEmpty();

memory/spring-ai-model-chat-memory-jdbc/src/main/java/org/springframework/ai/chat/memory/jdbc/JdbcChatMemory.java

+8-2
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@
1919
import java.sql.PreparedStatement;
2020
import java.sql.ResultSet;
2121
import java.sql.SQLException;
22+
import java.sql.Timestamp;
23+
import java.time.Instant;
24+
import java.util.Collections;
2225
import java.util.List;
2326

2427
import org.springframework.ai.chat.memory.ChatMemory;
@@ -42,7 +45,7 @@
4245
public class JdbcChatMemory implements ChatMemory {
4346

4447
private static final String QUERY_ADD = """
45-
INSERT INTO ai_chat_memory (conversation_id, content, type) VALUES (?, ?, ?)""";
48+
INSERT INTO ai_chat_memory (conversation_id, content, type, "timestamp") VALUES (?, ?, ?, ?)""";
4649

4750
private static final String QUERY_GET = """
4851
SELECT content, type FROM ai_chat_memory WHERE conversation_id = ? ORDER BY "timestamp" DESC LIMIT ?""";
@@ -66,7 +69,9 @@ public void add(String conversationId, List<Message> messages) {
6669

6770
@Override
6871
public List<Message> get(String conversationId, int lastN) {
69-
return this.jdbcTemplate.query(QUERY_GET, new MessageRowMapper(), conversationId, lastN);
72+
List<Message> messages = this.jdbcTemplate.query(QUERY_GET, new MessageRowMapper(), conversationId, lastN);
73+
Collections.reverse(messages);
74+
return messages;
7075
}
7176

7277
@Override
@@ -83,6 +88,7 @@ public void setValues(PreparedStatement ps, int i) throws SQLException {
8388
ps.setString(1, this.conversationId);
8489
ps.setString(2, message.getText());
8590
ps.setString(3, message.getMessageType().name());
91+
ps.setTimestamp(4, Timestamp.from(Instant.now()));
8692
}
8793

8894
@Override

memory/spring-ai-model-chat-memory-jdbc/src/test/java/org/springframework/ai/chat/memory/jdbc/JdbcChatMemoryIT.java

+33-20
Original file line numberDiff line numberDiff line change
@@ -16,27 +16,12 @@
1616

1717
package org.springframework.ai.chat.memory.jdbc;
1818

19-
import java.sql.Timestamp;
20-
import java.util.List;
21-
import java.util.UUID;
22-
23-
import javax.sql.DataSource;
24-
2519
import org.junit.jupiter.api.BeforeAll;
2620
import org.junit.jupiter.api.Test;
2721
import org.junit.jupiter.params.ParameterizedTest;
2822
import org.junit.jupiter.params.provider.CsvSource;
29-
import org.testcontainers.containers.PostgreSQLContainer;
30-
import org.testcontainers.junit.jupiter.Container;
31-
import org.testcontainers.junit.jupiter.Testcontainers;
32-
import org.testcontainers.utility.MountableFile;
33-
3423
import org.springframework.ai.chat.memory.ChatMemory;
35-
import org.springframework.ai.chat.messages.AssistantMessage;
36-
import org.springframework.ai.chat.messages.Message;
37-
import org.springframework.ai.chat.messages.MessageType;
38-
import org.springframework.ai.chat.messages.SystemMessage;
39-
import org.springframework.ai.chat.messages.UserMessage;
24+
import org.springframework.ai.chat.messages.*;
4025
import org.springframework.boot.SpringBootConfiguration;
4126
import org.springframework.boot.autoconfigure.EnableAutoConfiguration;
4227
import org.springframework.boot.autoconfigure.jdbc.DataSourceAutoConfiguration;
@@ -46,6 +31,15 @@
4631
import org.springframework.context.annotation.Bean;
4732
import org.springframework.context.annotation.Primary;
4833
import org.springframework.jdbc.core.JdbcTemplate;
34+
import org.testcontainers.containers.PostgreSQLContainer;
35+
import org.testcontainers.junit.jupiter.Container;
36+
import org.testcontainers.junit.jupiter.Testcontainers;
37+
import org.testcontainers.utility.MountableFile;
38+
39+
import javax.sql.DataSource;
40+
import java.sql.Timestamp;
41+
import java.util.List;
42+
import java.util.UUID;
4943

5044
import static org.assertj.core.api.Assertions.assertThat;
5145

@@ -147,10 +141,11 @@ void get_shouldReturnMessages() {
147141
this.contextRunner.run(context -> {
148142
var chatMemory = context.getBean(ChatMemory.class);
149143
var conversationId = UUID.randomUUID().toString();
150-
var messages = List.<Message>of(new AssistantMessage("Message from assistant 1 - " + conversationId),
151-
new AssistantMessage("Message from assistant 2 - " + conversationId),
152-
new UserMessage("Message from user - " + conversationId),
153-
new SystemMessage("Message from system - " + conversationId));
144+
var messages = List.<Message>of(new SystemMessage("Message from system - " + conversationId),
145+
new UserMessage("Message from user 1 - " + conversationId),
146+
new AssistantMessage("Message from assistant 1 - " + conversationId),
147+
new UserMessage("Message from user 2 - " + conversationId),
148+
new AssistantMessage("Message from assistant 2 - " + conversationId));
154149

155150
chatMemory.add(conversationId, messages);
156151

@@ -161,6 +156,24 @@ void get_shouldReturnMessages() {
161156
});
162157
}
163158

159+
@Test
160+
void get_afterMultipleAdds_shouldReturnMessagesInSameOrder() {
161+
this.contextRunner.run(context -> {
162+
var chatMemory = context.getBean(ChatMemory.class);
163+
var conversationId = UUID.randomUUID().toString();
164+
var userMessage = new UserMessage("Message from user - " + conversationId);
165+
var assistantMessage = new AssistantMessage("Message from assistant - " + conversationId);
166+
167+
chatMemory.add(conversationId, userMessage);
168+
chatMemory.add(conversationId, assistantMessage);
169+
170+
var results = chatMemory.get(conversationId, Integer.MAX_VALUE);
171+
172+
assertThat(results.size()).isEqualTo(2);
173+
assertThat(results).isEqualTo(List.of(userMessage, assistantMessage));
174+
});
175+
}
176+
164177
@Test
165178
void clear_shouldDeleteMessages() {
166179
this.contextRunner.run(context -> {

0 commit comments

Comments
 (0)