1
0
Fork 0
mirror of https://github.com/borgbackup/borg.git synced 2024-12-26 01:37:20 +00:00

refactor buffer code into helpers.Buffer class, add tests

This commit is contained in:
Thomas Waldmann 2016-08-12 19:34:29 +02:00
parent 00449ad7b0
commit ef9e8a584b
4 changed files with 112 additions and 23 deletions

View file

@ -10,6 +10,7 @@
from shutil import get_terminal_size from shutil import get_terminal_size
import sys import sys
import platform import platform
import threading
import time import time
import unicodedata import unicodedata
import io import io
@ -655,6 +656,47 @@ def decorated_function(*args):
return decorated_function return decorated_function
class Buffer:
"""
provide a thread-local buffer
"""
def __init__(self, allocator, size=4096, limit=None):
"""
Initialize the buffer: use allocator(size) call to allocate a buffer.
Optionally, set the upper <limit> for the buffer size.
"""
assert callable(allocator), 'must give alloc(size) function as first param'
assert limit is None or size <= limit, 'initial size must be <= limit'
self._thread_local = threading.local()
self.allocator = allocator
self.limit = limit
self.resize(size, init=True)
def __len__(self):
return len(self._thread_local.buffer)
def resize(self, size, init=False):
"""
resize the buffer - to avoid frequent reallocation, we usually always grow (if needed).
giving init=True it is possible to first-time initialize or shrink the buffer.
if a buffer size beyond the limit is requested, raise ValueError.
"""
size = int(size)
if self.limit is not None and size > self.limit:
raise ValueError('Requested buffer size %d is above the limit of %d.' % (size, self.limit))
if init or len(self) < size:
self._thread_local.buffer = self.allocator(size)
def get(self, size=None, init=False):
"""
return a buffer of at least the requested size (None: any current size).
init=True can be given to trigger shrinking of the buffer to the given size.
"""
if size is not None:
self.resize(size, init)
return self._thread_local.buffer
@memoize @memoize
def uid2user(uid, default=None): def uid2user(uid, default=None):
try: try:

View file

@ -15,7 +15,8 @@
yes, TRUISH, FALSISH, DEFAULTISH, \ yes, TRUISH, FALSISH, DEFAULTISH, \
StableDict, int_to_bigint, bigint_to_int, parse_timestamp, CompressionSpec, ChunkerParams, \ StableDict, int_to_bigint, bigint_to_int, parse_timestamp, CompressionSpec, ChunkerParams, \
ProgressIndicatorPercent, ProgressIndicatorEndless, load_excludes, parse_pattern, \ ProgressIndicatorPercent, ProgressIndicatorEndless, load_excludes, parse_pattern, \
PatternMatcher, RegexPattern, PathPrefixPattern, FnmatchPattern, ShellPattern PatternMatcher, RegexPattern, PathPrefixPattern, FnmatchPattern, ShellPattern, \
Buffer
from . import BaseTestCase, environment_variable, FakeInputs from . import BaseTestCase, environment_variable, FakeInputs
@ -714,6 +715,61 @@ def test_is_slow_msgpack():
assert not is_slow_msgpack() assert not is_slow_msgpack()
class TestBuffer:
def test_type(self):
buffer = Buffer(bytearray)
assert isinstance(buffer.get(), bytearray)
buffer = Buffer(bytes) # don't do that in practice
assert isinstance(buffer.get(), bytes)
def test_len(self):
buffer = Buffer(bytearray, size=0)
b = buffer.get()
assert len(buffer) == len(b) == 0
buffer = Buffer(bytearray, size=1234)
b = buffer.get()
assert len(buffer) == len(b) == 1234
def test_resize(self):
buffer = Buffer(bytearray, size=100)
assert len(buffer) == 100
b1 = buffer.get()
buffer.resize(200)
assert len(buffer) == 200
b2 = buffer.get()
assert b2 is not b1 # new, bigger buffer
buffer.resize(100)
assert len(buffer) >= 100
b3 = buffer.get()
assert b3 is b2 # still same buffer (200)
buffer.resize(100, init=True)
assert len(buffer) == 100 # except on init
b4 = buffer.get()
assert b4 is not b3 # new, smaller buffer
def test_limit(self):
buffer = Buffer(bytearray, size=100, limit=200)
buffer.resize(200)
assert len(buffer) == 200
with pytest.raises(ValueError):
buffer.resize(201)
assert len(buffer) == 200
def test_get(self):
buffer = Buffer(bytearray, size=100, limit=200)
b1 = buffer.get(50)
assert len(b1) >= 50 # == 100
b2 = buffer.get(100)
assert len(b2) >= 100 # == 100
assert b2 is b1 # did not need resizing yet
b3 = buffer.get(200)
assert len(b3) == 200
assert b3 is not b2 # new, resized buffer
with pytest.raises(ValueError):
buffer.get(201) # beyond limit
assert len(buffer) == 200
def test_yes_input(): def test_yes_input():
inputs = list(TRUISH) inputs = list(TRUISH)
input = FakeInputs(inputs) input = FakeInputs(inputs)

View file

