diff --git a/dev_requirements.txt b/dev_requirements.txt index 2a0938bec3..ad7330598d 100644 --- a/dev_requirements.txt +++ b/dev_requirements.txt @@ -13,4 +13,4 @@ ujson>=4.2.0 uvloop vulture>=2.3.0 numpy>=1.24.0 -redis-entraid==0.3.0b1 +redis-entraid==0.4.0b2 diff --git a/tests/entraid_utils.py b/tests/entraid_utils.py index daefbd3956..529c3ccdee 100644 --- a/tests/entraid_utils.py +++ b/tests/entraid_utils.py @@ -19,6 +19,8 @@ ServicePrincipalIdentityProviderConfig, _create_provider_from_managed_identity, _create_provider_from_service_principal, + DefaultAzureCredentialIdentityProviderConfig, + _create_provider_from_default_azure_credential, ) from tests.conftest import mock_identity_provider @@ -26,6 +28,7 @@ class AuthType(Enum): MANAGED_IDENTITY = "managed_identity" SERVICE_PRINCIPAL = "service_principal" + DEFAULT_AZURE_CREDENTIAL = "default_azure_credential" def identity_provider(request) -> IdentityProviderInterface: @@ -37,18 +40,25 @@ def identity_provider(request) -> IdentityProviderInterface: if request.param.get("mock_idp", None) is not None: return mock_identity_provider() - auth_type = kwargs.pop("auth_type", AuthType.SERVICE_PRINCIPAL) + auth_type = kwargs.get("auth_type", AuthType.SERVICE_PRINCIPAL) config = get_identity_provider_config(request=request) - if auth_type == "MANAGED_IDENTITY": + if auth_type == AuthType.MANAGED_IDENTITY: return _create_provider_from_managed_identity(config) + if auth_type == AuthType.DEFAULT_AZURE_CREDENTIAL: + return _create_provider_from_default_azure_credential(config) + return _create_provider_from_service_principal(config) def get_identity_provider_config( request, -) -> Union[ManagedIdentityProviderConfig, ServicePrincipalIdentityProviderConfig]: +) -> Union[ + ManagedIdentityProviderConfig, + ServicePrincipalIdentityProviderConfig, + DefaultAzureCredentialIdentityProviderConfig, +]: if hasattr(request, "param"): kwargs = request.param.get("idp_kwargs", {}) else: @@ -59,6 +69,9 @@ def get_identity_provider_config( if auth_type == AuthType.MANAGED_IDENTITY: return _get_managed_identity_provider_config(request) + if auth_type == AuthType.DEFAULT_AZURE_CREDENTIAL: + return _get_default_azure_credential_provider_config(request) + return _get_service_principal_provider_config(request) @@ -114,6 +127,26 @@ def _get_service_principal_provider_config( ) +def _get_default_azure_credential_provider_config( + request, +) -> DefaultAzureCredentialIdentityProviderConfig: + scopes = os.getenv("AZURE_REDIS_SCOPES", ()) + + if hasattr(request, "param"): + kwargs = request.param.get("idp_kwargs", {}) + token_kwargs = request.param.get("token_kwargs", {}) + else: + kwargs = {} + token_kwargs = {} + + if isinstance(scopes, str): + scopes = scopes.split(",") + + return DefaultAzureCredentialIdentityProviderConfig( + scopes=scopes, app_kwargs=kwargs, token_kwargs=token_kwargs + ) + + def get_entra_id_credentials_provider(request, cred_provider_kwargs): idp = identity_provider(request) expiration_refresh_ratio = cred_provider_kwargs.get( diff --git a/tests/test_asyncio/conftest.py b/tests/test_asyncio/conftest.py index 60e447e6fd..340d146ea3 100644 --- a/tests/test_asyncio/conftest.py +++ b/tests/test_asyncio/conftest.py @@ -1,6 +1,5 @@ import random from contextlib import asynccontextmanager as _asynccontextmanager -from enum import Enum from typing import Union import pytest @@ -18,11 +17,6 @@ from .compat import mock -class AuthType(Enum): - MANAGED_IDENTITY = "managed_identity" - SERVICE_PRINCIPAL = "service_principal" - - async def _get_info(redis_url): client = redis.Redis.from_url(redis_url) info = await client.info() diff --git a/tests/test_asyncio/test_credentials.py b/tests/test_asyncio/test_credentials.py index ce8d76ea45..b4824be469 100644 --- a/tests/test_asyncio/test_credentials.py +++ b/tests/test_asyncio/test_credentials.py @@ -18,6 +18,7 @@ from redis.exceptions import ConnectionError from redis.utils import str_if_bytes from tests.conftest import get_endpoint, skip_if_redis_enterprise +from tests.entraid_utils import AuthType from tests.test_asyncio.conftest import get_credential_provider try: @@ -616,8 +617,12 @@ class TestEntraIdCredentialsProvider: "cred_provider_class": EntraIdCredentialsProvider, "cred_provider_kwargs": {"block_for_initial": True}, }, + { + "cred_provider_class": EntraIdCredentialsProvider, + "idp_kwargs": {"auth_type": AuthType.DEFAULT_AZURE_CREDENTIAL}, + }, ], - ids=["blocked", "non-blocked"], + ids=["blocked", "non-blocked", "DefaultAzureCredential"], indirect=True, ) @pytest.mark.asyncio @@ -692,8 +697,12 @@ class TestClusterEntraIdCredentialsProvider: "cred_provider_class": EntraIdCredentialsProvider, "cred_provider_kwargs": {"block_for_initial": True}, }, + { + "cred_provider_class": EntraIdCredentialsProvider, + "idp_kwargs": {"auth_type": AuthType.DEFAULT_AZURE_CREDENTIAL}, + }, ], - ids=["blocked", "non-blocked"], + ids=["blocked", "non-blocked", "DefaultAzureCredential"], indirect=True, ) @pytest.mark.asyncio diff --git a/tests/test_credentials.py b/tests/test_credentials.py index 1f98c5208d..58bbd01f28 100644 --- a/tests/test_credentials.py +++ b/tests/test_credentials.py @@ -22,6 +22,7 @@ get_endpoint, skip_if_redis_enterprise, ) +from tests.entraid_utils import AuthType try: from redis_entraid.cred_provider import EntraIdCredentialsProvider @@ -585,8 +586,12 @@ class TestEntraIdCredentialsProvider: "cred_provider_class": EntraIdCredentialsProvider, "single_connection_client": True, }, + { + "cred_provider_class": EntraIdCredentialsProvider, + "idp_kwargs": {"auth_type": AuthType.DEFAULT_AZURE_CREDENTIAL}, + }, ], - ids=["pool", "single"], + ids=["pool", "single", "DefaultAzureCredential"], indirect=True, ) @pytest.mark.onlynoncluster @@ -656,8 +661,12 @@ class TestClusterEntraIdCredentialsProvider: "cred_provider_class": EntraIdCredentialsProvider, "single_connection_client": True, }, + { + "cred_provider_class": EntraIdCredentialsProvider, + "idp_kwargs": {"auth_type": AuthType.DEFAULT_AZURE_CREDENTIAL}, + }, ], - ids=["pool", "single"], + ids=["pool", "single", "DefaultAzureCredential"], indirect=True, ) @pytest.mark.onlycluster