Updated SignalRCore to support websocket-client 1.0.0.

This commit is contained in:
morpheus65535 2021-05-18 14:33:37 -04:00
parent 1a07a00e71
commit 2643023240
20 changed files with 372 additions and 333 deletions

View File

@ -133,12 +133,12 @@ class WebsocketTransport(BaseTransport):
raise ValueError("Handshake error {0}".format(msg.error)) raise ValueError("Handshake error {0}".format(msg.error))
return messages return messages
def on_open(self): def on_open(self, _):
self.logger.debug("-- web socket open --") self.logger.debug("-- web socket open --")
msg = self.protocol.handshake_message() msg = self.protocol.handshake_message()
self.send(msg) self.send(msg)
def on_close(self): def on_close(self, _):
self.logger.debug("-- web socket close --") self.logger.debug("-- web socket close --")
self.state = ConnectionState.disconnected self.state = ConnectionState.disconnected
if self._on_close is not None and callable(self._on_close): if self._on_close is not None and callable(self._on_close):
@ -150,35 +150,23 @@ class WebsocketTransport(BaseTransport):
if self._on_close is not None and callable(self._on_close): if self._on_close is not None and callable(self._on_close):
self._on_close() self._on_close()
def on_socket_error(self, error): def on_socket_error(self, _, error):
""" """
Throws error related on
https://github.com/websocket-client/websocket-client/issues/449
Args: Args:
_: Required to support websocket-client version equal or greater than 0.58.0
error ([type]): [description] error ([type]): [description]
Raises: Raises:
HubError: [description] HubError: [description]
""" """
self.logger.debug("-- web socket error --") self.logger.debug("-- web socket error --")
if (type(error) is AttributeError and self.logger.error(traceback.format_exc(5, True))
"'NoneType' object has no attribute 'connected'" self.logger.error("{0} {1}".format(self, error))
in str(error)): self.logger.error("{0} {1}".format(error, type(error)))
url = "https://github.com/websocket-client" +\ self._on_close()
"/websocket-client/issues/449" raise HubError(error)
self.logger.warning(
"Websocket closing error: issue" +
url)
self._on_close()
else:
self.logger.error(traceback.format_exc(5, True))
self.logger.error("{0} {1}".format(self, error))
self.logger.error("{0} {1}".format(error, type(error)))
self._on_close()
raise HubError(error)
def on_message(self, raw_message): def on_message(self, _, raw_message):
self.logger.debug("Message received{0}".format(raw_message)) self.logger.debug("Message received{0}".format(raw_message))
self.connection_checker.last_message = time.time() self.connection_checker.last_message = time.time()
if not self.handshake_received: if not self.handshake_received:

View File

@ -30,7 +30,7 @@ rebulk=3.0.1
requests=2.18.4 requests=2.18.4
semver=2.13.0 semver=2.13.0
signalr-client-threads=0.0.12 <-- Modified to work with Sonarr signalr-client-threads=0.0.12 <-- Modified to work with Sonarr
signalrcore=0.9.2 <-- https://github.com/mandrewcito/signalrcore/pull/60 signalrcore=0.9.2 <-- https://github.com/mandrewcito/signalrcore/pull/60 and 61
SimpleConfigParser=0.1.0 <-- modified version: do not update!!! SimpleConfigParser=0.1.0 <-- modified version: do not update!!!
six=1.11.0 six=1.11.0
socketio=5.1.0 socketio=5.1.0
@ -40,7 +40,7 @@ subliminal=2.1.0dev
tzlocal=2.1b1 tzlocal=2.1b1
twine=3.4.1 twine=3.4.1
urllib3=1.23 urllib3=1.23
websocket-client=0.59.0 <-- Modified to work with SignalRCore: https://github.com/websocket-client/websocket-client/commit/3112b7d75b1e5d65cb8fdfca7801606649044ed1#commitcomment-50947250 websocket-client=1.0.0
## indirect dependencies ## indirect dependencies
auditok=0.1.5 # Required-by: ffsubsync auditok=0.1.5 # Required-by: ffsubsync

View File

@ -25,4 +25,4 @@ from ._exceptions import *
from ._logging import * from ._logging import *
from ._socket import * from ._socket import *
__version__ = "0.59.0" __version__ = "1.0.0"

View File

