Skip to content

Commit fa5419d

Browse files
authored
refactor(js/plugins/compat-oai): Refactor existing compat plugin to match proposal. (#3086)
1 parent 31b18b0 commit fa5419d

File tree

21 files changed

+1880
-1682
lines changed

21 files changed

+1880
-1682
lines changed

js/plugins/compat-oai/jest.config.js

Lines changed: 0 additions & 22 deletions
This file was deleted.

js/plugins/compat-oai/jest.config.ts

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
/**
2+
* Copyright 2024 Google LLC
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
/**
18+
* For a detailed explanation regarding each configuration property, visit:
19+
* https://jestjs.io/docs/configuration
20+
*/
21+
22+
import type { Config } from 'jest';
23+
24+
const config: Config = {
25+
// Automatically clear mock calls, instances, contexts and results before every test
26+
clearMocks: true,
27+
28+
// A preset that is used as a base for Jest's configuration
29+
preset: 'ts-jest',
30+
31+
// The glob patterns Jest uses to detect test files
32+
testMatch: ['**/tests/**/*_test.ts'],
33+
34+
// An array of regexp pattern strings that are matched against all test paths, matched tests are skipped
35+
testPathIgnorePatterns: ['/node_modules/'],
36+
37+
// A map from regular expressions to paths to transformers
38+
transform: {
39+
'^.+\\.ts$': 'ts-jest',
40+
},
41+
42+
// An array of regexp pattern strings that are matched against all source file paths, matched files will skip transformation
43+
transformIgnorePatterns: ['/node_modules/'],
44+
45+
moduleNameMapper: {
46+
'^(\\.{1,2}/.*)\\.js$': '$1',
47+
},
48+
};
49+
50+
export default config;

js/plugins/compat-oai/package.json

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,19 @@
4242
"import": "./lib/index.mjs",
4343
"types": "./lib/index.d.ts",
4444
"default": "./lib/index.js"
45+
},
46+
"./openai": {
47+
"require": "./lib/openai/index.js",
48+
"import": "./lib/openai/index.mjs",
49+
"types": "./lib/openai/index.d.ts",
50+
"default": "./lib/openai/index.js"
51+
}
52+
},
53+
"typesVersions": {
54+
"*": {
55+
"openai": [
56+
"lib/openai"
57+
]
4558
}
4659
},
4760
"files": [

js/plugins/compat-oai/src/audio.ts

Lines changed: 255 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,255 @@
1+
/**
2+
* Copyright 2024 The Fire Company
3+
* Copyright 2024 Google LLC
4+
*
5+
* Licensed under the Apache License, Version 2.0 (the "License");
6+
* you may not use this file except in compliance with the License.
7+
* You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
import type {
18+
GenerateRequest,
19+
GenerateResponseData,
20+
Genkit,
21+
ModelReference,
22+
} from 'genkit';
23+
import { Message, z } from 'genkit';
24+
import type { ModelAction } from 'genkit/model';
25+
import type OpenAI from 'openai';
26+
import type {
27+
SpeechCreateParams,
28+
Transcription,
29+
TranscriptionCreateParams,
30+
} from 'openai/resources/audio/index.mjs';
31+
32+
/**
33+
* Supported media formats for Audio generation
34+
*/
35+
export const RESPONSE_FORMAT_MEDIA_TYPES = {
36+
mp3: 'audio/mpeg',
37+
opus: 'audio/opus',
38+
aac: 'audio/aac',
39+
flac: 'audio/flac',
40+
wav: 'audio/wav',
41+
pcm: 'audio/L16',
42+
};
43+
44+
function toTTSRequest(
45+
modelName: string,
46+
request: GenerateRequest
47+
): SpeechCreateParams {
48+
const {
49+
voice,
50+
version: modelVersion,
51+
temperature,
52+
maxOutputTokens,
53+
stopSequences,
54+
topK,
55+
topP,
56+
...restOfConfig
57+
} = request.config ?? {};
58+
59+
const options: SpeechCreateParams = {
60+
model: modelVersion ?? modelName,
61+
input: new Message(request.messages[0]).text,
62+
voice: voice ?? 'alloy',
63+
...restOfConfig, // passthorugh rest of the config
64+
};
65+
for (const k in options) {
66+
if (options[k] === undefined) {
67+
delete options[k];
68+
}
69+
}
70+
return options;
71+
}
72+
73+
function toGenerateResponse(
74+
result: Buffer,
75+
responseFormat: 'mp3' | 'opus' | 'aac' | 'flac' | 'wav' | 'pcm' = 'mp3'
76+
): GenerateResponseData {
77+
const mediaType = RESPONSE_FORMAT_MEDIA_TYPES[responseFormat];
78+
return {
79+
message: {
80+
role: 'model',
81+
content: [
82+
{
83+
media: {
84+
contentType: mediaType,
85+
url: `data:${mediaType};base64,${result.toString('base64')}`,
86+
},
87+
},
88+
],
89+
},
90+
finishReason: 'stop',
91+
};
92+
}
93+
94+
/**
95+
* Method to define a new Genkit Model that is compatible with the Open AI Audio
96+
* API.
97+
*
98+
* These models are to be used to create audio speech from a given request.
99+
* @param params An object containing parameters for defining the OpenAI speech
100+
* model.
101+
* @param params.ai The Genkit AI instance.
102+
* @param params.name The name of the model.
103+
* @param params.client The OpenAI client instance.
104+
* @param params.modelRef Optional reference to the model's configuration and
105+
* custom options.
106+
107+
* @returns the created {@link ModelAction}
108+
*/
109+
export function defineCompatOpenAISpeechModel<
110+
CustomOptions extends z.ZodTypeAny = z.ZodTypeAny,
111+
>(params: {
112+
ai: Genkit;
113+
name: string;
114+
client: OpenAI;
115+
modelRef?: ModelReference<CustomOptions>;
116+
}): ModelAction {
117+
const { ai, name, client, modelRef } = params;
118+
119+
const model = name.split('/').pop();
120+
return ai.defineModel(
121+
{
122+
name,
123+
...modelRef?.info,
124+
configSchema: modelRef?.configSchema,
125+
},
126+
async (request) => {
127+
const ttsRequest = toTTSRequest(model!, request);
128+
const result = await client.audio.speech.create(ttsRequest);
129+
const resultArrayBuffer = await result.arrayBuffer();
130+
const resultBuffer = Buffer.from(new Uint8Array(resultArrayBuffer));
131+
return toGenerateResponse(resultBuffer, ttsRequest.response_format);
132+
}
133+
);
134+
}
135+
136+
function toSttRequest(
137+
modelName: string,
138+
request: GenerateRequest
139+
): TranscriptionCreateParams {
140+
const message = new Message(request.messages[0]);
141+
const media = message.media;
142+
if (!media?.url) {
143+
throw new Error('No media found in the request');
144+
}
145+
const mediaBuffer = Buffer.from(
146+
media.url.slice(media.url.indexOf(',') + 1),
147+
'base64'
148+
);
149+
const mediaFile = new File([mediaBuffer], 'input', {
150+
type:
151+
media.contentType ??
152+
media.url.slice('data:'.length, media.url.indexOf(';')),
153+
});
154+
const {
155+
temperature,
156+
version: modelVersion,
157+
maxOutputTokens,
158+
stopSequences,
159+
topK,
160+
topP,
161+
...restOfConfig
162+
} = request.config ?? {};
163+
164+
const options: TranscriptionCreateParams = {
165+
model: modelVersion ?? modelName,
166+
file: mediaFile,
167+
prompt: message.text,
168+
temperature,
169+
...restOfConfig, // passthrough rest of the config
170+
};
171+
const outputFormat = request.output?.format as 'json' | 'text' | 'media';
172+
const customFormat = request.config?.response_format;
173+
if (outputFormat && customFormat) {
174+
if (
175+
outputFormat === 'json' &&
176+
customFormat !== 'json' &&
177+
customFormat !== 'verbose_json'
178+
) {
179+
throw new Error(
180+
`Custom response format ${customFormat} is not compatible with output format ${outputFormat}`
181+
);
182+
}
183+
}
184+
if (outputFormat === 'media') {
185+
throw new Error(`Output format ${outputFormat} is not supported.`);
186+
}
187+
options.response_format = customFormat || outputFormat || 'text';
188+
for (const k in options) {
189+
if (options[k] === undefined) {
190+
delete options[k];
191+
}
192+
}
193+
return options;
194+
}
195+
196+
function transcriptionToGenerateResponse(
197+
result: Transcription | string
198+
): GenerateResponseData {
199+
return {
200+
message: {
201+
role: 'model',
202+
content: [
203+
{
204+
text: typeof result === 'string' ? result : result.text,
205+
},
206+
],
207+
},
208+
finishReason: 'stop',
209+
};
210+
}
211+
212+
/**
213+
* Method to define a new Genkit Model that is compatible with Open AI
214+
* Transcriptions API.
215+
*
216+
* These models are to be used to transcribe audio to text.
217+
*
218+
* @param params An object containing parameters for defining the OpenAI
219+
* transcription model.
220+
* @param params.ai The Genkit AI instance.
221+
* @param params.name The name of the model.
222+
* @param params.client The OpenAI client instance.
223+
* @param params.modelRef Optional reference to the model's configuration and
224+
* custom options.
225+
226+
* @returns the created {@link ModelAction}
227+
*/
228+
export function defineCompatOpenAITranscriptionModel<
229+
CustomOptions extends z.ZodTypeAny = z.ZodTypeAny,
230+
>(params: {
231+
ai: Genkit;
232+
name: string;
233+
client: OpenAI;
234+
modelRef?: ModelReference<CustomOptions>;
235+
}): ModelAction {
236+
const { ai, name, client, modelRef } = params;
237+
238+
return ai.defineModel(
239+
{
240+
name,
241+
...modelRef?.info,
242+
configSchema: modelRef?.configSchema,
243+
},
244+
async (request) => {
245+
const modelName = name.split('/').pop();
246+
const params = toSttRequest(modelName!, request);
247+
// Explicitly setting stream to false ensures we use the non-streaming overload
248+
const result = await client.audio.transcriptions.create({
249+
...params,
250+
stream: false,
251+
});
252+
return transcriptionToGenerateResponse(result);
253+
}
254+
);
255+
}

0 commit comments

Comments
 (0)