diff --git a/azure_functions_worker/dispatcher.py b/azure_functions_worker/dispatcher.py index 8837ed77b..f031812f8 100644 --- a/azure_functions_worker/dispatcher.py +++ b/azure_functions_worker/dispatcher.py @@ -27,6 +27,8 @@ PYTHON_THREADPOOL_THREAD_COUNT_DEFAULT, PYTHON_THREADPOOL_THREAD_COUNT_MAX, PYTHON_THREADPOOL_THREAD_COUNT_MIN) +from .extensions import (get_before_invocation_request_callbacks, + get_after_invocation_request_callbacks) from .logging import disable_console_logging, enable_console_logging from .logging import error_logger, is_system_log_category, logger from .utils.common import get_app_setting @@ -359,9 +361,19 @@ async def _handle__invocation_request(self, req): trigger_metadata=trigger_metadata, pytype=pb_type_info.pytype) + context = bindings.Context( + fi.name, fi.directory, invocation_id, trace_context) + + # Execute before invocation callbacks + for callback in get_before_invocation_request_callbacks(): + try: + callback(context) + except Exception as ex: + logger.warning( + "Before invocation callback failed with: %s.", ex) + if fi.requires_context: - args['context'] = bindings.Context( - fi.name, fi.directory, invocation_id, trace_context) + args['context'] = context if fi.output_types: for name in fi.output_types: @@ -402,6 +414,14 @@ async def _handle__invocation_request(self, req): fi.return_type.binding_name, call_result, pytype=fi.return_type.pytype) + # Execute after invocation callbacks + for callback in get_after_invocation_request_callbacks(): + try: + callback(context) + except Exception as ex: + logger.warning( + "After invocation callback failed with: %s.", ex) + # Actively flush customer print() function to console sys.stdout.flush() diff --git a/azure_functions_worker/extensions.py b/azure_functions_worker/extensions.py new file mode 100644 index 000000000..052a4246c --- /dev/null +++ b/azure_functions_worker/extensions.py @@ -0,0 +1,36 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +_EXTENSIONS_CONTEXT = dict() + + +def register_before_invocation_request(callback): + if _EXTENSIONS_CONTEXT.get("BEFORE_INVOCATION_REQUEST_CALLBACKS"): + _EXTENSIONS_CONTEXT.get( + "BEFORE_INVOCATION_REQUEST_CALLBACKS").append(callback) + else: + _EXTENSIONS_CONTEXT["BEFORE_INVOCATION_REQUEST_CALLBACKS"] = [callback] + + +def register_after_invocation_request(callback): + if _EXTENSIONS_CONTEXT.get("AFTER_INVOCATION_REQUEST_CALLBACKS"): + _EXTENSIONS_CONTEXT.get( + "AFTER_INVOCATION_REQUEST_CALLBACKS").append(callback) + else: + _EXTENSIONS_CONTEXT["AFTER_INVOCATION_REQUEST_CALLBACKS"] = [callback] + + +def clear_before_invocation_request_callbacks(): + _EXTENSIONS_CONTEXT.pop("BEFORE_INVOCATION_REQUEST_CALLBACKS", None) + + +def clear_after_invocation_request_callbacks(): + _EXTENSIONS_CONTEXT.pop("AFTER_INVOCATION_REQUEST_CALLBACKS", None) + + +def get_before_invocation_request_callbacks(): + return _EXTENSIONS_CONTEXT.get("BEFORE_INVOCATION_REQUEST_CALLBACKS", []) + + +def get_after_invocation_request_callbacks(): + return _EXTENSIONS_CONTEXT.get("AFTER_INVOCATION_REQUEST_CALLBACKS", []) diff --git a/tests/unittests/test_extensions.py b/tests/unittests/test_extensions.py new file mode 100644 index 000000000..6634e4160 --- /dev/null +++ b/tests/unittests/test_extensions.py @@ -0,0 +1,98 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +from unittest import TestCase, mock + +from azure_functions_worker import extensions + + +class TestExtensions(TestCase): + + def tearDown(self): + extensions._EXTENSIONS_CONTEXT.clear() + + def test_register_before_invocation_request(self): + mock_cb = mock.Mock() + mock_cb2 = mock.Mock() + extensions.register_before_invocation_request(mock_cb) + self.assertEqual( + extensions._EXTENSIONS_CONTEXT + ["BEFORE_INVOCATION_REQUEST_CALLBACKS"][0], + mock_cb, + ) + extensions.register_before_invocation_request(mock_cb2) + self.assertEqual( + extensions._EXTENSIONS_CONTEXT + ["BEFORE_INVOCATION_REQUEST_CALLBACKS"][1], + mock_cb2, + ) + + def test_register_after_invocation_request(self): + mock_cb = mock.Mock() + mock_cb2 = mock.Mock() + extensions.register_after_invocation_request(mock_cb) + self.assertEqual( + extensions._EXTENSIONS_CONTEXT + ["AFTER_INVOCATION_REQUEST_CALLBACKS"][0], + mock_cb, + ) + extensions.register_after_invocation_request(mock_cb2) + self.assertEqual( + extensions._EXTENSIONS_CONTEXT + ["AFTER_INVOCATION_REQUEST_CALLBACKS"][1], + mock_cb2, + ) + + def test_clear_before_invocation_request_callbacks(self): + mock_cb = mock.Mock() + extensions.register_before_invocation_request(mock_cb) + self.assertEqual( + extensions._EXTENSIONS_CONTEXT + ["BEFORE_INVOCATION_REQUEST_CALLBACKS"][0], + mock_cb, + ) + extensions.clear_before_invocation_request_callbacks() + self.assertIsNone( + extensions._EXTENSIONS_CONTEXT. + get("BEFORE_INVOCATION_REQUEST_CALLBACKS"), + ) + + def test_clear_after_invocation_request_callbacks(self): + mock_cb = mock.Mock() + extensions.register_after_invocation_request(mock_cb) + self.assertEqual( + extensions._EXTENSIONS_CONTEXT + ["AFTER_INVOCATION_REQUEST_CALLBACKS"][0], + mock_cb, + ) + extensions.clear_after_invocation_request_callbacks() + self.assertIsNone( + extensions._EXTENSIONS_CONTEXT. + get("AFTER_INVOCATION_REQUEST_CALLBACKS"), + ) + + def test_get_before_invocation_request_callbacks(self): + mock_cb = mock.Mock() + extensions.register_before_invocation_request(mock_cb) + self.assertEqual( + extensions._EXTENSIONS_CONTEXT + ["BEFORE_INVOCATION_REQUEST_CALLBACKS"][0], + mock_cb, + ) + self.assertEqual( + extensions.get_before_invocation_request_callbacks()[0], + mock_cb + ) + + def test_get_after_invocation_request_callbacks(self): + mock_cb = mock.Mock() + extensions.register_after_invocation_request(mock_cb) + self.assertEqual( + extensions._EXTENSIONS_CONTEXT + ["AFTER_INVOCATION_REQUEST_CALLBACKS"][0], + mock_cb, + ) + self.assertEqual( + extensions.get_after_invocation_request_callbacks()[0], + mock_cb + )