@ -26,17 +26,12 @@ import array
import os import os
import struct import struct
import six
from ._exceptions import * from ._exceptions import *
from ._utils import validate_utf8 from ._utils import validate_utf8
from threading import Lock from threading import Lock
try: try:
if six.PY3: import numpy
import numpy
else:
numpy = None
except ImportError: except ImportError:
numpy = None numpy = None
@ -53,10 +48,7 @@ except ImportError:
for i in range(len(_d)): for i in range(len(_d)):
_d[i] ^= _m[i % 4] _d[i] ^= _m[i % 4]
if six.PY3: return _d.tobytes()
return _d.tobytes()
else:
return _d.tostring()
__all__ = [ __all__ = [
@ -181,8 +173,7 @@ class ABNF(object):
if l > 2 and not skip_utf8_validation and not validate_utf8(self.data[2:]): if l > 2 and not skip_utf8_validation and not validate_utf8(self.data[2:]):
raise WebSocketProtocolException("Invalid close frame.") raise WebSocketProtocolException("Invalid close frame.")
code = 256 * \ code = 256 * self.data[0] + self.data[1]
six.byte2int(self.data[0:1]) + six.byte2int(self.data[1:2])
if not self._is_valid_close_status(code): if not self._is_valid_close_status(code):
raise WebSocketProtocolException("Invalid close opcode.") raise WebSocketProtocolException("Invalid close opcode.")
@ -211,7 +202,7 @@ class ABNF(object):
fin: <type> fin: <type>
fin flag. if set to 0, create continue fragmentation. fin flag. if set to 0, create continue fragmentation.
""" """
if opcode == ABNF.OPCODE_TEXT and isinstance(data, six.text_type): if opcode == ABNF.OPCODE_TEXT and isinstance(data, str):
data = data.encode("utf-8") data = data.encode("utf-8")
# mask must be set if send data from client # mask must be set if send data from client
return ABNF(fin, 0, 0, 0, opcode, 1, data) return ABNF(fin, 0, 0, 0, opcode, 1, data)
@ -230,17 +221,14 @@ class ABNF(object):
frame_header = chr(self.fin << 7 | frame_header = chr(self.fin << 7 |
self.rsv1 << 6 | self.rsv2 << 5 | self.rsv3 << 4 | self.rsv1 << 6 | self.rsv2 << 5 | self.rsv3 << 4 |
self.opcode) self.opcode).encode('latin-1')
if length < ABNF.LENGTH_7: if length < ABNF.LENGTH_7:
frame_header += chr(self.mask << 7 | length) frame_header += chr(self.mask << 7 | length).encode('latin-1')
frame_header = six.b(frame_header)
elif length < ABNF.LENGTH_16: elif length < ABNF.LENGTH_16:
frame_header += chr(self.mask << 7 | 0x7e) frame_header += chr(self.mask << 7 | 0x7e).encode('latin-1')
frame_header = six.b(frame_header)
frame_header += struct.pack("!H", length) frame_header += struct.pack("!H", length)
else: else:
frame_header += chr(self.mask << 7 | 0x7f) frame_header += chr(self.mask << 7 | 0x7f).encode('latin-1')
frame_header = six.b(frame_header)
frame_header += struct.pack("!Q", length) frame_header += struct.pack("!Q", length)
if not self.mask: if not self.mask:
@ -252,7 +240,7 @@ class ABNF(object):
def _get_masked(self, mask_key): def _get_masked(self, mask_key):
s = ABNF.mask(mask_key, self.data) s = ABNF.mask(mask_key, self.data)
if isinstance(mask_key, six.text_type): if isinstance(mask_key, str):
mask_key = mask_key.encode('utf-8') mask_key = mask_key.encode('utf-8')
return mask_key + s return mask_key + s
@ -265,34 +253,32 @@ class ABNF(object):
Parameters Parameters
---------- ----------
mask_key: <type> mask_key: <type>
4 byte string(byte). 4 byte string.
data: <type> data: <type>
data to mask/unmask. data to mask/unmask.
""" """
if data is None: if data is None:
data = "" data = ""
if isinstance(mask_key, six.text_type): if isinstance(mask_key, str):
mask_key = six.b(mask_key) mask_key = mask_key.encode('latin-1')
if isinstance(data, six.text_type): if isinstance(data, str):
data = six.b(data) data = data.encode('latin-1')
if numpy: if numpy:
origlen = len(data) origlen = len(data)
_mask_key = mask_key[3] << 24 | mask_key[2] << 16 | mask_key[1] << 8 | mask_key[0] _mask_key = mask_key[3] << 24 | mask_key[2] << 16 | mask_key[1] << 8 | mask_key[0]
# We need data to be a multiple of four... # We need data to be a multiple of four...
data += bytes(" " * (4 - (len(data) % 4)), "us-ascii") data += b' ' * (4 - (len(data) % 4))
a = numpy.frombuffer(data, dtype="uint32") a = numpy.frombuffer(data, dtype="uint32")
masked = numpy.bitwise_xor(a, [_mask_key]).astype("uint32") masked = numpy.bitwise_xor(a, [_mask_key]).astype("uint32")
if len(data) > origlen: if len(data) > origlen:
return masked.tobytes()[:origlen] return masked.tobytes()[:origlen]
return masked.tobytes() return masked.tobytes()
else: else:
_m = array.array("B", mask_key) return _mask(array.array("B", mask_key), array.array("B", data))
_d = array.array("B", data)
return _mask(_m, _d)
class frame_buffer(object): class frame_buffer(object):
@ -319,20 +305,12 @@ class frame_buffer(object):
def recv_header(self): def recv_header(self):
header = self.recv_strict(2) header = self.recv_strict(2)
b1 = header[0] b1 = header[0]
if six.PY2:
b1 = ord(b1)
fin = b1 >> 7 & 1 fin = b1 >> 7 & 1
rsv1 = b1 >> 6 & 1 rsv1 = b1 >> 6 & 1
rsv2 = b1 >> 5 & 1 rsv2 = b1 >> 5 & 1
rsv3 = b1 >> 4 & 1 rsv3 = b1 >> 4 & 1
opcode = b1 & 0xf opcode = b1 & 0xf
b2 = header[1] b2 = header[1]
if six.PY2:
b2 = ord(b2)
has_mask = b2 >> 7 & 1 has_mask = b2 >> 7 & 1
length_bits = b2 & 0x7f length_bits = b2 & 0x7f
@ -408,7 +386,7 @@ class frame_buffer(object):
self.recv_buffer.append(bytes_) self.recv_buffer.append(bytes_)
shortage -= len(bytes_) shortage -= len(bytes_)
unified = six.b("").join(self.recv_buffer) unified = bytes("", 'utf-8').join(self.recv_buffer)
if shortage == 0: if shortage == 0:
self.recv_buffer = [] self.recv_buffer = []

View File

@ -22,15 +22,11 @@ Copyright (C) 2010 Hiroki Ohtani(liris)
Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
""" """
import inspect import selectors
import select
import sys import sys
import threading import threading
import time import time
import traceback import traceback
import six
from ._abnf import ABNF from ._abnf import ABNF
from ._core import WebSocket, getdefaulttimeout from ._core import WebSocket, getdefaulttimeout
from ._exceptions import * from ._exceptions import *
@ -50,12 +46,15 @@ class Dispatcher:
def read(self, sock, read_callback, check_callback): def read(self, sock, read_callback, check_callback):
while self.app.keep_running: while self.app.keep_running:
r, w, e = select.select( sel = selectors.DefaultSelector()
(self.app.sock.sock, ), (), (), self.ping_timeout) sel.register(self.app.sock.sock, selectors.EVENT_READ)
r = sel.select(self.ping_timeout)
if r: if r:
if not read_callback(): if not read_callback():
break break
check_callback() check_callback()
sel.close()
class SSLDispatcher: class SSLDispatcher:
@ -79,8 +78,14 @@ class SSLDispatcher:
if sock.pending(): if sock.pending():
return [sock,] return [sock,]
r, w, e = select.select((sock, ), (), (), self.ping_timeout) sel = selectors.DefaultSelector()
return r sel.register(sock, selectors.EVENT_READ)
r = sel.select(self.ping_timeout)
sel.close()
if len(r) > 0:
return r[0][0]
class WebSocketApp(object): class WebSocketApp(object):
@ -255,7 +260,9 @@ class WebSocketApp(object):
""" """
if ping_timeout is not None and ping_timeout <= 0: if ping_timeout is not None and ping_timeout <= 0:
ping_timeout = None raise WebSocketException("Ensure ping_timeout > 0")
if ping_interval is not None and ping_interval < 0:
raise WebSocketException("Ensure ping_interval >= 0")
if ping_timeout and ping_interval and ping_interval <= ping_timeout: if ping_timeout and ping_interval and ping_interval <= ping_timeout:
raise WebSocketException("Ensure ping_interval > ping_timeout") raise WebSocketException("Ensure ping_interval > ping_timeout")
if not sockopt: if not sockopt:
@ -276,15 +283,16 @@ class WebSocketApp(object):
If close_frame is set, we will invoke the on_close handler with the If close_frame is set, we will invoke the on_close handler with the
statusCode and reason from there. statusCode and reason from there.
""" """
if thread and thread.is_alive(): if thread and thread.is_alive():
event.set() event.set()
thread.join() thread.join()
self.keep_running = False self.keep_running = False
if self.sock: if self.sock:
self.sock.close() self.sock.close()
close_args = self._get_close_args( close_status_code, close_reason = self._get_close_args(
close_frame.data if close_frame else None) close_frame if close_frame else None)
self._callback(self.on_close, *close_args) self._callback(self.on_close, close_status_code, close_reason)
self.sock = None self.sock = None
try: try:
@ -332,7 +340,7 @@ class WebSocketApp(object):
frame.data, frame.fin) frame.data, frame.fin)
else: else:
data = frame.data data = frame.data
if six.PY3 and op_code == ABNF.OPCODE_TEXT: if op_code == ABNF.OPCODE_TEXT:
data = data.decode("utf-8") data = data.decode("utf-8")
self._callback(self.on_data, data, frame.opcode, True) self._callback(self.on_data, data, frame.opcode, True)
self._callback(self.on_message, data) self._callback(self.on_message, data)
@ -367,33 +375,29 @@ class WebSocketApp(object):
return Dispatcher(self, timeout) return Dispatcher(self, timeout)
def _get_close_args(self, data): def _get_close_args(self, close_frame):
""" """
_get_close_args extracts the code, reason from the close body _get_close_args extracts the close code and reason from the close body
if they exists, and if the self.on_close except three arguments if it exists (RFC6455 says WebSocket Connection Close Code is optional)
""" """
# if the on_close callback is "old", just return empty list # Need to catch the case where close_frame is None
if sys.version_info < (3, 0): # Otherwise the following if statement causes an error
if not self.on_close or len(inspect.getargspec(self.on_close).args) != 3: if not self.on_close or not close_frame:
return [] return [None, None]
# Extract close frame status code
if close_frame.data and len(close_frame.data) >= 2:
close_status_code = 256 * close_frame.data[0] + close_frame.data[1]
reason = close_frame.data[2:].decode('utf-8')
return [close_status_code, reason]
else: else:
if not self.on_close or len(inspect.getfullargspec(self.on_close).args) != 3: # Most likely reached this because len(close_frame_data.data) < 2
return [] return [None, None]
if data and len(data) >= 2:
code = 256 * six.byte2int(data[0:1]) + six.byte2int(data[1:2])
reason = data[2:].decode('utf-8')
return [code, reason]
return [None, None]
def _callback(self, callback, *args): def _callback(self, callback, *args):
if callback: if callback:
try: try:
if inspect.ismethod(callback): callback(self, *args)
callback(*args)
else:
callback(self, *args)
except Exception as e: except Exception as e:
_logging.error("error from callback {}: {}".format(callback, e)) _logging.error("error from callback {}: {}".format(callback, e))

View File

@ -22,10 +22,7 @@ Copyright (C) 2010 Hiroki Ohtani(liris)
Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
""" """
try: import http.cookies
import Cookie
except:
import http.cookies as Cookie
class SimpleCookieJar(object): class SimpleCookieJar(object):
@ -34,26 +31,20 @@ class SimpleCookieJar(object):
def add(self, set_cookie): def add(self, set_cookie):
if set_cookie: if set_cookie:
try: simpleCookie = http.cookies.SimpleCookie(set_cookie)
simpleCookie = Cookie.SimpleCookie(set_cookie)
except:
simpleCookie = Cookie.SimpleCookie(set_cookie.encode('ascii', 'ignore'))
for k, v in simpleCookie.items(): for k, v in simpleCookie.items():
domain = v.get("domain") domain = v.get("domain")
if domain: if domain:
if not domain.startswith("."): if not domain.startswith("."):
domain = "." + domain domain = "." + domain
cookie = self.jar.get(domain) if self.jar.get(domain) else Cookie.SimpleCookie() cookie = self.jar.get(domain) if self.jar.get(domain) else http.cookies.SimpleCookie()
cookie.update(simpleCookie) cookie.update(simpleCookie)
self.jar[domain.lower()] = cookie self.jar[domain.lower()] = cookie
def set(self, set_cookie): def set(self, set_cookie):
if set_cookie: if set_cookie:
try: simpleCookie = http.cookies.SimpleCookie(set_cookie)
simpleCookie = Cookie.SimpleCookie(set_cookie)
except:
simpleCookie = Cookie.SimpleCookie(set_cookie.encode('ascii', 'ignore'))
for k, v in simpleCookie.items(): for k, v in simpleCookie.items():
domain = v.get("domain") domain = v.get("domain")

View File

