1
0
Fork 0
mirror of https://github.com/borgbackup/borg.git synced 2024-12-25 09:19:31 +00:00

Merge pull request #1211 from enkore/issue/1138

Fix incorrect propagation of OSErrors in create code
This commit is contained in:
enkore 2016-06-29 17:07:51 +02:00 committed by GitHub
commit 67c69998d6
4 changed files with 89 additions and 24 deletions

View file

@ -1,4 +1,5 @@
from binascii import hexlify from binascii import hexlify
from contextlib import contextmanager
from datetime import datetime, timezone from datetime import datetime, timezone
from getpass import getuser from getpass import getuser
from itertools import groupby from itertools import groupby
@ -45,6 +46,37 @@
flags_noatime = flags_normal | getattr(os, 'O_NOATIME', 0) flags_noatime = flags_normal | getattr(os, 'O_NOATIME', 0)
class InputOSError(Exception):
"""Wrapper for OSError raised while accessing input files."""
def __init__(self, os_error):
self.os_error = os_error
self.errno = os_error.errno
self.strerror = os_error.strerror
self.filename = os_error.filename
def __str__(self):
return str(self.os_error)
@contextmanager
def input_io():
"""Context manager changing OSError to InputOSError."""
try:
yield
except OSError as os_error:
raise InputOSError(os_error) from os_error
def input_io_iter(iterator):
while True:
try:
with input_io():
item = next(iterator)
except StopIteration:
return
yield item
class DownloadPipeline: class DownloadPipeline:
def __init__(self, repository, key): def __init__(self, repository, key):
@ -464,12 +496,14 @@ def stat_attrs(self, st, path):
} }
if self.numeric_owner: if self.numeric_owner:
item[b'user'] = item[b'group'] = None item[b'user'] = item[b'group'] = None
xattrs = xattr.get_all(path, follow_symlinks=False) with input_io():
xattrs = xattr.get_all(path, follow_symlinks=False)
if xattrs: if xattrs:
item[b'xattrs'] = StableDict(xattrs) item[b'xattrs'] = StableDict(xattrs)
if has_lchflags and st.st_flags: if has_lchflags and st.st_flags:
item[b'bsdflags'] = st.st_flags item[b'bsdflags'] = st.st_flags
acl_get(path, item, st, self.numeric_owner) with input_io():
acl_get(path, item, st, self.numeric_owner)
return item return item
def process_dir(self, path, st): def process_dir(self, path, st):
@ -504,7 +538,7 @@ def process_stdin(self, path, cache):
uid, gid = 0, 0 uid, gid = 0, 0
fd = sys.stdin.buffer # binary fd = sys.stdin.buffer # binary
chunks = [] chunks = []
for chunk in self.chunker.chunkify(fd): for chunk in input_io_iter(self.chunker.chunkify(fd)):
chunks.append(cache.add_chunk(self.key.id_hash(chunk), chunk, self.stats)) chunks.append(cache.add_chunk(self.key.id_hash(chunk), chunk, self.stats))
self.stats.nfiles += 1 self.stats.nfiles += 1
t = int_to_bigint(int(time.time()) * 1000000000) t = int_to_bigint(int(time.time()) * 1000000000)
@ -552,10 +586,11 @@ def process_file(self, path, st, cache, ignore_inode=False):
item = {b'path': safe_path} item = {b'path': safe_path}
# Only chunkify the file if needed # Only chunkify the file if needed
if chunks is None: if chunks is None:
fh = Archive._open_rb(path) with input_io():
fh = Archive._open_rb(path)
with os.fdopen(fh, 'rb') as fd: with os.fdopen(fh, 'rb') as fd:
chunks = [] chunks = []
for chunk in self.chunker.chunkify(fd, fh): for chunk in input_io_iter(self.chunker.chunkify(fd, fh)):
chunks.append(cache.add_chunk(self.key.id_hash(chunk), chunk, self.stats)) chunks.append(cache.add_chunk(self.key.id_hash(chunk), chunk, self.stats))
if self.show_progress: if self.show_progress:
self.stats.show_progress(item=item, dt=0.2) self.stats.show_progress(item=item, dt=0.2)

View file

