diff --git a/internal/restorer/restorer.go b/internal/restorer/restorer.go index 7ce7bb04a..5b6604421 100644 --- a/internal/restorer/restorer.go +++ b/internal/restorer/restorer.go @@ -99,10 +99,13 @@ func (res *Restorer) traverseTree(ctx context.Context, target, location string, } sanitizeError := func(err error) error { - if err != nil { - err = res.Error(nodeLocation, err) + switch err { + case nil, context.Canceled, context.DeadlineExceeded: + // Context errors are permanent. + return err + default: + return res.Error(nodeLocation, err) } - return err } if node.Type == "dir" { @@ -364,7 +367,7 @@ func (res *Restorer) VerifyFiles(ctx context.Context, dst string) (int, error) { } atomic.AddUint64(&nchecked, 1) } - return + return err }) } diff --git a/internal/restorer/restorer_test.go b/internal/restorer/restorer_test.go index 4a51c2c19..4c4093464 100644 --- a/internal/restorer/restorer_test.go +++ b/internal/restorer/restorer_test.go @@ -808,3 +808,41 @@ func TestRestorerConsistentTimestampsAndPermissions(t *testing.T) { checkConsistentInfo(t, test.path, f, test.modtime, test.mode) } } + +// VerifyFiles must not report cancelation of its context through res.Error. +func TestVerifyCancel(t *testing.T) { + snapshot := Snapshot{ + Nodes: map[string]Node{ + "foo": File{Data: "content: foo\n"}, + }, + } + + repo, cleanup := repository.TestRepository(t) + defer cleanup() + + _, id := saveSnapshot(t, repo, snapshot) + + res, err := NewRestorer(context.TODO(), repo, id) + rtest.OK(t, err) + + tempdir, cleanup := rtest.TempDir(t) + defer cleanup() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + rtest.OK(t, res.RestoreTo(ctx, tempdir)) + err = ioutil.WriteFile(filepath.Join(tempdir, "foo"), []byte("bar"), 0644) + rtest.OK(t, err) + + var errs []error + res.Error = func(filename string, err error) error { + errs = append(errs, err) + return err + } + + nverified, err := res.VerifyFiles(ctx, tempdir) + rtest.Equals(t, 0, nverified) + rtest.Assert(t, err != nil, "nil error from VerifyFiles") + rtest.Equals(t, []error(nil), errs) +}