diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/JdbcOAuth2AuthorizationConsentService.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/JdbcOAuth2AuthorizationConsentService.java new file mode 100644 index 000000000..2c7563974 --- /dev/null +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/JdbcOAuth2AuthorizationConsentService.java @@ -0,0 +1,249 @@ +/* + * Copyright 2020-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.server.authorization; + +import java.sql.ResultSet; +import java.sql.SQLException; +import java.sql.Types; +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Set; +import java.util.function.Function; + +import org.springframework.dao.DataRetrievalFailureException; +import org.springframework.jdbc.core.ArgumentPreparedStatementSetter; +import org.springframework.jdbc.core.JdbcOperations; +import org.springframework.jdbc.core.PreparedStatementSetter; +import org.springframework.jdbc.core.RowMapper; +import org.springframework.jdbc.core.SqlParameterValue; +import org.springframework.lang.Nullable; +import org.springframework.security.core.GrantedAuthority; +import org.springframework.security.core.authority.SimpleGrantedAuthority; +import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; +import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository; +import org.springframework.util.Assert; +import org.springframework.util.StringUtils; + +/** + * A JDBC implementation of an {@link OAuth2AuthorizationConsentService} that uses a + *

+ * {@link JdbcOperations} for {@link OAuth2AuthorizationConsent} persistence. + * + *

+ * NOTE: This {@code OAuth2AuthorizationConsentService} depends on the table definition + * described in + * "classpath:org/springframework/security/oauth2/server/authorization/oauth2-authorization-consent-schema.sql" and + * therefore MUST be defined in the database schema. + * + * @author Ovidiu Popa + * @see OAuth2AuthorizationConsentService + * @see OAuth2AuthorizationConsent + * @see JdbcOperations + * @see RowMapper + * @since 0.1.2 + */ +public final class JdbcOAuth2AuthorizationConsentService implements OAuth2AuthorizationConsentService { + + // @formatter:off + private static final String COLUMN_NAMES = "registered_client_id, " + + "principal_name, " + + "authorities"; + // @formatter:on + + private static final String TABLE_NAME = "oauth2_authorization_consent"; + + private static final String PK_FILTER = "registered_client_id = ? AND principal_name = ?"; + + // @formatter:off + private static final String LOAD_AUTHORIZATION_CONSENT_SQL = "SELECT " + COLUMN_NAMES + + " FROM " + TABLE_NAME + + " WHERE " + PK_FILTER; + // @formatter:on + + // @formatter:off + private static final String SAVE_AUTHORIZATION_CONSENT_SQL = "INSERT INTO " + TABLE_NAME + + " (" + COLUMN_NAMES + ") VALUES (?, ?, ?)"; + // @formatter:on + + // @formatter:off + private static final String UPDATE_AUTHORIZATION_CONSENT_SQL = "UPDATE " + TABLE_NAME + + " SET authorities = ?" + + " WHERE " + PK_FILTER; + // @formatter:on + + private static final String REMOVE_AUTHORIZATION_CONSENT_SQL = "DELETE FROM " + TABLE_NAME + " WHERE " + PK_FILTER; + + private final JdbcOperations jdbcOperations; + private RowMapper authorizationConsentRowMapper; + private Function> authorizationConsentParametersMapper; + + /** + * Constructs a {@code JdbcOAuth2AuthorizationConsentService} using the provided parameters. + * + * @param jdbcOperations the JDBC operations + * @param registeredClientRepository the registered client repository + */ + public JdbcOAuth2AuthorizationConsentService(JdbcOperations jdbcOperations, + RegisteredClientRepository registeredClientRepository) { + Assert.notNull(jdbcOperations, "jdbcOperations cannot be null"); + Assert.notNull(registeredClientRepository, "registeredClientRepository cannot be null"); + this.jdbcOperations = jdbcOperations; + this.authorizationConsentRowMapper = new OAuth2AuthorizationConsentRowMapper(registeredClientRepository); + this.authorizationConsentParametersMapper = new OAuth2AuthorizationConsentParametersMapper(); + } + + @Override + public void save(OAuth2AuthorizationConsent authorizationConsent) { + Assert.notNull(authorizationConsent, "authorizationConsent cannot be null"); + + OAuth2AuthorizationConsent existingAuthorizationConsent = + findById(authorizationConsent.getRegisteredClientId(), authorizationConsent.getPrincipalName()); + + if (existingAuthorizationConsent == null) { + insertAuthorizationConsent(authorizationConsent); + } else { + updateAuthorizationConsent(authorizationConsent); + } + } + + private void updateAuthorizationConsent(OAuth2AuthorizationConsent authorizationConsent) { + List parameters = this.authorizationConsentParametersMapper.apply(authorizationConsent); + SqlParameterValue registeredClientId = parameters.remove(0); + SqlParameterValue principalName = parameters.remove(0); + parameters.add(registeredClientId); + parameters.add(principalName); + PreparedStatementSetter pss = new ArgumentPreparedStatementSetter(parameters.toArray()); + this.jdbcOperations.update(UPDATE_AUTHORIZATION_CONSENT_SQL, pss); + } + + private void insertAuthorizationConsent(OAuth2AuthorizationConsent authorizationConsent) { + List parameters = this.authorizationConsentParametersMapper.apply(authorizationConsent); + PreparedStatementSetter pss = new ArgumentPreparedStatementSetter(parameters.toArray()); + this.jdbcOperations.update(SAVE_AUTHORIZATION_CONSENT_SQL, pss); + } + + @Override + public void remove(OAuth2AuthorizationConsent authorizationConsent) { + Assert.notNull(authorizationConsent, "authorizationConsent cannot be null"); + SqlParameterValue[] parameters = new SqlParameterValue[]{ + new SqlParameterValue(Types.VARCHAR, authorizationConsent.getRegisteredClientId()), + new SqlParameterValue(Types.VARCHAR, authorizationConsent.getPrincipalName()) + }; + PreparedStatementSetter pss = new ArgumentPreparedStatementSetter(parameters); + this.jdbcOperations.update(REMOVE_AUTHORIZATION_CONSENT_SQL, pss); + } + + @Override + @Nullable + public OAuth2AuthorizationConsent findById(String registeredClientId, String principalName) { + Assert.hasText(registeredClientId, "registeredClientId cannot be empty"); + Assert.hasText(principalName, "principalName cannot be empty"); + SqlParameterValue[] parameters = new SqlParameterValue[]{ + new SqlParameterValue(Types.VARCHAR, registeredClientId), + new SqlParameterValue(Types.VARCHAR, principalName)}; + PreparedStatementSetter pss = new ArgumentPreparedStatementSetter(parameters); + List result = this.jdbcOperations.query(LOAD_AUTHORIZATION_CONSENT_SQL, pss, + this.authorizationConsentRowMapper); + return !result.isEmpty() ? result.get(0) : null; + } + + /** + * Sets the {@link RowMapper} used for mapping the current row in + * {@code java.sql.ResultSet} to {@link OAuth2AuthorizationConsent}. The default is + * {@link OAuth2AuthorizationConsentRowMapper}. + * + * @param authorizationConsentRowMapper the {@link RowMapper} used for mapping the current + * row in {@code ResultSet} to {@link OAuth2AuthorizationConsent} + */ + public void setAuthorizationConsentRowMapper(RowMapper authorizationConsentRowMapper) { + Assert.notNull(authorizationConsentRowMapper, "authorizationConsentRowMapper cannot be null"); + this.authorizationConsentRowMapper = authorizationConsentRowMapper; + } + + /** + * Sets the {@code Function} used for mapping {@link OAuth2AuthorizationConsent} to + * a {@code List} of {@link SqlParameterValue}. The default is + * {@link OAuth2AuthorizationConsentParametersMapper}. + * + * @param authorizationConsentParametersMapper the {@code Function} used for mapping + * {@link OAuth2AuthorizationConsent} to a {@code List} of {@link SqlParameterValue} + */ + public void setAuthorizationConsentParametersMapper( + Function> authorizationConsentParametersMapper) { + Assert.notNull(authorizationConsentParametersMapper, "authorizationConsentParametersMapper cannot be null"); + this.authorizationConsentParametersMapper = authorizationConsentParametersMapper; + } + + /** + * The default {@link RowMapper} that maps the current row in + * {@code ResultSet} to {@link OAuth2AuthorizationConsent}. + */ + public static class OAuth2AuthorizationConsentRowMapper implements RowMapper { + + private final RegisteredClientRepository registeredClientRepository; + + public OAuth2AuthorizationConsentRowMapper(RegisteredClientRepository registeredClientRepository) { + Assert.notNull(registeredClientRepository, "registeredClientRepository cannot be null"); + this.registeredClientRepository = registeredClientRepository; + } + + @Override + public OAuth2AuthorizationConsent mapRow(ResultSet rs, int rowNum) throws SQLException { + String registeredClientId = rs.getString("registered_client_id"); + + RegisteredClient registeredClient = this.registeredClientRepository + .findById(registeredClientId); + if (registeredClient == null) { + throw new DataRetrievalFailureException( + "The RegisteredClient with id '" + registeredClientId + "' it was not found in the RegisteredClientRepository."); + } + + String principalName = rs.getString("principal_name"); + + OAuth2AuthorizationConsent.Builder builder = OAuth2AuthorizationConsent.withId(registeredClientId, principalName); + String authorizationConsentAuthorities = rs.getString("authorities"); + if (authorizationConsentAuthorities != null) { + for (String authority : StringUtils.commaDelimitedListToSet(authorizationConsentAuthorities)) { + builder.authority(new SimpleGrantedAuthority(authority)); + } + } + return builder.build(); + } + } + + /** + * The default {@code Function} that maps {@link OAuth2AuthorizationConsent} to a + * {@code List} of {@link SqlParameterValue}. + */ + public static class OAuth2AuthorizationConsentParametersMapper implements Function> { + + @Override + public List apply(OAuth2AuthorizationConsent authorizationConsent) { + List parameters = new ArrayList<>(); + parameters.add(new SqlParameterValue(Types.VARCHAR, authorizationConsent.getRegisteredClientId())); + parameters.add(new SqlParameterValue(Types.VARCHAR, authorizationConsent.getPrincipalName())); + + Set authorities = new HashSet<>(); + for (GrantedAuthority authority : authorizationConsent.getAuthorities()) { + authorities.add(authority.getAuthority()); + } + parameters.add(new SqlParameterValue(Types.VARCHAR, StringUtils.collectionToDelimitedString(authorities, ","))); + return parameters; + } + } + +} diff --git a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/OAuth2AuthorizationConsent.java b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/OAuth2AuthorizationConsent.java index 10a6463f7..dbe2fe2e2 100644 --- a/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/OAuth2AuthorizationConsent.java +++ b/oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/OAuth2AuthorizationConsent.java @@ -18,6 +18,7 @@ import java.io.Serializable; import java.util.Collections; import java.util.HashSet; +import java.util.Objects; import java.util.Set; import java.util.function.Consumer; import java.util.stream.Collectors; @@ -97,6 +98,25 @@ public Set getScopes() { .collect(Collectors.toSet()); } + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (obj == null || getClass() != obj.getClass()) { + return false; + } + OAuth2AuthorizationConsent that = (OAuth2AuthorizationConsent) obj; + return Objects.equals(this.registeredClientId, that.registeredClientId) && + Objects.equals(this.principalName, that.principalName) && + Objects.equals(this.authorities, that.authorities); + } + + @Override + public int hashCode() { + return Objects.hash(this.registeredClientId, this.principalName, this.authorities); + } + /** * Returns a new {@link Builder}, initialized with the values from the provided {@code OAuth2AuthorizationConsent}. * diff --git a/oauth2-authorization-server/src/main/resources/org/springframework/security/oauth2/server/authorization/oauth2-authorization-consent-schema.sql b/oauth2-authorization-server/src/main/resources/org/springframework/security/oauth2/server/authorization/oauth2-authorization-consent-schema.sql new file mode 100644 index 000000000..3020828ab --- /dev/null +++ b/oauth2-authorization-server/src/main/resources/org/springframework/security/oauth2/server/authorization/oauth2-authorization-consent-schema.sql @@ -0,0 +1,6 @@ +CREATE TABLE oauth2_authorization_consent ( + registered_client_id varchar(100) NOT NULL, + principal_name varchar(200) NOT NULL, + authorities varchar(1000) NOT NULL, + PRIMARY KEY (registered_client_id, principal_name) +); diff --git a/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/JdbcOAuth2AuthorizationConsentServiceTests.java b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/JdbcOAuth2AuthorizationConsentServiceTests.java new file mode 100644 index 000000000..2fcbde8dd --- /dev/null +++ b/oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/JdbcOAuth2AuthorizationConsentServiceTests.java @@ -0,0 +1,229 @@ +/* + * Copyright 2020-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.oauth2.server.authorization; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.springframework.jdbc.core.JdbcOperations; +import org.springframework.jdbc.core.JdbcTemplate; +import org.springframework.jdbc.datasource.embedded.EmbeddedDatabase; +import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseBuilder; +import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseType; +import org.springframework.security.core.authority.SimpleGrantedAuthority; +import org.springframework.security.oauth2.server.authorization.client.RegisteredClient; +import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository; +import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +/** + * Tests for {@link JdbcOAuth2AuthorizationConsentService}. + * + * @author Ovidiu Popa + */ +public class JdbcOAuth2AuthorizationConsentServiceTests { + + private static final String OAUTH2_AUTHORIZATION_CONSENT_SCHEMA_SQL_RESOURCE = "org/springframework/security/oauth2/server/authorization/oauth2-authorization-consent-schema.sql"; + private static final String PRINCIPAL_NAME = "principal-name"; + private static final RegisteredClient REGISTERED_CLIENT = TestRegisteredClients.registeredClient().build(); + + private static final OAuth2AuthorizationConsent AUTHORIZATION_CONSENT = + OAuth2AuthorizationConsent.withId(REGISTERED_CLIENT.getId(), PRINCIPAL_NAME) + .authority(new SimpleGrantedAuthority("some.authority")) + .build(); + + private EmbeddedDatabase db; + private JdbcOperations jdbcOperations; + private RegisteredClientRepository registeredClientRepository; + private JdbcOAuth2AuthorizationConsentService authorizationConsentService; + + @Test + public void constructorWhenJdbcOperationsIsNullThenThrowIllegalArgumentException() { + // @formatter:off + assertThatThrownBy(() -> new JdbcOAuth2AuthorizationConsentService(null, this.registeredClientRepository)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("jdbcOperations cannot be null"); + // @formatter:on + } + + @Test + public void constructorWhenRegisteredClientRepositoryIsNullThenThrowIllegalArgumentException() { + // @formatter:off + assertThatThrownBy(() -> new JdbcOAuth2AuthorizationConsentService(this.jdbcOperations, null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("registeredClientRepository cannot be null"); + // @formatter:on + } + + @Test + public void setAuthorizationConsentRowMapperWhenNullThenThrowIllegalArgumentException() { + // @formatter:off + assertThatThrownBy(() -> this.authorizationConsentService.setAuthorizationConsentRowMapper(null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("authorizationConsentRowMapper cannot be null"); + // @formatter:on + } + + @Test + public void setAuthorizationConsentParametersMapperWhenNullThenThrowIllegalArgumentException() { + // @formatter:off + assertThatThrownBy(() -> this.authorizationConsentService.setAuthorizationConsentParametersMapper(null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("authorizationConsentParametersMapper cannot be null"); + // @formatter:on + } + + @Test + public void saveWhenAuthorizationConsentNullThenThrowIllegalArgumentException() { + // @formatter:off + assertThatIllegalArgumentException() + .isThrownBy(() -> this.authorizationConsentService.save(null)) + .withMessage("authorizationConsent cannot be null"); + // @formatter:on + } + + @Test + public void saveWhenAuthorizationConsentNewThenSaved() { + OAuth2AuthorizationConsent expectedAuthorizationConsent = + OAuth2AuthorizationConsent.withId("new-client", "new-principal") + .authority(new SimpleGrantedAuthority("new.authority")) + .build(); + + RegisteredClient newRegisteredClient = TestRegisteredClients.registeredClient() + .id("new-client").build(); + + when(registeredClientRepository.findById(eq(newRegisteredClient.getId()))) + .thenReturn(newRegisteredClient); + + this.authorizationConsentService.save(expectedAuthorizationConsent); + + OAuth2AuthorizationConsent authorizationConsent = + this.authorizationConsentService.findById("new-client", "new-principal"); + assertThat(authorizationConsent).isEqualTo(expectedAuthorizationConsent); + } + + @Test + public void saveWhenAuthorizationConsentExistsThenUpdated() { + OAuth2AuthorizationConsent expectedAuthorizationConsent = + OAuth2AuthorizationConsent.from(AUTHORIZATION_CONSENT) + .authority(new SimpleGrantedAuthority("new.authority")) + .build(); + when(registeredClientRepository.findById(eq(REGISTERED_CLIENT.getId()))) + .thenReturn(REGISTERED_CLIENT); + + this.authorizationConsentService.save(expectedAuthorizationConsent); + + OAuth2AuthorizationConsent authorizationConsent = + this.authorizationConsentService.findById( + AUTHORIZATION_CONSENT.getRegisteredClientId(), AUTHORIZATION_CONSENT.getPrincipalName()); + assertThat(authorizationConsent).isEqualTo(expectedAuthorizationConsent); + assertThat(authorizationConsent).isNotEqualTo(AUTHORIZATION_CONSENT); + } + + @Test + public void saveLoadAuthorizationConsentWhenCustomStrategiesSetThenCalled() throws Exception { + when(registeredClientRepository.findById(eq(REGISTERED_CLIENT.getId()))) + .thenReturn(REGISTERED_CLIENT); + + JdbcOAuth2AuthorizationConsentService.OAuth2AuthorizationConsentRowMapper authorizationConsentRowMapper = spy( + new JdbcOAuth2AuthorizationConsentService.OAuth2AuthorizationConsentRowMapper( + this.registeredClientRepository)); + this.authorizationConsentService.setAuthorizationConsentRowMapper(authorizationConsentRowMapper); + JdbcOAuth2AuthorizationConsentService.OAuth2AuthorizationConsentParametersMapper authorizationConsentParametersMapper = spy( + new JdbcOAuth2AuthorizationConsentService.OAuth2AuthorizationConsentParametersMapper()); + this.authorizationConsentService.setAuthorizationConsentParametersMapper(authorizationConsentParametersMapper); + + this.authorizationConsentService.save(AUTHORIZATION_CONSENT); + OAuth2AuthorizationConsent authorizationConsent = this.authorizationConsentService.findById( + AUTHORIZATION_CONSENT.getRegisteredClientId(), AUTHORIZATION_CONSENT.getPrincipalName()); + assertThat(authorizationConsent).isEqualTo(AUTHORIZATION_CONSENT); + verify(authorizationConsentRowMapper).mapRow(any(), anyInt()); + verify(authorizationConsentParametersMapper).apply(any()); + } + + @Test + public void removeWhenAuthorizationConsentNullThenThrowIllegalArgumentException() { + assertThatIllegalArgumentException() + .isThrownBy(() -> this.authorizationConsentService.remove(null)) + .withMessage("authorizationConsent cannot be null"); + } + + @Test + public void removeWhenAuthorizationConsentProvidedThenRemoved() { + this.authorizationConsentService.remove(AUTHORIZATION_CONSENT); + assertThat(this.authorizationConsentService.findById( + AUTHORIZATION_CONSENT.getRegisteredClientId(), AUTHORIZATION_CONSENT.getPrincipalName())) + .isNull(); + } + + @Test + public void findByIdWhenRegisteredClientIdNullThenThrowIllegalArgumentException() { + assertThatIllegalArgumentException() + .isThrownBy(() -> this.authorizationConsentService.findById(null, "some-user")) + .withMessage("registeredClientId cannot be empty"); + } + + @Test + public void findByIdWhenPrincipalNameNullThenThrowIllegalArgumentException() { + assertThatIllegalArgumentException() + .isThrownBy(() -> this.authorizationConsentService.findById("some-client", null)) + .withMessage("principalName cannot be empty"); + } + + @Test + public void findByIdWhenAuthorizationConsentDoesNotExistThenNull() { + this.authorizationConsentService.save(AUTHORIZATION_CONSENT); + assertThat(this.authorizationConsentService.findById("unknown-client", PRINCIPAL_NAME)).isNull(); + assertThat(this.authorizationConsentService.findById(REGISTERED_CLIENT.getId(), "unknown-user")).isNull(); + } + + @Before + public void setUp() { + this.db = createDb(); + this.registeredClientRepository = mock(RegisteredClientRepository.class); + this.jdbcOperations = new JdbcTemplate(this.db); + this.authorizationConsentService = new JdbcOAuth2AuthorizationConsentService(this.jdbcOperations, this.registeredClientRepository); + } + + @After + public void tearDown() { + this.db.shutdown(); + } + + private static EmbeddedDatabase createDb() { + return createDb(OAUTH2_AUTHORIZATION_CONSENT_SCHEMA_SQL_RESOURCE); + } + private static EmbeddedDatabase createDb(String schema) { + // @formatter:off + return new EmbeddedDatabaseBuilder() + .generateUniqueName(true) + .setType(EmbeddedDatabaseType.HSQL) + .setScriptEncoding("UTF-8") + .addScript(schema) + .build(); + // @formatter:on + } +}