Skip to content

Commit f699bcd

Browse files
Include task runner (#7)
1 parent af9fc37 commit f699bcd

File tree

4 files changed

+231
-41
lines changed

4 files changed

+231
-41
lines changed

README.md

Lines changed: 31 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -5,49 +5,41 @@
55
[![Go Report Card](https://goreportcard.com/badge/github.com/StudioSol/async)](https://goreportcard.com/report/github.com/StudioSol/async)
66
[![GoDoc](https://godoc.org/github.com/StudioSol/async?status.svg)](https://godoc.org/github.com/StudioSol/async)
77

8-
Provides a safe way to execute `fns`'s functions asynchronously, recovering them in case of panic. It also provides an error stack aiming to facilitate fail causes discovery.
8+
Provides a safe way to execute functions asynchronously, recovering them in case of panic. It also provides an error stack aiming to facilitate fail causes discovery, and a simple way to control execution flow without `WaitGroup`.
99

1010
### Usage
1111
```go
12-
func InsertAsynchronously(ctx context.Context) error {
13-
transaction := db.Transaction().Begin()
14-
15-
err := async.Run(ctx,
16-
func(_ context.Context) error {
17-
_, err := transaction.Exec(`
18-
INSERT INTO foo (bar)
19-
VALUES ('Hello')
20-
`)
21-
22-
return err
23-
},
24-
25-
func(_ context.Context) error {
26-
_, err := transaction.Exec(`
27-
INSERT INTO foo (bar)
28-
VALUES ('world')
29-
`)
30-
31-
return err
32-
},
33-
34-
func(_ context.Context) error {
35-
_, err := transaction.Exec(`
36-
INSERT INTO foo (bar)
37-
VALUES ('asynchronously!')
38-
`)
39-
40-
return err
41-
},
42-
)
12+
var (
13+
user User
14+
songs []Songs
15+
photos []Photos
16+
)
17+
18+
err := async.Run(ctx,
19+
func(ctx context.Context) error {
20+
user, err = user.Get(ctx, id)
21+
return err
22+
},
23+
func(ctx context.Context) error {
24+
songs, err = song.GetByUserID(ctx, id)
25+
return err
26+
},
27+
func(ctx context.Context) error {
28+
photos, err = photo.GetByUserID(ctx, id)
29+
return err
30+
},
31+
)
32+
33+
if err != nil {
34+
log.Error(err)
35+
}
36+
```
4337

44-
if err != nil {
45-
e := transaction.Rollback()
46-
log.IfError(e)
47-
return err
48-
}
38+
You can also limit the number of asynchronous tasks
4939

50-
return transaction.Commit()
40+
```go
41+
runner := async.NewRunner(tasks...).WithLimit(3)
42+
if err := runner.Run(ctx); err != nil {
43+
log.Error(e)
5144
}
52-
5345
```

async_test.go

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
11
package async_test
22

33
import (
4+
"context"
45
"errors"
56
"sync"
67
"testing"
78
"time"
89

9-
"context"
10-
1110
"github.com/StudioSol/async"
1211
. "github.com/smartystreets/goconvey/convey"
1312
)

runner.go

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
package async
2+
3+
import (
4+
"context"
5+
"sync"
6+
)
7+
8+
type Runner struct {
9+
sync.Mutex
10+
tasks []Task
11+
errs []error
12+
limit int
13+
waitErrors bool
14+
}
15+
16+
// NewRunner creates a new task manager to control async functions.
17+
func NewRunner(tasks ...Task) *Runner {
18+
return &Runner{
19+
tasks: tasks,
20+
limit: len(tasks),
21+
}
22+
}
23+
24+
// WaitErrors tells the runner to wait for the response from all functions instead of cancelling them all when the first error occurs.
25+
func (r *Runner) WaitErrors() *Runner {
26+
r.waitErrors = true
27+
return r
28+
}
29+
30+
// WithLimit defines a limit for concurrent tasks execution
31+
func (r *Runner) WithLimit(limit int) *Runner {
32+
r.limit = limit
33+
return r
34+
}
35+
36+
// AllErrors returns all errors reported by functions
37+
func (r *Runner) AllErrors() []error {
38+
return r.errs
39+
}
40+
41+
// registerErr store an error to final report
42+
func (r *Runner) registerErr(err error) {
43+
r.Lock()
44+
defer r.Unlock()
45+
if err != nil {
46+
r.errs = append(r.errs, err)
47+
}
48+
}
49+
50+
// wrapperChannel converts a given Task to a channel of errors
51+
func wrapperChannel(ctx context.Context, task Task) chan error {
52+
cerr := make(chan error, 1)
53+
go func() {
54+
cerr <- task(ctx)
55+
close(cerr)
56+
}()
57+
return cerr
58+
}
59+
60+
// Run starts the task manager and returns the first error or nil if succeed
61+
func (r *Runner) Run(parentCtx context.Context) error {
62+
ctx, cancel := context.WithCancel(parentCtx)
63+
cerr := make(chan error, len(r.tasks))
64+
queue := make(chan struct{}, r.limit)
65+
var wg sync.WaitGroup
66+
wg.Add(len(r.tasks))
67+
for _, task := range r.tasks {
68+
queue <- struct{}{}
69+
go func(fn func(context.Context) error) {
70+
defer func() {
71+
<-queue
72+
wg.Done()
73+
safePanic(cerr)
74+
}()
75+
76+
select {
77+
case <-parentCtx.Done():
78+
cerr <- parentCtx.Err()
79+
r.registerErr(parentCtx.Err())
80+
case err := <-wrapperChannel(ctx, fn):
81+
cerr <- err
82+
r.registerErr(err)
83+
}
84+
}(task)
85+
}
86+
87+
go func() {
88+
wg.Wait()
89+
cancel()
90+
close(cerr)
91+
}()
92+
93+
var firstErr error
94+
for err := range cerr {
95+
if err != nil && firstErr == nil {
96+
firstErr = err
97+
if r.waitErrors {
98+
continue
99+
}
100+
cancel()
101+
return firstErr
102+
}
103+
}
104+
105+
return firstErr
106+
}

runner_test.go

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
package async
2+
3+
import (
4+
"context"
5+
"errors"
6+
"testing"
7+
"time"
8+
9+
"github.com/stretchr/testify/require"
10+
)
11+
12+
func TestRunner_AllErrors(t *testing.T) {
13+
expectErr := errors.New("fail")
14+
runner := NewRunner(func(context.Context) error {
15+
return expectErr
16+
}).WaitErrors()
17+
err := runner.Run(context.Background())
18+
require.Equal(t, expectErr, err)
19+
require.Len(t, runner.AllErrors(), 1)
20+
require.Equal(t, expectErr, runner.AllErrors()[0])
21+
}
22+
23+
func TestRunner_WaitErrors(t *testing.T) {
24+
expectErrOne := errors.New("fail")
25+
expectErrTwo := errors.New("fail")
26+
runner := NewRunner(func(context.Context) error {
27+
return expectErrOne
28+
}, func(context.Context) error {
29+
return expectErrTwo
30+
}).WaitErrors()
31+
err := runner.Run(context.Background())
32+
require.False(t, err != expectErrOne && err != expectErrTwo)
33+
require.Len(t, runner.AllErrors(), 2)
34+
}
35+
36+
func TestRunner_Run(t *testing.T) {
37+
calledFist := false
38+
calledSecond := false
39+
runner := NewRunner(func(context.Context) error {
40+
calledFist = true
41+
return nil
42+
}, func(context.Context) error {
43+
calledSecond = true
44+
return nil
45+
})
46+
err := runner.Run(context.Background())
47+
require.Nil(t, err)
48+
require.True(t, calledFist)
49+
require.True(t, calledSecond)
50+
}
51+
52+
func TestRunner_WithLimit(t *testing.T) {
53+
order := 1
54+
runner := NewRunner(func(context.Context) error {
55+
require.Equal(t, 1, order)
56+
order++
57+
return nil
58+
}, func(context.Context) error {
59+
require.Equal(t, 2, order)
60+
order++
61+
return nil
62+
}).WithLimit(1)
63+
err := runner.Run(context.Background())
64+
require.Nil(t, err)
65+
}
66+
67+
func TestRunner_ContextCancelled(t *testing.T) {
68+
ctx, cancel := context.WithCancel(context.Background())
69+
70+
start := time.Now()
71+
runner := NewRunner(func(context.Context) error {
72+
cancel()
73+
time.Sleep(time.Minute)
74+
return nil
75+
})
76+
err := runner.Run(ctx)
77+
require.True(t, time.Since(start) < time.Minute)
78+
require.Equal(t, context.Canceled, err)
79+
}
80+
81+
func TestRunner_ContextTimeout(t *testing.T) {
82+
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
83+
defer cancel()
84+
85+
start := time.Now()
86+
runner := NewRunner(func(context.Context) error {
87+
time.Sleep(time.Minute)
88+
return nil
89+
})
90+
err := runner.Run(ctx)
91+
require.True(t, time.Since(start) < time.Minute)
92+
require.Equal(t, context.DeadlineExceeded, err)
93+
}

0 commit comments

Comments
 (0)