diff --git a/attic/archive.py b/attic/archive.py index 3d660d207..50bc674dc 100644 --- a/attic/archive.py +++ b/attic/archive.py @@ -5,6 +5,7 @@ import errno import shutil import tempfile from attic.key import key_factory +from attic.remote import RemoteRepository, RepositoryCache import msgpack import os import socket @@ -606,11 +607,16 @@ class ArchiveChecker: continue if state > 0: unpacker.resync() - for chunk_id, cdata in zip(items, self.repository.get_many(items)): + for chunk_id, cdata in zip(items, repository.get_many(items)): unpacker.feed(self.key.decrypt(chunk_id, cdata)) for item in unpacker: yield item + if isinstance(self.repository, RemoteRepository): + repository = RepositoryCache(self.repository) + else: + repository = self.repository + num_archives = len(self.manifest.archives) for i, (name, info) in enumerate(list(self.manifest.archives.items()), 1): self.report_progress('Analyzing archive {} ({}/{})'.format(name, i, num_archives)) diff --git a/attic/cache.py b/attic/cache.py index 8c4a4caf3..d029394f8 100644 --- a/attic/cache.py +++ b/attic/cache.py @@ -1,5 +1,5 @@ from configparser import RawConfigParser -from itertools import zip_longest +from attic.remote import RemoteRepository, RepositoryCache import msgpack import os from binascii import hexlify @@ -146,24 +146,28 @@ class Cache(object): print('Initializing cache...') self.chunks.clear() unpacker = msgpack.Unpacker() + if isinstance(self.repository, RemoteRepository): + repository = RepositoryCache(self.repository) + else: + repository = self.repository for name, info in self.manifest.archives.items(): - id = info[b'id'] - cdata = self.repository.get(id) - data = self.key.decrypt(id, cdata) - add(id, len(data), len(cdata)) + archive_id = info[b'id'] + cdata = repository.get(archive_id) + data = self.key.decrypt(archive_id, cdata) + add(archive_id, len(data), len(cdata)) archive = msgpack.unpackb(data) if archive[b'version'] != 1: raise Exception('Unknown archive metadata version') - decode_dict(archive, (b'name', b'hostname', b'username', b'time')) # fixme: argv + decode_dict(archive, (b'name',)) print('Analyzing archive:', archive[b'name']) - for id_, chunk in zip_longest(archive[b'items'], self.repository.get_many(archive[b'items'])): - data = self.key.decrypt(id_, chunk) - add(id_, len(data), len(chunk)) + for key, chunk in zip(archive[b'items'], repository.get_many(archive[b'items'])): + data = self.key.decrypt(key, chunk) + add(key, len(data), len(chunk)) unpacker.feed(data) for item in unpacker: if b'chunks' in item: - for id_, size, csize in item[b'chunks']: - add(id_, size, csize) + for chunk_id, size, csize in item[b'chunks']: + add(chunk_id, size, csize) def add_chunk(self, id, data, stats): if not self.txn_active: diff --git a/attic/remote.py b/attic/remote.py index 5017ef4de..b6bfd99c0 100644 --- a/attic/remote.py +++ b/attic/remote.py @@ -2,9 +2,12 @@ import fcntl import msgpack import os import select +import shutil from subprocess import Popen, PIPE import sys +import tempfile +from .hashindex import NSIndex from .helpers import Error, IntegrityError from .repository import Repository @@ -220,3 +223,63 @@ class RemoteRepository(object): def preload(self, ids): self.preload_ids += ids + + +class RepositoryCache: + """A caching Repository wrapper + + Caches Repository GET operations using a temporary file + """ + def __init__(self, repository): + self.tmppath = None + self.index = None + self.data_fd = None + self.repository = repository + self.entries = {} + self.initialize() + + def __del__(self): + self.cleanup() + + def initialize(self): + self.tmppath = tempfile.mkdtemp() + self.index = NSIndex.create(os.path.join(self.tmppath, 'index')) + self.data_fd = open(os.path.join(self.tmppath, 'data'), 'a+b') + + def cleanup(self): + del self.index + if self.data_fd: + self.data_fd.close() + if self.tmppath: + shutil.rmtree(self.tmppath) + + def load_object(self, offset, size): + self.data_fd.seek(offset) + data = self.data_fd.read(size) + assert len(data) == size + return data + + def store_object(self, key, data): + self.data_fd.seek(0, os.SEEK_END) + self.data_fd.write(data) + offset = self.data_fd.tell() + self.index[key] = offset - len(data), len(data) + + def get(self, key): + return next(self.get_many([key])) + + def get_many(self, keys): + unknown_keys = [key for key in keys if not key in self.index] + repository_iterator = zip(unknown_keys, self.repository.get_many(unknown_keys)) + for key in keys: + try: + yield self.load_object(*self.index[key]) + except KeyError: + for key_, data in repository_iterator: + if key_ == key: + self.store_object(key, data) + yield data + break + # Consume any pending requests + for _ in repository_iterator: + pass