diff --git a/azure_functions_worker/__init__.py b/azure_functions_worker/__init__.py index 5b7f7a925..a75af377d 100644 --- a/azure_functions_worker/__init__.py +++ b/azure_functions_worker/__init__.py @@ -1,2 +1,4 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. + +__version__ = '1.1.10' diff --git a/azure_functions_worker/bindings/tracecontext.py b/azure_functions_worker/bindings/tracecontext.py index 2c12f8c7d..e90312ddb 100644 --- a/azure_functions_worker/bindings/tracecontext.py +++ b/azure_functions_worker/bindings/tracecontext.py @@ -1,23 +1,44 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. +from typing import Dict + class TraceContext: + """Check https://www.w3.org/TR/trace-context/ for more information""" def __init__(self, trace_parent: str, - trace_state: str, attributes: dict) -> None: + trace_state: str, attributes: Dict[str, str]) -> None: self.__trace_parent = trace_parent self.__trace_state = trace_state self.__attributes = attributes @property def Tracestate(self) -> str: + """Get trace state from trace-context (deprecated).""" return self.__trace_state @property def Traceparent(self) -> str: + """Get trace parent from trace-context (deprecated).""" + return self.__trace_parent + + @property + def Attributes(self) -> Dict[str, str]: + """Get trace-context attributes (deprecated).""" + return self.__attributes + + @property + def trace_state(self) -> str: + """Get trace state from trace-context""" + return self.__trace_state + + @property + def trace_parent(self) -> str: + """Get trace parent from trace-context""" return self.__trace_parent @property - def Attributes(self) -> str: + def attributes(self) -> Dict[str, str]: + """Get trace-context attributes""" return self.__attributes diff --git a/azure_functions_worker/constants.py b/azure_functions_worker/constants.py index 6c75ddb19..fb2d29a0d 100644 --- a/azure_functions_worker/constants.py +++ b/azure_functions_worker/constants.py @@ -1,9 +1,6 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -# Prefixes -CONSOLE_LOG_PREFIX = "LanguageWorkerConsoleLog" - # Capabilities RAW_HTTP_BODY_BYTES = "RawHttpBodyBytes" TYPED_DATA_COLLECTION = "TypedDataCollection" @@ -26,6 +23,7 @@ PYTHON_ROLLBACK_CWD_PATH = "PYTHON_ROLLBACK_CWD_PATH" PYTHON_THREADPOOL_THREAD_COUNT = "PYTHON_THREADPOOL_THREAD_COUNT" PYTHON_ISOLATE_WORKER_DEPENDENCIES = "PYTHON_ISOLATE_WORKER_DEPENDENCIES" +PYTHON_ENABLE_WORKER_EXTENSIONS = "PYTHON_ENABLE_WORKER_EXTENSIONS" FUNCTIONS_WORKER_SHARED_MEMORY_DATA_TRANSFER_ENABLED = \ "FUNCTIONS_WORKER_SHARED_MEMORY_DATA_TRANSFER_ENABLED" """ @@ -40,6 +38,8 @@ PYTHON_THREADPOOL_THREAD_COUNT_MAX = 32 PYTHON_ISOLATE_WORKER_DEPENDENCIES_DEFAULT = False PYTHON_ISOLATE_WORKER_DEPENDENCIES_DEFAULT_39 = True +PYTHON_ENABLE_WORKER_EXTENSIONS_DEFAULT = False +PYTHON_ENABLE_WORKER_EXTENSIONS_DEFAULT_39 = True # External Site URLs MODULE_NOT_FOUND_TS_URL = "https://aka.ms/functions-modulenotfound" diff --git a/azure_functions_worker/dispatcher.py b/azure_functions_worker/dispatcher.py index 453559855..6d934d15e 100644 --- a/azure_functions_worker/dispatcher.py +++ b/azure_functions_worker/dispatcher.py @@ -18,17 +18,20 @@ import grpc +from . import __version__ from . import bindings from . import constants from . import functions from . import loader from . import protos -from .constants import (CONSOLE_LOG_PREFIX, PYTHON_THREADPOOL_THREAD_COUNT, +from .constants import (PYTHON_THREADPOOL_THREAD_COUNT, PYTHON_THREADPOOL_THREAD_COUNT_DEFAULT, PYTHON_THREADPOOL_THREAD_COUNT_MAX, PYTHON_THREADPOOL_THREAD_COUNT_MIN) from .logging import disable_console_logging, enable_console_logging -from .logging import error_logger, is_system_log_category, logger +from .logging import (logger, error_logger, is_system_log_category, + CONSOLE_LOG_PREFIX) +from .extension import ExtensionManager from .utils.common import get_app_setting from .utils.tracing import marshall_exception_trace from .utils.dependency import DependencyManager @@ -255,8 +258,9 @@ async def _dispatch_grpc_request(self, request): self._grpc_resp_queue.put_nowait(resp) async def _handle__worker_init_request(self, req): - logger.info('Received WorkerInitRequest, request ID %s', - self.request_id) + logger.info('Received WorkerInitRequest, ' + 'python version %s, worker version %s, request ID %s', + sys.version, __version__, self.request_id) capabilities = { constants.RAW_HTTP_BODY_BYTES: _TRUE, @@ -304,6 +308,11 @@ async def _handle__function_load_request(self, req): self._functions.add_function( function_id, func, func_request.metadata) + ExtensionManager.function_load_extension( + function_name, + func_request.metadata.directory + ) + logger.info('Successfully processed FunctionLoadRequest, ' f'request ID: {self.request_id}, ' f'function ID: {function_id},' @@ -373,20 +382,24 @@ async def _handle__invocation_request(self, req): pytype=pb_type_info.pytype, shmem_mgr=self._shmem_mgr) + fi_context = bindings.Context( + fi.name, fi.directory, invocation_id, trace_context) if fi.requires_context: - args['context'] = bindings.Context( - fi.name, fi.directory, invocation_id, trace_context) + args['context'] = fi_context if fi.output_types: for name in fi.output_types: args[name] = bindings.Out() if fi.is_async: - call_result = await fi.func(**args) + call_result = await self._run_async_func( + fi_context, fi.func, args + ) else: call_result = await self._loop.run_in_executor( self._sync_call_tp, - self.__run_sync_func, invocation_id, fi.func, args) + self._run_sync_func, + invocation_id, fi_context, fi.func, args) if call_result is not None and not fi.has_return: raise RuntimeError(f'function {fi.name!r} without a $return ' 'binding returned a non-None value') @@ -582,15 +595,21 @@ def _create_sync_call_tp( max_workers=max_worker ) - def __run_sync_func(self, invocation_id, func, params): + def _run_sync_func(self, invocation_id, context, func, params): # This helper exists because we need to access the current # invocation_id from ThreadPoolExecutor's threads. _invocation_id_local.v = invocation_id try: - return func(**params) + return ExtensionManager.get_sync_invocation_wrapper(context, + func)(params) finally: _invocation_id_local.v = None + async def _run_async_func(self, context, func, params): + return await ExtensionManager.get_async_invocation_wrapper( + context, func, params + ) + def __poll_grpc(self): options = [] if self._grpc_max_msg_len: diff --git a/azure_functions_worker/extension.py b/azure_functions_worker/extension.py new file mode 100644 index 000000000..3c6791b3e --- /dev/null +++ b/azure_functions_worker/extension.py @@ -0,0 +1,259 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +from types import ModuleType +from typing import Any, Callable, List, Optional +import logging +import functools +from .utils.common import ( + is_python_version, + get_sdk_from_sys_path, + get_sdk_version +) +from .utils.wrappers import enable_feature_by +from .constants import ( + PYTHON_ISOLATE_WORKER_DEPENDENCIES, + PYTHON_ENABLE_WORKER_EXTENSIONS, + PYTHON_ENABLE_WORKER_EXTENSIONS_DEFAULT, + PYTHON_ENABLE_WORKER_EXTENSIONS_DEFAULT_39 +) +from .logging import logger, SYSTEM_LOG_PREFIX + + +# Extension Hooks +FUNC_EXT_POST_FUNCTION_LOAD = "post_function_load" +FUNC_EXT_PRE_INVOCATION = "pre_invocation" +FUNC_EXT_POST_INVOCATION = "post_invocation" +APP_EXT_POST_FUNCTION_LOAD = "post_function_load_app_level" +APP_EXT_PRE_INVOCATION = "pre_invocation_app_level" +APP_EXT_POST_INVOCATION = "post_invocation_app_level" + + +class ExtensionManager: + _is_sdk_detected: bool = False + """This marks if the ExtensionManager has already proceeded a detection, + if so, the sdk will be cached in ._extension_enabled_sdk + """ + + _extension_enabled_sdk: Optional[ModuleType] = None + """This is a cache of azure.functions module that supports extension + interfaces. If this is None, that mean the sdk does not support extension. + """ + + @classmethod + @enable_feature_by( + flag=PYTHON_ENABLE_WORKER_EXTENSIONS, + flag_default=( + PYTHON_ENABLE_WORKER_EXTENSIONS_DEFAULT_39 if + is_python_version('3.9') else + PYTHON_ENABLE_WORKER_EXTENSIONS_DEFAULT + ) + ) + def function_load_extension(cls, func_name, func_directory): + """Helper to execute function load extensions. If one of the extension + fails in the extension chain, the rest of them will continue, emitting + an error log of an exception trace for failed extension. + + Parameters + ---------- + func_name: str + The name of the trigger (e.g. HttpTrigger) + func_directory: str + The folder path of the trigger + (e.g. /home/site/wwwroot/HttpTrigger). + """ + sdk = cls._try_get_sdk_with_extension_enabled() + if sdk is None: + return + + # Reports application & function extensions installed on customer's app + cls._info_discover_extension_list(func_name, sdk) + + # Get function hooks from azure.functions.extension.ExtensionMeta + # The return type is FuncExtensionHooks + funcs = sdk.ExtensionMeta.get_function_hooks(func_name) + + # Invoke function hooks + cls._safe_execute_function_load_hooks( + funcs, FUNC_EXT_POST_FUNCTION_LOAD, func_name, func_directory + ) + + # Get application hooks from azure.functions.extension.ExtensionMeta + # The reutnr type is AppExtensionHooks + apps = sdk.ExtensionMeta.get_application_hooks() + + # Invoke application hook + cls._safe_execute_function_load_hooks( + apps, APP_EXT_POST_FUNCTION_LOAD, func_name, func_directory + ) + + @classmethod + @enable_feature_by( + flag=PYTHON_ENABLE_WORKER_EXTENSIONS, + flag_default=( + PYTHON_ENABLE_WORKER_EXTENSIONS_DEFAULT_39 if + is_python_version('3.9') else + PYTHON_ENABLE_WORKER_EXTENSIONS_DEFAULT + ) + ) + def _invocation_extension(cls, ctx, hook_name, func_args, func_ret=None): + """Helper to execute extensions. If one of the extension fails in the + extension chain, the rest of them will continue, emitting an error log + of an exception trace for failed extension. + + Parameters + ---------- + ctx: azure.functions.Context + Azure Functions context to be passed onto extension + hook_name: str + The exetension name to be executed (e.g. pre_invocations). + These are defined in azure.functions.FuncExtensionHooks. + """ + sdk = cls._try_get_sdk_with_extension_enabled() + if sdk is None: + return + + # Get function hooks from azure.functions.extension.ExtensionMeta + # The return type is FuncExtensionHooks + funcs = sdk.ExtensionMeta.get_function_hooks(ctx.function_name) + + # Invoke function hooks + cls._safe_execute_invocation_hooks( + funcs, hook_name, ctx, func_args, func_ret + ) + + # Get application hooks from azure.functions.extension.ExtensionMeta + # The reutnr type is AppExtensionHooks + apps = sdk.ExtensionMeta.get_application_hooks() + + # Invoke application hook + cls._safe_execute_invocation_hooks( + apps, hook_name, ctx, func_args, func_ret + ) + + @classmethod + def get_sync_invocation_wrapper(cls, ctx, func) -> Callable[[List], Any]: + """Get a synchronous lambda of extension wrapped function which takes + function parameters + """ + return functools.partial(cls._raw_invocation_wrapper, ctx, func) + + @classmethod + async def get_async_invocation_wrapper(cls, ctx, function, args) -> Any: + """An asynchronous coroutine for executing function with extensions + """ + cls._invocation_extension(ctx, APP_EXT_PRE_INVOCATION, args) + cls._invocation_extension(ctx, FUNC_EXT_PRE_INVOCATION, args) + result = await function(**args) + cls._invocation_extension(ctx, FUNC_EXT_POST_INVOCATION, args, result) + cls._invocation_extension(ctx, APP_EXT_POST_INVOCATION, args, result) + return result + + @staticmethod + def _is_extension_enabled_in_sdk(module: ModuleType) -> bool: + """Check if the extension feature is enabled in particular + azure.functions package. + + Parameters + ---------- + module: ModuleType + The azure.functions SDK module + + Returns + ------- + bool + True on azure.functions SDK supports extension registration + """ + return getattr(module, 'ExtensionMeta', None) is not None + + @classmethod + def _is_pre_invocation_hook(cls, name) -> bool: + return name in (FUNC_EXT_PRE_INVOCATION, APP_EXT_PRE_INVOCATION) + + @classmethod + def _is_post_invocation_hook(cls, name) -> bool: + return name in (FUNC_EXT_POST_INVOCATION, APP_EXT_POST_INVOCATION) + + @classmethod + def _safe_execute_invocation_hooks(cls, hooks, hook_name, ctx, fargs, fret): + # hooks from azure.functions.ExtensionMeta.get_function_hooks() or + # azure.functions.ExtensionMeta.get_application_hooks() + if hooks: + # Invoke extension implementation from ..ext_impl + for hook_meta in getattr(hooks, hook_name, []): + # Register a system logger with prefix azure_functions_worker + ext_logger = logging.getLogger( + f'{SYSTEM_LOG_PREFIX}.extension.{hook_meta.ext_name}' + ) + try: + if cls._is_pre_invocation_hook(hook_name): + hook_meta.ext_impl(ext_logger, ctx, fargs) + elif cls._is_post_invocation_hook(hook_name): + hook_meta.ext_impl(ext_logger, ctx, fargs, fret) + except Exception as e: + ext_logger.error(e, exc_info=True) + + @classmethod + def _safe_execute_function_load_hooks(cls, hooks, hook_name, fname, fdir): + # hooks from azure.functions.ExtensionMeta.get_function_hooks() or + # azure.functions.ExtensionMeta.get_application_hooks() + if hooks: + # Invoke extension implementation from ..ext_impl + for hook_meta in getattr(hooks, hook_name, []): + try: + hook_meta.ext_impl(fname, fdir) + except Exception as e: + logger.error(e, exc_info=True) + + @classmethod + def _raw_invocation_wrapper(cls, ctx, function, args) -> Any: + """Calls pre_invocation and post_invocation extensions additional + to function invocation + """ + cls._invocation_extension(ctx, APP_EXT_PRE_INVOCATION, args) + cls._invocation_extension(ctx, FUNC_EXT_PRE_INVOCATION, args) + result = function(**args) + cls._invocation_extension(ctx, FUNC_EXT_POST_INVOCATION, args, result) + cls._invocation_extension(ctx, APP_EXT_POST_INVOCATION, args, result) + return result + + @classmethod + def _try_get_sdk_with_extension_enabled(cls) -> Optional[ModuleType]: + if cls._is_sdk_detected: + return cls._extension_enabled_sdk + + sdk = get_sdk_from_sys_path() + if cls._is_extension_enabled_in_sdk(sdk): + cls._info_extension_is_enabled(sdk) + cls._extension_enabled_sdk = sdk + else: + cls._warn_sdk_not_support_extension(sdk) + cls._extension_enabled_sdk = None + + cls._is_sdk_detected = True + return cls._extension_enabled_sdk + + @classmethod + def _info_extension_is_enabled(cls, sdk): + logger.info( + 'Python Worker Extension is enabled in azure.functions ' + f'({get_sdk_version(sdk)}).' + ) + + @classmethod + def _info_discover_extension_list(cls, function_name, sdk): + logger.info( + f'Python Worker Extension Manager is loading {function_name}, ' + 'current registered extensions: ' + f'{sdk.ExtensionMeta.get_registered_extensions_json()}' + ) + + @classmethod + def _warn_sdk_not_support_extension(cls, sdk): + logger.warning( + f'The azure.functions ({get_sdk_version(sdk)}) does not ' + 'support Python worker extensions. If you believe extensions ' + 'are correctly installed, please set the ' + f'{PYTHON_ISOLATE_WORKER_DEPENDENCIES} and ' + f'{PYTHON_ENABLE_WORKER_EXTENSIONS} to "true"' + ) diff --git a/azure_functions_worker/logging.py b/azure_functions_worker/logging.py index 50cfc2e8f..c37e3ae0e 100644 --- a/azure_functions_worker/logging.py +++ b/azure_functions_worker/logging.py @@ -6,7 +6,9 @@ import logging.handlers import sys -from .constants import CONSOLE_LOG_PREFIX +# Logging Prefixes +CONSOLE_LOG_PREFIX = "LanguageWorkerConsoleLog" +SYSTEM_LOG_PREFIX = "azure_functions_worker" logger: logging.Logger = logging.getLogger('azure_functions_worker') @@ -76,4 +78,4 @@ def enable_console_logging() -> None: def is_system_log_category(ctg: str) -> bool: # Category starts with 'azure_functions_worker' or # 'azure_functions_worker_errors' will be treated as system logs - return ctg.lower().startswith('azure_functions_worker') + return ctg.lower().startswith(SYSTEM_LOG_PREFIX) diff --git a/azure_functions_worker/utils/common.py b/azure_functions_worker/utils/common.py index d203fb315..350f858f2 100644 --- a/azure_functions_worker/utils/common.py +++ b/azure_functions_worker/utils/common.py @@ -1,8 +1,10 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. from typing import Optional, Callable +from types import ModuleType import os import sys +import importlib def is_true_like(setting: str) -> bool: @@ -79,3 +81,48 @@ def get_app_setting( if validator(app_setting_value): return app_setting_value return default_value + + +def get_sdk_version(module: ModuleType) -> str: + """Check the version of azure.functions sdk. + + Parameters + ---------- + module: ModuleType + The azure.functions SDK module + + Returns + ------- + str + The SDK version that our customer has installed. + """ + + return getattr(module, '__version__', 'undefined') + + +def get_sdk_from_sys_path() -> ModuleType: + """Get the azure.functions SDK from the latest sys.path defined. + This is to ensure the extension loaded from SDK coming from customer's + site-packages. + + Returns + ------- + ModuleType + The azure.functions that is loaded from the first sys.path entry + """ + backup_azure_functions = None + backup_azure = None + + if 'azure.functions' in sys.modules: + backup_azure_functions = sys.modules.pop('azure.functions') + if 'azure' in sys.modules: + backup_azure = sys.modules.pop('azure') + + module = importlib.import_module('azure.functions') + + if backup_azure: + sys.modules['azure'] = backup_azure + if backup_azure_functions: + sys.modules['azure.functions'] = backup_azure_functions + + return module diff --git a/setup.py b/setup.py index 4787fbef8..a1dc11923 100644 --- a/setup.py +++ b/setup.py @@ -16,6 +16,8 @@ from setuptools import setup from setuptools.command import develop +from azure_functions_worker import __version__ + # The GitHub repository of the Azure Functions Host WEBHOST_GITHUB_API = "https://api.github.com/repos/Azure/azure-functions-host" WEBHOST_TAG_PREFIX = "v3." @@ -354,7 +356,7 @@ def run(self): setup( name='azure-functions-worker', - version='1.1.10', + version=__version__, description='Python Language Worker for Azure Functions Host', author="Microsoft Corp.", author_email="azurefunctions@microsoft.com", diff --git a/tests/unittests/resources/mock_azure_functions/azure/__init__.py b/tests/unittests/resources/mock_azure_functions/azure/__init__.py new file mode 100644 index 000000000..649cbaa5f --- /dev/null +++ b/tests/unittests/resources/mock_azure_functions/azure/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +__path__ = __import__('pkgutil').extend_path(__path__, __name__) diff --git a/tests/unittests/resources/mock_azure_functions/azure/functions/__init__.py b/tests/unittests/resources/mock_azure_functions/azure/functions/__init__.py new file mode 100644 index 000000000..9f561659c --- /dev/null +++ b/tests/unittests/resources/mock_azure_functions/azure/functions/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +__version__ = "dummy" diff --git a/tests/unittests/resources/mock_azure_functions/readme.md b/tests/unittests/resources/mock_azure_functions/readme.md new file mode 100644 index 000000000..c40015fb4 --- /dev/null +++ b/tests/unittests/resources/mock_azure_functions/readme.md @@ -0,0 +1,3 @@ +# Instruction + +This is a dummy azure.functions SDK for testing the backward compatibility \ No newline at end of file diff --git a/tests/unittests/test_extension.py b/tests/unittests/test_extension.py new file mode 100644 index 000000000..3c50276e2 --- /dev/null +++ b/tests/unittests/test_extension.py @@ -0,0 +1,831 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import logging +import os +import sys +import unittest +from unittest.mock import patch, Mock, call +from importlib import import_module +from azure_functions_worker._thirdparty import aio_compat +from azure_functions_worker.extension import ( + ExtensionManager, + APP_EXT_POST_FUNCTION_LOAD, FUNC_EXT_POST_FUNCTION_LOAD, + APP_EXT_PRE_INVOCATION, FUNC_EXT_PRE_INVOCATION, + APP_EXT_POST_INVOCATION, FUNC_EXT_POST_INVOCATION +) +from azure_functions_worker.utils.common import get_sdk_from_sys_path +from azure_functions_worker.constants import PYTHON_ENABLE_WORKER_EXTENSIONS + + +class MockContext: + def __init__(self, function_name: str, function_directory: str): + self.function_name = function_name + self.function_directory = function_directory + + +class TestExtension(unittest.TestCase): + + def setUp(self): + # Initialize Extension Manager Instance + self._instance = ExtensionManager + self._instance._is_sdk_detected = False + self._instance._extension_enabled_sdk = None + + # Initialize Azure Functions SDK and clear cache + self._sdk = import_module('azure.functions') + self._sdk.ExtensionMeta._func_exts = {} + self._sdk.ExtensionMeta._app_exts = None + self._sdk.ExtensionMeta._info = {} + sys.modules.pop('azure.functions') + sys.modules.pop('azure') + + # Derived dummy SDK Python system path + self._dummy_sdk_sys_path = os.path.join( + os.path.dirname(__file__), + 'resources', + 'mock_azure_functions' + ) + + # Initialize mock context + self._mock_arguments = {'req': 'request'} + self._mock_func_name = 'HttpTrigger' + self._mock_func_dir = '/home/site/wwwroot/HttpTrigger' + self._mock_context = MockContext( + function_name=self._mock_func_name, + function_directory=self._mock_func_dir + ) + + # Patch sys.modules and sys.path to avoid pollution between tests + self.mock_sys_module = patch.dict('sys.modules', sys.modules.copy()) + self.mock_sys_path = patch('sys.path', sys.path.copy()) + self.mock_sys_module.start() + self.mock_sys_path.start() + + # Set feature flag to on + os.environ[PYTHON_ENABLE_WORKER_EXTENSIONS] = 'true' + + def tearDown(self) -> None: + self.mock_sys_path.stop() + self.mock_sys_module.stop() + os.environ.pop(PYTHON_ENABLE_WORKER_EXTENSIONS) + + def test_extension_is_supported_by_latest_sdk(self): + """Test if extension interface supports check as expected on + new version of azure.functions SDK + """ + module = get_sdk_from_sys_path() + sdk_enabled = self._instance._is_extension_enabled_in_sdk(module) + self.assertTrue(sdk_enabled) + + def test_extension_is_not_supported_by_mock_sdk(self): + """Test if the detection works when an azure.functions SDK does not + support extension management. + """ + sys.path.insert(0, self._dummy_sdk_sys_path) + module = get_sdk_from_sys_path() + sdk_enabled = self._instance._is_extension_enabled_in_sdk(module) + self.assertFalse(sdk_enabled) + + @patch('azure_functions_worker.extension.get_sdk_from_sys_path') + def test_function_load_extension_enable_when_feature_flag_is_on( + self, + get_sdk_from_sys_path_mock: Mock + ): + """When turning off the feature flag PYTHON_ENABLE_WORKER_EXTENSIONS, + the post_function_load extension should be disabled + """ + self._instance.function_load_extension( + func_name=self._mock_func_name, + func_directory=self._mock_func_dir + ) + get_sdk_from_sys_path_mock.assert_called_once() + + @patch('azure_functions_worker.extension.get_sdk_from_sys_path') + def test_function_load_extension_disable_when_feature_flag_is_off( + self, + get_sdk_from_sys_path_mock: Mock + ): + """When turning off the feature flag PYTHON_ENABLE_WORKER_EXTENSIONS, + the post_function_load extension should be disabled + """ + os.environ[PYTHON_ENABLE_WORKER_EXTENSIONS] = 'false' + self._instance.function_load_extension( + func_name=self._mock_func_name, + func_directory=self._mock_func_dir + ) + get_sdk_from_sys_path_mock.assert_not_called() + + @patch('azure_functions_worker.extension.ExtensionManager.' + '_warn_sdk_not_support_extension') + def test_function_load_extension_warns_when_sdk_does_not_support( + self, + _warn_sdk_not_support_extension_mock: Mock + ): + """When customer is using an old version of sdk which does not have + extension support and turning on the feature flag, we should warn them + """ + sys.path.insert(0, self._dummy_sdk_sys_path) + self._instance.function_load_extension( + func_name=self._mock_func_name, + func_directory=self._mock_func_dir + ) + _warn_sdk_not_support_extension_mock.assert_called_once() + + @patch('azure_functions_worker.extension.ExtensionManager.' + '_safe_execute_function_load_hooks') + def test_function_load_extension_should_invoke_extension_call( + self, + safe_execute_function_load_hooks_mock: Mock + ): + """Should invoke extension if SDK suports extension interface + """ + self._instance.function_load_extension( + func_name=self._mock_func_name, + func_directory=self._mock_func_dir + ) + # No registered hooks + safe_execute_function_load_hooks_mock.assert_has_calls( + calls=[ + call( + None, APP_EXT_POST_FUNCTION_LOAD, + self._mock_func_name, self._mock_func_dir + ), + call( + None, FUNC_EXT_POST_FUNCTION_LOAD, + self._mock_func_name, self._mock_func_dir + ) + ], + any_order=True + ) + + @patch('azure_functions_worker.extension.get_sdk_from_sys_path') + def test_invocation_extension_enable_when_feature_flag_is_on( + self, + get_sdk_from_sys_path_mock: Mock + ): + """When turning off the feature flag PYTHON_ENABLE_WORKER_EXTENSIONS, + the pre_invocation and post_invocation extension should be disabled + """ + self._instance._invocation_extension( + ctx=self._mock_context, + hook_name=FUNC_EXT_PRE_INVOCATION, + func_args=[], + func_ret=None + ) + get_sdk_from_sys_path_mock.assert_called_once() + + @patch('azure_functions_worker.extension.get_sdk_from_sys_path') + def test_invocation_extension_extension_disable_when_feature_flag_is_off( + self, + get_sdk_from_sys_path_mock: Mock + ): + """When turning off the feature flag PYTHON_ENABLE_WORKER_EXTENSIONS, + the pre_invocation and post_invocation extension should be disabled + """ + os.environ[PYTHON_ENABLE_WORKER_EXTENSIONS] = 'false' + self._instance._invocation_extension( + ctx=self._mock_context, + hook_name=FUNC_EXT_PRE_INVOCATION, + func_args=[], + func_ret=None + ) + get_sdk_from_sys_path_mock.assert_not_called() + + @patch('azure_functions_worker.extension.ExtensionManager.' + '_warn_sdk_not_support_extension') + def test_invocation_extension_warns_when_sdk_does_not_support( + self, + _warn_sdk_not_support_extension_mock: Mock + ): + """When customer is using an old version of sdk which does not have + extension support and turning on the feature flag, we should warn them + """ + sys.path.insert(0, self._dummy_sdk_sys_path) + self._instance._invocation_extension( + ctx=self._mock_context, + hook_name=FUNC_EXT_PRE_INVOCATION, + func_args=[], + func_ret=None + ) + _warn_sdk_not_support_extension_mock.assert_called_once() + + @patch('azure_functions_worker.extension.ExtensionManager.' + '_safe_execute_invocation_hooks') + def test_invocation_extension_should_invoke_extension_call( + self, + safe_execute_invocation_hooks_mock: Mock + ): + """Should invoke extension if SDK suports extension interface + """ + for hook_name in (APP_EXT_PRE_INVOCATION, FUNC_EXT_PRE_INVOCATION, + APP_EXT_POST_INVOCATION, FUNC_EXT_POST_INVOCATION): + self._instance._invocation_extension( + ctx=self._mock_context, + hook_name=hook_name, + func_args=[], + func_ret=None + ) + + safe_execute_invocation_hooks_mock.assert_has_calls( + calls=[ + call( + None, hook_name, self._mock_context, + [], None + ) + ], + any_order=True + ) + + @patch('azure_functions_worker.extension.ExtensionManager.' + '_is_pre_invocation_hook') + @patch('azure_functions_worker.extension.ExtensionManager.' + '_is_post_invocation_hook') + def test_empty_hooks_should_not_receive_any_invocation( + self, + _is_post_invocation_hook_mock: Mock, + _is_pre_invocation_hook_mock: Mock + ): + """If there is no life-cycle hooks implemented under a function, + then we should skip it + """ + for hook_name in (APP_EXT_PRE_INVOCATION, FUNC_EXT_PRE_INVOCATION, + APP_EXT_POST_INVOCATION, FUNC_EXT_POST_INVOCATION): + self._instance._safe_execute_invocation_hooks( + hooks=[], + hook_name=hook_name, + ctx=self._mock_context, + fargs=[], + fret=None + ) + _is_pre_invocation_hook_mock.assert_not_called() + _is_post_invocation_hook_mock.assert_not_called() + + def test_invocation_hooks_should_be_executed(self): + """If there is an extension implemented the pre_invocation and + post_invocation life-cycle hooks, it should be invoked in + safe_execute_invocation_hooks + """ + FuncExtClass = self._generate_new_func_extension_class( + base=self._sdk.FuncExtensionBase, + trigger=self._mock_func_name + ) + func_ext_instance = FuncExtClass() + hook_instances = ( + self._sdk.ExtensionMeta.get_function_hooks(self._mock_func_name) + ) + for hook_name in (FUNC_EXT_PRE_INVOCATION, FUNC_EXT_POST_INVOCATION): + self._instance._safe_execute_invocation_hooks( + hooks=hook_instances, + hook_name=hook_name, + ctx=self._mock_context, + fargs=[], + fret=None + ) + self.assertFalse(func_ext_instance._post_function_load_executed) + self.assertTrue(func_ext_instance._pre_invocation_executed) + self.assertTrue(func_ext_instance._post_invocation_executed) + + def test_post_function_load_hook_should_be_executed(self): + """If there is an extension implemented the post_function_load + life-cycle hook, it invokes in safe_execute_function_load_hooks + """ + FuncExtClass = self._generate_new_func_extension_class( + base=self._sdk.FuncExtensionBase, + trigger=self._mock_func_name + ) + func_ext_instance = FuncExtClass() + hook_instances = ( + self._sdk.ExtensionMeta.get_function_hooks(self._mock_func_name) + ) + for hook_name in (FUNC_EXT_POST_FUNCTION_LOAD,): + self._instance._safe_execute_function_load_hooks( + hooks=hook_instances, + hook_name=hook_name, + fname=self._mock_func_name, + fdir=self._mock_func_dir + ) + self.assertTrue(func_ext_instance._post_function_load_executed) + self.assertFalse(func_ext_instance._pre_invocation_executed) + self.assertFalse(func_ext_instance._post_invocation_executed) + + def test_invocation_hooks_app_level_should_be_executed(self): + """If there is an extension implemented the pre_invocation and + post_invocation life-cycle hooks, it should be invoked in + safe_execute_invocation_hooks + """ + AppExtClass = self._generate_new_app_extension( + base=self._sdk.AppExtensionBase + ) + hook_instances = ( + self._sdk.ExtensionMeta.get_application_hooks() + ) + for hook_name in (APP_EXT_PRE_INVOCATION, APP_EXT_POST_INVOCATION): + self._instance._safe_execute_invocation_hooks( + hooks=hook_instances, + hook_name=hook_name, + ctx=self._mock_context, + fargs=[], + fret=None + ) + self.assertFalse(AppExtClass._post_function_load_app_level_executed) + self.assertTrue(AppExtClass._pre_invocation_app_level_executed) + self.assertTrue(AppExtClass._post_invocation_app_level_executed) + + def test_post_function_load_app_level_hook_should_be_executed(self): + """If there is an extension implemented the post_function_load + life-cycle hook, it invokes in safe_execute_function_load_hooks + """ + AppExtClass = self._generate_new_app_extension( + base=self._sdk.AppExtensionBase + ) + hook_instances = ( + self._sdk.ExtensionMeta.get_application_hooks() + ) + for hook_name in (APP_EXT_POST_FUNCTION_LOAD,): + self._instance._safe_execute_function_load_hooks( + hooks=hook_instances, + hook_name=hook_name, + fname=self._mock_func_name, + fdir=self._mock_func_dir + ) + self.assertTrue(AppExtClass._post_function_load_app_level_executed) + self.assertFalse(AppExtClass._pre_invocation_app_level_executed) + self.assertFalse(AppExtClass._post_invocation_app_level_executed) + + def test_raw_invocation_wrapper(self): + """This wrapper should automatically invoke all invocation extensions + """ + # Instantiate extensions + AppExtClass = self._generate_new_app_extension( + base=self._sdk.AppExtensionBase + ) + FuncExtClass = self._generate_new_func_extension_class( + base=self._sdk.FuncExtensionBase, + trigger=self._mock_func_name + ) + func_ext_instance = FuncExtClass() + + # Invoke with wrapper + self._instance._raw_invocation_wrapper( + self._mock_context, self._mock_function_main, self._mock_arguments + ) + + # Assert: invocation hooks should be executed + self.assertTrue(func_ext_instance._pre_invocation_executed) + self.assertTrue(func_ext_instance._post_invocation_executed) + self.assertTrue(AppExtClass._pre_invocation_app_level_executed) + self.assertTrue(AppExtClass._post_invocation_app_level_executed) + + # Assert: arguments should be passed into the extension + comparisons = ( + func_ext_instance._pre_invocation_executed_fargs, + func_ext_instance._post_invocation_executed_fargs, + AppExtClass._pre_invocation_app_level_executed_fargs, + AppExtClass._post_invocation_app_level_executed_fargs + ) + for current_argument in comparisons: + self.assertEqual(current_argument, self._mock_arguments) + + # Assert: returns should be passed into the extension + comparisons = ( + func_ext_instance._post_invocation_executed_fret, + AppExtClass._post_invocation_app_level_executed_fret + ) + for current_return in comparisons: + self.assertEqual(current_return, 'request_ok') + + @patch('azure_functions_worker.extension.logger.error') + def test_exception_handling_in_post_function_load_app_level( + self, + error_mock: Mock + ): + """When there's a chain breaks in the extension chain, it should not + pause other executions. For post_function_load_app_level, becasue the + logger is not fully initialized, the exception will be suppressed. + """ + # Create an customized exception + expt = Exception('Exception in post_function_load_app_level') + + # Register an application extension + class BadAppExtension(self._sdk.AppExtensionBase): + post_function_load_app_level_executed = False + + @classmethod + def post_function_load_app_level(cls, + function_name, + function_directory, + *args, + **kwargs): + cls.post_function_load_app_level_executed = True + raise expt + + # Execute function with a broken extension + hooks = self._sdk.ExtensionMeta.get_application_hooks() + self._instance._safe_execute_function_load_hooks( + hooks=hooks, + hook_name=APP_EXT_POST_FUNCTION_LOAD, + fname=self._mock_func_name, + fdir=self._mock_func_dir + ) + + # Ensure the extension is executed, but the exception shouldn't surface + self.assertTrue(BadAppExtension.post_function_load_app_level_executed) + + # Ensure errors are reported from system logger + error_mock.assert_called_with(expt, exc_info=True) + + def test_exception_handling_in_pre_invocation_app_level(self): + """When there's a chain breaks in the extension chain, it should not + pause other executions, but report with a system logger, so that the + error is accessible to customers and ours. + """ + # Create an customized exception + expt = Exception('Exception in pre_invocation_app_level') + + # Register an application extension + class BadAppExtension(self._sdk.AppExtensionBase): + @classmethod + def pre_invocation_app_level(cls, logger, context, func_args, + *args, **kwargs): + raise expt + + # Create a mocked customer_function + wrapped = self._instance.get_sync_invocation_wrapper( + self._mock_context, + self._mock_function_main + ) + + # Mock logger + ext_logger = logging.getLogger( + 'azure_functions_worker.extension.BadAppExtension' + ) + ext_logger_error_mock = Mock() + ext_logger.error = ext_logger_error_mock + + # Invocation with arguments. This will throw an exception, but should + # not break the execution chain. + result = wrapped(self._mock_arguments) + + # Ensure the customer's function is executed + self.assertEqual(result, 'request_ok') + + # Ensure the error is reported + ext_logger_error_mock.assert_called_with(expt, exc_info=True) + + def test_get_sync_invocation_wrapper_no_extension(self): + """The wrapper is using functools.partial() to expose the arguments + for synchronous execution in dispatcher. + """ + # Create a mocked customer_function + wrapped = self._instance.get_sync_invocation_wrapper( + self._mock_context, + self._mock_function_main + ) + + # Invocation with arguments + result = wrapped(self._mock_arguments) + + # Ensure the return value matches the function method + self.assertEqual(result, 'request_ok') + + def test_get_sync_invocation_wrapper_with_func_extension(self): + """The wrapper is using functools.partial() to expose the arguments. + Ensure the func extension can be executed along with customer's funcs. + """ + # Register a function extension + FuncExtClass = self._generate_new_func_extension_class( + self._sdk.FuncExtensionBase, + self._mock_func_name + ) + _func_ext_instance = FuncExtClass() + + # Create a mocked customer_function + wrapped = self._instance.get_sync_invocation_wrapper( + self._mock_context, + self._mock_function_main + ) + + # Invocation via wrapper with arguments + result = wrapped(self._mock_arguments) + + # Ensure the extension is executed + self.assertTrue(_func_ext_instance._pre_invocation_executed) + + # Ensure the customer's function is executed + self.assertEqual(result, 'request_ok') + + def test_get_sync_invocation_wrapper_disabled_with_flag(self): + """The wrapper should still exist, customer's functions should still + be executed, but not the extension + """ + # Turn off feature flag + os.environ[PYTHON_ENABLE_WORKER_EXTENSIONS] = 'false' + + # Register a function extension + FuncExtClass = self._generate_new_func_extension_class( + self._sdk.FuncExtensionBase, + self._mock_func_name + ) + _func_ext_instance = FuncExtClass() + + # Create a mocked customer_function + wrapped = self._instance.get_sync_invocation_wrapper( + self._mock_context, + self._mock_function_main + ) + + # Invocation via wrapper with arguments + result = wrapped(self._mock_arguments) + + # The extension SHOULD NOT be executed, since the feature flag is off + self.assertFalse(_func_ext_instance._pre_invocation_executed) + + # Ensure the customer's function is executed + self.assertEqual(result, 'request_ok') + + def test_get_async_invocation_wrapper_no_extension(self): + """The async wrapper will wrap an asynchronous function with a + coroutine interface. When there is no extension, it should only invoke + the customer's function. + """ + # Create a mocked customer_function with async wrapper + result = aio_compat.run( + self._instance.get_async_invocation_wrapper( + self._mock_context, + self._mock_function_main_async, + self._mock_arguments + ) + ) + + # Ensure the return value matches the function method + self.assertEqual(result, 'request_ok') + + def test_get_async_invocation_wrapper_with_func_extension(self): + """The async wrapper will wrap an asynchronous function with a + coroutine interface. When there is registered extension, it should + execute the extension as well. + """ + # Register a function extension + FuncExtClass = self._generate_new_func_extension_class( + self._sdk.FuncExtensionBase, + self._mock_func_name + ) + _func_ext_instance = FuncExtClass() + + # Create a mocked customer_function with async wrapper + result = aio_compat.run( + self._instance.get_async_invocation_wrapper( + self._mock_context, + self._mock_function_main_async, + self._mock_arguments + ) + ) + + # Ensure the extension is executed + self.assertTrue(_func_ext_instance._pre_invocation_executed) + + # Ensure the customer's function is executed + self.assertEqual(result, 'request_ok') + + def test_get_invocation_async_disabled_with_flag(self): + """The async wrapper will only execute customer's function. This + should not execute the extension. + """ + # Turn off feature flag + os.environ[PYTHON_ENABLE_WORKER_EXTENSIONS] = 'false' + + # Register a function extension + FuncExtClass = self._generate_new_func_extension_class( + self._sdk.FuncExtensionBase, + self._mock_func_name + ) + _func_ext_instance = FuncExtClass() + + # Create a mocked customer_function with async wrapper + result = aio_compat.run( + self._instance.get_async_invocation_wrapper( + self._mock_context, + self._mock_function_main_async, + self._mock_arguments + ) + ) + + # The extension SHOULD NOT be executed + self.assertFalse(_func_ext_instance._pre_invocation_executed) + + # Ensure the customer's function is executed + self.assertEqual(result, 'request_ok') + + def test_is_pre_invocation_hook(self): + for name in (FUNC_EXT_PRE_INVOCATION, APP_EXT_PRE_INVOCATION): + self.assertTrue( + self._instance._is_pre_invocation_hook(name) + ) + + def test_is_pre_invocation_hook_negative(self): + for name in (FUNC_EXT_POST_INVOCATION, APP_EXT_POST_INVOCATION, + FUNC_EXT_POST_FUNCTION_LOAD, APP_EXT_POST_FUNCTION_LOAD): + self.assertFalse( + self._instance._is_pre_invocation_hook(name) + ) + + def test_is_post_invocation_hook(self): + for name in (FUNC_EXT_POST_INVOCATION, APP_EXT_POST_INVOCATION): + self.assertTrue( + self._instance._is_post_invocation_hook(name) + ) + + def test_is_post_invocation_hook_negative(self): + for name in (FUNC_EXT_PRE_INVOCATION, APP_EXT_PRE_INVOCATION, + FUNC_EXT_POST_FUNCTION_LOAD, APP_EXT_POST_FUNCTION_LOAD): + self.assertFalse( + self._instance._is_post_invocation_hook(name) + ) + + @patch('azure_functions_worker.extension.' + 'ExtensionManager._info_extension_is_enabled') + def test_try_get_sdk_with_extension_enabled_should_execute_once( + self, + info_extension_is_enabled_mock: Mock + ): + """The result of an extension enabled SDK should be cached. No need + to be derived multiple times. + """ + # Call twice the function + self._instance._try_get_sdk_with_extension_enabled() + sdk = self._instance._try_get_sdk_with_extension_enabled() + + # The actual execution will only process once (e.g. list extensions) + info_extension_is_enabled_mock.assert_called_once() + + # Ensure the SDK is returned correctly + self.assertIsNotNone(sdk) + + @patch('azure_functions_worker.extension.' + 'ExtensionManager._warn_sdk_not_support_extension') + def test_try_get_sdk_with_extension_disabled_should_execute_once( + self, + warn_sdk_not_support_extension_mock: Mock + ): + """When SDK does not support extension interface, it should return + None and throw a warning. + """ + # Point to dummy SDK + sys.path.insert(0, self._dummy_sdk_sys_path) + + # Call twice the function + self._instance._try_get_sdk_with_extension_enabled() + sdk = self._instance._try_get_sdk_with_extension_enabled() + + # The actual execution will only process once (e.g. warning) + warn_sdk_not_support_extension_mock.assert_called_once() + + # The SDK does not support Extension Interface, should be None + self.assertIsNone(sdk) + + @patch('azure_functions_worker.extension.logger.info') + def test_info_extension_is_enabled(self, info_mock: Mock): + # Get SDK from sys.path + sdk = get_sdk_from_sys_path() + + # Check logs + self._instance._info_extension_is_enabled(sdk) + info_mock.assert_called_once_with( + 'Python Worker Extension is enabled in azure.functions ' + f'({sdk.__version__}).' + ) + + @patch('azure_functions_worker.extension.logger.info') + def test_info_discover_extension_list_func_ext(self, info_mock: Mock): + # Get SDK from sys.path + sdk = get_sdk_from_sys_path() + + # Register a function extension class + FuncExtClass = self._generate_new_func_extension_class( + sdk.FuncExtensionBase, + self._mock_func_name + ) + + # Instantiate a function extension + FuncExtClass() + + # Check logs + self._instance._info_discover_extension_list(self._mock_func_name, sdk) + info_mock.assert_called_once_with( + 'Python Worker Extension Manager is loading HttpTrigger, ' + 'current registered extensions: ' + r'{"FuncExtension": {"HttpTrigger": ["NewFuncExtension"]}}' + ) + + @patch('azure_functions_worker.extension.logger.info') + def test_info_discover_extension_list_app_ext(self, info_mock: Mock): + # Get SDK from sys.path + sdk = get_sdk_from_sys_path() + + # Register a function extension class + self._generate_new_app_extension(sdk.AppExtensionBase) + + # Check logs + self._instance._info_discover_extension_list(self._mock_func_name, sdk) + info_mock.assert_called_once_with( + 'Python Worker Extension Manager is loading HttpTrigger, ' + 'current registered extensions: ' + r'{"AppExtension": ["NewAppExtension"]}' + ) + + @patch('azure_functions_worker.extension.logger.warning') + def test_warn_sdk_not_support_extension(self, warning_mock: Mock): + # Get SDK from dummy + sys.path.insert(0, self._dummy_sdk_sys_path) + sdk = get_sdk_from_sys_path() + + # Check logs + self._instance._warn_sdk_not_support_extension(sdk) + warning_mock.assert_called_once_with( + 'The azure.functions (dummy) does not ' + 'support Python worker extensions. If you believe extensions ' + 'are correctly installed, please set the ' + 'PYTHON_ISOLATE_WORKER_DEPENDENCIES and ' + 'PYTHON_ENABLE_WORKER_EXTENSIONS to "true"' + ) + + def _generate_new_func_extension_class(self, base: type, trigger: str): + class NewFuncExtension(base): + def __init__(self): + self._trigger_name = trigger + self._post_function_load_executed = False + self._pre_invocation_executed = False + self._post_invocation_executed = False + + self._pre_invocation_executed_fargs = {} + self._post_invocation_executed_fargs = {} + self._post_invocation_executed_fret = None + + def post_function_load(self, + function_name, + function_directory, + *args, + **kwargs): + self._post_function_load_executed = True + + def pre_invocation(self, logger, context, fargs, + *args, **kwargs): + self._pre_invocation_executed = True + self._pre_invocation_executed_fargs = fargs + + def post_invocation(self, logger, context, fargs, fret, + *args, **kwargs): + self._post_invocation_executed = True + self._post_invocation_executed_fargs = fargs + self._post_invocation_executed_fret = fret + + return NewFuncExtension + + def _generate_new_app_extension(self, base: type): + class NewAppExtension(base): + _init_executed = False + + _post_function_load_app_level_executed = False + _pre_invocation_app_level_executed = False + _post_invocation_app_level_executed = False + + _pre_invocation_app_level_executed_fargs = {} + _post_invocation_app_level_executed_fargs = {} + _post_invocation_app_level_executed_fret = None + + @classmethod + def init(cls): + cls._init_executed = True + + @classmethod + def post_function_load_app_level(cls, + function_name, + function_directory, + *args, + **kwargs): + cls._post_function_load_app_level_executed = True + + @classmethod + def pre_invocation_app_level(cls, logger, context, func_args, + *args, **kwargs): + cls._pre_invocation_app_level_executed = True + cls._pre_invocation_app_level_executed_fargs = func_args + + @classmethod + def post_invocation_app_level(cls, logger, context, + func_args, func_ret, + *args, **kwargs): + cls._post_invocation_app_level_executed = True + cls._post_invocation_app_level_executed_fargs = func_args + cls._post_invocation_app_level_executed_fret = func_ret + + return NewAppExtension + + def _mock_function_main(self, req): + assert req == 'request' + return req + '_ok' + + async def _mock_function_main_async(self, req): + assert req == 'request' + return req + '_ok' diff --git a/tests/unittests/test_utilities.py b/tests/unittests/test_utilities.py index 935659816..df3ea1e8e 100644 --- a/tests/unittests/test_utilities.py +++ b/tests/unittests/test_utilities.py @@ -1,7 +1,9 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. import os +import sys import unittest +from unittest.mock import patch import typing from azure_functions_worker.utils import common, wrappers @@ -68,8 +70,22 @@ class TestUtilities(unittest.TestCase): def setUp(self): self._pre_env = dict(os.environ) + self._dummy_sdk_sys_path = os.path.join( + os.path.dirname(__file__), + 'resources', + 'mock_azure_functions' + ) + + self.mock_sys_module = patch.dict('sys.modules', sys.modules.copy()) + self.mock_sys_path = patch('sys.path', sys.path.copy()) + + self.mock_sys_module.start() + self.mock_sys_path.start() def tearDown(self): + self.mock_sys_path.stop() + self.mock_sys_module.stop() + os.environ.clear() os.environ.update(self._pre_env) @@ -312,6 +328,40 @@ def test_is_python_version(self): is_python_version_39 ])) + def test_get_sdk_from_sys_path(self): + """Test if the extension manager can find azure.functions module + """ + module = common.get_sdk_from_sys_path() + self.assertIsNotNone(module.__file__) + + def test_get_sdk_from_sys_path_after_updating_sys_path(self): + """Test if the get_sdk_from_sys_path can find the newer azure.functions + module after updating the sys.path. This is specifically for a scenario + after the dependency manager is switched to customer's path + """ + sys.path.insert(0, self._dummy_sdk_sys_path) + module = common.get_sdk_from_sys_path() + self.assertEqual( + os.path.dirname(module.__file__), + os.path.join(self._dummy_sdk_sys_path, 'azure', 'functions') + ) + + def test_get_sdk_version(self): + """Test if sdk version can be retrieved correctly + """ + module = common.get_sdk_from_sys_path() + sdk_version = common.get_sdk_version(module) + # e.g. 1.6.0, 1.7.0b, 1.8.1dev + self.assertRegex(sdk_version, r'\d+\.\d+\.\w+') + + def test_get_sdk_dummy_version(self): + """Test if sdk version can get dummy sdk version + """ + sys.path.insert(0, self._dummy_sdk_sys_path) + module = common.get_sdk_from_sys_path() + sdk_version = common.get_sdk_version(module) + self.assertEqual(sdk_version, 'dummy') + def _unset_feature_flag(self): try: os.environ.pop(TEST_FEATURE_FLAG)