Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
* irrelevant information that may affect the quality of the search results.
*
* @author Thomas Vitale
* @author Sun Yuhan
* @since 1.0.0
* @see <a href="https://arxiv.org/pdf/2305.14283">arXiv:2305.14283</a>
*/
Expand All @@ -60,15 +61,28 @@ public class RewriteQueryTransformer implements QueryTransformer {

private final String targetSearchSystem;

private final ValidationMode validationMode;

public RewriteQueryTransformer(ChatClient.Builder chatClientBuilder, @Nullable PromptTemplate promptTemplate,
@Nullable String targetSearchSystem) {
Assert.notNull(chatClientBuilder, "chatClientBuilder cannot be null");

this.chatClient = chatClientBuilder.build();
this.promptTemplate = promptTemplate != null ? promptTemplate : DEFAULT_PROMPT_TEMPLATE;
this.targetSearchSystem = targetSearchSystem != null ? targetSearchSystem : DEFAULT_TARGET;
this.validationMode = ValidationMode.THROW;
validate();
}

public RewriteQueryTransformer(ChatClient.Builder chatClientBuilder, @Nullable PromptTemplate promptTemplate,
@Nullable String targetSearchSystem, @Nullable ValidationMode validationMode) {
Assert.notNull(chatClientBuilder, "chatClientBuilder cannot be null");

PromptAssert.templateHasRequiredPlaceholders(this.promptTemplate, "target", "query");
this.chatClient = chatClientBuilder.build();
this.promptTemplate = promptTemplate != null ? promptTemplate : DEFAULT_PROMPT_TEMPLATE;
this.targetSearchSystem = targetSearchSystem != null ? targetSearchSystem : DEFAULT_TARGET;
this.validationMode = validationMode != null ? validationMode : ValidationMode.THROW;
validate();
}

@Override
Expand All @@ -92,6 +106,23 @@ public Query transform(Query query) {
return query.mutate().text(rewrittenQueryText).build();
}

/**
* Verify whether the template contains the required variables.
*/
private void validate() {
switch (this.validationMode) {
case THROW -> PromptAssert.templateHasRequiredPlaceholders(this.promptTemplate, "target", "query");
case WARN -> {
try {
PromptAssert.templateHasRequiredPlaceholders(this.promptTemplate, "target", "query");
}
catch (IllegalArgumentException e) {
logger.warn(e.getMessage());
}
}
}
}

public static Builder builder() {
return new Builder();
}
Expand All @@ -106,6 +137,9 @@ public static final class Builder {
@Nullable
private String targetSearchSystem;

@Nullable
private ValidationMode validationMode;

private Builder() {
}

Expand All @@ -124,8 +158,14 @@ public Builder targetSearchSystem(String targetSearchSystem) {
return this;
}

public Builder validationMode(ValidationMode validationMode) {
this.validationMode = validationMode;
return this;
}

public RewriteQueryTransformer build() {
return new RewriteQueryTransformer(this.chatClientBuilder, this.promptTemplate, this.targetSearchSystem);
return new RewriteQueryTransformer(this.chatClientBuilder, this.promptTemplate, this.targetSearchSystem,
this.validationMode);
}

}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
/*
* Copyright 2025-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.ai.rag.preretrieval.query.transformation;

/**
* Validation modes for {@link RewriteQueryTransformer}.
*
* @author Sun Yuhan
*/
public enum ValidationMode {

/**
* If the validation fails, an exception is thrown. This is the default mode.
*/
THROW,

/**
* If the validation fails, a warning is logged.
*/
WARN,

/**
* No validation is performed.
*/
NONE

}
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,16 @@

package org.springframework.ai.rag.preretrieval.query.transformation;

import ch.qos.logback.classic.Logger;
import ch.qos.logback.classic.spi.ILoggingEvent;
import ch.qos.logback.core.read.ListAppender;
import org.junit.jupiter.api.Test;
import org.slf4j.LoggerFactory;

import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.prompt.PromptTemplate;

import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.mockito.Mockito.mock;

Expand Down Expand Up @@ -71,4 +76,90 @@ void whenPromptHasMissingQueryPlaceholderThenThrow() {
.hasMessageContaining("query");
}

@Test
void shouldLoggingWithMissingTargetPlaceholderInWarnMode() {
PromptTemplate customPromptTemplate = new PromptTemplate("Rewrite {query}");
Logger logger = (Logger) LoggerFactory.getLogger(RewriteQueryTransformer.class);

ListAppender<ILoggingEvent> listAppender = new ListAppender<>();
listAppender.start();
logger.addAppender(listAppender);

RewriteQueryTransformer.builder()
.chatClientBuilder(mock(ChatClient.Builder.class))
.targetSearchSystem("vector store")
.validationMode(ValidationMode.WARN)
.promptTemplate(customPromptTemplate)
.build();
var logsList = listAppender.list;

assertThat(logsList).isNotEmpty();
assertThat(logsList.get(0).getLevel()).isEqualTo(ch.qos.logback.classic.Level.WARN);
assertThat(logsList.get(0).getMessage())
.contains("The following placeholders must be present in the prompt template: target");
}

@Test
void shouldLoggingWithMissingQueryPlaceholderInWarnMode() {
PromptTemplate customPromptTemplate = new PromptTemplate("Rewrite for {target}");
Logger logger = (Logger) LoggerFactory.getLogger(RewriteQueryTransformer.class);

ListAppender<ILoggingEvent> listAppender = new ListAppender<>();
listAppender.start();
logger.addAppender(listAppender);

RewriteQueryTransformer.builder()
.chatClientBuilder(mock(ChatClient.Builder.class))
.targetSearchSystem("search engine")
.validationMode(ValidationMode.WARN)
.promptTemplate(customPromptTemplate)
.build();
var logsList = listAppender.list;

assertThat(logsList).isNotEmpty();
assertThat(logsList.get(0).getLevel()).isEqualTo(ch.qos.logback.classic.Level.WARN);
assertThat(logsList.get(0).getMessage())
.contains("The following placeholders must be present in the prompt template: query");
}

@Test
void shouldContinueWithMissingTargetPlaceholderInNoneMode() {
PromptTemplate customPromptTemplate = new PromptTemplate("Rewrite {target}");
Logger logger = (Logger) LoggerFactory.getLogger(RewriteQueryTransformer.class);

ListAppender<ILoggingEvent> listAppender = new ListAppender<>();
listAppender.start();
logger.addAppender(listAppender);

RewriteQueryTransformer.builder()
.chatClientBuilder(mock(ChatClient.Builder.class))
.targetSearchSystem("vector store")
.validationMode(ValidationMode.NONE)
.promptTemplate(customPromptTemplate)
.build();
var logsList = listAppender.list;

assertThat(logsList).isEmpty();
}

@Test
void shouldContinueWithMissingQueryPlaceholderInNoneMode() {
PromptTemplate customPromptTemplate = new PromptTemplate("Rewrite {query}");
Logger logger = (Logger) LoggerFactory.getLogger(RewriteQueryTransformer.class);

ListAppender<ILoggingEvent> listAppender = new ListAppender<>();
listAppender.start();
logger.addAppender(listAppender);

RewriteQueryTransformer.builder()
.chatClientBuilder(mock(ChatClient.Builder.class))
.targetSearchSystem("search engine")
.validationMode(ValidationMode.NONE)
.promptTemplate(customPromptTemplate)
.build();
var logsList = listAppender.list;

assertThat(logsList).isEmpty();
}

}