From 378eb463fda9693890c00b6919bbdd4038debe4b Mon Sep 17 00:00:00 2001 From: Grant Linville Date: Wed, 29 Jan 2025 21:40:11 -0500 Subject: [PATCH 1/3] fix: count tokens from tool definitions when adjusting for context window Signed-off-by: Grant Linville --- pkg/openai/client.go | 18 ++++++++++++------ pkg/openai/count.go | 29 +++++++++++++++++++++++++++++ 2 files changed, 41 insertions(+), 6 deletions(-) diff --git a/pkg/openai/client.go b/pkg/openai/client.go index db911962..60dfee15 100644 --- a/pkg/openai/client.go +++ b/pkg/openai/client.go @@ -331,7 +331,12 @@ func (c *Client) Call(ctx context.Context, messageRequest types.CompletionReques messageRequest.Messages[len(messageRequest.Messages)-1].Content = types.Text(TooLongMessage) } - msgs = dropMessagesOverCount(messageRequest.MaxTokens, msgs) + toolsCount, err := countChatCompletionTools(messageRequest.Tools) + if err != nil { + return nil, err + } + + msgs = dropMessagesOverCount(messageRequest.MaxTokens-toolsCount, msgs) } if len(msgs) == 0 { @@ -447,14 +452,15 @@ func (c *Client) Call(ctx context.Context, messageRequest types.CompletionReques } func (c *Client) contextLimitRetryLoop(ctx context.Context, request openai.ChatCompletionRequest, id string, env []string, maxTokens int, status chan<- types.CompletionStatus) (types.CompletionMessage, error) { - var ( - response types.CompletionMessage - err error - ) + toolsCount, err := countOpenAITools(request.Tools) + if err != nil { + return types.CompletionMessage{}, err + } + var response types.CompletionMessage for range 10 { // maximum 10 tries // Try to drop older messages again, with a decreased max tokens. - request.Messages = dropMessagesOverCount(maxTokens, request.Messages) + request.Messages = dropMessagesOverCount(maxTokens-toolsCount, request.Messages) response, err = c.call(ctx, request, id, env, status) if err == nil { return response, nil diff --git a/pkg/openai/count.go b/pkg/openai/count.go index ffd902e5..57f67f54 100644 --- a/pkg/openai/count.go +++ b/pkg/openai/count.go @@ -1,7 +1,10 @@ package openai import ( + "encoding/json" + openai "github.com/gptscript-ai/chat-completion-client" + "github.com/gptscript-ai/gptscript/pkg/types" ) const DefaultMaxTokens = 128_000 @@ -73,3 +76,29 @@ func countMessage(msg openai.ChatCompletionMessage) (count int) { count += len(msg.ToolCallID) return count / 3 } + +func countChatCompletionTools(tools []types.ChatCompletionTool) (count int, err error) { + for _, t := range tools { + count += len(t.Function.Name) + count += len(t.Function.Description) + paramsJSON, err := json.Marshal(t.Function.Parameters) + if err != nil { + return 0, err + } + count += len(paramsJSON) + } + return count / 3, nil +} + +func countOpenAITools(tools []openai.Tool) (count int, err error) { + for _, t := range tools { + count += len(t.Function.Name) + count += len(t.Function.Description) + paramsJSON, err := json.Marshal(t.Function.Parameters) + if err != nil { + return 0, err + } + count += len(paramsJSON) + } + return count / 3, nil +} From f2200970d0d7dd0f8eec13354b65c90c7301ea96 Mon Sep 17 00:00:00 2001 From: Grant Linville Date: Wed, 29 Jan 2025 22:14:56 -0500 Subject: [PATCH 2/3] fixes Signed-off-by: Grant Linville --- pkg/openai/client.go | 4 ++-- pkg/openai/count.go | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pkg/openai/client.go b/pkg/openai/client.go index 60dfee15..a4984a0c 100644 --- a/pkg/openai/client.go +++ b/pkg/openai/client.go @@ -413,7 +413,7 @@ func (c *Client) Call(ctx context.Context, messageRequest types.CompletionReques // If we got back a context length exceeded error, keep retrying and shrinking the message history until we pass. var apiError *openai.APIError - if errors.As(err, &apiError) && apiError.Code == "context_length_exceeded" && messageRequest.Chat { + if err != nil && ((errors.As(err, &apiError) && apiError.Code == "context_length_exceeded") || strings.Contains(err.Error(), "maximum context length is")) && messageRequest.Chat { // Decrease maxTokens by 10% to make garbage collection more aggressive. // The retry loop will further decrease maxTokens if needed. maxTokens := decreaseTenPercent(messageRequest.MaxTokens) @@ -467,7 +467,7 @@ func (c *Client) contextLimitRetryLoop(ctx context.Context, request openai.ChatC } var apiError *openai.APIError - if errors.As(err, &apiError) && apiError.Code == "context_length_exceeded" { + if (errors.As(err, &apiError) && apiError.Code == "context_length_exceeded") || strings.Contains(err.Error(), "maximum context length is") { // Decrease maxTokens and try again maxTokens = decreaseTenPercent(maxTokens) continue diff --git a/pkg/openai/count.go b/pkg/openai/count.go index 57f67f54..e625b8e4 100644 --- a/pkg/openai/count.go +++ b/pkg/openai/count.go @@ -15,7 +15,7 @@ func decreaseTenPercent(maxTokens int) int { } func getBudget(maxTokens int) int { - if maxTokens == 0 { + if maxTokens <= 0 { return DefaultMaxTokens } return maxTokens From eff6633f32786cc00e1ab76460fc9006b38b6f37 Mon Sep 17 00:00:00 2001 From: Grant Linville Date: Wed, 29 Jan 2025 22:22:31 -0500 Subject: [PATCH 3/3] fix Signed-off-by: Grant Linville --- pkg/openai/count.go | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/pkg/openai/count.go b/pkg/openai/count.go index e625b8e4..0832824f 100644 --- a/pkg/openai/count.go +++ b/pkg/openai/count.go @@ -15,8 +15,17 @@ func decreaseTenPercent(maxTokens int) int { } func getBudget(maxTokens int) int { - if maxTokens <= 0 { + if maxTokens == 0 { return DefaultMaxTokens + } else if maxTokens <= 0 { + // maxTokens was 0 (or some very small number), the tool count pushed it negative + // so we can just add that negative number to the default max tokens, to get something lower + if DefaultMaxTokens+maxTokens >= 0 { + return DefaultMaxTokens + maxTokens + } + + // If max tokens was so negative that it was below 128k, then we just return 0 I guess + return 0 } return maxTokens }