Upgraded engine.io module to improve socket.io connection stability. Should help to prevent #1613.

This commit is contained in:
morpheus65535 2021-11-29 23:07:14 -05:00
parent a7a685491a
commit c60c7513a5
16 changed files with 235 additions and 120 deletions

View File

@ -28,7 +28,8 @@ def create_app():
else:
app.config["DEBUG"] = False
socketio.init_app(app, path=base_url.rstrip('/')+'/api/socket.io', cors_allowed_origins='*', async_mode='threading')
socketio.init_app(app, path=base_url.rstrip('/')+'/api/socket.io', cors_allowed_origins='*',
async_mode='threading', allow_upgrades=False, transports='polling')
return app

View File

@ -17,7 +17,7 @@ else: # pragma: no cover
get_tornado_handler = None
ASGIApp = None
__version__ = '4.0.2dev'
__version__ = '4.2.1dev'
__all__ = ['__version__', 'Server', 'WSGIApp', 'Middleware', 'Client']
if AsyncServer is not None: # pragma: no cover

View File

@ -43,19 +43,23 @@ class ASGIApp:
on_startup=None, on_shutdown=None):
self.engineio_server = engineio_server
self.other_asgi_app = other_asgi_app
self.engineio_path = engineio_path.strip('/')
self.engineio_path = engineio_path
if not self.engineio_path.startswith('/'):
self.engineio_path = '/' + self.engineio_path
if not self.engineio_path.endswith('/'):
self.engineio_path += '/'
self.static_files = static_files or {}
self.on_startup = on_startup
self.on_shutdown = on_shutdown
async def __call__(self, scope, receive, send):
if scope['type'] in ['http', 'websocket'] and \
scope['path'].startswith('/{0}/'.format(self.engineio_path)):
scope['path'].startswith(self.engineio_path):
await self.engineio_server.handle_request(scope, receive, send)
else:
static_file = get_static_file(scope['path'], self.static_files) \
if scope['type'] == 'http' and self.static_files else None
if static_file:
if static_file and os.path.exists(static_file['filename']):
await self.serve_static_file(static_file, receive, send)
elif self.other_asgi_app is not None:
await self.other_asgi_app(scope, receive, send)
@ -68,17 +72,14 @@ class ASGIApp:
send): # pragma: no cover
event = await receive()
if event['type'] == 'http.request':
if os.path.exists(static_file['filename']):
with open(static_file['filename'], 'rb') as f:
payload = f.read()
await send({'type': 'http.response.start',
'status': 200,
'headers': [(b'Content-Type', static_file[
'content_type'].encode('utf-8'))]})
await send({'type': 'http.response.body',
'body': payload})
else:
await self.not_found(receive, send)
with open(static_file['filename'], 'rb') as f:
payload = f.read()
await send({'type': 'http.response.start',
'status': 200,
'headers': [(b'Content-Type', static_file[
'content_type'].encode('utf-8'))]})
await send({'type': 'http.response.body',
'body': payload})
async def lifespan(self, receive, send):
while True:
@ -195,7 +196,13 @@ async def make_response(status, headers, payload, environ):
await environ['asgi.send']({'type': 'websocket.accept',
'headers': headers})
else:
await environ['asgi.send']({'type': 'websocket.close'})
if payload:
reason = payload.decode('utf-8') \
if isinstance(payload, bytes) else str(payload)
await environ['asgi.send']({'type': 'websocket.close',
'reason': reason})
else:
await environ['asgi.send']({'type': 'websocket.close'})
return
await environ['asgi.send']({'type': 'http.response.start',

View File

@ -1,8 +1,7 @@
from __future__ import absolute_import
import gevent
from gevent import queue
from gevent.event import Event
from gevent import selectors
import uwsgi
_websocket_available = hasattr(uwsgi, 'websocket_handshake')
@ -40,21 +39,20 @@ class uWSGIWebSocket(object): # pragma: no cover
self._req_ctx = uwsgi.request_context()
else:
# use event and queue for sending messages
from gevent.event import Event
from gevent.queue import Queue
from gevent.select import select
self._event = Event()
self._send_queue = Queue()
self._send_queue = queue.Queue()
# spawn a select greenlet
def select_greenlet_runner(fd, event):
"""Sets event when data becomes available to read on fd."""
while True:
event.set()
try:
select([fd], [], [])[0]
except ValueError:
break
sel = selectors.DefaultSelector()
sel.register(fd, selectors.EVENT_READ)
try:
while True:
sel.select()
event.set()
except gevent.GreenletExit:
sel.unregister(fd)
self._select_greenlet = gevent.spawn(
select_greenlet_runner,
self._sock,

View File

@ -1,17 +1,48 @@
from __future__ import absolute_import
import queue
import threading
import time
try:
import queue
from simple_websocket import Server, ConnectionClosed
_websocket_available = True
except ImportError: # pragma: no cover
import Queue as queue
_websocket_available = False
class WebSocketWSGI(object): # pragma: no cover
"""
This wrapper class provides a threading WebSocket interface that is
compatible with eventlet's implementation.
"""
def __init__(self, app):
self.app = app
def __call__(self, environ, start_response):
self.ws = Server(environ)
return self.app(self)
def close(self):
return self.ws.close()
def send(self, message):
try:
return self.ws.send(message)
except ConnectionClosed:
raise IOError()
def wait(self):
try:
return self.ws.receive()
except ConnectionClosed:
raise IOError()
_async = {
'thread': threading.Thread,
'queue': queue.Queue,
'queue_empty': queue.Empty,
'event': threading.Event,
'websocket': None,
'websocket': WebSocketWSGI if _websocket_available else None,
'sleep': time.sleep,
}

View File

@ -57,6 +57,11 @@ class AsyncClient(client.Client):
skip SSL certificate verification, allowing
connections to servers with self signed certificates.
The default is ``True``.
:param handle_sigint: Set to ``True`` to automatically handle disconnection
when the process is interrupted, or to ``False`` to
leave interrupt handling to the calling application.
Interrupt handling can only be enabled when the
client instance is created in the main thread.
"""
def is_asyncio_based(self):
return True
@ -85,9 +90,8 @@ class AsyncClient(client.Client):
await eio.connect('http://localhost:5000')
"""
global async_signal_handler_set
if not async_signal_handler_set and \
if self.handle_sigint and not async_signal_handler_set and \
threading.current_thread() == threading.main_thread():
try:
asyncio.get_event_loop().add_signal_handler(
signal.SIGINT, async_signal_handler)
@ -166,11 +170,7 @@ class AsyncClient(client.Client):
:param args: arguments to pass to the function.
:param kwargs: keyword arguments to pass to the function.
This function returns an object compatible with the `Thread` class in
the Python standard library. The `start()` method on this object is
already called by this function.
Note: this method is a coroutine.
The return value is a ``asyncio.Task`` object.
"""
return asyncio.ensure_future(target(*args, **kwargs))
@ -191,10 +191,17 @@ class AsyncClient(client.Client):
"""Create an event object."""
return asyncio.Event()
def _reset(self):
if self.http: # pragma: no cover
asyncio.ensure_future(self.http.close())
super()._reset()
def __del__(self): # pragma: no cover
# try to close the aiohttp session if it is still open
if self.http and not self.http.closed:
try:
loop = asyncio.get_event_loop()
if loop.is_running():
loop.ensure_future(self.http.close())
else:
loop.run_until_complete(self.http.close())
except:
pass
async def _connect_polling(self, url, headers, engineio_path):
"""Establish a long-polling connection to the Engine.IO server."""
@ -207,10 +214,10 @@ class AsyncClient(client.Client):
r = await self._send_request(
'GET', self.base_url + self._get_url_timestamp(), headers=headers,
timeout=self.request_timeout)
if r is None:
if r is None or isinstance(r, str):
self._reset()
raise exceptions.ConnectionError(
'Connection refused by the server')
r or 'Connection refused by the server')
if r.status < 200 or r.status >= 300:
self._reset()
try:
@ -416,6 +423,7 @@ class AsyncClient(client.Client):
except (aiohttp.ClientError, asyncio.TimeoutError) as exc:
self.logger.info('HTTP %s request to %s failed with error %s.',
method, url, exc)
return str(exc)
async def _trigger_event(self, event, *args, **kwargs):
"""Invoke an event handler."""
@ -462,9 +470,9 @@ class AsyncClient(client.Client):
r = await self._send_request(
'GET', self.base_url + self._get_url_timestamp(),
timeout=max(self.ping_interval, self.ping_timeout) + 5)
if r is None:
if r is None or isinstance(r, str):
self.logger.warning(
'Connection refused by the server, aborting')
r or 'Connection refused by the server, aborting')
await self.queue.put(None)
break
if r.status < 200 or r.status >= 300:
@ -578,13 +586,13 @@ class AsyncClient(client.Client):
p = payload.Payload(packets=packets)
r = await self._send_request(
'POST', self.base_url, body=p.encode(),
headers={'Content-Type': 'application/octet-stream'},
headers={'Content-Type': 'text/plain'},
timeout=self.request_timeout)
for pkt in packets:
self.queue.task_done()
if r is None:
if r is None or isinstance(r, str):
self.logger.warning(
'Connection refused by the server, aborting')
r or 'Connection refused by the server, aborting')
break
if r.status < 200 or r.status >= 300:
self.logger.warning('Unexpected status code %s in server '

View File

@ -29,7 +29,7 @@ class AsyncServer(server.Server):
is a grace period added by the server.
:param ping_timeout: The time in seconds that the client waits for the
server to respond before disconnecting. The default
is 5 seconds.
is 20 seconds.
:param max_http_buffer_size: The maximum size of a message when using the
polling transport. The default is 1,000,000
bytes.
@ -63,6 +63,9 @@ class AsyncServer(server.Server):
:param async_handlers: If set to ``True``, run message event handlers in
non-blocking threads. To run handlers synchronously,
set to ``False``. The default is ``True``.
:param transports: The list of allowed transports. Valid transports
are ``'polling'`` and ``'websocket'``. Defaults to
``['polling', 'websocket']``.
:param kwargs: Reserved for future extensions, any additional parameters
given as keyword arguments will be silently ignored.
"""
@ -213,6 +216,13 @@ class AsyncServer(server.Server):
jsonp = False
jsonp_index = None
# make sure the client uses an allowed transport
transport = query.get('transport', ['polling'])[0]
if transport not in self.transports:
self._log_error_once('Invalid transport', 'bad-transport')
return await self._make_response(
self._bad_request('Invalid transport'), environ)
# make sure the client speaks a compatible Engine.IO version
sid = query['sid'][0] if 'sid' in query else None
if sid is None and query.get('EIO') != ['4']:
@ -239,7 +249,6 @@ class AsyncServer(server.Server):
r = self._bad_request('Invalid JSONP index number')
elif method == 'GET':
if sid is None:
transport = query.get('transport', ['polling'])[0]
# transport must be one of 'polling' or 'websocket'.
# if 'websocket', the HTTP_UPGRADE header must match.
upgrade_header = environ.get('HTTP_UPGRADE').lower() \
@ -249,9 +258,9 @@ class AsyncServer(server.Server):
r = await self._handle_connect(environ, transport,
jsonp_index)
else:
self._log_error_once('Invalid transport ' + transport,
'bad-transport')
r = self._bad_request('Invalid transport ' + transport)
self._log_error_once('Invalid websocket upgrade',
'bad-upgrade')
r = self._bad_request('Invalid websocket upgrade')
else:
if sid not in self.sockets:
self._log_error_once('Invalid session ' + sid, 'bad-sid')

View File

@ -143,12 +143,18 @@ class AsyncSocket(socket.Socket):
async def _websocket_handler(self, ws):
"""Engine.IO handler for websocket transport."""
async def websocket_wait():
data = await ws.wait()
if data and len(data) > self.server.max_http_buffer_size:
raise ValueError('packet is too large')
return data
if self.connected:
# the socket was already connected, so this is an upgrade
self.upgrading = True # hold packet sends during the upgrade
try:
pkt = await ws.wait()
pkt = await websocket_wait()
except IOError: # pragma: no cover
return
decoded_pkt = packet.Packet(encoded_packet=pkt)
@ -162,7 +168,7 @@ class AsyncSocket(socket.Socket):
await self.queue.put(packet.Packet(packet.NOOP)) # end poll
try:
pkt = await ws.wait()
pkt = await websocket_wait()
except IOError: # pragma: no cover
self.upgrading = False
return
@ -204,7 +210,7 @@ class AsyncSocket(socket.Socket):
while True:
p = None
wait_task = asyncio.ensure_future(ws.wait())
wait_task = asyncio.ensure_future(websocket_wait())
try:
p = await asyncio.wait_for(
wait_task,

View File

@ -1,10 +1,7 @@
from base64 import b64encode
from json import JSONDecodeError
from engineio.json import JSONDecodeError
import logging
try:
import queue
except ImportError: # pragma: no cover
import Queue as queue
import queue
import signal
import ssl
import threading
@ -69,17 +66,18 @@ class Client(object):
skip SSL certificate verification, allowing
connections to servers with self signed certificates.
The default is ``True``.
:param handle_sigint: Set to ``True`` to automatically handle disconnection
when the process is interrupted, or to ``False`` to
leave interrupt handling to the calling application.
Interrupt handling can only be enabled when the
client instance is created in the main thread.
"""
event_names = ['connect', 'disconnect', 'message']
def __init__(self,
logger=False,
json=None,
request_timeout=5,
http_session=None,
ssl_verify=True):
def __init__(self, logger=False, json=None, request_timeout=5,
http_session=None, ssl_verify=True, handle_sigint=True):
global original_signal_handler
if original_signal_handler is None and \
if handle_sigint and original_signal_handler is None and \
threading.current_thread() == threading.main_thread():
original_signal_handler = signal.signal(signal.SIGINT,
signal_handler)
@ -92,6 +90,7 @@ class Client(object):
self.ping_interval = None
self.ping_timeout = None
self.http = http_session
self.handle_sigint = handle_sigint
self.ws = None
self.read_loop_task = None
self.write_loop_task = None
@ -244,9 +243,9 @@ class Client(object):
:param args: arguments to pass to the function.
:param kwargs: keyword arguments to pass to the function.
This function returns an object compatible with the `Thread` class in
the Python standard library. The `start()` method on this object is
already called by this function.
This function returns an object that represents the background task,
on which the ``join()`` method can be invoked to wait for the task to
complete.
"""
th = threading.Thread(target=target, args=args, kwargs=kwargs)
th.start()
@ -282,10 +281,10 @@ class Client(object):
r = self._send_request(
'GET', self.base_url + self._get_url_timestamp(), headers=headers,
timeout=self.request_timeout)
if r is None:
if r is None or isinstance(r, str):
self._reset()
raise exceptions.ConnectionError(
'Connection refused by the server')
r or 'Connection refused by the server')
if r.status_code < 200 or r.status_code >= 300:
self._reset()
try:
@ -528,6 +527,7 @@ class Client(object):
except requests.exceptions.RequestException as exc:
self.logger.info('HTTP %s request to %s failed with error %s.',
method, url, exc)
return str(exc)
def _trigger_event(self, event, *args, **kwargs):
"""Invoke an event handler."""
@ -574,9 +574,9 @@ class Client(object):
r = self._send_request(
'GET', self.base_url + self._get_url_timestamp(),
timeout=max(self.ping_interval, self.ping_timeout) + 5)
if r is None:
if r is None or isinstance(r, str):
self.logger.warning(
'Connection refused by the server, aborting')
r or 'Connection refused by the server, aborting')
self.queue.put(None)
break
if r.status_code < 200 or r.status_code >= 300:
@ -682,13 +682,13 @@ class Client(object):
p = payload.Payload(packets=packets)
r = self._send_request(
'POST', self.base_url, body=p.encode(),
headers={'Content-Type': 'application/octet-stream'},
headers={'Content-Type': 'text/plain'},
timeout=self.request_timeout)
for pkt in packets:
self.queue.task_done()
if r is None:
if r is None or isinstance(r, str):
self.logger.warning(
'Connection refused by the server, aborting')
r or 'Connection refused by the server, aborting')
break
if r.status_code < 200 or r.status_code >= 300:
self.logger.warning('Unexpected status code %s in server '

16
libs/engineio/json.py Normal file
View File

@ -0,0 +1,16 @@
"""JSON-compatible module with sane defaults."""
from json import * # noqa: F401, F403
from json import loads as original_loads
def _safe_int(s):
if len(s) > 100:
raise ValueError('Integer is too large')
return int(s)
def loads(*args, **kwargs):
if 'parse_int' not in kwargs: # pragma: no cover
kwargs['parse_int'] = _safe_int
return original_loads(*args, **kwargs)

View File

@ -35,7 +35,11 @@ class WSGIApp(object):
engineio_path='engine.io'):
self.engineio_app = engineio_app
self.wsgi_app = wsgi_app
self.engineio_path = engineio_path.strip('/')
self.engineio_path = engineio_path
if not self.engineio_path.startswith('/'):
self.engineio_path = '/' + self.engineio_path
if not self.engineio_path.endswith('/'):
self.engineio_path += '/'
self.static_files = static_files or {}
def __call__(self, environ, start_response):
@ -55,21 +59,17 @@ class WSGIApp(object):
environ['eventlet.input'] = Input(environ['gunicorn.socket'])
path = environ['PATH_INFO']
if path is not None and \
path.startswith('/{0}/'.format(self.engineio_path)):
if path is not None and path.startswith(self.engineio_path):
return self.engineio_app.handle_request(environ, start_response)
else:
static_file = get_static_file(path, self.static_files) \
if self.static_files else None
if static_file:
if os.path.exists(static_file['filename']):
start_response(
'200 OK',
[('Content-Type', static_file['content_type'])])
with open(static_file['filename'], 'rb') as f:
return [f.read()]
else:
return self.not_found(start_response)
if static_file and os.path.exists(static_file['filename']):
start_response(
'200 OK',
[('Content-Type', static_file['content_type'])])
with open(static_file['filename'], 'rb') as f:
return [f.read()]
elif self.wsgi_app is not None:
return self.wsgi_app(environ, start_response)
return self.not_found(start_response)

View File

@ -1,5 +1,5 @@
import base64
import json as _json
from engineio import json as _json
(OPEN, CLOSE, PING, PONG, MESSAGE, UPGRADE, NOOP) = (0, 1, 2, 3, 4, 5, 6)
packet_names = ['OPEN', 'CLOSE', 'PING', 'PONG', 'MESSAGE', 'UPGRADE', 'NOOP']
@ -23,7 +23,7 @@ class Packet(object):
self.binary = False
if self.binary and self.packet_type != MESSAGE:
raise ValueError('Binary packets can only be of type MESSAGE')
if encoded_packet:
if encoded_packet is not None:
self.decode(encoded_packet)
def encode(self, b64=False):

View File

@ -36,7 +36,7 @@ class Server(object):
is a grace period added by the server.
:param ping_timeout: The time in seconds that the client waits for the
server to respond before disconnecting. The default
is 5 seconds.
is 20 seconds.
:param max_http_buffer_size: The maximum size of a message when using the
polling transport. The default is 1,000,000
bytes.
@ -78,20 +78,25 @@ class Server(object):
inactive clients are closed. Set to ``False`` to
disable the monitoring task (not recommended). The
default is ``True``.
:param transports: The list of allowed transports. Valid transports
are ``'polling'`` and ``'websocket'``. Defaults to
``['polling', 'websocket']``.
:param kwargs: Reserved for future extensions, any additional parameters
given as keyword arguments will be silently ignored.
"""
compression_methods = ['gzip', 'deflate']
event_names = ['connect', 'disconnect', 'message']
valid_transports = ['polling', 'websocket']
_default_monitor_clients = True
sequence_number = 0
def __init__(self, async_mode=None, ping_interval=25, ping_timeout=5,
def __init__(self, async_mode=None, ping_interval=25, ping_timeout=20,
max_http_buffer_size=1000000, allow_upgrades=True,
http_compression=True, compression_threshold=1024,
cookie=None, cors_allowed_origins=None,
cors_credentials=True, logger=False, json=None,
async_handlers=True, monitor_clients=None, **kwargs):
async_handlers=True, monitor_clients=None, transports=None,
**kwargs):
self.ping_timeout = ping_timeout
if isinstance(ping_interval, tuple):
self.ping_interval = ping_interval[0]
@ -152,6 +157,14 @@ class Server(object):
self._async['asyncio']: # pragma: no cover
raise ValueError('The selected async_mode requires asyncio and '
'must use the AsyncServer class')
if transports is not None:
if isinstance(transports, str):
transports = [transports]
transports = [transport for transport in transports
if transport in self.valid_transports]
if not transports:
raise ValueError('No valid transports provided')
self.transports = transports or self.valid_transports
self.logger.info('Server initialized for %s.', self.async_mode)
def is_asyncio_based(self):
@ -333,8 +346,7 @@ class Server(object):
allowed_origins:
self._log_error_once(
origin + ' is not an accepted origin.', 'bad-origin')
r = self._bad_request(
origin + ' is not an accepted origin.')
r = self._bad_request('Not an accepted origin.')
start_response(r['status'], r['headers'])
return [r['response']]
@ -343,6 +355,14 @@ class Server(object):
jsonp = False
jsonp_index = None
# make sure the client uses an allowed transport
transport = query.get('transport', ['polling'])[0]
if transport not in self.transports:
self._log_error_once('Invalid transport', 'bad-transport')
r = self._bad_request('Invalid transport')
start_response(r['status'], r['headers'])
return [r['response']]
# make sure the client speaks a compatible Engine.IO version
sid = query['sid'][0] if 'sid' in query else None
if sid is None and query.get('EIO') != ['4']:
@ -369,7 +389,6 @@ class Server(object):
r = self._bad_request('Invalid JSONP index number')
elif method == 'GET':
if sid is None:
transport = query.get('transport', ['polling'])[0]
# transport must be one of 'polling' or 'websocket'.
# if 'websocket', the HTTP_UPGRADE header must match.
upgrade_header = environ.get('HTTP_UPGRADE').lower() \
@ -379,13 +398,13 @@ class Server(object):
r = self._handle_connect(environ, start_response,
transport, jsonp_index)
else:
self._log_error_once('Invalid transport ' + transport,
'bad-transport')
r = self._bad_request('Invalid transport ' + transport)
self._log_error_once('Invalid websocket upgrade',
'bad-upgrade')
r = self._bad_request('Invalid websocket upgrade')
else:
if sid not in self.sockets:
self._log_error_once('Invalid session ' + sid, 'bad-sid')
r = self._bad_request('Invalid session ' + sid)
r = self._bad_request('Invalid session')
else:
socket = self._get_socket(sid)
try:
@ -405,7 +424,7 @@ class Server(object):
if sid is None or sid not in self.sockets:
self._log_error_once(
'Invalid session ' + (sid or 'None'), 'bad-sid')
r = self._bad_request('Invalid session ' + (sid or 'None'))
r = self._bad_request('Invalid session')
else:
socket = self._get_socket(sid)
try:
@ -453,9 +472,9 @@ class Server(object):
:param args: arguments to pass to the function.
:param kwargs: keyword arguments to pass to the function.
This function returns an object compatible with the `Thread` class in
the Python standard library. The `start()` method on this object is
already called by this function.
This function returns an object that represents the background task,
on which the ``join()`` methond can be invoked to wait for the task to
complete.
"""
th = self._async['thread'](target=target, args=args, kwargs=kwargs)
th.start()
@ -581,7 +600,14 @@ class Server(object):
def _upgrades(self, sid, transport):
"""Return the list of possible upgrades for a client connection."""
if not self.allow_upgrades or self._get_socket(sid).upgraded or \
self._async['websocket'] is None or transport == 'websocket':
transport == 'websocket':
return []
if self._async['websocket'] is None: # pragma: no cover
self._log_error_once(
'The WebSocket transport is not available, you must install a '
'WebSocket server that is compatible with your async mode to '
'enable it. See the documentation for details.',
'no-websocket')
return []
return ['websocket']
@ -656,13 +682,15 @@ class Server(object):
if 'wsgi.url_scheme' in environ and 'HTTP_HOST' in environ:
default_origins.append('{scheme}://{host}'.format(
scheme=environ['wsgi.url_scheme'], host=environ['HTTP_HOST']))
if 'HTTP_X_FORWARDED_HOST' in environ:
if 'HTTP_X_FORWARDED_PROTO' in environ or \
'HTTP_X_FORWARDED_HOST' in environ:
scheme = environ.get(
'HTTP_X_FORWARDED_PROTO',
environ['wsgi.url_scheme']).split(',')[0].strip()
default_origins.append('{scheme}://{host}'.format(
scheme=scheme, host=environ['HTTP_X_FORWARDED_HOST'].split(
',')[0].strip()))
scheme=scheme, host=environ.get(
'HTTP_X_FORWARDED_HOST', environ['HTTP_HOST']).split(
',')[0].strip()))
if self.cors_allowed_origins is None:
allowed_origins = default_origins
elif self.cors_allowed_origins == '*':

