From 2d132b9d678f4613914cab1670cacdb21de90ce9 Mon Sep 17 00:00:00 2001 From: Gohar Irfan Chaudhry Date: Tue, 20 Apr 2021 16:04:35 -0700 Subject: [PATCH 01/11] Separate field in the message to close shared memory maps which indicates if the maps should be deleted or just the reference should be dropped --- .../shared_memory_data_transfer/shared_memory_manager.py | 9 +++++---- azure_functions_worker/dispatcher.py | 4 +++- .../protos/_src/src/proto/FunctionRpc.proto | 1 + 3 files changed, 9 insertions(+), 5 deletions(-) diff --git a/azure_functions_worker/bindings/shared_memory_data_transfer/shared_memory_manager.py b/azure_functions_worker/bindings/shared_memory_data_transfer/shared_memory_manager.py index b4cba5444..661bf1fcc 100644 --- a/azure_functions_worker/bindings/shared_memory_data_transfer/shared_memory_manager.py +++ b/azure_functions_worker/bindings/shared_memory_data_transfer/shared_memory_manager.py @@ -157,10 +157,11 @@ def get_string(self, mem_map_name: str, offset: int, count: int) \ content_str = content_bytes.decode('utf-8') return content_str - def free_mem_map(self, mem_map_name: str) -> bool: + def free_mem_map(self, mem_map_name: str, + is_delete_backing_resources: bool = True) -> bool: """ - Frees the memory map and any backing resources (e.g. file in the case of - Unix) associated with it. + Frees the memory map and, if specified, any backing resources (e.g. + file in the case of Unix) associated with it. If there is no memory map with the given name being tracked, then no action is performed. Returns True if the memory map was freed successfully, False otherwise. @@ -170,7 +171,7 @@ def free_mem_map(self, mem_map_name: str) -> bool: f'Cannot find memory map in list of allocations {mem_map_name}') return False shared_mem_map = self.allocated_mem_maps[mem_map_name] - success = shared_mem_map.dispose() + success = shared_mem_map.dispose(is_delete_backing_resources) del self.allocated_mem_maps[mem_map_name] return success diff --git a/azure_functions_worker/dispatcher.py b/azure_functions_worker/dispatcher.py index 453559855..148c74335 100644 --- a/azure_functions_worker/dispatcher.py +++ b/azure_functions_worker/dispatcher.py @@ -504,6 +504,7 @@ async def _handle__close_shared_memory_resources_request(self, req): """ close_request = req.close_shared_memory_resources_request map_names = close_request.map_names + to_delete = close_request.to_delete # Assign default value of False to all result values. # If we are successfully able to close a memory map, its result will be # set to True. @@ -512,7 +513,8 @@ async def _handle__close_shared_memory_resources_request(self, req): try: for mem_map_name in map_names: try: - success = self._shmem_mgr.free_mem_map(mem_map_name) + success = self._shmem_mgr.free_mem_map(mem_map_name, + to_delete) results[mem_map_name] = success except Exception as e: logger.error(f'Cannot free memory map {mem_map_name} - {e}', diff --git a/azure_functions_worker/protos/_src/src/proto/FunctionRpc.proto b/azure_functions_worker/protos/_src/src/proto/FunctionRpc.proto index 403156e24..d3f8278bd 100644 --- a/azure_functions_worker/protos/_src/src/proto/FunctionRpc.proto +++ b/azure_functions_worker/protos/_src/src/proto/FunctionRpc.proto @@ -208,6 +208,7 @@ message FunctionEnvironmentReloadResponse { // Tell the out-of-proc worker to close any shared memory maps it allocated for given invocation message CloseSharedMemoryResourcesRequest { repeated string map_names = 1; + bool to_delete = 2; } // Response from the worker indicating which of the shared memory maps have been successfully closed and which have not been closed From 29019cbf2bff409904fe04e37bcc6e32db9f4278 Mon Sep 17 00:00:00 2001 From: Gohar Irfan Chaudhry Date: Wed, 28 Apr 2021 15:31:49 -0700 Subject: [PATCH 02/11] Check for function data cache capability from host --- azure_functions_worker/bindings/meta.py | 16 ++++++++++++++-- azure_functions_worker/constants.py | 1 + azure_functions_worker/dispatcher.py | 16 ++++++++++++---- .../protos/_src/src/proto/FunctionRpc.proto | 1 - 4 files changed, 27 insertions(+), 7 deletions(-) diff --git a/azure_functions_worker/bindings/meta.py b/azure_functions_worker/bindings/meta.py index cbf95fcb0..8279f78c1 100644 --- a/azure_functions_worker/bindings/meta.py +++ b/azure_functions_worker/bindings/meta.py @@ -1,5 +1,6 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. +from operator import truediv import sys import typing @@ -111,6 +112,14 @@ def get_datum(binding: str, obj: typing.Any, return datum +def is_cache_supported(datum: datumdef.Datum): + if datum.type == 'bytes': + return True + elif datum.type == 'string': + return True + return False + + def to_outgoing_proto(binding: str, obj: typing.Any, *, pytype: typing.Optional[type]) -> protos.TypedData: datum = get_datum(binding, obj, pytype) @@ -120,13 +129,16 @@ def to_outgoing_proto(binding: str, obj: typing.Any, *, def to_outgoing_param_binding(binding: str, obj: typing.Any, *, pytype: typing.Optional[type], out_name: str, - shmem_mgr) \ + shmem_mgr, + is_function_data_cache_enabled: bool) \ -> protos.ParameterBinding: datum = get_datum(binding, obj, pytype) shared_mem_value = None # If shared memory is enabled and supported for the given datum, try to # transfer to host over shared memory as a default - if shmem_mgr.is_enabled() and shmem_mgr.is_supported(datum): + can_transfer_over_shmem = shmem_mgr.is_supported(datum) or \ + (is_function_data_cache_enabled and is_cache_supported(datum)) + if shmem_mgr.is_enabled() and can_transfer_over_shmem: shared_mem_value = datumdef.Datum.to_rpc_shared_memory(datum, shmem_mgr) # Check if data was written into shared memory if shared_mem_value is not None: diff --git a/azure_functions_worker/constants.py b/azure_functions_worker/constants.py index 6c75ddb19..16ab3419a 100644 --- a/azure_functions_worker/constants.py +++ b/azure_functions_worker/constants.py @@ -11,6 +11,7 @@ RPC_HTTP_TRIGGER_METADATA_REMOVED = "RpcHttpTriggerMetadataRemoved" WORKER_STATUS = "WorkerStatus" SHARED_MEMORY_DATA_TRANSFER = "SharedMemoryDataTransfer" +FUNCTION_DATA_CACHE = "FunctionDataCache" # Debug Flags PYAZURE_WEBHOST_DEBUG = "PYAZURE_WEBHOST_DEBUG" diff --git a/azure_functions_worker/dispatcher.py b/azure_functions_worker/dispatcher.py index 148c74335..19a0f4478 100644 --- a/azure_functions_worker/dispatcher.py +++ b/azure_functions_worker/dispatcher.py @@ -15,6 +15,7 @@ from asyncio import BaseEventLoop from logging import LogRecord from typing import List, Optional +import time # TODO del import grpc @@ -71,6 +72,7 @@ def __init__(self, loop: BaseEventLoop, host: str, port: int, self._port = port self._request_id = request_id self._worker_id = worker_id + self._function_data_cache_enabled = False self._functions = functions.Registry() self._shmem_mgr = SharedMemoryManager() @@ -258,6 +260,12 @@ async def _handle__worker_init_request(self, req): logger.info('Received WorkerInitRequest, request ID %s', self.request_id) + worker_init_request = req.worker_init_request + host_capabilities = worker_init_request.capabilities + if constants.FUNCTION_DATA_CACHE in host_capabilities: + val = host_capabilities[constants.FUNCTION_DATA_CACHE] + self._function_data_cache_enabled = val == _TRUE + capabilities = { constants.RAW_HTTP_BODY_BYTES: _TRUE, constants.TYPED_DATA_COLLECTION: _TRUE, @@ -392,6 +400,7 @@ async def _handle__invocation_request(self, req): 'binding returned a non-None value') output_data = [] + cache_enabled = self._function_data_cache_enabled if fi.output_types: for out_name, out_type_info in fi.output_types.items(): val = args[out_name].get() @@ -403,7 +412,8 @@ async def _handle__invocation_request(self, req): param_binding = bindings.to_outgoing_param_binding( out_type_info.binding_name, val, pytype=out_type_info.pytype, - out_name=out_name, shmem_mgr=self._shmem_mgr) + out_name=out_name, shmem_mgr=self._shmem_mgr, + is_function_data_cache_enabled=cache_enabled) output_data.append(param_binding) return_value = None @@ -504,7 +514,6 @@ async def _handle__close_shared_memory_resources_request(self, req): """ close_request = req.close_shared_memory_resources_request map_names = close_request.map_names - to_delete = close_request.to_delete # Assign default value of False to all result values. # If we are successfully able to close a memory map, its result will be # set to True. @@ -513,8 +522,7 @@ async def _handle__close_shared_memory_resources_request(self, req): try: for mem_map_name in map_names: try: - success = self._shmem_mgr.free_mem_map(mem_map_name, - to_delete) + success = self._shmem_mgr.free_mem_map(mem_map_name, False) results[mem_map_name] = success except Exception as e: logger.error(f'Cannot free memory map {mem_map_name} - {e}', diff --git a/azure_functions_worker/protos/_src/src/proto/FunctionRpc.proto b/azure_functions_worker/protos/_src/src/proto/FunctionRpc.proto index d3f8278bd..403156e24 100644 --- a/azure_functions_worker/protos/_src/src/proto/FunctionRpc.proto +++ b/azure_functions_worker/protos/_src/src/proto/FunctionRpc.proto @@ -208,7 +208,6 @@ message FunctionEnvironmentReloadResponse { // Tell the out-of-proc worker to close any shared memory maps it allocated for given invocation message CloseSharedMemoryResourcesRequest { repeated string map_names = 1; - bool to_delete = 2; } // Response from the worker indicating which of the shared memory maps have been successfully closed and which have not been closed From 934262253df53e6250dcb94a510056690356fa33 Mon Sep 17 00:00:00 2001 From: Gohar Irfan Chaudhry Date: Wed, 28 Apr 2021 15:39:32 -0700 Subject: [PATCH 03/11] Removing unused import --- azure_functions_worker/dispatcher.py | 1 - 1 file changed, 1 deletion(-) diff --git a/azure_functions_worker/dispatcher.py b/azure_functions_worker/dispatcher.py index 19a0f4478..ad12bf400 100644 --- a/azure_functions_worker/dispatcher.py +++ b/azure_functions_worker/dispatcher.py @@ -15,7 +15,6 @@ from asyncio import BaseEventLoop from logging import LogRecord from typing import List, Optional -import time # TODO del import grpc From bbc2b04570678de1e5c9fe4c0318c96ff1c095a7 Mon Sep 17 00:00:00 2001 From: Gohar Irfan Chaudhry Date: Wed, 28 Apr 2021 15:42:53 -0700 Subject: [PATCH 04/11] Removing unused import --- azure_functions_worker/bindings/meta.py | 1 - 1 file changed, 1 deletion(-) diff --git a/azure_functions_worker/bindings/meta.py b/azure_functions_worker/bindings/meta.py index 8279f78c1..7526a952c 100644 --- a/azure_functions_worker/bindings/meta.py +++ b/azure_functions_worker/bindings/meta.py @@ -1,6 +1,5 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -from operator import truediv import sys import typing From 486d71b53f94fe29e37129a23e3b21d42cfbb37a Mon Sep 17 00:00:00 2001 From: Gohar Irfan Chaudhry Date: Mon, 20 Sep 2021 08:25:50 -0700 Subject: [PATCH 05/11] Addressing comments --- .../shared_memory_manager.py | 402 ++--- azure_functions_worker/dispatcher.py | 1404 ++++++++--------- tests/unittests/test_shared_memory_manager.py | 712 +++++---- 3 files changed, 1270 insertions(+), 1248 deletions(-) diff --git a/azure_functions_worker/bindings/shared_memory_data_transfer/shared_memory_manager.py b/azure_functions_worker/bindings/shared_memory_data_transfer/shared_memory_manager.py index 661bf1fcc..c8efc55d9 100644 --- a/azure_functions_worker/bindings/shared_memory_data_transfer/shared_memory_manager.py +++ b/azure_functions_worker/bindings/shared_memory_data_transfer/shared_memory_manager.py @@ -1,201 +1,201 @@ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. - -import uuid -from typing import Dict, Optional -from .shared_memory_constants import SharedMemoryConstants as consts -from .file_accessor_factory import FileAccessorFactory -from .shared_memory_metadata import SharedMemoryMetadata -from .shared_memory_map import SharedMemoryMap -from ..datumdef import Datum -from ...logging import logger -from ...utils.common import is_envvar_true -from ...constants import FUNCTIONS_WORKER_SHARED_MEMORY_DATA_TRANSFER_ENABLED - - -class SharedMemoryManager: - """ - Performs all operations related to reading/writing data from/to shared - memory. - This is used for transferring input/output data of the function from/to the - functions host over shared memory as opposed to RPC to improve the rate of - data transfer and the function's end-to-end latency. - """ - def __init__(self): - # The allocated memory maps are tracked here so that a reference to them - # is kept open until they have been used (e.g. if they contain a - # function's output, it is read by the functions host). - # Having a mapping of the name and the memory map is then later used to - # close a given memory map by its name, after it has been used. - # key: mem_map_name, val: SharedMemoryMap - self._allocated_mem_maps: Dict[str, SharedMemoryMap] = {} - self._file_accessor = FileAccessorFactory.create_file_accessor() - - def __del__(self): - del self._file_accessor - del self._allocated_mem_maps - - @property - def allocated_mem_maps(self): - """ - List of allocated shared memory maps. - """ - return self._allocated_mem_maps - - @property - def file_accessor(self): - """ - FileAccessor instance for accessing memory maps. - """ - return self._file_accessor - - def is_enabled(self) -> bool: - """ - Whether supported types should be transferred between functions host and - the worker using shared memory. - """ - return is_envvar_true( - FUNCTIONS_WORKER_SHARED_MEMORY_DATA_TRANSFER_ENABLED) - - def is_supported(self, datum: Datum) -> bool: - """ - Whether the given Datum object can be transferred to the functions host - using shared memory. - This logic is kept consistent with the host's which can be found in - SharedMemoryManager.cs - """ - if datum.type == 'bytes': - num_bytes = len(datum.value) - if num_bytes >= consts.MIN_BYTES_FOR_SHARED_MEM_TRANSFER and \ - num_bytes <= consts.MAX_BYTES_FOR_SHARED_MEM_TRANSFER: - return True - elif datum.type == 'string': - num_bytes = len(datum.value) * consts.SIZE_OF_CHAR_BYTES - if num_bytes >= consts.MIN_BYTES_FOR_SHARED_MEM_TRANSFER and \ - num_bytes <= consts.MAX_BYTES_FOR_SHARED_MEM_TRANSFER: - return True - return False - - def put_bytes(self, content: bytes) -> Optional[SharedMemoryMetadata]: - """ - Writes the given bytes into shared memory. - Returns metadata about the shared memory region to which the content was - written if successful, None otherwise. - """ - if content is None: - return None - mem_map_name = str(uuid.uuid4()) - content_length = len(content) - shared_mem_map = self._create(mem_map_name, content_length) - if shared_mem_map is None: - return None - try: - num_bytes_written = shared_mem_map.put_bytes(content) - except Exception as e: - logger.warning(f'Cannot write {content_length} bytes into shared ' - f'memory {mem_map_name} - {e}') - shared_mem_map.dispose() - return None - if num_bytes_written != content_length: - logger.error( - f'Cannot write data into shared memory {mem_map_name} ' - f'({num_bytes_written} != {content_length})') - shared_mem_map.dispose() - return None - self.allocated_mem_maps[mem_map_name] = shared_mem_map - return SharedMemoryMetadata(mem_map_name, content_length) - - def put_string(self, content: str) -> Optional[SharedMemoryMetadata]: - """ - Writes the given string into shared memory. - Returns the name of the memory map into which the data was written if - succesful, None otherwise. - Note: The encoding used here must be consistent with what is used by the - host in SharedMemoryManager.cs (GetStringAsync/PutStringAsync). - """ - if content is None: - return None - content_bytes = content.encode('utf-8') - return self.put_bytes(content_bytes) - - def get_bytes(self, mem_map_name: str, offset: int, count: int) \ - -> Optional[bytes]: - """ - Reads data from the given memory map with the provided name, starting at - the provided offset and reading a total of count bytes. - Returns the data read from shared memory as bytes if successful, None - otherwise. - """ - if offset != 0: - logger.error( - f'Cannot read bytes. Non-zero offset ({offset}) ' - f'not supported.') - return None - shared_mem_map = self._open(mem_map_name, count) - if shared_mem_map is None: - return None - try: - content = shared_mem_map.get_bytes(content_offset=0, - bytes_to_read=count) - finally: - shared_mem_map.dispose(is_delete_file=False) - return content - - def get_string(self, mem_map_name: str, offset: int, count: int) \ - -> Optional[str]: - """ - Reads data from the given memory map with the provided name, starting at - the provided offset and reading a total of count bytes. - Returns the data read from shared memory as a string if successful, None - otherwise. - Note: The encoding used here must be consistent with what is used by the - host in SharedMemoryManager.cs (GetStringAsync/PutStringAsync). - """ - content_bytes = self.get_bytes(mem_map_name, offset, count) - if content_bytes is None: - return None - content_str = content_bytes.decode('utf-8') - return content_str - - def free_mem_map(self, mem_map_name: str, - is_delete_backing_resources: bool = True) -> bool: - """ - Frees the memory map and, if specified, any backing resources (e.g. - file in the case of Unix) associated with it. - If there is no memory map with the given name being tracked, then no - action is performed. - Returns True if the memory map was freed successfully, False otherwise. - """ - if mem_map_name not in self.allocated_mem_maps: - logger.error( - f'Cannot find memory map in list of allocations {mem_map_name}') - return False - shared_mem_map = self.allocated_mem_maps[mem_map_name] - success = shared_mem_map.dispose(is_delete_backing_resources) - del self.allocated_mem_maps[mem_map_name] - return success - - def _create(self, mem_map_name: str, content_length: int) \ - -> Optional[SharedMemoryMap]: - """ - Creates a new SharedMemoryMap with the given name and content length. - Returns the SharedMemoryMap object if successful, None otherwise. - """ - mem_map_size = consts.CONTENT_HEADER_TOTAL_BYTES + content_length - mem_map = self.file_accessor.create_mem_map(mem_map_name, mem_map_size) - if mem_map is None: - return None - return SharedMemoryMap(self.file_accessor, mem_map_name, mem_map) - - def _open(self, mem_map_name: str, content_length: int) \ - -> Optional[SharedMemoryMap]: - """ - Opens an existing SharedMemoryMap with the given name and content - length. - Returns the SharedMemoryMap object if successful, None otherwise. - """ - mem_map_size = consts.CONTENT_HEADER_TOTAL_BYTES + content_length - mem_map = self.file_accessor.open_mem_map(mem_map_name, mem_map_size) - if mem_map is None: - return None - return SharedMemoryMap(self.file_accessor, mem_map_name, mem_map) +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import uuid +from typing import Dict, Optional +from .shared_memory_constants import SharedMemoryConstants as consts +from .file_accessor_factory import FileAccessorFactory +from .shared_memory_metadata import SharedMemoryMetadata +from .shared_memory_map import SharedMemoryMap +from ..datumdef import Datum +from ...logging import logger +from ...utils.common import is_envvar_true +from ...constants import FUNCTIONS_WORKER_SHARED_MEMORY_DATA_TRANSFER_ENABLED + + +class SharedMemoryManager: + """ + Performs all operations related to reading/writing data from/to shared + memory. + This is used for transferring input/output data of the function from/to the + functions host over shared memory as opposed to RPC to improve the rate of + data transfer and the function's end-to-end latency. + """ + def __init__(self): + # The allocated memory maps are tracked here so that a reference to them + # is kept open until they have been used (e.g. if they contain a + # function's output, it is read by the functions host). + # Having a mapping of the name and the memory map is then later used to + # close a given memory map by its name, after it has been used. + # key: mem_map_name, val: SharedMemoryMap + self._allocated_mem_maps: Dict[str, SharedMemoryMap] = {} + self._file_accessor = FileAccessorFactory.create_file_accessor() + + def __del__(self): + del self._file_accessor + del self._allocated_mem_maps + + @property + def allocated_mem_maps(self): + """ + List of allocated shared memory maps. + """ + return self._allocated_mem_maps + + @property + def file_accessor(self): + """ + FileAccessor instance for accessing memory maps. + """ + return self._file_accessor + + def is_enabled(self) -> bool: + """ + Whether supported types should be transferred between functions host and + the worker using shared memory. + """ + return is_envvar_true( + FUNCTIONS_WORKER_SHARED_MEMORY_DATA_TRANSFER_ENABLED) + + def is_supported(self, datum: Datum) -> bool: + """ + Whether the given Datum object can be transferred to the functions host + using shared memory. + This logic is kept consistent with the host's which can be found in + SharedMemoryManager.cs + """ + if datum.type == 'bytes': + num_bytes = len(datum.value) + if num_bytes >= consts.MIN_BYTES_FOR_SHARED_MEM_TRANSFER and \ + num_bytes <= consts.MAX_BYTES_FOR_SHARED_MEM_TRANSFER: + return True + elif datum.type == 'string': + num_bytes = len(datum.value) * consts.SIZE_OF_CHAR_BYTES + if num_bytes >= consts.MIN_BYTES_FOR_SHARED_MEM_TRANSFER and \ + num_bytes <= consts.MAX_BYTES_FOR_SHARED_MEM_TRANSFER: + return True + return False + + def put_bytes(self, content: bytes) -> Optional[SharedMemoryMetadata]: + """ + Writes the given bytes into shared memory. + Returns metadata about the shared memory region to which the content was + written if successful, None otherwise. + """ + if content is None: + return None + mem_map_name = str(uuid.uuid4()) + content_length = len(content) + shared_mem_map = self._create(mem_map_name, content_length) + if shared_mem_map is None: + return None + try: + num_bytes_written = shared_mem_map.put_bytes(content) + except Exception as e: + logger.warning(f'Cannot write {content_length} bytes into shared ' + f'memory {mem_map_name} - {e}') + shared_mem_map.dispose() + return None + if num_bytes_written != content_length: + logger.error( + f'Cannot write data into shared memory {mem_map_name} ' + f'({num_bytes_written} != {content_length})') + shared_mem_map.dispose() + return None + self.allocated_mem_maps[mem_map_name] = shared_mem_map + return SharedMemoryMetadata(mem_map_name, content_length) + + def put_string(self, content: str) -> Optional[SharedMemoryMetadata]: + """ + Writes the given string into shared memory. + Returns the name of the memory map into which the data was written if + succesful, None otherwise. + Note: The encoding used here must be consistent with what is used by the + host in SharedMemoryManager.cs (GetStringAsync/PutStringAsync). + """ + if content is None: + return None + content_bytes = content.encode('utf-8') + return self.put_bytes(content_bytes) + + def get_bytes(self, mem_map_name: str, offset: int, count: int) \ + -> Optional[bytes]: + """ + Reads data from the given memory map with the provided name, starting at + the provided offset and reading a total of count bytes. + Returns the data read from shared memory as bytes if successful, None + otherwise. + """ + if offset != 0: + logger.error( + f'Cannot read bytes. Non-zero offset ({offset}) ' + f'not supported.') + return None + shared_mem_map = self._open(mem_map_name, count) + if shared_mem_map is None: + return None + try: + content = shared_mem_map.get_bytes(content_offset=0, + bytes_to_read=count) + finally: + shared_mem_map.dispose(is_delete_file=False) + return content + + def get_string(self, mem_map_name: str, offset: int, count: int) \ + -> Optional[str]: + """ + Reads data from the given memory map with the provided name, starting at + the provided offset and reading a total of count bytes. + Returns the data read from shared memory as a string if successful, None + otherwise. + Note: The encoding used here must be consistent with what is used by the + host in SharedMemoryManager.cs (GetStringAsync/PutStringAsync). + """ + content_bytes = self.get_bytes(mem_map_name, offset, count) + if content_bytes is None: + return None + content_str = content_bytes.decode('utf-8') + return content_str + + def free_mem_map(self, mem_map_name: str, + to_delete_backing_resources: bool = True) -> bool: + """ + Frees the memory map and, if specified, any backing resources (e.g. + file in the case of Unix) associated with it. + If there is no memory map with the given name being tracked, then no + action is performed. + Returns True if the memory map was freed successfully, False otherwise. + """ + if mem_map_name not in self.allocated_mem_maps: + logger.error( + f'Cannot find memory map in list of allocations {mem_map_name}') + return False + shared_mem_map = self.allocated_mem_maps[mem_map_name] + success = shared_mem_map.dispose(to_delete_backing_resources) + del self.allocated_mem_maps[mem_map_name] + return success + + def _create(self, mem_map_name: str, content_length: int) \ + -> Optional[SharedMemoryMap]: + """ + Creates a new SharedMemoryMap with the given name and content length. + Returns the SharedMemoryMap object if successful, None otherwise. + """ + mem_map_size = consts.CONTENT_HEADER_TOTAL_BYTES + content_length + mem_map = self.file_accessor.create_mem_map(mem_map_name, mem_map_size) + if mem_map is None: + return None + return SharedMemoryMap(self.file_accessor, mem_map_name, mem_map) + + def _open(self, mem_map_name: str, content_length: int) \ + -> Optional[SharedMemoryMap]: + """ + Opens an existing SharedMemoryMap with the given name and content + length. + Returns the SharedMemoryMap object if successful, None otherwise. + """ + mem_map_size = consts.CONTENT_HEADER_TOTAL_BYTES + content_length + mem_map = self.file_accessor.open_mem_map(mem_map_name, mem_map_size) + if mem_map is None: + return None + return SharedMemoryMap(self.file_accessor, mem_map_name, mem_map) diff --git a/azure_functions_worker/dispatcher.py b/azure_functions_worker/dispatcher.py index ad12bf400..0c843e9ad 100644 --- a/azure_functions_worker/dispatcher.py +++ b/azure_functions_worker/dispatcher.py @@ -1,702 +1,702 @@ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. -"""GRPC client. - -Implements loading and execution of Python workers. -""" - -import asyncio -import concurrent.futures -import logging -import os -import queue -import sys -import threading -from asyncio import BaseEventLoop -from logging import LogRecord -from typing import List, Optional - -import grpc - -from . import bindings -from . import constants -from . import functions -from . import loader -from . import protos -from .constants import (CONSOLE_LOG_PREFIX, PYTHON_THREADPOOL_THREAD_COUNT, - PYTHON_THREADPOOL_THREAD_COUNT_DEFAULT, - PYTHON_THREADPOOL_THREAD_COUNT_MAX, - PYTHON_THREADPOOL_THREAD_COUNT_MIN) -from .logging import disable_console_logging, enable_console_logging -from .logging import error_logger, is_system_log_category, logger -from .utils.common import get_app_setting -from .utils.tracing import marshall_exception_trace -from .utils.dependency import DependencyManager -from .utils.wrappers import disable_feature_by -from .bindings.shared_memory_data_transfer import SharedMemoryManager - -_TRUE = "true" - -"""In Python 3.6, the current_task method was in the Task class, but got moved -out in 3.7+ and fully removed in 3.9. Thus, to support 3.6 and 3.9 together, we -need to switch the implementation of current_task for 3.6. -""" -_CURRENT_TASK = asyncio.Task.current_task \ - if (sys.version_info[0] == 3 and sys.version_info[1] == 6) \ - else asyncio.current_task - - -class DispatcherMeta(type): - - __current_dispatcher__ = None - - @property - def current(mcls): - disp = mcls.__current_dispatcher__ - if disp is None: - raise RuntimeError('no currently running Dispatcher is found') - return disp - - -class Dispatcher(metaclass=DispatcherMeta): - - _GRPC_STOP_RESPONSE = object() - - def __init__(self, loop: BaseEventLoop, host: str, port: int, - worker_id: str, request_id: str, - grpc_connect_timeout: float, - grpc_max_msg_len: int = -1) -> None: - self._loop = loop - self._host = host - self._port = port - self._request_id = request_id - self._worker_id = worker_id - self._function_data_cache_enabled = False - self._functions = functions.Registry() - self._shmem_mgr = SharedMemoryManager() - - self._old_task_factory = None - - # We allow the customer to change synchronous thread pool max worker - # count by setting the PYTHON_THREADPOOL_THREAD_COUNT app setting. - # For 3.[6|7|8] The default value is 1. - # For 3.9, we don't set this value by default but we honor incoming - # the app setting. - self._sync_call_tp: concurrent.futures.Executor = ( - self._create_sync_call_tp(self._get_sync_tp_max_workers()) - ) - - self._grpc_connect_timeout: float = grpc_connect_timeout - # This is set to -1 by default to remove the limitation on msg size - self._grpc_max_msg_len: int = grpc_max_msg_len - self._grpc_resp_queue: queue.Queue = queue.Queue() - self._grpc_connected_fut = loop.create_future() - self._grpc_thread: threading.Thread = threading.Thread( - name='grpc-thread', target=self.__poll_grpc) - - def get_sync_tp_workers_set(self): - """We don't know the exact value of the threadcount set for the Python - 3.9 scenarios (as we'll start passing only None by default), and we - need to get that information. - - Ref: concurrent.futures.thread.ThreadPoolExecutor.__init__._max_workers - """ - return self._sync_call_tp._max_workers - - @classmethod - async def connect(cls, host: str, port: int, worker_id: str, - request_id: str, connect_timeout: float): - loop = asyncio.events.get_event_loop() - disp = cls(loop, host, port, worker_id, request_id, connect_timeout) - disp._grpc_thread.start() - await disp._grpc_connected_fut - logger.info('Successfully opened gRPC channel to %s:%s ', host, port) - return disp - - async def dispatch_forever(self): - if DispatcherMeta.__current_dispatcher__ is not None: - raise RuntimeError('there can be only one running dispatcher per ' - 'process') - - self._old_task_factory = self._loop.get_task_factory() - - loader.install() - - DispatcherMeta.__current_dispatcher__ = self - try: - forever = self._loop.create_future() - - self._grpc_resp_queue.put_nowait( - protos.StreamingMessage( - request_id=self.request_id, - start_stream=protos.StartStream( - worker_id=self.worker_id))) - - self._loop.set_task_factory( - lambda loop, coro: ContextEnabledTask(coro, loop=loop)) - - # Detach console logging before enabling GRPC channel logging - logger.info('Detaching console logging.') - disable_console_logging() - - # Attach gRPC logging to the root logger. Since gRPC channel is - # established, should use it for system and user logs - logging_handler = AsyncLoggingHandler() - root_logger = logging.getLogger() - - # Don't change this unless you read #780 and #745 - root_logger.setLevel(logging.INFO) - root_logger.addHandler(logging_handler) - logger.info('Switched to gRPC logging.') - logging_handler.flush() - - try: - await forever - finally: - logger.warning('Detaching gRPC logging due to exception.') - logging_handler.flush() - root_logger.removeHandler(logging_handler) - - # Reenable console logging when there's an exception - enable_console_logging() - logger.warning('Switched to console logging due to exception.') - finally: - DispatcherMeta.__current_dispatcher__ = None - - loader.uninstall() - - self._loop.set_task_factory(self._old_task_factory) - self.stop() - - def stop(self) -> None: - if self._grpc_thread is not None: - self._grpc_resp_queue.put_nowait(self._GRPC_STOP_RESPONSE) - self._grpc_thread.join() - self._grpc_thread = None - - self._stop_sync_call_tp() - - def on_logging(self, record: logging.LogRecord, formatted_msg: str) -> None: - if record.levelno >= logging.CRITICAL: - log_level = protos.RpcLog.Critical - elif record.levelno >= logging.ERROR: - log_level = protos.RpcLog.Error - elif record.levelno >= logging.WARNING: - log_level = protos.RpcLog.Warning - elif record.levelno >= logging.INFO: - log_level = protos.RpcLog.Information - elif record.levelno >= logging.DEBUG: - log_level = protos.RpcLog.Debug - else: - log_level = getattr(protos.RpcLog, 'None') - - if is_system_log_category(record.name): - log_category = protos.RpcLog.RpcLogCategory.Value('System') - else: # customers using logging will yield 'root' in record.name - log_category = protos.RpcLog.RpcLogCategory.Value('User') - - log = dict( - level=log_level, - message=formatted_msg, - category=record.name, - log_category=log_category - ) - - invocation_id = get_current_invocation_id() - if invocation_id is not None: - log['invocation_id'] = invocation_id - - # XXX: When an exception field is set in RpcLog, WebHost doesn't - # wait for the call result and simply aborts the execution. - # - # if record.exc_info and record.exc_info[1] is not None: - # log['exception'] = self._serialize_exception(record.exc_info[1]) - - self._grpc_resp_queue.put_nowait( - protos.StreamingMessage( - request_id=self.request_id, - rpc_log=protos.RpcLog(**log))) - - @property - def request_id(self) -> str: - return self._request_id - - @property - def worker_id(self) -> str: - return self._worker_id - - # noinspection PyBroadException - @staticmethod - def _serialize_exception(exc: Exception): - try: - message = f'{type(exc).__name__}: {exc}' - except Exception: - message = ('Unhandled exception in function. ' - 'Could not serialize original exception message.') - - try: - stack_trace = marshall_exception_trace(exc) - except Exception: - stack_trace = '' - - return protos.RpcException(message=message, stack_trace=stack_trace) - - async def _dispatch_grpc_request(self, request): - content_type = request.WhichOneof('content') - request_handler = getattr(self, f'_handle__{content_type}', None) - if request_handler is None: - # Don't crash on unknown messages. Some of them can be ignored; - # and if something goes really wrong the host can always just - # kill the worker's process. - logger.error(f'unknown StreamingMessage content type ' - f'{content_type}') - return - - resp = await request_handler(request) - self._grpc_resp_queue.put_nowait(resp) - - async def _handle__worker_init_request(self, req): - logger.info('Received WorkerInitRequest, request ID %s', - self.request_id) - - worker_init_request = req.worker_init_request - host_capabilities = worker_init_request.capabilities - if constants.FUNCTION_DATA_CACHE in host_capabilities: - val = host_capabilities[constants.FUNCTION_DATA_CACHE] - self._function_data_cache_enabled = val == _TRUE - - capabilities = { - constants.RAW_HTTP_BODY_BYTES: _TRUE, - constants.TYPED_DATA_COLLECTION: _TRUE, - constants.RPC_HTTP_BODY_ONLY: _TRUE, - constants.RPC_HTTP_TRIGGER_METADATA_REMOVED: _TRUE, - constants.WORKER_STATUS: _TRUE, - constants.SHARED_MEMORY_DATA_TRANSFER: _TRUE, - } - - # Can detech worker packages - DependencyManager.prioritize_customer_dependencies() - - return protos.StreamingMessage( - request_id=self.request_id, - worker_init_response=protos.WorkerInitResponse( - capabilities=capabilities, - result=protos.StatusResult( - status=protos.StatusResult.Success))) - - async def _handle__worker_status_request(self, req): - # Logging is not necessary in this request since the response is used - # for host to judge scale decisions of out-of-proc languages. - # Having log here will reduce the responsiveness of the worker. - return protos.StreamingMessage( - request_id=self.request_id, - worker_status_response=protos.WorkerStatusResponse()) - - async def _handle__function_load_request(self, req): - func_request = req.function_load_request - function_id = func_request.function_id - function_name = func_request.metadata.name - - logger.info(f'Received FunctionLoadRequest, ' - f'request ID: {self.request_id}, ' - f'function ID: {function_id}' - f'function Name: {function_name}') - try: - func = loader.load_function( - func_request.metadata.name, - func_request.metadata.directory, - func_request.metadata.script_file, - func_request.metadata.entry_point) - - self._functions.add_function( - function_id, func, func_request.metadata) - - logger.info('Successfully processed FunctionLoadRequest, ' - f'request ID: {self.request_id}, ' - f'function ID: {function_id},' - f'function Name: {function_name}') - - return protos.StreamingMessage( - request_id=self.request_id, - function_load_response=protos.FunctionLoadResponse( - function_id=function_id, - result=protos.StatusResult( - status=protos.StatusResult.Success))) - - except Exception as ex: - return protos.StreamingMessage( - request_id=self.request_id, - function_load_response=protos.FunctionLoadResponse( - function_id=function_id, - result=protos.StatusResult( - status=protos.StatusResult.Failure, - exception=self._serialize_exception(ex)))) - - async def _handle__invocation_request(self, req): - invoc_request = req.invocation_request - - invocation_id = invoc_request.invocation_id - function_id = invoc_request.function_id - trace_context = bindings.TraceContext( - invoc_request.trace_context.trace_parent, - invoc_request.trace_context.trace_state, - invoc_request.trace_context.attributes) - # Set the current `invocation_id` to the current task so - # that our logging handler can find it. - current_task = _CURRENT_TASK(self._loop) - assert isinstance(current_task, ContextEnabledTask) - current_task.set_azure_invocation_id(invocation_id) - - try: - fi: functions.FunctionInfo = self._functions.get_function( - function_id) - - function_invocation_logs: List[str] = [ - 'Received FunctionInvocationRequest', - f'request ID: {self.request_id}', - f'function ID: {function_id}', - f'function name: {fi.name}', - f'invocation ID: {invocation_id}', - f'function type: {"async" if fi.is_async else "sync"}' - ] - if not fi.is_async: - function_invocation_logs.append( - f'sync threadpool max workers: ' - f'{self.get_sync_tp_workers_set()}' - ) - logger.info(', '.join(function_invocation_logs)) - - args = {} - for pb in invoc_request.input_data: - pb_type_info = fi.input_types[pb.name] - if bindings.is_trigger_binding(pb_type_info.binding_name): - trigger_metadata = invoc_request.trigger_metadata - else: - trigger_metadata = None - - args[pb.name] = bindings.from_incoming_proto( - pb_type_info.binding_name, pb, - trigger_metadata=trigger_metadata, - pytype=pb_type_info.pytype, - shmem_mgr=self._shmem_mgr) - - if fi.requires_context: - args['context'] = bindings.Context( - fi.name, fi.directory, invocation_id, trace_context) - - if fi.output_types: - for name in fi.output_types: - args[name] = bindings.Out() - - if fi.is_async: - call_result = await fi.func(**args) - else: - call_result = await self._loop.run_in_executor( - self._sync_call_tp, - self.__run_sync_func, invocation_id, fi.func, args) - if call_result is not None and not fi.has_return: - raise RuntimeError(f'function {fi.name!r} without a $return ' - 'binding returned a non-None value') - - output_data = [] - cache_enabled = self._function_data_cache_enabled - if fi.output_types: - for out_name, out_type_info in fi.output_types.items(): - val = args[out_name].get() - if val is None: - # TODO: is the "Out" parameter optional? - # Can "None" be marshaled into protos.TypedData? - continue - - param_binding = bindings.to_outgoing_param_binding( - out_type_info.binding_name, val, - pytype=out_type_info.pytype, - out_name=out_name, shmem_mgr=self._shmem_mgr, - is_function_data_cache_enabled=cache_enabled) - output_data.append(param_binding) - - return_value = None - if fi.return_type is not None: - return_value = bindings.to_outgoing_proto( - fi.return_type.binding_name, call_result, - pytype=fi.return_type.pytype) - - # Actively flush customer print() function to console - sys.stdout.flush() - - return protos.StreamingMessage( - request_id=self.request_id, - invocation_response=protos.InvocationResponse( - invocation_id=invocation_id, - return_value=return_value, - result=protos.StatusResult( - status=protos.StatusResult.Success), - output_data=output_data)) - - except Exception as ex: - return protos.StreamingMessage( - request_id=self.request_id, - invocation_response=protos.InvocationResponse( - invocation_id=invocation_id, - result=protos.StatusResult( - status=protos.StatusResult.Failure, - exception=self._serialize_exception(ex)))) - - async def _handle__function_environment_reload_request(self, req): - """Only runs on Linux Consumption placeholder specialization. - """ - try: - logger.info('Received FunctionEnvironmentReloadRequest, ' - 'request ID: %s', self.request_id) - - func_env_reload_request = req.function_environment_reload_request - - # Import before clearing path cache so that the default - # azure.functions modules is available in sys.modules for - # customer use - import azure.functions # NoQA - - # Append function project root to module finding sys.path - if func_env_reload_request.function_app_directory: - sys.path.append(func_env_reload_request.function_app_directory) - - # Clear sys.path import cache, reload all module from new sys.path - sys.path_importer_cache.clear() - - # Reload environment variables - os.environ.clear() - env_vars = func_env_reload_request.environment_variables - for var in env_vars: - os.environ[var] = env_vars[var] - - # Apply PYTHON_THREADPOOL_THREAD_COUNT - self._stop_sync_call_tp() - self._sync_call_tp = ( - self._create_sync_call_tp(self._get_sync_tp_max_workers()) - ) - - # Reload azure google namespaces - DependencyManager.reload_azure_google_namespace( - func_env_reload_request.function_app_directory - ) - - # Change function app directory - if getattr(func_env_reload_request, - 'function_app_directory', None): - self._change_cwd( - func_env_reload_request.function_app_directory) - - success_response = protos.FunctionEnvironmentReloadResponse( - result=protos.StatusResult( - status=protos.StatusResult.Success)) - - return protos.StreamingMessage( - request_id=self.request_id, - function_environment_reload_response=success_response) - - except Exception as ex: - failure_response = protos.FunctionEnvironmentReloadResponse( - result=protos.StatusResult( - status=protos.StatusResult.Failure, - exception=self._serialize_exception(ex))) - - return protos.StreamingMessage( - request_id=self.request_id, - function_environment_reload_response=failure_response) - - async def _handle__close_shared_memory_resources_request(self, req): - """ - Frees any memory maps that were produced as output for a given - invocation. - This is called after the functions host is done reading the output from - the worker and wants the worker to free up those resources. - """ - close_request = req.close_shared_memory_resources_request - map_names = close_request.map_names - # Assign default value of False to all result values. - # If we are successfully able to close a memory map, its result will be - # set to True. - results = {mem_map_name: False for mem_map_name in map_names} - - try: - for mem_map_name in map_names: - try: - success = self._shmem_mgr.free_mem_map(mem_map_name, False) - results[mem_map_name] = success - except Exception as e: - logger.error(f'Cannot free memory map {mem_map_name} - {e}', - exc_info=True) - finally: - response = protos.CloseSharedMemoryResourcesResponse( - close_map_results=results) - return protos.StreamingMessage( - request_id=self.request_id, - close_shared_memory_resources_response=response) - - @disable_feature_by(constants.PYTHON_ROLLBACK_CWD_PATH) - def _change_cwd(self, new_cwd: str): - if os.path.exists(new_cwd): - os.chdir(new_cwd) - logger.info('Changing current working directory to %s', new_cwd) - else: - logger.warning('Directory %s is not found when reloading', new_cwd) - - def _stop_sync_call_tp(self): - """Deallocate the current synchronous thread pool and assign - self._sync_call_tp to None. If the thread pool does not exist, - this will be a no op. - """ - if getattr(self, '_sync_call_tp', None): - self._sync_call_tp.shutdown() - self._sync_call_tp = None - - @staticmethod - def _get_sync_tp_max_workers() -> Optional[int]: - def tp_max_workers_validator(value: str) -> bool: - try: - int_value = int(value) - except ValueError: - logger.warning(f'{PYTHON_THREADPOOL_THREAD_COUNT} must be an ' - 'integer') - return False - - if int_value < PYTHON_THREADPOOL_THREAD_COUNT_MIN or ( - int_value > PYTHON_THREADPOOL_THREAD_COUNT_MAX): - logger.warning(f'{PYTHON_THREADPOOL_THREAD_COUNT} must be set ' - 'to a value between 1 and 32. ' - 'Reverting to default value for max_workers') - return False - - return True - - # Starting Python 3.9, worker won't be putting a limit on the - # max_workers count in the created threadpool. - default_value = None if sys.version_info.minor == 9 \ - else f'{PYTHON_THREADPOOL_THREAD_COUNT_DEFAULT}' - max_workers = get_app_setting(setting=PYTHON_THREADPOOL_THREAD_COUNT, - default_value=default_value, - validator=tp_max_workers_validator) - - # We can box the app setting as int for earlier python versions. - return int(max_workers) if max_workers else None - - def _create_sync_call_tp( - self, max_worker: Optional[int]) -> concurrent.futures.Executor: - """Create a thread pool executor with max_worker. This is a wrapper - over ThreadPoolExecutor constructor. Consider calling this method after - _stop_sync_call_tp() to ensure only 1 synchronous thread pool is - running. - """ - return concurrent.futures.ThreadPoolExecutor( - max_workers=max_worker - ) - - def __run_sync_func(self, invocation_id, func, params): - # This helper exists because we need to access the current - # invocation_id from ThreadPoolExecutor's threads. - _invocation_id_local.v = invocation_id - try: - return func(**params) - finally: - _invocation_id_local.v = None - - def __poll_grpc(self): - options = [] - if self._grpc_max_msg_len: - options.append(('grpc.max_receive_message_length', - self._grpc_max_msg_len)) - options.append(('grpc.max_send_message_length', - self._grpc_max_msg_len)) - - channel = grpc.insecure_channel( - f'{self._host}:{self._port}', options) - - try: - grpc.channel_ready_future(channel).result( - timeout=self._grpc_connect_timeout) - except Exception as ex: - self._loop.call_soon_threadsafe( - self._grpc_connected_fut.set_exception, ex) - return - else: - self._loop.call_soon_threadsafe( - self._grpc_connected_fut.set_result, True) - - stub = protos.FunctionRpcStub(channel) - - def gen(resp_queue): - while True: - msg = resp_queue.get() - if msg is self._GRPC_STOP_RESPONSE: - grpc_req_stream.cancel() - return - yield msg - - grpc_req_stream = stub.EventStream(gen(self._grpc_resp_queue)) - try: - for req in grpc_req_stream: - self._loop.call_soon_threadsafe( - self._loop.create_task, self._dispatch_grpc_request(req)) - except Exception as ex: - if ex is grpc_req_stream: - # Yes, this is how grpc_req_stream iterator exits. - return - error_logger.exception('unhandled error in gRPC thread') - raise - - -class AsyncLoggingHandler(logging.Handler): - - def emit(self, record: LogRecord) -> None: - # Since we disable console log after gRPC channel is initiated, - # we should redirect all the messages into dispatcher. - - # When dispatcher receives an exception, it should switch back - # to console logging. However, it is possible that - # __current_dispatcher__ is set to None as there are still messages - # buffered in this handler, not calling the emit yet. - msg = self.format(record) - try: - Dispatcher.current.on_logging(record, msg) - except RuntimeError as runtime_error: - # This will cause 'Dispatcher not found' failure. - # Logging such of an issue will cause infinite loop of gRPC logging - # To mitigate, we should suppress the 2nd level error logging here - # and use print function to report exception instead. - print(f'{CONSOLE_LOG_PREFIX} ERROR: {str(runtime_error)}', - file=sys.stderr, flush=True) - - -class ContextEnabledTask(asyncio.Task): - - AZURE_INVOCATION_ID = '__azure_function_invocation_id__' - - def __init__(self, coro, loop): - super().__init__(coro, loop=loop) - - current_task = _CURRENT_TASK(loop) - if current_task is not None: - invocation_id = getattr( - current_task, self.AZURE_INVOCATION_ID, None) - if invocation_id is not None: - self.set_azure_invocation_id(invocation_id) - - def set_azure_invocation_id(self, invocation_id: str) -> None: - setattr(self, self.AZURE_INVOCATION_ID, invocation_id) - - -def get_current_invocation_id() -> Optional[str]: - loop = asyncio._get_running_loop() - if loop is not None: - current_task = _CURRENT_TASK(loop) - if current_task is not None: - task_invocation_id = getattr(current_task, - ContextEnabledTask.AZURE_INVOCATION_ID, - None) - if task_invocation_id is not None: - return task_invocation_id - - return getattr(_invocation_id_local, 'v', None) - - -_invocation_id_local = threading.local() +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +"""GRPC client. + +Implements loading and execution of Python workers. +""" + +import asyncio +import concurrent.futures +import logging +import os +import queue +import sys +import threading +from asyncio import BaseEventLoop +from logging import LogRecord +from typing import List, Optional + +import grpc + +from . import bindings +from . import constants +from . import functions +from . import loader +from . import protos +from .constants import (CONSOLE_LOG_PREFIX, PYTHON_THREADPOOL_THREAD_COUNT, + PYTHON_THREADPOOL_THREAD_COUNT_DEFAULT, + PYTHON_THREADPOOL_THREAD_COUNT_MAX, + PYTHON_THREADPOOL_THREAD_COUNT_MIN) +from .logging import disable_console_logging, enable_console_logging +from .logging import error_logger, is_system_log_category, logger +from .utils.common import get_app_setting +from .utils.tracing import marshall_exception_trace +from .utils.dependency import DependencyManager +from .utils.wrappers import disable_feature_by +from .bindings.shared_memory_data_transfer import SharedMemoryManager + +_TRUE = "true" + +"""In Python 3.6, the current_task method was in the Task class, but got moved +out in 3.7+ and fully removed in 3.9. Thus, to support 3.6 and 3.9 together, we +need to switch the implementation of current_task for 3.6. +""" +_CURRENT_TASK = asyncio.Task.current_task \ + if (sys.version_info[0] == 3 and sys.version_info[1] == 6) \ + else asyncio.current_task + + +class DispatcherMeta(type): + + __current_dispatcher__ = None + + @property + def current(mcls): + disp = mcls.__current_dispatcher__ + if disp is None: + raise RuntimeError('no currently running Dispatcher is found') + return disp + + +class Dispatcher(metaclass=DispatcherMeta): + + _GRPC_STOP_RESPONSE = object() + + def __init__(self, loop: BaseEventLoop, host: str, port: int, + worker_id: str, request_id: str, + grpc_connect_timeout: float, + grpc_max_msg_len: int = -1) -> None: + self._loop = loop + self._host = host + self._port = port + self._request_id = request_id + self._worker_id = worker_id + self._function_data_cache_enabled = False + self._functions = functions.Registry() + self._shmem_mgr = SharedMemoryManager() + + self._old_task_factory = None + + # We allow the customer to change synchronous thread pool max worker + # count by setting the PYTHON_THREADPOOL_THREAD_COUNT app setting. + # For 3.[6|7|8] The default value is 1. + # For 3.9, we don't set this value by default but we honor incoming + # the app setting. + self._sync_call_tp: concurrent.futures.Executor = ( + self._create_sync_call_tp(self._get_sync_tp_max_workers()) + ) + + self._grpc_connect_timeout: float = grpc_connect_timeout + # This is set to -1 by default to remove the limitation on msg size + self._grpc_max_msg_len: int = grpc_max_msg_len + self._grpc_resp_queue: queue.Queue = queue.Queue() + self._grpc_connected_fut = loop.create_future() + self._grpc_thread: threading.Thread = threading.Thread( + name='grpc-thread', target=self.__poll_grpc) + + def get_sync_tp_workers_set(self): + """We don't know the exact value of the threadcount set for the Python + 3.9 scenarios (as we'll start passing only None by default), and we + need to get that information. + + Ref: concurrent.futures.thread.ThreadPoolExecutor.__init__._max_workers + """ + return self._sync_call_tp._max_workers + + @classmethod + async def connect(cls, host: str, port: int, worker_id: str, + request_id: str, connect_timeout: float): + loop = asyncio.events.get_event_loop() + disp = cls(loop, host, port, worker_id, request_id, connect_timeout) + disp._grpc_thread.start() + await disp._grpc_connected_fut + logger.info('Successfully opened gRPC channel to %s:%s ', host, port) + return disp + + async def dispatch_forever(self): + if DispatcherMeta.__current_dispatcher__ is not None: + raise RuntimeError('there can be only one running dispatcher per ' + 'process') + + self._old_task_factory = self._loop.get_task_factory() + + loader.install() + + DispatcherMeta.__current_dispatcher__ = self + try: + forever = self._loop.create_future() + + self._grpc_resp_queue.put_nowait( + protos.StreamingMessage( + request_id=self.request_id, + start_stream=protos.StartStream( + worker_id=self.worker_id))) + + self._loop.set_task_factory( + lambda loop, coro: ContextEnabledTask(coro, loop=loop)) + + # Detach console logging before enabling GRPC channel logging + logger.info('Detaching console logging.') + disable_console_logging() + + # Attach gRPC logging to the root logger. Since gRPC channel is + # established, should use it for system and user logs + logging_handler = AsyncLoggingHandler() + root_logger = logging.getLogger() + + # Don't change this unless you read #780 and #745 + root_logger.setLevel(logging.INFO) + root_logger.addHandler(logging_handler) + logger.info('Switched to gRPC logging.') + logging_handler.flush() + + try: + await forever + finally: + logger.warning('Detaching gRPC logging due to exception.') + logging_handler.flush() + root_logger.removeHandler(logging_handler) + + # Reenable console logging when there's an exception + enable_console_logging() + logger.warning('Switched to console logging due to exception.') + finally: + DispatcherMeta.__current_dispatcher__ = None + + loader.uninstall() + + self._loop.set_task_factory(self._old_task_factory) + self.stop() + + def stop(self) -> None: + if self._grpc_thread is not None: + self._grpc_resp_queue.put_nowait(self._GRPC_STOP_RESPONSE) + self._grpc_thread.join() + self._grpc_thread = None + + self._stop_sync_call_tp() + + def on_logging(self, record: logging.LogRecord, formatted_msg: str) -> None: + if record.levelno >= logging.CRITICAL: + log_level = protos.RpcLog.Critical + elif record.levelno >= logging.ERROR: + log_level = protos.RpcLog.Error + elif record.levelno >= logging.WARNING: + log_level = protos.RpcLog.Warning + elif record.levelno >= logging.INFO: + log_level = protos.RpcLog.Information + elif record.levelno >= logging.DEBUG: + log_level = protos.RpcLog.Debug + else: + log_level = getattr(protos.RpcLog, 'None') + + if is_system_log_category(record.name): + log_category = protos.RpcLog.RpcLogCategory.Value('System') + else: # customers using logging will yield 'root' in record.name + log_category = protos.RpcLog.RpcLogCategory.Value('User') + + log = dict( + level=log_level, + message=formatted_msg, + category=record.name, + log_category=log_category + ) + + invocation_id = get_current_invocation_id() + if invocation_id is not None: + log['invocation_id'] = invocation_id + + # XXX: When an exception field is set in RpcLog, WebHost doesn't + # wait for the call result and simply aborts the execution. + # + # if record.exc_info and record.exc_info[1] is not None: + # log['exception'] = self._serialize_exception(record.exc_info[1]) + + self._grpc_resp_queue.put_nowait( + protos.StreamingMessage( + request_id=self.request_id, + rpc_log=protos.RpcLog(**log))) + + @property + def request_id(self) -> str: + return self._request_id + + @property + def worker_id(self) -> str: + return self._worker_id + + # noinspection PyBroadException + @staticmethod + def _serialize_exception(exc: Exception): + try: + message = f'{type(exc).__name__}: {exc}' + except Exception: + message = ('Unhandled exception in function. ' + 'Could not serialize original exception message.') + + try: + stack_trace = marshall_exception_trace(exc) + except Exception: + stack_trace = '' + + return protos.RpcException(message=message, stack_trace=stack_trace) + + async def _dispatch_grpc_request(self, request): + content_type = request.WhichOneof('content') + request_handler = getattr(self, f'_handle__{content_type}', None) + if request_handler is None: + # Don't crash on unknown messages. Some of them can be ignored; + # and if something goes really wrong the host can always just + # kill the worker's process. + logger.error(f'unknown StreamingMessage content type ' + f'{content_type}') + return + + resp = await request_handler(request) + self._grpc_resp_queue.put_nowait(resp) + + async def _handle__worker_init_request(self, req): + logger.info('Received WorkerInitRequest, request ID %s', + self.request_id) + + worker_init_request = req.worker_init_request + host_capabilities = worker_init_request.capabilities + if constants.FUNCTION_DATA_CACHE in host_capabilities: + val = host_capabilities[constants.FUNCTION_DATA_CACHE] + self._function_data_cache_enabled = val == _TRUE + + capabilities = { + constants.RAW_HTTP_BODY_BYTES: _TRUE, + constants.TYPED_DATA_COLLECTION: _TRUE, + constants.RPC_HTTP_BODY_ONLY: _TRUE, + constants.RPC_HTTP_TRIGGER_METADATA_REMOVED: _TRUE, + constants.WORKER_STATUS: _TRUE, + constants.SHARED_MEMORY_DATA_TRANSFER: _TRUE, + } + + # Can detech worker packages + DependencyManager.prioritize_customer_dependencies() + + return protos.StreamingMessage( + request_id=self.request_id, + worker_init_response=protos.WorkerInitResponse( + capabilities=capabilities, + result=protos.StatusResult( + status=protos.StatusResult.Success))) + + async def _handle__worker_status_request(self, req): + # Logging is not necessary in this request since the response is used + # for host to judge scale decisions of out-of-proc languages. + # Having log here will reduce the responsiveness of the worker. + return protos.StreamingMessage( + request_id=self.request_id, + worker_status_response=protos.WorkerStatusResponse()) + + async def _handle__function_load_request(self, req): + func_request = req.function_load_request + function_id = func_request.function_id + function_name = func_request.metadata.name + + logger.info(f'Received FunctionLoadRequest, ' + f'request ID: {self.request_id}, ' + f'function ID: {function_id}' + f'function Name: {function_name}') + try: + func = loader.load_function( + func_request.metadata.name, + func_request.metadata.directory, + func_request.metadata.script_file, + func_request.metadata.entry_point) + + self._functions.add_function( + function_id, func, func_request.metadata) + + logger.info('Successfully processed FunctionLoadRequest, ' + f'request ID: {self.request_id}, ' + f'function ID: {function_id},' + f'function Name: {function_name}') + + return protos.StreamingMessage( + request_id=self.request_id, + function_load_response=protos.FunctionLoadResponse( + function_id=function_id, + result=protos.StatusResult( + status=protos.StatusResult.Success))) + + except Exception as ex: + return protos.StreamingMessage( + request_id=self.request_id, + function_load_response=protos.FunctionLoadResponse( + function_id=function_id, + result=protos.StatusResult( + status=protos.StatusResult.Failure, + exception=self._serialize_exception(ex)))) + + async def _handle__invocation_request(self, req): + invoc_request = req.invocation_request + + invocation_id = invoc_request.invocation_id + function_id = invoc_request.function_id + trace_context = bindings.TraceContext( + invoc_request.trace_context.trace_parent, + invoc_request.trace_context.trace_state, + invoc_request.trace_context.attributes) + # Set the current `invocation_id` to the current task so + # that our logging handler can find it. + current_task = _CURRENT_TASK(self._loop) + assert isinstance(current_task, ContextEnabledTask) + current_task.set_azure_invocation_id(invocation_id) + + try: + fi: functions.FunctionInfo = self._functions.get_function( + function_id) + + function_invocation_logs: List[str] = [ + 'Received FunctionInvocationRequest', + f'request ID: {self.request_id}', + f'function ID: {function_id}', + f'function name: {fi.name}', + f'invocation ID: {invocation_id}', + f'function type: {"async" if fi.is_async else "sync"}' + ] + if not fi.is_async: + function_invocation_logs.append( + f'sync threadpool max workers: ' + f'{self.get_sync_tp_workers_set()}' + ) + logger.info(', '.join(function_invocation_logs)) + + args = {} + for pb in invoc_request.input_data: + pb_type_info = fi.input_types[pb.name] + if bindings.is_trigger_binding(pb_type_info.binding_name): + trigger_metadata = invoc_request.trigger_metadata + else: + trigger_metadata = None + + args[pb.name] = bindings.from_incoming_proto( + pb_type_info.binding_name, pb, + trigger_metadata=trigger_metadata, + pytype=pb_type_info.pytype, + shmem_mgr=self._shmem_mgr) + + if fi.requires_context: + args['context'] = bindings.Context( + fi.name, fi.directory, invocation_id, trace_context) + + if fi.output_types: + for name in fi.output_types: + args[name] = bindings.Out() + + if fi.is_async: + call_result = await fi.func(**args) + else: + call_result = await self._loop.run_in_executor( + self._sync_call_tp, + self.__run_sync_func, invocation_id, fi.func, args) + if call_result is not None and not fi.has_return: + raise RuntimeError(f'function {fi.name!r} without a $return ' + 'binding returned a non-None value') + + output_data = [] + cache_enabled = self._function_data_cache_enabled + if fi.output_types: + for out_name, out_type_info in fi.output_types.items(): + val = args[out_name].get() + if val is None: + # TODO: is the "Out" parameter optional? + # Can "None" be marshaled into protos.TypedData? + continue + + param_binding = bindings.to_outgoing_param_binding( + out_type_info.binding_name, val, + pytype=out_type_info.pytype, + out_name=out_name, shmem_mgr=self._shmem_mgr, + is_function_data_cache_enabled=cache_enabled) + output_data.append(param_binding) + + return_value = None + if fi.return_type is not None: + return_value = bindings.to_outgoing_proto( + fi.return_type.binding_name, call_result, + pytype=fi.return_type.pytype) + + # Actively flush customer print() function to console + sys.stdout.flush() + + return protos.StreamingMessage( + request_id=self.request_id, + invocation_response=protos.InvocationResponse( + invocation_id=invocation_id, + return_value=return_value, + result=protos.StatusResult( + status=protos.StatusResult.Success), + output_data=output_data)) + + except Exception as ex: + return protos.StreamingMessage( + request_id=self.request_id, + invocation_response=protos.InvocationResponse( + invocation_id=invocation_id, + result=protos.StatusResult( + status=protos.StatusResult.Failure, + exception=self._serialize_exception(ex)))) + + async def _handle__function_environment_reload_request(self, req): + """Only runs on Linux Consumption placeholder specialization. + """ + try: + logger.info('Received FunctionEnvironmentReloadRequest, ' + 'request ID: %s', self.request_id) + + func_env_reload_request = req.function_environment_reload_request + + # Import before clearing path cache so that the default + # azure.functions modules is available in sys.modules for + # customer use + import azure.functions # NoQA + + # Append function project root to module finding sys.path + if func_env_reload_request.function_app_directory: + sys.path.append(func_env_reload_request.function_app_directory) + + # Clear sys.path import cache, reload all module from new sys.path + sys.path_importer_cache.clear() + + # Reload environment variables + os.environ.clear() + env_vars = func_env_reload_request.environment_variables + for var in env_vars: + os.environ[var] = env_vars[var] + + # Apply PYTHON_THREADPOOL_THREAD_COUNT + self._stop_sync_call_tp() + self._sync_call_tp = ( + self._create_sync_call_tp(self._get_sync_tp_max_workers()) + ) + + # Reload azure google namespaces + DependencyManager.reload_azure_google_namespace( + func_env_reload_request.function_app_directory + ) + + # Change function app directory + if getattr(func_env_reload_request, + 'function_app_directory', None): + self._change_cwd( + func_env_reload_request.function_app_directory) + + success_response = protos.FunctionEnvironmentReloadResponse( + result=protos.StatusResult( + status=protos.StatusResult.Success)) + + return protos.StreamingMessage( + request_id=self.request_id, + function_environment_reload_response=success_response) + + except Exception as ex: + failure_response = protos.FunctionEnvironmentReloadResponse( + result=protos.StatusResult( + status=protos.StatusResult.Failure, + exception=self._serialize_exception(ex))) + + return protos.StreamingMessage( + request_id=self.request_id, + function_environment_reload_response=failure_response) + + async def _handle__close_shared_memory_resources_request(self, req): + """ + Frees any memory maps that were produced as output for a given + invocation. + This is called after the functions host is done reading the output from + the worker and wants the worker to free up those resources. + """ + close_request = req.close_shared_memory_resources_request + map_names = close_request.map_names + # Assign default value of False to all result values. + # If we are successfully able to close a memory map, its result will be + # set to True. + results = {mem_map_name: False for mem_map_name in map_names} + + try: + for mem_map_name in map_names: + try: + success = self._shmem_mgr.free_mem_map(mem_map_name, False) + results[mem_map_name] = success + except Exception as e: + logger.error(f'Cannot free memory map {mem_map_name} - {e}', + exc_info=True) + finally: + response = protos.CloseSharedMemoryResourcesResponse( + close_map_results=results) + return protos.StreamingMessage( + request_id=self.request_id, + close_shared_memory_resources_response=response) + + @disable_feature_by(constants.PYTHON_ROLLBACK_CWD_PATH) + def _change_cwd(self, new_cwd: str): + if os.path.exists(new_cwd): + os.chdir(new_cwd) + logger.info('Changing current working directory to %s', new_cwd) + else: + logger.warning('Directory %s is not found when reloading', new_cwd) + + def _stop_sync_call_tp(self): + """Deallocate the current synchronous thread pool and assign + self._sync_call_tp to None. If the thread pool does not exist, + this will be a no op. + """ + if getattr(self, '_sync_call_tp', None): + self._sync_call_tp.shutdown() + self._sync_call_tp = None + + @staticmethod + def _get_sync_tp_max_workers() -> Optional[int]: + def tp_max_workers_validator(value: str) -> bool: + try: + int_value = int(value) + except ValueError: + logger.warning(f'{PYTHON_THREADPOOL_THREAD_COUNT} must be an ' + 'integer') + return False + + if int_value < PYTHON_THREADPOOL_THREAD_COUNT_MIN or ( + int_value > PYTHON_THREADPOOL_THREAD_COUNT_MAX): + logger.warning(f'{PYTHON_THREADPOOL_THREAD_COUNT} must be set ' + 'to a value between 1 and 32. ' + 'Reverting to default value for max_workers') + return False + + return True + + # Starting Python 3.9, worker won't be putting a limit on the + # max_workers count in the created threadpool. + default_value = None if sys.version_info.minor == 9 \ + else f'{PYTHON_THREADPOOL_THREAD_COUNT_DEFAULT}' + max_workers = get_app_setting(setting=PYTHON_THREADPOOL_THREAD_COUNT, + default_value=default_value, + validator=tp_max_workers_validator) + + # We can box the app setting as int for earlier python versions. + return int(max_workers) if max_workers else None + + def _create_sync_call_tp( + self, max_worker: Optional[int]) -> concurrent.futures.Executor: + """Create a thread pool executor with max_worker. This is a wrapper + over ThreadPoolExecutor constructor. Consider calling this method after + _stop_sync_call_tp() to ensure only 1 synchronous thread pool is + running. + """ + return concurrent.futures.ThreadPoolExecutor( + max_workers=max_worker + ) + + def __run_sync_func(self, invocation_id, func, params): + # This helper exists because we need to access the current + # invocation_id from ThreadPoolExecutor's threads. + _invocation_id_local.v = invocation_id + try: + return func(**params) + finally: + _invocation_id_local.v = None + + def __poll_grpc(self): + options = [] + if self._grpc_max_msg_len: + options.append(('grpc.max_receive_message_length', + self._grpc_max_msg_len)) + options.append(('grpc.max_send_message_length', + self._grpc_max_msg_len)) + + channel = grpc.insecure_channel( + f'{self._host}:{self._port}', options) + + try: + grpc.channel_ready_future(channel).result( + timeout=self._grpc_connect_timeout) + except Exception as ex: + self._loop.call_soon_threadsafe( + self._grpc_connected_fut.set_exception, ex) + return + else: + self._loop.call_soon_threadsafe( + self._grpc_connected_fut.set_result, True) + + stub = protos.FunctionRpcStub(channel) + + def gen(resp_queue): + while True: + msg = resp_queue.get() + if msg is self._GRPC_STOP_RESPONSE: + grpc_req_stream.cancel() + return + yield msg + + grpc_req_stream = stub.EventStream(gen(self._grpc_resp_queue)) + try: + for req in grpc_req_stream: + self._loop.call_soon_threadsafe( + self._loop.create_task, self._dispatch_grpc_request(req)) + except Exception as ex: + if ex is grpc_req_stream: + # Yes, this is how grpc_req_stream iterator exits. + return + error_logger.exception('unhandled error in gRPC thread') + raise + + +class AsyncLoggingHandler(logging.Handler): + + def emit(self, record: LogRecord) -> None: + # Since we disable console log after gRPC channel is initiated, + # we should redirect all the messages into dispatcher. + + # When dispatcher receives an exception, it should switch back + # to console logging. However, it is possible that + # __current_dispatcher__ is set to None as there are still messages + # buffered in this handler, not calling the emit yet. + msg = self.format(record) + try: + Dispatcher.current.on_logging(record, msg) + except RuntimeError as runtime_error: + # This will cause 'Dispatcher not found' failure. + # Logging such of an issue will cause infinite loop of gRPC logging + # To mitigate, we should suppress the 2nd level error logging here + # and use print function to report exception instead. + print(f'{CONSOLE_LOG_PREFIX} ERROR: {str(runtime_error)}', + file=sys.stderr, flush=True) + + +class ContextEnabledTask(asyncio.Task): + + AZURE_INVOCATION_ID = '__azure_function_invocation_id__' + + def __init__(self, coro, loop): + super().__init__(coro, loop=loop) + + current_task = _CURRENT_TASK(loop) + if current_task is not None: + invocation_id = getattr( + current_task, self.AZURE_INVOCATION_ID, None) + if invocation_id is not None: + self.set_azure_invocation_id(invocation_id) + + def set_azure_invocation_id(self, invocation_id: str) -> None: + setattr(self, self.AZURE_INVOCATION_ID, invocation_id) + + +def get_current_invocation_id() -> Optional[str]: + loop = asyncio._get_running_loop() + if loop is not None: + current_task = _CURRENT_TASK(loop) + if current_task is not None: + task_invocation_id = getattr(current_task, + ContextEnabledTask.AZURE_INVOCATION_ID, + None) + if task_invocation_id is not None: + return task_invocation_id + + return getattr(_invocation_id_local, 'v', None) + + +_invocation_id_local = threading.local() diff --git a/tests/unittests/test_shared_memory_manager.py b/tests/unittests/test_shared_memory_manager.py index 9193b20e9..1c5cf994c 100644 --- a/tests/unittests/test_shared_memory_manager.py +++ b/tests/unittests/test_shared_memory_manager.py @@ -1,345 +1,367 @@ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. - -import math -import os -import json -from azure_functions_worker.utils.common import is_envvar_true -from azure.functions import meta as bind_meta -from azure_functions_worker import testutils -from azure_functions_worker.bindings.shared_memory_data_transfer \ - import SharedMemoryManager -from azure_functions_worker.bindings.shared_memory_data_transfer \ - import SharedMemoryConstants as consts -from azure_functions_worker.constants \ - import FUNCTIONS_WORKER_SHARED_MEMORY_DATA_TRANSFER_ENABLED - - -class TestSharedMemoryManager(testutils.SharedMemoryTestCase): - """ - Tests for SharedMemoryManager. - """ - def test_is_enabled(self): - """ - Verify that when the AppSetting is enabled, SharedMemoryManager is - enabled. - """ - # Make sure shared memory data transfer is enabled - was_shmem_env_true = is_envvar_true( - FUNCTIONS_WORKER_SHARED_MEMORY_DATA_TRANSFER_ENABLED) - os.environ.update( - {FUNCTIONS_WORKER_SHARED_MEMORY_DATA_TRANSFER_ENABLED: '1'}) - manager = SharedMemoryManager() - self.assertTrue(manager.is_enabled()) - # Restore the env variable to original value - if not was_shmem_env_true: - os.environ.update( - {FUNCTIONS_WORKER_SHARED_MEMORY_DATA_TRANSFER_ENABLED: '0'}) - - def test_is_disabled(self): - """ - Verify that when the AppSetting is disabled, SharedMemoryManager is - disabled. - """ - # Make sure shared memory data transfer is disabled - was_shmem_env_true = is_envvar_true( - FUNCTIONS_WORKER_SHARED_MEMORY_DATA_TRANSFER_ENABLED) - os.environ.update( - {FUNCTIONS_WORKER_SHARED_MEMORY_DATA_TRANSFER_ENABLED: '0'}) - manager = SharedMemoryManager() - self.assertFalse(manager.is_enabled()) - # Restore the env variable to original value - if was_shmem_env_true: - os.environ.update( - {FUNCTIONS_WORKER_SHARED_MEMORY_DATA_TRANSFER_ENABLED: '1'}) - - def test_bytes_input_support(self): - """ - Verify that the given input is supported by SharedMemoryManager to be - transfered over shared memory. - The input is bytes. - """ - manager = SharedMemoryManager() - content_size = consts.MIN_BYTES_FOR_SHARED_MEM_TRANSFER + 10 - content = self.get_random_bytes(content_size) - bytes_datum = bind_meta.Datum(type='bytes', value=content) - is_supported = manager.is_supported(bytes_datum) - self.assertTrue(is_supported) - - def test_string_input_support(self): - """ - Verify that the given input is supported by SharedMemoryManager to be - transfered over shared memory. - The input is string. - """ - manager = SharedMemoryManager() - content_size = consts.MIN_BYTES_FOR_SHARED_MEM_TRANSFER + 10 - num_chars = math.floor(content_size / consts.SIZE_OF_CHAR_BYTES) - content = self.get_random_string(num_chars) - bytes_datum = bind_meta.Datum(type='string', value=content) - is_supported = manager.is_supported(bytes_datum) - self.assertTrue(is_supported) - - def test_int_input_unsupported(self): - """ - Verify that the given input is unsupported by SharedMemoryManager. - This input is int. - """ - manager = SharedMemoryManager() - datum = bind_meta.Datum(type='int', value=1) - is_supported = manager.is_supported(datum) - self.assertFalse(is_supported) - - def test_double_input_unsupported(self): - """ - Verify that the given input is unsupported by SharedMemoryManager. - This input is double. - """ - manager = SharedMemoryManager() - datum = bind_meta.Datum(type='double', value=1.0) - is_supported = manager.is_supported(datum) - self.assertFalse(is_supported) - - def test_json_input_unsupported(self): - """ - Verify that the given input is unsupported by SharedMemoryManager. - This input is json. - """ - manager = SharedMemoryManager() - content = { - 'name': 'foo', - 'val': 'bar' - } - datum = bind_meta.Datum(type='json', value=json.dumps(content)) - is_supported = manager.is_supported(datum) - self.assertFalse(is_supported) - - def test_collection_string_unsupported(self): - """ - Verify that the given input is unsupported by SharedMemoryManager. - This input is collection_string. - """ - manager = SharedMemoryManager() - content = ['foo', 'bar'] - datum = bind_meta.Datum(type='collection_string', value=content) - is_supported = manager.is_supported(datum) - self.assertFalse(is_supported) - - def test_collection_bytes_unsupported(self): - """ - Verify that the given input is unsupported by SharedMemoryManager. - This input is collection_bytes. - """ - manager = SharedMemoryManager() - content = [b'x01', b'x02'] - datum = bind_meta.Datum(type='collection_bytes', value=content) - is_supported = manager.is_supported(datum) - self.assertFalse(is_supported) - - def test_collection_double_unsupported(self): - """ - Verify that the given input is unsupported by SharedMemoryManager. - This input is collection_double. - """ - manager = SharedMemoryManager() - content = [1.0, 2.0] - datum = bind_meta.Datum(type='collection_double', value=content) - is_supported = manager.is_supported(datum) - self.assertFalse(is_supported) - - def test_collection_sint64_unsupported(self): - """ - Verify that the given input is unsupported by SharedMemoryManager. - This input is collection_sint64. - """ - manager = SharedMemoryManager() - content = [1, 2] - datum = bind_meta.Datum(type='collection_sint64', value=content) - is_supported = manager.is_supported(datum) - self.assertFalse(is_supported) - - def test_large_invalid_bytes_input_support(self): - """ - Verify that the given input is NOT supported by SharedMemoryManager to - be transfered over shared memory. - The input is bytes of larger than the allowed size. - """ - manager = SharedMemoryManager() - content_size = consts.MAX_BYTES_FOR_SHARED_MEM_TRANSFER + 10 - # Not using get_random_bytes to avoid slowing down for creating a large - # random input - content = b'x01' * content_size - bytes_datum = bind_meta.Datum(type='bytes', value=content) - is_supported = manager.is_supported(bytes_datum) - self.assertFalse(is_supported) - - def test_small_invalid_bytes_input_support(self): - """ - Verify that the given input is NOT supported by SharedMemoryManager to - be transfered over shared memory. - The input is bytes of smaller than the allowed size. - """ - manager = SharedMemoryManager() - content_size = consts.MIN_BYTES_FOR_SHARED_MEM_TRANSFER - 10 - content = self.get_random_bytes(content_size) - bytes_datum = bind_meta.Datum(type='bytes', value=content) - is_supported = manager.is_supported(bytes_datum) - self.assertFalse(is_supported) - - def test_large_invalid_string_input_support(self): - """ - Verify that the given input is NOT supported by SharedMemoryManager to - be transfered over shared memory. - The input is string of larger than the allowed size. - """ - manager = SharedMemoryManager() - content_size = consts.MAX_BYTES_FOR_SHARED_MEM_TRANSFER + 10 - num_chars = math.floor(content_size / consts.SIZE_OF_CHAR_BYTES) - # Not using get_random_string to avoid slowing down for creating a large - # random input - content = 'a' * num_chars - string_datum = bind_meta.Datum(type='string', value=content) - is_supported = manager.is_supported(string_datum) - self.assertFalse(is_supported) - - def test_small_invalid_string_input_support(self): - """ - Verify that the given input is NOT supported by SharedMemoryManager to - be transfered over shared memory. - The input is string of smaller than the allowed size. - """ - manager = SharedMemoryManager() - content_size = consts.MIN_BYTES_FOR_SHARED_MEM_TRANSFER - 10 - num_chars = math.floor(content_size / consts.SIZE_OF_CHAR_BYTES) - content = self.get_random_string(num_chars) - string_datum = bind_meta.Datum(type='string', value=content) - is_supported = manager.is_supported(string_datum) - self.assertFalse(is_supported) - - def test_put_bytes(self): - """ - Verify that the given input was successfully put into shared memory. - The input is bytes. - """ - manager = SharedMemoryManager() - content_size = consts.MIN_BYTES_FOR_SHARED_MEM_TRANSFER + 10 - content = self.get_random_bytes(content_size) - shared_mem_meta = manager.put_bytes(content) - self.assertIsNotNone(shared_mem_meta) - self.assertTrue(self.is_valid_uuid(shared_mem_meta.mem_map_name)) - self.assertEqual(content_size, shared_mem_meta.count_bytes) - free_success = manager.free_mem_map(shared_mem_meta.mem_map_name) - self.assertTrue(free_success) - - def test_invalid_put_bytes(self): - """ - Attempt to put bytes using an invalid input and verify that it fails. - """ - manager = SharedMemoryManager() - shared_mem_meta = manager.put_bytes(None) - self.assertIsNone(shared_mem_meta) - - def test_get_bytes(self): - """ - Verify that the output object was successfully gotten from shared - memory. - The output is bytes. - """ - manager = SharedMemoryManager() - content_size = consts.MIN_BYTES_FOR_SHARED_MEM_TRANSFER + 10 - content = self.get_random_bytes(content_size) - shared_mem_meta = manager.put_bytes(content) - mem_map_name = shared_mem_meta.mem_map_name - num_bytes_written = shared_mem_meta.count_bytes - read_content = manager.get_bytes(mem_map_name, offset=0, - count=num_bytes_written) - self.assertEqual(content, read_content) - free_success = manager.free_mem_map(mem_map_name) - self.assertTrue(free_success) - - def test_put_string(self): - """ - Verify that the given input was successfully put into shared memory. - The input is string. - """ - manager = SharedMemoryManager() - content_size = consts.MIN_BYTES_FOR_SHARED_MEM_TRANSFER + 10 - num_chars = math.floor(content_size / consts.SIZE_OF_CHAR_BYTES) - content = self.get_random_string(num_chars) - expected_size = len(content.encode('utf-8')) - shared_mem_meta = manager.put_string(content) - self.assertIsNotNone(shared_mem_meta) - self.assertTrue(self.is_valid_uuid(shared_mem_meta.mem_map_name)) - self.assertEqual(expected_size, shared_mem_meta.count_bytes) - free_success = manager.free_mem_map(shared_mem_meta.mem_map_name) - self.assertTrue(free_success) - - def test_invalid_put_string(self): - """ - Attempt to put a string using an invalid input and verify that it fails. - """ - manager = SharedMemoryManager() - shared_mem_meta = manager.put_string(None) - self.assertIsNone(shared_mem_meta) - - def test_get_string(self): - """ - Verify that the output object was successfully gotten from shared - memory. - The output is string. - """ - manager = SharedMemoryManager() - content_size = consts.MIN_BYTES_FOR_SHARED_MEM_TRANSFER + 10 - num_chars = math.floor(content_size / consts.SIZE_OF_CHAR_BYTES) - content = self.get_random_string(num_chars) - shared_mem_meta = manager.put_string(content) - mem_map_name = shared_mem_meta.mem_map_name - num_bytes_written = shared_mem_meta.count_bytes - read_content = manager.get_string(mem_map_name, offset=0, - count=num_bytes_written) - self.assertEqual(content, read_content) - free_success = manager.free_mem_map(mem_map_name) - self.assertTrue(free_success) - - def test_allocated_mem_maps(self): - """ - Verify that the SharedMemoryManager is tracking the shared memory maps - it has allocated after put operations. - Verify that those shared memory maps are freed and no longer tracked - after attempting to free them. - """ - manager = SharedMemoryManager() - content_size = consts.MIN_BYTES_FOR_SHARED_MEM_TRANSFER + 10 - content = self.get_random_bytes(content_size) - shared_mem_meta = manager.put_bytes(content) - self.assertIsNotNone(shared_mem_meta) - mem_map_name = shared_mem_meta.mem_map_name - is_mem_map_found = mem_map_name in manager.allocated_mem_maps - self.assertTrue(is_mem_map_found) - self.assertEqual(1, len(manager.allocated_mem_maps.keys())) - free_success = manager.free_mem_map(mem_map_name) - self.assertTrue(free_success) - is_mem_map_found = mem_map_name in manager.allocated_mem_maps - self.assertFalse(is_mem_map_found) - self.assertEqual(0, len(manager.allocated_mem_maps.keys())) - - def test_invalid_put_allocated_mem_maps(self): - """ - Verify that after an invalid put operation, no shared memory maps were - added to the list of allocated/tracked shared memory maps. - i.e. no resources were leaked for invalid operations. - """ - manager = SharedMemoryManager() - shared_mem_meta = manager.put_bytes(None) - self.assertIsNone(shared_mem_meta) - self.assertEqual(0, len(manager.allocated_mem_maps.keys())) - - def test_invalid_free_mem_map(self): - """ - Attempt to free a shared memory map that does not exist in the list of - allocated/tracked shared memory maps and verify that it fails. - """ - manager = SharedMemoryManager() - mem_map_name = self.get_new_mem_map_name() - free_success = manager.free_mem_map(mem_map_name) - self.assertFalse(free_success) +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import math +import os +import json +from azure_functions_worker.utils.common import is_envvar_true +from azure.functions import meta as bind_meta +from azure_functions_worker import testutils +from azure_functions_worker.bindings.shared_memory_data_transfer \ + import SharedMemoryManager +from azure_functions_worker.bindings.shared_memory_data_transfer \ + import SharedMemoryConstants as consts +from azure_functions_worker.constants \ + import FUNCTIONS_WORKER_SHARED_MEMORY_DATA_TRANSFER_ENABLED + + +class TestSharedMemoryManager(testutils.SharedMemoryTestCase): + """ + Tests for SharedMemoryManager. + """ + def test_is_enabled(self): + """ + Verify that when the AppSetting is enabled, SharedMemoryManager is + enabled. + """ + # Make sure shared memory data transfer is enabled + was_shmem_env_true = is_envvar_true( + FUNCTIONS_WORKER_SHARED_MEMORY_DATA_TRANSFER_ENABLED) + os.environ.update( + {FUNCTIONS_WORKER_SHARED_MEMORY_DATA_TRANSFER_ENABLED: '1'}) + manager = SharedMemoryManager() + self.assertTrue(manager.is_enabled()) + # Restore the env variable to original value + if not was_shmem_env_true: + os.environ.update( + {FUNCTIONS_WORKER_SHARED_MEMORY_DATA_TRANSFER_ENABLED: '0'}) + + def test_is_disabled(self): + """ + Verify that when the AppSetting is disabled, SharedMemoryManager is + disabled. + """ + # Make sure shared memory data transfer is disabled + was_shmem_env_true = is_envvar_true( + FUNCTIONS_WORKER_SHARED_MEMORY_DATA_TRANSFER_ENABLED) + os.environ.update( + {FUNCTIONS_WORKER_SHARED_MEMORY_DATA_TRANSFER_ENABLED: '0'}) + manager = SharedMemoryManager() + self.assertFalse(manager.is_enabled()) + # Restore the env variable to original value + if was_shmem_env_true: + os.environ.update( + {FUNCTIONS_WORKER_SHARED_MEMORY_DATA_TRANSFER_ENABLED: '1'}) + + def test_bytes_input_support(self): + """ + Verify that the given input is supported by SharedMemoryManager to be + transfered over shared memory. + The input is bytes. + """ + manager = SharedMemoryManager() + content_size = consts.MIN_BYTES_FOR_SHARED_MEM_TRANSFER + 10 + content = self.get_random_bytes(content_size) + bytes_datum = bind_meta.Datum(type='bytes', value=content) + is_supported = manager.is_supported(bytes_datum) + self.assertTrue(is_supported) + + def test_string_input_support(self): + """ + Verify that the given input is supported by SharedMemoryManager to be + transfered over shared memory. + The input is string. + """ + manager = SharedMemoryManager() + content_size = consts.MIN_BYTES_FOR_SHARED_MEM_TRANSFER + 10 + num_chars = math.floor(content_size / consts.SIZE_OF_CHAR_BYTES) + content = self.get_random_string(num_chars) + bytes_datum = bind_meta.Datum(type='string', value=content) + is_supported = manager.is_supported(bytes_datum) + self.assertTrue(is_supported) + + def test_int_input_unsupported(self): + """ + Verify that the given input is unsupported by SharedMemoryManager. + This input is int. + """ + manager = SharedMemoryManager() + datum = bind_meta.Datum(type='int', value=1) + is_supported = manager.is_supported(datum) + self.assertFalse(is_supported) + + def test_double_input_unsupported(self): + """ + Verify that the given input is unsupported by SharedMemoryManager. + This input is double. + """ + manager = SharedMemoryManager() + datum = bind_meta.Datum(type='double', value=1.0) + is_supported = manager.is_supported(datum) + self.assertFalse(is_supported) + + def test_json_input_unsupported(self): + """ + Verify that the given input is unsupported by SharedMemoryManager. + This input is json. + """ + manager = SharedMemoryManager() + content = { + 'name': 'foo', + 'val': 'bar' + } + datum = bind_meta.Datum(type='json', value=json.dumps(content)) + is_supported = manager.is_supported(datum) + self.assertFalse(is_supported) + + def test_collection_string_unsupported(self): + """ + Verify that the given input is unsupported by SharedMemoryManager. + This input is collection_string. + """ + manager = SharedMemoryManager() + content = ['foo', 'bar'] + datum = bind_meta.Datum(type='collection_string', value=content) + is_supported = manager.is_supported(datum) + self.assertFalse(is_supported) + + def test_collection_bytes_unsupported(self): + """ + Verify that the given input is unsupported by SharedMemoryManager. + This input is collection_bytes. + """ + manager = SharedMemoryManager() + content = [b'x01', b'x02'] + datum = bind_meta.Datum(type='collection_bytes', value=content) + is_supported = manager.is_supported(datum) + self.assertFalse(is_supported) + + def test_collection_double_unsupported(self): + """ + Verify that the given input is unsupported by SharedMemoryManager. + This input is collection_double. + """ + manager = SharedMemoryManager() + content = [1.0, 2.0] + datum = bind_meta.Datum(type='collection_double', value=content) + is_supported = manager.is_supported(datum) + self.assertFalse(is_supported) + + def test_collection_sint64_unsupported(self): + """ + Verify that the given input is unsupported by SharedMemoryManager. + This input is collection_sint64. + """ + manager = SharedMemoryManager() + content = [1, 2] + datum = bind_meta.Datum(type='collection_sint64', value=content) + is_supported = manager.is_supported(datum) + self.assertFalse(is_supported) + + def test_large_invalid_bytes_input_support(self): + """ + Verify that the given input is NOT supported by SharedMemoryManager to + be transfered over shared memory. + The input is bytes of larger than the allowed size. + """ + manager = SharedMemoryManager() + content_size = consts.MAX_BYTES_FOR_SHARED_MEM_TRANSFER + 10 + # Not using get_random_bytes to avoid slowing down for creating a large + # random input + content = b'x01' * content_size + bytes_datum = bind_meta.Datum(type='bytes', value=content) + is_supported = manager.is_supported(bytes_datum) + self.assertFalse(is_supported) + + def test_small_invalid_bytes_input_support(self): + """ + Verify that the given input is NOT supported by SharedMemoryManager to + be transfered over shared memory. + The input is bytes of smaller than the allowed size. + """ + manager = SharedMemoryManager() + content_size = consts.MIN_BYTES_FOR_SHARED_MEM_TRANSFER - 10 + content = self.get_random_bytes(content_size) + bytes_datum = bind_meta.Datum(type='bytes', value=content) + is_supported = manager.is_supported(bytes_datum) + self.assertFalse(is_supported) + + def test_large_invalid_string_input_support(self): + """ + Verify that the given input is NOT supported by SharedMemoryManager to + be transfered over shared memory. + The input is string of larger than the allowed size. + """ + manager = SharedMemoryManager() + content_size = consts.MAX_BYTES_FOR_SHARED_MEM_TRANSFER + 10 + num_chars = math.floor(content_size / consts.SIZE_OF_CHAR_BYTES) + # Not using get_random_string to avoid slowing down for creating a large + # random input + content = 'a' * num_chars + string_datum = bind_meta.Datum(type='string', value=content) + is_supported = manager.is_supported(string_datum) + self.assertFalse(is_supported) + + def test_small_invalid_string_input_support(self): + """ + Verify that the given input is NOT supported by SharedMemoryManager to + be transfered over shared memory. + The input is string of smaller than the allowed size. + """ + manager = SharedMemoryManager() + content_size = consts.MIN_BYTES_FOR_SHARED_MEM_TRANSFER - 10 + num_chars = math.floor(content_size / consts.SIZE_OF_CHAR_BYTES) + content = self.get_random_string(num_chars) + string_datum = bind_meta.Datum(type='string', value=content) + is_supported = manager.is_supported(string_datum) + self.assertFalse(is_supported) + + def test_put_bytes(self): + """ + Verify that the given input was successfully put into shared memory. + The input is bytes. + """ + manager = SharedMemoryManager() + content_size = consts.MIN_BYTES_FOR_SHARED_MEM_TRANSFER + 10 + content = self.get_random_bytes(content_size) + shared_mem_meta = manager.put_bytes(content) + self.assertIsNotNone(shared_mem_meta) + self.assertTrue(self.is_valid_uuid(shared_mem_meta.mem_map_name)) + self.assertEqual(content_size, shared_mem_meta.count_bytes) + free_success = manager.free_mem_map(shared_mem_meta.mem_map_name) + self.assertTrue(free_success) + + def test_invalid_put_bytes(self): + """ + Attempt to put bytes using an invalid input and verify that it fails. + """ + manager = SharedMemoryManager() + shared_mem_meta = manager.put_bytes(None) + self.assertIsNone(shared_mem_meta) + + def test_get_bytes(self): + """ + Verify that the output object was successfully gotten from shared + memory. + The output is bytes. + """ + manager = SharedMemoryManager() + content_size = consts.MIN_BYTES_FOR_SHARED_MEM_TRANSFER + 10 + content = self.get_random_bytes(content_size) + shared_mem_meta = manager.put_bytes(content) + mem_map_name = shared_mem_meta.mem_map_name + num_bytes_written = shared_mem_meta.count_bytes + read_content = manager.get_bytes(mem_map_name, offset=0, + count=num_bytes_written) + self.assertEqual(content, read_content) + free_success = manager.free_mem_map(mem_map_name) + self.assertTrue(free_success) + + def test_put_string(self): + """ + Verify that the given input was successfully put into shared memory. + The input is string. + """ + manager = SharedMemoryManager() + content_size = consts.MIN_BYTES_FOR_SHARED_MEM_TRANSFER + 10 + num_chars = math.floor(content_size / consts.SIZE_OF_CHAR_BYTES) + content = self.get_random_string(num_chars) + expected_size = len(content.encode('utf-8')) + shared_mem_meta = manager.put_string(content) + self.assertIsNotNone(shared_mem_meta) + self.assertTrue(self.is_valid_uuid(shared_mem_meta.mem_map_name)) + self.assertEqual(expected_size, shared_mem_meta.count_bytes) + free_success = manager.free_mem_map(shared_mem_meta.mem_map_name) + self.assertTrue(free_success) + + def test_invalid_put_string(self): + """ + Attempt to put a string using an invalid input and verify that it fails. + """ + manager = SharedMemoryManager() + shared_mem_meta = manager.put_string(None) + self.assertIsNone(shared_mem_meta) + + def test_get_string(self): + """ + Verify that the output object was successfully gotten from shared + memory. + The output is string. + """ + manager = SharedMemoryManager() + content_size = consts.MIN_BYTES_FOR_SHARED_MEM_TRANSFER + 10 + num_chars = math.floor(content_size / consts.SIZE_OF_CHAR_BYTES) + content = self.get_random_string(num_chars) + shared_mem_meta = manager.put_string(content) + mem_map_name = shared_mem_meta.mem_map_name + num_bytes_written = shared_mem_meta.count_bytes + read_content = manager.get_string(mem_map_name, offset=0, + count=num_bytes_written) + self.assertEqual(content, read_content) + free_success = manager.free_mem_map(mem_map_name) + self.assertTrue(free_success) + + def test_allocated_mem_maps(self): + """ + Verify that the SharedMemoryManager is tracking the shared memory maps + it has allocated after put operations. + Verify that those shared memory maps are freed and no longer tracked + after attempting to free them. + """ + manager = SharedMemoryManager() + content_size = consts.MIN_BYTES_FOR_SHARED_MEM_TRANSFER + 10 + content = self.get_random_bytes(content_size) + shared_mem_meta = manager.put_bytes(content) + self.assertIsNotNone(shared_mem_meta) + mem_map_name = shared_mem_meta.mem_map_name + is_mem_map_found = mem_map_name in manager.allocated_mem_maps + self.assertTrue(is_mem_map_found) + self.assertEqual(1, len(manager.allocated_mem_maps.keys())) + free_success = manager.free_mem_map(mem_map_name) + self.assertTrue(free_success) + is_mem_map_found = mem_map_name in manager.allocated_mem_maps + self.assertFalse(is_mem_map_found) + self.assertEqual(0, len(manager.allocated_mem_maps.keys())) + + def test_do_not_free_resources_on_dispose(self): + """ + Verify that when the allocated shared memory maps are freed, + their backing resources are not freed. + Note: The shared memory map should no longer be tracked by the + SharedMemoryManager, though. + """ + manager = SharedMemoryManager() + content_size = consts.MIN_BYTES_FOR_SHARED_MEM_TRANSFER + 10 + content = self.get_random_bytes(content_size) + shared_mem_meta = manager.put_bytes(content) + self.assertIsNotNone(shared_mem_meta) + mem_map_name = shared_mem_meta.mem_map_name + is_mem_map_found = mem_map_name in manager.allocated_mem_maps + self.assertTrue(is_mem_map_found) + self.assertEqual(1, len(manager.allocated_mem_maps.keys())) + free_success = manager.free_mem_map(mem_map_name, False) + self.assertTrue(free_success) + is_mem_map_found = mem_map_name in manager.allocated_mem_maps + self.assertFalse(is_mem_map_found) + self.assertEqual(0, len(manager.allocated_mem_maps.keys())) + + def test_invalid_put_allocated_mem_maps(self): + """ + Verify that after an invalid put operation, no shared memory maps were + added to the list of allocated/tracked shared memory maps. + i.e. no resources were leaked for invalid operations. + """ + manager = SharedMemoryManager() + shared_mem_meta = manager.put_bytes(None) + self.assertIsNone(shared_mem_meta) + self.assertEqual(0, len(manager.allocated_mem_maps.keys())) + + def test_invalid_free_mem_map(self): + """ + Attempt to free a shared memory map that does not exist in the list of + allocated/tracked shared memory maps and verify that it fails. + """ + manager = SharedMemoryManager() + mem_map_name = self.get_new_mem_map_name() + free_success = manager.free_mem_map(mem_map_name) + self.assertFalse(free_success) From 45265fd0f293d84312dd23fb89cb6f748b689895 Mon Sep 17 00:00:00 2001 From: Gohar Irfan Chaudhry Date: Mon, 20 Sep 2021 08:32:20 -0700 Subject: [PATCH 06/11] Undo bad file changes --- .../shared_memory_manager.py | 402 +++++----- tests/unittests/test_shared_memory_manager.py | 750 +++++++++--------- 2 files changed, 584 insertions(+), 568 deletions(-) diff --git a/azure_functions_worker/bindings/shared_memory_data_transfer/shared_memory_manager.py b/azure_functions_worker/bindings/shared_memory_data_transfer/shared_memory_manager.py index c8efc55d9..b5fc1ed03 100644 --- a/azure_functions_worker/bindings/shared_memory_data_transfer/shared_memory_manager.py +++ b/azure_functions_worker/bindings/shared_memory_data_transfer/shared_memory_manager.py @@ -1,201 +1,201 @@ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. - -import uuid -from typing import Dict, Optional -from .shared_memory_constants import SharedMemoryConstants as consts -from .file_accessor_factory import FileAccessorFactory -from .shared_memory_metadata import SharedMemoryMetadata -from .shared_memory_map import SharedMemoryMap -from ..datumdef import Datum -from ...logging import logger -from ...utils.common import is_envvar_true -from ...constants import FUNCTIONS_WORKER_SHARED_MEMORY_DATA_TRANSFER_ENABLED - - -class SharedMemoryManager: - """ - Performs all operations related to reading/writing data from/to shared - memory. - This is used for transferring input/output data of the function from/to the - functions host over shared memory as opposed to RPC to improve the rate of - data transfer and the function's end-to-end latency. - """ - def __init__(self): - # The allocated memory maps are tracked here so that a reference to them - # is kept open until they have been used (e.g. if they contain a - # function's output, it is read by the functions host). - # Having a mapping of the name and the memory map is then later used to - # close a given memory map by its name, after it has been used. - # key: mem_map_name, val: SharedMemoryMap - self._allocated_mem_maps: Dict[str, SharedMemoryMap] = {} - self._file_accessor = FileAccessorFactory.create_file_accessor() - - def __del__(self): - del self._file_accessor - del self._allocated_mem_maps - - @property - def allocated_mem_maps(self): - """ - List of allocated shared memory maps. - """ - return self._allocated_mem_maps - - @property - def file_accessor(self): - """ - FileAccessor instance for accessing memory maps. - """ - return self._file_accessor - - def is_enabled(self) -> bool: - """ - Whether supported types should be transferred between functions host and - the worker using shared memory. - """ - return is_envvar_true( - FUNCTIONS_WORKER_SHARED_MEMORY_DATA_TRANSFER_ENABLED) - - def is_supported(self, datum: Datum) -> bool: - """ - Whether the given Datum object can be transferred to the functions host - using shared memory. - This logic is kept consistent with the host's which can be found in - SharedMemoryManager.cs - """ - if datum.type == 'bytes': - num_bytes = len(datum.value) - if num_bytes >= consts.MIN_BYTES_FOR_SHARED_MEM_TRANSFER and \ - num_bytes <= consts.MAX_BYTES_FOR_SHARED_MEM_TRANSFER: - return True - elif datum.type == 'string': - num_bytes = len(datum.value) * consts.SIZE_OF_CHAR_BYTES - if num_bytes >= consts.MIN_BYTES_FOR_SHARED_MEM_TRANSFER and \ - num_bytes <= consts.MAX_BYTES_FOR_SHARED_MEM_TRANSFER: - return True - return False - - def put_bytes(self, content: bytes) -> Optional[SharedMemoryMetadata]: - """ - Writes the given bytes into shared memory. - Returns metadata about the shared memory region to which the content was - written if successful, None otherwise. - """ - if content is None: - return None - mem_map_name = str(uuid.uuid4()) - content_length = len(content) - shared_mem_map = self._create(mem_map_name, content_length) - if shared_mem_map is None: - return None - try: - num_bytes_written = shared_mem_map.put_bytes(content) - except Exception as e: - logger.warning(f'Cannot write {content_length} bytes into shared ' - f'memory {mem_map_name} - {e}') - shared_mem_map.dispose() - return None - if num_bytes_written != content_length: - logger.error( - f'Cannot write data into shared memory {mem_map_name} ' - f'({num_bytes_written} != {content_length})') - shared_mem_map.dispose() - return None - self.allocated_mem_maps[mem_map_name] = shared_mem_map - return SharedMemoryMetadata(mem_map_name, content_length) - - def put_string(self, content: str) -> Optional[SharedMemoryMetadata]: - """ - Writes the given string into shared memory. - Returns the name of the memory map into which the data was written if - succesful, None otherwise. - Note: The encoding used here must be consistent with what is used by the - host in SharedMemoryManager.cs (GetStringAsync/PutStringAsync). - """ - if content is None: - return None - content_bytes = content.encode('utf-8') - return self.put_bytes(content_bytes) - - def get_bytes(self, mem_map_name: str, offset: int, count: int) \ - -> Optional[bytes]: - """ - Reads data from the given memory map with the provided name, starting at - the provided offset and reading a total of count bytes. - Returns the data read from shared memory as bytes if successful, None - otherwise. - """ - if offset != 0: - logger.error( - f'Cannot read bytes. Non-zero offset ({offset}) ' - f'not supported.') - return None - shared_mem_map = self._open(mem_map_name, count) - if shared_mem_map is None: - return None - try: - content = shared_mem_map.get_bytes(content_offset=0, - bytes_to_read=count) - finally: - shared_mem_map.dispose(is_delete_file=False) - return content - - def get_string(self, mem_map_name: str, offset: int, count: int) \ - -> Optional[str]: - """ - Reads data from the given memory map with the provided name, starting at - the provided offset and reading a total of count bytes. - Returns the data read from shared memory as a string if successful, None - otherwise. - Note: The encoding used here must be consistent with what is used by the - host in SharedMemoryManager.cs (GetStringAsync/PutStringAsync). - """ - content_bytes = self.get_bytes(mem_map_name, offset, count) - if content_bytes is None: - return None - content_str = content_bytes.decode('utf-8') - return content_str - - def free_mem_map(self, mem_map_name: str, - to_delete_backing_resources: bool = True) -> bool: - """ - Frees the memory map and, if specified, any backing resources (e.g. - file in the case of Unix) associated with it. - If there is no memory map with the given name being tracked, then no - action is performed. - Returns True if the memory map was freed successfully, False otherwise. - """ - if mem_map_name not in self.allocated_mem_maps: - logger.error( - f'Cannot find memory map in list of allocations {mem_map_name}') - return False - shared_mem_map = self.allocated_mem_maps[mem_map_name] - success = shared_mem_map.dispose(to_delete_backing_resources) - del self.allocated_mem_maps[mem_map_name] - return success - - def _create(self, mem_map_name: str, content_length: int) \ - -> Optional[SharedMemoryMap]: - """ - Creates a new SharedMemoryMap with the given name and content length. - Returns the SharedMemoryMap object if successful, None otherwise. - """ - mem_map_size = consts.CONTENT_HEADER_TOTAL_BYTES + content_length - mem_map = self.file_accessor.create_mem_map(mem_map_name, mem_map_size) - if mem_map is None: - return None - return SharedMemoryMap(self.file_accessor, mem_map_name, mem_map) - - def _open(self, mem_map_name: str, content_length: int) \ - -> Optional[SharedMemoryMap]: - """ - Opens an existing SharedMemoryMap with the given name and content - length. - Returns the SharedMemoryMap object if successful, None otherwise. - """ - mem_map_size = consts.CONTENT_HEADER_TOTAL_BYTES + content_length - mem_map = self.file_accessor.open_mem_map(mem_map_name, mem_map_size) - if mem_map is None: - return None - return SharedMemoryMap(self.file_accessor, mem_map_name, mem_map) +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import uuid +from typing import Dict, Optional +from .shared_memory_constants import SharedMemoryConstants as consts +from .file_accessor_factory import FileAccessorFactory +from .shared_memory_metadata import SharedMemoryMetadata +from .shared_memory_map import SharedMemoryMap +from ..datumdef import Datum +from ...logging import logger +from ...utils.common import is_envvar_true +from ...constants import FUNCTIONS_WORKER_SHARED_MEMORY_DATA_TRANSFER_ENABLED + + +class SharedMemoryManager: + """ + Performs all operations related to reading/writing data from/to shared + memory. + This is used for transferring input/output data of the function from/to the + functions host over shared memory as opposed to RPC to improve the rate of + data transfer and the function's end-to-end latency. + """ + def __init__(self): + # The allocated memory maps are tracked here so that a reference to them + # is kept open until they have been used (e.g. if they contain a + # function's output, it is read by the functions host). + # Having a mapping of the name and the memory map is then later used to + # close a given memory map by its name, after it has been used. + # key: mem_map_name, val: SharedMemoryMap + self._allocated_mem_maps: Dict[str, SharedMemoryMap] = {} + self._file_accessor = FileAccessorFactory.create_file_accessor() + + def __del__(self): + del self._file_accessor + del self._allocated_mem_maps + + @property + def allocated_mem_maps(self): + """ + List of allocated shared memory maps. + """ + return self._allocated_mem_maps + + @property + def file_accessor(self): + """ + FileAccessor instance for accessing memory maps. + """ + return self._file_accessor + + def is_enabled(self) -> bool: + """ + Whether supported types should be transferred between functions host and + the worker using shared memory. + """ + return is_envvar_true( + FUNCTIONS_WORKER_SHARED_MEMORY_DATA_TRANSFER_ENABLED) + + def is_supported(self, datum: Datum) -> bool: + """ + Whether the given Datum object can be transferred to the functions host + using shared memory. + This logic is kept consistent with the host's which can be found in + SharedMemoryManager.cs + """ + if datum.type == 'bytes': + num_bytes = len(datum.value) + if num_bytes >= consts.MIN_BYTES_FOR_SHARED_MEM_TRANSFER and \ + num_bytes <= consts.MAX_BYTES_FOR_SHARED_MEM_TRANSFER: + return True + elif datum.type == 'string': + num_bytes = len(datum.value) * consts.SIZE_OF_CHAR_BYTES + if num_bytes >= consts.MIN_BYTES_FOR_SHARED_MEM_TRANSFER and \ + num_bytes <= consts.MAX_BYTES_FOR_SHARED_MEM_TRANSFER: + return True + return False + + def put_bytes(self, content: bytes) -> Optional[SharedMemoryMetadata]: + """ + Writes the given bytes into shared memory. + Returns metadata about the shared memory region to which the content was + written if successful, None otherwise. + """ + if content is None: + return None + mem_map_name = str(uuid.uuid4()) + content_length = len(content) + shared_mem_map = self._create(mem_map_name, content_length) + if shared_mem_map is None: + return None + try: + num_bytes_written = shared_mem_map.put_bytes(content) + except Exception as e: + logger.warning(f'Cannot write {content_length} bytes into shared ' + f'memory {mem_map_name} - {e}') + shared_mem_map.dispose() + return None + if num_bytes_written != content_length: + logger.error( + f'Cannot write data into shared memory {mem_map_name} ' + f'({num_bytes_written} != {content_length})') + shared_mem_map.dispose() + return None + self.allocated_mem_maps[mem_map_name] = shared_mem_map + return SharedMemoryMetadata(mem_map_name, content_length) + + def put_string(self, content: str) -> Optional[SharedMemoryMetadata]: + """ + Writes the given string into shared memory. + Returns the name of the memory map into which the data was written if + succesful, None otherwise. + Note: The encoding used here must be consistent with what is used by the + host in SharedMemoryManager.cs (GetStringAsync/PutStringAsync). + """ + if content is None: + return None + content_bytes = content.encode('utf-8') + return self.put_bytes(content_bytes) + + def get_bytes(self, mem_map_name: str, offset: int, count: int) \ + -> Optional[bytes]: + """ + Reads data from the given memory map with the provided name, starting at + the provided offset and reading a total of count bytes. + Returns the data read from shared memory as bytes if successful, None + otherwise. + """ + if offset != 0: + logger.error( + f'Cannot read bytes. Non-zero offset ({offset}) ' + f'not supported.') + return None + shared_mem_map = self._open(mem_map_name, count) + if shared_mem_map is None: + return None + try: + content = shared_mem_map.get_bytes(content_offset=0, + bytes_to_read=count) + finally: + shared_mem_map.dispose(is_delete_file=False) + return content + + def get_string(self, mem_map_name: str, offset: int, count: int) \ + -> Optional[str]: + """ + Reads data from the given memory map with the provided name, starting at + the provided offset and reading a total of count bytes. + Returns the data read from shared memory as a string if successful, None + otherwise. + Note: The encoding used here must be consistent with what is used by the + host in SharedMemoryManager.cs (GetStringAsync/PutStringAsync). + """ + content_bytes = self.get_bytes(mem_map_name, offset, count) + if content_bytes is None: + return None + content_str = content_bytes.decode('utf-8') + return content_str + + def free_mem_map(self, mem_map_name: str, + to_delete_backing_resources: bool = True) -> bool: + """ + Frees the memory map and, if specified, any backing resources (e.g. + file in the case of Unix) associated with it. + If there is no memory map with the given name being tracked, then no + action is performed. + Returns True if the memory map was freed successfully, False otherwise. + """ + if mem_map_name not in self.allocated_mem_maps: + logger.error( + f'Cannot find memory map in list of allocations {mem_map_name}') + return False + shared_mem_map = self.allocated_mem_maps[mem_map_name] + success = shared_mem_map.dispose(to_delete_backing_resources) + del self.allocated_mem_maps[mem_map_name] + return success + + def _create(self, mem_map_name: str, content_length: int) \ + -> Optional[SharedMemoryMap]: + """ + Creates a new SharedMemoryMap with the given name and content length. + Returns the SharedMemoryMap object if successful, None otherwise. + """ + mem_map_size = consts.CONTENT_HEADER_TOTAL_BYTES + content_length + mem_map = self.file_accessor.create_mem_map(mem_map_name, mem_map_size) + if mem_map is None: + return None + return SharedMemoryMap(self.file_accessor, mem_map_name, mem_map) + + def _open(self, mem_map_name: str, content_length: int) \ + -> Optional[SharedMemoryMap]: + """ + Opens an existing SharedMemoryMap with the given name and content + length. + Returns the SharedMemoryMap object if successful, None otherwise. + """ + mem_map_size = consts.CONTENT_HEADER_TOTAL_BYTES + content_length + mem_map = self.file_accessor.open_mem_map(mem_map_name, mem_map_size) + if mem_map is None: + return None + return SharedMemoryMap(self.file_accessor, mem_map_name, mem_map) diff --git a/tests/unittests/test_shared_memory_manager.py b/tests/unittests/test_shared_memory_manager.py index 1c5cf994c..0cdb7c234 100644 --- a/tests/unittests/test_shared_memory_manager.py +++ b/tests/unittests/test_shared_memory_manager.py @@ -1,367 +1,383 @@ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. - -import math -import os -import json -from azure_functions_worker.utils.common import is_envvar_true -from azure.functions import meta as bind_meta -from azure_functions_worker import testutils -from azure_functions_worker.bindings.shared_memory_data_transfer \ - import SharedMemoryManager -from azure_functions_worker.bindings.shared_memory_data_transfer \ - import SharedMemoryConstants as consts -from azure_functions_worker.constants \ - import FUNCTIONS_WORKER_SHARED_MEMORY_DATA_TRANSFER_ENABLED - - -class TestSharedMemoryManager(testutils.SharedMemoryTestCase): - """ - Tests for SharedMemoryManager. - """ - def test_is_enabled(self): - """ - Verify that when the AppSetting is enabled, SharedMemoryManager is - enabled. - """ - # Make sure shared memory data transfer is enabled - was_shmem_env_true = is_envvar_true( - FUNCTIONS_WORKER_SHARED_MEMORY_DATA_TRANSFER_ENABLED) - os.environ.update( - {FUNCTIONS_WORKER_SHARED_MEMORY_DATA_TRANSFER_ENABLED: '1'}) - manager = SharedMemoryManager() - self.assertTrue(manager.is_enabled()) - # Restore the env variable to original value - if not was_shmem_env_true: - os.environ.update( - {FUNCTIONS_WORKER_SHARED_MEMORY_DATA_TRANSFER_ENABLED: '0'}) - - def test_is_disabled(self): - """ - Verify that when the AppSetting is disabled, SharedMemoryManager is - disabled. - """ - # Make sure shared memory data transfer is disabled - was_shmem_env_true = is_envvar_true( - FUNCTIONS_WORKER_SHARED_MEMORY_DATA_TRANSFER_ENABLED) - os.environ.update( - {FUNCTIONS_WORKER_SHARED_MEMORY_DATA_TRANSFER_ENABLED: '0'}) - manager = SharedMemoryManager() - self.assertFalse(manager.is_enabled()) - # Restore the env variable to original value - if was_shmem_env_true: - os.environ.update( - {FUNCTIONS_WORKER_SHARED_MEMORY_DATA_TRANSFER_ENABLED: '1'}) - - def test_bytes_input_support(self): - """ - Verify that the given input is supported by SharedMemoryManager to be - transfered over shared memory. - The input is bytes. - """ - manager = SharedMemoryManager() - content_size = consts.MIN_BYTES_FOR_SHARED_MEM_TRANSFER + 10 - content = self.get_random_bytes(content_size) - bytes_datum = bind_meta.Datum(type='bytes', value=content) - is_supported = manager.is_supported(bytes_datum) - self.assertTrue(is_supported) - - def test_string_input_support(self): - """ - Verify that the given input is supported by SharedMemoryManager to be - transfered over shared memory. - The input is string. - """ - manager = SharedMemoryManager() - content_size = consts.MIN_BYTES_FOR_SHARED_MEM_TRANSFER + 10 - num_chars = math.floor(content_size / consts.SIZE_OF_CHAR_BYTES) - content = self.get_random_string(num_chars) - bytes_datum = bind_meta.Datum(type='string', value=content) - is_supported = manager.is_supported(bytes_datum) - self.assertTrue(is_supported) - - def test_int_input_unsupported(self): - """ - Verify that the given input is unsupported by SharedMemoryManager. - This input is int. - """ - manager = SharedMemoryManager() - datum = bind_meta.Datum(type='int', value=1) - is_supported = manager.is_supported(datum) - self.assertFalse(is_supported) - - def test_double_input_unsupported(self): - """ - Verify that the given input is unsupported by SharedMemoryManager. - This input is double. - """ - manager = SharedMemoryManager() - datum = bind_meta.Datum(type='double', value=1.0) - is_supported = manager.is_supported(datum) - self.assertFalse(is_supported) - - def test_json_input_unsupported(self): - """ - Verify that the given input is unsupported by SharedMemoryManager. - This input is json. - """ - manager = SharedMemoryManager() - content = { - 'name': 'foo', - 'val': 'bar' - } - datum = bind_meta.Datum(type='json', value=json.dumps(content)) - is_supported = manager.is_supported(datum) - self.assertFalse(is_supported) - - def test_collection_string_unsupported(self): - """ - Verify that the given input is unsupported by SharedMemoryManager. - This input is collection_string. - """ - manager = SharedMemoryManager() - content = ['foo', 'bar'] - datum = bind_meta.Datum(type='collection_string', value=content) - is_supported = manager.is_supported(datum) - self.assertFalse(is_supported) - - def test_collection_bytes_unsupported(self): - """ - Verify that the given input is unsupported by SharedMemoryManager. - This input is collection_bytes. - """ - manager = SharedMemoryManager() - content = [b'x01', b'x02'] - datum = bind_meta.Datum(type='collection_bytes', value=content) - is_supported = manager.is_supported(datum) - self.assertFalse(is_supported) - - def test_collection_double_unsupported(self): - """ - Verify that the given input is unsupported by SharedMemoryManager. - This input is collection_double. - """ - manager = SharedMemoryManager() - content = [1.0, 2.0] - datum = bind_meta.Datum(type='collection_double', value=content) - is_supported = manager.is_supported(datum) - self.assertFalse(is_supported) - - def test_collection_sint64_unsupported(self): - """ - Verify that the given input is unsupported by SharedMemoryManager. - This input is collection_sint64. - """ - manager = SharedMemoryManager() - content = [1, 2] - datum = bind_meta.Datum(type='collection_sint64', value=content) - is_supported = manager.is_supported(datum) - self.assertFalse(is_supported) - - def test_large_invalid_bytes_input_support(self): - """ - Verify that the given input is NOT supported by SharedMemoryManager to - be transfered over shared memory. - The input is bytes of larger than the allowed size. - """ - manager = SharedMemoryManager() - content_size = consts.MAX_BYTES_FOR_SHARED_MEM_TRANSFER + 10 - # Not using get_random_bytes to avoid slowing down for creating a large - # random input - content = b'x01' * content_size - bytes_datum = bind_meta.Datum(type='bytes', value=content) - is_supported = manager.is_supported(bytes_datum) - self.assertFalse(is_supported) - - def test_small_invalid_bytes_input_support(self): - """ - Verify that the given input is NOT supported by SharedMemoryManager to - be transfered over shared memory. - The input is bytes of smaller than the allowed size. - """ - manager = SharedMemoryManager() - content_size = consts.MIN_BYTES_FOR_SHARED_MEM_TRANSFER - 10 - content = self.get_random_bytes(content_size) - bytes_datum = bind_meta.Datum(type='bytes', value=content) - is_supported = manager.is_supported(bytes_datum) - self.assertFalse(is_supported) - - def test_large_invalid_string_input_support(self): - """ - Verify that the given input is NOT supported by SharedMemoryManager to - be transfered over shared memory. - The input is string of larger than the allowed size. - """ - manager = SharedMemoryManager() - content_size = consts.MAX_BYTES_FOR_SHARED_MEM_TRANSFER + 10 - num_chars = math.floor(content_size / consts.SIZE_OF_CHAR_BYTES) - # Not using get_random_string to avoid slowing down for creating a large - # random input - content = 'a' * num_chars - string_datum = bind_meta.Datum(type='string', value=content) - is_supported = manager.is_supported(string_datum) - self.assertFalse(is_supported) - - def test_small_invalid_string_input_support(self): - """ - Verify that the given input is NOT supported by SharedMemoryManager to - be transfered over shared memory. - The input is string of smaller than the allowed size. - """ - manager = SharedMemoryManager() - content_size = consts.MIN_BYTES_FOR_SHARED_MEM_TRANSFER - 10 - num_chars = math.floor(content_size / consts.SIZE_OF_CHAR_BYTES) - content = self.get_random_string(num_chars) - string_datum = bind_meta.Datum(type='string', value=content) - is_supported = manager.is_supported(string_datum) - self.assertFalse(is_supported) - - def test_put_bytes(self): - """ - Verify that the given input was successfully put into shared memory. - The input is bytes. - """ - manager = SharedMemoryManager() - content_size = consts.MIN_BYTES_FOR_SHARED_MEM_TRANSFER + 10 - content = self.get_random_bytes(content_size) - shared_mem_meta = manager.put_bytes(content) - self.assertIsNotNone(shared_mem_meta) - self.assertTrue(self.is_valid_uuid(shared_mem_meta.mem_map_name)) - self.assertEqual(content_size, shared_mem_meta.count_bytes) - free_success = manager.free_mem_map(shared_mem_meta.mem_map_name) - self.assertTrue(free_success) - - def test_invalid_put_bytes(self): - """ - Attempt to put bytes using an invalid input and verify that it fails. - """ - manager = SharedMemoryManager() - shared_mem_meta = manager.put_bytes(None) - self.assertIsNone(shared_mem_meta) - - def test_get_bytes(self): - """ - Verify that the output object was successfully gotten from shared - memory. - The output is bytes. - """ - manager = SharedMemoryManager() - content_size = consts.MIN_BYTES_FOR_SHARED_MEM_TRANSFER + 10 - content = self.get_random_bytes(content_size) - shared_mem_meta = manager.put_bytes(content) - mem_map_name = shared_mem_meta.mem_map_name - num_bytes_written = shared_mem_meta.count_bytes - read_content = manager.get_bytes(mem_map_name, offset=0, - count=num_bytes_written) - self.assertEqual(content, read_content) - free_success = manager.free_mem_map(mem_map_name) - self.assertTrue(free_success) - - def test_put_string(self): - """ - Verify that the given input was successfully put into shared memory. - The input is string. - """ - manager = SharedMemoryManager() - content_size = consts.MIN_BYTES_FOR_SHARED_MEM_TRANSFER + 10 - num_chars = math.floor(content_size / consts.SIZE_OF_CHAR_BYTES) - content = self.get_random_string(num_chars) - expected_size = len(content.encode('utf-8')) - shared_mem_meta = manager.put_string(content) - self.assertIsNotNone(shared_mem_meta) - self.assertTrue(self.is_valid_uuid(shared_mem_meta.mem_map_name)) - self.assertEqual(expected_size, shared_mem_meta.count_bytes) - free_success = manager.free_mem_map(shared_mem_meta.mem_map_name) - self.assertTrue(free_success) - - def test_invalid_put_string(self): - """ - Attempt to put a string using an invalid input and verify that it fails. - """ - manager = SharedMemoryManager() - shared_mem_meta = manager.put_string(None) - self.assertIsNone(shared_mem_meta) - - def test_get_string(self): - """ - Verify that the output object was successfully gotten from shared - memory. - The output is string. - """ - manager = SharedMemoryManager() - content_size = consts.MIN_BYTES_FOR_SHARED_MEM_TRANSFER + 10 - num_chars = math.floor(content_size / consts.SIZE_OF_CHAR_BYTES) - content = self.get_random_string(num_chars) - shared_mem_meta = manager.put_string(content) - mem_map_name = shared_mem_meta.mem_map_name - num_bytes_written = shared_mem_meta.count_bytes - read_content = manager.get_string(mem_map_name, offset=0, - count=num_bytes_written) - self.assertEqual(content, read_content) - free_success = manager.free_mem_map(mem_map_name) - self.assertTrue(free_success) - - def test_allocated_mem_maps(self): - """ - Verify that the SharedMemoryManager is tracking the shared memory maps - it has allocated after put operations. - Verify that those shared memory maps are freed and no longer tracked - after attempting to free them. - """ - manager = SharedMemoryManager() - content_size = consts.MIN_BYTES_FOR_SHARED_MEM_TRANSFER + 10 - content = self.get_random_bytes(content_size) - shared_mem_meta = manager.put_bytes(content) - self.assertIsNotNone(shared_mem_meta) - mem_map_name = shared_mem_meta.mem_map_name - is_mem_map_found = mem_map_name in manager.allocated_mem_maps - self.assertTrue(is_mem_map_found) - self.assertEqual(1, len(manager.allocated_mem_maps.keys())) - free_success = manager.free_mem_map(mem_map_name) - self.assertTrue(free_success) - is_mem_map_found = mem_map_name in manager.allocated_mem_maps - self.assertFalse(is_mem_map_found) - self.assertEqual(0, len(manager.allocated_mem_maps.keys())) - - def test_do_not_free_resources_on_dispose(self): - """ - Verify that when the allocated shared memory maps are freed, - their backing resources are not freed. - Note: The shared memory map should no longer be tracked by the - SharedMemoryManager, though. - """ - manager = SharedMemoryManager() - content_size = consts.MIN_BYTES_FOR_SHARED_MEM_TRANSFER + 10 - content = self.get_random_bytes(content_size) - shared_mem_meta = manager.put_bytes(content) - self.assertIsNotNone(shared_mem_meta) - mem_map_name = shared_mem_meta.mem_map_name - is_mem_map_found = mem_map_name in manager.allocated_mem_maps - self.assertTrue(is_mem_map_found) - self.assertEqual(1, len(manager.allocated_mem_maps.keys())) - free_success = manager.free_mem_map(mem_map_name, False) - self.assertTrue(free_success) - is_mem_map_found = mem_map_name in manager.allocated_mem_maps - self.assertFalse(is_mem_map_found) - self.assertEqual(0, len(manager.allocated_mem_maps.keys())) - - def test_invalid_put_allocated_mem_maps(self): - """ - Verify that after an invalid put operation, no shared memory maps were - added to the list of allocated/tracked shared memory maps. - i.e. no resources were leaked for invalid operations. - """ - manager = SharedMemoryManager() - shared_mem_meta = manager.put_bytes(None) - self.assertIsNone(shared_mem_meta) - self.assertEqual(0, len(manager.allocated_mem_maps.keys())) - - def test_invalid_free_mem_map(self): - """ - Attempt to free a shared memory map that does not exist in the list of - allocated/tracked shared memory maps and verify that it fails. - """ - manager = SharedMemoryManager() - mem_map_name = self.get_new_mem_map_name() - free_success = manager.free_mem_map(mem_map_name) - self.assertFalse(free_success) +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import math +import os +import json +import sys +from unittest.mock import patch +from azure_functions_worker.utils.common import is_envvar_true +from azure.functions import meta as bind_meta +from azure_functions_worker import testutils +from azure_functions_worker.bindings.shared_memory_data_transfer \ + import SharedMemoryManager +from azure_functions_worker.bindings.shared_memory_data_transfer \ + import SharedMemoryConstants as consts +from azure_functions_worker.constants \ + import FUNCTIONS_WORKER_SHARED_MEMORY_DATA_TRANSFER_ENABLED + + +class TestSharedMemoryManager(testutils.SharedMemoryTestCase): + """ + Tests for SharedMemoryManager. + """ + def setUp(self): + self.mock_environ = patch.dict('os.environ', os.environ.copy()) + self.mock_sys_module = patch.dict('sys.modules', sys.modules.copy()) + self.mock_sys_path = patch('sys.path', sys.path.copy()) + self.mock_environ.start() + self.mock_sys_module.start() + self.mock_sys_path.start() + + def tearDown(self): + self.mock_sys_path.stop() + self.mock_sys_module.stop() + self.mock_environ.stop() + + def test_is_enabled(self): + """ + Verify that when the AppSetting is enabled, SharedMemoryManager is + enabled. + """ + + # Make sure shared memory data transfer is enabled + was_shmem_env_true = is_envvar_true( + FUNCTIONS_WORKER_SHARED_MEMORY_DATA_TRANSFER_ENABLED) + os.environ.update( + {FUNCTIONS_WORKER_SHARED_MEMORY_DATA_TRANSFER_ENABLED: '1'}) + manager = SharedMemoryManager() + self.assertTrue(manager.is_enabled()) + # Restore the env variable to original value + if not was_shmem_env_true: + os.environ.update( + {FUNCTIONS_WORKER_SHARED_MEMORY_DATA_TRANSFER_ENABLED: '0'}) + + def test_is_disabled(self): + """ + Verify that when the AppSetting is disabled, SharedMemoryManager is + disabled. + """ + # Make sure shared memory data transfer is disabled + was_shmem_env_true = is_envvar_true( + FUNCTIONS_WORKER_SHARED_MEMORY_DATA_TRANSFER_ENABLED) + os.environ.update( + {FUNCTIONS_WORKER_SHARED_MEMORY_DATA_TRANSFER_ENABLED: '0'}) + manager = SharedMemoryManager() + self.assertFalse(manager.is_enabled()) + # Restore the env variable to original value + if was_shmem_env_true: + os.environ.update( + {FUNCTIONS_WORKER_SHARED_MEMORY_DATA_TRANSFER_ENABLED: '1'}) + + def test_bytes_input_support(self): + """ + Verify that the given input is supported by SharedMemoryManager to be + transfered over shared memory. + The input is bytes. + """ + manager = SharedMemoryManager() + content_size = consts.MIN_BYTES_FOR_SHARED_MEM_TRANSFER + 10 + content = self.get_random_bytes(content_size) + bytes_datum = bind_meta.Datum(type='bytes', value=content) + is_supported = manager.is_supported(bytes_datum) + self.assertTrue(is_supported) + + def test_string_input_support(self): + """ + Verify that the given input is supported by SharedMemoryManager to be + transfered over shared memory. + The input is string. + """ + manager = SharedMemoryManager() + content_size = consts.MIN_BYTES_FOR_SHARED_MEM_TRANSFER + 10 + num_chars = math.floor(content_size / consts.SIZE_OF_CHAR_BYTES) + content = self.get_random_string(num_chars) + bytes_datum = bind_meta.Datum(type='string', value=content) + is_supported = manager.is_supported(bytes_datum) + self.assertTrue(is_supported) + + def test_int_input_unsupported(self): + """ + Verify that the given input is unsupported by SharedMemoryManager. + This input is int. + """ + manager = SharedMemoryManager() + datum = bind_meta.Datum(type='int', value=1) + is_supported = manager.is_supported(datum) + self.assertFalse(is_supported) + + def test_double_input_unsupported(self): + """ + Verify that the given input is unsupported by SharedMemoryManager. + This input is double. + """ + manager = SharedMemoryManager() + datum = bind_meta.Datum(type='double', value=1.0) + is_supported = manager.is_supported(datum) + self.assertFalse(is_supported) + + def test_json_input_unsupported(self): + """ + Verify that the given input is unsupported by SharedMemoryManager. + This input is json. + """ + manager = SharedMemoryManager() + content = { + 'name': 'foo', + 'val': 'bar' + } + datum = bind_meta.Datum(type='json', value=json.dumps(content)) + is_supported = manager.is_supported(datum) + self.assertFalse(is_supported) + + def test_collection_string_unsupported(self): + """ + Verify that the given input is unsupported by SharedMemoryManager. + This input is collection_string. + """ + manager = SharedMemoryManager() + content = ['foo', 'bar'] + datum = bind_meta.Datum(type='collection_string', value=content) + is_supported = manager.is_supported(datum) + self.assertFalse(is_supported) + + def test_collection_bytes_unsupported(self): + """ + Verify that the given input is unsupported by SharedMemoryManager. + This input is collection_bytes. + """ + manager = SharedMemoryManager() + content = [b'x01', b'x02'] + datum = bind_meta.Datum(type='collection_bytes', value=content) + is_supported = manager.is_supported(datum) + self.assertFalse(is_supported) + + def test_collection_double_unsupported(self): + """ + Verify that the given input is unsupported by SharedMemoryManager. + This input is collection_double. + """ + manager = SharedMemoryManager() + content = [1.0, 2.0] + datum = bind_meta.Datum(type='collection_double', value=content) + is_supported = manager.is_supported(datum) + self.assertFalse(is_supported) + + def test_collection_sint64_unsupported(self): + """ + Verify that the given input is unsupported by SharedMemoryManager. + This input is collection_sint64. + """ + manager = SharedMemoryManager() + content = [1, 2] + datum = bind_meta.Datum(type='collection_sint64', value=content) + is_supported = manager.is_supported(datum) + self.assertFalse(is_supported) + + def test_large_invalid_bytes_input_support(self): + """ + Verify that the given input is NOT supported by SharedMemoryManager to + be transfered over shared memory. + The input is bytes of larger than the allowed size. + """ + manager = SharedMemoryManager() + content_size = consts.MAX_BYTES_FOR_SHARED_MEM_TRANSFER + 10 + # Not using get_random_bytes to avoid slowing down for creating a large + # random input + content = b'x01' * content_size + bytes_datum = bind_meta.Datum(type='bytes', value=content) + is_supported = manager.is_supported(bytes_datum) + self.assertFalse(is_supported) + + def test_small_invalid_bytes_input_support(self): + """ + Verify that the given input is NOT supported by SharedMemoryManager to + be transfered over shared memory. + The input is bytes of smaller than the allowed size. + """ + manager = SharedMemoryManager() + content_size = consts.MIN_BYTES_FOR_SHARED_MEM_TRANSFER - 10 + content = self.get_random_bytes(content_size) + bytes_datum = bind_meta.Datum(type='bytes', value=content) + is_supported = manager.is_supported(bytes_datum) + self.assertFalse(is_supported) + + def test_large_invalid_string_input_support(self): + """ + Verify that the given input is NOT supported by SharedMemoryManager to + be transfered over shared memory. + The input is string of larger than the allowed size. + """ + manager = SharedMemoryManager() + content_size = consts.MAX_BYTES_FOR_SHARED_MEM_TRANSFER + 10 + num_chars = math.floor(content_size / consts.SIZE_OF_CHAR_BYTES) + # Not using get_random_string to avoid slowing down for creating a large + # random input + content = 'a' * num_chars + string_datum = bind_meta.Datum(type='string', value=content) + is_supported = manager.is_supported(string_datum) + self.assertFalse(is_supported) + + def test_small_invalid_string_input_support(self): + """ + Verify that the given input is NOT supported by SharedMemoryManager to + be transfered over shared memory. + The input is string of smaller than the allowed size. + """ + manager = SharedMemoryManager() + content_size = consts.MIN_BYTES_FOR_SHARED_MEM_TRANSFER - 10 + num_chars = math.floor(content_size / consts.SIZE_OF_CHAR_BYTES) + content = self.get_random_string(num_chars) + string_datum = bind_meta.Datum(type='string', value=content) + is_supported = manager.is_supported(string_datum) + self.assertFalse(is_supported) + + def test_put_bytes(self): + """ + Verify that the given input was successfully put into shared memory. + The input is bytes. + """ + manager = SharedMemoryManager() + content_size = consts.MIN_BYTES_FOR_SHARED_MEM_TRANSFER + 10 + content = self.get_random_bytes(content_size) + shared_mem_meta = manager.put_bytes(content) + self.assertIsNotNone(shared_mem_meta) + self.assertTrue(self.is_valid_uuid(shared_mem_meta.mem_map_name)) + self.assertEqual(content_size, shared_mem_meta.count_bytes) + free_success = manager.free_mem_map(shared_mem_meta.mem_map_name) + self.assertTrue(free_success) + + def test_invalid_put_bytes(self): + """ + Attempt to put bytes using an invalid input and verify that it fails. + """ + manager = SharedMemoryManager() + shared_mem_meta = manager.put_bytes(None) + self.assertIsNone(shared_mem_meta) + + def test_get_bytes(self): + """ + Verify that the output object was successfully gotten from shared + memory. + The output is bytes. + """ + manager = SharedMemoryManager() + content_size = consts.MIN_BYTES_FOR_SHARED_MEM_TRANSFER + 10 + content = self.get_random_bytes(content_size) + shared_mem_meta = manager.put_bytes(content) + mem_map_name = shared_mem_meta.mem_map_name + num_bytes_written = shared_mem_meta.count_bytes + read_content = manager.get_bytes(mem_map_name, offset=0, + count=num_bytes_written) + self.assertEqual(content, read_content) + free_success = manager.free_mem_map(mem_map_name) + self.assertTrue(free_success) + + def test_put_string(self): + """ + Verify that the given input was successfully put into shared memory. + The input is string. + """ + manager = SharedMemoryManager() + content_size = consts.MIN_BYTES_FOR_SHARED_MEM_TRANSFER + 10 + num_chars = math.floor(content_size / consts.SIZE_OF_CHAR_BYTES) + content = self.get_random_string(num_chars) + expected_size = len(content.encode('utf-8')) + shared_mem_meta = manager.put_string(content) + self.assertIsNotNone(shared_mem_meta) + self.assertTrue(self.is_valid_uuid(shared_mem_meta.mem_map_name)) + self.assertEqual(expected_size, shared_mem_meta.count_bytes) + free_success = manager.free_mem_map(shared_mem_meta.mem_map_name) + self.assertTrue(free_success) + + def test_invalid_put_string(self): + """ + Attempt to put a string using an invalid input and verify that it fails. + """ + manager = SharedMemoryManager() + shared_mem_meta = manager.put_string(None) + self.assertIsNone(shared_mem_meta) + + def test_get_string(self): + """ + Verify that the output object was successfully gotten from shared + memory. + The output is string. + """ + manager = SharedMemoryManager() + content_size = consts.MIN_BYTES_FOR_SHARED_MEM_TRANSFER + 10 + num_chars = math.floor(content_size / consts.SIZE_OF_CHAR_BYTES) + content = self.get_random_string(num_chars) + shared_mem_meta = manager.put_string(content) + mem_map_name = shared_mem_meta.mem_map_name + num_bytes_written = shared_mem_meta.count_bytes + read_content = manager.get_string(mem_map_name, offset=0, + count=num_bytes_written) + self.assertEqual(content, read_content) + free_success = manager.free_mem_map(mem_map_name) + self.assertTrue(free_success) + + def test_allocated_mem_maps(self): + """ + Verify that the SharedMemoryManager is tracking the shared memory maps + it has allocated after put operations. + Verify that those shared memory maps are freed and no longer tracked + after attempting to free them. + """ + manager = SharedMemoryManager() + content_size = consts.MIN_BYTES_FOR_SHARED_MEM_TRANSFER + 10 + content = self.get_random_bytes(content_size) + shared_mem_meta = manager.put_bytes(content) + self.assertIsNotNone(shared_mem_meta) + mem_map_name = shared_mem_meta.mem_map_name + is_mem_map_found = mem_map_name in manager.allocated_mem_maps + self.assertTrue(is_mem_map_found) + self.assertEqual(1, len(manager.allocated_mem_maps.keys())) + free_success = manager.free_mem_map(mem_map_name) + self.assertTrue(free_success) + is_mem_map_found = mem_map_name in manager.allocated_mem_maps + self.assertFalse(is_mem_map_found) + self.assertEqual(0, len(manager.allocated_mem_maps.keys())) + + def test_do_not_free_resources_on_dispose(self): + """ + Verify that when the allocated shared memory maps are freed, + their backing resources are not freed. + Note: The shared memory map should no longer be tracked by the + SharedMemoryManager, though. + """ + manager = SharedMemoryManager() + content_size = consts.MIN_BYTES_FOR_SHARED_MEM_TRANSFER + 10 + content = self.get_random_bytes(content_size) + shared_mem_meta = manager.put_bytes(content) + self.assertIsNotNone(shared_mem_meta) + mem_map_name = shared_mem_meta.mem_map_name + is_mem_map_found = mem_map_name in manager.allocated_mem_maps + self.assertTrue(is_mem_map_found) + self.assertEqual(1, len(manager.allocated_mem_maps.keys())) + free_success = manager.free_mem_map(mem_map_name, False) + self.assertTrue(free_success) + is_mem_map_found = mem_map_name in manager.allocated_mem_maps + self.assertFalse(is_mem_map_found) + self.assertEqual(0, len(manager.allocated_mem_maps.keys())) + + def test_invalid_put_allocated_mem_maps(self): + """ + Verify that after an invalid put operation, no shared memory maps were + added to the list of allocated/tracked shared memory maps. + i.e. no resources were leaked for invalid operations. + """ + manager = SharedMemoryManager() + shared_mem_meta = manager.put_bytes(None) + self.assertIsNone(shared_mem_meta) + self.assertEqual(0, len(manager.allocated_mem_maps.keys())) + + def test_invalid_free_mem_map(self): + """ + Attempt to free a shared memory map that does not exist in the list of + allocated/tracked shared memory maps and verify that it fails. + """ + manager = SharedMemoryManager() + mem_map_name = self.get_new_mem_map_name() + free_success = manager.free_mem_map(mem_map_name) + self.assertFalse(free_success) From e23fe11a720c0de75af7e69b0c0e26047860d796 Mon Sep 17 00:00:00 2001 From: Gohar Irfan Chaudhry Date: Tue, 21 Sep 2021 09:24:35 -0700 Subject: [PATCH 07/11] Addressing comments --- azure_functions_worker/dispatcher.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/azure_functions_worker/dispatcher.py b/azure_functions_worker/dispatcher.py index 0a40b348e..f93c0b6e4 100644 --- a/azure_functions_worker/dispatcher.py +++ b/azure_functions_worker/dispatcher.py @@ -535,12 +535,13 @@ async def _handle__close_shared_memory_resources_request(self, req): results = {mem_map_name: False for mem_map_name in map_names} try: - for mem_map_name in map_names: + for map_name in map_names: try: - success = self._shmem_mgr.free_mem_map(mem_map_name, False) - results[mem_map_name] = success + to_delete = False + success = self._shmem_mgr.free_mem_map(map_name, to_delete) + results[map_name] = success except Exception as e: - logger.error(f'Cannot free memory map {mem_map_name} - {e}', + logger.error(f'Cannot free memory map {map_name} - {e}', exc_info=True) finally: response = protos.CloseSharedMemoryResourcesResponse( From 0fbf2256ec29dae8d427c47eff713256e748c715 Mon Sep 17 00:00:00 2001 From: Gohar Irfan Chaudhry Date: Thu, 30 Sep 2021 18:52:48 -0700 Subject: [PATCH 08/11] Update azure_functions_worker/bindings/meta.py Co-authored-by: Varad Meru --- azure_functions_worker/bindings/meta.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/azure_functions_worker/bindings/meta.py b/azure_functions_worker/bindings/meta.py index 7526a952c..a0495a6d0 100644 --- a/azure_functions_worker/bindings/meta.py +++ b/azure_functions_worker/bindings/meta.py @@ -112,11 +112,8 @@ def get_datum(binding: str, obj: typing.Any, def is_cache_supported(datum: datumdef.Datum): - if datum.type == 'bytes': - return True - elif datum.type == 'string': - return True - return False + supported_datatypes = ('bytes', 'string') + return datum.type in supported_datatypes def to_outgoing_proto(binding: str, obj: typing.Any, *, From accd1f893c83e055b3fec6168feea08563b50ed4 Mon Sep 17 00:00:00 2001 From: Gohar Irfan Chaudhry Date: Thu, 30 Sep 2021 21:52:27 -0700 Subject: [PATCH 09/11] Addressing comments --- azure_functions_worker/bindings/meta.py | 17 ++++++++++++----- azure_functions_worker/dispatcher.py | 11 ++++++++++- 2 files changed, 22 insertions(+), 6 deletions(-) diff --git a/azure_functions_worker/bindings/meta.py b/azure_functions_worker/bindings/meta.py index a0495a6d0..293a18b0f 100644 --- a/azure_functions_worker/bindings/meta.py +++ b/azure_functions_worker/bindings/meta.py @@ -7,6 +7,7 @@ from . import datumdef from . import generic +from .shared_memory_data_transfer import SharedMemoryManager PB_TYPE = 'rpc_data' PB_TYPE_DATA = 'data' @@ -62,7 +63,7 @@ def from_incoming_proto( pb: protos.ParameterBinding, *, pytype: typing.Optional[type], trigger_metadata: typing.Optional[typing.Dict[str, protos.TypedData]], - shmem_mgr) -> typing.Any: + shmem_mgr: SharedMemoryManager) -> typing.Any: binding = get_binding(binding) if trigger_metadata: metadata = { @@ -111,7 +112,7 @@ def get_datum(binding: str, obj: typing.Any, return datum -def is_cache_supported(datum: datumdef.Datum): +def _does_datatype_support_caching(datum: datumdef.Datum): supported_datatypes = ('bytes', 'string') return datum.type in supported_datatypes @@ -125,15 +126,21 @@ def to_outgoing_proto(binding: str, obj: typing.Any, *, def to_outgoing_param_binding(binding: str, obj: typing.Any, *, pytype: typing.Optional[type], out_name: str, - shmem_mgr, + shmem_mgr: SharedMemoryManager, is_function_data_cache_enabled: bool) \ -> protos.ParameterBinding: datum = get_datum(binding, obj, pytype) shared_mem_value = None # If shared memory is enabled and supported for the given datum, try to - # transfer to host over shared memory as a default + # transfer to host over shared memory as a default. + # If caching is enabled, then also check if this type is supported - if so, + # transfer over shared memory. + # In case of caching, some conditions like object size may not be + # applicable since even small objects are also allowed to be cached. + is_cache_supported = is_function_data_cache_enabled and \ + _does_datatype_support_caching(datum) can_transfer_over_shmem = shmem_mgr.is_supported(datum) or \ - (is_function_data_cache_enabled and is_cache_supported(datum)) + is_cache_supported if shmem_mgr.is_enabled() and can_transfer_over_shmem: shared_mem_value = datumdef.Datum.to_rpc_shared_memory(datum, shmem_mgr) # Check if data was written into shared memory diff --git a/azure_functions_worker/dispatcher.py b/azure_functions_worker/dispatcher.py index f93c0b6e4..aed478699 100644 --- a/azure_functions_worker/dispatcher.py +++ b/azure_functions_worker/dispatcher.py @@ -537,7 +537,16 @@ async def _handle__close_shared_memory_resources_request(self, req): try: for map_name in map_names: try: - to_delete = False + if self._function_data_cache_enabled: + # If the cache is enabled, let the host decide when to + # delete the resources. + # Just drop the reference from the worker. + to_delete = False + else: + # If the cache is not enabled, the worker should free + # the resources as at this point the host has read the + # memory maps and does not need them. + to_delete = True success = self._shmem_mgr.free_mem_map(map_name, to_delete) results[map_name] = success except Exception as e: From 73ecf8af9ad1f607ec8eb2bf206e1796175268f4 Mon Sep 17 00:00:00 2001 From: Gohar Irfan Chaudhry Date: Fri, 1 Oct 2021 14:51:34 -0700 Subject: [PATCH 10/11] Addressing comments --- azure_functions_worker/dispatcher.py | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/azure_functions_worker/dispatcher.py b/azure_functions_worker/dispatcher.py index aed478699..81c3f6d8f 100644 --- a/azure_functions_worker/dispatcher.py +++ b/azure_functions_worker/dispatcher.py @@ -526,6 +526,10 @@ async def _handle__close_shared_memory_resources_request(self, req): invocation. This is called after the functions host is done reading the output from the worker and wants the worker to free up those resources. + If the cache is enabled, let the host decide when to delete the + resources. Just drop the reference from the worker. + If the cache is not enabled, the worker should free the resources as at + this point the host has read the memory maps and does not need them. """ close_request = req.close_shared_memory_resources_request map_names = close_request.map_names @@ -537,17 +541,10 @@ async def _handle__close_shared_memory_resources_request(self, req): try: for map_name in map_names: try: - if self._function_data_cache_enabled: - # If the cache is enabled, let the host decide when to - # delete the resources. - # Just drop the reference from the worker. - to_delete = False - else: - # If the cache is not enabled, the worker should free - # the resources as at this point the host has read the - # memory maps and does not need them. - to_delete = True - success = self._shmem_mgr.free_mem_map(map_name, to_delete) + to_delete_resources = \ + False if self._function_data_cache_enabled else True + success = self._shmem_mgr.free_mem_map(map_name, + to_delete_resources) results[map_name] = success except Exception as e: logger.error(f'Cannot free memory map {map_name} - {e}', From f577ea8c358fb39e2545fde37072f16ed7946a67 Mon Sep 17 00:00:00 2001 From: Gohar Irfan Chaudhry Date: Fri, 1 Oct 2021 15:00:27 -0700 Subject: [PATCH 11/11] Addressing comments for cleaning up shared memory usage checks in meta.py --- azure_functions_worker/bindings/meta.py | 41 ++++++++++++++++++------- 1 file changed, 30 insertions(+), 11 deletions(-) diff --git a/azure_functions_worker/bindings/meta.py b/azure_functions_worker/bindings/meta.py index 293a18b0f..52867a000 100644 --- a/azure_functions_worker/bindings/meta.py +++ b/azure_functions_worker/bindings/meta.py @@ -117,6 +117,34 @@ def _does_datatype_support_caching(datum: datumdef.Datum): return datum.type in supported_datatypes +def _can_transfer_over_shmem(shmem_mgr: SharedMemoryManager, + is_function_data_cache_enabled: bool, + datum: datumdef.Datum): + """ + If shared memory is enabled and supported for the given datum, try to + transfer to host over shared memory as a default. + If caching is enabled, then also check if this type is supported - if so, + transfer over shared memory. + In case of caching, some conditions like object size may not be + applicable since even small objects are also allowed to be cached. + """ + if not shmem_mgr.is_enabled(): + # If shared memory usage is not enabled, no further checks required + return False + if shmem_mgr.is_supported(datum): + # If transferring this object over shared memory is supported, do so. + return True + if is_function_data_cache_enabled and _does_datatype_support_caching(datum): + # If caching is enabled and this object can be cached, transfer over + # shared memory (since the cache uses shared memory). + # In this case, some requirements (like object size) for using shared + # memory may be ignored since we want to support caching of small + # objects (those that have sizes smaller that the minimum we transfer + # over shared memory when the cache is not enabled) as well. + return True + return False + + def to_outgoing_proto(binding: str, obj: typing.Any, *, pytype: typing.Optional[type]) -> protos.TypedData: datum = get_datum(binding, obj, pytype) @@ -131,17 +159,8 @@ def to_outgoing_param_binding(binding: str, obj: typing.Any, *, -> protos.ParameterBinding: datum = get_datum(binding, obj, pytype) shared_mem_value = None - # If shared memory is enabled and supported for the given datum, try to - # transfer to host over shared memory as a default. - # If caching is enabled, then also check if this type is supported - if so, - # transfer over shared memory. - # In case of caching, some conditions like object size may not be - # applicable since even small objects are also allowed to be cached. - is_cache_supported = is_function_data_cache_enabled and \ - _does_datatype_support_caching(datum) - can_transfer_over_shmem = shmem_mgr.is_supported(datum) or \ - is_cache_supported - if shmem_mgr.is_enabled() and can_transfer_over_shmem: + if _can_transfer_over_shmem(shmem_mgr, is_function_data_cache_enabled, + datum): shared_mem_value = datumdef.Datum.to_rpc_shared_memory(datum, shmem_mgr) # Check if data was written into shared memory if shared_mem_value is not None: