Skip to content

AML-6 model manifest, tamper detection implementation #7

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Oct 15, 2021
Merged
44 changes: 21 additions & 23 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -93,22 +93,21 @@ from Algorithmia import ADK
# API calls will begin at the apply() method, with the request body passed as 'input'
# For more details, see algorithmia.com/developers/algorithm-development/languages

def apply(input, globals):
def apply(input, modelData):
# If your apply function uses state that's loaded into memory via load, you can pass that loaded state to your apply
# function by defining an additional "globals" parameter in your apply function.
return "hello {} {}".format(str(input), str(globals['payload']))
return "hello {} {}".format(str(input), str(modelData.user_data['payload']))


def load():
def load(modelData):
# Here you can optionally define a function that will be called when the algorithm is loaded.
# The return object from this function can be passed directly as input to your apply function.
# A great example would be any model files that need to be available to this algorithm
# during runtime.

# Any variables returned here, will be passed as the secondary argument to your 'algorithm' function
globals = {}
globals['payload'] = "Loading has been completed."
return globals
modelData.user_data['payload'] = "Loading has been completed."
return modelData


# This turns your library code into an algorithm that can run on the platform.
Expand All @@ -129,18 +128,19 @@ from PIL import Image
import json
from torchvision import models, transforms

def load_labels(label_path, client):
local_path = client.file(label_path).getFile().name
with open(local_path) as f:

client = Algorithmia.client()

def load_labels(label_path):
with open(label_path) as f:
labels = json.load(f)
labels = [labels[str(k)][1] for k in range(len(labels))]
return labels


def load_model(model_paths, client):
def load_model(model_path):
model = models.squeezenet1_1()
local_file = client.file(model_paths["filepath"]).getFile().name
weights = torch.load(local_file)
weights = torch.load(model_path)
model.load_state_dict(weights)
return model.float().eval()

Expand Down Expand Up @@ -174,28 +174,26 @@ def infer_image(image_url, n, globals):
return result


def load(manifest):
def load(modelData):

globals = {}
client = Algorithmia.client()
globals["SMID_ALGO"] = "algo://util/SmartImageDownloader/0.2.x"
globals["model"] = load_model(manifest["squeezenet"], client)
globals["labels"] = load_labels(manifest["label_file"], client)
return globals
modelData.user_data["SMID_ALGO"] = "algo://util/SmartImageDownloader/0.2.x"
modelData.user_data["model"] = load_model(modelData.get_model("squeezenet"))
modelData.user_data["labels"] = load_labels(modelData.get_model("labels"))
return modelData


def apply(input, globals):
def apply(input, modelData):
if isinstance(input, dict):
if "n" in input:
n = input["n"]
else:
n = 3
if "data" in input:
if isinstance(input["data"], str):
output = infer_image(input["data"], n, globals)
output = infer_image(input["data"], n, modelData.user_data)
elif isinstance(input["data"], list):
for row in input["data"]:
row["predictions"] = infer_image(row["image_url"], n, globals)
row["predictions"] = infer_image(row["image_url"], n, modelData.user_data)
output = input["data"]
else:
raise Exception("\"data\" must be a image url or a list of image urls (with labels)")
Expand All @@ -206,7 +204,7 @@ def apply(input, globals):
raise Exception("input must be a json object")


algorithm = ADK(apply_func=apply, load_func=load)
algorithm = ADK(apply_func=apply, load_func=load, client=client)
algorithm.init({"data": "https://i.imgur.com/bXdORXl.jpeg"})

