Skip to content

bug: fix model provider working with no openai key set #199

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ module github.com/gptscript-ai/gptscript

go 1.22.0

replace github.com/sashabaranov/go-openai => github.com/gptscript-ai/go-openai v0.0.0-20240328093028-7993661f9eab

require (
github.com/AlecAivazis/survey/v2 v2.3.7
github.com/BurntSushi/locker v0.0.0-20171006230638-a6e239ea1c69
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,8 @@ github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+
github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5mhpdKc/us6bOk=
github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc=
github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/gptscript-ai/go-openai v0.0.0-20240328093028-7993661f9eab h1:uZP7zZqtQI5lfK0fGBmi2ZUrI973tNCnCDx326LG00k=
github.com/gptscript-ai/go-openai v0.0.0-20240328093028-7993661f9eab/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg=
github.com/hashicorp/errwrap v1.0.0 h1:hLrqtEDnRye3+sgx6z4qVLNuviH3MR5aQ0ykNJa/UYA=
github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4=
github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo=
Expand Down Expand Up @@ -206,8 +208,6 @@ github.com/samber/lo v1.38.1 h1:j2XEAqXKb09Am4ebOg31SpvzUTTs6EN3VfgeLUhPdXM=
github.com/samber/lo v1.38.1/go.mod h1:+m/ZKRl6ClXCE2Lgf3MsQlWfh4bn1bz6CXEOxnEXnEA=
github.com/samber/slog-logrus v1.0.0 h1:SsrN0p9akjCEaYd42Q5GtisMdHm0q11UD4fp4XCZi04=
github.com/samber/slog-logrus v1.0.0/go.mod h1:ZTdPCmVWljwlfjz6XflKNvW4TcmYlexz4HMUOO/42bI=
github.com/sashabaranov/go-openai v1.20.1 h1:cFnTixAtc0I0cCBFr8gkvEbGCm6Rjf2JyoVWCjXwy9g=
github.com/sashabaranov/go-openai v1.20.1/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg=
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
github.com/spf13/cobra v1.8.0 h1:7aJaZx1B85qltLMc546zn58BxxfZdR/W22ej9CFoEf0=
Expand Down
3 changes: 3 additions & 0 deletions pkg/engine/cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,9 @@ func appendInputAsEnv(env []string, input string) []string {
func (e *Engine) newCommand(ctx context.Context, extraEnv []string, tool types.Tool, input string) (*exec.Cmd, func(), error) {
envvars := append(e.Env[:], extraEnv...)
envvars = appendInputAsEnv(envvars, input)
if log.IsDebug() {
envvars = append(envvars, "GPTSCRIPT_DEBUG=true")
}

interpreter, rest, _ := strings.Cut(tool.Instructions, "\n")
interpreter = strings.TrimSpace(interpreter)[2:]
Expand Down
31 changes: 26 additions & 5 deletions pkg/engine/daemon.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"fmt"
"io"
"math/rand"
"net/http"
"os"
"strings"
Expand All @@ -18,7 +19,7 @@ var (
daemonLock sync.Mutex

startPort, endPort int64
nextPort int64
usedPorts map[int64]struct{}
daemonCtx context.Context
daemonClose func()
daemonWG sync.WaitGroup
Expand All @@ -41,10 +42,29 @@ func (e *Engine) getNextPort() int64 {
startPort = 10240
endPort = 11240
}
count := endPort - startPort
nextPort++
nextPort = nextPort % count
return startPort + nextPort
// This is pretty simple and inefficient approach, but also never releases ports
count := endPort - startPort + 1
toTry := make([]int64, 0, count)
for i := startPort; i <= endPort; i++ {
toTry = append(toTry, i)
}

rand.Shuffle(len(toTry), func(i, j int) {
toTry[i], toTry[j] = toTry[j], toTry[i]
})

for _, nextPort := range toTry {
if _, ok := usedPorts[nextPort]; ok {
continue
}
if usedPorts == nil {
usedPorts = map[int64]struct{}{}
}
usedPorts[nextPort] = struct{}{}
return nextPort
}

panic("Ran out of usable ports")
}

func getPath(instructions string) (string, string) {
Expand Down Expand Up @@ -92,6 +112,7 @@ func (e *Engine) startDaemon(_ context.Context, tool types.Tool) (string, error)

cmd, stop, err := e.newCommand(ctx, []string{
fmt.Sprintf("PORT=%d", port),
fmt.Sprintf("GPTSCRIPT_PORT=%d", port),
},
tool,
"{}",
Expand Down
4 changes: 3 additions & 1 deletion pkg/env/env.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@ func execEquals(bin, check string) bool {
}

func ToEnvLike(v string) string {
return strings.ToUpper(strings.ReplaceAll(v, "-", "_"))
v = strings.ReplaceAll(v, ".", "_")
v = strings.ReplaceAll(v, "-", "_")
return strings.ToUpper(v)
}

func Matches(cmd []string, bin string) bool {
Expand Down
3 changes: 2 additions & 1 deletion pkg/gptscript/gptscript.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@ func New(opts *Options) (*GPTScript, error) {
}

oAIClient, err := openai.NewClient(append([]openai.Options{opts.OpenAI}, openai.Options{
Cache: cacheClient,
Cache: cacheClient,
SetSeed: true,
})...)
if err != nil {
return nil, err
Expand Down
48 changes: 36 additions & 12 deletions pkg/openai/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,12 @@ var (
)

type Client struct {
url string
key string
defaultModel string
c *openai.Client
cache *cache.Client
invalidAuth bool
cacheKeyBase string
setSeed bool
}

type Options struct {
Expand All @@ -47,6 +48,8 @@ type Options struct {
APIType openai.APIType `usage:"OpenAI API Type (valid: OPEN_AI, AZURE, AZURE_AD)" name:"openai-api-type" env:"OPENAI_API_TYPE"`
OrgID string `usage:"OpenAI organization ID" name:"openai-org-id" env:"OPENAI_ORG_ID"`
DefaultModel string `usage:"Default LLM model to use" default:"gpt-4-turbo-preview"`
SetSeed bool `usage:"-"`
CacheKey string `usage:"-"`
Cache *cache.Client
}

Expand All @@ -59,6 +62,8 @@ func complete(opts ...Options) (result Options, err error) {
result.APIVersion = types.FirstSet(opt.APIVersion, result.APIVersion)
result.APIType = types.FirstSet(opt.APIType, result.APIType)
result.DefaultModel = types.FirstSet(opt.DefaultModel, result.DefaultModel)
result.SetSeed = types.FirstSet(opt.SetSeed, result.SetSeed)
result.CacheKey = types.FirstSet(opt.CacheKey, result.CacheKey)
}

if result.Cache == nil {
Expand All @@ -75,10 +80,6 @@ func complete(opts ...Options) (result Options, err error) {
result.APIKey = key
}

if result.APIKey == "" && result.BaseURL == "" {
return result, fmt.Errorf("OPENAI_API_KEY is not set. Please set the OPENAI_API_KEY environment variable")
}

return result, err
}

Expand Down Expand Up @@ -112,13 +113,28 @@ func NewClient(opts ...Options) (*Client, error) {
cfg.APIVersion = types.FirstSet(opt.APIVersion, cfg.APIVersion)
cfg.APIType = types.FirstSet(opt.APIType, cfg.APIType)

cacheKeyBase := opt.CacheKey
if cacheKeyBase == "" {
cacheKeyBase = hash.ID(opt.APIKey, opt.BaseURL)
}

return &Client{
c: openai.NewClientWithConfig(cfg),
cache: opt.Cache,
defaultModel: opt.DefaultModel,
cacheKeyBase: cacheKeyBase,
invalidAuth: opt.APIKey == "" && opt.BaseURL == "",
setSeed: opt.SetSeed,
}, nil
}

func (c *Client) ValidAuth() error {
if c.invalidAuth {
return fmt.Errorf("OPENAI_API_KEY is not set. Please set the OPENAI_API_KEY environment variable")
}
return nil
}

func (c *Client) Supports(ctx context.Context, modelName string) (bool, error) {
models, err := c.ListModels(ctx)
if err != nil {
Expand All @@ -133,6 +149,10 @@ func (c *Client) ListModels(ctx context.Context, providers ...string) (result []
return nil, nil
}

if err := c.ValidAuth(); err != nil {
return nil, err
}

models, err := c.c.ListModels(ctx)
if err != nil {
return nil, err
Expand All @@ -146,8 +166,7 @@ func (c *Client) ListModels(ctx context.Context, providers ...string) (result []

func (c *Client) cacheKey(request openai.ChatCompletionRequest) string {
return hash.Encode(map[string]any{
"url": c.url,
"key": c.key,
"base": c.cacheKeyBase,
"request": request,
})
}
Expand Down Expand Up @@ -277,6 +296,10 @@ func toMessages(request types.CompletionRequest) (result []openai.ChatCompletion
}

func (c *Client) Call(ctx context.Context, messageRequest types.CompletionRequest, status chan<- types.CompletionStatus) (*types.CompletionMessage, error) {
if err := c.ValidAuth(); err != nil {
return nil, err
}

if messageRequest.Model == "" {
messageRequest.Model = c.defaultModel
}
Expand All @@ -296,10 +319,9 @@ func (c *Client) Call(ctx context.Context, messageRequest types.CompletionReques
}

if messageRequest.Temperature == nil {
// this is a hack because the field is marked as omitempty, so we need it to be set to a non-zero value but arbitrarily small
request.Temperature = 1e-08
request.Temperature = new(float32)
} else {
request.Temperature = *messageRequest.Temperature
request.Temperature = messageRequest.Temperature
}

if messageRequest.JSONResponse {
Expand Down Expand Up @@ -330,7 +352,9 @@ func (c *Client) Call(ctx context.Context, messageRequest types.CompletionReques
}

var cacheResponse bool
request.Seed = ptr(c.seed(request))
if c.setSeed {
request.Seed = ptr(c.seed(request))
}
response, ok, err := c.fromCache(ctx, messageRequest, request)
if err != nil {
return nil, err
Expand Down
2 changes: 2 additions & 0 deletions pkg/parser/parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ func isParam(line string, tool *types.Tool) (_ bool, err error) {
switch normalize(key) {
case "name":
tool.Parameters.Name = strings.ToLower(value)
case "modelprovider":
tool.Parameters.ModelProvider = true
case "model", "modelname":
tool.Parameters.ModelName = value
case "description":
Expand Down
18 changes: 5 additions & 13 deletions pkg/remote/remote.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@ import (
"fmt"
"net/url"
"os"
"slices"
"sort"
"strings"
"sync"

"github.com/gptscript-ai/gptscript/pkg/cache"
env2 "github.com/gptscript-ai/gptscript/pkg/env"
"github.com/gptscript-ai/gptscript/pkg/loader"
"github.com/gptscript-ai/gptscript/pkg/openai"
"github.com/gptscript-ai/gptscript/pkg/runner"
Expand Down Expand Up @@ -78,15 +78,6 @@ func (c *Client) Supports(ctx context.Context, modelName string) (bool, error) {
return false, err
}

models, err := client.ListModels(ctx)
if err != nil {
return false, err
}

if !slices.Contains(models, modelNameSuffix) {
return false, fmt.Errorf("Failed in find model [%s], supported [%s]", modelNameSuffix, strings.Join(models, ", "))
}

c.clientsLock.Lock()
defer c.clientsLock.Unlock()

Expand All @@ -108,7 +99,7 @@ func (c *Client) clientFromURL(apiURL string) (*openai.Client, error) {
if err != nil {
return nil, err
}
env := strings.ToUpper(strings.ReplaceAll(parsed.Hostname(), ".", "_")) + "_API_KEY"
env := "GPTSCRIPT_PROVIDER_" + env2.ToEnvLike(parsed.Hostname()) + "_API_KEY"
apiKey := os.Getenv(env)
if apiKey == "" {
apiKey = "<unset>"
Expand Down Expand Up @@ -159,8 +150,9 @@ func (c *Client) load(ctx context.Context, toolName string) (*openai.Client, err
}

client, err = openai.NewClient(openai.Options{
BaseURL: url,
Cache: c.cache,
BaseURL: url,
Cache: c.cache,
CacheKey: prg.EntryToolID,
})
if err != nil {
return nil, err
Expand Down
4 changes: 4 additions & 0 deletions pkg/types/tool.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ type Parameters struct {
Description string `json:"description,omitempty"`
MaxTokens int `json:"maxTokens,omitempty"`
ModelName string `json:"modelName,omitempty"`
ModelProvider bool `json:"modelProvider,omitempty"`
JSONResponse bool `json:"jsonResponse,omitempty"`
Temperature *float32 `json:"temperature,omitempty"`
Cache *bool `json:"cache,omitempty"`
Expand Down Expand Up @@ -81,6 +82,9 @@ func (t Tool) String() string {
if t.Parameters.ModelName != "" {
_, _ = fmt.Fprintf(buf, "Model Name: %s\n", t.Parameters.ModelName)
}
if t.Parameters.ModelProvider {
_, _ = fmt.Fprintf(buf, "Model Provider: true\n")
}
if t.Parameters.JSONResponse {
_, _ = fmt.Fprintln(buf, "JSON Response: true")
}
Expand Down