Skip to content

Commit c0b8139

Browse files
chore: Refactor CLI root and tool call sorting for code base improvements
- Renamed Root struct to GPTScript and added Output field to streamline CLI usage. - Enhanced tool call sorting in the engine package for consistent ordering. - Implemented seed generation method for API call consistency in the OpenAI client module. This commit enhances the overall code quality and user experience by introducing better struct naming conventions, ensuring deterministic sorting of tool calls, and improving cache key generation for API responses.
1 parent 6887aeb commit c0b8139

File tree

3 files changed

+51
-16
lines changed

3 files changed

+51
-16
lines changed

pkg/cli/root.go

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,21 +13,22 @@ import (
1313
"golang.org/x/term"
1414
)
1515

16-
type Root struct {
16+
type GPTScript struct {
1717
runner.Options
18+
Output string `usage:"Save output to a file" short:"o"`
1819
}
1920

2021
func New() *cobra.Command {
21-
return cmd.Command(&Root{})
22+
return cmd.Command(&GPTScript{})
2223
}
2324

24-
func (r *Root) Customize(cmd *cobra.Command) {
25+
func (r *GPTScript) Customize(cmd *cobra.Command) {
2526
cmd.Use = version.ProgramName
2627
cmd.Args = cobra.MinimumNArgs(1)
2728
cmd.Flags().SetInterspersed(false)
2829
}
2930

30-
func (r *Root) Run(cmd *cobra.Command, args []string) error {
31+
func (r *GPTScript) Run(cmd *cobra.Command, args []string) error {
3132
in, err := os.Open(args[0])
3233
if err != nil {
3334
return err
@@ -55,9 +56,17 @@ func (r *Root) Run(cmd *cobra.Command, args []string) error {
5556
return err
5657
}
5758

58-
fmt.Print(s)
59-
if !strings.HasSuffix(s, "\n") {
60-
fmt.Println()
59+
if r.Output != "" {
60+
err = os.WriteFile(r.Output, []byte(s), 0644)
61+
if err != nil {
62+
return err
63+
}
64+
} else {
65+
fmt.Print(s)
66+
if !strings.HasSuffix(s, "\n") {
67+
fmt.Println()
68+
}
6169
}
62-
return err
70+
71+
return nil
6372
}

pkg/engine/engine.go

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -267,12 +267,12 @@ func (e *Engine) Continue(ctx context.Context, state *State, results ...CallResu
267267
}
268268

269269
var (
270-
added bool
271-
pendingIDs []string
270+
added bool
271+
pendingToolCalls []types.CompletionToolCall
272272
)
273273

274274
for id, pending := range state.Pending {
275-
pendingIDs = append(pendingIDs, id)
275+
pendingToolCalls = append(pendingToolCalls, pending)
276276
if _, ok := state.Results[id]; !ok {
277277
ret.Calls[id] = Call{
278278
ToolName: pending.Function.Name,
@@ -285,11 +285,18 @@ func (e *Engine) Continue(ctx context.Context, state *State, results ...CallResu
285285
return &ret, nil
286286
}
287287

288-
sort.Strings(pendingIDs)
288+
sort.Slice(pendingToolCalls, func(i, j int) bool {
289+
left := pendingToolCalls[i].Function.Name + pendingToolCalls[i].Function.Arguments
290+
right := pendingToolCalls[j].Function.Name + pendingToolCalls[j].Function.Arguments
291+
if left == right {
292+
return pendingToolCalls[i].ID < pendingToolCalls[j].ID
293+
}
294+
return left < right
295+
})
289296

290-
for _, id := range pendingIDs {
291-
pending := state.Pending[id]
292-
if result, ok := state.Results[id]; ok {
297+
for _, pending := range pendingToolCalls {
298+
pending := pending
299+
if result, ok := state.Results[pending.ID]; ok {
293300
added = true
294301
state.Completion.Messages = append(state.Completion.Messages, types.CompletionMessage{
295302
Role: types.CompletionMessageRoleTypeTool,

pkg/openai/client.go

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,25 @@ func (c *Client) cacheKey(request openai.ChatCompletionRequest) string {
5858
return hash.Encode(request)
5959
}
6060

61+
func (c *Client) seed(request openai.ChatCompletionRequest) int {
62+
newRequest := request
63+
newRequest.Messages = nil
64+
65+
for _, msg := range request.Messages {
66+
newMsg := msg
67+
newMsg.ToolCalls = nil
68+
newMsg.ToolCallID = ""
69+
70+
for _, tool := range msg.ToolCalls {
71+
tool.ID = ""
72+
newMsg.ToolCalls = append(newMsg.ToolCalls, tool)
73+
}
74+
75+
newRequest.Messages = append(newRequest.Messages, newMsg)
76+
}
77+
return hash.Seed(newRequest)
78+
}
79+
6180
func (c *Client) fromCache(ctx context.Context, messageRequest types.CompletionRequest, request openai.ChatCompletionRequest) (result []openai.ChatCompletionStreamResponse, _ bool, _ error) {
6281
if messageRequest.Cache != nil && !*messageRequest.Cache {
6382
return nil, false, nil
@@ -210,7 +229,7 @@ func (c *Client) Call(ctx context.Context, messageRequest types.CompletionReques
210229
}
211230
}
212231

213-
request.Seed = ptr(hash.Seed(request))
232+
request.Seed = ptr(c.seed(request))
214233
response, ok, err := c.fromCache(ctx, messageRequest, request)
215234
if err != nil {
216235
return nil, err

0 commit comments

Comments
 (0)