@ -1,4 +1,3 @@
from __future__ import print_function
""" """
_core.py _core.py
==================================== ====================================
@ -30,8 +29,6 @@ import struct
import threading import threading
import time import time
import six
# websocket modules # websocket modules
from ._abnf import * from ._abnf import *
from ._exceptions import * from ._exceptions import *
@ -226,6 +223,9 @@ class WebSocket(object):
cookie value. cookie value.
- origin: str - origin: str
custom origin url. custom origin url.
- connection: str
custom connection header value.
default value "Upgrade" set in _handshake.py
- suppress_origin: bool - suppress_origin: bool
suppress outputting origin header. suppress outputting origin header.
- host: str - host: str
@ -271,11 +271,11 @@ class WebSocket(object):
Parameters Parameters
---------- ----------
payload: <type> payload: str
Payload must be utf-8 string or unicode, Payload must be utf-8 string or unicode,
if the opcode is OPCODE_TEXT. if the opcode is OPCODE_TEXT.
Otherwise, it must be string(byte array) Otherwise, it must be string(byte array)
opcode: <type> opcode: int
operation code to send. Please see OPCODE_XXX. operation code to send. Please see OPCODE_XXX.
""" """
@ -296,7 +296,7 @@ class WebSocket(object):
Parameters Parameters
---------- ----------
frame: <type> frame: ABNF frame
frame data created by ABNF.create_frame frame data created by ABNF.create_frame
""" """
if self.get_mask_key: if self.get_mask_key:
@ -304,8 +304,8 @@ class WebSocket(object):
data = frame.format() data = frame.format()
length = len(data) length = len(data)
if (isEnabledForTrace()): if (isEnabledForTrace()):
trace("send: " + repr(data)) trace("++Sent raw: " + repr(data))
trace("++Sent decoded: " + frame.__str__())
with self.lock: with self.lock:
while data: while data:
l = self._send(data) l = self._send(data)
@ -322,10 +322,10 @@ class WebSocket(object):
Parameters Parameters
---------- ----------
payload: <type> payload: str
data payload to send server. data payload to send server.
""" """
if isinstance(payload, six.text_type): if isinstance(payload, str):
payload = payload.encode("utf-8") payload = payload.encode("utf-8")
self.send(payload, ABNF.OPCODE_PING) self.send(payload, ABNF.OPCODE_PING)
@ -335,10 +335,10 @@ class WebSocket(object):
Parameters Parameters
---------- ----------
payload: <type> payload: str
data payload to send server. data payload to send server.
""" """
if isinstance(payload, six.text_type): if isinstance(payload, str):
payload = payload.encode("utf-8") payload = payload.encode("utf-8")
self.send(payload, ABNF.OPCODE_PONG) self.send(payload, ABNF.OPCODE_PONG)
@ -352,7 +352,7 @@ class WebSocket(object):
""" """
with self.readlock: with self.readlock:
opcode, data = self.recv_data() opcode, data = self.recv_data()
if six.PY3 and opcode == ABNF.OPCODE_TEXT: if opcode == ABNF.OPCODE_TEXT:
return data.decode("utf-8") return data.decode("utf-8")
elif opcode == ABNF.OPCODE_TEXT or opcode == ABNF.OPCODE_BINARY: elif opcode == ABNF.OPCODE_TEXT or opcode == ABNF.OPCODE_BINARY:
return data return data
@ -394,6 +394,9 @@ class WebSocket(object):
""" """
while True: while True:
frame = self.recv_frame() frame = self.recv_frame()
if (isEnabledForTrace()):
trace("++Rcv raw: " + repr(frame.format()))
trace("++Rcv decoded: " + frame.__str__())
if not frame: if not frame:
# handle error: # handle error:
# 'NoneType' object has no attribute 'opcode' # 'NoneType' object has no attribute 'opcode'
@ -431,7 +434,7 @@ class WebSocket(object):
""" """
return self.frame_buffer.recv_frame() return self.frame_buffer.recv_frame()
def send_close(self, status=STATUS_NORMAL, reason=six.b("")): def send_close(self, status=STATUS_NORMAL, reason=bytes('', encoding='utf-8')):
""" """
Send close data to the server. Send close data to the server.
@ -447,16 +450,16 @@ class WebSocket(object):
self.connected = False self.connected = False
self.send(struct.pack('!H', status) + reason, ABNF.OPCODE_CLOSE) self.send(struct.pack('!H', status) + reason, ABNF.OPCODE_CLOSE)
def close(self, status=STATUS_NORMAL, reason=six.b(""), timeout=3): def close(self, status=STATUS_NORMAL, reason=bytes('', encoding='utf-8'), timeout=3):
""" """
Close Websocket object Close Websocket object
Parameters Parameters
---------- ----------
status: <type> status: int
status code to send. see STATUS_XXX. status code to send. see STATUS_XXX.
reason: <type> reason: bytes
the reason to close. This must be string. the reason to close.
timeout: int or float timeout: int or float
timeout until receive a close frame. timeout until receive a close frame.
If None, it will wait forever until receive a close frame. If None, it will wait forever until receive a close frame.
@ -487,8 +490,10 @@ class WebSocket(object):
break break
self.sock.settimeout(sock_timeout) self.sock.settimeout(sock_timeout)
self.sock.shutdown(socket.SHUT_RDWR) self.sock.shutdown(socket.SHUT_RDWR)
except: except OSError: # This happens often on Mac
pass pass
except:
raise
self.shutdown() self.shutdown()

View File

@ -21,36 +21,16 @@ Copyright (C) 2010 Hiroki Ohtani(liris)
import hashlib import hashlib
import hmac import hmac
import os import os
from base64 import encodebytes as base64encode
import six from http import client as HTTPStatus
from ._cookiejar import SimpleCookieJar from ._cookiejar import SimpleCookieJar
from ._exceptions import * from ._exceptions import *
from ._http import * from ._http import *
from ._logging import * from ._logging import *
from ._socket import * from ._socket import *
if hasattr(six, 'PY3') and six.PY3:
from base64 import encodebytes as base64encode
else:
from base64 import encodestring as base64encode
if hasattr(six, 'PY3') and six.PY3:
if hasattr(six, 'PY34') and six.PY34:
from http import client as HTTPStatus
else:
from http import HTTPStatus
else:
import httplib as HTTPStatus
__all__ = ["handshake_response", "handshake", "SUPPORTED_REDIRECT_STATUSES"] __all__ = ["handshake_response", "handshake", "SUPPORTED_REDIRECT_STATUSES"]
if hasattr(hmac, "compare_digest"):
compare_digest = hmac.compare_digest
else:
def compare_digest(s1, s2):
return s1 == s2
# websocket supported version. # websocket supported version.
VERSION = 13 VERSION = 13
@ -194,12 +174,12 @@ def _validate(headers, key, subprotocols):
return False, None return False, None
result = result.lower() result = result.lower()
if isinstance(result, six.text_type): if isinstance(result, str):
result = result.encode('utf-8') result = result.encode('utf-8')
value = (key + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11").encode('utf-8') value = (key + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11").encode('utf-8')
hashed = base64encode(hashlib.sha1(value).digest()).strip().lower() hashed = base64encode(hashlib.sha1(value).digest()).strip().lower()
success = compare_digest(hashed, result) success = hmac.compare_digest(hashed, result)
if success: if success:
return True, subproto return True, subproto

View File

@ -23,18 +23,13 @@ import os
import socket import socket
import sys import sys
import six
from ._exceptions import * from ._exceptions import *
from ._logging import * from ._logging import *
from ._socket import* from ._socket import*
from ._ssl_compat import * from ._ssl_compat import *
from ._url import * from ._url import *
if six.PY3: from base64 import encodebytes as base64encode
from base64 import encodebytes as base64encode
else:
from base64 import encodestring as base64encode
__all__ = ["proxy_info", "connect", "read_headers"] __all__ = ["proxy_info", "connect", "read_headers"]
@ -92,11 +87,10 @@ def _open_proxied_socket(url, options, proxy):
socket_options=DEFAULT_SOCKET_OPTION + options.sockopt socket_options=DEFAULT_SOCKET_OPTION + options.sockopt
) )
if is_secure: if is_secure and HAVE_SSL:
if HAVE_SSL: sock = _ssl_socket(sock, options.sslopt, hostname)
sock = _ssl_socket(sock, options.sslopt, hostname) elif is_secure:
else: raise WebSocketException("SSL not available.")
raise WebSocketException("SSL not available.")
return sock, (hostname, port, resource) return sock, (hostname, port, resource)
@ -190,6 +184,8 @@ def _open_socket(addrinfo_list, sockopt, timeout):
err = error err = error
continue continue
else: else:
if sock:
sock.close()
raise error raise error
else: else:
break break
@ -203,10 +199,6 @@ def _open_socket(addrinfo_list, sockopt, timeout):
return sock return sock
def _can_use_sni():
return six.PY2 and sys.version_info >= (2, 7, 9) or sys.version_info >= (3, 2)
def _wrap_sni_socket(sock, sslopt, hostname, check_hostname): def _wrap_sni_socket(sock, sslopt, hostname, check_hostname):
context = ssl.SSLContext(sslopt.get('ssl_version', ssl.PROTOCOL_SSLv23)) context = ssl.SSLContext(sslopt.get('ssl_version', ssl.PROTOCOL_SSLv23))
@ -250,8 +242,7 @@ def _ssl_socket(sock, user_sslopt, hostname):
certPath = os.environ.get('WEBSOCKET_CLIENT_CA_BUNDLE') certPath = os.environ.get('WEBSOCKET_CLIENT_CA_BUNDLE')
if certPath and os.path.isfile(certPath) \ if certPath and os.path.isfile(certPath) \
and user_sslopt.get('ca_certs', None) is None \ and user_sslopt.get('ca_certs', None) is None:
and user_sslopt.get('ca_cert', None) is None:
sslopt['ca_certs'] = certPath sslopt['ca_certs'] = certPath
elif certPath and os.path.isdir(certPath) \ elif certPath and os.path.isdir(certPath) \
and user_sslopt.get('ca_cert_path', None) is None: and user_sslopt.get('ca_cert_path', None) is None:
@ -259,12 +250,7 @@ def _ssl_socket(sock, user_sslopt, hostname):
check_hostname = sslopt["cert_reqs"] != ssl.CERT_NONE and sslopt.pop( check_hostname = sslopt["cert_reqs"] != ssl.CERT_NONE and sslopt.pop(
'check_hostname', True) 'check_hostname', True)
sock = _wrap_sni_socket(sock, sslopt, hostname, check_hostname)
if _can_use_sni():
sock = _wrap_sni_socket(sock, sslopt, hostname, check_hostname)
else:
sslopt.pop('check_hostname', True)
sock = ssl.wrap_socket(sock, **sslopt)
if not HAVE_CONTEXT_CHECK_HOSTNAME and check_hostname: if not HAVE_CONTEXT_CHECK_HOSTNAME and check_hostname:
match_hostname(sock.getpeercert(), hostname) match_hostname(sock.getpeercert(), hostname)

