mirror of
https://github.com/borgbackup/borg.git
synced 2024-12-25 09:19:31 +00:00
Merge pull request #1211 from enkore/issue/1138
Fix incorrect propagation of OSErrors in create code
This commit is contained in:
commit
67c69998d6
4 changed files with 89 additions and 24 deletions
|
@ -1,4 +1,5 @@
|
||||||
from binascii import hexlify
|
from binascii import hexlify
|
||||||
|
from contextlib import contextmanager
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from getpass import getuser
|
from getpass import getuser
|
||||||
from itertools import groupby
|
from itertools import groupby
|
||||||
|
@ -45,6 +46,37 @@
|
||||||
flags_noatime = flags_normal | getattr(os, 'O_NOATIME', 0)
|
flags_noatime = flags_normal | getattr(os, 'O_NOATIME', 0)
|
||||||
|
|
||||||
|
|
||||||
|
class InputOSError(Exception):
|
||||||
|
"""Wrapper for OSError raised while accessing input files."""
|
||||||
|
def __init__(self, os_error):
|
||||||
|
self.os_error = os_error
|
||||||
|
self.errno = os_error.errno
|
||||||
|
self.strerror = os_error.strerror
|
||||||
|
self.filename = os_error.filename
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return str(self.os_error)
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def input_io():
|
||||||
|
"""Context manager changing OSError to InputOSError."""
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
except OSError as os_error:
|
||||||
|
raise InputOSError(os_error) from os_error
|
||||||
|
|
||||||
|
|
||||||
|
def input_io_iter(iterator):
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
with input_io():
|
||||||
|
item = next(iterator)
|
||||||
|
except StopIteration:
|
||||||
|
return
|
||||||
|
yield item
|
||||||
|
|
||||||
|
|
||||||
class DownloadPipeline:
|
class DownloadPipeline:
|
||||||
|
|
||||||
def __init__(self, repository, key):
|
def __init__(self, repository, key):
|
||||||
|
@ -464,12 +496,14 @@ def stat_attrs(self, st, path):
|
||||||
}
|
}
|
||||||
if self.numeric_owner:
|
if self.numeric_owner:
|
||||||
item[b'user'] = item[b'group'] = None
|
item[b'user'] = item[b'group'] = None
|
||||||
xattrs = xattr.get_all(path, follow_symlinks=False)
|
with input_io():
|
||||||
|
xattrs = xattr.get_all(path, follow_symlinks=False)
|
||||||
if xattrs:
|
if xattrs:
|
||||||
item[b'xattrs'] = StableDict(xattrs)
|
item[b'xattrs'] = StableDict(xattrs)
|
||||||
if has_lchflags and st.st_flags:
|
if has_lchflags and st.st_flags:
|
||||||
item[b'bsdflags'] = st.st_flags
|
item[b'bsdflags'] = st.st_flags
|
||||||
acl_get(path, item, st, self.numeric_owner)
|
with input_io():
|
||||||
|
acl_get(path, item, st, self.numeric_owner)
|
||||||
return item
|
return item
|
||||||
|
|
||||||
def process_dir(self, path, st):
|
def process_dir(self, path, st):
|
||||||
|
@ -504,7 +538,7 @@ def process_stdin(self, path, cache):
|
||||||
uid, gid = 0, 0
|
uid, gid = 0, 0
|
||||||
fd = sys.stdin.buffer # binary
|
fd = sys.stdin.buffer # binary
|
||||||
chunks = []
|
chunks = []
|
||||||
for chunk in self.chunker.chunkify(fd):
|
for chunk in input_io_iter(self.chunker.chunkify(fd)):
|
||||||
chunks.append(cache.add_chunk(self.key.id_hash(chunk), chunk, self.stats))
|
chunks.append(cache.add_chunk(self.key.id_hash(chunk), chunk, self.stats))
|
||||||
self.stats.nfiles += 1
|
self.stats.nfiles += 1
|
||||||
t = int_to_bigint(int(time.time()) * 1000000000)
|
t = int_to_bigint(int(time.time()) * 1000000000)
|
||||||
|
@ -552,10 +586,11 @@ def process_file(self, path, st, cache, ignore_inode=False):
|
||||||
item = {b'path': safe_path}
|
item = {b'path': safe_path}
|
||||||
# Only chunkify the file if needed
|
# Only chunkify the file if needed
|
||||||
if chunks is None:
|
if chunks is None:
|
||||||
fh = Archive._open_rb(path)
|
with input_io():
|
||||||
|
fh = Archive._open_rb(path)
|
||||||
with os.fdopen(fh, 'rb') as fd:
|
with os.fdopen(fh, 'rb') as fd:
|
||||||
chunks = []
|
chunks = []
|
||||||
for chunk in self.chunker.chunkify(fd, fh):
|
for chunk in input_io_iter(self.chunker.chunkify(fd, fh)):
|
||||||
chunks.append(cache.add_chunk(self.key.id_hash(chunk), chunk, self.stats))
|
chunks.append(cache.add_chunk(self.key.id_hash(chunk), chunk, self.stats))
|
||||||
if self.show_progress:
|
if self.show_progress:
|
||||||
self.stats.show_progress(item=item, dt=0.2)
|
self.stats.show_progress(item=item, dt=0.2)
|
||||||
|
|
|
@ -29,7 +29,7 @@
|
||||||
from .repository import Repository
|
from .repository import Repository
|
||||||
from .cache import Cache
|
from .cache import Cache
|
||||||
from .key import key_creator, RepoKey, PassphraseKey
|
from .key import key_creator, RepoKey, PassphraseKey
|
||||||
from .archive import Archive, ArchiveChecker, CHUNKER_PARAMS
|
from .archive import input_io, InputOSError, Archive, ArchiveChecker, CHUNKER_PARAMS
|
||||||
from .remote import RepositoryServer, RemoteRepository, cache_if_remote
|
from .remote import RepositoryServer, RemoteRepository, cache_if_remote
|
||||||
|
|
||||||
has_lchflags = hasattr(os, 'lchflags')
|
has_lchflags = hasattr(os, 'lchflags')
|
||||||
|
@ -198,7 +198,7 @@ def create_inner(archive, cache):
|
||||||
if not dry_run:
|
if not dry_run:
|
||||||
try:
|
try:
|
||||||
status = archive.process_stdin(path, cache)
|
status = archive.process_stdin(path, cache)
|
||||||
except OSError as e:
|
except InputOSError as e:
|
||||||
status = 'E'
|
status = 'E'
|
||||||
self.print_warning('%s: %s', path, e)
|
self.print_warning('%s: %s', path, e)
|
||||||
else:
|
else:
|
||||||
|
@ -273,7 +273,7 @@ def _process(self, archive, cache, matcher, exclude_caches, exclude_if_present,
|
||||||
if not dry_run:
|
if not dry_run:
|
||||||
try:
|
try:
|
||||||
status = archive.process_file(path, st, cache, self.ignore_inode)
|
status = archive.process_file(path, st, cache, self.ignore_inode)
|
||||||
except OSError as e:
|
except InputOSError as e:
|
||||||
status = 'E'
|
status = 'E'
|
||||||
self.print_warning('%s: %s', path, e)
|
self.print_warning('%s: %s', path, e)
|
||||||
elif stat.S_ISDIR(st.st_mode):
|
elif stat.S_ISDIR(st.st_mode):
|
||||||
|
|
|
@ -241,6 +241,24 @@ def fetch_from_cache(args):
|
||||||
del self.cache[args]
|
del self.cache[args]
|
||||||
return msgid
|
return msgid
|
||||||
|
|
||||||
|
def handle_error(error, res):
|
||||||
|
if error == b'DoesNotExist':
|
||||||
|
raise Repository.DoesNotExist(self.location.orig)
|
||||||
|
elif error == b'AlreadyExists':
|
||||||
|
raise Repository.AlreadyExists(self.location.orig)
|
||||||
|
elif error == b'CheckNeeded':
|
||||||
|
raise Repository.CheckNeeded(self.location.orig)
|
||||||
|
elif error == b'IntegrityError':
|
||||||
|
raise IntegrityError(res)
|
||||||
|
elif error == b'PathNotAllowed':
|
||||||
|
raise PathNotAllowed(*res)
|
||||||
|
elif error == b'ObjectNotFound':
|
||||||
|
raise Repository.ObjectNotFound(res[0], self.location.orig)
|
||||||
|
elif error == b'InvalidRPCMethod':
|
||||||
|
raise InvalidRPCMethod(*res)
|
||||||
|
else:
|
||||||
|
raise self.RPCError(res.decode('utf-8'))
|
||||||
|
|
||||||
calls = list(calls)
|
calls = list(calls)
|
||||||
waiting_for = []
|
waiting_for = []
|
||||||
w_fds = [self.stdin_fd]
|
w_fds = [self.stdin_fd]
|
||||||
|
@ -250,22 +268,7 @@ def fetch_from_cache(args):
|
||||||
error, res = self.responses.pop(waiting_for[0])
|
error, res = self.responses.pop(waiting_for[0])
|
||||||
waiting_for.pop(0)
|
waiting_for.pop(0)
|
||||||
if error:
|
if error:
|
||||||
if error == b'DoesNotExist':
|
handle_error(error, res)
|
||||||
raise Repository.DoesNotExist(self.location.orig)
|
|
||||||
elif error == b'AlreadyExists':
|
|
||||||
raise Repository.AlreadyExists(self.location.orig)
|
|
||||||
elif error == b'CheckNeeded':
|
|
||||||
raise Repository.CheckNeeded(self.location.orig)
|
|
||||||
elif error == b'IntegrityError':
|
|
||||||
raise IntegrityError(res)
|
|
||||||
elif error == b'PathNotAllowed':
|
|
||||||
raise PathNotAllowed(*res)
|
|
||||||
elif error == b'ObjectNotFound':
|
|
||||||
raise Repository.ObjectNotFound(res[0], self.location.orig)
|
|
||||||
elif error == b'InvalidRPCMethod':
|
|
||||||
raise InvalidRPCMethod(*res)
|
|
||||||
else:
|
|
||||||
raise self.RPCError(res.decode('utf-8'))
|
|
||||||
else:
|
else:
|
||||||
yield res
|
yield res
|
||||||
if not waiting_for and not calls:
|
if not waiting_for and not calls:
|
||||||
|
@ -287,6 +290,8 @@ def fetch_from_cache(args):
|
||||||
type, msgid, error, res = unpacked
|
type, msgid, error, res = unpacked
|
||||||
if msgid in self.ignore_responses:
|
if msgid in self.ignore_responses:
|
||||||
self.ignore_responses.remove(msgid)
|
self.ignore_responses.remove(msgid)
|
||||||
|
if error:
|
||||||
|
handle_error(error, res)
|
||||||
else:
|
else:
|
||||||
self.responses[msgid] = error, res
|
self.responses[msgid] = error, res
|
||||||
elif fd is self.stderr_fd:
|
elif fd is self.stderr_fd:
|
||||||
|
|
|
@ -5,6 +5,7 @@
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from ..archive import Archive, CacheChunkBuffer, RobustUnpacker, valid_msgpacked_dict, ITEM_KEYS
|
from ..archive import Archive, CacheChunkBuffer, RobustUnpacker, valid_msgpacked_dict, ITEM_KEYS
|
||||||
|
from ..archive import InputOSError, input_io, input_io_iter
|
||||||
from ..key import PlaintextKey
|
from ..key import PlaintextKey
|
||||||
from ..helpers import Manifest
|
from ..helpers import Manifest
|
||||||
from . import BaseTestCase
|
from . import BaseTestCase
|
||||||
|
@ -145,3 +146,27 @@ def test_key_length_msgpacked_items():
|
||||||
data = {key: b''}
|
data = {key: b''}
|
||||||
item_keys_serialized = [msgpack.packb(key), ]
|
item_keys_serialized = [msgpack.packb(key), ]
|
||||||
assert valid_msgpacked_dict(msgpack.packb(data), item_keys_serialized)
|
assert valid_msgpacked_dict(msgpack.packb(data), item_keys_serialized)
|
||||||
|
|
||||||
|
|
||||||
|
def test_input_io():
|
||||||
|
with pytest.raises(InputOSError):
|
||||||
|
with input_io():
|
||||||
|
raise OSError(123)
|
||||||
|
|
||||||
|
|
||||||
|
def test_input_io_iter():
|
||||||
|
class Iterator:
|
||||||
|
def __init__(self, exc):
|
||||||
|
self.exc = exc
|
||||||
|
|
||||||
|
def __next__(self):
|
||||||
|
raise self.exc()
|
||||||
|
|
||||||
|
oserror_iterator = Iterator(OSError)
|
||||||
|
with pytest.raises(InputOSError):
|
||||||
|
for _ in input_io_iter(oserror_iterator):
|
||||||
|
pass
|
||||||
|
|
||||||
|
normal_iterator = Iterator(StopIteration)
|
||||||
|
for _ in input_io_iter(normal_iterator):
|
||||||
|
assert False, 'StopIteration handled incorrectly'
|
||||||
|
|
Loading…
Reference in a new issue