diff --git a/changelog/unreleased/pull-3106 b/changelog/unreleased/pull-3106 new file mode 100644 index 000000000..67d80b4d0 --- /dev/null +++ b/changelog/unreleased/pull-3106 @@ -0,0 +1,10 @@ +Enhancement: Parallelize scan of snapshot content in copy and prune + +The copy and the prune commands used to traverse the directories of +snapshots one by one to find used data. This snapshot traversal is +now parallized which can speed up this step several times. + +In addition the check command now reports how many snapshots have +already been processed. + +https://github.com/restic/restic/pull/3106 diff --git a/cmd/restic/cmd_check.go b/cmd/restic/cmd_check.go index 774879490..c1a6e5464 100644 --- a/cmd/restic/cmd_check.go +++ b/cmd/restic/cmd_check.go @@ -240,7 +240,11 @@ func runCheck(opts CheckOptions, gopts GlobalOptions, args []string) error { Verbosef("check snapshots, trees and blobs\n") errChan = make(chan error) - go chkr.Structure(gopts.ctx, errChan) + go func() { + bar := newProgressMax(!gopts.Quiet, 0, "snapshots") + defer bar.Done() + chkr.Structure(gopts.ctx, bar, errChan) + }() for err := range errChan { errorsFound = true diff --git a/cmd/restic/cmd_copy.go b/cmd/restic/cmd_copy.go index 216dab836..cb8296d4b 100644 --- a/cmd/restic/cmd_copy.go +++ b/cmd/restic/cmd_copy.go @@ -6,6 +6,7 @@ import ( "github.com/restic/restic/internal/debug" "github.com/restic/restic/internal/restic" + "golang.org/x/sync/errgroup" "github.com/spf13/cobra" ) @@ -103,12 +104,8 @@ func runCopy(opts CopyOptions, gopts GlobalOptions, args []string) error { dstSnapshotByOriginal[*sn.ID()] = append(dstSnapshotByOriginal[*sn.ID()], sn) } - cloner := &treeCloner{ - srcRepo: srcRepo, - dstRepo: dstRepo, - visitedTrees: restic.NewIDSet(), - buf: nil, - } + // remember already processed trees across all snapshots + visitedTrees := restic.NewIDSet() for sn := range FindFilteredSnapshots(ctx, srcRepo, opts.Hosts, opts.Tags, opts.Paths, args) { Verbosef("\nsnapshot %s of %v at %s)\n", sn.ID().Str(), sn.Paths, sn.Time) @@ -133,7 +130,7 @@ func runCopy(opts CopyOptions, gopts GlobalOptions, args []string) error { } Verbosef(" copy started, this may take a while...\n") - if err := cloner.copyTree(ctx, *sn.Tree); err != nil { + if err := copyTree(ctx, srcRepo, dstRepo, visitedTrees, *sn.Tree); err != nil { return err } debug.Log("tree copied") @@ -177,64 +174,64 @@ func similarSnapshots(sna *restic.Snapshot, snb *restic.Snapshot) bool { return true } -type treeCloner struct { - srcRepo restic.Repository - dstRepo restic.Repository - visitedTrees restic.IDSet - buf []byte -} +func copyTree(ctx context.Context, srcRepo restic.Repository, dstRepo restic.Repository, + visitedTrees restic.IDSet, rootTreeID restic.ID) error { -func (t *treeCloner) copyTree(ctx context.Context, treeID restic.ID) error { - // We have already processed this tree - if t.visitedTrees.Has(treeID) { + wg, ctx := errgroup.WithContext(ctx) + + treeStream := restic.StreamTrees(ctx, wg, srcRepo, restic.IDs{rootTreeID}, func(treeID restic.ID) bool { + visited := visitedTrees.Has(treeID) + visitedTrees.Insert(treeID) + return visited + }, nil) + + wg.Go(func() error { + // reused buffer + var buf []byte + + for tree := range treeStream { + if tree.Error != nil { + return fmt.Errorf("LoadTree(%v) returned error %v", tree.ID.Str(), tree.Error) + } + + // Do we already have this tree blob? + if !dstRepo.Index().Has(restic.BlobHandle{ID: tree.ID, Type: restic.TreeBlob}) { + newTreeID, err := dstRepo.SaveTree(ctx, tree.Tree) + if err != nil { + return fmt.Errorf("SaveTree(%v) returned error %v", tree.ID.Str(), err) + } + // Assurance only. + if newTreeID != tree.ID { + return fmt.Errorf("SaveTree(%v) returned unexpected id %s", tree.ID.Str(), newTreeID.Str()) + } + } + + // TODO: parallelize blob down/upload + + for _, entry := range tree.Nodes { + // Recursion into directories is handled by StreamTrees + // Copy the blobs for this file. + for _, blobID := range entry.Content { + // Do we already have this data blob? + if dstRepo.Index().Has(restic.BlobHandle{ID: blobID, Type: restic.DataBlob}) { + continue + } + debug.Log("Copying blob %s\n", blobID.Str()) + var err error + buf, err = srcRepo.LoadBlob(ctx, restic.DataBlob, blobID, buf) + if err != nil { + return fmt.Errorf("LoadBlob(%v) returned error %v", blobID, err) + } + + _, _, err = dstRepo.SaveBlob(ctx, restic.DataBlob, buf, blobID, false) + if err != nil { + return fmt.Errorf("SaveBlob(%v) returned error %v", blobID, err) + } + } + } + + } return nil - } - - tree, err := t.srcRepo.LoadTree(ctx, treeID) - if err != nil { - return fmt.Errorf("LoadTree(%v) returned error %v", treeID.Str(), err) - } - t.visitedTrees.Insert(treeID) - - // Do we already have this tree blob? - if !t.dstRepo.Index().Has(restic.BlobHandle{ID: treeID, Type: restic.TreeBlob}) { - newTreeID, err := t.dstRepo.SaveTree(ctx, tree) - if err != nil { - return fmt.Errorf("SaveTree(%v) returned error %v", treeID.Str(), err) - } - // Assurance only. - if newTreeID != treeID { - return fmt.Errorf("SaveTree(%v) returned unexpected id %s", treeID.Str(), newTreeID.Str()) - } - } - - // TODO: parellize this stuff, likely only needed inside a tree. - - for _, entry := range tree.Nodes { - // If it is a directory, recurse - if entry.Type == "dir" && entry.Subtree != nil { - if err := t.copyTree(ctx, *entry.Subtree); err != nil { - return err - } - } - // Copy the blobs for this file. - for _, blobID := range entry.Content { - // Do we already have this data blob? - if t.dstRepo.Index().Has(restic.BlobHandle{ID: blobID, Type: restic.DataBlob}) { - continue - } - debug.Log("Copying blob %s\n", blobID.Str()) - t.buf, err = t.srcRepo.LoadBlob(ctx, restic.DataBlob, blobID, t.buf) - if err != nil { - return fmt.Errorf("LoadBlob(%v) returned error %v", blobID, err) - } - - _, _, err = t.dstRepo.SaveBlob(ctx, restic.DataBlob, t.buf, blobID, false) - if err != nil { - return fmt.Errorf("SaveBlob(%v) returned error %v", blobID, err) - } - } - } - - return nil + }) + return wg.Wait() } diff --git a/cmd/restic/cmd_prune.go b/cmd/restic/cmd_prune.go index bdad4efd9..90fe7693d 100644 --- a/cmd/restic/cmd_prune.go +++ b/cmd/restic/cmd_prune.go @@ -574,20 +574,14 @@ func getUsedBlobs(gopts GlobalOptions, repo restic.Repository, ignoreSnapshots r bar := newProgressMax(!gopts.Quiet, uint64(len(snapshotTrees)), "snapshots") defer bar.Done() - for _, tree := range snapshotTrees { - debug.Log("process tree %v", tree) - err = restic.FindUsedBlobs(ctx, repo, tree, usedBlobs) - if err != nil { - if repo.Backend().IsNotExist(err) { - return nil, errors.Fatal("unable to load a tree from the repo: " + err.Error()) - } - - return nil, err + err = restic.FindUsedBlobs(ctx, repo, snapshotTrees, usedBlobs, bar) + if err != nil { + if repo.Backend().IsNotExist(err) { + return nil, errors.Fatal("unable to load a tree from the repo: " + err.Error()) } - debug.Log("processed tree %v", tree) - bar.Add(1) + return nil, err } return usedBlobs, nil } diff --git a/cmd/restic/cmd_stats.go b/cmd/restic/cmd_stats.go index 81ec66843..deb649e26 100644 --- a/cmd/restic/cmd_stats.go +++ b/cmd/restic/cmd_stats.go @@ -166,7 +166,7 @@ func statsWalkSnapshot(ctx context.Context, snapshot *restic.Snapshot, repo rest if statsOptions.countMode == countModeRawData { // count just the sizes of unique blobs; we don't need to walk the tree // ourselves in this case, since a nifty function does it for us - return restic.FindUsedBlobs(ctx, repo, *snapshot.Tree, stats.blobs) + return restic.FindUsedBlobs(ctx, repo, restic.IDs{*snapshot.Tree}, stats.blobs, nil) } err := walker.Walk(ctx, repo, *snapshot.Tree, restic.NewIDSet(), statsWalkTree(repo, stats)) diff --git a/cmd/restic/progress.go b/cmd/restic/progress.go index 2fbd97c6c..c0b6c56fb 100644 --- a/cmd/restic/progress.go +++ b/cmd/restic/progress.go @@ -33,11 +33,14 @@ func newProgressMax(show bool, max uint64, description string) *progress.Counter } interval := calculateProgressInterval() - return progress.New(interval, func(v uint64, d time.Duration, final bool) { - status := fmt.Sprintf("[%s] %s %d / %d %s", - formatDuration(d), - formatPercent(v, max), - v, max, description) + return progress.New(interval, max, func(v uint64, max uint64, d time.Duration, final bool) { + var status string + if max == 0 { + status = fmt.Sprintf("[%s] %d %s", formatDuration(d), v, description) + } else { + status = fmt.Sprintf("[%s] %s %d / %d %s", + formatDuration(d), formatPercent(v, max), v, max, description) + } if w := stdoutTerminalWidth(); w > 0 { status = shortenStatus(w, status) diff --git a/internal/checker/checker.go b/internal/checker/checker.go index 23564ea45..1ed470e99 100644 --- a/internal/checker/checker.go +++ b/internal/checker/checker.go @@ -308,200 +308,27 @@ func (e TreeError) Error() string { return fmt.Sprintf("tree %v: %v", e.ID.Str(), e.Errors) } -type treeJob struct { - restic.ID - error - *restic.Tree -} - -// loadTreeWorker loads trees from repo and sends them to out. -func loadTreeWorker(ctx context.Context, repo restic.Repository, - in <-chan restic.ID, out chan<- treeJob, - wg *sync.WaitGroup) { - - defer func() { - debug.Log("exiting") - wg.Done() - }() - - var ( - inCh = in - outCh = out - job treeJob - ) - - outCh = nil - for { - select { - case <-ctx.Done(): - return - - case treeID, ok := <-inCh: - if !ok { - return - } - debug.Log("load tree %v", treeID) - - tree, err := repo.LoadTree(ctx, treeID) - debug.Log("load tree %v (%v) returned err: %v", tree, treeID, err) - job = treeJob{ID: treeID, error: err, Tree: tree} - outCh = out - inCh = nil - - case outCh <- job: - debug.Log("sent tree %v", job.ID) - outCh = nil - inCh = in - } - } -} - // checkTreeWorker checks the trees received and sends out errors to errChan. -func (c *Checker) checkTreeWorker(ctx context.Context, in <-chan treeJob, out chan<- error, wg *sync.WaitGroup) { - defer func() { - debug.Log("exiting") - wg.Done() - }() +func (c *Checker) checkTreeWorker(ctx context.Context, trees <-chan restic.TreeItem, out chan<- error) { + for job := range trees { + debug.Log("check tree %v (tree %v, err %v)", job.ID, job.Tree, job.Error) - var ( - inCh = in - outCh = out - treeError TreeError - ) + var errs []error + if job.Error != nil { + errs = append(errs, job.Error) + } else { + errs = c.checkTree(job.ID, job.Tree) + } - outCh = nil - for { + if len(errs) == 0 { + continue + } + treeError := TreeError{ID: job.ID, Errors: errs} select { case <-ctx.Done(): - debug.Log("done channel closed, exiting") return - - case job, ok := <-inCh: - if !ok { - debug.Log("input channel closed, exiting") - return - } - - debug.Log("check tree %v (tree %v, err %v)", job.ID, job.Tree, job.error) - - var errs []error - if job.error != nil { - errs = append(errs, job.error) - } else { - errs = c.checkTree(job.ID, job.Tree) - } - - if len(errs) > 0 { - debug.Log("checked tree %v: %v errors", job.ID, len(errs)) - treeError = TreeError{ID: job.ID, Errors: errs} - outCh = out - inCh = nil - } - - case outCh <- treeError: + case out <- treeError: debug.Log("tree %v: sent %d errors", treeError.ID, len(treeError.Errors)) - outCh = nil - inCh = in - } - } -} - -func (c *Checker) filterTrees(ctx context.Context, backlog restic.IDs, loaderChan chan<- restic.ID, in <-chan treeJob, out chan<- treeJob) { - defer func() { - debug.Log("closing output channels") - close(loaderChan) - close(out) - }() - - var ( - inCh = in - outCh = out - loadCh = loaderChan - job treeJob - nextTreeID restic.ID - outstandingLoadTreeJobs = 0 - ) - - outCh = nil - loadCh = nil - - for { - if loadCh == nil && len(backlog) > 0 { - // process last added ids first, that is traverse the tree in depth-first order - ln := len(backlog) - 1 - nextTreeID, backlog = backlog[ln], backlog[:ln] - - // use a separate flag for processed trees to ensure that check still processes trees - // even when a file references a tree blob - c.blobRefs.Lock() - h := restic.BlobHandle{ID: nextTreeID, Type: restic.TreeBlob} - blobReferenced := c.blobRefs.M.Has(h) - // noop if already referenced - c.blobRefs.M.Insert(h) - c.blobRefs.Unlock() - if blobReferenced { - continue - } - - loadCh = loaderChan - } - - if loadCh == nil && outCh == nil && outstandingLoadTreeJobs == 0 { - debug.Log("backlog is empty, all channels nil, exiting") - return - } - - select { - case <-ctx.Done(): - return - - case loadCh <- nextTreeID: - outstandingLoadTreeJobs++ - loadCh = nil - - case j, ok := <-inCh: - if !ok { - debug.Log("input channel closed") - inCh = nil - in = nil - continue - } - - outstandingLoadTreeJobs-- - - debug.Log("input job tree %v", j.ID) - - if j.error != nil { - debug.Log("received job with error: %v (tree %v, ID %v)", j.error, j.Tree, j.ID) - } else if j.Tree == nil { - debug.Log("received job with nil tree pointer: %v (ID %v)", j.error, j.ID) - // send a new job with the new error instead of the old one - j = treeJob{ID: j.ID, error: errors.New("tree is nil and error is nil")} - } else { - subtrees := j.Tree.Subtrees() - debug.Log("subtrees for tree %v: %v", j.ID, subtrees) - // iterate backwards over subtree to compensate backwards traversal order of nextTreeID selection - for i := len(subtrees) - 1; i >= 0; i-- { - id := subtrees[i] - if id.IsNull() { - // We do not need to raise this error here, it is - // checked when the tree is checked. Just make sure - // that we do not add any null IDs to the backlog. - debug.Log("tree %v has nil subtree", j.ID) - continue - } - backlog = append(backlog, id) - } - } - - job = j - outCh = out - inCh = nil - - case outCh <- job: - debug.Log("tree sent to check: %v", job.ID) - outCh = nil - inCh = in } } } @@ -527,10 +354,9 @@ func loadSnapshotTreeIDs(ctx context.Context, repo restic.Repository) (ids resti // Structure checks that for all snapshots all referenced data blobs and // subtrees are available in the index. errChan is closed after all trees have // been traversed. -func (c *Checker) Structure(ctx context.Context, errChan chan<- error) { - defer close(errChan) - +func (c *Checker) Structure(ctx context.Context, p *progress.Counter, errChan chan<- error) { trees, errs := loadSnapshotTreeIDs(ctx, c.repo) + p.SetMax(uint64(len(trees))) debug.Log("need to check %d trees from snapshots, %d errs returned", len(trees), len(errs)) for _, err := range errs { @@ -541,19 +367,26 @@ func (c *Checker) Structure(ctx context.Context, errChan chan<- error) { } } - treeIDChan := make(chan restic.ID) - treeJobChan1 := make(chan treeJob) - treeJobChan2 := make(chan treeJob) + wg, ctx := errgroup.WithContext(ctx) + treeStream := restic.StreamTrees(ctx, wg, c.repo, trees, func(treeID restic.ID) bool { + // blobRefs may be accessed in parallel by checkTree + c.blobRefs.Lock() + h := restic.BlobHandle{ID: treeID, Type: restic.TreeBlob} + blobReferenced := c.blobRefs.M.Has(h) + // noop if already referenced + c.blobRefs.M.Insert(h) + c.blobRefs.Unlock() + return blobReferenced + }, p) - var wg sync.WaitGroup + defer close(errChan) for i := 0; i < defaultParallelism; i++ { - wg.Add(2) - go loadTreeWorker(ctx, c.repo, treeIDChan, treeJobChan1, &wg) - go c.checkTreeWorker(ctx, treeJobChan2, errChan, &wg) + wg.Go(func() error { + c.checkTreeWorker(ctx, treeStream, errChan) + return nil + }) } - c.filterTrees(ctx, trees, treeIDChan, treeJobChan1, treeJobChan2) - wg.Wait() } diff --git a/internal/checker/checker_test.go b/internal/checker/checker_test.go index f8efd05e8..ad1b15f1a 100644 --- a/internal/checker/checker_test.go +++ b/internal/checker/checker_test.go @@ -43,7 +43,9 @@ func checkPacks(chkr *checker.Checker) []error { } func checkStruct(chkr *checker.Checker) []error { - return collectErrors(context.TODO(), chkr.Structure) + return collectErrors(context.TODO(), func(ctx context.Context, errChan chan<- error) { + chkr.Structure(ctx, nil, errChan) + }) } func checkData(chkr *checker.Checker) []error { diff --git a/internal/checker/testing.go b/internal/checker/testing.go index 6c5be84e2..d672911b1 100644 --- a/internal/checker/testing.go +++ b/internal/checker/testing.go @@ -30,7 +30,7 @@ func TestCheckRepo(t testing.TB, repo restic.Repository) { // structure errChan = make(chan error) - go chkr.Structure(context.TODO(), errChan) + go chkr.Structure(context.TODO(), nil, errChan) for err := range errChan { t.Error(err) diff --git a/internal/repository/master_index_test.go b/internal/repository/master_index_test.go index 9ccf0e59e..3c279696e 100644 --- a/internal/repository/master_index_test.go +++ b/internal/repository/master_index_test.go @@ -368,7 +368,7 @@ func TestIndexSave(t *testing.T) { defer cancel() errCh := make(chan error) - go checker.Structure(ctx, errCh) + go checker.Structure(ctx, nil, errCh) i := 0 for err := range errCh { t.Errorf("checker returned error: %v", err) diff --git a/internal/restic/find.go b/internal/restic/find.go index b5bef0720..4d6433d60 100644 --- a/internal/restic/find.go +++ b/internal/restic/find.go @@ -1,6 +1,12 @@ package restic -import "context" +import ( + "context" + "sync" + + "github.com/restic/restic/internal/ui/progress" + "golang.org/x/sync/errgroup" +) // TreeLoader loads a tree from a repository. type TreeLoader interface { @@ -9,31 +15,39 @@ type TreeLoader interface { // FindUsedBlobs traverses the tree ID and adds all seen blobs (trees and data // blobs) to the set blobs. Already seen tree blobs will not be visited again. -func FindUsedBlobs(ctx context.Context, repo TreeLoader, treeID ID, blobs BlobSet) error { - h := BlobHandle{ID: treeID, Type: TreeBlob} - if blobs.Has(h) { - return nil - } - blobs.Insert(h) +func FindUsedBlobs(ctx context.Context, repo TreeLoader, treeIDs IDs, blobs BlobSet, p *progress.Counter) error { + var lock sync.Mutex - tree, err := repo.LoadTree(ctx, treeID) - if err != nil { - return err - } + wg, ctx := errgroup.WithContext(ctx) + treeStream := StreamTrees(ctx, wg, repo, treeIDs, func(treeID ID) bool { + // locking is necessary the goroutine below concurrently adds data blobs + lock.Lock() + h := BlobHandle{ID: treeID, Type: TreeBlob} + blobReferenced := blobs.Has(h) + // noop if already referenced + blobs.Insert(h) + lock.Unlock() + return blobReferenced + }, p) - for _, node := range tree.Nodes { - switch node.Type { - case "file": - for _, blob := range node.Content { - blobs.Insert(BlobHandle{ID: blob, Type: DataBlob}) + wg.Go(func() error { + for tree := range treeStream { + if tree.Error != nil { + return tree.Error } - case "dir": - err := FindUsedBlobs(ctx, repo, *node.Subtree, blobs) - if err != nil { - return err + + lock.Lock() + for _, node := range tree.Nodes { + switch node.Type { + case "file": + for _, blob := range node.Content { + blobs.Insert(BlobHandle{ID: blob, Type: DataBlob}) + } + } } + lock.Unlock() } - } - - return nil + return nil + }) + return wg.Wait() } diff --git a/internal/restic/find_test.go b/internal/restic/find_test.go index 635421d8b..c599e5fdb 100644 --- a/internal/restic/find_test.go +++ b/internal/restic/find_test.go @@ -15,6 +15,8 @@ import ( "github.com/restic/restic/internal/errors" "github.com/restic/restic/internal/repository" "github.com/restic/restic/internal/restic" + "github.com/restic/restic/internal/test" + "github.com/restic/restic/internal/ui/progress" ) func loadIDSet(t testing.TB, filename string) restic.BlobSet { @@ -92,9 +94,12 @@ func TestFindUsedBlobs(t *testing.T) { snapshots = append(snapshots, sn) } + p := progress.New(time.Second, findTestSnapshots, func(value uint64, total uint64, runtime time.Duration, final bool) {}) + defer p.Done() + for i, sn := range snapshots { usedBlobs := restic.NewBlobSet() - err := restic.FindUsedBlobs(context.TODO(), repo, *sn.Tree, usedBlobs) + err := restic.FindUsedBlobs(context.TODO(), repo, restic.IDs{*sn.Tree}, usedBlobs, p) if err != nil { t.Errorf("FindUsedBlobs returned error: %v", err) continue @@ -105,6 +110,8 @@ func TestFindUsedBlobs(t *testing.T) { continue } + test.Equals(t, p.Get(), uint64(i+1)) + goldenFilename := filepath.Join("testdata", fmt.Sprintf("used_blobs_snapshot%d", i)) want := loadIDSet(t, goldenFilename) @@ -119,6 +126,40 @@ func TestFindUsedBlobs(t *testing.T) { } } +func TestMultiFindUsedBlobs(t *testing.T) { + repo, cleanup := repository.TestRepository(t) + defer cleanup() + + var snapshotTrees restic.IDs + for i := 0; i < findTestSnapshots; i++ { + sn := restic.TestCreateSnapshot(t, repo, findTestTime.Add(time.Duration(i)*time.Second), findTestDepth, 0) + t.Logf("snapshot %v saved, tree %v", sn.ID().Str(), sn.Tree.Str()) + snapshotTrees = append(snapshotTrees, *sn.Tree) + } + + want := restic.NewBlobSet() + for i := range snapshotTrees { + goldenFilename := filepath.Join("testdata", fmt.Sprintf("used_blobs_snapshot%d", i)) + want.Merge(loadIDSet(t, goldenFilename)) + } + + p := progress.New(time.Second, findTestSnapshots, func(value uint64, total uint64, runtime time.Duration, final bool) {}) + defer p.Done() + + // run twice to check progress bar handling of duplicate tree roots + usedBlobs := restic.NewBlobSet() + for i := 1; i < 3; i++ { + err := restic.FindUsedBlobs(context.TODO(), repo, snapshotTrees, usedBlobs, p) + test.OK(t, err) + test.Equals(t, p.Get(), uint64(i*len(snapshotTrees))) + + if !want.Equals(usedBlobs) { + t.Errorf("wrong list of blobs returned:\n missing blobs: %v\n extra blobs: %v", + want.Sub(usedBlobs), usedBlobs.Sub(want)) + } + } +} + type ForbiddenRepo struct{} func (r ForbiddenRepo) LoadTree(ctx context.Context, id restic.ID) (*restic.Tree, error) { @@ -133,12 +174,12 @@ func TestFindUsedBlobsSkipsSeenBlobs(t *testing.T) { t.Logf("snapshot %v saved, tree %v", snapshot.ID().Str(), snapshot.Tree.Str()) usedBlobs := restic.NewBlobSet() - err := restic.FindUsedBlobs(context.TODO(), repo, *snapshot.Tree, usedBlobs) + err := restic.FindUsedBlobs(context.TODO(), repo, restic.IDs{*snapshot.Tree}, usedBlobs, nil) if err != nil { t.Fatalf("FindUsedBlobs returned error: %v", err) } - err = restic.FindUsedBlobs(context.TODO(), ForbiddenRepo{}, *snapshot.Tree, usedBlobs) + err = restic.FindUsedBlobs(context.TODO(), ForbiddenRepo{}, restic.IDs{*snapshot.Tree}, usedBlobs, nil) if err != nil { t.Fatalf("FindUsedBlobs returned error: %v", err) } @@ -154,7 +195,7 @@ func BenchmarkFindUsedBlobs(b *testing.B) { for i := 0; i < b.N; i++ { blobs := restic.NewBlobSet() - err := restic.FindUsedBlobs(context.TODO(), repo, *sn.Tree, blobs) + err := restic.FindUsedBlobs(context.TODO(), repo, restic.IDs{*sn.Tree}, blobs, nil) if err != nil { b.Error(err) } diff --git a/internal/restic/tree_stream.go b/internal/restic/tree_stream.go new file mode 100644 index 000000000..871ba8998 --- /dev/null +++ b/internal/restic/tree_stream.go @@ -0,0 +1,183 @@ +package restic + +import ( + "context" + "errors" + "sync" + + "github.com/restic/restic/internal/debug" + "github.com/restic/restic/internal/ui/progress" + "golang.org/x/sync/errgroup" +) + +const streamTreeParallelism = 5 + +// TreeItem is used to return either an error or the tree for a tree id +type TreeItem struct { + ID + Error error + *Tree +} + +type trackedTreeItem struct { + TreeItem + rootIdx int +} + +type trackedID struct { + ID + rootIdx int +} + +// loadTreeWorker loads trees from repo and sends them to out. +func loadTreeWorker(ctx context.Context, repo TreeLoader, + in <-chan trackedID, out chan<- trackedTreeItem) { + + for treeID := range in { + tree, err := repo.LoadTree(ctx, treeID.ID) + debug.Log("load tree %v (%v) returned err: %v", tree, treeID, err) + job := trackedTreeItem{TreeItem: TreeItem{ID: treeID.ID, Error: err, Tree: tree}, rootIdx: treeID.rootIdx} + + select { + case <-ctx.Done(): + return + case out <- job: + } + } +} + +func filterTrees(ctx context.Context, trees IDs, loaderChan chan<- trackedID, + in <-chan trackedTreeItem, out chan<- TreeItem, skip func(tree ID) bool, p *progress.Counter) { + + var ( + inCh = in + outCh chan<- TreeItem + loadCh chan<- trackedID + job TreeItem + nextTreeID trackedID + outstandingLoadTreeJobs = 0 + ) + rootCounter := make([]int, len(trees)) + backlog := make([]trackedID, 0, len(trees)) + for idx, id := range trees { + backlog = append(backlog, trackedID{ID: id, rootIdx: idx}) + rootCounter[idx] = 1 + } + + for { + if loadCh == nil && len(backlog) > 0 { + // process last added ids first, that is traverse the tree in depth-first order + ln := len(backlog) - 1 + nextTreeID, backlog = backlog[ln], backlog[:ln] + + if skip(nextTreeID.ID) { + rootCounter[nextTreeID.rootIdx]-- + if p != nil && rootCounter[nextTreeID.rootIdx] == 0 { + p.Add(1) + } + continue + } + + loadCh = loaderChan + } + + if loadCh == nil && outCh == nil && outstandingLoadTreeJobs == 0 { + debug.Log("backlog is empty, all channels nil, exiting") + return + } + + select { + case <-ctx.Done(): + return + + case loadCh <- nextTreeID: + outstandingLoadTreeJobs++ + loadCh = nil + + case j, ok := <-inCh: + if !ok { + debug.Log("input channel closed") + inCh = nil + in = nil + continue + } + + outstandingLoadTreeJobs-- + rootCounter[j.rootIdx]-- + + debug.Log("input job tree %v", j.ID) + + if j.Error != nil { + debug.Log("received job with error: %v (tree %v, ID %v)", j.Error, j.Tree, j.ID) + } else if j.Tree == nil { + debug.Log("received job with nil tree pointer: %v (ID %v)", j.Error, j.ID) + // send a new job with the new error instead of the old one + j = trackedTreeItem{TreeItem: TreeItem{ID: j.ID, Error: errors.New("tree is nil and error is nil")}, rootIdx: j.rootIdx} + } else { + subtrees := j.Tree.Subtrees() + debug.Log("subtrees for tree %v: %v", j.ID, subtrees) + // iterate backwards over subtree to compensate backwards traversal order of nextTreeID selection + for i := len(subtrees) - 1; i >= 0; i-- { + id := subtrees[i] + if id.IsNull() { + // We do not need to raise this error here, it is + // checked when the tree is checked. Just make sure + // that we do not add any null IDs to the backlog. + debug.Log("tree %v has nil subtree", j.ID) + continue + } + backlog = append(backlog, trackedID{ID: id, rootIdx: j.rootIdx}) + rootCounter[j.rootIdx]++ + } + } + if p != nil && rootCounter[j.rootIdx] == 0 { + p.Add(1) + } + + job = j.TreeItem + outCh = out + inCh = nil + + case outCh <- job: + debug.Log("tree sent to process: %v", job.ID) + outCh = nil + inCh = in + } + } +} + +// StreamTrees iteratively loads the given trees and their subtrees. The skip method +// is guaranteed to always be called from the same goroutine. To shutdown the started +// goroutines, either read all items from the channel or cancel the context. Then `Wait()` +// on the errgroup until all goroutines were stopped. +func StreamTrees(ctx context.Context, wg *errgroup.Group, repo TreeLoader, trees IDs, skip func(tree ID) bool, p *progress.Counter) <-chan TreeItem { + loaderChan := make(chan trackedID) + loadedTreeChan := make(chan trackedTreeItem) + treeStream := make(chan TreeItem) + + var loadTreeWg sync.WaitGroup + + for i := 0; i < streamTreeParallelism; i++ { + loadTreeWg.Add(1) + wg.Go(func() error { + defer loadTreeWg.Done() + loadTreeWorker(ctx, repo, loaderChan, loadedTreeChan) + return nil + }) + } + + // close once all loadTreeWorkers have completed + wg.Go(func() error { + loadTreeWg.Wait() + close(loadedTreeChan) + return nil + }) + + wg.Go(func() error { + defer close(loaderChan) + defer close(treeStream) + filterTrees(ctx, trees, loaderChan, loadedTreeChan, treeStream, skip, p) + return nil + }) + return treeStream +} diff --git a/internal/ui/progress/counter.go b/internal/ui/progress/counter.go index bf4906978..d2f75c9bf 100644 --- a/internal/ui/progress/counter.go +++ b/internal/ui/progress/counter.go @@ -12,7 +12,7 @@ import ( // // The final argument is true if Counter.Done has been called, // which means that the current call will be the last. -type Func func(value uint64, runtime time.Duration, final bool) +type Func func(value uint64, total uint64, runtime time.Duration, final bool) // A Counter tracks a running count and controls a goroutine that passes its // value periodically to a Func. @@ -27,16 +27,19 @@ type Counter struct { valueMutex sync.Mutex value uint64 + max uint64 } // New starts a new Counter. -func New(interval time.Duration, report Func) *Counter { +func New(interval time.Duration, total uint64, report Func) *Counter { c := &Counter{ report: report, start: time.Now(), stopped: make(chan struct{}), stop: make(chan struct{}), + max: total, } + if interval > 0 { c.tick = time.NewTicker(interval) } @@ -56,6 +59,16 @@ func (c *Counter) Add(v uint64) { c.valueMutex.Unlock() } +// SetMax sets the maximum expected counter value. This method is concurrency-safe. +func (c *Counter) SetMax(max uint64) { + if c == nil { + return + } + c.valueMutex.Lock() + c.max = max + c.valueMutex.Unlock() +} + // Done tells a Counter to stop and waits for it to report its final value. func (c *Counter) Done() { if c == nil { @@ -69,7 +82,8 @@ func (c *Counter) Done() { *c = Counter{} // Prevent reuse. } -func (c *Counter) get() uint64 { +// Get the current Counter value. This method is concurrency-safe. +func (c *Counter) Get() uint64 { c.valueMutex.Lock() v := c.value c.valueMutex.Unlock() @@ -77,11 +91,19 @@ func (c *Counter) get() uint64 { return v } +func (c *Counter) getMax() uint64 { + c.valueMutex.Lock() + max := c.max + c.valueMutex.Unlock() + + return max +} + func (c *Counter) run() { defer close(c.stopped) defer func() { // Must be a func so that time.Since isn't called at defer time. - c.report(c.get(), time.Since(c.start), true) + c.report(c.Get(), c.getMax(), time.Since(c.start), true) }() var tick <-chan time.Time @@ -101,6 +123,6 @@ func (c *Counter) run() { return } - c.report(c.get(), now.Sub(c.start), false) + c.report(c.Get(), c.getMax(), now.Sub(c.start), false) } } diff --git a/internal/ui/progress/counter_test.go b/internal/ui/progress/counter_test.go index 9a76d9cbf..49a99f7ee 100644 --- a/internal/ui/progress/counter_test.go +++ b/internal/ui/progress/counter_test.go @@ -10,23 +10,32 @@ import ( func TestCounter(t *testing.T) { const N = 100 + const startTotal = uint64(12345) var ( finalSeen = false increasing = true last uint64 + lastTotal = startTotal ncalls int + nmaxChange int ) - report := func(value uint64, d time.Duration, final bool) { - finalSeen = true + report := func(value uint64, total uint64, d time.Duration, final bool) { + if final { + finalSeen = true + } if value < last { increasing = false } last = value + if total != lastTotal { + nmaxChange++ + } + lastTotal = total ncalls++ } - c := progress.New(10*time.Millisecond, report) + c := progress.New(10*time.Millisecond, startTotal, report) done := make(chan struct{}) go func() { @@ -35,6 +44,7 @@ func TestCounter(t *testing.T) { time.Sleep(time.Millisecond) c.Add(1) } + c.SetMax(42) }() <-done @@ -43,6 +53,8 @@ func TestCounter(t *testing.T) { test.Assert(t, finalSeen, "final call did not happen") test.Assert(t, increasing, "values not increasing") test.Equals(t, uint64(N), last) + test.Equals(t, uint64(42), lastTotal) + test.Equals(t, int(1), nmaxChange) t.Log("number of calls:", ncalls) } @@ -58,14 +70,14 @@ func TestCounterNoTick(t *testing.T) { finalSeen := false otherSeen := false - report := func(value uint64, d time.Duration, final bool) { + report := func(value, total uint64, d time.Duration, final bool) { if final { finalSeen = true } else { otherSeen = true } } - c := progress.New(0, report) + c := progress.New(0, 1, report) time.Sleep(time.Millisecond) c.Done()