1
- import base64
2
1
import inspect
3
2
import json
4
3
import os
5
4
import sys
6
- import traceback
7
- import six
5
+ import Algorithmia
6
+ from adk .io import create_exception , format_data , format_response
7
+ from adk .manifest .modeldata import ModelData
8
8
9
9
10
10
class ADK (object ):
11
- def __init__ (self , apply_func , load_func = None ):
11
+ def __init__ (self , apply_func , load_func = None , client = None ):
12
12
"""
13
13
Creates the adk object
14
14
:param apply_func: A required function that can have an arity of 1-2, depending on if loading occurs
15
- :param load_func: An optional supplier function used if load time events are required, has an arity of 0.
15
+ :param load_func: An optional supplier function used if load time events are required, if a model manifest is provided;
16
+ the function may have a single `manifest` parameter to interact with the model manifest, otherwise must have no parameters.
17
+ :param client: A Algorithmia Client instance that might be user defined,
18
+ and is used for interacting with a model manifest file; if defined.
16
19
"""
17
20
self .FIFO_PATH = "/tmp/algoout"
21
+
22
+ if client :
23
+ self .client = client
24
+ else :
25
+ self .client = Algorithmia .client ()
26
+
18
27
apply_args , _ , _ , _ , _ , _ , _ = inspect .getfullargspec (apply_func )
28
+ self .apply_arity = len (apply_args )
19
29
if load_func :
20
30
load_args , _ , _ , _ , _ , _ , _ = inspect .getfullargspec (load_func )
21
- if len (load_args ) > 0 :
22
- raise Exception ("load function must not have parameters" )
31
+ self .load_arity = len (load_args )
32
+ if self .load_arity != 1 :
33
+ raise Exception ("load function expects 1 parameter to be used to store algorithm state" )
23
34
self .load_func = load_func
24
35
else :
25
36
self .load_func = None
26
- if len (apply_args ) > 2 or len (apply_args ) == 0 :
27
- raise Exception ("apply function may have between 1 and 2 parameters, not {}" .format (len (apply_args )))
28
37
self .apply_func = apply_func
29
38
self .is_local = not os .path .exists (self .FIFO_PATH )
30
39
self .load_result = None
31
40
self .loading_exception = None
41
+ self .manifest_path = "model_manifest.json.freeze"
42
+ self .model_data = self .init_manifest (self .manifest_path )
43
+
44
+ def init_manifest (self , path ):
45
+ return ModelData (self .client , path )
32
46
33
47
def load (self ):
34
48
try :
49
+ if self .model_data .available ():
50
+ self .model_data .initialize ()
35
51
if self .load_func :
36
- self .load_result = self .load_func ()
52
+ self .load_result = self .load_func (self . model_data )
37
53
except Exception as e :
38
54
self .loading_exception = e
39
55
finally :
@@ -45,55 +61,16 @@ def load(self):
45
61
46
62
def apply (self , payload ):
47
63
try :
48
- if self .load_result :
64
+ if self .load_result and self . apply_arity == 2 :
49
65
apply_result = self .apply_func (payload , self .load_result )
50
66
else :
51
67
apply_result = self .apply_func (payload )
52
- response_obj = self . format_response (apply_result )
68
+ response_obj = format_response (apply_result )
53
69
return response_obj
54
70
except Exception as e :
55
- response_obj = self . create_exception (e )
71
+ response_obj = create_exception (e )
56
72
return response_obj
57
73
58
- def format_data (self , request ):
59
- if request ["content_type" ] in ["text" , "json" ]:
60
- data = request ["data" ]
61
- elif request ["content_type" ] == "binary" :
62
- data = self .wrap_binary_data (base64 .b64decode (request ["data" ]))
63
- else :
64
- raise Exception ("Invalid content_type: {}" .format (request ["content_type" ]))
65
- return data
66
-
67
- def is_binary (self , arg ):
68
- if six .PY3 :
69
- return isinstance (arg , base64 .bytes_types )
70
-
71
- return isinstance (arg , bytearray )
72
-
73
- def wrap_binary_data (self , data ):
74
- if six .PY3 :
75
- return bytes (data )
76
- else :
77
- return bytearray (data )
78
-
79
- def format_response (self , response ):
80
- if self .is_binary (response ):
81
- content_type = "binary"
82
- response = str (base64 .b64encode (response ), "utf-8" )
83
- elif isinstance (response , six .string_types ) or isinstance (response , six .text_type ):
84
- content_type = "text"
85
- else :
86
- content_type = "json"
87
- response_string = json .dumps (
88
- {
89
- "result" : response ,
90
- "metadata" : {
91
- "content_type" : content_type
92
- }
93
- }
94
- )
95
- return response_string
96
-
97
74
def write_to_pipe (self , payload , pprint = print ):
98
75
if self .is_local :
99
76
if isinstance (payload , dict ):
@@ -109,40 +86,24 @@ def write_to_pipe(self, payload, pprint=print):
109
86
if os .name == "nt" :
110
87
sys .stdin = payload
111
88
112
- def create_exception (self , exception , loading_exception = False ):
113
- if hasattr (exception , "error_type" ):
114
- error_type = exception .error_type
115
- elif loading_exception :
116
- error_type = "LoadingError"
117
- else :
118
- error_type = "AlgorithmError"
119
- response = json .dumps ({
120
- "error" : {
121
- "message" : str (exception ),
122
- "stacktrace" : traceback .format_exc (),
123
- "error_type" : error_type ,
124
- }
125
- })
126
- return response
127
-
128
89
def process_local (self , local_payload , pprint ):
129
90
result = self .apply (local_payload )
130
91
self .write_to_pipe (result , pprint = pprint )
131
92
132
93
def init (self , local_payload = None , pprint = print ):
133
- self .load ()
134
- if self .is_local and local_payload :
94
+ self .load ()
95
+ if self .is_local and local_payload :
96
+ if self .loading_exception :
97
+ load_error = create_exception (self .loading_exception , loading_exception = True )
98
+ self .write_to_pipe (load_error , pprint = pprint )
99
+ self .process_local (local_payload , pprint )
100
+ else :
101
+ for line in sys .stdin :
102
+ request = json .loads (line )
103
+ formatted_input = format_data (request )
135
104
if self .loading_exception :
136
- load_error = self . create_exception (self .loading_exception , loading_exception = True )
105
+ load_error = create_exception (self .loading_exception , loading_exception = True )
137
106
self .write_to_pipe (load_error , pprint = pprint )
138
- self .process_local (local_payload , pprint )
139
- else :
140
- for line in sys .stdin :
141
- request = json .loads (line )
142
- formatted_input = self .format_data (request )
143
- if self .loading_exception :
144
- load_error = self .create_exception (self .loading_exception , loading_exception = True )
145
- self .write_to_pipe (load_error , pprint = pprint )
146
- else :
147
- result = self .apply (formatted_input )
148
- self .write_to_pipe (result )
107
+ else :
108
+ result = self .apply (formatted_input )
109
+ self .write_to_pipe (result )
0 commit comments