mirror of https://github.com/restic/restic.git
176 lines
3.5 KiB
Go
176 lines
3.5 KiB
Go
package backend
|
|
|
|
import (
|
|
"bytes"
|
|
"crypto/md5"
|
|
"hash"
|
|
"io"
|
|
"math/rand"
|
|
"os"
|
|
"path/filepath"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/restic/restic/internal/test"
|
|
)
|
|
|
|
func TestByteReader(t *testing.T) {
|
|
buf := []byte("foobar")
|
|
for _, hasher := range []hash.Hash{nil, md5.New()} {
|
|
fn := func() RewindReader {
|
|
return NewByteReader(buf, hasher)
|
|
}
|
|
testRewindReader(t, fn, buf)
|
|
}
|
|
}
|
|
|
|
func TestFileReader(t *testing.T) {
|
|
buf := []byte("foobar")
|
|
|
|
d := test.TempDir(t)
|
|
filename := filepath.Join(d, "file-reader-test")
|
|
err := os.WriteFile(filename, buf, 0600)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
f, err := os.Open(filename)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
defer func() {
|
|
err := f.Close()
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
}()
|
|
|
|
for _, hasher := range []hash.Hash{nil, md5.New()} {
|
|
fn := func() RewindReader {
|
|
var hash []byte
|
|
if hasher != nil {
|
|
// must never fail according to interface
|
|
_, err := hasher.Write(buf)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
hash = hasher.Sum(nil)
|
|
}
|
|
rd, err := NewFileReader(f, hash)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
return rd
|
|
}
|
|
|
|
testRewindReader(t, fn, buf)
|
|
}
|
|
}
|
|
|
|
func testRewindReader(t *testing.T, fn func() RewindReader, data []byte) {
|
|
seed := time.Now().UnixNano()
|
|
t.Logf("seed is %d", seed)
|
|
rnd := rand.New(rand.NewSource(seed))
|
|
|
|
type ReaderTestFunc func(t testing.TB, r RewindReader, data []byte)
|
|
var tests = []ReaderTestFunc{
|
|
func(t testing.TB, rd RewindReader, data []byte) {
|
|
if rd.Length() != int64(len(data)) {
|
|
t.Fatalf("wrong length returned, want %d, got %d", int64(len(data)), rd.Length())
|
|
}
|
|
|
|
buf := make([]byte, len(data))
|
|
_, err := io.ReadFull(rd, buf)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
if !bytes.Equal(buf, data) {
|
|
t.Fatalf("wrong data returned")
|
|
}
|
|
|
|
if rd.Length() != int64(len(data)) {
|
|
t.Fatalf("wrong length returned, want %d, got %d", int64(len(data)), rd.Length())
|
|
}
|
|
|
|
err = rd.Rewind()
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
if rd.Length() != int64(len(data)) {
|
|
t.Fatalf("wrong length returned, want %d, got %d", int64(len(data)), rd.Length())
|
|
}
|
|
|
|
buf2 := make([]byte, int64(len(data)))
|
|
_, err = io.ReadFull(rd, buf2)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
if !bytes.Equal(buf2, data) {
|
|
t.Fatalf("wrong data returned")
|
|
}
|
|
|
|
if rd.Length() != int64(len(data)) {
|
|
t.Fatalf("wrong length returned, want %d, got %d", int64(len(data)), rd.Length())
|
|
}
|
|
|
|
if rd.Hash() != nil {
|
|
hasher := md5.New()
|
|
// must never fail according to interface
|
|
_, _ = hasher.Write(buf2)
|
|
if !bytes.Equal(rd.Hash(), hasher.Sum(nil)) {
|
|
t.Fatal("hash does not match data")
|
|
}
|
|
}
|
|
},
|
|
func(t testing.TB, rd RewindReader, data []byte) {
|
|
// read first bytes
|
|
buf := make([]byte, rnd.Intn(len(data)))
|
|
_, err := io.ReadFull(rd, buf)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
if !bytes.Equal(buf, data[:len(buf)]) {
|
|
t.Fatalf("wrong data returned")
|
|
}
|
|
|
|
err = rd.Rewind()
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
buf2 := make([]byte, rnd.Intn(len(data)))
|
|
_, err = io.ReadFull(rd, buf2)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
if !bytes.Equal(buf2, data[:len(buf2)]) {
|
|
t.Fatalf("wrong data returned")
|
|
}
|
|
|
|
// read remainder
|
|
buf3 := make([]byte, len(data)-len(buf2))
|
|
_, err = io.ReadFull(rd, buf3)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
if !bytes.Equal(buf3, data[len(buf2):]) {
|
|
t.Fatalf("wrong data returned")
|
|
}
|
|
},
|
|
}
|
|
|
|
for _, test := range tests {
|
|
t.Run("", func(t *testing.T) {
|
|
rd := fn()
|
|
test(t, rd, data)
|
|
})
|
|
}
|
|
}
|