Skip to content

Commit 690cedb

Browse files
Merge pull request #174 from ibuildthecloud/models
feat: add support for 3rd party model shims
2 parents 834f379 + bf43e68 commit 690cedb

File tree

14 files changed

+415
-130
lines changed

14 files changed

+415
-130
lines changed

go.mod

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ require (
1919
github.com/sirupsen/logrus v1.9.3
2020
github.com/spf13/cobra v1.8.0
2121
github.com/stretchr/testify v1.8.4
22+
golang.org/x/exp v0.0.0-20240103183307-be819d1f06fc
2223
golang.org/x/sync v0.6.0
2324
golang.org/x/term v0.16.0
2425
)
@@ -66,7 +67,6 @@ require (
6667
github.com/therootcompany/xz v1.0.1 // indirect
6768
github.com/ulikunitz/xz v0.5.10 // indirect
6869
go4.org v0.0.0-20200411211856-f5505b9728dd // indirect
69-
golang.org/x/exp v0.0.0-20240103183307-be819d1f06fc // indirect
7070
golang.org/x/mod v0.15.0 // indirect
7171
golang.org/x/net v0.20.0 // indirect
7272
golang.org/x/sys v0.16.0 // indirect

pkg/cache/cache.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,10 @@ func New(opts ...Options) (*Client, error) {
5858
}, nil
5959
}
6060

