bazarr/libs/dns/rdtypes/util.py

245 lines
8.5 KiB
Python

# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
# Copyright (C) 2006, 2007, 2009-2011 Nominum, Inc.
#
# Permission to use, copy, modify, and distribute this software and its
# documentation for any purpose with or without fee is hereby granted,
# provided that the above copyright notice and this permission notice
# appear in all copies.
#
# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
import collections
import random
import struct
import dns.exception
import dns.ipv4
import dns.ipv6
import dns.name
import dns.rdata
class Gateway:
"""A helper class for the IPSECKEY gateway and AMTRELAY relay fields"""
name = ""
def __init__(self, type, gateway=None):
self.type = dns.rdata.Rdata._as_uint8(type)
self.gateway = gateway
self._check()
@classmethod
def _invalid_type(cls, gateway_type):
return f"invalid {cls.name} type: {gateway_type}"
def _check(self):
if self.type == 0:
if self.gateway not in (".", None):
raise SyntaxError(f"invalid {self.name} for type 0")
self.gateway = None
elif self.type == 1:
# check that it's OK
dns.ipv4.inet_aton(self.gateway)
elif self.type == 2:
# check that it's OK
dns.ipv6.inet_aton(self.gateway)
elif self.type == 3:
if not isinstance(self.gateway, dns.name.Name):
raise SyntaxError(f"invalid {self.name}; not a name")
else:
raise SyntaxError(self._invalid_type(self.type))
def to_text(self, origin=None, relativize=True):
if self.type == 0:
return "."
elif self.type in (1, 2):
return self.gateway
elif self.type == 3:
return str(self.gateway.choose_relativity(origin, relativize))
else:
raise ValueError(self._invalid_type(self.type)) # pragma: no cover
@classmethod
def from_text(cls, gateway_type, tok, origin=None, relativize=True,
relativize_to=None):
if gateway_type in (0, 1, 2):
gateway = tok.get_string()
elif gateway_type == 3:
gateway = tok.get_name(origin, relativize, relativize_to)
else:
raise dns.exception.SyntaxError(
cls._invalid_type(gateway_type)) # pragma: no cover
return cls(gateway_type, gateway)
# pylint: disable=unused-argument
def to_wire(self, file, compress=None, origin=None, canonicalize=False):
if self.type == 0:
pass
elif self.type == 1:
file.write(dns.ipv4.inet_aton(self.gateway))
elif self.type == 2:
file.write(dns.ipv6.inet_aton(self.gateway))
elif self.type == 3:
self.gateway.to_wire(file, None, origin, False)
else:
raise ValueError(self._invalid_type(self.type)) # pragma: no cover
# pylint: enable=unused-argument
@classmethod
def from_wire_parser(cls, gateway_type, parser, origin=None):
if gateway_type == 0:
gateway = None
elif gateway_type == 1:
gateway = dns.ipv4.inet_ntoa(parser.get_bytes(4))
elif gateway_type == 2:
gateway = dns.ipv6.inet_ntoa(parser.get_bytes(16))
elif gateway_type == 3:
gateway = parser.get_name(origin)
else:
raise dns.exception.FormError(cls._invalid_type(gateway_type))
return cls(gateway_type, gateway)
class Bitmap:
"""A helper class for the NSEC/NSEC3/CSYNC type bitmaps"""
type_name = ""
def __init__(self, windows=None):
last_window = -1
self.windows = windows
for (window, bitmap) in self.windows:
if not isinstance(window, int):
raise ValueError(f"bad {self.type_name} window type")
if window <= last_window:
raise ValueError(f"bad {self.type_name} window order")
if window > 256:
raise ValueError(f"bad {self.type_name} window number")
last_window = window
if not isinstance(bitmap, bytes):
raise ValueError(f"bad {self.type_name} octets type")
if len(bitmap) == 0 or len(bitmap) > 32:
raise ValueError(f"bad {self.type_name} octets")
def to_text(self):
text = ""
for (window, bitmap) in self.windows:
bits = []
for (i, byte) in enumerate(bitmap):
for j in range(0, 8):
if byte & (0x80 >> j):
rdtype = window * 256 + i * 8 + j
bits.append(dns.rdatatype.to_text(rdtype))
text += (' ' + ' '.join(bits))
return text
@classmethod
def from_text(cls, tok):
rdtypes = []
for token in tok.get_remaining():
rdtype = dns.rdatatype.from_text(token.unescape().value)
if rdtype == 0:
raise dns.exception.SyntaxError(f"{cls.type_name} with bit 0")
rdtypes.append(rdtype)
rdtypes.sort()
window = 0
octets = 0
prior_rdtype = 0
bitmap = bytearray(b'\0' * 32)
windows = []
for rdtype in rdtypes:
if rdtype == prior_rdtype:
continue
prior_rdtype = rdtype
new_window = rdtype // 256
if new_window != window:
if octets != 0:
windows.append((window, bytes(bitmap[0:octets])))
bitmap = bytearray(b'\0' * 32)
window = new_window
offset = rdtype % 256
byte = offset // 8
bit = offset % 8
octets = byte + 1
bitmap[byte] = bitmap[byte] | (0x80 >> bit)
if octets != 0:
windows.append((window, bytes(bitmap[0:octets])))
return cls(windows)
def to_wire(self, file):
for (window, bitmap) in self.windows:
file.write(struct.pack('!BB', window, len(bitmap)))
file.write(bitmap)
@classmethod
def from_wire_parser(cls, parser):
windows = []
while parser.remaining() > 0:
window = parser.get_uint8()
bitmap = parser.get_counted_bytes()
windows.append((window, bitmap))
return cls(windows)
def _priority_table(items):
by_priority = collections.defaultdict(list)
for rdata in items:
by_priority[rdata._processing_priority()].append(rdata)
return by_priority
def priority_processing_order(iterable):
items = list(iterable)
if len(items) == 1:
return items
by_priority = _priority_table(items)
ordered = []
for k in sorted(by_priority.keys()):
rdatas = by_priority[k]
random.shuffle(rdatas)
ordered.extend(rdatas)
return ordered
_no_weight = 0.1
def weighted_processing_order(iterable):
items = list(iterable)
if len(items) == 1:
return items
by_priority = _priority_table(items)
ordered = []
for k in sorted(by_priority.keys()):
rdatas = by_priority[k]
total = sum(rdata._processing_weight() or _no_weight
for rdata in rdatas)
while len(rdatas) > 1:
r = random.uniform(0, total)
for (n, rdata) in enumerate(rdatas):
weight = rdata._processing_weight() or _no_weight
if weight > r:
break
r -= weight
total -= weight
ordered.append(rdata) # pylint: disable=undefined-loop-variable
del rdatas[n] # pylint: disable=undefined-loop-variable
ordered.append(rdatas[0])
return ordered
def parse_formatted_hex(formatted, num_chunks, chunk_size, separator):
if len(formatted) != num_chunks * (chunk_size + 1) - 1:
raise ValueError('invalid formatted hex string')
value = b''
for _ in range(num_chunks):
chunk = formatted[0:chunk_size]
value += int(chunk, 16).to_bytes(chunk_size // 2, 'big')
formatted = formatted[chunk_size:]
if len(formatted) > 0 and formatted[0] != separator:
raise ValueError('invalid formatted hex string')
formatted = formatted[1:]
return value