mirror of https://github.com/morpheus65535/bazarr
150 lines
4.8 KiB
Python
150 lines
4.8 KiB
Python
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
|
|
|
|
"""asyncio library query support"""
|
|
|
|
import socket
|
|
import asyncio
|
|
import sys
|
|
|
|
import dns._asyncbackend
|
|
import dns.exception
|
|
|
|
|
|
_is_win32 = sys.platform == 'win32'
|
|
|
|
def _get_running_loop():
|
|
try:
|
|
return asyncio.get_running_loop()
|
|
except AttributeError: # pragma: no cover
|
|
return asyncio.get_event_loop()
|
|
|
|
|
|
class _DatagramProtocol:
|
|
def __init__(self):
|
|
self.transport = None
|
|
self.recvfrom = None
|
|
|
|
def connection_made(self, transport):
|
|
self.transport = transport
|
|
|
|
def datagram_received(self, data, addr):
|
|
if self.recvfrom and not self.recvfrom.done():
|
|
self.recvfrom.set_result((data, addr))
|
|
self.recvfrom = None
|
|
|
|
def error_received(self, exc): # pragma: no cover
|
|
if self.recvfrom and not self.recvfrom.done():
|
|
self.recvfrom.set_exception(exc)
|
|
|
|
def connection_lost(self, exc):
|
|
if self.recvfrom and not self.recvfrom.done():
|
|
self.recvfrom.set_exception(exc)
|
|
|
|
def close(self):
|
|
self.transport.close()
|
|
|
|
|
|
async def _maybe_wait_for(awaitable, timeout):
|
|
if timeout:
|
|
try:
|
|
return await asyncio.wait_for(awaitable, timeout)
|
|
except asyncio.TimeoutError:
|
|
raise dns.exception.Timeout(timeout=timeout)
|
|
else:
|
|
return await awaitable
|
|
|
|
|
|
class DatagramSocket(dns._asyncbackend.DatagramSocket):
|
|
def __init__(self, family, transport, protocol):
|
|
self.family = family
|
|
self.transport = transport
|
|
self.protocol = protocol
|
|
|
|
async def sendto(self, what, destination, timeout): # pragma: no cover
|
|
# no timeout for asyncio sendto
|
|
self.transport.sendto(what, destination)
|
|
|
|
async def recvfrom(self, size, timeout):
|
|
# ignore size as there's no way I know to tell protocol about it
|
|
done = _get_running_loop().create_future()
|
|
assert self.protocol.recvfrom is None
|
|
self.protocol.recvfrom = done
|
|
await _maybe_wait_for(done, timeout)
|
|
return done.result()
|
|
|
|
async def close(self):
|
|
self.protocol.close()
|
|
|
|
async def getpeername(self):
|
|
return self.transport.get_extra_info('peername')
|
|
|
|
async def getsockname(self):
|
|
return self.transport.get_extra_info('sockname')
|
|
|
|
|
|
class StreamSocket(dns._asyncbackend.StreamSocket):
|
|
def __init__(self, af, reader, writer):
|
|
self.family = af
|
|
self.reader = reader
|
|
self.writer = writer
|
|
|
|
async def sendall(self, what, timeout):
|
|
self.writer.write(what)
|
|
return await _maybe_wait_for(self.writer.drain(), timeout)
|
|
|
|
async def recv(self, size, timeout):
|
|
return await _maybe_wait_for(self.reader.read(size),
|
|
timeout)
|
|
|
|
async def close(self):
|
|
self.writer.close()
|
|
try:
|
|
await self.writer.wait_closed()
|
|
except AttributeError: # pragma: no cover
|
|
pass
|
|
|
|
async def getpeername(self):
|
|
return self.writer.get_extra_info('peername')
|
|
|
|
async def getsockname(self):
|
|
return self.writer.get_extra_info('sockname')
|
|
|
|
|
|
class Backend(dns._asyncbackend.Backend):
|
|
def name(self):
|
|
return 'asyncio'
|
|
|
|
async def make_socket(self, af, socktype, proto=0,
|
|
source=None, destination=None, timeout=None,
|
|
ssl_context=None, server_hostname=None):
|
|
if destination is None and socktype == socket.SOCK_DGRAM and \
|
|
_is_win32:
|
|
raise NotImplementedError('destinationless datagram sockets '
|
|
'are not supported by asyncio '
|
|
'on Windows')
|
|
loop = _get_running_loop()
|
|
if socktype == socket.SOCK_DGRAM:
|
|
transport, protocol = await loop.create_datagram_endpoint(
|
|
_DatagramProtocol, source, family=af,
|
|
proto=proto, remote_addr=destination)
|
|
return DatagramSocket(af, transport, protocol)
|
|
elif socktype == socket.SOCK_STREAM:
|
|
(r, w) = await _maybe_wait_for(
|
|
asyncio.open_connection(destination[0],
|
|
destination[1],
|
|
ssl=ssl_context,
|
|
family=af,
|
|
proto=proto,
|
|
local_addr=source,
|
|
server_hostname=server_hostname),
|
|
timeout)
|
|
return StreamSocket(af, r, w)
|
|
raise NotImplementedError('unsupported socket ' +
|
|
f'type {socktype}') # pragma: no cover
|
|
|
|
async def sleep(self, interval):
|
|
await asyncio.sleep(interval)
|
|
|
|
def datagram_connection_required(self):
|
|
return _is_win32
|