Skip to content

Commit 6e73e5c

Browse files
Merge pull request #80 from ibuildthecloud/main
Refactor client code to make way for additionl AI API backends
2 parents 6c12178 + 046d340 commit 6e73e5c

File tree

15 files changed

+341
-206
lines changed

15 files changed

+341
-206
lines changed

go.mod

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ module github.com/gptscript-ai/gptscript
22

33
go 1.22.0
44

5-
replace github.com/sashabaranov/go-openai => github.com/gptscript-ai/go-openai v0.0.0-20240206232711-45b6e096246a
5+
replace github.com/sashabaranov/go-openai => github.com/gptscript-ai/go-openai v0.0.0-20240227161457-daa30caa3185
66

77
require (
88
github.com/BurntSushi/locker v0.0.0-20171006230638-a6e239ea1c69

go.sum

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,8 @@ github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 h1:El6M4kTTCOh6aBiKaU
4040
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510/go.mod h1:pupxD2MaaD3pAXIBCelhxNneeOaAeabZDe5s4K6zSpQ=
4141
github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc=
4242
github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
43-
github.com/gptscript-ai/go-openai v0.0.0-20240206232711-45b6e096246a h1:AdBbQ1ODOYK5AwCey4VFEmKeu9gG4PCzuO80pQmgupE=
44-
github.com/gptscript-ai/go-openai v0.0.0-20240206232711-45b6e096246a/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg=
43+
github.com/gptscript-ai/go-openai v0.0.0-20240227161457-daa30caa3185 h1:+TfC9DYtWuexdL7x1lIdD1HP61IStb3ZTj/byBdiWs0=
44+
github.com/gptscript-ai/go-openai v0.0.0-20240227161457-daa30caa3185/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg=
4545
github.com/hexops/autogold v0.8.1/go.mod h1:97HLDXyG23akzAoRYJh/2OBs3kd80eHyKPvZw0S5ZBY=
4646
github.com/hexops/autogold v1.3.1 h1:YgxF9OHWbEIUjhDbpnLhgVsjUDsiHDTyDfy2lrfdlzo=
4747
github.com/hexops/autogold v1.3.1/go.mod h1:sQO+mQUCVfxOKPht+ipDSkJ2SCJ7BNJVHZexsXqWMx4=

pkg/builtin/defaults.go

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,20 @@ import (
66
)
77

88
var (
9-
DefaultModel = openai.DefaultModel
9+
defaultModel = openai.DefaultModel
1010
)
1111

12+
func GetDefaultModel() string {
13+
return defaultModel
14+
}
15+
16+
func SetDefaultModel(model string) {
17+
defaultModel = model
18+
}
19+
1220
func SetDefaults(tool types.Tool) types.Tool {
1321
if tool.Parameters.ModelName == "" {
14-
tool.Parameters.ModelName = DefaultModel
22+
tool.Parameters.ModelName = GetDefaultModel()
1523
}
1624
return tool
1725
}

pkg/cache/cache.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package cache
22

33
import (
4+
"context"
45
"errors"
56
"io/fs"
67
"os"
@@ -35,6 +36,17 @@ func complete(opts ...Options) (result Options) {
3536
return
3637
}
3738

39+
type noCacheKey struct{}
40+
41+
func IsNoCache(ctx context.Context) bool {
42+
v, _ := ctx.Value(noCacheKey{}).(bool)
43+
return v
44+
}
45+
46+
func WithNoCache(ctx context.Context) context.Context {
47+
return context.WithValue(ctx, noCacheKey{}, true)
48+
}
49+
3850
func New(opts ...Options) (*Client, error) {
3951
opt := complete(opts...)
4052
if err := os.MkdirAll(opt.CacheDir, 0755); err != nil {

pkg/cli/gptscript.go

Lines changed: 52 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,10 @@ import (
1010
"github.com/acorn-io/cmd"
1111
"github.com/gptscript-ai/gptscript/pkg/assemble"
1212
"github.com/gptscript-ai/gptscript/pkg/builtin"
13+
"github.com/gptscript-ai/gptscript/pkg/cache"
1314
"github.com/gptscript-ai/gptscript/pkg/engine"
1415
"github.com/gptscript-ai/gptscript/pkg/input"
16+
"github.com/gptscript-ai/gptscript/pkg/llm"
1517
"github.com/gptscript-ai/gptscript/pkg/loader"
1618
"github.com/gptscript-ai/gptscript/pkg/monitor"
1719
"github.com/gptscript-ai/gptscript/pkg/mvl"
@@ -26,10 +28,13 @@ import (
2628

2729
type (
2830
DisplayOptions monitor.Options
31+
CacheOptions cache.Options
32+
OpenAIOptions openai.Options
2933
)
3034

3135
type GPTScript struct {
32-
runner.Options
36+
CacheOptions
37+
OpenAIOptions
3338
DisplayOptions
3439
Debug bool `usage:"Enable debug logging"`
3540
Quiet *bool `usage:"No output logging" short:"q"`
@@ -41,6 +46,8 @@ type GPTScript struct {
4146
ListTools bool `usage:"List built-in tools and exit"`
4247
Server bool `usage:"Start server"`
4348
ListenAddress string `usage:"Server listen address" default:"127.0.0.1:9090"`
49+
50+
_client llm.Client `usage:"-"`
4451
}
4552

4653
func New() *cobra.Command {
@@ -67,6 +74,33 @@ func (r *GPTScript) Customize(cmd *cobra.Command) {
6774
}
6875
}
6976

77+
func (r *GPTScript) getClient(ctx context.Context) (llm.Client, error) {
78+
if r._client != nil {
79+
return r._client, nil
80+
}
81+
82+
cacheClient, err := cache.New(cache.Options(r.CacheOptions))
83+
if err != nil {
84+
return nil, err
85+
}
86+
87+
oaClient, err := openai.NewClient(openai.Options(r.OpenAIOptions), openai.Options{
88+
Cache: cacheClient,
89+
})
90+
if err != nil {
91+
return nil, err
92+
}
93+
94+
registry := llm.NewRegistry()
95+
96+
if err := registry.AddClient(ctx, oaClient); err != nil {
97+
return nil, err
98+
}
99+
100+
r._client = registry
101+
return r._client, nil
102+
}
103+
70104
func (r *GPTScript) listTools() error {
71105
var lines []string
72106
for _, tool := range builtin.ListTools() {
@@ -77,12 +111,12 @@ func (r *GPTScript) listTools() error {
77111
}
78112

79113
func (r *GPTScript) listModels(ctx context.Context) error {
80-
c, err := openai.NewClient(openai.Options(r.OpenAIOptions))
114+
c, err := r.getClient(ctx)
81115
if err != nil {
82116
return err
83117
}
84118

85-
models, err := c.ListModules(ctx)
119+
models, err := c.ListModels(ctx)
86120
if err != nil {
87121
return err
88122
}
@@ -95,6 +129,10 @@ func (r *GPTScript) listModels(ctx context.Context) error {
95129
}
96130

97131
func (r *GPTScript) Pre(*cobra.Command, []string) error {
132+
if r.DefaultModel != "" {
133+
builtin.SetDefaultModel(r.DefaultModel)
134+
}
135+
98136
if r.Quiet == nil {
99137
if term.IsTerminal(int(os.Stdout.Fd())) {
100138
r.Quiet = new(bool)
@@ -126,9 +164,11 @@ func (r *GPTScript) Run(cmd *cobra.Command, args []string) error {
126164
}
127165

128166
if r.Server {
129-
s, err := server.New(server.Options{
130-
CacheOptions: r.CacheOptions,
131-
OpenAIOptions: r.OpenAIOptions,
167+
c, err := r.getClient(cmd.Context())
168+
if err != nil {
169+
return err
170+
}
171+
s, err := server.New(c, server.Options{
132172
ListenAddress: r.ListenAddress,
133173
})
134174
if err != nil {
@@ -176,9 +216,12 @@ func (r *GPTScript) Run(cmd *cobra.Command, args []string) error {
176216
return assemble.Assemble(prg, out)
177217
}
178218

179-
runner, err := runner.New(r.Options, runner.Options{
180-
CacheOptions: r.CacheOptions,
181-
OpenAIOptions: r.OpenAIOptions,
219+
client, err := r.getClient(cmd.Context())
220+
if err != nil {
221+
return err
222+
}
223+
224+
runner, err := runner.New(client, runner.Options{
182225
MonitorFactory: monitor.NewConsole(monitor.Options(r.DisplayOptions), monitor.Options{
183226
DisplayProgress: !*r.Quiet,
184227
}),

pkg/engine/cmd.go

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@ import (
1212
"sync/atomic"
1313

1414
"github.com/google/shlex"
15-
"github.com/gptscript-ai/gptscript/pkg/openai"
1615
"github.com/gptscript-ai/gptscript/pkg/types"
1716
"github.com/gptscript-ai/gptscript/pkg/version"
1817
)
@@ -21,7 +20,7 @@ func (e *Engine) runCommand(ctx context.Context, tool types.Tool, input string)
2120
id := fmt.Sprint(atomic.AddInt64(&completionID, 1))
2221

2322
defer func() {
24-
e.Progress <- openai.Status{
23+
e.Progress <- types.CompletionStatus{
2524
CompletionID: id,
2625
Response: map[string]any{
2726
"output": cmdOut,
@@ -31,7 +30,7 @@ func (e *Engine) runCommand(ctx context.Context, tool types.Tool, input string)
3130
}()
3231

3332
if tool.BuiltinFunc != nil {
34-
e.Progress <- openai.Status{
33+
e.Progress <- types.CompletionStatus{
3534
CompletionID: id,
3635
Request: map[string]any{
3736
"command": []string{tool.ID},
@@ -47,7 +46,7 @@ func (e *Engine) runCommand(ctx context.Context, tool types.Tool, input string)
4746
}
4847
defer stop()
4948

50-
e.Progress <- openai.Status{
49+
e.Progress <- types.CompletionStatus{
5150
CompletionID: id,
5251
Request: map[string]any{
5352
"command": cmd.Args,

pkg/engine/engine.go

Lines changed: 19 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -4,46 +4,16 @@ import (
44
"context"
55
"encoding/json"
66
"fmt"
7-
"os"
87
"sync"
98
"sync/atomic"
109

11-
"github.com/gptscript-ai/gptscript/pkg/openai"
10+
"github.com/gptscript-ai/gptscript/pkg/system"
1211
"github.com/gptscript-ai/gptscript/pkg/types"
1312
"github.com/gptscript-ai/gptscript/pkg/version"
1413
)
1514

16-
// InternalSystemPrompt is added to all threads. Changing this is very dangerous as it has a
17-
// terrible global effect and changes the behavior of all scripts.
18-
var InternalSystemPrompt = `
19-
You are task oriented system.
20-
You receive input from a user, process the input from the given instructions, and then output the result.
21-
Your objective is to provide consistent and correct results.
22-
You do not need to explain the steps taken, only provide the result to the given instructions.
23-
You are referred to as a tool.
24-
`
25-
26-
var DefaultToolSchema = types.JSONSchema{
27-
Property: types.Property{
28-
Type: "object",
29-
},
30-
Properties: map[string]types.Property{
31-
openai.DefaultPromptParameter: {
32-
Description: "Prompt to send to the tool or assistant. This may be instructions or question.",
33-
Type: "string",
34-
},
35-
},
36-
Required: []string{openai.DefaultPromptParameter},
37-
}
38-
3915
var completionID int64
4016

41-
func init() {
42-
if p := os.Getenv("GPTSCRIPT_INTERNAL_SYSTEM_PROMPT"); p != "" {
43-
InternalSystemPrompt = p
44-
}
45-
}
46-
4717
type ErrToolNotFound struct {
4818
ToolName string
4919
}
@@ -52,10 +22,14 @@ func (e *ErrToolNotFound) Error() string {
5222
return fmt.Sprintf("tool not found: %s", e.ToolName)
5323
}
5424

25+
type Model interface {
26+
Call(ctx context.Context, messageRequest types.CompletionRequest, status chan<- types.CompletionStatus) (*types.CompletionMessage, error)
27+
}
28+
5529
type Engine struct {
56-
Client *openai.Client
30+
Model Model
5731
Env []string
58-
Progress chan<- openai.Status
32+
Progress chan<- types.CompletionStatus
5933
}
6034

6135
type State struct {
@@ -172,18 +146,12 @@ func (e *Engine) Start(ctx Context, input string) (*Return, error) {
172146
}
173147

174148
completion := types.CompletionRequest{
175-
Model: tool.Parameters.ModelName,
176-
MaxToken: tool.Parameters.MaxTokens,
177-
JSONResponse: tool.Parameters.JSONResponse,
178-
Cache: tool.Parameters.Cache,
179-
Temperature: tool.Parameters.Temperature,
180-
}
181-
182-
if InternalSystemPrompt != "" && (tool.Parameters.InternalPrompt == nil || *tool.Parameters.InternalPrompt) {
183-
completion.Messages = append(completion.Messages, types.CompletionMessage{
184-
Role: types.CompletionMessageRoleTypeSystem,
185-
Content: types.Text(InternalSystemPrompt),
186-
})
149+
Model: tool.Parameters.ModelName,
150+
MaxTokens: tool.Parameters.MaxTokens,
151+
JSONResponse: tool.Parameters.JSONResponse,
152+
Cache: tool.Parameters.Cache,
153+
Temperature: tool.Parameters.Temperature,
154+
InternalSystemPrompt: tool.Parameters.InternalPrompt,
187155
}
188156

189157
for _, subToolName := range tool.Parameters.Tools {
@@ -193,10 +161,9 @@ func (e *Engine) Start(ctx Context, input string) (*Return, error) {
193161
}
194162
args := subTool.Parameters.Arguments
195163
if args == nil && !subTool.IsCommand() {
196-
args = &DefaultToolSchema
164+
args = &system.DefaultToolSchema
197165
}
198166
completion.Tools = append(completion.Tools, types.CompletionTool{
199-
Type: types.CompletionToolTypeFunction,
200167
Function: types.CompletionFunctionDefinition{
201168
Name: subToolName,
202169
Description: subTool.Parameters.Description,
@@ -207,12 +174,8 @@ func (e *Engine) Start(ctx Context, input string) (*Return, error) {
207174

208175
if tool.Instructions != "" {
209176
completion.Messages = append(completion.Messages, types.CompletionMessage{
210-
Role: types.CompletionMessageRoleTypeSystem,
211-
Content: []types.ContentPart{
212-
{
213-
Text: tool.Instructions,
214-
},
215-
},
177+
Role: types.CompletionMessageRoleTypeSystem,
178+
Content: types.Text(tool.Instructions),
216179
})
217180
}
218181

@@ -230,7 +193,7 @@ func (e *Engine) Start(ctx Context, input string) (*Return, error) {
230193

231194
func (e *Engine) complete(ctx context.Context, state *State) (*Return, error) {
232195
var (
233-
progress = make(chan openai.Status)
196+
progress = make(chan types.CompletionStatus)
234197
ret = Return{
235198
State: state,
236199
Calls: map[string]Call{},
@@ -241,6 +204,7 @@ func (e *Engine) complete(ctx context.Context, state *State) (*Return, error) {
241204
// ensure we aren't writing to the channel anymore on exit
242205
wg.Add(1)
243206
defer wg.Wait()
207+
defer close(progress)
244208

245209
go func() {
246210
defer wg.Done()
@@ -251,8 +215,7 @@ func (e *Engine) complete(ctx context.Context, state *State) (*Return, error) {
251215
}
252216
}()
253217

254-
resp, err := e.Client.Call(ctx, state.Completion, progress)
255-
close(progress)
218+
resp, err := e.Model.Call(ctx, state.Completion, progress)
256219
if err != nil {
257220
return nil, err
258221
}

pkg/hash/sha256.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,16 @@ import (
66
"encoding/json"
77
)
88

9+
func Digest(obj any) string {
10+
data, err := json.Marshal(obj)
11+
if err != nil {
12+
panic(err)
13+
}
14+
15+
hash := sha256.Sum224(data)
16+
return hex.EncodeToString(hash[:])
17+
}
18+
919
func Encode(obj any) string {
1020
data, err := json.Marshal(obj)
1121
if err != nil {

0 commit comments

Comments
 (0)