diff --git a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/repository/query/Multiparameter.java b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/repository/query/Multiparameter.java new file mode 100644 index 0000000000..accf004607 --- /dev/null +++ b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/repository/query/Multiparameter.java @@ -0,0 +1,34 @@ +/* + * Copyright 2020-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.jdbc.repository.query; + +import org.springframework.jdbc.core.namedparam.SqlParameterSource; + +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +/** + * Indicates a method parameter, which may be a Map or JavaBean, + * should be regarded as multiple parameters and added to the {@link SqlParameterSource}. + * + * @author Zhou Xingyi + */ +@Target(ElementType.PARAMETER) +@Retention(RetentionPolicy.RUNTIME) +public @interface Multiparameter { +} diff --git a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/repository/query/StringBasedJdbcQuery.java b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/repository/query/StringBasedJdbcQuery.java index 58ac579b66..491a0051f9 100644 --- a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/repository/query/StringBasedJdbcQuery.java +++ b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/repository/query/StringBasedJdbcQuery.java @@ -18,12 +18,15 @@ import static org.springframework.data.jdbc.repository.query.JdbcQueryExecution.*; import java.lang.reflect.Constructor; +import java.lang.reflect.Field; import java.sql.SQLType; import java.util.ArrayList; import java.util.List; +import java.util.Map; import org.springframework.beans.BeanUtils; import org.springframework.beans.factory.BeanFactory; +import org.springframework.core.MethodParameter; import org.springframework.core.ResolvableType; import org.springframework.core.convert.converter.Converter; import org.springframework.data.jdbc.core.convert.JdbcColumnTypes; @@ -39,12 +42,14 @@ import org.springframework.data.repository.query.ResultProcessor; import org.springframework.jdbc.core.ResultSetExtractor; import org.springframework.jdbc.core.RowMapper; +import org.springframework.jdbc.core.namedparam.BeanPropertySqlParameterSource; import org.springframework.jdbc.core.namedparam.MapSqlParameterSource; import org.springframework.jdbc.core.namedparam.NamedParameterJdbcOperations; import org.springframework.lang.Nullable; import org.springframework.util.Assert; import org.springframework.util.ClassUtils; import org.springframework.util.ObjectUtils; +import org.springframework.util.ReflectionUtils; import org.springframework.util.StringUtils; /** @@ -58,6 +63,7 @@ * @author Mark Paluch * @author Hebert Coelho * @author Chirag Tailor + * @author Zhou Xingyii * @since 2.0 */ public class StringBasedJdbcQuery extends AbstractJdbcQuery { @@ -162,43 +168,67 @@ private void convertAndAddParameter(MapSqlParameterSource parameters, Parameter String parameterName = p.getName().orElseThrow(() -> new IllegalStateException(PARAMETER_NEEDS_TO_BE_NAMED)); - RelationalParameters.RelationalParameter parameter = queryMethod.getParameters().getParameter(p.getIndex()); - ResolvableType resolvableType = parameter.getResolvableType(); - Class type = resolvableType.resolve(); - Assert.notNull(type, "@Query parameter type could not be resolved!"); + final Field methodParameterField = ReflectionUtils.findField(Parameter.class, "parameter", MethodParameter.class); + Assert.state(methodParameterField != null, "MethodParameter must not be null"); + ReflectionUtils.makeAccessible(methodParameterField); + final MethodParameter methodParameter = (MethodParameter) ReflectionUtils.getField(methodParameterField, p); + if (methodParameter != null && methodParameter.hasParameterAnnotation(Multiparameter.class)) { + if (value instanceof Map) { + final Map m = (Map) value; + m.forEach((propertyName, v) -> parameters.addValue(parameterName + '.' + propertyName, v)); + } else { + final BeanPropertySqlParameterSource parameterSource = new BeanPropertySqlParameterSource(value); + for (String propertyName : parameterSource.getParameterNames()) { + final String newParameterName = parameterName + '.' + propertyName; + final Object parameterValue = parameterSource.getValue(propertyName); + final int sqlType = parameterSource.getSqlType(propertyName); + final String typeName = parameterSource.getTypeName(propertyName); + if (typeName == null) { + parameters.addValue(newParameterName, parameterValue, sqlType); + } else { + parameters.addValue(newParameterName, parameterValue, sqlType, typeName); + } + } + } + } else { + RelationalParameters.RelationalParameter parameter = queryMethod.getParameters().getParameter(p.getIndex()); + ResolvableType resolvableType = parameter.getResolvableType(); + Class type = resolvableType.resolve(); + Assert.notNull(type, "@Query parameter type could not be resolved!"); - JdbcValue jdbcValue; - if (value instanceof Iterable) { + JdbcValue jdbcValue; + if (value instanceof Iterable) { - List mapped = new ArrayList<>(); - SQLType jdbcType = null; + List mapped = new ArrayList<>(); + SQLType jdbcType = null; - Class elementType = resolvableType.getGeneric(0).resolve(); + Class elementType = resolvableType.getGeneric(0).resolve(); - Assert.notNull(elementType, "@Query Iterable parameter generic type could not be resolved!"); + Assert.notNull(elementType, "@Query Iterable parameter generic type could not be resolved!"); - for (Object o : (Iterable) value) { - JdbcValue elementJdbcValue = converter.writeJdbcValue(o, elementType, - JdbcUtil.targetSqlTypeFor(JdbcColumnTypes.INSTANCE.resolvePrimitiveType(elementType))); - if (jdbcType == null) { - jdbcType = elementJdbcValue.getJdbcType(); + for (Object o : (Iterable) value) { + JdbcValue elementJdbcValue = converter.writeJdbcValue(o, elementType, + JdbcUtil.targetSqlTypeFor(JdbcColumnTypes.INSTANCE.resolvePrimitiveType(elementType))); + if (jdbcType == null) { + jdbcType = elementJdbcValue.getJdbcType(); + } + + mapped.add(elementJdbcValue.getValue()); } - mapped.add(elementJdbcValue.getValue()); + jdbcValue = JdbcValue.of(mapped, jdbcType); + } else { + jdbcValue = converter.writeJdbcValue(value, type, + JdbcUtil.targetSqlTypeFor(JdbcColumnTypes.INSTANCE.resolvePrimitiveType(type))); } - jdbcValue = JdbcValue.of(mapped, jdbcType); - } else { - jdbcValue = converter.writeJdbcValue(value, type, - JdbcUtil.targetSqlTypeFor(JdbcColumnTypes.INSTANCE.resolvePrimitiveType(type))); - } - - SQLType jdbcType = jdbcValue.getJdbcType(); - if (jdbcType == null) { + SQLType jdbcType = jdbcValue.getJdbcType(); + if (jdbcType == null) { - parameters.addValue(parameterName, jdbcValue.getValue()); - } else { - parameters.addValue(parameterName, jdbcValue.getValue(), jdbcType.getVendorTypeNumber()); + parameters.addValue(parameterName, jdbcValue.getValue()); + } else { + parameters.addValue(parameterName, jdbcValue.getValue(), jdbcType.getVendorTypeNumber()); + } } } diff --git a/spring-data-jdbc/src/test/java/org/springframework/data/jdbc/repository/query/StringBasedJdbcQueryUnitTests.java b/spring-data-jdbc/src/test/java/org/springframework/data/jdbc/repository/query/StringBasedJdbcQueryUnitTests.java index 603c6138ef..67fb7b5110 100644 --- a/spring-data-jdbc/src/test/java/org/springframework/data/jdbc/repository/query/StringBasedJdbcQueryUnitTests.java +++ b/spring-data-jdbc/src/test/java/org/springframework/data/jdbc/repository/query/StringBasedJdbcQueryUnitTests.java @@ -17,12 +17,19 @@ import static java.util.Arrays.*; import static org.assertj.core.api.Assertions.*; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.*; import java.lang.reflect.Method; import java.sql.JDBCType; import java.sql.ResultSet; +import java.util.HashMap; import java.util.List; +import java.util.Map; import java.util.Properties; import java.util.Set; import java.util.stream.Stream; @@ -50,12 +57,15 @@ import org.springframework.data.repository.Repository; import org.springframework.data.repository.core.support.DefaultRepositoryMetadata; import org.springframework.data.repository.core.support.PropertiesBasedNamedQueries; +import org.springframework.data.repository.query.Param; import org.springframework.jdbc.core.ResultSetExtractor; import org.springframework.jdbc.core.RowMapper; import org.springframework.jdbc.core.namedparam.NamedParameterJdbcOperations; import org.springframework.jdbc.core.namedparam.SqlParameterSource; import org.springframework.util.ReflectionUtils; +import lombok.Getter; + /** * Unit tests for {@link StringBasedJdbcQuery}. * @@ -231,6 +241,33 @@ void doesNotConvertNonCollectionParameter() { assertThat(sqlParameterSource.getValue("value")).isEqualTo(1); } + @Test + void convertMapAndJavaBeanParameter() { + JdbcQueryMethod queryMethod = createMethod("queryMethodWithQueryParameters", Map.class, PageInfo.class); + BasicJdbcConverter converter = new BasicJdbcConverter(mock(RelationalMappingContext.class), mock(RelationResolver.class)); + StringBasedJdbcQuery query = new StringBasedJdbcQuery(queryMethod, operations, result -> mock(RowMapper.class), converter); + + Map queryParams = new HashMap<>(1); + queryParams.put("status", "BLOCKED"); + + PageInfo pageInfo = new PageInfo(5L, 15); + + query.execute(new Object[] { queryParams, pageInfo }); + + ArgumentCaptor captor = ArgumentCaptor.forClass(SqlParameterSource.class); + verify(operations).queryForObject(anyString(), captor.capture(), any(RowMapper.class)); + + SqlParameterSource sqlParameterSource = captor.getValue(); + assertTrue(sqlParameterSource.hasValue("queryParams.status")); + assertEquals(queryParams.get("status"), sqlParameterSource.getValue("queryParams.status")); + assertTrue(sqlParameterSource.hasValue("page.size")); + assertEquals(pageInfo.getSize(), sqlParameterSource.getValue("page.size")); + assertTrue(sqlParameterSource.hasValue("page.pageNumber")); + assertEquals(pageInfo.getPageNumber(), sqlParameterSource.getValue("page.pageNumber")); + assertTrue(sqlParameterSource.hasValue("page.offset")); + assertEquals(pageInfo.getOffset(), sqlParameterSource.getValue("page.offset")); + } + private JdbcQueryMethod createMethod(String methodName, Class... paramTypes) { Method method = ReflectionUtils.findMethod(MyRepository.class, methodName, paramTypes); @@ -276,6 +313,11 @@ interface MyRepository extends Repository { @Query(value = "some sql statement") List findBySimpleValue(Integer value); + + @Query("SELECT something FROM table_name WHERE status = queryParams.status LIMIT page.size OFFSET page.offset") + List queryMethodWithQueryParameters( + @Multiparameter Map queryParams, + @Param("page") @Multiparameter PageInfo pageInfo); } private static class CustomRowMapper implements RowMapper { @@ -364,4 +406,16 @@ Long getId() { return id; } } + + @Getter + private static class PageInfo { + private final long pageNumber; + private final int size; + private final long offset; + public PageInfo(long pageNumber, int size) { + this.pageNumber = pageNumber; + this.size = size; + this.offset = (pageNumber - 1) * size; + } + } }