mirror of https://github.com/morpheus65535/bazarr
Upgraded vendored Python dependencies to the latest versions and removed the unused dependencies.
This commit is contained in:
parent
36bf0d219d
commit
0c3c5a02a7
|
@ -52,7 +52,10 @@ def refine_from_ffprobe(path, video):
|
|||
if isinstance(data['ffprobe']['video'][0]['frame_rate'], float):
|
||||
video.fps = data['ffprobe']['video'][0]['frame_rate']
|
||||
else:
|
||||
video.fps = data['ffprobe']['video'][0]['frame_rate'].magnitude
|
||||
try:
|
||||
video.fps = data['ffprobe']['video'][0]['frame_rate'].magnitude
|
||||
except AttributeError:
|
||||
video.fps = data['ffprobe']['video'][0]['frame_rate']
|
||||
|
||||
if 'audio' not in data['ffprobe']:
|
||||
logging.debug('BAZARR FFprobe was unable to find audio tracks in the file!')
|
||||
|
|
|
@ -184,9 +184,6 @@ def init_binaries():
|
|||
except Exception:
|
||||
logging.debug("custom check failed for: %s", exe)
|
||||
|
||||
rarfile.OPEN_ARGS = rarfile.ORIG_OPEN_ARGS
|
||||
rarfile.EXTRACT_ARGS = rarfile.ORIG_EXTRACT_ARGS
|
||||
rarfile.TEST_ARGS = rarfile.ORIG_TEST_ARGS
|
||||
logging.debug("Using UnRAR from: %s", exe)
|
||||
unrar = exe
|
||||
|
||||
|
|
|
@ -7,6 +7,8 @@ import platform
|
|||
import warnings
|
||||
|
||||
from logging.handlers import TimedRotatingFileHandler
|
||||
from pytz_deprecation_shim import PytzUsageWarning
|
||||
|
||||
from get_args import args
|
||||
from config import settings
|
||||
|
||||
|
@ -55,6 +57,7 @@ class NoExceptionFormatter(logging.Formatter):
|
|||
|
||||
def configure_logging(debug=False):
|
||||
warnings.simplefilter('ignore', category=ResourceWarning)
|
||||
warnings.simplefilter('ignore', category=PytzUsageWarning)
|
||||
|
||||
if not debug:
|
||||
log_level = "INFO"
|
||||
|
|
|
@ -1,16 +0,0 @@
|
|||
try:
|
||||
import ast
|
||||
from _markerlib.markers import default_environment, compile, interpret
|
||||
except ImportError:
|
||||
if 'ast' in globals():
|
||||
raise
|
||||
def default_environment():
|
||||
return {}
|
||||
def compile(marker):
|
||||
def marker_fn(environment=None, override=None):
|
||||
# 'empty markers are True' heuristic won't install extra deps.
|
||||
return not marker.strip()
|
||||
marker_fn.__doc__ = marker
|
||||
return marker_fn
|
||||
def interpret(marker, environment=None, override=None):
|
||||
return compile(marker)()
|
|
@ -1,119 +0,0 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
"""Interpret PEP 345 environment markers.
|
||||
|
||||
EXPR [in|==|!=|not in] EXPR [or|and] ...
|
||||
|
||||
where EXPR belongs to any of those:
|
||||
|
||||
python_version = '%s.%s' % (sys.version_info[0], sys.version_info[1])
|
||||
python_full_version = sys.version.split()[0]
|
||||
os.name = os.name
|
||||
sys.platform = sys.platform
|
||||
platform.version = platform.version()
|
||||
platform.machine = platform.machine()
|
||||
platform.python_implementation = platform.python_implementation()
|
||||
a free string, like '2.6', or 'win32'
|
||||
"""
|
||||
|
||||
__all__ = ['default_environment', 'compile', 'interpret']
|
||||
|
||||
import ast
|
||||
import os
|
||||
import platform
|
||||
import sys
|
||||
import weakref
|
||||
|
||||
_builtin_compile = compile
|
||||
|
||||
try:
|
||||
from platform import python_implementation
|
||||
except ImportError:
|
||||
if os.name == "java":
|
||||
# Jython 2.5 has ast module, but not platform.python_implementation() function.
|
||||
def python_implementation():
|
||||
return "Jython"
|
||||
else:
|
||||
raise
|
||||
|
||||
|
||||
# restricted set of variables
|
||||
_VARS = {'sys.platform': sys.platform,
|
||||
'python_version': '%s.%s' % sys.version_info[:2],
|
||||
# FIXME parsing sys.platform is not reliable, but there is no other
|
||||
# way to get e.g. 2.7.2+, and the PEP is defined with sys.version
|
||||
'python_full_version': sys.version.split(' ', 1)[0],
|
||||
'os.name': os.name,
|
||||
'platform.version': platform.version(),
|
||||
'platform.machine': platform.machine(),
|
||||
'platform.python_implementation': python_implementation(),
|
||||
'extra': None # wheel extension
|
||||
}
|
||||
|
||||
for var in list(_VARS.keys()):
|
||||
if '.' in var:
|
||||
_VARS[var.replace('.', '_')] = _VARS[var]
|
||||
|
||||
def default_environment():
|
||||
"""Return copy of default PEP 385 globals dictionary."""
|
||||
return dict(_VARS)
|
||||
|
||||
class ASTWhitelist(ast.NodeTransformer):
|
||||
def __init__(self, statement):
|
||||
self.statement = statement # for error messages
|
||||
|
||||
ALLOWED = (ast.Compare, ast.BoolOp, ast.Attribute, ast.Name, ast.Load, ast.Str)
|
||||
# Bool operations
|
||||
ALLOWED += (ast.And, ast.Or)
|
||||
# Comparison operations
|
||||
ALLOWED += (ast.Eq, ast.Gt, ast.GtE, ast.In, ast.Is, ast.IsNot, ast.Lt, ast.LtE, ast.NotEq, ast.NotIn)
|
||||
|
||||
def visit(self, node):
|
||||
"""Ensure statement only contains allowed nodes."""
|
||||
if not isinstance(node, self.ALLOWED):
|
||||
raise SyntaxError('Not allowed in environment markers.\n%s\n%s' %
|
||||
(self.statement,
|
||||
(' ' * node.col_offset) + '^'))
|
||||
return ast.NodeTransformer.visit(self, node)
|
||||
|
||||
def visit_Attribute(self, node):
|
||||
"""Flatten one level of attribute access."""
|
||||
new_node = ast.Name("%s.%s" % (node.value.id, node.attr), node.ctx)
|
||||
return ast.copy_location(new_node, node)
|
||||
|
||||
def parse_marker(marker):
|
||||
tree = ast.parse(marker, mode='eval')
|
||||
new_tree = ASTWhitelist(marker).generic_visit(tree)
|
||||
return new_tree
|
||||
|
||||
def compile_marker(parsed_marker):
|
||||
return _builtin_compile(parsed_marker, '<environment marker>', 'eval',
|
||||
dont_inherit=True)
|
||||
|
||||
_cache = weakref.WeakValueDictionary()
|
||||
|
||||
def compile(marker):
|
||||
"""Return compiled marker as a function accepting an environment dict."""
|
||||
try:
|
||||
return _cache[marker]
|
||||
except KeyError:
|
||||
pass
|
||||
if not marker.strip():
|
||||
def marker_fn(environment=None, override=None):
|
||||
""""""
|
||||
return True
|
||||
else:
|
||||
compiled_marker = compile_marker(parse_marker(marker))
|
||||
def marker_fn(environment=None, override=None):
|
||||
"""override updates environment"""
|
||||
if override is None:
|
||||
override = {}
|
||||
if environment is None:
|
||||
environment = default_environment()
|
||||
environment.update(override)
|
||||
return eval(compiled_marker, environment)
|
||||
marker_fn.__doc__ = marker
|
||||
_cache[marker] = marker_fn
|
||||
return _cache[marker]
|
||||
|
||||
def interpret(marker, environment=None):
|
||||
return compile(marker)(environment)
|
|
@ -13,8 +13,8 @@ See <http://github.com/ActiveState/appdirs> for details and usage.
|
|||
# - Mac OS X: http://developer.apple.com/documentation/MacOSX/Conceptual/BPFileSystem/index.html
|
||||
# - XDG spec for Un*x: http://standards.freedesktop.org/basedir-spec/basedir-spec-latest.html
|
||||
|
||||
__version_info__ = (1, 4, 3)
|
||||
__version__ = '.'.join(map(str, __version_info__))
|
||||
__version__ = "1.4.4"
|
||||
__version_info__ = tuple(int(segment) for segment in __version__.split("."))
|
||||
|
||||
|
||||
import sys
|
||||
|
@ -98,7 +98,7 @@ def user_data_dir(appname=None, appauthor=None, version=None, roaming=False):
|
|||
|
||||
|
||||
def site_data_dir(appname=None, appauthor=None, version=None, multipath=False):
|
||||
"""Return full path to the user-shared data dir for this application.
|
||||
r"""Return full path to the user-shared data dir for this application.
|
||||
|
||||
"appname" is the name of application.
|
||||
If None, just the system directory is returned.
|
||||
|
@ -204,7 +204,7 @@ def user_config_dir(appname=None, appauthor=None, version=None, roaming=False):
|
|||
|
||||
|
||||
def site_config_dir(appname=None, appauthor=None, version=None, multipath=False):
|
||||
"""Return full path to the user-shared data dir for this application.
|
||||
r"""Return full path to the user-shared data dir for this application.
|
||||
|
||||
"appname" is the name of application.
|
||||
If None, just the system directory is returned.
|
||||
|
|
|
@ -1,116 +0,0 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright (c) 2014-2016 The arghelper developers. All rights reserved.
|
||||
# Project site: https://github.com/questrail/arghelper
|
||||
# Use of this source code is governed by a MIT-style license that
|
||||
# can be found in the LICENSE.txt file for the project.
|
||||
"""Provide helper functions for argparse
|
||||
|
||||
"""
|
||||
|
||||
# Try to future proof code so that it's Python 3.x ready
|
||||
from __future__ import print_function
|
||||
from __future__ import unicode_literals
|
||||
from __future__ import division
|
||||
from __future__ import absolute_import
|
||||
|
||||
# Standard module imports
|
||||
import argparse
|
||||
import sys
|
||||
import os
|
||||
|
||||
|
||||
def extant_file(arg):
|
||||
"""Facade for extant_item(arg, arg_type="file")
|
||||
"""
|
||||
return extant_item(arg, "file")
|
||||
|
||||
|
||||
def extant_dir(arg):
|
||||
"""Facade for extant_item(arg, arg_type="directory")
|
||||
"""
|
||||
return extant_item(arg, "directory")
|
||||
|
||||
|
||||
def extant_item(arg, arg_type):
|
||||
"""Determine if parser argument is an existing file or directory.
|
||||
|
||||
This technique comes from http://stackoverflow.com/a/11541450/95592
|
||||
and from http://stackoverflow.com/a/11541495/95592
|
||||
|
||||
Args:
|
||||
arg: parser argument containing filename to be checked
|
||||
arg_type: string of either "file" or "directory"
|
||||
|
||||
Returns:
|
||||
If the file exists, return the filename or directory.
|
||||
|
||||
Raises:
|
||||
If the file does not exist, raise a parser error.
|
||||
"""
|
||||
if arg_type == "file":
|
||||
if not os.path.isfile(arg):
|
||||
raise argparse.ArgumentError(
|
||||
None,
|
||||
"The file {arg} does not exist.".format(arg=arg))
|
||||
else:
|
||||
# File exists so return the filename
|
||||
return arg
|
||||
elif arg_type == "directory":
|
||||
if not os.path.isdir(arg):
|
||||
raise argparse.ArgumentError(
|
||||
None,
|
||||
"The directory {arg} does not exist.".format(arg=arg))
|
||||
else:
|
||||
# Directory exists so return the directory name
|
||||
return arg
|
||||
|
||||
|
||||
def parse_config_input_output(args=sys.argv):
|
||||
"""Parse the args using the config_file, input_dir, output_dir pattern
|
||||
|
||||
Args:
|
||||
args: sys.argv
|
||||
|
||||
Returns:
|
||||
The populated namespace object from parser.parse_args().
|
||||
|
||||
Raises:
|
||||
TBD
|
||||
"""
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Process the input files using the given config')
|
||||
parser.add_argument(
|
||||
'config_file',
|
||||
help='Configuration file.',
|
||||
metavar='FILE', type=extant_file)
|
||||
parser.add_argument(
|
||||
'input_dir',
|
||||
help='Directory containing the input files.',
|
||||
metavar='DIR', type=extant_dir)
|
||||
parser.add_argument(
|
||||
'output_dir',
|
||||
help='Directory where the output files should be saved.',
|
||||
metavar='DIR', type=extant_dir)
|
||||
return parser.parse_args(args[1:])
|
||||
|
||||
|
||||
def parse_config(args=sys.argv):
|
||||
"""Parse the args using the config_file pattern
|
||||
|
||||
Args:
|
||||
args: sys.argv
|
||||
|
||||
Returns:
|
||||
The populated namespace object from parser.parse_args().
|
||||
|
||||
Raises:
|
||||
TBD
|
||||
"""
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Read in the config file')
|
||||
parser.add_argument(
|
||||
'config_file',
|
||||
help='Configuration file.',
|
||||
metavar='FILE', type=extant_file)
|
||||
return parser.parse_args(args[1:])
|
|
@ -1,61 +0,0 @@
|
|||
# Copyright 2013 Dean Gardiner <gardiner91@gmail.com>
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from asio.file import SEEK_ORIGIN_CURRENT
|
||||
from asio.file_opener import FileOpener
|
||||
from asio.open_parameters import OpenParameters
|
||||
from asio.interfaces.posix import PosixInterface
|
||||
from asio.interfaces.windows import WindowsInterface
|
||||
|
||||
import os
|
||||
|
||||
|
||||
class ASIO(object):
|
||||
platform_handler = None
|
||||
|
||||
@classmethod
|
||||
def get_handler(cls):
|
||||
if cls.platform_handler:
|
||||
return cls.platform_handler
|
||||
|
||||
if os.name == 'nt':
|
||||
cls.platform_handler = WindowsInterface
|
||||
elif os.name == 'posix':
|
||||
cls.platform_handler = PosixInterface
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
return cls.platform_handler
|
||||
|
||||
@classmethod
|
||||
def open(cls, file_path, opener=True, parameters=None):
|
||||
"""Open file
|
||||
|
||||
:type file_path: str
|
||||
|
||||
:param opener: Use FileOpener, for use with the 'with' statement
|
||||
:type opener: bool
|
||||
|
||||
:rtype: asio.file.File
|
||||
"""
|
||||
if not parameters:
|
||||
parameters = OpenParameters()
|
||||
|
||||
if opener:
|
||||
return FileOpener(file_path, parameters)
|
||||
|
||||
return ASIO.get_handler().open(
|
||||
file_path,
|
||||
parameters=parameters.handlers.get(ASIO.get_handler())
|
||||
)
|
|
@ -1,92 +0,0 @@
|
|||
# Copyright 2013 Dean Gardiner <gardiner91@gmail.com>
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from io import RawIOBase
|
||||
import time
|
||||
|
||||
DEFAULT_BUFFER_SIZE = 4096
|
||||
|
||||
SEEK_ORIGIN_BEGIN = 0
|
||||
SEEK_ORIGIN_CURRENT = 1
|
||||
SEEK_ORIGIN_END = 2
|
||||
|
||||
|
||||
class ReadTimeoutError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class File(RawIOBase):
|
||||
platform_handler = None
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(File, self).__init__(*args, **kwargs)
|
||||
|
||||
def get_handler(self):
|
||||
"""
|
||||
:rtype: asio.interfaces.base.Interface
|
||||
"""
|
||||
if not self.platform_handler:
|
||||
raise ValueError()
|
||||
|
||||
return self.platform_handler
|
||||
|
||||
def get_size(self):
|
||||
"""Get the current file size
|
||||
|
||||
:rtype: int
|
||||
"""
|
||||
return self.get_handler().get_size(self)
|
||||
|
||||
def get_path(self):
|
||||
"""Get the path of this file
|
||||
|
||||
:rtype: str
|
||||
"""
|
||||
return self.get_handler().get_path(self)
|
||||
|
||||
def seek(self, offset, origin):
|
||||
"""Sets a reference point of a file to the given value.
|
||||
|
||||
:param offset: The point relative to origin to move
|
||||
:type offset: int
|
||||
|
||||
:param origin: Reference point to seek (SEEK_ORIGIN_BEGIN, SEEK_ORIGIN_CURRENT, SEEK_ORIGIN_END)
|
||||
:type origin: int
|
||||
"""
|
||||
return self.get_handler().seek(self, offset, origin)
|
||||
|
||||
def read(self, n=-1):
|
||||
"""Read up to n bytes from the object and return them.
|
||||
|
||||
:type n: int
|
||||
:rtype: str
|
||||
"""
|
||||
return self.get_handler().read(self, n)
|
||||
|
||||
def readinto(self, b):
|
||||
"""Read up to len(b) bytes into bytearray b and return the number of bytes read."""
|
||||
data = self.read(len(b))
|
||||
|
||||
if data is None:
|
||||
return None
|
||||
|
||||
b[:len(data)] = data
|
||||
return len(data)
|
||||
|
||||
def close(self):
|
||||
"""Close the file handle"""
|
||||
return self.get_handler().close(self)
|
||||
|
||||
def readable(self, *args, **kwargs):
|
||||
return True
|
|
@ -1,21 +0,0 @@
|
|||
class FileOpener(object):
|
||||
def __init__(self, file_path, parameters=None):
|
||||
self.file_path = file_path
|
||||
self.parameters = parameters
|
||||
|
||||
self.file = None
|
||||
|
||||
def __enter__(self):
|
||||
self.file = ASIO.get_handler().open(
|
||||
self.file_path,
|
||||
self.parameters.handlers.get(ASIO.get_handler())
|
||||
)
|
||||
|
||||
return self.file
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
if not self.file:
|
||||
return
|
||||
|
||||
self.file.close()
|
||||
self.file = None
|
|
@ -1,41 +0,0 @@
|
|||
# Copyright 2013 Dean Gardiner <gardiner91@gmail.com>
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from asio.file import DEFAULT_BUFFER_SIZE
|
||||
|
||||
|
||||
class Interface(object):
|
||||
@classmethod
|
||||
def open(cls, file_path, parameters=None):
|
||||
raise NotImplementedError()
|
||||
|
||||
@classmethod
|
||||
def get_size(cls, fp):
|
||||
raise NotImplementedError()
|
||||
|
||||
@classmethod
|
||||
def get_path(cls, fp):
|
||||
raise NotImplementedError()
|
||||
|
||||
@classmethod
|
||||
def seek(cls, fp, pointer, distance):
|
||||
raise NotImplementedError()
|
||||
|
||||
@classmethod
|
||||
def read(cls, fp, n=DEFAULT_BUFFER_SIZE):
|
||||
raise NotImplementedError()
|
||||
|
||||
@classmethod
|
||||
def close(cls, fp):
|
||||
raise NotImplementedError()
|
|
@ -1,123 +0,0 @@
|
|||
# Copyright 2013 Dean Gardiner <gardiner91@gmail.com>
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from asio.file import File, DEFAULT_BUFFER_SIZE
|
||||
from asio.interfaces.base import Interface
|
||||
|
||||
import sys
|
||||
import os
|
||||
|
||||
if os.name == 'posix':
|
||||
import select
|
||||
|
||||
# fcntl is only required on darwin
|
||||
if sys.platform == 'darwin':
|
||||
import fcntl
|
||||
|
||||
F_GETPATH = 50
|
||||
|
||||
|
||||
class PosixInterface(Interface):
|
||||
@classmethod
|
||||
def open(cls, file_path, parameters=None):
|
||||
"""
|
||||
:type file_path: str
|
||||
:rtype: asio.interfaces.posix.PosixFile
|
||||
"""
|
||||
if not parameters:
|
||||
parameters = {}
|
||||
|
||||
if not parameters.get('mode'):
|
||||
parameters.pop('mode')
|
||||
|
||||
if not parameters.get('buffering'):
|
||||
parameters.pop('buffering')
|
||||
|
||||
fd = os.open(file_path, os.O_RDONLY | os.O_NONBLOCK)
|
||||
|
||||
return PosixFile(fd)
|
||||
|
||||
@classmethod
|
||||
def get_size(cls, fp):
|
||||
"""
|
||||
:type fp: asio.interfaces.posix.PosixFile
|
||||
:rtype: int
|
||||
"""
|
||||
return os.fstat(fp.fd).st_size
|
||||
|
||||
@classmethod
|
||||
def get_path(cls, fp):
|
||||
"""
|
||||
:type fp: asio.interfaces.posix.PosixFile
|
||||
:rtype: int
|
||||
"""
|
||||
|
||||
# readlink /dev/fd fails on darwin, so instead use fcntl F_GETPATH
|
||||
if sys.platform == 'darwin':
|
||||
return fcntl.fcntl(fp.fd, F_GETPATH, '\0' * 1024).rstrip('\0')
|
||||
|
||||
# Use /proc/self/fd if available
|
||||
if os.path.lexists("/proc/self/fd/"):
|
||||
return os.readlink("/proc/self/fd/%s" % fp.fd)
|
||||
|
||||
# Fallback to /dev/fd
|
||||
if os.path.lexists("/dev/fd/"):
|
||||
return os.readlink("/dev/fd/%s" % fp.fd)
|
||||
|
||||
raise NotImplementedError('Environment not supported (fdescfs not mounted?)')
|
||||
|
||||
@classmethod
|
||||
def seek(cls, fp, offset, origin):
|
||||
"""
|
||||
:type fp: asio.interfaces.posix.PosixFile
|
||||
:type offset: int
|
||||
:type origin: int
|
||||
"""
|
||||
os.lseek(fp.fd, offset, origin)
|
||||
|
||||
@classmethod
|
||||
def read(cls, fp, n=DEFAULT_BUFFER_SIZE):
|
||||
"""
|
||||
:type fp: asio.interfaces.posix.PosixFile
|
||||
:type n: int
|
||||
:rtype: str
|
||||
"""
|
||||
r, w, x = select.select([fp.fd], [], [], 5)
|
||||
|
||||
if r:
|
||||
return os.read(fp.fd, n)
|
||||
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def close(cls, fp):
|
||||
"""
|
||||
:type fp: asio.interfaces.posix.PosixFile
|
||||
"""
|
||||
os.close(fp.fd)
|
||||
|
||||
|
||||
class PosixFile(File):
|
||||
platform_handler = PosixInterface
|
||||
|
||||
def __init__(self, fd, *args, **kwargs):
|
||||
"""
|
||||
:type fd: asio.file.File
|
||||
"""
|
||||
super(PosixFile, self).__init__(*args, **kwargs)
|
||||
|
||||
self.fd = fd
|
||||
|
||||
def __str__(self):
|
||||
return "<asio_posix.PosixFile file: %s>" % self.fd
|
|
@ -1,201 +0,0 @@
|
|||
# Copyright 2013 Dean Gardiner <gardiner91@gmail.com>
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from asio.file import File, DEFAULT_BUFFER_SIZE
|
||||
from asio.interfaces.base import Interface
|
||||
|
||||
import os
|
||||
|
||||
|
||||
NULL = 0
|
||||
|
||||
if os.name == 'nt':
|
||||
from asio.interfaces.windows.interop import WindowsInterop
|
||||
|
||||
|
||||
class WindowsInterface(Interface):
|
||||
@classmethod
|
||||
def open(cls, file_path, parameters=None):
|
||||
"""
|
||||
:type file_path: str
|
||||
:rtype: asio.interfaces.windows.WindowsFile
|
||||
"""
|
||||
if not parameters:
|
||||
parameters = {}
|
||||
|
||||
return WindowsFile(WindowsInterop.create_file(
|
||||
file_path,
|
||||
parameters.get('desired_access', WindowsInterface.GenericAccess.READ),
|
||||
parameters.get('share_mode', WindowsInterface.ShareMode.ALL),
|
||||
parameters.get('creation_disposition', WindowsInterface.CreationDisposition.OPEN_EXISTING),
|
||||
parameters.get('flags_and_attributes', NULL)
|
||||
))
|
||||
|
||||
@classmethod
|
||||
def get_size(cls, fp):
|
||||
"""
|
||||
:type fp: asio.interfaces.windows.WindowsFile
|
||||
:rtype: int
|
||||
"""
|
||||
return WindowsInterop.get_file_size(fp.handle)
|
||||
|
||||
@classmethod
|
||||
def get_path(cls, fp):
|
||||
"""
|
||||
:type fp: asio.interfaces.windows.WindowsFile
|
||||
:rtype: str
|
||||
"""
|
||||
|
||||
if not fp.file_map:
|
||||
fp.file_map = WindowsInterop.create_file_mapping(fp.handle, WindowsInterface.Protection.READONLY)
|
||||
|
||||
if not fp.map_view:
|
||||
fp.map_view = WindowsInterop.map_view_of_file(fp.file_map, WindowsInterface.FileMapAccess.READ, 1)
|
||||
|
||||
file_name = WindowsInterop.get_mapped_file_name(fp.map_view)
|
||||
|
||||
return file_name
|
||||
|
||||
@classmethod
|
||||
def seek(cls, fp, offset, origin):
|
||||
"""
|
||||
:type fp: asio.interfaces.windows.WindowsFile
|
||||
:type offset: int
|
||||
:type origin: int
|
||||
:rtype: int
|
||||
"""
|
||||
|
||||
return WindowsInterop.set_file_pointer(
|
||||
fp.handle,
|
||||
offset,
|
||||
origin
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def read(cls, fp, n=DEFAULT_BUFFER_SIZE):
|
||||
"""
|
||||
:type fp: asio.interfaces.windows.WindowsFile
|
||||
:type n: int
|
||||
:rtype: str
|
||||
"""
|
||||
return WindowsInterop.read(fp.handle, n)
|
||||
|
||||
@classmethod
|
||||
def read_into(cls, fp, b):
|
||||
"""
|
||||
:type fp: asio.interfaces.windows.WindowsFile
|
||||
:type b: str
|
||||
:rtype: int
|
||||
"""
|
||||
return WindowsInterop.read_into(fp.handle, b)
|
||||
|
||||
@classmethod
|
||||
def close(cls, fp):
|
||||
"""
|
||||
:type fp: asio.interfaces.windows.WindowsFile
|
||||
:rtype: bool
|
||||
"""
|
||||
if fp.map_view:
|
||||
WindowsInterop.unmap_view_of_file(fp.map_view)
|
||||
|
||||
if fp.file_map:
|
||||
WindowsInterop.close_handle(fp.file_map)
|
||||
|
||||
return bool(WindowsInterop.close_handle(fp.handle))
|
||||
|
||||
class GenericAccess(object):
|
||||
READ = 0x80000000
|
||||
WRITE = 0x40000000
|
||||
EXECUTE = 0x20000000
|
||||
ALL = 0x10000000
|
||||
|
||||
class ShareMode(object):
|
||||
READ = 0x00000001
|
||||
WRITE = 0x00000002
|
||||
DELETE = 0x00000004
|
||||
ALL = READ | WRITE | DELETE
|
||||
|
||||
class CreationDisposition(object):
|
||||
CREATE_NEW = 1
|
||||
CREATE_ALWAYS = 2
|
||||
OPEN_EXISTING = 3
|
||||
OPEN_ALWAYS = 4
|
||||
TRUNCATE_EXISTING = 5
|
||||
|
||||
class Attribute(object):
|
||||
READONLY = 0x00000001
|
||||
HIDDEN = 0x00000002
|
||||
SYSTEM = 0x00000004
|
||||
DIRECTORY = 0x00000010
|
||||
ARCHIVE = 0x00000020
|
||||
DEVICE = 0x00000040
|
||||
NORMAL = 0x00000080
|
||||
TEMPORARY = 0x00000100
|
||||
SPARSE_FILE = 0x00000200
|
||||
REPARSE_POINT = 0x00000400
|
||||
COMPRESSED = 0x00000800
|
||||
OFFLINE = 0x00001000
|
||||
NOT_CONTENT_INDEXED = 0x00002000
|
||||
ENCRYPTED = 0x00004000
|
||||
|
||||
class Flag(object):
|
||||
WRITE_THROUGH = 0x80000000
|
||||
OVERLAPPED = 0x40000000
|
||||
NO_BUFFERING = 0x20000000
|
||||
RANDOM_ACCESS = 0x10000000
|
||||
SEQUENTIAL_SCAN = 0x08000000
|
||||
DELETE_ON_CLOSE = 0x04000000
|
||||
BACKUP_SEMANTICS = 0x02000000
|
||||
POSIX_SEMANTICS = 0x01000000
|
||||
OPEN_REPARSE_POINT = 0x00200000
|
||||
OPEN_NO_RECALL = 0x00100000
|
||||
FIRST_PIPE_INSTANCE = 0x00080000
|
||||
|
||||
class Protection(object):
|
||||
NOACCESS = 0x01
|
||||
READONLY = 0x02
|
||||
READWRITE = 0x04
|
||||
WRITECOPY = 0x08
|
||||
EXECUTE = 0x10
|
||||
EXECUTE_READ = 0x20,
|
||||
EXECUTE_READWRITE = 0x40
|
||||
EXECUTE_WRITECOPY = 0x80
|
||||
GUARD = 0x100
|
||||
NOCACHE = 0x200
|
||||
WRITECOMBINE = 0x400
|
||||
|
||||
class FileMapAccess(object):
|
||||
COPY = 0x0001
|
||||
WRITE = 0x0002
|
||||
READ = 0x0004
|
||||
ALL_ACCESS = 0x001f
|
||||
EXECUTE = 0x0020
|
||||
|
||||
|
||||
class WindowsFile(File):
|
||||
platform_handler = WindowsInterface
|
||||
|
||||
def __init__(self, handle, *args, **kwargs):
|
||||
super(WindowsFile, self).__init__(*args, **kwargs)
|
||||
|
||||
self.handle = handle
|
||||
|
||||
self.file_map = None
|
||||
self.map_view = None
|
||||
|
||||
def readinto(self, b):
|
||||
return self.get_handler().read_into(self, b)
|
||||
|
||||
def __str__(self):
|
||||
return "<asio_windows.WindowsFile file: %s>" % self.handle
|
|
@ -1,230 +0,0 @@
|
|||
# Copyright 2013 Dean Gardiner <gardiner91@gmail.com>
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ctypes.wintypes import *
|
||||
from ctypes import *
|
||||
import logging
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
CreateFileW = windll.kernel32.CreateFileW
|
||||
CreateFileW.argtypes = (LPCWSTR, DWORD, DWORD, c_void_p, DWORD, DWORD, HANDLE)
|
||||
CreateFileW.restype = HANDLE
|
||||
|
||||
ReadFile = windll.kernel32.ReadFile
|
||||
ReadFile.argtypes = (HANDLE, c_void_p, DWORD, POINTER(DWORD), HANDLE)
|
||||
ReadFile.restype = BOOL
|
||||
|
||||
|
||||
NULL = 0
|
||||
MAX_PATH = 260
|
||||
DEFAULT_BUFFER_SIZE = 4096
|
||||
LPSECURITY_ATTRIBUTES = c_void_p
|
||||
|
||||
|
||||
class WindowsInterop(object):
|
||||
ri_buffer = None
|
||||
|
||||
@classmethod
|
||||
def create_file(cls, path, desired_access, share_mode, creation_disposition, flags_and_attributes):
|
||||
h = CreateFileW(
|
||||
path,
|
||||
desired_access,
|
||||
share_mode,
|
||||
NULL,
|
||||
creation_disposition,
|
||||
flags_and_attributes,
|
||||
NULL
|
||||
)
|
||||
|
||||
error = GetLastError()
|
||||
if error != 0:
|
||||
raise Exception('[WindowsASIO.open] "%s"' % FormatError(error))
|
||||
|
||||
return h
|
||||
|
||||
@classmethod
|
||||
def read(cls, handle, buf_size=DEFAULT_BUFFER_SIZE):
|
||||
buf = create_string_buffer(buf_size)
|
||||
bytes_read = c_ulong(0)
|
||||
|
||||
success = ReadFile(handle, buf, buf_size, byref(bytes_read), NULL)
|
||||
|
||||
error = GetLastError()
|
||||
if error:
|
||||
log.debug('read_file - error: (%s) "%s"', error, FormatError(error))
|
||||
|
||||
if not success and error:
|
||||
raise Exception('[WindowsInterop.read_file] (%s) "%s"' % (error, FormatError(error)))
|
||||
|
||||
# Return if we have a valid buffer
|
||||
if success and bytes_read.value:
|
||||
return buf.value
|
||||
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def read_into(cls, handle, b):
|
||||
if cls.ri_buffer is None or len(cls.ri_buffer) < len(b):
|
||||
cls.ri_buffer = create_string_buffer(len(b))
|
||||
|
||||
bytes_read = c_ulong(0)
|
||||
|
||||
success = ReadFile(handle, cls.ri_buffer, len(b), byref(bytes_read), NULL)
|
||||
bytes_read = int(bytes_read.value)
|
||||
|
||||
b[:bytes_read] = cls.ri_buffer[:bytes_read]
|
||||
|
||||
error = GetLastError()
|
||||
|
||||
if not success and error:
|
||||
raise Exception('[WindowsInterop.read_file] (%s) "%s"' % (error, FormatError(error)))
|
||||
|
||||
# Return if we have a valid buffer
|
||||
if success and bytes_read:
|
||||
return bytes_read
|
||||
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def set_file_pointer(cls, handle, distance, method):
|
||||
pos_high = DWORD(NULL)
|
||||
|
||||
result = windll.kernel32.SetFilePointer(
|
||||
handle,
|
||||
c_ulong(distance),
|
||||
byref(pos_high),
|
||||
DWORD(method)
|
||||
)
|
||||
|
||||
if result == -1:
|
||||
raise Exception('[WindowsASIO.seek] INVALID_SET_FILE_POINTER: "%s"' % FormatError(GetLastError()))
|
||||
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def get_file_size(cls, handle):
|
||||
return windll.kernel32.GetFileSize(
|
||||
handle,
|
||||
DWORD(NULL)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def close_handle(cls, handle):
|
||||
return windll.kernel32.CloseHandle(handle)
|
||||
|
||||
@classmethod
|
||||
def create_file_mapping(cls, handle, protect, maximum_size_high=0, maximum_size_low=1):
|
||||
return HANDLE(windll.kernel32.CreateFileMappingW(
|
||||
handle,
|
||||
LPSECURITY_ATTRIBUTES(NULL),
|
||||
DWORD(protect),
|
||||
DWORD(maximum_size_high),
|
||||
DWORD(maximum_size_low),
|
||||
LPCSTR(NULL)
|
||||
))
|
||||
|
||||
@classmethod
|
||||
def map_view_of_file(cls, map_handle, desired_access, num_bytes, file_offset_high=0, file_offset_low=0):
|
||||
return HANDLE(windll.kernel32.MapViewOfFile(
|
||||
map_handle,
|
||||
DWORD(desired_access),
|
||||
DWORD(file_offset_high),
|
||||
DWORD(file_offset_low),
|
||||
num_bytes
|
||||
))
|
||||
|
||||
@classmethod
|
||||
def unmap_view_of_file(cls, view_handle):
|
||||
return windll.kernel32.UnmapViewOfFile(view_handle)
|
||||
|
||||
@classmethod
|
||||
def get_mapped_file_name(cls, view_handle, translate_device_name=True):
|
||||
buf = create_string_buffer(MAX_PATH + 1)
|
||||
|
||||
result = windll.psapi.GetMappedFileNameW(
|
||||
cls.get_current_process(),
|
||||
view_handle,
|
||||
buf,
|
||||
MAX_PATH
|
||||
)
|
||||
|
||||
# Raise exception on error
|
||||
error = GetLastError()
|
||||
if result == 0:
|
||||
raise Exception(FormatError(error))
|
||||
|
||||
# Retrieve a clean file name (skipping over NUL bytes)
|
||||
file_name = cls.clean_buffer_value(buf)
|
||||
|
||||
# If we are not translating the device name return here
|
||||
if not translate_device_name:
|
||||
return file_name
|
||||
|
||||
drives = cls.get_logical_drive_strings()
|
||||
|
||||
# Find the drive matching the file_name device name
|
||||
translated = False
|
||||
for drive in drives:
|
||||
device_name = cls.query_dos_device(drive)
|
||||
|
||||
if file_name.startswith(device_name):
|
||||
file_name = drive + file_name[len(device_name):]
|
||||
translated = True
|
||||
break
|
||||
|
||||
if not translated:
|
||||
raise Exception('Unable to translate device name')
|
||||
|
||||
return file_name
|
||||
|
||||
@classmethod
|
||||
def get_logical_drive_strings(cls, buf_size=512):
|
||||
buf = create_string_buffer(buf_size)
|
||||
|
||||
result = windll.kernel32.GetLogicalDriveStringsW(buf_size, buf)
|
||||
|
||||
error = GetLastError()
|
||||
if result == 0:
|
||||
raise Exception(FormatError(error))
|
||||
|
||||
drive_strings = cls.clean_buffer_value(buf)
|
||||
return [dr for dr in drive_strings.split('\\') if dr != '']
|
||||
|
||||
@classmethod
|
||||
def query_dos_device(cls, drive, buf_size=MAX_PATH):
|
||||
buf = create_string_buffer(buf_size)
|
||||
|
||||
result = windll.kernel32.QueryDosDeviceA(
|
||||
drive,
|
||||
buf,
|
||||
buf_size
|
||||
)
|
||||
|
||||
return cls.clean_buffer_value(buf)
|
||||
|
||||
@classmethod
|
||||
def get_current_process(cls):
|
||||
return HANDLE(windll.kernel32.GetCurrentProcess())
|
||||
|
||||
@classmethod
|
||||
def clean_buffer_value(cls, buf):
|
||||
value = ""
|
||||
|
||||
for ch in buf.raw:
|
||||
if ord(ch) != 0:
|
||||
value += ch
|
||||
|
||||
return value
|
|
@ -1,47 +0,0 @@
|
|||
from asio.interfaces.posix import PosixInterface
|
||||
from asio.interfaces.windows import WindowsInterface
|
||||
|
||||
|
||||
class OpenParameters(object):
|
||||
def __init__(self):
|
||||
self.handlers = {}
|
||||
|
||||
# Update handler_parameters with defaults
|
||||
self.posix()
|
||||
self.windows()
|
||||
|
||||
def posix(self, mode=None, buffering=None):
|
||||
"""
|
||||
:type mode: str
|
||||
:type buffering: int
|
||||
"""
|
||||
self.handlers.update({PosixInterface: {
|
||||
'mode': mode,
|
||||
'buffering': buffering
|
||||
}})
|
||||
|
||||
def windows(self, desired_access=WindowsInterface.GenericAccess.READ,
|
||||
share_mode=WindowsInterface.ShareMode.ALL,
|
||||
creation_disposition=WindowsInterface.CreationDisposition.OPEN_EXISTING,
|
||||
flags_and_attributes=0):
|
||||
|
||||
"""
|
||||
:param desired_access: WindowsInterface.DesiredAccess
|
||||
:type desired_access: int
|
||||
|
||||
:param share_mode: WindowsInterface.ShareMode
|
||||
:type share_mode: int
|
||||
|
||||
:param creation_disposition: WindowsInterface.CreationDisposition
|
||||
:type creation_disposition: int
|
||||
|
||||
:param flags_and_attributes: WindowsInterface.Attribute, WindowsInterface.Flag
|
||||
:type flags_and_attributes: int
|
||||
"""
|
||||
|
||||
self.handlers.update({WindowsInterface: {
|
||||
'desired_access': desired_access,
|
||||
'share_mode': share_mode,
|
||||
'creation_disposition': creation_disposition,
|
||||
'flags_and_attributes': flags_and_attributes
|
||||
}})
|
|
@ -2,20 +2,16 @@
|
|||
:author:
|
||||
|
||||
Amine SEHILI <amine.sehili@gmail.com>
|
||||
2015-2016
|
||||
2015-2021
|
||||
|
||||
:License:
|
||||
|
||||
This package is published under GNU GPL Version 3.
|
||||
This package is published under the MIT license.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from .core import *
|
||||
from .io import *
|
||||
from .util import *
|
||||
from . import dataset
|
||||
from .exceptions import *
|
||||
|
||||
__version__ = "0.1.5"
|
||||
|
||||
|
||||
__version__ = "0.2.0"
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,126 @@
|
|||
import sys
|
||||
import logging
|
||||
from collections import namedtuple
|
||||
from . import workers
|
||||
from .util import AudioDataSource
|
||||
from .io import player_for
|
||||
|
||||
_AUDITOK_LOGGER = "AUDITOK_LOGGER"
|
||||
KeywordArguments = namedtuple(
|
||||
"KeywordArguments", ["io", "split", "miscellaneous"]
|
||||
)
|
||||
|
||||
|
||||
def make_kwargs(args_ns):
|
||||
if args_ns.save_stream is None:
|
||||
record = args_ns.plot or (args_ns.save_image is not None)
|
||||
else:
|
||||
record = False
|
||||
try:
|
||||
use_channel = int(args_ns.use_channel)
|
||||
except (ValueError, TypeError):
|
||||
use_channel = args_ns.use_channel
|
||||
|
||||
io_kwargs = {
|
||||
"input": args_ns.input,
|
||||
"audio_format": args_ns.input_format,
|
||||
"max_read": args_ns.max_read,
|
||||
"block_dur": args_ns.analysis_window,
|
||||
"sampling_rate": args_ns.sampling_rate,
|
||||
"sample_width": args_ns.sample_width,
|
||||
"channels": args_ns.channels,
|
||||
"use_channel": use_channel,
|
||||
"save_stream": args_ns.save_stream,
|
||||
"save_detections_as": args_ns.save_detections_as,
|
||||
"export_format": args_ns.output_format,
|
||||
"large_file": args_ns.large_file,
|
||||
"frames_per_buffer": args_ns.frame_per_buffer,
|
||||
"input_device_index": args_ns.input_device_index,
|
||||
"record": record,
|
||||
}
|
||||
|
||||
split_kwargs = {
|
||||
"min_dur": args_ns.min_duration,
|
||||
"max_dur": args_ns.max_duration,
|
||||
"max_silence": args_ns.max_silence,
|
||||
"drop_trailing_silence": args_ns.drop_trailing_silence,
|
||||
"strict_min_dur": args_ns.strict_min_duration,
|
||||
"energy_threshold": args_ns.energy_threshold,
|
||||
}
|
||||
|
||||
miscellaneous = {
|
||||
"echo": args_ns.echo,
|
||||
"progress_bar": args_ns.progress_bar,
|
||||
"command": args_ns.command,
|
||||
"quiet": args_ns.quiet,
|
||||
"printf": args_ns.printf,
|
||||
"time_format": args_ns.time_format,
|
||||
"timestamp_format": args_ns.timestamp_format,
|
||||
}
|
||||
return KeywordArguments(io_kwargs, split_kwargs, miscellaneous)
|
||||
|
||||
|
||||
def make_logger(stderr=False, file=None, name=_AUDITOK_LOGGER):
|
||||
if not stderr and file is None:
|
||||
return None
|
||||
logger = logging.getLogger(name)
|
||||
logger.setLevel(logging.INFO)
|
||||
if stderr:
|
||||
handler = logging.StreamHandler(sys.stderr)
|
||||
handler.setLevel(logging.INFO)
|
||||
logger.addHandler(handler)
|
||||
|
||||
if file is not None:
|
||||
handler = logging.FileHandler(file, "w")
|
||||
fmt = logging.Formatter("[%(asctime)s] | %(message)s")
|
||||
handler.setFormatter(fmt)
|
||||
handler.setLevel(logging.INFO)
|
||||
logger.addHandler(handler)
|
||||
return logger
|
||||
|
||||
|
||||
def initialize_workers(logger=None, **kwargs):
|
||||
observers = []
|
||||
reader = AudioDataSource(source=kwargs["input"], **kwargs)
|
||||
if kwargs["save_stream"] is not None:
|
||||
reader = workers.StreamSaverWorker(
|
||||
reader,
|
||||
filename=kwargs["save_stream"],
|
||||
export_format=kwargs["export_format"],
|
||||
)
|
||||
reader.start()
|
||||
|
||||
if kwargs["save_detections_as"] is not None:
|
||||
worker = workers.RegionSaverWorker(
|
||||
kwargs["save_detections_as"],
|
||||
kwargs["export_format"],
|
||||
logger=logger,
|
||||
)
|
||||
observers.append(worker)
|
||||
|
||||
if kwargs["echo"]:
|
||||
player = player_for(reader)
|
||||
worker = workers.PlayerWorker(
|
||||
player, progress_bar=kwargs["progress_bar"], logger=logger
|
||||
)
|
||||
observers.append(worker)
|
||||
|
||||
if kwargs["command"] is not None:
|
||||
worker = workers.CommandLineWorker(
|
||||
command=kwargs["command"], logger=logger
|
||||
)
|
||||
observers.append(worker)
|
||||
|
||||
if not kwargs["quiet"]:
|
||||
print_format = (
|
||||
kwargs["printf"]
|
||||
.replace("\\n", "\n")
|
||||
.replace("\\t", "\t")
|
||||
.replace("\\r", "\r")
|
||||
)
|
||||
worker = workers.PrintWorker(
|
||||
print_format, kwargs["time_format"], kwargs["timestamp_format"]
|
||||
)
|
||||
observers.append(worker)
|
||||
|
||||
return reader, observers
|
1654
libs/auditok/core.py
1654
libs/auditok/core.py
File diff suppressed because it is too large
Load Diff
|
@ -1,19 +1,31 @@
|
|||
"""
|
||||
This module contains links to audio files you can use for test purposes.
|
||||
This module contains links to audio files that can be used for test purposes.
|
||||
|
||||
.. autosummary::
|
||||
:toctree: generated/
|
||||
|
||||
one_to_six_arabic_16000_mono_bc_noise
|
||||
was_der_mensch_saet_mono_44100_lead_trail_silence
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
__all__ = ["one_to_six_arabic_16000_mono_bc_noise", "was_der_mensch_saet_mono_44100_lead_trail_silence"]
|
||||
__all__ = [
|
||||
"one_to_six_arabic_16000_mono_bc_noise",
|
||||
"was_der_mensch_saet_mono_44100_lead_trail_silence",
|
||||
]
|
||||
|
||||
_current_dir = os.path.dirname(os.path.realpath(__file__))
|
||||
|
||||
one_to_six_arabic_16000_mono_bc_noise = "{cd}{sep}data{sep}1to6arabic_\
|
||||
16000_mono_bc_noise.wav".format(cd=_current_dir, sep=os.path.sep)
|
||||
16000_mono_bc_noise.wav".format(
|
||||
cd=_current_dir, sep=os.path.sep
|
||||
)
|
||||
"""A wave file that contains a pronunciation of Arabic numbers from 1 to 6"""
|
||||
|
||||
|
||||
was_der_mensch_saet_mono_44100_lead_trail_silence = "{cd}{sep}data{sep}was_\
|
||||
der_mensch_saet_das_wird_er_vielfach_ernten_44100Hz_mono_lead_trail_\
|
||||
silence.wav".format(cd=_current_dir, sep=os.path.sep)
|
||||
""" A wave file that contains a sentence between long leading and trailing periods of silence"""
|
||||
silence.wav".format(
|
||||
cd=_current_dir, sep=os.path.sep
|
||||
)
|
||||
"""A wave file that contains a sentence with a long leading and trailing silence"""
|
||||
|
|
|
@ -1,9 +1,41 @@
|
|||
"""
|
||||
November 2015
|
||||
@author: Amine SEHILI <amine.sehili@gmail.com>
|
||||
"""
|
||||
|
||||
class DuplicateArgument(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class TooSamllBlockDuration(ValueError):
|
||||
"""Raised when block_dur results in a block_size smaller than one sample."""
|
||||
|
||||
def __init__(self, message, block_dur, sampling_rate):
|
||||
self.block_dur = block_dur
|
||||
self.sampling_rate = sampling_rate
|
||||
super(TooSamllBlockDuration, self).__init__(message)
|
||||
|
||||
|
||||
class TimeFormatError(Exception):
|
||||
"""Raised when a duration formatting directive is unknown."""
|
||||
|
||||
|
||||
class EndOfProcessing(Exception):
|
||||
"""Raised within command line script's main function to jump to
|
||||
postprocessing code."""
|
||||
|
||||
|
||||
class AudioIOError(Exception):
|
||||
"""Raised when a compressed audio file cannot be loaded or when trying
|
||||
to read from a not yet open AudioSource"""
|
||||
|
||||
|
||||
class AudioParameterError(AudioIOError):
|
||||
"""Raised when one audio parameter is missing when loading raw data or
|
||||
saving data to a format other than raw. Also raised when an audio
|
||||
parameter has a wrong value."""
|
||||
|
||||
|
||||
class AudioEncodingError(Exception):
|
||||
"""Raised if audio data can not be encoded in the provided format"""
|
||||
|
||||
|
||||
class AudioEncodingWarning(RuntimeWarning):
|
||||
"""Raised if audio data can not be encoded in the provided format
|
||||
but saved as wav.
|
||||
"""
|
||||
|
|
1270
libs/auditok/io.py
1270
libs/auditok/io.py
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,150 @@
|
|||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
|
||||
AUDITOK_PLOT_THEME = {
|
||||
"figure": {"facecolor": "#482a36", "alpha": 0.2},
|
||||
"plot": {"facecolor": "#282a36"},
|
||||
"energy_threshold": {
|
||||
"color": "#e31f8f",
|
||||
"linestyle": "--",
|
||||
"linewidth": 1,
|
||||
},
|
||||
"signal": {"color": "#40d970", "linestyle": "-", "linewidth": 1},
|
||||
"detections": {
|
||||
"facecolor": "#777777",
|
||||
"edgecolor": "#ff8c1a",
|
||||
"linewidth": 1,
|
||||
"alpha": 0.75,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _make_time_axis(nb_samples, sampling_rate):
|
||||
sample_duration = 1 / sampling_rate
|
||||
x = np.linspace(0, sample_duration * (nb_samples - 1), nb_samples)
|
||||
return x
|
||||
|
||||
|
||||
def _plot_line(x, y, theme, xlabel=None, ylabel=None, **kwargs):
|
||||
color = theme.get("color", theme.get("c"))
|
||||
ls = theme.get("linestyle", theme.get("ls"))
|
||||
lw = theme.get("linewidth", theme.get("lw"))
|
||||
plt.plot(x, y, c=color, ls=ls, lw=lw, **kwargs)
|
||||
plt.xlabel(xlabel, fontsize=8)
|
||||
plt.ylabel(ylabel, fontsize=8)
|
||||
|
||||
|
||||
def _plot_detections(subplot, detections, theme):
|
||||
fc = theme.get("facecolor", theme.get("fc"))
|
||||
ec = theme.get("edgecolor", theme.get("ec"))
|
||||
ls = theme.get("linestyle", theme.get("ls"))
|
||||
lw = theme.get("linewidth", theme.get("lw"))
|
||||
alpha = theme.get("alpha")
|
||||
for (start, end) in detections:
|
||||
subplot.axvspan(start, end, fc=fc, ec=ec, ls=ls, lw=lw, alpha=alpha)
|
||||
|
||||
|
||||
def plot(
|
||||
audio_region,
|
||||
scale_signal=True,
|
||||
detections=None,
|
||||
energy_threshold=None,
|
||||
show=True,
|
||||
figsize=None,
|
||||
save_as=None,
|
||||
dpi=120,
|
||||
theme="auditok",
|
||||
):
|
||||
y = np.asarray(audio_region)
|
||||
if len(y.shape) == 1:
|
||||
y = y.reshape(1, -1)
|
||||
nb_subplots, nb_samples = y.shape
|
||||
sampling_rate = audio_region.sampling_rate
|
||||
time_axis = _make_time_axis(nb_samples, sampling_rate)
|
||||
if energy_threshold is not None:
|
||||
eth_log10 = energy_threshold * np.log(10) / 10
|
||||
amplitude_threshold = np.sqrt(np.exp(eth_log10))
|
||||
else:
|
||||
amplitude_threshold = None
|
||||
if detections is None:
|
||||
detections = []
|
||||
else:
|
||||
# End of detection corresponds to the end of the last sample but
|
||||
# to stay compatible with the time axis of signal plotting we want end
|
||||
# of detection to correspond to the *start* of the that last sample.
|
||||
detections = [
|
||||
(start, end - (1 / sampling_rate)) for (start, end) in detections
|
||||
]
|
||||
if theme == "auditok":
|
||||
theme = AUDITOK_PLOT_THEME
|
||||
|
||||
fig = plt.figure(figsize=figsize, dpi=dpi)
|
||||
fig_theme = theme.get("figure", theme.get("fig", {}))
|
||||
fig_fc = fig_theme.get("facecolor", fig_theme.get("ffc"))
|
||||
fig_alpha = fig_theme.get("alpha", 1)
|
||||
fig.patch.set_facecolor(fig_fc)
|
||||
fig.patch.set_alpha(fig_alpha)
|
||||
|
||||
plot_theme = theme.get("plot", {})
|
||||
plot_fc = plot_theme.get("facecolor", plot_theme.get("pfc"))
|
||||
|
||||
if nb_subplots > 2 and nb_subplots % 2 == 0:
|
||||
nb_rows = nb_subplots // 2
|
||||
nb_columns = 2
|
||||
else:
|
||||
nb_rows = nb_subplots
|
||||
nb_columns = 1
|
||||
|
||||
for sid, samples in enumerate(y, 1):
|
||||
ax = fig.add_subplot(nb_rows, nb_columns, sid)
|
||||
ax.set_facecolor(plot_fc)
|
||||
if scale_signal:
|
||||
std = samples.std()
|
||||
if std > 0:
|
||||
mean = samples.mean()
|
||||
std = samples.std()
|
||||
samples = (samples - mean) / std
|
||||
max_ = samples.max()
|
||||
plt.ylim(-1.5 * max_, 1.5 * max_)
|
||||
if amplitude_threshold is not None:
|
||||
if scale_signal and std > 0:
|
||||
amp_th = (amplitude_threshold - mean) / std
|
||||
else:
|
||||
amp_th = amplitude_threshold
|
||||
eth_theme = theme.get("energy_threshold", theme.get("eth", {}))
|
||||
_plot_line(
|
||||
[time_axis[0], time_axis[-1]],
|
||||
[amp_th] * 2,
|
||||
eth_theme,
|
||||
label="Detection threshold",
|
||||
)
|
||||
if sid == 1:
|
||||
legend = plt.legend(
|
||||
["Detection threshold"],
|
||||
facecolor=fig_fc,
|
||||
framealpha=0.1,
|
||||
bbox_to_anchor=(0.0, 1.15, 1.0, 0.102),
|
||||
loc=2,
|
||||
)
|
||||
legend = plt.gca().add_artist(legend)
|
||||
|
||||
signal_theme = theme.get("signal", {})
|
||||
_plot_line(
|
||||
time_axis,
|
||||
samples,
|
||||
signal_theme,
|
||||
xlabel="Time (seconds)",
|
||||
ylabel="Signal{}".format(" (scaled)" if scale_signal else ""),
|
||||
)
|
||||
detections_theme = theme.get("detections", {})
|
||||
_plot_detections(ax, detections, detections_theme)
|
||||
plt.title("Channel {}".format(sid), fontsize=10)
|
||||
|
||||
plt.xticks(fontsize=8)
|
||||
plt.yticks(fontsize=8)
|
||||
plt.tight_layout()
|
||||
|
||||
if save_as is not None:
|
||||
plt.savefig(save_as, dpi=dpi)
|
||||
if show:
|
||||
plt.show()
|
|
@ -0,0 +1,179 @@
|
|||
"""
|
||||
Module for basic audio signal processing and array operations.
|
||||
|
||||
.. autosummary::
|
||||
:toctree: generated/
|
||||
|
||||
to_array
|
||||
extract_single_channel
|
||||
compute_average_channel
|
||||
compute_average_channel_stereo
|
||||
separate_channels
|
||||
calculate_energy_single_channel
|
||||
calculate_energy_multichannel
|
||||
"""
|
||||
from array import array as array_
|
||||
import audioop
|
||||
import math
|
||||
|
||||
FORMAT = {1: "b", 2: "h", 4: "i"}
|
||||
_EPSILON = 1e-10
|
||||
|
||||
|
||||
def to_array(data, sample_width, channels):
|
||||
"""Extract individual channels of audio data and return a list of arrays of
|
||||
numeric samples. This will always return a list of `array.array` objects
|
||||
(one per channel) even if audio data is mono.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data : bytes
|
||||
raw audio data.
|
||||
sample_width : int
|
||||
size in bytes of one audio sample (one channel considered).
|
||||
|
||||
Returns
|
||||
-------
|
||||
samples_arrays : list
|
||||
list of arrays of audio samples.
|
||||
"""
|
||||
fmt = FORMAT[sample_width]
|
||||
if channels == 1:
|
||||
return [array_(fmt, data)]
|
||||
return separate_channels(data, fmt, channels)
|
||||
|
||||
|
||||
def extract_single_channel(data, fmt, channels, selected):
|
||||
samples = array_(fmt, data)
|
||||
return samples[selected::channels]
|
||||
|
||||
|
||||
def compute_average_channel(data, fmt, channels):
|
||||
"""
|
||||
Compute and return average channel of multi-channel audio data. If the
|
||||
number of channels is 2, use :func:`compute_average_channel_stereo` (much
|
||||
faster). This function uses satandard `array` module to convert `bytes` data
|
||||
into an array of numeric values.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data : bytes
|
||||
multi-channel audio data to mix down.
|
||||
fmt : str
|
||||
format (single character) to pass to `array.array` to convert `data`
|
||||
into an array of samples. This should be "b" if audio data's sample width
|
||||
is 1, "h" if it's 2 and "i" if it's 4.
|
||||
channels : int
|
||||
number of channels of audio data.
|
||||
|
||||
Returns
|
||||
-------
|
||||
mono_audio : bytes
|
||||
mixed down audio data.
|
||||
"""
|
||||
all_channels = array_(fmt, data)
|
||||
mono_channels = [
|
||||
array_(fmt, all_channels[ch::channels]) for ch in range(channels)
|
||||
]
|
||||
avg_arr = array_(
|
||||
fmt,
|
||||
(round(sum(samples) / channels) for samples in zip(*mono_channels)),
|
||||
)
|
||||
return avg_arr
|
||||
|
||||
|
||||
def compute_average_channel_stereo(data, sample_width):
|
||||
"""Compute and return average channel of stereo audio data. This function
|
||||
should be used when the number of channels is exactly 2 because in that
|
||||
case we can use standard `audioop` module which *much* faster then calling
|
||||
:func:`compute_average_channel`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data : bytes
|
||||
2-channel audio data to mix down.
|
||||
sample_width : int
|
||||
size in bytes of one audio sample (one channel considered).
|
||||
|
||||
Returns
|
||||
-------
|
||||
mono_audio : bytes
|
||||
mixed down audio data.
|
||||
"""
|
||||
fmt = FORMAT[sample_width]
|
||||
arr = array_(fmt, audioop.tomono(data, sample_width, 0.5, 0.5))
|
||||
return arr
|
||||
|
||||
|
||||
def separate_channels(data, fmt, channels):
|
||||
"""Create a list of arrays of audio samples (`array.array` objects), one for
|
||||
each channel.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data : bytes
|
||||
multi-channel audio data to mix down.
|
||||
fmt : str
|
||||
format (single character) to pass to `array.array` to convert `data`
|
||||
into an array of samples. This should be "b" if audio data's sample width
|
||||
is 1, "h" if it's 2 and "i" if it's 4.
|
||||
channels : int
|
||||
number of channels of audio data.
|
||||
|
||||
Returns
|
||||
-------
|
||||
channels_arr : list
|
||||
list of audio channels, each as a standard `array.array`.
|
||||
"""
|
||||
all_channels = array_(fmt, data)
|
||||
mono_channels = [
|
||||
array_(fmt, all_channels[ch::channels]) for ch in range(channels)
|
||||
]
|
||||
return mono_channels
|
||||
|
||||
|
||||
def calculate_energy_single_channel(data, sample_width):
|
||||
"""Calculate the energy of mono audio data. Energy is computed as:
|
||||
|
||||
.. math:: energy = 20 \log(\sqrt({1}/{N}\sum_{i}^{N}{a_i}^2)) % # noqa: W605
|
||||
|
||||
where `a_i` is the i-th audio sample and `N` is the number of audio samples
|
||||
in data.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data : bytes
|
||||
single-channel audio data.
|
||||
sample_width : int
|
||||
size in bytes of one audio sample.
|
||||
|
||||
Returns
|
||||
-------
|
||||
energy : float
|
||||
energy of audio signal.
|
||||
"""
|
||||
energy_sqrt = max(audioop.rms(data, sample_width), _EPSILON)
|
||||
return 20 * math.log10(energy_sqrt)
|
||||
|
||||
|
||||
def calculate_energy_multichannel(x, sample_width, aggregation_fn=max):
|
||||
"""Calculate the energy of multi-channel audio data. Energy is calculated
|
||||
channel-wise. An aggregation function is applied to the resulting energies
|
||||
(default: `max`). Also see :func:`calculate_energy_single_channel`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data : bytes
|
||||
single-channel audio data.
|
||||
sample_width : int
|
||||
size in bytes of one audio sample (one channel considered).
|
||||
aggregation_fn : callable, default: max
|
||||
aggregation function to apply to the resulting per-channel energies.
|
||||
|
||||
Returns
|
||||
-------
|
||||
energy : float
|
||||
aggregated energy of multi-channel audio signal.
|
||||
"""
|
||||
energies = (calculate_energy_single_channel(xi, sample_width) for xi in x)
|
||||
return aggregation_fn(energies)
|
|
@ -0,0 +1,30 @@
|
|||
import numpy as np
|
||||
from .signal import (
|
||||
compute_average_channel_stereo,
|
||||
calculate_energy_single_channel,
|
||||
calculate_energy_multichannel,
|
||||
)
|
||||
|
||||
FORMAT = {1: np.int8, 2: np.int16, 4: np.int32}
|
||||
|
||||
|
||||
def to_array(data, sample_width, channels):
|
||||
fmt = FORMAT[sample_width]
|
||||
if channels == 1:
|
||||
return np.frombuffer(data, dtype=fmt).astype(np.float64)
|
||||
return separate_channels(data, fmt, channels).astype(np.float64)
|
||||
|
||||
|
||||
def extract_single_channel(data, fmt, channels, selected):
|
||||
samples = np.frombuffer(data, dtype=fmt)
|
||||
return np.asanyarray(samples[selected::channels], order="C")
|
||||
|
||||
|
||||
def compute_average_channel(data, fmt, channels):
|
||||
array = np.frombuffer(data, dtype=fmt).astype(np.float64)
|
||||
return array.reshape(-1, channels).mean(axis=1).round().astype(fmt)
|
||||
|
||||
|
||||
def separate_channels(data, fmt, channels):
|
||||
array = np.frombuffer(data, dtype=fmt)
|
||||
return np.asanyarray(array.reshape(-1, channels).T, order="C")
|
1734
libs/auditok/util.py
1734
libs/auditok/util.py
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,427 @@
|
|||
import os
|
||||
import sys
|
||||
from tempfile import NamedTemporaryFile
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from threading import Thread
|
||||
from datetime import datetime, timedelta
|
||||
from collections import namedtuple
|
||||
import wave
|
||||
import subprocess
|
||||
from queue import Queue, Empty
|
||||
from .io import _guess_audio_format
|
||||
from .util import AudioDataSource, make_duration_formatter
|
||||
from .core import split
|
||||
from .exceptions import (
|
||||
EndOfProcessing,
|
||||
AudioEncodingError,
|
||||
AudioEncodingWarning,
|
||||
)
|
||||
|
||||
|
||||
_STOP_PROCESSING = "STOP_PROCESSING"
|
||||
_Detection = namedtuple("_Detection", "id start end duration")
|
||||
|
||||
|
||||
def _run_subprocess(command):
|
||||
try:
|
||||
with subprocess.Popen(
|
||||
command,
|
||||
stdin=open(os.devnull, "rb"),
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
) as proc:
|
||||
stdout, stderr = proc.communicate()
|
||||
return proc.returncode, stdout, stderr
|
||||
except Exception:
|
||||
err_msg = "Couldn't export audio using command: '{}'".format(command)
|
||||
raise AudioEncodingError(err_msg)
|
||||
|
||||
|
||||
class Worker(Thread, metaclass=ABCMeta):
|
||||
def __init__(self, timeout=0.5, logger=None):
|
||||
self._timeout = timeout
|
||||
self._logger = logger
|
||||
self._inbox = Queue()
|
||||
Thread.__init__(self)
|
||||
|
||||
def run(self):
|
||||
while True:
|
||||
message = self._get_message()
|
||||
if message == _STOP_PROCESSING:
|
||||
break
|
||||
if message is not None:
|
||||
self._process_message(message)
|
||||
self._post_process()
|
||||
|
||||
@abstractmethod
|
||||
def _process_message(self, message):
|
||||
"""Process incoming messages"""
|
||||
|
||||
def _post_process(self):
|
||||
pass
|
||||
|
||||
def _log(self, message):
|
||||
self._logger.info(message)
|
||||
|
||||
def _stop_requested(self):
|
||||
try:
|
||||
message = self._inbox.get_nowait()
|
||||
if message == _STOP_PROCESSING:
|
||||
return True
|
||||
except Empty:
|
||||
return False
|
||||
|
||||
def stop(self):
|
||||
self.send(_STOP_PROCESSING)
|
||||
self.join()
|
||||
|
||||
def send(self, message):
|
||||
self._inbox.put(message)
|
||||
|
||||
def _get_message(self):
|
||||
try:
|
||||
message = self._inbox.get(timeout=self._timeout)
|
||||
return message
|
||||
except Empty:
|
||||
return None
|
||||
|
||||
|
||||
class TokenizerWorker(Worker, AudioDataSource):
|
||||
def __init__(self, reader, observers=None, logger=None, **kwargs):
|
||||
self._observers = observers if observers is not None else []
|
||||
self._reader = reader
|
||||
self._audio_region_gen = split(self, **kwargs)
|
||||
self._detections = []
|
||||
self._log_format = "[DET]: Detection {0.id} (start: {0.start:.3f}, "
|
||||
self._log_format += "end: {0.end:.3f}, duration: {0.duration:.3f})"
|
||||
Worker.__init__(self, timeout=0.2, logger=logger)
|
||||
|
||||
def _process_message(self):
|
||||
pass
|
||||
|
||||
@property
|
||||
def detections(self):
|
||||
return self._detections
|
||||
|
||||
def _notify_observers(self, message):
|
||||
for observer in self._observers:
|
||||
observer.send(message)
|
||||
|
||||
def run(self):
|
||||
self._reader.open()
|
||||
start_processing_timestamp = datetime.now()
|
||||
for _id, audio_region in enumerate(self._audio_region_gen, start=1):
|
||||
timestamp = start_processing_timestamp + timedelta(
|
||||
seconds=audio_region.meta.start
|
||||
)
|
||||
audio_region.meta.timestamp = timestamp
|
||||
detection = _Detection(
|
||||
_id,
|
||||
audio_region.meta.start,
|
||||
audio_region.meta.end,
|
||||
audio_region.duration,
|
||||
)
|
||||
self._detections.append(detection)
|
||||
if self._logger is not None:
|
||||
message = self._log_format.format(detection)
|
||||
self._log(message)
|
||||
self._notify_observers((_id, audio_region))
|
||||
self._notify_observers(_STOP_PROCESSING)
|
||||
self._reader.close()
|
||||
|
||||
def start_all(self):
|
||||
for observer in self._observers:
|
||||
observer.start()
|
||||
self.start()
|
||||
|
||||
def stop_all(self):
|
||||
self.stop()
|
||||
for observer in self._observers:
|
||||
observer.stop()
|
||||
self._reader.close()
|
||||
|
||||
def read(self):
|
||||
if self._stop_requested():
|
||||
return None
|
||||
else:
|
||||
return self._reader.read()
|
||||
|
||||
def __getattr__(self, name):
|
||||
return getattr(self._reader, name)
|
||||
|
||||
|
||||
class StreamSaverWorker(Worker):
|
||||
def __init__(
|
||||
self,
|
||||
audio_reader,
|
||||
filename,
|
||||
export_format=None,
|
||||
cache_size_sec=0.5,
|
||||
timeout=0.2,
|
||||
):
|
||||
self._reader = audio_reader
|
||||
sample_size_bytes = self._reader.sw * self._reader.ch
|
||||
self._cache_size = cache_size_sec * self._reader.sr * sample_size_bytes
|
||||
self._output_filename = filename
|
||||
self._export_format = _guess_audio_format(export_format, filename)
|
||||
if self._export_format is None:
|
||||
self._export_format = "wav"
|
||||
self._init_output_stream()
|
||||
self._exported = False
|
||||
self._cache = []
|
||||
self._total_cached = 0
|
||||
Worker.__init__(self, timeout=timeout)
|
||||
|
||||
def _get_non_existent_filename(self):
|
||||
filename = self._output_filename + ".wav"
|
||||
i = 0
|
||||
while os.path.exists(filename):
|
||||
i += 1
|
||||
filename = self._output_filename + "({}).wav".format(i)
|
||||
return filename
|
||||
|
||||
def _init_output_stream(self):
|
||||
if self._export_format != "wav":
|
||||
self._tmp_output_filename = self._get_non_existent_filename()
|
||||
else:
|
||||
self._tmp_output_filename = self._output_filename
|
||||
self._wfp = wave.open(self._tmp_output_filename, "wb")
|
||||
self._wfp.setframerate(self._reader.sr)
|
||||
self._wfp.setsampwidth(self._reader.sw)
|
||||
self._wfp.setnchannels(self._reader.ch)
|
||||
|
||||
@property
|
||||
def sr(self):
|
||||
return self._reader.sampling_rate
|
||||
|
||||
@property
|
||||
def sw(self):
|
||||
return self._reader.sample_width
|
||||
|
||||
@property
|
||||
def ch(self):
|
||||
return self._reader.channels
|
||||
|
||||
def __del__(self):
|
||||
self._post_process()
|
||||
|
||||
if (
|
||||
(self._tmp_output_filename != self._output_filename)
|
||||
and self._exported
|
||||
and os.path.exists(self._tmp_output_filename)
|
||||
):
|
||||
os.remove(self._tmp_output_filename)
|
||||
|
||||
def _process_message(self, data):
|
||||
self._cache.append(data)
|
||||
self._total_cached += len(data)
|
||||
if self._total_cached >= self._cache_size:
|
||||
self._write_cached_data()
|
||||
|
||||
def _post_process(self):
|
||||
while True:
|
||||
try:
|
||||
data = self._inbox.get_nowait()
|
||||
if data != _STOP_PROCESSING:
|
||||
self._cache.append(data)
|
||||
self._total_cached += len(data)
|
||||
except Empty:
|
||||
break
|
||||
self._write_cached_data()
|
||||
self._wfp.close()
|
||||
|
||||
def _write_cached_data(self):
|
||||
if self._cache:
|
||||
data = b"".join(self._cache)
|
||||
self._wfp.writeframes(data)
|
||||
self._cache = []
|
||||
self._total_cached = 0
|
||||
|
||||
def open(self):
|
||||
self._reader.open()
|
||||
|
||||
def close(self):
|
||||
self._reader.close()
|
||||
self.stop()
|
||||
|
||||
def rewind(self):
|
||||
# ensure compatibility with AudioDataSource with record=True
|
||||
pass
|
||||
|
||||
@property
|
||||
def data(self):
|
||||
with wave.open(self._tmp_output_filename, "rb") as wfp:
|
||||
return wfp.readframes(-1)
|
||||
|
||||
def save_stream(self):
|
||||
if self._exported:
|
||||
return self._output_filename
|
||||
|
||||
if self._export_format in ("raw", "wav"):
|
||||
if self._export_format == "raw":
|
||||
self._export_raw()
|
||||
self._exported = True
|
||||
return self._output_filename
|
||||
try:
|
||||
self._export_with_ffmpeg_or_avconv()
|
||||
except AudioEncodingError:
|
||||
try:
|
||||
self._export_with_sox()
|
||||
except AudioEncodingError:
|
||||
warn_msg = "Couldn't save audio data in the desired format "
|
||||
warn_msg += "'{}'. Either none of 'ffmpeg', 'avconv' or 'sox' "
|
||||
warn_msg += "is installed or this format is not recognized.\n"
|
||||
warn_msg += "Audio file was saved as '{}'"
|
||||
raise AudioEncodingWarning(
|
||||
warn_msg.format(
|
||||
self._export_format, self._tmp_output_filename
|
||||
)
|
||||
)
|
||||
finally:
|
||||
self._exported = True
|
||||
return self._output_filename
|
||||
|
||||
def _export_raw(self):
|
||||
with open(self._output_filename, "wb") as wfp:
|
||||
wfp.write(self.data)
|
||||
|
||||
def _export_with_ffmpeg_or_avconv(self):
|
||||
command = [
|
||||
"-y",
|
||||
"-f",
|
||||
"wav",
|
||||
"-i",
|
||||
self._tmp_output_filename,
|
||||
"-f",
|
||||
self._export_format,
|
||||
self._output_filename,
|
||||
]
|
||||
returncode, stdout, stderr = _run_subprocess(["ffmpeg"] + command)
|
||||
if returncode != 0:
|
||||
returncode, stdout, stderr = _run_subprocess(["avconv"] + command)
|
||||
if returncode != 0:
|
||||
raise AudioEncodingError(stderr)
|
||||
return stdout, stderr
|
||||
|
||||
def _export_with_sox(self):
|
||||
command = [
|
||||
"sox",
|
||||
"-t",
|
||||
"wav",
|
||||
self._tmp_output_filename,
|
||||
self._output_filename,
|
||||
]
|
||||
returncode, stdout, stderr = _run_subprocess(command)
|
||||
if returncode != 0:
|
||||
raise AudioEncodingError(stderr)
|
||||
return stdout, stderr
|
||||
|
||||
def close_output(self):
|
||||
self._wfp.close()
|
||||
|
||||
def read(self):
|
||||
data = self._reader.read()
|
||||
if data is not None:
|
||||
self.send(data)
|
||||
else:
|
||||
self.send(_STOP_PROCESSING)
|
||||
return data
|
||||
|
||||
def __getattr__(self, name):
|
||||
if name == "data":
|
||||
return self.data
|
||||
return getattr(self._reader, name)
|
||||
|
||||
|
||||
class PlayerWorker(Worker):
|
||||
def __init__(self, player, progress_bar=False, timeout=0.2, logger=None):
|
||||
self._player = player
|
||||
self._progress_bar = progress_bar
|
||||
self._log_format = "[PLAY]: Detection {id} played"
|
||||
Worker.__init__(self, timeout=timeout, logger=logger)
|
||||
|
||||
def _process_message(self, message):
|
||||
_id, audio_region = message
|
||||
if self._logger is not None:
|
||||
message = self._log_format.format(id=_id)
|
||||
self._log(message)
|
||||
audio_region.play(
|
||||
player=self._player, progress_bar=self._progress_bar, leave=False
|
||||
)
|
||||
|
||||
|
||||
class RegionSaverWorker(Worker):
|
||||
def __init__(
|
||||
self,
|
||||
filename_format,
|
||||
audio_format=None,
|
||||
timeout=0.2,
|
||||
logger=None,
|
||||
**audio_parameters
|
||||
):
|
||||
self._filename_format = filename_format
|
||||
self._audio_format = audio_format
|
||||
self._audio_parameters = audio_parameters
|
||||
self._debug_format = "[SAVE]: Detection {id} saved as '{filename}'"
|
||||
Worker.__init__(self, timeout=timeout, logger=logger)
|
||||
|
||||
def _process_message(self, message):
|
||||
_id, audio_region = message
|
||||
filename = self._filename_format.format(
|
||||
id=_id,
|
||||
start=audio_region.meta.start,
|
||||
end=audio_region.meta.end,
|
||||
duration=audio_region.duration,
|
||||
)
|
||||
filename = audio_region.save(
|
||||
filename, self._audio_format, **self._audio_parameters
|
||||
)
|
||||
if self._logger:
|
||||
message = self._debug_format.format(id=_id, filename=filename)
|
||||
self._log(message)
|
||||
|
||||
|
||||
class CommandLineWorker(Worker):
|
||||
def __init__(self, command, timeout=0.2, logger=None):
|
||||
self._command = command
|
||||
Worker.__init__(self, timeout=timeout, logger=logger)
|
||||
self._debug_format = "[COMMAND]: Detection {id} command: '{command}'"
|
||||
|
||||
def _process_message(self, message):
|
||||
_id, audio_region = message
|
||||
with NamedTemporaryFile(delete=False) as file:
|
||||
filename = audio_region.save(file.name, audio_format="wav")
|
||||
command = self._command.format(file=filename)
|
||||
os.system(command)
|
||||
if self._logger is not None:
|
||||
message = self._debug_format.format(id=_id, command=command)
|
||||
self._log(message)
|
||||
|
||||
|
||||
class PrintWorker(Worker):
|
||||
def __init__(
|
||||
self,
|
||||
print_format="{start} {end}",
|
||||
time_format="%S",
|
||||
timestamp_format="%Y/%m/%d %H:%M:%S.%f",
|
||||
timeout=0.2,
|
||||
):
|
||||
|
||||
self._print_format = print_format
|
||||
self._format_time = make_duration_formatter(time_format)
|
||||
self._timestamp_format = timestamp_format
|
||||
self.detections = []
|
||||
Worker.__init__(self, timeout=timeout)
|
||||
|
||||
def _process_message(self, message):
|
||||
_id, audio_region = message
|
||||
timestamp = audio_region.meta.timestamp
|
||||
timestamp = timestamp.strftime(self._timestamp_format)
|
||||
text = self._print_format.format(
|
||||
id=_id,
|
||||
start=self._format_time(audio_region.meta.start),
|
||||
end=self._format_time(audio_region.meta.end),
|
||||
duration=self._format_time(audio_region.duration),
|
||||
timestamp=timestamp,
|
||||
)
|
||||
print(text)
|
|
@ -1,11 +1 @@
|
|||
# A Python "namespace package" http://www.python.org/dev/peps/pep-0382/
|
||||
# This always goes inside of a namespace package's __init__.py
|
||||
|
||||
from pkgutil import extend_path
|
||||
__path__ = extend_path(__path__, __name__)
|
||||
|
||||
try:
|
||||
import pkg_resources
|
||||
pkg_resources.declare_namespace(__name__)
|
||||
except ImportError:
|
||||
pass
|
||||
__path__ = __import__('pkgutil').extend_path(__path__, __name__) # type: ignore
|
||||
|
|
|
@ -4,14 +4,16 @@ import functools
|
|||
from collections import namedtuple
|
||||
from threading import RLock
|
||||
|
||||
_CacheInfo = namedtuple("CacheInfo", ["hits", "misses", "maxsize", "currsize"])
|
||||
_CacheInfo = namedtuple("_CacheInfo", ["hits", "misses", "maxsize", "currsize"])
|
||||
|
||||
|
||||
@functools.wraps(functools.update_wrapper)
|
||||
def update_wrapper(wrapper,
|
||||
wrapped,
|
||||
assigned = functools.WRAPPER_ASSIGNMENTS,
|
||||
updated = functools.WRAPPER_UPDATES):
|
||||
def update_wrapper(
|
||||
wrapper,
|
||||
wrapped,
|
||||
assigned=functools.WRAPPER_ASSIGNMENTS,
|
||||
updated=functools.WRAPPER_UPDATES,
|
||||
):
|
||||
"""
|
||||
Patch two bugs in functools.update_wrapper.
|
||||
"""
|
||||
|
@ -34,10 +36,17 @@ class _HashedSeq(list):
|
|||
return self.hashvalue
|
||||
|
||||
|
||||
def _make_key(args, kwds, typed,
|
||||
kwd_mark=(object(),),
|
||||
fasttypes=set([int, str, frozenset, type(None)]),
|
||||
sorted=sorted, tuple=tuple, type=type, len=len):
|
||||
def _make_key(
|
||||
args,
|
||||
kwds,
|
||||
typed,
|
||||
kwd_mark=(object(),),
|
||||
fasttypes=set([int, str, frozenset, type(None)]),
|
||||
sorted=sorted,
|
||||
tuple=tuple,
|
||||
type=type,
|
||||
len=len,
|
||||
):
|
||||
'Make a cache key from optionally typed positional and keyword arguments'
|
||||
key = args
|
||||
if kwds:
|
||||
|
@ -54,7 +63,7 @@ def _make_key(args, kwds, typed,
|
|||
return _HashedSeq(key)
|
||||
|
||||
|
||||
def lru_cache(maxsize=100, typed=False):
|
||||
def lru_cache(maxsize=100, typed=False): # noqa: C901
|
||||
"""Least-recently-used cache decorator.
|
||||
|
||||
If *maxsize* is set to None, the LRU features are disabled and the cache
|
||||
|
@ -82,16 +91,16 @@ def lru_cache(maxsize=100, typed=False):
|
|||
def decorating_function(user_function):
|
||||
|
||||
cache = dict()
|
||||
stats = [0, 0] # make statistics updateable non-locally
|
||||
HITS, MISSES = 0, 1 # names for the stats fields
|
||||
stats = [0, 0] # make statistics updateable non-locally
|
||||
HITS, MISSES = 0, 1 # names for the stats fields
|
||||
make_key = _make_key
|
||||
cache_get = cache.get # bound method to lookup key or return None
|
||||
_len = len # localize the global len() function
|
||||
lock = RLock() # because linkedlist updates aren't threadsafe
|
||||
root = [] # root of the circular doubly linked list
|
||||
root[:] = [root, root, None, None] # initialize by pointing to self
|
||||
nonlocal_root = [root] # make updateable non-locally
|
||||
PREV, NEXT, KEY, RESULT = 0, 1, 2, 3 # names for the link fields
|
||||
cache_get = cache.get # bound method to lookup key or return None
|
||||
_len = len # localize the global len() function
|
||||
lock = RLock() # because linkedlist updates aren't threadsafe
|
||||
root = [] # root of the circular doubly linked list
|
||||
root[:] = [root, root, None, None] # initialize by pointing to self
|
||||
nonlocal_root = [root] # make updateable non-locally
|
||||
PREV, NEXT, KEY, RESULT = 0, 1, 2, 3 # names for the link fields
|
||||
|
||||
if maxsize == 0:
|
||||
|
||||
|
@ -106,7 +115,9 @@ def lru_cache(maxsize=100, typed=False):
|
|||
def wrapper(*args, **kwds):
|
||||
# simple caching without ordering or size limit
|
||||
key = make_key(args, kwds, typed)
|
||||
result = cache_get(key, root) # root used here as a unique not-found sentinel
|
||||
result = cache_get(
|
||||
key, root
|
||||
) # root used here as a unique not-found sentinel
|
||||
if result is not root:
|
||||
stats[HITS] += 1
|
||||
return result
|
||||
|
@ -123,8 +134,9 @@ def lru_cache(maxsize=100, typed=False):
|
|||
with lock:
|
||||
link = cache_get(key)
|
||||
if link is not None:
|
||||
# record recent use of the key by moving it to the front of the list
|
||||
root, = nonlocal_root
|
||||
# record recent use of the key by moving it
|
||||
# to the front of the list
|
||||
(root,) = nonlocal_root
|
||||
link_prev, link_next, key, result = link
|
||||
link_prev[NEXT] = link_next
|
||||
link_next[PREV] = link_prev
|
||||
|
@ -136,7 +148,7 @@ def lru_cache(maxsize=100, typed=False):
|
|||
return result
|
||||
result = user_function(*args, **kwds)
|
||||
with lock:
|
||||
root, = nonlocal_root
|
||||
(root,) = nonlocal_root
|
||||
if key in cache:
|
||||
# getting here means that this same key was added to the
|
||||
# cache while the lock was released. since the link
|
||||
|
|
|
@ -0,0 +1,49 @@
|
|||
__all__ = [
|
||||
"ZoneInfo",
|
||||
"reset_tzpath",
|
||||
"available_timezones",
|
||||
"TZPATH",
|
||||
"ZoneInfoNotFoundError",
|
||||
"InvalidTZPathWarning",
|
||||
]
|
||||
import sys
|
||||
|
||||
from . import _tzpath
|
||||
from ._common import ZoneInfoNotFoundError
|
||||
from ._version import __version__
|
||||
|
||||
try:
|
||||
from ._czoneinfo import ZoneInfo
|
||||
except ImportError: # pragma: nocover
|
||||
from ._zoneinfo import ZoneInfo
|
||||
|
||||
reset_tzpath = _tzpath.reset_tzpath
|
||||
available_timezones = _tzpath.available_timezones
|
||||
InvalidTZPathWarning = _tzpath.InvalidTZPathWarning
|
||||
|
||||
if sys.version_info < (3, 7):
|
||||
# Module-level __getattr__ was added in Python 3.7, so instead of lazily
|
||||
# populating TZPATH on every access, we will register a callback with
|
||||
# reset_tzpath to update the top-level tuple.
|
||||
TZPATH = _tzpath.TZPATH
|
||||
|
||||
def _tzpath_callback(new_tzpath):
|
||||
global TZPATH
|
||||
TZPATH = new_tzpath
|
||||
|
||||
_tzpath.TZPATH_CALLBACKS.append(_tzpath_callback)
|
||||
del _tzpath_callback
|
||||
|
||||
else:
|
||||
|
||||
def __getattr__(name):
|
||||
if name == "TZPATH":
|
||||
return _tzpath.TZPATH
|
||||
else:
|
||||
raise AttributeError(
|
||||
f"module {__name__!r} has no attribute {name!r}"
|
||||
)
|
||||
|
||||
|
||||
def __dir__():
|
||||
return sorted(list(globals()) + ["TZPATH"])
|
|
@ -0,0 +1,45 @@
|
|||
import os
|
||||
import typing
|
||||
from datetime import datetime, tzinfo
|
||||
from typing import (
|
||||
Any,
|
||||
Iterable,
|
||||
Optional,
|
||||
Protocol,
|
||||
Sequence,
|
||||
Set,
|
||||
Type,
|
||||
Union,
|
||||
)
|
||||
|
||||
_T = typing.TypeVar("_T", bound="ZoneInfo")
|
||||
|
||||
class _IOBytes(Protocol):
|
||||
def read(self, __size: int) -> bytes: ...
|
||||
def seek(self, __size: int, __whence: int = ...) -> Any: ...
|
||||
|
||||
class ZoneInfo(tzinfo):
|
||||
@property
|
||||
def key(self) -> str: ...
|
||||
def __init__(self, key: str) -> None: ...
|
||||
@classmethod
|
||||
def no_cache(cls: Type[_T], key: str) -> _T: ...
|
||||
@classmethod
|
||||
def from_file(
|
||||
cls: Type[_T], __fobj: _IOBytes, key: Optional[str] = ...
|
||||
) -> _T: ...
|
||||
@classmethod
|
||||
def clear_cache(cls, *, only_keys: Iterable[str] = ...) -> None: ...
|
||||
|
||||
# Note: Both here and in clear_cache, the types allow the use of `str` where
|
||||
# a sequence of strings is required. This should be remedied if a solution
|
||||
# to this typing bug is found: https://github.com/python/typing/issues/256
|
||||
def reset_tzpath(
|
||||
to: Optional[Sequence[Union[os.PathLike, str]]] = ...
|
||||
) -> None: ...
|
||||
def available_timezones() -> Set[str]: ...
|
||||
|
||||
TZPATH: Sequence[str]
|
||||
|
||||
class ZoneInfoNotFoundError(KeyError): ...
|
||||
class InvalidTZPathWarning(RuntimeWarning): ...
|
|
@ -0,0 +1,171 @@
|
|||
import struct
|
||||
|
||||
|
||||
def load_tzdata(key):
|
||||
try:
|
||||
import importlib.resources as importlib_resources
|
||||
except ImportError:
|
||||
import importlib_resources
|
||||
|
||||
components = key.split("/")
|
||||
package_name = ".".join(["tzdata.zoneinfo"] + components[:-1])
|
||||
resource_name = components[-1]
|
||||
|
||||
try:
|
||||
return importlib_resources.open_binary(package_name, resource_name)
|
||||
except (ImportError, FileNotFoundError, UnicodeEncodeError):
|
||||
# There are three types of exception that can be raised that all amount
|
||||
# to "we cannot find this key":
|
||||
#
|
||||
# ImportError: If package_name doesn't exist (e.g. if tzdata is not
|
||||
# installed, or if there's an error in the folder name like
|
||||
# Amrica/New_York)
|
||||
# FileNotFoundError: If resource_name doesn't exist in the package
|
||||
# (e.g. Europe/Krasnoy)
|
||||
# UnicodeEncodeError: If package_name or resource_name are not UTF-8,
|
||||
# such as keys containing a surrogate character.
|
||||
raise ZoneInfoNotFoundError(f"No time zone found with key {key}")
|
||||
|
||||
|
||||
def load_data(fobj):
|
||||
header = _TZifHeader.from_file(fobj)
|
||||
|
||||
if header.version == 1:
|
||||
time_size = 4
|
||||
time_type = "l"
|
||||
else:
|
||||
# Version 2+ has 64-bit integer transition times
|
||||
time_size = 8
|
||||
time_type = "q"
|
||||
|
||||
# Version 2+ also starts with a Version 1 header and data, which
|
||||
# we need to skip now
|
||||
skip_bytes = (
|
||||
header.timecnt * 5 # Transition times and types
|
||||
+ header.typecnt * 6 # Local time type records
|
||||
+ header.charcnt # Time zone designations
|
||||
+ header.leapcnt * 8 # Leap second records
|
||||
+ header.isstdcnt # Standard/wall indicators
|
||||
+ header.isutcnt # UT/local indicators
|
||||
)
|
||||
|
||||
fobj.seek(skip_bytes, 1)
|
||||
|
||||
# Now we need to read the second header, which is not the same
|
||||
# as the first
|
||||
header = _TZifHeader.from_file(fobj)
|
||||
|
||||
typecnt = header.typecnt
|
||||
timecnt = header.timecnt
|
||||
charcnt = header.charcnt
|
||||
|
||||
# The data portion starts with timecnt transitions and indices
|
||||
if timecnt:
|
||||
trans_list_utc = struct.unpack(
|
||||
f">{timecnt}{time_type}", fobj.read(timecnt * time_size)
|
||||
)
|
||||
trans_idx = struct.unpack(f">{timecnt}B", fobj.read(timecnt))
|
||||
else:
|
||||
trans_list_utc = ()
|
||||
trans_idx = ()
|
||||
|
||||
# Read the ttinfo struct, (utoff, isdst, abbrind)
|
||||
if typecnt:
|
||||
utcoff, isdst, abbrind = zip(
|
||||
*(struct.unpack(">lbb", fobj.read(6)) for i in range(typecnt))
|
||||
)
|
||||
else:
|
||||
utcoff = ()
|
||||
isdst = ()
|
||||
abbrind = ()
|
||||
|
||||
# Now read the abbreviations. They are null-terminated strings, indexed
|
||||
# not by position in the array but by position in the unsplit
|
||||
# abbreviation string. I suppose this makes more sense in C, which uses
|
||||
# null to terminate the strings, but it's inconvenient here...
|
||||
abbr_vals = {}
|
||||
abbr_chars = fobj.read(charcnt)
|
||||
|
||||
def get_abbr(idx):
|
||||
# Gets a string starting at idx and running until the next \x00
|
||||
#
|
||||
# We cannot pre-populate abbr_vals by splitting on \x00 because there
|
||||
# are some zones that use subsets of longer abbreviations, like so:
|
||||
#
|
||||
# LMT\x00AHST\x00HDT\x00
|
||||
#
|
||||
# Where the idx to abbr mapping should be:
|
||||
#
|
||||
# {0: "LMT", 4: "AHST", 5: "HST", 9: "HDT"}
|
||||
if idx not in abbr_vals:
|
||||
span_end = abbr_chars.find(b"\x00", idx)
|
||||
abbr_vals[idx] = abbr_chars[idx:span_end].decode()
|
||||
|
||||
return abbr_vals[idx]
|
||||
|
||||
abbr = tuple(get_abbr(idx) for idx in abbrind)
|
||||
|
||||
# The remainder of the file consists of leap seconds (currently unused) and
|
||||
# the standard/wall and ut/local indicators, which are metadata we don't need.
|
||||
# In version 2 files, we need to skip the unnecessary data to get at the TZ string:
|
||||
if header.version >= 2:
|
||||
# Each leap second record has size (time_size + 4)
|
||||
skip_bytes = header.isutcnt + header.isstdcnt + header.leapcnt * 12
|
||||
fobj.seek(skip_bytes, 1)
|
||||
|
||||
c = fobj.read(1) # Should be \n
|
||||
assert c == b"\n", c
|
||||
|
||||
tz_bytes = b""
|
||||
while True:
|
||||
c = fobj.read(1)
|
||||
if c == b"\n":
|
||||
break
|
||||
tz_bytes += c
|
||||
|
||||
tz_str = tz_bytes
|
||||
else:
|
||||
tz_str = None
|
||||
|
||||
return trans_idx, trans_list_utc, utcoff, isdst, abbr, tz_str
|
||||
|
||||
|
||||
class _TZifHeader:
|
||||
__slots__ = [
|
||||
"version",
|
||||
"isutcnt",
|
||||
"isstdcnt",
|
||||
"leapcnt",
|
||||
"timecnt",
|
||||
"typecnt",
|
||||
"charcnt",
|
||||
]
|
||||
|
||||
def __init__(self, *args):
|
||||
assert len(self.__slots__) == len(args)
|
||||
for attr, val in zip(self.__slots__, args):
|
||||
setattr(self, attr, val)
|
||||
|
||||
@classmethod
|
||||
def from_file(cls, stream):
|
||||
# The header starts with a 4-byte "magic" value
|
||||
if stream.read(4) != b"TZif":
|
||||
raise ValueError("Invalid TZif file: magic not found")
|
||||
|
||||
_version = stream.read(1)
|
||||
if _version == b"\x00":
|
||||
version = 1
|
||||
else:
|
||||
version = int(_version)
|
||||
stream.read(15)
|
||||
|
||||
args = (version,)
|
||||
|
||||
# Slots are defined in the order that the bytes are arranged
|
||||
args = args + struct.unpack(">6l", stream.read(24))
|
||||
|
||||
return cls(*args)
|
||||
|
||||
|
||||
class ZoneInfoNotFoundError(KeyError):
|
||||
"""Exception raised when a ZoneInfo key is not found."""
|
|
@ -0,0 +1,207 @@
|
|||
import os
|
||||
import sys
|
||||
|
||||
PY36 = sys.version_info < (3, 7)
|
||||
|
||||
|
||||
def reset_tzpath(to=None):
|
||||
global TZPATH
|
||||
|
||||
tzpaths = to
|
||||
if tzpaths is not None:
|
||||
if isinstance(tzpaths, (str, bytes)):
|
||||
raise TypeError(
|
||||
f"tzpaths must be a list or tuple, "
|
||||
+ f"not {type(tzpaths)}: {tzpaths!r}"
|
||||
)
|
||||
|
||||
if not all(map(os.path.isabs, tzpaths)):
|
||||
raise ValueError(_get_invalid_paths_message(tzpaths))
|
||||
base_tzpath = tzpaths
|
||||
else:
|
||||
env_var = os.environ.get("PYTHONTZPATH", None)
|
||||
if env_var is not None:
|
||||
base_tzpath = _parse_python_tzpath(env_var)
|
||||
elif sys.platform != "win32":
|
||||
base_tzpath = [
|
||||
"/usr/share/zoneinfo",
|
||||
"/usr/lib/zoneinfo",
|
||||
"/usr/share/lib/zoneinfo",
|
||||
"/etc/zoneinfo",
|
||||
]
|
||||
|
||||
base_tzpath.sort(key=lambda x: not os.path.exists(x))
|
||||
else:
|
||||
base_tzpath = ()
|
||||
|
||||
TZPATH = tuple(base_tzpath)
|
||||
|
||||
if TZPATH_CALLBACKS:
|
||||
for callback in TZPATH_CALLBACKS:
|
||||
callback(TZPATH)
|
||||
|
||||
|
||||
def _parse_python_tzpath(env_var):
|
||||
if not env_var:
|
||||
return ()
|
||||
|
||||
raw_tzpath = env_var.split(os.pathsep)
|
||||
new_tzpath = tuple(filter(os.path.isabs, raw_tzpath))
|
||||
|
||||
# If anything has been filtered out, we will warn about it
|
||||
if len(new_tzpath) != len(raw_tzpath):
|
||||
import warnings
|
||||
|
||||
msg = _get_invalid_paths_message(raw_tzpath)
|
||||
|
||||
warnings.warn(
|
||||
"Invalid paths specified in PYTHONTZPATH environment variable."
|
||||
+ msg,
|
||||
InvalidTZPathWarning,
|
||||
)
|
||||
|
||||
return new_tzpath
|
||||
|
||||
|
||||
def _get_invalid_paths_message(tzpaths):
|
||||
invalid_paths = (path for path in tzpaths if not os.path.isabs(path))
|
||||
|
||||
prefix = "\n "
|
||||
indented_str = prefix + prefix.join(invalid_paths)
|
||||
|
||||
return (
|
||||
"Paths should be absolute but found the following relative paths:"
|
||||
+ indented_str
|
||||
)
|
||||
|
||||
|
||||
if sys.version_info < (3, 8):
|
||||
|
||||
def _isfile(path):
|
||||
# bpo-33721: In Python 3.8 non-UTF8 paths return False rather than
|
||||
# raising an error. See https://bugs.python.org/issue33721
|
||||
try:
|
||||
return os.path.isfile(path)
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
|
||||
else:
|
||||
_isfile = os.path.isfile
|
||||
|
||||
|
||||
def find_tzfile(key):
|
||||
"""Retrieve the path to a TZif file from a key."""
|
||||
_validate_tzfile_path(key)
|
||||
for search_path in TZPATH:
|
||||
filepath = os.path.join(search_path, key)
|
||||
if _isfile(filepath):
|
||||
return filepath
|
||||
|
||||
return None
|
||||
|
||||
|
||||
_TEST_PATH = os.path.normpath(os.path.join("_", "_"))[:-1]
|
||||
|
||||
|
||||
def _validate_tzfile_path(path, _base=_TEST_PATH):
|
||||
if os.path.isabs(path):
|
||||
raise ValueError(
|
||||
f"ZoneInfo keys may not be absolute paths, got: {path}"
|
||||
)
|
||||
|
||||
# We only care about the kinds of path normalizations that would change the
|
||||
# length of the key - e.g. a/../b -> a/b, or a/b/ -> a/b. On Windows,
|
||||
# normpath will also change from a/b to a\b, but that would still preserve
|
||||
# the length.
|
||||
new_path = os.path.normpath(path)
|
||||
if len(new_path) != len(path):
|
||||
raise ValueError(
|
||||
f"ZoneInfo keys must be normalized relative paths, got: {path}"
|
||||
)
|
||||
|
||||
resolved = os.path.normpath(os.path.join(_base, new_path))
|
||||
if not resolved.startswith(_base):
|
||||
raise ValueError(
|
||||
f"ZoneInfo keys must refer to subdirectories of TZPATH, got: {path}"
|
||||
)
|
||||
|
||||
|
||||
del _TEST_PATH
|
||||
|
||||
|
||||
def available_timezones():
|
||||
"""Returns a set containing all available time zones.
|
||||
|
||||
.. caution::
|
||||
|
||||
This may attempt to open a large number of files, since the best way to
|
||||
determine if a given file on the time zone search path is to open it
|
||||
and check for the "magic string" at the beginning.
|
||||
"""
|
||||
try:
|
||||
from importlib import resources
|
||||
except ImportError:
|
||||
import importlib_resources as resources
|
||||
|
||||
valid_zones = set()
|
||||
|
||||
# Start with loading from the tzdata package if it exists: this has a
|
||||
# pre-assembled list of zones that only requires opening one file.
|
||||
try:
|
||||
with resources.open_text("tzdata", "zones") as f:
|
||||
for zone in f:
|
||||
zone = zone.strip()
|
||||
if zone:
|
||||
valid_zones.add(zone)
|
||||
except (ImportError, FileNotFoundError):
|
||||
pass
|
||||
|
||||
def valid_key(fpath):
|
||||
try:
|
||||
with open(fpath, "rb") as f:
|
||||
return f.read(4) == b"TZif"
|
||||
except Exception: # pragma: nocover
|
||||
return False
|
||||
|
||||
for tz_root in TZPATH:
|
||||
if not os.path.exists(tz_root):
|
||||
continue
|
||||
|
||||
for root, dirnames, files in os.walk(tz_root):
|
||||
if root == tz_root:
|
||||
# right/ and posix/ are special directories and shouldn't be
|
||||
# included in the output of available zones
|
||||
if "right" in dirnames:
|
||||
dirnames.remove("right")
|
||||
if "posix" in dirnames:
|
||||
dirnames.remove("posix")
|
||||
|
||||
for file in files:
|
||||
fpath = os.path.join(root, file)
|
||||
|
||||
key = os.path.relpath(fpath, start=tz_root)
|
||||
if os.sep != "/": # pragma: nocover
|
||||
key = key.replace(os.sep, "/")
|
||||
|
||||
if not key or key in valid_zones:
|
||||
continue
|
||||
|
||||
if valid_key(fpath):
|
||||
valid_zones.add(key)
|
||||
|
||||
if "posixrules" in valid_zones:
|
||||
# posixrules is a special symlink-only time zone where it exists, it
|
||||
# should not be included in the output
|
||||
valid_zones.remove("posixrules")
|
||||
|
||||
return valid_zones
|
||||
|
||||
|
||||
class InvalidTZPathWarning(RuntimeWarning):
|
||||
"""Warning raised if an invalid path is specified in PYTHONTZPATH."""
|
||||
|
||||
|
||||
TZPATH = ()
|
||||
TZPATH_CALLBACKS = []
|
||||
reset_tzpath()
|
|
@ -0,0 +1 @@
|
|||
__version__ = "0.2.1"
|
|
@ -0,0 +1,754 @@
|
|||
import bisect
|
||||
import calendar
|
||||
import collections
|
||||
import functools
|
||||
import re
|
||||
import weakref
|
||||
from datetime import datetime, timedelta, tzinfo
|
||||
|
||||
from . import _common, _tzpath
|
||||
|
||||
EPOCH = datetime(1970, 1, 1)
|
||||
EPOCHORDINAL = datetime(1970, 1, 1).toordinal()
|
||||
|
||||
# It is relatively expensive to construct new timedelta objects, and in most
|
||||
# cases we're looking at the same deltas, like integer numbers of hours, etc.
|
||||
# To improve speed and memory use, we'll keep a dictionary with references
|
||||
# to the ones we've already used so far.
|
||||
#
|
||||
# Loading every time zone in the 2020a version of the time zone database
|
||||
# requires 447 timedeltas, which requires approximately the amount of space
|
||||
# that ZoneInfo("America/New_York") with 236 transitions takes up, so we will
|
||||
# set the cache size to 512 so that in the common case we always get cache
|
||||
# hits, but specifically crafted ZoneInfo objects don't leak arbitrary amounts
|
||||
# of memory.
|
||||
@functools.lru_cache(maxsize=512)
|
||||
def _load_timedelta(seconds):
|
||||
return timedelta(seconds=seconds)
|
||||
|
||||
|
||||
class ZoneInfo(tzinfo):
|
||||
_strong_cache_size = 8
|
||||
_strong_cache = collections.OrderedDict()
|
||||
_weak_cache = weakref.WeakValueDictionary()
|
||||
__module__ = "backports.zoneinfo"
|
||||
|
||||
def __init_subclass__(cls):
|
||||
cls._strong_cache = collections.OrderedDict()
|
||||
cls._weak_cache = weakref.WeakValueDictionary()
|
||||
|
||||
def __new__(cls, key):
|
||||
instance = cls._weak_cache.get(key, None)
|
||||
if instance is None:
|
||||
instance = cls._weak_cache.setdefault(key, cls._new_instance(key))
|
||||
instance._from_cache = True
|
||||
|
||||
# Update the "strong" cache
|
||||
cls._strong_cache[key] = cls._strong_cache.pop(key, instance)
|
||||
|
||||
if len(cls._strong_cache) > cls._strong_cache_size:
|
||||
cls._strong_cache.popitem(last=False)
|
||||
|
||||
return instance
|
||||
|
||||
@classmethod
|
||||
def no_cache(cls, key):
|
||||
obj = cls._new_instance(key)
|
||||
obj._from_cache = False
|
||||
|
||||
return obj
|
||||
|
||||
@classmethod
|
||||
def _new_instance(cls, key):
|
||||
obj = super().__new__(cls)
|
||||
obj._key = key
|
||||
obj._file_path = obj._find_tzfile(key)
|
||||
|
||||
if obj._file_path is not None:
|
||||
file_obj = open(obj._file_path, "rb")
|
||||
else:
|
||||
file_obj = _common.load_tzdata(key)
|
||||
|
||||
with file_obj as f:
|
||||
obj._load_file(f)
|
||||
|
||||
return obj
|
||||
|
||||
@classmethod
|
||||
def from_file(cls, fobj, key=None):
|
||||
obj = super().__new__(cls)
|
||||
obj._key = key
|
||||
obj._file_path = None
|
||||
obj._load_file(fobj)
|
||||
obj._file_repr = repr(fobj)
|
||||
|
||||
# Disable pickling for objects created from files
|
||||
obj.__reduce__ = obj._file_reduce
|
||||
|
||||
return obj
|
||||
|
||||
@classmethod
|
||||
def clear_cache(cls, *, only_keys=None):
|
||||
if only_keys is not None:
|
||||
for key in only_keys:
|
||||
cls._weak_cache.pop(key, None)
|
||||
cls._strong_cache.pop(key, None)
|
||||
|
||||
else:
|
||||
cls._weak_cache.clear()
|
||||
cls._strong_cache.clear()
|
||||
|
||||
@property
|
||||
def key(self):
|
||||
return self._key
|
||||
|
||||
def utcoffset(self, dt):
|
||||
return self._find_trans(dt).utcoff
|
||||
|
||||
def dst(self, dt):
|
||||
return self._find_trans(dt).dstoff
|
||||
|
||||
def tzname(self, dt):
|
||||
return self._find_trans(dt).tzname
|
||||
|
||||
def fromutc(self, dt):
|
||||
"""Convert from datetime in UTC to datetime in local time"""
|
||||
|
||||
if not isinstance(dt, datetime):
|
||||
raise TypeError("fromutc() requires a datetime argument")
|
||||
if dt.tzinfo is not self:
|
||||
raise ValueError("dt.tzinfo is not self")
|
||||
|
||||
timestamp = self._get_local_timestamp(dt)
|
||||
num_trans = len(self._trans_utc)
|
||||
|
||||
if num_trans >= 1 and timestamp < self._trans_utc[0]:
|
||||
tti = self._tti_before
|
||||
fold = 0
|
||||
elif (
|
||||
num_trans == 0 or timestamp > self._trans_utc[-1]
|
||||
) and not isinstance(self._tz_after, _ttinfo):
|
||||
tti, fold = self._tz_after.get_trans_info_fromutc(
|
||||
timestamp, dt.year
|
||||
)
|
||||
elif num_trans == 0:
|
||||
tti = self._tz_after
|
||||
fold = 0
|
||||
else:
|
||||
idx = bisect.bisect_right(self._trans_utc, timestamp)
|
||||
|
||||
if num_trans > 1 and timestamp >= self._trans_utc[1]:
|
||||
tti_prev, tti = self._ttinfos[idx - 2 : idx]
|
||||
elif timestamp > self._trans_utc[-1]:
|
||||
tti_prev = self._ttinfos[-1]
|
||||
tti = self._tz_after
|
||||
else:
|
||||
tti_prev = self._tti_before
|
||||
tti = self._ttinfos[0]
|
||||
|
||||
# Detect fold
|
||||
shift = tti_prev.utcoff - tti.utcoff
|
||||
fold = shift.total_seconds() > timestamp - self._trans_utc[idx - 1]
|
||||
dt += tti.utcoff
|
||||
if fold:
|
||||
return dt.replace(fold=1)
|
||||
else:
|
||||
return dt
|
||||
|
||||
def _find_trans(self, dt):
|
||||
if dt is None:
|
||||
if self._fixed_offset:
|
||||
return self._tz_after
|
||||
else:
|
||||
return _NO_TTINFO
|
||||
|
||||
ts = self._get_local_timestamp(dt)
|
||||
|
||||
lt = self._trans_local[dt.fold]
|
||||
|
||||
num_trans = len(lt)
|
||||
|
||||
if num_trans and ts < lt[0]:
|
||||
return self._tti_before
|
||||
elif not num_trans or ts > lt[-1]:
|
||||
if isinstance(self._tz_after, _TZStr):
|
||||
return self._tz_after.get_trans_info(ts, dt.year, dt.fold)
|
||||
else:
|
||||
return self._tz_after
|
||||
else:
|
||||
# idx is the transition that occurs after this timestamp, so we
|
||||
# subtract off 1 to get the current ttinfo
|
||||
idx = bisect.bisect_right(lt, ts) - 1
|
||||
assert idx >= 0
|
||||
return self._ttinfos[idx]
|
||||
|
||||
def _get_local_timestamp(self, dt):
|
||||
return (
|
||||
(dt.toordinal() - EPOCHORDINAL) * 86400
|
||||
+ dt.hour * 3600
|
||||
+ dt.minute * 60
|
||||
+ dt.second
|
||||
)
|
||||
|
||||
def __str__(self):
|
||||
if self._key is not None:
|
||||
return f"{self._key}"
|
||||
else:
|
||||
return repr(self)
|
||||
|
||||
def __repr__(self):
|
||||
if self._key is not None:
|
||||
return f"{self.__class__.__name__}(key={self._key!r})"
|
||||
else:
|
||||
return f"{self.__class__.__name__}.from_file({self._file_repr})"
|
||||
|
||||
def __reduce__(self):
|
||||
return (self.__class__._unpickle, (self._key, self._from_cache))
|
||||
|
||||
def _file_reduce(self):
|
||||
import pickle
|
||||
|
||||
raise pickle.PicklingError(
|
||||
"Cannot pickle a ZoneInfo file created from a file stream."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _unpickle(cls, key, from_cache):
|
||||
if from_cache:
|
||||
return cls(key)
|
||||
else:
|
||||
return cls.no_cache(key)
|
||||
|
||||
def _find_tzfile(self, key):
|
||||
return _tzpath.find_tzfile(key)
|
||||
|
||||
def _load_file(self, fobj):
|
||||
# Retrieve all the data as it exists in the zoneinfo file
|
||||
trans_idx, trans_utc, utcoff, isdst, abbr, tz_str = _common.load_data(
|
||||
fobj
|
||||
)
|
||||
|
||||
# Infer the DST offsets (needed for .dst()) from the data
|
||||
dstoff = self._utcoff_to_dstoff(trans_idx, utcoff, isdst)
|
||||
|
||||
# Convert all the transition times (UTC) into "seconds since 1970-01-01 local time"
|
||||
trans_local = self._ts_to_local(trans_idx, trans_utc, utcoff)
|
||||
|
||||
# Construct `_ttinfo` objects for each transition in the file
|
||||
_ttinfo_list = [
|
||||
_ttinfo(
|
||||
_load_timedelta(utcoffset), _load_timedelta(dstoffset), tzname
|
||||
)
|
||||
for utcoffset, dstoffset, tzname in zip(utcoff, dstoff, abbr)
|
||||
]
|
||||
|
||||
self._trans_utc = trans_utc
|
||||
self._trans_local = trans_local
|
||||
self._ttinfos = [_ttinfo_list[idx] for idx in trans_idx]
|
||||
|
||||
# Find the first non-DST transition
|
||||
for i in range(len(isdst)):
|
||||
if not isdst[i]:
|
||||
self._tti_before = _ttinfo_list[i]
|
||||
break
|
||||
else:
|
||||
if self._ttinfos:
|
||||
self._tti_before = self._ttinfos[0]
|
||||
else:
|
||||
self._tti_before = None
|
||||
|
||||
# Set the "fallback" time zone
|
||||
if tz_str is not None and tz_str != b"":
|
||||
self._tz_after = _parse_tz_str(tz_str.decode())
|
||||
else:
|
||||
if not self._ttinfos and not _ttinfo_list:
|
||||
raise ValueError("No time zone information found.")
|
||||
|
||||
if self._ttinfos:
|
||||
self._tz_after = self._ttinfos[-1]
|
||||
else:
|
||||
self._tz_after = _ttinfo_list[-1]
|
||||
|
||||
# Determine if this is a "fixed offset" zone, meaning that the output
|
||||
# of the utcoffset, dst and tzname functions does not depend on the
|
||||
# specific datetime passed.
|
||||
#
|
||||
# We make three simplifying assumptions here:
|
||||
#
|
||||
# 1. If _tz_after is not a _ttinfo, it has transitions that might
|
||||
# actually occur (it is possible to construct TZ strings that
|
||||
# specify STD and DST but no transitions ever occur, such as
|
||||
# AAA0BBB,0/0,J365/25).
|
||||
# 2. If _ttinfo_list contains more than one _ttinfo object, the objects
|
||||
# represent different offsets.
|
||||
# 3. _ttinfo_list contains no unused _ttinfos (in which case an
|
||||
# otherwise fixed-offset zone with extra _ttinfos defined may
|
||||
# appear to *not* be a fixed offset zone).
|
||||
#
|
||||
# Violations to these assumptions would be fairly exotic, and exotic
|
||||
# zones should almost certainly not be used with datetime.time (the
|
||||
# only thing that would be affected by this).
|
||||
if len(_ttinfo_list) > 1 or not isinstance(self._tz_after, _ttinfo):
|
||||
self._fixed_offset = False
|
||||
elif not _ttinfo_list:
|
||||
self._fixed_offset = True
|
||||
else:
|
||||
self._fixed_offset = _ttinfo_list[0] == self._tz_after
|
||||
|
||||
@staticmethod
|
||||
def _utcoff_to_dstoff(trans_idx, utcoffsets, isdsts):
|
||||
# Now we must transform our ttis and abbrs into `_ttinfo` objects,
|
||||
# but there is an issue: .dst() must return a timedelta with the
|
||||
# difference between utcoffset() and the "standard" offset, but
|
||||
# the "base offset" and "DST offset" are not encoded in the file;
|
||||
# we can infer what they are from the isdst flag, but it is not
|
||||
# sufficient to to just look at the last standard offset, because
|
||||
# occasionally countries will shift both DST offset and base offset.
|
||||
|
||||
typecnt = len(isdsts)
|
||||
dstoffs = [0] * typecnt # Provisionally assign all to 0.
|
||||
dst_cnt = sum(isdsts)
|
||||
dst_found = 0
|
||||
|
||||
for i in range(1, len(trans_idx)):
|
||||
if dst_cnt == dst_found:
|
||||
break
|
||||
|
||||
idx = trans_idx[i]
|
||||
|
||||
dst = isdsts[idx]
|
||||
|
||||
# We're only going to look at daylight saving time
|
||||
if not dst:
|
||||
continue
|
||||
|
||||
# Skip any offsets that have already been assigned
|
||||
if dstoffs[idx] != 0:
|
||||
continue
|
||||
|
||||
dstoff = 0
|
||||
utcoff = utcoffsets[idx]
|
||||
|
||||
comp_idx = trans_idx[i - 1]
|
||||
|
||||
if not isdsts[comp_idx]:
|
||||
dstoff = utcoff - utcoffsets[comp_idx]
|
||||
|
||||
if not dstoff and idx < (typecnt - 1):
|
||||
comp_idx = trans_idx[i + 1]
|
||||
|
||||
# If the following transition is also DST and we couldn't
|
||||
# find the DST offset by this point, we're going ot have to
|
||||
# skip it and hope this transition gets assigned later
|
||||
if isdsts[comp_idx]:
|
||||
continue
|
||||
|
||||
dstoff = utcoff - utcoffsets[comp_idx]
|
||||
|
||||
if dstoff:
|
||||
dst_found += 1
|
||||
dstoffs[idx] = dstoff
|
||||
else:
|
||||
# If we didn't find a valid value for a given index, we'll end up
|
||||
# with dstoff = 0 for something where `isdst=1`. This is obviously
|
||||
# wrong - one hour will be a much better guess than 0
|
||||
for idx in range(typecnt):
|
||||
if not dstoffs[idx] and isdsts[idx]:
|
||||
dstoffs[idx] = 3600
|
||||
|
||||
return dstoffs
|
||||
|
||||
@staticmethod
|
||||
def _ts_to_local(trans_idx, trans_list_utc, utcoffsets):
|
||||
"""Generate number of seconds since 1970 *in the local time*.
|
||||
|
||||
This is necessary to easily find the transition times in local time"""
|
||||
if not trans_list_utc:
|
||||
return [[], []]
|
||||
|
||||
# Start with the timestamps and modify in-place
|
||||
trans_list_wall = [list(trans_list_utc), list(trans_list_utc)]
|
||||
|
||||
if len(utcoffsets) > 1:
|
||||
offset_0 = utcoffsets[0]
|
||||
offset_1 = utcoffsets[trans_idx[0]]
|
||||
if offset_1 > offset_0:
|
||||
offset_1, offset_0 = offset_0, offset_1
|
||||
else:
|
||||
offset_0 = offset_1 = utcoffsets[0]
|
||||
|
||||
trans_list_wall[0][0] += offset_0
|
||||
trans_list_wall[1][0] += offset_1
|
||||
|
||||
for i in range(1, len(trans_idx)):
|
||||
offset_0 = utcoffsets[trans_idx[i - 1]]
|
||||
offset_1 = utcoffsets[trans_idx[i]]
|
||||
|
||||
if offset_1 > offset_0:
|
||||
offset_1, offset_0 = offset_0, offset_1
|
||||
|
||||
trans_list_wall[0][i] += offset_0
|
||||
trans_list_wall[1][i] += offset_1
|
||||
|
||||
return trans_list_wall
|
||||
|
||||
|
||||
class _ttinfo:
|
||||
__slots__ = ["utcoff", "dstoff", "tzname"]
|
||||
|
||||
def __init__(self, utcoff, dstoff, tzname):
|
||||
self.utcoff = utcoff
|
||||
self.dstoff = dstoff
|
||||
self.tzname = tzname
|
||||
|
||||
def __eq__(self, other):
|
||||
return (
|
||||
self.utcoff == other.utcoff
|
||||
and self.dstoff == other.dstoff
|
||||
and self.tzname == other.tzname
|
||||
)
|
||||
|
||||
def __repr__(self): # pragma: nocover
|
||||
return (
|
||||
f"{self.__class__.__name__}"
|
||||
+ f"({self.utcoff}, {self.dstoff}, {self.tzname})"
|
||||
)
|
||||
|
||||
|
||||
_NO_TTINFO = _ttinfo(None, None, None)
|
||||
|
||||
|
||||
class _TZStr:
|
||||
__slots__ = (
|
||||
"std",
|
||||
"dst",
|
||||
"start",
|
||||
"end",
|
||||
"get_trans_info",
|
||||
"get_trans_info_fromutc",
|
||||
"dst_diff",
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self, std_abbr, std_offset, dst_abbr, dst_offset, start=None, end=None
|
||||
):
|
||||
self.dst_diff = dst_offset - std_offset
|
||||
std_offset = _load_timedelta(std_offset)
|
||||
self.std = _ttinfo(
|
||||
utcoff=std_offset, dstoff=_load_timedelta(0), tzname=std_abbr
|
||||
)
|
||||
|
||||
self.start = start
|
||||
self.end = end
|
||||
|
||||
dst_offset = _load_timedelta(dst_offset)
|
||||
delta = _load_timedelta(self.dst_diff)
|
||||
self.dst = _ttinfo(utcoff=dst_offset, dstoff=delta, tzname=dst_abbr)
|
||||
|
||||
# These are assertions because the constructor should only be called
|
||||
# by functions that would fail before passing start or end
|
||||
assert start is not None, "No transition start specified"
|
||||
assert end is not None, "No transition end specified"
|
||||
|
||||
self.get_trans_info = self._get_trans_info
|
||||
self.get_trans_info_fromutc = self._get_trans_info_fromutc
|
||||
|
||||
def transitions(self, year):
|
||||
start = self.start.year_to_epoch(year)
|
||||
end = self.end.year_to_epoch(year)
|
||||
return start, end
|
||||
|
||||
def _get_trans_info(self, ts, year, fold):
|
||||
"""Get the information about the current transition - tti"""
|
||||
start, end = self.transitions(year)
|
||||
|
||||
# With fold = 0, the period (denominated in local time) with the
|
||||
# smaller offset starts at the end of the gap and ends at the end of
|
||||
# the fold; with fold = 1, it runs from the start of the gap to the
|
||||
# beginning of the fold.
|
||||
#
|
||||
# So in order to determine the DST boundaries we need to know both
|
||||
# the fold and whether DST is positive or negative (rare), and it
|
||||
# turns out that this boils down to fold XOR is_positive.
|
||||
if fold == (self.dst_diff >= 0):
|
||||
end -= self.dst_diff
|
||||
else:
|
||||
start += self.dst_diff
|
||||
|
||||
if start < end:
|
||||
isdst = start <= ts < end
|
||||
else:
|
||||
isdst = not (end <= ts < start)
|
||||
|
||||
return self.dst if isdst else self.std
|
||||
|
||||
def _get_trans_info_fromutc(self, ts, year):
|
||||
start, end = self.transitions(year)
|
||||
start -= self.std.utcoff.total_seconds()
|
||||
end -= self.dst.utcoff.total_seconds()
|
||||
|
||||
if start < end:
|
||||
isdst = start <= ts < end
|
||||
else:
|
||||
isdst = not (end <= ts < start)
|
||||
|
||||
# For positive DST, the ambiguous period is one dst_diff after the end
|
||||
# of DST; for negative DST, the ambiguous period is one dst_diff before
|
||||
# the start of DST.
|
||||
if self.dst_diff > 0:
|
||||
ambig_start = end
|
||||
ambig_end = end + self.dst_diff
|
||||
else:
|
||||
ambig_start = start
|
||||
ambig_end = start - self.dst_diff
|
||||
|
||||
fold = ambig_start <= ts < ambig_end
|
||||
|
||||
return (self.dst if isdst else self.std, fold)
|
||||
|
||||
|
||||
def _post_epoch_days_before_year(year):
|
||||
"""Get the number of days between 1970-01-01 and YEAR-01-01"""
|
||||
y = year - 1
|
||||
return y * 365 + y // 4 - y // 100 + y // 400 - EPOCHORDINAL
|
||||
|
||||
|
||||
class _DayOffset:
|
||||
__slots__ = ["d", "julian", "hour", "minute", "second"]
|
||||
|
||||
def __init__(self, d, julian, hour=2, minute=0, second=0):
|
||||
if not (0 + julian) <= d <= 365:
|
||||
min_day = 0 + julian
|
||||
raise ValueError(f"d must be in [{min_day}, 365], not: {d}")
|
||||
|
||||
self.d = d
|
||||
self.julian = julian
|
||||
self.hour = hour
|
||||
self.minute = minute
|
||||
self.second = second
|
||||
|
||||
def year_to_epoch(self, year):
|
||||
days_before_year = _post_epoch_days_before_year(year)
|
||||
|
||||
d = self.d
|
||||
if self.julian and d >= 59 and calendar.isleap(year):
|
||||
d += 1
|
||||
|
||||
epoch = (days_before_year + d) * 86400
|
||||
epoch += self.hour * 3600 + self.minute * 60 + self.second
|
||||
|
||||
return epoch
|
||||
|
||||
|
||||
class _CalendarOffset:
|
||||
__slots__ = ["m", "w", "d", "hour", "minute", "second"]
|
||||
|
||||
_DAYS_BEFORE_MONTH = (
|
||||
-1,
|
||||
0,
|
||||
31,
|
||||
59,
|
||||
90,
|
||||
120,
|
||||
151,
|
||||
181,
|
||||
212,
|
||||
243,
|
||||
273,
|
||||
304,
|
||||
334,
|
||||
)
|
||||
|
||||
def __init__(self, m, w, d, hour=2, minute=0, second=0):
|
||||
if not 0 < m <= 12:
|
||||
raise ValueError("m must be in (0, 12]")
|
||||
|
||||
if not 0 < w <= 5:
|
||||
raise ValueError("w must be in (0, 5]")
|
||||
|
||||
if not 0 <= d <= 6:
|
||||
raise ValueError("d must be in [0, 6]")
|
||||
|
||||
self.m = m
|
||||
self.w = w
|
||||
self.d = d
|
||||
self.hour = hour
|
||||
self.minute = minute
|
||||
self.second = second
|
||||
|
||||
@classmethod
|
||||
def _ymd2ord(cls, year, month, day):
|
||||
return (
|
||||
_post_epoch_days_before_year(year)
|
||||
+ cls._DAYS_BEFORE_MONTH[month]
|
||||
+ (month > 2 and calendar.isleap(year))
|
||||
+ day
|
||||
)
|
||||
|
||||
# TODO: These are not actually epoch dates as they are expressed in local time
|
||||
def year_to_epoch(self, year):
|
||||
"""Calculates the datetime of the occurrence from the year"""
|
||||
# We know year and month, we need to convert w, d into day of month
|
||||
#
|
||||
# Week 1 is the first week in which day `d` (where 0 = Sunday) appears.
|
||||
# Week 5 represents the last occurrence of day `d`, so we need to know
|
||||
# the range of the month.
|
||||
first_day, days_in_month = calendar.monthrange(year, self.m)
|
||||
|
||||
# This equation seems magical, so I'll break it down:
|
||||
# 1. calendar says 0 = Monday, POSIX says 0 = Sunday
|
||||
# so we need first_day + 1 to get 1 = Monday -> 7 = Sunday,
|
||||
# which is still equivalent because this math is mod 7
|
||||
# 2. Get first day - desired day mod 7: -1 % 7 = 6, so we don't need
|
||||
# to do anything to adjust negative numbers.
|
||||
# 3. Add 1 because month days are a 1-based index.
|
||||
month_day = (self.d - (first_day + 1)) % 7 + 1
|
||||
|
||||
# Now use a 0-based index version of `w` to calculate the w-th
|
||||
# occurrence of `d`
|
||||
month_day += (self.w - 1) * 7
|
||||
|
||||
# month_day will only be > days_in_month if w was 5, and `w` means
|
||||
# "last occurrence of `d`", so now we just check if we over-shot the
|
||||
# end of the month and if so knock off 1 week.
|
||||
if month_day > days_in_month:
|
||||
month_day -= 7
|
||||
|
||||
ordinal = self._ymd2ord(year, self.m, month_day)
|
||||
epoch = ordinal * 86400
|
||||
epoch += self.hour * 3600 + self.minute * 60 + self.second
|
||||
return epoch
|
||||
|
||||
|
||||
def _parse_tz_str(tz_str):
|
||||
# The tz string has the format:
|
||||
#
|
||||
# std[offset[dst[offset],start[/time],end[/time]]]
|
||||
#
|
||||
# std and dst must be 3 or more characters long and must not contain
|
||||
# a leading colon, embedded digits, commas, nor a plus or minus signs;
|
||||
# The spaces between "std" and "offset" are only for display and are
|
||||
# not actually present in the string.
|
||||
#
|
||||
# The format of the offset is ``[+|-]hh[:mm[:ss]]``
|
||||
|
||||
offset_str, *start_end_str = tz_str.split(",", 1)
|
||||
|
||||
# fmt: off
|
||||
parser_re = re.compile(
|
||||
r"(?P<std>[^<0-9:.+-]+|<[a-zA-Z0-9+\-]+>)" +
|
||||
r"((?P<stdoff>[+-]?\d{1,2}(:\d{2}(:\d{2})?)?)" +
|
||||
r"((?P<dst>[^0-9:.+-]+|<[a-zA-Z0-9+\-]+>)" +
|
||||
r"((?P<dstoff>[+-]?\d{1,2}(:\d{2}(:\d{2})?)?))?" +
|
||||
r")?" + # dst
|
||||
r")?$" # stdoff
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
m = parser_re.match(offset_str)
|
||||
|
||||
if m is None:
|
||||
raise ValueError(f"{tz_str} is not a valid TZ string")
|
||||
|
||||
std_abbr = m.group("std")
|
||||
dst_abbr = m.group("dst")
|
||||
dst_offset = None
|
||||
|
||||
std_abbr = std_abbr.strip("<>")
|
||||
|
||||
if dst_abbr:
|
||||
dst_abbr = dst_abbr.strip("<>")
|
||||
|
||||
std_offset = m.group("stdoff")
|
||||
if std_offset:
|
||||
try:
|
||||
std_offset = _parse_tz_delta(std_offset)
|
||||
except ValueError as e:
|
||||
raise ValueError(f"Invalid STD offset in {tz_str}") from e
|
||||
else:
|
||||
std_offset = 0
|
||||
|
||||
if dst_abbr is not None:
|
||||
dst_offset = m.group("dstoff")
|
||||
if dst_offset:
|
||||
try:
|
||||
dst_offset = _parse_tz_delta(dst_offset)
|
||||
except ValueError as e:
|
||||
raise ValueError(f"Invalid DST offset in {tz_str}") from e
|
||||
else:
|
||||
dst_offset = std_offset + 3600
|
||||
|
||||
if not start_end_str:
|
||||
raise ValueError(f"Missing transition rules: {tz_str}")
|
||||
|
||||
start_end_strs = start_end_str[0].split(",", 1)
|
||||
try:
|
||||
start, end = (_parse_dst_start_end(x) for x in start_end_strs)
|
||||
except ValueError as e:
|
||||
raise ValueError(f"Invalid TZ string: {tz_str}") from e
|
||||
|
||||
return _TZStr(std_abbr, std_offset, dst_abbr, dst_offset, start, end)
|
||||
elif start_end_str:
|
||||
raise ValueError(f"Transition rule present without DST: {tz_str}")
|
||||
else:
|
||||
# This is a static ttinfo, don't return _TZStr
|
||||
return _ttinfo(
|
||||
_load_timedelta(std_offset), _load_timedelta(0), std_abbr
|
||||
)
|
||||
|
||||
|
||||
def _parse_dst_start_end(dststr):
|
||||
date, *time = dststr.split("/")
|
||||
if date[0] == "M":
|
||||
n_is_julian = False
|
||||
m = re.match(r"M(\d{1,2})\.(\d).(\d)$", date)
|
||||
if m is None:
|
||||
raise ValueError(f"Invalid dst start/end date: {dststr}")
|
||||
date_offset = tuple(map(int, m.groups()))
|
||||
offset = _CalendarOffset(*date_offset)
|
||||
else:
|
||||
if date[0] == "J":
|
||||
n_is_julian = True
|
||||
date = date[1:]
|
||||
else:
|
||||
n_is_julian = False
|
||||
|
||||
doy = int(date)
|
||||
offset = _DayOffset(doy, n_is_julian)
|
||||
|
||||
if time:
|
||||
time_components = list(map(int, time[0].split(":")))
|
||||
n_components = len(time_components)
|
||||
if n_components < 3:
|
||||
time_components.extend([0] * (3 - n_components))
|
||||
offset.hour, offset.minute, offset.second = time_components
|
||||
|
||||
return offset
|
||||
|
||||
|
||||
def _parse_tz_delta(tz_delta):
|
||||
match = re.match(
|
||||
r"(?P<sign>[+-])?(?P<h>\d{1,2})(:(?P<m>\d{2})(:(?P<s>\d{2}))?)?",
|
||||
tz_delta,
|
||||
)
|
||||
# Anything passed to this function should already have hit an equivalent
|
||||
# regular expression to find the section to parse.
|
||||
assert match is not None, tz_delta
|
||||
|
||||
h, m, s = (
|
||||
int(v) if v is not None else 0
|
||||
for v in map(match.group, ("h", "m", "s"))
|
||||
)
|
||||
|
||||
total = h * 3600 + m * 60 + s
|
||||
|
||||
if not -86400 < total < 86400:
|
||||
raise ValueError(
|
||||
"Offset must be strictly between -24h and +24h:" + tz_delta
|
||||
)
|
||||
|
||||
# Yes, +5 maps to an offset of -5h
|
||||
if match.group("sign") != "-":
|
||||
total *= -1
|
||||
|
||||
return total
|
|
@ -1 +0,0 @@
|
|||
__version__ = '1.10.0'
|
|
@ -1,169 +0,0 @@
|
|||
from __future__ import absolute_import
|
||||
import sys
|
||||
|
||||
# True if we are running on Python 2.
|
||||
PY2 = sys.version_info[0] == 2
|
||||
PYVER = sys.version_info[:2]
|
||||
JYTHON = sys.platform.startswith('java')
|
||||
|
||||
if PY2 and not JYTHON: # pragma: no cover
|
||||
import cPickle as pickle
|
||||
else: # pragma: no cover
|
||||
import pickle
|
||||
|
||||
|
||||
if not PY2: # pragma: no cover
|
||||
xrange_ = range
|
||||
NoneType = type(None)
|
||||
|
||||
string_type = str
|
||||
unicode_text = str
|
||||
byte_string = bytes
|
||||
|
||||
from urllib.parse import urlencode as url_encode
|
||||
from urllib.parse import quote as url_quote
|
||||
from urllib.parse import unquote as url_unquote
|
||||
from urllib.parse import urlparse as url_parse
|
||||
from urllib.request import url2pathname
|
||||
import http.cookies as http_cookies
|
||||
from base64 import b64decode as _b64decode, b64encode as _b64encode
|
||||
|
||||
try:
|
||||
import dbm as anydbm
|
||||
except:
|
||||
import dumbdbm as anydbm
|
||||
|
||||
def b64decode(b):
|
||||
return _b64decode(b.encode('ascii'))
|
||||
|
||||
def b64encode(s):
|
||||
return _b64encode(s).decode('ascii')
|
||||
|
||||
def u_(s):
|
||||
return str(s)
|
||||
|
||||
def bytes_(s):
|
||||
if isinstance(s, byte_string):
|
||||
return s
|
||||
return str(s).encode('ascii', 'strict')
|
||||
|
||||
def dictkeyslist(d):
|
||||
return list(d.keys())
|
||||
|
||||
else:
|
||||
xrange_ = xrange
|
||||
from types import NoneType
|
||||
|
||||
string_type = basestring
|
||||
unicode_text = unicode
|
||||
byte_string = str
|
||||
|
||||
from urllib import urlencode as url_encode
|
||||
from urllib import quote as url_quote
|
||||
from urllib import unquote as url_unquote
|
||||
from urlparse import urlparse as url_parse
|
||||
from urllib import url2pathname
|
||||
import Cookie as http_cookies
|
||||
from base64 import b64decode, b64encode
|
||||
import anydbm
|
||||
|
||||
def u_(s):
|
||||
if isinstance(s, unicode_text):
|
||||
return s
|
||||
|
||||
if not isinstance(s, byte_string):
|
||||
s = str(s)
|
||||
return unicode(s, 'utf-8')
|
||||
|
||||
def bytes_(s):
|
||||
if isinstance(s, byte_string):
|
||||
return s
|
||||
return str(s)
|
||||
|
||||
def dictkeyslist(d):
|
||||
return d.keys()
|
||||
|
||||
|
||||
def im_func(f):
|
||||
if not PY2: # pragma: no cover
|
||||
return getattr(f, '__func__', None)
|
||||
else:
|
||||
return getattr(f, 'im_func', None)
|
||||
|
||||
|
||||
def default_im_func(f):
|
||||
if not PY2: # pragma: no cover
|
||||
return getattr(f, '__func__', f)
|
||||
else:
|
||||
return getattr(f, 'im_func', f)
|
||||
|
||||
|
||||
def im_self(f):
|
||||
if not PY2: # pragma: no cover
|
||||
return getattr(f, '__self__', None)
|
||||
else:
|
||||
return getattr(f, 'im_self', None)
|
||||
|
||||
|
||||
def im_class(f):
|
||||
if not PY2: # pragma: no cover
|
||||
self = im_self(f)
|
||||
if self is not None:
|
||||
return self.__class__
|
||||
else:
|
||||
return None
|
||||
else:
|
||||
return getattr(f, 'im_class', None)
|
||||
|
||||
|
||||
def add_metaclass(metaclass):
|
||||
"""Class decorator for creating a class with a metaclass."""
|
||||
def wrapper(cls):
|
||||
orig_vars = cls.__dict__.copy()
|
||||
slots = orig_vars.get('__slots__')
|
||||
if slots is not None:
|
||||
if isinstance(slots, str):
|
||||
slots = [slots]
|
||||
for slots_var in slots:
|
||||
orig_vars.pop(slots_var)
|
||||
orig_vars.pop('__dict__', None)
|
||||
orig_vars.pop('__weakref__', None)
|
||||
return metaclass(cls.__name__, cls.__bases__, orig_vars)
|
||||
return wrapper
|
||||
|
||||
|
||||
if not PY2: # pragma: no cover
|
||||
import builtins
|
||||
exec_ = getattr(builtins, "exec")
|
||||
|
||||
def reraise(tp, value, tb=None):
|
||||
if value.__traceback__ is not tb:
|
||||
raise value.with_traceback(tb)
|
||||
raise value
|
||||
else: # pragma: no cover
|
||||
def exec_(code, globs=None, locs=None):
|
||||
"""Execute code in a namespace."""
|
||||
if globs is None:
|
||||
frame = sys._getframe(1)
|
||||
globs = frame.f_globals
|
||||
if locs is None:
|
||||
locs = frame.f_locals
|
||||
del frame
|
||||
elif locs is None:
|
||||
locs = globs
|
||||
exec("""exec code in globs, locs""")
|
||||
|
||||
exec_("""def reraise(tp, value, tb=None):
|
||||
raise tp, value, tb
|
||||
""")
|
||||
|
||||
|
||||
try:
|
||||
from inspect import signature as func_signature
|
||||
except ImportError:
|
||||
from funcsigs import signature as func_signature
|
||||
|
||||
|
||||
def bindfuncargs(arginfo, args, kwargs):
|
||||
boundargs = arginfo.bind(*args, **kwargs)
|
||||
return boundargs.args, boundargs.kwargs
|
|
@ -1,615 +0,0 @@
|
|||
"""This package contains the "front end" classes and functions
|
||||
for Beaker caching.
|
||||
|
||||
Included are the :class:`.Cache` and :class:`.CacheManager` classes,
|
||||
as well as the function decorators :func:`.region_decorate`,
|
||||
:func:`.region_invalidate`.
|
||||
|
||||
"""
|
||||
import warnings
|
||||
from itertools import chain
|
||||
|
||||
from beaker._compat import u_, unicode_text, func_signature, bindfuncargs
|
||||
import beaker.container as container
|
||||
import beaker.util as util
|
||||
from beaker.crypto.util import sha1
|
||||
from beaker.exceptions import BeakerException, InvalidCacheBackendError
|
||||
from beaker.synchronization import _threading
|
||||
|
||||
import beaker.ext.memcached as memcached
|
||||
import beaker.ext.database as database
|
||||
import beaker.ext.sqla as sqla
|
||||
import beaker.ext.google as google
|
||||
import beaker.ext.mongodb as mongodb
|
||||
import beaker.ext.redisnm as redisnm
|
||||
from functools import wraps
|
||||
|
||||
# Initialize the cache region dict
|
||||
cache_regions = {}
|
||||
"""Dictionary of 'region' arguments.
|
||||
|
||||
A "region" is a string name that refers to a series of cache
|
||||
configuration arguments. An application may have multiple
|
||||
"regions" - one which stores things in a memory cache, one
|
||||
which writes data to files, etc.
|
||||
|
||||
The dictionary stores string key names mapped to dictionaries
|
||||
of configuration arguments. Example::
|
||||
|
||||
from beaker.cache import cache_regions
|
||||
cache_regions.update({
|
||||
'short_term':{
|
||||
'expire':60,
|
||||
'type':'memory'
|
||||
},
|
||||
'long_term':{
|
||||
'expire':1800,
|
||||
'type':'dbm',
|
||||
'data_dir':'/tmp',
|
||||
}
|
||||
})
|
||||
"""
|
||||
|
||||
|
||||
cache_managers = {}
|
||||
|
||||
|
||||
class _backends(object):
|
||||
initialized = False
|
||||
|
||||
def __init__(self, clsmap):
|
||||
self._clsmap = clsmap
|
||||
self._mutex = _threading.Lock()
|
||||
|
||||
def __getitem__(self, key):
|
||||
try:
|
||||
return self._clsmap[key]
|
||||
except KeyError as e:
|
||||
if not self.initialized:
|
||||
self._mutex.acquire()
|
||||
try:
|
||||
if not self.initialized:
|
||||
self._init()
|
||||
self.initialized = True
|
||||
|
||||
return self._clsmap[key]
|
||||
finally:
|
||||
self._mutex.release()
|
||||
|
||||
raise e
|
||||
|
||||
def _init(self):
|
||||
try:
|
||||
import pkg_resources
|
||||
|
||||
# Load up the additional entry point defined backends
|
||||
for entry_point in pkg_resources.iter_entry_points('beaker.backends'):
|
||||
try:
|
||||
namespace_manager = entry_point.load()
|
||||
name = entry_point.name
|
||||
if name in self._clsmap:
|
||||
raise BeakerException("NamespaceManager name conflict,'%s' "
|
||||
"already loaded" % name)
|
||||
self._clsmap[name] = namespace_manager
|
||||
except (InvalidCacheBackendError, SyntaxError):
|
||||
# Ignore invalid backends
|
||||
pass
|
||||
except:
|
||||
import sys
|
||||
from pkg_resources import DistributionNotFound
|
||||
# Warn when there's a problem loading a NamespaceManager
|
||||
if not isinstance(sys.exc_info()[1], DistributionNotFound):
|
||||
import traceback
|
||||
try:
|
||||
from StringIO import StringIO # Python2
|
||||
except ImportError:
|
||||
from io import StringIO # Python3
|
||||
|
||||
tb = StringIO()
|
||||
traceback.print_exc(file=tb)
|
||||
warnings.warn(
|
||||
"Unable to load NamespaceManager "
|
||||
"entry point: '%s': %s" % (
|
||||
entry_point,
|
||||
tb.getvalue()),
|
||||
RuntimeWarning, 2)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# Initialize the basic available backends
|
||||
clsmap = _backends({
|
||||
'memory': container.MemoryNamespaceManager,
|
||||
'dbm': container.DBMNamespaceManager,
|
||||
'file': container.FileNamespaceManager,
|
||||
'ext:memcached': memcached.MemcachedNamespaceManager,
|
||||
'ext:database': database.DatabaseNamespaceManager,
|
||||
'ext:sqla': sqla.SqlaNamespaceManager,
|
||||
'ext:google': google.GoogleNamespaceManager,
|
||||
'ext:mongodb': mongodb.MongoNamespaceManager,
|
||||
'ext:redis': redisnm.RedisNamespaceManager
|
||||
})
|
||||
|
||||
|
||||
def cache_region(region, *args):
|
||||
"""Decorate a function such that its return result is cached,
|
||||
using a "region" to indicate the cache arguments.
|
||||
|
||||
Example::
|
||||
|
||||
from beaker.cache import cache_regions, cache_region
|
||||
|
||||
# configure regions
|
||||
cache_regions.update({
|
||||
'short_term':{
|
||||
'expire':60,
|
||||
'type':'memory'
|
||||
}
|
||||
})
|
||||
|
||||
@cache_region('short_term', 'load_things')
|
||||
def load(search_term, limit, offset):
|
||||
'''Load from a database given a search term, limit, offset.'''
|
||||
return database.query(search_term)[offset:offset + limit]
|
||||
|
||||
The decorator can also be used with object methods. The ``self``
|
||||
argument is not part of the cache key. This is based on the
|
||||
actual string name ``self`` being in the first argument
|
||||
position (new in 1.6)::
|
||||
|
||||
class MyThing(object):
|
||||
@cache_region('short_term', 'load_things')
|
||||
def load(self, search_term, limit, offset):
|
||||
'''Load from a database given a search term, limit, offset.'''
|
||||
return database.query(search_term)[offset:offset + limit]
|
||||
|
||||
Classmethods work as well - use ``cls`` as the name of the class argument,
|
||||
and place the decorator around the function underneath ``@classmethod``
|
||||
(new in 1.6)::
|
||||
|
||||
class MyThing(object):
|
||||
@classmethod
|
||||
@cache_region('short_term', 'load_things')
|
||||
def load(cls, search_term, limit, offset):
|
||||
'''Load from a database given a search term, limit, offset.'''
|
||||
return database.query(search_term)[offset:offset + limit]
|
||||
|
||||
:param region: String name of the region corresponding to the desired
|
||||
caching arguments, established in :attr:`.cache_regions`.
|
||||
|
||||
:param \*args: Optional ``str()``-compatible arguments which will uniquely
|
||||
identify the key used by this decorated function, in addition
|
||||
to the positional arguments passed to the function itself at call time.
|
||||
This is recommended as it is needed to distinguish between any two functions
|
||||
or methods that have the same name (regardless of parent class or not).
|
||||
|
||||
.. note::
|
||||
|
||||
The function being decorated must only be called with
|
||||
positional arguments, and the arguments must support
|
||||
being stringified with ``str()``. The concatenation
|
||||
of the ``str()`` version of each argument, combined
|
||||
with that of the ``*args`` sent to the decorator,
|
||||
forms the unique cache key.
|
||||
|
||||
.. note::
|
||||
|
||||
When a method on a class is decorated, the ``self`` or ``cls``
|
||||
argument in the first position is
|
||||
not included in the "key" used for caching. New in 1.6.
|
||||
|
||||
"""
|
||||
return _cache_decorate(args, None, None, region)
|
||||
|
||||
|
||||
def region_invalidate(namespace, region, *args):
|
||||
"""Invalidate a cache region corresponding to a function
|
||||
decorated with :func:`.cache_region`.
|
||||
|
||||
:param namespace: The namespace of the cache to invalidate. This is typically
|
||||
a reference to the original function (as returned by the :func:`.cache_region`
|
||||
decorator), where the :func:`.cache_region` decorator applies a "memo" to
|
||||
the function in order to locate the string name of the namespace.
|
||||
|
||||
:param region: String name of the region used with the decorator. This can be
|
||||
``None`` in the usual case that the decorated function itself is passed,
|
||||
not the string name of the namespace.
|
||||
|
||||
:param args: Stringifyable arguments that are used to locate the correct
|
||||
key. This consists of the ``*args`` sent to the :func:`.cache_region`
|
||||
decorator itself, plus the ``*args`` sent to the function itself
|
||||
at runtime.
|
||||
|
||||
Example::
|
||||
|
||||
from beaker.cache import cache_regions, cache_region, region_invalidate
|
||||
|
||||
# configure regions
|
||||
cache_regions.update({
|
||||
'short_term':{
|
||||
'expire':60,
|
||||
'type':'memory'
|
||||
}
|
||||
})
|
||||
|
||||
@cache_region('short_term', 'load_data')
|
||||
def load(search_term, limit, offset):
|
||||
'''Load from a database given a search term, limit, offset.'''
|
||||
return database.query(search_term)[offset:offset + limit]
|
||||
|
||||
def invalidate_search(search_term, limit, offset):
|
||||
'''Invalidate the cached storage for a given search term, limit, offset.'''
|
||||
region_invalidate(load, 'short_term', 'load_data', search_term, limit, offset)
|
||||
|
||||
Note that when a method on a class is decorated, the first argument ``cls``
|
||||
or ``self`` is not included in the cache key. This means you don't send
|
||||
it to :func:`.region_invalidate`::
|
||||
|
||||
class MyThing(object):
|
||||
@cache_region('short_term', 'some_data')
|
||||
def load(self, search_term, limit, offset):
|
||||
'''Load from a database given a search term, limit, offset.'''
|
||||
return database.query(search_term)[offset:offset + limit]
|
||||
|
||||
def invalidate_search(self, search_term, limit, offset):
|
||||
'''Invalidate the cached storage for a given search term, limit, offset.'''
|
||||
region_invalidate(self.load, 'short_term', 'some_data', search_term, limit, offset)
|
||||
|
||||
"""
|
||||
if callable(namespace):
|
||||
if not region:
|
||||
region = namespace._arg_region
|
||||
namespace = namespace._arg_namespace
|
||||
|
||||
if not region:
|
||||
raise BeakerException("Region or callable function "
|
||||
"namespace is required")
|
||||
else:
|
||||
region = cache_regions[region]
|
||||
|
||||
cache = Cache._get_cache(namespace, region)
|
||||
_cache_decorator_invalidate(cache,
|
||||
region.get('key_length', util.DEFAULT_CACHE_KEY_LENGTH),
|
||||
args)
|
||||
|
||||
|
||||
class Cache(object):
|
||||
"""Front-end to the containment API implementing a data cache.
|
||||
|
||||
:param namespace: the namespace of this Cache
|
||||
|
||||
:param type: type of cache to use
|
||||
|
||||
:param expire: seconds to keep cached data
|
||||
|
||||
:param expiretime: seconds to keep cached data (legacy support)
|
||||
|
||||
:param starttime: time when cache was cache was
|
||||
|
||||
"""
|
||||
def __init__(self, namespace, type='memory', expiretime=None,
|
||||
starttime=None, expire=None, **nsargs):
|
||||
try:
|
||||
cls = clsmap[type]
|
||||
if isinstance(cls, InvalidCacheBackendError):
|
||||
raise cls
|
||||
except KeyError:
|
||||
raise TypeError("Unknown cache implementation %r" % type)
|
||||
|
||||
if expire is not None:
|
||||
expire = int(expire)
|
||||
|
||||
self.namespace_name = namespace
|
||||
self.namespace = cls(namespace, **nsargs)
|
||||
self.expiretime = expiretime or expire
|
||||
self.starttime = starttime
|
||||
self.nsargs = nsargs
|
||||
|
||||
@classmethod
|
||||
def _get_cache(cls, namespace, kw):
|
||||
key = namespace + str(kw)
|
||||
try:
|
||||
return cache_managers[key]
|
||||
except KeyError:
|
||||
cache_managers[key] = cache = cls(namespace, **kw)
|
||||
return cache
|
||||
|
||||
def put(self, key, value, **kw):
|
||||
self._get_value(key, **kw).set_value(value)
|
||||
set_value = put
|
||||
|
||||
def get(self, key, **kw):
|
||||
"""Retrieve a cached value from the container"""
|
||||
return self._get_value(key, **kw).get_value()
|
||||
get_value = get
|
||||
|
||||
def remove_value(self, key, **kw):
|
||||
mycontainer = self._get_value(key, **kw)
|
||||
mycontainer.clear_value()
|
||||
remove = remove_value
|
||||
|
||||
def _get_value(self, key, **kw):
|
||||
if isinstance(key, unicode_text):
|
||||
key = key.encode('ascii', 'backslashreplace')
|
||||
|
||||
if 'type' in kw:
|
||||
return self._legacy_get_value(key, **kw)
|
||||
|
||||
kw.setdefault('expiretime', self.expiretime)
|
||||
kw.setdefault('starttime', self.starttime)
|
||||
|
||||
return container.Value(key, self.namespace, **kw)
|
||||
|
||||
@util.deprecated("Specifying a "
|
||||
"'type' and other namespace configuration with cache.get()/put()/etc. "
|
||||
"is deprecated. Specify 'type' and other namespace configuration to "
|
||||
"cache_manager.get_cache() and/or the Cache constructor instead.")
|
||||
def _legacy_get_value(self, key, type, **kw):
|
||||
expiretime = kw.pop('expiretime', self.expiretime)
|
||||
starttime = kw.pop('starttime', None)
|
||||
createfunc = kw.pop('createfunc', None)
|
||||
kwargs = self.nsargs.copy()
|
||||
kwargs.update(kw)
|
||||
c = Cache(self.namespace.namespace, type=type, **kwargs)
|
||||
return c._get_value(key, expiretime=expiretime, createfunc=createfunc,
|
||||
starttime=starttime)
|
||||
|
||||
def clear(self):
|
||||
"""Clear all the values from the namespace"""
|
||||
self.namespace.remove()
|
||||
|
||||
# dict interface
|
||||
def __getitem__(self, key):
|
||||
return self.get(key)
|
||||
|
||||
def __contains__(self, key):
|
||||
return self._get_value(key).has_current_value()
|
||||
|
||||
def has_key(self, key):
|
||||
return key in self
|
||||
|
||||
def __delitem__(self, key):
|
||||
self.remove_value(key)
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
self.put(key, value)
|
||||
|
||||
|
||||
class CacheManager(object):
|
||||
def __init__(self, **kwargs):
|
||||
"""Initialize a CacheManager object with a set of options
|
||||
|
||||
Options should be parsed with the
|
||||
:func:`~beaker.util.parse_cache_config_options` function to
|
||||
ensure only valid options are used.
|
||||
|
||||
"""
|
||||
self.kwargs = kwargs
|
||||
self.regions = kwargs.pop('cache_regions', {})
|
||||
|
||||
# Add these regions to the module global
|
||||
cache_regions.update(self.regions)
|
||||
|
||||
def get_cache(self, name, **kwargs):
|
||||
kw = self.kwargs.copy()
|
||||
kw.update(kwargs)
|
||||
return Cache._get_cache(name, kw)
|
||||
|
||||
def get_cache_region(self, name, region):
|
||||
if region not in self.regions:
|
||||
raise BeakerException('Cache region not configured: %s' % region)
|
||||
kw = self.regions[region]
|
||||
return Cache._get_cache(name, kw)
|
||||
|
||||
def region(self, region, *args):
|
||||
"""Decorate a function to cache itself using a cache region
|
||||
|
||||
The region decorator requires arguments if there are more than
|
||||
two of the same named function, in the same module. This is
|
||||
because the namespace used for the functions cache is based on
|
||||
the functions name and the module.
|
||||
|
||||
|
||||
Example::
|
||||
|
||||
# Assuming a cache object is available like:
|
||||
cache = CacheManager(dict_of_config_options)
|
||||
|
||||
|
||||
def populate_things():
|
||||
|
||||
@cache.region('short_term', 'some_data')
|
||||
def load(search_term, limit, offset):
|
||||
return load_the_data(search_term, limit, offset)
|
||||
|
||||
return load('rabbits', 20, 0)
|
||||
|
||||
.. note::
|
||||
|
||||
The function being decorated must only be called with
|
||||
positional arguments.
|
||||
|
||||
"""
|
||||
return cache_region(region, *args)
|
||||
|
||||
def region_invalidate(self, namespace, region, *args):
|
||||
"""Invalidate a cache region namespace or decorated function
|
||||
|
||||
This function only invalidates cache spaces created with the
|
||||
cache_region decorator.
|
||||
|
||||
:param namespace: Either the namespace of the result to invalidate, or the
|
||||
cached function
|
||||
|
||||
:param region: The region the function was cached to. If the function was
|
||||
cached to a single region then this argument can be None
|
||||
|
||||
:param args: Arguments that were used to differentiate the cached
|
||||
function as well as the arguments passed to the decorated
|
||||
function
|
||||
|
||||
Example::
|
||||
|
||||
# Assuming a cache object is available like:
|
||||
cache = CacheManager(dict_of_config_options)
|
||||
|
||||
def populate_things(invalidate=False):
|
||||
|
||||
@cache.region('short_term', 'some_data')
|
||||
def load(search_term, limit, offset):
|
||||
return load_the_data(search_term, limit, offset)
|
||||
|
||||
# If the results should be invalidated first
|
||||
if invalidate:
|
||||
cache.region_invalidate(load, None, 'some_data',
|
||||
'rabbits', 20, 0)
|
||||
return load('rabbits', 20, 0)
|
||||
|
||||
|
||||
"""
|
||||
return region_invalidate(namespace, region, *args)
|
||||
|
||||
def cache(self, *args, **kwargs):
|
||||
"""Decorate a function to cache itself with supplied parameters
|
||||
|
||||
:param args: Used to make the key unique for this function, as in region()
|
||||
above.
|
||||
|
||||
:param kwargs: Parameters to be passed to get_cache(), will override defaults
|
||||
|
||||
Example::
|
||||
|
||||
# Assuming a cache object is available like:
|
||||
cache = CacheManager(dict_of_config_options)
|
||||
|
||||
|
||||
def populate_things():
|
||||
|
||||
@cache.cache('mycache', expire=15)
|
||||
def load(search_term, limit, offset):
|
||||
return load_the_data(search_term, limit, offset)
|
||||
|
||||
return load('rabbits', 20, 0)
|
||||
|
||||
.. note::
|
||||
|
||||
The function being decorated must only be called with
|
||||
positional arguments.
|
||||
|
||||
"""
|
||||
return _cache_decorate(args, self, kwargs, None)
|
||||
|
||||
def invalidate(self, func, *args, **kwargs):
|
||||
"""Invalidate a cache decorated function
|
||||
|
||||
This function only invalidates cache spaces created with the
|
||||
cache decorator.
|
||||
|
||||
:param func: Decorated function to invalidate
|
||||
|
||||
:param args: Used to make the key unique for this function, as in region()
|
||||
above.
|
||||
|
||||
:param kwargs: Parameters that were passed for use by get_cache(), note that
|
||||
this is only required if a ``type`` was specified for the
|
||||
function
|
||||
|
||||
Example::
|
||||
|
||||
# Assuming a cache object is available like:
|
||||
cache = CacheManager(dict_of_config_options)
|
||||
|
||||
|
||||
def populate_things(invalidate=False):
|
||||
|
||||
@cache.cache('mycache', type="file", expire=15)
|
||||
def load(search_term, limit, offset):
|
||||
return load_the_data(search_term, limit, offset)
|
||||
|
||||
# If the results should be invalidated first
|
||||
if invalidate:
|
||||
cache.invalidate(load, 'mycache', 'rabbits', 20, 0, type="file")
|
||||
return load('rabbits', 20, 0)
|
||||
|
||||
"""
|
||||
namespace = func._arg_namespace
|
||||
|
||||
cache = self.get_cache(namespace, **kwargs)
|
||||
if hasattr(func, '_arg_region'):
|
||||
cachereg = cache_regions[func._arg_region]
|
||||
key_length = cachereg.get('key_length', util.DEFAULT_CACHE_KEY_LENGTH)
|
||||
else:
|
||||
key_length = kwargs.pop('key_length', util.DEFAULT_CACHE_KEY_LENGTH)
|
||||
_cache_decorator_invalidate(cache, key_length, args)
|
||||
|
||||
|
||||
def _cache_decorate(deco_args, manager, options, region):
|
||||
"""Return a caching function decorator."""
|
||||
|
||||
cache = [None]
|
||||
|
||||
def decorate(func):
|
||||
namespace = util.func_namespace(func)
|
||||
skip_self = util.has_self_arg(func)
|
||||
signature = func_signature(func)
|
||||
|
||||
@wraps(func)
|
||||
def cached(*args, **kwargs):
|
||||
if not cache[0]:
|
||||
if region is not None:
|
||||
if region not in cache_regions:
|
||||
raise BeakerException(
|
||||
'Cache region not configured: %s' % region)
|
||||
reg = cache_regions[region]
|
||||
if not reg.get('enabled', True):
|
||||
return func(*args, **kwargs)
|
||||
cache[0] = Cache._get_cache(namespace, reg)
|
||||
elif manager:
|
||||
cache[0] = manager.get_cache(namespace, **options)
|
||||
else:
|
||||
raise Exception("'manager + kwargs' or 'region' "
|
||||
"argument is required")
|
||||
|
||||
cache_key_kwargs = []
|
||||
if kwargs:
|
||||
# kwargs provided, merge them in positional args
|
||||
# to avoid having different cache keys.
|
||||
args, kwargs = bindfuncargs(signature, args, kwargs)
|
||||
cache_key_kwargs = [u_(':').join((u_(key), u_(value))) for key, value in kwargs.items()]
|
||||
|
||||
cache_key_args = args
|
||||
if skip_self:
|
||||
cache_key_args = args[1:]
|
||||
|
||||
cache_key = u_(" ").join(map(u_, chain(deco_args, cache_key_args, cache_key_kwargs)))
|
||||
|
||||
if region:
|
||||
cachereg = cache_regions[region]
|
||||
key_length = cachereg.get('key_length', util.DEFAULT_CACHE_KEY_LENGTH)
|
||||
else:
|
||||
key_length = options.pop('key_length', util.DEFAULT_CACHE_KEY_LENGTH)
|
||||
|
||||
# TODO: This is probably a bug as length is checked before converting to UTF8
|
||||
# which will cause cache_key to grow in size.
|
||||
if len(cache_key) + len(namespace) > int(key_length):
|
||||
cache_key = sha1(cache_key.encode('utf-8')).hexdigest()
|
||||
|
||||
def go():
|
||||
return func(*args, **kwargs)
|
||||
# save org function name
|
||||
go.__name__ = '_cached_%s' % (func.__name__,)
|
||||
|
||||
return cache[0].get_value(cache_key, createfunc=go)
|
||||
cached._arg_namespace = namespace
|
||||
if region is not None:
|
||||
cached._arg_region = region
|
||||
return cached
|
||||
return decorate
|
||||
|
||||
|
||||
def _cache_decorator_invalidate(cache, key_length, args):
|
||||
"""Invalidate a cache key based on function arguments."""
|
||||
|
||||
cache_key = u_(" ").join(map(u_, args))
|
||||
if len(cache_key) + len(cache.namespace_name) > key_length:
|
||||
cache_key = sha1(cache_key.encode('utf-8')).hexdigest()
|
||||
cache.remove_value(cache_key)
|
|
@ -1,760 +0,0 @@
|
|||
"""Container and Namespace classes"""
|
||||
import errno
|
||||
|
||||
from ._compat import pickle, anydbm, add_metaclass, PYVER, unicode_text
|
||||
|
||||
import beaker.util as util
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
|
||||
from beaker.exceptions import CreationAbortedError, MissingCacheParameter
|
||||
from beaker.synchronization import _threading, file_synchronizer, \
|
||||
mutex_synchronizer, NameLock, null_synchronizer
|
||||
|
||||
__all__ = ['Value', 'Container', 'ContainerContext',
|
||||
'MemoryContainer', 'DBMContainer', 'NamespaceManager',
|
||||
'MemoryNamespaceManager', 'DBMNamespaceManager', 'FileContainer',
|
||||
'OpenResourceNamespaceManager',
|
||||
'FileNamespaceManager', 'CreationAbortedError']
|
||||
|
||||
|
||||
logger = logging.getLogger('beaker.container')
|
||||
if logger.isEnabledFor(logging.DEBUG):
|
||||
debug = logger.debug
|
||||
else:
|
||||
def debug(message, *args):
|
||||
pass
|
||||
|
||||
|
||||
class NamespaceManager(object):
|
||||
"""Handles dictionary operations and locking for a namespace of
|
||||
values.
|
||||
|
||||
:class:`.NamespaceManager` provides a dictionary-like interface,
|
||||
implementing ``__getitem__()``, ``__setitem__()``, and
|
||||
``__contains__()``, as well as functions related to lock
|
||||
acquisition.
|
||||
|
||||
The implementation for setting and retrieving the namespace data is
|
||||
handled by subclasses.
|
||||
|
||||
NamespaceManager may be used alone, or may be accessed by
|
||||
one or more :class:`.Value` objects. :class:`.Value` objects provide per-key
|
||||
services like expiration times and automatic recreation of values.
|
||||
|
||||
Multiple NamespaceManagers created with a particular name will all
|
||||
share access to the same underlying datasource and will attempt to
|
||||
synchronize against a common mutex object. The scope of this
|
||||
sharing may be within a single process or across multiple
|
||||
processes, depending on the type of NamespaceManager used.
|
||||
|
||||
The NamespaceManager itself is generally threadsafe, except in the
|
||||
case of the DBMNamespaceManager in conjunction with the gdbm dbm
|
||||
implementation.
|
||||
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def _init_dependencies(cls):
|
||||
"""Initialize module-level dependent libraries required
|
||||
by this :class:`.NamespaceManager`."""
|
||||
|
||||
def __init__(self, namespace):
|
||||
self._init_dependencies()
|
||||
self.namespace = namespace
|
||||
|
||||
def get_creation_lock(self, key):
|
||||
"""Return a locking object that is used to synchronize
|
||||
multiple threads or processes which wish to generate a new
|
||||
cache value.
|
||||
|
||||
This function is typically an instance of
|
||||
:class:`.FileSynchronizer`, :class:`.ConditionSynchronizer`,
|
||||
or :class:`.null_synchronizer`.
|
||||
|
||||
The creation lock is only used when a requested value
|
||||
does not exist, or has been expired, and is only used
|
||||
by the :class:`.Value` key-management object in conjunction
|
||||
with a "createfunc" value-creation function.
|
||||
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def do_remove(self):
|
||||
"""Implement removal of the entire contents of this
|
||||
:class:`.NamespaceManager`.
|
||||
|
||||
e.g. for a file-based namespace, this would remove
|
||||
all the files.
|
||||
|
||||
The front-end to this method is the
|
||||
:meth:`.NamespaceManager.remove` method.
|
||||
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def acquire_read_lock(self):
|
||||
"""Establish a read lock.
|
||||
|
||||
This operation is called before a key is read. By
|
||||
default the function does nothing.
|
||||
|
||||
"""
|
||||
|
||||
def release_read_lock(self):
|
||||
"""Release a read lock.
|
||||
|
||||
This operation is called after a key is read. By
|
||||
default the function does nothing.
|
||||
|
||||
"""
|
||||
|
||||
def acquire_write_lock(self, wait=True, replace=False):
|
||||
"""Establish a write lock.
|
||||
|
||||
This operation is called before a key is written.
|
||||
A return value of ``True`` indicates the lock has
|
||||
been acquired.
|
||||
|
||||
By default the function returns ``True`` unconditionally.
|
||||
|
||||
'replace' is a hint indicating the full contents
|
||||
of the namespace may be safely discarded. Some backends
|
||||
may implement this (i.e. file backend won't unpickle the
|
||||
current contents).
|
||||
|
||||
"""
|
||||
return True
|
||||
|
||||
def release_write_lock(self):
|
||||
"""Release a write lock.
|
||||
|
||||
This operation is called after a new value is written.
|
||||
By default this function does nothing.
|
||||
|
||||
"""
|
||||
|
||||
def has_key(self, key):
|
||||
"""Return ``True`` if the given key is present in this
|
||||
:class:`.Namespace`.
|
||||
"""
|
||||
return self.__contains__(key)
|
||||
|
||||
def __getitem__(self, key):
|
||||
raise NotImplementedError()
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
raise NotImplementedError()
|
||||
|
||||
def set_value(self, key, value, expiretime=None):
|
||||
"""Sets a value in this :class:`.NamespaceManager`.
|
||||
|
||||
This is the same as ``__setitem__()``, but
|
||||
also allows an expiration time to be passed
|
||||
at the same time.
|
||||
|
||||
"""
|
||||
self[key] = value
|
||||
|
||||
def __contains__(self, key):
|
||||
raise NotImplementedError()
|
||||
|
||||
def __delitem__(self, key):
|
||||
raise NotImplementedError()
|
||||
|
||||
def keys(self):
|
||||
"""Return the list of all keys.
|
||||
|
||||
This method may not be supported by all
|
||||
:class:`.NamespaceManager` implementations.
|
||||
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def remove(self):
|
||||
"""Remove the entire contents of this
|
||||
:class:`.NamespaceManager`.
|
||||
|
||||
e.g. for a file-based namespace, this would remove
|
||||
all the files.
|
||||
"""
|
||||
self.do_remove()
|
||||
|
||||
|
||||
class OpenResourceNamespaceManager(NamespaceManager):
|
||||
"""A NamespaceManager where read/write operations require opening/
|
||||
closing of a resource which is possibly mutexed.
|
||||
|
||||
"""
|
||||
def __init__(self, namespace):
|
||||
NamespaceManager.__init__(self, namespace)
|
||||
self.access_lock = self.get_access_lock()
|
||||
self.openers = 0
|
||||
self.mutex = _threading.Lock()
|
||||
|
||||
def get_access_lock(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
def do_open(self, flags, replace):
|
||||
raise NotImplementedError()
|
||||
|
||||
def do_close(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
def acquire_read_lock(self):
|
||||
self.access_lock.acquire_read_lock()
|
||||
try:
|
||||
self.open('r', checkcount=True)
|
||||
except:
|
||||
self.access_lock.release_read_lock()
|
||||
raise
|
||||
|
||||
def release_read_lock(self):
|
||||
try:
|
||||
self.close(checkcount=True)
|
||||
finally:
|
||||
self.access_lock.release_read_lock()
|
||||
|
||||
def acquire_write_lock(self, wait=True, replace=False):
|
||||
r = self.access_lock.acquire_write_lock(wait)
|
||||
try:
|
||||
if (wait or r):
|
||||
self.open('c', checkcount=True, replace=replace)
|
||||
return r
|
||||
except:
|
||||
self.access_lock.release_write_lock()
|
||||
raise
|
||||
|
||||
def release_write_lock(self):
|
||||
try:
|
||||
self.close(checkcount=True)
|
||||
finally:
|
||||
self.access_lock.release_write_lock()
|
||||
|
||||
def open(self, flags, checkcount=False, replace=False):
|
||||
self.mutex.acquire()
|
||||
try:
|
||||
if checkcount:
|
||||
if self.openers == 0:
|
||||
self.do_open(flags, replace)
|
||||
self.openers += 1
|
||||
else:
|
||||
self.do_open(flags, replace)
|
||||
self.openers = 1
|
||||
finally:
|
||||
self.mutex.release()
|
||||
|
||||
def close(self, checkcount=False):
|
||||
self.mutex.acquire()
|
||||
try:
|
||||
if checkcount:
|
||||
self.openers -= 1
|
||||
if self.openers == 0:
|
||||
self.do_close()
|
||||
else:
|
||||
if self.openers > 0:
|
||||
self.do_close()
|
||||
self.openers = 0
|
||||
finally:
|
||||
self.mutex.release()
|
||||
|
||||
def remove(self):
|
||||
self.access_lock.acquire_write_lock()
|
||||
try:
|
||||
self.close(checkcount=False)
|
||||
self.do_remove()
|
||||
finally:
|
||||
self.access_lock.release_write_lock()
|
||||
|
||||
|
||||
class Value(object):
|
||||
"""Implements synchronization, expiration, and value-creation logic
|
||||
for a single value stored in a :class:`.NamespaceManager`.
|
||||
|
||||
"""
|
||||
|
||||
__slots__ = 'key', 'createfunc', 'expiretime', 'expire_argument', 'starttime', 'storedtime',\
|
||||
'namespace'
|
||||
|
||||
def __init__(self, key, namespace, createfunc=None, expiretime=None, starttime=None):
|
||||
self.key = key
|
||||
self.createfunc = createfunc
|
||||
self.expire_argument = expiretime
|
||||
self.starttime = starttime
|
||||
self.storedtime = -1
|
||||
self.namespace = namespace
|
||||
|
||||
def has_value(self):
|
||||
"""return true if the container has a value stored.
|
||||
|
||||
This is regardless of it being expired or not.
|
||||
|
||||
"""
|
||||
self.namespace.acquire_read_lock()
|
||||
try:
|
||||
return self.key in self.namespace
|
||||
finally:
|
||||
self.namespace.release_read_lock()
|
||||
|
||||
def can_have_value(self):
|
||||
return self.has_current_value() or self.createfunc is not None
|
||||
|
||||
def has_current_value(self):
|
||||
self.namespace.acquire_read_lock()
|
||||
try:
|
||||
has_value = self.key in self.namespace
|
||||
if has_value:
|
||||
try:
|
||||
stored, expired, value = self._get_value()
|
||||
return not self._is_expired(stored, expired)
|
||||
except KeyError:
|
||||
pass
|
||||
return False
|
||||
finally:
|
||||
self.namespace.release_read_lock()
|
||||
|
||||
def _is_expired(self, storedtime, expiretime):
|
||||
"""Return true if this container's value is expired."""
|
||||
return (
|
||||
(
|
||||
self.starttime is not None and
|
||||
storedtime < self.starttime
|
||||
)
|
||||
or
|
||||
(
|
||||
expiretime is not None and
|
||||
time.time() >= expiretime + storedtime
|
||||
)
|
||||
)
|
||||
|
||||
def get_value(self):
|
||||
self.namespace.acquire_read_lock()
|
||||
try:
|
||||
has_value = self.has_value()
|
||||
if has_value:
|
||||
try:
|
||||
stored, expired, value = self._get_value()
|
||||
if not self._is_expired(stored, expired):
|
||||
return value
|
||||
except KeyError:
|
||||
# guard against un-mutexed backends raising KeyError
|
||||
has_value = False
|
||||
|
||||
if not self.createfunc:
|
||||
raise KeyError(self.key)
|
||||
finally:
|
||||
self.namespace.release_read_lock()
|
||||
|
||||
has_createlock = False
|
||||
creation_lock = self.namespace.get_creation_lock(self.key)
|
||||
if has_value:
|
||||
if not creation_lock.acquire(wait=False):
|
||||
debug("get_value returning old value while new one is created")
|
||||
return value
|
||||
else:
|
||||
debug("lock_creatfunc (didnt wait)")
|
||||
has_createlock = True
|
||||
|
||||
if not has_createlock:
|
||||
debug("lock_createfunc (waiting)")
|
||||
creation_lock.acquire()
|
||||
debug("lock_createfunc (waited)")
|
||||
|
||||
try:
|
||||
# see if someone created the value already
|
||||
self.namespace.acquire_read_lock()
|
||||
try:
|
||||
if self.has_value():
|
||||
try:
|
||||
stored, expired, value = self._get_value()
|
||||
if not self._is_expired(stored, expired):
|
||||
return value
|
||||
except KeyError:
|
||||
# guard against un-mutexed backends raising KeyError
|
||||
pass
|
||||
finally:
|
||||
self.namespace.release_read_lock()
|
||||
|
||||
debug("get_value creating new value")
|
||||
v = self.createfunc()
|
||||
self.set_value(v)
|
||||
return v
|
||||
finally:
|
||||
creation_lock.release()
|
||||
debug("released create lock")
|
||||
|
||||
def _get_value(self):
|
||||
value = self.namespace[self.key]
|
||||
try:
|
||||
stored, expired, value = value
|
||||
except ValueError:
|
||||
if not len(value) == 2:
|
||||
raise
|
||||
# Old format: upgrade
|
||||
stored, value = value
|
||||
expired = self.expire_argument
|
||||
debug("get_value upgrading time %r expire time %r", stored, self.expire_argument)
|
||||
self.namespace.release_read_lock()
|
||||
self.set_value(value, stored)
|
||||
self.namespace.acquire_read_lock()
|
||||
except TypeError:
|
||||
# occurs when the value is None. memcached
|
||||
# may yank the rug from under us in which case
|
||||
# that's the result
|
||||
raise KeyError(self.key)
|
||||
return stored, expired, value
|
||||
|
||||
def set_value(self, value, storedtime=None):
|
||||
self.namespace.acquire_write_lock()
|
||||
try:
|
||||
if storedtime is None:
|
||||
storedtime = time.time()
|
||||
debug("set_value stored time %r expire time %r", storedtime, self.expire_argument)
|
||||
self.namespace.set_value(self.key, (storedtime, self.expire_argument, value),
|
||||
expiretime=self.expire_argument)
|
||||
finally:
|
||||
self.namespace.release_write_lock()
|
||||
|
||||
def clear_value(self):
|
||||
self.namespace.acquire_write_lock()
|
||||
try:
|
||||
debug("clear_value")
|
||||
if self.key in self.namespace:
|
||||
try:
|
||||
del self.namespace[self.key]
|
||||
except KeyError:
|
||||
# guard against un-mutexed backends raising KeyError
|
||||
pass
|
||||
self.storedtime = -1
|
||||
finally:
|
||||
self.namespace.release_write_lock()
|
||||
|
||||
|
||||
class AbstractDictionaryNSManager(NamespaceManager):
|
||||
"""A subclassable NamespaceManager that places data in a dictionary.
|
||||
|
||||
Subclasses should provide a "dictionary" attribute or descriptor
|
||||
which returns a dict-like object. The dictionary will store keys
|
||||
that are local to the "namespace" attribute of this manager, so
|
||||
ensure that the dictionary will not be used by any other namespace.
|
||||
|
||||
e.g.::
|
||||
|
||||
import collections
|
||||
cached_data = collections.defaultdict(dict)
|
||||
|
||||
class MyDictionaryManager(AbstractDictionaryNSManager):
|
||||
def __init__(self, namespace):
|
||||
AbstractDictionaryNSManager.__init__(self, namespace)
|
||||
self.dictionary = cached_data[self.namespace]
|
||||
|
||||
The above stores data in a global dictionary called "cached_data",
|
||||
which is structured as a dictionary of dictionaries, keyed
|
||||
first on namespace name to a sub-dictionary, then on actual
|
||||
cache key to value.
|
||||
|
||||
"""
|
||||
|
||||
def get_creation_lock(self, key):
|
||||
return NameLock(
|
||||
identifier="memorynamespace/funclock/%s/%s" %
|
||||
(self.namespace, key),
|
||||
reentrant=True
|
||||
)
|
||||
|
||||
def __getitem__(self, key):
|
||||
return self.dictionary[key]
|
||||
|
||||
def __contains__(self, key):
|
||||
return self.dictionary.__contains__(key)
|
||||
|
||||
def has_key(self, key):
|
||||
return self.dictionary.__contains__(key)
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
self.dictionary[key] = value
|
||||
|
||||
def __delitem__(self, key):
|
||||
del self.dictionary[key]
|
||||
|
||||
def do_remove(self):
|
||||
self.dictionary.clear()
|
||||
|
||||
def keys(self):
|
||||
return self.dictionary.keys()
|
||||
|
||||
|
||||
class MemoryNamespaceManager(AbstractDictionaryNSManager):
|
||||
""":class:`.NamespaceManager` that uses a Python dictionary for storage."""
|
||||
|
||||
namespaces = util.SyncDict()
|
||||
|
||||
def __init__(self, namespace, **kwargs):
|
||||
AbstractDictionaryNSManager.__init__(self, namespace)
|
||||
self.dictionary = MemoryNamespaceManager.\
|
||||
namespaces.get(self.namespace, dict)
|
||||
|
||||
|
||||
class DBMNamespaceManager(OpenResourceNamespaceManager):
|
||||
""":class:`.NamespaceManager` that uses ``dbm`` files for storage."""
|
||||
|
||||
def __init__(self, namespace, dbmmodule=None, data_dir=None,
|
||||
dbm_dir=None, lock_dir=None,
|
||||
digest_filenames=True, **kwargs):
|
||||
self.digest_filenames = digest_filenames
|
||||
|
||||
if not dbm_dir and not data_dir:
|
||||
raise MissingCacheParameter("data_dir or dbm_dir is required")
|
||||
elif dbm_dir:
|
||||
self.dbm_dir = dbm_dir
|
||||
else:
|
||||
self.dbm_dir = data_dir + "/container_dbm"
|
||||
util.verify_directory(self.dbm_dir)
|
||||
|
||||
if not lock_dir and not data_dir:
|
||||
raise MissingCacheParameter("data_dir or lock_dir is required")
|
||||
elif lock_dir:
|
||||
self.lock_dir = lock_dir
|
||||
else:
|
||||
self.lock_dir = data_dir + "/container_dbm_lock"
|
||||
util.verify_directory(self.lock_dir)
|
||||
|
||||
self.dbmmodule = dbmmodule or anydbm
|
||||
|
||||
self.dbm = None
|
||||
OpenResourceNamespaceManager.__init__(self, namespace)
|
||||
|
||||
self.file = util.encoded_path(root=self.dbm_dir,
|
||||
identifiers=[self.namespace],
|
||||
extension='.dbm',
|
||||
digest_filenames=self.digest_filenames)
|
||||
|
||||
debug("data file %s", self.file)
|
||||
self._checkfile()
|
||||
|
||||
def get_access_lock(self):
|
||||
return file_synchronizer(identifier=self.namespace,
|
||||
lock_dir=self.lock_dir)
|
||||
|
||||
def get_creation_lock(self, key):
|
||||
return file_synchronizer(
|
||||
identifier="dbmcontainer/funclock/%s/%s" % (
|
||||
self.namespace, key
|
||||
),
|
||||
lock_dir=self.lock_dir
|
||||
)
|
||||
|
||||
def file_exists(self, file):
|
||||
if os.access(file, os.F_OK):
|
||||
return True
|
||||
else:
|
||||
for ext in ('db', 'dat', 'pag', 'dir'):
|
||||
if os.access(file + os.extsep + ext, os.F_OK):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _ensuredir(self, filename):
|
||||
dirname = os.path.dirname(filename)
|
||||
if not os.path.exists(dirname):
|
||||
util.verify_directory(dirname)
|
||||
|
||||
def _checkfile(self):
|
||||
if not self.file_exists(self.file):
|
||||
self._ensuredir(self.file)
|
||||
g = self.dbmmodule.open(self.file, 'c')
|
||||
g.close()
|
||||
|
||||
def get_filenames(self):
|
||||
list = []
|
||||
if os.access(self.file, os.F_OK):
|
||||
list.append(self.file)
|
||||
|
||||
for ext in ('pag', 'dir', 'db', 'dat'):
|
||||
if os.access(self.file + os.extsep + ext, os.F_OK):
|
||||
list.append(self.file + os.extsep + ext)
|
||||
return list
|
||||
|
||||
def do_open(self, flags, replace):
|
||||
debug("opening dbm file %s", self.file)
|
||||
try:
|
||||
self.dbm = self.dbmmodule.open(self.file, flags)
|
||||
except:
|
||||
self._checkfile()
|
||||
self.dbm = self.dbmmodule.open(self.file, flags)
|
||||
|
||||
def do_close(self):
|
||||
if self.dbm is not None:
|
||||
debug("closing dbm file %s", self.file)
|
||||
self.dbm.close()
|
||||
|
||||
def do_remove(self):
|
||||
for f in self.get_filenames():
|
||||
os.remove(f)
|
||||
|
||||
def __getitem__(self, key):
|
||||
return pickle.loads(self.dbm[key])
|
||||
|
||||
def __contains__(self, key):
|
||||
if PYVER == (3, 2):
|
||||
# Looks like this is a bug that got solved in PY3.3 and PY3.4
|
||||
# http://bugs.python.org/issue19288
|
||||
if isinstance(key, unicode_text):
|
||||
key = key.encode('UTF-8')
|
||||
return key in self.dbm
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
self.dbm[key] = pickle.dumps(value)
|
||||
|
||||
def __delitem__(self, key):
|
||||
del self.dbm[key]
|
||||
|
||||
def keys(self):
|
||||
return self.dbm.keys()
|
||||
|
||||
|
||||
class FileNamespaceManager(OpenResourceNamespaceManager):
|
||||
""":class:`.NamespaceManager` that uses binary files for storage.
|
||||
|
||||
Each namespace is implemented as a single file storing a
|
||||
dictionary of key/value pairs, serialized using the Python
|
||||
``pickle`` module.
|
||||
|
||||
"""
|
||||
def __init__(self, namespace, data_dir=None, file_dir=None, lock_dir=None,
|
||||
digest_filenames=True, **kwargs):
|
||||
self.digest_filenames = digest_filenames
|
||||
|
||||
if not file_dir and not data_dir:
|
||||
raise MissingCacheParameter("data_dir or file_dir is required")
|
||||
elif file_dir:
|
||||
self.file_dir = file_dir
|
||||
else:
|
||||
self.file_dir = data_dir + "/container_file"
|
||||
util.verify_directory(self.file_dir)
|
||||
|
||||
if not lock_dir and not data_dir:
|
||||
raise MissingCacheParameter("data_dir or lock_dir is required")
|
||||
elif lock_dir:
|
||||
self.lock_dir = lock_dir
|
||||
else:
|
||||
self.lock_dir = data_dir + "/container_file_lock"
|
||||
util.verify_directory(self.lock_dir)
|
||||
OpenResourceNamespaceManager.__init__(self, namespace)
|
||||
|
||||
self.file = util.encoded_path(root=self.file_dir,
|
||||
identifiers=[self.namespace],
|
||||
extension='.cache',
|
||||
digest_filenames=self.digest_filenames)
|
||||
self.hash = {}
|
||||
|
||||
debug("data file %s", self.file)
|
||||
|
||||
def get_access_lock(self):
|
||||
return file_synchronizer(identifier=self.namespace,
|
||||
lock_dir=self.lock_dir)
|
||||
|
||||
def get_creation_lock(self, key):
|
||||
return file_synchronizer(
|
||||
identifier="dbmcontainer/funclock/%s/%s" % (
|
||||
self.namespace, key
|
||||
),
|
||||
lock_dir=self.lock_dir
|
||||
)
|
||||
|
||||
def file_exists(self, file):
|
||||
return os.access(file, os.F_OK)
|
||||
|
||||
def do_open(self, flags, replace):
|
||||
if not replace and self.file_exists(self.file):
|
||||
try:
|
||||
with open(self.file, 'rb') as fh:
|
||||
self.hash = pickle.load(fh)
|
||||
except IOError as e:
|
||||
# Ignore EACCES and ENOENT as it just means we are no longer
|
||||
# able to access the file or that it no longer exists
|
||||
if e.errno not in [errno.EACCES, errno.ENOENT]:
|
||||
raise
|
||||
|
||||
self.flags = flags
|
||||
|
||||
def do_close(self):
|
||||
if self.flags == 'c' or self.flags == 'w':
|
||||
pickled = pickle.dumps(self.hash)
|
||||
util.safe_write(self.file, pickled)
|
||||
|
||||
self.hash = {}
|
||||
self.flags = None
|
||||
|
||||
def do_remove(self):
|
||||
try:
|
||||
os.remove(self.file)
|
||||
except OSError:
|
||||
# for instance, because we haven't yet used this cache,
|
||||
# but client code has asked for a clear() operation...
|
||||
pass
|
||||
self.hash = {}
|
||||
|
||||
def __getitem__(self, key):
|
||||
return self.hash[key]
|
||||
|
||||
def __contains__(self, key):
|
||||
return key in self.hash
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
self.hash[key] = value
|
||||
|
||||
def __delitem__(self, key):
|
||||
del self.hash[key]
|
||||
|
||||
def keys(self):
|
||||
return self.hash.keys()
|
||||
|
||||
|
||||
#### legacy stuff to support the old "Container" class interface
|
||||
|
||||
namespace_classes = {}
|
||||
|
||||
ContainerContext = dict
|
||||
|
||||
|
||||
class ContainerMeta(type):
|
||||
def __init__(cls, classname, bases, dict_):
|
||||
namespace_classes[cls] = cls.namespace_class
|
||||
return type.__init__(cls, classname, bases, dict_)
|
||||
|
||||
def __call__(self, key, context, namespace, createfunc=None,
|
||||
expiretime=None, starttime=None, **kwargs):
|
||||
if namespace in context:
|
||||
ns = context[namespace]
|
||||
else:
|
||||
nscls = namespace_classes[self]
|
||||
context[namespace] = ns = nscls(namespace, **kwargs)
|
||||
return Value(key, ns, createfunc=createfunc,
|
||||
expiretime=expiretime, starttime=starttime)
|
||||
|
||||
@add_metaclass(ContainerMeta)
|
||||
class Container(object):
|
||||
"""Implements synchronization and value-creation logic
|
||||
for a 'value' stored in a :class:`.NamespaceManager`.
|
||||
|
||||
:class:`.Container` and its subclasses are deprecated. The
|
||||
:class:`.Value` class is now used for this purpose.
|
||||
|
||||
"""
|
||||
namespace_class = NamespaceManager
|
||||
|
||||
|
||||
class FileContainer(Container):
|
||||
namespace_class = FileNamespaceManager
|
||||
|
||||
|
||||
class MemoryContainer(Container):
|
||||
namespace_class = MemoryNamespaceManager
|
||||
|
||||
|
||||
class DBMContainer(Container):
|
||||
namespace_class = DBMNamespaceManager
|
||||
|
||||
DbmContainer = DBMContainer
|
|
@ -1,29 +0,0 @@
|
|||
from beaker._compat import string_type
|
||||
|
||||
# (c) 2005 Ian Bicking and contributors; written for Paste (http://pythonpaste.org)
|
||||
# Licensed under the MIT license: http://www.opensource.org/licenses/mit-license.php
|
||||
def asbool(obj):
|
||||
if isinstance(obj, string_type):
|
||||
obj = obj.strip().lower()
|
||||
if obj in ['true', 'yes', 'on', 'y', 't', '1']:
|
||||
return True
|
||||
elif obj in ['false', 'no', 'off', 'n', 'f', '0']:
|
||||
return False
|
||||
else:
|
||||
raise ValueError(
|
||||
"String is not true/false: %r" % obj)
|
||||
return bool(obj)
|
||||
|
||||
|
||||
def aslist(obj, sep=None, strip=True):
|
||||
if isinstance(obj, string_type):
|
||||
lst = obj.split(sep)
|
||||
if strip:
|
||||
lst = [v.strip() for v in lst]
|
||||
return lst
|
||||
elif isinstance(obj, (list, tuple)):
|
||||
return obj
|
||||
elif obj is None:
|
||||
return []
|
||||
else:
|
||||
return [obj]
|
|
@ -1,72 +0,0 @@
|
|||
import sys
|
||||
from ._compat import http_cookies
|
||||
|
||||
# Some versions of Python 2.7 and later won't need this encoding bug fix:
|
||||
_cookie_encodes_correctly = http_cookies.SimpleCookie().value_encode(';') == (';', '"\\073"')
|
||||
|
||||
# Cookie pickling bug is fixed in Python 2.7.9 and Python 3.4.3+
|
||||
# http://bugs.python.org/issue22775
|
||||
cookie_pickles_properly = (
|
||||
(sys.version_info[:2] == (2, 7) and sys.version_info >= (2, 7, 9)) or
|
||||
sys.version_info >= (3, 4, 3)
|
||||
)
|
||||
|
||||
# Add support for the SameSite attribute (obsolete when PY37 is unsupported).
|
||||
http_cookies.Morsel._reserved.setdefault('samesite', 'SameSite')
|
||||
|
||||
|
||||
# Adapted from Django.http.cookies and always enabled the bad_cookies
|
||||
# behaviour to cope with any invalid cookie key while keeping around
|
||||
# the session.
|
||||
class SimpleCookie(http_cookies.SimpleCookie):
|
||||
if not cookie_pickles_properly:
|
||||
def __setitem__(self, key, value):
|
||||
# Apply the fix from http://bugs.python.org/issue22775 where
|
||||
# it's not fixed in Python itself
|
||||
if isinstance(value, http_cookies.Morsel):
|
||||
# allow assignment of constructed Morsels (e.g. for pickling)
|
||||
dict.__setitem__(self, key, value)
|
||||
else:
|
||||
super(SimpleCookie, self).__setitem__(key, value)
|
||||
|
||||
if not _cookie_encodes_correctly:
|
||||
def value_encode(self, val):
|
||||
# Some browsers do not support quoted-string from RFC 2109,
|
||||
# including some versions of Safari and Internet Explorer.
|
||||
# These browsers split on ';', and some versions of Safari
|
||||
# are known to split on ', '. Therefore, we encode ';' and ','
|
||||
|
||||
# SimpleCookie already does the hard work of encoding and decoding.
|
||||
# It uses octal sequences like '\\012' for newline etc.
|
||||
# and non-ASCII chars. We just make use of this mechanism, to
|
||||
# avoid introducing two encoding schemes which would be confusing
|
||||
# and especially awkward for javascript.
|
||||
|
||||
# NB, contrary to Python docs, value_encode returns a tuple containing
|
||||
# (real val, encoded_val)
|
||||
val, encoded = super(SimpleCookie, self).value_encode(val)
|
||||
|
||||
encoded = encoded.replace(";", "\\073").replace(",", "\\054")
|
||||
# If encoded now contains any quoted chars, we need double quotes
|
||||
# around the whole string.
|
||||
if "\\" in encoded and not encoded.startswith('"'):
|
||||
encoded = '"' + encoded + '"'
|
||||
|
||||
return val, encoded
|
||||
|
||||
def load(self, rawdata):
|
||||
self.bad_cookies = set()
|
||||
super(SimpleCookie, self).load(rawdata)
|
||||
for key in self.bad_cookies:
|
||||
del self[key]
|
||||
|
||||
# override private __set() method:
|
||||
# (needed for using our Morsel, and for laxness with CookieError
|
||||
def _BaseCookie__set(self, key, real_value, coded_value):
|
||||
try:
|
||||
super(SimpleCookie, self)._BaseCookie__set(key, real_value, coded_value)
|
||||
except http_cookies.CookieError:
|
||||
if not hasattr(self, 'bad_cookies'):
|
||||
self.bad_cookies = set()
|
||||
self.bad_cookies.add(key)
|
||||
dict.__setitem__(self, key, http_cookies.Morsel())
|
|
@ -1,83 +0,0 @@
|
|||
from .._compat import JYTHON
|
||||
|
||||
|
||||
from beaker.crypto.pbkdf2 import pbkdf2
|
||||
from beaker.crypto.util import hmac, sha1, hmac_sha1, md5
|
||||
from beaker import util
|
||||
from beaker.exceptions import InvalidCryptoBackendError
|
||||
|
||||
keyLength = None
|
||||
DEFAULT_NONCE_BITS = 128
|
||||
|
||||
CRYPTO_MODULES = {}
|
||||
|
||||
|
||||
def load_default_module():
|
||||
""" Load the default crypto module
|
||||
"""
|
||||
if JYTHON:
|
||||
try:
|
||||
from beaker.crypto import jcecrypto
|
||||
return jcecrypto
|
||||
except ImportError:
|
||||
pass
|
||||
else:
|
||||
try:
|
||||
from beaker.crypto import nsscrypto
|
||||
return nsscrypto
|
||||
except ImportError:
|
||||
try:
|
||||
from beaker.crypto import pycrypto
|
||||
return pycrypto
|
||||
except ImportError:
|
||||
pass
|
||||
from beaker.crypto import noencryption
|
||||
return noencryption
|
||||
|
||||
|
||||
def register_crypto_module(name, mod):
|
||||
"""
|
||||
Register the given module under the name given.
|
||||
"""
|
||||
CRYPTO_MODULES[name] = mod
|
||||
|
||||
|
||||
def get_crypto_module(name):
|
||||
"""
|
||||
Get the active crypto module for this name
|
||||
"""
|
||||
if name not in CRYPTO_MODULES:
|
||||
if name == 'default':
|
||||
register_crypto_module('default', load_default_module())
|
||||
elif name == 'nss':
|
||||
from beaker.crypto import nsscrypto
|
||||
register_crypto_module(name, nsscrypto)
|
||||
elif name == 'pycrypto':
|
||||
from beaker.crypto import pycrypto
|
||||
register_crypto_module(name, pycrypto)
|
||||
elif name == 'cryptography':
|
||||
from beaker.crypto import pyca_cryptography
|
||||
register_crypto_module(name, pyca_cryptography)
|
||||
else:
|
||||
raise InvalidCryptoBackendError(
|
||||
"No crypto backend with name '%s' is registered." % name)
|
||||
|
||||
return CRYPTO_MODULES[name]
|
||||
|
||||
|
||||
|
||||
def generateCryptoKeys(master_key, salt, iterations, keylen):
|
||||
# NB: We XOR parts of the keystream into the randomly-generated parts, just
|
||||
# in case os.urandom() isn't as random as it should be. Note that if
|
||||
# os.urandom() returns truly random data, this will have no effect on the
|
||||
# overall security.
|
||||
return pbkdf2(master_key, salt, iterations=iterations, dklen=keylen)
|
||||
|
||||
|
||||
def get_nonce_size(number_of_bits):
|
||||
if number_of_bits % 8:
|
||||
raise ValueError('Nonce complexity currently supports multiples of 8')
|
||||
|
||||
bytes = number_of_bits // 8
|
||||
b64bytes = ((4 * bytes // 3) + 3) & ~3
|
||||
return bytes, b64bytes
|
|
@ -1,41 +0,0 @@
|
|||
"""
|
||||
Encryption module that uses the Java Cryptography Extensions (JCE).
|
||||
|
||||
Note that in default installations of the Java Runtime Environment, the
|
||||
maximum key length is limited to 128 bits due to US export
|
||||
restrictions. This makes the generated keys incompatible with the ones
|
||||
generated by pycryptopp, which has no such restrictions. To fix this,
|
||||
download the "Unlimited Strength Jurisdiction Policy Files" from Sun,
|
||||
which will allow encryption using 256 bit AES keys.
|
||||
"""
|
||||
from warnings import warn
|
||||
|
||||
from javax.crypto import Cipher
|
||||
from javax.crypto.spec import SecretKeySpec, IvParameterSpec
|
||||
|
||||
import jarray
|
||||
|
||||
# Initialization vector filled with zeros
|
||||
_iv = IvParameterSpec(jarray.zeros(16, 'b'))
|
||||
|
||||
|
||||
def aesEncrypt(data, key):
|
||||
cipher = Cipher.getInstance('AES/CTR/NoPadding')
|
||||
skeySpec = SecretKeySpec(key, 'AES')
|
||||
cipher.init(Cipher.ENCRYPT_MODE, skeySpec, _iv)
|
||||
return cipher.doFinal(data).tostring()
|
||||
|
||||
# magic.
|
||||
aesDecrypt = aesEncrypt
|
||||
|
||||
has_aes = True
|
||||
|
||||
def getKeyLength():
|
||||
maxlen = Cipher.getMaxAllowedKeyLength('AES/CTR/NoPadding')
|
||||
return min(maxlen, 256) / 8
|
||||
|
||||
|
||||
if getKeyLength() < 32:
|
||||
warn('Crypto implementation only supports key lengths up to %d bits. '
|
||||
'Generated session cookies may be incompatible with other '
|
||||
'environments' % (getKeyLength() * 8))
|
|
@ -1,12 +0,0 @@
|
|||
"""Encryption module that does nothing"""
|
||||
|
||||
def aesEncrypt(data, key):
|
||||
return data
|
||||
|
||||
def aesDecrypt(data, key):
|
||||
return data
|
||||
|
||||
has_aes = False
|
||||
|
||||
def getKeyLength():
|
||||
return 32
|
|
@ -1,47 +0,0 @@
|
|||
"""Encryption module that uses nsscrypto"""
|
||||
import nss.nss
|
||||
|
||||
nss.nss.nss_init_nodb()
|
||||
|
||||
# Apparently the rest of beaker doesn't care about the particluar cipher,
|
||||
# mode and padding used.
|
||||
# NOTE: A constant IV!!! This is only secure if the KEY is never reused!!!
|
||||
_mech = nss.nss.CKM_AES_CBC_PAD
|
||||
_iv = '\0' * nss.nss.get_iv_length(_mech)
|
||||
|
||||
def aesEncrypt(data, key):
|
||||
slot = nss.nss.get_best_slot(_mech)
|
||||
|
||||
key_obj = nss.nss.import_sym_key(slot, _mech, nss.nss.PK11_OriginGenerated,
|
||||
nss.nss.CKA_ENCRYPT, nss.nss.SecItem(key))
|
||||
|
||||
param = nss.nss.param_from_iv(_mech, nss.nss.SecItem(_iv))
|
||||
ctx = nss.nss.create_context_by_sym_key(_mech, nss.nss.CKA_ENCRYPT, key_obj,
|
||||
param)
|
||||
l1 = ctx.cipher_op(data)
|
||||
# Yes, DIGEST. This needs fixing in NSS, but apparently nobody (including
|
||||
# me :( ) cares enough.
|
||||
l2 = ctx.digest_final()
|
||||
|
||||
return l1 + l2
|
||||
|
||||
def aesDecrypt(data, key):
|
||||
slot = nss.nss.get_best_slot(_mech)
|
||||
|
||||
key_obj = nss.nss.import_sym_key(slot, _mech, nss.nss.PK11_OriginGenerated,
|
||||
nss.nss.CKA_DECRYPT, nss.nss.SecItem(key))
|
||||
|
||||
param = nss.nss.param_from_iv(_mech, nss.nss.SecItem(_iv))
|
||||
ctx = nss.nss.create_context_by_sym_key(_mech, nss.nss.CKA_DECRYPT, key_obj,
|
||||
param)
|
||||
l1 = ctx.cipher_op(data)
|
||||
# Yes, DIGEST. This needs fixing in NSS, but apparently nobody (including
|
||||
# me :( ) cares enough.
|
||||
l2 = ctx.digest_final()
|
||||
|
||||
return l1 + l2
|
||||
|
||||
has_aes = True
|
||||
|
||||
def getKeyLength():
|
||||
return 32
|
|
@ -1,94 +0,0 @@
|
|||
"""
|
||||
PBKDF2 Implementation adapted from django.utils.crypto.
|
||||
|
||||
This is used to generate the encryption key for enciphered sessions.
|
||||
"""
|
||||
from beaker._compat import bytes_, xrange_
|
||||
|
||||
import hmac
|
||||
import struct
|
||||
import hashlib
|
||||
import binascii
|
||||
|
||||
|
||||
def _bin_to_long(x):
|
||||
"""Convert a binary string into a long integer"""
|
||||
return int(binascii.hexlify(x), 16)
|
||||
|
||||
|
||||
def _long_to_bin(x, hex_format_string):
|
||||
"""
|
||||
Convert a long integer into a binary string.
|
||||
hex_format_string is like "%020x" for padding 10 characters.
|
||||
"""
|
||||
return binascii.unhexlify((hex_format_string % x).encode('ascii'))
|
||||
|
||||
|
||||
if hasattr(hashlib, "pbkdf2_hmac"):
|
||||
def pbkdf2(password, salt, iterations, dklen=0, digest=None):
|
||||
"""
|
||||
Implements PBKDF2 using the stdlib. This is used in Python 2.7.8+ and 3.4+.
|
||||
|
||||
HMAC+SHA256 is used as the default pseudo random function.
|
||||
|
||||
As of 2014, 100,000 iterations was the recommended default which took
|
||||
100ms on a 2.7Ghz Intel i7 with an optimized implementation. This is
|
||||
probably the bare minimum for security given 1000 iterations was
|
||||
recommended in 2001.
|
||||
"""
|
||||
if digest is None:
|
||||
digest = hashlib.sha1
|
||||
if not dklen:
|
||||
dklen = None
|
||||
password = bytes_(password)
|
||||
salt = bytes_(salt)
|
||||
return hashlib.pbkdf2_hmac(
|
||||
digest().name, password, salt, iterations, dklen)
|
||||
else:
|
||||
def pbkdf2(password, salt, iterations, dklen=0, digest=None):
|
||||
"""
|
||||
Implements PBKDF2 as defined in RFC 2898, section 5.2
|
||||
|
||||
HMAC+SHA256 is used as the default pseudo random function.
|
||||
|
||||
As of 2014, 100,000 iterations was the recommended default which took
|
||||
100ms on a 2.7Ghz Intel i7 with an optimized implementation. This is
|
||||
probably the bare minimum for security given 1000 iterations was
|
||||
recommended in 2001. This code is very well optimized for CPython and
|
||||
is about five times slower than OpenSSL's implementation.
|
||||
"""
|
||||
assert iterations > 0
|
||||
if not digest:
|
||||
digest = hashlib.sha1
|
||||
password = bytes_(password)
|
||||
salt = bytes_(salt)
|
||||
hlen = digest().digest_size
|
||||
if not dklen:
|
||||
dklen = hlen
|
||||
if dklen > (2 ** 32 - 1) * hlen:
|
||||
raise OverflowError('dklen too big')
|
||||
l = -(-dklen // hlen)
|
||||
r = dklen - (l - 1) * hlen
|
||||
|
||||
hex_format_string = "%%0%ix" % (hlen * 2)
|
||||
|
||||
inner, outer = digest(), digest()
|
||||
if len(password) > inner.block_size:
|
||||
password = digest(password).digest()
|
||||
password += b'\x00' * (inner.block_size - len(password))
|
||||
inner.update(password.translate(hmac.trans_36))
|
||||
outer.update(password.translate(hmac.trans_5C))
|
||||
|
||||
def F(i):
|
||||
u = salt + struct.pack(b'>I', i)
|
||||
result = 0
|
||||
for j in xrange_(int(iterations)):
|
||||
dig1, dig2 = inner.copy(), outer.copy()
|
||||
dig1.update(u)
|
||||
dig2.update(dig1.digest())
|
||||
u = dig2.digest()
|
||||
result ^= _bin_to_long(u)
|
||||
return _long_to_bin(result, hex_format_string)
|
||||
|
||||
T = [F(x) for x in xrange_(1, l)]
|
||||
return b''.join(T) + F(l)[:r]
|
|
@ -1,52 +0,0 @@
|
|||
"""Encryption module that uses pyca/cryptography"""
|
||||
|
||||
import os
|
||||
import json
|
||||
|
||||
from cryptography.hazmat.backends import default_backend
|
||||
from cryptography.hazmat.primitives.ciphers import (
|
||||
Cipher, algorithms, modes
|
||||
)
|
||||
|
||||
|
||||
def aesEncrypt(data, key):
|
||||
# Generate a random 96-bit IV.
|
||||
iv = os.urandom(12)
|
||||
|
||||
# Construct an AES-GCM Cipher object with the given key and a
|
||||
# randomly generated IV.
|
||||
encryptor = Cipher(
|
||||
algorithms.AES(key),
|
||||
modes.GCM(iv),
|
||||
backend=default_backend()
|
||||
).encryptor()
|
||||
|
||||
# Encrypt the plaintext and get the associated ciphertext.
|
||||
# GCM does not require padding.
|
||||
ciphertext = encryptor.update(data) + encryptor.finalize()
|
||||
|
||||
return iv + encryptor.tag + ciphertext
|
||||
|
||||
|
||||
def aesDecrypt(data, key):
|
||||
iv = data[:12]
|
||||
tag = data[12:28]
|
||||
ciphertext = data[28:]
|
||||
|
||||
# Construct a Cipher object, with the key, iv, and additionally the
|
||||
# GCM tag used for authenticating the message.
|
||||
decryptor = Cipher(
|
||||
algorithms.AES(key),
|
||||
modes.GCM(iv, tag),
|
||||
backend=default_backend()
|
||||
).decryptor()
|
||||
|
||||
# Decryption gets us the authenticated plaintext.
|
||||
# If the tag does not match an InvalidTag exception will be raised.
|
||||
return decryptor.update(ciphertext) + decryptor.finalize()
|
||||
|
||||
|
||||
has_aes = True
|
||||
|
||||
def getKeyLength():
|
||||
return 32
|
|
@ -1,34 +0,0 @@
|
|||
"""Encryption module that uses pycryptopp or pycrypto"""
|
||||
try:
|
||||
# Pycryptopp is preferred over Crypto because Crypto has had
|
||||
# various periods of not being maintained, and pycryptopp uses
|
||||
# the Crypto++ library which is generally considered the 'gold standard'
|
||||
# of crypto implementations
|
||||
from pycryptopp.cipher import aes
|
||||
|
||||
def aesEncrypt(data, key):
|
||||
cipher = aes.AES(key)
|
||||
return cipher.process(data)
|
||||
|
||||
# magic.
|
||||
aesDecrypt = aesEncrypt
|
||||
|
||||
except ImportError:
|
||||
from Crypto.Cipher import AES
|
||||
from Crypto.Util import Counter
|
||||
|
||||
def aesEncrypt(data, key):
|
||||
cipher = AES.new(key, AES.MODE_CTR,
|
||||
counter=Counter.new(128, initial_value=0))
|
||||
|
||||
return cipher.encrypt(data)
|
||||
|
||||
def aesDecrypt(data, key):
|
||||
cipher = AES.new(key, AES.MODE_CTR,
|
||||
counter=Counter.new(128, initial_value=0))
|
||||
return cipher.decrypt(data)
|
||||
|
||||
has_aes = True
|
||||
|
||||
def getKeyLength():
|
||||
return 32
|
|
@ -1,16 +0,0 @@
|
|||
from hashlib import md5
|
||||
|
||||
try:
|
||||
# Use PyCrypto (if available)
|
||||
from Crypto.Hash import HMAC as hmac, SHA as hmac_sha1
|
||||
sha1 = hmac_sha1.new
|
||||
|
||||
except ImportError:
|
||||
|
||||
# PyCrypto not available. Use the Python standard library.
|
||||
import hmac
|
||||
|
||||
# NOTE: We have to use the callable with hashlib (hashlib.sha1),
|
||||
# otherwise hmac only accepts the sha module object itself
|
||||
from hashlib import sha1
|
||||
hmac_sha1 = sha1
|
|
@ -1,29 +0,0 @@
|
|||
"""Beaker exception classes"""
|
||||
|
||||
|
||||
class BeakerException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class BeakerWarning(RuntimeWarning):
|
||||
"""Issued at runtime."""
|
||||
|
||||
|
||||
class CreationAbortedError(Exception):
|
||||
"""Deprecated."""
|
||||
|
||||
|
||||
class InvalidCacheBackendError(BeakerException, ImportError):
|
||||
pass
|
||||
|
||||
|
||||
class MissingCacheParameter(BeakerException):
|
||||
pass
|
||||
|
||||
|
||||
class LockError(BeakerException):
|
||||
pass
|
||||
|
||||
|
||||
class InvalidCryptoBackendError(BeakerException):
|
||||
pass
|
|
@ -1,180 +0,0 @@
|
|||
from beaker._compat import pickle
|
||||
|
||||
import logging
|
||||
import pickle
|
||||
from datetime import datetime
|
||||
|
||||
from beaker.container import OpenResourceNamespaceManager, Container
|
||||
from beaker.exceptions import InvalidCacheBackendError, MissingCacheParameter
|
||||
from beaker.synchronization import file_synchronizer, null_synchronizer
|
||||
from beaker.util import verify_directory, SyncDict
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
sa = None
|
||||
pool = None
|
||||
types = None
|
||||
|
||||
|
||||
class DatabaseNamespaceManager(OpenResourceNamespaceManager):
|
||||
metadatas = SyncDict()
|
||||
tables = SyncDict()
|
||||
|
||||
@classmethod
|
||||
def _init_dependencies(cls):
|
||||
global sa, pool, types
|
||||
if sa is not None:
|
||||
return
|
||||
try:
|
||||
import sqlalchemy as sa
|
||||
import sqlalchemy.pool as pool
|
||||
from sqlalchemy import types
|
||||
except ImportError:
|
||||
raise InvalidCacheBackendError("Database cache backend requires "
|
||||
"the 'sqlalchemy' library")
|
||||
|
||||
def __init__(self, namespace, url=None, sa_opts=None, optimistic=False,
|
||||
table_name='beaker_cache', data_dir=None, lock_dir=None,
|
||||
schema_name=None, **params):
|
||||
"""Creates a database namespace manager
|
||||
|
||||
``url``
|
||||
SQLAlchemy compliant db url
|
||||
``sa_opts``
|
||||
A dictionary of SQLAlchemy keyword options to initialize the engine
|
||||
with.
|
||||
``optimistic``
|
||||
Use optimistic session locking, note that this will result in an
|
||||
additional select when updating a cache value to compare version
|
||||
numbers.
|
||||
``table_name``
|
||||
The table name to use in the database for the cache.
|
||||
``schema_name``
|
||||
The schema name to use in the database for the cache.
|
||||
"""
|
||||
OpenResourceNamespaceManager.__init__(self, namespace)
|
||||
|
||||
if sa_opts is None:
|
||||
sa_opts = {}
|
||||
|
||||
self.lock_dir = None
|
||||
|
||||
if lock_dir:
|
||||
self.lock_dir = lock_dir
|
||||
elif data_dir:
|
||||
self.lock_dir = data_dir + "/container_db_lock"
|
||||
if self.lock_dir:
|
||||
verify_directory(self.lock_dir)
|
||||
|
||||
# Check to see if the table's been created before
|
||||
url = url or sa_opts['sa.url']
|
||||
table_key = url + table_name
|
||||
|
||||
def make_cache():
|
||||
# Check to see if we have a connection pool open already
|
||||
meta_key = url + table_name
|
||||
|
||||
def make_meta():
|
||||
# SQLAlchemy pops the url, this ensures it sticks around
|
||||
# later
|
||||
sa_opts['sa.url'] = url
|
||||
engine = sa.engine_from_config(sa_opts, 'sa.')
|
||||
meta = sa.MetaData()
|
||||
meta.bind = engine
|
||||
return meta
|
||||
meta = DatabaseNamespaceManager.metadatas.get(meta_key, make_meta)
|
||||
# Create the table object and cache it now
|
||||
cache = sa.Table(table_name, meta,
|
||||
sa.Column('id', types.Integer, primary_key=True),
|
||||
sa.Column('namespace', types.String(255), nullable=False),
|
||||
sa.Column('accessed', types.DateTime, nullable=False),
|
||||
sa.Column('created', types.DateTime, nullable=False),
|
||||
sa.Column('data', types.PickleType, nullable=False),
|
||||
sa.UniqueConstraint('namespace'),
|
||||
schema=schema_name if schema_name else meta.schema
|
||||
)
|
||||
cache.create(checkfirst=True)
|
||||
return cache
|
||||
self.hash = {}
|
||||
self._is_new = False
|
||||
self.loaded = False
|
||||
self.cache = DatabaseNamespaceManager.tables.get(table_key, make_cache)
|
||||
|
||||
def get_access_lock(self):
|
||||
return null_synchronizer()
|
||||
|
||||
def get_creation_lock(self, key):
|
||||
return file_synchronizer(
|
||||
identifier="databasecontainer/funclock/%s/%s" % (
|
||||
self.namespace, key
|
||||
),
|
||||
lock_dir=self.lock_dir)
|
||||
|
||||
def do_open(self, flags, replace):
|
||||
# If we already loaded the data, don't bother loading it again
|
||||
if self.loaded:
|
||||
self.flags = flags
|
||||
return
|
||||
|
||||
cache = self.cache
|
||||
result_proxy = sa.select([cache.c.data],
|
||||
cache.c.namespace == self.namespace
|
||||
).execute()
|
||||
result = result_proxy.fetchone()
|
||||
result_proxy.close()
|
||||
|
||||
if not result:
|
||||
self._is_new = True
|
||||
self.hash = {}
|
||||
else:
|
||||
self._is_new = False
|
||||
try:
|
||||
self.hash = result['data']
|
||||
except (IOError, OSError, EOFError, pickle.PickleError,
|
||||
pickle.PickleError):
|
||||
log.debug("Couln't load pickle data, creating new storage")
|
||||
self.hash = {}
|
||||
self._is_new = True
|
||||
self.flags = flags
|
||||
self.loaded = True
|
||||
|
||||
def do_close(self):
|
||||
if self.flags is not None and (self.flags == 'c' or self.flags == 'w'):
|
||||
cache = self.cache
|
||||
if self._is_new:
|
||||
cache.insert().execute(namespace=self.namespace, data=self.hash,
|
||||
accessed=datetime.now(),
|
||||
created=datetime.now())
|
||||
self._is_new = False
|
||||
else:
|
||||
cache.update(cache.c.namespace == self.namespace).execute(
|
||||
data=self.hash, accessed=datetime.now())
|
||||
self.flags = None
|
||||
|
||||
def do_remove(self):
|
||||
cache = self.cache
|
||||
cache.delete(cache.c.namespace == self.namespace).execute()
|
||||
self.hash = {}
|
||||
|
||||
# We can retain the fact that we did a load attempt, but since the
|
||||
# file is gone this will be a new namespace should it be saved.
|
||||
self._is_new = True
|
||||
|
||||
def __getitem__(self, key):
|
||||
return self.hash[key]
|
||||
|
||||
def __contains__(self, key):
|
||||
return key in self.hash
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
self.hash[key] = value
|
||||
|
||||
def __delitem__(self, key):
|
||||
del self.hash[key]
|
||||
|
||||
def keys(self):
|
||||
return self.hash.keys()
|
||||
|
||||
|
||||
class DatabaseContainer(Container):
|
||||
namespace_manager = DatabaseNamespaceManager
|
|
@ -1,122 +0,0 @@
|
|||
from beaker._compat import pickle
|
||||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
from beaker.container import OpenResourceNamespaceManager, Container
|
||||
from beaker.exceptions import InvalidCacheBackendError
|
||||
from beaker.synchronization import null_synchronizer
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
db = None
|
||||
|
||||
|
||||
class GoogleNamespaceManager(OpenResourceNamespaceManager):
|
||||
tables = {}
|
||||
|
||||
@classmethod
|
||||
def _init_dependencies(cls):
|
||||
global db
|
||||
if db is not None:
|
||||
return
|
||||
try:
|
||||
db = __import__('google.appengine.ext.db').appengine.ext.db
|
||||
except ImportError:
|
||||
raise InvalidCacheBackendError("Datastore cache backend requires the "
|
||||
"'google.appengine.ext' library")
|
||||
|
||||
def __init__(self, namespace, table_name='beaker_cache', **params):
|
||||
"""Creates a datastore namespace manager"""
|
||||
OpenResourceNamespaceManager.__init__(self, namespace)
|
||||
|
||||
def make_cache():
|
||||
table_dict = dict(created=db.DateTimeProperty(),
|
||||
accessed=db.DateTimeProperty(),
|
||||
data=db.BlobProperty())
|
||||
table = type(table_name, (db.Model,), table_dict)
|
||||
return table
|
||||
self.table_name = table_name
|
||||
self.cache = GoogleNamespaceManager.tables.setdefault(table_name, make_cache())
|
||||
self.hash = {}
|
||||
self._is_new = False
|
||||
self.loaded = False
|
||||
self.log_debug = logging.DEBUG >= log.getEffectiveLevel()
|
||||
|
||||
# Google wants namespaces to start with letters, change the namespace
|
||||
# to start with a letter
|
||||
self.namespace = 'p%s' % self.namespace
|
||||
|
||||
def get_access_lock(self):
|
||||
return null_synchronizer()
|
||||
|
||||
def get_creation_lock(self, key):
|
||||
# this is weird, should probably be present
|
||||
return null_synchronizer()
|
||||
|
||||
def do_open(self, flags, replace):
|
||||
# If we already loaded the data, don't bother loading it again
|
||||
if self.loaded:
|
||||
self.flags = flags
|
||||
return
|
||||
|
||||
item = self.cache.get_by_key_name(self.namespace)
|
||||
|
||||
if not item:
|
||||
self._is_new = True
|
||||
self.hash = {}
|
||||
else:
|
||||
self._is_new = False
|
||||
try:
|
||||
self.hash = pickle.loads(str(item.data))
|
||||
except (IOError, OSError, EOFError, pickle.PickleError):
|
||||
if self.log_debug:
|
||||
log.debug("Couln't load pickle data, creating new storage")
|
||||
self.hash = {}
|
||||
self._is_new = True
|
||||
self.flags = flags
|
||||
self.loaded = True
|
||||
|
||||
def do_close(self):
|
||||
if self.flags is not None and (self.flags == 'c' or self.flags == 'w'):
|
||||
if self._is_new:
|
||||
item = self.cache(key_name=self.namespace)
|
||||
item.data = pickle.dumps(self.hash)
|
||||
item.created = datetime.now()
|
||||
item.accessed = datetime.now()
|
||||
item.put()
|
||||
self._is_new = False
|
||||
else:
|
||||
item = self.cache.get_by_key_name(self.namespace)
|
||||
item.data = pickle.dumps(self.hash)
|
||||
item.accessed = datetime.now()
|
||||
item.put()
|
||||
self.flags = None
|
||||
|
||||
def do_remove(self):
|
||||
item = self.cache.get_by_key_name(self.namespace)
|
||||
item.delete()
|
||||
self.hash = {}
|
||||
|
||||
# We can retain the fact that we did a load attempt, but since the
|
||||
# file is gone this will be a new namespace should it be saved.
|
||||
self._is_new = True
|
||||
|
||||
def __getitem__(self, key):
|
||||
return self.hash[key]
|
||||
|
||||
def __contains__(self, key):
|
||||
return key in self.hash
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
self.hash[key] = value
|
||||
|
||||
def __delitem__(self, key):
|
||||
del self.hash[key]
|
||||
|
||||
def keys(self):
|
||||
return self.hash.keys()
|
||||
|
||||
|
||||
class GoogleContainer(Container):
|
||||
namespace_class = GoogleNamespaceManager
|
|
@ -1,218 +0,0 @@
|
|||
from .._compat import PY2
|
||||
|
||||
from beaker.container import NamespaceManager, Container
|
||||
from beaker.crypto.util import sha1
|
||||
from beaker.exceptions import InvalidCacheBackendError, MissingCacheParameter
|
||||
from beaker.synchronization import file_synchronizer
|
||||
from beaker.util import verify_directory, SyncDict, parse_memcached_behaviors
|
||||
import warnings
|
||||
|
||||
MAX_KEY_LENGTH = 250
|
||||
|
||||
_client_libs = {}
|
||||
|
||||
|
||||
def _load_client(name='auto'):
|
||||
if name in _client_libs:
|
||||
return _client_libs[name]
|
||||
|
||||
def _pylibmc():
|
||||
global pylibmc
|
||||
import pylibmc
|
||||
return pylibmc
|
||||
|
||||
def _cmemcache():
|
||||
global cmemcache
|
||||
import cmemcache
|
||||
warnings.warn("cmemcache is known to have serious "
|
||||
"concurrency issues; consider using 'memcache' "
|
||||
"or 'pylibmc'")
|
||||
return cmemcache
|
||||
|
||||
def _memcache():
|
||||
global memcache
|
||||
import memcache
|
||||
return memcache
|
||||
|
||||
def _bmemcached():
|
||||
global bmemcached
|
||||
import bmemcached
|
||||
return bmemcached
|
||||
|
||||
def _auto():
|
||||
for _client in (_pylibmc, _cmemcache, _memcache, _bmemcached):
|
||||
try:
|
||||
return _client()
|
||||
except ImportError:
|
||||
pass
|
||||
else:
|
||||
raise InvalidCacheBackendError(
|
||||
"Memcached cache backend requires one "
|
||||
"of: 'pylibmc' or 'memcache' to be installed.")
|
||||
|
||||
clients = {
|
||||
'pylibmc': _pylibmc,
|
||||
'cmemcache': _cmemcache,
|
||||
'memcache': _memcache,
|
||||
'bmemcached': _bmemcached,
|
||||
'auto': _auto
|
||||
}
|
||||
_client_libs[name] = clib = clients[name]()
|
||||
return clib
|
||||
|
||||
|
||||
def _is_configured_for_pylibmc(memcache_module_config, memcache_client):
|
||||
return memcache_module_config == 'pylibmc' or \
|
||||
memcache_client.__name__.startswith('pylibmc')
|
||||
|
||||
|
||||
class MemcachedNamespaceManager(NamespaceManager):
|
||||
"""Provides the :class:`.NamespaceManager` API over a memcache client library."""
|
||||
|
||||
clients = SyncDict()
|
||||
|
||||
def __new__(cls, *args, **kw):
|
||||
memcache_module = kw.pop('memcache_module', 'auto')
|
||||
|
||||
memcache_client = _load_client(memcache_module)
|
||||
|
||||
if _is_configured_for_pylibmc(memcache_module, memcache_client):
|
||||
return object.__new__(PyLibMCNamespaceManager)
|
||||
else:
|
||||
return object.__new__(MemcachedNamespaceManager)
|
||||
|
||||
def __init__(self, namespace, url,
|
||||
memcache_module='auto',
|
||||
data_dir=None, lock_dir=None,
|
||||
**kw):
|
||||
NamespaceManager.__init__(self, namespace)
|
||||
|
||||
_memcache_module = _client_libs[memcache_module]
|
||||
|
||||
if not url:
|
||||
raise MissingCacheParameter("url is required")
|
||||
|
||||
self.lock_dir = None
|
||||
|
||||
if lock_dir:
|
||||
self.lock_dir = lock_dir
|
||||
elif data_dir:
|
||||
self.lock_dir = data_dir + "/container_mcd_lock"
|
||||
if self.lock_dir:
|
||||
verify_directory(self.lock_dir)
|
||||
|
||||
# Check for pylibmc namespace manager, in which case client will be
|
||||
# instantiated by subclass __init__, to handle behavior passing to the
|
||||
# pylibmc client
|
||||
if not _is_configured_for_pylibmc(memcache_module, _memcache_module):
|
||||
self.mc = MemcachedNamespaceManager.clients.get(
|
||||
(memcache_module, url),
|
||||
_memcache_module.Client,
|
||||
url.split(';'))
|
||||
|
||||
def get_creation_lock(self, key):
|
||||
return file_synchronizer(
|
||||
identifier="memcachedcontainer/funclock/%s/%s" %
|
||||
(self.namespace, key), lock_dir=self.lock_dir)
|
||||
|
||||
def _format_key(self, key):
|
||||
if not isinstance(key, str):
|
||||
key = key.decode('ascii')
|
||||
formated_key = (self.namespace + '_' + key).replace(' ', '\302\267')
|
||||
if len(formated_key) > MAX_KEY_LENGTH:
|
||||
if not PY2:
|
||||
formated_key = formated_key.encode('utf-8')
|
||||
formated_key = sha1(formated_key).hexdigest()
|
||||
return formated_key
|
||||
|
||||
def __getitem__(self, key):
|
||||
return self.mc.get(self._format_key(key))
|
||||
|
||||
def __contains__(self, key):
|
||||
value = self.mc.get(self._format_key(key))
|
||||
return value is not None
|
||||
|
||||
def has_key(self, key):
|
||||
return key in self
|
||||
|
||||
def set_value(self, key, value, expiretime=None):
|
||||
if expiretime:
|
||||
self.mc.set(self._format_key(key), value, time=expiretime)
|
||||
else:
|
||||
self.mc.set(self._format_key(key), value)
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
self.set_value(key, value)
|
||||
|
||||
def __delitem__(self, key):
|
||||
self.mc.delete(self._format_key(key))
|
||||
|
||||
def do_remove(self):
|
||||
self.mc.flush_all()
|
||||
|
||||
def keys(self):
|
||||
raise NotImplementedError(
|
||||
"Memcache caching does not "
|
||||
"support iteration of all cache keys")
|
||||
|
||||
|
||||
class PyLibMCNamespaceManager(MemcachedNamespaceManager):
|
||||
"""Provide thread-local support for pylibmc."""
|
||||
|
||||
pools = SyncDict()
|
||||
|
||||
def __init__(self, *arg, **kw):
|
||||
super(PyLibMCNamespaceManager, self).__init__(*arg, **kw)
|
||||
|
||||
memcache_module = kw.get('memcache_module', 'auto')
|
||||
_memcache_module = _client_libs[memcache_module]
|
||||
protocol = kw.get('protocol', 'text')
|
||||
username = kw.get('username', None)
|
||||
password = kw.get('password', None)
|
||||
url = kw.get('url')
|
||||
behaviors = parse_memcached_behaviors(kw)
|
||||
|
||||
self.mc = MemcachedNamespaceManager.clients.get(
|
||||
(memcache_module, url),
|
||||
_memcache_module.Client,
|
||||
servers=url.split(';'), behaviors=behaviors,
|
||||
binary=(protocol == 'binary'), username=username,
|
||||
password=password)
|
||||
self.pool = PyLibMCNamespaceManager.pools.get(
|
||||
(memcache_module, url),
|
||||
pylibmc.ThreadMappedPool, self.mc)
|
||||
|
||||
def __getitem__(self, key):
|
||||
with self.pool.reserve() as mc:
|
||||
return mc.get(self._format_key(key))
|
||||
|
||||
def __contains__(self, key):
|
||||
with self.pool.reserve() as mc:
|
||||
value = mc.get(self._format_key(key))
|
||||
return value is not None
|
||||
|
||||
def has_key(self, key):
|
||||
return key in self
|
||||
|
||||
def set_value(self, key, value, expiretime=None):
|
||||
with self.pool.reserve() as mc:
|
||||
if expiretime:
|
||||
mc.set(self._format_key(key), value, time=expiretime)
|
||||
else:
|
||||
mc.set(self._format_key(key), value)
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
self.set_value(key, value)
|
||||
|
||||
def __delitem__(self, key):
|
||||
with self.pool.reserve() as mc:
|
||||
mc.delete(self._format_key(key))
|
||||
|
||||
def do_remove(self):
|
||||
with self.pool.reserve() as mc:
|
||||
mc.flush_all()
|
||||
|
||||
|
||||
class MemcachedContainer(Container):
|
||||
"""Container class which invokes :class:`.MemcacheNamespaceManager`."""
|
||||
namespace_class = MemcachedNamespaceManager
|
|
@ -1,184 +0,0 @@
|
|||
import datetime
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
import pickle
|
||||
|
||||
try:
|
||||
import pymongo
|
||||
import pymongo.errors
|
||||
import bson
|
||||
except ImportError:
|
||||
pymongo = None
|
||||
bson = None
|
||||
|
||||
from beaker.container import NamespaceManager
|
||||
from beaker.synchronization import SynchronizerImpl
|
||||
from beaker.util import SyncDict, machine_identifier
|
||||
from beaker.crypto.util import sha1
|
||||
from beaker._compat import string_type, PY2
|
||||
|
||||
|
||||
class MongoNamespaceManager(NamespaceManager):
|
||||
"""Provides the :class:`.NamespaceManager` API over MongoDB.
|
||||
|
||||
Provided ``url`` can be both a mongodb connection string or
|
||||
an already existing MongoClient instance.
|
||||
|
||||
The data will be stored into ``beaker_cache`` collection of the
|
||||
*default database*, so make sure your connection string or
|
||||
MongoClient point to a default database.
|
||||
"""
|
||||
MAX_KEY_LENGTH = 1024
|
||||
|
||||
clients = SyncDict()
|
||||
|
||||
def __init__(self, namespace, url, **kw):
|
||||
super(MongoNamespaceManager, self).__init__(namespace)
|
||||
self.lock_dir = None # MongoDB uses mongo itself for locking.
|
||||
|
||||
if pymongo is None:
|
||||
raise RuntimeError('pymongo3 is not available')
|
||||
|
||||
if isinstance(url, string_type):
|
||||
self.client = MongoNamespaceManager.clients.get(url, pymongo.MongoClient, url)
|
||||
else:
|
||||
self.client = url
|
||||
self.db = self.client.get_default_database()
|
||||
|
||||
def _format_key(self, key):
|
||||
if not isinstance(key, str):
|
||||
key = key.decode('ascii')
|
||||
if len(key) > (self.MAX_KEY_LENGTH - len(self.namespace) - 1):
|
||||
if not PY2:
|
||||
key = key.encode('utf-8')
|
||||
key = sha1(key).hexdigest()
|
||||
return '%s:%s' % (self.namespace, key)
|
||||
|
||||
def get_creation_lock(self, key):
|
||||
return MongoSynchronizer(self._format_key(key), self.client)
|
||||
|
||||
def __getitem__(self, key):
|
||||
self._clear_expired()
|
||||
entry = self.db.backer_cache.find_one({'_id': self._format_key(key)})
|
||||
if entry is None:
|
||||
raise KeyError(key)
|
||||
return pickle.loads(entry['value'])
|
||||
|
||||
def __contains__(self, key):
|
||||
self._clear_expired()
|
||||
entry = self.db.backer_cache.find_one({'_id': self._format_key(key)})
|
||||
return entry is not None
|
||||
|
||||
def has_key(self, key):
|
||||
return key in self
|
||||
|
||||
def set_value(self, key, value, expiretime=None):
|
||||
self._clear_expired()
|
||||
|
||||
expiration = None
|
||||
if expiretime is not None:
|
||||
expiration = time.time() + expiretime
|
||||
|
||||
value = pickle.dumps(value)
|
||||
self.db.backer_cache.update_one({'_id': self._format_key(key)},
|
||||
{'$set': {'value': bson.Binary(value),
|
||||
'expiration': expiration}},
|
||||
upsert=True)
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
self.set_value(key, value)
|
||||
|
||||
def __delitem__(self, key):
|
||||
self._clear_expired()
|
||||
self.db.backer_cache.delete_many({'_id': self._format_key(key)})
|
||||
|
||||
def do_remove(self):
|
||||
self.db.backer_cache.delete_many({'_id': {'$regex': '^%s' % self.namespace}})
|
||||
|
||||
def keys(self):
|
||||
return [e['key'].split(':', 1)[-1] for e in self.db.backer_cache.find_all(
|
||||
{'_id': {'$regex': '^%s' % self.namespace}}
|
||||
)]
|
||||
|
||||
def _clear_expired(self):
|
||||
now = time.time()
|
||||
self.db.backer_cache.delete_many({'_id': {'$regex': '^%s' % self.namespace},
|
||||
'expiration': {'$ne': None, '$lte': now}})
|
||||
|
||||
|
||||
class MongoSynchronizer(SynchronizerImpl):
|
||||
"""Provides a Writer/Reader lock based on MongoDB.
|
||||
|
||||
Provided ``url`` can be both a mongodb connection string or
|
||||
an already existing MongoClient instance.
|
||||
|
||||
The data will be stored into ``beaker_locks`` collection of the
|
||||
*default database*, so make sure your connection string or
|
||||
MongoClient point to a default database.
|
||||
|
||||
Locks are identified by local machine, PID and threadid, so
|
||||
are suitable for use in both local and distributed environments.
|
||||
"""
|
||||
# If a cache entry generation function can take a lot,
|
||||
# but 15 minutes is more than a reasonable time.
|
||||
LOCK_EXPIRATION = 900
|
||||
MACHINE_ID = machine_identifier()
|
||||
|
||||
def __init__(self, identifier, url):
|
||||
super(MongoSynchronizer, self).__init__()
|
||||
self.identifier = identifier
|
||||
if isinstance(url, string_type):
|
||||
self.client = MongoNamespaceManager.clients.get(url, pymongo.MongoClient, url)
|
||||
else:
|
||||
self.client = url
|
||||
self.db = self.client.get_default_database()
|
||||
|
||||
def _clear_expired_locks(self):
|
||||
now = datetime.datetime.utcnow()
|
||||
expired = now - datetime.timedelta(seconds=self.LOCK_EXPIRATION)
|
||||
self.db.beaker_locks.delete_many({'_id': self.identifier, 'timestamp': {'$lte': expired}})
|
||||
return now
|
||||
|
||||
def _get_owner_id(self):
|
||||
return '%s-%s-%s' % (self.MACHINE_ID, os.getpid(), threading.current_thread().ident)
|
||||
|
||||
def do_release_read_lock(self):
|
||||
owner_id = self._get_owner_id()
|
||||
self.db.beaker_locks.update_one({'_id': self.identifier, 'readers': owner_id},
|
||||
{'$pull': {'readers': owner_id}})
|
||||
|
||||
def do_acquire_read_lock(self, wait):
|
||||
now = self._clear_expired_locks()
|
||||
owner_id = self._get_owner_id()
|
||||
while True:
|
||||
try:
|
||||
self.db.beaker_locks.update_one({'_id': self.identifier, 'owner': None},
|
||||
{'$set': {'timestamp': now},
|
||||
'$push': {'readers': owner_id}},
|
||||
upsert=True)
|
||||
return True
|
||||
except pymongo.errors.DuplicateKeyError:
|
||||
if not wait:
|
||||
return False
|
||||
time.sleep(0.2)
|
||||
|
||||
def do_release_write_lock(self):
|
||||
self.db.beaker_locks.delete_one({'_id': self.identifier, 'owner': self._get_owner_id()})
|
||||
|
||||
def do_acquire_write_lock(self, wait):
|
||||
now = self._clear_expired_locks()
|
||||
owner_id = self._get_owner_id()
|
||||
while True:
|
||||
try:
|
||||
self.db.beaker_locks.update_one({'_id': self.identifier, 'owner': None,
|
||||
'readers': []},
|
||||
{'$set': {'owner': owner_id,
|
||||
'timestamp': now}},
|
||||
upsert=True)
|
||||
return True
|
||||
except pymongo.errors.DuplicateKeyError:
|
||||
if not wait:
|
||||
return False
|
||||
time.sleep(0.2)
|
||||
|
|
@ -1,144 +0,0 @@
|
|||
import os
|
||||
import threading
|
||||
import time
|
||||
import pickle
|
||||
|
||||
try:
|
||||
import redis
|
||||
except ImportError:
|
||||
redis = None
|
||||
|
||||
from beaker.container import NamespaceManager
|
||||
from beaker.synchronization import SynchronizerImpl
|
||||
from beaker.util import SyncDict, machine_identifier
|
||||
from beaker.crypto.util import sha1
|
||||
from beaker._compat import string_type, PY2
|
||||
|
||||
|
||||
class RedisNamespaceManager(NamespaceManager):
|
||||
"""Provides the :class:`.NamespaceManager` API over Redis.
|
||||
|
||||
Provided ``url`` can be both a redis connection string or
|
||||
an already existing StrictRedis instance.
|
||||
|
||||
The data will be stored into redis keys, with their name
|
||||
starting with ``beaker_cache:``. So make sure you provide
|
||||
a specific database number if you don't want to mix them
|
||||
with your own data.
|
||||
"""
|
||||
MAX_KEY_LENGTH = 1024
|
||||
|
||||
clients = SyncDict()
|
||||
|
||||
def __init__(self, namespace, url, timeout=None, **kw):
|
||||
super(RedisNamespaceManager, self).__init__(namespace)
|
||||
self.lock_dir = None # Redis uses redis itself for locking.
|
||||
self.timeout = timeout
|
||||
|
||||
if redis is None:
|
||||
raise RuntimeError('redis is not available')
|
||||
|
||||
if isinstance(url, string_type):
|
||||
self.client = RedisNamespaceManager.clients.get(url, redis.StrictRedis.from_url, url)
|
||||
else:
|
||||
self.client = url
|
||||
|
||||
def _format_key(self, key):
|
||||
if not isinstance(key, str):
|
||||
key = key.decode('ascii')
|
||||
if len(key) > (self.MAX_KEY_LENGTH - len(self.namespace) - len('beaker_cache:') - 1):
|
||||
if not PY2:
|
||||
key = key.encode('utf-8')
|
||||
key = sha1(key).hexdigest()
|
||||
return 'beaker_cache:%s:%s' % (self.namespace, key)
|
||||
|
||||
def get_creation_lock(self, key):
|
||||
return RedisSynchronizer(self._format_key(key), self.client)
|
||||
|
||||
def __getitem__(self, key):
|
||||
entry = self.client.get(self._format_key(key))
|
||||
if entry is None:
|
||||
raise KeyError(key)
|
||||
return pickle.loads(entry)
|
||||
|
||||
def __contains__(self, key):
|
||||
return self.client.exists(self._format_key(key))
|
||||
|
||||
def has_key(self, key):
|
||||
return key in self
|
||||
|
||||
def set_value(self, key, value, expiretime=None):
|
||||
value = pickle.dumps(value)
|
||||
if expiretime is None and self.timeout is not None:
|
||||
expiretime = self.timeout
|
||||
if expiretime is not None:
|
||||
self.client.setex(self._format_key(key), int(expiretime), value)
|
||||
else:
|
||||
self.client.set(self._format_key(key), value)
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
self.set_value(key, value)
|
||||
|
||||
def __delitem__(self, key):
|
||||
self.client.delete(self._format_key(key))
|
||||
|
||||
def do_remove(self):
|
||||
for k in self.keys():
|
||||
self.client.delete(k)
|
||||
|
||||
def keys(self):
|
||||
return self.client.keys('beaker_cache:%s:*' % self.namespace)
|
||||
|
||||
|
||||
class RedisSynchronizer(SynchronizerImpl):
|
||||
"""Synchronizer based on redis.
|
||||
|
||||
Provided ``url`` can be both a redis connection string or
|
||||
an already existing StrictRedis instance.
|
||||
|
||||
This Synchronizer only supports 1 reader or 1 writer at time, not concurrent readers.
|
||||
"""
|
||||
# If a cache entry generation function can take a lot,
|
||||
# but 15 minutes is more than a reasonable time.
|
||||
LOCK_EXPIRATION = 900
|
||||
MACHINE_ID = machine_identifier()
|
||||
|
||||
def __init__(self, identifier, url):
|
||||
super(RedisSynchronizer, self).__init__()
|
||||
self.identifier = 'beaker_lock:%s' % identifier
|
||||
if isinstance(url, string_type):
|
||||
self.client = RedisNamespaceManager.clients.get(url, redis.StrictRedis.from_url, url)
|
||||
else:
|
||||
self.client = url
|
||||
|
||||
def _get_owner_id(self):
|
||||
return (
|
||||
'%s-%s-%s' % (self.MACHINE_ID, os.getpid(), threading.current_thread().ident)
|
||||
).encode('ascii')
|
||||
|
||||
def do_release_read_lock(self):
|
||||
self.do_release_write_lock()
|
||||
|
||||
def do_acquire_read_lock(self, wait):
|
||||
self.do_acquire_write_lock(wait)
|
||||
|
||||
def do_release_write_lock(self):
|
||||
identifier = self.identifier
|
||||
owner_id = self._get_owner_id()
|
||||
def execute_release(pipe):
|
||||
lock_value = pipe.get(identifier)
|
||||
if lock_value == owner_id:
|
||||
pipe.delete(identifier)
|
||||
self.client.transaction(execute_release, identifier)
|
||||
|
||||
def do_acquire_write_lock(self, wait):
|
||||
owner_id = self._get_owner_id()
|
||||
while True:
|
||||
if self.client.setnx(self.identifier, owner_id):
|
||||
self.client.pexpire(self.identifier, self.LOCK_EXPIRATION * 1000)
|
||||
return True
|
||||
|
||||
if not wait:
|
||||
return False
|
||||
time.sleep(0.2)
|
||||
|
|
@ -1,137 +0,0 @@
|
|||
from beaker._compat import pickle
|
||||
|
||||
import logging
|
||||
import pickle
|
||||
from datetime import datetime
|
||||
|
||||
from beaker.container import OpenResourceNamespaceManager, Container
|
||||
from beaker.exceptions import InvalidCacheBackendError, MissingCacheParameter
|
||||
from beaker.synchronization import file_synchronizer, null_synchronizer
|
||||
from beaker.util import verify_directory, SyncDict
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
sa = None
|
||||
|
||||
|
||||
class SqlaNamespaceManager(OpenResourceNamespaceManager):
|
||||
binds = SyncDict()
|
||||
tables = SyncDict()
|
||||
|
||||
@classmethod
|
||||
def _init_dependencies(cls):
|
||||
global sa
|
||||
if sa is not None:
|
||||
return
|
||||
try:
|
||||
import sqlalchemy as sa
|
||||
except ImportError:
|
||||
raise InvalidCacheBackendError("SQLAlchemy, which is required by "
|
||||
"this backend, is not installed")
|
||||
|
||||
def __init__(self, namespace, bind, table, data_dir=None, lock_dir=None,
|
||||
**kwargs):
|
||||
"""Create a namespace manager for use with a database table via
|
||||
SQLAlchemy.
|
||||
|
||||
``bind``
|
||||
SQLAlchemy ``Engine`` or ``Connection`` object
|
||||
|
||||
``table``
|
||||
SQLAlchemy ``Table`` object in which to store namespace data.
|
||||
This should usually be something created by ``make_cache_table``.
|
||||
"""
|
||||
OpenResourceNamespaceManager.__init__(self, namespace)
|
||||
|
||||
if lock_dir:
|
||||
self.lock_dir = lock_dir
|
||||
elif data_dir:
|
||||
self.lock_dir = data_dir + "/container_db_lock"
|
||||
if self.lock_dir:
|
||||
verify_directory(self.lock_dir)
|
||||
|
||||
self.bind = self.__class__.binds.get(str(bind.url), lambda: bind)
|
||||
self.table = self.__class__.tables.get('%s:%s' % (bind.url, table.name),
|
||||
lambda: table)
|
||||
self.hash = {}
|
||||
self._is_new = False
|
||||
self.loaded = False
|
||||
|
||||
def get_access_lock(self):
|
||||
return null_synchronizer()
|
||||
|
||||
def get_creation_lock(self, key):
|
||||
return file_synchronizer(
|
||||
identifier="databasecontainer/funclock/%s" % self.namespace,
|
||||
lock_dir=self.lock_dir)
|
||||
|
||||
def do_open(self, flags, replace):
|
||||
if self.loaded:
|
||||
self.flags = flags
|
||||
return
|
||||
select = sa.select([self.table.c.data],
|
||||
(self.table.c.namespace == self.namespace))
|
||||
result = self.bind.execute(select).fetchone()
|
||||
if not result:
|
||||
self._is_new = True
|
||||
self.hash = {}
|
||||
else:
|
||||
self._is_new = False
|
||||
try:
|
||||
self.hash = result['data']
|
||||
except (IOError, OSError, EOFError, pickle.PickleError,
|
||||
pickle.PickleError):
|
||||
log.debug("Couln't load pickle data, creating new storage")
|
||||
self.hash = {}
|
||||
self._is_new = True
|
||||
self.flags = flags
|
||||
self.loaded = True
|
||||
|
||||
def do_close(self):
|
||||
if self.flags is not None and (self.flags == 'c' or self.flags == 'w'):
|
||||
if self._is_new:
|
||||
insert = self.table.insert()
|
||||
self.bind.execute(insert, namespace=self.namespace, data=self.hash,
|
||||
accessed=datetime.now(), created=datetime.now())
|
||||
self._is_new = False
|
||||
else:
|
||||
update = self.table.update(self.table.c.namespace == self.namespace)
|
||||
self.bind.execute(update, data=self.hash, accessed=datetime.now())
|
||||
self.flags = None
|
||||
|
||||
def do_remove(self):
|
||||
delete = self.table.delete(self.table.c.namespace == self.namespace)
|
||||
self.bind.execute(delete)
|
||||
self.hash = {}
|
||||
self._is_new = True
|
||||
|
||||
def __getitem__(self, key):
|
||||
return self.hash[key]
|
||||
|
||||
def __contains__(self, key):
|
||||
return key in self.hash
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
self.hash[key] = value
|
||||
|
||||
def __delitem__(self, key):
|
||||
del self.hash[key]
|
||||
|
||||
def keys(self):
|
||||
return self.hash.keys()
|
||||
|
||||
|
||||
class SqlaContainer(Container):
|
||||
namespace_manager = SqlaNamespaceManager
|
||||
|
||||
|
||||
def make_cache_table(metadata, table_name='beaker_cache', schema_name=None):
|
||||
"""Return a ``Table`` object suitable for storing cached values for the
|
||||
namespace manager. Do not create the table."""
|
||||
return sa.Table(table_name, metadata,
|
||||
sa.Column('namespace', sa.String(255), primary_key=True),
|
||||
sa.Column('accessed', sa.DateTime, nullable=False),
|
||||
sa.Column('created', sa.DateTime, nullable=False),
|
||||
sa.Column('data', sa.PickleType, nullable=False),
|
||||
schema=schema_name if schema_name else metadata.schema)
|
|
@ -1,169 +0,0 @@
|
|||
import warnings
|
||||
|
||||
try:
|
||||
from paste.registry import StackedObjectProxy
|
||||
beaker_session = StackedObjectProxy(name="Beaker Session")
|
||||
beaker_cache = StackedObjectProxy(name="Cache Manager")
|
||||
except:
|
||||
beaker_cache = None
|
||||
beaker_session = None
|
||||
|
||||
from beaker.cache import CacheManager
|
||||
from beaker.session import Session, SessionObject
|
||||
from beaker.util import coerce_cache_params, coerce_session_params, \
|
||||
parse_cache_config_options
|
||||
|
||||
|
||||
class CacheMiddleware(object):
|
||||
cache = beaker_cache
|
||||
|
||||
def __init__(self, app, config=None, environ_key='beaker.cache', **kwargs):
|
||||
"""Initialize the Cache Middleware
|
||||
|
||||
The Cache middleware will make a CacheManager instance available
|
||||
every request under the ``environ['beaker.cache']`` key by
|
||||
default. The location in environ can be changed by setting
|
||||
``environ_key``.
|
||||
|
||||
``config``
|
||||
dict All settings should be prefixed by 'cache.'. This
|
||||
method of passing variables is intended for Paste and other
|
||||
setups that accumulate multiple component settings in a
|
||||
single dictionary. If config contains *no cache. prefixed
|
||||
args*, then *all* of the config options will be used to
|
||||
intialize the Cache objects.
|
||||
|
||||
``environ_key``
|
||||
Location where the Cache instance will keyed in the WSGI
|
||||
environ
|
||||
|
||||
``**kwargs``
|
||||
All keyword arguments are assumed to be cache settings and
|
||||
will override any settings found in ``config``
|
||||
|
||||
"""
|
||||
self.app = app
|
||||
config = config or {}
|
||||
|
||||
self.options = {}
|
||||
|
||||
# Update the options with the parsed config
|
||||
self.options.update(parse_cache_config_options(config))
|
||||
|
||||
# Add any options from kwargs, but leave out the defaults this
|
||||
# time
|
||||
self.options.update(
|
||||
parse_cache_config_options(kwargs, include_defaults=False))
|
||||
|
||||
# Assume all keys are intended for cache if none are prefixed with
|
||||
# 'cache.'
|
||||
if not self.options and config:
|
||||
self.options = config
|
||||
|
||||
self.options.update(kwargs)
|
||||
self.cache_manager = CacheManager(**self.options)
|
||||
self.environ_key = environ_key
|
||||
|
||||
def __call__(self, environ, start_response):
|
||||
if environ.get('paste.registry'):
|
||||
if environ['paste.registry'].reglist:
|
||||
environ['paste.registry'].register(self.cache,
|
||||
self.cache_manager)
|
||||
environ[self.environ_key] = self.cache_manager
|
||||
return self.app(environ, start_response)
|
||||
|
||||
|
||||
class SessionMiddleware(object):
|
||||
session = beaker_session
|
||||
|
||||
def __init__(self, wrap_app, config=None, environ_key='beaker.session',
|
||||
**kwargs):
|
||||
"""Initialize the Session Middleware
|
||||
|
||||
The Session middleware will make a lazy session instance
|
||||
available every request under the ``environ['beaker.session']``
|
||||
key by default. The location in environ can be changed by
|
||||
setting ``environ_key``.
|
||||
|
||||
``config``
|
||||
dict All settings should be prefixed by 'session.'. This
|
||||
method of passing variables is intended for Paste and other
|
||||
setups that accumulate multiple component settings in a
|
||||
single dictionary. If config contains *no session. prefixed
|
||||
args*, then *all* of the config options will be used to
|
||||
intialize the Session objects.
|
||||
|
||||
``environ_key``
|
||||
Location where the Session instance will keyed in the WSGI
|
||||
environ
|
||||
|
||||
``**kwargs``
|
||||
All keyword arguments are assumed to be session settings and
|
||||
will override any settings found in ``config``
|
||||
|
||||
"""
|
||||
config = config or {}
|
||||
|
||||
# Load up the default params
|
||||
self.options = dict(invalidate_corrupt=True, type=None,
|
||||
data_dir=None, key='beaker.session.id',
|
||||
timeout=None, save_accessed_time=True, secret=None,
|
||||
log_file=None)
|
||||
|
||||
# Pull out any config args meant for beaker session. if there are any
|
||||
for dct in [config, kwargs]:
|
||||
for key, val in dct.items():
|
||||
if key.startswith('beaker.session.'):
|
||||
self.options[key[15:]] = val
|
||||
if key.startswith('session.'):
|
||||
self.options[key[8:]] = val
|
||||
if key.startswith('session_'):
|
||||
warnings.warn('Session options should start with session. '
|
||||
'instead of session_.', DeprecationWarning, 2)
|
||||
self.options[key[8:]] = val
|
||||
|
||||
# Coerce and validate session params
|
||||
coerce_session_params(self.options)
|
||||
|
||||
# Assume all keys are intended for session if none are prefixed with
|
||||
# 'session.'
|
||||
if not self.options and config:
|
||||
self.options = config
|
||||
|
||||
self.options.update(kwargs)
|
||||
self.wrap_app = self.app = wrap_app
|
||||
self.environ_key = environ_key
|
||||
|
||||
def __call__(self, environ, start_response):
|
||||
session = SessionObject(environ, **self.options)
|
||||
if environ.get('paste.registry'):
|
||||
if environ['paste.registry'].reglist:
|
||||
environ['paste.registry'].register(self.session, session)
|
||||
environ[self.environ_key] = session
|
||||
environ['beaker.get_session'] = self._get_session
|
||||
|
||||
if 'paste.testing_variables' in environ and 'webtest_varname' in self.options:
|
||||
environ['paste.testing_variables'][self.options['webtest_varname']] = session
|
||||
|
||||
def session_start_response(status, headers, exc_info=None):
|
||||
if session.accessed():
|
||||
session.persist()
|
||||
if session.__dict__['_headers']['set_cookie']:
|
||||
cookie = session.__dict__['_headers']['cookie_out']
|
||||
if cookie:
|
||||
headers.append(('Set-cookie', cookie))
|
||||
return start_response(status, headers, exc_info)
|
||||
return self.wrap_app(environ, session_start_response)
|
||||
|
||||
def _get_session(self):
|
||||
return Session({}, use_cookies=False, **self.options)
|
||||
|
||||
|
||||
def session_filter_factory(global_conf, **kwargs):
|
||||
def filter(app):
|
||||
return SessionMiddleware(app, global_conf, **kwargs)
|
||||
return filter
|
||||
|
||||
|
||||
def session_filter_app_factory(app, global_conf, **kwargs):
|
||||
return SessionMiddleware(app, global_conf, **kwargs)
|
|
@ -1,845 +0,0 @@
|
|||
from ._compat import PY2, pickle, http_cookies, unicode_text, b64encode, b64decode, string_type
|
||||
|
||||
import os
|
||||
import time
|
||||
from datetime import datetime, timedelta
|
||||
from beaker.crypto import hmac as HMAC, hmac_sha1 as SHA1, sha1, get_nonce_size, DEFAULT_NONCE_BITS, get_crypto_module
|
||||
from beaker import crypto, util
|
||||
from beaker.cache import clsmap
|
||||
from beaker.exceptions import BeakerException, InvalidCryptoBackendError
|
||||
from beaker.cookie import SimpleCookie
|
||||
|
||||
__all__ = ['SignedCookie', 'Session', 'InvalidSignature']
|
||||
|
||||
|
||||
class _InvalidSignatureType(object):
|
||||
"""Returned from SignedCookie when the value's signature was invalid."""
|
||||
def __nonzero__(self):
|
||||
return False
|
||||
|
||||
def __bool__(self):
|
||||
return False
|
||||
|
||||
|
||||
InvalidSignature = _InvalidSignatureType()
|
||||
|
||||
|
||||
try:
|
||||
import uuid
|
||||
|
||||
def _session_id():
|
||||
return uuid.uuid4().hex
|
||||
except ImportError:
|
||||
import random
|
||||
if hasattr(os, 'getpid'):
|
||||
getpid = os.getpid
|
||||
else:
|
||||
def getpid():
|
||||
return ''
|
||||
|
||||
def _session_id():
|
||||
id_str = "%f%s%f%s" % (
|
||||
time.time(),
|
||||
id({}),
|
||||
random.random(),
|
||||
getpid()
|
||||
)
|
||||
# NB: nothing against second parameter to b64encode, but it seems
|
||||
# to be slower than simple chained replacement
|
||||
if not PY2:
|
||||
raw_id = b64encode(sha1(id_str.encode('ascii')).digest())
|
||||
return str(raw_id.replace(b'+', b'-').replace(b'/', b'_').rstrip(b'='))
|
||||
else:
|
||||
raw_id = b64encode(sha1(id_str).digest())
|
||||
return raw_id.replace('+', '-').replace('/', '_').rstrip('=')
|
||||
|
||||
|
||||
class SignedCookie(SimpleCookie):
|
||||
"""Extends python cookie to give digital signature support"""
|
||||
def __init__(self, secret, input=None):
|
||||
self.secret = secret.encode('UTF-8')
|
||||
http_cookies.BaseCookie.__init__(self, input)
|
||||
|
||||
def value_decode(self, val):
|
||||
val = val.strip('"')
|
||||
if not val:
|
||||
return None, val
|
||||
|
||||
sig = HMAC.new(self.secret, val[40:].encode('utf-8'), SHA1).hexdigest()
|
||||
|
||||
# Avoid timing attacks
|
||||
invalid_bits = 0
|
||||
input_sig = val[:40]
|
||||
if len(sig) != len(input_sig):
|
||||
return InvalidSignature, val
|
||||
|
||||
for a, b in zip(sig, input_sig):
|
||||
invalid_bits += a != b
|
||||
|
||||
if invalid_bits:
|
||||
return InvalidSignature, val
|
||||
else:
|
||||
return val[40:], val
|
||||
|
||||
def value_encode(self, val):
|
||||
sig = HMAC.new(self.secret, val.encode('utf-8'), SHA1).hexdigest()
|
||||
return str(val), ("%s%s" % (sig, val))
|
||||
|
||||
|
||||
class Session(dict):
|
||||
"""Session object that uses container package for storage.
|
||||
|
||||
:param invalidate_corrupt: How to handle corrupt data when loading. When
|
||||
set to True, then corrupt data will be silently
|
||||
invalidated and a new session created,
|
||||
otherwise invalid data will cause an exception.
|
||||
:type invalidate_corrupt: bool
|
||||
:param use_cookies: Whether or not cookies should be created. When set to
|
||||
False, it is assumed the user will handle storing the
|
||||
session on their own.
|
||||
:type use_cookies: bool
|
||||
:param type: What data backend type should be used to store the underlying
|
||||
session data
|
||||
:param key: The name the cookie should be set to.
|
||||
:param timeout: How long session data is considered valid. This is used
|
||||
regardless of the cookie being present or not to determine
|
||||
whether session data is still valid. Can be set to None to
|
||||
disable session time out.
|
||||
:type timeout: int or None
|
||||
:param save_accessed_time: Whether beaker should save the session's access
|
||||
time (True) or only modification time (False).
|
||||
Defaults to True.
|
||||
:param cookie_expires: Expiration date for cookie
|
||||
:param cookie_domain: Domain to use for the cookie.
|
||||
:param cookie_path: Path to use for the cookie.
|
||||
:param data_serializer: If ``"json"`` or ``"pickle"`` should be used
|
||||
to serialize data. Can also be an object with
|
||||
``loads` and ``dumps`` methods. By default
|
||||
``"pickle"`` is used.
|
||||
:param secure: Whether or not the cookie should only be sent over SSL.
|
||||
:param httponly: Whether or not the cookie should only be accessible by
|
||||
the browser not by JavaScript.
|
||||
:param encrypt_key: The key to use for the local session encryption, if not
|
||||
provided the session will not be encrypted.
|
||||
:param validate_key: The key used to sign the local encrypted session
|
||||
:param encrypt_nonce_bits: Number of bits used to generate nonce for encryption key salt.
|
||||
For security reason this is 128bits be default. If you want
|
||||
to keep backward compatibility with sessions generated before 1.8.0
|
||||
set this to 48.
|
||||
:param crypto_type: encryption module to use
|
||||
:param samesite: SameSite value for the cookie -- should be either 'Lax',
|
||||
'Strict', or None.
|
||||
"""
|
||||
def __init__(self, request, id=None, invalidate_corrupt=False,
|
||||
use_cookies=True, type=None, data_dir=None,
|
||||
key='beaker.session.id', timeout=None, save_accessed_time=True,
|
||||
cookie_expires=True, cookie_domain=None, cookie_path='/',
|
||||
data_serializer='pickle', secret=None,
|
||||
secure=False, namespace_class=None, httponly=False,
|
||||
encrypt_key=None, validate_key=None, encrypt_nonce_bits=DEFAULT_NONCE_BITS,
|
||||
crypto_type='default', samesite='Lax',
|
||||
**namespace_args):
|
||||
if not type:
|
||||
if data_dir:
|
||||
self.type = 'file'
|
||||
else:
|
||||
self.type = 'memory'
|
||||
else:
|
||||
self.type = type
|
||||
|
||||
self.namespace_class = namespace_class or clsmap[self.type]
|
||||
|
||||
self.namespace_args = namespace_args
|
||||
|
||||
self.request = request
|
||||
self.data_dir = data_dir
|
||||
self.key = key
|
||||
|
||||
if timeout and not save_accessed_time:
|
||||
raise BeakerException("timeout requires save_accessed_time")
|
||||
self.timeout = timeout
|
||||
|
||||
# If a timeout was provided, forward it to the backend too, so the backend
|
||||
# can automatically expire entries if it's supported.
|
||||
if self.timeout is not None:
|
||||
# The backend expiration should always be a bit longer than the
|
||||
# session expiration itself to prevent the case where the backend data expires while
|
||||
# the session is being read (PR#153). 2 Minutes seems a reasonable time.
|
||||
self.namespace_args['timeout'] = self.timeout + 60 * 2
|
||||
|
||||
self.save_atime = save_accessed_time
|
||||
self.use_cookies = use_cookies
|
||||
self.cookie_expires = cookie_expires
|
||||
|
||||
self._set_serializer(data_serializer)
|
||||
|
||||
# Default cookie domain/path
|
||||
self._domain = cookie_domain
|
||||
self._path = cookie_path
|
||||
self.was_invalidated = False
|
||||
self.secret = secret
|
||||
self.secure = secure
|
||||
self.httponly = httponly
|
||||
self.samesite = samesite
|
||||
self.encrypt_key = encrypt_key
|
||||
self.validate_key = validate_key
|
||||
self.encrypt_nonce_size = get_nonce_size(encrypt_nonce_bits)
|
||||
self.crypto_module = get_crypto_module(crypto_type)
|
||||
self.id = id
|
||||
self.accessed_dict = {}
|
||||
self.invalidate_corrupt = invalidate_corrupt
|
||||
|
||||
if self.use_cookies:
|
||||
cookieheader = request.get('cookie', '')
|
||||
if secret:
|
||||
try:
|
||||
self.cookie = SignedCookie(
|
||||
secret,
|
||||
input=cookieheader,
|
||||
)
|
||||
except http_cookies.CookieError:
|
||||
self.cookie = SignedCookie(
|
||||
secret,
|
||||
input=None,
|
||||
)
|
||||
else:
|
||||
self.cookie = SimpleCookie(input=cookieheader)
|
||||
|
||||
if not self.id and self.key in self.cookie:
|
||||
cookie_data = self.cookie[self.key].value
|
||||
# Should we check invalidate_corrupt here?
|
||||
if cookie_data is InvalidSignature:
|
||||
cookie_data = None
|
||||
self.id = cookie_data
|
||||
|
||||
self.is_new = self.id is None
|
||||
if self.is_new:
|
||||
self._create_id()
|
||||
self['_accessed_time'] = self['_creation_time'] = time.time()
|
||||
else:
|
||||
try:
|
||||
self.load()
|
||||
except Exception as e:
|
||||
if self.invalidate_corrupt:
|
||||
util.warn(
|
||||
"Invalidating corrupt session %s; "
|
||||
"error was: %s. Set invalidate_corrupt=False "
|
||||
"to propagate this exception." % (self.id, e))
|
||||
self.invalidate()
|
||||
else:
|
||||
raise
|
||||
|
||||
def _set_serializer(self, data_serializer):
|
||||
self.data_serializer = data_serializer
|
||||
if self.data_serializer == 'json':
|
||||
self.serializer = util.JsonSerializer()
|
||||
elif self.data_serializer == 'pickle':
|
||||
self.serializer = util.PickleSerializer()
|
||||
elif isinstance(self.data_serializer, string_type):
|
||||
raise BeakerException('Invalid value for data_serializer: %s' % data_serializer)
|
||||
else:
|
||||
self.serializer = data_serializer
|
||||
|
||||
def has_key(self, name):
|
||||
return name in self
|
||||
|
||||
def _set_cookie_values(self, expires=None):
|
||||
self.cookie[self.key] = self.id
|
||||
if self._domain:
|
||||
self.cookie[self.key]['domain'] = self._domain
|
||||
if self.secure:
|
||||
self.cookie[self.key]['secure'] = True
|
||||
if self.samesite:
|
||||
self.cookie[self.key]['samesite'] = self.samesite
|
||||
self._set_cookie_http_only()
|
||||
self.cookie[self.key]['path'] = self._path
|
||||
|
||||
self._set_cookie_expires(expires)
|
||||
|
||||
def _set_cookie_expires(self, expires):
|
||||
if expires is None:
|
||||
expires = self.cookie_expires
|
||||
if expires is False:
|
||||
expires_date = datetime.fromtimestamp(0x7FFFFFFF)
|
||||
elif isinstance(expires, timedelta):
|
||||
expires_date = datetime.utcnow() + expires
|
||||
elif isinstance(expires, datetime):
|
||||
expires_date = expires
|
||||
elif expires is not True:
|
||||
raise ValueError("Invalid argument for cookie_expires: %s"
|
||||
% repr(self.cookie_expires))
|
||||
self.cookie_expires = expires
|
||||
if not self.cookie or self.key not in self.cookie:
|
||||
self.cookie[self.key] = self.id
|
||||
if expires is True:
|
||||
self.cookie[self.key]['expires'] = ''
|
||||
return True
|
||||
self.cookie[self.key]['expires'] = \
|
||||
expires_date.strftime("%a, %d-%b-%Y %H:%M:%S GMT")
|
||||
return expires_date
|
||||
|
||||
def _update_cookie_out(self, set_cookie=True):
|
||||
self._set_cookie_values()
|
||||
self.request['cookie_out'] = self.cookie[self.key].output(header='')
|
||||
self.request['set_cookie'] = set_cookie
|
||||
|
||||
def _set_cookie_http_only(self):
|
||||
try:
|
||||
if self.httponly:
|
||||
self.cookie[self.key]['httponly'] = True
|
||||
except http_cookies.CookieError as e:
|
||||
if 'Invalid Attribute httponly' not in str(e):
|
||||
raise
|
||||
util.warn('Python 2.6+ is required to use httponly')
|
||||
|
||||
def _create_id(self, set_new=True):
|
||||
self.id = _session_id()
|
||||
|
||||
if set_new:
|
||||
self.is_new = True
|
||||
self.last_accessed = None
|
||||
if self.use_cookies:
|
||||
sc = set_new is False
|
||||
self._update_cookie_out(set_cookie=sc)
|
||||
|
||||
@property
|
||||
def created(self):
|
||||
return self['_creation_time']
|
||||
|
||||
def _set_domain(self, domain):
|
||||
self['_domain'] = self._domain = domain
|
||||
self._update_cookie_out()
|
||||
|
||||
def _get_domain(self):
|
||||
return self._domain
|
||||
|
||||
domain = property(_get_domain, _set_domain)
|
||||
|
||||
def _set_path(self, path):
|
||||
self['_path'] = self._path = path
|
||||
self._update_cookie_out()
|
||||
|
||||
def _get_path(self):
|
||||
return self._path
|
||||
|
||||
path = property(_get_path, _set_path)
|
||||
|
||||
def _encrypt_data(self, session_data=None):
|
||||
"""Serialize, encipher, and base64 the session dict"""
|
||||
session_data = session_data or self.copy()
|
||||
if self.encrypt_key:
|
||||
nonce_len, nonce_b64len = self.encrypt_nonce_size
|
||||
nonce = b64encode(os.urandom(nonce_len))[:nonce_b64len]
|
||||
encrypt_key = crypto.generateCryptoKeys(self.encrypt_key,
|
||||
self.validate_key + nonce,
|
||||
1,
|
||||
self.crypto_module.getKeyLength())
|
||||
data = self.serializer.dumps(session_data)
|
||||
return nonce + b64encode(self.crypto_module.aesEncrypt(data, encrypt_key))
|
||||
else:
|
||||
data = self.serializer.dumps(session_data)
|
||||
return b64encode(data)
|
||||
|
||||
def _decrypt_data(self, session_data):
|
||||
"""Base64, decipher, then un-serialize the data for the session
|
||||
dict"""
|
||||
if self.encrypt_key:
|
||||
__, nonce_b64len = self.encrypt_nonce_size
|
||||
nonce = session_data[:nonce_b64len]
|
||||
encrypt_key = crypto.generateCryptoKeys(self.encrypt_key,
|
||||
self.validate_key + nonce,
|
||||
1,
|
||||
self.crypto_module.getKeyLength())
|
||||
payload = b64decode(session_data[nonce_b64len:])
|
||||
data = self.crypto_module.aesDecrypt(payload, encrypt_key)
|
||||
else:
|
||||
data = b64decode(session_data)
|
||||
|
||||
return self.serializer.loads(data)
|
||||
|
||||
def _delete_cookie(self):
|
||||
self.request['set_cookie'] = True
|
||||
expires = datetime.utcnow() - timedelta(365)
|
||||
self._set_cookie_values(expires)
|
||||
self._update_cookie_out()
|
||||
|
||||
def delete(self):
|
||||
"""Deletes the session from the persistent storage, and sends
|
||||
an expired cookie out"""
|
||||
if self.use_cookies:
|
||||
self._delete_cookie()
|
||||
self.clear()
|
||||
|
||||
def invalidate(self):
|
||||
"""Invalidates this session, creates a new session id, returns
|
||||
to the is_new state"""
|
||||
self.clear()
|
||||
self.was_invalidated = True
|
||||
self._create_id()
|
||||
self.load()
|
||||
|
||||
def load(self):
|
||||
"Loads the data from this session from persistent storage"
|
||||
self.namespace = self.namespace_class(self.id,
|
||||
data_dir=self.data_dir,
|
||||
digest_filenames=False,
|
||||
**self.namespace_args)
|
||||
now = time.time()
|
||||
if self.use_cookies:
|
||||
self.request['set_cookie'] = True
|
||||
|
||||
self.namespace.acquire_read_lock()
|
||||
timed_out = False
|
||||
try:
|
||||
self.clear()
|
||||
try:
|
||||
session_data = self.namespace['session']
|
||||
|
||||
if (session_data is not None and self.encrypt_key):
|
||||
session_data = self._decrypt_data(session_data)
|
||||
|
||||
# Memcached always returns a key, its None when its not
|
||||
# present
|
||||
if session_data is None:
|
||||
session_data = {
|
||||
'_creation_time': now,
|
||||
'_accessed_time': now
|
||||
}
|
||||
self.is_new = True
|
||||
except (KeyError, TypeError):
|
||||
session_data = {
|
||||
'_creation_time': now,
|
||||
'_accessed_time': now
|
||||
}
|
||||
self.is_new = True
|
||||
|
||||
if session_data is None or len(session_data) == 0:
|
||||
session_data = {
|
||||
'_creation_time': now,
|
||||
'_accessed_time': now
|
||||
}
|
||||
self.is_new = True
|
||||
|
||||
if self.timeout is not None and \
|
||||
now - session_data['_accessed_time'] > self.timeout:
|
||||
timed_out = True
|
||||
else:
|
||||
# Properly set the last_accessed time, which is different
|
||||
# than the *currently* _accessed_time
|
||||
if self.is_new or '_accessed_time' not in session_data:
|
||||
self.last_accessed = None
|
||||
else:
|
||||
self.last_accessed = session_data['_accessed_time']
|
||||
|
||||
# Update the current _accessed_time
|
||||
session_data['_accessed_time'] = now
|
||||
|
||||
# Set the path if applicable
|
||||
if '_path' in session_data:
|
||||
self._path = session_data['_path']
|
||||
self.update(session_data)
|
||||
self.accessed_dict = session_data.copy()
|
||||
finally:
|
||||
self.namespace.release_read_lock()
|
||||
if timed_out:
|
||||
self.invalidate()
|
||||
|
||||
def save(self, accessed_only=False):
|
||||
"""Saves the data for this session to persistent storage
|
||||
|
||||
If accessed_only is True, then only the original data loaded
|
||||
at the beginning of the request will be saved, with the updated
|
||||
last accessed time.
|
||||
|
||||
"""
|
||||
# Look to see if its a new session that was only accessed
|
||||
# Don't save it under that case
|
||||
if accessed_only and (self.is_new or not self.save_atime):
|
||||
return None
|
||||
|
||||
# this session might not have a namespace yet or the session id
|
||||
# might have been regenerated
|
||||
if not hasattr(self, 'namespace') or self.namespace.namespace != self.id:
|
||||
self.namespace = self.namespace_class(
|
||||
self.id,
|
||||
data_dir=self.data_dir,
|
||||
digest_filenames=False,
|
||||
**self.namespace_args)
|
||||
|
||||
self.namespace.acquire_write_lock(replace=True)
|
||||
try:
|
||||
if accessed_only:
|
||||
data = dict(self.accessed_dict.items())
|
||||
else:
|
||||
data = dict(self.items())
|
||||
|
||||
if self.encrypt_key:
|
||||
data = self._encrypt_data(data)
|
||||
|
||||
# Save the data
|
||||
if not data and 'session' in self.namespace:
|
||||
del self.namespace['session']
|
||||
else:
|
||||
self.namespace['session'] = data
|
||||
finally:
|
||||
self.namespace.release_write_lock()
|
||||
if self.use_cookies and self.is_new:
|
||||
self.request['set_cookie'] = True
|
||||
|
||||
def revert(self):
|
||||
"""Revert the session to its original state from its first
|
||||
access in the request"""
|
||||
self.clear()
|
||||
self.update(self.accessed_dict)
|
||||
|
||||
def regenerate_id(self):
|
||||
"""
|
||||
creates a new session id, retains all session data
|
||||
|
||||
Its a good security practice to regnerate the id after a client
|
||||
elevates privileges.
|
||||
|
||||
"""
|
||||
self._create_id(set_new=False)
|
||||
|
||||
# TODO: I think both these methods should be removed. They're from
|
||||
# the original mod_python code i was ripping off but they really
|
||||
# have no use here.
|
||||
def lock(self):
|
||||
"""Locks this session against other processes/threads. This is
|
||||
automatic when load/save is called.
|
||||
|
||||
***use with caution*** and always with a corresponding 'unlock'
|
||||
inside a "finally:" block, as a stray lock typically cannot be
|
||||
unlocked without shutting down the whole application.
|
||||
|
||||
"""
|
||||
self.namespace.acquire_write_lock()
|
||||
|
||||
def unlock(self):
|
||||
"""Unlocks this session against other processes/threads. This
|
||||
is automatic when load/save is called.
|
||||
|
||||
***use with caution*** and always within a "finally:" block, as
|
||||
a stray lock typically cannot be unlocked without shutting down
|
||||
the whole application.
|
||||
|
||||
"""
|
||||
self.namespace.release_write_lock()
|
||||
|
||||
|
||||
class CookieSession(Session):
|
||||
"""Pure cookie-based session
|
||||
|
||||
Options recognized when using cookie-based sessions are slightly
|
||||
more restricted than general sessions.
|
||||
|
||||
:param key: The name the cookie should be set to.
|
||||
:param timeout: How long session data is considered valid. This is used
|
||||
regardless of the cookie being present or not to determine
|
||||
whether session data is still valid.
|
||||
:type timeout: int
|
||||
:param save_accessed_time: Whether beaker should save the session's access
|
||||
time (True) or only modification time (False).
|
||||
Defaults to True.
|
||||
:param cookie_expires: Expiration date for cookie
|
||||
:param cookie_domain: Domain to use for the cookie.
|
||||
:param cookie_path: Path to use for the cookie.
|
||||
:param data_serializer: If ``"json"`` or ``"pickle"`` should be used
|
||||
to serialize data. Can also be an object with
|
||||
``loads` and ``dumps`` methods. By default
|
||||
``"pickle"`` is used.
|
||||
:param secure: Whether or not the cookie should only be sent over SSL.
|
||||
:param httponly: Whether or not the cookie should only be accessible by
|
||||
the browser not by JavaScript.
|
||||
:param encrypt_key: The key to use for the local session encryption, if not
|
||||
provided the session will not be encrypted.
|
||||
:param validate_key: The key used to sign the local encrypted session
|
||||
:param invalidate_corrupt: How to handle corrupt data when loading. When
|
||||
set to True, then corrupt data will be silently
|
||||
invalidated and a new session created,
|
||||
otherwise invalid data will cause an exception.
|
||||
:type invalidate_corrupt: bool
|
||||
:param crypto_type: The crypto module to use.
|
||||
:param samesite: SameSite value for the cookie -- should be either 'Lax',
|
||||
'Strict', or None.
|
||||
"""
|
||||
def __init__(self, request, key='beaker.session.id', timeout=None,
|
||||
save_accessed_time=True, cookie_expires=True, cookie_domain=None,
|
||||
cookie_path='/', encrypt_key=None, validate_key=None, secure=False,
|
||||
httponly=False, data_serializer='pickle',
|
||||
encrypt_nonce_bits=DEFAULT_NONCE_BITS, invalidate_corrupt=False,
|
||||
crypto_type='default', samesite='Lax',
|
||||
**kwargs):
|
||||
|
||||
self.crypto_module = get_crypto_module(crypto_type)
|
||||
|
||||
if encrypt_key and not self.crypto_module.has_aes:
|
||||
raise InvalidCryptoBackendError("No AES library is installed, can't generate "
|
||||
"encrypted cookie-only Session.")
|
||||
|
||||
self.request = request
|
||||
self.key = key
|
||||
self.timeout = timeout
|
||||
self.save_atime = save_accessed_time
|
||||
self.cookie_expires = cookie_expires
|
||||
self.encrypt_key = encrypt_key
|
||||
self.validate_key = validate_key
|
||||
self.encrypt_nonce_size = get_nonce_size(encrypt_nonce_bits)
|
||||
self.request['set_cookie'] = False
|
||||
self.secure = secure
|
||||
self.httponly = httponly
|
||||
self.samesite = samesite
|
||||
self._domain = cookie_domain
|
||||
self._path = cookie_path
|
||||
self.invalidate_corrupt = invalidate_corrupt
|
||||
self._set_serializer(data_serializer)
|
||||
|
||||
try:
|
||||
cookieheader = request['cookie']
|
||||
except KeyError:
|
||||
cookieheader = ''
|
||||
|
||||
if validate_key is None:
|
||||
raise BeakerException("No validate_key specified for Cookie only "
|
||||
"Session.")
|
||||
if timeout and not save_accessed_time:
|
||||
raise BeakerException("timeout requires save_accessed_time")
|
||||
|
||||
try:
|
||||
self.cookie = SignedCookie(
|
||||
validate_key,
|
||||
input=cookieheader,
|
||||
)
|
||||
except http_cookies.CookieError:
|
||||
self.cookie = SignedCookie(
|
||||
validate_key,
|
||||
input=None,
|
||||
)
|
||||
|
||||
self['_id'] = _session_id()
|
||||
self.is_new = True
|
||||
|
||||
# If we have a cookie, load it
|
||||
if self.key in self.cookie and self.cookie[self.key].value is not None:
|
||||
self.is_new = False
|
||||
try:
|
||||
cookie_data = self.cookie[self.key].value
|
||||
if cookie_data is InvalidSignature:
|
||||
raise BeakerException("Invalid signature")
|
||||
self.update(self._decrypt_data(cookie_data))
|
||||
self._path = self.get('_path', '/')
|
||||
except Exception as e:
|
||||
if self.invalidate_corrupt:
|
||||
util.warn(
|
||||
"Invalidating corrupt session %s; "
|
||||
"error was: %s. Set invalidate_corrupt=False "
|
||||
"to propagate this exception." % (self.id, e))
|
||||
self.invalidate()
|
||||
else:
|
||||
raise
|
||||
|
||||
if self.timeout is not None:
|
||||
now = time.time()
|
||||
last_accessed_time = self.get('_accessed_time', now)
|
||||
if now - last_accessed_time > self.timeout:
|
||||
self.clear()
|
||||
|
||||
self.accessed_dict = self.copy()
|
||||
self._create_cookie()
|
||||
|
||||
def created(self):
|
||||
return self['_creation_time']
|
||||
created = property(created)
|
||||
|
||||
def id(self):
|
||||
return self['_id']
|
||||
id = property(id)
|
||||
|
||||
def _set_domain(self, domain):
|
||||
self['_domain'] = domain
|
||||
self._domain = domain
|
||||
|
||||
def _get_domain(self):
|
||||
return self._domain
|
||||
|
||||
domain = property(_get_domain, _set_domain)
|
||||
|
||||
def _set_path(self, path):
|
||||
self['_path'] = self._path = path
|
||||
|
||||
def _get_path(self):
|
||||
return self._path
|
||||
|
||||
path = property(_get_path, _set_path)
|
||||
|
||||
def save(self, accessed_only=False):
|
||||
"""Saves the data for this session to persistent storage"""
|
||||
if accessed_only and (self.is_new or not self.save_atime):
|
||||
return
|
||||
if accessed_only:
|
||||
self.clear()
|
||||
self.update(self.accessed_dict)
|
||||
self._create_cookie()
|
||||
|
||||
def expire(self):
|
||||
"""Delete the 'expires' attribute on this Session, if any."""
|
||||
|
||||
self.pop('_expires', None)
|
||||
|
||||
def _create_cookie(self):
|
||||
if '_creation_time' not in self:
|
||||
self['_creation_time'] = time.time()
|
||||
if '_id' not in self:
|
||||
self['_id'] = _session_id()
|
||||
self['_accessed_time'] = time.time()
|
||||
|
||||
val = self._encrypt_data()
|
||||
if len(val) > 4064:
|
||||
raise BeakerException("Cookie value is too long to store")
|
||||
|
||||
self.cookie[self.key] = val
|
||||
|
||||
if '_expires' in self:
|
||||
expires = self['_expires']
|
||||
else:
|
||||
expires = None
|
||||
expires = self._set_cookie_expires(expires)
|
||||
if expires is not None:
|
||||
self['_expires'] = expires
|
||||
|
||||
if '_domain' in self:
|
||||
self.cookie[self.key]['domain'] = self['_domain']
|
||||
elif self._domain:
|
||||
self.cookie[self.key]['domain'] = self._domain
|
||||
if self.secure:
|
||||
self.cookie[self.key]['secure'] = True
|
||||
self._set_cookie_http_only()
|
||||
|
||||
self.cookie[self.key]['path'] = self.get('_path', '/')
|
||||
|
||||
self.request['cookie_out'] = self.cookie[self.key].output(header='')
|
||||
self.request['set_cookie'] = True
|
||||
|
||||
def delete(self):
|
||||
"""Delete the cookie, and clear the session"""
|
||||
# Send a delete cookie request
|
||||
self._delete_cookie()
|
||||
self.clear()
|
||||
|
||||
def invalidate(self):
|
||||
"""Clear the contents and start a new session"""
|
||||
self.clear()
|
||||
self['_id'] = _session_id()
|
||||
|
||||
|
||||
class SessionObject(object):
|
||||
"""Session proxy/lazy creator
|
||||
|
||||
This object proxies access to the actual session object, so that in
|
||||
the case that the session hasn't been used before, it will be
|
||||
setup. This avoid creating and loading the session from persistent
|
||||
storage unless its actually used during the request.
|
||||
|
||||
"""
|
||||
def __init__(self, environ, **params):
|
||||
self.__dict__['_params'] = params
|
||||
self.__dict__['_environ'] = environ
|
||||
self.__dict__['_sess'] = None
|
||||
self.__dict__['_headers'] = {}
|
||||
|
||||
def _session(self):
|
||||
"""Lazy initial creation of session object"""
|
||||
if self.__dict__['_sess'] is None:
|
||||
params = self.__dict__['_params']
|
||||
environ = self.__dict__['_environ']
|
||||
self.__dict__['_headers'] = req = {'cookie_out': None}
|
||||
req['cookie'] = environ.get('HTTP_COOKIE')
|
||||
session_cls = params.get('session_class', None)
|
||||
if session_cls is None:
|
||||
if params.get('type') == 'cookie':
|
||||
session_cls = CookieSession
|
||||
else:
|
||||
session_cls = Session
|
||||
else:
|
||||
assert issubclass(session_cls, Session),\
|
||||
"Not a Session: " + session_cls
|
||||
self.__dict__['_sess'] = session_cls(req, **params)
|
||||
return self.__dict__['_sess']
|
||||
|
||||
def __getattr__(self, attr):
|
||||
return getattr(self._session(), attr)
|
||||
|
||||
def __setattr__(self, attr, value):
|
||||
setattr(self._session(), attr, value)
|
||||
|
||||
def __delattr__(self, name):
|
||||
self._session().__delattr__(name)
|
||||
|
||||
def __getitem__(self, key):
|
||||
return self._session()[key]
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
self._session()[key] = value
|
||||
|
||||
def __delitem__(self, key):
|
||||
self._session().__delitem__(key)
|
||||
|
||||
def __repr__(self):
|
||||
return self._session().__repr__()
|
||||
|
||||
def __iter__(self):
|
||||
"""Only works for proxying to a dict"""
|
||||
return iter(self._session().keys())
|
||||
|
||||
def __contains__(self, key):
|
||||
return key in self._session()
|
||||
|
||||
def has_key(self, key):
|
||||
return key in self._session()
|
||||
|
||||
def get_by_id(self, id):
|
||||
"""Loads a session given a session ID"""
|
||||
params = self.__dict__['_params']
|
||||
session = Session({}, use_cookies=False, id=id, **params)
|
||||
if session.is_new:
|
||||
return None
|
||||
return session
|
||||
|
||||
def save(self):
|
||||
self.__dict__['_dirty'] = True
|
||||
|
||||
def delete(self):
|
||||
self.__dict__['_dirty'] = True
|
||||
self._session().delete()
|
||||
|
||||
def persist(self):
|
||||
"""Persist the session to the storage
|
||||
|
||||
Always saves the whole session if save() or delete() have been called.
|
||||
If they haven't:
|
||||
|
||||
- If autosave is set to true, saves the the entire session regardless.
|
||||
- If save_accessed_time is set to true or unset, only saves the updated
|
||||
access time.
|
||||
- If save_accessed_time is set to false, doesn't save anything.
|
||||
|
||||
"""
|
||||
if self.__dict__['_params'].get('auto'):
|
||||
self._session().save()
|
||||
elif self.__dict__['_params'].get('save_accessed_time', True):
|
||||
if self.dirty():
|
||||
self._session().save()
|
||||
else:
|
||||
self._session().save(accessed_only=True)
|
||||
else: # save_accessed_time is false
|
||||
if self.dirty():
|
||||
self._session().save()
|
||||
|
||||
def dirty(self):
|
||||
"""Returns True if save() or delete() have been called"""
|
||||
return self.__dict__.get('_dirty', False)
|
||||
|
||||
def accessed(self):
|
||||
"""Returns whether or not the session has been accessed"""
|
||||
return self.__dict__['_sess'] is not None
|
|
@ -1,392 +0,0 @@
|
|||
"""Synchronization functions.
|
||||
|
||||
File- and mutex-based mutual exclusion synchronizers are provided,
|
||||
as well as a name-based mutex which locks within an application
|
||||
based on a string name.
|
||||
|
||||
"""
|
||||
import errno
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
|
||||
try:
|
||||
import threading as _threading
|
||||
except ImportError:
|
||||
import dummy_threading as _threading
|
||||
|
||||
# check for fcntl module
|
||||
try:
|
||||
sys.getwindowsversion()
|
||||
has_flock = False
|
||||
except:
|
||||
try:
|
||||
import fcntl
|
||||
has_flock = True
|
||||
except ImportError:
|
||||
has_flock = False
|
||||
|
||||
from beaker import util
|
||||
from beaker.exceptions import LockError
|
||||
|
||||
__all__ = ["file_synchronizer", "mutex_synchronizer", "null_synchronizer",
|
||||
"NameLock", "_threading"]
|
||||
|
||||
|
||||
class NameLock(object):
|
||||
"""a proxy for an RLock object that is stored in a name based
|
||||
registry.
|
||||
|
||||
Multiple threads can get a reference to the same RLock based on the
|
||||
name alone, and synchronize operations related to that name.
|
||||
|
||||
"""
|
||||
locks = util.WeakValuedRegistry()
|
||||
|
||||
class NLContainer(object):
|
||||
def __init__(self, reentrant):
|
||||
if reentrant:
|
||||
self.lock = _threading.RLock()
|
||||
else:
|
||||
self.lock = _threading.Lock()
|
||||
|
||||
def __call__(self):
|
||||
return self.lock
|
||||
|
||||
def __init__(self, identifier=None, reentrant=False):
|
||||
if identifier is None:
|
||||
self._lock = NameLock.NLContainer(reentrant)
|
||||
else:
|
||||
self._lock = NameLock.locks.get(identifier, NameLock.NLContainer,
|
||||
reentrant)
|
||||
|
||||
def acquire(self, wait=True):
|
||||
return self._lock().acquire(wait)
|
||||
|
||||
def release(self):
|
||||
self._lock().release()
|
||||
|
||||
|
||||
_synchronizers = util.WeakValuedRegistry()
|
||||
|
||||
|
||||
def _synchronizer(identifier, cls, **kwargs):
|
||||
return _synchronizers.sync_get((identifier, cls), cls, identifier, **kwargs)
|
||||
|
||||
|
||||
def file_synchronizer(identifier, **kwargs):
|
||||
if not has_flock or 'lock_dir' not in kwargs:
|
||||
return mutex_synchronizer(identifier)
|
||||
else:
|
||||
return _synchronizer(identifier, FileSynchronizer, **kwargs)
|
||||
|
||||
|
||||
def mutex_synchronizer(identifier, **kwargs):
|
||||
return _synchronizer(identifier, ConditionSynchronizer, **kwargs)
|
||||
|
||||
|
||||
class null_synchronizer(object):
|
||||
"""A 'null' synchronizer, which provides the :class:`.SynchronizerImpl` interface
|
||||
without any locking.
|
||||
|
||||
"""
|
||||
def acquire_write_lock(self, wait=True):
|
||||
return True
|
||||
|
||||
def acquire_read_lock(self):
|
||||
pass
|
||||
|
||||
def release_write_lock(self):
|
||||
pass
|
||||
|
||||
def release_read_lock(self):
|
||||
pass
|
||||
acquire = acquire_write_lock
|
||||
release = release_write_lock
|
||||
|
||||
|
||||
class SynchronizerImpl(object):
|
||||
"""Base class for a synchronization object that allows
|
||||
multiple readers, single writers.
|
||||
|
||||
"""
|
||||
def __init__(self):
|
||||
self._state = util.ThreadLocal()
|
||||
|
||||
class SyncState(object):
|
||||
__slots__ = 'reentrantcount', 'writing', 'reading'
|
||||
|
||||
def __init__(self):
|
||||
self.reentrantcount = 0
|
||||
self.writing = False
|
||||
self.reading = False
|
||||
|
||||
def state(self):
|
||||
if not self._state.has():
|
||||
state = SynchronizerImpl.SyncState()
|
||||
self._state.put(state)
|
||||
return state
|
||||
else:
|
||||
return self._state.get()
|
||||
state = property(state)
|
||||
|
||||
def release_read_lock(self):
|
||||
state = self.state
|
||||
|
||||
if state.writing:
|
||||
raise LockError("lock is in writing state")
|
||||
if not state.reading:
|
||||
raise LockError("lock is not in reading state")
|
||||
|
||||
if state.reentrantcount == 1:
|
||||
self.do_release_read_lock()
|
||||
state.reading = False
|
||||
|
||||
state.reentrantcount -= 1
|
||||
|
||||
def acquire_read_lock(self, wait=True):
|
||||
state = self.state
|
||||
|
||||
if state.writing:
|
||||
raise LockError("lock is in writing state")
|
||||
|
||||
if state.reentrantcount == 0:
|
||||
x = self.do_acquire_read_lock(wait)
|
||||
if (wait or x):
|
||||
state.reentrantcount += 1
|
||||
state.reading = True
|
||||
return x
|
||||
elif state.reading:
|
||||
state.reentrantcount += 1
|
||||
return True
|
||||
|
||||
def release_write_lock(self):
|
||||
state = self.state
|
||||
|
||||
if state.reading:
|
||||
raise LockError("lock is in reading state")
|
||||
if not state.writing:
|
||||
raise LockError("lock is not in writing state")
|
||||
|
||||
if state.reentrantcount == 1:
|
||||
self.do_release_write_lock()
|
||||
state.writing = False
|
||||
|
||||
state.reentrantcount -= 1
|
||||
|
||||
release = release_write_lock
|
||||
|
||||
def acquire_write_lock(self, wait=True):
|
||||
state = self.state
|
||||
|
||||
if state.reading:
|
||||
raise LockError("lock is in reading state")
|
||||
|
||||
if state.reentrantcount == 0:
|
||||
x = self.do_acquire_write_lock(wait)
|
||||
if (wait or x):
|
||||
state.reentrantcount += 1
|
||||
state.writing = True
|
||||
return x
|
||||
elif state.writing:
|
||||
state.reentrantcount += 1
|
||||
return True
|
||||
|
||||
acquire = acquire_write_lock
|
||||
|
||||
def do_release_read_lock(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
def do_acquire_read_lock(self, wait):
|
||||
raise NotImplementedError()
|
||||
|
||||
def do_release_write_lock(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
def do_acquire_write_lock(self, wait):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class FileSynchronizer(SynchronizerImpl):
|
||||
"""A synchronizer which locks using flock().
|
||||
|
||||
"""
|
||||
def __init__(self, identifier, lock_dir):
|
||||
super(FileSynchronizer, self).__init__()
|
||||
self._filedescriptor = util.ThreadLocal()
|
||||
|
||||
if lock_dir is None:
|
||||
lock_dir = tempfile.gettempdir()
|
||||
else:
|
||||
lock_dir = lock_dir
|
||||
|
||||
self.filename = util.encoded_path(
|
||||
lock_dir,
|
||||
[identifier],
|
||||
extension='.lock'
|
||||
)
|
||||
self.lock_dir = os.path.dirname(self.filename)
|
||||
|
||||
def _filedesc(self):
|
||||
return self._filedescriptor.get()
|
||||
_filedesc = property(_filedesc)
|
||||
|
||||
def _ensuredir(self):
|
||||
if not os.path.exists(self.lock_dir):
|
||||
util.verify_directory(self.lock_dir)
|
||||
|
||||
def _open(self, mode):
|
||||
filedescriptor = self._filedesc
|
||||
if filedescriptor is None:
|
||||
self._ensuredir()
|
||||
filedescriptor = os.open(self.filename, mode)
|
||||
self._filedescriptor.put(filedescriptor)
|
||||
return filedescriptor
|
||||
|
||||
def do_acquire_read_lock(self, wait):
|
||||
filedescriptor = self._open(os.O_CREAT | os.O_RDONLY)
|
||||
if not wait:
|
||||
try:
|
||||
fcntl.flock(filedescriptor, fcntl.LOCK_SH | fcntl.LOCK_NB)
|
||||
return True
|
||||
except IOError:
|
||||
os.close(filedescriptor)
|
||||
self._filedescriptor.remove()
|
||||
return False
|
||||
else:
|
||||
fcntl.flock(filedescriptor, fcntl.LOCK_SH)
|
||||
return True
|
||||
|
||||
def do_acquire_write_lock(self, wait):
|
||||
filedescriptor = self._open(os.O_CREAT | os.O_WRONLY)
|
||||
if not wait:
|
||||
try:
|
||||
fcntl.flock(filedescriptor, fcntl.LOCK_EX | fcntl.LOCK_NB)
|
||||
return True
|
||||
except IOError:
|
||||
os.close(filedescriptor)
|
||||
self._filedescriptor.remove()
|
||||
return False
|
||||
else:
|
||||
fcntl.flock(filedescriptor, fcntl.LOCK_EX)
|
||||
return True
|
||||
|
||||
def do_release_read_lock(self):
|
||||
self._release_all_locks()
|
||||
|
||||
def do_release_write_lock(self):
|
||||
self._release_all_locks()
|
||||
|
||||
def _release_all_locks(self):
|
||||
filedescriptor = self._filedesc
|
||||
if filedescriptor is not None:
|
||||
fcntl.flock(filedescriptor, fcntl.LOCK_UN)
|
||||
os.close(filedescriptor)
|
||||
self._filedescriptor.remove()
|
||||
|
||||
|
||||
class ConditionSynchronizer(SynchronizerImpl):
|
||||
"""a synchronizer using a Condition."""
|
||||
|
||||
def __init__(self, identifier):
|
||||
super(ConditionSynchronizer, self).__init__()
|
||||
|
||||
# counts how many asynchronous methods are executing
|
||||
self.asynch = 0
|
||||
|
||||
# pointer to thread that is the current sync operation
|
||||
self.current_sync_operation = None
|
||||
|
||||
# condition object to lock on
|
||||
self.condition = _threading.Condition(_threading.Lock())
|
||||
|
||||
def do_acquire_read_lock(self, wait=True):
|
||||
self.condition.acquire()
|
||||
try:
|
||||
# see if a synchronous operation is waiting to start
|
||||
# or is already running, in which case we wait (or just
|
||||
# give up and return)
|
||||
if wait:
|
||||
while self.current_sync_operation is not None:
|
||||
self.condition.wait()
|
||||
else:
|
||||
if self.current_sync_operation is not None:
|
||||
return False
|
||||
|
||||
self.asynch += 1
|
||||
finally:
|
||||
self.condition.release()
|
||||
|
||||
if not wait:
|
||||
return True
|
||||
|
||||
def do_release_read_lock(self):
|
||||
self.condition.acquire()
|
||||
try:
|
||||
self.asynch -= 1
|
||||
|
||||
# check if we are the last asynchronous reader thread
|
||||
# out the door.
|
||||
if self.asynch == 0:
|
||||
# yes. so if a sync operation is waiting, notifyAll to wake
|
||||
# it up
|
||||
if self.current_sync_operation is not None:
|
||||
self.condition.notifyAll()
|
||||
elif self.asynch < 0:
|
||||
raise LockError("Synchronizer error - too many "
|
||||
"release_read_locks called")
|
||||
finally:
|
||||
self.condition.release()
|
||||
|
||||
def do_acquire_write_lock(self, wait=True):
|
||||
self.condition.acquire()
|
||||
try:
|
||||
# here, we are not a synchronous reader, and after returning,
|
||||
# assuming waiting or immediate availability, we will be.
|
||||
|
||||
if wait:
|
||||
# if another sync is working, wait
|
||||
while self.current_sync_operation is not None:
|
||||
self.condition.wait()
|
||||
else:
|
||||
# if another sync is working,
|
||||
# we dont want to wait, so forget it
|
||||
if self.current_sync_operation is not None:
|
||||
return False
|
||||
|
||||
# establish ourselves as the current sync
|
||||
# this indicates to other read/write operations
|
||||
# that they should wait until this is None again
|
||||
self.current_sync_operation = _threading.currentThread()
|
||||
|
||||
# now wait again for asyncs to finish
|
||||
if self.asynch > 0:
|
||||
if wait:
|
||||
# wait
|
||||
self.condition.wait()
|
||||
else:
|
||||
# we dont want to wait, so forget it
|
||||
self.current_sync_operation = None
|
||||
return False
|
||||
finally:
|
||||
self.condition.release()
|
||||
|
||||
if not wait:
|
||||
return True
|
||||
|
||||
def do_release_write_lock(self):
|
||||
self.condition.acquire()
|
||||
try:
|
||||
if self.current_sync_operation is not _threading.currentThread():
|
||||
raise LockError("Synchronizer error - current thread doesnt "
|
||||
"have the write lock")
|
||||
|
||||
# reset the current sync operation so
|
||||
# another can get it
|
||||
self.current_sync_operation = None
|
||||
|
||||
# tell everyone to get ready
|
||||
self.condition.notifyAll()
|
||||
finally:
|
||||
# everyone go !!
|
||||
self.condition.release()
|
|
@ -1,507 +0,0 @@
|
|||
"""Beaker utilities"""
|
||||
import hashlib
|
||||
import socket
|
||||
|
||||
import binascii
|
||||
|
||||
from ._compat import PY2, string_type, unicode_text, NoneType, dictkeyslist, im_class, im_func, pickle, func_signature, \
|
||||
default_im_func
|
||||
|
||||
try:
|
||||
import threading as _threading
|
||||
except ImportError:
|
||||
import dummy_threading as _threading
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
import os
|
||||
import re
|
||||
import string
|
||||
import types
|
||||
import weakref
|
||||
import warnings
|
||||
import sys
|
||||
import inspect
|
||||
import json
|
||||
import zlib
|
||||
|
||||
from beaker.converters import asbool
|
||||
from beaker import exceptions
|
||||
from threading import local as _tlocal
|
||||
|
||||
DEFAULT_CACHE_KEY_LENGTH = 250
|
||||
|
||||
__all__ = ["ThreadLocal", "WeakValuedRegistry", "SyncDict", "encoded_path",
|
||||
"verify_directory",
|
||||
"serialize", "deserialize"]
|
||||
|
||||
|
||||
def function_named(fn, name):
|
||||
"""Return a function with a given __name__.
|
||||
|
||||
Will assign to __name__ and return the original function if possible on
|
||||
the Python implementation, otherwise a new function will be constructed.
|
||||
|
||||
"""
|
||||
fn.__name__ = name
|
||||
return fn
|
||||
|
||||
|
||||
def skip_if(predicate, reason=None):
|
||||
"""Skip a test if predicate is true."""
|
||||
reason = reason or predicate.__name__
|
||||
|
||||
from nose import SkipTest
|
||||
|
||||
def decorate(fn):
|
||||
fn_name = fn.__name__
|
||||
|
||||
def maybe(*args, **kw):
|
||||
if predicate():
|
||||
msg = "'%s' skipped: %s" % (
|
||||
fn_name, reason)
|
||||
raise SkipTest(msg)
|
||||
else:
|
||||
return fn(*args, **kw)
|
||||
return function_named(maybe, fn_name)
|
||||
return decorate
|
||||
|
||||
|
||||
def assert_raises(except_cls, callable_, *args, **kw):
|
||||
"""Assert the given exception is raised by the given function + arguments."""
|
||||
|
||||
try:
|
||||
callable_(*args, **kw)
|
||||
success = False
|
||||
except except_cls:
|
||||
success = True
|
||||
|
||||
# assert outside the block so it works for AssertionError too !
|
||||
assert success, "Callable did not raise an exception"
|
||||
|
||||
|
||||
def verify_directory(dir):
|
||||
"""verifies and creates a directory. tries to
|
||||
ignore collisions with other threads and processes."""
|
||||
|
||||
tries = 0
|
||||
while not os.access(dir, os.F_OK):
|
||||
try:
|
||||
tries += 1
|
||||
os.makedirs(dir)
|
||||
except:
|
||||
if tries > 5:
|
||||
raise
|
||||
|
||||
|
||||
def has_self_arg(func):
|
||||
"""Return True if the given function has a 'self' argument."""
|
||||
args = list(func_signature(func).parameters)
|
||||
if args and args[0] in ('self', 'cls'):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def warn(msg, stacklevel=3):
|
||||
"""Issue a warning."""
|
||||
if isinstance(msg, string_type):
|
||||
warnings.warn(msg, exceptions.BeakerWarning, stacklevel=stacklevel)
|
||||
else:
|
||||
warnings.warn(msg, stacklevel=stacklevel)
|
||||
|
||||
|
||||
def deprecated(message):
|
||||
def wrapper(fn):
|
||||
def deprecated_method(*args, **kargs):
|
||||
warnings.warn(message, DeprecationWarning, 2)
|
||||
return fn(*args, **kargs)
|
||||
# TODO: use decorator ? functools.wrapper ?
|
||||
deprecated_method.__name__ = fn.__name__
|
||||
deprecated_method.__doc__ = "%s\n\n%s" % (message, fn.__doc__)
|
||||
return deprecated_method
|
||||
return wrapper
|
||||
|
||||
|
||||
class ThreadLocal(object):
|
||||
"""stores a value on a per-thread basis"""
|
||||
|
||||
__slots__ = '_tlocal'
|
||||
|
||||
def __init__(self):
|
||||
self._tlocal = _tlocal()
|
||||
|
||||
def put(self, value):
|
||||
self._tlocal.value = value
|
||||
|
||||
def has(self):
|
||||
return hasattr(self._tlocal, 'value')
|
||||
|
||||
def get(self, default=None):
|
||||
return getattr(self._tlocal, 'value', default)
|
||||
|
||||
def remove(self):
|
||||
del self._tlocal.value
|
||||
|
||||
|
||||
class SyncDict(object):
|
||||
"""
|
||||
An efficient/threadsafe singleton map algorithm, a.k.a.
|
||||
"get a value based on this key, and create if not found or not
|
||||
valid" paradigm:
|
||||
|
||||
exists && isvalid ? get : create
|
||||
|
||||
Designed to work with weakref dictionaries to expect items
|
||||
to asynchronously disappear from the dictionary.
|
||||
|
||||
Use python 2.3.3 or greater ! a major bug was just fixed in Nov.
|
||||
2003 that was driving me nuts with garbage collection/weakrefs in
|
||||
this section.
|
||||
|
||||
"""
|
||||
def __init__(self):
|
||||
self.mutex = _threading.Lock()
|
||||
self.dict = {}
|
||||
|
||||
def get(self, key, createfunc, *args, **kwargs):
|
||||
try:
|
||||
if key in self.dict:
|
||||
return self.dict[key]
|
||||
else:
|
||||
return self.sync_get(key, createfunc, *args, **kwargs)
|
||||
except KeyError:
|
||||
return self.sync_get(key, createfunc, *args, **kwargs)
|
||||
|
||||
def sync_get(self, key, createfunc, *args, **kwargs):
|
||||
self.mutex.acquire()
|
||||
try:
|
||||
try:
|
||||
if key in self.dict:
|
||||
return self.dict[key]
|
||||
else:
|
||||
return self._create(key, createfunc, *args, **kwargs)
|
||||
except KeyError:
|
||||
return self._create(key, createfunc, *args, **kwargs)
|
||||
finally:
|
||||
self.mutex.release()
|
||||
|
||||
def _create(self, key, createfunc, *args, **kwargs):
|
||||
self[key] = obj = createfunc(*args, **kwargs)
|
||||
return obj
|
||||
|
||||
def has_key(self, key):
|
||||
return key in self.dict
|
||||
|
||||
def __contains__(self, key):
|
||||
return self.dict.__contains__(key)
|
||||
|
||||
def __getitem__(self, key):
|
||||
return self.dict.__getitem__(key)
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
self.dict.__setitem__(key, value)
|
||||
|
||||
def __delitem__(self, key):
|
||||
return self.dict.__delitem__(key)
|
||||
|
||||
def clear(self):
|
||||
self.dict.clear()
|
||||
|
||||
|
||||
class WeakValuedRegistry(SyncDict):
|
||||
def __init__(self):
|
||||
self.mutex = _threading.RLock()
|
||||
self.dict = weakref.WeakValueDictionary()
|
||||
|
||||
sha1 = None
|
||||
|
||||
|
||||
def encoded_path(root, identifiers, extension=".enc", depth=3,
|
||||
digest_filenames=True):
|
||||
|
||||
"""Generate a unique file-accessible path from the given list of
|
||||
identifiers starting at the given root directory."""
|
||||
ident = "_".join(identifiers)
|
||||
|
||||
global sha1
|
||||
if sha1 is None:
|
||||
from beaker.crypto import sha1
|
||||
|
||||
if digest_filenames:
|
||||
if isinstance(ident, unicode_text):
|
||||
ident = sha1(ident.encode('utf-8')).hexdigest()
|
||||
else:
|
||||
ident = sha1(ident).hexdigest()
|
||||
|
||||
ident = os.path.basename(ident)
|
||||
|
||||
tokens = []
|
||||
for d in range(1, depth):
|
||||
tokens.append(ident[0:d])
|
||||
|
||||
dir = os.path.join(root, *tokens)
|
||||
verify_directory(dir)
|
||||
|
||||
return os.path.join(dir, ident + extension)
|
||||
|
||||
|
||||
def asint(obj):
|
||||
if isinstance(obj, int):
|
||||
return obj
|
||||
elif isinstance(obj, string_type) and re.match(r'^\d+$', obj):
|
||||
return int(obj)
|
||||
else:
|
||||
raise Exception("This is not a proper int")
|
||||
|
||||
|
||||
def verify_options(opt, types, error):
|
||||
if not isinstance(opt, types):
|
||||
if not isinstance(types, tuple):
|
||||
types = (types,)
|
||||
coerced = False
|
||||
for typ in types:
|
||||
try:
|
||||
if typ in (list, tuple):
|
||||
opt = [x.strip() for x in opt.split(',')]
|
||||
else:
|
||||
if typ == bool:
|
||||
typ = asbool
|
||||
elif typ == int:
|
||||
typ = asint
|
||||
elif typ in (timedelta, datetime):
|
||||
if not isinstance(opt, typ):
|
||||
raise Exception("%s requires a timedelta type", typ)
|
||||
opt = typ(opt)
|
||||
coerced = True
|
||||
except:
|
||||
pass
|
||||
if coerced:
|
||||
break
|
||||
if not coerced:
|
||||
raise Exception(error)
|
||||
elif isinstance(opt, str) and not opt.strip():
|
||||
raise Exception("Empty strings are invalid for: %s" % error)
|
||||
return opt
|
||||
|
||||
|
||||
def verify_rules(params, ruleset):
|
||||
for key, types, message in ruleset:
|
||||
if key in params:
|
||||
params[key] = verify_options(params[key], types, message)
|
||||
return params
|
||||
|
||||
|
||||
def coerce_session_params(params):
|
||||
rules = [
|
||||
('data_dir', (str, NoneType), "data_dir must be a string referring to a directory."),
|
||||
('lock_dir', (str, NoneType), "lock_dir must be a string referring to a directory."),
|
||||
('type', (str, NoneType), "Session type must be a string."),
|
||||
('cookie_expires', (bool, datetime, timedelta, int),
|
||||
"Cookie expires was not a boolean, datetime, int, or timedelta instance."),
|
||||
('cookie_domain', (str, NoneType), "Cookie domain must be a string."),
|
||||
('cookie_path', (str, NoneType), "Cookie path must be a string."),
|
||||
('id', (str,), "Session id must be a string."),
|
||||
('key', (str,), "Session key must be a string."),
|
||||
('secret', (str, NoneType), "Session secret must be a string."),
|
||||
('validate_key', (str, NoneType), "Session encrypt_key must be a string."),
|
||||
('encrypt_key', (str, NoneType), "Session validate_key must be a string."),
|
||||
('encrypt_nonce_bits', (int, NoneType), "Session encrypt_nonce_bits must be a number"),
|
||||
('secure', (bool, NoneType), "Session secure must be a boolean."),
|
||||
('httponly', (bool, NoneType), "Session httponly must be a boolean."),
|
||||
('timeout', (int, NoneType), "Session timeout must be an integer."),
|
||||
('save_accessed_time', (bool, NoneType),
|
||||
"Session save_accessed_time must be a boolean (defaults to true)."),
|
||||
('auto', (bool, NoneType), "Session is created if accessed."),
|
||||
('webtest_varname', (str, NoneType), "Session varname must be a string."),
|
||||
('data_serializer', (str,), "data_serializer must be a string.")
|
||||
]
|
||||
opts = verify_rules(params, rules)
|
||||
cookie_expires = opts.get('cookie_expires')
|
||||
if cookie_expires and isinstance(cookie_expires, int) and \
|
||||
not isinstance(cookie_expires, bool):
|
||||
opts['cookie_expires'] = timedelta(seconds=cookie_expires)
|
||||
|
||||
if opts.get('timeout') is not None and not opts.get('save_accessed_time', True):
|
||||
raise Exception("save_accessed_time must be true to use timeout")
|
||||
|
||||
return opts
|
||||
|
||||
|
||||
def coerce_cache_params(params):
|
||||
rules = [
|
||||
('data_dir', (str, NoneType), "data_dir must be a string referring to a directory."),
|
||||
('lock_dir', (str, NoneType), "lock_dir must be a string referring to a directory."),
|
||||
('type', (str,), "Cache type must be a string."),
|
||||
('enabled', (bool, NoneType), "enabled must be true/false if present."),
|
||||
('expire', (int, NoneType),
|
||||
"expire must be an integer representing how many seconds the cache is valid for"),
|
||||
('regions', (list, tuple, NoneType),
|
||||
"Regions must be a comma separated list of valid regions"),
|
||||
('key_length', (int, NoneType),
|
||||
"key_length must be an integer which indicates the longest a key can be before hashing"),
|
||||
]
|
||||
return verify_rules(params, rules)
|
||||
|
||||
|
||||
def coerce_memcached_behaviors(behaviors):
|
||||
rules = [
|
||||
('cas', (bool, int), 'cas must be a boolean or an integer'),
|
||||
('no_block', (bool, int), 'no_block must be a boolean or an integer'),
|
||||
('receive_timeout', (int,), 'receive_timeout must be an integer'),
|
||||
('send_timeout', (int,), 'send_timeout must be an integer'),
|
||||
('ketama_hash', (str,),
|
||||
'ketama_hash must be a string designating a valid hashing strategy option'),
|
||||
('_poll_timeout', (int,), '_poll_timeout must be an integer'),
|
||||
('auto_eject', (bool, int), 'auto_eject must be an integer'),
|
||||
('retry_timeout', (int,), 'retry_timeout must be an integer'),
|
||||
('_sort_hosts', (bool, int), '_sort_hosts must be an integer'),
|
||||
('_io_msg_watermark', (int,), '_io_msg_watermark must be an integer'),
|
||||
('ketama', (bool, int), 'ketama must be a boolean or an integer'),
|
||||
('ketama_weighted', (bool, int), 'ketama_weighted must be a boolean or an integer'),
|
||||
('_io_key_prefetch', (int, bool), '_io_key_prefetch must be a boolean or an integer'),
|
||||
('_hash_with_prefix_key', (bool, int),
|
||||
'_hash_with_prefix_key must be a boolean or an integer'),
|
||||
('tcp_nodelay', (bool, int), 'tcp_nodelay must be a boolean or an integer'),
|
||||
('failure_limit', (int,), 'failure_limit must be an integer'),
|
||||
('buffer_requests', (bool, int), 'buffer_requests must be a boolean or an integer'),
|
||||
('_socket_send_size', (int,), '_socket_send_size must be an integer'),
|
||||
('num_replicas', (int,), 'num_replicas must be an integer'),
|
||||
('remove_failed', (int,), 'remove_failed must be an integer'),
|
||||
('_noreply', (bool, int), '_noreply must be a boolean or an integer'),
|
||||
('_io_bytes_watermark', (int,), '_io_bytes_watermark must be an integer'),
|
||||
('_socket_recv_size', (int,), '_socket_recv_size must be an integer'),
|
||||
('distribution', (str,),
|
||||
'distribution must be a string designating a valid distribution option'),
|
||||
('connect_timeout', (int,), 'connect_timeout must be an integer'),
|
||||
('hash', (str,), 'hash must be a string designating a valid hashing option'),
|
||||
('verify_keys', (bool, int), 'verify_keys must be a boolean or an integer'),
|
||||
('dead_timeout', (int,), 'dead_timeout must be an integer')
|
||||
]
|
||||
return verify_rules(behaviors, rules)
|
||||
|
||||
|
||||
def parse_cache_config_options(config, include_defaults=True):
|
||||
"""Parse configuration options and validate for use with the
|
||||
CacheManager"""
|
||||
|
||||
# Load default cache options
|
||||
if include_defaults:
|
||||
options = dict(type='memory', data_dir=None, expire=None,
|
||||
log_file=None)
|
||||
else:
|
||||
options = {}
|
||||
for key, val in config.items():
|
||||
if key.startswith('beaker.cache.'):
|
||||
options[key[13:]] = val
|
||||
if key.startswith('cache.'):
|
||||
options[key[6:]] = val
|
||||
coerce_cache_params(options)
|
||||
|
||||
# Set cache to enabled if not turned off
|
||||
if 'enabled' not in options and include_defaults:
|
||||
options['enabled'] = True
|
||||
|
||||
# Configure region dict if regions are available
|
||||
regions = options.pop('regions', None)
|
||||
if regions:
|
||||
region_configs = {}
|
||||
for region in regions:
|
||||
if not region: # ensure region name is valid
|
||||
continue
|
||||
# Setup the default cache options
|
||||
region_options = dict(data_dir=options.get('data_dir'),
|
||||
lock_dir=options.get('lock_dir'),
|
||||
type=options.get('type'),
|
||||
enabled=options['enabled'],
|
||||
expire=options.get('expire'),
|
||||
key_length=options.get('key_length', DEFAULT_CACHE_KEY_LENGTH))
|
||||
region_prefix = '%s.' % region
|
||||
region_len = len(region_prefix)
|
||||
for key in dictkeyslist(options):
|
||||
if key.startswith(region_prefix):
|
||||
region_options[key[region_len:]] = options.pop(key)
|
||||
coerce_cache_params(region_options)
|
||||
region_configs[region] = region_options
|
||||
options['cache_regions'] = region_configs
|
||||
return options
|
||||
|
||||
|
||||
def parse_memcached_behaviors(config):
|
||||
"""Parse behavior options and validate for use with pylibmc
|
||||
client/PylibMCNamespaceManager, or potentially other memcached
|
||||
NamespaceManagers that support behaviors"""
|
||||
behaviors = {}
|
||||
|
||||
for key, val in config.items():
|
||||
if key.startswith('behavior.'):
|
||||
behaviors[key[9:]] = val
|
||||
|
||||
coerce_memcached_behaviors(behaviors)
|
||||
return behaviors
|
||||
|
||||
|
||||
def func_namespace(func):
|
||||
"""Generates a unique namespace for a function"""
|
||||
kls = None
|
||||
if hasattr(func, 'im_func') or hasattr(func, '__func__'):
|
||||
kls = im_class(func)
|
||||
func = im_func(func)
|
||||
|
||||
if kls:
|
||||
return '%s.%s' % (kls.__module__, kls.__name__)
|
||||
else:
|
||||
return '%s|%s' % (inspect.getsourcefile(func), func.__name__)
|
||||
|
||||
|
||||
class PickleSerializer(object):
|
||||
def loads(self, data_string):
|
||||
return pickle.loads(data_string)
|
||||
|
||||
def dumps(self, data):
|
||||
return pickle.dumps(data, 2)
|
||||
|
||||
|
||||
class JsonSerializer(object):
|
||||
def loads(self, data_string):
|
||||
return json.loads(zlib.decompress(data_string).decode('utf-8'))
|
||||
|
||||
def dumps(self, data):
|
||||
return zlib.compress(json.dumps(data).encode('utf-8'))
|
||||
|
||||
|
||||
def serialize(data, method):
|
||||
if method == 'json':
|
||||
serializer = JsonSerializer()
|
||||
else:
|
||||
serializer = PickleSerializer()
|
||||
return serializer.dumps(data)
|
||||
|
||||
|
||||
def deserialize(data_string, method):
|
||||
if method == 'json':
|
||||
serializer = JsonSerializer()
|
||||
else:
|
||||
serializer = PickleSerializer()
|
||||
return serializer.loads(data_string)
|
||||
|
||||
|
||||
def machine_identifier():
|
||||
machine_hash = hashlib.md5()
|
||||
if not PY2:
|
||||
machine_hash.update(socket.gethostname().encode())
|
||||
else:
|
||||
machine_hash.update(socket.gethostname())
|
||||
return binascii.hexlify(machine_hash.digest()[0:3]).decode('ascii')
|
||||
|
||||
|
||||
def safe_write (filepath, contents):
|
||||
if os.name == 'posix':
|
||||
tempname = '%s.temp' % (filepath)
|
||||
fh = open(tempname, 'wb')
|
||||
fh.write(contents)
|
||||
fh.close()
|
||||
os.rename(tempname, filepath)
|
||||
else:
|
||||
fh = open(filepath, 'wb')
|
||||
fh.write(contents)
|
||||
fh.close()
|
|
@ -1,5 +1,5 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2009-2019 Joshua Bronson. All Rights Reserved.
|
||||
# Copyright 2009-2021 Joshua Bronson. All Rights Reserved.
|
||||
#
|
||||
# This Source Code Form is subject to the terms of the Mozilla Public
|
||||
# License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
|
@ -26,8 +26,9 @@
|
|||
#==============================================================================
|
||||
|
||||
|
||||
"""
|
||||
Efficient, Pythonic bidirectional map implementation and related functionality.
|
||||
"""The bidirectional mapping library for Python.
|
||||
|
||||
bidict by example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
|
@ -44,66 +45,45 @@ https://bidict.readthedocs.io for the most up-to-date documentation
|
|||
if you are reading this elsewhere.
|
||||
|
||||
|
||||
.. :copyright: (c) 2019 Joshua Bronson.
|
||||
.. :copyright: (c) 2009-2021 Joshua Bronson.
|
||||
.. :license: MPLv2. See LICENSE for details.
|
||||
"""
|
||||
|
||||
# This __init__.py only collects functionality implemented in the rest of the
|
||||
# source and exports it under the `bidict` module namespace (via `__all__`).
|
||||
# Use private aliases to not re-export these publicly (for Sphinx automodule with imported-members).
|
||||
from sys import version_info as _version_info
|
||||
|
||||
from ._abc import BidirectionalMapping
|
||||
|
||||
if _version_info < (3, 6): # pragma: no cover
|
||||
raise ImportError('Python 3.6+ is required.')
|
||||
|
||||
from ._abc import BidirectionalMapping, MutableBidirectionalMapping
|
||||
from ._base import BidictBase
|
||||
from ._mut import MutableBidict
|
||||
from ._bidict import bidict
|
||||
from ._dup import DuplicationPolicy, IGNORE, OVERWRITE, RAISE
|
||||
from ._exc import (
|
||||
BidictException, DuplicationError,
|
||||
KeyDuplicationError, ValueDuplicationError, KeyAndValueDuplicationError)
|
||||
from ._util import inverted
|
||||
from ._frozenbidict import frozenbidict
|
||||
from ._frozenordered import FrozenOrderedBidict
|
||||
from ._named import namedbidict
|
||||
from ._orderedbase import OrderedBidictBase
|
||||
from ._orderedbidict import OrderedBidict
|
||||
from ._dup import ON_DUP_DEFAULT, ON_DUP_RAISE, ON_DUP_DROP_OLD, RAISE, DROP_OLD, DROP_NEW, OnDup, OnDupAction
|
||||
from ._exc import BidictException, DuplicationError, KeyDuplicationError, ValueDuplicationError, KeyAndValueDuplicationError
|
||||
from ._iter import inverted
|
||||
from .metadata import (
|
||||
__author__, __maintainer__, __copyright__, __email__, __credits__, __url__,
|
||||
__license__, __status__, __description__, __keywords__, __version__, __version_info__)
|
||||
|
||||
|
||||
__all__ = (
|
||||
'__author__',
|
||||
'__maintainer__',
|
||||
'__copyright__',
|
||||
'__email__',
|
||||
'__credits__',
|
||||
'__license__',
|
||||
'__status__',
|
||||
'__description__',
|
||||
'__keywords__',
|
||||
'__url__',
|
||||
'__version__',
|
||||
'__version_info__',
|
||||
'BidirectionalMapping',
|
||||
'BidictException',
|
||||
'DuplicationPolicy',
|
||||
'IGNORE',
|
||||
'OVERWRITE',
|
||||
'RAISE',
|
||||
'DuplicationError',
|
||||
'KeyDuplicationError',
|
||||
'ValueDuplicationError',
|
||||
'KeyAndValueDuplicationError',
|
||||
'BidictBase',
|
||||
'MutableBidict',
|
||||
'frozenbidict',
|
||||
'bidict',
|
||||
'namedbidict',
|
||||
'FrozenOrderedBidict',
|
||||
'OrderedBidictBase',
|
||||
'OrderedBidict',
|
||||
'inverted',
|
||||
__license__, __status__, __description__, __keywords__, __version__,
|
||||
)
|
||||
|
||||
# Set __module__ of re-exported classes to the 'bidict' top-level module name
|
||||
# so that private/internal submodules are not exposed to users e.g. in repr strings.
|
||||
_locals = tuple(locals().items())
|
||||
for _name, _obj in _locals: # pragma: no cover
|
||||
if not getattr(_obj, '__module__', '').startswith('bidict.'):
|
||||
continue
|
||||
try:
|
||||
_obj.__module__ = 'bidict'
|
||||
except AttributeError: # raised when __module__ is read-only (as in OnDup)
|
||||
pass
|
||||
|
||||
|
||||
# * Code review nav *
|
||||
#==============================================================================
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2009-2019 Joshua Bronson. All Rights Reserved.
|
||||
# Copyright 2009-2021 Joshua Bronson. All Rights Reserved.
|
||||
#
|
||||
# This Source Code Form is subject to the terms of the Mozilla Public
|
||||
# License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
|
@ -26,12 +26,15 @@
|
|||
#==============================================================================
|
||||
|
||||
|
||||
"""Provides the :class:`BidirectionalMapping` abstract base class."""
|
||||
"""Provide the :class:`BidirectionalMapping` abstract base class."""
|
||||
|
||||
from .compat import Mapping, abstractproperty, iteritems
|
||||
import typing as _t
|
||||
from abc import abstractmethod
|
||||
|
||||
from ._typing import KT, VT
|
||||
|
||||
|
||||
class BidirectionalMapping(Mapping): # pylint: disable=abstract-method,no-init
|
||||
class BidirectionalMapping(_t.Mapping[KT, VT]):
|
||||
"""Abstract base class (ABC) for bidirectional mapping types.
|
||||
|
||||
Extends :class:`collections.abc.Mapping` primarily by adding the
|
||||
|
@ -43,8 +46,9 @@ class BidirectionalMapping(Mapping): # pylint: disable=abstract-method,no-init
|
|||
|
||||
__slots__ = ()
|
||||
|
||||
@abstractproperty
|
||||
def inverse(self):
|
||||
@property
|
||||
@abstractmethod
|
||||
def inverse(self) -> 'BidirectionalMapping[VT, KT]':
|
||||
"""The inverse of this bidirectional mapping instance.
|
||||
|
||||
*See also* :attr:`bidict.BidictBase.inverse`, :attr:`bidict.BidictBase.inv`
|
||||
|
@ -58,7 +62,7 @@ class BidirectionalMapping(Mapping): # pylint: disable=abstract-method,no-init
|
|||
# clear there's no reason to call this implementation (e.g. via super() after overriding).
|
||||
raise NotImplementedError
|
||||
|
||||
def __inverted__(self):
|
||||
def __inverted__(self) -> _t.Iterator[_t.Tuple[VT, KT]]:
|
||||
"""Get an iterator over the items in :attr:`inverse`.
|
||||
|
||||
This is functionally equivalent to iterating over the items in the
|
||||
|
@ -72,7 +76,27 @@ class BidirectionalMapping(Mapping): # pylint: disable=abstract-method,no-init
|
|||
|
||||
*See also* :func:`bidict.inverted`
|
||||
"""
|
||||
return iteritems(self.inverse)
|
||||
return iter(self.inverse.items())
|
||||
|
||||
def values(self) -> _t.KeysView[VT]: # type: ignore [override] # https://github.com/python/typeshed/issues/4435
|
||||
"""A set-like object providing a view on the contained values.
|
||||
|
||||
Override the implementation inherited from
|
||||
:class:`~collections.abc.Mapping`.
|
||||
Because the values of a :class:`~bidict.BidirectionalMapping`
|
||||
are the keys of its inverse,
|
||||
this returns a :class:`~collections.abc.KeysView`
|
||||
rather than a :class:`~collections.abc.ValuesView`,
|
||||
which has the advantages of constant-time containment checks
|
||||
and supporting set operations.
|
||||
"""
|
||||
return self.inverse.keys() # type: ignore [return-value]
|
||||
|
||||
|
||||
class MutableBidirectionalMapping(BidirectionalMapping[KT, VT], _t.MutableMapping[KT, VT]):
|
||||
"""Abstract base class (ABC) for mutable bidirectional mapping types."""
|
||||
|
||||
__slots__ = ()
|
||||
|
||||
|
||||
# * Code review nav *
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2009-2019 Joshua Bronson. All Rights Reserved.
|
||||
# Copyright 2009-2021 Joshua Bronson. All Rights Reserved.
|
||||
#
|
||||
# This Source Code Form is subject to the terms of the Mozilla Public
|
||||
# License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
|
@ -22,139 +22,118 @@
|
|||
|
||||
# * Code review nav *
|
||||
#==============================================================================
|
||||
# ← Prev: _abc.py Current: _base.py Next: _delegating_mixins.py →
|
||||
# ← Prev: _abc.py Current: _base.py Next: _frozenbidict.py →
|
||||
#==============================================================================
|
||||
|
||||
|
||||
"""Provides :class:`BidictBase`."""
|
||||
"""Provide :class:`BidictBase`."""
|
||||
|
||||
import typing as _t
|
||||
from collections import namedtuple
|
||||
from copy import copy
|
||||
from weakref import ref
|
||||
|
||||
from ._abc import BidirectionalMapping
|
||||
from ._dup import RAISE, OVERWRITE, IGNORE, _OnDup
|
||||
from ._exc import (
|
||||
DuplicationError, KeyDuplicationError, ValueDuplicationError, KeyAndValueDuplicationError)
|
||||
from ._miss import _MISS
|
||||
from ._noop import _NOOP
|
||||
from ._util import _iteritems_args_kw
|
||||
from .compat import PY2, KeysView, ItemsView, Mapping, iteritems
|
||||
from ._dup import ON_DUP_DEFAULT, RAISE, DROP_OLD, DROP_NEW, OnDup
|
||||
from ._exc import DuplicationError, KeyDuplicationError, ValueDuplicationError, KeyAndValueDuplicationError
|
||||
from ._iter import _iteritems_args_kw
|
||||
from ._typing import _NONE, KT, VT, OKT, OVT, IterItems, MapOrIterItems
|
||||
|
||||
|
||||
_DedupResult = namedtuple('_DedupResult', 'isdupkey isdupval invbyval fwdbykey')
|
||||
_WriteResult = namedtuple('_WriteResult', 'key val oldkey oldval')
|
||||
_NODUP = _DedupResult(False, False, _MISS, _MISS)
|
||||
_DedupResult = namedtuple('_DedupResult', 'isdupkey isdupval invbyval fwdbykey')
|
||||
_NODUP = _DedupResult(False, False, _NONE, _NONE)
|
||||
|
||||
BT = _t.TypeVar('BT', bound='BidictBase') # typevar for BidictBase.copy
|
||||
|
||||
|
||||
class BidictBase(BidirectionalMapping):
|
||||
class BidictBase(BidirectionalMapping[KT, VT]):
|
||||
"""Base class implementing :class:`BidirectionalMapping`."""
|
||||
|
||||
__slots__ = ('_fwdm', '_invm', '_inv', '_invweak', '_hash') + (() if PY2 else ('__weakref__',))
|
||||
__slots__ = ['_fwdm', '_invm', '_inv', '_invweak', '__weakref__']
|
||||
|
||||
#: The default :class:`DuplicationPolicy`
|
||||
#: (in effect during e.g. :meth:`~bidict.bidict.__init__` calls)
|
||||
#: The default :class:`~bidict.OnDup`
|
||||
#: that governs behavior when a provided item
|
||||
#: duplicates only the key of another item.
|
||||
#:
|
||||
#: Defaults to :attr:`~bidict.OVERWRITE`
|
||||
#: to match :class:`dict`'s behavior.
|
||||
#: duplicates the key or value of other item(s).
|
||||
#:
|
||||
#: *See also* :ref:`basic-usage:Values Must Be Unique`, :doc:`extending`
|
||||
on_dup_key = OVERWRITE
|
||||
on_dup = ON_DUP_DEFAULT
|
||||
|
||||
#: The default :class:`DuplicationPolicy`
|
||||
#: (in effect during e.g. :meth:`~bidict.bidict.__init__` calls)
|
||||
#: that governs behavior when a provided item
|
||||
#: duplicates only the value of another item.
|
||||
#:
|
||||
#: Defaults to :attr:`~bidict.RAISE`
|
||||
#: to prevent unintended overwrite of another item.
|
||||
#:
|
||||
#: *See also* :ref:`basic-usage:Values Must Be Unique`, :doc:`extending`
|
||||
on_dup_val = RAISE
|
||||
|
||||
#: The default :class:`DuplicationPolicy`
|
||||
#: (in effect during e.g. :meth:`~bidict.bidict.__init__` calls)
|
||||
#: that governs behavior when a provided item
|
||||
#: duplicates the key of another item and the value of a third item.
|
||||
#:
|
||||
#: Defaults to ``None``, which causes the *on_dup_kv* policy to match
|
||||
#: whatever *on_dup_val* policy is in effect.
|
||||
#:
|
||||
#: *See also* :ref:`basic-usage:Values Must Be Unique`, :doc:`extending`
|
||||
on_dup_kv = None
|
||||
|
||||
_fwdm_cls = dict
|
||||
_invm_cls = dict
|
||||
_fwdm_cls: _t.Type[_t.MutableMapping[KT, VT]] = dict #: class of the backing forward mapping
|
||||
_invm_cls: _t.Type[_t.MutableMapping[VT, KT]] = dict #: class of the backing inverse mapping
|
||||
|
||||
#: The object used by :meth:`__repr__` for printing the contained items.
|
||||
_repr_delegate = dict
|
||||
_repr_delegate: _t.Callable = dict
|
||||
|
||||
def __init__(self, *args, **kw): # pylint: disable=super-init-not-called
|
||||
_inv: 'BidictBase[VT, KT]'
|
||||
_inv_cls: '_t.Type[BidictBase[VT, KT]]'
|
||||
|
||||
def __init_subclass__(cls, **kw):
|
||||
super().__init_subclass__(**kw)
|
||||
# Compute and set _inv_cls, the inverse of this bidict class.
|
||||
if '_inv_cls' in cls.__dict__:
|
||||
return
|
||||
if cls._fwdm_cls is cls._invm_cls:
|
||||
cls._inv_cls = cls
|
||||
return
|
||||
inv_cls = type(cls.__name__ + 'Inv', cls.__bases__, {
|
||||
**cls.__dict__,
|
||||
'_inv_cls': cls,
|
||||
'_fwdm_cls': cls._invm_cls,
|
||||
'_invm_cls': cls._fwdm_cls,
|
||||
})
|
||||
cls._inv_cls = inv_cls
|
||||
|
||||
@_t.overload
|
||||
def __init__(self, __arg: _t.Mapping[KT, VT], **kw: VT) -> None: ...
|
||||
@_t.overload
|
||||
def __init__(self, __arg: IterItems[KT, VT], **kw: VT) -> None: ...
|
||||
@_t.overload
|
||||
def __init__(self, **kw: VT) -> None: ...
|
||||
def __init__(self, *args: MapOrIterItems[KT, VT], **kw: VT) -> None:
|
||||
"""Make a new bidirectional dictionary.
|
||||
The signature is the same as that of regular dictionaries.
|
||||
The signature behaves like that of :class:`dict`.
|
||||
Items passed in are added in the order they are passed,
|
||||
respecting the current duplication policies in the process.
|
||||
|
||||
*See also* :attr:`on_dup_key`, :attr:`on_dup_val`, :attr:`on_dup_kv`
|
||||
respecting the :attr:`on_dup` class attribute in the process.
|
||||
"""
|
||||
#: The backing :class:`~collections.abc.Mapping`
|
||||
#: storing the forward mapping data (*key* → *value*).
|
||||
self._fwdm = self._fwdm_cls()
|
||||
self._fwdm: _t.MutableMapping[KT, VT] = self._fwdm_cls()
|
||||
#: The backing :class:`~collections.abc.Mapping`
|
||||
#: storing the inverse mapping data (*value* → *key*).
|
||||
self._invm = self._invm_cls()
|
||||
self._init_inv() # lgtm [py/init-calls-subclass]
|
||||
self._invm: _t.MutableMapping[VT, KT] = self._invm_cls()
|
||||
self._init_inv()
|
||||
if args or kw:
|
||||
self._update(True, None, *args, **kw)
|
||||
self._update(True, self.on_dup, *args, **kw)
|
||||
|
||||
def _init_inv(self):
|
||||
# Compute the type for this bidict's inverse bidict (will be different from this
|
||||
# bidict's type if _fwdm_cls and _invm_cls are different).
|
||||
inv_cls = self._inv_cls()
|
||||
def _init_inv(self) -> None:
|
||||
# Create the inverse bidict instance via __new__, bypassing its __init__ so that its
|
||||
# _fwdm and _invm can be assigned to this bidict's _invm and _fwdm. Store it in self._inv,
|
||||
# which holds a strong reference to a bidict's inverse, if one is available.
|
||||
self._inv = inv = inv_cls.__new__(inv_cls)
|
||||
inv._fwdm = self._invm # pylint: disable=protected-access
|
||||
inv._invm = self._fwdm # pylint: disable=protected-access
|
||||
self._inv = inv = self._inv_cls.__new__(self._inv_cls)
|
||||
inv._fwdm = self._invm
|
||||
inv._invm = self._fwdm
|
||||
# Only give the inverse a weak reference to this bidict to avoid creating a reference cycle,
|
||||
# stored in the _invweak attribute. See also the docs in
|
||||
# :ref:`addendum:Bidict Avoids Reference Cycles`
|
||||
inv._inv = None # pylint: disable=protected-access
|
||||
inv._invweak = ref(self) # pylint: disable=protected-access
|
||||
inv._inv = None
|
||||
inv._invweak = ref(self)
|
||||
# Since this bidict has a strong reference to its inverse already, set its _invweak to None.
|
||||
self._invweak = None
|
||||
|
||||
@classmethod
|
||||
def _inv_cls(cls):
|
||||
"""The inverse of this bidict type, i.e. one with *_fwdm_cls* and *_invm_cls* swapped."""
|
||||
if cls._fwdm_cls is cls._invm_cls:
|
||||
return cls
|
||||
if not getattr(cls, '_inv_cls_', None):
|
||||
class _Inv(cls):
|
||||
_fwdm_cls = cls._invm_cls
|
||||
_invm_cls = cls._fwdm_cls
|
||||
_inv_cls_ = cls
|
||||
_Inv.__name__ = cls.__name__ + 'Inv'
|
||||
cls._inv_cls_ = _Inv
|
||||
return cls._inv_cls_
|
||||
|
||||
@property
|
||||
def _isinv(self):
|
||||
def _isinv(self) -> bool:
|
||||
return self._inv is None
|
||||
|
||||
@property
|
||||
def inverse(self):
|
||||
"""The inverse of this bidict.
|
||||
|
||||
*See also* :attr:`inv`
|
||||
"""
|
||||
def inverse(self) -> 'BidictBase[VT, KT]':
|
||||
"""The inverse of this bidict."""
|
||||
# Resolve and return a strong reference to the inverse bidict.
|
||||
# One may be stored in self._inv already.
|
||||
if self._inv is not None:
|
||||
return self._inv
|
||||
# Otherwise a weakref is stored in self._invweak. Try to get a strong ref from it.
|
||||
assert self._invweak is not None
|
||||
inv = self._invweak()
|
||||
if inv is not None:
|
||||
return inv
|
||||
|
@ -162,12 +141,10 @@ class BidictBase(BidirectionalMapping):
|
|||
self._init_inv() # Now this bidict will retain a strong ref to its inverse.
|
||||
return self._inv
|
||||
|
||||
@property
|
||||
def inv(self):
|
||||
"""Alias for :attr:`inverse`."""
|
||||
return self.inverse
|
||||
#: Alias for :attr:`inverse`.
|
||||
inv = inverse
|
||||
|
||||
def __getstate__(self):
|
||||
def __getstate__(self) -> dict:
|
||||
"""Needed to enable pickling due to use of :attr:`__slots__` and weakrefs.
|
||||
|
||||
*See also* :meth:`object.__getstate__`
|
||||
|
@ -183,27 +160,27 @@ class BidictBase(BidirectionalMapping):
|
|||
state.pop('__weakref__', None) # Not added back in __setstate__. Python manages this one.
|
||||
return state
|
||||
|
||||
def __setstate__(self, state):
|
||||
def __setstate__(self, state: dict) -> None:
|
||||
"""Implemented because use of :attr:`__slots__` would prevent unpickling otherwise.
|
||||
|
||||
*See also* :meth:`object.__setstate__`
|
||||
"""
|
||||
for slot, value in iteritems(state):
|
||||
for slot, value in state.items():
|
||||
setattr(self, slot, value)
|
||||
self._init_inv()
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
"""See :func:`repr`."""
|
||||
clsname = self.__class__.__name__
|
||||
if not self:
|
||||
return '%s()' % clsname
|
||||
return '%s(%r)' % (clsname, self._repr_delegate(iteritems(self)))
|
||||
return f'{clsname}()'
|
||||
return f'{clsname}({self._repr_delegate(self.items())})'
|
||||
|
||||
# The inherited Mapping.__eq__ implementation would work, but it's implemented in terms of an
|
||||
# inefficient ``dict(self.items()) == dict(other.items())`` comparison, so override it with a
|
||||
# more efficient implementation.
|
||||
def __eq__(self, other):
|
||||
u"""*x.__eq__(other) ⟺ x == other*
|
||||
def __eq__(self, other: object) -> bool:
|
||||
"""*x.__eq__(other) ⟺ x == other*
|
||||
|
||||
Equivalent to *dict(x.items()) == dict(other.items())*
|
||||
but more efficient.
|
||||
|
@ -216,101 +193,98 @@ class BidictBase(BidirectionalMapping):
|
|||
|
||||
*See also* :meth:`bidict.FrozenOrderedBidict.equals_order_sensitive`
|
||||
"""
|
||||
if not isinstance(other, Mapping) or len(self) != len(other):
|
||||
if not isinstance(other, _t.Mapping) or len(self) != len(other):
|
||||
return False
|
||||
selfget = self.get
|
||||
return all(selfget(k, _MISS) == v for (k, v) in iteritems(other))
|
||||
return all(selfget(k, _NONE) == v for (k, v) in other.items()) # type: ignore [arg-type]
|
||||
|
||||
def equals_order_sensitive(self, other: object) -> bool:
|
||||
"""Order-sensitive equality check.
|
||||
|
||||
*See also* :ref:`eq-order-insensitive`
|
||||
"""
|
||||
# Same short-circuit as in __eq__ above. Factoring out not worth function call overhead.
|
||||
if not isinstance(other, _t.Mapping) or len(self) != len(other):
|
||||
return False
|
||||
return all(i == j for (i, j) in zip(self.items(), other.items()))
|
||||
|
||||
# The following methods are mutating and so are not public. But they are implemented in this
|
||||
# non-mutable base class (rather than the mutable `bidict` subclass) because they are used here
|
||||
# during initialization (starting with the `_update` method). (Why is this? Because `__init__`
|
||||
# and `update` share a lot of the same behavior (inserting the provided items while respecting
|
||||
# the active duplication policies), so it makes sense for them to share implementation too.)
|
||||
def _pop(self, key):
|
||||
# `on_dup`), so it makes sense for them to share implementation too.)
|
||||
def _pop(self, key: KT) -> VT:
|
||||
val = self._fwdm.pop(key)
|
||||
del self._invm[val]
|
||||
return val
|
||||
|
||||
def _put(self, key, val, on_dup):
|
||||
def _put(self, key: KT, val: VT, on_dup: OnDup) -> None:
|
||||
dedup_result = self._dedup_item(key, val, on_dup)
|
||||
if dedup_result is not _NOOP:
|
||||
if dedup_result is not None:
|
||||
self._write_item(key, val, dedup_result)
|
||||
|
||||
def _dedup_item(self, key, val, on_dup):
|
||||
"""
|
||||
Check *key* and *val* for any duplication in self.
|
||||
def _dedup_item(self, key: KT, val: VT, on_dup: OnDup) -> _t.Optional[_DedupResult]:
|
||||
"""Check *key* and *val* for any duplication in self.
|
||||
|
||||
Handle any duplication as per the duplication policies given in *on_dup*.
|
||||
Handle any duplication as per the passed in *on_dup*.
|
||||
|
||||
(key, val) already present is construed as a no-op, not a duplication.
|
||||
|
||||
If duplication is found and the corresponding duplication policy is
|
||||
If duplication is found and the corresponding :class:`~bidict.OnDupAction` is
|
||||
:attr:`~bidict.DROP_NEW`, return None.
|
||||
|
||||
If duplication is found and the corresponding :class:`~bidict.OnDupAction` is
|
||||
:attr:`~bidict.RAISE`, raise the appropriate error.
|
||||
|
||||
If duplication is found and the corresponding duplication policy is
|
||||
:attr:`~bidict.IGNORE`, return *None*.
|
||||
|
||||
If duplication is found and the corresponding duplication policy is
|
||||
:attr:`~bidict.OVERWRITE`,
|
||||
If duplication is found and the corresponding :class:`~bidict.OnDupAction` is
|
||||
:attr:`~bidict.DROP_OLD`,
|
||||
or if no duplication is found,
|
||||
return the _DedupResult *(isdupkey, isdupval, oldkey, oldval)*.
|
||||
return the :class:`_DedupResult` *(isdupkey, isdupval, oldkey, oldval)*.
|
||||
"""
|
||||
fwdm = self._fwdm
|
||||
invm = self._invm
|
||||
oldval = fwdm.get(key, _MISS)
|
||||
oldkey = invm.get(val, _MISS)
|
||||
isdupkey = oldval is not _MISS
|
||||
isdupval = oldkey is not _MISS
|
||||
oldval: OVT = fwdm.get(key, _NONE)
|
||||
oldkey: OKT = invm.get(val, _NONE)
|
||||
isdupkey = oldval is not _NONE
|
||||
isdupval = oldkey is not _NONE
|
||||
dedup_result = _DedupResult(isdupkey, isdupval, oldkey, oldval)
|
||||
if isdupkey and isdupval:
|
||||
if self._isdupitem(key, val, dedup_result):
|
||||
if self._already_have(key, val, oldkey, oldval):
|
||||
# (key, val) duplicates an existing item -> no-op.
|
||||
return _NOOP
|
||||
return None
|
||||
# key and val each duplicate a different existing item.
|
||||
if on_dup.kv is RAISE:
|
||||
raise KeyAndValueDuplicationError(key, val)
|
||||
elif on_dup.kv is IGNORE:
|
||||
return _NOOP
|
||||
assert on_dup.kv is OVERWRITE, 'invalid on_dup_kv: %r' % on_dup.kv
|
||||
if on_dup.kv is DROP_NEW:
|
||||
return None
|
||||
assert on_dup.kv is DROP_OLD
|
||||
# Fall through to the return statement on the last line.
|
||||
elif isdupkey:
|
||||
if on_dup.key is RAISE:
|
||||
raise KeyDuplicationError(key)
|
||||
elif on_dup.key is IGNORE:
|
||||
return _NOOP
|
||||
assert on_dup.key is OVERWRITE, 'invalid on_dup.key: %r' % on_dup.key
|
||||
if on_dup.key is DROP_NEW:
|
||||
return None
|
||||
assert on_dup.key is DROP_OLD
|
||||
# Fall through to the return statement on the last line.
|
||||
elif isdupval:
|
||||
if on_dup.val is RAISE:
|
||||
raise ValueDuplicationError(val)
|
||||
elif on_dup.val is IGNORE:
|
||||
return _NOOP
|
||||
assert on_dup.val is OVERWRITE, 'invalid on_dup.val: %r' % on_dup.val
|
||||
if on_dup.val is DROP_NEW:
|
||||
return None
|
||||
assert on_dup.val is DROP_OLD
|
||||
# Fall through to the return statement on the last line.
|
||||
# else neither isdupkey nor isdupval.
|
||||
return dedup_result
|
||||
|
||||
@staticmethod
|
||||
def _isdupitem(key, val, dedup_result):
|
||||
isdupkey, isdupval, oldkey, oldval = dedup_result
|
||||
isdupitem = oldkey == key
|
||||
assert isdupitem == (oldval == val), '%r %r %r' % (key, val, dedup_result)
|
||||
if isdupitem:
|
||||
assert isdupkey
|
||||
assert isdupval
|
||||
return isdupitem
|
||||
def _already_have(key: KT, val: VT, oldkey: OKT, oldval: OVT) -> bool:
|
||||
# Overridden by _orderedbase.OrderedBidictBase.
|
||||
isdup = oldkey == key
|
||||
assert isdup == (oldval == val), f'{key} {val} {oldkey} {oldval}'
|
||||
return isdup
|
||||
|
||||
@classmethod
|
||||
def _get_on_dup(cls, on_dup=None):
|
||||
if on_dup is None:
|
||||
on_dup = _OnDup(cls.on_dup_key, cls.on_dup_val, cls.on_dup_kv)
|
||||
elif not isinstance(on_dup, _OnDup):
|
||||
on_dup = _OnDup(*on_dup)
|
||||
if on_dup.kv is None:
|
||||
on_dup = on_dup._replace(kv=on_dup.val)
|
||||
return on_dup
|
||||
|
||||
def _write_item(self, key, val, dedup_result):
|
||||
def _write_item(self, key: KT, val: VT, dedup_result: _DedupResult) -> _WriteResult:
|
||||
# Overridden by _orderedbase.OrderedBidictBase.
|
||||
isdupkey, isdupval, oldkey, oldval = dedup_result
|
||||
fwdm = self._fwdm
|
||||
invm = self._invm
|
||||
|
@ -322,35 +296,34 @@ class BidictBase(BidirectionalMapping):
|
|||
del fwdm[oldkey]
|
||||
return _WriteResult(key, val, oldkey, oldval)
|
||||
|
||||
def _update(self, init, on_dup, *args, **kw):
|
||||
def _update(self, init: bool, on_dup: OnDup, *args: MapOrIterItems[KT, VT], **kw: VT) -> None:
|
||||
# args[0] may be a generator that yields many items, so process input in a single pass.
|
||||
if not args and not kw:
|
||||
return
|
||||
can_skip_dup_check = not self and not kw and isinstance(args[0], BidirectionalMapping)
|
||||
if can_skip_dup_check:
|
||||
self._update_no_dup_check(args[0])
|
||||
self._update_no_dup_check(args[0]) # type: ignore [arg-type]
|
||||
return
|
||||
on_dup = self._get_on_dup(on_dup)
|
||||
can_skip_rollback = init or RAISE not in on_dup
|
||||
if can_skip_rollback:
|
||||
self._update_no_rollback(on_dup, *args, **kw)
|
||||
else:
|
||||
self._update_with_rollback(on_dup, *args, **kw)
|
||||
|
||||
def _update_no_dup_check(self, other, _nodup=_NODUP):
|
||||
def _update_no_dup_check(self, other: BidirectionalMapping[KT, VT]) -> None:
|
||||
write_item = self._write_item
|
||||
for (key, val) in iteritems(other):
|
||||
write_item(key, val, _nodup)
|
||||
for (key, val) in other.items():
|
||||
write_item(key, val, _NODUP)
|
||||
|
||||
def _update_no_rollback(self, on_dup, *args, **kw):
|
||||
def _update_no_rollback(self, on_dup: OnDup, *args: MapOrIterItems[KT, VT], **kw: VT) -> None:
|
||||
put = self._put
|
||||
for (key, val) in _iteritems_args_kw(*args, **kw):
|
||||
put(key, val, on_dup)
|
||||
|
||||
def _update_with_rollback(self, on_dup, *args, **kw):
|
||||
def _update_with_rollback(self, on_dup: OnDup, *args: MapOrIterItems[KT, VT], **kw: VT) -> None:
|
||||
"""Update, rolling back on failure."""
|
||||
writelog = []
|
||||
appendlog = writelog.append
|
||||
writes: _t.List[_t.Tuple[_DedupResult, _WriteResult]] = []
|
||||
append_write = writes.append
|
||||
dedup_item = self._dedup_item
|
||||
write_item = self._write_item
|
||||
for (key, val) in _iteritems_args_kw(*args, **kw):
|
||||
|
@ -358,14 +331,14 @@ class BidictBase(BidirectionalMapping):
|
|||
dedup_result = dedup_item(key, val, on_dup)
|
||||
except DuplicationError:
|
||||
undo_write = self._undo_write
|
||||
for dedup_result, write_result in reversed(writelog):
|
||||
for dedup_result, write_result in reversed(writes):
|
||||
undo_write(dedup_result, write_result)
|
||||
raise
|
||||
if dedup_result is not _NOOP:
|
||||
if dedup_result is not None:
|
||||
write_result = write_item(key, val, dedup_result)
|
||||
appendlog((dedup_result, write_result))
|
||||
append_write((dedup_result, write_result))
|
||||
|
||||
def _undo_write(self, dedup_result, write_result):
|
||||
def _undo_write(self, dedup_result: _DedupResult, write_result: _WriteResult) -> None:
|
||||
isdupkey, isdupval, _, _ = dedup_result
|
||||
key, val, oldkey, oldval = write_result
|
||||
if not isdupkey and not isdupval:
|
||||
|
@ -384,79 +357,48 @@ class BidictBase(BidirectionalMapping):
|
|||
if not isdupkey:
|
||||
del fwdm[key]
|
||||
|
||||
def copy(self):
|
||||
def copy(self: BT) -> BT:
|
||||
"""A shallow copy."""
|
||||
# Could just ``return self.__class__(self)`` here instead, but the below is faster. It uses
|
||||
# __new__ to create a copy instance while bypassing its __init__, which would result
|
||||
# in copying this bidict's items into the copy instance one at a time. Instead, make whole
|
||||
# copies of each of the backing mappings, and make them the backing mappings of the copy,
|
||||
# avoiding copying items one at a time.
|
||||
copy = self.__class__.__new__(self.__class__)
|
||||
copy._fwdm = self._fwdm.copy() # pylint: disable=protected-access
|
||||
copy._invm = self._invm.copy() # pylint: disable=protected-access
|
||||
copy._init_inv() # pylint: disable=protected-access
|
||||
return copy
|
||||
cp: BT = self.__class__.__new__(self.__class__)
|
||||
cp._fwdm = copy(self._fwdm)
|
||||
cp._invm = copy(self._invm)
|
||||
cp._init_inv()
|
||||
return cp
|
||||
|
||||
def __copy__(self):
|
||||
"""Used for the copy protocol.
|
||||
#: Used for the copy protocol.
|
||||
#: *See also* the :mod:`copy` module
|
||||
__copy__ = copy
|
||||
|
||||
*See also* the :mod:`copy` module
|
||||
"""
|
||||
return self.copy()
|
||||
|
||||
def __len__(self):
|
||||
def __len__(self) -> int:
|
||||
"""The number of contained items."""
|
||||
return len(self._fwdm)
|
||||
|
||||
def __iter__(self): # lgtm [py/inheritance/incorrect-overridden-signature]
|
||||
"""Iterator over the contained items."""
|
||||
# No default implementation for __iter__ inherited from Mapping ->
|
||||
# always delegate to _fwdm.
|
||||
def __iter__(self) -> _t.Iterator[KT]:
|
||||
"""Iterator over the contained keys."""
|
||||
return iter(self._fwdm)
|
||||
|
||||
def __getitem__(self, key):
|
||||
u"""*x.__getitem__(key) ⟺ x[key]*"""
|
||||
def __getitem__(self, key: KT) -> VT:
|
||||
"""*x.__getitem__(key) ⟺ x[key]*"""
|
||||
return self._fwdm[key]
|
||||
|
||||
def values(self):
|
||||
"""A set-like object providing a view on the contained values.
|
||||
# On Python 3.8+, dicts are reversible, so even non-Ordered bidicts can provide an efficient
|
||||
# __reversed__ implementation. (On Python < 3.8, they cannot.) Once support is dropped for
|
||||
# Python < 3.8, can remove the following if statement to provide __reversed__ unconditionally.
|
||||
if hasattr(_fwdm_cls, '__reversed__'):
|
||||
def __reversed__(self) -> _t.Iterator[KT]:
|
||||
"""Iterator over the contained keys in reverse order."""
|
||||
return reversed(self._fwdm) # type: ignore [no-any-return,call-overload]
|
||||
|
||||
Note that because the values of a :class:`~bidict.BidirectionalMapping`
|
||||
are the keys of its inverse,
|
||||
this returns a :class:`~collections.abc.KeysView`
|
||||
rather than a :class:`~collections.abc.ValuesView`,
|
||||
which has the advantages of constant-time containment checks
|
||||
and supporting set operations.
|
||||
"""
|
||||
return self.inverse.keys()
|
||||
|
||||
if PY2:
|
||||
# For iterkeys and iteritems, inheriting from Mapping already provides
|
||||
# the best default implementations so no need to define here.
|
||||
|
||||
def itervalues(self):
|
||||
"""An iterator over the contained values."""
|
||||
return self.inverse.iterkeys()
|
||||
|
||||
def viewkeys(self): # noqa: D102; pylint: disable=missing-docstring
|
||||
return KeysView(self)
|
||||
|
||||
def viewvalues(self): # noqa: D102; pylint: disable=missing-docstring
|
||||
return self.inverse.viewkeys()
|
||||
|
||||
viewvalues.__doc__ = values.__doc__
|
||||
values.__doc__ = 'A list of the contained values.'
|
||||
|
||||
def viewitems(self): # noqa: D102; pylint: disable=missing-docstring
|
||||
return ItemsView(self)
|
||||
|
||||
# __ne__ added automatically in Python 3 when you implement __eq__, but not in Python 2.
|
||||
def __ne__(self, other): # noqa: N802
|
||||
u"""*x.__ne__(other) ⟺ x != other*"""
|
||||
return not self == other # Implement __ne__ in terms of __eq__.
|
||||
|
||||
# Work around weakref slot with Generics bug on Python 3.6 (https://bugs.python.org/issue41451):
|
||||
BidictBase.__slots__.remove('__weakref__')
|
||||
|
||||
# * Code review nav *
|
||||
#==============================================================================
|
||||
# ← Prev: _abc.py Current: _base.py Next: _delegating_mixins.py →
|
||||
# ← Prev: _abc.py Current: _base.py Next: _frozenbidict.py →
|
||||
#==============================================================================
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2009-2019 Joshua Bronson. All Rights Reserved.
|
||||
# Copyright 2009-2021 Joshua Bronson. All Rights Reserved.
|
||||
#
|
||||
# This Source Code Form is subject to the terms of the Mozilla Public
|
||||
# License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
|
@ -26,18 +26,23 @@
|
|||
#==============================================================================
|
||||
|
||||
|
||||
"""Provides :class:`bidict`."""
|
||||
"""Provide :class:`bidict`."""
|
||||
|
||||
import typing as _t
|
||||
|
||||
from ._delegating import _DelegatingBidict
|
||||
from ._mut import MutableBidict
|
||||
from ._delegating_mixins import _DelegateKeysAndItemsToFwdm
|
||||
from ._typing import KT, VT
|
||||
|
||||
|
||||
class bidict(_DelegateKeysAndItemsToFwdm, MutableBidict): # noqa: N801,E501; pylint: disable=invalid-name
|
||||
class bidict(_DelegatingBidict[KT, VT], MutableBidict[KT, VT]):
|
||||
"""Base class for mutable bidirectional mappings."""
|
||||
|
||||
__slots__ = ()
|
||||
|
||||
__hash__ = None # since this class is mutable; explicit > implicit.
|
||||
if _t.TYPE_CHECKING:
|
||||
@property
|
||||
def inverse(self) -> 'bidict[VT, KT]': ...
|
||||
|
||||
|
||||
# * Code review nav *
|
||||
|
|
|
@ -0,0 +1,39 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2009-2021 Joshua Bronson. All Rights Reserved.
|
||||
#
|
||||
# This Source Code Form is subject to the terms of the Mozilla Public
|
||||
# License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
|
||||
"""Provide :class:`_DelegatingBidict`."""
|
||||
|
||||
import typing as _t
|
||||
|
||||
from ._base import BidictBase
|
||||
from ._typing import KT, VT
|
||||
|
||||
|
||||
class _DelegatingBidict(BidictBase[KT, VT]):
|
||||
"""Provide optimized implementations of several methods by delegating to backing dicts.
|
||||
|
||||
Used to override less efficient implementations inherited by :class:`~collections.abc.Mapping`.
|
||||
"""
|
||||
|
||||
__slots__ = ()
|
||||
|
||||
def __iter__(self) -> _t.Iterator[KT]:
|
||||
"""Iterator over the contained keys."""
|
||||
return iter(self._fwdm)
|
||||
|
||||
def keys(self) -> _t.KeysView[KT]:
|
||||
"""A set-like object providing a view on the contained keys."""
|
||||
return self._fwdm.keys() # type: ignore [return-value]
|
||||
|
||||
def values(self) -> _t.KeysView[VT]: # type: ignore [override] # https://github.com/python/typeshed/issues/4435
|
||||
"""A set-like object providing a view on the contained values."""
|
||||
return self._invm.keys() # type: ignore [return-value]
|
||||
|
||||
def items(self) -> _t.ItemsView[KT, VT]:
|
||||
"""A set-like object providing a view on the contained items."""
|
||||
return self._fwdm.items() # type: ignore [return-value]
|
|
@ -1,92 +0,0 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2009-2019 Joshua Bronson. All Rights Reserved.
|
||||
#
|
||||
# This Source Code Form is subject to the terms of the Mozilla Public
|
||||
# License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
|
||||
#==============================================================================
|
||||
# * Welcome to the bidict source code *
|
||||
#==============================================================================
|
||||
|
||||
# Doing a code review? You'll find a "Code review nav" comment like the one
|
||||
# below at the top and bottom of the most important source files. This provides
|
||||
# a suggested initial path through the source when reviewing.
|
||||
#
|
||||
# Note: If you aren't reading this on https://github.com/jab/bidict, you may be
|
||||
# viewing an outdated version of the code. Please head to GitHub to review the
|
||||
# latest version, which contains important improvements over older versions.
|
||||
#
|
||||
# Thank you for reading and for any feedback you provide.
|
||||
|
||||
# * Code review nav *
|
||||
#==============================================================================
|
||||
# ← Prev: _base.py Current: _delegating_mixins.py Next: _frozenbidict.py →
|
||||
#==============================================================================
|
||||
|
||||
|
||||
r"""Provides mixin classes that delegate to ``self._fwdm`` for various operations.
|
||||
|
||||
This allows methods such as :meth:`bidict.bidict.items`
|
||||
to be implemented in terms of a ``self._fwdm.items()`` call,
|
||||
which is potentially much more efficient (e.g. in CPython 2)
|
||||
compared to the implementation inherited from :class:`~collections.abc.Mapping`
|
||||
(which returns ``[(key, self[key]) for key in self]`` in Python 2).
|
||||
|
||||
Because this depends on implementation details that aren't necessarily true
|
||||
(such as the bidict's values being the same as its ``self._fwdm.values()``,
|
||||
which is not true for e.g. ordered bidicts where ``_fwdm``\'s values are nodes),
|
||||
these should always be mixed in at a layer below a more general layer,
|
||||
as they are in e.g. :class:`~bidict.frozenbidict`
|
||||
which extends :class:`~bidict.BidictBase`.
|
||||
|
||||
See the :ref:`extending:Sorted Bidict Recipes`
|
||||
for another example of where this comes into play.
|
||||
``SortedBidict`` extends :class:`bidict.MutableBidict`
|
||||
rather than :class:`bidict.bidict`
|
||||
to avoid inheriting these mixins,
|
||||
which are incompatible with the backing
|
||||
:class:`sortedcontainers.SortedDict`s.
|
||||
"""
|
||||
|
||||
from .compat import PY2
|
||||
|
||||
|
||||
_KEYS_METHODS = ('keys',) + (('viewkeys', 'iterkeys') if PY2 else ())
|
||||
_ITEMS_METHODS = ('items',) + (('viewitems', 'iteritems') if PY2 else ())
|
||||
_DOCSTRING_BY_METHOD = {
|
||||
'keys': 'A set-like object providing a view on the contained keys.',
|
||||
'items': 'A set-like object providing a view on the contained items.',
|
||||
}
|
||||
if PY2:
|
||||
_DOCSTRING_BY_METHOD['viewkeys'] = _DOCSTRING_BY_METHOD['keys']
|
||||
_DOCSTRING_BY_METHOD['viewitems'] = _DOCSTRING_BY_METHOD['items']
|
||||
_DOCSTRING_BY_METHOD['keys'] = 'A list of the contained keys.'
|
||||
_DOCSTRING_BY_METHOD['items'] = 'A list of the contained items.'
|
||||
|
||||
|
||||
def _make_method(methodname):
|
||||
def method(self):
|
||||
return getattr(self._fwdm, methodname)() # pylint: disable=protected-access
|
||||
method.__name__ = methodname
|
||||
method.__doc__ = _DOCSTRING_BY_METHOD.get(methodname, '')
|
||||
return method
|
||||
|
||||
|
||||
def _make_fwdm_delegating_mixin(clsname, methodnames):
|
||||
clsdict = dict({name: _make_method(name) for name in methodnames}, __slots__=())
|
||||
return type(clsname, (object,), clsdict)
|
||||
|
||||
|
||||
_DelegateKeysToFwdm = _make_fwdm_delegating_mixin('_DelegateKeysToFwdm', _KEYS_METHODS)
|
||||
_DelegateItemsToFwdm = _make_fwdm_delegating_mixin('_DelegateItemsToFwdm', _ITEMS_METHODS)
|
||||
_DelegateKeysAndItemsToFwdm = type(
|
||||
'_DelegateKeysAndItemsToFwdm',
|
||||
(_DelegateKeysToFwdm, _DelegateItemsToFwdm),
|
||||
{'__slots__': ()})
|
||||
|
||||
# * Code review nav *
|
||||
#==============================================================================
|
||||
# ← Prev: _base.py Current: _delegating_mixins.py Next: _frozenbidict.py →
|
||||
#==============================================================================
|
|
@ -1,36 +1,58 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2009-2019 Joshua Bronson. All Rights Reserved.
|
||||
# Copyright 2009-2021 Joshua Bronson. All Rights Reserved.
|
||||
#
|
||||
# This Source Code Form is subject to the terms of the Mozilla Public
|
||||
# License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
|
||||
"""Provides bidict duplication policies and the :class:`_OnDup` class."""
|
||||
"""Provide :class:`OnDup` and related functionality."""
|
||||
|
||||
|
||||
from collections import namedtuple
|
||||
|
||||
from ._marker import _Marker
|
||||
from enum import Enum
|
||||
|
||||
|
||||
_OnDup = namedtuple('_OnDup', 'key val kv')
|
||||
class OnDupAction(Enum):
|
||||
"""An action to take to prevent duplication from occurring."""
|
||||
|
||||
#: Raise a :class:`~bidict.DuplicationError`.
|
||||
RAISE = 'RAISE'
|
||||
#: Overwrite existing items with new items.
|
||||
DROP_OLD = 'DROP_OLD'
|
||||
#: Keep existing items and drop new items.
|
||||
DROP_NEW = 'DROP_NEW'
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'<{self.name}>'
|
||||
|
||||
|
||||
class DuplicationPolicy(_Marker):
|
||||
"""Base class for bidict's duplication policies.
|
||||
RAISE = OnDupAction.RAISE
|
||||
DROP_OLD = OnDupAction.DROP_OLD
|
||||
DROP_NEW = OnDupAction.DROP_NEW
|
||||
|
||||
|
||||
class OnDup(namedtuple('_OnDup', 'key val kv')):
|
||||
r"""A 3-tuple of :class:`OnDupAction`\s specifying how to handle the 3 kinds of duplication.
|
||||
|
||||
*See also* :ref:`basic-usage:Values Must Be Unique`
|
||||
|
||||
If *kv* is not specified, *val* will be used for *kv*.
|
||||
"""
|
||||
|
||||
__slots__ = ()
|
||||
|
||||
def __new__(cls, key: OnDupAction = DROP_OLD, val: OnDupAction = RAISE, kv: OnDupAction = RAISE) -> 'OnDup':
|
||||
"""Override to provide user-friendly default values."""
|
||||
return super().__new__(cls, key, val, kv or val)
|
||||
|
||||
#: Raise an exception when a duplication is encountered.
|
||||
RAISE = DuplicationPolicy('DUP_POLICY.RAISE')
|
||||
|
||||
#: Overwrite an existing item when a duplication is encountered.
|
||||
OVERWRITE = DuplicationPolicy('DUP_POLICY.OVERWRITE')
|
||||
|
||||
#: Keep the existing item and ignore the new item when a duplication is encountered.
|
||||
IGNORE = DuplicationPolicy('DUP_POLICY.IGNORE')
|
||||
#: Default :class:`OnDup` used for the
|
||||
#: :meth:`~bidict.bidict.__init__`,
|
||||
#: :meth:`~bidict.bidict.__setitem__`, and
|
||||
#: :meth:`~bidict.bidict.update` methods.
|
||||
ON_DUP_DEFAULT = OnDup()
|
||||
#: An :class:`OnDup` whose members are all :obj:`RAISE`.
|
||||
ON_DUP_RAISE = OnDup(key=RAISE, val=RAISE, kv=RAISE)
|
||||
#: An :class:`OnDup` whose members are all :obj:`DROP_OLD`.
|
||||
ON_DUP_DROP_OLD = OnDup(key=DROP_OLD, val=DROP_OLD, kv=DROP_OLD)
|
||||
|
|
|
@ -1,12 +1,12 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2009-2019 Joshua Bronson. All Rights Reserved.
|
||||
# Copyright 2009-2021 Joshua Bronson. All Rights Reserved.
|
||||
#
|
||||
# This Source Code Form is subject to the terms of the Mozilla Public
|
||||
# License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
|
||||
"""Provides all bidict exceptions."""
|
||||
"""Provide all bidict exceptions."""
|
||||
|
||||
|
||||
class BidictException(Exception):
|
||||
|
@ -15,7 +15,7 @@ class BidictException(Exception):
|
|||
|
||||
class DuplicationError(BidictException):
|
||||
"""Base class for exceptions raised when uniqueness is violated
|
||||
as per the RAISE duplication policy.
|
||||
as per the :attr:~bidict.RAISE` :class:`~bidict.OnDupAction`.
|
||||
"""
|
||||
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2009-2019 Joshua Bronson. All Rights Reserved.
|
||||
# Copyright 2009-2021 Joshua Bronson. All Rights Reserved.
|
||||
#
|
||||
# This Source Code Form is subject to the terms of the Mozilla Public
|
||||
# License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
|
@ -22,30 +22,39 @@
|
|||
|
||||
# * Code review nav *
|
||||
#==============================================================================
|
||||
# ← Prev: _delegating_mixins.py Current: _frozenbidict.py Next: _mut.py →
|
||||
# ← Prev: _base.py Current: _frozenbidict.py Next: _mut.py →
|
||||
#==============================================================================
|
||||
|
||||
"""Provides :class:`frozenbidict`, an immutable, hashable bidirectional mapping type."""
|
||||
"""Provide :class:`frozenbidict`, an immutable, hashable bidirectional mapping type."""
|
||||
|
||||
from ._base import BidictBase
|
||||
from ._delegating_mixins import _DelegateKeysAndItemsToFwdm
|
||||
from .compat import ItemsView
|
||||
import typing as _t
|
||||
|
||||
from ._delegating import _DelegatingBidict
|
||||
from ._typing import KT, VT
|
||||
|
||||
|
||||
class frozenbidict(_DelegateKeysAndItemsToFwdm, BidictBase): # noqa: N801,E501; pylint: disable=invalid-name
|
||||
class frozenbidict(_DelegatingBidict[KT, VT]):
|
||||
"""Immutable, hashable bidict type."""
|
||||
|
||||
__slots__ = ()
|
||||
__slots__ = ('_hash',)
|
||||
|
||||
def __hash__(self): # lgtm [py/equals-hash-mismatch]
|
||||
_hash: int
|
||||
|
||||
# Work around lack of support for higher-kinded types in mypy.
|
||||
# Ref: https://github.com/python/typing/issues/548#issuecomment-621571821
|
||||
# Remove this and similar type stubs from other classes if support is ever added.
|
||||
if _t.TYPE_CHECKING:
|
||||
@property
|
||||
def inverse(self) -> 'frozenbidict[VT, KT]': ...
|
||||
|
||||
def __hash__(self) -> int:
|
||||
"""The hash of this bidict as determined by its items."""
|
||||
if getattr(self, '_hash', None) is None:
|
||||
# pylint: disable=protected-access,attribute-defined-outside-init
|
||||
self._hash = ItemsView(self)._hash()
|
||||
self._hash = _t.ItemsView(self)._hash() # type: ignore [attr-defined]
|
||||
return self._hash
|
||||
|
||||
|
||||
# * Code review nav *
|
||||
#==============================================================================
|
||||
# ← Prev: _delegating_mixins.py Current: _frozenbidict.py Next: _mut.py →
|
||||
# ← Prev: _base.py Current: _frozenbidict.py Next: _mut.py →
|
||||
#==============================================================================
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2009-2019 Joshua Bronson. All Rights Reserved.
|
||||
# Copyright 2009-2021 Joshua Bronson. All Rights Reserved.
|
||||
#
|
||||
# This Source Code Form is subject to the terms of the Mozilla Public
|
||||
# License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
|
@ -25,38 +25,61 @@
|
|||
#← Prev: _orderedbase.py Current: _frozenordered.py Next: _orderedbidict.py →
|
||||
#==============================================================================
|
||||
|
||||
"""Provides :class:`FrozenOrderedBidict`, an immutable, hashable, ordered bidict."""
|
||||
"""Provide :class:`FrozenOrderedBidict`, an immutable, hashable, ordered bidict."""
|
||||
|
||||
import typing as _t
|
||||
|
||||
from ._delegating_mixins import _DelegateKeysToFwdm
|
||||
from ._frozenbidict import frozenbidict
|
||||
from ._orderedbase import OrderedBidictBase
|
||||
from .compat import DICTS_ORDERED, PY2, izip
|
||||
from ._typing import KT, VT
|
||||
|
||||
|
||||
# If the Python implementation's dict type is ordered (e.g. PyPy or CPython >= 3.6), then
|
||||
# `FrozenOrderedBidict` can delegate to `_fwdm` for keys: Both `_fwdm` and `_invm` will always
|
||||
# be initialized with the provided items in the correct order, and since `FrozenOrderedBidict`
|
||||
# is immutable, their respective orders can't get out of sync after a mutation. (Can't delegate
|
||||
# to `_fwdm` for items though because values in `_fwdm` are nodes.)
|
||||
_BASES = ((_DelegateKeysToFwdm,) if DICTS_ORDERED else ()) + (OrderedBidictBase,)
|
||||
_CLSDICT = dict(
|
||||
__slots__=(),
|
||||
# Must set __hash__ explicitly, Python prevents inheriting it.
|
||||
# frozenbidict.__hash__ can be reused for FrozenOrderedBidict:
|
||||
# FrozenOrderedBidict inherits BidictBase.__eq__ which is order-insensitive,
|
||||
# and frozenbidict.__hash__ is consistent with BidictBase.__eq__.
|
||||
__hash__=frozenbidict.__hash__.__func__ if PY2 else frozenbidict.__hash__,
|
||||
__doc__='Hashable, immutable, ordered bidict type.',
|
||||
__module__=__name__, # Otherwise unpickling fails in Python 2.
|
||||
)
|
||||
class FrozenOrderedBidict(OrderedBidictBase[KT, VT]):
|
||||
"""Hashable, immutable, ordered bidict type.
|
||||
|
||||
# When PY2 (so we provide iteritems) and DICTS_ORDERED, e.g. on PyPy, the following implementation
|
||||
# of iteritems may be more efficient than that inherited from `Mapping`. This exploits the property
|
||||
# that the keys in `_fwdm` and `_invm` are already in the right order:
|
||||
if PY2 and DICTS_ORDERED:
|
||||
_CLSDICT['iteritems'] = lambda self: izip(self._fwdm, self._invm) # noqa: E501; pylint: disable=protected-access
|
||||
Like a hashable :class:`bidict.OrderedBidict`
|
||||
without the mutating APIs, or like a
|
||||
reversible :class:`bidict.frozenbidict` even on Python < 3.8.
|
||||
(All bidicts are order-preserving when never mutated, so frozenbidict is
|
||||
already order-preserving, but only on Python 3.8+, where dicts are
|
||||
reversible, are all bidicts (including frozenbidict) also reversible.)
|
||||
|
||||
FrozenOrderedBidict = type('FrozenOrderedBidict', _BASES, _CLSDICT) # pylint: disable=invalid-name
|
||||
If you are using Python 3.8+, frozenbidict gives you everything that
|
||||
FrozenOrderedBidict gives you, but with less space overhead.
|
||||
"""
|
||||
|
||||
__slots__ = ('_hash',)
|
||||
__hash__ = frozenbidict.__hash__
|
||||
|
||||
if _t.TYPE_CHECKING:
|
||||
@property
|
||||
def inverse(self) -> 'FrozenOrderedBidict[VT, KT]': ...
|
||||
|
||||
# Delegate to backing dicts for more efficient implementations of keys() and values().
|
||||
# Possible with FrozenOrderedBidict but not OrderedBidict since FrozenOrderedBidict
|
||||
# is immutable, i.e. these can't get out of sync after initialization due to mutation.
|
||||
def keys(self) -> _t.KeysView[KT]:
|
||||
"""A set-like object providing a view on the contained keys."""
|
||||
return self._fwdm._fwdm.keys() # type: ignore [return-value]
|
||||
|
||||
def values(self) -> _t.KeysView[VT]: # type: ignore [override]
|
||||
"""A set-like object providing a view on the contained values."""
|
||||
return self._invm._fwdm.keys() # type: ignore [return-value]
|
||||
|
||||
# Can't delegate for items() because values in _fwdm and _invm are nodes.
|
||||
|
||||
# On Python 3.8+, delegate to backing dicts for a more efficient implementation
|
||||
# of __iter__ and __reversed__ (both of which call this _iter() method):
|
||||
if hasattr(dict, '__reversed__'):
|
||||
def _iter(self, *, reverse: bool = False) -> _t.Iterator[KT]:
|
||||
itfn = reversed if reverse else iter
|
||||
return itfn(self._fwdm._fwdm) # type: ignore [operator,no-any-return]
|
||||
else:
|
||||
# On Python < 3.8, just optimize __iter__:
|
||||
def _iter(self, *, reverse: bool = False) -> _t.Iterator[KT]:
|
||||
if not reverse:
|
||||
return iter(self._fwdm._fwdm)
|
||||
return super()._iter(reverse=True)
|
||||
|
||||
|
||||
# * Code review nav *
|
||||
|
|
|
@ -1,50 +1,56 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2009-2019 Joshua Bronson. All Rights Reserved.
|
||||
# Copyright 2009-2021 Joshua Bronson. All Rights Reserved.
|
||||
#
|
||||
# This Source Code Form is subject to the terms of the Mozilla Public
|
||||
# License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
|
||||
"""Useful functions for working with bidirectional mappings and related data."""
|
||||
"""Functions for iterating over items in a mapping."""
|
||||
|
||||
from itertools import chain, repeat
|
||||
import typing as _t
|
||||
from collections.abc import Mapping
|
||||
from itertools import chain
|
||||
|
||||
from .compat import iteritems, Mapping
|
||||
from ._typing import KT, VT, IterItems, MapOrIterItems
|
||||
|
||||
|
||||
_NULL_IT = repeat(None, 0) # repeat 0 times -> raise StopIteration from the start
|
||||
_NULL_IT: IterItems = iter(())
|
||||
|
||||
|
||||
def _iteritems_mapping_or_iterable(arg):
|
||||
def _iteritems_mapping_or_iterable(arg: MapOrIterItems[KT, VT]) -> IterItems[KT, VT]:
|
||||
"""Yield the items in *arg*.
|
||||
|
||||
If *arg* is a :class:`~collections.abc.Mapping`, return an iterator over its items.
|
||||
Otherwise return an iterator over *arg* itself.
|
||||
"""
|
||||
return iteritems(arg) if isinstance(arg, Mapping) else iter(arg)
|
||||
return iter(arg.items() if isinstance(arg, Mapping) else arg)
|
||||
|
||||
|
||||
def _iteritems_args_kw(*args, **kw):
|
||||
def _iteritems_args_kw(*args: MapOrIterItems[KT, VT], **kw: VT) -> IterItems[KT, VT]:
|
||||
"""Yield the items from the positional argument (if given) and then any from *kw*.
|
||||
|
||||
:raises TypeError: if more than one positional argument is given.
|
||||
"""
|
||||
args_len = len(args)
|
||||
if args_len > 1:
|
||||
raise TypeError('Expected at most 1 positional argument, got %d' % args_len)
|
||||
itemchain = None
|
||||
raise TypeError(f'Expected at most 1 positional argument, got {args_len}')
|
||||
it: IterItems = ()
|
||||
if args:
|
||||
arg = args[0]
|
||||
if arg:
|
||||
itemchain = _iteritems_mapping_or_iterable(arg)
|
||||
it = _iteritems_mapping_or_iterable(arg)
|
||||
if kw:
|
||||
iterkw = iteritems(kw)
|
||||
itemchain = chain(itemchain, iterkw) if itemchain else iterkw
|
||||
return itemchain or _NULL_IT
|
||||
iterkw = iter(kw.items())
|
||||
it = chain(it, iterkw) if it else iterkw
|
||||
return it or _NULL_IT
|
||||
|
||||
|
||||
def inverted(arg):
|
||||
@_t.overload
|
||||
def inverted(arg: _t.Mapping[KT, VT]) -> IterItems[VT, KT]: ...
|
||||
@_t.overload
|
||||
def inverted(arg: IterItems[KT, VT]) -> IterItems[VT, KT]: ...
|
||||
def inverted(arg: MapOrIterItems[KT, VT]) -> IterItems[VT, KT]:
|
||||
"""Yield the inverse items of the provided object.
|
||||
|
||||
If *arg* has a :func:`callable` ``__inverted__`` attribute,
|
||||
|
@ -57,5 +63,5 @@ def inverted(arg):
|
|||
"""
|
||||
inv = getattr(arg, '__inverted__', None)
|
||||
if callable(inv):
|
||||
return inv()
|
||||
return inv() # type: ignore [no-any-return]
|
||||
return ((val, key) for (key, val) in _iteritems_mapping_or_iterable(arg))
|
|
@ -1,19 +0,0 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2009-2019 Joshua Bronson. All Rights Reserved.
|
||||
#
|
||||
# This Source Code Form is subject to the terms of the Mozilla Public
|
||||
# License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
|
||||
"""Provides :class:`_Marker`, an internal type for representing singletons."""
|
||||
|
||||
from collections import namedtuple
|
||||
|
||||
|
||||
class _Marker(namedtuple('_Marker', 'name')):
|
||||
|
||||
__slots__ = ()
|
||||
|
||||
def __repr__(self):
|
||||
return '<%s>' % self.name # pragma: no cover
|
|
@ -1,14 +0,0 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2009-2019 Joshua Bronson. All Rights Reserved.
|
||||
#
|
||||
# This Source Code Form is subject to the terms of the Mozilla Public
|
||||
# License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
|
||||
"""Provides the :obj:`_MISS` sentinel, for internally signaling "missing/not found"."""
|
||||
|
||||
from ._marker import _Marker
|
||||
|
||||
|
||||
_MISS = _Marker('MISSING')
|
|
@ -1,5 +1,5 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2009-2019 Joshua Bronson. All Rights Reserved.
|
||||
# Copyright 2009-2021 Joshua Bronson. All Rights Reserved.
|
||||
#
|
||||
# This Source Code Form is subject to the terms of the Mozilla Public
|
||||
# License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
|
@ -26,32 +26,31 @@
|
|||
#==============================================================================
|
||||
|
||||
|
||||
"""Provides :class:`bidict`."""
|
||||
"""Provide :class:`MutableBidict`."""
|
||||
|
||||
import typing as _t
|
||||
|
||||
from ._abc import MutableBidirectionalMapping
|
||||
from ._base import BidictBase
|
||||
from ._dup import OVERWRITE, RAISE, _OnDup
|
||||
from ._miss import _MISS
|
||||
from .compat import MutableMapping
|
||||
from ._dup import OnDup, ON_DUP_RAISE, ON_DUP_DROP_OLD
|
||||
from ._typing import _NONE, KT, VT, VDT, IterItems, MapOrIterItems
|
||||
|
||||
|
||||
# Extend MutableMapping explicitly because it doesn't implement __subclasshook__, as well as to
|
||||
# inherit method implementations it provides that we can reuse (namely `setdefault`).
|
||||
class MutableBidict(BidictBase, MutableMapping):
|
||||
class MutableBidict(BidictBase[KT, VT], MutableBidirectionalMapping[KT, VT]):
|
||||
"""Base class for mutable bidirectional mappings."""
|
||||
|
||||
__slots__ = ()
|
||||
|
||||
__hash__ = None # since this class is mutable; explicit > implicit.
|
||||
if _t.TYPE_CHECKING:
|
||||
@property
|
||||
def inverse(self) -> 'MutableBidict[VT, KT]': ...
|
||||
|
||||
_ON_DUP_OVERWRITE = _OnDup(key=OVERWRITE, val=OVERWRITE, kv=OVERWRITE)
|
||||
|
||||
def __delitem__(self, key):
|
||||
u"""*x.__delitem__(y) ⟺ del x[y]*"""
|
||||
def __delitem__(self, key: KT) -> None:
|
||||
"""*x.__delitem__(y) ⟺ del x[y]*"""
|
||||
self._pop(key)
|
||||
|
||||
def __setitem__(self, key, val):
|
||||
"""
|
||||
Set the value for *key* to *val*.
|
||||
def __setitem__(self, key: KT, val: VT) -> None:
|
||||
"""Set the value for *key* to *val*.
|
||||
|
||||
If *key* is already associated with *val*, this is a no-op.
|
||||
|
||||
|
@ -64,7 +63,7 @@ class MutableBidict(BidictBase, MutableMapping):
|
|||
to protect against accidental removal of the key
|
||||
that's currently associated with *val*.
|
||||
|
||||
Use :meth:`put` instead if you want to specify different policy in
|
||||
Use :meth:`put` instead if you want to specify different behavior in
|
||||
the case that the provided key or value duplicates an existing one.
|
||||
Or use :meth:`forceput` to unconditionally associate *key* with *val*,
|
||||
replacing any existing items as necessary to preserve uniqueness.
|
||||
|
@ -76,16 +75,12 @@ class MutableBidict(BidictBase, MutableMapping):
|
|||
existing item and *val* duplicates the value of a different
|
||||
existing item.
|
||||
"""
|
||||
on_dup = self._get_on_dup()
|
||||
self._put(key, val, on_dup)
|
||||
self._put(key, val, self.on_dup)
|
||||
|
||||
def put(self, key, val, on_dup_key=RAISE, on_dup_val=RAISE, on_dup_kv=None):
|
||||
"""
|
||||
Associate *key* with *val* with the specified duplication policies.
|
||||
def put(self, key: KT, val: VT, on_dup: OnDup = ON_DUP_RAISE) -> None:
|
||||
"""Associate *key* with *val*, honoring the :class:`OnDup` given in *on_dup*.
|
||||
|
||||
If *on_dup_kv* is ``None``, the *on_dup_val* policy will be used for it.
|
||||
|
||||
For example, if all given duplication policies are :attr:`~bidict.RAISE`,
|
||||
For example, if *on_dup* is :attr:`~bidict.ON_DUP_RAISE`,
|
||||
then *key* will be associated with *val* if and only if
|
||||
*key* is not already associated with an existing value and
|
||||
*val* is not already associated with an existing key,
|
||||
|
@ -94,37 +89,39 @@ class MutableBidict(BidictBase, MutableMapping):
|
|||
If *key* is already associated with *val*, this is a no-op.
|
||||
|
||||
:raises bidict.KeyDuplicationError: if attempting to insert an item
|
||||
whose key only duplicates an existing item's, and *on_dup_key* is
|
||||
whose key only duplicates an existing item's, and *on_dup.key* is
|
||||
:attr:`~bidict.RAISE`.
|
||||
|
||||
:raises bidict.ValueDuplicationError: if attempting to insert an item
|
||||
whose value only duplicates an existing item's, and *on_dup_val* is
|
||||
whose value only duplicates an existing item's, and *on_dup.val* is
|
||||
:attr:`~bidict.RAISE`.
|
||||
|
||||
:raises bidict.KeyAndValueDuplicationError: if attempting to insert an
|
||||
item whose key duplicates one existing item's, and whose value
|
||||
duplicates another existing item's, and *on_dup_kv* is
|
||||
duplicates another existing item's, and *on_dup.kv* is
|
||||
:attr:`~bidict.RAISE`.
|
||||
"""
|
||||
on_dup = self._get_on_dup((on_dup_key, on_dup_val, on_dup_kv))
|
||||
self._put(key, val, on_dup)
|
||||
|
||||
def forceput(self, key, val):
|
||||
"""
|
||||
Associate *key* with *val* unconditionally.
|
||||
def forceput(self, key: KT, val: VT) -> None:
|
||||
"""Associate *key* with *val* unconditionally.
|
||||
|
||||
Replace any existing mappings containing key *key* or value *val*
|
||||
as necessary to preserve uniqueness.
|
||||
"""
|
||||
self._put(key, val, self._ON_DUP_OVERWRITE)
|
||||
self._put(key, val, ON_DUP_DROP_OLD)
|
||||
|
||||
def clear(self):
|
||||
def clear(self) -> None:
|
||||
"""Remove all items."""
|
||||
self._fwdm.clear()
|
||||
self._invm.clear()
|
||||
|
||||
def pop(self, key, default=_MISS):
|
||||
u"""*x.pop(k[, d]) → v*
|
||||
@_t.overload
|
||||
def pop(self, key: KT) -> VT: ...
|
||||
@_t.overload
|
||||
def pop(self, key: KT, default: VDT = ...) -> VDT: ...
|
||||
def pop(self, key: KT, default: VDT = _NONE) -> VDT:
|
||||
"""*x.pop(k[, d]) → v*
|
||||
|
||||
Remove specified key and return the corresponding value.
|
||||
|
||||
|
@ -133,12 +130,12 @@ class MutableBidict(BidictBase, MutableMapping):
|
|||
try:
|
||||
return self._pop(key)
|
||||
except KeyError:
|
||||
if default is _MISS:
|
||||
if default is _NONE:
|
||||
raise
|
||||
return default
|
||||
|
||||
def popitem(self):
|
||||
u"""*x.popitem() → (k, v)*
|
||||
def popitem(self) -> _t.Tuple[KT, VT]:
|
||||
"""*x.popitem() → (k, v)*
|
||||
|
||||
Remove and return some item as a (key, value) pair.
|
||||
|
||||
|
@ -150,24 +147,38 @@ class MutableBidict(BidictBase, MutableMapping):
|
|||
del self._invm[val]
|
||||
return key, val
|
||||
|
||||
def update(self, *args, **kw):
|
||||
"""Like :meth:`putall` with default duplication policies."""
|
||||
@_t.overload
|
||||
def update(self, __arg: _t.Mapping[KT, VT], **kw: VT) -> None: ...
|
||||
@_t.overload
|
||||
def update(self, __arg: IterItems[KT, VT], **kw: VT) -> None: ...
|
||||
@_t.overload
|
||||
def update(self, **kw: VT) -> None: ...
|
||||
def update(self, *args: MapOrIterItems[KT, VT], **kw: VT) -> None:
|
||||
"""Like calling :meth:`putall` with *self.on_dup* passed for *on_dup*."""
|
||||
if args or kw:
|
||||
self._update(False, None, *args, **kw)
|
||||
self._update(False, self.on_dup, *args, **kw)
|
||||
|
||||
def forceupdate(self, *args, **kw):
|
||||
@_t.overload
|
||||
def forceupdate(self, __arg: _t.Mapping[KT, VT], **kw: VT) -> None: ...
|
||||
@_t.overload
|
||||
def forceupdate(self, __arg: IterItems[KT, VT], **kw: VT) -> None: ...
|
||||
@_t.overload
|
||||
def forceupdate(self, **kw: VT) -> None: ...
|
||||
def forceupdate(self, *args: MapOrIterItems[KT, VT], **kw: VT) -> None:
|
||||
"""Like a bulk :meth:`forceput`."""
|
||||
self._update(False, self._ON_DUP_OVERWRITE, *args, **kw)
|
||||
self._update(False, ON_DUP_DROP_OLD, *args, **kw)
|
||||
|
||||
def putall(self, items, on_dup_key=RAISE, on_dup_val=RAISE, on_dup_kv=None):
|
||||
"""
|
||||
Like a bulk :meth:`put`.
|
||||
@_t.overload
|
||||
def putall(self, items: _t.Mapping[KT, VT], on_dup: OnDup) -> None: ...
|
||||
@_t.overload
|
||||
def putall(self, items: IterItems[KT, VT], on_dup: OnDup = ON_DUP_RAISE) -> None: ...
|
||||
def putall(self, items: MapOrIterItems[KT, VT], on_dup: OnDup = ON_DUP_RAISE) -> None:
|
||||
"""Like a bulk :meth:`put`.
|
||||
|
||||
If one of the given items causes an exception to be raised,
|
||||
none of the items is inserted.
|
||||
"""
|
||||
if items:
|
||||
on_dup = self._get_on_dup((on_dup_key, on_dup_val, on_dup_kv))
|
||||
self._update(False, on_dup, items)
|
||||
|
||||
|
||||
|
|
|
@ -1,34 +1,35 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2009-2019 Joshua Bronson. All Rights Reserved.
|
||||
# Copyright 2009-2021 Joshua Bronson. All Rights Reserved.
|
||||
#
|
||||
# This Source Code Form is subject to the terms of the Mozilla Public
|
||||
# License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
"""Provides :func:`bidict.namedbidict`."""
|
||||
"""Provide :func:`bidict.namedbidict`."""
|
||||
|
||||
import re
|
||||
import typing as _t
|
||||
from sys import _getframe
|
||||
|
||||
from ._abc import BidirectionalMapping
|
||||
from ._abc import BidirectionalMapping, KT, VT
|
||||
from ._bidict import bidict
|
||||
from .compat import PY2
|
||||
|
||||
|
||||
_isidentifier = ( # pylint: disable=invalid-name
|
||||
re.compile('[A-Za-z_][A-Za-z0-9_]*$').match if PY2 else str.isidentifier
|
||||
)
|
||||
|
||||
|
||||
def namedbidict(typename, keyname, valname, base_type=bidict):
|
||||
def namedbidict(
|
||||
typename: str,
|
||||
keyname: str,
|
||||
valname: str,
|
||||
*,
|
||||
base_type: _t.Type[BidirectionalMapping[KT, VT]] = bidict,
|
||||
) -> _t.Type[BidirectionalMapping[KT, VT]]:
|
||||
r"""Create a new subclass of *base_type* with custom accessors.
|
||||
|
||||
Analagous to :func:`collections.namedtuple`.
|
||||
Like :func:`collections.namedtuple` for bidicts.
|
||||
|
||||
The new class's ``__name__`` and ``__qualname__``
|
||||
will be set based on *typename*.
|
||||
The new class's ``__name__`` and ``__qualname__`` will be set to *typename*,
|
||||
and its ``__module__`` will be set to the caller's module.
|
||||
|
||||
Instances of it will provide access to their
|
||||
:attr:`inverse <BidirectionalMapping.inverse>`\s
|
||||
Instances of the new class will provide access to their
|
||||
:attr:`inverse <BidirectionalMapping.inverse>` instances
|
||||
via the custom *keyname*\_for property,
|
||||
and access to themselves
|
||||
via the custom *valname*\_for property.
|
||||
|
@ -39,63 +40,58 @@ def namedbidict(typename, keyname, valname, base_type=bidict):
|
|||
:raises ValueError: if any of the *typename*, *keyname*, or *valname*
|
||||
strings is not a valid Python identifier, or if *keyname == valname*.
|
||||
|
||||
:raises TypeError: if *base_type* is not a subclass of
|
||||
:class:`BidirectionalMapping`.
|
||||
(This function requires slightly more of *base_type*,
|
||||
e.g. the availability of an ``_isinv`` attribute,
|
||||
but all the :ref:`concrete bidict types
|
||||
<other-bidict-types:Bidict Types Diagram>`
|
||||
that the :mod:`bidict` module provides can be passed in.
|
||||
Check out the code if you actually need to pass in something else.)
|
||||
:raises TypeError: if *base_type* is not a :class:`BidirectionalMapping` subclass
|
||||
that provides ``_isinv`` and :meth:`~object.__getstate__` attributes.
|
||||
(Any :class:`~bidict.BidictBase` subclass can be passed in, including all the
|
||||
concrete bidict types pictured in the :ref:`other-bidict-types:Bidict Types Diagram`.
|
||||
"""
|
||||
# Re the `base_type` docs above:
|
||||
# The additional requirements (providing _isinv and __getstate__) do not belong in the
|
||||
# BidirectionalMapping interface, and it's overkill to create additional interface(s) for this.
|
||||
# On the other hand, it's overkill to require that base_type be a subclass of BidictBase, since
|
||||
# that's too specific. The BidirectionalMapping check along with the docs above should suffice.
|
||||
if not issubclass(base_type, BidirectionalMapping):
|
||||
if not issubclass(base_type, BidirectionalMapping) or not all(hasattr(base_type, i) for i in ('_isinv', '__getstate__')):
|
||||
raise TypeError(base_type)
|
||||
names = (typename, keyname, valname)
|
||||
if not all(map(_isidentifier, names)) or keyname == valname:
|
||||
if not all(map(str.isidentifier, names)) or keyname == valname:
|
||||
raise ValueError(names)
|
||||
|
||||
class _Named(base_type): # pylint: disable=too-many-ancestors
|
||||
class _Named(base_type): # type: ignore [valid-type,misc]
|
||||
|
||||
__slots__ = ()
|
||||
|
||||
def _getfwd(self):
|
||||
return self.inverse if self._isinv else self
|
||||
def _getfwd(self) -> '_Named':
|
||||
return self.inverse if self._isinv else self # type: ignore [no-any-return]
|
||||
|
||||
def _getinv(self):
|
||||
return self if self._isinv else self.inverse
|
||||
def _getinv(self) -> '_Named':
|
||||
return self if self._isinv else self.inverse # type: ignore [no-any-return]
|
||||
|
||||
@property
|
||||
def _keyname(self):
|
||||
def _keyname(self) -> str:
|
||||
return valname if self._isinv else keyname
|
||||
|
||||
@property
|
||||
def _valname(self):
|
||||
def _valname(self) -> str:
|
||||
return keyname if self._isinv else valname
|
||||
|
||||
def __reduce__(self):
|
||||
def __reduce__(self) -> '_t.Tuple[_t.Callable[[str, str, str, _t.Type[BidirectionalMapping]], BidirectionalMapping], _t.Tuple[str, str, str, _t.Type[BidirectionalMapping]], dict]':
|
||||
return (_make_empty, (typename, keyname, valname, base_type), self.__getstate__())
|
||||
|
||||
bname = base_type.__name__
|
||||
fname = valname + '_for'
|
||||
iname = keyname + '_for'
|
||||
names = dict(typename=typename, bname=bname, keyname=keyname, valname=valname)
|
||||
fdoc = u'{typename} forward {bname}: {keyname} → {valname}'.format(**names)
|
||||
idoc = u'{typename} inverse {bname}: {valname} → {keyname}'.format(**names)
|
||||
setattr(_Named, fname, property(_Named._getfwd, doc=fdoc)) # pylint: disable=protected-access
|
||||
setattr(_Named, iname, property(_Named._getinv, doc=idoc)) # pylint: disable=protected-access
|
||||
fdoc = f'{typename} forward {bname}: {keyname} → {valname}'
|
||||
idoc = f'{typename} inverse {bname}: {valname} → {keyname}'
|
||||
setattr(_Named, fname, property(_Named._getfwd, doc=fdoc))
|
||||
setattr(_Named, iname, property(_Named._getinv, doc=idoc))
|
||||
|
||||
if not PY2:
|
||||
_Named.__qualname__ = _Named.__qualname__[:-len(_Named.__name__)] + typename
|
||||
_Named.__name__ = typename
|
||||
_Named.__qualname__ = typename
|
||||
_Named.__module__ = _getframe(1).f_globals.get('__name__') # type: ignore [assignment]
|
||||
return _Named
|
||||
|
||||
|
||||
def _make_empty(typename, keyname, valname, base_type):
|
||||
def _make_empty(
|
||||
typename: str,
|
||||
keyname: str,
|
||||
valname: str,
|
||||
base_type: _t.Type[BidirectionalMapping] = bidict,
|
||||
) -> BidirectionalMapping:
|
||||
"""Create a named bidict with the indicated arguments and return an empty instance.
|
||||
Used to make :func:`bidict.namedbidict` instances picklable.
|
||||
"""
|
||||
|
|
|
@ -1,14 +0,0 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2009-2019 Joshua Bronson. All Rights Reserved.
|
||||
#
|
||||
# This Source Code Form is subject to the terms of the Mozilla Public
|
||||
# License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
|
||||
"""Provides the :obj:`_NOOP` sentinel, for internally signaling "no-op"."""
|
||||
|
||||
from ._marker import _Marker
|
||||
|
||||
|
||||
_NOOP = _Marker('NO-OP')
|
|
@ -1,5 +1,5 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2009-2019 Joshua Bronson. All Rights Reserved.
|
||||
# Copyright 2009-2021 Joshua Bronson. All Rights Reserved.
|
||||
#
|
||||
# This Source Code Form is subject to the terms of the Mozilla Public
|
||||
# License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
|
@ -26,17 +26,19 @@
|
|||
#==============================================================================
|
||||
|
||||
|
||||
"""Provides :class:`OrderedBidictBase`."""
|
||||
"""Provide :class:`OrderedBidictBase`."""
|
||||
|
||||
import typing as _t
|
||||
from copy import copy
|
||||
from weakref import ref
|
||||
|
||||
from ._base import _WriteResult, BidictBase
|
||||
from ._abc import MutableBidirectionalMapping
|
||||
from ._base import _NONE, _DedupResult, _WriteResult, BidictBase, BT
|
||||
from ._bidict import bidict
|
||||
from ._miss import _MISS
|
||||
from .compat import Mapping, PY2, iteritems, izip
|
||||
from ._typing import KT, VT, OKT, OVT, IterItems, MapOrIterItems
|
||||
|
||||
|
||||
class _Node(object): # pylint: disable=too-few-public-methods
|
||||
class _Node:
|
||||
"""A node in a circular doubly-linked list
|
||||
used to encode the order of items in an ordered bidict.
|
||||
|
||||
|
@ -55,33 +57,33 @@ class _Node(object): # pylint: disable=too-few-public-methods
|
|||
|
||||
__slots__ = ('_prv', '_nxt', '__weakref__')
|
||||
|
||||
def __init__(self, prv=None, nxt=None):
|
||||
def __init__(self, prv: '_Node' = None, nxt: '_Node' = None) -> None:
|
||||
self._setprv(prv)
|
||||
self._setnxt(nxt)
|
||||
|
||||
def __repr__(self): # pragma: no cover
|
||||
def __repr__(self) -> str:
|
||||
clsname = self.__class__.__name__
|
||||
prv = id(self.prv)
|
||||
nxt = id(self.nxt)
|
||||
return '%s(prv=%s, self=%s, nxt=%s)' % (clsname, prv, id(self), nxt)
|
||||
return f'{clsname}(prv={prv}, self={id(self)}, nxt={nxt})'
|
||||
|
||||
def _getprv(self):
|
||||
def _getprv(self) -> '_t.Optional[_Node]':
|
||||
return self._prv() if isinstance(self._prv, ref) else self._prv
|
||||
|
||||
def _setprv(self, prv):
|
||||
def _setprv(self, prv: '_t.Optional[_Node]') -> None:
|
||||
self._prv = prv and ref(prv)
|
||||
|
||||
prv = property(_getprv, _setprv)
|
||||
|
||||
def _getnxt(self):
|
||||
def _getnxt(self) -> '_t.Optional[_Node]':
|
||||
return self._nxt() if isinstance(self._nxt, ref) else self._nxt
|
||||
|
||||
def _setnxt(self, nxt):
|
||||
def _setnxt(self, nxt: '_t.Optional[_Node]') -> None:
|
||||
self._nxt = nxt and ref(nxt)
|
||||
|
||||
nxt = property(_getnxt, _setnxt)
|
||||
|
||||
def __getstate__(self):
|
||||
def __getstate__(self) -> dict:
|
||||
"""Return the instance state dictionary
|
||||
but with weakrefs converted to strong refs
|
||||
so that it can be pickled.
|
||||
|
@ -90,13 +92,13 @@ class _Node(object): # pylint: disable=too-few-public-methods
|
|||
"""
|
||||
return dict(_prv=self.prv, _nxt=self.nxt)
|
||||
|
||||
def __setstate__(self, state):
|
||||
def __setstate__(self, state: dict) -> None:
|
||||
"""Set the instance state from *state*."""
|
||||
self._setprv(state['_prv'])
|
||||
self._setnxt(state['_nxt'])
|
||||
|
||||
|
||||
class _Sentinel(_Node): # pylint: disable=too-few-public-methods
|
||||
class _SentinelNode(_Node):
|
||||
"""Special node in a circular doubly-linked list
|
||||
that links the first node with the last node.
|
||||
When its next and previous references point back to itself
|
||||
|
@ -105,19 +107,16 @@ class _Sentinel(_Node): # pylint: disable=too-few-public-methods
|
|||
|
||||
__slots__ = ()
|
||||
|
||||
def __init__(self, prv=None, nxt=None):
|
||||
super(_Sentinel, self).__init__(prv or self, nxt or self)
|
||||
def __init__(self, prv: _Node = None, nxt: _Node = None) -> None:
|
||||
super().__init__(prv or self, nxt or self)
|
||||
|
||||
def __repr__(self): # pragma: no cover
|
||||
return '<SENTINEL>'
|
||||
def __repr__(self) -> str:
|
||||
return '<SNTL>'
|
||||
|
||||
def __bool__(self):
|
||||
def __bool__(self) -> bool:
|
||||
return False
|
||||
|
||||
if PY2:
|
||||
__nonzero__ = __bool__
|
||||
|
||||
def __iter__(self, reverse=False):
|
||||
def _iter(self, *, reverse: bool = False) -> _t.Iterator[_Node]:
|
||||
"""Iterator yielding nodes in the requested order,
|
||||
i.e. traverse the linked list via :attr:`nxt`
|
||||
(or :attr:`prv` if *reverse* is truthy)
|
||||
|
@ -130,26 +129,35 @@ class _Sentinel(_Node): # pylint: disable=too-few-public-methods
|
|||
node = getattr(node, attr)
|
||||
|
||||
|
||||
class OrderedBidictBase(BidictBase):
|
||||
class OrderedBidictBase(BidictBase[KT, VT]):
|
||||
"""Base class implementing an ordered :class:`BidirectionalMapping`."""
|
||||
|
||||
__slots__ = ('_sntl',)
|
||||
|
||||
_fwdm_cls = bidict
|
||||
_invm_cls = bidict
|
||||
_fwdm_cls: _t.Type[MutableBidirectionalMapping[KT, _Node]] = bidict # type: ignore [assignment]
|
||||
_invm_cls: _t.Type[MutableBidirectionalMapping[VT, _Node]] = bidict # type: ignore [assignment]
|
||||
_fwdm: bidict[KT, _Node] # type: ignore [assignment]
|
||||
_invm: bidict[VT, _Node] # type: ignore [assignment]
|
||||
|
||||
#: The object used by :meth:`__repr__` for printing the contained items.
|
||||
_repr_delegate = list
|
||||
|
||||
def __init__(self, *args, **kw):
|
||||
@_t.overload
|
||||
def __init__(self, __arg: _t.Mapping[KT, VT], **kw: VT) -> None: ...
|
||||
@_t.overload
|
||||
def __init__(self, __arg: IterItems[KT, VT], **kw: VT) -> None: ...
|
||||
@_t.overload
|
||||
def __init__(self, **kw: VT) -> None: ...
|
||||
def __init__(self, *args: MapOrIterItems[KT, VT], **kw: VT) -> None:
|
||||
"""Make a new ordered bidirectional mapping.
|
||||
The signature is the same as that of regular dictionaries.
|
||||
The signature behaves like that of :class:`dict`.
|
||||
Items passed in are added in the order they are passed,
|
||||
respecting this bidict type's duplication policies along the way.
|
||||
respecting the :attr:`on_dup` class attribute in the process.
|
||||
|
||||
The order in which items are inserted is remembered,
|
||||
similar to :class:`collections.OrderedDict`.
|
||||
"""
|
||||
self._sntl = _Sentinel()
|
||||
self._sntl = _SentinelNode()
|
||||
|
||||
# Like unordered bidicts, ordered bidicts also store two backing one-directional mappings
|
||||
# `_fwdm` and `_invm`. But rather than mapping `key` to `val` and `val` to `key`
|
||||
|
@ -159,55 +167,58 @@ class OrderedBidictBase(BidictBase):
|
|||
# To effect this difference, `_write_item` and `_undo_write` are overridden. But much of the
|
||||
# rest of BidictBase's implementation, including BidictBase.__init__ and BidictBase._update,
|
||||
# are inherited and are able to be reused without modification.
|
||||
super(OrderedBidictBase, self).__init__(*args, **kw)
|
||||
super().__init__(*args, **kw)
|
||||
|
||||
def _init_inv(self):
|
||||
super(OrderedBidictBase, self)._init_inv()
|
||||
self.inverse._sntl = self._sntl # pylint: disable=protected-access
|
||||
if _t.TYPE_CHECKING:
|
||||
@property
|
||||
def inverse(self) -> 'OrderedBidictBase[VT, KT]': ...
|
||||
|
||||
def _init_inv(self) -> None:
|
||||
super()._init_inv()
|
||||
self.inverse._sntl = self._sntl
|
||||
|
||||
# Can't reuse BidictBase.copy since ordered bidicts have different internal structure.
|
||||
def copy(self):
|
||||
def copy(self: BT) -> BT:
|
||||
"""A shallow copy of this ordered bidict."""
|
||||
# Fast copy implementation bypassing __init__. See comments in :meth:`BidictBase.copy`.
|
||||
copy = self.__class__.__new__(self.__class__)
|
||||
sntl = _Sentinel()
|
||||
fwdm = self._fwdm.copy()
|
||||
invm = self._invm.copy()
|
||||
cp: BT = self.__class__.__new__(self.__class__)
|
||||
sntl = _SentinelNode()
|
||||
fwdm = copy(self._fwdm)
|
||||
invm = copy(self._invm)
|
||||
cur = sntl
|
||||
nxt = sntl.nxt
|
||||
for (key, val) in iteritems(self):
|
||||
for (key, val) in self.items():
|
||||
nxt = _Node(cur, sntl)
|
||||
cur.nxt = fwdm[key] = invm[val] = nxt
|
||||
cur = nxt
|
||||
sntl.prv = nxt
|
||||
copy._sntl = sntl # pylint: disable=protected-access
|
||||
copy._fwdm = fwdm # pylint: disable=protected-access
|
||||
copy._invm = invm # pylint: disable=protected-access
|
||||
copy._init_inv() # pylint: disable=protected-access
|
||||
return copy
|
||||
cp._sntl = sntl # type: ignore [attr-defined]
|
||||
cp._fwdm = fwdm
|
||||
cp._invm = invm
|
||||
cp._init_inv()
|
||||
return cp
|
||||
|
||||
def __getitem__(self, key):
|
||||
__copy__ = copy
|
||||
|
||||
def __getitem__(self, key: KT) -> VT:
|
||||
nodefwd = self._fwdm[key]
|
||||
val = self._invm.inverse[nodefwd]
|
||||
return val
|
||||
|
||||
def _pop(self, key):
|
||||
def _pop(self, key: KT) -> VT:
|
||||
nodefwd = self._fwdm.pop(key)
|
||||
val = self._invm.inverse.pop(nodefwd)
|
||||
nodefwd.prv.nxt = nodefwd.nxt
|
||||
nodefwd.nxt.prv = nodefwd.prv
|
||||
return val
|
||||
|
||||
def _isdupitem(self, key, val, dedup_result):
|
||||
"""Return whether (key, val) duplicates an existing item."""
|
||||
isdupkey, isdupval, nodeinv, nodefwd = dedup_result
|
||||
isdupitem = nodeinv is nodefwd
|
||||
if isdupitem:
|
||||
assert isdupkey
|
||||
assert isdupval
|
||||
return isdupitem
|
||||
@staticmethod
|
||||
def _already_have(key: KT, val: VT, nodeinv: _Node, nodefwd: _Node) -> bool: # type: ignore [override]
|
||||
# Overrides _base.BidictBase.
|
||||
return nodeinv is nodefwd
|
||||
|
||||
def _write_item(self, key, val, dedup_result): # pylint: disable=too-many-locals
|
||||
def _write_item(self, key: KT, val: VT, dedup_result: _DedupResult) -> _WriteResult:
|
||||
# Overrides _base.BidictBase.
|
||||
fwdm = self._fwdm # bidict mapping keys to nodes
|
||||
invm = self._invm # bidict mapping vals to nodes
|
||||
isdupkey, isdupval, nodeinv, nodefwd = dedup_result
|
||||
|
@ -217,7 +228,8 @@ class OrderedBidictBase(BidictBase):
|
|||
last = sntl.prv
|
||||
node = _Node(last, sntl)
|
||||
last.nxt = sntl.prv = fwdm[key] = invm[val] = node
|
||||
oldkey = oldval = _MISS
|
||||
oldkey: OKT = _NONE
|
||||
oldval: OVT = _NONE
|
||||
elif isdupkey and isdupval:
|
||||
# Key and value duplication across two different nodes.
|
||||
assert nodefwd is not nodeinv
|
||||
|
@ -239,19 +251,19 @@ class OrderedBidictBase(BidictBase):
|
|||
fwdm[key] = invm[val] = nodefwd
|
||||
elif isdupkey:
|
||||
oldval = invm.inverse[nodefwd]
|
||||
oldkey = _MISS
|
||||
oldkey = _NONE
|
||||
oldnodeinv = invm.pop(oldval)
|
||||
assert oldnodeinv is nodefwd
|
||||
invm[val] = nodefwd
|
||||
else: # isdupval
|
||||
oldkey = fwdm.inverse[nodeinv]
|
||||
oldval = _MISS
|
||||
oldval = _NONE
|
||||
oldnodefwd = fwdm.pop(oldkey)
|
||||
assert oldnodefwd is nodeinv
|
||||
fwdm[key] = nodeinv
|
||||
return _WriteResult(key, val, oldkey, oldval)
|
||||
|
||||
def _undo_write(self, dedup_result, write_result): # pylint: disable=too-many-locals
|
||||
def _undo_write(self, dedup_result: _DedupResult, write_result: _WriteResult) -> None:
|
||||
fwdm = self._fwdm
|
||||
invm = self._invm
|
||||
isdupkey, isdupval, nodeinv, nodefwd = dedup_result
|
||||
|
@ -274,26 +286,18 @@ class OrderedBidictBase(BidictBase):
|
|||
fwdm[oldkey] = nodeinv
|
||||
assert invm[val] is nodeinv
|
||||
|
||||
def __iter__(self, reverse=False):
|
||||
"""An iterator over this bidict's items in order."""
|
||||
def __iter__(self) -> _t.Iterator[KT]:
|
||||
"""Iterator over the contained keys in insertion order."""
|
||||
return self._iter()
|
||||
|
||||
def _iter(self, *, reverse: bool = False) -> _t.Iterator[KT]:
|
||||
fwdm_inv = self._fwdm.inverse
|
||||
for node in self._sntl.__iter__(reverse=reverse):
|
||||
for node in self._sntl._iter(reverse=reverse):
|
||||
yield fwdm_inv[node]
|
||||
|
||||
def __reversed__(self):
|
||||
"""An iterator over this bidict's items in reverse order."""
|
||||
for key in self.__iter__(reverse=True):
|
||||
yield key
|
||||
|
||||
def equals_order_sensitive(self, other):
|
||||
"""Order-sensitive equality check.
|
||||
|
||||
*See also* :ref:`eq-order-insensitive`
|
||||
"""
|
||||
# Same short-circuit as BidictBase.__eq__. Factoring out not worth function call overhead.
|
||||
if not isinstance(other, Mapping) or len(self) != len(other):
|
||||
return False
|
||||
return all(i == j for (i, j) in izip(iteritems(self), iteritems(other)))
|
||||
def __reversed__(self) -> _t.Iterator[KT]:
|
||||
"""Iterator over the contained keys in reverse insertion order."""
|
||||
yield from self._iter(reverse=True)
|
||||
|
||||
|
||||
# * Code review nav *
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2009-2019 Joshua Bronson. All Rights Reserved.
|
||||
# Copyright 2009-2021 Joshua Bronson. All Rights Reserved.
|
||||
#
|
||||
# This Source Code Form is subject to the terms of the Mozilla Public
|
||||
# License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
|
@ -26,26 +26,32 @@
|
|||
#==============================================================================
|
||||
|
||||
|
||||
"""Provides :class:`OrderedBidict`."""
|
||||
"""Provide :class:`OrderedBidict`."""
|
||||
|
||||
import typing as _t
|
||||
|
||||
from ._mut import MutableBidict
|
||||
from ._orderedbase import OrderedBidictBase
|
||||
from ._typing import KT, VT
|
||||
|
||||
|
||||
class OrderedBidict(OrderedBidictBase, MutableBidict):
|
||||
class OrderedBidict(OrderedBidictBase[KT, VT], MutableBidict[KT, VT]):
|
||||
"""Mutable bidict type that maintains items in insertion order."""
|
||||
|
||||
__slots__ = ()
|
||||
__hash__ = None # since this class is mutable; explicit > implicit.
|
||||
|
||||
def clear(self):
|
||||
if _t.TYPE_CHECKING:
|
||||
@property
|
||||
def inverse(self) -> 'OrderedBidict[VT, KT]': ...
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Remove all items."""
|
||||
self._fwdm.clear()
|
||||
self._invm.clear()
|
||||
self._sntl.nxt = self._sntl.prv = self._sntl
|
||||
|
||||
def popitem(self, last=True): # pylint: disable=arguments-differ
|
||||
u"""*x.popitem() → (k, v)*
|
||||
def popitem(self, last: bool = True) -> _t.Tuple[KT, VT]:
|
||||
"""*x.popitem() → (k, v)*
|
||||
|
||||
Remove and return the most recently added item as a (key, value) pair
|
||||
if *last* is True, else the least recently added item.
|
||||
|
@ -54,11 +60,13 @@ class OrderedBidict(OrderedBidictBase, MutableBidict):
|
|||
"""
|
||||
if not self:
|
||||
raise KeyError('mapping is empty')
|
||||
key = next((reversed if last else iter)(self))
|
||||
itfn: _t.Callable = reversed if last else iter # type: ignore [assignment]
|
||||
it = itfn(self)
|
||||
key = next(it)
|
||||
val = self._pop(key)
|
||||
return key, val
|
||||
|
||||
def move_to_end(self, key, last=True):
|
||||
def move_to_end(self, key: KT, last: bool = True) -> None:
|
||||
"""Move an existing key to the beginning or end of this ordered bidict.
|
||||
|
||||
The item is moved to the end if *last* is True, else to the beginning.
|
||||
|
@ -70,15 +78,15 @@ class OrderedBidict(OrderedBidictBase, MutableBidict):
|
|||
node.nxt.prv = node.prv
|
||||
sntl = self._sntl
|
||||
if last:
|
||||
last = sntl.prv
|
||||
node.prv = last
|
||||
lastnode = sntl.prv
|
||||
node.prv = lastnode
|
||||
node.nxt = sntl
|
||||
sntl.prv = last.nxt = node
|
||||
sntl.prv = lastnode.nxt = node
|
||||
else:
|
||||
first = sntl.nxt
|
||||
firstnode = sntl.nxt
|
||||
node.prv = sntl
|
||||
node.nxt = first
|
||||
sntl.nxt = first.prv = node
|
||||
node.nxt = firstnode
|
||||
sntl.nxt = firstnode.prv = node
|
||||
|
||||
|
||||
# * Code review nav *
|
||||
|
|
|
@ -0,0 +1,33 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2009-2021 Joshua Bronson. All Rights Reserved.
|
||||
#
|
||||
# This Source Code Form is subject to the terms of the Mozilla Public
|
||||
# License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
|
||||
"""Provide typing-related objects."""
|
||||
|
||||
import typing as _t
|
||||
|
||||
|
||||
KT = _t.TypeVar('KT')
|
||||
VT = _t.TypeVar('VT')
|
||||
IterItems = _t.Iterable[_t.Tuple[KT, VT]]
|
||||
MapOrIterItems = _t.Union[_t.Mapping[KT, VT], IterItems[KT, VT]]
|
||||
|
||||
DT = _t.TypeVar('DT') #: for default arguments
|
||||
VDT = _t.Union[VT, DT]
|
||||
|
||||
|
||||
class _BareReprMeta(type):
|
||||
def __repr__(cls) -> str:
|
||||
return f'<{cls.__name__}>'
|
||||
|
||||
|
||||
class _NONE(metaclass=_BareReprMeta):
|
||||
"""Sentinel type used to represent 'missing'."""
|
||||
|
||||
|
||||
OKT = _t.Union[KT, _NONE] #: optional key type
|
||||
OVT = _t.Union[VT, _NONE] #: optional value type
|
|
@ -1,78 +0,0 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2009-2019 Joshua Bronson. All Rights Reserved.
|
||||
#
|
||||
# This Source Code Form is subject to the terms of the Mozilla Public
|
||||
# License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
|
||||
"""Compatibility helpers."""
|
||||
|
||||
from operator import methodcaller
|
||||
from platform import python_implementation
|
||||
from sys import version_info
|
||||
from warnings import warn
|
||||
|
||||
|
||||
# Use #: (before or) at the end of each line with a member we want to show up in the docs,
|
||||
# otherwise Sphinx won't include (even though we configure automodule with undoc-members).
|
||||
|
||||
PYMAJOR, PYMINOR = version_info[:2] #:
|
||||
PY2 = PYMAJOR == 2 #:
|
||||
PYIMPL = python_implementation() #:
|
||||
CPY = PYIMPL == 'CPython' #:
|
||||
PYPY = PYIMPL == 'PyPy' #:
|
||||
DICTS_ORDERED = PYPY or (CPY and (PYMAJOR, PYMINOR) >= (3, 6)) #:
|
||||
|
||||
# Without the following, pylint gives lots of false positives.
|
||||
# pylint: disable=invalid-name,unused-import,ungrouped-imports,no-name-in-module
|
||||
|
||||
if PY2:
|
||||
if PYMINOR < 7: # pragma: no cover
|
||||
raise ImportError('Python 2.7 or 3.5+ is required.')
|
||||
warn('Python 2 support will be dropped in a future release.')
|
||||
|
||||
# abstractproperty deprecated in Python 3.3 in favor of using @property with @abstractmethod.
|
||||
# Before 3.3, this silently fails to detect when an abstract property has not been overridden.
|
||||
from abc import abstractproperty #:
|
||||
|
||||
from itertools import izip #:
|
||||
|
||||
# In Python 3, the collections ABCs were moved into collections.abc, which does not exist in
|
||||
# Python 2. Support for importing them directly from collections is dropped in Python 3.8.
|
||||
import collections as collections_abc # noqa: F401 (imported but unused)
|
||||
from collections import ( # noqa: F401 (imported but unused)
|
||||
Mapping, MutableMapping, KeysView, ValuesView, ItemsView)
|
||||
|
||||
viewkeys = lambda m: m.viewkeys() if hasattr(m, 'viewkeys') else KeysView(m) #:
|
||||
viewvalues = lambda m: m.viewvalues() if hasattr(m, 'viewvalues') else ValuesView(m) #:
|
||||
viewitems = lambda m: m.viewitems() if hasattr(m, 'viewitems') else ItemsView(m) #:
|
||||
|
||||
iterkeys = lambda m: m.iterkeys() if hasattr(m, 'iterkeys') else iter(m.keys()) #:
|
||||
itervalues = lambda m: m.itervalues() if hasattr(m, 'itervalues') else iter(m.values()) #:
|
||||
iteritems = lambda m: m.iteritems() if hasattr(m, 'iteritems') else iter(m.items()) #:
|
||||
|
||||
else:
|
||||
# Assume Python 3 when not PY2, but explicitly check before showing this warning.
|
||||
if PYMAJOR == 3 and PYMINOR < 5: # pragma: no cover
|
||||
warn('Python 3.4 and below are not supported.')
|
||||
|
||||
import collections.abc as collections_abc # noqa: F401 (imported but unused)
|
||||
from collections.abc import ( # noqa: F401 (imported but unused)
|
||||
Mapping, MutableMapping, KeysView, ValuesView, ItemsView)
|
||||
|
||||
viewkeys = methodcaller('keys') #:
|
||||
viewvalues = methodcaller('values') #:
|
||||
viewitems = methodcaller('items') #:
|
||||
|
||||
def _compose(f, g):
|
||||
return lambda x: f(g(x))
|
||||
|
||||
iterkeys = _compose(iter, viewkeys) #:
|
||||
itervalues = _compose(iter, viewvalues) #:
|
||||
iteritems = _compose(iter, viewitems) #:
|
||||
|
||||
from abc import abstractmethod
|
||||
abstractproperty = _compose(property, abstractmethod) #:
|
||||
|
||||
izip = zip #:
|
|
@ -1,5 +1,5 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2009-2019 Joshua Bronson. All Rights Reserved.
|
||||
# Copyright 2009-2021 Joshua Bronson. All Rights Reserved.
|
||||
#
|
||||
# This Source Code Form is subject to the terms of the Mozilla Public
|
||||
# License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
|
@ -8,42 +8,22 @@
|
|||
"""Define bidict package metadata."""
|
||||
|
||||
|
||||
__version__ = '0.0.0.VERSION_NOT_FOUND'
|
||||
|
||||
# _version.py is generated by setuptools_scm (via its `write_to` param, see setup.py)
|
||||
try:
|
||||
from ._version import version as __version__ # pylint: disable=unused-import
|
||||
except (ImportError, ValueError, SystemError): # pragma: no cover
|
||||
try:
|
||||
import pkg_resources
|
||||
except ImportError:
|
||||
pass
|
||||
else:
|
||||
try:
|
||||
__version__ = pkg_resources.get_distribution('bidict').version
|
||||
except pkg_resources.DistributionNotFound:
|
||||
pass
|
||||
|
||||
try:
|
||||
__version_info__ = tuple(int(p) if i < 3 else p for (i, p) in enumerate(__version__.split('.')))
|
||||
except Exception: # noqa: E722; pragma: no cover; pylint: disable=broad-except
|
||||
__vesion_info__ = (0, 0, 0, 'PARSE FAILURE: __version__=%s' % __version__)
|
||||
|
||||
__author__ = u'Joshua Bronson'
|
||||
__maintainer__ = u'Joshua Bronson'
|
||||
__copyright__ = u'Copyright 2019 Joshua Bronson'
|
||||
__email__ = u'jab@math.brown.edu'
|
||||
__version__ = '0.21.4'
|
||||
__author__ = 'Joshua Bronson'
|
||||
__maintainer__ = 'Joshua Bronson'
|
||||
__copyright__ = 'Copyright 2009-2021 Joshua Bronson'
|
||||
__email__ = 'jabronson@gmail.com'
|
||||
|
||||
# See: ../docs/thanks.rst
|
||||
__credits__ = [i.strip() for i in u"""
|
||||
__credits__ = [i.strip() for i in """
|
||||
Joshua Bronson, Michael Arntzenius, Francis Carr, Gregory Ewing, Raymond Hettinger, Jozef Knaperek,
|
||||
Daniel Pope, Terry Reedy, David Turner, Tom Viner, Richard Sanger, Zeyi Wang
|
||||
""".split(u',')]
|
||||
""".split(',')]
|
||||
|
||||
__description__ = u'Efficient, Pythonic bidirectional map implementation and related functionality'
|
||||
__description__ = 'The bidirectional mapping library for Python.'
|
||||
__keywords__ = 'dict dictionary mapping datastructure bimap bijection bijective ' \
|
||||
'injective inverse reverse bidirectional two-way 2-way'
|
||||
|
||||
__license__ = u'MPL 2.0'
|
||||
__status__ = u'Beta'
|
||||
__url__ = u'https://bidict.readthedocs.io'
|
||||
__license__ = 'MPL 2.0'
|
||||
__status__ = 'Beta'
|
||||
__url__ = 'https://bidict.readthedocs.io'
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
"""Beautiful Soup
|
||||
Elixir and Tonic
|
||||
"The Screen-Scraper's Friend"
|
||||
"""Beautiful Soup Elixir and Tonic - "The Screen-Scraper's Friend".
|
||||
|
||||
http://www.crummy.com/software/BeautifulSoup/
|
||||
|
||||
Beautiful Soup uses a pluggable XML or HTML parser to parse a
|
||||
|
@ -8,29 +7,34 @@ Beautiful Soup uses a pluggable XML or HTML parser to parse a
|
|||
provides methods and Pythonic idioms that make it easy to navigate,
|
||||
search, and modify the parse tree.
|
||||
|
||||
Beautiful Soup works with Python 2.7 and up. It works better if lxml
|
||||
Beautiful Soup works with Python 3.5 and up. It works better if lxml
|
||||
and/or html5lib is installed.
|
||||
|
||||
For more than you ever wanted to know about Beautiful Soup, see the
|
||||
documentation:
|
||||
http://www.crummy.com/software/BeautifulSoup/bs4/doc/
|
||||
|
||||
documentation: http://www.crummy.com/software/BeautifulSoup/bs4/doc/
|
||||
"""
|
||||
|
||||
__author__ = "Leonard Richardson (leonardr@segfault.org)"
|
||||
__version__ = "4.8.0"
|
||||
__copyright__ = "Copyright (c) 2004-2019 Leonard Richardson"
|
||||
__version__ = "4.10.0"
|
||||
__copyright__ = "Copyright (c) 2004-2021 Leonard Richardson"
|
||||
# Use of this source code is governed by the MIT license.
|
||||
__license__ = "MIT"
|
||||
|
||||
__all__ = ['BeautifulSoup']
|
||||
|
||||
|
||||
from collections import Counter
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import traceback
|
||||
import warnings
|
||||
|
||||
# The very first thing we do is give a useful error if someone is
|
||||
# running this code under Python 2.
|
||||
if sys.version_info.major < 3:
|
||||
raise ImportError('You are trying to use a Python 3-specific version of Beautiful Soup under Python 2. This will not work. The final version of Beautiful Soup to support Python 2 was 4.9.3.')
|
||||
|
||||
from .builder import builder_registry, ParserRejectedMarkup
|
||||
from .dammit import UnicodeDammit
|
||||
from .element import (
|
||||
|
@ -42,28 +46,49 @@ from .element import (
|
|||
NavigableString,
|
||||
PageElement,
|
||||
ProcessingInstruction,
|
||||
PYTHON_SPECIFIC_ENCODINGS,
|
||||
ResultSet,
|
||||
Script,
|
||||
Stylesheet,
|
||||
SoupStrainer,
|
||||
Tag,
|
||||
TemplateString,
|
||||
)
|
||||
|
||||
# The very first thing we do is give a useful error if someone is
|
||||
# running this code under Python 3 without converting it.
|
||||
'You are trying to run the Python 2 version of Beautiful Soup under Python 3. This will not work.'!='You need to convert the code, either by installing it (`python setup.py install`) or by running 2to3 (`2to3 -w bs4`).'
|
||||
# Define some custom warnings.
|
||||
class GuessedAtParserWarning(UserWarning):
|
||||
"""The warning issued when BeautifulSoup has to guess what parser to
|
||||
use -- probably because no parser was specified in the constructor.
|
||||
"""
|
||||
|
||||
class MarkupResemblesLocatorWarning(UserWarning):
|
||||
"""The warning issued when BeautifulSoup is given 'markup' that
|
||||
actually looks like a resource locator -- a URL or a path to a file
|
||||
on disk.
|
||||
"""
|
||||
|
||||
|
||||
class BeautifulSoup(Tag):
|
||||
"""
|
||||
This class defines the basic interface called by the tree builders.
|
||||
"""A data structure representing a parsed HTML or XML document.
|
||||
|
||||
These methods will be called by the parser:
|
||||
reset()
|
||||
feed(markup)
|
||||
Most of the methods you'll call on a BeautifulSoup object are inherited from
|
||||
PageElement or Tag.
|
||||
|
||||
Internally, this class defines the basic interface called by the
|
||||
tree builders when converting an HTML/XML document into a data
|
||||
structure. The interface abstracts away the differences between
|
||||
parsers. To write a new tree builder, you'll need to understand
|
||||
these methods as a whole.
|
||||
|
||||
These methods will be called by the BeautifulSoup constructor:
|
||||
* reset()
|
||||
* feed(markup)
|
||||
|
||||
The tree builder may call these methods from its feed() implementation:
|
||||
handle_starttag(name, attrs) # See note about return value
|
||||
handle_endtag(name)
|
||||
handle_data(data) # Appends to the current data node
|
||||
endData(containerClass=NavigableString) # Ends the current data node
|
||||
* handle_starttag(name, attrs) # See note about return value
|
||||
* handle_endtag(name)
|
||||
* handle_data(data) # Appends to the current data node
|
||||
* endData(containerClass) # Ends the current data node
|
||||
|
||||
No matter how complicated the underlying parser is, you should be
|
||||
able to build a tree using 'start tag' events, 'end tag' events,
|
||||
|
@ -73,62 +98,75 @@ class BeautifulSoup(Tag):
|
|||
like HTML's <br> tag), call handle_starttag and then
|
||||
handle_endtag.
|
||||
"""
|
||||
|
||||
# Since BeautifulSoup subclasses Tag, it's possible to treat it as
|
||||
# a Tag with a .name. This name makes it clear the BeautifulSoup
|
||||
# object isn't a real markup tag.
|
||||
ROOT_TAG_NAME = '[document]'
|
||||
|
||||
# If the end-user gives no indication which tree builder they
|
||||
# want, look for one with these features.
|
||||
DEFAULT_BUILDER_FEATURES = ['html', 'fast']
|
||||
|
||||
# A string containing all ASCII whitespace characters, used in
|
||||
# endData() to detect data chunks that seem 'empty'.
|
||||
ASCII_SPACES = '\x20\x0a\x09\x0c\x0d'
|
||||
|
||||
NO_PARSER_SPECIFIED_WARNING = "No parser was explicitly specified, so I'm using the best available %(markup_type)s parser for this system (\"%(parser)s\"). This usually isn't a problem, but if you run this code on another system, or in a different virtual environment, it may use a different parser and behave differently.\n\nThe code that caused this warning is on line %(line_number)s of the file %(filename)s. To get rid of this warning, pass the additional argument 'features=\"%(parser)s\"' to the BeautifulSoup constructor.\n"
|
||||
|
||||
|
||||
def __init__(self, markup="", features=None, builder=None,
|
||||
parse_only=None, from_encoding=None, exclude_encodings=None,
|
||||
**kwargs):
|
||||
element_classes=None, **kwargs):
|
||||
"""Constructor.
|
||||
|
||||
:param markup: A string or a file-like object representing
|
||||
markup to be parsed.
|
||||
markup to be parsed.
|
||||
|
||||
:param features: Desirable features of the parser to be used. This
|
||||
may be the name of a specific parser ("lxml", "lxml-xml",
|
||||
"html.parser", or "html5lib") or it may be the type of markup
|
||||
to be used ("html", "html5", "xml"). It's recommended that you
|
||||
name a specific parser, so that Beautiful Soup gives you the
|
||||
same results across platforms and virtual environments.
|
||||
:param features: Desirable features of the parser to be
|
||||
used. This may be the name of a specific parser ("lxml",
|
||||
"lxml-xml", "html.parser", or "html5lib") or it may be the
|
||||
type of markup to be used ("html", "html5", "xml"). It's
|
||||
recommended that you name a specific parser, so that
|
||||
Beautiful Soup gives you the same results across platforms
|
||||
and virtual environments.
|
||||
|
||||
:param builder: A TreeBuilder subclass to instantiate (or
|
||||
instance to use) instead of looking one up based on
|
||||
`features`. You only need to use this if you've implemented a
|
||||
custom TreeBuilder.
|
||||
instance to use) instead of looking one up based on
|
||||
`features`. You only need to use this if you've implemented a
|
||||
custom TreeBuilder.
|
||||
|
||||
:param parse_only: A SoupStrainer. Only parts of the document
|
||||
matching the SoupStrainer will be considered. This is useful
|
||||
when parsing part of a document that would otherwise be too
|
||||
large to fit into memory.
|
||||
matching the SoupStrainer will be considered. This is useful
|
||||
when parsing part of a document that would otherwise be too
|
||||
large to fit into memory.
|
||||
|
||||
:param from_encoding: A string indicating the encoding of the
|
||||
document to be parsed. Pass this in if Beautiful Soup is
|
||||
guessing wrongly about the document's encoding.
|
||||
document to be parsed. Pass this in if Beautiful Soup is
|
||||
guessing wrongly about the document's encoding.
|
||||
|
||||
:param exclude_encodings: A list of strings indicating
|
||||
encodings known to be wrong. Pass this in if you don't know
|
||||
the document's encoding but you know Beautiful Soup's guess is
|
||||
wrong.
|
||||
encodings known to be wrong. Pass this in if you don't know
|
||||
the document's encoding but you know Beautiful Soup's guess is
|
||||
wrong.
|
||||
|
||||
:param element_classes: A dictionary mapping BeautifulSoup
|
||||
classes like Tag and NavigableString, to other classes you'd
|
||||
like to be instantiated instead as the parse tree is
|
||||
built. This is useful for subclassing Tag or NavigableString
|
||||
to modify default behavior.
|
||||
|
||||
:param kwargs: For backwards compatibility purposes, the
|
||||
constructor accepts certain keyword arguments used in
|
||||
Beautiful Soup 3. None of these arguments do anything in
|
||||
Beautiful Soup 4; they will result in a warning and then be ignored.
|
||||
|
||||
Apart from this, any keyword arguments passed into the BeautifulSoup
|
||||
constructor are propagated to the TreeBuilder constructor. This
|
||||
makes it possible to configure a TreeBuilder beyond saying
|
||||
which one to use.
|
||||
|
||||
constructor accepts certain keyword arguments used in
|
||||
Beautiful Soup 3. None of these arguments do anything in
|
||||
Beautiful Soup 4; they will result in a warning and then be
|
||||
ignored.
|
||||
|
||||
Apart from this, any keyword arguments passed into the
|
||||
BeautifulSoup constructor are propagated to the TreeBuilder
|
||||
constructor. This makes it possible to configure a
|
||||
TreeBuilder by passing in arguments, not just by saying which
|
||||
one to use.
|
||||
"""
|
||||
|
||||
if 'convertEntities' in kwargs:
|
||||
del kwargs['convertEntities']
|
||||
warnings.warn(
|
||||
|
@ -185,6 +223,8 @@ class BeautifulSoup(Tag):
|
|||
warnings.warn("You provided Unicode markup but also provided a value for from_encoding. Your from_encoding will be ignored.")
|
||||
from_encoding = None
|
||||
|
||||
self.element_classes = element_classes or dict()
|
||||
|
||||
# We need this information to track whether or not the builder
|
||||
# was specified well enough that we can omit the 'you need to
|
||||
# specify a parser' warning.
|
||||
|
@ -215,7 +255,9 @@ class BeautifulSoup(Tag):
|
|||
if not original_builder and not (
|
||||
original_features == builder.NAME or
|
||||
original_features in builder.ALTERNATE_NAMES
|
||||
):
|
||||
) and markup:
|
||||
# The user did not tell us which TreeBuilder to use,
|
||||
# and we had to guess. Issue a warning.
|
||||
if builder.is_xml:
|
||||
markup_type = "XML"
|
||||
else:
|
||||
|
@ -249,7 +291,10 @@ class BeautifulSoup(Tag):
|
|||
parser=builder.NAME,
|
||||
markup_type=markup_type
|
||||
)
|
||||
warnings.warn(self.NO_PARSER_SPECIFIED_WARNING % values, stacklevel=2)
|
||||
warnings.warn(
|
||||
self.NO_PARSER_SPECIFIED_WARNING % values,
|
||||
GuessedAtParserWarning, stacklevel=2
|
||||
)
|
||||
else:
|
||||
if kwargs:
|
||||
warnings.warn("Keyword arguments to the BeautifulSoup constructor will be ignored. These would normally be passed into the TreeBuilder constructor, but a TreeBuilder instance was passed in as `builder`.")
|
||||
|
@ -278,22 +323,36 @@ class BeautifulSoup(Tag):
|
|||
else:
|
||||
possible_filename = markup
|
||||
is_file = False
|
||||
is_directory = False
|
||||
try:
|
||||
is_file = os.path.exists(possible_filename)
|
||||
if is_file:
|
||||
is_directory = os.path.isdir(possible_filename)
|
||||
except Exception as e:
|
||||
# This is almost certainly a problem involving
|
||||
# characters not valid in filenames on this
|
||||
# system. Just let it go.
|
||||
pass
|
||||
if is_file:
|
||||
if isinstance(markup, str):
|
||||
markup = markup.encode("utf8")
|
||||
if is_directory:
|
||||
warnings.warn(
|
||||
'"%s" looks like a directory name, not markup. You may'
|
||||
' want to open a file found in this directory and pass'
|
||||
' the filehandle into Beautiful Soup.' % (
|
||||
self._decode_markup(markup)
|
||||
),
|
||||
MarkupResemblesLocatorWarning
|
||||
)
|
||||
elif is_file:
|
||||
warnings.warn(
|
||||
'"%s" looks like a filename, not markup. You should'
|
||||
' probably open this file and pass the filehandle into'
|
||||
' Beautiful Soup.' % markup)
|
||||
' Beautiful Soup.' % self._decode_markup(markup),
|
||||
MarkupResemblesLocatorWarning
|
||||
)
|
||||
self._check_markup_is_url(markup)
|
||||
|
||||
rejections = []
|
||||
success = False
|
||||
for (self.markup, self.original_encoding, self.declared_html_encoding,
|
||||
self.contains_replacement_characters) in (
|
||||
self.builder.prepare_markup(
|
||||
|
@ -301,16 +360,25 @@ class BeautifulSoup(Tag):
|
|||
self.reset()
|
||||
try:
|
||||
self._feed()
|
||||
success = True
|
||||
break
|
||||
except ParserRejectedMarkup:
|
||||
except ParserRejectedMarkup as e:
|
||||
rejections.append(e)
|
||||
pass
|
||||
|
||||
if not success:
|
||||
other_exceptions = [str(e) for e in rejections]
|
||||
raise ParserRejectedMarkup(
|
||||
"The markup you provided was rejected by the parser. Trying a different parser or a different encoding may help.\n\nOriginal exception(s) from parser:\n " + "\n ".join(other_exceptions)
|
||||
)
|
||||
|
||||
# Clear out the markup and remove the builder's circular
|
||||
# reference to this object.
|
||||
self.markup = None
|
||||
self.builder.soup = None
|
||||
|
||||
def __copy__(self):
|
||||
"""Copy a BeautifulSoup object by converting the document to a string and parsing it again."""
|
||||
copy = type(self)(
|
||||
self.encode('utf-8'), builder=self.builder, from_encoding='utf-8'
|
||||
)
|
||||
|
@ -329,11 +397,25 @@ class BeautifulSoup(Tag):
|
|||
d['builder'] = None
|
||||
return d
|
||||
|
||||
@staticmethod
|
||||
def _check_markup_is_url(markup):
|
||||
"""
|
||||
Check if markup looks like it's actually a url and raise a warning
|
||||
if so. Markup can be unicode or str (py2) / bytes (py3).
|
||||
@classmethod
|
||||
def _decode_markup(cls, markup):
|
||||
"""Ensure `markup` is bytes so it's safe to send into warnings.warn.
|
||||
|
||||
TODO: warnings.warn had this problem back in 2010 but it might not
|
||||
anymore.
|
||||
"""
|
||||
if isinstance(markup, bytes):
|
||||
decoded = markup.decode('utf-8', 'replace')
|
||||
else:
|
||||
decoded = markup
|
||||
return decoded
|
||||
|
||||
@classmethod
|
||||
def _check_markup_is_url(cls, markup):
|
||||
"""Error-handling method to raise a warning if incoming markup looks
|
||||
like a URL.
|
||||
|
||||
:param markup: A string.
|
||||
"""
|
||||
if isinstance(markup, bytes):
|
||||
space = b' '
|
||||
|
@ -346,18 +428,20 @@ class BeautifulSoup(Tag):
|
|||
|
||||
if any(markup.startswith(prefix) for prefix in cant_start_with):
|
||||
if not space in markup:
|
||||
if isinstance(markup, bytes):
|
||||
decoded_markup = markup.decode('utf-8', 'replace')
|
||||
else:
|
||||
decoded_markup = markup
|
||||
warnings.warn(
|
||||
'"%s" looks like a URL. Beautiful Soup is not an'
|
||||
' HTTP client. You should probably use an HTTP client like'
|
||||
' requests to get the document behind the URL, and feed'
|
||||
' that document to Beautiful Soup.' % decoded_markup
|
||||
' that document to Beautiful Soup.' % cls._decode_markup(
|
||||
markup
|
||||
),
|
||||
MarkupResemblesLocatorWarning
|
||||
)
|
||||
|
||||
def _feed(self):
|
||||
"""Internal method that parses previously set markup, creating a large
|
||||
number of Tag and NavigableString objects.
|
||||
"""
|
||||
# Convert the document to Unicode.
|
||||
self.builder.reset()
|
||||
|
||||
|
@ -368,49 +452,110 @@ class BeautifulSoup(Tag):
|
|||
self.popTag()
|
||||
|
||||
def reset(self):
|
||||
"""Reset this object to a state as though it had never parsed any
|
||||
markup.
|
||||
"""
|
||||
Tag.__init__(self, self, self.builder, self.ROOT_TAG_NAME)
|
||||
self.hidden = 1
|
||||
self.builder.reset()
|
||||
self.current_data = []
|
||||
self.currentTag = None
|
||||
self.tagStack = []
|
||||
self.open_tag_counter = Counter()
|
||||
self.preserve_whitespace_tag_stack = []
|
||||
self.string_container_stack = []
|
||||
self.pushTag(self)
|
||||
|
||||
def new_tag(self, name, namespace=None, nsprefix=None, attrs={}, **kwattrs):
|
||||
"""Create a new tag associated with this soup."""
|
||||
def new_tag(self, name, namespace=None, nsprefix=None, attrs={},
|
||||
sourceline=None, sourcepos=None, **kwattrs):
|
||||
"""Create a new Tag associated with this BeautifulSoup object.
|
||||
|
||||
:param name: The name of the new Tag.
|
||||
:param namespace: The URI of the new Tag's XML namespace, if any.
|
||||
:param prefix: The prefix for the new Tag's XML namespace, if any.
|
||||
:param attrs: A dictionary of this Tag's attribute values; can
|
||||
be used instead of `kwattrs` for attributes like 'class'
|
||||
that are reserved words in Python.
|
||||
:param sourceline: The line number where this tag was
|
||||
(purportedly) found in its source document.
|
||||
:param sourcepos: The character position within `sourceline` where this
|
||||
tag was (purportedly) found.
|
||||
:param kwattrs: Keyword arguments for the new Tag's attribute values.
|
||||
|
||||
"""
|
||||
kwattrs.update(attrs)
|
||||
return Tag(None, self.builder, name, namespace, nsprefix, kwattrs)
|
||||
return self.element_classes.get(Tag, Tag)(
|
||||
None, self.builder, name, namespace, nsprefix, kwattrs,
|
||||
sourceline=sourceline, sourcepos=sourcepos
|
||||
)
|
||||
|
||||
def new_string(self, s, subclass=NavigableString):
|
||||
"""Create a new NavigableString associated with this soup."""
|
||||
return subclass(s)
|
||||
def string_container(self, base_class=None):
|
||||
container = base_class or NavigableString
|
||||
|
||||
# There may be a general override of NavigableString.
|
||||
container = self.element_classes.get(
|
||||
container, container
|
||||
)
|
||||
|
||||
def insert_before(self, successor):
|
||||
# On top of that, we may be inside a tag that needs a special
|
||||
# container class.
|
||||
if self.string_container_stack and container is NavigableString:
|
||||
container = self.builder.string_containers.get(
|
||||
self.string_container_stack[-1].name, container
|
||||
)
|
||||
return container
|
||||
|
||||
def new_string(self, s, subclass=None):
|
||||
"""Create a new NavigableString associated with this BeautifulSoup
|
||||
object.
|
||||
"""
|
||||
container = self.string_container(subclass)
|
||||
return container(s)
|
||||
|
||||
def insert_before(self, *args):
|
||||
"""This method is part of the PageElement API, but `BeautifulSoup` doesn't implement
|
||||
it because there is nothing before or after it in the parse tree.
|
||||
"""
|
||||
raise NotImplementedError("BeautifulSoup objects don't support insert_before().")
|
||||
|
||||
def insert_after(self, successor):
|
||||
def insert_after(self, *args):
|
||||
"""This method is part of the PageElement API, but `BeautifulSoup` doesn't implement
|
||||
it because there is nothing before or after it in the parse tree.
|
||||
"""
|
||||
raise NotImplementedError("BeautifulSoup objects don't support insert_after().")
|
||||
|
||||
def popTag(self):
|
||||
"""Internal method called by _popToTag when a tag is closed."""
|
||||
tag = self.tagStack.pop()
|
||||
if tag.name in self.open_tag_counter:
|
||||
self.open_tag_counter[tag.name] -= 1
|
||||
if self.preserve_whitespace_tag_stack and tag == self.preserve_whitespace_tag_stack[-1]:
|
||||
self.preserve_whitespace_tag_stack.pop()
|
||||
#print "Pop", tag.name
|
||||
if self.string_container_stack and tag == self.string_container_stack[-1]:
|
||||
self.string_container_stack.pop()
|
||||
#print("Pop", tag.name)
|
||||
if self.tagStack:
|
||||
self.currentTag = self.tagStack[-1]
|
||||
return self.currentTag
|
||||
|
||||
def pushTag(self, tag):
|
||||
#print "Push", tag.name
|
||||
"""Internal method called by handle_starttag when a tag is opened."""
|
||||
#print("Push", tag.name)
|
||||
if self.currentTag is not None:
|
||||
self.currentTag.contents.append(tag)
|
||||
self.tagStack.append(tag)
|
||||
self.currentTag = self.tagStack[-1]
|
||||
if tag.name != self.ROOT_TAG_NAME:
|
||||
self.open_tag_counter[tag.name] += 1
|
||||
if tag.name in self.builder.preserve_whitespace_tags:
|
||||
self.preserve_whitespace_tag_stack.append(tag)
|
||||
if tag.name in self.builder.string_containers:
|
||||
self.string_container_stack.append(tag)
|
||||
|
||||
def endData(self, containerClass=NavigableString):
|
||||
def endData(self, containerClass=None):
|
||||
"""Method called by the TreeBuilder when the end of a data segment
|
||||
occurs.
|
||||
"""
|
||||
if self.current_data:
|
||||
current_data = ''.join(self.current_data)
|
||||
# If whitespace is not preserved, and this string contains
|
||||
|
@ -437,11 +582,12 @@ class BeautifulSoup(Tag):
|
|||
not self.parse_only.search(current_data)):
|
||||
return
|
||||
|
||||
containerClass = self.string_container(containerClass)
|
||||
o = containerClass(current_data)
|
||||
self.object_was_parsed(o)
|
||||
|
||||
def object_was_parsed(self, o, parent=None, most_recent_element=None):
|
||||
"""Add an object to the parse tree."""
|
||||
"""Method called by the TreeBuilder to integrate an object into the parse tree."""
|
||||
if parent is None:
|
||||
parent = self.currentTag
|
||||
if most_recent_element is not None:
|
||||
|
@ -510,10 +656,19 @@ class BeautifulSoup(Tag):
|
|||
|
||||
def _popToTag(self, name, nsprefix=None, inclusivePop=True):
|
||||
"""Pops the tag stack up to and including the most recent
|
||||
instance of the given tag. If inclusivePop is false, pops the tag
|
||||
stack up to but *not* including the most recent instqance of
|
||||
the given tag."""
|
||||
#print "Popping to %s" % name
|
||||
instance of the given tag.
|
||||
|
||||
If there are no open tags with the given name, nothing will be
|
||||
popped.
|
||||
|
||||
:param name: Pop up to the most recent tag with this name.
|
||||
:param nsprefix: The namespace prefix that goes with `name`.
|
||||
:param inclusivePop: It this is false, pops the tag stack up
|
||||
to but *not* including the most recent instqance of the
|
||||
given tag.
|
||||
|
||||
"""
|
||||
#print("Popping to %s" % name)
|
||||
if name == self.ROOT_TAG_NAME:
|
||||
# The BeautifulSoup object itself can never be popped.
|
||||
return
|
||||
|
@ -522,6 +677,8 @@ class BeautifulSoup(Tag):
|
|||
|
||||
stack_size = len(self.tagStack)
|
||||
for i in range(stack_size - 1, 0, -1):
|
||||
if not self.open_tag_counter.get(name):
|
||||
break
|
||||
t = self.tagStack[i]
|
||||
if (name == t.name and nsprefix == t.prefix):
|
||||
if inclusivePop:
|
||||
|
@ -531,16 +688,24 @@ class BeautifulSoup(Tag):
|
|||
|
||||
return most_recently_popped
|
||||
|
||||
def handle_starttag(self, name, namespace, nsprefix, attrs):
|
||||
"""Push a start tag on to the stack.
|
||||
def handle_starttag(self, name, namespace, nsprefix, attrs, sourceline=None,
|
||||
sourcepos=None):
|
||||
"""Called by the tree builder when a new tag is encountered.
|
||||
|
||||
If this method returns None, the tag was rejected by the
|
||||
:param name: Name of the tag.
|
||||
:param nsprefix: Namespace prefix for the tag.
|
||||
:param attrs: A dictionary of attribute values.
|
||||
:param sourceline: The line number where this tag was found in its
|
||||
source document.
|
||||
:param sourcepos: The character position within `sourceline` where this
|
||||
tag was found.
|
||||
|
||||
If this method returns None, the tag was rejected by an active
|
||||
SoupStrainer. You should proceed as if the tag had not occurred
|
||||
in the document. For instance, if this was a self-closing tag,
|
||||
don't call handle_endtag.
|
||||
"""
|
||||
|
||||
# print "Start tag %s: %s" % (name, attrs)
|
||||
# print("Start tag %s: %s" % (name, attrs))
|
||||
self.endData()
|
||||
|
||||
if (self.parse_only and len(self.tagStack) <= 1
|
||||
|
@ -548,8 +713,11 @@ class BeautifulSoup(Tag):
|
|||
or not self.parse_only.search_tag(name, attrs))):
|
||||
return None
|
||||
|
||||
tag = Tag(self, self.builder, name, namespace, nsprefix, attrs,
|
||||
self.currentTag, self._most_recent_element)
|
||||
tag = self.element_classes.get(Tag, Tag)(
|
||||
self, self.builder, name, namespace, nsprefix, attrs,
|
||||
self.currentTag, self._most_recent_element,
|
||||
sourceline=sourceline, sourcepos=sourcepos
|
||||
)
|
||||
if tag is None:
|
||||
return tag
|
||||
if self._most_recent_element is not None:
|
||||
|
@ -559,22 +727,38 @@ class BeautifulSoup(Tag):
|
|||
return tag
|
||||
|
||||
def handle_endtag(self, name, nsprefix=None):
|
||||
#print "End tag: " + name
|
||||
"""Called by the tree builder when an ending tag is encountered.
|
||||
|
||||
:param name: Name of the tag.
|
||||
:param nsprefix: Namespace prefix for the tag.
|
||||
"""
|
||||
#print("End tag: " + name)
|
||||
self.endData()
|
||||
self._popToTag(name, nsprefix)
|
||||
|
||||
def handle_data(self, data):
|
||||
"""Called by the tree builder when a chunk of textual data is encountered."""
|
||||
self.current_data.append(data)
|
||||
|
||||
|
||||
def decode(self, pretty_print=False,
|
||||
eventual_encoding=DEFAULT_OUTPUT_ENCODING,
|
||||
formatter="minimal"):
|
||||
"""Returns a string or Unicode representation of this document.
|
||||
To get Unicode, pass None for encoding."""
|
||||
"""Returns a string or Unicode representation of the parse tree
|
||||
as an HTML or XML document.
|
||||
|
||||
:param pretty_print: If this is True, indentation will be used to
|
||||
make the document more readable.
|
||||
:param eventual_encoding: The encoding of the final document.
|
||||
If this is None, the document will be a Unicode string.
|
||||
"""
|
||||
if self.is_xml:
|
||||
# Print the XML declaration
|
||||
encoding_part = ''
|
||||
if eventual_encoding in PYTHON_SPECIFIC_ENCODINGS:
|
||||
# This is a special Python encoding; it can't actually
|
||||
# go into an XML document because it means nothing
|
||||
# outside of Python.
|
||||
eventual_encoding = None
|
||||
if eventual_encoding != None:
|
||||
encoding_part = ' encoding="%s"' % eventual_encoding
|
||||
prefix = '<?xml version="1.0"%s?>\n' % encoding_part
|
||||
|
@ -587,7 +771,7 @@ class BeautifulSoup(Tag):
|
|||
return prefix + super(BeautifulSoup, self).decode(
|
||||
indent_level, eventual_encoding, formatter)
|
||||
|
||||
# Alias to make it easier to type import: 'from bs4 import _soup'
|
||||
# Aliases to make it easier to get started quickly, e.g. 'from bs4 import _soup'
|
||||
_s = BeautifulSoup
|
||||
_soup = BeautifulSoup
|
||||
|
||||
|
@ -603,14 +787,18 @@ class BeautifulStoneSoup(BeautifulSoup):
|
|||
|
||||
|
||||
class StopParsing(Exception):
|
||||
"""Exception raised by a TreeBuilder if it's unable to continue parsing."""
|
||||
pass
|
||||
|
||||
class FeatureNotFound(ValueError):
|
||||
"""Exception raised by the BeautifulSoup constructor if no parser with the
|
||||
requested features is found.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
#By default, act as an HTML pretty-printer.
|
||||
#If this file is run as a script, act as an HTML pretty-printer.
|
||||
if __name__ == '__main__':
|
||||
import sys
|
||||
soup = BeautifulSoup(sys.stdin)
|
||||
print(soup.prettify())
|
||||
print((soup.prettify()))
|
||||
|
|
|
@ -7,8 +7,11 @@ import sys
|
|||
from bs4.element import (
|
||||
CharsetMetaAttributeValue,
|
||||
ContentMetaAttributeValue,
|
||||
Stylesheet,
|
||||
Script,
|
||||
TemplateString,
|
||||
nonwhitespace_re
|
||||
)
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
'HTMLTreeBuilder',
|
||||
|
@ -27,18 +30,33 @@ HTML_5 = 'html5'
|
|||
|
||||
|
||||
class TreeBuilderRegistry(object):
|
||||
|
||||
"""A way of looking up TreeBuilder subclasses by their name or by desired
|
||||
features.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.builders_for_feature = defaultdict(list)
|
||||
self.builders = []
|
||||
|
||||
def register(self, treebuilder_class):
|
||||
"""Register a treebuilder based on its advertised features."""
|
||||
"""Register a treebuilder based on its advertised features.
|
||||
|
||||
:param treebuilder_class: A subclass of Treebuilder. its .features
|
||||
attribute should list its features.
|
||||
"""
|
||||
for feature in treebuilder_class.features:
|
||||
self.builders_for_feature[feature].insert(0, treebuilder_class)
|
||||
self.builders.insert(0, treebuilder_class)
|
||||
|
||||
def lookup(self, *features):
|
||||
"""Look up a TreeBuilder subclass with the desired features.
|
||||
|
||||
:param features: A list of features to look for. If none are
|
||||
provided, the most recently registered TreeBuilder subclass
|
||||
will be used.
|
||||
:return: A TreeBuilder subclass, or None if there's no
|
||||
registered subclass with all the requested features.
|
||||
"""
|
||||
if len(self.builders) == 0:
|
||||
# There are no builders at all.
|
||||
return None
|
||||
|
@ -81,7 +99,7 @@ class TreeBuilderRegistry(object):
|
|||
builder_registry = TreeBuilderRegistry()
|
||||
|
||||
class TreeBuilder(object):
|
||||
"""Turn a document into a Beautiful Soup object tree."""
|
||||
"""Turn a textual document into a Beautiful Soup object tree."""
|
||||
|
||||
NAME = "[Unknown tree builder]"
|
||||
ALTERNATE_NAMES = []
|
||||
|
@ -96,24 +114,53 @@ class TreeBuilder(object):
|
|||
# comma-separated list of CDATA, rather than a single CDATA.
|
||||
DEFAULT_CDATA_LIST_ATTRIBUTES = {}
|
||||
|
||||
# Whitespace should be preserved inside these tags.
|
||||
DEFAULT_PRESERVE_WHITESPACE_TAGS = set()
|
||||
|
||||
# The textual contents of tags with these names should be
|
||||
# instantiated with some class other than NavigableString.
|
||||
DEFAULT_STRING_CONTAINERS = {}
|
||||
|
||||
USE_DEFAULT = object()
|
||||
|
||||
# Most parsers don't keep track of line numbers.
|
||||
TRACKS_LINE_NUMBERS = False
|
||||
|
||||
def __init__(self, multi_valued_attributes=USE_DEFAULT, preserve_whitespace_tags=USE_DEFAULT):
|
||||
def __init__(self, multi_valued_attributes=USE_DEFAULT,
|
||||
preserve_whitespace_tags=USE_DEFAULT,
|
||||
store_line_numbers=USE_DEFAULT,
|
||||
string_containers=USE_DEFAULT,
|
||||
):
|
||||
"""Constructor.
|
||||
|
||||
:param multi_valued_attributes: If this is set to None, the
|
||||
TreeBuilder will not turn any values for attributes like
|
||||
'class' into lists. Setting this do a dictionary will
|
||||
customize this behavior; look at DEFAULT_CDATA_LIST_ATTRIBUTES
|
||||
for an example.
|
||||
TreeBuilder will not turn any values for attributes like
|
||||
'class' into lists. Setting this to a dictionary will
|
||||
customize this behavior; look at DEFAULT_CDATA_LIST_ATTRIBUTES
|
||||
for an example.
|
||||
|
||||
Internally, these are called "CDATA list attributes", but that
|
||||
probably doesn't make sense to an end-user, so the argument name
|
||||
is `multi_valued_attributes`.
|
||||
Internally, these are called "CDATA list attributes", but that
|
||||
probably doesn't make sense to an end-user, so the argument name
|
||||
is `multi_valued_attributes`.
|
||||
|
||||
:param preserve_whitespace_tags:
|
||||
:param preserve_whitespace_tags: A list of tags to treat
|
||||
the way <pre> tags are treated in HTML. Tags in this list
|
||||
are immune from pretty-printing; their contents will always be
|
||||
output as-is.
|
||||
|
||||
:param string_containers: A dictionary mapping tag names to
|
||||
the classes that should be instantiated to contain the textual
|
||||
contents of those tags. The default is to use NavigableString
|
||||
for every tag, no matter what the name. You can override the
|
||||
default by changing DEFAULT_STRING_CONTAINERS.
|
||||
|
||||
:param store_line_numbers: If the parser keeps track of the
|
||||
line numbers and positions of the original markup, that
|
||||
information will, by default, be stored in each corresponding
|
||||
`Tag` object. You can turn this off by passing
|
||||
store_line_numbers=False. If the parser you're using doesn't
|
||||
keep track of this information, then setting store_line_numbers=True
|
||||
will do nothing.
|
||||
"""
|
||||
self.soup = None
|
||||
if multi_valued_attributes is self.USE_DEFAULT:
|
||||
|
@ -122,14 +169,27 @@ class TreeBuilder(object):
|
|||
if preserve_whitespace_tags is self.USE_DEFAULT:
|
||||
preserve_whitespace_tags = self.DEFAULT_PRESERVE_WHITESPACE_TAGS
|
||||
self.preserve_whitespace_tags = preserve_whitespace_tags
|
||||
|
||||
if store_line_numbers == self.USE_DEFAULT:
|
||||
store_line_numbers = self.TRACKS_LINE_NUMBERS
|
||||
self.store_line_numbers = store_line_numbers
|
||||
if string_containers == self.USE_DEFAULT:
|
||||
string_containers = self.DEFAULT_STRING_CONTAINERS
|
||||
self.string_containers = string_containers
|
||||
|
||||
def initialize_soup(self, soup):
|
||||
"""The BeautifulSoup object has been initialized and is now
|
||||
being associated with the TreeBuilder.
|
||||
|
||||
:param soup: A BeautifulSoup object.
|
||||
"""
|
||||
self.soup = soup
|
||||
|
||||
def reset(self):
|
||||
"""Do any work necessary to reset the underlying parser
|
||||
for a new document.
|
||||
|
||||
By default, this does nothing.
|
||||
"""
|
||||
pass
|
||||
|
||||
def can_be_empty_element(self, tag_name):
|
||||
|
@ -141,24 +201,58 @@ class TreeBuilder(object):
|
|||
For instance: an HTMLBuilder does not consider a <p> tag to be
|
||||
an empty-element tag (it's not in
|
||||
HTMLBuilder.empty_element_tags). This means an empty <p> tag
|
||||
will be presented as "<p></p>", not "<p />".
|
||||
will be presented as "<p></p>", not "<p/>" or "<p>".
|
||||
|
||||
The default implementation has no opinion about which tags are
|
||||
empty-element tags, so a tag will be presented as an
|
||||
empty-element tag if and only if it has no contents.
|
||||
"<foo></foo>" will become "<foo />", and "<foo>bar</foo>" will
|
||||
empty-element tag if and only if it has no children.
|
||||
"<foo></foo>" will become "<foo/>", and "<foo>bar</foo>" will
|
||||
be left alone.
|
||||
|
||||
:param tag_name: The name of a markup tag.
|
||||
"""
|
||||
if self.empty_element_tags is None:
|
||||
return True
|
||||
return tag_name in self.empty_element_tags
|
||||
|
||||
def feed(self, markup):
|
||||
"""Run some incoming markup through some parsing process,
|
||||
populating the `BeautifulSoup` object in self.soup.
|
||||
|
||||
This method is not implemented in TreeBuilder; it must be
|
||||
implemented in subclasses.
|
||||
|
||||
:return: None.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def prepare_markup(self, markup, user_specified_encoding=None,
|
||||
document_declared_encoding=None):
|
||||
return markup, None, None, False
|
||||
document_declared_encoding=None, exclude_encodings=None):
|
||||
"""Run any preliminary steps necessary to make incoming markup
|
||||
acceptable to the parser.
|
||||
|
||||
:param markup: Some markup -- probably a bytestring.
|
||||
:param user_specified_encoding: The user asked to try this encoding.
|
||||
:param document_declared_encoding: The markup itself claims to be
|
||||
in this encoding. NOTE: This argument is not used by the
|
||||
calling code and can probably be removed.
|
||||
:param exclude_encodings: The user asked _not_ to try any of
|
||||
these encodings.
|
||||
|
||||
:yield: A series of 4-tuples:
|
||||
(markup, encoding, declared encoding,
|
||||
has undergone character replacement)
|
||||
|
||||
Each 4-tuple represents a strategy for converting the
|
||||
document to Unicode and parsing it. Each strategy will be tried
|
||||
in turn.
|
||||
|
||||
By default, the only strategy is to parse the markup
|
||||
as-is. See `LXMLTreeBuilderForXML` and
|
||||
`HTMLParserTreeBuilder` for implementations that take into
|
||||
account the quirks of particular parsers.
|
||||
"""
|
||||
yield markup, None, None, False
|
||||
|
||||
def test_fragment_to_document(self, fragment):
|
||||
"""Wrap an HTML fragment to make it look like a document.
|
||||
|
@ -170,16 +264,36 @@ class TreeBuilder(object):
|
|||
results against other HTML fragments.
|
||||
|
||||
This method should not be used outside of tests.
|
||||
|
||||
:param fragment: A string -- fragment of HTML.
|
||||
:return: A string -- a full HTML document.
|
||||
"""
|
||||
return fragment
|
||||
|
||||
def set_up_substitutions(self, tag):
|
||||
"""Set up any substitutions that will need to be performed on
|
||||
a `Tag` when it's output as a string.
|
||||
|
||||
By default, this does nothing. See `HTMLTreeBuilder` for a
|
||||
case where this is used.
|
||||
|
||||
:param tag: A `Tag`
|
||||
:return: Whether or not a substitution was performed.
|
||||
"""
|
||||
return False
|
||||
|
||||
def _replace_cdata_list_attribute_values(self, tag_name, attrs):
|
||||
"""Replaces class="foo bar" with class=["foo", "bar"]
|
||||
"""When an attribute value is associated with a tag that can
|
||||
have multiple values for that attribute, convert the string
|
||||
value to a list of strings.
|
||||
|
||||
Modifies its input in place.
|
||||
Basically, replaces class="foo bar" with class=["foo", "bar"]
|
||||
|
||||
NOTE: This method modifies its input in place.
|
||||
|
||||
:param tag_name: The name of a tag.
|
||||
:param attrs: A dictionary containing the tag's attributes.
|
||||
Any appropriate attribute values will be modified in place.
|
||||
"""
|
||||
if not attrs:
|
||||
return attrs
|
||||
|
@ -207,7 +321,11 @@ class TreeBuilder(object):
|
|||
return attrs
|
||||
|
||||
class SAXTreeBuilder(TreeBuilder):
|
||||
"""A Beautiful Soup treebuilder that listens for SAX events."""
|
||||
"""A Beautiful Soup treebuilder that listens for SAX events.
|
||||
|
||||
This is not currently used for anything, but it demonstrates
|
||||
how a simple TreeBuilder would work.
|
||||
"""
|
||||
|
||||
def feed(self, markup):
|
||||
raise NotImplementedError()
|
||||
|
@ -217,11 +335,11 @@ class SAXTreeBuilder(TreeBuilder):
|
|||
|
||||
def startElement(self, name, attrs):
|
||||
attrs = dict((key[1], value) for key, value in list(attrs.items()))
|
||||
#print "Start %s, %r" % (name, attrs)
|
||||
#print("Start %s, %r" % (name, attrs))
|
||||
self.soup.handle_starttag(name, attrs)
|
||||
|
||||
def endElement(self, name):
|
||||
#print "End %s" % name
|
||||
#print("End %s" % name)
|
||||
self.soup.handle_endtag(name)
|
||||
|
||||
def startElementNS(self, nsTuple, nodeName, attrs):
|
||||
|
@ -271,6 +389,22 @@ class HTMLTreeBuilder(TreeBuilder):
|
|||
# but it may do so eventually, and this information is available if
|
||||
# you need to use it.
|
||||
block_elements = set(["address", "article", "aside", "blockquote", "canvas", "dd", "div", "dl", "dt", "fieldset", "figcaption", "figure", "footer", "form", "h1", "h2", "h3", "h4", "h5", "h6", "header", "hr", "li", "main", "nav", "noscript", "ol", "output", "p", "pre", "section", "table", "tfoot", "ul", "video"])
|
||||
|
||||
# The HTML standard defines an unusual content model for these tags.
|
||||
# We represent this by using a string class other than NavigableString
|
||||
# inside these tags.
|
||||
#
|
||||
# I made this list by going through the HTML spec
|
||||
# (https://html.spec.whatwg.org/#metadata-content) and looking for
|
||||
# "metadata content" elements that can contain strings.
|
||||
#
|
||||
# TODO: Arguably <noscript> could go here but it seems
|
||||
# qualitatively different from the other tags.
|
||||
DEFAULT_STRING_CONTAINERS = {
|
||||
'style': Stylesheet,
|
||||
'script': Script,
|
||||
'template': TemplateString,
|
||||
}
|
||||
|
||||
# The HTML standard defines these attributes as containing a
|
||||
# space-separated list of values, not a single value. That is,
|
||||
|
@ -299,6 +433,16 @@ class HTMLTreeBuilder(TreeBuilder):
|
|||
DEFAULT_PRESERVE_WHITESPACE_TAGS = set(['pre', 'textarea'])
|
||||
|
||||
def set_up_substitutions(self, tag):
|
||||
"""Replace the declared encoding in a <meta> tag with a placeholder,
|
||||
to be substituted when the tag is output to a string.
|
||||
|
||||
An HTML document may come in to Beautiful Soup as one
|
||||
encoding, but exit in a different encoding, and the <meta> tag
|
||||
needs to be changed to reflect this.
|
||||
|
||||
:param tag: A `Tag`
|
||||
:return: Whether or not a substitution was performed.
|
||||
"""
|
||||
# We are only interested in <meta> tags
|
||||
if tag.name != 'meta':
|
||||
return False
|
||||
|
@ -333,8 +477,7 @@ class HTMLTreeBuilder(TreeBuilder):
|
|||
|
||||
def register_treebuilders_from(module):
|
||||
"""Copy TreeBuilders from the given module into this module."""
|
||||
# I'm fairly sure this is not the best way to do this.
|
||||
this_module = sys.modules['bs4.builder']
|
||||
this_module = sys.modules[__name__]
|
||||
for name in module.__all__:
|
||||
obj = getattr(module, name)
|
||||
|
||||
|
@ -345,12 +488,22 @@ def register_treebuilders_from(module):
|
|||
this_module.builder_registry.register(obj)
|
||||
|
||||
class ParserRejectedMarkup(Exception):
|
||||
pass
|
||||
|
||||
"""An Exception to be raised when the underlying parser simply
|
||||
refuses to parse the given markup.
|
||||
"""
|
||||
def __init__(self, message_or_exception):
|
||||
"""Explain why the parser rejected the given markup, either
|
||||
with a textual explanation or another exception.
|
||||
"""
|
||||
if isinstance(message_or_exception, Exception):
|
||||
e = message_or_exception
|
||||
message_or_exception = "%s: %s" % (e.__class__.__name__, str(e))
|
||||
super(ParserRejectedMarkup, self).__init__(message_or_exception)
|
||||
|
||||
# Builders are registered in reverse order of priority, so that custom
|
||||
# builder registrations will take precedence. In general, we want lxml
|
||||
# to take precedence over html5lib, because it's faster. And we only
|
||||
# want to use HTMLParser as a last result.
|
||||
# want to use HTMLParser as a last resort.
|
||||
from . import _htmlparser
|
||||
register_treebuilders_from(_htmlparser)
|
||||
try:
|
||||
|
|
|
@ -39,12 +39,27 @@ except ImportError as e:
|
|||
new_html5lib = True
|
||||
|
||||
class HTML5TreeBuilder(HTMLTreeBuilder):
|
||||
"""Use html5lib to build a tree."""
|
||||
"""Use html5lib to build a tree.
|
||||
|
||||
Note that this TreeBuilder does not support some features common
|
||||
to HTML TreeBuilders. Some of these features could theoretically
|
||||
be implemented, but at the very least it's quite difficult,
|
||||
because html5lib moves the parse tree around as it's being built.
|
||||
|
||||
* This TreeBuilder doesn't use different subclasses of NavigableString
|
||||
based on the name of the tag in which the string was found.
|
||||
|
||||
* You can't use a SoupStrainer to parse only part of a document.
|
||||
"""
|
||||
|
||||
NAME = "html5lib"
|
||||
|
||||
features = [NAME, PERMISSIVE, HTML_5, HTML]
|
||||
|
||||
# html5lib can tell us which line number and position in the
|
||||
# original file is the source of an element.
|
||||
TRACKS_LINE_NUMBERS = True
|
||||
|
||||
def prepare_markup(self, markup, user_specified_encoding,
|
||||
document_declared_encoding=None, exclude_encodings=None):
|
||||
# Store the user-specified encoding for use later on.
|
||||
|
@ -62,7 +77,7 @@ class HTML5TreeBuilder(HTMLTreeBuilder):
|
|||
if self.soup.parse_only is not None:
|
||||
warnings.warn("You provided a value for parse_only, but the html5lib tree builder doesn't support parse_only. The entire document will be parsed.")
|
||||
parser = html5lib.HTMLParser(tree=self.create_treebuilder)
|
||||
|
||||
self.underlying_builder.parser = parser
|
||||
extra_kwargs = dict()
|
||||
if not isinstance(markup, str):
|
||||
if new_html5lib:
|
||||
|
@ -70,7 +85,7 @@ class HTML5TreeBuilder(HTMLTreeBuilder):
|
|||
else:
|
||||
extra_kwargs['encoding'] = self.user_specified_encoding
|
||||
doc = parser.parse(markup, **extra_kwargs)
|
||||
|
||||
|
||||
# Set the character encoding detected by the tokenizer.
|
||||
if isinstance(markup, str):
|
||||
# We need to special-case this because html5lib sets
|
||||
|
@ -84,10 +99,13 @@ class HTML5TreeBuilder(HTMLTreeBuilder):
|
|||
# with other tree builders.
|
||||
original_encoding = original_encoding.name
|
||||
doc.original_encoding = original_encoding
|
||||
|
||||
self.underlying_builder.parser = None
|
||||
|
||||
def create_treebuilder(self, namespaceHTMLElements):
|
||||
self.underlying_builder = TreeBuilderForHtml5lib(
|
||||
namespaceHTMLElements, self.soup)
|
||||
namespaceHTMLElements, self.soup,
|
||||
store_line_numbers=self.store_line_numbers
|
||||
)
|
||||
return self.underlying_builder
|
||||
|
||||
def test_fragment_to_document(self, fragment):
|
||||
|
@ -96,15 +114,29 @@ class HTML5TreeBuilder(HTMLTreeBuilder):
|
|||
|
||||
|
||||
class TreeBuilderForHtml5lib(treebuilder_base.TreeBuilder):
|
||||
|
||||
def __init__(self, namespaceHTMLElements, soup=None):
|
||||
|
||||
def __init__(self, namespaceHTMLElements, soup=None,
|
||||
store_line_numbers=True, **kwargs):
|
||||
if soup:
|
||||
self.soup = soup
|
||||
else:
|
||||
from bs4 import BeautifulSoup
|
||||
self.soup = BeautifulSoup("", "html.parser")
|
||||
# TODO: Why is the parser 'html.parser' here? To avoid an
|
||||
# infinite loop?
|
||||
self.soup = BeautifulSoup(
|
||||
"", "html.parser", store_line_numbers=store_line_numbers,
|
||||
**kwargs
|
||||
)
|
||||
# TODO: What are **kwargs exactly? Should they be passed in
|
||||
# here in addition to/instead of being passed to the BeautifulSoup
|
||||
# constructor?
|
||||
super(TreeBuilderForHtml5lib, self).__init__(namespaceHTMLElements)
|
||||
|
||||
# This will be set later to an html5lib.html5parser.HTMLParser
|
||||
# object, which we can use to track the current line number.
|
||||
self.parser = None
|
||||
self.store_line_numbers = store_line_numbers
|
||||
|
||||
def documentClass(self):
|
||||
self.soup.reset()
|
||||
return Element(self.soup, self.soup, None)
|
||||
|
@ -118,7 +150,16 @@ class TreeBuilderForHtml5lib(treebuilder_base.TreeBuilder):
|
|||
self.soup.object_was_parsed(doctype)
|
||||
|
||||
def elementClass(self, name, namespace):
|
||||
tag = self.soup.new_tag(name, namespace)
|
||||
kwargs = {}
|
||||
if self.parser and self.store_line_numbers:
|
||||
# This represents the point immediately after the end of the
|
||||
# tag. We don't know when the tag started, but we do know
|
||||
# where it ended -- the character just before this one.
|
||||
sourceline, sourcepos = self.parser.tokenizer.stream.position()
|
||||
kwargs['sourceline'] = sourceline
|
||||
kwargs['sourcepos'] = sourcepos-1
|
||||
tag = self.soup.new_tag(name, namespace, **kwargs)
|
||||
|
||||
return Element(tag, self.soup, namespace)
|
||||
|
||||
def commentClass(self, data):
|
||||
|
@ -126,6 +167,8 @@ class TreeBuilderForHtml5lib(treebuilder_base.TreeBuilder):
|
|||
|
||||
def fragmentClass(self):
|
||||
from bs4 import BeautifulSoup
|
||||
# TODO: Why is the parser 'html.parser' here? To avoid an
|
||||
# infinite loop?
|
||||
self.soup = BeautifulSoup("", "html.parser")
|
||||
self.soup.name = "[document_fragment]"
|
||||
return Element(self.soup, self.soup, None)
|
||||
|
@ -287,9 +330,7 @@ class Element(treebuilder_base.Node):
|
|||
return AttrList(self.element)
|
||||
|
||||
def setAttributes(self, attributes):
|
||||
|
||||
if attributes is not None and len(attributes) > 0:
|
||||
|
||||
converted_attributes = []
|
||||
for name, value in list(attributes.items()):
|
||||
if isinstance(name, tuple):
|
||||
|
@ -334,9 +375,9 @@ class Element(treebuilder_base.Node):
|
|||
|
||||
def reparentChildren(self, new_parent):
|
||||
"""Move all of this tag's children into another tag."""
|
||||
# print "MOVE", self.element.contents
|
||||
# print "FROM", self.element
|
||||
# print "TO", new_parent.element
|
||||
# print("MOVE", self.element.contents)
|
||||
# print("FROM", self.element)
|
||||
# print("TO", new_parent.element)
|
||||
|
||||
element = self.element
|
||||
new_parent_element = new_parent.element
|
||||
|
@ -394,9 +435,9 @@ class Element(treebuilder_base.Node):
|
|||
element.contents = []
|
||||
element.next_element = final_next_element
|
||||
|
||||
# print "DONE WITH MOVE"
|
||||
# print "FROM", self.element
|
||||
# print "TO", new_parent_element
|
||||
# print("DONE WITH MOVE")
|
||||
# print("FROM", self.element)
|
||||
# print("TO", new_parent_element)
|
||||
|
||||
def cloneNode(self):
|
||||
tag = self.soup.new_tag(self.element.name, self.namespace)
|
||||
|
|
|
@ -53,8 +53,30 @@ from bs4.builder import (
|
|||
HTMLPARSER = 'html.parser'
|
||||
|
||||
class BeautifulSoupHTMLParser(HTMLParser):
|
||||
"""A subclass of the Python standard library's HTMLParser class, which
|
||||
listens for HTMLParser events and translates them into calls
|
||||
to Beautiful Soup's tree construction API.
|
||||
"""
|
||||
|
||||
# Strategies for handling duplicate attributes
|
||||
IGNORE = 'ignore'
|
||||
REPLACE = 'replace'
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
"""Constructor.
|
||||
|
||||
:param on_duplicate_attribute: A strategy for what to do if a
|
||||
tag includes the same attribute more than once. Accepted
|
||||
values are: REPLACE (replace earlier values with later
|
||||
ones, the default), IGNORE (keep the earliest value
|
||||
encountered), or a callable. A callable must take three
|
||||
arguments: the dictionary of attributes already processed,
|
||||
the name of the duplicate attribute, and the most recent value
|
||||
encountered.
|
||||
"""
|
||||
self.on_duplicate_attribute = kwargs.pop(
|
||||
'on_duplicate_attribute', self.REPLACE
|
||||
)
|
||||
HTMLParser.__init__(self, *args, **kwargs)
|
||||
|
||||
# Keep a list of empty-element tags that were encountered
|
||||
|
@ -67,20 +89,26 @@ class BeautifulSoupHTMLParser(HTMLParser):
|
|||
self.already_closed_empty_element = []
|
||||
|
||||
def error(self, msg):
|
||||
"""In Python 3, HTMLParser subclasses must implement error(), although this
|
||||
requirement doesn't appear to be documented.
|
||||
"""In Python 3, HTMLParser subclasses must implement error(), although
|
||||
this requirement doesn't appear to be documented.
|
||||
|
||||
In Python 2, HTMLParser implements error() as raising an exception.
|
||||
In Python 2, HTMLParser implements error() by raising an exception,
|
||||
which we don't want to do.
|
||||
|
||||
In any event, this method is called only on very strange markup and our best strategy
|
||||
is to pretend it didn't happen and keep going.
|
||||
In any event, this method is called only on very strange
|
||||
markup and our best strategy is to pretend it didn't happen
|
||||
and keep going.
|
||||
"""
|
||||
warnings.warn(msg)
|
||||
|
||||
def handle_startendtag(self, name, attrs):
|
||||
# This is only called when the markup looks like
|
||||
# <tag/>.
|
||||
"""Handle an incoming empty-element tag.
|
||||
|
||||
This is only called when the markup looks like <tag/>.
|
||||
|
||||
:param name: Name of the tag.
|
||||
:param attrs: Dictionary of the tag's attributes.
|
||||
"""
|
||||
# is_startend() tells handle_starttag not to close the tag
|
||||
# just because its name matches a known empty-element tag. We
|
||||
# know that this is an empty-element tag and we want to call
|
||||
|
@ -89,6 +117,14 @@ class BeautifulSoupHTMLParser(HTMLParser):
|
|||
self.handle_endtag(name)
|
||||
|
||||
def handle_starttag(self, name, attrs, handle_empty_element=True):
|
||||
"""Handle an opening tag, e.g. '<tag>'
|
||||
|
||||
:param name: Name of the tag.
|
||||
:param attrs: Dictionary of the tag's attributes.
|
||||
:param handle_empty_element: True if this tag is known to be
|
||||
an empty-element tag (i.e. there is not expected to be any
|
||||
closing tag).
|
||||
"""
|
||||
# XXX namespace
|
||||
attr_dict = {}
|
||||
for key, value in attrs:
|
||||
|
@ -96,10 +132,26 @@ class BeautifulSoupHTMLParser(HTMLParser):
|
|||
# for consistency with the other tree builders.
|
||||
if value is None:
|
||||
value = ''
|
||||
attr_dict[key] = value
|
||||
if key in attr_dict:
|
||||
# A single attribute shows up multiple times in this
|
||||
# tag. How to handle it depends on the
|
||||
# on_duplicate_attribute setting.
|
||||
on_dupe = self.on_duplicate_attribute
|
||||
if on_dupe == self.IGNORE:
|
||||
pass
|
||||
elif on_dupe in (None, self.REPLACE):
|
||||
attr_dict[key] = value
|
||||
else:
|
||||
on_dupe(attr_dict, key, value)
|
||||
else:
|
||||
attr_dict[key] = value
|
||||
attrvalue = '""'
|
||||
#print "START", name
|
||||
tag = self.soup.handle_starttag(name, None, None, attr_dict)
|
||||
#print("START", name)
|
||||
sourceline, sourcepos = self.getpos()
|
||||
tag = self.soup.handle_starttag(
|
||||
name, None, None, attr_dict, sourceline=sourceline,
|
||||
sourcepos=sourcepos
|
||||
)
|
||||
if tag and tag.is_empty_element and handle_empty_element:
|
||||
# Unlike other parsers, html.parser doesn't send separate end tag
|
||||
# events for empty-element tags. (It's handled in
|
||||
|
@ -117,20 +169,34 @@ class BeautifulSoupHTMLParser(HTMLParser):
|
|||
self.already_closed_empty_element.append(name)
|
||||
|
||||
def handle_endtag(self, name, check_already_closed=True):
|
||||
#print "END", name
|
||||
"""Handle a closing tag, e.g. '</tag>'
|
||||
|
||||
:param name: A tag name.
|
||||
:param check_already_closed: True if this tag is expected to
|
||||
be the closing portion of an empty-element tag,
|
||||
e.g. '<tag></tag>'.
|
||||
"""
|
||||
#print("END", name)
|
||||
if check_already_closed and name in self.already_closed_empty_element:
|
||||
# This is a redundant end tag for an empty-element tag.
|
||||
# We've already called handle_endtag() for it, so just
|
||||
# check it off the list.
|
||||
# print "ALREADY CLOSED", name
|
||||
#print("ALREADY CLOSED", name)
|
||||
self.already_closed_empty_element.remove(name)
|
||||
else:
|
||||
self.soup.handle_endtag(name)
|
||||
|
||||
def handle_data(self, data):
|
||||
"""Handle some textual data that shows up between tags."""
|
||||
self.soup.handle_data(data)
|
||||
|
||||
def handle_charref(self, name):
|
||||
"""Handle a numeric character reference by converting it to the
|
||||
corresponding Unicode character and treating it as textual
|
||||
data.
|
||||
|
||||
:param name: Character number, possibly in hexadecimal.
|
||||
"""
|
||||
# XXX workaround for a bug in HTMLParser. Remove this once
|
||||
# it's fixed in all supported versions.
|
||||
# http://bugs.python.org/issue13633
|
||||
|
@ -164,6 +230,12 @@ class BeautifulSoupHTMLParser(HTMLParser):
|
|||
self.handle_data(data)
|
||||
|
||||
def handle_entityref(self, name):
|
||||
"""Handle a named entity reference by converting it to the
|
||||
corresponding Unicode character(s) and treating it as textual
|
||||
data.
|
||||
|
||||
:param name: Name of the entity reference.
|
||||
"""
|
||||
character = EntitySubstitution.HTML_ENTITY_TO_CHARACTER.get(name)
|
||||
if character is not None:
|
||||
data = character
|
||||
|
@ -177,21 +249,29 @@ class BeautifulSoupHTMLParser(HTMLParser):
|
|||
self.handle_data(data)
|
||||
|
||||
def handle_comment(self, data):
|
||||
"""Handle an HTML comment.
|
||||
|
||||
:param data: The text of the comment.
|
||||
"""
|
||||
self.soup.endData()
|
||||
self.soup.handle_data(data)
|
||||
self.soup.endData(Comment)
|
||||
|
||||
def handle_decl(self, data):
|
||||
"""Handle a DOCTYPE declaration.
|
||||
|
||||
:param data: The text of the declaration.
|
||||
"""
|
||||
self.soup.endData()
|
||||
if data.startswith("DOCTYPE "):
|
||||
data = data[len("DOCTYPE "):]
|
||||
elif data == 'DOCTYPE':
|
||||
# i.e. "<!DOCTYPE>"
|
||||
data = ''
|
||||
data = data[len("DOCTYPE "):]
|
||||
self.soup.handle_data(data)
|
||||
self.soup.endData(Doctype)
|
||||
|
||||
def unknown_decl(self, data):
|
||||
"""Handle a declaration of unknown type -- probably a CDATA block.
|
||||
|
||||
:param data: The text of the declaration.
|
||||
"""
|
||||
if data.upper().startswith('CDATA['):
|
||||
cls = CData
|
||||
data = data[len('CDATA['):]
|
||||
|
@ -202,47 +282,109 @@ class BeautifulSoupHTMLParser(HTMLParser):
|
|||
self.soup.endData(cls)
|
||||
|
||||
def handle_pi(self, data):
|
||||
"""Handle a processing instruction.
|
||||
|
||||
:param data: The text of the instruction.
|
||||
"""
|
||||
self.soup.endData()
|
||||
self.soup.handle_data(data)
|
||||
self.soup.endData(ProcessingInstruction)
|
||||
|
||||
|
||||
class HTMLParserTreeBuilder(HTMLTreeBuilder):
|
||||
|
||||
"""A Beautiful soup `TreeBuilder` that uses the `HTMLParser` parser,
|
||||
found in the Python standard library.
|
||||
"""
|
||||
is_xml = False
|
||||
picklable = True
|
||||
NAME = HTMLPARSER
|
||||
features = [NAME, HTML, STRICT]
|
||||
|
||||
# The html.parser knows which line number and position in the
|
||||
# original file is the source of an element.
|
||||
TRACKS_LINE_NUMBERS = True
|
||||
|
||||
def __init__(self, parser_args=None, parser_kwargs=None, **kwargs):
|
||||
"""Constructor.
|
||||
|
||||
:param parser_args: Positional arguments to pass into
|
||||
the BeautifulSoupHTMLParser constructor, once it's
|
||||
invoked.
|
||||
:param parser_kwargs: Keyword arguments to pass into
|
||||
the BeautifulSoupHTMLParser constructor, once it's
|
||||
invoked.
|
||||
:param kwargs: Keyword arguments for the superclass constructor.
|
||||
"""
|
||||
# Some keyword arguments will be pulled out of kwargs and placed
|
||||
# into parser_kwargs.
|
||||
extra_parser_kwargs = dict()
|
||||
for arg in ('on_duplicate_attribute',):
|
||||
if arg in kwargs:
|
||||
value = kwargs.pop(arg)
|
||||
extra_parser_kwargs[arg] = value
|
||||
super(HTMLParserTreeBuilder, self).__init__(**kwargs)
|
||||
parser_args = parser_args or []
|
||||
parser_kwargs = parser_kwargs or {}
|
||||
parser_kwargs.update(extra_parser_kwargs)
|
||||
if CONSTRUCTOR_TAKES_STRICT and not CONSTRUCTOR_STRICT_IS_DEPRECATED:
|
||||
parser_kwargs['strict'] = False
|
||||
if CONSTRUCTOR_TAKES_CONVERT_CHARREFS:
|
||||
parser_kwargs['convert_charrefs'] = False
|
||||
self.parser_args = (parser_args, parser_kwargs)
|
||||
|
||||
|
||||
def prepare_markup(self, markup, user_specified_encoding=None,
|
||||
document_declared_encoding=None, exclude_encodings=None):
|
||||
"""
|
||||
:return: A 4-tuple (markup, original encoding, encoding
|
||||
declared within markup, whether any characters had to be
|
||||
replaced with REPLACEMENT CHARACTER).
|
||||
|
||||
"""Run any preliminary steps necessary to make incoming markup
|
||||
acceptable to the parser.
|
||||
|
||||
:param markup: Some markup -- probably a bytestring.
|
||||
:param user_specified_encoding: The user asked to try this encoding.
|
||||
:param document_declared_encoding: The markup itself claims to be
|
||||
in this encoding.
|
||||
:param exclude_encodings: The user asked _not_ to try any of
|
||||
these encodings.
|
||||
|
||||
:yield: A series of 4-tuples:
|
||||
(markup, encoding, declared encoding,
|
||||
has undergone character replacement)
|
||||
|
||||
Each 4-tuple represents a strategy for converting the
|
||||
document to Unicode and parsing it. Each strategy will be tried
|
||||
in turn.
|
||||
"""
|
||||
if isinstance(markup, str):
|
||||
# Parse Unicode as-is.
|
||||
yield (markup, None, None, False)
|
||||
return
|
||||
|
||||
# Ask UnicodeDammit to sniff the most likely encoding.
|
||||
|
||||
# This was provided by the end-user; treat it as a known
|
||||
# definite encoding per the algorithm laid out in the HTML5
|
||||
# spec. (See the EncodingDetector class for details.)
|
||||
known_definite_encodings = [user_specified_encoding]
|
||||
|
||||
# This was found in the document; treat it as a slightly lower-priority
|
||||
# user encoding.
|
||||
user_encodings = [document_declared_encoding]
|
||||
|
||||
try_encodings = [user_specified_encoding, document_declared_encoding]
|
||||
dammit = UnicodeDammit(markup, try_encodings, is_html=True,
|
||||
exclude_encodings=exclude_encodings)
|
||||
dammit = UnicodeDammit(
|
||||
markup,
|
||||
known_definite_encodings=known_definite_encodings,
|
||||
user_encodings=user_encodings,
|
||||
is_html=True,
|
||||
exclude_encodings=exclude_encodings
|
||||
)
|
||||
yield (dammit.markup, dammit.original_encoding,
|
||||
dammit.declared_html_encoding,
|
||||
dammit.contains_replacement_characters)
|
||||
|
||||
def feed(self, markup):
|
||||
"""Run some incoming markup through some parsing process,
|
||||
populating the `BeautifulSoup` object in self.soup.
|
||||
"""
|
||||
args, kwargs = self.parser_args
|
||||
parser = BeautifulSoupHTMLParser(*args, **kwargs)
|
||||
parser.soup = self.soup
|
||||
|
|
|
@ -57,9 +57,18 @@ class LXMLTreeBuilderForXML(TreeBuilder):
|
|||
|
||||
DEFAULT_NSMAPS_INVERTED = _invert(DEFAULT_NSMAPS)
|
||||
|
||||
# NOTE: If we parsed Element objects and looked at .sourceline,
|
||||
# we'd be able to see the line numbers from the original document.
|
||||
# But instead we build an XMLParser or HTMLParser object to serve
|
||||
# as the target of parse messages, and those messages don't include
|
||||
# line numbers.
|
||||
# See: https://bugs.launchpad.net/lxml/+bug/1846906
|
||||
|
||||
def initialize_soup(self, soup):
|
||||
"""Let the BeautifulSoup object know about the standard namespace
|
||||
mapping.
|
||||
|
||||
:param soup: A `BeautifulSoup`.
|
||||
"""
|
||||
super(LXMLTreeBuilderForXML, self).initialize_soup(soup)
|
||||
self._register_namespaces(self.DEFAULT_NSMAPS)
|
||||
|
@ -69,6 +78,8 @@ class LXMLTreeBuilderForXML(TreeBuilder):
|
|||
while parsing the document.
|
||||
|
||||
This might be useful later on when creating CSS selectors.
|
||||
|
||||
:param mapping: A dictionary mapping namespace prefixes to URIs.
|
||||
"""
|
||||
for key, value in list(mapping.items()):
|
||||
if key and key not in self.soup._namespaces:
|
||||
|
@ -78,20 +89,31 @@ class LXMLTreeBuilderForXML(TreeBuilder):
|
|||
self.soup._namespaces[key] = value
|
||||
|
||||
def default_parser(self, encoding):
|
||||
# This can either return a parser object or a class, which
|
||||
# will be instantiated with default arguments.
|
||||
"""Find the default parser for the given encoding.
|
||||
|
||||
:param encoding: A string.
|
||||
:return: Either a parser object or a class, which
|
||||
will be instantiated with default arguments.
|
||||
"""
|
||||
if self._default_parser is not None:
|
||||
return self._default_parser
|
||||
return etree.XMLParser(
|
||||
target=self, strip_cdata=False, recover=True, encoding=encoding)
|
||||
|
||||
def parser_for(self, encoding):
|
||||
"""Instantiate an appropriate parser for the given encoding.
|
||||
|
||||
:param encoding: A string.
|
||||
:return: A parser object such as an `etree.XMLParser`.
|
||||
"""
|
||||
# Use the default parser.
|
||||
parser = self.default_parser(encoding)
|
||||
|
||||
if isinstance(parser, Callable):
|
||||
# Instantiate the parser with default arguments
|
||||
parser = parser(target=self, strip_cdata=False, encoding=encoding)
|
||||
parser = parser(
|
||||
target=self, strip_cdata=False, recover=True, encoding=encoding
|
||||
)
|
||||
return parser
|
||||
|
||||
def __init__(self, parser=None, empty_element_tags=None, **kwargs):
|
||||
|
@ -116,17 +138,31 @@ class LXMLTreeBuilderForXML(TreeBuilder):
|
|||
def prepare_markup(self, markup, user_specified_encoding=None,
|
||||
exclude_encodings=None,
|
||||
document_declared_encoding=None):
|
||||
"""
|
||||
:yield: A series of 4-tuples.
|
||||
"""Run any preliminary steps necessary to make incoming markup
|
||||
acceptable to the parser.
|
||||
|
||||
lxml really wants to get a bytestring and convert it to
|
||||
Unicode itself. So instead of using UnicodeDammit to convert
|
||||
the bytestring to Unicode using different encodings, this
|
||||
implementation uses EncodingDetector to iterate over the
|
||||
encodings, and tell lxml to try to parse the document as each
|
||||
one in turn.
|
||||
|
||||
:param markup: Some markup -- hopefully a bytestring.
|
||||
:param user_specified_encoding: The user asked to try this encoding.
|
||||
:param document_declared_encoding: The markup itself claims to be
|
||||
in this encoding.
|
||||
:param exclude_encodings: The user asked _not_ to try any of
|
||||
these encodings.
|
||||
|
||||
:yield: A series of 4-tuples:
|
||||
(markup, encoding, declared encoding,
|
||||
has undergone character replacement)
|
||||
|
||||
Each 4-tuple represents a strategy for parsing the document.
|
||||
Each 4-tuple represents a strategy for converting the
|
||||
document to Unicode and parsing it. Each strategy will be tried
|
||||
in turn.
|
||||
"""
|
||||
# Instead of using UnicodeDammit to convert the bytestring to
|
||||
# Unicode using different encodings, use EncodingDetector to
|
||||
# iterate over the encodings, and tell lxml to try to parse
|
||||
# the document as each one in turn.
|
||||
is_html = not self.is_xml
|
||||
if is_html:
|
||||
self.processing_instruction_class = ProcessingInstruction
|
||||
|
@ -144,9 +180,19 @@ class LXMLTreeBuilderForXML(TreeBuilder):
|
|||
yield (markup.encode("utf8"), "utf8",
|
||||
document_declared_encoding, False)
|
||||
|
||||
try_encodings = [user_specified_encoding, document_declared_encoding]
|
||||
# This was provided by the end-user; treat it as a known
|
||||
# definite encoding per the algorithm laid out in the HTML5
|
||||
# spec. (See the EncodingDetector class for details.)
|
||||
known_definite_encodings = [user_specified_encoding]
|
||||
|
||||
# This was found in the document; treat it as a slightly lower-priority
|
||||
# user encoding.
|
||||
user_encodings = [document_declared_encoding]
|
||||
detector = EncodingDetector(
|
||||
markup, try_encodings, is_html, exclude_encodings)
|
||||
markup, known_definite_encodings=known_definite_encodings,
|
||||
user_encodings=user_encodings, is_html=is_html,
|
||||
exclude_encodings=exclude_encodings
|
||||
)
|
||||
for encoding in detector.encodings:
|
||||
yield (detector.markup, encoding, document_declared_encoding, False)
|
||||
|
||||
|
@ -169,7 +215,7 @@ class LXMLTreeBuilderForXML(TreeBuilder):
|
|||
self.parser.feed(data)
|
||||
self.parser.close()
|
||||
except (UnicodeDecodeError, LookupError, etree.ParserError) as e:
|
||||
raise ParserRejectedMarkup(str(e))
|
||||
raise ParserRejectedMarkup(e)
|
||||
|
||||
def close(self):
|
||||
self.nsmaps = [self.DEFAULT_NSMAPS_INVERTED]
|
||||
|
@ -288,7 +334,7 @@ class LXMLTreeBuilder(HTMLTreeBuilder, LXMLTreeBuilderForXML):
|
|||
self.parser.feed(markup)
|
||||
self.parser.close()
|
||||
except (UnicodeDecodeError, LookupError, etree.ParserError) as e:
|
||||
raise ParserRejectedMarkup(str(e))
|
||||
raise ParserRejectedMarkup(e)
|
||||
|
||||
|
||||
def test_fragment_to_document(self, fragment):
|
||||
|
|
2606
libs/bs4/dammit.py
2606
libs/bs4/dammit.py
File diff suppressed because it is too large
Load Diff
|
@ -20,9 +20,13 @@ import sys
|
|||
import cProfile
|
||||
|
||||
def diagnose(data):
|
||||
"""Diagnostic suite for isolating common problems."""
|
||||
print("Diagnostic running on Beautiful Soup %s" % __version__)
|
||||
print("Python version %s" % sys.version)
|
||||
"""Diagnostic suite for isolating common problems.
|
||||
|
||||
:param data: A string containing markup that needs to be explained.
|
||||
:return: None; diagnostics are printed to standard output.
|
||||
"""
|
||||
print(("Diagnostic running on Beautiful Soup %s" % __version__))
|
||||
print(("Python version %s" % sys.version))
|
||||
|
||||
basic_parsers = ["html.parser", "html5lib", "lxml"]
|
||||
for name in basic_parsers:
|
||||
|
@ -39,65 +43,76 @@ def diagnose(data):
|
|||
basic_parsers.append("lxml-xml")
|
||||
try:
|
||||
from lxml import etree
|
||||
print("Found lxml version %s" % ".".join(map(str,etree.LXML_VERSION)))
|
||||
print(("Found lxml version %s" % ".".join(map(str,etree.LXML_VERSION))))
|
||||
except ImportError as e:
|
||||
print (
|
||||
print(
|
||||
"lxml is not installed or couldn't be imported.")
|
||||
|
||||
|
||||
if 'html5lib' in basic_parsers:
|
||||
try:
|
||||
import html5lib
|
||||
print("Found html5lib version %s" % html5lib.__version__)
|
||||
print(("Found html5lib version %s" % html5lib.__version__))
|
||||
except ImportError as e:
|
||||
print (
|
||||
print(
|
||||
"html5lib is not installed or couldn't be imported.")
|
||||
|
||||
if hasattr(data, 'read'):
|
||||
data = data.read()
|
||||
elif data.startswith("http:") or data.startswith("https:"):
|
||||
print('"%s" looks like a URL. Beautiful Soup is not an HTTP client.' % data)
|
||||
print(('"%s" looks like a URL. Beautiful Soup is not an HTTP client.' % data))
|
||||
print("You need to use some other library to get the document behind the URL, and feed that document to Beautiful Soup.")
|
||||
return
|
||||
else:
|
||||
try:
|
||||
if os.path.exists(data):
|
||||
print('"%s" looks like a filename. Reading data from the file.' % data)
|
||||
print(('"%s" looks like a filename. Reading data from the file.' % data))
|
||||
with open(data) as fp:
|
||||
data = fp.read()
|
||||
except ValueError:
|
||||
# This can happen on some platforms when the 'filename' is
|
||||
# too long. Assume it's data and not a filename.
|
||||
pass
|
||||
print()
|
||||
print("")
|
||||
|
||||
for parser in basic_parsers:
|
||||
print("Trying to parse your markup with %s" % parser)
|
||||
print(("Trying to parse your markup with %s" % parser))
|
||||
success = False
|
||||
try:
|
||||
soup = BeautifulSoup(data, features=parser)
|
||||
success = True
|
||||
except Exception as e:
|
||||
print("%s could not parse the markup." % parser)
|
||||
print(("%s could not parse the markup." % parser))
|
||||
traceback.print_exc()
|
||||
if success:
|
||||
print("Here's what %s did with the markup:" % parser)
|
||||
print(soup.prettify())
|
||||
print(("Here's what %s did with the markup:" % parser))
|
||||
print((soup.prettify()))
|
||||
|
||||
print("-" * 80)
|
||||
print(("-" * 80))
|
||||
|
||||
def lxml_trace(data, html=True, **kwargs):
|
||||
"""Print out the lxml events that occur during parsing.
|
||||
|
||||
This lets you see how lxml parses a document when no Beautiful
|
||||
Soup code is running.
|
||||
Soup code is running. You can use this to determine whether
|
||||
an lxml-specific problem is in Beautiful Soup's lxml tree builders
|
||||
or in lxml itself.
|
||||
|
||||
:param data: Some markup.
|
||||
:param html: If True, markup will be parsed with lxml's HTML parser.
|
||||
if False, lxml's XML parser will be used.
|
||||
"""
|
||||
from lxml import etree
|
||||
for event, element in etree.iterparse(StringIO(data), html=html, **kwargs):
|
||||
print(("%s, %4s, %s" % (event, element.tag, element.text)))
|
||||
|
||||
class AnnouncingParser(HTMLParser):
|
||||
"""Announces HTMLParser parse events, without doing anything else."""
|
||||
"""Subclass of HTMLParser that announces parse events, without doing
|
||||
anything else.
|
||||
|
||||
You can use this to get a picture of how html.parser sees a given
|
||||
document. The easiest way to do this is to call `htmlparser_trace`.
|
||||
"""
|
||||
|
||||
def _p(self, s):
|
||||
print(s)
|
||||
|
@ -134,6 +149,8 @@ def htmlparser_trace(data):
|
|||
|
||||
This lets you see how HTMLParser parses a document when no
|
||||
Beautiful Soup code is running.
|
||||
|
||||
:param data: Some markup.
|
||||
"""
|
||||
parser = AnnouncingParser()
|
||||
parser.feed(data)
|
||||
|
@ -154,7 +171,7 @@ def rword(length=5):
|
|||
|
||||
def rsentence(length=4):
|
||||
"Generate a random sentence-like string."
|
||||
return " ".join(rword(random.randint(4,9)) for i in list(range(length)))
|
||||
return " ".join(rword(random.randint(4,9)) for i in range(length))
|
||||
|
||||
def rdoc(num_elements=1000):
|
||||
"""Randomly generate an invalid HTML document."""
|
||||
|
@ -176,9 +193,9 @@ def rdoc(num_elements=1000):
|
|||
|
||||
def benchmark_parsers(num_elements=100000):
|
||||
"""Very basic head-to-head performance benchmark."""
|
||||
print("Comparative parser benchmark on Beautiful Soup %s" % __version__)
|
||||
print(("Comparative parser benchmark on Beautiful Soup %s" % __version__))
|
||||
data = rdoc(num_elements)
|
||||
print("Generated a large invalid HTML document (%d bytes)." % len(data))
|
||||
print(("Generated a large invalid HTML document (%d bytes)." % len(data)))
|
||||
|
||||
for parser in ["lxml", ["lxml", "html"], "html5lib", "html.parser"]:
|
||||
success = False
|
||||
|
@ -188,26 +205,26 @@ def benchmark_parsers(num_elements=100000):
|
|||
b = time.time()
|
||||
success = True
|
||||
except Exception as e:
|
||||
print("%s could not parse the markup." % parser)
|
||||
print(("%s could not parse the markup." % parser))
|
||||
traceback.print_exc()
|
||||
if success:
|
||||
print("BS4+%s parsed the markup in %.2fs." % (parser, b-a))
|
||||
print(("BS4+%s parsed the markup in %.2fs." % (parser, b-a)))
|
||||
|
||||
from lxml import etree
|
||||
a = time.time()
|
||||
etree.HTML(data)
|
||||
b = time.time()
|
||||
print("Raw lxml parsed the markup in %.2fs." % (b-a))
|
||||
print(("Raw lxml parsed the markup in %.2fs." % (b-a)))
|
||||
|
||||
import html5lib
|
||||
parser = html5lib.HTMLParser()
|
||||
a = time.time()
|
||||
parser.parse(data)
|
||||
b = time.time()
|
||||
print("Raw html5lib parsed the markup in %.2fs." % (b-a))
|
||||
print(("Raw html5lib parsed the markup in %.2fs." % (b-a)))
|
||||
|
||||
def profile(num_elements=100000, parser="lxml"):
|
||||
|
||||
"""Use Python's profiler on a randomly generated document."""
|
||||
filehandle = tempfile.NamedTemporaryFile()
|
||||
filename = filehandle.name
|
||||
|
||||
|
@ -220,5 +237,6 @@ def profile(num_elements=100000, parser="lxml"):
|
|||
stats.sort_stats("cumulative")
|
||||
stats.print_stats('_html5lib|bs4', 50)
|
||||
|
||||
# If this file is run as a script, standard input is diagnosed.
|
||||
if __name__ == '__main__':
|
||||
diagnose(sys.stdin.read())
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -5,6 +5,28 @@ class Formatter(EntitySubstitution):
|
|||
|
||||
Some parts of this strategy come from the distinction between
|
||||
HTML4, HTML5, and XML. Others are configurable by the user.
|
||||
|
||||
Formatters are passed in as the `formatter` argument to methods
|
||||
like `PageElement.encode`. Most people won't need to think about
|
||||
formatters, and most people who need to think about them can pass
|
||||
in one of these predefined strings as `formatter` rather than
|
||||
making a new Formatter object:
|
||||
|
||||
For HTML documents:
|
||||
* 'html' - HTML entity substitution for generic HTML documents. (default)
|
||||
* 'html5' - HTML entity substitution for HTML5 documents, as
|
||||
well as some optimizations in the way tags are rendered.
|
||||
* 'minimal' - Only make the substitutions necessary to guarantee
|
||||
valid HTML.
|
||||
* None - Do not perform any substitution. This will be faster
|
||||
but may result in invalid markup.
|
||||
|
||||
For XML documents:
|
||||
* 'html' - Entity substitution for XHTML documents.
|
||||
* 'minimal' - Only make the substitutions necessary to guarantee
|
||||
valid XML. (default)
|
||||
* None - Do not perform any substitution. This will be faster
|
||||
but may result in invalid markup.
|
||||
"""
|
||||
# Registries of XML and HTML formatters.
|
||||
XML_FORMATTERS = {}
|
||||
|
@ -27,11 +49,26 @@ class Formatter(EntitySubstitution):
|
|||
def __init__(
|
||||
self, language=None, entity_substitution=None,
|
||||
void_element_close_prefix='/', cdata_containing_tags=None,
|
||||
empty_attributes_are_booleans=False,
|
||||
):
|
||||
"""
|
||||
"""Constructor.
|
||||
|
||||
:param void_element_close_prefix: By default, represent void
|
||||
elements as <tag/> rather than <tag>
|
||||
:param language: This should be Formatter.XML if you are formatting
|
||||
XML markup and Formatter.HTML if you are formatting HTML markup.
|
||||
|
||||
:param entity_substitution: A function to call to replace special
|
||||
characters with XML/HTML entities. For examples, see
|
||||
bs4.dammit.EntitySubstitution.substitute_html and substitute_xml.
|
||||
:param void_element_close_prefix: By default, void elements
|
||||
are represented as <tag/> (XML rules) rather than <tag>
|
||||
(HTML rules). To get <tag>, pass in the empty string.
|
||||
:param cdata_containing_tags: The list of tags that are defined
|
||||
as containing CDATA in this dialect. For example, in HTML,
|
||||
<script> and <style> tags are defined as containing CDATA,
|
||||
and their contents should not be formatted.
|
||||
:param blank_attributes_are_booleans: Render attributes whose value
|
||||
is the empty string as HTML-style boolean attributes.
|
||||
(Attributes whose value is None are always rendered this way.)
|
||||
"""
|
||||
self.language = language
|
||||
self.entity_substitution = entity_substitution
|
||||
|
@ -39,9 +76,17 @@ class Formatter(EntitySubstitution):
|
|||
self.cdata_containing_tags = self._default(
|
||||
language, cdata_containing_tags, 'cdata_containing_tags'
|
||||
)
|
||||
|
||||
self.empty_attributes_are_booleans=empty_attributes_are_booleans
|
||||
|
||||
def substitute(self, ns):
|
||||
"""Process a string that needs to undergo entity substitution."""
|
||||
"""Process a string that needs to undergo entity substitution.
|
||||
This may be a string encountered in an attribute value or as
|
||||
text.
|
||||
|
||||
:param ns: A string.
|
||||
:return: A string with certain characters replaced by named
|
||||
or numeric entities.
|
||||
"""
|
||||
if not self.entity_substitution:
|
||||
return ns
|
||||
from .element import NavigableString
|
||||
|
@ -54,21 +99,41 @@ class Formatter(EntitySubstitution):
|
|||
return self.entity_substitution(ns)
|
||||
|
||||
def attribute_value(self, value):
|
||||
"""Process the value of an attribute."""
|
||||
"""Process the value of an attribute.
|
||||
|
||||
:param ns: A string.
|
||||
:return: A string with certain characters replaced by named
|
||||
or numeric entities.
|
||||
"""
|
||||
return self.substitute(value)
|
||||
|
||||
def attributes(self, tag):
|
||||
"""Reorder a tag's attributes however you want."""
|
||||
return sorted(tag.attrs.items())
|
||||
"""Reorder a tag's attributes however you want.
|
||||
|
||||
By default, attributes are sorted alphabetically. This makes
|
||||
behavior consistent between Python 2 and Python 3, and preserves
|
||||
backwards compatibility with older versions of Beautiful Soup.
|
||||
|
||||
If `empty_boolean_attributes` is True, then attributes whose
|
||||
values are set to the empty string will be treated as boolean
|
||||
attributes.
|
||||
"""
|
||||
if tag.attrs is None:
|
||||
return []
|
||||
return sorted(
|
||||
(k, (None if self.empty_attributes_are_booleans and v == '' else v))
|
||||
for k, v in list(tag.attrs.items())
|
||||
)
|
||||
|
||||
class HTMLFormatter(Formatter):
|
||||
"""A generic Formatter for HTML."""
|
||||
REGISTRY = {}
|
||||
def __init__(self, *args, **kwargs):
|
||||
return super(HTMLFormatter, self).__init__(self.HTML, *args, **kwargs)
|
||||
|
||||
|
||||
class XMLFormatter(Formatter):
|
||||
"""A generic Formatter for XML."""
|
||||
REGISTRY = {}
|
||||
def __init__(self, *args, **kwargs):
|
||||
return super(XMLFormatter, self).__init__(self.XML, *args, **kwargs)
|
||||
|
@ -80,7 +145,8 @@ HTMLFormatter.REGISTRY['html'] = HTMLFormatter(
|
|||
)
|
||||
HTMLFormatter.REGISTRY["html5"] = HTMLFormatter(
|
||||
entity_substitution=EntitySubstitution.substitute_html,
|
||||
void_element_close_prefix = None
|
||||
void_element_close_prefix=None,
|
||||
empty_attributes_are_booleans=True,
|
||||
)
|
||||
HTMLFormatter.REGISTRY["minimal"] = HTMLFormatter(
|
||||
entity_substitution=EntitySubstitution.substitute_xml
|
||||
|
|
|
@ -8,6 +8,7 @@ import pickle
|
|||
import copy
|
||||
import functools
|
||||
import unittest
|
||||
import warnings
|
||||
from unittest import TestCase
|
||||
from bs4 import BeautifulSoup
|
||||
from bs4.element import (
|
||||
|
@ -15,7 +16,10 @@ from bs4.element import (
|
|||
Comment,
|
||||
ContentMetaAttributeValue,
|
||||
Doctype,
|
||||
PYTHON_SPECIFIC_ENCODINGS,
|
||||
SoupStrainer,
|
||||
Script,
|
||||
Stylesheet,
|
||||
Tag
|
||||
)
|
||||
|
||||
|
@ -83,8 +87,22 @@ class SoupTest(unittest.TestCase):
|
|||
if compare_parsed_to is None:
|
||||
compare_parsed_to = to_parse
|
||||
|
||||
# Verify that the documents come out the same.
|
||||
self.assertEqual(obj.decode(), self.document_for(compare_parsed_to))
|
||||
|
||||
# Also run some checks on the BeautifulSoup object itself:
|
||||
|
||||
# Verify that every tag that was opened was eventually closed.
|
||||
|
||||
# There are no tags in the open tag counter.
|
||||
assert all(v==0 for v in list(obj.open_tag_counter.values()))
|
||||
|
||||
# The only tag in the tag stack is the one for the root
|
||||
# document.
|
||||
self.assertEqual(
|
||||
[obj.ROOT_TAG_NAME], [x.name for x in obj.tagStack]
|
||||
)
|
||||
|
||||
def assertConnectedness(self, element):
|
||||
"""Ensure that next_element and previous_element are properly
|
||||
set for all descendants of the given element.
|
||||
|
@ -211,7 +229,41 @@ class SoupTest(unittest.TestCase):
|
|||
return child
|
||||
|
||||
|
||||
class HTMLTreeBuilderSmokeTest(object):
|
||||
class TreeBuilderSmokeTest(object):
|
||||
# Tests that are common to HTML and XML tree builders.
|
||||
|
||||
def test_fuzzed_input(self):
|
||||
# This test centralizes in one place the various fuzz tests
|
||||
# for Beautiful Soup created by the oss-fuzz project.
|
||||
|
||||
# These strings superficially resemble markup, but they
|
||||
# generally can't be parsed into anything. The best we can
|
||||
# hope for is that parsing these strings won't crash the
|
||||
# parser.
|
||||
#
|
||||
# n.b. This markup is commented out because these fuzz tests
|
||||
# _do_ crash the parser. However the crashes are due to bugs
|
||||
# in html.parser, not Beautiful Soup -- otherwise I'd fix the
|
||||
# bugs!
|
||||
|
||||
bad_markup = [
|
||||
# https://bugs.chromium.org/p/oss-fuzz/issues/detail?id=28873
|
||||
# https://github.com/guidovranken/python-library-fuzzers/blob/master/corp-html/519e5b4269a01185a0d5e76295251921da2f0700
|
||||
# https://bugs.python.org/issue37747
|
||||
#
|
||||
#b'\n<![\xff\xfe\xfe\xcd\x00',
|
||||
|
||||
#https://github.com/guidovranken/python-library-fuzzers/blob/master/corp-html/de32aa55785be29bbc72a1a8e06b00611fb3d9f8
|
||||
# https://bugs.python.org/issue34480
|
||||
#
|
||||
#b'<![n\x00'
|
||||
]
|
||||
for markup in bad_markup:
|
||||
with warnings.catch_warnings(record=False):
|
||||
soup = self.soup(markup)
|
||||
|
||||
|
||||
class HTMLTreeBuilderSmokeTest(TreeBuilderSmokeTest):
|
||||
|
||||
"""A basic test of a treebuilder's competence.
|
||||
|
||||
|
@ -233,6 +285,22 @@ class HTMLTreeBuilderSmokeTest(object):
|
|||
new_tag = soup.new_tag(name)
|
||||
self.assertEqual(True, new_tag.is_empty_element)
|
||||
|
||||
def test_special_string_containers(self):
|
||||
soup = self.soup(
|
||||
"<style>Some CSS</style><script>Some Javascript</script>"
|
||||
)
|
||||
assert isinstance(soup.style.string, Stylesheet)
|
||||
assert isinstance(soup.script.string, Script)
|
||||
|
||||
soup = self.soup(
|
||||
"<style><!--Some CSS--></style>"
|
||||
)
|
||||
assert isinstance(soup.style.string, Stylesheet)
|
||||
# The contents of the style tag resemble an HTML comment, but
|
||||
# it's not treated as a comment.
|
||||
self.assertEqual("<!--Some CSS-->", soup.style.string)
|
||||
assert isinstance(soup.style.string, Stylesheet)
|
||||
|
||||
def test_pickle_and_unpickle_identity(self):
|
||||
# Pickling a tree, then unpickling it, yields a tree identical
|
||||
# to the original.
|
||||
|
@ -250,18 +318,21 @@ class HTMLTreeBuilderSmokeTest(object):
|
|||
doctype = soup.contents[0]
|
||||
self.assertEqual(doctype.__class__, Doctype)
|
||||
self.assertEqual(doctype, doctype_fragment)
|
||||
self.assertEqual(str(soup)[:len(doctype_str)], doctype_str)
|
||||
self.assertEqual(
|
||||
soup.encode("utf8")[:len(doctype_str)],
|
||||
doctype_str
|
||||
)
|
||||
|
||||
# Make sure that the doctype was correctly associated with the
|
||||
# parse tree and that the rest of the document parsed.
|
||||
self.assertEqual(soup.p.contents[0], 'foo')
|
||||
|
||||
def _document_with_doctype(self, doctype_fragment):
|
||||
def _document_with_doctype(self, doctype_fragment, doctype_string="DOCTYPE"):
|
||||
"""Generate and parse a document with the given doctype."""
|
||||
doctype = '<!DOCTYPE %s>' % doctype_fragment
|
||||
doctype = '<!%s %s>' % (doctype_string, doctype_fragment)
|
||||
markup = doctype + '\n<p>foo</p>'
|
||||
soup = self.soup(markup)
|
||||
return doctype, soup
|
||||
return doctype.encode("utf8"), soup
|
||||
|
||||
def test_normal_doctypes(self):
|
||||
"""Make sure normal, everyday HTML doctypes are handled correctly."""
|
||||
|
@ -274,6 +345,27 @@ class HTMLTreeBuilderSmokeTest(object):
|
|||
doctype = soup.contents[0]
|
||||
self.assertEqual("", doctype.strip())
|
||||
|
||||
def test_mixed_case_doctype(self):
|
||||
# A lowercase or mixed-case doctype becomes a Doctype.
|
||||
for doctype_fragment in ("doctype", "DocType"):
|
||||
doctype_str, soup = self._document_with_doctype(
|
||||
"html", doctype_fragment
|
||||
)
|
||||
|
||||
# Make sure a Doctype object was created and that the DOCTYPE
|
||||
# is uppercase.
|
||||
doctype = soup.contents[0]
|
||||
self.assertEqual(doctype.__class__, Doctype)
|
||||
self.assertEqual(doctype, "html")
|
||||
self.assertEqual(
|
||||
soup.encode("utf8")[:len(doctype_str)],
|
||||
b"<!DOCTYPE html>"
|
||||
)
|
||||
|
||||
# Make sure that the doctype was correctly associated with the
|
||||
# parse tree and that the rest of the document parsed.
|
||||
self.assertEqual(soup.p.contents[0], 'foo')
|
||||
|
||||
def test_public_doctype_with_url(self):
|
||||
doctype = 'html PUBLIC "-//W3C//DTD XHTML 1.0 Transitional//EN" "http://www.w3.org/TR/xhtml1/DTD/xhtml1-transitional.dtd"'
|
||||
self.assertDoctypeHandled(doctype)
|
||||
|
@ -532,7 +624,7 @@ Hello, world!
|
|||
self.assertSoupEquals("�", expect)
|
||||
self.assertSoupEquals("�", expect)
|
||||
self.assertSoupEquals("�", expect)
|
||||
|
||||
|
||||
def test_multipart_strings(self):
|
||||
"Mostly to prevent a recurrence of a bug in the html5lib treebuilder."
|
||||
soup = self.soup("<html><h2>\nfoo</h2><p></p></html>")
|
||||
|
@ -594,7 +686,7 @@ Hello, world!
|
|||
markup = b'<a class="foo bar">'
|
||||
soup = self.soup(markup)
|
||||
self.assertEqual(['foo', 'bar'], soup.a['class'])
|
||||
|
||||
|
||||
#
|
||||
# Generally speaking, tests below this point are more tests of
|
||||
# Beautiful Soup than tests of the tree builders. But parsers are
|
||||
|
@ -779,11 +871,44 @@ Hello, world!
|
|||
# encoding.
|
||||
self.assertEqual('utf8', charset.encode("utf8"))
|
||||
|
||||
def test_python_specific_encodings_not_used_in_charset(self):
|
||||
# You can encode an HTML document using a Python-specific
|
||||
# encoding, but that encoding won't be mentioned _inside_ the
|
||||
# resulting document. Instead, the document will appear to
|
||||
# have no encoding.
|
||||
for markup in [
|
||||
b'<meta charset="utf8"></head>'
|
||||
b'<meta id="encoding" charset="utf-8" />'
|
||||
]:
|
||||
soup = self.soup(markup)
|
||||
for encoding in PYTHON_SPECIFIC_ENCODINGS:
|
||||
if encoding in (
|
||||
'idna', 'mbcs', 'oem', 'undefined',
|
||||
'string_escape', 'string-escape'
|
||||
):
|
||||
# For one reason or another, these will raise an
|
||||
# exception if we actually try to use them, so don't
|
||||
# bother.
|
||||
continue
|
||||
encoded = soup.encode(encoding)
|
||||
assert b'meta charset=""' in encoded
|
||||
assert encoding.encode("ascii") not in encoded
|
||||
|
||||
def test_tag_with_no_attributes_can_have_attributes_added(self):
|
||||
data = self.soup("<a>text</a>")
|
||||
data.a['foo'] = 'bar'
|
||||
self.assertEqual('<a foo="bar">text</a>', data.a.decode())
|
||||
|
||||
def test_closing_tag_with_no_opening_tag(self):
|
||||
# Without BeautifulSoup.open_tag_counter, the </span> tag will
|
||||
# cause _popToTag to be called over and over again as we look
|
||||
# for a <span> tag that wasn't there. The result is that 'text2'
|
||||
# will show up outside the body of the document.
|
||||
soup = self.soup("<body><div><p>text1</p></span>text2</div></body>")
|
||||
self.assertEqual(
|
||||
"<body><div><p>text1</p>text2</div></body>", soup.body.decode()
|
||||
)
|
||||
|
||||
def test_worst_case(self):
|
||||
"""Test the worst case (currently) for linking issues."""
|
||||
|
||||
|
@ -791,7 +916,7 @@ Hello, world!
|
|||
self.linkage_validator(soup)
|
||||
|
||||
|
||||
class XMLTreeBuilderSmokeTest(object):
|
||||
class XMLTreeBuilderSmokeTest(TreeBuilderSmokeTest):
|
||||
|
||||
def test_pickle_and_unpickle_identity(self):
|
||||
# Pickling a tree, then unpickling it, yields a tree identical
|
||||
|
@ -812,6 +937,25 @@ class XMLTreeBuilderSmokeTest(object):
|
|||
soup = self.soup(markup)
|
||||
self.assertEqual(markup, soup.encode("utf8"))
|
||||
|
||||
def test_python_specific_encodings_not_used_in_xml_declaration(self):
|
||||
# You can encode an XML document using a Python-specific
|
||||
# encoding, but that encoding won't be mentioned _inside_ the
|
||||
# resulting document.
|
||||
markup = b"""<?xml version="1.0"?>\n<foo/>"""
|
||||
soup = self.soup(markup)
|
||||
for encoding in PYTHON_SPECIFIC_ENCODINGS:
|
||||
if encoding in (
|
||||
'idna', 'mbcs', 'oem', 'undefined',
|
||||
'string_escape', 'string-escape'
|
||||
):
|
||||
# For one reason or another, these will raise an
|
||||
# exception if we actually try to use them, so don't
|
||||
# bother.
|
||||
continue
|
||||
encoded = soup.encode(encoding)
|
||||
assert b'<?xml version="1.0"?>' in encoded
|
||||
assert encoding.encode("ascii") not in encoded
|
||||
|
||||
def test_processing_instruction(self):
|
||||
markup = b"""<?xml version="1.0" encoding="utf8"?>\n<?PITarget PIContent?>"""
|
||||
soup = self.soup(markup)
|
||||
|
@ -828,7 +972,7 @@ class XMLTreeBuilderSmokeTest(object):
|
|||
soup = self.soup(markup)
|
||||
self.assertEqual(
|
||||
soup.encode("utf-8"), markup)
|
||||
|
||||
|
||||
def test_nested_namespaces(self):
|
||||
doc = b"""<?xml version="1.0" encoding="utf-8"?>
|
||||
<!DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.1//EN" "http://www.w3.org/TR/xhtml11/DTD/xhtml11.dtd">
|
||||
|
|
|
@ -168,3 +168,59 @@ class HTML5LibBuilderSmokeTest(SoupTest, HTML5TreeBuilderSmokeTest):
|
|||
for form in soup.find_all('form'):
|
||||
inputs.extend(form.find_all('input'))
|
||||
self.assertEqual(len(inputs), 1)
|
||||
|
||||
def test_tracking_line_numbers(self):
|
||||
# The html.parser TreeBuilder keeps track of line number and
|
||||
# position of each element.
|
||||
markup = "\n <p>\n\n<sourceline>\n<b>text</b></sourceline><sourcepos></p>"
|
||||
soup = self.soup(markup)
|
||||
self.assertEqual(2, soup.p.sourceline)
|
||||
self.assertEqual(5, soup.p.sourcepos)
|
||||
self.assertEqual("sourceline", soup.p.find('sourceline').name)
|
||||
|
||||
# You can deactivate this behavior.
|
||||
soup = self.soup(markup, store_line_numbers=False)
|
||||
self.assertEqual("sourceline", soup.p.sourceline.name)
|
||||
self.assertEqual("sourcepos", soup.p.sourcepos.name)
|
||||
|
||||
def test_special_string_containers(self):
|
||||
# The html5lib tree builder doesn't support this standard feature,
|
||||
# because there's no way of knowing, when a string is created,
|
||||
# where in the tree it will eventually end up.
|
||||
pass
|
||||
|
||||
def test_html5_attributes(self):
|
||||
# The html5lib TreeBuilder can convert any entity named in
|
||||
# the HTML5 spec to a sequence of Unicode characters, and
|
||||
# convert those Unicode characters to a (potentially
|
||||
# different) named entity on the way out.
|
||||
#
|
||||
# This is a copy of the same test from
|
||||
# HTMLParserTreeBuilderSmokeTest. It's not in the superclass
|
||||
# because the lxml HTML TreeBuilder _doesn't_ work this way.
|
||||
for input_element, output_unicode, output_element in (
|
||||
("⇄", '\u21c4', b'⇄'),
|
||||
('⊧', '\u22a7', b'⊧'),
|
||||
('𝔑', '\U0001d511', b'𝔑'),
|
||||
('≧̸', '\u2267\u0338', b'≧̸'),
|
||||
('¬', '\xac', b'¬'),
|
||||
('⫬', '\u2aec', b'⫬'),
|
||||
('"', '"', b'"'),
|
||||
('∴', '\u2234', b'∴'),
|
||||
('∴', '\u2234', b'∴'),
|
||||
('∴', '\u2234', b'∴'),
|
||||
("fj", 'fj', b'fj'),
|
||||
("⊔", '\u2294', b'⊔'),
|
||||
("⊔︀", '\u2294\ufe00', b'⊔︀'),
|
||||
("'", "'", b"'"),
|
||||
("|", "|", b"|"),
|
||||
):
|
||||
markup = '<div>%s</div>' % input_element
|
||||
div = self.soup(markup).div
|
||||
without_element = div.encode()
|
||||
expect = b"<div>%s</div>" % output_unicode.encode("utf8")
|
||||
self.assertEqual(without_element, expect)
|
||||
|
||||
with_element = div.encode(formatter="html")
|
||||
expect = b"<div>%s</div>" % output_element
|
||||
self.assertEqual(with_element, expect)
|
||||
|
|
|
@ -3,6 +3,7 @@ trees."""
|
|||
|
||||
from pdb import set_trace
|
||||
import pickle
|
||||
import warnings
|
||||
from bs4.testing import SoupTest, HTMLTreeBuilderSmokeTest
|
||||
from bs4.builder import HTMLParserTreeBuilder
|
||||
from bs4.builder._htmlparser import BeautifulSoupHTMLParser
|
||||
|
@ -37,6 +38,88 @@ class HTMLParserTreeBuilderSmokeTest(SoupTest, HTMLTreeBuilderSmokeTest):
|
|||
# finishes working is handled.
|
||||
self.assertSoupEquals("foo &# bar", "foo &# bar")
|
||||
|
||||
def test_tracking_line_numbers(self):
|
||||
# The html.parser TreeBuilder keeps track of line number and
|
||||
# position of each element.
|
||||
markup = "\n <p>\n\n<sourceline>\n<b>text</b></sourceline><sourcepos></p>"
|
||||
soup = self.soup(markup)
|
||||
self.assertEqual(2, soup.p.sourceline)
|
||||
self.assertEqual(3, soup.p.sourcepos)
|
||||
self.assertEqual("sourceline", soup.p.find('sourceline').name)
|
||||
|
||||
# You can deactivate this behavior.
|
||||
soup = self.soup(markup, store_line_numbers=False)
|
||||
self.assertEqual("sourceline", soup.p.sourceline.name)
|
||||
self.assertEqual("sourcepos", soup.p.sourcepos.name)
|
||||
|
||||
def test_on_duplicate_attribute(self):
|
||||
# The html.parser tree builder has a variety of ways of
|
||||
# handling a tag that contains the same attribute multiple times.
|
||||
|
||||
markup = '<a class="cls" href="url1" href="url2" href="url3" id="id">'
|
||||
|
||||
# If you don't provide any particular value for
|
||||
# on_duplicate_attribute, later values replace earlier values.
|
||||
soup = self.soup(markup)
|
||||
self.assertEqual("url3", soup.a['href'])
|
||||
self.assertEqual(["cls"], soup.a['class'])
|
||||
self.assertEqual("id", soup.a['id'])
|
||||
|
||||
# You can also get this behavior explicitly.
|
||||
def assert_attribute(on_duplicate_attribute, expected):
|
||||
soup = self.soup(
|
||||
markup, on_duplicate_attribute=on_duplicate_attribute
|
||||
)
|
||||
self.assertEqual(expected, soup.a['href'])
|
||||
|
||||
# Verify that non-duplicate attributes are treated normally.
|
||||
self.assertEqual(["cls"], soup.a['class'])
|
||||
self.assertEqual("id", soup.a['id'])
|
||||
assert_attribute(None, "url3")
|
||||
assert_attribute(BeautifulSoupHTMLParser.REPLACE, "url3")
|
||||
|
||||
# You can ignore subsequent values in favor of the first.
|
||||
assert_attribute(BeautifulSoupHTMLParser.IGNORE, "url1")
|
||||
|
||||
# And you can pass in a callable that does whatever you want.
|
||||
def accumulate(attrs, key, value):
|
||||
if not isinstance(attrs[key], list):
|
||||
attrs[key] = [attrs[key]]
|
||||
attrs[key].append(value)
|
||||
assert_attribute(accumulate, ["url1", "url2", "url3"])
|
||||
|
||||
def test_html5_attributes(self):
|
||||
# The html.parser TreeBuilder can convert any entity named in
|
||||
# the HTML5 spec to a sequence of Unicode characters, and
|
||||
# convert those Unicode characters to a (potentially
|
||||
# different) named entity on the way out.
|
||||
for input_element, output_unicode, output_element in (
|
||||
("⇄", '\u21c4', b'⇄'),
|
||||
('⊧', '\u22a7', b'⊧'),
|
||||
('𝔑', '\U0001d511', b'𝔑'),
|
||||
('≧̸', '\u2267\u0338', b'≧̸'),
|
||||
('¬', '\xac', b'¬'),
|
||||
('⫬', '\u2aec', b'⫬'),
|
||||
('"', '"', b'"'),
|
||||
('∴', '\u2234', b'∴'),
|
||||
('∴', '\u2234', b'∴'),
|
||||
('∴', '\u2234', b'∴'),
|
||||
("fj", 'fj', b'fj'),
|
||||
("⊔", '\u2294', b'⊔'),
|
||||
("⊔︀", '\u2294\ufe00', b'⊔︀'),
|
||||
("'", "'", b"'"),
|
||||
("|", "|", b"|"),
|
||||
):
|
||||
markup = '<div>%s</div>' % input_element
|
||||
div = self.soup(markup).div
|
||||
without_element = div.encode()
|
||||
expect = b"<div>%s</div>" % output_unicode.encode("utf8")
|
||||
self.assertEqual(without_element, expect)
|
||||
|
||||
with_element = div.encode(formatter="html")
|
||||
expect = b"<div>%s</div>" % output_element
|
||||
self.assertEqual(with_element, expect)
|
||||
|
||||
|
||||
class TestHTMLParserSubclass(SoupTest):
|
||||
def test_error(self):
|
||||
|
@ -44,4 +127,8 @@ class TestHTMLParserSubclass(SoupTest):
|
|||
that doesn't cause a crash.
|
||||
"""
|
||||
parser = BeautifulSoupHTMLParser()
|
||||
parser.error("don't crash")
|
||||
with warnings.catch_warnings(record=True) as warns:
|
||||
parser.error("don't crash")
|
||||
[warning] = warns
|
||||
assert "don't crash" == str(warning.message)
|
||||
|
||||
|
|
|
@ -45,7 +45,7 @@ class LXMLTreeBuilderSmokeTest(SoupTest, HTMLTreeBuilderSmokeTest):
|
|||
"<p>foo�bar</p>", "<p>foobar</p>")
|
||||
self.assertSoupEquals(
|
||||
"<p>foo�bar</p>", "<p>foobar</p>")
|
||||
|
||||
|
||||
def test_entities_in_foreign_document_encoding(self):
|
||||
# We can't implement this case correctly because by the time we
|
||||
# hear about markup like "“", it's been (incorrectly) converted into
|
||||
|
@ -71,6 +71,21 @@ class LXMLTreeBuilderSmokeTest(SoupTest, HTMLTreeBuilderSmokeTest):
|
|||
self.assertEqual("<b/>", str(soup.b))
|
||||
self.assertTrue("BeautifulStoneSoup class is deprecated" in str(w[0].message))
|
||||
|
||||
def test_tracking_line_numbers(self):
|
||||
# The lxml TreeBuilder cannot keep track of line numbers from
|
||||
# the original markup. Even if you ask for line numbers, we
|
||||
# don't have 'em.
|
||||
#
|
||||
# This means that if you have a tag like <sourceline> or
|
||||
# <sourcepos>, attribute access will find it rather than
|
||||
# giving you a numeric answer.
|
||||
soup = self.soup(
|
||||
"\n <p>\n\n<sourceline>\n<b>text</b></sourceline><sourcepos></p>",
|
||||
store_line_numbers=True
|
||||
)
|
||||
self.assertEqual("sourceline", soup.p.sourceline.name)
|
||||
self.assertEqual("sourcepos", soup.p.sourcepos.name)
|
||||
|
||||
@skipIf(
|
||||
not LXML_PRESENT,
|
||||
"lxml seems not to be present, not testing its XML tree builder.")
|
||||
|
|
|
@ -3,6 +3,7 @@
|
|||
|
||||
from pdb import set_trace
|
||||
import logging
|
||||
import os
|
||||
import unittest
|
||||
import sys
|
||||
import tempfile
|
||||
|
@ -10,18 +11,27 @@ import tempfile
|
|||
from bs4 import (
|
||||
BeautifulSoup,
|
||||
BeautifulStoneSoup,
|
||||
GuessedAtParserWarning,
|
||||
MarkupResemblesLocatorWarning,
|
||||
)
|
||||
from bs4.builder import (
|
||||
TreeBuilder,
|
||||
ParserRejectedMarkup,
|
||||
)
|
||||
from bs4.element import (
|
||||
CharsetMetaAttributeValue,
|
||||
Comment,
|
||||
ContentMetaAttributeValue,
|
||||
SoupStrainer,
|
||||
NamespacedAttribute,
|
||||
Tag,
|
||||
NavigableString,
|
||||
)
|
||||
|
||||
import bs4.dammit
|
||||
from bs4.dammit import (
|
||||
EntitySubstitution,
|
||||
UnicodeDammit,
|
||||
EncodingDetector,
|
||||
)
|
||||
from bs4.testing import (
|
||||
default_builder,
|
||||
|
@ -62,10 +72,21 @@ class TestConstructor(SoupTest):
|
|||
def __init__(self, **kwargs):
|
||||
self.called_with = kwargs
|
||||
self.is_xml = True
|
||||
self.store_line_numbers = False
|
||||
self.cdata_list_attributes = []
|
||||
self.preserve_whitespace_tags = []
|
||||
self.string_containers = {}
|
||||
def initialize_soup(self, soup):
|
||||
pass
|
||||
def feed(self, markup):
|
||||
self.fed = markup
|
||||
def reset(self):
|
||||
pass
|
||||
def ignore(self, ignore):
|
||||
pass
|
||||
set_up_substitutions = can_be_empty_element = ignore
|
||||
def prepare_markup(self, *args, **kwargs):
|
||||
return ''
|
||||
yield "prepared markup", "original encoding", "declared encoding", "contains replacement characters"
|
||||
|
||||
kwargs = dict(
|
||||
var="value",
|
||||
|
@ -77,7 +98,8 @@ class TestConstructor(SoupTest):
|
|||
soup = BeautifulSoup('', builder=Mock, **kwargs)
|
||||
assert isinstance(soup.builder, Mock)
|
||||
self.assertEqual(dict(var="value"), soup.builder.called_with)
|
||||
|
||||
self.assertEqual("prepared markup", soup.builder.fed)
|
||||
|
||||
# You can also instantiate the TreeBuilder yourself. In this
|
||||
# case, that specific object is used and any keyword arguments
|
||||
# to the BeautifulSoup constructor are ignored.
|
||||
|
@ -91,6 +113,26 @@ class TestConstructor(SoupTest):
|
|||
self.assertEqual(builder, soup.builder)
|
||||
self.assertEqual(kwargs, builder.called_with)
|
||||
|
||||
def test_parser_markup_rejection(self):
|
||||
# If markup is completely rejected by the parser, an
|
||||
# explanatory ParserRejectedMarkup exception is raised.
|
||||
class Mock(TreeBuilder):
|
||||
def feed(self, *args, **kwargs):
|
||||
raise ParserRejectedMarkup("Nope.")
|
||||
|
||||
def prepare_markup(self, *args, **kwargs):
|
||||
# We're going to try two different ways of preparing this markup,
|
||||
# but feed() will reject both of them.
|
||||
yield markup, None, None, False
|
||||
yield markup, None, None, False
|
||||
|
||||
import re
|
||||
self.assertRaisesRegex(
|
||||
ParserRejectedMarkup,
|
||||
"The markup you provided was rejected by the parser. Trying a different parser or a different encoding may help.",
|
||||
BeautifulSoup, '', builder=Mock,
|
||||
)
|
||||
|
||||
def test_cdata_list_attributes(self):
|
||||
# Most attribute values are represented as scalars, but the
|
||||
# HTML standard says that some attributes, like 'class' have
|
||||
|
@ -120,28 +162,96 @@ class TestConstructor(SoupTest):
|
|||
self.assertEqual(["an", "id"], a['id'])
|
||||
self.assertEqual(" a class ", a['class'])
|
||||
|
||||
def test_replacement_classes(self):
|
||||
# Test the ability to pass in replacements for element classes
|
||||
# which will be used when building the tree.
|
||||
class TagPlus(Tag):
|
||||
pass
|
||||
|
||||
class StringPlus(NavigableString):
|
||||
pass
|
||||
|
||||
class CommentPlus(Comment):
|
||||
pass
|
||||
|
||||
soup = self.soup(
|
||||
"<a><b>foo</b>bar</a><!--whee-->",
|
||||
element_classes = {
|
||||
Tag: TagPlus,
|
||||
NavigableString: StringPlus,
|
||||
Comment: CommentPlus,
|
||||
}
|
||||
)
|
||||
|
||||
# The tree was built with TagPlus, StringPlus, and CommentPlus objects,
|
||||
# rather than Tag, String, and Comment objects.
|
||||
assert all(
|
||||
isinstance(x, (TagPlus, StringPlus, CommentPlus))
|
||||
for x in soup.recursiveChildGenerator()
|
||||
)
|
||||
|
||||
def test_alternate_string_containers(self):
|
||||
# Test the ability to customize the string containers for
|
||||
# different types of tags.
|
||||
class PString(NavigableString):
|
||||
pass
|
||||
|
||||
class BString(NavigableString):
|
||||
pass
|
||||
|
||||
soup = self.soup(
|
||||
"<div>Hello.<p>Here is <b>some <i>bolded</i></b> text",
|
||||
string_containers = {
|
||||
'b': BString,
|
||||
'p': PString,
|
||||
}
|
||||
)
|
||||
|
||||
# The string before the <p> tag is a regular NavigableString.
|
||||
assert isinstance(soup.div.contents[0], NavigableString)
|
||||
|
||||
# The string inside the <p> tag, but not inside the <i> tag,
|
||||
# is a PString.
|
||||
assert isinstance(soup.p.contents[0], PString)
|
||||
|
||||
# Every string inside the <b> tag is a BString, even the one that
|
||||
# was also inside an <i> tag.
|
||||
for s in soup.b.strings:
|
||||
assert isinstance(s, BString)
|
||||
|
||||
# Now that parsing was complete, the string_container_stack
|
||||
# (where this information was kept) has been cleared out.
|
||||
self.assertEqual([], soup.string_container_stack)
|
||||
|
||||
|
||||
class TestWarnings(SoupTest):
|
||||
|
||||
def _no_parser_specified(self, s, is_there=True):
|
||||
v = s.startswith(BeautifulSoup.NO_PARSER_SPECIFIED_WARNING[:80])
|
||||
self.assertTrue(v)
|
||||
def _assert_warning(self, warnings, cls):
|
||||
for w in warnings:
|
||||
if isinstance(w.message, cls):
|
||||
return w
|
||||
raise Exception("%s warning not found in %r" % cls, warnings)
|
||||
|
||||
def _assert_no_parser_specified(self, w):
|
||||
warning = self._assert_warning(w, GuessedAtParserWarning)
|
||||
message = str(warning.message)
|
||||
self.assertTrue(
|
||||
message.startswith(BeautifulSoup.NO_PARSER_SPECIFIED_WARNING[:60])
|
||||
)
|
||||
|
||||
def test_warning_if_no_parser_specified(self):
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
soup = self.soup("<a><b></b></a>")
|
||||
msg = str(w[0].message)
|
||||
self._assert_no_parser_specified(msg)
|
||||
soup = BeautifulSoup("<a><b></b></a>")
|
||||
self._assert_no_parser_specified(w)
|
||||
|
||||
def test_warning_if_parser_specified_too_vague(self):
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
soup = self.soup("<a><b></b></a>", "html")
|
||||
msg = str(w[0].message)
|
||||
self._assert_no_parser_specified(msg)
|
||||
soup = BeautifulSoup("<a><b></b></a>", "html")
|
||||
self._assert_no_parser_specified(w)
|
||||
|
||||
def test_no_warning_if_explicit_parser_specified(self):
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
soup = self.soup("<a><b></b></a>", "html.parser")
|
||||
soup = BeautifulSoup("<a><b></b></a>", "html.parser")
|
||||
self.assertEqual([], w)
|
||||
|
||||
def test_parseOnlyThese_renamed_to_parse_only(self):
|
||||
|
@ -165,41 +275,58 @@ class TestWarnings(SoupTest):
|
|||
self.assertRaises(
|
||||
TypeError, self.soup, "<a>", no_such_argument=True)
|
||||
|
||||
class TestWarnings(SoupTest):
|
||||
|
||||
def test_disk_file_warning(self):
|
||||
filehandle = tempfile.NamedTemporaryFile()
|
||||
filename = filehandle.name
|
||||
try:
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
soup = self.soup(filename)
|
||||
msg = str(w[0].message)
|
||||
self.assertTrue("looks like a filename" in msg)
|
||||
warning = self._assert_warning(w, MarkupResemblesLocatorWarning)
|
||||
self.assertTrue("looks like a filename" in str(warning.message))
|
||||
finally:
|
||||
filehandle.close()
|
||||
|
||||
# The file no longer exists, so Beautiful Soup will no longer issue the warning.
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
soup = self.soup(filename)
|
||||
self.assertEqual(0, len(w))
|
||||
self.assertEqual([], w)
|
||||
|
||||
def test_directory_warning(self):
|
||||
try:
|
||||
filename = tempfile.mkdtemp()
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
soup = self.soup(filename)
|
||||
warning = self._assert_warning(w, MarkupResemblesLocatorWarning)
|
||||
self.assertTrue("looks like a directory" in str(warning.message))
|
||||
finally:
|
||||
os.rmdir(filename)
|
||||
|
||||
# The directory no longer exists, so Beautiful Soup will no longer issue the warning.
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
soup = self.soup(filename)
|
||||
self.assertEqual([], w)
|
||||
|
||||
def test_url_warning_with_bytes_url(self):
|
||||
with warnings.catch_warnings(record=True) as warning_list:
|
||||
soup = self.soup(b"http://www.crummybytes.com/")
|
||||
# Be aware this isn't the only warning that can be raised during
|
||||
# execution..
|
||||
self.assertTrue(any("looks like a URL" in str(w.message)
|
||||
for w in warning_list))
|
||||
warning = self._assert_warning(
|
||||
warning_list, MarkupResemblesLocatorWarning
|
||||
)
|
||||
self.assertTrue("looks like a URL" in str(warning.message))
|
||||
|
||||
def test_url_warning_with_unicode_url(self):
|
||||
with warnings.catch_warnings(record=True) as warning_list:
|
||||
# note - this url must differ from the bytes one otherwise
|
||||
# python's warnings system swallows the second warning
|
||||
soup = self.soup("http://www.crummyunicode.com/")
|
||||
self.assertTrue(any("looks like a URL" in str(w.message)
|
||||
for w in warning_list))
|
||||
warning = self._assert_warning(
|
||||
warning_list, MarkupResemblesLocatorWarning
|
||||
)
|
||||
self.assertTrue("looks like a URL" in str(warning.message))
|
||||
|
||||
def test_url_warning_with_bytes_and_space(self):
|
||||
# Here the markup contains something besides a URL, so no warning
|
||||
# is issued.
|
||||
with warnings.catch_warnings(record=True) as warning_list:
|
||||
soup = self.soup(b"http://www.crummybytes.com/ is great")
|
||||
self.assertFalse(any("looks like a URL" in str(w.message)
|
||||
|
@ -241,6 +368,51 @@ class TestEntitySubstitution(unittest.TestCase):
|
|||
self.assertEqual(self.sub.substitute_html(dammit.markup),
|
||||
"‘’foo“”")
|
||||
|
||||
def test_html5_entity(self):
|
||||
# Some HTML5 entities correspond to single- or multi-character
|
||||
# Unicode sequences.
|
||||
|
||||
for entity, u in (
|
||||
# A few spot checks of our ability to recognize
|
||||
# special character sequences and convert them
|
||||
# to named entities.
|
||||
('⊧', '\u22a7'),
|
||||
('𝔑', '\U0001d511'),
|
||||
('≧̸', '\u2267\u0338'),
|
||||
('¬', '\xac'),
|
||||
('⫬', '\u2aec'),
|
||||
|
||||
# We _could_ convert | to &verbarr;, but we don't, because
|
||||
# | is an ASCII character.
|
||||
('|' '|'),
|
||||
|
||||
# Similarly for the fj ligature, which we could convert to
|
||||
# fj, but we don't.
|
||||
("fj", "fj"),
|
||||
|
||||
# We do convert _these_ ASCII characters to HTML entities,
|
||||
# because that's required to generate valid HTML.
|
||||
('>', '>'),
|
||||
('<', '<'),
|
||||
('&', '&'),
|
||||
):
|
||||
template = '3 %s 4'
|
||||
raw = template % u
|
||||
with_entities = template % entity
|
||||
self.assertEqual(self.sub.substitute_html(raw), with_entities)
|
||||
|
||||
def test_html5_entity_with_variation_selector(self):
|
||||
# Some HTML5 entities correspond either to a single-character
|
||||
# Unicode sequence _or_ to the same character plus U+FE00,
|
||||
# VARIATION SELECTOR 1. We can handle this.
|
||||
data = "fjords \u2294 penguins"
|
||||
markup = "fjords ⊔ penguins"
|
||||
self.assertEqual(self.sub.substitute_html(data), markup)
|
||||
|
||||
data = "fjords \u2294\ufe00 penguins"
|
||||
markup = "fjords ⊔︀ penguins"
|
||||
self.assertEqual(self.sub.substitute_html(data), markup)
|
||||
|
||||
def test_xml_converstion_includes_no_quotes_if_make_quoted_attribute_is_false(self):
|
||||
s = 'Welcome to "my bar"'
|
||||
self.assertEqual(self.sub.substitute_xml(s, False), s)
|
||||
|
@ -350,186 +522,26 @@ class TestEncodingConversion(SoupTest):
|
|||
markup = '<div><a \N{SNOWMAN}="snowman"></a></div>'
|
||||
self.assertEqual(self.soup(markup).div.encode("utf8"), markup.encode("utf8"))
|
||||
|
||||
class TestUnicodeDammit(unittest.TestCase):
|
||||
"""Standalone tests of UnicodeDammit."""
|
||||
|
||||
def test_unicode_input(self):
|
||||
markup = "I'm already Unicode! \N{SNOWMAN}"
|
||||
dammit = UnicodeDammit(markup)
|
||||
self.assertEqual(dammit.unicode_markup, markup)
|
||||
|
||||
def test_smart_quotes_to_unicode(self):
|
||||
markup = b"<foo>\x91\x92\x93\x94</foo>"
|
||||
dammit = UnicodeDammit(markup)
|
||||
self.assertEqual(
|
||||
dammit.unicode_markup, "<foo>\u2018\u2019\u201c\u201d</foo>")
|
||||
|
||||
def test_smart_quotes_to_xml_entities(self):
|
||||
markup = b"<foo>\x91\x92\x93\x94</foo>"
|
||||
dammit = UnicodeDammit(markup, smart_quotes_to="xml")
|
||||
self.assertEqual(
|
||||
dammit.unicode_markup, "<foo>‘’“”</foo>")
|
||||
|
||||
def test_smart_quotes_to_html_entities(self):
|
||||
markup = b"<foo>\x91\x92\x93\x94</foo>"
|
||||
dammit = UnicodeDammit(markup, smart_quotes_to="html")
|
||||
self.assertEqual(
|
||||
dammit.unicode_markup, "<foo>‘’“”</foo>")
|
||||
|
||||
def test_smart_quotes_to_ascii(self):
|
||||
markup = b"<foo>\x91\x92\x93\x94</foo>"
|
||||
dammit = UnicodeDammit(markup, smart_quotes_to="ascii")
|
||||
self.assertEqual(
|
||||
dammit.unicode_markup, """<foo>''""</foo>""")
|
||||
|
||||
def test_detect_utf8(self):
|
||||
utf8 = b"Sacr\xc3\xa9 bleu! \xe2\x98\x83"
|
||||
dammit = UnicodeDammit(utf8)
|
||||
self.assertEqual(dammit.original_encoding.lower(), 'utf-8')
|
||||
self.assertEqual(dammit.unicode_markup, 'Sacr\xe9 bleu! \N{SNOWMAN}')
|
||||
|
||||
|
||||
def test_convert_hebrew(self):
|
||||
hebrew = b"\xed\xe5\xec\xf9"
|
||||
dammit = UnicodeDammit(hebrew, ["iso-8859-8"])
|
||||
self.assertEqual(dammit.original_encoding.lower(), 'iso-8859-8')
|
||||
self.assertEqual(dammit.unicode_markup, '\u05dd\u05d5\u05dc\u05e9')
|
||||
|
||||
def test_dont_see_smart_quotes_where_there_are_none(self):
|
||||
utf_8 = b"\343\202\261\343\203\274\343\202\277\343\202\244 Watch"
|
||||
dammit = UnicodeDammit(utf_8)
|
||||
self.assertEqual(dammit.original_encoding.lower(), 'utf-8')
|
||||
self.assertEqual(dammit.unicode_markup.encode("utf-8"), utf_8)
|
||||
|
||||
def test_ignore_inappropriate_codecs(self):
|
||||
utf8_data = "Räksmörgås".encode("utf-8")
|
||||
dammit = UnicodeDammit(utf8_data, ["iso-8859-8"])
|
||||
self.assertEqual(dammit.original_encoding.lower(), 'utf-8')
|
||||
|
||||
def test_ignore_invalid_codecs(self):
|
||||
utf8_data = "Räksmörgås".encode("utf-8")
|
||||
for bad_encoding in ['.utf8', '...', 'utF---16.!']:
|
||||
dammit = UnicodeDammit(utf8_data, [bad_encoding])
|
||||
self.assertEqual(dammit.original_encoding.lower(), 'utf-8')
|
||||
|
||||
def test_exclude_encodings(self):
|
||||
# This is UTF-8.
|
||||
utf8_data = "Räksmörgås".encode("utf-8")
|
||||
|
||||
# But if we exclude UTF-8 from consideration, the guess is
|
||||
# Windows-1252.
|
||||
dammit = UnicodeDammit(utf8_data, exclude_encodings=["utf-8"])
|
||||
self.assertEqual(dammit.original_encoding.lower(), 'windows-1252')
|
||||
|
||||
# And if we exclude that, there is no valid guess at all.
|
||||
dammit = UnicodeDammit(
|
||||
utf8_data, exclude_encodings=["utf-8", "windows-1252"])
|
||||
self.assertEqual(dammit.original_encoding, None)
|
||||
|
||||
def test_encoding_detector_replaces_junk_in_encoding_name_with_replacement_character(self):
|
||||
detected = EncodingDetector(
|
||||
b'<?xml version="1.0" encoding="UTF-\xdb" ?>')
|
||||
encodings = list(detected.encodings)
|
||||
assert 'utf-\N{REPLACEMENT CHARACTER}' in encodings
|
||||
|
||||
def test_detect_html5_style_meta_tag(self):
|
||||
|
||||
for data in (
|
||||
b'<html><meta charset="euc-jp" /></html>',
|
||||
b"<html><meta charset='euc-jp' /></html>",
|
||||
b"<html><meta charset=euc-jp /></html>",
|
||||
b"<html><meta charset=euc-jp/></html>"):
|
||||
dammit = UnicodeDammit(data, is_html=True)
|
||||
self.assertEqual(
|
||||
"euc-jp", dammit.original_encoding)
|
||||
|
||||
def test_last_ditch_entity_replacement(self):
|
||||
# This is a UTF-8 document that contains bytestrings
|
||||
# completely incompatible with UTF-8 (ie. encoded with some other
|
||||
# encoding).
|
||||
#
|
||||
# Since there is no consistent encoding for the document,
|
||||
# Unicode, Dammit will eventually encode the document as UTF-8
|
||||
# and encode the incompatible characters as REPLACEMENT
|
||||
# CHARACTER.
|
||||
#
|
||||
# If chardet is installed, it will detect that the document
|
||||
# can be converted into ISO-8859-1 without errors. This happens
|
||||
# to be the wrong encoding, but it is a consistent encoding, so the
|
||||
# code we're testing here won't run.
|
||||
#
|
||||
# So we temporarily disable chardet if it's present.
|
||||
doc = b"""\357\273\277<?xml version="1.0" encoding="UTF-8"?>
|
||||
<html><b>\330\250\330\252\330\261</b>
|
||||
<i>\310\322\321\220\312\321\355\344</i></html>"""
|
||||
chardet = bs4.dammit.chardet_dammit
|
||||
logging.disable(logging.WARNING)
|
||||
try:
|
||||
def noop(str):
|
||||
return None
|
||||
bs4.dammit.chardet_dammit = noop
|
||||
dammit = UnicodeDammit(doc)
|
||||
self.assertEqual(True, dammit.contains_replacement_characters)
|
||||
self.assertTrue("\ufffd" in dammit.unicode_markup)
|
||||
|
||||
soup = BeautifulSoup(doc, "html.parser")
|
||||
self.assertTrue(soup.contains_replacement_characters)
|
||||
finally:
|
||||
logging.disable(logging.NOTSET)
|
||||
bs4.dammit.chardet_dammit = chardet
|
||||
|
||||
def test_byte_order_mark_removed(self):
|
||||
# A document written in UTF-16LE will have its byte order marker stripped.
|
||||
data = b'\xff\xfe<\x00a\x00>\x00\xe1\x00\xe9\x00<\x00/\x00a\x00>\x00'
|
||||
dammit = UnicodeDammit(data)
|
||||
self.assertEqual("<a>áé</a>", dammit.unicode_markup)
|
||||
self.assertEqual("utf-16le", dammit.original_encoding)
|
||||
|
||||
def test_detwingle(self):
|
||||
# Here's a UTF8 document.
|
||||
utf8 = ("\N{SNOWMAN}" * 3).encode("utf8")
|
||||
|
||||
# Here's a Windows-1252 document.
|
||||
windows_1252 = (
|
||||
"\N{LEFT DOUBLE QUOTATION MARK}Hi, I like Windows!"
|
||||
"\N{RIGHT DOUBLE QUOTATION MARK}").encode("windows_1252")
|
||||
|
||||
# Through some unholy alchemy, they've been stuck together.
|
||||
doc = utf8 + windows_1252 + utf8
|
||||
|
||||
# The document can't be turned into UTF-8:
|
||||
self.assertRaises(UnicodeDecodeError, doc.decode, "utf8")
|
||||
|
||||
# Unicode, Dammit thinks the whole document is Windows-1252,
|
||||
# and decodes it into "☃☃☃“Hi, I like Windows!”☃☃☃"
|
||||
|
||||
# But if we run it through fix_embedded_windows_1252, it's fixed:
|
||||
|
||||
fixed = UnicodeDammit.detwingle(doc)
|
||||
self.assertEqual(
|
||||
"☃☃☃“Hi, I like Windows!”☃☃☃", fixed.decode("utf8"))
|
||||
|
||||
def test_detwingle_ignores_multibyte_characters(self):
|
||||
# Each of these characters has a UTF-8 representation ending
|
||||
# in \x93. \x93 is a smart quote if interpreted as
|
||||
# Windows-1252. But our code knows to skip over multibyte
|
||||
# UTF-8 characters, so they'll survive the process unscathed.
|
||||
for tricky_unicode_char in (
|
||||
"\N{LATIN SMALL LIGATURE OE}", # 2-byte char '\xc5\x93'
|
||||
"\N{LATIN SUBSCRIPT SMALL LETTER X}", # 3-byte char '\xe2\x82\x93'
|
||||
"\xf0\x90\x90\x93", # This is a CJK character, not sure which one.
|
||||
):
|
||||
input = tricky_unicode_char.encode("utf8")
|
||||
self.assertTrue(input.endswith(b'\x93'))
|
||||
output = UnicodeDammit.detwingle(input)
|
||||
self.assertEqual(output, input)
|
||||
|
||||
class TestNamedspacedAttribute(SoupTest):
|
||||
|
||||
def test_name_may_be_none(self):
|
||||
def test_name_may_be_none_or_missing(self):
|
||||
a = NamespacedAttribute("xmlns", None)
|
||||
self.assertEqual(a, "xmlns")
|
||||
|
||||
a = NamespacedAttribute("xmlns", "")
|
||||
self.assertEqual(a, "xmlns")
|
||||
|
||||
a = NamespacedAttribute("xmlns")
|
||||
self.assertEqual(a, "xmlns")
|
||||
|
||||
def test_namespace_may_be_none_or_missing(self):
|
||||
a = NamespacedAttribute(None, "tag")
|
||||
self.assertEqual(a, "tag")
|
||||
|
||||
a = NamespacedAttribute("", "tag")
|
||||
self.assertEqual(a, "tag")
|
||||
|
||||
def test_attribute_is_equivalent_to_colon_separated_string(self):
|
||||
a = NamespacedAttribute("a", "b")
|
||||
self.assertEqual("a:b", a)
|
||||
|
|
|
@ -27,13 +27,17 @@ from bs4.element import (
|
|||
Doctype,
|
||||
Formatter,
|
||||
NavigableString,
|
||||
Script,
|
||||
SoupStrainer,
|
||||
Stylesheet,
|
||||
Tag,
|
||||
TemplateString,
|
||||
)
|
||||
from bs4.testing import (
|
||||
SoupTest,
|
||||
skipIf,
|
||||
)
|
||||
from soupsieve import SelectorSyntaxError
|
||||
|
||||
XML_BUILDER_PRESENT = (builder_registry.lookup("xml") is not None)
|
||||
LXML_PRESENT = (builder_registry.lookup("lxml") is not None)
|
||||
|
@ -741,6 +745,30 @@ class TestPreviousSibling(SiblingTest):
|
|||
self.assertEqual(start.find_previous_sibling(text="nonesuch"), None)
|
||||
|
||||
|
||||
class TestTag(SoupTest):
|
||||
|
||||
# Test various methods of Tag.
|
||||
|
||||
def test__should_pretty_print(self):
|
||||
# Test the rules about when a tag should be pretty-printed.
|
||||
tag = self.soup("").new_tag("a_tag")
|
||||
|
||||
# No list of whitespace-preserving tags -> pretty-print
|
||||
tag._preserve_whitespace_tags = None
|
||||
self.assertEqual(True, tag._should_pretty_print(0))
|
||||
|
||||
# List exists but tag is not on the list -> pretty-print
|
||||
tag.preserve_whitespace_tags = ["some_other_tag"]
|
||||
self.assertEqual(True, tag._should_pretty_print(1))
|
||||
|
||||
# Indent level is None -> don't pretty-print
|
||||
self.assertEqual(False, tag._should_pretty_print(None))
|
||||
|
||||
# Tag is on the whitespace-preserving list -> don't pretty-print
|
||||
tag.preserve_whitespace_tags = ["some_other_tag", "a_tag"]
|
||||
self.assertEqual(False, tag._should_pretty_print(1))
|
||||
|
||||
|
||||
class TestTagCreation(SoupTest):
|
||||
"""Test the ability to create new tags."""
|
||||
def test_new_tag(self):
|
||||
|
@ -981,6 +1009,15 @@ class TestTreeModification(SoupTest):
|
|||
soup.a.extend(l)
|
||||
self.assertEqual("<a><g></g><f></f><e></e><d></d><c></c><b></b></a>", soup.decode())
|
||||
|
||||
def test_extend_with_another_tags_contents(self):
|
||||
data = '<body><div id="d1"><a>1</a><a>2</a><a>3</a><a>4</a></div><div id="d2"></div></body>'
|
||||
soup = self.soup(data)
|
||||
d1 = soup.find('div', id='d1')
|
||||
d2 = soup.find('div', id='d2')
|
||||
d2.extend(d1)
|
||||
self.assertEqual('<div id="d1"></div>', d1.decode())
|
||||
self.assertEqual('<div id="d2"><a>1</a><a>2</a><a>3</a><a>4</a></div>', d2.decode())
|
||||
|
||||
def test_move_tag_to_beginning_of_parent(self):
|
||||
data = "<a><b></b><c></c><d></d></a>"
|
||||
soup = self.soup(data)
|
||||
|
@ -1093,6 +1130,37 @@ class TestTreeModification(SoupTest):
|
|||
self.assertEqual(no.next_element, "no")
|
||||
self.assertEqual(no.next_sibling, " business")
|
||||
|
||||
def test_replace_with_errors(self):
|
||||
# Can't replace a tag that's not part of a tree.
|
||||
a_tag = Tag(name="a")
|
||||
self.assertRaises(ValueError, a_tag.replace_with, "won't work")
|
||||
|
||||
# Can't replace a tag with its parent.
|
||||
a_tag = self.soup("<a><b></b></a>").a
|
||||
self.assertRaises(ValueError, a_tag.b.replace_with, a_tag)
|
||||
|
||||
# Or with a list that includes its parent.
|
||||
self.assertRaises(ValueError, a_tag.b.replace_with,
|
||||
"string1", a_tag, "string2")
|
||||
|
||||
def test_replace_with_multiple(self):
|
||||
data = "<a><b></b><c></c></a>"
|
||||
soup = self.soup(data)
|
||||
d_tag = soup.new_tag("d")
|
||||
d_tag.string = "Text In D Tag"
|
||||
e_tag = soup.new_tag("e")
|
||||
f_tag = soup.new_tag("f")
|
||||
a_string = "Random Text"
|
||||
soup.c.replace_with(d_tag, e_tag, a_string, f_tag)
|
||||
self.assertEqual(
|
||||
"<a><b></b><d>Text In D Tag</d><e></e>Random Text<f></f></a>",
|
||||
soup.decode()
|
||||
)
|
||||
assert soup.b.next_element == d_tag
|
||||
assert d_tag.string.next_element==e_tag
|
||||
assert e_tag.next_element.string == a_string
|
||||
assert e_tag.next_element.next_element == f_tag
|
||||
|
||||
def test_replace_first_child(self):
|
||||
data = "<a><b></b><c></c></a>"
|
||||
soup = self.soup(data)
|
||||
|
@ -1251,6 +1319,23 @@ class TestTreeModification(SoupTest):
|
|||
a.clear(decompose=True)
|
||||
self.assertEqual(0, len(em.contents))
|
||||
|
||||
|
||||
def test_decompose(self):
|
||||
# Test PageElement.decompose() and PageElement.decomposed
|
||||
soup = self.soup("<p><a>String <em>Italicized</em></a></p><p>Another para</p>")
|
||||
p1, p2 = soup.find_all('p')
|
||||
a = p1.a
|
||||
text = p1.em.string
|
||||
for i in [p1, p2, a, text]:
|
||||
self.assertEqual(False, i.decomposed)
|
||||
|
||||
# This sets p1 and everything beneath it to decomposed.
|
||||
p1.decompose()
|
||||
for i in [p1, a, text]:
|
||||
self.assertEqual(True, i.decomposed)
|
||||
# p2 is unaffected.
|
||||
self.assertEqual(False, p2.decomposed)
|
||||
|
||||
def test_string_set(self):
|
||||
"""Tag.string = 'string'"""
|
||||
soup = self.soup("<a></a> <b><c></c></b>")
|
||||
|
@ -1367,7 +1452,7 @@ class TestElementObjects(SoupTest):
|
|||
self.assertEqual(soup.a.get_text(","), "a,r, , t ")
|
||||
self.assertEqual(soup.a.get_text(",", strip=True), "a,r,t")
|
||||
|
||||
def test_get_text_ignores_comments(self):
|
||||
def test_get_text_ignores_special_string_containers(self):
|
||||
soup = self.soup("foo<!--IGNORE-->bar")
|
||||
self.assertEqual(soup.get_text(), "foobar")
|
||||
|
||||
|
@ -1376,10 +1461,51 @@ class TestElementObjects(SoupTest):
|
|||
self.assertEqual(
|
||||
soup.get_text(types=None), "fooIGNOREbar")
|
||||
|
||||
def test_all_strings_ignores_comments(self):
|
||||
soup = self.soup("foo<style>CSS</style><script>Javascript</script>bar")
|
||||
self.assertEqual(soup.get_text(), "foobar")
|
||||
|
||||
def test_all_strings_ignores_special_string_containers(self):
|
||||
soup = self.soup("foo<!--IGNORE-->bar")
|
||||
self.assertEqual(['foo', 'bar'], list(soup.strings))
|
||||
|
||||
soup = self.soup("foo<style>CSS</style><script>Javascript</script>bar")
|
||||
self.assertEqual(['foo', 'bar'], list(soup.strings))
|
||||
|
||||
def test_string_methods_inside_special_string_container_tags(self):
|
||||
# Strings inside tags like <script> are generally ignored by
|
||||
# methods like get_text, because they're not what humans
|
||||
# consider 'text'. But if you call get_text on the <script>
|
||||
# tag itself, those strings _are_ considered to be 'text',
|
||||
# because there's nothing else you might be looking for.
|
||||
|
||||
style = self.soup("<div>a<style>Some CSS</style></div>")
|
||||
template = self.soup("<div>a<template><p>Templated <b>text</b>.</p><!--With a comment.--></template></div>")
|
||||
script = self.soup("<div>a<script><!--a comment-->Some text</script></div>")
|
||||
|
||||
self.assertEqual(style.div.get_text(), "a")
|
||||
self.assertEqual(list(style.div.strings), ["a"])
|
||||
self.assertEqual(style.div.style.get_text(), "Some CSS")
|
||||
self.assertEqual(list(style.div.style.strings),
|
||||
['Some CSS'])
|
||||
|
||||
# The comment is not picked up here. That's because it was
|
||||
# parsed into a Comment object, which is not considered
|
||||
# interesting by template.strings.
|
||||
self.assertEqual(template.div.get_text(), "a")
|
||||
self.assertEqual(list(template.div.strings), ["a"])
|
||||
self.assertEqual(template.div.template.get_text(), "Templated text.")
|
||||
self.assertEqual(list(template.div.template.strings),
|
||||
["Templated ", "text", "."])
|
||||
|
||||
# The comment is included here, because it didn't get parsed
|
||||
# into a Comment object--it's part of the Script string.
|
||||
self.assertEqual(script.div.get_text(), "a")
|
||||
self.assertEqual(list(script.div.strings), ["a"])
|
||||
self.assertEqual(script.div.script.get_text(),
|
||||
"<!--a comment-->Some text")
|
||||
self.assertEqual(list(script.div.script.strings),
|
||||
['<!--a comment-->Some text'])
|
||||
|
||||
class TestCDAtaListAttributes(SoupTest):
|
||||
|
||||
"""Testing cdata-list attributes like 'class'.
|
||||
|
@ -1466,6 +1592,31 @@ class TestPersistence(SoupTest):
|
|||
self.assertEqual("<p> </p>", str(copy))
|
||||
self.assertEqual(encoding, copy.original_encoding)
|
||||
|
||||
def test_copy_preserves_builder_information(self):
|
||||
|
||||
tag = self.soup('<p></p>').p
|
||||
|
||||
# Simulate a tag obtained from a source file.
|
||||
tag.sourceline = 10
|
||||
tag.sourcepos = 33
|
||||
|
||||
copied = tag.__copy__()
|
||||
|
||||
# The TreeBuilder object is no longer availble, but information
|
||||
# obtained from it gets copied over to the new Tag object.
|
||||
self.assertEqual(tag.sourceline, copied.sourceline)
|
||||
self.assertEqual(tag.sourcepos, copied.sourcepos)
|
||||
self.assertEqual(
|
||||
tag.can_be_empty_element, copied.can_be_empty_element
|
||||
)
|
||||
self.assertEqual(
|
||||
tag.cdata_list_attributes, copied.cdata_list_attributes
|
||||
)
|
||||
self.assertEqual(
|
||||
tag.preserve_whitespace_tags, copied.preserve_whitespace_tags
|
||||
)
|
||||
|
||||
|
||||
def test_unicode_pickle(self):
|
||||
# A tree containing Unicode characters can be pickled.
|
||||
html = "<b>\N{SNOWMAN}</b>"
|
||||
|
@ -1726,71 +1877,7 @@ class TestEncoding(SoupTest):
|
|||
else:
|
||||
self.assertEqual(b'<b>\\u2603</b>', repr(soup))
|
||||
|
||||
class TestFormatter(SoupTest):
|
||||
|
||||
def test_sort_attributes(self):
|
||||
# Test the ability to override Formatter.attributes() to,
|
||||
# e.g., disable the normal sorting of attributes.
|
||||
class UnsortedFormatter(Formatter):
|
||||
def attributes(self, tag):
|
||||
self.called_with = tag
|
||||
for k, v in sorted(tag.attrs.items()):
|
||||
if k == 'ignore':
|
||||
continue
|
||||
yield k,v
|
||||
|
||||
soup = self.soup('<p cval="1" aval="2" ignore="ignored"></p>')
|
||||
formatter = UnsortedFormatter()
|
||||
decoded = soup.decode(formatter=formatter)
|
||||
|
||||
# attributes() was called on the <p> tag. It filtered out one
|
||||
# attribute and sorted the other two.
|
||||
self.assertEqual(formatter.called_with, soup.p)
|
||||
self.assertEqual('<p aval="2" cval="1"></p>', decoded)
|
||||
|
||||
|
||||
class TestNavigableStringSubclasses(SoupTest):
|
||||
|
||||
def test_cdata(self):
|
||||
# None of the current builders turn CDATA sections into CData
|
||||
# objects, but you can create them manually.
|
||||
soup = self.soup("")
|
||||
cdata = CData("foo")
|
||||
soup.insert(1, cdata)
|
||||
self.assertEqual(str(soup), "<![CDATA[foo]]>")
|
||||
self.assertEqual(soup.find(text="foo"), "foo")
|
||||
self.assertEqual(soup.contents[0], "foo")
|
||||
|
||||
def test_cdata_is_never_formatted(self):
|
||||
"""Text inside a CData object is passed into the formatter.
|
||||
|
||||
But the return value is ignored.
|
||||
"""
|
||||
|
||||
self.count = 0
|
||||
def increment(*args):
|
||||
self.count += 1
|
||||
return "BITTER FAILURE"
|
||||
|
||||
soup = self.soup("")
|
||||
cdata = CData("<><><>")
|
||||
soup.insert(1, cdata)
|
||||
self.assertEqual(
|
||||
b"<![CDATA[<><><>]]>", soup.encode(formatter=increment))
|
||||
self.assertEqual(1, self.count)
|
||||
|
||||
def test_doctype_ends_in_newline(self):
|
||||
# Unlike other NavigableString subclasses, a DOCTYPE always ends
|
||||
# in a newline.
|
||||
doctype = Doctype("foo")
|
||||
soup = self.soup("")
|
||||
soup.insert(1, doctype)
|
||||
self.assertEqual(soup.encode(), b"<!DOCTYPE foo>\n")
|
||||
|
||||
def test_declaration(self):
|
||||
d = Declaration("foo")
|
||||
self.assertEqual("<?foo?>", d.output_ready())
|
||||
|
||||
|
||||
class TestSoupSelector(TreeTest):
|
||||
|
||||
HTML = """
|
||||
|
@ -1900,7 +1987,7 @@ class TestSoupSelector(TreeTest):
|
|||
self.assertEqual(len(self.soup.select('del')), 0)
|
||||
|
||||
def test_invalid_tag(self):
|
||||
self.assertRaises(SyntaxError, self.soup.select, 'tag%t')
|
||||
self.assertRaises(SelectorSyntaxError, self.soup.select, 'tag%t')
|
||||
|
||||
def test_select_dashed_tag_ids(self):
|
||||
self.assertSelects('custom-dashed-tag', ['dash1', 'dash2'])
|
||||
|
@ -2091,7 +2178,7 @@ class TestSoupSelector(TreeTest):
|
|||
NotImplementedError, self.soup.select, "a:no-such-pseudoclass")
|
||||
|
||||
self.assertRaises(
|
||||
SyntaxError, self.soup.select, "a:nth-of-type(a)")
|
||||
SelectorSyntaxError, self.soup.select, "a:nth-of-type(a)")
|
||||
|
||||
def test_nth_of_type(self):
|
||||
# Try to select first paragraph
|
||||
|
@ -2147,7 +2234,7 @@ class TestSoupSelector(TreeTest):
|
|||
self.assertEqual([], self.soup.select('#inner ~ h2'))
|
||||
|
||||
def test_dangling_combinator(self):
|
||||
self.assertRaises(SyntaxError, self.soup.select, 'h1 >')
|
||||
self.assertRaises(SelectorSyntaxError, self.soup.select, 'h1 >')
|
||||
|
||||
def test_sibling_combinator_wont_select_same_tag_twice(self):
|
||||
self.assertSelects('p[lang] ~ p', ['lang-en-gb', 'lang-en-us', 'lang-fr'])
|
||||
|
@ -2178,8 +2265,8 @@ class TestSoupSelector(TreeTest):
|
|||
self.assertSelects('div x,y, z', ['xid', 'yid', 'zida', 'zidb', 'zidab', 'zidac'])
|
||||
|
||||
def test_invalid_multiple_select(self):
|
||||
self.assertRaises(SyntaxError, self.soup.select, ',x, y')
|
||||
self.assertRaises(SyntaxError, self.soup.select, 'x,,y')
|
||||
self.assertRaises(SelectorSyntaxError, self.soup.select, ',x, y')
|
||||
self.assertRaises(SelectorSyntaxError, self.soup.select, 'x,,y')
|
||||
|
||||
def test_multiple_select_attrs(self):
|
||||
self.assertSelects('p[lang=en], p[lang=en-gb]', ['lang-en', 'lang-en-gb'])
|
||||
|
|
|
@ -1,3 +1,3 @@
|
|||
from .core import where
|
||||
from .core import contents, where
|
||||
|
||||
__version__ = "2019.09.11"
|
||||
__version__ = "2021.10.08"
|
||||
|
|
|
@ -1,2 +1,12 @@
|
|||
from certifi import where
|
||||
print(where())
|
||||
import argparse
|
||||
|
||||
from certifi import contents, where
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("-c", "--contents", action="store_true")
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.contents:
|
||||
print(contents())
|
||||
else:
|
||||
print(where())
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue