Skip to content

Allow query methods get multiple parameters from Map or JavaBean and add them to the SqlParameterSource. #1564

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
/*
* Copyright 2020-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.data.jdbc.repository.query;

import org.springframework.jdbc.core.namedparam.SqlParameterSource;

import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;

/**
* Indicates a method parameter, which may be a Map or JavaBean,
* should be regarded as multiple parameters and added to the {@link SqlParameterSource}.
*
* @author Zhou Xingyi
*/
@Target(ElementType.PARAMETER)
@Retention(RetentionPolicy.RUNTIME)
public @interface Multiparameter {
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,15 @@
import static org.springframework.data.jdbc.repository.query.JdbcQueryExecution.*;

import java.lang.reflect.Constructor;
import java.lang.reflect.Field;
import java.sql.SQLType;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;

import org.springframework.beans.BeanUtils;
import org.springframework.beans.factory.BeanFactory;
import org.springframework.core.MethodParameter;
import org.springframework.core.ResolvableType;
import org.springframework.core.convert.converter.Converter;
import org.springframework.data.jdbc.core.convert.JdbcColumnTypes;
Expand All @@ -39,12 +42,14 @@
import org.springframework.data.repository.query.ResultProcessor;
import org.springframework.jdbc.core.ResultSetExtractor;
import org.springframework.jdbc.core.RowMapper;
import org.springframework.jdbc.core.namedparam.BeanPropertySqlParameterSource;
import org.springframework.jdbc.core.namedparam.MapSqlParameterSource;
import org.springframework.jdbc.core.namedparam.NamedParameterJdbcOperations;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
import org.springframework.util.ClassUtils;
import org.springframework.util.ObjectUtils;
import org.springframework.util.ReflectionUtils;
import org.springframework.util.StringUtils;

/**
Expand All @@ -58,6 +63,7 @@
* @author Mark Paluch
* @author Hebert Coelho
* @author Chirag Tailor
* @author Zhou Xingyii
* @since 2.0
*/
public class StringBasedJdbcQuery extends AbstractJdbcQuery {
Expand Down Expand Up @@ -162,43 +168,67 @@ private void convertAndAddParameter(MapSqlParameterSource parameters, Parameter

String parameterName = p.getName().orElseThrow(() -> new IllegalStateException(PARAMETER_NEEDS_TO_BE_NAMED));

RelationalParameters.RelationalParameter parameter = queryMethod.getParameters().getParameter(p.getIndex());
ResolvableType resolvableType = parameter.getResolvableType();
Class<?> type = resolvableType.resolve();
Assert.notNull(type, "@Query parameter type could not be resolved!");
final Field methodParameterField = ReflectionUtils.findField(Parameter.class, "parameter", MethodParameter.class);
Assert.state(methodParameterField != null, "MethodParameter must not be null");
ReflectionUtils.makeAccessible(methodParameterField);
final MethodParameter methodParameter = (MethodParameter) ReflectionUtils.getField(methodParameterField, p);
if (methodParameter != null && methodParameter.hasParameterAnnotation(Multiparameter.class)) {
if (value instanceof Map<?, ?>) {
final Map<?, ?> m = (Map<?, ?>) value;
m.forEach((propertyName, v) -> parameters.addValue(parameterName + '.' + propertyName, v));
} else {
final BeanPropertySqlParameterSource parameterSource = new BeanPropertySqlParameterSource(value);
for (String propertyName : parameterSource.getParameterNames()) {
final String newParameterName = parameterName + '.' + propertyName;
final Object parameterValue = parameterSource.getValue(propertyName);
final int sqlType = parameterSource.getSqlType(propertyName);
final String typeName = parameterSource.getTypeName(propertyName);
if (typeName == null) {
parameters.addValue(newParameterName, parameterValue, sqlType);
} else {
parameters.addValue(newParameterName, parameterValue, sqlType, typeName);
}
}
}
} else {
RelationalParameters.RelationalParameter parameter = queryMethod.getParameters().getParameter(p.getIndex());
ResolvableType resolvableType = parameter.getResolvableType();
Class<?> type = resolvableType.resolve();
Assert.notNull(type, "@Query parameter type could not be resolved!");

JdbcValue jdbcValue;
if (value instanceof Iterable) {
JdbcValue jdbcValue;
if (value instanceof Iterable) {

List<Object> mapped = new ArrayList<>();
SQLType jdbcType = null;
List<Object> mapped = new ArrayList<>();
SQLType jdbcType = null;

Class<?> elementType = resolvableType.getGeneric(0).resolve();
Class<?> elementType = resolvableType.getGeneric(0).resolve();

Assert.notNull(elementType, "@Query Iterable parameter generic type could not be resolved!");
Assert.notNull(elementType, "@Query Iterable parameter generic type could not be resolved!");

for (Object o : (Iterable<?>) value) {
JdbcValue elementJdbcValue = converter.writeJdbcValue(o, elementType,
JdbcUtil.targetSqlTypeFor(JdbcColumnTypes.INSTANCE.resolvePrimitiveType(elementType)));
if (jdbcType == null) {
jdbcType = elementJdbcValue.getJdbcType();
for (Object o : (Iterable<?>) value) {
JdbcValue elementJdbcValue = converter.writeJdbcValue(o, elementType,
JdbcUtil.targetSqlTypeFor(JdbcColumnTypes.INSTANCE.resolvePrimitiveType(elementType)));
if (jdbcType == null) {
jdbcType = elementJdbcValue.getJdbcType();
}

mapped.add(elementJdbcValue.getValue());
}

mapped.add(elementJdbcValue.getValue());
jdbcValue = JdbcValue.of(mapped, jdbcType);
} else {
jdbcValue = converter.writeJdbcValue(value, type,
JdbcUtil.targetSqlTypeFor(JdbcColumnTypes.INSTANCE.resolvePrimitiveType(type)));
}

jdbcValue = JdbcValue.of(mapped, jdbcType);
} else {
jdbcValue = converter.writeJdbcValue(value, type,
JdbcUtil.targetSqlTypeFor(JdbcColumnTypes.INSTANCE.resolvePrimitiveType(type)));
}

SQLType jdbcType = jdbcValue.getJdbcType();
if (jdbcType == null) {
SQLType jdbcType = jdbcValue.getJdbcType();
if (jdbcType == null) {

parameters.addValue(parameterName, jdbcValue.getValue());
} else {
parameters.addValue(parameterName, jdbcValue.getValue(), jdbcType.getVendorTypeNumber());
parameters.addValue(parameterName, jdbcValue.getValue());
} else {
parameters.addValue(parameterName, jdbcValue.getValue(), jdbcType.getVendorTypeNumber());
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,19 @@

import static java.util.Arrays.*;
import static org.assertj.core.api.Assertions.*;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.*;

import java.lang.reflect.Method;
import java.sql.JDBCType;
import java.sql.ResultSet;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Properties;
import java.util.Set;
import java.util.stream.Stream;
Expand Down Expand Up @@ -50,12 +57,15 @@
import org.springframework.data.repository.Repository;
import org.springframework.data.repository.core.support.DefaultRepositoryMetadata;
import org.springframework.data.repository.core.support.PropertiesBasedNamedQueries;
import org.springframework.data.repository.query.Param;
import org.springframework.jdbc.core.ResultSetExtractor;
import org.springframework.jdbc.core.RowMapper;
import org.springframework.jdbc.core.namedparam.NamedParameterJdbcOperations;
import org.springframework.jdbc.core.namedparam.SqlParameterSource;
import org.springframework.util.ReflectionUtils;

import lombok.Getter;

/**
* Unit tests for {@link StringBasedJdbcQuery}.
*
Expand Down Expand Up @@ -231,6 +241,33 @@ void doesNotConvertNonCollectionParameter() {
assertThat(sqlParameterSource.getValue("value")).isEqualTo(1);
}

@Test
void convertMapAndJavaBeanParameter() {
JdbcQueryMethod queryMethod = createMethod("queryMethodWithQueryParameters", Map.class, PageInfo.class);
BasicJdbcConverter converter = new BasicJdbcConverter(mock(RelationalMappingContext.class), mock(RelationResolver.class));
StringBasedJdbcQuery query = new StringBasedJdbcQuery(queryMethod, operations, result -> mock(RowMapper.class), converter);

Map<String, Object> queryParams = new HashMap<>(1);
queryParams.put("status", "BLOCKED");

PageInfo pageInfo = new PageInfo(5L, 15);

query.execute(new Object[] { queryParams, pageInfo });

ArgumentCaptor<SqlParameterSource> captor = ArgumentCaptor.forClass(SqlParameterSource.class);
verify(operations).queryForObject(anyString(), captor.capture(), any(RowMapper.class));

SqlParameterSource sqlParameterSource = captor.getValue();
assertTrue(sqlParameterSource.hasValue("queryParams.status"));
assertEquals(queryParams.get("status"), sqlParameterSource.getValue("queryParams.status"));
assertTrue(sqlParameterSource.hasValue("page.size"));
assertEquals(pageInfo.getSize(), sqlParameterSource.getValue("page.size"));
assertTrue(sqlParameterSource.hasValue("page.pageNumber"));
assertEquals(pageInfo.getPageNumber(), sqlParameterSource.getValue("page.pageNumber"));
assertTrue(sqlParameterSource.hasValue("page.offset"));
assertEquals(pageInfo.getOffset(), sqlParameterSource.getValue("page.offset"));
}

private JdbcQueryMethod createMethod(String methodName, Class<?>... paramTypes) {

Method method = ReflectionUtils.findMethod(MyRepository.class, methodName, paramTypes);
Expand Down Expand Up @@ -276,6 +313,11 @@ interface MyRepository extends Repository<Object, Long> {

@Query(value = "some sql statement")
List<Object> findBySimpleValue(Integer value);

@Query("SELECT something FROM table_name WHERE status = queryParams.status LIMIT page.size OFFSET page.offset")
List<Object> queryMethodWithQueryParameters(
@Multiparameter Map<String, ?> queryParams,
@Param("page") @Multiparameter PageInfo pageInfo);
}

private static class CustomRowMapper implements RowMapper<Object> {
Expand Down Expand Up @@ -364,4 +406,16 @@ Long getId() {
return id;
}
}

@Getter
private static class PageInfo {
private final long pageNumber;
private final int size;
private final long offset;
public PageInfo(long pageNumber, int size) {
this.pageNumber = pageNumber;
this.size = size;
this.offset = (pageNumber - 1) * size;
}
}
}