diff --git a/bazarr/api.py b/bazarr/api.py index 40145f0a0..dc8e57fec 100644 --- a/bazarr/api.py +++ b/bazarr/api.py @@ -20,7 +20,7 @@ from logger import empty_log from init import * import logging from database import database, get_exclusion_clause, get_profiles_list, get_desired_languages, get_profile_id_name, \ - get_audio_profile_languages, update_profile_id_list + get_audio_profile_languages, update_profile_id_list, convert_list_to_clause from helper import path_mappings from get_languages import language_from_alpha2, language_from_alpha3, alpha2_from_alpha3, alpha3_from_alpha2 from get_subtitle import download_subtitle, series_download_subtitles, manual_search, manual_download_subtitle, \ @@ -31,7 +31,7 @@ from list_subtitles import store_subtitles, store_subtitles_movie, series_scan_s list_missing_subtitles, list_missing_subtitles_movies from utils import history_log, history_log_movie, blacklist_log, blacklist_delete, blacklist_delete_all, \ blacklist_log_movie, blacklist_delete_movie, blacklist_delete_all_movie, get_sonarr_version, get_radarr_version, \ - delete_subtitles, subtitles_apply_mods, translate_subtitles_file, check_credentials + delete_subtitles, subtitles_apply_mods, translate_subtitles_file, check_credentials, get_health_issues from get_providers import get_providers, get_providers_auth, list_throttled_providers, reset_throttled_providers, \ get_throttled_providers, set_throttled_providers from event_handler import event_stream @@ -125,8 +125,6 @@ def postprocessSeries(item): if 'path' in item: item['path'] = path_mappings.path_replace(item['path']) - # Confirm if path exist - item['exist'] = os.path.isdir(item['path']) # map poster and fanart to server proxy if 'poster' in item: @@ -138,9 +136,7 @@ def postprocessSeries(item): item['fanart'] = f"{base_url}/images/series{fanart}" -def postprocessEpisode(item, desired=None): - if desired is None: - desired = [] +def postprocessEpisode(item): postprocess(item) if 'audio_language' in item and item['audio_language'] is not None: item['audio_language'] = get_audio_profile_languages(episode_id=item['sonarrEpisodeId']) @@ -168,10 +164,6 @@ def postprocessEpisode(item, desired=None): item.update({"subtitles": subtitles}) - if settings.general.getboolean('embedded_subs_show_desired'): - item['subtitles'] = [x for x in item['subtitles'] if - x['code2'] in desired or x['path']] - # Parse missing subtitles if 'missing_subtitles' in item: if item['missing_subtitles'] is None: @@ -195,11 +187,9 @@ def postprocessEpisode(item, desired=None): item["sceneName"] = item["scene_name"] del item["scene_name"] - if 'path' in item: - if item['path']: - # Provide mapped path - item['path'] = path_mappings.path_replace(item['path']) - item['exist'] = os.path.isfile(item['path']) + if 'path' in item and item['path']: + # Provide mapped path + item['path'] = path_mappings.path_replace(item['path']) # TODO: Move @@ -270,8 +260,6 @@ def postprocessMovie(item): if 'path' in item: if item['path']: item['path'] = path_mappings.path_replace_movie(item['path']) - # Confirm if path exist - item['exist'] = os.path.isfile(item['path']) if 'subtitles_path' in item: # Provide mapped subtitles path @@ -319,7 +307,7 @@ class System(Resource): return '', 204 -class BadgesSeries(Resource): +class Badges(Resource): @authenticate def get(self): missing_episodes = database.execute("SELECT table_shows.tags, table_episodes.monitored, table_shows.seriesType " @@ -334,10 +322,13 @@ class BadgesSeries(Resource): throttled_providers = len(eval(str(get_throttled_providers()))) + health_issues = len(get_health_issues()) + result = { "episodes": missing_episodes, "movies": missing_movies, - "providers": throttled_providers + "providers": throttled_providers, + "status": health_issues } return jsonify(result) @@ -421,7 +412,8 @@ class SystemSettings(Resource): if len(enabled_languages) != 0: database.execute("UPDATE table_settings_languages SET enabled=0") for code in enabled_languages: - database.execute("UPDATE table_settings_languages SET enabled=1 WHERE code2=?", (code,)) + database.execute("UPDATE table_settings_languages SET enabled=1 WHERE code2=?",(code,)) + event_stream("languages") languages_profiles = request.form.get('languages-profiles') if languages_profiles: @@ -451,6 +443,7 @@ class SystemSettings(Resource): database.execute('DELETE FROM table_languages_profiles WHERE profileId = ?', (profileId,)) update_profile_id_list() + event_stream("languages") if settings.general.getboolean('use_sonarr'): scheduler.add_job(list_missing_subtitles, kwargs={'send_event': False}) @@ -465,6 +458,7 @@ class SystemSettings(Resource): (item['enabled'], item['url'], item['name'])) save_settings(zip(request.form.keys(), request.form.listvalues())) + event_stream("settings") return '', 204 @@ -533,6 +527,12 @@ class SystemStatus(Resource): return jsonify(data=system_status) +class SystemHealth(Resource): + @authenticate + def get(self): + return jsonify(data=get_health_issues()) + + class SystemReleases(Resource): @authenticate def get(self): @@ -577,9 +577,8 @@ class Series(Resource): count = database.execute("SELECT COUNT(*) as count FROM table_shows", only_one=True)['count'] if len(seriesId) != 0: - seriesIdList = ','.join(seriesId) result = database.execute( - f"SELECT * FROM table_shows WHERE sonarrSeriesId in ({seriesIdList}) ORDER BY sortTitle ASC") + f"SELECT * FROM table_shows WHERE sonarrSeriesId in {convert_list_to_clause(seriesId)} ORDER BY sortTitle ASC") else: result = database.execute("SELECT * FROM table_shows ORDER BY sortTitle ASC LIMIT ? OFFSET ?" , (length, start)) @@ -627,9 +626,10 @@ class Series(Resource): database.execute("UPDATE table_shows SET profileId=? WHERE sonarrSeriesId=?", (profileId, seriesId)) - list_missing_subtitles(no=seriesId) + list_missing_subtitles(no=seriesId, send_event=False) - # event_stream(type='series', action='update', series=seriesId) + event_stream(type='series', payload=seriesId) + event_stream(type='badges') return '', 204 @@ -653,23 +653,20 @@ class Series(Resource): class Episodes(Resource): @authenticate def get(self): - seriesId = request.args.get('seriesid') - episodeId = request.args.get('episodeid') - if episodeId: - result = database.execute("SELECT * FROM table_episodes WHERE sonarrEpisodeId=?", (episodeId,)) - elif seriesId: - result = database.execute("SELECT * FROM table_episodes WHERE sonarrSeriesId=? ORDER BY season DESC, " - "episode DESC", (seriesId,)) - else: - return "Series ID not provided", 400 + seriesId = request.args.getlist('seriesid[]') + episodeId = request.args.getlist('episodeid[]') - profileId = database.execute("SELECT profileId FROM table_shows WHERE sonarrSeriesId = ?", (seriesId,), - only_one=True)['profileId'] - desired_languages = str(get_desired_languages(profileId)) - desired = ast.literal_eval(desired_languages) + if len(episodeId) > 0: + result = database.execute(f"SELECT * FROM table_episodes WHERE sonarrEpisodeId in {convert_list_to_clause(episodeId)}") + elif len(seriesId) > 0: + result = database.execute("SELECT * FROM table_episodes " + f"WHERE sonarrSeriesId in {convert_list_to_clause(seriesId)} ORDER BY season DESC, " + "episode DESC") + else: + return "Series or Episode ID not provided", 400 for item in result: - postprocessEpisode(item, desired) + postprocessEpisode(item) return jsonify(data=result) @@ -727,7 +724,7 @@ class EpisodesSubtitles(Resource): send_notifications(sonarrSeriesId, sonarrEpisodeId, message) store_subtitles(path, episodePath) else: - event_stream(type='episode', action='update', series=int(sonarrSeriesId), episode=int(sonarrEpisodeId)) + event_stream(type='episode', payload=sonarrEpisodeId) except OSError: pass @@ -820,14 +817,12 @@ class Movies(Resource): def get(self): start = request.args.get('start') or 0 length = request.args.get('length') or -1 - id = request.args.getlist('radarrid[]') + radarrId = request.args.getlist('radarrid[]') count = database.execute("SELECT COUNT(*) as count FROM table_movies", only_one=True)['count'] - if len(id) != 0: - movieIdList = ','.join(id) - result = database.execute( - f"SELECT * FROM table_movies WHERE radarrId in ({movieIdList}) ORDER BY sortTitle ASC") + if len(radarrId) != 0: + result = database.execute(f"SELECT * FROM table_movies WHERE radarrId in {convert_list_to_clause(radarrId)} ORDER BY sortTitle ASC") else: result = database.execute("SELECT * FROM table_movies ORDER BY sortTitle ASC LIMIT ? OFFSET ?", (length, start)) @@ -857,7 +852,8 @@ class Movies(Resource): list_missing_subtitles_movies(no=radarrId) - # event_stream(type='movies', action='update', movie=radarrId) + event_stream(type='movies', payload=radarrId) + event_stream(type='badges') return '', 204 @@ -933,7 +929,7 @@ class MoviesSubtitles(Resource): send_notifications_movie(radarrId, message) store_subtitles_movie(path, moviePath) else: - event_stream(type='movie', action='update', movie=int(radarrId)) + event_stream(type='movie', payload=radarrId) except OSError: pass @@ -1442,17 +1438,30 @@ class HistoryStats(Resource): class EpisodesWanted(Resource): @authenticate def get(self): - start = request.args.get('start') or 0 - length = request.args.get('length') or -1 - data = database.execute("SELECT table_shows.title as seriesTitle, table_episodes.monitored, " - "table_episodes.season || 'x' || table_episodes.episode as episode_number, " - "table_episodes.title as episodeTitle, table_episodes.missing_subtitles, " - "table_episodes.sonarrSeriesId, " - "table_episodes.sonarrEpisodeId, table_episodes.scene_name as sceneName, table_shows.tags, " - "table_episodes.failedAttempts, table_shows.seriesType FROM table_episodes INNER JOIN " - "table_shows on table_shows.sonarrSeriesId = table_episodes.sonarrSeriesId WHERE " - "table_episodes.missing_subtitles != '[]'" + get_exclusion_clause('series') + - " ORDER BY table_episodes._rowid_ DESC LIMIT ? OFFSET ?", (length, start)) + episodeid = request.args.getlist('episodeid[]') + if len(episodeid) > 0: + data = database.execute("SELECT table_shows.title as seriesTitle, table_episodes.monitored, " + "table_episodes.season || 'x' || table_episodes.episode as episode_number, " + "table_episodes.title as episodeTitle, table_episodes.missing_subtitles, " + "table_episodes.sonarrSeriesId, " + "table_episodes.sonarrEpisodeId, table_episodes.scene_name as sceneName, table_shows.tags, " + "table_episodes.failedAttempts, table_shows.seriesType FROM table_episodes INNER JOIN " + "table_shows on table_shows.sonarrSeriesId = table_episodes.sonarrSeriesId WHERE " + "table_episodes.missing_subtitles != '[]'" + get_exclusion_clause('series') + + f" AND sonarrEpisodeId in {convert_list_to_clause(episodeid)}") + pass + else: + start = request.args.get('start') or 0 + length = request.args.get('length') or -1 + data = database.execute("SELECT table_shows.title as seriesTitle, table_episodes.monitored, " + "table_episodes.season || 'x' || table_episodes.episode as episode_number, " + "table_episodes.title as episodeTitle, table_episodes.missing_subtitles, " + "table_episodes.sonarrSeriesId, " + "table_episodes.sonarrEpisodeId, table_episodes.scene_name as sceneName, table_shows.tags, " + "table_episodes.failedAttempts, table_shows.seriesType FROM table_episodes INNER JOIN " + "table_shows on table_shows.sonarrSeriesId = table_episodes.sonarrSeriesId WHERE " + "table_episodes.missing_subtitles != '[]'" + get_exclusion_clause('series') + + " ORDER BY table_episodes._rowid_ DESC LIMIT ? OFFSET ?", (length, start)) for item in data: postprocessEpisode(item) @@ -1469,20 +1478,28 @@ class EpisodesWanted(Resource): class MoviesWanted(Resource): @authenticate def get(self): - start = request.args.get('start') or 0 - length = request.args.get('length') or -1 - data = database.execute("SELECT title, missing_subtitles, radarrId, sceneName, " - "failedAttempts, tags, monitored FROM table_movies WHERE missing_subtitles != '[]'" + - get_exclusion_clause('movie') + - " ORDER BY _rowid_ DESC LIMIT ? OFFSET ?", (length, start)) + radarrid = request.args.getlist("radarrid[]") + if len(radarrid) > 0: + result = database.execute("SELECT title, missing_subtitles, radarrId, sceneName, " + "failedAttempts, tags, monitored FROM table_movies WHERE missing_subtitles != '[]'" + + get_exclusion_clause('movie') + + f" AND radarrId in {convert_list_to_clause(radarrid)}") + pass + else: + start = request.args.get('start') or 0 + length = request.args.get('length') or -1 + result = database.execute("SELECT title, missing_subtitles, radarrId, sceneName, " + "failedAttempts, tags, monitored FROM table_movies WHERE missing_subtitles != '[]'" + + get_exclusion_clause('movie') + + " ORDER BY _rowid_ DESC LIMIT ? OFFSET ?", (length, start)) - for item in data: + for item in result: postprocessMovie(item) count = database.execute("SELECT COUNT(*) as count FROM table_movies WHERE missing_subtitles != '[]'" + get_exclusion_clause('movie'), only_one=True)['count'] - return jsonify(data=data, total=count) + return jsonify(data=result, total=count) # GET: get blacklist @@ -1540,7 +1557,7 @@ class EpisodesBlacklist(Resource): sonarr_series_id=sonarr_series_id, sonarr_episode_id=sonarr_episode_id) episode_download_subtitles(sonarr_episode_id) - event_stream(type='episodeHistory') + event_stream(type='episode-history') return '', 200 @authenticate @@ -1606,7 +1623,7 @@ class MoviesBlacklist(Resource): subtitles_path=subtitles_path, radarr_id=radarr_id) movies_download_subtitles(radarr_id) - event_stream(type='movieHistory') + event_stream(type='movie-history') return '', 200 @authenticate @@ -1746,7 +1763,7 @@ class BrowseRadarrFS(Resource): return jsonify(data) -api.add_resource(BadgesSeries, '/badges') +api.add_resource(Badges, '/badges') api.add_resource(Providers, '/providers') api.add_resource(ProviderMovies, '/providers/movies') @@ -1758,6 +1775,7 @@ api.add_resource(SystemAccount, '/system/account') api.add_resource(SystemTasks, '/system/tasks') api.add_resource(SystemLogs, '/system/logs') api.add_resource(SystemStatus, '/system/status') +api.add_resource(SystemHealth, '/system/health') api.add_resource(SystemReleases, '/system/releases') api.add_resource(SystemSettings, '/system/settings') api.add_resource(Languages, '/system/languages') diff --git a/bazarr/app.py b/bazarr/app.py index 101053ea9..dcd240636 100644 --- a/bazarr/app.py +++ b/bazarr/app.py @@ -38,7 +38,7 @@ def create_app(): toolbar = DebugToolbarExtension(app) - socketio.init_app(app, path=base_url.rstrip('/')+'/socket.io', cors_allowed_origins='*', async_mode='threading') + socketio.init_app(app, path=base_url.rstrip('/')+'/api/socket.io', cors_allowed_origins='*', async_mode='gevent') return app diff --git a/bazarr/config.py b/bazarr/config.py index 462491710..6d29eb58b 100644 --- a/bazarr/config.py +++ b/bazarr/config.py @@ -397,8 +397,7 @@ def save_settings(settings_items): if exclusion_updated: from event_handler import event_stream - event_stream(type='badges_series') - event_stream(type='badges_movies') + event_stream(type='badges') def url_sonarr(): diff --git a/bazarr/database.py b/bazarr/database.py index 3e0e2c19d..3513194b0 100644 --- a/bazarr/database.py +++ b/bazarr/database.py @@ -160,6 +160,12 @@ def db_upgrade(): database.execute("CREATE TABLE IF NOT EXISTS table_blacklist_movie (radarr_id integer, timestamp integer, " "provider text, subs_id text, language text)") + # Create rootfolder tables + database.execute("CREATE TABLE IF NOT EXISTS table_shows_rootfolder (id integer, path text, accessible integer, " + "error text)") + database.execute("CREATE TABLE IF NOT EXISTS table_movies_rootfolder (id integer, path text, accessible integer, " + "error text)") + # Create languages profiles table and populate it lang_table_content = database.execute("SELECT * FROM table_languages_profiles") if isinstance(lang_table_content, list): @@ -483,3 +489,9 @@ def get_audio_profile_languages(series_id=None, episode_id=None, movie_id=None): ) return audio_languages + +def convert_list_to_clause(arr: list): + if isinstance(arr, list): + return f"({','.join(str(x) for x in arr)})" + else: + return "" diff --git a/bazarr/event_handler.py b/bazarr/event_handler.py index e824ea492..e03e22ede 100644 --- a/bazarr/event_handler.py +++ b/bazarr/event_handler.py @@ -1,23 +1,19 @@ # coding=utf-8 -import json from app import socketio -def event_stream(type=None, action=None, series=None, episode=None, movie=None, task=None): +def event_stream(type, action="update", payload=None): """ :param type: The type of element. :type type: str - :param action: The action type of element from insert, update, delete. + :param action: The action type of element from update and delete. :type action: str - :param series: The series id. - :type series: str - :param episode: The episode id. - :type episode: str - :param movie: The movie id. - :type movie: str - :param task: The task id. - :type task: str + :param payload: The payload to send, can be anything """ - socketio.emit('event', json.dumps({"type": type, "action": action, "series": series, "episode": episode, - "movie": movie, "task": task})) + + try: + payload = int(payload) + except (ValueError, TypeError): + pass + socketio.emit("data", {"type": type, "action": action, "payload": payload}) diff --git a/bazarr/get_episodes.py b/bazarr/get_episodes.py index 582d0e32d..60d72c262 100644 --- a/bazarr/get_episodes.py +++ b/bazarr/get_episodes.py @@ -143,8 +143,7 @@ def sync_episodes(): episode_to_delete = database.execute("SELECT sonarrSeriesId, sonarrEpisodeId FROM table_episodes WHERE " "sonarrEpisodeId=?", (removed_episode,), only_one=True) database.execute("DELETE FROM table_episodes WHERE sonarrEpisodeId=?", (removed_episode,)) - event_stream(type='episode', action='delete', series=episode_to_delete['sonarrSeriesId'], - episode=episode_to_delete['sonarrEpisodeId']) + event_stream(type='episode', action='delete', payload=episode_to_delete['sonarrEpisodeId']) # Update existing episodes in DB episode_in_db_list = [] @@ -175,8 +174,7 @@ def sync_episodes(): altered_episodes.append([added_episode['sonarrEpisodeId'], added_episode['path'], added_episode['monitored']]) - event_stream(type='episode', action='insert', series=added_episode['sonarrSeriesId'], - episode=added_episode['sonarrEpisodeId']) + event_stream(type='episode', payload=added_episode['sonarrEpisodeId']) else: logging.debug('BAZARR unable to insert this episode into the database:{}'.format(path_mappings.path_replace(added_episode['path']))) diff --git a/bazarr/get_movies.py b/bazarr/get_movies.py index 8441ac7e4..55b58fdbc 100644 --- a/bazarr/get_movies.py +++ b/bazarr/get_movies.py @@ -8,6 +8,7 @@ from config import settings, url_radarr from helper import path_mappings from utils import get_radarr_version from list_subtitles import store_subtitles_movie, movies_full_scan_subtitles +from get_rootfolder import check_radarr_rootfolder from get_subtitle import movies_download_subtitles from database import database, dict_converter, get_exclusion_clause @@ -21,6 +22,7 @@ def update_all_movies(): def update_movies(): + check_radarr_rootfolder() logging.debug('BAZARR Starting movie sync from Radarr.') apikey_radarr = settings.radarr.apikey diff --git a/bazarr/get_providers.py b/bazarr/get_providers.py index 6be3e4588..e0fd0293b 100644 --- a/bazarr/get_providers.py +++ b/bazarr/get_providers.py @@ -278,7 +278,7 @@ def update_throttled_provider(): del tp[provider] set_throttled_providers(str(tp)) - event_stream(type='badges_providers') + event_stream(type='badges') def list_throttled_providers(): diff --git a/bazarr/get_rootfolder.py b/bazarr/get_rootfolder.py new file mode 100644 index 000000000..625ab523e --- /dev/null +++ b/bazarr/get_rootfolder.py @@ -0,0 +1,116 @@ +# coding=utf-8 + +import os +import requests +import logging + +from config import settings, url_sonarr, url_radarr +from helper import path_mappings +from database import database + +headers = {"User-Agent": os.environ["SZ_USER_AGENT"]} + + +def get_sonarr_rootfolder(): + apikey_sonarr = settings.sonarr.apikey + sonarr_rootfolder = [] + + # Get root folder data from Sonarr + url_sonarr_api_rootfolder = url_sonarr() + "/api/rootfolder?apikey=" + apikey_sonarr + + try: + rootfolder = requests.get(url_sonarr_api_rootfolder, timeout=60, verify=False, headers=headers) + except requests.exceptions.ConnectionError: + logging.exception("BAZARR Error trying to get rootfolder from Sonarr. Connection Error.") + return [] + except requests.exceptions.Timeout: + logging.exception("BAZARR Error trying to get rootfolder from Sonarr. Timeout Error.") + return [] + except requests.exceptions.RequestException: + logging.exception("BAZARR Error trying to get rootfolder from Sonarr.") + return [] + else: + for folder in rootfolder.json(): + sonarr_rootfolder.append({'id': folder['id'], 'path': folder['path']}) + db_rootfolder = database.execute('SELECT id, path FROM table_shows_rootfolder') + rootfolder_to_remove = [x for x in db_rootfolder if not + next((item for item in sonarr_rootfolder if item['id'] == x['id']), False)] + rootfolder_to_update = [x for x in sonarr_rootfolder if + next((item for item in db_rootfolder if item['id'] == x['id']), False)] + rootfolder_to_insert = [x for x in sonarr_rootfolder if not + next((item for item in db_rootfolder if item['id'] == x['id']), False)] + + for item in rootfolder_to_remove: + database.execute('DELETE FROM table_shows_rootfolder WHERE id = ?', (item['id'],)) + for item in rootfolder_to_update: + database.execute('UPDATE table_shows_rootfolder SET path=? WHERE id = ?', (item['path'], item['id'])) + for item in rootfolder_to_insert: + database.execute('INSERT INTO table_shows_rootfolder (id, path) VALUES (?, ?)', (item['id'], item['path'])) + + +def check_sonarr_rootfolder(): + get_sonarr_rootfolder() + rootfolder = database.execute('SELECT id, path FROM table_shows_rootfolder') + for item in rootfolder: + if not os.path.isdir(path_mappings.path_replace(item['path'])): + database.execute("UPDATE table_shows_rootfolder SET accessible = 0, error = 'This Sonarr root directory " + "does not seems to be accessible by Bazarr. Please check path mapping.' WHERE id = ?", + (item['id'],)) + elif not os.access(path_mappings.path_replace(item['path']), os.W_OK): + database.execute("UPDATE table_shows_rootfolder SET accessible = 0, error = 'Bazarr cannot write to " + "this directory' WHERE id = ?", (item['id'],)) + else: + database.execute("UPDATE table_shows_rootfolder SET accessible = 1, error = '' WHERE id = ?", (item['id'],)) + + +def get_radarr_rootfolder(): + apikey_radarr = settings.radarr.apikey + radarr_rootfolder = [] + + # Get root folder data from Radarr + url_radarr_api_rootfolder = url_radarr() + "/api/rootfolder?apikey=" + apikey_radarr + + try: + rootfolder = requests.get(url_radarr_api_rootfolder, timeout=60, verify=False, headers=headers) + except requests.exceptions.ConnectionError: + logging.exception("BAZARR Error trying to get rootfolder from Radarr. Connection Error.") + return [] + except requests.exceptions.Timeout: + logging.exception("BAZARR Error trying to get rootfolder from Radarr. Timeout Error.") + return [] + except requests.exceptions.RequestException: + logging.exception("BAZARR Error trying to get rootfolder from Radarr.") + return [] + else: + for folder in rootfolder.json(): + radarr_rootfolder.append({'id': folder['id'], 'path': folder['path']}) + db_rootfolder = database.execute('SELECT id, path FROM table_movies_rootfolder') + rootfolder_to_remove = [x for x in db_rootfolder if not + next((item for item in radarr_rootfolder if item['id'] == x['id']), False)] + rootfolder_to_update = [x for x in radarr_rootfolder if + next((item for item in db_rootfolder if item['id'] == x['id']), False)] + rootfolder_to_insert = [x for x in radarr_rootfolder if not + next((item for item in db_rootfolder if item['id'] == x['id']), False)] + + for item in rootfolder_to_remove: + database.execute('DELETE FROM table_movies_rootfolder WHERE id = ?', (item['id'],)) + for item in rootfolder_to_update: + database.execute('UPDATE table_movies_rootfolder SET path=? WHERE id = ?', (item['path'], item['id'])) + for item in rootfolder_to_insert: + database.execute('INSERT INTO table_movies_rootfolder (id, path) VALUES (?, ?)', (item['id'], item['path'])) + + +def check_radarr_rootfolder(): + get_radarr_rootfolder() + rootfolder = database.execute('SELECT id, path FROM table_movies_rootfolder') + for item in rootfolder: + if not os.path.isdir(path_mappings.path_replace_movie(item['path'])): + database.execute("UPDATE table_movies_rootfolder SET accessible = 0, error = 'This Radarr root directory " + "does not seems to be accessible by Bazarr. Please check path mapping.' WHERE id = ?", + (item['id'],)) + elif not os.access(path_mappings.path_replace_movie(item['path']), os.W_OK): + database.execute("UPDATE table_movies_rootfolder SET accessible = 0, error = 'Bazarr cannot write to " + "this directory' WHERE id = ?", (item['id'],)) + else: + database.execute("UPDATE table_movies_rootfolder SET accessible = 1, error = '' WHERE id = ?", + (item['id'],)) diff --git a/bazarr/get_series.py b/bazarr/get_series.py index 1160b1843..0cbef5a8d 100644 --- a/bazarr/get_series.py +++ b/bazarr/get_series.py @@ -6,6 +6,7 @@ import logging from config import settings, url_sonarr from list_subtitles import list_missing_subtitles +from get_rootfolder import check_sonarr_rootfolder from database import database, dict_converter from utils import get_sonarr_version from helper import path_mappings @@ -15,6 +16,7 @@ headers = {"User-Agent": os.environ["SZ_USER_AGENT"]} def update_series(): + check_sonarr_rootfolder() apikey_sonarr = settings.sonarr.apikey if apikey_sonarr is None: return @@ -125,7 +127,7 @@ def update_series(): for series in removed_series: database.execute("DELETE FROM table_shows WHERE sonarrSeriesId=?",(series,)) - event_stream(type='series', action='delete', series=series) + event_stream(type='series', action='delete', payload=series) # Update existing series in DB series_in_db_list = [] @@ -141,7 +143,7 @@ def update_series(): query = dict_converter.convert(updated_series) database.execute('''UPDATE table_shows SET ''' + query.keys_update + ''' WHERE sonarrSeriesId = ?''', query.values + (updated_series['sonarrSeriesId'],)) - event_stream(type='series', action='update', series=updated_series['sonarrSeriesId']) + event_stream(type='series', payload=updated_series['sonarrSeriesId']) # Insert new series in DB for added_series in series_to_add: @@ -155,7 +157,7 @@ def update_series(): logging.debug('BAZARR unable to insert this series into the database:', path_mappings.path_replace(added_series['path'])) - event_stream(type='series', action='insert', series=added_series['sonarrSeriesId']) + event_stream(type='series', series=added_series['sonarrSeriesId']) logging.debug('BAZARR All series synced from Sonarr into database.') diff --git a/bazarr/get_subtitle.py b/bazarr/get_subtitle.py index d416c239d..bceb1d432 100644 --- a/bazarr/get_subtitle.py +++ b/bazarr/get_subtitle.py @@ -33,6 +33,7 @@ from subsyncer import subsync from guessit import guessit from database import database, dict_mapper, get_exclusion_clause, get_profiles_list, get_audio_profile_languages, \ get_desired_languages +from event_handler import event_stream from embedded_subs_reader import parse_video_metadata from analytics import track_event @@ -982,6 +983,7 @@ def wanted_download_subtitles(path, l, count_episodes): store_subtitles(episode['path'], path_mappings.path_replace(episode['path'])) history_log(1, episode['sonarrSeriesId'], episode['sonarrEpisodeId'], message, path, language_code, provider, score, subs_id, subs_path) + event_stream(type='episode-wanted', action='delete', payload=episode['sonarrEpisodeId']) send_notifications(episode['sonarrSeriesId'], episode['sonarrEpisodeId'], message) else: logging.debug( @@ -1050,6 +1052,7 @@ def wanted_download_subtitles_movie(path, l, count_movies): store_subtitles_movie(movie['path'], path_mappings.path_replace_movie(movie['path'])) history_log_movie(1, movie['radarrId'], message, path, language_code, provider, score, subs_id, subs_path) + event_stream(type='movie-wanted', action='delete', payload=movie['radarrId']) send_notifications_movie(movie['radarrId'], message) else: logging.info( diff --git a/bazarr/init.py b/bazarr/init.py index 1e9d4303b..8f390c651 100644 --- a/bazarr/init.py +++ b/bazarr/init.py @@ -45,7 +45,7 @@ import logging # deploy requirements.txt if not args.no_update: try: - import lxml, numpy, webrtcvad + import lxml, numpy, webrtcvad, gevent, geventwebsocket except ImportError: try: import pip diff --git a/bazarr/list_subtitles.py b/bazarr/list_subtitles.py index 061f86d14..69cb17251 100644 --- a/bazarr/list_subtitles.py +++ b/bazarr/list_subtitles.py @@ -365,9 +365,8 @@ def list_missing_subtitles(no=None, epno=None, send_event=True): (missing_subtitles_text, episode_subtitles['sonarrEpisodeId'])) if send_event: - event_stream(type='episode', action='update', series=episode_subtitles['sonarrSeriesId'], - episode=episode_subtitles['sonarrEpisodeId']) - event_stream(type='badges_series') + event_stream(type='episode', payload=episode_subtitles['sonarrEpisodeId']) + event_stream(type='badges') def list_missing_subtitles_movies(no=None, epno=None, send_event=True): @@ -475,8 +474,8 @@ def list_missing_subtitles_movies(no=None, epno=None, send_event=True): (missing_subtitles_text, movie_subtitles['radarrId'])) if send_event: - event_stream(type='movie', action='update', movie=movie_subtitles['radarrId']) - event_stream(type='badges_movies') + event_stream(type='movie', payload=movie_subtitles['radarrId']) + event_stream(type='badges') def series_full_scan_subtitles(): diff --git a/bazarr/logger.py b/bazarr/logger.py index 1dd062762..d2e8c4bf1 100644 --- a/bazarr/logger.py +++ b/bazarr/logger.py @@ -104,7 +104,7 @@ def configure_logging(debug=False): logging.getLogger("ffsubsync.ffsubsync").setLevel(logging.ERROR) logging.getLogger("srt").setLevel(logging.ERROR) - logging.getLogger("waitress").setLevel(logging.CRITICAL) + logging.getLogger("geventwebsocket.handler").setLevel(logging.WARNING) logging.getLogger("knowit").setLevel(logging.CRITICAL) logging.getLogger("enzyme").setLevel(logging.CRITICAL) logging.getLogger("guessit").setLevel(logging.WARNING) diff --git a/bazarr/scheduler.py b/bazarr/scheduler.py index c24635f38..8cde8262c 100644 --- a/bazarr/scheduler.py +++ b/bazarr/scheduler.py @@ -6,7 +6,7 @@ from get_series import update_series from config import settings from get_subtitle import wanted_search_missing_subtitles_series, wanted_search_missing_subtitles_movies, \ upgrade_subtitles -from utils import cache_maintenance +from utils import cache_maintenance, check_health from get_args import args if not args.no_update: from check_update import check_if_new_update, check_releases @@ -36,18 +36,19 @@ class Scheduler: def task_listener_add(event): if event.job_id not in self.__running_tasks: self.__running_tasks.append(event.job_id) - event_stream(type='task', task=event.job_id) + event_stream(type='task') def task_listener_remove(event): if event.job_id in self.__running_tasks: self.__running_tasks.remove(event.job_id) - event_stream(type='task', task=event.job_id) + event_stream(type='task') self.aps_scheduler.add_listener(task_listener_add, EVENT_JOB_SUBMITTED) self.aps_scheduler.add_listener(task_listener_remove, EVENT_JOB_EXECUTED | EVENT_JOB_ERROR) # configure all tasks self.__cache_cleanup_task() + self.__check_health_task() self.update_configurable_tasks() self.aps_scheduler.start() @@ -161,6 +162,10 @@ class Scheduler: self.aps_scheduler.add_job(cache_maintenance, IntervalTrigger(hours=24), max_instances=1, coalesce=True, misfire_grace_time=15, id='cache_cleanup', name='Cache maintenance') + def __check_health_task(self): + self.aps_scheduler.add_job(check_health, IntervalTrigger(hours=6), max_instances=1, coalesce=True, + misfire_grace_time=15, id='check_health', name='Check health') + def __sonarr_full_update_task(self): if settings.general.getboolean('use_sonarr'): full_update = settings.sonarr.full_update diff --git a/bazarr/server.py b/bazarr/server.py index 260969677..4eb0e801a 100644 --- a/bazarr/server.py +++ b/bazarr/server.py @@ -4,7 +4,8 @@ import warnings import logging import os import io -from waitress.server import create_server +from gevent import pywsgi +from geventwebsocket.handler import WebSocketHandler from get_args import args from config import settings, base_url @@ -30,10 +31,10 @@ class Server: self.server = app.run(host=str(settings.general.ip), port=(int(args.port) if args.port else int(settings.general.port))) else: - self.server = create_server(app, - host=str(settings.general.ip), - port=int(args.port) if args.port else int(settings.general.port), - threads=24) + self.server = pywsgi.WSGIServer((str(settings.general.ip), + int(args.port) if args.port else int(settings.general.port)), + app, + handler_class=WebSocketHandler) def start(self): try: @@ -41,13 +42,13 @@ class Server: 'BAZARR is started and waiting for request on http://' + str(settings.general.ip) + ':' + (str( args.port) if args.port else str(settings.general.port)) + str(base_url)) if not args.dev: - self.server.run() + self.server.serve_forever() except KeyboardInterrupt: self.shutdown() def shutdown(self): try: - self.server.close() + self.server.stop() except Exception as e: logging.error('BAZARR Cannot stop Waitress: ' + repr(e)) else: @@ -64,7 +65,7 @@ class Server: def restart(self): try: - self.server.close() + self.server.stop() except Exception as e: logging.error('BAZARR Cannot stop Waitress: ' + repr(e)) else: diff --git a/bazarr/utils.py b/bazarr/utils.py index 0e15242a1..c6ea9dd48 100644 --- a/bazarr/utils.py +++ b/bazarr/utils.py @@ -40,24 +40,24 @@ def history_log(action, sonarr_series_id, sonarr_episode_id, description, video_ "video_path, language, provider, score, subs_id, subtitles_path) VALUES (?,?,?,?,?,?,?,?,?,?,?)", (action, sonarr_series_id, sonarr_episode_id, time.time(), description, video_path, language, provider, score, subs_id, subtitles_path)) - event_stream(type='episodeHistory') + event_stream(type='episode-history') def blacklist_log(sonarr_series_id, sonarr_episode_id, provider, subs_id, language): database.execute("INSERT INTO table_blacklist (sonarr_series_id, sonarr_episode_id, timestamp, provider, " "subs_id, language) VALUES (?,?,?,?,?,?)", (sonarr_series_id, sonarr_episode_id, time.time(), provider, subs_id, language)) - event_stream(type='episodeBlacklist') + event_stream(type='episode-blacklist') def blacklist_delete(provider, subs_id): database.execute("DELETE FROM table_blacklist WHERE provider=? AND subs_id=?", (provider, subs_id)) - event_stream(type='episodeBlacklist') + event_stream(type='episode-blacklist', action='delete') def blacklist_delete_all(): database.execute("DELETE FROM table_blacklist") - event_stream(type='episodeBlacklist') + event_stream(type='episode-blacklist', action='delete') def history_log_movie(action, radarr_id, description, video_path=None, language=None, provider=None, score=None, @@ -65,23 +65,23 @@ def history_log_movie(action, radarr_id, description, video_path=None, language= database.execute("INSERT INTO table_history_movie (action, radarrId, timestamp, description, video_path, language, " "provider, score, subs_id, subtitles_path) VALUES (?,?,?,?,?,?,?,?,?,?)", (action, radarr_id, time.time(), description, video_path, language, provider, score, subs_id, subtitles_path)) - event_stream(type='movieHistory') + event_stream(type='movie-history') def blacklist_log_movie(radarr_id, provider, subs_id, language): database.execute("INSERT INTO table_blacklist_movie (radarr_id, timestamp, provider, subs_id, language) " "VALUES (?,?,?,?,?)", (radarr_id, time.time(), provider, subs_id, language)) - event_stream(type='movieBlacklist') + event_stream(type='movie-blacklist') def blacklist_delete_movie(provider, subs_id): database.execute("DELETE FROM table_blacklist_movie WHERE provider=? AND subs_id=?", (provider, subs_id)) - event_stream(type='movieBlacklist') + event_stream(type='movie-blacklist', action='delete') def blacklist_delete_all_movie(): database.execute("DELETE FROM table_blacklist_movie") - event_stream(type='movieBlacklist') + event_stream(type='movie-blacklist', action='delete') @region.cache_on_arguments() @@ -401,7 +401,39 @@ def translate_subtitles_file(video_path, source_srt_file, to_lang, forced, hi): return dest_srt_file + def check_credentials(user, pw): username = settings.auth.username password = settings.auth.password - return hashlib.md5(pw.encode('utf-8')).hexdigest() == password and user == username \ No newline at end of file + return hashlib.md5(pw.encode('utf-8')).hexdigest() == password and user == username + + +def check_health(): + from get_rootfolder import check_sonarr_rootfolder, check_radarr_rootfolder + if settings.general.getboolean('use_sonarr'): + check_sonarr_rootfolder() + if settings.general.getboolean('use_radarr'): + check_radarr_rootfolder() + event_stream(type='badges') + + +def get_health_issues(): + # this function must return a list of dictionaries consisting of to keys: object and issue + health_issues = [] + + # get Sonarr rootfolder issues + if settings.general.getboolean('use_sonarr'): + rootfolder = database.execute('SELECT path, accessible, error FROM table_shows_rootfolder WHERE accessible = 0') + for item in rootfolder: + health_issues.append({'object': path_mappings.path_replace(item['path']), + 'issue': item['error']}) + + # get Radarr rootfolder issues + if settings.general.getboolean('use_radarr'): + rootfolder = database.execute('SELECT path, accessible, error FROM table_movies_rootfolder ' + 'WHERE accessible = 0') + for item in rootfolder: + health_issues.append({'object': path_mappings.path_replace_movie(item['path']), + 'issue': item['error']}) + + return health_issues diff --git a/frontend/package-lock.json b/frontend/package-lock.json index effc57e2a..928e3fc12 100644 --- a/frontend/package-lock.json +++ b/frontend/package-lock.json @@ -31,6 +31,7 @@ "@types/redux-promise": "^0.5.0", "axios": "^0.21.0", "bootstrap": "^4.0.0", + "http-proxy-middleware": "^0.19.1", "lodash": "^4.0.0", "rc-slider": "^9.7.1", "react": "^16.0.0", @@ -48,6 +49,7 @@ "redux-promise": "^0.6.0", "redux-thunk": "^2.3.0", "sass": "^1.0.0", + "socket.io-client": "^4.0.0", "typescript": "^4.0.0" }, "devDependencies": { @@ -2793,6 +2795,11 @@ "resolved": "https://registry.npmjs.org/@types/classnames/-/classnames-2.2.11.tgz", "integrity": "sha512-2koNhpWm3DgWRp5tpkiJ8JGc1xTn2q0l+jUNUE7oMKXUf5NpI9AIdC4kbjGNFBdHtcxBD18LAksoudAVhFKCjw==" }, + "node_modules/@types/component-emitter": { + "version": "1.2.10", + "resolved": "https://registry.npmjs.org/@types/component-emitter/-/component-emitter-1.2.10.tgz", + "integrity": "sha512-bsjleuRKWmGqajMerkzox19aGbscQX5rmmvvXl3wlIp5gMG1HgkiwPxsN5p070fBDKTNSPgojVbuY1+HWMbFhg==" + }, "node_modules/@types/d3-path": { "version": "1.0.9", "resolved": "https://registry.npmjs.org/@types/d3-path/-/d3-path-1.0.9.tgz", @@ -4544,6 +4551,11 @@ "babylon": "bin/babylon.js" } }, + "node_modules/backo2": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/backo2/-/backo2-1.0.2.tgz", + "integrity": "sha1-MasayLEpNjRj41s+u2n038+6eUc=" + }, "node_modules/balanced-match": { "version": "1.0.0", "resolved": "https://registry.npmjs.org/balanced-match/-/balanced-match-1.0.0.tgz", @@ -4577,6 +4589,14 @@ "node": ">=0.10.0" } }, + "node_modules/base64-arraybuffer": { + "version": "0.1.4", + "resolved": "https://registry.npmjs.org/base64-arraybuffer/-/base64-arraybuffer-0.1.4.tgz", + "integrity": "sha1-mBjHngWbE1X5fgQooBfIOOkLqBI=", + "engines": { + "node": ">= 0.6.0" + } + }, "node_modules/base64-js": { "version": "1.5.1", "resolved": "https://registry.npmjs.org/base64-js/-/base64-js-1.5.1.tgz", @@ -7069,6 +7089,33 @@ "once": "^1.4.0" } }, + "node_modules/engine.io-client": { + "version": "5.0.1", + "resolved": "https://registry.npmjs.org/engine.io-client/-/engine.io-client-5.0.1.tgz", + "integrity": "sha512-CQtGN3YwfvbxVwpPugcsHe5rHT4KgT49CEcQppNtu9N7WxbPN0MAG27lGaem7bvtCFtGNLSL+GEqXsFSz36jTg==", + "dependencies": { + "base64-arraybuffer": "0.1.4", + "component-emitter": "~1.3.0", + "debug": "~4.3.1", + "engine.io-parser": "~4.0.1", + "has-cors": "1.1.0", + "parseqs": "0.0.6", + "parseuri": "0.0.6", + "ws": "~7.4.2", + "yeast": "0.1.2" + } + }, + "node_modules/engine.io-parser": { + "version": "4.0.2", + "resolved": "https://registry.npmjs.org/engine.io-parser/-/engine.io-parser-4.0.2.tgz", + "integrity": "sha512-sHfEQv6nmtJrq6TKuIz5kyEKH/qSdK56H/A+7DnAuUPWosnIZAS2NHNcPLmyjtY3cGS/MqJdZbUjW97JU72iYg==", + "dependencies": { + "base64-arraybuffer": "0.1.4" + }, + "engines": { + "node": ">=8.0.0" + } + }, "node_modules/enhanced-resolve": { "version": "4.5.0", "resolved": "https://registry.npmjs.org/enhanced-resolve/-/enhanced-resolve-4.5.0.tgz", @@ -9397,6 +9444,11 @@ "url": "https://github.com/sponsors/ljharb" } }, + "node_modules/has-cors": { + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/has-cors/-/has-cors-1.1.0.tgz", + "integrity": "sha1-XkdHk/fqmEPRu5nCPu9J/xJv/zk=" + }, "node_modules/has-flag": { "version": "4.0.0", "resolved": "https://registry.npmjs.org/has-flag/-/has-flag-4.0.0.tgz", @@ -13634,6 +13686,16 @@ "resolved": "https://registry.npmjs.org/parse5/-/parse5-6.0.1.tgz", "integrity": "sha512-Ofn/CTFzRGTTxwpNEs9PP93gXShHcTq255nzRYSKe8AkVpZY7e1fpmTfOyoIvjP5HG7Z2ZM7VS9PPhQGW2pOpw==" }, + "node_modules/parseqs": { + "version": "0.0.6", + "resolved": "https://registry.npmjs.org/parseqs/-/parseqs-0.0.6.tgz", + "integrity": "sha512-jeAGzMDbfSHHA091hr0r31eYfTig+29g3GKKE/PPbEQ65X0lmMwlEoqmhzu0iztID5uJpZsFlUPDP8ThPL7M8w==" + }, + "node_modules/parseuri": { + "version": "0.0.6", + "resolved": "https://registry.npmjs.org/parseuri/-/parseuri-0.0.6.tgz", + "integrity": "sha512-AUjen8sAkGgao7UyCX6Ahv0gIK2fABKmYjvP4xmy5JaKvcbTRueIqIPHLAfq30xJddqSE033IOMUSOMCcK3Sow==" + }, "node_modules/parseurl": { "version": "1.3.3", "resolved": "https://registry.npmjs.org/parseurl/-/parseurl-1.3.3.tgz", @@ -18161,6 +18223,36 @@ "resolved": "https://registry.npmjs.org/ms/-/ms-2.0.0.tgz", "integrity": "sha1-VgiurfwAvmwpAd9fmGF4jeDVl8g=" }, + "node_modules/socket.io-client": { + "version": "4.0.1", + "resolved": "https://registry.npmjs.org/socket.io-client/-/socket.io-client-4.0.1.tgz", + "integrity": "sha512-6AkaEG5zrVuSVW294cH1chioag9i1OqnCYjKwTc3EBGXbnyb98Lw7yMa40ifLjFj3y6fsFKsd0llbUZUCRf3Qw==", + "dependencies": { + "@types/component-emitter": "^1.2.10", + "backo2": "~1.0.2", + "component-emitter": "~1.3.0", + "debug": "~4.3.1", + "engine.io-client": "~5.0.0", + "parseuri": "0.0.6", + "socket.io-parser": "~4.0.4" + }, + "engines": { + "node": ">=10.0.0" + } + }, + "node_modules/socket.io-parser": { + "version": "4.0.4", + "resolved": "https://registry.npmjs.org/socket.io-parser/-/socket.io-parser-4.0.4.tgz", + "integrity": "sha512-t+b0SS+IxG7Rxzda2EVvyBZbvFPBCjJoyHuE0P//7OAsN23GItzDRdWa6ALxZI/8R5ygK7jAR6t028/z+7295g==", + "dependencies": { + "@types/component-emitter": "^1.2.10", + "component-emitter": "~1.3.0", + "debug": "~4.3.1" + }, + "engines": { + "node": ">=10.0.0" + } + }, "node_modules/sockjs": { "version": "0.3.21", "resolved": "https://registry.npmjs.org/sockjs/-/sockjs-0.3.21.tgz", @@ -22025,6 +22117,11 @@ "node": ">=8" } }, + "node_modules/yeast": { + "version": "0.1.2", + "resolved": "https://registry.npmjs.org/yeast/-/yeast-0.1.2.tgz", + "integrity": "sha1-AI4G2AlDIMNy28L47XagymyKxBk=" + }, "node_modules/yocto-queue": { "version": "0.1.0", "resolved": "https://registry.npmjs.org/yocto-queue/-/yocto-queue-0.1.0.tgz", @@ -24144,6 +24241,11 @@ "resolved": "https://registry.npmjs.org/@types/classnames/-/classnames-2.2.11.tgz", "integrity": "sha512-2koNhpWm3DgWRp5tpkiJ8JGc1xTn2q0l+jUNUE7oMKXUf5NpI9AIdC4kbjGNFBdHtcxBD18LAksoudAVhFKCjw==" }, + "@types/component-emitter": { + "version": "1.2.10", + "resolved": "https://registry.npmjs.org/@types/component-emitter/-/component-emitter-1.2.10.tgz", + "integrity": "sha512-bsjleuRKWmGqajMerkzox19aGbscQX5rmmvvXl3wlIp5gMG1HgkiwPxsN5p070fBDKTNSPgojVbuY1+HWMbFhg==" + }, "@types/d3-path": { "version": "1.0.9", "resolved": "https://registry.npmjs.org/@types/d3-path/-/d3-path-1.0.9.tgz", @@ -25605,6 +25707,11 @@ "resolved": "https://registry.npmjs.org/babylon/-/babylon-6.18.0.tgz", "integrity": "sha512-q/UEjfGJ2Cm3oKV71DJz9d25TPnq5rhBVL2Q4fA5wcC3jcrdn7+SssEybFIxwAvvP+YCsCYNKughoF33GxgycQ==" }, + "backo2": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/backo2/-/backo2-1.0.2.tgz", + "integrity": "sha1-MasayLEpNjRj41s+u2n038+6eUc=" + }, "balanced-match": { "version": "1.0.0", "resolved": "https://registry.npmjs.org/balanced-match/-/balanced-match-1.0.0.tgz", @@ -25634,6 +25741,11 @@ } } }, + "base64-arraybuffer": { + "version": "0.1.4", + "resolved": "https://registry.npmjs.org/base64-arraybuffer/-/base64-arraybuffer-0.1.4.tgz", + "integrity": "sha1-mBjHngWbE1X5fgQooBfIOOkLqBI=" + }, "base64-js": { "version": "1.5.1", "resolved": "https://registry.npmjs.org/base64-js/-/base64-js-1.5.1.tgz", @@ -27674,6 +27786,30 @@ "once": "^1.4.0" } }, + "engine.io-client": { + "version": "5.0.1", + "resolved": "https://registry.npmjs.org/engine.io-client/-/engine.io-client-5.0.1.tgz", + "integrity": "sha512-CQtGN3YwfvbxVwpPugcsHe5rHT4KgT49CEcQppNtu9N7WxbPN0MAG27lGaem7bvtCFtGNLSL+GEqXsFSz36jTg==", + "requires": { + "base64-arraybuffer": "0.1.4", + "component-emitter": "~1.3.0", + "debug": "~4.3.1", + "engine.io-parser": "~4.0.1", + "has-cors": "1.1.0", + "parseqs": "0.0.6", + "parseuri": "0.0.6", + "ws": "~7.4.2", + "yeast": "0.1.2" + } + }, + "engine.io-parser": { + "version": "4.0.2", + "resolved": "https://registry.npmjs.org/engine.io-parser/-/engine.io-parser-4.0.2.tgz", + "integrity": "sha512-sHfEQv6nmtJrq6TKuIz5kyEKH/qSdK56H/A+7DnAuUPWosnIZAS2NHNcPLmyjtY3cGS/MqJdZbUjW97JU72iYg==", + "requires": { + "base64-arraybuffer": "0.1.4" + } + }, "enhanced-resolve": { "version": "4.5.0", "resolved": "https://registry.npmjs.org/enhanced-resolve/-/enhanced-resolve-4.5.0.tgz", @@ -29466,6 +29602,11 @@ "resolved": "https://registry.npmjs.org/has-bigints/-/has-bigints-1.0.1.tgz", "integrity": "sha512-LSBS2LjbNBTf6287JEbEzvJgftkF5qFkmCo9hDRpAzKhUOlJ+hx8dd4USs00SgsUNwc4617J9ki5YtEClM2ffA==" }, + "has-cors": { + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/has-cors/-/has-cors-1.1.0.tgz", + "integrity": "sha1-XkdHk/fqmEPRu5nCPu9J/xJv/zk=" + }, "has-flag": { "version": "4.0.0", "resolved": "https://registry.npmjs.org/has-flag/-/has-flag-4.0.0.tgz", @@ -32737,6 +32878,16 @@ "resolved": "https://registry.npmjs.org/parse5/-/parse5-6.0.1.tgz", "integrity": "sha512-Ofn/CTFzRGTTxwpNEs9PP93gXShHcTq255nzRYSKe8AkVpZY7e1fpmTfOyoIvjP5HG7Z2ZM7VS9PPhQGW2pOpw==" }, + "parseqs": { + "version": "0.0.6", + "resolved": "https://registry.npmjs.org/parseqs/-/parseqs-0.0.6.tgz", + "integrity": "sha512-jeAGzMDbfSHHA091hr0r31eYfTig+29g3GKKE/PPbEQ65X0lmMwlEoqmhzu0iztID5uJpZsFlUPDP8ThPL7M8w==" + }, + "parseuri": { + "version": "0.0.6", + "resolved": "https://registry.npmjs.org/parseuri/-/parseuri-0.0.6.tgz", + "integrity": "sha512-AUjen8sAkGgao7UyCX6Ahv0gIK2fABKmYjvP4xmy5JaKvcbTRueIqIPHLAfq30xJddqSE033IOMUSOMCcK3Sow==" + }, "parseurl": { "version": "1.3.3", "resolved": "https://registry.npmjs.org/parseurl/-/parseurl-1.3.3.tgz", @@ -36328,6 +36479,30 @@ } } }, + "socket.io-client": { + "version": "4.0.1", + "resolved": "https://registry.npmjs.org/socket.io-client/-/socket.io-client-4.0.1.tgz", + "integrity": "sha512-6AkaEG5zrVuSVW294cH1chioag9i1OqnCYjKwTc3EBGXbnyb98Lw7yMa40ifLjFj3y6fsFKsd0llbUZUCRf3Qw==", + "requires": { + "@types/component-emitter": "^1.2.10", + "backo2": "~1.0.2", + "component-emitter": "~1.3.0", + "debug": "~4.3.1", + "engine.io-client": "~5.0.0", + "parseuri": "0.0.6", + "socket.io-parser": "~4.0.4" + } + }, + "socket.io-parser": { + "version": "4.0.4", + "resolved": "https://registry.npmjs.org/socket.io-parser/-/socket.io-parser-4.0.4.tgz", + "integrity": "sha512-t+b0SS+IxG7Rxzda2EVvyBZbvFPBCjJoyHuE0P//7OAsN23GItzDRdWa6ALxZI/8R5ygK7jAR6t028/z+7295g==", + "requires": { + "@types/component-emitter": "^1.2.10", + "component-emitter": "~1.3.0", + "debug": "~4.3.1" + } + }, "sockjs": { "version": "0.3.21", "resolved": "https://registry.npmjs.org/sockjs/-/sockjs-0.3.21.tgz", @@ -39451,6 +39626,11 @@ } } }, + "yeast": { + "version": "0.1.2", + "resolved": "https://registry.npmjs.org/yeast/-/yeast-0.1.2.tgz", + "integrity": "sha1-AI4G2AlDIMNy28L47XagymyKxBk=" + }, "yocto-queue": { "version": "0.1.0", "resolved": "https://registry.npmjs.org/yocto-queue/-/yocto-queue-0.1.0.tgz", diff --git a/frontend/package.json b/frontend/package.json index 6a806de05..aa04a1b10 100644 --- a/frontend/package.json +++ b/frontend/package.json @@ -12,7 +12,6 @@ "url": "https://github.com/morpheus65535/bazarr/issues" }, "homepage": "./", - "proxy": "http://localhost:6767", "dependencies": { "@fontsource/roboto": "^4.2.2", "@fortawesome/fontawesome-svg-core": "^1.2.0", @@ -36,6 +35,7 @@ "@types/redux-promise": "^0.5.0", "axios": "^0.21.0", "bootstrap": "^4.0.0", + "http-proxy-middleware": "^0.19.1", "lodash": "^4.0.0", "rc-slider": "^9.7.1", "react": "^16.0.0", @@ -53,6 +53,7 @@ "redux-promise": "^0.6.0", "redux-thunk": "^2.3.0", "sass": "^1.0.0", + "socket.io-client": "^4.0.0", "typescript": "^4.0.0" }, "devDependencies": { diff --git a/frontend/src/@redux/actions/factory.ts b/frontend/src/@redux/actions/factory.ts index 6303ca96a..22a24e1f3 100644 --- a/frontend/src/@redux/actions/factory.ts +++ b/frontend/src/@redux/actions/factory.ts @@ -1,5 +1,4 @@ -import { isEqual } from "lodash"; -import { log } from "../../utilites/logger"; +import { createAction } from "redux-actions"; import { ActionCallback, ActionDispatcher, @@ -10,42 +9,12 @@ import { PromiseCreator, } from "../types"; -// Limiter the call to API -const gLimiter: Map = new Map(); -const gArgs: Map = new Map(); - -const LIMIT_CALL_MS = 200; - function asyncActionFactory( type: string, promise: T, args: Parameters ): AsyncActionDispatcher>> { return (dispatch) => { - const previousArgs = gArgs.get(promise); - const date = new Date(); - - if (isEqual(previousArgs, args)) { - // Get last execute date - const previousExec = gLimiter.get(promise); - if (previousExec) { - const distInMs = date.getTime() - previousExec.getTime(); - if (distInMs < LIMIT_CALL_MS) { - log( - "warning", - "Multiple calls to API within the range", - promise, - args - ); - return Promise.resolve(); - } - } - } else { - gArgs.set(promise, args); - } - - gLimiter.set(promise, date); - dispatch({ type, payload: { @@ -153,3 +122,8 @@ export function createCallbackAction( return (...args: Parameters) => callbackActionFactory(fn(args), success, error); } + +// Helper +export function createDeleteAction(type: string): SocketIO.ActionFn { + return createAction(type, (id?: number[]) => id ?? []); +} diff --git a/frontend/src/@redux/actions/index.ts b/frontend/src/@redux/actions/index.ts index ed38f6326..afb0e5255 100644 --- a/frontend/src/@redux/actions/index.ts +++ b/frontend/src/@redux/actions/index.ts @@ -1,5 +1,4 @@ export * from "./movie"; -export * from "./providers"; export * from "./series"; export * from "./site"; export * from "./system"; diff --git a/frontend/src/@redux/actions/movie.ts b/frontend/src/@redux/actions/movie.ts index e2154d7e4..8de392181 100644 --- a/frontend/src/@redux/actions/movie.ts +++ b/frontend/src/@redux/actions/movie.ts @@ -1,58 +1,45 @@ import { MoviesApi } from "../../apis"; import { + MOVIES_DELETE_ITEMS, + MOVIES_DELETE_WANTED_ITEMS, MOVIES_UPDATE_BLACKLIST, MOVIES_UPDATE_HISTORY_LIST, - MOVIES_UPDATE_INFO, MOVIES_UPDATE_LIST, - MOVIES_UPDATE_RANGE, MOVIES_UPDATE_WANTED_LIST, - MOVIES_UPDATE_WANTED_RANGE, } from "../constants"; -import { - createAsyncAction, - createAsyncCombineAction, - createCombineAction, -} from "./factory"; -import { badgeUpdateAll } from "./site"; +import { createAsyncAction, createDeleteAction } from "./factory"; -export const movieUpdateList = createAsyncAction(MOVIES_UPDATE_LIST, () => - MoviesApi.movies() +export const movieUpdateList = createAsyncAction( + MOVIES_UPDATE_LIST, + (id?: number[]) => MoviesApi.movies(id) ); -const movieUpdateWantedList = createAsyncAction( +export const movieDeleteItems = createDeleteAction(MOVIES_DELETE_ITEMS); + +export const movieUpdateWantedList = createAsyncAction( MOVIES_UPDATE_WANTED_LIST, - (radarrid?: number) => MoviesApi.wantedBy(radarrid) + (radarrid: number[]) => MoviesApi.wantedBy(radarrid) +); + +export const movieDeleteWantedItems = createDeleteAction( + MOVIES_DELETE_WANTED_ITEMS ); export const movieUpdateWantedByRange = createAsyncAction( - MOVIES_UPDATE_WANTED_RANGE, + MOVIES_UPDATE_WANTED_LIST, (start: number, length: number) => MoviesApi.wanted(start, length) ); -export const movieUpdateWantedBy = createCombineAction((radarrid?: number) => [ - movieUpdateWantedList(radarrid), - badgeUpdateAll(), -]); - export const movieUpdateHistoryList = createAsyncAction( MOVIES_UPDATE_HISTORY_LIST, () => MoviesApi.history() ); export const movieUpdateByRange = createAsyncAction( - MOVIES_UPDATE_RANGE, + MOVIES_UPDATE_LIST, (start: number, length: number) => MoviesApi.moviesBy(start, length) ); -const movieUpdateInfo = createAsyncAction(MOVIES_UPDATE_INFO, (id?: number[]) => - MoviesApi.movies(id) -); - -export const movieUpdateInfoAll = createAsyncCombineAction((id?: number[]) => [ - movieUpdateInfo(id), - badgeUpdateAll(), -]); - export const movieUpdateBlacklist = createAsyncAction( MOVIES_UPDATE_BLACKLIST, () => MoviesApi.blacklist() diff --git a/frontend/src/@redux/actions/providers.ts b/frontend/src/@redux/actions/providers.ts deleted file mode 100644 index 59afb659c..000000000 --- a/frontend/src/@redux/actions/providers.ts +++ /dev/null @@ -1,13 +0,0 @@ -import { ProvidersApi } from "../../apis"; -import { PROVIDER_UPDATE_LIST } from "../constants"; -import { createAsyncAction, createCombineAction } from "./factory"; -import { badgeUpdateAll } from "./site"; - -const providerUpdateList = createAsyncAction(PROVIDER_UPDATE_LIST, () => - ProvidersApi.providers() -); - -export const providerUpdateAll = createCombineAction(() => [ - providerUpdateList(), - badgeUpdateAll(), -]); diff --git a/frontend/src/@redux/actions/series.ts b/frontend/src/@redux/actions/series.ts index 14d3366b4..bbee5b3b5 100644 --- a/frontend/src/@redux/actions/series.ts +++ b/frontend/src/@redux/actions/series.ts @@ -1,50 +1,52 @@ import { EpisodesApi, SeriesApi } from "../../apis"; import { + SERIES_DELETE_EPISODES, + SERIES_DELETE_ITEMS, + SERIES_DELETE_WANTED_ITEMS, SERIES_UPDATE_BLACKLIST, SERIES_UPDATE_EPISODE_LIST, SERIES_UPDATE_HISTORY_LIST, - SERIES_UPDATE_INFO, - SERIES_UPDATE_RANGE, + SERIES_UPDATE_LIST, SERIES_UPDATE_WANTED_LIST, - SERIES_UPDATE_WANTED_RANGE, } from "../constants"; -import { - createAsyncAction, - createAsyncCombineAction, - createCombineAction, -} from "./factory"; -import { badgeUpdateAll } from "./site"; +import { createAsyncAction, createDeleteAction } from "./factory"; -const seriesUpdateWantedList = createAsyncAction( +export const seriesUpdateWantedList = createAsyncAction( SERIES_UPDATE_WANTED_LIST, - (episodeid?: number) => EpisodesApi.wantedBy(episodeid) + (episodeid: number[]) => EpisodesApi.wantedBy(episodeid) ); -const seriesUpdateBy = createAsyncAction(SERIES_UPDATE_INFO, (id?: number[]) => - SeriesApi.series(id) -); - -const episodeUpdateBy = createAsyncAction( - SERIES_UPDATE_EPISODE_LIST, - (seriesid: number) => EpisodesApi.bySeriesId(seriesid) -); - -export const seriesUpdateByRange = createAsyncAction( - SERIES_UPDATE_RANGE, - (start: number, length: number) => SeriesApi.seriesBy(start, length) +export const seriesDeleteWantedItems = createDeleteAction( + SERIES_DELETE_WANTED_ITEMS ); export const seriesUpdateWantedByRange = createAsyncAction( - SERIES_UPDATE_WANTED_RANGE, + SERIES_UPDATE_WANTED_LIST, (start: number, length: number) => EpisodesApi.wanted(start, length) ); -export const seriesUpdateWantedBy = createCombineAction( - (episodeid?: number) => [seriesUpdateWantedList(episodeid), badgeUpdateAll()] +export const seriesUpdateList = createAsyncAction( + SERIES_UPDATE_LIST, + (id?: number[]) => SeriesApi.series(id) ); -export const episodeUpdateBySeriesId = createCombineAction( - (seriesid: number) => [episodeUpdateBy(seriesid), badgeUpdateAll()] +export const seriesDeleteItems = createDeleteAction(SERIES_DELETE_ITEMS); + +export const episodeUpdateBy = createAsyncAction( + SERIES_UPDATE_EPISODE_LIST, + (seriesid: number[]) => EpisodesApi.bySeriesId(seriesid) +); + +export const episodeDeleteItems = createDeleteAction(SERIES_DELETE_EPISODES); + +export const episodeUpdateById = createAsyncAction( + SERIES_UPDATE_EPISODE_LIST, + (episodeid: number[]) => EpisodesApi.byEpisodeId(episodeid) +); + +export const seriesUpdateByRange = createAsyncAction( + SERIES_UPDATE_LIST, + (start: number, length: number) => SeriesApi.seriesBy(start, length) ); export const seriesUpdateHistoryList = createAsyncAction( @@ -52,10 +54,6 @@ export const seriesUpdateHistoryList = createAsyncAction( () => EpisodesApi.history() ); -export const seriesUpdateInfoAll = createAsyncCombineAction( - (seriesid?: number[]) => [seriesUpdateBy(seriesid), badgeUpdateAll()] -); - export const seriesUpdateBlacklist = createAsyncAction( SERIES_UPDATE_BLACKLIST, () => EpisodesApi.blacklist() diff --git a/frontend/src/@redux/actions/site.ts b/frontend/src/@redux/actions/site.ts index 9ea9e6af5..038ca5fe7 100644 --- a/frontend/src/@redux/actions/site.ts +++ b/frontend/src/@redux/actions/site.ts @@ -16,7 +16,7 @@ import { createAsyncAction, createCallbackAction } from "./factory"; import { systemUpdateLanguagesAll, systemUpdateSettings } from "./system"; export const bootstrap = createCallbackAction( - () => [systemUpdateLanguagesAll(), systemUpdateSettings()], + () => [systemUpdateLanguagesAll(), systemUpdateSettings(), badgeUpdateAll()], () => siteInitialized(), () => siteInitializeFailed() ); @@ -36,17 +36,17 @@ export const siteSaveLocalstorage = createAction( (settings: LooseObject) => settings ); -export const siteAddError = createAction( +export const siteAddNotification = createAction( SITE_NOTIFICATIONS_ADD, (err: ReduxStore.Notification) => err ); -export const siteRemoveError = createAction( +export const siteRemoveNotification = createAction( SITE_NOTIFICATIONS_REMOVE, (id: string) => id ); -export const siteRemoveErrorByTimestamp = createAction( +export const siteRemoveNotificationByTime = createAction( SITE_NOTIFICATIONS_REMOVE_BY_TIMESTAMP, (date: Date) => date ); diff --git a/frontend/src/@redux/actions/system.ts b/frontend/src/@redux/actions/system.ts index ed0b75a0a..adf73b3eb 100644 --- a/frontend/src/@redux/actions/system.ts +++ b/frontend/src/@redux/actions/system.ts @@ -1,10 +1,10 @@ -import { Action } from "redux-actions"; -import { SystemApi } from "../../apis"; +import { ProvidersApi, SystemApi } from "../../apis"; import { - SYSTEM_RUN_TASK, + SYSTEM_UPDATE_HEALTH, SYSTEM_UPDATE_LANGUAGES_LIST, SYSTEM_UPDATE_LANGUAGES_PROFILE_LIST, SYSTEM_UPDATE_LOGS, + SYSTEM_UPDATE_PROVIDERS, SYSTEM_UPDATE_RELEASES, SYSTEM_UPDATE_SETTINGS, SYSTEM_UPDATE_STATUS, @@ -31,17 +31,14 @@ export const systemUpdateStatus = createAsyncAction(SYSTEM_UPDATE_STATUS, () => SystemApi.status() ); +export const systemUpdateHealth = createAsyncAction(SYSTEM_UPDATE_HEALTH, () => + SystemApi.health() +); + export const systemUpdateTasks = createAsyncAction(SYSTEM_UPDATE_TASKS, () => SystemApi.getTasks() ); -export function systemRunTasks(id: string): Action { - return { - type: SYSTEM_RUN_TASK, - payload: id, - }; -} - export const systemUpdateLogs = createAsyncAction(SYSTEM_UPDATE_LOGS, () => SystemApi.logs() ); @@ -56,6 +53,11 @@ export const systemUpdateSettings = createAsyncAction( () => SystemApi.settings() ); +export const providerUpdateList = createAsyncAction( + SYSTEM_UPDATE_PROVIDERS, + () => ProvidersApi.providers() +); + export const systemUpdateSettingsAll = createAsyncCombineAction(() => [ systemUpdateSettings(), systemUpdateLanguagesAll(), diff --git a/frontend/src/@redux/constants/index.ts b/frontend/src/@redux/constants/index.ts index d4be20f04..1201528af 100644 --- a/frontend/src/@redux/constants/index.ts +++ b/frontend/src/@redux/constants/index.ts @@ -1,33 +1,33 @@ // Provider action -export const PROVIDER_UPDATE_LIST = "UPDATE_PROVIDER_LIST"; // System action export const SYSTEM_UPDATE_LANGUAGES_LIST = "UPDATE_ALL_LANGUAGES_LIST"; export const SYSTEM_UPDATE_LANGUAGES_PROFILE_LIST = "UPDATE_LANGUAGES_PROFILE_LIST"; export const SYSTEM_UPDATE_STATUS = "UPDATE_SYSTEM_STATUS"; +export const SYSTEM_UPDATE_HEALTH = "UPDATE_SYSTEM_HEALTH"; export const SYSTEM_UPDATE_TASKS = "UPDATE_SYSTEM_TASKS"; export const SYSTEM_UPDATE_LOGS = "UPDATE_SYSTEM_LOGS"; export const SYSTEM_UPDATE_RELEASES = "SYSTEM_UPDATE_RELEASES"; export const SYSTEM_UPDATE_SETTINGS = "UPDATE_SYSTEM_SETTINGS"; -export const SYSTEM_RUN_TASK = "SYSTEM_RUN_TASK"; +export const SYSTEM_UPDATE_PROVIDERS = "SYSTEM_UPDATE_PROVIDERS"; // Series action -export const SERIES_UPDATE_WANTED_RANGE = "SERIES_UPDATE_WANTED_RANGE"; export const SERIES_UPDATE_WANTED_LIST = "UPDATE_SERIES_WANTED_LIST"; +export const SERIES_DELETE_WANTED_ITEMS = "SERIES_DELETE_WANTED_ITEMS"; export const SERIES_UPDATE_EPISODE_LIST = "UPDATE_SERIES_EPISODE_LIST"; +export const SERIES_DELETE_EPISODES = "SERIES_DELETE_EPISODES"; export const SERIES_UPDATE_HISTORY_LIST = "UPDATE_SERIES_HISTORY_LIST"; -export const SERIES_UPDATE_INFO = "UPDATE_SEIRES_INFO"; -export const SERIES_UPDATE_RANGE = "SERIES_UPDATE_RANGE"; +export const SERIES_UPDATE_LIST = "UPDATE_SEIRES_LIST"; +export const SERIES_DELETE_ITEMS = "SERIES_DELETE_ITEMS"; export const SERIES_UPDATE_BLACKLIST = "UPDATE_SERIES_BLACKLIST"; // Movie action export const MOVIES_UPDATE_LIST = "UPDATE_MOVIE_LIST"; -export const MOVIES_UPDATE_WANTED_RANGE = "MOVIES_UPDATE_WANTED_RANGE"; +export const MOVIES_DELETE_ITEMS = "MOVIES_DELETE_ITEMS"; export const MOVIES_UPDATE_WANTED_LIST = "UPDATE_MOVIE_WANTED_LIST"; +export const MOVIES_DELETE_WANTED_ITEMS = "MOVIES_DELETE_WANTED_ITEMS"; export const MOVIES_UPDATE_HISTORY_LIST = "UPDATE_MOVIE_HISTORY_LIST"; -export const MOVIES_UPDATE_INFO = "UPDATE_MOVIE_INFO"; -export const MOVIES_UPDATE_RANGE = "MOVIES_UPDATE_RANGE"; export const MOVIES_UPDATE_BLACKLIST = "UPDATE_MOVIES_BLACKLIST"; // Site Action diff --git a/frontend/src/@redux/hooks/index.ts b/frontend/src/@redux/hooks/index.ts index 3db9dd37d..1bd6062f6 100644 --- a/frontend/src/@redux/hooks/index.ts +++ b/frontend/src/@redux/hooks/index.ts @@ -1,19 +1,29 @@ -import { useCallback, useMemo } from "react"; +import { useCallback, useEffect, useMemo } from "react"; +import { useSocketIOReducer, useWrapToOptionalId } from "../../@socketio/hooks"; import { buildOrderList } from "../../utilites"; import { - episodeUpdateBySeriesId, + episodeDeleteItems, + episodeUpdateBy, + episodeUpdateById, + movieDeleteWantedItems, movieUpdateBlacklist, movieUpdateHistoryList, - movieUpdateInfoAll, - movieUpdateWantedBy, - providerUpdateAll, + movieUpdateList, + movieUpdateWantedList, + providerUpdateList, + seriesDeleteWantedItems, seriesUpdateBlacklist, seriesUpdateHistoryList, - seriesUpdateInfoAll, - seriesUpdateWantedBy, + seriesUpdateList, + seriesUpdateWantedList, + systemUpdateHealth, systemUpdateLanguages, systemUpdateLanguagesProfiles, + systemUpdateLogs, + systemUpdateReleases, systemUpdateSettingsAll, + systemUpdateStatus, + systemUpdateTasks, } from "../actions"; import { useReduxAction, useReduxStore } from "./base"; @@ -25,9 +35,71 @@ function stateBuilder any>( } export function useSystemSettings() { - const action = useReduxAction(systemUpdateSettingsAll); + const update = useReduxAction(systemUpdateSettingsAll); const items = useReduxStore((s) => s.system.settings); - return stateBuilder(items, action); + + return stateBuilder(items, update); +} + +export function useSystemLogs() { + const items = useReduxStore(({ system }) => system.logs); + const update = useReduxAction(systemUpdateLogs); + + useEffect(() => { + update(); + }, [update]); + return stateBuilder(items, update); +} + +export function useSystemTasks() { + const items = useReduxStore((s) => s.system.tasks); + const update = useReduxAction(systemUpdateTasks); + useSocketIOReducer("task", update); + + useEffect(() => { + update(); + }, [update]); + return stateBuilder(items, update); +} + +export function useSystemStatus() { + const items = useReduxStore((s) => s.system.status.data); + const update = useReduxAction(systemUpdateStatus); + + useEffect(() => { + update(); + }, [update]); + return stateBuilder(items, update); +} + +export function useSystemHealth() { + const update = useReduxAction(systemUpdateHealth); + const items = useReduxStore((s) => s.system.health); + + useEffect(() => { + update(); + }, [update]); + return stateBuilder(items, update); +} + +export function useSystemProviders() { + const update = useReduxAction(providerUpdateList); + const items = useReduxStore((d) => d.system.providers); + + useEffect(() => { + update(); + }, [update]); + return stateBuilder(items, update); +} + +export function useSystemReleases() { + const items = useReduxStore(({ system }) => system.releases); + const update = useReduxAction(systemUpdateReleases); + + useEffect(() => { + update(); + }, [update]); + return stateBuilder(items, update); } export function useLanguageProfiles() { @@ -92,9 +164,9 @@ export function useProfileItems(profile?: Profile.Languages) { } export function useRawSeries() { - const action = useReduxAction(seriesUpdateInfoAll); + const update = useReduxAction(seriesUpdateList); const items = useReduxStore((d) => d.series.seriesList); - return stateBuilder(items, action); + return stateBuilder(items, update); } export function useSeries(order = true) { @@ -118,7 +190,6 @@ export function useSeries(order = true) { export function useSerieBy(id?: number) { const [series, updateSerie] = useRawSeries(); - const updateEpisodes = useReduxAction(episodeUpdateBySeriesId); const serie = useMemo>(() => { const items = series.data.items; let item: Item.Series | null = null; @@ -134,18 +205,22 @@ export function useSerieBy(id?: number) { const update = useCallback(() => { if (id && !isNaN(id)) { updateSerie([id]); - updateEpisodes(id); } - }, [id, updateSerie, updateEpisodes]); + }, [id, updateSerie]); + useEffect(() => { + if (serie.data === null) { + update(); + } + }, [serie.data, update]); return stateBuilder(serie, update); } export function useEpisodesBy(seriesId?: number) { - const action = useReduxAction(episodeUpdateBySeriesId); - const callback = useCallback(() => { + const action = useReduxAction(episodeUpdateBy); + const update = useCallback(() => { if (seriesId !== undefined && !isNaN(seriesId)) { - action(seriesId); + action([seriesId]); } }, [action, seriesId]); @@ -153,24 +228,38 @@ export function useEpisodesBy(seriesId?: number) { const items = useMemo(() => { if (seriesId !== undefined && !isNaN(seriesId)) { - return list.data[seriesId] ?? []; + return list.data.filter((v) => v.sonarrSeriesId === seriesId); } else { return []; } }, [seriesId, list.data]); - const state: AsyncState = { - ...list, - data: items, - }; + const state: AsyncState = useMemo( + () => ({ + ...list, + data: items, + }), + [list, items] + ); - return stateBuilder(state, callback); + const actionById = useReduxAction(episodeUpdateById); + const wrapActionById = useWrapToOptionalId(actionById); + const deleteAction = useReduxAction(episodeDeleteItems); + useSocketIOReducer("episode", undefined, wrapActionById, deleteAction); + + const wrapAction = useWrapToOptionalId(action); + useSocketIOReducer("series", undefined, wrapAction); + + useEffect(() => { + update(); + }, [update]); + return stateBuilder(state, update); } export function useRawMovies() { - const action = useReduxAction(movieUpdateInfoAll); + const update = useReduxAction(movieUpdateList); const items = useReduxStore((d) => d.movie.movieList); - return stateBuilder(items, action); + return stateBuilder(items, update); } export function useMovies(order = true) { @@ -212,54 +301,80 @@ export function useMovieBy(id?: number) { } }, [id, updateMovies]); + useEffect(() => { + if (movie.data === null) { + update(); + } + }, [movie.data, update]); return stateBuilder(movie, update); } export function useWantedSeries() { - const action = useReduxAction(seriesUpdateWantedBy); + const update = useReduxAction(seriesUpdateWantedList); const items = useReduxStore((d) => d.series.wantedEpisodesList); - return stateBuilder(items, action); + const updateAction = useWrapToOptionalId(update); + const deleteAction = useReduxAction(seriesDeleteWantedItems); + useSocketIOReducer("episode-wanted", undefined, updateAction, deleteAction); + + return stateBuilder(items, update); } export function useWantedMovies() { - const action = useReduxAction(movieUpdateWantedBy); + const update = useReduxAction(movieUpdateWantedList); const items = useReduxStore((d) => d.movie.wantedMovieList); - return stateBuilder(items, action); -} + const updateAction = useWrapToOptionalId(update); + const deleteAction = useReduxAction(movieDeleteWantedItems); + useSocketIOReducer("movie-wanted", undefined, updateAction, deleteAction); -export function useProviders() { - const action = useReduxAction(providerUpdateAll); - const items = useReduxStore((d) => d.system.providers); - - return stateBuilder(items, action); + return stateBuilder(items, update); } export function useBlacklistMovies() { - const action = useReduxAction(movieUpdateBlacklist); + const update = useReduxAction(movieUpdateBlacklist); const items = useReduxStore((d) => d.movie.blacklist); - return stateBuilder(items, action); + useSocketIOReducer("movie-blacklist", update); + + useEffect(() => { + update(); + }, [update]); + return stateBuilder(items, update); } export function useBlacklistSeries() { - const action = useReduxAction(seriesUpdateBlacklist); + const update = useReduxAction(seriesUpdateBlacklist); const items = useReduxStore((d) => d.series.blacklist); - return stateBuilder(items, action); + useSocketIOReducer("episode-blacklist", update); + + useEffect(() => { + update(); + }, [update]); + return stateBuilder(items, update); } export function useMoviesHistory() { - const action = useReduxAction(movieUpdateHistoryList); + const update = useReduxAction(movieUpdateHistoryList); const items = useReduxStore((s) => s.movie.historyList); - return stateBuilder(items, action); + useSocketIOReducer("movie-history", update); + + useEffect(() => { + update(); + }, [update]); + return stateBuilder(items, update); } export function useSeriesHistory() { - const action = useReduxAction(seriesUpdateHistoryList); + const update = useReduxAction(seriesUpdateHistoryList); const items = useReduxStore((s) => s.series.historyList); - return stateBuilder(items, action); + useSocketIOReducer("episode-history", update); + + useEffect(() => { + update(); + }, [update]); + return stateBuilder(items, update); } diff --git a/frontend/src/@redux/hooks/site.ts b/frontend/src/@redux/hooks/site.ts index 139f128d0..c789bb026 100644 --- a/frontend/src/@redux/hooks/site.ts +++ b/frontend/src/@redux/hooks/site.ts @@ -1,11 +1,15 @@ -import { useCallback } from "react"; +import { useCallback, useEffect } from "react"; import { useSystemSettings } from "."; -import { siteAddError, siteRemoveErrorByTimestamp } from "../actions"; +import { + siteAddNotification, + siteChangeSidebar, + siteRemoveNotificationByTime, +} from "../actions"; import { useReduxAction, useReduxStore } from "./base"; export function useNotification(id: string, sec: number = 5) { - const add = useReduxAction(siteAddError); - const remove = useReduxAction(siteRemoveErrorByTimestamp); + const add = useReduxAction(siteAddNotification); + const remove = useReduxAction(siteRemoveNotificationByTime); return useCallback( (msg: Omit) => { @@ -34,3 +38,15 @@ export function useIsRadarrEnabled() { const [settings] = useSystemSettings(); return settings.data?.general.use_radarr ?? true; } + +export function useShowOnlyDesired() { + const [settings] = useSystemSettings(); + return settings.data?.general.embedded_subs_show_desired ?? false; +} + +export function useSetSidebar(key: string) { + const update = useReduxAction(siteChangeSidebar); + useEffect(() => { + update(key); + }, [update, key]); +} diff --git a/frontend/src/@redux/reducers/mapper.ts b/frontend/src/@redux/reducers/mapper.ts deleted file mode 100644 index a2cfd2588..000000000 --- a/frontend/src/@redux/reducers/mapper.ts +++ /dev/null @@ -1,112 +0,0 @@ -import { mergeArray } from "../../utilites"; -import { AsyncAction } from "../types"; - -export function updateAsyncState( - action: AsyncAction, - defVal: Readonly -): AsyncState { - if (action.payload.loading) { - return { - updating: true, - data: defVal, - }; - } else if (action.error !== undefined) { - return { - updating: false, - error: action.payload.item as Error, - data: defVal, - }; - } else { - return { - updating: false, - error: undefined, - data: action.payload.item as Payload, - }; - } -} - -export function updateOrderIdState( - action: AsyncAction>, - state: AsyncState>, - id: ItemIdType -): AsyncState> { - if (action.payload.loading) { - return { - ...state, - updating: true, - }; - } else if (action.error !== undefined) { - return { - ...state, - updating: false, - error: action.payload.item as Error, - }; - } else { - const { data, total } = action.payload.item as AsyncDataWrapper; - const [start, length] = action.payload.parameters; - - // Convert item list to object - const idState: IdState = data.reduce>((prev, curr) => { - const tid = curr[id]; - prev[tid] = curr; - return prev; - }, {}); - - const dataOrder: number[] = data.map((v) => v[id]); - - let newItems = { ...state.data.items, ...idState }; - let newOrder = state.data.order; - - const countDist = total - newOrder.length; - if (countDist > 0) { - newOrder.push(...Array(countDist).fill(null)); - } else if (countDist < 0) { - // Completely drop old data if list has shrinked - newOrder = Array(total).fill(null); - newItems = { ...idState }; - } - - if (typeof start === "number" && typeof length === "number") { - newOrder.splice(start, length, ...dataOrder); - } else if (start === undefined) { - // Full Update - newOrder = dataOrder; - } - - return { - updating: false, - data: { - items: newItems, - order: newOrder, - }, - }; - } -} - -export function updateAsyncList( - action: AsyncAction, - state: AsyncState, - match: ID -): AsyncState { - if (action.payload.loading) { - return { - ...state, - updating: true, - }; - } else if (action.error !== undefined) { - return { - ...state, - updating: false, - error: action.payload.item as Error, - }; - } else { - const list = state.data as T[]; - const payload = action.payload.item as T[]; - const result = mergeArray(list, payload, (l, r) => l[match] === r[match]); - - return { - updating: false, - data: result, - }; - } -} diff --git a/frontend/src/@redux/reducers/movie.ts b/frontend/src/@redux/reducers/movie.ts index 2ee2c39f8..2b7fe8df1 100644 --- a/frontend/src/@redux/reducers/movie.ts +++ b/frontend/src/@redux/reducers/movie.ts @@ -1,14 +1,19 @@ -import { handleActions } from "redux-actions"; +import { Action, handleActions } from "redux-actions"; import { + MOVIES_DELETE_ITEMS, + MOVIES_DELETE_WANTED_ITEMS, MOVIES_UPDATE_BLACKLIST, MOVIES_UPDATE_HISTORY_LIST, - MOVIES_UPDATE_INFO, - MOVIES_UPDATE_RANGE, + MOVIES_UPDATE_LIST, MOVIES_UPDATE_WANTED_LIST, - MOVIES_UPDATE_WANTED_RANGE, } from "../constants"; import { AsyncAction } from "../types"; -import { updateAsyncState, updateOrderIdState } from "./mapper"; +import { defaultAOS } from "../utils"; +import { + deleteOrderListItemBy, + updateAsyncState, + updateOrderIdState, +} from "../utils/mapper"; const reducer = handleActions( { @@ -25,17 +30,10 @@ const reducer = handleActions( ), }; }, - [MOVIES_UPDATE_WANTED_RANGE]: ( - state, - action: AsyncAction> - ) => { + [MOVIES_DELETE_WANTED_ITEMS]: (state, action: Action) => { return { ...state, - wantedMovieList: updateOrderIdState( - action, - state.wantedMovieList, - "radarrId" - ), + wantedMovieList: deleteOrderListItemBy(action, state.wantedMovieList), }; }, [MOVIES_UPDATE_HISTORY_LIST]: ( @@ -47,7 +45,7 @@ const reducer = handleActions( historyList: updateAsyncState(action, state.historyList.data), }; }, - [MOVIES_UPDATE_INFO]: ( + [MOVIES_UPDATE_LIST]: ( state, action: AsyncAction> ) => { @@ -56,13 +54,10 @@ const reducer = handleActions( movieList: updateOrderIdState(action, state.movieList, "radarrId"), }; }, - [MOVIES_UPDATE_RANGE]: ( - state, - action: AsyncAction> - ) => { + [MOVIES_DELETE_ITEMS]: (state, action: Action) => { return { ...state, - movieList: updateOrderIdState(action, state.movieList, "radarrId"), + movieList: deleteOrderListItemBy(action, state.movieList), }; }, [MOVIES_UPDATE_BLACKLIST]: ( @@ -76,8 +71,8 @@ const reducer = handleActions( }, }, { - movieList: { updating: true, data: { items: {}, order: [] } }, - wantedMovieList: { updating: true, data: { items: {}, order: [] } }, + movieList: defaultAOS(), + wantedMovieList: defaultAOS(), historyList: { updating: true, data: [] }, blacklist: { updating: true, data: [] }, } diff --git a/frontend/src/@redux/reducers/series.ts b/frontend/src/@redux/reducers/series.ts index 57292a054..23ae4ecd6 100644 --- a/frontend/src/@redux/reducers/series.ts +++ b/frontend/src/@redux/reducers/series.ts @@ -1,15 +1,23 @@ -import { handleActions } from "redux-actions"; +import { Action, handleActions } from "redux-actions"; import { + SERIES_DELETE_EPISODES, + SERIES_DELETE_ITEMS, + SERIES_DELETE_WANTED_ITEMS, SERIES_UPDATE_BLACKLIST, SERIES_UPDATE_EPISODE_LIST, SERIES_UPDATE_HISTORY_LIST, - SERIES_UPDATE_INFO, - SERIES_UPDATE_RANGE, + SERIES_UPDATE_LIST, SERIES_UPDATE_WANTED_LIST, - SERIES_UPDATE_WANTED_RANGE, } from "../constants"; import { AsyncAction } from "../types"; -import { updateAsyncState, updateOrderIdState } from "./mapper"; +import { defaultAOS } from "../utils"; +import { + deleteAsyncListItemBy, + deleteOrderListItemBy, + updateAsyncList, + updateAsyncState, + updateOrderIdState, +} from "../utils/mapper"; const reducer = handleActions( { @@ -26,16 +34,12 @@ const reducer = handleActions( ), }; }, - [SERIES_UPDATE_WANTED_RANGE]: ( - state, - action: AsyncAction> - ) => { + [SERIES_DELETE_WANTED_ITEMS]: (state, action: Action) => { return { ...state, - wantedEpisodesList: updateOrderIdState( + wantedEpisodesList: deleteOrderListItemBy( action, - state.wantedEpisodesList, - "sonarrEpisodeId" + state.wantedEpisodesList ), }; }, @@ -43,22 +47,23 @@ const reducer = handleActions( state, action: AsyncAction ) => { - const { updating, error, data: items } = updateAsyncState(action, []); - - const stateItems = { ...state.episodeList.data }; - - if (items.length > 0) { - const id = items[0].sonarrSeriesId; - stateItems[id] = items; - } - return { ...state, - episodeList: { - updating, - error, - data: stateItems, - }, + episodeList: updateAsyncList( + action, + state.episodeList, + "sonarrEpisodeId" + ), + }; + }, + [SERIES_DELETE_EPISODES]: (state, action: Action) => { + return { + ...state, + episodeList: deleteAsyncListItemBy( + action, + state.episodeList, + "sonarrEpisodeId" + ), }; }, [SERIES_UPDATE_HISTORY_LIST]: ( @@ -70,7 +75,7 @@ const reducer = handleActions( historyList: updateAsyncState(action, state.historyList.data), }; }, - [SERIES_UPDATE_INFO]: ( + [SERIES_UPDATE_LIST]: ( state, action: AsyncAction> ) => { @@ -83,17 +88,10 @@ const reducer = handleActions( ), }; }, - [SERIES_UPDATE_RANGE]: ( - state, - action: AsyncAction> - ) => { + [SERIES_DELETE_ITEMS]: (state, action: Action) => { return { ...state, - seriesList: updateOrderIdState( - action, - state.seriesList, - "sonarrSeriesId" - ), + seriesList: deleteOrderListItemBy(action, state.seriesList), }; }, [SERIES_UPDATE_BLACKLIST]: ( @@ -107,9 +105,9 @@ const reducer = handleActions( }, }, { - seriesList: { updating: true, data: { items: {}, order: [] } }, - wantedEpisodesList: { updating: true, data: { items: {}, order: [] } }, - episodeList: { updating: true, data: {} }, + seriesList: defaultAOS(), + wantedEpisodesList: defaultAOS(), + episodeList: { updating: true, data: [] }, historyList: { updating: true, data: [] }, blacklist: { updating: true, data: [] }, } diff --git a/frontend/src/@redux/reducers/site.ts b/frontend/src/@redux/reducers/site.ts index 78191908a..797b2bf0f 100644 --- a/frontend/src/@redux/reducers/site.ts +++ b/frontend/src/@redux/reducers/site.ts @@ -101,6 +101,7 @@ const reducer = handleActions( movies: 0, episodes: 0, providers: 0, + status: 0, }, offline: false, ...updateLocalStorage(), diff --git a/frontend/src/@redux/reducers/system.ts b/frontend/src/@redux/reducers/system.ts index 38b9d32f9..f3498be69 100644 --- a/frontend/src/@redux/reducers/system.ts +++ b/frontend/src/@redux/reducers/system.ts @@ -1,16 +1,16 @@ -import { Action, handleActions } from "redux-actions"; +import { handleActions } from "redux-actions"; import { - PROVIDER_UPDATE_LIST, - SYSTEM_RUN_TASK, + SYSTEM_UPDATE_HEALTH, SYSTEM_UPDATE_LANGUAGES_LIST, SYSTEM_UPDATE_LANGUAGES_PROFILE_LIST, SYSTEM_UPDATE_LOGS, + SYSTEM_UPDATE_PROVIDERS, SYSTEM_UPDATE_RELEASES, SYSTEM_UPDATE_SETTINGS, SYSTEM_UPDATE_STATUS, SYSTEM_UPDATE_TASKS, } from "../constants"; -import { updateAsyncState } from "./mapper"; +import { updateAsyncState } from "../utils/mapper"; const reducer = handleActions( { @@ -46,32 +46,19 @@ const reducer = handleActions( ), }; }, + [SYSTEM_UPDATE_HEALTH]: (state, action) => { + return { + ...state, + health: updateAsyncState(action, state.health.data), + }; + }, [SYSTEM_UPDATE_TASKS]: (state, action) => { return { ...state, tasks: updateAsyncState>(action, state.tasks.data), }; }, - [SYSTEM_RUN_TASK]: (state, action: Action) => { - const id = action.payload; - const tasks = state.tasks; - const newItems = [...tasks.data]; - - const idx = newItems.findIndex((v) => v.job_id === id); - - if (idx !== -1) { - newItems[idx].job_running = true; - } - - return { - ...state, - tasks: { - ...tasks, - data: newItems, - }, - }; - }, - [PROVIDER_UPDATE_LIST]: (state, action) => { + [SYSTEM_UPDATE_PROVIDERS]: (state, action) => { return { ...state, providers: updateAsyncState(action, state.providers.data), @@ -104,6 +91,10 @@ const reducer = handleActions( updating: true, data: undefined, }, + health: { + updating: true, + data: [], + }, tasks: { updating: true, data: [], diff --git a/frontend/src/@redux/redux.d.ts b/frontend/src/@redux/redux.d.ts index e7835975f..56365c0bd 100644 --- a/frontend/src/@redux/redux.d.ts +++ b/frontend/src/@redux/redux.d.ts @@ -1,12 +1,3 @@ -interface IdState { - [key: number]: Readonly; -} - -interface OrderIdState { - items: IdState; - order: (number | null)[]; -} - interface ReduxStore { system: ReduxStore.System; series: ReduxStore.Series; @@ -38,6 +29,7 @@ namespace ReduxStore { enabledLanguage: AsyncState>; languagesProfiles: AsyncState>; status: AsyncState; + health: AsyncState>; tasks: AsyncState>; providers: AsyncState>; logs: AsyncState>; @@ -46,16 +38,16 @@ namespace ReduxStore { } interface Series { - seriesList: AsyncState>; - wantedEpisodesList: AsyncState>; - episodeList: AsyncState>; + seriesList: AsyncOrderState; + wantedEpisodesList: AsyncOrderState; + episodeList: AsyncState; historyList: AsyncState>; blacklist: AsyncState>; } interface Movie { - movieList: AsyncState>; - wantedMovieList: AsyncState>; + movieList: AsyncOrderState; + wantedMovieList: AsyncOrderState; historyList: AsyncState>; blacklist: AsyncState>; } diff --git a/frontend/src/@redux/utils/index.ts b/frontend/src/@redux/utils/index.ts new file mode 100644 index 000000000..61340a902 --- /dev/null +++ b/frontend/src/@redux/utils/index.ts @@ -0,0 +1,10 @@ +export function defaultAOS(): AsyncOrderState { + return { + updating: true, + data: { + items: [], + order: [], + fetched: false, + }, + }; +} diff --git a/frontend/src/@redux/utils/mapper.ts b/frontend/src/@redux/utils/mapper.ts new file mode 100644 index 000000000..32c25b8aa --- /dev/null +++ b/frontend/src/@redux/utils/mapper.ts @@ -0,0 +1,181 @@ +import { difference, has, isArray, isNull, isNumber, uniqBy } from "lodash"; +import { Action } from "redux-actions"; +import { conditionalLog } from "../../utilites/logger"; +import { AsyncAction } from "../types"; + +export function updateAsyncState( + action: AsyncAction, + defVal: Readonly +): AsyncState { + if (action.payload.loading) { + return { + updating: true, + data: defVal, + }; + } else if (action.error !== undefined) { + return { + updating: false, + error: action.payload.item as Error, + data: defVal, + }; + } else { + return { + updating: false, + error: undefined, + data: action.payload.item as Payload, + }; + } +} + +export function updateOrderIdState( + action: AsyncAction>, + state: AsyncOrderState, + id: ItemIdType +): AsyncOrderState { + if (action.payload.loading) { + return { + data: { + ...state.data, + fetched: true, + }, + updating: true, + }; + } else if (action.error !== undefined) { + return { + data: { + ...state.data, + fetched: true, + }, + updating: false, + error: action.payload.item as Error, + }; + } else { + const { data, total } = action.payload.item as AsyncDataWrapper; + const { parameters } = action.payload; + const [start, length] = parameters; + + // Convert item list to object + const newItems = data.reduce>( + (prev, curr) => { + const tid = curr[id]; + prev[tid] = curr; + return prev; + }, + { ...state.data.items } + ); + + let newOrder = [...state.data.order]; + + const countDist = total - newOrder.length; + if (countDist > 0) { + newOrder = Array(countDist).fill(null).concat(newOrder); + } else if (countDist < 0) { + // Completely drop old data if list has shrinked + newOrder = Array(total).fill(null); + } + + const idList = newOrder.filter(isNumber); + + const dataOrder: number[] = data.map((v) => v[id]); + + if (typeof start === "number" && typeof length === "number") { + newOrder.splice(start, length, ...dataOrder); + } else if (isArray(start)) { + // Find the null values and delete them, insert new values to the front of array + const addition = difference(dataOrder, idList); + let addCount = addition.length; + newOrder.unshift(...addition); + + newOrder = newOrder.flatMap((v) => { + if (isNull(v) && addCount > 0) { + --addCount; + return []; + } else { + return [v]; + } + }, []); + + conditionalLog( + addCount !== 0, + "Error when replacing item in OrderIdState" + ); + } else if (parameters.length === 0) { + // TODO: Delete me -> Full Update + newOrder = dataOrder; + } + + return { + updating: false, + data: { + fetched: true, + items: newItems, + order: newOrder, + }, + }; + } +} + +export function deleteOrderListItemBy( + action: Action, + state: AsyncOrderState +): AsyncOrderState { + const ids = action.payload; + const { items, order } = state.data; + const newItems = { ...items }; + ids.forEach((v) => { + if (has(newItems, v)) { + delete newItems[v]; + } + }); + const newOrder = difference(order, ids); + return { + ...state, + data: { + fetched: true, + items: newItems, + order: newOrder, + }, + }; +} + +export function deleteAsyncListItemBy( + action: Action, + state: AsyncState, + match: ItemIdType +): AsyncState { + const ids = new Set(action.payload); + const data = [...state.data].filter((v) => !ids.has(v[match])); + return { + ...state, + data, + }; +} + +export function updateAsyncList( + action: AsyncAction, + state: AsyncState, + match: ID +): AsyncState { + if (action.payload.loading) { + return { + ...state, + updating: true, + }; + } else if (action.error !== undefined) { + return { + ...state, + updating: false, + error: action.payload.item as Error, + }; + } else { + const olds = state.data as T[]; + const news = action.payload.item as T[]; + + const result = uniqBy([...news, ...olds], match); + + return { + updating: false, + data: result, + }; + } +} diff --git a/frontend/src/@socketio/hooks.ts b/frontend/src/@socketio/hooks.ts new file mode 100644 index 000000000..802907ea1 --- /dev/null +++ b/frontend/src/@socketio/hooks.ts @@ -0,0 +1,35 @@ +import { useCallback, useEffect, useMemo } from "react"; +import Socketio from "."; +import { log } from "../utilites/logger"; + +export function useSocketIOReducer( + key: SocketIO.EventType, + any?: () => void, + update?: SocketIO.ActionFn, + remove?: SocketIO.ActionFn +) { + const reducer = useMemo( + () => ({ key, any, update, delete: remove }), + [key, any, update, remove] + ); + useEffect(() => { + Socketio.addReducer(reducer); + log("info", "listening to SocketIO event", key); + return () => { + Socketio.removeReducer(reducer); + }; + }, [reducer, key]); +} + +export function useWrapToOptionalId( + fn: (id: number[]) => void +): SocketIO.ActionFn { + return useCallback( + (id?: number[]) => { + if (id) { + fn(id); + } + }, + [fn] + ); +} diff --git a/frontend/src/@socketio/index.ts b/frontend/src/@socketio/index.ts new file mode 100644 index 000000000..5a576e1b2 --- /dev/null +++ b/frontend/src/@socketio/index.ts @@ -0,0 +1,123 @@ +import { debounce, forIn, remove, uniq } from "lodash"; +import { io, Socket } from "socket.io-client"; +import { getBaseUrl } from "../utilites"; +import { conditionalLog, log } from "../utilites/logger"; +import { createDefaultReducer } from "./reducer"; + +class SocketIOClient { + private socket: Socket; + private events: SocketIO.Event[]; + private debounceReduce: () => void; + + private reducers: SocketIO.Reducer[]; + + constructor() { + const baseUrl = getBaseUrl(); + this.socket = io({ + path: `${baseUrl}/api/socket.io`, + transports: ["polling", "websocket"], + upgrade: true, + rememberUpgrade: true, + autoConnect: false, + }); + + this.socket.on("connect", this.onConnect.bind(this)); + this.socket.on("disconnect", this.onDisconnect.bind(this)); + this.socket.on("connect_error", this.onDisconnect.bind(this)); + this.socket.on("data", this.onEvent.bind(this)); + + this.events = []; + this.debounceReduce = debounce(this.reduce, 200); + this.reducers = []; + } + + initialize() { + this.reducers.push(...createDefaultReducer()); + this.socket.connect(); + + // Debug Command + window._socketio = { + dump: this.dump.bind(this), + emit: this.onEvent.bind(this), + }; + } + + private dump() { + console.log("SocketIO reducers", this.reducers); + } + + addReducer(reducer: SocketIO.Reducer) { + this.reducers.push(reducer); + } + + removeReducer(reducer: SocketIO.Reducer) { + const removed = remove(this.reducers, (r) => r === reducer); + conditionalLog(removed.length === 0, "Fail to remove reducer", reducer); + } + + private reduce() { + const events = [...this.events]; + this.events = []; + + const records: SocketIO.ActionRecord = {}; + + events.forEach((e) => { + if (!(e.type in records)) { + records[e.type] = {}; + } + const record = records[e.type]!; + if (!(e.action in record)) { + record[e.action] = []; + } + if (e.payload) { + record[e.action]?.push(e.payload); + } + }); + + forIn(records, (element, type) => { + if (element) { + const handlers = this.reducers.filter((v) => v.key === type); + if (handlers.length === 0) { + log("warning", "Unhandle SocketIO event", type); + return; + } + + // eslint-disable-next-line no-loop-func + handlers.forEach((handler) => { + const anyAction = handler.any; + if (anyAction) { + anyAction(); + } + + forIn(element, (ids, key) => { + ids = uniq(ids); + const action = handler[key as SocketIO.ActionType]; + if (action) { + action(ids); + } else if (anyAction === undefined) { + log("warning", "Unhandle action of SocketIO event", key, type); + } + }); + }); + } + }); + } + + private onConnect() { + log("info", "Socket.IO has connected"); + this.onEvent({ type: "connect", action: "update", payload: null }); + } + + private onDisconnect() { + log("warning", "Socket.IO has disconnected"); + this.onEvent({ type: "disconnect", action: "update", payload: null }); + } + + private onEvent(event: SocketIO.Event) { + log("info", "Socket.IO receives", event); + this.events.push(event); + this.debounceReduce(); + } +} + +export default new SocketIOClient(); diff --git a/frontend/src/@socketio/reducer.ts b/frontend/src/@socketio/reducer.ts new file mode 100644 index 000000000..0507f89e2 --- /dev/null +++ b/frontend/src/@socketio/reducer.ts @@ -0,0 +1,55 @@ +import { + badgeUpdateAll, + bootstrap, + movieDeleteItems, + movieUpdateList, + seriesDeleteItems, + seriesUpdateList, + siteUpdateOffline, + systemUpdateLanguagesAll, + systemUpdateSettings, +} from "../@redux/actions"; +import reduxStore from "../@redux/store"; + +function bindToReduxStore(fn: (ids?: number[]) => any): SocketIO.ActionFn { + return (ids?: number[]) => reduxStore.dispatch(fn(ids)); +} + +export function createDefaultReducer(): SocketIO.Reducer[] { + return [ + { + key: "connect", + any: () => reduxStore.dispatch(siteUpdateOffline(false)), + }, + { + key: "connect", + any: () => reduxStore.dispatch(bootstrap()), + }, + { + key: "disconnect", + any: () => reduxStore.dispatch(siteUpdateOffline(true)), + }, + { + key: "series", + update: bindToReduxStore(seriesUpdateList), + delete: bindToReduxStore(seriesDeleteItems), + }, + { + key: "movie", + update: bindToReduxStore(movieUpdateList), + delete: bindToReduxStore(movieDeleteItems), + }, + { + key: "settings", + any: bindToReduxStore(systemUpdateSettings), + }, + { + key: "languages", + any: bindToReduxStore(systemUpdateLanguagesAll), + }, + { + key: "badges", + any: bindToReduxStore(badgeUpdateAll), + }, + ]; +} diff --git a/frontend/src/@types/api.d.ts b/frontend/src/@types/api.d.ts index 21dbc332f..968237290 100644 --- a/frontend/src/@types/api.d.ts +++ b/frontend/src/@types/api.d.ts @@ -4,6 +4,7 @@ interface Badge { episodes: number; movies: number; providers: number; + status: number; } interface ApiLanguage { @@ -40,7 +41,6 @@ interface Subtitle extends Language { interface PathType { path: string; - exist: boolean; } interface SubtitlePathType { diff --git a/frontend/src/@types/basic.d.ts b/frontend/src/@types/basic.d.ts index cc27a4361..c0fe42c3c 100644 --- a/frontend/src/@types/basic.d.ts +++ b/frontend/src/@types/basic.d.ts @@ -11,12 +11,20 @@ type FileTree = { type StorageType = string | null; +interface OrderIdState { + items: IdState; + order: (number | null)[]; + fetched: boolean; +} + interface AsyncState { updating: boolean; error?: Error; data: Readonly; } +type AsyncOrderState = AsyncState>; + type AsyncPayload = T extends AsyncState ? D : never; type SelectorOption = { @@ -32,3 +40,5 @@ type SimpleStateType = [ T, ((item: T) => void) | ((fn: (item: T) => T) => void) ]; + +type Factory = () => T; diff --git a/frontend/src/@types/react-table.d.ts b/frontend/src/@types/react-table.d.ts index f78cba86a..0c917d200 100644 --- a/frontend/src/@types/react-table.d.ts +++ b/frontend/src/@types/react-table.d.ts @@ -32,7 +32,6 @@ import { UseSortByState, } from "react-table"; import {} from "../components/tables/plugins"; -import { PageControlAction } from "../components/tables/types"; declare module "react-table" { // take this file as-is, or comment out the sections that don't apply to your plugin configuration @@ -40,17 +39,6 @@ declare module "react-table" { // Customize of React Table type TableUpdater = (row: Row, ...others: any[]) => void; - interface useAsyncPaginationProps> { - asyncLoader?: (start: number, length: number) => void; - asyncState?: AsyncState>; - asyncId?: (item: D) => number; - } - - interface useAsyncPaginationState> { - pageToLoad?: PageControlAction; - needLoadingScreen?: boolean; - } - interface useSelectionProps> { isSelecting?: boolean; onSelect?: (items: D[]) => void; @@ -59,15 +47,13 @@ declare module "react-table" { interface useSelectionState> {} interface CustomTableProps> - extends useSelectionProps, - useAsyncPaginationProps { + extends useSelectionProps { externalUpdate?: TableUpdater; loose?: any[]; } interface CustomTableState> - extends useSelectionState, - useAsyncPaginationState {} + extends useSelectionState {} export interface TableOptions< D extends Record diff --git a/frontend/src/@types/socket.d.ts b/frontend/src/@types/socket.d.ts new file mode 100644 index 000000000..47d28d736 --- /dev/null +++ b/frontend/src/@types/socket.d.ts @@ -0,0 +1,39 @@ +namespace SocketIO { + type EventType = + | "connect" + | "disconnect" + | "movie" + | "series" + | "episode" + | "episode-history" + | "episode-blacklist" + | "episode-wanted" + | "movie-history" + | "movie-blacklist" + | "movie-wanted" + | "badges" + | "task" + | "settings" + | "languages" + | "message"; + + type ActionType = "update" | "delete"; + + interface Event { + type: EventType; + action: ActionType; + payload: any; // TODO: Use specific types + } + + type ActionFn = (payload?: any[]) => void; + + type Reducer = { + key: EventType; + any?: () => any; + } & Partial>; + + type ActionRecord = OptionalRecord< + EventType, + OptionalRecord + >; +} diff --git a/frontend/src/@types/system.d.ts b/frontend/src/@types/system.d.ts index cd79892b2..66bbab8cc 100644 --- a/frontend/src/@types/system.d.ts +++ b/frontend/src/@types/system.d.ts @@ -18,6 +18,11 @@ namespace System { sonarr_version: string; } + interface Health { + object: string; + issue: string; + } + interface Provider { name: string; status: string; diff --git a/frontend/src/@types/utilities.d.ts b/frontend/src/@types/utilities.d.ts index 1c2babf04..023ddfebf 100644 --- a/frontend/src/@types/utilities.d.ts +++ b/frontend/src/@types/utilities.d.ts @@ -37,3 +37,9 @@ type KeysOfType = NonNullable< >; type ItemIdType = KeysOfType; + +type OptionalRecord = { [P in T]?: D }; + +interface IdState { + [key: number]: Readonly; +} diff --git a/frontend/src/@types/window.d.ts b/frontend/src/@types/window.d.ts index ab5f39761..8f421d90b 100644 --- a/frontend/src/@types/window.d.ts +++ b/frontend/src/@types/window.d.ts @@ -1,6 +1,12 @@ +interface SocketIODebugger { + dump: () => void; + emit: (event: SocketIO.Event) => void; +} + declare global { interface Window { Bazarr: BazarrServer; + _socketio: SocketIODebugger; } } diff --git a/frontend/src/App/Header.tsx b/frontend/src/App/Header.tsx index 27a6c5c1f..47b4263b1 100644 --- a/frontend/src/App/Header.tsx +++ b/frontend/src/App/Header.tsx @@ -5,13 +5,7 @@ import { faUser, } from "@fortawesome/free-solid-svg-icons"; import { FontAwesomeIcon } from "@fortawesome/react-fontawesome"; -import React, { - FunctionComponent, - useCallback, - useContext, - useMemo, - useState, -} from "react"; +import React, { FunctionComponent, useContext, useMemo } from "react"; import { Button, Col, @@ -100,12 +94,6 @@ const Header: FunctionComponent = () => { [canLogout, setNeedAuth] ); - const [reconnecting, setReconnect] = useState(false); - const reconnect = useCallback(() => { - setReconnect(true); - SystemApi.status().finally(() => setReconnect(false)); - }, []); - const goHome = useGotoHomepage(); return ( @@ -137,13 +125,13 @@ const Header: FunctionComponent = () => { {offline ? ( - Reconnect + Connecting... ) : ( dropdown diff --git a/frontend/src/App/index.tsx b/frontend/src/App/index.tsx index 2523e6911..d98ff2388 100644 --- a/frontend/src/App/index.tsx +++ b/frontend/src/App/index.tsx @@ -6,8 +6,7 @@ import React, { } from "react"; import { Row } from "react-bootstrap"; import { Redirect } from "react-router-dom"; -import { bootstrap as ReduxBootstrap } from "../@redux/actions"; -import { useReduxAction, useReduxStore } from "../@redux/hooks/base"; +import { useReduxStore } from "../@redux/hooks/base"; import { useNotification } from "../@redux/hooks/site"; import { LoadingIndicator, ModalProvider } from "../components"; import Sidebar from "../Sidebar"; @@ -24,8 +23,6 @@ export const SidebarToggleContext = React.createContext<() => void>(() => {}); interface Props {} const App: FunctionComponent = () => { - const bootstrap = useReduxAction(ReduxBootstrap); - const { initialized, auth } = useReduxStore((s) => s.site); const notify = useNotification("has-update", 10); @@ -44,10 +41,6 @@ const App: FunctionComponent = () => { } }, [initialized, hasUpdate, notify]); - useEffect(() => { - bootstrap(); - }, [bootstrap]); - const [sidebar, setSidebar] = useState(false); const toggleSidebar = useCallback(() => setSidebar(!sidebar), [sidebar]); diff --git a/frontend/src/App/notifications/index.tsx b/frontend/src/App/notifications/index.tsx index 5f99613fe..dff37dd76 100644 --- a/frontend/src/App/notifications/index.tsx +++ b/frontend/src/App/notifications/index.tsx @@ -3,22 +3,14 @@ import { FontAwesomeIcon } from "@fortawesome/react-fontawesome"; import { capitalize } from "lodash"; import React, { FunctionComponent, useCallback, useMemo } from "react"; import { Toast } from "react-bootstrap"; -import { siteRemoveError } from "../../@redux/actions"; +import { siteRemoveNotification } from "../../@redux/actions"; import { useReduxAction, useReduxStore } from "../../@redux/hooks/base"; import "./style.scss"; -function useNotificationList() { - return useReduxStore((s) => s.site.notifications); -} - -function useRemoveNotification() { - return useReduxAction(siteRemoveError); -} - export interface NotificationContainerProps {} const NotificationContainer: FunctionComponent = () => { - const list = useNotificationList(); + const list = useReduxStore((s) => s.site.notifications); const items = useMemo( () => @@ -38,7 +30,7 @@ type MessageHolderProps = ReduxStore.Notification & {}; const NotificationToast: FunctionComponent = (props) => { const { message, id, type } = props; - const removeNotification = useRemoveNotification(); + const removeNotification = useReduxAction(siteRemoveNotification); const remove = useCallback(() => removeNotification(id), [ removeNotification, diff --git a/frontend/src/Blacklist/Movies/index.tsx b/frontend/src/Blacklist/Movies/index.tsx index 88e169807..f8913a3e2 100644 --- a/frontend/src/Blacklist/Movies/index.tsx +++ b/frontend/src/Blacklist/Movies/index.tsx @@ -5,17 +5,15 @@ import { Helmet } from "react-helmet"; import { useBlacklistMovies } from "../../@redux/hooks"; import { MoviesApi } from "../../apis"; import { AsyncStateOverlay, ContentHeader } from "../../components"; -import { useAutoUpdate } from "../../utilites/hooks"; import Table from "./table"; interface Props {} const BlacklistMoviesView: FunctionComponent = () => { - const [blacklist, update] = useBlacklistMovies(); - useAutoUpdate(update); + const [blacklist] = useBlacklistMovies(); return ( - {(data) => ( + {({ data }) => ( Movies Blacklist - Bazarr @@ -25,13 +23,12 @@ const BlacklistMoviesView: FunctionComponent = () => { icon={faTrash} disabled={data.length === 0} promise={() => MoviesApi.deleteBlacklist(true)} - onSuccess={update} > Remove All -
+
)} diff --git a/frontend/src/Blacklist/Movies/table.tsx b/frontend/src/Blacklist/Movies/table.tsx index b2b8dbaf0..4c642d344 100644 --- a/frontend/src/Blacklist/Movies/table.tsx +++ b/frontend/src/Blacklist/Movies/table.tsx @@ -13,10 +13,9 @@ import { interface Props { blacklist: readonly Blacklist.Movie[]; - update: () => void; } -const Table: FunctionComponent = ({ blacklist, update }) => { +const Table: FunctionComponent = ({ blacklist }) => { const columns = useMemo[]>( () => [ { @@ -78,7 +77,6 @@ const Table: FunctionComponent = ({ blacklist, update }) => { subs_id, }) } - onSuccess={update} > @@ -86,7 +84,7 @@ const Table: FunctionComponent = ({ blacklist, update }) => { }, }, ], - [update] + [] ); return ( { const sonarr = useIsSonarrEnabled(); const radarr = useIsRadarrEnabled(); + + useSetSidebar("Blacklist"); return ( {sonarr && ( diff --git a/frontend/src/Blacklist/Series/index.tsx b/frontend/src/Blacklist/Series/index.tsx index df2c0eecc..bde82a1a8 100644 --- a/frontend/src/Blacklist/Series/index.tsx +++ b/frontend/src/Blacklist/Series/index.tsx @@ -5,17 +5,15 @@ import { Helmet } from "react-helmet"; import { useBlacklistSeries } from "../../@redux/hooks"; import { EpisodesApi } from "../../apis"; import { AsyncStateOverlay, ContentHeader } from "../../components"; -import { useAutoUpdate } from "../../utilites"; import Table from "./table"; interface Props {} const BlacklistSeriesView: FunctionComponent = () => { - const [blacklist, update] = useBlacklistSeries(); - useAutoUpdate(update); + const [blacklist] = useBlacklistSeries(); return ( - {(data) => ( + {({ data }) => ( Series Blacklist - Bazarr @@ -25,13 +23,12 @@ const BlacklistSeriesView: FunctionComponent = () => { icon={faTrash} disabled={data.length === 0} promise={() => EpisodesApi.deleteBlacklist(true)} - onSuccess={update} > Remove All -
+
)} diff --git a/frontend/src/Blacklist/Series/table.tsx b/frontend/src/Blacklist/Series/table.tsx index 4cbd2bee5..6448389a6 100644 --- a/frontend/src/Blacklist/Series/table.tsx +++ b/frontend/src/Blacklist/Series/table.tsx @@ -13,10 +13,9 @@ import { interface Props { blacklist: readonly Blacklist.Episode[]; - update: () => void; } -const Table: FunctionComponent = ({ blacklist, update }) => { +const Table: FunctionComponent = ({ blacklist }) => { const columns = useMemo[]>( () => [ { @@ -84,7 +83,6 @@ const Table: FunctionComponent = ({ blacklist, update }) => { subs_id, }) } - onSuccess={update} > @@ -92,7 +90,7 @@ const Table: FunctionComponent = ({ blacklist, update }) => { }, }, ], - [update] + [] ); return ( = () => { - const [movies, update] = useMoviesHistory(); - useAutoUpdate(update); - - const tableUpdate = useCallback((row: Row) => update(), [ - update, - ]); + const [movies] = useMoviesHistory(); const columns: Column[] = useMemo[]>( () => [ @@ -114,12 +108,11 @@ const MoviesHistoryView: FunctionComponent = () => { }, { accessor: "blacklisted", - Cell: ({ row, externalUpdate }) => { + Cell: ({ row }) => { const original = row.original; return ( externalUpdate && externalUpdate(row)} promise={(form) => MoviesApi.addBlacklist(original.radarrId, form) } @@ -136,7 +129,6 @@ const MoviesHistoryView: FunctionComponent = () => { type="movies" state={movies} columns={columns as Column[]} - tableUpdater={tableUpdate} > ); }; diff --git a/frontend/src/History/Router.tsx b/frontend/src/History/Router.tsx index 0ea21c546..b7693355f 100644 --- a/frontend/src/History/Router.tsx +++ b/frontend/src/History/Router.tsx @@ -1,6 +1,10 @@ import React, { FunctionComponent } from "react"; import { Redirect, Route, Switch } from "react-router-dom"; -import { useIsRadarrEnabled, useIsSonarrEnabled } from "../@redux/hooks/site"; +import { + useIsRadarrEnabled, + useIsSonarrEnabled, + useSetSidebar, +} from "../@redux/hooks/site"; import { RouterEmptyPath } from "../special-pages/404"; import MoviesHistory from "./Movies"; import SeriesHistory from "./Series"; @@ -9,6 +13,8 @@ import HistoryStats from "./Statistics"; const Router: FunctionComponent = () => { const sonarr = useIsSonarrEnabled(); const radarr = useIsRadarrEnabled(); + + useSetSidebar("History"); return ( {sonarr && ( diff --git a/frontend/src/History/Series/index.tsx b/frontend/src/History/Series/index.tsx index 02dbda4f1..d29cbc656 100644 --- a/frontend/src/History/Series/index.tsx +++ b/frontend/src/History/Series/index.tsx @@ -1,25 +1,19 @@ import { faInfoCircle, faRecycle } from "@fortawesome/free-solid-svg-icons"; import { FontAwesomeIcon } from "@fortawesome/react-fontawesome"; -import React, { FunctionComponent, useCallback, useMemo } from "react"; +import React, { FunctionComponent, useMemo } from "react"; import { Badge, OverlayTrigger, Popover } from "react-bootstrap"; import { Link } from "react-router-dom"; -import { Column, Row } from "react-table"; +import { Column } from "react-table"; import { useSeriesHistory } from "../../@redux/hooks"; import { EpisodesApi } from "../../apis"; import { HistoryIcon, LanguageText, TextPopover } from "../../components"; import { BlacklistButton } from "../../generic/blacklist"; -import { useAutoUpdate } from "../../utilites/hooks"; import HistoryGenericView from "../generic"; interface Props {} const SeriesHistoryView: FunctionComponent = () => { - const [series, update] = useSeriesHistory(); - useAutoUpdate(update); - - const tableUpdate = useCallback((row: Row) => update(), [ - update, - ]); + const [series] = useSeriesHistory(); const columns: Column[] = useMemo[]>( () => [ @@ -121,14 +115,13 @@ const SeriesHistoryView: FunctionComponent = () => { }, { accessor: "blacklisted", - Cell: ({ row, externalUpdate }) => { + Cell: ({ row }) => { const original = row.original; const { sonarrEpisodeId, sonarrSeriesId } = original; return ( externalUpdate && externalUpdate(row)} promise={(form) => EpisodesApi.addBlacklist(sonarrSeriesId, sonarrEpisodeId, form) } @@ -145,7 +138,6 @@ const SeriesHistoryView: FunctionComponent = () => { type="series" state={series} columns={columns as Column[]} - tableUpdater={tableUpdate} > ); }; diff --git a/frontend/src/History/Statistics/index.tsx b/frontend/src/History/Statistics/index.tsx index f072464f1..f8ecc986e 100644 --- a/frontend/src/History/Statistics/index.tsx +++ b/frontend/src/History/Statistics/index.tsx @@ -12,7 +12,7 @@ import { XAxis, YAxis, } from "recharts"; -import { useLanguages, useProviders } from "../../@redux/hooks"; +import { useLanguages, useSystemProviders } from "../../@redux/hooks"; import { HistoryApi } from "../../apis"; import { AsyncSelector, @@ -21,7 +21,6 @@ import { PromiseOverlay, Selector, } from "../../components"; -import { useAutoUpdate } from "../../utilites/hooks"; import { actionOptions, timeframeOptions } from "./options"; function converter(item: History.Stat) { @@ -48,8 +47,7 @@ const SelectorContainer: FunctionComponent = ({ children }) => ( const HistoryStats: FunctionComponent = () => { const [languages] = useLanguages(true); - const [providerList, update] = useProviders(); - useAutoUpdate(update); + const [providerList] = useSystemProviders(); const [timeframe, setTimeframe] = useState("month"); const [action, setAction] = useState>(null); diff --git a/frontend/src/History/generic/index.tsx b/frontend/src/History/generic/index.tsx index 91bf5e760..4b14c486e 100644 --- a/frontend/src/History/generic/index.tsx +++ b/frontend/src/History/generic/index.tsx @@ -2,21 +2,19 @@ import { capitalize } from "lodash"; import React, { FunctionComponent } from "react"; import { Container, Row } from "react-bootstrap"; import { Helmet } from "react-helmet"; -import { Column, TableUpdater } from "react-table"; +import { Column } from "react-table"; import { AsyncStateOverlay, PageTable } from "../../components"; interface Props { type: "movies" | "series"; state: Readonly>; columns: Column[]; - tableUpdater?: TableUpdater; } const HistoryGenericView: FunctionComponent = ({ state, columns, type, - tableUpdater, }) => { const typeName = capitalize(type); return ( @@ -26,12 +24,11 @@ const HistoryGenericView: FunctionComponent = ({ - {(data) => ( + {({ data }) => ( )} diff --git a/frontend/src/Movies/Detail/index.tsx b/frontend/src/Movies/Detail/index.tsx index 41dfeb834..8cc6a9239 100644 --- a/frontend/src/Movies/Detail/index.tsx +++ b/frontend/src/Movies/Detail/index.tsx @@ -25,7 +25,7 @@ import { import { ManualSearchModal } from "../../components/modals/ManualSearchModal"; import ItemOverview from "../../generic/ItemOverview"; import { RouterEmptyPath } from "../../special-pages/404"; -import { useAutoUpdate, useWhenLoadingFinish } from "../../utilites"; +import { useWhenLoadingFinish } from "../../utilites"; import Table from "./table"; const download = (item: any, result: SearchResultType) => { @@ -48,8 +48,7 @@ interface Props extends RouteComponentProps {} const MovieDetailView: FunctionComponent = ({ match }) => { const id = Number.parseInt(match.params.id); - const [movie, update] = useMovieBy(id); - useAutoUpdate(update); + const [movie] = useMovieBy(id); const item = movie.data; const showModal = useShowModal(); @@ -86,7 +85,6 @@ const MovieDetailView: FunctionComponent = ({ match }) => { promise={() => MoviesApi.action({ action: "scan-disk", radarrid: item.radarrId }) } - onSuccess={update} > Scan Disk @@ -99,7 +97,6 @@ const MovieDetailView: FunctionComponent = ({ match }) => { radarrid: item.radarrId, }) } - onSuccess={update} > Search @@ -144,23 +141,17 @@ const MovieDetailView: FunctionComponent = ({ match }) => { -
+
MoviesApi.modify(form)} - onSuccess={update} > - + diff --git a/frontend/src/Movies/Detail/table.tsx b/frontend/src/Movies/Detail/table.tsx index f1d6eb51c..a50b9c82b 100644 --- a/frontend/src/Movies/Detail/table.tsx +++ b/frontend/src/Movies/Detail/table.tsx @@ -1,8 +1,11 @@ import { faSearch, faTrash } from "@fortawesome/free-solid-svg-icons"; import { FontAwesomeIcon } from "@fortawesome/react-fontawesome"; +import { intersectionWith } from "lodash"; import React, { FunctionComponent, useMemo } from "react"; import { Badge } from "react-bootstrap"; import { Column } from "react-table"; +import { useProfileItems } from "../../@redux/hooks"; +import { useShowOnlyDesired } from "../../@redux/hooks/site"; import { MoviesApi } from "../../apis"; import { AsyncButton, LanguageText, SimpleTable } from "../../components"; @@ -10,11 +13,13 @@ const missingText = "Missing Subtitles"; interface Props { movie: Item.Movie; - update: (id: number) => void; + profile?: Profile.Languages; } -const Table: FunctionComponent = (props) => { - const { movie, update } = props; +const Table: FunctionComponent = ({ movie, profile }) => { + const onlyDesired = useShowOnlyDesired(); + + const profileItems = useProfileItems(profile); const columns: Column[] = useMemo[]>( () => [ @@ -66,7 +71,6 @@ const Table: FunctionComponent = (props) => { forced: original.forced, }) } - onSuccess={() => update(movie.radarrId)} variant="light" size="sm" > @@ -86,7 +90,6 @@ const Table: FunctionComponent = (props) => { path: original.path ?? "", }) } - onSuccess={() => update(movie.radarrId)} > @@ -95,7 +98,7 @@ const Table: FunctionComponent = (props) => { }, }, ], - [movie, update] + [movie] ); const data: Subtitle[] = useMemo(() => { @@ -104,8 +107,17 @@ const Table: FunctionComponent = (props) => { return item; }); - return movie.subtitles.concat(missing); - }, [movie.missing_subtitles, movie.subtitles]); + let raw_subtitles = movie.subtitles; + if (onlyDesired) { + raw_subtitles = intersectionWith( + raw_subtitles, + profileItems, + (l, r) => l.code2 === r.code2 + ); + } + + return [...raw_subtitles, ...missing]; + }, [movie.missing_subtitles, movie.subtitles, onlyDesired, profileItems]); return ( = () => { } }, }, - { - Header: "Exist", - accessor: "exist", - selectHide: true, - Cell: ({ row, value }) => { - const exist = value; - const { path } = row.original; - return ( - - ); - }, - }, { Header: "Audio", accessor: "audio_language", @@ -133,8 +113,8 @@ const MovieView: FunctionComponent = () => { state={movies} name="Movies" loader={load} - updateAction={movieUpdateInfoAll} - columns={columns as Column[]} + updateAction={movieUpdateList} + columns={columns} modify={(form) => MoviesApi.modify(form)} > ); diff --git a/frontend/src/Series/Episodes/components.tsx b/frontend/src/Series/Episodes/components.tsx index ef4ca43b3..7f9c34bfb 100644 --- a/frontend/src/Series/Episodes/components.tsx +++ b/frontend/src/Series/Episodes/components.tsx @@ -2,7 +2,6 @@ import { faSearch, faTrash } from "@fortawesome/free-solid-svg-icons"; import { FontAwesomeIcon } from "@fortawesome/react-fontawesome"; import React, { FunctionComponent } from "react"; import { Badge } from "react-bootstrap"; -import { useSerieBy } from "../../@redux/hooks"; import { EpisodesApi } from "../../apis"; import { AsyncButton, LanguageText } from "../../components"; @@ -21,8 +20,6 @@ export const SubtitleAction: FunctionComponent = ({ }) => { const { hi, forced } = subtitle; - const [, update] = useSerieBy(seriesid); - const path = subtitle.path; if (missing || path) { @@ -46,7 +43,6 @@ export const SubtitleAction: FunctionComponent = ({ return null; } }} - onSuccess={update} as={Badge} className="mr-1" variant={missing ? "primary" : "secondary"} diff --git a/frontend/src/Series/Episodes/index.tsx b/frontend/src/Series/Episodes/index.tsx index 866c1c67a..aebbefac8 100644 --- a/frontend/src/Series/Episodes/index.tsx +++ b/frontend/src/Series/Episodes/index.tsx @@ -16,7 +16,7 @@ import React, { import { Container, Row } from "react-bootstrap"; import { Helmet } from "react-helmet"; import { Redirect, RouteComponentProps, withRouter } from "react-router-dom"; -import { useEpisodesBy, useSerieBy } from "../../@redux/hooks"; +import { useEpisodesBy, useProfileBy, useSerieBy } from "../../@redux/hooks"; import { SeriesApi } from "../../apis"; import { ContentHeader, @@ -27,7 +27,7 @@ import { } from "../../components"; import ItemOverview from "../../generic/ItemOverview"; import { RouterEmptyPath } from "../../special-pages/404"; -import { useAutoUpdate, useWhenLoadingFinish } from "../../utilites"; +import { useWhenLoadingFinish } from "../../utilites"; import Table from "./table"; interface Params { @@ -39,13 +39,11 @@ interface Props extends RouteComponentProps {} const SeriesEpisodesView: FunctionComponent = (props) => { const { match } = props; const id = Number.parseInt(match.params.id); - const [serie, update] = useSerieBy(id); + const [serie] = useSerieBy(id); const item = serie.data; const [episodes] = useEpisodesBy(serie.data?.sonarrSeriesId); - useAutoUpdate(update); - const available = episodes.data.length !== 0; const details = useMemo( @@ -74,6 +72,8 @@ const SeriesEpisodesView: FunctionComponent = (props) => { useWhenLoadingFinish(serie, validator); + const profile = useProfileBy(serie.data?.profileId); + if (isNaN(id) || !valid) { return ; } @@ -95,7 +95,6 @@ const SeriesEpisodesView: FunctionComponent = (props) => { promise={() => SeriesApi.action({ action: "scan-disk", seriesid: id }) } - onSuccess={update} > Scan Disk @@ -104,7 +103,6 @@ const SeriesEpisodesView: FunctionComponent = (props) => { promise={() => SeriesApi.action({ action: "search-missing", seriesid: id }) } - onSuccess={update} disabled={ item.episodeFileCount === 0 || item.profileId === null || @@ -145,14 +143,16 @@ const SeriesEpisodesView: FunctionComponent = (props) => { -
+
SeriesApi.modify(form)} - onSuccess={update} > - + ); }; diff --git a/frontend/src/Series/Episodes/table.tsx b/frontend/src/Series/Episodes/table.tsx index 10ad8fadf..ad2df7596 100644 --- a/frontend/src/Series/Episodes/table.tsx +++ b/frontend/src/Series/Episodes/table.tsx @@ -6,10 +6,12 @@ import { faUser, } from "@fortawesome/free-solid-svg-icons"; import { FontAwesomeIcon } from "@fortawesome/react-fontawesome"; +import { intersectionWith } from "lodash"; import React, { FunctionComponent, useCallback, useMemo } from "react"; import { Badge, ButtonGroup } from "react-bootstrap"; -import { Column, TableOptions, TableUpdater } from "react-table"; -import { useSerieBy } from "../../@redux/hooks"; +import { Column, TableUpdater } from "react-table"; +import { useProfileItems, useSerieBy } from "../../@redux/hooks"; +import { useShowOnlyDesired } from "../../@redux/hooks/site"; import { ProvidersApi } from "../../apis"; import { ActionButton, @@ -26,7 +28,7 @@ import { SubtitleAction } from "./components"; interface Props { episodes: AsyncState; - update: () => void; + profile?: Profile.Languages; } const download = (item: any, result: SearchResultType) => { @@ -45,9 +47,13 @@ const download = (item: any, result: SearchResultType) => { ); }; -const Table: FunctionComponent = ({ episodes, update }) => { +const Table: FunctionComponent = ({ episodes, profile }) => { const showModal = useShowModal(); + const onlyDesired = useShowOnlyDesired(); + + const profileItems = useProfileItems(profile); + const columns: Column[] = useMemo[]>( () => [ { @@ -113,7 +119,16 @@ const Table: FunctionComponent = ({ episodes, update }) => { > )); - const subtitles = episode.subtitles.map((val, idx) => ( + let raw_subtitles = episode.subtitles; + if (onlyDesired) { + raw_subtitles = intersectionWith( + raw_subtitles, + profileItems, + (l, r) => l.code2 === r.code2 + ); + } + + const subtitles = raw_subtitles.map((val, idx) => ( = ({ episodes, update }) => { }, }, ], - [] + [onlyDesired, profileItems] ); const updateRow = useCallback>( @@ -183,43 +198,32 @@ const Table: FunctionComponent = ({ episodes, update }) => { [episodes] ); - const options: TableOptions = useMemo(() => { - return { - columns, - data: episodes.data, - externalUpdate: updateRow, - initialState: { - sortBy: [ - { id: "season", desc: true }, - { id: "episode", desc: true }, - ], - groupBy: ["season"], - expanded: { - [`season:${maxSeason}`]: true, - }, - }, - }; - }, [episodes, columns, maxSeason, updateRow]); - return ( - {() => ( + {({ data }) => ( )} - + diff --git a/frontend/src/Series/index.tsx b/frontend/src/Series/index.tsx index a01b8c06a..86757805b 100644 --- a/frontend/src/Series/index.tsx +++ b/frontend/src/Series/index.tsx @@ -1,14 +1,9 @@ -import { - faCheck, - faExclamationTriangle, - faWrench, -} from "@fortawesome/free-solid-svg-icons"; -import { FontAwesomeIcon } from "@fortawesome/react-fontawesome"; +import { faWrench } from "@fortawesome/free-solid-svg-icons"; import React, { FunctionComponent, useMemo } from "react"; import { Badge, ProgressBar } from "react-bootstrap"; import { Link } from "react-router-dom"; import { Column } from "react-table"; -import { seriesUpdateByRange, seriesUpdateInfoAll } from "../@redux/actions"; +import { seriesUpdateByRange, seriesUpdateList } from "../@redux/actions"; import { useRawSeries } from "../@redux/hooks"; import { useReduxAction } from "../@redux/hooks/base"; import { SeriesApi } from "../apis"; @@ -40,21 +35,6 @@ const SeriesView: FunctionComponent = () => { } }, }, - { - Header: "Exist", - accessor: "exist", - selectHide: true, - Cell: (row) => { - const exist = row.value; - const { path } = row.row.original; - return ( - - ); - }, - }, { Header: "Audio", accessor: "audio_language", @@ -138,9 +118,9 @@ const SeriesView: FunctionComponent = () => { []} + columns={columns} modify={(form) => SeriesApi.modify(form)} > ); diff --git a/frontend/src/Settings/Notifications/components.tsx b/frontend/src/Settings/Notifications/components.tsx index 9a71a991b..81a8a578c 100644 --- a/frontend/src/Settings/Notifications/components.tsx +++ b/frontend/src/Settings/Notifications/components.tsx @@ -17,18 +17,13 @@ import { useShowModal, } from "../../components"; import { BuildKey } from "../../utilites"; -import { ColCard, useLatestMergeArray, useUpdateArray } from "../components"; +import { ColCard, useLatestArray, useUpdateArray } from "../components"; import { notificationsKey } from "../keys"; interface ModalProps { selections: readonly Settings.NotificationInfo[]; } -const notificationComparer = ( - one: Settings.NotificationInfo, - another: Settings.NotificationInfo -) => one.name === another.name; - const NotificationModal: FunctionComponent = ({ selections, ...modal @@ -46,7 +41,7 @@ const NotificationModal: FunctionComponent = ({ const update = useUpdateArray( notificationsKey, - notificationComparer + "name" ); const payload = usePayload(modal.modalKey); @@ -158,10 +153,10 @@ const NotificationModal: FunctionComponent = ({ }; export const NotificationView: FunctionComponent = () => { - const notifications = useLatestMergeArray( + const notifications = useLatestArray( notificationsKey, - notificationComparer, - (settings) => settings.notifications.providers + "name", + (s) => s.notifications.providers ); const showModal = useShowModal(); diff --git a/frontend/src/Settings/Router.tsx b/frontend/src/Settings/Router.tsx index ca721967f..e7f4e277c 100644 --- a/frontend/src/Settings/Router.tsx +++ b/frontend/src/Settings/Router.tsx @@ -1,9 +1,9 @@ -import React, { FunctionComponent } from "react"; +import React, { FunctionComponent, useEffect } from "react"; import { Redirect, Route, Switch } from "react-router-dom"; import { systemUpdateSettings } from "../@redux/actions"; import { useReduxAction } from "../@redux/hooks/base"; +import { useSetSidebar } from "../@redux/hooks/site"; import { RouterEmptyPath } from "../special-pages/404"; -import { useAutoUpdate } from "../utilites/hooks"; import General from "./General"; import Languages from "./Languages"; import Notifications from "./Notifications"; @@ -18,8 +18,9 @@ interface Props {} const Router: FunctionComponent = () => { const update = useReduxAction(systemUpdateSettings); - useAutoUpdate(update); + useEffect(() => update, [update]); + useSetSidebar("Settings"); return ( diff --git a/frontend/src/Settings/components/hooks.ts b/frontend/src/Settings/components/hooks.ts index 6579c8c43..819323e3d 100644 --- a/frontend/src/Settings/components/hooks.ts +++ b/frontend/src/Settings/components/hooks.ts @@ -1,8 +1,7 @@ -import { isArray, isEqual } from "lodash"; +import { isArray, uniqBy } from "lodash"; import { useCallback, useContext, useMemo } from "react"; import { useStore } from "react-redux"; import { useSystemSettings } from "../../@redux/hooks"; -import { mergeArray } from "../../utilites"; import { log } from "../../utilites/logger"; import { StagedChangesContext } from "./provider"; @@ -96,40 +95,6 @@ export function useExtract( } } -export function useUpdateArray( - key: string, - compare?: (one: T, another: T) => boolean -) { - const update = useSingleUpdate(); - const stagedValue = useStagedValues(); - - if (compare === undefined) { - compare = isEqual; - } - - const staged: T[] = useMemo(() => { - if (key in stagedValue) { - return stagedValue[key]; - } else { - return []; - } - }, [key, stagedValue]); - - return useCallback( - (v: T) => { - const newArray = [...staged]; - const idx = newArray.findIndex((inn) => compare!(inn, v)); - if (idx !== -1) { - newArray[idx] = v; - } else { - newArray.push(v); - } - update(newArray, key); - }, - [compare, staged, key, update] - ); -} - export function useLatest( key: string, validate: ValidateFuncType, @@ -144,19 +109,14 @@ export function useLatest( } } -// Merge Two Array -export function useLatestMergeArray( +export function useLatestArray( key: string, - compare: Comparer, + compare: keyof T, override?: OverrideFuncType ): Readonly> { const extractValue = useExtract(key, isArray, override); const stagedValue = useStagedValues(); - if (compare === undefined) { - compare = isEqual; - } - let staged: T[] | undefined = undefined; if (key in stagedValue) { staged = stagedValue[key]; @@ -164,9 +124,30 @@ export function useLatestMergeArray( return useMemo(() => { if (staged !== undefined && extractValue) { - return mergeArray(extractValue, staged, compare); + return uniqBy([...staged, ...extractValue], compare); } else { return extractValue; } }, [extractValue, staged, compare]); } + +export function useUpdateArray(key: string, compare: keyof T) { + const update = useSingleUpdate(); + const stagedValue = useStagedValues(); + + const staged: T[] = useMemo(() => { + if (key in stagedValue) { + return stagedValue[key]; + } else { + return []; + } + }, [key, stagedValue]); + + return useCallback( + (v: T) => { + const newArray = uniqBy([v, ...staged], compare); + update(newArray, key); + }, + [compare, staged, key, update] + ); +} diff --git a/frontend/src/Settings/components/provider.tsx b/frontend/src/Settings/components/provider.tsx index 52cf5e9bf..580f3dd2f 100644 --- a/frontend/src/Settings/components/provider.tsx +++ b/frontend/src/Settings/components/provider.tsx @@ -10,13 +10,12 @@ import React, { import { Container, Row } from "react-bootstrap"; import { Helmet } from "react-helmet"; import { Prompt } from "react-router"; -import { - siteSaveLocalstorage, - systemUpdateSettingsAll, -} from "../../@redux/actions"; -import { useReduxAction, useReduxActionWith } from "../../@redux/hooks/base"; +import { siteSaveLocalstorage } from "../../@redux/actions"; +import { useSystemSettings } from "../../@redux/hooks"; +import { useReduxAction } from "../../@redux/hooks/base"; import { SystemApi } from "../../apis"; import { ContentHeader } from "../../components"; +import { useWhenLoadingFinish } from "../../utilites"; import { log } from "../../utilites/logger"; import { enabledLanguageKey, @@ -66,17 +65,15 @@ const SettingsProvider: FunctionComponent = (props) => { setUpdating(false); }, []); - const update = useReduxActionWith(systemUpdateSettingsAll, cleanup); + const [settings] = useSystemSettings(); + useWhenLoadingFinish(settings, cleanup); - const saveSettings = useCallback( - (settings: LooseObject) => { - submitHooks(settings); - setUpdating(true); - log("info", "submitting settings", settings); - SystemApi.setSettings(settings).finally(update); - }, - [update] - ); + const saveSettings = useCallback((settings: LooseObject) => { + submitHooks(settings); + setUpdating(true); + log("info", "submitting settings", settings); + SystemApi.setSettings(settings); + }, []); const saveLocalStorage = useCallback( (settings: LooseObject) => { diff --git a/frontend/src/Sidebar/index.tsx b/frontend/src/Sidebar/index.tsx index 339ab4768..dd2c19375 100644 --- a/frontend/src/Sidebar/index.tsx +++ b/frontend/src/Sidebar/index.tsx @@ -1,17 +1,10 @@ -import React, { - FunctionComponent, - useContext, - useEffect, - useMemo, -} from "react"; +import React, { FunctionComponent, useContext, useMemo } from "react"; import { Container, Image, ListGroup } from "react-bootstrap"; -import { useHistory } from "react-router-dom"; -import { badgeUpdateAll, siteChangeSidebar } from "../@redux/actions"; -import { useReduxAction, useReduxStore } from "../@redux/hooks/base"; +import { useReduxStore } from "../@redux/hooks/base"; import { useIsRadarrEnabled, useIsSonarrEnabled } from "../@redux/hooks/site"; import logo from "../@static/logo64.png"; import { SidebarToggleContext } from "../App"; -import { useAutoUpdate, useGotoHomepage } from "../utilites/hooks"; +import { useGotoHomepage } from "../utilites/hooks"; import { BadgesContext, CollapseItem, @@ -22,25 +15,16 @@ import { RadarrDisabledKey, SidebarList, SonarrDisabledKey } from "./list"; import "./style.scss"; import { BadgeProvider } from "./types"; -export function useSidebarKey() { - return useReduxStore((s) => s.site.sidebar); -} - -export function useUpdateSidebar() { - return useReduxAction(siteChangeSidebar); -} - interface Props { open?: boolean; } const Sidebar: FunctionComponent = ({ open }) => { - const updateBadges = useReduxAction(badgeUpdateAll); - useAutoUpdate(updateBadges); - const toggle = useContext(SidebarToggleContext); - const { movies, episodes, providers } = useReduxStore((s) => s.site.badges); + const { movies, episodes, providers, status } = useReduxStore( + (s) => s.site.badges + ); const sonarrEnabled = useIsSonarrEnabled(); const radarrEnabled = useIsRadarrEnabled(); @@ -53,9 +37,10 @@ const Sidebar: FunctionComponent = ({ open }) => { }, System: { Providers: providers, + Status: status, }, }), - [movies, episodes, providers, sonarrEnabled, radarrEnabled] + [movies, episodes, providers, sonarrEnabled, radarrEnabled, status] ); const hiddenKeys = useMemo(() => { @@ -69,20 +54,6 @@ const Sidebar: FunctionComponent = ({ open }) => { return list; }, [sonarrEnabled, radarrEnabled]); - const history = useHistory(); - - const updateSidebar = useUpdateSidebar(); - - useEffect(() => { - const path = history.location.pathname.split("/"); - const len = path.length; - if (len >= 3) { - updateSidebar(path[len - 2]); - } else { - updateSidebar(path[len - 1]); - } - }, [history.location.pathname, updateSidebar]); - const cls = ["sidebar-container"]; const overlay = ["sidebar-overlay"]; diff --git a/frontend/src/Sidebar/items.tsx b/frontend/src/Sidebar/items.tsx index 8982e29c8..beb376eb2 100644 --- a/frontend/src/Sidebar/items.tsx +++ b/frontend/src/Sidebar/items.tsx @@ -3,7 +3,8 @@ import { FontAwesomeIcon } from "@fortawesome/react-fontawesome"; import React, { FunctionComponent, useContext, useMemo } from "react"; import { Badge, Collapse, ListGroupItem } from "react-bootstrap"; import { NavLink } from "react-router-dom"; -import { useSidebarKey, useUpdateSidebar } from "."; +import { siteChangeSidebar } from "../@redux/actions"; +import { useReduxAction, useReduxStore } from "../@redux/hooks/base"; import { SidebarToggleContext } from "../App"; import { BadgeProvider, @@ -16,6 +17,14 @@ export const HiddenKeysContext = React.createContext([]); export const BadgesContext = React.createContext({}); +function useToggleSidebar() { + return useReduxAction(siteChangeSidebar); +} + +function useSidebarKey() { + return useReduxStore((s) => s.site.sidebar); +} + export const LinkItem: FunctionComponent = ({ link, name, @@ -60,10 +69,8 @@ export const CollapseItem: FunctionComponent = ({ const hiddenKeys = useContext(HiddenKeysContext); const toggleSidebar = useContext(SidebarToggleContext); - const itemKey = name.toLowerCase(); - const sidebarKey = useSidebarKey(); - const updateSidebar = useUpdateSidebar(); + const updateSidebar = useToggleSidebar(); const [badgeValue, childValue] = useMemo< [Nullable, Nullable] @@ -86,7 +93,7 @@ export const CollapseItem: FunctionComponent = ({ return [badge, child]; }, [badges, name]); - const active = useMemo(() => sidebarKey === itemKey, [sidebarKey, itemKey]); + const active = useMemo(() => sidebarKey === name, [sidebarKey, name]); const collapseBoxClass = useMemo( () => `sidebar-collapse-box ${active ? "active" : ""}`, @@ -133,7 +140,7 @@ export const CollapseItem: FunctionComponent = ({ if (active) { updateSidebar(""); } else { - updateSidebar(itemKey); + updateSidebar(name); } }} > diff --git a/frontend/src/System/Logs/index.tsx b/frontend/src/System/Logs/index.tsx index 43bfec04c..151d6fb69 100644 --- a/frontend/src/System/Logs/index.tsx +++ b/frontend/src/System/Logs/index.tsx @@ -2,20 +2,16 @@ import { faDownload, faSync, faTrash } from "@fortawesome/free-solid-svg-icons"; import React, { FunctionComponent, useCallback, useState } from "react"; import { Container, Row } from "react-bootstrap"; import { Helmet } from "react-helmet"; -import { systemUpdateLogs } from "../../@redux/actions"; -import { useReduxAction, useReduxStore } from "../../@redux/hooks/base"; +import { useSystemLogs } from "../../@redux/hooks"; import { SystemApi } from "../../apis"; import { AsyncStateOverlay, ContentHeader } from "../../components"; import { useBaseUrl } from "../../utilites"; -import { useAutoUpdate } from "../../utilites/hooks"; import Table from "./table"; interface Props {} const SystemLogsView: FunctionComponent = () => { - const logs = useReduxStore(({ system }) => system.logs); - const update = useReduxAction(systemUpdateLogs); - useAutoUpdate(update); + const [logs, update] = useSystemLogs(); const [resetting, setReset] = useState(false); @@ -27,7 +23,7 @@ const SystemLogsView: FunctionComponent = () => { return ( - {(data) => ( + {({ data }) => ( Logs - Bazarr (System) diff --git a/frontend/src/System/Providers/index.tsx b/frontend/src/System/Providers/index.tsx index 85cef7265..3f8a15ae4 100644 --- a/frontend/src/System/Providers/index.tsx +++ b/frontend/src/System/Providers/index.tsx @@ -2,21 +2,19 @@ import { faSync, faTrash } from "@fortawesome/free-solid-svg-icons"; import React, { FunctionComponent } from "react"; import { Container, Row } from "react-bootstrap"; import { Helmet } from "react-helmet"; -import { useProviders } from "../../@redux/hooks"; +import { useSystemProviders } from "../../@redux/hooks"; import { ProvidersApi } from "../../apis"; import { AsyncStateOverlay, ContentHeader } from "../../components"; -import { useAutoUpdate } from "../../utilites/hooks"; import Table from "./table"; interface Props {} const SystemProvidersView: FunctionComponent = () => { - const [providers, update] = useProviders(); - useAutoUpdate(update); + const [providers, update] = useSystemProviders(); return ( - {(data) => ( + {({ data }) => ( Providers - Bazarr (System) diff --git a/frontend/src/System/Releases/index.tsx b/frontend/src/System/Releases/index.tsx index f28b14461..778916c35 100644 --- a/frontend/src/System/Releases/index.tsx +++ b/frontend/src/System/Releases/index.tsx @@ -1,28 +1,24 @@ import React, { FunctionComponent, useMemo } from "react"; import { Badge, Card, Col, Container, Row } from "react-bootstrap"; import { Helmet } from "react-helmet"; -import { systemUpdateReleases } from "../../@redux/actions"; -import { useReduxAction, useReduxStore } from "../../@redux/hooks/base"; +import { useSystemReleases } from "../../@redux/hooks"; import { AsyncStateOverlay } from "../../components"; import { BuildKey } from "../../utilites"; -import { useAutoUpdate } from "../../utilites/hooks"; interface Props {} const ReleasesView: FunctionComponent = () => { - const releases = useReduxStore(({ system }) => system.releases); - const update = useReduxAction(systemUpdateReleases); - useAutoUpdate(update); + const [releases] = useSystemReleases(); return ( - {(item) => ( + {({ data }) => ( Releases - Bazarr (System) - {item.map((v, idx) => ( + {data.map((v, idx) => ( diff --git a/frontend/src/System/Router.tsx b/frontend/src/System/Router.tsx index ab7254e06..575cf228b 100644 --- a/frontend/src/System/Router.tsx +++ b/frontend/src/System/Router.tsx @@ -1,5 +1,6 @@ import React, { FunctionComponent } from "react"; import { Redirect, Route, Switch } from "react-router-dom"; +import { useSetSidebar } from "../@redux/hooks/site"; import { RouterEmptyPath } from "../special-pages/404"; import Logs from "./Logs"; import Providers from "./Providers"; @@ -8,6 +9,7 @@ import Status from "./Status"; import Tasks from "./Tasks"; const Router: FunctionComponent = () => { + useSetSidebar("System"); return ( diff --git a/frontend/src/System/Status/index.tsx b/frontend/src/System/Status/index.tsx index 55d7dabc9..422c41dd8 100644 --- a/frontend/src/System/Status/index.tsx +++ b/frontend/src/System/Status/index.tsx @@ -9,10 +9,10 @@ import { FontAwesomeIcon } from "@fortawesome/react-fontawesome"; import React, { FunctionComponent } from "react"; import { Col, Container, Row } from "react-bootstrap"; import { Helmet } from "react-helmet"; -import { systemUpdateStatus } from "../../@redux/actions"; -import { useReduxAction, useReduxStore } from "../../@redux/hooks/base"; +import { useSystemHealth, useSystemStatus } from "../../@redux/hooks"; +import { AsyncStateOverlay } from "../../components"; import { GithubRepoRoot } from "../../constants"; -import { useAutoUpdate } from "../../utilites/hooks"; +import Table from "./table"; interface InfoProps { title: string; @@ -65,15 +65,28 @@ const InfoContainer: FunctionComponent<{ title: string }> = ({ interface Props {} const SystemStatusView: FunctionComponent = () => { - const status = useReduxStore((s) => s.system.status.data); - const update = useReduxAction(systemUpdateStatus); - useAutoUpdate(update); + const [health] = useSystemHealth(); + const [status] = useSystemStatus(); + + let health_table; + if (health.data.length) { + health_table = ( + + {({ data }) =>
} +
+ ); + } else { + health_table = "No issues with your configuration"; + } return ( Status - Bazarr (System) + + {health_table} + diff --git a/frontend/src/System/Status/table.tsx b/frontend/src/System/Status/table.tsx new file mode 100644 index 000000000..db6e9173d --- /dev/null +++ b/frontend/src/System/Status/table.tsx @@ -0,0 +1,27 @@ +import React, { FunctionComponent, useMemo } from "react"; +import { Column } from "react-table"; +import { SimpleTable } from "../../components"; + +interface Props { + health: readonly System.Health[]; +} + +const Table: FunctionComponent = (props) => { + const columns: Column[] = useMemo[]>( + () => [ + { + Header: "Object", + accessor: "object", + }, + { + Header: "Issue", + accessor: "issue", + }, + ], + [] + ); + + return ; +}; + +export default Table; diff --git a/frontend/src/System/Tasks/index.tsx b/frontend/src/System/Tasks/index.tsx index 8f9d6dc85..149624406 100644 --- a/frontend/src/System/Tasks/index.tsx +++ b/frontend/src/System/Tasks/index.tsx @@ -2,24 +2,18 @@ import { faSync } from "@fortawesome/free-solid-svg-icons"; import React, { FunctionComponent } from "react"; import { Container, Row } from "react-bootstrap"; import { Helmet } from "react-helmet"; -import { systemUpdateTasks } from "../../@redux/actions"; -import { useReduxAction, useReduxStore } from "../../@redux/hooks/base"; +import { useSystemTasks } from "../../@redux/hooks"; import { AsyncStateOverlay, ContentHeader } from "../../components"; -import { useAutoUpdate } from "../../utilites"; import Table from "./table"; interface Props {} const SystemTasksView: FunctionComponent = () => { - const tasks = useReduxStore((s) => s.system.tasks); - const update = useReduxAction(systemUpdateTasks); - - // TODO: Use Websocket - useAutoUpdate(update, 10 * 1000); + const [tasks, update] = useSystemTasks(); return ( - {(data) => ( + {({ data }) => ( Tasks - Bazarr (System) diff --git a/frontend/src/System/Tasks/table.tsx b/frontend/src/System/Tasks/table.tsx index 42b941158..55729dc01 100644 --- a/frontend/src/System/Tasks/table.tsx +++ b/frontend/src/System/Tasks/table.tsx @@ -2,8 +2,6 @@ import { faSync } from "@fortawesome/free-solid-svg-icons"; import { FontAwesomeIcon } from "@fortawesome/react-fontawesome"; import React, { FunctionComponent, useMemo } from "react"; import { Column } from "react-table"; -import { systemRunTasks } from "../../@redux/actions"; -import { useReduxAction } from "../../@redux/hooks/base"; import { SystemApi } from "../../apis"; import { AsyncButton, SimpleTable } from "../../components"; @@ -12,7 +10,6 @@ interface Props { } const Table: FunctionComponent = ({ tasks }) => { - const run = useReduxAction(systemRunTasks); const columns: Column[] = useMemo[]>( () => [ { @@ -37,10 +34,10 @@ const Table: FunctionComponent = ({ tasks }) => { return ( SystemApi.runTask(job_id)} - onSuccess={() => run(job_id)} variant="light" size="sm" disabled={row.value} + animation={false} > @@ -48,7 +45,7 @@ const Table: FunctionComponent = ({ tasks }) => { }, }, ], - [run] + [] ); return ; diff --git a/frontend/src/Wanted/Movies/index.tsx b/frontend/src/Wanted/Movies/index.tsx index a0a9c563a..3d9e5cfe0 100644 --- a/frontend/src/Wanted/Movies/index.tsx +++ b/frontend/src/Wanted/Movies/index.tsx @@ -15,7 +15,7 @@ import GenericWantedView from "../generic"; interface Props {} const WantedMoviesView: FunctionComponent = () => { - const [movies, update] = useWantedMovies(); + const [movies] = useWantedMovies(); const loader = useReduxAction(movieUpdateWantedByRange); @@ -74,9 +74,8 @@ const WantedMoviesView: FunctionComponent = () => { return ( []} + columns={columns} state={movies} - update={update} loader={loader} searchAll={searchAll} > diff --git a/frontend/src/Wanted/Router.tsx b/frontend/src/Wanted/Router.tsx index fa61a7899..750c18ee3 100644 --- a/frontend/src/Wanted/Router.tsx +++ b/frontend/src/Wanted/Router.tsx @@ -1,6 +1,10 @@ import React, { FunctionComponent } from "react"; import { Redirect, Route, Switch } from "react-router-dom"; -import { useIsRadarrEnabled, useIsSonarrEnabled } from "../@redux/hooks/site"; +import { + useIsRadarrEnabled, + useIsSonarrEnabled, + useSetSidebar, +} from "../@redux/hooks/site"; import { RouterEmptyPath } from "../special-pages/404"; import Movies from "./Movies"; import Series from "./Series"; @@ -8,6 +12,8 @@ import Series from "./Series"; const Router: FunctionComponent = () => { const sonarr = useIsSonarrEnabled(); const radarr = useIsRadarrEnabled(); + + useSetSidebar("Wanted"); return ( {sonarr && ( diff --git a/frontend/src/Wanted/Series/index.tsx b/frontend/src/Wanted/Series/index.tsx index c27aad6e0..04d268c4c 100644 --- a/frontend/src/Wanted/Series/index.tsx +++ b/frontend/src/Wanted/Series/index.tsx @@ -15,7 +15,7 @@ import GenericWantedView from "../generic"; interface Props {} const WantedSeriesView: FunctionComponent = () => { - const [series, update] = useWantedSeries(); + const [series] = useWantedSeries(); const loader = useReduxAction(seriesUpdateWantedByRange); @@ -82,9 +82,8 @@ const WantedSeriesView: FunctionComponent = () => { return ( []} + columns={columns} state={series} - update={update} loader={loader} searchAll={searchAll} > diff --git a/frontend/src/Wanted/generic/index.tsx b/frontend/src/Wanted/generic/index.tsx index b0aa336a8..0a198afd2 100644 --- a/frontend/src/Wanted/generic/index.tsx +++ b/frontend/src/Wanted/generic/index.tsx @@ -1,39 +1,29 @@ import { faSearch } from "@fortawesome/free-solid-svg-icons"; import { capitalize } from "lodash"; -import React, { FunctionComponent, useCallback, useMemo } from "react"; +import React from "react"; import { Container, Row } from "react-bootstrap"; import { Helmet } from "react-helmet"; -import { Column, TableUpdater } from "react-table"; -import { ContentHeader, PageTable } from "../../components"; -import { buildOrderList, GetItemId } from "../../utilites"; +import { Column } from "react-table"; +import { AsyncPageTable, ContentHeader } from "../../components"; -interface Props { +interface Props { type: "movies" | "series"; - columns: Column[]; - state: Readonly>>; + columns: Column[]; + state: Readonly>; loader: (start: number, length: number) => void; - update: (id?: number) => void; searchAll: () => Promise; } -const GenericWantedView: FunctionComponent = ({ +function GenericWantedView({ type, columns, state, - update, loader, searchAll, -}) => { +}: Props) { const typeName = capitalize(type); - const data = useMemo(() => buildOrderList(state.data), [state.data]); - - const updater = useCallback>( - (row, id: number) => { - update(id); - }, - [update] - ); + const dataCount = Object.keys(state.data.items).length; return ( @@ -42,28 +32,24 @@ const GenericWantedView: FunctionComponent = ({ void} icon={faSearch} > Search All - + data={[]} + > ); -}; +} export default GenericWantedView; diff --git a/frontend/src/apis/episodes.ts b/frontend/src/apis/episodes.ts index b49de1d72..72cd9de15 100644 --- a/frontend/src/apis/episodes.ts +++ b/frontend/src/apis/episodes.ts @@ -5,7 +5,7 @@ class EpisodeApi extends BaseApi { super("/episodes"); } - async bySeriesId(seriesid: number): Promise> { + async bySeriesId(seriesid: number[]): Promise> { return new Promise>((resolve, reject) => { this.get>>("", { seriesid }) .then((result) => { @@ -17,6 +17,18 @@ class EpisodeApi extends BaseApi { }); } + async byEpisodeId(episodeid: number[]): Promise> { + return new Promise>((resolve, reject) => { + this.get>>("", { episodeid }) + .then((result) => { + resolve(result.data.data); + }) + .catch((reason) => { + reject(reason); + }); + }); + } + async wanted(start: number, length: number) { return new Promise>((resolve, reject) => { this.get>("/wanted", { start, length }) @@ -29,8 +41,7 @@ class EpisodeApi extends BaseApi { }); } - // TODO: Implement this on backend - async wantedBy(episodeid?: number) { + async wantedBy(episodeid: number[]) { return new Promise>((resolve, reject) => { this.get>("/wanted", { episodeid }) .then((result) => { @@ -42,18 +53,6 @@ class EpisodeApi extends BaseApi { }); } - async byEpisodeId(episodeid: number): Promise> { - return new Promise>((resolve, reject) => { - this.get>>("", { episodeid }) - .then((result) => { - resolve(result.data.data); - }) - .catch((reason) => { - reject(reason); - }); - }); - } - async history(episodeid?: number): Promise> { return new Promise>((resolve, reject) => { this.get>>("/history", { episodeid }) diff --git a/frontend/src/apis/index.ts b/frontend/src/apis/index.ts index b583d3a84..f79ca6976 100644 --- a/frontend/src/apis/index.ts +++ b/frontend/src/apis/index.ts @@ -1,18 +1,16 @@ import Axios, { AxiosError, AxiosInstance, CancelTokenSource } from "axios"; import { siteRedirectToAuth, siteUpdateOffline } from "../@redux/actions"; import reduxStore from "../@redux/store"; +import { getBaseUrl } from "../utilites"; class Api { axios!: AxiosInstance; source!: CancelTokenSource; constructor() { + const baseUrl = `${getBaseUrl()}/api/`; if (process.env.NODE_ENV === "development") { - this.initialize("/api/", process.env["REACT_APP_APIKEY"]!); + this.initialize(baseUrl, process.env["REACT_APP_APIKEY"]!); } else { - const baseUrl = - window.Bazarr.baseUrl === "/" - ? "/api/" - : `${window.Bazarr.baseUrl}/api/`; this.initialize(baseUrl, window.Bazarr.apiKey); } } @@ -34,7 +32,6 @@ class Api { this.axios.interceptors.response.use( (resp) => { - this.onOnline(); if (resp.status >= 200 && resp.status < 300) { return Promise.resolve(resp); } else { @@ -46,9 +43,7 @@ class Api { if (error.response) { const response = error.response; this.handleError(response.status); - this.onOnline(); } else { - this.onOffline(); error.message = "You have disconnected to Bazarr backend"; } return Promise.reject(error); diff --git a/frontend/src/apis/movies.ts b/frontend/src/apis/movies.ts index 4dd265ccc..044b4540a 100644 --- a/frontend/src/apis/movies.ts +++ b/frontend/src/apis/movies.ts @@ -37,9 +37,9 @@ class MovieApi extends BaseApi { }); } - async movies(id?: number[]) { + async movies(radarrid?: number[]) { return new Promise>((resolve, reject) => { - this.get>("", { radarrid: id }) + this.get>("", { radarrid }) .then((result) => { resolve(result.data); }) @@ -81,8 +81,7 @@ class MovieApi extends BaseApi { }); } - // TODO: Implement this on backend - async wantedBy(radarrid?: number) { + async wantedBy(radarrid: number[]) { return new Promise>((resolve, reject) => { this.get>("/wanted", { radarrid }) .then((result) => { diff --git a/frontend/src/apis/series.ts b/frontend/src/apis/series.ts index 5dce47198..516891b50 100644 --- a/frontend/src/apis/series.ts +++ b/frontend/src/apis/series.ts @@ -5,9 +5,9 @@ class SeriesApi extends BaseApi { super("/series"); } - async series(id?: number[]) { + async series(seriesid?: number[]) { return new Promise>((resolve, reject) => { - this.get>("", { seriesid: id }) + this.get>("", { seriesid }) .then((result) => { resolve(result.data); }) diff --git a/frontend/src/apis/system.ts b/frontend/src/apis/system.ts index bc0bb181c..8322781fd 100644 --- a/frontend/src/apis/system.ts +++ b/frontend/src/apis/system.ts @@ -89,6 +89,18 @@ class SystemApi extends BaseApi { }); } + async health() { + return new Promise((resolve, reject) => { + this.get>("/health") + .then((result) => { + resolve(result.data.data); + }) + .catch((reason) => { + reject(reason); + }); + }); + } + async logs() { return new Promise>((resolve, reject) => { this.get>>("/logs") diff --git a/frontend/src/components/async.tsx b/frontend/src/components/async.tsx index c34a5e743..092a08a8d 100644 --- a/frontend/src/components/async.tsx +++ b/frontend/src/components/async.tsx @@ -25,10 +25,15 @@ enum RequestState { Invalid, } +interface ChildProps { + data: NonNullable>; + error?: Error; +} + interface AsyncStateOverlayProps { state: AsyncState; exist?: (item: T) => boolean; - children?: (item: NonNullable>, error?: Error) => JSX.Element; + children?: FunctionComponent>; } function defaultExist(item: any) { @@ -83,7 +88,7 @@ export function AsyncStateOverlay(props: AsyncStateOverlayProps) { } } - return children ? children(state.data!, state.error) : null; + return children ? children({ data: state.data!, error: state.error }) : null; } interface PromiseProps { @@ -156,6 +161,7 @@ interface AsyncButtonProps { onChange?: (v: boolean) => void; noReset?: boolean; + animation?: boolean; promise: () => Promise | null; onSuccess?: (result: T) => void; @@ -171,6 +177,7 @@ export function AsyncButton( promise, onSuccess, noReset, + animation, error, onChange, disabled, @@ -230,15 +237,19 @@ export function AsyncButton( } }, [error, onChange, promise, onSuccess, state]); - let children = propChildren; - if (loading) { - children = ; - } + const showAnimation = animation ?? true; - if (state === RequestState.Success) { - children = ; - } else if (state === RequestState.Error) { - children = ; + let children = propChildren; + if (showAnimation) { + if (loading) { + children = ; + } + + if (state === RequestState.Success) { + children = ; + } else if (state === RequestState.Error) { + children = ; + } } return ( diff --git a/frontend/src/components/buttons.tsx b/frontend/src/components/buttons.tsx index 06b247804..c472c1256 100644 --- a/frontend/src/components/buttons.tsx +++ b/frontend/src/components/buttons.tsx @@ -53,6 +53,7 @@ export const ActionButton: FunctionComponent = ({ interface ActionButtonItemProps { loading?: boolean; + alwaysShowText?: boolean; icon: IconDefinition; children?: string; } @@ -61,7 +62,9 @@ export const ActionButtonItem: FunctionComponent = ({ icon, children, loading, + alwaysShowText, }) => { + const showText = alwaysShowText === true || loading !== true; return ( = ({ icon={loading ? faCircleNotch : icon} spin={loading} > - {children && !loading ? ( + {children && showText ? ( {children} ) : null} diff --git a/frontend/src/components/modals/HistoryModal.tsx b/frontend/src/components/modals/HistoryModal.tsx index 6ffa3d8ff..eb0977157 100644 --- a/frontend/src/components/modals/HistoryModal.tsx +++ b/frontend/src/components/modals/HistoryModal.tsx @@ -105,7 +105,7 @@ export const MovieHistoryModal: FunctionComponent = (props) => { return ( - {(data) => ( + {({ data }) => ( - {(data) => ( + {({ data }) => ( = ( const [availableLanguages] = useLanguages(true); const movie = usePayload(modal.modalKey); - const [, update] = useMovieBy(movie?.radarrId); const closeModal = useCloseModal(); @@ -63,10 +57,7 @@ const MovieUploadModal: FunctionComponent = ( return null; } }} - onSuccess={() => { - closeModal(); - update(); - }} + onSuccess={closeModal} > Upload diff --git a/frontend/src/components/modals/SeriesUploadModal.tsx b/frontend/src/components/modals/SeriesUploadModal.tsx index 4e60964d4..7c479cdc7 100644 --- a/frontend/src/components/modals/SeriesUploadModal.tsx +++ b/frontend/src/components/modals/SeriesUploadModal.tsx @@ -24,11 +24,7 @@ import { useCloseModal, usePayload, } from ".."; -import { - useEpisodesBy, - useProfileBy, - useProfileItems, -} from "../../@redux/hooks"; +import { useProfileBy, useProfileItems } from "../../@redux/hooks"; import { EpisodesApi, SubtitlesApi } from "../../apis"; import { Selector } from "../inputs"; import BaseModal, { BaseModalProps } from "./BaseModal"; @@ -59,15 +55,16 @@ type EpisodeMap = { [name: string]: Item.Episode; }; -interface MovieProps {} +interface SerieProps { + episodes: readonly Item.Episode[]; +} -const SeriesUploadModal: FunctionComponent = ( - modal -) => { +const SeriesUploadModal: FunctionComponent = ({ + episodes, + ...modal +}) => { const series = usePayload(modal.modalKey); - const [episodes, updateEpisodes] = useEpisodesBy(series?.sonarrSeriesId); - const [uploading, setUpload] = useState(false); const closeModal = useCloseModal(); @@ -122,7 +119,7 @@ const SeriesUploadModal: FunctionComponent = ( const results = await SubtitlesApi.info(names); const episodeMap = results.reduce((prev, curr) => { - const ep = episodes.data.find( + const ep = episodes.find( (v) => v.season === curr.season && v.episode === curr.episode ); if (ep) { @@ -140,7 +137,7 @@ const SeriesUploadModal: FunctionComponent = ( ); } }, - [episodes.data] + [episodes] ); const updateLanguage = useCallback( @@ -386,7 +383,6 @@ const SeriesUploadModal: FunctionComponent = ( onSuccess={() => { closeModal(); setFiles([]); - updateEpisodes(); }} > Upload @@ -419,7 +415,7 @@ const SeriesUploadModal: FunctionComponent = ( diff --git a/frontend/src/components/modals/SubtitleToolModal.tsx b/frontend/src/components/modals/SubtitleToolModal.tsx index 5971bc494..a21878696 100644 --- a/frontend/src/components/modals/SubtitleToolModal.tsx +++ b/frontend/src/components/modals/SubtitleToolModal.tsx @@ -330,14 +330,9 @@ const TranslateModal: FunctionComponent = ({ ); }; -interface STMProps { - update: () => void; -} +interface STMProps {} -const STM: FunctionComponent = ({ - update, - ...props -}) => { +const STM: FunctionComponent = ({ ...props }) => { const items = usePayload(props.modalKey); const [updating, setUpdate] = useState(false); @@ -380,10 +375,8 @@ const STM: FunctionComponent = ({ setProcessState(states); } setUpdate(false); - - update(); }, - [closeUntil, selections, update] + [closeUntil, selections] ); const showModal = useShowModal(); diff --git a/frontend/src/components/tables/AsyncPageTable.tsx b/frontend/src/components/tables/AsyncPageTable.tsx new file mode 100644 index 000000000..3ccf54627 --- /dev/null +++ b/frontend/src/components/tables/AsyncPageTable.tsx @@ -0,0 +1,128 @@ +import { isNull } from "lodash"; +import React, { useCallback, useEffect, useMemo, useState } from "react"; +import { PluginHook, TableOptions, useTable } from "react-table"; +import { LoadingIndicator } from ".."; +import { useReduxStore } from "../../@redux/hooks/base"; +import { buildOrderListFrom, isNonNullable, ScrollToTop } from "../../utilites"; +import BaseTable, { TableStyleProps, useStyleAndOptions } from "./BaseTable"; +import PageControl from "./PageControl"; +import { useDefaultSettings } from "./plugins"; + +type Props = TableOptions & + TableStyleProps & { + plugins?: PluginHook[]; + aos: AsyncOrderState; + loader: (start: number, length: number) => void; + }; + +export default function AsyncPageTable(props: Props) { + const { aos, plugins, loader, ...remain } = props; + const { style, options } = useStyleAndOptions(remain); + + const { + updating, + data: { order, items, fetched }, + } = aos; + + const allPlugins: PluginHook[] = [useDefaultSettings]; + + if (plugins) { + allPlugins.push(...plugins); + } + + // Impl a new pagination system instead of hooking into the existing one + const [pageIndex, setIndex] = useState(0); + const pageSize = useReduxStore((s) => s.site.pageSize); + const totalRows = order.length; + const pageCount = Math.ceil(totalRows / pageSize); + + const previous = useCallback(() => { + setIndex((idx) => idx - 1); + }, []); + + const next = useCallback(() => { + setIndex((idx) => idx + 1); + }, []); + + const goto = useCallback((idx: number) => { + setIndex(idx); + }, []); + + const pageStart = pageIndex * pageSize; + const pageEnd = pageStart + pageSize; + + const visibleItemIds = useMemo(() => order.slice(pageStart, pageEnd), [ + pageStart, + pageEnd, + order, + ]); + + const newData = useMemo(() => buildOrderListFrom(items, visibleItemIds), [ + items, + visibleItemIds, + ]); + + const newOptions = useMemo>( + () => ({ + ...options, + data: newData, + }), + [options, newData] + ); + + const instance = useTable(newOptions, ...allPlugins); + + const { + getTableProps, + getTableBodyProps, + headerGroups, + rows, + prepareRow, + } = instance; + + useEffect(() => { + ScrollToTop(); + }, [pageIndex]); + + useEffect(() => { + const needInit = visibleItemIds.length === 0 && fetched === false; + const needRefresh = !visibleItemIds.every(isNonNullable); + if (needInit || needRefresh) { + loader(pageStart, pageSize); + } + }, [visibleItemIds, pageStart, pageSize, loader, fetched]); + + const showLoading = useMemo( + () => + updating && (visibleItemIds.every(isNull) || visibleItemIds.length === 0), + [visibleItemIds, updating] + ); + + if (showLoading) { + return ; + } + + return ( + + + 0} + canNext={pageIndex < pageCount - 1} + previous={previous} + next={next} + goto={goto} + > + + ); +} diff --git a/frontend/src/components/tables/PageControl.tsx b/frontend/src/components/tables/PageControl.tsx index 32acad742..1680c5d39 100644 --- a/frontend/src/components/tables/PageControl.tsx +++ b/frontend/src/components/tables/PageControl.tsx @@ -79,12 +79,12 @@ const PageControl: FunctionComponent = ({ diff --git a/frontend/src/components/tables/PageTable.tsx b/frontend/src/components/tables/PageTable.tsx index 1ed2d1632..6df6d5308 100644 --- a/frontend/src/components/tables/PageTable.tsx +++ b/frontend/src/components/tables/PageTable.tsx @@ -1,5 +1,5 @@ -import { isNull, isUndefined } from "lodash"; -import React, { useCallback, useEffect } from "react"; +import { isUndefined } from "lodash"; +import React, { useEffect } from "react"; import { PluginHook, TableOptions, @@ -9,33 +9,23 @@ import { } from "react-table"; import { useReduxStore } from "../../@redux/hooks/base"; import { ScrollToTop } from "../../utilites"; -import { AsyncStateOverlay } from "../async"; import BaseTable, { TableStyleProps, useStyleAndOptions } from "./BaseTable"; import PageControl from "./PageControl"; -import { - useAsyncPagination, - useCustomSelection, - useDefaultSettings, -} from "./plugins"; +import { useCustomSelection, useDefaultSettings } from "./plugins"; type Props = TableOptions & TableStyleProps & { - async?: boolean; canSelect?: boolean; autoScroll?: boolean; plugins?: PluginHook[]; }; export default function PageTable(props: Props) { - const { async, autoScroll, canSelect, plugins, ...remain } = props; + const { autoScroll, canSelect, plugins, ...remain } = props; const { style, options } = useStyleAndOptions(remain); const allPlugins: PluginHook[] = [useDefaultSettings, usePagination]; - if (async) { - allPlugins.push(useAsyncPagination); - } - if (canSelect) { allPlugins.push(useRowSelect, useCustomSelection); } @@ -62,7 +52,7 @@ export default function PageTable(props: Props) { nextPage, previousPage, setPageSize, - state: { pageIndex, pageSize, pageToLoad, needLoadingScreen }, + state: { pageIndex, pageSize }, } = instance; const globalPageSize = useReduxStore((s) => s.site.pageSize); @@ -91,28 +81,6 @@ export default function PageTable(props: Props) { setPageSize, ]); - const total = options.asyncState - ? options.asyncState.data.order.length - : rows.length; - - const orderIdStateValidater = useCallback( - (state: OrderIdState) => { - const start = pageIndex * pageSize; - const end = start + pageSize; - return state.order.slice(start, end).every(isNull) === false; - }, - [pageIndex, pageSize] - ); - - if (needLoadingScreen && options.asyncState) { - return ( - - ); - } - return ( (props: Props) { tableBodyProps={getTableBodyProps()} > (hooks: Hooks) { - hooks.stateReducers.push(reducer); - hooks.useInstance.push(useInstance); - hooks.useOptions.push(useOptions); -} -useAsyncPagination.pluginName = pluginName; - -function reducer( - state: TableState, - action: ActionType, - previous: TableState | undefined, - instance: TableInstance | undefined -): ReducerTableState { - if (action.type === ActionLoadingChange && instance) { - let pageToLoad: - | PageControlAction - | undefined = action.pageToLoad as PageControlAction; - let needLoadingScreen = false; - const { asyncState } = instance; - const { pageIndex, pageSize } = state; - let index = pageIndex; - if (pageToLoad === "prev") { - index -= 1; - } else if (pageToLoad === "next") { - index += 1; - } else if (typeof pageToLoad === "number") { - index = pageToLoad; - } - const pageStart = index * pageSize; - const pageEnd = pageStart + pageSize; - if (asyncState) { - const error = asyncState.error; - const order = asyncState.data.order.slice(pageStart, pageEnd); - - const isInitializedError = order.length === 0 && error !== undefined; - const isLoadingError = order.length !== 0 && order.every(isNull); - - if (isInitializedError || isLoadingError) { - needLoadingScreen = true; - } else if (order.every(isNonNullable)) { - pageToLoad = undefined; - } - } - return { ...state, pageToLoad, needLoadingScreen }; - } - return state; -} - -function useOptions(options: TableOptions) { - options.manualPagination = true; - if (options.initialState === undefined) { - options.initialState = {}; - } - options.initialState.pageToLoad = 0; - options.initialState.needLoadingScreen = true; - return options; -} - -function useInstance(instance: TableInstance) { - const { - plugins, - asyncLoader, - dispatch, - asyncState, - asyncId, - rows, - nextPage, - previousPage, - gotoPage, - state: { pageIndex, pageSize, pageToLoad }, - } = instance; - - ensurePluginOrder(plugins, ["usePagination"], pluginName); - - const totalCount = asyncState?.data.order.length ?? 0; - const pageCount = Math.ceil(totalCount / pageSize); - const pageStart = pageIndex * pageSize; - const pageEnd = pageStart + pageSize; - - useEffect(() => { - // TODO Lazy Load - if (pageToLoad === undefined) { - return; - } - asyncLoader && asyncLoader(pageStart, pageSize); - }, [asyncLoader, pageStart, pageSize, pageToLoad]); - - const setPageToLoad = useCallback( - (pageToLoad?: PageControlAction) => { - dispatch({ type: ActionLoadingChange, pageToLoad }); - }, - [dispatch] - ); - - useEffect(() => { - if (asyncState?.updating === false) { - setPageToLoad(); - } - }, [asyncState?.updating, setPageToLoad]); - - const newGoto = useCallback( - (updater: number | ((pageIndex: number) => number)) => { - let page: number; - if (typeof updater === "number") { - page = updater; - } else { - page = updater(pageIndex); - } - if (page === pageIndex) { - return; - } - setPageToLoad(page); - gotoPage(page); - }, - [pageIndex, setPageToLoad, gotoPage] - ); - - const newPrevious = useCallback(() => { - if (pageIndex === 0) { - return; - } - setPageToLoad("prev"); - previousPage(); - }, [setPageToLoad, previousPage, pageIndex]); - - const newNext = useCallback(() => { - if (pageIndex === pageCount) { - return; - } - setPageToLoad("next"); - nextPage(); - }, [setPageToLoad, nextPage, pageCount, pageIndex]); - - const newPages = useMemo(() => { - // TODO: Performance - - const order = (asyncState?.data.order - .slice(pageStart, pageEnd) - .filter(isNonNullable) ?? []) as number[]; - - return order.flatMap((num) => { - const row = rows.find((v) => asyncId && asyncId(v.original) === num); - if (row) { - return [row]; - } else { - return []; - } - }); - }, [pageStart, pageEnd, asyncId, asyncState?.data.order, rows]); - - Object.assign, Partial>>(instance, { - previousPage: newPrevious, - nextPage: newNext, - gotoPage: newGoto, - page: newPages, - pageCount, - }); -} - -export default useAsyncPagination; diff --git a/frontend/src/components/tables/plugins/useDefaultSettings.tsx b/frontend/src/components/tables/plugins/useDefaultSettings.tsx index 438316739..8f36ec1dd 100644 --- a/frontend/src/components/tables/plugins/useDefaultSettings.tsx +++ b/frontend/src/components/tables/plugins/useDefaultSettings.tsx @@ -23,16 +23,10 @@ function useOptions(options: TableOptions) { options.initialState = {}; } - options.initialState.needLoadingScreen = false; - if (options.initialState.pageSize === undefined) { options.initialState.pageSize = pageSize; } - if (options.asyncLoader === undefined) { - options.initialState.pageToLoad = undefined; - } - return options; } diff --git a/frontend/src/generic/BaseItemView/index.tsx b/frontend/src/generic/BaseItemView/index.tsx index 34a3a27e9..616e4f771 100644 --- a/frontend/src/generic/BaseItemView/index.tsx +++ b/frontend/src/generic/BaseItemView/index.tsx @@ -1,10 +1,6 @@ import { faCheck, faList, faUndo } from "@fortawesome/free-solid-svg-icons"; -import React, { - FunctionComponent, - useCallback, - useMemo, - useState, -} from "react"; +import { uniqBy } from "lodash"; +import React, { useCallback, useMemo, useState } from "react"; import { Container, Dropdown, Row } from "react-bootstrap"; import { Helmet } from "react-helmet"; import { Column } from "react-table"; @@ -12,29 +8,25 @@ import { useLanguageProfiles } from "../../@redux/hooks"; import { useReduxActionWith } from "../../@redux/hooks/base"; import { AsyncActionDispatcher } from "../../@redux/types"; import { ContentHeader } from "../../components"; -import { GetItemId, isNonNullable, mergeArray } from "../../utilites"; +import { GetItemId, isNonNullable } from "../../utilites"; import Table from "./table"; -export interface SharedProps { +export interface SharedProps { name: string; loader: (start: number, length: number) => void; - columns: Column[]; + columns: Column[]; modify: (form: FormType.ModifyItem) => Promise; - state: AsyncState>; + state: AsyncOrderState; } -export function ExtendItemComparer(lhs: Item.Base, rhs: Item.Base): boolean { - return GetItemId(lhs) === GetItemId(rhs); -} - -interface Props extends SharedProps { +interface Props extends SharedProps { updateAction: (id?: number[]) => AsyncActionDispatcher; } -const BaseItemView: FunctionComponent = ({ +function BaseItemView({ updateAction, ...shared -}) => { +}: Props) { const state = shared.state; const [pendingEditMode, setPendingEdit] = useState(false); @@ -51,8 +43,8 @@ const BaseItemView: FunctionComponent = ({ const update = useReduxActionWith(updateAction, onUpdated); - const [selections, setSelections] = useState([]); - const [dirtyItems, setDirty] = useState([]); + const [selections, setSelections] = useState([]); + const [dirtyItems, setDirty] = useState([]); const [profiles] = useLanguageProfiles(); @@ -80,7 +72,7 @@ const BaseItemView: FunctionComponent = ({ item.profileId = id; return item; }); - const newDirty = mergeArray(dirtyItems, newItems, ExtendItemComparer); + const newDirty = uniqBy([...newItems, ...dirtyItems], GetItemId); setDirty(newDirty); }, [selections, dirtyItems] @@ -95,20 +87,12 @@ const BaseItemView: FunctionComponent = ({ setPendingEdit(true); }, [shared.state.data.order, update]); - const endEdit = useCallback( - (cancel: boolean = false) => { - if (!cancel && dirtyItems.length > 0) { - const ids = dirtyItems.map(GetItemId); - update(ids); - } else { - setEdit(false); - setDirty([]); - } - setPendingEdit(false); - setSelections([]); - }, - [dirtyItems, update] - ); + const endEdit = useCallback(() => { + setEdit(false); + setDirty([]); + setPendingEdit(false); + setSelections([]); + }, []); const saveItems = useCallback(() => { const form: FormType.ModifyItem = { @@ -143,14 +127,14 @@ const BaseItemView: FunctionComponent = ({ - endEdit(true)}> + Cancel endEdit()} + onSuccess={endEdit} > Save @@ -170,7 +154,6 @@ const BaseItemView: FunctionComponent = ({ = ({ ); -}; +} export default BaseItemView; diff --git a/frontend/src/generic/BaseItemView/table.tsx b/frontend/src/generic/BaseItemView/table.tsx index 387a9be58..8029c9231 100644 --- a/frontend/src/generic/BaseItemView/table.tsx +++ b/frontend/src/generic/BaseItemView/table.tsx @@ -1,31 +1,37 @@ -import React, { FunctionComponent, useCallback, useMemo } from "react"; -import { TableUpdater } from "react-table"; -import { ExtendItemComparer, SharedProps } from "."; +import { uniqBy } from "lodash"; +import React, { useCallback, useMemo } from "react"; +import { TableOptions, TableUpdater, useRowSelect } from "react-table"; +import { SharedProps } from "."; import { useLanguageProfiles } from "../../@redux/hooks"; -import { ItemEditorModal, PageTable, useShowModal } from "../../components"; -import { buildOrderList, GetItemId, useMergeArray } from "../../utilites"; +import { + AsyncPageTable, + ItemEditorModal, + SimpleTable, + useShowModal, +} from "../../components"; +import { TableStyleProps } from "../../components/tables/BaseTable"; +import { useCustomSelection } from "../../components/tables/plugins"; +import { buildOrderList, GetItemId } from "../../utilites"; -interface Props extends SharedProps { - dirtyItems: readonly Item.Base[]; +interface Props extends SharedProps { + dirtyItems: readonly T[]; editMode: boolean; - select: React.Dispatch; - update: (ids?: number[]) => void; + select: React.Dispatch; } -const Table: FunctionComponent = ({ +function Table({ state, dirtyItems, - update, modify, editMode, select, columns, loader, name, -}) => { +}: Props) { const showModal = useShowModal(); - const updateRow = useCallback>( + const updateRow = useCallback>( (row, modalKey: string) => { showModal(modalKey, row.original); }, @@ -36,37 +42,43 @@ const Table: FunctionComponent = ({ const orderList = useMemo(() => buildOrderList(idState), [idState]); - const data = useMergeArray(orderList, dirtyItems, ExtendItemComparer); + const data = useMemo(() => uniqBy([...dirtyItems, ...orderList], GetItemId), [ + dirtyItems, + orderList, + ]); const [profiles] = useLanguageProfiles(); + const options: Partial & TableStyleProps> = { + loose: [profiles], + emptyText: `No ${name} Found`, + externalUpdate: updateRow, + }; + return ( - - { - const id = GetItemId(item); - update([id]); - }} - > + {editMode ? ( + // TODO: Use PageTable + + ) : ( + + )} + ); -}; +} export default Table; diff --git a/frontend/src/generic/blacklist.tsx b/frontend/src/generic/blacklist.tsx index c26e0e22f..df33e3efe 100644 --- a/frontend/src/generic/blacklist.tsx +++ b/frontend/src/generic/blacklist.tsx @@ -5,7 +5,7 @@ import { AsyncButton } from "../components"; interface Props { history: History.Base; - update: () => void; + update?: () => void; promise: (form: FormType.AddBlacklist) => Promise; } diff --git a/frontend/src/index.tsx b/frontend/src/index.tsx index 05d81769a..a15e7c9f1 100644 --- a/frontend/src/index.tsx +++ b/frontend/src/index.tsx @@ -1,11 +1,12 @@ import "@fontsource/roboto/300.css"; -import React, { FunctionComponent } from "react"; +import React, { FunctionComponent, useEffect } from "react"; import ReactDOM from "react-dom"; import { Provider } from "react-redux"; import { Route, Switch } from "react-router"; import { BrowserRouter } from "react-router-dom"; import store from "./@redux/store"; import "./@scss/index.scss"; +import Socketio from "./@socketio"; import App from "./App"; import Auth from "./Auth"; import { useBaseUrl } from "./utilites"; @@ -13,6 +14,10 @@ import { useBaseUrl } from "./utilites"; const MainRouter: FunctionComponent = () => { const baseUrl = useBaseUrl(); + useEffect(() => { + Socketio.initialize(); + }, []); + return ( diff --git a/frontend/src/setupProxy.js b/frontend/src/setupProxy.js new file mode 100644 index 000000000..9c8092142 --- /dev/null +++ b/frontend/src/setupProxy.js @@ -0,0 +1,18 @@ +const proxy = require("http-proxy-middleware"); + +const target = "http://localhost:6767"; + +module.exports = function (app) { + app.use( + proxy(["/api", "/images", "/test", "/bazarr.log"], { + target, + }) + ); + app.use( + proxy("/api/socket.io", { + target, + ws: true, + logLevel: "error", + }) + ); +}; diff --git a/frontend/src/utilites/hooks.ts b/frontend/src/utilites/hooks.ts index ad355a099..a0b3f9310 100644 --- a/frontend/src/utilites/hooks.ts +++ b/frontend/src/utilites/hooks.ts @@ -1,17 +1,9 @@ import React, { useCallback, useEffect, useMemo, useState } from "react"; import { useHistory } from "react-router"; -import { mergeArray } from "."; +import { getBaseUrl } from "."; export function useBaseUrl(slash: boolean = false) { - if (process.env.NODE_ENV === "development") { - return "/"; - } else { - let url = window.Bazarr.baseUrl ?? "/"; - if (slash && !url.endsWith("/")) { - url += "/"; - } - return url; - } + return useMemo(() => getBaseUrl(slash), [slash]); } export function useGotoHomepage() { @@ -51,34 +43,6 @@ export function useSessionStorage( return [sessionStorage.getItem(key), dispatch]; } -export function useMergeArray( - olds: readonly T[], - news: readonly T[], - comparer: Comparer> -) { - return useMemo(() => mergeArray(olds, news, comparer), [ - olds, - news, - comparer, - ]); -} - -export function useAutoUpdate(action: () => void, interval?: number) { - useEffect(() => { - action(); - - let hd: NodeJS.Timeout | null = null; - if (interval !== undefined) { - hd = setInterval(action, interval); - } - return () => { - if (hd !== null) { - clearInterval(hd); - } - }; - }, [action, interval]); -} - export function useWatcher( curr: T, expected: T, diff --git a/frontend/src/utilites/index.ts b/frontend/src/utilites/index.ts index 37c71552d..f4860e171 100644 --- a/frontend/src/utilites/index.ts +++ b/frontend/src/utilites/index.ts @@ -1,5 +1,5 @@ import { Dispatch } from "react"; -import { isEpisode, isMovie, isNullable, isSeries } from "./validate"; +import { isEpisode, isMovie, isSeries } from "./validate"; export function updateAsyncState( promise: Promise, @@ -26,6 +26,21 @@ export function updateAsyncState( }); } +export function getBaseUrl(slash: boolean = false) { + let url: string = "/"; + if (process.env.NODE_ENV !== "development") { + url = window.Bazarr.baseUrl; + } + + const endsWithSlash = url.endsWith("/"); + if (slash && !endsWithSlash) { + return `${url}/`; + } else if (!slash && endsWithSlash) { + return url.slice(0, -1); + } + return url; +} + export function copyToClipboard(s: string) { let field = document.createElement("textarea"); field.innerText = s; @@ -62,7 +77,14 @@ export function GetItemId(item: any): number { } export function buildOrderList(state: OrderIdState): T[] { - const { items, order } = state; + const { order, items } = state; + return buildOrderListFrom(items, order); +} + +export function buildOrderListFrom( + items: IdState, + order: (number | null)[] +): T[] { return order.flatMap((v) => { if (v !== null && v in items) { const item = items[v]; @@ -73,32 +95,6 @@ export function buildOrderList(state: OrderIdState): T[] { }); } -// Replace elements in old array with news -export function mergeArray( - olds: readonly T[], - news: readonly T[], - comparer: Comparer> -) { - const list = [...olds]; - const newList = news.filter((v) => !isNullable(v)) as NonNullable[]; - // Performance - newList.forEach((v) => { - const idx = list.findIndex((n, idx) => { - if (!isNullable(n)) { - return comparer(n, v); - } else { - return false; - } - }); - if (idx !== -1) { - list[idx] = v; - } else { - list.push(v); - } - }); - return list; -} - export function BuildKey(...args: any[]) { return args.join("-"); } diff --git a/frontend/src/utilites/logger.ts b/frontend/src/utilites/logger.ts index 708706e77..c32272d6d 100644 --- a/frontend/src/utilites/logger.ts +++ b/frontend/src/utilites/logger.ts @@ -11,3 +11,13 @@ export function log(type: LoggerType, msg: string, ...payload: any[]) { logger(`[${type}] ${msg}`, ...payload); } } + +export function conditionalLog( + condition: boolean, + msg: string, + ...payload: any[] +) { + if (condition) { + log("error", msg, payload); + } +} diff --git a/libs/bidict/__init__.py b/libs/bidict/__init__.py new file mode 100644 index 000000000..725e18750 --- /dev/null +++ b/libs/bidict/__init__.py @@ -0,0 +1,111 @@ +# -*- coding: utf-8 -*- +# Copyright 2009-2019 Joshua Bronson. All Rights Reserved. +# +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. + + +#============================================================================== +# * Welcome to the bidict source code * +#============================================================================== + +# Doing a code review? You'll find a "Code review nav" comment like the one +# below at the top and bottom of the most important source files. This provides +# a suggested initial path through the source when reviewing. +# +# Note: If you aren't reading this on https://github.com/jab/bidict, you may be +# viewing an outdated version of the code. Please head to GitHub to review the +# latest version, which contains important improvements over older versions. +# +# Thank you for reading and for any feedback you provide. + +# * Code review nav * +#============================================================================== +# Current: __init__.py Next: _abc.py → +#============================================================================== + + +""" +Efficient, Pythonic bidirectional map implementation and related functionality. + +.. code-block:: python + + >>> from bidict import bidict + >>> element_by_symbol = bidict({'H': 'hydrogen'}) + >>> element_by_symbol['H'] + 'hydrogen' + >>> element_by_symbol.inverse['hydrogen'] + 'H' + + +Please see https://github.com/jab/bidict for the most up-to-date code and +https://bidict.readthedocs.io for the most up-to-date documentation +if you are reading this elsewhere. + + +.. :copyright: (c) 2019 Joshua Bronson. +.. :license: MPLv2. See LICENSE for details. +""" + +# This __init__.py only collects functionality implemented in the rest of the +# source and exports it under the `bidict` module namespace (via `__all__`). + +from ._abc import BidirectionalMapping +from ._base import BidictBase +from ._mut import MutableBidict +from ._bidict import bidict +from ._dup import DuplicationPolicy, IGNORE, OVERWRITE, RAISE +from ._exc import ( + BidictException, DuplicationError, + KeyDuplicationError, ValueDuplicationError, KeyAndValueDuplicationError) +from ._util import inverted +from ._frozenbidict import frozenbidict +from ._frozenordered import FrozenOrderedBidict +from ._named import namedbidict +from ._orderedbase import OrderedBidictBase +from ._orderedbidict import OrderedBidict +from .metadata import ( + __author__, __maintainer__, __copyright__, __email__, __credits__, __url__, + __license__, __status__, __description__, __keywords__, __version__, __version_info__) + + +__all__ = ( + '__author__', + '__maintainer__', + '__copyright__', + '__email__', + '__credits__', + '__license__', + '__status__', + '__description__', + '__keywords__', + '__url__', + '__version__', + '__version_info__', + 'BidirectionalMapping', + 'BidictException', + 'DuplicationPolicy', + 'IGNORE', + 'OVERWRITE', + 'RAISE', + 'DuplicationError', + 'KeyDuplicationError', + 'ValueDuplicationError', + 'KeyAndValueDuplicationError', + 'BidictBase', + 'MutableBidict', + 'frozenbidict', + 'bidict', + 'namedbidict', + 'FrozenOrderedBidict', + 'OrderedBidictBase', + 'OrderedBidict', + 'inverted', +) + + +# * Code review nav * +#============================================================================== +# Current: __init__.py Next: _abc.py → +#============================================================================== diff --git a/libs/bidict/_abc.py b/libs/bidict/_abc.py new file mode 100644 index 000000000..268b00480 --- /dev/null +++ b/libs/bidict/_abc.py @@ -0,0 +1,81 @@ +# -*- coding: utf-8 -*- +# Copyright 2009-2019 Joshua Bronson. All Rights Reserved. +# +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. + + +#============================================================================== +# * Welcome to the bidict source code * +#============================================================================== + +# Doing a code review? You'll find a "Code review nav" comment like the one +# below at the top and bottom of the most important source files. This provides +# a suggested initial path through the source when reviewing. +# +# Note: If you aren't reading this on https://github.com/jab/bidict, you may be +# viewing an outdated version of the code. Please head to GitHub to review the +# latest version, which contains important improvements over older versions. +# +# Thank you for reading and for any feedback you provide. + +# * Code review nav * +#============================================================================== +# ← Prev: __init__.py Current: _abc.py Next: _base.py → +#============================================================================== + + +"""Provides the :class:`BidirectionalMapping` abstract base class.""" + +from .compat import Mapping, abstractproperty, iteritems + + +class BidirectionalMapping(Mapping): # pylint: disable=abstract-method,no-init + """Abstract base class (ABC) for bidirectional mapping types. + + Extends :class:`collections.abc.Mapping` primarily by adding the + (abstract) :attr:`inverse` property, + which implementors of :class:`BidirectionalMapping` + should override to return a reference to the inverse + :class:`BidirectionalMapping` instance. + """ + + __slots__ = () + + @abstractproperty + def inverse(self): + """The inverse of this bidirectional mapping instance. + + *See also* :attr:`bidict.BidictBase.inverse`, :attr:`bidict.BidictBase.inv` + + :raises NotImplementedError: Meant to be overridden in subclasses. + """ + # The @abstractproperty decorator prevents BidirectionalMapping subclasses from being + # instantiated unless they override this method. So users shouldn't be able to get to the + # point where they can unintentionally call this implementation of .inverse on something + # anyway. Could leave the method body empty, but raise NotImplementedError so it's extra + # clear there's no reason to call this implementation (e.g. via super() after overriding). + raise NotImplementedError + + def __inverted__(self): + """Get an iterator over the items in :attr:`inverse`. + + This is functionally equivalent to iterating over the items in the + forward mapping and inverting each one on the fly, but this provides a + more efficient implementation: Assuming the already-inverted items + are stored in :attr:`inverse`, just return an iterator over them directly. + + Providing this default implementation enables external functions, + particularly :func:`~bidict.inverted`, to use this optimized + implementation when available, instead of having to invert on the fly. + + *See also* :func:`bidict.inverted` + """ + return iteritems(self.inverse) + + +# * Code review nav * +#============================================================================== +# ← Prev: __init__.py Current: _abc.py Next: _base.py → +#============================================================================== diff --git a/libs/bidict/_base.py b/libs/bidict/_base.py new file mode 100644 index 000000000..b2a852df2 --- /dev/null +++ b/libs/bidict/_base.py @@ -0,0 +1,462 @@ +# -*- coding: utf-8 -*- +# Copyright 2009-2019 Joshua Bronson. All Rights Reserved. +# +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. + + +#============================================================================== +# * Welcome to the bidict source code * +#============================================================================== + +# Doing a code review? You'll find a "Code review nav" comment like the one +# below at the top and bottom of the most important source files. This provides +# a suggested initial path through the source when reviewing. +# +# Note: If you aren't reading this on https://github.com/jab/bidict, you may be +# viewing an outdated version of the code. Please head to GitHub to review the +# latest version, which contains important improvements over older versions. +# +# Thank you for reading and for any feedback you provide. + +# * Code review nav * +#============================================================================== +# ← Prev: _abc.py Current: _base.py Next: _delegating_mixins.py → +#============================================================================== + + +"""Provides :class:`BidictBase`.""" + +from collections import namedtuple +from weakref import ref + +from ._abc import BidirectionalMapping +from ._dup import RAISE, OVERWRITE, IGNORE, _OnDup +from ._exc import ( + DuplicationError, KeyDuplicationError, ValueDuplicationError, KeyAndValueDuplicationError) +from ._miss import _MISS +from ._noop import _NOOP +from ._util import _iteritems_args_kw +from .compat import PY2, KeysView, ItemsView, Mapping, iteritems + + +_DedupResult = namedtuple('_DedupResult', 'isdupkey isdupval invbyval fwdbykey') +_WriteResult = namedtuple('_WriteResult', 'key val oldkey oldval') +_NODUP = _DedupResult(False, False, _MISS, _MISS) + + +class BidictBase(BidirectionalMapping): + """Base class implementing :class:`BidirectionalMapping`.""" + + __slots__ = ('_fwdm', '_invm', '_inv', '_invweak', '_hash') + (() if PY2 else ('__weakref__',)) + + #: The default :class:`DuplicationPolicy` + #: (in effect during e.g. :meth:`~bidict.bidict.__init__` calls) + #: that governs behavior when a provided item + #: duplicates only the key of another item. + #: + #: Defaults to :attr:`~bidict.OVERWRITE` + #: to match :class:`dict`'s behavior. + #: + #: *See also* :ref:`basic-usage:Values Must Be Unique`, :doc:`extending` + on_dup_key = OVERWRITE + + #: The default :class:`DuplicationPolicy` + #: (in effect during e.g. :meth:`~bidict.bidict.__init__` calls) + #: that governs behavior when a provided item + #: duplicates only the value of another item. + #: + #: Defaults to :attr:`~bidict.RAISE` + #: to prevent unintended overwrite of another item. + #: + #: *See also* :ref:`basic-usage:Values Must Be Unique`, :doc:`extending` + on_dup_val = RAISE + + #: The default :class:`DuplicationPolicy` + #: (in effect during e.g. :meth:`~bidict.bidict.__init__` calls) + #: that governs behavior when a provided item + #: duplicates the key of another item and the value of a third item. + #: + #: Defaults to ``None``, which causes the *on_dup_kv* policy to match + #: whatever *on_dup_val* policy is in effect. + #: + #: *See also* :ref:`basic-usage:Values Must Be Unique`, :doc:`extending` + on_dup_kv = None + + _fwdm_cls = dict + _invm_cls = dict + + #: The object used by :meth:`__repr__` for printing the contained items. + _repr_delegate = dict + + def __init__(self, *args, **kw): # pylint: disable=super-init-not-called + """Make a new bidirectional dictionary. + The signature is the same as that of regular dictionaries. + Items passed in are added in the order they are passed, + respecting the current duplication policies in the process. + + *See also* :attr:`on_dup_key`, :attr:`on_dup_val`, :attr:`on_dup_kv` + """ + #: The backing :class:`~collections.abc.Mapping` + #: storing the forward mapping data (*key* → *value*). + self._fwdm = self._fwdm_cls() + #: The backing :class:`~collections.abc.Mapping` + #: storing the inverse mapping data (*value* → *key*). + self._invm = self._invm_cls() + self._init_inv() # lgtm [py/init-calls-subclass] + if args or kw: + self._update(True, None, *args, **kw) + + def _init_inv(self): + # Compute the type for this bidict's inverse bidict (will be different from this + # bidict's type if _fwdm_cls and _invm_cls are different). + inv_cls = self._inv_cls() + # Create the inverse bidict instance via __new__, bypassing its __init__ so that its + # _fwdm and _invm can be assigned to this bidict's _invm and _fwdm. Store it in self._inv, + # which holds a strong reference to a bidict's inverse, if one is available. + self._inv = inv = inv_cls.__new__(inv_cls) + inv._fwdm = self._invm # pylint: disable=protected-access + inv._invm = self._fwdm # pylint: disable=protected-access + # Only give the inverse a weak reference to this bidict to avoid creating a reference cycle, + # stored in the _invweak attribute. See also the docs in + # :ref:`addendum:Bidict Avoids Reference Cycles` + inv._inv = None # pylint: disable=protected-access + inv._invweak = ref(self) # pylint: disable=protected-access + # Since this bidict has a strong reference to its inverse already, set its _invweak to None. + self._invweak = None + + @classmethod + def _inv_cls(cls): + """The inverse of this bidict type, i.e. one with *_fwdm_cls* and *_invm_cls* swapped.""" + if cls._fwdm_cls is cls._invm_cls: + return cls + if not getattr(cls, '_inv_cls_', None): + class _Inv(cls): + _fwdm_cls = cls._invm_cls + _invm_cls = cls._fwdm_cls + _inv_cls_ = cls + _Inv.__name__ = cls.__name__ + 'Inv' + cls._inv_cls_ = _Inv + return cls._inv_cls_ + + @property + def _isinv(self): + return self._inv is None + + @property + def inverse(self): + """The inverse of this bidict. + + *See also* :attr:`inv` + """ + # Resolve and return a strong reference to the inverse bidict. + # One may be stored in self._inv already. + if self._inv is not None: + return self._inv + # Otherwise a weakref is stored in self._invweak. Try to get a strong ref from it. + inv = self._invweak() + if inv is not None: + return inv + # Refcount of referent must have dropped to zero, as in `bidict().inv.inv`. Init a new one. + self._init_inv() # Now this bidict will retain a strong ref to its inverse. + return self._inv + + @property + def inv(self): + """Alias for :attr:`inverse`.""" + return self.inverse + + def __getstate__(self): + """Needed to enable pickling due to use of :attr:`__slots__` and weakrefs. + + *See also* :meth:`object.__getstate__` + """ + state = {} + for cls in self.__class__.__mro__: + slots = getattr(cls, '__slots__', ()) + for slot in slots: + if hasattr(self, slot): + state[slot] = getattr(self, slot) + # weakrefs can't be pickled. + state.pop('_invweak', None) # Added back in __setstate__ via _init_inv call. + state.pop('__weakref__', None) # Not added back in __setstate__. Python manages this one. + return state + + def __setstate__(self, state): + """Implemented because use of :attr:`__slots__` would prevent unpickling otherwise. + + *See also* :meth:`object.__setstate__` + """ + for slot, value in iteritems(state): + setattr(self, slot, value) + self._init_inv() + + def __repr__(self): + """See :func:`repr`.""" + clsname = self.__class__.__name__ + if not self: + return '%s()' % clsname + return '%s(%r)' % (clsname, self._repr_delegate(iteritems(self))) + + # The inherited Mapping.__eq__ implementation would work, but it's implemented in terms of an + # inefficient ``dict(self.items()) == dict(other.items())`` comparison, so override it with a + # more efficient implementation. + def __eq__(self, other): + u"""*x.__eq__(other) ⟺ x == other* + + Equivalent to *dict(x.items()) == dict(other.items())* + but more efficient. + + Note that :meth:`bidict's __eq__() ` implementation + is inherited by subclasses, + in particular by the ordered bidict subclasses, + so even with ordered bidicts, + :ref:`== comparison is order-insensitive `. + + *See also* :meth:`bidict.FrozenOrderedBidict.equals_order_sensitive` + """ + if not isinstance(other, Mapping) or len(self) != len(other): + return False + selfget = self.get + return all(selfget(k, _MISS) == v for (k, v) in iteritems(other)) + + # The following methods are mutating and so are not public. But they are implemented in this + # non-mutable base class (rather than the mutable `bidict` subclass) because they are used here + # during initialization (starting with the `_update` method). (Why is this? Because `__init__` + # and `update` share a lot of the same behavior (inserting the provided items while respecting + # the active duplication policies), so it makes sense for them to share implementation too.) + def _pop(self, key): + val = self._fwdm.pop(key) + del self._invm[val] + return val + + def _put(self, key, val, on_dup): + dedup_result = self._dedup_item(key, val, on_dup) + if dedup_result is not _NOOP: + self._write_item(key, val, dedup_result) + + def _dedup_item(self, key, val, on_dup): + """ + Check *key* and *val* for any duplication in self. + + Handle any duplication as per the duplication policies given in *on_dup*. + + (key, val) already present is construed as a no-op, not a duplication. + + If duplication is found and the corresponding duplication policy is + :attr:`~bidict.RAISE`, raise the appropriate error. + + If duplication is found and the corresponding duplication policy is + :attr:`~bidict.IGNORE`, return *None*. + + If duplication is found and the corresponding duplication policy is + :attr:`~bidict.OVERWRITE`, + or if no duplication is found, + return the _DedupResult *(isdupkey, isdupval, oldkey, oldval)*. + """ + fwdm = self._fwdm + invm = self._invm + oldval = fwdm.get(key, _MISS) + oldkey = invm.get(val, _MISS) + isdupkey = oldval is not _MISS + isdupval = oldkey is not _MISS + dedup_result = _DedupResult(isdupkey, isdupval, oldkey, oldval) + if isdupkey and isdupval: + if self._isdupitem(key, val, dedup_result): + # (key, val) duplicates an existing item -> no-op. + return _NOOP + # key and val each duplicate a different existing item. + if on_dup.kv is RAISE: + raise KeyAndValueDuplicationError(key, val) + elif on_dup.kv is IGNORE: + return _NOOP + assert on_dup.kv is OVERWRITE, 'invalid on_dup_kv: %r' % on_dup.kv + # Fall through to the return statement on the last line. + elif isdupkey: + if on_dup.key is RAISE: + raise KeyDuplicationError(key) + elif on_dup.key is IGNORE: + return _NOOP + assert on_dup.key is OVERWRITE, 'invalid on_dup.key: %r' % on_dup.key + # Fall through to the return statement on the last line. + elif isdupval: + if on_dup.val is RAISE: + raise ValueDuplicationError(val) + elif on_dup.val is IGNORE: + return _NOOP + assert on_dup.val is OVERWRITE, 'invalid on_dup.val: %r' % on_dup.val + # Fall through to the return statement on the last line. + # else neither isdupkey nor isdupval. + return dedup_result + + @staticmethod + def _isdupitem(key, val, dedup_result): + isdupkey, isdupval, oldkey, oldval = dedup_result + isdupitem = oldkey == key + assert isdupitem == (oldval == val), '%r %r %r' % (key, val, dedup_result) + if isdupitem: + assert isdupkey + assert isdupval + return isdupitem + + @classmethod + def _get_on_dup(cls, on_dup=None): + if on_dup is None: + on_dup = _OnDup(cls.on_dup_key, cls.on_dup_val, cls.on_dup_kv) + elif not isinstance(on_dup, _OnDup): + on_dup = _OnDup(*on_dup) + if on_dup.kv is None: + on_dup = on_dup._replace(kv=on_dup.val) + return on_dup + + def _write_item(self, key, val, dedup_result): + isdupkey, isdupval, oldkey, oldval = dedup_result + fwdm = self._fwdm + invm = self._invm + fwdm[key] = val + invm[val] = key + if isdupkey: + del invm[oldval] + if isdupval: + del fwdm[oldkey] + return _WriteResult(key, val, oldkey, oldval) + + def _update(self, init, on_dup, *args, **kw): + # args[0] may be a generator that yields many items, so process input in a single pass. + if not args and not kw: + return + can_skip_dup_check = not self and not kw and isinstance(args[0], BidirectionalMapping) + if can_skip_dup_check: + self._update_no_dup_check(args[0]) + return + on_dup = self._get_on_dup(on_dup) + can_skip_rollback = init or RAISE not in on_dup + if can_skip_rollback: + self._update_no_rollback(on_dup, *args, **kw) + else: + self._update_with_rollback(on_dup, *args, **kw) + + def _update_no_dup_check(self, other, _nodup=_NODUP): + write_item = self._write_item + for (key, val) in iteritems(other): + write_item(key, val, _nodup) + + def _update_no_rollback(self, on_dup, *args, **kw): + put = self._put + for (key, val) in _iteritems_args_kw(*args, **kw): + put(key, val, on_dup) + + def _update_with_rollback(self, on_dup, *args, **kw): + """Update, rolling back on failure.""" + writelog = [] + appendlog = writelog.append + dedup_item = self._dedup_item + write_item = self._write_item + for (key, val) in _iteritems_args_kw(*args, **kw): + try: + dedup_result = dedup_item(key, val, on_dup) + except DuplicationError: + undo_write = self._undo_write + for dedup_result, write_result in reversed(writelog): + undo_write(dedup_result, write_result) + raise + if dedup_result is not _NOOP: + write_result = write_item(key, val, dedup_result) + appendlog((dedup_result, write_result)) + + def _undo_write(self, dedup_result, write_result): + isdupkey, isdupval, _, _ = dedup_result + key, val, oldkey, oldval = write_result + if not isdupkey and not isdupval: + self._pop(key) + return + fwdm = self._fwdm + invm = self._invm + if isdupkey: + fwdm[key] = oldval + invm[oldval] = key + if not isdupval: + del invm[val] + if isdupval: + invm[val] = oldkey + fwdm[oldkey] = val + if not isdupkey: + del fwdm[key] + + def copy(self): + """A shallow copy.""" + # Could just ``return self.__class__(self)`` here instead, but the below is faster. It uses + # __new__ to create a copy instance while bypassing its __init__, which would result + # in copying this bidict's items into the copy instance one at a time. Instead, make whole + # copies of each of the backing mappings, and make them the backing mappings of the copy, + # avoiding copying items one at a time. + copy = self.__class__.__new__(self.__class__) + copy._fwdm = self._fwdm.copy() # pylint: disable=protected-access + copy._invm = self._invm.copy() # pylint: disable=protected-access + copy._init_inv() # pylint: disable=protected-access + return copy + + def __copy__(self): + """Used for the copy protocol. + + *See also* the :mod:`copy` module + """ + return self.copy() + + def __len__(self): + """The number of contained items.""" + return len(self._fwdm) + + def __iter__(self): # lgtm [py/inheritance/incorrect-overridden-signature] + """Iterator over the contained items.""" + # No default implementation for __iter__ inherited from Mapping -> + # always delegate to _fwdm. + return iter(self._fwdm) + + def __getitem__(self, key): + u"""*x.__getitem__(key) ⟺ x[key]*""" + return self._fwdm[key] + + def values(self): + """A set-like object providing a view on the contained values. + + Note that because the values of a :class:`~bidict.BidirectionalMapping` + are the keys of its inverse, + this returns a :class:`~collections.abc.KeysView` + rather than a :class:`~collections.abc.ValuesView`, + which has the advantages of constant-time containment checks + and supporting set operations. + """ + return self.inverse.keys() + + if PY2: + # For iterkeys and iteritems, inheriting from Mapping already provides + # the best default implementations so no need to define here. + + def itervalues(self): + """An iterator over the contained values.""" + return self.inverse.iterkeys() + + def viewkeys(self): # noqa: D102; pylint: disable=missing-docstring + return KeysView(self) + + def viewvalues(self): # noqa: D102; pylint: disable=missing-docstring + return self.inverse.viewkeys() + + viewvalues.__doc__ = values.__doc__ + values.__doc__ = 'A list of the contained values.' + + def viewitems(self): # noqa: D102; pylint: disable=missing-docstring + return ItemsView(self) + + # __ne__ added automatically in Python 3 when you implement __eq__, but not in Python 2. + def __ne__(self, other): # noqa: N802 + u"""*x.__ne__(other) ⟺ x != other*""" + return not self == other # Implement __ne__ in terms of __eq__. + + +# * Code review nav * +#============================================================================== +# ← Prev: _abc.py Current: _base.py Next: _delegating_mixins.py → +#============================================================================== diff --git a/libs/bidict/_bidict.py b/libs/bidict/_bidict.py new file mode 100644 index 000000000..9082775b6 --- /dev/null +++ b/libs/bidict/_bidict.py @@ -0,0 +1,46 @@ +# -*- coding: utf-8 -*- +# Copyright 2009-2019 Joshua Bronson. All Rights Reserved. +# +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. + + +#============================================================================== +# * Welcome to the bidict source code * +#============================================================================== + +# Doing a code review? You'll find a "Code review nav" comment like the one +# below at the top and bottom of the most important source files. This provides +# a suggested initial path through the source when reviewing. +# +# Note: If you aren't reading this on https://github.com/jab/bidict, you may be +# viewing an outdated version of the code. Please head to GitHub to review the +# latest version, which contains important improvements over older versions. +# +# Thank you for reading and for any feedback you provide. + +# * Code review nav * +#============================================================================== +# ← Prev: _mut.py Current: _bidict.py Next: _orderedbase.py → +#============================================================================== + + +"""Provides :class:`bidict`.""" + +from ._mut import MutableBidict +from ._delegating_mixins import _DelegateKeysAndItemsToFwdm + + +class bidict(_DelegateKeysAndItemsToFwdm, MutableBidict): # noqa: N801,E501; pylint: disable=invalid-name + """Base class for mutable bidirectional mappings.""" + + __slots__ = () + + __hash__ = None # since this class is mutable; explicit > implicit. + + +# * Code review nav * +#============================================================================== +# ← Prev: _mut.py Current: _bidict.py Next: _orderedbase.py → +#============================================================================== diff --git a/libs/bidict/_delegating_mixins.py b/libs/bidict/_delegating_mixins.py new file mode 100644 index 000000000..8772490c7 --- /dev/null +++ b/libs/bidict/_delegating_mixins.py @@ -0,0 +1,92 @@ +# -*- coding: utf-8 -*- +# Copyright 2009-2019 Joshua Bronson. All Rights Reserved. +# +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. + + +#============================================================================== +# * Welcome to the bidict source code * +#============================================================================== + +# Doing a code review? You'll find a "Code review nav" comment like the one +# below at the top and bottom of the most important source files. This provides +# a suggested initial path through the source when reviewing. +# +# Note: If you aren't reading this on https://github.com/jab/bidict, you may be +# viewing an outdated version of the code. Please head to GitHub to review the +# latest version, which contains important improvements over older versions. +# +# Thank you for reading and for any feedback you provide. + +# * Code review nav * +#============================================================================== +# ← Prev: _base.py Current: _delegating_mixins.py Next: _frozenbidict.py → +#============================================================================== + + +r"""Provides mixin classes that delegate to ``self._fwdm`` for various operations. + +This allows methods such as :meth:`bidict.bidict.items` +to be implemented in terms of a ``self._fwdm.items()`` call, +which is potentially much more efficient (e.g. in CPython 2) +compared to the implementation inherited from :class:`~collections.abc.Mapping` +(which returns ``[(key, self[key]) for key in self]`` in Python 2). + +Because this depends on implementation details that aren't necessarily true +(such as the bidict's values being the same as its ``self._fwdm.values()``, +which is not true for e.g. ordered bidicts where ``_fwdm``\'s values are nodes), +these should always be mixed in at a layer below a more general layer, +as they are in e.g. :class:`~bidict.frozenbidict` +which extends :class:`~bidict.BidictBase`. + +See the :ref:`extending:Sorted Bidict Recipes` +for another example of where this comes into play. +``SortedBidict`` extends :class:`bidict.MutableBidict` +rather than :class:`bidict.bidict` +to avoid inheriting these mixins, +which are incompatible with the backing +:class:`sortedcontainers.SortedDict`s. +""" + +from .compat import PY2 + + +_KEYS_METHODS = ('keys',) + (('viewkeys', 'iterkeys') if PY2 else ()) +_ITEMS_METHODS = ('items',) + (('viewitems', 'iteritems') if PY2 else ()) +_DOCSTRING_BY_METHOD = { + 'keys': 'A set-like object providing a view on the contained keys.', + 'items': 'A set-like object providing a view on the contained items.', +} +if PY2: + _DOCSTRING_BY_METHOD['viewkeys'] = _DOCSTRING_BY_METHOD['keys'] + _DOCSTRING_BY_METHOD['viewitems'] = _DOCSTRING_BY_METHOD['items'] + _DOCSTRING_BY_METHOD['keys'] = 'A list of the contained keys.' + _DOCSTRING_BY_METHOD['items'] = 'A list of the contained items.' + + +def _make_method(methodname): + def method(self): + return getattr(self._fwdm, methodname)() # pylint: disable=protected-access + method.__name__ = methodname + method.__doc__ = _DOCSTRING_BY_METHOD.get(methodname, '') + return method + + +def _make_fwdm_delegating_mixin(clsname, methodnames): + clsdict = dict({name: _make_method(name) for name in methodnames}, __slots__=()) + return type(clsname, (object,), clsdict) + + +_DelegateKeysToFwdm = _make_fwdm_delegating_mixin('_DelegateKeysToFwdm', _KEYS_METHODS) +_DelegateItemsToFwdm = _make_fwdm_delegating_mixin('_DelegateItemsToFwdm', _ITEMS_METHODS) +_DelegateKeysAndItemsToFwdm = type( + '_DelegateKeysAndItemsToFwdm', + (_DelegateKeysToFwdm, _DelegateItemsToFwdm), + {'__slots__': ()}) + +# * Code review nav * +#============================================================================== +# ← Prev: _base.py Current: _delegating_mixins.py Next: _frozenbidict.py → +#============================================================================== diff --git a/libs/bidict/_dup.py b/libs/bidict/_dup.py new file mode 100644 index 000000000..4670dcc57 --- /dev/null +++ b/libs/bidict/_dup.py @@ -0,0 +1,36 @@ +# -*- coding: utf-8 -*- +# Copyright 2009-2019 Joshua Bronson. All Rights Reserved. +# +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. + + +"""Provides bidict duplication policies and the :class:`_OnDup` class.""" + + +from collections import namedtuple + +from ._marker import _Marker + + +_OnDup = namedtuple('_OnDup', 'key val kv') + + +class DuplicationPolicy(_Marker): + """Base class for bidict's duplication policies. + + *See also* :ref:`basic-usage:Values Must Be Unique` + """ + + __slots__ = () + + +#: Raise an exception when a duplication is encountered. +RAISE = DuplicationPolicy('DUP_POLICY.RAISE') + +#: Overwrite an existing item when a duplication is encountered. +OVERWRITE = DuplicationPolicy('DUP_POLICY.OVERWRITE') + +#: Keep the existing item and ignore the new item when a duplication is encountered. +IGNORE = DuplicationPolicy('DUP_POLICY.IGNORE') diff --git a/libs/bidict/_exc.py b/libs/bidict/_exc.py new file mode 100644 index 000000000..5370361e6 --- /dev/null +++ b/libs/bidict/_exc.py @@ -0,0 +1,35 @@ +# -*- coding: utf-8 -*- +# Copyright 2009-2019 Joshua Bronson. All Rights Reserved. +# +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. + + +"""Provides all bidict exceptions.""" + + +class BidictException(Exception): + """Base class for bidict exceptions.""" + + +class DuplicationError(BidictException): + """Base class for exceptions raised when uniqueness is violated + as per the RAISE duplication policy. + """ + + +class KeyDuplicationError(DuplicationError): + """Raised when a given key is not unique.""" + + +class ValueDuplicationError(DuplicationError): + """Raised when a given value is not unique.""" + + +class KeyAndValueDuplicationError(KeyDuplicationError, ValueDuplicationError): + """Raised when a given item's key and value are not unique. + + That is, its key duplicates that of another item, + and its value duplicates that of a different other item. + """ diff --git a/libs/bidict/_frozenbidict.py b/libs/bidict/_frozenbidict.py new file mode 100644 index 000000000..07831fd91 --- /dev/null +++ b/libs/bidict/_frozenbidict.py @@ -0,0 +1,51 @@ +# -*- coding: utf-8 -*- +# Copyright 2009-2019 Joshua Bronson. All Rights Reserved. +# +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. + + +#============================================================================== +# * Welcome to the bidict source code * +#============================================================================== + +# Doing a code review? You'll find a "Code review nav" comment like the one +# below at the top and bottom of the most important source files. This provides +# a suggested initial path through the source when reviewing. +# +# Note: If you aren't reading this on https://github.com/jab/bidict, you may be +# viewing an outdated version of the code. Please head to GitHub to review the +# latest version, which contains important improvements over older versions. +# +# Thank you for reading and for any feedback you provide. + +# * Code review nav * +#============================================================================== +# ← Prev: _delegating_mixins.py Current: _frozenbidict.py Next: _mut.py → +#============================================================================== + +"""Provides :class:`frozenbidict`, an immutable, hashable bidirectional mapping type.""" + +from ._base import BidictBase +from ._delegating_mixins import _DelegateKeysAndItemsToFwdm +from .compat import ItemsView + + +class frozenbidict(_DelegateKeysAndItemsToFwdm, BidictBase): # noqa: N801,E501; pylint: disable=invalid-name + """Immutable, hashable bidict type.""" + + __slots__ = () + + def __hash__(self): # lgtm [py/equals-hash-mismatch] + """The hash of this bidict as determined by its items.""" + if getattr(self, '_hash', None) is None: + # pylint: disable=protected-access,attribute-defined-outside-init + self._hash = ItemsView(self)._hash() + return self._hash + + +# * Code review nav * +#============================================================================== +# ← Prev: _delegating_mixins.py Current: _frozenbidict.py Next: _mut.py → +#============================================================================== diff --git a/libs/bidict/_frozenordered.py b/libs/bidict/_frozenordered.py new file mode 100644 index 000000000..25cbace3b --- /dev/null +++ b/libs/bidict/_frozenordered.py @@ -0,0 +1,65 @@ +# -*- coding: utf-8 -*- +# Copyright 2009-2019 Joshua Bronson. All Rights Reserved. +# +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. + + +#============================================================================== +# * Welcome to the bidict source code * +#============================================================================== + +# Doing a code review? You'll find a "Code review nav" comment like the one +# below at the top and bottom of the most important source files. This provides +# a suggested initial path through the source when reviewing. +# +# Note: If you aren't reading this on https://github.com/jab/bidict, you may be +# viewing an outdated version of the code. Please head to GitHub to review the +# latest version, which contains important improvements over older versions. +# +# Thank you for reading and for any feedback you provide. + +# * Code review nav * +#============================================================================== +#← Prev: _orderedbase.py Current: _frozenordered.py Next: _orderedbidict.py → +#============================================================================== + +"""Provides :class:`FrozenOrderedBidict`, an immutable, hashable, ordered bidict.""" + +from ._delegating_mixins import _DelegateKeysToFwdm +from ._frozenbidict import frozenbidict +from ._orderedbase import OrderedBidictBase +from .compat import DICTS_ORDERED, PY2, izip + + +# If the Python implementation's dict type is ordered (e.g. PyPy or CPython >= 3.6), then +# `FrozenOrderedBidict` can delegate to `_fwdm` for keys: Both `_fwdm` and `_invm` will always +# be initialized with the provided items in the correct order, and since `FrozenOrderedBidict` +# is immutable, their respective orders can't get out of sync after a mutation. (Can't delegate +# to `_fwdm` for items though because values in `_fwdm` are nodes.) +_BASES = ((_DelegateKeysToFwdm,) if DICTS_ORDERED else ()) + (OrderedBidictBase,) +_CLSDICT = dict( + __slots__=(), + # Must set __hash__ explicitly, Python prevents inheriting it. + # frozenbidict.__hash__ can be reused for FrozenOrderedBidict: + # FrozenOrderedBidict inherits BidictBase.__eq__ which is order-insensitive, + # and frozenbidict.__hash__ is consistent with BidictBase.__eq__. + __hash__=frozenbidict.__hash__.__func__ if PY2 else frozenbidict.__hash__, + __doc__='Hashable, immutable, ordered bidict type.', + __module__=__name__, # Otherwise unpickling fails in Python 2. +) + +# When PY2 (so we provide iteritems) and DICTS_ORDERED, e.g. on PyPy, the following implementation +# of iteritems may be more efficient than that inherited from `Mapping`. This exploits the property +# that the keys in `_fwdm` and `_invm` are already in the right order: +if PY2 and DICTS_ORDERED: + _CLSDICT['iteritems'] = lambda self: izip(self._fwdm, self._invm) # noqa: E501; pylint: disable=protected-access + +FrozenOrderedBidict = type('FrozenOrderedBidict', _BASES, _CLSDICT) # pylint: disable=invalid-name + + +# * Code review nav * +#============================================================================== +#← Prev: _orderedbase.py Current: _frozenordered.py Next: _orderedbidict.py → +#============================================================================== diff --git a/libs/bidict/_marker.py b/libs/bidict/_marker.py new file mode 100644 index 000000000..f2f9c57cb --- /dev/null +++ b/libs/bidict/_marker.py @@ -0,0 +1,19 @@ +# -*- coding: utf-8 -*- +# Copyright 2009-2019 Joshua Bronson. All Rights Reserved. +# +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. + + +"""Provides :class:`_Marker`, an internal type for representing singletons.""" + +from collections import namedtuple + + +class _Marker(namedtuple('_Marker', 'name')): + + __slots__ = () + + def __repr__(self): + return '<%s>' % self.name # pragma: no cover diff --git a/libs/bidict/_miss.py b/libs/bidict/_miss.py new file mode 100644 index 000000000..32d02c584 --- /dev/null +++ b/libs/bidict/_miss.py @@ -0,0 +1,14 @@ +# -*- coding: utf-8 -*- +# Copyright 2009-2019 Joshua Bronson. All Rights Reserved. +# +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. + + +"""Provides the :obj:`_MISS` sentinel, for internally signaling "missing/not found".""" + +from ._marker import _Marker + + +_MISS = _Marker('MISSING') diff --git a/libs/bidict/_mut.py b/libs/bidict/_mut.py new file mode 100644 index 000000000..1a117c2ab --- /dev/null +++ b/libs/bidict/_mut.py @@ -0,0 +1,177 @@ +# -*- coding: utf-8 -*- +# Copyright 2009-2019 Joshua Bronson. All Rights Reserved. +# +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. + + +#============================================================================== +# * Welcome to the bidict source code * +#============================================================================== + +# Doing a code review? You'll find a "Code review nav" comment like the one +# below at the top and bottom of the most important source files. This provides +# a suggested initial path through the source when reviewing. +# +# Note: If you aren't reading this on https://github.com/jab/bidict, you may be +# viewing an outdated version of the code. Please head to GitHub to review the +# latest version, which contains important improvements over older versions. +# +# Thank you for reading and for any feedback you provide. + +# * Code review nav * +#============================================================================== +# ← Prev: _frozenbidict.py Current: _mut.py Next: _bidict.py → +#============================================================================== + + +"""Provides :class:`bidict`.""" + +from ._base import BidictBase +from ._dup import OVERWRITE, RAISE, _OnDup +from ._miss import _MISS +from .compat import MutableMapping + + +# Extend MutableMapping explicitly because it doesn't implement __subclasshook__, as well as to +# inherit method implementations it provides that we can reuse (namely `setdefault`). +class MutableBidict(BidictBase, MutableMapping): + """Base class for mutable bidirectional mappings.""" + + __slots__ = () + + __hash__ = None # since this class is mutable; explicit > implicit. + + _ON_DUP_OVERWRITE = _OnDup(key=OVERWRITE, val=OVERWRITE, kv=OVERWRITE) + + def __delitem__(self, key): + u"""*x.__delitem__(y) ⟺ del x[y]*""" + self._pop(key) + + def __setitem__(self, key, val): + """ + Set the value for *key* to *val*. + + If *key* is already associated with *val*, this is a no-op. + + If *key* is already associated with a different value, + the old value will be replaced with *val*, + as with dict's :meth:`__setitem__`. + + If *val* is already associated with a different key, + an exception is raised + to protect against accidental removal of the key + that's currently associated with *val*. + + Use :meth:`put` instead if you want to specify different policy in + the case that the provided key or value duplicates an existing one. + Or use :meth:`forceput` to unconditionally associate *key* with *val*, + replacing any existing items as necessary to preserve uniqueness. + + :raises bidict.ValueDuplicationError: if *val* duplicates that of an + existing item. + + :raises bidict.KeyAndValueDuplicationError: if *key* duplicates the key of an + existing item and *val* duplicates the value of a different + existing item. + """ + on_dup = self._get_on_dup() + self._put(key, val, on_dup) + + def put(self, key, val, on_dup_key=RAISE, on_dup_val=RAISE, on_dup_kv=None): + """ + Associate *key* with *val* with the specified duplication policies. + + If *on_dup_kv* is ``None``, the *on_dup_val* policy will be used for it. + + For example, if all given duplication policies are :attr:`~bidict.RAISE`, + then *key* will be associated with *val* if and only if + *key* is not already associated with an existing value and + *val* is not already associated with an existing key, + otherwise an exception will be raised. + + If *key* is already associated with *val*, this is a no-op. + + :raises bidict.KeyDuplicationError: if attempting to insert an item + whose key only duplicates an existing item's, and *on_dup_key* is + :attr:`~bidict.RAISE`. + + :raises bidict.ValueDuplicationError: if attempting to insert an item + whose value only duplicates an existing item's, and *on_dup_val* is + :attr:`~bidict.RAISE`. + + :raises bidict.KeyAndValueDuplicationError: if attempting to insert an + item whose key duplicates one existing item's, and whose value + duplicates another existing item's, and *on_dup_kv* is + :attr:`~bidict.RAISE`. + """ + on_dup = self._get_on_dup((on_dup_key, on_dup_val, on_dup_kv)) + self._put(key, val, on_dup) + + def forceput(self, key, val): + """ + Associate *key* with *val* unconditionally. + + Replace any existing mappings containing key *key* or value *val* + as necessary to preserve uniqueness. + """ + self._put(key, val, self._ON_DUP_OVERWRITE) + + def clear(self): + """Remove all items.""" + self._fwdm.clear() + self._invm.clear() + + def pop(self, key, default=_MISS): + u"""*x.pop(k[, d]) → v* + + Remove specified key and return the corresponding value. + + :raises KeyError: if *key* is not found and no *default* is provided. + """ + try: + return self._pop(key) + except KeyError: + if default is _MISS: + raise + return default + + def popitem(self): + u"""*x.popitem() → (k, v)* + + Remove and return some item as a (key, value) pair. + + :raises KeyError: if *x* is empty. + """ + if not self: + raise KeyError('mapping is empty') + key, val = self._fwdm.popitem() + del self._invm[val] + return key, val + + def update(self, *args, **kw): + """Like :meth:`putall` with default duplication policies.""" + if args or kw: + self._update(False, None, *args, **kw) + + def forceupdate(self, *args, **kw): + """Like a bulk :meth:`forceput`.""" + self._update(False, self._ON_DUP_OVERWRITE, *args, **kw) + + def putall(self, items, on_dup_key=RAISE, on_dup_val=RAISE, on_dup_kv=None): + """ + Like a bulk :meth:`put`. + + If one of the given items causes an exception to be raised, + none of the items is inserted. + """ + if items: + on_dup = self._get_on_dup((on_dup_key, on_dup_val, on_dup_kv)) + self._update(False, on_dup, items) + + +# * Code review nav * +#============================================================================== +# ← Prev: _frozenbidict.py Current: _mut.py Next: _bidict.py → +#============================================================================== diff --git a/libs/bidict/_named.py b/libs/bidict/_named.py new file mode 100644 index 000000000..8748b98cc --- /dev/null +++ b/libs/bidict/_named.py @@ -0,0 +1,103 @@ +# -*- coding: utf-8 -*- +# Copyright 2009-2019 Joshua Bronson. All Rights Reserved. +# +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. + +"""Provides :func:`bidict.namedbidict`.""" + +import re + +from ._abc import BidirectionalMapping +from ._bidict import bidict +from .compat import PY2 + + +_isidentifier = ( # pylint: disable=invalid-name + re.compile('[A-Za-z_][A-Za-z0-9_]*$').match if PY2 else str.isidentifier +) + + +def namedbidict(typename, keyname, valname, base_type=bidict): + r"""Create a new subclass of *base_type* with custom accessors. + + Analagous to :func:`collections.namedtuple`. + + The new class's ``__name__`` and ``__qualname__`` + will be set based on *typename*. + + Instances of it will provide access to their + :attr:`inverse `\s + via the custom *keyname*\_for property, + and access to themselves + via the custom *valname*\_for property. + + *See also* the :ref:`namedbidict usage documentation + ` + + :raises ValueError: if any of the *typename*, *keyname*, or *valname* + strings is not a valid Python identifier, or if *keyname == valname*. + + :raises TypeError: if *base_type* is not a subclass of + :class:`BidirectionalMapping`. + (This function requires slightly more of *base_type*, + e.g. the availability of an ``_isinv`` attribute, + but all the :ref:`concrete bidict types + ` + that the :mod:`bidict` module provides can be passed in. + Check out the code if you actually need to pass in something else.) + """ + # Re the `base_type` docs above: + # The additional requirements (providing _isinv and __getstate__) do not belong in the + # BidirectionalMapping interface, and it's overkill to create additional interface(s) for this. + # On the other hand, it's overkill to require that base_type be a subclass of BidictBase, since + # that's too specific. The BidirectionalMapping check along with the docs above should suffice. + if not issubclass(base_type, BidirectionalMapping): + raise TypeError(base_type) + names = (typename, keyname, valname) + if not all(map(_isidentifier, names)) or keyname == valname: + raise ValueError(names) + + class _Named(base_type): # pylint: disable=too-many-ancestors + + __slots__ = () + + def _getfwd(self): + return self.inverse if self._isinv else self + + def _getinv(self): + return self if self._isinv else self.inverse + + @property + def _keyname(self): + return valname if self._isinv else keyname + + @property + def _valname(self): + return keyname if self._isinv else valname + + def __reduce__(self): + return (_make_empty, (typename, keyname, valname, base_type), self.__getstate__()) + + bname = base_type.__name__ + fname = valname + '_for' + iname = keyname + '_for' + names = dict(typename=typename, bname=bname, keyname=keyname, valname=valname) + fdoc = u'{typename} forward {bname}: {keyname} → {valname}'.format(**names) + idoc = u'{typename} inverse {bname}: {valname} → {keyname}'.format(**names) + setattr(_Named, fname, property(_Named._getfwd, doc=fdoc)) # pylint: disable=protected-access + setattr(_Named, iname, property(_Named._getinv, doc=idoc)) # pylint: disable=protected-access + + if not PY2: + _Named.__qualname__ = _Named.__qualname__[:-len(_Named.__name__)] + typename + _Named.__name__ = typename + return _Named + + +def _make_empty(typename, keyname, valname, base_type): + """Create a named bidict with the indicated arguments and return an empty instance. + Used to make :func:`bidict.namedbidict` instances picklable. + """ + cls = namedbidict(typename, keyname, valname, base_type=base_type) + return cls() diff --git a/libs/bidict/_noop.py b/libs/bidict/_noop.py new file mode 100644 index 000000000..b045e8d72 --- /dev/null +++ b/libs/bidict/_noop.py @@ -0,0 +1,14 @@ +# -*- coding: utf-8 -*- +# Copyright 2009-2019 Joshua Bronson. All Rights Reserved. +# +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. + + +"""Provides the :obj:`_NOOP` sentinel, for internally signaling "no-op".""" + +from ._marker import _Marker + + +_NOOP = _Marker('NO-OP') diff --git a/libs/bidict/_orderedbase.py b/libs/bidict/_orderedbase.py new file mode 100644 index 000000000..aa085a2d5 --- /dev/null +++ b/libs/bidict/_orderedbase.py @@ -0,0 +1,302 @@ +# -*- coding: utf-8 -*- +# Copyright 2009-2019 Joshua Bronson. All Rights Reserved. +# +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. + + +#============================================================================== +# * Welcome to the bidict source code * +#============================================================================== + +# Doing a code review? You'll find a "Code review nav" comment like the one +# below at the top and bottom of the most important source files. This provides +# a suggested initial path through the source when reviewing. +# +# Note: If you aren't reading this on https://github.com/jab/bidict, you may be +# viewing an outdated version of the code. Please head to GitHub to review the +# latest version, which contains important improvements over older versions. +# +# Thank you for reading and for any feedback you provide. + +# * Code review nav * +#============================================================================== +# ← Prev: _bidict.py Current: _orderedbase.py Next: _frozenordered.py → +#============================================================================== + + +"""Provides :class:`OrderedBidictBase`.""" + +from weakref import ref + +from ._base import _WriteResult, BidictBase +from ._bidict import bidict +from ._miss import _MISS +from .compat import Mapping, PY2, iteritems, izip + + +class _Node(object): # pylint: disable=too-few-public-methods + """A node in a circular doubly-linked list + used to encode the order of items in an ordered bidict. + + Only weak references to the next and previous nodes + are held to avoid creating strong reference cycles. + + Because an ordered bidict retains two strong references + to each node instance (one from its backing `_fwdm` mapping + and one from its `_invm` mapping), a node's refcount will not + drop to zero (and so will not be garbage collected) as long as + the ordered bidict that contains it is still alive. + Because nodes don't have strong reference cycles, + once their containing bidict is freed, + they too are immediately freed. + """ + + __slots__ = ('_prv', '_nxt', '__weakref__') + + def __init__(self, prv=None, nxt=None): + self._setprv(prv) + self._setnxt(nxt) + + def __repr__(self): # pragma: no cover + clsname = self.__class__.__name__ + prv = id(self.prv) + nxt = id(self.nxt) + return '%s(prv=%s, self=%s, nxt=%s)' % (clsname, prv, id(self), nxt) + + def _getprv(self): + return self._prv() if isinstance(self._prv, ref) else self._prv + + def _setprv(self, prv): + self._prv = prv and ref(prv) + + prv = property(_getprv, _setprv) + + def _getnxt(self): + return self._nxt() if isinstance(self._nxt, ref) else self._nxt + + def _setnxt(self, nxt): + self._nxt = nxt and ref(nxt) + + nxt = property(_getnxt, _setnxt) + + def __getstate__(self): + """Return the instance state dictionary + but with weakrefs converted to strong refs + so that it can be pickled. + + *See also* :meth:`object.__getstate__` + """ + return dict(_prv=self.prv, _nxt=self.nxt) + + def __setstate__(self, state): + """Set the instance state from *state*.""" + self._setprv(state['_prv']) + self._setnxt(state['_nxt']) + + +class _Sentinel(_Node): # pylint: disable=too-few-public-methods + """Special node in a circular doubly-linked list + that links the first node with the last node. + When its next and previous references point back to itself + it represents an empty list. + """ + + __slots__ = () + + def __init__(self, prv=None, nxt=None): + super(_Sentinel, self).__init__(prv or self, nxt or self) + + def __repr__(self): # pragma: no cover + return '' + + def __bool__(self): + return False + + if PY2: + __nonzero__ = __bool__ + + def __iter__(self, reverse=False): + """Iterator yielding nodes in the requested order, + i.e. traverse the linked list via :attr:`nxt` + (or :attr:`prv` if *reverse* is truthy) + until reaching a falsy (i.e. sentinel) node. + """ + attr = 'prv' if reverse else 'nxt' + node = getattr(self, attr) + while node: + yield node + node = getattr(node, attr) + + +class OrderedBidictBase(BidictBase): + """Base class implementing an ordered :class:`BidirectionalMapping`.""" + + __slots__ = ('_sntl',) + + _fwdm_cls = bidict + _invm_cls = bidict + + #: The object used by :meth:`__repr__` for printing the contained items. + _repr_delegate = list + + def __init__(self, *args, **kw): + """Make a new ordered bidirectional mapping. + The signature is the same as that of regular dictionaries. + Items passed in are added in the order they are passed, + respecting this bidict type's duplication policies along the way. + The order in which items are inserted is remembered, + similar to :class:`collections.OrderedDict`. + """ + self._sntl = _Sentinel() + + # Like unordered bidicts, ordered bidicts also store two backing one-directional mappings + # `_fwdm` and `_invm`. But rather than mapping `key` to `val` and `val` to `key` + # (respectively), they map `key` to `nodefwd` and `val` to `nodeinv` (respectively), where + # `nodefwd` is `nodeinv` when `key` and `val` are associated with one another. + + # To effect this difference, `_write_item` and `_undo_write` are overridden. But much of the + # rest of BidictBase's implementation, including BidictBase.__init__ and BidictBase._update, + # are inherited and are able to be reused without modification. + super(OrderedBidictBase, self).__init__(*args, **kw) + + def _init_inv(self): + super(OrderedBidictBase, self)._init_inv() + self.inverse._sntl = self._sntl # pylint: disable=protected-access + + # Can't reuse BidictBase.copy since ordered bidicts have different internal structure. + def copy(self): + """A shallow copy of this ordered bidict.""" + # Fast copy implementation bypassing __init__. See comments in :meth:`BidictBase.copy`. + copy = self.__class__.__new__(self.__class__) + sntl = _Sentinel() + fwdm = self._fwdm.copy() + invm = self._invm.copy() + cur = sntl + nxt = sntl.nxt + for (key, val) in iteritems(self): + nxt = _Node(cur, sntl) + cur.nxt = fwdm[key] = invm[val] = nxt + cur = nxt + sntl.prv = nxt + copy._sntl = sntl # pylint: disable=protected-access + copy._fwdm = fwdm # pylint: disable=protected-access + copy._invm = invm # pylint: disable=protected-access + copy._init_inv() # pylint: disable=protected-access + return copy + + def __getitem__(self, key): + nodefwd = self._fwdm[key] + val = self._invm.inverse[nodefwd] + return val + + def _pop(self, key): + nodefwd = self._fwdm.pop(key) + val = self._invm.inverse.pop(nodefwd) + nodefwd.prv.nxt = nodefwd.nxt + nodefwd.nxt.prv = nodefwd.prv + return val + + def _isdupitem(self, key, val, dedup_result): + """Return whether (key, val) duplicates an existing item.""" + isdupkey, isdupval, nodeinv, nodefwd = dedup_result + isdupitem = nodeinv is nodefwd + if isdupitem: + assert isdupkey + assert isdupval + return isdupitem + + def _write_item(self, key, val, dedup_result): # pylint: disable=too-many-locals + fwdm = self._fwdm # bidict mapping keys to nodes + invm = self._invm # bidict mapping vals to nodes + isdupkey, isdupval, nodeinv, nodefwd = dedup_result + if not isdupkey and not isdupval: + # No key or value duplication -> create and append a new node. + sntl = self._sntl + last = sntl.prv + node = _Node(last, sntl) + last.nxt = sntl.prv = fwdm[key] = invm[val] = node + oldkey = oldval = _MISS + elif isdupkey and isdupval: + # Key and value duplication across two different nodes. + assert nodefwd is not nodeinv + oldval = invm.inverse[nodefwd] + oldkey = fwdm.inverse[nodeinv] + assert oldkey != key + assert oldval != val + # We have to collapse nodefwd and nodeinv into a single node, i.e. drop one of them. + # Drop nodeinv, so that the item with the same key is the one overwritten in place. + nodeinv.prv.nxt = nodeinv.nxt + nodeinv.nxt.prv = nodeinv.prv + # Don't remove nodeinv's references to its neighbors since + # if the update fails, we'll need them to undo this write. + # Update fwdm and invm. + tmp = fwdm.pop(oldkey) + assert tmp is nodeinv + tmp = invm.pop(oldval) + assert tmp is nodefwd + fwdm[key] = invm[val] = nodefwd + elif isdupkey: + oldval = invm.inverse[nodefwd] + oldkey = _MISS + oldnodeinv = invm.pop(oldval) + assert oldnodeinv is nodefwd + invm[val] = nodefwd + else: # isdupval + oldkey = fwdm.inverse[nodeinv] + oldval = _MISS + oldnodefwd = fwdm.pop(oldkey) + assert oldnodefwd is nodeinv + fwdm[key] = nodeinv + return _WriteResult(key, val, oldkey, oldval) + + def _undo_write(self, dedup_result, write_result): # pylint: disable=too-many-locals + fwdm = self._fwdm + invm = self._invm + isdupkey, isdupval, nodeinv, nodefwd = dedup_result + key, val, oldkey, oldval = write_result + if not isdupkey and not isdupval: + self._pop(key) + elif isdupkey and isdupval: + # Restore original items. + nodeinv.prv.nxt = nodeinv.nxt.prv = nodeinv + fwdm[oldkey] = invm[val] = nodeinv + invm[oldval] = fwdm[key] = nodefwd + elif isdupkey: + tmp = invm.pop(val) + assert tmp is nodefwd + invm[oldval] = nodefwd + assert fwdm[key] is nodefwd + else: # isdupval + tmp = fwdm.pop(key) + assert tmp is nodeinv + fwdm[oldkey] = nodeinv + assert invm[val] is nodeinv + + def __iter__(self, reverse=False): + """An iterator over this bidict's items in order.""" + fwdm_inv = self._fwdm.inverse + for node in self._sntl.__iter__(reverse=reverse): + yield fwdm_inv[node] + + def __reversed__(self): + """An iterator over this bidict's items in reverse order.""" + for key in self.__iter__(reverse=True): + yield key + + def equals_order_sensitive(self, other): + """Order-sensitive equality check. + + *See also* :ref:`eq-order-insensitive` + """ + # Same short-circuit as BidictBase.__eq__. Factoring out not worth function call overhead. + if not isinstance(other, Mapping) or len(self) != len(other): + return False + return all(i == j for (i, j) in izip(iteritems(self), iteritems(other))) + + +# * Code review nav * +#============================================================================== +# ← Prev: _bidict.py Current: _orderedbase.py Next: _frozenordered.py → +#============================================================================== diff --git a/libs/bidict/_orderedbidict.py b/libs/bidict/_orderedbidict.py new file mode 100644 index 000000000..874954838 --- /dev/null +++ b/libs/bidict/_orderedbidict.py @@ -0,0 +1,87 @@ +# -*- coding: utf-8 -*- +# Copyright 2009-2019 Joshua Bronson. All Rights Reserved. +# +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. + + +#============================================================================== +# * Welcome to the bidict source code * +#============================================================================== + +# Doing a code review? You'll find a "Code review nav" comment like the one +# below at the top and bottom of the most important source files. This provides +# a suggested initial path through the source when reviewing. +# +# Note: If you aren't reading this on https://github.com/jab/bidict, you may be +# viewing an outdated version of the code. Please head to GitHub to review the +# latest version, which contains important improvements over older versions. +# +# Thank you for reading and for any feedback you provide. + +# * Code review nav * +#============================================================================== +# ← Prev: _frozenordered.py Current: _orderedbidict.py +#============================================================================== + + +"""Provides :class:`OrderedBidict`.""" + +from ._mut import MutableBidict +from ._orderedbase import OrderedBidictBase + + +class OrderedBidict(OrderedBidictBase, MutableBidict): + """Mutable bidict type that maintains items in insertion order.""" + + __slots__ = () + __hash__ = None # since this class is mutable; explicit > implicit. + + def clear(self): + """Remove all items.""" + self._fwdm.clear() + self._invm.clear() + self._sntl.nxt = self._sntl.prv = self._sntl + + def popitem(self, last=True): # pylint: disable=arguments-differ + u"""*x.popitem() → (k, v)* + + Remove and return the most recently added item as a (key, value) pair + if *last* is True, else the least recently added item. + + :raises KeyError: if *x* is empty. + """ + if not self: + raise KeyError('mapping is empty') + key = next((reversed if last else iter)(self)) + val = self._pop(key) + return key, val + + def move_to_end(self, key, last=True): + """Move an existing key to the beginning or end of this ordered bidict. + + The item is moved to the end if *last* is True, else to the beginning. + + :raises KeyError: if the key does not exist + """ + node = self._fwdm[key] + node.prv.nxt = node.nxt + node.nxt.prv = node.prv + sntl = self._sntl + if last: + last = sntl.prv + node.prv = last + node.nxt = sntl + sntl.prv = last.nxt = node + else: + first = sntl.nxt + node.prv = sntl + node.nxt = first + sntl.nxt = first.prv = node + + +# * Code review nav * +#============================================================================== +# ← Prev: _frozenordered.py Current: _orderedbidict.py +#============================================================================== diff --git a/libs/bidict/_util.py b/libs/bidict/_util.py new file mode 100644 index 000000000..89636e66c --- /dev/null +++ b/libs/bidict/_util.py @@ -0,0 +1,61 @@ +# -*- coding: utf-8 -*- +# Copyright 2009-2019 Joshua Bronson. All Rights Reserved. +# +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. + + +"""Useful functions for working with bidirectional mappings and related data.""" + +from itertools import chain, repeat + +from .compat import iteritems, Mapping + + +_NULL_IT = repeat(None, 0) # repeat 0 times -> raise StopIteration from the start + + +def _iteritems_mapping_or_iterable(arg): + """Yield the items in *arg*. + + If *arg* is a :class:`~collections.abc.Mapping`, return an iterator over its items. + Otherwise return an iterator over *arg* itself. + """ + return iteritems(arg) if isinstance(arg, Mapping) else iter(arg) + + +def _iteritems_args_kw(*args, **kw): + """Yield the items from the positional argument (if given) and then any from *kw*. + + :raises TypeError: if more than one positional argument is given. + """ + args_len = len(args) + if args_len > 1: + raise TypeError('Expected at most 1 positional argument, got %d' % args_len) + itemchain = None + if args: + arg = args[0] + if arg: + itemchain = _iteritems_mapping_or_iterable(arg) + if kw: + iterkw = iteritems(kw) + itemchain = chain(itemchain, iterkw) if itemchain else iterkw + return itemchain or _NULL_IT + + +def inverted(arg): + """Yield the inverse items of the provided object. + + If *arg* has a :func:`callable` ``__inverted__`` attribute, + return the result of calling it. + + Otherwise, return an iterator over the items in `arg`, + inverting each item on the fly. + + *See also* :attr:`bidict.BidirectionalMapping.__inverted__` + """ + inv = getattr(arg, '__inverted__', None) + if callable(inv): + return inv() + return ((val, key) for (key, val) in _iteritems_mapping_or_iterable(arg)) diff --git a/libs/bidict/compat.py b/libs/bidict/compat.py new file mode 100644 index 000000000..dc095c920 --- /dev/null +++ b/libs/bidict/compat.py @@ -0,0 +1,78 @@ +# -*- coding: utf-8 -*- +# Copyright 2009-2019 Joshua Bronson. All Rights Reserved. +# +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. + + +"""Compatibility helpers.""" + +from operator import methodcaller +from platform import python_implementation +from sys import version_info +from warnings import warn + + +# Use #: (before or) at the end of each line with a member we want to show up in the docs, +# otherwise Sphinx won't include (even though we configure automodule with undoc-members). + +PYMAJOR, PYMINOR = version_info[:2] #: +PY2 = PYMAJOR == 2 #: +PYIMPL = python_implementation() #: +CPY = PYIMPL == 'CPython' #: +PYPY = PYIMPL == 'PyPy' #: +DICTS_ORDERED = PYPY or (CPY and (PYMAJOR, PYMINOR) >= (3, 6)) #: + +# Without the following, pylint gives lots of false positives. +# pylint: disable=invalid-name,unused-import,ungrouped-imports,no-name-in-module + +if PY2: + if PYMINOR < 7: # pragma: no cover + raise ImportError('Python 2.7 or 3.5+ is required.') + warn('Python 2 support will be dropped in a future release.') + + # abstractproperty deprecated in Python 3.3 in favor of using @property with @abstractmethod. + # Before 3.3, this silently fails to detect when an abstract property has not been overridden. + from abc import abstractproperty #: + + from itertools import izip #: + + # In Python 3, the collections ABCs were moved into collections.abc, which does not exist in + # Python 2. Support for importing them directly from collections is dropped in Python 3.8. + import collections as collections_abc # noqa: F401 (imported but unused) + from collections import ( # noqa: F401 (imported but unused) + Mapping, MutableMapping, KeysView, ValuesView, ItemsView) + + viewkeys = lambda m: m.viewkeys() if hasattr(m, 'viewkeys') else KeysView(m) #: + viewvalues = lambda m: m.viewvalues() if hasattr(m, 'viewvalues') else ValuesView(m) #: + viewitems = lambda m: m.viewitems() if hasattr(m, 'viewitems') else ItemsView(m) #: + + iterkeys = lambda m: m.iterkeys() if hasattr(m, 'iterkeys') else iter(m.keys()) #: + itervalues = lambda m: m.itervalues() if hasattr(m, 'itervalues') else iter(m.values()) #: + iteritems = lambda m: m.iteritems() if hasattr(m, 'iteritems') else iter(m.items()) #: + +else: + # Assume Python 3 when not PY2, but explicitly check before showing this warning. + if PYMAJOR == 3 and PYMINOR < 5: # pragma: no cover + warn('Python 3.4 and below are not supported.') + + import collections.abc as collections_abc # noqa: F401 (imported but unused) + from collections.abc import ( # noqa: F401 (imported but unused) + Mapping, MutableMapping, KeysView, ValuesView, ItemsView) + + viewkeys = methodcaller('keys') #: + viewvalues = methodcaller('values') #: + viewitems = methodcaller('items') #: + + def _compose(f, g): + return lambda x: f(g(x)) + + iterkeys = _compose(iter, viewkeys) #: + itervalues = _compose(iter, viewvalues) #: + iteritems = _compose(iter, viewitems) #: + + from abc import abstractmethod + abstractproperty = _compose(property, abstractmethod) #: + + izip = zip #: diff --git a/libs/bidict/metadata.py b/libs/bidict/metadata.py new file mode 100644 index 000000000..95ec8af78 --- /dev/null +++ b/libs/bidict/metadata.py @@ -0,0 +1,49 @@ +# -*- coding: utf-8 -*- +# Copyright 2009-2019 Joshua Bronson. All Rights Reserved. +# +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. + +"""Define bidict package metadata.""" + + +__version__ = '0.0.0.VERSION_NOT_FOUND' + +# _version.py is generated by setuptools_scm (via its `write_to` param, see setup.py) +try: + from ._version import version as __version__ # pylint: disable=unused-import +except (ImportError, ValueError, SystemError): # pragma: no cover + try: + import pkg_resources + except ImportError: + pass + else: + try: + __version__ = pkg_resources.get_distribution('bidict').version + except pkg_resources.DistributionNotFound: + pass + +try: + __version_info__ = tuple(int(p) if i < 3 else p for (i, p) in enumerate(__version__.split('.'))) +except Exception: # noqa: E722; pragma: no cover; pylint: disable=broad-except + __vesion_info__ = (0, 0, 0, 'PARSE FAILURE: __version__=%s' % __version__) + +__author__ = u'Joshua Bronson' +__maintainer__ = u'Joshua Bronson' +__copyright__ = u'Copyright 2019 Joshua Bronson' +__email__ = u'jab@math.brown.edu' + +# See: ../docs/thanks.rst +__credits__ = [i.strip() for i in u""" +Joshua Bronson, Michael Arntzenius, Francis Carr, Gregory Ewing, Raymond Hettinger, Jozef Knaperek, +Daniel Pope, Terry Reedy, David Turner, Tom Viner, Richard Sanger, Zeyi Wang +""".split(u',')] + +__description__ = u'Efficient, Pythonic bidirectional map implementation and related functionality' +__keywords__ = 'dict dictionary mapping datastructure bimap bijection bijective ' \ + 'injective inverse reverse bidirectional two-way 2-way' + +__license__ = u'MPL 2.0' +__status__ = u'Beta' +__url__ = u'https://bidict.readthedocs.io' diff --git a/libs/engineio/__init__.py b/libs/engineio/__init__.py index f2c5b774c..b897468d2 100644 --- a/libs/engineio/__init__.py +++ b/libs/engineio/__init__.py @@ -17,7 +17,7 @@ else: # pragma: no cover get_tornado_handler = None ASGIApp = None -__version__ = '3.11.2' +__version__ = '4.0.2dev' __all__ = ['__version__', 'Server', 'WSGIApp', 'Middleware', 'Client'] if AsyncServer is not None: # pragma: no cover diff --git a/libs/engineio/async_drivers/aiohttp.py b/libs/engineio/async_drivers/aiohttp.py index ad6987649..a59199588 100644 --- a/libs/engineio/async_drivers/aiohttp.py +++ b/libs/engineio/async_drivers/aiohttp.py @@ -3,7 +3,6 @@ import sys from urllib.parse import urlsplit from aiohttp.web import Response, WebSocketResponse -import six def create_route(app, engineio_server, engineio_endpoint): @@ -113,8 +112,8 @@ class WebSocket(object): # pragma: no cover async def wait(self): msg = await self._sock.receive() - if not isinstance(msg.data, six.binary_type) and \ - not isinstance(msg.data, six.text_type): + if not isinstance(msg.data, bytes) and \ + not isinstance(msg.data, str): raise IOError() return msg.data diff --git a/libs/engineio/async_drivers/asgi.py b/libs/engineio/async_drivers/asgi.py index 9f14ef05f..eb3139b5e 100644 --- a/libs/engineio/async_drivers/asgi.py +++ b/libs/engineio/async_drivers/asgi.py @@ -1,5 +1,6 @@ import os import sys +import asyncio from engineio.static_files import get_static_file @@ -19,6 +20,10 @@ class ASGIApp: :param engineio_path: The endpoint where the Engine.IO application should be installed. The default value is appropriate for most cases. + :param on_startup: function to be called on application startup; can be + coroutine + :param on_shutdown: function to be called on application shutdown; can be + coroutine Example usage:: @@ -34,11 +39,14 @@ class ASGIApp: uvicorn.run(app, '127.0.0.1', 5000) """ def __init__(self, engineio_server, other_asgi_app=None, - static_files=None, engineio_path='engine.io'): + static_files=None, engineio_path='engine.io', + on_startup=None, on_shutdown=None): self.engineio_server = engineio_server self.other_asgi_app = other_asgi_app self.engineio_path = engineio_path.strip('/') self.static_files = static_files or {} + self.on_startup = on_startup + self.on_shutdown = on_shutdown async def __call__(self, scope, receive, send): if scope['type'] in ['http', 'websocket'] and \ @@ -73,11 +81,29 @@ class ASGIApp: await self.not_found(receive, send) async def lifespan(self, receive, send): - event = await receive() - if event['type'] == 'lifespan.startup': - await send({'type': 'lifespan.startup.complete'}) - elif event['type'] == 'lifespan.shutdown': - await send({'type': 'lifespan.shutdown.complete'}) + while True: + event = await receive() + if event['type'] == 'lifespan.startup': + if self.on_startup: + try: + await self.on_startup() \ + if asyncio.iscoroutinefunction(self.on_startup) \ + else self.on_startup() + except: + await send({'type': 'lifespan.startup.failed'}) + return + await send({'type': 'lifespan.startup.complete'}) + elif event['type'] == 'lifespan.shutdown': + if self.on_shutdown: + try: + await self.on_shutdown() \ + if asyncio.iscoroutinefunction(self.on_shutdown) \ + else self.on_shutdown() + except: + await send({'type': 'lifespan.shutdown.failed'}) + return + await send({'type': 'lifespan.shutdown.complete'}) + return async def not_found(self, receive, send): """Return a 404 Not Found error to the client.""" @@ -111,7 +137,7 @@ async def translate_request(scope, receive, send): if event['type'] == 'http.request': payload += event.get('body') or b'' elif event['type'] == 'websocket.connect': - await send({'type': 'websocket.accept'}) + pass else: return {} @@ -139,6 +165,7 @@ async def translate_request(scope, receive, send): 'SERVER_PORT': '0', 'asgi.receive': receive, 'asgi.send': send, + 'asgi.scope': scope, } for hdr_name, hdr_value in scope['headers']: @@ -163,6 +190,14 @@ async def translate_request(scope, receive, send): async def make_response(status, headers, payload, environ): headers = [(h[0].encode('utf-8'), h[1].encode('utf-8')) for h in headers] + if environ['asgi.scope']['type'] == 'websocket': + if status.startswith('200 '): + await environ['asgi.send']({'type': 'websocket.accept', + 'headers': headers}) + else: + await environ['asgi.send']({'type': 'websocket.close'}) + return + await environ['asgi.send']({'type': 'http.response.start', 'status': int(status.split(' ')[0]), 'headers': headers}) @@ -183,6 +218,7 @@ class WebSocket(object): # pragma: no cover async def __call__(self, environ): self.asgi_receive = environ['asgi.receive'] self.asgi_send = environ['asgi.send'] + await self.asgi_send({'type': 'websocket.accept'}) await self.handler(self) async def close(self): diff --git a/libs/engineio/async_drivers/gevent_uwsgi.py b/libs/engineio/async_drivers/gevent_uwsgi.py index 07fa2a79d..bdee812de 100644 --- a/libs/engineio/async_drivers/gevent_uwsgi.py +++ b/libs/engineio/async_drivers/gevent_uwsgi.py @@ -1,7 +1,5 @@ from __future__ import absolute_import -import six - import gevent from gevent import queue from gevent.event import Event @@ -75,7 +73,7 @@ class uWSGIWebSocket(object): # pragma: no cover def _send(self, msg): """Transmits message either in binary or UTF-8 text mode, depending on its type.""" - if isinstance(msg, six.binary_type): + if isinstance(msg, bytes): method = uwsgi.websocket_send_binary else: method = uwsgi.websocket_send @@ -86,11 +84,11 @@ class uWSGIWebSocket(object): # pragma: no cover def _decode_received(self, msg): """Returns either bytes or str, depending on message type.""" - if not isinstance(msg, six.binary_type): + if not isinstance(msg, bytes): # already decoded - do nothing return msg # only decode from utf-8 if message is not binary data - type = six.byte2int(msg[0:1]) + type = ord(msg[0:1]) if type >= 48: # no binary return msg.decode('utf-8') # binary message, don't try to decode diff --git a/libs/engineio/async_drivers/sanic.py b/libs/engineio/async_drivers/sanic.py index 6929654b9..e9555f310 100644 --- a/libs/engineio/async_drivers/sanic.py +++ b/libs/engineio/async_drivers/sanic.py @@ -1,16 +1,15 @@ import sys from urllib.parse import urlsplit -from sanic.response import HTTPResponse -try: +try: # pragma: no cover + from sanic.response import HTTPResponse from sanic.websocket import WebSocketProtocol except ImportError: - # the installed version of sanic does not have websocket support + HTTPResponse = None WebSocketProtocol = None -import six -def create_route(app, engineio_server, engineio_endpoint): +def create_route(app, engineio_server, engineio_endpoint): # pragma: no cover """This function sets up the engine.io endpoint as a route for the application. @@ -26,7 +25,7 @@ def create_route(app, engineio_server, engineio_endpoint): pass -def translate_request(request): +def translate_request(request): # pragma: no cover """This function takes the arguments passed to the request handler and uses them to generate a WSGI compatible environ dictionary. """ @@ -89,7 +88,7 @@ def translate_request(request): return environ -def make_response(status, headers, payload, environ): +def make_response(status, headers, payload, environ): # pragma: no cover """This function generates an appropriate response object for this async mode. """ @@ -100,7 +99,7 @@ def make_response(status, headers, payload, environ): content_type = h[1] else: headers_dict[h[0]] = h[1] - return HTTPResponse(body_bytes=payload, content_type=content_type, + return HTTPResponse(body=payload, content_type=content_type, status=int(status.split()[0]), headers=headers_dict) @@ -129,8 +128,8 @@ class WebSocket(object): # pragma: no cover async def wait(self): data = await self._sock.recv() - if not isinstance(data, six.binary_type) and \ - not isinstance(data, six.text_type): + if not isinstance(data, bytes) and \ + not isinstance(data, str): raise IOError() return data diff --git a/libs/engineio/async_drivers/tornado.py b/libs/engineio/async_drivers/tornado.py index adfe18f5a..eb1c4de8a 100644 --- a/libs/engineio/async_drivers/tornado.py +++ b/libs/engineio/async_drivers/tornado.py @@ -5,15 +5,13 @@ from .. import exceptions import tornado.web import tornado.websocket -import six def get_tornado_handler(engineio_server): class Handler(tornado.websocket.WebSocketHandler): # pragma: no cover def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - if isinstance(engineio_server.cors_allowed_origins, - six.string_types): + if isinstance(engineio_server.cors_allowed_origins, str): if engineio_server.cors_allowed_origins == '*': self.allowed_origins = None else: @@ -170,8 +168,8 @@ class WebSocket(object): # pragma: no cover async def wait(self): msg = await self.tornado_handler.get_next_message() - if not isinstance(msg, six.binary_type) and \ - not isinstance(msg, six.text_type): + if not isinstance(msg, bytes) and \ + not isinstance(msg, str): raise IOError() return msg diff --git a/libs/engineio/asyncio_client.py b/libs/engineio/asyncio_client.py index 049b4bd95..4a11eb3b2 100644 --- a/libs/engineio/asyncio_client.py +++ b/libs/engineio/asyncio_client.py @@ -1,17 +1,36 @@ import asyncio +import signal import ssl +import threading try: import aiohttp except ImportError: # pragma: no cover aiohttp = None -import six from . import client from . import exceptions from . import packet from . import payload +async_signal_handler_set = False + + +def async_signal_handler(): + """SIGINT handler. + + Disconnect all active async clients. + """ + async def _handler(): + asyncio.get_event_loop().stop() + for c in client.connected_clients[:]: + if c.is_asyncio_based(): + await c.disconnect() + else: # pragma: no cover + pass + + asyncio.ensure_future(_handler()) + class AsyncClient(client.Client): """An Engine.IO client for asyncio. @@ -22,13 +41,18 @@ class AsyncClient(client.Client): :param logger: To enable logging set to ``True`` or pass a logger object to use. To disable logging set to ``False``. The default is - ``False``. + ``False``. Note that fatal errors are logged even when + ``logger`` is ``False``. :param json: An alternative json module to use for encoding and decoding packets. Custom json modules must have ``dumps`` and ``loads`` functions that are compatible with the standard library versions. :param request_timeout: A timeout in seconds for requests. The default is 5 seconds. + :param http_session: an initialized ``aiohttp.ClientSession`` object to be + used when sending requests to the server. Use it if + you need to add special client options such as proxy + servers, SSL certificates, etc. :param ssl_verify: ``True`` to verify SSL certificates, or ``False`` to skip SSL certificate verification, allowing connections to servers with self signed certificates. @@ -37,7 +61,7 @@ class AsyncClient(client.Client): def is_asyncio_based(self): return True - async def connect(self, url, headers={}, transports=None, + async def connect(self, url, headers=None, transports=None, engineio_path='engine.io'): """Connect to an Engine.IO server. @@ -60,11 +84,22 @@ class AsyncClient(client.Client): eio = engineio.Client() await eio.connect('http://localhost:5000') """ + global async_signal_handler_set + if not async_signal_handler_set and \ + threading.current_thread() == threading.main_thread(): + + try: + asyncio.get_event_loop().add_signal_handler( + signal.SIGINT, async_signal_handler) + async_signal_handler_set = True + except NotImplementedError: # pragma: no cover + self.logger.warning('Signal handler is unsupported') + if self.state != 'disconnected': raise ValueError('Client is not in a disconnected state') valid_transports = ['polling', 'websocket'] if transports is not None: - if isinstance(transports, six.text_type): + if isinstance(transports, str): transports = [transports] transports = [transport for transport in transports if transport in valid_transports] @@ -73,7 +108,7 @@ class AsyncClient(client.Client): self.transports = transports or valid_transports self.queue = self.create_queue() return await getattr(self, '_connect_' + self.transports[0])( - url, headers, engineio_path) + url, headers or {}, engineio_path) async def wait(self): """Wait until the connection with the server ends. @@ -86,21 +121,16 @@ class AsyncClient(client.Client): if self.read_loop_task: await self.read_loop_task - async def send(self, data, binary=None): + async def send(self, data): """Send a message to a client. :param data: The data to send to the client. Data can be of type ``str``, ``bytes``, ``list`` or ``dict``. If a ``list`` or ``dict``, the data will be serialized as JSON. - :param binary: ``True`` to send packet as binary, ``False`` to send - as text. If not given, unicode (Python 2) and str - (Python 3) are sent as text, and str (Python 2) and - bytes (Python 3) are sent as binary. Note: this method is a coroutine. """ - await self._send_packet(packet.Packet(packet.MESSAGE, data=data, - binary=binary)) + await self._send_packet(packet.Packet(packet.MESSAGE, data=data)) async def disconnect(self, abort=False): """Disconnect from the server. @@ -182,14 +212,20 @@ class AsyncClient(client.Client): raise exceptions.ConnectionError( 'Connection refused by the server') if r.status < 200 or r.status >= 300: + self._reset() + try: + arg = await r.json() + except aiohttp.ClientError: + arg = None raise exceptions.ConnectionError( 'Unexpected status code {} in server response'.format( - r.status)) + r.status), arg) try: - p = payload.Payload(encoded_payload=await r.read()) + p = payload.Payload(encoded_payload=(await r.read()).decode( + 'utf-8')) except ValueError: - six.raise_from(exceptions.ConnectionError( - 'Unexpected response from server'), None) + raise exceptions.ConnectionError( + 'Unexpected response from server') from None open_packet = p.packets[0] if open_packet.packet_type != packet.OPEN: raise exceptions.ConnectionError( @@ -198,8 +234,8 @@ class AsyncClient(client.Client): 'Polling connection accepted with ' + str(open_packet.data)) self.sid = open_packet.data['sid'] self.upgrades = open_packet.data['upgrades'] - self.ping_interval = open_packet.data['pingInterval'] / 1000.0 - self.ping_timeout = open_packet.data['pingTimeout'] / 1000.0 + self.ping_interval = int(open_packet.data['pingInterval']) / 1000.0 + self.ping_timeout = int(open_packet.data['pingTimeout']) / 1000.0 self.current_transport = 'polling' self.base_url += '&sid=' + self.sid @@ -216,7 +252,6 @@ class AsyncClient(client.Client): # upgrade to websocket succeeded, we're done here return - self.ping_loop_task = self.start_background_task(self._ping_loop) self.write_loop_task = self.start_background_task(self._write_loop) self.read_loop_task = self.start_background_task( self._read_loop_polling) @@ -242,6 +277,17 @@ class AsyncClient(client.Client): if self.http is None or self.http.closed: # pragma: no cover self.http = aiohttp.ClientSession() + # extract any new cookies passed in a header so that they can also be + # sent the the WebSocket route + cookies = {} + for header, value in headers.items(): + if header.lower() == 'cookie': + cookies = dict( + [cookie.split('=', 1) for cookie in value.split('; ')]) + del headers[header] + break + self.http.cookie_jar.update_cookies(cookies) + try: if not self.ssl_verify: ssl_context = ssl.create_default_context() @@ -255,7 +301,8 @@ class AsyncClient(client.Client): websocket_url + self._get_url_timestamp(), headers=headers) except (aiohttp.client_exceptions.WSServerHandshakeError, - aiohttp.client_exceptions.ServerConnectionError): + aiohttp.client_exceptions.ServerConnectionError, + aiohttp.client_exceptions.ClientConnectionError): if upgrade: self.logger.warning( 'WebSocket upgrade failed: connection error') @@ -263,8 +310,7 @@ class AsyncClient(client.Client): else: raise exceptions.ConnectionError('Connection error') if upgrade: - p = packet.Packet(packet.PING, data='probe').encode( - always_bytes=False) + p = packet.Packet(packet.PING, data='probe').encode() try: await ws.send_str(p) except Exception as e: # pragma: no cover @@ -284,7 +330,7 @@ class AsyncClient(client.Client): self.logger.warning( 'WebSocket upgrade failed: no PONG packet') return False - p = packet.Packet(packet.UPGRADE).encode(always_bytes=False) + p = packet.Packet(packet.UPGRADE).encode() try: await ws.send_str(p) except Exception as e: # pragma: no cover @@ -307,8 +353,8 @@ class AsyncClient(client.Client): 'WebSocket connection accepted with ' + str(open_packet.data)) self.sid = open_packet.data['sid'] self.upgrades = open_packet.data['upgrades'] - self.ping_interval = open_packet.data['pingInterval'] / 1000.0 - self.ping_timeout = open_packet.data['pingTimeout'] / 1000.0 + self.ping_interval = int(open_packet.data['pingInterval']) / 1000.0 + self.ping_timeout = int(open_packet.data['pingTimeout']) / 1000.0 self.current_transport = 'websocket' self.state = 'connected' @@ -316,7 +362,6 @@ class AsyncClient(client.Client): await self._trigger_event('connect', run_async=False) self.ws = ws - self.ping_loop_task = self.start_background_task(self._ping_loop) self.write_loop_task = self.start_background_task(self._write_loop) self.read_loop_task = self.start_background_task( self._read_loop_websocket) @@ -331,8 +376,8 @@ class AsyncClient(client.Client): pkt.data if not isinstance(pkt.data, bytes) else '') if pkt.packet_type == packet.MESSAGE: await self._trigger_event('message', pkt.data, run_async=True) - elif pkt.packet_type == packet.PONG: - self.pong_received = True + elif pkt.packet_type == packet.PING: + await self._send_packet(packet.Packet(packet.PONG, pkt.data)) elif pkt.packet_type == packet.CLOSE: await self.disconnect(abort=True) elif pkt.packet_type == packet.NOOP: @@ -409,33 +454,6 @@ class AsyncClient(client.Client): return False return ret - async def _ping_loop(self): - """This background task sends a PING to the server at the requested - interval. - """ - self.pong_received = True - if self.ping_loop_event is None: - self.ping_loop_event = self.create_event() - else: - self.ping_loop_event.clear() - while self.state == 'connected': - if not self.pong_received: - self.logger.info( - 'PONG response has not been received, aborting') - if self.ws: - await self.ws.close() - await self.queue.put(None) - break - self.pong_received = False - await self._send_packet(packet.Packet(packet.PING)) - try: - await asyncio.wait_for(self.ping_loop_event.wait(), - self.ping_interval) - except (asyncio.TimeoutError, - asyncio.CancelledError): # pragma: no cover - pass - self.logger.info('Exiting ping task') - async def _read_loop_polling(self): """Read packets by polling the Engine.IO server.""" while self.state == 'connected': @@ -455,7 +473,8 @@ class AsyncClient(client.Client): await self.queue.put(None) break try: - p = payload.Payload(encoded_payload=await r.read()) + p = payload.Payload(encoded_payload=(await r.read()).decode( + 'utf-8')) except ValueError: self.logger.warning( 'Unexpected packet from server, aborting') @@ -466,10 +485,6 @@ class AsyncClient(client.Client): self.logger.info('Waiting for write loop task to end') await self.write_loop_task - self.logger.info('Waiting for ping loop task to end') - if self.ping_loop_event: # pragma: no cover - self.ping_loop_event.set() - await self.ping_loop_task if self.state == 'connected': await self._trigger_event('disconnect', run_async=False) try: @@ -484,9 +499,18 @@ class AsyncClient(client.Client): while self.state == 'connected': p = None try: - p = (await self.ws.receive()).data + p = await asyncio.wait_for( + self.ws.receive(), + timeout=self.ping_interval + self.ping_timeout) + p = p.data if p is None: # pragma: no cover - raise RuntimeError('WebSocket read returned None') + await self.queue.put(None) + break # the connection is broken + except asyncio.TimeoutError: + self.logger.warning( + 'Server has stopped communicating, aborting') + await self.queue.put(None) + break except aiohttp.client_exceptions.ServerDisconnectedError: self.logger.info( 'Read loop: WebSocket connection was closed, aborting') @@ -494,20 +518,21 @@ class AsyncClient(client.Client): break except Exception as e: self.logger.info( - 'Unexpected error "%s", aborting', str(e)) + 'Unexpected error receiving packet: "%s", aborting', + str(e)) + await self.queue.put(None) + break + try: + pkt = packet.Packet(encoded_packet=p) + except Exception as e: # pragma: no cover + self.logger.info( + 'Unexpected error decoding packet: "%s", aborting', str(e)) await self.queue.put(None) break - if isinstance(p, six.text_type): # pragma: no cover - p = p.encode('utf-8') - pkt = packet.Packet(encoded_packet=p) await self._receive_packet(pkt) self.logger.info('Waiting for write loop task to end') await self.write_loop_task - self.logger.info('Waiting for ping loop task to end') - if self.ping_loop_event: # pragma: no cover - self.ping_loop_event.set() - await self.ping_loop_task if self.state == 'connected': await self._trigger_event('disconnect', run_async=False) try: @@ -571,13 +596,12 @@ class AsyncClient(client.Client): try: for pkt in packets: if pkt.binary: - await self.ws.send_bytes(pkt.encode( - always_bytes=False)) + await self.ws.send_bytes(pkt.encode()) else: - await self.ws.send_str(pkt.encode( - always_bytes=False)) + await self.ws.send_str(pkt.encode()) self.queue.task_done() - except aiohttp.client_exceptions.ServerDisconnectedError: + except (aiohttp.client_exceptions.ServerDisconnectedError, + BrokenPipeError, OSError): self.logger.info( 'Write loop: WebSocket connection was closed, ' 'aborting') diff --git a/libs/engineio/asyncio_server.py b/libs/engineio/asyncio_server.py index d52b556db..6639f26bf 100644 --- a/libs/engineio/asyncio_server.py +++ b/libs/engineio/asyncio_server.py @@ -1,7 +1,5 @@ import asyncio - -import six -from six.moves import urllib +import urllib from . import exceptions from . import packet @@ -24,23 +22,30 @@ class AsyncServer(server.Server): "tornado", and finally "asgi". The first async mode that has all its dependencies installed is the one that is chosen. - :param ping_timeout: The time in seconds that the client waits for the - server to respond before disconnecting. - :param ping_interval: The interval in seconds at which the client pings - the server. The default is 25 seconds. For advanced + :param ping_interval: The interval in seconds at which the server pings + the client. The default is 25 seconds. For advanced control, a two element tuple can be given, where the first number is the ping interval and the second - is a grace period added by the server. The default - grace period is 5 seconds. + is a grace period added by the server. + :param ping_timeout: The time in seconds that the client waits for the + server to respond before disconnecting. The default + is 5 seconds. :param max_http_buffer_size: The maximum size of a message when using the - polling transport. + polling transport. The default is 1,000,000 + bytes. :param allow_upgrades: Whether to allow transport upgrades or not. :param http_compression: Whether to compress packages when using the polling transport. :param compression_threshold: Only compress messages when their byte size is greater than this value. - :param cookie: Name of the HTTP cookie that contains the client session - id. If set to ``None``, a cookie is not sent to the client. + :param cookie: If set to a string, it is the name of the HTTP cookie the + server sends back tot he client containing the client + session id. If set to a dictionary, the ``'name'`` key + contains the cookie name and other keys define cookie + attributes, where the value of each attribute can be a + string, a callable with no arguments, or a boolean. If set + to ``None`` (the default), a cookie is not sent to the + client. :param cors_allowed_origins: Origin or list of origins that are allowed to connect to this server. Only the same origin is allowed by default. Set this argument to @@ -49,7 +54,8 @@ class AsyncServer(server.Server): :param cors_credentials: Whether credentials (cookies, authentication) are allowed in requests to this server. :param logger: To enable logging set to ``True`` or pass a logger object to - use. To disable logging set to ``False``. + use. To disable logging set to ``False``. Note that fatal + errors are logged even when ``logger`` is ``False``. :param json: An alternative json module to use for encoding and decoding packets. Custom json modules must have ``dumps`` and ``loads`` functions that are compatible with the standard library @@ -71,17 +77,13 @@ class AsyncServer(server.Server): engineio_path = engineio_path.strip('/') self._async['create_route'](app, self, '/{}/'.format(engineio_path)) - async def send(self, sid, data, binary=None): + async def send(self, sid, data): """Send a message to a client. :param sid: The session id of the recipient client. :param data: The data to send to the client. Data can be of type ``str``, ``bytes``, ``list`` or ``dict``. If a ``list`` or ``dict``, the data will be serialized as JSON. - :param binary: ``True`` to send packet as binary, ``False`` to send - as text. If not given, unicode (Python 2) and str - (Python 3) are sent as text, and str (Python 2) and - bytes (Python 3) are sent as binary. Note: this method is a coroutine. """ @@ -91,8 +93,7 @@ class AsyncServer(server.Server): # the socket is not available self.logger.warning('Cannot send to sid %s', sid) return - await socket.send(packet.Packet(packet.MESSAGE, data=data, - binary=binary)) + await socket.send(packet.Packet(packet.MESSAGE, data=data)) async def get_session(self, sid): """Return the user session for a client. @@ -172,7 +173,7 @@ class AsyncServer(server.Server): del self.sockets[sid] else: await asyncio.wait([client.close() - for client in six.itervalues(self.sockets)]) + for client in self.sockets.values()]) self.sockets = {} async def handle_request(self, *args, **kwargs): @@ -198,28 +199,32 @@ class AsyncServer(server.Server): allowed_origins = self._cors_allowed_origins(environ) if allowed_origins is not None and origin not in \ allowed_origins: - self.logger.info(origin + ' is not an accepted origin.') - r = self._bad_request() - make_response = self._async['make_response'] - if asyncio.iscoroutinefunction(make_response): - response = await make_response( - r['status'], r['headers'], r['response'], environ) - else: - response = make_response(r['status'], r['headers'], - r['response'], environ) - return response + self._log_error_once( + origin + ' is not an accepted origin.', 'bad-origin') + return await self._make_response( + self._bad_request( + origin + ' is not an accepted origin.'), + environ) method = environ['REQUEST_METHOD'] query = urllib.parse.parse_qs(environ.get('QUERY_STRING', '')) sid = query['sid'][0] if 'sid' in query else None - b64 = False jsonp = False jsonp_index = None - if 'b64' in query: - if query['b64'][0] == "1" or query['b64'][0].lower() == "true": - b64 = True + # make sure the client speaks a compatible Engine.IO version + sid = query['sid'][0] if 'sid' in query else None + if sid is None and query.get('EIO') != ['4']: + self._log_error_once( + 'The client is using an unsupported version of the Socket.IO ' + 'or Engine.IO protocols', 'bad-version' + ) + return await self._make_response(self._bad_request( + 'The client is using an unsupported version of the Socket.IO ' + 'or Engine.IO protocols' + ), environ) + if 'j' in query: jsonp = True try: @@ -229,28 +234,34 @@ class AsyncServer(server.Server): pass if jsonp and jsonp_index is None: - self.logger.warning('Invalid JSONP index number') - r = self._bad_request() + self._log_error_once('Invalid JSONP index number', + 'bad-jsonp-index') + r = self._bad_request('Invalid JSONP index number') elif method == 'GET': if sid is None: transport = query.get('transport', ['polling'])[0] - if transport != 'polling' and transport != 'websocket': - self.logger.warning('Invalid transport %s', transport) - r = self._bad_request() - else: + # transport must be one of 'polling' or 'websocket'. + # if 'websocket', the HTTP_UPGRADE header must match. + upgrade_header = environ.get('HTTP_UPGRADE').lower() \ + if 'HTTP_UPGRADE' in environ else None + if transport == 'polling' \ + or transport == upgrade_header == 'websocket': r = await self._handle_connect(environ, transport, - b64, jsonp_index) + jsonp_index) + else: + self._log_error_once('Invalid transport ' + transport, + 'bad-transport') + r = self._bad_request('Invalid transport ' + transport) else: if sid not in self.sockets: - self.logger.warning('Invalid session %s', sid) - r = self._bad_request() + self._log_error_once('Invalid session ' + sid, 'bad-sid') + r = self._bad_request('Invalid session ' + sid) else: socket = self._get_socket(sid) try: packets = await socket.handle_get_request(environ) if isinstance(packets, list): - r = self._ok(packets, b64=b64, - jsonp_index=jsonp_index) + r = self._ok(packets, jsonp_index=jsonp_index) else: r = packets except exceptions.EngineIOError: @@ -261,8 +272,8 @@ class AsyncServer(server.Server): del self.sockets[sid] elif method == 'POST': if sid is None or sid not in self.sockets: - self.logger.warning('Invalid session %s', sid) - r = self._bad_request() + self._log_error_once('Invalid session ' + sid, 'bad-sid') + r = self._bad_request('Invalid session ' + sid) else: socket = self._get_socket(sid) try: @@ -294,16 +305,7 @@ class AsyncServer(server.Server): getattr(self, '_' + encoding)(r['response']) r['headers'] += [('Content-Encoding', encoding)] break - cors_headers = self._cors_headers(environ) - make_response = self._async['make_response'] - if asyncio.iscoroutinefunction(make_response): - response = await make_response(r['status'], - r['headers'] + cors_headers, - r['response'], environ) - else: - response = make_response(r['status'], r['headers'] + cors_headers, - r['response'], environ) - return response + return await self._make_response(r, environ) def start_background_task(self, target, *args, **kwargs): """Start a background task using the appropriate async model. @@ -362,15 +364,29 @@ class AsyncServer(server.Server): """ return asyncio.Event(*args, **kwargs) - async def _handle_connect(self, environ, transport, b64=False, - jsonp_index=None): + async def _make_response(self, response_dict, environ): + cors_headers = self._cors_headers(environ) + make_response = self._async['make_response'] + if asyncio.iscoroutinefunction(make_response): + response = await make_response( + response_dict['status'], + response_dict['headers'] + cors_headers, + response_dict['response'], environ) + else: + response = make_response( + response_dict['status'], + response_dict['headers'] + cors_headers, + response_dict['response'], environ) + return response + + async def _handle_connect(self, environ, transport, jsonp_index=None): """Handle a client connection request.""" if self.start_service_task: # start the service task to monitor connected clients self.start_service_task = False self.start_background_task(self._service_task) - sid = self._generate_id() + sid = self.generate_id() s = asyncio_socket.AsyncSocket(self, sid) self.sockets[sid] = s @@ -380,17 +396,18 @@ class AsyncServer(server.Server): 'pingTimeout': int(self.ping_timeout * 1000), 'pingInterval': int(self.ping_interval * 1000)}) await s.send(pkt) + s.schedule_ping() ret = await self._trigger_event('connect', sid, environ, run_async=False) - if ret is False: + if ret is not None and ret is not True: del self.sockets[sid] self.logger.warning('Application rejected connection') - return self._unauthorized() + return self._unauthorized(ret or None) if transport == 'websocket': ret = await s.handle_get_request(environ) - if s.closed: + if s.closed and sid in self.sockets: # websocket connection ended, so we are done del self.sockets[sid] return ret @@ -398,9 +415,20 @@ class AsyncServer(server.Server): s.connected = True headers = None if self.cookie: - headers = [('Set-Cookie', self.cookie + '=' + sid)] + if isinstance(self.cookie, dict): + headers = [( + 'Set-Cookie', + self._generate_sid_cookie(sid, self.cookie) + )] + else: + headers = [( + 'Set-Cookie', + self._generate_sid_cookie(sid, { + 'name': self.cookie, 'path': '/', 'SameSite': 'Lax' + }) + )] try: - return self._ok(await s.poll(), headers=headers, b64=b64, + return self._ok(await s.poll(), headers=headers, jsonp_index=jsonp_index) except exceptions.QueueEmpty: return self._bad_request() @@ -459,7 +487,12 @@ class AsyncServer(server.Server): if not socket.closing and not socket.closed: await socket.check_ping_timeout() await self.sleep(sleep_interval) - except (SystemExit, KeyboardInterrupt, asyncio.CancelledError): + except ( + SystemExit, + KeyboardInterrupt, + asyncio.CancelledError, + GeneratorExit, + ): self.logger.info('service task canceled') break except: diff --git a/libs/engineio/asyncio_socket.py b/libs/engineio/asyncio_socket.py index 7057a6cc3..508ee3ca2 100644 --- a/libs/engineio/asyncio_socket.py +++ b/libs/engineio/asyncio_socket.py @@ -1,5 +1,4 @@ import asyncio -import six import sys import time @@ -13,18 +12,24 @@ class AsyncSocket(socket.Socket): async def poll(self): """Wait for packets to send to the client.""" try: - packets = [await asyncio.wait_for(self.queue.get(), - self.server.ping_timeout)] + packets = [await asyncio.wait_for( + self.queue.get(), + self.server.ping_interval + self.server.ping_timeout)] self.queue.task_done() except (asyncio.TimeoutError, asyncio.CancelledError): raise exceptions.QueueEmpty() if packets == [None]: return [] - try: - packets.append(self.queue.get_nowait()) - self.queue.task_done() - except asyncio.QueueEmpty: - pass + while True: + try: + pkt = self.queue.get_nowait() + self.queue.task_done() + if pkt is None: + self.queue.put_nowait(None) + break + packets.append(pkt) + except asyncio.QueueEmpty: + break return packets async def receive(self, pkt): @@ -33,9 +38,8 @@ class AsyncSocket(socket.Socket): self.sid, packet.packet_names[pkt.packet_type], pkt.data if not isinstance(pkt.data, bytes) else '') - if pkt.packet_type == packet.PING: - self.last_ping = time.time() - await self.send(packet.Packet(packet.PONG, pkt.data)) + if pkt.packet_type == packet.PONG: + self.schedule_ping() elif pkt.packet_type == packet.MESSAGE: await self.server._trigger_event( 'message', self.sid, pkt.data, @@ -48,14 +52,11 @@ class AsyncSocket(socket.Socket): raise exceptions.UnknownPacketError() async def check_ping_timeout(self): - """Make sure the client is still sending pings. - - This helps detect disconnections for long-polling clients. - """ + """Make sure the client is still sending pings.""" if self.closed: raise exceptions.SocketIsClosedError() - if time.time() - self.last_ping > self.server.ping_interval + \ - self.server.ping_interval_grace_period: + if self.last_ping and \ + time.time() - self.last_ping > self.server.ping_timeout: self.server.logger.info('%s: Client is gone, closing socket', self.sid) # Passing abort=False here will cause close() to write a @@ -69,8 +70,6 @@ class AsyncSocket(socket.Socket): """Send a packet to the client.""" if not await self.check_ping_timeout(): return - if self.upgrading: - self.packet_backlog.append(pkt) else: await self.queue.put(pkt) self.server.logger.info('%s: Sending packet %s data %s', @@ -88,12 +87,16 @@ class AsyncSocket(socket.Socket): self.server.logger.info('%s: Received request to upgrade to %s', self.sid, transport) return await getattr(self, '_upgrade_' + transport)(environ) + if self.upgrading or self.upgraded: + # we are upgrading to WebSocket, do not return any more packets + # through the polling endpoint + return [packet.Packet(packet.NOOP)] try: packets = await self.poll() except exceptions.QueueEmpty: exc = sys.exc_info() await self.close(wait=False) - six.reraise(*exc) + raise exc[1].with_traceback(exc[2]) return packets async def handle_post_request(self, environ): @@ -102,7 +105,7 @@ class AsyncSocket(socket.Socket): if length > self.server.max_http_buffer_size: raise exceptions.ContentTooLongError() else: - body = await environ['wsgi.input'].read(length) + body = (await environ['wsgi.input'].read(length)).decode('utf-8') p = payload.Payload(encoded_payload=body) for pkt in p.packets: await self.receive(pkt) @@ -118,6 +121,16 @@ class AsyncSocket(socket.Socket): if wait: await self.queue.join() + def schedule_ping(self): + async def send_ping(): + self.last_ping = None + await asyncio.sleep(self.server.ping_interval) + if not self.closing and not self.closed: + self.last_ping = time.time() + await self.send(packet.Packet(packet.PING)) + + self.server.start_background_task(send_ping) + async def _upgrade_websocket(self, environ): """Upgrade the connection from polling to websocket.""" if self.upgraded: @@ -143,15 +156,15 @@ class AsyncSocket(socket.Socket): decoded_pkt.data != 'probe': self.server.logger.info( '%s: Failed websocket upgrade, no PING packet', self.sid) + self.upgrading = False return - await ws.send(packet.Packet( - packet.PONG, - data=six.text_type('probe')).encode(always_bytes=False)) + await ws.send(packet.Packet(packet.PONG, data='probe').encode()) await self.queue.put(packet.Packet(packet.NOOP)) # end poll try: pkt = await ws.wait() except IOError: # pragma: no cover + self.upgrading = False return decoded_pkt = packet.Packet(encoded_packet=pkt) if decoded_pkt.packet_type != packet.UPGRADE: @@ -160,13 +173,9 @@ class AsyncSocket(socket.Socket): ('%s: Failed websocket upgrade, expected UPGRADE packet, ' 'received %s instead.'), self.sid, pkt) + self.upgrading = False return self.upgraded = True - - # flush any packets that were sent during the upgrade - for pkt in self.packet_backlog: - await self.queue.put(pkt) - self.packet_backlog = [] self.upgrading = False else: self.connected = True @@ -185,7 +194,7 @@ class AsyncSocket(socket.Socket): break try: for pkt in packets: - await ws.send(pkt.encode(always_bytes=False)) + await ws.send(pkt.encode()) except: break writer_task = asyncio.ensure_future(writer()) @@ -197,7 +206,9 @@ class AsyncSocket(socket.Socket): p = None wait_task = asyncio.ensure_future(ws.wait()) try: - p = await asyncio.wait_for(wait_task, self.server.ping_timeout) + p = await asyncio.wait_for( + wait_task, + self.server.ping_interval + self.server.ping_timeout) except asyncio.CancelledError: # pragma: no cover # there is a bug (https://bugs.python.org/issue30508) in # asyncio that causes a "Task exception never retrieved" error @@ -216,8 +227,6 @@ class AsyncSocket(socket.Socket): if p is None: # connection closed by client break - if isinstance(p, six.text_type): # pragma: no cover - p = p.encode('utf-8') pkt = packet.Packet(encoded_packet=p) try: await self.receive(pkt) diff --git a/libs/engineio/client.py b/libs/engineio/client.py index b5ab50377..d307a5d62 100644 --- a/libs/engineio/client.py +++ b/libs/engineio/client.py @@ -1,3 +1,5 @@ +from base64 import b64encode +from json import JSONDecodeError import logging try: import queue @@ -7,9 +9,8 @@ import signal import ssl import threading import time +import urllib -import six -from six.moves import urllib try: import requests except ImportError: # pragma: no cover @@ -25,9 +26,6 @@ from . import payload default_logger = logging.getLogger('engineio.client') connected_clients = [] -if six.PY2: # pragma: no cover - ConnectionError = OSError - def signal_handler(sig, frame): """SIGINT handler. @@ -35,10 +33,8 @@ def signal_handler(sig, frame): Disconnect all active clients and then invoke the original signal handler. """ for client in connected_clients[:]: - if client.is_asyncio_based(): - client.start_background_task(client.disconnect, abort=True) - else: - client.disconnect(abort=True) + if not client.is_asyncio_based(): + client.disconnect() if callable(original_signal_handler): return original_signal_handler(sig, frame) else: # pragma: no cover @@ -57,13 +53,18 @@ class Client(object): :param logger: To enable logging set to ``True`` or pass a logger object to use. To disable logging set to ``False``. The default is - ``False``. + ``False``. Note that fatal errors are logged even when + ``logger`` is ``False``. :param json: An alternative json module to use for encoding and decoding packets. Custom json modules must have ``dumps`` and ``loads`` functions that are compatible with the standard library versions. :param request_timeout: A timeout in seconds for requests. The default is 5 seconds. + :param http_session: an initialized ``requests.Session`` object to be used + when sending requests to the server. Use it if you + need to add special client options such as proxy + servers, SSL certificates, custom CA bundle, etc. :param ssl_verify: ``True`` to verify SSL certificates, or ``False`` to skip SSL certificate verification, allowing connections to servers with self signed certificates. @@ -75,9 +76,11 @@ class Client(object): logger=False, json=None, request_timeout=5, + http_session=None, ssl_verify=True): global original_signal_handler - if original_signal_handler is None: + if original_signal_handler is None and \ + threading.current_thread() == threading.main_thread(): original_signal_handler = signal.signal(signal.SIGINT, signal_handler) self.handlers = {} @@ -88,13 +91,10 @@ class Client(object): self.upgrades = None self.ping_interval = None self.ping_timeout = None - self.pong_received = True - self.http = None + self.http = http_session self.ws = None self.read_loop_task = None self.write_loop_task = None - self.ping_loop_task = None - self.ping_loop_event = None self.queue = None self.state = 'disconnected' self.ssl_verify = ssl_verify @@ -105,8 +105,7 @@ class Client(object): self.logger = logger else: self.logger = default_logger - if not logging.root.handlers and \ - self.logger.level == logging.NOTSET: + if self.logger.level == logging.NOTSET: if logger: self.logger.setLevel(logging.INFO) else: @@ -151,7 +150,7 @@ class Client(object): return set_handler set_handler(handler) - def connect(self, url, headers={}, transports=None, + def connect(self, url, headers=None, transports=None, engineio_path='engine.io'): """Connect to an Engine.IO server. @@ -176,7 +175,7 @@ class Client(object): raise ValueError('Client is not in a disconnected state') valid_transports = ['polling', 'websocket'] if transports is not None: - if isinstance(transports, six.string_types): + if isinstance(transports, str): transports = [transports] transports = [transport for transport in transports if transport in valid_transports] @@ -185,7 +184,7 @@ class Client(object): self.transports = transports or valid_transports self.queue = self.create_queue() return getattr(self, '_connect_' + self.transports[0])( - url, headers, engineio_path) + url, headers or {}, engineio_path) def wait(self): """Wait until the connection with the server ends. @@ -196,19 +195,14 @@ class Client(object): if self.read_loop_task: self.read_loop_task.join() - def send(self, data, binary=None): + def send(self, data): """Send a message to a client. :param data: The data to send to the client. Data can be of type ``str``, ``bytes``, ``list`` or ``dict``. If a ``list`` or ``dict``, the data will be serialized as JSON. - :param binary: ``True`` to send packet as binary, ``False`` to send - as text. If not given, unicode (Python 2) and str - (Python 3) are sent as text, and str (Python 2) and - bytes (Python 3) are sent as binary. """ - self._send_packet(packet.Packet(packet.MESSAGE, data=data, - binary=binary)) + self._send_packet(packet.Packet(packet.MESSAGE, data=data)) def disconnect(self, abort=False): """Disconnect from the server. @@ -293,14 +287,19 @@ class Client(object): raise exceptions.ConnectionError( 'Connection refused by the server') if r.status_code < 200 or r.status_code >= 300: + self._reset() + try: + arg = r.json() + except JSONDecodeError: + arg = None raise exceptions.ConnectionError( 'Unexpected status code {} in server response'.format( - r.status_code)) + r.status_code), arg) try: - p = payload.Payload(encoded_payload=r.content) + p = payload.Payload(encoded_payload=r.content.decode('utf-8')) except ValueError: - six.raise_from(exceptions.ConnectionError( - 'Unexpected response from server'), None) + raise exceptions.ConnectionError( + 'Unexpected response from server') from None open_packet = p.packets[0] if open_packet.packet_type != packet.OPEN: raise exceptions.ConnectionError( @@ -309,8 +308,8 @@ class Client(object): 'Polling connection accepted with ' + str(open_packet.data)) self.sid = open_packet.data['sid'] self.upgrades = open_packet.data['upgrades'] - self.ping_interval = open_packet.data['pingInterval'] / 1000.0 - self.ping_timeout = open_packet.data['pingTimeout'] / 1000.0 + self.ping_interval = int(open_packet.data['pingInterval']) / 1000.0 + self.ping_timeout = int(open_packet.data['pingTimeout']) / 1000.0 self.current_transport = 'polling' self.base_url += '&sid=' + self.sid @@ -328,7 +327,6 @@ class Client(object): return # start background tasks associated with this client - self.ping_loop_task = self.start_background_task(self._ping_loop) self.write_loop_task = self.start_background_task(self._write_loop) self.read_loop_task = self.start_background_task( self._read_loop_polling) @@ -337,8 +335,8 @@ class Client(object): """Establish or upgrade to a WebSocket connection with the server.""" if websocket is None: # pragma: no cover # not installed - self.logger.warning('websocket-client package not installed, only ' - 'polling transport is available') + self.logger.error('websocket-client package not installed, only ' + 'polling transport is available') return False websocket_url = self._get_engineio_url(url, engineio_path, 'websocket') if self.sid: @@ -352,22 +350,75 @@ class Client(object): self.logger.info( 'Attempting WebSocket connection to ' + websocket_url) - # get the cookies from the long-polling connection so that they can - # also be sent the the WebSocket route + # get cookies and other settings from the long-polling connection + # so that they are preserved when connecting to the WebSocket route cookies = None + extra_options = {} if self.http: + # cookies cookies = '; '.join(["{}={}".format(cookie.name, cookie.value) for cookie in self.http.cookies]) + for header, value in headers.items(): + if header.lower() == 'cookie': + if cookies: + cookies += '; ' + cookies += value + del headers[header] + break + # auth + if 'Authorization' not in headers and self.http.auth is not None: + if not isinstance(self.http.auth, tuple): # pragma: no cover + raise ValueError('Only basic authentication is supported') + basic_auth = '{}:{}'.format( + self.http.auth[0], self.http.auth[1]).encode('utf-8') + basic_auth = b64encode(basic_auth).decode('utf-8') + headers['Authorization'] = 'Basic ' + basic_auth + + # cert + # this can be given as ('certfile', 'keyfile') or just 'certfile' + if isinstance(self.http.cert, tuple): + extra_options['sslopt'] = { + 'certfile': self.http.cert[0], + 'keyfile': self.http.cert[1]} + elif self.http.cert: + extra_options['sslopt'] = {'certfile': self.http.cert} + + # proxies + if self.http.proxies: + proxy_url = None + if websocket_url.startswith('ws://'): + proxy_url = self.http.proxies.get( + 'ws', self.http.proxies.get('http')) + else: # wss:// + proxy_url = self.http.proxies.get( + 'wss', self.http.proxies.get('https')) + if proxy_url: + parsed_url = urllib.parse.urlparse( + proxy_url if '://' in proxy_url + else 'scheme://' + proxy_url) + extra_options['http_proxy_host'] = parsed_url.hostname + extra_options['http_proxy_port'] = parsed_url.port + extra_options['http_proxy_auth'] = ( + (parsed_url.username, parsed_url.password) + if parsed_url.username or parsed_url.password + else None) + + # verify + if isinstance(self.http.verify, str): + if 'sslopt' in extra_options: + extra_options['sslopt']['ca_certs'] = self.http.verify + else: + extra_options['sslopt'] = {'ca_certs': self.http.verify} + elif not self.http.verify: + self.ssl_verify = False + + if not self.ssl_verify: + extra_options['sslopt'] = {"cert_reqs": ssl.CERT_NONE} try: - if not self.ssl_verify: - ws = websocket.create_connection( - websocket_url + self._get_url_timestamp(), header=headers, - cookie=cookies, sslopt={"cert_reqs": ssl.CERT_NONE}) - else: - ws = websocket.create_connection( - websocket_url + self._get_url_timestamp(), header=headers, - cookie=cookies) + ws = websocket.create_connection( + websocket_url + self._get_url_timestamp(), header=headers, + cookie=cookies, enable_multithread=True, **extra_options) except (ConnectionError, IOError, websocket.WebSocketException): if upgrade: self.logger.warning( @@ -376,8 +427,7 @@ class Client(object): else: raise exceptions.ConnectionError('Connection error') if upgrade: - p = packet.Packet(packet.PING, - data=six.text_type('probe')).encode() + p = packet.Packet(packet.PING, data='probe').encode() try: ws.send(p) except Exception as e: # pragma: no cover @@ -420,17 +470,17 @@ class Client(object): 'WebSocket connection accepted with ' + str(open_packet.data)) self.sid = open_packet.data['sid'] self.upgrades = open_packet.data['upgrades'] - self.ping_interval = open_packet.data['pingInterval'] / 1000.0 - self.ping_timeout = open_packet.data['pingTimeout'] / 1000.0 + self.ping_interval = int(open_packet.data['pingInterval']) / 1000.0 + self.ping_timeout = int(open_packet.data['pingTimeout']) / 1000.0 self.current_transport = 'websocket' self.state = 'connected' connected_clients.append(self) self._trigger_event('connect', run_async=False) self.ws = ws + self.ws.settimeout(self.ping_interval + self.ping_timeout) # start background tasks associated with this client - self.ping_loop_task = self.start_background_task(self._ping_loop) self.write_loop_task = self.start_background_task(self._write_loop) self.read_loop_task = self.start_background_task( self._read_loop_websocket) @@ -445,8 +495,8 @@ class Client(object): pkt.data if not isinstance(pkt.data, bytes) else '') if pkt.packet_type == packet.MESSAGE: self._trigger_event('message', pkt.data, run_async=True) - elif pkt.packet_type == packet.PONG: - self.pong_received = True + elif pkt.packet_type == packet.PING: + self._send_packet(packet.Packet(packet.PONG, pkt.data)) elif pkt.packet_type == packet.CLOSE: self.disconnect(abort=True) elif pkt.packet_type == packet.NOOP: @@ -470,9 +520,11 @@ class Client(object): timeout=None): # pragma: no cover if self.http is None: self.http = requests.Session() + if not self.ssl_verify: + self.http.verify = False try: return self.http.request(method, url, headers=headers, data=body, - timeout=timeout, verify=self.ssl_verify) + timeout=timeout) except requests.exceptions.RequestException as exc: self.logger.info('HTTP %s request to %s failed with error %s.', method, url, exc) @@ -504,7 +556,7 @@ class Client(object): scheme += 's' return ('{scheme}://{netloc}/{path}/?{query}' - '{sep}transport={transport}&EIO=3').format( + '{sep}transport={transport}&EIO=4').format( scheme=scheme, netloc=parsed_url.netloc, path=engineio_path, query=parsed_url.query, sep='&' if parsed_url.query else '', @@ -514,28 +566,6 @@ class Client(object): """Generate the Engine.IO query string timestamp.""" return '&t=' + str(time.time()) - def _ping_loop(self): - """This background task sends a PING to the server at the requested - interval. - """ - self.pong_received = True - if self.ping_loop_event is None: - self.ping_loop_event = self.create_event() - else: - self.ping_loop_event.clear() - while self.state == 'connected': - if not self.pong_received: - self.logger.info( - 'PONG response has not been received, aborting') - if self.ws: - self.ws.close(timeout=0) - self.queue.put(None) - break - self.pong_received = False - self._send_packet(packet.Packet(packet.PING)) - self.ping_loop_event.wait(timeout=self.ping_interval) - self.logger.info('Exiting ping task') - def _read_loop_polling(self): """Read packets by polling the Engine.IO server.""" while self.state == 'connected': @@ -555,7 +585,7 @@ class Client(object): self.queue.put(None) break try: - p = payload.Payload(encoded_payload=r.content) + p = payload.Payload(encoded_payload=r.content.decode('utf-8')) except ValueError: self.logger.warning( 'Unexpected packet from server, aborting') @@ -566,10 +596,6 @@ class Client(object): self.logger.info('Waiting for write loop task to end') self.write_loop_task.join() - self.logger.info('Waiting for ping loop task to end') - if self.ping_loop_event: # pragma: no cover - self.ping_loop_event.set() - self.ping_loop_task.join() if self.state == 'connected': self._trigger_event('disconnect', run_async=False) try: @@ -585,6 +611,11 @@ class Client(object): p = None try: p = self.ws.recv() + except websocket.WebSocketTimeoutException: + self.logger.warning( + 'Server has stopped communicating, aborting') + self.queue.put(None) + break except websocket.WebSocketConnectionClosedException: self.logger.warning( 'WebSocket connection was closed, aborting') @@ -592,20 +623,21 @@ class Client(object): break except Exception as e: self.logger.info( - 'Unexpected error "%s", aborting', str(e)) + 'Unexpected error receiving packet: "%s", aborting', + str(e)) + self.queue.put(None) + break + try: + pkt = packet.Packet(encoded_packet=p) + except Exception as e: # pragma: no cover + self.logger.info( + 'Unexpected error decoding packet: "%s", aborting', str(e)) self.queue.put(None) break - if isinstance(p, six.text_type): # pragma: no cover - p = p.encode('utf-8') - pkt = packet.Packet(encoded_packet=p) self._receive_packet(pkt) self.logger.info('Waiting for write loop task to end') self.write_loop_task.join() - self.logger.info('Waiting for ping loop task to end') - if self.ping_loop_event: # pragma: no cover - self.ping_loop_event.set() - self.ping_loop_task.join() if self.state == 'connected': self._trigger_event('disconnect', run_async=False) try: @@ -667,13 +699,14 @@ class Client(object): # websocket try: for pkt in packets: - encoded_packet = pkt.encode(always_bytes=False) + encoded_packet = pkt.encode() if pkt.binary: self.ws.send_binary(encoded_packet) else: self.ws.send(encoded_packet) self.queue.task_done() - except websocket.WebSocketConnectionClosedException: + except (websocket.WebSocketConnectionClosedException, + BrokenPipeError, OSError): self.logger.warning( 'WebSocket connection was closed, aborting') break diff --git a/libs/engineio/packet.py b/libs/engineio/packet.py index a3aa6d476..9dbd6c684 100644 --- a/libs/engineio/packet.py +++ b/libs/engineio/packet.py @@ -1,12 +1,10 @@ import base64 import json as _json -import six - (OPEN, CLOSE, PING, PONG, MESSAGE, UPGRADE, NOOP) = (0, 1, 2, 3, 4, 5, 6) packet_names = ['OPEN', 'CLOSE', 'PING', 'PONG', 'MESSAGE', 'UPGRADE', 'NOOP'] -binary_types = (six.binary_type, bytearray) +binary_types = (bytes, bytearray) class Packet(object): @@ -14,79 +12,61 @@ class Packet(object): json = _json - def __init__(self, packet_type=NOOP, data=None, binary=None, - encoded_packet=None): + def __init__(self, packet_type=NOOP, data=None, encoded_packet=None): self.packet_type = packet_type self.data = data - if binary is not None: - self.binary = binary - elif isinstance(data, six.text_type): + if isinstance(data, str): self.binary = False elif isinstance(data, binary_types): self.binary = True else: self.binary = False + if self.binary and self.packet_type != MESSAGE: + raise ValueError('Binary packets can only be of type MESSAGE') if encoded_packet: self.decode(encoded_packet) - def encode(self, b64=False, always_bytes=True): + def encode(self, b64=False): """Encode the packet for transmission.""" - if self.binary and not b64: - encoded_packet = six.int2byte(self.packet_type) - else: - encoded_packet = six.text_type(self.packet_type) - if self.binary and b64: - encoded_packet = 'b' + encoded_packet if self.binary: if b64: - encoded_packet += base64.b64encode(self.data).decode('utf-8') + encoded_packet = 'b' + base64.b64encode(self.data).decode( + 'utf-8') else: + encoded_packet = self.data + else: + encoded_packet = str(self.packet_type) + if isinstance(self.data, str): encoded_packet += self.data - elif isinstance(self.data, six.string_types): - encoded_packet += self.data - elif isinstance(self.data, dict) or isinstance(self.data, list): - encoded_packet += self.json.dumps(self.data, - separators=(',', ':')) - elif self.data is not None: - encoded_packet += str(self.data) - if always_bytes and not isinstance(encoded_packet, binary_types): - encoded_packet = encoded_packet.encode('utf-8') + elif isinstance(self.data, dict) or isinstance(self.data, list): + encoded_packet += self.json.dumps(self.data, + separators=(',', ':')) + elif self.data is not None: + encoded_packet += str(self.data) return encoded_packet def decode(self, encoded_packet): """Decode a transmitted package.""" - b64 = False - if not isinstance(encoded_packet, binary_types): - encoded_packet = encoded_packet.encode('utf-8') - elif not isinstance(encoded_packet, bytes): - encoded_packet = bytes(encoded_packet) - self.packet_type = six.byte2int(encoded_packet[0:1]) - if self.packet_type == 98: # 'b' --> binary base64 encoded packet + self.binary = isinstance(encoded_packet, binary_types) + b64 = not self.binary and encoded_packet[0] == 'b' + if b64: self.binary = True - encoded_packet = encoded_packet[1:] - self.packet_type = six.byte2int(encoded_packet[0:1]) - self.packet_type -= 48 - b64 = True - elif self.packet_type >= 48: - self.packet_type -= 48 - self.binary = False + self.packet_type = MESSAGE + self.data = base64.b64decode(encoded_packet[1:]) else: - self.binary = True - self.data = None - if len(encoded_packet) > 1: + if self.binary and not isinstance(encoded_packet, bytes): + encoded_packet = bytes(encoded_packet) if self.binary: - if b64: - self.data = base64.b64decode(encoded_packet[1:]) - else: - self.data = encoded_packet[1:] + self.packet_type = MESSAGE + self.data = encoded_packet else: + self.packet_type = int(encoded_packet[0]) try: - self.data = self.json.loads( - encoded_packet[1:].decode('utf-8')) + self.data = self.json.loads(encoded_packet[1:]) if isinstance(self.data, int): # do not allow integer payloads, see # github.com/miguelgrinberg/python-engineio/issues/75 # for background on this decision raise ValueError except ValueError: - self.data = encoded_packet[1:].decode('utf-8') + self.data = encoded_packet[1:] diff --git a/libs/engineio/payload.py b/libs/engineio/payload.py index fbf9cbd27..f0e9e343d 100644 --- a/libs/engineio/payload.py +++ b/libs/engineio/payload.py @@ -1,9 +1,7 @@ -import six +import urllib from . import packet -from six.moves import urllib - class Payload(object): """Engine.IO payload.""" @@ -14,31 +12,19 @@ class Payload(object): if encoded_payload is not None: self.decode(encoded_payload) - def encode(self, b64=False, jsonp_index=None): + def encode(self, jsonp_index=None): """Encode the payload for transmission.""" - encoded_payload = b'' + encoded_payload = '' for pkt in self.packets: - encoded_packet = pkt.encode(b64=b64) - packet_len = len(encoded_packet) - if b64: - encoded_payload += str(packet_len).encode('utf-8') + b':' + \ - encoded_packet - else: - binary_len = b'' - while packet_len != 0: - binary_len = six.int2byte(packet_len % 10) + binary_len - packet_len = int(packet_len / 10) - if not pkt.binary: - encoded_payload += b'\0' - else: - encoded_payload += b'\1' - encoded_payload += binary_len + b'\xff' + encoded_packet + if encoded_payload: + encoded_payload += '\x1e' + encoded_payload += pkt.encode(b64=True) if jsonp_index is not None: - encoded_payload = b'___eio[' + \ - str(jsonp_index).encode() + \ - b']("' + \ - encoded_payload.replace(b'"', b'\\"') + \ - b'");' + encoded_payload = '___eio[' + \ + str(jsonp_index) + \ + ']("' + \ + encoded_payload.replace('"', '\\"') + \ + '");' return encoded_payload def decode(self, encoded_payload): @@ -49,33 +35,12 @@ class Payload(object): return # JSONP POST payload starts with 'd=' - if encoded_payload.startswith(b'd='): + if encoded_payload.startswith('d='): encoded_payload = urllib.parse.parse_qs( - encoded_payload)[b'd'][0] + encoded_payload)['d'][0] - i = 0 - if six.byte2int(encoded_payload[0:1]) <= 1: - # binary encoding - while i < len(encoded_payload): - if len(self.packets) >= self.max_decode_packets: - raise ValueError('Too many packets in payload') - packet_len = 0 - i += 1 - while six.byte2int(encoded_payload[i:i + 1]) != 255: - packet_len = packet_len * 10 + six.byte2int( - encoded_payload[i:i + 1]) - i += 1 - self.packets.append(packet.Packet( - encoded_packet=encoded_payload[i + 1:i + 1 + packet_len])) - i += packet_len + 1 - else: - # assume text encoding - encoded_payload = encoded_payload.decode('utf-8') - while i < len(encoded_payload): - if len(self.packets) >= self.max_decode_packets: - raise ValueError('Too many packets in payload') - j = encoded_payload.find(':', i) - packet_len = int(encoded_payload[i:j]) - pkt = encoded_payload[j + 1:j + 1 + packet_len] - self.packets.append(packet.Packet(encoded_packet=pkt)) - i = j + 1 + packet_len + encoded_packets = encoded_payload.split('\x1e') + if len(encoded_packets) > self.max_decode_packets: + raise ValueError('Too many packets in payload') + self.packets = [packet.Packet(encoded_packet=encoded_packet) + for encoded_packet in encoded_packets] diff --git a/libs/engineio/server.py b/libs/engineio/server.py index e1543c2dc..7498f3f6b 100644 --- a/libs/engineio/server.py +++ b/libs/engineio/server.py @@ -1,12 +1,12 @@ +import base64 import gzip import importlib +import io import logging -import uuid +import secrets +import urllib import zlib -import six -from six.moves import urllib - from . import exceptions from . import packet from . import payload @@ -29,17 +29,16 @@ class Server(object): "gevent_uwsgi", then "gevent", and finally "threading". The first async mode that has all its dependencies installed is the one that is chosen. - :param ping_timeout: The time in seconds that the client waits for the - server to respond before disconnecting. The default - is 60 seconds. - :param ping_interval: The interval in seconds at which the client pings - the server. The default is 25 seconds. For advanced + :param ping_interval: The interval in seconds at which the server pings + the client. The default is 25 seconds. For advanced control, a two element tuple can be given, where the first number is the ping interval and the second - is a grace period added by the server. The default - grace period is 5 seconds. + is a grace period added by the server. + :param ping_timeout: The time in seconds that the client waits for the + server to respond before disconnecting. The default + is 5 seconds. :param max_http_buffer_size: The maximum size of a message when using the - polling transport. The default is 100,000,000 + polling transport. The default is 1,000,000 bytes. :param allow_upgrades: Whether to allow transport upgrades or not. The default is ``True``. @@ -48,9 +47,14 @@ class Server(object): :param compression_threshold: Only compress messages when their byte size is greater than this value. The default is 1024 bytes. - :param cookie: Name of the HTTP cookie that contains the client session - id. If set to ``None``, a cookie is not sent to the client. - The default is ``'io'``. + :param cookie: If set to a string, it is the name of the HTTP cookie the + server sends back tot he client containing the client + session id. If set to a dictionary, the ``'name'`` key + contains the cookie name and other keys define cookie + attributes, where the value of each attribute can be a + string, a callable with no arguments, or a boolean. If set + to ``None`` (the default), a cookie is not sent to the + client. :param cors_allowed_origins: Origin or list of origins that are allowed to connect to this server. Only the same origin is allowed by default. Set this argument to @@ -61,7 +65,8 @@ class Server(object): is ``True``. :param logger: To enable logging set to ``True`` or pass a logger object to use. To disable logging set to ``False``. The default is - ``False``. + ``False``. Note that fatal errors are logged even when + ``logger`` is ``False``. :param json: An alternative json module to use for encoding and decoding packets. Custom json modules must have ``dumps`` and ``loads`` functions that are compatible with the standard library @@ -79,11 +84,12 @@ class Server(object): compression_methods = ['gzip', 'deflate'] event_names = ['connect', 'disconnect', 'message'] _default_monitor_clients = True + sequence_number = 0 - def __init__(self, async_mode=None, ping_timeout=60, ping_interval=25, - max_http_buffer_size=100000000, allow_upgrades=True, + def __init__(self, async_mode=None, ping_interval=25, ping_timeout=5, + max_http_buffer_size=1000000, allow_upgrades=True, http_compression=True, compression_threshold=1024, - cookie='io', cors_allowed_origins=None, + cookie=None, cors_allowed_origins=None, cors_credentials=True, logger=False, json=None, async_handlers=True, monitor_clients=None, **kwargs): self.ping_timeout = ping_timeout @@ -92,7 +98,7 @@ class Server(object): self.ping_interval_grace_period = ping_interval[1] else: self.ping_interval = ping_interval - self.ping_interval_grace_period = 5 + self.ping_interval_grace_period = 0 self.max_http_buffer_size = max_http_buffer_size self.allow_upgrades = allow_upgrades self.http_compression = http_compression @@ -103,6 +109,7 @@ class Server(object): self.async_handlers = async_handlers self.sockets = {} self.handlers = {} + self.log_message_keys = set() self.start_service_task = monitor_clients \ if monitor_clients is not None else self._default_monitor_clients if json is not None: @@ -111,8 +118,7 @@ class Server(object): self.logger = logger else: self.logger = default_logger - if not logging.root.handlers and \ - self.logger.level == logging.NOTSET: + if self.logger.level == logging.NOTSET: if logger: self.logger.setLevel(logging.INFO) else: @@ -196,17 +202,13 @@ class Server(object): return set_handler set_handler(handler) - def send(self, sid, data, binary=None): + def send(self, sid, data): """Send a message to a client. :param sid: The session id of the recipient client. :param data: The data to send to the client. Data can be of type ``str``, ``bytes``, ``list`` or ``dict``. If a ``list`` or ``dict``, the data will be serialized as JSON. - :param binary: ``True`` to send packet as binary, ``False`` to send - as text. If not given, unicode (Python 2) and str - (Python 3) are sent as text, and str (Python 2) and - bytes (Python 3) are sent as binary. """ try: socket = self._get_socket(sid) @@ -214,7 +216,7 @@ class Server(object): # the socket is not available self.logger.warning('Cannot send to sid %s', sid) return - socket.send(packet.Packet(packet.MESSAGE, data=data, binary=binary)) + socket.send(packet.Packet(packet.MESSAGE, data=data)) def get_session(self, sid): """Return the user session for a client. @@ -292,7 +294,7 @@ class Server(object): if sid in self.sockets: # pragma: no cover del self.sockets[sid] else: - for client in six.itervalues(self.sockets): + for client in self.sockets.values(): client.close() self.sockets = {} @@ -329,22 +331,30 @@ class Server(object): allowed_origins = self._cors_allowed_origins(environ) if allowed_origins is not None and origin not in \ allowed_origins: - self.logger.info(origin + ' is not an accepted origin.') - r = self._bad_request() + self._log_error_once( + origin + ' is not an accepted origin.', 'bad-origin') + r = self._bad_request( + origin + ' is not an accepted origin.') start_response(r['status'], r['headers']) return [r['response']] method = environ['REQUEST_METHOD'] query = urllib.parse.parse_qs(environ.get('QUERY_STRING', '')) - - sid = query['sid'][0] if 'sid' in query else None - b64 = False jsonp = False jsonp_index = None - if 'b64' in query: - if query['b64'][0] == "1" or query['b64'][0].lower() == "true": - b64 = True + # make sure the client speaks a compatible Engine.IO version + sid = query['sid'][0] if 'sid' in query else None + if sid is None and query.get('EIO') != ['4']: + self._log_error_once( + 'The client is using an unsupported version of the Socket.IO ' + 'or Engine.IO protocols', 'bad-version') + r = self._bad_request( + 'The client is using an unsupported version of the Socket.IO ' + 'or Engine.IO protocols') + start_response(r['status'], r['headers']) + return [r['response']] + if 'j' in query: jsonp = True try: @@ -354,29 +364,35 @@ class Server(object): pass if jsonp and jsonp_index is None: - self.logger.warning('Invalid JSONP index number') - r = self._bad_request() + self._log_error_once('Invalid JSONP index number', + 'bad-jsonp-index') + r = self._bad_request('Invalid JSONP index number') elif method == 'GET': if sid is None: transport = query.get('transport', ['polling'])[0] - if transport != 'polling' and transport != 'websocket': - self.logger.warning('Invalid transport %s', transport) - r = self._bad_request() - else: + # transport must be one of 'polling' or 'websocket'. + # if 'websocket', the HTTP_UPGRADE header must match. + upgrade_header = environ.get('HTTP_UPGRADE').lower() \ + if 'HTTP_UPGRADE' in environ else None + if transport == 'polling' \ + or transport == upgrade_header == 'websocket': r = self._handle_connect(environ, start_response, - transport, b64, jsonp_index) + transport, jsonp_index) + else: + self._log_error_once('Invalid transport ' + transport, + 'bad-transport') + r = self._bad_request('Invalid transport ' + transport) else: if sid not in self.sockets: - self.logger.warning('Invalid session %s', sid) - r = self._bad_request() + self._log_error_once('Invalid session ' + sid, 'bad-sid') + r = self._bad_request('Invalid session ' + sid) else: socket = self._get_socket(sid) try: packets = socket.handle_get_request( environ, start_response) if isinstance(packets, list): - r = self._ok(packets, b64=b64, - jsonp_index=jsonp_index) + r = self._ok(packets, jsonp_index=jsonp_index) else: r = packets except exceptions.EngineIOError: @@ -387,8 +403,9 @@ class Server(object): del self.sockets[sid] elif method == 'POST': if sid is None or sid not in self.sockets: - self.logger.warning('Invalid session %s', sid) - r = self._bad_request() + self._log_error_once( + 'Invalid session ' + (sid or 'None'), 'bad-sid') + r = self._bad_request('Invalid session ' + (sid or 'None')) else: socket = self._get_socket(sid) try: @@ -481,11 +498,28 @@ class Server(object): """ return self._async['event'](*args, **kwargs) - def _generate_id(self): + def generate_id(self): """Generate a unique session id.""" - return uuid.uuid4().hex + id = base64.b64encode( + secrets.token_bytes(12) + self.sequence_number.to_bytes(3, 'big')) + self.sequence_number = (self.sequence_number + 1) & 0xffffff + return id.decode('utf-8').replace('/', '_').replace('+', '-') - def _handle_connect(self, environ, start_response, transport, b64=False, + def _generate_sid_cookie(self, sid, attributes): + """Generate the sid cookie.""" + cookie = attributes.get('name', 'io') + '=' + sid + for attribute, value in attributes.items(): + if attribute == 'name': + continue + if callable(value): + value = value() + if value is True: + cookie += '; ' + attribute + else: + cookie += '; ' + attribute + '=' + value + return cookie + + def _handle_connect(self, environ, start_response, transport, jsonp_index=None): """Handle a client connection request.""" if self.start_service_task: @@ -493,36 +527,53 @@ class Server(object): self.start_service_task = False self.start_background_task(self._service_task) - sid = self._generate_id() + sid = self.generate_id() s = socket.Socket(self, sid) self.sockets[sid] = s - pkt = packet.Packet( - packet.OPEN, {'sid': sid, - 'upgrades': self._upgrades(sid, transport), - 'pingTimeout': int(self.ping_timeout * 1000), - 'pingInterval': int(self.ping_interval * 1000)}) + pkt = packet.Packet(packet.OPEN, { + 'sid': sid, + 'upgrades': self._upgrades(sid, transport), + 'pingTimeout': int(self.ping_timeout * 1000), + 'pingInterval': int( + self.ping_interval + self.ping_interval_grace_period) * 1000}) s.send(pkt) + s.schedule_ping() + # NOTE: some sections below are marked as "no cover" to workaround + # what seems to be a bug in the coverage package. All the lines below + # are covered by tests, but some are not reported as such for some + # reason ret = self._trigger_event('connect', sid, environ, run_async=False) - if ret is False: + if ret is not None and ret is not True: # pragma: no cover del self.sockets[sid] self.logger.warning('Application rejected connection') - return self._unauthorized() + return self._unauthorized(ret or None) - if transport == 'websocket': + if transport == 'websocket': # pragma: no cover ret = s.handle_get_request(environ, start_response) - if s.closed: + if s.closed and sid in self.sockets: # websocket connection ended, so we are done del self.sockets[sid] return ret - else: + else: # pragma: no cover s.connected = True headers = None if self.cookie: - headers = [('Set-Cookie', self.cookie + '=' + sid)] + if isinstance(self.cookie, dict): + headers = [( + 'Set-Cookie', + self._generate_sid_cookie(sid, self.cookie) + )] + else: + headers = [( + 'Set-Cookie', + self._generate_sid_cookie(sid, { + 'name': self.cookie, 'path': '/', 'SameSite': 'Lax' + }) + )] try: - return self._ok(s.poll(), headers=headers, b64=b64, + return self._ok(s.poll(), headers=headers, jsonp_index=jsonp_index) except exceptions.QueueEmpty: return self._bad_request() @@ -561,29 +612,29 @@ class Server(object): raise KeyError('Session is disconnected') return s - def _ok(self, packets=None, headers=None, b64=False, jsonp_index=None): + def _ok(self, packets=None, headers=None, jsonp_index=None): """Generate a successful HTTP response.""" if packets is not None: if headers is None: headers = [] - if b64: - headers += [('Content-Type', 'text/plain; charset=UTF-8')] - else: - headers += [('Content-Type', 'application/octet-stream')] + headers += [('Content-Type', 'text/plain; charset=UTF-8')] return {'status': '200 OK', 'headers': headers, 'response': payload.Payload(packets=packets).encode( - b64=b64, jsonp_index=jsonp_index)} + jsonp_index=jsonp_index).encode('utf-8')} else: return {'status': '200 OK', 'headers': [('Content-Type', 'text/plain')], 'response': b'OK'} - def _bad_request(self): + def _bad_request(self, message=None): """Generate a bad request HTTP error response.""" + if message is None: + message = 'Bad Request' + message = packet.Packet.json.dumps(message) return {'status': '400 BAD REQUEST', 'headers': [('Content-Type', 'text/plain')], - 'response': b'Bad Request'} + 'response': message.encode('utf-8')} def _method_not_found(self): """Generate a method not found HTTP error response.""" @@ -591,11 +642,14 @@ class Server(object): 'headers': [('Content-Type', 'text/plain')], 'response': b'Method Not Found'} - def _unauthorized(self): + def _unauthorized(self, message=None): """Generate a unauthorized HTTP error response.""" + if message is None: + message = 'Unauthorized' + message = packet.Packet.json.dumps(message) return {'status': '401 UNAUTHORIZED', - 'headers': [('Content-Type', 'text/plain')], - 'response': b'Unauthorized'} + 'headers': [('Content-Type', 'application/json')], + 'response': message.encode('utf-8')} def _cors_allowed_origins(self, environ): default_origins = [] @@ -613,7 +667,7 @@ class Server(object): allowed_origins = default_origins elif self.cors_allowed_origins == '*': allowed_origins = None - elif isinstance(self.cors_allowed_origins, six.string_types): + elif isinstance(self.cors_allowed_origins, str): allowed_origins = [self.cors_allowed_origins] else: allowed_origins = self.cors_allowed_origins @@ -641,7 +695,7 @@ class Server(object): def _gzip(self, response): """Apply gzip compression to a response.""" - bytesio = six.BytesIO() + bytesio = io.BytesIO() with gzip.GzipFile(fileobj=bytesio, mode='w') as gz: gz.write(response) return bytesio.getvalue() @@ -650,6 +704,16 @@ class Server(object): """Apply deflate compression to a response.""" return zlib.compress(response) + def _log_error_once(self, message, message_key): + """Log message with logging.ERROR level the first time, then log + with given level.""" + if message_key not in self.log_message_keys: + self.logger.error(message + ' (further occurrences of this error ' + 'will be logged with level INFO)') + self.log_message_keys.add(message_key) + else: + self.logger.info(message) + def _service_task(self): # pragma: no cover """Monitor connected clients and clean up those that time out.""" while True: diff --git a/libs/engineio/socket.py b/libs/engineio/socket.py index 38593e7c7..1434b191d 100644 --- a/libs/engineio/socket.py +++ b/libs/engineio/socket.py @@ -1,4 +1,3 @@ -import six import sys import time @@ -15,11 +14,10 @@ class Socket(object): self.server = server self.sid = sid self.queue = self.server.create_queue() - self.last_ping = time.time() + self.last_ping = None self.connected = False self.upgrading = False self.upgraded = False - self.packet_backlog = [] self.closing = False self.closed = False self.session = {} @@ -28,7 +26,8 @@ class Socket(object): """Wait for packets to send to the client.""" queue_empty = self.server.get_queue_empty_exception() try: - packets = [self.queue.get(timeout=self.server.ping_timeout)] + packets = [self.queue.get( + timeout=self.server.ping_interval + self.server.ping_timeout)] self.queue.task_done() except queue_empty: raise exceptions.QueueEmpty() @@ -36,8 +35,12 @@ class Socket(object): return [] while True: try: - packets.append(self.queue.get(block=False)) + pkt = self.queue.get(block=False) self.queue.task_done() + if pkt is None: + self.queue.put(None) + break + packets.append(pkt) except queue_empty: break return packets @@ -50,9 +53,8 @@ class Socket(object): self.sid, packet_name, pkt.data if not isinstance(pkt.data, bytes) else '') - if pkt.packet_type == packet.PING: - self.last_ping = time.time() - self.send(packet.Packet(packet.PONG, pkt.data)) + if pkt.packet_type == packet.PONG: + self.schedule_ping() elif pkt.packet_type == packet.MESSAGE: self.server._trigger_event('message', self.sid, pkt.data, run_async=self.server.async_handlers) @@ -64,14 +66,11 @@ class Socket(object): raise exceptions.UnknownPacketError() def check_ping_timeout(self): - """Make sure the client is still sending pings. - - This helps detect disconnections for long-polling clients. - """ + """Make sure the client is still responding to pings.""" if self.closed: raise exceptions.SocketIsClosedError() - if time.time() - self.last_ping > self.server.ping_interval + \ - self.server.ping_interval_grace_period: + if self.last_ping and \ + time.time() - self.last_ping > self.server.ping_timeout: self.server.logger.info('%s: Client is gone, closing socket', self.sid) # Passing abort=False here will cause close() to write a @@ -85,8 +84,6 @@ class Socket(object): """Send a packet to the client.""" if not self.check_ping_timeout(): return - if self.upgrading: - self.packet_backlog.append(pkt) else: self.queue.put(pkt) self.server.logger.info('%s: Sending packet %s data %s', @@ -105,12 +102,16 @@ class Socket(object): self.sid, transport) return getattr(self, '_upgrade_' + transport)(environ, start_response) + if self.upgrading or self.upgraded: + # we are upgrading to WebSocket, do not return any more packets + # through the polling endpoint + return [packet.Packet(packet.NOOP)] try: packets = self.poll() except exceptions.QueueEmpty: exc = sys.exc_info() self.close(wait=False) - six.reraise(*exc) + raise exc[1].with_traceback(exc[2]) return packets def handle_post_request(self, environ): @@ -119,7 +120,7 @@ class Socket(object): if length > self.server.max_http_buffer_size: raise exceptions.ContentTooLongError() else: - body = environ['wsgi.input'].read(length) + body = environ['wsgi.input'].read(length).decode('utf-8') p = payload.Payload(encoded_payload=body) for pkt in p.packets: self.receive(pkt) @@ -136,6 +137,16 @@ class Socket(object): if wait: self.queue.join() + def schedule_ping(self): + def send_ping(): + self.last_ping = None + self.server.sleep(self.server.ping_interval) + if not self.closing and not self.closed: + self.last_ping = time.time() + self.send(packet.Packet(packet.PING)) + + self.server.start_background_task(send_ping) + def _upgrade_websocket(self, environ, start_response): """Upgrade the connection from polling to websocket.""" if self.upgraded: @@ -149,9 +160,11 @@ class Socket(object): def _websocket_handler(self, ws): """Engine.IO handler for websocket transport.""" # try to set a socket timeout matching the configured ping interval + # and timeout for attr in ['_sock', 'socket']: # pragma: no cover if hasattr(ws, attr) and hasattr(getattr(ws, attr), 'settimeout'): - getattr(ws, attr).settimeout(self.server.ping_timeout) + getattr(ws, attr).settimeout( + self.server.ping_interval + self.server.ping_timeout) if self.connected: # the socket was already connected, so this is an upgrade @@ -163,10 +176,9 @@ class Socket(object): decoded_pkt.data != 'probe': self.server.logger.info( '%s: Failed websocket upgrade, no PING packet', self.sid) + self.upgrading = False return [] - ws.send(packet.Packet( - packet.PONG, - data=six.text_type('probe')).encode(always_bytes=False)) + ws.send(packet.Packet(packet.PONG, data='probe').encode()) self.queue.put(packet.Packet(packet.NOOP)) # end poll pkt = ws.wait() @@ -177,13 +189,9 @@ class Socket(object): ('%s: Failed websocket upgrade, expected UPGRADE packet, ' 'received %s instead.'), self.sid, pkt) + self.upgrading = False return [] self.upgraded = True - - # flush any packets that were sent during the upgrade - for pkt in self.packet_backlog: - self.queue.put(pkt) - self.packet_backlog = [] self.upgrading = False else: self.connected = True @@ -202,7 +210,7 @@ class Socket(object): break try: for pkt in packets: - ws.send(pkt.encode(always_bytes=False)) + ws.send(pkt.encode()) except: break writer_task = self.server.start_background_task(writer) @@ -225,8 +233,6 @@ class Socket(object): if p is None: # connection closed by client break - if isinstance(p, six.text_type): # pragma: no cover - p = p.encode('utf-8') pkt = packet.Packet(encoded_packet=p) try: self.receive(pkt) diff --git a/libs/flask_socketio/__init__.py b/libs/flask_socketio/__init__.py index e4209f1e9..b8f993532 100644 --- a/libs/flask_socketio/__init__.py +++ b/libs/flask_socketio/__init__.py @@ -6,7 +6,7 @@ import sys # python-socketio gevent_socketio_found = True try: - from socketio import socketio_manage + from socketio import socketio_manage # noqa: F401 except ImportError: gevent_socketio_found = False if gevent_socketio_found: @@ -16,17 +16,17 @@ if gevent_socketio_found: sys.exit(1) import flask -from flask import _request_ctx_stack, json as flask_json +from flask import _request_ctx_stack, has_request_context, json as flask_json from flask.sessions import SessionMixin import socketio -from socketio.exceptions import ConnectionRefusedError +from socketio.exceptions import ConnectionRefusedError # noqa: F401 from werkzeug.debug import DebuggedApplication from werkzeug.serving import run_with_reloader from .namespace import Namespace from .test_client import SocketIOTestClient -__version__ = '4.2.1' +__version__ = '5.0.2dev' class _SocketIOMiddleware(socketio.WSGIApp): @@ -75,8 +75,8 @@ class SocketIO(object): :param channel: The channel name, when using a message queue. If a channel isn't specified, a default channel will be used. If multiple clusters of SocketIO processes need to use the - same message queue without interfering with each other, then - each cluster should use a different channel. + same message queue without interfering with each other, + then each cluster should use a different channel. :param path: The path where the Socket.IO server is exposed. Defaults to ``'socket.io'``. Leave this as is unless you know what you are doing. @@ -93,13 +93,8 @@ class SocketIO(object): explicitly. :param logger: To enable logging set to ``True`` or pass a logger object to use. To disable logging set to ``False``. The default is - ``False``. - :param binary: ``True`` to support binary payloads, ``False`` to treat all - payloads as text. On Python 2, if this is set to ``True``, - ``unicode`` values are treated as text, and ``str`` and - ``bytes`` values are treated as binary. This option has no - effect on Python 3, where text and binary payloads are - always automatically discovered. + ``False``. Note that fatal errors will be logged even when + ``logger`` is ``False``. :param json: An alternative json module to use for encoding and decoding packets. Custom json modules must have ``dumps`` and ``loads`` functions that are compatible with the standard library @@ -124,20 +119,22 @@ class SocketIO(object): :param async_mode: The asynchronous model to use. See the Deployment section in the documentation for a description of the - available options. Valid async modes are - ``threading``, ``eventlet``, ``gevent`` and - ``gevent_uwsgi``. If this argument is not given, - ``eventlet`` is tried first, then ``gevent_uwsgi``, - then ``gevent``, and finally ``threading``. The - first async mode that has all its dependencies installed - is then one that is chosen. + available options. Valid async modes are ``threading``, + ``eventlet``, ``gevent`` and ``gevent_uwsgi``. If this + argument is not given, ``eventlet`` is tried first, then + ``gevent_uwsgi``, then ``gevent``, and finally + ``threading``. The first async mode that has all its + dependencies installed is then one that is chosen. + :param ping_interval: The interval in seconds at which the server pings + the client. The default is 25 seconds. For advanced + control, a two element tuple can be given, where + the first number is the ping interval and the second + is a grace period added by the server. :param ping_timeout: The time in seconds that the client waits for the - server to respond before disconnecting. The default is - 60 seconds. - :param ping_interval: The interval in seconds at which the client pings - the server. The default is 25 seconds. + server to respond before disconnecting. The default + is 5 seconds. :param max_http_buffer_size: The maximum size of a message when using the - polling transport. The default is 100,000,000 + polling transport. The default is 1,000,000 bytes. :param allow_upgrades: Whether to allow transport upgrades or not. The default is ``True``. @@ -146,9 +143,14 @@ class SocketIO(object): :param compression_threshold: Only compress messages when their byte size is greater than this value. The default is 1024 bytes. - :param cookie: Name of the HTTP cookie that contains the client session - id. If set to ``None``, a cookie is not sent to the client. - The default is ``'io'``. + :param cookie: If set to a string, it is the name of the HTTP cookie the + server sends back tot he client containing the client + session id. If set to a dictionary, the ``'name'`` key + contains the cookie name and other keys define cookie + attributes, where the value of each attribute can be a + string, a callable with no arguments, or a boolean. If set + to ``None`` (the default), a cookie is not sent to the + client. :param cors_allowed_origins: Origin or list of origins that are allowed to connect to this server. Only the same origin is allowed by default. Set this argument to @@ -163,7 +165,9 @@ class SocketIO(object): default is ``True``. :param engineio_logger: To enable Engine.IO logging set to ``True`` or pass a logger object to use. To disable logging set to - ``False``. The default is ``False``. + ``False``. The default is ``False``. Note that + fatal errors are logged even when + ``engineio_logger`` is ``False``. """ def __init__(self, app=None, **kwargs): @@ -237,10 +241,6 @@ class SocketIO(object): resource = resource[1:] if os.environ.get('FLASK_RUN_FROM_CLI'): if self.server_options.get('async_mode') is None: - if app is not None: - app.logger.warning( - 'Flask-SocketIO is Running under Werkzeug, WebSocket ' - 'is not available.') self.server_options['async_mode'] = 'threading' self.server = socketio.Server(**self.server_options) self.async_mode = self.server.async_mode @@ -250,8 +250,9 @@ class SocketIO(object): self.server.register_namespace(namespace_handler) if app is not None: - # here we attach the SocketIO middlware to the SocketIO object so it - # can be referenced later if debug middleware needs to be inserted + # here we attach the SocketIO middlware to the SocketIO object so + # it can be referenced later if debug middleware needs to be + # inserted self.sockio_mw = _SocketIOMiddleware(self.server, app, socketio_path=resource) app.wsgi_app = self.sockio_mw @@ -355,6 +356,41 @@ class SocketIO(object): """ self.on(message, namespace=namespace)(handler) + def event(self, *args, **kwargs): + """Decorator to register an event handler. + + This is a simplified version of the ``on()`` method that takes the + event name from the decorated function. + + Example usage:: + + @socketio.event + def my_event(data): + print('Received data: ', data) + + The above example is equivalent to:: + + @socketio.on('my_event') + def my_event(data): + print('Received data: ', data) + + A custom namespace can be given as an argument to the decorator:: + + @socketio.event(namespace='/test') + def my_event(data): + print('Received data: ', data) + """ + if len(args) == 1 and len(kwargs) == 0 and callable(args[0]): + # the decorator was invoked without arguments + # args[0] is the decorated function + return self.on(args[0].__name__)(args[0]) + else: + # the decorator was invoked with arguments + def set_handler(handler): + return self.on(handler.__name__, *args, **kwargs)(handler) + + return set_handler + def on_namespace(self, namespace_handler): if not isinstance(namespace_handler, Namespace): raise ValueError('Not a namespace instance.') @@ -382,9 +418,12 @@ class SocketIO(object): :param args: A dictionary with the JSON data to send as payload. :param namespace: The namespace under which the message is to be sent. Defaults to the global namespace. - :param room: Send the message to all the users in the given room. If - this parameter is not included, the event is sent to - all connected users. + :param to: Send the message to all the users in the given room. If + this parameter is not included, the event is sent to all + connected users. + :param include_self: ``True`` to include the sender when broadcasting + or addressing a room, or ``False`` to send to + everyone but the sender. :param skip_sid: The session id of a client to ignore when broadcasting or addressing a room. This is typically set to the originator of the message, so that everyone except @@ -397,7 +436,7 @@ class SocketIO(object): only be used when addressing an individual client. """ namespace = kwargs.pop('namespace', '/') - room = kwargs.pop('room', None) + to = kwargs.pop('to', kwargs.pop('room', None)) include_self = kwargs.pop('include_self', True) skip_sid = kwargs.pop('skip_sid', None) if not include_self and not skip_sid: @@ -405,18 +444,25 @@ class SocketIO(object): callback = kwargs.pop('callback', None) if callback: # wrap the callback so that it sets app app and request contexts - sid = flask.request.sid + sid = None + if has_request_context(): + sid = getattr(flask.request, 'sid', None) original_callback = callback def _callback_wrapper(*args): return self._handle_event(original_callback, None, namespace, sid, *args) - callback = _callback_wrapper - self.server.emit(event, *args, namespace=namespace, room=room, + if sid: + # the callback wrapper above will install a request context + # before invoking the original callback + # we only use it if the emit was issued from a Socket.IO + # populated request context (i.e. request.sid is defined) + callback = _callback_wrapper + self.server.emit(event, *args, namespace=namespace, to=to, skip_sid=skip_sid, callback=callback, **kwargs) - def send(self, data, json=False, namespace=None, room=None, + def send(self, data, json=False, namespace=None, to=None, callback=None, include_self=True, skip_sid=None, **kwargs): """Send a server-generated SocketIO message. @@ -431,9 +477,12 @@ class SocketIO(object): otherwise. :param namespace: The namespace under which the message is to be sent. Defaults to the global namespace. - :param room: Send the message only to the users in the given room. If - this parameter is not included, the message is sent to - all connected users. + :param to: Send the message only to the users in the given room. If + this parameter is not included, the message is sent to all + connected users. + :param include_self: ``True`` to include the sender when broadcasting + or addressing a room, or ``False`` to send to + everyone but the sender. :param skip_sid: The session id of a client to ignore when broadcasting or addressing a room. This is typically set to the originator of the message, so that everyone except @@ -447,10 +496,10 @@ class SocketIO(object): """ skip_sid = flask.request.sid if not include_self else skip_sid if json: - self.emit('json', data, namespace=namespace, room=room, + self.emit('json', data, namespace=namespace, to=to, skip_sid=skip_sid, callback=callback, **kwargs) else: - self.emit('message', data, namespace=namespace, room=room, + self.emit('message', data, namespace=namespace, to=to, skip_sid=skip_sid, callback=callback, **kwargs) def close_room(self, room, namespace=None): @@ -480,7 +529,7 @@ class SocketIO(object): to disable it. :param extra_files: A list of additional files that the Flask reloader should watch. Defaults to ``None`` - :param log_output: If ``True``, the server logs all incomming + :param log_output: If ``True``, the server logs all incoming connections. If ``False`` logging is disabled. Defaults to ``True`` in debug mode, ``False`` in normal mode. Unused when the threading async @@ -529,8 +578,8 @@ class SocketIO(object): # o # Flask-SocketIO WebSocket handler # - self.sockio_mw.wsgi_app = DebuggedApplication(self.sockio_mw.wsgi_app, - evalex=True) + self.sockio_mw.wsgi_app = DebuggedApplication( + self.sockio_mw.wsgi_app, evalex=True) if self.server.eio.async_mode == 'threading': from werkzeug._internal import _log @@ -546,8 +595,10 @@ class SocketIO(object): import eventlet.green addresses = eventlet.green.socket.getaddrinfo(host, port) if not addresses: - raise RuntimeError('Could not resolve host to a valid address') - eventlet_socket = eventlet.listen(addresses[0][4], addresses[0][0]) + raise RuntimeError( + 'Could not resolve host to a valid address') + eventlet_socket = eventlet.listen(addresses[0][4], + addresses[0][0]) # If provided an SSL argument, use an SSL socket ssl_args = ['keyfile', 'certfile', 'server_side', 'cert_reqs', @@ -575,6 +626,9 @@ class SocketIO(object): from geventwebsocket.handler import WebSocketHandler websocket = True except ImportError: + app.logger.warning( + 'WebSocket transport not available. Install ' + 'gevent-websocket for improved performance.') websocket = False log = 'default' @@ -668,18 +722,18 @@ class SocketIO(object): flask_test_client=flask_test_client) def _handle_event(self, handler, message, namespace, sid, *args): - if sid not in self.server.environ: + environ = self.server.get_environ(sid, namespace=namespace) + if not environ: # we don't have record of this client, ignore this event return '', 400 - app = self.server.environ[sid]['flask.app'] - with app.request_context(self.server.environ[sid]): + app = environ['flask.app'] + with app.request_context(environ): if self.manage_session: # manage a separate session for this client's Socket.IO events # created as a copy of the regular user session - if 'saved_session' not in self.server.environ[sid]: - self.server.environ[sid]['saved_session'] = \ - _ManagedSession(flask.session) - session_obj = self.server.environ[sid]['saved_session'] + if 'saved_session' not in environ: + environ['saved_session'] = _ManagedSession(flask.session) + session_obj = environ['saved_session'] else: # let Flask handle the user session # for cookie based sessions, this effectively freezes the @@ -705,7 +759,8 @@ class SocketIO(object): return err_handler(value) if not self.manage_session: # when Flask is managing the user session, it needs to save it - if not hasattr(session_obj, 'modified') or session_obj.modified: + if not hasattr(session_obj, 'modified') or \ + session_obj.modified: resp = app.response_class() app.session_interface.save_session(app, session_obj, resp) return ret @@ -733,17 +788,23 @@ def emit(event, *args, **kwargs): acknowledgement. :param broadcast: ``True`` to send the message to all clients, or ``False`` to only reply to the sender of the originating event. - :param room: Send the message to all the users in the given room. If this - argument is set, then broadcast is implied to be ``True``. + :param to: Send the message to all the users in the given room. If this + argument is not set and ``broadcast`` is ``False``, then the + message is sent only to the originating user. :param include_self: ``True`` to include the sender when broadcasting or addressing a room, or ``False`` to send to everyone but the sender. + :param skip_sid: The session id of a client to ignore when broadcasting + or addressing a room. This is typically set to the + originator of the message, so that everyone except + that client receive the message. To skip multiple sids + pass a list. :param ignore_queue: Only used when a message queue is configured. If set to ``True``, the event is emitted to the clients directly, without going through the queue. This is more efficient, but only works when a single server process is used, or when there is a - single addresee. It is recommended to always leave + single addressee. It is recommended to always leave this parameter with its default value of ``False``. """ if 'namespace' in kwargs: @@ -752,16 +813,17 @@ def emit(event, *args, **kwargs): namespace = flask.request.namespace callback = kwargs.get('callback') broadcast = kwargs.get('broadcast') - room = kwargs.get('room') - if room is None and not broadcast: - room = flask.request.sid + to = kwargs.pop('to', kwargs.pop('room', None)) + if to is None and not broadcast: + to = flask.request.sid include_self = kwargs.get('include_self', True) + skip_sid = kwargs.get('skip_sid') ignore_queue = kwargs.get('ignore_queue', False) socketio = flask.current_app.extensions['socketio'] - return socketio.emit(event, *args, namespace=namespace, room=room, - include_self=include_self, callback=callback, - ignore_queue=ignore_queue) + return socketio.emit(event, *args, namespace=namespace, to=to, + include_self=include_self, skip_sid=skip_sid, + callback=callback, ignore_queue=ignore_queue) def send(message, **kwargs): @@ -783,16 +845,23 @@ def send(message, **kwargs): :param broadcast: ``True`` to send the message to all connected clients, or ``False`` to only reply to the sender of the originating event. - :param room: Send the message to all the users in the given room. + :param to: Send the message to all the users in the given room. If this + argument is not set and ``broadcast`` is ``False``, then the + message is sent only to the originating user. :param include_self: ``True`` to include the sender when broadcasting or addressing a room, or ``False`` to send to everyone but the sender. + :param skip_sid: The session id of a client to ignore when broadcasting + or addressing a room. This is typically set to the + originator of the message, so that everyone except + that client receive the message. To skip multiple sids + pass a list. :param ignore_queue: Only used when a message queue is configured. If set to ``True``, the event is emitted to the clients directly, without going through the queue. This is more efficient, but only works when a single server process is used, or when there is a - single addresee. It is recommended to always leave + single addressee. It is recommended to always leave this parameter with its default value of ``False``. """ json = kwargs.get('json', False) @@ -802,16 +871,17 @@ def send(message, **kwargs): namespace = flask.request.namespace callback = kwargs.get('callback') broadcast = kwargs.get('broadcast') - room = kwargs.get('room') - if room is None and not broadcast: - room = flask.request.sid + to = kwargs.pop('to', kwargs.pop('room', None)) + if to is None and not broadcast: + to = flask.request.sid include_self = kwargs.get('include_self', True) + skip_sid = kwargs.get('skip_sid') ignore_queue = kwargs.get('ignore_queue', False) socketio = flask.current_app.extensions['socketio'] - return socketio.send(message, json=json, namespace=namespace, room=room, - include_self=include_self, callback=callback, - ignore_queue=ignore_queue) + return socketio.send(message, json=json, namespace=namespace, to=to, + include_self=include_self, skip_sid=skip_sid, + callback=callback, ignore_queue=ignore_queue) def join_room(room, sid=None, namespace=None): diff --git a/libs/flask_socketio/namespace.py b/libs/flask_socketio/namespace.py index 914ff3816..43833a9bd 100644 --- a/libs/flask_socketio/namespace.py +++ b/libs/flask_socketio/namespace.py @@ -14,7 +14,7 @@ class Namespace(_Namespace): In the most common usage, this method is not overloaded by subclasses, as it performs the routing of events to methods. However, this - method can be overriden if special dispatching rules are needed, or if + method can be overridden if special dispatching rules are needed, or if having a single method that catches all events is desired. """ handler_name = 'on_' + event @@ -44,4 +44,3 @@ class Namespace(_Namespace): """Close a room.""" return self.socketio.close_room(room=room, namespace=namespace or self.namespace) - diff --git a/libs/flask_socketio/test_client.py b/libs/flask_socketio/test_client.py index 0c4592034..84d3f5649 100644 --- a/libs/flask_socketio/test_client.py +++ b/libs/flask_socketio/test_client.py @@ -28,36 +28,46 @@ class SocketIOTestClient(object): def __init__(self, app, socketio, namespace=None, query_string=None, headers=None, flask_test_client=None): - def _mock_send_packet(sid, pkt): + def _mock_send_packet(eio_sid, pkt): + # make sure the packet can be encoded and decoded + epkt = pkt.encode() + if not isinstance(epkt, list): + pkt = packet.Packet(encoded_packet=epkt) + else: + pkt = packet.Packet(encoded_packet=epkt[0]) + for att in epkt[1:]: + pkt.add_attachment(att) if pkt.packet_type == packet.EVENT or \ pkt.packet_type == packet.BINARY_EVENT: - if sid not in self.queue: - self.queue[sid] = [] + if eio_sid not in self.queue: + self.queue[eio_sid] = [] if pkt.data[0] == 'message' or pkt.data[0] == 'json': - self.queue[sid].append({'name': pkt.data[0], - 'args': pkt.data[1], - 'namespace': pkt.namespace or '/'}) + self.queue[eio_sid].append({ + 'name': pkt.data[0], + 'args': pkt.data[1], + 'namespace': pkt.namespace or '/'}) else: - self.queue[sid].append({'name': pkt.data[0], - 'args': pkt.data[1:], - 'namespace': pkt.namespace or '/'}) + self.queue[eio_sid].append({ + 'name': pkt.data[0], + 'args': pkt.data[1:], + 'namespace': pkt.namespace or '/'}) elif pkt.packet_type == packet.ACK or \ pkt.packet_type == packet.BINARY_ACK: - self.acks[sid] = {'args': pkt.data, - 'namespace': pkt.namespace or '/'} - elif pkt.packet_type == packet.DISCONNECT: + self.acks[eio_sid] = {'args': pkt.data, + 'namespace': pkt.namespace or '/'} + elif pkt.packet_type in [packet.DISCONNECT, packet.CONNECT_ERROR]: self.connected[pkt.namespace or '/'] = False self.app = app self.flask_test_client = flask_test_client - self.sid = uuid.uuid4().hex - self.queue[self.sid] = [] - self.acks[self.sid] = None + self.eio_sid = uuid.uuid4().hex + self.acks[self.eio_sid] = None + self.queue[self.eio_sid] = [] self.callback_counter = 0 self.socketio = socketio self.connected = {} socketio.server._send_packet = _mock_send_packet - socketio.server.environ[self.sid] = {} + socketio.server.environ[self.eio_sid] = {} socketio.server.async_handlers = False # easier to test when socketio.server.eio.async_handlers = False # events are sync if isinstance(socketio.server.manager, PubSubManager): @@ -91,6 +101,7 @@ class SocketIOTestClient(object): is when the application accepts multiple namespace connections. """ url = '/socket.io' + namespace = namespace or '/' if query_string: if query_string[0] != '?': query_string = '?' + query_string @@ -100,17 +111,15 @@ class SocketIOTestClient(object): if self.flask_test_client: # inject cookies from Flask self.flask_test_client.cookie_jar.inject_wsgi(environ) - self.connected['/'] = True - if self.socketio.server._handle_eio_connect( - self.sid, environ) is False: - del self.connected['/'] - if namespace is not None and namespace != '/': + self.socketio.server._handle_eio_connect(self.eio_sid, environ) + pkt = packet.Packet(packet.CONNECT, namespace=namespace) + with self.app.app_context(): + self.socketio.server._handle_eio_message(self.eio_sid, + pkt.encode()) + sid = self.socketio.server.manager.sid_from_eio_sid(self.eio_sid, + namespace) + if sid: self.connected[namespace] = True - pkt = packet.Packet(packet.CONNECT, namespace=namespace) - with self.app.app_context(): - if self.socketio.server._handle_eio_message( - self.sid, pkt.encode()) is False: - del self.connected[namespace] def disconnect(self, namespace=None): """Disconnect the client. @@ -122,7 +131,8 @@ class SocketIOTestClient(object): raise RuntimeError('not connected') pkt = packet.Packet(packet.DISCONNECT, namespace=namespace) with self.app.app_context(): - self.socketio.server._handle_eio_message(self.sid, pkt.encode()) + self.socketio.server._handle_eio_message(self.eio_sid, + pkt.encode()) del self.connected[namespace or '/'] def emit(self, event, *args, **kwargs): @@ -154,10 +164,12 @@ class SocketIOTestClient(object): encoded_pkt = pkt.encode() if isinstance(encoded_pkt, list): for epkt in encoded_pkt: - self.socketio.server._handle_eio_message(self.sid, epkt) + self.socketio.server._handle_eio_message(self.eio_sid, + epkt) else: - self.socketio.server._handle_eio_message(self.sid, encoded_pkt) - ack = self.acks.pop(self.sid, None) + self.socketio.server._handle_eio_message(self.eio_sid, + encoded_pkt) + ack = self.acks.pop(self.eio_sid, None) if ack is not None: return ack['args'][0] if len(ack['args']) == 1 \ else ack['args'] @@ -198,8 +210,8 @@ class SocketIOTestClient(object): if not self.is_connected(namespace): raise RuntimeError('not connected') namespace = namespace or '/' - r = [pkt for pkt in self.queue[self.sid] + r = [pkt for pkt in self.queue[self.eio_sid] if pkt['namespace'] == namespace] - self.queue[self.sid] = [pkt for pkt in self.queue[self.sid] - if pkt not in r] + self.queue[self.eio_sid] = [ + pkt for pkt in self.queue[self.eio_sid] if pkt not in r] return r diff --git a/libs/socketio/__init__.py b/libs/socketio/__init__.py index d3ee7242b..529f506bb 100644 --- a/libs/socketio/__init__.py +++ b/libs/socketio/__init__.py @@ -27,7 +27,7 @@ else: # pragma: no cover AsyncRedisManager = None AsyncAioPikaManager = None -__version__ = '4.4.0' +__version__ = '5.1.0' __all__ = ['__version__', 'Client', 'Server', 'BaseManager', 'PubSubManager', 'KombuManager', 'RedisManager', 'ZmqManager', 'KafkaManager', diff --git a/libs/socketio/asgi.py b/libs/socketio/asgi.py index 9bcdd03ba..2394ee1e2 100644 --- a/libs/socketio/asgi.py +++ b/libs/socketio/asgi.py @@ -16,6 +16,10 @@ class ASGIApp(engineio.ASGIApp): # pragma: no cover :param socketio_path: The endpoint where the Socket.IO application should be installed. The default value is appropriate for most cases. + :param on_startup: function to be called on application startup; can be + coroutine + :param on_shutdown: function to be called on application shutdown; can be + coroutine Example usage:: @@ -30,7 +34,9 @@ class ASGIApp(engineio.ASGIApp): # pragma: no cover uvicorn.run(app, host='127.0.0.1', port=5000) """ def __init__(self, socketio_server, other_asgi_app=None, - static_files=None, socketio_path='socket.io'): + static_files=None, socketio_path='socket.io', + on_startup=None, on_shutdown=None): super().__init__(socketio_server, other_asgi_app, static_files=static_files, - engineio_path=socketio_path) + engineio_path=socketio_path, on_startup=on_startup, + on_shutdown=on_shutdown) diff --git a/libs/socketio/asyncio_aiopika_manager.py b/libs/socketio/asyncio_aiopika_manager.py index b20d6afd9..905057d5c 100644 --- a/libs/socketio/asyncio_aiopika_manager.py +++ b/libs/socketio/asyncio_aiopika_manager.py @@ -30,7 +30,7 @@ class AsyncAioPikaManager(AsyncPubSubManager): # pragma: no cover notifications. Must be the same in all the servers. With this manager, the channel name is the exchange name in rabbitmq - :param write_only: If set ot ``True``, only initialize to emit events. The + :param write_only: If set to ``True``, only initialize to emit events. The default of ``False`` initializes the class for emitting and receiving. """ @@ -89,6 +89,7 @@ class AsyncAioPikaManager(AsyncPubSubManager): # pragma: no cover self.listener_queue = await self._queue( self.listener_channel, exchange ) + retry_sleep = 1 async with self.listener_queue.iterator() as queue_iter: async for message in queue_iter: diff --git a/libs/socketio/asyncio_client.py b/libs/socketio/asyncio_client.py index 2b10434ae..d89c6272c 100644 --- a/libs/socketio/asyncio_client.py +++ b/libs/socketio/asyncio_client.py @@ -3,7 +3,6 @@ import logging import random import engineio -import six from . import client from . import exceptions @@ -35,13 +34,8 @@ class AsyncClient(client.Client): adjusted by +/- 50%. :param logger: To enable logging set to ``True`` or pass a logger object to use. To disable logging set to ``False``. The default is - ``False``. - :param binary: ``True`` to support binary payloads, ``False`` to treat all - payloads as text. On Python 2, if this is set to ``True``, - ``unicode`` values are treated as text, and ``str`` and - ``bytes`` values are treated as binary. This option has no - effect on Python 3, where text and binary payloads are - always automatically discovered. + ``False``. Note that fatal errors are logged even when + ``logger`` is ``False``. :param json: An alternative json module to use for encoding and decoding packets. Custom json modules must have ``dumps`` and ``loads`` functions that are compatible with the standard library @@ -51,19 +45,26 @@ class AsyncClient(client.Client): :param request_timeout: A timeout in seconds for requests. The default is 5 seconds. + :param http_session: an initialized ``requests.Session`` object to be used + when sending requests to the server. Use it if you + need to add special client options such as proxy + servers, SSL certificates, etc. :param ssl_verify: ``True`` to verify SSL certificates, or ``False`` to skip SSL certificate verification, allowing connections to servers with self signed certificates. The default is ``True``. :param engineio_logger: To enable Engine.IO logging set to ``True`` or pass a logger object to use. To disable logging set to - ``False``. The default is ``False``. + ``False``. The default is ``False``. Note that + fatal errors are logged even when + ``engineio_logger`` is ``False``. """ def is_asyncio_based(self): return True async def connect(self, url, headers={}, transports=None, - namespaces=None, socketio_path='socket.io'): + namespaces=None, socketio_path='socket.io', wait=True, + wait_timeout=1): """Connect to a Socket.IO server. :param url: The URL of the Socket.IO server. It can include custom @@ -74,21 +75,32 @@ class AsyncClient(client.Client): are ``'polling'`` and ``'websocket'``. If not given, the polling transport is connected first, then an upgrade to websocket is attempted. - :param namespaces: The list of custom namespaces to connect, in - addition to the default namespace. If not given, - the namespace list is obtained from the registered - event handlers. + :param namespaces: The namespaces to connect as a string or list of + strings. If not given, the namespaces that have + registered event handlers are connected. :param socketio_path: The endpoint where the Socket.IO server is installed. The default value is appropriate for most cases. + :param wait: if set to ``True`` (the default) the call only returns + when all the namespaces are connected. If set to + ``False``, the call returns as soon as the Engine.IO + transport is connected, and the namespaces will connect + in the background. + :param wait_timeout: How long the client should wait for the + connection. The default is 1 second. This + argument is only considered when ``wait`` is set + to ``True``. Note: this method is a coroutine. Example usage:: - sio = socketio.Client() + sio = socketio.AsyncClient() sio.connect('http://localhost:5000') """ + if self.connected: + raise exceptions.ConnectionError('Already connected') + self.connection_url = url self.connection_headers = headers self.connection_transports = transports @@ -96,18 +108,43 @@ class AsyncClient(client.Client): self.socketio_path = socketio_path if namespaces is None: - namespaces = set(self.handlers.keys()).union( - set(self.namespace_handlers.keys())) - elif isinstance(namespaces, six.string_types): + namespaces = list(set(self.handlers.keys()).union( + set(self.namespace_handlers.keys()))) + if len(namespaces) == 0: + namespaces = ['/'] + elif isinstance(namespaces, str): namespaces = [namespaces] - self.connection_namespaces = namespaces - self.namespaces = [n for n in namespaces if n != '/'] + self.connection_namespaces = namespaces + self.namespaces = {} + if self._connect_event is None: + self._connect_event = self.eio.create_event() + else: + self._connect_event.clear() try: await self.eio.connect(url, headers=headers, transports=transports, engineio_path=socketio_path) except engineio.exceptions.ConnectionError as exc: - six.raise_from(exceptions.ConnectionError(exc.args[0]), None) + await self._trigger_event( + 'connect_error', '/', + exc.args[1] if len(exc.args) > 1 else exc.args[0]) + raise exceptions.ConnectionError(exc.args[0]) from None + + if wait: + try: + while True: + await asyncio.wait_for(self._connect_event.wait(), + wait_timeout) + self._connect_event.clear() + if set(self.namespaces) == set(self.connection_namespaces): + break + except asyncio.TimeoutError: + pass + if set(self.namespaces) != set(self.connection_namespaces): + await self.disconnect() + raise exceptions.ConnectionError( + 'One or more namespaces failed to connect') + self.connected = True async def wait(self): @@ -133,22 +170,28 @@ class AsyncClient(client.Client): :param event: The event name. It can be any string. The event names ``'connect'``, ``'message'`` and ``'disconnect'`` are reserved and should not be used. - :param data: The data to send to the client or clients. Data can be of - type ``str``, ``bytes``, ``list`` or ``dict``. If a - ``list`` or ``dict``, the data will be serialized as JSON. + :param data: The data to send to the server. Data can be of + type ``str``, ``bytes``, ``list`` or ``dict``. To send + multiple arguments, use a tuple where each element is of + one of the types indicated above. :param namespace: The Socket.IO namespace for the event. If this argument is omitted the event is emitted to the default namespace. :param callback: If given, this function will be called to acknowledge - the the client has received the message. The arguments + the the server has received the message. The arguments that will be passed to the function are those provided - by the client. Callback functions can only be used - when addressing an individual client. + by the server. - Note: this method is a coroutine. + Note: this method is not designed to be used concurrently. If multiple + tasks are emitting at the same time on the same client connection, then + messages composed of multiple packets may end up being sent in an + incorrect sequence. Use standard concurrency solutions (such as a Lock + object) to prevent this situation. + + Note 2: this method is a coroutine. """ namespace = namespace or '/' - if namespace != '/' and namespace not in self.namespaces: + if namespace not in self.namespaces: raise exceptions.BadNamespaceError( namespace + ' is not a connected namespace.') self.logger.info('Emitting event "%s" [%s]', event, namespace) @@ -156,10 +199,6 @@ class AsyncClient(client.Client): id = self._generate_ack_id(namespace, callback) else: id = None - if six.PY2 and not self.binary: - binary = False # pragma: nocover - else: - binary = None # tuples are expanded to multiple arguments, everything else is sent # as a single argument if isinstance(data, tuple): @@ -169,8 +208,7 @@ class AsyncClient(client.Client): else: data = [] await self._send_packet(packet.Packet( - packet.EVENT, namespace=namespace, data=[event] + data, id=id, - binary=binary)) + packet.EVENT, namespace=namespace, data=[event] + data, id=id)) async def send(self, data, namespace=None, callback=None): """Send a message to one or more connected clients. @@ -178,17 +216,17 @@ class AsyncClient(client.Client): This function emits an event with the name ``'message'``. Use :func:`emit` to issue custom event names. - :param data: The data to send to the client or clients. Data can be of - type ``str``, ``bytes``, ``list`` or ``dict``. If a - ``list`` or ``dict``, the data will be serialized as JSON. + :param data: The data to send to the server. Data can be of + type ``str``, ``bytes``, ``list`` or ``dict``. To send + multiple arguments, use a tuple where each element is of + one of the types indicated above. :param namespace: The Socket.IO namespace for the event. If this argument is omitted the event is emitted to the default namespace. :param callback: If given, this function will be called to acknowledge - the the client has received the message. The arguments + the the server has received the message. The arguments that will be passed to the function are those provided - by the client. Callback functions can only be used - when addressing an individual client. + by the server. Note: this method is a coroutine. """ @@ -201,9 +239,10 @@ class AsyncClient(client.Client): :param event: The event name. It can be any string. The event names ``'connect'``, ``'message'`` and ``'disconnect'`` are reserved and should not be used. - :param data: The data to send to the client or clients. Data can be of - type ``str``, ``bytes``, ``list`` or ``dict``. If a - ``list`` or ``dict``, the data will be serialized as JSON. + :param data: The data to send to the server. Data can be of + type ``str``, ``bytes``, ``list`` or ``dict``. To send + multiple arguments, use a tuple where each element is of + one of the types indicated above. :param namespace: The Socket.IO namespace for the event. If this argument is omitted the event is emitted to the default namespace. @@ -211,7 +250,13 @@ class AsyncClient(client.Client): the client acknowledges the event, then a ``TimeoutError`` exception is raised. - Note: this method is a coroutine. + Note: this method is not designed to be used concurrently. If multiple + tasks are emitting at the same time on the same client connection, then + messages composed of multiple packets may end up being sent in an + incorrect sequence. Use standard concurrency solutions (such as a Lock + object) to prevent this situation. + + Note 2: this method is a coroutine. """ callback_event = self.eio.create_event() callback_args = [] @@ -225,7 +270,7 @@ class AsyncClient(client.Client): try: await asyncio.wait_for(callback_event.wait(), timeout) except asyncio.TimeoutError: - six.raise_from(exceptions.TimeoutError(), None) + raise exceptions.TimeoutError() from None return callback_args[0] if len(callback_args[0]) > 1 \ else callback_args[0][0] if len(callback_args[0]) == 1 \ else None @@ -240,9 +285,6 @@ class AsyncClient(client.Client): for n in self.namespaces: await self._send_packet(packet.Packet(packet.DISCONNECT, namespace=n)) - await self._send_packet(packet.Packet( - packet.DISCONNECT, namespace='/')) - self.connected = False await self.eio.disconnect(abort=True) def start_background_task(self, target, *args, **kwargs): @@ -278,37 +320,29 @@ class AsyncClient(client.Client): """Send a Socket.IO packet to the server.""" encoded_packet = pkt.encode() if isinstance(encoded_packet, list): - binary = False for ep in encoded_packet: - await self.eio.send(ep, binary=binary) - binary = True + await self.eio.send(ep) else: - await self.eio.send(encoded_packet, binary=False) + await self.eio.send(encoded_packet) - async def _handle_connect(self, namespace): + async def _handle_connect(self, namespace, data): namespace = namespace or '/' - self.logger.info('Namespace {} is connected'.format(namespace)) - await self._trigger_event('connect', namespace=namespace) - if namespace == '/': - for n in self.namespaces: - await self._send_packet(packet.Packet(packet.CONNECT, - namespace=n)) - elif namespace not in self.namespaces: - self.namespaces.append(namespace) + if namespace not in self.namespaces: + self.logger.info('Namespace {} is connected'.format(namespace)) + self.namespaces[namespace] = (data or {}).get('sid', self.sid) + await self._trigger_event('connect', namespace=namespace) + self._connect_event.set() async def _handle_disconnect(self, namespace): if not self.connected: return namespace = namespace or '/' - if namespace == '/': - for n in self.namespaces: - await self._trigger_event('disconnect', namespace=n) - self.namespaces = [] await self._trigger_event('disconnect', namespace=namespace) if namespace in self.namespaces: - self.namespaces.remove(namespace) - if namespace == '/': + del self.namespaces[namespace] + if not self.namespaces: self.connected = False + await self.eio.disconnect(abort=True) async def _handle_event(self, namespace, id, data): namespace = namespace or '/' @@ -323,13 +357,8 @@ class AsyncClient(client.Client): data = list(r) else: data = [r] - if six.PY2 and not self.binary: - binary = False # pragma: nocover - else: - binary = None await self._send_packet(packet.Packet( - packet.ACK, namespace=namespace, id=id, data=data, - binary=binary)) + packet.ACK, namespace=namespace, id=id, data=data)) async def _handle_ack(self, namespace, id, data): namespace = namespace or '/' @@ -357,10 +386,11 @@ class AsyncClient(client.Client): elif not isinstance(data, (tuple, list)): data = (data,) await self._trigger_event('connect_error', namespace, *data) + self._connect_event.set() if namespace in self.namespaces: - self.namespaces.remove(namespace) + del self.namespaces[namespace] if namespace == '/': - self.namespaces = [] + self.namespaces = {} self.connected = False async def _trigger_event(self, event, namespace, *args): @@ -382,6 +412,8 @@ class AsyncClient(client.Client): event, *args) async def _handle_reconnect(self): + if self._reconnect_abort is None: # pragma: no cover + self._reconnect_abort = self.eio.create_event() self._reconnect_abort.clear() client.reconnecting_clients.append(self) attempt_count = 0 @@ -421,10 +453,12 @@ class AsyncClient(client.Client): break client.reconnecting_clients.remove(self) - def _handle_eio_connect(self): + async def _handle_eio_connect(self): """Handle the Engine.IO connection event.""" self.logger.info('Engine.IO connection established') self.sid = self.eio.sid + for n in self.connection_namespaces: + await self._send_packet(packet.Packet(packet.CONNECT, namespace=n)) async def _handle_eio_message(self, data): """Dispatch Engine.IO messages.""" @@ -439,7 +473,7 @@ class AsyncClient(client.Client): else: pkt = packet.Packet(encoded_packet=data) if pkt.packet_type == packet.CONNECT: - await self._handle_connect(pkt.namespace) + await self._handle_connect(pkt.namespace, pkt.data) elif pkt.packet_type == packet.DISCONNECT: await self._handle_disconnect(pkt.namespace) elif pkt.packet_type == packet.EVENT: @@ -449,7 +483,7 @@ class AsyncClient(client.Client): elif pkt.packet_type == packet.BINARY_EVENT or \ pkt.packet_type == packet.BINARY_ACK: self._binary_packet = pkt - elif pkt.packet_type == packet.ERROR: + elif pkt.packet_type == packet.CONNECT_ERROR: await self._handle_error(pkt.namespace, pkt.data) else: raise ValueError('Unknown packet type.') @@ -457,12 +491,10 @@ class AsyncClient(client.Client): async def _handle_eio_disconnect(self): """Handle the Engine.IO disconnection event.""" self.logger.info('Engine.IO connection dropped') - self._reconnect_abort.set() if self.connected: for n in self.namespaces: await self._trigger_event('disconnect', namespace=n) - await self._trigger_event('disconnect', namespace='/') - self.namespaces = [] + self.namespaces = {} self.connected = False self.callbacks = {} self._binary_packet = None diff --git a/libs/socketio/asyncio_manager.py b/libs/socketio/asyncio_manager.py index f4496ec7f..f89022c62 100644 --- a/libs/socketio/asyncio_manager.py +++ b/libs/socketio/asyncio_manager.py @@ -5,6 +5,9 @@ from .base_manager import BaseManager class AsyncManager(BaseManager): """Manage a client list for an asyncio server.""" + async def can_disconnect(self, sid, namespace): + return self.is_connected(sid, namespace) + async def emit(self, event, data, namespace, room=None, skip_sid=None, callback=None, **kwargs): """Emit a message to a single client, a room, or all the clients @@ -17,13 +20,13 @@ class AsyncManager(BaseManager): tasks = [] if not isinstance(skip_sid, list): skip_sid = [skip_sid] - for sid in self.get_participants(namespace, room): + for sid, eio_sid in self.get_participants(namespace, room): if sid not in skip_sid: if callback is not None: - id = self._generate_ack_id(sid, namespace, callback) + id = self._generate_ack_id(sid, callback) else: id = None - tasks.append(self.server._emit_internal(sid, event, data, + tasks.append(self.server._emit_internal(eio_sid, event, data, namespace, id)) if tasks == []: # pragma: no cover return @@ -36,19 +39,19 @@ class AsyncManager(BaseManager): """ return super().close_room(room, namespace) - async def trigger_callback(self, sid, namespace, id, data): + async def trigger_callback(self, sid, id, data): """Invoke an application callback. Note: this method is a coroutine. """ callback = None try: - callback = self.callbacks[sid][namespace][id] + callback = self.callbacks[sid][id] except KeyError: # if we get an unknown callback we just ignore it self._get_logger().warning('Unknown callback received, ignoring.') else: - del self.callbacks[sid][namespace][id] + del self.callbacks[sid][id] if callback is not None: ret = callback(*data) if asyncio.iscoroutine(ret): diff --git a/libs/socketio/asyncio_namespace.py b/libs/socketio/asyncio_namespace.py index 12e9c0fe6..f95ec2377 100644 --- a/libs/socketio/asyncio_namespace.py +++ b/libs/socketio/asyncio_namespace.py @@ -24,7 +24,7 @@ class AsyncNamespace(namespace.Namespace): In the most common usage, this method is not overloaded by subclasses, as it performs the routing of events to methods. However, this - method can be overriden if special dispatching rules are needed, or if + method can be overridden if special dispatching rules are needed, or if having a single method that catches all events is desired. Note: this method is a coroutine. @@ -149,7 +149,7 @@ class AsyncClientNamespace(namespace.ClientNamespace): In the most common usage, this method is not overloaded by subclasses, as it performs the routing of events to methods. However, this - method can be overriden if special dispatching rules are needed, or if + method can be overridden if special dispatching rules are needed, or if having a single method that catches all events is desired. Note: this method is a coroutine. diff --git a/libs/socketio/asyncio_pubsub_manager.py b/libs/socketio/asyncio_pubsub_manager.py index 6fdba6d0c..cabd41e70 100644 --- a/libs/socketio/asyncio_pubsub_manager.py +++ b/libs/socketio/asyncio_pubsub_manager.py @@ -3,7 +3,6 @@ import uuid import json import pickle -import six from .asyncio_manager import AsyncManager @@ -60,7 +59,7 @@ class AsyncPubSubManager(AsyncManager): 'context of a server.') if room is None: raise ValueError('Cannot use callback without a room set.') - id = self._generate_ack_id(room, namespace, callback) + id = self._generate_ack_id(room, callback) callback = (room, namespace, id) else: callback = None @@ -69,6 +68,15 @@ class AsyncPubSubManager(AsyncManager): 'skip_sid': skip_sid, 'callback': callback, 'host_id': self.host_id}) + async def can_disconnect(self, sid, namespace): + if self.is_connected(sid, namespace): + # client is in this server, so we can disconnect directly + return await super().can_disconnect(sid, namespace) + else: + # client is in another server, so we post request to the queue + await self._publish({'method': 'disconnect', 'sid': sid, + 'namespace': namespace or '/'}) + async def close_room(self, room, namespace=None): await self._publish({'method': 'close_room', 'room': room, 'namespace': namespace or '/'}) @@ -113,12 +121,11 @@ class AsyncPubSubManager(AsyncManager): if self.host_id == message.get('host_id'): try: sid = message['sid'] - namespace = message['namespace'] id = message['id'] args = message['args'] except KeyError: return - await self.trigger_callback(sid, namespace, id, args) + await self.trigger_callback(sid, id, args) async def _return_callback(self, host_id, sid, namespace, callback_id, *args): @@ -128,6 +135,11 @@ class AsyncPubSubManager(AsyncManager): 'sid': sid, 'namespace': namespace, 'id': callback_id, 'args': args}) + async def _handle_disconnect(self, message): + await self.server.disconnect(sid=message.get('sid'), + namespace=message.get('namespace'), + ignore_queue=True) + async def _handle_close_room(self, message): await super().close_room( room=message.get('room'), namespace=message.get('namespace')) @@ -144,7 +156,7 @@ class AsyncPubSubManager(AsyncManager): if isinstance(message, dict): data = message else: - if isinstance(message, six.binary_type): # pragma: no cover + if isinstance(message, bytes): # pragma: no cover try: data = pickle.loads(message) except: @@ -155,9 +167,13 @@ class AsyncPubSubManager(AsyncManager): except: pass if data and 'method' in data: + self._get_logger().info('pubsub message: {}'.format( + data['method'])) if data['method'] == 'emit': await self._handle_emit(data) elif data['method'] == 'callback': await self._handle_callback(data) + elif data['method'] == 'disconnect': + await self._handle_disconnect(data) elif data['method'] == 'close_room': await self._handle_close_room(data) diff --git a/libs/socketio/asyncio_redis_manager.py b/libs/socketio/asyncio_redis_manager.py index 21499c26c..9762d3eb9 100644 --- a/libs/socketio/asyncio_redis_manager.py +++ b/libs/socketio/asyncio_redis_manager.py @@ -44,7 +44,7 @@ class AsyncRedisManager(AsyncPubSubManager): # pragma: no cover SSL connection, use ``rediss://``. :param channel: The channel name on which the server sends and receives notifications. Must be the same in all the servers. - :param write_only: If set ot ``True``, only initialize to emit events. The + :param write_only: If set to ``True``, only initialize to emit events. The default of ``False`` initializes the class for emitting and receiving. """ @@ -95,6 +95,7 @@ class AsyncRedisManager(AsyncPubSubManager): # pragma: no cover password=self.password, ssl=self.ssl ) self.ch = (await self.sub.subscribe(self.channel))[0] + retry_sleep = 1 return await self.ch.get() except (aioredis.RedisError, OSError): self._get_logger().error('Cannot receive from redis... ' diff --git a/libs/socketio/asyncio_server.py b/libs/socketio/asyncio_server.py index 251d58180..778abc0dc 100644 --- a/libs/socketio/asyncio_server.py +++ b/libs/socketio/asyncio_server.py @@ -1,7 +1,6 @@ import asyncio import engineio -import six from . import asyncio_manager from . import exceptions @@ -14,65 +13,95 @@ class AsyncServer(server.Server): This class implements a fully compliant Socket.IO web server with support for websocket and long-polling transports, compatible with the asyncio - framework on Python 3.5 or newer. + framework. :param client_manager: The client manager instance that will manage the client list. When this is omitted, the client list is stored in an in-memory structure, so the use of multiple connected servers is not possible. :param logger: To enable logging set to ``True`` or pass a logger object to - use. To disable logging set to ``False``. + use. To disable logging set to ``False``. Note that fatal + errors are logged even when ``logger`` is ``False``. :param json: An alternative json module to use for encoding and decoding packets. Custom json modules must have ``dumps`` and ``loads`` functions that are compatible with the standard library versions. - :param async_handlers: If set to ``True``, event handlers are executed in - separate threads. To run handlers synchronously, - set to ``False``. The default is ``True``. + :param async_handlers: If set to ``True``, event handlers for a client are + executed in separate threads. To run handlers for a + client synchronously, set to ``False``. The default + is ``True``. + :param always_connect: When set to ``False``, new connections are + provisory until the connect handler returns + something other than ``False``, at which point they + are accepted. When set to ``True``, connections are + immediately accepted, and then if the connect + handler returns ``False`` a disconnect is issued. + Set to ``True`` if you need to emit events from the + connect handler and your client is confused when it + receives events before the connection acceptance. + In any other case use the default of ``False``. :param kwargs: Connection parameters for the underlying Engine.IO server. The Engine.IO configuration supports the following settings: :param async_mode: The asynchronous model to use. See the Deployment section in the documentation for a description of the - available options. Valid async modes are "aiohttp". If - this argument is not given, an async mode is chosen - based on the installed packages. + available options. Valid async modes are "threading", + "eventlet", "gevent" and "gevent_uwsgi". If this + argument is not given, "eventlet" is tried first, then + "gevent_uwsgi", then "gevent", and finally "threading". + The first async mode that has all its dependencies + installed is then one that is chosen. + :param ping_interval: The interval in seconds at which the server pings + the client. The default is 25 seconds. For advanced + control, a two element tuple can be given, where + the first number is the ping interval and the second + is a grace period added by the server. :param ping_timeout: The time in seconds that the client waits for the - server to respond before disconnecting. - :param ping_interval: The interval in seconds at which the client pings - the server. + server to respond before disconnecting. The default + is 5 seconds. :param max_http_buffer_size: The maximum size of a message when using the - polling transport. - :param allow_upgrades: Whether to allow transport upgrades or not. + polling transport. The default is 1,000,000 + bytes. + :param allow_upgrades: Whether to allow transport upgrades or not. The + default is ``True``. :param http_compression: Whether to compress packages when using the - polling transport. + polling transport. The default is ``True``. :param compression_threshold: Only compress messages when their byte size - is greater than this value. - :param cookie: Name of the HTTP cookie that contains the client session - id. If set to ``None``, a cookie is not sent to the client. + is greater than this value. The default is + 1024 bytes. + :param cookie: If set to a string, it is the name of the HTTP cookie the + server sends back tot he client containing the client + session id. If set to a dictionary, the ``'name'`` key + contains the cookie name and other keys define cookie + attributes, where the value of each attribute can be a + string, a callable with no arguments, or a boolean. If set + to ``None`` (the default), a cookie is not sent to the + client. :param cors_allowed_origins: Origin or list of origins that are allowed to connect to this server. Only the same origin is allowed by default. Set this argument to ``'*'`` to allow all origins, or to ``[]`` to disable CORS handling. :param cors_credentials: Whether credentials (cookies, authentication) are - allowed in requests to this server. + allowed in requests to this server. The default is + ``True``. :param monitor_clients: If set to ``True``, a background task will ensure inactive clients are closed. Set to ``False`` to disable the monitoring task (not recommended). The default is ``True``. :param engineio_logger: To enable Engine.IO logging set to ``True`` or pass a logger object to use. To disable logging set to - ``False``. + ``False``. The default is ``False``. Note that + fatal errors are logged even when + ``engineio_logger`` is ``False``. """ def __init__(self, client_manager=None, logger=False, json=None, async_handlers=True, **kwargs): if client_manager is None: client_manager = asyncio_manager.AsyncManager() super().__init__(client_manager=client_manager, logger=logger, - binary=False, json=json, - async_handlers=async_handlers, **kwargs) + json=json, async_handlers=async_handlers, **kwargs) def is_asyncio_based(self): return True @@ -89,8 +118,9 @@ class AsyncServer(server.Server): ``'connect'``, ``'message'`` and ``'disconnect'`` are reserved and should not be used. :param data: The data to send to the client or clients. Data can be of - type ``str``, ``bytes``, ``list`` or ``dict``. If a - ``list`` or ``dict``, the data will be serialized as JSON. + type ``str``, ``bytes``, ``list`` or ``dict``. To send + multiple arguments, use a tuple where each element is of + one of the types indicated above. :param to: The recipient of the message. This can be set to the session ID of a client to address only that client, or to to any custom room created by the application to address all @@ -116,7 +146,13 @@ class AsyncServer(server.Server): to always leave this parameter with its default value of ``False``. - Note: this method is a coroutine. + Note: this method is not designed to be used concurrently. If multiple + tasks are emitting at the same time to the same client connection, then + messages composed of multiple packets may end up being sent in an + incorrect sequence. Use standard concurrency solutions (such as a Lock + object) to prevent this situation. + + Note 2: this method is a coroutine. """ namespace = namespace or '/' room = to or room @@ -134,8 +170,9 @@ class AsyncServer(server.Server): :func:`emit` to issue custom event names. :param data: The data to send to the client or clients. Data can be of - type ``str``, ``bytes``, ``list`` or ``dict``. If a - ``list`` or ``dict``, the data will be serialized as JSON. + type ``str``, ``bytes``, ``list`` or ``dict``. To send + multiple arguments, use a tuple where each element is of + one of the types indicated above. :param to: The recipient of the message. This can be set to the session ID of a client to address only that client, or to to any custom room created by the application to address all @@ -175,8 +212,9 @@ class AsyncServer(server.Server): ``'connect'``, ``'message'`` and ``'disconnect'`` are reserved and should not be used. :param data: The data to send to the client or clients. Data can be of - type ``str``, ``bytes``, ``list`` or ``dict``. If a - ``list`` or ``dict``, the data will be serialized as JSON. + type ``str``, ``bytes``, ``list`` or ``dict``. To send + multiple arguments, use a tuple where each element is of + one of the types indicated above. :param to: The session ID of the recipient client. :param sid: Alias for the ``to`` parameter. :param namespace: The Socket.IO namespace for the event. If this @@ -192,7 +230,17 @@ class AsyncServer(server.Server): single server process is used. It is recommended to always leave this parameter with its default value of ``False``. + + Note: this method is not designed to be used concurrently. If multiple + tasks are emitting at the same time to the same client connection, then + messages composed of multiple packets may end up being sent in an + incorrect sequence. Use standard concurrency solutions (such as a Lock + object) to prevent this situation. + + Note 2: this method is a coroutine. """ + if to is None and sid is None: + raise ValueError('Cannot use call() to broadcast.') if not self.async_handlers: raise RuntimeError( 'Cannot use call() when async_handlers is False.') @@ -208,7 +256,7 @@ class AsyncServer(server.Server): try: await asyncio.wait_for(callback_event.wait(), timeout) except asyncio.TimeoutError: - six.raise_from(exceptions.TimeoutError(), None) + raise exceptions.TimeoutError() from None return callback_args[0] if len(callback_args[0]) > 1 \ else callback_args[0][0] if len(callback_args[0]) == 1 \ else None @@ -240,7 +288,8 @@ class AsyncServer(server.Server): the user session, use the ``session`` context manager instead. """ namespace = namespace or '/' - eio_session = await self.eio.get_session(sid) + eio_sid = self.manager.eio_sid_from_sid(sid, namespace) + eio_session = await self.eio.get_session(eio_sid) return eio_session.setdefault(namespace, {}) async def save_session(self, sid, session, namespace=None): @@ -252,7 +301,8 @@ class AsyncServer(server.Server): the default namespace is used. """ namespace = namespace or '/' - eio_session = await self.eio.get_session(sid) + eio_sid = self.manager.eio_sid_from_sid(sid, namespace) + eio_session = await self.eio.get_session(eio_sid) eio_session[namespace] = session def session(self, sid, namespace=None): @@ -295,25 +345,32 @@ class AsyncServer(server.Server): return _session_context_manager(self, sid, namespace) - async def disconnect(self, sid, namespace=None): + async def disconnect(self, sid, namespace=None, ignore_queue=False): """Disconnect a client. :param sid: Session ID of the client. :param namespace: The Socket.IO namespace to disconnect. If this argument is omitted the default namespace is used. + :param ignore_queue: Only used when a message queue is configured. If + set to ``True``, the disconnect is processed + locally, without broadcasting on the queue. It is + recommended to always leave this parameter with + its default value of ``False``. Note: this method is a coroutine. """ namespace = namespace or '/' - if self.manager.is_connected(sid, namespace=namespace): + if ignore_queue: + delete_it = self.manager.is_connected(sid, namespace) + else: + delete_it = await self.manager.can_disconnect(sid, namespace) + if delete_it: self.logger.info('Disconnecting %s [%s]', sid, namespace) - self.manager.pre_disconnect(sid, namespace=namespace) - await self._send_packet(sid, packet.Packet(packet.DISCONNECT, - namespace=namespace)) + eio_sid = self.manager.pre_disconnect(sid, namespace=namespace) + await self._send_packet(eio_sid, packet.Packet( + packet.DISCONNECT, namespace=namespace)) await self._trigger_event('disconnect', namespace, sid) self.manager.disconnect(sid, namespace=namespace) - if namespace == '/': - await self.eio.disconnect(sid) async def handle_request(self, *args, **kwargs): """Handle an HTTP request from the client. @@ -360,34 +417,41 @@ class AsyncServer(server.Server): # as a single argument if isinstance(data, tuple): data = list(data) - else: + elif data is not None: data = [data] + else: + data = [] await self._send_packet(sid, packet.Packet( - packet.EVENT, namespace=namespace, data=[event] + data, id=id, - binary=None)) + packet.EVENT, namespace=namespace, data=[event] + data, id=id)) - async def _send_packet(self, sid, pkt): + async def _send_packet(self, eio_sid, pkt): """Send a Socket.IO packet to a client.""" encoded_packet = pkt.encode() if isinstance(encoded_packet, list): - binary = False for ep in encoded_packet: - await self.eio.send(sid, ep, binary=binary) - binary = True + await self.eio.send(eio_sid, ep) else: - await self.eio.send(sid, encoded_packet, binary=False) + await self.eio.send(eio_sid, encoded_packet) - async def _handle_connect(self, sid, namespace): + async def _handle_connect(self, eio_sid, namespace, data): """Handle a client connection request.""" namespace = namespace or '/' - self.manager.connect(sid, namespace) + sid = self.manager.connect(eio_sid, namespace) if self.always_connect: - await self._send_packet(sid, packet.Packet(packet.CONNECT, - namespace=namespace)) - fail_reason = None + await self._send_packet(eio_sid, packet.Packet( + packet.CONNECT, {'sid': sid}, namespace=namespace)) + fail_reason = exceptions.ConnectionRefusedError().error_args try: - success = await self._trigger_event('connect', namespace, sid, - self.environ[sid]) + if data: + success = await self._trigger_event( + 'connect', namespace, sid, self.environ[eio_sid], data) + else: + try: + success = await self._trigger_event( + 'connect', namespace, sid, self.environ[eio_sid]) + except TypeError: + success = await self._trigger_event( + 'connect', namespace, sid, self.environ[eio_sid], None) except exceptions.ConnectionRefusedError as exc: fail_reason = exc.error_args success = False @@ -395,36 +459,31 @@ class AsyncServer(server.Server): if success is False: if self.always_connect: self.manager.pre_disconnect(sid, namespace) - await self._send_packet(sid, packet.Packet( + await self._send_packet(eio_sid, packet.Packet( packet.DISCONNECT, data=fail_reason, namespace=namespace)) + else: + await self._send_packet(eio_sid, packet.Packet( + packet.CONNECT_ERROR, data=fail_reason, + namespace=namespace)) self.manager.disconnect(sid, namespace) - if not self.always_connect: - await self._send_packet(sid, packet.Packet( - packet.ERROR, data=fail_reason, namespace=namespace)) - if sid in self.environ: # pragma: no cover - del self.environ[sid] elif not self.always_connect: - await self._send_packet(sid, packet.Packet(packet.CONNECT, - namespace=namespace)) + await self._send_packet(eio_sid, packet.Packet( + packet.CONNECT, {'sid': sid}, namespace=namespace)) - async def _handle_disconnect(self, sid, namespace): + async def _handle_disconnect(self, eio_sid, namespace): """Handle a client disconnect.""" namespace = namespace or '/' - if namespace == '/': - namespace_list = list(self.manager.get_namespaces()) - else: - namespace_list = [namespace] - for n in namespace_list: - if n != '/' and self.manager.is_connected(sid, n): - await self._trigger_event('disconnect', n, sid) - self.manager.disconnect(sid, n) - if namespace == '/' and self.manager.is_connected(sid, namespace): - await self._trigger_event('disconnect', '/', sid) - self.manager.disconnect(sid, '/') + sid = self.manager.sid_from_eio_sid(eio_sid, namespace) + if not self.manager.is_connected(sid, namespace): # pragma: no cover + return + self.manager.pre_disconnect(sid, namespace=namespace) + await self._trigger_event('disconnect', namespace, sid) + self.manager.disconnect(sid, namespace) - async def _handle_event(self, sid, namespace, id, data): + async def _handle_event(self, eio_sid, namespace, id, data): """Handle an incoming client event.""" namespace = namespace or '/' + sid = self.manager.sid_from_eio_sid(eio_sid, namespace) self.logger.info('received event "%s" from %s [%s]', data[0], sid, namespace) if not self.manager.is_connected(sid, namespace): @@ -433,11 +492,13 @@ class AsyncServer(server.Server): return if self.async_handlers: self.start_background_task(self._handle_event_internal, self, sid, - data, namespace, id) + eio_sid, data, namespace, id) else: - await self._handle_event_internal(self, sid, data, namespace, id) + await self._handle_event_internal(self, sid, eio_sid, data, + namespace, id) - async def _handle_event_internal(self, server, sid, data, namespace, id): + async def _handle_event_internal(self, server, sid, eio_sid, data, + namespace, id): r = await server._trigger_event(data[0], namespace, sid, *data[1:]) if id is not None: # send ACK packet with the response returned by the handler @@ -448,16 +509,15 @@ class AsyncServer(server.Server): data = list(r) else: data = [r] - await server._send_packet(sid, packet.Packet(packet.ACK, - namespace=namespace, - id=id, data=data, - binary=None)) + await server._send_packet(eio_sid, packet.Packet( + packet.ACK, namespace=namespace, id=id, data=data)) - async def _handle_ack(self, sid, namespace, id, data): + async def _handle_ack(self, eio_sid, namespace, id, data): """Handle ACK packets from the client.""" namespace = namespace or '/' + sid = self.manager.sid_from_eio_sid(eio_sid, namespace) self.logger.info('received ack from %s [%s]', sid, namespace) - await self.manager.trigger_callback(sid, namespace, id, data) + await self.manager.trigger_callback(sid, id, data) async def _trigger_event(self, event, namespace, *args): """Invoke an application event handler.""" @@ -478,49 +538,51 @@ class AsyncServer(server.Server): return await self.namespace_handlers[namespace].trigger_event( event, *args) - async def _handle_eio_connect(self, sid, environ): + async def _handle_eio_connect(self, eio_sid, environ): """Handle the Engine.IO connection event.""" if not self.manager_initialized: self.manager_initialized = True self.manager.initialize() - self.environ[sid] = environ - return await self._handle_connect(sid, '/') + self.environ[eio_sid] = environ - async def _handle_eio_message(self, sid, data): + async def _handle_eio_message(self, eio_sid, data): """Dispatch Engine.IO messages.""" - if sid in self._binary_packet: - pkt = self._binary_packet[sid] + if eio_sid in self._binary_packet: + pkt = self._binary_packet[eio_sid] if pkt.add_attachment(data): - del self._binary_packet[sid] + del self._binary_packet[eio_sid] if pkt.packet_type == packet.BINARY_EVENT: - await self._handle_event(sid, pkt.namespace, pkt.id, + await self._handle_event(eio_sid, pkt.namespace, pkt.id, pkt.data) else: - await self._handle_ack(sid, pkt.namespace, pkt.id, + await self._handle_ack(eio_sid, pkt.namespace, pkt.id, pkt.data) else: pkt = packet.Packet(encoded_packet=data) if pkt.packet_type == packet.CONNECT: - await self._handle_connect(sid, pkt.namespace) + await self._handle_connect(eio_sid, pkt.namespace, pkt.data) elif pkt.packet_type == packet.DISCONNECT: - await self._handle_disconnect(sid, pkt.namespace) + await self._handle_disconnect(eio_sid, pkt.namespace) elif pkt.packet_type == packet.EVENT: - await self._handle_event(sid, pkt.namespace, pkt.id, pkt.data) + await self._handle_event(eio_sid, pkt.namespace, pkt.id, + pkt.data) elif pkt.packet_type == packet.ACK: - await self._handle_ack(sid, pkt.namespace, pkt.id, pkt.data) + await self._handle_ack(eio_sid, pkt.namespace, pkt.id, + pkt.data) elif pkt.packet_type == packet.BINARY_EVENT or \ pkt.packet_type == packet.BINARY_ACK: - self._binary_packet[sid] = pkt - elif pkt.packet_type == packet.ERROR: - raise ValueError('Unexpected ERROR packet.') + self._binary_packet[eio_sid] = pkt + elif pkt.packet_type == packet.CONNECT_ERROR: + raise ValueError('Unexpected CONNECT_ERROR packet.') else: raise ValueError('Unknown packet type.') - async def _handle_eio_disconnect(self, sid): + async def _handle_eio_disconnect(self, eio_sid): """Handle Engine.IO disconnect event.""" - await self._handle_disconnect(sid, '/') - if sid in self.environ: - del self.environ[sid] + for n in list(self.manager.get_namespaces()).copy(): + await self._handle_disconnect(eio_sid, n) + if eio_sid in self.environ: + del self.environ[eio_sid] def _engineio_server_class(self): return engineio.AsyncServer diff --git a/libs/socketio/base_manager.py b/libs/socketio/base_manager.py index 3cccb8569..795bb93a9 100644 --- a/libs/socketio/base_manager.py +++ b/libs/socketio/base_manager.py @@ -1,7 +1,7 @@ import itertools import logging -import six +from bidict import bidict default_logger = logging.getLogger('socketio') @@ -18,7 +18,8 @@ class BaseManager(object): def __init__(self): self.logger = None self.server = None - self.rooms = {} + self.rooms = {} # self.rooms[namespace][room][sio_sid] = eio_sid + self.eio_to_sid = {} self.callbacks = {} self.pending_disconnect = {} @@ -33,17 +34,19 @@ class BaseManager(object): def get_namespaces(self): """Return an iterable with the active namespace names.""" - return six.iterkeys(self.rooms) + return self.rooms.keys() def get_participants(self, namespace, room): """Return an iterable with the active participants in a room.""" - for sid, active in six.iteritems(self.rooms[namespace][room].copy()): - yield sid + for sid, eio_sid in self.rooms[namespace][room]._fwdm.copy().items(): + yield sid, eio_sid - def connect(self, sid, namespace): + def connect(self, eio_sid, namespace): """Register a client connection to a namespace.""" - self.enter_room(sid, namespace, None) - self.enter_room(sid, namespace, sid) + sid = self.server.eio.generate_id() + self.enter_room(sid, namespace, None, eio_sid=eio_sid) + self.enter_room(sid, namespace, sid, eio_sid=eio_sid) + return sid def is_connected(self, sid, namespace): if namespace in self.pending_disconnect and \ @@ -51,10 +54,23 @@ class BaseManager(object): # the client is in the process of being disconnected return False try: - return self.rooms[namespace][None][sid] + return self.rooms[namespace][None][sid] is not None except KeyError: pass + def sid_from_eio_sid(self, eio_sid, namespace): + try: + return self.rooms[namespace][None]._invm[eio_sid] + except KeyError: + pass + + def eio_sid_from_sid(self, sid, namespace): + if namespace in self.rooms: + return self.rooms[namespace][None].get(sid) + + def can_disconnect(self, sid, namespace): + return self.is_connected(sid, namespace) + def pre_disconnect(self, sid, namespace): """Put the client in the to-be-disconnected list. @@ -65,34 +81,35 @@ class BaseManager(object): if namespace not in self.pending_disconnect: self.pending_disconnect[namespace] = [] self.pending_disconnect[namespace].append(sid) + return self.rooms[namespace][None].get(sid) def disconnect(self, sid, namespace): """Register a client disconnect from a namespace.""" if namespace not in self.rooms: return rooms = [] - for room_name, room in six.iteritems(self.rooms[namespace].copy()): + for room_name, room in self.rooms[namespace].copy().items(): if sid in room: rooms.append(room_name) for room in rooms: self.leave_room(sid, namespace, room) - if sid in self.callbacks and namespace in self.callbacks[sid]: - del self.callbacks[sid][namespace] - if len(self.callbacks[sid]) == 0: - del self.callbacks[sid] + if sid in self.callbacks: + del self.callbacks[sid] if namespace in self.pending_disconnect and \ sid in self.pending_disconnect[namespace]: self.pending_disconnect[namespace].remove(sid) if len(self.pending_disconnect[namespace]) == 0: del self.pending_disconnect[namespace] - def enter_room(self, sid, namespace, room): + def enter_room(self, sid, namespace, room, eio_sid=None): """Add a client to a room.""" if namespace not in self.rooms: self.rooms[namespace] = {} if room not in self.rooms[namespace]: - self.rooms[namespace][room] = {} - self.rooms[namespace][room][sid] = True + self.rooms[namespace][room] = bidict() + if eio_sid is None: + eio_sid = self.rooms[namespace][None][sid] + self.rooms[namespace][room][sid] = eio_sid def leave_room(self, sid, namespace, room): """Remove a client from a room.""" @@ -108,7 +125,7 @@ class BaseManager(object): def close_room(self, room, namespace): """Remove all participants from a room.""" try: - for sid in self.get_participants(namespace, room): + for sid, _ in self.get_participants(namespace, room): self.leave_room(sid, namespace, room) except KeyError: pass @@ -117,8 +134,8 @@ class BaseManager(object): """Return the rooms a client is in.""" r = [] try: - for room_name, room in six.iteritems(self.rooms[namespace]): - if room_name is not None and sid in room and room[sid]: + for room_name, room in self.rooms[namespace].items(): + if room_name is not None and sid in room: r.append(room_name) except KeyError: pass @@ -132,36 +149,33 @@ class BaseManager(object): return if not isinstance(skip_sid, list): skip_sid = [skip_sid] - for sid in self.get_participants(namespace, room): + for sid, eio_sid in self.get_participants(namespace, room): if sid not in skip_sid: if callback is not None: - id = self._generate_ack_id(sid, namespace, callback) + id = self._generate_ack_id(sid, callback) else: id = None - self.server._emit_internal(sid, event, data, namespace, id) + self.server._emit_internal(eio_sid, event, data, namespace, id) - def trigger_callback(self, sid, namespace, id, data): + def trigger_callback(self, sid, id, data): """Invoke an application callback.""" callback = None try: - callback = self.callbacks[sid][namespace][id] + callback = self.callbacks[sid][id] except KeyError: # if we get an unknown callback we just ignore it self._get_logger().warning('Unknown callback received, ignoring.') else: - del self.callbacks[sid][namespace][id] + del self.callbacks[sid][id] if callback is not None: callback(*data) - def _generate_ack_id(self, sid, namespace, callback): + def _generate_ack_id(self, sid, callback): """Generate a unique identifier for an ACK packet.""" - namespace = namespace or '/' if sid not in self.callbacks: - self.callbacks[sid] = {} - if namespace not in self.callbacks[sid]: - self.callbacks[sid][namespace] = {0: itertools.count(1)} - id = six.next(self.callbacks[sid][namespace][0]) - self.callbacks[sid][namespace][id] = callback + self.callbacks[sid] = {0: itertools.count(1)} + id = next(self.callbacks[sid][0]) + self.callbacks[sid][id] = callback return id def _get_logger(self): diff --git a/libs/socketio/client.py b/libs/socketio/client.py index e917d634d..80c2e3182 100644 --- a/libs/socketio/client.py +++ b/libs/socketio/client.py @@ -2,9 +2,9 @@ import itertools import logging import random import signal +import threading import engineio -import six from . import exceptions from . import namespace @@ -22,10 +22,14 @@ def signal_handler(sig, frame): # pragma: no cover """ for client in reconnecting_clients[:]: client._reconnect_abort.set() - return original_signal_handler(sig, frame) + if callable(original_signal_handler): + return original_signal_handler(sig, frame) + else: # pragma: no cover + # Handle case where no original SIGINT handler was present. + return signal.default_int_handler(sig, frame) -original_signal_handler = signal.signal(signal.SIGINT, signal_handler) +original_signal_handler = None class Client(object): @@ -51,13 +55,8 @@ class Client(object): adjusted by +/- 50%. :param logger: To enable logging set to ``True`` or pass a logger object to use. To disable logging set to ``False``. The default is - ``False``. - :param binary: ``True`` to support binary payloads, ``False`` to treat all - payloads as text. On Python 2, if this is set to ``True``, - ``unicode`` values are treated as text, and ``str`` and - ``bytes`` values are treated as binary. This option has no - effect on Python 3, where text and binary payloads are - always automatically discovered. + ``False``. Note that fatal errors are logged even when + ``logger`` is ``False``. :param json: An alternative json module to use for encoding and decoding packets. Custom json modules must have ``dumps`` and ``loads`` functions that are compatible with the standard library @@ -67,24 +66,33 @@ class Client(object): :param request_timeout: A timeout in seconds for requests. The default is 5 seconds. + :param http_session: an initialized ``requests.Session`` object to be used + when sending requests to the server. Use it if you + need to add special client options such as proxy + servers, SSL certificates, etc. :param ssl_verify: ``True`` to verify SSL certificates, or ``False`` to skip SSL certificate verification, allowing connections to servers with self signed certificates. The default is ``True``. :param engineio_logger: To enable Engine.IO logging set to ``True`` or pass a logger object to use. To disable logging set to - ``False``. The default is ``False``. + ``False``. The default is ``False``. Note that + fatal errors are logged even when + ``engineio_logger`` is ``False``. """ def __init__(self, reconnection=True, reconnection_attempts=0, reconnection_delay=1, reconnection_delay_max=5, - randomization_factor=0.5, logger=False, binary=False, - json=None, **kwargs): + randomization_factor=0.5, logger=False, json=None, **kwargs): + global original_signal_handler + if original_signal_handler is None and \ + threading.current_thread() == threading.main_thread(): + original_signal_handler = signal.signal(signal.SIGINT, + signal_handler) self.reconnection = reconnection self.reconnection_attempts = reconnection_attempts self.reconnection_delay = reconnection_delay self.reconnection_delay_max = reconnection_delay_max self.randomization_factor = randomization_factor - self.binary = binary engineio_options = kwargs engineio_logger = engineio_options.pop('engineio_logger', None) @@ -103,8 +111,7 @@ class Client(object): self.logger = logger else: self.logger = default_logger - if not logging.root.handlers and \ - self.logger.level == logging.NOTSET: + if self.logger.level == logging.NOTSET: if logger: self.logger.setLevel(logging.INFO) else: @@ -114,18 +121,19 @@ class Client(object): self.connection_url = None self.connection_headers = None self.connection_transports = None - self.connection_namespaces = None + self.connection_namespaces = [] self.socketio_path = None self.sid = None self.connected = False - self.namespaces = [] + self.namespaces = {} self.handlers = {} self.namespace_handlers = {} self.callbacks = {} self._binary_packet = None + self._connect_event = None self._reconnect_task = None - self._reconnect_abort = self.eio.create_event() + self._reconnect_abort = None def is_asyncio_based(self): return False @@ -226,7 +234,8 @@ class Client(object): namespace_handler def connect(self, url, headers={}, transports=None, - namespaces=None, socketio_path='socket.io'): + namespaces=None, socketio_path='socket.io', wait=True, + wait_timeout=1): """Connect to a Socket.IO server. :param url: The URL of the Socket.IO server. It can include custom @@ -237,19 +246,30 @@ class Client(object): are ``'polling'`` and ``'websocket'``. If not given, the polling transport is connected first, then an upgrade to websocket is attempted. - :param namespaces: The list of custom namespaces to connect, in - addition to the default namespace. If not given, - the namespace list is obtained from the registered - event handlers. + :param namespaces: The namespaces to connect as a string or list of + strings. If not given, the namespaces that have + registered event handlers are connected. :param socketio_path: The endpoint where the Socket.IO server is installed. The default value is appropriate for most cases. + :param wait: if set to ``True`` (the default) the call only returns + when all the namespaces are connected. If set to + ``False``, the call returns as soon as the Engine.IO + transport is connected, and the namespaces will connect + in the background. + :param wait_timeout: How long the client should wait for the + connection. The default is 1 second. This + argument is only considered when ``wait`` is set + to ``True``. Example usage:: sio = socketio.Client() sio.connect('http://localhost:5000') """ + if self.connected: + raise exceptions.ConnectionError('Already connected') + self.connection_url = url self.connection_headers = headers self.connection_transports = transports @@ -257,17 +277,37 @@ class Client(object): self.socketio_path = socketio_path if namespaces is None: - namespaces = set(self.handlers.keys()).union( - set(self.namespace_handlers.keys())) - elif isinstance(namespaces, six.string_types): + namespaces = list(set(self.handlers.keys()).union( + set(self.namespace_handlers.keys()))) + if len(namespaces) == 0: + namespaces = ['/'] + elif isinstance(namespaces, str): namespaces = [namespaces] - self.connection_namespaces = namespaces - self.namespaces = [n for n in namespaces if n != '/'] + self.connection_namespaces = namespaces + self.namespaces = {} + if self._connect_event is None: + self._connect_event = self.eio.create_event() + else: + self._connect_event.clear() try: self.eio.connect(url, headers=headers, transports=transports, engineio_path=socketio_path) except engineio.exceptions.ConnectionError as exc: - six.raise_from(exceptions.ConnectionError(exc.args[0]), None) + self._trigger_event( + 'connect_error', '/', + exc.args[1] if len(exc.args) > 1 else exc.args[0]) + raise exceptions.ConnectionError(exc.args[0]) from None + + if wait: + while self._connect_event.wait(timeout=wait_timeout): + self._connect_event.clear() + if set(self.namespaces) == set(self.connection_namespaces): + break + if set(self.namespaces) != set(self.connection_namespaces): + self.disconnect() + raise exceptions.ConnectionError( + 'One or more namespaces failed to connect') + self.connected = True def wait(self): @@ -291,20 +331,26 @@ class Client(object): :param event: The event name. It can be any string. The event names ``'connect'``, ``'message'`` and ``'disconnect'`` are reserved and should not be used. - :param data: The data to send to the client or clients. Data can be of - type ``str``, ``bytes``, ``list`` or ``dict``. If a - ``list`` or ``dict``, the data will be serialized as JSON. + :param data: The data to send to the server. Data can be of + type ``str``, ``bytes``, ``list`` or ``dict``. To send + multiple arguments, use a tuple where each element is of + one of the types indicated above. :param namespace: The Socket.IO namespace for the event. If this argument is omitted the event is emitted to the default namespace. :param callback: If given, this function will be called to acknowledge - the the client has received the message. The arguments + the the server has received the message. The arguments that will be passed to the function are those provided - by the client. Callback functions can only be used - when addressing an individual client. + by the server. + + Note: this method is not thread safe. If multiple threads are emitting + at the same time on the same client connection, messages composed of + multiple packets may end up being sent in an incorrect sequence. Use + standard concurrency solutions (such as a Lock object) to prevent this + situation. """ namespace = namespace or '/' - if namespace != '/' and namespace not in self.namespaces: + if namespace not in self.namespaces: raise exceptions.BadNamespaceError( namespace + ' is not a connected namespace.') self.logger.info('Emitting event "%s" [%s]', event, namespace) @@ -312,10 +358,6 @@ class Client(object): id = self._generate_ack_id(namespace, callback) else: id = None - if six.PY2 and not self.binary: - binary = False # pragma: nocover - else: - binary = None # tuples are expanded to multiple arguments, everything else is sent # as a single argument if isinstance(data, tuple): @@ -325,8 +367,7 @@ class Client(object): else: data = [] self._send_packet(packet.Packet(packet.EVENT, namespace=namespace, - data=[event] + data, id=id, - binary=binary)) + data=[event] + data, id=id)) def send(self, data, namespace=None, callback=None): """Send a message to one or more connected clients. @@ -334,17 +375,17 @@ class Client(object): This function emits an event with the name ``'message'``. Use :func:`emit` to issue custom event names. - :param data: The data to send to the client or clients. Data can be of - type ``str``, ``bytes``, ``list`` or ``dict``. If a - ``list`` or ``dict``, the data will be serialized as JSON. + :param data: The data to send to the server. Data can be of + type ``str``, ``bytes``, ``list`` or ``dict``. To send + multiple arguments, use a tuple where each element is of + one of the types indicated above. :param namespace: The Socket.IO namespace for the event. If this argument is omitted the event is emitted to the default namespace. :param callback: If given, this function will be called to acknowledge - the the client has received the message. The arguments + the the server has received the message. The arguments that will be passed to the function are those provided - by the client. Callback functions can only be used - when addressing an individual client. + by the server. """ self.emit('message', data=data, namespace=namespace, callback=callback) @@ -355,15 +396,22 @@ class Client(object): :param event: The event name. It can be any string. The event names ``'connect'``, ``'message'`` and ``'disconnect'`` are reserved and should not be used. - :param data: The data to send to the client or clients. Data can be of - type ``str``, ``bytes``, ``list`` or ``dict``. If a - ``list`` or ``dict``, the data will be serialized as JSON. + :param data: The data to send to the server. Data can be of + type ``str``, ``bytes``, ``list`` or ``dict``. To send + multiple arguments, use a tuple where each element is of + one of the types indicated above. :param namespace: The Socket.IO namespace for the event. If this argument is omitted the event is emitted to the default namespace. :param timeout: The waiting timeout. If the timeout is reached before the client acknowledges the event, then a ``TimeoutError`` exception is raised. + + Note: this method is not thread safe. If multiple threads are emitting + at the same time on the same client connection, messages composed of + multiple packets may end up being sent in an incorrect sequence. Use + standard concurrency solutions (such as a Lock object) to prevent this + situation. """ callback_event = self.eio.create_event() callback_args = [] @@ -386,11 +434,22 @@ class Client(object): # later in _handle_eio_disconnect we invoke the disconnect handler for n in self.namespaces: self._send_packet(packet.Packet(packet.DISCONNECT, namespace=n)) - self._send_packet(packet.Packet( - packet.DISCONNECT, namespace='/')) - self.connected = False self.eio.disconnect(abort=True) + def get_sid(self, namespace=None): + """Return the ``sid`` associated with a connection. + + :param namespace: The Socket.IO namespace. If this argument is omitted + the handler is associated with the default + namespace. Note that unlike previous versions, the + current version of the Socket.IO protocol uses + different ``sid`` values per namespace. + + This method returns the ``sid`` for the requested namespace as a + string. + """ + return self.namespaces.get(namespace or '/') + def transport(self): """Return the name of the transport used by the client. @@ -430,45 +489,38 @@ class Client(object): """Send a Socket.IO packet to the server.""" encoded_packet = pkt.encode() if isinstance(encoded_packet, list): - binary = False for ep in encoded_packet: - self.eio.send(ep, binary=binary) - binary = True + self.eio.send(ep) else: - self.eio.send(encoded_packet, binary=False) + self.eio.send(encoded_packet) def _generate_ack_id(self, namespace, callback): """Generate a unique identifier for an ACK packet.""" namespace = namespace or '/' if namespace not in self.callbacks: self.callbacks[namespace] = {0: itertools.count(1)} - id = six.next(self.callbacks[namespace][0]) + id = next(self.callbacks[namespace][0]) self.callbacks[namespace][id] = callback return id - def _handle_connect(self, namespace): + def _handle_connect(self, namespace, data): namespace = namespace or '/' - self.logger.info('Namespace {} is connected'.format(namespace)) - self._trigger_event('connect', namespace=namespace) - if namespace == '/': - for n in self.namespaces: - self._send_packet(packet.Packet(packet.CONNECT, namespace=n)) - elif namespace not in self.namespaces: - self.namespaces.append(namespace) + if namespace not in self.namespaces: + self.logger.info('Namespace {} is connected'.format(namespace)) + self.namespaces[namespace] = (data or {}).get('sid', self.sid) + self._trigger_event('connect', namespace=namespace) + self._connect_event.set() def _handle_disconnect(self, namespace): if not self.connected: return namespace = namespace or '/' - if namespace == '/': - for n in self.namespaces: - self._trigger_event('disconnect', namespace=n) - self.namespaces = [] self._trigger_event('disconnect', namespace=namespace) if namespace in self.namespaces: - self.namespaces.remove(namespace) - if namespace == '/': + del self.namespaces[namespace] + if not self.namespaces: self.connected = False + self.eio.disconnect(abort=True) def _handle_event(self, namespace, id, data): namespace = namespace or '/' @@ -483,12 +535,8 @@ class Client(object): data = list(r) else: data = [r] - if six.PY2 and not self.binary: - binary = False # pragma: nocover - else: - binary = None self._send_packet(packet.Packet(packet.ACK, namespace=namespace, - id=id, data=data, binary=binary)) + id=id, data=data)) def _handle_ack(self, namespace, id, data): namespace = namespace or '/' @@ -513,10 +561,11 @@ class Client(object): elif not isinstance(data, (tuple, list)): data = (data,) self._trigger_event('connect_error', namespace, *data) + self._connect_event.set() if namespace in self.namespaces: - self.namespaces.remove(namespace) + del self.namespaces[namespace] if namespace == '/': - self.namespaces = [] + self.namespaces = {} self.connected = False def _trigger_event(self, event, namespace, *args): @@ -531,6 +580,8 @@ class Client(object): event, *args) def _handle_reconnect(self): + if self._reconnect_abort is None: # pragma: no cover + self._reconnect_abort = self.eio.create_event() self._reconnect_abort.clear() reconnecting_clients.append(self) attempt_count = 0 @@ -571,6 +622,8 @@ class Client(object): """Handle the Engine.IO connection event.""" self.logger.info('Engine.IO connection established') self.sid = self.eio.sid + for n in self.connection_namespaces: + self._send_packet(packet.Packet(packet.CONNECT, namespace=n)) def _handle_eio_message(self, data): """Dispatch Engine.IO messages.""" @@ -585,7 +638,7 @@ class Client(object): else: pkt = packet.Packet(encoded_packet=data) if pkt.packet_type == packet.CONNECT: - self._handle_connect(pkt.namespace) + self._handle_connect(pkt.namespace, pkt.data) elif pkt.packet_type == packet.DISCONNECT: self._handle_disconnect(pkt.namespace) elif pkt.packet_type == packet.EVENT: @@ -595,7 +648,7 @@ class Client(object): elif pkt.packet_type == packet.BINARY_EVENT or \ pkt.packet_type == packet.BINARY_ACK: self._binary_packet = pkt - elif pkt.packet_type == packet.ERROR: + elif pkt.packet_type == packet.CONNECT_ERROR: self._handle_error(pkt.namespace, pkt.data) else: raise ValueError('Unknown packet type.') @@ -606,8 +659,7 @@ class Client(object): if self.connected: for n in self.namespaces: self._trigger_event('disconnect', namespace=n) - self._trigger_event('disconnect', namespace='/') - self.namespaces = [] + self.namespaces = {} self.connected = False self.callbacks = {} self._binary_packet = None diff --git a/libs/socketio/exceptions.py b/libs/socketio/exceptions.py index 36dddd9fc..d9dae4a5f 100644 --- a/libs/socketio/exceptions.py +++ b/libs/socketio/exceptions.py @@ -15,11 +15,15 @@ class ConnectionRefusedError(ConnectionError): """ def __init__(self, *args): if len(args) == 0: - self.error_args = None - elif len(args) == 1 and not isinstance(args[0], list): - self.error_args = args[0] + self.error_args = {'message': 'Connection rejected by server'} + elif len(args) == 1: + self.error_args = {'message': str(args[0])} else: - self.error_args = args + self.error_args = {'message': str(args[0])} + if len(args) == 2: + self.error_args['data'] = args[1] + else: + self.error_args['data'] = args[1:] class TimeoutError(SocketIOError): diff --git a/libs/socketio/kafka_manager.py b/libs/socketio/kafka_manager.py index 00a2e7f05..b5eb63635 100644 --- a/libs/socketio/kafka_manager.py +++ b/libs/socketio/kafka_manager.py @@ -28,7 +28,7 @@ class KafkaManager(PubSubManager): # pragma: no cover :param channel: The channel name (topic) on which the server sends and receives notifications. Must be the same in all the servers. - :param write_only: If set ot ``True``, only initialize to emit events. The + :param write_only: If set to ``True``, only initialize to emit events. The default of ``False`` initializes the class for emitting and receiving. """ diff --git a/libs/socketio/kombu_manager.py b/libs/socketio/kombu_manager.py index 4eb9ee498..61eebd00e 100644 --- a/libs/socketio/kombu_manager.py +++ b/libs/socketio/kombu_manager.py @@ -31,7 +31,7 @@ class KombuManager(PubSubManager): # pragma: no cover connection URLs. :param channel: The channel name on which the server sends and receives notifications. Must be the same in all the servers. - :param write_only: If set ot ``True``, only initialize to emit events. The + :param write_only: If set to ``True``, only initialize to emit events. The default of ``False`` initializes the class for emitting and receiving. :param connection_options: additional keyword arguments to be passed to diff --git a/libs/socketio/namespace.py b/libs/socketio/namespace.py index 418615ff8..97be4ee40 100644 --- a/libs/socketio/namespace.py +++ b/libs/socketio/namespace.py @@ -10,7 +10,7 @@ class BaseNamespace(object): In the most common usage, this method is not overloaded by subclasses, as it performs the routing of events to methods. However, this - method can be overriden if special dispatching rules are needed, or if + method can be overridden if special dispatching rules are needed, or if having a single method that catches all events is desired. """ handler_name = 'on_' + event diff --git a/libs/socketio/packet.py b/libs/socketio/packet.py index 73b469d6d..49f210e53 100644 --- a/libs/socketio/packet.py +++ b/libs/socketio/packet.py @@ -1,11 +1,9 @@ import functools import json as _json -import six - -(CONNECT, DISCONNECT, EVENT, ACK, ERROR, BINARY_EVENT, BINARY_ACK) = \ +(CONNECT, DISCONNECT, EVENT, ACK, CONNECT_ERROR, BINARY_EVENT, BINARY_ACK) = \ (0, 1, 2, 3, 4, 5, 6) -packet_names = ['CONNECT', 'DISCONNECT', 'EVENT', 'ACK', 'ERROR', +packet_names = ['CONNECT', 'DISCONNECT', 'EVENT', 'ACK', 'CONNECT_ERROR', 'BINARY_EVENT', 'BINARY_ACK'] @@ -49,10 +47,10 @@ class Packet(object): of packets where the first is the original packet with placeholders for the binary components and the remaining ones the binary attachments. """ - encoded_packet = six.text_type(self.packet_type) + encoded_packet = str(self.packet_type) if self.packet_type == BINARY_EVENT or self.packet_type == BINARY_ACK: data, attachments = self._deconstruct_binary(self.data) - encoded_packet += six.text_type(len(attachments)) + '-' + encoded_packet += str(len(attachments)) + '-' else: data = self.data attachments = None @@ -64,7 +62,7 @@ class Packet(object): if needs_comma: encoded_packet += ',' needs_comma = False - encoded_packet += six.text_type(self.id) + encoded_packet += str(self.id) if data is not None: if needs_comma: encoded_packet += ',' @@ -139,7 +137,7 @@ class Packet(object): else: return {key: self._reconstruct_binary_internal(value, attachments) - for key, value in six.iteritems(data)} + for key, value in data.items()} else: return data @@ -150,7 +148,7 @@ class Packet(object): return data, attachments def _deconstruct_binary_internal(self, data, attachments): - if isinstance(data, six.binary_type): + if isinstance(data, bytes): attachments.append(data) return {'_placeholder': True, 'num': len(attachments) - 1} elif isinstance(data, list): @@ -158,13 +156,13 @@ class Packet(object): for item in data] elif isinstance(data, dict): return {key: self._deconstruct_binary_internal(value, attachments) - for key, value in six.iteritems(data)} + for key, value in data.items()} else: return data def _data_is_binary(self, data): """Check if the data contains binary components.""" - if isinstance(data, six.binary_type): + if isinstance(data, bytes): return True elif isinstance(data, list): return functools.reduce( @@ -173,7 +171,7 @@ class Packet(object): elif isinstance(data, dict): return functools.reduce( lambda a, b: a or b, [self._data_is_binary(item) - for item in six.itervalues(data)], + for item in data.values()], False) else: return False diff --git a/libs/socketio/pubsub_manager.py b/libs/socketio/pubsub_manager.py index 2905b2c32..ff3304ce0 100644 --- a/libs/socketio/pubsub_manager.py +++ b/libs/socketio/pubsub_manager.py @@ -3,7 +3,6 @@ import uuid import json import pickle -import six from .base_manager import BaseManager @@ -58,7 +57,7 @@ class PubSubManager(BaseManager): 'context of a server.') if room is None: raise ValueError('Cannot use callback without a room set.') - id = self._generate_ack_id(room, namespace, callback) + id = self._generate_ack_id(room, callback) callback = (room, namespace, id) else: callback = None @@ -67,6 +66,15 @@ class PubSubManager(BaseManager): 'skip_sid': skip_sid, 'callback': callback, 'host_id': self.host_id}) + def can_disconnect(self, sid, namespace): + if self.is_connected(sid, namespace): + # client is in this server, so we can disconnect directly + return super().can_disconnect(sid, namespace) + else: + # client is in another server, so we post request to the queue + self._publish({'method': 'disconnect', 'sid': sid, + 'namespace': namespace or '/'}) + def close_room(self, room, namespace=None): self._publish({'method': 'close_room', 'room': room, 'namespace': namespace or '/'}) @@ -111,20 +119,24 @@ class PubSubManager(BaseManager): if self.host_id == message.get('host_id'): try: sid = message['sid'] - namespace = message['namespace'] id = message['id'] args = message['args'] except KeyError: return - self.trigger_callback(sid, namespace, id, args) + self.trigger_callback(sid, id, args) def _return_callback(self, host_id, sid, namespace, callback_id, *args): # When an event callback is received, the callback is returned back - # the sender, which is identified by the host_id + # to the sender, which is identified by the host_id self._publish({'method': 'callback', 'host_id': host_id, 'sid': sid, 'namespace': namespace, 'id': callback_id, 'args': args}) + def _handle_disconnect(self, message): + self.server.disconnect(sid=message.get('sid'), + namespace=message.get('namespace'), + ignore_queue=True) + def _handle_close_room(self, message): super(PubSubManager, self).close_room( room=message.get('room'), namespace=message.get('namespace')) @@ -135,7 +147,7 @@ class PubSubManager(BaseManager): if isinstance(message, dict): data = message else: - if isinstance(message, six.binary_type): # pragma: no cover + if isinstance(message, bytes): # pragma: no cover try: data = pickle.loads(message) except: @@ -146,9 +158,13 @@ class PubSubManager(BaseManager): except: pass if data and 'method' in data: + self._get_logger().info('pubsub message: {}'.format( + data['method'])) if data['method'] == 'emit': self._handle_emit(data) elif data['method'] == 'callback': self._handle_callback(data) + elif data['method'] == 'disconnect': + self._handle_disconnect(data) elif data['method'] == 'close_room': self._handle_close_room(data) diff --git a/libs/socketio/redis_manager.py b/libs/socketio/redis_manager.py index ad383345e..7e99d31eb 100644 --- a/libs/socketio/redis_manager.py +++ b/libs/socketio/redis_manager.py @@ -30,7 +30,7 @@ class RedisManager(PubSubManager): # pragma: no cover store running on the same host, use ``redis://``. :param channel: The channel name on which the server sends and receives notifications. Must be the same in all the servers. - :param write_only: If set ot ``True``, only initialize to emit events. The + :param write_only: If set to ``True``, only initialize to emit events. The default of ``False`` initializes the class for emitting and receiving. :param redis_options: additional keyword arguments to be passed to @@ -78,7 +78,7 @@ class RedisManager(PubSubManager): # pragma: no cover if not retry: self._redis_connect() return self.redis.publish(self.channel, pickle.dumps(data)) - except redis.exceptions.ConnectionError: + except redis.exceptions.RedisError: if retry: logger.error('Cannot publish to redis... retrying') retry = False @@ -94,9 +94,10 @@ class RedisManager(PubSubManager): # pragma: no cover if connect: self._redis_connect() self.pubsub.subscribe(self.channel) + retry_sleep = 1 for message in self.pubsub.listen(): yield message - except redis.exceptions.ConnectionError: + except redis.exceptions.RedisError: logger.error('Cannot receive from redis... ' 'retrying in {} secs'.format(retry_sleep)) connect = True diff --git a/libs/socketio/server.py b/libs/socketio/server.py index 76b7d2e8f..22da0ac20 100644 --- a/libs/socketio/server.py +++ b/libs/socketio/server.py @@ -1,7 +1,6 @@ import logging import engineio -import six from . import base_manager from . import exceptions @@ -23,13 +22,8 @@ class Server(object): multiple connected servers is not possible. :param logger: To enable logging set to ``True`` or pass a logger object to use. To disable logging set to ``False``. The default is - ``False``. - :param binary: ``True`` to support binary payloads, ``False`` to treat all - payloads as text. On Python 2, if this is set to ``True``, - ``unicode`` values are treated as text, and ``str`` and - ``bytes`` values are treated as binary. This option has no - effect on Python 3, where text and binary payloads are - always automatically discovered. + ``False``. Note that fatal errors are logged even when + ``logger`` is ``False``. :param json: An alternative json module to use for encoding and decoding packets. Custom json modules must have ``dumps`` and ``loads`` functions that are compatible with the standard library @@ -60,13 +54,16 @@ class Server(object): "gevent_uwsgi", then "gevent", and finally "threading". The first async mode that has all its dependencies installed is then one that is chosen. + :param ping_interval: The interval in seconds at which the server pings + the client. The default is 25 seconds. For advanced + control, a two element tuple can be given, where + the first number is the ping interval and the second + is a grace period added by the server. :param ping_timeout: The time in seconds that the client waits for the server to respond before disconnecting. The default - is 60 seconds. - :param ping_interval: The interval in seconds at which the client pings - the server. The default is 25 seconds. + is 5 seconds. :param max_http_buffer_size: The maximum size of a message when using the - polling transport. The default is 100,000,000 + polling transport. The default is 1,000,000 bytes. :param allow_upgrades: Whether to allow transport upgrades or not. The default is ``True``. @@ -75,9 +72,14 @@ class Server(object): :param compression_threshold: Only compress messages when their byte size is greater than this value. The default is 1024 bytes. - :param cookie: Name of the HTTP cookie that contains the client session - id. If set to ``None``, a cookie is not sent to the client. - The default is ``'io'``. + :param cookie: If set to a string, it is the name of the HTTP cookie the + server sends back tot he client containing the client + session id. If set to a dictionary, the ``'name'`` key + contains the cookie name and other keys define cookie + attributes, where the value of each attribute can be a + string, a callable with no arguments, or a boolean. If set + to ``None`` (the default), a cookie is not sent to the + client. :param cors_allowed_origins: Origin or list of origins that are allowed to connect to this server. Only the same origin is allowed by default. Set this argument to @@ -92,11 +94,12 @@ class Server(object): default is ``True``. :param engineio_logger: To enable Engine.IO logging set to ``True`` or pass a logger object to use. To disable logging set to - ``False``. The default is ``False``. + ``False``. The default is ``False``. Note that + fatal errors are logged even when + ``engineio_logger`` is ``False``. """ - def __init__(self, client_manager=None, logger=False, binary=False, - json=None, async_handlers=True, always_connect=False, - **kwargs): + def __init__(self, client_manager=None, logger=False, json=None, + async_handlers=True, always_connect=False, **kwargs): engineio_options = kwargs engineio_logger = engineio_options.pop('engineio_logger', None) if engineio_logger is not None: @@ -109,7 +112,6 @@ class Server(object): self.eio.on('connect', self._handle_eio_connect) self.eio.on('message', self._handle_eio_message) self.eio.on('disconnect', self._handle_eio_disconnect) - self.binary = binary self.environ = {} self.handlers = {} @@ -121,8 +123,7 @@ class Server(object): self.logger = logger else: self.logger = default_logger - if not logging.root.handlers and \ - self.logger.level == logging.NOTSET: + if self.logger.level == logging.NOTSET: if logger: self.logger.setLevel(logging.INFO) else: @@ -169,7 +170,7 @@ class Server(object): def message_handler(sid, msg): print('Received message: ', msg) eio.send(sid, 'response') - socket_io.on('message', namespace='/chat', message_handler) + socket_io.on('message', namespace='/chat', handler=message_handler) The handler function receives the ``sid`` (session ID) for the client as first argument. The ``'connect'`` event handler receives the @@ -250,8 +251,9 @@ class Server(object): ``'connect'``, ``'message'`` and ``'disconnect'`` are reserved and should not be used. :param data: The data to send to the client or clients. Data can be of - type ``str``, ``bytes``, ``list`` or ``dict``. If a - ``list`` or ``dict``, the data will be serialized as JSON. + type ``str``, ``bytes``, ``list`` or ``dict``. To send + multiple arguments, use a tuple where each element is of + one of the types indicated above. :param to: The recipient of the message. This can be set to the session ID of a client to address only that client, or to to any custom room created by the application to address all @@ -277,6 +279,12 @@ class Server(object): single server process is used. It is recommended to always leave this parameter with its default value of ``False``. + + Note: this method is not thread safe. If multiple threads are emitting + at the same time to the same client, then messages composed of + multiple packets may end up being sent in an incorrect sequence. Use + standard concurrency solutions (such as a Lock object) to prevent this + situation. """ namespace = namespace or '/' room = to or room @@ -293,8 +301,9 @@ class Server(object): :func:`emit` to issue custom event names. :param data: The data to send to the client or clients. Data can be of - type ``str``, ``bytes``, ``list`` or ``dict``. If a - ``list`` or ``dict``, the data will be serialized as JSON. + type ``str``, ``bytes``, ``list`` or ``dict``. To send + multiple arguments, use a tuple where each element is of + one of the types indicated above. :param to: The recipient of the message. This can be set to the session ID of a client to address only that client, or to to any custom room created by the application to address all @@ -332,8 +341,9 @@ class Server(object): ``'connect'``, ``'message'`` and ``'disconnect'`` are reserved and should not be used. :param data: The data to send to the client or clients. Data can be of - type ``str``, ``bytes``, ``list`` or ``dict``. If a - ``list`` or ``dict``, the data will be serialized as JSON. + type ``str``, ``bytes``, ``list`` or ``dict``. To send + multiple arguments, use a tuple where each element is of + one of the types indicated above. :param to: The session ID of the recipient client. :param sid: Alias for the ``to`` parameter. :param namespace: The Socket.IO namespace for the event. If this @@ -349,7 +359,15 @@ class Server(object): single server process is used. It is recommended to always leave this parameter with its default value of ``False``. + + Note: this method is not thread safe. If multiple threads are emitting + at the same time to the same client, then messages composed of + multiple packets may end up being sent in an incorrect sequence. Use + standard concurrency solutions (such as a Lock object) to prevent this + situation. """ + if to is None and sid is None: + raise ValueError('Cannot use call() to broadcast.') if not self.async_handlers: raise RuntimeError( 'Cannot use call() when async_handlers is False.') @@ -434,7 +452,8 @@ class Server(object): is used. """ namespace = namespace or '/' - eio_session = self.eio.get_session(sid) + eio_sid = self.manager.eio_sid_from_sid(sid, namespace) + eio_session = self.eio.get_session(eio_sid) return eio_session.setdefault(namespace, {}) def save_session(self, sid, session, namespace=None): @@ -446,7 +465,8 @@ class Server(object): the default namespace is used. """ namespace = namespace or '/' - eio_session = self.eio.get_session(sid) + eio_sid = self.manager.eio_sid_from_sid(sid, namespace) + eio_session = self.eio.get_session(eio_sid) eio_session[namespace] = session def session(self, sid, namespace=None): @@ -489,23 +509,30 @@ class Server(object): return _session_context_manager(self, sid, namespace) - def disconnect(self, sid, namespace=None): + def disconnect(self, sid, namespace=None, ignore_queue=False): """Disconnect a client. :param sid: Session ID of the client. :param namespace: The Socket.IO namespace to disconnect. If this argument is omitted the default namespace is used. + :param ignore_queue: Only used when a message queue is configured. If + set to ``True``, the disconnect is processed + locally, without broadcasting on the queue. It is + recommended to always leave this parameter with + its default value of ``False``. """ namespace = namespace or '/' - if self.manager.is_connected(sid, namespace=namespace): + if ignore_queue: + delete_it = self.manager.is_connected(sid, namespace) + else: + delete_it = self.manager.can_disconnect(sid, namespace) + if delete_it: self.logger.info('Disconnecting %s [%s]', sid, namespace) - self.manager.pre_disconnect(sid, namespace=namespace) - self._send_packet(sid, packet.Packet(packet.DISCONNECT, - namespace=namespace)) + eio_sid = self.manager.pre_disconnect(sid, namespace=namespace) + self._send_packet(eio_sid, packet.Packet( + packet.DISCONNECT, namespace=namespace)) self._trigger_event('disconnect', namespace, sid) self.manager.disconnect(sid, namespace=namespace) - if namespace == '/': - self.eio.disconnect(sid) def transport(self, sid): """Return the name of the transport used by the client. @@ -517,6 +544,16 @@ class Server(object): """ return self.eio.transport(sid) + def get_environ(self, sid, namespace=None): + """Return the WSGI environ dictionary for a client. + + :param sid: The session of the client. + :param namespace: The Socket.IO namespace. If this argument is omitted + the default namespace is used. + """ + eio_sid = self.manager.eio_sid_from_sid(sid, namespace or '/') + return self.environ.get(eio_sid) + def handle_request(self, environ, start_response): """Handle an HTTP request from the client. @@ -560,44 +597,47 @@ class Server(object): """ return self.eio.sleep(seconds) - def _emit_internal(self, sid, event, data, namespace=None, id=None): + def _emit_internal(self, eio_sid, event, data, namespace=None, id=None): """Send a message to a client.""" - if six.PY2 and not self.binary: - binary = False # pragma: nocover - else: - binary = None # tuples are expanded to multiple arguments, everything else is sent # as a single argument if isinstance(data, tuple): data = list(data) - else: + elif data is not None: data = [data] - self._send_packet(sid, packet.Packet(packet.EVENT, namespace=namespace, - data=[event] + data, id=id, - binary=binary)) + else: + data = [] + self._send_packet(eio_sid, packet.Packet( + packet.EVENT, namespace=namespace, data=[event] + data, id=id)) - def _send_packet(self, sid, pkt): + def _send_packet(self, eio_sid, pkt): """Send a Socket.IO packet to a client.""" encoded_packet = pkt.encode() if isinstance(encoded_packet, list): - binary = False for ep in encoded_packet: - self.eio.send(sid, ep, binary=binary) - binary = True + self.eio.send(eio_sid, ep) else: - self.eio.send(sid, encoded_packet, binary=False) + self.eio.send(eio_sid, encoded_packet) - def _handle_connect(self, sid, namespace): + def _handle_connect(self, eio_sid, namespace, data): """Handle a client connection request.""" namespace = namespace or '/' - self.manager.connect(sid, namespace) + sid = self.manager.connect(eio_sid, namespace) if self.always_connect: - self._send_packet(sid, packet.Packet(packet.CONNECT, - namespace=namespace)) - fail_reason = None + self._send_packet(eio_sid, packet.Packet( + packet.CONNECT, {'sid': sid}, namespace=namespace)) + fail_reason = exceptions.ConnectionRefusedError().error_args try: - success = self._trigger_event('connect', namespace, sid, - self.environ[sid]) + if data: + success = self._trigger_event( + 'connect', namespace, sid, self.environ[eio_sid], data) + else: + try: + success = self._trigger_event( + 'connect', namespace, sid, self.environ[eio_sid]) + except TypeError: + success = self._trigger_event( + 'connect', namespace, sid, self.environ[eio_sid], None) except exceptions.ConnectionRefusedError as exc: fail_reason = exc.error_args success = False @@ -605,36 +645,31 @@ class Server(object): if success is False: if self.always_connect: self.manager.pre_disconnect(sid, namespace) - self._send_packet(sid, packet.Packet( + self._send_packet(eio_sid, packet.Packet( packet.DISCONNECT, data=fail_reason, namespace=namespace)) + else: + self._send_packet(eio_sid, packet.Packet( + packet.CONNECT_ERROR, data=fail_reason, + namespace=namespace)) self.manager.disconnect(sid, namespace) - if not self.always_connect: - self._send_packet(sid, packet.Packet( - packet.ERROR, data=fail_reason, namespace=namespace)) - if sid in self.environ: # pragma: no cover - del self.environ[sid] elif not self.always_connect: - self._send_packet(sid, packet.Packet(packet.CONNECT, - namespace=namespace)) + self._send_packet(eio_sid, packet.Packet( + packet.CONNECT, {'sid': sid}, namespace=namespace)) - def _handle_disconnect(self, sid, namespace): + def _handle_disconnect(self, eio_sid, namespace): """Handle a client disconnect.""" namespace = namespace or '/' - if namespace == '/': - namespace_list = list(self.manager.get_namespaces()) - else: - namespace_list = [namespace] - for n in namespace_list: - if n != '/' and self.manager.is_connected(sid, n): - self._trigger_event('disconnect', n, sid) - self.manager.disconnect(sid, n) - if namespace == '/' and self.manager.is_connected(sid, namespace): - self._trigger_event('disconnect', '/', sid) - self.manager.disconnect(sid, '/') + sid = self.manager.sid_from_eio_sid(eio_sid, namespace) + if not self.manager.is_connected(sid, namespace): # pragma: no cover + return + self.manager.pre_disconnect(sid, namespace=namespace) + self._trigger_event('disconnect', namespace, sid) + self.manager.disconnect(sid, namespace) - def _handle_event(self, sid, namespace, id, data): + def _handle_event(self, eio_sid, namespace, id, data): """Handle an incoming client event.""" namespace = namespace or '/' + sid = self.manager.sid_from_eio_sid(eio_sid, namespace) self.logger.info('received event "%s" from %s [%s]', data[0], sid, namespace) if not self.manager.is_connected(sid, namespace): @@ -643,11 +678,13 @@ class Server(object): return if self.async_handlers: self.start_background_task(self._handle_event_internal, self, sid, - data, namespace, id) + eio_sid, data, namespace, id) else: - self._handle_event_internal(self, sid, data, namespace, id) + self._handle_event_internal(self, sid, eio_sid, data, namespace, + id) - def _handle_event_internal(self, server, sid, data, namespace, id): + def _handle_event_internal(self, server, sid, eio_sid, data, namespace, + id): r = server._trigger_event(data[0], namespace, sid, *data[1:]) if id is not None: # send ACK packet with the response returned by the handler @@ -658,20 +695,15 @@ class Server(object): data = list(r) else: data = [r] - if six.PY2 and not self.binary: - binary = False # pragma: nocover - else: - binary = None - server._send_packet(sid, packet.Packet(packet.ACK, - namespace=namespace, - id=id, data=data, - binary=binary)) + server._send_packet(eio_sid, packet.Packet( + packet.ACK, namespace=namespace, id=id, data=data)) - def _handle_ack(self, sid, namespace, id, data): + def _handle_ack(self, eio_sid, namespace, id, data): """Handle ACK packets from the client.""" namespace = namespace or '/' + sid = self.manager.sid_from_eio_sid(eio_sid, namespace) self.logger.info('received ack from %s [%s]', sid, namespace) - self.manager.trigger_callback(sid, namespace, id, data) + self.manager.trigger_callback(sid, id, data) def _trigger_event(self, event, namespace, *args): """Invoke an application event handler.""" @@ -684,47 +716,48 @@ class Server(object): return self.namespace_handlers[namespace].trigger_event( event, *args) - def _handle_eio_connect(self, sid, environ): + def _handle_eio_connect(self, eio_sid, environ): """Handle the Engine.IO connection event.""" if not self.manager_initialized: self.manager_initialized = True self.manager.initialize() - self.environ[sid] = environ - return self._handle_connect(sid, '/') + self.environ[eio_sid] = environ - def _handle_eio_message(self, sid, data): + def _handle_eio_message(self, eio_sid, data): """Dispatch Engine.IO messages.""" - if sid in self._binary_packet: - pkt = self._binary_packet[sid] + if eio_sid in self._binary_packet: + pkt = self._binary_packet[eio_sid] if pkt.add_attachment(data): - del self._binary_packet[sid] + del self._binary_packet[eio_sid] if pkt.packet_type == packet.BINARY_EVENT: - self._handle_event(sid, pkt.namespace, pkt.id, pkt.data) + self._handle_event(eio_sid, pkt.namespace, pkt.id, + pkt.data) else: - self._handle_ack(sid, pkt.namespace, pkt.id, pkt.data) + self._handle_ack(eio_sid, pkt.namespace, pkt.id, pkt.data) else: pkt = packet.Packet(encoded_packet=data) if pkt.packet_type == packet.CONNECT: - self._handle_connect(sid, pkt.namespace) + self._handle_connect(eio_sid, pkt.namespace, pkt.data) elif pkt.packet_type == packet.DISCONNECT: - self._handle_disconnect(sid, pkt.namespace) + self._handle_disconnect(eio_sid, pkt.namespace) elif pkt.packet_type == packet.EVENT: - self._handle_event(sid, pkt.namespace, pkt.id, pkt.data) + self._handle_event(eio_sid, pkt.namespace, pkt.id, pkt.data) elif pkt.packet_type == packet.ACK: - self._handle_ack(sid, pkt.namespace, pkt.id, pkt.data) + self._handle_ack(eio_sid, pkt.namespace, pkt.id, pkt.data) elif pkt.packet_type == packet.BINARY_EVENT or \ pkt.packet_type == packet.BINARY_ACK: - self._binary_packet[sid] = pkt - elif pkt.packet_type == packet.ERROR: - raise ValueError('Unexpected ERROR packet.') + self._binary_packet[eio_sid] = pkt + elif pkt.packet_type == packet.CONNECT_ERROR: + raise ValueError('Unexpected CONNECT_ERROR packet.') else: raise ValueError('Unknown packet type.') - def _handle_eio_disconnect(self, sid): + def _handle_eio_disconnect(self, eio_sid): """Handle Engine.IO disconnect event.""" - self._handle_disconnect(sid, '/') - if sid in self.environ: - del self.environ[sid] + for n in list(self.manager.get_namespaces()).copy(): + self._handle_disconnect(eio_sid, n) + if eio_sid in self.environ: + del self.environ[eio_sid] def _engineio_server_class(self): return engineio.Server diff --git a/libs/socketio/zmq_manager.py b/libs/socketio/zmq_manager.py index f2a2ae5dc..54538cf1b 100644 --- a/libs/socketio/zmq_manager.py +++ b/libs/socketio/zmq_manager.py @@ -5,7 +5,6 @@ try: import eventlet.green.zmq as zmq except ImportError: zmq = None -import six from .pubsub_manager import PubSubManager @@ -98,7 +97,7 @@ class ZmqManager(PubSubManager): # pragma: no cover def _listen(self): for message in self.zmq_listen(): - if isinstance(message, six.binary_type): + if isinstance(message, bytes): try: message = pickle.loads(message) except Exception: diff --git a/libs/version.txt b/libs/version.txt index b5f2fafae..440a8e802 100644 --- a/libs/version.txt +++ b/libs/version.txt @@ -3,14 +3,16 @@ apscheduler=3.5.1 babelfish=0.5.5 backports.functools-lru-cache=1.5 Beaker=1.10.0 +bidict=0.18.4 bottle-fdsend=0.1.1 bottle=0.12.13 chardet=3.0.4 dogpile.cache=0.6.5 +engineio=4.0.2dev enzyme=0.4.1 ffsubsync=0.4.11 Flask=1.1.1 -gevent-websocker=0.10.1 +flask-socketio=5.0.2dev gitpython=2.1.9 guessit=3.3.1 guess_language-spirit=0.5.3 @@ -27,11 +29,11 @@ requests=2.18.4 semver=2.13.0 SimpleConfigParser=0.1.0 <-- modified version: do not update!!! six=1.11.0 +socketio=5.1.0 stevedore=1.28.0 subliminal=2.1.0dev tzlocal=2.1b1 urllib3=1.23 -Waitress=1.4.3 ## indirect dependencies auditok=0.1.5 # Required-by: ffsubsync diff --git a/libs/waitress/__init__.py b/libs/waitress/__init__.py deleted file mode 100644 index e6e5911a5..000000000 --- a/libs/waitress/__init__.py +++ /dev/null @@ -1,45 +0,0 @@ -from waitress.server import create_server -import logging - - -def serve(app, **kw): - _server = kw.pop("_server", create_server) # test shim - _quiet = kw.pop("_quiet", False) # test shim - _profile = kw.pop("_profile", False) # test shim - if not _quiet: # pragma: no cover - # idempotent if logging has already been set up - logging.basicConfig() - server = _server(app, **kw) - if not _quiet: # pragma: no cover - server.print_listen("Serving on http://{}:{}") - if _profile: # pragma: no cover - profile("server.run()", globals(), locals(), (), False) - else: - server.run() - - -def serve_paste(app, global_conf, **kw): - serve(app, **kw) - return 0 - - -def profile(cmd, globals, locals, sort_order, callers): # pragma: no cover - # runs a command under the profiler and print profiling output at shutdown - import os - import profile - import pstats - import tempfile - - fd, fn = tempfile.mkstemp() - try: - profile.runctx(cmd, globals, locals, fn) - stats = pstats.Stats(fn) - stats.strip_dirs() - # calls,time,cumulative and cumulative,calls,time are useful - stats.sort_stats(*(sort_order or ("cumulative", "calls", "time"))) - if callers: - stats.print_callers(0.3) - else: - stats.print_stats(0.3) - finally: - os.remove(fn) diff --git a/libs/waitress/__main__.py b/libs/waitress/__main__.py deleted file mode 100644 index 9bcd07e59..000000000 --- a/libs/waitress/__main__.py +++ /dev/null @@ -1,3 +0,0 @@ -from waitress.runner import run # pragma nocover - -run() # pragma nocover diff --git a/libs/waitress/adjustments.py b/libs/waitress/adjustments.py deleted file mode 100644 index 93439eab8..000000000 --- a/libs/waitress/adjustments.py +++ /dev/null @@ -1,515 +0,0 @@ -############################################################################## -# -# Copyright (c) 2002 Zope Foundation and Contributors. -# All Rights Reserved. -# -# This software is subject to the provisions of the Zope Public License, -# Version 2.1 (ZPL). A copy of the ZPL should accompany this distribution. -# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED -# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED -# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS -# FOR A PARTICULAR PURPOSE. -# -############################################################################## -"""Adjustments are tunable parameters. -""" -import getopt -import socket -import warnings - -from .proxy_headers import PROXY_HEADERS -from .compat import ( - PY2, - WIN, - string_types, - HAS_IPV6, -) - -truthy = frozenset(("t", "true", "y", "yes", "on", "1")) - -KNOWN_PROXY_HEADERS = frozenset( - header.lower().replace("_", "-") for header in PROXY_HEADERS -) - - -def asbool(s): - """ Return the boolean value ``True`` if the case-lowered value of string - input ``s`` is any of ``t``, ``true``, ``y``, ``on``, or ``1``, otherwise - return the boolean value ``False``. If ``s`` is the value ``None``, - return ``False``. If ``s`` is already one of the boolean values ``True`` - or ``False``, return it.""" - if s is None: - return False - if isinstance(s, bool): - return s - s = str(s).strip() - return s.lower() in truthy - - -def asoctal(s): - """Convert the given octal string to an actual number.""" - return int(s, 8) - - -def aslist_cronly(value): - if isinstance(value, string_types): - value = filter(None, [x.strip() for x in value.splitlines()]) - return list(value) - - -def aslist(value): - """ Return a list of strings, separating the input based on newlines - and, if flatten=True (the default), also split on spaces within - each line.""" - values = aslist_cronly(value) - result = [] - for value in values: - subvalues = value.split() - result.extend(subvalues) - return result - - -def asset(value): - return set(aslist(value)) - - -def slash_fixed_str(s): - s = s.strip() - if s: - # always have a leading slash, replace any number of leading slashes - # with a single slash, and strip any trailing slashes - s = "/" + s.lstrip("/").rstrip("/") - return s - - -def str_iftruthy(s): - return str(s) if s else None - - -def as_socket_list(sockets): - """Checks if the elements in the list are of type socket and - removes them if not.""" - return [sock for sock in sockets if isinstance(sock, socket.socket)] - - -class _str_marker(str): - pass - - -class _int_marker(int): - pass - - -class _bool_marker(object): - pass - - -class Adjustments(object): - """This class contains tunable parameters. - """ - - _params = ( - ("host", str), - ("port", int), - ("ipv4", asbool), - ("ipv6", asbool), - ("listen", aslist), - ("threads", int), - ("trusted_proxy", str_iftruthy), - ("trusted_proxy_count", int), - ("trusted_proxy_headers", asset), - ("log_untrusted_proxy_headers", asbool), - ("clear_untrusted_proxy_headers", asbool), - ("url_scheme", str), - ("url_prefix", slash_fixed_str), - ("backlog", int), - ("recv_bytes", int), - ("send_bytes", int), - ("outbuf_overflow", int), - ("outbuf_high_watermark", int), - ("inbuf_overflow", int), - ("connection_limit", int), - ("cleanup_interval", int), - ("channel_timeout", int), - ("log_socket_errors", asbool), - ("max_request_header_size", int), - ("max_request_body_size", int), - ("expose_tracebacks", asbool), - ("ident", str_iftruthy), - ("asyncore_loop_timeout", int), - ("asyncore_use_poll", asbool), - ("unix_socket", str), - ("unix_socket_perms", asoctal), - ("sockets", as_socket_list), - ) - - _param_map = dict(_params) - - # hostname or IP address to listen on - host = _str_marker("0.0.0.0") - - # TCP port to listen on - port = _int_marker(8080) - - listen = ["{}:{}".format(host, port)] - - # number of threads available for tasks - threads = 4 - - # Host allowed to overrid ``wsgi.url_scheme`` via header - trusted_proxy = None - - # How many proxies we trust when chained - # - # X-Forwarded-For: 192.0.2.1, "[2001:db8::1]" - # - # or - # - # Forwarded: for=192.0.2.1, For="[2001:db8::1]" - # - # means there were (potentially), two proxies involved. If we know there is - # only 1 valid proxy, then that initial IP address "192.0.2.1" is not - # trusted and we completely ignore it. If there are two trusted proxies in - # the path, this value should be set to a higher number. - trusted_proxy_count = None - - # Which of the proxy headers should we trust, this is a set where you - # either specify forwarded or one or more of forwarded-host, forwarded-for, - # forwarded-proto, forwarded-port. - trusted_proxy_headers = set() - - # Would you like waitress to log warnings about untrusted proxy headers - # that were encountered while processing the proxy headers? This only makes - # sense to set when you have a trusted_proxy, and you expect the upstream - # proxy server to filter invalid headers - log_untrusted_proxy_headers = False - - # Should waitress clear any proxy headers that are not deemed trusted from - # the environ? Change to True by default in 2.x - clear_untrusted_proxy_headers = _bool_marker - - # default ``wsgi.url_scheme`` value - url_scheme = "http" - - # default ``SCRIPT_NAME`` value, also helps reset ``PATH_INFO`` - # when nonempty - url_prefix = "" - - # server identity (sent in Server: header) - ident = "waitress" - - # backlog is the value waitress passes to pass to socket.listen() This is - # the maximum number of incoming TCP connections that will wait in an OS - # queue for an available channel. From listen(1): "If a connection - # request arrives when the queue is full, the client may receive an error - # with an indication of ECONNREFUSED or, if the underlying protocol - # supports retransmission, the request may be ignored so that a later - # reattempt at connection succeeds." - backlog = 1024 - - # recv_bytes is the argument to pass to socket.recv(). - recv_bytes = 8192 - - # deprecated setting controls how many bytes will be buffered before - # being flushed to the socket - send_bytes = 1 - - # A tempfile should be created if the pending output is larger than - # outbuf_overflow, which is measured in bytes. The default is 1MB. This - # is conservative. - outbuf_overflow = 1048576 - - # The app_iter will pause when pending output is larger than this value - # in bytes. - outbuf_high_watermark = 16777216 - - # A tempfile should be created if the pending input is larger than - # inbuf_overflow, which is measured in bytes. The default is 512K. This - # is conservative. - inbuf_overflow = 524288 - - # Stop creating new channels if too many are already active (integer). - # Each channel consumes at least one file descriptor, and, depending on - # the input and output body sizes, potentially up to three. The default - # is conservative, but you may need to increase the number of file - # descriptors available to the Waitress process on most platforms in - # order to safely change it (see ``ulimit -a`` "open files" setting). - # Note that this doesn't control the maximum number of TCP connections - # that can be waiting for processing; the ``backlog`` argument controls - # that. - connection_limit = 100 - - # Minimum seconds between cleaning up inactive channels. - cleanup_interval = 30 - - # Maximum seconds to leave an inactive connection open. - channel_timeout = 120 - - # Boolean: turn off to not log premature client disconnects. - log_socket_errors = True - - # maximum number of bytes of all request headers combined (256K default) - max_request_header_size = 262144 - - # maximum number of bytes in request body (1GB default) - max_request_body_size = 1073741824 - - # expose tracebacks of uncaught exceptions - expose_tracebacks = False - - # Path to a Unix domain socket to use. - unix_socket = None - - # Path to a Unix domain socket to use. - unix_socket_perms = 0o600 - - # The socket options to set on receiving a connection. It is a list of - # (level, optname, value) tuples. TCP_NODELAY disables the Nagle - # algorithm for writes (Waitress already buffers its writes). - socket_options = [ - (socket.SOL_TCP, socket.TCP_NODELAY, 1), - ] - - # The asyncore.loop timeout value - asyncore_loop_timeout = 1 - - # The asyncore.loop flag to use poll() instead of the default select(). - asyncore_use_poll = False - - # Enable IPv4 by default - ipv4 = True - - # Enable IPv6 by default - ipv6 = True - - # A list of sockets that waitress will use to accept connections. They can - # be used for e.g. socket activation - sockets = [] - - def __init__(self, **kw): - - if "listen" in kw and ("host" in kw or "port" in kw): - raise ValueError("host or port may not be set if listen is set.") - - if "listen" in kw and "sockets" in kw: - raise ValueError("socket may not be set if listen is set.") - - if "sockets" in kw and ("host" in kw or "port" in kw): - raise ValueError("host or port may not be set if sockets is set.") - - if "sockets" in kw and "unix_socket" in kw: - raise ValueError("unix_socket may not be set if sockets is set") - - if "unix_socket" in kw and ("host" in kw or "port" in kw): - raise ValueError("unix_socket may not be set if host or port is set") - - if "unix_socket" in kw and "listen" in kw: - raise ValueError("unix_socket may not be set if listen is set") - - if "send_bytes" in kw: - warnings.warn( - "send_bytes will be removed in a future release", DeprecationWarning, - ) - - for k, v in kw.items(): - if k not in self._param_map: - raise ValueError("Unknown adjustment %r" % k) - setattr(self, k, self._param_map[k](v)) - - if not isinstance(self.host, _str_marker) or not isinstance( - self.port, _int_marker - ): - self.listen = ["{}:{}".format(self.host, self.port)] - - enabled_families = socket.AF_UNSPEC - - if not self.ipv4 and not HAS_IPV6: # pragma: no cover - raise ValueError( - "IPv4 is disabled but IPv6 is not available. Cowardly refusing to start." - ) - - if self.ipv4 and not self.ipv6: - enabled_families = socket.AF_INET - - if not self.ipv4 and self.ipv6 and HAS_IPV6: - enabled_families = socket.AF_INET6 - - wanted_sockets = [] - hp_pairs = [] - for i in self.listen: - if ":" in i: - (host, port) = i.rsplit(":", 1) - - # IPv6 we need to make sure that we didn't split on the address - if "]" in port: # pragma: nocover - (host, port) = (i, str(self.port)) - else: - (host, port) = (i, str(self.port)) - - if WIN and PY2: # pragma: no cover - try: - # Try turning the port into an integer - port = int(port) - - except Exception: - raise ValueError( - "Windows does not support service names instead of port numbers" - ) - - try: - if "[" in host and "]" in host: # pragma: nocover - host = host.strip("[").rstrip("]") - - if host == "*": - host = None - - for s in socket.getaddrinfo( - host, - port, - enabled_families, - socket.SOCK_STREAM, - socket.IPPROTO_TCP, - socket.AI_PASSIVE, - ): - (family, socktype, proto, _, sockaddr) = s - - # It seems that getaddrinfo() may sometimes happily return - # the same result multiple times, this of course makes - # bind() very unhappy... - # - # Split on %, and drop the zone-index from the host in the - # sockaddr. Works around a bug in OS X whereby - # getaddrinfo() returns the same link-local interface with - # two different zone-indices (which makes no sense what so - # ever...) yet treats them equally when we attempt to bind(). - if ( - sockaddr[1] == 0 - or (sockaddr[0].split("%", 1)[0], sockaddr[1]) not in hp_pairs - ): - wanted_sockets.append((family, socktype, proto, sockaddr)) - hp_pairs.append((sockaddr[0].split("%", 1)[0], sockaddr[1])) - - except Exception: - raise ValueError("Invalid host/port specified.") - - if self.trusted_proxy_count is not None and self.trusted_proxy is None: - raise ValueError( - "trusted_proxy_count has no meaning without setting " "trusted_proxy" - ) - - elif self.trusted_proxy_count is None: - self.trusted_proxy_count = 1 - - if self.trusted_proxy_headers and self.trusted_proxy is None: - raise ValueError( - "trusted_proxy_headers has no meaning without setting " "trusted_proxy" - ) - - if self.trusted_proxy_headers: - self.trusted_proxy_headers = { - header.lower() for header in self.trusted_proxy_headers - } - - unknown_values = self.trusted_proxy_headers - KNOWN_PROXY_HEADERS - if unknown_values: - raise ValueError( - "Received unknown trusted_proxy_headers value (%s) expected one " - "of %s" - % (", ".join(unknown_values), ", ".join(KNOWN_PROXY_HEADERS)) - ) - - if ( - "forwarded" in self.trusted_proxy_headers - and self.trusted_proxy_headers - {"forwarded"} - ): - raise ValueError( - "The Forwarded proxy header and the " - "X-Forwarded-{By,Host,Proto,Port,For} headers are mutually " - "exclusive. Can't trust both!" - ) - - elif self.trusted_proxy is not None: - warnings.warn( - "No proxy headers were marked as trusted, but trusted_proxy was set. " - "Implicitly trusting X-Forwarded-Proto for backwards compatibility. " - "This will be removed in future versions of waitress.", - DeprecationWarning, - ) - self.trusted_proxy_headers = {"x-forwarded-proto"} - - if self.clear_untrusted_proxy_headers is _bool_marker: - warnings.warn( - "In future versions of Waitress clear_untrusted_proxy_headers will be " - "set to True by default. You may opt-out by setting this value to " - "False, or opt-in explicitly by setting this to True.", - DeprecationWarning, - ) - self.clear_untrusted_proxy_headers = False - - self.listen = wanted_sockets - - self.check_sockets(self.sockets) - - @classmethod - def parse_args(cls, argv): - """Pre-parse command line arguments for input into __init__. Note that - this does not cast values into adjustment types, it just creates a - dictionary suitable for passing into __init__, where __init__ does the - casting. - """ - long_opts = ["help", "call"] - for opt, cast in cls._params: - opt = opt.replace("_", "-") - if cast is asbool: - long_opts.append(opt) - long_opts.append("no-" + opt) - else: - long_opts.append(opt + "=") - - kw = { - "help": False, - "call": False, - } - - opts, args = getopt.getopt(argv, "", long_opts) - for opt, value in opts: - param = opt.lstrip("-").replace("-", "_") - - if param == "listen": - kw["listen"] = "{} {}".format(kw.get("listen", ""), value) - continue - - if param.startswith("no_"): - param = param[3:] - kw[param] = "false" - elif param in ("help", "call"): - kw[param] = True - elif cls._param_map[param] is asbool: - kw[param] = "true" - else: - kw[param] = value - - return kw, args - - @classmethod - def check_sockets(cls, sockets): - has_unix_socket = False - has_inet_socket = False - has_unsupported_socket = False - for sock in sockets: - if ( - sock.family == socket.AF_INET or sock.family == socket.AF_INET6 - ) and sock.type == socket.SOCK_STREAM: - has_inet_socket = True - elif ( - hasattr(socket, "AF_UNIX") - and sock.family == socket.AF_UNIX - and sock.type == socket.SOCK_STREAM - ): - has_unix_socket = True - else: - has_unsupported_socket = True - if has_unix_socket and has_inet_socket: - raise ValueError("Internet and UNIX sockets may not be mixed.") - if has_unsupported_socket: - raise ValueError("Only Internet or UNIX stream sockets may be used.") diff --git a/libs/waitress/buffers.py b/libs/waitress/buffers.py deleted file mode 100644 index 04f6b4274..000000000 --- a/libs/waitress/buffers.py +++ /dev/null @@ -1,308 +0,0 @@ -############################################################################## -# -# Copyright (c) 2001-2004 Zope Foundation and Contributors. -# All Rights Reserved. -# -# This software is subject to the provisions of the Zope Public License, -# Version 2.1 (ZPL). A copy of the ZPL should accompany this distribution. -# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED -# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED -# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS -# FOR A PARTICULAR PURPOSE. -# -############################################################################## -"""Buffers -""" -from io import BytesIO - -# copy_bytes controls the size of temp. strings for shuffling data around. -COPY_BYTES = 1 << 18 # 256K - -# The maximum number of bytes to buffer in a simple string. -STRBUF_LIMIT = 8192 - - -class FileBasedBuffer(object): - - remain = 0 - - def __init__(self, file, from_buffer=None): - self.file = file - if from_buffer is not None: - from_file = from_buffer.getfile() - read_pos = from_file.tell() - from_file.seek(0) - while True: - data = from_file.read(COPY_BYTES) - if not data: - break - file.write(data) - self.remain = int(file.tell() - read_pos) - from_file.seek(read_pos) - file.seek(read_pos) - - def __len__(self): - return self.remain - - def __nonzero__(self): - return True - - __bool__ = __nonzero__ # py3 - - def append(self, s): - file = self.file - read_pos = file.tell() - file.seek(0, 2) - file.write(s) - file.seek(read_pos) - self.remain = self.remain + len(s) - - def get(self, numbytes=-1, skip=False): - file = self.file - if not skip: - read_pos = file.tell() - if numbytes < 0: - # Read all - res = file.read() - else: - res = file.read(numbytes) - if skip: - self.remain -= len(res) - else: - file.seek(read_pos) - return res - - def skip(self, numbytes, allow_prune=0): - if self.remain < numbytes: - raise ValueError( - "Can't skip %d bytes in buffer of %d bytes" % (numbytes, self.remain) - ) - self.file.seek(numbytes, 1) - self.remain = self.remain - numbytes - - def newfile(self): - raise NotImplementedError() - - def prune(self): - file = self.file - if self.remain == 0: - read_pos = file.tell() - file.seek(0, 2) - sz = file.tell() - file.seek(read_pos) - if sz == 0: - # Nothing to prune. - return - nf = self.newfile() - while True: - data = file.read(COPY_BYTES) - if not data: - break - nf.write(data) - self.file = nf - - def getfile(self): - return self.file - - def close(self): - if hasattr(self.file, "close"): - self.file.close() - self.remain = 0 - - -class TempfileBasedBuffer(FileBasedBuffer): - def __init__(self, from_buffer=None): - FileBasedBuffer.__init__(self, self.newfile(), from_buffer) - - def newfile(self): - from tempfile import TemporaryFile - - return TemporaryFile("w+b") - - -class BytesIOBasedBuffer(FileBasedBuffer): - def __init__(self, from_buffer=None): - if from_buffer is not None: - FileBasedBuffer.__init__(self, BytesIO(), from_buffer) - else: - # Shortcut. :-) - self.file = BytesIO() - - def newfile(self): - return BytesIO() - - -def _is_seekable(fp): - if hasattr(fp, "seekable"): - return fp.seekable() - return hasattr(fp, "seek") and hasattr(fp, "tell") - - -class ReadOnlyFileBasedBuffer(FileBasedBuffer): - # used as wsgi.file_wrapper - - def __init__(self, file, block_size=32768): - self.file = file - self.block_size = block_size # for __iter__ - - def prepare(self, size=None): - if _is_seekable(self.file): - start_pos = self.file.tell() - self.file.seek(0, 2) - end_pos = self.file.tell() - self.file.seek(start_pos) - fsize = end_pos - start_pos - if size is None: - self.remain = fsize - else: - self.remain = min(fsize, size) - return self.remain - - def get(self, numbytes=-1, skip=False): - # never read more than self.remain (it can be user-specified) - if numbytes == -1 or numbytes > self.remain: - numbytes = self.remain - file = self.file - if not skip: - read_pos = file.tell() - res = file.read(numbytes) - if skip: - self.remain -= len(res) - else: - file.seek(read_pos) - return res - - def __iter__(self): # called by task if self.filelike has no seek/tell - return self - - def next(self): - val = self.file.read(self.block_size) - if not val: - raise StopIteration - return val - - __next__ = next # py3 - - def append(self, s): - raise NotImplementedError - - -class OverflowableBuffer(object): - """ - This buffer implementation has four stages: - - No data - - Bytes-based buffer - - BytesIO-based buffer - - Temporary file storage - The first two stages are fastest for simple transfers. - """ - - overflowed = False - buf = None - strbuf = b"" # Bytes-based buffer. - - def __init__(self, overflow): - # overflow is the maximum to be stored in a StringIO buffer. - self.overflow = overflow - - def __len__(self): - buf = self.buf - if buf is not None: - # use buf.__len__ rather than len(buf) FBO of not getting - # OverflowError on Python 2 - return buf.__len__() - else: - return self.strbuf.__len__() - - def __nonzero__(self): - # use self.__len__ rather than len(self) FBO of not getting - # OverflowError on Python 2 - return self.__len__() > 0 - - __bool__ = __nonzero__ # py3 - - def _create_buffer(self): - strbuf = self.strbuf - if len(strbuf) >= self.overflow: - self._set_large_buffer() - else: - self._set_small_buffer() - buf = self.buf - if strbuf: - buf.append(self.strbuf) - self.strbuf = b"" - return buf - - def _set_small_buffer(self): - self.buf = BytesIOBasedBuffer(self.buf) - self.overflowed = False - - def _set_large_buffer(self): - self.buf = TempfileBasedBuffer(self.buf) - self.overflowed = True - - def append(self, s): - buf = self.buf - if buf is None: - strbuf = self.strbuf - if len(strbuf) + len(s) < STRBUF_LIMIT: - self.strbuf = strbuf + s - return - buf = self._create_buffer() - buf.append(s) - # use buf.__len__ rather than len(buf) FBO of not getting - # OverflowError on Python 2 - sz = buf.__len__() - if not self.overflowed: - if sz >= self.overflow: - self._set_large_buffer() - - def get(self, numbytes=-1, skip=False): - buf = self.buf - if buf is None: - strbuf = self.strbuf - if not skip: - return strbuf - buf = self._create_buffer() - return buf.get(numbytes, skip) - - def skip(self, numbytes, allow_prune=False): - buf = self.buf - if buf is None: - if allow_prune and numbytes == len(self.strbuf): - # We could slice instead of converting to - # a buffer, but that would eat up memory in - # large transfers. - self.strbuf = b"" - return - buf = self._create_buffer() - buf.skip(numbytes, allow_prune) - - def prune(self): - """ - A potentially expensive operation that removes all data - already retrieved from the buffer. - """ - buf = self.buf - if buf is None: - self.strbuf = b"" - return - buf.prune() - if self.overflowed: - # use buf.__len__ rather than len(buf) FBO of not getting - # OverflowError on Python 2 - sz = buf.__len__() - if sz < self.overflow: - # Revert to a faster buffer. - self._set_small_buffer() - - def getfile(self): - buf = self.buf - if buf is None: - buf = self._create_buffer() - return buf.getfile() - - def close(self): - buf = self.buf - if buf is not None: - buf.close() diff --git a/libs/waitress/channel.py b/libs/waitress/channel.py deleted file mode 100644 index a8bc76f74..000000000 --- a/libs/waitress/channel.py +++ /dev/null @@ -1,414 +0,0 @@ -############################################################################## -# -# Copyright (c) 2001, 2002 Zope Foundation and Contributors. -# All Rights Reserved. -# -# This software is subject to the provisions of the Zope Public License, -# Version 2.1 (ZPL). A copy of the ZPL should accompany this distribution. -# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED -# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED -# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS -# FOR A PARTICULAR PURPOSE. -# -############################################################################## -import socket -import threading -import time -import traceback - -from waitress.buffers import ( - OverflowableBuffer, - ReadOnlyFileBasedBuffer, -) - -from waitress.parser import HTTPRequestParser - -from waitress.task import ( - ErrorTask, - WSGITask, -) - -from waitress.utilities import InternalServerError - -from . import wasyncore - - -class ClientDisconnected(Exception): - """ Raised when attempting to write to a closed socket.""" - - -class HTTPChannel(wasyncore.dispatcher, object): - """ - Setting self.requests = [somerequest] prevents more requests from being - received until the out buffers have been flushed. - - Setting self.requests = [] allows more requests to be received. - """ - - task_class = WSGITask - error_task_class = ErrorTask - parser_class = HTTPRequestParser - - request = None # A request parser instance - last_activity = 0 # Time of last activity - will_close = False # set to True to close the socket. - close_when_flushed = False # set to True to close the socket when flushed - requests = () # currently pending requests - sent_continue = False # used as a latch after sending 100 continue - total_outbufs_len = 0 # total bytes ready to send - current_outbuf_count = 0 # total bytes written to current outbuf - - # - # ASYNCHRONOUS METHODS (including __init__) - # - - def __init__( - self, server, sock, addr, adj, map=None, - ): - self.server = server - self.adj = adj - self.outbufs = [OverflowableBuffer(adj.outbuf_overflow)] - self.creation_time = self.last_activity = time.time() - self.sendbuf_len = sock.getsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF) - - # task_lock used to push/pop requests - self.task_lock = threading.Lock() - # outbuf_lock used to access any outbuf (expected to use an RLock) - self.outbuf_lock = threading.Condition() - - wasyncore.dispatcher.__init__(self, sock, map=map) - - # Don't let wasyncore.dispatcher throttle self.addr on us. - self.addr = addr - - def writable(self): - # if there's data in the out buffer or we've been instructed to close - # the channel (possibly by our server maintenance logic), run - # handle_write - return self.total_outbufs_len or self.will_close or self.close_when_flushed - - def handle_write(self): - # Precondition: there's data in the out buffer to be sent, or - # there's a pending will_close request - if not self.connected: - # we dont want to close the channel twice - return - - # try to flush any pending output - if not self.requests: - # 1. There are no running tasks, so we don't need to try to lock - # the outbuf before sending - # 2. The data in the out buffer should be sent as soon as possible - # because it's either data left over from task output - # or a 100 Continue line sent within "received". - flush = self._flush_some - elif self.total_outbufs_len >= self.adj.send_bytes: - # 1. There's a running task, so we need to try to lock - # the outbuf before sending - # 2. Only try to send if the data in the out buffer is larger - # than self.adj_bytes to avoid TCP fragmentation - flush = self._flush_some_if_lockable - else: - # 1. There's not enough data in the out buffer to bother to send - # right now. - flush = None - - if flush: - try: - flush() - except socket.error: - if self.adj.log_socket_errors: - self.logger.exception("Socket error") - self.will_close = True - except Exception: - self.logger.exception("Unexpected exception when flushing") - self.will_close = True - - if self.close_when_flushed and not self.total_outbufs_len: - self.close_when_flushed = False - self.will_close = True - - if self.will_close: - self.handle_close() - - def readable(self): - # We might want to create a new task. We can only do this if: - # 1. We're not already about to close the connection. - # 2. There's no already currently running task(s). - # 3. There's no data in the output buffer that needs to be sent - # before we potentially create a new task. - return not (self.will_close or self.requests or self.total_outbufs_len) - - def handle_read(self): - try: - data = self.recv(self.adj.recv_bytes) - except socket.error: - if self.adj.log_socket_errors: - self.logger.exception("Socket error") - self.handle_close() - return - if data: - self.last_activity = time.time() - self.received(data) - - def received(self, data): - """ - Receives input asynchronously and assigns one or more requests to the - channel. - """ - # Preconditions: there's no task(s) already running - request = self.request - requests = [] - - if not data: - return False - - while data: - if request is None: - request = self.parser_class(self.adj) - n = request.received(data) - if request.expect_continue and request.headers_finished: - # guaranteed by parser to be a 1.1 request - request.expect_continue = False - if not self.sent_continue: - # there's no current task, so we don't need to try to - # lock the outbuf to append to it. - outbuf_payload = b"HTTP/1.1 100 Continue\r\n\r\n" - self.outbufs[-1].append(outbuf_payload) - self.current_outbuf_count += len(outbuf_payload) - self.total_outbufs_len += len(outbuf_payload) - self.sent_continue = True - self._flush_some() - request.completed = False - if request.completed: - # The request (with the body) is ready to use. - self.request = None - if not request.empty: - requests.append(request) - request = None - else: - self.request = request - if n >= len(data): - break - data = data[n:] - - if requests: - self.requests = requests - self.server.add_task(self) - - return True - - def _flush_some_if_lockable(self): - # Since our task may be appending to the outbuf, we try to acquire - # the lock, but we don't block if we can't. - if self.outbuf_lock.acquire(False): - try: - self._flush_some() - - if self.total_outbufs_len < self.adj.outbuf_high_watermark: - self.outbuf_lock.notify() - finally: - self.outbuf_lock.release() - - def _flush_some(self): - # Send as much data as possible to our client - - sent = 0 - dobreak = False - - while True: - outbuf = self.outbufs[0] - # use outbuf.__len__ rather than len(outbuf) FBO of not getting - # OverflowError on 32-bit Python - outbuflen = outbuf.__len__() - while outbuflen > 0: - chunk = outbuf.get(self.sendbuf_len) - num_sent = self.send(chunk) - if num_sent: - outbuf.skip(num_sent, True) - outbuflen -= num_sent - sent += num_sent - self.total_outbufs_len -= num_sent - else: - # failed to write anything, break out entirely - dobreak = True - break - else: - # self.outbufs[-1] must always be a writable outbuf - if len(self.outbufs) > 1: - toclose = self.outbufs.pop(0) - try: - toclose.close() - except Exception: - self.logger.exception("Unexpected error when closing an outbuf") - else: - # caught up, done flushing for now - dobreak = True - - if dobreak: - break - - if sent: - self.last_activity = time.time() - return True - - return False - - def handle_close(self): - with self.outbuf_lock: - for outbuf in self.outbufs: - try: - outbuf.close() - except Exception: - self.logger.exception( - "Unknown exception while trying to close outbuf" - ) - self.total_outbufs_len = 0 - self.connected = False - self.outbuf_lock.notify() - wasyncore.dispatcher.close(self) - - def add_channel(self, map=None): - """See wasyncore.dispatcher - - This hook keeps track of opened channels. - """ - wasyncore.dispatcher.add_channel(self, map) - self.server.active_channels[self._fileno] = self - - def del_channel(self, map=None): - """See wasyncore.dispatcher - - This hook keeps track of closed channels. - """ - fd = self._fileno # next line sets this to None - wasyncore.dispatcher.del_channel(self, map) - ac = self.server.active_channels - if fd in ac: - del ac[fd] - - # - # SYNCHRONOUS METHODS - # - - def write_soon(self, data): - if not self.connected: - # if the socket is closed then interrupt the task so that it - # can cleanup possibly before the app_iter is exhausted - raise ClientDisconnected - if data: - # the async mainloop might be popping data off outbuf; we can - # block here waiting for it because we're in a task thread - with self.outbuf_lock: - self._flush_outbufs_below_high_watermark() - if not self.connected: - raise ClientDisconnected - num_bytes = len(data) - if data.__class__ is ReadOnlyFileBasedBuffer: - # they used wsgi.file_wrapper - self.outbufs.append(data) - nextbuf = OverflowableBuffer(self.adj.outbuf_overflow) - self.outbufs.append(nextbuf) - self.current_outbuf_count = 0 - else: - if self.current_outbuf_count > self.adj.outbuf_high_watermark: - # rotate to a new buffer if the current buffer has hit - # the watermark to avoid it growing unbounded - nextbuf = OverflowableBuffer(self.adj.outbuf_overflow) - self.outbufs.append(nextbuf) - self.current_outbuf_count = 0 - self.outbufs[-1].append(data) - self.current_outbuf_count += num_bytes - self.total_outbufs_len += num_bytes - if self.total_outbufs_len >= self.adj.send_bytes: - self.server.pull_trigger() - return num_bytes - return 0 - - def _flush_outbufs_below_high_watermark(self): - # check first to avoid locking if possible - if self.total_outbufs_len > self.adj.outbuf_high_watermark: - with self.outbuf_lock: - while ( - self.connected - and self.total_outbufs_len > self.adj.outbuf_high_watermark - ): - self.server.pull_trigger() - self.outbuf_lock.wait() - - def service(self): - """Execute all pending requests """ - with self.task_lock: - while self.requests: - request = self.requests[0] - if request.error: - task = self.error_task_class(self, request) - else: - task = self.task_class(self, request) - try: - task.service() - except ClientDisconnected: - self.logger.info( - "Client disconnected while serving %s" % task.request.path - ) - task.close_on_finish = True - except Exception: - self.logger.exception( - "Exception while serving %s" % task.request.path - ) - if not task.wrote_header: - if self.adj.expose_tracebacks: - body = traceback.format_exc() - else: - body = ( - "The server encountered an unexpected " - "internal server error" - ) - req_version = request.version - req_headers = request.headers - request = self.parser_class(self.adj) - request.error = InternalServerError(body) - # copy some original request attributes to fulfill - # HTTP 1.1 requirements - request.version = req_version - try: - request.headers["CONNECTION"] = req_headers["CONNECTION"] - except KeyError: - pass - task = self.error_task_class(self, request) - try: - task.service() # must not fail - except ClientDisconnected: - task.close_on_finish = True - else: - task.close_on_finish = True - # we cannot allow self.requests to drop to empty til - # here; otherwise the mainloop gets confused - if task.close_on_finish: - self.close_when_flushed = True - for request in self.requests: - request.close() - self.requests = [] - else: - # before processing a new request, ensure there is not too - # much data in the outbufs waiting to be flushed - # NB: currently readable() returns False while we are - # flushing data so we know no new requests will come in - # that we need to account for, otherwise it'd be better - # to do this check at the start of the request instead of - # at the end to account for consecutive service() calls - if len(self.requests) > 1: - self._flush_outbufs_below_high_watermark() - request = self.requests.pop(0) - request.close() - - if self.connected: - self.server.pull_trigger() - self.last_activity = time.time() - - def cancel(self): - """ Cancels all pending / active requests """ - self.will_close = True - self.connected = False - self.last_activity = time.time() - self.requests = [] diff --git a/libs/waitress/compat.py b/libs/waitress/compat.py deleted file mode 100644 index fe72a7610..000000000 --- a/libs/waitress/compat.py +++ /dev/null @@ -1,179 +0,0 @@ -import os -import sys -import types -import platform -import warnings - -try: - import urlparse -except ImportError: # pragma: no cover - from urllib import parse as urlparse - -try: - import fcntl -except ImportError: # pragma: no cover - fcntl = None # windows - -# True if we are running on Python 3. -PY2 = sys.version_info[0] == 2 -PY3 = sys.version_info[0] == 3 - -# True if we are running on Windows -WIN = platform.system() == "Windows" - -if PY3: # pragma: no cover - string_types = (str,) - integer_types = (int,) - class_types = (type,) - text_type = str - binary_type = bytes - long = int -else: - string_types = (basestring,) - integer_types = (int, long) - class_types = (type, types.ClassType) - text_type = unicode - binary_type = str - long = long - -if PY3: # pragma: no cover - from urllib.parse import unquote_to_bytes - - def unquote_bytes_to_wsgi(bytestring): - return unquote_to_bytes(bytestring).decode("latin-1") - - -else: - from urlparse import unquote as unquote_to_bytes - - def unquote_bytes_to_wsgi(bytestring): - return unquote_to_bytes(bytestring) - - -def text_(s, encoding="latin-1", errors="strict"): - """ If ``s`` is an instance of ``binary_type``, return - ``s.decode(encoding, errors)``, otherwise return ``s``""" - if isinstance(s, binary_type): - return s.decode(encoding, errors) - return s # pragma: no cover - - -if PY3: # pragma: no cover - - def tostr(s): - if isinstance(s, text_type): - s = s.encode("latin-1") - return str(s, "latin-1", "strict") - - def tobytes(s): - return bytes(s, "latin-1") - - -else: - tostr = str - - def tobytes(s): - return s - - -if PY3: # pragma: no cover - import builtins - - exec_ = getattr(builtins, "exec") - - def reraise(tp, value, tb=None): - if value is None: - value = tp - if value.__traceback__ is not tb: - raise value.with_traceback(tb) - raise value - - del builtins - -else: # pragma: no cover - - def exec_(code, globs=None, locs=None): - """Execute code in a namespace.""" - if globs is None: - frame = sys._getframe(1) - globs = frame.f_globals - if locs is None: - locs = frame.f_locals - del frame - elif locs is None: - locs = globs - exec("""exec code in globs, locs""") - - exec_( - """def reraise(tp, value, tb=None): - raise tp, value, tb -""" - ) - -try: - from StringIO import StringIO as NativeIO -except ImportError: # pragma: no cover - from io import StringIO as NativeIO - -try: - import httplib -except ImportError: # pragma: no cover - from http import client as httplib - -try: - MAXINT = sys.maxint -except AttributeError: # pragma: no cover - MAXINT = sys.maxsize - - -# Fix for issue reported in https://github.com/Pylons/waitress/issues/138, -# Python on Windows may not define IPPROTO_IPV6 in socket. -import socket - -HAS_IPV6 = socket.has_ipv6 - -if hasattr(socket, "IPPROTO_IPV6") and hasattr(socket, "IPV6_V6ONLY"): - IPPROTO_IPV6 = socket.IPPROTO_IPV6 - IPV6_V6ONLY = socket.IPV6_V6ONLY -else: # pragma: no cover - if WIN: - IPPROTO_IPV6 = 41 - IPV6_V6ONLY = 27 - else: - warnings.warn( - "OS does not support required IPv6 socket flags. This is requirement " - "for Waitress. Please open an issue at https://github.com/Pylons/waitress. " - "IPv6 support has been disabled.", - RuntimeWarning, - ) - HAS_IPV6 = False - - -def set_nonblocking(fd): # pragma: no cover - if PY3 and sys.version_info[1] >= 5: - os.set_blocking(fd, False) - elif fcntl is None: - raise RuntimeError("no fcntl module present") - else: - flags = fcntl.fcntl(fd, fcntl.F_GETFL, 0) - flags = flags | os.O_NONBLOCK - fcntl.fcntl(fd, fcntl.F_SETFL, flags) - - -if PY3: - ResourceWarning = ResourceWarning -else: - ResourceWarning = UserWarning - - -def qualname(cls): - if PY3: - return cls.__qualname__ - return cls.__name__ - - -try: - import thread -except ImportError: - # py3 - import _thread as thread diff --git a/libs/waitress/parser.py b/libs/waitress/parser.py deleted file mode 100644 index fef8a3da6..000000000 --- a/libs/waitress/parser.py +++ /dev/null @@ -1,413 +0,0 @@ -############################################################################## -# -# Copyright (c) 2001, 2002 Zope Foundation and Contributors. -# All Rights Reserved. -# -# This software is subject to the provisions of the Zope Public License, -# Version 2.1 (ZPL). A copy of the ZPL should accompany this distribution. -# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED -# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED -# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS -# FOR A PARTICULAR PURPOSE. -# -############################################################################## -"""HTTP Request Parser - -This server uses asyncore to accept connections and do initial -processing but threads to do work. -""" -import re -from io import BytesIO - -from waitress.buffers import OverflowableBuffer -from waitress.compat import tostr, unquote_bytes_to_wsgi, urlparse -from waitress.receiver import ChunkedReceiver, FixedStreamReceiver -from waitress.utilities import ( - BadRequest, - RequestEntityTooLarge, - RequestHeaderFieldsTooLarge, - ServerNotImplemented, - find_double_newline, -) -from .rfc7230 import HEADER_FIELD - - -class ParsingError(Exception): - pass - - -class TransferEncodingNotImplemented(Exception): - pass - -class HTTPRequestParser(object): - """A structure that collects the HTTP request. - - Once the stream is completed, the instance is passed to - a server task constructor. - """ - - completed = False # Set once request is completed. - empty = False # Set if no request was made. - expect_continue = False # client sent "Expect: 100-continue" header - headers_finished = False # True when headers have been read - header_plus = b"" - chunked = False - content_length = 0 - header_bytes_received = 0 - body_bytes_received = 0 - body_rcv = None - version = "1.0" - error = None - connection_close = False - - # Other attributes: first_line, header, headers, command, uri, version, - # path, query, fragment - - def __init__(self, adj): - """ - adj is an Adjustments object. - """ - # headers is a mapping containing keys translated to uppercase - # with dashes turned into underscores. - self.headers = {} - self.adj = adj - - def received(self, data): - """ - Receives the HTTP stream for one request. Returns the number of - bytes consumed. Sets the completed flag once both the header and the - body have been received. - """ - if self.completed: - return 0 # Can't consume any more. - - datalen = len(data) - br = self.body_rcv - if br is None: - # In header. - max_header = self.adj.max_request_header_size - - s = self.header_plus + data - index = find_double_newline(s) - consumed = 0 - - if index >= 0: - # If the headers have ended, and we also have part of the body - # message in data we still want to validate we aren't going - # over our limit for received headers. - self.header_bytes_received += index - consumed = datalen - (len(s) - index) - else: - self.header_bytes_received += datalen - consumed = datalen - - # If the first line + headers is over the max length, we return a - # RequestHeaderFieldsTooLarge error rather than continuing to - # attempt to parse the headers. - if self.header_bytes_received >= max_header: - self.parse_header(b"GET / HTTP/1.0\r\n") - self.error = RequestHeaderFieldsTooLarge( - "exceeds max_header of %s" % max_header - ) - self.completed = True - return consumed - - if index >= 0: - # Header finished. - header_plus = s[:index] - - # Remove preceeding blank lines. This is suggested by - # https://tools.ietf.org/html/rfc7230#section-3.5 to support - # clients sending an extra CR LF after another request when - # using HTTP pipelining - header_plus = header_plus.lstrip() - - if not header_plus: - self.empty = True - self.completed = True - else: - try: - self.parse_header(header_plus) - except ParsingError as e: - self.error = BadRequest(e.args[0]) - self.completed = True - except TransferEncodingNotImplemented as e: - self.error = ServerNotImplemented(e.args[0]) - self.completed = True - else: - if self.body_rcv is None: - # no content-length header and not a t-e: chunked - # request - self.completed = True - - if self.content_length > 0: - max_body = self.adj.max_request_body_size - # we won't accept this request if the content-length - # is too large - - if self.content_length >= max_body: - self.error = RequestEntityTooLarge( - "exceeds max_body of %s" % max_body - ) - self.completed = True - self.headers_finished = True - - return consumed - - # Header not finished yet. - self.header_plus = s - - return datalen - else: - # In body. - consumed = br.received(data) - self.body_bytes_received += consumed - max_body = self.adj.max_request_body_size - - if self.body_bytes_received >= max_body: - # this will only be raised during t-e: chunked requests - self.error = RequestEntityTooLarge("exceeds max_body of %s" % max_body) - self.completed = True - elif br.error: - # garbage in chunked encoding input probably - self.error = br.error - self.completed = True - elif br.completed: - # The request (with the body) is ready to use. - self.completed = True - - if self.chunked: - # We've converted the chunked transfer encoding request - # body into a normal request body, so we know its content - # length; set the header here. We already popped the - # TRANSFER_ENCODING header in parse_header, so this will - # appear to the client to be an entirely non-chunked HTTP - # request with a valid content-length. - self.headers["CONTENT_LENGTH"] = str(br.__len__()) - - return consumed - - def parse_header(self, header_plus): - """ - Parses the header_plus block of text (the headers plus the - first line of the request). - """ - index = header_plus.find(b"\r\n") - if index >= 0: - first_line = header_plus[:index].rstrip() - header = header_plus[index + 2 :] - else: - raise ParsingError("HTTP message header invalid") - - if b"\r" in first_line or b"\n" in first_line: - raise ParsingError("Bare CR or LF found in HTTP message") - - self.first_line = first_line # for testing - - lines = get_header_lines(header) - - headers = self.headers - for line in lines: - header = HEADER_FIELD.match(line) - - if not header: - raise ParsingError("Invalid header") - - key, value = header.group("name", "value") - - if b"_" in key: - # TODO(xistence): Should we drop this request instead? - continue - - # Only strip off whitespace that is considered valid whitespace by - # RFC7230, don't strip the rest - value = value.strip(b" \t") - key1 = tostr(key.upper().replace(b"-", b"_")) - # If a header already exists, we append subsequent values - # seperated by a comma. Applications already need to handle - # the comma seperated values, as HTTP front ends might do - # the concatenation for you (behavior specified in RFC2616). - try: - headers[key1] += tostr(b", " + value) - except KeyError: - headers[key1] = tostr(value) - - # command, uri, version will be bytes - command, uri, version = crack_first_line(first_line) - version = tostr(version) - command = tostr(command) - self.command = command - self.version = version - ( - self.proxy_scheme, - self.proxy_netloc, - self.path, - self.query, - self.fragment, - ) = split_uri(uri) - self.url_scheme = self.adj.url_scheme - connection = headers.get("CONNECTION", "") - - if version == "1.0": - if connection.lower() != "keep-alive": - self.connection_close = True - - if version == "1.1": - # since the server buffers data from chunked transfers and clients - # never need to deal with chunked requests, downstream clients - # should not see the HTTP_TRANSFER_ENCODING header; we pop it - # here - te = headers.pop("TRANSFER_ENCODING", "") - - # NB: We can not just call bare strip() here because it will also - # remove other non-printable characters that we explicitly do not - # want removed so that if someone attempts to smuggle a request - # with these characters we don't fall prey to it. - # - # For example \x85 is stripped by default, but it is not considered - # valid whitespace to be stripped by RFC7230. - encodings = [ - encoding.strip(" \t").lower() for encoding in te.split(",") if encoding - ] - - for encoding in encodings: - # Out of the transfer-codings listed in - # https://tools.ietf.org/html/rfc7230#section-4 we only support - # chunked at this time. - - # Note: the identity transfer-coding was removed in RFC7230: - # https://tools.ietf.org/html/rfc7230#appendix-A.2 and is thus - # not supported - if encoding not in {"chunked"}: - raise TransferEncodingNotImplemented( - "Transfer-Encoding requested is not supported." - ) - - if encodings and encodings[-1] == "chunked": - self.chunked = True - buf = OverflowableBuffer(self.adj.inbuf_overflow) - self.body_rcv = ChunkedReceiver(buf) - elif encodings: # pragma: nocover - raise TransferEncodingNotImplemented( - "Transfer-Encoding requested is not supported." - ) - - expect = headers.get("EXPECT", "").lower() - self.expect_continue = expect == "100-continue" - if connection.lower() == "close": - self.connection_close = True - - if not self.chunked: - try: - cl = int(headers.get("CONTENT_LENGTH", 0)) - except ValueError: - raise ParsingError("Content-Length is invalid") - - self.content_length = cl - if cl > 0: - buf = OverflowableBuffer(self.adj.inbuf_overflow) - self.body_rcv = FixedStreamReceiver(cl, buf) - - def get_body_stream(self): - body_rcv = self.body_rcv - if body_rcv is not None: - return body_rcv.getfile() - else: - return BytesIO() - - def close(self): - body_rcv = self.body_rcv - if body_rcv is not None: - body_rcv.getbuf().close() - - -def split_uri(uri): - # urlsplit handles byte input by returning bytes on py3, so - # scheme, netloc, path, query, and fragment are bytes - - scheme = netloc = path = query = fragment = b"" - - # urlsplit below will treat this as a scheme-less netloc, thereby losing - # the original intent of the request. Here we shamelessly stole 4 lines of - # code from the CPython stdlib to parse out the fragment and query but - # leave the path alone. See - # https://github.com/python/cpython/blob/8c9e9b0cd5b24dfbf1424d1f253d02de80e8f5ef/Lib/urllib/parse.py#L465-L468 - # and https://github.com/Pylons/waitress/issues/260 - - if uri[:2] == b"//": - path = uri - - if b"#" in path: - path, fragment = path.split(b"#", 1) - - if b"?" in path: - path, query = path.split(b"?", 1) - else: - try: - scheme, netloc, path, query, fragment = urlparse.urlsplit(uri) - except UnicodeError: - raise ParsingError("Bad URI") - - return ( - tostr(scheme), - tostr(netloc), - unquote_bytes_to_wsgi(path), - tostr(query), - tostr(fragment), - ) - - -def get_header_lines(header): - """ - Splits the header into lines, putting multi-line headers together. - """ - r = [] - lines = header.split(b"\r\n") - for line in lines: - if not line: - continue - - if b"\r" in line or b"\n" in line: - raise ParsingError('Bare CR or LF found in header line "%s"' % tostr(line)) - - if line.startswith((b" ", b"\t")): - if not r: - # https://corte.si/posts/code/pathod/pythonservers/index.html - raise ParsingError('Malformed header line "%s"' % tostr(line)) - r[-1] += line - else: - r.append(line) - return r - - -first_line_re = re.compile( - b"([^ ]+) " - b"((?:[^ :?#]+://[^ ?#/]*(?:[0-9]{1,5})?)?[^ ]+)" - b"(( HTTP/([0-9.]+))$|$)" -) - - -def crack_first_line(line): - m = first_line_re.match(line) - if m is not None and m.end() == len(line): - if m.group(3): - version = m.group(5) - else: - version = b"" - method = m.group(1) - - # the request methods that are currently defined are all uppercase: - # https://www.iana.org/assignments/http-methods/http-methods.xhtml and - # the request method is case sensitive according to - # https://tools.ietf.org/html/rfc7231#section-4.1 - - # By disallowing anything but uppercase methods we save poor - # unsuspecting souls from sending lowercase HTTP methods to waitress - # and having the request complete, while servers like nginx drop the - # request onto the floor. - if method != method.upper(): - raise ParsingError('Malformed HTTP method "%s"' % tostr(method)) - uri = m.group(2) - return method, uri, version - else: - return b"", b"", b"" diff --git a/libs/waitress/proxy_headers.py b/libs/waitress/proxy_headers.py deleted file mode 100644 index 1df8b8eba..000000000 --- a/libs/waitress/proxy_headers.py +++ /dev/null @@ -1,333 +0,0 @@ -from collections import namedtuple - -from .utilities import logger, undquote, BadRequest - - -PROXY_HEADERS = frozenset( - { - "X_FORWARDED_FOR", - "X_FORWARDED_HOST", - "X_FORWARDED_PROTO", - "X_FORWARDED_PORT", - "X_FORWARDED_BY", - "FORWARDED", - } -) - -Forwarded = namedtuple("Forwarded", ["by", "for_", "host", "proto"]) - - -class MalformedProxyHeader(Exception): - def __init__(self, header, reason, value): - self.header = header - self.reason = reason - self.value = value - super(MalformedProxyHeader, self).__init__(header, reason, value) - - -def proxy_headers_middleware( - app, - trusted_proxy=None, - trusted_proxy_count=1, - trusted_proxy_headers=None, - clear_untrusted=True, - log_untrusted=False, - logger=logger, -): - def translate_proxy_headers(environ, start_response): - untrusted_headers = PROXY_HEADERS - remote_peer = environ["REMOTE_ADDR"] - if trusted_proxy == "*" or remote_peer == trusted_proxy: - try: - untrusted_headers = parse_proxy_headers( - environ, - trusted_proxy_count=trusted_proxy_count, - trusted_proxy_headers=trusted_proxy_headers, - logger=logger, - ) - except MalformedProxyHeader as ex: - logger.warning( - 'Malformed proxy header "%s" from "%s": %s value: %s', - ex.header, - remote_peer, - ex.reason, - ex.value, - ) - error = BadRequest('Header "{0}" malformed.'.format(ex.header)) - return error.wsgi_response(environ, start_response) - - # Clear out the untrusted proxy headers - if clear_untrusted: - clear_untrusted_headers( - environ, untrusted_headers, log_warning=log_untrusted, logger=logger, - ) - - return app(environ, start_response) - - return translate_proxy_headers - - -def parse_proxy_headers( - environ, trusted_proxy_count, trusted_proxy_headers, logger=logger, -): - if trusted_proxy_headers is None: - trusted_proxy_headers = set() - - forwarded_for = [] - forwarded_host = forwarded_proto = forwarded_port = forwarded = "" - client_addr = None - untrusted_headers = set(PROXY_HEADERS) - - def raise_for_multiple_values(): - raise ValueError("Unspecified behavior for multiple values found in header",) - - if "x-forwarded-for" in trusted_proxy_headers and "HTTP_X_FORWARDED_FOR" in environ: - try: - forwarded_for = [] - - for forward_hop in environ["HTTP_X_FORWARDED_FOR"].split(","): - forward_hop = forward_hop.strip() - forward_hop = undquote(forward_hop) - - # Make sure that all IPv6 addresses are surrounded by brackets, - # this is assuming that the IPv6 representation here does not - # include a port number. - - if "." not in forward_hop and ( - ":" in forward_hop and forward_hop[-1] != "]" - ): - forwarded_for.append("[{}]".format(forward_hop)) - else: - forwarded_for.append(forward_hop) - - forwarded_for = forwarded_for[-trusted_proxy_count:] - client_addr = forwarded_for[0] - - untrusted_headers.remove("X_FORWARDED_FOR") - except Exception as ex: - raise MalformedProxyHeader( - "X-Forwarded-For", str(ex), environ["HTTP_X_FORWARDED_FOR"], - ) - - if ( - "x-forwarded-host" in trusted_proxy_headers - and "HTTP_X_FORWARDED_HOST" in environ - ): - try: - forwarded_host_multiple = [] - - for forward_host in environ["HTTP_X_FORWARDED_HOST"].split(","): - forward_host = forward_host.strip() - forward_host = undquote(forward_host) - forwarded_host_multiple.append(forward_host) - - forwarded_host_multiple = forwarded_host_multiple[-trusted_proxy_count:] - forwarded_host = forwarded_host_multiple[0] - - untrusted_headers.remove("X_FORWARDED_HOST") - except Exception as ex: - raise MalformedProxyHeader( - "X-Forwarded-Host", str(ex), environ["HTTP_X_FORWARDED_HOST"], - ) - - if "x-forwarded-proto" in trusted_proxy_headers: - try: - forwarded_proto = undquote(environ.get("HTTP_X_FORWARDED_PROTO", "")) - if "," in forwarded_proto: - raise_for_multiple_values() - untrusted_headers.remove("X_FORWARDED_PROTO") - except Exception as ex: - raise MalformedProxyHeader( - "X-Forwarded-Proto", str(ex), environ["HTTP_X_FORWARDED_PROTO"], - ) - - if "x-forwarded-port" in trusted_proxy_headers: - try: - forwarded_port = undquote(environ.get("HTTP_X_FORWARDED_PORT", "")) - if "," in forwarded_port: - raise_for_multiple_values() - untrusted_headers.remove("X_FORWARDED_PORT") - except Exception as ex: - raise MalformedProxyHeader( - "X-Forwarded-Port", str(ex), environ["HTTP_X_FORWARDED_PORT"], - ) - - if "x-forwarded-by" in trusted_proxy_headers: - # Waitress itself does not use X-Forwarded-By, but we can not - # remove it so it can get set in the environ - untrusted_headers.remove("X_FORWARDED_BY") - - if "forwarded" in trusted_proxy_headers: - forwarded = environ.get("HTTP_FORWARDED", None) - untrusted_headers = PROXY_HEADERS - {"FORWARDED"} - - # If the Forwarded header exists, it gets priority - if forwarded: - proxies = [] - try: - for forwarded_element in forwarded.split(","): - # Remove whitespace that may have been introduced when - # appending a new entry - forwarded_element = forwarded_element.strip() - - forwarded_for = forwarded_host = forwarded_proto = "" - forwarded_port = forwarded_by = "" - - for pair in forwarded_element.split(";"): - pair = pair.lower() - - if not pair: - continue - - token, equals, value = pair.partition("=") - - if equals != "=": - raise ValueError('Invalid forwarded-pair missing "="') - - if token.strip() != token: - raise ValueError("Token may not be surrounded by whitespace") - - if value.strip() != value: - raise ValueError("Value may not be surrounded by whitespace") - - if token == "by": - forwarded_by = undquote(value) - - elif token == "for": - forwarded_for = undquote(value) - - elif token == "host": - forwarded_host = undquote(value) - - elif token == "proto": - forwarded_proto = undquote(value) - - else: - logger.warning("Unknown Forwarded token: %s" % token) - - proxies.append( - Forwarded( - forwarded_by, forwarded_for, forwarded_host, forwarded_proto - ) - ) - except Exception as ex: - raise MalformedProxyHeader( - "Forwarded", str(ex), environ["HTTP_FORWARDED"], - ) - - proxies = proxies[-trusted_proxy_count:] - - # Iterate backwards and fill in some values, the oldest entry that - # contains the information we expect is the one we use. We expect - # that intermediate proxies may re-write the host header or proto, - # but the oldest entry is the one that contains the information the - # client expects when generating URL's - # - # Forwarded: for="[2001:db8::1]";host="example.com:8443";proto="https" - # Forwarded: for=192.0.2.1;host="example.internal:8080" - # - # (After HTTPS header folding) should mean that we use as values: - # - # Host: example.com - # Protocol: https - # Port: 8443 - - for proxy in proxies[::-1]: - client_addr = proxy.for_ or client_addr - forwarded_host = proxy.host or forwarded_host - forwarded_proto = proxy.proto or forwarded_proto - - if forwarded_proto: - forwarded_proto = forwarded_proto.lower() - - if forwarded_proto not in {"http", "https"}: - raise MalformedProxyHeader( - "Forwarded Proto=" if forwarded else "X-Forwarded-Proto", - "unsupported proto value", - forwarded_proto, - ) - - # Set the URL scheme to the proxy provided proto - environ["wsgi.url_scheme"] = forwarded_proto - - if not forwarded_port: - if forwarded_proto == "http": - forwarded_port = "80" - - if forwarded_proto == "https": - forwarded_port = "443" - - if forwarded_host: - if ":" in forwarded_host and forwarded_host[-1] != "]": - host, port = forwarded_host.rsplit(":", 1) - host, port = host.strip(), str(port) - - # We trust the port in the Forwarded Host/X-Forwarded-Host over - # X-Forwarded-Port, or whatever we got from Forwarded - # Proto/X-Forwarded-Proto. - - if forwarded_port != port: - forwarded_port = port - - # We trust the proxy server's forwarded Host - environ["SERVER_NAME"] = host - environ["HTTP_HOST"] = forwarded_host - else: - # We trust the proxy server's forwarded Host - environ["SERVER_NAME"] = forwarded_host - environ["HTTP_HOST"] = forwarded_host - - if forwarded_port: - if forwarded_port not in {"443", "80"}: - environ["HTTP_HOST"] = "{}:{}".format( - forwarded_host, forwarded_port - ) - elif forwarded_port == "80" and environ["wsgi.url_scheme"] != "http": - environ["HTTP_HOST"] = "{}:{}".format( - forwarded_host, forwarded_port - ) - elif forwarded_port == "443" and environ["wsgi.url_scheme"] != "https": - environ["HTTP_HOST"] = "{}:{}".format( - forwarded_host, forwarded_port - ) - - if forwarded_port: - environ["SERVER_PORT"] = str(forwarded_port) - - if client_addr: - if ":" in client_addr and client_addr[-1] != "]": - addr, port = client_addr.rsplit(":", 1) - environ["REMOTE_ADDR"] = strip_brackets(addr.strip()) - environ["REMOTE_PORT"] = port.strip() - else: - environ["REMOTE_ADDR"] = strip_brackets(client_addr.strip()) - environ["REMOTE_HOST"] = environ["REMOTE_ADDR"] - - return untrusted_headers - - -def strip_brackets(addr): - if addr[0] == "[" and addr[-1] == "]": - return addr[1:-1] - return addr - - -def clear_untrusted_headers( - environ, untrusted_headers, log_warning=False, logger=logger -): - untrusted_headers_removed = [ - header - for header in untrusted_headers - if environ.pop("HTTP_" + header, False) is not False - ] - - if log_warning and untrusted_headers_removed: - untrusted_headers_removed = [ - "-".join(x.capitalize() for x in header.split("_")) - for header in untrusted_headers_removed - ] - logger.warning( - "Removed untrusted headers (%s). Waitress recommends these be " - "removed upstream.", - ", ".join(untrusted_headers_removed), - ) diff --git a/libs/waitress/receiver.py b/libs/waitress/receiver.py deleted file mode 100644 index 5d1568d51..000000000 --- a/libs/waitress/receiver.py +++ /dev/null @@ -1,186 +0,0 @@ -############################################################################## -# -# Copyright (c) 2001, 2002 Zope Foundation and Contributors. -# All Rights Reserved. -# -# This software is subject to the provisions of the Zope Public License, -# Version 2.1 (ZPL). A copy of the ZPL should accompany this distribution. -# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED -# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED -# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS -# FOR A PARTICULAR PURPOSE. -# -############################################################################## -"""Data Chunk Receiver -""" - -from waitress.utilities import BadRequest, find_double_newline - - -class FixedStreamReceiver(object): - - # See IStreamConsumer - completed = False - error = None - - def __init__(self, cl, buf): - self.remain = cl - self.buf = buf - - def __len__(self): - return self.buf.__len__() - - def received(self, data): - "See IStreamConsumer" - rm = self.remain - - if rm < 1: - self.completed = True # Avoid any chance of spinning - - return 0 - datalen = len(data) - - if rm <= datalen: - self.buf.append(data[:rm]) - self.remain = 0 - self.completed = True - - return rm - else: - self.buf.append(data) - self.remain -= datalen - - return datalen - - def getfile(self): - return self.buf.getfile() - - def getbuf(self): - return self.buf - - -class ChunkedReceiver(object): - - chunk_remainder = 0 - validate_chunk_end = False - control_line = b"" - chunk_end = b"" - all_chunks_received = False - trailer = b"" - completed = False - error = None - - # max_control_line = 1024 - # max_trailer = 65536 - - def __init__(self, buf): - self.buf = buf - - def __len__(self): - return self.buf.__len__() - - def received(self, s): - # Returns the number of bytes consumed. - - if self.completed: - return 0 - orig_size = len(s) - - while s: - rm = self.chunk_remainder - - if rm > 0: - # Receive the remainder of a chunk. - to_write = s[:rm] - self.buf.append(to_write) - written = len(to_write) - s = s[written:] - - self.chunk_remainder -= written - - if self.chunk_remainder == 0: - self.validate_chunk_end = True - elif self.validate_chunk_end: - s = self.chunk_end + s - - pos = s.find(b"\r\n") - - if pos < 0 and len(s) < 2: - self.chunk_end = s - s = b"" - else: - self.chunk_end = b"" - if pos == 0: - # Chop off the terminating CR LF from the chunk - s = s[2:] - else: - self.error = BadRequest("Chunk not properly terminated") - self.all_chunks_received = True - - # Always exit this loop - self.validate_chunk_end = False - elif not self.all_chunks_received: - # Receive a control line. - s = self.control_line + s - pos = s.find(b"\r\n") - - if pos < 0: - # Control line not finished. - self.control_line = s - s = b"" - else: - # Control line finished. - line = s[:pos] - s = s[pos + 2 :] - self.control_line = b"" - line = line.strip() - - if line: - # Begin a new chunk. - semi = line.find(b";") - - if semi >= 0: - # discard extension info. - line = line[:semi] - try: - sz = int(line.strip(), 16) # hexadecimal - except ValueError: # garbage in input - self.error = BadRequest("garbage in chunked encoding input") - sz = 0 - - if sz > 0: - # Start a new chunk. - self.chunk_remainder = sz - else: - # Finished chunks. - self.all_chunks_received = True - # else expect a control line. - else: - # Receive the trailer. - trailer = self.trailer + s - - if trailer.startswith(b"\r\n"): - # No trailer. - self.completed = True - - return orig_size - (len(trailer) - 2) - pos = find_double_newline(trailer) - - if pos < 0: - # Trailer not finished. - self.trailer = trailer - s = b"" - else: - # Finished the trailer. - self.completed = True - self.trailer = trailer[:pos] - - return orig_size - (len(trailer) - pos) - - return orig_size - - def getfile(self): - return self.buf.getfile() - - def getbuf(self): - return self.buf diff --git a/libs/waitress/rfc7230.py b/libs/waitress/rfc7230.py deleted file mode 100644 index cd33c9064..000000000 --- a/libs/waitress/rfc7230.py +++ /dev/null @@ -1,52 +0,0 @@ -""" -This contains a bunch of RFC7230 definitions and regular expressions that are -needed to properly parse HTTP messages. -""" - -import re - -from .compat import tobytes - -WS = "[ \t]" -OWS = WS + "{0,}?" -RWS = WS + "{1,}?" -BWS = OWS - -# RFC 7230 Section 3.2.6 "Field Value Components": -# tchar = "!" / "#" / "$" / "%" / "&" / "'" / "*" -# / "+" / "-" / "." / "^" / "_" / "`" / "|" / "~" -# / DIGIT / ALPHA -# obs-text = %x80-FF -TCHAR = r"[!#$%&'*+\-.^_`|~0-9A-Za-z]" -OBS_TEXT = r"\x80-\xff" - -TOKEN = TCHAR + "{1,}" - -# RFC 5234 Appendix B.1 "Core Rules": -# VCHAR = %x21-7E -# ; visible (printing) characters -VCHAR = r"\x21-\x7e" - -# header-field = field-name ":" OWS field-value OWS -# field-name = token -# field-value = *( field-content / obs-fold ) -# field-content = field-vchar [ 1*( SP / HTAB ) field-vchar ] -# field-vchar = VCHAR / obs-text - -# Errata from: https://www.rfc-editor.org/errata_search.php?rfc=7230&eid=4189 -# changes field-content to: -# -# field-content = field-vchar [ 1*( SP / HTAB / field-vchar ) -# field-vchar ] - -FIELD_VCHAR = "[" + VCHAR + OBS_TEXT + "]" -# Field content is more greedy than the ABNF, in that it will match the whole value -FIELD_CONTENT = FIELD_VCHAR + "+(?:[ \t]+" + FIELD_VCHAR + "+)*" -# Which allows the field value here to just see if there is even a value in the first place -FIELD_VALUE = "(?:" + FIELD_CONTENT + ")?" - -HEADER_FIELD = re.compile( - tobytes( - "^(?P" + TOKEN + "):" + OWS + "(?P" + FIELD_VALUE + ")" + OWS + "$" - ) -) diff --git a/libs/waitress/runner.py b/libs/waitress/runner.py deleted file mode 100644 index 2495084f0..000000000 --- a/libs/waitress/runner.py +++ /dev/null @@ -1,286 +0,0 @@ -############################################################################## -# -# Copyright (c) 2013 Zope Foundation and Contributors. -# All Rights Reserved. -# -# This software is subject to the provisions of the Zope Public License, -# Version 2.1 (ZPL). A copy of the ZPL should accompany this distribution. -# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED -# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED -# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS -# FOR A PARTICULAR PURPOSE. -# -############################################################################## -"""Command line runner. -""" - -from __future__ import print_function, unicode_literals - -import getopt -import os -import os.path -import re -import sys - -from waitress import serve -from waitress.adjustments import Adjustments - -HELP = """\ -Usage: - - {0} [OPTS] MODULE:OBJECT - -Standard options: - - --help - Show this information. - - --call - Call the given object to get the WSGI application. - - --host=ADDR - Hostname or IP address on which to listen, default is '0.0.0.0', - which means "all IP addresses on this host". - - Note: May not be used together with --listen - - --port=PORT - TCP port on which to listen, default is '8080' - - Note: May not be used together with --listen - - --listen=ip:port - Tell waitress to listen on an ip port combination. - - Example: - - --listen=127.0.0.1:8080 - --listen=[::1]:8080 - --listen=*:8080 - - This option may be used multiple times to listen on multiple sockets. - A wildcard for the hostname is also supported and will bind to both - IPv4/IPv6 depending on whether they are enabled or disabled. - - --[no-]ipv4 - Toggle on/off IPv4 support. - - Example: - - --no-ipv4 - - This will disable IPv4 socket support. This affects wildcard matching - when generating the list of sockets. - - --[no-]ipv6 - Toggle on/off IPv6 support. - - Example: - - --no-ipv6 - - This will turn on IPv6 socket support. This affects wildcard matching - when generating a list of sockets. - - --unix-socket=PATH - Path of Unix socket. If a socket path is specified, a Unix domain - socket is made instead of the usual inet domain socket. - - Not available on Windows. - - --unix-socket-perms=PERMS - Octal permissions to use for the Unix domain socket, default is - '600'. - - --url-scheme=STR - Default wsgi.url_scheme value, default is 'http'. - - --url-prefix=STR - The ``SCRIPT_NAME`` WSGI environment value. Setting this to anything - except the empty string will cause the WSGI ``SCRIPT_NAME`` value to be - the value passed minus any trailing slashes you add, and it will cause - the ``PATH_INFO`` of any request which is prefixed with this value to - be stripped of the prefix. Default is the empty string. - - --ident=STR - Server identity used in the 'Server' header in responses. Default - is 'waitress'. - -Tuning options: - - --threads=INT - Number of threads used to process application logic, default is 4. - - --backlog=INT - Connection backlog for the server. Default is 1024. - - --recv-bytes=INT - Number of bytes to request when calling socket.recv(). Default is - 8192. - - --send-bytes=INT - Number of bytes to send to socket.send(). Default is 18000. - Multiples of 9000 should avoid partly-filled TCP packets. - - --outbuf-overflow=INT - A temporary file should be created if the pending output is larger - than this. Default is 1048576 (1MB). - - --outbuf-high-watermark=INT - The app_iter will pause when pending output is larger than this value - and will resume once enough data is written to the socket to fall below - this threshold. Default is 16777216 (16MB). - - --inbuf-overflow=INT - A temporary file should be created if the pending input is larger - than this. Default is 524288 (512KB). - - --connection-limit=INT - Stop creating new channels if too many are already active. - Default is 100. - - --cleanup-interval=INT - Minimum seconds between cleaning up inactive channels. Default - is 30. See '--channel-timeout'. - - --channel-timeout=INT - Maximum number of seconds to leave inactive connections open. - Default is 120. 'Inactive' is defined as 'has received no data - from the client and has sent no data to the client'. - - --[no-]log-socket-errors - Toggle whether premature client disconnect tracebacks ought to be - logged. On by default. - - --max-request-header-size=INT - Maximum size of all request headers combined. Default is 262144 - (256KB). - - --max-request-body-size=INT - Maximum size of request body. Default is 1073741824 (1GB). - - --[no-]expose-tracebacks - Toggle whether to expose tracebacks of unhandled exceptions to the - client. Off by default. - - --asyncore-loop-timeout=INT - The timeout value in seconds passed to asyncore.loop(). Default is 1. - - --asyncore-use-poll - The use_poll argument passed to ``asyncore.loop()``. Helps overcome - open file descriptors limit. Default is False. - -""" - -RUNNER_PATTERN = re.compile( - r""" - ^ - (?P - [a-z_][a-z0-9_]*(?:\.[a-z_][a-z0-9_]*)* - ) - : - (?P - [a-z_][a-z0-9_]*(?:\.[a-z_][a-z0-9_]*)* - ) - $ - """, - re.I | re.X, -) - - -def match(obj_name): - matches = RUNNER_PATTERN.match(obj_name) - if not matches: - raise ValueError("Malformed application '{0}'".format(obj_name)) - return matches.group("module"), matches.group("object") - - -def resolve(module_name, object_name): - """Resolve a named object in a module.""" - # We cast each segments due to an issue that has been found to manifest - # in Python 2.6.6, but not 2.6.8, and may affect other revisions of Python - # 2.6 and 2.7, whereby ``__import__`` chokes if the list passed in the - # ``fromlist`` argument are unicode strings rather than 8-bit strings. - # The error triggered is "TypeError: Item in ``fromlist '' not a string". - # My guess is that this was fixed by checking against ``basestring`` - # rather than ``str`` sometime between the release of 2.6.6 and 2.6.8, - # but I've yet to go over the commits. I know, however, that the NEWS - # file makes no mention of such a change to the behaviour of - # ``__import__``. - segments = [str(segment) for segment in object_name.split(".")] - obj = __import__(module_name, fromlist=segments[:1]) - for segment in segments: - obj = getattr(obj, segment) - return obj - - -def show_help(stream, name, error=None): # pragma: no cover - if error is not None: - print("Error: {0}\n".format(error), file=stream) - print(HELP.format(name), file=stream) - - -def show_exception(stream): - exc_type, exc_value = sys.exc_info()[:2] - args = getattr(exc_value, "args", None) - print( - ("There was an exception ({0}) importing your module.\n").format( - exc_type.__name__, - ), - file=stream, - ) - if args: - print("It had these arguments: ", file=stream) - for idx, arg in enumerate(args, start=1): - print("{0}. {1}\n".format(idx, arg), file=stream) - else: - print("It had no arguments.", file=stream) - - -def run(argv=sys.argv, _serve=serve): - """Command line runner.""" - name = os.path.basename(argv[0]) - - try: - kw, args = Adjustments.parse_args(argv[1:]) - except getopt.GetoptError as exc: - show_help(sys.stderr, name, str(exc)) - return 1 - - if kw["help"]: - show_help(sys.stdout, name) - return 0 - - if len(args) != 1: - show_help(sys.stderr, name, "Specify one application only") - return 1 - - try: - module, obj_name = match(args[0]) - except ValueError as exc: - show_help(sys.stderr, name, str(exc)) - show_exception(sys.stderr) - return 1 - - # Add the current directory onto sys.path - sys.path.append(os.getcwd()) - - # Get the WSGI function. - try: - app = resolve(module, obj_name) - except ImportError: - show_help(sys.stderr, name, "Bad module '{0}'".format(module)) - show_exception(sys.stderr) - return 1 - except AttributeError: - show_help(sys.stderr, name, "Bad object name '{0}'".format(obj_name)) - show_exception(sys.stderr) - return 1 - if kw["call"]: - app = app() - - # These arguments are specific to the runner, not waitress itself. - del kw["call"], kw["help"] - - _serve(app, **kw) - return 0 diff --git a/libs/waitress/server.py b/libs/waitress/server.py deleted file mode 100644 index ae566994f..000000000 --- a/libs/waitress/server.py +++ /dev/null @@ -1,436 +0,0 @@ -############################################################################## -# -# Copyright (c) 2001, 2002 Zope Foundation and Contributors. -# All Rights Reserved. -# -# This software is subject to the provisions of the Zope Public License, -# Version 2.1 (ZPL). A copy of the ZPL should accompany this distribution. -# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED -# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED -# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS -# FOR A PARTICULAR PURPOSE. -# -############################################################################## - -import os -import os.path -import socket -import time - -from waitress import trigger -from waitress.adjustments import Adjustments -from waitress.channel import HTTPChannel -from waitress.task import ThreadedTaskDispatcher -from waitress.utilities import cleanup_unix_socket - -from waitress.compat import ( - IPPROTO_IPV6, - IPV6_V6ONLY, -) -from . import wasyncore -from .proxy_headers import proxy_headers_middleware - - -def create_server( - application, - map=None, - _start=True, # test shim - _sock=None, # test shim - _dispatcher=None, # test shim - **kw # adjustments -): - """ - if __name__ == '__main__': - server = create_server(app) - server.run() - """ - if application is None: - raise ValueError( - 'The "app" passed to ``create_server`` was ``None``. You forgot ' - "to return a WSGI app within your application." - ) - adj = Adjustments(**kw) - - if map is None: # pragma: nocover - map = {} - - dispatcher = _dispatcher - if dispatcher is None: - dispatcher = ThreadedTaskDispatcher() - dispatcher.set_thread_count(adj.threads) - - if adj.unix_socket and hasattr(socket, "AF_UNIX"): - sockinfo = (socket.AF_UNIX, socket.SOCK_STREAM, None, None) - return UnixWSGIServer( - application, - map, - _start, - _sock, - dispatcher=dispatcher, - adj=adj, - sockinfo=sockinfo, - ) - - effective_listen = [] - last_serv = None - if not adj.sockets: - for sockinfo in adj.listen: - # When TcpWSGIServer is called, it registers itself in the map. This - # side-effect is all we need it for, so we don't store a reference to - # or return it to the user. - last_serv = TcpWSGIServer( - application, - map, - _start, - _sock, - dispatcher=dispatcher, - adj=adj, - sockinfo=sockinfo, - ) - effective_listen.append( - (last_serv.effective_host, last_serv.effective_port) - ) - - for sock in adj.sockets: - sockinfo = (sock.family, sock.type, sock.proto, sock.getsockname()) - if sock.family == socket.AF_INET or sock.family == socket.AF_INET6: - last_serv = TcpWSGIServer( - application, - map, - _start, - sock, - dispatcher=dispatcher, - adj=adj, - bind_socket=False, - sockinfo=sockinfo, - ) - effective_listen.append( - (last_serv.effective_host, last_serv.effective_port) - ) - elif hasattr(socket, "AF_UNIX") and sock.family == socket.AF_UNIX: - last_serv = UnixWSGIServer( - application, - map, - _start, - sock, - dispatcher=dispatcher, - adj=adj, - bind_socket=False, - sockinfo=sockinfo, - ) - effective_listen.append( - (last_serv.effective_host, last_serv.effective_port) - ) - - # We are running a single server, so we can just return the last server, - # saves us from having to create one more object - if len(effective_listen) == 1: - # In this case we have no need to use a MultiSocketServer - return last_serv - - # Return a class that has a utility function to print out the sockets it's - # listening on, and has a .run() function. All of the TcpWSGIServers - # registered themselves in the map above. - return MultiSocketServer(map, adj, effective_listen, dispatcher) - - -# This class is only ever used if we have multiple listen sockets. It allows -# the serve() API to call .run() which starts the wasyncore loop, and catches -# SystemExit/KeyboardInterrupt so that it can atempt to cleanly shut down. -class MultiSocketServer(object): - asyncore = wasyncore # test shim - - def __init__( - self, map=None, adj=None, effective_listen=None, dispatcher=None, - ): - self.adj = adj - self.map = map - self.effective_listen = effective_listen - self.task_dispatcher = dispatcher - - def print_listen(self, format_str): # pragma: nocover - for l in self.effective_listen: - l = list(l) - - if ":" in l[0]: - l[0] = "[{}]".format(l[0]) - - print(format_str.format(*l)) - - def run(self): - try: - self.asyncore.loop( - timeout=self.adj.asyncore_loop_timeout, - map=self.map, - use_poll=self.adj.asyncore_use_poll, - ) - except (SystemExit, KeyboardInterrupt): - self.close() - - def close(self): - self.task_dispatcher.shutdown() - wasyncore.close_all(self.map) - - -class BaseWSGIServer(wasyncore.dispatcher, object): - - channel_class = HTTPChannel - next_channel_cleanup = 0 - socketmod = socket # test shim - asyncore = wasyncore # test shim - - def __init__( - self, - application, - map=None, - _start=True, # test shim - _sock=None, # test shim - dispatcher=None, # dispatcher - adj=None, # adjustments - sockinfo=None, # opaque object - bind_socket=True, - **kw - ): - if adj is None: - adj = Adjustments(**kw) - - if adj.trusted_proxy or adj.clear_untrusted_proxy_headers: - # wrap the application to deal with proxy headers - # we wrap it here because webtest subclasses the TcpWSGIServer - # directly and thus doesn't run any code that's in create_server - application = proxy_headers_middleware( - application, - trusted_proxy=adj.trusted_proxy, - trusted_proxy_count=adj.trusted_proxy_count, - trusted_proxy_headers=adj.trusted_proxy_headers, - clear_untrusted=adj.clear_untrusted_proxy_headers, - log_untrusted=adj.log_untrusted_proxy_headers, - logger=self.logger, - ) - - if map is None: - # use a nonglobal socket map by default to hopefully prevent - # conflicts with apps and libs that use the wasyncore global socket - # map ala https://github.com/Pylons/waitress/issues/63 - map = {} - if sockinfo is None: - sockinfo = adj.listen[0] - - self.sockinfo = sockinfo - self.family = sockinfo[0] - self.socktype = sockinfo[1] - self.application = application - self.adj = adj - self.trigger = trigger.trigger(map) - if dispatcher is None: - dispatcher = ThreadedTaskDispatcher() - dispatcher.set_thread_count(self.adj.threads) - - self.task_dispatcher = dispatcher - self.asyncore.dispatcher.__init__(self, _sock, map=map) - if _sock is None: - self.create_socket(self.family, self.socktype) - if self.family == socket.AF_INET6: # pragma: nocover - self.socket.setsockopt(IPPROTO_IPV6, IPV6_V6ONLY, 1) - - self.set_reuse_addr() - - if bind_socket: - self.bind_server_socket() - - self.effective_host, self.effective_port = self.getsockname() - self.server_name = self.get_server_name(self.effective_host) - self.active_channels = {} - if _start: - self.accept_connections() - - def bind_server_socket(self): - raise NotImplementedError # pragma: no cover - - def get_server_name(self, ip): - """Given an IP or hostname, try to determine the server name.""" - - if not ip: - raise ValueError("Requires an IP to get the server name") - - server_name = str(ip) - - # If we are bound to all IP's, just return the current hostname, only - # fall-back to "localhost" if we fail to get the hostname - if server_name == "0.0.0.0" or server_name == "::": - try: - return str(self.socketmod.gethostname()) - except (socket.error, UnicodeDecodeError): # pragma: no cover - # We also deal with UnicodeDecodeError in case of Windows with - # non-ascii hostname - return "localhost" - - # Now let's try and convert the IP address to a proper hostname - try: - server_name = self.socketmod.gethostbyaddr(server_name)[0] - except (socket.error, UnicodeDecodeError): # pragma: no cover - # We also deal with UnicodeDecodeError in case of Windows with - # non-ascii hostname - pass - - # If it contains an IPv6 literal, make sure to surround it with - # brackets - if ":" in server_name and "[" not in server_name: - server_name = "[{}]".format(server_name) - - return server_name - - def getsockname(self): - raise NotImplementedError # pragma: no cover - - def accept_connections(self): - self.accepting = True - self.socket.listen(self.adj.backlog) # Get around asyncore NT limit - - def add_task(self, task): - self.task_dispatcher.add_task(task) - - def readable(self): - now = time.time() - if now >= self.next_channel_cleanup: - self.next_channel_cleanup = now + self.adj.cleanup_interval - self.maintenance(now) - return self.accepting and len(self._map) < self.adj.connection_limit - - def writable(self): - return False - - def handle_read(self): - pass - - def handle_connect(self): - pass - - def handle_accept(self): - try: - v = self.accept() - if v is None: - return - conn, addr = v - except socket.error: - # Linux: On rare occasions we get a bogus socket back from - # accept. socketmodule.c:makesockaddr complains that the - # address family is unknown. We don't want the whole server - # to shut down because of this. - if self.adj.log_socket_errors: - self.logger.warning("server accept() threw an exception", exc_info=True) - return - self.set_socket_options(conn) - addr = self.fix_addr(addr) - self.channel_class(self, conn, addr, self.adj, map=self._map) - - def run(self): - try: - self.asyncore.loop( - timeout=self.adj.asyncore_loop_timeout, - map=self._map, - use_poll=self.adj.asyncore_use_poll, - ) - except (SystemExit, KeyboardInterrupt): - self.task_dispatcher.shutdown() - - def pull_trigger(self): - self.trigger.pull_trigger() - - def set_socket_options(self, conn): - pass - - def fix_addr(self, addr): - return addr - - def maintenance(self, now): - """ - Closes channels that have not had any activity in a while. - - The timeout is configured through adj.channel_timeout (seconds). - """ - cutoff = now - self.adj.channel_timeout - for channel in self.active_channels.values(): - if (not channel.requests) and channel.last_activity < cutoff: - channel.will_close = True - - def print_listen(self, format_str): # pragma: nocover - print(format_str.format(self.effective_host, self.effective_port)) - - def close(self): - self.trigger.close() - return wasyncore.dispatcher.close(self) - - -class TcpWSGIServer(BaseWSGIServer): - def bind_server_socket(self): - (_, _, _, sockaddr) = self.sockinfo - self.bind(sockaddr) - - def getsockname(self): - try: - return self.socketmod.getnameinfo( - self.socket.getsockname(), self.socketmod.NI_NUMERICSERV - ) - except: # pragma: no cover - # This only happens on Linux because a DNS issue is considered a - # temporary failure that will raise (even when NI_NAMEREQD is not - # set). Instead we try again, but this time we just ask for the - # numerichost and the numericserv (port) and return those. It is - # better than nothing. - return self.socketmod.getnameinfo( - self.socket.getsockname(), - self.socketmod.NI_NUMERICHOST | self.socketmod.NI_NUMERICSERV, - ) - - def set_socket_options(self, conn): - for (level, optname, value) in self.adj.socket_options: - conn.setsockopt(level, optname, value) - - -if hasattr(socket, "AF_UNIX"): - - class UnixWSGIServer(BaseWSGIServer): - def __init__( - self, - application, - map=None, - _start=True, # test shim - _sock=None, # test shim - dispatcher=None, # dispatcher - adj=None, # adjustments - sockinfo=None, # opaque object - **kw - ): - if sockinfo is None: - sockinfo = (socket.AF_UNIX, socket.SOCK_STREAM, None, None) - - super(UnixWSGIServer, self).__init__( - application, - map=map, - _start=_start, - _sock=_sock, - dispatcher=dispatcher, - adj=adj, - sockinfo=sockinfo, - **kw - ) - - def bind_server_socket(self): - cleanup_unix_socket(self.adj.unix_socket) - self.bind(self.adj.unix_socket) - if os.path.exists(self.adj.unix_socket): - os.chmod(self.adj.unix_socket, self.adj.unix_socket_perms) - - def getsockname(self): - return ("unix", self.socket.getsockname()) - - def fix_addr(self, addr): - return ("localhost", None) - - def get_server_name(self, ip): - return "localhost" - - -# Compatibility alias. -WSGIServer = TcpWSGIServer diff --git a/libs/waitress/task.py b/libs/waitress/task.py deleted file mode 100644 index 8e7ab1888..000000000 --- a/libs/waitress/task.py +++ /dev/null @@ -1,570 +0,0 @@ -############################################################################## -# -# Copyright (c) 2001, 2002 Zope Foundation and Contributors. -# All Rights Reserved. -# -# This software is subject to the provisions of the Zope Public License, -# Version 2.1 (ZPL). A copy of the ZPL should accompany this distribution. -# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED -# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED -# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS -# FOR A PARTICULAR PURPOSE. -# -############################################################################## - -import socket -import sys -import threading -import time -from collections import deque - -from .buffers import ReadOnlyFileBasedBuffer -from .compat import reraise, tobytes -from .utilities import build_http_date, logger, queue_logger - -rename_headers = { # or keep them without the HTTP_ prefix added - "CONTENT_LENGTH": "CONTENT_LENGTH", - "CONTENT_TYPE": "CONTENT_TYPE", -} - -hop_by_hop = frozenset( - ( - "connection", - "keep-alive", - "proxy-authenticate", - "proxy-authorization", - "te", - "trailers", - "transfer-encoding", - "upgrade", - ) -) - - -class ThreadedTaskDispatcher(object): - """A Task Dispatcher that creates a thread for each task. - """ - - stop_count = 0 # Number of threads that will stop soon. - active_count = 0 # Number of currently active threads - logger = logger - queue_logger = queue_logger - - def __init__(self): - self.threads = set() - self.queue = deque() - self.lock = threading.Lock() - self.queue_cv = threading.Condition(self.lock) - self.thread_exit_cv = threading.Condition(self.lock) - - def start_new_thread(self, target, args): - t = threading.Thread(target=target, name="waitress", args=args) - t.daemon = True - t.start() - - def handler_thread(self, thread_no): - while True: - with self.lock: - while not self.queue and self.stop_count == 0: - # Mark ourselves as idle before waiting to be - # woken up, then we will once again be active - self.active_count -= 1 - self.queue_cv.wait() - self.active_count += 1 - - if self.stop_count > 0: - self.active_count -= 1 - self.stop_count -= 1 - self.threads.discard(thread_no) - self.thread_exit_cv.notify() - break - - task = self.queue.popleft() - try: - task.service() - except BaseException: - self.logger.exception("Exception when servicing %r", task) - - def set_thread_count(self, count): - with self.lock: - threads = self.threads - thread_no = 0 - running = len(threads) - self.stop_count - while running < count: - # Start threads. - while thread_no in threads: - thread_no = thread_no + 1 - threads.add(thread_no) - running += 1 - self.start_new_thread(self.handler_thread, (thread_no,)) - self.active_count += 1 - thread_no = thread_no + 1 - if running > count: - # Stop threads. - self.stop_count += running - count - self.queue_cv.notify_all() - - def add_task(self, task): - with self.lock: - self.queue.append(task) - self.queue_cv.notify() - queue_size = len(self.queue) - idle_threads = len(self.threads) - self.stop_count - self.active_count - if queue_size > idle_threads: - self.queue_logger.warning( - "Task queue depth is %d", queue_size - idle_threads - ) - - def shutdown(self, cancel_pending=True, timeout=5): - self.set_thread_count(0) - # Ensure the threads shut down. - threads = self.threads - expiration = time.time() + timeout - with self.lock: - while threads: - if time.time() >= expiration: - self.logger.warning("%d thread(s) still running", len(threads)) - break - self.thread_exit_cv.wait(0.1) - if cancel_pending: - # Cancel remaining tasks. - queue = self.queue - if len(queue) > 0: - self.logger.warning("Canceling %d pending task(s)", len(queue)) - while queue: - task = queue.popleft() - task.cancel() - self.queue_cv.notify_all() - return True - return False - - -class Task(object): - close_on_finish = False - status = "200 OK" - wrote_header = False - start_time = 0 - content_length = None - content_bytes_written = 0 - logged_write_excess = False - logged_write_no_body = False - complete = False - chunked_response = False - logger = logger - - def __init__(self, channel, request): - self.channel = channel - self.request = request - self.response_headers = [] - version = request.version - if version not in ("1.0", "1.1"): - # fall back to a version we support. - version = "1.0" - self.version = version - - def service(self): - try: - try: - self.start() - self.execute() - self.finish() - except socket.error: - self.close_on_finish = True - if self.channel.adj.log_socket_errors: - raise - finally: - pass - - @property - def has_body(self): - return not ( - self.status.startswith("1") - or self.status.startswith("204") - or self.status.startswith("304") - ) - - def build_response_header(self): - version = self.version - # Figure out whether the connection should be closed. - connection = self.request.headers.get("CONNECTION", "").lower() - response_headers = [] - content_length_header = None - date_header = None - server_header = None - connection_close_header = None - - for (headername, headerval) in self.response_headers: - headername = "-".join([x.capitalize() for x in headername.split("-")]) - - if headername == "Content-Length": - if self.has_body: - content_length_header = headerval - else: - continue # pragma: no cover - - if headername == "Date": - date_header = headerval - - if headername == "Server": - server_header = headerval - - if headername == "Connection": - connection_close_header = headerval.lower() - # replace with properly capitalized version - response_headers.append((headername, headerval)) - - if ( - content_length_header is None - and self.content_length is not None - and self.has_body - ): - content_length_header = str(self.content_length) - response_headers.append(("Content-Length", content_length_header)) - - def close_on_finish(): - if connection_close_header is None: - response_headers.append(("Connection", "close")) - self.close_on_finish = True - - if version == "1.0": - if connection == "keep-alive": - if not content_length_header: - close_on_finish() - else: - response_headers.append(("Connection", "Keep-Alive")) - else: - close_on_finish() - - elif version == "1.1": - if connection == "close": - close_on_finish() - - if not content_length_header: - # RFC 7230: MUST NOT send Transfer-Encoding or Content-Length - # for any response with a status code of 1xx, 204 or 304. - - if self.has_body: - response_headers.append(("Transfer-Encoding", "chunked")) - self.chunked_response = True - - if not self.close_on_finish: - close_on_finish() - - # under HTTP 1.1 keep-alive is default, no need to set the header - else: - raise AssertionError("neither HTTP/1.0 or HTTP/1.1") - - # Set the Server and Date field, if not yet specified. This is needed - # if the server is used as a proxy. - ident = self.channel.server.adj.ident - - if not server_header: - if ident: - response_headers.append(("Server", ident)) - else: - response_headers.append(("Via", ident or "waitress")) - - if not date_header: - response_headers.append(("Date", build_http_date(self.start_time))) - - self.response_headers = response_headers - - first_line = "HTTP/%s %s" % (self.version, self.status) - # NB: sorting headers needs to preserve same-named-header order - # as per RFC 2616 section 4.2; thus the key=lambda x: x[0] here; - # rely on stable sort to keep relative position of same-named headers - next_lines = [ - "%s: %s" % hv for hv in sorted(self.response_headers, key=lambda x: x[0]) - ] - lines = [first_line] + next_lines - res = "%s\r\n\r\n" % "\r\n".join(lines) - - return tobytes(res) - - def remove_content_length_header(self): - response_headers = [] - - for header_name, header_value in self.response_headers: - if header_name.lower() == "content-length": - continue # pragma: nocover - response_headers.append((header_name, header_value)) - - self.response_headers = response_headers - - def start(self): - self.start_time = time.time() - - def finish(self): - if not self.wrote_header: - self.write(b"") - if self.chunked_response: - # not self.write, it will chunk it! - self.channel.write_soon(b"0\r\n\r\n") - - def write(self, data): - if not self.complete: - raise RuntimeError("start_response was not called before body written") - channel = self.channel - if not self.wrote_header: - rh = self.build_response_header() - channel.write_soon(rh) - self.wrote_header = True - - if data and self.has_body: - towrite = data - cl = self.content_length - if self.chunked_response: - # use chunked encoding response - towrite = tobytes(hex(len(data))[2:].upper()) + b"\r\n" - towrite += data + b"\r\n" - elif cl is not None: - towrite = data[: cl - self.content_bytes_written] - self.content_bytes_written += len(towrite) - if towrite != data and not self.logged_write_excess: - self.logger.warning( - "application-written content exceeded the number of " - "bytes specified by Content-Length header (%s)" % cl - ) - self.logged_write_excess = True - if towrite: - channel.write_soon(towrite) - elif data: - # Cheat, and tell the application we have written all of the bytes, - # even though the response shouldn't have a body and we are - # ignoring it entirely. - self.content_bytes_written += len(data) - - if not self.logged_write_no_body: - self.logger.warning( - "application-written content was ignored due to HTTP " - "response that may not contain a message-body: (%s)" % self.status - ) - self.logged_write_no_body = True - - -class ErrorTask(Task): - """ An error task produces an error response - """ - - complete = True - - def execute(self): - e = self.request.error - status, headers, body = e.to_response() - self.status = status - self.response_headers.extend(headers) - # We need to explicitly tell the remote client we are closing the - # connection, because self.close_on_finish is set, and we are going to - # slam the door in the clients face. - self.response_headers.append(("Connection", "close")) - self.close_on_finish = True - self.content_length = len(body) - self.write(tobytes(body)) - - -class WSGITask(Task): - """A WSGI task produces a response from a WSGI application. - """ - - environ = None - - def execute(self): - environ = self.get_environment() - - def start_response(status, headers, exc_info=None): - if self.complete and not exc_info: - raise AssertionError( - "start_response called a second time without providing exc_info." - ) - if exc_info: - try: - if self.wrote_header: - # higher levels will catch and handle raised exception: - # 1. "service" method in task.py - # 2. "service" method in channel.py - # 3. "handler_thread" method in task.py - reraise(exc_info[0], exc_info[1], exc_info[2]) - else: - # As per WSGI spec existing headers must be cleared - self.response_headers = [] - finally: - exc_info = None - - self.complete = True - - if not status.__class__ is str: - raise AssertionError("status %s is not a string" % status) - if "\n" in status or "\r" in status: - raise ValueError( - "carriage return/line feed character present in status" - ) - - self.status = status - - # Prepare the headers for output - for k, v in headers: - if not k.__class__ is str: - raise AssertionError( - "Header name %r is not a string in %r" % (k, (k, v)) - ) - if not v.__class__ is str: - raise AssertionError( - "Header value %r is not a string in %r" % (v, (k, v)) - ) - - if "\n" in v or "\r" in v: - raise ValueError( - "carriage return/line feed character present in header value" - ) - if "\n" in k or "\r" in k: - raise ValueError( - "carriage return/line feed character present in header name" - ) - - kl = k.lower() - if kl == "content-length": - self.content_length = int(v) - elif kl in hop_by_hop: - raise AssertionError( - '%s is a "hop-by-hop" header; it cannot be used by ' - "a WSGI application (see PEP 3333)" % k - ) - - self.response_headers.extend(headers) - - # Return a method used to write the response data. - return self.write - - # Call the application to handle the request and write a response - app_iter = self.channel.server.application(environ, start_response) - - can_close_app_iter = True - try: - if app_iter.__class__ is ReadOnlyFileBasedBuffer: - cl = self.content_length - size = app_iter.prepare(cl) - if size: - if cl != size: - if cl is not None: - self.remove_content_length_header() - self.content_length = size - self.write(b"") # generate headers - # if the write_soon below succeeds then the channel will - # take over closing the underlying file via the channel's - # _flush_some or handle_close so we intentionally avoid - # calling close in the finally block - self.channel.write_soon(app_iter) - can_close_app_iter = False - return - - first_chunk_len = None - for chunk in app_iter: - if first_chunk_len is None: - first_chunk_len = len(chunk) - # Set a Content-Length header if one is not supplied. - # start_response may not have been called until first - # iteration as per PEP, so we must reinterrogate - # self.content_length here - if self.content_length is None: - app_iter_len = None - if hasattr(app_iter, "__len__"): - app_iter_len = len(app_iter) - if app_iter_len == 1: - self.content_length = first_chunk_len - # transmit headers only after first iteration of the iterable - # that returns a non-empty bytestring (PEP 3333) - if chunk: - self.write(chunk) - - cl = self.content_length - if cl is not None: - if self.content_bytes_written != cl: - # close the connection so the client isn't sitting around - # waiting for more data when there are too few bytes - # to service content-length - self.close_on_finish = True - if self.request.command != "HEAD": - self.logger.warning( - "application returned too few bytes (%s) " - "for specified Content-Length (%s) via app_iter" - % (self.content_bytes_written, cl), - ) - finally: - if can_close_app_iter and hasattr(app_iter, "close"): - app_iter.close() - - def get_environment(self): - """Returns a WSGI environment.""" - environ = self.environ - if environ is not None: - # Return the cached copy. - return environ - - request = self.request - path = request.path - channel = self.channel - server = channel.server - url_prefix = server.adj.url_prefix - - if path.startswith("/"): - # strip extra slashes at the beginning of a path that starts - # with any number of slashes - path = "/" + path.lstrip("/") - - if url_prefix: - # NB: url_prefix is guaranteed by the configuration machinery to - # be either the empty string or a string that starts with a single - # slash and ends without any slashes - if path == url_prefix: - # if the path is the same as the url prefix, the SCRIPT_NAME - # should be the url_prefix and PATH_INFO should be empty - path = "" - else: - # if the path starts with the url prefix plus a slash, - # the SCRIPT_NAME should be the url_prefix and PATH_INFO should - # the value of path from the slash until its end - url_prefix_with_trailing_slash = url_prefix + "/" - if path.startswith(url_prefix_with_trailing_slash): - path = path[len(url_prefix) :] - - environ = { - "REMOTE_ADDR": channel.addr[0], - # Nah, we aren't actually going to look up the reverse DNS for - # REMOTE_ADDR, but we will happily set this environment variable - # for the WSGI application. Spec says we can just set this to - # REMOTE_ADDR, so we do. - "REMOTE_HOST": channel.addr[0], - # try and set the REMOTE_PORT to something useful, but maybe None - "REMOTE_PORT": str(channel.addr[1]), - "REQUEST_METHOD": request.command.upper(), - "SERVER_PORT": str(server.effective_port), - "SERVER_NAME": server.server_name, - "SERVER_SOFTWARE": server.adj.ident, - "SERVER_PROTOCOL": "HTTP/%s" % self.version, - "SCRIPT_NAME": url_prefix, - "PATH_INFO": path, - "QUERY_STRING": request.query, - "wsgi.url_scheme": request.url_scheme, - # the following environment variables are required by the WSGI spec - "wsgi.version": (1, 0), - # apps should use the logging module - "wsgi.errors": sys.stderr, - "wsgi.multithread": True, - "wsgi.multiprocess": False, - "wsgi.run_once": False, - "wsgi.input": request.get_body_stream(), - "wsgi.file_wrapper": ReadOnlyFileBasedBuffer, - "wsgi.input_terminated": True, # wsgi.input is EOF terminated - } - - for key, value in dict(request.headers).items(): - value = value.strip() - mykey = rename_headers.get(key, None) - if mykey is None: - mykey = "HTTP_" + key - if mykey not in environ: - environ[mykey] = value - - # cache the environ for this request - self.environ = environ - return environ diff --git a/libs/waitress/tests/__init__.py b/libs/waitress/tests/__init__.py deleted file mode 100644 index b711d3609..000000000 --- a/libs/waitress/tests/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -# -# This file is necessary to make this directory a package. diff --git a/libs/waitress/tests/fixtureapps/__init__.py b/libs/waitress/tests/fixtureapps/__init__.py deleted file mode 100644 index f215a2b90..000000000 --- a/libs/waitress/tests/fixtureapps/__init__.py +++ /dev/null @@ -1 +0,0 @@ -# package (for -m) diff --git a/libs/waitress/tests/fixtureapps/badcl.py b/libs/waitress/tests/fixtureapps/badcl.py deleted file mode 100644 index 24067de41..000000000 --- a/libs/waitress/tests/fixtureapps/badcl.py +++ /dev/null @@ -1,11 +0,0 @@ -def app(environ, start_response): # pragma: no cover - body = b"abcdefghi" - cl = len(body) - if environ["PATH_INFO"] == "/short_body": - cl = len(body) + 1 - if environ["PATH_INFO"] == "/long_body": - cl = len(body) - 1 - start_response( - "200 OK", [("Content-Length", str(cl)), ("Content-Type", "text/plain")] - ) - return [body] diff --git a/libs/waitress/tests/fixtureapps/echo.py b/libs/waitress/tests/fixtureapps/echo.py deleted file mode 100644 index 813bdacea..000000000 --- a/libs/waitress/tests/fixtureapps/echo.py +++ /dev/null @@ -1,56 +0,0 @@ -from collections import namedtuple -import json - - -def app_body_only(environ, start_response): # pragma: no cover - cl = environ.get("CONTENT_LENGTH", None) - if cl is not None: - cl = int(cl) - body = environ["wsgi.input"].read(cl) - cl = str(len(body)) - start_response("200 OK", [("Content-Length", cl), ("Content-Type", "text/plain"),]) - return [body] - - -def app(environ, start_response): # pragma: no cover - cl = environ.get("CONTENT_LENGTH", None) - if cl is not None: - cl = int(cl) - request_body = environ["wsgi.input"].read(cl) - cl = str(len(request_body)) - meta = { - "method": environ["REQUEST_METHOD"], - "path_info": environ["PATH_INFO"], - "script_name": environ["SCRIPT_NAME"], - "query_string": environ["QUERY_STRING"], - "content_length": cl, - "scheme": environ["wsgi.url_scheme"], - "remote_addr": environ["REMOTE_ADDR"], - "remote_host": environ["REMOTE_HOST"], - "server_port": environ["SERVER_PORT"], - "server_name": environ["SERVER_NAME"], - "headers": { - k[len("HTTP_") :]: v for k, v in environ.items() if k.startswith("HTTP_") - }, - } - response = json.dumps(meta).encode("utf8") + b"\r\n\r\n" + request_body - start_response( - "200 OK", - [("Content-Length", str(len(response))), ("Content-Type", "text/plain"),], - ) - return [response] - - -Echo = namedtuple( - "Echo", - ( - "method path_info script_name query_string content_length scheme " - "remote_addr remote_host server_port server_name headers body" - ), -) - - -def parse_response(response): - meta, body = response.split(b"\r\n\r\n", 1) - meta = json.loads(meta.decode("utf8")) - return Echo(body=body, **meta) diff --git a/libs/waitress/tests/fixtureapps/error.py b/libs/waitress/tests/fixtureapps/error.py deleted file mode 100644 index 5afb1c542..000000000 --- a/libs/waitress/tests/fixtureapps/error.py +++ /dev/null @@ -1,21 +0,0 @@ -def app(environ, start_response): # pragma: no cover - cl = environ.get("CONTENT_LENGTH", None) - if cl is not None: - cl = int(cl) - body = environ["wsgi.input"].read(cl) - cl = str(len(body)) - if environ["PATH_INFO"] == "/before_start_response": - raise ValueError("wrong") - write = start_response( - "200 OK", [("Content-Length", cl), ("Content-Type", "text/plain")] - ) - if environ["PATH_INFO"] == "/after_write_cb": - write("abc") - if environ["PATH_INFO"] == "/in_generator": - - def foo(): - yield "abc" - raise ValueError - - return foo() - raise ValueError("wrong") diff --git a/libs/waitress/tests/fixtureapps/filewrapper.py b/libs/waitress/tests/fixtureapps/filewrapper.py deleted file mode 100644 index 63df5a6dc..000000000 --- a/libs/waitress/tests/fixtureapps/filewrapper.py +++ /dev/null @@ -1,93 +0,0 @@ -import io -import os - -here = os.path.dirname(os.path.abspath(__file__)) -fn = os.path.join(here, "groundhog1.jpg") - - -class KindaFilelike(object): # pragma: no cover - def __init__(self, bytes): - self.bytes = bytes - - def read(self, n): - bytes = self.bytes[:n] - self.bytes = self.bytes[n:] - return bytes - - -class UnseekableIOBase(io.RawIOBase): # pragma: no cover - def __init__(self, bytes): - self.buf = io.BytesIO(bytes) - - def writable(self): - return False - - def readable(self): - return True - - def seekable(self): - return False - - def read(self, n): - return self.buf.read(n) - - -def app(environ, start_response): # pragma: no cover - path_info = environ["PATH_INFO"] - if path_info.startswith("/filelike"): - f = open(fn, "rb") - f.seek(0, 2) - cl = f.tell() - f.seek(0) - if path_info == "/filelike": - headers = [ - ("Content-Length", str(cl)), - ("Content-Type", "image/jpeg"), - ] - elif path_info == "/filelike_nocl": - headers = [("Content-Type", "image/jpeg")] - elif path_info == "/filelike_shortcl": - # short content length - headers = [ - ("Content-Length", "1"), - ("Content-Type", "image/jpeg"), - ] - else: - # long content length (/filelike_longcl) - headers = [ - ("Content-Length", str(cl + 10)), - ("Content-Type", "image/jpeg"), - ] - else: - with open(fn, "rb") as fp: - data = fp.read() - cl = len(data) - f = KindaFilelike(data) - if path_info == "/notfilelike": - headers = [ - ("Content-Length", str(len(data))), - ("Content-Type", "image/jpeg"), - ] - elif path_info == "/notfilelike_iobase": - headers = [ - ("Content-Length", str(len(data))), - ("Content-Type", "image/jpeg"), - ] - f = UnseekableIOBase(data) - elif path_info == "/notfilelike_nocl": - headers = [("Content-Type", "image/jpeg")] - elif path_info == "/notfilelike_shortcl": - # short content length - headers = [ - ("Content-Length", "1"), - ("Content-Type", "image/jpeg"), - ] - else: - # long content length (/notfilelike_longcl) - headers = [ - ("Content-Length", str(cl + 10)), - ("Content-Type", "image/jpeg"), - ] - - start_response("200 OK", headers) - return environ["wsgi.file_wrapper"](f, 8192) diff --git a/libs/waitress/tests/fixtureapps/getline.py b/libs/waitress/tests/fixtureapps/getline.py deleted file mode 100644 index 5e0ad3ae5..000000000 --- a/libs/waitress/tests/fixtureapps/getline.py +++ /dev/null @@ -1,17 +0,0 @@ -import sys - -if __name__ == "__main__": - try: - from urllib.request import urlopen, URLError - except ImportError: - from urllib2 import urlopen, URLError - - url = sys.argv[1] - headers = {"Content-Type": "text/plain; charset=utf-8"} - try: - resp = urlopen(url) - line = resp.readline().decode("ascii") # py3 - except URLError: - line = "failed to read %s" % url - sys.stdout.write(line) - sys.stdout.flush() diff --git a/libs/waitress/tests/fixtureapps/groundhog1.jpg b/libs/waitress/tests/fixtureapps/groundhog1.jpg deleted file mode 100644 index 90f610ea0..000000000 Binary files a/libs/waitress/tests/fixtureapps/groundhog1.jpg and /dev/null differ diff --git a/libs/waitress/tests/fixtureapps/nocl.py b/libs/waitress/tests/fixtureapps/nocl.py deleted file mode 100644 index f82bba0c8..000000000 --- a/libs/waitress/tests/fixtureapps/nocl.py +++ /dev/null @@ -1,23 +0,0 @@ -def chunks(l, n): # pragma: no cover - """ Yield successive n-sized chunks from l. - """ - for i in range(0, len(l), n): - yield l[i : i + n] - - -def gen(body): # pragma: no cover - for chunk in chunks(body, 10): - yield chunk - - -def app(environ, start_response): # pragma: no cover - cl = environ.get("CONTENT_LENGTH", None) - if cl is not None: - cl = int(cl) - body = environ["wsgi.input"].read(cl) - start_response("200 OK", [("Content-Type", "text/plain")]) - if environ["PATH_INFO"] == "/list": - return [body] - if environ["PATH_INFO"] == "/list_lentwo": - return [body[0:1], body[1:]] - return gen(body) diff --git a/libs/waitress/tests/fixtureapps/runner.py b/libs/waitress/tests/fixtureapps/runner.py deleted file mode 100644 index 1d66ad1cc..000000000 --- a/libs/waitress/tests/fixtureapps/runner.py +++ /dev/null @@ -1,6 +0,0 @@ -def app(): # pragma: no cover - return None - - -def returns_app(): # pragma: no cover - return app diff --git a/libs/waitress/tests/fixtureapps/sleepy.py b/libs/waitress/tests/fixtureapps/sleepy.py deleted file mode 100644 index 2d171d8be..000000000 --- a/libs/waitress/tests/fixtureapps/sleepy.py +++ /dev/null @@ -1,12 +0,0 @@ -import time - - -def app(environ, start_response): # pragma: no cover - if environ["PATH_INFO"] == "/sleepy": - time.sleep(2) - body = b"sleepy returned" - else: - body = b"notsleepy returned" - cl = str(len(body)) - start_response("200 OK", [("Content-Length", cl), ("Content-Type", "text/plain")]) - return [body] diff --git a/libs/waitress/tests/fixtureapps/toolarge.py b/libs/waitress/tests/fixtureapps/toolarge.py deleted file mode 100644 index a0f36d2cc..000000000 --- a/libs/waitress/tests/fixtureapps/toolarge.py +++ /dev/null @@ -1,7 +0,0 @@ -def app(environ, start_response): # pragma: no cover - body = b"abcdef" - cl = len(body) - start_response( - "200 OK", [("Content-Length", str(cl)), ("Content-Type", "text/plain")] - ) - return [body] diff --git a/libs/waitress/tests/fixtureapps/writecb.py b/libs/waitress/tests/fixtureapps/writecb.py deleted file mode 100644 index e1d2792e6..000000000 --- a/libs/waitress/tests/fixtureapps/writecb.py +++ /dev/null @@ -1,14 +0,0 @@ -def app(environ, start_response): # pragma: no cover - path_info = environ["PATH_INFO"] - if path_info == "/no_content_length": - headers = [] - else: - headers = [("Content-Length", "9")] - write = start_response("200 OK", headers) - if path_info == "/long_body": - write(b"abcdefghij") - elif path_info == "/short_body": - write(b"abcdefgh") - else: - write(b"abcdefghi") - return [] diff --git a/libs/waitress/tests/test_adjustments.py b/libs/waitress/tests/test_adjustments.py deleted file mode 100644 index 303c1aa3a..000000000 --- a/libs/waitress/tests/test_adjustments.py +++ /dev/null @@ -1,481 +0,0 @@ -import sys -import socket -import warnings - -from waitress.compat import ( - PY2, - WIN, -) - -if sys.version_info[:2] == (2, 6): # pragma: no cover - import unittest2 as unittest -else: # pragma: no cover - import unittest - - -class Test_asbool(unittest.TestCase): - def _callFUT(self, s): - from waitress.adjustments import asbool - - return asbool(s) - - def test_s_is_None(self): - result = self._callFUT(None) - self.assertEqual(result, False) - - def test_s_is_True(self): - result = self._callFUT(True) - self.assertEqual(result, True) - - def test_s_is_False(self): - result = self._callFUT(False) - self.assertEqual(result, False) - - def test_s_is_true(self): - result = self._callFUT("True") - self.assertEqual(result, True) - - def test_s_is_false(self): - result = self._callFUT("False") - self.assertEqual(result, False) - - def test_s_is_yes(self): - result = self._callFUT("yes") - self.assertEqual(result, True) - - def test_s_is_on(self): - result = self._callFUT("on") - self.assertEqual(result, True) - - def test_s_is_1(self): - result = self._callFUT(1) - self.assertEqual(result, True) - - -class Test_as_socket_list(unittest.TestCase): - def test_only_sockets_in_list(self): - from waitress.adjustments import as_socket_list - - sockets = [ - socket.socket(socket.AF_INET, socket.SOCK_STREAM), - socket.socket(socket.AF_INET6, socket.SOCK_STREAM), - ] - if hasattr(socket, "AF_UNIX"): - sockets.append(socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)) - new_sockets = as_socket_list(sockets) - self.assertEqual(sockets, new_sockets) - for sock in sockets: - sock.close() - - def test_not_only_sockets_in_list(self): - from waitress.adjustments import as_socket_list - - sockets = [ - socket.socket(socket.AF_INET, socket.SOCK_STREAM), - socket.socket(socket.AF_INET6, socket.SOCK_STREAM), - {"something": "else"}, - ] - new_sockets = as_socket_list(sockets) - self.assertEqual(new_sockets, [sockets[0], sockets[1]]) - for sock in [sock for sock in sockets if isinstance(sock, socket.socket)]: - sock.close() - - -class TestAdjustments(unittest.TestCase): - def _hasIPv6(self): # pragma: nocover - if not socket.has_ipv6: - return False - - try: - socket.getaddrinfo( - "::1", - 0, - socket.AF_UNSPEC, - socket.SOCK_STREAM, - socket.IPPROTO_TCP, - socket.AI_PASSIVE | socket.AI_ADDRCONFIG, - ) - - return True - except socket.gaierror as e: - # Check to see what the error is - if e.errno == socket.EAI_ADDRFAMILY: - return False - else: - raise e - - def _makeOne(self, **kw): - from waitress.adjustments import Adjustments - - return Adjustments(**kw) - - def test_goodvars(self): - inst = self._makeOne( - host="localhost", - port="8080", - threads="5", - trusted_proxy="192.168.1.1", - trusted_proxy_headers={"forwarded"}, - trusted_proxy_count=2, - log_untrusted_proxy_headers=True, - url_scheme="https", - backlog="20", - recv_bytes="200", - send_bytes="300", - outbuf_overflow="400", - inbuf_overflow="500", - connection_limit="1000", - cleanup_interval="1100", - channel_timeout="1200", - log_socket_errors="true", - max_request_header_size="1300", - max_request_body_size="1400", - expose_tracebacks="true", - ident="abc", - asyncore_loop_timeout="5", - asyncore_use_poll=True, - unix_socket_perms="777", - url_prefix="///foo/", - ipv4=True, - ipv6=False, - ) - - self.assertEqual(inst.host, "localhost") - self.assertEqual(inst.port, 8080) - self.assertEqual(inst.threads, 5) - self.assertEqual(inst.trusted_proxy, "192.168.1.1") - self.assertEqual(inst.trusted_proxy_headers, {"forwarded"}) - self.assertEqual(inst.trusted_proxy_count, 2) - self.assertEqual(inst.log_untrusted_proxy_headers, True) - self.assertEqual(inst.url_scheme, "https") - self.assertEqual(inst.backlog, 20) - self.assertEqual(inst.recv_bytes, 200) - self.assertEqual(inst.send_bytes, 300) - self.assertEqual(inst.outbuf_overflow, 400) - self.assertEqual(inst.inbuf_overflow, 500) - self.assertEqual(inst.connection_limit, 1000) - self.assertEqual(inst.cleanup_interval, 1100) - self.assertEqual(inst.channel_timeout, 1200) - self.assertEqual(inst.log_socket_errors, True) - self.assertEqual(inst.max_request_header_size, 1300) - self.assertEqual(inst.max_request_body_size, 1400) - self.assertEqual(inst.expose_tracebacks, True) - self.assertEqual(inst.asyncore_loop_timeout, 5) - self.assertEqual(inst.asyncore_use_poll, True) - self.assertEqual(inst.ident, "abc") - self.assertEqual(inst.unix_socket_perms, 0o777) - self.assertEqual(inst.url_prefix, "/foo") - self.assertEqual(inst.ipv4, True) - self.assertEqual(inst.ipv6, False) - - bind_pairs = [ - sockaddr[:2] - for (family, _, _, sockaddr) in inst.listen - if family == socket.AF_INET - ] - - # On Travis, somehow we start listening to two sockets when resolving - # localhost... - self.assertEqual(("127.0.0.1", 8080), bind_pairs[0]) - - def test_goodvar_listen(self): - inst = self._makeOne(listen="127.0.0.1") - - bind_pairs = [(host, port) for (_, _, _, (host, port)) in inst.listen] - - self.assertEqual(bind_pairs, [("127.0.0.1", 8080)]) - - def test_default_listen(self): - inst = self._makeOne() - - bind_pairs = [(host, port) for (_, _, _, (host, port)) in inst.listen] - - self.assertEqual(bind_pairs, [("0.0.0.0", 8080)]) - - def test_multiple_listen(self): - inst = self._makeOne(listen="127.0.0.1:9090 127.0.0.1:8080") - - bind_pairs = [sockaddr[:2] for (_, _, _, sockaddr) in inst.listen] - - self.assertEqual(bind_pairs, [("127.0.0.1", 9090), ("127.0.0.1", 8080)]) - - def test_wildcard_listen(self): - inst = self._makeOne(listen="*:8080") - - bind_pairs = [sockaddr[:2] for (_, _, _, sockaddr) in inst.listen] - - self.assertTrue(len(bind_pairs) >= 1) - - def test_ipv6_no_port(self): # pragma: nocover - if not self._hasIPv6(): - return - - inst = self._makeOne(listen="[::1]") - - bind_pairs = [sockaddr[:2] for (_, _, _, sockaddr) in inst.listen] - - self.assertEqual(bind_pairs, [("::1", 8080)]) - - def test_bad_port(self): - self.assertRaises(ValueError, self._makeOne, listen="127.0.0.1:test") - - def test_service_port(self): - if WIN and PY2: # pragma: no cover - # On Windows and Python 2 this is broken, so we raise a ValueError - self.assertRaises( - ValueError, self._makeOne, listen="127.0.0.1:http", - ) - return - - inst = self._makeOne(listen="127.0.0.1:http 0.0.0.0:https") - - bind_pairs = [sockaddr[:2] for (_, _, _, sockaddr) in inst.listen] - - self.assertEqual(bind_pairs, [("127.0.0.1", 80), ("0.0.0.0", 443)]) - - def test_dont_mix_host_port_listen(self): - self.assertRaises( - ValueError, - self._makeOne, - host="localhost", - port="8080", - listen="127.0.0.1:8080", - ) - - def test_good_sockets(self): - sockets = [ - socket.socket(socket.AF_INET6, socket.SOCK_STREAM), - socket.socket(socket.AF_INET, socket.SOCK_STREAM), - ] - inst = self._makeOne(sockets=sockets) - self.assertEqual(inst.sockets, sockets) - sockets[0].close() - sockets[1].close() - - def test_dont_mix_sockets_and_listen(self): - sockets = [socket.socket(socket.AF_INET, socket.SOCK_STREAM)] - self.assertRaises( - ValueError, self._makeOne, listen="127.0.0.1:8080", sockets=sockets - ) - sockets[0].close() - - def test_dont_mix_sockets_and_host_port(self): - sockets = [socket.socket(socket.AF_INET, socket.SOCK_STREAM)] - self.assertRaises( - ValueError, self._makeOne, host="localhost", port="8080", sockets=sockets - ) - sockets[0].close() - - def test_dont_mix_sockets_and_unix_socket(self): - sockets = [socket.socket(socket.AF_INET, socket.SOCK_STREAM)] - self.assertRaises( - ValueError, self._makeOne, unix_socket="./tmp/test", sockets=sockets - ) - sockets[0].close() - - def test_dont_mix_unix_socket_and_host_port(self): - self.assertRaises( - ValueError, - self._makeOne, - unix_socket="./tmp/test", - host="localhost", - port="8080", - ) - - def test_dont_mix_unix_socket_and_listen(self): - self.assertRaises( - ValueError, self._makeOne, unix_socket="./tmp/test", listen="127.0.0.1:8080" - ) - - def test_dont_use_unsupported_socket_types(self): - sockets = [socket.socket(socket.AF_INET, socket.SOCK_DGRAM)] - self.assertRaises(ValueError, self._makeOne, sockets=sockets) - sockets[0].close() - - def test_dont_mix_forwarded_with_x_forwarded(self): - with self.assertRaises(ValueError) as cm: - self._makeOne( - trusted_proxy="localhost", - trusted_proxy_headers={"forwarded", "x-forwarded-for"}, - ) - - self.assertIn("The Forwarded proxy header", str(cm.exception)) - - def test_unknown_trusted_proxy_header(self): - with self.assertRaises(ValueError) as cm: - self._makeOne( - trusted_proxy="localhost", - trusted_proxy_headers={"forwarded", "x-forwarded-unknown"}, - ) - - self.assertIn( - "unknown trusted_proxy_headers value (x-forwarded-unknown)", - str(cm.exception), - ) - - def test_trusted_proxy_count_no_trusted_proxy(self): - with self.assertRaises(ValueError) as cm: - self._makeOne(trusted_proxy_count=1) - - self.assertIn("trusted_proxy_count has no meaning", str(cm.exception)) - - def test_trusted_proxy_headers_no_trusted_proxy(self): - with self.assertRaises(ValueError) as cm: - self._makeOne(trusted_proxy_headers={"forwarded"}) - - self.assertIn("trusted_proxy_headers has no meaning", str(cm.exception)) - - def test_trusted_proxy_headers_string_list(self): - inst = self._makeOne( - trusted_proxy="localhost", - trusted_proxy_headers="x-forwarded-for x-forwarded-by", - ) - self.assertEqual( - inst.trusted_proxy_headers, {"x-forwarded-for", "x-forwarded-by"} - ) - - def test_trusted_proxy_headers_string_list_newlines(self): - inst = self._makeOne( - trusted_proxy="localhost", - trusted_proxy_headers="x-forwarded-for\nx-forwarded-by\nx-forwarded-host", - ) - self.assertEqual( - inst.trusted_proxy_headers, - {"x-forwarded-for", "x-forwarded-by", "x-forwarded-host"}, - ) - - def test_no_trusted_proxy_headers_trusted_proxy(self): - with warnings.catch_warnings(record=True) as w: - warnings.resetwarnings() - warnings.simplefilter("always") - self._makeOne(trusted_proxy="localhost") - - self.assertGreaterEqual(len(w), 1) - self.assertTrue(issubclass(w[0].category, DeprecationWarning)) - self.assertIn("Implicitly trusting X-Forwarded-Proto", str(w[0])) - - def test_clear_untrusted_proxy_headers(self): - with warnings.catch_warnings(record=True) as w: - warnings.resetwarnings() - warnings.simplefilter("always") - self._makeOne( - trusted_proxy="localhost", trusted_proxy_headers={"x-forwarded-for"} - ) - - self.assertGreaterEqual(len(w), 1) - self.assertTrue(issubclass(w[0].category, DeprecationWarning)) - self.assertIn( - "clear_untrusted_proxy_headers will be set to True", str(w[0]) - ) - - def test_deprecated_send_bytes(self): - with warnings.catch_warnings(record=True) as w: - warnings.resetwarnings() - warnings.simplefilter("always") - self._makeOne(send_bytes=1) - - self.assertGreaterEqual(len(w), 1) - self.assertTrue(issubclass(w[0].category, DeprecationWarning)) - self.assertIn("send_bytes", str(w[0])) - - def test_badvar(self): - self.assertRaises(ValueError, self._makeOne, nope=True) - - def test_ipv4_disabled(self): - self.assertRaises( - ValueError, self._makeOne, ipv4=False, listen="127.0.0.1:8080" - ) - - def test_ipv6_disabled(self): - self.assertRaises(ValueError, self._makeOne, ipv6=False, listen="[::]:8080") - - def test_server_header_removable(self): - inst = self._makeOne(ident=None) - self.assertEqual(inst.ident, None) - - inst = self._makeOne(ident="") - self.assertEqual(inst.ident, None) - - inst = self._makeOne(ident="specific_header") - self.assertEqual(inst.ident, "specific_header") - - -class TestCLI(unittest.TestCase): - def parse(self, argv): - from waitress.adjustments import Adjustments - - return Adjustments.parse_args(argv) - - def test_noargs(self): - opts, args = self.parse([]) - self.assertDictEqual(opts, {"call": False, "help": False}) - self.assertSequenceEqual(args, []) - - def test_help(self): - opts, args = self.parse(["--help"]) - self.assertDictEqual(opts, {"call": False, "help": True}) - self.assertSequenceEqual(args, []) - - def test_call(self): - opts, args = self.parse(["--call"]) - self.assertDictEqual(opts, {"call": True, "help": False}) - self.assertSequenceEqual(args, []) - - def test_both(self): - opts, args = self.parse(["--call", "--help"]) - self.assertDictEqual(opts, {"call": True, "help": True}) - self.assertSequenceEqual(args, []) - - def test_positive_boolean(self): - opts, args = self.parse(["--expose-tracebacks"]) - self.assertDictContainsSubset({"expose_tracebacks": "true"}, opts) - self.assertSequenceEqual(args, []) - - def test_negative_boolean(self): - opts, args = self.parse(["--no-expose-tracebacks"]) - self.assertDictContainsSubset({"expose_tracebacks": "false"}, opts) - self.assertSequenceEqual(args, []) - - def test_cast_params(self): - opts, args = self.parse( - ["--host=localhost", "--port=80", "--unix-socket-perms=777"] - ) - self.assertDictContainsSubset( - {"host": "localhost", "port": "80", "unix_socket_perms": "777",}, opts - ) - self.assertSequenceEqual(args, []) - - def test_listen_params(self): - opts, args = self.parse(["--listen=test:80",]) - - self.assertDictContainsSubset({"listen": " test:80"}, opts) - self.assertSequenceEqual(args, []) - - def test_multiple_listen_params(self): - opts, args = self.parse(["--listen=test:80", "--listen=test:8080",]) - - self.assertDictContainsSubset({"listen": " test:80 test:8080"}, opts) - self.assertSequenceEqual(args, []) - - def test_bad_param(self): - import getopt - - self.assertRaises(getopt.GetoptError, self.parse, ["--no-host"]) - - -if hasattr(socket, "AF_UNIX"): - - class TestUnixSocket(unittest.TestCase): - def _makeOne(self, **kw): - from waitress.adjustments import Adjustments - - return Adjustments(**kw) - - def test_dont_mix_internet_and_unix_sockets(self): - sockets = [ - socket.socket(socket.AF_INET, socket.SOCK_STREAM), - socket.socket(socket.AF_UNIX, socket.SOCK_STREAM), - ] - self.assertRaises(ValueError, self._makeOne, sockets=sockets) - sockets[0].close() - sockets[1].close() diff --git a/libs/waitress/tests/test_buffers.py b/libs/waitress/tests/test_buffers.py deleted file mode 100644 index a1330ac1b..000000000 --- a/libs/waitress/tests/test_buffers.py +++ /dev/null @@ -1,523 +0,0 @@ -import unittest -import io - - -class TestFileBasedBuffer(unittest.TestCase): - def _makeOne(self, file=None, from_buffer=None): - from waitress.buffers import FileBasedBuffer - - buf = FileBasedBuffer(file, from_buffer=from_buffer) - self.buffers_to_close.append(buf) - return buf - - def setUp(self): - self.buffers_to_close = [] - - def tearDown(self): - for buf in self.buffers_to_close: - buf.close() - - def test_ctor_from_buffer_None(self): - inst = self._makeOne("file") - self.assertEqual(inst.file, "file") - - def test_ctor_from_buffer(self): - from_buffer = io.BytesIO(b"data") - from_buffer.getfile = lambda *x: from_buffer - f = io.BytesIO() - inst = self._makeOne(f, from_buffer) - self.assertEqual(inst.file, f) - del from_buffer.getfile - self.assertEqual(inst.remain, 4) - from_buffer.close() - - def test___len__(self): - inst = self._makeOne() - inst.remain = 10 - self.assertEqual(len(inst), 10) - - def test___nonzero__(self): - inst = self._makeOne() - inst.remain = 10 - self.assertEqual(bool(inst), True) - inst.remain = 0 - self.assertEqual(bool(inst), True) - - def test_append(self): - f = io.BytesIO(b"data") - inst = self._makeOne(f) - inst.append(b"data2") - self.assertEqual(f.getvalue(), b"datadata2") - self.assertEqual(inst.remain, 5) - - def test_get_skip_true(self): - f = io.BytesIO(b"data") - inst = self._makeOne(f) - result = inst.get(100, skip=True) - self.assertEqual(result, b"data") - self.assertEqual(inst.remain, -4) - - def test_get_skip_false(self): - f = io.BytesIO(b"data") - inst = self._makeOne(f) - result = inst.get(100, skip=False) - self.assertEqual(result, b"data") - self.assertEqual(inst.remain, 0) - - def test_get_skip_bytes_less_than_zero(self): - f = io.BytesIO(b"data") - inst = self._makeOne(f) - result = inst.get(-1, skip=False) - self.assertEqual(result, b"data") - self.assertEqual(inst.remain, 0) - - def test_skip_remain_gt_bytes(self): - f = io.BytesIO(b"d") - inst = self._makeOne(f) - inst.remain = 1 - inst.skip(1) - self.assertEqual(inst.remain, 0) - - def test_skip_remain_lt_bytes(self): - f = io.BytesIO(b"d") - inst = self._makeOne(f) - inst.remain = 1 - self.assertRaises(ValueError, inst.skip, 2) - - def test_newfile(self): - inst = self._makeOne() - self.assertRaises(NotImplementedError, inst.newfile) - - def test_prune_remain_notzero(self): - f = io.BytesIO(b"d") - inst = self._makeOne(f) - inst.remain = 1 - nf = io.BytesIO() - inst.newfile = lambda *x: nf - inst.prune() - self.assertTrue(inst.file is not f) - self.assertEqual(nf.getvalue(), b"d") - - def test_prune_remain_zero_tell_notzero(self): - f = io.BytesIO(b"d") - inst = self._makeOne(f) - nf = io.BytesIO(b"d") - inst.newfile = lambda *x: nf - inst.remain = 0 - inst.prune() - self.assertTrue(inst.file is not f) - self.assertEqual(nf.getvalue(), b"d") - - def test_prune_remain_zero_tell_zero(self): - f = io.BytesIO() - inst = self._makeOne(f) - inst.remain = 0 - inst.prune() - self.assertTrue(inst.file is f) - - def test_close(self): - f = io.BytesIO() - inst = self._makeOne(f) - inst.close() - self.assertTrue(f.closed) - self.buffers_to_close.remove(inst) - - -class TestTempfileBasedBuffer(unittest.TestCase): - def _makeOne(self, from_buffer=None): - from waitress.buffers import TempfileBasedBuffer - - buf = TempfileBasedBuffer(from_buffer=from_buffer) - self.buffers_to_close.append(buf) - return buf - - def setUp(self): - self.buffers_to_close = [] - - def tearDown(self): - for buf in self.buffers_to_close: - buf.close() - - def test_newfile(self): - inst = self._makeOne() - r = inst.newfile() - self.assertTrue(hasattr(r, "fileno")) # file - r.close() - - -class TestBytesIOBasedBuffer(unittest.TestCase): - def _makeOne(self, from_buffer=None): - from waitress.buffers import BytesIOBasedBuffer - - return BytesIOBasedBuffer(from_buffer=from_buffer) - - def test_ctor_from_buffer_not_None(self): - f = io.BytesIO() - f.getfile = lambda *x: f - inst = self._makeOne(f) - self.assertTrue(hasattr(inst.file, "read")) - - def test_ctor_from_buffer_None(self): - inst = self._makeOne() - self.assertTrue(hasattr(inst.file, "read")) - - def test_newfile(self): - inst = self._makeOne() - r = inst.newfile() - self.assertTrue(hasattr(r, "read")) - - -class TestReadOnlyFileBasedBuffer(unittest.TestCase): - def _makeOne(self, file, block_size=8192): - from waitress.buffers import ReadOnlyFileBasedBuffer - - buf = ReadOnlyFileBasedBuffer(file, block_size) - self.buffers_to_close.append(buf) - return buf - - def setUp(self): - self.buffers_to_close = [] - - def tearDown(self): - for buf in self.buffers_to_close: - buf.close() - - def test_prepare_not_seekable(self): - f = KindaFilelike(b"abc") - inst = self._makeOne(f) - result = inst.prepare() - self.assertEqual(result, False) - self.assertEqual(inst.remain, 0) - - def test_prepare_not_seekable_closeable(self): - f = KindaFilelike(b"abc", close=1) - inst = self._makeOne(f) - result = inst.prepare() - self.assertEqual(result, False) - self.assertEqual(inst.remain, 0) - self.assertTrue(hasattr(inst, "close")) - - def test_prepare_seekable_closeable(self): - f = Filelike(b"abc", close=1, tellresults=[0, 10]) - inst = self._makeOne(f) - result = inst.prepare() - self.assertEqual(result, 10) - self.assertEqual(inst.remain, 10) - self.assertEqual(inst.file.seeked, 0) - self.assertTrue(hasattr(inst, "close")) - - def test_get_numbytes_neg_one(self): - f = io.BytesIO(b"abcdef") - inst = self._makeOne(f) - inst.remain = 2 - result = inst.get(-1) - self.assertEqual(result, b"ab") - self.assertEqual(inst.remain, 2) - self.assertEqual(f.tell(), 0) - - def test_get_numbytes_gt_remain(self): - f = io.BytesIO(b"abcdef") - inst = self._makeOne(f) - inst.remain = 2 - result = inst.get(3) - self.assertEqual(result, b"ab") - self.assertEqual(inst.remain, 2) - self.assertEqual(f.tell(), 0) - - def test_get_numbytes_lt_remain(self): - f = io.BytesIO(b"abcdef") - inst = self._makeOne(f) - inst.remain = 2 - result = inst.get(1) - self.assertEqual(result, b"a") - self.assertEqual(inst.remain, 2) - self.assertEqual(f.tell(), 0) - - def test_get_numbytes_gt_remain_withskip(self): - f = io.BytesIO(b"abcdef") - inst = self._makeOne(f) - inst.remain = 2 - result = inst.get(3, skip=True) - self.assertEqual(result, b"ab") - self.assertEqual(inst.remain, 0) - self.assertEqual(f.tell(), 2) - - def test_get_numbytes_lt_remain_withskip(self): - f = io.BytesIO(b"abcdef") - inst = self._makeOne(f) - inst.remain = 2 - result = inst.get(1, skip=True) - self.assertEqual(result, b"a") - self.assertEqual(inst.remain, 1) - self.assertEqual(f.tell(), 1) - - def test___iter__(self): - data = b"a" * 10000 - f = io.BytesIO(data) - inst = self._makeOne(f) - r = b"" - for val in inst: - r += val - self.assertEqual(r, data) - - def test_append(self): - inst = self._makeOne(None) - self.assertRaises(NotImplementedError, inst.append, "a") - - -class TestOverflowableBuffer(unittest.TestCase): - def _makeOne(self, overflow=10): - from waitress.buffers import OverflowableBuffer - - buf = OverflowableBuffer(overflow) - self.buffers_to_close.append(buf) - return buf - - def setUp(self): - self.buffers_to_close = [] - - def tearDown(self): - for buf in self.buffers_to_close: - buf.close() - - def test___len__buf_is_None(self): - inst = self._makeOne() - self.assertEqual(len(inst), 0) - - def test___len__buf_is_not_None(self): - inst = self._makeOne() - inst.buf = b"abc" - self.assertEqual(len(inst), 3) - self.buffers_to_close.remove(inst) - - def test___nonzero__(self): - inst = self._makeOne() - inst.buf = b"abc" - self.assertEqual(bool(inst), True) - inst.buf = b"" - self.assertEqual(bool(inst), False) - self.buffers_to_close.remove(inst) - - def test___nonzero___on_int_overflow_buffer(self): - inst = self._makeOne() - - class int_overflow_buf(bytes): - def __len__(self): - # maxint + 1 - return 0x7FFFFFFFFFFFFFFF + 1 - - inst.buf = int_overflow_buf() - self.assertEqual(bool(inst), True) - inst.buf = b"" - self.assertEqual(bool(inst), False) - self.buffers_to_close.remove(inst) - - def test__create_buffer_large(self): - from waitress.buffers import TempfileBasedBuffer - - inst = self._makeOne() - inst.strbuf = b"x" * 11 - inst._create_buffer() - self.assertEqual(inst.buf.__class__, TempfileBasedBuffer) - self.assertEqual(inst.buf.get(100), b"x" * 11) - self.assertEqual(inst.strbuf, b"") - - def test__create_buffer_small(self): - from waitress.buffers import BytesIOBasedBuffer - - inst = self._makeOne() - inst.strbuf = b"x" * 5 - inst._create_buffer() - self.assertEqual(inst.buf.__class__, BytesIOBasedBuffer) - self.assertEqual(inst.buf.get(100), b"x" * 5) - self.assertEqual(inst.strbuf, b"") - - def test_append_with_len_more_than_max_int(self): - from waitress.compat import MAXINT - - inst = self._makeOne() - inst.overflowed = True - buf = DummyBuffer(length=MAXINT) - inst.buf = buf - result = inst.append(b"x") - # we don't want this to throw an OverflowError on Python 2 (see - # https://github.com/Pylons/waitress/issues/47) - self.assertEqual(result, None) - self.buffers_to_close.remove(inst) - - def test_append_buf_None_not_longer_than_srtbuf_limit(self): - inst = self._makeOne() - inst.strbuf = b"x" * 5 - inst.append(b"hello") - self.assertEqual(inst.strbuf, b"xxxxxhello") - - def test_append_buf_None_longer_than_strbuf_limit(self): - inst = self._makeOne(10000) - inst.strbuf = b"x" * 8192 - inst.append(b"hello") - self.assertEqual(inst.strbuf, b"") - self.assertEqual(len(inst.buf), 8197) - - def test_append_overflow(self): - inst = self._makeOne(10) - inst.strbuf = b"x" * 8192 - inst.append(b"hello") - self.assertEqual(inst.strbuf, b"") - self.assertEqual(len(inst.buf), 8197) - - def test_append_sz_gt_overflow(self): - from waitress.buffers import BytesIOBasedBuffer - - f = io.BytesIO(b"data") - inst = self._makeOne(f) - buf = BytesIOBasedBuffer() - inst.buf = buf - inst.overflow = 2 - inst.append(b"data2") - self.assertEqual(f.getvalue(), b"data") - self.assertTrue(inst.overflowed) - self.assertNotEqual(inst.buf, buf) - - def test_get_buf_None_skip_False(self): - inst = self._makeOne() - inst.strbuf = b"x" * 5 - r = inst.get(5) - self.assertEqual(r, b"xxxxx") - - def test_get_buf_None_skip_True(self): - inst = self._makeOne() - inst.strbuf = b"x" * 5 - r = inst.get(5, skip=True) - self.assertFalse(inst.buf is None) - self.assertEqual(r, b"xxxxx") - - def test_skip_buf_None(self): - inst = self._makeOne() - inst.strbuf = b"data" - inst.skip(4) - self.assertEqual(inst.strbuf, b"") - self.assertNotEqual(inst.buf, None) - - def test_skip_buf_None_allow_prune_True(self): - inst = self._makeOne() - inst.strbuf = b"data" - inst.skip(4, True) - self.assertEqual(inst.strbuf, b"") - self.assertEqual(inst.buf, None) - - def test_prune_buf_None(self): - inst = self._makeOne() - inst.prune() - self.assertEqual(inst.strbuf, b"") - - def test_prune_with_buf(self): - inst = self._makeOne() - - class Buf(object): - def prune(self): - self.pruned = True - - inst.buf = Buf() - inst.prune() - self.assertEqual(inst.buf.pruned, True) - self.buffers_to_close.remove(inst) - - def test_prune_with_buf_overflow(self): - inst = self._makeOne() - - class DummyBuffer(io.BytesIO): - def getfile(self): - return self - - def prune(self): - return True - - def __len__(self): - return 5 - - def close(self): - pass - - buf = DummyBuffer(b"data") - inst.buf = buf - inst.overflowed = True - inst.overflow = 10 - inst.prune() - self.assertNotEqual(inst.buf, buf) - - def test_prune_with_buflen_more_than_max_int(self): - from waitress.compat import MAXINT - - inst = self._makeOne() - inst.overflowed = True - buf = DummyBuffer(length=MAXINT + 1) - inst.buf = buf - result = inst.prune() - # we don't want this to throw an OverflowError on Python 2 (see - # https://github.com/Pylons/waitress/issues/47) - self.assertEqual(result, None) - - def test_getfile_buf_None(self): - inst = self._makeOne() - f = inst.getfile() - self.assertTrue(hasattr(f, "read")) - - def test_getfile_buf_not_None(self): - inst = self._makeOne() - buf = io.BytesIO() - buf.getfile = lambda *x: buf - inst.buf = buf - f = inst.getfile() - self.assertEqual(f, buf) - - def test_close_nobuf(self): - inst = self._makeOne() - inst.buf = None - self.assertEqual(inst.close(), None) # doesnt raise - self.buffers_to_close.remove(inst) - - def test_close_withbuf(self): - class Buffer(object): - def close(self): - self.closed = True - - buf = Buffer() - inst = self._makeOne() - inst.buf = buf - inst.close() - self.assertTrue(buf.closed) - self.buffers_to_close.remove(inst) - - -class KindaFilelike(object): - def __init__(self, bytes, close=None, tellresults=None): - self.bytes = bytes - self.tellresults = tellresults - if close is not None: - self.close = lambda: close - - -class Filelike(KindaFilelike): - def seek(self, v, whence=0): - self.seeked = v - - def tell(self): - v = self.tellresults.pop(0) - return v - - -class DummyBuffer(object): - def __init__(self, length=0): - self.length = length - - def __len__(self): - return self.length - - def append(self, s): - self.length = self.length + len(s) - - def prune(self): - pass - - def close(self): - pass diff --git a/libs/waitress/tests/test_channel.py b/libs/waitress/tests/test_channel.py deleted file mode 100644 index 14ef5a0ec..000000000 --- a/libs/waitress/tests/test_channel.py +++ /dev/null @@ -1,882 +0,0 @@ -import unittest -import io - - -class TestHTTPChannel(unittest.TestCase): - def _makeOne(self, sock, addr, adj, map=None): - from waitress.channel import HTTPChannel - - server = DummyServer() - return HTTPChannel(server, sock, addr, adj=adj, map=map) - - def _makeOneWithMap(self, adj=None): - if adj is None: - adj = DummyAdjustments() - sock = DummySock() - map = {} - inst = self._makeOne(sock, "127.0.0.1", adj, map=map) - inst.outbuf_lock = DummyLock() - return inst, sock, map - - def test_ctor(self): - inst, _, map = self._makeOneWithMap() - self.assertEqual(inst.addr, "127.0.0.1") - self.assertEqual(inst.sendbuf_len, 2048) - self.assertEqual(map[100], inst) - - def test_total_outbufs_len_an_outbuf_size_gt_sys_maxint(self): - from waitress.compat import MAXINT - - inst, _, map = self._makeOneWithMap() - - class DummyBuffer(object): - chunks = [] - - def append(self, data): - self.chunks.append(data) - - class DummyData(object): - def __len__(self): - return MAXINT - - inst.total_outbufs_len = 1 - inst.outbufs = [DummyBuffer()] - inst.write_soon(DummyData()) - # we are testing that this method does not raise an OverflowError - # (see https://github.com/Pylons/waitress/issues/47) - self.assertEqual(inst.total_outbufs_len, MAXINT + 1) - - def test_writable_something_in_outbuf(self): - inst, sock, map = self._makeOneWithMap() - inst.total_outbufs_len = 3 - self.assertTrue(inst.writable()) - - def test_writable_nothing_in_outbuf(self): - inst, sock, map = self._makeOneWithMap() - self.assertFalse(inst.writable()) - - def test_writable_nothing_in_outbuf_will_close(self): - inst, sock, map = self._makeOneWithMap() - inst.will_close = True - self.assertTrue(inst.writable()) - - def test_handle_write_not_connected(self): - inst, sock, map = self._makeOneWithMap() - inst.connected = False - self.assertFalse(inst.handle_write()) - - def test_handle_write_with_requests(self): - inst, sock, map = self._makeOneWithMap() - inst.requests = True - inst.last_activity = 0 - result = inst.handle_write() - self.assertEqual(result, None) - self.assertEqual(inst.last_activity, 0) - - def test_handle_write_no_request_with_outbuf(self): - inst, sock, map = self._makeOneWithMap() - inst.requests = [] - inst.outbufs = [DummyBuffer(b"abc")] - inst.total_outbufs_len = len(inst.outbufs[0]) - inst.last_activity = 0 - result = inst.handle_write() - self.assertEqual(result, None) - self.assertNotEqual(inst.last_activity, 0) - self.assertEqual(sock.sent, b"abc") - - def test_handle_write_outbuf_raises_socketerror(self): - import socket - - inst, sock, map = self._makeOneWithMap() - inst.requests = [] - outbuf = DummyBuffer(b"abc", socket.error) - inst.outbufs = [outbuf] - inst.total_outbufs_len = len(outbuf) - inst.last_activity = 0 - inst.logger = DummyLogger() - result = inst.handle_write() - self.assertEqual(result, None) - self.assertEqual(inst.last_activity, 0) - self.assertEqual(sock.sent, b"") - self.assertEqual(len(inst.logger.exceptions), 1) - self.assertTrue(outbuf.closed) - - def test_handle_write_outbuf_raises_othererror(self): - inst, sock, map = self._makeOneWithMap() - inst.requests = [] - outbuf = DummyBuffer(b"abc", IOError) - inst.outbufs = [outbuf] - inst.total_outbufs_len = len(outbuf) - inst.last_activity = 0 - inst.logger = DummyLogger() - result = inst.handle_write() - self.assertEqual(result, None) - self.assertEqual(inst.last_activity, 0) - self.assertEqual(sock.sent, b"") - self.assertEqual(len(inst.logger.exceptions), 1) - self.assertTrue(outbuf.closed) - - def test_handle_write_no_requests_no_outbuf_will_close(self): - inst, sock, map = self._makeOneWithMap() - inst.requests = [] - outbuf = DummyBuffer(b"") - inst.outbufs = [outbuf] - inst.will_close = True - inst.last_activity = 0 - result = inst.handle_write() - self.assertEqual(result, None) - self.assertEqual(inst.connected, False) - self.assertEqual(sock.closed, True) - self.assertEqual(inst.last_activity, 0) - self.assertTrue(outbuf.closed) - - def test_handle_write_no_requests_outbuf_gt_send_bytes(self): - inst, sock, map = self._makeOneWithMap() - inst.requests = [True] - inst.outbufs = [DummyBuffer(b"abc")] - inst.total_outbufs_len = len(inst.outbufs[0]) - inst.adj.send_bytes = 2 - inst.will_close = False - inst.last_activity = 0 - result = inst.handle_write() - self.assertEqual(result, None) - self.assertEqual(inst.will_close, False) - self.assertTrue(inst.outbuf_lock.acquired) - self.assertEqual(sock.sent, b"abc") - - def test_handle_write_close_when_flushed(self): - inst, sock, map = self._makeOneWithMap() - outbuf = DummyBuffer(b"abc") - inst.outbufs = [outbuf] - inst.total_outbufs_len = len(outbuf) - inst.will_close = False - inst.close_when_flushed = True - inst.last_activity = 0 - result = inst.handle_write() - self.assertEqual(result, None) - self.assertEqual(inst.will_close, True) - self.assertEqual(inst.close_when_flushed, False) - self.assertEqual(sock.sent, b"abc") - self.assertTrue(outbuf.closed) - - def test_readable_no_requests_not_will_close(self): - inst, sock, map = self._makeOneWithMap() - inst.requests = [] - inst.will_close = False - self.assertEqual(inst.readable(), True) - - def test_readable_no_requests_will_close(self): - inst, sock, map = self._makeOneWithMap() - inst.requests = [] - inst.will_close = True - self.assertEqual(inst.readable(), False) - - def test_readable_with_requests(self): - inst, sock, map = self._makeOneWithMap() - inst.requests = True - self.assertEqual(inst.readable(), False) - - def test_handle_read_no_error(self): - inst, sock, map = self._makeOneWithMap() - inst.will_close = False - inst.recv = lambda *arg: b"abc" - inst.last_activity = 0 - L = [] - inst.received = lambda x: L.append(x) - result = inst.handle_read() - self.assertEqual(result, None) - self.assertNotEqual(inst.last_activity, 0) - self.assertEqual(L, [b"abc"]) - - def test_handle_read_error(self): - import socket - - inst, sock, map = self._makeOneWithMap() - inst.will_close = False - - def recv(b): - raise socket.error - - inst.recv = recv - inst.last_activity = 0 - inst.logger = DummyLogger() - result = inst.handle_read() - self.assertEqual(result, None) - self.assertEqual(inst.last_activity, 0) - self.assertEqual(len(inst.logger.exceptions), 1) - - def test_write_soon_empty_byte(self): - inst, sock, map = self._makeOneWithMap() - wrote = inst.write_soon(b"") - self.assertEqual(wrote, 0) - self.assertEqual(len(inst.outbufs[0]), 0) - - def test_write_soon_nonempty_byte(self): - inst, sock, map = self._makeOneWithMap() - wrote = inst.write_soon(b"a") - self.assertEqual(wrote, 1) - self.assertEqual(len(inst.outbufs[0]), 1) - - def test_write_soon_filewrapper(self): - from waitress.buffers import ReadOnlyFileBasedBuffer - - f = io.BytesIO(b"abc") - wrapper = ReadOnlyFileBasedBuffer(f, 8192) - wrapper.prepare() - inst, sock, map = self._makeOneWithMap() - outbufs = inst.outbufs - orig_outbuf = outbufs[0] - wrote = inst.write_soon(wrapper) - self.assertEqual(wrote, 3) - self.assertEqual(len(outbufs), 3) - self.assertEqual(outbufs[0], orig_outbuf) - self.assertEqual(outbufs[1], wrapper) - self.assertEqual(outbufs[2].__class__.__name__, "OverflowableBuffer") - - def test_write_soon_disconnected(self): - from waitress.channel import ClientDisconnected - - inst, sock, map = self._makeOneWithMap() - inst.connected = False - self.assertRaises(ClientDisconnected, lambda: inst.write_soon(b"stuff")) - - def test_write_soon_disconnected_while_over_watermark(self): - from waitress.channel import ClientDisconnected - - inst, sock, map = self._makeOneWithMap() - - def dummy_flush(): - inst.connected = False - - inst._flush_outbufs_below_high_watermark = dummy_flush - self.assertRaises(ClientDisconnected, lambda: inst.write_soon(b"stuff")) - - def test_write_soon_rotates_outbuf_on_overflow(self): - inst, sock, map = self._makeOneWithMap() - inst.adj.outbuf_high_watermark = 3 - inst.current_outbuf_count = 4 - wrote = inst.write_soon(b"xyz") - self.assertEqual(wrote, 3) - self.assertEqual(len(inst.outbufs), 2) - self.assertEqual(inst.outbufs[0].get(), b"") - self.assertEqual(inst.outbufs[1].get(), b"xyz") - - def test_write_soon_waits_on_backpressure(self): - inst, sock, map = self._makeOneWithMap() - inst.adj.outbuf_high_watermark = 3 - inst.total_outbufs_len = 4 - inst.current_outbuf_count = 4 - - class Lock(DummyLock): - def wait(self): - inst.total_outbufs_len = 0 - super(Lock, self).wait() - - inst.outbuf_lock = Lock() - wrote = inst.write_soon(b"xyz") - self.assertEqual(wrote, 3) - self.assertEqual(len(inst.outbufs), 2) - self.assertEqual(inst.outbufs[0].get(), b"") - self.assertEqual(inst.outbufs[1].get(), b"xyz") - self.assertTrue(inst.outbuf_lock.waited) - - def test_handle_write_notify_after_flush(self): - inst, sock, map = self._makeOneWithMap() - inst.requests = [True] - inst.outbufs = [DummyBuffer(b"abc")] - inst.total_outbufs_len = len(inst.outbufs[0]) - inst.adj.send_bytes = 1 - inst.adj.outbuf_high_watermark = 5 - inst.will_close = False - inst.last_activity = 0 - result = inst.handle_write() - self.assertEqual(result, None) - self.assertEqual(inst.will_close, False) - self.assertTrue(inst.outbuf_lock.acquired) - self.assertTrue(inst.outbuf_lock.notified) - self.assertEqual(sock.sent, b"abc") - - def test_handle_write_no_notify_after_flush(self): - inst, sock, map = self._makeOneWithMap() - inst.requests = [True] - inst.outbufs = [DummyBuffer(b"abc")] - inst.total_outbufs_len = len(inst.outbufs[0]) - inst.adj.send_bytes = 1 - inst.adj.outbuf_high_watermark = 2 - sock.send = lambda x: False - inst.will_close = False - inst.last_activity = 0 - result = inst.handle_write() - self.assertEqual(result, None) - self.assertEqual(inst.will_close, False) - self.assertTrue(inst.outbuf_lock.acquired) - self.assertFalse(inst.outbuf_lock.notified) - self.assertEqual(sock.sent, b"") - - def test__flush_some_empty_outbuf(self): - inst, sock, map = self._makeOneWithMap() - result = inst._flush_some() - self.assertEqual(result, False) - - def test__flush_some_full_outbuf_socket_returns_nonzero(self): - inst, sock, map = self._makeOneWithMap() - inst.outbufs[0].append(b"abc") - inst.total_outbufs_len = sum(len(x) for x in inst.outbufs) - result = inst._flush_some() - self.assertEqual(result, True) - - def test__flush_some_full_outbuf_socket_returns_zero(self): - inst, sock, map = self._makeOneWithMap() - sock.send = lambda x: False - inst.outbufs[0].append(b"abc") - inst.total_outbufs_len = sum(len(x) for x in inst.outbufs) - result = inst._flush_some() - self.assertEqual(result, False) - - def test_flush_some_multiple_buffers_first_empty(self): - inst, sock, map = self._makeOneWithMap() - sock.send = lambda x: len(x) - buffer = DummyBuffer(b"abc") - inst.outbufs.append(buffer) - inst.total_outbufs_len = sum(len(x) for x in inst.outbufs) - result = inst._flush_some() - self.assertEqual(result, True) - self.assertEqual(buffer.skipped, 3) - self.assertEqual(inst.outbufs, [buffer]) - - def test_flush_some_multiple_buffers_close_raises(self): - inst, sock, map = self._makeOneWithMap() - sock.send = lambda x: len(x) - buffer = DummyBuffer(b"abc") - inst.outbufs.append(buffer) - inst.total_outbufs_len = sum(len(x) for x in inst.outbufs) - inst.logger = DummyLogger() - - def doraise(): - raise NotImplementedError - - inst.outbufs[0].close = doraise - result = inst._flush_some() - self.assertEqual(result, True) - self.assertEqual(buffer.skipped, 3) - self.assertEqual(inst.outbufs, [buffer]) - self.assertEqual(len(inst.logger.exceptions), 1) - - def test__flush_some_outbuf_len_gt_sys_maxint(self): - from waitress.compat import MAXINT - - inst, sock, map = self._makeOneWithMap() - - class DummyHugeOutbuffer(object): - def __init__(self): - self.length = MAXINT + 1 - - def __len__(self): - return self.length - - def get(self, numbytes): - self.length = 0 - return b"123" - - buf = DummyHugeOutbuffer() - inst.outbufs = [buf] - inst.send = lambda *arg: 0 - result = inst._flush_some() - # we are testing that _flush_some doesn't raise an OverflowError - # when one of its outbufs has a __len__ that returns gt sys.maxint - self.assertEqual(result, False) - - def test_handle_close(self): - inst, sock, map = self._makeOneWithMap() - inst.handle_close() - self.assertEqual(inst.connected, False) - self.assertEqual(sock.closed, True) - - def test_handle_close_outbuf_raises_on_close(self): - inst, sock, map = self._makeOneWithMap() - - def doraise(): - raise NotImplementedError - - inst.outbufs[0].close = doraise - inst.logger = DummyLogger() - inst.handle_close() - self.assertEqual(inst.connected, False) - self.assertEqual(sock.closed, True) - self.assertEqual(len(inst.logger.exceptions), 1) - - def test_add_channel(self): - inst, sock, map = self._makeOneWithMap() - fileno = inst._fileno - inst.add_channel(map) - self.assertEqual(map[fileno], inst) - self.assertEqual(inst.server.active_channels[fileno], inst) - - def test_del_channel(self): - inst, sock, map = self._makeOneWithMap() - fileno = inst._fileno - inst.server.active_channels[fileno] = True - inst.del_channel(map) - self.assertEqual(map.get(fileno), None) - self.assertEqual(inst.server.active_channels.get(fileno), None) - - def test_received(self): - inst, sock, map = self._makeOneWithMap() - inst.server = DummyServer() - inst.received(b"GET / HTTP/1.1\r\n\r\n") - self.assertEqual(inst.server.tasks, [inst]) - self.assertTrue(inst.requests) - - def test_received_no_chunk(self): - inst, sock, map = self._makeOneWithMap() - self.assertEqual(inst.received(b""), False) - - def test_received_preq_not_completed(self): - inst, sock, map = self._makeOneWithMap() - inst.server = DummyServer() - preq = DummyParser() - inst.request = preq - preq.completed = False - preq.empty = True - inst.received(b"GET / HTTP/1.1\r\n\r\n") - self.assertEqual(inst.requests, ()) - self.assertEqual(inst.server.tasks, []) - - def test_received_preq_completed_empty(self): - inst, sock, map = self._makeOneWithMap() - inst.server = DummyServer() - preq = DummyParser() - inst.request = preq - preq.completed = True - preq.empty = True - inst.received(b"GET / HTTP/1.1\r\n\r\n") - self.assertEqual(inst.request, None) - self.assertEqual(inst.server.tasks, []) - - def test_received_preq_error(self): - inst, sock, map = self._makeOneWithMap() - inst.server = DummyServer() - preq = DummyParser() - inst.request = preq - preq.completed = True - preq.error = True - inst.received(b"GET / HTTP/1.1\r\n\r\n") - self.assertEqual(inst.request, None) - self.assertEqual(len(inst.server.tasks), 1) - self.assertTrue(inst.requests) - - def test_received_preq_completed_connection_close(self): - inst, sock, map = self._makeOneWithMap() - inst.server = DummyServer() - preq = DummyParser() - inst.request = preq - preq.completed = True - preq.empty = True - preq.connection_close = True - inst.received(b"GET / HTTP/1.1\r\n\r\n" + b"a" * 50000) - self.assertEqual(inst.request, None) - self.assertEqual(inst.server.tasks, []) - - def test_received_headers_finished_expect_continue_false(self): - inst, sock, map = self._makeOneWithMap() - inst.server = DummyServer() - preq = DummyParser() - inst.request = preq - preq.expect_continue = False - preq.headers_finished = True - preq.completed = False - preq.empty = False - preq.retval = 1 - inst.received(b"GET / HTTP/1.1\r\n\r\n") - self.assertEqual(inst.request, preq) - self.assertEqual(inst.server.tasks, []) - self.assertEqual(inst.outbufs[0].get(100), b"") - - def test_received_headers_finished_expect_continue_true(self): - inst, sock, map = self._makeOneWithMap() - inst.server = DummyServer() - preq = DummyParser() - inst.request = preq - preq.expect_continue = True - preq.headers_finished = True - preq.completed = False - preq.empty = False - inst.received(b"GET / HTTP/1.1\r\n\r\n") - self.assertEqual(inst.request, preq) - self.assertEqual(inst.server.tasks, []) - self.assertEqual(sock.sent, b"HTTP/1.1 100 Continue\r\n\r\n") - self.assertEqual(inst.sent_continue, True) - self.assertEqual(preq.completed, False) - - def test_received_headers_finished_expect_continue_true_sent_true(self): - inst, sock, map = self._makeOneWithMap() - inst.server = DummyServer() - preq = DummyParser() - inst.request = preq - preq.expect_continue = True - preq.headers_finished = True - preq.completed = False - preq.empty = False - inst.sent_continue = True - inst.received(b"GET / HTTP/1.1\r\n\r\n") - self.assertEqual(inst.request, preq) - self.assertEqual(inst.server.tasks, []) - self.assertEqual(sock.sent, b"") - self.assertEqual(inst.sent_continue, True) - self.assertEqual(preq.completed, False) - - def test_service_no_requests(self): - inst, sock, map = self._makeOneWithMap() - inst.requests = [] - inst.service() - self.assertEqual(inst.requests, []) - self.assertTrue(inst.server.trigger_pulled) - self.assertTrue(inst.last_activity) - - def test_service_with_one_request(self): - inst, sock, map = self._makeOneWithMap() - request = DummyRequest() - inst.task_class = DummyTaskClass() - inst.requests = [request] - inst.service() - self.assertEqual(inst.requests, []) - self.assertTrue(request.serviced) - self.assertTrue(request.closed) - - def test_service_with_one_error_request(self): - inst, sock, map = self._makeOneWithMap() - request = DummyRequest() - request.error = DummyError() - inst.error_task_class = DummyTaskClass() - inst.requests = [request] - inst.service() - self.assertEqual(inst.requests, []) - self.assertTrue(request.serviced) - self.assertTrue(request.closed) - - def test_service_with_multiple_requests(self): - inst, sock, map = self._makeOneWithMap() - request1 = DummyRequest() - request2 = DummyRequest() - inst.task_class = DummyTaskClass() - inst.requests = [request1, request2] - inst.service() - self.assertEqual(inst.requests, []) - self.assertTrue(request1.serviced) - self.assertTrue(request2.serviced) - self.assertTrue(request1.closed) - self.assertTrue(request2.closed) - - def test_service_with_request_raises(self): - inst, sock, map = self._makeOneWithMap() - inst.adj.expose_tracebacks = False - inst.server = DummyServer() - request = DummyRequest() - inst.requests = [request] - inst.task_class = DummyTaskClass(ValueError) - inst.task_class.wrote_header = False - inst.error_task_class = DummyTaskClass() - inst.logger = DummyLogger() - inst.service() - self.assertTrue(request.serviced) - self.assertEqual(inst.requests, []) - self.assertEqual(len(inst.logger.exceptions), 1) - self.assertTrue(inst.server.trigger_pulled) - self.assertTrue(inst.last_activity) - self.assertFalse(inst.will_close) - self.assertEqual(inst.error_task_class.serviced, True) - self.assertTrue(request.closed) - - def test_service_with_requests_raises_already_wrote_header(self): - inst, sock, map = self._makeOneWithMap() - inst.adj.expose_tracebacks = False - inst.server = DummyServer() - request = DummyRequest() - inst.requests = [request] - inst.task_class = DummyTaskClass(ValueError) - inst.error_task_class = DummyTaskClass() - inst.logger = DummyLogger() - inst.service() - self.assertTrue(request.serviced) - self.assertEqual(inst.requests, []) - self.assertEqual(len(inst.logger.exceptions), 1) - self.assertTrue(inst.server.trigger_pulled) - self.assertTrue(inst.last_activity) - self.assertTrue(inst.close_when_flushed) - self.assertEqual(inst.error_task_class.serviced, False) - self.assertTrue(request.closed) - - def test_service_with_requests_raises_didnt_write_header_expose_tbs(self): - inst, sock, map = self._makeOneWithMap() - inst.adj.expose_tracebacks = True - inst.server = DummyServer() - request = DummyRequest() - inst.requests = [request] - inst.task_class = DummyTaskClass(ValueError) - inst.task_class.wrote_header = False - inst.error_task_class = DummyTaskClass() - inst.logger = DummyLogger() - inst.service() - self.assertTrue(request.serviced) - self.assertFalse(inst.will_close) - self.assertEqual(inst.requests, []) - self.assertEqual(len(inst.logger.exceptions), 1) - self.assertTrue(inst.server.trigger_pulled) - self.assertTrue(inst.last_activity) - self.assertEqual(inst.error_task_class.serviced, True) - self.assertTrue(request.closed) - - def test_service_with_requests_raises_didnt_write_header(self): - inst, sock, map = self._makeOneWithMap() - inst.adj.expose_tracebacks = False - inst.server = DummyServer() - request = DummyRequest() - inst.requests = [request] - inst.task_class = DummyTaskClass(ValueError) - inst.task_class.wrote_header = False - inst.logger = DummyLogger() - inst.service() - self.assertTrue(request.serviced) - self.assertEqual(inst.requests, []) - self.assertEqual(len(inst.logger.exceptions), 1) - self.assertTrue(inst.server.trigger_pulled) - self.assertTrue(inst.last_activity) - self.assertTrue(inst.close_when_flushed) - self.assertTrue(request.closed) - - def test_service_with_request_raises_disconnect(self): - from waitress.channel import ClientDisconnected - - inst, sock, map = self._makeOneWithMap() - inst.adj.expose_tracebacks = False - inst.server = DummyServer() - request = DummyRequest() - inst.requests = [request] - inst.task_class = DummyTaskClass(ClientDisconnected) - inst.error_task_class = DummyTaskClass() - inst.logger = DummyLogger() - inst.service() - self.assertTrue(request.serviced) - self.assertEqual(inst.requests, []) - self.assertEqual(len(inst.logger.infos), 1) - self.assertTrue(inst.server.trigger_pulled) - self.assertTrue(inst.last_activity) - self.assertFalse(inst.will_close) - self.assertEqual(inst.error_task_class.serviced, False) - self.assertTrue(request.closed) - - def test_service_with_request_error_raises_disconnect(self): - from waitress.channel import ClientDisconnected - - inst, sock, map = self._makeOneWithMap() - inst.adj.expose_tracebacks = False - inst.server = DummyServer() - request = DummyRequest() - err_request = DummyRequest() - inst.requests = [request] - inst.parser_class = lambda x: err_request - inst.task_class = DummyTaskClass(RuntimeError) - inst.task_class.wrote_header = False - inst.error_task_class = DummyTaskClass(ClientDisconnected) - inst.logger = DummyLogger() - inst.service() - self.assertTrue(request.serviced) - self.assertTrue(err_request.serviced) - self.assertEqual(inst.requests, []) - self.assertEqual(len(inst.logger.exceptions), 1) - self.assertEqual(len(inst.logger.infos), 0) - self.assertTrue(inst.server.trigger_pulled) - self.assertTrue(inst.last_activity) - self.assertFalse(inst.will_close) - self.assertEqual(inst.task_class.serviced, True) - self.assertEqual(inst.error_task_class.serviced, True) - self.assertTrue(request.closed) - - def test_cancel_no_requests(self): - inst, sock, map = self._makeOneWithMap() - inst.requests = () - inst.cancel() - self.assertEqual(inst.requests, []) - - def test_cancel_with_requests(self): - inst, sock, map = self._makeOneWithMap() - inst.requests = [None] - inst.cancel() - self.assertEqual(inst.requests, []) - - -class DummySock(object): - blocking = False - closed = False - - def __init__(self): - self.sent = b"" - - def setblocking(self, *arg): - self.blocking = True - - def fileno(self): - return 100 - - def getpeername(self): - return "127.0.0.1" - - def getsockopt(self, level, option): - return 2048 - - def close(self): - self.closed = True - - def send(self, data): - self.sent += data - return len(data) - - -class DummyLock(object): - notified = False - - def __init__(self, acquirable=True): - self.acquirable = acquirable - - def acquire(self, val): - self.val = val - self.acquired = True - return self.acquirable - - def release(self): - self.released = True - - def notify(self): - self.notified = True - - def wait(self): - self.waited = True - - def __exit__(self, type, val, traceback): - self.acquire(True) - - def __enter__(self): - pass - - -class DummyBuffer(object): - closed = False - - def __init__(self, data, toraise=None): - self.data = data - self.toraise = toraise - - def get(self, *arg): - if self.toraise: - raise self.toraise - data = self.data - self.data = b"" - return data - - def skip(self, num, x): - self.skipped = num - - def __len__(self): - return len(self.data) - - def close(self): - self.closed = True - - -class DummyAdjustments(object): - outbuf_overflow = 1048576 - outbuf_high_watermark = 1048576 - inbuf_overflow = 512000 - cleanup_interval = 900 - url_scheme = "http" - channel_timeout = 300 - log_socket_errors = True - recv_bytes = 8192 - send_bytes = 1 - expose_tracebacks = True - ident = "waitress" - max_request_header_size = 10000 - - -class DummyServer(object): - trigger_pulled = False - adj = DummyAdjustments() - - def __init__(self): - self.tasks = [] - self.active_channels = {} - - def add_task(self, task): - self.tasks.append(task) - - def pull_trigger(self): - self.trigger_pulled = True - - -class DummyParser(object): - version = 1 - data = None - completed = True - empty = False - headers_finished = False - expect_continue = False - retval = None - error = None - connection_close = False - - def received(self, data): - self.data = data - if self.retval is not None: - return self.retval - return len(data) - - -class DummyRequest(object): - error = None - path = "/" - version = "1.0" - closed = False - - def __init__(self): - self.headers = {} - - def close(self): - self.closed = True - - -class DummyLogger(object): - def __init__(self): - self.exceptions = [] - self.infos = [] - self.warnings = [] - - def info(self, msg): - self.infos.append(msg) - - def exception(self, msg): - self.exceptions.append(msg) - - -class DummyError(object): - code = "431" - reason = "Bleh" - body = "My body" - - -class DummyTaskClass(object): - wrote_header = True - close_on_finish = False - serviced = False - - def __init__(self, toraise=None): - self.toraise = toraise - - def __call__(self, channel, request): - self.request = request - return self - - def service(self): - self.serviced = True - self.request.serviced = True - if self.toraise: - raise self.toraise diff --git a/libs/waitress/tests/test_compat.py b/libs/waitress/tests/test_compat.py deleted file mode 100644 index 37c219303..000000000 --- a/libs/waitress/tests/test_compat.py +++ /dev/null @@ -1,22 +0,0 @@ -# -*- coding: utf-8 -*- - -import unittest - - -class Test_unquote_bytes_to_wsgi(unittest.TestCase): - def _callFUT(self, v): - from waitress.compat import unquote_bytes_to_wsgi - - return unquote_bytes_to_wsgi(v) - - def test_highorder(self): - from waitress.compat import PY3 - - val = b"/a%C5%9B" - result = self._callFUT(val) - if PY3: # pragma: no cover - # PEP 3333 urlunquoted-latin1-decoded-bytes - self.assertEqual(result, "/aÅ\x9b") - else: # pragma: no cover - # sanity - self.assertEqual(result, b"/a\xc5\x9b") diff --git a/libs/waitress/tests/test_functional.py b/libs/waitress/tests/test_functional.py deleted file mode 100644 index 8f4b262fe..000000000 --- a/libs/waitress/tests/test_functional.py +++ /dev/null @@ -1,1667 +0,0 @@ -import errno -import logging -import multiprocessing -import os -import signal -import socket -import string -import subprocess -import sys -import time -import unittest -from waitress import server -from waitress.compat import httplib, tobytes -from waitress.utilities import cleanup_unix_socket - -dn = os.path.dirname -here = dn(__file__) - - -class NullHandler(logging.Handler): # pragma: no cover - """A logging handler that swallows all emitted messages. - """ - - def emit(self, record): - pass - - -def start_server(app, svr, queue, **kwargs): # pragma: no cover - """Run a fixture application. - """ - logging.getLogger("waitress").addHandler(NullHandler()) - try_register_coverage() - svr(app, queue, **kwargs).run() - - -def try_register_coverage(): # pragma: no cover - # Hack around multiprocessing exiting early and not triggering coverage's - # atexit handler by always registering a signal handler - - if "COVERAGE_PROCESS_START" in os.environ: - def sigterm(*args): - sys.exit(0) - - signal.signal(signal.SIGTERM, sigterm) - - -class FixtureTcpWSGIServer(server.TcpWSGIServer): - """A version of TcpWSGIServer that relays back what it's bound to. - """ - - family = socket.AF_INET # Testing - - def __init__(self, application, queue, **kw): # pragma: no cover - # Coverage doesn't see this as it's ran in a separate process. - kw["port"] = 0 # Bind to any available port. - super(FixtureTcpWSGIServer, self).__init__(application, **kw) - host, port = self.socket.getsockname() - if os.name == "nt": - host = "127.0.0.1" - queue.put((host, port)) - - -class SubprocessTests(object): - - # For nose: all tests may be ran in separate processes. - _multiprocess_can_split_ = True - - exe = sys.executable - - server = None - - def start_subprocess(self, target, **kw): - # Spawn a server process. - self.queue = multiprocessing.Queue() - - if "COVERAGE_RCFILE" in os.environ: - os.environ["COVERAGE_PROCESS_START"] = os.environ["COVERAGE_RCFILE"] - - self.proc = multiprocessing.Process( - target=start_server, args=(target, self.server, self.queue), kwargs=kw, - ) - self.proc.start() - - if self.proc.exitcode is not None: # pragma: no cover - raise RuntimeError("%s didn't start" % str(target)) - # Get the socket the server is listening on. - self.bound_to = self.queue.get(timeout=5) - self.sock = self.create_socket() - - def stop_subprocess(self): - if self.proc.exitcode is None: - self.proc.terminate() - self.sock.close() - # This give us one FD back ... - self.queue.close() - self.proc.join() - - def assertline(self, line, status, reason, version): - v, s, r = (x.strip() for x in line.split(None, 2)) - self.assertEqual(s, tobytes(status)) - self.assertEqual(r, tobytes(reason)) - self.assertEqual(v, tobytes(version)) - - def create_socket(self): - return socket.socket(self.server.family, socket.SOCK_STREAM) - - def connect(self): - self.sock.connect(self.bound_to) - - def make_http_connection(self): - raise NotImplementedError # pragma: no cover - - def send_check_error(self, to_send): - self.sock.send(to_send) - - -class TcpTests(SubprocessTests): - - server = FixtureTcpWSGIServer - - def make_http_connection(self): - return httplib.HTTPConnection(*self.bound_to) - - -class SleepyThreadTests(TcpTests, unittest.TestCase): - # test that sleepy thread doesnt block other requests - - def setUp(self): - from waitress.tests.fixtureapps import sleepy - - self.start_subprocess(sleepy.app) - - def tearDown(self): - self.stop_subprocess() - - def test_it(self): - getline = os.path.join(here, "fixtureapps", "getline.py") - cmds = ( - [self.exe, getline, "http://%s:%d/sleepy" % self.bound_to], - [self.exe, getline, "http://%s:%d/" % self.bound_to], - ) - r, w = os.pipe() - procs = [] - for cmd in cmds: - procs.append(subprocess.Popen(cmd, stdout=w)) - time.sleep(3) - for proc in procs: - if proc.returncode is not None: # pragma: no cover - proc.terminate() - proc.wait() - # the notsleepy response should always be first returned (it sleeps - # for 2 seconds, then returns; the notsleepy response should be - # processed in the meantime) - result = os.read(r, 10000) - os.close(r) - os.close(w) - self.assertEqual(result, b"notsleepy returnedsleepy returned") - - -class EchoTests(object): - def setUp(self): - from waitress.tests.fixtureapps import echo - - self.start_subprocess( - echo.app, - trusted_proxy="*", - trusted_proxy_count=1, - trusted_proxy_headers={"x-forwarded-for", "x-forwarded-proto"}, - clear_untrusted_proxy_headers=True, - ) - - def tearDown(self): - self.stop_subprocess() - - def _read_echo(self, fp): - from waitress.tests.fixtureapps import echo - - line, headers, body = read_http(fp) - return line, headers, echo.parse_response(body) - - def test_date_and_server(self): - to_send = "GET / HTTP/1.0\r\nContent-Length: 0\r\n\r\n" - to_send = tobytes(to_send) - self.connect() - self.sock.send(to_send) - fp = self.sock.makefile("rb", 0) - line, headers, echo = self._read_echo(fp) - self.assertline(line, "200", "OK", "HTTP/1.0") - self.assertEqual(headers.get("server"), "waitress") - self.assertTrue(headers.get("date")) - - def test_bad_host_header(self): - # https://corte.si/posts/code/pathod/pythonservers/index.html - to_send = "GET / HTTP/1.0\r\n Host: 0\r\n\r\n" - to_send = tobytes(to_send) - self.connect() - self.sock.send(to_send) - fp = self.sock.makefile("rb", 0) - line, headers, response_body = read_http(fp) - self.assertline(line, "400", "Bad Request", "HTTP/1.0") - self.assertEqual(headers.get("server"), "waitress") - self.assertTrue(headers.get("date")) - - def test_send_with_body(self): - to_send = "GET / HTTP/1.0\r\nContent-Length: 5\r\n\r\n" - to_send += "hello" - to_send = tobytes(to_send) - self.connect() - self.sock.send(to_send) - fp = self.sock.makefile("rb", 0) - line, headers, echo = self._read_echo(fp) - self.assertline(line, "200", "OK", "HTTP/1.0") - self.assertEqual(echo.content_length, "5") - self.assertEqual(echo.body, b"hello") - - def test_send_empty_body(self): - to_send = "GET / HTTP/1.0\r\nContent-Length: 0\r\n\r\n" - to_send = tobytes(to_send) - self.connect() - self.sock.send(to_send) - fp = self.sock.makefile("rb", 0) - line, headers, echo = self._read_echo(fp) - self.assertline(line, "200", "OK", "HTTP/1.0") - self.assertEqual(echo.content_length, "0") - self.assertEqual(echo.body, b"") - - def test_multiple_requests_with_body(self): - orig_sock = self.sock - for x in range(3): - self.sock = self.create_socket() - self.test_send_with_body() - self.sock.close() - self.sock = orig_sock - - def test_multiple_requests_without_body(self): - orig_sock = self.sock - for x in range(3): - self.sock = self.create_socket() - self.test_send_empty_body() - self.sock.close() - self.sock = orig_sock - - def test_without_crlf(self): - data = "Echo\r\nthis\r\nplease" - s = tobytes( - "GET / HTTP/1.0\r\n" - "Connection: close\r\n" - "Content-Length: %d\r\n" - "\r\n" - "%s" % (len(data), data) - ) - self.connect() - self.sock.send(s) - fp = self.sock.makefile("rb", 0) - line, headers, echo = self._read_echo(fp) - self.assertline(line, "200", "OK", "HTTP/1.0") - self.assertEqual(int(echo.content_length), len(data)) - self.assertEqual(len(echo.body), len(data)) - self.assertEqual(echo.body, tobytes(data)) - - def test_large_body(self): - # 1024 characters. - body = "This string has 32 characters.\r\n" * 32 - s = tobytes( - "GET / HTTP/1.0\r\nContent-Length: %d\r\n\r\n%s" % (len(body), body) - ) - self.connect() - self.sock.send(s) - fp = self.sock.makefile("rb", 0) - line, headers, echo = self._read_echo(fp) - self.assertline(line, "200", "OK", "HTTP/1.0") - self.assertEqual(echo.content_length, "1024") - self.assertEqual(echo.body, tobytes(body)) - - def test_many_clients(self): - conns = [] - for n in range(50): - h = self.make_http_connection() - h.request("GET", "/", headers={"Accept": "text/plain"}) - conns.append(h) - responses = [] - for h in conns: - response = h.getresponse() - self.assertEqual(response.status, 200) - responses.append(response) - for response in responses: - response.read() - for h in conns: - h.close() - - def test_chunking_request_without_content(self): - header = tobytes("GET / HTTP/1.1\r\nTransfer-Encoding: chunked\r\n\r\n") - self.connect() - self.sock.send(header) - self.sock.send(b"0\r\n\r\n") - fp = self.sock.makefile("rb", 0) - line, headers, echo = self._read_echo(fp) - self.assertline(line, "200", "OK", "HTTP/1.1") - self.assertEqual(echo.body, b"") - self.assertEqual(echo.content_length, "0") - self.assertFalse("transfer-encoding" in headers) - - def test_chunking_request_with_content(self): - control_line = b"20;\r\n" # 20 hex = 32 dec - s = b"This string has 32 characters.\r\n" - expected = s * 12 - header = tobytes("GET / HTTP/1.1\r\nTransfer-Encoding: chunked\r\n\r\n") - self.connect() - self.sock.send(header) - fp = self.sock.makefile("rb", 0) - for n in range(12): - self.sock.send(control_line) - self.sock.send(s) - self.sock.send(b"\r\n") # End the chunk - self.sock.send(b"0\r\n\r\n") - line, headers, echo = self._read_echo(fp) - self.assertline(line, "200", "OK", "HTTP/1.1") - self.assertEqual(echo.body, expected) - self.assertEqual(echo.content_length, str(len(expected))) - self.assertFalse("transfer-encoding" in headers) - - def test_broken_chunked_encoding(self): - control_line = "20;\r\n" # 20 hex = 32 dec - s = "This string has 32 characters.\r\n" - to_send = "GET / HTTP/1.1\r\nTransfer-Encoding: chunked\r\n\r\n" - to_send += control_line + s + "\r\n" - # garbage in input - to_send += "garbage\r\n" - to_send = tobytes(to_send) - self.connect() - self.sock.send(to_send) - fp = self.sock.makefile("rb", 0) - line, headers, response_body = read_http(fp) - # receiver caught garbage and turned it into a 400 - self.assertline(line, "400", "Bad Request", "HTTP/1.1") - cl = int(headers["content-length"]) - self.assertEqual(cl, len(response_body)) - self.assertEqual( - sorted(headers.keys()), ["connection", "content-length", "content-type", "date", "server"] - ) - self.assertEqual(headers["content-type"], "text/plain") - # connection has been closed - self.send_check_error(to_send) - self.assertRaises(ConnectionClosed, read_http, fp) - - def test_broken_chunked_encoding_missing_chunk_end(self): - control_line = "20;\r\n" # 20 hex = 32 dec - s = "This string has 32 characters.\r\n" - to_send = "GET / HTTP/1.1\r\nTransfer-Encoding: chunked\r\n\r\n" - to_send += control_line + s - # garbage in input - to_send += "garbage" - to_send = tobytes(to_send) - self.connect() - self.sock.send(to_send) - fp = self.sock.makefile("rb", 0) - line, headers, response_body = read_http(fp) - # receiver caught garbage and turned it into a 400 - self.assertline(line, "400", "Bad Request", "HTTP/1.1") - cl = int(headers["content-length"]) - self.assertEqual(cl, len(response_body)) - self.assertTrue(b"Chunk not properly terminated" in response_body) - self.assertEqual( - sorted(headers.keys()), ["connection", "content-length", "content-type", "date", "server"] - ) - self.assertEqual(headers["content-type"], "text/plain") - # connection has been closed - self.send_check_error(to_send) - self.assertRaises(ConnectionClosed, read_http, fp) - - def test_keepalive_http_10(self): - # Handling of Keep-Alive within HTTP 1.0 - data = "Default: Don't keep me alive" - s = tobytes( - "GET / HTTP/1.0\r\nContent-Length: %d\r\n\r\n%s" % (len(data), data) - ) - self.connect() - self.sock.send(s) - response = httplib.HTTPResponse(self.sock) - response.begin() - self.assertEqual(int(response.status), 200) - connection = response.getheader("Connection", "") - # We sent no Connection: Keep-Alive header - # Connection: close (or no header) is default. - self.assertTrue(connection != "Keep-Alive") - - def test_keepalive_http10_explicit(self): - # If header Connection: Keep-Alive is explicitly sent, - # we want to keept the connection open, we also need to return - # the corresponding header - data = "Keep me alive" - s = tobytes( - "GET / HTTP/1.0\r\n" - "Connection: Keep-Alive\r\n" - "Content-Length: %d\r\n" - "\r\n" - "%s" % (len(data), data) - ) - self.connect() - self.sock.send(s) - response = httplib.HTTPResponse(self.sock) - response.begin() - self.assertEqual(int(response.status), 200) - connection = response.getheader("Connection", "") - self.assertEqual(connection, "Keep-Alive") - - def test_keepalive_http_11(self): - # Handling of Keep-Alive within HTTP 1.1 - - # All connections are kept alive, unless stated otherwise - data = "Default: Keep me alive" - s = tobytes( - "GET / HTTP/1.1\r\nContent-Length: %d\r\n\r\n%s" % (len(data), data) - ) - self.connect() - self.sock.send(s) - response = httplib.HTTPResponse(self.sock) - response.begin() - self.assertEqual(int(response.status), 200) - self.assertTrue(response.getheader("connection") != "close") - - def test_keepalive_http11_explicit(self): - # Explicitly set keep-alive - data = "Default: Keep me alive" - s = tobytes( - "GET / HTTP/1.1\r\n" - "Connection: keep-alive\r\n" - "Content-Length: %d\r\n" - "\r\n" - "%s" % (len(data), data) - ) - self.connect() - self.sock.send(s) - response = httplib.HTTPResponse(self.sock) - response.begin() - self.assertEqual(int(response.status), 200) - self.assertTrue(response.getheader("connection") != "close") - - def test_keepalive_http11_connclose(self): - # specifying Connection: close explicitly - data = "Don't keep me alive" - s = tobytes( - "GET / HTTP/1.1\r\n" - "Connection: close\r\n" - "Content-Length: %d\r\n" - "\r\n" - "%s" % (len(data), data) - ) - self.connect() - self.sock.send(s) - response = httplib.HTTPResponse(self.sock) - response.begin() - self.assertEqual(int(response.status), 200) - self.assertEqual(response.getheader("connection"), "close") - - def test_proxy_headers(self): - to_send = ( - "GET / HTTP/1.0\r\n" - "Content-Length: 0\r\n" - "Host: www.google.com:8080\r\n" - "X-Forwarded-For: 192.168.1.1\r\n" - "X-Forwarded-Proto: https\r\n" - "X-Forwarded-Port: 5000\r\n\r\n" - ) - to_send = tobytes(to_send) - self.connect() - self.sock.send(to_send) - fp = self.sock.makefile("rb", 0) - line, headers, echo = self._read_echo(fp) - self.assertline(line, "200", "OK", "HTTP/1.0") - self.assertEqual(headers.get("server"), "waitress") - self.assertTrue(headers.get("date")) - self.assertIsNone(echo.headers.get("X_FORWARDED_PORT")) - self.assertEqual(echo.headers["HOST"], "www.google.com:8080") - self.assertEqual(echo.scheme, "https") - self.assertEqual(echo.remote_addr, "192.168.1.1") - self.assertEqual(echo.remote_host, "192.168.1.1") - - -class PipeliningTests(object): - def setUp(self): - from waitress.tests.fixtureapps import echo - - self.start_subprocess(echo.app_body_only) - - def tearDown(self): - self.stop_subprocess() - - def test_pipelining(self): - s = ( - "GET / HTTP/1.0\r\n" - "Connection: %s\r\n" - "Content-Length: %d\r\n" - "\r\n" - "%s" - ) - to_send = b"" - count = 25 - for n in range(count): - body = "Response #%d\r\n" % (n + 1) - if n + 1 < count: - conn = "keep-alive" - else: - conn = "close" - to_send += tobytes(s % (conn, len(body), body)) - - self.connect() - self.sock.send(to_send) - fp = self.sock.makefile("rb", 0) - for n in range(count): - expect_body = tobytes("Response #%d\r\n" % (n + 1)) - line = fp.readline() # status line - version, status, reason = (x.strip() for x in line.split(None, 2)) - headers = parse_headers(fp) - length = int(headers.get("content-length")) or None - response_body = fp.read(length) - self.assertEqual(int(status), 200) - self.assertEqual(length, len(response_body)) - self.assertEqual(response_body, expect_body) - - -class ExpectContinueTests(object): - def setUp(self): - from waitress.tests.fixtureapps import echo - - self.start_subprocess(echo.app_body_only) - - def tearDown(self): - self.stop_subprocess() - - def test_expect_continue(self): - # specifying Connection: close explicitly - data = "I have expectations" - to_send = tobytes( - "GET / HTTP/1.1\r\n" - "Connection: close\r\n" - "Content-Length: %d\r\n" - "Expect: 100-continue\r\n" - "\r\n" - "%s" % (len(data), data) - ) - self.connect() - self.sock.send(to_send) - fp = self.sock.makefile("rb", 0) - line = fp.readline() # continue status line - version, status, reason = (x.strip() for x in line.split(None, 2)) - self.assertEqual(int(status), 100) - self.assertEqual(reason, b"Continue") - self.assertEqual(version, b"HTTP/1.1") - fp.readline() # blank line - line = fp.readline() # next status line - version, status, reason = (x.strip() for x in line.split(None, 2)) - headers = parse_headers(fp) - length = int(headers.get("content-length")) or None - response_body = fp.read(length) - self.assertEqual(int(status), 200) - self.assertEqual(length, len(response_body)) - self.assertEqual(response_body, tobytes(data)) - - -class BadContentLengthTests(object): - def setUp(self): - from waitress.tests.fixtureapps import badcl - - self.start_subprocess(badcl.app) - - def tearDown(self): - self.stop_subprocess() - - def test_short_body(self): - # check to see if server closes connection when body is too short - # for cl header - to_send = tobytes( - "GET /short_body HTTP/1.0\r\n" - "Connection: Keep-Alive\r\n" - "Content-Length: 0\r\n" - "\r\n" - ) - self.connect() - self.sock.send(to_send) - fp = self.sock.makefile("rb", 0) - line = fp.readline() # status line - version, status, reason = (x.strip() for x in line.split(None, 2)) - headers = parse_headers(fp) - content_length = int(headers.get("content-length")) - response_body = fp.read(content_length) - self.assertEqual(int(status), 200) - self.assertNotEqual(content_length, len(response_body)) - self.assertEqual(len(response_body), content_length - 1) - self.assertEqual(response_body, tobytes("abcdefghi")) - # remote closed connection (despite keepalive header); not sure why - # first send succeeds - self.send_check_error(to_send) - self.assertRaises(ConnectionClosed, read_http, fp) - - def test_long_body(self): - # check server doesnt close connection when body is too short - # for cl header - to_send = tobytes( - "GET /long_body HTTP/1.0\r\n" - "Connection: Keep-Alive\r\n" - "Content-Length: 0\r\n" - "\r\n" - ) - self.connect() - self.sock.send(to_send) - fp = self.sock.makefile("rb", 0) - line = fp.readline() # status line - version, status, reason = (x.strip() for x in line.split(None, 2)) - headers = parse_headers(fp) - content_length = int(headers.get("content-length")) or None - response_body = fp.read(content_length) - self.assertEqual(int(status), 200) - self.assertEqual(content_length, len(response_body)) - self.assertEqual(response_body, tobytes("abcdefgh")) - # remote does not close connection (keepalive header) - self.sock.send(to_send) - fp = self.sock.makefile("rb", 0) - line = fp.readline() # status line - version, status, reason = (x.strip() for x in line.split(None, 2)) - headers = parse_headers(fp) - content_length = int(headers.get("content-length")) or None - response_body = fp.read(content_length) - self.assertEqual(int(status), 200) - - -class NoContentLengthTests(object): - def setUp(self): - from waitress.tests.fixtureapps import nocl - - self.start_subprocess(nocl.app) - - def tearDown(self): - self.stop_subprocess() - - def test_http10_generator(self): - body = string.ascii_letters - to_send = ( - "GET / HTTP/1.0\r\n" - "Connection: Keep-Alive\r\n" - "Content-Length: %d\r\n\r\n" % len(body) - ) - to_send += body - to_send = tobytes(to_send) - self.connect() - self.sock.send(to_send) - fp = self.sock.makefile("rb", 0) - line, headers, response_body = read_http(fp) - self.assertline(line, "200", "OK", "HTTP/1.0") - self.assertEqual(headers.get("content-length"), None) - self.assertEqual(headers.get("connection"), "close") - self.assertEqual(response_body, tobytes(body)) - # remote closed connection (despite keepalive header), because - # generators cannot have a content-length divined - self.send_check_error(to_send) - self.assertRaises(ConnectionClosed, read_http, fp) - - def test_http10_list(self): - body = string.ascii_letters - to_send = ( - "GET /list HTTP/1.0\r\n" - "Connection: Keep-Alive\r\n" - "Content-Length: %d\r\n\r\n" % len(body) - ) - to_send += body - to_send = tobytes(to_send) - self.connect() - self.sock.send(to_send) - fp = self.sock.makefile("rb", 0) - line, headers, response_body = read_http(fp) - self.assertline(line, "200", "OK", "HTTP/1.0") - self.assertEqual(headers["content-length"], str(len(body))) - self.assertEqual(headers.get("connection"), "Keep-Alive") - self.assertEqual(response_body, tobytes(body)) - # remote keeps connection open because it divined the content length - # from a length-1 list - self.sock.send(to_send) - line, headers, response_body = read_http(fp) - self.assertline(line, "200", "OK", "HTTP/1.0") - - def test_http10_listlentwo(self): - body = string.ascii_letters - to_send = ( - "GET /list_lentwo HTTP/1.0\r\n" - "Connection: Keep-Alive\r\n" - "Content-Length: %d\r\n\r\n" % len(body) - ) - to_send += body - to_send = tobytes(to_send) - self.connect() - self.sock.send(to_send) - fp = self.sock.makefile("rb", 0) - line, headers, response_body = read_http(fp) - self.assertline(line, "200", "OK", "HTTP/1.0") - self.assertEqual(headers.get("content-length"), None) - self.assertEqual(headers.get("connection"), "close") - self.assertEqual(response_body, tobytes(body)) - # remote closed connection (despite keepalive header), because - # lists of length > 1 cannot have their content length divined - self.send_check_error(to_send) - self.assertRaises(ConnectionClosed, read_http, fp) - - def test_http11_generator(self): - body = string.ascii_letters - to_send = "GET / HTTP/1.1\r\nContent-Length: %s\r\n\r\n" % len(body) - to_send += body - to_send = tobytes(to_send) - self.connect() - self.sock.send(to_send) - fp = self.sock.makefile("rb") - line, headers, response_body = read_http(fp) - self.assertline(line, "200", "OK", "HTTP/1.1") - expected = b"" - for chunk in chunks(body, 10): - expected += tobytes( - "%s\r\n%s\r\n" % (str(hex(len(chunk))[2:].upper()), chunk) - ) - expected += b"0\r\n\r\n" - self.assertEqual(response_body, expected) - # connection is always closed at the end of a chunked response - self.send_check_error(to_send) - self.assertRaises(ConnectionClosed, read_http, fp) - - def test_http11_list(self): - body = string.ascii_letters - to_send = "GET /list HTTP/1.1\r\nContent-Length: %d\r\n\r\n" % len(body) - to_send += body - to_send = tobytes(to_send) - self.connect() - self.sock.send(to_send) - fp = self.sock.makefile("rb", 0) - line, headers, response_body = read_http(fp) - self.assertline(line, "200", "OK", "HTTP/1.1") - self.assertEqual(headers["content-length"], str(len(body))) - self.assertEqual(response_body, tobytes(body)) - # remote keeps connection open because it divined the content length - # from a length-1 list - self.sock.send(to_send) - line, headers, response_body = read_http(fp) - self.assertline(line, "200", "OK", "HTTP/1.1") - - def test_http11_listlentwo(self): - body = string.ascii_letters - to_send = "GET /list_lentwo HTTP/1.1\r\nContent-Length: %s\r\n\r\n" % len(body) - to_send += body - to_send = tobytes(to_send) - self.connect() - self.sock.send(to_send) - fp = self.sock.makefile("rb") - line, headers, response_body = read_http(fp) - self.assertline(line, "200", "OK", "HTTP/1.1") - expected = b"" - for chunk in (body[0], body[1:]): - expected += tobytes( - "%s\r\n%s\r\n" % (str(hex(len(chunk))[2:].upper()), chunk) - ) - expected += b"0\r\n\r\n" - self.assertEqual(response_body, expected) - # connection is always closed at the end of a chunked response - self.send_check_error(to_send) - self.assertRaises(ConnectionClosed, read_http, fp) - - -class WriteCallbackTests(object): - def setUp(self): - from waitress.tests.fixtureapps import writecb - - self.start_subprocess(writecb.app) - - def tearDown(self): - self.stop_subprocess() - - def test_short_body(self): - # check to see if server closes connection when body is too short - # for cl header - to_send = tobytes( - "GET /short_body HTTP/1.0\r\n" - "Connection: Keep-Alive\r\n" - "Content-Length: 0\r\n" - "\r\n" - ) - self.connect() - self.sock.send(to_send) - fp = self.sock.makefile("rb", 0) - line, headers, response_body = read_http(fp) - # server trusts the content-length header (5) - self.assertline(line, "200", "OK", "HTTP/1.0") - cl = int(headers["content-length"]) - self.assertEqual(cl, 9) - self.assertNotEqual(cl, len(response_body)) - self.assertEqual(len(response_body), cl - 1) - self.assertEqual(response_body, tobytes("abcdefgh")) - # remote closed connection (despite keepalive header) - self.send_check_error(to_send) - self.assertRaises(ConnectionClosed, read_http, fp) - - def test_long_body(self): - # check server doesnt close connection when body is too long - # for cl header - to_send = tobytes( - "GET /long_body HTTP/1.0\r\n" - "Connection: Keep-Alive\r\n" - "Content-Length: 0\r\n" - "\r\n" - ) - self.connect() - self.sock.send(to_send) - fp = self.sock.makefile("rb", 0) - line, headers, response_body = read_http(fp) - content_length = int(headers.get("content-length")) or None - self.assertEqual(content_length, 9) - self.assertEqual(content_length, len(response_body)) - self.assertEqual(response_body, tobytes("abcdefghi")) - # remote does not close connection (keepalive header) - self.sock.send(to_send) - fp = self.sock.makefile("rb", 0) - line, headers, response_body = read_http(fp) - self.assertline(line, "200", "OK", "HTTP/1.0") - - def test_equal_body(self): - # check server doesnt close connection when body is equal to - # cl header - to_send = tobytes( - "GET /equal_body HTTP/1.0\r\n" - "Connection: Keep-Alive\r\n" - "Content-Length: 0\r\n" - "\r\n" - ) - self.connect() - self.sock.send(to_send) - fp = self.sock.makefile("rb", 0) - line, headers, response_body = read_http(fp) - content_length = int(headers.get("content-length")) or None - self.assertEqual(content_length, 9) - self.assertline(line, "200", "OK", "HTTP/1.0") - self.assertEqual(content_length, len(response_body)) - self.assertEqual(response_body, tobytes("abcdefghi")) - # remote does not close connection (keepalive header) - self.sock.send(to_send) - fp = self.sock.makefile("rb", 0) - line, headers, response_body = read_http(fp) - self.assertline(line, "200", "OK", "HTTP/1.0") - - def test_no_content_length(self): - # wtf happens when there's no content-length - to_send = tobytes( - "GET /no_content_length HTTP/1.0\r\n" - "Connection: Keep-Alive\r\n" - "Content-Length: 0\r\n" - "\r\n" - ) - self.connect() - self.sock.send(to_send) - fp = self.sock.makefile("rb", 0) - line = fp.readline() # status line - line, headers, response_body = read_http(fp) - content_length = headers.get("content-length") - self.assertEqual(content_length, None) - self.assertEqual(response_body, tobytes("abcdefghi")) - # remote closed connection (despite keepalive header) - self.send_check_error(to_send) - self.assertRaises(ConnectionClosed, read_http, fp) - - -class TooLargeTests(object): - - toobig = 1050 - - def setUp(self): - from waitress.tests.fixtureapps import toolarge - - self.start_subprocess( - toolarge.app, max_request_header_size=1000, max_request_body_size=1000 - ) - - def tearDown(self): - self.stop_subprocess() - - def test_request_body_too_large_with_wrong_cl_http10(self): - body = "a" * self.toobig - to_send = "GET / HTTP/1.0\r\nContent-Length: 5\r\n\r\n" - to_send += body - to_send = tobytes(to_send) - self.connect() - self.sock.send(to_send) - fp = self.sock.makefile("rb") - # first request succeeds (content-length 5) - line, headers, response_body = read_http(fp) - self.assertline(line, "200", "OK", "HTTP/1.0") - cl = int(headers["content-length"]) - self.assertEqual(cl, len(response_body)) - # server trusts the content-length header; no pipelining, - # so request fulfilled, extra bytes are thrown away - # connection has been closed - self.send_check_error(to_send) - self.assertRaises(ConnectionClosed, read_http, fp) - - def test_request_body_too_large_with_wrong_cl_http10_keepalive(self): - body = "a" * self.toobig - to_send = "GET / HTTP/1.0\r\nContent-Length: 5\r\nConnection: Keep-Alive\r\n\r\n" - to_send += body - to_send = tobytes(to_send) - self.connect() - self.sock.send(to_send) - fp = self.sock.makefile("rb") - # first request succeeds (content-length 5) - line, headers, response_body = read_http(fp) - self.assertline(line, "200", "OK", "HTTP/1.0") - cl = int(headers["content-length"]) - self.assertEqual(cl, len(response_body)) - line, headers, response_body = read_http(fp) - self.assertline(line, "431", "Request Header Fields Too Large", "HTTP/1.0") - cl = int(headers["content-length"]) - self.assertEqual(cl, len(response_body)) - # connection has been closed - self.send_check_error(to_send) - self.assertRaises(ConnectionClosed, read_http, fp) - - def test_request_body_too_large_with_no_cl_http10(self): - body = "a" * self.toobig - to_send = "GET / HTTP/1.0\r\n\r\n" - to_send += body - to_send = tobytes(to_send) - self.connect() - self.sock.send(to_send) - fp = self.sock.makefile("rb", 0) - line, headers, response_body = read_http(fp) - self.assertline(line, "200", "OK", "HTTP/1.0") - cl = int(headers["content-length"]) - self.assertEqual(cl, len(response_body)) - # extra bytes are thrown away (no pipelining), connection closed - self.send_check_error(to_send) - self.assertRaises(ConnectionClosed, read_http, fp) - - def test_request_body_too_large_with_no_cl_http10_keepalive(self): - body = "a" * self.toobig - to_send = "GET / HTTP/1.0\r\nConnection: Keep-Alive\r\n\r\n" - to_send += body - to_send = tobytes(to_send) - self.connect() - self.sock.send(to_send) - fp = self.sock.makefile("rb", 0) - line, headers, response_body = read_http(fp) - # server trusts the content-length header (assumed zero) - self.assertline(line, "200", "OK", "HTTP/1.0") - cl = int(headers["content-length"]) - self.assertEqual(cl, len(response_body)) - line, headers, response_body = read_http(fp) - # next response overruns because the extra data appears to be - # header data - self.assertline(line, "431", "Request Header Fields Too Large", "HTTP/1.0") - cl = int(headers["content-length"]) - self.assertEqual(cl, len(response_body)) - # connection has been closed - self.send_check_error(to_send) - self.assertRaises(ConnectionClosed, read_http, fp) - - def test_request_body_too_large_with_wrong_cl_http11(self): - body = "a" * self.toobig - to_send = "GET / HTTP/1.1\r\nContent-Length: 5\r\n\r\n" - to_send += body - to_send = tobytes(to_send) - self.connect() - self.sock.send(to_send) - fp = self.sock.makefile("rb") - # first request succeeds (content-length 5) - line, headers, response_body = read_http(fp) - self.assertline(line, "200", "OK", "HTTP/1.1") - cl = int(headers["content-length"]) - self.assertEqual(cl, len(response_body)) - # second response is an error response - line, headers, response_body = read_http(fp) - self.assertline(line, "431", "Request Header Fields Too Large", "HTTP/1.0") - cl = int(headers["content-length"]) - self.assertEqual(cl, len(response_body)) - # connection has been closed - self.send_check_error(to_send) - self.assertRaises(ConnectionClosed, read_http, fp) - - def test_request_body_too_large_with_wrong_cl_http11_connclose(self): - body = "a" * self.toobig - to_send = "GET / HTTP/1.1\r\nContent-Length: 5\r\nConnection: close\r\n\r\n" - to_send += body - to_send = tobytes(to_send) - self.connect() - self.sock.send(to_send) - fp = self.sock.makefile("rb", 0) - line, headers, response_body = read_http(fp) - # server trusts the content-length header (5) - self.assertline(line, "200", "OK", "HTTP/1.1") - cl = int(headers["content-length"]) - self.assertEqual(cl, len(response_body)) - # connection has been closed - self.send_check_error(to_send) - self.assertRaises(ConnectionClosed, read_http, fp) - - def test_request_body_too_large_with_no_cl_http11(self): - body = "a" * self.toobig - to_send = "GET / HTTP/1.1\r\n\r\n" - to_send += body - to_send = tobytes(to_send) - self.connect() - self.sock.send(to_send) - fp = self.sock.makefile("rb") - # server trusts the content-length header (assumed 0) - line, headers, response_body = read_http(fp) - self.assertline(line, "200", "OK", "HTTP/1.1") - cl = int(headers["content-length"]) - self.assertEqual(cl, len(response_body)) - # server assumes pipelined requests due to http/1.1, and the first - # request was assumed c-l 0 because it had no content-length header, - # so entire body looks like the header of the subsequent request - # second response is an error response - line, headers, response_body = read_http(fp) - self.assertline(line, "431", "Request Header Fields Too Large", "HTTP/1.0") - cl = int(headers["content-length"]) - self.assertEqual(cl, len(response_body)) - # connection has been closed - self.send_check_error(to_send) - self.assertRaises(ConnectionClosed, read_http, fp) - - def test_request_body_too_large_with_no_cl_http11_connclose(self): - body = "a" * self.toobig - to_send = "GET / HTTP/1.1\r\nConnection: close\r\n\r\n" - to_send += body - to_send = tobytes(to_send) - self.connect() - self.sock.send(to_send) - fp = self.sock.makefile("rb", 0) - line, headers, response_body = read_http(fp) - # server trusts the content-length header (assumed 0) - self.assertline(line, "200", "OK", "HTTP/1.1") - cl = int(headers["content-length"]) - self.assertEqual(cl, len(response_body)) - # connection has been closed - self.send_check_error(to_send) - self.assertRaises(ConnectionClosed, read_http, fp) - - def test_request_body_too_large_chunked_encoding(self): - control_line = "20;\r\n" # 20 hex = 32 dec - s = "This string has 32 characters.\r\n" - to_send = "GET / HTTP/1.1\r\nTransfer-Encoding: chunked\r\n\r\n" - repeat = control_line + s - to_send += repeat * ((self.toobig // len(repeat)) + 1) - to_send = tobytes(to_send) - self.connect() - self.sock.send(to_send) - fp = self.sock.makefile("rb", 0) - line, headers, response_body = read_http(fp) - # body bytes counter caught a max_request_body_size overrun - self.assertline(line, "413", "Request Entity Too Large", "HTTP/1.1") - cl = int(headers["content-length"]) - self.assertEqual(cl, len(response_body)) - self.assertEqual(headers["content-type"], "text/plain") - # connection has been closed - self.send_check_error(to_send) - self.assertRaises(ConnectionClosed, read_http, fp) - - -class InternalServerErrorTests(object): - def setUp(self): - from waitress.tests.fixtureapps import error - - self.start_subprocess(error.app, expose_tracebacks=True) - - def tearDown(self): - self.stop_subprocess() - - def test_before_start_response_http_10(self): - to_send = "GET /before_start_response HTTP/1.0\r\n\r\n" - to_send = tobytes(to_send) - self.connect() - self.sock.send(to_send) - fp = self.sock.makefile("rb", 0) - line, headers, response_body = read_http(fp) - self.assertline(line, "500", "Internal Server Error", "HTTP/1.0") - cl = int(headers["content-length"]) - self.assertEqual(cl, len(response_body)) - self.assertTrue(response_body.startswith(b"Internal Server Error")) - self.assertEqual(headers["connection"], "close") - # connection has been closed - self.send_check_error(to_send) - self.assertRaises(ConnectionClosed, read_http, fp) - - def test_before_start_response_http_11(self): - to_send = "GET /before_start_response HTTP/1.1\r\n\r\n" - to_send = tobytes(to_send) - self.connect() - self.sock.send(to_send) - fp = self.sock.makefile("rb", 0) - line, headers, response_body = read_http(fp) - self.assertline(line, "500", "Internal Server Error", "HTTP/1.1") - cl = int(headers["content-length"]) - self.assertEqual(cl, len(response_body)) - self.assertTrue(response_body.startswith(b"Internal Server Error")) - self.assertEqual( - sorted(headers.keys()), ["connection", "content-length", "content-type", "date", "server"] - ) - # connection has been closed - self.send_check_error(to_send) - self.assertRaises(ConnectionClosed, read_http, fp) - - def test_before_start_response_http_11_close(self): - to_send = tobytes( - "GET /before_start_response HTTP/1.1\r\nConnection: close\r\n\r\n" - ) - self.connect() - self.sock.send(to_send) - fp = self.sock.makefile("rb", 0) - line, headers, response_body = read_http(fp) - self.assertline(line, "500", "Internal Server Error", "HTTP/1.1") - cl = int(headers["content-length"]) - self.assertEqual(cl, len(response_body)) - self.assertTrue(response_body.startswith(b"Internal Server Error")) - self.assertEqual( - sorted(headers.keys()), - ["connection", "content-length", "content-type", "date", "server"], - ) - self.assertEqual(headers["connection"], "close") - # connection has been closed - self.send_check_error(to_send) - self.assertRaises(ConnectionClosed, read_http, fp) - - def test_after_start_response_http10(self): - to_send = "GET /after_start_response HTTP/1.0\r\n\r\n" - to_send = tobytes(to_send) - self.connect() - self.sock.send(to_send) - fp = self.sock.makefile("rb", 0) - line, headers, response_body = read_http(fp) - self.assertline(line, "500", "Internal Server Error", "HTTP/1.0") - cl = int(headers["content-length"]) - self.assertEqual(cl, len(response_body)) - self.assertTrue(response_body.startswith(b"Internal Server Error")) - self.assertEqual( - sorted(headers.keys()), - ["connection", "content-length", "content-type", "date", "server"], - ) - self.assertEqual(headers["connection"], "close") - # connection has been closed - self.send_check_error(to_send) - self.assertRaises(ConnectionClosed, read_http, fp) - - def test_after_start_response_http11(self): - to_send = "GET /after_start_response HTTP/1.1\r\n\r\n" - to_send = tobytes(to_send) - self.connect() - self.sock.send(to_send) - fp = self.sock.makefile("rb", 0) - line, headers, response_body = read_http(fp) - self.assertline(line, "500", "Internal Server Error", "HTTP/1.1") - cl = int(headers["content-length"]) - self.assertEqual(cl, len(response_body)) - self.assertTrue(response_body.startswith(b"Internal Server Error")) - self.assertEqual( - sorted(headers.keys()), ["connection", "content-length", "content-type", "date", "server"] - ) - # connection has been closed - self.send_check_error(to_send) - self.assertRaises(ConnectionClosed, read_http, fp) - - def test_after_start_response_http11_close(self): - to_send = tobytes( - "GET /after_start_response HTTP/1.1\r\nConnection: close\r\n\r\n" - ) - self.connect() - self.sock.send(to_send) - fp = self.sock.makefile("rb", 0) - line, headers, response_body = read_http(fp) - self.assertline(line, "500", "Internal Server Error", "HTTP/1.1") - cl = int(headers["content-length"]) - self.assertEqual(cl, len(response_body)) - self.assertTrue(response_body.startswith(b"Internal Server Error")) - self.assertEqual( - sorted(headers.keys()), - ["connection", "content-length", "content-type", "date", "server"], - ) - self.assertEqual(headers["connection"], "close") - # connection has been closed - self.send_check_error(to_send) - self.assertRaises(ConnectionClosed, read_http, fp) - - def test_after_write_cb(self): - to_send = "GET /after_write_cb HTTP/1.1\r\n\r\n" - to_send = tobytes(to_send) - self.connect() - self.sock.send(to_send) - fp = self.sock.makefile("rb", 0) - line, headers, response_body = read_http(fp) - self.assertline(line, "200", "OK", "HTTP/1.1") - self.assertEqual(response_body, b"") - # connection has been closed - self.send_check_error(to_send) - self.assertRaises(ConnectionClosed, read_http, fp) - - def test_in_generator(self): - to_send = "GET /in_generator HTTP/1.1\r\n\r\n" - to_send = tobytes(to_send) - self.connect() - self.sock.send(to_send) - fp = self.sock.makefile("rb", 0) - line, headers, response_body = read_http(fp) - self.assertline(line, "200", "OK", "HTTP/1.1") - self.assertEqual(response_body, b"") - # connection has been closed - self.send_check_error(to_send) - self.assertRaises(ConnectionClosed, read_http, fp) - - -class FileWrapperTests(object): - def setUp(self): - from waitress.tests.fixtureapps import filewrapper - - self.start_subprocess(filewrapper.app) - - def tearDown(self): - self.stop_subprocess() - - def test_filelike_http11(self): - to_send = "GET /filelike HTTP/1.1\r\n\r\n" - to_send = tobytes(to_send) - - self.connect() - - for t in range(0, 2): - self.sock.send(to_send) - fp = self.sock.makefile("rb", 0) - line, headers, response_body = read_http(fp) - self.assertline(line, "200", "OK", "HTTP/1.1") - cl = int(headers["content-length"]) - self.assertEqual(cl, len(response_body)) - ct = headers["content-type"] - self.assertEqual(ct, "image/jpeg") - self.assertTrue(b"\377\330\377" in response_body) - # connection has not been closed - - def test_filelike_nocl_http11(self): - to_send = "GET /filelike_nocl HTTP/1.1\r\n\r\n" - to_send = tobytes(to_send) - - self.connect() - - for t in range(0, 2): - self.sock.send(to_send) - fp = self.sock.makefile("rb", 0) - line, headers, response_body = read_http(fp) - self.assertline(line, "200", "OK", "HTTP/1.1") - cl = int(headers["content-length"]) - self.assertEqual(cl, len(response_body)) - ct = headers["content-type"] - self.assertEqual(ct, "image/jpeg") - self.assertTrue(b"\377\330\377" in response_body) - # connection has not been closed - - def test_filelike_shortcl_http11(self): - to_send = "GET /filelike_shortcl HTTP/1.1\r\n\r\n" - to_send = tobytes(to_send) - - self.connect() - - for t in range(0, 2): - self.sock.send(to_send) - fp = self.sock.makefile("rb", 0) - line, headers, response_body = read_http(fp) - self.assertline(line, "200", "OK", "HTTP/1.1") - cl = int(headers["content-length"]) - self.assertEqual(cl, 1) - self.assertEqual(cl, len(response_body)) - ct = headers["content-type"] - self.assertEqual(ct, "image/jpeg") - self.assertTrue(b"\377" in response_body) - # connection has not been closed - - def test_filelike_longcl_http11(self): - to_send = "GET /filelike_longcl HTTP/1.1\r\n\r\n" - to_send = tobytes(to_send) - - self.connect() - - for t in range(0, 2): - self.sock.send(to_send) - fp = self.sock.makefile("rb", 0) - line, headers, response_body = read_http(fp) - self.assertline(line, "200", "OK", "HTTP/1.1") - cl = int(headers["content-length"]) - self.assertEqual(cl, len(response_body)) - ct = headers["content-type"] - self.assertEqual(ct, "image/jpeg") - self.assertTrue(b"\377\330\377" in response_body) - # connection has not been closed - - def test_notfilelike_http11(self): - to_send = "GET /notfilelike HTTP/1.1\r\n\r\n" - to_send = tobytes(to_send) - - self.connect() - - for t in range(0, 2): - self.sock.send(to_send) - fp = self.sock.makefile("rb", 0) - line, headers, response_body = read_http(fp) - self.assertline(line, "200", "OK", "HTTP/1.1") - cl = int(headers["content-length"]) - self.assertEqual(cl, len(response_body)) - ct = headers["content-type"] - self.assertEqual(ct, "image/jpeg") - self.assertTrue(b"\377\330\377" in response_body) - # connection has not been closed - - def test_notfilelike_iobase_http11(self): - to_send = "GET /notfilelike_iobase HTTP/1.1\r\n\r\n" - to_send = tobytes(to_send) - - self.connect() - - for t in range(0, 2): - self.sock.send(to_send) - fp = self.sock.makefile("rb", 0) - line, headers, response_body = read_http(fp) - self.assertline(line, "200", "OK", "HTTP/1.1") - cl = int(headers["content-length"]) - self.assertEqual(cl, len(response_body)) - ct = headers["content-type"] - self.assertEqual(ct, "image/jpeg") - self.assertTrue(b"\377\330\377" in response_body) - # connection has not been closed - - def test_notfilelike_nocl_http11(self): - to_send = "GET /notfilelike_nocl HTTP/1.1\r\n\r\n" - to_send = tobytes(to_send) - - self.connect() - - self.sock.send(to_send) - fp = self.sock.makefile("rb", 0) - line, headers, response_body = read_http(fp) - self.assertline(line, "200", "OK", "HTTP/1.1") - ct = headers["content-type"] - self.assertEqual(ct, "image/jpeg") - self.assertTrue(b"\377\330\377" in response_body) - # connection has been closed (no content-length) - self.send_check_error(to_send) - self.assertRaises(ConnectionClosed, read_http, fp) - - def test_notfilelike_shortcl_http11(self): - to_send = "GET /notfilelike_shortcl HTTP/1.1\r\n\r\n" - to_send = tobytes(to_send) - - self.connect() - - for t in range(0, 2): - self.sock.send(to_send) - fp = self.sock.makefile("rb", 0) - line, headers, response_body = read_http(fp) - self.assertline(line, "200", "OK", "HTTP/1.1") - cl = int(headers["content-length"]) - self.assertEqual(cl, 1) - self.assertEqual(cl, len(response_body)) - ct = headers["content-type"] - self.assertEqual(ct, "image/jpeg") - self.assertTrue(b"\377" in response_body) - # connection has not been closed - - def test_notfilelike_longcl_http11(self): - to_send = "GET /notfilelike_longcl HTTP/1.1\r\n\r\n" - to_send = tobytes(to_send) - - self.connect() - - self.sock.send(to_send) - fp = self.sock.makefile("rb", 0) - line, headers, response_body = read_http(fp) - self.assertline(line, "200", "OK", "HTTP/1.1") - cl = int(headers["content-length"]) - self.assertEqual(cl, len(response_body) + 10) - ct = headers["content-type"] - self.assertEqual(ct, "image/jpeg") - self.assertTrue(b"\377\330\377" in response_body) - # connection has been closed - self.send_check_error(to_send) - self.assertRaises(ConnectionClosed, read_http, fp) - - def test_filelike_http10(self): - to_send = "GET /filelike HTTP/1.0\r\n\r\n" - to_send = tobytes(to_send) - - self.connect() - - self.sock.send(to_send) - fp = self.sock.makefile("rb", 0) - line, headers, response_body = read_http(fp) - self.assertline(line, "200", "OK", "HTTP/1.0") - cl = int(headers["content-length"]) - self.assertEqual(cl, len(response_body)) - ct = headers["content-type"] - self.assertEqual(ct, "image/jpeg") - self.assertTrue(b"\377\330\377" in response_body) - # connection has been closed - self.send_check_error(to_send) - self.assertRaises(ConnectionClosed, read_http, fp) - - def test_filelike_nocl_http10(self): - to_send = "GET /filelike_nocl HTTP/1.0\r\n\r\n" - to_send = tobytes(to_send) - - self.connect() - - self.sock.send(to_send) - fp = self.sock.makefile("rb", 0) - line, headers, response_body = read_http(fp) - self.assertline(line, "200", "OK", "HTTP/1.0") - cl = int(headers["content-length"]) - self.assertEqual(cl, len(response_body)) - ct = headers["content-type"] - self.assertEqual(ct, "image/jpeg") - self.assertTrue(b"\377\330\377" in response_body) - # connection has been closed - self.send_check_error(to_send) - self.assertRaises(ConnectionClosed, read_http, fp) - - def test_notfilelike_http10(self): - to_send = "GET /notfilelike HTTP/1.0\r\n\r\n" - to_send = tobytes(to_send) - - self.connect() - - self.sock.send(to_send) - fp = self.sock.makefile("rb", 0) - line, headers, response_body = read_http(fp) - self.assertline(line, "200", "OK", "HTTP/1.0") - cl = int(headers["content-length"]) - self.assertEqual(cl, len(response_body)) - ct = headers["content-type"] - self.assertEqual(ct, "image/jpeg") - self.assertTrue(b"\377\330\377" in response_body) - # connection has been closed - self.send_check_error(to_send) - self.assertRaises(ConnectionClosed, read_http, fp) - - def test_notfilelike_nocl_http10(self): - to_send = "GET /notfilelike_nocl HTTP/1.0\r\n\r\n" - to_send = tobytes(to_send) - - self.connect() - - self.sock.send(to_send) - fp = self.sock.makefile("rb", 0) - line, headers, response_body = read_http(fp) - self.assertline(line, "200", "OK", "HTTP/1.0") - ct = headers["content-type"] - self.assertEqual(ct, "image/jpeg") - self.assertTrue(b"\377\330\377" in response_body) - # connection has been closed (no content-length) - self.send_check_error(to_send) - self.assertRaises(ConnectionClosed, read_http, fp) - - -class TcpEchoTests(EchoTests, TcpTests, unittest.TestCase): - pass - - -class TcpPipeliningTests(PipeliningTests, TcpTests, unittest.TestCase): - pass - - -class TcpExpectContinueTests(ExpectContinueTests, TcpTests, unittest.TestCase): - pass - - -class TcpBadContentLengthTests(BadContentLengthTests, TcpTests, unittest.TestCase): - pass - - -class TcpNoContentLengthTests(NoContentLengthTests, TcpTests, unittest.TestCase): - pass - - -class TcpWriteCallbackTests(WriteCallbackTests, TcpTests, unittest.TestCase): - pass - - -class TcpTooLargeTests(TooLargeTests, TcpTests, unittest.TestCase): - pass - - -class TcpInternalServerErrorTests( - InternalServerErrorTests, TcpTests, unittest.TestCase -): - pass - - -class TcpFileWrapperTests(FileWrapperTests, TcpTests, unittest.TestCase): - pass - - -if hasattr(socket, "AF_UNIX"): - - class FixtureUnixWSGIServer(server.UnixWSGIServer): - """A version of UnixWSGIServer that relays back what it's bound to. - """ - - family = socket.AF_UNIX # Testing - - def __init__(self, application, queue, **kw): # pragma: no cover - # Coverage doesn't see this as it's ran in a separate process. - # To permit parallel testing, use a PID-dependent socket. - kw["unix_socket"] = "/tmp/waitress.test-%d.sock" % os.getpid() - super(FixtureUnixWSGIServer, self).__init__(application, **kw) - queue.put(self.socket.getsockname()) - - class UnixTests(SubprocessTests): - - server = FixtureUnixWSGIServer - - def make_http_connection(self): - return UnixHTTPConnection(self.bound_to) - - def stop_subprocess(self): - super(UnixTests, self).stop_subprocess() - cleanup_unix_socket(self.bound_to) - - def send_check_error(self, to_send): - # Unlike inet domain sockets, Unix domain sockets can trigger a - # 'Broken pipe' error when the socket it closed. - try: - self.sock.send(to_send) - except socket.error as exc: - self.assertEqual(get_errno(exc), errno.EPIPE) - - class UnixEchoTests(EchoTests, UnixTests, unittest.TestCase): - pass - - class UnixPipeliningTests(PipeliningTests, UnixTests, unittest.TestCase): - pass - - class UnixExpectContinueTests(ExpectContinueTests, UnixTests, unittest.TestCase): - pass - - class UnixBadContentLengthTests( - BadContentLengthTests, UnixTests, unittest.TestCase - ): - pass - - class UnixNoContentLengthTests(NoContentLengthTests, UnixTests, unittest.TestCase): - pass - - class UnixWriteCallbackTests(WriteCallbackTests, UnixTests, unittest.TestCase): - pass - - class UnixTooLargeTests(TooLargeTests, UnixTests, unittest.TestCase): - pass - - class UnixInternalServerErrorTests( - InternalServerErrorTests, UnixTests, unittest.TestCase - ): - pass - - class UnixFileWrapperTests(FileWrapperTests, UnixTests, unittest.TestCase): - pass - - -def parse_headers(fp): - """Parses only RFC2822 headers from a file pointer. - """ - headers = {} - while True: - line = fp.readline() - if line in (b"\r\n", b"\n", b""): - break - line = line.decode("iso-8859-1") - name, value = line.strip().split(":", 1) - headers[name.lower().strip()] = value.lower().strip() - return headers - - -class UnixHTTPConnection(httplib.HTTPConnection): - """Patched version of HTTPConnection that uses Unix domain sockets. - """ - - def __init__(self, path): - httplib.HTTPConnection.__init__(self, "localhost") - self.path = path - - def connect(self): - sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) - sock.connect(self.path) - self.sock = sock - - -class ConnectionClosed(Exception): - pass - - -# stolen from gevent -def read_http(fp): # pragma: no cover - try: - response_line = fp.readline() - except socket.error as exc: - fp.close() - # errno 104 is ENOTRECOVERABLE, In WinSock 10054 is ECONNRESET - if get_errno(exc) in (errno.ECONNABORTED, errno.ECONNRESET, 104, 10054): - raise ConnectionClosed - raise - if not response_line: - raise ConnectionClosed - - header_lines = [] - while True: - line = fp.readline() - if line in (b"\r\n", b"\r\n", b""): - break - else: - header_lines.append(line) - headers = dict() - for x in header_lines: - x = x.strip() - if not x: - continue - key, value = x.split(b": ", 1) - key = key.decode("iso-8859-1").lower() - value = value.decode("iso-8859-1") - assert key not in headers, "%s header duplicated" % key - headers[key] = value - - if "content-length" in headers: - num = int(headers["content-length"]) - body = b"" - left = num - while left > 0: - data = fp.read(left) - if not data: - break - body += data - left -= len(data) - else: - # read until EOF - body = fp.read() - - return response_line, headers, body - - -# stolen from gevent -def get_errno(exc): # pragma: no cover - """ Get the error code out of socket.error objects. - socket.error in <2.5 does not have errno attribute - socket.error in 3.x does not allow indexing access - e.args[0] works for all. - There are cases when args[0] is not errno. - i.e. http://bugs.python.org/issue6471 - Maybe there are cases when errno is set, but it is not the first argument? - """ - try: - if exc.errno is not None: - return exc.errno - except AttributeError: - pass - try: - return exc.args[0] - except IndexError: - return None - - -def chunks(l, n): - """ Yield successive n-sized chunks from l. - """ - for i in range(0, len(l), n): - yield l[i : i + n] diff --git a/libs/waitress/tests/test_init.py b/libs/waitress/tests/test_init.py deleted file mode 100644 index f9b91d762..000000000 --- a/libs/waitress/tests/test_init.py +++ /dev/null @@ -1,51 +0,0 @@ -import unittest - - -class Test_serve(unittest.TestCase): - def _callFUT(self, app, **kw): - from waitress import serve - - return serve(app, **kw) - - def test_it(self): - server = DummyServerFactory() - app = object() - result = self._callFUT(app, _server=server, _quiet=True) - self.assertEqual(server.app, app) - self.assertEqual(result, None) - self.assertEqual(server.ran, True) - - -class Test_serve_paste(unittest.TestCase): - def _callFUT(self, app, **kw): - from waitress import serve_paste - - return serve_paste(app, None, **kw) - - def test_it(self): - server = DummyServerFactory() - app = object() - result = self._callFUT(app, _server=server, _quiet=True) - self.assertEqual(server.app, app) - self.assertEqual(result, 0) - self.assertEqual(server.ran, True) - - -class DummyServerFactory(object): - ran = False - - def __call__(self, app, **kw): - self.adj = DummyAdj(kw) - self.app = app - self.kw = kw - return self - - def run(self): - self.ran = True - - -class DummyAdj(object): - verbose = False - - def __init__(self, kw): - self.__dict__.update(kw) diff --git a/libs/waitress/tests/test_parser.py b/libs/waitress/tests/test_parser.py deleted file mode 100644 index 91837c7fc..000000000 --- a/libs/waitress/tests/test_parser.py +++ /dev/null @@ -1,732 +0,0 @@ -############################################################################## -# -# Copyright (c) 2002 Zope Foundation and Contributors. -# All Rights Reserved. -# -# This software is subject to the provisions of the Zope Public License, -# Version 2.1 (ZPL). A copy of the ZPL should accompany this distribution. -# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED -# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED -# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS -# FOR A PARTICULAR PURPOSE. -# -############################################################################## -"""HTTP Request Parser tests -""" -import unittest - -from waitress.compat import text_, tobytes - - -class TestHTTPRequestParser(unittest.TestCase): - def setUp(self): - from waitress.parser import HTTPRequestParser - from waitress.adjustments import Adjustments - - my_adj = Adjustments() - self.parser = HTTPRequestParser(my_adj) - - def test_get_body_stream_None(self): - self.parser.body_recv = None - result = self.parser.get_body_stream() - self.assertEqual(result.getvalue(), b"") - - def test_get_body_stream_nonNone(self): - body_rcv = DummyBodyStream() - self.parser.body_rcv = body_rcv - result = self.parser.get_body_stream() - self.assertEqual(result, body_rcv) - - def test_received_get_no_headers(self): - data = b"HTTP/1.0 GET /foobar\r\n\r\n" - result = self.parser.received(data) - self.assertEqual(result, 24) - self.assertTrue(self.parser.completed) - self.assertEqual(self.parser.headers, {}) - - def test_received_bad_host_header(self): - from waitress.utilities import BadRequest - - data = b"HTTP/1.0 GET /foobar\r\n Host: foo\r\n\r\n" - result = self.parser.received(data) - self.assertEqual(result, 36) - self.assertTrue(self.parser.completed) - self.assertEqual(self.parser.error.__class__, BadRequest) - - def test_received_bad_transfer_encoding(self): - from waitress.utilities import ServerNotImplemented - - data = ( - b"GET /foobar HTTP/1.1\r\n" - b"Transfer-Encoding: foo\r\n" - b"\r\n" - b"1d;\r\n" - b"This string has 29 characters\r\n" - b"0\r\n\r\n" - ) - result = self.parser.received(data) - self.assertEqual(result, 48) - self.assertTrue(self.parser.completed) - self.assertEqual(self.parser.error.__class__, ServerNotImplemented) - - def test_received_nonsense_nothing(self): - data = b"\r\n\r\n" - result = self.parser.received(data) - self.assertEqual(result, 4) - self.assertTrue(self.parser.completed) - self.assertEqual(self.parser.headers, {}) - - def test_received_no_doublecr(self): - data = b"GET /foobar HTTP/8.4\r\n" - result = self.parser.received(data) - self.assertEqual(result, 22) - self.assertFalse(self.parser.completed) - self.assertEqual(self.parser.headers, {}) - - def test_received_already_completed(self): - self.parser.completed = True - result = self.parser.received(b"a") - self.assertEqual(result, 0) - - def test_received_cl_too_large(self): - from waitress.utilities import RequestEntityTooLarge - - self.parser.adj.max_request_body_size = 2 - data = b"GET /foobar HTTP/8.4\r\nContent-Length: 10\r\n\r\n" - result = self.parser.received(data) - self.assertEqual(result, 44) - self.assertTrue(self.parser.completed) - self.assertTrue(isinstance(self.parser.error, RequestEntityTooLarge)) - - def test_received_headers_too_large(self): - from waitress.utilities import RequestHeaderFieldsTooLarge - - self.parser.adj.max_request_header_size = 2 - data = b"GET /foobar HTTP/8.4\r\nX-Foo: 1\r\n\r\n" - result = self.parser.received(data) - self.assertEqual(result, 34) - self.assertTrue(self.parser.completed) - self.assertTrue(isinstance(self.parser.error, RequestHeaderFieldsTooLarge)) - - def test_received_body_too_large(self): - from waitress.utilities import RequestEntityTooLarge - - self.parser.adj.max_request_body_size = 2 - data = ( - b"GET /foobar HTTP/1.1\r\n" - b"Transfer-Encoding: chunked\r\n" - b"X-Foo: 1\r\n" - b"\r\n" - b"1d;\r\n" - b"This string has 29 characters\r\n" - b"0\r\n\r\n" - ) - - result = self.parser.received(data) - self.assertEqual(result, 62) - self.parser.received(data[result:]) - self.assertTrue(self.parser.completed) - self.assertTrue(isinstance(self.parser.error, RequestEntityTooLarge)) - - def test_received_error_from_parser(self): - from waitress.utilities import BadRequest - - data = ( - b"GET /foobar HTTP/1.1\r\n" - b"Transfer-Encoding: chunked\r\n" - b"X-Foo: 1\r\n" - b"\r\n" - b"garbage\r\n" - ) - # header - result = self.parser.received(data) - # body - result = self.parser.received(data[result:]) - self.assertEqual(result, 9) - self.assertTrue(self.parser.completed) - self.assertTrue(isinstance(self.parser.error, BadRequest)) - - def test_received_chunked_completed_sets_content_length(self): - data = ( - b"GET /foobar HTTP/1.1\r\n" - b"Transfer-Encoding: chunked\r\n" - b"X-Foo: 1\r\n" - b"\r\n" - b"1d;\r\n" - b"This string has 29 characters\r\n" - b"0\r\n\r\n" - ) - result = self.parser.received(data) - self.assertEqual(result, 62) - data = data[result:] - result = self.parser.received(data) - self.assertTrue(self.parser.completed) - self.assertTrue(self.parser.error is None) - self.assertEqual(self.parser.headers["CONTENT_LENGTH"], "29") - - def test_parse_header_gardenpath(self): - data = b"GET /foobar HTTP/8.4\r\nfoo: bar\r\n" - self.parser.parse_header(data) - self.assertEqual(self.parser.first_line, b"GET /foobar HTTP/8.4") - self.assertEqual(self.parser.headers["FOO"], "bar") - - def test_parse_header_no_cr_in_headerplus(self): - from waitress.parser import ParsingError - - data = b"GET /foobar HTTP/8.4" - - try: - self.parser.parse_header(data) - except ParsingError: - pass - else: # pragma: nocover - self.assertTrue(False) - - def test_parse_header_bad_content_length(self): - from waitress.parser import ParsingError - - data = b"GET /foobar HTTP/8.4\r\ncontent-length: abc\r\n" - - try: - self.parser.parse_header(data) - except ParsingError as e: - self.assertIn("Content-Length is invalid", e.args[0]) - else: # pragma: nocover - self.assertTrue(False) - - def test_parse_header_multiple_content_length(self): - from waitress.parser import ParsingError - - data = b"GET /foobar HTTP/8.4\r\ncontent-length: 10\r\ncontent-length: 20\r\n" - - try: - self.parser.parse_header(data) - except ParsingError as e: - self.assertIn("Content-Length is invalid", e.args[0]) - else: # pragma: nocover - self.assertTrue(False) - - def test_parse_header_11_te_chunked(self): - # NB: test that capitalization of header value is unimportant - data = b"GET /foobar HTTP/1.1\r\ntransfer-encoding: ChUnKed\r\n" - self.parser.parse_header(data) - self.assertEqual(self.parser.body_rcv.__class__.__name__, "ChunkedReceiver") - - def test_parse_header_transfer_encoding_invalid(self): - from waitress.parser import TransferEncodingNotImplemented - - data = b"GET /foobar HTTP/1.1\r\ntransfer-encoding: gzip\r\n" - - try: - self.parser.parse_header(data) - except TransferEncodingNotImplemented as e: - self.assertIn("Transfer-Encoding requested is not supported.", e.args[0]) - else: # pragma: nocover - self.assertTrue(False) - - def test_parse_header_transfer_encoding_invalid_multiple(self): - from waitress.parser import TransferEncodingNotImplemented - - data = b"GET /foobar HTTP/1.1\r\ntransfer-encoding: gzip\r\ntransfer-encoding: chunked\r\n" - - try: - self.parser.parse_header(data) - except TransferEncodingNotImplemented as e: - self.assertIn("Transfer-Encoding requested is not supported.", e.args[0]) - else: # pragma: nocover - self.assertTrue(False) - - def test_parse_header_transfer_encoding_invalid_whitespace(self): - from waitress.parser import TransferEncodingNotImplemented - - data = b"GET /foobar HTTP/1.1\r\nTransfer-Encoding:\x85chunked\r\n" - - try: - self.parser.parse_header(data) - except TransferEncodingNotImplemented as e: - self.assertIn("Transfer-Encoding requested is not supported.", e.args[0]) - else: # pragma: nocover - self.assertTrue(False) - - def test_parse_header_transfer_encoding_invalid_unicode(self): - from waitress.parser import TransferEncodingNotImplemented - - # This is the binary encoding for the UTF-8 character - # https://www.compart.com/en/unicode/U+212A "unicode character "K"" - # which if waitress were to accidentally do the wrong thing get - # lowercased to just the ascii "k" due to unicode collisions during - # transformation - data = b"GET /foobar HTTP/1.1\r\nTransfer-Encoding: chun\xe2\x84\xaaed\r\n" - - try: - self.parser.parse_header(data) - except TransferEncodingNotImplemented as e: - self.assertIn("Transfer-Encoding requested is not supported.", e.args[0]) - else: # pragma: nocover - self.assertTrue(False) - - def test_parse_header_11_expect_continue(self): - data = b"GET /foobar HTTP/1.1\r\nexpect: 100-continue\r\n" - self.parser.parse_header(data) - self.assertEqual(self.parser.expect_continue, True) - - def test_parse_header_connection_close(self): - data = b"GET /foobar HTTP/1.1\r\nConnection: close\r\n" - self.parser.parse_header(data) - self.assertEqual(self.parser.connection_close, True) - - def test_close_with_body_rcv(self): - body_rcv = DummyBodyStream() - self.parser.body_rcv = body_rcv - self.parser.close() - self.assertTrue(body_rcv.closed) - - def test_close_with_no_body_rcv(self): - self.parser.body_rcv = None - self.parser.close() # doesn't raise - - def test_parse_header_lf_only(self): - from waitress.parser import ParsingError - - data = b"GET /foobar HTTP/8.4\nfoo: bar" - - try: - self.parser.parse_header(data) - except ParsingError: - pass - else: # pragma: nocover - self.assertTrue(False) - - def test_parse_header_cr_only(self): - from waitress.parser import ParsingError - - data = b"GET /foobar HTTP/8.4\rfoo: bar" - try: - self.parser.parse_header(data) - except ParsingError: - pass - else: # pragma: nocover - self.assertTrue(False) - - def test_parse_header_extra_lf_in_header(self): - from waitress.parser import ParsingError - - data = b"GET /foobar HTTP/8.4\r\nfoo: \nbar\r\n" - try: - self.parser.parse_header(data) - except ParsingError as e: - self.assertIn("Bare CR or LF found in header line", e.args[0]) - else: # pragma: nocover - self.assertTrue(False) - - def test_parse_header_extra_lf_in_first_line(self): - from waitress.parser import ParsingError - - data = b"GET /foobar\n HTTP/8.4\r\n" - try: - self.parser.parse_header(data) - except ParsingError as e: - self.assertIn("Bare CR or LF found in HTTP message", e.args[0]) - else: # pragma: nocover - self.assertTrue(False) - - def test_parse_header_invalid_whitespace(self): - from waitress.parser import ParsingError - - data = b"GET /foobar HTTP/8.4\r\nfoo : bar\r\n" - try: - self.parser.parse_header(data) - except ParsingError as e: - self.assertIn("Invalid header", e.args[0]) - else: # pragma: nocover - self.assertTrue(False) - - def test_parse_header_invalid_whitespace_vtab(self): - from waitress.parser import ParsingError - - data = b"GET /foobar HTTP/1.1\r\nfoo:\x0bbar\r\n" - try: - self.parser.parse_header(data) - except ParsingError as e: - self.assertIn("Invalid header", e.args[0]) - else: # pragma: nocover - self.assertTrue(False) - - def test_parse_header_invalid_no_colon(self): - from waitress.parser import ParsingError - - data = b"GET /foobar HTTP/1.1\r\nfoo: bar\r\nnotvalid\r\n" - try: - self.parser.parse_header(data) - except ParsingError as e: - self.assertIn("Invalid header", e.args[0]) - else: # pragma: nocover - self.assertTrue(False) - - def test_parse_header_invalid_folding_spacing(self): - from waitress.parser import ParsingError - - data = b"GET /foobar HTTP/1.1\r\nfoo: bar\r\n\t\x0bbaz\r\n" - try: - self.parser.parse_header(data) - except ParsingError as e: - self.assertIn("Invalid header", e.args[0]) - else: # pragma: nocover - self.assertTrue(False) - - def test_parse_header_invalid_chars(self): - from waitress.parser import ParsingError - - data = b"GET /foobar HTTP/1.1\r\nfoo: bar\r\nfoo: \x0bbaz\r\n" - try: - self.parser.parse_header(data) - except ParsingError as e: - self.assertIn("Invalid header", e.args[0]) - else: # pragma: nocover - self.assertTrue(False) - - def test_parse_header_empty(self): - from waitress.parser import ParsingError - - data = b"GET /foobar HTTP/1.1\r\nfoo: bar\r\nempty:\r\n" - self.parser.parse_header(data) - - self.assertIn("EMPTY", self.parser.headers) - self.assertIn("FOO", self.parser.headers) - self.assertEqual(self.parser.headers["EMPTY"], "") - self.assertEqual(self.parser.headers["FOO"], "bar") - - def test_parse_header_multiple_values(self): - from waitress.parser import ParsingError - - data = b"GET /foobar HTTP/1.1\r\nfoo: bar, whatever, more, please, yes\r\n" - self.parser.parse_header(data) - - self.assertIn("FOO", self.parser.headers) - self.assertEqual(self.parser.headers["FOO"], "bar, whatever, more, please, yes") - - def test_parse_header_multiple_values_header_folded(self): - from waitress.parser import ParsingError - - data = b"GET /foobar HTTP/1.1\r\nfoo: bar, whatever,\r\n more, please, yes\r\n" - self.parser.parse_header(data) - - self.assertIn("FOO", self.parser.headers) - self.assertEqual(self.parser.headers["FOO"], "bar, whatever, more, please, yes") - - def test_parse_header_multiple_values_header_folded_multiple(self): - from waitress.parser import ParsingError - - data = b"GET /foobar HTTP/1.1\r\nfoo: bar, whatever,\r\n more\r\nfoo: please, yes\r\n" - self.parser.parse_header(data) - - self.assertIn("FOO", self.parser.headers) - self.assertEqual(self.parser.headers["FOO"], "bar, whatever, more, please, yes") - - def test_parse_header_multiple_values_extra_space(self): - # Tests errata from: https://www.rfc-editor.org/errata_search.php?rfc=7230&eid=4189 - from waitress.parser import ParsingError - - data = b"GET /foobar HTTP/1.1\r\nfoo: abrowser/0.001 (C O M M E N T)\r\n" - self.parser.parse_header(data) - - self.assertIn("FOO", self.parser.headers) - self.assertEqual(self.parser.headers["FOO"], "abrowser/0.001 (C O M M E N T)") - - def test_parse_header_invalid_backtrack_bad(self): - from waitress.parser import ParsingError - - data = b"GET /foobar HTTP/1.1\r\nfoo: bar\r\nfoo: xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx\x10\r\n" - try: - self.parser.parse_header(data) - except ParsingError as e: - self.assertIn("Invalid header", e.args[0]) - else: # pragma: nocover - self.assertTrue(False) - - def test_parse_header_short_values(self): - from waitress.parser import ParsingError - - data = b"GET /foobar HTTP/1.1\r\none: 1\r\ntwo: 22\r\n" - self.parser.parse_header(data) - - self.assertIn("ONE", self.parser.headers) - self.assertIn("TWO", self.parser.headers) - self.assertEqual(self.parser.headers["ONE"], "1") - self.assertEqual(self.parser.headers["TWO"], "22") - - -class Test_split_uri(unittest.TestCase): - def _callFUT(self, uri): - from waitress.parser import split_uri - - ( - self.proxy_scheme, - self.proxy_netloc, - self.path, - self.query, - self.fragment, - ) = split_uri(uri) - - def test_split_uri_unquoting_unneeded(self): - self._callFUT(b"http://localhost:8080/abc def") - self.assertEqual(self.path, "/abc def") - - def test_split_uri_unquoting_needed(self): - self._callFUT(b"http://localhost:8080/abc%20def") - self.assertEqual(self.path, "/abc def") - - def test_split_url_with_query(self): - self._callFUT(b"http://localhost:8080/abc?a=1&b=2") - self.assertEqual(self.path, "/abc") - self.assertEqual(self.query, "a=1&b=2") - - def test_split_url_with_query_empty(self): - self._callFUT(b"http://localhost:8080/abc?") - self.assertEqual(self.path, "/abc") - self.assertEqual(self.query, "") - - def test_split_url_with_fragment(self): - self._callFUT(b"http://localhost:8080/#foo") - self.assertEqual(self.path, "/") - self.assertEqual(self.fragment, "foo") - - def test_split_url_https(self): - self._callFUT(b"https://localhost:8080/") - self.assertEqual(self.path, "/") - self.assertEqual(self.proxy_scheme, "https") - self.assertEqual(self.proxy_netloc, "localhost:8080") - - def test_split_uri_unicode_error_raises_parsing_error(self): - # See https://github.com/Pylons/waitress/issues/64 - from waitress.parser import ParsingError - - # Either pass or throw a ParsingError, just don't throw another type of - # exception as that will cause the connection to close badly: - try: - self._callFUT(b"/\xd0") - except ParsingError: - pass - - def test_split_uri_path(self): - self._callFUT(b"//testing/whatever") - self.assertEqual(self.path, "//testing/whatever") - self.assertEqual(self.proxy_scheme, "") - self.assertEqual(self.proxy_netloc, "") - self.assertEqual(self.query, "") - self.assertEqual(self.fragment, "") - - def test_split_uri_path_query(self): - self._callFUT(b"//testing/whatever?a=1&b=2") - self.assertEqual(self.path, "//testing/whatever") - self.assertEqual(self.proxy_scheme, "") - self.assertEqual(self.proxy_netloc, "") - self.assertEqual(self.query, "a=1&b=2") - self.assertEqual(self.fragment, "") - - def test_split_uri_path_query_fragment(self): - self._callFUT(b"//testing/whatever?a=1&b=2#fragment") - self.assertEqual(self.path, "//testing/whatever") - self.assertEqual(self.proxy_scheme, "") - self.assertEqual(self.proxy_netloc, "") - self.assertEqual(self.query, "a=1&b=2") - self.assertEqual(self.fragment, "fragment") - - -class Test_get_header_lines(unittest.TestCase): - def _callFUT(self, data): - from waitress.parser import get_header_lines - - return get_header_lines(data) - - def test_get_header_lines(self): - result = self._callFUT(b"slam\r\nslim") - self.assertEqual(result, [b"slam", b"slim"]) - - def test_get_header_lines_folded(self): - # From RFC2616: - # HTTP/1.1 header field values can be folded onto multiple lines if the - # continuation line begins with a space or horizontal tab. All linear - # white space, including folding, has the same semantics as SP. A - # recipient MAY replace any linear white space with a single SP before - # interpreting the field value or forwarding the message downstream. - - # We are just preserving the whitespace that indicates folding. - result = self._callFUT(b"slim\r\n slam") - self.assertEqual(result, [b"slim slam"]) - - def test_get_header_lines_tabbed(self): - result = self._callFUT(b"slam\r\n\tslim") - self.assertEqual(result, [b"slam\tslim"]) - - def test_get_header_lines_malformed(self): - # https://corte.si/posts/code/pathod/pythonservers/index.html - from waitress.parser import ParsingError - - self.assertRaises(ParsingError, self._callFUT, b" Host: localhost\r\n\r\n") - - -class Test_crack_first_line(unittest.TestCase): - def _callFUT(self, line): - from waitress.parser import crack_first_line - - return crack_first_line(line) - - def test_crack_first_line_matchok(self): - result = self._callFUT(b"GET / HTTP/1.0") - self.assertEqual(result, (b"GET", b"/", b"1.0")) - - def test_crack_first_line_lowercase_method(self): - from waitress.parser import ParsingError - - self.assertRaises(ParsingError, self._callFUT, b"get / HTTP/1.0") - - def test_crack_first_line_nomatch(self): - result = self._callFUT(b"GET / bleh") - self.assertEqual(result, (b"", b"", b"")) - - result = self._callFUT(b"GET /info?txtAirPlay&txtRAOP RTSP/1.0") - self.assertEqual(result, (b"", b"", b"")) - - def test_crack_first_line_missing_version(self): - result = self._callFUT(b"GET /") - self.assertEqual(result, (b"GET", b"/", b"")) - - -class TestHTTPRequestParserIntegration(unittest.TestCase): - def setUp(self): - from waitress.parser import HTTPRequestParser - from waitress.adjustments import Adjustments - - my_adj = Adjustments() - self.parser = HTTPRequestParser(my_adj) - - def feed(self, data): - parser = self.parser - - for n in range(100): # make sure we never loop forever - consumed = parser.received(data) - data = data[consumed:] - - if parser.completed: - return - raise ValueError("Looping") # pragma: no cover - - def testSimpleGET(self): - data = ( - b"GET /foobar HTTP/8.4\r\n" - b"FirstName: mickey\r\n" - b"lastname: Mouse\r\n" - b"content-length: 6\r\n" - b"\r\n" - b"Hello." - ) - parser = self.parser - self.feed(data) - self.assertTrue(parser.completed) - self.assertEqual(parser.version, "8.4") - self.assertFalse(parser.empty) - self.assertEqual( - parser.headers, - {"FIRSTNAME": "mickey", "LASTNAME": "Mouse", "CONTENT_LENGTH": "6",}, - ) - self.assertEqual(parser.path, "/foobar") - self.assertEqual(parser.command, "GET") - self.assertEqual(parser.query, "") - self.assertEqual(parser.proxy_scheme, "") - self.assertEqual(parser.proxy_netloc, "") - self.assertEqual(parser.get_body_stream().getvalue(), b"Hello.") - - def testComplexGET(self): - data = ( - b"GET /foo/a+%2B%2F%C3%A4%3D%26a%3Aint?d=b+%2B%2F%3D%26b%3Aint&c+%2B%2F%3D%26c%3Aint=6 HTTP/8.4\r\n" - b"FirstName: mickey\r\n" - b"lastname: Mouse\r\n" - b"content-length: 10\r\n" - b"\r\n" - b"Hello mickey." - ) - parser = self.parser - self.feed(data) - self.assertEqual(parser.command, "GET") - self.assertEqual(parser.version, "8.4") - self.assertFalse(parser.empty) - self.assertEqual( - parser.headers, - {"FIRSTNAME": "mickey", "LASTNAME": "Mouse", "CONTENT_LENGTH": "10"}, - ) - # path should be utf-8 encoded - self.assertEqual( - tobytes(parser.path).decode("utf-8"), - text_(b"/foo/a++/\xc3\xa4=&a:int", "utf-8"), - ) - self.assertEqual( - parser.query, "d=b+%2B%2F%3D%26b%3Aint&c+%2B%2F%3D%26c%3Aint=6" - ) - self.assertEqual(parser.get_body_stream().getvalue(), b"Hello mick") - - def testProxyGET(self): - data = ( - b"GET https://example.com:8080/foobar HTTP/8.4\r\n" - b"content-length: 6\r\n" - b"\r\n" - b"Hello." - ) - parser = self.parser - self.feed(data) - self.assertTrue(parser.completed) - self.assertEqual(parser.version, "8.4") - self.assertFalse(parser.empty) - self.assertEqual(parser.headers, {"CONTENT_LENGTH": "6"}) - self.assertEqual(parser.path, "/foobar") - self.assertEqual(parser.command, "GET") - self.assertEqual(parser.proxy_scheme, "https") - self.assertEqual(parser.proxy_netloc, "example.com:8080") - self.assertEqual(parser.command, "GET") - self.assertEqual(parser.query, "") - self.assertEqual(parser.get_body_stream().getvalue(), b"Hello.") - - def testDuplicateHeaders(self): - # Ensure that headers with the same key get concatenated as per - # RFC2616. - data = ( - b"GET /foobar HTTP/8.4\r\n" - b"x-forwarded-for: 10.11.12.13\r\n" - b"x-forwarded-for: unknown,127.0.0.1\r\n" - b"X-Forwarded_for: 255.255.255.255\r\n" - b"content-length: 6\r\n" - b"\r\n" - b"Hello." - ) - self.feed(data) - self.assertTrue(self.parser.completed) - self.assertEqual( - self.parser.headers, - { - "CONTENT_LENGTH": "6", - "X_FORWARDED_FOR": "10.11.12.13, unknown,127.0.0.1", - }, - ) - - def testSpoofedHeadersDropped(self): - data = ( - b"GET /foobar HTTP/8.4\r\n" - b"x-auth_user: bob\r\n" - b"content-length: 6\r\n" - b"\r\n" - b"Hello." - ) - self.feed(data) - self.assertTrue(self.parser.completed) - self.assertEqual(self.parser.headers, {"CONTENT_LENGTH": "6",}) - - -class DummyBodyStream(object): - def getfile(self): - return self - - def getbuf(self): - return self - - def close(self): - self.closed = True diff --git a/libs/waitress/tests/test_proxy_headers.py b/libs/waitress/tests/test_proxy_headers.py deleted file mode 100644 index 15b4a0828..000000000 --- a/libs/waitress/tests/test_proxy_headers.py +++ /dev/null @@ -1,724 +0,0 @@ -import unittest - -from waitress.compat import tobytes - - -class TestProxyHeadersMiddleware(unittest.TestCase): - def _makeOne(self, app, **kw): - from waitress.proxy_headers import proxy_headers_middleware - - return proxy_headers_middleware(app, **kw) - - def _callFUT(self, app, **kw): - response = DummyResponse() - environ = DummyEnviron(**kw) - - def start_response(status, response_headers): - response.status = status - response.headers = response_headers - - response.steps = list(app(environ, start_response)) - response.body = b"".join(tobytes(s) for s in response.steps) - return response - - def test_get_environment_values_w_scheme_override_untrusted(self): - inner = DummyApp() - app = self._makeOne(inner) - response = self._callFUT( - app, headers={"X_FOO": "BAR", "X_FORWARDED_PROTO": "https",} - ) - self.assertEqual(response.status, "200 OK") - self.assertEqual(inner.environ["wsgi.url_scheme"], "http") - - def test_get_environment_values_w_scheme_override_trusted(self): - inner = DummyApp() - app = self._makeOne( - inner, - trusted_proxy="192.168.1.1", - trusted_proxy_headers={"x-forwarded-proto"}, - ) - response = self._callFUT( - app, - addr=["192.168.1.1", 8080], - headers={"X_FOO": "BAR", "X_FORWARDED_PROTO": "https",}, - ) - - environ = inner.environ - self.assertEqual(response.status, "200 OK") - self.assertEqual(environ["SERVER_PORT"], "443") - self.assertEqual(environ["SERVER_NAME"], "localhost") - self.assertEqual(environ["REMOTE_ADDR"], "192.168.1.1") - self.assertEqual(environ["HTTP_X_FOO"], "BAR") - - def test_get_environment_values_w_bogus_scheme_override(self): - inner = DummyApp() - app = self._makeOne( - inner, - trusted_proxy="192.168.1.1", - trusted_proxy_headers={"x-forwarded-proto"}, - ) - response = self._callFUT( - app, - addr=["192.168.1.1", 80], - headers={ - "X_FOO": "BAR", - "X_FORWARDED_PROTO": "http://p02n3e.com?url=http", - }, - ) - self.assertEqual(response.status, "400 Bad Request") - self.assertIn(b'Header "X-Forwarded-Proto" malformed', response.body) - - def test_get_environment_warning_other_proxy_headers(self): - inner = DummyApp() - logger = DummyLogger() - app = self._makeOne( - inner, - trusted_proxy="192.168.1.1", - trusted_proxy_count=1, - trusted_proxy_headers={"forwarded"}, - log_untrusted=True, - logger=logger, - ) - response = self._callFUT( - app, - addr=["192.168.1.1", 80], - headers={ - "X_FORWARDED_FOR": "[2001:db8::1]", - "FORWARDED": "For=198.51.100.2;host=example.com:8080;proto=https", - }, - ) - self.assertEqual(response.status, "200 OK") - - self.assertEqual(len(logger.logged), 1) - - environ = inner.environ - self.assertNotIn("HTTP_X_FORWARDED_FOR", environ) - self.assertEqual(environ["REMOTE_ADDR"], "198.51.100.2") - self.assertEqual(environ["SERVER_NAME"], "example.com") - self.assertEqual(environ["HTTP_HOST"], "example.com:8080") - self.assertEqual(environ["SERVER_PORT"], "8080") - self.assertEqual(environ["wsgi.url_scheme"], "https") - - def test_get_environment_contains_all_headers_including_untrusted(self): - inner = DummyApp() - app = self._makeOne( - inner, - trusted_proxy="192.168.1.1", - trusted_proxy_count=1, - trusted_proxy_headers={"x-forwarded-by"}, - clear_untrusted=False, - ) - headers_orig = { - "X_FORWARDED_FOR": "198.51.100.2", - "X_FORWARDED_BY": "Waitress", - "X_FORWARDED_PROTO": "https", - "X_FORWARDED_HOST": "example.org", - } - response = self._callFUT( - app, addr=["192.168.1.1", 80], headers=headers_orig.copy(), - ) - self.assertEqual(response.status, "200 OK") - environ = inner.environ - for k, expected in headers_orig.items(): - result = environ["HTTP_%s" % k] - self.assertEqual(result, expected) - - def test_get_environment_contains_only_trusted_headers(self): - inner = DummyApp() - app = self._makeOne( - inner, - trusted_proxy="192.168.1.1", - trusted_proxy_count=1, - trusted_proxy_headers={"x-forwarded-by"}, - clear_untrusted=True, - ) - response = self._callFUT( - app, - addr=["192.168.1.1", 80], - headers={ - "X_FORWARDED_FOR": "198.51.100.2", - "X_FORWARDED_BY": "Waitress", - "X_FORWARDED_PROTO": "https", - "X_FORWARDED_HOST": "example.org", - }, - ) - self.assertEqual(response.status, "200 OK") - - environ = inner.environ - self.assertEqual(environ["HTTP_X_FORWARDED_BY"], "Waitress") - self.assertNotIn("HTTP_X_FORWARDED_FOR", environ) - self.assertNotIn("HTTP_X_FORWARDED_PROTO", environ) - self.assertNotIn("HTTP_X_FORWARDED_HOST", environ) - - def test_get_environment_clears_headers_if_untrusted_proxy(self): - inner = DummyApp() - app = self._makeOne( - inner, - trusted_proxy="192.168.1.1", - trusted_proxy_count=1, - trusted_proxy_headers={"x-forwarded-by"}, - clear_untrusted=True, - ) - response = self._callFUT( - app, - addr=["192.168.1.255", 80], - headers={ - "X_FORWARDED_FOR": "198.51.100.2", - "X_FORWARDED_BY": "Waitress", - "X_FORWARDED_PROTO": "https", - "X_FORWARDED_HOST": "example.org", - }, - ) - self.assertEqual(response.status, "200 OK") - - environ = inner.environ - self.assertNotIn("HTTP_X_FORWARDED_BY", environ) - self.assertNotIn("HTTP_X_FORWARDED_FOR", environ) - self.assertNotIn("HTTP_X_FORWARDED_PROTO", environ) - self.assertNotIn("HTTP_X_FORWARDED_HOST", environ) - - def test_parse_proxy_headers_forwarded_for(self): - inner = DummyApp() - app = self._makeOne( - inner, trusted_proxy="*", trusted_proxy_headers={"x-forwarded-for"}, - ) - response = self._callFUT(app, headers={"X_FORWARDED_FOR": "192.0.2.1"}) - self.assertEqual(response.status, "200 OK") - - environ = inner.environ - self.assertEqual(environ["REMOTE_ADDR"], "192.0.2.1") - - def test_parse_proxy_headers_forwarded_for_v6_missing_brackets(self): - inner = DummyApp() - app = self._makeOne( - inner, trusted_proxy="*", trusted_proxy_headers={"x-forwarded-for"}, - ) - response = self._callFUT(app, headers={"X_FORWARDED_FOR": "2001:db8::0"}) - self.assertEqual(response.status, "200 OK") - - environ = inner.environ - self.assertEqual(environ["REMOTE_ADDR"], "2001:db8::0") - - def test_parse_proxy_headers_forwared_for_multiple(self): - inner = DummyApp() - app = self._makeOne( - inner, - trusted_proxy="*", - trusted_proxy_count=2, - trusted_proxy_headers={"x-forwarded-for"}, - ) - response = self._callFUT( - app, headers={"X_FORWARDED_FOR": "192.0.2.1, 198.51.100.2, 203.0.113.1"} - ) - self.assertEqual(response.status, "200 OK") - - environ = inner.environ - self.assertEqual(environ["REMOTE_ADDR"], "198.51.100.2") - - def test_parse_forwarded_multiple_proxies_trust_only_two(self): - inner = DummyApp() - app = self._makeOne( - inner, - trusted_proxy="*", - trusted_proxy_count=2, - trusted_proxy_headers={"forwarded"}, - ) - response = self._callFUT( - app, - headers={ - "FORWARDED": ( - "For=192.0.2.1;host=fake.com, " - "For=198.51.100.2;host=example.com:8080, " - "For=203.0.113.1" - ), - }, - ) - self.assertEqual(response.status, "200 OK") - - environ = inner.environ - self.assertEqual(environ["REMOTE_ADDR"], "198.51.100.2") - self.assertEqual(environ["SERVER_NAME"], "example.com") - self.assertEqual(environ["HTTP_HOST"], "example.com:8080") - self.assertEqual(environ["SERVER_PORT"], "8080") - - def test_parse_forwarded_multiple_proxies(self): - inner = DummyApp() - app = self._makeOne( - inner, - trusted_proxy="*", - trusted_proxy_count=2, - trusted_proxy_headers={"forwarded"}, - ) - response = self._callFUT( - app, - headers={ - "FORWARDED": ( - 'for="[2001:db8::1]:3821";host="example.com:8443";proto="https", ' - 'for=192.0.2.1;host="example.internal:8080"' - ), - }, - ) - self.assertEqual(response.status, "200 OK") - - environ = inner.environ - self.assertEqual(environ["REMOTE_ADDR"], "2001:db8::1") - self.assertEqual(environ["REMOTE_PORT"], "3821") - self.assertEqual(environ["SERVER_NAME"], "example.com") - self.assertEqual(environ["HTTP_HOST"], "example.com:8443") - self.assertEqual(environ["SERVER_PORT"], "8443") - self.assertEqual(environ["wsgi.url_scheme"], "https") - - def test_parse_forwarded_multiple_proxies_minimal(self): - inner = DummyApp() - app = self._makeOne( - inner, - trusted_proxy="*", - trusted_proxy_count=2, - trusted_proxy_headers={"forwarded"}, - ) - response = self._callFUT( - app, - headers={ - "FORWARDED": ( - 'for="[2001:db8::1]";proto="https", ' - 'for=192.0.2.1;host="example.org"' - ), - }, - ) - self.assertEqual(response.status, "200 OK") - - environ = inner.environ - self.assertEqual(environ["REMOTE_ADDR"], "2001:db8::1") - self.assertEqual(environ["SERVER_NAME"], "example.org") - self.assertEqual(environ["HTTP_HOST"], "example.org") - self.assertEqual(environ["SERVER_PORT"], "443") - self.assertEqual(environ["wsgi.url_scheme"], "https") - - def test_parse_proxy_headers_forwarded_host_with_port(self): - inner = DummyApp() - app = self._makeOne( - inner, - trusted_proxy="*", - trusted_proxy_count=2, - trusted_proxy_headers={ - "x-forwarded-for", - "x-forwarded-proto", - "x-forwarded-host", - }, - ) - response = self._callFUT( - app, - headers={ - "X_FORWARDED_FOR": "192.0.2.1, 198.51.100.2, 203.0.113.1", - "X_FORWARDED_PROTO": "http", - "X_FORWARDED_HOST": "example.com:8080", - }, - ) - self.assertEqual(response.status, "200 OK") - - environ = inner.environ - self.assertEqual(environ["REMOTE_ADDR"], "198.51.100.2") - self.assertEqual(environ["SERVER_NAME"], "example.com") - self.assertEqual(environ["HTTP_HOST"], "example.com:8080") - self.assertEqual(environ["SERVER_PORT"], "8080") - - def test_parse_proxy_headers_forwarded_host_without_port(self): - inner = DummyApp() - app = self._makeOne( - inner, - trusted_proxy="*", - trusted_proxy_count=2, - trusted_proxy_headers={ - "x-forwarded-for", - "x-forwarded-proto", - "x-forwarded-host", - }, - ) - response = self._callFUT( - app, - headers={ - "X_FORWARDED_FOR": "192.0.2.1, 198.51.100.2, 203.0.113.1", - "X_FORWARDED_PROTO": "http", - "X_FORWARDED_HOST": "example.com", - }, - ) - self.assertEqual(response.status, "200 OK") - - environ = inner.environ - self.assertEqual(environ["REMOTE_ADDR"], "198.51.100.2") - self.assertEqual(environ["SERVER_NAME"], "example.com") - self.assertEqual(environ["HTTP_HOST"], "example.com") - self.assertEqual(environ["SERVER_PORT"], "80") - - def test_parse_proxy_headers_forwarded_host_with_forwarded_port(self): - inner = DummyApp() - app = self._makeOne( - inner, - trusted_proxy="*", - trusted_proxy_count=2, - trusted_proxy_headers={ - "x-forwarded-for", - "x-forwarded-proto", - "x-forwarded-host", - "x-forwarded-port", - }, - ) - response = self._callFUT( - app, - headers={ - "X_FORWARDED_FOR": "192.0.2.1, 198.51.100.2, 203.0.113.1", - "X_FORWARDED_PROTO": "http", - "X_FORWARDED_HOST": "example.com", - "X_FORWARDED_PORT": "8080", - }, - ) - self.assertEqual(response.status, "200 OK") - - environ = inner.environ - self.assertEqual(environ["REMOTE_ADDR"], "198.51.100.2") - self.assertEqual(environ["SERVER_NAME"], "example.com") - self.assertEqual(environ["HTTP_HOST"], "example.com:8080") - self.assertEqual(environ["SERVER_PORT"], "8080") - - def test_parse_proxy_headers_forwarded_host_multiple_with_forwarded_port(self): - inner = DummyApp() - app = self._makeOne( - inner, - trusted_proxy="*", - trusted_proxy_count=2, - trusted_proxy_headers={ - "x-forwarded-for", - "x-forwarded-proto", - "x-forwarded-host", - "x-forwarded-port", - }, - ) - response = self._callFUT( - app, - headers={ - "X_FORWARDED_FOR": "192.0.2.1, 198.51.100.2, 203.0.113.1", - "X_FORWARDED_PROTO": "http", - "X_FORWARDED_HOST": "example.com, example.org", - "X_FORWARDED_PORT": "8080", - }, - ) - self.assertEqual(response.status, "200 OK") - - environ = inner.environ - self.assertEqual(environ["REMOTE_ADDR"], "198.51.100.2") - self.assertEqual(environ["SERVER_NAME"], "example.com") - self.assertEqual(environ["HTTP_HOST"], "example.com:8080") - self.assertEqual(environ["SERVER_PORT"], "8080") - - def test_parse_proxy_headers_forwarded_host_multiple_with_forwarded_port_limit_one_trusted( - self, - ): - inner = DummyApp() - app = self._makeOne( - inner, - trusted_proxy="*", - trusted_proxy_count=1, - trusted_proxy_headers={ - "x-forwarded-for", - "x-forwarded-proto", - "x-forwarded-host", - "x-forwarded-port", - }, - ) - response = self._callFUT( - app, - headers={ - "X_FORWARDED_FOR": "192.0.2.1, 198.51.100.2, 203.0.113.1", - "X_FORWARDED_PROTO": "http", - "X_FORWARDED_HOST": "example.com, example.org", - "X_FORWARDED_PORT": "8080", - }, - ) - self.assertEqual(response.status, "200 OK") - - environ = inner.environ - self.assertEqual(environ["REMOTE_ADDR"], "203.0.113.1") - self.assertEqual(environ["SERVER_NAME"], "example.org") - self.assertEqual(environ["HTTP_HOST"], "example.org:8080") - self.assertEqual(environ["SERVER_PORT"], "8080") - - def test_parse_forwarded(self): - inner = DummyApp() - app = self._makeOne( - inner, - trusted_proxy="*", - trusted_proxy_count=1, - trusted_proxy_headers={"forwarded"}, - ) - response = self._callFUT( - app, - headers={ - "FORWARDED": "For=198.51.100.2:5858;host=example.com:8080;proto=https", - }, - ) - self.assertEqual(response.status, "200 OK") - - environ = inner.environ - self.assertEqual(environ["REMOTE_ADDR"], "198.51.100.2") - self.assertEqual(environ["REMOTE_PORT"], "5858") - self.assertEqual(environ["SERVER_NAME"], "example.com") - self.assertEqual(environ["HTTP_HOST"], "example.com:8080") - self.assertEqual(environ["SERVER_PORT"], "8080") - self.assertEqual(environ["wsgi.url_scheme"], "https") - - def test_parse_forwarded_empty_pair(self): - inner = DummyApp() - app = self._makeOne( - inner, - trusted_proxy="*", - trusted_proxy_count=1, - trusted_proxy_headers={"forwarded"}, - ) - response = self._callFUT( - app, headers={"FORWARDED": "For=198.51.100.2;;proto=https;by=_unused",} - ) - self.assertEqual(response.status, "200 OK") - - environ = inner.environ - self.assertEqual(environ["REMOTE_ADDR"], "198.51.100.2") - - def test_parse_forwarded_pair_token_whitespace(self): - inner = DummyApp() - app = self._makeOne( - inner, - trusted_proxy="*", - trusted_proxy_count=1, - trusted_proxy_headers={"forwarded"}, - ) - response = self._callFUT( - app, headers={"FORWARDED": "For=198.51.100.2; proto =https",} - ) - self.assertEqual(response.status, "400 Bad Request") - self.assertIn(b'Header "Forwarded" malformed', response.body) - - def test_parse_forwarded_pair_value_whitespace(self): - inner = DummyApp() - app = self._makeOne( - inner, - trusted_proxy="*", - trusted_proxy_count=1, - trusted_proxy_headers={"forwarded"}, - ) - response = self._callFUT( - app, headers={"FORWARDED": 'For= "198.51.100.2"; proto =https',} - ) - self.assertEqual(response.status, "400 Bad Request") - self.assertIn(b'Header "Forwarded" malformed', response.body) - - def test_parse_forwarded_pair_no_equals(self): - inner = DummyApp() - app = self._makeOne( - inner, - trusted_proxy="*", - trusted_proxy_count=1, - trusted_proxy_headers={"forwarded"}, - ) - response = self._callFUT(app, headers={"FORWARDED": "For"}) - self.assertEqual(response.status, "400 Bad Request") - self.assertIn(b'Header "Forwarded" malformed', response.body) - - def test_parse_forwarded_warning_unknown_token(self): - inner = DummyApp() - logger = DummyLogger() - app = self._makeOne( - inner, - trusted_proxy="*", - trusted_proxy_count=1, - trusted_proxy_headers={"forwarded"}, - logger=logger, - ) - response = self._callFUT( - app, - headers={ - "FORWARDED": ( - "For=198.51.100.2;host=example.com:8080;proto=https;" - 'unknown="yolo"' - ), - }, - ) - self.assertEqual(response.status, "200 OK") - - self.assertEqual(len(logger.logged), 1) - self.assertIn("Unknown Forwarded token", logger.logged[0]) - - environ = inner.environ - self.assertEqual(environ["REMOTE_ADDR"], "198.51.100.2") - self.assertEqual(environ["SERVER_NAME"], "example.com") - self.assertEqual(environ["HTTP_HOST"], "example.com:8080") - self.assertEqual(environ["SERVER_PORT"], "8080") - self.assertEqual(environ["wsgi.url_scheme"], "https") - - def test_parse_no_valid_proxy_headers(self): - inner = DummyApp() - app = self._makeOne(inner, trusted_proxy="*", trusted_proxy_count=1,) - response = self._callFUT( - app, - headers={ - "X_FORWARDED_FOR": "198.51.100.2", - "FORWARDED": "For=198.51.100.2;host=example.com:8080;proto=https", - }, - ) - self.assertEqual(response.status, "200 OK") - - environ = inner.environ - self.assertEqual(environ["REMOTE_ADDR"], "127.0.0.1") - self.assertEqual(environ["SERVER_NAME"], "localhost") - self.assertEqual(environ["HTTP_HOST"], "192.168.1.1:80") - self.assertEqual(environ["SERVER_PORT"], "8080") - self.assertEqual(environ["wsgi.url_scheme"], "http") - - def test_parse_multiple_x_forwarded_proto(self): - inner = DummyApp() - logger = DummyLogger() - app = self._makeOne( - inner, - trusted_proxy="*", - trusted_proxy_count=1, - trusted_proxy_headers={"x-forwarded-proto"}, - logger=logger, - ) - response = self._callFUT(app, headers={"X_FORWARDED_PROTO": "http, https",}) - self.assertEqual(response.status, "400 Bad Request") - self.assertIn(b'Header "X-Forwarded-Proto" malformed', response.body) - - def test_parse_multiple_x_forwarded_port(self): - inner = DummyApp() - logger = DummyLogger() - app = self._makeOne( - inner, - trusted_proxy="*", - trusted_proxy_count=1, - trusted_proxy_headers={"x-forwarded-port"}, - logger=logger, - ) - response = self._callFUT(app, headers={"X_FORWARDED_PORT": "443, 80",}) - self.assertEqual(response.status, "400 Bad Request") - self.assertIn(b'Header "X-Forwarded-Port" malformed', response.body) - - def test_parse_forwarded_port_wrong_proto_port_80(self): - inner = DummyApp() - app = self._makeOne( - inner, - trusted_proxy="*", - trusted_proxy_count=1, - trusted_proxy_headers={ - "x-forwarded-port", - "x-forwarded-host", - "x-forwarded-proto", - }, - ) - response = self._callFUT( - app, - headers={ - "X_FORWARDED_PORT": "80", - "X_FORWARDED_PROTO": "https", - "X_FORWARDED_HOST": "example.com", - }, - ) - self.assertEqual(response.status, "200 OK") - - environ = inner.environ - self.assertEqual(environ["SERVER_NAME"], "example.com") - self.assertEqual(environ["HTTP_HOST"], "example.com:80") - self.assertEqual(environ["SERVER_PORT"], "80") - self.assertEqual(environ["wsgi.url_scheme"], "https") - - def test_parse_forwarded_port_wrong_proto_port_443(self): - inner = DummyApp() - app = self._makeOne( - inner, - trusted_proxy="*", - trusted_proxy_count=1, - trusted_proxy_headers={ - "x-forwarded-port", - "x-forwarded-host", - "x-forwarded-proto", - }, - ) - response = self._callFUT( - app, - headers={ - "X_FORWARDED_PORT": "443", - "X_FORWARDED_PROTO": "http", - "X_FORWARDED_HOST": "example.com", - }, - ) - self.assertEqual(response.status, "200 OK") - - environ = inner.environ - self.assertEqual(environ["SERVER_NAME"], "example.com") - self.assertEqual(environ["HTTP_HOST"], "example.com:443") - self.assertEqual(environ["SERVER_PORT"], "443") - self.assertEqual(environ["wsgi.url_scheme"], "http") - - def test_parse_forwarded_for_bad_quote(self): - inner = DummyApp() - app = self._makeOne( - inner, - trusted_proxy="*", - trusted_proxy_count=1, - trusted_proxy_headers={"x-forwarded-for"}, - ) - response = self._callFUT(app, headers={"X_FORWARDED_FOR": '"foo'}) - self.assertEqual(response.status, "400 Bad Request") - self.assertIn(b'Header "X-Forwarded-For" malformed', response.body) - - def test_parse_forwarded_host_bad_quote(self): - inner = DummyApp() - app = self._makeOne( - inner, - trusted_proxy="*", - trusted_proxy_count=1, - trusted_proxy_headers={"x-forwarded-host"}, - ) - response = self._callFUT(app, headers={"X_FORWARDED_HOST": '"foo'}) - self.assertEqual(response.status, "400 Bad Request") - self.assertIn(b'Header "X-Forwarded-Host" malformed', response.body) - - -class DummyLogger(object): - def __init__(self): - self.logged = [] - - def warning(self, msg, *args): - self.logged.append(msg % args) - - -class DummyApp(object): - def __call__(self, environ, start_response): - self.environ = environ - start_response("200 OK", [("Content-Type", "text/plain")]) - yield "hello" - - -class DummyResponse(object): - status = None - headers = None - body = None - - -def DummyEnviron( - addr=("127.0.0.1", 8080), scheme="http", server="localhost", headers=None, -): - environ = { - "REMOTE_ADDR": addr[0], - "REMOTE_HOST": addr[0], - "REMOTE_PORT": addr[1], - "SERVER_PORT": str(addr[1]), - "SERVER_NAME": server, - "wsgi.url_scheme": scheme, - "HTTP_HOST": "192.168.1.1:80", - } - if headers: - environ.update( - { - "HTTP_" + key.upper().replace("-", "_"): value - for key, value in headers.items() - } - ) - return environ diff --git a/libs/waitress/tests/test_receiver.py b/libs/waitress/tests/test_receiver.py deleted file mode 100644 index b4910bba8..000000000 --- a/libs/waitress/tests/test_receiver.py +++ /dev/null @@ -1,242 +0,0 @@ -import unittest - - -class TestFixedStreamReceiver(unittest.TestCase): - def _makeOne(self, cl, buf): - from waitress.receiver import FixedStreamReceiver - - return FixedStreamReceiver(cl, buf) - - def test_received_remain_lt_1(self): - buf = DummyBuffer() - inst = self._makeOne(0, buf) - result = inst.received("a") - self.assertEqual(result, 0) - self.assertEqual(inst.completed, True) - - def test_received_remain_lte_datalen(self): - buf = DummyBuffer() - inst = self._makeOne(1, buf) - result = inst.received("aa") - self.assertEqual(result, 1) - self.assertEqual(inst.completed, True) - self.assertEqual(inst.completed, 1) - self.assertEqual(inst.remain, 0) - self.assertEqual(buf.data, ["a"]) - - def test_received_remain_gt_datalen(self): - buf = DummyBuffer() - inst = self._makeOne(10, buf) - result = inst.received("aa") - self.assertEqual(result, 2) - self.assertEqual(inst.completed, False) - self.assertEqual(inst.remain, 8) - self.assertEqual(buf.data, ["aa"]) - - def test_getfile(self): - buf = DummyBuffer() - inst = self._makeOne(10, buf) - self.assertEqual(inst.getfile(), buf) - - def test_getbuf(self): - buf = DummyBuffer() - inst = self._makeOne(10, buf) - self.assertEqual(inst.getbuf(), buf) - - def test___len__(self): - buf = DummyBuffer(["1", "2"]) - inst = self._makeOne(10, buf) - self.assertEqual(inst.__len__(), 2) - - -class TestChunkedReceiver(unittest.TestCase): - def _makeOne(self, buf): - from waitress.receiver import ChunkedReceiver - - return ChunkedReceiver(buf) - - def test_alreadycompleted(self): - buf = DummyBuffer() - inst = self._makeOne(buf) - inst.completed = True - result = inst.received(b"a") - self.assertEqual(result, 0) - self.assertEqual(inst.completed, True) - - def test_received_remain_gt_zero(self): - buf = DummyBuffer() - inst = self._makeOne(buf) - inst.chunk_remainder = 100 - result = inst.received(b"a") - self.assertEqual(inst.chunk_remainder, 99) - self.assertEqual(result, 1) - self.assertEqual(inst.completed, False) - - def test_received_control_line_notfinished(self): - buf = DummyBuffer() - inst = self._makeOne(buf) - result = inst.received(b"a") - self.assertEqual(inst.control_line, b"a") - self.assertEqual(result, 1) - self.assertEqual(inst.completed, False) - - def test_received_control_line_finished_garbage_in_input(self): - buf = DummyBuffer() - inst = self._makeOne(buf) - result = inst.received(b"garbage\r\n") - self.assertEqual(result, 9) - self.assertTrue(inst.error) - - def test_received_control_line_finished_all_chunks_not_received(self): - buf = DummyBuffer() - inst = self._makeOne(buf) - result = inst.received(b"a;discard\r\n") - self.assertEqual(inst.control_line, b"") - self.assertEqual(inst.chunk_remainder, 10) - self.assertEqual(inst.all_chunks_received, False) - self.assertEqual(result, 11) - self.assertEqual(inst.completed, False) - - def test_received_control_line_finished_all_chunks_received(self): - buf = DummyBuffer() - inst = self._makeOne(buf) - result = inst.received(b"0;discard\r\n") - self.assertEqual(inst.control_line, b"") - self.assertEqual(inst.all_chunks_received, True) - self.assertEqual(result, 11) - self.assertEqual(inst.completed, False) - - def test_received_trailer_startswith_crlf(self): - buf = DummyBuffer() - inst = self._makeOne(buf) - inst.all_chunks_received = True - result = inst.received(b"\r\n") - self.assertEqual(result, 2) - self.assertEqual(inst.completed, True) - - def test_received_trailer_startswith_lf(self): - buf = DummyBuffer() - inst = self._makeOne(buf) - inst.all_chunks_received = True - result = inst.received(b"\n") - self.assertEqual(result, 1) - self.assertEqual(inst.completed, False) - - def test_received_trailer_not_finished(self): - buf = DummyBuffer() - inst = self._makeOne(buf) - inst.all_chunks_received = True - result = inst.received(b"a") - self.assertEqual(result, 1) - self.assertEqual(inst.completed, False) - - def test_received_trailer_finished(self): - buf = DummyBuffer() - inst = self._makeOne(buf) - inst.all_chunks_received = True - result = inst.received(b"abc\r\n\r\n") - self.assertEqual(inst.trailer, b"abc\r\n\r\n") - self.assertEqual(result, 7) - self.assertEqual(inst.completed, True) - - def test_getfile(self): - buf = DummyBuffer() - inst = self._makeOne(buf) - self.assertEqual(inst.getfile(), buf) - - def test_getbuf(self): - buf = DummyBuffer() - inst = self._makeOne(buf) - self.assertEqual(inst.getbuf(), buf) - - def test___len__(self): - buf = DummyBuffer(["1", "2"]) - inst = self._makeOne(buf) - self.assertEqual(inst.__len__(), 2) - - def test_received_chunk_is_properly_terminated(self): - buf = DummyBuffer() - inst = self._makeOne(buf) - data = b"4\r\nWiki\r\n" - result = inst.received(data) - self.assertEqual(result, len(data)) - self.assertEqual(inst.completed, False) - self.assertEqual(buf.data[0], b"Wiki") - - def test_received_chunk_not_properly_terminated(self): - from waitress.utilities import BadRequest - - buf = DummyBuffer() - inst = self._makeOne(buf) - data = b"4\r\nWikibadchunk\r\n" - result = inst.received(data) - self.assertEqual(result, len(data)) - self.assertEqual(inst.completed, False) - self.assertEqual(buf.data[0], b"Wiki") - self.assertEqual(inst.error.__class__, BadRequest) - - def test_received_multiple_chunks(self): - from waitress.utilities import BadRequest - - buf = DummyBuffer() - inst = self._makeOne(buf) - data = ( - b"4\r\n" - b"Wiki\r\n" - b"5\r\n" - b"pedia\r\n" - b"E\r\n" - b" in\r\n" - b"\r\n" - b"chunks.\r\n" - b"0\r\n" - b"\r\n" - ) - result = inst.received(data) - self.assertEqual(result, len(data)) - self.assertEqual(inst.completed, True) - self.assertEqual(b"".join(buf.data), b"Wikipedia in\r\n\r\nchunks.") - self.assertEqual(inst.error, None) - - def test_received_multiple_chunks_split(self): - from waitress.utilities import BadRequest - - buf = DummyBuffer() - inst = self._makeOne(buf) - data1 = b"4\r\nWiki\r" - result = inst.received(data1) - self.assertEqual(result, len(data1)) - - data2 = ( - b"\n5\r\n" - b"pedia\r\n" - b"E\r\n" - b" in\r\n" - b"\r\n" - b"chunks.\r\n" - b"0\r\n" - b"\r\n" - ) - - result = inst.received(data2) - self.assertEqual(result, len(data2)) - - self.assertEqual(inst.completed, True) - self.assertEqual(b"".join(buf.data), b"Wikipedia in\r\n\r\nchunks.") - self.assertEqual(inst.error, None) - - -class DummyBuffer(object): - def __init__(self, data=None): - if data is None: - data = [] - self.data = data - - def append(self, s): - self.data.append(s) - - def getfile(self): - return self - - def __len__(self): - return len(self.data) diff --git a/libs/waitress/tests/test_regression.py b/libs/waitress/tests/test_regression.py deleted file mode 100644 index 3c4c6c202..000000000 --- a/libs/waitress/tests/test_regression.py +++ /dev/null @@ -1,147 +0,0 @@ -############################################################################## -# -# Copyright (c) 2005 Zope Foundation and Contributors. -# All Rights Reserved. -# -# This software is subject to the provisions of the Zope Public License, -# Version 2.1 (ZPL). A copy of the ZPL should accompany this distribution. -# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED -# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED -# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS -# FOR A PARTICULAR PURPOSE. -# -############################################################################## -"""Tests for waitress.channel maintenance logic -""" -import doctest - - -class FakeSocket: # pragma: no cover - data = "" - setblocking = lambda *_: None - close = lambda *_: None - - def __init__(self, no): - self.no = no - - def fileno(self): - return self.no - - def getpeername(self): - return ("localhost", self.no) - - def send(self, data): - self.data += data - return len(data) - - def recv(self, data): - return "data" - - -def zombies_test(): - """Regression test for HTTPChannel.maintenance method - - Bug: This method checks for channels that have been "inactive" for a - configured time. The bug was that last_activity is set at creation time - but never updated during async channel activity (reads and writes), so - any channel older than the configured timeout will be closed when a new - channel is created, regardless of activity. - - >>> import time - >>> import waitress.adjustments - >>> config = waitress.adjustments.Adjustments() - - >>> from waitress.server import HTTPServer - >>> class TestServer(HTTPServer): - ... def bind(self, (ip, port)): - ... print "Listening on %s:%d" % (ip or '*', port) - >>> sb = TestServer('127.0.0.1', 80, start=False, verbose=True) - Listening on 127.0.0.1:80 - - First we confirm the correct behavior, where a channel with no activity - for the timeout duration gets closed. - - >>> from waitress.channel import HTTPChannel - >>> socket = FakeSocket(42) - >>> channel = HTTPChannel(sb, socket, ('localhost', 42)) - - >>> channel.connected - True - - >>> channel.last_activity -= int(config.channel_timeout) + 1 - - >>> channel.next_channel_cleanup[0] = channel.creation_time - int( - ... config.cleanup_interval) - 1 - - >>> socket2 = FakeSocket(7) - >>> channel2 = HTTPChannel(sb, socket2, ('localhost', 7)) - - >>> channel.connected - False - - Write Activity - -------------- - - Now we make sure that if there is activity the channel doesn't get closed - incorrectly. - - >>> channel2.connected - True - - >>> channel2.last_activity -= int(config.channel_timeout) + 1 - - >>> channel2.handle_write() - - >>> channel2.next_channel_cleanup[0] = channel2.creation_time - int( - ... config.cleanup_interval) - 1 - - >>> socket3 = FakeSocket(3) - >>> channel3 = HTTPChannel(sb, socket3, ('localhost', 3)) - - >>> channel2.connected - True - - Read Activity - -------------- - - We should test to see that read activity will update a channel as well. - - >>> channel3.connected - True - - >>> channel3.last_activity -= int(config.channel_timeout) + 1 - - >>> import waitress.parser - >>> channel3.parser_class = ( - ... waitress.parser.HTTPRequestParser) - >>> channel3.handle_read() - - >>> channel3.next_channel_cleanup[0] = channel3.creation_time - int( - ... config.cleanup_interval) - 1 - - >>> socket4 = FakeSocket(4) - >>> channel4 = HTTPChannel(sb, socket4, ('localhost', 4)) - - >>> channel3.connected - True - - Main loop window - ---------------- - - There is also a corner case we'll do a shallow test for where a - channel can be closed waiting for the main loop. - - >>> channel4.last_activity -= 1 - - >>> last_active = channel4.last_activity - - >>> channel4.set_async() - - >>> channel4.last_activity != last_active - True - -""" - - -def test_suite(): - return doctest.DocTestSuite() diff --git a/libs/waitress/tests/test_runner.py b/libs/waitress/tests/test_runner.py deleted file mode 100644 index 127757e15..000000000 --- a/libs/waitress/tests/test_runner.py +++ /dev/null @@ -1,191 +0,0 @@ -import contextlib -import os -import sys - -if sys.version_info[:2] == (2, 6): # pragma: no cover - import unittest2 as unittest -else: # pragma: no cover - import unittest - -from waitress import runner - - -class Test_match(unittest.TestCase): - def test_empty(self): - self.assertRaisesRegexp( - ValueError, "^Malformed application ''$", runner.match, "" - ) - - def test_module_only(self): - self.assertRaisesRegexp( - ValueError, r"^Malformed application 'foo\.bar'$", runner.match, "foo.bar" - ) - - def test_bad_module(self): - self.assertRaisesRegexp( - ValueError, - r"^Malformed application 'foo#bar:barney'$", - runner.match, - "foo#bar:barney", - ) - - def test_module_obj(self): - self.assertTupleEqual( - runner.match("foo.bar:fred.barney"), ("foo.bar", "fred.barney") - ) - - -class Test_resolve(unittest.TestCase): - def test_bad_module(self): - self.assertRaises( - ImportError, runner.resolve, "nonexistent", "nonexistent_function" - ) - - def test_nonexistent_function(self): - self.assertRaisesRegexp( - AttributeError, - r"has no attribute 'nonexistent_function'", - runner.resolve, - "os.path", - "nonexistent_function", - ) - - def test_simple_happy_path(self): - from os.path import exists - - self.assertIs(runner.resolve("os.path", "exists"), exists) - - def test_complex_happy_path(self): - # Ensure we can recursively resolve object attributes if necessary. - self.assertEquals(runner.resolve("os.path", "exists.__name__"), "exists") - - -class Test_run(unittest.TestCase): - def match_output(self, argv, code, regex): - argv = ["waitress-serve"] + argv - with capture() as captured: - self.assertEqual(runner.run(argv=argv), code) - self.assertRegexpMatches(captured.getvalue(), regex) - captured.close() - - def test_bad(self): - self.match_output(["--bad-opt"], 1, "^Error: option --bad-opt not recognized") - - def test_help(self): - self.match_output(["--help"], 0, "^Usage:\n\n waitress-serve") - - def test_no_app(self): - self.match_output([], 1, "^Error: Specify one application only") - - def test_multiple_apps_app(self): - self.match_output(["a:a", "b:b"], 1, "^Error: Specify one application only") - - def test_bad_apps_app(self): - self.match_output(["a"], 1, "^Error: Malformed application 'a'") - - def test_bad_app_module(self): - self.match_output(["nonexistent:a"], 1, "^Error: Bad module 'nonexistent'") - - self.match_output( - ["nonexistent:a"], - 1, - ( - r"There was an exception \((ImportError|ModuleNotFoundError)\) " - "importing your module.\n\nIt had these arguments: \n" - "1. No module named '?nonexistent'?" - ), - ) - - def test_cwd_added_to_path(self): - def null_serve(app, **kw): - pass - - sys_path = sys.path - current_dir = os.getcwd() - try: - os.chdir(os.path.dirname(__file__)) - argv = [ - "waitress-serve", - "fixtureapps.runner:app", - ] - self.assertEqual(runner.run(argv=argv, _serve=null_serve), 0) - finally: - sys.path = sys_path - os.chdir(current_dir) - - def test_bad_app_object(self): - self.match_output( - ["waitress.tests.fixtureapps.runner:a"], 1, "^Error: Bad object name 'a'" - ) - - def test_simple_call(self): - import waitress.tests.fixtureapps.runner as _apps - - def check_server(app, **kw): - self.assertIs(app, _apps.app) - self.assertDictEqual(kw, {"port": "80"}) - - argv = [ - "waitress-serve", - "--port=80", - "waitress.tests.fixtureapps.runner:app", - ] - self.assertEqual(runner.run(argv=argv, _serve=check_server), 0) - - def test_returned_app(self): - import waitress.tests.fixtureapps.runner as _apps - - def check_server(app, **kw): - self.assertIs(app, _apps.app) - self.assertDictEqual(kw, {"port": "80"}) - - argv = [ - "waitress-serve", - "--port=80", - "--call", - "waitress.tests.fixtureapps.runner:returns_app", - ] - self.assertEqual(runner.run(argv=argv, _serve=check_server), 0) - - -class Test_helper(unittest.TestCase): - def test_exception_logging(self): - from waitress.runner import show_exception - - regex = ( - r"There was an exception \(ImportError\) importing your module." - r"\n\nIt had these arguments: \n1. My reason" - ) - - with capture() as captured: - try: - raise ImportError("My reason") - except ImportError: - self.assertEqual(show_exception(sys.stderr), None) - self.assertRegexpMatches(captured.getvalue(), regex) - captured.close() - - regex = ( - r"There was an exception \(ImportError\) importing your module." - r"\n\nIt had no arguments." - ) - - with capture() as captured: - try: - raise ImportError - except ImportError: - self.assertEqual(show_exception(sys.stderr), None) - self.assertRegexpMatches(captured.getvalue(), regex) - captured.close() - - -@contextlib.contextmanager -def capture(): - from waitress.compat import NativeIO - - fd = NativeIO() - sys.stdout = fd - sys.stderr = fd - yield fd - sys.stdout = sys.__stdout__ - sys.stderr = sys.__stderr__ diff --git a/libs/waitress/tests/test_server.py b/libs/waitress/tests/test_server.py deleted file mode 100644 index 9134fb8c1..000000000 --- a/libs/waitress/tests/test_server.py +++ /dev/null @@ -1,533 +0,0 @@ -import errno -import socket -import unittest - -dummy_app = object() - - -class TestWSGIServer(unittest.TestCase): - def _makeOne( - self, - application=dummy_app, - host="127.0.0.1", - port=0, - _dispatcher=None, - adj=None, - map=None, - _start=True, - _sock=None, - _server=None, - ): - from waitress.server import create_server - - self.inst = create_server( - application, - host=host, - port=port, - map=map, - _dispatcher=_dispatcher, - _start=_start, - _sock=_sock, - ) - return self.inst - - def _makeOneWithMap( - self, adj=None, _start=True, host="127.0.0.1", port=0, app=dummy_app - ): - sock = DummySock() - task_dispatcher = DummyTaskDispatcher() - map = {} - return self._makeOne( - app, - host=host, - port=port, - map=map, - _sock=sock, - _dispatcher=task_dispatcher, - _start=_start, - ) - - def _makeOneWithMulti( - self, adj=None, _start=True, app=dummy_app, listen="127.0.0.1:0 127.0.0.1:0" - ): - sock = DummySock() - task_dispatcher = DummyTaskDispatcher() - map = {} - from waitress.server import create_server - - self.inst = create_server( - app, - listen=listen, - map=map, - _dispatcher=task_dispatcher, - _start=_start, - _sock=sock, - ) - return self.inst - - def _makeWithSockets( - self, - application=dummy_app, - _dispatcher=None, - map=None, - _start=True, - _sock=None, - _server=None, - sockets=None, - ): - from waitress.server import create_server - - _sockets = [] - if sockets is not None: - _sockets = sockets - self.inst = create_server( - application, - map=map, - _dispatcher=_dispatcher, - _start=_start, - _sock=_sock, - sockets=_sockets, - ) - return self.inst - - def tearDown(self): - if self.inst is not None: - self.inst.close() - - def test_ctor_app_is_None(self): - self.inst = None - self.assertRaises(ValueError, self._makeOneWithMap, app=None) - - def test_ctor_start_true(self): - inst = self._makeOneWithMap(_start=True) - self.assertEqual(inst.accepting, True) - self.assertEqual(inst.socket.listened, 1024) - - def test_ctor_makes_dispatcher(self): - inst = self._makeOne(_start=False, map={}) - self.assertEqual( - inst.task_dispatcher.__class__.__name__, "ThreadedTaskDispatcher" - ) - - def test_ctor_start_false(self): - inst = self._makeOneWithMap(_start=False) - self.assertEqual(inst.accepting, False) - - def test_get_server_name_empty(self): - inst = self._makeOneWithMap(_start=False) - self.assertRaises(ValueError, inst.get_server_name, "") - - def test_get_server_name_with_ip(self): - inst = self._makeOneWithMap(_start=False) - result = inst.get_server_name("127.0.0.1") - self.assertTrue(result) - - def test_get_server_name_with_hostname(self): - inst = self._makeOneWithMap(_start=False) - result = inst.get_server_name("fred.flintstone.com") - self.assertEqual(result, "fred.flintstone.com") - - def test_get_server_name_0000(self): - inst = self._makeOneWithMap(_start=False) - result = inst.get_server_name("0.0.0.0") - self.assertTrue(len(result) != 0) - - def test_get_server_name_double_colon(self): - inst = self._makeOneWithMap(_start=False) - result = inst.get_server_name("::") - self.assertTrue(len(result) != 0) - - def test_get_server_name_ipv6(self): - inst = self._makeOneWithMap(_start=False) - result = inst.get_server_name("2001:DB8::ffff") - self.assertEqual("[2001:DB8::ffff]", result) - - def test_get_server_multi(self): - inst = self._makeOneWithMulti() - self.assertEqual(inst.__class__.__name__, "MultiSocketServer") - - def test_run(self): - inst = self._makeOneWithMap(_start=False) - inst.asyncore = DummyAsyncore() - inst.task_dispatcher = DummyTaskDispatcher() - inst.run() - self.assertTrue(inst.task_dispatcher.was_shutdown) - - def test_run_base_server(self): - inst = self._makeOneWithMulti(_start=False) - inst.asyncore = DummyAsyncore() - inst.task_dispatcher = DummyTaskDispatcher() - inst.run() - self.assertTrue(inst.task_dispatcher.was_shutdown) - - def test_pull_trigger(self): - inst = self._makeOneWithMap(_start=False) - inst.trigger.close() - inst.trigger = DummyTrigger() - inst.pull_trigger() - self.assertEqual(inst.trigger.pulled, True) - - def test_add_task(self): - task = DummyTask() - inst = self._makeOneWithMap() - inst.add_task(task) - self.assertEqual(inst.task_dispatcher.tasks, [task]) - self.assertFalse(task.serviced) - - def test_readable_not_accepting(self): - inst = self._makeOneWithMap() - inst.accepting = False - self.assertFalse(inst.readable()) - - def test_readable_maplen_gt_connection_limit(self): - inst = self._makeOneWithMap() - inst.accepting = True - inst.adj = DummyAdj - inst._map = {"a": 1, "b": 2} - self.assertFalse(inst.readable()) - - def test_readable_maplen_lt_connection_limit(self): - inst = self._makeOneWithMap() - inst.accepting = True - inst.adj = DummyAdj - inst._map = {} - self.assertTrue(inst.readable()) - - def test_readable_maintenance_false(self): - import time - - inst = self._makeOneWithMap() - then = time.time() + 1000 - inst.next_channel_cleanup = then - L = [] - inst.maintenance = lambda t: L.append(t) - inst.readable() - self.assertEqual(L, []) - self.assertEqual(inst.next_channel_cleanup, then) - - def test_readable_maintenance_true(self): - inst = self._makeOneWithMap() - inst.next_channel_cleanup = 0 - L = [] - inst.maintenance = lambda t: L.append(t) - inst.readable() - self.assertEqual(len(L), 1) - self.assertNotEqual(inst.next_channel_cleanup, 0) - - def test_writable(self): - inst = self._makeOneWithMap() - self.assertFalse(inst.writable()) - - def test_handle_read(self): - inst = self._makeOneWithMap() - self.assertEqual(inst.handle_read(), None) - - def test_handle_connect(self): - inst = self._makeOneWithMap() - self.assertEqual(inst.handle_connect(), None) - - def test_handle_accept_wouldblock_socket_error(self): - inst = self._makeOneWithMap() - ewouldblock = socket.error(errno.EWOULDBLOCK) - inst.socket = DummySock(toraise=ewouldblock) - inst.handle_accept() - self.assertEqual(inst.socket.accepted, False) - - def test_handle_accept_other_socket_error(self): - inst = self._makeOneWithMap() - eaborted = socket.error(errno.ECONNABORTED) - inst.socket = DummySock(toraise=eaborted) - inst.adj = DummyAdj - - def foo(): - raise socket.error - - inst.accept = foo - inst.logger = DummyLogger() - inst.handle_accept() - self.assertEqual(inst.socket.accepted, False) - self.assertEqual(len(inst.logger.logged), 1) - - def test_handle_accept_noerror(self): - inst = self._makeOneWithMap() - innersock = DummySock() - inst.socket = DummySock(acceptresult=(innersock, None)) - inst.adj = DummyAdj - L = [] - inst.channel_class = lambda *arg, **kw: L.append(arg) - inst.handle_accept() - self.assertEqual(inst.socket.accepted, True) - self.assertEqual(innersock.opts, [("level", "optname", "value")]) - self.assertEqual(L, [(inst, innersock, None, inst.adj)]) - - def test_maintenance(self): - inst = self._makeOneWithMap() - - class DummyChannel(object): - requests = [] - - zombie = DummyChannel() - zombie.last_activity = 0 - zombie.running_tasks = False - inst.active_channels[100] = zombie - inst.maintenance(10000) - self.assertEqual(zombie.will_close, True) - - def test_backward_compatibility(self): - from waitress.server import WSGIServer, TcpWSGIServer - from waitress.adjustments import Adjustments - - self.assertTrue(WSGIServer is TcpWSGIServer) - self.inst = WSGIServer(None, _start=False, port=1234) - # Ensure the adjustment was actually applied. - self.assertNotEqual(Adjustments.port, 1234) - self.assertEqual(self.inst.adj.port, 1234) - - def test_create_with_one_tcp_socket(self): - from waitress.server import TcpWSGIServer - - sockets = [socket.socket(socket.AF_INET, socket.SOCK_STREAM)] - sockets[0].bind(("127.0.0.1", 0)) - inst = self._makeWithSockets(_start=False, sockets=sockets) - self.assertTrue(isinstance(inst, TcpWSGIServer)) - - def test_create_with_multiple_tcp_sockets(self): - from waitress.server import MultiSocketServer - - sockets = [ - socket.socket(socket.AF_INET, socket.SOCK_STREAM), - socket.socket(socket.AF_INET, socket.SOCK_STREAM), - ] - sockets[0].bind(("127.0.0.1", 0)) - sockets[1].bind(("127.0.0.1", 0)) - inst = self._makeWithSockets(_start=False, sockets=sockets) - self.assertTrue(isinstance(inst, MultiSocketServer)) - self.assertEqual(len(inst.effective_listen), 2) - - def test_create_with_one_socket_should_not_bind_socket(self): - innersock = DummySock() - sockets = [DummySock(acceptresult=(innersock, None))] - sockets[0].bind(("127.0.0.1", 80)) - sockets[0].bind_called = False - inst = self._makeWithSockets(_start=False, sockets=sockets) - self.assertEqual(inst.socket.bound, ("127.0.0.1", 80)) - self.assertFalse(inst.socket.bind_called) - - def test_create_with_one_socket_handle_accept_noerror(self): - innersock = DummySock() - sockets = [DummySock(acceptresult=(innersock, None))] - sockets[0].bind(("127.0.0.1", 80)) - inst = self._makeWithSockets(sockets=sockets) - L = [] - inst.channel_class = lambda *arg, **kw: L.append(arg) - inst.adj = DummyAdj - inst.handle_accept() - self.assertEqual(sockets[0].accepted, True) - self.assertEqual(innersock.opts, [("level", "optname", "value")]) - self.assertEqual(L, [(inst, innersock, None, inst.adj)]) - - -if hasattr(socket, "AF_UNIX"): - - class TestUnixWSGIServer(unittest.TestCase): - unix_socket = "/tmp/waitress.test.sock" - - def _makeOne(self, _start=True, _sock=None): - from waitress.server import create_server - - self.inst = create_server( - dummy_app, - map={}, - _start=_start, - _sock=_sock, - _dispatcher=DummyTaskDispatcher(), - unix_socket=self.unix_socket, - unix_socket_perms="600", - ) - return self.inst - - def _makeWithSockets( - self, - application=dummy_app, - _dispatcher=None, - map=None, - _start=True, - _sock=None, - _server=None, - sockets=None, - ): - from waitress.server import create_server - - _sockets = [] - if sockets is not None: - _sockets = sockets - self.inst = create_server( - application, - map=map, - _dispatcher=_dispatcher, - _start=_start, - _sock=_sock, - sockets=_sockets, - ) - return self.inst - - def tearDown(self): - self.inst.close() - - def _makeDummy(self, *args, **kwargs): - sock = DummySock(*args, **kwargs) - sock.family = socket.AF_UNIX - return sock - - def test_unix(self): - inst = self._makeOne(_start=False) - self.assertEqual(inst.socket.family, socket.AF_UNIX) - self.assertEqual(inst.socket.getsockname(), self.unix_socket) - - def test_handle_accept(self): - # Working on the assumption that we only have to test the happy path - # for Unix domain sockets as the other paths should've been covered - # by inet sockets. - client = self._makeDummy() - listen = self._makeDummy(acceptresult=(client, None)) - inst = self._makeOne(_sock=listen) - self.assertEqual(inst.accepting, True) - self.assertEqual(inst.socket.listened, 1024) - L = [] - inst.channel_class = lambda *arg, **kw: L.append(arg) - inst.handle_accept() - self.assertEqual(inst.socket.accepted, True) - self.assertEqual(client.opts, []) - self.assertEqual(L, [(inst, client, ("localhost", None), inst.adj)]) - - def test_creates_new_sockinfo(self): - from waitress.server import UnixWSGIServer - - self.inst = UnixWSGIServer( - dummy_app, unix_socket=self.unix_socket, unix_socket_perms="600" - ) - - self.assertEqual(self.inst.sockinfo[0], socket.AF_UNIX) - - def test_create_with_unix_socket(self): - from waitress.server import ( - MultiSocketServer, - BaseWSGIServer, - TcpWSGIServer, - UnixWSGIServer, - ) - - sockets = [ - socket.socket(socket.AF_UNIX, socket.SOCK_STREAM), - socket.socket(socket.AF_UNIX, socket.SOCK_STREAM), - ] - inst = self._makeWithSockets(sockets=sockets, _start=False) - self.assertTrue(isinstance(inst, MultiSocketServer)) - server = list( - filter(lambda s: isinstance(s, BaseWSGIServer), inst.map.values()) - ) - self.assertTrue(isinstance(server[0], UnixWSGIServer)) - self.assertTrue(isinstance(server[1], UnixWSGIServer)) - - -class DummySock(socket.socket): - accepted = False - blocking = False - family = socket.AF_INET - type = socket.SOCK_STREAM - proto = 0 - - def __init__(self, toraise=None, acceptresult=(None, None)): - self.toraise = toraise - self.acceptresult = acceptresult - self.bound = None - self.opts = [] - self.bind_called = False - - def bind(self, addr): - self.bind_called = True - self.bound = addr - - def accept(self): - if self.toraise: - raise self.toraise - self.accepted = True - return self.acceptresult - - def setblocking(self, x): - self.blocking = True - - def fileno(self): - return 10 - - def getpeername(self): - return "127.0.0.1" - - def setsockopt(self, *arg): - self.opts.append(arg) - - def getsockopt(self, *arg): - return 1 - - def listen(self, num): - self.listened = num - - def getsockname(self): - return self.bound - - def close(self): - pass - - -class DummyTaskDispatcher(object): - def __init__(self): - self.tasks = [] - - def add_task(self, task): - self.tasks.append(task) - - def shutdown(self): - self.was_shutdown = True - - -class DummyTask(object): - serviced = False - start_response_called = False - wrote_header = False - status = "200 OK" - - def __init__(self): - self.response_headers = {} - self.written = "" - - def service(self): # pragma: no cover - self.serviced = True - - -class DummyAdj: - connection_limit = 1 - log_socket_errors = True - socket_options = [("level", "optname", "value")] - cleanup_interval = 900 - channel_timeout = 300 - - -class DummyAsyncore(object): - def loop(self, timeout=30.0, use_poll=False, map=None, count=None): - raise SystemExit - - -class DummyTrigger(object): - def pull_trigger(self): - self.pulled = True - - def close(self): - pass - - -class DummyLogger(object): - def __init__(self): - self.logged = [] - - def warning(self, msg, **kw): - self.logged.append(msg) diff --git a/libs/waitress/tests/test_task.py b/libs/waitress/tests/test_task.py deleted file mode 100644 index 1a86245ab..000000000 --- a/libs/waitress/tests/test_task.py +++ /dev/null @@ -1,1001 +0,0 @@ -import unittest -import io - - -class TestThreadedTaskDispatcher(unittest.TestCase): - def _makeOne(self): - from waitress.task import ThreadedTaskDispatcher - - return ThreadedTaskDispatcher() - - def test_handler_thread_task_raises(self): - inst = self._makeOne() - inst.threads.add(0) - inst.logger = DummyLogger() - - class BadDummyTask(DummyTask): - def service(self): - super(BadDummyTask, self).service() - inst.stop_count += 1 - raise Exception - - task = BadDummyTask() - inst.logger = DummyLogger() - inst.queue.append(task) - inst.active_count += 1 - inst.handler_thread(0) - self.assertEqual(inst.stop_count, 0) - self.assertEqual(inst.active_count, 0) - self.assertEqual(inst.threads, set()) - self.assertEqual(len(inst.logger.logged), 1) - - def test_set_thread_count_increase(self): - inst = self._makeOne() - L = [] - inst.start_new_thread = lambda *x: L.append(x) - inst.set_thread_count(1) - self.assertEqual(L, [(inst.handler_thread, (0,))]) - - def test_set_thread_count_increase_with_existing(self): - inst = self._makeOne() - L = [] - inst.threads = {0} - inst.start_new_thread = lambda *x: L.append(x) - inst.set_thread_count(2) - self.assertEqual(L, [(inst.handler_thread, (1,))]) - - def test_set_thread_count_decrease(self): - inst = self._makeOne() - inst.threads = {0, 1} - inst.set_thread_count(1) - self.assertEqual(inst.stop_count, 1) - - def test_set_thread_count_same(self): - inst = self._makeOne() - L = [] - inst.start_new_thread = lambda *x: L.append(x) - inst.threads = {0} - inst.set_thread_count(1) - self.assertEqual(L, []) - - def test_add_task_with_idle_threads(self): - task = DummyTask() - inst = self._makeOne() - inst.threads.add(0) - inst.queue_logger = DummyLogger() - inst.add_task(task) - self.assertEqual(len(inst.queue), 1) - self.assertEqual(len(inst.queue_logger.logged), 0) - - def test_add_task_with_all_busy_threads(self): - task = DummyTask() - inst = self._makeOne() - inst.queue_logger = DummyLogger() - inst.add_task(task) - self.assertEqual(len(inst.queue_logger.logged), 1) - inst.add_task(task) - self.assertEqual(len(inst.queue_logger.logged), 2) - - def test_shutdown_one_thread(self): - inst = self._makeOne() - inst.threads.add(0) - inst.logger = DummyLogger() - task = DummyTask() - inst.queue.append(task) - self.assertEqual(inst.shutdown(timeout=0.01), True) - self.assertEqual( - inst.logger.logged, - ["1 thread(s) still running", "Canceling 1 pending task(s)",], - ) - self.assertEqual(task.cancelled, True) - - def test_shutdown_no_threads(self): - inst = self._makeOne() - self.assertEqual(inst.shutdown(timeout=0.01), True) - - def test_shutdown_no_cancel_pending(self): - inst = self._makeOne() - self.assertEqual(inst.shutdown(cancel_pending=False, timeout=0.01), False) - - -class TestTask(unittest.TestCase): - def _makeOne(self, channel=None, request=None): - if channel is None: - channel = DummyChannel() - if request is None: - request = DummyParser() - from waitress.task import Task - - return Task(channel, request) - - def test_ctor_version_not_in_known(self): - request = DummyParser() - request.version = "8.4" - inst = self._makeOne(request=request) - self.assertEqual(inst.version, "1.0") - - def test_build_response_header_bad_http_version(self): - inst = self._makeOne() - inst.request = DummyParser() - inst.version = "8.4" - self.assertRaises(AssertionError, inst.build_response_header) - - def test_build_response_header_v10_keepalive_no_content_length(self): - inst = self._makeOne() - inst.request = DummyParser() - inst.request.headers["CONNECTION"] = "keep-alive" - inst.version = "1.0" - result = inst.build_response_header() - lines = filter_lines(result) - self.assertEqual(len(lines), 4) - self.assertEqual(lines[0], b"HTTP/1.0 200 OK") - self.assertEqual(lines[1], b"Connection: close") - self.assertTrue(lines[2].startswith(b"Date:")) - self.assertEqual(lines[3], b"Server: waitress") - self.assertEqual(inst.close_on_finish, True) - self.assertTrue(("Connection", "close") in inst.response_headers) - - def test_build_response_header_v10_keepalive_with_content_length(self): - inst = self._makeOne() - inst.request = DummyParser() - inst.request.headers["CONNECTION"] = "keep-alive" - inst.response_headers = [("Content-Length", "10")] - inst.version = "1.0" - inst.content_length = 0 - result = inst.build_response_header() - lines = filter_lines(result) - self.assertEqual(len(lines), 5) - self.assertEqual(lines[0], b"HTTP/1.0 200 OK") - self.assertEqual(lines[1], b"Connection: Keep-Alive") - self.assertEqual(lines[2], b"Content-Length: 10") - self.assertTrue(lines[3].startswith(b"Date:")) - self.assertEqual(lines[4], b"Server: waitress") - self.assertEqual(inst.close_on_finish, False) - - def test_build_response_header_v11_connection_closed_by_client(self): - inst = self._makeOne() - inst.request = DummyParser() - inst.version = "1.1" - inst.request.headers["CONNECTION"] = "close" - result = inst.build_response_header() - lines = filter_lines(result) - self.assertEqual(len(lines), 5) - self.assertEqual(lines[0], b"HTTP/1.1 200 OK") - self.assertEqual(lines[1], b"Connection: close") - self.assertTrue(lines[2].startswith(b"Date:")) - self.assertEqual(lines[3], b"Server: waitress") - self.assertEqual(lines[4], b"Transfer-Encoding: chunked") - self.assertTrue(("Connection", "close") in inst.response_headers) - self.assertEqual(inst.close_on_finish, True) - - def test_build_response_header_v11_connection_keepalive_by_client(self): - inst = self._makeOne() - inst.request = DummyParser() - inst.request.headers["CONNECTION"] = "keep-alive" - inst.version = "1.1" - result = inst.build_response_header() - lines = filter_lines(result) - self.assertEqual(len(lines), 5) - self.assertEqual(lines[0], b"HTTP/1.1 200 OK") - self.assertEqual(lines[1], b"Connection: close") - self.assertTrue(lines[2].startswith(b"Date:")) - self.assertEqual(lines[3], b"Server: waitress") - self.assertEqual(lines[4], b"Transfer-Encoding: chunked") - self.assertTrue(("Connection", "close") in inst.response_headers) - self.assertEqual(inst.close_on_finish, True) - - def test_build_response_header_v11_200_no_content_length(self): - inst = self._makeOne() - inst.request = DummyParser() - inst.version = "1.1" - result = inst.build_response_header() - lines = filter_lines(result) - self.assertEqual(len(lines), 5) - self.assertEqual(lines[0], b"HTTP/1.1 200 OK") - self.assertEqual(lines[1], b"Connection: close") - self.assertTrue(lines[2].startswith(b"Date:")) - self.assertEqual(lines[3], b"Server: waitress") - self.assertEqual(lines[4], b"Transfer-Encoding: chunked") - self.assertEqual(inst.close_on_finish, True) - self.assertTrue(("Connection", "close") in inst.response_headers) - - def test_build_response_header_v11_204_no_content_length_or_transfer_encoding(self): - # RFC 7230: MUST NOT send Transfer-Encoding or Content-Length - # for any response with a status code of 1xx or 204. - inst = self._makeOne() - inst.request = DummyParser() - inst.version = "1.1" - inst.status = "204 No Content" - result = inst.build_response_header() - lines = filter_lines(result) - self.assertEqual(len(lines), 4) - self.assertEqual(lines[0], b"HTTP/1.1 204 No Content") - self.assertEqual(lines[1], b"Connection: close") - self.assertTrue(lines[2].startswith(b"Date:")) - self.assertEqual(lines[3], b"Server: waitress") - self.assertEqual(inst.close_on_finish, True) - self.assertTrue(("Connection", "close") in inst.response_headers) - - def test_build_response_header_v11_1xx_no_content_length_or_transfer_encoding(self): - # RFC 7230: MUST NOT send Transfer-Encoding or Content-Length - # for any response with a status code of 1xx or 204. - inst = self._makeOne() - inst.request = DummyParser() - inst.version = "1.1" - inst.status = "100 Continue" - result = inst.build_response_header() - lines = filter_lines(result) - self.assertEqual(len(lines), 4) - self.assertEqual(lines[0], b"HTTP/1.1 100 Continue") - self.assertEqual(lines[1], b"Connection: close") - self.assertTrue(lines[2].startswith(b"Date:")) - self.assertEqual(lines[3], b"Server: waitress") - self.assertEqual(inst.close_on_finish, True) - self.assertTrue(("Connection", "close") in inst.response_headers) - - def test_build_response_header_v11_304_no_content_length_or_transfer_encoding(self): - # RFC 7230: MUST NOT send Transfer-Encoding or Content-Length - # for any response with a status code of 1xx, 204 or 304. - inst = self._makeOne() - inst.request = DummyParser() - inst.version = "1.1" - inst.status = "304 Not Modified" - result = inst.build_response_header() - lines = filter_lines(result) - self.assertEqual(len(lines), 4) - self.assertEqual(lines[0], b"HTTP/1.1 304 Not Modified") - self.assertEqual(lines[1], b"Connection: close") - self.assertTrue(lines[2].startswith(b"Date:")) - self.assertEqual(lines[3], b"Server: waitress") - self.assertEqual(inst.close_on_finish, True) - self.assertTrue(("Connection", "close") in inst.response_headers) - - def test_build_response_header_via_added(self): - inst = self._makeOne() - inst.request = DummyParser() - inst.version = "1.0" - inst.response_headers = [("Server", "abc")] - result = inst.build_response_header() - lines = filter_lines(result) - self.assertEqual(len(lines), 5) - self.assertEqual(lines[0], b"HTTP/1.0 200 OK") - self.assertEqual(lines[1], b"Connection: close") - self.assertTrue(lines[2].startswith(b"Date:")) - self.assertEqual(lines[3], b"Server: abc") - self.assertEqual(lines[4], b"Via: waitress") - - def test_build_response_header_date_exists(self): - inst = self._makeOne() - inst.request = DummyParser() - inst.version = "1.0" - inst.response_headers = [("Date", "date")] - result = inst.build_response_header() - lines = filter_lines(result) - self.assertEqual(len(lines), 4) - self.assertEqual(lines[0], b"HTTP/1.0 200 OK") - self.assertEqual(lines[1], b"Connection: close") - self.assertTrue(lines[2].startswith(b"Date:")) - self.assertEqual(lines[3], b"Server: waitress") - - def test_build_response_header_preexisting_content_length(self): - inst = self._makeOne() - inst.request = DummyParser() - inst.version = "1.1" - inst.content_length = 100 - result = inst.build_response_header() - lines = filter_lines(result) - self.assertEqual(len(lines), 4) - self.assertEqual(lines[0], b"HTTP/1.1 200 OK") - self.assertEqual(lines[1], b"Content-Length: 100") - self.assertTrue(lines[2].startswith(b"Date:")) - self.assertEqual(lines[3], b"Server: waitress") - - def test_remove_content_length_header(self): - inst = self._makeOne() - inst.response_headers = [("Content-Length", "70")] - inst.remove_content_length_header() - self.assertEqual(inst.response_headers, []) - - def test_remove_content_length_header_with_other(self): - inst = self._makeOne() - inst.response_headers = [ - ("Content-Length", "70"), - ("Content-Type", "text/html"), - ] - inst.remove_content_length_header() - self.assertEqual(inst.response_headers, [("Content-Type", "text/html")]) - - def test_start(self): - inst = self._makeOne() - inst.start() - self.assertTrue(inst.start_time) - - def test_finish_didnt_write_header(self): - inst = self._makeOne() - inst.wrote_header = False - inst.complete = True - inst.finish() - self.assertTrue(inst.channel.written) - - def test_finish_wrote_header(self): - inst = self._makeOne() - inst.wrote_header = True - inst.finish() - self.assertFalse(inst.channel.written) - - def test_finish_chunked_response(self): - inst = self._makeOne() - inst.wrote_header = True - inst.chunked_response = True - inst.finish() - self.assertEqual(inst.channel.written, b"0\r\n\r\n") - - def test_write_wrote_header(self): - inst = self._makeOne() - inst.wrote_header = True - inst.complete = True - inst.content_length = 3 - inst.write(b"abc") - self.assertEqual(inst.channel.written, b"abc") - - def test_write_header_not_written(self): - inst = self._makeOne() - inst.wrote_header = False - inst.complete = True - inst.write(b"abc") - self.assertTrue(inst.channel.written) - self.assertEqual(inst.wrote_header, True) - - def test_write_start_response_uncalled(self): - inst = self._makeOne() - self.assertRaises(RuntimeError, inst.write, b"") - - def test_write_chunked_response(self): - inst = self._makeOne() - inst.wrote_header = True - inst.chunked_response = True - inst.complete = True - inst.write(b"abc") - self.assertEqual(inst.channel.written, b"3\r\nabc\r\n") - - def test_write_preexisting_content_length(self): - inst = self._makeOne() - inst.wrote_header = True - inst.complete = True - inst.content_length = 1 - inst.logger = DummyLogger() - inst.write(b"abc") - self.assertTrue(inst.channel.written) - self.assertEqual(inst.logged_write_excess, True) - self.assertEqual(len(inst.logger.logged), 1) - - -class TestWSGITask(unittest.TestCase): - def _makeOne(self, channel=None, request=None): - if channel is None: - channel = DummyChannel() - if request is None: - request = DummyParser() - from waitress.task import WSGITask - - return WSGITask(channel, request) - - def test_service(self): - inst = self._makeOne() - - def execute(): - inst.executed = True - - inst.execute = execute - inst.complete = True - inst.service() - self.assertTrue(inst.start_time) - self.assertTrue(inst.close_on_finish) - self.assertTrue(inst.channel.written) - self.assertEqual(inst.executed, True) - - def test_service_server_raises_socket_error(self): - import socket - - inst = self._makeOne() - - def execute(): - raise socket.error - - inst.execute = execute - self.assertRaises(socket.error, inst.service) - self.assertTrue(inst.start_time) - self.assertTrue(inst.close_on_finish) - self.assertFalse(inst.channel.written) - - def test_execute_app_calls_start_response_twice_wo_exc_info(self): - def app(environ, start_response): - start_response("200 OK", []) - start_response("200 OK", []) - - inst = self._makeOne() - inst.channel.server.application = app - self.assertRaises(AssertionError, inst.execute) - - def test_execute_app_calls_start_response_w_exc_info_complete(self): - def app(environ, start_response): - start_response("200 OK", [], [ValueError, ValueError(), None]) - return [b"a"] - - inst = self._makeOne() - inst.complete = True - inst.channel.server.application = app - inst.execute() - self.assertTrue(inst.complete) - self.assertEqual(inst.status, "200 OK") - self.assertTrue(inst.channel.written) - - def test_execute_app_calls_start_response_w_excinf_headers_unwritten(self): - def app(environ, start_response): - start_response("200 OK", [], [ValueError, None, None]) - return [b"a"] - - inst = self._makeOne() - inst.wrote_header = False - inst.channel.server.application = app - inst.response_headers = [("a", "b")] - inst.execute() - self.assertTrue(inst.complete) - self.assertEqual(inst.status, "200 OK") - self.assertTrue(inst.channel.written) - self.assertFalse(("a", "b") in inst.response_headers) - - def test_execute_app_calls_start_response_w_excinf_headers_written(self): - def app(environ, start_response): - start_response("200 OK", [], [ValueError, ValueError(), None]) - - inst = self._makeOne() - inst.complete = True - inst.wrote_header = True - inst.channel.server.application = app - self.assertRaises(ValueError, inst.execute) - - def test_execute_bad_header_key(self): - def app(environ, start_response): - start_response("200 OK", [(None, "a")]) - - inst = self._makeOne() - inst.channel.server.application = app - self.assertRaises(AssertionError, inst.execute) - - def test_execute_bad_header_value(self): - def app(environ, start_response): - start_response("200 OK", [("a", None)]) - - inst = self._makeOne() - inst.channel.server.application = app - self.assertRaises(AssertionError, inst.execute) - - def test_execute_hopbyhop_header(self): - def app(environ, start_response): - start_response("200 OK", [("Connection", "close")]) - - inst = self._makeOne() - inst.channel.server.application = app - self.assertRaises(AssertionError, inst.execute) - - def test_execute_bad_header_value_control_characters(self): - def app(environ, start_response): - start_response("200 OK", [("a", "\n")]) - - inst = self._makeOne() - inst.channel.server.application = app - self.assertRaises(ValueError, inst.execute) - - def test_execute_bad_header_name_control_characters(self): - def app(environ, start_response): - start_response("200 OK", [("a\r", "value")]) - - inst = self._makeOne() - inst.channel.server.application = app - self.assertRaises(ValueError, inst.execute) - - def test_execute_bad_status_control_characters(self): - def app(environ, start_response): - start_response("200 OK\r", []) - - inst = self._makeOne() - inst.channel.server.application = app - self.assertRaises(ValueError, inst.execute) - - def test_preserve_header_value_order(self): - def app(environ, start_response): - write = start_response("200 OK", [("C", "b"), ("A", "b"), ("A", "a")]) - write(b"abc") - return [] - - inst = self._makeOne() - inst.channel.server.application = app - inst.execute() - self.assertTrue(b"A: b\r\nA: a\r\nC: b\r\n" in inst.channel.written) - - def test_execute_bad_status_value(self): - def app(environ, start_response): - start_response(None, []) - - inst = self._makeOne() - inst.channel.server.application = app - self.assertRaises(AssertionError, inst.execute) - - def test_execute_with_content_length_header(self): - def app(environ, start_response): - start_response("200 OK", [("Content-Length", "1")]) - return [b"a"] - - inst = self._makeOne() - inst.channel.server.application = app - inst.execute() - self.assertEqual(inst.content_length, 1) - - def test_execute_app_calls_write(self): - def app(environ, start_response): - write = start_response("200 OK", [("Content-Length", "3")]) - write(b"abc") - return [] - - inst = self._makeOne() - inst.channel.server.application = app - inst.execute() - self.assertEqual(inst.channel.written[-3:], b"abc") - - def test_execute_app_returns_len1_chunk_without_cl(self): - def app(environ, start_response): - start_response("200 OK", []) - return [b"abc"] - - inst = self._makeOne() - inst.channel.server.application = app - inst.execute() - self.assertEqual(inst.content_length, 3) - - def test_execute_app_returns_empty_chunk_as_first(self): - def app(environ, start_response): - start_response("200 OK", []) - return ["", b"abc"] - - inst = self._makeOne() - inst.channel.server.application = app - inst.execute() - self.assertEqual(inst.content_length, None) - - def test_execute_app_returns_too_many_bytes(self): - def app(environ, start_response): - start_response("200 OK", [("Content-Length", "1")]) - return [b"abc"] - - inst = self._makeOne() - inst.channel.server.application = app - inst.logger = DummyLogger() - inst.execute() - self.assertEqual(inst.close_on_finish, True) - self.assertEqual(len(inst.logger.logged), 1) - - def test_execute_app_returns_too_few_bytes(self): - def app(environ, start_response): - start_response("200 OK", [("Content-Length", "3")]) - return [b"a"] - - inst = self._makeOne() - inst.channel.server.application = app - inst.logger = DummyLogger() - inst.execute() - self.assertEqual(inst.close_on_finish, True) - self.assertEqual(len(inst.logger.logged), 1) - - def test_execute_app_do_not_warn_on_head(self): - def app(environ, start_response): - start_response("200 OK", [("Content-Length", "3")]) - return [b""] - - inst = self._makeOne() - inst.request.command = "HEAD" - inst.channel.server.application = app - inst.logger = DummyLogger() - inst.execute() - self.assertEqual(inst.close_on_finish, True) - self.assertEqual(len(inst.logger.logged), 0) - - def test_execute_app_without_body_204_logged(self): - def app(environ, start_response): - start_response("204 No Content", [("Content-Length", "3")]) - return [b"abc"] - - inst = self._makeOne() - inst.channel.server.application = app - inst.logger = DummyLogger() - inst.execute() - self.assertEqual(inst.close_on_finish, True) - self.assertNotIn(b"abc", inst.channel.written) - self.assertNotIn(b"Content-Length", inst.channel.written) - self.assertNotIn(b"Transfer-Encoding", inst.channel.written) - self.assertEqual(len(inst.logger.logged), 1) - - def test_execute_app_without_body_304_logged(self): - def app(environ, start_response): - start_response("304 Not Modified", [("Content-Length", "3")]) - return [b"abc"] - - inst = self._makeOne() - inst.channel.server.application = app - inst.logger = DummyLogger() - inst.execute() - self.assertEqual(inst.close_on_finish, True) - self.assertNotIn(b"abc", inst.channel.written) - self.assertNotIn(b"Content-Length", inst.channel.written) - self.assertNotIn(b"Transfer-Encoding", inst.channel.written) - self.assertEqual(len(inst.logger.logged), 1) - - def test_execute_app_returns_closeable(self): - class closeable(list): - def close(self): - self.closed = True - - foo = closeable([b"abc"]) - - def app(environ, start_response): - start_response("200 OK", [("Content-Length", "3")]) - return foo - - inst = self._makeOne() - inst.channel.server.application = app - inst.execute() - self.assertEqual(foo.closed, True) - - def test_execute_app_returns_filewrapper_prepare_returns_True(self): - from waitress.buffers import ReadOnlyFileBasedBuffer - - f = io.BytesIO(b"abc") - app_iter = ReadOnlyFileBasedBuffer(f, 8192) - - def app(environ, start_response): - start_response("200 OK", [("Content-Length", "3")]) - return app_iter - - inst = self._makeOne() - inst.channel.server.application = app - inst.execute() - self.assertTrue(inst.channel.written) # header - self.assertEqual(inst.channel.otherdata, [app_iter]) - - def test_execute_app_returns_filewrapper_prepare_returns_True_nocl(self): - from waitress.buffers import ReadOnlyFileBasedBuffer - - f = io.BytesIO(b"abc") - app_iter = ReadOnlyFileBasedBuffer(f, 8192) - - def app(environ, start_response): - start_response("200 OK", []) - return app_iter - - inst = self._makeOne() - inst.channel.server.application = app - inst.execute() - self.assertTrue(inst.channel.written) # header - self.assertEqual(inst.channel.otherdata, [app_iter]) - self.assertEqual(inst.content_length, 3) - - def test_execute_app_returns_filewrapper_prepare_returns_True_badcl(self): - from waitress.buffers import ReadOnlyFileBasedBuffer - - f = io.BytesIO(b"abc") - app_iter = ReadOnlyFileBasedBuffer(f, 8192) - - def app(environ, start_response): - start_response("200 OK", []) - return app_iter - - inst = self._makeOne() - inst.channel.server.application = app - inst.content_length = 10 - inst.response_headers = [("Content-Length", "10")] - inst.execute() - self.assertTrue(inst.channel.written) # header - self.assertEqual(inst.channel.otherdata, [app_iter]) - self.assertEqual(inst.content_length, 3) - self.assertEqual(dict(inst.response_headers)["Content-Length"], "3") - - def test_get_environment_already_cached(self): - inst = self._makeOne() - inst.environ = object() - self.assertEqual(inst.get_environment(), inst.environ) - - def test_get_environment_path_startswith_more_than_one_slash(self): - inst = self._makeOne() - request = DummyParser() - request.path = "///abc" - inst.request = request - environ = inst.get_environment() - self.assertEqual(environ["PATH_INFO"], "/abc") - - def test_get_environment_path_empty(self): - inst = self._makeOne() - request = DummyParser() - request.path = "" - inst.request = request - environ = inst.get_environment() - self.assertEqual(environ["PATH_INFO"], "") - - def test_get_environment_no_query(self): - inst = self._makeOne() - request = DummyParser() - inst.request = request - environ = inst.get_environment() - self.assertEqual(environ["QUERY_STRING"], "") - - def test_get_environment_with_query(self): - inst = self._makeOne() - request = DummyParser() - request.query = "abc" - inst.request = request - environ = inst.get_environment() - self.assertEqual(environ["QUERY_STRING"], "abc") - - def test_get_environ_with_url_prefix_miss(self): - inst = self._makeOne() - inst.channel.server.adj.url_prefix = "/foo" - request = DummyParser() - request.path = "/bar" - inst.request = request - environ = inst.get_environment() - self.assertEqual(environ["PATH_INFO"], "/bar") - self.assertEqual(environ["SCRIPT_NAME"], "/foo") - - def test_get_environ_with_url_prefix_hit(self): - inst = self._makeOne() - inst.channel.server.adj.url_prefix = "/foo" - request = DummyParser() - request.path = "/foo/fuz" - inst.request = request - environ = inst.get_environment() - self.assertEqual(environ["PATH_INFO"], "/fuz") - self.assertEqual(environ["SCRIPT_NAME"], "/foo") - - def test_get_environ_with_url_prefix_empty_path(self): - inst = self._makeOne() - inst.channel.server.adj.url_prefix = "/foo" - request = DummyParser() - request.path = "/foo" - inst.request = request - environ = inst.get_environment() - self.assertEqual(environ["PATH_INFO"], "") - self.assertEqual(environ["SCRIPT_NAME"], "/foo") - - def test_get_environment_values(self): - import sys - - inst = self._makeOne() - request = DummyParser() - request.headers = { - "CONTENT_TYPE": "abc", - "CONTENT_LENGTH": "10", - "X_FOO": "BAR", - "CONNECTION": "close", - } - request.query = "abc" - inst.request = request - environ = inst.get_environment() - - # nail the keys of environ - self.assertEqual( - sorted(environ.keys()), - [ - "CONTENT_LENGTH", - "CONTENT_TYPE", - "HTTP_CONNECTION", - "HTTP_X_FOO", - "PATH_INFO", - "QUERY_STRING", - "REMOTE_ADDR", - "REMOTE_HOST", - "REMOTE_PORT", - "REQUEST_METHOD", - "SCRIPT_NAME", - "SERVER_NAME", - "SERVER_PORT", - "SERVER_PROTOCOL", - "SERVER_SOFTWARE", - "wsgi.errors", - "wsgi.file_wrapper", - "wsgi.input", - "wsgi.input_terminated", - "wsgi.multiprocess", - "wsgi.multithread", - "wsgi.run_once", - "wsgi.url_scheme", - "wsgi.version", - ], - ) - - self.assertEqual(environ["REQUEST_METHOD"], "GET") - self.assertEqual(environ["SERVER_PORT"], "80") - self.assertEqual(environ["SERVER_NAME"], "localhost") - self.assertEqual(environ["SERVER_SOFTWARE"], "waitress") - self.assertEqual(environ["SERVER_PROTOCOL"], "HTTP/1.0") - self.assertEqual(environ["SCRIPT_NAME"], "") - self.assertEqual(environ["HTTP_CONNECTION"], "close") - self.assertEqual(environ["PATH_INFO"], "/") - self.assertEqual(environ["QUERY_STRING"], "abc") - self.assertEqual(environ["REMOTE_ADDR"], "127.0.0.1") - self.assertEqual(environ["REMOTE_HOST"], "127.0.0.1") - self.assertEqual(environ["REMOTE_PORT"], "39830") - self.assertEqual(environ["CONTENT_TYPE"], "abc") - self.assertEqual(environ["CONTENT_LENGTH"], "10") - self.assertEqual(environ["HTTP_X_FOO"], "BAR") - self.assertEqual(environ["wsgi.version"], (1, 0)) - self.assertEqual(environ["wsgi.url_scheme"], "http") - self.assertEqual(environ["wsgi.errors"], sys.stderr) - self.assertEqual(environ["wsgi.multithread"], True) - self.assertEqual(environ["wsgi.multiprocess"], False) - self.assertEqual(environ["wsgi.run_once"], False) - self.assertEqual(environ["wsgi.input"], "stream") - self.assertEqual(environ["wsgi.input_terminated"], True) - self.assertEqual(inst.environ, environ) - - -class TestErrorTask(unittest.TestCase): - def _makeOne(self, channel=None, request=None): - if channel is None: - channel = DummyChannel() - if request is None: - request = DummyParser() - request.error = self._makeDummyError() - from waitress.task import ErrorTask - - return ErrorTask(channel, request) - - def _makeDummyError(self): - from waitress.utilities import Error - - e = Error("body") - e.code = 432 - e.reason = "Too Ugly" - return e - - def test_execute_http_10(self): - inst = self._makeOne() - inst.execute() - lines = filter_lines(inst.channel.written) - self.assertEqual(len(lines), 9) - self.assertEqual(lines[0], b"HTTP/1.0 432 Too Ugly") - self.assertEqual(lines[1], b"Connection: close") - self.assertEqual(lines[2], b"Content-Length: 43") - self.assertEqual(lines[3], b"Content-Type: text/plain") - self.assertTrue(lines[4]) - self.assertEqual(lines[5], b"Server: waitress") - self.assertEqual(lines[6], b"Too Ugly") - self.assertEqual(lines[7], b"body") - self.assertEqual(lines[8], b"(generated by waitress)") - - def test_execute_http_11(self): - inst = self._makeOne() - inst.version = "1.1" - inst.execute() - lines = filter_lines(inst.channel.written) - self.assertEqual(len(lines), 9) - self.assertEqual(lines[0], b"HTTP/1.1 432 Too Ugly") - self.assertEqual(lines[1], b"Connection: close") - self.assertEqual(lines[2], b"Content-Length: 43") - self.assertEqual(lines[3], b"Content-Type: text/plain") - self.assertTrue(lines[4]) - self.assertEqual(lines[5], b"Server: waitress") - self.assertEqual(lines[6], b"Too Ugly") - self.assertEqual(lines[7], b"body") - self.assertEqual(lines[8], b"(generated by waitress)") - - def test_execute_http_11_close(self): - inst = self._makeOne() - inst.version = "1.1" - inst.request.headers["CONNECTION"] = "close" - inst.execute() - lines = filter_lines(inst.channel.written) - self.assertEqual(len(lines), 9) - self.assertEqual(lines[0], b"HTTP/1.1 432 Too Ugly") - self.assertEqual(lines[1], b"Connection: close") - self.assertEqual(lines[2], b"Content-Length: 43") - self.assertEqual(lines[3], b"Content-Type: text/plain") - self.assertTrue(lines[4]) - self.assertEqual(lines[5], b"Server: waitress") - self.assertEqual(lines[6], b"Too Ugly") - self.assertEqual(lines[7], b"body") - self.assertEqual(lines[8], b"(generated by waitress)") - - def test_execute_http_11_keep_forces_close(self): - inst = self._makeOne() - inst.version = "1.1" - inst.request.headers["CONNECTION"] = "keep-alive" - inst.execute() - lines = filter_lines(inst.channel.written) - self.assertEqual(len(lines), 9) - self.assertEqual(lines[0], b"HTTP/1.1 432 Too Ugly") - self.assertEqual(lines[1], b"Connection: close") - self.assertEqual(lines[2], b"Content-Length: 43") - self.assertEqual(lines[3], b"Content-Type: text/plain") - self.assertTrue(lines[4]) - self.assertEqual(lines[5], b"Server: waitress") - self.assertEqual(lines[6], b"Too Ugly") - self.assertEqual(lines[7], b"body") - self.assertEqual(lines[8], b"(generated by waitress)") - - -class DummyTask(object): - serviced = False - cancelled = False - - def service(self): - self.serviced = True - - def cancel(self): - self.cancelled = True - - -class DummyAdj(object): - log_socket_errors = True - ident = "waitress" - host = "127.0.0.1" - port = 80 - url_prefix = "" - - -class DummyServer(object): - server_name = "localhost" - effective_port = 80 - - def __init__(self): - self.adj = DummyAdj() - - -class DummyChannel(object): - closed_when_done = False - adj = DummyAdj() - creation_time = 0 - addr = ("127.0.0.1", 39830) - - def __init__(self, server=None): - if server is None: - server = DummyServer() - self.server = server - self.written = b"" - self.otherdata = [] - - def write_soon(self, data): - if isinstance(data, bytes): - self.written += data - else: - self.otherdata.append(data) - return len(data) - - -class DummyParser(object): - version = "1.0" - command = "GET" - path = "/" - query = "" - url_scheme = "http" - expect_continue = False - headers_finished = False - - def __init__(self): - self.headers = {} - - def get_body_stream(self): - return "stream" - - -def filter_lines(s): - return list(filter(None, s.split(b"\r\n"))) - - -class DummyLogger(object): - def __init__(self): - self.logged = [] - - def warning(self, msg, *args): - self.logged.append(msg % args) - - def exception(self, msg, *args): - self.logged.append(msg % args) diff --git a/libs/waitress/tests/test_trigger.py b/libs/waitress/tests/test_trigger.py deleted file mode 100644 index af740f68d..000000000 --- a/libs/waitress/tests/test_trigger.py +++ /dev/null @@ -1,111 +0,0 @@ -import unittest -import os -import sys - -if not sys.platform.startswith("win"): - - class Test_trigger(unittest.TestCase): - def _makeOne(self, map): - from waitress.trigger import trigger - - self.inst = trigger(map) - return self.inst - - def tearDown(self): - self.inst.close() # prevent __del__ warning from file_dispatcher - - def test__close(self): - map = {} - inst = self._makeOne(map) - fd1, fd2 = inst._fds - inst.close() - self.assertRaises(OSError, os.read, fd1, 1) - self.assertRaises(OSError, os.read, fd2, 1) - - def test__physical_pull(self): - map = {} - inst = self._makeOne(map) - inst._physical_pull() - r = os.read(inst._fds[0], 1) - self.assertEqual(r, b"x") - - def test_readable(self): - map = {} - inst = self._makeOne(map) - self.assertEqual(inst.readable(), True) - - def test_writable(self): - map = {} - inst = self._makeOne(map) - self.assertEqual(inst.writable(), False) - - def test_handle_connect(self): - map = {} - inst = self._makeOne(map) - self.assertEqual(inst.handle_connect(), None) - - def test_close(self): - map = {} - inst = self._makeOne(map) - self.assertEqual(inst.close(), None) - self.assertEqual(inst._closed, True) - - def test_handle_close(self): - map = {} - inst = self._makeOne(map) - self.assertEqual(inst.handle_close(), None) - self.assertEqual(inst._closed, True) - - def test_pull_trigger_nothunk(self): - map = {} - inst = self._makeOne(map) - self.assertEqual(inst.pull_trigger(), None) - r = os.read(inst._fds[0], 1) - self.assertEqual(r, b"x") - - def test_pull_trigger_thunk(self): - map = {} - inst = self._makeOne(map) - self.assertEqual(inst.pull_trigger(True), None) - self.assertEqual(len(inst.thunks), 1) - r = os.read(inst._fds[0], 1) - self.assertEqual(r, b"x") - - def test_handle_read_socket_error(self): - map = {} - inst = self._makeOne(map) - result = inst.handle_read() - self.assertEqual(result, None) - - def test_handle_read_no_socket_error(self): - map = {} - inst = self._makeOne(map) - inst.pull_trigger() - result = inst.handle_read() - self.assertEqual(result, None) - - def test_handle_read_thunk(self): - map = {} - inst = self._makeOne(map) - inst.pull_trigger() - L = [] - inst.thunks = [lambda: L.append(True)] - result = inst.handle_read() - self.assertEqual(result, None) - self.assertEqual(L, [True]) - self.assertEqual(inst.thunks, []) - - def test_handle_read_thunk_error(self): - map = {} - inst = self._makeOne(map) - - def errorthunk(): - raise ValueError - - inst.pull_trigger(errorthunk) - L = [] - inst.log_info = lambda *arg: L.append(arg) - result = inst.handle_read() - self.assertEqual(result, None) - self.assertEqual(len(L), 1) - self.assertEqual(inst.thunks, []) diff --git a/libs/waitress/tests/test_utilities.py b/libs/waitress/tests/test_utilities.py deleted file mode 100644 index 15cd24f5a..000000000 --- a/libs/waitress/tests/test_utilities.py +++ /dev/null @@ -1,140 +0,0 @@ -############################################################################## -# -# Copyright (c) 2002 Zope Foundation and Contributors. -# All Rights Reserved. -# -# This software is subject to the provisions of the Zope Public License, -# Version 2.1 (ZPL). A copy of the ZPL should accompany this distribution. -# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED -# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED -# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS -# FOR A PARTICULAR PURPOSE. -# -############################################################################## - -import unittest - - -class Test_parse_http_date(unittest.TestCase): - def _callFUT(self, v): - from waitress.utilities import parse_http_date - - return parse_http_date(v) - - def test_rfc850(self): - val = "Tuesday, 08-Feb-94 14:15:29 GMT" - result = self._callFUT(val) - self.assertEqual(result, 760716929) - - def test_rfc822(self): - val = "Sun, 08 Feb 1994 14:15:29 GMT" - result = self._callFUT(val) - self.assertEqual(result, 760716929) - - def test_neither(self): - val = "" - result = self._callFUT(val) - self.assertEqual(result, 0) - - -class Test_build_http_date(unittest.TestCase): - def test_rountdrip(self): - from waitress.utilities import build_http_date, parse_http_date - from time import time - - t = int(time()) - self.assertEqual(t, parse_http_date(build_http_date(t))) - - -class Test_unpack_rfc850(unittest.TestCase): - def _callFUT(self, val): - from waitress.utilities import unpack_rfc850, rfc850_reg - - return unpack_rfc850(rfc850_reg.match(val.lower())) - - def test_it(self): - val = "Tuesday, 08-Feb-94 14:15:29 GMT" - result = self._callFUT(val) - self.assertEqual(result, (1994, 2, 8, 14, 15, 29, 0, 0, 0)) - - -class Test_unpack_rfc_822(unittest.TestCase): - def _callFUT(self, val): - from waitress.utilities import unpack_rfc822, rfc822_reg - - return unpack_rfc822(rfc822_reg.match(val.lower())) - - def test_it(self): - val = "Sun, 08 Feb 1994 14:15:29 GMT" - result = self._callFUT(val) - self.assertEqual(result, (1994, 2, 8, 14, 15, 29, 0, 0, 0)) - - -class Test_find_double_newline(unittest.TestCase): - def _callFUT(self, val): - from waitress.utilities import find_double_newline - - return find_double_newline(val) - - def test_empty(self): - self.assertEqual(self._callFUT(b""), -1) - - def test_one_linefeed(self): - self.assertEqual(self._callFUT(b"\n"), -1) - - def test_double_linefeed(self): - self.assertEqual(self._callFUT(b"\n\n"), -1) - - def test_one_crlf(self): - self.assertEqual(self._callFUT(b"\r\n"), -1) - - def test_double_crfl(self): - self.assertEqual(self._callFUT(b"\r\n\r\n"), 4) - - def test_mixed(self): - self.assertEqual(self._callFUT(b"\n\n00\r\n\r\n"), 8) - - -class TestBadRequest(unittest.TestCase): - def _makeOne(self): - from waitress.utilities import BadRequest - - return BadRequest(1) - - def test_it(self): - inst = self._makeOne() - self.assertEqual(inst.body, 1) - - -class Test_undquote(unittest.TestCase): - def _callFUT(self, value): - from waitress.utilities import undquote - - return undquote(value) - - def test_empty(self): - self.assertEqual(self._callFUT(""), "") - - def test_quoted(self): - self.assertEqual(self._callFUT('"test"'), "test") - - def test_unquoted(self): - self.assertEqual(self._callFUT("test"), "test") - - def test_quoted_backslash_quote(self): - self.assertEqual(self._callFUT('"\\""'), '"') - - def test_quoted_htab(self): - self.assertEqual(self._callFUT('"\t"'), "\t") - - def test_quoted_backslash_htab(self): - self.assertEqual(self._callFUT('"\\\t"'), "\t") - - def test_quoted_backslash_invalid(self): - self.assertRaises(ValueError, self._callFUT, '"\\"') - - def test_invalid_quoting(self): - self.assertRaises(ValueError, self._callFUT, '"test') - - def test_invalid_quoting_single_quote(self): - self.assertRaises(ValueError, self._callFUT, '"') diff --git a/libs/waitress/tests/test_wasyncore.py b/libs/waitress/tests/test_wasyncore.py deleted file mode 100644 index 9c235092f..000000000 --- a/libs/waitress/tests/test_wasyncore.py +++ /dev/null @@ -1,1761 +0,0 @@ -from waitress import wasyncore as asyncore -from waitress import compat -import contextlib -import functools -import gc -import unittest -import select -import os -import socket -import sys -import time -import errno -import re -import struct -import threading -import warnings - -from io import BytesIO - -TIMEOUT = 3 -HAS_UNIX_SOCKETS = hasattr(socket, "AF_UNIX") -HOST = "localhost" -HOSTv4 = "127.0.0.1" -HOSTv6 = "::1" - -# Filename used for testing -if os.name == "java": # pragma: no cover - # Jython disallows @ in module names - TESTFN = "$test" -else: - TESTFN = "@test" - -TESTFN = "{}_{}_tmp".format(TESTFN, os.getpid()) - - -class DummyLogger(object): # pragma: no cover - def __init__(self): - self.messages = [] - - def log(self, severity, message): - self.messages.append((severity, message)) - - -class WarningsRecorder(object): # pragma: no cover - """Convenience wrapper for the warnings list returned on - entry to the warnings.catch_warnings() context manager. - """ - - def __init__(self, warnings_list): - self._warnings = warnings_list - self._last = 0 - - @property - def warnings(self): - return self._warnings[self._last :] - - def reset(self): - self._last = len(self._warnings) - - -def _filterwarnings(filters, quiet=False): # pragma: no cover - """Catch the warnings, then check if all the expected - warnings have been raised and re-raise unexpected warnings. - If 'quiet' is True, only re-raise the unexpected warnings. - """ - # Clear the warning registry of the calling module - # in order to re-raise the warnings. - frame = sys._getframe(2) - registry = frame.f_globals.get("__warningregistry__") - if registry: - registry.clear() - with warnings.catch_warnings(record=True) as w: - # Set filter "always" to record all warnings. Because - # test_warnings swap the module, we need to look up in - # the sys.modules dictionary. - sys.modules["warnings"].simplefilter("always") - yield WarningsRecorder(w) - # Filter the recorded warnings - reraise = list(w) - missing = [] - for msg, cat in filters: - seen = False - for w in reraise[:]: - warning = w.message - # Filter out the matching messages - if re.match(msg, str(warning), re.I) and issubclass(warning.__class__, cat): - seen = True - reraise.remove(w) - if not seen and not quiet: - # This filter caught nothing - missing.append((msg, cat.__name__)) - if reraise: - raise AssertionError("unhandled warning %s" % reraise[0]) - if missing: - raise AssertionError("filter (%r, %s) did not catch any warning" % missing[0]) - - -@contextlib.contextmanager -def check_warnings(*filters, **kwargs): # pragma: no cover - """Context manager to silence warnings. - - Accept 2-tuples as positional arguments: - ("message regexp", WarningCategory) - - Optional argument: - - if 'quiet' is True, it does not fail if a filter catches nothing - (default True without argument, - default False if some filters are defined) - - Without argument, it defaults to: - check_warnings(("", Warning), quiet=True) - """ - quiet = kwargs.get("quiet") - if not filters: - filters = (("", Warning),) - # Preserve backward compatibility - if quiet is None: - quiet = True - return _filterwarnings(filters, quiet) - - -def gc_collect(): # pragma: no cover - """Force as many objects as possible to be collected. - - In non-CPython implementations of Python, this is needed because timely - deallocation is not guaranteed by the garbage collector. (Even in CPython - this can be the case in case of reference cycles.) This means that __del__ - methods may be called later than expected and weakrefs may remain alive for - longer than expected. This function tries its best to force all garbage - objects to disappear. - """ - gc.collect() - if sys.platform.startswith("java"): - time.sleep(0.1) - gc.collect() - gc.collect() - - -def threading_setup(): # pragma: no cover - return (compat.thread._count(), None) - - -def threading_cleanup(*original_values): # pragma: no cover - global environment_altered - - _MAX_COUNT = 100 - - for count in range(_MAX_COUNT): - values = (compat.thread._count(), None) - if values == original_values: - break - - if not count: - # Display a warning at the first iteration - environment_altered = True - sys.stderr.write( - "Warning -- threading_cleanup() failed to cleanup " - "%s threads" % (values[0] - original_values[0]) - ) - sys.stderr.flush() - - values = None - - time.sleep(0.01) - gc_collect() - - -def reap_threads(func): # pragma: no cover - """Use this function when threads are being used. This will - ensure that the threads are cleaned up even when the test fails. - """ - - @functools.wraps(func) - def decorator(*args): - key = threading_setup() - try: - return func(*args) - finally: - threading_cleanup(*key) - - return decorator - - -def join_thread(thread, timeout=30.0): # pragma: no cover - """Join a thread. Raise an AssertionError if the thread is still alive - after timeout seconds. - """ - thread.join(timeout) - if thread.is_alive(): - msg = "failed to join the thread in %.1f seconds" % timeout - raise AssertionError(msg) - - -def bind_port(sock, host=HOST): # pragma: no cover - """Bind the socket to a free port and return the port number. Relies on - ephemeral ports in order to ensure we are using an unbound port. This is - important as many tests may be running simultaneously, especially in a - buildbot environment. This method raises an exception if the sock.family - is AF_INET and sock.type is SOCK_STREAM, *and* the socket has SO_REUSEADDR - or SO_REUSEPORT set on it. Tests should *never* set these socket options - for TCP/IP sockets. The only case for setting these options is testing - multicasting via multiple UDP sockets. - - Additionally, if the SO_EXCLUSIVEADDRUSE socket option is available (i.e. - on Windows), it will be set on the socket. This will prevent anyone else - from bind()'ing to our host/port for the duration of the test. - """ - - if sock.family == socket.AF_INET and sock.type == socket.SOCK_STREAM: - if hasattr(socket, "SO_REUSEADDR"): - if sock.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR) == 1: - raise RuntimeError( - "tests should never set the SO_REUSEADDR " - "socket option on TCP/IP sockets!" - ) - if hasattr(socket, "SO_REUSEPORT"): - try: - if sock.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT) == 1: - raise RuntimeError( - "tests should never set the SO_REUSEPORT " - "socket option on TCP/IP sockets!" - ) - except OSError: - # Python's socket module was compiled using modern headers - # thus defining SO_REUSEPORT but this process is running - # under an older kernel that does not support SO_REUSEPORT. - pass - if hasattr(socket, "SO_EXCLUSIVEADDRUSE"): - sock.setsockopt(socket.SOL_SOCKET, socket.SO_EXCLUSIVEADDRUSE, 1) - - sock.bind((host, 0)) - port = sock.getsockname()[1] - return port - - -@contextlib.contextmanager -def closewrapper(sock): # pragma: no cover - try: - yield sock - finally: - sock.close() - - -class dummysocket: # pragma: no cover - def __init__(self): - self.closed = False - - def close(self): - self.closed = True - - def fileno(self): - return 42 - - def setblocking(self, yesno): - self.isblocking = yesno - - def getpeername(self): - return "peername" - - -class dummychannel: # pragma: no cover - def __init__(self): - self.socket = dummysocket() - - def close(self): - self.socket.close() - - -class exitingdummy: # pragma: no cover - def __init__(self): - pass - - def handle_read_event(self): - raise asyncore.ExitNow() - - handle_write_event = handle_read_event - handle_close = handle_read_event - handle_expt_event = handle_read_event - - -class crashingdummy: - def __init__(self): - self.error_handled = False - - def handle_read_event(self): - raise Exception() - - handle_write_event = handle_read_event - handle_close = handle_read_event - handle_expt_event = handle_read_event - - def handle_error(self): - self.error_handled = True - - -# used when testing senders; just collects what it gets until newline is sent -def capture_server(evt, buf, serv): # pragma no cover - try: - serv.listen(0) - conn, addr = serv.accept() - except socket.timeout: - pass - else: - n = 200 - start = time.time() - while n > 0 and time.time() - start < 3.0: - r, w, e = select.select([conn], [], [], 0.1) - if r: - n -= 1 - data = conn.recv(10) - # keep everything except for the newline terminator - buf.write(data.replace(b"\n", b"")) - if b"\n" in data: - break - time.sleep(0.01) - - conn.close() - finally: - serv.close() - evt.set() - - -def bind_unix_socket(sock, addr): # pragma: no cover - """Bind a unix socket, raising SkipTest if PermissionError is raised.""" - assert sock.family == socket.AF_UNIX - try: - sock.bind(addr) - except PermissionError: - sock.close() - raise unittest.SkipTest("cannot bind AF_UNIX sockets") - - -def bind_af_aware(sock, addr): - """Helper function to bind a socket according to its family.""" - if HAS_UNIX_SOCKETS and sock.family == socket.AF_UNIX: - # Make sure the path doesn't exist. - unlink(addr) - bind_unix_socket(sock, addr) - else: - sock.bind(addr) - - -if sys.platform.startswith("win"): # pragma: no cover - - def _waitfor(func, pathname, waitall=False): - # Perform the operation - func(pathname) - # Now setup the wait loop - if waitall: - dirname = pathname - else: - dirname, name = os.path.split(pathname) - dirname = dirname or "." - # Check for `pathname` to be removed from the filesystem. - # The exponential backoff of the timeout amounts to a total - # of ~1 second after which the deletion is probably an error - # anyway. - # Testing on an i7@4.3GHz shows that usually only 1 iteration is - # required when contention occurs. - timeout = 0.001 - while timeout < 1.0: - # Note we are only testing for the existence of the file(s) in - # the contents of the directory regardless of any security or - # access rights. If we have made it this far, we have sufficient - # permissions to do that much using Python's equivalent of the - # Windows API FindFirstFile. - # Other Windows APIs can fail or give incorrect results when - # dealing with files that are pending deletion. - L = os.listdir(dirname) - if not (L if waitall else name in L): - return - # Increase the timeout and try again - time.sleep(timeout) - timeout *= 2 - warnings.warn( - "tests may fail, delete still pending for " + pathname, - RuntimeWarning, - stacklevel=4, - ) - - def _unlink(filename): - _waitfor(os.unlink, filename) - - -else: - _unlink = os.unlink - - -def unlink(filename): - try: - _unlink(filename) - except OSError: - pass - - -def _is_ipv6_enabled(): # pragma: no cover - """Check whether IPv6 is enabled on this host.""" - if compat.HAS_IPV6: - sock = None - try: - sock = socket.socket(socket.AF_INET6, socket.SOCK_STREAM) - sock.bind(("::1", 0)) - return True - except socket.error: - pass - finally: - if sock: - sock.close() - return False - - -IPV6_ENABLED = _is_ipv6_enabled() - - -class HelperFunctionTests(unittest.TestCase): - def test_readwriteexc(self): - # Check exception handling behavior of read, write and _exception - - # check that ExitNow exceptions in the object handler method - # bubbles all the way up through asyncore read/write/_exception calls - tr1 = exitingdummy() - self.assertRaises(asyncore.ExitNow, asyncore.read, tr1) - self.assertRaises(asyncore.ExitNow, asyncore.write, tr1) - self.assertRaises(asyncore.ExitNow, asyncore._exception, tr1) - - # check that an exception other than ExitNow in the object handler - # method causes the handle_error method to get called - tr2 = crashingdummy() - asyncore.read(tr2) - self.assertEqual(tr2.error_handled, True) - - tr2 = crashingdummy() - asyncore.write(tr2) - self.assertEqual(tr2.error_handled, True) - - tr2 = crashingdummy() - asyncore._exception(tr2) - self.assertEqual(tr2.error_handled, True) - - # asyncore.readwrite uses constants in the select module that - # are not present in Windows systems (see this thread: - # http://mail.python.org/pipermail/python-list/2001-October/109973.html) - # These constants should be present as long as poll is available - - @unittest.skipUnless(hasattr(select, "poll"), "select.poll required") - def test_readwrite(self): - # Check that correct methods are called by readwrite() - - attributes = ("read", "expt", "write", "closed", "error_handled") - - expected = ( - (select.POLLIN, "read"), - (select.POLLPRI, "expt"), - (select.POLLOUT, "write"), - (select.POLLERR, "closed"), - (select.POLLHUP, "closed"), - (select.POLLNVAL, "closed"), - ) - - class testobj: - def __init__(self): - self.read = False - self.write = False - self.closed = False - self.expt = False - self.error_handled = False - - def handle_read_event(self): - self.read = True - - def handle_write_event(self): - self.write = True - - def handle_close(self): - self.closed = True - - def handle_expt_event(self): - self.expt = True - - # def handle_error(self): - # self.error_handled = True - - for flag, expectedattr in expected: - tobj = testobj() - self.assertEqual(getattr(tobj, expectedattr), False) - asyncore.readwrite(tobj, flag) - - # Only the attribute modified by the routine we expect to be - # called should be True. - for attr in attributes: - self.assertEqual(getattr(tobj, attr), attr == expectedattr) - - # check that ExitNow exceptions in the object handler method - # bubbles all the way up through asyncore readwrite call - tr1 = exitingdummy() - self.assertRaises(asyncore.ExitNow, asyncore.readwrite, tr1, flag) - - # check that an exception other than ExitNow in the object handler - # method causes the handle_error method to get called - tr2 = crashingdummy() - self.assertEqual(tr2.error_handled, False) - asyncore.readwrite(tr2, flag) - self.assertEqual(tr2.error_handled, True) - - def test_closeall(self): - self.closeall_check(False) - - def test_closeall_default(self): - self.closeall_check(True) - - def closeall_check(self, usedefault): - # Check that close_all() closes everything in a given map - - l = [] - testmap = {} - for i in range(10): - c = dummychannel() - l.append(c) - self.assertEqual(c.socket.closed, False) - testmap[i] = c - - if usedefault: - socketmap = asyncore.socket_map - try: - asyncore.socket_map = testmap - asyncore.close_all() - finally: - testmap, asyncore.socket_map = asyncore.socket_map, socketmap - else: - asyncore.close_all(testmap) - - self.assertEqual(len(testmap), 0) - - for c in l: - self.assertEqual(c.socket.closed, True) - - def test_compact_traceback(self): - try: - raise Exception("I don't like spam!") - except: - real_t, real_v, real_tb = sys.exc_info() - r = asyncore.compact_traceback() - - (f, function, line), t, v, info = r - self.assertEqual(os.path.split(f)[-1], "test_wasyncore.py") - self.assertEqual(function, "test_compact_traceback") - self.assertEqual(t, real_t) - self.assertEqual(v, real_v) - self.assertEqual(info, "[%s|%s|%s]" % (f, function, line)) - - -class DispatcherTests(unittest.TestCase): - def setUp(self): - pass - - def tearDown(self): - asyncore.close_all() - - def test_basic(self): - d = asyncore.dispatcher() - self.assertEqual(d.readable(), True) - self.assertEqual(d.writable(), True) - - def test_repr(self): - d = asyncore.dispatcher() - self.assertEqual(repr(d), "" % id(d)) - - def test_log_info(self): - import logging - - inst = asyncore.dispatcher(map={}) - logger = DummyLogger() - inst.logger = logger - inst.log_info("message", "warning") - self.assertEqual(logger.messages, [(logging.WARN, "message")]) - - def test_log(self): - import logging - - inst = asyncore.dispatcher() - logger = DummyLogger() - inst.logger = logger - inst.log("message") - self.assertEqual(logger.messages, [(logging.DEBUG, "message")]) - - def test_unhandled(self): - import logging - - inst = asyncore.dispatcher() - logger = DummyLogger() - inst.logger = logger - - inst.handle_expt() - inst.handle_read() - inst.handle_write() - inst.handle_connect() - - expected = [ - (logging.WARN, "unhandled incoming priority event"), - (logging.WARN, "unhandled read event"), - (logging.WARN, "unhandled write event"), - (logging.WARN, "unhandled connect event"), - ] - self.assertEqual(logger.messages, expected) - - def test_strerror(self): - # refers to bug #8573 - err = asyncore._strerror(errno.EPERM) - if hasattr(os, "strerror"): - self.assertEqual(err, os.strerror(errno.EPERM)) - err = asyncore._strerror(-1) - self.assertTrue(err != "") - - -class dispatcherwithsend_noread(asyncore.dispatcher_with_send): # pragma: no cover - def readable(self): - return False - - def handle_connect(self): - pass - - -class DispatcherWithSendTests(unittest.TestCase): - def setUp(self): - pass - - def tearDown(self): - asyncore.close_all() - - @reap_threads - def test_send(self): - evt = threading.Event() - sock = socket.socket() - sock.settimeout(3) - port = bind_port(sock) - - cap = BytesIO() - args = (evt, cap, sock) - t = threading.Thread(target=capture_server, args=args) - t.start() - try: - # wait a little longer for the server to initialize (it sometimes - # refuses connections on slow machines without this wait) - time.sleep(0.2) - - data = b"Suppose there isn't a 16-ton weight?" - d = dispatcherwithsend_noread() - d.create_socket() - d.connect((HOST, port)) - - # give time for socket to connect - time.sleep(0.1) - - d.send(data) - d.send(data) - d.send(b"\n") - - n = 1000 - while d.out_buffer and n > 0: # pragma: no cover - asyncore.poll() - n -= 1 - - evt.wait() - - self.assertEqual(cap.getvalue(), data * 2) - finally: - join_thread(t, timeout=TIMEOUT) - - -@unittest.skipUnless( - hasattr(asyncore, "file_wrapper"), "asyncore.file_wrapper required" -) -class FileWrapperTest(unittest.TestCase): - def setUp(self): - self.d = b"It's not dead, it's sleeping!" - with open(TESTFN, "wb") as file: - file.write(self.d) - - def tearDown(self): - unlink(TESTFN) - - def test_recv(self): - fd = os.open(TESTFN, os.O_RDONLY) - w = asyncore.file_wrapper(fd) - os.close(fd) - - self.assertNotEqual(w.fd, fd) - self.assertNotEqual(w.fileno(), fd) - self.assertEqual(w.recv(13), b"It's not dead") - self.assertEqual(w.read(6), b", it's") - w.close() - self.assertRaises(OSError, w.read, 1) - - def test_send(self): - d1 = b"Come again?" - d2 = b"I want to buy some cheese." - fd = os.open(TESTFN, os.O_WRONLY | os.O_APPEND) - w = asyncore.file_wrapper(fd) - os.close(fd) - - w.write(d1) - w.send(d2) - w.close() - with open(TESTFN, "rb") as file: - self.assertEqual(file.read(), self.d + d1 + d2) - - @unittest.skipUnless( - hasattr(asyncore, "file_dispatcher"), "asyncore.file_dispatcher required" - ) - def test_dispatcher(self): - fd = os.open(TESTFN, os.O_RDONLY) - data = [] - - class FileDispatcher(asyncore.file_dispatcher): - def handle_read(self): - data.append(self.recv(29)) - - FileDispatcher(fd) - os.close(fd) - asyncore.loop(timeout=0.01, use_poll=True, count=2) - self.assertEqual(b"".join(data), self.d) - - def test_resource_warning(self): - # Issue #11453 - got_warning = False - while got_warning is False: - # we try until we get the outcome we want because this - # test is not deterministic (gc_collect() may not - fd = os.open(TESTFN, os.O_RDONLY) - f = asyncore.file_wrapper(fd) - - os.close(fd) - - try: - with check_warnings(("", compat.ResourceWarning)): - f = None - gc_collect() - except AssertionError: # pragma: no cover - pass - else: - got_warning = True - - def test_close_twice(self): - fd = os.open(TESTFN, os.O_RDONLY) - f = asyncore.file_wrapper(fd) - os.close(fd) - - os.close(f.fd) # file_wrapper dupped fd - with self.assertRaises(OSError): - f.close() - - self.assertEqual(f.fd, -1) - # calling close twice should not fail - f.close() - - -class BaseTestHandler(asyncore.dispatcher): # pragma: no cover - def __init__(self, sock=None): - asyncore.dispatcher.__init__(self, sock) - self.flag = False - - def handle_accept(self): - raise Exception("handle_accept not supposed to be called") - - def handle_accepted(self): - raise Exception("handle_accepted not supposed to be called") - - def handle_connect(self): - raise Exception("handle_connect not supposed to be called") - - def handle_expt(self): - raise Exception("handle_expt not supposed to be called") - - def handle_close(self): - raise Exception("handle_close not supposed to be called") - - def handle_error(self): - raise - - -class BaseServer(asyncore.dispatcher): - """A server which listens on an address and dispatches the - connection to a handler. - """ - - def __init__(self, family, addr, handler=BaseTestHandler): - asyncore.dispatcher.__init__(self) - self.create_socket(family) - self.set_reuse_addr() - bind_af_aware(self.socket, addr) - self.listen(5) - self.handler = handler - - @property - def address(self): - return self.socket.getsockname() - - def handle_accepted(self, sock, addr): - self.handler(sock) - - def handle_error(self): # pragma: no cover - raise - - -class BaseClient(BaseTestHandler): - def __init__(self, family, address): - BaseTestHandler.__init__(self) - self.create_socket(family) - self.connect(address) - - def handle_connect(self): - pass - - -class BaseTestAPI: - def tearDown(self): - asyncore.close_all(ignore_all=True) - - def loop_waiting_for_flag(self, instance, timeout=5): # pragma: no cover - timeout = float(timeout) / 100 - count = 100 - while asyncore.socket_map and count > 0: - asyncore.loop(timeout=0.01, count=1, use_poll=self.use_poll) - if instance.flag: - return - count -= 1 - time.sleep(timeout) - self.fail("flag not set") - - def test_handle_connect(self): - # make sure handle_connect is called on connect() - - class TestClient(BaseClient): - def handle_connect(self): - self.flag = True - - server = BaseServer(self.family, self.addr) - client = TestClient(self.family, server.address) - self.loop_waiting_for_flag(client) - - def test_handle_accept(self): - # make sure handle_accept() is called when a client connects - - class TestListener(BaseTestHandler): - def __init__(self, family, addr): - BaseTestHandler.__init__(self) - self.create_socket(family) - bind_af_aware(self.socket, addr) - self.listen(5) - self.address = self.socket.getsockname() - - def handle_accept(self): - self.flag = True - - server = TestListener(self.family, self.addr) - client = BaseClient(self.family, server.address) - self.loop_waiting_for_flag(server) - - def test_handle_accepted(self): - # make sure handle_accepted() is called when a client connects - - class TestListener(BaseTestHandler): - def __init__(self, family, addr): - BaseTestHandler.__init__(self) - self.create_socket(family) - bind_af_aware(self.socket, addr) - self.listen(5) - self.address = self.socket.getsockname() - - def handle_accept(self): - asyncore.dispatcher.handle_accept(self) - - def handle_accepted(self, sock, addr): - sock.close() - self.flag = True - - server = TestListener(self.family, self.addr) - client = BaseClient(self.family, server.address) - self.loop_waiting_for_flag(server) - - def test_handle_read(self): - # make sure handle_read is called on data received - - class TestClient(BaseClient): - def handle_read(self): - self.flag = True - - class TestHandler(BaseTestHandler): - def __init__(self, conn): - BaseTestHandler.__init__(self, conn) - self.send(b"x" * 1024) - - server = BaseServer(self.family, self.addr, TestHandler) - client = TestClient(self.family, server.address) - self.loop_waiting_for_flag(client) - - def test_handle_write(self): - # make sure handle_write is called - - class TestClient(BaseClient): - def handle_write(self): - self.flag = True - - server = BaseServer(self.family, self.addr) - client = TestClient(self.family, server.address) - self.loop_waiting_for_flag(client) - - def test_handle_close(self): - # make sure handle_close is called when the other end closes - # the connection - - class TestClient(BaseClient): - def handle_read(self): - # in order to make handle_close be called we are supposed - # to make at least one recv() call - self.recv(1024) - - def handle_close(self): - self.flag = True - self.close() - - class TestHandler(BaseTestHandler): - def __init__(self, conn): - BaseTestHandler.__init__(self, conn) - self.close() - - server = BaseServer(self.family, self.addr, TestHandler) - client = TestClient(self.family, server.address) - self.loop_waiting_for_flag(client) - - def test_handle_close_after_conn_broken(self): - # Check that ECONNRESET/EPIPE is correctly handled (issues #5661 and - # #11265). - - data = b"\0" * 128 - - class TestClient(BaseClient): - def handle_write(self): - self.send(data) - - def handle_close(self): - self.flag = True - self.close() - - def handle_expt(self): # pragma: no cover - # needs to exist for MacOS testing - self.flag = True - self.close() - - class TestHandler(BaseTestHandler): - def handle_read(self): - self.recv(len(data)) - self.close() - - def writable(self): - return False - - server = BaseServer(self.family, self.addr, TestHandler) - client = TestClient(self.family, server.address) - self.loop_waiting_for_flag(client) - - @unittest.skipIf( - sys.platform.startswith("sunos"), "OOB support is broken on Solaris" - ) - def test_handle_expt(self): - # Make sure handle_expt is called on OOB data received. - # Note: this might fail on some platforms as OOB data is - # tenuously supported and rarely used. - if HAS_UNIX_SOCKETS and self.family == socket.AF_UNIX: - self.skipTest("Not applicable to AF_UNIX sockets.") - - if sys.platform == "darwin" and self.use_poll: # pragma: no cover - self.skipTest("poll may fail on macOS; see issue #28087") - - class TestClient(BaseClient): - def handle_expt(self): - self.socket.recv(1024, socket.MSG_OOB) - self.flag = True - - class TestHandler(BaseTestHandler): - def __init__(self, conn): - BaseTestHandler.__init__(self, conn) - self.socket.send(compat.tobytes(chr(244)), socket.MSG_OOB) - - server = BaseServer(self.family, self.addr, TestHandler) - client = TestClient(self.family, server.address) - self.loop_waiting_for_flag(client) - - def test_handle_error(self): - class TestClient(BaseClient): - def handle_write(self): - 1.0 / 0 - - def handle_error(self): - self.flag = True - try: - raise - except ZeroDivisionError: - pass - else: # pragma: no cover - raise Exception("exception not raised") - - server = BaseServer(self.family, self.addr) - client = TestClient(self.family, server.address) - self.loop_waiting_for_flag(client) - - def test_connection_attributes(self): - server = BaseServer(self.family, self.addr) - client = BaseClient(self.family, server.address) - - # we start disconnected - self.assertFalse(server.connected) - self.assertTrue(server.accepting) - # this can't be taken for granted across all platforms - # self.assertFalse(client.connected) - self.assertFalse(client.accepting) - - # execute some loops so that client connects to server - asyncore.loop(timeout=0.01, use_poll=self.use_poll, count=100) - self.assertFalse(server.connected) - self.assertTrue(server.accepting) - self.assertTrue(client.connected) - self.assertFalse(client.accepting) - - # disconnect the client - client.close() - self.assertFalse(server.connected) - self.assertTrue(server.accepting) - self.assertFalse(client.connected) - self.assertFalse(client.accepting) - - # stop serving - server.close() - self.assertFalse(server.connected) - self.assertFalse(server.accepting) - - def test_create_socket(self): - s = asyncore.dispatcher() - s.create_socket(self.family) - # self.assertEqual(s.socket.type, socket.SOCK_STREAM) - self.assertEqual(s.socket.family, self.family) - self.assertEqual(s.socket.gettimeout(), 0) - # self.assertFalse(s.socket.get_inheritable()) - - def test_bind(self): - if HAS_UNIX_SOCKETS and self.family == socket.AF_UNIX: - self.skipTest("Not applicable to AF_UNIX sockets.") - s1 = asyncore.dispatcher() - s1.create_socket(self.family) - s1.bind(self.addr) - s1.listen(5) - port = s1.socket.getsockname()[1] - - s2 = asyncore.dispatcher() - s2.create_socket(self.family) - # EADDRINUSE indicates the socket was correctly bound - self.assertRaises(socket.error, s2.bind, (self.addr[0], port)) - - def test_set_reuse_addr(self): # pragma: no cover - if HAS_UNIX_SOCKETS and self.family == socket.AF_UNIX: - self.skipTest("Not applicable to AF_UNIX sockets.") - - with closewrapper(socket.socket(self.family)) as sock: - try: - sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - except OSError: - unittest.skip("SO_REUSEADDR not supported on this platform") - else: - # if SO_REUSEADDR succeeded for sock we expect asyncore - # to do the same - s = asyncore.dispatcher(socket.socket(self.family)) - self.assertFalse( - s.socket.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR) - ) - s.socket.close() - s.create_socket(self.family) - s.set_reuse_addr() - self.assertTrue( - s.socket.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR) - ) - - @reap_threads - def test_quick_connect(self): # pragma: no cover - # see: http://bugs.python.org/issue10340 - if self.family not in (socket.AF_INET, getattr(socket, "AF_INET6", object())): - self.skipTest("test specific to AF_INET and AF_INET6") - - server = BaseServer(self.family, self.addr) - # run the thread 500 ms: the socket should be connected in 200 ms - t = threading.Thread(target=lambda: asyncore.loop(timeout=0.1, count=5)) - t.start() - try: - sock = socket.socket(self.family, socket.SOCK_STREAM) - with closewrapper(sock) as s: - s.settimeout(0.2) - s.setsockopt( - socket.SOL_SOCKET, socket.SO_LINGER, struct.pack("ii", 1, 0) - ) - - try: - s.connect(server.address) - except OSError: - pass - finally: - join_thread(t, timeout=TIMEOUT) - - -class TestAPI_UseIPv4Sockets(BaseTestAPI): - family = socket.AF_INET - addr = (HOST, 0) - - -@unittest.skipUnless(IPV6_ENABLED, "IPv6 support required") -class TestAPI_UseIPv6Sockets(BaseTestAPI): - family = socket.AF_INET6 - addr = (HOSTv6, 0) - - -@unittest.skipUnless(HAS_UNIX_SOCKETS, "Unix sockets required") -class TestAPI_UseUnixSockets(BaseTestAPI): - if HAS_UNIX_SOCKETS: - family = socket.AF_UNIX - addr = TESTFN - - def tearDown(self): - unlink(self.addr) - BaseTestAPI.tearDown(self) - - -class TestAPI_UseIPv4Select(TestAPI_UseIPv4Sockets, unittest.TestCase): - use_poll = False - - -@unittest.skipUnless(hasattr(select, "poll"), "select.poll required") -class TestAPI_UseIPv4Poll(TestAPI_UseIPv4Sockets, unittest.TestCase): - use_poll = True - - -class TestAPI_UseIPv6Select(TestAPI_UseIPv6Sockets, unittest.TestCase): - use_poll = False - - -@unittest.skipUnless(hasattr(select, "poll"), "select.poll required") -class TestAPI_UseIPv6Poll(TestAPI_UseIPv6Sockets, unittest.TestCase): - use_poll = True - - -class TestAPI_UseUnixSocketsSelect(TestAPI_UseUnixSockets, unittest.TestCase): - use_poll = False - - -@unittest.skipUnless(hasattr(select, "poll"), "select.poll required") -class TestAPI_UseUnixSocketsPoll(TestAPI_UseUnixSockets, unittest.TestCase): - use_poll = True - - -class Test__strerror(unittest.TestCase): - def _callFUT(self, err): - from waitress.wasyncore import _strerror - - return _strerror(err) - - def test_gardenpath(self): - self.assertEqual(self._callFUT(1), "Operation not permitted") - - def test_unknown(self): - self.assertEqual(self._callFUT("wut"), "Unknown error wut") - - -class Test_read(unittest.TestCase): - def _callFUT(self, dispatcher): - from waitress.wasyncore import read - - return read(dispatcher) - - def test_gardenpath(self): - inst = DummyDispatcher() - self._callFUT(inst) - self.assertTrue(inst.read_event_handled) - self.assertFalse(inst.error_handled) - - def test_reraised(self): - from waitress.wasyncore import ExitNow - - inst = DummyDispatcher(ExitNow) - self.assertRaises(ExitNow, self._callFUT, inst) - self.assertTrue(inst.read_event_handled) - self.assertFalse(inst.error_handled) - - def test_non_reraised(self): - inst = DummyDispatcher(OSError) - self._callFUT(inst) - self.assertTrue(inst.read_event_handled) - self.assertTrue(inst.error_handled) - - -class Test_write(unittest.TestCase): - def _callFUT(self, dispatcher): - from waitress.wasyncore import write - - return write(dispatcher) - - def test_gardenpath(self): - inst = DummyDispatcher() - self._callFUT(inst) - self.assertTrue(inst.write_event_handled) - self.assertFalse(inst.error_handled) - - def test_reraised(self): - from waitress.wasyncore import ExitNow - - inst = DummyDispatcher(ExitNow) - self.assertRaises(ExitNow, self._callFUT, inst) - self.assertTrue(inst.write_event_handled) - self.assertFalse(inst.error_handled) - - def test_non_reraised(self): - inst = DummyDispatcher(OSError) - self._callFUT(inst) - self.assertTrue(inst.write_event_handled) - self.assertTrue(inst.error_handled) - - -class Test__exception(unittest.TestCase): - def _callFUT(self, dispatcher): - from waitress.wasyncore import _exception - - return _exception(dispatcher) - - def test_gardenpath(self): - inst = DummyDispatcher() - self._callFUT(inst) - self.assertTrue(inst.expt_event_handled) - self.assertFalse(inst.error_handled) - - def test_reraised(self): - from waitress.wasyncore import ExitNow - - inst = DummyDispatcher(ExitNow) - self.assertRaises(ExitNow, self._callFUT, inst) - self.assertTrue(inst.expt_event_handled) - self.assertFalse(inst.error_handled) - - def test_non_reraised(self): - inst = DummyDispatcher(OSError) - self._callFUT(inst) - self.assertTrue(inst.expt_event_handled) - self.assertTrue(inst.error_handled) - - -@unittest.skipUnless(hasattr(select, "poll"), "select.poll required") -class Test_readwrite(unittest.TestCase): - def _callFUT(self, obj, flags): - from waitress.wasyncore import readwrite - - return readwrite(obj, flags) - - def test_handle_read_event(self): - flags = 0 - flags |= select.POLLIN - inst = DummyDispatcher() - self._callFUT(inst, flags) - self.assertTrue(inst.read_event_handled) - - def test_handle_write_event(self): - flags = 0 - flags |= select.POLLOUT - inst = DummyDispatcher() - self._callFUT(inst, flags) - self.assertTrue(inst.write_event_handled) - - def test_handle_expt_event(self): - flags = 0 - flags |= select.POLLPRI - inst = DummyDispatcher() - self._callFUT(inst, flags) - self.assertTrue(inst.expt_event_handled) - - def test_handle_close(self): - flags = 0 - flags |= select.POLLHUP - inst = DummyDispatcher() - self._callFUT(inst, flags) - self.assertTrue(inst.close_handled) - - def test_socketerror_not_in_disconnected(self): - flags = 0 - flags |= select.POLLIN - inst = DummyDispatcher(socket.error(errno.EALREADY, "EALREADY")) - self._callFUT(inst, flags) - self.assertTrue(inst.read_event_handled) - self.assertTrue(inst.error_handled) - - def test_socketerror_in_disconnected(self): - flags = 0 - flags |= select.POLLIN - inst = DummyDispatcher(socket.error(errno.ECONNRESET, "ECONNRESET")) - self._callFUT(inst, flags) - self.assertTrue(inst.read_event_handled) - self.assertTrue(inst.close_handled) - - def test_exception_in_reraised(self): - from waitress import wasyncore - - flags = 0 - flags |= select.POLLIN - inst = DummyDispatcher(wasyncore.ExitNow) - self.assertRaises(wasyncore.ExitNow, self._callFUT, inst, flags) - self.assertTrue(inst.read_event_handled) - - def test_exception_not_in_reraised(self): - flags = 0 - flags |= select.POLLIN - inst = DummyDispatcher(ValueError) - self._callFUT(inst, flags) - self.assertTrue(inst.error_handled) - - -class Test_poll(unittest.TestCase): - def _callFUT(self, timeout=0.0, map=None): - from waitress.wasyncore import poll - - return poll(timeout, map) - - def test_nothing_writable_nothing_readable_but_map_not_empty(self): - # i read the mock.patch docs. nerp. - dummy_time = DummyTime() - map = {0: DummyDispatcher()} - try: - from waitress import wasyncore - - old_time = wasyncore.time - wasyncore.time = dummy_time - result = self._callFUT(map=map) - finally: - wasyncore.time = old_time - self.assertEqual(result, None) - self.assertEqual(dummy_time.sleepvals, [0.0]) - - def test_select_raises_EINTR(self): - # i read the mock.patch docs. nerp. - dummy_select = DummySelect(select.error(errno.EINTR)) - disp = DummyDispatcher() - disp.readable = lambda: True - map = {0: disp} - try: - from waitress import wasyncore - - old_select = wasyncore.select - wasyncore.select = dummy_select - result = self._callFUT(map=map) - finally: - wasyncore.select = old_select - self.assertEqual(result, None) - self.assertEqual(dummy_select.selected, [([0], [], [0], 0.0)]) - - def test_select_raises_non_EINTR(self): - # i read the mock.patch docs. nerp. - dummy_select = DummySelect(select.error(errno.EBADF)) - disp = DummyDispatcher() - disp.readable = lambda: True - map = {0: disp} - try: - from waitress import wasyncore - - old_select = wasyncore.select - wasyncore.select = dummy_select - self.assertRaises(select.error, self._callFUT, map=map) - finally: - wasyncore.select = old_select - self.assertEqual(dummy_select.selected, [([0], [], [0], 0.0)]) - - -class Test_poll2(unittest.TestCase): - def _callFUT(self, timeout=0.0, map=None): - from waitress.wasyncore import poll2 - - return poll2(timeout, map) - - def test_select_raises_EINTR(self): - # i read the mock.patch docs. nerp. - pollster = DummyPollster(exc=select.error(errno.EINTR)) - dummy_select = DummySelect(pollster=pollster) - disp = DummyDispatcher() - map = {0: disp} - try: - from waitress import wasyncore - - old_select = wasyncore.select - wasyncore.select = dummy_select - self._callFUT(map=map) - finally: - wasyncore.select = old_select - self.assertEqual(pollster.polled, [0.0]) - - def test_select_raises_non_EINTR(self): - # i read the mock.patch docs. nerp. - pollster = DummyPollster(exc=select.error(errno.EBADF)) - dummy_select = DummySelect(pollster=pollster) - disp = DummyDispatcher() - map = {0: disp} - try: - from waitress import wasyncore - - old_select = wasyncore.select - wasyncore.select = dummy_select - self.assertRaises(select.error, self._callFUT, map=map) - finally: - wasyncore.select = old_select - self.assertEqual(pollster.polled, [0.0]) - - -class Test_dispatcher(unittest.TestCase): - def _makeOne(self, sock=None, map=None): - from waitress.wasyncore import dispatcher - - return dispatcher(sock=sock, map=map) - - def test_unexpected_getpeername_exc(self): - sock = dummysocket() - - def getpeername(): - raise socket.error(errno.EBADF) - - map = {} - sock.getpeername = getpeername - self.assertRaises(socket.error, self._makeOne, sock=sock, map=map) - self.assertEqual(map, {}) - - def test___repr__accepting(self): - sock = dummysocket() - map = {} - inst = self._makeOne(sock=sock, map=map) - inst.accepting = True - inst.addr = ("localhost", 8080) - result = repr(inst) - expected = "= 10: # I've never seen it go above 2 - a.close() - w.close() - raise RuntimeError("Cannot bind trigger!") - # Close `a` and try again. Note: I originally put a short - # sleep() here, but it didn't appear to help or hurt. - a.close() - - r, addr = a.accept() # r becomes wasyncore's (self.)socket - a.close() - self.trigger = w - wasyncore.dispatcher.__init__(self, r, map=map) - - def _close(self): - # self.socket is r, and self.trigger is w, from __init__ - self.socket.close() - self.trigger.close() - - def _physical_pull(self): - self.trigger.send(b"x") diff --git a/libs/waitress/utilities.py b/libs/waitress/utilities.py deleted file mode 100644 index 556bed20a..000000000 --- a/libs/waitress/utilities.py +++ /dev/null @@ -1,320 +0,0 @@ -############################################################################## -# -# Copyright (c) 2004 Zope Foundation and Contributors. -# All Rights Reserved. -# -# This software is subject to the provisions of the Zope Public License, -# Version 2.1 (ZPL). A copy of the ZPL should accompany this distribution. -# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED -# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED -# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS -# FOR A PARTICULAR PURPOSE. -# -############################################################################## -"""Utility functions -""" - -import calendar -import errno -import logging -import os -import re -import stat -import time - -from .rfc7230 import OBS_TEXT, VCHAR - -logger = logging.getLogger("waitress") -queue_logger = logging.getLogger("waitress.queue") - - -def find_double_newline(s): - """Returns the position just after a double newline in the given string.""" - pos = s.find(b"\r\n\r\n") - - if pos >= 0: - pos += 4 - - return pos - - -def concat(*args): - return "".join(args) - - -def join(seq, field=" "): - return field.join(seq) - - -def group(s): - return "(" + s + ")" - - -short_days = ["sun", "mon", "tue", "wed", "thu", "fri", "sat"] -long_days = [ - "sunday", - "monday", - "tuesday", - "wednesday", - "thursday", - "friday", - "saturday", -] - -short_day_reg = group(join(short_days, "|")) -long_day_reg = group(join(long_days, "|")) - -daymap = {} - -for i in range(7): - daymap[short_days[i]] = i - daymap[long_days[i]] = i - -hms_reg = join(3 * [group("[0-9][0-9]")], ":") - -months = [ - "jan", - "feb", - "mar", - "apr", - "may", - "jun", - "jul", - "aug", - "sep", - "oct", - "nov", - "dec", -] - -monmap = {} - -for i in range(12): - monmap[months[i]] = i + 1 - -months_reg = group(join(months, "|")) - -# From draft-ietf-http-v11-spec-07.txt/3.3.1 -# Sun, 06 Nov 1994 08:49:37 GMT ; RFC 822, updated by RFC 1123 -# Sunday, 06-Nov-94 08:49:37 GMT ; RFC 850, obsoleted by RFC 1036 -# Sun Nov 6 08:49:37 1994 ; ANSI C's asctime() format - -# rfc822 format -rfc822_date = join( - [ - concat(short_day_reg, ","), # day - group("[0-9][0-9]?"), # date - months_reg, # month - group("[0-9]+"), # year - hms_reg, # hour minute second - "gmt", - ], - " ", -) - -rfc822_reg = re.compile(rfc822_date) - - -def unpack_rfc822(m): - g = m.group - - return ( - int(g(4)), # year - monmap[g(3)], # month - int(g(2)), # day - int(g(5)), # hour - int(g(6)), # minute - int(g(7)), # second - 0, - 0, - 0, - ) - - -# rfc850 format -rfc850_date = join( - [ - concat(long_day_reg, ","), - join([group("[0-9][0-9]?"), months_reg, group("[0-9]+")], "-"), - hms_reg, - "gmt", - ], - " ", -) - -rfc850_reg = re.compile(rfc850_date) -# they actually unpack the same way -def unpack_rfc850(m): - g = m.group - yr = g(4) - - if len(yr) == 2: - yr = "19" + yr - - return ( - int(yr), # year - monmap[g(3)], # month - int(g(2)), # day - int(g(5)), # hour - int(g(6)), # minute - int(g(7)), # second - 0, - 0, - 0, - ) - - -# parsdate.parsedate - ~700/sec. -# parse_http_date - ~1333/sec. - -weekdayname = ["Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun"] -monthname = [ - None, - "Jan", - "Feb", - "Mar", - "Apr", - "May", - "Jun", - "Jul", - "Aug", - "Sep", - "Oct", - "Nov", - "Dec", -] - - -def build_http_date(when): - year, month, day, hh, mm, ss, wd, y, z = time.gmtime(when) - - return "%s, %02d %3s %4d %02d:%02d:%02d GMT" % ( - weekdayname[wd], - day, - monthname[month], - year, - hh, - mm, - ss, - ) - - -def parse_http_date(d): - d = d.lower() - m = rfc850_reg.match(d) - - if m and m.end() == len(d): - retval = int(calendar.timegm(unpack_rfc850(m))) - else: - m = rfc822_reg.match(d) - - if m and m.end() == len(d): - retval = int(calendar.timegm(unpack_rfc822(m))) - else: - return 0 - - return retval - - -# RFC 5234 Appendix B.1 "Core Rules": -# VCHAR = %x21-7E -# ; visible (printing) characters -vchar_re = VCHAR - -# RFC 7230 Section 3.2.6 "Field Value Components": -# quoted-string = DQUOTE *( qdtext / quoted-pair ) DQUOTE -# qdtext = HTAB / SP /%x21 / %x23-5B / %x5D-7E / obs-text -# obs-text = %x80-FF -# quoted-pair = "\" ( HTAB / SP / VCHAR / obs-text ) -obs_text_re = OBS_TEXT - -# The '\\' between \x5b and \x5d is needed to escape \x5d (']') -qdtext_re = "[\t \x21\x23-\x5b\\\x5d-\x7e" + obs_text_re + "]" - -quoted_pair_re = r"\\" + "([\t " + vchar_re + obs_text_re + "])" -quoted_string_re = '"(?:(?:' + qdtext_re + ")|(?:" + quoted_pair_re + '))*"' - -quoted_string = re.compile(quoted_string_re) -quoted_pair = re.compile(quoted_pair_re) - - -def undquote(value): - if value.startswith('"') and value.endswith('"'): - # So it claims to be DQUOTE'ed, let's validate that - matches = quoted_string.match(value) - - if matches and matches.end() == len(value): - # Remove the DQUOTE's from the value - value = value[1:-1] - - # Remove all backslashes that are followed by a valid vchar or - # obs-text - value = quoted_pair.sub(r"\1", value) - - return value - elif not value.startswith('"') and not value.endswith('"'): - return value - - raise ValueError("Invalid quoting in value") - - -def cleanup_unix_socket(path): - try: - st = os.stat(path) - except OSError as exc: - if exc.errno != errno.ENOENT: - raise # pragma: no cover - else: - if stat.S_ISSOCK(st.st_mode): - try: - os.remove(path) - except OSError: # pragma: no cover - # avoid race condition error during tests - pass - - -class Error(object): - code = 500 - reason = "Internal Server Error" - - def __init__(self, body): - self.body = body - - def to_response(self): - status = "%s %s" % (self.code, self.reason) - body = "%s\r\n\r\n%s" % (self.reason, self.body) - tag = "\r\n\r\n(generated by waitress)" - body = body + tag - headers = [("Content-Type", "text/plain")] - - return status, headers, body - - def wsgi_response(self, environ, start_response): - status, headers, body = self.to_response() - start_response(status, headers) - yield body - - -class BadRequest(Error): - code = 400 - reason = "Bad Request" - - -class RequestHeaderFieldsTooLarge(BadRequest): - code = 431 - reason = "Request Header Fields Too Large" - - -class RequestEntityTooLarge(BadRequest): - code = 413 - reason = "Request Entity Too Large" - - -class InternalServerError(Error): - code = 500 - reason = "Internal Server Error" - - -class ServerNotImplemented(Error): - code = 501 - reason = "Not Implemented" diff --git a/libs/waitress/wasyncore.py b/libs/waitress/wasyncore.py deleted file mode 100644 index 09bcafaa0..000000000 --- a/libs/waitress/wasyncore.py +++ /dev/null @@ -1,693 +0,0 @@ -# -*- Mode: Python -*- -# Id: asyncore.py,v 2.51 2000/09/07 22:29:26 rushing Exp -# Author: Sam Rushing - -# ====================================================================== -# Copyright 1996 by Sam Rushing -# -# All Rights Reserved -# -# Permission to use, copy, modify, and distribute this software and -# its documentation for any purpose and without fee is hereby -# granted, provided that the above copyright notice appear in all -# copies and that both that copyright notice and this permission -# notice appear in supporting documentation, and that the name of Sam -# Rushing not be used in advertising or publicity pertaining to -# distribution of the software without specific, written prior -# permission. -# -# SAM RUSHING DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE, -# INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS, IN -# NO EVENT SHALL SAM RUSHING BE LIABLE FOR ANY SPECIAL, INDIRECT OR -# CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS -# OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, -# NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN -# CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. -# ====================================================================== - -"""Basic infrastructure for asynchronous socket service clients and servers. - -There are only two ways to have a program on a single processor do "more -than one thing at a time". Multi-threaded programming is the simplest and -most popular way to do it, but there is another very different technique, -that lets you have nearly all the advantages of multi-threading, without -actually using multiple threads. it's really only practical if your program -is largely I/O bound. If your program is CPU bound, then pre-emptive -scheduled threads are probably what you really need. Network servers are -rarely CPU-bound, however. - -If your operating system supports the select() system call in its I/O -library (and nearly all do), then you can use it to juggle multiple -communication channels at once; doing other work while your I/O is taking -place in the "background." Although this strategy can seem strange and -complex, especially at first, it is in many ways easier to understand and -control than multi-threaded programming. The module documented here solves -many of the difficult problems for you, making the task of building -sophisticated high-performance network servers and clients a snap. - -NB: this is a fork of asyncore from the stdlib that we've (the waitress -developers) named 'wasyncore' to ensure forward compatibility, as asyncore -in the stdlib will be dropped soon. It is neither a copy of the 2.7 asyncore -nor the 3.X asyncore; it is a version compatible with either 2.7 or 3.X. -""" - -from . import compat -from . import utilities - -import logging -import select -import socket -import sys -import time -import warnings - -import os -from errno import ( - EALREADY, - EINPROGRESS, - EWOULDBLOCK, - ECONNRESET, - EINVAL, - ENOTCONN, - ESHUTDOWN, - EISCONN, - EBADF, - ECONNABORTED, - EPIPE, - EAGAIN, - EINTR, - errorcode, -) - -_DISCONNECTED = frozenset({ECONNRESET, ENOTCONN, ESHUTDOWN, ECONNABORTED, EPIPE, EBADF}) - -try: - socket_map -except NameError: - socket_map = {} - - -def _strerror(err): - try: - return os.strerror(err) - except (TypeError, ValueError, OverflowError, NameError): - return "Unknown error %s" % err - - -class ExitNow(Exception): - pass - - -_reraised_exceptions = (ExitNow, KeyboardInterrupt, SystemExit) - - -def read(obj): - try: - obj.handle_read_event() - except _reraised_exceptions: - raise - except: - obj.handle_error() - - -def write(obj): - try: - obj.handle_write_event() - except _reraised_exceptions: - raise - except: - obj.handle_error() - - -def _exception(obj): - try: - obj.handle_expt_event() - except _reraised_exceptions: - raise - except: - obj.handle_error() - - -def readwrite(obj, flags): - try: - if flags & select.POLLIN: - obj.handle_read_event() - if flags & select.POLLOUT: - obj.handle_write_event() - if flags & select.POLLPRI: - obj.handle_expt_event() - if flags & (select.POLLHUP | select.POLLERR | select.POLLNVAL): - obj.handle_close() - except socket.error as e: - if e.args[0] not in _DISCONNECTED: - obj.handle_error() - else: - obj.handle_close() - except _reraised_exceptions: - raise - except: - obj.handle_error() - - -def poll(timeout=0.0, map=None): - if map is None: # pragma: no cover - map = socket_map - if map: - r = [] - w = [] - e = [] - for fd, obj in list(map.items()): # list() call FBO py3 - is_r = obj.readable() - is_w = obj.writable() - if is_r: - r.append(fd) - # accepting sockets should not be writable - if is_w and not obj.accepting: - w.append(fd) - if is_r or is_w: - e.append(fd) - if [] == r == w == e: - time.sleep(timeout) - return - - try: - r, w, e = select.select(r, w, e, timeout) - except select.error as err: - if err.args[0] != EINTR: - raise - else: - return - - for fd in r: - obj = map.get(fd) - if obj is None: # pragma: no cover - continue - read(obj) - - for fd in w: - obj = map.get(fd) - if obj is None: # pragma: no cover - continue - write(obj) - - for fd in e: - obj = map.get(fd) - if obj is None: # pragma: no cover - continue - _exception(obj) - - -def poll2(timeout=0.0, map=None): - # Use the poll() support added to the select module in Python 2.0 - if map is None: # pragma: no cover - map = socket_map - if timeout is not None: - # timeout is in milliseconds - timeout = int(timeout * 1000) - pollster = select.poll() - if map: - for fd, obj in list(map.items()): - flags = 0 - if obj.readable(): - flags |= select.POLLIN | select.POLLPRI - # accepting sockets should not be writable - if obj.writable() and not obj.accepting: - flags |= select.POLLOUT - if flags: - pollster.register(fd, flags) - - try: - r = pollster.poll(timeout) - except select.error as err: - if err.args[0] != EINTR: - raise - r = [] - - for fd, flags in r: - obj = map.get(fd) - if obj is None: # pragma: no cover - continue - readwrite(obj, flags) - - -poll3 = poll2 # Alias for backward compatibility - - -def loop(timeout=30.0, use_poll=False, map=None, count=None): - if map is None: # pragma: no cover - map = socket_map - - if use_poll and hasattr(select, "poll"): - poll_fun = poll2 - else: - poll_fun = poll - - if count is None: # pragma: no cover - while map: - poll_fun(timeout, map) - - else: - while map and count > 0: - poll_fun(timeout, map) - count = count - 1 - - -def compact_traceback(): - t, v, tb = sys.exc_info() - tbinfo = [] - if not tb: # pragma: no cover - raise AssertionError("traceback does not exist") - while tb: - tbinfo.append( - ( - tb.tb_frame.f_code.co_filename, - tb.tb_frame.f_code.co_name, - str(tb.tb_lineno), - ) - ) - tb = tb.tb_next - - # just to be safe - del tb - - file, function, line = tbinfo[-1] - info = " ".join(["[%s|%s|%s]" % x for x in tbinfo]) - return (file, function, line), t, v, info - - -class dispatcher: - - debug = False - connected = False - accepting = False - connecting = False - closing = False - addr = None - ignore_log_types = frozenset({"warning"}) - logger = utilities.logger - compact_traceback = staticmethod(compact_traceback) # for testing - - def __init__(self, sock=None, map=None): - if map is None: # pragma: no cover - self._map = socket_map - else: - self._map = map - - self._fileno = None - - if sock: - # Set to nonblocking just to make sure for cases where we - # get a socket from a blocking source. - sock.setblocking(0) - self.set_socket(sock, map) - self.connected = True - # The constructor no longer requires that the socket - # passed be connected. - try: - self.addr = sock.getpeername() - except socket.error as err: - if err.args[0] in (ENOTCONN, EINVAL): - # To handle the case where we got an unconnected - # socket. - self.connected = False - else: - # The socket is broken in some unknown way, alert - # the user and remove it from the map (to prevent - # polling of broken sockets). - self.del_channel(map) - raise - else: - self.socket = None - - def __repr__(self): - status = [self.__class__.__module__ + "." + compat.qualname(self.__class__)] - if self.accepting and self.addr: - status.append("listening") - elif self.connected: - status.append("connected") - if self.addr is not None: - try: - status.append("%s:%d" % self.addr) - except TypeError: # pragma: no cover - status.append(repr(self.addr)) - return "<%s at %#x>" % (" ".join(status), id(self)) - - __str__ = __repr__ - - def add_channel(self, map=None): - # self.log_info('adding channel %s' % self) - if map is None: - map = self._map - map[self._fileno] = self - - def del_channel(self, map=None): - fd = self._fileno - if map is None: - map = self._map - if fd in map: - # self.log_info('closing channel %d:%s' % (fd, self)) - del map[fd] - self._fileno = None - - def create_socket(self, family=socket.AF_INET, type=socket.SOCK_STREAM): - self.family_and_type = family, type - sock = socket.socket(family, type) - sock.setblocking(0) - self.set_socket(sock) - - def set_socket(self, sock, map=None): - self.socket = sock - self._fileno = sock.fileno() - self.add_channel(map) - - def set_reuse_addr(self): - # try to re-use a server port if possible - try: - self.socket.setsockopt( - socket.SOL_SOCKET, - socket.SO_REUSEADDR, - self.socket.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR) | 1, - ) - except socket.error: - pass - - # ================================================== - # predicates for select() - # these are used as filters for the lists of sockets - # to pass to select(). - # ================================================== - - def readable(self): - return True - - def writable(self): - return True - - # ================================================== - # socket object methods. - # ================================================== - - def listen(self, num): - self.accepting = True - if os.name == "nt" and num > 5: # pragma: no cover - num = 5 - return self.socket.listen(num) - - def bind(self, addr): - self.addr = addr - return self.socket.bind(addr) - - def connect(self, address): - self.connected = False - self.connecting = True - err = self.socket.connect_ex(address) - if ( - err in (EINPROGRESS, EALREADY, EWOULDBLOCK) - or err == EINVAL - and os.name == "nt" - ): # pragma: no cover - self.addr = address - return - if err in (0, EISCONN): - self.addr = address - self.handle_connect_event() - else: - raise socket.error(err, errorcode[err]) - - def accept(self): - # XXX can return either an address pair or None - try: - conn, addr = self.socket.accept() - except TypeError: - return None - except socket.error as why: - if why.args[0] in (EWOULDBLOCK, ECONNABORTED, EAGAIN): - return None - else: - raise - else: - return conn, addr - - def send(self, data): - try: - result = self.socket.send(data) - return result - except socket.error as why: - if why.args[0] == EWOULDBLOCK: - return 0 - elif why.args[0] in _DISCONNECTED: - self.handle_close() - return 0 - else: - raise - - def recv(self, buffer_size): - try: - data = self.socket.recv(buffer_size) - if not data: - # a closed connection is indicated by signaling - # a read condition, and having recv() return 0. - self.handle_close() - return b"" - else: - return data - except socket.error as why: - # winsock sometimes raises ENOTCONN - if why.args[0] in _DISCONNECTED: - self.handle_close() - return b"" - else: - raise - - def close(self): - self.connected = False - self.accepting = False - self.connecting = False - self.del_channel() - if self.socket is not None: - try: - self.socket.close() - except socket.error as why: - if why.args[0] not in (ENOTCONN, EBADF): - raise - - # log and log_info may be overridden to provide more sophisticated - # logging and warning methods. In general, log is for 'hit' logging - # and 'log_info' is for informational, warning and error logging. - - def log(self, message): - self.logger.log(logging.DEBUG, message) - - def log_info(self, message, type="info"): - severity = { - "info": logging.INFO, - "warning": logging.WARN, - "error": logging.ERROR, - } - self.logger.log(severity.get(type, logging.INFO), message) - - def handle_read_event(self): - if self.accepting: - # accepting sockets are never connected, they "spawn" new - # sockets that are connected - self.handle_accept() - elif not self.connected: - if self.connecting: - self.handle_connect_event() - self.handle_read() - else: - self.handle_read() - - def handle_connect_event(self): - err = self.socket.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR) - if err != 0: - raise socket.error(err, _strerror(err)) - self.handle_connect() - self.connected = True - self.connecting = False - - def handle_write_event(self): - if self.accepting: - # Accepting sockets shouldn't get a write event. - # We will pretend it didn't happen. - return - - if not self.connected: - if self.connecting: - self.handle_connect_event() - self.handle_write() - - def handle_expt_event(self): - # handle_expt_event() is called if there might be an error on the - # socket, or if there is OOB data - # check for the error condition first - err = self.socket.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR) - if err != 0: - # we can get here when select.select() says that there is an - # exceptional condition on the socket - # since there is an error, we'll go ahead and close the socket - # like we would in a subclassed handle_read() that received no - # data - self.handle_close() - else: - self.handle_expt() - - def handle_error(self): - nil, t, v, tbinfo = self.compact_traceback() - - # sometimes a user repr method will crash. - try: - self_repr = repr(self) - except: # pragma: no cover - self_repr = "<__repr__(self) failed for object at %0x>" % id(self) - - self.log_info( - "uncaptured python exception, closing channel %s (%s:%s %s)" - % (self_repr, t, v, tbinfo), - "error", - ) - self.handle_close() - - def handle_expt(self): - self.log_info("unhandled incoming priority event", "warning") - - def handle_read(self): - self.log_info("unhandled read event", "warning") - - def handle_write(self): - self.log_info("unhandled write event", "warning") - - def handle_connect(self): - self.log_info("unhandled connect event", "warning") - - def handle_accept(self): - pair = self.accept() - if pair is not None: - self.handle_accepted(*pair) - - def handle_accepted(self, sock, addr): - sock.close() - self.log_info("unhandled accepted event", "warning") - - def handle_close(self): - self.log_info("unhandled close event", "warning") - self.close() - - -# --------------------------------------------------------------------------- -# adds simple buffered output capability, useful for simple clients. -# [for more sophisticated usage use asynchat.async_chat] -# --------------------------------------------------------------------------- - - -class dispatcher_with_send(dispatcher): - def __init__(self, sock=None, map=None): - dispatcher.__init__(self, sock, map) - self.out_buffer = b"" - - def initiate_send(self): - num_sent = 0 - num_sent = dispatcher.send(self, self.out_buffer[:65536]) - self.out_buffer = self.out_buffer[num_sent:] - - handle_write = initiate_send - - def writable(self): - return (not self.connected) or len(self.out_buffer) - - def send(self, data): - if self.debug: # pragma: no cover - self.log_info("sending %s" % repr(data)) - self.out_buffer = self.out_buffer + data - self.initiate_send() - - -def close_all(map=None, ignore_all=False): - if map is None: # pragma: no cover - map = socket_map - for x in list(map.values()): # list() FBO py3 - try: - x.close() - except socket.error as x: - if x.args[0] == EBADF: - pass - elif not ignore_all: - raise - except _reraised_exceptions: - raise - except: - if not ignore_all: - raise - map.clear() - - -# Asynchronous File I/O: -# -# After a little research (reading man pages on various unixen, and -# digging through the linux kernel), I've determined that select() -# isn't meant for doing asynchronous file i/o. -# Heartening, though - reading linux/mm/filemap.c shows that linux -# supports asynchronous read-ahead. So _MOST_ of the time, the data -# will be sitting in memory for us already when we go to read it. -# -# What other OS's (besides NT) support async file i/o? [VMS?] -# -# Regardless, this is useful for pipes, and stdin/stdout... - -if os.name == "posix": - - class file_wrapper: - # Here we override just enough to make a file - # look like a socket for the purposes of asyncore. - # The passed fd is automatically os.dup()'d - - def __init__(self, fd): - self.fd = os.dup(fd) - - def __del__(self): - if self.fd >= 0: - warnings.warn("unclosed file %r" % self, compat.ResourceWarning) - self.close() - - def recv(self, *args): - return os.read(self.fd, *args) - - def send(self, *args): - return os.write(self.fd, *args) - - def getsockopt(self, level, optname, buflen=None): # pragma: no cover - if level == socket.SOL_SOCKET and optname == socket.SO_ERROR and not buflen: - return 0 - raise NotImplementedError( - "Only asyncore specific behaviour " "implemented." - ) - - read = recv - write = send - - def close(self): - if self.fd < 0: - return - fd = self.fd - self.fd = -1 - os.close(fd) - - def fileno(self): - return self.fd - - class file_dispatcher(dispatcher): - def __init__(self, fd, map=None): - dispatcher.__init__(self, None, map) - self.connected = True - try: - fd = fd.fileno() - except AttributeError: - pass - self.set_file(fd) - # set it to non-blocking mode - compat.set_nonblocking(fd) - - def set_file(self, fd): - self.socket = file_wrapper(fd) - self._fileno = self.socket.fileno() - self.add_channel() diff --git a/requirements.txt b/requirements.txt index 983acb62e..93c85925f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,5 @@ +gevent>=21 +gevent-websocket>=0.10.1 lxml>=4.3.0 numpy>=1.12.0 webrtcvad-wheels>=2.0.10 \ No newline at end of file