```
Expand Down
125 changes: 43 additions & 82 deletions adk/ADK.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,55 @@
import base64
import inspect
import json
import os
import sys
import traceback
import six
import Algorithmia
from adk.io import create_exception, format_data, format_response
from adk.manifest.modeldata import ModelData


class ADK(object):
def __init__(self, apply_func, load_func=None):
def __init__(self, apply_func, load_func=None, client=None):
"""
Creates the adk object
:param apply_func: A required function that can have an arity of 1-2, depending on if loading occurs
:param load_func: An optional supplier function used if load time events are required, has an arity of 0.
:param load_func: An optional supplier function used if load time events are required, if a model manifest is provided;
the function may have a single `manifest` parameter to interact with the model manifest, otherwise must have no parameters.
:param client: A Algorithmia Client instance that might be user defined,
and is used for interacting with a model manifest file; if defined.
"""
self.FIFO_PATH = "/tmp/algoout"

if client:
self.client = client
else:
self.client = Algorithmia.client()

apply_args, _, _, _, _, _, _ = inspect.getfullargspec(apply_func)
self.apply_arity = len(apply_args)
if load_func:
load_args, _, _, _, _, _, _ = inspect.getfullargspec(load_func)
if len(load_args) > 0:
raise Exception("load function must not have parameters")
self.load_arity = len(load_args)
if self.load_arity != 1:
raise Exception("load function expects 1 parameter to be used to store algorithm state")
self.load_func = load_func
else:
self.load_func = None
if len(apply_args) > 2 or len(apply_args) == 0:
raise Exception("apply function may have between 1 and 2 parameters, not {}".format(len(apply_args)))
self.apply_func = apply_func
self.is_local = not os.path.exists(self.FIFO_PATH)
self.load_result = None
self.loading_exception = None
self.manifest_path = "model_manifest.json.freeze"
self.model_data = self.init_manifest(self.manifest_path)

def init_manifest(self, path):
return ModelData(self.client, path)

def load(self):
try:
if self.model_data.available():
self.model_data.initialize()
if self.load_func:
self.load_result = self.load_func()
self.load_result = self.load_func(self.model_data)
except Exception as e:
self.loading_exception = e
finally:
Expand All @@ -45,55 +61,16 @@ def load(self):

def apply(self, payload):
try:
if self.load_result:
if self.load_result and self.apply_arity == 2:
apply_result = self.apply_func(payload, self.load_result)
else:
apply_result = self.apply_func(payload)
response_obj = self.format_response(apply_result)
response_obj = format_response(apply_result)
return response_obj
except Exception as e:
response_obj = self.create_exception(e)
response_obj = create_exception(e)
return response_obj

def format_data(self, request):
if request["content_type"] in ["text", "json"]:
data = request["data"]
elif request["content_type"] == "binary":
data = self.wrap_binary_data(base64.b64decode(request["data"]))
else:
raise Exception("Invalid content_type: {}".format(request["content_type"]))
return data

def is_binary(self, arg):
if six.PY3:
return isinstance(arg, base64.bytes_types)

return isinstance(arg, bytearray)

def wrap_binary_data(self, data):
if six.PY3:
return bytes(data)
else:
return bytearray(data)

def format_response(self, response):
if self.is_binary(response):
content_type = "binary"
response = str(base64.b64encode(response), "utf-8")
elif isinstance(response, six.string_types) or isinstance(response, six.text_type):
content_type = "text"
else:
content_type = "json"
response_string = json.dumps(
{
"result": response,
"metadata": {
"content_type": content_type
}
}
)
return response_string

def write_to_pipe(self, payload, pprint=print):
if self.is_local:
if isinstance(payload, dict):
Expand All @@ -109,40 +86,24 @@ def write_to_pipe(self, payload, pprint=print):
if os.name == "nt":
sys.stdin = payload

def create_exception(self, exception, loading_exception=False):
if hasattr(exception, "error_type"):
error_type = exception.error_type
elif loading_exception:
error_type = "LoadingError"
else:
error_type = "AlgorithmError"
response = json.dumps({
"error": {
"message": str(exception),
"stacktrace": traceback.format_exc(),
"error_type": error_type,
}
})
return response

def process_local(self, local_payload, pprint):
result = self.apply(local_payload)
self.write_to_pipe(result, pprint=pprint)

def init(self, local_payload=None, pprint=print):
self.load()
if self.is_local and local_payload:
self.load()
if self.is_local and local_payload:
if self.loading_exception:
load_error = create_exception(self.loading_exception, loading_exception=True)
self.write_to_pipe(load_error, pprint=pprint)
self.process_local(local_payload, pprint)
else:
for line in sys.stdin:
request = json.loads(line)
formatted_input = format_data(request)
if self.loading_exception:
load_error = self.create_exception(self.loading_exception, loading_exception=True)
load_error = create_exception(self.loading_exception, loading_exception=True)
self.write_to_pipe(load_error, pprint=pprint)
self.process_local(local_payload, pprint)
else:
for line in sys.stdin:
request = json.loads(line)
formatted_input = self.format_data(request)
if self.loading_exception:
load_error = self.create_exception(self.loading_exception, loading_exception=True)
self.write_to_pipe(load_error, pprint=pprint)
else:
result = self.apply(formatted_input)
self.write_to_pipe(result)
else:
result = self.apply(formatted_input)
self.write_to_pipe(result)
64 changes: 64 additions & 0 deletions adk/io.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import traceback
import six
import base64
import json


def format_data(request):
if request["content_type"] in ["text", "json"]:
data = request["data"]
elif request["content_type"] == "binary":
data = wrap_binary_data(base64.b64decode(request["data"]))
else:
raise Exception("Invalid content_type: {}".format(request["content_type"]))
return data


def is_binary(arg):
if six.PY3:
return isinstance(arg, base64.bytes_types)

return isinstance(arg, bytearray)


def wrap_binary_data(data):
if six.PY3:
return bytes(data)
else:
return bytearray(data)


def format_response(response):
if is_binary(response):
content_type = "binary"
response = str(base64.b64encode(response), "utf-8")
elif isinstance(response, six.string_types) or isinstance(response, six.text_type):
content_type = "text"
else:
content_type = "json"
response_string = json.dumps(
{
"result": response,
"metadata": {
"content_type": content_type
}
}
)
return response_string


def create_exception(exception, loading_exception=False):
if hasattr(exception, "error_type"):
error_type = exception.error_type
elif loading_exception:
error_type = "LoadingError"
else:
error_type = "AlgorithmError"
response = json.dumps({
"error": {
"message": str(exception),
"stacktrace": traceback.format_exc(),
"error_type": error_type,
}
})
return response
Empty file added adk/manifest/__init__.py
Empty file.
6 changes: 6 additions & 0 deletions adk/manifest/classes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@

class FileData(object):
def __init__(self, md5_checksum, file_path):
self.md5_checksum = md5_checksum
self.file_path = file_path

Loading