diff --git a/src/main/java/org/springframework/data/jpa/repository/query/JpaQueryExecution.java b/src/main/java/org/springframework/data/jpa/repository/query/JpaQueryExecution.java index 3d94aa6167..f66d2acf6a 100644 --- a/src/main/java/org/springframework/data/jpa/repository/query/JpaQueryExecution.java +++ b/src/main/java/org/springframework/data/jpa/repository/query/JpaQueryExecution.java @@ -15,10 +15,15 @@ */ package org.springframework.data.jpa.repository.query; +import static java.util.regex.Pattern.CASE_INSENSITIVE; +import static java.util.regex.Pattern.compile; + import java.lang.reflect.Method; import java.util.Collection; import java.util.List; import java.util.Optional; +import java.util.regex.Matcher; +import java.util.regex.Pattern; import javax.persistence.EntityManager; import javax.persistence.NoResultException; @@ -56,6 +61,7 @@ * @author Christoph Strobl * @author Nicolas Cirigliano * @author Jens Schauder + * @author Chao Jiang */ public abstract class JpaQueryExecution { @@ -204,13 +210,33 @@ protected Object doExecute(final AbstractJpaQuery repositoryQuery, final Object[ } private long count(AbstractJpaQuery repositoryQuery, Object[] values) { - + StringBuilder builder = new StringBuilder(); + String queryString; + List totals = repositoryQuery.createCountQuery(values).getResultList(); - return (totals.size() == 1 ? CONVERSION_SERVICE.convert(totals.get(0), Long.class) : totals.size()); + + if(repositoryQuery instanceof AbstractStringBasedJpaQuery) { + queryString = ((AbstractStringBasedJpaQuery)repositoryQuery).getQuery().getQueryString(); + } else if(repositoryQuery instanceof NamedQuery) { + queryString = ((NamedQuery)repositoryQuery).getQuery().getQueryString(); + } else { + //OartTreeJpaQuery, StoredProcedureJpaQuery, etc. + queryString = ""; + } + + builder.append("(.*)?(group\\s+by\\s+).*"); + Pattern GROUP_MATCH = compile(builder.toString(), CASE_INSENSITIVE); + Matcher matcher = GROUP_MATCH.matcher(queryString); + + if(matcher.matches()) { + return totals.size(); + } else { + return (totals.size() == 1 ? CONVERSION_SERVICE.convert(totals.get(0), Long.class) : totals.size()); + } } } - /** + /** * Executes a {@link AbstractStringBasedJpaQuery} to return a single entity. */ static class SingleEntityExecution extends JpaQueryExecution { diff --git a/src/main/java/org/springframework/data/jpa/repository/query/NamedQuery.java b/src/main/java/org/springframework/data/jpa/repository/query/NamedQuery.java index 73528f24a8..7d9476b9a9 100644 --- a/src/main/java/org/springframework/data/jpa/repository/query/NamedQuery.java +++ b/src/main/java/org/springframework/data/jpa/repository/query/NamedQuery.java @@ -41,6 +41,7 @@ * @author Oliver Gierke * @author Thomas Darimont * @author Mark Paluch + * @author Chao Jiang */ final class NamedQuery extends AbstractJpaQuery { @@ -224,4 +225,11 @@ protected Optional> getTypeToRead(ReturnedType returnedType) { ? Optional.empty() // : super.getTypeToRead(returnedType); } + + /** + * @return the query + */ + public DeclaredQuery getQuery() { + return declaredQuery; + } } diff --git a/src/test/java/org/springframework/data/jpa/repository/query/JpaQueryExecutionUnitTests.java b/src/test/java/org/springframework/data/jpa/repository/query/JpaQueryExecutionUnitTests.java index 572f174e6c..384c0bcf9e 100644 --- a/src/test/java/org/springframework/data/jpa/repository/query/JpaQueryExecutionUnitTests.java +++ b/src/test/java/org/springframework/data/jpa/repository/query/JpaQueryExecutionUnitTests.java @@ -34,6 +34,7 @@ import org.mockito.Mock; import org.mockito.Mockito; import org.mockito.junit.MockitoJUnitRunner; +import org.springframework.data.domain.Page; import org.springframework.data.domain.PageRequest; import org.springframework.data.domain.Pageable; import org.springframework.data.jpa.repository.query.JpaQueryExecution.ModifyingExecution; @@ -49,6 +50,7 @@ * @author Mark Paluch * @author Nicolas Cirigliano * @author Jens Schauder + * @author Chao Jiang */ @RunWith(MockitoJUnitRunner.Silent.class) public class JpaQueryExecutionUnitTests { @@ -57,6 +59,7 @@ public class JpaQueryExecutionUnitTests { @Mock AbstractStringBasedJpaQuery jpaQuery; @Mock Query query; @Mock JpaQueryMethod method; + @Mock DeclaredQuery declaredQuery; @Mock TypedQuery countQuery; @@ -147,6 +150,9 @@ public void pagedExecutionRetrievesObjectsForPageableOutOfRange() throws Excepti when(jpaQuery.createCountQuery(Mockito.any(Object[].class))).thenReturn(countQuery); when(jpaQuery.createQuery(Mockito.any(Object[].class))).thenReturn(query); when(countQuery.getResultList()).thenReturn(Arrays.asList(20L)); + + when(jpaQuery.getQuery()).thenReturn(declaredQuery); + when(declaredQuery.getQueryString()).thenReturn("select count(1) from User u"); PagedExecution execution = new PagedExecution(parameters); execution.doExecute(jpaQuery, new Object[] { PageRequest.of(2, 10) }); @@ -168,6 +174,32 @@ public void pagedExecutionShouldNotGenerateCountQueryIfQueryReportedNoResults() verify(countQuery, times(0)).getResultList(); verify(jpaQuery, times(0)).createCountQuery((Object[]) any()); } + + @Test // DATAJPA-1544 + public void pagedExecutionShouldUseTotalSizeInCount() throws Exception { + + Parameters parameters = new DefaultParameters(getClass().getMethod("sampleMethod", Pageable.class)); + when(jpaQuery.createCountQuery(Mockito.any(Object[].class))).thenReturn(countQuery); + when(jpaQuery.createQuery(Mockito.any(Object[].class))).thenReturn(query); + when(countQuery.getResultList()).thenReturn(Arrays.asList(20L)); + when(query.getResultList()).thenReturn(Arrays.asList(20L)); + when(method.getCountQuery()).thenReturn("select count(1) from User u"); + + when(jpaQuery.getQuery()).thenReturn(declaredQuery); + when(declaredQuery.getQueryString()).thenReturn("select count(1) from User u"); + + PagedExecution execution = new PagedExecution(parameters); + Page page = (Page) execution.doExecute(jpaQuery, new Object[] { PageRequest.of(0, 1) }); + + assertEquals(page.getTotalElements(), 20); + + when(declaredQuery.getQueryString()).thenReturn("select count(1) from User u group by u.id"); + + page = (Page) execution.doExecute(jpaQuery, new Object[] { PageRequest.of(0, 1) }); + + assertEquals(page.getTotalElements(), 1); + + } @Test // DATAJPA-912 public void pagedExecutionShouldUseCountFromResultIfOffsetIsZeroAndResultsWithinPageSize() throws Exception { @@ -205,6 +237,9 @@ public void pagedExecutionShouldUseRequestCountFromResultWithOffsetAndResultsHit when(jpaQuery.createCountQuery(Mockito.any(Object[].class))).thenReturn(query); when(countQuery.getResultList()).thenReturn(Arrays.asList(20L)); + when(jpaQuery.getQuery()).thenReturn(declaredQuery); + when(declaredQuery.getQueryString()).thenReturn("select count(1) from User u"); + PagedExecution execution = new PagedExecution(parameters); execution.doExecute(jpaQuery, new Object[] { PageRequest.of(4, 4) }); @@ -221,6 +256,9 @@ public void pagedExecutionShouldUseRequestCountFromResultWithOffsetAndResultsHit when(jpaQuery.createCountQuery(Mockito.any(Object[].class))).thenReturn(query); when(countQuery.getResultList()).thenReturn(Arrays.asList(20L)); + when(jpaQuery.getQuery()).thenReturn(declaredQuery); + when(declaredQuery.getQueryString()).thenReturn("select count(1) from User u"); + PagedExecution execution = new PagedExecution(parameters); execution.doExecute(jpaQuery, new Object[] { PageRequest.of(4, 4) });