package backend import ( "bytes" "crypto/md5" "hash" "io" "math/rand" "os" "path/filepath" "testing" "time" "github.com/restic/restic/internal/test" ) func TestByteReader(t *testing.T) { buf := []byte("foobar") for _, hasher := range []hash.Hash{nil, md5.New()} { fn := func() RewindReader { return NewByteReader(buf, hasher) } testRewindReader(t, fn, buf) } } func TestFileReader(t *testing.T) { buf := []byte("foobar") d := test.TempDir(t) filename := filepath.Join(d, "file-reader-test") err := os.WriteFile(filename, buf, 0600) if err != nil { t.Fatal(err) } f, err := os.Open(filename) if err != nil { t.Fatal(err) } defer func() { err := f.Close() if err != nil { t.Fatal(err) } }() for _, hasher := range []hash.Hash{nil, md5.New()} { fn := func() RewindReader { var hash []byte if hasher != nil { // must never fail according to interface _, err := hasher.Write(buf) if err != nil { panic(err) } hash = hasher.Sum(nil) } rd, err := NewFileReader(f, hash) if err != nil { t.Fatal(err) } return rd } testRewindReader(t, fn, buf) } } func testRewindReader(t *testing.T, fn func() RewindReader, data []byte) { seed := time.Now().UnixNano() t.Logf("seed is %d", seed) rnd := rand.New(rand.NewSource(seed)) type ReaderTestFunc func(t testing.TB, r RewindReader, data []byte) var tests = []ReaderTestFunc{ func(t testing.TB, rd RewindReader, data []byte) { if rd.Length() != int64(len(data)) { t.Fatalf("wrong length returned, want %d, got %d", int64(len(data)), rd.Length()) } buf := make([]byte, len(data)) _, err := io.ReadFull(rd, buf) if err != nil { t.Fatal(err) } if !bytes.Equal(buf, data) { t.Fatalf("wrong data returned") } if rd.Length() != int64(len(data)) { t.Fatalf("wrong length returned, want %d, got %d", int64(len(data)), rd.Length()) } err = rd.Rewind() if err != nil { t.Fatal(err) } if rd.Length() != int64(len(data)) { t.Fatalf("wrong length returned, want %d, got %d", int64(len(data)), rd.Length()) } buf2 := make([]byte, int64(len(data))) _, err = io.ReadFull(rd, buf2) if err != nil { t.Fatal(err) } if !bytes.Equal(buf2, data) { t.Fatalf("wrong data returned") } if rd.Length() != int64(len(data)) { t.Fatalf("wrong length returned, want %d, got %d", int64(len(data)), rd.Length()) } if rd.Hash() != nil { hasher := md5.New() // must never fail according to interface _, _ = hasher.Write(buf2) if !bytes.Equal(rd.Hash(), hasher.Sum(nil)) { t.Fatal("hash does not match data") } } }, func(t testing.TB, rd RewindReader, data []byte) { // read first bytes buf := make([]byte, rnd.Intn(len(data))) _, err := io.ReadFull(rd, buf) if err != nil { t.Fatal(err) } if !bytes.Equal(buf, data[:len(buf)]) { t.Fatalf("wrong data returned") } err = rd.Rewind() if err != nil { t.Fatal(err) } buf2 := make([]byte, rnd.Intn(len(data))) _, err = io.ReadFull(rd, buf2) if err != nil { t.Fatal(err) } if !bytes.Equal(buf2, data[:len(buf2)]) { t.Fatalf("wrong data returned") } // read remainder buf3 := make([]byte, len(data)-len(buf2)) _, err = io.ReadFull(rd, buf3) if err != nil { t.Fatal(err) } if !bytes.Equal(buf3, data[len(buf2):]) { t.Fatalf("wrong data returned") } }, } for _, test := range tests { t.Run("", func(t *testing.T) { rd := fn() test(t, rd, data) }) } }