View File

@ -23,11 +23,9 @@ Copyright (C) 2010 Hiroki Ohtani(liris)
""" """
import errno import errno
import select import selectors
import socket import socket
import six
from ._exceptions import * from ._exceptions import *
from ._ssl_compat import * from ._ssl_compat import *
from ._utils import * from ._utils import *
@ -101,7 +99,12 @@ def recv(sock, bufsize):
if error_code != errno.EAGAIN or error_code != errno.EWOULDBLOCK: if error_code != errno.EAGAIN or error_code != errno.EWOULDBLOCK:
raise raise
r, w, e = select.select((sock, ), (), (), sock.gettimeout()) sel = selectors.DefaultSelector()
sel.register(sock, selectors.EVENT_READ)
r = sel.select(sock.gettimeout())
sel.close()
if r: if r:
return sock.recv(bufsize) return sock.recv(bufsize)
@ -132,13 +135,13 @@ def recv_line(sock):
while True: while True:
c = recv(sock, 1) c = recv(sock, 1)
line.append(c) line.append(c)
if c == six.b("\n"): if c == b'\n':
break break
return six.b("").join(line) return b''.join(line)
def send(sock, data): def send(sock, data):
if isinstance(data, six.text_type): if isinstance(data, str):
data = data.encode('utf-8') data = data.encode('utf-8')
if not sock: if not sock:
@ -156,7 +159,12 @@ def send(sock, data):
if error_code != errno.EAGAIN or error_code != errno.EWOULDBLOCK: if error_code != errno.EAGAIN or error_code != errno.EWOULDBLOCK:
raise raise
r, w, e = select.select((), (sock, ), (), sock.gettimeout()) sel = selectors.DefaultSelector()
sel.register(sock, selectors.EVENT_WRITE)
w = sel.select(sock.gettimeout())
sel.close()
if w: if w:
return sock.send(data) return sock.send(data)

View File

@ -25,20 +25,14 @@ try:
from ssl import SSLError from ssl import SSLError
from ssl import SSLWantReadError from ssl import SSLWantReadError
from ssl import SSLWantWriteError from ssl import SSLWantWriteError
HAVE_CONTEXT_CHECK_HOSTNAME = False
if hasattr(ssl, 'SSLContext') and hasattr(ssl.SSLContext, 'check_hostname'): if hasattr(ssl, 'SSLContext') and hasattr(ssl.SSLContext, 'check_hostname'):
HAVE_CONTEXT_CHECK_HOSTNAME = True HAVE_CONTEXT_CHECK_HOSTNAME = True
else:
HAVE_CONTEXT_CHECK_HOSTNAME = False
if hasattr(ssl, "match_hostname"):
from ssl import match_hostname
else:
from backports.ssl_match_hostname import match_hostname
__all__.append("match_hostname")
__all__.append("HAVE_CONTEXT_CHECK_HOSTNAME")
__all__.append("HAVE_CONTEXT_CHECK_HOSTNAME")
HAVE_SSL = True HAVE_SSL = True
except ImportError: except ImportError:
# dummy class of SSLError for ssl none-support environment. # dummy class of SSLError for environment without ssl support
class SSLError(Exception): class SSLError(Exception):
pass pass
@ -49,5 +43,4 @@ except ImportError:
pass pass
ssl = None ssl = None
HAVE_SSL = False HAVE_SSL = False

View File

@ -26,7 +26,7 @@ import os
import socket import socket
import struct import struct
from six.moves.urllib.parse import urlparse from urllib.parse import urlparse
__all__ = ["parse_url", "get_proxy_info"] __all__ = ["parse_url", "get_proxy_info"]

View File

@ -18,8 +18,6 @@ Copyright (C) 2010 Hiroki Ohtani(liris)
Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
""" """
import six
__all__ = ["NoLock", "validate_utf8", "extract_err_message", "extract_error_code"] __all__ = ["NoLock", "validate_utf8", "extract_err_message", "extract_error_code"]
@ -80,8 +78,6 @@ except ImportError:
state = _UTF8_ACCEPT state = _UTF8_ACCEPT
codep = 0 codep = 0
for i in utfbytes: for i in utfbytes:
if six.PY2:
i = ord(i)
state, codep = _decode(state, codep, i) state, codep = _decode(state, codep, i)
if state == _UTF8_REJECT: if state == _UTF8_REJECT:
return False return False

View File

@ -2,5 +2,6 @@ HTTP/1.1 101 WebSocket Protocol Handshake
Connection: Upgrade, Keep-Alive Connection: Upgrade, Keep-Alive
Upgrade: WebSocket Upgrade: WebSocket
Sec-WebSocket-Accept: Kxep+hNu9n51529fGidYu7a3wO0= Sec-WebSocket-Accept: Kxep+hNu9n51529fGidYu7a3wO0=
Set-Cookie: Token=ABCDE
some_header: something some_header: something

View File

@ -25,13 +25,9 @@ import os
import websocket as ws import websocket as ws
from websocket._abnf import * from websocket._abnf import *
import sys import sys
import unittest
sys.path[0:0] = [""] sys.path[0:0] = [""]
if sys.version_info[0] == 2 and sys.version_info[1] < 7:
import unittest2 as unittest
else:
import unittest
class ABNFTest(unittest.TestCase): class ABNFTest(unittest.TestCase):
@ -48,19 +44,40 @@ class ABNFTest(unittest.TestCase):
self.assertEqual(a_bad.opcode, 77) self.assertEqual(a_bad.opcode, 77)
def testValidate(self): def testValidate(self):
a = ABNF(0,0,0,0, opcode=ABNF.OPCODE_PING) a_invalid_ping = ABNF(0,0,0,0, opcode=ABNF.OPCODE_PING)
self.assertRaises(ws.WebSocketProtocolException, a.validate) self.assertRaises(ws._exceptions.WebSocketProtocolException, a_invalid_ping.validate, skip_utf8_validation=False)
a_bad = ABNF(0,1,0,0, opcode=77) a_bad_rsv_value = ABNF(0,1,0,0, opcode=ABNF.OPCODE_TEXT)
self.assertRaises(ws.WebSocketProtocolException, a_bad.validate) self.assertRaises(ws._exceptions.WebSocketProtocolException, a_bad_rsv_value.validate, skip_utf8_validation=False)
a_close = ABNF(0,1,0,0, opcode=ABNF.OPCODE_CLOSE, data="abcdefgh1234567890abcdefgh1234567890abcdefgh1234567890abcdefgh1234567890") a_bad_opcode = ABNF(0,0,0,0, opcode=77)
self.assertRaises(ws.WebSocketProtocolException, a_close.validate) self.assertRaises(ws._exceptions.WebSocketProtocolException, a_bad_opcode.validate, skip_utf8_validation=False)
a_bad_close_frame = ABNF(0,0,0,0, opcode=ABNF.OPCODE_CLOSE, data=b'\x01')
self.assertRaises(ws._exceptions.WebSocketProtocolException, a_bad_close_frame.validate, skip_utf8_validation=False)
a_bad_close_frame_2 = ABNF(0,0,0,0, opcode=ABNF.OPCODE_CLOSE, data=b'\x01\x8a\xaa\xff\xdd')
self.assertRaises(ws._exceptions.WebSocketProtocolException, a_bad_close_frame_2.validate, skip_utf8_validation=False)
a_bad_close_frame_3 = ABNF(0,0,0,0, opcode=ABNF.OPCODE_CLOSE, data=b'\x03\xe7')
self.assertRaises(ws._exceptions.WebSocketProtocolException, a_bad_close_frame_3.validate, skip_utf8_validation=True)
# This caused an error in the Python 2.7 Github Actions build def testMask(self):
# Uncomment test case when Python 2 support no longer wanted abnf_none_data = ABNF(0,0,0,0, opcode=ABNF.OPCODE_PING, mask=1, data=None)
# def testMask(self): bytes_val = bytes("aaaa", 'utf-8')
# ab = ABNF(0,0,0,0, opcode=ABNF.OPCODE_PING) self.assertEqual(abnf_none_data._get_masked(bytes_val), bytes_val)
# bytes_val = bytes("aaaa", 'utf-8') abnf_str_data = ABNF(0,0,0,0, opcode=ABNF.OPCODE_PING, mask=1, data="a")
# self.assertEqual(ab._get_masked(bytes_val), bytes_val) self.assertEqual(abnf_str_data._get_masked(bytes_val), b'aaaa\x00')
def testFormat(self):
abnf_bad_rsv_bits = ABNF(2,0,0,0, opcode=ABNF.OPCODE_TEXT)
self.assertRaises(ValueError, abnf_bad_rsv_bits.format)
abnf_bad_opcode = ABNF(0,0,0,0, opcode=5)
self.assertRaises(ValueError, abnf_bad_opcode.format)
abnf_length_10 = ABNF(0,0,0,0, opcode=ABNF.OPCODE_TEXT, data="abcdefghij")
self.assertEqual(b'\x01', abnf_length_10.format()[0].to_bytes(1, 'big'))
self.assertEqual(b'\x8a', abnf_length_10.format()[1].to_bytes(1, 'big'))
self.assertEqual("fin=0 opcode=1 data=abcdefghij", abnf_length_10.__str__())
abnf_length_20 = ABNF(0,0,0,0, opcode=ABNF.OPCODE_BINARY, data="abcdefghijabcdefghij")
self.assertEqual(b'\x02', abnf_length_20.format()[0].to_bytes(1, 'big'))
self.assertEqual(b'\x94', abnf_length_20.format()[1].to_bytes(1, 'big'))
abnf_no_mask = ABNF(0,0,0,0, opcode=ABNF.OPCODE_TEXT, mask=0, data=b'\x01\x8a\xcc')
self.assertEqual(b'\x01\x03\x01\x8a\xcc', abnf_no_mask.format())
def testFrameBuffer(self): def testFrameBuffer(self):
fb = frame_buffer(0, True) fb = frame_buffer(0, True)

