1
0
Fork 0
mirror of https://github.com/borgbackup/borg.git synced 2025-02-21 21:57:36 +00:00

RepositoryCache: checksum decrypted cache

This commit is contained in:
Marian Beermann 2017-06-03 12:14:17 +02:00
parent 5b3667b617
commit b544af2af1
2 changed files with 97 additions and 24 deletions

View file

@ -30,6 +30,7 @@
from .logger import create_logger, setup_logging from .logger import create_logger, setup_logging
from .repository import Repository, MAX_OBJECT_SIZE, LIST_SCAN_LIMIT from .repository import Repository, MAX_OBJECT_SIZE, LIST_SCAN_LIMIT
from .version import parse_version, format_version from .version import parse_version, format_version
from .algorithms.checksums import xxh64
logger = create_logger(__name__) logger = create_logger(__name__)
@ -1086,6 +1087,9 @@ class RepositoryCache(RepositoryNoCache):
should return the initial data (as returned by *transform*). should return the initial data (as returned by *transform*).
""" """
class InvalidateCacheEntry(Exception):
pass
def __init__(self, repository, pack=None, unpack=None, transform=None): def __init__(self, repository, pack=None, unpack=None, transform=None):
super().__init__(repository, transform) super().__init__(repository, transform)
self.pack = pack or (lambda data: data) self.pack = pack or (lambda data: data)
@ -1100,6 +1104,7 @@ def __init__(self, repository, pack=None, unpack=None, transform=None):
self.slow_misses = 0 self.slow_misses = 0
self.slow_lat = 0.0 self.slow_lat = 0.0
self.evictions = 0 self.evictions = 0
self.checksum_errors = 0
self.enospc = 0 self.enospc = 0
def query_size_limit(self): def query_size_limit(self):
@ -1144,10 +1149,10 @@ def add_entry(self, key, data, cache):
def close(self): def close(self):
logger.debug('RepositoryCache: current items %d, size %s / %s, %d hits, %d misses, %d slow misses (+%.1fs), ' logger.debug('RepositoryCache: current items %d, size %s / %s, %d hits, %d misses, %d slow misses (+%.1fs), '
'%d evictions, %d ENOSPC hit', '%d evictions, %d ENOSPC hit, %d checksum errors',
len(self.cache), format_file_size(self.size), format_file_size(self.size_limit), len(self.cache), format_file_size(self.size), format_file_size(self.size_limit),
self.hits, self.misses, self.slow_misses, self.slow_lat, self.hits, self.misses, self.slow_misses, self.slow_lat,
self.evictions, self.enospc) self.evictions, self.enospc, self.checksum_errors)
self.cache.clear() self.cache.clear()
shutil.rmtree(self.basedir) shutil.rmtree(self.basedir)
@ -1157,30 +1162,37 @@ def get_many(self, keys, cache=True):
for key in keys: for key in keys:
if key in self.cache: if key in self.cache:
file = self.key_filename(key) file = self.key_filename(key)
with open(file, 'rb') as fd: try:
self.hits += 1 with open(file, 'rb') as fd:
yield self.unpack(fd.read()) self.hits += 1
else: yield self.unpack(fd.read())
for key_, data in repository_iterator: continue # go to the next key
if key_ == key: except self.InvalidateCacheEntry:
transformed = self.add_entry(key, data, cache) self.cache.remove(key)
self.misses += 1 self.size -= os.stat(file).st_size
yield transformed self.checksum_errors += 1
break os.unlink(file)
else: # fall through to fetch the object again
# slow path: eviction during this get_many removed this key from the cache for key_, data in repository_iterator:
t0 = time.perf_counter() if key_ == key:
data = self.repository.get(key)
self.slow_lat += time.perf_counter() - t0
transformed = self.add_entry(key, data, cache) transformed = self.add_entry(key, data, cache)
self.slow_misses += 1 self.misses += 1
yield transformed yield transformed
break
else:
# slow path: eviction during this get_many removed this key from the cache
t0 = time.perf_counter()
data = self.repository.get(key)
self.slow_lat += time.perf_counter() - t0
transformed = self.add_entry(key, data, cache)
self.slow_misses += 1
yield transformed
# Consume any pending requests # Consume any pending requests
for _ in repository_iterator: for _ in repository_iterator:
pass pass
def cache_if_remote(repository, *, decrypted_cache=False, pack=None, unpack=None, transform=None): def cache_if_remote(repository, *, decrypted_cache=False, pack=None, unpack=None, transform=None, force_cache=False):
""" """
Return a Repository(No)Cache for *repository*. Return a Repository(No)Cache for *repository*.
@ -1194,21 +1206,30 @@ def cache_if_remote(repository, *, decrypted_cache=False, pack=None, unpack=None
raise ValueError('decrypted_cache and pack/unpack/transform are incompatible') raise ValueError('decrypted_cache and pack/unpack/transform are incompatible')
elif decrypted_cache: elif decrypted_cache:
key = decrypted_cache key = decrypted_cache
cache_struct = struct.Struct('=I') # 32 bit csize, 64 bit (8 byte) xxh64
cache_struct = struct.Struct('=I8s')
compressor = LZ4() compressor = LZ4()
def pack(data): def pack(data):
return cache_struct.pack(data[0]) + compressor.compress(data[1]) csize, decrypted = data
compressed = compressor.compress(decrypted)
return cache_struct.pack(csize, xxh64(compressed)) + compressed
def unpack(data): def unpack(data):
return cache_struct.unpack(data[:cache_struct.size])[0], compressor.decompress(data[cache_struct.size:]) data = memoryview(data)
csize, checksum = cache_struct.unpack(data[:cache_struct.size])
compressed = data[cache_struct.size:]
if checksum != xxh64(compressed):
logger.warning('Repository metadata cache: detected corrupted data in cache!')
raise RepositoryCache.InvalidateCacheEntry
return csize, compressor.decompress(compressed)
def transform(id_, data): def transform(id_, data):
csize = len(data) csize = len(data)
decrypted = key.decrypt(id_, data) decrypted = key.decrypt(id_, data)
return csize, decrypted return csize, decrypted
if isinstance(repository, RemoteRepository): if isinstance(repository, RemoteRepository) or force_cache:
return RepositoryCache(repository, pack, unpack, transform) return RepositoryCache(repository, pack, unpack, transform)
else: else:
return RepositoryNoCache(repository, transform) return RepositoryNoCache(repository, transform)

View file

@ -1,13 +1,17 @@
import errno import errno
import os import os
import io
import time import time
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
from ..remote import SleepingBandwidthLimiter, RepositoryCache from ..remote import SleepingBandwidthLimiter, RepositoryCache, cache_if_remote
from ..repository import Repository from ..repository import Repository
from ..crypto.key import PlaintextKey
from ..compress import CompressionSpec
from .hashindex import H from .hashindex import H
from .key import TestKey
class TestSleepingBandwidthLimiter: class TestSleepingBandwidthLimiter:
@ -147,3 +151,51 @@ def write(self, data):
assert cache.evictions == 0 assert cache.evictions == 0
assert next(iterator) == bytes(100) assert next(iterator) == bytes(100)
@pytest.fixture
def key(self, repository, monkeypatch):
monkeypatch.setenv('BORG_PASSPHRASE', 'test')
key = PlaintextKey.create(repository, TestKey.MockArgs())
key.compressor = CompressionSpec('none').compressor
return key
def _put_encrypted_object(self, key, repository, data):
id_ = key.id_hash(data)
repository.put(id_, key.encrypt(data))
return id_
@pytest.fixture
def H1(self, key, repository):
return self._put_encrypted_object(key, repository, b'1234')
@pytest.fixture
def H2(self, key, repository):
return self._put_encrypted_object(key, repository, b'5678')
@pytest.fixture
def H3(self, key, repository):
return self._put_encrypted_object(key, repository, bytes(100))
@pytest.fixture
def decrypted_cache(self, key, repository):
return cache_if_remote(repository, decrypted_cache=key, force_cache=True)
def test_cache_corruption(self, decrypted_cache: RepositoryCache, H1, H2, H3):
list(decrypted_cache.get_many([H1, H2, H3]))
iterator = decrypted_cache.get_many([H1, H2, H3])
assert next(iterator) == (7, b'1234')
with open(decrypted_cache.key_filename(H2), 'a+b') as fd:
fd.seek(-1, io.SEEK_END)
corrupted = (int.from_bytes(fd.read(), 'little') ^ 2).to_bytes(1, 'little')
fd.seek(-1, io.SEEK_END)
fd.write(corrupted)
fd.truncate()
assert next(iterator) == (7, b'5678')
assert decrypted_cache.checksum_errors == 1
assert decrypted_cache.slow_misses == 1
assert next(iterator) == (103, bytes(100))
assert decrypted_cache.hits == 3
assert decrypted_cache.misses == 3