From ba553ec628b84cfa2f2f92898627b711244ea1d6 Mon Sep 17 00:00:00 2001 From: Martin Hostettler Date: Thu, 10 Nov 2016 09:56:18 +0100 Subject: [PATCH] remote: Introduce rpc protocol with named parameters. --- src/borg/remote.py | 286 +++++++++++++++++++++++++++++++++++++-------- 1 file changed, 236 insertions(+), 50 deletions(-) diff --git a/src/borg/remote.py b/src/borg/remote.py index b071a07b2..b4e1b3d2d 100644 --- a/src/borg/remote.py +++ b/src/borg/remote.py @@ -1,5 +1,7 @@ import errno import fcntl +import functools +import inspect import logging import os import select @@ -20,8 +22,11 @@ from .helpers import bin_to_hex from .helpers import replace_placeholders from .helpers import yes from .repository import Repository +from .version import parse_version, format_version RPC_PROTOCOL_VERSION = 2 +BORG_VERSION = parse_version(__version__) +MSGID, MSG, ARGS, RESULT = b'i', b'm', b'a', b'r' BUFSIZE = 10 * 1024 * 1024 @@ -54,6 +59,51 @@ class UnexpectedRPCDataFormatFromServer(Error): """Got unexpected RPC data format from server.""" +# Protocol compatibility: +# In general the server is responsible for rejecting too old clients and the client it responsible for rejecting +# too old servers. This ensures that the knowledge what is compatible is always held by the newer component. +# +# The server can do checks for the client version in RepositoryServer.negotiate. If the client_data is 2 then +# client is in the version range [0.29.0, 1.0.x] inclusive. For newer clients client_data is a dict which contains +# client_version. +# +# For the client the return of the negotiate method is either 2 if the server is in the version range [0.29.0, 1.0.x] +# inclusive, or it is a dict which includes the server version. +# +# All method calls on the remote repository object must be whitelisted in RepositoryServer.rpc_methods and have api +# stubs in RemoteRepository. The @api decorator on these stubs is used to set server version requirements. +# +# Method parameters are identified only by name and never by position. Unknown parameters are ignored by the server side. +# If a new parameter is important and may not be ignored, on the client a parameter specific version requirement needs +# to be added. +# When parameters are removed, they need to be preserved as defaulted parameters on the client stubs so that older +# servers still get compatible input. + + +compatMap = { + 'check': ('repair', 'save_space', ), + 'commit': ('save_space', ), + 'rollback': (), + 'destroy': (), + '__len__': (), + 'list': ('limit', 'marker', ), + 'put': ('id', 'data', ), + 'get': ('id_', ), + 'delete': ('id', ), + 'save_key': ('keydata', ), + 'load_key': (), + 'break_lock': (), + 'negotiate': ('client_data', ), + 'open': ('path', 'create', 'lock_wait', 'lock', 'exclusive', 'append_only', ), + 'get_free_nonce': (), + 'commit_nonce_reservation': ('next_unreserved', 'start_nonce', ), +} + + +def decode_keys(d): + return {k.decode(): d[k] for k in d} + + class RepositoryServer: # pragma: no cover rpc_methods = ( '__len__', @@ -79,6 +129,16 @@ class RepositoryServer: # pragma: no cover self.repository = None self.restrict_to_paths = restrict_to_paths self.append_only = append_only + self.client_version = parse_version('1.0.8') # fallback version if client is too old to send version information + + def positional_to_named(self, method, argv): + """Translate from positional protocol to named protocol.""" + return {name: argv[pos] for pos, name in enumerate(compatMap[method])} + + def filter_args(self, f, kwargs): + """Remove unknown named parameters from call, because client did (implicitly) say it's ok.""" + known = set(inspect.signature(f).parameters) + return {name: kwargs[name] for name in kwargs if name in known} def serve(self): stdin_fd = sys.stdin.fileno() @@ -107,12 +167,20 @@ class RepositoryServer: # pragma: no cover return unpacker.feed(data) for unpacked in unpacker: - if not (isinstance(unpacked, tuple) and len(unpacked) == 4): + if isinstance(unpacked, dict): + dictFormat = True + msgid = unpacked[MSGID] + method = unpacked[MSG].decode() + args = decode_keys(unpacked[ARGS]) + elif isinstance(unpacked, tuple) and len(unpacked) == 4: + dictFormat = False + type, msgid, method, args = unpacked + method = method.decode() + args = self.positional_to_named(method, args) + else: if self.repository is not None: self.repository.close() raise UnexpectedRPCDataFormatFromClient(__version__) - type, msgid, method, args = unpacked - method = method.decode() try: if method not in self.rpc_methods: raise InvalidRPCMethod(method) @@ -120,7 +188,8 @@ class RepositoryServer: # pragma: no cover f = getattr(self, method) except AttributeError: f = getattr(self.repository, method) - res = f(*args) + args = self.filter_args(f, args) + res = f(**args) except BaseException as e: if isinstance(e, (Repository.DoesNotExist, Repository.AlreadyExists, PathNotAllowed)): # These exceptions are reconstructed on the client end in RemoteRepository.call_many(), @@ -138,18 +207,35 @@ class RepositoryServer: # pragma: no cover logging.error(msg) logging.log(tb_log_level, tb) exc = 'Remote Exception (see remote log for the traceback)' - os.write(stdout_fd, msgpack.packb((1, msgid, e.__class__.__name__, exc))) + if dictFormat: + os.write(stdout_fd, msgpack.packb({MSGID: msgid, b'exception_class': e.__class__.__name__})) + else: + os.write(stdout_fd, msgpack.packb((1, msgid, e.__class__.__name__, exc))) else: - os.write(stdout_fd, msgpack.packb((1, msgid, None, res))) + if dictFormat: + os.write(stdout_fd, msgpack.packb({MSGID: msgid, RESULT: res})) + else: + os.write(stdout_fd, msgpack.packb((1, msgid, None, res))) if es: self.repository.close() return - def negotiate(self, versions): - return RPC_PROTOCOL_VERSION + def negotiate(self, client_data): + # old format used in 1.0.x + if client_data == RPC_PROTOCOL_VERSION: + return RPC_PROTOCOL_VERSION + # clients since 1.1.0b3 use a dict as client_data + if isinstance(client_data, dict): + self.client_version = client_data[b'client_version'] + else: + self.client_version = BORG_VERSION # seems to be newer than current version (no known old format) + + # not a known old format, send newest negotiate this version knows + return {'server_version': BORG_VERSION} def open(self, path, create=False, lock_wait=None, lock=True, exclusive=None, append_only=False): - path = os.fsdecode(path) + if isinstance(path, bytes): + path = os.fsdecode(path) if path.startswith('/~'): # /~/x = path x relative to home dir, /~username/x = relative to "user" home dir path = os.path.join(get_home_dir(), path[2:]) # XXX check this (see also 1.0-maint), is it correct for ~u? elif path.startswith('/./'): # /./x = path x relative to cwd @@ -203,6 +289,54 @@ class SleepingBandwidthLimiter: return written +def api(*, since, **kwargs_decorator): + """Check version requirements and use self.call to do the remote method call. + + specifies the version in which borg introduced this method, + calling this method when connected to an older version will fail without transmiting + anything to the server. + + Further kwargs can be used to encode version specific restrictions. + If a previous hardcoded behaviour is parameterized in a version, this allows calls that + use the previously hardcoded behaviour to pass through and generates an error if another + behaviour is requested by the client. + + e.g. when 'append_only' was introduced in 1.0.7 the previous behaviour was what now is append_only=False. + Thus @api(..., append_only={'since': parse_version('1.0.7'), 'previously': False}) allows calls + with append_only=False for all version but rejects calls using append_only=True on versions older than 1.0.7. + """ + def decorator(f): + @functools.wraps(f) + def do_rpc(self, *args, **kwargs): + sig = inspect.signature(f) + bound_args = sig.bind(self, *args, **kwargs) + named = {} + for name, param in sig.parameters.items(): + if name == 'self': + continue + if name in bound_args.arguments: + named[name] = bound_args.arguments[name] + else: + if param.default is not param.empty: + named[name] = param.default + + if self.server_version < since: + raise self.RPCServerOutdated(f.__name__, format_version(since)) + + for name, restriction in kwargs_decorator.items(): + if restriction['since'] <= self.server_version: + continue + if 'previously' in restriction and named[name] == restriction['previously']: + continue + + raise self.RPCServerOutdated("{0} {1}={2!s}".format(f.__name__, name, named[name]), + format_version(restriction['since'])) + + return self.call(f.__name__, named) + return do_rpc + return decorator + + class RemoteRepository: extra_test_args = [] @@ -214,6 +348,17 @@ class RemoteRepository: class NoAppendOnlyOnServer(Error): """Server does not support --append-only.""" + class RPCServerOutdated(Error): + """Borg server is too old for {}. Required version {}""" + + @property + def method(self): + return self.args[0] + + @property + def required_version(self): + return self.args[1] + def __init__(self, location, create=False, exclusive=False, lock_wait=None, lock=True, append_only=False, args=None): self.location = self._location = location self.preload_ids = [] @@ -225,6 +370,8 @@ class RemoteRepository: self.ratelimit = SleepingBandwidthLimiter(args.remote_ratelimit * 1024 if args and args.remote_ratelimit else 0) self.unpacker = msgpack.Unpacker(use_list=False) + self.dictFormat = False + self.server_version = parse_version('1.0.8') # fallback version if server is too old to send version information self.p = None testing = location.host == '__testsuite__' borg_cmd = self.borg_cmd(args, testing) @@ -254,15 +401,22 @@ class RemoteRepository: try: try: - version = self.call('negotiate', RPC_PROTOCOL_VERSION) + version = self.call('negotiate', {'client_data': {b'client_version': BORG_VERSION}}) except ConnectionClosed: raise ConnectionClosedWithHint('Is borg working on the server?') from None - if version != RPC_PROTOCOL_VERSION: - raise Exception('Server insisted on using unsupported protocol version %d' % version) + if version == RPC_PROTOCOL_VERSION: + self.dictFormat = False + elif isinstance(version, dict) and b'server_version' in version: + self.dictFormat = True + self.server_version = version[b'server_version'] + else: + raise Exception('Server insisted on using unsupported protocol version %s' % version) + try: - self.id = self.call('open', self.location.path, create, lock_wait, lock, exclusive, append_only) + self.id = self.call('open', {'path': self.location.path, 'create': create, 'lock_wait': lock_wait, + 'lock': lock, 'exclusive': exclusive, 'append_only': append_only}) except self.RPCError as err: - if err.remote_type != 'TypeError': + if self.dictFormat or err.remote_type != 'TypeError': raise msg = """\ Please note: @@ -276,7 +430,9 @@ This problem will go away as soon as the server has been upgraded to 1.0.7+. sys.stderr.write(msg) if append_only: raise self.NoAppendOnlyOnServer() - self.id = self.call('open', self.location.path, create, lock_wait, lock) + compatMap['open'] = ('path', 'create', 'lock_wait', 'lock', ) + self.id = self.call('open', {'path': self.location.path, 'create': create, 'lock_wait': lock_wait, + 'lock': lock, 'exclusive': exclusive, 'append_only': append_only}) except Exception: self.close() raise @@ -348,7 +504,10 @@ This problem will go away as soon as the server has been upgraded to 1.0.7+. args.append('%s' % location.host) return args - def call(self, cmd, *args, **kw): + def named_to_positional(self, method, kwargs): + return [kwargs[name] for name in compatMap[method]] + + def call(self, cmd, args, **kw): for resp in self.call_many(cmd, [args], **kw): return resp @@ -386,12 +545,12 @@ This problem will go away as soon as the server has been upgraded to 1.0.7+. while wait or calls: while waiting_for: try: - error, res = self.responses.pop(waiting_for[0]) + unpacked = self.responses.pop(waiting_for[0]) waiting_for.pop(0) - if error: - handle_error(error, res) + if b'exception_class' in unpacked: + handle_error(unpacked[b'exception_class'], None) else: - yield res + yield unpacked[RESULT] if not waiting_for and not calls: return except KeyError: @@ -410,15 +569,22 @@ This problem will go away as soon as the server has been upgraded to 1.0.7+. raise ConnectionClosed() self.unpacker.feed(data) for unpacked in self.unpacker: - if not (isinstance(unpacked, tuple) and len(unpacked) == 4): + if isinstance(unpacked, dict): + msgid = unpacked[MSGID] + elif isinstance(unpacked, tuple) and len(unpacked) == 4: + type, msgid, error, res = unpacked + if error: + unpacked = {MSGID: msgid, b'exception_class': error} + else: + unpacked = {MSGID: msgid, RESULT: res} + else: raise UnexpectedRPCDataFormatFromServer() - type, msgid, error, res = unpacked if msgid in self.ignore_responses: self.ignore_responses.remove(msgid) - if error: - handle_error(error, res) + if b'exception_class' in unpacked: + handle_error(unpacked[b'exception_class'], None) else: - self.responses[msgid] = error, res + self.responses[msgid] = unpacked elif fd is self.stderr_fd: data = os.read(fd, 32768) if not data: @@ -431,22 +597,28 @@ This problem will go away as soon as the server has been upgraded to 1.0.7+. if calls: if is_preloaded: assert cmd == 'get', "is_preload is only supported for 'get'" - if calls[0][0] in self.chunkid_to_msgids: - waiting_for.append(pop_preload_msgid(calls.pop(0)[0])) + if calls[0]['id_'] in self.chunkid_to_msgids: + waiting_for.append(pop_preload_msgid(calls.pop(0)['id_'])) else: args = calls.pop(0) - if cmd == 'get' and args[0] in self.chunkid_to_msgids: - waiting_for.append(pop_preload_msgid(args[0])) + if cmd == 'get' and args['id_'] in self.chunkid_to_msgids: + waiting_for.append(pop_preload_msgid(args['id_'])) else: self.msgid += 1 waiting_for.append(self.msgid) - self.to_send = msgpack.packb((1, self.msgid, cmd, args)) + if self.dictFormat: + self.to_send = msgpack.packb({MSGID: self.msgid, MSG: cmd, ARGS: args}) + else: + self.to_send = msgpack.packb((1, self.msgid, cmd, self.named_to_positional(cmd, args))) if not self.to_send and self.preload_ids: chunk_id = self.preload_ids.pop(0) - args = (chunk_id,) + args = {'id_': chunk_id} self.msgid += 1 self.chunkid_to_msgids.setdefault(chunk_id, []).append(self.msgid) - self.to_send = msgpack.packb((1, self.msgid, 'get', args)) + if self.dictFormat: + self.to_send = msgpack.packb({MSGID: self.msgid, MSG: 'get', ARGS: args}) + else: + self.to_send = msgpack.packb((1, self.msgid, 'get', self.named_to_positional(cmd, args))) if self.to_send: try: @@ -458,55 +630,69 @@ This problem will go away as soon as the server has been upgraded to 1.0.7+. raise self.ignore_responses |= set(waiting_for) + @api(since=parse_version('1.0.0')) def check(self, repair=False, save_space=False): - return self.call('check', repair, save_space) + """actual remoting is done via self.call in the @api decorator""" + @api(since=parse_version('1.0.0')) def commit(self, save_space=False): - return self.call('commit', save_space) + """actual remoting is done via self.call in the @api decorator""" - def rollback(self, *args): - return self.call('rollback') + @api(since=parse_version('1.0.0')) + def rollback(self): + """actual remoting is done via self.call in the @api decorator""" + @api(since=parse_version('1.0.0')) def destroy(self): - return self.call('destroy') + """actual remoting is done via self.call in the @api decorator""" + @api(since=parse_version('1.0.0')) def __len__(self): - return self.call('__len__') + """actual remoting is done via self.call in the @api decorator""" + @api(since=parse_version('1.0.0')) def list(self, limit=None, marker=None): - return self.call('list', limit, marker) + """actual remoting is done via self.call in the @api decorator""" + @api(since=parse_version('1.1.0b3')) def scan(self, limit=None, marker=None): - return self.call('scan', limit, marker) + """actual remoting is done via self.call in the @api decorator""" def get(self, id_): for resp in self.get_many([id_]): return resp def get_many(self, ids, is_preloaded=False): - for resp in self.call_many('get', [(id_,) for id_ in ids], is_preloaded=is_preloaded): + for resp in self.call_many('get', [{'id_': id_} for id_ in ids], is_preloaded=is_preloaded): yield resp - def put(self, id_, data, wait=True): - return self.call('put', id_, data, wait=wait) + @api(since=parse_version('1.0.0')) + def put(self, id, data, wait=True): + """actual remoting is done via self.call in the @api decorator""" - def delete(self, id_, wait=True): - return self.call('delete', id_, wait=wait) + @api(since=parse_version('1.0.0')) + def delete(self, id, wait=True): + """actual remoting is done via self.call in the @api decorator""" + @api(since=parse_version('1.0.0')) def save_key(self, keydata): - return self.call('save_key', keydata) + """actual remoting is done via self.call in the @api decorator""" + @api(since=parse_version('1.0.0')) def load_key(self): - return self.call('load_key') + """actual remoting is done via self.call in the @api decorator""" + @api(since=parse_version('1.0.0')) def get_free_nonce(self): - return self.call('get_free_nonce') + """actual remoting is done via self.call in the @api decorator""" + @api(since=parse_version('1.0.0')) def commit_nonce_reservation(self, next_unreserved, start_nonce): - return self.call('commit_nonce_reservation', next_unreserved, start_nonce) + """actual remoting is done via self.call in the @api decorator""" + @api(since=parse_version('1.0.0')) def break_lock(self): - return self.call('break_lock') + """actual remoting is done via self.call in the @api decorator""" def close(self): if self.p: