mirror of https://github.com/borgbackup/borg.git
278 lines
9.6 KiB
Python
278 lines
9.6 KiB
Python
import fcntl
|
|
import msgpack
|
|
import os
|
|
import select
|
|
from subprocess import Popen, PIPE
|
|
import sys
|
|
|
|
from .helpers import Error
|
|
from .repository import Repository
|
|
from .lrucache import LRUCache
|
|
|
|
BUFSIZE = 10 * 1024 * 1024
|
|
|
|
|
|
class ConnectionClosed(Error):
|
|
"""Connection closed by remote host"""
|
|
|
|
|
|
class RepositoryServer(object):
|
|
|
|
def __init__(self):
|
|
self.repository = None
|
|
|
|
def serve(self):
|
|
# Make stdin non-blocking
|
|
fl = fcntl.fcntl(sys.stdin.fileno(), fcntl.F_GETFL)
|
|
fcntl.fcntl(sys.stdin.fileno(), fcntl.F_SETFL, fl | os.O_NONBLOCK)
|
|
# Make stdout blocking
|
|
fl = fcntl.fcntl(sys.stdout.fileno(), fcntl.F_GETFL)
|
|
fcntl.fcntl(sys.stdout.fileno(), fcntl.F_SETFL, fl & ~os.O_NONBLOCK)
|
|
unpacker = msgpack.Unpacker(use_list=False)
|
|
while True:
|
|
r, w, es = select.select([sys.stdin], [], [], 10)
|
|
if r:
|
|
data = os.read(sys.stdin.fileno(), BUFSIZE)
|
|
if not data:
|
|
return
|
|
unpacker.feed(data)
|
|
for type, msgid, method, args in unpacker:
|
|
method = method.decode('ascii')
|
|
try:
|
|
try:
|
|
f = getattr(self, method)
|
|
except AttributeError:
|
|
f = getattr(self.repository, method)
|
|
res = f(*args)
|
|
except Exception as e:
|
|
sys.stdout.buffer.write(msgpack.packb((1, msgid, e.__class__.__name__, None)))
|
|
else:
|
|
sys.stdout.buffer.write(msgpack.packb((1, msgid, None, res)))
|
|
sys.stdout.flush()
|
|
if es:
|
|
return
|
|
|
|
def negotiate(self, versions):
|
|
return 1
|
|
|
|
def open(self, path, create=False):
|
|
path = os.fsdecode(path)
|
|
if path.startswith('/~'):
|
|
path = path[1:]
|
|
self.repository = Repository(os.path.expanduser(path), create)
|
|
return self.repository.id
|
|
|
|
|
|
class RemoteRepository(object):
|
|
|
|
class RPCError(Exception):
|
|
|
|
def __init__(self, name):
|
|
self.name = name
|
|
|
|
def __init__(self, location, create=False):
|
|
self.repository_url = '%s@%s:%s' % (location.user, location.host, location.path)
|
|
self.p = None
|
|
self.cache = LRUCache(256)
|
|
self.to_send = b''
|
|
self.extra = {}
|
|
self.pending = {}
|
|
self.unpacker = msgpack.Unpacker(use_list=False)
|
|
self.msgid = 0
|
|
self.received_msgid = 0
|
|
if location.host == '__testsuite__':
|
|
args = [sys.executable, '-m', 'attic.archiver', 'serve']
|
|
else:
|
|
args = ['ssh',]
|
|
if location.port:
|
|
args += ['-p', str(location.port)]
|
|
if location.user:
|
|
args.append('%s@%s' % (location.user, location.host))
|
|
else:
|
|
args.append('%s' % location.host)
|
|
args += ['attic', 'serve']
|
|
self.p = Popen(args, bufsize=0, stdin=PIPE, stdout=PIPE)
|
|
self.stdin_fd = self.p.stdin.fileno()
|
|
self.stdout_fd = self.p.stdout.fileno()
|
|
fcntl.fcntl(self.stdin_fd, fcntl.F_SETFL, fcntl.fcntl(self.stdin_fd, fcntl.F_GETFL) | os.O_NONBLOCK)
|
|
fcntl.fcntl(self.stdout_fd, fcntl.F_SETFL, fcntl.fcntl(self.stdout_fd, fcntl.F_GETFL) | os.O_NONBLOCK)
|
|
self.r_fds = [self.stdout_fd]
|
|
self.x_fds = [self.stdin_fd, self.stdout_fd]
|
|
|
|
version = self.call('negotiate', (1,))
|
|
if version != 1:
|
|
raise Exception('Server insisted on using unsupported protocol version %d' % version)
|
|
try:
|
|
self.id = self.call('open', (location.path, create))
|
|
except self.RPCError as e:
|
|
if e.name == b'DoesNotExist':
|
|
raise Repository.DoesNotExist(self.repository_url)
|
|
elif e.name == b'AlreadyExists':
|
|
raise Repository.AlreadyExists(self.repository_url)
|
|
|
|
def __del__(self):
|
|
self.close()
|
|
|
|
def call(self, cmd, args, wait=True):
|
|
self.msgid += 1
|
|
to_send = msgpack.packb((1, self.msgid, cmd, args))
|
|
w_fds = [self.stdin_fd]
|
|
while wait or to_send:
|
|
r, w, x = select.select(self.r_fds, w_fds, self.x_fds, 1)
|
|
if x:
|
|
raise Exception('FD exception occured')
|
|
if r:
|
|
data = os.read(self.stdout_fd, BUFSIZE)
|
|
if not data:
|
|
raise ConnectionClosed()
|
|
self.unpacker.feed(data)
|
|
for type, msgid, error, res in self.unpacker:
|
|
if msgid == self.msgid:
|
|
self.received_msgid = msgid
|
|
if error:
|
|
raise self.RPCError(error)
|
|
else:
|
|
return res
|
|
else:
|
|
args = self.pending.pop(msgid, None)
|
|
if args is not None:
|
|
self.cache[args] = msgid, res, error
|
|
if w:
|
|
if to_send:
|
|
n = os.write(self.stdin_fd, to_send)
|
|
assert n > 0
|
|
to_send = memoryview(to_send)[n:]
|
|
if not to_send:
|
|
w_fds = []
|
|
|
|
def _read(self):
|
|
data = os.read(self.stdout_fd, BUFSIZE)
|
|
if not data:
|
|
raise Exception('Remote host closed connection')
|
|
self.unpacker.feed(data)
|
|
to_yield = []
|
|
for type, msgid, error, res in self.unpacker:
|
|
self.received_msgid = msgid
|
|
args = self.pending.pop(msgid, None)
|
|
if args is not None:
|
|
self.cache[args] = msgid, res, error
|
|
for args, resp, error in self.extra.pop(msgid, []):
|
|
if not resp and not error:
|
|
resp, error = self.cache[args][1:]
|
|
to_yield.append((resp, error))
|
|
for res, error in to_yield:
|
|
if error:
|
|
raise self.RPCError(error)
|
|
else:
|
|
yield res
|
|
|
|
def gen_request(self, cmd, argsv, wait):
|
|
data = []
|
|
m = self.received_msgid
|
|
for args in argsv:
|
|
# Make sure to invalidate any existing cache entries for non-get requests
|
|
if not args in self.cache:
|
|
self.msgid += 1
|
|
msgid = self.msgid
|
|
self.pending[msgid] = args
|
|
self.cache[args] = msgid, None, None
|
|
data.append(msgpack.packb((1, msgid, cmd, args)))
|
|
if wait:
|
|
msgid, resp, error = self.cache[args]
|
|
m = max(m, msgid)
|
|
self.extra.setdefault(m, []).append((args, resp, error))
|
|
return b''.join(data)
|
|
|
|
def gen_cache_requests(self, cmd, peek):
|
|
data = []
|
|
while True:
|
|
try:
|
|
args = (peek()[0],)
|
|
except StopIteration:
|
|
break
|
|
if args in self.cache:
|
|
continue
|
|
self.msgid += 1
|
|
msgid = self.msgid
|
|
self.pending[msgid] = args
|
|
self.cache[args] = msgid, None, None
|
|
data.append(msgpack.packb((1, msgid, cmd, args)))
|
|
return b''.join(data)
|
|
|
|
def call_multi(self, cmd, argsv, wait=True, peek=None):
|
|
w_fds = [self.stdin_fd]
|
|
left = len(argsv)
|
|
data = self.gen_request(cmd, argsv, wait)
|
|
self.to_send += data
|
|
for args, resp, error in self.extra.pop(self.received_msgid, []):
|
|
left -= 1
|
|
if not resp and not error:
|
|
resp, error = self.cache[args][1:]
|
|
if error:
|
|
raise self.RPCError(error)
|
|
else:
|
|
yield resp
|
|
while left:
|
|
r, w, x = select.select(self.r_fds, w_fds, self.x_fds, 1)
|
|
if x:
|
|
raise Exception('FD exception occured')
|
|
if r:
|
|
for res in self._read():
|
|
left -= 1
|
|
yield res
|
|
if w:
|
|
if not self.to_send and peek:
|
|
self.to_send = self.gen_cache_requests(cmd, peek)
|
|
if self.to_send:
|
|
n = os.write(self.stdin_fd, self.to_send)
|
|
assert n > 0
|
|
# self.to_send = memoryview(self.to_send)[n:]
|
|
self.to_send = self.to_send[n:]
|
|
else:
|
|
w_fds = []
|
|
if not wait:
|
|
return
|
|
|
|
def commit(self, *args):
|
|
self.call('commit', args)
|
|
|
|
def rollback(self, *args):
|
|
self.cache.clear()
|
|
self.pending.clear()
|
|
self.extra.clear()
|
|
return self.call('rollback', args)
|
|
|
|
def get(self, id):
|
|
try:
|
|
for res in self.call_multi('get', [(id, )]):
|
|
return res
|
|
except self.RPCError as e:
|
|
if e.name == b'DoesNotExist':
|
|
raise Repository.DoesNotExist(self.repository_url)
|
|
raise
|
|
|
|
def get_many(self, ids, peek=None):
|
|
return self.call_multi('get', [(id, ) for id in ids], peek=peek)
|
|
|
|
def _invalidate(self, id):
|
|
key = (id, )
|
|
if key in self.cache:
|
|
self.pending.pop(self.cache.pop(key)[0], None)
|
|
|
|
def put(self, id, data, wait=True):
|
|
resp = self.call('put', (id, data), wait=wait)
|
|
self._invalidate(id)
|
|
return resp
|
|
|
|
def delete(self, id, wait=True):
|
|
resp = self.call('delete', (id, ), wait=wait)
|
|
self._invalidate(id)
|
|
return resp
|
|
|
|
def close(self):
|
|
if self.p:
|
|
self.p.stdin.close()
|
|
self.p.stdout.close()
|
|
self.p.wait()
|
|
self.p = None
|