From 0a6e6cfe2e6e28644d139ac6662e122f2f63da40 Mon Sep 17 00:00:00 2001 From: Thomas Waldmann Date: Sun, 1 Nov 2015 19:18:29 +0100 Subject: [PATCH] refactor confirmation code, reduce code duplication, add tests --- borg/archiver.py | 32 ++++++++-------- borg/cache.py | 29 ++++++-------- borg/helpers.py | 81 +++++++++++++++++++++++++++++++++++++++ borg/testsuite/helpers.py | 78 ++++++++++++++++++++++++++++++++++++- 4 files changed, 184 insertions(+), 36 deletions(-) diff --git a/borg/archiver.py b/borg/archiver.py index 50a4d7017..f948db793 100644 --- a/borg/archiver.py +++ b/borg/archiver.py @@ -19,7 +19,7 @@ from .helpers import Error, location_validator, format_time, format_file_size, \ format_file_mode, ExcludePattern, IncludePattern, exclude_path, adjust_patterns, to_localtime, timestamp, \ get_cache_dir, get_keys_dir, format_timedelta, prune_within, prune_split, \ Manifest, remove_surrogates, update_excludes, format_archive, check_extension_modules, Statistics, \ - is_cachedir, bigint_to_int, ChunkerParams, CompressionSpec, have_cython, is_slow_msgpack, \ + is_cachedir, bigint_to_int, ChunkerParams, CompressionSpec, have_cython, is_slow_msgpack, yes, \ EXIT_SUCCESS, EXIT_WARNING, EXIT_ERROR from .logger import create_logger, setup_logging logger = create_logger() @@ -88,13 +88,12 @@ class Archiver: """Check repository consistency""" repository = self.open_repository(args.repository, exclusive=args.repair) if args.repair: - while not os.environ.get('BORG_CHECK_I_KNOW_WHAT_I_AM_DOING'): - self.print_warning("""'check --repair' is an experimental feature that might result -in data loss. - -Type "Yes I am sure" if you understand this and want to continue.\n""") - if input('Do you want to continue? ') == 'Yes I am sure': - break + msg = ("'check --repair' is an experimental feature that might result in data loss." + + "\n" + + "Type 'YES' if you understand this and want to continue: ") + if not yes(msg, false_msg="Aborting.", + env_var_override='BORG_CHECK_I_KNOW_WHAT_I_AM_DOING', truish=('YES', )): + return EXIT_ERROR if not args.archives_only: logger.info('Starting repository check...') if repository.check(repair=args.repair): @@ -330,15 +329,16 @@ Type "Yes I am sure" if you understand this and want to continue.\n""") logger.info(str(cache)) else: if not args.cache_only: - print("You requested to completely DELETE the repository *including* all archives it contains:", file=sys.stderr) + msg = [] + msg.append("You requested to completely DELETE the repository *including* all archives it contains:") for archive_info in manifest.list_archive_infos(sort_by='ts'): - print(format_archive(archive_info), file=sys.stderr) - if not os.environ.get('BORG_CHECK_I_KNOW_WHAT_I_AM_DOING'): - print("""Type "YES" if you understand this and want to continue.\n""", file=sys.stderr) - # XXX: prompt may end up on stdout, but we'll assume that input() does the right thing - if input('Do you want to continue? ') != 'YES': - self.exit_code = EXIT_ERROR - return self.exit_code + msg.append(format_archive(archive_info)) + msg.append("Type 'YES' if you understand this and want to continue: ") + msg = '\n'.join(msg) + if not yes(msg, false_msg="Aborting.", + env_var_override='BORG_CHECK_I_KNOW_WHAT_I_AM_DOING', truish=('YES', )): + self.exit_code = EXIT_ERROR + return self.exit_code repository.destroy() logger.info("Repository deleted.") cache.destroy() diff --git a/borg/cache.py b/borg/cache.py index 12c302274..c3f085cdc 100644 --- a/borg/cache.py +++ b/borg/cache.py @@ -14,7 +14,7 @@ from .key import PlaintextKey from .logger import create_logger logger = create_logger() from .helpers import Error, get_cache_dir, decode_dict, st_mtime_ns, unhexlify, int_to_bigint, \ - bigint_to_int, format_file_size, have_cython + bigint_to_int, format_file_size, have_cython, yes from .locking import UpgradableLock from .hashindex import ChunkIndex @@ -51,15 +51,21 @@ class Cache: # Warn user before sending data to a never seen before unencrypted repository if not os.path.exists(self.path): if warn_if_unencrypted and isinstance(key, PlaintextKey): - if not self._confirm('Warning: Attempting to access a previously unknown unencrypted repository', - 'BORG_UNKNOWN_UNENCRYPTED_REPO_ACCESS_IS_OK'): + msg = ("Warning: Attempting to access a previously unknown unencrypted repository!" + + "\n" + + "Do you want to continue? [yN] ") + if not yes(msg, false_msg="Aborting.", + env_var_override='BORG_UNKNOWN_UNENCRYPTED_REPO_ACCESS_IS_OK'): raise self.CacheInitAbortedError() self.create() self.open() # Warn user before sending data to a relocated repository if self.previous_location and self.previous_location != repository._location.canonical_path(): - msg = 'Warning: The repository at location {} was previously located at {}'.format(repository._location.canonical_path(), self.previous_location) - if not self._confirm(msg, 'BORG_RELOCATED_REPO_ACCESS_IS_OK'): + msg = ("Warning: The repository at location {} was previously located at {}".format(repository._location.canonical_path(), self.previous_location) + + "\n" + + "Do you want to continue? [yN] ") + if not yes(msg, false_msg="Aborting.", + env_var_override='BORG_RELOCATED_REPO_ACCESS_IS_OK'): raise self.RepositoryAccessAborted() if sync and self.manifest.id != self.manifest_id: @@ -92,19 +98,6 @@ Chunk index: {0.total_unique_chunks:20d} {0.total_chunks:20d}""" stats[field] = format_file_size(stats[field]) return Summary(**stats) - def _confirm(self, message, env_var_override=None): - print(message, file=sys.stderr) - if env_var_override and os.environ.get(env_var_override): - print("Yes (From {})".format(env_var_override), file=sys.stderr) - return True - if not sys.stdin.isatty(): - return False - try: - answer = input('Do you want to continue? [yN] ') - except EOFError: - return False - return answer and answer in 'Yy' - def create(self): """Create a new empty cache at `self.path` """ diff --git a/borg/helpers.py b/borg/helpers.py index 6a170596c..a2941ddd4 100644 --- a/borg/helpers.py +++ b/borg/helpers.py @@ -804,3 +804,84 @@ def int_to_bigint(value): def is_slow_msgpack(): return msgpack.Packer is msgpack.fallback.Packer + + +def yes(msg=None, retry_msg=None, false_msg=None, true_msg=None, + default=False, default_notty=None, default_eof=None, + falsish=('No', 'no', 'N', 'n'), truish=('Yes', 'yes', 'Y', 'y'), + env_var_override=None, ifile=None, ofile=None, input=input): + """ + Output (usually a question) and let user input an answer. + Qualifies the answer according to falsish and truish as True or False. + If it didn't qualify and retry_msg is None (no retries wanted), + return the default [which defaults to False]. Otherwise let user retry + answering until answer is qualified. + + If env_var_override is given and it is non-empty, counts as truish answer + and won't ask user for an answer. + If we don't have a tty as input and default_notty is not None, return its value. + Otherwise read input from non-tty and proceed as normal. + If EOF is received instead an input, return default_eof [or default, if not given]. + + :param msg: introducing message to output on ofile, no \n is added [None] + :param retry_msg: retry message to output on ofile, no \n is added [None] + (also enforces retries instead of returning default) + :param false_msg: message to output before returning False [None] + :param true_msg: message to output before returning True [None] + :param default: default return value (empty answer is given) [False] + :param default_notty: if not None, return its value if no tty is connected [None] + :param default_eof: return value if EOF was read as answer [same as default] + :param falsish: sequence of answers qualifying as False + :param truish: sequence of answers qualifying as True + :param env_var_override: environment variable name [None] + :param ifile: input stream [sys.stdin] (only for testing!) + :param ofile: output stream [sys.stderr] + :param input: input function [input from builtins] + :return: boolean answer value, True or False + """ + # note: we do not assign sys.stdin/stderr as defaults above, so they are + # really evaluated NOW, not at function definition time. + if ifile is None: + ifile = sys.stdin + if ofile is None: + ofile = sys.stderr + if default not in (True, False): + raise ValueError("invalid default value, must be True or False") + if default_notty not in (None, True, False): + raise ValueError("invalid default_notty value, must be None, True or False") + if default_eof not in (None, True, False): + raise ValueError("invalid default_eof value, must be None, True or False") + if msg: + print(msg, file=ofile, end='') + ofile.flush() + if env_var_override: + value = os.environ.get(env_var_override) + # currently, any non-empty value counts as truish + # TODO: change this so one can give y/n there? + if value: + value = bool(value) + value_str = truish[0] if value else falsish[0] + print("{} (from {})".format(value_str, env_var_override), file=ofile) + return value + if default_notty is not None and not ifile.isatty(): + # looks like ifile is not a terminal (but e.g. a pipe) + return default_notty + while True: + try: + answer = input() # XXX how can we use ifile? + except EOFError: + return default_eof if default_eof is not None else default + if answer in truish: + if true_msg: + print(true_msg, file=ofile) + return True + if answer in falsish: + if false_msg: + print(false_msg, file=ofile) + return False + if retry_msg is None: + # no retries wanted, we just return the default + return default + if retry_msg: + print(retry_msg, file=ofile, end='') + ofile.flush() diff --git a/borg/testsuite/helpers.py b/borg/testsuite/helpers.py index 2faa569da..58861aaab 100644 --- a/borg/testsuite/helpers.py +++ b/borg/testsuite/helpers.py @@ -10,9 +10,9 @@ import msgpack import msgpack.fallback from ..helpers import adjust_patterns, exclude_path, Location, format_file_size, format_timedelta, IncludePattern, ExcludePattern, make_path_safe, \ - prune_within, prune_split, get_cache_dir, Statistics, is_slow_msgpack, \ + prune_within, prune_split, get_cache_dir, Statistics, is_slow_msgpack, yes, \ StableDict, int_to_bigint, bigint_to_int, parse_timestamp, CompressionSpec, ChunkerParams -from . import BaseTestCase +from . import BaseTestCase, environment_variable, FakeInputs class BigIntTestCase(BaseTestCase): @@ -492,3 +492,77 @@ def test_is_slow_msgpack(): msgpack.Packer = saved_packer # this assumes that we have fast msgpack on test platform: assert not is_slow_msgpack() + + +def test_yes_simple(): + input = FakeInputs(['y', 'Y', 'yes', 'Yes', ]) + assert yes(input=input) + assert yes(input=input) + assert yes(input=input) + assert yes(input=input) + input = FakeInputs(['n', 'N', 'no', 'No', ]) + assert not yes(input=input) + assert not yes(input=input) + assert not yes(input=input) + assert not yes(input=input) + + +def test_yes_custom(): + input = FakeInputs(['YES', 'SURE', 'NOPE', ]) + assert yes(truish=('YES', ), input=input) + assert yes(truish=('SURE', ), input=input) + assert not yes(falsish=('NOPE', ), input=input) + + +def test_yes_env(): + input = FakeInputs(['n', 'n']) + with environment_variable(OVERRIDE_THIS='nonempty'): + assert yes(env_var_override='OVERRIDE_THIS', input=input) + with environment_variable(OVERRIDE_THIS=None): # env not set + assert not yes(env_var_override='OVERRIDE_THIS', input=input) + + +def test_yes_defaults(): + input = FakeInputs(['invalid', '', ' ']) + assert not yes(input=input) # default=False + assert not yes(input=input) + assert not yes(input=input) + input = FakeInputs(['invalid', '', ' ']) + assert yes(default=True, input=input) + assert yes(default=True, input=input) + assert yes(default=True, input=input) + ifile = StringIO() + assert yes(default_notty=True, ifile=ifile) + assert not yes(default_notty=False, ifile=ifile) + input = FakeInputs([]) + assert yes(default_eof=True, input=input) + assert not yes(default_eof=False, input=input) + with pytest.raises(ValueError): + yes(default=None) + with pytest.raises(ValueError): + yes(default_notty='invalid') + with pytest.raises(ValueError): + yes(default_eof='invalid') + + +def test_yes_retry(): + input = FakeInputs(['foo', 'bar', 'y', ]) + assert yes(retry_msg='Retry: ', input=input) + input = FakeInputs(['foo', 'bar', 'N', ]) + assert not yes(retry_msg='Retry: ', input=input) + + +def test_yes_output(capfd): + input = FakeInputs(['invalid', 'y', 'n']) + assert yes(msg='intro-msg', false_msg='false-msg', true_msg='true-msg', retry_msg='retry-msg', input=input) + out, err = capfd.readouterr() + assert out == '' + assert 'intro-msg' in err + assert 'retry-msg' in err + assert 'true-msg' in err + assert not yes(msg='intro-msg', false_msg='false-msg', true_msg='true-msg', retry_msg='retry-msg', input=input) + out, err = capfd.readouterr() + assert out == '' + assert 'intro-msg' in err + assert 'retry-msg' not in err + assert 'false-msg' in err