View File

@ -25,18 +25,10 @@ import os
import os.path import os.path
import websocket as ws import websocket as ws
import sys import sys
import ssl
import unittest
sys.path[0:0] = [""] sys.path[0:0] = [""]
try:
import ssl
except ImportError:
HAVE_SSL = False
if sys.version_info[0] == 2 and sys.version_info[1] < 7:
import unittest2 as unittest
else:
import unittest
# Skip test to access the internet. # Skip test to access the internet.
TEST_WITH_INTERNET = os.environ.get('TEST_WITH_INTERNET', '0') == '1' TEST_WITH_INTERNET = os.environ.get('TEST_WITH_INTERNET', '0') == '1'
TRACEABLE = True TRACEABLE = True
@ -71,27 +63,17 @@ class WebSocketAppTest(unittest.TestCase):
close the connection. close the connection.
""" """
WebSocketAppTest.keep_running_open = self.keep_running WebSocketAppTest.keep_running_open = self.keep_running
self.close() self.close()
def on_close(self, *args, **kwargs): def on_close(self, *args, **kwargs):
""" Set the keep_running flag for the test to use. """ Set the keep_running flag for the test to use.
""" """
WebSocketAppTest.keep_running_close = self.keep_running WebSocketAppTest.keep_running_close = self.keep_running
self.send("connection should be closed here")
app = ws.WebSocketApp('ws://echo.websocket.org/', on_open=on_open, on_close=on_close) app = ws.WebSocketApp('ws://echo.websocket.org/', on_open=on_open, on_close=on_close)
app.run_forever() app.run_forever()
# if numpy is installed, this assertion fail
# self.assertFalse(isinstance(WebSocketAppTest.keep_running_open,
# WebSocketAppTest.NotSetYet))
# self.assertFalse(isinstance(WebSocketAppTest.keep_running_close,
# WebSocketAppTest.NotSetYet))
# self.assertEqual(True, WebSocketAppTest.keep_running_open)
# self.assertEqual(False, WebSocketAppTest.keep_running_close)
@unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled") @unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled")
def testSockMaskKey(self): def testSockMaskKey(self):
""" A WebSocketApp should forward the received mask_key function down """ A WebSocketApp should forward the received mask_key function down
@ -99,25 +81,17 @@ class WebSocketAppTest(unittest.TestCase):
""" """
def my_mask_key_func(): def my_mask_key_func():
pass return "\x00\x00\x00\x00"
def on_open(self, *args, **kwargs): app = ws.WebSocketApp('wss://stream.meetup.com/2/rsvps', get_mask_key=my_mask_key_func)
""" Set the value so the test can use it later on and immediately
close the connection.
"""
WebSocketAppTest.get_mask_key_id = id(self.get_mask_key)
self.close()
app = ws.WebSocketApp('ws://echo.websocket.org/', on_open=on_open, get_mask_key=my_mask_key_func)
app.run_forever()
# if numpy is installed, this assertion fail # if numpy is installed, this assertion fail
# Note: We can't use 'is' for comparing the functions directly, need to use 'id'. # Note: We can't use 'is' for comparing the functions directly, need to use 'id'.
# self.assertEqual(WebSocketAppTest.get_mask_key_id, id(my_mask_key_func)) self.assertEqual(id(app.get_mask_key), id(my_mask_key_func))
@unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled") @unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled")
def testPingInterval(self): def testInvalidPingIntervalPingTimeout(self):
""" A WebSocketApp should ping regularly """ Test exception handling if ping_interval < ping_timeout
""" """
def on_ping(app, msg): def on_ping(app, msg):
@ -129,8 +103,73 @@ class WebSocketAppTest(unittest.TestCase):
app.close() app.close()
app = ws.WebSocketApp('wss://api-pub.bitfinex.com/ws/1', on_ping=on_ping, on_pong=on_pong) app = ws.WebSocketApp('wss://api-pub.bitfinex.com/ws/1', on_ping=on_ping, on_pong=on_pong)
app.run_forever(ping_interval=2, ping_timeout=1) # , sslopt={"cert_reqs": ssl.CERT_NONE} self.assertRaises(ws.WebSocketException, app.run_forever, ping_interval=1, ping_timeout=2, sslopt={"cert_reqs": ssl.CERT_NONE})
self.assertRaises(ws.WebSocketException, app.run_forever, ping_interval=2, ping_timeout=3, sslopt={"cert_reqs": ssl.CERT_NONE})
@unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled")
def testPingInterval(self):
""" Test WebSocketApp proper ping functionality
"""
def on_ping(app, msg):
print("Got a ping!")
app.close()
def on_pong(app, msg):
print("Got a pong! No need to respond")
app.close()
app = ws.WebSocketApp('wss://api-pub.bitfinex.com/ws/1', on_ping=on_ping, on_pong=on_pong)
app.run_forever(ping_interval=2, ping_timeout=1, sslopt={"cert_reqs": ssl.CERT_NONE})
@unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled")
def testOpcodeClose(self):
""" Test WebSocketApp close opcode
"""
app = ws.WebSocketApp('wss://tsock.us1.twilio.com/v3/wsconnect')
app.run_forever(ping_interval=2, ping_timeout=1, ping_payload="Ping payload")
@unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled")
def testOpcodeBinary(self):
""" Test WebSocketApp binary opcode
"""
app = ws.WebSocketApp('streaming.vn.teslamotors.com/streaming/')
app.run_forever(ping_interval=2, ping_timeout=1, ping_payload="Ping payload")
@unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled")
def testBadPingInterval(self):
""" A WebSocketApp handling of negative ping_interval
"""
app = ws.WebSocketApp('wss://api-pub.bitfinex.com/ws/1')
self.assertRaises(ws.WebSocketException, app.run_forever, ping_interval=-5, sslopt={"cert_reqs": ssl.CERT_NONE})
@unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled")
def testBadPingTimeout(self):
""" A WebSocketApp handling of negative ping_timeout
"""
app = ws.WebSocketApp('wss://api-pub.bitfinex.com/ws/1')
self.assertRaises(ws.WebSocketException, app.run_forever, ping_timeout=-3, sslopt={"cert_reqs": ssl.CERT_NONE})
@unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled")
def testCloseStatusCode(self):
""" Test extraction of close frame status code and close reason in WebSocketApp
"""
def on_close(wsapp, close_status_code, close_msg):
print("on_close reached")
app = ws.WebSocketApp('wss://tsock.us1.twilio.com/v3/wsconnect', on_close=on_close)
closeframe = ws.ABNF(opcode=ws.ABNF.OPCODE_CLOSE, data=b'\x03\xe8no-init-from-client')
self.assertEqual([1000, 'no-init-from-client'], app._get_close_args(closeframe))
closeframe = ws.ABNF(opcode=ws.ABNF.OPCODE_CLOSE, data=b'')
self.assertEqual([None, None], app._get_close_args(closeframe))
app2 = ws.WebSocketApp('wss://tsock.us1.twilio.com/v3/wsconnect')
closeframe = ws.ABNF(opcode=ws.ABNF.OPCODE_CLOSE, data=b'')
self.assertEqual([None, None], app2._get_close_args(closeframe))
self.assertRaises(ws.WebSocketConnectionClosedException, app.send, data="test if connection is closed")
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -49,6 +49,7 @@ class CookieJarTest(unittest.TestCase):
cookie_jar = SimpleCookieJar() cookie_jar = SimpleCookieJar()
cookie_jar.add("a=b; c=d; domain=abc") cookie_jar.add("a=b; c=d; domain=abc")
self.assertEqual(cookie_jar.get("abc"), "a=b; c=d") self.assertEqual(cookie_jar.get("abc"), "a=b; c=d")
self.assertEqual(cookie_jar.get(None), "")
cookie_jar = SimpleCookieJar() cookie_jar = SimpleCookieJar()
cookie_jar.add("a=b; c=d; domain=abc") cookie_jar.add("a=b; c=d; domain=abc")

View File

