Skip to content

Commit 81f0154

Browse files
committed
Add pre and post invocation hooks
1 parent 6c8fb5c commit 81f0154

File tree

7 files changed

+429
-151
lines changed

7 files changed

+429
-151
lines changed

.eslintrc.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
"@typescript-eslint/restrict-template-expressions": "off",
4444
"@typescript-eslint/unbound-method": "off",
4545
"no-empty": "off",
46+
"prefer-const": ["error", { "destructuring": "all" }],
4647
"prefer-rest-params": "off",
4748
"prefer-spread": "off"
4849
},

src/Disposable.ts

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
// Copyright (c) .NET Foundation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
/**
5+
* Based off of VS Code
6+
* https://github.com/microsoft/vscode/blob/a64e8e5673a44e5b9c2d493666bde684bd5a135c/src/vs/workbench/api/common/extHostTypes.ts#L32
7+
*/
8+
export class Disposable {
9+
static from(...inDisposables: { dispose(): any }[]): Disposable {
10+
let disposables: ReadonlyArray<{ dispose(): any }> | undefined = inDisposables;
11+
return new Disposable(function () {
12+
if (disposables) {
13+
for (const disposable of disposables) {
14+
if (disposable && typeof disposable.dispose === 'function') {
15+
disposable.dispose();
16+
}
17+
}
18+
disposables = undefined;
19+
}
20+
});
21+
}
22+
23+
#callOnDispose?: () => any;
24+
25+
constructor(callOnDispose: () => any) {
26+
this.#callOnDispose = callOnDispose;
27+
}
28+
29+
dispose(): any {
30+
if (this.#callOnDispose instanceof Function) {
31+
this.#callOnDispose();
32+
this.#callOnDispose = undefined;
33+
}
34+
}
35+
}

src/WorkerChannel.ts

Lines changed: 43 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,23 @@
11
// Copyright (c) .NET Foundation. All rights reserved.
22
// Licensed under the MIT License.
33

4-
import { Context } from '@azure/functions';
4+
import { HookCallback } from '@azure/functions-worker';
55
import { AzureFunctionsRpcMessages as rpc } from '../azure-functions-language-worker-protobuf/src/rpc';
6+
import { Disposable } from './Disposable';
67
import { IFunctionLoader } from './FunctionLoader';
78
import { IEventStream } from './GrpcClient';
8-
9-
type InvocationRequestBefore = (context: Context, userFn: Function) => Function;
10-
type InvocationRequestAfter = (context: Context) => void;
9+
import Module = require('module');
1110

