diff --git a/.grit/patterns/python/openai.md b/.grit/patterns/python/openai.md index 01e1ada9..a19f056e 100644 --- a/.grit/patterns/python/openai.md +++ b/.grit/patterns/python/openai.md @@ -198,6 +198,17 @@ pattern pytest_patch() { }, } + + +// When there is a variable used by an openai call, make sure it isn't subscripted +pattern fix_downstream_openai_usage() { + $var where { + $program <: maybe contains bubble($var) `$x['$y']` as $sub => `$x.$y` where { + $sub <: contains $var + } + } +} + pattern openai_main($client, $azure) { $body where { if ($client <: undefined) { @@ -257,6 +268,9 @@ pattern openai_main($client, $azure) { contains `import openai` as $import_stmt where { $body <: contains bubble($has_sync, $has_async, $has_openai_import, $body, $client, $azure) `openai.$res.$func($params)` as $stmt where { $res <: rewrite_whole_fn_call(import = $has_openai_import, $has_sync, $has_async, $res, $func, $params, $stmt, $body, $client, $azure), + $stmt <: maybe within bubble($stmt) `$var = $stmt` where { + $var <: fix_downstream_openai_usage() + } }, }, contains `from openai import $resources` as $partial_import_stmt where { @@ -562,3 +576,51 @@ response = client.chat.completions.create( ] ) ``` + +## Fix subscripting + +The new API does not support subscripting on the outputs. + +```python +import openai + +model, token_limit, prompt_cost, comp_cost = 'gpt-4-32k', 32_768, 0.06, 0.12 + +completion = openai.ChatCompletion.create( + model=model, + messages=[ + {"role": "system", "content": system}, + {"role": "user", "content": + user + text}, + ] +) +output = completion['choices'][0]['message']['content'] + +prom = completion['usage']['prompt_tokens'] +comp = completion['usage']['completion_tokens'] + +# unrelated variable +foo = something['else'] +``` + +```python +from openai import OpenAI + +client = OpenAI() + +model, token_limit, prompt_cost, comp_cost = 'gpt-4-32k', 32_768, 0.06, 0.12 + +completion = client.chat.completions.create(model=model, +messages=[ + {"role": "system", "content": system}, + {"role": "user", "content": + user + text}, +]) +output = completion.choices[0].message.content + +prom = completion.usage.prompt_tokens +comp = completion.usage.completion_tokens + +# unrelated variable +foo = something['else'] +``` diff --git a/.grit/patterns/python/openai_azure.md b/.grit/patterns/python/openai_azure.md index 65a1d22b..1856f2e7 100644 --- a/.grit/patterns/python/openai_azure.md +++ b/.grit/patterns/python/openai_azure.md @@ -64,7 +64,7 @@ response = client.chat.completions.create( ] ) -print(response['choices'][0]['message']['content']) +print(response.choices[0].message.content) ``` ## Embeddings @@ -99,7 +99,7 @@ response = client.embeddings.create( input="Your text string goes here", model="YOUR_DEPLOYMENT_NAME" ) -embeddings = response['data'][0]['embedding'] +embeddings = response.data[0].embedding print(embeddings) ```