@ -24,14 +24,17 @@ Copyright (C) 2010 Hiroki Ohtani(liris)
import os import os
import os.path import os.path
import websocket as ws import websocket as ws
from websocket._http import proxy_info, read_headers, _open_proxied_socket, _tunnel from websocket._http import proxy_info, read_headers, _open_proxied_socket, _tunnel, _get_addrinfo_list, connect
import sys import sys
import unittest
import ssl
import websocket
import socks
import socket
sys.path[0:0] = [""] sys.path[0:0] = [""]
if sys.version_info[0] == 2 and sys.version_info[1] < 7: # Skip test to access the internet.
import unittest2 as unittest TEST_WITH_INTERNET = os.environ.get('TEST_WITH_INTERNET', '0') == '1'
else:
import unittest
class SockMock(object): class SockMock(object):
@ -74,7 +77,7 @@ class HeaderSockMock(SockMock):
class OptsList(): class OptsList():
def __init__(self): def __init__(self):
self.timeout = 0 self.timeout = 1
self.sockopt = [] self.sockopt = []
@ -91,11 +94,49 @@ class HttpTest(unittest.TestCase):
self.assertRaises(ws.WebSocketProxyException, _tunnel, HeaderSockMock("data/header01.txt"), "example.com", 80, ("username", "password")) self.assertRaises(ws.WebSocketProxyException, _tunnel, HeaderSockMock("data/header01.txt"), "example.com", 80, ("username", "password"))
self.assertRaises(ws.WebSocketProxyException, _tunnel, HeaderSockMock("data/header02.txt"), "example.com", 80, ("username", "password")) self.assertRaises(ws.WebSocketProxyException, _tunnel, HeaderSockMock("data/header02.txt"), "example.com", 80, ("username", "password"))
@unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled")
def testConnect(self): def testConnect(self):
# Not currently testing an actual proxy connection, so just check whether TypeError is raised # Not currently testing an actual proxy connection, so just check whether TypeError is raised. This requires internet for a DNS lookup
self.assertRaises(TypeError, _open_proxied_socket, "wss://example.com", OptsList(), proxy_info(http_proxy_host=None, http_proxy_port=None, proxy_type=None))
self.assertRaises(TypeError, _open_proxied_socket, "wss://example.com", OptsList(), proxy_info(http_proxy_host="example.com", http_proxy_port="8080", proxy_type="http")) self.assertRaises(TypeError, _open_proxied_socket, "wss://example.com", OptsList(), proxy_info(http_proxy_host="example.com", http_proxy_port="8080", proxy_type="http"))
self.assertRaises(TypeError, _open_proxied_socket, "wss://example.com", OptsList(), proxy_info(http_proxy_host="example.com", http_proxy_port="8080", proxy_type="socks4")) self.assertRaises(TypeError, _open_proxied_socket, "wss://example.com", OptsList(), proxy_info(http_proxy_host="example.com", http_proxy_port="8080", proxy_type="socks4"))
self.assertRaises(TypeError, _open_proxied_socket, "wss://example.com", OptsList(), proxy_info(http_proxy_host="example.com", http_proxy_port="8080", proxy_type="socks5h")) self.assertRaises(TypeError, _open_proxied_socket, "wss://example.com", OptsList(), proxy_info(http_proxy_host="example.com", http_proxy_port="8080", proxy_type="socks5h"))
self.assertRaises(TypeError, _get_addrinfo_list, None, 80, True, proxy_info(http_proxy_host="127.0.0.1", http_proxy_port="8080", proxy_type="http"))
self.assertRaises(TypeError, _get_addrinfo_list, None, 80, True, proxy_info(http_proxy_host="127.0.0.1", http_proxy_port="8080", proxy_type="http"))
self.assertRaises(socks.ProxyConnectionError, connect, "wss://example.com", OptsList(), proxy_info(http_proxy_host="127.0.0.1", http_proxy_port=8080, proxy_type="socks4"), None)
self.assertRaises(socket.timeout, connect, "wss://google.com", OptsList(), proxy_info(http_proxy_host="8.8.8.8", http_proxy_port=8080, proxy_type="http"), None)
self.assertEqual(
connect("wss://google.com", OptsList(), proxy_info(http_proxy_host="8.8.8.8", http_proxy_port=8080, proxy_type="http"), True),
(True, ("google.com", 443, "/")))
# The following test fails on Mac OS with a gaierror, not an OverflowError
# self.assertRaises(OverflowError, connect, "wss://example.com", OptsList(), proxy_info(http_proxy_host="127.0.0.1", http_proxy_port=99999, proxy_type="socks4", timeout=2), False)
@unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled")
def testSSLopt(self):
ssloptions = {
"cert_reqs": ssl.CERT_NONE,
"check_hostname": False,
"ssl_version": ssl.PROTOCOL_SSLv23,
"ciphers": "TLS_AES_256_GCM_SHA384:TLS_CHACHA20_POLY1305_SHA256:\
TLS_AES_128_GCM_SHA256:ECDHE-ECDSA-AES256-GCM-SHA384:\
ECDHE-RSA-AES256-GCM-SHA384:DHE-RSA-AES256-GCM-SHA384:\
ECDHE-ECDSA-CHACHA20-POLY1305:ECDHE-RSA-CHACHA20-POLY1305:\
DHE-RSA-CHACHA20-POLY1305:ECDHE-ECDSA-AES128-GCM-SHA256:\
ECDHE-RSA-AES128-GCM-SHA256:DHE-RSA-AES128-GCM-SHA256:\
ECDHE-ECDSA-AES256-SHA384:ECDHE-RSA-AES256-SHA384:\
DHE-RSA-AES256-SHA256:ECDHE-ECDSA-AES128-SHA256:\
ECDHE-RSA-AES128-SHA256:DHE-RSA-AES128-SHA256:\
ECDHE-ECDSA-AES256-SHA:ECDHE-RSA-AES256-SHA",
"ecdh_curve": "prime256v1"
}
ws_ssl1 = websocket.WebSocket(sslopt=ssloptions)
ws_ssl1.connect("wss://api.bitfinex.com/ws/2")
ws_ssl1.send("Hello")
ws_ssl1.close()
ws_ssl2 = websocket.WebSocket(sslopt={"check_hostname": True})
ws_ssl2.connect("wss://api.bitfinex.com/ws/2")
ws_ssl2.close
def testProxyInfo(self): def testProxyInfo(self):
self.assertEqual(proxy_info(http_proxy_host="127.0.0.1", http_proxy_port="8080", proxy_type="http").type, "http") self.assertEqual(proxy_info(http_proxy_host="127.0.0.1", http_proxy_port="8080", proxy_type="http").type, "http")

View File

@ -23,14 +23,9 @@ Copyright (C) 2010 Hiroki Ohtani(liris)
import sys import sys
import os import os
import unittest
from websocket._url import get_proxy_info, parse_url, _is_address_in_network, _is_no_proxy_host
if sys.version_info[0] == 2 and sys.version_info[1] < 7:
import unittest2 as unittest
else:
import unittest
sys.path[0:0] = [""] sys.path[0:0] = [""]
from websocket._url import get_proxy_info, parse_url, _is_address_in_network, _is_no_proxy_host
class UrlTest(unittest.TestCase): class UrlTest(unittest.TestCase):
@ -97,9 +92,6 @@ class UrlTest(unittest.TestCase):
self.assertRaises(ValueError, parse_url, "http://www.example.com/r") self.assertRaises(ValueError, parse_url, "http://www.example.com/r")
if sys.version_info[0] == 2 and sys.version_info[1] < 7:
return
p = parse_url("ws://[2a03:4000:123:83::3]/r") p = parse_url("ws://[2a03:4000:123:83::3]/r")
self.assertEqual(p[0], "2a03:4000:123:83::3") self.assertEqual(p[0], "2a03:4000:123:83::3")
self.assertEqual(p[1], 80) self.assertEqual(p[1], 80)

View File