@ -2,7 +2,7 @@
import tempfile import tempfile
import unittest import unittest
from ..xattr import is_enabled, getxattr, setxattr, listxattr, get_buffer from ..xattr import is_enabled, getxattr, setxattr, listxattr, buffer
from . import BaseTestCase from . import BaseTestCase
@ -41,20 +41,20 @@ def test(self):
def test_listxattr_buffer_growth(self): def test_listxattr_buffer_growth(self):
# make it work even with ext4, which imposes rather low limits # make it work even with ext4, which imposes rather low limits
get_buffer(size=64, init=True) buffer.resize(size=64, init=True)
# xattr raw key list will be size 9 * (10 + 1), which is > 64 # xattr raw key list will be size 9 * (10 + 1), which is > 64
keys = ['user.attr%d' % i for i in range(9)] keys = ['user.attr%d' % i for i in range(9)]
for key in keys: for key in keys:
setxattr(self.tmpfile.name, key, b'x') setxattr(self.tmpfile.name, key, b'x')
got_keys = listxattr(self.tmpfile.name) got_keys = listxattr(self.tmpfile.name)
self.assert_equal_se(got_keys, keys) self.assert_equal_se(got_keys, keys)
self.assert_equal(len(get_buffer()), 128) self.assert_equal(len(buffer), 128)
def test_getxattr_buffer_growth(self): def test_getxattr_buffer_growth(self):
# make it work even with ext4, which imposes rather low limits # make it work even with ext4, which imposes rather low limits
get_buffer(size=64, init=True) buffer.resize(size=64, init=True)
value = b'x' * 126 value = b'x' * 126
setxattr(self.tmpfile.name, 'user.big', value) setxattr(self.tmpfile.name, 'user.big', value)
got_value = getxattr(self.tmpfile.name, 'user.big') got_value = getxattr(self.tmpfile.name, 'user.big')
self.assert_equal(value, got_value) self.assert_equal(value, got_value)
self.assert_equal(len(get_buffer()), 128) self.assert_equal(len(buffer), 128)

View file

@ -6,11 +6,12 @@
import subprocess import subprocess
import sys import sys
import tempfile import tempfile
import threading
from ctypes import CDLL, create_string_buffer, c_ssize_t, c_size_t, c_char_p, c_int, c_uint32, get_errno from ctypes import CDLL, create_string_buffer, c_ssize_t, c_size_t, c_char_p, c_int, c_uint32, get_errno
from ctypes.util import find_library from ctypes.util import find_library
from distutils.version import LooseVersion from distutils.version import LooseVersion
from .helpers import Buffer
from .logger import create_logger from .logger import create_logger
logger = create_logger() logger = create_logger()
@ -22,17 +23,7 @@
ENOATTR = errno.ENODATA ENOATTR = errno.ENODATA
def get_buffer(size=None, init=False): buffer = Buffer(create_string_buffer, limit=2**24)
if size is not None:
size = int(size)
assert size < 2 ** 24
if init or len(thread_local.buffer) < size:
thread_local.buffer = create_string_buffer(size)
return thread_local.buffer
thread_local = threading.local()
get_buffer(size=4096, init=True)
def is_enabled(path=None): def is_enabled(path=None):
@ -144,7 +135,7 @@ def _check(rv, path=None, detect_buffer_too_small=False):
if isinstance(path, int): if isinstance(path, int):
path = '<FD %d>' % path path = '<FD %d>' % path
raise OSError(e, msg, path) raise OSError(e, msg, path)
if detect_buffer_too_small and rv >= len(get_buffer()): if detect_buffer_too_small and rv >= len(buffer):
# freebsd does not error with ERANGE if the buffer is too small, # freebsd does not error with ERANGE if the buffer is too small,
# it just fills the buffer, truncates and returns. # it just fills the buffer, truncates and returns.
# so, we play sure and just assume that result is truncated if # so, we play sure and just assume that result is truncated if
@ -156,9 +147,9 @@ def _check(rv, path=None, detect_buffer_too_small=False):
def _listxattr_inner(func, path): def _listxattr_inner(func, path):
if isinstance(path, str): if isinstance(path, str):
path = os.fsencode(path) path = os.fsencode(path)
size = len(get_buffer()) size = len(buffer)
while True: while True:
buf = get_buffer(size) buf = buffer.get(size)
try: try:
n = _check(func(path, buf, size), path, detect_buffer_too_small=True) n = _check(func(path, buf, size), path, detect_buffer_too_small=True)
except BufferTooSmallError: except BufferTooSmallError:
@ -171,9 +162,9 @@ def _getxattr_inner(func, path, name):
if isinstance(path, str): if isinstance(path, str):
path = os.fsencode(path) path = os.fsencode(path)
name = os.fsencode(name) name = os.fsencode(name)
size = len(get_buffer()) size = len(buffer)
while True: while True:
buf = get_buffer(size) buf = buffer.get(size)
try: try:
n = _check(func(path, name, buf, size), path, detect_buffer_too_small=True) n = _check(func(path, name, buf, size), path, detect_buffer_too_small=True)
except BufferTooSmallError: except BufferTooSmallError: