diff --git a/key.go b/key.go index 9339a6f7b..317b67c37 100644 --- a/key.go +++ b/key.go @@ -1,7 +1,6 @@ package restic import ( - "bytes" "crypto/aes" "crypto/cipher" "crypto/hmac" @@ -15,6 +14,7 @@ import ( "io/ioutil" "os" "os/user" + "sync" "time" "github.com/restic/restic/backend" @@ -338,6 +338,7 @@ type encryptWriter struct { iv []byte wroteIV bool h hash.Hash + s cipher.Stream w io.Writer origWr io.Writer err error // remember error writing iv @@ -353,6 +354,13 @@ func (e *encryptWriter) Close() error { return nil } +const encryptWriterChunkSize = 512 * 1024 // 512 KiB +var encryptWriterBufPool = sync.Pool{ + New: func() interface{} { + return make([]byte, encryptWriterChunkSize) + }, +} + func (e *encryptWriter) Write(p []byte) (int, error) { // write iv first if !e.wroteIV { @@ -364,13 +372,34 @@ func (e *encryptWriter) Write(p []byte) (int, error) { return 0, e.err } - n, err := e.w.Write(p) - if err != nil { - e.err = err - return n, err + buf := encryptWriterBufPool.Get().([]byte) + defer encryptWriterBufPool.Put(buf) + + written := 0 + for len(p) > 0 { + max := len(p) + if max > encryptWriterChunkSize { + max = encryptWriterChunkSize + } + + e.s.XORKeyStream(buf, p[:max]) + n, err := e.w.Write(buf[:max]) + if n != max { + if err == nil { // should never happen + err = io.ErrShortWrite + } + } + + written += n + p = p[n:] + + if err != nil { + e.err = err + return written, err + } } - return n, nil + return written, nil } func (k *Key) encryptTo(ks *keys, wr io.Writer) io.WriteCloser { @@ -396,10 +425,8 @@ func (k *Key) encryptTo(ks *keys, wr io.Writer) io.WriteCloser { panic(fmt.Sprintf("unable to create cipher: %v", err)) } - ew.w = cipher.StreamWriter{ - S: cipher.NewCTR(c, ew.iv), - W: io.MultiWriter(ew.h, wr), - } + ew.s = cipher.NewCTR(c, ew.iv) + ew.w = io.MultiWriter(ew.h, wr) return ew } @@ -474,6 +501,34 @@ func (k *Key) DecryptUser(ciphertext []byte) ([]byte, error) { return k.decrypt(k.user, ciphertext) } +type decryptReader struct { + buf []byte + pos int +} + +func (d *decryptReader) Read(dst []byte) (int, error) { + if d.buf == nil { + return 0, io.EOF + } + + if len(dst) == 0 { + return 0, nil + } + + remaining := len(d.buf) - d.pos + if len(dst) >= remaining { + n := copy(dst, d.buf[d.pos:]) + FreeChunkBuf("decryptReader", d.buf) + d.buf = nil + return n, io.EOF + } + + n := copy(dst, d.buf[d.pos:d.pos+len(dst)]) + d.pos += n + + return n, nil +} + // decryptFrom verifies and decrypts the ciphertext read from rd with ks and // makes it available on the returned Reader. Ciphertext must be in the form IV // || Ciphertext || HMAC. In order to correctly verify the ciphertext, rd is @@ -481,14 +536,28 @@ func (k *Key) DecryptUser(ciphertext []byte) ([]byte, error) { // afterwards. If an HMAC verification failure is observed, it is returned // immediately. func (k *Key) decryptFrom(ks *keys, rd io.Reader) (io.Reader, error) { - ciphertext, err := ioutil.ReadAll(rd) + ciphertext := GetChunkBuf("decryptReader") + ciphertext = ciphertext[0:cap(ciphertext)] + n, err := io.ReadFull(rd, ciphertext) + if err != io.ErrUnexpectedEOF { + // read remaining data + buf, e := ioutil.ReadAll(rd) + ciphertext = append(ciphertext, buf...) + n += len(buf) + err = e + } else { + err = nil + } + if err != nil { return nil, err } + ciphertext = ciphertext[:n] + // check for plausible length if len(ciphertext) < ivSize+hmacSize { - panic("trying to decryipt invalid data: ciphertext too small") + panic("trying to decrypt invalid data: ciphertext too small") } hm := hmac.New(sha256.New, ks.Sign) @@ -498,7 +567,7 @@ func (k *Key) decryptFrom(ks *keys, rd io.Reader) (io.Reader, error) { ciphertext, mac := ciphertext[:l], ciphertext[l:] // calculate new hmac - n, err := hm.Write(ciphertext) + n, err = hm.Write(ciphertext) if err != nil || n != len(ciphertext) { panic(fmt.Sprintf("unable to calculate hmac of ciphertext, err %v", err)) } @@ -519,12 +588,10 @@ func (k *Key) decryptFrom(ks *keys, rd io.Reader) (io.Reader, error) { panic(fmt.Sprintf("unable to create cipher: %v", err)) } - r := cipher.StreamReader{ - S: cipher.NewCTR(c, iv), - R: bytes.NewReader(ciphertext), - } + stream := cipher.NewCTR(c, iv) + stream.XORKeyStream(ciphertext, ciphertext) - return r, nil + return &decryptReader{buf: ciphertext}, nil } // DecryptFrom verifies and decrypts the ciphertext read from rd and makes it