cache: Always use cached file if it exists

A file is always cached whole. Thus, any out of bounds access will also
fail when directed at the backend. To handle case in which the cached
file is broken, then caller must call Cache.Forget(h) for the file in
question.
This commit is contained in:
Michael Eischer 2024-05-09 18:21:53 +02:00
parent 8cce06d915
commit 97a307df1a
3 changed files with 29 additions and 24 deletions

View File

@ -40,7 +40,8 @@ func (b *Backend) Remove(ctx context.Context, h backend.Handle) error {
return err return err
} }
return b.Cache.remove(h) err = b.Cache.remove(h)
return err
} }
func autoCacheTypes(h backend.Handle) bool { func autoCacheTypes(h backend.Handle) bool {
@ -133,9 +134,9 @@ func (b *Backend) cacheFile(ctx context.Context, h backend.Handle) error {
// loadFromCache will try to load the file from the cache. // loadFromCache will try to load the file from the cache.
func (b *Backend) loadFromCache(h backend.Handle, length int, offset int64, consumer func(rd io.Reader) error) (bool, error) { func (b *Backend) loadFromCache(h backend.Handle, length int, offset int64, consumer func(rd io.Reader) error) (bool, error) {
rd, err := b.Cache.load(h, length, offset) rd, inCache, err := b.Cache.load(h, length, offset)
if err != nil { if err != nil {
return false, err return inCache, err
} }
err = consumer(rd) err = consumer(rd)

View File

@ -34,46 +34,48 @@ func (c *Cache) canBeCached(t backend.FileType) bool {
// load returns a reader that yields the contents of the file with the // load returns a reader that yields the contents of the file with the
// given handle. rd must be closed after use. If an error is returned, the // given handle. rd must be closed after use. If an error is returned, the
// ReadCloser is nil. // ReadCloser is nil. The bool return value indicates whether the requested
func (c *Cache) load(h backend.Handle, length int, offset int64) (io.ReadCloser, error) { // file exists in the cache. It can be true even when no reader is returned
// because length or offset are out of bounds
func (c *Cache) load(h backend.Handle, length int, offset int64) (io.ReadCloser, bool, error) {
debug.Log("Load(%v, %v, %v) from cache", h, length, offset) debug.Log("Load(%v, %v, %v) from cache", h, length, offset)
if !c.canBeCached(h.Type) { if !c.canBeCached(h.Type) {
return nil, errors.New("cannot be cached") return nil, false, errors.New("cannot be cached")
} }
f, err := fs.Open(c.filename(h)) f, err := fs.Open(c.filename(h))
if err != nil { if err != nil {
return nil, errors.WithStack(err) return nil, false, errors.WithStack(err)
} }
fi, err := f.Stat() fi, err := f.Stat()
if err != nil { if err != nil {
_ = f.Close() _ = f.Close()
return nil, errors.WithStack(err) return nil, true, errors.WithStack(err)
} }
size := fi.Size() size := fi.Size()
if size <= int64(crypto.CiphertextLength(0)) { if size <= int64(crypto.CiphertextLength(0)) {
_ = f.Close() _ = f.Close()
return nil, errors.Errorf("cached file %v is truncated", h) return nil, true, errors.Errorf("cached file %v is truncated", h)
} }
if size < offset+int64(length) { if size < offset+int64(length) {
_ = f.Close() _ = f.Close()
return nil, errors.Errorf("cached file %v is too short", h) return nil, true, errors.Errorf("cached file %v is too short", h)
} }
if offset > 0 { if offset > 0 {
if _, err = f.Seek(offset, io.SeekStart); err != nil { if _, err = f.Seek(offset, io.SeekStart); err != nil {
_ = f.Close() _ = f.Close()
return nil, err return nil, true, err
} }
} }
if length <= 0 { if length <= 0 {
return f, nil return f, true, nil
} }
return util.LimitReadCloser(f, int64(length)), nil return util.LimitReadCloser(f, int64(length)), true, nil
} }
// save saves a file in the cache. // save saves a file in the cache.

View File

@ -14,7 +14,7 @@ import (
"github.com/restic/restic/internal/errors" "github.com/restic/restic/internal/errors"
"github.com/restic/restic/internal/fs" "github.com/restic/restic/internal/fs"
"github.com/restic/restic/internal/restic" "github.com/restic/restic/internal/restic"
"github.com/restic/restic/internal/test" rtest "github.com/restic/restic/internal/test"
"golang.org/x/sync/errgroup" "golang.org/x/sync/errgroup"
) )
@ -22,7 +22,7 @@ import (
func generateRandomFiles(t testing.TB, tpe backend.FileType, c *Cache) restic.IDSet { func generateRandomFiles(t testing.TB, tpe backend.FileType, c *Cache) restic.IDSet {
ids := restic.NewIDSet() ids := restic.NewIDSet()
for i := 0; i < rand.Intn(15)+10; i++ { for i := 0; i < rand.Intn(15)+10; i++ {
buf := test.Random(rand.Int(), 1<<19) buf := rtest.Random(rand.Int(), 1<<19)
id := restic.Hash(buf) id := restic.Hash(buf)
h := backend.Handle{Type: tpe, Name: id.String()} h := backend.Handle{Type: tpe, Name: id.String()}
@ -48,10 +48,11 @@ func randomID(s restic.IDSet) restic.ID {
} }
func load(t testing.TB, c *Cache, h backend.Handle) []byte { func load(t testing.TB, c *Cache, h backend.Handle) []byte {
rd, err := c.load(h, 0, 0) rd, inCache, err := c.load(h, 0, 0)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
rtest.Equals(t, true, inCache, "expected inCache flag to be true")
if rd == nil { if rd == nil {
t.Fatalf("load() returned nil reader") t.Fatalf("load() returned nil reader")
@ -144,7 +145,7 @@ func TestFileLoad(t *testing.T) {
c := TestNewCache(t) c := TestNewCache(t)
// save about 5 MiB of data in the cache // save about 5 MiB of data in the cache
data := test.Random(rand.Int(), 5234142) data := rtest.Random(rand.Int(), 5234142)
id := restic.ID{} id := restic.ID{}
copy(id[:], data) copy(id[:], data)
h := backend.Handle{ h := backend.Handle{
@ -169,10 +170,11 @@ func TestFileLoad(t *testing.T) {
for _, test := range tests { for _, test := range tests {
t.Run(fmt.Sprintf("%v/%v", test.length, test.offset), func(t *testing.T) { t.Run(fmt.Sprintf("%v/%v", test.length, test.offset), func(t *testing.T) {
rd, err := c.load(h, test.length, test.offset) rd, inCache, err := c.load(h, test.length, test.offset)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
rtest.Equals(t, true, inCache, "expected inCache flag to be true")
buf, err := io.ReadAll(rd) buf, err := io.ReadAll(rd)
if err != nil { if err != nil {
@ -225,7 +227,7 @@ func TestFileSaveConcurrent(t *testing.T) {
var ( var (
c = TestNewCache(t) c = TestNewCache(t)
data = test.Random(1, 10000) data = rtest.Random(1, 10000)
g errgroup.Group g errgroup.Group
id restic.ID id restic.ID
) )
@ -245,7 +247,7 @@ func TestFileSaveConcurrent(t *testing.T) {
// ensure is ENOENT or nil error. // ensure is ENOENT or nil error.
time.Sleep(time.Duration(100+rand.Intn(200)) * time.Millisecond) time.Sleep(time.Duration(100+rand.Intn(200)) * time.Millisecond)
f, err := c.load(h, 0, 0) f, _, err := c.load(h, 0, 0)
t.Logf("Load error: %v", err) t.Logf("Load error: %v", err)
switch { switch {
case err == nil: case err == nil:
@ -264,17 +266,17 @@ func TestFileSaveConcurrent(t *testing.T) {
}) })
} }
test.OK(t, g.Wait()) rtest.OK(t, g.Wait())
saved := load(t, c, h) saved := load(t, c, h)
test.Equals(t, data, saved) rtest.Equals(t, data, saved)
} }
func TestFileSaveAfterDamage(t *testing.T) { func TestFileSaveAfterDamage(t *testing.T) {
c := TestNewCache(t) c := TestNewCache(t)
test.OK(t, fs.RemoveAll(c.path)) rtest.OK(t, fs.RemoveAll(c.path))
// save a few bytes of data in the cache // save a few bytes of data in the cache
data := test.Random(123456789, 42) data := rtest.Random(123456789, 42)
id := restic.Hash(data) id := restic.Hash(data)
h := backend.Handle{ h := backend.Handle{
Type: restic.PackFile, Type: restic.PackFile,