diff --git a/pom.xml b/pom.xml index 3daab4d790..63a1959063 100644 --- a/pom.xml +++ b/pom.xml @@ -5,7 +5,7 @@ org.springframework.data spring-data-mongodb-parent - 5.0.0-SNAPSHOT + 5.0.x-GH-5004-SNAPSHOT pom Spring Data MongoDB diff --git a/spring-data-mongodb-distribution/pom.xml b/spring-data-mongodb-distribution/pom.xml index fc88571622..66a68de39f 100644 --- a/spring-data-mongodb-distribution/pom.xml +++ b/spring-data-mongodb-distribution/pom.xml @@ -15,7 +15,7 @@ org.springframework.data spring-data-mongodb-parent - 5.0.0-SNAPSHOT + 5.0.x-GH-5004-SNAPSHOT ../pom.xml diff --git a/spring-data-mongodb/pom.xml b/spring-data-mongodb/pom.xml index 6f34da5660..102427d19a 100644 --- a/spring-data-mongodb/pom.xml +++ b/spring-data-mongodb/pom.xml @@ -13,7 +13,7 @@ org.springframework.data spring-data-mongodb-parent - 5.0.0-SNAPSHOT + 5.0.x-GH-5004-SNAPSHOT ../pom.xml diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/ExecutableFindOperation.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/ExecutableFindOperation.java index 43c0d521c3..47fea8a02f 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/ExecutableFindOperation.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/ExecutableFindOperation.java @@ -225,6 +225,14 @@ interface TerminatingFindNear { * @return never {@literal null}. */ GeoResults all(); + + /** + * Count matching elements. + * + * @return number of elements matching the query. + * @since 5.0 + */ + long count(); } /** diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/ExecutableFindOperationSupport.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/ExecutableFindOperationSupport.java index 46289ecfa4..39f4affd35 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/ExecutableFindOperationSupport.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/ExecutableFindOperationSupport.java @@ -243,6 +243,11 @@ public TerminatingFindNear map(QueryResultConverter all() { return template.doGeoNear(nearQuery, domainType, getCollectionName(), returnType, resultConverter); } + + @Override + public long count() { + return template.doGeoNearCount(nearQuery, domainType, getCollectionName()); + } } } diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/MongoTemplate.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/MongoTemplate.java index 8682f77ec8..03c0bb7682 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/MongoTemplate.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/MongoTemplate.java @@ -48,6 +48,7 @@ import org.springframework.dao.support.PersistenceExceptionTranslator; import org.springframework.data.convert.EntityReader; import org.springframework.data.domain.OffsetScrollPosition; +import org.springframework.data.domain.Pageable; import org.springframework.data.domain.Window; import org.springframework.data.geo.Distance; import org.springframework.data.geo.GeoResult; @@ -1044,6 +1045,31 @@ public GeoResults geoNear(NearQuery near, Class domainType, String col return doGeoNear(near, domainType, collectionName, returnType, QueryResultConverter.entity()); } + long doGeoNearCount(NearQuery near, Class domainType, String collectionName) { + + Builder optionsBuilder = AggregationOptions.builder().collation(near.getCollation()); + + if (near.hasReadPreference()) { + optionsBuilder.readPreference(near.getReadPreference()); + } + + if (near.hasReadConcern()) { + optionsBuilder.readConcern(near.getReadConcern()); + } + + String distanceField = operations.nearQueryDistanceFieldName(domainType); + Aggregation $geoNear = TypedAggregation.newAggregation(domainType, + Aggregation.geoNear(near, distanceField).skip(-1).limit(-1), Aggregation.count().as("_totalCount")) + .withOptions(optionsBuilder.build()); + + AggregationResults results = doAggregate($geoNear, collectionName, Document.class, + queryOperations.createAggregation($geoNear, (AggregationOperationContext) null)); + Iterator iterator = results.iterator(); + return iterator.hasNext() + ? NumberUtils.convertNumberToTargetClass(iterator.next().get("_totalCount", Integer.class), Long.class) + : 0L; + } + GeoResults doGeoNear(NearQuery near, Class domainType, String collectionName, Class returnType, QueryResultConverter resultConverter) { diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/GeoNearOperation.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/GeoNearOperation.java index bcfc64f2b4..04b793f839 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/GeoNearOperation.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/GeoNearOperation.java @@ -42,6 +42,8 @@ public class GeoNearOperation implements AggregationOperation { private final NearQuery nearQuery; private final String distanceField; private final @Nullable String indexKey; + private final @Nullable Long skip; + private final @Nullable Integer limit; /** * Creates a new {@link GeoNearOperation} from the given {@link NearQuery} and the given distance field. The @@ -51,7 +53,7 @@ public class GeoNearOperation implements AggregationOperation { * @param distanceField must not be {@literal null}. */ public GeoNearOperation(NearQuery nearQuery, String distanceField) { - this(nearQuery, distanceField, null); + this(nearQuery, distanceField, null, nearQuery.getSkip(), null); } /** @@ -63,7 +65,8 @@ public GeoNearOperation(NearQuery nearQuery, String distanceField) { * @param indexKey can be {@literal null}; * @since 2.1 */ - private GeoNearOperation(NearQuery nearQuery, String distanceField, @Nullable String indexKey) { + private GeoNearOperation(NearQuery nearQuery, String distanceField, @Nullable String indexKey, @Nullable Long skip, + @Nullable Integer limit) { Assert.notNull(nearQuery, "NearQuery must not be null"); Assert.hasLength(distanceField, "Distance field must not be null or empty"); @@ -71,6 +74,8 @@ private GeoNearOperation(NearQuery nearQuery, String distanceField, @Nullable St this.nearQuery = nearQuery; this.distanceField = distanceField; this.indexKey = indexKey; + this.skip = skip; + this.limit = limit; } /** @@ -83,7 +88,30 @@ private GeoNearOperation(NearQuery nearQuery, String distanceField, @Nullable St */ @Contract("_ -> new") public GeoNearOperation useIndex(String key) { - return new GeoNearOperation(nearQuery, distanceField, key); + return new GeoNearOperation(nearQuery, distanceField, key, skip, limit); + } + + /** + * Override potential skip applied via {@link NearQuery#getSkip()}. Adds an additional {@link SkipOperation} if value + * is non negative. + * + * @param skip + * @return new instance of {@link GeoNearOperation}. + * @since 5.0 + */ + public GeoNearOperation skip(long skip) { + return new GeoNearOperation(nearQuery, distanceField, indexKey, skip, limit); + } + + /** + * Override potential limit value. Adds an additional {@link LimitOperation} if value is non negative. + * + * @param limit + * @return new instance of {@link GeoNearOperation}. + * @since 5.0 + */ + public GeoNearOperation limit(Integer limit) { + return new GeoNearOperation(nearQuery, distanceField, indexKey, skip, limit); } @Override @@ -92,7 +120,13 @@ public Document toDocument(AggregationOperationContext context) { Document command = context.getMappedObject(nearQuery.toDocument()); if (command.containsKey("query")) { - command.replace("query", context.getMappedObject(command.get("query", Document.class))); + Document query = command.get("query", Document.class); + if (query == null || query.isEmpty()) { + command.remove("query"); + } else { + command.replace("query", context.getMappedObject(query)); + } + } command.remove("collation"); @@ -115,15 +149,18 @@ public List toPipelineStages(AggregationOperationContext context) { Document command = toDocument(context); Number limit = (Number) command.get("$geoNear", Document.class).remove("num"); + if (limit != null && this.limit != null) { + limit = this.limit; + } List stages = new ArrayList<>(3); stages.add(command); - if (nearQuery.getSkip() != null && nearQuery.getSkip() > 0) { - stages.add(new Document("$skip", nearQuery.getSkip())); + if (this.skip != null && this.skip > 0) { + stages.add(new Document("$skip", this.skip)); } - if (limit != null) { + if (limit != null && limit.longValue() > 0) { stages.add(new Document("$limit", limit.longValue())); } diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/query/CriteriaDefinition.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/query/CriteriaDefinition.java index 7777e5f554..4400baa6d6 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/query/CriteriaDefinition.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/query/CriteriaDefinition.java @@ -46,9 +46,7 @@ public interface CriteriaDefinition { * @since 5.0 * @author Christoph Strobl */ - class Placeholder { - - private final Object expression; + interface Placeholder { /** * Create a new placeholder for index bindable parameter. @@ -56,23 +54,29 @@ class Placeholder { * @param position the index of the parameter to bind. * @return new instance of {@link Placeholder}. */ - public static Placeholder indexed(int position) { - return new Placeholder("?%s".formatted(position)); + static Placeholder indexed(int position) { + return new PlaceholderImpl("?%s".formatted(position)); } - public static Placeholder placeholder(String expression) { - return new Placeholder(expression); + static Placeholder placeholder(String expression) { + return new PlaceholderImpl(expression); } - Placeholder(Object value) { - this.expression = value; + Object getValue(); + } + + static class PlaceholderImpl implements Placeholder { + private final Object expression; + + public PlaceholderImpl(Object expression) { + this.expression = expression; } + @Override public Object getValue() { return expression; } - @Override public String toString() { return getValue().toString(); } diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/query/GeoCommand.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/query/GeoCommand.java index 19ecd94e23..4b8f81ef2b 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/query/GeoCommand.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/query/GeoCommand.java @@ -22,6 +22,7 @@ import org.springframework.data.geo.Circle; import org.springframework.data.geo.Polygon; import org.springframework.data.geo.Shape; +import org.springframework.data.mongodb.core.geo.GeoJson; import org.springframework.data.mongodb.core.geo.Sphere; import org.springframework.util.Assert; @@ -75,6 +76,9 @@ private String getCommand(Shape shape) { Assert.notNull(shape, "Shape must not be null"); + if(shape instanceof GeoJson) { + return "$geometry"; + } if (shape instanceof Box) { return "$box"; } else if (shape instanceof Circle) { diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/query/NearQuery.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/query/NearQuery.java index 88d7dc5c1d..4f42437704 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/query/NearQuery.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/query/NearQuery.java @@ -671,7 +671,7 @@ public Document toDocument() { document.put("distanceMultiplier", getDistanceMultiplier()); } - if (limit != null) { + if (limit != null && limit > 0) { document.put("num", limit); } diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/AggregationBlocks.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/AggregationBlocks.java new file mode 100644 index 0000000000..37f24cd849 --- /dev/null +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/AggregationBlocks.java @@ -0,0 +1,363 @@ +/* + * Copyright 2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.mongodb.repository.aot; + +import java.util.ArrayList; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import org.bson.Document; +import org.jspecify.annotations.NullUnmarked; +import org.springframework.core.annotation.MergedAnnotation; +import org.springframework.data.domain.SliceImpl; +import org.springframework.data.domain.Sort.Order; +import org.springframework.data.mongodb.core.MongoOperations; +import org.springframework.data.mongodb.core.aggregation.Aggregation; +import org.springframework.data.mongodb.core.aggregation.AggregationOptions; +import org.springframework.data.mongodb.core.aggregation.AggregationPipeline; +import org.springframework.data.mongodb.core.aggregation.AggregationResults; +import org.springframework.data.mongodb.core.aggregation.TypedAggregation; +import org.springframework.data.mongodb.core.mapping.MongoSimpleTypes; +import org.springframework.data.mongodb.core.query.Collation; +import org.springframework.data.mongodb.repository.Hint; +import org.springframework.data.mongodb.repository.ReadPreference; +import org.springframework.data.mongodb.repository.query.MongoQueryMethod; +import org.springframework.data.repository.aot.generate.AotQueryMethodGenerationContext; +import org.springframework.data.util.ReflectionUtils; +import org.springframework.javapoet.CodeBlock; +import org.springframework.javapoet.CodeBlock.Builder; +import org.springframework.util.ClassUtils; +import org.springframework.util.CollectionUtils; +import org.springframework.util.StringUtils; + +/** + * @author Christoph Strobl + * @since 5.0 + */ +class AggregationBlocks { + + @NullUnmarked + static class AggregationExecutionCodeBlockBuilder { + + private final AotQueryMethodGenerationContext context; + private final MongoQueryMethod queryMethod; + private String aggregationVariableName; + + AggregationExecutionCodeBlockBuilder(AotQueryMethodGenerationContext context, MongoQueryMethod queryMethod) { + + this.context = context; + this.queryMethod = queryMethod; + } + + AggregationExecutionCodeBlockBuilder referencing(String aggregationVariableName) { + + this.aggregationVariableName = aggregationVariableName; + return this; + } + + CodeBlock build() { + + String mongoOpsRef = context.fieldNameOf(MongoOperations.class); + Builder builder = CodeBlock.builder(); + + builder.add("\n"); + + Class outputType = queryMethod.getReturnedObjectType(); + if (MongoSimpleTypes.HOLDER.isSimpleType(outputType)) { + outputType = Document.class; + } else if (ClassUtils.isAssignable(AggregationResults.class, outputType)) { + outputType = queryMethod.getReturnType().getComponentType().getType(); + } + + if (ReflectionUtils.isVoid(queryMethod.getReturnedObjectType())) { + builder.addStatement("$L.aggregate($L, $T.class)", mongoOpsRef, aggregationVariableName, outputType); + return builder.build(); + } + + if (ClassUtils.isAssignable(AggregationResults.class, context.getMethod().getReturnType())) { + builder.addStatement("return $L.aggregate($L, $T.class)", mongoOpsRef, aggregationVariableName, outputType); + return builder.build(); + } + + if (outputType == Document.class) { + + Class returnType = ClassUtils.resolvePrimitiveIfNecessary(queryMethod.getReturnedObjectType()); + + if (queryMethod.isStreamQuery()) { + + builder.addStatement("$T<$T> $L = $L.aggregateStream($L, $T.class)", Stream.class, Document.class, + context.localVariable("results"), mongoOpsRef, aggregationVariableName, outputType); + + builder.addStatement("return $1L.map(it -> ($2T) convertSimpleRawResult($2T.class, it))", + context.localVariable("results"), returnType); + } else { + + builder.addStatement("$T $L = $L.aggregate($L, $T.class)", AggregationResults.class, + context.localVariable("results"), mongoOpsRef, aggregationVariableName, outputType); + + if (!queryMethod.isCollectionQuery()) { + builder.addStatement( + "return $1T.<$2T>firstElement(convertSimpleRawResults($2T.class, $3L.getMappedResults()))", + CollectionUtils.class, returnType, context.localVariable("results")); + } else { + builder.addStatement("return convertSimpleRawResults($T.class, $L.getMappedResults())", returnType, + context.localVariable("results")); + } + } + } else { + if (queryMethod.isSliceQuery()) { + builder.addStatement("$T $L = $L.aggregate($L, $T.class)", AggregationResults.class, + context.localVariable("results"), mongoOpsRef, aggregationVariableName, outputType); + builder.addStatement("boolean $L = $L.getMappedResults().size() > $L.getPageSize()", + context.localVariable("hasNext"), context.localVariable("results"), context.getPageableParameterName()); + builder.addStatement( + "return new $1T<>($2L ? $3L.getMappedResults().subList(0, $4L.getPageSize()) : $3L.getMappedResults(), $4L, $2L)", + SliceImpl.class, context.localVariable("hasNext"), context.localVariable("results"), + context.getPageableParameterName()); + } else { + + if (queryMethod.isStreamQuery()) { + builder.addStatement("return $L.aggregateStream($L, $T.class)", mongoOpsRef, aggregationVariableName, + outputType); + } else { + + builder.addStatement("return $L.aggregate($L, $T.class).getMappedResults()", mongoOpsRef, + aggregationVariableName, outputType); + } + } + } + + return builder.build(); + } + } + + @NullUnmarked + static class AggregationCodeBlockBuilder { + + private final AotQueryMethodGenerationContext context; + private final MongoQueryMethod queryMethod; + private final Map arguments; + + private AggregationInteraction source; + + private String aggregationVariableName; + private boolean pipelineOnly; + + AggregationCodeBlockBuilder(AotQueryMethodGenerationContext context, MongoQueryMethod queryMethod) { + + this.context = context; + this.arguments = new LinkedHashMap<>(); + context.getBindableParameterNames().forEach(it -> arguments.put(it, CodeBlock.of(it))); + this.queryMethod = queryMethod; + } + + AggregationCodeBlockBuilder stages(AggregationInteraction aggregation) { + + this.source = aggregation; + return this; + } + + AggregationCodeBlockBuilder usingAggregationVariableName(String aggregationVariableName) { + + this.aggregationVariableName = aggregationVariableName; + return this; + } + + AggregationCodeBlockBuilder pipelineOnly(boolean pipelineOnly) { + + this.pipelineOnly = pipelineOnly; + return this; + } + + CodeBlock build() { + + Builder builder = CodeBlock.builder(); + builder.add("\n"); + + String pipelineName = context.localVariable(aggregationVariableName + (pipelineOnly ? "" : "Pipeline")); + builder.add(pipeline(pipelineName)); + + if (!pipelineOnly) { + + builder.addStatement("$1T<$2T> $3L = $4T.newAggregation($2T.class, $5L.getOperations())", + TypedAggregation.class, context.getRepositoryInformation().getDomainType(), aggregationVariableName, + Aggregation.class, pipelineName); + + builder.add(aggregationOptions(aggregationVariableName)); + } + + return builder.build(); + } + + private CodeBlock pipeline(String pipelineVariableName) { + + String sortParameter = context.getSortParameterName(); + String limitParameter = context.getLimitParameterName(); + String pageableParameter = context.getPageableParameterName(); + + boolean mightBeSorted = StringUtils.hasText(sortParameter); + boolean mightBeLimited = StringUtils.hasText(limitParameter); + boolean mightBePaged = StringUtils.hasText(pageableParameter); + + int stageCount = source.stages().size(); + if (mightBeSorted) { + stageCount++; + } + if (mightBeLimited) { + stageCount++; + } + if (mightBePaged) { + stageCount += 3; + } + + Builder builder = CodeBlock.builder(); + builder.add(aggregationStages(context.localVariable("stages"), source.stages(), stageCount, arguments)); + + if (mightBeSorted) { + builder.add(sortingStage(sortParameter)); + } + + if (mightBeLimited) { + builder.add(limitingStage(limitParameter)); + } + + if (mightBePaged) { + builder.add(pagingStage(pageableParameter, queryMethod.isSliceQuery())); + } + + builder.addStatement("$T $L = createPipeline($L)", AggregationPipeline.class, pipelineVariableName, + context.localVariable("stages")); + return builder.build(); + } + + private CodeBlock aggregationOptions(String aggregationVariableName) { + + Builder builder = CodeBlock.builder(); + List options = new ArrayList<>(5); + if (ReflectionUtils.isVoid(queryMethod.getReturnedObjectType())) { + options.add(CodeBlock.of(".skipOutput()")); + } + + MergedAnnotation hintAnnotation = context.getAnnotation(Hint.class); + String hint = hintAnnotation.isPresent() ? hintAnnotation.getString("value") : null; + if (StringUtils.hasText(hint)) { + options.add(CodeBlock.of(".hint($S)", hint)); + } + + MergedAnnotation readPreferenceAnnotation = context.getAnnotation(ReadPreference.class); + String readPreference = readPreferenceAnnotation.isPresent() ? readPreferenceAnnotation.getString("value") : null; + if (StringUtils.hasText(readPreference)) { + options.add(CodeBlock.of(".readPreference($T.valueOf($S))", com.mongodb.ReadPreference.class, readPreference)); + } + + if (queryMethod.hasAnnotatedCollation()) { + options.add(CodeBlock.of(".collation($T.parse($S))", Collation.class, queryMethod.getAnnotatedCollation())); + } + + if (!options.isEmpty()) { + + Builder optionsBuilder = CodeBlock.builder(); + optionsBuilder.add("$1T $2L = $1T.builder()\n", AggregationOptions.class, + context.localVariable("aggregationOptions")); + optionsBuilder.indent(); + for (CodeBlock optionBlock : options) { + optionsBuilder.add(optionBlock); + optionsBuilder.add("\n"); + } + optionsBuilder.add(".build();\n"); + optionsBuilder.unindent(); + builder.add(optionsBuilder.build()); + + builder.addStatement("$1L = $1L.withOptions($2L)", aggregationVariableName, + context.localVariable("aggregationOptions")); + } + return builder.build(); + } + + private CodeBlock aggregationStages(String stageListVariableName, Iterable stages, int stageCount, + Map arguments) { + + Builder builder = CodeBlock.builder(); + builder.addStatement("$T<$T> $L = new $T($L)", List.class, Object.class, stageListVariableName, ArrayList.class, + stageCount); + int stageCounter = 0; + + for (String stage : stages) { + String stageName = context.localVariable("stage_%s".formatted(stageCounter++)); + builder.add(MongoCodeBlocks.renderExpressionToDocument(stage, stageName, arguments)); + builder.addStatement("$L.add($L)", context.localVariable("stages"), stageName); + } + + return builder.build(); + } + + private CodeBlock sortingStage(String sortProvider) { + + Builder builder = CodeBlock.builder(); + + builder.beginControlFlow("if ($L.isSorted())", sortProvider); + builder.addStatement("$1T $2L = new $1T()", Document.class, context.localVariable("sortDocument")); + builder.beginControlFlow("for ($T $L : $L)", Order.class, context.localVariable("order"), sortProvider); + builder.addStatement("$1L.append($2L.getProperty(), $2L.isAscending() ? 1 : -1);", + context.localVariable("sortDocument"), context.localVariable("order")); + builder.endControlFlow(); + builder.addStatement("stages.add(new $T($S, $L))", Document.class, "$sort", + context.localVariable("sortDocument")); + builder.endControlFlow(); + + return builder.build(); + } + + private CodeBlock pagingStage(String pageableProvider, boolean slice) { + + Builder builder = CodeBlock.builder(); + + builder.add(sortingStage(pageableProvider + ".getSort()")); + + builder.beginControlFlow("if ($L.isPaged())", pageableProvider); + builder.beginControlFlow("if ($L.getOffset() > 0)", pageableProvider); + builder.addStatement("$L.add($T.skip($L.getOffset()))", context.localVariable("stages"), Aggregation.class, + pageableProvider); + builder.endControlFlow(); + if (slice) { + builder.addStatement("$L.add($T.limit($L.getPageSize() + 1))", context.localVariable("stages"), + Aggregation.class, pageableProvider); + } else { + builder.addStatement("$L.add($T.limit($L.getPageSize()))", context.localVariable("stages"), Aggregation.class, + pageableProvider); + } + builder.endControlFlow(); + + return builder.build(); + } + + private CodeBlock limitingStage(String limitProvider) { + + Builder builder = CodeBlock.builder(); + + builder.beginControlFlow("if ($L.isLimited())", limitProvider); + builder.addStatement("$L.add($T.limit($L.max()))", context.localVariable("stages"), Aggregation.class, + limitProvider); + builder.endControlFlow(); + + return builder.build(); + } + + } +} diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/AotQueryCreator.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/AotQueryCreator.java index 17c19ad951..219f90348c 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/AotQueryCreator.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/AotQueryCreator.java @@ -15,10 +15,10 @@ */ package org.springframework.data.mongodb.repository.aot; +import java.lang.reflect.Method; +import java.util.ArrayList; import java.util.Iterator; import java.util.List; -import java.util.stream.Collectors; -import java.util.stream.IntStream; import org.bson.conversions.Bson; import org.jspecify.annotations.NullUnmarked; @@ -29,10 +29,17 @@ import org.springframework.data.domain.ScrollPosition; import org.springframework.data.domain.Sort; import org.springframework.data.domain.Vector; +import org.springframework.data.geo.Box; +import org.springframework.data.geo.Circle; import org.springframework.data.geo.Distance; +import org.springframework.data.geo.Metrics; import org.springframework.data.geo.Point; +import org.springframework.data.geo.Polygon; +import org.springframework.data.geo.Shape; import org.springframework.data.mongodb.core.convert.MongoCustomConversions; import org.springframework.data.mongodb.core.convert.MongoWriter; +import org.springframework.data.mongodb.core.geo.GeoJson; +import org.springframework.data.mongodb.core.geo.Sphere; import org.springframework.data.mongodb.core.mapping.MongoMappingContext; import org.springframework.data.mongodb.core.mapping.MongoPersistentProperty; import org.springframework.data.mongodb.core.query.Collation; @@ -40,11 +47,17 @@ import org.springframework.data.mongodb.core.query.Query; import org.springframework.data.mongodb.core.query.TextCriteria; import org.springframework.data.mongodb.core.query.UpdateDefinition; +import org.springframework.data.mongodb.repository.VectorSearch; import org.springframework.data.mongodb.repository.query.ConvertingParameterAccessor; import org.springframework.data.mongodb.repository.query.MongoParameterAccessor; import org.springframework.data.mongodb.repository.query.MongoQueryCreator; +import org.springframework.data.mongodb.repository.query.MongoQueryMethod; +import org.springframework.data.repository.query.Parameter; +import org.springframework.data.repository.query.Parameters; +import org.springframework.data.repository.query.QueryMethod; import org.springframework.data.repository.query.parser.PartTree; import org.springframework.data.util.TypeInformation; +import org.springframework.util.ClassUtils; import com.mongodb.DBRef; @@ -68,11 +81,16 @@ public AotQueryCreator() { } @SuppressWarnings("NullAway") - StringQuery createQuery(PartTree partTree, int parameterCount) { + StringQuery createQuery(PartTree partTree, QueryMethod queryMethod, Method source) { + + boolean geoNear = queryMethod instanceof MongoQueryMethod mqm ? mqm.isGeoNearQuery() : false; + boolean searchQuery = queryMethod instanceof MongoQueryMethod mqm + ? mqm.isSearchQuery() || source.isAnnotationPresent(VectorSearch.class) + : source.isAnnotationPresent(VectorSearch.class); Query query = new MongoQueryCreator(partTree, - new PlaceholderConvertingParameterAccessor(new PlaceholderParameterAccessor(parameterCount)), mappingContext) - .createQuery(); + new PlaceholderConvertingParameterAccessor(new PlaceholderParameterAccessor(queryMethod)), mappingContext, + geoNear, searchQuery).createQuery(); if (partTree.isLimiting()) { query.limit(partTree.getMaxResults()); @@ -118,17 +136,35 @@ static class PlaceholderParameterAccessor implements MongoParameterAccessor { private final List placeholders; - public PlaceholderParameterAccessor(int parameterCount) { - if (parameterCount == 0) { + public PlaceholderParameterAccessor(QueryMethod queryMethod) { + if (queryMethod.getParameters().getNumberOfParameters() == 0) { placeholders = List.of(); } else { - placeholders = IntStream.range(0, parameterCount).mapToObj(Placeholder::indexed).collect(Collectors.toList()); + placeholders = new ArrayList<>(); + Parameters parameters = queryMethod.getParameters(); + for (Parameter parameter : parameters.toList()) { + if (ClassUtils.isAssignable(GeoJson.class, parameter.getType())) { + placeholders.add(parameter.getIndex(), new GeoJsonPlaceholder(parameter.getIndex(), "")); + } else if (ClassUtils.isAssignable(Point.class, parameter.getType())) { + placeholders.add(parameter.getIndex(), new PointPlaceholder(parameter.getIndex())); + } else if (ClassUtils.isAssignable(Circle.class, parameter.getType())) { + placeholders.add(parameter.getIndex(), new CirclePlaceholder(parameter.getIndex())); + } else if (ClassUtils.isAssignable(Box.class, parameter.getType())) { + placeholders.add(parameter.getIndex(), new BoxPlaceholder(parameter.getIndex())); + } else if (ClassUtils.isAssignable(Sphere.class, parameter.getType())) { + placeholders.add(parameter.getIndex(), new SpherePlaceholder(parameter.getIndex())); + } else if (ClassUtils.isAssignable(Polygon.class, parameter.getType())) { + placeholders.add(parameter.getIndex(), new PolygonPlaceholder(parameter.getIndex())); + } else { + placeholders.add(parameter.getIndex(), Placeholder.indexed(parameter.getIndex())); + } + } } } @Override public Range getDistanceRange() { - return null; + return Range.unbounded(); } @Override @@ -207,4 +243,134 @@ public Iterator iterator() { return ((List) placeholders).iterator(); } } + + static class CirclePlaceholder extends Circle implements Placeholder { + + int index; + + public CirclePlaceholder(int index) { + super(new PointPlaceholder(index), Distance.of(1, Metrics.NEUTRAL)); // + this.index = index; + } + + @Override + public Object getValue() { + return "?%s".formatted(index); + } + + @Override + public String toString() { + return getValue().toString(); + } + } + + static class SpherePlaceholder extends Sphere implements Placeholder { + + int index; + + public SpherePlaceholder(int index) { + super(new PointPlaceholder(index), Distance.of(1, Metrics.NEUTRAL)); // + this.index = index; + } + + @Override + public Object getValue() { + return "?%s".formatted(index); + } + + @Override + public String toString() { + return getValue().toString(); + } + } + + static class GeoJsonPlaceholder implements Placeholder, GeoJson>, Shape { + + int index; + String type; + + public GeoJsonPlaceholder(int index, String type) { + this.index = index; + this.type = type; + } + + @Override + public Object getValue() { + return "?%s".formatted(index); + } + + @Override + public String toString() { + return getValue().toString(); + } + + @Override + public String getType() { + return type; + } + + @Override + public List getCoordinates() { + return List.of(); + } + } + + static class BoxPlaceholder extends Box implements Placeholder { + int index; + + public BoxPlaceholder(int index) { + super(new PointPlaceholder(index), new PointPlaceholder(index)); + this.index = index; + } + + @Override + public Object getValue() { + return "?%s".formatted(index); + } + + @Override + public String toString() { + return getValue().toString(); + } + } + + static class PolygonPlaceholder extends Polygon implements Placeholder { + int index; + + public PolygonPlaceholder(int index) { + super(new PointPlaceholder(index), new PointPlaceholder(index), new PointPlaceholder(index), + new PointPlaceholder(index)); + this.index = index; + } + + @Override + public Object getValue() { + return "?%s".formatted(index); + } + + @Override + public String toString() { + return getValue().toString(); + } + } + + static class PointPlaceholder extends Point implements Placeholder { + + int index; + + public PointPlaceholder(int index) { + super(Double.NaN, Double.NaN); + this.index = index; + } + + @Override + public Object getValue() { + return "?" + index; + } + + @Override + public String toString() { + return getValue().toString(); + } + } } diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/DeleteBlocks.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/DeleteBlocks.java new file mode 100644 index 0000000000..1d009f3085 --- /dev/null +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/DeleteBlocks.java @@ -0,0 +1,100 @@ +/* + * Copyright 2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.mongodb.repository.aot; + +import java.util.Optional; + +import org.jspecify.annotations.NullUnmarked; +import org.springframework.data.mongodb.core.ExecutableRemoveOperation.ExecutableRemove; +import org.springframework.data.mongodb.core.MongoOperations; +import org.springframework.data.mongodb.repository.query.MongoQueryExecution.DeleteExecution; +import org.springframework.data.mongodb.repository.query.MongoQueryMethod; +import org.springframework.data.repository.aot.generate.AotQueryMethodGenerationContext; +import org.springframework.javapoet.CodeBlock; +import org.springframework.javapoet.CodeBlock.Builder; +import org.springframework.javapoet.TypeName; +import org.springframework.util.ClassUtils; +import org.springframework.util.ObjectUtils; + +/** + * @author Christoph Strobl + * @since 5.0 + */ +class DeleteBlocks { + + @NullUnmarked + static class DeleteExecutionCodeBlockBuilder { + + private final AotQueryMethodGenerationContext context; + private final MongoQueryMethod queryMethod; + private String queryVariableName; + + DeleteExecutionCodeBlockBuilder(AotQueryMethodGenerationContext context, MongoQueryMethod queryMethod) { + + this.context = context; + this.queryMethod = queryMethod; + } + + DeleteExecutionCodeBlockBuilder referencing(String queryVariableName) { + + this.queryVariableName = queryVariableName; + return this; + } + + CodeBlock build() { + + String mongoOpsRef = context.fieldNameOf(MongoOperations.class); + Builder builder = CodeBlock.builder(); + + Class domainType = context.getRepositoryInformation().getDomainType(); + boolean isProjecting = context.getActualReturnType() != null + && !ObjectUtils.nullSafeEquals(TypeName.get(domainType), context.getActualReturnType()); + + Object actualReturnType = isProjecting ? context.getActualReturnType().getType() : domainType; + + builder.add("\n"); + builder.addStatement("$1T<$2T> $3L = $4L.remove($2T.class)", ExecutableRemove.class, domainType, + context.localVariable("remover"), mongoOpsRef); + + DeleteExecution.Type type = DeleteExecution.Type.FIND_AND_REMOVE_ALL; + if (!queryMethod.isCollectionQuery()) { + if (!ClassUtils.isPrimitiveOrWrapper(context.getMethod().getReturnType())) { + type = DeleteExecution.Type.FIND_AND_REMOVE_ONE; + } else { + type = DeleteExecution.Type.ALL; + } + } + + actualReturnType = ClassUtils.isPrimitiveOrWrapper(context.getMethod().getReturnType()) + ? TypeName.get(context.getMethod().getReturnType()) + : queryMethod.isCollectionQuery() ? context.getReturnTypeName() : actualReturnType; + + if (ClassUtils.isVoidType(context.getMethod().getReturnType())) { + builder.addStatement("new $T($L, $T.$L).execute($L)", DeleteExecution.class, context.localVariable("remover"), + DeleteExecution.Type.class, type.name(), queryVariableName); + } else if (context.getMethod().getReturnType() == Optional.class) { + builder.addStatement("return $T.ofNullable(($T) new $T($L, $T.$L).execute($L))", Optional.class, + actualReturnType, DeleteExecution.class, context.localVariable("remover"), DeleteExecution.Type.class, + type.name(), queryVariableName); + } else { + builder.addStatement("return ($T) new $T($L, $T.$L).execute($L)", actualReturnType, DeleteExecution.class, + context.localVariable("remover"), DeleteExecution.Type.class, type.name(), queryVariableName); + } + + return builder.build(); + } + } +} diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/GeoBlocks.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/GeoBlocks.java new file mode 100644 index 0000000000..b94f55adc2 --- /dev/null +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/GeoBlocks.java @@ -0,0 +1,145 @@ +/* + * Copyright 2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.mongodb.repository.aot; + +import org.springframework.data.geo.Distance; +import org.springframework.data.geo.GeoPage; +import org.springframework.data.geo.GeoResults; +import org.springframework.data.mongodb.core.MongoOperations; +import org.springframework.data.mongodb.core.query.NearQuery; +import org.springframework.data.mongodb.repository.query.MongoQueryMethod; +import org.springframework.data.repository.aot.generate.AotQueryMethodGenerationContext; +import org.springframework.data.support.PageableExecutionUtils; +import org.springframework.javapoet.CodeBlock; +import org.springframework.util.ClassUtils; + +/** + * @author Christoph Strobl + * @since 5.0 + */ +class GeoBlocks { + + static class GeoNearCodeBlockBuilder { + + private final AotQueryMethodGenerationContext context; + private final MongoQueryMethod queryMethod; + + private String variableName; + + GeoNearCodeBlockBuilder(AotQueryMethodGenerationContext context, MongoQueryMethod queryMethod) { + + this.context = context; + this.queryMethod = queryMethod; + } + + CodeBlock build() { + + CodeBlock.Builder builder = CodeBlock.builder(); + builder.add("\n"); + + String locationParameterName = context.getParameterName(queryMethod.getParameters().getNearIndex()); + + builder.addStatement("$1T $2L = $1T.near($3L)", NearQuery.class, variableName, locationParameterName); + + if (queryMethod.getParameters().getRangeIndex() != -1) { + + String rangeParametername = context.getParameterName(queryMethod.getParameters().getRangeIndex()); + String minVarName = context.localVariable("min"); + String maxVarName = context.localVariable("max"); + + builder.beginControlFlow("if($L.getLowerBound().isBounded())", rangeParametername); + builder.addStatement("$1T $2L = $3L.getLowerBound().getValue().get()", Distance.class, minVarName, + rangeParametername); + builder.addStatement("$1L.minDistance($2L).in($2L.getMetric())", variableName, minVarName); + builder.endControlFlow(); + + builder.beginControlFlow("if($L.getUpperBound().isBounded())", rangeParametername); + builder.addStatement("$1T $2L = $3L.getUpperBound().getValue().get()", Distance.class, maxVarName, + rangeParametername); + builder.addStatement("$1L.maxDistance($2L).in($2L.getMetric())", variableName, maxVarName); + builder.endControlFlow(); + } else { + + String distanceParametername = context.getParameterName(queryMethod.getParameters().getMaxDistanceIndex()); + builder.addStatement("$1L.maxDistance($2L).in($2L.getMetric())", variableName, distanceParametername); + } + + if (context.getPageableParameterName() != null) { + builder.addStatement("$L.with($L)", variableName, context.getPageableParameterName()); + } + + MongoCodeBlocks.appendReadPreference(context, builder, variableName); + + return builder.build(); + } + + public GeoNearCodeBlockBuilder usingQueryVariableName(String variableName) { + this.variableName = variableName; + return this; + } + } + + static class GeoNearExecutionCodeBlockBuilder { + + private final AotQueryMethodGenerationContext context; + private final MongoQueryMethod queryMethod; + private String queryVariableName; + + GeoNearExecutionCodeBlockBuilder(AotQueryMethodGenerationContext context, MongoQueryMethod queryMethod) { + + this.context = context; + this.queryMethod = queryMethod; + } + + GeoNearExecutionCodeBlockBuilder referencing(String queryVariableName) { + + this.queryVariableName = queryVariableName; + return this; + } + + CodeBlock build() { + + CodeBlock.Builder builder = CodeBlock.builder(); + builder.add("\n"); + + String executorVar = context.localVariable("nearFinder"); + builder.addStatement("var $L = $L.query($T.class).near($L)", executorVar, + context.fieldNameOf(MongoOperations.class), context.getRepositoryInformation().getDomainType(), + queryVariableName); + + if (ClassUtils.isAssignable(GeoPage.class, context.getReturnType().getRawClass())) { + + String geoResultVar = context.localVariable("geoResult"); + builder.addStatement("var $L = $L.all()", geoResultVar, executorVar); + + builder.beginControlFlow("if($L.isUnpaged())", context.getPageableParameterName()); + builder.addStatement("return new $T<>($L)", GeoPage.class, geoResultVar); + builder.endControlFlow(); + + String pageVar = context.localVariable("resultPage"); + builder.addStatement("var $L = $T.getPage($L.getContent(), $L, () -> $L.count())", pageVar, + PageableExecutionUtils.class, geoResultVar, context.getPageableParameterName(), executorVar); + builder.addStatement("return new $T<>($L, $L, $L.getTotalElements())", GeoPage.class, geoResultVar, + context.getPageableParameterName(), pageVar); + } else if (ClassUtils.isAssignable(GeoResults.class, context.getReturnType().getRawClass())) { + builder.addStatement("return $L.all()", executorVar); + } else { + builder.addStatement("return $L.all().getContent()", executorVar); + } + return builder.build(); + } + } +} diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoAotRepositoryFragmentSupport.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoAotRepositoryFragmentSupport.java index 178ce4bda6..86b3217b07 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoAotRepositoryFragmentSupport.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoAotRepositoryFragmentSupport.java @@ -16,11 +16,20 @@ package org.springframework.data.mongodb.repository.aot; import java.util.ArrayList; +import java.util.LinkedHashMap; import java.util.List; +import java.util.Locale; import java.util.Map; +import java.util.function.Consumer; import org.bson.Document; import org.jspecify.annotations.Nullable; +import org.springframework.data.domain.Range; +import org.springframework.data.domain.Score; +import org.springframework.data.domain.ScoringFunction; +import org.springframework.data.expression.ValueEvaluationContext; +import org.springframework.data.expression.ValueExpression; +import org.springframework.data.mapping.model.ValueExpressionEvaluator; import org.springframework.data.mongodb.BindableMongoExpression; import org.springframework.data.mongodb.core.MongoOperations; import org.springframework.data.mongodb.core.aggregation.AggregationOperation; @@ -28,10 +37,20 @@ import org.springframework.data.mongodb.core.convert.MongoConverter; import org.springframework.data.mongodb.core.mapping.FieldName; import org.springframework.data.mongodb.core.query.BasicQuery; +import org.springframework.data.mongodb.core.query.Collation; +import org.springframework.data.mongodb.core.query.Criteria; +import org.springframework.data.mongodb.repository.query.MongoParameters; +import org.springframework.data.mongodb.util.json.ParameterBindingContext; +import org.springframework.data.mongodb.util.json.ParameterBindingDocumentCodec; +import org.springframework.data.mongodb.util.json.ValueProvider; import org.springframework.data.projection.ProjectionFactory; import org.springframework.data.repository.core.RepositoryMetadata; import org.springframework.data.repository.core.support.RepositoryFactoryBeanSupport; +import org.springframework.data.repository.query.ValueExpressionDelegate; +import org.springframework.expression.EvaluationContext; +import org.springframework.util.Assert; import org.springframework.util.ClassUtils; +import org.springframework.util.CollectionUtils; import org.springframework.util.ObjectUtils; /** @@ -46,31 +65,176 @@ public class MongoAotRepositoryFragmentSupport { private final MongoOperations mongoOperations; private final MongoConverter mongoConverter; private final ProjectionFactory projectionFactory; + private final ValueExpressionDelegate valueExpressionDelegate; protected MongoAotRepositoryFragmentSupport(MongoOperations mongoOperations, RepositoryFactoryBeanSupport.FragmentCreationContext context) { - this(mongoOperations, context.getRepositoryMetadata(), context.getProjectionFactory()); + this(mongoOperations, context.getRepositoryMetadata(), context.getProjectionFactory(), + context.getValueExpressionDelegate()); } protected MongoAotRepositoryFragmentSupport(MongoOperations mongoOperations, RepositoryMetadata repositoryMetadata, - ProjectionFactory projectionFactory) { + ProjectionFactory projectionFactory, ValueExpressionDelegate valueExpressionDelegate) { this.mongoOperations = mongoOperations; this.mongoConverter = mongoOperations.getConverter(); this.repositoryMetadata = repositoryMetadata; this.projectionFactory = projectionFactory; + this.valueExpressionDelegate = valueExpressionDelegate; } protected Document bindParameters(String source, Object[] parameters) { return new BindableMongoExpression(source, this.mongoConverter, parameters).toDocument(); } + protected Document bindParameters(String source, Map parameters) { + + ValueEvaluationContext valueEvaluationContext = this.valueExpressionDelegate.getEvaluationContextAccessor() + .create(new NoMongoParameters()).getEvaluationContext(parameters.values()); + + EvaluationContext evaluationContext = valueEvaluationContext.getEvaluationContext(); + parameters.forEach(evaluationContext::setVariable); + + ParameterBindingContext bindingContext = new ParameterBindingContext(new ValueProvider() { + + private final List args = new ArrayList<>(parameters.values()); + + @Override + public @Nullable Object getBindableValue(int index) { + return args.get(index); + } + }, new ValueExpressionEvaluator() { + + @Override + @SuppressWarnings("unchecked") + public @Nullable T evaluate(String expression) { + ValueExpression parse = valueExpressionDelegate.getValueExpressionParser().parse(expression); + return (T) parse.evaluate(valueEvaluationContext); + } + }); + + return new ParameterBindingDocumentCodec().decode(source, bindingContext); + } + + protected Object[] arguments(Object... arguments) { + return arguments; + } + + protected Map argumentMap(Object... parameters) { + + Assert.state(parameters.length % 2 == 0, "even number of args required"); + + LinkedHashMap argumentMap = CollectionUtils.newLinkedHashMap(parameters.length / 2); + for (int i = 0; i < parameters.length; i += 2) { + + if (!(parameters[i] instanceof String key)) { + throw new IllegalArgumentException("key must be a String"); + } + argumentMap.put(key, parameters[i + 1]); + } + + return argumentMap; + } + + protected @Nullable Object evaluate(String source, Map parameters) { + + ValueEvaluationContext valueEvaluationContext = this.valueExpressionDelegate.getEvaluationContextAccessor() + .create(new NoMongoParameters()).getEvaluationContext(parameters.values()); + + EvaluationContext evaluationContext = valueEvaluationContext.getEvaluationContext(); + parameters.forEach(evaluationContext::setVariable); + + ValueExpression parse = valueExpressionDelegate.getValueExpressionParser().parse(source); + return parse.evaluate(valueEvaluationContext); + } + + protected Consumer scoreBetween(Range.Bound lower, Range.Bound upper) { + + return criteria -> { + if (lower.isBounded()) { + double value = lower.getValue().get().getValue(); + if (lower.isInclusive()) { + criteria.gte(value); + } else { + criteria.gt(value); + } + } + + if (upper.isBounded()) { + + double value = upper.getValue().get().getValue(); + if (upper.isInclusive()) { + criteria.lte(value); + } else { + criteria.lt(value); + } + } + + }; + } + + protected ScoringFunction scoringFunction(Range scoreRange) { + + if (scoreRange != null) { + if (scoreRange.getUpperBound().isBounded()) { + return scoreRange.getUpperBound().getValue().get().getFunction(); + } + + if (scoreRange.getLowerBound().isBounded()) { + return scoreRange.getLowerBound().getValue().get().getFunction(); + } + } + + return ScoringFunction.unspecified(); + } + + // Range scoreRange = accessor.getScoreRange(); + // + // if (scoreRange != null) { + // if (scoreRange.getUpperBound().isBounded()) { + // return scoreRange.getUpperBound().getValue().get().getFunction(); + // } + // + // if (scoreRange.getLowerBound().isBounded()) { + // return scoreRange.getLowerBound().getValue().get().getFunction(); + // } + // } + // + // return ScoringFunction.unspecified(); + + protected Collation collationOf(@Nullable Object source) { + + if (source == null) { + return Collation.simple(); + } + if (source instanceof String) { + return Collation.parse(source.toString()); + } + if (source instanceof Locale locale) { + return Collation.of(locale); + } + if (source instanceof Document document) { + return Collation.from(document); + } + if (source instanceof Collation collation) { + return collation; + } + throw new IllegalArgumentException( + "Unsupported collation source [%s]".formatted(ObjectUtils.nullSafeClassName(source))); + } + protected BasicQuery createQuery(String queryString, Object[] parameters) { Document queryDocument = bindParameters(queryString, parameters); return new BasicQuery(queryDocument); } + protected BasicQuery createQuery(String queryString, Map parameters) { + + Document queryDocument = bindParameters(queryString, parameters); + return new BasicQuery(queryDocument); + } + protected AggregationPipeline createPipeline(List rawStages) { List stages = new ArrayList<>(rawStages.size()); @@ -151,4 +315,10 @@ private static T getPotentiallyConvertedSimpleTypeValue(MongoConverter conve return converter.getConversionService().convert(value, targetType); } + static class NoMongoParameters extends MongoParameters { + + NoMongoParameters() { + super(); + } + } } diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoCodeBlocks.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoCodeBlocks.java index 999391f5ec..4125139bd9 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoCodeBlocks.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoCodeBlocks.java @@ -15,48 +15,29 @@ */ package org.springframework.data.mongodb.repository.aot; -import java.util.ArrayList; -import java.util.List; -import java.util.Optional; +import java.util.Iterator; +import java.util.Map; +import java.util.Map.Entry; import java.util.regex.Pattern; -import java.util.stream.Stream; import org.bson.Document; -import org.jspecify.annotations.NullUnmarked; import org.jspecify.annotations.Nullable; - import org.springframework.core.annotation.MergedAnnotation; -import org.springframework.data.domain.SliceImpl; -import org.springframework.data.domain.Sort.Order; -import org.springframework.data.mongodb.core.ExecutableFindOperation.FindWithQuery; -import org.springframework.data.mongodb.core.ExecutableRemoveOperation.ExecutableRemove; -import org.springframework.data.mongodb.core.ExecutableUpdateOperation.ExecutableUpdate; -import org.springframework.data.mongodb.core.MongoOperations; -import org.springframework.data.mongodb.core.aggregation.Aggregation; -import org.springframework.data.mongodb.core.aggregation.AggregationOptions; -import org.springframework.data.mongodb.core.aggregation.AggregationPipeline; -import org.springframework.data.mongodb.core.aggregation.AggregationResults; -import org.springframework.data.mongodb.core.aggregation.TypedAggregation; -import org.springframework.data.mongodb.core.mapping.MongoSimpleTypes; -import org.springframework.data.mongodb.core.query.BasicQuery; -import org.springframework.data.mongodb.core.query.BasicUpdate; -import org.springframework.data.mongodb.core.query.Collation; -import org.springframework.data.mongodb.repository.Hint; -import org.springframework.data.mongodb.repository.Meta; import org.springframework.data.mongodb.repository.ReadPreference; -import org.springframework.data.mongodb.repository.query.MongoQueryExecution.DeleteExecution; -import org.springframework.data.mongodb.repository.query.MongoQueryExecution.PagedExecution; -import org.springframework.data.mongodb.repository.query.MongoQueryExecution.SlicedExecution; +import org.springframework.data.mongodb.repository.aot.AggregationBlocks.AggregationCodeBlockBuilder; +import org.springframework.data.mongodb.repository.aot.AggregationBlocks.AggregationExecutionCodeBlockBuilder; +import org.springframework.data.mongodb.repository.aot.DeleteBlocks.DeleteExecutionCodeBlockBuilder; +import org.springframework.data.mongodb.repository.aot.GeoBlocks.GeoNearCodeBlockBuilder; +import org.springframework.data.mongodb.repository.aot.GeoBlocks.GeoNearExecutionCodeBlockBuilder; +import org.springframework.data.mongodb.repository.aot.QueryBlocks.QueryCodeBlockBuilder; +import org.springframework.data.mongodb.repository.aot.QueryBlocks.QueryExecutionCodeBlockBuilder; +import org.springframework.data.mongodb.repository.aot.UpdateBlocks.UpdateCodeBlockBuilder; +import org.springframework.data.mongodb.repository.aot.UpdateBlocks.UpdateExecutionCodeBlockBuilder; import org.springframework.data.mongodb.repository.query.MongoQueryMethod; import org.springframework.data.repository.aot.generate.AotQueryMethodGenerationContext; -import org.springframework.data.util.ReflectionUtils; import org.springframework.javapoet.CodeBlock; import org.springframework.javapoet.CodeBlock.Builder; -import org.springframework.javapoet.TypeName; -import org.springframework.util.ClassUtils; -import org.springframework.util.CollectionUtils; import org.springframework.util.NumberUtils; -import org.springframework.util.ObjectUtils; import org.springframework.util.StringUtils; /** @@ -68,6 +49,8 @@ class MongoCodeBlocks { private static final Pattern PARAMETER_BINDING_PATTERN = Pattern.compile("\\?(\\d+)"); + private static final Pattern EXPRESSION_BINDING_PATTERN = Pattern.compile("[\\?:][#$]\\{.*\\}"); + private static final Pattern VALUE_EXPRESSION_PATTERN = Pattern.compile("^#\\{.*}$"); /** * Builder for generating query parsing {@link CodeBlock}. @@ -78,6 +61,7 @@ class MongoCodeBlocks { */ static QueryCodeBlockBuilder queryBlockBuilder(AotQueryMethodGenerationContext context, MongoQueryMethod queryMethod) { + return new QueryCodeBlockBuilder(context, queryMethod); } @@ -116,6 +100,7 @@ static DeleteExecutionCodeBlockBuilder deleteExecutionBlockBuilder(AotQueryMetho */ static UpdateCodeBlockBuilder updateBlockBuilder(AotQueryMethodGenerationContext context, MongoQueryMethod queryMethod) { + return new UpdateCodeBlockBuilder(context, queryMethod); } @@ -158,693 +143,124 @@ static AggregationExecutionCodeBlockBuilder aggregationExecutionBlockBuilder(Aot return new AggregationExecutionCodeBlockBuilder(context, queryMethod); } - @NullUnmarked - static class DeleteExecutionCodeBlockBuilder { - - private final AotQueryMethodGenerationContext context; - private final MongoQueryMethod queryMethod; - private String queryVariableName; - - DeleteExecutionCodeBlockBuilder(AotQueryMethodGenerationContext context, MongoQueryMethod queryMethod) { - - this.context = context; - this.queryMethod = queryMethod; - } - - DeleteExecutionCodeBlockBuilder referencing(String queryVariableName) { - - this.queryVariableName = queryVariableName; - return this; - } - - CodeBlock build() { - - String mongoOpsRef = context.fieldNameOf(MongoOperations.class); - Builder builder = CodeBlock.builder(); - - Class domainType = context.getRepositoryInformation().getDomainType(); - boolean isProjecting = context.getActualReturnType() != null - && !ObjectUtils.nullSafeEquals(TypeName.get(domainType), context.getActualReturnType()); - - Object actualReturnType = isProjecting ? context.getActualReturnType().getType() : domainType; - - builder.add("\n"); - builder.addStatement("$T<$T> $L = $L.remove($T.class)", ExecutableRemove.class, domainType, - context.localVariable("remover"), mongoOpsRef, domainType); - - DeleteExecution.Type type = DeleteExecution.Type.FIND_AND_REMOVE_ALL; - if (!queryMethod.isCollectionQuery()) { - if (!ClassUtils.isPrimitiveOrWrapper(context.getMethod().getReturnType())) { - type = DeleteExecution.Type.FIND_AND_REMOVE_ONE; - } else { - type = DeleteExecution.Type.ALL; - } - } - - actualReturnType = ClassUtils.isPrimitiveOrWrapper(context.getMethod().getReturnType()) - ? TypeName.get(context.getMethod().getReturnType()) - : queryMethod.isCollectionQuery() ? context.getReturnTypeName() : actualReturnType; - - if (ClassUtils.isVoidType(context.getMethod().getReturnType())) { - builder.addStatement("new $T($L, $T.$L).execute($L)", DeleteExecution.class, context.localVariable("remover"), - DeleteExecution.Type.class, type.name(), queryVariableName); - } else if (context.getMethod().getReturnType() == Optional.class) { - builder.addStatement("return $T.ofNullable(($T) new $T($L, $T.$L).execute($L))", Optional.class, - actualReturnType, DeleteExecution.class, context.localVariable("remover"), DeleteExecution.Type.class, - type.name(), queryVariableName); - } else { - builder.addStatement("return ($T) new $T($L, $T.$L).execute($L)", actualReturnType, DeleteExecution.class, - context.localVariable("remover"), DeleteExecution.Type.class, type.name(), queryVariableName); - } + /** + * Builder for generating {@link org.springframework.data.mongodb.core.query.NearQuery} {@link CodeBlock}. + * + * @param context + * @param queryMethod + * @return + */ + static GeoNearCodeBlockBuilder geoNearBlockBuilder(AotQueryMethodGenerationContext context, + MongoQueryMethod queryMethod) { - return builder.build(); - } + return new GeoNearCodeBlockBuilder(context, queryMethod); } - @NullUnmarked - static class UpdateExecutionCodeBlockBuilder { - - private final AotQueryMethodGenerationContext context; - private final MongoQueryMethod queryMethod; - private String queryVariableName; - private String updateVariableName; - - UpdateExecutionCodeBlockBuilder(AotQueryMethodGenerationContext context, MongoQueryMethod queryMethod) { - - this.context = context; - this.queryMethod = queryMethod; - } - - UpdateExecutionCodeBlockBuilder withFilter(String queryVariableName) { - - this.queryVariableName = queryVariableName; - return this; - } - - UpdateExecutionCodeBlockBuilder referencingUpdate(String updateVariableName) { - - this.updateVariableName = updateVariableName; - return this; - } - - CodeBlock build() { - - String mongoOpsRef = context.fieldNameOf(MongoOperations.class); - Builder builder = CodeBlock.builder(); - - builder.add("\n"); - - String updateReference = updateVariableName; - Class domainType = context.getRepositoryInformation().getDomainType(); - builder.addStatement("$T<$T> $L = $L.update($T.class)", ExecutableUpdate.class, domainType, - context.localVariable("updater"), mongoOpsRef, domainType); - - Class returnType = ClassUtils.resolvePrimitiveIfNecessary(queryMethod.getReturnedObjectType()); - if (ReflectionUtils.isVoid(returnType)) { - builder.addStatement("$L.matching($L).apply($L).all()", context.localVariable("updater"), queryVariableName, - updateReference); - } else if (ClassUtils.isAssignable(Long.class, returnType)) { - builder.addStatement("return $L.matching($L).apply($L).all().getModifiedCount()", - context.localVariable("updater"), queryVariableName, updateReference); - } else { - builder.addStatement("$T $L = $L.matching($L).apply($L).all().getModifiedCount()", Long.class, - context.localVariable("modifiedCount"), context.localVariable("updater"), queryVariableName, - updateReference); - builder.addStatement("return $T.convertNumberToTargetClass($L, $T.class)", NumberUtils.class, - context.localVariable("modifiedCount"), returnType); - } + /** + * Builder for generating {@link org.springframework.data.mongodb.core.query.NearQuery} execution {@link CodeBlock} + * that can return {@link org.springframework.data.geo.GeoResults}. + * + * @param context + * @param queryMethod + * @return + */ + static GeoNearExecutionCodeBlockBuilder geoNearExecutionBlockBuilder(AotQueryMethodGenerationContext context, + MongoQueryMethod queryMethod) { - return builder.build(); - } + return new GeoNearExecutionCodeBlockBuilder(context, queryMethod); } - @NullUnmarked - static class AggregationExecutionCodeBlockBuilder { - - private final AotQueryMethodGenerationContext context; - private final MongoQueryMethod queryMethod; - private String aggregationVariableName; - - AggregationExecutionCodeBlockBuilder(AotQueryMethodGenerationContext context, MongoQueryMethod queryMethod) { - - this.context = context; - this.queryMethod = queryMethod; - } - - AggregationExecutionCodeBlockBuilder referencing(String aggregationVariableName) { - - this.aggregationVariableName = aggregationVariableName; - return this; - } - - CodeBlock build() { - - String mongoOpsRef = context.fieldNameOf(MongoOperations.class); - Builder builder = CodeBlock.builder(); - - builder.add("\n"); + static CodeBlock renderExpressionToDocument(@Nullable String source, String variableName, + Map arguments) { - Class outputType = queryMethod.getReturnedObjectType(); - if (MongoSimpleTypes.HOLDER.isSimpleType(outputType)) { - outputType = Document.class; - } else if (ClassUtils.isAssignable(AggregationResults.class, outputType)) { - outputType = queryMethod.getReturnType().getComponentType().getType(); - } - - if (ReflectionUtils.isVoid(queryMethod.getReturnedObjectType())) { - builder.addStatement("$L.aggregate($L, $T.class)", mongoOpsRef, aggregationVariableName, outputType); - return builder.build(); - } - - if (ClassUtils.isAssignable(AggregationResults.class, context.getMethod().getReturnType())) { - builder.addStatement("return $L.aggregate($L, $T.class)", mongoOpsRef, aggregationVariableName, outputType); - return builder.build(); - } - - if (outputType == Document.class) { - - Class returnType = ClassUtils.resolvePrimitiveIfNecessary(queryMethod.getReturnedObjectType()); - - if (queryMethod.isStreamQuery()) { - - builder.addStatement("$T<$T> $L = $L.aggregateStream($L, $T.class)", Stream.class, Document.class, - context.localVariable("results"), mongoOpsRef, aggregationVariableName, outputType); - - builder.addStatement("return $L.map(it -> ($T) convertSimpleRawResult($T.class, it))", - context.localVariable("results"), returnType, returnType); - } else { - - builder.addStatement("$T $L = $L.aggregate($L, $T.class)", AggregationResults.class, - context.localVariable("results"), mongoOpsRef, aggregationVariableName, outputType); - - if (!queryMethod.isCollectionQuery()) { - builder.addStatement("return $T.<$T>firstElement(convertSimpleRawResults($T.class, $L.getMappedResults()))", - CollectionUtils.class, returnType, returnType, context.localVariable("results")); - } else { - builder.addStatement("return convertSimpleRawResults($T.class, $L.getMappedResults())", returnType, - context.localVariable("results")); - } - } + Builder builder = CodeBlock.builder(); + if (!StringUtils.hasText(source)) { + builder.addStatement("$1T $2L = new $1T()", Document.class, variableName); + } else if (!containsPlaceholder(source)) { + builder.addStatement("$1T $2L = $1T.parse($3S)", Document.class, variableName, source); + } else { + builder.add("$T $L = bindParameters($S, ", Document.class, variableName, source); + if (containsNamedPlaceholder(source)) { + builder.add(renderArgumentMap(arguments)); } else { - if (queryMethod.isSliceQuery()) { - builder.addStatement("$T $L = $L.aggregate($L, $T.class)", AggregationResults.class, - context.localVariable("results"), mongoOpsRef, aggregationVariableName, outputType); - builder.addStatement("boolean $L = $L.getMappedResults().size() > $L.getPageSize()", - context.localVariable("hasNext"), context.localVariable("results"), context.getPageableParameterName()); - builder.addStatement( - "return new $T<>($L ? $L.getMappedResults().subList(0, $L.getPageSize()) : $L.getMappedResults(), $L, $L)", - SliceImpl.class, context.localVariable("hasNext"), context.localVariable("results"), - context.getPageableParameterName(), context.localVariable("results"), context.getPageableParameterName(), - context.localVariable("hasNext")); - } else { - - if (queryMethod.isStreamQuery()) { - builder.addStatement("return $L.aggregateStream($L, $T.class)", mongoOpsRef, aggregationVariableName, - outputType); - } else { - - builder.addStatement("return $L.aggregate($L, $T.class).getMappedResults()", mongoOpsRef, - aggregationVariableName, outputType); - } - } + builder.add(renderArgumentArray(arguments)); } - - return builder.build(); + builder.add(");\n"); } + return builder.build(); } - @NullUnmarked - static class QueryExecutionCodeBlockBuilder { - - private final AotQueryMethodGenerationContext context; - private final MongoQueryMethod queryMethod; - private QueryInteraction query; - - QueryExecutionCodeBlockBuilder(AotQueryMethodGenerationContext context, MongoQueryMethod queryMethod) { - - this.context = context; - this.queryMethod = queryMethod; - } - - QueryExecutionCodeBlockBuilder forQuery(QueryInteraction query) { - - this.query = query; - return this; - } - - CodeBlock build() { - - String mongoOpsRef = context.fieldNameOf(MongoOperations.class); - - Builder builder = CodeBlock.builder(); - - boolean isProjecting = context.getReturnedType().isProjecting(); - Class domainType = context.getRepositoryInformation().getDomainType(); - Object actualReturnType = queryMethod.getParameters().hasDynamicProjection() || isProjecting - ? TypeName.get(context.getActualReturnType().getType()) - : domainType; - - builder.add("\n"); - - if (queryMethod.getParameters().hasDynamicProjection()) { - builder.addStatement("$T<$T> $L = $L.query($T.class).as($L)", FindWithQuery.class, actualReturnType, - context.localVariable("finder"), mongoOpsRef, domainType, context.getDynamicProjectionParameterName()); - } else if (isProjecting) { - builder.addStatement("$T<$T> $L = $L.query($T.class).as($T.class)", FindWithQuery.class, actualReturnType, - context.localVariable("finder"), mongoOpsRef, domainType, actualReturnType); - } else { - - builder.addStatement("$T<$T> $L = $L.query($T.class)", FindWithQuery.class, actualReturnType, - context.localVariable("finder"), mongoOpsRef, domainType); - } - - String terminatingMethod; + static CodeBlock renderArgumentMap(Map arguments) { - if (queryMethod.isCollectionQuery() || queryMethod.isPageQuery() || queryMethod.isSliceQuery()) { - terminatingMethod = "all()"; - } else if (query.isCount()) { - terminatingMethod = "count()"; - } else if (query.isExists()) { - terminatingMethod = "exists()"; - } else if (queryMethod.isStreamQuery()) { - terminatingMethod = "stream()"; - } else { - terminatingMethod = Optional.class.isAssignableFrom(context.getReturnType().toClass()) ? "one()" : "oneValue()"; - } - - if (queryMethod.isPageQuery()) { - builder.addStatement("return new $T($L, $L).execute($L)", PagedExecution.class, context.localVariable("finder"), - context.getPageableParameterName(), query.name()); - } else if (queryMethod.isSliceQuery()) { - builder.addStatement("return new $T($L, $L).execute($L)", SlicedExecution.class, - context.localVariable("finder"), context.getPageableParameterName(), query.name()); - } else if (queryMethod.isScrollQuery()) { - - String scrollPositionParameterName = context.getScrollPositionParameterName(); - - builder.addStatement("return $L.matching($L).scroll($L)", context.localVariable("finder"), query.name(), - scrollPositionParameterName); - } else { - if (query.isCount() && !ClassUtils.isAssignable(Long.class, context.getActualReturnType().getRawClass())) { - - Class returnType = ClassUtils.resolvePrimitiveIfNecessary(queryMethod.getReturnedObjectType()); - builder.addStatement("return $T.convertNumberToTargetClass($L.matching($L).$L, $T.class)", NumberUtils.class, - context.localVariable("finder"), query.name(), terminatingMethod, returnType); - - } else { - builder.addStatement("return $L.matching($L).$L", context.localVariable("finder"), query.name(), - terminatingMethod); - } + Builder builder = CodeBlock.builder(); + builder.add("argumentMap("); + Iterator> iterator = arguments.entrySet().iterator(); + while (iterator.hasNext()) { + Entry next = iterator.next(); + builder.add("$S, ", next.getKey()); + builder.add(next.getValue()); + if (iterator.hasNext()) { + builder.add(", "); } - - return builder.build(); } + builder.add(")"); + return builder.build(); } - @NullUnmarked - static class AggregationCodeBlockBuilder { - - private final AotQueryMethodGenerationContext context; - private final MongoQueryMethod queryMethod; - - private AggregationInteraction source; - private final List arguments; - private String aggregationVariableName; - private boolean pipelineOnly; - - AggregationCodeBlockBuilder(AotQueryMethodGenerationContext context, MongoQueryMethod queryMethod) { - - this.context = context; - this.arguments = context.getBindableParameterNames(); - this.queryMethod = queryMethod; - } - - AggregationCodeBlockBuilder stages(AggregationInteraction aggregation) { - - this.source = aggregation; - return this; - } - - AggregationCodeBlockBuilder usingAggregationVariableName(String aggregationVariableName) { - - this.aggregationVariableName = aggregationVariableName; - return this; - } - - AggregationCodeBlockBuilder pipelineOnly(boolean pipelineOnly) { - - this.pipelineOnly = pipelineOnly; - return this; - } - - CodeBlock build() { - - CodeBlock.Builder builder = CodeBlock.builder(); - builder.add("\n"); - - String pipelineName = context.localVariable(aggregationVariableName + (pipelineOnly ? "" : "Pipeline")); - builder.add(pipeline(pipelineName)); - - if (!pipelineOnly) { - - builder.addStatement("$T<$T> $L = $T.newAggregation($T.class, $L.getOperations())", TypedAggregation.class, - context.getRepositoryInformation().getDomainType(), aggregationVariableName, Aggregation.class, - context.getRepositoryInformation().getDomainType(), pipelineName); + static CodeBlock renderArgumentArray(Map arguments) { - builder.add(aggregationOptions(aggregationVariableName)); - } - - return builder.build(); - } - - private CodeBlock pipeline(String pipelineVariableName) { - - String sortParameter = context.getSortParameterName(); - String limitParameter = context.getLimitParameterName(); - String pageableParameter = context.getPageableParameterName(); - - boolean mightBeSorted = StringUtils.hasText(sortParameter); - boolean mightBeLimited = StringUtils.hasText(limitParameter); - boolean mightBePaged = StringUtils.hasText(pageableParameter); - - int stageCount = source.stages().size(); - if (mightBeSorted) { - stageCount++; - } - if (mightBeLimited) { - stageCount++; - } - if (mightBePaged) { - stageCount += 3; - } - - Builder builder = CodeBlock.builder(); - builder.add(aggregationStages(context.localVariable("stages"), source.stages(), stageCount, arguments)); - - if (mightBeSorted) { - builder.add(sortingStage(sortParameter)); - } - - if (mightBeLimited) { - builder.add(limitingStage(limitParameter)); - } - - if (mightBePaged) { - builder.add(pagingStage(pageableParameter, queryMethod.isSliceQuery())); - } - - builder.addStatement("$T $L = createPipeline($L)", AggregationPipeline.class, pipelineVariableName, - context.localVariable("stages")); - return builder.build(); - } - - private CodeBlock aggregationOptions(String aggregationVariableName) { - - Builder builder = CodeBlock.builder(); - List options = new ArrayList<>(5); - if (ReflectionUtils.isVoid(queryMethod.getReturnedObjectType())) { - options.add(CodeBlock.of(".skipOutput()")); - } - - MergedAnnotation hintAnnotation = context.getAnnotation(Hint.class); - String hint = hintAnnotation.isPresent() ? hintAnnotation.getString("value") : null; - if (StringUtils.hasText(hint)) { - options.add(CodeBlock.of(".hint($S)", hint)); - } - - MergedAnnotation readPreferenceAnnotation = context.getAnnotation(ReadPreference.class); - String readPreference = readPreferenceAnnotation.isPresent() ? readPreferenceAnnotation.getString("value") : null; - if (StringUtils.hasText(readPreference)) { - options.add(CodeBlock.of(".readPreference($T.valueOf($S))", com.mongodb.ReadPreference.class, readPreference)); - } - - if (queryMethod.hasAnnotatedCollation()) { - options.add(CodeBlock.of(".collation($T.parse($S))", Collation.class, queryMethod.getAnnotatedCollation())); - } - - if (!options.isEmpty()) { - - Builder optionsBuilder = CodeBlock.builder(); - optionsBuilder.add("$T $L = $T.builder()\n", AggregationOptions.class, - context.localVariable("aggregationOptions"), AggregationOptions.class); - optionsBuilder.indent(); - for (CodeBlock optionBlock : options) { - optionsBuilder.add(optionBlock); - optionsBuilder.add("\n"); - } - optionsBuilder.add(".build();\n"); - optionsBuilder.unindent(); - builder.add(optionsBuilder.build()); - - builder.addStatement("$L = $L.withOptions($L)", aggregationVariableName, aggregationVariableName, - context.localVariable("aggregationOptions")); - } - return builder.build(); - } - - private CodeBlock aggregationStages(String stageListVariableName, Iterable stages, int stageCount, - List arguments) { - - Builder builder = CodeBlock.builder(); - builder.addStatement("$T<$T> $L = new $T($L)", List.class, Object.class, stageListVariableName, ArrayList.class, - stageCount); - int stageCounter = 0; - - for (String stage : stages) { - String stageName = context.localVariable("stage_%s".formatted(stageCounter++)); - builder.add(renderExpressionToDocument(stage, stageName, arguments)); - builder.addStatement("$L.add($L)", context.localVariable("stages"), stageName); - } - - return builder.build(); - } - - private CodeBlock sortingStage(String sortProvider) { - - Builder builder = CodeBlock.builder(); - - builder.beginControlFlow("if ($L.isSorted())", sortProvider); - builder.addStatement("$T $L = new $T()", Document.class, context.localVariable("sortDocument"), Document.class); - builder.beginControlFlow("for ($T $L : $L)", Order.class, context.localVariable("order"), sortProvider); - builder.addStatement("$L.append($L.getProperty(), $L.isAscending() ? 1 : -1);", - context.localVariable("sortDocument"), context.localVariable("order"), context.localVariable("order")); - builder.endControlFlow(); - builder.addStatement("stages.add(new $T($S, $L))", Document.class, "$sort", - context.localVariable("sortDocument")); - builder.endControlFlow(); - - return builder.build(); - } - - private CodeBlock pagingStage(String pageableProvider, boolean slice) { - - Builder builder = CodeBlock.builder(); - - builder.add(sortingStage(pageableProvider + ".getSort()")); - - builder.beginControlFlow("if ($L.isPaged())", pageableProvider); - builder.beginControlFlow("if ($L.getOffset() > 0)", pageableProvider); - builder.addStatement("$L.add($T.skip($L.getOffset()))", context.localVariable("stages"), Aggregation.class, - pageableProvider); - builder.endControlFlow(); - if (slice) { - builder.addStatement("$L.add($T.limit($L.getPageSize() + 1))", context.localVariable("stages"), - Aggregation.class, pageableProvider); - } else { - builder.addStatement("$L.add($T.limit($L.getPageSize()))", context.localVariable("stages"), Aggregation.class, - pageableProvider); + Builder builder = CodeBlock.builder(); + builder.add("arguments("); + Iterator iterator = arguments.values().iterator(); + while (iterator.hasNext()) { + builder.add(iterator.next()); + if (iterator.hasNext()) { + builder.add(", "); } - builder.endControlFlow(); - - return builder.build(); } - - private CodeBlock limitingStage(String limitProvider) { - - Builder builder = CodeBlock.builder(); - - builder.beginControlFlow("if ($L.isLimited())", limitProvider); - builder.addStatement("$L.add($T.limit($L.max()))", context.localVariable("stages"), Aggregation.class, - limitProvider); - builder.endControlFlow(); - - return builder.build(); - } - + builder.add(")"); + return builder.build(); } - @NullUnmarked - static class QueryCodeBlockBuilder { - - private final AotQueryMethodGenerationContext context; - private final MongoQueryMethod queryMethod; - - private QueryInteraction source; - private final List arguments; - private String queryVariableName; - - QueryCodeBlockBuilder(AotQueryMethodGenerationContext context, MongoQueryMethod queryMethod) { - - this.context = context; - this.arguments = context.getBindableParameterNames(); - this.queryMethod = queryMethod; - } - - QueryCodeBlockBuilder filter(QueryInteraction query) { - - this.source = query; - return this; - } - - QueryCodeBlockBuilder usingQueryVariableName(String queryVariableName) { - this.queryVariableName = queryVariableName; - return this; - } - - CodeBlock build() { - - CodeBlock.Builder builder = CodeBlock.builder(); - - builder.add("\n"); - builder.add(renderExpressionToQuery(source.getQuery().getQueryString(), queryVariableName)); - - if (StringUtils.hasText(source.getQuery().getFieldsString())) { - - builder.add(renderExpressionToDocument(source.getQuery().getFieldsString(), "fields", arguments)); - builder.addStatement("$L.setFieldsObject(fields)", queryVariableName); - } - - String sortParameter = context.getSortParameterName(); - if (StringUtils.hasText(sortParameter)) { - builder.addStatement("$L.with($L)", queryVariableName, sortParameter); - } else if (StringUtils.hasText(source.getQuery().getSortString())) { - - builder.add(renderExpressionToDocument(source.getQuery().getSortString(), "sort", arguments)); - builder.addStatement("$L.setSortObject(sort)", queryVariableName); - } - - String limitParameter = context.getLimitParameterName(); - if (StringUtils.hasText(limitParameter)) { - builder.addStatement("$L.limit($L)", queryVariableName, limitParameter); - } else if (context.getPageableParameterName() == null && source.getQuery().isLimited()) { - builder.addStatement("$L.limit($L)", queryVariableName, source.getQuery().getLimit()); - } - - String pageableParameter = context.getPageableParameterName(); - if (StringUtils.hasText(pageableParameter) && !queryMethod.isPageQuery() && !queryMethod.isSliceQuery()) { - builder.addStatement("$L.with($L)", queryVariableName, pageableParameter); - } - - MergedAnnotation hintAnnotation = context.getAnnotation(Hint.class); - String hint = hintAnnotation.isPresent() ? hintAnnotation.getString("value") : null; - - if (StringUtils.hasText(hint)) { - builder.addStatement("$L.withHint($S)", queryVariableName, hint); - } - - MergedAnnotation readPreferenceAnnotation = context.getAnnotation(ReadPreference.class); - String readPreference = readPreferenceAnnotation.isPresent() ? readPreferenceAnnotation.getString("value") : null; - - if (StringUtils.hasText(readPreference)) { - builder.addStatement("$L.withReadPreference($T.valueOf($S))", queryVariableName, - com.mongodb.ReadPreference.class, readPreference); - } - - MergedAnnotation metaAnnotation = context.getAnnotation(Meta.class); - - if (metaAnnotation.isPresent()) { - - long maxExecutionTimeMs = metaAnnotation.getLong("maxExecutionTimeMs"); - if (maxExecutionTimeMs != -1) { - builder.addStatement("$L.maxTimeMsec($L)", queryVariableName, maxExecutionTimeMs); - } - - int cursorBatchSize = metaAnnotation.getInt("cursorBatchSize"); - if (cursorBatchSize != 0) { - builder.addStatement("$L.cursorBatchSize($L)", queryVariableName, cursorBatchSize); - } - - String comment = metaAnnotation.getString("comment"); - if (StringUtils.hasText("comment")) { - builder.addStatement("$L.comment($S)", queryVariableName, comment); - } - } - - // TODO: Meta annotation: Disk usage - - return builder.build(); - } - - private CodeBlock renderExpressionToQuery(@Nullable String source, String variableName) { + static CodeBlock evaluateNumberPotentially(String value, Class targetType, + Map arguments) { + try { + Number number = NumberUtils.parseNumber(value, targetType); + return CodeBlock.of("$L", number); + } catch (IllegalArgumentException e) { Builder builder = CodeBlock.builder(); - if (!StringUtils.hasText(source)) { - - builder.addStatement("$T $L = new $T(new $T())", BasicQuery.class, variableName, BasicQuery.class, - Document.class); - } else if (!containsPlaceholder(source)) { - builder.addStatement("$T $L = new $T($T.parse($S))", BasicQuery.class, variableName, BasicQuery.class, - Document.class, source); - } else { - builder.addStatement("$T $L = createQuery($S, new $T[]{ $L })", BasicQuery.class, variableName, source, - Object.class, StringUtils.collectionToDelimitedString(arguments, ", ")); - } - + builder.add("($T) evaluate($S, ", targetType, value); + builder.add(MongoCodeBlocks.renderArgumentMap(arguments)); + builder.add(")"); return builder.build(); } } - @NullUnmarked - static class UpdateCodeBlockBuilder { - - private UpdateInteraction source; - private List arguments; - private String updateVariableName; - - public UpdateCodeBlockBuilder(AotQueryMethodGenerationContext context, MongoQueryMethod queryMethod) { - this.arguments = context.getBindableParameterNames(); - } - - public UpdateCodeBlockBuilder update(UpdateInteraction update) { - this.source = update; - return this; - } - - public UpdateCodeBlockBuilder usingUpdateVariableName(String updateVariableName) { - this.updateVariableName = updateVariableName; - return this; - } - - CodeBlock build() { + static boolean containsPlaceholder(String source) { + return containsIndexedPlaceholder(source) || containsNamedPlaceholder(source); + } - CodeBlock.Builder builder = CodeBlock.builder(); + static boolean containsExpression(String source) { + return VALUE_EXPRESSION_PATTERN.matcher(source).find(); + } - builder.add("\n"); - String tmpVariableName = updateVariableName + "Document"; - builder.add(renderExpressionToDocument(source.getUpdate().getUpdateString(), tmpVariableName, arguments)); - builder.addStatement("$T $L = new $T($L)", BasicUpdate.class, updateVariableName, BasicUpdate.class, - tmpVariableName); + static boolean containsNamedPlaceholder(String source) { + return EXPRESSION_BINDING_PATTERN.matcher(source).find(); + } - return builder.build(); - } + static boolean containsIndexedPlaceholder(String source) { + return PARAMETER_BINDING_PATTERN.matcher(source).find(); } - private static CodeBlock renderExpressionToDocument(@Nullable String source, String variableName, - List arguments) { + static void appendReadPreference(AotQueryMethodGenerationContext context, Builder builder, String queryVariableName) { - Builder builder = CodeBlock.builder(); - if (!StringUtils.hasText(source)) { - builder.addStatement("$T $L = new $T()", Document.class, variableName, Document.class); - } else if (!containsPlaceholder(source)) { - builder.addStatement("$T $L = $T.parse($S)", Document.class, variableName, Document.class, source); - } else { - builder.addStatement("$T $L = bindParameters($S, new $T[]{ $L })", Document.class, variableName, source, - Object.class, StringUtils.collectionToDelimitedString(arguments, ", ")); - } - return builder.build(); - } + MergedAnnotation readPreferenceAnnotation = context.getAnnotation(ReadPreference.class); + String readPreference = readPreferenceAnnotation.isPresent() ? readPreferenceAnnotation.getString("value") : null; - private static boolean containsPlaceholder(String source) { - return PARAMETER_BINDING_PATTERN.matcher(source).find(); + if (StringUtils.hasText(readPreference)) { + builder.addStatement("$L.withReadPreference($T.valueOf($S))", queryVariableName, com.mongodb.ReadPreference.class, + readPreference); + } } } diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributor.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributor.java index 424d067d74..524c5e8f23 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributor.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributor.java @@ -15,11 +15,18 @@ */ package org.springframework.data.mongodb.repository.aot; -import static org.springframework.data.mongodb.repository.aot.MongoCodeBlocks.*; +import static org.springframework.data.mongodb.repository.aot.MongoCodeBlocks.aggregationBlockBuilder; +import static org.springframework.data.mongodb.repository.aot.MongoCodeBlocks.aggregationExecutionBlockBuilder; +import static org.springframework.data.mongodb.repository.aot.MongoCodeBlocks.deleteExecutionBlockBuilder; +import static org.springframework.data.mongodb.repository.aot.MongoCodeBlocks.geoNearBlockBuilder; +import static org.springframework.data.mongodb.repository.aot.MongoCodeBlocks.geoNearExecutionBlockBuilder; +import static org.springframework.data.mongodb.repository.aot.MongoCodeBlocks.queryBlockBuilder; +import static org.springframework.data.mongodb.repository.aot.MongoCodeBlocks.queryExecutionBlockBuilder; +import static org.springframework.data.mongodb.repository.aot.MongoCodeBlocks.updateBlockBuilder; +import static org.springframework.data.mongodb.repository.aot.MongoCodeBlocks.updateExecutionBlockBuilder; +import static org.springframework.data.mongodb.repository.aot.QueryBlocks.QueryCodeBlockBuilder; import java.lang.reflect.Method; -import java.util.Locale; -import java.util.regex.Pattern; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; @@ -30,6 +37,7 @@ import org.springframework.data.mongodb.core.mapping.MongoMappingContext; import org.springframework.data.mongodb.repository.Query; import org.springframework.data.mongodb.repository.Update; +import org.springframework.data.mongodb.repository.VectorSearch; import org.springframework.data.mongodb.repository.query.MongoQueryMethod; import org.springframework.data.repository.aot.generate.AotRepositoryClassBuilder; import org.springframework.data.repository.aot.generate.AotRepositoryConstructorBuilder; @@ -90,28 +98,26 @@ protected void customizeConstructor(AotRepositoryConstructorBuilder constructorB MongoQueryMethod queryMethod = new MongoQueryMethod(method, getRepositoryInformation(), getProjectionFactory(), mappingContext); + if (backoff(queryMethod)) { + return null; + } + if (queryMethod.hasAnnotatedAggregation()) { AggregationInteraction aggregation = new AggregationInteraction(queryMethod.getAnnotatedAggregation()); return aggregationMethodContributor(queryMethod, aggregation); } QueryInteraction query = createStringQuery(getRepositoryInformation(), queryMethod, - AnnotatedElementUtils.findMergedAnnotation(method, Query.class), method.getParameterCount()); + AnnotatedElementUtils.findMergedAnnotation(method, Query.class), method); - if (queryMethod.hasAnnotatedQuery()) { - if (StringUtils.hasText(queryMethod.getAnnotatedQuery()) - && Pattern.compile("[\\?:][#$]\\{.*\\}").matcher(queryMethod.getAnnotatedQuery()).find()) { - - if (logger.isDebugEnabled()) { - logger.debug( - "Skipping AOT generation for [%s]. SpEL expressions are not supported".formatted(method.getName())); - } - return MethodContributor.forQueryMethod(queryMethod).metadataOnly(query); - } + if (queryMethod.isSearchQuery() || method.isAnnotationPresent(VectorSearch.class)) { + return searchMethodContributor(queryMethod, new SearchInteraction(query.getQuery())); } - if (backoff(queryMethod)) { - return null; + if (queryMethod.isGeoNearQuery() || (queryMethod.getParameters().getMaxDistanceIndex() != -1 + && queryMethod.getReturnType().isCollectionLike())) { + NearQueryInteraction near = new NearQueryInteraction(query, queryMethod.getParameters()); + return nearQueryMethodContributor(queryMethod, near); } if (query.isDelete()) { @@ -125,8 +131,8 @@ protected void customizeConstructor(AotRepositoryConstructorBuilder constructorB UpdateInteraction update = new UpdateInteraction(query, null, updateIndex); return updateMethodContributor(queryMethod, update); - } else { + Update updateSource = queryMethod.getUpdateSource(); if (StringUtils.hasText(updateSource.value())) { UpdateInteraction update = new UpdateInteraction(query, new StringUpdate(updateSource.value()), null); @@ -145,7 +151,7 @@ protected void customizeConstructor(AotRepositoryConstructorBuilder constructorB @SuppressWarnings("NullAway") private QueryInteraction createStringQuery(RepositoryInformation repositoryInformation, MongoQueryMethod queryMethod, - @Nullable Query queryAnnotation, int parameterCount) { + @Nullable Query queryAnnotation, Method source) { QueryInteraction query; if (queryMethod.hasAnnotatedQuery() && queryAnnotation != null) { @@ -154,8 +160,8 @@ private QueryInteraction createStringQuery(RepositoryInformation repositoryInfor } else { PartTree partTree = new PartTree(queryMethod.getName(), repositoryInformation.getDomainType()); - query = new QueryInteraction(queryCreator.createQuery(partTree, parameterCount), partTree.isCountProjection(), - partTree.isDelete(), partTree.isExistsProjection()); + query = new QueryInteraction(queryCreator.createQuery(partTree, queryMethod, source), + partTree.isCountProjection(), partTree.isDelete(), partTree.isExistsProjection()); } if (queryAnnotation != null && StringUtils.hasText(queryAnnotation.sort())) { @@ -171,8 +177,7 @@ private QueryInteraction createStringQuery(RepositoryInformation repositoryInfor private static boolean backoff(MongoQueryMethod method) { // TODO: namedQuery, Regex queries, queries accepting Shapes (e.g. within) or returning arrays. - boolean skip = method.isGeoNearQuery() || method.isSearchQuery() - || method.getName().toLowerCase(Locale.ROOT).contains("regex") || method.getReturnType().getType().isArray(); + boolean skip = method.getReturnType().getType().isArray(); if (skip && logger.isDebugEnabled()) { logger.debug("Skipping AOT generation for [%s]. Method is either returning an array or a geo-near, regex query" @@ -181,22 +186,61 @@ private static boolean backoff(MongoQueryMethod method) { return skip; } - private static MethodContributor aggregationMethodContributor(MongoQueryMethod queryMethod, + private static MethodContributor nearQueryMethodContributor(MongoQueryMethod queryMethod, + NearQueryInteraction interaction) { + + return MethodContributor.forQueryMethod(queryMethod).withMetadata(interaction).contribute(context -> { + + CodeBlock.Builder builder = CodeBlock.builder(); + + String variableName = context.localVariable("nearQuery"); + builder.add(geoNearBlockBuilder(context, queryMethod).usingQueryVariableName(variableName).build()); + + if (!context.getBindableParameterNames().isEmpty()) { + String filterQueryVariableName = context.localVariable("filterQuery"); + builder.add(queryBlockBuilder(context, queryMethod).usingQueryVariableName(filterQueryVariableName) + .filter(interaction.getQuery()).build()); + builder.addStatement("$L.query($L)", variableName, filterQueryVariableName); + } + + builder.add(geoNearExecutionBlockBuilder(context, queryMethod).referencing(variableName).build()); + + return builder.build(); + }); + } + + static MethodContributor aggregationMethodContributor(MongoQueryMethod queryMethod, AggregationInteraction aggregation) { return MethodContributor.forQueryMethod(queryMethod).withMetadata(aggregation).contribute(context -> { CodeBlock.Builder builder = CodeBlock.builder(); + String variableName = "aggregation"; builder.add(aggregationBlockBuilder(context, queryMethod).stages(aggregation) - .usingAggregationVariableName("aggregation").build()); - builder.add(aggregationExecutionBlockBuilder(context, queryMethod).referencing("aggregation").build()); + .usingAggregationVariableName(variableName).build()); + builder.add(aggregationExecutionBlockBuilder(context, queryMethod).referencing(variableName).build()); + + return builder.build(); + }); + } + + static MethodContributor searchMethodContributor(MongoQueryMethod queryMethod, + SearchInteraction interaction) { + return MethodContributor.forQueryMethod(queryMethod).withMetadata(interaction).contribute(context -> { + + CodeBlock.Builder builder = CodeBlock.builder(); + + String variableName = "search"; + + builder.add(new VectorSearchBocks.VectorSearchQueryCodeBlockBuilder(context, queryMethod) + .usingVariableName(variableName).withFilter(interaction.getFilter()).build()); return builder.build(); }); } - private static MethodContributor updateMethodContributor(MongoQueryMethod queryMethod, + static MethodContributor updateMethodContributor(MongoQueryMethod queryMethod, UpdateInteraction update) { return MethodContributor.forQueryMethod(queryMethod).withMetadata(update).contribute(context -> { @@ -225,7 +269,7 @@ private static MethodContributor updateMethodContributor(Mongo }); } - private static MethodContributor aggregationUpdateMethodContributor(MongoQueryMethod queryMethod, + static MethodContributor aggregationUpdateMethodContributor(MongoQueryMethod queryMethod, AggregationUpdateInteraction update) { return MethodContributor.forQueryMethod(queryMethod).withMetadata(update).contribute(context -> { @@ -251,7 +295,7 @@ private static MethodContributor aggregationUpdateMethodContri }); } - private static MethodContributor deleteMethodContributor(MongoQueryMethod queryMethod, + static MethodContributor deleteMethodContributor(MongoQueryMethod queryMethod, QueryInteraction query) { return MethodContributor.forQueryMethod(queryMethod).withMetadata(query).contribute(context -> { @@ -266,7 +310,7 @@ private static MethodContributor deleteMethodContributor(Mongo }); } - private static MethodContributor queryMethodContributor(MongoQueryMethod queryMethod, + static MethodContributor queryMethodContributor(MongoQueryMethod queryMethod, QueryInteraction query) { return MethodContributor.forQueryMethod(queryMethod).withMetadata(query).contribute(context -> { diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/NearQueryInteraction.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/NearQueryInteraction.java new file mode 100644 index 0000000000..2005626784 --- /dev/null +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/NearQueryInteraction.java @@ -0,0 +1,68 @@ +/* + * Copyright 2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.mongodb.repository.aot; + +import java.util.LinkedHashMap; +import java.util.Map; + +import org.springframework.data.mongodb.repository.query.MongoParameters; +import org.springframework.data.repository.aot.generate.QueryMetadata; + +/** + * An {@link MongoInteraction} to execute a query. + * + * @author Christoph Strobl + * @since 5.0 + */ +class NearQueryInteraction extends MongoInteraction implements QueryMetadata { + + private final InteractionType interactionType; + private final QueryInteraction query; + private final MongoParameters parameters; + + NearQueryInteraction(QueryInteraction query, MongoParameters parameters) { + interactionType = InteractionType.QUERY; + this.query = query; + this.parameters = parameters; + } + + @Override + InteractionType getExecutionType() { + return interactionType; + } + + public QueryInteraction getQuery() { + return query; + } + + @Override + public Map serialize() { + + Map serialized = new LinkedHashMap<>(); + serialized.put("near", "?%s".formatted(parameters.getNearIndex())); + if (parameters.getRangeIndex() != -1) { + serialized.put("minDistance", "?%s".formatted(parameters.getRangeIndex())); + serialized.put("maxDistance", "?%s".formatted(parameters.getRangeIndex())); + } else if (parameters.getMaxDistanceIndex() != -1) { + serialized.put("minDistance", "?%s".formatted(parameters.getMaxDistanceIndex())); + } + Object filter = query.serialize().get("filter"); // TODO: filter position index can be off due to bindable params + if (filter != null) { + serialized.put("filter", filter); + } + return serialized; + } +} diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/QueryBlocks.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/QueryBlocks.java new file mode 100644 index 0000000000..7ad0c25b16 --- /dev/null +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/QueryBlocks.java @@ -0,0 +1,321 @@ +/* + * Copyright 2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.mongodb.repository.aot; + +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; + +import org.bson.Document; +import org.jspecify.annotations.NullUnmarked; +import org.jspecify.annotations.Nullable; +import org.springframework.core.annotation.MergedAnnotation; +import org.springframework.data.geo.Box; +import org.springframework.data.geo.Circle; +import org.springframework.data.geo.Polygon; +import org.springframework.data.mongodb.core.ExecutableFindOperation.FindWithQuery; +import org.springframework.data.mongodb.core.MongoOperations; +import org.springframework.data.mongodb.core.annotation.Collation; +import org.springframework.data.mongodb.core.geo.GeoJson; +import org.springframework.data.mongodb.core.geo.Sphere; +import org.springframework.data.mongodb.core.query.BasicQuery; +import org.springframework.data.mongodb.repository.Hint; +import org.springframework.data.mongodb.repository.Meta; +import org.springframework.data.mongodb.repository.query.MongoParameters.MongoParameter; +import org.springframework.data.mongodb.repository.query.MongoQueryExecution.PagedExecution; +import org.springframework.data.mongodb.repository.query.MongoQueryExecution.SlicedExecution; +import org.springframework.data.mongodb.repository.query.MongoQueryMethod; +import org.springframework.data.repository.aot.generate.AotQueryMethodGenerationContext; +import org.springframework.javapoet.CodeBlock; +import org.springframework.javapoet.CodeBlock.Builder; +import org.springframework.javapoet.TypeName; +import org.springframework.util.ClassUtils; +import org.springframework.util.NumberUtils; +import org.springframework.util.StringUtils; + +/** + * @author Christoph Strobl + * @since 5.0 + */ +class QueryBlocks { + + @NullUnmarked + static class QueryExecutionCodeBlockBuilder { + + private final AotQueryMethodGenerationContext context; + private final MongoQueryMethod queryMethod; + private QueryInteraction query; + + QueryExecutionCodeBlockBuilder(AotQueryMethodGenerationContext context, MongoQueryMethod queryMethod) { + + this.context = context; + this.queryMethod = queryMethod; + } + + QueryExecutionCodeBlockBuilder forQuery(QueryInteraction query) { + + this.query = query; + return this; + } + + CodeBlock build() { + + String mongoOpsRef = context.fieldNameOf(MongoOperations.class); + + Builder builder = CodeBlock.builder(); + + boolean isProjecting = context.getReturnedType().isProjecting(); + Class domainType = context.getRepositoryInformation().getDomainType(); + Object actualReturnType = queryMethod.getParameters().hasDynamicProjection() || isProjecting + ? TypeName.get(context.getActualReturnType().getType()) + : domainType; + + builder.add("\n"); + + if (queryMethod.getParameters().hasDynamicProjection()) { + builder.addStatement("$T<$T> $L = $L.query($T.class).as($L)", FindWithQuery.class, actualReturnType, + context.localVariable("finder"), mongoOpsRef, domainType, context.getDynamicProjectionParameterName()); + } else if (isProjecting) { + builder.addStatement("$T<$T> $L = $L.query($T.class).as($T.class)", FindWithQuery.class, actualReturnType, + context.localVariable("finder"), mongoOpsRef, domainType, actualReturnType); + } else { + + builder.addStatement("$T<$T> $L = $L.query($T.class)", FindWithQuery.class, actualReturnType, + context.localVariable("finder"), mongoOpsRef, domainType); + } + + String terminatingMethod; + + if (queryMethod.isCollectionQuery() || queryMethod.isPageQuery() || queryMethod.isSliceQuery()) { + terminatingMethod = "all()"; + } else if (query.isCount()) { + terminatingMethod = "count()"; + } else if (query.isExists()) { + terminatingMethod = "exists()"; + } else if (queryMethod.isStreamQuery()) { + terminatingMethod = "stream()"; + } else { + terminatingMethod = Optional.class.isAssignableFrom(context.getReturnType().toClass()) ? "one()" : "oneValue()"; + } + + if (queryMethod.isPageQuery()) { + builder.addStatement("return new $T($L, $L).execute($L)", PagedExecution.class, context.localVariable("finder"), + context.getPageableParameterName(), query.name()); + } else if (queryMethod.isSliceQuery()) { + builder.addStatement("return new $T($L, $L).execute($L)", SlicedExecution.class, + context.localVariable("finder"), context.getPageableParameterName(), query.name()); + } else if (queryMethod.isScrollQuery()) { + + String scrollPositionParameterName = context.getScrollPositionParameterName(); + + builder.addStatement("return $L.matching($L).scroll($L)", context.localVariable("finder"), query.name(), + scrollPositionParameterName); + } else { + if (query.isCount() && !ClassUtils.isAssignable(Long.class, context.getActualReturnType().getRawClass())) { + + Class returnType = ClassUtils.resolvePrimitiveIfNecessary(queryMethod.getReturnedObjectType()); + builder.addStatement("return $T.convertNumberToTargetClass($L.matching($L).$L, $T.class)", NumberUtils.class, + context.localVariable("finder"), query.name(), terminatingMethod, returnType); + + } else { + builder.addStatement("return $L.matching($L).$L", context.localVariable("finder"), query.name(), + terminatingMethod); + } + } + + return builder.build(); + } + } + + @NullUnmarked + static class QueryCodeBlockBuilder { + + private final AotQueryMethodGenerationContext context; + private final MongoQueryMethod queryMethod; + + private QueryInteraction source; + private final Map arguments; + private String queryVariableName; + + QueryCodeBlockBuilder(AotQueryMethodGenerationContext context, MongoQueryMethod queryMethod) { + + this.context = context; + + this.arguments = new LinkedHashMap<>(); + this.queryMethod = queryMethod; + collectArguments(context); + + } + + private void collectArguments(AotQueryMethodGenerationContext context) { + + for (MongoParameter parameter : queryMethod.getParameters().getBindableParameters()) { + String parameterName = context.getParameterName(parameter.getIndex()); + if (ClassUtils.isAssignable(GeoJson.class, parameter.getType())) { + + // renders as generic $geometry, thus can be handled by the converter when parsing + arguments.put(parameterName, CodeBlock.of(parameterName)); + } else if (ClassUtils.isAssignable(Circle.class, parameter.getType()) + || ClassUtils.isAssignable(Sphere.class, parameter.getType())) { + + // $center | $centerSphere : [ [ , ], ] + arguments.put(parameterName, CodeBlock.builder().add( + "$1T.of($1T.of($2L.getCenter().getX(), $2L.getCenter().getY()), $2L.getRadius().getNormalizedValue())", + List.class, parameterName).build()); + } else if (ClassUtils.isAssignable(Box.class, parameter.getType())) { + + // $box: [ [ , ], [ , ] ] + arguments.put(parameterName, CodeBlock.builder().add( + "$1T.of($1T.of($2L.getFirst().getX(), $2L.getFirst().getY()), $1T.of($2L.getSecond().getX(), $2L.getSecond().getY()))", + List.class, parameterName).build()); + } else if (ClassUtils.isAssignable(Polygon.class, parameter.getType())) { + + // $polygon: [ [ , ], [ , ], [ , ], ... ] + String localVar = context.localVariable("_p"); + arguments.put(parameterName, + CodeBlock.builder().add("$1L.getPoints().stream().map($2L -> $3T.of($2L.getX(), $2L.getY())).toList()", + parameterName, localVar, List.class).build()); + } else { + arguments.put(parameterName, CodeBlock.of(parameterName)); + } + } + } + + QueryCodeBlockBuilder filter(QueryInteraction query) { + + this.source = query; + return this; + } + + QueryCodeBlockBuilder usingQueryVariableName(String queryVariableName) { + + this.queryVariableName = queryVariableName; + return this; + } + + CodeBlock build() { + + Builder builder = CodeBlock.builder(); + + builder.add(buildJustTheQuery()); + + if (StringUtils.hasText(source.getQuery().getFieldsString())) { + + builder + .add(MongoCodeBlocks.renderExpressionToDocument(source.getQuery().getFieldsString(), "fields", arguments)); + builder.addStatement("$L.setFieldsObject(fields)", queryVariableName); + } + + String sortParameter = context.getSortParameterName(); + if (StringUtils.hasText(sortParameter)) { + builder.addStatement("$L.with($L)", queryVariableName, sortParameter); + } else if (StringUtils.hasText(source.getQuery().getSortString())) { + + builder.add(MongoCodeBlocks.renderExpressionToDocument(source.getQuery().getSortString(), "sort", arguments)); + builder.addStatement("$L.setSortObject(sort)", queryVariableName); + } + + String limitParameter = context.getLimitParameterName(); + if (StringUtils.hasText(limitParameter)) { + builder.addStatement("$L.limit($L)", queryVariableName, limitParameter); + } else if (context.getPageableParameterName() == null && source.getQuery().isLimited()) { + builder.addStatement("$L.limit($L)", queryVariableName, source.getQuery().getLimit()); + } + + String pageableParameter = context.getPageableParameterName(); + if (StringUtils.hasText(pageableParameter) && !queryMethod.isPageQuery() && !queryMethod.isSliceQuery()) { + builder.addStatement("$L.with($L)", queryVariableName, pageableParameter); + } + + MergedAnnotation hintAnnotation = context.getAnnotation(Hint.class); + String hint = hintAnnotation.isPresent() ? hintAnnotation.getString("value") : null; + + if (StringUtils.hasText(hint)) { + builder.addStatement("$L.withHint($S)", queryVariableName, hint); + } + + MongoCodeBlocks.appendReadPreference(context, builder, queryVariableName); + + MergedAnnotation metaAnnotation = context.getAnnotation(Meta.class); + if (metaAnnotation.isPresent()) { + + long maxExecutionTimeMs = metaAnnotation.getLong("maxExecutionTimeMs"); + if (maxExecutionTimeMs != -1) { + builder.addStatement("$L.maxTimeMsec($L)", queryVariableName, maxExecutionTimeMs); + } + + int cursorBatchSize = metaAnnotation.getInt("cursorBatchSize"); + if (cursorBatchSize != 0) { + builder.addStatement("$L.cursorBatchSize($L)", queryVariableName, cursorBatchSize); + } + + String comment = metaAnnotation.getString("comment"); + if (StringUtils.hasText(comment)) { + builder.addStatement("$L.comment($S)", queryVariableName, comment); + } + } + + MergedAnnotation collationAnnotation = context.getAnnotation(Collation.class); + if (collationAnnotation.isPresent()) { + + String collationString = collationAnnotation.getString("value"); + if(StringUtils.hasText(collationString)) { + if (!MongoCodeBlocks.containsPlaceholder(collationString)) { + builder.addStatement("$L.collation($T.parse($S))", queryVariableName, + org.springframework.data.mongodb.core.query.Collation.class, collationString); + } else { + builder.add("$L.collation(collationOf(evaluate($S, ", queryVariableName, collationString); + builder.add(MongoCodeBlocks.renderArgumentMap(arguments)); + builder.add(")));\n"); + } + } + } + + return builder.build(); + } + + CodeBlock buildJustTheQuery() { + + Builder builder = CodeBlock.builder(); + builder.add("\n"); + builder.add(renderExpressionToQuery(source.getQuery().getQueryString(), queryVariableName)); + return builder.build(); + } + + private CodeBlock renderExpressionToQuery(@Nullable String source, String variableName) { + + Builder builder = CodeBlock.builder(); + if (!StringUtils.hasText(source)) { + + builder.addStatement("$1T $2L = new $1T(new $3T())", BasicQuery.class, variableName, Document.class); + } else if (!MongoCodeBlocks.containsPlaceholder(source)) { + builder.addStatement("$1T $2L = new $1T($3T.parse($4S))", BasicQuery.class, variableName, Document.class, + source); + } else { + builder.add("$T $L = createQuery($S, ", BasicQuery.class, variableName, source); + if (MongoCodeBlocks.containsNamedPlaceholder(source)) { + builder.add(MongoCodeBlocks.renderArgumentMap(arguments)); + } else { + builder.add(MongoCodeBlocks.renderArgumentArray(arguments)); + } + builder.add(");\n"); + } + + return builder.build(); + } + } +} diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/SearchInteraction.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/SearchInteraction.java new file mode 100644 index 0000000000..a94ff1082b --- /dev/null +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/SearchInteraction.java @@ -0,0 +1,48 @@ +/* + * Copyright 2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.mongodb.repository.aot; + +import java.util.Map; + +import org.jspecify.annotations.Nullable; +import org.springframework.data.repository.aot.generate.QueryMetadata; + +/** + * @author Christoph Strobl + */ +public class SearchInteraction extends MongoInteraction implements QueryMetadata { + + StringQuery filter; + + public SearchInteraction(StringQuery filter) { + this.filter = filter; + } + + public StringQuery getFilter() { + return filter; + } + + @Override + InteractionType getExecutionType() { + return InteractionType.AGGREGATION; + } + + @Override + public Map serialize() { + + return Map.of("FIXME", "please!"); + } +} diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/UpdateBlocks.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/UpdateBlocks.java new file mode 100644 index 0000000000..e4061c7717 --- /dev/null +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/UpdateBlocks.java @@ -0,0 +1,147 @@ +/* + * Copyright 2025. the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * Copyright 2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.mongodb.repository.aot; + +import java.util.LinkedHashMap; +import java.util.Map; + +import org.jspecify.annotations.NullUnmarked; +import org.springframework.data.mongodb.core.ExecutableUpdateOperation.ExecutableUpdate; +import org.springframework.data.mongodb.core.MongoOperations; +import org.springframework.data.mongodb.core.query.BasicUpdate; +import org.springframework.data.mongodb.repository.query.MongoQueryMethod; +import org.springframework.data.repository.aot.generate.AotQueryMethodGenerationContext; +import org.springframework.data.util.ReflectionUtils; +import org.springframework.javapoet.CodeBlock; +import org.springframework.javapoet.CodeBlock.Builder; +import org.springframework.util.ClassUtils; +import org.springframework.util.NumberUtils; + +/** + * @author Christoph Strobl + * @since 2025/06 + */ +class UpdateBlocks { + + @NullUnmarked + static class UpdateExecutionCodeBlockBuilder { + + private final AotQueryMethodGenerationContext context; + private final MongoQueryMethod queryMethod; + private String queryVariableName; + private String updateVariableName; + + UpdateExecutionCodeBlockBuilder(AotQueryMethodGenerationContext context, MongoQueryMethod queryMethod) { + + this.context = context; + this.queryMethod = queryMethod; + } + + UpdateExecutionCodeBlockBuilder withFilter(String queryVariableName) { + + this.queryVariableName = queryVariableName; + return this; + } + + UpdateExecutionCodeBlockBuilder referencingUpdate(String updateVariableName) { + + this.updateVariableName = updateVariableName; + return this; + } + + CodeBlock build() { + + String mongoOpsRef = context.fieldNameOf(MongoOperations.class); + Builder builder = CodeBlock.builder(); + + builder.add("\n"); + + String updateReference = updateVariableName; + Class domainType = context.getRepositoryInformation().getDomainType(); + builder.addStatement("$1T<$2T> $3L = $4L.update($2T.class)", ExecutableUpdate.class, domainType, + context.localVariable("updater"), mongoOpsRef); + + Class returnType = ClassUtils.resolvePrimitiveIfNecessary(queryMethod.getReturnedObjectType()); + if (ReflectionUtils.isVoid(returnType)) { + builder.addStatement("$L.matching($L).apply($L).all()", context.localVariable("updater"), queryVariableName, + updateReference); + } else if (ClassUtils.isAssignable(Long.class, returnType)) { + builder.addStatement("return $L.matching($L).apply($L).all().getModifiedCount()", + context.localVariable("updater"), queryVariableName, updateReference); + } else { + builder.addStatement("$T $L = $L.matching($L).apply($L).all().getModifiedCount()", Long.class, + context.localVariable("modifiedCount"), context.localVariable("updater"), queryVariableName, + updateReference); + builder.addStatement("return $T.convertNumberToTargetClass($L, $T.class)", NumberUtils.class, + context.localVariable("modifiedCount"), returnType); + } + + return builder.build(); + } + } + + @NullUnmarked + static class UpdateCodeBlockBuilder { + + private UpdateInteraction source; + private Map arguments; + private String updateVariableName; + + public UpdateCodeBlockBuilder(AotQueryMethodGenerationContext context, MongoQueryMethod queryMethod) { + this.arguments = new LinkedHashMap<>(); + context.getBindableParameterNames().forEach(it -> arguments.put(it, CodeBlock.of(it))); + } + + public UpdateCodeBlockBuilder update(UpdateInteraction update) { + this.source = update; + return this; + } + + public UpdateCodeBlockBuilder usingUpdateVariableName(String updateVariableName) { + this.updateVariableName = updateVariableName; + return this; + } + + CodeBlock build() { + + Builder builder = CodeBlock.builder(); + + builder.add("\n"); + String tmpVariableName = updateVariableName + "Document"; + builder.add( + MongoCodeBlocks.renderExpressionToDocument(source.getUpdate().getUpdateString(), tmpVariableName, arguments)); + builder.addStatement("$1T $2L = new $1T($3L)", BasicUpdate.class, updateVariableName, tmpVariableName); + + return builder.build(); + } + } +} diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/VectorSearchBocks.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/VectorSearchBocks.java new file mode 100644 index 0000000000..3efdc080b2 --- /dev/null +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/VectorSearchBocks.java @@ -0,0 +1,211 @@ +/* + * Copyright 2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.mongodb.repository.aot; + +import java.lang.reflect.Field; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; + +import org.bson.Document; +import org.springframework.core.annotation.MergedAnnotation; +import org.springframework.data.domain.Limit; +import org.springframework.data.domain.ScoringFunction; +import org.springframework.data.domain.Sort; +import org.springframework.data.domain.Vector; +import org.springframework.data.mongodb.core.MongoOperations; +import org.springframework.data.mongodb.core.aggregation.Aggregation; +import org.springframework.data.mongodb.core.aggregation.AggregationOperation; +import org.springframework.data.mongodb.core.aggregation.AggregationOperationContext; +import org.springframework.data.mongodb.core.aggregation.AggregationPipeline; +import org.springframework.data.mongodb.core.aggregation.VectorSearchOperation; +import org.springframework.data.mongodb.core.aggregation.VectorSearchOperation.SearchType; +import org.springframework.data.mongodb.repository.VectorSearch; +import org.springframework.data.mongodb.repository.query.MongoQueryExecution.VectorSearchExecution; +import org.springframework.data.mongodb.repository.query.MongoQueryMethod; +import org.springframework.data.repository.aot.generate.AotQueryMethodGenerationContext; +import org.springframework.data.util.TypeInformation; +import org.springframework.javapoet.CodeBlock; +import org.springframework.javapoet.CodeBlock.Builder; +import org.springframework.util.StringUtils; + +/** + * @author Christoph Strobl + * @since 5.0 + */ +class VectorSearchBocks { + + static class VectorSearchQueryCodeBlockBuilder { + + private final AotQueryMethodGenerationContext context; + private final MongoQueryMethod queryMethod; + private String searchQueryVariableName; + private StringQuery filter; + private final Map arguments; + + VectorSearchQueryCodeBlockBuilder(AotQueryMethodGenerationContext context, MongoQueryMethod queryMethod) { + + this.context = context; + this.queryMethod = queryMethod; + this.arguments = new LinkedHashMap<>(); + context.getBindableParameterNames().forEach(it -> arguments.put(it, CodeBlock.of(it))); + } + + VectorSearchQueryCodeBlockBuilder usingVariableName(String searchQueryVariableName) { + + this.searchQueryVariableName = searchQueryVariableName; + return this; + } + + CodeBlock build() { + + Builder builder = CodeBlock.builder(); + + String vectorParameterName = context.getVectorParameterName(); + + MergedAnnotation annotation = context.getAnnotation(VectorSearch.class); + String searchPath = annotation.getString("path"); + String indexName = annotation.getString("indexName"); + String numCandidates = annotation.getString("numCandidates"); + SearchType searchType = annotation.getEnum("searchType", SearchType.class); + String limit = annotation.getString("limit"); + + if (!StringUtils.hasText(searchPath)) { // FIXME: somehow duplicate logic of AnnotatedQueryFactory + + Field[] declaredFields = context.getRepositoryInformation().getDomainType().getDeclaredFields(); + for (Field field : declaredFields) { + if (Vector.class.isAssignableFrom(field.getType())) { + searchPath = field.getName(); + break; + } + } + + } + + String vectorSearchVar = context.localVariable("$vectorSearch"); + builder.add("$T $L = $T.vectorSearch($S).path($S).vector($L)", VectorSearchOperation.class, vectorSearchVar, + Aggregation.class, indexName, searchPath, vectorParameterName); + + if (StringUtils.hasText(context.getLimitParameterName())) { + builder.add(".limit($L);\n", context.getLimitParameterName()); + } else if (filter.isLimited()) { + builder.add(".limit($L);\n", filter.getLimit()); + } else if (StringUtils.hasText(limit)) { + if (MongoCodeBlocks.containsPlaceholder(limit) || MongoCodeBlocks.containsExpression(limit)) { + builder.add(".limit("); + builder.add(MongoCodeBlocks.evaluateNumberPotentially(limit, Integer.class, arguments)); + builder.add(");\n"); + } else { + builder.add(".limit($L);\n", limit); + } + } else { + builder.add(".limit($T.unlimited());\n", Limit.class); + } + + if (!searchType.equals(SearchType.DEFAULT)) { + builder.addStatement("$1L = $1L.searchType($2T.$3L)", vectorSearchVar, SearchType.class, searchType.name()); + } + + if (StringUtils.hasText(numCandidates)) { + builder.add("$1L = $1L.numCandidates(", vectorSearchVar); + builder.add(MongoCodeBlocks.evaluateNumberPotentially(numCandidates, Integer.class, arguments)); + builder.add(");\n"); + } else if (searchType == VectorSearchOperation.SearchType.ANN + || searchType == VectorSearchOperation.SearchType.DEFAULT) { + + builder.add( + "// MongoDB: We recommend that you specify a number at least 20 times higher than the number of documents to return\n"); + if (StringUtils.hasText(context.getLimitParameterName())) { + builder.addStatement("$1L = $1L.numCandidates($2L.max() * 20)", vectorSearchVar, + context.getLimitParameterName()); + } else if (StringUtils.hasText(limit)) { + if (MongoCodeBlocks.containsPlaceholder(limit) || MongoCodeBlocks.containsExpression(limit)) { + + builder.add("$1L = $1L.numCandidates((", vectorSearchVar); + builder.add(MongoCodeBlocks.evaluateNumberPotentially(limit, Integer.class, arguments)); + builder.add(") * 20);\n"); + } else { + builder.addStatement("$1L = $1L.numCandidates($2L * 20)", vectorSearchVar, limit); + } + } else { + builder.addStatement("$1L = $1L.numCandidates($2L)", vectorSearchVar, filter.getLimit() * 20); + } + } + + builder.addStatement("$1L = $1L.withSearchScore(\"__score__\")", vectorSearchVar); + if (StringUtils.hasText(context.getScoreParameterName())) { + + String scoreCriteriaVar = context.localVariable("criteria"); + builder.addStatement("$1L = $1L.withFilterBySore($2L -> { $2L.gt($3L.getValue()); })", vectorSearchVar, + scoreCriteriaVar, context.getScoreParameterName()); + } else if (StringUtils.hasText(context.getScoreRangeParameterName())) { + builder.addStatement("$1L = $1L.withFilterBySore(scoreBetween($2L.getLowerBound(), $2L.getUpperBound()))", + vectorSearchVar, context.getScoreRangeParameterName()); + } + + if (StringUtils.hasText(filter.getQueryString())) { + + String filterVar = context.localVariable("filter"); + builder.add(MongoCodeBlocks.queryBlockBuilder(context, queryMethod).usingQueryVariableName("filter") + .filter(new QueryInteraction(this.filter, false, false, false)).buildJustTheQuery()); + builder.addStatement("$1L = $1L.filter($2L.getQueryObject())", vectorSearchVar, filterVar); + builder.add("\n"); + } + + + String sortStageVar = context.localVariable("$sort"); + if(filter.isSorted()) { + + builder.add("$T $L = (_ctx) -> {\n", AggregationOperation.class, sortStageVar); + builder.indent(); + + builder.addStatement("$1T _mappedSort = _ctx.getMappedObject($1T.parse($2S), $3T.class)", Document.class, filter.getSortString(), context.getActualReturnType().getType()); + builder.addStatement("return new $T($S, _mappedSort.append(\"__score__\", -1))", Document.class, "$sort"); + builder.unindent(); + builder.add("};"); + + } else { + builder.addStatement("var $L = $T.sort($T.Direction.DESC, $S)", sortStageVar, Aggregation.class, Sort.class, "__score__"); + } + builder.add("\n"); + + builder.addStatement("$1T $2L = new $1T($3T.of($4L, $5L))", AggregationPipeline.class, searchQueryVariableName, + List.class, vectorSearchVar, sortStageVar); + + String scoringFunctionVar = context.localVariable("scoringFunction"); + builder.add("$1T $2L = ", ScoringFunction.class, scoringFunctionVar); + if (StringUtils.hasText(context.getScoreParameterName())) { + builder.add("$L.getFunction();\n", context.getScoreParameterName()); + } else if (StringUtils.hasText(context.getScoreRangeParameterName())) { + builder.add("scoringFunction($L);\n", context.getScoreRangeParameterName()); + } else { + builder.add("$1T.unspecified();\n", ScoringFunction.class); + } + + builder.addStatement( + "return ($5T) new $1T($2L, $3T.class, $2L.getCollectionName($3T.class), $4T.of($5T.class), $6L, $7L).execute(null)", + VectorSearchExecution.class, context.fieldNameOf(MongoOperations.class), + context.getRepositoryInformation().getDomainType(), TypeInformation.class, + queryMethod.getReturnType().getType(), searchQueryVariableName, scoringFunctionVar); + return builder.build(); + } + + public VectorSearchQueryCodeBlockBuilder withFilter(StringQuery filter) { + this.filter = filter; + return this; + } + } +} diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/MongoParameters.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/MongoParameters.java index 94acef17ce..76738bf375 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/MongoParameters.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/MongoParameters.java @@ -122,6 +122,10 @@ private MongoParameters(List parameters, int maxDistanceIndex, i this.domainType = domainType; } + protected MongoParameters() { + this(List.of(), -1, -1, -1, -1, -1, -1, TypeInformation.OBJECT); + } + static boolean isGeoNearQuery(Method method) { Class returnType = method.getReturnType(); @@ -292,7 +296,7 @@ private int getTypeIndex(List> parameterTypes, Class type, * * @author Oliver Gierke */ - static class MongoParameter extends Parameter { + public static class MongoParameter extends Parameter { private final MethodParameter parameter; private final @Nullable Integer nearIndex; diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/MongoQueryCreator.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/MongoQueryCreator.java index ba7394ec17..3436a52d1f 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/MongoQueryCreator.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/MongoQueryCreator.java @@ -54,6 +54,7 @@ import org.springframework.data.repository.query.parser.Part.Type; import org.springframework.data.repository.query.parser.PartTree; import org.springframework.data.util.Streamable; +import org.springframework.lang.NonNull; import org.springframework.util.Assert; import org.springframework.util.ClassUtils; import org.springframework.util.ObjectUtils; @@ -220,8 +221,7 @@ private Criteria from(Part part, MongoPersistentProperty property, Criteria crit return createContainingCriteria(part, property, criteria.not(), parameters); case REGEX: - Object param = parameters.next(); - return param instanceof Pattern pattern ? criteria.regex(pattern) : criteria.regex(param.toString()); + return createPatternCriteria(criteria, parameters); case EXISTS: Object next = parameters.next(); if (next instanceof Placeholder placeholder) { @@ -235,35 +235,7 @@ private Criteria from(Part part, MongoPersistentProperty property, Criteria crit return criteria.is(false); case NEAR: - Range range = accessor.getDistanceRange(); - Optional distance = range.getUpperBound().getValue(); - Optional minDistance = range.getLowerBound().getValue(); - - Point point = accessor.getGeoNearLocation(); - Point pointToUse = point == null ? nextAs(parameters, Point.class) : point; - - boolean isSpherical = isSpherical(property); - - return distance.map(it -> { - - if (isSpherical || !Metrics.NEUTRAL.equals(it.getMetric())) { - criteria.nearSphere(pointToUse); - } else { - criteria.near(pointToUse); - } - - if (pointToUse instanceof GeoJson) { // using GeoJson distance is in meters. - - criteria.maxDistance(MetricConversion.getDistanceInMeters(it)); - minDistance.map(MetricConversion::getDistanceInMeters).ifPresent(criteria::minDistance); - } else { - criteria.maxDistance(it.getNormalizedValue()); - minDistance.map(Distance::getNormalizedValue).ifPresent(criteria::minDistance); - } - - return criteria; - - }).orElseGet(() -> isSpherical ? criteria.nearSphere(pointToUse) : criteria.near(pointToUse)); + return createNearCriteria(property, criteria, parameters); case WITHIN: @@ -283,6 +255,49 @@ private Criteria from(Part part, MongoPersistentProperty property, Criteria crit } } + @NonNull + private static Criteria createPatternCriteria(Criteria criteria, Iterator parameters) { + Object param = parameters.next(); + if (param instanceof Placeholder) { + return criteria.raw("$regex", param); + } + return param instanceof Pattern pattern ? criteria.regex(pattern) : criteria.regex(param.toString()); + } + + @NonNull + private Criteria createNearCriteria(MongoPersistentProperty property, Criteria criteria, + Iterator parameters) { + + Range range = accessor.getDistanceRange(); + Optional distance = range.getUpperBound().getValue(); + Optional minDistance = range.getLowerBound().getValue(); + + Point point = accessor.getGeoNearLocation(); + Point pointToUse = point == null ? nextAs(parameters, Point.class) : point; + + boolean isSpherical = isSpherical(property); + + return distance.map(it -> { + + if (isSpherical || !Metrics.NEUTRAL.equals(it.getMetric())) { + criteria.nearSphere(pointToUse); + } else { + criteria.near(pointToUse); + } + + if (pointToUse instanceof GeoJson) { // using GeoJson distance is in meters. + + criteria.maxDistance(MetricConversion.getDistanceInMeters(it)); + minDistance.map(MetricConversion::getDistanceInMeters).ifPresent(criteria::minDistance); + } else { + criteria.maxDistance(it.getNormalizedValue()); + minDistance.map(Distance::getNormalizedValue).ifPresent(criteria::minDistance); + } + + return criteria; + }).orElseGet(() -> isSpherical ? criteria.nearSphere(pointToUse) : criteria.near(pointToUse)); + } + private boolean isSimpleComparisonPossible(Part part) { return switch (part.shouldIgnoreCase()) { diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/MongoQueryExecution.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/MongoQueryExecution.java index c0531e0e19..acf80db214 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/MongoQueryExecution.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/MongoQueryExecution.java @@ -186,8 +186,11 @@ public Object execute(Query query) { return isListOfGeoResult(method.getReturnType()) ? results.getContent() : results; } - @SuppressWarnings({ "unchecked", "NullAway" }) GeoResults doExecuteQuery(Query query) { + return doExecuteQuery(nearQuery(query)); + } + + NearQuery nearQuery(Query query) { Point nearLocation = accessor.getGeoNearLocation(); Assert.notNull(nearLocation, "[query.location] must not be null"); @@ -205,9 +208,12 @@ GeoResults doExecuteQuery(Query query) { distances.getUpperBound().getValue().ifPresent(it -> nearQuery.maxDistance(it).in(it.getMetric())); Pageable pageable = accessor.getPageable(); - nearQuery.with(pageable); + return nearQuery.with(pageable); + } - return (GeoResults) operation.near(nearQuery).all(); + @SuppressWarnings({ "unchecked", "NullAway" }) + GeoResults doExecuteQuery(NearQuery query) { + return (GeoResults) operation.near(query).all(); } private static boolean isListOfGeoResult(TypeInformation returnType) { @@ -324,16 +330,11 @@ final class PagingGeoNearExecution extends GeoNearExecution { @Override public Object execute(Query query) { - GeoResults geoResults = doExecuteQuery(query); + NearQuery nearQuery = nearQuery(query); + GeoResults geoResults = doExecuteQuery(nearQuery); Page> page = PageableExecutionUtils.getPage(geoResults.getContent(), accessor.getPageable(), - () -> { - - Query countQuery = mongoQuery.createCountQuery(accessor); - countQuery = mongoQuery.applyQueryMetaAttributesWhenPresent(countQuery); - - return operation.matching(countQuery).count(); - }); + () -> operation.near(nearQuery).count()); // transform to GeoPage after applying optimization return new GeoPage<>(geoResults, accessor.getPageable(), page.getTotalElements()); diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/util/BsonUtils.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/util/BsonUtils.java index dc51da84ed..eb052da9a4 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/util/BsonUtils.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/util/BsonUtils.java @@ -60,6 +60,7 @@ import org.bson.codecs.DocumentCodec; import org.bson.codecs.EncoderContext; import org.bson.codecs.configuration.CodecConfigurationException; +import org.bson.codecs.configuration.CodecProvider; import org.bson.codecs.configuration.CodecRegistries; import org.bson.codecs.configuration.CodecRegistry; import org.bson.conversions.Bson; @@ -74,6 +75,7 @@ import org.springframework.data.mongodb.core.mapping.FieldName; import org.springframework.data.mongodb.core.mapping.FieldName.Type; import org.springframework.data.mongodb.core.query.CriteriaDefinition.Placeholder; +import org.springframework.data.mongodb.core.query.GeoCommand; import org.springframework.lang.Contract; import org.springframework.util.Assert; import org.springframework.util.ClassUtils; @@ -103,7 +105,7 @@ public class BsonUtils { public static final Document EMPTY_DOCUMENT = new EmptyDocument(); private static final CodecRegistry JSON_CODEC_REGISTRY = CodecRegistries.fromRegistries( - MongoClientSettings.getDefaultCodecRegistry(), CodecRegistries.fromCodecs(new PlaceholderCodec())); + MongoClientSettings.getDefaultCodecRegistry(), CodecRegistries.fromProviders(new PlaceholderCodecProvider())); @SuppressWarnings("unchecked") @Contract("null, _ -> null") @@ -377,7 +379,7 @@ public static BsonValue simpleToBsonValue(@Nullable Object source) { @Contract("null, _ -> !null") public static BsonValue simpleToBsonValue(@Nullable Object source, CodecRegistry codecRegistry) { - if(source == null) { + if (source == null) { return BsonNull.VALUE; } @@ -1031,6 +1033,25 @@ public void flush() { } } + @NullUnmarked + public static class PlaceholderCodecProvider implements CodecProvider { + + PlaceholderCodec placeholderCodec = new PlaceholderCodec(); + GeoCommandCodec geoCommandCodec = new GeoCommandCodec(); + + @Override + public Codec get(Class clazz, CodecRegistry registry) { + if (ClassUtils.isAssignable(Placeholder.class, clazz)) { + return (Codec) placeholderCodec; + } + if (ClassUtils.isAssignable(GeoCommand.class, clazz)) { + return (Codec) geoCommandCodec; + } + return null; + + } + } + /** * Internal {@link Codec} implementation to write * {@link org.springframework.data.mongodb.core.query.CriteriaDefinition.Placeholder placeholders}. @@ -1060,4 +1081,38 @@ public Class getEncoderClass() { return Placeholder.class; } } + + static class GeoCommandCodec implements Codec { + + @Override + public GeoCommand decode(BsonReader reader, DecoderContext decoderContext) { + return null; + } + + @Override + public void encode(BsonWriter writer, GeoCommand value, EncoderContext encoderContext) { + + if (writer instanceof SpringJsonWriter sjw) { + if (!value.getCommand().equals("$geometry")) { + writer.writeStartDocument(); + writer.writeName(value.getCommand()); + if (value.getShape() instanceof Placeholder p) { // maybe we should wrap input to use geo command object + sjw.writePlaceholder(p.toString()); + } + writer.writeEndDocument(); + } else { + if (value.getShape() instanceof Placeholder p) { // maybe we should wrap input to use geo command object + sjw.writePlaceholder(p.toString()); + } + } + } else { + writer.writeString(value.getCommand(), value.getShape().toString()); + } + } + + @Override + public Class getEncoderClass() { + return null; + } + } } diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/util/SpringJsonWriter.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/util/SpringJsonWriter.java index 07eab92a01..98dbc3a682 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/util/SpringJsonWriter.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/util/SpringJsonWriter.java @@ -463,7 +463,7 @@ public void writePlaceholder(String placeholder) { write(placeholder); } - private void write(String str) { + public void write(String str) { buffer.append(str); } diff --git a/spring-data-mongodb/src/test/java/example/aot/Location.java b/spring-data-mongodb/src/test/java/example/aot/Location.java new file mode 100644 index 0000000000..210e9e0ce6 --- /dev/null +++ b/spring-data-mongodb/src/test/java/example/aot/Location.java @@ -0,0 +1,26 @@ +/* + * Copyright 2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package example.aot; + +import org.springframework.data.geo.Point; + +/** + * @param planet + * @param coordinates + * @author Christoph Strobl + */ +public record Location(String planet, Point coordinates) { +} diff --git a/spring-data-mongodb/src/test/java/example/aot/User.java b/spring-data-mongodb/src/test/java/example/aot/User.java index 06022c0a55..dfe3ec3553 100644 --- a/spring-data-mongodb/src/test/java/example/aot/User.java +++ b/spring-data-mongodb/src/test/java/example/aot/User.java @@ -17,6 +17,7 @@ import java.time.Instant; +import org.springframework.data.domain.Vector; import org.springframework.data.mongodb.core.mapping.Field; /** @@ -32,10 +33,14 @@ public class User { @Field("last_name") String lastname; + Location location; + Instant registrationDate; Instant lastSeen; Long visits; + Vector embedding; + public String getId() { return id; } @@ -91,4 +96,12 @@ public Long getVisits() { public void setVisits(Long visits) { this.visits = visits; } + + public Location getLocation() { + return location; + } + + public void setLocation(Location location) { + this.location = location; + } } diff --git a/spring-data-mongodb/src/test/java/example/aot/UserRepository.java b/spring-data-mongodb/src/test/java/example/aot/UserRepository.java index 5eb9fed686..ee1058fdc6 100644 --- a/spring-data-mongodb/src/test/java/example/aot/UserRepository.java +++ b/spring-data-mongodb/src/test/java/example/aot/UserRepository.java @@ -22,23 +22,43 @@ import java.util.Objects; import java.util.Optional; import java.util.Set; +import java.util.regex.Pattern; import java.util.stream.Stream; import org.springframework.data.annotation.Id; import org.springframework.data.domain.Limit; import org.springframework.data.domain.Page; import org.springframework.data.domain.Pageable; +import org.springframework.data.domain.Range; +import org.springframework.data.domain.Score; import org.springframework.data.domain.ScrollPosition; +import org.springframework.data.domain.SearchResults; +import org.springframework.data.domain.Similarity; import org.springframework.data.domain.Slice; import org.springframework.data.domain.Sort; +import org.springframework.data.domain.Vector; import org.springframework.data.domain.Window; +import org.springframework.data.geo.Box; +import org.springframework.data.geo.Circle; +import org.springframework.data.geo.Distance; +import org.springframework.data.geo.GeoPage; +import org.springframework.data.geo.GeoResult; +import org.springframework.data.geo.GeoResults; +import org.springframework.data.geo.Point; +import org.springframework.data.geo.Polygon; import org.springframework.data.mongodb.core.aggregation.AggregationResults; +import org.springframework.data.mongodb.core.aggregation.VectorSearchOperation; +import org.springframework.data.mongodb.core.geo.GeoJson; +import org.springframework.data.mongodb.core.geo.GeoJsonPolygon; +import org.springframework.data.mongodb.core.geo.Sphere; import org.springframework.data.mongodb.repository.Aggregation; import org.springframework.data.mongodb.repository.Hint; import org.springframework.data.mongodb.repository.Query; import org.springframework.data.mongodb.repository.ReadPreference; import org.springframework.data.mongodb.repository.Update; +import org.springframework.data.mongodb.repository.VectorSearch; import org.springframework.data.repository.CrudRepository; +import org.springframework.data.repository.query.Param; /** * @author Christoph Strobl @@ -83,6 +103,8 @@ public interface UserRepository extends CrudRepository { List findByLastnameNot(String lastname); + List findByFirstnameRegex(Pattern pattern); + List findTop2ByLastnameStartingWith(String lastname); List findByLastnameStartingWithOrderByUsername(String lastname); @@ -103,7 +125,30 @@ public interface UserRepository extends CrudRepository { Window findTop2WindowByLastnameStartingWithOrderByUsername(String lastname, ScrollPosition scrollPosition); - // TODO: GeoQueries + List findByLocationCoordinatesNear(Point location); + + List findByLocationCoordinatesWithin(Circle circle); + + List findByLocationCoordinatesWithin(Sphere circle); + + List findByLocationCoordinatesWithin(Box box); + + List findByLocationCoordinatesWithin(Polygon polygon); + + List findByLocationCoordinatesWithin(GeoJsonPolygon polygon); + + List findUserByLocationCoordinatesWithin(GeoJson geoJson); + + GeoResults findByLocationCoordinatesNear(Point point, Distance maxDistance); + + GeoResults findByLocationCoordinatesNearAndLastname(Point point, Distance maxDistance, String lastname); + + List> findUserAsListByLocationCoordinatesNear(Point point, Distance maxDistance); + + GeoResults findByLocationCoordinatesNear(Point point, Range distance); + + GeoPage findByLocationCoordinatesNear(Point point, Distance maxDistance, Pageable pageable); + // TODO: TextSearch /* Annotated Queries */ @@ -143,6 +188,12 @@ public interface UserRepository extends CrudRepository { @Query("{ 'lastname' : { '$regex' : '^?0' } }") Slice findAnnotatedQuerySliceOfUsersByLastname(String lastname, Pageable pageable); + @Query("{ firstname : ?#{[0]} }") + List findWithExpressionUsingParameterIndex(String firstname); + + @Query("{ firstname : :#{#firstname} }") + List findWithExpressionUsingParameterName(@Param("firstname") String firstname); + /* deletes */ User deleteByUsername(String username); @@ -246,6 +297,30 @@ public interface UserRepository extends CrudRepository { "{ '$project': { '_id' : '$last_name' } }" }, collation = "no_collation") List findAllLastnamesWithCollation(); + // Vector Search + + @VectorSearch(indexName = "embedding.vector_cos", filter = "{lastname: ?0}", numCandidates = "#{10+10}", + searchType = VectorSearchOperation.SearchType.ANN) + SearchResults annotatedVectorSearch(String lastname, Vector vector, Score distance, Limit limit); + + @VectorSearch(indexName = "embedding.vector_cos") + SearchResults searchCosineByLastnameAndEmbeddingNear(String lastname, Vector vector, Score similarity, + Limit limit); + + @VectorSearch(indexName = "embedding.vector_cos") + List searchAsListByLastnameAndEmbeddingNear(String lastname, Vector vector, Limit limit); + + @VectorSearch(indexName = "embedding.vector_cos", limit = "10") + SearchResults searchByLastnameAndEmbeddingWithin(String lastname, Vector vector, Range distance); + + @VectorSearch(indexName = "embedding.vector_cos", limit = "10") + SearchResults searchByLastnameAndEmbeddingWithinOrderByFirstname(String lastname, Vector vector, + Range distance); + + @VectorSearch(indexName = "embedding.vector_cos") + SearchResults searchTop1ByLastnameAndEmbeddingWithin(String lastname, Vector vector, + Range distance); + class UserAggregate { @Id // diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/ExecutableFindOperationSupportTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/ExecutableFindOperationSupportTests.java index 835367990a..3c95a5a8ea 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/ExecutableFindOperationSupportTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/ExecutableFindOperationSupportTests.java @@ -15,10 +15,13 @@ */ package org.springframework.data.mongodb.core; -import static org.assertj.core.api.Assertions.*; -import static org.springframework.data.mongodb.core.query.Criteria.*; -import static org.springframework.data.mongodb.core.query.Query.*; -import static org.springframework.data.mongodb.test.util.DirtiesStateExtension.*; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; +import static org.springframework.data.mongodb.core.query.Criteria.where; +import static org.springframework.data.mongodb.core.query.Query.query; +import static org.springframework.data.mongodb.test.util.DirtiesStateExtension.DirtiesState; +import static org.springframework.data.mongodb.test.util.DirtiesStateExtension.StateFunctions; import java.util.Date; import java.util.List; @@ -47,6 +50,7 @@ import org.springframework.data.mongodb.core.mapping.Field; import org.springframework.data.mongodb.core.query.BasicQuery; import org.springframework.data.mongodb.core.query.NearQuery; +import org.springframework.data.mongodb.core.query.Query; import org.springframework.data.mongodb.test.util.DirtiesStateExtension; import org.springframework.data.mongodb.test.util.MongoTemplateExtension; import org.springframework.data.mongodb.test.util.MongoTestTemplate; @@ -81,7 +85,7 @@ public void clear() { @Override public void setupState() { - template.indexOps(Planet.class).ensureIndex( + template.indexOps(Planet.class).createIndex( new GeospatialIndex("coordinates").typed(GeoSpatialIndexType.GEO_2DSPHERE).named("planet-coordinate-idx")); initPersons(); @@ -162,7 +166,7 @@ void findAllByWithCollection() { void findAllAsDocument() { assertThat( template.query(Document.class).inCollection(STAR_WARS).matching(query(where("firstname").is("luke"))).all()) - .hasSize(1); + .hasSize(1); } @Test // DATAMONGO-1563 @@ -324,6 +328,14 @@ void findAllNearBy() { assertThat(results.getContent().get(0).getDistance()).isNotNull(); } + @Test + void countResultsOfNearQuery() { + + Long count = template.query(Planet.class) + .near(NearQuery.near(-73.9667, 40.78).spherical(true).query(new Query(where("name").is("alderan")))).count(); + assertThat(count).isEqualTo(1); + } + @Test // DATAMONGO-1563 void findAllNearByWithCollectionAndProjection() { diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/AbstractPersonRepositoryIntegrationTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/AbstractPersonRepositoryIntegrationTests.java index c2cb6cacf8..493a23e4e5 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/AbstractPersonRepositoryIntegrationTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/AbstractPersonRepositoryIntegrationTests.java @@ -15,10 +15,12 @@ */ package org.springframework.data.mongodb.repository; -import static java.util.Arrays.*; -import static org.assertj.core.api.Assertions.*; -import static org.assertj.core.api.Assumptions.*; -import static org.springframework.data.geo.Metrics.*; +import static java.util.Arrays.asList; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; +import static org.assertj.core.api.Assumptions.assumeThat; +import static org.springframework.data.geo.Metrics.KILOMETERS; import java.util.ArrayList; import java.util.Arrays; @@ -38,13 +40,22 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.api.TestInstance; import org.junit.jupiter.api.extension.ExtendWith; - import org.springframework.beans.factory.annotation.Autowired; import org.springframework.dao.DuplicateKeyException; import org.springframework.dao.IncorrectResultSizeDataAccessException; -import org.springframework.data.domain.*; +import org.springframework.data.domain.Example; +import org.springframework.data.domain.ExampleMatcher; import org.springframework.data.domain.ExampleMatcher.GenericPropertyMatcher; +import org.springframework.data.domain.Limit; +import org.springframework.data.domain.Page; +import org.springframework.data.domain.PageRequest; +import org.springframework.data.domain.Pageable; +import org.springframework.data.domain.Range; +import org.springframework.data.domain.ScrollPosition; +import org.springframework.data.domain.Slice; +import org.springframework.data.domain.Sort; import org.springframework.data.domain.Sort.Direction; +import org.springframework.data.domain.Window; import org.springframework.data.geo.Box; import org.springframework.data.geo.Circle; import org.springframework.data.geo.Distance; @@ -216,8 +227,8 @@ void appliesScrollPositionCorrectly() { @Test // GH-4397 void appliesLimitToScrollingCorrectly() { - Window page = repository.findByLastnameLikeOrderByLastnameAscFirstnameAsc("*a*", - ScrollPosition.keyset(), Limit.of(2)); + Window page = repository.findByLastnameLikeOrderByLastnameAscFirstnameAsc("*a*", ScrollPosition.keyset(), + Limit.of(2)); assertThat(page.isLast()).isFalse(); assertThat(page.size()).isEqualTo(2); @@ -250,7 +261,8 @@ void executesPagedFinderCorrectly() { @Test // GH-4397 void executesFinderCorrectlyWithSortAndLimit() { - List page = repository.findByLastnameLike("*a*", Sort.by(Direction.ASC, "lastname", "firstname"), Limit.of(2)); + List page = repository.findByLastnameLike("*a*", Sort.by(Direction.ASC, "lastname", "firstname"), + Limit.of(2)); assertThat(page).containsExactly(carter, stefan); } @@ -462,6 +474,22 @@ void executesGeoNearQueryForResultsCorrectly() { assertThat(results.getContent()).isNotEmpty(); } + @Test + void executesGeoNearQueryWithAdditionalFilterCorrectly() { + + Point point = new Point(-73.99171, 40.738868); + dave.setLocation(point); + repository.save(dave); + + Person p2 = new Person("fn", "ln", 42, Sex.MALE); + p2.setLocation(point); + repository.save(p2); + + GeoResults results = repository.findByLocationNearAndLastname(new Point(-73.99, 40.73), + Distance.of(2000, Metrics.KILOMETERS), "ln"); + assertThat(results.getContent()).hasSize(1); + } + @Test void executesGeoPageQueryForResultsCorrectly() { @@ -638,6 +666,7 @@ void executesGeoPageQueryForWithPageRequestForPageInBetween() { assertThat(results.getContent()).isNotEmpty(); assertThat(results.getNumberOfElements()).isEqualTo(2); + assertThat(results.getTotalElements()).isEqualTo(5); assertThat(results.isFirst()).isFalse(); assertThat(results.isLast()).isFalse(); assertThat(results.getAverageDistance().getMetric()).isEqualTo(Metrics.KILOMETERS); @@ -697,6 +726,30 @@ void executesGeoPageQueryForWithPageRequestForJustOneElementEmptyPage() { assertThat(results.getAverageDistance().getMetric()).isEqualTo(Metrics.KILOMETERS); } + @Test + void executesGeoPageCountCorrectly() { + + Point farAway = new Point(-73.9, 40.7); + Point here = new Point(-73.99, 40.73); + + dave.setLocation(farAway); + oliver.setLocation(here); + carter.setLocation(here); + boyd.setLocation(here); + leroi.setLocation(here); + + repository.saveAll(Arrays.asList(dave, oliver, carter, boyd, leroi)); + + GeoPage results = repository.findByLocationNear(new Point(-73.99, 40.73), + Distance.of(5, Metrics.KILOMETERS), PageRequest.of(1, 2)); + + assertThat(results.getContent()).isNotEmpty(); + assertThat(results.getNumberOfElements()).isEqualTo(2); + assertThat(results.getTotalElements()).isEqualTo(4); + assertThat(results.isFirst()).isFalse(); + assertThat(results.isLast()).isTrue(); + } + @Test // DATAMONGO-1608 void findByFirstNameIgnoreCaseWithNull() { diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/PersonRepository.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/PersonRepository.java index 1f4f682ebc..9ab0d71dc3 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/PersonRepository.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/PersonRepository.java @@ -222,6 +222,8 @@ Window findByLastnameLikeOrderByLastnameAscFirstnameAsc(String lastname, GeoResults findByLocationNear(Point point, Distance maxDistance); + GeoResults findByLocationNearAndLastname(Point point, Distance maxDistance, String Lastname); + // DATAMONGO-1110 GeoResults findPersonByLocationNear(Point point, Range distance); diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/AotFragmentTestConfigurationSupport.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/AotFragmentTestConfigurationSupport.java index eba08ecc2e..0a0549eb1b 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/AotFragmentTestConfigurationSupport.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/AotFragmentTestConfigurationSupport.java @@ -93,7 +93,8 @@ private Object getFragmentFacadeProxy(Object fragment) { Method target = ReflectionUtils.findMethod(fragment.getClass(), method.getName(), method.getParameterTypes()); if (target == null) { - throw new NoSuchMethodException("Method [%s] is not implemented by [%s]".formatted(method, target)); + throw new MethodNotImplementedException( + "Method [%s] is not implemented by [%s]".formatted(method, fragment.getClass())); } try { @@ -127,4 +128,11 @@ public ProjectionFactory getProjectionFactory() { } }; } + + public static class MethodNotImplementedException extends RuntimeException { + + public MethodNotImplementedException(String message) { + super(message); + } + } } diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributorTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributorTests.java index a2840ec268..e40f0cd53b 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributorTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributorTests.java @@ -15,22 +15,27 @@ */ package org.springframework.data.mongodb.repository.aot; -import static org.assertj.core.api.Assertions.*; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatException; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; import example.aot.User; import example.aot.UserProjection; import example.aot.UserRepository; import example.aot.UserRepository.UserAggregate; +import java.time.Duration; import java.time.Instant; +import java.util.ArrayList; import java.util.List; import java.util.Optional; +import java.util.regex.Pattern; +import org.bson.BsonString; import org.bson.Document; +import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.extension.ExtendWith; - import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; @@ -39,20 +44,43 @@ import org.springframework.data.domain.OffsetScrollPosition; import org.springframework.data.domain.Page; import org.springframework.data.domain.PageRequest; +import org.springframework.data.domain.Range; +import org.springframework.data.domain.Score; import org.springframework.data.domain.ScrollPosition; +import org.springframework.data.domain.SearchResults; +import org.springframework.data.domain.Similarity; import org.springframework.data.domain.Slice; import org.springframework.data.domain.Sort; +import org.springframework.data.domain.Vector; import org.springframework.data.domain.Window; +import org.springframework.data.geo.Box; +import org.springframework.data.geo.Circle; +import org.springframework.data.geo.Distance; +import org.springframework.data.geo.GeoPage; +import org.springframework.data.geo.GeoResult; +import org.springframework.data.geo.GeoResults; +import org.springframework.data.geo.Metrics; +import org.springframework.data.geo.Point; +import org.springframework.data.geo.Polygon; import org.springframework.data.mongodb.core.MongoOperations; import org.springframework.data.mongodb.core.MongoTemplate; import org.springframework.data.mongodb.core.aggregation.AggregationResults; -import org.springframework.data.mongodb.test.util.Client; -import org.springframework.data.mongodb.test.util.MongoClientExtension; +import org.springframework.data.mongodb.core.geo.GeoJsonPoint; +import org.springframework.data.mongodb.core.geo.GeoJsonPolygon; +import org.springframework.data.mongodb.test.util.AtlasContainer; import org.springframework.data.mongodb.test.util.MongoTestUtils; import org.springframework.test.context.junit.jupiter.SpringJUnitConfig; import org.springframework.util.StringUtils; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; +import org.testcontainers.shaded.org.awaitility.Awaitility; import com.mongodb.client.MongoClient; +import com.mongodb.client.MongoClients; +import com.mongodb.client.MongoCollection; +import com.mongodb.client.model.IndexOptions; +import com.mongodb.client.model.SearchIndexModel; +import com.mongodb.client.model.SearchIndexType; /** * Integration tests for the {@link UserRepository} AOT fragment. @@ -60,13 +88,15 @@ * @author Christoph Strobl * @author Mark Paluch */ -@ExtendWith(MongoClientExtension.class) +@Testcontainers(disabledWithoutDocker = true) @SpringJUnitConfig(classes = MongoRepositoryContributorTests.MongoRepositoryContributorConfiguration.class) class MongoRepositoryContributorTests { + private static final @Container AtlasContainer atlasLocal = AtlasContainer.bestMatch(); private static final String DB_NAME = "aot-repo-tests"; + private static final String COLLECTION_NAME = "user"; - @Client static MongoClient client; + static MongoClient client; @Autowired UserRepository fragment; @Configuration @@ -82,8 +112,40 @@ MongoOperations mongoOperations() { } } + @BeforeAll + static void beforeAll() throws InterruptedException { + + client = MongoClients.create(atlasLocal.getConnectionString()); + MongoCollection userCollection = client.getDatabase(DB_NAME).getCollection(COLLECTION_NAME); + userCollection.createIndex(new Document("location.coordinates", "2d"), new IndexOptions()); + userCollection.createIndex(new Document("location.coordinates", "2dsphere"), new IndexOptions()); + + Document searchIndex = new Document("fields", + List.of(new Document("type", "vector").append("path", "embedding").append("numDimensions", 5) + .append("similarity", "cosine"), new Document("type", "filter").append("path", "last_name"))); + + userCollection.createSearchIndexes(List.of( + new SearchIndexModel("embedding.vector_cos", searchIndex, SearchIndexType.of(new BsonString("vectorSearch"))))); + + Awaitility.await().atMost(Duration.ofSeconds(120)).pollInterval(Duration.ofMillis(200)).until(() -> { + + List execute = userCollection + .aggregate( + List.of(Document.parse("{'$listSearchIndexes': { 'name' : '%s'}}".formatted("embedding.vector_cos")))) + .into(new ArrayList<>()); + for (Document doc : execute) { + if (doc.getString("name").equals("embedding.vector_cos")) { + return doc.getString("status").equals("READY"); + } + } + return false; + }); + + Thread.sleep(250); // just wait a little or the index will be broken + } + @BeforeEach - void beforeEach() { + void beforeEach() throws InterruptedException { MongoTestUtils.flushCollection(DB_NAME, "user", client); initUsers(); @@ -208,6 +270,13 @@ void testNot() { assertThat(users).extracting(User::getUsername).isNotEmpty().doesNotContain("luke", "vader"); } + @Test // GH-4939 + void testRegex() { + + List lukes = fragment.findByFirstnameRegex(Pattern.compile(".*uk.*")); + assertThat(lukes).extracting(User::getUsername).containsExactly("luke"); + } + @Test void testExistsCriteria() { @@ -313,6 +382,20 @@ void testAnnotatedFinderReturningSingleValueWithQuery() { assertThat(user).isNotNull().extracting(User::getUsername).isEqualTo("yoda"); } + @Test // GH-5006 + void testAnnotatedFinderWithExpressionUsingParameterIndex() { + + List users = fragment.findWithExpressionUsingParameterIndex("Luke"); + assertThat(users).extracting(User::getUsername).containsExactly("luke"); + } + + @Test // GH-5006 + void testAnnotatedFinderWithExpressionUsingParameterName() { + + List users = fragment.findWithExpressionUsingParameterName("Luke"); + assertThat(users).extracting(User::getUsername).containsExactly("luke"); + } + @Test void testAnnotatedCount() { @@ -592,7 +675,194 @@ void testAggregationWithCollation() { .withMessageContaining("'locale' is invalid"); } - private static void initUsers() { + @Test // GH-5004 + void testNear() { + + List users = fragment.findByLocationCoordinatesNear(new Point(-73.99171, 40.738868)); + assertThat(users).extracting(User::getUsername).containsExactly("leia", "vader"); + } + + @Test // GH-5004 + void testNearWithGeoJson() { + + List users = fragment.findByLocationCoordinatesNear(new GeoJsonPoint(-73.99171, 40.738868)); + assertThat(users).extracting(User::getUsername).containsExactly("leia", "vader"); + } + + @Test // GH-5004 + void testGeoWithinCircle() { + + List users = fragment.findByLocationCoordinatesWithin(new Circle(-78.99171, 45.738868, 170)); + assertThat(users).extracting(User::getUsername).containsExactly("leia", "vader"); + } + + @Test // GH-5004 + void testWithinBox() { + + Box box = new Box(new Point(-78.99171, 35.738868), new Point(-68.99171, 45.738868)); + + List result = fragment.findByLocationCoordinatesWithin(box); + assertThat(result).extracting(User::getUsername).containsExactly("leia", "vader"); + } + + @Test // GH-5004 + void findsPeopleByLocationWithinPolygon() { + + Point first = new Point(-78.99171, 35.738868); + Point second = new Point(-78.99171, 45.738868); + Point third = new Point(-68.99171, 45.738868); + Point fourth = new Point(-68.99171, 35.738868); + + List result = fragment.findByLocationCoordinatesWithin(new Polygon(first, second, third, fourth)); + assertThat(result).extracting(User::getUsername).containsExactly("leia", "vader"); + } + + @Test // GH-5004 + void findsPeopleByLocationWithinGeoJsonPolygon() { + + Point first = new Point(-78.99171, 35.738868); + Point second = new Point(-78.99171, 45.738868); + Point third = new Point(-68.99171, 45.738868); + Point fourth = new Point(-68.99171, 35.738868); + + List result = fragment + .findByLocationCoordinatesWithin(new GeoJsonPolygon(first, second, third, fourth, first)); + assertThat(result).extracting(User::getUsername).containsExactly("leia", "vader"); + } + + @Test // GH-5004 + void findsPeopleByLocationWithinSomeGenericGeoJsonObject() { + + Point first = new Point(-78.99171, 35.738868); + Point second = new Point(-78.99171, 45.738868); + Point third = new Point(-68.99171, 45.738868); + Point fourth = new Point(-68.99171, 35.738868); + + List result = fragment + .findUserByLocationCoordinatesWithin(new GeoJsonPolygon(first, second, third, fourth, first)); + assertThat(result).extracting(User::getUsername).containsExactly("leia", "vader"); + } + + @Test // GH-5004 + void testNearWithGeoResult() { + + GeoResults users = fragment.findByLocationCoordinatesNear(new Point(-73.99, 40.73), + Distance.of(5, Metrics.KILOMETERS)); + assertThat(users).extracting(GeoResult::getContent).extracting(User::getUsername).containsExactly("leia"); + } + + @Test // GH-5004 + void testNearWithAdditionalFilterQueryAsGeoResult() { + + GeoResults users = fragment.findByLocationCoordinatesNearAndLastname(new Point(-73.99, 40.73), + Distance.of(50, Metrics.KILOMETERS), "Organa"); + assertThat(users).extracting(GeoResult::getContent).extracting(User::getUsername).containsExactly("leia"); + } + + @Test // GH-5004 + void testNearReturningListOfGeoResult() { + + List> users = fragment.findUserAsListByLocationCoordinatesNear(new Point(-73.99, 40.73), + Distance.of(5, Metrics.KILOMETERS)); + assertThat(users).extracting(GeoResult::getContent).extracting(User::getUsername).containsExactly("leia"); + } + + @Test // GH-5004 + void testNearWithRange() { + + Range range = Distance.between(Distance.of(5, Metrics.KILOMETERS), Distance.of(2000, Metrics.KILOMETERS)); + GeoResults users = fragment.findByLocationCoordinatesNear(new Point(-73.99, 40.73), range); + + assertThat(users).extracting(GeoResult::getContent).extracting(User::getUsername).containsExactly("vader"); + } + + @Test // GH-5004 + void testNearReturningGeoPage() { + + GeoPage page1 = fragment.findByLocationCoordinatesNear(new Point(-73.99, 40.73), + Distance.of(2000, Metrics.KILOMETERS), PageRequest.of(0, 1)); + + assertThat(page1.hasNext()).isTrue(); + + GeoPage page2 = fragment.findByLocationCoordinatesNear(new Point(-73.99, 40.73), + Distance.of(2000, Metrics.KILOMETERS), page1.nextPageable()); + assertThat(page2.hasNext()).isFalse(); + } + + @Test + void vectorSearchFromAnnotation() throws InterruptedException { + + Thread.sleep(1000); // srly - reindex for vector search + + Vector vector = Vector.of(1.00000d, 1.12345d, 2.23456d, 3.34567d, 4.45678d); + SearchResults results = fragment.annotatedVectorSearch("Skywalker", vector, Score.of(0.99), Limit.of(10)); + + assertThat(results).hasSize(1); + } + + @Test + void vectorSearchWithDerivedQuery() throws InterruptedException { + + Thread.sleep(1000); // srly - reindex for vector search + + Vector vector = Vector.of(1.00000d, 1.12345d, 2.23456d, 3.34567d, 4.45678d); + SearchResults results = fragment.searchCosineByLastnameAndEmbeddingNear("Skywalker", vector, Score.of(0.98), + Limit.of(10)); + + assertThat(results).hasSize(1); + } + + @Test + void vectorSearchReturningResultsAsList() throws InterruptedException { + + Thread.sleep(1000); // srly - reindex for vector search + + Vector vector = Vector.of(1.00000d, 1.12345d, 2.23456d, 3.34567d, 4.45678d); + List results = fragment.searchAsListByLastnameAndEmbeddingNear("Skywalker", vector, Limit.of(10)); + + assertThat(results).hasSize(2); + } + + @Test + void vectorSearchWithLimitFromAnnotation() throws InterruptedException { + + Thread.sleep(1000); // srly - reindex for vector search + + Vector vector = Vector.of(1.00000d, 1.12345d, 2.23456d, 3.34567d, 4.45678d); + SearchResults results = fragment.searchByLastnameAndEmbeddingWithin("Skywalker", vector, + Similarity.between(0.4, 0.99)); + + assertThat(results).hasSize(1); + } + + @Test + void vectorSearchWithSorting() throws InterruptedException { + + Thread.sleep(1000); // srly - reindex for vector search + + Vector vector = Vector.of(1.00000d, 1.12345d, 2.23456d, 3.34567d, 4.45678d); + SearchResults results = fragment.searchByLastnameAndEmbeddingWithinOrderByFirstname("Skywalker", vector, + Similarity.between(0.4, 1.0)); + + assertThat(results).hasSize(2); + } + + @Test + void vectorSearchWithLimitFromDerivedQuery() throws InterruptedException { + + Thread.sleep(1000); // srly - reindex for vector search + + Vector vector = Vector.of(1.00000d, 1.12345d, 2.23456d, 3.34567d, 4.45678d); + SearchResults results = fragment.searchTop1ByLastnameAndEmbeddingWithin("Skywalker", vector, + Similarity.between(0.4, 1.0)); + + assertThat(results).hasSize(1); + } + + /** + * GeoResults results = repository.findPersonByLocationNear(new Point(-73.99, 40.73), range); + */ + private static void initUsers() throws InterruptedException { Document luke = Document.parse(""" { @@ -612,6 +882,7 @@ private static void initUsers() { } } ], + "embedding" : [1.00000, 1.12345, 2.23456, 3.34567, 4.45678], "_class": "example.springdata.aot.User" }"""); @@ -621,6 +892,13 @@ private static void initUsers() { "username": "leia", "first_name": "Leia", "last_name": "Organa", + "location" : { + "planet" : "Coruscant", + "coordinates" : { + "x" : -73.99171, "y" : 40.738868 + } + }, + "embedding" : [1.0001, 2.12345, 3.23456, 4.34567, 5.45678], "_class": "example.springdata.aot.User" }"""); @@ -638,6 +916,7 @@ private static void initUsers() { } } ], + "embedding" : [2.0002, 3.12345, 4.23456, 5.34567, 6.45678], "_class": "example.springdata.aot.User" }"""); @@ -648,6 +927,7 @@ private static void initUsers() { "lastSeen" : { "$date": "2025-01-01T00:00:00.000Z" }, + "embedding" : [3.0003, 4.12345, 5.23456, 6.34567, 7.45678], "_class": "example.springdata.aot.User" }"""); @@ -670,7 +950,8 @@ private static void initUsers() { "$date": "2025-01-15T13:53:33.855Z" } } - ] + ], + "embedding" : [4.0004, 5.12345, 6.23456, 7.34567, 8.45678] }"""); Document vader = Document.parse(""" @@ -679,6 +960,12 @@ private static void initUsers() { "username": "vader", "first_name": "Anakin", "last_name": "Skywalker", + "location" : { + "planet" : "Death Star", + "coordinates" : { + "x" : -73.9, "y" : 40.7 + } + }, "visits" : 50, "posts": [ { @@ -687,7 +974,8 @@ private static void initUsers() { "$date": "2025-01-15T13:46:33.855Z" } } - ] + ], + "embedding" : [5.0005, 6.12345, 7.23456, 8.34567, 9.45678] }"""); Document kylo = Document.parse(""" @@ -695,7 +983,8 @@ private static void initUsers() { "_id": "id-7", "username": "kylo", "first_name": "Ben", - "last_name": "Solo" + "last_name": "Solo", + "embedding" : [6.0006, 7.12345, 8.23456, 9.34567, 10.45678] } """); diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryMetadataTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryMetadataTests.java index aa069a2710..7fb8870263 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryMetadataTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryMetadataTests.java @@ -15,19 +15,19 @@ */ package org.springframework.data.mongodb.repository.aot; -import static net.javacrumbs.jsonunit.assertj.JsonAssertions.*; -import static org.assertj.core.api.Assertions.*; -import static org.mockito.Mockito.*; +import static net.javacrumbs.jsonunit.assertj.JsonAssertions.assertThatJson; +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.Mockito.mock; import example.aot.UserRepository; import java.io.IOException; import java.nio.charset.StandardCharsets; +import java.util.List; import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; - import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; @@ -150,7 +150,7 @@ void shouldDocumentAggregation() throws IOException { assertThatJson(json).inPath("$.methods[?(@.name == 'findAllLastnames')].query").isArray().element(0).isObject() .containsEntry("pipeline", - "[{ '$match' : { 'last_name' : { '$ne' : null } } }, { '$project': { '_id' : '$last_name' } }]"); + List.of("{ '$match' : { 'last_name' : { '$ne' : null } } }", "{ '$project': { '_id' : '$last_name' } }")); } @Test // GH-4964 @@ -165,7 +165,7 @@ void shouldDocumentPipelineUpdate() throws IOException { assertThatJson(json).inPath("$.methods[?(@.name == 'findAndIncrementVisitsViaPipelineByLastname')].query").isArray() .element(0).isObject().containsEntry("filter", "{'lastname':?0}").containsEntry("update-pipeline", - "[{ '$set' : { 'visits' : { '$ifNull' : [ {'$add' : [ '$visits', ?1 ] }, ?1 ] } } }]"); + List.of("{ '$set' : { 'visits' : { '$ifNull' : [ {'$add' : [ '$visits', ?1 ] }, ?1 ] } } }")); } @Test // GH-4964 diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/QueryMethodContributionUnitTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/QueryMethodContributionUnitTests.java new file mode 100644 index 0000000000..d8de601d4e --- /dev/null +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/aot/QueryMethodContributionUnitTests.java @@ -0,0 +1,413 @@ +/* + * Copyright 2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.mongodb.repository.aot; + +import static org.assertj.core.api.Assertions.assertThat; + +import example.aot.User; +import example.aot.UserRepository; + +import java.lang.reflect.Method; +import java.util.Arrays; +import java.util.List; +import java.util.regex.Pattern; + +import javax.lang.model.element.Modifier; + +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; +import org.springframework.data.domain.Limit; +import org.springframework.data.domain.Pageable; +import org.springframework.data.domain.Range; +import org.springframework.data.domain.Score; +import org.springframework.data.domain.SearchResults; +import org.springframework.data.domain.Similarity; +import org.springframework.data.domain.Vector; +import org.springframework.data.geo.Box; +import org.springframework.data.geo.Circle; +import org.springframework.data.geo.Distance; +import org.springframework.data.geo.GeoResults; +import org.springframework.data.geo.Point; +import org.springframework.data.geo.Polygon; +import org.springframework.data.mongodb.core.MongoOperations; +import org.springframework.data.mongodb.core.annotation.Collation; +import org.springframework.data.mongodb.core.geo.GeoJsonPolygon; +import org.springframework.data.mongodb.core.geo.Sphere; +import org.springframework.data.mongodb.repository.Hint; +import org.springframework.data.mongodb.repository.ReadPreference; +import org.springframework.data.mongodb.repository.VectorSearch; +import org.springframework.data.repository.Repository; +import org.springframework.data.repository.aot.generate.AotQueryMethodGenerationContext; +import org.springframework.data.repository.aot.generate.AotRepositoryFragmentMetadata; +import org.springframework.data.repository.aot.generate.MethodContributor; +import org.springframework.data.repository.core.RepositoryInformation; +import org.springframework.data.repository.query.QueryMethod; +import org.springframework.javapoet.ClassName; +import org.springframework.javapoet.FieldSpec; +import org.springframework.javapoet.MethodSpec; + +/** + * @author Christoph Strobl + */ +public class QueryMethodContributionUnitTests { + + @Test // GH-5004 + void rendersQueryForNearUsingPoint() throws NoSuchMethodException { + + MethodSpec methodSpec = codeOf(UserRepository.class, "findByLocationCoordinatesNear", Point.class); + + assertThat(methodSpec.toString()) // + .contains("{'location.coordinates':{'$near':?0}}") // + .contains("arguments(location)") // + .contains("return finder.matching(filterQuery).all()"); + } + + @Test // GH-5004 + void rendersQueryForWithinUsingCircle() throws NoSuchMethodException { + + MethodSpec methodSpec = codeOf(UserRepository.class, "findByLocationCoordinatesWithin", Circle.class); + + assertThat(methodSpec.toString()) // + .contains("{'location.coordinates':{'$geoWithin':{'$center':?0}}") // + .contains( + "List.of(circle.getCenter().getX(), circle.getCenter().getY()), circle.getRadius().getNormalizedValue())") // + .contains("return finder.matching(filterQuery).all()"); + } + + @Test // GH-5004 + void rendersQueryForWithinUsingSphere() throws NoSuchMethodException { + + MethodSpec methodSpec = codeOf(UserRepository.class, "findByLocationCoordinatesWithin", Sphere.class); + + assertThat(methodSpec.toString()) // + .contains("{'location.coordinates':{'$geoWithin':{'$centerSphere':?0}}") // + .contains( + "List.of(circle.getCenter().getX(), circle.getCenter().getY()), circle.getRadius().getNormalizedValue())") // + .contains("return finder.matching(filterQuery).all()"); + } + + @Test // GH-5004 + void rendersQueryForWithinUsingBox() throws NoSuchMethodException { + + MethodSpec methodSpec = codeOf(UserRepository.class, "findByLocationCoordinatesWithin", Box.class); + + assertThat(methodSpec.toString()) // + .contains("{'location.coordinates':{'$geoWithin':{'$box':?0}}") // + .contains("List.of(box.getFirst().getX(), box.getFirst().getY())") // + .contains("List.of(box.getSecond().getX(), box.getSecond().getY())") // + .contains("return finder.matching(filterQuery).all()"); + } + + @Test // GH-5004 + void rendersQueryForWithinUsingPolygon() throws NoSuchMethodException { + + MethodSpec methodSpec = codeOf(UserRepository.class, "findByLocationCoordinatesWithin", Polygon.class); + + assertThat(methodSpec.toString()) // + .contains("{'location.coordinates':{'$geoWithin':{'$polygon':?0}}") // + .contains("polygon.getPoints().stream().map(_p ->") // + .contains("List.of(_p.getX(), _p.getY())") // + .contains("return finder.matching(filterQuery).all()"); + } + + @Test // GH-5004 + void rendersQueryForWithinUsingGeoJsonPolygon() throws NoSuchMethodException { + + MethodSpec methodSpec = codeOf(UserRepository.class, "findByLocationCoordinatesWithin", GeoJsonPolygon.class); + + assertThat(methodSpec.toString()) // + .contains("{'location.coordinates':{'$geoWithin':{'$geometry':?0}}") // + .contains("arguments(polygon)") // + .contains("return finder.matching(filterQuery).all()"); + } + + @Test // GH-5004 + void rendersNearQueryForGeoResults() throws NoSuchMethodException { + + MethodSpec methodSpec = codeOf(UserRepoWithMeta.class, "findByLocationCoordinatesNear", Point.class, + Distance.class); + + assertThat(methodSpec.toString()) // + .contains("NearQuery.near(point)") // + .contains("nearQuery.maxDistance(maxDistance).in(maxDistance.getMetric())") // + .contains(".withReadPreference(com.mongodb.ReadPreference.valueOf(\"NEAREST\")") // + .doesNotContain("nearQuery.query(") // + .contains(".near(nearQuery)") // + .contains("return nearFinder.all()"); + } + + @Test // GH-5004 + void rendersNearQueryWithDistanceRangeForGeoResults() throws NoSuchMethodException { + + MethodSpec methodSpec = codeOf(UserRepository.class, "findByLocationCoordinatesNear", Point.class, Range.class); + + assertThat(methodSpec.toString()) // + .contains("NearQuery.near(point)") // + .contains("if(distance.getLowerBound().isBounded())") // + .contains("nearQuery.minDistance(min).in(min.getMetric())") // + .contains("if(distance.getUpperBound().isBounded())") // + .contains("nearQuery.maxDistance(max).in(max.getMetric())") // + .contains(".near(nearQuery)") // + .contains("return nearFinder.all()"); + } + + @Test // GH-5004 + void rendersNearQueryReturningGeoPage() throws NoSuchMethodException { + + MethodSpec methodSpec = codeOf(UserRepository.class, "findByLocationCoordinatesNear", Point.class, Distance.class, + Pageable.class); + + assertThat(methodSpec.toString()) // + .contains("NearQuery.near(point)") // + .contains("nearQuery.maxDistance(maxDistance).in(maxDistance.getMetric())") // + .doesNotContain("nearQuery.query(") // + .contains("var geoResult = nearFinder.all()") // + .contains("PageableExecutionUtils.getPage(geoResult.getContent(), pageable, () -> nearFinder.count())") + .contains("GeoPage<>(geoResult, pageable, resultPage.getTotalElements())"); + } + + @Test // GH-5004 + void rendersNearQueryWithFilterForGeoResults() throws NoSuchMethodException { + + MethodSpec methodSpec = codeOf(UserRepository.class, "findByLocationCoordinatesNearAndLastname", Point.class, + Distance.class, String.class); + + assertThat(methodSpec.toString()) // + .contains("NearQuery.near(point)") // + .contains("nearQuery.maxDistance(maxDistance).in(maxDistance.getMetric())") // + .contains("filterQuery = createQuery(\"{'lastname':?0}\", arguments(lastname))") // + .contains("nearQuery.query(filterQuery)") // + .contains(".near(nearQuery)") // + .contains("return nearFinder.all()"); + } + + @Test // GH-5006 + void rendersExpressionUsingParameterIndex() throws NoSuchMethodException { + + MethodSpec methodSpec = codeOf(UserRepository.class, "findWithExpressionUsingParameterIndex", String.class); + + assertThat(methodSpec.toString()) // + .contains("createQuery(\"{ firstname : ?#{[0]} }\", argumentMap(\"firstname\", firstname))"); + } + + @Test // GH-5006 + void rendersExpressionUsingParameterName() throws NoSuchMethodException { + + MethodSpec methodSpec = codeOf(UserRepository.class, "findWithExpressionUsingParameterName", String.class); + + assertThat(methodSpec.toString()) // + .contains("createQuery(\"{ firstname : :#{#firstname} }\", argumentMap(\"firstname\", firstname))"); + } + + @Test // GH-4939 + void rendersRegexCriteria() throws NoSuchMethodException { + + MethodSpec methodSpec = codeOf(UserRepository.class, "findByFirstnameRegex", Pattern.class); + + assertThat(methodSpec.toString()) // + .contains("createQuery(\"{'firstname':{'$regex':?0}}\", arguments(pattern))"); + } + + @Test // GH-4939 + void rendersHint() throws NoSuchMethodException { + + MethodSpec methodSpec = codeOf(UserRepoWithMeta.class, "findByFirstname", String.class); + + assertThat(methodSpec.toString()) // + .contains(".withHint(\"fn-idx\")"); + } + + @Test // GH-4939 + void rendersCollation() throws NoSuchMethodException { + + MethodSpec methodSpec = codeOf(UserRepoWithMeta.class, "findByFirstname", String.class); + + assertThat(methodSpec.toString()) // + .containsSubsequence(".collation(", "Collation.parse(\"en_US\"))"); + } + + @Test // GH-4939 + void rendersCollationFromExpression() throws NoSuchMethodException { + + MethodSpec methodSpec = codeOf(UserRepoWithMeta.class, "findWithCollationByFirstname", String.class, String.class); + + assertThat(methodSpec.toString()) // + .containsIgnoringWhitespaces( + "collationOf(evaluate(\"?#{[1]}\", argumentMap(\"firstname\", firstname, \"locale\", locale)))"); + } + + @Test + void rendersVectorSearchFilterFromAnnotatedQuery() throws NoSuchMethodException { + + MethodSpec methodSpec = codeOf(UserRepository.class, "annotatedVectorSearch", String.class, Vector.class, + Score.class, Limit.class); + + assertThat(methodSpec.toString()) // + .containsSubsequence("$vectorSearch =", + "Aggregation.vectorSearch(\"embedding.vector_cos\").path(\"embedding\").vector(vector).limit(limit);") + .contains("filter = createQuery(\"{lastname: ?0}\", arguments(lastname, distance))") + .contains("$vectorSearch.filter(filter.getQueryObject())"); + } + + @Test + void rendersVectorSearchNumCandidatesExpression() throws NoSuchMethodException { + + MethodSpec methodSpec = codeOf(UserRepository.class, "annotatedVectorSearch", String.class, Vector.class, + Score.class, Limit.class); + + assertThat(methodSpec.toString()) // + .containsSubsequence("$vectorSearch.numCandidates", + "evaluate(\"#{10+10}\", argumentMap(\"lastname\", lastname, \"distance\", distance)))"); + } + + @Test + void rendersVectorSearchScoringFunctionFromScore() throws NoSuchMethodException { + + MethodSpec methodSpec = codeOf(UserRepository.class, "annotatedVectorSearch", String.class, Vector.class, + Score.class, Limit.class); + + assertThat(methodSpec.toString()) // + .contains("ScoringFunction scoringFunction = distance.getFunction()"); + } + + @Test + void rendersVectorSearchSearchTypeFromAnnotation() throws NoSuchMethodException { + + MethodSpec methodSpec = codeOf(UserRepository.class, "annotatedVectorSearch", String.class, Vector.class, + Score.class, Limit.class); + + assertThat(methodSpec.toString()) // + .containsSubsequence("$vectorSearch.searchType(", "VectorSearchOperation.SearchType.ANN)"); + } + + @Test + void rendersVectorSearchQueryFromMethodName() throws NoSuchMethodException { + + MethodSpec methodSpec = codeOf(UserRepository.class, "searchCosineByLastnameAndEmbeddingNear", String.class, + Vector.class, Score.class, Limit.class); + + assertThat(methodSpec.toString()) // + .contains("filter = createQuery(\"{'lastname':?0}\", arguments(lastname, similarity))"); + } + + @Test + void rendersVectorSearchNumCandidatesFromLimitIfNotExplicitlyDefined() throws NoSuchMethodException { + + MethodSpec methodSpec = codeOf(UserRepository.class, "searchCosineByLastnameAndEmbeddingNear", String.class, + Vector.class, Score.class, Limit.class); + + assertThat(methodSpec.toString()) // + .contains("$vectorSearch.numCandidates(limit.max() * 20)"); + } + + @Test + void rendersVectorSearchLimitFromAnnotation() throws NoSuchMethodException { + + MethodSpec methodSpec = codeOf(UserRepository.class, "searchByLastnameAndEmbeddingWithin", String.class, + Vector.class, Range.class); + + assertThat(methodSpec.toString()) // + .contains("Aggregation.vectorSearch(\"embedding.vector_cos\").path(\"embedding\").vector(vector).limit(10)") + .contains("$vectorSearch.numCandidates(10 * 20)"); + } + + @Test + void rendersVectorSearchLimitFromExpression() throws NoSuchMethodException { + + MethodSpec methodSpec = codeOf(UserRepoWithMeta.class, + "searchWithLimitAsExpressionByLastnameAndEmbeddingWithinOrderByFirstname", String.class, Vector.class, + Range.class); + + assertThat(methodSpec.toString()) // + .containsSubsequence( + "Aggregation.vectorSearch(\"embedding.vector_cos\").path(\"embedding\").vector(vector).limit(", + "evaluate(\"#{5+5}\", argumentMap(\"lastname\", lastname, \"distance\", distance)") + .containsSubsequence("$vectorSearch.numCandidates(", + "evaluate(\"#{5+5}\", argumentMap(\"lastname\", lastname, \"distance\", distance))) * 20)"); + } + + @Test + void rendersVectorSearchOrderByScoreAsDefault() throws NoSuchMethodException { + + MethodSpec methodSpec = codeOf(UserRepository.class, "searchCosineByLastnameAndEmbeddingNear", String.class, + Vector.class, Score.class, Limit.class); + + assertThat(methodSpec.toString()) // + .contains("$vectorSearch.withSearchScore(\"__score__\")") + .containsSubsequence("$sort = ", "Aggregation.sort(", "DESC, \"__score__\")") + .containsSubsequence("AggregationPipeline(", "List.of($vectorSearch, $sort))"); + } + + @Test + void rendersVectorSearchOrderByWithScoreLast() throws NoSuchMethodException { + + MethodSpec methodSpec = codeOf(UserRepository.class, "searchByLastnameAndEmbeddingWithinOrderByFirstname", + String.class, Vector.class, Range.class); + + assertThat(methodSpec.toString()) // + .containsSubsequence("AggregationOperation $sort = (_ctx) -> {", // + "_mappedSort = _ctx.getMappedObject(", // + "Document.parse(\"{'firstname':{'$numberInt':'1'}}\")", // + "Document(\"$sort\", _mappedSort.append(\"__score__\", -1))"); + } + + private static MethodSpec codeOf(Class repository, String methodName, Class... args) + throws NoSuchMethodException { + + Method method = repository.getMethod(methodName, args); + + TestMongoAotRepositoryContext repoContext = new TestMongoAotRepositoryContext(repository, null); + MongoRepositoryContributor contributor = new MongoRepositoryContributor(repoContext); + MethodContributor methodContributor = contributor.contributeQueryMethod(method); + + if (methodContributor == null) { + Assertions.fail("No contribution for method %s.%s(%s)".formatted(repository.getSimpleName(), methodName, + Arrays.stream(args).map(Class::getSimpleName).toList())); + } + AotRepositoryFragmentMetadata metadata = new AotRepositoryFragmentMetadata(ClassName.get(repository)); + metadata.addField( + FieldSpec.builder(MongoOperations.class, "mongoOperations", Modifier.PRIVATE, Modifier.FINAL).build()); + + TestQueryMethodGenerationContext methodContext = new TestQueryMethodGenerationContext( + repoContext.getRepositoryInformation(), method, methodContributor.getQueryMethod(), metadata); + return methodContributor.contribute(methodContext); + } + + static class TestQueryMethodGenerationContext extends AotQueryMethodGenerationContext { + + protected TestQueryMethodGenerationContext(RepositoryInformation repositoryInformation, Method method, + QueryMethod queryMethod, AotRepositoryFragmentMetadata targetTypeMetadata) { + super(repositoryInformation, method, queryMethod, targetTypeMetadata); + } + } + + interface UserRepoWithMeta extends Repository { + + @Hint("fn-idx") + @Collation("en_US") + List findByFirstname(String firstname); + + @Collation("?#{[1]}") + List findWithCollationByFirstname(String firstname, String locale); + + @ReadPreference("NEAREST") + GeoResults findByLocationCoordinatesNear(Point point, Distance maxDistance); + + @VectorSearch(indexName = "embedding.vector_cos", limit = "#{5+5}") + SearchResults searchWithLimitAsExpressionByLastnameAndEmbeddingWithinOrderByFirstname(String lastname, + Vector vector, Range distance); + } +} diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/query/MongoQueryExecutionUnitTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/query/MongoQueryExecutionUnitTests.java index 2c0c996bc3..11a025ea5c 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/query/MongoQueryExecutionUnitTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/query/MongoQueryExecutionUnitTests.java @@ -32,7 +32,6 @@ import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; - import org.springframework.data.domain.PageRequest; import org.springframework.data.domain.Pageable; import org.springframework.data.geo.Distance; @@ -171,7 +170,6 @@ void pagingGeoExecutionRetrievesObjectsForPageableOutOfRange() { when(mongoOperationsMock.query(any(Class.class))).thenReturn(findOperationMock); when(findOperationMock.near(any(NearQuery.class))).thenReturn(terminatingGeoMock); doReturn(new GeoResults<>(Collections.emptyList())).when(terminatingGeoMock).all(); - doReturn(terminatingMock).when(findOperationMock).matching(any(Query.class)); ConvertingParameterAccessor accessor = new ConvertingParameterAccessor(converter, new MongoParametersParameterAccessor(queryMethod, new Object[] { POINT, DISTANCE, PageRequest.of(2, 10) })); @@ -183,7 +181,7 @@ void pagingGeoExecutionRetrievesObjectsForPageableOutOfRange() { execution.execute(new Query()); verify(terminatingGeoMock).all(); - verify(terminatingMock).count(); + verify(terminatingGeoMock).count(); } @Test // DATAMONGO-2351 diff --git a/spring-data-mongodb/src/test/resources/logback.xml b/spring-data-mongodb/src/test/resources/logback.xml index 55e4309a36..d0907937fa 100644 --- a/spring-data-mongodb/src/test/resources/logback.xml +++ b/spring-data-mongodb/src/test/resources/logback.xml @@ -20,8 +20,9 @@ - + + diff --git a/src/main/antora/modules/ROOT/pages/mongodb/repositories/query-methods.adoc b/src/main/antora/modules/ROOT/pages/mongodb/repositories/query-methods.adoc index adb2392f04..31a19b5aca 100644 --- a/src/main/antora/modules/ROOT/pages/mongodb/repositories/query-methods.adoc +++ b/src/main/antora/modules/ROOT/pages/mongodb/repositories/query-methods.adoc @@ -209,9 +209,9 @@ NOTE: If the property criterion compares a document, the order of the fields and == Geo-spatial Queries As you saw in the preceding table of keywords, a few keywords trigger geo-spatial operations within a MongoDB query. -The `Near` keyword allows some further modification, as the next few examples show. +The `Near` and `Within` keywords allows some further modification, as the next few examples show. -The following example shows how to define a `near` query that finds all persons with a given distance of a given point: +The following example shows how to define a `near` / `within` query that finds all persons using different shapes: .Advanced `Near` queries [tabs] @@ -222,8 +222,20 @@ Imperative:: ---- public interface PersonRepository extends MongoRepository { - // { 'location' : { '$near' : [point.x, point.y], '$maxDistance' : distance}} + // { 'location' : { '$near' : [point.x, point.y], '$maxDistance' : distance } } List findByLocationNear(Point location, Distance distance); + + // { 'location' : { $geoWithin: { $center: [ [ circle.center.x, circle.center.y ], circle.radius ] } } } + List findByLocationWithin(Circle circle); + + // { 'location' : { $geoWithin: { $box: [ [ box.first.x, box.first.y ], [ box.second.x, box.second.y ] ] } } } + List findByLocationWithin(Box box); + + // { 'location' : { $geoWithin: { $polygon: [ [ polygon.x1, polygon.y1 ], [ polygon.x2, polygon.y2 ], ... ] } } } + List findByLocationWithin(Polygon polygon); + + // { 'location' : { $geoWithin: { $geometry: { $type : 'polygon', coordinates: [[ polygon.x1, polygon.y1 ], [ polygon.x2, polygon.y2 ], ... ] } } } } + List findByLocationWithin(GeoJsonPolygon polygon); } ---- @@ -233,8 +245,20 @@ Reactive:: ---- interface PersonRepository extends ReactiveMongoRepository { - // { 'location' : { '$near' : [point.x, point.y], '$maxDistance' : distance}} + // { 'location' : { '$near' : [point.x, point.y], '$maxDistance' : distance } } Flux findByLocationNear(Point location, Distance distance); + + // { 'location' : { $geoWithin: { $center: [ [ circle.center.x, circle.center.y ], circle.radius ] } } } + Flux findByLocationWithin(Circle circle); + + // { 'location' : { $geoWithin: { $box: [ [ box.first.x, box.first.y ], [ box.second.x, box.second.y ] ] } } } + Flux findByLocationWithin(Box box); + + // { 'location' : { $geoWithin: { $polygon: [ [ polygon.x1, polygon.y1 ], [ polygon.x2, polygon.y2 ], ... ] } } } + Flux findByLocationWithin(Polygon polygon); + + // { 'location' : { $geoWithin: { $geometry: { $type : 'polygon', coordinates: [[ polygon.x1, polygon.y1 ], [ polygon.x2, polygon.y2 ], ... ] } } } } + Flux findByLocationWithin(GeoJsonPolygon polygon); } ---- ======