diff --git a/.gitignore b/.gitignore index 12d2326c..73f4d464 100644 --- a/.gitignore +++ b/.gitignore @@ -4,3 +4,4 @@ **/node_modules/ **/package-lock.json **/__pycache__ +/docs/yarn.lock diff --git a/README.md b/README.md index 8f8193b9..99a19a6f 100644 --- a/README.md +++ b/README.md @@ -70,6 +70,15 @@ Download and install the archive for your platform and architecture from the [re export OPENAI_API_KEY="your-api-key" ``` +Alternatively Azure OpenAI can be utilized + +```shell +export OPENAI_API_KEY="your-api-key" +export OPENAI_BASE_URL="your-endpiont" +export OPENAI_API_TYPE="AZURE" +export OPENAI_AZURE_DEPLOYMENT="your-deployment-name" +``` + #### Windows ```powershell diff --git a/go.mod b/go.mod index 4347ea7b..1d73035b 100644 --- a/go.mod +++ b/go.mod @@ -2,8 +2,6 @@ module github.com/gptscript-ai/gptscript go 1.22.0 -replace github.com/sashabaranov/go-openai => github.com/gptscript-ai/go-openai v0.0.0-20240227161457-daa30caa3185 - require ( github.com/BurntSushi/locker v0.0.0-20171006230638-a6e239ea1c69 github.com/acorn-io/broadcaster v0.0.0-20240105011354-bfadd4a7b45d @@ -14,7 +12,7 @@ require ( github.com/jaytaylor/html2text v0.0.0-20230321000545-74c2419ad056 github.com/olahol/melody v1.1.4 github.com/rs/cors v1.10.1 - github.com/sashabaranov/go-openai v1.18.3 + github.com/sashabaranov/go-openai v1.20.1 github.com/sirupsen/logrus v1.9.3 github.com/spf13/cobra v1.8.0 github.com/stretchr/testify v1.8.4 diff --git a/go.sum b/go.sum index ab8fa286..4ad85942 100644 --- a/go.sum +++ b/go.sum @@ -40,8 +40,6 @@ github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 h1:El6M4kTTCOh6aBiKaU github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510/go.mod h1:pupxD2MaaD3pAXIBCelhxNneeOaAeabZDe5s4K6zSpQ= github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc= github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= -github.com/gptscript-ai/go-openai v0.0.0-20240227161457-daa30caa3185 h1:+TfC9DYtWuexdL7x1lIdD1HP61IStb3ZTj/byBdiWs0= -github.com/gptscript-ai/go-openai v0.0.0-20240227161457-daa30caa3185/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg= github.com/hexops/autogold v0.8.1/go.mod h1:97HLDXyG23akzAoRYJh/2OBs3kd80eHyKPvZw0S5ZBY= github.com/hexops/autogold v1.3.1 h1:YgxF9OHWbEIUjhDbpnLhgVsjUDsiHDTyDfy2lrfdlzo= github.com/hexops/autogold v1.3.1/go.mod h1:sQO+mQUCVfxOKPht+ipDSkJ2SCJ7BNJVHZexsXqWMx4= @@ -97,6 +95,8 @@ github.com/samber/lo v1.38.1 h1:j2XEAqXKb09Am4ebOg31SpvzUTTs6EN3VfgeLUhPdXM= github.com/samber/lo v1.38.1/go.mod h1:+m/ZKRl6ClXCE2Lgf3MsQlWfh4bn1bz6CXEOxnEXnEA= github.com/samber/slog-logrus v1.0.0 h1:SsrN0p9akjCEaYd42Q5GtisMdHm0q11UD4fp4XCZi04= github.com/samber/slog-logrus v1.0.0/go.mod h1:ZTdPCmVWljwlfjz6XflKNvW4TcmYlexz4HMUOO/42bI= +github.com/sashabaranov/go-openai v1.20.1 h1:cFnTixAtc0I0cCBFr8gkvEbGCm6Rjf2JyoVWCjXwy9g= +github.com/sashabaranov/go-openai v1.20.1/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg= github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= github.com/spf13/cobra v1.8.0 h1:7aJaZx1B85qltLMc546zn58BxxfZdR/W22ej9CFoEf0= diff --git a/pkg/openai/client.go b/pkg/openai/client.go index 28f92680..1590f29c 100644 --- a/pkg/openai/client.go +++ b/pkg/openai/client.go @@ -27,6 +27,7 @@ const ( var ( key = os.Getenv("OPENAI_API_KEY") url = os.Getenv("OPENAI_URL") + azureModel = os.Getenv("OPENAI_AZURE_DEPLOYMENT") completionID int64 ) @@ -80,6 +81,15 @@ func complete(opts ...Options) (result Options, err error) { return result, err } +func AzureMapperFunction(model string) string { + if azureModel == "" { + return model + } + return map[string]string{ + openai.GPT4TurboPreview: azureModel, + }[model] +} + func NewClient(opts ...Options) (*Client, error) { opt, err := complete(opts...) if err != nil { @@ -89,6 +99,7 @@ func NewClient(opts ...Options) (*Client, error) { cfg := openai.DefaultConfig(opt.APIKey) if strings.Contains(string(opt.APIType), "AZURE") { cfg = openai.DefaultAzureConfig(key, url) + cfg.AzureModelMapperFunc = AzureMapperFunction } cfg.BaseURL = types.FirstSet(opt.BaseURL, cfg.BaseURL) @@ -236,15 +247,16 @@ func (c *Client) Call(ctx context.Context, messageRequest types.CompletionReques } request := openai.ChatCompletionRequest{ - Model: messageRequest.Model, - Messages: msgs, - MaxTokens: messageRequest.MaxTokens, - Temperature: messageRequest.Temperature, - Grammar: messageRequest.Grammar, + Model: messageRequest.Model, + Messages: msgs, + MaxTokens: messageRequest.MaxTokens, } - if request.Temperature == nil { - request.Temperature = new(float32) + if messageRequest.Temperature == nil { + // this is a hack because the field is marked as omitempty, so we need it to be set to a non-zero value but arbitrarily small + request.Temperature = 1e-08 + } else { + request.Temperature = *messageRequest.Temperature } if messageRequest.JSONResponse { @@ -260,7 +272,7 @@ func (c *Client) Call(ctx context.Context, messageRequest types.CompletionReques } request.Tools = append(request.Tools, openai.Tool{ Type: openai.ToolTypeFunction, - Function: openai.FunctionDefinition{ + Function: &openai.FunctionDefinition{ Name: tool.Function.Name, Description: tool.Function.Description, Parameters: params,