View File

@ -159,6 +159,12 @@ class Socket(object):
def _websocket_handler(self, ws):
"""Engine.IO handler for websocket transport."""
def websocket_wait():
data = ws.wait()
if data and len(data) > self.server.max_http_buffer_size:
raise ValueError('packet is too large')
return data
# try to set a socket timeout matching the configured ping interval
# and timeout
for attr in ['_sock', 'socket']: # pragma: no cover
@ -170,7 +176,7 @@ class Socket(object):
# the socket was already connected, so this is an upgrade
self.upgrading = True # hold packet sends during the upgrade
pkt = ws.wait()
pkt = websocket_wait()
decoded_pkt = packet.Packet(encoded_packet=pkt)
if decoded_pkt.packet_type != packet.PING or \
decoded_pkt.data != 'probe':
@ -181,7 +187,7 @@ class Socket(object):
ws.send(packet.Packet(packet.PONG, data='probe').encode())
self.queue.put(packet.Packet(packet.NOOP)) # end poll
pkt = ws.wait()
pkt = websocket_wait()
decoded_pkt = packet.Packet(encoded_packet=pkt)
if decoded_pkt.packet_type != packet.UPGRADE:
self.upgraded = False
@ -221,7 +227,7 @@ class Socket(object):
while True:
p = None
try:
p = ws.wait()
p = websocket_wait()
except Exception as e:
# if the socket is already closed, we can assume this is a
# downstream error of that

View File

@ -21,23 +21,28 @@ def get_static_file(path, static_files):
"content_type". If the requested URL does not match any static file, the
return value is None.
"""
extra_path = ''
if path in static_files:
f = static_files[path]
else:
f = None
rest = ''
while path != '':
path, last = path.rsplit('/', 1)
rest = '/' + last + rest
extra_path = '/' + last + extra_path
if path in static_files:
f = static_files[path] + rest
f = static_files[path]
break
elif path + '/' in static_files:
f = static_files[path + '/'] + rest[1:]
f = static_files[path + '/']
break
if f:
if isinstance(f, str):
f = {'filename': f}
else:
f = f.copy() # in case it is mutated below
if f['filename'].endswith('/') and extra_path.startswith('/'):
extra_path = extra_path[1:]
f['filename'] += extra_path
if f['filename'].endswith('/'):
if '' in static_files:
if isinstance(static_files[''], str):

View File

@ -10,7 +10,7 @@ chardet=3.0.4
cloudscraper=1.2.58
deep-translator=1.5.4
dogpile.cache=0.6.5
engineio=4.0.2dev
engineio=4.3.0
enzyme=0.4.1
ffsubsync=0.4.11
Flask=1.1.1