Skip to content

Commit f5ac94c

Browse files
Add a missing getter to the CategoryScores class (#2939)
Signed-off-by: jonghoon park <[email protected]>
1 parent 19d7601 commit f5ac94c

File tree

2 files changed

+44
-10
lines changed

2 files changed

+44
-10
lines changed

models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiModerationModelIT.java

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,9 @@
1818

1919
import org.junit.jupiter.api.Test;
2020
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
21-
import org.slf4j.Logger;
22-
import org.slf4j.LoggerFactory;
2321

2422
import org.springframework.ai.mistralai.moderation.MistralAiModerationModel;
23+
import org.springframework.ai.moderation.CategoryScores;
2524
import org.springframework.ai.moderation.Moderation;
2625
import org.springframework.ai.moderation.ModerationPrompt;
2726
import org.springframework.ai.moderation.ModerationResult;
@@ -32,13 +31,12 @@
3231

3332
/**
3433
* @author Ricken Bazolo
34+
* @author Jonghoon Park
3535
*/
3636
@SpringBootTest(classes = MistralAiTestConfiguration.class)
3737
@EnabledIfEnvironmentVariable(named = "MISTRAL_AI_API_KEY", matches = ".+")
3838
public class MistralAiModerationModelIT {
3939

40-
private static final Logger logger = LoggerFactory.getLogger(MistralAiModerationModelIT.class);
41-
4240
@Autowired
4341
private MistralAiModerationModel mistralAiModerationModel;
4442

@@ -58,14 +56,23 @@ void moderationAsPositiveTest() {
5856
assertThat(moderation.getId()).isNotEmpty();
5957
assertThat(moderation.getResults()).isNotNull();
6058
assertThat(moderation.getResults().size()).isNotZero();
61-
logger.info(moderation.getResults().toString());
6259

6360
assertThat(moderation.getId()).isNotNull();
6461
assertThat(moderation.getModel()).isNotNull();
6562

6663
ModerationResult result = moderation.getResults().get(0);
6764
assertThat(result.isFlagged()).isTrue();
68-
assertThat(result.getCategories().isViolence()).isTrue();
65+
66+
CategoryScores scores = result.getCategoryScores();
67+
assertThat(scores.getSexual()).isNotNull();
68+
assertThat(scores.getHate()).isNotNull();
69+
assertThat(scores.getViolence()).isNotNull();
70+
assertThat(scores.getDangerousAndCriminalContent()).isNotNull();
71+
assertThat(scores.getSelfHarm()).isNotNull();
72+
assertThat(scores.getHealth()).isNotNull();
73+
assertThat(scores.getFinancial()).isNotNull();
74+
assertThat(scores.getLaw()).isNotNull();
75+
assertThat(scores.getPii()).isNotNull();
6976
}
7077

7178
}

spring-ai-model/src/main/java/org/springframework/ai/moderation/CategoryScores.java

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2023-2024 the original author or authors.
2+
* Copyright 2023-2025 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -26,6 +26,7 @@
2626
* @author Ahmed Yousri
2727
* @author Ilayaperumal Gopinathan
2828
* @author Ricken Bazolo
29+
* @author Jonghoon Park
2930
* @since 1.0.0
3031
*/
3132
public final class CategoryScores {
@@ -129,6 +130,26 @@ public double getViolence() {
129130
return this.violence;
130131
}
131132

133+
public double getDangerousAndCriminalContent() {
134+
return dangerousAndCriminalContent;
135+
}
136+
137+
public double getHealth() {
138+
return health;
139+
}
140+
141+
public double getFinancial() {
142+
return financial;
143+
}
144+
145+
public double getLaw() {
146+
return law;
147+
}
148+
149+
public double getPii() {
150+
return pii;
151+
}
152+
132153
@Override
133154
public boolean equals(Object o) {
134155
if (this == o) {
@@ -147,14 +168,18 @@ public boolean equals(Object o) {
147168
&& Double.compare(that.selfHarmIntent, this.selfHarmIntent) == 0
148169
&& Double.compare(that.selfHarmInstructions, this.selfHarmInstructions) == 0
149170
&& Double.compare(that.harassmentThreatening, this.harassmentThreatening) == 0
150-
&& Double.compare(that.violence, this.violence) == 0;
171+
&& Double.compare(that.violence, this.violence) == 0
172+
&& Double.compare(that.dangerousAndCriminalContent, this.dangerousAndCriminalContent) == 0
173+
&& Double.compare(that.health, this.health) == 0 && Double.compare(that.financial, this.financial) == 0
174+
&& Double.compare(that.law, this.law) == 0 && Double.compare(that.pii, this.pii) == 0;
151175
}
152176

153177
@Override
154178
public int hashCode() {
155179
return Objects.hash(this.sexual, this.hate, this.harassment, this.selfHarm, this.sexualMinors,
156180
this.hateThreatening, this.violenceGraphic, this.selfHarmIntent, this.selfHarmInstructions,
157-
this.harassmentThreatening, this.violence);
181+
this.harassmentThreatening, this.violence, this.dangerousAndCriminalContent, this.health,
182+
this.financial, this.law, this.pii);
158183
}
159184

160185
@Override
@@ -163,7 +188,9 @@ public String toString() {
163188
+ ", selfHarm=" + this.selfHarm + ", sexualMinors=" + this.sexualMinors + ", hateThreatening="
164189
+ this.hateThreatening + ", violenceGraphic=" + this.violenceGraphic + ", selfHarmIntent="
165190
+ this.selfHarmIntent + ", selfHarmInstructions=" + this.selfHarmInstructions
166-
+ ", harassmentThreatening=" + this.harassmentThreatening + ", violence=" + this.violence + '}';
191+
+ ", harassmentThreatening=" + this.harassmentThreatening + ", violence=" + this.violence
192+
+ ", dangerousAndCriminalContent=" + dangerousAndCriminalContent + ", health=" + health + ", financial="
193+
+ financial + ", law=" + law + ", pii=" + pii + '}';
167194
}
168195

169196
public static class Builder {

0 commit comments

Comments
 (0)