diff --git a/internal/restorer/restorer_test.go b/internal/restorer/restorer_test.go index 25ce668db..b1fb0ff92 100644 --- a/internal/restorer/restorer_test.go +++ b/internal/restorer/restorer_test.go @@ -895,6 +895,31 @@ func TestRestorerSparseFiles(t *testing.T) { len(zeros), blocks, 100*sparsity) } +func saveSnapshotsAndOverwrite(t *testing.T, baseSnapshot Snapshot, overwriteSnapshot Snapshot, options Options) string { + repo := repository.TestRepository(t) + tempdir := filepath.Join(rtest.TempDir(t), "target") + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // base snapshot + sn, id := saveSnapshot(t, repo, baseSnapshot, noopGetGenericAttributes) + t.Logf("base snapshot saved as %v", id.Str()) + + res := NewRestorer(repo, sn, options) + rtest.OK(t, res.RestoreTo(ctx, tempdir)) + + // overwrite snapshot + sn, id = saveSnapshot(t, repo, overwriteSnapshot, noopGetGenericAttributes) + t.Logf("overwrite snapshot saved as %v", id.Str()) + res = NewRestorer(repo, sn, options) + rtest.OK(t, res.RestoreTo(ctx, tempdir)) + + _, err := res.VerifyFiles(ctx, tempdir) + rtest.OK(t, err) + + return tempdir +} + func TestRestorerSparseOverwrite(t *testing.T) { baseSnapshot := Snapshot{ Nodes: map[string]Node{ @@ -908,29 +933,7 @@ func TestRestorerSparseOverwrite(t *testing.T) { }, } - repo := repository.TestRepository(t) - tempdir := filepath.Join(rtest.TempDir(t), "target") - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - // base snapshot - sn, id := saveSnapshot(t, repo, baseSnapshot, noopGetGenericAttributes) - t.Logf("base snapshot saved as %v", id.Str()) - - res := NewRestorer(repo, sn, Options{Sparse: true}) - err := res.RestoreTo(ctx, tempdir) - rtest.OK(t, err) - - // sparse snapshot - sn, id = saveSnapshot(t, repo, sparseSnapshot, noopGetGenericAttributes) - t.Logf("base snapshot saved as %v", id.Str()) - - res = NewRestorer(repo, sn, Options{Sparse: true, Overwrite: OverwriteAlways}) - err = res.RestoreTo(ctx, tempdir) - rtest.OK(t, err) - files, err := res.VerifyFiles(ctx, tempdir) - rtest.OK(t, err) - rtest.Equals(t, 1, files, "unexpected number of verified files") + saveSnapshotsAndOverwrite(t, baseSnapshot, sparseSnapshot, Options{Sparse: true, Overwrite: OverwriteAlways}) } func TestRestorerOverwriteBehavior(t *testing.T) { @@ -993,26 +996,7 @@ func TestRestorerOverwriteBehavior(t *testing.T) { for _, test := range tests { t.Run("", func(t *testing.T) { - repo := repository.TestRepository(t) - tempdir := filepath.Join(rtest.TempDir(t), "target") - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - // base snapshot - sn, id := saveSnapshot(t, repo, baseSnapshot, noopGetGenericAttributes) - t.Logf("base snapshot saved as %v", id.Str()) - - res := NewRestorer(repo, sn, Options{}) - rtest.OK(t, res.RestoreTo(ctx, tempdir)) - - // overwrite snapshot - sn, id = saveSnapshot(t, repo, overwriteSnapshot, noopGetGenericAttributes) - t.Logf("overwrite snapshot saved as %v", id.Str()) - res = NewRestorer(repo, sn, Options{Overwrite: test.Overwrite}) - rtest.OK(t, res.RestoreTo(ctx, tempdir)) - - _, err := res.VerifyFiles(ctx, tempdir) - rtest.OK(t, err) + tempdir := saveSnapshotsAndOverwrite(t, baseSnapshot, overwriteSnapshot, Options{Overwrite: test.Overwrite}) for filename, content := range test.Files { data, err := os.ReadFile(filepath.Join(tempdir, filepath.FromSlash(filename)))