Skip to content

fix: replace subscripting in OpenAI responses #50

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jan 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 62 additions & 0 deletions .grit/patterns/python/openai.md
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,17 @@ pattern pytest_patch() {
},
}



// When there is a variable used by an openai call, make sure it isn't subscripted
pattern fix_downstream_openai_usage() {
$var where {
$program <: maybe contains bubble($var) `$x['$y']` as $sub => `$x.$y` where {
$sub <: contains $var
}
}
}

pattern openai_main($client, $azure) {
$body where {
if ($client <: undefined) {
Expand Down Expand Up @@ -257,6 +268,9 @@ pattern openai_main($client, $azure) {
contains `import openai` as $import_stmt where {
$body <: contains bubble($has_sync, $has_async, $has_openai_import, $body, $client, $azure) `openai.$res.$func($params)` as $stmt where {
$res <: rewrite_whole_fn_call(import = $has_openai_import, $has_sync, $has_async, $res, $func, $params, $stmt, $body, $client, $azure),
$stmt <: maybe within bubble($stmt) `$var = $stmt` where {
$var <: fix_downstream_openai_usage()
}
},
},
contains `from openai import $resources` as $partial_import_stmt where {
Expand Down Expand Up @@ -562,3 +576,51 @@ response = client.chat.completions.create(
]
)
```

## Fix subscripting

The new API does not support subscripting on the outputs.

```python
import openai

model, token_limit, prompt_cost, comp_cost = 'gpt-4-32k', 32_768, 0.06, 0.12

completion = openai.ChatCompletion.create(
model=model,
messages=[
{"role": "system", "content": system},
{"role": "user", "content":
user + text},
]
)
output = completion['choices'][0]['message']['content']

prom = completion['usage']['prompt_tokens']
comp = completion['usage']['completion_tokens']

# unrelated variable
foo = something['else']
```

```python
from openai import OpenAI

client = OpenAI()

model, token_limit, prompt_cost, comp_cost = 'gpt-4-32k', 32_768, 0.06, 0.12

completion = client.chat.completions.create(model=model,
messages=[
{"role": "system", "content": system},
{"role": "user", "content":
user + text},
])
output = completion.choices[0].message.content

prom = completion.usage.prompt_tokens
comp = completion.usage.completion_tokens

# unrelated variable
foo = something['else']
```
4 changes: 2 additions & 2 deletions .grit/patterns/python/openai_azure.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ response = client.chat.completions.create(
]
)

print(response['choices'][0]['message']['content'])
print(response.choices[0].message.content)
```

## Embeddings
Expand Down Expand Up @@ -99,7 +99,7 @@ response = client.embeddings.create(
input="Your text string goes here",
model="YOUR_DEPLOYMENT_NAME"
)
embeddings = response['data'][0]['embedding']
embeddings = response.data[0].embedding
print(embeddings)
```

Expand Down