diff --git a/README.md b/README.md index 474ca6c..050b673 100644 --- a/README.md +++ b/README.md @@ -93,22 +93,21 @@ from Algorithmia import ADK # API calls will begin at the apply() method, with the request body passed as 'input' # For more details, see algorithmia.com/developers/algorithm-development/languages -def apply(input, globals): +def apply(input, modelData): # If your apply function uses state that's loaded into memory via load, you can pass that loaded state to your apply # function by defining an additional "globals" parameter in your apply function. - return "hello {} {}".format(str(input), str(globals['payload'])) + return "hello {} {}".format(str(input), str(modelData.user_data['payload'])) -def load(): +def load(modelData): # Here you can optionally define a function that will be called when the algorithm is loaded. # The return object from this function can be passed directly as input to your apply function. # A great example would be any model files that need to be available to this algorithm # during runtime. # Any variables returned here, will be passed as the secondary argument to your 'algorithm' function - globals = {} - globals['payload'] = "Loading has been completed." - return globals + modelData.user_data['payload'] = "Loading has been completed." + return modelData # This turns your library code into an algorithm that can run on the platform. @@ -129,18 +128,19 @@ from PIL import Image import json from torchvision import models, transforms -def load_labels(label_path, client): - local_path = client.file(label_path).getFile().name - with open(local_path) as f: + +client = Algorithmia.client() + +def load_labels(label_path): + with open(label_path) as f: labels = json.load(f) labels = [labels[str(k)][1] for k in range(len(labels))] return labels -def load_model(model_paths, client): +def load_model(model_path): model = models.squeezenet1_1() - local_file = client.file(model_paths["filepath"]).getFile().name - weights = torch.load(local_file) + weights = torch.load(model_path) model.load_state_dict(weights) return model.float().eval() @@ -174,17 +174,15 @@ def infer_image(image_url, n, globals): return result -def load(manifest): +def load(modelData): - globals = {} - client = Algorithmia.client() - globals["SMID_ALGO"] = "algo://util/SmartImageDownloader/0.2.x" - globals["model"] = load_model(manifest["squeezenet"], client) - globals["labels"] = load_labels(manifest["label_file"], client) - return globals + modelData.user_data["SMID_ALGO"] = "algo://util/SmartImageDownloader/0.2.x" + modelData.user_data["model"] = load_model(modelData.get_model("squeezenet")) + modelData.user_data["labels"] = load_labels(modelData.get_model("labels")) + return modelData -def apply(input, globals): +def apply(input, modelData): if isinstance(input, dict): if "n" in input: n = input["n"] @@ -192,10 +190,10 @@ def apply(input, globals): n = 3 if "data" in input: if isinstance(input["data"], str): - output = infer_image(input["data"], n, globals) + output = infer_image(input["data"], n, modelData.user_data) elif isinstance(input["data"], list): for row in input["data"]: - row["predictions"] = infer_image(row["image_url"], n, globals) + row["predictions"] = infer_image(row["image_url"], n, modelData.user_data) output = input["data"] else: raise Exception("\"data\" must be a image url or a list of image urls (with labels)") @@ -206,7 +204,7 @@ def apply(input, globals): raise Exception("input must be a json object") -algorithm = ADK(apply_func=apply, load_func=load) +algorithm = ADK(apply_func=apply, load_func=load, client=client) algorithm.init({"data": "https://i.imgur.com/bXdORXl.jpeg"}) ``` diff --git a/adk/ADK.py b/adk/ADK.py index 2d3f587..fb1dd01 100644 --- a/adk/ADK.py +++ b/adk/ADK.py @@ -1,39 +1,55 @@ -import base64 import inspect import json import os import sys -import traceback -import six +import Algorithmia +from adk.io import create_exception, format_data, format_response +from adk.manifest.modeldata import ModelData class ADK(object): - def __init__(self, apply_func, load_func=None): + def __init__(self, apply_func, load_func=None, client=None): """ Creates the adk object :param apply_func: A required function that can have an arity of 1-2, depending on if loading occurs - :param load_func: An optional supplier function used if load time events are required, has an arity of 0. + :param load_func: An optional supplier function used if load time events are required, if a model manifest is provided; + the function may have a single `manifest` parameter to interact with the model manifest, otherwise must have no parameters. + :param client: A Algorithmia Client instance that might be user defined, + and is used for interacting with a model manifest file; if defined. """ self.FIFO_PATH = "/tmp/algoout" + + if client: + self.client = client + else: + self.client = Algorithmia.client() + apply_args, _, _, _, _, _, _ = inspect.getfullargspec(apply_func) + self.apply_arity = len(apply_args) if load_func: load_args, _, _, _, _, _, _ = inspect.getfullargspec(load_func) - if len(load_args) > 0: - raise Exception("load function must not have parameters") + self.load_arity = len(load_args) + if self.load_arity != 1: + raise Exception("load function expects 1 parameter to be used to store algorithm state") self.load_func = load_func else: self.load_func = None - if len(apply_args) > 2 or len(apply_args) == 0: - raise Exception("apply function may have between 1 and 2 parameters, not {}".format(len(apply_args))) self.apply_func = apply_func self.is_local = not os.path.exists(self.FIFO_PATH) self.load_result = None self.loading_exception = None + self.manifest_path = "model_manifest.json.freeze" + self.model_data = self.init_manifest(self.manifest_path) + + def init_manifest(self, path): + return ModelData(self.client, path) def load(self): try: + if self.model_data.available(): + self.model_data.initialize() if self.load_func: - self.load_result = self.load_func() + self.load_result = self.load_func(self.model_data) except Exception as e: self.loading_exception = e finally: @@ -45,55 +61,16 @@ def load(self): def apply(self, payload): try: - if self.load_result: + if self.load_result and self.apply_arity == 2: apply_result = self.apply_func(payload, self.load_result) else: apply_result = self.apply_func(payload) - response_obj = self.format_response(apply_result) + response_obj = format_response(apply_result) return response_obj except Exception as e: - response_obj = self.create_exception(e) + response_obj = create_exception(e) return response_obj - def format_data(self, request): - if request["content_type"] in ["text", "json"]: - data = request["data"] - elif request["content_type"] == "binary": - data = self.wrap_binary_data(base64.b64decode(request["data"])) - else: - raise Exception("Invalid content_type: {}".format(request["content_type"])) - return data - - def is_binary(self, arg): - if six.PY3: - return isinstance(arg, base64.bytes_types) - - return isinstance(arg, bytearray) - - def wrap_binary_data(self, data): - if six.PY3: - return bytes(data) - else: - return bytearray(data) - - def format_response(self, response): - if self.is_binary(response): - content_type = "binary" - response = str(base64.b64encode(response), "utf-8") - elif isinstance(response, six.string_types) or isinstance(response, six.text_type): - content_type = "text" - else: - content_type = "json" - response_string = json.dumps( - { - "result": response, - "metadata": { - "content_type": content_type - } - } - ) - return response_string - def write_to_pipe(self, payload, pprint=print): if self.is_local: if isinstance(payload, dict): @@ -109,40 +86,24 @@ def write_to_pipe(self, payload, pprint=print): if os.name == "nt": sys.stdin = payload - def create_exception(self, exception, loading_exception=False): - if hasattr(exception, "error_type"): - error_type = exception.error_type - elif loading_exception: - error_type = "LoadingError" - else: - error_type = "AlgorithmError" - response = json.dumps({ - "error": { - "message": str(exception), - "stacktrace": traceback.format_exc(), - "error_type": error_type, - } - }) - return response - def process_local(self, local_payload, pprint): result = self.apply(local_payload) self.write_to_pipe(result, pprint=pprint) def init(self, local_payload=None, pprint=print): - self.load() - if self.is_local and local_payload: + self.load() + if self.is_local and local_payload: + if self.loading_exception: + load_error = create_exception(self.loading_exception, loading_exception=True) + self.write_to_pipe(load_error, pprint=pprint) + self.process_local(local_payload, pprint) + else: + for line in sys.stdin: + request = json.loads(line) + formatted_input = format_data(request) if self.loading_exception: - load_error = self.create_exception(self.loading_exception, loading_exception=True) + load_error = create_exception(self.loading_exception, loading_exception=True) self.write_to_pipe(load_error, pprint=pprint) - self.process_local(local_payload, pprint) - else: - for line in sys.stdin: - request = json.loads(line) - formatted_input = self.format_data(request) - if self.loading_exception: - load_error = self.create_exception(self.loading_exception, loading_exception=True) - self.write_to_pipe(load_error, pprint=pprint) - else: - result = self.apply(formatted_input) - self.write_to_pipe(result) + else: + result = self.apply(formatted_input) + self.write_to_pipe(result) diff --git a/adk/io.py b/adk/io.py new file mode 100644 index 0000000..be2045c --- /dev/null +++ b/adk/io.py @@ -0,0 +1,64 @@ +import traceback +import six +import base64 +import json + + +def format_data(request): + if request["content_type"] in ["text", "json"]: + data = request["data"] + elif request["content_type"] == "binary": + data = wrap_binary_data(base64.b64decode(request["data"])) + else: + raise Exception("Invalid content_type: {}".format(request["content_type"])) + return data + + +def is_binary(arg): + if six.PY3: + return isinstance(arg, base64.bytes_types) + + return isinstance(arg, bytearray) + + +def wrap_binary_data(data): + if six.PY3: + return bytes(data) + else: + return bytearray(data) + + +def format_response(response): + if is_binary(response): + content_type = "binary" + response = str(base64.b64encode(response), "utf-8") + elif isinstance(response, six.string_types) or isinstance(response, six.text_type): + content_type = "text" + else: + content_type = "json" + response_string = json.dumps( + { + "result": response, + "metadata": { + "content_type": content_type + } + } + ) + return response_string + + +def create_exception(exception, loading_exception=False): + if hasattr(exception, "error_type"): + error_type = exception.error_type + elif loading_exception: + error_type = "LoadingError" + else: + error_type = "AlgorithmError" + response = json.dumps({ + "error": { + "message": str(exception), + "stacktrace": traceback.format_exc(), + "error_type": error_type, + } + }) + return response diff --git a/adk/manifest/__init__.py b/adk/manifest/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/adk/manifest/classes.py b/adk/manifest/classes.py new file mode 100644 index 0000000..f2d2235 --- /dev/null +++ b/adk/manifest/classes.py @@ -0,0 +1,6 @@ + +class FileData(object): + def __init__(self, md5_checksum, file_path): + self.md5_checksum = md5_checksum + self.file_path = file_path + diff --git a/adk/manifest/modeldata.py b/adk/manifest/modeldata.py new file mode 100644 index 0000000..1033d1e --- /dev/null +++ b/adk/manifest/modeldata.py @@ -0,0 +1,95 @@ +import os +import json +import hashlib +from adk.manifest.classes import FileData + + +class ModelData(object): + def __init__(self, client, model_manifest_path): + self.manifest_freeze_path = model_manifest_path + self.manifest_data = get_manifest(self.manifest_freeze_path) + self.client = client + self.models = {} + self.user_data = {} + self.system_data = {} + + def available(self): + if self.manifest_data: + return True + else: + return False + + def initialize(self): + if self.client is None: + raise Exception("Client was not defined, please define a Client when using Model Manifests.") + for required_file in self.manifest_data['required_files']: + name = required_file['name'] + if name in self.models: + raise Exception("Duplicate 'name' detected. \n" + + name + " was found to be used by more than one data file, please rename.") + expected_hash = required_file['md5_checksum'] + with self.client.file(required_file['source_uri']).getFile() as f: + local_data_path = f.name + real_hash = md5_for_file(local_data_path) + if real_hash != expected_hash and required_file['fail_on_tamper']: + raise Exception("Model File Mismatch for " + name + + "\nexpected hash: " + expected_hash + "\nreal hash: " + real_hash) + else: + self.models[name] = FileData(real_hash, local_data_path) + + def get_model(self, model_name): + if model_name in self.models: + return self.models[model_name].file_path + elif len([optional for optional in self.manifest_data['optional_files'] if + optional['name'] == model_name]) > 0: + self.find_optional_model(model_name) + return self.models[model_name].file_path + else: + raise Exception("model name " + model_name + " not found in manifest") + + def find_optional_model(self, file_name): + + found_models = [optional for optional in self.manifest_data['optional_files'] if + optional['name'] == file_name] + if len(found_models) == 0: + raise Exception("file with name '" + file_name + "' not found in model manifest.") + model_info = found_models[0] + self.models[file_name] = {} + expected_hash = model_info['md5_checksum'] + with self.client.file(model_info['source_uri']).getFile() as f: + local_data_path = f.name + real_hash = md5_for_file(local_data_path) + if real_hash != expected_hash and model_info['fail_on_tamper']: + raise Exception("Model File Mismatch for " + file_name + + "\nexpected hash: " + expected_hash + "\nreal hash: " + real_hash) + else: + self.models[file_name] = FileData(real_hash, local_data_path) + + +def get_manifest(manifest_path): + if os.path.exists(manifest_path): + with open(manifest_path) as f: + manifest_data = json.load(f) + expected_lock_checksum = manifest_data.get('lock_checksum') + del manifest_data['lock_checksum'] + detected_lock_checksum = md5_for_str(str(manifest_data)) + if expected_lock_checksum != detected_lock_checksum: + raise Exception("Manifest FreezeFile Tamper Detected; please use the CLI and 'algo freeze' to rebuild your " + "algorithm's freeze file.") + return manifest_data + else: + return None + + +def md5_for_file(fname): + hash_md5 = hashlib.md5() + with open(fname, "rb") as f: + for chunk in iter(lambda: f.read(4096), b""): + hash_md5.update(chunk) + return str(hash_md5.hexdigest()) + + +def md5_for_str(content): + hash_md5 = hashlib.md5() + hash_md5.update(content.encode()) + return str(hash_md5.hexdigest()) diff --git a/examples/loaded_state_hello_world/src/Algorithm.py b/examples/loaded_state_hello_world/src/Algorithm.py index d7fe987..a58fd71 100644 --- a/examples/loaded_state_hello_world/src/Algorithm.py +++ b/examples/loaded_state_hello_world/src/Algorithm.py @@ -4,22 +4,21 @@ # API calls will begin at the apply() method, with the request body passed as 'input' # For more details, see algorithmia.com/developers/algorithm-development/languages -def apply(input, globals): +def apply(input, modelData): # If your apply function uses state that's loaded into memory via load, you can pass that loaded state to your apply # function by defining an additional "globals" parameter in your apply function. - return "hello {} {}".format(str(input), str(globals['payload'])) + return "hello {} {}".format(str(input), str(modelData.user_data['payload'])) -def load(): +def load(modelData): # Here you can optionally define a function that will be called when the algorithm is loaded. # The return object from this function can be passed directly as input to your apply function. # A great example would be any model files that need to be available to this algorithm # during runtime. # Any variables returned here, will be passed as the secondary argument to your 'algorithm' function - globals = {} - globals['payload'] = "Loading has been completed." - return globals + modelData.user_data['payload'] = "Loading has been completed." + return modelData # This turns your library code into an algorithm that can run on the platform. diff --git a/examples/pytorch_image_classification/model_manifest.json b/examples/pytorch_image_classification/model_manifest.json index 5194941..ba6cbf5 100644 --- a/examples/pytorch_image_classification/model_manifest.json +++ b/examples/pytorch_image_classification/model_manifest.json @@ -1,14 +1,29 @@ { - "label_file": { - "filepath": "data://AlgorithmiaSE/image_cassification_demo/imagenet_class_index.json", - "md5_hash": "c2c37ea517e94d9795004a39431a14cb", - "origin_ref": "this file came from imagenet.org", - "uploaded_utc": "2021-05-03-11:05" - }, - "squeezenet": { - "filepath": "data://AlgorithmiaSE/image_cassification_demo/squeezenet1_1-f364aa15.pth", - "md5_hash": "46a44d32d2c5c07f7f66324bef4c7266", - "origin_ref": "From https://download.pytorch.org/models/squeezenet1_1-f364aa15.pth", - "uploaded_utc": "2021-05-03-11:05" + "required_files" : [ + { "name": "squeezenet", + "source_uri": "data://AlgorithmiaSE/image_cassification_demo/squeezenet1_1-f364aa15.pth", + "fail_on_tamper": true, + "metadata": { + "dataset_md5_checksum": "46a44d32d2c5c07f7f66324bef4c7266" + } + }, + { + "name": "labels", + "source_uri": "data://AlgorithmiaSE/image_cassification_demo/imagenet_class_index.json", + "fail_on_tamper": true, + "metadata": { + "dataset_md5_checksum": "46a44d32d2c5c07f7f66324bef4c7266" + } } + ], + "optional_files": [ + { + "name": "mobilenet", + "source_uri": "data://AlgorithmiaSE/image_cassification_demo/mobilenet_v2-b0353104.pth", + "fail_on_tamper": false, + "metadata": { + "dataset_md5_checksum": "46a44d32d2c5c07f7f66324bef4c7266" + } + } + ] } \ No newline at end of file diff --git a/examples/pytorch_image_classification/model_manifest.json.freeze b/examples/pytorch_image_classification/model_manifest.json.freeze new file mode 100644 index 0000000..11d4203 --- /dev/null +++ b/examples/pytorch_image_classification/model_manifest.json.freeze @@ -0,0 +1,34 @@ +{ + "required_files":[ + { + "name":"squeezenet", + "source_uri":"data://AlgorithmiaSE/image_cassification_demo/squeezenet1_1-f364aa15.pth", + "fail_on_tamper":true, + "metadata":{ + "dataset_md5_checksum":"46a44d32d2c5c07f7f66324bef4c7266" + }, + "md5_checksum":"46a44d32d2c5c07f7f66324bef4c7266" + }, + { + "name":"labels", + "source_uri":"data://AlgorithmiaSE/image_cassification_demo/imagenet_class_index.json", + "fail_on_tamper":true, + "metadata":{ + "dataset_md5_checksum":"46a44d32d2c5c07f7f66324bef4c7266" + }, + "md5_checksum":"c2c37ea517e94d9795004a39431a14cb" + } + ], + "optional_files":[ + { + "name":"mobilenet", + "source_uri":"data://AlgorithmiaSE/image_cassification_demo/mobilenet_v2-b0353104.pth", + "fail_on_tamper":false, + "metadata":{ + "dataset_md5_checksum":"46a44d32d2c5c07f7f66324bef4c7266" + } + } + ], + "timestamp":"1633450866.985464", + "lock_checksum":"24f5eca888d87661ca6fc08042e40cb7" +} \ No newline at end of file diff --git a/examples/pytorch_image_classification/src/Algorithm.py b/examples/pytorch_image_classification/src/Algorithm.py index 35d9431..85e8e55 100644 --- a/examples/pytorch_image_classification/src/Algorithm.py +++ b/examples/pytorch_image_classification/src/Algorithm.py @@ -5,18 +5,19 @@ import json from torchvision import models, transforms -def load_labels(label_path, client): - local_path = client.file(label_path).getFile().name - with open(local_path) as f: + +client = Algorithmia.client() + +def load_labels(label_path): + with open(label_path) as f: labels = json.load(f) labels = [labels[str(k)][1] for k in range(len(labels))] return labels -def load_model(model_paths, client): +def load_model(model_path): model = models.squeezenet1_1() - local_file = client.file(model_paths["filepath"]).getFile().name - weights = torch.load(local_file) + weights = torch.load(model_path) model.load_state_dict(weights) return model.float().eval() @@ -50,17 +51,15 @@ def infer_image(image_url, n, globals): return result -def load(manifest): +def load(modelData): - globals = {} - client = Algorithmia.client() - globals["SMID_ALGO"] = "algo://util/SmartImageDownloader/0.2.x" - globals["model"] = load_model(manifest["squeezenet"], client) - globals["labels"] = load_labels(manifest["label_file"], client) - return globals + modelData.user_data["SMID_ALGO"] = "algo://util/SmartImageDownloader/0.2.x" + modelData.user_data["model"] = load_model(modelData.get_model("squeezenet")) + modelData.user_data["labels"] = load_labels(modelData.get_model("labels")) + return modelData -def apply(input, globals): +def apply(input, modelData): if isinstance(input, dict): if "n" in input: n = input["n"] @@ -68,10 +67,10 @@ def apply(input, globals): n = 3 if "data" in input: if isinstance(input["data"], str): - output = infer_image(input["data"], n, globals) + output = infer_image(input["data"], n, modelData.user_data) elif isinstance(input["data"], list): for row in input["data"]: - row["predictions"] = infer_image(row["image_url"], n, globals) + row["predictions"] = infer_image(row["image_url"], n, modelData.user_data) output = input["data"] else: raise Exception("\"data\" must be a image url or a list of image urls (with labels)") @@ -82,5 +81,5 @@ def apply(input, globals): raise Exception("input must be a json object") -algorithm = ADK(apply_func=apply, load_func=load) +algorithm = ADK(apply_func=apply, load_func=load, client=client) algorithm.init({"data": "https://i.imgur.com/bXdORXl.jpeg"}) diff --git a/tests/AdkTest.py b/tests/AdkTest.py new file mode 100644 index 0000000..e6b4672 --- /dev/null +++ b/tests/AdkTest.py @@ -0,0 +1,7 @@ +from adk import ADK + + +class ADKTest(ADK): + def __init__(self, apply_func, load_func=None, client=None, manifest_path="model_manifest.json.freeze"): + super(ADKTest, self).__init__(apply_func, load_func, client) + self.model_data = self.init_manifest(manifest_path) diff --git a/tests/adk_algorithms.py b/tests/adk_algorithms.py index c8c21bf..7a41600 100644 --- a/tests/adk_algorithms.py +++ b/tests/adk_algorithms.py @@ -1,36 +1,52 @@ import Algorithmia import base64 +import os + # -- Apply functions --- # def apply_basic(input): return "hello " + input + def apply_binary(input): if isinstance(input, bytes): input = input.decode('utf8') return bytes("hello " + input, encoding='utf8') -def apply_input_or_context(input, globals=None): - if isinstance(globals, dict): - return globals + +def apply_input_or_context(input, model_data=None): + if model_data: + return model_data.user_data else: return "hello " + input +def apply_successful_manifest_parsing(input, model_data): + if model_data: + return "all model files were successfully loaded" + else: + return "model files were not loaded correctly" + + # -- Loading functions --- # -def loading_text(): - context = dict() - context['message'] = 'This message was loaded prior to runtime' - return context +def loading_text(modelData): + modelData.user_data['message'] = 'This message was loaded prior to runtime' + return modelData -def loading_exception(): +def loading_exception(modelData): raise Exception("This exception was thrown in loading") -def loading_file_from_algorithmia(): - context = dict() - client = Algorithmia.client() - context['data_url'] = 'data://demo/collection/somefile.json' - context['data'] = client.file(context['data_url']).getJson() - return context +def loading_file_from_algorithmia(modelData): + modelData.user_data['data_url'] = 'data://demo/collection/somefile.json' + modelData.user_data['data'] = modelData.client.file(modelData.user_data['data_url']).getJson() + return modelData + + +def loading_with_manifest(modelData): + modelData.user_data["squeezenet"] = modelData.get_model("squeezenet") + modelData.user_data['labels'] = modelData.get_model("labels") + # optional model + modelData.user_data['mobilenet'] = modelData.get_model("mobilenet") + return modelData diff --git a/tests/manifests/bad_model_manifest.json.freeze b/tests/manifests/bad_model_manifest.json.freeze new file mode 100644 index 0000000..ec346aa --- /dev/null +++ b/tests/manifests/bad_model_manifest.json.freeze @@ -0,0 +1,35 @@ +{ + "algorithm_name": "test_algorithm", + "timestamp": "1632770803", + "lock_checksum": "36162a15980fa79975d2a747eb1bb842", + "required_files" : [ + { "name": "squeezenet", + "source_uri": "data://AlgorithmiaSE/image_cassification_demo/squeezenet1_1-f364aa15.pth", + "md5_checksum": "f20b50b44fdef367a225d41f747a0963", + "fail_on_tamper": true, + "metadata": { + "dataset_md5_checksum": "46a44d32d2c5c07f7f66324bef4c7266" + } + }, + { + "name": "labels", + "data_api_path": "data://AlgorithmiaSE/image_cassification_demo/imagenet_class_index.json", + "md5_checksum": "c2c37ea517e94d9795004a39431a14cb", + "fail_on_tamper": true, + "metadata": { + "dataset_md5_checksum": "46a44d32d2c5c07f7f66324bef4c7266" + } + } + ], + "optional_files": [ + { + "name": "mobilenet", + "source_uri": "data://AlgorithmiaSE/image_cassification_demo/mobilenet_v2-b0353104.pth", + "md5_checksum": "c2c37ea517e94d9795004a39431a14cb", + "fail_on_tamper": false, + "metadata": { + "dataset_md5_checksum": "46a44d32d2c5c07f7f66324bef4c7266" + } + } + ] +} \ No newline at end of file diff --git a/tests/manifests/good_model_manifest.json.freeze b/tests/manifests/good_model_manifest.json.freeze new file mode 100644 index 0000000..2e2b74a --- /dev/null +++ b/tests/manifests/good_model_manifest.json.freeze @@ -0,0 +1,34 @@ +{ + "timestamp": "1632770803", + "lock_checksum": "0c1cd66787a7368ef31b23d653e31bf7", + "required_files" : [ + { "name": "squeezenet", + "source_uri": "data://AlgorithmiaSE/image_cassification_demo/squeezenet1_1-f364aa15.pth", + "md5_checksum": "46a44d32d2c5c07f7f66324bef4c7266", + "fail_on_tamper": true, + "metadata": { + "dataset_md5_checksum": "46a44d32d2c5c07f7f66324bef4c7266" + } + }, + { + "name": "labels", + "source_uri": "data://AlgorithmiaSE/image_cassification_demo/imagenet_class_index.json", + "md5_checksum": "c2c37ea517e94d9795004a39431a14cb", + "fail_on_tamper": true, + "metadata": { + "dataset_md5_checksum": "46a44d32d2c5c07f7f66324bef4c7266" + } + } + ], + "optional_files": [ + { + "name": "mobilenet", + "source_uri": "data://AlgorithmiaSE/image_cassification_demo/mobilenet_v2-b0353104.pth", + "md5_checksum": "f20b50b44fdef367a225d41f747a0963", + "fail_on_tamper": false, + "metadata": { + "dataset_md5_checksum": "46a44d32d2c5c07f7f66324bef4c7266" + } + } + ] +} \ No newline at end of file diff --git a/tests/test_adk_local.py b/tests/test_adk_local.py index db9ec05..5592bfc 100644 --- a/tests/test_adk_local.py +++ b/tests/test_adk_local.py @@ -1,7 +1,7 @@ import json import os import unittest -from adk import ADK +from tests.AdkTest import ADKTest from tests.adk_algorithms import * @@ -14,14 +14,24 @@ def setUp(self): except: pass - def execute_example(self, input, apply, load=lambda: None): - algo = ADK(apply, load) + def execute_example(self, input, apply, load=None): + if load: + algo = ADKTest(apply, load) + else: + algo = ADKTest(apply) + output = [] + algo.init(input, pprint=lambda x: output.append(x)) + return output[0] + + def execute_manifest_example(self, input, apply, load, manifest_path="manifests/good_model_manifest.json.freeze"): + client = Algorithmia.client() + algo = ADKTest(apply, load, manifest_path=manifest_path, client=client) output = [] algo.init(input, pprint=lambda x: output.append(x)) return output[0] def execute_without_load(self, input, apply): - algo = ADK(apply) + algo = ADKTest(apply) output = [] algo.init(input, pprint=lambda x: output.append(x)) return output[0] @@ -110,6 +120,34 @@ def test_binary_data(self): actual_output = json.loads(self.execute_without_load(input, apply_binary)) self.assertEqual(expected_output, actual_output) + def test_manifest_file_success(self): + input = "Algorithmia" + expected_output = {'metadata': + { + 'content_type': 'text' + }, + 'result': "all model files were successfully loaded" + } + actual_output = json.loads(self.execute_manifest_example(input, apply_successful_manifest_parsing, + loading_with_manifest, + manifest_path="tests/manifests/good_model_manifest" + ".json.freeze")) + self.assertEqual(expected_output, actual_output) + + def test_manifest_file_tampered(self): + input = "Algorithmia" + expected_output = {"error": {"error_type": "LoadingError", + "message": "Model File Mismatch for squeezenet\n" + "expected hash: f20b50b44fdef367a225d41f747a0963\n" + "real hash: 46a44d32d2c5c07f7f66324bef4c7266", + "stacktrace": "NoneType: None\n"}} + + actual_output = json.loads(self.execute_manifest_example(input, apply_successful_manifest_parsing, + loading_with_manifest, + manifest_path="tests/manifests/bad_model_manifest" + ".json.freeze")) + self.assertEqual(expected_output, actual_output) + def run_test(): unittest.main() diff --git a/tests/test_adk_remote.py b/tests/test_adk_remote.py index 2a77285..f0d69ba 100644 --- a/tests/test_adk_remote.py +++ b/tests/test_adk_remote.py @@ -2,7 +2,7 @@ import json import unittest import os -from adk import ADK +from tests.AdkTest import ADKTest import base64 from tests.adk_algorithms import * @@ -43,9 +43,9 @@ def open_pipe(self): if os.name == "posix": self.fifo_pipe = os.open(self.fifo_pipe_path, os.O_RDONLY | os.O_NONBLOCK) - def execute_example(self, input, apply, load=lambda: None): + def execute_example(self, input, apply, load=None): self.open_pipe() - algo = ADK(apply, load) + algo = ADKTest(apply, load) sys.stdin = input algo.init() output = self.read_in() @@ -53,20 +53,22 @@ def execute_example(self, input, apply, load=lambda: None): def execute_without_load(self, input, apply): self.open_pipe() - algo = ADK(apply) + algo = ADKTest(apply) sys.stdin = input algo.init() output = self.read_in() return output - def execute_example_local(self, input, apply, load=lambda: None): - algo = ADK(apply, load) - output = algo.init(input, pprint=lambda x: x) + def execute_manifest_example(self, input, apply, load, manifest_path): + client = Algorithmia.client() + self.open_pipe() + algo = ADKTest(apply, load, manifest_path=manifest_path, client=client) + sys.stdin = input + algo.init() + output = self.read_in() return output - - -# ----- Tests ----- # + # ----- Tests ----- # def test_basic(self): input = {'content_type': 'json', 'data': 'Algorithmia'} @@ -92,7 +94,6 @@ def test_basic_2(self): actual_output = self.execute_without_load(input, apply_basic) self.assertEqual(expected_output, actual_output) - def test_algorithm_loading_basic(self): input = {'content_type': 'json', 'data': 'ignore me'} expected_output = {'metadata': @@ -160,6 +161,21 @@ def test_binary_data(self): actual_output = self.execute_without_load(input, apply_binary) self.assertEqual(expected_output, actual_output) + def test_manifest_file_success(self): + input = {'content_type': 'json', 'data': 'Algorithmia'} + expected_output = {'metadata': + { + 'content_type': 'text' + }, + 'result': "all model files were successfully loaded" + } + input = [str(json.dumps(input))] + actual_output = self.execute_manifest_example(input, apply_successful_manifest_parsing, + loading_with_manifest, + manifest_path="tests/manifests/good_model_manifest" + ".json.freeze") + self.assertEqual(expected_output, actual_output) + def run_test(): - unittest.main() \ No newline at end of file + unittest.main()