Skip to content

Commit e277844

Browse files
authored
Merge pull request #7 from algorithmiaio/AML-6-model-manifest
AML-6 model manifest, tamper detection implementation
2 parents 64dbb70 + 0818abe commit e277844

File tree

16 files changed

+486
-169
lines changed

16 files changed

+486
-169
lines changed

README.md

Lines changed: 21 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -93,22 +93,21 @@ from Algorithmia import ADK
9393
# API calls will begin at the apply() method, with the request body passed as 'input'
9494
# For more details, see algorithmia.com/developers/algorithm-development/languages
9595

96-
def apply(input, globals):
96+
def apply(input, modelData):
9797
# If your apply function uses state that's loaded into memory via load, you can pass that loaded state to your apply
9898
# function by defining an additional "globals" parameter in your apply function.
99-
return "hello {} {}".format(str(input), str(globals['payload']))
99+
return "hello {} {}".format(str(input), str(modelData.user_data['payload']))
100100

101101

102-
def load():
102+
def load(modelData):
103103
# Here you can optionally define a function that will be called when the algorithm is loaded.
104104
# The return object from this function can be passed directly as input to your apply function.
105105
# A great example would be any model files that need to be available to this algorithm
106106
# during runtime.
107107

108108
# Any variables returned here, will be passed as the secondary argument to your 'algorithm' function
109-
globals = {}
110-
globals['payload'] = "Loading has been completed."
111-
return globals
109+
modelData.user_data['payload'] = "Loading has been completed."
110+
return modelData
112111

113112

114113
# This turns your library code into an algorithm that can run on the platform.
@@ -129,18 +128,19 @@ from PIL import Image
129128
import json
130129
from torchvision import models, transforms
131130

132-
def load_labels(label_path, client):
133-
local_path = client.file(label_path).getFile().name
134-
with open(local_path) as f:
131+
132+
client = Algorithmia.client()
133+
134+
def load_labels(label_path):
135+
with open(label_path) as f:
135136
labels = json.load(f)
136137
labels = [labels[str(k)][1] for k in range(len(labels))]
137138
return labels
138139

139140

140-
def load_model(model_paths, client):
141+
def load_model(model_path):
141142
model = models.squeezenet1_1()
142-
local_file = client.file(model_paths["filepath"]).getFile().name
143-
weights = torch.load(local_file)
143+
weights = torch.load(model_path)
144144
model.load_state_dict(weights)
145145
return model.float().eval()
146146

@@ -174,28 +174,26 @@ def infer_image(image_url, n, globals):
174174
return result
175175

176176

177-
def load(manifest):
177+
def load(modelData):
178178

179-
globals = {}
180-
client = Algorithmia.client()
181-
globals["SMID_ALGO"] = "algo://util/SmartImageDownloader/0.2.x"
182-
globals["model"] = load_model(manifest["squeezenet"], client)
183-
globals["labels"] = load_labels(manifest["label_file"], client)
184-
return globals
179+
modelData.user_data["SMID_ALGO"] = "algo://util/SmartImageDownloader/0.2.x"
180+
modelData.user_data["model"] = load_model(modelData.get_model("squeezenet"))
181+
modelData.user_data["labels"] = load_labels(modelData.get_model("labels"))
182+
return modelData
185183

186184

187-
def apply(input, globals):
185+
def apply(input, modelData):
188186
if isinstance(input, dict):
189187
if "n" in input:
190188
n = input["n"]
191189
else:
192190
n = 3
193191
if "data" in input:
194192
if isinstance(input["data"], str):
195-
output = infer_image(input["data"], n, globals)
193+
output = infer_image(input["data"], n, modelData.user_data)
196194
elif isinstance(input["data"], list):
197195
for row in input["data"]:
198-
row["predictions"] = infer_image(row["image_url"], n, globals)
196+
row["predictions"] = infer_image(row["image_url"], n, modelData.user_data)
199197
output = input["data"]
200198
else:
201199
raise Exception("\"data\" must be a image url or a list of image urls (with labels)")
@@ -206,7 +204,7 @@ def apply(input, globals):
206204
raise Exception("input must be a json object")
207205

208206

209-
algorithm = ADK(apply_func=apply, load_func=load)
207+
algorithm = ADK(apply_func=apply, load_func=load, client=client)
210208
algorithm.init({"data": "https://i.imgur.com/bXdORXl.jpeg"})
211209

