mirror of
https://github.com/morpheus65535/bazarr
synced 2024-12-23 00:03:33 +00:00
571 lines
21 KiB
Python
571 lines
21 KiB
Python
|
##############################################################################
|
||
|
#
|
||
|
# Copyright (c) 2001, 2002 Zope Foundation and Contributors.
|
||
|
# All Rights Reserved.
|
||
|
#
|
||
|
# This software is subject to the provisions of the Zope Public License,
|
||
|
# Version 2.1 (ZPL). A copy of the ZPL should accompany this distribution.
|
||
|
# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED
|
||
|
# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
||
|
# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS
|
||
|
# FOR A PARTICULAR PURPOSE.
|
||
|
#
|
||
|
##############################################################################
|
||
|
|
||
|
from collections import deque
|
||
|
import socket
|
||
|
import sys
|
||
|
import threading
|
||
|
import time
|
||
|
|
||
|
from .buffers import ReadOnlyFileBasedBuffer
|
||
|
from .utilities import build_http_date, logger, queue_logger
|
||
|
|
||
|
rename_headers = { # or keep them without the HTTP_ prefix added
|
||
|
"CONTENT_LENGTH": "CONTENT_LENGTH",
|
||
|
"CONTENT_TYPE": "CONTENT_TYPE",
|
||
|
}
|
||
|
|
||
|
hop_by_hop = frozenset(
|
||
|
(
|
||
|
"connection",
|
||
|
"keep-alive",
|
||
|
"proxy-authenticate",
|
||
|
"proxy-authorization",
|
||
|
"te",
|
||
|
"trailers",
|
||
|
"transfer-encoding",
|
||
|
"upgrade",
|
||
|
)
|
||
|
)
|
||
|
|
||
|
|
||
|
class ThreadedTaskDispatcher:
|
||
|
"""A Task Dispatcher that creates a thread for each task."""
|
||
|
|
||
|
stop_count = 0 # Number of threads that will stop soon.
|
||
|
active_count = 0 # Number of currently active threads
|
||
|
logger = logger
|
||
|
queue_logger = queue_logger
|
||
|
|
||
|
def __init__(self):
|
||
|
self.threads = set()
|
||
|
self.queue = deque()
|
||
|
self.lock = threading.Lock()
|
||
|
self.queue_cv = threading.Condition(self.lock)
|
||
|
self.thread_exit_cv = threading.Condition(self.lock)
|
||
|
|
||
|
def start_new_thread(self, target, thread_no):
|
||
|
t = threading.Thread(
|
||
|
target=target, name="waitress-{}".format(thread_no), args=(thread_no,)
|
||
|
)
|
||
|
t.daemon = True
|
||
|
t.start()
|
||
|
|
||
|
def handler_thread(self, thread_no):
|
||
|
while True:
|
||
|
with self.lock:
|
||
|
while not self.queue and self.stop_count == 0:
|
||
|
# Mark ourselves as idle before waiting to be
|
||
|
# woken up, then we will once again be active
|
||
|
self.active_count -= 1
|
||
|
self.queue_cv.wait()
|
||
|
self.active_count += 1
|
||
|
|
||
|
if self.stop_count > 0:
|
||
|
self.active_count -= 1
|
||
|
self.stop_count -= 1
|
||
|
self.threads.discard(thread_no)
|
||
|
self.thread_exit_cv.notify()
|
||
|
break
|
||
|
|
||
|
task = self.queue.popleft()
|
||
|
try:
|
||
|
task.service()
|
||
|
except BaseException:
|
||
|
self.logger.exception("Exception when servicing %r", task)
|
||
|
|
||
|
def set_thread_count(self, count):
|
||
|
with self.lock:
|
||
|
threads = self.threads
|
||
|
thread_no = 0
|
||
|
running = len(threads) - self.stop_count
|
||
|
while running < count:
|
||
|
# Start threads.
|
||
|
while thread_no in threads:
|
||
|
thread_no = thread_no + 1
|
||
|
threads.add(thread_no)
|
||
|
running += 1
|
||
|
self.start_new_thread(self.handler_thread, thread_no)
|
||
|
self.active_count += 1
|
||
|
thread_no = thread_no + 1
|
||
|
if running > count:
|
||
|
# Stop threads.
|
||
|
self.stop_count += running - count
|
||
|
self.queue_cv.notify_all()
|
||
|
|
||
|
def add_task(self, task):
|
||
|
with self.lock:
|
||
|
self.queue.append(task)
|
||
|
self.queue_cv.notify()
|
||
|
queue_size = len(self.queue)
|
||
|
idle_threads = len(self.threads) - self.stop_count - self.active_count
|
||
|
if queue_size > idle_threads:
|
||
|
self.queue_logger.warning(
|
||
|
"Task queue depth is %d", queue_size - idle_threads
|
||
|
)
|
||
|
|
||
|
def shutdown(self, cancel_pending=True, timeout=5):
|
||
|
self.set_thread_count(0)
|
||
|
# Ensure the threads shut down.
|
||
|
threads = self.threads
|
||
|
expiration = time.time() + timeout
|
||
|
with self.lock:
|
||
|
while threads:
|
||
|
if time.time() >= expiration:
|
||
|
self.logger.warning("%d thread(s) still running", len(threads))
|
||
|
break
|
||
|
self.thread_exit_cv.wait(0.1)
|
||
|
if cancel_pending:
|
||
|
# Cancel remaining tasks.
|
||
|
queue = self.queue
|
||
|
if len(queue) > 0:
|
||
|
self.logger.warning("Canceling %d pending task(s)", len(queue))
|
||
|
while queue:
|
||
|
task = queue.popleft()
|
||
|
task.cancel()
|
||
|
self.queue_cv.notify_all()
|
||
|
return True
|
||
|
return False
|
||
|
|
||
|
|
||
|
class Task:
|
||
|
close_on_finish = False
|
||
|
status = "200 OK"
|
||
|
wrote_header = False
|
||
|
start_time = 0
|
||
|
content_length = None
|
||
|
content_bytes_written = 0
|
||
|
logged_write_excess = False
|
||
|
logged_write_no_body = False
|
||
|
complete = False
|
||
|
chunked_response = False
|
||
|
logger = logger
|
||
|
|
||
|
def __init__(self, channel, request):
|
||
|
self.channel = channel
|
||
|
self.request = request
|
||
|
self.response_headers = []
|
||
|
version = request.version
|
||
|
if version not in ("1.0", "1.1"):
|
||
|
# fall back to a version we support.
|
||
|
version = "1.0"
|
||
|
self.version = version
|
||
|
|
||
|
def service(self):
|
||
|
try:
|
||
|
self.start()
|
||
|
self.execute()
|
||
|
self.finish()
|
||
|
except OSError:
|
||
|
self.close_on_finish = True
|
||
|
if self.channel.adj.log_socket_errors:
|
||
|
raise
|
||
|
|
||
|
@property
|
||
|
def has_body(self):
|
||
|
return not (
|
||
|
self.status.startswith("1")
|
||
|
or self.status.startswith("204")
|
||
|
or self.status.startswith("304")
|
||
|
)
|
||
|
|
||
|
def build_response_header(self):
|
||
|
version = self.version
|
||
|
# Figure out whether the connection should be closed.
|
||
|
connection = self.request.headers.get("CONNECTION", "").lower()
|
||
|
response_headers = []
|
||
|
content_length_header = None
|
||
|
date_header = None
|
||
|
server_header = None
|
||
|
connection_close_header = None
|
||
|
|
||
|
for (headername, headerval) in self.response_headers:
|
||
|
headername = "-".join([x.capitalize() for x in headername.split("-")])
|
||
|
|
||
|
if headername == "Content-Length":
|
||
|
if self.has_body:
|
||
|
content_length_header = headerval
|
||
|
else:
|
||
|
continue # pragma: no cover
|
||
|
|
||
|
if headername == "Date":
|
||
|
date_header = headerval
|
||
|
|
||
|
if headername == "Server":
|
||
|
server_header = headerval
|
||
|
|
||
|
if headername == "Connection":
|
||
|
connection_close_header = headerval.lower()
|
||
|
# replace with properly capitalized version
|
||
|
response_headers.append((headername, headerval))
|
||
|
|
||
|
if (
|
||
|
content_length_header is None
|
||
|
and self.content_length is not None
|
||
|
and self.has_body
|
||
|
):
|
||
|
content_length_header = str(self.content_length)
|
||
|
response_headers.append(("Content-Length", content_length_header))
|
||
|
|
||
|
def close_on_finish():
|
||
|
if connection_close_header is None:
|
||
|
response_headers.append(("Connection", "close"))
|
||
|
self.close_on_finish = True
|
||
|
|
||
|
if version == "1.0":
|
||
|
if connection == "keep-alive":
|
||
|
if not content_length_header:
|
||
|
close_on_finish()
|
||
|
else:
|
||
|
response_headers.append(("Connection", "Keep-Alive"))
|
||
|
else:
|
||
|
close_on_finish()
|
||
|
|
||
|
elif version == "1.1":
|
||
|
if connection == "close":
|
||
|
close_on_finish()
|
||
|
|
||
|
if not content_length_header:
|
||
|
# RFC 7230: MUST NOT send Transfer-Encoding or Content-Length
|
||
|
# for any response with a status code of 1xx, 204 or 304.
|
||
|
|
||
|
if self.has_body:
|
||
|
response_headers.append(("Transfer-Encoding", "chunked"))
|
||
|
self.chunked_response = True
|
||
|
|
||
|
if not self.close_on_finish:
|
||
|
close_on_finish()
|
||
|
|
||
|
# under HTTP 1.1 keep-alive is default, no need to set the header
|
||
|
else:
|
||
|
raise AssertionError("neither HTTP/1.0 or HTTP/1.1")
|
||
|
|
||
|
# Set the Server and Date field, if not yet specified. This is needed
|
||
|
# if the server is used as a proxy.
|
||
|
ident = self.channel.server.adj.ident
|
||
|
|
||
|
if not server_header:
|
||
|
if ident:
|
||
|
response_headers.append(("Server", ident))
|
||
|
else:
|
||
|
response_headers.append(("Via", ident or "waitress"))
|
||
|
|
||
|
if not date_header:
|
||
|
response_headers.append(("Date", build_http_date(self.start_time)))
|
||
|
|
||
|
self.response_headers = response_headers
|
||
|
|
||
|
first_line = "HTTP/%s %s" % (self.version, self.status)
|
||
|
# NB: sorting headers needs to preserve same-named-header order
|
||
|
# as per RFC 2616 section 4.2; thus the key=lambda x: x[0] here;
|
||
|
# rely on stable sort to keep relative position of same-named headers
|
||
|
next_lines = [
|
||
|
"%s: %s" % hv for hv in sorted(self.response_headers, key=lambda x: x[0])
|
||
|
]
|
||
|
lines = [first_line] + next_lines
|
||
|
res = "%s\r\n\r\n" % "\r\n".join(lines)
|
||
|
|
||
|
return res.encode("latin-1")
|
||
|
|
||
|
def remove_content_length_header(self):
|
||
|
response_headers = []
|
||
|
|
||
|
for header_name, header_value in self.response_headers:
|
||
|
if header_name.lower() == "content-length":
|
||
|
continue # pragma: nocover
|
||
|
response_headers.append((header_name, header_value))
|
||
|
|
||
|
self.response_headers = response_headers
|
||
|
|
||
|
def start(self):
|
||
|
self.start_time = time.time()
|
||
|
|
||
|
def finish(self):
|
||
|
if not self.wrote_header:
|
||
|
self.write(b"")
|
||
|
if self.chunked_response:
|
||
|
# not self.write, it will chunk it!
|
||
|
self.channel.write_soon(b"0\r\n\r\n")
|
||
|
|
||
|
def write(self, data):
|
||
|
if not self.complete:
|
||
|
raise RuntimeError("start_response was not called before body written")
|
||
|
channel = self.channel
|
||
|
if not self.wrote_header:
|
||
|
rh = self.build_response_header()
|
||
|
channel.write_soon(rh)
|
||
|
self.wrote_header = True
|
||
|
|
||
|
if data and self.has_body:
|
||
|
towrite = data
|
||
|
cl = self.content_length
|
||
|
if self.chunked_response:
|
||
|
# use chunked encoding response
|
||
|
towrite = hex(len(data))[2:].upper().encode("latin-1") + b"\r\n"
|
||
|
towrite += data + b"\r\n"
|
||
|
elif cl is not None:
|
||
|
towrite = data[: cl - self.content_bytes_written]
|
||
|
self.content_bytes_written += len(towrite)
|
||
|
if towrite != data and not self.logged_write_excess:
|
||
|
self.logger.warning(
|
||
|
"application-written content exceeded the number of "
|
||
|
"bytes specified by Content-Length header (%s)" % cl
|
||
|
)
|
||
|
self.logged_write_excess = True
|
||
|
if towrite:
|
||
|
channel.write_soon(towrite)
|
||
|
elif data:
|
||
|
# Cheat, and tell the application we have written all of the bytes,
|
||
|
# even though the response shouldn't have a body and we are
|
||
|
# ignoring it entirely.
|
||
|
self.content_bytes_written += len(data)
|
||
|
|
||
|
if not self.logged_write_no_body:
|
||
|
self.logger.warning(
|
||
|
"application-written content was ignored due to HTTP "
|
||
|
"response that may not contain a message-body: (%s)" % self.status
|
||
|
)
|
||
|
self.logged_write_no_body = True
|
||
|
|
||
|
|
||
|
class ErrorTask(Task):
|
||
|
"""An error task produces an error response"""
|
||
|
|
||
|
complete = True
|
||
|
|
||
|
def execute(self):
|
||
|
e = self.request.error
|
||
|
status, headers, body = e.to_response()
|
||
|
self.status = status
|
||
|
self.response_headers.extend(headers)
|
||
|
# We need to explicitly tell the remote client we are closing the
|
||
|
# connection, because self.close_on_finish is set, and we are going to
|
||
|
# slam the door in the clients face.
|
||
|
self.response_headers.append(("Connection", "close"))
|
||
|
self.close_on_finish = True
|
||
|
self.content_length = len(body)
|
||
|
self.write(body.encode("latin-1"))
|
||
|
|
||
|
|
||
|
class WSGITask(Task):
|
||
|
"""A WSGI task produces a response from a WSGI application."""
|
||
|
|
||
|
environ = None
|
||
|
|
||
|
def execute(self):
|
||
|
environ = self.get_environment()
|
||
|
|
||
|
def start_response(status, headers, exc_info=None):
|
||
|
if self.complete and not exc_info:
|
||
|
raise AssertionError(
|
||
|
"start_response called a second time without providing exc_info."
|
||
|
)
|
||
|
if exc_info:
|
||
|
try:
|
||
|
if self.wrote_header:
|
||
|
# higher levels will catch and handle raised exception:
|
||
|
# 1. "service" method in task.py
|
||
|
# 2. "service" method in channel.py
|
||
|
# 3. "handler_thread" method in task.py
|
||
|
raise exc_info[1]
|
||
|
else:
|
||
|
# As per WSGI spec existing headers must be cleared
|
||
|
self.response_headers = []
|
||
|
finally:
|
||
|
exc_info = None
|
||
|
|
||
|
self.complete = True
|
||
|
|
||
|
if not status.__class__ is str:
|
||
|
raise AssertionError("status %s is not a string" % status)
|
||
|
if "\n" in status or "\r" in status:
|
||
|
raise ValueError(
|
||
|
"carriage return/line feed character present in status"
|
||
|
)
|
||
|
|
||
|
self.status = status
|
||
|
|
||
|
# Prepare the headers for output
|
||
|
for k, v in headers:
|
||
|
if not k.__class__ is str:
|
||
|
raise AssertionError(
|
||
|
"Header name %r is not a string in %r" % (k, (k, v))
|
||
|
)
|
||
|
if not v.__class__ is str:
|
||
|
raise AssertionError(
|
||
|
"Header value %r is not a string in %r" % (v, (k, v))
|
||
|
)
|
||
|
|
||
|
if "\n" in v or "\r" in v:
|
||
|
raise ValueError(
|
||
|
"carriage return/line feed character present in header value"
|
||
|
)
|
||
|
if "\n" in k or "\r" in k:
|
||
|
raise ValueError(
|
||
|
"carriage return/line feed character present in header name"
|
||
|
)
|
||
|
|
||
|
kl = k.lower()
|
||
|
if kl == "content-length":
|
||
|
self.content_length = int(v)
|
||
|
elif kl in hop_by_hop:
|
||
|
raise AssertionError(
|
||
|
'%s is a "hop-by-hop" header; it cannot be used by '
|
||
|
"a WSGI application (see PEP 3333)" % k
|
||
|
)
|
||
|
|
||
|
self.response_headers.extend(headers)
|
||
|
|
||
|
# Return a method used to write the response data.
|
||
|
return self.write
|
||
|
|
||
|
# Call the application to handle the request and write a response
|
||
|
app_iter = self.channel.server.application(environ, start_response)
|
||
|
|
||
|
can_close_app_iter = True
|
||
|
try:
|
||
|
if app_iter.__class__ is ReadOnlyFileBasedBuffer:
|
||
|
cl = self.content_length
|
||
|
size = app_iter.prepare(cl)
|
||
|
if size:
|
||
|
if cl != size:
|
||
|
if cl is not None:
|
||
|
self.remove_content_length_header()
|
||
|
self.content_length = size
|
||
|
self.write(b"") # generate headers
|
||
|
# if the write_soon below succeeds then the channel will
|
||
|
# take over closing the underlying file via the channel's
|
||
|
# _flush_some or handle_close so we intentionally avoid
|
||
|
# calling close in the finally block
|
||
|
self.channel.write_soon(app_iter)
|
||
|
can_close_app_iter = False
|
||
|
return
|
||
|
|
||
|
first_chunk_len = None
|
||
|
for chunk in app_iter:
|
||
|
if first_chunk_len is None:
|
||
|
first_chunk_len = len(chunk)
|
||
|
# Set a Content-Length header if one is not supplied.
|
||
|
# start_response may not have been called until first
|
||
|
# iteration as per PEP, so we must reinterrogate
|
||
|
# self.content_length here
|
||
|
if self.content_length is None:
|
||
|
app_iter_len = None
|
||
|
if hasattr(app_iter, "__len__"):
|
||
|
app_iter_len = len(app_iter)
|
||
|
if app_iter_len == 1:
|
||
|
self.content_length = first_chunk_len
|
||
|
# transmit headers only after first iteration of the iterable
|
||
|
# that returns a non-empty bytestring (PEP 3333)
|
||
|
if chunk:
|
||
|
self.write(chunk)
|
||
|
|
||
|
cl = self.content_length
|
||
|
if cl is not None:
|
||
|
if self.content_bytes_written != cl:
|
||
|
# close the connection so the client isn't sitting around
|
||
|
# waiting for more data when there are too few bytes
|
||
|
# to service content-length
|
||
|
self.close_on_finish = True
|
||
|
if self.request.command != "HEAD":
|
||
|
self.logger.warning(
|
||
|
"application returned too few bytes (%s) "
|
||
|
"for specified Content-Length (%s) via app_iter"
|
||
|
% (self.content_bytes_written, cl),
|
||
|
)
|
||
|
finally:
|
||
|
if can_close_app_iter and hasattr(app_iter, "close"):
|
||
|
app_iter.close()
|
||
|
|
||
|
def get_environment(self):
|
||
|
"""Returns a WSGI environment."""
|
||
|
environ = self.environ
|
||
|
if environ is not None:
|
||
|
# Return the cached copy.
|
||
|
return environ
|
||
|
|
||
|
request = self.request
|
||
|
path = request.path
|
||
|
channel = self.channel
|
||
|
server = channel.server
|
||
|
url_prefix = server.adj.url_prefix
|
||
|
|
||
|
if path.startswith("/"):
|
||
|
# strip extra slashes at the beginning of a path that starts
|
||
|
# with any number of slashes
|
||
|
path = "/" + path.lstrip("/")
|
||
|
|
||
|
if url_prefix:
|
||
|
# NB: url_prefix is guaranteed by the configuration machinery to
|
||
|
# be either the empty string or a string that starts with a single
|
||
|
# slash and ends without any slashes
|
||
|
if path == url_prefix:
|
||
|
# if the path is the same as the url prefix, the SCRIPT_NAME
|
||
|
# should be the url_prefix and PATH_INFO should be empty
|
||
|
path = ""
|
||
|
else:
|
||
|
# if the path starts with the url prefix plus a slash,
|
||
|
# the SCRIPT_NAME should be the url_prefix and PATH_INFO should
|
||
|
# the value of path from the slash until its end
|
||
|
url_prefix_with_trailing_slash = url_prefix + "/"
|
||
|
if path.startswith(url_prefix_with_trailing_slash):
|
||
|
path = path[len(url_prefix) :]
|
||
|
|
||
|
environ = {
|
||
|
"REMOTE_ADDR": channel.addr[0],
|
||
|
# Nah, we aren't actually going to look up the reverse DNS for
|
||
|
# REMOTE_ADDR, but we will happily set this environment variable
|
||
|
# for the WSGI application. Spec says we can just set this to
|
||
|
# REMOTE_ADDR, so we do.
|
||
|
"REMOTE_HOST": channel.addr[0],
|
||
|
# try and set the REMOTE_PORT to something useful, but maybe None
|
||
|
"REMOTE_PORT": str(channel.addr[1]),
|
||
|
"REQUEST_METHOD": request.command.upper(),
|
||
|
"SERVER_PORT": str(server.effective_port),
|
||
|
"SERVER_NAME": server.server_name,
|
||
|
"SERVER_SOFTWARE": server.adj.ident,
|
||
|
"SERVER_PROTOCOL": "HTTP/%s" % self.version,
|
||
|
"SCRIPT_NAME": url_prefix,
|
||
|
"PATH_INFO": path,
|
||
|
"QUERY_STRING": request.query,
|
||
|
"wsgi.url_scheme": request.url_scheme,
|
||
|
# the following environment variables are required by the WSGI spec
|
||
|
"wsgi.version": (1, 0),
|
||
|
# apps should use the logging module
|
||
|
"wsgi.errors": sys.stderr,
|
||
|
"wsgi.multithread": True,
|
||
|
"wsgi.multiprocess": False,
|
||
|
"wsgi.run_once": False,
|
||
|
"wsgi.input": request.get_body_stream(),
|
||
|
"wsgi.file_wrapper": ReadOnlyFileBasedBuffer,
|
||
|
"wsgi.input_terminated": True, # wsgi.input is EOF terminated
|
||
|
}
|
||
|
|
||
|
for key, value in dict(request.headers).items():
|
||
|
value = value.strip()
|
||
|
mykey = rename_headers.get(key, None)
|
||
|
if mykey is None:
|
||
|
mykey = "HTTP_" + key
|
||
|
if mykey not in environ:
|
||
|
environ[mykey] = value
|
||
|
|
||
|
# Insert a callable into the environment that allows the application to
|
||
|
# check if the client disconnected. Only works with
|
||
|
# channel_request_lookahead larger than 0.
|
||
|
environ["waitress.client_disconnected"] = self.channel.check_client_disconnected
|
||
|
|
||
|
# cache the environ for this request
|
||
|
self.environ = environ
|
||
|
return environ
|