Skip to content

Commit 96a8368

Browse files
Merge pull request #199 from ibuildthecloud/model-provider
bug: fix model provider working with no openai key set
2 parents 09da727 + 259e094 commit 96a8368

File tree

10 files changed

+85
-34
lines changed

10 files changed

+85
-34
lines changed

go.mod

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@ module github.com/gptscript-ai/gptscript
22

33
go 1.22.0
44

5+
replace github.com/sashabaranov/go-openai => github.com/gptscript-ai/go-openai v0.0.0-20240328093028-7993661f9eab
6+
57
require (
68
github.com/AlecAivazis/survey/v2 v2.3.7
79
github.com/BurntSushi/locker v0.0.0-20171006230638-a6e239ea1c69

go.sum

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,8 @@ github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+
116116
github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5mhpdKc/us6bOk=
117117
github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc=
118118
github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
119+
github.com/gptscript-ai/go-openai v0.0.0-20240328093028-7993661f9eab h1:uZP7zZqtQI5lfK0fGBmi2ZUrI973tNCnCDx326LG00k=
120+
github.com/gptscript-ai/go-openai v0.0.0-20240328093028-7993661f9eab/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg=
119121
github.com/hashicorp/errwrap v1.0.0 h1:hLrqtEDnRye3+sgx6z4qVLNuviH3MR5aQ0ykNJa/UYA=
120122
github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4=
121123
github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo=
@@ -206,8 +208,6 @@ github.com/samber/lo v1.38.1 h1:j2XEAqXKb09Am4ebOg31SpvzUTTs6EN3VfgeLUhPdXM=
206208
github.com/samber/lo v1.38.1/go.mod h1:+m/ZKRl6ClXCE2Lgf3MsQlWfh4bn1bz6CXEOxnEXnEA=
207209
github.com/samber/slog-logrus v1.0.0 h1:SsrN0p9akjCEaYd42Q5GtisMdHm0q11UD4fp4XCZi04=
208210
github.com/samber/slog-logrus v1.0.0/go.mod h1:ZTdPCmVWljwlfjz6XflKNvW4TcmYlexz4HMUOO/42bI=
209-
github.com/sashabaranov/go-openai v1.20.1 h1:cFnTixAtc0I0cCBFr8gkvEbGCm6Rjf2JyoVWCjXwy9g=
210-
github.com/sashabaranov/go-openai v1.20.1/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg=
211211
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
212212
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
213213
github.com/spf13/cobra v1.8.0 h1:7aJaZx1B85qltLMc546zn58BxxfZdR/W22ej9CFoEf0=

pkg/engine/cmd.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,9 @@ func appendInputAsEnv(env []string, input string) []string {
151151
func (e *Engine) newCommand(ctx context.Context, extraEnv []string, tool types.Tool, input string) (*exec.Cmd, func(), error) {
152152
envvars := append(e.Env[:], extraEnv...)
153153
envvars = appendInputAsEnv(envvars, input)
154+
if log.IsDebug() {
155+
envvars = append(envvars, "GPTSCRIPT_DEBUG=true")
156+
}
154157

155158
interpreter, rest, _ := strings.Cut(tool.Instructions, "\n")
156159
interpreter = strings.TrimSpace(interpreter)[2:]

pkg/engine/daemon.go

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"context"
55
"fmt"
66
"io"
7+
"math/rand"
78
"net/http"
89
"os"
910
"strings"
@@ -18,7 +19,7 @@ var (
1819
daemonLock sync.Mutex
1920

2021
startPort, endPort int64
21-
nextPort int64
22+
usedPorts map[int64]struct{}
2223
daemonCtx context.Context
2324
daemonClose func()
2425
daemonWG sync.WaitGroup
@@ -41,10 +42,29 @@ func (e *Engine) getNextPort() int64 {
4142
startPort = 10240
4243
endPort = 11240
4344
}
44-
count := endPort - startPort
45-
nextPort++
46-
nextPort = nextPort % count
47-
return startPort + nextPort
45+
// This is pretty simple and inefficient approach, but also never releases ports
46+
count := endPort - startPort + 1
47+
toTry := make([]int64, 0, count)
48+
for i := startPort; i <= endPort; i++ {
49+
toTry = append(toTry, i)
50+
}
51+
52+
rand.Shuffle(len(toTry), func(i, j int) {
53+
toTry[i], toTry[j] = toTry[j], toTry[i]
54+
})
55+
56+
for _, nextPort := range toTry {
57+
if _, ok := usedPorts[nextPort]; ok {
58+
continue
59+
}
60+
if usedPorts == nil {
61+
usedPorts = map[int64]struct{}{}
62+
}
63+
usedPorts[nextPort] = struct{}{}
64+
return nextPort
65+
}
66+
67+
panic("Ran out of usable ports")
4868
}
4969

5070
func getPath(instructions string) (string, string) {
@@ -92,6 +112,7 @@ func (e *Engine) startDaemon(_ context.Context, tool types.Tool) (string, error)
92112

93113
cmd, stop, err := e.newCommand(ctx, []string{
94114
fmt.Sprintf("PORT=%d", port),
115+
fmt.Sprintf("GPTSCRIPT_PORT=%d", port),
95116
},
96117
tool,
97118
"{}",

pkg/env/env.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,9 @@ func execEquals(bin, check string) bool {
1313
}
1414

1515
func ToEnvLike(v string) string {
16-
return strings.ToUpper(strings.ReplaceAll(v, "-", "_"))
16+
v = strings.ReplaceAll(v, ".", "_")
17+
v = strings.ReplaceAll(v, "-", "_")
18+
return strings.ToUpper(v)
1719
}
1820

1921
func Matches(cmd []string, bin string) bool {

pkg/gptscript/gptscript.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,8 @@ func New(opts *Options) (*GPTScript, error) {
5454
}
5555

5656
oAIClient, err := openai.NewClient(append([]openai.Options{opts.OpenAI}, openai.Options{
57-
Cache: cacheClient,
57+
Cache: cacheClient,
58+
SetSeed: true,
5859
})...)
5960
if err != nil {
6061
return nil, err

pkg/openai/client.go

Lines changed: 36 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,12 @@ var (
3333
)
3434

3535
type Client struct {
36-
url string
37-
key string
3836
defaultModel string
3937
c *openai.Client
4038
cache *cache.Client
39+
invalidAuth bool
40+
cacheKeyBase string
41+
setSeed bool
4142
}
4243

4344
type Options struct {
@@ -47,6 +48,8 @@ type Options struct {
4748
APIType openai.APIType `usage:"OpenAI API Type (valid: OPEN_AI, AZURE, AZURE_AD)" name:"openai-api-type" env:"OPENAI_API_TYPE"`
4849
OrgID string `usage:"OpenAI organization ID" name:"openai-org-id" env:"OPENAI_ORG_ID"`
4950
DefaultModel string `usage:"Default LLM model to use" default:"gpt-4-turbo-preview"`
51+
SetSeed bool `usage:"-"`
52+
CacheKey string `usage:"-"`
5053
Cache *cache.Client
5154
}
5255

@@ -59,6 +62,8 @@ func complete(opts ...Options) (result Options, err error) {
5962
result.APIVersion = types.FirstSet(opt.APIVersion, result.APIVersion)
6063
result.APIType = types.FirstSet(opt.APIType, result.APIType)
6164
result.DefaultModel = types.FirstSet(opt.DefaultModel, result.DefaultModel)
65+
result.SetSeed = types.FirstSet(opt.SetSeed, result.SetSeed)
66+
result.CacheKey = types.FirstSet(opt.CacheKey, result.CacheKey)
6267
}
6368

6469
if result.Cache == nil {
@@ -75,10 +80,6 @@ func complete(opts ...Options) (result Options, err error) {
7580
result.APIKey = key
7681
}
7782

78-
if result.APIKey == "" && result.BaseURL == "" {
79-
return result, fmt.Errorf("OPENAI_API_KEY is not set. Please set the OPENAI_API_KEY environment variable")
80-
}
81-
8283
return result, err
8384
}
8485

@@ -112,13 +113,28 @@ func NewClient(opts ...Options) (*Client, error) {
112113
cfg.APIVersion = types.FirstSet(opt.APIVersion, cfg.APIVersion)
113114
cfg.APIType = types.FirstSet(opt.APIType, cfg.APIType)
114115

116+
cacheKeyBase := opt.CacheKey
117+
if cacheKeyBase == "" {
118+
cacheKeyBase = hash.ID(opt.APIKey, opt.BaseURL)
119+
}
120+
115121
return &Client{
116122
c: openai.NewClientWithConfig(cfg),
117123
cache: opt.Cache,
118124
defaultModel: opt.DefaultModel,
125+
cacheKeyBase: cacheKeyBase,
126+
invalidAuth: opt.APIKey == "" && opt.BaseURL == "",
127+
setSeed: opt.SetSeed,
119128
}, nil
120129
}
121130

131+
func (c *Client) ValidAuth() error {
132+
if c.invalidAuth {
133+
return fmt.Errorf("OPENAI_API_KEY is not set. Please set the OPENAI_API_KEY environment variable")
134+
}
135+
return nil
136+
}
137+
122138
func (c *Client) Supports(ctx context.Context, modelName string) (bool, error) {
123139
models, err := c.ListModels(ctx)
124140
if err != nil {
@@ -133,6 +149,10 @@ func (c *Client) ListModels(ctx context.Context, providers ...string) (result []
133149
return nil, nil
134150
}
135151

152+
if err := c.ValidAuth(); err != nil {
153+
return nil, err
154+
}
155+
136156
models, err := c.c.ListModels(ctx)
137157
if err != nil {
138158
return nil, err
@@ -146,8 +166,7 @@ func (c *Client) ListModels(ctx context.Context, providers ...string) (result []
146166

147167
func (c *Client) cacheKey(request openai.ChatCompletionRequest) string {
148168
return hash.Encode(map[string]any{
149-
"url": c.url,
150-
"key": c.key,
169+
"base": c.cacheKeyBase,
151170
"request": request,
152171
})
153172
}
@@ -277,6 +296,10 @@ func toMessages(request types.CompletionRequest) (result []openai.ChatCompletion
277296
}
278297

279298
func (c *Client) Call(ctx context.Context, messageRequest types.CompletionRequest, status chan<- types.CompletionStatus) (*types.CompletionMessage, error) {
299+
if err := c.ValidAuth(); err != nil {
300+
return nil, err
301+
}
302+
280303
if messageRequest.Model == "" {
281304
messageRequest.Model = c.defaultModel
282305
}
@@ -296,10 +319,9 @@ func (c *Client) Call(ctx context.Context, messageRequest types.CompletionReques
296319
}
297320

298321
if messageRequest.Temperature == nil {
299-
// this is a hack because the field is marked as omitempty, so we need it to be set to a non-zero value but arbitrarily small
300-
request.Temperature = 1e-08
322+
request.Temperature = new(float32)
301323
} else {
302-
request.Temperature = *messageRequest.Temperature
324+
request.Temperature = messageRequest.Temperature
303325
}
304326

305327
if messageRequest.JSONResponse {
@@ -330,7 +352,9 @@ func (c *Client) Call(ctx context.Context, messageRequest types.CompletionReques
330352
}
331353

332354
var cacheResponse bool
333-
request.Seed = ptr(c.seed(request))
355+
if c.setSeed {
356+
request.Seed = ptr(c.seed(request))
357+
}
334358
response, ok, err := c.fromCache(ctx, messageRequest, request)
335359
if err != nil {
336360
return nil, err

pkg/parser/parser.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,8 @@ func isParam(line string, tool *types.Tool) (_ bool, err error) {
7676
switch normalize(key) {
7777
case "name":
7878
tool.Parameters.Name = strings.ToLower(value)
79+
case "modelprovider":
80+
tool.Parameters.ModelProvider = true
7981
case "model", "modelname":
8082
tool.Parameters.ModelName = value
8183
case "description":

pkg/remote/remote.go

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,12 @@ import (
55
"fmt"
66
"net/url"
77
"os"
8-
"slices"
98
"sort"
109
"strings"
1110
"sync"
1211

1312
"github.com/gptscript-ai/gptscript/pkg/cache"
13+
env2 "github.com/gptscript-ai/gptscript/pkg/env"
1414
"github.com/gptscript-ai/gptscript/pkg/loader"
1515
"github.com/gptscript-ai/gptscript/pkg/openai"
1616
"github.com/gptscript-ai/gptscript/pkg/runner"
@@ -78,15 +78,6 @@ func (c *Client) Supports(ctx context.Context, modelName string) (bool, error) {
7878
return false, err
7979
}
8080

81-
models, err := client.ListModels(ctx)
82-
if err != nil {
83-
return false, err
84-
}
85-
86-
if !slices.Contains(models, modelNameSuffix) {
87-
return false, fmt.Errorf("Failed in find model [%s], supported [%s]", modelNameSuffix, strings.Join(models, ", "))
88-
}
89-
9081
c.clientsLock.Lock()
9182
defer c.clientsLock.Unlock()
9283

@@ -108,7 +99,7 @@ func (c *Client) clientFromURL(apiURL string) (*openai.Client, error) {
10899
if err != nil {
109100
return nil, err
110101
}
111-
env := strings.ToUpper(strings.ReplaceAll(parsed.Hostname(), ".", "_")) + "_API_KEY"
102+
env := "GPTSCRIPT_PROVIDER_" + env2.ToEnvLike(parsed.Hostname()) + "_API_KEY"
112103
apiKey := os.Getenv(env)
113104
if apiKey == "" {
114105
apiKey = "<unset>"
@@ -159,8 +150,9 @@ func (c *Client) load(ctx context.Context, toolName string) (*openai.Client, err
159150
}
160151

161152
client, err = openai.NewClient(openai.Options{
162-
BaseURL: url,
163-
Cache: c.cache,
153+
BaseURL: url,
154+
Cache: c.cache,
155+
CacheKey: prg.EntryToolID,
164156
})
165157
if err != nil {
166158
return nil, err

pkg/types/tool.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ type Parameters struct {
3939
Description string `json:"description,omitempty"`
4040
MaxTokens int `json:"maxTokens,omitempty"`
4141
ModelName string `json:"modelName,omitempty"`
42+
ModelProvider bool `json:"modelProvider,omitempty"`
4243
JSONResponse bool `json:"jsonResponse,omitempty"`
4344
Temperature *float32 `json:"temperature,omitempty"`
4445
Cache *bool `json:"cache,omitempty"`
@@ -81,6 +82,9 @@ func (t Tool) String() string {
8182
if t.Parameters.ModelName != "" {
8283
_, _ = fmt.Fprintf(buf, "Model Name: %s\n", t.Parameters.ModelName)
8384
}
85+
if t.Parameters.ModelProvider {
86+
_, _ = fmt.Fprintf(buf, "Model Provider: true\n")
87+
}
8488
if t.Parameters.JSONResponse {
8589
_, _ = fmt.Fprintln(buf, "JSON Response: true")
8690
}

0 commit comments

Comments
 (0)