diff --git a/pkg/openai/client.go b/pkg/openai/client.go index db911962..a4984a0c 100644 --- a/pkg/openai/client.go +++ b/pkg/openai/client.go @@ -331,7 +331,12 @@ func (c *Client) Call(ctx context.Context, messageRequest types.CompletionReques messageRequest.Messages[len(messageRequest.Messages)-1].Content = types.Text(TooLongMessage) } - msgs = dropMessagesOverCount(messageRequest.MaxTokens, msgs) + toolsCount, err := countChatCompletionTools(messageRequest.Tools) + if err != nil { + return nil, err + } + + msgs = dropMessagesOverCount(messageRequest.MaxTokens-toolsCount, msgs) } if len(msgs) == 0 { @@ -408,7 +413,7 @@ func (c *Client) Call(ctx context.Context, messageRequest types.CompletionReques // If we got back a context length exceeded error, keep retrying and shrinking the message history until we pass. var apiError *openai.APIError - if errors.As(err, &apiError) && apiError.Code == "context_length_exceeded" && messageRequest.Chat { + if err != nil && ((errors.As(err, &apiError) && apiError.Code == "context_length_exceeded") || strings.Contains(err.Error(), "maximum context length is")) && messageRequest.Chat { // Decrease maxTokens by 10% to make garbage collection more aggressive. // The retry loop will further decrease maxTokens if needed. maxTokens := decreaseTenPercent(messageRequest.MaxTokens) @@ -447,21 +452,22 @@ func (c *Client) Call(ctx context.Context, messageRequest types.CompletionReques } func (c *Client) contextLimitRetryLoop(ctx context.Context, request openai.ChatCompletionRequest, id string, env []string, maxTokens int, status chan<- types.CompletionStatus) (types.CompletionMessage, error) { - var ( - response types.CompletionMessage - err error - ) + toolsCount, err := countOpenAITools(request.Tools) + if err != nil { + return types.CompletionMessage{}, err + } + var response types.CompletionMessage for range 10 { // maximum 10 tries // Try to drop older messages again, with a decreased max tokens. - request.Messages = dropMessagesOverCount(maxTokens, request.Messages) + request.Messages = dropMessagesOverCount(maxTokens-toolsCount, request.Messages) response, err = c.call(ctx, request, id, env, status) if err == nil { return response, nil } var apiError *openai.APIError - if errors.As(err, &apiError) && apiError.Code == "context_length_exceeded" { + if (errors.As(err, &apiError) && apiError.Code == "context_length_exceeded") || strings.Contains(err.Error(), "maximum context length is") { // Decrease maxTokens and try again maxTokens = decreaseTenPercent(maxTokens) continue diff --git a/pkg/openai/count.go b/pkg/openai/count.go index ffd902e5..0832824f 100644 --- a/pkg/openai/count.go +++ b/pkg/openai/count.go @@ -1,7 +1,10 @@ package openai import ( + "encoding/json" + openai "github.com/gptscript-ai/chat-completion-client" + "github.com/gptscript-ai/gptscript/pkg/types" ) const DefaultMaxTokens = 128_000 @@ -14,6 +17,15 @@ func decreaseTenPercent(maxTokens int) int { func getBudget(maxTokens int) int { if maxTokens == 0 { return DefaultMaxTokens + } else if maxTokens <= 0 { + // maxTokens was 0 (or some very small number), the tool count pushed it negative + // so we can just add that negative number to the default max tokens, to get something lower + if DefaultMaxTokens+maxTokens >= 0 { + return DefaultMaxTokens + maxTokens + } + + // If max tokens was so negative that it was below 128k, then we just return 0 I guess + return 0 } return maxTokens } @@ -73,3 +85,29 @@ func countMessage(msg openai.ChatCompletionMessage) (count int) { count += len(msg.ToolCallID) return count / 3 } + +func countChatCompletionTools(tools []types.ChatCompletionTool) (count int, err error) { + for _, t := range tools { + count += len(t.Function.Name) + count += len(t.Function.Description) + paramsJSON, err := json.Marshal(t.Function.Parameters) + if err != nil { + return 0, err + } + count += len(paramsJSON) + } + return count / 3, nil +} + +func countOpenAITools(tools []openai.Tool) (count int, err error) { + for _, t := range tools { + count += len(t.Function.Name) + count += len(t.Function.Description) + paramsJSON, err := json.Marshal(t.Function.Parameters) + if err != nil { + return 0, err + } + count += len(paramsJSON) + } + return count / 3, nil +}