212210
```

adk/ADK.py

Lines changed: 43 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,55 @@
1-
import base64
21
import inspect
32
import json
43
import os
54
import sys
6-
import traceback
7-
import six
5+
import Algorithmia
6+
from adk.io import create_exception, format_data, format_response
7+
from adk.manifest.modeldata import ModelData
88

99

1010
class ADK(object):
11-
def __init__(self, apply_func, load_func=None):
11+
def __init__(self, apply_func, load_func=None, client=None):
1212
"""
1313
Creates the adk object
1414
:param apply_func: A required function that can have an arity of 1-2, depending on if loading occurs
15-
:param load_func: An optional supplier function used if load time events are required, has an arity of 0.
15+
:param load_func: An optional supplier function used if load time events are required, if a model manifest is provided;
16+
the function may have a single `manifest` parameter to interact with the model manifest, otherwise must have no parameters.
17+
:param client: A Algorithmia Client instance that might be user defined,
18+
and is used for interacting with a model manifest file; if defined.
1619
"""
1720
self.FIFO_PATH = "/tmp/algoout"
21+
22+
if client:
23+
self.client = client
24+
else:
25+
self.client = Algorithmia.client()
26+
1827
apply_args, _, _, _, _, _, _ = inspect.getfullargspec(apply_func)
28+
self.apply_arity = len(apply_args)
1929
if load_func:
2030
load_args, _, _, _, _, _, _ = inspect.getfullargspec(load_func)
21-
if len(load_args) > 0:
22-
raise Exception("load function must not have parameters")
31+
self.load_arity = len(load_args)
32+
if self.load_arity != 1:
33+
raise Exception("load function expects 1 parameter to be used to store algorithm state")
2334
self.load_func = load_func
2435
else:
2536
self.load_func = None
26-
if len(apply_args) > 2 or len(apply_args) == 0:
27-
raise Exception("apply function may have between 1 and 2 parameters, not {}".format(len(apply_args)))
2837
self.apply_func = apply_func
2938
self.is_local = not os.path.exists(self.FIFO_PATH)
3039
self.load_result = None
3140
self.loading_exception = None
41+
self.manifest_path = "model_manifest.json.freeze"
42+
self.model_data = self.init_manifest(self.manifest_path)
43+
44+
def init_manifest(self, path):
45+
return ModelData(self.client, path)
3246

3347
def load(self):
3448
try:
49+
if self.model_data.available():
50+
self.model_data.initialize()
3551
if self.load_func:
36-
self.load_result = self.load_func()
52+
self.load_result = self.load_func(self.model_data)
3753
except Exception as e:
3854
self.loading_exception = e
3955
finally:
@@ -45,55 +61,16 @@ def load(self):
4561

4662
def apply(self, payload):
4763
try:
48-
if self.load_result:
64+
if self.load_result and self.apply_arity == 2:
4965
apply_result = self.apply_func(payload, self.load_result)
5066
else:
5167
apply_result = self.apply_func(payload)
52-
response_obj = self.format_response(apply_result)
68+
response_obj = format_response(apply_result)
5369
return response_obj
5470
except Exception as e:
55-
response_obj = self.create_exception(e)
71+
response_obj = create_exception(e)
5672
return response_obj
5773

58-
def format_data(self, request):
59-
if request["content_type"] in ["text", "json"]:
60-
data = request["data"]
61-
elif request["content_type"] == "binary":
62-
data = self.wrap_binary_data(base64.b64decode(request["data"]))
63-
else:
64-
raise Exception("Invalid content_type: {}".format(request["content_type"]))
65-
return data
66-
67-
def is_binary(self, arg):
68-
if six.PY3:
69-
return isinstance(arg, base64.bytes_types)
70-
71-
return isinstance(arg, bytearray)
72-
73-
def wrap_binary_data(self, data):
74-
if six.PY3:
75-
return bytes(data)
76-
else:
77-
return bytearray(data)
78-
79-
def format_response(self, response):
80-
if self.is_binary(response):
81-
content_type = "binary"
82-
response = str(base64.b64encode(response), "utf-8")
83-
elif isinstance(response, six.string_types) or isinstance(response, six.text_type):
84-
content_type = "text"
85-
else:
86-
content_type = "json"
87-
response_string = json.dumps(
88-
{
89-
"result": response,
90-
"metadata": {
91-
"content_type": content_type
92-
}
93-
}
94-
)
95-
return response_string
96-
9774
def write_to_pipe(self, payload, pprint=print):
9875
if self.is_local:
9976
if isinstance(payload, dict):
@@ -109,40 +86,24 @@ def write_to_pipe(self, payload, pprint=print):
10986
if os.name == "nt":
11087
sys.stdin = payload
11188

112-
def create_exception(self, exception, loading_exception=False):
113-
if hasattr(exception, "error_type"):
114-
error_type = exception.error_type
115-
elif loading_exception:
116-
error_type = "LoadingError"
117-
else:
118-
error_type = "AlgorithmError"
119-
response = json.dumps({
120-
"error": {
121-
"message": str(exception),
122-
"stacktrace": traceback.format_exc(),
123-
"error_type": error_type,
124-
}
125-
})
126-
return response
127-
12889
def process_local(self, local_payload, pprint):
12990
result = self.apply(local_payload)
13091
self.write_to_pipe(result, pprint=pprint)
13192

13293
def init(self, local_payload=None, pprint=print):
133-
self.load()
134-
if self.is_local and local_payload:
94+
self.load()
95+
if self.is_local and local_payload:
96+
if self.loading_exception:
97+
load_error = create_exception(self.loading_exception, loading_exception=True)
98+
self.write_to_pipe(load_error, pprint=pprint)
99+
self.process_local(local_payload, pprint)
100+
else:
101+
for line in sys.stdin:
102+
request = json.loads(line)
103+
formatted_input = format_data(request)
135104
if self.loading_exception:
136-
load_error = self.create_exception(self.loading_exception, loading_exception=True)
105+
load_error = create_exception(self.loading_exception, loading_exception=True)
137106
self.write_to_pipe(load_error, pprint=pprint)
138-
self.process_local(local_payload, pprint)
139-
else:
140-
for line in sys.stdin:
141-
request = json.loads(line)
142-
formatted_input = self.format_data(request)
143-
if self.loading_exception:
144-
load_error = self.create_exception(self.loading_exception, loading_exception=True)
145-
self.write_to_pipe(load_error, pprint=pprint)
146-
else:
147-
result = self.apply(formatted_input)
148-
self.write_to_pipe(result)
107+
else:
108+
result = self.apply(formatted_input)
109+
self.write_to_pipe(result)

adk/io.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
import traceback
2+
import six
3+
import base64
4+
import json
5+
6+
7+
def format_data(request):
8+
if request["content_type"] in ["text", "json"]:
9+
data = request["data"]
10+
elif request["content_type"] == "binary":
11+
data = wrap_binary_data(base64.b64decode(request["data"]))
12+
else:
13+
raise Exception("Invalid content_type: {}".format(request["content_type"]))
14+
return data
15+
16+
17+
def is_binary(arg):
18+
if six.PY3:
19+
return isinstance(arg, base64.bytes_types)
20+
21+
return isinstance(arg, bytearray)
22+
23+
24+
def wrap_binary_data(data):
25+
if six.PY3:
26+
return bytes(data)
27+
else:
28+
return bytearray(data)
29+
30+
31+
def format_response(response):
32+
if is_binary(response):
33+
content_type = "binary"
34+
response = str(base64.b64encode(response), "utf-8")
35+
elif isinstance(response, six.string_types) or isinstance(response, six.text_type):
36+
content_type = "text"
37+
else:
38+
content_type = "json"
39+
response_string = json.dumps(
40+
{
41+
"result": response,
42+
"metadata": {
43+
"content_type": content_type
44+
}
45+
}
46+
)
47+
return response_string
48+
49+
50+
def create_exception(exception, loading_exception=False):
51+
if hasattr(exception, "error_type"):
52+
error_type = exception.error_type
53+
elif loading_exception:
54+
error_type = "LoadingError"
55+
else:
56+
error_type = "AlgorithmError"
57+
response = json.dumps({
58+
"error": {
59+
"message": str(exception),
60+
"stacktrace": traceback.format_exc(),
61+
"error_type": error_type,
62+
}
63+
})
64+
return response

adk/manifest/__init__.py

Whitespace-only changes.

adk/manifest/classes.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
2+
class FileData(object):
3+
def __init__(self, md5_checksum, file_path):
4+
self.md5_checksum = md5_checksum
5+
self.file_path = file_path
6+

0 commit comments

Comments
 (0)