@ -29,7 +29,7 @@
from .repository import Repository from .repository import Repository
from .cache import Cache from .cache import Cache
from .key import key_creator, RepoKey, PassphraseKey from .key import key_creator, RepoKey, PassphraseKey
from .archive import Archive, ArchiveChecker, CHUNKER_PARAMS from .archive import input_io, InputOSError, Archive, ArchiveChecker, CHUNKER_PARAMS
from .remote import RepositoryServer, RemoteRepository, cache_if_remote from .remote import RepositoryServer, RemoteRepository, cache_if_remote
has_lchflags = hasattr(os, 'lchflags') has_lchflags = hasattr(os, 'lchflags')
@ -198,7 +198,7 @@ def create_inner(archive, cache):
if not dry_run: if not dry_run:
try: try:
status = archive.process_stdin(path, cache) status = archive.process_stdin(path, cache)
except OSError as e: except InputOSError as e:
status = 'E' status = 'E'
self.print_warning('%s: %s', path, e) self.print_warning('%s: %s', path, e)
else: else:
@ -273,7 +273,7 @@ def _process(self, archive, cache, matcher, exclude_caches, exclude_if_present,
if not dry_run: if not dry_run:
try: try:
status = archive.process_file(path, st, cache, self.ignore_inode) status = archive.process_file(path, st, cache, self.ignore_inode)
except OSError as e: except InputOSError as e:
status = 'E' status = 'E'
self.print_warning('%s: %s', path, e) self.print_warning('%s: %s', path, e)
elif stat.S_ISDIR(st.st_mode): elif stat.S_ISDIR(st.st_mode):

View file

@ -241,6 +241,24 @@ def fetch_from_cache(args):
del self.cache[args] del self.cache[args]
return msgid return msgid
def handle_error(error, res):
if error == b'DoesNotExist':
raise Repository.DoesNotExist(self.location.orig)
elif error == b'AlreadyExists':
raise Repository.AlreadyExists(self.location.orig)
elif error == b'CheckNeeded':
raise Repository.CheckNeeded(self.location.orig)
elif error == b'IntegrityError':
raise IntegrityError(res)
elif error == b'PathNotAllowed':
raise PathNotAllowed(*res)
elif error == b'ObjectNotFound':
raise Repository.ObjectNotFound(res[0], self.location.orig)
elif error == b'InvalidRPCMethod':
raise InvalidRPCMethod(*res)
else:
raise self.RPCError(res.decode('utf-8'))
calls = list(calls) calls = list(calls)
waiting_for = [] waiting_for = []
w_fds = [self.stdin_fd] w_fds = [self.stdin_fd]
@ -250,22 +268,7 @@ def fetch_from_cache(args):
error, res = self.responses.pop(waiting_for[0]) error, res = self.responses.pop(waiting_for[0])
waiting_for.pop(0) waiting_for.pop(0)
if error: if error:
if error == b'DoesNotExist': handle_error(error, res)
raise Repository.DoesNotExist(self.location.orig)
elif error == b'AlreadyExists':
raise Repository.AlreadyExists(self.location.orig)
elif error == b'CheckNeeded':
raise Repository.CheckNeeded(self.location.orig)
elif error == b'IntegrityError':
raise IntegrityError(res)
elif error == b'PathNotAllowed':
raise PathNotAllowed(*res)
elif error == b'ObjectNotFound':
raise Repository.ObjectNotFound(res[0], self.location.orig)
elif error == b'InvalidRPCMethod':
raise InvalidRPCMethod(*res)
else:
raise self.RPCError(res.decode('utf-8'))
else: else:
yield res yield res
if not waiting_for and not calls: if not waiting_for and not calls:
@ -287,6 +290,8 @@ def fetch_from_cache(args):
type, msgid, error, res = unpacked type, msgid, error, res = unpacked
if msgid in self.ignore_responses: if msgid in self.ignore_responses:
self.ignore_responses.remove(msgid) self.ignore_responses.remove(msgid)
if error:
handle_error(error, res)
else: else:
self.responses[msgid] = error, res self.responses[msgid] = error, res
elif fd is self.stderr_fd: elif fd is self.stderr_fd:

View file

@ -5,6 +5,7 @@
import pytest import pytest
from ..archive import Archive, CacheChunkBuffer, RobustUnpacker, valid_msgpacked_dict, ITEM_KEYS from ..archive import Archive, CacheChunkBuffer, RobustUnpacker, valid_msgpacked_dict, ITEM_KEYS
from ..archive import InputOSError, input_io, input_io_iter
from ..key import PlaintextKey from ..key import PlaintextKey
from ..helpers import Manifest from ..helpers import Manifest
from . import BaseTestCase from . import BaseTestCase
@ -145,3 +146,27 @@ def test_key_length_msgpacked_items():
data = {key: b''} data = {key: b''}
item_keys_serialized = [msgpack.packb(key), ] item_keys_serialized = [msgpack.packb(key), ]
assert valid_msgpacked_dict(msgpack.packb(data), item_keys_serialized) assert valid_msgpacked_dict(msgpack.packb(data), item_keys_serialized)
def test_input_io():
with pytest.raises(InputOSError):
with input_io():
raise OSError(123)
def test_input_io_iter():
class Iterator:
def __init__(self, exc):
self.exc = exc
def __next__(self):
raise self.exc()
oserror_iterator = Iterator(OSError)
with pytest.raises(InputOSError):
for _ in input_io_iter(oserror_iterator):
pass
normal_iterator = Iterator(StopIteration)
for _ in input_io_iter(normal_iterator):
assert False, 'StopIteration handled incorrectly'