Skip to content

Commit

Permalink
Consider supporting SpEL in native queries
Browse files Browse the repository at this point in the history
Closes gh-13
  • Loading branch information
evgeniycheban committed Dec 26, 2024
1 parent 267c854 commit 426e3ee
Show file tree
Hide file tree
Showing 2 changed files with 109 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,12 @@
import org.springframework.data.repository.query.Parameter;
import org.springframework.data.repository.query.QueryMethod;
import org.springframework.data.repository.query.RepositoryQuery;
import org.springframework.expression.EvaluationContext;
import org.springframework.expression.Expression;
import org.springframework.expression.PropertyAccessor;
import org.springframework.expression.TypedValue;
import org.springframework.expression.spel.standard.SpelExpressionParser;
import org.springframework.expression.spel.support.StandardEvaluationContext;
import org.springframework.util.Assert;

/**
Expand All @@ -43,6 +49,10 @@
*/
public class StringBasedReindexerRepositoryQuery implements RepositoryQuery {

private final SpelExpressionParser spelExpressionParser = new SpelExpressionParser();

private final NamedParameterPropertyAccessor propertyAccessor = new NamedParameterPropertyAccessor();

private final ReindexerQueryMethod queryMethod;

private final Namespace<?> namespace;
Expand Down Expand Up @@ -121,22 +131,50 @@ private String prepareQuery(Object[] parameters) {
String value = getParameterValuePart(parameters[index - 1]);
result.replace(offset + i - 1, offset + i + digits, value);
offset += value.length() - digits - 1;
i += digits + 1;
break;
}
case ':': {
StringBuilder sb = new StringBuilder();
for (int j = i; j < queryParts.length; j++) {
if (Character.isWhitespace(queryParts[j])) {
break;
if (queryParts[i] == '#') {
int special = 1;
StringBuilder sb = new StringBuilder();
for (int j = i + 1; j < queryParts.length; j++) {
if (queryParts[j] == '{') {
special++;
continue;
}
if (queryParts[j] == '}') {
special++;
break;
}
sb.append(queryParts[j]);
}
if (special != 3) {
throw new IllegalStateException("Invalid SpEL expression provided at index: " + i);
}
sb.append(queryParts[j]);
Expression expression = this.spelExpressionParser.parseExpression(sb.toString());
StandardEvaluationContext ctx = new StandardEvaluationContext(parameters);
ctx.addPropertyAccessor(this.propertyAccessor);
String value = getParameterValuePart(expression.getValue(ctx));
result.replace(offset + i - 1, offset + i + expression.getExpressionString().length() + special, value);
offset += value.length() - expression.getExpressionString().length() - special - 1;
i += expression.getExpressionString().length() + special;
} else {
StringBuilder sb = new StringBuilder();
for (int j = i; j < queryParts.length; j++) {
if (Character.isWhitespace(queryParts[j])) {
break;
}
sb.append(queryParts[j]);
}
String parameterName = sb.toString();
Integer index = this.namedParameters.get(parameterName);
Assert.notNull(index, () -> "No parameter found for name: " + parameterName);
String value = getParameterValuePart(parameters[index]);
result.replace(offset + i - 1, offset + i + parameterName.length(), value);
offset += value.length() - parameterName.length() - 1;
i += parameterName.length() + 1;
}
String parameterName = sb.toString();
Integer index = this.namedParameters.get(parameterName);
Assert.notNull(index, () -> "No parameter found for name: " + parameterName);
String value = getParameterValuePart(parameters[index]);
result.replace(offset + i - 1, offset + i + parameterName.length(), value);
offset += value.length() - parameterName.length() - 1;
break;
}
}
Expand Down Expand Up @@ -197,4 +235,38 @@ public QueryMethod getQueryMethod() {
return this.queryMethod;
}

private final class NamedParameterPropertyAccessor implements PropertyAccessor {

@Override
public boolean canRead(EvaluationContext context, Object target, String name) {
return StringBasedReindexerRepositoryQuery.this.namedParameters.containsKey(name);
}

@Override
public TypedValue read(EvaluationContext context, Object target, String name) {
Assert.state(target instanceof Object[], "target must be an array");
Object[] arguments = (Object[]) target;
Integer index = StringBasedReindexerRepositoryQuery.this.namedParameters.get(name);
Assert.notNull(index, () -> "No parameter found for name: " + name);
Object value = arguments[index];
return new TypedValue(value);
}

@Override
public boolean canWrite(EvaluationContext context, Object target, String name) {
return false;
}

@Override
public void write(EvaluationContext context, Object target, String name, Object newValue) {
// NOOP
}

@Override
public Class<?>[] getSpecificTargetClasses() {
return new Class[0];
}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,26 @@ public void findOneSqlByNameOrValueParam() {
assertEquals(testItem.getValue(), item.getValue());
}

@Test
public void findOneSqlSpelByItemIdAndNameAndValueParam() {
TestItem testItem = this.repository.save(new TestItem(1L, "TestName", "TestValue"));
TestItem item = this.repository.findOneSqlSpelByItemIdAndNameAndValueParam(testItem).orElse(null);
assertNotNull(item);
assertEquals(testItem.getId(), item.getId());
assertEquals(testItem.getName(), item.getName());
assertEquals(testItem.getValue(), item.getValue());
}

@Test
public void findOneSqlSpelByIdAndNameAndValueParam() {
TestItem testItem = this.repository.save(new TestItem(1L, "TEST_NAME", "test_value"));
TestItem item = this.repository.findOneSqlSpelByIdAndNameAndValueParam(0L, "test_name", "TEST", "VALUE").orElse(null);
assertNotNull(item);
assertEquals(testItem.getId(), item.getId());
assertEquals(testItem.getName(), item.getName());
assertEquals(testItem.getValue(), item.getValue());
}

@Test
public void findOneSqlByIdAndNameAndValueAnyParameterOrder() {
TestItem testItem = this.repository.save(new TestItem(1L, "TestName", "TestValue"));
Expand Down Expand Up @@ -765,6 +785,12 @@ Optional<TestItem> findOneSqlByNameAndValueManyParams(String name1, String name2
@Query("SELECT * FROM items WHERE id = :id AND name = :name AND value = :value")
Optional<TestItem> findOneSqlByIdAndNameAndValueParam(@Param("id") Long id, @Param("name") String name, @Param("value") String value);

@Query("SELECT * FROM items WHERE id = :#{id + 1} AND name = :#{name.toUpperCase()} AND value = :#{value1.toLowerCase() + '_' + value2.toLowerCase()}")
Optional<TestItem> findOneSqlSpelByIdAndNameAndValueParam(@Param("id") Long id, @Param("name") String name, @Param("value1") String value1, @Param("value2") String value2);

@Query("SELECT * FROM items WHERE id = :#{item.id} AND name = :#{item.name} AND value = :#{item.value}")
Optional<TestItem> findOneSqlSpelByItemIdAndNameAndValueParam(@Param("item") TestItem item);

@Query("SELECT * FROM items WHERE id = ?2 AND name = ?3 AND value = ?1")
Optional<TestItem> findOneSqlByIdAndNameAndValue(String value, Long id, String name);

Expand Down

0 comments on commit 426e3ee

Please sign in to comment.