@ -27,31 +27,20 @@ Copyright (C) 2010 Hiroki Ohtani(liris)
import sys import sys
sys.path[0:0] = [""] sys.path[0:0] = [""]
import os import os
import os.path import os.path
import socket import socket
import six
# websocket-client
import websocket as ws import websocket as ws
from websocket._handshake import _create_sec_websocket_key, \ from websocket._handshake import _create_sec_websocket_key, \
_validate as _validate_header _validate as _validate_header
from websocket._http import read_headers from websocket._http import read_headers
from websocket._utils import validate_utf8 from websocket._utils import validate_utf8
from base64 import decodebytes as base64decode
if six.PY3: import unittest
from base64 import decodebytes as base64decode
else:
from base64 import decodestring as base64decode
if sys.version_info[0] == 2 and sys.version_info[1] < 7:
import unittest2 as unittest
else:
import unittest
try: try:
import ssl
from ssl import SSLError from ssl import SSLError
except ImportError: except ImportError:
# dummy class of SSLError for ssl none-support environment. # dummy class of SSLError for ssl none-support environment.
@ -120,7 +109,7 @@ class WebSocketTest(unittest.TestCase):
def testWSKey(self): def testWSKey(self):
key = _create_sec_websocket_key() key = _create_sec_websocket_key()
self.assertTrue(key != 24) self.assertTrue(key != 24)
self.assertTrue(six.u("¥n") not in key) self.assertTrue(str("¥n") not in key)
def testNonce(self): def testNonce(self):
""" WebSocket key should be a random 16-byte nonce. """ WebSocket key should be a random 16-byte nonce.
@ -158,6 +147,7 @@ class WebSocketTest(unittest.TestCase):
header = required_header.copy() header = required_header.copy()
header["sec-websocket-protocol"] = "sub1" header["sec-websocket-protocol"] = "sub1"
self.assertEqual(_validate_header(header, key, ["sub1", "sub2"]), (True, "sub1")) self.assertEqual(_validate_header(header, key, ["sub1", "sub2"]), (True, "sub1"))
# This case will print out a logging error using the error() function, but that is expected
self.assertEqual(_validate_header(header, key, ["sub2", "sub3"]), (False, None)) self.assertEqual(_validate_header(header, key, ["sub2", "sub3"]), (False, None))
header = required_header.copy() header = required_header.copy()
@ -165,6 +155,7 @@ class WebSocketTest(unittest.TestCase):
self.assertEqual(_validate_header(header, key, ["Sub1", "suB2"]), (True, "sub1")) self.assertEqual(_validate_header(header, key, ["Sub1", "suB2"]), (True, "sub1"))
header = required_header.copy() header = required_header.copy()
# This case will print out a logging error using the error() function, but that is expected
self.assertEqual(_validate_header(header, key, ["Sub1", "suB2"]), (False, None)) self.assertEqual(_validate_header(header, key, ["Sub1", "suB2"]), (False, None))
def testReadHeader(self): def testReadHeader(self):
@ -185,16 +176,13 @@ class WebSocketTest(unittest.TestCase):
sock.set_mask_key(create_mask_key) sock.set_mask_key(create_mask_key)
s = sock.sock = HeaderSockMock("data/header01.txt") s = sock.sock = HeaderSockMock("data/header01.txt")
sock.send("Hello") sock.send("Hello")
self.assertEqual(s.sent[0], six.b("\x81\x85abcd)\x07\x0f\x08\x0e")) self.assertEqual(s.sent[0], b'\x81\x85abcd)\x07\x0f\x08\x0e')
sock.send("こんにちは") sock.send("こんにちは")
self.assertEqual(s.sent[1], six.b("\x81\x8fabcd\x82\xe3\xf0\x87\xe3\xf1\x80\xe5\xca\x81\xe2\xc5\x82\xe3\xcc")) self.assertEqual(s.sent[1], b'\x81\x8fabcd\x82\xe3\xf0\x87\xe3\xf1\x80\xe5\xca\x81\xe2\xc5\x82\xe3\xcc')
sock.send(u"こんにちは")
self.assertEqual(s.sent[1], six.b("\x81\x8fabcd\x82\xe3\xf0\x87\xe3\xf1\x80\xe5\xca\x81\xe2\xc5\x82\xe3\xcc"))
# sock.send("x" * 5000) # sock.send("x" * 5000)
# self.assertEqual(s.sent[1], six.b("\x81\x8fabcd\x82\xe3\xf0\x87\xe3\xf1\x80\xe5\xca\x81\xe2\xc5\x82\xe3\xcc")) # self.assertEqual(s.sent[1], b'\x81\x8fabcd\x82\xe3\xf0\x87\xe3\xf1\x80\xe5\xca\x81\xe2\xc5\x82\xe3\xcc")
self.assertEqual(sock.send_binary(b'1111111111101'), 19) self.assertEqual(sock.send_binary(b'1111111111101'), 19)
@ -202,12 +190,12 @@ class WebSocketTest(unittest.TestCase):
# TODO: add longer frame data # TODO: add longer frame data
sock = ws.WebSocket() sock = ws.WebSocket()
s = sock.sock = SockMock() s = sock.sock = SockMock()
something = six.b("\x81\x8fabcd\x82\xe3\xf0\x87\xe3\xf1\x80\xe5\xca\x81\xe2\xc5\x82\xe3\xcc") something = b'\x81\x8fabcd\x82\xe3\xf0\x87\xe3\xf1\x80\xe5\xca\x81\xe2\xc5\x82\xe3\xcc'
s.add_packet(something) s.add_packet(something)
data = sock.recv() data = sock.recv()
self.assertEqual(data, "こんにちは") self.assertEqual(data, "こんにちは")
s.add_packet(six.b("\x81\x85abcd)\x07\x0f\x08\x0e")) s.add_packet(b'\x81\x85abcd)\x07\x0f\x08\x0e')
data = sock.recv() data = sock.recv()
self.assertEqual(data, "Hello") self.assertEqual(data, "Hello")
@ -227,32 +215,28 @@ class WebSocketTest(unittest.TestCase):
def testInternalRecvStrict(self): def testInternalRecvStrict(self):
sock = ws.WebSocket() sock = ws.WebSocket()
s = sock.sock = SockMock() s = sock.sock = SockMock()
s.add_packet(six.b("foo")) s.add_packet(b'foo')
s.add_packet(socket.timeout()) s.add_packet(socket.timeout())
s.add_packet(six.b("bar")) s.add_packet(b'bar')
# s.add_packet(SSLError("The read operation timed out")) # s.add_packet(SSLError("The read operation timed out"))
s.add_packet(six.b("baz")) s.add_packet(b'baz')
with self.assertRaises(ws.WebSocketTimeoutException): with self.assertRaises(ws.WebSocketTimeoutException):
sock.frame_buffer.recv_strict(9) sock.frame_buffer.recv_strict(9)
# if six.PY2:
# with self.assertRaises(ws.WebSocketTimeoutException):
# data = sock._recv_strict(9)
# else:
# with self.assertRaises(SSLError): # with self.assertRaises(SSLError):
# data = sock._recv_strict(9) # data = sock._recv_strict(9)
data = sock.frame_buffer.recv_strict(9) data = sock.frame_buffer.recv_strict(9)
self.assertEqual(data, six.b("foobarbaz")) self.assertEqual(data, b'foobarbaz')
with self.assertRaises(ws.WebSocketConnectionClosedException): with self.assertRaises(ws.WebSocketConnectionClosedException):
sock.frame_buffer.recv_strict(1) sock.frame_buffer.recv_strict(1)
def testRecvTimeout(self): def testRecvTimeout(self):
sock = ws.WebSocket() sock = ws.WebSocket()
s = sock.sock = SockMock() s = sock.sock = SockMock()
s.add_packet(six.b("\x81")) s.add_packet(b'\x81')
s.add_packet(socket.timeout()) s.add_packet(socket.timeout())
s.add_packet(six.b("\x8dabcd\x29\x07\x0f\x08\x0e")) s.add_packet(b'\x8dabcd\x29\x07\x0f\x08\x0e')
s.add_packet(socket.timeout()) s.add_packet(socket.timeout())
s.add_packet(six.b("\x4e\x43\x33\x0e\x10\x0f\x00\x40")) s.add_packet(b'\x4e\x43\x33\x0e\x10\x0f\x00\x40')
with self.assertRaises(ws.WebSocketTimeoutException): with self.assertRaises(ws.WebSocketTimeoutException):
sock.recv() sock.recv()
with self.assertRaises(ws.WebSocketTimeoutException): with self.assertRaises(ws.WebSocketTimeoutException):
@ -266,9 +250,9 @@ class WebSocketTest(unittest.TestCase):
sock = ws.WebSocket() sock = ws.WebSocket()
s = sock.sock = SockMock() s = sock.sock = SockMock()
# OPCODE=TEXT, FIN=0, MSG="Brevity is " # OPCODE=TEXT, FIN=0, MSG="Brevity is "
s.add_packet(six.b("\x01\x8babcd#\x10\x06\x12\x08\x16\x1aD\x08\x11C")) s.add_packet(b'\x01\x8babcd#\x10\x06\x12\x08\x16\x1aD\x08\x11C')
# OPCODE=CONT, FIN=1, MSG="the soul of wit" # OPCODE=CONT, FIN=1, MSG="the soul of wit"
s.add_packet(six.b("\x80\x8fabcd\x15\n\x06D\x12\r\x16\x08A\r\x05D\x16\x0b\x17")) s.add_packet(b'\x80\x8fabcd\x15\n\x06D\x12\r\x16\x08A\r\x05D\x16\x0b\x17')
data = sock.recv() data = sock.recv()
self.assertEqual(data, "Brevity is the soul of wit") self.assertEqual(data, "Brevity is the soul of wit")
with self.assertRaises(ws.WebSocketConnectionClosedException): with self.assertRaises(ws.WebSocketConnectionClosedException):
@ -278,21 +262,21 @@ class WebSocketTest(unittest.TestCase):
sock = ws.WebSocket(fire_cont_frame=True) sock = ws.WebSocket(fire_cont_frame=True)
s = sock.sock = SockMock() s = sock.sock = SockMock()
# OPCODE=TEXT, FIN=0, MSG="Brevity is " # OPCODE=TEXT, FIN=0, MSG="Brevity is "
s.add_packet(six.b("\x01\x8babcd#\x10\x06\x12\x08\x16\x1aD\x08\x11C")) s.add_packet(b'\x01\x8babcd#\x10\x06\x12\x08\x16\x1aD\x08\x11C')
# OPCODE=CONT, FIN=0, MSG="Brevity is " # OPCODE=CONT, FIN=0, MSG="Brevity is "
s.add_packet(six.b("\x00\x8babcd#\x10\x06\x12\x08\x16\x1aD\x08\x11C")) s.add_packet(b'\x00\x8babcd#\x10\x06\x12\x08\x16\x1aD\x08\x11C')
# OPCODE=CONT, FIN=1, MSG="the soul of wit" # OPCODE=CONT, FIN=1, MSG="the soul of wit"
s.add_packet(six.b("\x80\x8fabcd\x15\n\x06D\x12\r\x16\x08A\r\x05D\x16\x0b\x17")) s.add_packet(b'\x80\x8fabcd\x15\n\x06D\x12\r\x16\x08A\r\x05D\x16\x0b\x17')
_, data = sock.recv_data() _, data = sock.recv_data()
self.assertEqual(data, six.b("Brevity is ")) self.assertEqual(data, b'Brevity is ')
_, data = sock.recv_data() _, data = sock.recv_data()
self.assertEqual(data, six.b("Brevity is ")) self.assertEqual(data, b'Brevity is ')
_, data = sock.recv_data() _, data = sock.recv_data()
self.assertEqual(data, six.b("the soul of wit")) self.assertEqual(data, b'the soul of wit')
# OPCODE=CONT, FIN=0, MSG="Brevity is " # OPCODE=CONT, FIN=0, MSG="Brevity is "
s.add_packet(six.b("\x80\x8babcd#\x10\x06\x12\x08\x16\x1aD\x08\x11C")) s.add_packet(b'\x80\x8babcd#\x10\x06\x12\x08\x16\x1aD\x08\x11C')
with self.assertRaises(ws.WebSocketException): with self.assertRaises(ws.WebSocketException):
sock.recv_data() sock.recv_data()
@ -302,15 +286,13 @@ class WebSocketTest(unittest.TestCase):
def testClose(self): def testClose(self):
sock = ws.WebSocket() sock = ws.WebSocket()
sock.sock = SockMock()
sock.connected = True sock.connected = True
sock.close() self.assertRaises(ws._exceptions.WebSocketConnectionClosedException, sock.close)
self.assertEqual(sock.connected, False)
sock = ws.WebSocket() sock = ws.WebSocket()
s = sock.sock = SockMock() s = sock.sock = SockMock()
sock.connected = True sock.connected = True
s.add_packet(six.b('\x88\x80\x17\x98p\x84')) s.add_packet(b'\x88\x80\x17\x98p\x84')
sock.recv() sock.recv()
self.assertEqual(sock.connected, False) self.assertEqual(sock.connected, False)
@ -318,20 +300,18 @@ class WebSocketTest(unittest.TestCase):
sock = ws.WebSocket() sock = ws.WebSocket()
s = sock.sock = SockMock() s = sock.sock = SockMock()
# OPCODE=CONT, FIN=1, MSG="the soul of wit" # OPCODE=CONT, FIN=1, MSG="the soul of wit"
s.add_packet(six.b("\x80\x8fabcd\x15\n\x06D\x12\r\x16\x08A\r\x05D\x16\x0b\x17")) s.add_packet(b'\x80\x8fabcd\x15\n\x06D\x12\r\x16\x08A\r\x05D\x16\x0b\x17')
self.assertRaises(ws.WebSocketException, sock.recv) self.assertRaises(ws.WebSocketException, sock.recv)
def testRecvWithProlongedFragmentation(self): def testRecvWithProlongedFragmentation(self):
sock = ws.WebSocket() sock = ws.WebSocket()
s = sock.sock = SockMock() s = sock.sock = SockMock()
# OPCODE=TEXT, FIN=0, MSG="Once more unto the breach, " # OPCODE=TEXT, FIN=0, MSG="Once more unto the breach, "
s.add_packet(six.b("\x01\x9babcd.\x0c\x00\x01A\x0f\x0c\x16\x04B\x16\n\x15" s.add_packet(b'\x01\x9babcd.\x0c\x00\x01A\x0f\x0c\x16\x04B\x16\n\x15\rC\x10\t\x07C\x06\x13\x07\x02\x07\tNC')
"\rC\x10\t\x07C\x06\x13\x07\x02\x07\tNC"))
# OPCODE=CONT, FIN=0, MSG="dear friends, " # OPCODE=CONT, FIN=0, MSG="dear friends, "
s.add_packet(six.b("\x00\x8eabcd\x05\x07\x02\x16A\x04\x11\r\x04\x0c\x07" s.add_packet(b'\x00\x8eabcd\x05\x07\x02\x16A\x04\x11\r\x04\x0c\x07\x17MB')
"\x17MB"))
# OPCODE=CONT, FIN=1, MSG="once more" # OPCODE=CONT, FIN=1, MSG="once more"
s.add_packet(six.b("\x80\x89abcd\x0e\x0c\x00\x01A\x0f\x0c\x16\x04")) s.add_packet(b'\x80\x89abcd\x0e\x0c\x00\x01A\x0f\x0c\x16\x04')
data = sock.recv() data = sock.recv()
self.assertEqual( self.assertEqual(
data, data,
@ -344,19 +324,18 @@ class WebSocketTest(unittest.TestCase):
sock.set_mask_key(create_mask_key) sock.set_mask_key(create_mask_key)
s = sock.sock = SockMock() s = sock.sock = SockMock()
# OPCODE=TEXT, FIN=0, MSG="Too much " # OPCODE=TEXT, FIN=0, MSG="Too much "
s.add_packet(six.b("\x01\x89abcd5\r\x0cD\x0c\x17\x00\x0cA")) s.add_packet(b'\x01\x89abcd5\r\x0cD\x0c\x17\x00\x0cA')
# OPCODE=PING, FIN=1, MSG="Please PONG this" # OPCODE=PING, FIN=1, MSG="Please PONG this"
s.add_packet(six.b("\x89\x90abcd1\x0e\x06\x05\x12\x07C4.,$D\x15\n\n\x17")) s.add_packet(b'\x89\x90abcd1\x0e\x06\x05\x12\x07C4.,$D\x15\n\n\x17')
# OPCODE=CONT, FIN=1, MSG="of a good thing" # OPCODE=CONT, FIN=1, MSG="of a good thing"
s.add_packet(six.b("\x80\x8fabcd\x0e\x04C\x05A\x05\x0c\x0b\x05B\x17\x0c" s.add_packet(b'\x80\x8fabcd\x0e\x04C\x05A\x05\x0c\x0b\x05B\x17\x0c\x08\x0c\x04')
"\x08\x0c\x04"))
data = sock.recv() data = sock.recv()
self.assertEqual(data, "Too much of a good thing") self.assertEqual(data, "Too much of a good thing")
with self.assertRaises(ws.WebSocketConnectionClosedException): with self.assertRaises(ws.WebSocketConnectionClosedException):
sock.recv() sock.recv()
self.assertEqual( self.assertEqual(
s.sent[0], s.sent[0],
six.b("\x8a\x90abcd1\x0e\x06\x05\x12\x07C4.,$D\x15\n\n\x17")) b'\x8a\x90abcd1\x0e\x06\x05\x12\x07C4.,$D\x15\n\n\x17')
@unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled") @unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled")
def testWebSocket(self): def testWebSocket(self):
@ -366,7 +345,7 @@ class WebSocketTest(unittest.TestCase):
result = s.recv() result = s.recv()
self.assertEqual(result, "Hello, World") self.assertEqual(result, "Hello, World")
s.send(u"こにゃにゃちは、世界") s.send("こにゃにゃちは、世界")
result = s.recv() result = s.recv()
self.assertEqual(result, "こにゃにゃちは、世界") self.assertEqual(result, "こにゃにゃちは、世界")
self.assertRaises(ValueError, s.send_close, -1, "") self.assertRaises(ValueError, s.send_close, -1, "")
@ -388,7 +367,10 @@ class WebSocketTest(unittest.TestCase):
self.assertTrue(isinstance(s.sock, ssl.SSLSocket)) self.assertTrue(isinstance(s.sock, ssl.SSLSocket))
self.assertEqual(s.getstatus(), 101) self.assertEqual(s.getstatus(), 101)
self.assertNotEqual(s.getheaders(), None) self.assertNotEqual(s.getheaders(), None)
s.close() s.settimeout(10)
self.assertEqual(s.gettimeout(), 10)
self.assertEqual(s.getsubprotocol(), None)
s.abort()
@unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled") @unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled")
def testWebSocketWithCustomHeader(self): def testWebSocketWithCustomHeader(self):
@ -421,13 +403,50 @@ class SockOptTest(unittest.TestCase):
class UtilsTest(unittest.TestCase): class UtilsTest(unittest.TestCase):
def testUtf8Validator(self): def testUtf8Validator(self):
state = validate_utf8(six.b('\xf0\x90\x80\x80')) state = validate_utf8(b'\xf0\x90\x80\x80')
self.assertEqual(state, True) self.assertEqual(state, True)
state = validate_utf8(six.b('\xce\xba\xe1\xbd\xb9\xcf\x83\xce\xbc\xce\xb5\xed\xa0\x80edited')) state = validate_utf8(b'\xce\xba\xe1\xbd\xb9\xcf\x83\xce\xbc\xce\xb5\xed\xa0\x80edited')
self.assertEqual(state, False) self.assertEqual(state, False)
state = validate_utf8(six.b('')) state = validate_utf8(b'')
self.assertEqual(state, True) self.assertEqual(state, True)
class HandshakeTest(unittest.TestCase):
@unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled")
def test_http_SSL(self):
websock1 = ws.WebSocket(sslopt={"cert_chain": ssl.get_default_verify_paths().capath})
self.assertRaises(ValueError,
websock1.connect, "wss://api.bitfinex.com/ws/2")
websock2 = ws.WebSocket(sslopt={"certfile": "myNonexistentCertFile"})
self.assertRaises(FileNotFoundError,
websock2.connect, "wss://api.bitfinex.com/ws/2")
@unittest.skipUnless(TEST_WITH_INTERNET, "Internet-requiring tests are disabled")
def testManualHeaders(self):
websock3 = ws.WebSocket(sslopt={"cert_reqs": ssl.CERT_NONE,
"ca_certs": ssl.get_default_verify_paths().capath,
"ca_cert_path": ssl.get_default_verify_paths().openssl_cafile})
self.assertRaises(ws._exceptions.WebSocketBadStatusException,
websock3.connect, "wss://api.bitfinex.com/ws/2", cookie="chocolate",
origin="testing_websockets.com",
host="echo.websocket.org/websocket-client-test",
subprotocols=["testproto"],
connection="Upgrade",
header={"CustomHeader1":"123",
"Cookie":"TestValue",
"Sec-WebSocket-Key":"k9kFAUWNAMmf5OEMfTlOEA==",
"Sec-WebSocket-Protocol":"newprotocol"})
def testIPv6(self):
websock2 = ws.WebSocket()
self.assertRaises(ValueError, websock2.connect, "2001:4860:4860::8888")
def testBadURLs(self):
websock3 = ws.WebSocket()
self.assertRaises(ValueError, websock3.connect, "ws//example.com")
self.assertRaises(ws.WebSocketAddressException, websock3.connect, "ws://example")
self.assertRaises(ValueError, websock3.connect, "example.com")
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()