1211
export class WorkerChannel {
1312
public eventStream: IEventStream;
1413
public functionLoader: IFunctionLoader;
15-
private _invocationRequestBefore: InvocationRequestBefore[];
16-
private _invocationRequestAfter: InvocationRequestAfter[];
14+
private _preInvocationHooks: HookCallback[] = [];
15+
private _postInvocationHooks: HookCallback[] = [];
1716

1817
constructor(eventStream: IEventStream, functionLoader: IFunctionLoader) {
1918
this.eventStream = eventStream;
2019
this.functionLoader = functionLoader;
21-
this._invocationRequestBefore = [];
22-
this._invocationRequestAfter = [];
20+
this.initWorkerModule(this);
2321
}
2422

2523
/**
@@ -33,32 +31,49 @@ export class WorkerChannel {
3331
});
3432
}
3533

36-
/**
37-
* Register a patching function to be run before User Function is executed.
38-
* Hook should return a patched version of User Function.
39-
*/
40-
public registerBeforeInvocationRequest(beforeCb: InvocationRequestBefore): void {
41-
this._invocationRequestBefore.push(beforeCb);
34+
public registerHook(hookName: string, callback: HookCallback): Disposable {
35+
const hooks = this.getHooks(hookName);
36+
hooks.push(callback);
37+
return new Disposable(() => {
38+
const index = hooks.indexOf(callback);
39+
if (index > -1) {
40+
hooks.splice(index, 1);
41+
}
42+
});
4243
}
4344

44-
/**
45-
* Register a function to be run after User Function resolves.
46-
*/
47-
public registerAfterInvocationRequest(afterCb: InvocationRequestAfter): void {
48-
this._invocationRequestAfter.push(afterCb);
45+
public async executeHooks(hookName: string, context: {}): Promise<void> {
46+
const callbacks = this.getHooks(hookName);
47+
for (const callback of callbacks) {
48+
await callback(context);
49+
}
4950
}
5051

51-
public runInvocationRequestBefore(context: Context, userFunction: Function): Function {
52-
let wrappedFunction = userFunction;
53-
for (const before of this._invocationRequestBefore) {
54-
wrappedFunction = before(context, wrappedFunction);
52+
private getHooks(hookName: string): HookCallback[] {
53+
switch (hookName) {
54+
case 'preInvocation':
55+
return this._preInvocationHooks;
56+
case 'postInvocation':
57+
return this._postInvocationHooks;
58+
default:
59+
throw new RangeError(`Unrecognized hook "${hookName}"`);
5560
}
56-
return wrappedFunction;
5761
}
5862

59-
public runInvocationRequestAfter(context: Context) {
60-
for (const after of this._invocationRequestAfter) {
61-
after(context);
62-
}
63+
private initWorkerModule(channel: WorkerChannel) {
64+
const workerApi = {
65+
registerHook: (hookName: string, callback: HookCallback) => channel.registerHook(hookName, callback),
66+
Disposable,
67+
};
68+
69+
Module.prototype.require = new Proxy(Module.prototype.require, {
70+
apply(target, thisArg, argArray) {
71+
if (argArray[0] === '@azure/functions-worker') {
72+
return workerApi;
73+
} else {
74+
return Reflect.apply(target, thisArg, argArray);
75+
}
76+
},
77+
});
6378
}
6479
}

src/eventHandlers/invocationRequest.ts

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
// Copyright (c) .NET Foundation. All rights reserved.
22
// Licensed under the MIT License.
33

4+
import { PostInvocationContext, PreInvocationContext } from '@azure/functions-worker';
45
import { format } from 'util';
56
import { AzureFunctionsRpcMessages as rpc } from '../../azure-functions-language-worker-protobuf/src/rpc';
67
import { CreateContextAndInputs } from '../Context';
@@ -66,7 +67,7 @@ export async function invocationRequest(channel: WorkerChannel, requestId: strin
6667
isDone = true;
6768
}
6869

69-
const { context, inputs, doneEmitter } = CreateContextAndInputs(info, msg, userLog);
70+
let { context, inputs, doneEmitter } = CreateContextAndInputs(info, msg, userLog);
7071
try {
7172
const legacyDoneTask = new Promise((resolve, reject) => {
7273
doneEmitter.on('done', (err?: unknown, result?: any) => {
@@ -79,8 +80,12 @@ export async function invocationRequest(channel: WorkerChannel, requestId: strin
7980
});
8081
});
8182

82-
let userFunction = channel.functionLoader.getFunc(nonNullProp(msg, 'functionId'));
83-
userFunction = channel.runInvocationRequestBefore(context, userFunction);
83+
const userFunction = channel.functionLoader.getFunc(nonNullProp(msg, 'functionId'));
84+
const preInvocContext: PreInvocationContext = { invocationContext: context, inputs };
85+
86+
await channel.executeHooks('preInvocation', preInvocContext);
87+
inputs = preInvocContext.inputs;
88+
8489
let rawResult = userFunction(context, ...inputs);
8590
resultIsPromise = rawResult && typeof rawResult.then === 'function';
8691
let resultTask: Promise<any>;
@@ -94,7 +99,18 @@ export async function invocationRequest(channel: WorkerChannel, requestId: strin
9499
resultTask = legacyDoneTask;
95100
}
96101

97-
const result = await resultTask;
102+
const postInvocContext: PostInvocationContext = Object.assign(preInvocContext, { result: null, error: null });
103+
try {
104+
postInvocContext.result = await resultTask;
105+
} catch (err) {
106+
postInvocContext.error = err;
107+
}
108+
await channel.executeHooks('postInvocation', postInvocContext);
109+
110+
if (isError(postInvocContext.error)) {
111+
throw postInvocContext.error;
112+
}
113+
const result = postInvocContext.result;
98114

99115
// Allow HTTP response from context.res if HTTP response is not defined from the context.bindings object
100116
if (info.httpOutputName && context.res && context.bindings[info.httpOutputName] === undefined) {
@@ -163,6 +179,4 @@ export async function invocationRequest(channel: WorkerChannel, requestId: strin
163179
requestId: requestId,
164180
invocationResponse: response,
165181
});
166-
167-
channel.runInvocationRequestAfter(context);
168182
}

0 commit comments

Comments
 (0)