diff --git a/src/restic/backend/rest/rest.go b/src/restic/backend/rest/rest.go index 6ed05965f..fe931a6b9 100644 --- a/src/restic/backend/rest/rest.go +++ b/src/restic/backend/rest/rest.go @@ -80,6 +80,10 @@ func (b *restBackend) Save(h restic.Handle, rd io.Reader) (err error) { return err } + // make sure that client.Post() cannot close the reader by wrapping it in + // backend.Closer, which has a noop method. + rd = backend.Closer{Reader: rd} + <-b.connChan resp, err := b.client.Post(restPath(b.url, h), "binary/octet-stream", rd) b.connChan <- struct{}{} diff --git a/src/restic/backend/test/tests.go b/src/restic/backend/test/tests.go index 54583b0ce..d9241ec2b 100644 --- a/src/restic/backend/test/tests.go +++ b/src/restic/backend/test/tests.go @@ -6,8 +6,10 @@ import ( "io" "io/ioutil" "math/rand" + "os" "reflect" "restic" + "restic/errors" "sort" "strings" "testing" @@ -271,6 +273,16 @@ func TestLoad(t testing.TB) { test.OK(t, b.Remove(restic.DataFile, id.String())) } +type errorCloser struct { + io.Reader + t testing.TB +} + +func (ec errorCloser) Close() error { + ec.t.Error("forbidden method close was called") + return errors.New("forbidden method close was called") +} + // TestSave tests saving data in the backend. func TestSave(t testing.TB) { b := open(t) @@ -312,6 +324,46 @@ func TestSave(t testing.TB) { t.Fatalf("error removing item: %v", err) } } + + // test saving from a tempfile + tmpfile, err := ioutil.TempFile("", "restic-backend-save-test-") + if err != nil { + t.Fatal(err) + } + + length := rand.Intn(1<<23) + 200000 + data := test.Random(23, length) + copy(id[:], data) + + if _, err = tmpfile.Write(data); err != nil { + t.Fatal(err) + } + + if _, err = tmpfile.Seek(0, 0); err != nil { + t.Fatal(err) + } + + h := restic.Handle{Type: restic.DataFile, Name: id.String()} + + // wrap the tempfile in an errorCloser, so we can detect if the backend + // closes the reader + err = b.Save(h, errorCloser{t: t, Reader: tmpfile}) + if err != nil { + t.Fatal(err) + } + + if err = tmpfile.Close(); err != nil { + t.Fatal(err) + } + + if err = os.Remove(tmpfile.Name()); err != nil { + t.Fatal(err) + } + + err = b.Remove(h.Type, h.Name) + if err != nil { + t.Fatalf("error removing item: %v", err) + } } var filenameTests = []struct { diff --git a/src/restic/repository/packer_manager.go b/src/restic/repository/packer_manager.go index 8e9327e3e..e3f49f389 100644 --- a/src/restic/repository/packer_manager.go +++ b/src/restic/repository/packer_manager.go @@ -29,20 +29,6 @@ type Packer struct { tmpfile *os.File } -// Finalize finalizes the pack.Packer and then closes the tempfile. -func (p *Packer) Finalize() (uint, error) { - n, err := p.Packer.Finalize() - if err != nil { - return n, err - } - - if err = p.tmpfile.Close(); err != nil { - return n, err - } - - return n, nil -} - // packerManager keeps a list of open packs and creates new on demand. type packerManager struct { be Saver diff --git a/src/restic/repository/packer_manager_test.go b/src/restic/repository/packer_manager_test.go index 465fbadcb..37718a5ea 100644 --- a/src/restic/repository/packer_manager_test.go +++ b/src/restic/repository/packer_manager_test.go @@ -47,25 +47,19 @@ func randomID(rd io.Reader) restic.ID { const maxBlobSize = 1 << 20 -func saveFile(t testing.TB, be Saver, filename string, id restic.ID) { - f, err := os.Open(filename) - if err != nil { - t.Fatal(err) - } - +func saveFile(t testing.TB, be Saver, f *os.File, id restic.ID) { h := restic.Handle{Type: restic.DataFile, Name: id.String()} t.Logf("save file %v", h) - if err = be.Save(h, f); err != nil { + if err := be.Save(h, f); err != nil { t.Fatal(err) } - if err = f.Close(); err != nil { + if err := f.Close(); err != nil { t.Fatal(err) } - err = os.Remove(filename) - if err != nil { + if err := os.Remove(f.Name()); err != nil { t.Fatal(err) } } @@ -104,8 +98,12 @@ func fillPacks(t testing.TB, rnd *randReader, be Saver, pm *packerManager, buf [ t.Fatal(err) } + if _, err = packer.tmpfile.Seek(0, 0); err != nil { + t.Fatal(err) + } + packID := restic.IDFromHash(packer.hw.Sum(nil)) - saveFile(t, be, packer.tmpfile.Name(), packID) + saveFile(t, be, packer.tmpfile, packID) } return bytes @@ -121,7 +119,7 @@ func flushRemainingPacks(t testing.TB, rnd *randReader, be Saver, pm *packerMana bytes += int(n) packID := restic.IDFromHash(packer.hw.Sum(nil)) - saveFile(t, be, packer.tmpfile.Name(), packID) + saveFile(t, be, packer.tmpfile, packID) } } diff --git a/src/restic/test/helpers.go b/src/restic/test/helpers.go index 072cc4db9..4e19000e8 100644 --- a/src/restic/test/helpers.go +++ b/src/restic/test/helpers.go @@ -79,7 +79,7 @@ func Random(seed, count int) []byte { for j := range data { cur := i + j - if len(p) >= cur { + if cur >= len(p) { break } p[cur] = data[j]