bazarr/libs/dns/_asyncio_backend.py

150 lines
4.8 KiB
Python

# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
"""asyncio library query support"""
import socket
import asyncio
import sys
import dns._asyncbackend
import dns.exception
_is_win32 = sys.platform == 'win32'
def _get_running_loop():
try:
return asyncio.get_running_loop()
except AttributeError: # pragma: no cover
return asyncio.get_event_loop()
class _DatagramProtocol:
def __init__(self):
self.transport = None
self.recvfrom = None
def connection_made(self, transport):
self.transport = transport
def datagram_received(self, data, addr):
if self.recvfrom and not self.recvfrom.done():
self.recvfrom.set_result((data, addr))
self.recvfrom = None
def error_received(self, exc): # pragma: no cover
if self.recvfrom and not self.recvfrom.done():
self.recvfrom.set_exception(exc)
def connection_lost(self, exc):
if self.recvfrom and not self.recvfrom.done():
self.recvfrom.set_exception(exc)
def close(self):
self.transport.close()
async def _maybe_wait_for(awaitable, timeout):
if timeout:
try:
return await asyncio.wait_for(awaitable, timeout)
except asyncio.TimeoutError:
raise dns.exception.Timeout(timeout=timeout)
else:
return await awaitable
class DatagramSocket(dns._asyncbackend.DatagramSocket):
def __init__(self, family, transport, protocol):
self.family = family
self.transport = transport
self.protocol = protocol
async def sendto(self, what, destination, timeout): # pragma: no cover
# no timeout for asyncio sendto
self.transport.sendto(what, destination)
async def recvfrom(self, size, timeout):
# ignore size as there's no way I know to tell protocol about it
done = _get_running_loop().create_future()
assert self.protocol.recvfrom is None
self.protocol.recvfrom = done
await _maybe_wait_for(done, timeout)
return done.result()
async def close(self):
self.protocol.close()
async def getpeername(self):
return self.transport.get_extra_info('peername')
async def getsockname(self):
return self.transport.get_extra_info('sockname')
class StreamSocket(dns._asyncbackend.StreamSocket):
def __init__(self, af, reader, writer):
self.family = af
self.reader = reader
self.writer = writer
async def sendall(self, what, timeout):
self.writer.write(what)
return await _maybe_wait_for(self.writer.drain(), timeout)
async def recv(self, size, timeout):
return await _maybe_wait_for(self.reader.read(size),
timeout)
async def close(self):
self.writer.close()
try:
await self.writer.wait_closed()
except AttributeError: # pragma: no cover
pass
async def getpeername(self):
return self.writer.get_extra_info('peername')
async def getsockname(self):
return self.writer.get_extra_info('sockname')
class Backend(dns._asyncbackend.Backend):
def name(self):
return 'asyncio'
async def make_socket(self, af, socktype, proto=0,
source=None, destination=None, timeout=None,
ssl_context=None, server_hostname=None):
if destination is None and socktype == socket.SOCK_DGRAM and \
_is_win32:
raise NotImplementedError('destinationless datagram sockets '
'are not supported by asyncio '
'on Windows')
loop = _get_running_loop()
if socktype == socket.SOCK_DGRAM:
transport, protocol = await loop.create_datagram_endpoint(
_DatagramProtocol, source, family=af,
proto=proto, remote_addr=destination)
return DatagramSocket(af, transport, protocol)
elif socktype == socket.SOCK_STREAM:
(r, w) = await _maybe_wait_for(
asyncio.open_connection(destination[0],
destination[1],
ssl=ssl_context,
family=af,
proto=proto,
local_addr=source,
server_hostname=server_hostname),
timeout)
return StreamSocket(af, r, w)
raise NotImplementedError('unsupported socket ' +
f'type {socktype}') # pragma: no cover
async def sleep(self, interval):
await asyncio.sleep(interval)
def datagram_connection_required(self):
return _is_win32