61+
func (c *Client) CacheDir() string {
62+
return c.dir
63+
}
64+
6165
func (c *Client) Store(key string, content []byte) error {
6266
if c == nil || c.noop {
6367
return nil

pkg/cli/gptscript.go

Lines changed: 29 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
package cli
22

33
import (
4-
"context"
54
"fmt"
65
"io"
76
"os"
@@ -12,15 +11,12 @@ import (
1211
"github.com/gptscript-ai/gptscript/pkg/builtin"
1312
"github.com/gptscript-ai/gptscript/pkg/cache"
1413
"github.com/gptscript-ai/gptscript/pkg/confirm"
15-
"github.com/gptscript-ai/gptscript/pkg/engine"
14+
"github.com/gptscript-ai/gptscript/pkg/gptscript"
1615
"github.com/gptscript-ai/gptscript/pkg/input"
17-
"github.com/gptscript-ai/gptscript/pkg/llm"
1816
"github.com/gptscript-ai/gptscript/pkg/loader"
1917
"github.com/gptscript-ai/gptscript/pkg/monitor"
2018
"github.com/gptscript-ai/gptscript/pkg/mvl"
2119
"github.com/gptscript-ai/gptscript/pkg/openai"
22-
"github.com/gptscript-ai/gptscript/pkg/repos/runtimes"
23-
"github.com/gptscript-ai/gptscript/pkg/runner"
2420
"github.com/gptscript-ai/gptscript/pkg/server"
2521
"github.com/gptscript-ai/gptscript/pkg/types"
2622
"github.com/gptscript-ai/gptscript/pkg/version"
@@ -50,8 +46,6 @@ type GPTScript struct {
5046
Server bool `usage:"Start server"`
5147
ListenAddress string `usage:"Server listen address" default:"127.0.0.1:9090"`
5248
Chdir string `usage:"Change current working directory" short:"C"`
53-
54-
_client llm.Client `usage:"-"`
5549
}
5650

5751
func New() *cobra.Command {
@@ -78,33 +72,6 @@ func (r *GPTScript) Customize(cmd *cobra.Command) {
7872
}
7973
}
8074

81-
func (r *GPTScript) getClient(ctx context.Context) (llm.Client, error) {
82-
if r._client != nil {
83-
return r._client, nil
84-
}
85-
86-
cacheClient, err := cache.New(cache.Options(r.CacheOptions))
87-
if err != nil {
88-
return nil, err
89-
}
90-
91-
oaClient, err := openai.NewClient(openai.Options(r.OpenAIOptions), openai.Options{
92-
Cache: cacheClient,
93-
})
94-
if err != nil {
95-
return nil, err
96-
}
97-
98-
registry := llm.NewRegistry()
99-
100-
if err := registry.AddClient(ctx, oaClient); err != nil {
101-
return nil, err
102-
}
103-
104-
r._client = registry
105-
return r._client, nil
106-
}
107-
10875
func (r *GPTScript) listTools() error {
10976
var lines []string
11077
for _, tool := range builtin.ListTools() {
@@ -114,24 +81,6 @@ func (r *GPTScript) listTools() error {
11481
return nil
11582
}
11683

117-
func (r *GPTScript) listModels(ctx context.Context) error {
118-
c, err := r.getClient(ctx)
119-
if err != nil {
120-
return err
121-
}
122-
123-
models, err := c.ListModels(ctx)
124-
if err != nil {
125-
return err
126-
}
127-
128-
for _, model := range models {
129-
fmt.Println(model)
130-
}
131-
132-
return nil
133-
}
134-
13584
func (r *GPTScript) Pre(*cobra.Command, []string) error {
13685
// chdir as soon as possible
13786
if r.Chdir != "" {
@@ -164,37 +113,50 @@ func (r *GPTScript) Pre(*cobra.Command, []string) error {
164113
}
165114

166115
func (r *GPTScript) Run(cmd *cobra.Command, args []string) error {
167-
defer engine.CloseDaemons()
168-
169-
if r.ListModels {
170-
return r.listModels(cmd.Context())
171-
}
172-
173-
if r.ListTools {
174-
return r.listTools()
116+
gptOpt := gptscript.Options{
117+
Cache: cache.Options(r.CacheOptions),
118+
OpenAI: openai.Options(r.OpenAIOptions),
119+
Monitor: monitor.Options(r.DisplayOptions),
120+
Quiet: r.Quiet,
121+
Env: os.Environ(),
175122
}
176123

177124
if r.Server {
178-
c, err := r.getClient(cmd.Context())
179-
if err != nil {
180-
return err
181-
}
182-
s, err := server.New(c, server.Options{
125+
s, err := server.New(&server.Options{
183126
ListenAddress: r.ListenAddress,
127+
GPTScript: gptOpt,
184128
})
185129
if err != nil {
186130
return err
187131
}
132+
defer s.Close()
188133
return s.Start(cmd.Context())
189134
}
190135

136+
gptScript, err := gptscript.New(&gptOpt)
137+
if err != nil {
138+
return err
139+
}
140+
defer gptScript.Close()
141+
142+
if r.ListModels {
143+
models, err := gptScript.ListModels(cmd.Context())
144+
if err != nil {
145+
return err
146+
}
147+
fmt.Println(strings.Join(models, "\n"))
148+
}
149+
150+
if r.ListTools {
151+
return r.listTools()
152+
}
153+
191154
if len(args) == 0 {
192155
return cmd.Help()
193156
}
194157

195158
var (
196159
prg types.Program
197-
err error
198160
)
199161

200162
if args[0] == "-" {
@@ -227,21 +189,6 @@ func (r *GPTScript) Run(cmd *cobra.Command, args []string) error {
227189
return assemble.Assemble(prg, out)
228190
}
229191

230-
client, err := r.getClient(cmd.Context())
231-
if err != nil {
232-
return err
233-
}
234-
235-
runner, err := runner.New(client, runner.Options{
236-
MonitorFactory: monitor.NewConsole(monitor.Options(r.DisplayOptions), monitor.Options{
237-
DisplayProgress: !*r.Quiet,
238-
}),
239-
RuntimeManager: runtimes.Default(cache.Complete(cache.Options(r.CacheOptions)).CacheDir),
240-
})
241-
if err != nil {
242-
return err
243-
}
244-
245192
toolInput, err := input.FromCLI(r.Input, args)
246193
if err != nil {
247194
return err
@@ -251,7 +198,7 @@ func (r *GPTScript) Run(cmd *cobra.Command, args []string) error {
251198
if r.Confirm {
252199
ctx = confirm.WithConfirm(ctx, confirm.TextPrompt{})
253200
}
254-
s, err := runner.Run(ctx, prg, os.Environ(), toolInput)
201+
s, err := gptScript.Run(ctx, prg, os.Environ(), toolInput)
255202
if err != nil {
256203
return err
257204
}

pkg/engine/cmd.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -109,13 +109,13 @@ var ignoreENV = map[string]struct{}{
109109
"GPTSCRIPT_TOOL_DIR": {},
110110
}
111111

112-
func appendEnv(env []string, k, v string) []string {
113-
for _, k := range []string{k, strings.ToUpper(strings.ReplaceAll(k, "-", "_"))} {
112+
func appendEnv(envs []string, k, v string) []string {
113+
for _, k := range []string{k, env.ToEnvLike(k)} {
114114
if _, ignore := ignoreENV[k]; !ignore {
115-
env = append(env, k+"="+v)
115+
envs = append(envs, k+"="+v)
116116
}
117117
}
118-
return env
118+
return envs
119119
}
120120

121121
func appendInputAsEnv(env []string, input string) []string {

pkg/engine/http.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,12 @@ func (e *Engine) runHTTP(ctx context.Context, prg *types.Program, tool types.Too
5555
toolURL = parsed.String()
5656
}
5757

58+
if tool.Blocking {
59+
return &Return{
60+
Result: &toolURL,
61+
}, nil
62+
}
63+
5864
req, err := http.NewRequestWithContext(ctx, http.MethodPost, toolURL, strings.NewReader(input))
5965
if err != nil {
6066
return nil, err

pkg/env/env.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@ func execEquals(bin, check string) bool {
1212
bin == check+".exe"
1313
}
1414

15+
func ToEnvLike(v string) string {
16+
return strings.ToUpper(strings.ReplaceAll(v, "-", "_"))
17+
}
18+
1519
func Matches(cmd []string, bin string) bool {
1620
switch len(cmd) {
1721
case 0:

pkg/gptscript/gptscript.go

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
package gptscript
2+
3+
import (
4+
"context"
5+
"os"
6+
7+
"github.com/gptscript-ai/gptscript/pkg/cache"
8+
"github.com/gptscript-ai/gptscript/pkg/engine"
9+
"github.com/gptscript-ai/gptscript/pkg/llm"
10+
"github.com/gptscript-ai/gptscript/pkg/monitor"
11+
"github.com/gptscript-ai/gptscript/pkg/openai"
12+
"github.com/gptscript-ai/gptscript/pkg/remote"
13+
"github.com/gptscript-ai/gptscript/pkg/repos/runtimes"
14+
"github.com/gptscript-ai/gptscript/pkg/runner"
15+
"github.com/gptscript-ai/gptscript/pkg/types"
16+
)
17+
18+
type GPTScript struct {
19+
Registry *llm.Registry
20+
Runner *runner.Runner
21+
}
22+
23+
type Options struct {
24+
Cache cache.Options
25+
OpenAI openai.Options
26+
Monitor monitor.Options
27+
Runner runner.Options
28+
Quiet *bool `usage:"No output logging (set --quiet=false to force on even when there is no TTY)" short:"q"`
29+
Env []string `usage:"-"`
30+
}
31+
32+
func complete(opts *Options) (result *Options) {
33+
result = opts
34+
if result == nil {
35+
result = &Options{}
36+
}
37+
if result.Quiet == nil {
38+
result.Quiet = new(bool)
39+
}
40+
if len(result.Env) == 0 {
41+
result.Env = os.Environ()
42+
}
43+
return
44+
}
45+
46+
func New(opts *Options) (*GPTScript, error) {
47+
opts = complete(opts)
48+
49+
registry := llm.NewRegistry()
50+
51+
cacheClient, err := cache.New(opts.Cache)
52+
if err != nil {
53+
return nil, err
54+
}
55+
56+
oAIClient, err := openai.NewClient(append([]openai.Options{opts.OpenAI}, openai.Options{
57+
Cache: cacheClient,
58+
})...)
59+
if err != nil {
60+
return nil, err
61+
}
62+
63+
if err := registry.AddClient(oAIClient); err != nil {
64+
return nil, err
65+
}
66+
67+
if opts.Runner.MonitorFactory == nil {
68+
opts.Runner.MonitorFactory = monitor.NewConsole(append([]monitor.Options{opts.Monitor}, monitor.Options{
69+
DisplayProgress: !*opts.Quiet,
70+
})...)
71+
}
72+
73+
if opts.Runner.RuntimeManager == nil {
74+
opts.Runner.RuntimeManager = runtimes.Default(cacheClient.CacheDir())
75+
}
76+
77+
runner, err := runner.New(registry, opts.Runner)
78+
if err != nil {
79+
return nil, err
80+
}
81+
82+
remoteClient := remote.New(runner, opts.Env, cacheClient)
83+
84+
if err := registry.AddClient(remoteClient); err != nil {
85+
return nil, err
86+
}
87+
88+
return &GPTScript{
89+
Registry: registry,
90+
Runner: runner,
91+
}, nil
92+
}
93+
94+
func (g *GPTScript) Run(ctx context.Context, prg types.Program, envs []string, input string) (string, error) {
95+
return g.Runner.Run(ctx, prg, envs, input)
96+
}
97+
98+
func (g *GPTScript) Close() {
99+
engine.CloseDaemons()
100+
}
101+
102+
func (g *GPTScript) GetModel() engine.Model {
103+
return g.Registry
104+
}
105+
106+
func (g *GPTScript) ListModels(ctx context.Context) ([]string, error) {
107+
return g.Registry.ListModels(ctx)
108+
}

0 commit comments

Comments
 (0)