diff --git a/internal/repository/repository.go b/internal/repository/repository.go index 41f22f307..3ed9f7afa 100644 --- a/internal/repository/repository.go +++ b/internal/repository/repository.go @@ -950,10 +950,10 @@ const maxUnusedRange = 4 * 1024 * 1024 // then LoadBlobsFromPack will abort and not retry it. The buf passed to the callback is only valid within // this specific call. The callback must not keep a reference to buf. func (r *Repository) LoadBlobsFromPack(ctx context.Context, packID restic.ID, blobs []restic.Blob, handleBlobFn func(blob restic.BlobHandle, buf []byte, err error) error) error { - return streamPack(ctx, r.Backend().Load, r.LoadBlob, r.key, packID, blobs, handleBlobFn) + return streamPack(ctx, r.Backend().Load, r.LoadBlob, r.getZstdDecoder(), r.key, packID, blobs, handleBlobFn) } -func streamPack(ctx context.Context, beLoad backendLoadFn, loadBlobFn loadBlobFn, key *crypto.Key, packID restic.ID, blobs []restic.Blob, handleBlobFn func(blob restic.BlobHandle, buf []byte, err error) error) error { +func streamPack(ctx context.Context, beLoad backendLoadFn, loadBlobFn loadBlobFn, dec *zstd.Decoder, key *crypto.Key, packID restic.ID, blobs []restic.Blob, handleBlobFn func(blob restic.BlobHandle, buf []byte, err error) error) error { if len(blobs) == 0 { // nothing to do return nil @@ -987,7 +987,7 @@ func streamPack(ctx context.Context, beLoad backendLoadFn, loadBlobFn loadBlobFn if split { // load everything up to the skipped file section - err := streamPackPart(ctx, beLoad, loadBlobFn, key, packID, blobs[lowerIdx:i], handleBlobFn) + err := streamPackPart(ctx, beLoad, loadBlobFn, dec, key, packID, blobs[lowerIdx:i], handleBlobFn) if err != nil { return err } @@ -996,10 +996,10 @@ func streamPack(ctx context.Context, beLoad backendLoadFn, loadBlobFn loadBlobFn lastPos = blobs[i].Offset + blobs[i].Length } // load remainder - return streamPackPart(ctx, beLoad, loadBlobFn, key, packID, blobs[lowerIdx:], handleBlobFn) + return streamPackPart(ctx, beLoad, loadBlobFn, dec, key, packID, blobs[lowerIdx:], handleBlobFn) } -func streamPackPart(ctx context.Context, beLoad backendLoadFn, loadBlobFn loadBlobFn, key *crypto.Key, packID restic.ID, blobs []restic.Blob, handleBlobFn func(blob restic.BlobHandle, buf []byte, err error) error) error { +func streamPackPart(ctx context.Context, beLoad backendLoadFn, loadBlobFn loadBlobFn, dec *zstd.Decoder, key *crypto.Key, packID restic.ID, blobs []restic.Blob, handleBlobFn func(blob restic.BlobHandle, buf []byte, err error) error) error { h := backend.Handle{Type: restic.PackFile, Name: packID.String(), IsMetadata: false} dataStart := blobs[0].Offset @@ -1007,14 +1007,8 @@ func streamPackPart(ctx context.Context, beLoad backendLoadFn, loadBlobFn loadBl debug.Log("streaming pack %v (%d to %d bytes), blobs: %v", packID, dataStart, dataEnd, len(blobs)) - dec, err := zstd.NewReader(nil) - if err != nil { - panic(dec) - } - defer dec.Close() - data := make([]byte, int(dataEnd-dataStart)) - err = beLoad(ctx, h, int(dataEnd-dataStart), int64(dataStart), func(rd io.Reader) error { + err := beLoad(ctx, h, int(dataEnd-dataStart), int64(dataStart), func(rd io.Reader) error { _, cerr := io.ReadFull(rd, data) return cerr }) diff --git a/internal/repository/repository_internal_test.go b/internal/repository/repository_internal_test.go index 1f71b17de..16e6e8484 100644 --- a/internal/repository/repository_internal_test.go +++ b/internal/repository/repository_internal_test.go @@ -146,6 +146,12 @@ func TestStreamPack(t *testing.T) { } func testStreamPack(t *testing.T, version uint) { + dec, err := zstd.NewReader(nil) + if err != nil { + panic(dec) + } + defer dec.Close() + // always use the same key for deterministic output key := testKey(t) @@ -270,7 +276,7 @@ func testStreamPack(t *testing.T, version uint) { loadCalls = 0 shortFirstLoad = test.shortFirstLoad - err := streamPack(ctx, load, nil, &key, restic.ID{}, test.blobs, handleBlob) + err := streamPack(ctx, load, nil, dec, &key, restic.ID{}, test.blobs, handleBlob) if err != nil { t.Fatal(err) } @@ -333,7 +339,7 @@ func testStreamPack(t *testing.T, version uint) { return err } - err := streamPack(ctx, load, nil, &key, restic.ID{}, test.blobs, handleBlob) + err := streamPack(ctx, load, nil, dec, &key, restic.ID{}, test.blobs, handleBlob) if err == nil { t.Fatalf("wanted error %v, got nil", test.err) } @@ -456,6 +462,12 @@ func testKey(t *testing.T) crypto.Key { } func TestStreamPackFallback(t *testing.T) { + dec, err := zstd.NewReader(nil) + if err != nil { + panic(dec) + } + defer dec.Close() + test := func(t *testing.T, failLoad bool) { key := testKey(t) ctx, cancel := context.WithCancel(context.Background()) @@ -503,7 +515,7 @@ func TestStreamPackFallback(t *testing.T) { return err } - err := streamPack(ctx, loadPack, loadBlob, &key, restic.ID{}, blobs, handleBlob) + err := streamPack(ctx, loadPack, loadBlob, dec, &key, restic.ID{}, blobs, handleBlob) rtest.OK(t, err) rtest.Assert(t, blobOK, "blob failed to load") }