borg/darc/remote.py

211 lines
7.2 KiB
Python
Raw Normal View History

2011-08-03 15:54:03 +00:00
from __future__ import with_statement
2010-11-15 21:18:47 +00:00
import fcntl
import msgpack
import os
import paramiko
import select
import sys
import getpass
from .store import Store
2011-08-02 19:45:21 +00:00
from .helpers import Counter
2010-11-15 21:18:47 +00:00
BUFSIZE = 1024 * 1024
2011-07-05 19:29:15 +00:00
class ChannelNotifyer(object):
def __init__(self, channel):
self.channel = channel
2011-08-02 19:45:21 +00:00
self.enabled = Counter()
2011-07-05 19:29:15 +00:00
def set(self):
2011-08-02 19:45:21 +00:00
if self.enabled > 0:
2011-07-05 19:29:15 +00:00
with self.channel.lock:
self.channel.out_buffer_cv.notifyAll()
def clear(self):
pass
2010-11-15 21:18:47 +00:00
class StoreServer(object):
def __init__(self):
self.store = None
def serve(self):
# Make stdin non-blocking
fl = fcntl.fcntl(sys.stdin.fileno(), fcntl.F_GETFL)
fcntl.fcntl(sys.stdin.fileno(), fcntl.F_SETFL, fl | os.O_NONBLOCK)
unpacker = msgpack.Unpacker()
while True:
r, w, es = select.select([sys.stdin], [], [], 10)
if r:
data = os.read(sys.stdin.fileno(), BUFSIZE)
if not data:
return
unpacker.feed(data)
for type, msgid, method, args in unpacker:
try:
2010-11-17 21:40:39 +00:00
try:
f = getattr(self, method)
except AttributeError:
f = getattr(self.store, method)
res = f(*args)
2010-11-15 21:18:47 +00:00
except Exception, e:
sys.stdout.write(msgpack.packb((1, msgid, e.__class__.__name__, None)))
else:
2011-07-17 20:31:37 +00:00
sys.stdout.write(msgpack.packb((1, msgid, None, res)))
2010-11-15 21:18:47 +00:00
sys.stdout.flush()
if es:
return
2010-11-17 21:40:39 +00:00
def open(self, path, create=False):
if path.startswith('/~'):
path = path[1:]
self.store = Store(os.path.expanduser(path), create)
return self.store.id
2010-11-17 21:40:39 +00:00
2010-11-15 21:18:47 +00:00
class RemoteStore(object):
class DoesNotExist(Exception):
pass
class AlreadyExists(Exception):
pass
class RPCError(Exception):
def __init__(self, name):
self.name = name
def __init__(self, location, create=False):
self.client = paramiko.SSHClient()
self.client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
params = {'username': location.user or getpass.getuser(),
'hostname': location.host, 'port': location.port}
2010-11-15 21:18:47 +00:00
while True:
try:
self.client.connect(**params)
break
2011-01-04 21:37:27 +00:00
except (paramiko.PasswordRequiredException,
2011-07-05 19:29:15 +00:00
paramiko.AuthenticationException,
paramiko.SSHException):
if not 'password' in params:
params['password'] = getpass.getpass('Password for %(username)s@%(hostname)s:' % params)
else:
raise
2010-11-15 21:18:47 +00:00
self.unpacker = msgpack.Unpacker()
self.transport = self.client.get_transport()
self.channel = self.transport.open_session()
2011-07-05 19:29:15 +00:00
self.notifier = ChannelNotifyer(self.channel)
self.channel.in_buffer.set_event(self.notifier)
self.channel.in_stderr_buffer.set_event(self.notifier)
2010-11-15 21:18:47 +00:00
self.channel.exec_command('darc serve')
2011-07-17 20:31:37 +00:00
self.callbacks = {}
2010-11-15 21:18:47 +00:00
self.msgid = 0
2011-08-02 19:45:21 +00:00
self.recursion = 0
self.odata = []
self.id = self.cmd('open', (location.path, create))
2010-11-15 21:18:47 +00:00
def wait(self, write=True):
2011-07-17 21:53:23 +00:00
with self.channel.lock:
if ((not write or self.channel.out_window_size == 0) and
2011-07-31 16:22:58 +00:00
len(self.channel.in_buffer._buffer) == 0 and
len(self.channel.in_stderr_buffer._buffer) == 0):
2011-08-02 19:45:21 +00:00
self.channel.out_buffer_cv.wait(1)
2011-07-17 21:53:23 +00:00
2011-07-17 20:31:37 +00:00
def cmd(self, cmd, args, callback=None, callback_data=None):
2010-11-15 21:18:47 +00:00
self.msgid += 1
2011-08-02 19:45:21 +00:00
self.notifier.enabled.inc()
self.odata.append(msgpack.packb((1, self.msgid, cmd, args)))
2011-08-02 19:45:21 +00:00
self.recursion += 1
2011-07-17 20:31:37 +00:00
if callback:
2011-09-12 19:34:09 +00:00
self.add_callback(callback, callback_data)
2011-08-02 19:45:21 +00:00
if self.recursion > 1:
self.recursion -= 1
return
2010-11-15 21:18:47 +00:00
while True:
2011-07-05 19:29:15 +00:00
if self.channel.closed:
2011-08-02 19:45:21 +00:00
self.recursion -= 1
2011-07-05 19:29:15 +00:00
raise Exception('Connection closed')
elif self.channel.recv_stderr_ready():
print >> sys.stderr, 'remote stderr:', self.channel.recv_stderr(BUFSIZE)
elif self.channel.recv_ready():
self.unpacker.feed(self.channel.recv(BUFSIZE))
for type, msgid, error, res in self.unpacker:
2011-08-02 19:45:21 +00:00
self.notifier.enabled.dec()
2011-07-17 20:31:37 +00:00
if msgid == self.msgid:
if error:
self.recursion -= 1
raise self.RPCError(error)
2011-08-02 19:45:21 +00:00
self.recursion -= 1
2011-07-17 20:31:37 +00:00
return res
else:
2011-09-12 19:34:09 +00:00
for c, d in self.callbacks.pop(msgid, []):
2011-07-17 20:31:37 +00:00
c(res, error, d)
2011-08-02 19:45:21 +00:00
elif self.odata and self.channel.send_ready():
data = self.odata.pop(0)
n = self.channel.send(data)
if n != len(data):
self.odata.insert(0, data[n:])
2011-08-02 19:45:21 +00:00
if not self.odata and callback:
self.recursion -= 1
2011-07-17 21:53:23 +00:00
return
2011-07-17 20:31:37 +00:00
else:
2011-08-02 19:45:21 +00:00
self.wait(self.odata)
2011-09-05 19:20:17 +00:00
def commit(self, *args):
self.cmd('commit', args)
2010-11-15 21:18:47 +00:00
def rollback(self, *args):
return self.cmd('rollback', args)
2010-11-15 21:18:47 +00:00
def get(self, id, callback=None, callback_data=None):
2010-11-15 21:18:47 +00:00
try:
return self.cmd('get', (id, ), callback, callback_data)
except self.RPCError, e:
if e.name == 'DoesNotExist':
raise self.DoesNotExist
raise
def put(self, id, data, callback=None, callback_data=None):
2010-11-15 21:18:47 +00:00
try:
return self.cmd('put', (id, data), callback, callback_data)
2010-11-15 21:18:47 +00:00
except self.RPCError, e:
if e.name == 'AlreadyExists':
raise self.AlreadyExists
def delete(self, id, callback=None, callback_data=None):
return self.cmd('delete', (id, ), callback, callback_data)
2010-11-15 21:18:47 +00:00
2011-09-12 19:34:09 +00:00
def add_callback(self, cb, data):
self.callbacks.setdefault(self.msgid, []).append((cb, data))
2011-08-02 19:45:21 +00:00
def flush_rpc(self, counter=None, backlog=0):
counter = counter or self.notifier.enabled
while counter > backlog:
2011-07-17 21:53:23 +00:00
if self.channel.closed:
raise Exception('Connection closed')
2011-08-02 19:45:21 +00:00
elif self.odata and self.channel.send_ready():
n = self.channel.send(self.odata)
if n > 0:
self.odata = self.odata[n:]
2011-07-17 21:53:23 +00:00
elif self.channel.recv_stderr_ready():
print >> sys.stderr, 'remote stderr:', self.channel.recv_stderr(BUFSIZE)
elif self.channel.recv_ready():
self.unpacker.feed(self.channel.recv(BUFSIZE))
for type, msgid, error, res in self.unpacker:
2011-08-02 19:45:21 +00:00
self.notifier.enabled.dec()
2011-09-12 19:34:09 +00:00
for c, d in self.callbacks.pop(msgid, []):
2011-07-17 21:53:23 +00:00
c(res, error, d)
if msgid == self.msgid:
return
else:
2011-08-02 19:45:21 +00:00
self.wait(self.odata)