diff --git a/backend/reader_test.go b/backend/reader_test.go index 708fbceb2..7b05424c5 100644 --- a/backend/reader_test.go +++ b/backend/reader_test.go @@ -65,7 +65,7 @@ func TestHashingReader(t *testing.T) { assert(t, n == int64(size), "HashAppendReader: invalid number of bytes read: got %d, expected %d", - n, size+len(expectedHash)) + n, size) resultingHash := rd.Sum(nil) assert(t, bytes.Equal(expectedHash[:], resultingHash), diff --git a/backend/writer.go b/backend/writer.go new file mode 100644 index 000000000..55ffb3279 --- /dev/null +++ b/backend/writer.go @@ -0,0 +1,63 @@ +package backend + +import ( + "errors" + "hash" + "io" +) + +type HashAppendWriter struct { + w io.Writer + origWr io.Writer + h hash.Hash + sum []byte + closed bool +} + +func NewHashAppendWriter(w io.Writer, h hash.Hash) *HashAppendWriter { + return &HashAppendWriter{ + h: h, + w: io.MultiWriter(w, h), + origWr: w, + sum: make([]byte, 0, h.Size()), + } +} + +func (h *HashAppendWriter) Close() error { + if !h.closed { + h.closed = true + + _, err := h.origWr.Write(h.h.Sum(nil)) + return err + } + + return nil +} + +func (h *HashAppendWriter) Write(p []byte) (n int, err error) { + if !h.closed { + return h.w.Write(p) + } + + return 0, errors.New("Write() called on closed HashAppendWriter") +} + +type HashingWriter struct { + w io.Writer + h hash.Hash +} + +func NewHashingWriter(w io.Writer, h hash.Hash) *HashingWriter { + return &HashingWriter{ + h: h, + w: io.MultiWriter(w, h), + } +} + +func (h *HashingWriter) Write(p []byte) (int, error) { + return h.w.Write(p) +} + +func (h *HashingWriter) Sum(d []byte) []byte { + return h.h.Sum(d) +} diff --git a/backend/writer_test.go b/backend/writer_test.go new file mode 100644 index 000000000..bf4655c25 --- /dev/null +++ b/backend/writer_test.go @@ -0,0 +1,76 @@ +package backend_test + +import ( + "bytes" + "crypto/rand" + "crypto/sha256" + "io" + "io/ioutil" + "testing" + + "github.com/restic/restic/backend" +) + +func TestHashAppendWriter(t *testing.T) { + tests := []int{5, 23, 2<<18 + 23, 1 << 20} + + for _, size := range tests { + data := make([]byte, size) + _, err := io.ReadFull(rand.Reader, data) + if err != nil { + t.Fatalf("ReadFull: %v", err) + } + + expectedHash := sha256.Sum256(data) + + target := bytes.NewBuffer(nil) + wr := backend.NewHashAppendWriter(target, sha256.New()) + + _, err = wr.Write(data) + ok(t, err) + ok(t, wr.Close()) + + assert(t, len(target.Bytes()) == size+len(expectedHash), + "HashAppendWriter: invalid number of bytes written: got %d, expected %d", + len(target.Bytes()), size+len(expectedHash)) + + r := target.Bytes() + resultingHash := r[len(r)-len(expectedHash):] + assert(t, bytes.Equal(expectedHash[:], resultingHash), + "HashAppendWriter: hashes do not match: expected %02x, got %02x", + expectedHash, resultingHash) + + // write again, this must return an error + _, err = wr.Write([]byte{23}) + assert(t, err != nil, + "HashAppendWriter: Write() after Close() did not return an error") + } +} + +func TestHashingWriter(t *testing.T) { + tests := []int{5, 23, 2<<18 + 23, 1 << 20} + + for _, size := range tests { + data := make([]byte, size) + _, err := io.ReadFull(rand.Reader, data) + if err != nil { + t.Fatalf("ReadFull: %v", err) + } + + expectedHash := sha256.Sum256(data) + + wr := backend.NewHashingWriter(ioutil.Discard, sha256.New()) + + n, err := io.Copy(wr, bytes.NewReader(data)) + ok(t, err) + + assert(t, n == int64(size), + "HashAppendWriter: invalid number of bytes written: got %d, expected %d", + n, size) + + resultingHash := wr.Sum(nil) + assert(t, bytes.Equal(expectedHash[:], resultingHash), + "HashAppendWriter: hashes do not match: expected %02x, got %02x", + expectedHash, resultingHash) + } +}