diff --git a/internal/restorer/restorer.go b/internal/restorer/restorer.go index 289883ed0..4acd45f95 100644 --- a/internal/restorer/restorer.go +++ b/internal/restorer/restorer.go @@ -166,12 +166,14 @@ func (res *Restorer) restoreNodeTo(ctx context.Context, node *restic.Node, targe err := node.CreateAt(ctx, target, res.repo) if err != nil { debug.Log("node.CreateAt(%s) error %v", target, err) - } - if err == nil { - err = res.restoreNodeMetadataTo(node, target, location) + return err } - return err + if res.progress != nil { + res.progress.AddProgress(location, 0, 0) + } + + return res.restoreNodeMetadataTo(node, target, location) } func (res *Restorer) restoreNodeMetadataTo(node *restic.Node, target, location string) error { @@ -239,6 +241,9 @@ func (res *Restorer) RestoreTo(ctx context.Context, dst string) error { _, err = res.traverseTree(ctx, dst, string(filepath.Separator), *res.sn.Tree, treeVisitor{ enterDir: func(node *restic.Node, target, location string) error { debug.Log("first pass, enterDir: mkdir %q, leaveDir should restore metadata", location) + if res.progress != nil { + res.progress.AddFile(0) + } // create dir with default permissions // #leaveDir restores dir metadata after visiting all children return fs.MkdirAll(target, 0700) @@ -254,24 +259,34 @@ func (res *Restorer) RestoreTo(ctx context.Context, dst string) error { } if node.Type != "file" { + if res.progress != nil { + res.progress.AddFile(0) + } return nil } - if res.progress != nil { - res.progress.AddFile(node.Size) - } - if node.Size == 0 { + if res.progress != nil { + res.progress.AddFile(node.Size) + } return nil // deal with empty files later } if node.Links > 1 { if idx.Has(node.Inode, node.DeviceID) { + if res.progress != nil { + // a hardlinked file does not increase the restore size + res.progress.AddFile(0) + } return nil } idx.Add(node.Inode, node.DeviceID, location) } + if res.progress != nil { + res.progress.AddFile(node.Size) + } + filerestorer.addFile(location, node.Content, int64(node.Size)) return nil @@ -310,7 +325,13 @@ func (res *Restorer) RestoreTo(ctx context.Context, dst string) error { return res.restoreNodeMetadataTo(node, target, location) }, - leaveDir: res.restoreNodeMetadataTo, + leaveDir: func(node *restic.Node, target, location string) error { + err := res.restoreNodeMetadataTo(node, target, location) + if err == nil && res.progress != nil { + res.progress.AddProgress(location, 0, 0) + } + return err + }, }) return err } diff --git a/internal/restorer/restorer_unix_test.go b/internal/restorer/restorer_unix_test.go index e9c521e36..4c5f2a5b8 100644 --- a/internal/restorer/restorer_unix_test.go +++ b/internal/restorer/restorer_unix_test.go @@ -9,10 +9,12 @@ import ( "path/filepath" "syscall" "testing" + "time" "github.com/restic/restic/internal/repository" "github.com/restic/restic/internal/restic" rtest "github.com/restic/restic/internal/test" + restoreui "github.com/restic/restic/internal/ui/restore" ) func TestRestorerRestoreEmptyHardlinkedFileds(t *testing.T) { @@ -66,3 +68,56 @@ func getBlockCount(t *testing.T, filename string) int64 { } return st.Blocks } + +type printerMock struct { + filesFinished, filesTotal, allBytesWritten, allBytesTotal uint64 +} + +func (p *printerMock) Update(filesFinished, filesTotal, allBytesWritten, allBytesTotal uint64, duration time.Duration) { +} +func (p *printerMock) Finish(filesFinished, filesTotal, allBytesWritten, allBytesTotal uint64, duration time.Duration) { + p.filesFinished = filesFinished + p.filesTotal = filesTotal + p.allBytesWritten = allBytesWritten + p.allBytesTotal = allBytesTotal +} + +func TestRestorerProgressBar(t *testing.T) { + repo := repository.TestRepository(t) + + sn, _ := saveSnapshot(t, repo, Snapshot{ + Nodes: map[string]Node{ + "dirtest": Dir{ + Nodes: map[string]Node{ + "file1": File{Links: 2, Inode: 1, Data: "foo"}, + "file2": File{Links: 2, Inode: 1, Data: "foo"}, + }, + }, + "file2": File{Links: 1, Inode: 2, Data: "example"}, + }, + }) + + mock := &printerMock{} + progress := restoreui.NewProgress(mock, 0) + res := NewRestorer(context.TODO(), repo, sn, false, progress) + res.SelectFilter = func(item string, dstpath string, node *restic.Node) (selectedForRestore bool, childMayBeSelected bool) { + return true, true + } + + tempdir := rtest.TempDir(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + err := res.RestoreTo(ctx, tempdir) + rtest.OK(t, err) + progress.Finish() + + const filesFinished = 4 + const filesTotal = filesFinished + const allBytesWritten = 10 + const allBytesTotal = allBytesWritten + rtest.Assert(t, mock.filesFinished == filesFinished, "filesFinished: expected %v, got %v", filesFinished, mock.filesFinished) + rtest.Assert(t, mock.filesTotal == filesTotal, "filesTotal: expected %v, got %v", filesTotal, mock.filesTotal) + rtest.Assert(t, mock.allBytesWritten == allBytesWritten, "allBytesWritten: expected %v, got %v", allBytesWritten, mock.allBytesWritten) + rtest.Assert(t, mock.allBytesTotal == allBytesTotal, "allBytesTotal: expected %v, got %v", allBytesTotal, mock.allBytesTotal) +}