mirror of https://github.com/restic/restic.git
93 lines
1.8 KiB
Go
93 lines
1.8 KiB
Go
package worker_test
|
|
|
|
import (
|
|
"context"
|
|
"testing"
|
|
|
|
"github.com/restic/restic/internal/errors"
|
|
|
|
"github.com/restic/restic/internal/worker"
|
|
)
|
|
|
|
const concurrency = 10
|
|
|
|
var errTooLarge = errors.New("too large")
|
|
|
|
func square(ctx context.Context, job worker.Job) (interface{}, error) {
|
|
n := job.Data.(int)
|
|
if n > 2000 {
|
|
return nil, errTooLarge
|
|
}
|
|
return n * n, nil
|
|
}
|
|
|
|
func newBufferedPool(ctx context.Context, bufsize int, n int, f worker.Func) (chan worker.Job, chan worker.Job, *worker.Pool) {
|
|
inCh := make(chan worker.Job, bufsize)
|
|
outCh := make(chan worker.Job, bufsize)
|
|
|
|
return inCh, outCh, worker.New(ctx, n, f, inCh, outCh)
|
|
}
|
|
|
|
func TestPool(t *testing.T) {
|
|
inCh, outCh, p := newBufferedPool(context.TODO(), 200, concurrency, square)
|
|
|
|
for i := 0; i < 150; i++ {
|
|
inCh <- worker.Job{Data: i}
|
|
}
|
|
|
|
close(inCh)
|
|
p.Wait()
|
|
|
|
for res := range outCh {
|
|
if res.Error != nil {
|
|
t.Errorf("unexpected error for job %v received: %v", res.Data, res.Error)
|
|
continue
|
|
}
|
|
|
|
n := res.Data.(int)
|
|
m := res.Result.(int)
|
|
|
|
if m != n*n {
|
|
t.Errorf("wrong value for job %d returned: want %d, got %d", n, n*n, m)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestPoolErrors(t *testing.T) {
|
|
inCh, outCh, p := newBufferedPool(context.TODO(), 200, concurrency, square)
|
|
|
|
for i := 0; i < 150; i++ {
|
|
inCh <- worker.Job{Data: i + 1900}
|
|
}
|
|
|
|
close(inCh)
|
|
p.Wait()
|
|
|
|
for res := range outCh {
|
|
n := res.Data.(int)
|
|
|
|
if n > 2000 {
|
|
if res.Error == nil {
|
|
t.Errorf("expected error not found, result is %v", res)
|
|
continue
|
|
}
|
|
|
|
if res.Error != errTooLarge {
|
|
t.Errorf("unexpected error found, result is %v", res)
|
|
}
|
|
|
|
continue
|
|
} else {
|
|
if res.Error != nil {
|
|
t.Errorf("unexpected error for job %v received: %v", res.Data, res.Error)
|
|
continue
|
|
}
|
|
}
|
|
|
|
m := res.Result.(int)
|
|
if m != n*n {
|
|
t.Errorf("wrong value for job %d returned: want %d, got %d", n, n*n, m)
|
|
}
|
|
}
|
|
}
|