diff --git a/src/cmds/restic/cmd_dump.go b/src/cmds/restic/cmd_dump.go index f90fe5810..72a9d85b8 100644 --- a/src/cmds/restic/cmd_dump.go +++ b/src/cmds/restic/cmd_dump.go @@ -8,11 +8,14 @@ import ( "io" "os" - "github.com/juju/errors" "restic" "restic/backend" "restic/pack" "restic/repository" + + "restic/worker" + + "github.com/juju/errors" ) type CmdDump struct { @@ -32,7 +35,7 @@ func init() { } func (cmd CmdDump) Usage() string { - return "[indexes|snapshots|trees|all]" + return "[indexes|snapshots|trees|all|packs]" } func prettyPrintJSON(wr io.Writer, item interface{}) error { @@ -98,6 +101,82 @@ func printTrees(repo *repository.Repository, wr io.Writer) error { return nil } +const dumpPackWorkers = 10 + +// Pack is the struct used in printPacks. +type Pack struct { + Name string `json:"name"` + + Blobs []Blob `json:"blobs"` +} + +// Blob is the struct used in printPacks. +type Blob struct { + Type pack.BlobType `json:"type"` + Length uint `json:"length"` + ID backend.ID `json:"id"` + Offset uint `json:"offset"` +} + +func printPacks(repo *repository.Repository, wr io.Writer) error { + done := make(chan struct{}) + defer close(done) + + f := func(job worker.Job, done <-chan struct{}) (interface{}, error) { + name := job.Data.(string) + + h := backend.Handle{Type: backend.Data, Name: name} + rd := backend.NewReadSeeker(repo.Backend(), h) + + unpacker, err := pack.NewUnpacker(repo.Key(), rd) + if err != nil { + return nil, err + } + + return unpacker.Entries, nil + } + + jobCh := make(chan worker.Job) + resCh := make(chan worker.Job) + wp := worker.New(dumpPackWorkers, f, jobCh, resCh) + + go func() { + for name := range repo.Backend().List(backend.Data, done) { + jobCh <- worker.Job{Data: name} + } + close(jobCh) + }() + + for job := range resCh { + name := job.Data.(string) + + if job.Error != nil { + fmt.Fprintf(os.Stderr, "error for pack %v: %v\n", name, job.Error) + continue + } + + entries := job.Result.([]pack.Blob) + p := Pack{ + Name: name, + Blobs: make([]Blob, len(entries)), + } + for i, blob := range entries { + p.Blobs[i] = Blob{ + Type: blob.Type, + Length: blob.Length, + ID: blob.ID, + Offset: blob.Offset, + } + } + + prettyPrintJSON(os.Stdout, p) + } + + wp.Wait() + + return nil +} + func (cmd CmdDump) DumpIndexes() error { done := make(chan struct{}) defer close(done) @@ -150,6 +229,8 @@ func (cmd CmdDump) Execute(args []string) error { return printSnapshots(repo, os.Stdout) case "trees": return printTrees(repo, os.Stdout) + case "packs": + return printPacks(repo, os.Stdout) case "all": fmt.Printf("snapshots:\n") err := printSnapshots(repo, os.Stdout) diff --git a/src/cmds/restic/cmd_optimize.go b/src/cmds/restic/cmd_optimize.go deleted file mode 100644 index bd6f26ddf..000000000 --- a/src/cmds/restic/cmd_optimize.go +++ /dev/null @@ -1,84 +0,0 @@ -package main - -import ( - "errors" - "fmt" - - "restic/backend" - "restic/checker" -) - -type CmdOptimize struct { - global *GlobalOptions -} - -func init() { - _, err := parser.AddCommand("optimize", - "optimize the repository", - "The optimize command reorganizes the repository and removes uneeded data", - &CmdOptimize{global: &globalOpts}) - if err != nil { - panic(err) - } -} - -func (cmd CmdOptimize) Usage() string { - return "[optimize-options]" -} - -func (cmd CmdOptimize) Execute(args []string) error { - if len(args) != 0 { - return errors.New("optimize has no arguments") - } - - repo, err := cmd.global.OpenRepository() - if err != nil { - return err - } - - cmd.global.Verbosef("Create exclusive lock for repository\n") - lock, err := lockRepoExclusive(repo) - defer unlockRepo(lock) - if err != nil { - return err - } - - chkr := checker.New(repo) - - cmd.global.Verbosef("Load indexes\n") - _, errs := chkr.LoadIndex() - - if len(errs) > 0 { - for _, err := range errs { - cmd.global.Warnf("error: %v\n", err) - } - return fmt.Errorf("LoadIndex returned errors") - } - - done := make(chan struct{}) - errChan := make(chan error) - go chkr.Structure(errChan, done) - - for err := range errChan { - if e, ok := err.(checker.TreeError); ok { - cmd.global.Warnf("error for tree %v:\n", e.ID.Str()) - for _, treeErr := range e.Errors { - cmd.global.Warnf(" %v\n", treeErr) - } - } else { - cmd.global.Warnf("error: %v\n", err) - } - } - - unusedBlobs := backend.NewIDSet(chkr.UnusedBlobs()...) - cmd.global.Verbosef("%d unused blobs found, repacking...\n", len(unusedBlobs)) - - repacker := checker.NewRepacker(repo, unusedBlobs) - err = repacker.Repack() - if err != nil { - return err - } - - cmd.global.Verbosef("repacking done\n") - return nil -} diff --git a/src/cmds/restic/cmd_rebuild_index.go b/src/cmds/restic/cmd_rebuild_index.go index cab8bbc46..e3e82684a 100644 --- a/src/cmds/restic/cmd_rebuild_index.go +++ b/src/cmds/restic/cmd_rebuild_index.go @@ -1,13 +1,13 @@ package main import ( - "bytes" "fmt" - + "os" "restic/backend" "restic/debug" "restic/pack" "restic/repository" + "restic/worker" ) type CmdRebuildIndex struct { @@ -26,164 +26,101 @@ func init() { } } -func (cmd CmdRebuildIndex) storeIndex(index *repository.Index) (*repository.Index, error) { - debug.Log("RebuildIndex.RebuildIndex", "saving index") - - cmd.global.Printf(" saving new index\n") - id, err := repository.SaveIndex(cmd.repo, index) - if err != nil { - debug.Log("RebuildIndex.RebuildIndex", "error saving index: %v", err) - return nil, err - } - - debug.Log("RebuildIndex.RebuildIndex", "index saved as %v", id.Str()) - index = repository.NewIndex() - - return index, nil -} - -func (cmd CmdRebuildIndex) RebuildIndex() error { - debug.Log("RebuildIndex.RebuildIndex", "start") +const rebuildIndexWorkers = 10 +func loadBlobsFromPacks(repo *repository.Repository) (packs map[backend.ID][]pack.Blob) { done := make(chan struct{}) defer close(done) - indexIDs := backend.NewIDSet() - for id := range cmd.repo.List(backend.Index, done) { - indexIDs.Insert(id) - } + f := func(job worker.Job, done <-chan struct{}) (interface{}, error) { + id := job.Data.(backend.ID) - cmd.global.Printf("rebuilding index from %d indexes\n", len(indexIDs)) + h := backend.Handle{Type: backend.Data, Name: id.String()} + rd := backend.NewReadSeeker(repo.Backend(), h) - debug.Log("RebuildIndex.RebuildIndex", "found %v indexes", len(indexIDs)) - - combinedIndex := repository.NewIndex() - packsDone := backend.NewIDSet() - - type Blob struct { - id backend.ID - tpe pack.BlobType - } - blobsDone := make(map[Blob]struct{}) - - i := 0 - for indexID := range indexIDs { - cmd.global.Printf(" loading index %v\n", i) - - debug.Log("RebuildIndex.RebuildIndex", "load index %v", indexID.Str()) - idx, err := repository.LoadIndex(cmd.repo, indexID) + unpacker, err := pack.NewUnpacker(repo.Key(), rd) if err != nil { - return err + return nil, err } - debug.Log("RebuildIndex.RebuildIndex", "adding blobs from index %v", indexID.Str()) - - for packedBlob := range idx.Each(done) { - packsDone.Insert(packedBlob.PackID) - b := Blob{ - id: packedBlob.ID, - tpe: packedBlob.Type, - } - if _, ok := blobsDone[b]; ok { - continue - } - - blobsDone[b] = struct{}{} - combinedIndex.Store(packedBlob) - } - - combinedIndex.AddToSupersedes(indexID) - - if repository.IndexFull(combinedIndex) { - combinedIndex, err = cmd.storeIndex(combinedIndex) - if err != nil { - return err - } - } - - i++ + return unpacker.Entries, nil } - var err error - if combinedIndex.Length() > 0 { - combinedIndex, err = cmd.storeIndex(combinedIndex) - if err != nil { - return err + jobCh := make(chan worker.Job) + resCh := make(chan worker.Job) + wp := worker.New(rebuildIndexWorkers, f, jobCh, resCh) + + go func() { + for id := range repo.List(backend.Data, done) { + jobCh <- worker.Job{Data: id} } - } + close(jobCh) + }() - cmd.global.Printf("removing %d old indexes\n", len(indexIDs)) - for id := range indexIDs { - debug.Log("RebuildIndex.RebuildIndex", "remove index %v", id.Str()) + packs = make(map[backend.ID][]pack.Blob) + for job := range resCh { + id := job.Data.(backend.ID) - err := cmd.repo.Backend().Remove(backend.Index, id.String()) - if err != nil { - debug.Log("RebuildIndex.RebuildIndex", "error removing index %v: %v", id.Str(), err) - return err - } - } - - cmd.global.Printf("checking for additional packs\n") - newPacks := 0 - var buf []byte - for packID := range cmd.repo.List(backend.Data, done) { - if packsDone.Has(packID) { + if job.Error != nil { + fmt.Fprintf(os.Stderr, "error for pack %v: %v\n", id, job.Error) continue } - debug.Log("RebuildIndex.RebuildIndex", "pack %v not indexed", packID.Str()) - newPacks++ + entries := job.Result.([]pack.Blob) + packs[id] = entries + } - var err error + wp.Wait() - h := backend.Handle{Type: backend.Data, Name: packID.String()} - buf, err = backend.LoadAll(cmd.repo.Backend(), h, buf) - if err != nil { - debug.Log("RebuildIndex.RebuildIndex", "error while loading pack %v", packID.Str()) - return fmt.Errorf("error while loading pack %v: %v", packID.Str(), err) - } + return packs +} - hash := backend.Hash(buf) - if !hash.Equal(packID) { - debug.Log("RebuildIndex.RebuildIndex", "Pack ID does not match, want %v, got %v", packID.Str(), hash.Str()) - return fmt.Errorf("Pack ID does not match, want %v, got %v", packID.Str(), hash.Str()) - } +func listIndexIDs(repo *repository.Repository) (list backend.IDs) { + done := make(chan struct{}) + for id := range repo.List(backend.Index, done) { + list = append(list, id) + } - up, err := pack.NewUnpacker(cmd.repo.Key(), bytes.NewReader(buf)) - if err != nil { - debug.Log("RebuildIndex.RebuildIndex", "error while unpacking pack %v", packID.Str()) - return err - } + return list +} - for _, blob := range up.Entries { - debug.Log("RebuildIndex.RebuildIndex", "pack %v: blob %v", packID.Str(), blob) - combinedIndex.Store(repository.PackedBlob{ - Type: blob.Type, - ID: blob.ID, +func (cmd CmdRebuildIndex) rebuildIndex() error { + debug.Log("RebuildIndex.RebuildIndex", "start rebuilding index") + + packs := loadBlobsFromPacks(cmd.repo) + cmd.global.Verbosef("loaded blobs from %d packs\n", len(packs)) + + idx := repository.NewIndex() + for packID, entries := range packs { + for _, entry := range entries { + pb := repository.PackedBlob{ + ID: entry.ID, + Type: entry.Type, + Length: entry.Length, + Offset: entry.Offset, PackID: packID, - Offset: blob.Offset, - Length: blob.Length, - }) - } - - if repository.IndexFull(combinedIndex) { - combinedIndex, err = cmd.storeIndex(combinedIndex) - if err != nil { - return err } + idx.Store(pb) } } - if combinedIndex.Length() > 0 { - combinedIndex, err = cmd.storeIndex(combinedIndex) + oldIndexes := listIndexIDs(cmd.repo) + idx.AddToSupersedes(oldIndexes...) + cmd.global.Printf(" saving new index\n") + id, err := repository.SaveIndex(cmd.repo, idx) + if err != nil { + debug.Log("RebuildIndex.RebuildIndex", "error saving index: %v", err) + return err + } + debug.Log("RebuildIndex.RebuildIndex", "new index saved as %v", id.Str()) + + for _, indexID := range oldIndexes { + err := cmd.repo.Backend().Remove(backend.Index, indexID.String()) if err != nil { - return err + cmd.global.Warnf("unable to remove index %v: %v\n", indexID.Str(), err) } } - cmd.global.Printf("added %d packs to the index\n", newPacks) - - debug.Log("RebuildIndex.RebuildIndex", "done") return nil } @@ -200,5 +137,5 @@ func (cmd CmdRebuildIndex) Execute(args []string) error { return err } - return cmd.RebuildIndex() + return cmd.rebuildIndex() } diff --git a/src/cmds/restic/integration_test.go b/src/cmds/restic/integration_test.go index c8be04c5a..8c42c7548 100644 --- a/src/cmds/restic/integration_test.go +++ b/src/cmds/restic/integration_test.go @@ -110,11 +110,6 @@ func cmdRebuildIndex(t testing.TB, global GlobalOptions) { OK(t, cmd.Execute(nil)) } -func cmdOptimize(t testing.TB, global GlobalOptions) { - cmd := &CmdOptimize{global: &global} - OK(t, cmd.Execute(nil)) -} - func cmdLs(t testing.TB, global GlobalOptions, snapshotID string) []string { var buf bytes.Buffer global.stdout = &buf @@ -771,25 +766,6 @@ var optimizeTests = []struct { }, } -func TestOptimizeRemoveUnusedBlobs(t *testing.T) { - for i, test := range optimizeTests { - withTestEnvironment(t, func(env *testEnvironment, global GlobalOptions) { - SetupTarTestFixture(t, env.base, test.testFilename) - - for id := range test.snapshots { - OK(t, removeFile(filepath.Join(env.repo, "snapshots", id.String()))) - } - - cmdOptimize(t, global) - output := cmdCheckOutput(t, global) - - if len(output) > 0 { - t.Errorf("expected no output for check in test %d, got:\n%v", i, output) - } - }) - } -} - func TestCheckRestoreNoLock(t *testing.T) { withTestEnvironment(t, func(env *testEnvironment, global GlobalOptions) { datafile := filepath.Join("testdata", "small-repo.tar.gz") diff --git a/src/restic/backend/readseeker.go b/src/restic/backend/readseeker.go new file mode 100644 index 000000000..ea063e3f3 --- /dev/null +++ b/src/restic/backend/readseeker.go @@ -0,0 +1,63 @@ +package backend + +import ( + "errors" + "io" +) + +type readSeeker struct { + be Backend + h Handle + t Type + name string + offset int64 + size int64 +} + +// NewReadSeeker returns an io.ReadSeeker for the given object in the backend. +func NewReadSeeker(be Backend, h Handle) io.ReadSeeker { + return &readSeeker{be: be, h: h} +} + +func (rd *readSeeker) Read(p []byte) (int, error) { + n, err := rd.be.Load(rd.h, p, rd.offset) + rd.offset += int64(n) + return n, err +} + +func (rd *readSeeker) Seek(offset int64, whence int) (n int64, err error) { + switch whence { + case 0: + rd.offset = offset + case 1: + rd.offset += offset + case 2: + if rd.size == 0 { + rd.size, err = rd.getSize() + if err != nil { + return 0, err + } + } + + pos := rd.size + offset + if pos < 0 { + return 0, errors.New("invalid offset, before start of blob") + } + + rd.offset = pos + return rd.offset, nil + default: + return 0, errors.New("invalid value for parameter whence") + } + + return rd.offset, nil +} + +func (rd *readSeeker) getSize() (int64, error) { + stat, err := rd.be.Stat(rd.h) + if err != nil { + return 0, err + } + + return stat.Size, nil +} diff --git a/src/restic/backend/readseeker_test.go b/src/restic/backend/readseeker_test.go new file mode 100644 index 000000000..013f2528e --- /dev/null +++ b/src/restic/backend/readseeker_test.go @@ -0,0 +1,114 @@ +package backend_test + +import ( + "bytes" + "io" + "math/rand" + "restic/backend" + "restic/backend/mem" + "testing" + + . "restic/test" +) + +func abs(a int) int { + if a < 0 { + return -a + } + + return a +} + +func loadAndCompare(t testing.TB, rd io.ReadSeeker, size int, offset int64, expected []byte) { + var ( + pos int64 + err error + ) + + if offset >= 0 { + pos, err = rd.Seek(offset, 0) + } else { + pos, err = rd.Seek(offset, 2) + } + if err != nil { + t.Errorf("Seek(%d, 0) returned error: %v", offset, err) + return + } + + if offset >= 0 && pos != offset { + t.Errorf("pos after seek is wrong, want %d, got %d", offset, pos) + } else if offset < 0 && pos != int64(size)+offset { + t.Errorf("pos after relative seek is wrong, want %d, got %d", int64(size)+offset, pos) + } + + buf := make([]byte, len(expected)) + n, err := rd.Read(buf) + + // if we requested data beyond the end of the file, ignore + // ErrUnexpectedEOF error + if offset > 0 && len(buf) > size && err == io.ErrUnexpectedEOF { + err = nil + buf = buf[:size] + } + + if offset < 0 && len(buf) > abs(int(offset)) && err == io.ErrUnexpectedEOF { + err = nil + buf = buf[:abs(int(offset))] + } + + if n != len(buf) { + t.Errorf("Load(%d, %d): wrong length returned, want %d, got %d", + len(buf), offset, len(buf), n) + return + } + + if err != nil { + t.Errorf("Load(%d, %d): unexpected error: %v", len(buf), offset, err) + return + } + + buf = buf[:n] + if !bytes.Equal(buf, expected) { + t.Errorf("Load(%d, %d) returned wrong bytes", len(buf), offset) + return + } +} + +func TestReadSeeker(t *testing.T) { + b := mem.New() + + length := rand.Intn(1<<24) + 2000 + + data := Random(23, length) + id := backend.Hash(data) + + handle := backend.Handle{Type: backend.Data, Name: id.String()} + err := b.Save(handle, data) + if err != nil { + t.Fatalf("Save() error: %v", err) + } + + for i := 0; i < 50; i++ { + l := rand.Intn(length + 2000) + o := rand.Intn(length + 2000) + + if rand.Float32() > 0.5 { + o = -o + } + + d := data + if o > 0 && o < len(d) { + d = d[o:] + } else { + o = len(d) + d = d[:0] + } + + if l > 0 && l < len(d) { + d = d[:l] + } + + rd := backend.NewReadSeeker(b, handle) + loadAndCompare(t, rd, len(data), int64(o), d) + } +} diff --git a/src/restic/checker/checker.go b/src/restic/checker/checker.go index f545942d3..4b147d442 100644 --- a/src/restic/checker/checker.go +++ b/src/restic/checker/checker.go @@ -20,8 +20,8 @@ import ( // A Checker only tests for internal errors within the data structures of the // repository (e.g. missing blobs), and needs a valid Repository to work on. type Checker struct { - packs map[backend.ID]struct{} - blobs map[backend.ID]struct{} + packs backend.IDSet + blobs backend.IDSet blobRefs struct { sync.Mutex M map[backend.ID]uint @@ -37,8 +37,8 @@ type Checker struct { // New returns a new checker which runs on repo. func New(repo *repository.Repository) *Checker { c := &Checker{ - packs: make(map[backend.ID]struct{}), - blobs: make(map[backend.ID]struct{}), + packs: backend.NewIDSet(), + blobs: backend.NewIDSet(), masterIndex: repository.NewMasterIndex(), indexes: make(map[backend.ID]*repository.Index), repo: repo, @@ -136,8 +136,8 @@ func (c *Checker) LoadIndex() (hints []error, errs []error) { debug.Log("LoadIndex", "process blobs") cnt := 0 for blob := range res.Index.Each(done) { - c.packs[blob.PackID] = struct{}{} - c.blobs[blob.ID] = struct{}{} + c.packs.Insert(blob.PackID) + c.blobs.Insert(blob.ID) c.blobRefs.M[blob.ID] = 0 cnt++ @@ -217,7 +217,7 @@ func (c *Checker) Packs(errChan chan<- error, done <-chan struct{}) { defer close(errChan) debug.Log("Checker.Packs", "checking for %d packs", len(c.packs)) - seenPacks := make(map[backend.ID]struct{}) + seenPacks := backend.NewIDSet() var workerWG sync.WaitGroup @@ -228,7 +228,7 @@ func (c *Checker) Packs(errChan chan<- error, done <-chan struct{}) { } for id := range c.packs { - seenPacks[id] = struct{}{} + seenPacks.Insert(id) IDChan <- id } close(IDChan) @@ -239,7 +239,7 @@ func (c *Checker) Packs(errChan chan<- error, done <-chan struct{}) { for id := range c.repo.List(backend.Data, done) { debug.Log("Checker.Packs", "check data blob %v", id.Str()) - if _, ok := seenPacks[id]; !ok { + if !seenPacks.Has(id) { c.orphanedPacks = append(c.orphanedPacks, id) select { case <-done: @@ -252,20 +252,20 @@ func (c *Checker) Packs(errChan chan<- error, done <-chan struct{}) { // Error is an error that occurred while checking a repository. type Error struct { - TreeID *backend.ID - BlobID *backend.ID + TreeID backend.ID + BlobID backend.ID Err error } func (e Error) Error() string { - if e.BlobID != nil && e.TreeID != nil { + if !e.BlobID.IsNull() && !e.TreeID.IsNull() { msg := "tree " + e.TreeID.Str() msg += ", blob " + e.BlobID.Str() msg += ": " + e.Err.Error() return msg } - if e.TreeID != nil { + if !e.TreeID.IsNull() { return "tree " + e.TreeID.Str() + ": " + e.Err.Error() } @@ -583,19 +583,19 @@ func (c *Checker) checkTree(id backend.ID, tree *restic.Tree) (errs []error) { case "file": for b, blobID := range node.Content { if blobID.IsNull() { - errs = append(errs, Error{TreeID: &id, Err: fmt.Errorf("file %q blob %d has null ID", node.Name, b)}) + errs = append(errs, Error{TreeID: id, Err: fmt.Errorf("file %q blob %d has null ID", node.Name, b)}) continue } blobs = append(blobs, blobID) } case "dir": if node.Subtree == nil { - errs = append(errs, Error{TreeID: &id, Err: fmt.Errorf("dir node %q has no subtree", node.Name)}) + errs = append(errs, Error{TreeID: id, Err: fmt.Errorf("dir node %q has no subtree", node.Name)}) continue } if node.Subtree.IsNull() { - errs = append(errs, Error{TreeID: &id, Err: fmt.Errorf("dir node %q subtree id is null", node.Name)}) + errs = append(errs, Error{TreeID: id, Err: fmt.Errorf("dir node %q subtree id is null", node.Name)}) continue } } @@ -607,10 +607,10 @@ func (c *Checker) checkTree(id backend.ID, tree *restic.Tree) (errs []error) { debug.Log("Checker.checkTree", "blob %v refcount %d", blobID.Str(), c.blobRefs.M[blobID]) c.blobRefs.Unlock() - if _, ok := c.blobs[blobID]; !ok { + if !c.blobs.Has(blobID) { debug.Log("Checker.trees", "tree %v references blob %v which isn't contained in index", id.Str(), blobID.Str()) - errs = append(errs, Error{TreeID: &id, BlobID: &blobID, Err: errors.New("not found in index")}) + errs = append(errs, Error{TreeID: id, BlobID: blobID, Err: errors.New("not found in index")}) } } diff --git a/src/restic/pack/pack_test.go b/src/restic/pack/pack_test.go index 97075d52d..18a7a86f4 100644 --- a/src/restic/pack/pack_test.go +++ b/src/restic/pack/pack_test.go @@ -11,6 +11,7 @@ import ( "testing" "restic/backend" + "restic/backend/mem" "restic/crypto" "restic/pack" . "restic/test" @@ -18,12 +19,12 @@ import ( var lengths = []int{23, 31650, 25860, 10928, 13769, 19862, 5211, 127, 13690, 30231} -func TestCreatePack(t *testing.T) { - type Buf struct { - data []byte - id backend.ID - } +type Buf struct { + data []byte + id backend.ID +} +func newPack(t testing.TB, k *crypto.Key) ([]Buf, []byte, uint) { bufs := []Buf{} for _, l := range lengths { @@ -34,9 +35,6 @@ func TestCreatePack(t *testing.T) { bufs = append(bufs, Buf{data: b, id: h}) } - // create random keys - k := crypto.NewRandomKey() - // pack blobs p := pack.NewPacker(k, nil) for _, b := range bufs { @@ -46,6 +44,10 @@ func TestCreatePack(t *testing.T) { packData, err := p.Finalize() OK(t, err) + return bufs, packData, p.Size() +} + +func verifyBlobs(t testing.TB, bufs []Buf, k *crypto.Key, rd io.ReadSeeker, packSize uint) { written := 0 for _, l := range lengths { written += l @@ -58,11 +60,9 @@ func TestCreatePack(t *testing.T) { written += crypto.Extension // check length - Equals(t, written, len(packData)) - Equals(t, uint(written), p.Size()) + Equals(t, uint(written), packSize) // read and parse it again - rd := bytes.NewReader(packData) np, err := pack.NewUnpacker(k, rd) OK(t, err) Equals(t, len(np.Entries), len(bufs)) @@ -81,6 +81,15 @@ func TestCreatePack(t *testing.T) { } } +func TestCreatePack(t *testing.T) { + // create random keys + k := crypto.NewRandomKey() + + bufs, packData, packSize := newPack(t, k) + Equals(t, uint(len(packData)), packSize) + verifyBlobs(t, bufs, k, bytes.NewReader(packData), packSize) +} + var blobTypeJSON = []struct { t pack.BlobType res string @@ -103,3 +112,18 @@ func TestBlobTypeJSON(t *testing.T) { Equals(t, test.t, v) } } + +func TestUnpackReadSeeker(t *testing.T) { + // create random keys + k := crypto.NewRandomKey() + + bufs, packData, packSize := newPack(t, k) + + b := mem.New() + id := backend.Hash(packData) + + handle := backend.Handle{Type: backend.Data, Name: id.String()} + OK(t, b.Save(handle, packData)) + rd := backend.NewReadSeeker(b, handle) + verifyBlobs(t, bufs, k, rd, packSize) +} diff --git a/src/restic/worker/doc.go b/src/restic/worker/doc.go new file mode 100644 index 000000000..602bb5037 --- /dev/null +++ b/src/restic/worker/doc.go @@ -0,0 +1,2 @@ +// Package worker implements a worker pool. +package worker diff --git a/src/restic/worker/pool.go b/src/restic/worker/pool.go new file mode 100644 index 000000000..d2331f587 --- /dev/null +++ b/src/restic/worker/pool.go @@ -0,0 +1,106 @@ +package worker + +// Job is one unit of work. It is given to a Func, and the returned result and +// error are stored in Result and Error. +type Job struct { + Data interface{} + Result interface{} + Error error +} + +// Func does the actual work within a Pool. +type Func func(job Job, done <-chan struct{}) (result interface{}, err error) + +// Pool implements a worker pool. +type Pool struct { + f Func + done chan struct{} + jobCh <-chan Job + resCh chan<- Job + + numWorkers int + workersExit chan struct{} + allWorkersDone chan struct{} +} + +// New returns a new worker pool with n goroutines, each running the function +// f. The workers are started immediately. +func New(n int, f Func, jobChan <-chan Job, resultChan chan<- Job) *Pool { + p := &Pool{ + f: f, + done: make(chan struct{}), + workersExit: make(chan struct{}), + allWorkersDone: make(chan struct{}), + numWorkers: n, + jobCh: jobChan, + resCh: resultChan, + } + + for i := 0; i < n; i++ { + go p.runWorker(i) + } + + go p.waitForExit() + + return p +} + +// waitForExit receives from p.workersExit until all worker functions have +// exited, then closes the result channel. +func (p *Pool) waitForExit() { + n := p.numWorkers + for n > 0 { + <-p.workersExit + n-- + } + close(p.allWorkersDone) + close(p.resCh) +} + +// runWorker runs a worker function. +func (p *Pool) runWorker(numWorker int) { + defer func() { + p.workersExit <- struct{}{} + }() + + var ( + // enable the input channel when starting up a new goroutine + inCh = p.jobCh + // but do not enable the output channel until we have a result + outCh chan<- Job + + job Job + ok bool + ) + + for { + select { + case <-p.done: + return + + case job, ok = <-inCh: + if !ok { + return + } + + job.Result, job.Error = p.f(job, p.done) + inCh = nil + outCh = p.resCh + + case outCh <- job: + outCh = nil + inCh = p.jobCh + } + } +} + +// Cancel signals termination to all worker goroutines. +func (p *Pool) Cancel() { + close(p.done) +} + +// Wait waits for all worker goroutines to terminate, afterwards the output +// channel is closed. +func (p *Pool) Wait() { + <-p.allWorkersDone +} diff --git a/src/restic/worker/pool_test.go b/src/restic/worker/pool_test.go new file mode 100644 index 000000000..16b285702 --- /dev/null +++ b/src/restic/worker/pool_test.go @@ -0,0 +1,143 @@ +package worker_test + +import ( + "errors" + "testing" + "time" + + "restic/worker" +) + +const concurrency = 10 + +var errTooLarge = errors.New("too large") + +func square(job worker.Job, done <-chan struct{}) (interface{}, error) { + n := job.Data.(int) + if n > 2000 { + return nil, errTooLarge + } + return n * n, nil +} + +func newBufferedPool(bufsize int, n int, f worker.Func) (chan worker.Job, chan worker.Job, *worker.Pool) { + inCh := make(chan worker.Job, bufsize) + outCh := make(chan worker.Job, bufsize) + + return inCh, outCh, worker.New(n, f, inCh, outCh) +} + +func TestPool(t *testing.T) { + inCh, outCh, p := newBufferedPool(200, concurrency, square) + + for i := 0; i < 150; i++ { + inCh <- worker.Job{Data: i} + } + + close(inCh) + p.Wait() + + for res := range outCh { + if res.Error != nil { + t.Errorf("unexpected error for job %v received: %v", res.Data, res.Error) + continue + } + + n := res.Data.(int) + m := res.Result.(int) + + if m != n*n { + t.Errorf("wrong value for job %d returned: want %d, got %d", n, n*n, m) + } + } +} + +func TestPoolErrors(t *testing.T) { + inCh, outCh, p := newBufferedPool(200, concurrency, square) + + for i := 0; i < 150; i++ { + inCh <- worker.Job{Data: i + 1900} + } + + close(inCh) + p.Wait() + + for res := range outCh { + n := res.Data.(int) + + if n > 2000 { + if res.Error == nil { + t.Errorf("expected error not found, result is %v", res) + continue + } + + if res.Error != errTooLarge { + t.Errorf("unexpected error found, result is %v", res) + } + + continue + } else { + if res.Error != nil { + t.Errorf("unexpected error for job %v received: %v", res.Data, res.Error) + continue + } + } + + m := res.Result.(int) + if m != n*n { + t.Errorf("wrong value for job %d returned: want %d, got %d", n, n*n, m) + } + } +} + +var errCancelled = errors.New("cancelled") + +type Job struct { + suc chan struct{} + d time.Duration +} + +func wait(job worker.Job, done <-chan struct{}) (interface{}, error) { + j := job.Data.(Job) + select { + case j.suc <- struct{}{}: + return time.Now(), nil + case <-time.After(j.d): + return time.Now(), nil + case <-done: + return nil, errCancelled + } +} + +func TestPoolCancel(t *testing.T) { + jobCh, resCh, p := newBufferedPool(20, concurrency, wait) + + suc := make(chan struct{}, 1) + for i := 0; i < 20; i++ { + jobCh <- worker.Job{Data: Job{suc: suc, d: time.Second}} + } + + <-suc + p.Cancel() + p.Wait() + + foundResult := false + foundCancelError := false + for res := range resCh { + if res.Error == nil { + foundResult = true + } + + if res.Error == errCancelled { + foundCancelError = true + } + } + + if !foundResult { + t.Error("did not find one expected result") + } + + if !foundCancelError { + t.Error("did not find one expected cancel error") + } +}