diff --git a/src/main/java/org/springframework/data/reindexer/repository/query/ReindexerRepositoryQuery.java b/src/main/java/org/springframework/data/reindexer/repository/query/ReindexerRepositoryQuery.java index 73138a7..e2c42b8 100644 --- a/src/main/java/org/springframework/data/reindexer/repository/query/ReindexerRepositoryQuery.java +++ b/src/main/java/org/springframework/data/reindexer/repository/query/ReindexerRepositoryQuery.java @@ -15,9 +15,8 @@ */ package org.springframework.data.reindexer.repository.query; -import java.beans.PropertyDescriptor; import java.lang.reflect.Array; -import java.lang.reflect.Field; +import java.lang.reflect.Constructor; import java.util.ArrayList; import java.util.Collection; import java.util.HashMap; @@ -27,10 +26,14 @@ import java.util.Optional; import java.util.Spliterator; import java.util.Spliterators; +import java.util.concurrent.ConcurrentHashMap; import java.util.stream.Stream; import java.util.stream.StreamSupport; import ru.rt.restream.reindexer.AggregationResult; + +import org.springframework.data.mapping.PreferredConstructor; +import org.springframework.data.mapping.model.PreferredConstructorDiscoverer; import org.springframework.data.reindexer.repository.support.TransactionalNamespace; import ru.rt.restream.reindexer.FieldType; import ru.rt.restream.reindexer.Namespace; @@ -46,6 +49,7 @@ import org.springframework.data.domain.Sort.Order; import org.springframework.data.repository.query.ParametersParameterAccessor; import org.springframework.data.repository.query.RepositoryQuery; +import org.springframework.data.repository.query.ReturnedType; import org.springframework.data.repository.query.parser.Part; import org.springframework.data.repository.query.parser.PartTree; import org.springframework.data.repository.query.parser.PartTree.OrPart; @@ -58,6 +62,8 @@ */ public class ReindexerRepositoryQuery implements RepositoryQuery { + private final Map, Constructor> preferredConstructors = new ConcurrentHashMap<>(); + private final ReindexerQueryMethod queryMethod; private final Namespace namespace; @@ -66,8 +72,6 @@ public class ReindexerRepositoryQuery implements RepositoryQuery { private final Map indexes; - private final String[] selectFields; - /** * Creates an instance. * @@ -85,31 +89,23 @@ public ReindexerRepositoryQuery(ReindexerQueryMethod queryMethod, ReindexerEntit for (ReindexerIndex index : namespace.getIndexes()) { this.indexes.put(index.getName(), index); } - this.selectFields = getSelectFields(); - } - - private String[] getSelectFields() { - if (this.queryMethod.getDomainClass() == this.queryMethod.getReturnedObjectType() - || this.queryMethod.getParameters().hasDynamicProjection()) { - return new String[0]; - } - return getSelectFields(this.queryMethod.getReturnedObjectType()); } @Override public Object execute(Object[] parameters) { - Query query = createQuery(parameters); + ReturnedType projectionType = getProjectionType(parameters); + Query query = createQuery(projectionType, parameters); if (this.queryMethod.isCollectionQuery()) { - return toCollection(query, parameters); + return toCollection(query, projectionType); } if (this.queryMethod.isStreamQuery()) { - return toStream(query, parameters); + return toStream(query, projectionType); } if (this.queryMethod.isIteratorQuery()) { - return new ProjectingResultIterator(query, parameters); + return new ProjectingResultIterator(query, projectionType); } if (this.queryMethod.isQueryForEntity()) { - Object entity = toEntity(query, parameters); + Object entity = toEntity(query, projectionType); Assert.state(entity != null, "Exactly one item expected, but there is zero"); return entity; } @@ -123,13 +119,25 @@ public Object execute(Object[] parameters) { query.delete(); return null; } - return Optional.ofNullable(toEntity(query, parameters)); + return Optional.ofNullable(toEntity(query, projectionType)); } - private Query createQuery(Object[] parameters) { + private ReturnedType getProjectionType(Object[] parameters) { + if (this.queryMethod.getDomainClass() == this.queryMethod.getReturnedObjectType()) { + return null; + } + Class type = this.queryMethod.getParameters().hasDynamicProjection() + ? (Class) parameters[this.queryMethod.getParameters().getDynamicProjectionIndex()] + : this.queryMethod.getReturnedObjectType(); + return ReturnedType.of(type, this.queryMethod.getDomainClass(), this.queryMethod.getFactory()); + } + + private Query createQuery(ReturnedType projectionType, Object[] parameters) { ParametersParameterAccessor accessor = new ParametersParameterAccessor(this.queryMethod.getParameters(), parameters); - Query base = this.namespace.query().select(getSelectFields(parameters)); + String[] selectFields = (projectionType != null) ? projectionType.getInputProperties().toArray(String[]::new) + : new String[0]; + Query base = this.namespace.query().select(selectFields); Iterator iterator = accessor.iterator(); for (OrPart node : this.tree) { Iterator parts = node.iterator(); @@ -149,34 +157,6 @@ private Query createQuery(Object[] parameters) { return base; } - private String[] getSelectFields(Object[] parameters) { - if (this.queryMethod.getParameters().hasDynamicProjection()) { - Class type = (Class) parameters[this.queryMethod.getParameters().getDynamicProjectionIndex()]; - return getSelectFields(type); - } - return this.selectFields; - } - - private String[] getSelectFields(Class type) { - if (type.isInterface()) { - List inputProperties = this.queryMethod.getFactory() - .getProjectionInformation(type).getInputProperties(); - String[] result = new String[inputProperties.size()]; - for (int i = 0; i < result.length; i++) { - result[i] = inputProperties.get(i).getName(); - } - return result; - } - else { - List inheritedFields = BeanPropertyUtils.getInheritedFields(type); - String[] result = new String[inheritedFields.size()]; - for (int i = 0; i < result.length; i++) { - result[i] = inheritedFields.get(i).getName(); - } - return result; - } - } - private Query where(Part part, Query criteria, Iterator parameters) { String indexName = part.getProperty().toDotPath(); switch (part.getType()) { @@ -240,8 +220,8 @@ private Object getParameterValue(String indexName, Object value) { return value; } - private Collection toCollection(Query query, Object[] parameters) { - try (ResultIterator iterator = new ProjectingResultIterator(query, parameters)) { + private Collection toCollection(Query query, ReturnedType projectionType) { + try (ResultIterator iterator = new ProjectingResultIterator(query, projectionType)) { Collection result = CollectionFactory.createCollection(this.queryMethod.getReturnType(), this.queryMethod.getReturnedObjectType(), (int) iterator.size()); while (iterator.hasNext()) { @@ -251,14 +231,14 @@ private Collection toCollection(Query query, Object[] parameters) { } } - private Stream toStream(Query query, Object[] parameters) { - ResultIterator iterator = new ProjectingResultIterator(query, parameters); + private Stream toStream(Query query, ReturnedType projectionType) { + ResultIterator iterator = new ProjectingResultIterator(query, projectionType); Spliterator spliterator = Spliterators.spliterator(iterator, iterator.size(), Spliterator.NONNULL); return StreamSupport.stream(spliterator, false).onClose(iterator::close); } - private Object toEntity(Query query, Object[] parameters) { - try (ResultIterator iterator = new ProjectingResultIterator(query, parameters)) { + private Object toEntity(Query query, ReturnedType projectionType) { + try (ResultIterator iterator = new ProjectingResultIterator(query, projectionType)) { Object item = null; if (iterator.hasNext()) { item = iterator.next(); @@ -279,28 +259,13 @@ public ReindexerQueryMethod getQueryMethod() { private final class ProjectingResultIterator implements ResultIterator { - private final Object[] parameters; - private final ResultIterator delegate; - private ProjectingResultIterator(Query query, Object[] parameters) { - this.parameters = parameters; - this.delegate = query.execute(determineReturnType()); - } + private final ReturnedType projectionType; - private Class determineReturnType() { - if (ReindexerRepositoryQuery.this.queryMethod.getParameters().hasDynamicProjection()) { - Class type = (Class) this.parameters[ReindexerRepositoryQuery.this.queryMethod.getParameters().getDynamicProjectionIndex()]; - if (type.isInterface()) { - return ReindexerRepositoryQuery.this.queryMethod.getDomainClass(); - } - return type; - } - if (ReindexerRepositoryQuery.this.queryMethod.getDomainClass() != ReindexerRepositoryQuery.this.queryMethod.getReturnedObjectType() - && ReindexerRepositoryQuery.this.queryMethod.getReturnedObjectType().isInterface()) { - return ReindexerRepositoryQuery.this.queryMethod.getDomainClass(); - } - return ReindexerRepositoryQuery.this.queryMethod.getReturnedObjectType(); + private ProjectingResultIterator(Query query, ReturnedType projectionType) { + this.delegate = query.execute(); + this.projectionType = projectionType; } @Override @@ -331,15 +296,26 @@ public boolean hasNext() { @Override public Object next() { Object item = this.delegate.next(); - if (ReindexerRepositoryQuery.this.queryMethod.getParameters().hasDynamicProjection()) { - Class type = (Class) this.parameters[ReindexerRepositoryQuery.this.queryMethod.getParameters().getDynamicProjectionIndex()]; - if (type.isInterface()) { - return ReindexerRepositoryQuery.this.queryMethod.getFactory().createProjection(type, item); + if (this.projectionType != null) { + if (this.projectionType.getReturnedType().isInterface()) { + return ReindexerRepositoryQuery.this.queryMethod.getFactory().createProjection(this.projectionType.getReturnedType(), item); + } + List properties = this.projectionType.getInputProperties(); + Object[] values = new Object[properties.size()]; + for (int i = 0; i < properties.size(); i++) { + values[i] = BeanPropertyUtils.getProperty(item, properties.get(i)); + } + Constructor constructor = ReindexerRepositoryQuery.this.preferredConstructors.computeIfAbsent(this.projectionType.getReturnedType(), (type) -> { + PreferredConstructor preferredConstructor = PreferredConstructorDiscoverer.discover(type); + Assert.state(preferredConstructor != null, () -> "No preferred constructor found for " + type); + return preferredConstructor.getConstructor(); + }); + try { + return constructor.newInstance(values); + } + catch (Exception e) { + throw new RuntimeException(e); } - } - else if (ReindexerRepositoryQuery.this.queryMethod.getReturnedObjectType().isInterface()) { - return ReindexerRepositoryQuery.this.queryMethod.getFactory() - .createProjection(ReindexerRepositoryQuery.this.queryMethod.getReturnedObjectType(), item); } return item; } diff --git a/src/test/java/org/springframework/data/reindexer/repository/ReindexerRepositoryTests.java b/src/test/java/org/springframework/data/reindexer/repository/ReindexerRepositoryTests.java index a9ee5f8..41ea82e 100644 --- a/src/test/java/org/springframework/data/reindexer/repository/ReindexerRepositoryTests.java +++ b/src/test/java/org/springframework/data/reindexer/repository/ReindexerRepositoryTests.java @@ -57,6 +57,7 @@ import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.ComponentScan; import org.springframework.context.annotation.Configuration; +import org.springframework.data.annotation.PersistenceCreator; import org.springframework.data.domain.Sort; import org.springframework.data.domain.Sort.Direction; import org.springframework.data.reindexer.ReindexerTransactionManager; @@ -628,6 +629,54 @@ public void findItemDtoByIdIn() { } } + @Test + public void findItemPreferredConstructorDtoByIdIn() { + List expectedItems = new ArrayList<>(); + for (long i = 0; i < 100; i++) { + expectedItems.add(this.repository.save(new TestItem(i, "TestName" + i, "TestValue" + i))); + } + List foundItems = this.repository.findItemPreferredConstructorDtoByIdIn(expectedItems.stream() + .map(TestItem::getId) + .collect(Collectors.toList())); + assertEquals(expectedItems.size(), foundItems.size()); + for (int i = 0; i < foundItems.size(); i++) { + assertEquals(expectedItems.get(i).getId(), foundItems.get(i).getId()); + assertEquals(expectedItems.get(i).getName(), foundItems.get(i).getName()); + } + } + + @Test + public void findItemRecordByIdIn() { + List expectedItems = new ArrayList<>(); + for (long i = 0; i < 100; i++) { + expectedItems.add(this.repository.save(new TestItem(i, "TestName" + i, "TestValue" + i))); + } + List foundItems = this.repository.findItemRecordByIdIn(expectedItems.stream() + .map(TestItem::getId) + .collect(Collectors.toList())); + assertEquals(expectedItems.size(), foundItems.size()); + for (int i = 0; i < foundItems.size(); i++) { + assertEquals(expectedItems.get(i).getId(), foundItems.get(i).id()); + assertEquals(expectedItems.get(i).getName(), foundItems.get(i).name()); + } + } + + @Test + public void findItemPreferredConstructorRecordByIdIn() { + List expectedItems = new ArrayList<>(); + for (long i = 0; i < 100; i++) { + expectedItems.add(this.repository.save(new TestItem(i, "TestName" + i, "TestValue" + i))); + } + List foundItems = this.repository.findItemPreferredConstructorRecordByIdIn(expectedItems.stream() + .map(TestItem::getId) + .collect(Collectors.toList())); + assertEquals(expectedItems.size(), foundItems.size()); + for (int i = 0; i < foundItems.size(); i++) { + assertNull(foundItems.get(i).id()); + assertEquals(expectedItems.get(i).getName(), foundItems.get(i).name()); + } + } + @Test public void findDynamicItemProjectionByIdIn() { List expectedItems = new ArrayList<>(); @@ -660,6 +709,54 @@ public void findDynamicItemDtoByIdIn() { } } + @Test + public void findDynamicItemPreferredConstructorDtoDtoByIdIn() { + List expectedItems = new ArrayList<>(); + for (long i = 0; i < 100; i++) { + expectedItems.add(this.repository.save(new TestItem(i, "TestName" + i, "TestValue" + i))); + } + List foundItems = this.repository.findByIdIn(expectedItems.stream() + .map(TestItem::getId) + .collect(Collectors.toList()), TestItemPreferredConstructorDto.class); + assertEquals(expectedItems.size(), foundItems.size()); + for (int i = 0; i < foundItems.size(); i++) { + assertEquals(expectedItems.get(i).getId(), foundItems.get(i).getId()); + assertEquals(expectedItems.get(i).getName(), foundItems.get(i).getName()); + } + } + + @Test + public void findDynamicItemRecordByIdIn() { + List expectedItems = new ArrayList<>(); + for (long i = 0; i < 100; i++) { + expectedItems.add(this.repository.save(new TestItem(i, "TestName" + i, "TestValue" + i))); + } + List foundItems = this.repository.findByIdIn(expectedItems.stream() + .map(TestItem::getId) + .collect(Collectors.toList()), TestItemRecord.class); + assertEquals(expectedItems.size(), foundItems.size()); + for (int i = 0; i < foundItems.size(); i++) { + assertEquals(expectedItems.get(i).getId(), foundItems.get(i).id()); + assertEquals(expectedItems.get(i).getName(), foundItems.get(i).name()); + } + } + + @Test + public void findDynamicItemPreferredConstructorRecordByIdIn() { + List expectedItems = new ArrayList<>(); + for (long i = 0; i < 100; i++) { + expectedItems.add(this.repository.save(new TestItem(i, "TestName" + i, "TestValue" + i))); + } + List foundItems = this.repository.findByIdIn(expectedItems.stream() + .map(TestItem::getId) + .collect(Collectors.toList()), TestItemPreferredConstructorRecord.class); + assertEquals(expectedItems.size(), foundItems.size()); + for (int i = 0; i < foundItems.size(); i++) { + assertNull(foundItems.get(i).id()); + assertEquals(expectedItems.get(i).getName(), foundItems.get(i).name()); + } + } + @Test public void findByIdInArray() { List expectedItems = new ArrayList<>(); @@ -968,6 +1065,12 @@ Optional findOneSqlByNameAndValueManyParams(String name1, String name2 List findItemDtoByIdIn(List ids); + List findItemPreferredConstructorDtoByIdIn(List ids); + + List findItemRecordByIdIn(List ids); + + List findItemPreferredConstructorRecordByIdIn(List ids); + List findByIdIn(List ids, Class type); } @@ -1090,7 +1193,9 @@ public static class TestItemDto { private String name; - public TestItemDto() { + public TestItemDto(Long id, String name) { + this.id = id; + this.name = name; } public Long getId() { @@ -1111,6 +1216,52 @@ public void setName(String name) { } + public static class TestItemPreferredConstructorDto { + + private Long id; + + private String name; + + public TestItemPreferredConstructorDto(String name) { + this.name = name; + } + + @PersistenceCreator + public TestItemPreferredConstructorDto(Long id, String name) { + this.id = id; + this.name = name; + } + + public Long getId() { + return this.id; + } + + public void setId(Long id) { + this.id = id; + } + + public String getName() { + return this.name; + } + + public void setName(String name) { + this.name = name; + } + + } + + public record TestItemRecord(Long id, String name) { + } + + public record TestItemPreferredConstructorRecord(Long id, String name) { + + @PersistenceCreator + TestItemPreferredConstructorRecord(String name) { + this(null, name); + } + + } + public enum TestEnum { TEST_CONSTANT_1, TEST_CONSTANT_2,