diff --git a/src/main/java/org/springframework/data/reindexer/repository/query/StringBasedReindexerRepositoryQuery.java b/src/main/java/org/springframework/data/reindexer/repository/query/StringBasedReindexerRepositoryQuery.java index 5254feb..45cc44f 100644 --- a/src/main/java/org/springframework/data/reindexer/repository/query/StringBasedReindexerRepositoryQuery.java +++ b/src/main/java/org/springframework/data/reindexer/repository/query/StringBasedReindexerRepositoryQuery.java @@ -18,6 +18,8 @@ import java.util.ArrayList; import java.util.Collection; import java.util.HashSet; +import java.util.LinkedHashMap; +import java.util.Map; import java.util.Optional; import java.util.Spliterator; import java.util.Spliterators; @@ -29,8 +31,10 @@ import ru.rt.restream.reindexer.Reindexer; import ru.rt.restream.reindexer.ResultIterator; +import org.springframework.data.repository.query.Parameter; import org.springframework.data.repository.query.QueryMethod; import org.springframework.data.repository.query.RepositoryQuery; +import org.springframework.util.Assert; /** * A string-based {@link RepositoryQuery} implementation for Reindexer. @@ -43,6 +47,8 @@ public class StringBasedReindexerRepositoryQuery implements RepositoryQuery { private final Namespace namespace; + private final Map namedParameters; + /** * Creates an instance. * @@ -54,6 +60,12 @@ public StringBasedReindexerRepositoryQuery(ReindexerQueryMethod queryMethod, Rei this.queryMethod = queryMethod; this.namespace = reindexer.openNamespace(entityInformation.getNamespaceName(), entityInformation.getNamespaceOptions(), entityInformation.getJavaType()); + this.namedParameters = new LinkedHashMap<>(); + for (Parameter parameter : queryMethod.getParameters()) { + if (parameter.isNamedParameter()) { + parameter.getName().ifPresent(name -> this.namedParameters.put(name, parameter.getIndex())); + } + } } @Override @@ -84,7 +96,60 @@ public Object execute(Object[] parameters) { } private String prepareQuery(Object[] parameters) { - return String.format(this.queryMethod.getQuery(), parameters); + String query = this.queryMethod.getQuery(); + StringBuilder sb = new StringBuilder(query); + char[] queryParts = query.toCharArray(); + int offset = 0; + for (int i = 1; i < queryParts.length; i++) { + char c = queryParts[i - 1]; + switch (c) { + case '?': { + int j = i; + int index = 0; + int digits = 0; + while (j < queryParts.length) { + if (!Character.isDigit(queryParts[j])) { + break; + } + index *= 10; + index += Character.getNumericValue(queryParts[j++]); + digits++; + } + String value = getParameterValuePart(parameters[index - 1]); + sb.replace(offset + i - 1, offset + i + digits, value); + offset += value.length() - digits - 1; + break; + } + case ':': { + String parameterName = getParameterName(queryParts, i); + Integer index = this.namedParameters.get(parameterName); + Assert.notNull(index, () -> "No parameter found for name: " + parameterName); + String value = getParameterValuePart(parameters[index]); + sb.replace(offset + i - 1, offset + i + parameterName.length(), value); + offset += value.length() - parameterName.length() - 1; + break; + } + } + } + return sb.toString(); + } + + private String getParameterName(char[] queryParts, int i) { + StringBuilder sb = new StringBuilder(); + while (i < queryParts.length) { + if (Character.isWhitespace(queryParts[i])) { + break; + } + sb.append(queryParts[i++]); + } + return sb.toString(); + } + + private String getParameterValuePart(Object value) { + if (value instanceof String) { + return "'" + value + "'"; + } + return String.valueOf(value); } private Stream toStream(ResultIterator iterator) { diff --git a/src/test/java/org/springframework/data/reindexer/repository/ReindexerRepositoryTests.java b/src/test/java/org/springframework/data/reindexer/repository/ReindexerRepositoryTests.java index f00f036..6c54409 100644 --- a/src/test/java/org/springframework/data/reindexer/repository/ReindexerRepositoryTests.java +++ b/src/test/java/org/springframework/data/reindexer/repository/ReindexerRepositoryTests.java @@ -58,6 +58,7 @@ import org.springframework.data.reindexer.core.mapping.Namespace; import org.springframework.data.reindexer.core.mapping.Query; import org.springframework.data.reindexer.repository.config.EnableReindexerRepositories; +import org.springframework.data.repository.query.Param; import org.springframework.stereotype.Repository; import org.springframework.stereotype.Service; import org.springframework.test.context.ContextConfiguration; @@ -182,6 +183,19 @@ public void findIteratorSqlByName() { } } + @Test + public void findIteratorSqlByNameParam() { + TestItem testItem = this.repository.save(new TestItem(1L, "TestName", null)); + try (ResultIterator it = this.repository.findIteratorSqlByNameParam("TestName")) { + assertTrue(it.hasNext()); + TestItem item = it.next(); + assertEquals(testItem.getId(), item.getId()); + assertEquals(testItem.getName(), item.getName()); + assertEquals(testItem.getValue(), item.getValue()); + assertFalse(it.hasNext()); + } + } + @Test public void getByName() { TestItem testItem = this.repository.save(new TestItem(1L, "TestName", null)); @@ -200,6 +214,15 @@ public void getOneSqlByName() { assertEquals(testItem.getValue(), item.getValue()); } + @Test + public void getOneSqlByNameParam() { + TestItem testItem = this.repository.save(new TestItem(1L, "TestName", null)); + TestItem item = this.repository.getOneSqlByNameParam("TestName"); + assertEquals(testItem.getId(), item.getId()); + assertEquals(testItem.getName(), item.getName()); + assertEquals(testItem.getValue(), item.getValue()); + } + @Test public void findOneSqlByName() { TestItem testItem = this.repository.save(new TestItem(1L, "TestName", null)); @@ -210,6 +233,67 @@ public void findOneSqlByName() { assertEquals(testItem.getValue(), item.getValue()); } + @Test + public void findOneSqlByNameParam() { + TestItem testItem = this.repository.save(new TestItem(1L, "TestName", null)); + TestItem item = this.repository.findOneSqlByNameParam("TestName").orElse(null); + assertNotNull(item); + assertEquals(testItem.getId(), item.getId()); + assertEquals(testItem.getName(), item.getName()); + assertEquals(testItem.getValue(), item.getValue()); + } + + @Test + public void findOneSqlByNameManyParameters() { + TestItem testItem = this.repository.save(new TestItem(1L, "TestName", "TestValue")); + TestItem item = this.repository.findOneSqlByNameAndValueManyParams(null, null, + null, null, null, null, null, null, null, null, "TestName", "TestValue").orElse(null); + assertNotNull(item); + assertEquals(testItem.getId(), item.getId()); + assertEquals(testItem.getName(), item.getName()); + assertEquals(testItem.getValue(), item.getValue()); + } + + @Test + public void findOneSqlByNameOrValue() { + TestItem testItem = this.repository.save(new TestItem(1L, "TestName", "TestValue")); + TestItem item = this.repository.findOneSqlByIdAndNameAndValue(1L, "TestName", "TestValue").orElse(null); + assertNotNull(item); + assertEquals(testItem.getId(), item.getId()); + assertEquals(testItem.getName(), item.getName()); + assertEquals(testItem.getValue(), item.getValue()); + } + + @Test + public void findOneSqlByNameOrValueParam() { + TestItem testItem = this.repository.save(new TestItem(1L, "TestName", "TestValue")); + TestItem item = this.repository.findOneSqlByIdAndNameAndValueParam(1L, "TestName", "TestValue").orElse(null); + assertNotNull(item); + assertEquals(testItem.getId(), item.getId()); + assertEquals(testItem.getName(), item.getName()); + assertEquals(testItem.getValue(), item.getValue()); + } + + @Test + public void findOneSqlByIdAndNameAndValueAnyParameterOrder() { + TestItem testItem = this.repository.save(new TestItem(1L, "TestName", "TestValue")); + TestItem item = this.repository.findOneSqlByIdAndNameAndValue("TestValue", 1L, "TestName").orElse(null); + assertNotNull(item); + assertEquals(testItem.getId(), item.getId()); + assertEquals(testItem.getName(), item.getName()); + assertEquals(testItem.getValue(), item.getValue()); + } + + @Test + public void findOneSqlByIdAndNameAndValueParamAnyParameterOrder() { + TestItem testItem = this.repository.save(new TestItem(1L, "TestName", "TestValue")); + TestItem item = this.repository.findOneSqlByIdAndNameAndValueParam("TestValue", 1L, "TestName").orElse(null); + assertNotNull(item); + assertEquals(testItem.getId(), item.getId()); + assertEquals(testItem.getName(), item.getName()); + assertEquals(testItem.getValue(), item.getValue()); + } + @Test public void getByNameWhenNotExistsThenException() { assertThrows(IllegalStateException.class, () -> this.repository.getByName("notExists"), @@ -228,6 +312,18 @@ public void updateNameSql() { assertEquals(testItem.getValue(), item.getValue()); } + @Test + public void updateNameSqlParam() { + TestItem testItem = this.repository.save(new TestItem(1L, "TestName", "TestValue")); + assertNotNull(testItem); + this.repository.updateNameSqlParam("TestNameUpdated", 1L); + TestItem item = this.repository.findById(1L).orElse(null); + assertNotNull(item); + assertEquals(testItem.getId(), item.getId()); + assertEquals("TestNameUpdated", item.getName()); + assertEquals(testItem.getValue(), item.getValue()); + } + @Test public void save() { TestItem testItem = this.repository.save(new TestItem(1L, "TestName", "TestValue")); @@ -495,20 +591,48 @@ interface TestItemReindexerRepository extends ReindexerRepository findIteratorByName(String name); - @Query("SELECT * FROM items WHERE name = '%s'") + @Query("SELECT * FROM items WHERE name = ?1") ResultIterator findIteratorSqlByName(String name); - @Query(value = "UPDATE items SET name = '%s' WHERE id = %d", update = true) + @Query("SELECT * FROM items WHERE name = :name") + ResultIterator findIteratorSqlByNameParam(@Param("name") String name); + + @Query(value = "UPDATE items SET name = ?1 WHERE id = ?2", update = true) void updateNameSql(String name, Long id); + @Query(value = "UPDATE items SET name = :name WHERE id = :id", update = true) + void updateNameSqlParam(@Param("name") String name, @Param("id") Long id); + TestItem getByName(String name); - @Query("SELECT * FROM items WHERE name = '%s'") + @Query("SELECT * FROM items WHERE name = ?1") Optional findOneSqlByName(String name); - @Query("SELECT * FROM items WHERE name = '%s'") + @Query("SELECT * FROM items WHERE name = :name") + Optional findOneSqlByNameParam(@Param("name") String name); + + @Query("SELECT * FROM items WHERE name = ?11 AND value = ?12") + Optional findOneSqlByNameAndValueManyParams(String name1, String name2, String name3, String name4, String name5, + String name6, String name7, String name8, String name9, String name10, String name11, String value); + + @Query("SELECT * FROM items WHERE id = ?1 AND name = ?2 AND value = ?3") + Optional findOneSqlByIdAndNameAndValue(Long id, String name, String value); + + @Query("SELECT * FROM items WHERE id = :id AND name = :name AND value = :value") + Optional findOneSqlByIdAndNameAndValueParam(@Param("id") Long id, @Param("name") String name, @Param("value") String value); + + @Query("SELECT * FROM items WHERE id = ?2 AND name = ?3 AND value = ?1") + Optional findOneSqlByIdAndNameAndValue(String value, Long id, String name); + + @Query("SELECT * FROM items WHERE id = :id AND name = :name AND value = :value") + Optional findOneSqlByIdAndNameAndValueParam(@Param("value") String value, @Param("id") Long id, @Param("name") String name); + + @Query("SELECT * FROM items WHERE name = ?1") TestItem getOneSqlByName(String name); + @Query("SELECT * FROM items WHERE name = :name") + TestItem getOneSqlByNameParam(@Param("name") String name); + @Query("SELECT * FROM items") List findAllListSql();