Added Swagger documentation for Bazarr API

This commit is contained in:
morpheus65535 2022-09-21 23:51:34 -04:00 committed by GitHub
parent c3f43f0e42
commit 131b4e5cde
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
161 changed files with 26157 additions and 2098 deletions

View File

@ -1,25 +1,51 @@
# coding=utf-8
from .badges import api_bp_badges
from .system import api_bp_system
from .series import api_bp_series
from .episodes import api_bp_episodes
from .providers import api_bp_providers
from .subtitles import api_bp_subtitles
from .webhooks import api_bp_webhooks
from .history import api_bp_history
from .files import api_bp_files
from .movies import api_bp_movies
from flask import Blueprint, url_for
from flask_restx import Api, apidoc
api_bp_list = [
api_bp_badges,
api_bp_system,
api_bp_series,
api_bp_episodes,
api_bp_providers,
api_bp_subtitles,
api_bp_webhooks,
api_bp_history,
api_bp_files,
api_bp_movies
from .badges import api_ns_list_badges
from .episodes import api_ns_list_episodes
from .files import api_ns_list_files
from .history import api_ns_list_history
from .movies import api_ns_list_movies
from .providers import api_ns_list_providers
from .series import api_ns_list_series
from .subtitles import api_ns_list_subtitles
from .system import api_ns_list_system
from .webhooks import api_ns_list_webhooks
from .swaggerui import swaggerui_api_params
api_ns_list = [
api_ns_list_badges,
api_ns_list_episodes,
api_ns_list_files,
api_ns_list_history,
api_ns_list_movies,
api_ns_list_providers,
api_ns_list_series,
api_ns_list_subtitles,
api_ns_list_system,
api_ns_list_webhooks,
]
authorizations = {
'apikey': {
'type': 'apiKey',
'in': 'header',
'name': 'X-API-KEY'
}
}
api_bp = Blueprint('api', __name__, url_prefix='/api')
@apidoc.apidoc.add_app_template_global
def swagger_static(filename):
return url_for('ui.swaggerui_static', filename=filename)
api = Api(api_bp, authorizations=authorizations, security='apikey', validate=True, **swaggerui_api_params)
for api_ns in api_ns_list:
for item in api_ns:
api.add_namespace(item, "/")

View File

@ -1,12 +1,7 @@
# coding=utf-8
from flask import Blueprint
from flask_restful import Api
from .badges import api_ns_badges
from .badges import Badges
api_bp_badges = Blueprint('api_badges', __name__)
api = Api(api_bp_badges)
api.add_resource(Badges, '/badges')
api_ns_list_badges = [
api_ns_badges
]

View File

@ -3,8 +3,7 @@
import operator
from functools import reduce
from flask import jsonify
from flask_restful import Resource
from flask_restx import Resource, Namespace, fields
from app.database import get_exclusion_clause, TableEpisodes, TableShows, TableMovies
from app.get_providers import get_throttled_providers
@ -12,10 +11,25 @@ from utilities.health import get_health_issues
from ..utils import authenticate
api_ns_badges = Namespace('Badges', description='Get badges count to update the UI (episodes and movies wanted '
'subtitles, providers with issues and health issues.')
@api_ns_badges.route('badges')
class Badges(Resource):
get_model = api_ns_badges.model('BadgesGet', {
'episodes': fields.Integer(),
'movies': fields.Integer(),
'providers': fields.Integer(),
'status': fields.Integer(),
})
@authenticate
@api_ns_badges.marshal_with(get_model, code=200)
@api_ns_badges.response(401, 'Not Authenticated')
@api_ns_badges.doc(parser=None)
def get(self):
"""Get badges count to update the UI"""
episodes_conditions = [(TableEpisodes.missing_subtitles.is_null(False)),
(TableEpisodes.missing_subtitles != '[]')]
episodes_conditions += get_exclusion_clause('series')
@ -44,4 +58,4 @@ class Badges(Resource):
"providers": throttled_providers,
"status": health_issues
}
return jsonify(result)
return result

View File

@ -1,20 +1,16 @@
# coding=utf-8
from flask import Blueprint
from flask_restful import Api
from .episodes import Episodes
from .episodes_subtitles import EpisodesSubtitles
from .history import EpisodesHistory
from .wanted import EpisodesWanted
from .blacklist import EpisodesBlacklist
from .episodes import api_ns_episodes
from .episodes_subtitles import api_ns_episodes_subtitles
from .history import api_ns_episodes_history
from .wanted import api_ns_episodes_wanted
from .blacklist import api_ns_episodes_blacklist
api_bp_episodes = Blueprint('api_episodes', __name__)
api = Api(api_bp_episodes)
api.add_resource(Episodes, '/episodes')
api.add_resource(EpisodesWanted, '/episodes/wanted')
api.add_resource(EpisodesSubtitles, '/episodes/subtitles')
api.add_resource(EpisodesHistory, '/episodes/history')
api.add_resource(EpisodesBlacklist, '/episodes/blacklist')
api_ns_list_episodes = [
api_ns_episodes,
api_ns_episodes_blacklist,
api_ns_episodes_history,
api_ns_episodes_subtitles,
api_ns_episodes_wanted,
]

View File

@ -3,8 +3,7 @@
import datetime
import pretty
from flask import request, jsonify
from flask_restful import Resource
from flask_restx import Resource, Namespace, reqparse, fields
from app.database import TableEpisodes, TableShows, TableBlacklist
from subtitles.tools.delete import delete_subtitles
@ -12,18 +11,43 @@ from sonarr.blacklist import blacklist_log, blacklist_delete_all, blacklist_dele
from utilities.path_mappings import path_mappings
from subtitles.mass_download import episode_download_subtitles
from app.event_handler import event_stream
from api.swaggerui import subtitles_language_model
from ..utils import authenticate, postprocessEpisode
api_ns_episodes_blacklist = Namespace('Episodes Blacklist', description='List, add or remove subtitles to or from '
'episodes blacklist')
# GET: get blacklist
# POST: add blacklist
# DELETE: remove blacklist
@api_ns_episodes_blacklist.route('episodes/blacklist')
class EpisodesBlacklist(Resource):
get_request_parser = reqparse.RequestParser()
get_request_parser.add_argument('start', type=int, required=False, default=0, help='Paging start integer')
get_request_parser.add_argument('length', type=int, required=False, default=-1, help='Paging length integer')
get_language_model = api_ns_episodes_blacklist.model('subtitles_language_model', subtitles_language_model)
get_response_model = api_ns_episodes_blacklist.model('EpisodeBlacklistGetResponse', {
'seriesTitle': fields.String(),
'episode_number': fields.String(),
'episodeTitle': fields.String(),
'sonarrSeriesId': fields.Integer(),
'provider': fields.String(),
'subs_id': fields.String(),
'language': fields.Nested(get_language_model),
'timestamp': fields.String(),
'parsed_timestamp': fields.String(),
})
@authenticate
@api_ns_episodes_blacklist.marshal_with(get_response_model, envelope='data', code=200)
@api_ns_episodes_blacklist.response(401, 'Not Authenticated')
@api_ns_episodes_blacklist.doc(parser=get_request_parser)
def get(self):
start = request.args.get('start') or 0
length = request.args.get('length') or -1
"""List blacklisted episodes subtitles"""
args = self.get_request_parser.parse_args()
start = args.get('start')
length = args.get('length')
data = TableBlacklist.select(TableShows.title.alias('seriesTitle'),
TableEpisodes.season.concat('x').concat(TableEpisodes.episode).alias('episode_number'),
@ -48,15 +72,29 @@ class EpisodesBlacklist(Resource):
postprocessEpisode(item)
return jsonify(data=data)
return data
post_request_parser = reqparse.RequestParser()
post_request_parser.add_argument('seriesid', type=int, required=True, help='Series ID')
post_request_parser.add_argument('episodeid', type=int, required=True, help='Episode ID')
post_request_parser.add_argument('provider', type=str, required=True, help='Provider name')
post_request_parser.add_argument('subs_id', type=str, required=True, help='Subtitles ID')
post_request_parser.add_argument('language', type=str, required=True, help='Subtitles language')
post_request_parser.add_argument('subtitles_path', type=str, required=True, help='Subtitles file path')
@authenticate
@api_ns_episodes_blacklist.doc(parser=post_request_parser)
@api_ns_episodes_blacklist.response(200, 'Success')
@api_ns_episodes_blacklist.response(401, 'Not Authenticated')
@api_ns_episodes_blacklist.response(404, 'Episode not found')
def post(self):
sonarr_series_id = int(request.args.get('seriesid'))
sonarr_episode_id = int(request.args.get('episodeid'))
provider = request.form.get('provider')
subs_id = request.form.get('subs_id')
language = request.form.get('language')
"""Add an episodes subtitles to blacklist"""
args = self.post_request_parser.parse_args()
sonarr_series_id = args.get('seriesid')
sonarr_episode_id = args.get('episodeid')
provider = args.get('provider')
subs_id = args.get('subs_id')
language = args.get('language')
episodeInfo = TableEpisodes.select(TableEpisodes.path)\
.where(TableEpisodes.sonarrEpisodeId == sonarr_episode_id)\
@ -67,7 +105,7 @@ class EpisodesBlacklist(Resource):
return 'Episode not found', 404
media_path = episodeInfo['path']
subtitles_path = request.form.get('subtitles_path')
subtitles_path = args.get('subtitles_path')
blacklist_log(sonarr_series_id=sonarr_series_id,
sonarr_episode_id=sonarr_episode_id,
@ -86,12 +124,22 @@ class EpisodesBlacklist(Resource):
event_stream(type='episode-history')
return '', 200
delete_request_parser = reqparse.RequestParser()
delete_request_parser.add_argument('all', type=str, required=False, help='Empty episodes subtitles blacklist')
delete_request_parser.add_argument('provider', type=str, required=True, help='Provider name')
delete_request_parser.add_argument('subs_id', type=str, required=True, help='Subtitles ID')
@authenticate
@api_ns_episodes_blacklist.doc(parser=delete_request_parser)
@api_ns_episodes_blacklist.response(204, 'Success')
@api_ns_episodes_blacklist.response(401, 'Not Authenticated')
def delete(self):
if request.args.get("all") == "true":
"""Delete an episodes subtitles from blacklist"""
args = self.post_request_parser.parse_args()
if args.get("all") == "true":
blacklist_delete_all()
else:
provider = request.form.get('provider')
subs_id = request.form.get('subs_id')
provider = args.get('provider')
subs_id = args.get('subs_id')
blacklist_delete(provider=provider, subs_id=subs_id)
return '', 204

View File

@ -1,18 +1,60 @@
# coding=utf-8
from flask import request, jsonify
from flask_restful import Resource
from flask_restx import Resource, Namespace, reqparse, fields
from app.database import TableEpisodes
from api.swaggerui import subtitles_model, subtitles_language_model, audio_language_model
from ..utils import authenticate, postprocessEpisode
api_ns_episodes = Namespace('Episodes', description='List episodes metadata for specific series or episodes.')
@api_ns_episodes.route('episodes')
class Episodes(Resource):
get_request_parser = reqparse.RequestParser()
get_request_parser.add_argument('seriesid[]', type=int, action='append', required=False, default=[],
help='Series IDs to list episodes for')
get_request_parser.add_argument('episodeid[]', type=int, action='append', required=False, default=[],
help='Episodes ID to list')
get_subtitles_model = api_ns_episodes.model('subtitles_model', subtitles_model)
get_subtitles_language_model = api_ns_episodes.model('subtitles_language_model', subtitles_language_model)
get_audio_language_model = api_ns_episodes.model('audio_language_model', audio_language_model)
get_response_model = api_ns_episodes.model('EpisodeGetResponse', {
'rowid': fields.Integer(),
'audio_codec': fields.String(),
'audio_language': fields.Nested(get_audio_language_model),
'episode': fields.Integer(),
'episode_file_id': fields.Integer(),
'failedAttempts': fields.String(),
'file_size': fields.Integer(),
'format': fields.String(),
'missing_subtitles': fields.Nested(get_subtitles_language_model),
'monitored': fields.Boolean(),
'path': fields.String(),
'resolution': fields.String(),
'season': fields.Integer(),
'sonarrEpisodeId': fields.Integer(),
'sonarrSeriesId': fields.Integer(),
'subtitles': fields.Nested(get_subtitles_model),
'title': fields.String(),
'video_codec': fields.String(),
'sceneName': fields.String(),
})
@authenticate
@api_ns_episodes.marshal_with(get_response_model, envelope='data', code=200)
@api_ns_episodes.doc(parser=get_request_parser)
@api_ns_episodes.response(200, 'Success')
@api_ns_episodes.response(401, 'Not Authenticated')
@api_ns_episodes.response(404, 'Series or Episode ID not provided')
def get(self):
seriesId = request.args.getlist('seriesid[]')
episodeId = request.args.getlist('episodeid[]')
"""List episodes metadata for specific series or episodes"""
args = self.get_request_parser.parse_args()
seriesId = args.get('seriesid[]')
episodeId = args.get('episodeid[]')
if len(episodeId) > 0:
result = TableEpisodes.select().where(TableEpisodes.sonarrEpisodeId.in_(episodeId)).dicts()
@ -28,4 +70,4 @@ class Episodes(Resource):
for item in result:
postprocessEpisode(item)
return jsonify(data=result)
return result

View File

@ -3,9 +3,9 @@
import os
import logging
from flask import request
from flask_restful import Resource
from flask_restx import Resource, Namespace, reqparse
from subliminal_patch.core import SUBTITLE_EXTENSIONS
from werkzeug.datastructures import FileStorage
from app.database import TableShows, TableEpisodes, get_audio_profile_languages, get_profile_id
from utilities.path_mappings import path_mappings
@ -20,15 +20,28 @@ from app.config import settings
from ..utils import authenticate
api_ns_episodes_subtitles = Namespace('Episodes Subtitles', description='Download, upload or delete episodes subtitles')
# PATCH: Download Subtitles
# POST: Upload Subtitles
# DELETE: Delete Subtitles
@api_ns_episodes_subtitles.route('episodes/subtitles')
class EpisodesSubtitles(Resource):
patch_request_parser = reqparse.RequestParser()
patch_request_parser.add_argument('seriesid', type=int, required=True, help='Series ID')
patch_request_parser.add_argument('episodeid', type=int, required=True, help='Episode ID')
patch_request_parser.add_argument('language', type=str, required=True, help='Language code2')
patch_request_parser.add_argument('forced', type=str, required=True, help='Forced true/false as string')
patch_request_parser.add_argument('hi', type=str, required=True, help='HI true/false as string')
@authenticate
@api_ns_episodes_subtitles.doc(parser=patch_request_parser)
@api_ns_episodes_subtitles.response(204, 'Success')
@api_ns_episodes_subtitles.response(401, 'Not Authenticated')
@api_ns_episodes_subtitles.response(404, 'Episode not found')
def patch(self):
sonarrSeriesId = request.args.get('seriesid')
sonarrEpisodeId = request.args.get('episodeid')
"""Download an episode subtitles"""
args = self.patch_request_parser.parse_args()
sonarrSeriesId = args.get('seriesid')
sonarrEpisodeId = args.get('episodeid')
episodeInfo = TableEpisodes.select(TableEpisodes.path,
TableEpisodes.scene_name,
TableEpisodes.audio_language,
@ -45,9 +58,9 @@ class EpisodesSubtitles(Resource):
episodePath = path_mappings.path_replace(episodeInfo['path'])
sceneName = episodeInfo['scene_name'] or "None"
language = request.form.get('language')
hi = request.form.get('hi').capitalize()
forced = request.form.get('forced').capitalize()
language = args.get('language')
hi = args.get('hi').capitalize()
forced = args.get('forced').capitalize()
audio_language_list = get_audio_profile_languages(episode_id=sonarrEpisodeId)
if len(audio_language_list) > 0:
@ -85,10 +98,25 @@ class EpisodesSubtitles(Resource):
return '', 204
post_request_parser = reqparse.RequestParser()
post_request_parser.add_argument('seriesid', type=int, required=True, help='Series ID')
post_request_parser.add_argument('episodeid', type=int, required=True, help='Episode ID')
post_request_parser.add_argument('language', type=str, required=True, help='Language code2')
post_request_parser.add_argument('forced', type=str, required=True, help='Forced true/false as string')
post_request_parser.add_argument('hi', type=str, required=True, help='HI true/false as string')
post_request_parser.add_argument('file', type=FileStorage, location='files', required=True,
help='Subtitles file as file upload object')
@authenticate
@api_ns_episodes_subtitles.doc(parser=post_request_parser)
@api_ns_episodes_subtitles.response(204, 'Success')
@api_ns_episodes_subtitles.response(401, 'Not Authenticated')
@api_ns_episodes_subtitles.response(404, 'Episode not found')
def post(self):
sonarrSeriesId = request.args.get('seriesid')
sonarrEpisodeId = request.args.get('episodeid')
"""Upload an episode subtitles"""
args = self.post_request_parser.parse_args()
sonarrSeriesId = args.get('seriesid')
sonarrEpisodeId = args.get('episodeid')
episodeInfo = TableEpisodes.select(TableEpisodes.title,
TableEpisodes.path,
TableEpisodes.scene_name,
@ -105,10 +133,10 @@ class EpisodesSubtitles(Resource):
sceneName = episodeInfo['scene_name'] or "None"
audio_language = episodeInfo['audio_language']
language = request.form.get('language')
forced = True if request.form.get('forced') == 'true' else False
hi = True if request.form.get('hi') == 'true' else False
subFile = request.files.get('file')
language = args.get('language')
forced = True if args.get('forced') == 'true' else False
hi = True if args.get('hi') == 'true' else False
subFile = args.get('file')
_, ext = os.path.splitext(subFile.filename)
@ -151,10 +179,24 @@ class EpisodesSubtitles(Resource):
return '', 204
delete_request_parser = reqparse.RequestParser()
delete_request_parser.add_argument('seriesid', type=int, required=True, help='Series ID')
delete_request_parser.add_argument('episodeid', type=int, required=True, help='Episode ID')
delete_request_parser.add_argument('language', type=str, required=True, help='Language code2')
delete_request_parser.add_argument('forced', type=str, required=True, help='Forced true/false as string')
delete_request_parser.add_argument('hi', type=str, required=True, help='HI true/false as string')
delete_request_parser.add_argument('path', type=str, required=True, help='Path of the subtitles file')
@authenticate
@api_ns_episodes_subtitles.doc(parser=delete_request_parser)
@api_ns_episodes_subtitles.response(204, 'Success')
@api_ns_episodes_subtitles.response(401, 'Not Authenticated')
@api_ns_episodes_subtitles.response(404, 'Episode not found')
def delete(self):
sonarrSeriesId = request.args.get('seriesid')
sonarrEpisodeId = request.args.get('episodeid')
"""Delete an episode subtitles"""
args = self.delete_request_parser.parse_args()
sonarrSeriesId = args.get('seriesid')
sonarrEpisodeId = args.get('episodeid')
episodeInfo = TableEpisodes.select(TableEpisodes.title,
TableEpisodes.path,
TableEpisodes.scene_name,
@ -168,10 +210,10 @@ class EpisodesSubtitles(Resource):
episodePath = path_mappings.path_replace(episodeInfo['path'])
language = request.form.get('language')
forced = request.form.get('forced')
hi = request.form.get('hi')
subtitlesPath = request.form.get('path')
language = args.get('language')
forced = args.get('forced')
hi = args.get('hi')
subtitlesPath = args.get('path')
subtitlesPath = path_mappings.path_replace_reverse(subtitlesPath)

View File

@ -5,8 +5,7 @@ import os
import operator
import pretty
from flask import request, jsonify
from flask_restful import Resource
from flask_restx import Resource, Namespace, reqparse, fields
from functools import reduce
from peewee import fn
from datetime import timedelta
@ -14,16 +13,62 @@ from datetime import timedelta
from app.database import get_exclusion_clause, TableEpisodes, TableShows, TableHistory, TableBlacklist
from app.config import settings
from utilities.path_mappings import path_mappings
from api.swaggerui import subtitles_language_model
from ..utils import authenticate, postprocessEpisode
api_ns_episodes_history = Namespace('Episodes History', description='List episodes history events')
@api_ns_episodes_history.route('episodes/history')
class EpisodesHistory(Resource):
get_request_parser = reqparse.RequestParser()
get_request_parser.add_argument('start', type=int, required=False, default=0, help='Paging start integer')
get_request_parser.add_argument('length', type=int, required=False, default=-1, help='Paging length integer')
get_request_parser.add_argument('episodeid', type=int, required=False, help='Episode ID')
get_language_model = api_ns_episodes_history.model('subtitles_language_model', subtitles_language_model)
data_model = api_ns_episodes_history.model('history_episodes_data_model', {
'id': fields.Integer(),
'seriesTitle': fields.String(),
'monitored': fields.Boolean(),
'episode_number': fields.String(),
'episodeTitle': fields.String(),
'timestamp': fields.String(),
'subs_id': fields.String(),
'description': fields.String(),
'sonarrSeriesId': fields.Integer(),
'language': fields.Nested(get_language_model),
'score': fields.String(),
'tags': fields.List(fields.String),
'action': fields.Integer(),
'video_path': fields.String(),
'subtitles_path': fields.String(),
'sonarrEpisodeId': fields.Integer(),
'provider': fields.String(),
'seriesType': fields.String(),
'upgradable': fields.Boolean(),
'raw_timestamp': fields.Integer(),
'parsed_timestamp': fields.String(),
'blacklisted': fields.Boolean(),
})
get_response_model = api_ns_episodes_history.model('EpisodeHistoryGetResponse', {
'data': fields.Nested(data_model),
'total': fields.Integer(),
})
@authenticate
@api_ns_episodes_history.marshal_with(get_response_model, code=200)
@api_ns_episodes_history.response(401, 'Not Authenticated')
@api_ns_episodes_history.doc(parser=get_request_parser)
def get(self):
start = request.args.get('start') or 0
length = request.args.get('length') or -1
episodeid = request.args.get('episodeid')
"""List episodes history events"""
args = self.get_request_parser.parse_args()
start = args.get('start')
length = args.get('length')
episodeid = args.get('episodeid')
upgradable_episodes_not_perfect = []
if settings.general.getboolean('upgrade_subs'):
@ -133,4 +178,4 @@ class EpisodesHistory(Resource):
.join(TableEpisodes, on=(TableHistory.sonarrEpisodeId == TableEpisodes.sonarrEpisodeId))\
.where(TableEpisodes.title.is_null(False)).count()
return jsonify(data=episode_history, total=count)
return {'data': episode_history, 'total': count}

View File

@ -2,20 +2,54 @@
import operator
from flask import request, jsonify
from flask_restful import Resource
from flask_restx import Resource, Namespace, reqparse, fields
from functools import reduce
from app.database import get_exclusion_clause, TableEpisodes, TableShows
from api.swaggerui import subtitles_language_model
from ..utils import authenticate, postprocessEpisode
api_ns_episodes_wanted = Namespace('Episodes Wanted', description='List episodes wanted subtitles')
# GET: Get Wanted Episodes
@api_ns_episodes_wanted.route('episodes/wanted')
class EpisodesWanted(Resource):
get_request_parser = reqparse.RequestParser()
get_request_parser.add_argument('start', type=int, required=False, default=0, help='Paging start integer')
get_request_parser.add_argument('length', type=int, required=False, default=-1, help='Paging length integer')
get_request_parser.add_argument('episodeid[]', type=int, action='append', required=False, default=[],
help='Episodes ID to list')
get_subtitles_language_model = api_ns_episodes_wanted.model('subtitles_language_model', subtitles_language_model)
data_model = api_ns_episodes_wanted.model('wanted_episodes_data_model', {
'seriesTitle': fields.String(),
'monitored': fields.Boolean(),
'episode_number': fields.String(),
'episodeTitle': fields.String(),
'missing_subtitles': fields.Nested(get_subtitles_language_model),
'sonarrSeriesId': fields.Integer(),
'sonarrEpisodeId': fields.Integer(),
'sceneName': fields.String(),
'tags': fields.List(fields.String),
'failedAttempts': fields.String(),
'seriesType': fields.String(),
})
get_response_model = api_ns_episodes_wanted.model('EpisodeWantedGetResponse', {
'data': fields.Nested(data_model),
'total': fields.Integer(),
})
@authenticate
@api_ns_episodes_wanted.marshal_with(get_response_model, code=200)
@api_ns_episodes_wanted.response(401, 'Not Authenticated')
@api_ns_episodes_wanted.doc(parser=get_request_parser)
def get(self):
episodeid = request.args.getlist('episodeid[]')
"""List episodes wanted subtitles"""
args = self.get_request_parser.parse_args()
episodeid = args.get('episodeid[]')
wanted_conditions = [(TableEpisodes.missing_subtitles != '[]')]
if len(episodeid) > 0:
@ -39,8 +73,8 @@ class EpisodesWanted(Resource):
.where(wanted_condition)\
.dicts()
else:
start = request.args.get('start') or 0
length = request.args.get('length') or -1
start = args.get('start')
length = args.get('length')
data = TableEpisodes.select(TableShows.title.alias('seriesTitle'),
TableEpisodes.monitored,
TableEpisodes.season.concat('x').concat(TableEpisodes.episode).alias('episode_number'),
@ -72,4 +106,4 @@ class EpisodesWanted(Resource):
.where(reduce(operator.and_, count_conditions))\
.count()
return jsonify(data=data, total=count)
return {'data': data, 'total': count}

View File

@ -1,16 +1,11 @@
# coding=utf-8
from flask import Blueprint
from flask_restful import Api
from .files import api_ns_files
from .files_sonarr import api_ns_files_sonarr
from .files_radarr import api_ns_files_radarr
from .files import BrowseBazarrFS
from .files_sonarr import BrowseSonarrFS
from .files_radarr import BrowseRadarrFS
api_bp_files = Blueprint('api_files', __name__)
api = Api(api_bp_files)
api.add_resource(BrowseBazarrFS, '/files')
api.add_resource(BrowseSonarrFS, '/files/sonarr')
api.add_resource(BrowseRadarrFS, '/files/radarr')
api_ns_list_files = [
api_ns_files,
api_ns_files_radarr,
api_ns_files_sonarr,
]

View File

@ -1,24 +1,40 @@
# coding=utf-8
from flask import request, jsonify
from flask_restful import Resource
from flask_restx import Resource, Namespace, reqparse, fields
from utilities.filesystem import browse_bazarr_filesystem
from ..utils import authenticate
api_ns_files = Namespace('Files Browser for Bazarr', description='Browse content of file system as seen by Bazarr')
@api_ns_files.route('files')
class BrowseBazarrFS(Resource):
get_request_parser = reqparse.RequestParser()
get_request_parser.add_argument('path', type=str, default='', help='Path to browse')
get_response_model = api_ns_files.model('BazarrFileBrowserGetResponse', {
'name': fields.String(),
'children': fields.Boolean(),
'path': fields.String(),
})
@authenticate
@api_ns_files.marshal_with(get_response_model, code=200)
@api_ns_files.response(401, 'Not Authenticated')
@api_ns_files.doc(parser=get_request_parser)
def get(self):
path = request.args.get('path') or ''
"""List Bazarr file system content"""
args = self.get_request_parser.parse_args()
path = args.get('path')
data = []
try:
result = browse_bazarr_filesystem(path)
if result is None:
raise ValueError
except Exception:
return jsonify([])
return []
for item in result['directories']:
data.append({'name': item['name'], 'children': True, 'path': item['path']})
return jsonify(data)
return data

View File

@ -1,24 +1,41 @@
# coding=utf-8
from flask import request, jsonify
from flask_restful import Resource
from flask_restx import Resource, Namespace, reqparse, fields
from radarr.filesystem import browse_radarr_filesystem
from ..utils import authenticate
api_ns_files_radarr = Namespace('Files Browser for Radarr', description='Browse content of file system as seen by '
'Radarr')
@api_ns_files_radarr.route('files/radarr')
class BrowseRadarrFS(Resource):
get_request_parser = reqparse.RequestParser()
get_request_parser.add_argument('path', type=str, default='', help='Path to browse')
get_response_model = api_ns_files_radarr.model('RadarrFileBrowserGetResponse', {
'name': fields.String(),
'children': fields.Boolean(),
'path': fields.String(),
})
@authenticate
@api_ns_files_radarr.marshal_with(get_response_model, code=200)
@api_ns_files_radarr.response(401, 'Not Authenticated')
@api_ns_files_radarr.doc(parser=get_request_parser)
def get(self):
path = request.args.get('path') or ''
"""List Radarr file system content"""
args = self.get_request_parser.parse_args()
path = args.get('path')
data = []
try:
result = browse_radarr_filesystem(path)
if result is None:
raise ValueError
except Exception:
return jsonify([])
return []
for item in result['directories']:
data.append({'name': item['name'], 'children': True, 'path': item['path']})
return jsonify(data)
return data

View File

@ -1,24 +1,41 @@
# coding=utf-8
from flask import request, jsonify
from flask_restful import Resource
from flask_restx import Resource, Namespace, reqparse, fields
from sonarr.filesystem import browse_sonarr_filesystem
from ..utils import authenticate
api_ns_files_sonarr = Namespace('Files Browser for Sonarr', description='Browse content of file system as seen by '
'Sonarr')
@api_ns_files_sonarr.route('files/sonarr')
class BrowseSonarrFS(Resource):
get_request_parser = reqparse.RequestParser()
get_request_parser.add_argument('path', type=str, default='', help='Path to browse')
get_response_model = api_ns_files_sonarr.model('SonarrFileBrowserGetResponse', {
'name': fields.String(),
'children': fields.Boolean(),
'path': fields.String(),
})
@authenticate
@api_ns_files_sonarr.marshal_with(get_response_model, code=200)
@api_ns_files_sonarr.response(401, 'Not Authenticated')
@api_ns_files_sonarr.doc(parser=get_request_parser)
def get(self):
path = request.args.get('path') or ''
"""List Sonarr file system content"""
args = self.get_request_parser.parse_args()
path = args.get('path')
data = []
try:
result = browse_sonarr_filesystem(path)
if result is None:
raise ValueError
except Exception:
return jsonify([])
return []
for item in result['directories']:
data.append({'name': item['name'], 'children': True, 'path': item['path']})
return jsonify(data)
return data

View File

@ -1,12 +1,8 @@
# coding=utf-8
from flask import Blueprint
from flask_restful import Api
from .stats import HistoryStats
from .stats import api_ns_history_stats
api_bp_history = Blueprint('api_history', __name__)
api = Api(api_bp_history)
api.add_resource(HistoryStats, '/history/stats')
api_ns_list_history = [
api_ns_history_stats
]

View File

@ -5,8 +5,7 @@ import datetime
import operator
from dateutil import rrule
from flask import request, jsonify
from flask_restful import Resource
from flask_restx import Resource, Namespace, reqparse, fields
from functools import reduce
from peewee import fn
@ -14,14 +13,45 @@ from app.database import TableHistory, TableHistoryMovie
from ..utils import authenticate
api_ns_history_stats = Namespace('History Statistics', description='Get history statistics')
@api_ns_history_stats.route('history/stats')
class HistoryStats(Resource):
get_request_parser = reqparse.RequestParser()
get_request_parser.add_argument('timeFrame', type=str, default='month',
help='Timeframe to get stats for. Must be in ["week", "month", "trimester", '
'"year"]')
get_request_parser.add_argument('action', type=str, default='All', help='Action type to filter for.')
get_request_parser.add_argument('provider', type=str, default='All', help='Provider name to filter for.')
get_request_parser.add_argument('language', type=str, default='All', help='Language name to filter for')
series_data_model = api_ns_history_stats.model('history_series_stats_data_model', {
'date': fields.String(),
'count': fields.Integer(),
})
movies_data_model = api_ns_history_stats.model('history_movies_stats_data_model', {
'date': fields.String(),
'count': fields.Integer(),
})
get_response_model = api_ns_history_stats.model('HistoryStatsGetResponse', {
'series': fields.Nested(series_data_model),
'movies': fields.Nested(movies_data_model),
})
@authenticate
@api_ns_history_stats.marshal_with(get_response_model, code=200)
@api_ns_history_stats.response(401, 'Not Authenticated')
@api_ns_history_stats.doc(parser=get_request_parser)
def get(self):
timeframe = request.args.get('timeFrame') or 'month'
action = request.args.get('action') or 'All'
provider = request.args.get('provider') or 'All'
language = request.args.get('language') or 'All'
"""Get history statistics"""
args = self.get_request_parser.parse_args()
timeframe = args.get('timeFrame')
action = args.get('action')
provider = args.get('provider')
language = args.get('language')
# timeframe must be in ['week', 'month', 'trimester', 'year']
if timeframe == 'year':
@ -82,4 +112,4 @@ class HistoryStats(Resource):
sorted_data_series = sorted(data_series, key=lambda i: i['date'])
sorted_data_movies = sorted(data_movies, key=lambda i: i['date'])
return jsonify(series=sorted_data_series, movies=sorted_data_movies)
return {'series': sorted_data_series, 'movies': sorted_data_movies}

View File

@ -1,20 +1,16 @@
# coding=utf-8
from flask import Blueprint
from flask_restful import Api
from .movies import Movies
from .movies_subtitles import MoviesSubtitles
from .history import MoviesHistory
from .wanted import MoviesWanted
from .blacklist import MoviesBlacklist
from .movies import api_ns_movies
from .movies_subtitles import api_ns_movies_subtitles
from .history import api_ns_movies_history
from .wanted import api_ns_movies_wanted
from .blacklist import api_ns_movies_blacklist
api_bp_movies = Blueprint('api_movies', __name__)
api = Api(api_bp_movies)
api.add_resource(Movies, '/movies')
api.add_resource(MoviesWanted, '/movies/wanted')
api.add_resource(MoviesSubtitles, '/movies/subtitles')
api.add_resource(MoviesHistory, '/movies/history')
api.add_resource(MoviesBlacklist, '/movies/blacklist')
api_ns_list_movies = [
api_ns_movies,
api_ns_movies_blacklist,
api_ns_movies_history,
api_ns_movies_subtitles,
api_ns_movies_wanted,
]

View File

@ -3,8 +3,7 @@
import datetime
import pretty
from flask import request, jsonify
from flask_restful import Resource
from flask_restx import Resource, Namespace, reqparse, fields
from app.database import TableMovies, TableBlacklistMovie
from subtitles.tools.delete import delete_subtitles
@ -12,18 +11,41 @@ from radarr.blacklist import blacklist_log_movie, blacklist_delete_all_movie, bl
from utilities.path_mappings import path_mappings
from subtitles.mass_download import movies_download_subtitles
from app.event_handler import event_stream
from api.swaggerui import subtitles_language_model
from ..utils import authenticate, postprocessMovie
api_ns_movies_blacklist = Namespace('Movies Blacklist', description='List, add or remove subtitles to or from '
'movies blacklist')
# GET: get blacklist
# POST: add blacklist
# DELETE: remove blacklist
@api_ns_movies_blacklist.route('movies/blacklist')
class MoviesBlacklist(Resource):
get_request_parser = reqparse.RequestParser()
get_request_parser.add_argument('start', type=int, required=False, default=0, help='Paging start integer')
get_request_parser.add_argument('length', type=int, required=False, default=-1, help='Paging length integer')
get_language_model = api_ns_movies_blacklist.model('subtitles_language_model', subtitles_language_model)
get_response_model = api_ns_movies_blacklist.model('MovieBlacklistGetResponse', {
'title': fields.String(),
'radarrId': fields.Integer(),
'provider': fields.String(),
'subs_id': fields.String(),
'language': fields.Nested(get_language_model),
'timestamp': fields.String(),
'parsed_timestamp': fields.String(),
})
@authenticate
@api_ns_movies_blacklist.marshal_with(get_response_model, envelope='data', code=200)
@api_ns_movies_blacklist.response(401, 'Not Authenticated')
@api_ns_movies_blacklist.doc(parser=get_request_parser)
def get(self):
start = request.args.get('start') or 0
length = request.args.get('length') or -1
"""List blacklisted movies subtitles"""
args = self.get_request_parser.parse_args()
start = args.get('start')
length = args.get('length')
data = TableBlacklistMovie.select(TableMovies.title,
TableMovies.radarrId,
@ -45,14 +67,27 @@ class MoviesBlacklist(Resource):
item["parsed_timestamp"] = datetime.datetime.fromtimestamp(int(item['timestamp'])).strftime('%x %X')
item.update({'timestamp': pretty.date(datetime.datetime.fromtimestamp(item['timestamp']))})
return jsonify(data=data)
return data
post_request_parser = reqparse.RequestParser()
post_request_parser.add_argument('radarrid', type=int, required=True, help='Radarr ID')
post_request_parser.add_argument('provider', type=str, required=True, help='Provider name')
post_request_parser.add_argument('subs_id', type=str, required=True, help='Subtitles ID')
post_request_parser.add_argument('language', type=str, required=True, help='Subtitles language')
post_request_parser.add_argument('subtitles_path', type=str, required=True, help='Subtitles file path')
@authenticate
@api_ns_movies_blacklist.doc(parser=post_request_parser)
@api_ns_movies_blacklist.response(200, 'Success')
@api_ns_movies_blacklist.response(401, 'Not Authenticated')
@api_ns_movies_blacklist.response(404, 'Movie not found')
def post(self):
radarr_id = int(request.args.get('radarrid'))
provider = request.form.get('provider')
subs_id = request.form.get('subs_id')
language = request.form.get('language')
"""Add a movies subtitles to blacklist"""
args = self.post_request_parser.parse_args()
radarr_id = args.get('radarrid')
provider = args.get('provider')
subs_id = args.get('subs_id')
language = args.get('language')
# TODO
forced = False
hi = False
@ -63,7 +98,7 @@ class MoviesBlacklist(Resource):
return 'Movie not found', 404
media_path = data['path']
subtitles_path = request.form.get('subtitles_path')
subtitles_path = args.get('subtitles_path')
blacklist_log_movie(radarr_id=radarr_id,
provider=provider,
@ -80,12 +115,22 @@ class MoviesBlacklist(Resource):
event_stream(type='movie-history')
return '', 200
delete_request_parser = reqparse.RequestParser()
delete_request_parser.add_argument('all', type=str, required=False, help='Empty movies subtitles blacklist')
delete_request_parser.add_argument('provider', type=str, required=True, help='Provider name')
delete_request_parser.add_argument('subs_id', type=str, required=True, help='Subtitles ID')
@authenticate
@api_ns_movies_blacklist.doc(parser=delete_request_parser)
@api_ns_movies_blacklist.response(204, 'Success')
@api_ns_movies_blacklist.response(401, 'Not Authenticated')
def delete(self):
if request.args.get("all") == "true":
"""Delete a movies subtitles from blacklist"""
args = self.post_request_parser.parse_args()
if args.get("all") == "true":
blacklist_delete_all_movie()
else:
provider = request.form.get('provider')
subs_id = request.form.get('subs_id')
provider = args.get('provider')
subs_id = args.get('subs_id')
blacklist_delete_movie(provider=provider, subs_id=subs_id)
return '', 200

View File

@ -5,8 +5,7 @@ import os
import operator
import pretty
from flask import request, jsonify
from flask_restful import Resource
from flask_restx import Resource, Namespace, reqparse, fields
from functools import reduce
from peewee import fn
from datetime import timedelta
@ -14,16 +13,58 @@ from datetime import timedelta
from app.database import get_exclusion_clause, TableMovies, TableHistoryMovie, TableBlacklistMovie
from app.config import settings
from utilities.path_mappings import path_mappings
from api.swaggerui import subtitles_language_model
from ..utils import authenticate, postprocessMovie
api_ns_movies_history = Namespace('Movies History', description='List movies history events')
@api_ns_movies_history.route('movies/history')
class MoviesHistory(Resource):
get_request_parser = reqparse.RequestParser()
get_request_parser.add_argument('start', type=int, required=False, default=0, help='Paging start integer')
get_request_parser.add_argument('length', type=int, required=False, default=-1, help='Paging length integer')
get_request_parser.add_argument('radarrid', type=int, required=False, help='Movie ID')
get_language_model = api_ns_movies_history.model('subtitles_language_model', subtitles_language_model)
data_model = api_ns_movies_history.model('history_movies_data_model', {
'id': fields.Integer(),
'action': fields.Integer(),
'title': fields.String(),
'timestamp': fields.String(),
'description': fields.String(),
'radarrId': fields.Integer(),
'monitored': fields.Boolean(),
'path': fields.String(),
'language': fields.Nested(get_language_model),
'tags': fields.List(fields.String),
'score': fields.String(),
'subs_id': fields.String(),
'provider': fields.String(),
'subtitles_path': fields.String(),
'upgradable': fields.Boolean(),
'raw_timestamp': fields.Integer(),
'parsed_timestamp': fields.String(),
'blacklisted': fields.Boolean(),
})
get_response_model = api_ns_movies_history.model('MovieHistoryGetResponse', {
'data': fields.Nested(data_model),
'total': fields.Integer(),
})
@authenticate
@api_ns_movies_history.marshal_with(get_response_model, code=200)
@api_ns_movies_history.response(401, 'Not Authenticated')
@api_ns_movies_history.doc(parser=get_request_parser)
def get(self):
start = request.args.get('start') or 0
length = request.args.get('length') or -1
radarrid = request.args.get('radarrid')
"""List movies history events"""
args = self.get_request_parser.parse_args()
start = args.get('start')
length = args.get('length')
radarrid = args.get('radarrid')
upgradable_movies = []
upgradable_movies_not_perfect = []
@ -129,4 +170,4 @@ class MoviesHistory(Resource):
.where(TableMovies.title.is_null(False))\
.count()
return jsonify(data=movie_history, total=count)
return {'data': movie_history, 'total': count}

View File

@ -1,23 +1,78 @@
# coding=utf-8
from flask import request, jsonify
from flask_restful import Resource
from flask_restx import Resource, Namespace, reqparse, fields
from app.database import TableMovies
from subtitles.indexer.movies import list_missing_subtitles_movies, movies_scan_subtitles
from app.event_handler import event_stream
from subtitles.wanted import wanted_search_missing_subtitles_movies
from subtitles.mass_download import movies_download_subtitles
from api.swaggerui import subtitles_model, subtitles_language_model, audio_language_model
from ..utils import authenticate, postprocessMovie, None_Keys
api_ns_movies = Namespace('Movies', description='List movies metadata, update movie languages profile or run actions '
'for specific movies.')
@api_ns_movies.route('movies')
class Movies(Resource):
get_request_parser = reqparse.RequestParser()
get_request_parser.add_argument('start', type=int, required=False, default=0, help='Paging start integer')
get_request_parser.add_argument('length', type=int, required=False, default=-1, help='Paging length integer')
get_request_parser.add_argument('radarrid[]', type=int, action='append', required=False, default=[],
help='Movies IDs to get metadata for')
get_subtitles_model = api_ns_movies.model('subtitles_model', subtitles_model)
get_subtitles_language_model = api_ns_movies.model('subtitles_language_model', subtitles_language_model)
get_audio_language_model = api_ns_movies.model('audio_language_model', audio_language_model)
data_model = api_ns_movies.model('movies_data_model', {
'alternativeTitles': fields.List(fields.String),
'audio_codec': fields.String(),
'audio_language': fields.Nested(get_audio_language_model),
'failedAttempts': fields.String(),
'fanart': fields.String(),
'file_size': fields.Integer(),
'format': fields.String(),
'imdbId': fields.String(),
'missing_subtitles': fields.Nested(get_subtitles_language_model),
'monitored': fields.Boolean(),
'movie_file_id': fields.Integer(),
'overview': fields.String(),
'path': fields.String(),
'poster': fields.String(),
'profileId': fields.Integer(),
'radarrId': fields.Integer(),
'resolution': fields.String(),
'rowid': fields.Integer(),
'sceneName': fields.String(),
'sortTitle': fields.String(),
'subtitles': fields.Nested(get_subtitles_model),
'tags': fields.List(fields.String),
'title': fields.String(),
'tmdbId': fields.String(),
'video_codec': fields.String(),
'year': fields.String(),
})
get_response_model = api_ns_movies.model('MoviesGetResponse', {
'data': fields.Nested(data_model),
'total': fields.Integer(),
})
@authenticate
@api_ns_movies.marshal_with(get_response_model, code=200)
@api_ns_movies.doc(parser=get_request_parser)
@api_ns_movies.response(200, 'Success')
@api_ns_movies.response(401, 'Not Authenticated')
def get(self):
start = request.args.get('start') or 0
length = request.args.get('length') or -1
radarrId = request.args.getlist('radarrid[]')
"""List movies metadata for specific movies"""
args = self.get_request_parser.parse_args()
start = args.get('start')
length = args.get('length')
radarrId = args.get('radarrid[]')
count = TableMovies.select().count()
@ -32,12 +87,24 @@ class Movies(Resource):
for item in result:
postprocessMovie(item)
return jsonify(data=result, total=count)
return {'data': result, 'total': count}
post_request_parser = reqparse.RequestParser()
post_request_parser.add_argument('radarrid', type=int, action='append', required=False, default=[],
help='Radarr movie(s) ID')
post_request_parser.add_argument('profileid', type=str, action='append', required=False, default=[],
help='Languages profile(s) ID or "none"')
@authenticate
@api_ns_movies.doc(parser=post_request_parser)
@api_ns_movies.response(204, 'Success')
@api_ns_movies.response(401, 'Not Authenticated')
@api_ns_movies.response(404, 'Languages profile not found')
def post(self):
radarrIdList = request.form.getlist('radarrid')
profileIdList = request.form.getlist('profileid')
"""Update specific movies languages profile"""
args = self.post_request_parser.parse_args()
radarrIdList = args.get('radarrid')
profileIdList = args.get('profileid')
for idx in range(len(radarrIdList)):
radarrId = radarrIdList[idx]
@ -65,10 +132,21 @@ class Movies(Resource):
return '', 204
patch_request_parser = reqparse.RequestParser()
patch_request_parser.add_argument('radarrid', type=int, required=False, help='Radarr movie ID')
patch_request_parser.add_argument('action', type=str, required=False, help='Action to perform from ["scan-disk", '
'"search-missing", "search-wanted"]')
@authenticate
@api_ns_movies.doc(parser=patch_request_parser)
@api_ns_movies.response(204, 'Success')
@api_ns_movies.response(400, 'Unknown action')
@api_ns_movies.response(401, 'Not Authenticated')
def patch(self):
radarrid = request.form.get('radarrid')
action = request.form.get('action')
"""Run actions on specific movies"""
args = self.patch_request_parser.parse_args()
radarrid = args.get('radarrid')
action = args.get('action')
if action == "scan-disk":
movies_scan_subtitles(radarrid)
return '', 204

View File

@ -3,9 +3,9 @@
import os
import logging
from flask import request
from flask_restful import Resource
from flask_restx import Resource, Namespace, reqparse
from subliminal_patch.core import SUBTITLE_EXTENSIONS
from werkzeug.datastructures import FileStorage
from app.database import TableMovies, get_audio_profile_languages, get_profile_id
from utilities.path_mappings import path_mappings
@ -21,14 +21,26 @@ from app.config import settings
from ..utils import authenticate
# PATCH: Download Subtitles
# POST: Upload Subtitles
# DELETE: Delete Subtitles
api_ns_movies_subtitles = Namespace('Movies Subtitles', description='Download, upload or delete movies subtitles')
@api_ns_movies_subtitles.route('movies/subtitles')
class MoviesSubtitles(Resource):
patch_request_parser = reqparse.RequestParser()
patch_request_parser.add_argument('radarrid', type=int, required=True, help='Movie ID')
patch_request_parser.add_argument('language', type=str, required=True, help='Language code2')
patch_request_parser.add_argument('forced', type=str, required=True, help='Forced true/false as string')
patch_request_parser.add_argument('hi', type=str, required=True, help='HI true/false as string')
@authenticate
@api_ns_movies_subtitles.doc(parser=patch_request_parser)
@api_ns_movies_subtitles.response(204, 'Success')
@api_ns_movies_subtitles.response(401, 'Not Authenticated')
@api_ns_movies_subtitles.response(404, 'Movie not found')
def patch(self):
# Download
radarrId = request.args.get('radarrid')
"""Download a movie subtitles"""
args = self.patch_request_parser.parse_args()
radarrId = args.get('radarrid')
movieInfo = TableMovies.select(TableMovies.title,
TableMovies.path,
@ -47,9 +59,9 @@ class MoviesSubtitles(Resource):
title = movieInfo['title']
audio_language = movieInfo['audio_language']
language = request.form.get('language')
hi = request.form.get('hi').capitalize()
forced = request.form.get('forced').capitalize()
language = args.get('language')
hi = args.get('hi').capitalize()
forced = args.get('forced').capitalize()
audio_language_list = get_audio_profile_languages(movie_id=radarrId)
if len(audio_language_list) > 0:
@ -85,11 +97,25 @@ class MoviesSubtitles(Resource):
return '', 204
# POST: Upload Subtitles
post_request_parser = reqparse.RequestParser()
post_request_parser.add_argument('radarrid', type=int, required=True, help='Movie ID')
post_request_parser.add_argument('language', type=str, required=True, help='Language code2')
post_request_parser.add_argument('forced', type=str, required=True, help='Forced true/false as string')
post_request_parser.add_argument('hi', type=str, required=True, help='HI true/false as string')
post_request_parser.add_argument('file', type=FileStorage, location='files', required=True,
help='Subtitles file as file upload object')
@authenticate
@api_ns_movies_subtitles.doc(parser=post_request_parser)
@api_ns_movies_subtitles.response(204, 'Success')
@api_ns_movies_subtitles.response(401, 'Not Authenticated')
@api_ns_movies_subtitles.response(404, 'Movie not found')
def post(self):
# Upload
"""Upload a movie subtitles"""
# TODO: Support Multiply Upload
radarrId = request.args.get('radarrid')
args = self.post_request_parser.parse_args()
radarrId = args.get('radarrid')
movieInfo = TableMovies.select(TableMovies.title,
TableMovies.path,
TableMovies.sceneName,
@ -107,10 +133,10 @@ class MoviesSubtitles(Resource):
title = movieInfo['title']
audioLanguage = movieInfo['audio_language']
language = request.form.get('language')
forced = True if request.form.get('forced') == 'true' else False
hi = True if request.form.get('hi') == 'true' else False
subFile = request.files.get('file')
language = args.get('language')
forced = True if args.get('forced') == 'true' else False
hi = True if args.get('hi') == 'true' else False
subFile = args.get('file')
_, ext = os.path.splitext(subFile.filename)
@ -151,10 +177,23 @@ class MoviesSubtitles(Resource):
return '', 204
# DELETE: Delete Subtitles
delete_request_parser = reqparse.RequestParser()
delete_request_parser.add_argument('radarrid', type=int, required=True, help='Movie ID')
delete_request_parser.add_argument('language', type=str, required=True, help='Language code2')
delete_request_parser.add_argument('forced', type=str, required=True, help='Forced true/false as string')
delete_request_parser.add_argument('hi', type=str, required=True, help='HI true/false as string')
delete_request_parser.add_argument('path', type=str, required=True, help='Path of the subtitles file')
@authenticate
@api_ns_movies_subtitles.doc(parser=delete_request_parser)
@api_ns_movies_subtitles.response(204, 'Success')
@api_ns_movies_subtitles.response(401, 'Not Authenticated')
@api_ns_movies_subtitles.response(404, 'Movie not found')
def delete(self):
# Delete
radarrId = request.args.get('radarrid')
"""Delete a movie subtitles"""
args = self.delete_request_parser.parse_args()
radarrId = args.get('radarrid')
movieInfo = TableMovies.select(TableMovies.path) \
.where(TableMovies.radarrId == radarrId) \
.dicts() \
@ -165,21 +204,19 @@ class MoviesSubtitles(Resource):
moviePath = path_mappings.path_replace_movie(movieInfo['path'])
language = request.form.get('language')
forced = request.form.get('forced')
hi = request.form.get('hi')
subtitlesPath = request.form.get('path')
language = args.get('language')
forced = args.get('forced')
hi = args.get('hi')
subtitlesPath = args.get('path')
subtitlesPath = path_mappings.path_replace_reverse_movie(subtitlesPath)
result = delete_subtitles(media_type='movie',
language=language,
forced=forced,
hi=hi,
media_path=moviePath,
subtitles_path=subtitlesPath,
radarr_id=radarrId)
if result:
return '', 202
else:
return '', 204
delete_subtitles(media_type='movie',
language=language,
forced=forced,
hi=hi,
media_path=moviePath,
subtitles_path=subtitlesPath,
radarr_id=radarrId)
return '', 204

View File

@ -2,20 +2,51 @@
import operator
from flask import request, jsonify
from flask_restful import Resource
from flask_restx import Resource, Namespace, reqparse, fields
from functools import reduce
from app.database import get_exclusion_clause, TableMovies
from api.swaggerui import subtitles_language_model
from ..utils import authenticate, postprocessMovie
# GET: Get Wanted Movies
api_ns_movies_wanted = Namespace('Movies Wanted', description='List movies wanted subtitles')
@api_ns_movies_wanted.route('movies/wanted')
class MoviesWanted(Resource):
get_request_parser = reqparse.RequestParser()
get_request_parser.add_argument('start', type=int, required=False, default=0, help='Paging start integer')
get_request_parser.add_argument('length', type=int, required=False, default=-1, help='Paging length integer')
get_request_parser.add_argument('radarrid[]', type=int, action='append', required=False, default=[],
help='Movies ID to list')
get_subtitles_language_model = api_ns_movies_wanted.model('subtitles_language_model', subtitles_language_model)
data_model = api_ns_movies_wanted.model('wanted_movies_data_model', {
'title': fields.String(),
'monitored': fields.Boolean(),
'missing_subtitles': fields.Nested(get_subtitles_language_model),
'radarrId': fields.Integer(),
'sceneName': fields.String(),
'tags': fields.List(fields.String),
'failedAttempts': fields.String(),
})
get_response_model = api_ns_movies_wanted.model('MovieWantedGetResponse', {
'data': fields.Nested(data_model),
'total': fields.Integer(),
})
@authenticate
@api_ns_movies_wanted.marshal_with(get_response_model, code=200)
@api_ns_movies_wanted.response(401, 'Not Authenticated')
@api_ns_movies_wanted.doc(parser=get_request_parser)
def get(self):
radarrid = request.args.getlist("radarrid[]")
"""List movies wanted subtitles"""
args = self.get_request_parser.parse_args()
radarrid = args.get("radarrid[]")
wanted_conditions = [(TableMovies.missing_subtitles != '[]')]
if len(radarrid) > 0:
@ -34,8 +65,8 @@ class MoviesWanted(Resource):
.where(wanted_condition)\
.dicts()
else:
start = request.args.get('start') or 0
length = request.args.get('length') or -1
start = args.get('start')
length = args.get('length')
result = TableMovies.select(TableMovies.title,
TableMovies.missing_subtitles,
TableMovies.radarrId,
@ -60,4 +91,4 @@ class MoviesWanted(Resource):
.where(reduce(operator.and_, count_conditions))\
.count()
return jsonify(data=result, total=count)
return {'data': result, 'total': count}

View File

@ -1,16 +1,12 @@
# coding=utf-8
from flask import Blueprint
from flask_restful import Api
from .providers import Providers
from .providers_episodes import ProviderEpisodes
from .providers_movies import ProviderMovies
from .providers import api_ns_providers
from .providers_episodes import api_ns_providers_episodes
from .providers_movies import api_ns_providers_movies
api_bp_providers = Blueprint('api_providers', __name__)
api = Api(api_bp_providers)
api.add_resource(Providers, '/providers')
api.add_resource(ProviderMovies, '/providers/movies')
api.add_resource(ProviderEpisodes, '/providers/episodes')
api_ns_list_providers = [
api_ns_providers,
api_ns_providers_episodes,
api_ns_providers_movies,
]

View File

@ -1,7 +1,6 @@
# coding=utf-8
from flask import request, jsonify
from flask_restful import Resource
from flask_restx import Resource, Namespace, reqparse, fields
from operator import itemgetter
from app.database import TableHistory, TableHistoryMovie
@ -9,11 +8,29 @@ from app.get_providers import list_throttled_providers, reset_throttled_provider
from ..utils import authenticate, False_Keys
api_ns_providers = Namespace('Providers', description='Get and reset providers status')
@api_ns_providers.route('providers')
class Providers(Resource):
get_request_parser = reqparse.RequestParser()
get_request_parser.add_argument('history', type=str, required=False, help='Provider name for history stats')
get_response_model = api_ns_providers.model('MovieBlacklistGetResponse', {
'name': fields.String(),
'status': fields.String(),
'retry': fields.String(),
})
@authenticate
@api_ns_providers.marshal_with(get_response_model, envelope='data', code=200)
@api_ns_providers.response(200, 'Success')
@api_ns_providers.response(401, 'Not Authenticated')
@api_ns_providers.doc(parser=get_request_parser)
def get(self):
history = request.args.get('history')
"""Get providers status"""
args = self.get_request_parser.parse_args()
history = args.get('history')
if history and history not in False_Keys:
providers = list(TableHistory.select(TableHistory.provider)
.where(TableHistory.provider is not None and TableHistory.provider != "manual")
@ -29,22 +46,30 @@ class Providers(Resource):
'status': 'History',
'retry': '-'
})
return jsonify(data=sorted(providers_dicts, key=itemgetter('name')))
else:
throttled_providers = list_throttled_providers()
throttled_providers = list_throttled_providers()
providers_dicts = list()
for provider in throttled_providers:
providers_dicts.append({
"name": provider[0],
"status": provider[1] if provider[1] is not None else "Good",
"retry": provider[2] if provider[2] != "now" else "-"
})
return sorted(providers_dicts, key=itemgetter('name'))
providers = list()
for provider in throttled_providers:
providers.append({
"name": provider[0],
"status": provider[1] if provider[1] is not None else "Good",
"retry": provider[2] if provider[2] != "now" else "-"
})
return jsonify(data=providers)
post_request_parser = reqparse.RequestParser()
post_request_parser.add_argument('action', type=str, required=True, help='Action to perform from ["reset"]')
@authenticate
@api_ns_providers.doc(parser=post_request_parser)
@api_ns_providers.response(204, 'Success')
@api_ns_providers.response(401, 'Not Authenticated')
@api_ns_providers.response(400, 'Unknown action')
def post(self):
action = request.form.get('action')
"""Reset providers status"""
args = self.post_request_parser.parse_args()
action = args.get('action')
if action == 'reset':
reset_throttled_providers()

View File

@ -1,7 +1,6 @@
# coding=utf-8
from flask import request, jsonify
from flask_restful import Resource
from flask_restx import Resource, Namespace, reqparse, fields
from app.database import TableEpisodes, TableShows, get_audio_profile_languages, get_profile_id
from utilities.path_mappings import path_mappings
@ -15,11 +14,40 @@ from subtitles.indexer.series import store_subtitles
from ..utils import authenticate
api_ns_providers_episodes = Namespace('Providers Episodes', description='List and download episodes subtitles manually')
@api_ns_providers_episodes.route('providers/episodes')
class ProviderEpisodes(Resource):
get_request_parser = reqparse.RequestParser()
get_request_parser.add_argument('episodeid', type=int, required=True, help='Episode ID')
get_response_model = api_ns_providers_episodes.model('ProviderEpisodesGetResponse', {
'dont_matches': fields.List(fields.String),
'forced': fields.String(),
'hearing_impaired': fields.String(),
'language': fields.String(),
'matches': fields.List(fields.String),
'original_format': fields.String(),
'orig_score': fields.Integer(),
'provider': fields.String(),
'release_info': fields.List(fields.String),
'score': fields.Integer(),
'score_without_hash': fields.Integer(),
'subtitle': fields.String(),
'uploader': fields.String(),
'url': fields.String(),
})
@authenticate
@api_ns_providers_episodes.marshal_with(get_response_model, envelope='data', code=200)
@api_ns_providers_episodes.response(401, 'Not Authenticated')
@api_ns_providers_episodes.response(404, 'Episode not found')
@api_ns_providers_episodes.doc(parser=get_request_parser)
def get(self):
# Manual Search
sonarrEpisodeId = request.args.get('episodeid')
"""Search manually for an episode subtitles"""
args = self.get_request_parser.parse_args()
sonarrEpisodeId = args.get('episodeid')
episodeInfo = TableEpisodes.select(TableEpisodes.path,
TableEpisodes.scene_name,
TableShows.title,
@ -42,13 +70,28 @@ class ProviderEpisodes(Resource):
data = manual_search(episodePath, profileId, providers_list, sceneName, title, 'series')
if not data:
data = []
return jsonify(data=data)
return data
post_request_parser = reqparse.RequestParser()
post_request_parser.add_argument('seriesid', type=int, required=True, help='Series ID')
post_request_parser.add_argument('episodeid', type=int, required=True, help='Episode ID')
post_request_parser.add_argument('hi', type=str, required=True, help='HI subtitles from ["True", "False"]')
post_request_parser.add_argument('forced', type=str, required=True, help='Forced subtitles from ["True", "False"]')
post_request_parser.add_argument('original_format', type=str, required=True,
help='Use original subtitles format from ["True", "False"]')
post_request_parser.add_argument('provider', type=str, required=True, help='Provider name')
post_request_parser.add_argument('subtitle', type=str, required=True, help='Pickled subtitles as return by GET')
@authenticate
@api_ns_providers_episodes.doc(parser=post_request_parser)
@api_ns_providers_episodes.response(204, 'Success')
@api_ns_providers_episodes.response(401, 'Not Authenticated')
@api_ns_providers_episodes.response(404, 'Episode not found')
def post(self):
# Manual Download
sonarrSeriesId = request.args.get('seriesid')
sonarrEpisodeId = request.args.get('episodeid')
"""Manually download an episode subtitles"""
args = self.post_request_parser.parse_args()
sonarrSeriesId = args.get('seriesid')
sonarrEpisodeId = args.get('episodeid')
episodeInfo = TableEpisodes.select(TableEpisodes.path,
TableEpisodes.scene_name,
TableShows.title) \
@ -64,11 +107,11 @@ class ProviderEpisodes(Resource):
episodePath = path_mappings.path_replace(episodeInfo['path'])
sceneName = episodeInfo['scene_name'] or "None"
hi = request.form.get('hi').capitalize()
forced = request.form.get('forced').capitalize()
use_original_format = request.form.get('original_format').capitalize()
selected_provider = request.form.get('provider')
subtitle = request.form.get('subtitle')
hi = args.get('hi').capitalize()
forced = args.get('forced').capitalize()
use_original_format = args.get('original_format').capitalize()
selected_provider = args.get('provider')
subtitle = args.get('subtitle')
audio_language_list = get_audio_profile_languages(episode_id=sonarrEpisodeId)
if len(audio_language_list) > 0:
@ -78,7 +121,8 @@ class ProviderEpisodes(Resource):
try:
result = manual_download_subtitle(episodePath, audio_language, hi, forced, subtitle, selected_provider,
sceneName, title, 'series', use_original_format, profile_id=get_profile_id(episode_id=sonarrEpisodeId))
sceneName, title, 'series', use_original_format,
profile_id=get_profile_id(episode_id=sonarrEpisodeId))
if result is not None:
message = result[0]
path = result[1]

View File

@ -1,9 +1,6 @@
# coding=utf-8
import logging
from flask import request, jsonify
from flask_restful import Resource
from flask_restx import Resource, Namespace, reqparse, fields
from app.database import TableMovies, get_audio_profile_languages, get_profile_id
from utilities.path_mappings import path_mappings
@ -17,11 +14,40 @@ from subtitles.indexer.movies import store_subtitles_movie
from ..utils import authenticate
api_ns_providers_movies = Namespace('Providers Movies', description='List and download movies subtitles manually')
@api_ns_providers_movies.route('providers/movies')
class ProviderMovies(Resource):
get_request_parser = reqparse.RequestParser()
get_request_parser.add_argument('radarrid', type=int, required=True, help='Movie ID')
get_response_model = api_ns_providers_movies.model('ProviderMoviesGetResponse', {
'dont_matches': fields.List(fields.String),
'forced': fields.String(),
'hearing_impaired': fields.String(),
'language': fields.String(),
'matches': fields.List(fields.String),
'original_format': fields.String(),
'orig_score': fields.Integer(),
'provider': fields.String(),
'release_info': fields.List(fields.String),
'score': fields.Integer(),
'score_without_hash': fields.Integer(),
'subtitle': fields.String(),
'uploader': fields.String(),
'url': fields.String(),
})
@authenticate
@api_ns_providers_movies.marshal_with(get_response_model, envelope='data', code=200)
@api_ns_providers_movies.response(401, 'Not Authenticated')
@api_ns_providers_movies.response(404, 'Movie not found')
@api_ns_providers_movies.doc(parser=get_request_parser)
def get(self):
# Manual Search
radarrId = request.args.get('radarrid')
"""Search manually for a movie subtitles"""
args = self.get_request_parser.parse_args()
radarrId = args.get('radarrid')
movieInfo = TableMovies.select(TableMovies.title,
TableMovies.path,
TableMovies.sceneName,
@ -43,12 +69,26 @@ class ProviderMovies(Resource):
data = manual_search(moviePath, profileId, providers_list, sceneName, title, 'movie')
if not data:
data = []
return jsonify(data=data)
return data
post_request_parser = reqparse.RequestParser()
post_request_parser.add_argument('radarrid', type=int, required=True, help='Movie ID')
post_request_parser.add_argument('hi', type=str, required=True, help='HI subtitles from ["True", "False"]')
post_request_parser.add_argument('forced', type=str, required=True, help='Forced subtitles from ["True", "False"]')
post_request_parser.add_argument('original_format', type=str, required=True,
help='Use original subtitles format from ["True", "False"]')
post_request_parser.add_argument('provider', type=str, required=True, help='Provider name')
post_request_parser.add_argument('subtitle', type=str, required=True, help='Pickled subtitles as return by GET')
@authenticate
@api_ns_providers_movies.doc(parser=post_request_parser)
@api_ns_providers_movies.response(204, 'Success')
@api_ns_providers_movies.response(401, 'Not Authenticated')
@api_ns_providers_movies.response(404, 'Movie not found')
def post(self):
# Manual Download
radarrId = request.args.get('radarrid')
"""Manually download a movie subtitles"""
args = self.post_request_parser.parse_args()
radarrId = args.get('radarrid')
movieInfo = TableMovies.select(TableMovies.title,
TableMovies.path,
TableMovies.sceneName,
@ -64,12 +104,11 @@ class ProviderMovies(Resource):
moviePath = path_mappings.path_replace_movie(movieInfo['path'])
sceneName = movieInfo['sceneName'] or "None"
hi = request.form.get('hi').capitalize()
forced = request.form.get('forced').capitalize()
use_original_format = request.form.get('original_format').capitalize()
logging.debug(f"use_original_format {use_original_format}")
selected_provider = request.form.get('provider')
subtitle = request.form.get('subtitle')
hi = args.get('hi').capitalize()
forced = args.get('forced').capitalize()
use_original_format = args.get('original_format').capitalize()
selected_provider = args.get('provider')
subtitle = args.get('subtitle')
audio_language_list = get_audio_profile_languages(movie_id=radarrId)
if len(audio_language_list) > 0:
@ -79,7 +118,8 @@ class ProviderMovies(Resource):
try:
result = manual_download_subtitle(moviePath, audio_language, hi, forced, subtitle, selected_provider,
sceneName, title, 'movie', use_original_format, profile_id=get_profile_id(movie_id=radarrId))
sceneName, title, 'movie', use_original_format,
profile_id=get_profile_id(movie_id=radarrId))
if result is not None:
message = result[0]
path = result[1]

View File

@ -1,12 +1,8 @@
# coding=utf-8
from flask import Blueprint
from flask_restful import Api
from .series import Series
from .series import api_ns_series
api_bp_series = Blueprint('api_series', __name__)
api = Api(api_bp_series)
api.add_resource(Series, '/series')
api_ns_list_series = [
api_ns_series,
]

View File

@ -2,8 +2,7 @@
import operator
from flask import request, jsonify
from flask_restful import Resource
from flask_restx import Resource, Namespace, reqparse, fields
from functools import reduce
from app.database import get_exclusion_clause, TableEpisodes, TableShows
@ -11,16 +10,63 @@ from subtitles.indexer.series import list_missing_subtitles, series_scan_subtitl
from subtitles.mass_download import series_download_subtitles
from subtitles.wanted import wanted_search_missing_subtitles_series
from app.event_handler import event_stream
from api.swaggerui import subtitles_model, subtitles_language_model, audio_language_model
from ..utils import authenticate, postprocessSeries, None_Keys
api_ns_series = Namespace('Series', description='List series metadata, update series languages profile or run actions '
'for specific series.')
@api_ns_series.route('series')
class Series(Resource):
get_request_parser = reqparse.RequestParser()
get_request_parser.add_argument('start', type=int, required=False, default=0, help='Paging start integer')
get_request_parser.add_argument('length', type=int, required=False, default=-1, help='Paging length integer')
get_request_parser.add_argument('seriesid[]', type=int, action='append', required=False, default=[],
help='Series IDs to get metadata for')
get_subtitles_model = api_ns_series.model('subtitles_model', subtitles_model)
get_subtitles_language_model = api_ns_series.model('subtitles_language_model', subtitles_language_model)
get_audio_language_model = api_ns_series.model('audio_language_model', audio_language_model)
data_model = api_ns_series.model('series_data_model', {
'alternativeTitles': fields.List(fields.String),
'audio_language': fields.Nested(get_audio_language_model),
'episodeFileCount': fields.Integer(),
'episodeMissingCount': fields.Integer(),
'fanart': fields.String(),
'imdbId': fields.String(),
'overview': fields.String(),
'path': fields.String(),
'poster': fields.String(),
'profileId': fields.Integer(),
'seriesType': fields.String(),
'sonarrSeriesId': fields.Integer(),
'sortTitle': fields.String(),
'tags': fields.List(fields.String),
'title': fields.String(),
'tvdbId': fields.Integer(),
'year': fields.String(),
})
get_response_model = api_ns_series.model('SeriesGetResponse', {
'data': fields.Nested(data_model),
'total': fields.Integer(),
})
@authenticate
@api_ns_series.marshal_with(get_response_model, code=200)
@api_ns_series.doc(parser=get_request_parser)
@api_ns_series.response(200, 'Success')
@api_ns_series.response(401, 'Not Authenticated')
def get(self):
start = request.args.get('start') or 0
length = request.args.get('length') or -1
seriesId = request.args.getlist('seriesid[]')
"""List series metadata for specific series"""
args = self.get_request_parser.parse_args()
start = args.get('start')
length = args.get('length')
seriesId = args.get('seriesid[]')
count = TableShows.select().count()
@ -58,12 +104,24 @@ class Series(Resource):
.count()
item.update({"episodeFileCount": episodeFileCount})
return jsonify(data=result, total=count)
return {'data': result, 'total': count}
post_request_parser = reqparse.RequestParser()
post_request_parser.add_argument('seriesid', type=int, action='append', required=False, default=[],
help='Sonarr series ID')
post_request_parser.add_argument('profileid', type=str, action='append', required=False, default=[],
help='Languages profile(s) ID or "none"')
@authenticate
@api_ns_series.doc(parser=post_request_parser)
@api_ns_series.response(204, 'Success')
@api_ns_series.response(401, 'Not Authenticated')
@api_ns_series.response(404, 'Languages profile not found')
def post(self):
seriesIdList = request.form.getlist('seriesid')
profileIdList = request.form.getlist('profileid')
"""Update specific series languages profile"""
args = self.post_request_parser.parse_args()
seriesIdList = args.get('seriesid')
profileIdList = args.get('profileid')
for idx in range(len(seriesIdList)):
seriesId = seriesIdList[idx]
@ -99,10 +157,21 @@ class Series(Resource):
return '', 204
patch_request_parser = reqparse.RequestParser()
patch_request_parser.add_argument('seriesid', type=int, required=False, help='Sonarr series ID')
patch_request_parser.add_argument('action', type=str, required=False, help='Action to perform from ["scan-disk", '
'"search-missing", "search-wanted"]')
@authenticate
@api_ns_series.doc(parser=patch_request_parser)
@api_ns_series.response(204, 'Success')
@api_ns_series.response(400, 'Unknown action')
@api_ns_series.response(401, 'Not Authenticated')
def patch(self):
seriesid = request.form.get('seriesid')
action = request.form.get('action')
"""Run actions on specific series"""
args = self.patch_request_parser.parse_args()
seriesid = args.get('seriesid')
action = args.get('action')
if action == "scan-disk":
series_scan_subtitles(seriesid)
return '', 204

View File

@ -1,14 +1,10 @@
# coding=utf-8
from flask import Blueprint
from flask_restful import Api
from .subtitles import Subtitles
from .subtitles_info import SubtitleNameInfo
from .subtitles import api_ns_subtitles
from .subtitles_info import api_ns_subtitles_info
api_bp_subtitles = Blueprint('api_subtitles', __name__)
api = Api(api_bp_subtitles)
api.add_resource(Subtitles, '/subtitles')
api.add_resource(SubtitleNameInfo, '/subtitles/info')
api_ns_list_subtitles = [
api_ns_subtitles,
api_ns_subtitles_info,
]

View File

@ -4,8 +4,7 @@ import os
import sys
import gc
from flask import request
from flask_restful import Resource
from flask_restx import Resource, Namespace, reqparse
from app.database import TableEpisodes, TableMovies
from utilities.path_mappings import path_mappings
@ -20,15 +19,37 @@ from app.event_handler import event_stream
from ..utils import authenticate
class Subtitles(Resource):
@authenticate
def patch(self):
action = request.args.get('action')
api_ns_subtitles = Namespace('Subtitles', description='Apply mods/tools on external subtitles')
language = request.form.get('language')
subtitles_path = request.form.get('path')
media_type = request.form.get('type')
id = request.form.get('id')
@api_ns_subtitles.route('subtitles')
class Subtitles(Resource):
patch_request_parser = reqparse.RequestParser()
patch_request_parser.add_argument('action', type=str, required=True,
help='Action from ["sync", "translate" or mods name]')
patch_request_parser.add_argument('language', type=str, required=True, help='Language code2')
patch_request_parser.add_argument('path', type=str, required=True, help='Subtitles file path')
patch_request_parser.add_argument('type', type=str, required=True, help='Media type from ["episode", "movie"]')
patch_request_parser.add_argument('id', type=int, required=True, help='Episode ID')
patch_request_parser.add_argument('forced', type=str, required=False, help='Forced subtitles from ["True", "False"]')
patch_request_parser.add_argument('hi', type=str, required=False, help='HI subtitles from ["True", "False"]')
patch_request_parser.add_argument('original_format', type=str, required=False,
help='Use original subtitles format from ["True", "False"]')
@authenticate
@api_ns_subtitles.doc(parser=patch_request_parser)
@api_ns_subtitles.response(204, 'Success')
@api_ns_subtitles.response(401, 'Not Authenticated')
@api_ns_subtitles.response(404, 'Episode/movie not found')
def patch(self):
"""Apply mods/tools on external subtitles"""
args = self.patch_request_parser.parse_args()
action = args.get('action')
language = args.get('language')
subtitles_path = args.get('path')
media_type = args.get('type')
id = args.get('id')
if media_type == 'episode':
metadata = TableEpisodes.select(TableEpisodes.path, TableEpisodes.sonarrSeriesId)\
@ -62,8 +83,8 @@ class Subtitles(Resource):
elif action == 'translate':
from_language = os.path.splitext(subtitles_path)[0].rsplit(".", 1)[1].replace('_', '-')
dest_language = language
forced = True if request.form.get('forced') == 'true' else False
hi = True if request.form.get('hi') == 'true' else False
forced = True if args.get('forced') == 'true' else False
hi = True if args.get('hi') == 'true' else False
translate_subtitles_file(video_path=video_path, source_srt_file=subtitles_path,
from_lang=from_language, to_lang=dest_language, forced=forced, hi=hi,
media_type="series" if media_type == "episode" else "movies",
@ -71,7 +92,7 @@ class Subtitles(Resource):
sonarr_episode_id=int(id),
radarr_id=id)
else:
use_original_format = True if request.form.get('original_format') == 'true' else False
use_original_format = True if args.get('original_format') == 'true' else False
subtitles_apply_mods(language, subtitles_path, [action], use_original_format)
# apply chmod if required

View File

@ -1,16 +1,37 @@
# coding=utf-8
from flask import request, jsonify
from flask_restful import Resource
from flask_restx import Resource, Namespace, reqparse, fields
from subliminal_patch.core import guessit
from ..utils import authenticate
api_ns_subtitles_info = Namespace('Subtitles Info', description='Guess season number, episode number or language from '
'uploaded subtitles filename')
@api_ns_subtitles_info.route('subtitles/info')
class SubtitleNameInfo(Resource):
get_request_parser = reqparse.RequestParser()
get_request_parser.add_argument('filenames[]', type=str, required=True, action='append',
help='Subtitles filenames')
get_response_model = api_ns_subtitles_info.model('SubtitlesInfoGetResponse', {
'filename': fields.String(),
'subtitle_language': fields.String(),
'season': fields.Integer(),
'episode': fields.Integer(),
})
@authenticate
@api_ns_subtitles_info.marshal_with(get_response_model, envelope='data', code=200)
@api_ns_subtitles_info.response(200, 'Success')
@api_ns_subtitles_info.response(401, 'Not Authenticated')
@api_ns_subtitles_info.doc(parser=get_request_parser)
def get(self):
names = request.args.getlist('filenames[]')
"""Guessit over subtitles filename"""
args = self.get_request_parser.parse_args()
names = args.get('filenames[]')
results = []
for name in names:
opts = dict()
@ -27,16 +48,16 @@ class SubtitleNameInfo(Resource):
# for multiple episodes file, choose the first episode number
if len(guessit_result['episode']):
# make sure that guessit returned a list of more than 0 items
result['episode'] = int(guessit_result['episode'][0])
elif isinstance(guessit_result['episode'], (str, int)):
# if single episode (should be int but just in case we cast it to int)
result['episode'] = int(guessit_result['episode'])
result['episode'] = guessit_result['episode'][0]
elif isinstance(guessit_result['episode'], int):
# if single episode
result['episode'] = guessit_result['episode']
if 'season' in guessit_result:
result['season'] = int(guessit_result['season'])
result['season'] = guessit_result['season']
else:
result['season'] = 0
results.append(result)
return jsonify(data=results)
return results

33
bazarr/api/swaggerui.py Normal file
View File

@ -0,0 +1,33 @@
# coding=utf-8
import os
from flask_restx import fields
swaggerui_api_params = {"version": os.environ["BAZARR_VERSION"],
"description": "API docs for Bazarr",
"title": "Bazarr",
}
subtitles_model = {
"name": fields.String(),
"code2": fields.String(),
"code3": fields.String(),
"path": fields.String(),
"forced": fields.Boolean(),
"hi": fields.Boolean()
}
subtitles_language_model = {
"name": fields.String(),
"code2": fields.String(),
"code3": fields.String(),
"forced": fields.Boolean(),
"hi": fields.Boolean()
}
audio_language_model = {
"name": fields.String(),
"code2": fields.String(),
"code3": fields.String()
}

View File

@ -1,35 +1,31 @@
# coding=utf-8
from flask import Blueprint
from flask_restful import Api
from .system import api_ns_system
from .searches import api_ns_system_searches
from .account import api_ns_system_account
from .backups import api_ns_system_backups
from .tasks import api_ns_system_tasks
from .logs import api_ns_system_logs
from .status import api_ns_system_status
from .health import api_ns_system_health
from .releases import api_ns_system_releases
from .settings import api_ns_system_settings
from .languages import api_ns_system_languages
from .languages_profiles import api_ns_system_languages_profiles
from .notifications import api_ns_system_notifications
from .system import System
from .searches import Searches
from .account import SystemAccount
from .backups import SystemBackups
from .tasks import SystemTasks
from .logs import SystemLogs
from .status import SystemStatus
from .health import SystemHealth
from .releases import SystemReleases
from .settings import SystemSettings
from .languages import Languages
from .languages_profiles import LanguagesProfiles
from .notifications import Notifications
api_bp_system = Blueprint('api_system', __name__)
api = Api(api_bp_system)
api.add_resource(System, '/system')
api.add_resource(Searches, '/system/searches')
api.add_resource(SystemAccount, '/system/account')
api.add_resource(SystemBackups, '/system/backups')
api.add_resource(SystemTasks, '/system/tasks')
api.add_resource(SystemLogs, '/system/logs')
api.add_resource(SystemStatus, '/system/status')
api.add_resource(SystemHealth, '/system/health')
api.add_resource(SystemReleases, '/system/releases')
api.add_resource(SystemSettings, '/system/settings')
api.add_resource(Languages, '/system/languages')
api.add_resource(LanguagesProfiles, '/system/languages/profiles')
api.add_resource(Notifications, '/system/notifications')
api_ns_list_system = [
api_ns_system,
api_ns_system_account,
api_ns_system_backups,
api_ns_system_health,
api_ns_system_languages,
api_ns_system_languages_profiles,
api_ns_system_logs,
api_ns_system_notifications,
api_ns_system_releases,
api_ns_system_searches,
api_ns_system_settings,
api_ns_system_status,
api_ns_system_tasks,
]

View File

@ -2,22 +2,37 @@
import gc
from flask import request, session
from flask_restful import Resource
from flask import session
from flask_restx import Resource, Namespace, reqparse
from app.config import settings
from utilities.helper import check_credentials
api_ns_system_account = Namespace('System Account', description='Login or logout from Bazarr UI')
@api_ns_system_account.hide
@api_ns_system_account.route('system/account')
class SystemAccount(Resource):
post_request_parser = reqparse.RequestParser()
post_request_parser.add_argument('action', type=str, required=True, help='Action from ["login", "logout"]')
post_request_parser.add_argument('username', type=str, required=False, help='Bazarr username')
post_request_parser.add_argument('password', type=str, required=False, help='Bazarr password')
@api_ns_system_account.doc(parser=post_request_parser)
@api_ns_system_account.response(204, 'Success')
@api_ns_system_account.response(400, 'Unknown action')
@api_ns_system_account.response(404, 'Unknown authentication type define in config.ini')
def post(self):
"""Login or logout from Bazarr UI when using form login"""
args = self.patch_request_parser.parse_args()
if settings.auth.type != 'form':
return 'Unknown authentication type define in config.ini', 404
action = request.args.get('action')
action = args.get('action')
if action == 'login':
username = request.form.get('username')
password = request.form.get('password')
username = args.get('username')
password = args.get('password')
if check_credentials(username, password):
session['logged_in'] = True
return '', 204

View File

@ -1,36 +1,72 @@
# coding=utf-8
from flask import jsonify, request
from flask_restful import Resource
from flask_restx import Resource, Namespace, reqparse, fields
from utilities.backup import get_backup_files, prepare_restore, delete_backup_file, backup_to_zip
from ..utils import authenticate
api_ns_system_backups = Namespace('System Backups', description='List, create, restore or delete backups')
@api_ns_system_backups.route('system/backups')
class SystemBackups(Resource):
@authenticate
def get(self):
backups = get_backup_files(fullpath=False)
return jsonify(data=backups)
get_response_model = api_ns_system_backups.model('SystemBackupsGetResponse', {
'date': fields.String(),
'filename': fields.String(),
'size': fields.String(),
'type': fields.String(),
})
@authenticate
@api_ns_system_backups.marshal_with(get_response_model, envelope='data', code=200)
@api_ns_system_backups.doc(parser=None)
@api_ns_system_backups.response(204, 'Success')
@api_ns_system_backups.response(401, 'Not Authenticated')
def get(self):
"""List backup files"""
backups = get_backup_files(fullpath=False)
return backups
@authenticate
@api_ns_system_backups.doc(parser=None)
@api_ns_system_backups.response(204, 'Success')
@api_ns_system_backups.response(401, 'Not Authenticated')
def post(self):
"""Create a new backup"""
backup_to_zip()
return '', 204
patch_request_parser = reqparse.RequestParser()
patch_request_parser.add_argument('filename', type=str, required=True, help='Backups to restore filename')
@authenticate
@api_ns_system_backups.doc(parser=patch_request_parser)
@api_ns_system_backups.response(204, 'Success')
@api_ns_system_backups.response(400, 'Filename not provided')
@api_ns_system_backups.response(401, 'Not Authenticated')
def patch(self):
filename = request.form.get('filename')
"""Restore a backup file"""
args = self.patch_request_parser.parse_args()
filename = args.get('filename')
if filename:
restored = prepare_restore(filename)
if restored:
return '', 204
return 'Filename not provided', 400
delete_request_parser = reqparse.RequestParser()
delete_request_parser.add_argument('filename', type=str, required=True, help='Backups to delete filename')
@authenticate
@api_ns_system_backups.doc(parser=delete_request_parser)
@api_ns_system_backups.response(204, 'Success')
@api_ns_system_backups.response(400, 'Filename not provided')
@api_ns_system_backups.response(401, 'Not Authenticated')
def delete(self):
filename = request.form.get('filename')
"""Delete a backup file"""
args = self.delete_request_parser.parse_args()
filename = args.get('filename')
if filename:
deleted = delete_backup_file(filename)
if deleted:

View File

@ -1,14 +1,20 @@
# coding=utf-8
from flask import jsonify
from flask_restful import Resource
from flask_restx import Resource, Namespace, reqparse
from utilities.health import get_health_issues
from ..utils import authenticate
api_ns_system_health = Namespace('System Health', description='List health issues')
@api_ns_system_health.route('system/health')
class SystemHealth(Resource):
@authenticate
@api_ns_system_health.doc(parser=None)
@api_ns_system_health.response(200, 'Success')
@api_ns_system_health.response(401, 'Not Authenticated')
def get(self):
return jsonify(data=get_health_issues())
"""List health issues"""
return {'data': get_health_issues()}

View File

@ -1,7 +1,6 @@
# coding=utf-8
from flask import request, jsonify
from flask_restful import Resource
from flask_restx import Resource, Namespace, reqparse, fields
from operator import itemgetter
from app.database import TableHistory, TableHistoryMovie, TableSettingsLanguages
@ -9,11 +8,22 @@ from languages.get_languages import alpha2_from_alpha3, language_from_alpha2
from ..utils import authenticate, False_Keys
api_ns_system_languages = Namespace('System Languages', description='Get languages list')
@api_ns_system_languages.route('system/languages')
class Languages(Resource):
get_request_parser = reqparse.RequestParser()
get_request_parser.add_argument('history', type=str, required=False, help='Language name for history stats')
@authenticate
@api_ns_system_languages.doc(parser=get_request_parser)
@api_ns_system_languages.response(200, 'Success')
@api_ns_system_languages.response(401, 'Not Authenticated')
def get(self):
history = request.args.get('history')
"""List languages for history filter or for language filter menu"""
args = self.get_request_parser.parse_args()
history = args.get('history')
if history and history not in False_Keys:
languages = list(TableHistory.select(TableHistory.language)
.where(TableHistory.language.is_null(False))
@ -42,13 +52,13 @@ class Languages(Resource):
})
except Exception:
continue
return jsonify(sorted(languages_dicts, key=itemgetter('name')))
else:
languages_dicts = TableSettingsLanguages.select(TableSettingsLanguages.name,
TableSettingsLanguages.code2,
TableSettingsLanguages.enabled)\
.order_by(TableSettingsLanguages.name).dicts()
languages_dicts = list(languages_dicts)
for item in languages_dicts:
item['enabled'] = item['enabled'] == 1
result = TableSettingsLanguages.select(TableSettingsLanguages.name,
TableSettingsLanguages.code2,
TableSettingsLanguages.enabled)\
.order_by(TableSettingsLanguages.name).dicts()
result = list(result)
for item in result:
item['enabled'] = item['enabled'] == 1
return jsonify(result)
return sorted(languages_dicts, key=itemgetter('name'))

View File

@ -1,14 +1,20 @@
# coding=utf-8
from flask import jsonify
from flask_restful import Resource
from flask_restx import Resource, Namespace, reqparse
from app.database import get_profiles_list
from ..utils import authenticate
api_ns_system_languages_profiles = Namespace('System Languages Profiles', description='List languages profiles')
@api_ns_system_languages_profiles.route('system/languages/profiles')
class LanguagesProfiles(Resource):
@authenticate
@api_ns_system_languages_profiles.doc(parser=None)
@api_ns_system_languages_profiles.response(200, 'Success')
@api_ns_system_languages_profiles.response(401, 'Not Authenticated')
def get(self):
return jsonify(get_profiles_list())
"""List languages profiles"""
return get_profiles_list()

View File

@ -3,18 +3,32 @@
import io
import os
from flask import jsonify
from flask_restful import Resource
from flask_restx import Resource, Namespace, fields
from app.logger import empty_log
from app.get_args import args
from ..utils import authenticate
api_ns_system_logs = Namespace('System Logs', description='List log file entries or empty log file')
@api_ns_system_logs.route('system/logs')
class SystemLogs(Resource):
get_response_model = api_ns_system_logs.model('SystemBackupsGetResponse', {
'timestamp': fields.String(),
'type': fields.String(),
'message': fields.String(),
'exception': fields.String(),
})
@authenticate
@api_ns_system_logs.marshal_with(get_response_model, envelope='data', code=200)
@api_ns_system_logs.doc(parser=None)
@api_ns_system_logs.response(200, 'Success')
@api_ns_system_logs.response(401, 'Not Authenticated')
def get(self):
"""List log entries"""
logs = []
with io.open(os.path.join(args.config_dir, 'log', 'bazarr.log'), encoding='UTF-8') as file:
raw_lines = file.read()
@ -31,12 +45,18 @@ class SystemLogs(Resource):
log["message"] = raw_message[3]
if raw_message_len > 4 and raw_message[4] != '\n':
log['exception'] = raw_message[4].strip('\'').replace(' ', '\u2003\u2003')
else:
log['exception'] = None
logs.append(log)
logs.reverse()
return jsonify(data=logs)
return logs
@authenticate
@api_ns_system_logs.doc(parser=None)
@api_ns_system_logs.response(204, 'Success')
@api_ns_system_logs.response(401, 'Not Authenticated')
def delete(self):
"""Force log rotation and create a new log file"""
empty_log()
return '', 204

View File

@ -2,16 +2,27 @@
import apprise
from flask import request
from flask_restful import Resource
from flask_restx import Resource, Namespace, reqparse
from ..utils import authenticate
api_ns_system_notifications = Namespace('System Notifications', description='Send test notifications provider message')
@api_ns_system_notifications.hide
@api_ns_system_notifications.route('system/notifications')
class Notifications(Resource):
patch_request_parser = reqparse.RequestParser()
patch_request_parser.add_argument('url', type=str, required=True, help='Notifications provider URL')
@authenticate
@api_ns_system_notifications.doc(parser=patch_request_parser)
@api_ns_system_notifications.response(204, 'Success')
@api_ns_system_notifications.response(401, 'Not Authenticated')
def patch(self):
url = request.form.get("url")
"""Test a notifications provider URL"""
args = self.patch_request_parser.parse_args()
url = args.get("url")
asset = apprise.AppriseAsset(async_mode=False)

View File

@ -5,18 +5,33 @@ import json
import os
import logging
from flask import jsonify
from flask_restful import Resource
from flask_restx import Resource, Namespace, fields
from app.config import settings
from app.get_args import args
from ..utils import authenticate
api_ns_system_releases = Namespace('System Releases', description='List Bazarr releases from Github')
@api_ns_system_releases.route('system/releases')
class SystemReleases(Resource):
get_response_model = api_ns_system_releases.model('SystemBackupsGetResponse', {
'body': fields.List(fields.String),
'name': fields.String(),
'date': fields.String(),
'prerelease': fields.Boolean(),
'current': fields.Boolean(),
})
@authenticate
@api_ns_system_releases.marshal_with(get_response_model, envelope='data', code=200)
@api_ns_system_releases.doc(parser=None)
@api_ns_system_releases.response(200, 'Success')
@api_ns_system_releases.response(401, 'Not Authenticated')
def get(self):
"""Get Bazarr releases"""
filtered_releases = []
try:
with io.open(os.path.join(args.config_dir, 'config', 'releases.txt'), 'r', encoding='UTF-8') as f:
@ -45,4 +60,4 @@ class SystemReleases(Resource):
except Exception:
logging.exception(
'BAZARR cannot parse releases caching file: ' + os.path.join(args.config_dir, 'config', 'releases.txt'))
return jsonify(data=filtered_releases)
return filtered_releases

View File

@ -1,18 +1,28 @@
# coding=utf-8
from flask import request, jsonify
from flask_restful import Resource
from flask_restx import Resource, Namespace, reqparse
from app.config import settings
from app.database import TableShows, TableMovies
from ..utils import authenticate
api_ns_system_searches = Namespace('System Searches', description='Search for series or movies by name')
@api_ns_system_searches.route('system/searches')
class Searches(Resource):
get_request_parser = reqparse.RequestParser()
get_request_parser.add_argument('query', type=str, required=True, help='Series or movie name to search for')
@authenticate
@api_ns_system_searches.doc(parser=get_request_parser)
@api_ns_system_searches.response(200, 'Success')
@api_ns_system_searches.response(401, 'Not Authenticated')
def get(self):
query = request.args.get('query')
"""List results from query"""
args = self.get_request_parser.parse_args()
query = args.get('query')
search_list = []
if query:
@ -38,4 +48,4 @@ class Searches(Resource):
movies = list(movies)
search_list += movies
return jsonify(search_list)
return search_list

View File

@ -3,7 +3,7 @@
import json
from flask import request, jsonify
from flask_restful import Resource
from flask_restx import Resource, Namespace
from app.database import TableLanguagesProfiles, TableSettingsLanguages, TableShows, TableMovies, \
TableSettingsNotifier
@ -15,7 +15,11 @@ from subtitles.indexer.movies import list_missing_subtitles_movies
from ..utils import authenticate
api_ns_system_settings = Namespace('systemSettings', description='System settings API endpoint')
@api_ns_system_settings.hide
@api_ns_system_settings.route('system/settings')
class SystemSettings(Resource):
@authenticate
def get(self):

View File

@ -4,8 +4,7 @@ import os
import platform
import logging
from flask import jsonify
from flask_restful import Resource
from flask_restx import Resource, Namespace
from tzlocal import get_localzone_name
from radarr.info import get_radarr_info
@ -15,10 +14,14 @@ from init import startTime
from ..utils import authenticate
api_ns_system_status = Namespace('System Status', description='List environment information and versions')
@api_ns_system_status.route('system/status')
class SystemStatus(Resource):
@authenticate
def get(self):
"""Return environment information and versions"""
package_version = ''
if 'BAZARR_PACKAGE_VERSION' in os.environ:
package_version = os.environ['BAZARR_PACKAGE_VERSION']
@ -44,4 +47,4 @@ class SystemStatus(Resource):
system_status.update({'start_time': startTime})
system_status.update({'timezone': timezone})
return jsonify(data=system_status)
return {'data': system_status}

View File

@ -1,16 +1,28 @@
# coding=utf-8
from flask import request
from flask_restful import Resource
from flask_restx import Resource, Namespace, reqparse
from ..utils import authenticate
api_ns_system = Namespace('System', description='Shutdown or restart Bazarr')
@api_ns_system.hide
@api_ns_system.route('system')
class System(Resource):
post_request_parser = reqparse.RequestParser()
post_request_parser.add_argument('action', type=str, required=True,
help='Action to perform from ["shutdown", "restart"]')
@authenticate
@api_ns_system.doc(parser=post_request_parser)
@api_ns_system.response(204, 'Success')
@api_ns_system.response(401, 'Not Authenticated')
def post(self):
"""Shutdown or restart Bazarr"""
args = self.post_request_parser.parse_args()
from app.server import webserver
action = request.args.get('action')
action = args.get('action')
if action == "shutdown":
webserver.shutdown()
elif action == "restart":

View File

@ -1,17 +1,37 @@
# coding=utf-8
from flask import request, jsonify
from flask_restful import Resource
from flask_restx import Resource, Namespace, reqparse, fields
from app.scheduler import scheduler
from ..utils import authenticate
api_ns_system_tasks = Namespace('System Tasks', description='List or execute tasks')
@api_ns_system_tasks.route('system/tasks')
class SystemTasks(Resource):
get_response_model = api_ns_system_tasks.model('SystemBackupsGetResponse', {
'interval': fields.String(),
'job_id': fields.String(),
'job_running': fields.Boolean(),
'name': fields.String(),
'next_run_in': fields.String(),
'next_run_time': fields.String(),
})
get_request_parser = reqparse.RequestParser()
get_request_parser.add_argument('taskid', type=str, required=False, help='List tasks or a single task properties')
@authenticate
@api_ns_system_tasks.marshal_with(get_response_model, envelope='data', code=200)
@api_ns_system_tasks.doc(parser=None)
@api_ns_system_tasks.response(200, 'Success')
@api_ns_system_tasks.response(401, 'Not Authenticated')
def get(self):
taskid = request.args.get('taskid')
"""List tasks"""
args = self.get_request_parser.parse_args()
taskid = args.get('taskid')
task_list = scheduler.get_task_list()
@ -21,11 +41,19 @@ class SystemTasks(Resource):
task_list = [item]
continue
return jsonify(data=task_list)
return task_list
post_request_parser = reqparse.RequestParser()
post_request_parser.add_argument('taskid', type=str, required=True, help='Task id of the task to run')
@authenticate
@api_ns_system_tasks.doc(parser=post_request_parser)
@api_ns_system_tasks.response(204, 'Success')
@api_ns_system_tasks.response(401, 'Not Authenticated')
def post(self):
taskid = request.form.get('taskid')
"""Run task"""
args = self.post_request_parser.parse_args()
taskid = args.get('taskid')
scheduler.execute_job_now(taskid)

View File

@ -1,16 +1,12 @@
# coding=utf-8
from flask import Blueprint
from flask_restful import Api
from .plex import WebHooksPlex
from .sonarr import WebHooksSonarr
from .radarr import WebHooksRadarr
from .plex import api_ns_webhooks_plex
from .sonarr import api_ns_webhooks_sonarr
from .radarr import api_ns_webhooks_radarr
api_bp_webhooks = Blueprint('api_webhooks', __name__)
api = Api(api_bp_webhooks)
api.add_resource(WebHooksPlex, '/webhooks/plex')
api.add_resource(WebHooksSonarr, '/webhooks/sonarr')
api.add_resource(WebHooksRadarr, '/webhooks/radarr')
api_ns_list_webhooks = [
api_ns_webhooks_plex,
api_ns_webhooks_radarr,
api_ns_webhooks_sonarr,
]

View File

@ -5,8 +5,7 @@ import requests
import os
import logging
from flask import request
from flask_restful import Resource
from flask_restx import Resource, Namespace, reqparse
from bs4 import BeautifulSoup as bso
from app.database import TableEpisodes, TableShows, TableMovies
@ -15,15 +14,31 @@ from subtitles.mass_download import episode_download_subtitles, movies_download_
from ..utils import authenticate
api_ns_webhooks_plex = Namespace('Webhooks Plex', description='Webhooks endpoint that can be configured in Plex to '
'trigger a subtitles search when playback start.')
@api_ns_webhooks_plex.route('webhooks/plex')
class WebHooksPlex(Resource):
post_request_parser = reqparse.RequestParser()
post_request_parser.add_argument('payload', type=str, required=True, help='Webhook payload')
@authenticate
@api_ns_webhooks_plex.doc(parser=post_request_parser)
@api_ns_webhooks_plex.response(200, 'Success')
@api_ns_webhooks_plex.response(204, 'Unhandled event')
@api_ns_webhooks_plex.response(400, 'No GUID found')
@api_ns_webhooks_plex.response(401, 'Not Authenticated')
@api_ns_webhooks_plex.response(404, 'IMDB series/movie ID not found')
def post(self):
json_webhook = request.form.get('payload')
"""Trigger subtitles search on play media event in Plex"""
args = self.post_request_parser.parse_args()
json_webhook = args.get('payload')
parsed_json_webhook = json.loads(json_webhook)
event = parsed_json_webhook['event']
if event not in ['media.play']:
return '', 204
return 'Unhandled event', 204
media_type = parsed_json_webhook['Metadata']['type']

View File

@ -1,7 +1,6 @@
# coding=utf-8
from flask import request
from flask_restful import Resource
from flask_restx import Resource, Namespace, reqparse
from app.database import TableMovies
from subtitles.mass_download import movies_download_subtitles
@ -11,10 +10,23 @@ from utilities.path_mappings import path_mappings
from ..utils import authenticate
api_ns_webhooks_radarr = Namespace('Webhooks Radarr', description='Webhooks to trigger subtitles search based on '
'Radarr movie file ID')
@api_ns_webhooks_radarr.route('webhooks/radarr')
class WebHooksRadarr(Resource):
post_request_parser = reqparse.RequestParser()
post_request_parser.add_argument('radarr_moviefile_id', type=int, required=True, help='Movie file ID')
@authenticate
@api_ns_webhooks_radarr.doc(parser=post_request_parser)
@api_ns_webhooks_radarr.response(200, 'Success')
@api_ns_webhooks_radarr.response(401, 'Not Authenticated')
def post(self):
movie_file_id = request.form.get('radarr_moviefile_id')
"""Search for missing subtitles for a specific movie file id"""
args = self.post_request_parser.parse_args()
movie_file_id = args.get('radarr_moviefile_id')
radarrMovieId = TableMovies.select(TableMovies.radarrId,
TableMovies.path) \

View File

@ -1,7 +1,6 @@
# coding=utf-8
from flask import request
from flask_restful import Resource
from flask_restx import Resource, Namespace, reqparse
from app.database import TableEpisodes, TableShows
from subtitles.mass_download import episode_download_subtitles
@ -11,10 +10,23 @@ from utilities.path_mappings import path_mappings
from ..utils import authenticate
api_ns_webhooks_sonarr = Namespace('Webhooks Sonarr', description='Webhooks to trigger subtitles search based on '
'Sonarr episode file ID')
@api_ns_webhooks_sonarr.route('webhooks/sonarr')
class WebHooksSonarr(Resource):
post_request_parser = reqparse.RequestParser()
post_request_parser.add_argument('sonarr_episodefile_id', type=int, required=True, help='Episode file ID')
@authenticate
@api_ns_webhooks_sonarr.doc(parser=post_request_parser)
@api_ns_webhooks_sonarr.response(200, 'Success')
@api_ns_webhooks_sonarr.response(401, 'Not Authenticated')
def post(self):
episode_file_id = request.form.get('sonarr_episodefile_id')
"""Search for missing subtitles for a specific episode file id"""
args = self.post_request_parser.parse_args()
episode_file_id = args.get('sonarr_episodefile_id')
sonarrEpisodeId = TableEpisodes.select(TableEpisodes.sonarrEpisodeId,
TableEpisodes.path) \

View File

@ -20,6 +20,8 @@ def create_app():
app.config['JSONIFY_PRETTYPRINT_REGULAR'] = True
app.config['JSON_AS_ASCII'] = False
app.config['RESTX_MASK_SWAGGER'] = False
if settings.get('cors', 'enabled'):
CORS(app)

View File

@ -9,7 +9,7 @@ import errno
from waitress.server import create_server
from time import sleep
from api import api_bp_list
from api import api_bp
from .ui import ui_bp
from .get_args import args
from .config import settings, base_url
@ -17,10 +17,7 @@ from .database import database
from .app import create_app
app = create_app()
for item in api_bp_list:
ui_bp.register_blueprint(item, url_prefix='/api')
ui_bp.register_blueprint(api_bp, url_prefix='/api')
app.register_blueprint(ui_bp, url_prefix=base_url.rstrip('/'))

View File

@ -137,6 +137,12 @@ def backup_download(filename):
return send_file(os.path.join(settings.backup.folder, filename), cache_timeout=0, as_attachment=True)
@ui_bp.route('/api/swaggerui/static/<path:filename>', methods=['GET'])
def swaggerui_static(filename):
return send_file(os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), 'libs', 'flask_restx',
'static', filename))
def configured():
System.update({System.configured: '1'}).execute()

View File

@ -18,6 +18,10 @@ from radarr.notify import notify_radarr
def delete_subtitles(media_type, language, forced, hi, media_path, subtitles_path, sonarr_series_id=None,
sonarr_episode_id=None, radarr_id=None):
if not subtitles_path:
logging.error('No subtitles to delete.')
return False
if not os.path.splitext(subtitles_path)[1] in SUBTITLE_EXTENSIONS:
logging.error('BAZARR can only delete subtitles files.')
return False

View File

@ -10,7 +10,6 @@ import {
faQuestion,
} from "@fortawesome/free-solid-svg-icons";
import { FontAwesomeIcon } from "@fortawesome/react-fontawesome";
import { isUndefined } from "lodash";
import { FunctionComponent, useMemo } from "react";
import { Column } from "react-table";
import SystemLogModal from "./modal";
@ -55,7 +54,7 @@ const Table: FunctionComponent<Props> = ({ logs }) => {
accessor: "exception",
Cell: ({ value }) => {
const modals = useModals();
if (!isUndefined(value)) {
if (value) {
return (
<Action
label="Detail"

View File

@ -1,14 +1,14 @@
import { useSystemHealth, useSystemStatus } from "@/apis/hooks";
import { QueryOverlay } from "@/components/async";
import { GithubRepoRoot } from "@/constants";
import { useInterval } from "@/utilities";
import { Environment, useInterval } from "@/utilities";
import { IconDefinition } from "@fortawesome/fontawesome-common-types";
import {
faDiscord,
faGithub,
faWikipediaW,
} from "@fortawesome/free-brands-svg-icons";
import { faPaperPlane } from "@fortawesome/free-solid-svg-icons";
import { faCode, faPaperPlane } from "@fortawesome/free-solid-svg-icons";
import { FontAwesomeIcon } from "@fortawesome/react-fontawesome";
import { Anchor, Container, Divider, Grid, Stack, Text } from "@mantine/core";
import { useDocumentTitle } from "@mantine/hooks";
@ -131,6 +131,11 @@ const SystemStatusView: FunctionComponent = () => {
Bazarr Wiki
</Label>
</Row>
<Row title="API documentation">
<Label icon={faCode} link={`${Environment.baseUrl}/api/`}>
Swagger UI
</Label>
</Row>
<Row title="Discord">
<Label icon={faDiscord} link="https://discord.gg/MH2e2eb">
Bazarr on Discord

View File

@ -1,716 +0,0 @@
from __future__ import absolute_import
from functools import wraps, partial
from flask import request, url_for, current_app
from flask import abort as original_flask_abort
from flask import make_response as original_flask_make_response
from flask.views import MethodView
from flask.signals import got_request_exception
from werkzeug.datastructures import Headers
from werkzeug.exceptions import HTTPException, MethodNotAllowed, NotFound, NotAcceptable, InternalServerError
from werkzeug.wrappers import Response as ResponseBase
from flask_restful.utils import http_status_message, unpack, OrderedDict
from flask_restful.representations.json import output_json
import sys
from types import MethodType
import operator
try:
from collections.abc import Mapping
except ImportError:
from collections import Mapping
__all__ = ('Api', 'Resource', 'marshal', 'marshal_with', 'marshal_with_field', 'abort')
def abort(http_status_code, **kwargs):
"""Raise a HTTPException for the given http_status_code. Attach any keyword
arguments to the exception for later processing.
"""
#noinspection PyUnresolvedReferences
try:
original_flask_abort(http_status_code)
except HTTPException as e:
if len(kwargs):
e.data = kwargs
raise
DEFAULT_REPRESENTATIONS = [('application/json', output_json)]
class Api(object):
"""
The main entry point for the application.
You need to initialize it with a Flask Application: ::
>>> app = Flask(__name__)
>>> api = restful.Api(app)
Alternatively, you can use :meth:`init_app` to set the Flask application
after it has been constructed.
:param app: the Flask application object
:type app: flask.Flask or flask.Blueprint
:param prefix: Prefix all routes with a value, eg v1 or 2010-04-01
:type prefix: str
:param default_mediatype: The default media type to return
:type default_mediatype: str
:param decorators: Decorators to attach to every resource
:type decorators: list
:param catch_all_404s: Use :meth:`handle_error`
to handle 404 errors throughout your app
:param serve_challenge_on_401: Whether to serve a challenge response to
clients on receiving 401. This usually leads to a username/password
popup in web browsers.
:param url_part_order: A string that controls the order that the pieces
of the url are concatenated when the full url is constructed. 'b'
is the blueprint (or blueprint registration) prefix, 'a' is the api
prefix, and 'e' is the path component the endpoint is added with
:type catch_all_404s: bool
:param errors: A dictionary to define a custom response for each
exception or error raised during a request
:type errors: dict
"""
def __init__(self, app=None, prefix='',
default_mediatype='application/json', decorators=None,
catch_all_404s=False, serve_challenge_on_401=False,
url_part_order='bae', errors=None):
self.representations = OrderedDict(DEFAULT_REPRESENTATIONS)
self.urls = {}
self.prefix = prefix
self.default_mediatype = default_mediatype
self.decorators = decorators if decorators else []
self.catch_all_404s = catch_all_404s
self.serve_challenge_on_401 = serve_challenge_on_401
self.url_part_order = url_part_order
self.errors = errors or {}
self.blueprint_setup = None
self.endpoints = set()
self.resources = []
self.app = None
self.blueprint = None
if app is not None:
self.app = app
self.init_app(app)
def init_app(self, app):
"""Initialize this class with the given :class:`flask.Flask`
application or :class:`flask.Blueprint` object.
:param app: the Flask application or blueprint object
:type app: flask.Flask
:type app: flask.Blueprint
Examples::
api = Api()
api.add_resource(...)
api.init_app(app)
"""
# If app is a blueprint, defer the initialization
try:
app.record(self._deferred_blueprint_init)
# Flask.Blueprint has a 'record' attribute, Flask.Api does not
except AttributeError:
self._init_app(app)
else:
self.blueprint = app
def _complete_url(self, url_part, registration_prefix):
"""This method is used to defer the construction of the final url in
the case that the Api is created with a Blueprint.
:param url_part: The part of the url the endpoint is registered with
:param registration_prefix: The part of the url contributed by the
blueprint. Generally speaking, BlueprintSetupState.url_prefix
"""
parts = {
'b': registration_prefix,
'a': self.prefix,
'e': url_part
}
return ''.join(parts[key] for key in self.url_part_order if parts[key])
@staticmethod
def _blueprint_setup_add_url_rule_patch(blueprint_setup, rule, endpoint=None, view_func=None, **options):
"""Method used to patch BlueprintSetupState.add_url_rule for setup
state instance corresponding to this Api instance. Exists primarily
to enable _complete_url's function.
:param blueprint_setup: The BlueprintSetupState instance (self)
:param rule: A string or callable that takes a string and returns a
string(_complete_url) that is the url rule for the endpoint
being registered
:param endpoint: See BlueprintSetupState.add_url_rule
:param view_func: See BlueprintSetupState.add_url_rule
:param **options: See BlueprintSetupState.add_url_rule
"""
if callable(rule):
rule = rule(blueprint_setup.url_prefix)
elif blueprint_setup.url_prefix:
rule = blueprint_setup.url_prefix + rule
options.setdefault('subdomain', blueprint_setup.subdomain)
if endpoint is None:
endpoint = view_func.__name__
defaults = blueprint_setup.url_defaults
if 'defaults' in options:
defaults = dict(defaults, **options.pop('defaults'))
blueprint_setup.app.add_url_rule(rule, '%s.%s' % (blueprint_setup.blueprint.name, endpoint),
view_func, defaults=defaults, **options)
def _deferred_blueprint_init(self, setup_state):
"""Synchronize prefix between blueprint/api and registration options, then
perform initialization with setup_state.app :class:`flask.Flask` object.
When a :class:`flask_restful.Api` object is initialized with a blueprint,
this method is recorded on the blueprint to be run when the blueprint is later
registered to a :class:`flask.Flask` object. This method also monkeypatches
BlueprintSetupState.add_url_rule with _blueprint_setup_add_url_rule_patch.
:param setup_state: The setup state object passed to deferred functions
during blueprint registration
:type setup_state: flask.blueprints.BlueprintSetupState
"""
self.blueprint_setup = setup_state
if setup_state.add_url_rule.__name__ != '_blueprint_setup_add_url_rule_patch':
setup_state._original_add_url_rule = setup_state.add_url_rule
setup_state.add_url_rule = MethodType(Api._blueprint_setup_add_url_rule_patch,
setup_state)
if not setup_state.first_registration:
raise ValueError('flask-restful blueprints can only be registered once.')
self._init_app(setup_state.app)
def _init_app(self, app):
"""Perform initialization actions with the given :class:`flask.Flask`
object.
:param app: The flask application object
:type app: flask.Flask
"""
app.handle_exception = partial(self.error_router, app.handle_exception)
app.handle_user_exception = partial(self.error_router, app.handle_user_exception)
if len(self.resources) > 0:
for resource, urls, kwargs in self.resources:
self._register_view(app, resource, *urls, **kwargs)
def owns_endpoint(self, endpoint):
"""Tests if an endpoint name (not path) belongs to this Api. Takes
in to account the Blueprint name part of the endpoint name.
:param endpoint: The name of the endpoint being checked
:return: bool
"""
if self.blueprint:
if endpoint.startswith(self.blueprint.name):
endpoint = endpoint.split(self.blueprint.name + '.', 1)[-1]
else:
return False
return endpoint in self.endpoints
def _should_use_fr_error_handler(self):
""" Determine if error should be handled with FR or default Flask
The goal is to return Flask error handlers for non-FR-related routes,
and FR errors (with the correct media type) for FR endpoints. This
method currently handles 404 and 405 errors.
:return: bool
"""
adapter = current_app.create_url_adapter(request)
try:
adapter.match()
except MethodNotAllowed as e:
# Check if the other HTTP methods at this url would hit the Api
valid_route_method = e.valid_methods[0]
rule, _ = adapter.match(method=valid_route_method, return_rule=True)
return self.owns_endpoint(rule.endpoint)
except NotFound:
return self.catch_all_404s
except:
# Werkzeug throws other kinds of exceptions, such as Redirect
pass
def _has_fr_route(self):
"""Encapsulating the rules for whether the request was to a Flask endpoint"""
# 404's, 405's, which might not have a url_rule
if self._should_use_fr_error_handler():
return True
# for all other errors, just check if FR dispatched the route
if not request.url_rule:
return False
return self.owns_endpoint(request.url_rule.endpoint)
def error_router(self, original_handler, e):
"""This function decides whether the error occured in a flask-restful
endpoint or not. If it happened in a flask-restful endpoint, our
handler will be dispatched. If it happened in an unrelated view, the
app's original error handler will be dispatched.
In the event that the error occurred in a flask-restful endpoint but
the local handler can't resolve the situation, the router will fall
back onto the original_handler as last resort.
:param original_handler: the original Flask error handler for the app
:type original_handler: function
:param e: the exception raised while handling the request
:type e: Exception
"""
if self._has_fr_route():
try:
return self.handle_error(e)
except Exception:
pass # Fall through to original handler
return original_handler(e)
def handle_error(self, e):
"""Error handler for the API transforms a raised exception into a Flask
response, with the appropriate HTTP status code and body.
:param e: the raised Exception object
:type e: Exception
"""
got_request_exception.send(current_app._get_current_object(), exception=e)
if not isinstance(e, HTTPException) and current_app.propagate_exceptions:
exc_type, exc_value, tb = sys.exc_info()
if exc_value is e:
raise
else:
raise e
headers = Headers()
if isinstance(e, HTTPException):
if e.response is not None:
# If HTTPException is initialized with a response, then return e.get_response().
# This prevents specified error response from being overridden.
# eg. HTTPException(response=Response("Hello World"))
resp = e.get_response()
return resp
code = e.code
default_data = {
'message': getattr(e, 'description', http_status_message(code))
}
headers = e.get_response().headers
else:
code = 500
default_data = {
'message': http_status_message(code),
}
# Werkzeug exceptions generate a content-length header which is added
# to the response in addition to the actual content-length header
# https://github.com/flask-restful/flask-restful/issues/534
remove_headers = ('Content-Length',)
for header in remove_headers:
headers.pop(header, None)
data = getattr(e, 'data', default_data)
if code and code >= 500:
exc_info = sys.exc_info()
if exc_info[1] is None:
exc_info = None
current_app.log_exception(exc_info)
error_cls_name = type(e).__name__
if error_cls_name in self.errors:
custom_data = self.errors.get(error_cls_name, {})
code = custom_data.get('status', 500)
data.update(custom_data)
if code == 406 and self.default_mediatype is None:
# if we are handling NotAcceptable (406), make sure that
# make_response uses a representation we support as the
# default mediatype (so that make_response doesn't throw
# another NotAcceptable error).
supported_mediatypes = list(self.representations.keys())
fallback_mediatype = supported_mediatypes[0] if supported_mediatypes else "text/plain"
resp = self.make_response(
data,
code,
headers,
fallback_mediatype = fallback_mediatype
)
else:
resp = self.make_response(data, code, headers)
if code == 401:
resp = self.unauthorized(resp)
return resp
def mediatypes_method(self):
"""Return a method that returns a list of mediatypes
"""
return lambda resource_cls: self.mediatypes() + [self.default_mediatype]
def add_resource(self, resource, *urls, **kwargs):
"""Adds a resource to the api.
:param resource: the class name of your resource
:type resource: :class:`Type[Resource]`
:param urls: one or more url routes to match for the resource, standard
flask routing rules apply. Any url variables will be
passed to the resource method as args.
:type urls: str
:param endpoint: endpoint name (defaults to :meth:`Resource.__name__.lower`
Can be used to reference this route in :class:`fields.Url` fields
:type endpoint: str
:param resource_class_args: args to be forwarded to the constructor of
the resource.
:type resource_class_args: tuple
:param resource_class_kwargs: kwargs to be forwarded to the constructor
of the resource.
:type resource_class_kwargs: dict
Additional keyword arguments not specified above will be passed as-is
to :meth:`flask.Flask.add_url_rule`.
Examples::
api.add_resource(HelloWorld, '/', '/hello')
api.add_resource(Foo, '/foo', endpoint="foo")
api.add_resource(FooSpecial, '/special/foo', endpoint="foo")
"""
if self.app is not None:
self._register_view(self.app, resource, *urls, **kwargs)
else:
self.resources.append((resource, urls, kwargs))
def resource(self, *urls, **kwargs):
"""Wraps a :class:`~flask_restful.Resource` class, adding it to the
api. Parameters are the same as :meth:`~flask_restful.Api.add_resource`.
Example::
app = Flask(__name__)
api = restful.Api(app)
@api.resource('/foo')
class Foo(Resource):
def get(self):
return 'Hello, World!'
"""
def decorator(cls):
self.add_resource(cls, *urls, **kwargs)
return cls
return decorator
def _register_view(self, app, resource, *urls, **kwargs):
endpoint = kwargs.pop('endpoint', None) or resource.__name__.lower()
self.endpoints.add(endpoint)
resource_class_args = kwargs.pop('resource_class_args', ())
resource_class_kwargs = kwargs.pop('resource_class_kwargs', {})
# NOTE: 'view_functions' is cleaned up from Blueprint class in Flask 1.0
if endpoint in getattr(app, 'view_functions', {}):
previous_view_class = app.view_functions[endpoint].__dict__['view_class']
# if you override the endpoint with a different class, avoid the collision by raising an exception
if previous_view_class != resource:
raise ValueError('This endpoint (%s) is already set to the class %s.' % (endpoint, previous_view_class.__name__))
resource.mediatypes = self.mediatypes_method() # Hacky
resource.endpoint = endpoint
resource_func = self.output(resource.as_view(endpoint, *resource_class_args,
**resource_class_kwargs))
for decorator in self.decorators:
resource_func = decorator(resource_func)
for url in urls:
# If this Api has a blueprint
if self.blueprint:
# And this Api has been setup
if self.blueprint_setup:
# Set the rule to a string directly, as the blueprint is already
# set up.
self.blueprint_setup.add_url_rule(url, view_func=resource_func, **kwargs)
continue
else:
# Set the rule to a function that expects the blueprint prefix
# to construct the final url. Allows deferment of url finalization
# in the case that the associated Blueprint has not yet been
# registered to an application, so we can wait for the registration
# prefix
rule = partial(self._complete_url, url)
else:
# If we've got no Blueprint, just build a url with no prefix
rule = self._complete_url(url, '')
# Add the url to the application or blueprint
app.add_url_rule(rule, view_func=resource_func, **kwargs)
def output(self, resource):
"""Wraps a resource (as a flask view function), for cases where the
resource does not directly return a response object
:param resource: The resource as a flask view function
"""
@wraps(resource)
def wrapper(*args, **kwargs):
resp = resource(*args, **kwargs)
if isinstance(resp, ResponseBase): # There may be a better way to test
return resp
data, code, headers = unpack(resp)
return self.make_response(data, code, headers=headers)
return wrapper
def url_for(self, resource, **values):
"""Generates a URL to the given resource.
Works like :func:`flask.url_for`."""
endpoint = resource.endpoint
if self.blueprint:
endpoint = '{0}.{1}'.format(self.blueprint.name, endpoint)
return url_for(endpoint, **values)
def make_response(self, data, *args, **kwargs):
"""Looks up the representation transformer for the requested media
type, invoking the transformer to create a response object. This
defaults to default_mediatype if no transformer is found for the
requested mediatype. If default_mediatype is None, a 406 Not
Acceptable response will be sent as per RFC 2616 section 14.1
:param data: Python object containing response data to be transformed
"""
default_mediatype = kwargs.pop('fallback_mediatype', None) or self.default_mediatype
mediatype = request.accept_mimetypes.best_match(
self.representations,
default=default_mediatype,
)
if mediatype is None:
raise NotAcceptable()
if mediatype in self.representations:
resp = self.representations[mediatype](data, *args, **kwargs)
resp.headers['Content-Type'] = mediatype
return resp
elif mediatype == 'text/plain':
resp = original_flask_make_response(str(data), *args, **kwargs)
resp.headers['Content-Type'] = 'text/plain'
return resp
else:
raise InternalServerError()
def mediatypes(self):
"""Returns a list of requested mediatypes sent in the Accept header"""
return [h for h, q in sorted(request.accept_mimetypes,
key=operator.itemgetter(1), reverse=True)]
def representation(self, mediatype):
"""Allows additional representation transformers to be declared for the
api. Transformers are functions that must be decorated with this
method, passing the mediatype the transformer represents. Three
arguments are passed to the transformer:
* The data to be represented in the response body
* The http status code
* A dictionary of headers
The transformer should convert the data appropriately for the mediatype
and return a Flask response object.
Ex::
@api.representation('application/xml')
def xml(data, code, headers):
resp = make_response(convert_data_to_xml(data), code)
resp.headers.extend(headers)
return resp
"""
def wrapper(func):
self.representations[mediatype] = func
return func
return wrapper
def unauthorized(self, response):
""" Given a response, change it to ask for credentials """
if self.serve_challenge_on_401:
realm = current_app.config.get("HTTP_BASIC_AUTH_REALM", "flask-restful")
challenge = u"{0} realm=\"{1}\"".format("Basic", realm)
response.headers['WWW-Authenticate'] = challenge
return response
class Resource(MethodView):
"""
Represents an abstract RESTful resource. Concrete resources should
extend from this class and expose methods for each supported HTTP
method. If a resource is invoked with an unsupported HTTP method,
the API will return a response with status 405 Method Not Allowed.
Otherwise the appropriate method is called and passed all arguments
from the url rule used when adding the resource to an Api instance. See
:meth:`~flask_restful.Api.add_resource` for details.
"""
representations = None
method_decorators = []
def dispatch_request(self, *args, **kwargs):
# Taken from flask
#noinspection PyUnresolvedReferences
meth = getattr(self, request.method.lower(), None)
if meth is None and request.method == 'HEAD':
meth = getattr(self, 'get', None)
assert meth is not None, 'Unimplemented method %r' % request.method
if isinstance(self.method_decorators, Mapping):
decorators = self.method_decorators.get(request.method.lower(), [])
else:
decorators = self.method_decorators
for decorator in decorators:
meth = decorator(meth)
resp = meth(*args, **kwargs)
if isinstance(resp, ResponseBase): # There may be a better way to test
return resp
representations = self.representations or OrderedDict()
#noinspection PyUnresolvedReferences
mediatype = request.accept_mimetypes.best_match(representations, default=None)
if mediatype in representations:
data, code, headers = unpack(resp)
resp = representations[mediatype](data, code, headers)
resp.headers['Content-Type'] = mediatype
return resp
return resp
def marshal(data, fields, envelope=None):
"""Takes raw data (in the form of a dict, list, object) and a dict of
fields to output and filters the data based on those fields.
:param data: the actual object(s) from which the fields are taken from
:param fields: a dict of whose keys will make up the final serialized
response output
:param envelope: optional key that will be used to envelop the serialized
response
>>> from flask_restful import fields, marshal
>>> data = { 'a': 100, 'b': 'foo' }
>>> mfields = { 'a': fields.Raw }
>>> marshal(data, mfields)
OrderedDict([('a', 100)])
>>> marshal(data, mfields, envelope='data')
OrderedDict([('data', OrderedDict([('a', 100)]))])
"""
def make(cls):
if isinstance(cls, type):
return cls()
return cls
if isinstance(data, (list, tuple)):
return (OrderedDict([(envelope, [marshal(d, fields) for d in data])])
if envelope else [marshal(d, fields) for d in data])
items = ((k, marshal(data, v) if isinstance(v, dict)
else make(v).output(k, data))
for k, v in fields.items())
return OrderedDict([(envelope, OrderedDict(items))]) if envelope else OrderedDict(items)
class marshal_with(object):
"""A decorator that apply marshalling to the return values of your methods.
>>> from flask_restful import fields, marshal_with
>>> mfields = { 'a': fields.Raw }
>>> @marshal_with(mfields)
... def get():
... return { 'a': 100, 'b': 'foo' }
...
...
>>> get()
OrderedDict([('a', 100)])
>>> @marshal_with(mfields, envelope='data')
... def get():
... return { 'a': 100, 'b': 'foo' }
...
...
>>> get()
OrderedDict([('data', OrderedDict([('a', 100)]))])
see :meth:`flask_restful.marshal`
"""
def __init__(self, fields, envelope=None):
"""
:param fields: a dict of whose keys will make up the final
serialized response output
:param envelope: optional key that will be used to envelop the serialized
response
"""
self.fields = fields
self.envelope = envelope
def __call__(self, f):
@wraps(f)
def wrapper(*args, **kwargs):
resp = f(*args, **kwargs)
if isinstance(resp, tuple):
data, code, headers = unpack(resp)
return marshal(data, self.fields, self.envelope), code, headers
else:
return marshal(resp, self.fields, self.envelope)
return wrapper
class marshal_with_field(object):
"""
A decorator that formats the return values of your methods with a single field.
>>> from flask_restful import marshal_with_field, fields
>>> @marshal_with_field(fields.List(fields.Integer))
... def get():
... return ['1', 2, 3.0]
...
>>> get()
[1, 2, 3]
see :meth:`flask_restful.marshal_with`
"""
def __init__(self, field):
"""
:param field: a single field with which to marshal the output.
"""
if isinstance(field, type):
self.field = field()
else:
self.field = field
def __call__(self, f):
@wraps(f)
def wrapper(*args, **kwargs):
resp = f(*args, **kwargs)
if isinstance(resp, tuple):
data, code, headers = unpack(resp)
return self.field.format(data), code, headers
return self.field.format(resp)
return wrapper

View File

@ -1,3 +0,0 @@
#!/usr/bin/env python
__version__ = '0.3.9'

View File

@ -1,414 +0,0 @@
from calendar import timegm
from decimal import Decimal as MyDecimal, ROUND_HALF_EVEN
from email.utils import formatdate
import six
try:
from urlparse import urlparse, urlunparse
except ImportError:
# python3
from urllib.parse import urlparse, urlunparse
from flask_restful import marshal
from flask import url_for, request
__all__ = ["String", "FormattedString", "Url", "DateTime", "Float",
"Integer", "Arbitrary", "Nested", "List", "Raw", "Boolean",
"Fixed", "Price"]
class MarshallingException(Exception):
"""
This is an encapsulating Exception in case of marshalling error.
"""
def __init__(self, underlying_exception):
# just put the contextual representation of the error to hint on what
# went wrong without exposing internals
super(MarshallingException, self).__init__(six.text_type(underlying_exception))
def is_indexable_but_not_string(obj):
return not hasattr(obj, "strip") and hasattr(obj, "__iter__")
def get_value(key, obj, default=None):
"""Helper for pulling a keyed value off various types of objects"""
if isinstance(key, int):
return _get_value_for_key(key, obj, default)
elif callable(key):
return key(obj)
else:
return _get_value_for_keys(key.split('.'), obj, default)
def _get_value_for_keys(keys, obj, default):
if len(keys) == 1:
return _get_value_for_key(keys[0], obj, default)
else:
return _get_value_for_keys(
keys[1:], _get_value_for_key(keys[0], obj, default), default)
def _get_value_for_key(key, obj, default):
if is_indexable_but_not_string(obj):
try:
return obj[key]
except (IndexError, TypeError, KeyError):
pass
return getattr(obj, key, default)
def to_marshallable_type(obj):
"""Helper for converting an object to a dictionary only if it is not
dictionary already or an indexable object nor a simple type"""
if obj is None:
return None # make it idempotent for None
if hasattr(obj, '__marshallable__'):
return obj.__marshallable__()
if hasattr(obj, '__getitem__'):
return obj # it is indexable it is ok
return dict(obj.__dict__)
class Raw(object):
"""Raw provides a base field class from which others should extend. It
applies no formatting by default, and should only be used in cases where
data does not need to be formatted before being serialized. Fields should
throw a :class:`MarshallingException` in case of parsing problem.
:param default: The default value for the field, if no value is
specified.
:param attribute: If the public facing value differs from the internal
value, use this to retrieve a different attribute from the response
than the publicly named value.
"""
def __init__(self, default=None, attribute=None):
self.attribute = attribute
self.default = default
def format(self, value):
"""Formats a field's value. No-op by default - field classes that
modify how the value of existing object keys should be presented should
override this and apply the appropriate formatting.
:param value: The value to format
:exception MarshallingException: In case of formatting problem
Ex::
class TitleCase(Raw):
def format(self, value):
return unicode(value).title()
"""
return value
def output(self, key, obj):
"""Pulls the value for the given key from the object, applies the
field's formatting and returns the result. If the key is not found
in the object, returns the default value. Field classes that create
values which do not require the existence of the key in the object
should override this and return the desired value.
:exception MarshallingException: In case of formatting problem
"""
value = get_value(key if self.attribute is None else self.attribute, obj)
if value is None:
return self.default
return self.format(value)
class Nested(Raw):
"""Allows you to nest one set of fields inside another.
See :ref:`nested-field` for more information
:param dict nested: The dictionary to nest
:param bool allow_null: Whether to return None instead of a dictionary
with null keys, if a nested dictionary has all-null keys
:param kwargs: If ``default`` keyword argument is present, a nested
dictionary will be marshaled as its value if nested dictionary is
all-null keys (e.g. lets you return an empty JSON object instead of
null)
"""
def __init__(self, nested, allow_null=False, **kwargs):
self.nested = nested
self.allow_null = allow_null
super(Nested, self).__init__(**kwargs)
def output(self, key, obj):
value = get_value(key if self.attribute is None else self.attribute, obj)
if value is None:
if self.allow_null:
return None
elif self.default is not None:
return self.default
return marshal(value, self.nested)
class List(Raw):
"""
Field for marshalling lists of other fields.
See :ref:`list-field` for more information.
:param cls_or_instance: The field type the list will contain.
"""
def __init__(self, cls_or_instance, **kwargs):
super(List, self).__init__(**kwargs)
error_msg = ("The type of the list elements must be a subclass of "
"flask_restful.fields.Raw")
if isinstance(cls_or_instance, type):
if not issubclass(cls_or_instance, Raw):
raise MarshallingException(error_msg)
self.container = cls_or_instance()
else:
if not isinstance(cls_or_instance, Raw):
raise MarshallingException(error_msg)
self.container = cls_or_instance
def format(self, value):
# Convert all instances in typed list to container type
if isinstance(value, set):
value = list(value)
return [
self.container.output(idx,
val if (isinstance(val, dict)
or (self.container.attribute
and hasattr(val, self.container.attribute)))
and not isinstance(self.container, Nested)
and not type(self.container) is Raw
else value)
for idx, val in enumerate(value)
]
def output(self, key, data):
value = get_value(key if self.attribute is None else self.attribute, data)
# we cannot really test for external dict behavior
if is_indexable_but_not_string(value) and not isinstance(value, dict):
return self.format(value)
if value is None:
return self.default
return [marshal(value, self.container.nested)]
class String(Raw):
"""
Marshal a value as a string. Uses ``six.text_type`` so values will
be converted to :class:`unicode` in python2 and :class:`str` in
python3.
"""
def format(self, value):
try:
return six.text_type(value)
except ValueError as ve:
raise MarshallingException(ve)
class Integer(Raw):
""" Field for outputting an integer value.
:param int default: The default value for the field, if no value is
specified.
"""
def __init__(self, default=0, **kwargs):
super(Integer, self).__init__(default=default, **kwargs)
def format(self, value):
try:
if value is None:
return self.default
return int(value)
except ValueError as ve:
raise MarshallingException(ve)
class Boolean(Raw):
"""
Field for outputting a boolean value.
Empty collections such as ``""``, ``{}``, ``[]``, etc. will be converted to
``False``.
"""
def format(self, value):
return bool(value)
class FormattedString(Raw):
"""
FormattedString is used to interpolate other values from
the response into this field. The syntax for the source string is
the same as the string :meth:`~str.format` method from the python
stdlib.
Ex::
fields = {
'name': fields.String,
'greeting': fields.FormattedString("Hello {name}")
}
data = {
'name': 'Doug',
}
marshal(data, fields)
"""
def __init__(self, src_str):
"""
:param string src_str: the string to format with the other
values from the response.
"""
super(FormattedString, self).__init__()
self.src_str = six.text_type(src_str)
def output(self, key, obj):
try:
data = to_marshallable_type(obj)
return self.src_str.format(**data)
except (TypeError, IndexError) as error:
raise MarshallingException(error)
class Url(Raw):
"""
A string representation of a Url
:param endpoint: Endpoint name. If endpoint is ``None``,
``request.endpoint`` is used instead
:type endpoint: str
:param absolute: If ``True``, ensures that the generated urls will have the
hostname included
:type absolute: bool
:param scheme: URL scheme specifier (e.g. ``http``, ``https``)
:type scheme: str
"""
def __init__(self, endpoint=None, absolute=False, scheme=None, **kwargs):
super(Url, self).__init__(**kwargs)
self.endpoint = endpoint
self.absolute = absolute
self.scheme = scheme
def output(self, key, obj):
try:
data = to_marshallable_type(obj)
endpoint = self.endpoint if self.endpoint is not None else request.endpoint
o = urlparse(url_for(endpoint, _external=self.absolute, **data))
if self.absolute:
scheme = self.scheme if self.scheme is not None else o.scheme
return urlunparse((scheme, o.netloc, o.path, "", "", ""))
return urlunparse(("", "", o.path, "", "", ""))
except TypeError as te:
raise MarshallingException(te)
class Float(Raw):
"""
A double as IEEE-754 double precision.
ex : 3.141592653589793 3.1415926535897933e-06 3.141592653589793e+24 nan inf
-inf
"""
def format(self, value):
try:
return float(value)
except ValueError as ve:
raise MarshallingException(ve)
class Arbitrary(Raw):
"""
A floating point number with an arbitrary precision
ex: 634271127864378216478362784632784678324.23432
"""
def format(self, value):
return six.text_type(MyDecimal(value))
class DateTime(Raw):
"""
Return a formatted datetime string in UTC. Supported formats are RFC 822
and ISO 8601.
See :func:`email.utils.formatdate` for more info on the RFC 822 format.
See :meth:`datetime.datetime.isoformat` for more info on the ISO 8601
format.
:param dt_format: ``'rfc822'`` or ``'iso8601'``
:type dt_format: str
"""
def __init__(self, dt_format='rfc822', **kwargs):
super(DateTime, self).__init__(**kwargs)
self.dt_format = dt_format
def format(self, value):
try:
if self.dt_format == 'rfc822':
return _rfc822(value)
elif self.dt_format == 'iso8601':
return _iso8601(value)
else:
raise MarshallingException(
'Unsupported date format %s' % self.dt_format
)
except AttributeError as ae:
raise MarshallingException(ae)
ZERO = MyDecimal()
class Fixed(Raw):
"""
A decimal number with a fixed precision.
"""
def __init__(self, decimals=5, **kwargs):
super(Fixed, self).__init__(**kwargs)
self.precision = MyDecimal('0.' + '0' * (decimals - 1) + '1')
def format(self, value):
dvalue = MyDecimal(value)
if not dvalue.is_normal() and dvalue != ZERO:
raise MarshallingException('Invalid Fixed precision number.')
return six.text_type(dvalue.quantize(self.precision, rounding=ROUND_HALF_EVEN))
"""Alias for :class:`~fields.Fixed`"""
Price = Fixed
def _rfc822(dt):
"""Turn a datetime object into a formatted date.
Example::
fields._rfc822(datetime(2011, 1, 1)) => "Sat, 01 Jan 2011 00:00:00 -0000"
:param dt: The datetime to transform
:type dt: datetime
:return: A RFC 822 formatted date string
"""
return formatdate(timegm(dt.utctimetuple()))
def _iso8601(dt):
"""Turn a datetime object into an ISO8601 formatted date.
Example::
fields._iso8601(datetime(2012, 1, 1, 0, 0)) => "2012-01-01T00:00:00"
:param dt: The datetime to transform
:type dt: datetime
:return: A ISO 8601 formatted date string
"""
return dt.isoformat()

View File

@ -1,282 +0,0 @@
from calendar import timegm
from datetime import datetime, time, timedelta
from email.utils import parsedate_tz, mktime_tz
import re
import aniso8601
import pytz
# Constants for upgrading date-based intervals to full datetimes.
START_OF_DAY = time(0, 0, 0, tzinfo=pytz.UTC)
END_OF_DAY = time(23, 59, 59, 999999, tzinfo=pytz.UTC)
# https://code.djangoproject.com/browser/django/trunk/django/core/validators.py
# basic auth added by frank
url_regex = re.compile(
r'^(?:http|ftp)s?://' # http:// or https://
r'(?:[^:@]+?:[^:@]*?@|)' # basic auth
r'(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+'
r'(?:[A-Z]{2,6}\.?|[A-Z0-9-]{2,}\.?)|' # domain...
r'localhost|' # localhost...
r'\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}|' # ...or ipv4
r'\[?[A-F0-9]*:[A-F0-9:]+\]?)' # ...or ipv6
r'(?::\d+)?' # optional port
r'(?:/?|[/?]\S+)$', re.IGNORECASE)
def url(value):
"""Validate a URL.
:param string value: The URL to validate
:returns: The URL if valid.
:raises: ValueError
"""
if not url_regex.search(value):
message = u"{0} is not a valid URL".format(value)
if url_regex.search('http://' + value):
message += u". Did you mean: http://{0}".format(value)
raise ValueError(message)
return value
class regex(object):
"""Validate a string based on a regular expression.
Example::
parser = reqparse.RequestParser()
parser.add_argument('example', type=inputs.regex('^[0-9]+$'))
Input to the ``example`` argument will be rejected if it contains anything
but numbers.
:param pattern: The regular expression the input must match
:type pattern: str
:param flags: Flags to change expression behavior
:type flags: int
"""
def __init__(self, pattern, flags=0):
self.pattern = pattern
self.re = re.compile(pattern, flags)
def __call__(self, value):
if not self.re.search(value):
message = 'Value does not match pattern: "{0}"'.format(self.pattern)
raise ValueError(message)
return value
def __deepcopy__(self, memo):
return regex(self.pattern)
def _normalize_interval(start, end, value):
"""Normalize datetime intervals.
Given a pair of datetime.date or datetime.datetime objects,
returns a 2-tuple of tz-aware UTC datetimes spanning the same interval.
For datetime.date objects, the returned interval starts at 00:00:00.0
on the first date and ends at 00:00:00.0 on the second.
Naive datetimes are upgraded to UTC.
Timezone-aware datetimes are normalized to the UTC tzdata.
Params:
- start: A date or datetime
- end: A date or datetime
"""
if not isinstance(start, datetime):
start = datetime.combine(start, START_OF_DAY)
end = datetime.combine(end, START_OF_DAY)
if start.tzinfo is None:
start = pytz.UTC.localize(start)
end = pytz.UTC.localize(end)
else:
start = start.astimezone(pytz.UTC)
end = end.astimezone(pytz.UTC)
return start, end
def _expand_datetime(start, value):
if not isinstance(start, datetime):
# Expand a single date object to be the interval spanning
# that entire day.
end = start + timedelta(days=1)
else:
# Expand a datetime based on the finest resolution provided
# in the original input string.
time = value.split('T')[1]
time_without_offset = re.sub('[+-].+', '', time)
num_separators = time_without_offset.count(':')
if num_separators == 0:
# Hour resolution
end = start + timedelta(hours=1)
elif num_separators == 1:
# Minute resolution:
end = start + timedelta(minutes=1)
else:
# Second resolution
end = start + timedelta(seconds=1)
return end
def _parse_interval(value):
"""Do some nasty try/except voodoo to get some sort of datetime
object(s) out of the string.
"""
try:
return sorted(aniso8601.parse_interval(value))
except ValueError:
try:
return aniso8601.parse_datetime(value), None
except ValueError:
return aniso8601.parse_date(value), None
def iso8601interval(value, argument='argument'):
"""Parses ISO 8601-formatted datetime intervals into tuples of datetimes.
Accepts both a single date(time) or a full interval using either start/end
or start/duration notation, with the following behavior:
- Intervals are defined as inclusive start, exclusive end
- Single datetimes are translated into the interval spanning the
largest resolution not specified in the input value, up to the day.
- The smallest accepted resolution is 1 second.
- All timezones are accepted as values; returned datetimes are
localized to UTC. Naive inputs and date inputs will are assumed UTC.
Examples::
"2013-01-01" -> datetime(2013, 1, 1), datetime(2013, 1, 2)
"2013-01-01T12" -> datetime(2013, 1, 1, 12), datetime(2013, 1, 1, 13)
"2013-01-01/2013-02-28" -> datetime(2013, 1, 1), datetime(2013, 2, 28)
"2013-01-01/P3D" -> datetime(2013, 1, 1), datetime(2013, 1, 4)
"2013-01-01T12:00/PT30M" -> datetime(2013, 1, 1, 12), datetime(2013, 1, 1, 12, 30)
"2013-01-01T06:00/2013-01-01T12:00" -> datetime(2013, 1, 1, 6), datetime(2013, 1, 1, 12)
:param str value: The ISO8601 date time as a string
:return: Two UTC datetimes, the start and the end of the specified interval
:rtype: A tuple (datetime, datetime)
:raises: ValueError, if the interval is invalid.
"""
try:
start, end = _parse_interval(value)
if end is None:
end = _expand_datetime(start, value)
start, end = _normalize_interval(start, end, value)
except ValueError:
raise ValueError(
"Invalid {arg}: {value}. {arg} must be a valid ISO8601 "
"date/time interval.".format(arg=argument, value=value),
)
return start, end
def date(value):
"""Parse a valid looking date in the format YYYY-mm-dd"""
date = datetime.strptime(value, "%Y-%m-%d")
return date
def _get_integer(value):
try:
return int(value)
except (TypeError, ValueError):
raise ValueError('{0} is not a valid integer'.format(value))
def natural(value, argument='argument'):
""" Restrict input type to the natural numbers (0, 1, 2, 3...) """
value = _get_integer(value)
if value < 0:
error = ('Invalid {arg}: {value}. {arg} must be a non-negative '
'integer'.format(arg=argument, value=value))
raise ValueError(error)
return value
def positive(value, argument='argument'):
""" Restrict input type to the positive integers (1, 2, 3...) """
value = _get_integer(value)
if value < 1:
error = ('Invalid {arg}: {value}. {arg} must be a positive '
'integer'.format(arg=argument, value=value))
raise ValueError(error)
return value
class int_range(object):
""" Restrict input to an integer in a range (inclusive) """
def __init__(self, low, high, argument='argument'):
self.low = low
self.high = high
self.argument = argument
def __call__(self, value):
value = _get_integer(value)
if value < self.low or value > self.high:
error = ('Invalid {arg}: {val}. {arg} must be within the range {lo} - {hi}'
.format(arg=self.argument, val=value, lo=self.low, hi=self.high))
raise ValueError(error)
return value
def boolean(value):
"""Parse the string ``"true"`` or ``"false"`` as a boolean (case
insensitive). Also accepts ``"1"`` and ``"0"`` as ``True``/``False``
(respectively). If the input is from the request JSON body, the type is
already a native python boolean, and will be passed through without
further parsing.
"""
if isinstance(value, bool):
return value
if not value:
raise ValueError("boolean type must be non-null")
value = value.lower()
if value in ('true', '1',):
return True
if value in ('false', '0',):
return False
raise ValueError("Invalid literal for boolean(): {0}".format(value))
def datetime_from_rfc822(datetime_str):
"""Turns an RFC822 formatted date into a datetime object.
Example::
inputs.datetime_from_rfc822("Wed, 02 Oct 2002 08:00:00 EST")
:param datetime_str: The RFC822-complying string to transform
:type datetime_str: str
:return: A datetime
"""
return datetime.fromtimestamp(mktime_tz(parsedate_tz(datetime_str)), pytz.utc)
def datetime_from_iso8601(datetime_str):
"""Turns an ISO8601 formatted datetime into a datetime object.
Example::
inputs.datetime_from_iso8601("2012-01-01T23:30:00+02:00")
:param datetime_str: The ISO8601-complying string to transform
:type datetime_str: str
:return: A datetime
"""
return aniso8601.parse_datetime(datetime_str)

View File

@ -1,35 +0,0 @@
import sys
try:
from collections.abc import OrderedDict
except ImportError:
from collections import OrderedDict
from werkzeug.http import HTTP_STATUS_CODES
PY3 = sys.version_info > (3,)
def http_status_message(code):
"""Maps an HTTP status code to the textual status"""
return HTTP_STATUS_CODES.get(code, '')
def unpack(value):
"""Return a three tuple of data, code, and headers"""
if not isinstance(value, tuple):
return value, 200, {}
try:
data, code, headers = value
return data, code, headers
except ValueError:
pass
try:
data, code = value
return data, code, {}
except ValueError:
pass
return value, 200, {}

View File

@ -1,35 +0,0 @@
import pickle
from Crypto.Cipher import AES
from base64 import b64encode, b64decode
__all__ = "encrypt", "decrypt"
BLOCK_SIZE = 16
INTERRUPT = b'\0' # something impossible to put in a string
PADDING = b'\1'
def pad(data):
return data + INTERRUPT + PADDING * (BLOCK_SIZE - (len(data) + 1) % BLOCK_SIZE)
def strip(data):
return data.rstrip(PADDING).rstrip(INTERRUPT)
def create_cipher(key, seed):
if len(seed) != 16:
raise ValueError("Choose a seed of 16 bytes")
if len(key) != 32:
raise ValueError("Choose a key of 32 bytes")
return AES.new(key, AES.MODE_CBC, seed)
def encrypt(plaintext_data, key, seed):
plaintext_data = pickle.dumps(plaintext_data, pickle.HIGHEST_PROTOCOL) # whatever you give me I need to be able to restitute it
return b64encode(create_cipher(key, seed).encrypt(pad(plaintext_data)))
def decrypt(encrypted_data, key, seed):
return pickle.loads(strip(create_cipher(key, seed).decrypt(b64decode(encrypted_data))))

View File

@ -0,0 +1,5 @@
# -*- coding: utf-8 -*-
__version__ = "0.5.1"
__description__ = (
"Fully featured framework for fast, easy and documented API development with Flask"
)

View File

@ -0,0 +1,38 @@
# -*- coding: utf-8 -*-
from __future__ import absolute_import
from . import fields, reqparse, apidoc, inputs, cors
from .api import Api # noqa
from .marshalling import marshal, marshal_with, marshal_with_field # noqa
from .mask import Mask
from .model import Model, OrderedModel, SchemaModel # noqa
from .namespace import Namespace # noqa
from .resource import Resource # noqa
from .errors import abort, RestError, SpecsError, ValidationError
from .swagger import Swagger
from .__about__ import __version__, __description__
__all__ = (
"__version__",
"__description__",
"Api",
"Resource",
"apidoc",
"marshal",
"marshal_with",
"marshal_with_field",
"Mask",
"Model",
"Namespace",
"OrderedModel",
"SchemaModel",
"abort",
"cors",
"fields",
"inputs",
"reqparse",
"RestError",
"SpecsError",
"Swagger",
"ValidationError",
)

186
libs/flask_restx/_http.py Normal file
View File

@ -0,0 +1,186 @@
# encoding: utf-8
"""
This file is backported from Python 3.5 http built-in module.
"""
from enum import IntEnum
class HTTPStatus(IntEnum):
"""HTTP status codes and reason phrases
Status codes from the following RFCs are all observed:
* RFC 7231: Hypertext Transfer Protocol (HTTP/1.1), obsoletes 2616
* RFC 6585: Additional HTTP Status Codes
* RFC 3229: Delta encoding in HTTP
* RFC 4918: HTTP Extensions for WebDAV, obsoletes 2518
* RFC 5842: Binding Extensions to WebDAV
* RFC 7238: Permanent Redirect
* RFC 2295: Transparent Content Negotiation in HTTP
* RFC 2774: An HTTP Extension Framework
"""
def __new__(cls, value, phrase, description=""):
obj = int.__new__(cls, value)
obj._value_ = value
obj.phrase = phrase
obj.description = description
return obj
def __str__(self):
return str(self.value)
# informational
CONTINUE = 100, "Continue", "Request received, please continue"
SWITCHING_PROTOCOLS = (
101,
"Switching Protocols",
"Switching to new protocol; obey Upgrade header",
)
PROCESSING = 102, "Processing"
# success
OK = 200, "OK", "Request fulfilled, document follows"
CREATED = 201, "Created", "Document created, URL follows"
ACCEPTED = (202, "Accepted", "Request accepted, processing continues off-line")
NON_AUTHORITATIVE_INFORMATION = (
203,
"Non-Authoritative Information",
"Request fulfilled from cache",
)
NO_CONTENT = 204, "No Content", "Request fulfilled, nothing follows"
RESET_CONTENT = 205, "Reset Content", "Clear input form for further input"
PARTIAL_CONTENT = 206, "Partial Content", "Partial content follows"
MULTI_STATUS = 207, "Multi-Status"
ALREADY_REPORTED = 208, "Already Reported"
IM_USED = 226, "IM Used"
# redirection
MULTIPLE_CHOICES = (
300,
"Multiple Choices",
"Object has several resources -- see URI list",
)
MOVED_PERMANENTLY = (
301,
"Moved Permanently",
"Object moved permanently -- see URI list",
)
FOUND = 302, "Found", "Object moved temporarily -- see URI list"
SEE_OTHER = 303, "See Other", "Object moved -- see Method and URL list"
NOT_MODIFIED = (304, "Not Modified", "Document has not changed since given time")
USE_PROXY = (
305,
"Use Proxy",
"You must use proxy specified in Location to access this resource",
)
TEMPORARY_REDIRECT = (
307,
"Temporary Redirect",
"Object moved temporarily -- see URI list",
)
PERMANENT_REDIRECT = (
308,
"Permanent Redirect",
"Object moved temporarily -- see URI list",
)
# client error
BAD_REQUEST = (400, "Bad Request", "Bad request syntax or unsupported method")
UNAUTHORIZED = (401, "Unauthorized", "No permission -- see authorization schemes")
PAYMENT_REQUIRED = (402, "Payment Required", "No payment -- see charging schemes")
FORBIDDEN = (403, "Forbidden", "Request forbidden -- authorization will not help")
NOT_FOUND = (404, "Not Found", "Nothing matches the given URI")
METHOD_NOT_ALLOWED = (
405,
"Method Not Allowed",
"Specified method is invalid for this resource",
)
NOT_ACCEPTABLE = (406, "Not Acceptable", "URI not available in preferred format")
PROXY_AUTHENTICATION_REQUIRED = (
407,
"Proxy Authentication Required",
"You must authenticate with this proxy before proceeding",
)
REQUEST_TIMEOUT = (408, "Request Timeout", "Request timed out; try again later")
CONFLICT = 409, "Conflict", "Request conflict"
GONE = (410, "Gone", "URI no longer exists and has been permanently removed")
LENGTH_REQUIRED = (411, "Length Required", "Client must specify Content-Length")
PRECONDITION_FAILED = (
412,
"Precondition Failed",
"Precondition in headers is false",
)
REQUEST_ENTITY_TOO_LARGE = (413, "Request Entity Too Large", "Entity is too large")
REQUEST_URI_TOO_LONG = (414, "Request-URI Too Long", "URI is too long")
UNSUPPORTED_MEDIA_TYPE = (
415,
"Unsupported Media Type",
"Entity body in unsupported format",
)
REQUESTED_RANGE_NOT_SATISFIABLE = (
416,
"Requested Range Not Satisfiable",
"Cannot satisfy request range",
)
EXPECTATION_FAILED = (
417,
"Expectation Failed",
"Expect condition could not be satisfied",
)
UNPROCESSABLE_ENTITY = 422, "Unprocessable Entity"
LOCKED = 423, "Locked"
FAILED_DEPENDENCY = 424, "Failed Dependency"
UPGRADE_REQUIRED = 426, "Upgrade Required"
PRECONDITION_REQUIRED = (
428,
"Precondition Required",
"The origin server requires the request to be conditional",
)
TOO_MANY_REQUESTS = (
429,
"Too Many Requests",
"The user has sent too many requests in "
'a given amount of time ("rate limiting")',
)
REQUEST_HEADER_FIELDS_TOO_LARGE = (
431,
"Request Header Fields Too Large",
"The server is unwilling to process the request because its header "
"fields are too large",
)
# server errors
INTERNAL_SERVER_ERROR = (
500,
"Internal Server Error",
"Server got itself in trouble",
)
NOT_IMPLEMENTED = (501, "Not Implemented", "Server does not support this operation")
BAD_GATEWAY = (502, "Bad Gateway", "Invalid responses from another server/proxy")
SERVICE_UNAVAILABLE = (
503,
"Service Unavailable",
"The server cannot process the request due to a high load",
)
GATEWAY_TIMEOUT = (
504,
"Gateway Timeout",
"The gateway server did not receive a timely response",
)
HTTP_VERSION_NOT_SUPPORTED = (
505,
"HTTP Version Not Supported",
"Cannot fulfill request",
)
VARIANT_ALSO_NEGOTIATES = 506, "Variant Also Negotiates"
INSUFFICIENT_STORAGE = 507, "Insufficient Storage"
LOOP_DETECTED = 508, "Loop Detected"
NOT_EXTENDED = 510, "Not Extended"
NETWORK_AUTHENTICATION_REQUIRED = (
511,
"Network Authentication Required",
"The client needs to authenticate to gain network access",
)

962
libs/flask_restx/api.py Normal file
View File

@ -0,0 +1,962 @@
# -*- coding: utf-8 -*-
from __future__ import unicode_literals
import difflib
import inspect
from itertools import chain
import logging
import operator
import re
import six
import sys
import warnings
from collections import OrderedDict
from functools import wraps, partial
from types import MethodType
from flask import url_for, request, current_app
from flask import make_response as original_flask_make_response
try:
from flask.helpers import _endpoint_from_view_func
except ImportError:
from flask.scaffold import _endpoint_from_view_func
from flask.signals import got_request_exception
from jsonschema import RefResolver
from werkzeug.utils import cached_property
from werkzeug.datastructures import Headers
from werkzeug.exceptions import (
HTTPException,
MethodNotAllowed,
NotFound,
NotAcceptable,
InternalServerError,
)
from werkzeug import __version__ as werkzeug_version
if werkzeug_version.split('.')[0] >= '2':
from werkzeug.wrappers import Response as BaseResponse
else:
from werkzeug.wrappers import BaseResponse
from . import apidoc
from .mask import ParseError, MaskError
from .namespace import Namespace
from .postman import PostmanCollectionV1
from .resource import Resource
from .swagger import Swagger
from .utils import default_id, camel_to_dash, unpack
from .representations import output_json
from ._http import HTTPStatus
RE_RULES = re.compile("(<.*>)")
# List headers that should never be handled by Flask-RESTX
HEADERS_BLACKLIST = ("Content-Length",)
DEFAULT_REPRESENTATIONS = [("application/json", output_json)]
log = logging.getLogger(__name__)
class Api(object):
"""
The main entry point for the application.
You need to initialize it with a Flask Application: ::
>>> app = Flask(__name__)
>>> api = Api(app)
Alternatively, you can use :meth:`init_app` to set the Flask application
after it has been constructed.
The endpoint parameter prefix all views and resources:
- The API root/documentation will be ``{endpoint}.root``
- A resource registered as 'resource' will be available as ``{endpoint}.resource``
:param flask.Flask|flask.Blueprint app: the Flask application object or a Blueprint
:param str version: The API version (used in Swagger documentation)
:param str title: The API title (used in Swagger documentation)
:param str description: The API description (used in Swagger documentation)
:param str terms_url: The API terms page URL (used in Swagger documentation)
:param str contact: A contact email for the API (used in Swagger documentation)
:param str license: The license associated to the API (used in Swagger documentation)
:param str license_url: The license page URL (used in Swagger documentation)
:param str endpoint: The API base endpoint (default to 'api).
:param str default: The default namespace base name (default to 'default')
:param str default_label: The default namespace label (used in Swagger documentation)
:param str default_mediatype: The default media type to return
:param bool validate: Whether or not the API should perform input payload validation.
:param bool ordered: Whether or not preserve order models and marshalling.
:param str doc: The documentation path. If set to a false value, documentation is disabled.
(Default to '/')
:param list decorators: Decorators to attach to every resource
:param bool catch_all_404s: Use :meth:`handle_error`
to handle 404 errors throughout your app
:param dict authorizations: A Swagger Authorizations declaration as dictionary
:param bool serve_challenge_on_401: Serve basic authentication challenge with 401
responses (default 'False')
:param FormatChecker format_checker: A jsonschema.FormatChecker object that is hooked into
the Model validator. A default or a custom FormatChecker can be provided (e.g., with custom
checkers), otherwise the default action is to not enforce any format validation.
:param url_scheme: If set to a string (e.g. http, https), then the specs_url and base_url will explicitly use this
scheme regardless of how the application is deployed. This is necessary for some deployments behind a reverse
proxy.
"""
def __init__(
self,
app=None,
version="1.0",
title=None,
description=None,
terms_url=None,
license=None,
license_url=None,
contact=None,
contact_url=None,
contact_email=None,
authorizations=None,
security=None,
doc="/",
default_id=default_id,
default="default",
default_label="Default namespace",
validate=None,
tags=None,
prefix="",
ordered=False,
default_mediatype="application/json",
decorators=None,
catch_all_404s=False,
serve_challenge_on_401=False,
format_checker=None,
url_scheme=None,
**kwargs
):
self.version = version
self.title = title or "API"
self.description = description
self.terms_url = terms_url
self.contact = contact
self.contact_email = contact_email
self.contact_url = contact_url
self.license = license
self.license_url = license_url
self.authorizations = authorizations
self.security = security
self.default_id = default_id
self.ordered = ordered
self._validate = validate
self._doc = doc
self._doc_view = None
self._default_error_handler = None
self.tags = tags or []
self.error_handlers = OrderedDict({
ParseError: mask_parse_error_handler,
MaskError: mask_error_handler,
})
self._schema = None
self.models = {}
self._refresolver = None
self.format_checker = format_checker
self.namespaces = []
self.ns_paths = dict()
self.representations = OrderedDict(DEFAULT_REPRESENTATIONS)
self.urls = {}
self.prefix = prefix
self.default_mediatype = default_mediatype
self.decorators = decorators if decorators else []
self.catch_all_404s = catch_all_404s
self.serve_challenge_on_401 = serve_challenge_on_401
self.blueprint_setup = None
self.endpoints = set()
self.resources = []
self.app = None
self.blueprint = None
# must come after self.app initialisation to prevent __getattr__ recursion
# in self._configure_namespace_logger
self.default_namespace = self.namespace(
default,
default_label,
endpoint="{0}-declaration".format(default),
validate=validate,
api=self,
path="/",
)
self.url_scheme = url_scheme
if app is not None:
self.app = app
self.init_app(app)
# super(Api, self).__init__(app, **kwargs)
def init_app(self, app, **kwargs):
"""
Allow to lazy register the API on a Flask application::
>>> app = Flask(__name__)
>>> api = Api()
>>> api.init_app(app)
:param flask.Flask app: the Flask application object
:param str title: The API title (used in Swagger documentation)
:param str description: The API description (used in Swagger documentation)
:param str terms_url: The API terms page URL (used in Swagger documentation)
:param str contact: A contact email for the API (used in Swagger documentation)
:param str license: The license associated to the API (used in Swagger documentation)
:param str license_url: The license page URL (used in Swagger documentation)
:param url_scheme: If set to a string (e.g. http, https), then the specs_url and base_url will explicitly use
this scheme regardless of how the application is deployed. This is necessary for some deployments behind a
reverse proxy.
"""
self.app = app
self.title = kwargs.get("title", self.title)
self.description = kwargs.get("description", self.description)
self.terms_url = kwargs.get("terms_url", self.terms_url)
self.contact = kwargs.get("contact", self.contact)
self.contact_url = kwargs.get("contact_url", self.contact_url)
self.contact_email = kwargs.get("contact_email", self.contact_email)
self.license = kwargs.get("license", self.license)
self.license_url = kwargs.get("license_url", self.license_url)
self.url_scheme = kwargs.get("url_scheme", self.url_scheme)
self._add_specs = kwargs.get("add_specs", True)
# If app is a blueprint, defer the initialization
try:
app.record(self._deferred_blueprint_init)
# Flask.Blueprint has a 'record' attribute, Flask.Api does not
except AttributeError:
self._init_app(app)
else:
self.blueprint = app
def _init_app(self, app):
"""
Perform initialization actions with the given :class:`flask.Flask` object.
:param flask.Flask app: The flask application object
"""
self._register_specs(self.blueprint or app)
self._register_doc(self.blueprint or app)
app.handle_exception = partial(self.error_router, app.handle_exception)
app.handle_user_exception = partial(
self.error_router, app.handle_user_exception
)
if len(self.resources) > 0:
for resource, namespace, urls, kwargs in self.resources:
self._register_view(app, resource, namespace, *urls, **kwargs)
for ns in self.namespaces:
self._configure_namespace_logger(app, ns)
self._register_apidoc(app)
self._validate = (
self._validate
if self._validate is not None
else app.config.get("RESTX_VALIDATE", False)
)
app.config.setdefault("RESTX_MASK_HEADER", "X-Fields")
app.config.setdefault("RESTX_MASK_SWAGGER", True)
app.config.setdefault("RESTX_INCLUDE_ALL_MODELS", False)
# check for deprecated config variable names
if "ERROR_404_HELP" in app.config:
app.config['RESTX_ERROR_404_HELP'] = app.config['ERROR_404_HELP']
warnings.warn(
"'ERROR_404_HELP' config setting is deprecated and will be "
"removed in the future. Use 'RESTX_ERROR_404_HELP' instead.",
DeprecationWarning
)
def __getattr__(self, name):
try:
return getattr(self.default_namespace, name)
except AttributeError:
raise AttributeError("Api does not have {0} attribute".format(name))
def _complete_url(self, url_part, registration_prefix):
"""
This method is used to defer the construction of the final url in
the case that the Api is created with a Blueprint.
:param url_part: The part of the url the endpoint is registered with
:param registration_prefix: The part of the url contributed by the
blueprint. Generally speaking, BlueprintSetupState.url_prefix
"""
parts = (registration_prefix, self.prefix, url_part)
return "".join(part for part in parts if part)
def _register_apidoc(self, app):
conf = app.extensions.setdefault("restx", {})
if not conf.get("apidoc_registered", False):
app.register_blueprint(apidoc.apidoc)
conf["apidoc_registered"] = True
def _register_specs(self, app_or_blueprint):
if self._add_specs:
endpoint = str("specs")
self._register_view(
app_or_blueprint,
SwaggerView,
self.default_namespace,
"/swagger.json",
endpoint=endpoint,
resource_class_args=(self,),
)
self.endpoints.add(endpoint)
def _register_doc(self, app_or_blueprint):
if self._add_specs and self._doc:
# Register documentation before root if enabled
app_or_blueprint.add_url_rule(self._doc, "doc", self.render_doc)
app_or_blueprint.add_url_rule(self.prefix or "/", "root", self.render_root)
def register_resource(self, namespace, resource, *urls, **kwargs):
endpoint = kwargs.pop("endpoint", None)
endpoint = str(endpoint or self.default_endpoint(resource, namespace))
kwargs["endpoint"] = endpoint
self.endpoints.add(endpoint)
if self.app is not None:
self._register_view(self.app, resource, namespace, *urls, **kwargs)
else:
self.resources.append((resource, namespace, urls, kwargs))
return endpoint
def _configure_namespace_logger(self, app, namespace):
for handler in app.logger.handlers:
namespace.logger.addHandler(handler)
namespace.logger.setLevel(app.logger.level)
def _register_view(self, app, resource, namespace, *urls, **kwargs):
endpoint = kwargs.pop("endpoint", None) or camel_to_dash(resource.__name__)
resource_class_args = kwargs.pop("resource_class_args", ())
resource_class_kwargs = kwargs.pop("resource_class_kwargs", {})
# NOTE: 'view_functions' is cleaned up from Blueprint class in Flask 1.0
if endpoint in getattr(app, "view_functions", {}):
previous_view_class = app.view_functions[endpoint].__dict__["view_class"]
# if you override the endpoint with a different class, avoid the
# collision by raising an exception
if previous_view_class != resource:
msg = "This endpoint (%s) is already set to the class %s."
raise ValueError(msg % (endpoint, previous_view_class.__name__))
resource.mediatypes = self.mediatypes_method() # Hacky
resource.endpoint = endpoint
resource_func = self.output(
resource.as_view(
endpoint, self, *resource_class_args, **resource_class_kwargs
)
)
# Apply Namespace and Api decorators to a resource
for decorator in chain(namespace.decorators, self.decorators):
resource_func = decorator(resource_func)
for url in urls:
# If this Api has a blueprint
if self.blueprint:
# And this Api has been setup
if self.blueprint_setup:
# Set the rule to a string directly, as the blueprint is already
# set up.
self.blueprint_setup.add_url_rule(
url, view_func=resource_func, **kwargs
)
continue
else:
# Set the rule to a function that expects the blueprint prefix
# to construct the final url. Allows deferment of url finalization
# in the case that the associated Blueprint has not yet been
# registered to an application, so we can wait for the registration
# prefix
rule = partial(self._complete_url, url)
else:
# If we've got no Blueprint, just build a url with no prefix
rule = self._complete_url(url, "")
# Add the url to the application or blueprint
app.add_url_rule(rule, view_func=resource_func, **kwargs)
def output(self, resource):
"""
Wraps a resource (as a flask view function),
for cases where the resource does not directly return a response object
:param resource: The resource as a flask view function
"""
@wraps(resource)
def wrapper(*args, **kwargs):
resp = resource(*args, **kwargs)
if isinstance(resp, BaseResponse):
return resp
data, code, headers = unpack(resp)
return self.make_response(data, code, headers=headers)
return wrapper
def make_response(self, data, *args, **kwargs):
"""
Looks up the representation transformer for the requested media
type, invoking the transformer to create a response object. This
defaults to default_mediatype if no transformer is found for the
requested mediatype. If default_mediatype is None, a 406 Not
Acceptable response will be sent as per RFC 2616 section 14.1
:param data: Python object containing response data to be transformed
"""
default_mediatype = (
kwargs.pop("fallback_mediatype", None) or self.default_mediatype
)
mediatype = request.accept_mimetypes.best_match(
self.representations, default=default_mediatype,
)
if mediatype is None:
raise NotAcceptable()
if mediatype in self.representations:
resp = self.representations[mediatype](data, *args, **kwargs)
resp.headers["Content-Type"] = mediatype
return resp
elif mediatype == "text/plain":
resp = original_flask_make_response(str(data), *args, **kwargs)
resp.headers["Content-Type"] = "text/plain"
return resp
else:
raise InternalServerError()
def documentation(self, func):
"""A decorator to specify a view function for the documentation"""
self._doc_view = func
return func
def render_root(self):
self.abort(HTTPStatus.NOT_FOUND)
def render_doc(self):
"""Override this method to customize the documentation page"""
if self._doc_view:
return self._doc_view()
elif not self._doc:
self.abort(HTTPStatus.NOT_FOUND)
return apidoc.ui_for(self)
def default_endpoint(self, resource, namespace):
"""
Provide a default endpoint for a resource on a given namespace.
Endpoints are ensured not to collide.
Override this method specify a custom algorithm for default endpoint.
:param Resource resource: the resource for which we want an endpoint
:param Namespace namespace: the namespace holding the resource
:returns str: An endpoint name
"""
endpoint = camel_to_dash(resource.__name__)
if namespace is not self.default_namespace:
endpoint = "{ns.name}_{endpoint}".format(ns=namespace, endpoint=endpoint)
if endpoint in self.endpoints:
suffix = 2
while True:
new_endpoint = "{base}_{suffix}".format(base=endpoint, suffix=suffix)
if new_endpoint not in self.endpoints:
endpoint = new_endpoint
break
suffix += 1
return endpoint
def get_ns_path(self, ns):
return self.ns_paths.get(ns)
def ns_urls(self, ns, urls):
path = self.get_ns_path(ns) or ns.path
return [path + url for url in urls]
def add_namespace(self, ns, path=None):
"""
This method registers resources from namespace for current instance of api.
You can use argument path for definition custom prefix url for namespace.
:param Namespace ns: the namespace
:param path: registration prefix of namespace
"""
if ns not in self.namespaces:
self.namespaces.append(ns)
if self not in ns.apis:
ns.apis.append(self)
# Associate ns with prefix-path
if path is not None:
self.ns_paths[ns] = path
# Register resources
for r in ns.resources:
urls = self.ns_urls(ns, r.urls)
self.register_resource(ns, r.resource, *urls, **r.kwargs)
# Register models
for name, definition in six.iteritems(ns.models):
self.models[name] = definition
if not self.blueprint and self.app is not None:
self._configure_namespace_logger(self.app, ns)
def namespace(self, *args, **kwargs):
"""
A namespace factory.
:returns Namespace: a new namespace instance
"""
kwargs["ordered"] = kwargs.get("ordered", self.ordered)
ns = Namespace(*args, **kwargs)
self.add_namespace(ns)
return ns
def endpoint(self, name):
if self.blueprint:
return "{0}.{1}".format(self.blueprint.name, name)
else:
return name
@property
def specs_url(self):
"""
The Swagger specifications relative url (ie. `swagger.json`). If
the spec_url_scheme attribute is set, then the full url is provided instead
(e.g. http://localhost/swaggger.json).
:rtype: str
"""
external = None if self.url_scheme is None else True
return url_for(
self.endpoint("specs"), _scheme=self.url_scheme, _external=external
)
@property
def base_url(self):
"""
The API base absolute url
:rtype: str
"""
return url_for(self.endpoint("root"), _scheme=self.url_scheme, _external=True)
@property
def base_path(self):
"""
The API path
:rtype: str
"""
return url_for(self.endpoint("root"), _external=False)
@cached_property
def __schema__(self):
"""
The Swagger specifications/schema for this API
:returns dict: the schema as a serializable dict
"""
if not self._schema:
try:
self._schema = Swagger(self).as_dict()
except Exception:
# Log the source exception for debugging purpose
# and return an error message
msg = "Unable to render schema"
log.exception(msg) # This will provide a full traceback
return {"error": msg}
return self._schema
@property
def _own_and_child_error_handlers(self):
rv = OrderedDict()
rv.update(self.error_handlers)
for ns in self.namespaces:
for exception, handler in six.iteritems(ns.error_handlers):
rv[exception] = handler
return rv
def errorhandler(self, exception):
"""A decorator to register an error handler for a given exception"""
if inspect.isclass(exception) and issubclass(exception, Exception):
# Register an error handler for a given exception
def wrapper(func):
self.error_handlers[exception] = func
return func
return wrapper
else:
# Register the default error handler
self._default_error_handler = exception
return exception
def owns_endpoint(self, endpoint):
"""
Tests if an endpoint name (not path) belongs to this Api.
Takes into account the Blueprint name part of the endpoint name.
:param str endpoint: The name of the endpoint being checked
:return: bool
"""
if self.blueprint:
if endpoint.startswith(self.blueprint.name):
endpoint = endpoint.split(self.blueprint.name + ".", 1)[-1]
else:
return False
return endpoint in self.endpoints
def _should_use_fr_error_handler(self):
"""
Determine if error should be handled with FR or default Flask
The goal is to return Flask error handlers for non-FR-related routes,
and FR errors (with the correct media type) for FR endpoints. This
method currently handles 404 and 405 errors.
:return: bool
"""
adapter = current_app.create_url_adapter(request)
try:
adapter.match()
except MethodNotAllowed as e:
# Check if the other HTTP methods at this url would hit the Api
valid_route_method = e.valid_methods[0]
rule, _ = adapter.match(method=valid_route_method, return_rule=True)
return self.owns_endpoint(rule.endpoint)
except NotFound:
return self.catch_all_404s
except Exception:
# Werkzeug throws other kinds of exceptions, such as Redirect
pass
def _has_fr_route(self):
"""Encapsulating the rules for whether the request was to a Flask endpoint"""
# 404's, 405's, which might not have a url_rule
if self._should_use_fr_error_handler():
return True
# for all other errors, just check if FR dispatched the route
if not request.url_rule:
return False
return self.owns_endpoint(request.url_rule.endpoint)
def error_router(self, original_handler, e):
"""
This function decides whether the error occurred in a flask-restx
endpoint or not. If it happened in a flask-restx endpoint, our
handler will be dispatched. If it happened in an unrelated view, the
app's original error handler will be dispatched.
In the event that the error occurred in a flask-restx endpoint but
the local handler can't resolve the situation, the router will fall
back onto the original_handler as last resort.
:param function original_handler: the original Flask error handler for the app
:param Exception e: the exception raised while handling the request
"""
if self._has_fr_route():
try:
return self.handle_error(e)
except Exception as f:
return original_handler(f)
return original_handler(e)
def handle_error(self, e):
"""
Error handler for the API transforms a raised exception into a Flask response,
with the appropriate HTTP status code and body.
:param Exception e: the raised Exception object
"""
# When propagate_exceptions is set, do not return the exception to the
# client if a handler is configured for the exception.
if (
not isinstance(e, HTTPException)
and current_app.propagate_exceptions
and not isinstance(e, tuple(self._own_and_child_error_handlers.keys()))
):
exc_type, exc_value, tb = sys.exc_info()
if exc_value is e:
raise
else:
raise e
include_message_in_response = current_app.config.get(
"ERROR_INCLUDE_MESSAGE", True
)
default_data = {}
headers = Headers()
for typecheck, handler in six.iteritems(self._own_and_child_error_handlers):
if isinstance(e, typecheck):
result = handler(e)
default_data, code, headers = unpack(
result, HTTPStatus.INTERNAL_SERVER_ERROR
)
break
else:
# Flask docs say: "This signal is not sent for HTTPException or other exceptions that have error handlers
# registered, unless the exception was raised from an error handler."
got_request_exception.send(current_app._get_current_object(), exception=e)
if isinstance(e, HTTPException):
code = HTTPStatus(e.code)
if include_message_in_response:
default_data = {"message": getattr(e, "description", code.phrase)}
headers = e.get_response().headers
elif self._default_error_handler:
result = self._default_error_handler(e)
default_data, code, headers = unpack(
result, HTTPStatus.INTERNAL_SERVER_ERROR
)
else:
code = HTTPStatus.INTERNAL_SERVER_ERROR
if include_message_in_response:
default_data = {
"message": code.phrase,
}
if include_message_in_response:
default_data["message"] = default_data.get("message", str(e))
data = getattr(e, "data", default_data)
fallback_mediatype = None
if code >= HTTPStatus.INTERNAL_SERVER_ERROR:
exc_info = sys.exc_info()
if exc_info[1] is None:
exc_info = None
current_app.log_exception(exc_info)
elif (
code == HTTPStatus.NOT_FOUND
and current_app.config.get("RESTX_ERROR_404_HELP", True)
and include_message_in_response
):
data["message"] = self._help_on_404(data.get("message", None))
elif code == HTTPStatus.NOT_ACCEPTABLE and self.default_mediatype is None:
# if we are handling NotAcceptable (406), make sure that
# make_response uses a representation we support as the
# default mediatype (so that make_response doesn't throw
# another NotAcceptable error).
supported_mediatypes = list(self.representations.keys())
fallback_mediatype = (
supported_mediatypes[0] if supported_mediatypes else "text/plain"
)
# Remove blacklisted headers
for header in HEADERS_BLACKLIST:
headers.pop(header, None)
resp = self.make_response(
data, code, headers, fallback_mediatype=fallback_mediatype
)
if code == HTTPStatus.UNAUTHORIZED:
resp = self.unauthorized(resp)
return resp
def _help_on_404(self, message=None):
rules = dict(
[
(RE_RULES.sub("", rule.rule), rule.rule)
for rule in current_app.url_map.iter_rules()
]
)
close_matches = difflib.get_close_matches(request.path, rules.keys())
if close_matches:
# If we already have a message, add punctuation and continue it.
message = "".join(
(
(message.rstrip(".") + ". ") if message else "",
"You have requested this URI [",
request.path,
"] but did you mean ",
" or ".join((rules[match] for match in close_matches)),
" ?",
)
)
return message
def as_postman(self, urlvars=False, swagger=False):
"""
Serialize the API as Postman collection (v1)
:param bool urlvars: whether to include or not placeholders for query strings
:param bool swagger: whether to include or not the swagger.json specifications
"""
return PostmanCollectionV1(self, swagger=swagger).as_dict(urlvars=urlvars)
@property
def payload(self):
"""Store the input payload in the current request context"""
return request.get_json()
@property
def refresolver(self):
if not self._refresolver:
self._refresolver = RefResolver.from_schema(self.__schema__)
return self._refresolver
@staticmethod
def _blueprint_setup_add_url_rule_patch(
blueprint_setup, rule, endpoint=None, view_func=None, **options
):
"""
Method used to patch BlueprintSetupState.add_url_rule for setup
state instance corresponding to this Api instance. Exists primarily
to enable _complete_url's function.
:param blueprint_setup: The BlueprintSetupState instance (self)
:param rule: A string or callable that takes a string and returns a
string(_complete_url) that is the url rule for the endpoint
being registered
:param endpoint: See BlueprintSetupState.add_url_rule
:param view_func: See BlueprintSetupState.add_url_rule
:param **options: See BlueprintSetupState.add_url_rule
"""
if callable(rule):
rule = rule(blueprint_setup.url_prefix)
elif blueprint_setup.url_prefix:
rule = blueprint_setup.url_prefix + rule
options.setdefault("subdomain", blueprint_setup.subdomain)
if endpoint is None:
endpoint = _endpoint_from_view_func(view_func)
defaults = blueprint_setup.url_defaults
if "defaults" in options:
defaults = dict(defaults, **options.pop("defaults"))
blueprint_setup.app.add_url_rule(
rule,
"%s.%s" % (blueprint_setup.blueprint.name, endpoint),
view_func,
defaults=defaults,
**options
)
def _deferred_blueprint_init(self, setup_state):
"""
Synchronize prefix between blueprint/api and registration options, then
perform initialization with setup_state.app :class:`flask.Flask` object.
When a :class:`flask_restx.Api` object is initialized with a blueprint,
this method is recorded on the blueprint to be run when the blueprint is later
registered to a :class:`flask.Flask` object. This method also monkeypatches
BlueprintSetupState.add_url_rule with _blueprint_setup_add_url_rule_patch.
:param setup_state: The setup state object passed to deferred functions
during blueprint registration
:type setup_state: flask.blueprints.BlueprintSetupState
"""
self.blueprint_setup = setup_state
if setup_state.add_url_rule.__name__ != "_blueprint_setup_add_url_rule_patch":
setup_state._original_add_url_rule = setup_state.add_url_rule
setup_state.add_url_rule = MethodType(
Api._blueprint_setup_add_url_rule_patch, setup_state
)
if not setup_state.first_registration:
raise ValueError("flask-restx blueprints can only be registered once.")
self._init_app(setup_state.app)
def mediatypes_method(self):
"""Return a method that returns a list of mediatypes"""
return lambda resource_cls: self.mediatypes() + [self.default_mediatype]
def mediatypes(self):
"""Returns a list of requested mediatypes sent in the Accept header"""
return [
h
for h, q in sorted(
request.accept_mimetypes, key=operator.itemgetter(1), reverse=True
)
]
def representation(self, mediatype):
"""
Allows additional representation transformers to be declared for the
api. Transformers are functions that must be decorated with this
method, passing the mediatype the transformer represents. Three
arguments are passed to the transformer:
* The data to be represented in the response body
* The http status code
* A dictionary of headers
The transformer should convert the data appropriately for the mediatype
and return a Flask response object.
Ex::
@api.representation('application/xml')
def xml(data, code, headers):
resp = make_response(convert_data_to_xml(data), code)
resp.headers.extend(headers)
return resp
"""
def wrapper(func):
self.representations[mediatype] = func
return func
return wrapper
def unauthorized(self, response):
"""Given a response, change it to ask for credentials"""
if self.serve_challenge_on_401:
realm = current_app.config.get("HTTP_BASIC_AUTH_REALM", "flask-restx")
challenge = '{0} realm="{1}"'.format("Basic", realm)
response.headers["WWW-Authenticate"] = challenge
return response
def url_for(self, resource, **values):
"""
Generates a URL to the given resource.
Works like :func:`flask.url_for`.
"""
endpoint = resource.endpoint
if self.blueprint:
endpoint = "{0}.{1}".format(self.blueprint.name, endpoint)
return url_for(endpoint, **values)
class SwaggerView(Resource):
"""Render the Swagger specifications as JSON"""
def get(self):
schema = self.api.__schema__
return (
schema,
HTTPStatus.INTERNAL_SERVER_ERROR if "error" in schema else HTTPStatus.OK,
)
def mediatypes(self):
return ["application/json"]
def mask_parse_error_handler(error):
"""When a mask can't be parsed"""
return {"message": "Mask parse error: {0}".format(error)}, HTTPStatus.BAD_REQUEST
def mask_error_handler(error):
"""When any error occurs on mask"""
return {"message": "Mask error: {0}".format(error)}, HTTPStatus.BAD_REQUEST

View File

@ -0,0 +1,38 @@
# -*- coding: utf-8 -*-
from __future__ import unicode_literals
from flask import url_for, Blueprint, render_template
class Apidoc(Blueprint):
"""
Allow to know if the blueprint has already been registered
until https://github.com/mitsuhiko/flask/pull/1301 is merged
"""
def __init__(self, *args, **kwargs):
self.registered = False
super(Apidoc, self).__init__(*args, **kwargs)
def register(self, *args, **kwargs):
super(Apidoc, self).register(*args, **kwargs)
self.registered = True
apidoc = Apidoc(
"restx_doc",
__name__,
template_folder="templates",
static_folder="static",
static_url_path="/swaggerui",
)
@apidoc.add_app_template_global
def swagger_static(filename):
return url_for("restx_doc.static", filename=filename)
def ui_for(api):
"""Render a SwaggerUI for a given API"""
return render_template("swagger-ui.html", title=api.title, specs_url=api.specs_url)

View File

@ -1,22 +1,32 @@
# -*- coding: utf-8 -*-
from __future__ import unicode_literals
from datetime import timedelta
from flask import make_response, request, current_app
from functools import update_wrapper
def crossdomain(origin=None, methods=None, headers=None, expose_headers=None,
max_age=21600, attach_to_all=True,
automatic_options=True, credentials=False):
def crossdomain(
origin=None,
methods=None,
headers=None,
expose_headers=None,
max_age=21600,
attach_to_all=True,
automatic_options=True,
credentials=False,
):
"""
http://flask.pocoo.org/snippets/56/
https://web.archive.org/web/20190128010149/http://flask.pocoo.org/snippets/56/
"""
if methods is not None:
methods = ', '.join(sorted(x.upper() for x in methods))
methods = ", ".join(sorted(x.upper() for x in methods))
if headers is not None and not isinstance(headers, str):
headers = ', '.join(x.upper() for x in headers)
headers = ", ".join(x.upper() for x in headers)
if expose_headers is not None and not isinstance(expose_headers, str):
expose_headers = ', '.join(x.upper() for x in expose_headers)
expose_headers = ", ".join(x.upper() for x in expose_headers)
if not isinstance(origin, str):
origin = ', '.join(origin)
origin = ", ".join(origin)
if isinstance(max_age, timedelta):
max_age = max_age.total_seconds()
@ -25,30 +35,31 @@ def crossdomain(origin=None, methods=None, headers=None, expose_headers=None,
return methods
options_resp = current_app.make_default_options_response()
return options_resp.headers['allow']
return options_resp.headers["allow"]
def decorator(f):
def wrapped_function(*args, **kwargs):
if automatic_options and request.method == 'OPTIONS':
if automatic_options and request.method == "OPTIONS":
resp = current_app.make_default_options_response()
else:
resp = make_response(f(*args, **kwargs))
if not attach_to_all and request.method != 'OPTIONS':
if not attach_to_all and request.method != "OPTIONS":
return resp
h = resp.headers
h['Access-Control-Allow-Origin'] = origin
h['Access-Control-Allow-Methods'] = get_methods()
h['Access-Control-Max-Age'] = str(max_age)
h["Access-Control-Allow-Origin"] = origin
h["Access-Control-Allow-Methods"] = get_methods()
h["Access-Control-Max-Age"] = str(max_age)
if credentials:
h['Access-Control-Allow-Credentials'] = 'true'
h["Access-Control-Allow-Credentials"] = "true"
if headers is not None:
h['Access-Control-Allow-Headers'] = headers
h["Access-Control-Allow-Headers"] = headers
if expose_headers is not None:
h['Access-Control-Expose-Headers'] = expose_headers
h["Access-Control-Expose-Headers"] = expose_headers
return resp
f.provide_automatic_options = False
return update_wrapper(wrapped_function, f)
return decorator

View File

@ -0,0 +1,59 @@
# -*- coding: utf-8 -*-
from __future__ import unicode_literals
import flask
from werkzeug.exceptions import HTTPException
from ._http import HTTPStatus
__all__ = (
"abort",
"RestError",
"ValidationError",
"SpecsError",
)
def abort(code=HTTPStatus.INTERNAL_SERVER_ERROR, message=None, **kwargs):
"""
Properly abort the current request.
Raise a `HTTPException` for the given status `code`.
Attach any keyword arguments to the exception for later processing.
:param int code: The associated HTTP status code
:param str message: An optional details message
:param kwargs: Any additional data to pass to the error payload
:raise HTTPException:
"""
try:
flask.abort(code)
except HTTPException as e:
if message:
kwargs["message"] = str(message)
if kwargs:
e.data = kwargs
raise
class RestError(Exception):
"""Base class for all Flask-RESTX Errors"""
def __init__(self, msg):
self.msg = msg
def __str__(self):
return self.msg
class ValidationError(RestError):
"""A helper class for validation errors."""
pass
class SpecsError(RestError):
"""A helper class for incoherent specifications."""
pass

908
libs/flask_restx/fields.py Normal file
View File

@ -0,0 +1,908 @@
# -*- coding: utf-8 -*-
from __future__ import unicode_literals
import re
import fnmatch
import inspect
from calendar import timegm
from datetime import date, datetime
from decimal import Decimal, ROUND_HALF_EVEN
from email.utils import formatdate
from six import iteritems, itervalues, text_type, string_types
from six.moves.urllib.parse import urlparse, urlunparse
from flask import url_for, request
from werkzeug.utils import cached_property
from .inputs import (
date_from_iso8601,
datetime_from_iso8601,
datetime_from_rfc822,
boolean,
)
from .errors import RestError
from .marshalling import marshal
from .utils import camel_to_dash, not_none
__all__ = (
"Raw",
"String",
"FormattedString",
"Url",
"DateTime",
"Date",
"Boolean",
"Integer",
"Float",
"Arbitrary",
"Fixed",
"Nested",
"List",
"ClassName",
"Polymorph",
"Wildcard",
"StringMixin",
"MinMaxMixin",
"NumberMixin",
"MarshallingError",
)
class MarshallingError(RestError):
"""
This is an encapsulating Exception in case of marshalling error.
"""
def __init__(self, underlying_exception):
# just put the contextual representation of the error to hint on what
# went wrong without exposing internals
super(MarshallingError, self).__init__(text_type(underlying_exception))
def is_indexable_but_not_string(obj):
return not hasattr(obj, "strip") and hasattr(obj, "__iter__")
def is_integer_indexable(obj):
return isinstance(obj, list) or isinstance(obj, tuple)
def get_value(key, obj, default=None):
"""Helper for pulling a keyed value off various types of objects"""
if isinstance(key, int):
return _get_value_for_key(key, obj, default)
elif callable(key):
return key(obj)
else:
return _get_value_for_keys(key.split("."), obj, default)
def _get_value_for_keys(keys, obj, default):
if len(keys) == 1:
return _get_value_for_key(keys[0], obj, default)
else:
return _get_value_for_keys(
keys[1:], _get_value_for_key(keys[0], obj, default), default
)
def _get_value_for_key(key, obj, default):
if is_indexable_but_not_string(obj):
try:
return obj[key]
except (IndexError, TypeError, KeyError):
pass
if is_integer_indexable(obj):
try:
return obj[int(key)]
except (IndexError, TypeError, ValueError):
pass
return getattr(obj, key, default)
def to_marshallable_type(obj):
"""
Helper for converting an object to a dictionary only if it is not
dictionary already or an indexable object nor a simple type
"""
if obj is None:
return None # make it idempotent for None
if hasattr(obj, "__marshallable__"):
return obj.__marshallable__()
if hasattr(obj, "__getitem__"):
return obj # it is indexable it is ok
return dict(obj.__dict__)
class Raw(object):
"""
Raw provides a base field class from which others should extend. It
applies no formatting by default, and should only be used in cases where
data does not need to be formatted before being serialized. Fields should
throw a :class:`MarshallingError` in case of parsing problem.
:param default: The default value for the field, if no value is
specified.
:param attribute: If the public facing value differs from the internal
value, use this to retrieve a different attribute from the response
than the publicly named value.
:param str title: The field title (for documentation purpose)
:param str description: The field description (for documentation purpose)
:param bool required: Is the field required ?
:param bool readonly: Is the field read only ? (for documentation purpose)
:param example: An optional data example (for documentation purpose)
:param callable mask: An optional mask function to be applied to output
"""
#: The JSON/Swagger schema type
__schema_type__ = "object"
#: The JSON/Swagger schema format
__schema_format__ = None
#: An optional JSON/Swagger schema example
__schema_example__ = None
def __init__(
self,
default=None,
attribute=None,
title=None,
description=None,
required=None,
readonly=None,
example=None,
mask=None,
**kwargs
):
self.attribute = attribute
self.default = default
self.title = title
self.description = description
self.required = required
self.readonly = readonly
self.example = example if example is not None else self.__schema_example__
self.mask = mask
def format(self, value):
"""
Formats a field's value. No-op by default - field classes that
modify how the value of existing object keys should be presented should
override this and apply the appropriate formatting.
:param value: The value to format
:raises MarshallingError: In case of formatting problem
Ex::
class TitleCase(Raw):
def format(self, value):
return unicode(value).title()
"""
return value
def output(self, key, obj, **kwargs):
"""
Pulls the value for the given key from the object, applies the
field's formatting and returns the result. If the key is not found
in the object, returns the default value. Field classes that create
values which do not require the existence of the key in the object
should override this and return the desired value.
:raises MarshallingError: In case of formatting problem
"""
value = get_value(key if self.attribute is None else self.attribute, obj)
if value is None:
default = self._v("default")
return self.format(default) if default else default
try:
data = self.format(value)
except MarshallingError as e:
msg = 'Unable to marshal field "{0}" value "{1}": {2}'.format(
key, value, str(e)
)
raise MarshallingError(msg)
return self.mask.apply(data) if self.mask else data
def _v(self, key):
"""Helper for getting a value from attribute allowing callable"""
value = getattr(self, key)
return value() if callable(value) else value
@cached_property
def __schema__(self):
return not_none(self.schema())
def schema(self):
return {
"type": self.__schema_type__,
"format": self.__schema_format__,
"title": self.title,
"description": self.description,
"readOnly": self.readonly,
"default": self._v("default"),
"example": self.example,
}
class Nested(Raw):
"""
Allows you to nest one set of fields inside another.
See :ref:`nested-field` for more information
:param dict model: The model dictionary to nest
:param bool allow_null: Whether to return None instead of a dictionary
with null keys, if a nested dictionary has all-null keys
:param bool skip_none: Optional key will be used to eliminate inner fields
which value is None or the inner field's key not
exist in data
:param kwargs: If ``default`` keyword argument is present, a nested
dictionary will be marshaled as its value if nested dictionary is
all-null keys (e.g. lets you return an empty JSON object instead of
null)
"""
__schema_type__ = None
def __init__(
self, model, allow_null=False, skip_none=False, as_list=False, **kwargs
):
self.model = model
self.as_list = as_list
self.allow_null = allow_null
self.skip_none = skip_none
super(Nested, self).__init__(**kwargs)
@property
def nested(self):
return getattr(self.model, "resolved", self.model)
def output(self, key, obj, ordered=False, **kwargs):
value = get_value(key if self.attribute is None else self.attribute, obj)
if value is None:
if self.allow_null:
return None
elif self.default is not None:
return self.default
return marshal(value, self.nested, skip_none=self.skip_none, ordered=ordered)
def schema(self):
schema = super(Nested, self).schema()
ref = "#/definitions/{0}".format(self.nested.name)
if self.as_list:
schema["type"] = "array"
schema["items"] = {"$ref": ref}
elif any(schema.values()):
# There is already some properties in the schema
allOf = schema.get("allOf", [])
allOf.append({"$ref": ref})
schema["allOf"] = allOf
else:
schema["$ref"] = ref
return schema
def clone(self, mask=None):
kwargs = self.__dict__.copy()
model = kwargs.pop("model")
if mask:
model = mask.apply(model.resolved if hasattr(model, "resolved") else model)
return self.__class__(model, **kwargs)
class List(Raw):
"""
Field for marshalling lists of other fields.
See :ref:`list-field` for more information.
:param cls_or_instance: The field type the list will contain.
"""
def __init__(self, cls_or_instance, **kwargs):
self.min_items = kwargs.pop("min_items", None)
self.max_items = kwargs.pop("max_items", None)
self.unique = kwargs.pop("unique", None)
super(List, self).__init__(**kwargs)
error_msg = "The type of the list elements must be a subclass of fields.Raw"
if isinstance(cls_or_instance, type):
if not issubclass(cls_or_instance, Raw):
raise MarshallingError(error_msg)
self.container = cls_or_instance()
else:
if not isinstance(cls_or_instance, Raw):
raise MarshallingError(error_msg)
self.container = cls_or_instance
def format(self, value):
# Convert all instances in typed list to container type
if isinstance(value, set):
value = list(value)
is_nested = isinstance(self.container, Nested) or type(self.container) is Raw
def is_attr(val):
return self.container.attribute and hasattr(val, self.container.attribute)
if value is None:
return []
return [
self.container.output(
idx,
val
if (isinstance(val, dict) or is_attr(val)) and not is_nested
else value,
)
for idx, val in enumerate(value)
]
def output(self, key, data, ordered=False, **kwargs):
value = get_value(key if self.attribute is None else self.attribute, data)
# we cannot really test for external dict behavior
if is_indexable_but_not_string(value) and not isinstance(value, dict):
return self.format(value)
if value is None:
return self._v("default")
return [marshal(value, self.container.nested)]
def schema(self):
schema = super(List, self).schema()
schema.update(
minItems=self._v("min_items"),
maxItems=self._v("max_items"),
uniqueItems=self._v("unique"),
)
schema["type"] = "array"
schema["items"] = self.container.__schema__
return schema
def clone(self, mask=None):
kwargs = self.__dict__.copy()
model = kwargs.pop("container")
if mask:
model = mask.apply(model)
return self.__class__(model, **kwargs)
class StringMixin(object):
__schema_type__ = "string"
def __init__(self, *args, **kwargs):
self.min_length = kwargs.pop("min_length", None)
self.max_length = kwargs.pop("max_length", None)
self.pattern = kwargs.pop("pattern", None)
super(StringMixin, self).__init__(*args, **kwargs)
def schema(self):
schema = super(StringMixin, self).schema()
schema.update(
minLength=self._v("min_length"),
maxLength=self._v("max_length"),
pattern=self._v("pattern"),
)
return schema
class MinMaxMixin(object):
def __init__(self, *args, **kwargs):
self.minimum = kwargs.pop("min", None)
self.exclusiveMinimum = kwargs.pop("exclusiveMin", None)
self.maximum = kwargs.pop("max", None)
self.exclusiveMaximum = kwargs.pop("exclusiveMax", None)
super(MinMaxMixin, self).__init__(*args, **kwargs)
def schema(self):
schema = super(MinMaxMixin, self).schema()
schema.update(
minimum=self._v("minimum"),
exclusiveMinimum=self._v("exclusiveMinimum"),
maximum=self._v("maximum"),
exclusiveMaximum=self._v("exclusiveMaximum"),
)
return schema
class NumberMixin(MinMaxMixin):
__schema_type__ = "number"
def __init__(self, *args, **kwargs):
self.multiple = kwargs.pop("multiple", None)
super(NumberMixin, self).__init__(*args, **kwargs)
def schema(self):
schema = super(NumberMixin, self).schema()
schema.update(multipleOf=self._v("multiple"))
return schema
class String(StringMixin, Raw):
"""
Marshal a value as a string. Uses ``six.text_type`` so values will
be converted to :class:`unicode` in python2 and :class:`str` in
python3.
"""
def __init__(self, *args, **kwargs):
self.enum = kwargs.pop("enum", None)
self.discriminator = kwargs.pop("discriminator", None)
super(String, self).__init__(*args, **kwargs)
self.required = self.discriminator or self.required
def format(self, value):
try:
return text_type(value)
except ValueError as ve:
raise MarshallingError(ve)
def schema(self):
enum = self._v("enum")
schema = super(String, self).schema()
if enum:
schema.update(enum=enum)
if enum and schema["example"] is None:
schema["example"] = enum[0]
return schema
class Integer(NumberMixin, Raw):
"""
Field for outputting an integer value.
:param int default: The default value for the field, if no value is specified.
"""
__schema_type__ = "integer"
def format(self, value):
try:
if value is None:
return self.default
return int(value)
except (ValueError, TypeError) as ve:
raise MarshallingError(ve)
class Float(NumberMixin, Raw):
"""
A double as IEEE-754 double precision.
ex : 3.141592653589793 3.1415926535897933e-06 3.141592653589793e+24 nan inf -inf
"""
def format(self, value):
try:
if value is None:
return self.default
return float(value)
except (ValueError, TypeError) as ve:
raise MarshallingError(ve)
class Arbitrary(NumberMixin, Raw):
"""
A floating point number with an arbitrary precision.
ex: 634271127864378216478362784632784678324.23432
"""
def format(self, value):
return text_type(Decimal(value))
ZERO = Decimal()
class Fixed(NumberMixin, Raw):
"""
A decimal number with a fixed precision.
"""
def __init__(self, decimals=5, **kwargs):
super(Fixed, self).__init__(**kwargs)
self.precision = Decimal("0." + "0" * (decimals - 1) + "1")
def format(self, value):
dvalue = Decimal(value)
if not dvalue.is_normal() and dvalue != ZERO:
raise MarshallingError("Invalid Fixed precision number.")
return text_type(dvalue.quantize(self.precision, rounding=ROUND_HALF_EVEN))
class Boolean(Raw):
"""
Field for outputting a boolean value.
Empty collections such as ``""``, ``{}``, ``[]``, etc. will be converted to ``False``.
"""
__schema_type__ = "boolean"
def format(self, value):
return boolean(value)
class DateTime(MinMaxMixin, Raw):
"""
Return a formatted datetime string in UTC. Supported formats are RFC 822 and ISO 8601.
See :func:`email.utils.formatdate` for more info on the RFC 822 format.
See :meth:`datetime.datetime.isoformat` for more info on the ISO 8601 format.
:param str dt_format: ``rfc822`` or ``iso8601``
"""
__schema_type__ = "string"
__schema_format__ = "date-time"
def __init__(self, dt_format="iso8601", **kwargs):
super(DateTime, self).__init__(**kwargs)
self.dt_format = dt_format
def parse(self, value):
if value is None:
return None
elif isinstance(value, string_types):
parser = (
datetime_from_iso8601
if self.dt_format == "iso8601"
else datetime_from_rfc822
)
return parser(value)
elif isinstance(value, datetime):
return value
elif isinstance(value, date):
return datetime(value.year, value.month, value.day)
else:
raise ValueError("Unsupported DateTime format")
def format(self, value):
try:
value = self.parse(value)
if self.dt_format == "iso8601":
return self.format_iso8601(value)
elif self.dt_format == "rfc822":
return self.format_rfc822(value)
else:
raise MarshallingError("Unsupported date format %s" % self.dt_format)
except (AttributeError, ValueError) as e:
raise MarshallingError(e)
def format_rfc822(self, dt):
"""
Turn a datetime object into a formatted date.
:param datetime dt: The datetime to transform
:return: A RFC 822 formatted date string
"""
return formatdate(timegm(dt.utctimetuple()))
def format_iso8601(self, dt):
"""
Turn a datetime object into an ISO8601 formatted date.
:param datetime dt: The datetime to transform
:return: A ISO 8601 formatted date string
"""
return dt.isoformat()
def _for_schema(self, name):
value = self.parse(self._v(name))
return self.format(value) if value else None
def schema(self):
schema = super(DateTime, self).schema()
schema["default"] = self._for_schema("default")
schema["minimum"] = self._for_schema("minimum")
schema["maximum"] = self._for_schema("maximum")
return schema
class Date(DateTime):
"""
Return a formatted date string in UTC in ISO 8601.
See :meth:`datetime.date.isoformat` for more info on the ISO 8601 format.
"""
__schema_format__ = "date"
def __init__(self, **kwargs):
kwargs.pop("dt_format", None)
super(Date, self).__init__(dt_format="iso8601", **kwargs)
def parse(self, value):
if value is None:
return None
elif isinstance(value, string_types):
return date_from_iso8601(value)
elif isinstance(value, datetime):
return value.date()
elif isinstance(value, date):
return value
else:
raise ValueError("Unsupported Date format")
class Url(StringMixin, Raw):
"""
A string representation of a Url
:param str endpoint: Endpoint name. If endpoint is ``None``, ``request.endpoint`` is used instead
:param bool absolute: If ``True``, ensures that the generated urls will have the hostname included
:param str scheme: URL scheme specifier (e.g. ``http``, ``https``)
"""
def __init__(self, endpoint=None, absolute=False, scheme=None, **kwargs):
super(Url, self).__init__(**kwargs)
self.endpoint = endpoint
self.absolute = absolute
self.scheme = scheme
def output(self, key, obj, **kwargs):
try:
data = to_marshallable_type(obj)
endpoint = self.endpoint if self.endpoint is not None else request.endpoint
o = urlparse(url_for(endpoint, _external=self.absolute, **data))
if self.absolute:
scheme = self.scheme if self.scheme is not None else o.scheme
return urlunparse((scheme, o.netloc, o.path, "", "", ""))
return urlunparse(("", "", o.path, "", "", ""))
except TypeError as te:
raise MarshallingError(te)
class FormattedString(StringMixin, Raw):
"""
FormattedString is used to interpolate other values from
the response into this field. The syntax for the source string is
the same as the string :meth:`~str.format` method from the python
stdlib.
Ex::
fields = {
'name': fields.String,
'greeting': fields.FormattedString("Hello {name}")
}
data = {
'name': 'Doug',
}
marshal(data, fields)
:param str src_str: the string to format with the other values from the response.
"""
def __init__(self, src_str, **kwargs):
super(FormattedString, self).__init__(**kwargs)
self.src_str = text_type(src_str)
def output(self, key, obj, **kwargs):
try:
data = to_marshallable_type(obj)
return self.src_str.format(**data)
except (TypeError, IndexError) as error:
raise MarshallingError(error)
class ClassName(String):
"""
Return the serialized object class name as string.
:param bool dash: If `True`, transform CamelCase to kebab_case.
"""
def __init__(self, dash=False, **kwargs):
super(ClassName, self).__init__(**kwargs)
self.dash = dash
def output(self, key, obj, **kwargs):
classname = obj.__class__.__name__
if classname == "dict":
return "object"
return camel_to_dash(classname) if self.dash else classname
class Polymorph(Nested):
"""
A Nested field handling inheritance.
Allows you to specify a mapping between Python classes and fields specifications.
.. code-block:: python
mapping = {
Child1: child1_fields,
Child2: child2_fields,
}
fields = api.model('Thing', {
owner: fields.Polymorph(mapping)
})
:param dict mapping: Maps classes to their model/fields representation
"""
def __init__(self, mapping, required=False, **kwargs):
self.mapping = mapping
parent = self.resolve_ancestor(list(itervalues(mapping)))
super(Polymorph, self).__init__(parent, allow_null=not required, **kwargs)
def output(self, key, obj, ordered=False, **kwargs):
# Copied from upstream NestedField
value = get_value(key if self.attribute is None else self.attribute, obj)
if value is None:
if self.allow_null:
return None
elif self.default is not None:
return self.default
# Handle mappings
if not hasattr(value, "__class__"):
raise ValueError("Polymorph field only accept class instances")
candidates = [
fields for cls, fields in iteritems(self.mapping) if type(value) == cls
]
if len(candidates) <= 0:
raise ValueError("Unknown class: " + value.__class__.__name__)
elif len(candidates) > 1:
raise ValueError(
"Unable to determine a candidate for: " + value.__class__.__name__
)
else:
return marshal(
value, candidates[0].resolved, mask=self.mask, ordered=ordered
)
def resolve_ancestor(self, models):
"""
Resolve the common ancestor for all models.
Assume there is only one common ancestor.
"""
ancestors = [m.ancestors for m in models]
candidates = set.intersection(*ancestors)
if len(candidates) != 1:
field_names = [f.name for f in models]
raise ValueError(
"Unable to determine the common ancestor for: " + ", ".join(field_names)
)
parent_name = candidates.pop()
return models[0].get_parent(parent_name)
def clone(self, mask=None):
data = self.__dict__.copy()
mapping = data.pop("mapping")
for field in ("allow_null", "model"):
data.pop(field, None)
data["mask"] = mask
return Polymorph(mapping, **data)
class Wildcard(Raw):
"""
Field for marshalling list of "unkown" fields.
:param cls_or_instance: The field type the list will contain.
"""
exclude = set()
# cache the flat object
_flat = None
_obj = None
_cache = set()
_last = None
def __init__(self, cls_or_instance, **kwargs):
super(Wildcard, self).__init__(**kwargs)
error_msg = "The type of the wildcard elements must be a subclass of fields.Raw"
if isinstance(cls_or_instance, type):
if not issubclass(cls_or_instance, Raw):
raise MarshallingError(error_msg)
self.container = cls_or_instance()
else:
if not isinstance(cls_or_instance, Raw):
raise MarshallingError(error_msg)
self.container = cls_or_instance
def _flatten(self, obj):
if obj is None:
return None
if obj == self._obj and self._flat is not None:
return self._flat
if isinstance(obj, dict):
self._flat = [x for x in iteritems(obj)]
else:
def __match_attributes(attribute):
attr_name, attr_obj = attribute
if inspect.isroutine(attr_obj) or (
attr_name.startswith("__") and attr_name.endswith("__")
):
return False
return True
attributes = inspect.getmembers(obj)
self._flat = [x for x in attributes if __match_attributes(x)]
self._cache = set()
self._obj = obj
return self._flat
@property
def key(self):
return self._last
def reset(self):
self.exclude = set()
self._flat = None
self._obj = None
self._cache = set()
self._last = None
def output(self, key, obj, ordered=False):
value = None
reg = fnmatch.translate(key)
if self._flatten(obj):
while True:
try:
# we are using pop() so that we don't
# loop over the whole object every time dropping the
# complexity to O(n)
if ordered:
# Get first element if respecting order
(objkey, val) = self._flat.pop(0)
else:
# Previous default retained
(objkey, val) = self._flat.pop()
if (
objkey not in self._cache
and objkey not in self.exclude
and re.match(reg, objkey, re.IGNORECASE)
):
value = val
self._cache.add(objkey)
self._last = objkey
break
except IndexError:
break
if value is None:
if self.default is not None:
return self.container.format(self.default)
return None
if isinstance(self.container, Nested):
return marshal(
value,
self.container.nested,
skip_none=self.container.skip_none,
ordered=ordered,
)
return self.container.format(value)
def schema(self):
schema = super(Wildcard, self).schema()
schema["type"] = "object"
schema["additionalProperties"] = self.container.__schema__
return schema
def clone(self):
kwargs = self.__dict__.copy()
model = kwargs.pop("container")
return self.__class__(model, **kwargs)

610
libs/flask_restx/inputs.py Normal file
View File

@ -0,0 +1,610 @@
# -*- coding: utf-8 -*-
"""
This module provide some helpers for advanced types parsing.
You can define you own parser using the same pattern:
.. code-block:: python
def my_type(value):
if not condition:
raise ValueError('This is not my type')
return parse(value)
# Swagger documentation
my_type.__schema__ = {'type': 'string', 'format': 'my-custom-format'}
The last line allows you to document properly the type in the Swagger documentation.
"""
from __future__ import unicode_literals
import re
import socket
from datetime import datetime, time, timedelta
from email.utils import parsedate_tz, mktime_tz
from six.moves.urllib.parse import urlparse
import aniso8601
import pytz
# Constants for upgrading date-based intervals to full datetimes.
START_OF_DAY = time(0, 0, 0, tzinfo=pytz.UTC)
END_OF_DAY = time(23, 59, 59, 999999, tzinfo=pytz.UTC)
netloc_regex = re.compile(
r"(?:(?P<auth>[^:@]+?(?::[^:@]*?)?)@)?" # basic auth
r"(?:"
r"(?P<localhost>localhost)|" # localhost...
r"(?P<ipv4>\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})|" # ...or ipv4
r"(?:\[?(?P<ipv6>[A-F0-9]*:[A-F0-9:]+)\]?)|" # ...or ipv6
r"(?P<domain>(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+(?:[A-Z]{2,6}\.?|[A-Z0-9-]{2,}\.?))" # domain...
r")"
r"(?::(?P<port>\d+))?" # optional port
r"$",
re.IGNORECASE,
)
email_regex = re.compile(
r"^" "(?P<local>[^@]*[^@.])" r"@" r"(?P<server>[^@\.]+(?:\.[^@\.]+)*)" r"$",
re.IGNORECASE,
)
time_regex = re.compile(r"\d{2}:\d{2}")
def ipv4(value):
"""Validate an IPv4 address"""
try:
socket.inet_aton(value)
if value.count(".") == 3:
return value
except socket.error:
pass
raise ValueError("{0} is not a valid ipv4 address".format(value))
ipv4.__schema__ = {"type": "string", "format": "ipv4"}
def ipv6(value):
"""Validate an IPv6 address"""
try:
socket.inet_pton(socket.AF_INET6, value)
return value
except socket.error:
raise ValueError("{0} is not a valid ipv4 address".format(value))
ipv6.__schema__ = {"type": "string", "format": "ipv6"}
def ip(value):
"""Validate an IP address (both IPv4 and IPv6)"""
try:
return ipv4(value)
except ValueError:
pass
try:
return ipv6(value)
except ValueError:
raise ValueError("{0} is not a valid ip".format(value))
ip.__schema__ = {"type": "string", "format": "ip"}
class URL(object):
"""
Validate an URL.
Example::
parser = reqparse.RequestParser()
parser.add_argument('url', type=inputs.URL(schemes=['http', 'https']))
Input to the ``URL`` argument will be rejected
if it does not match an URL with specified constraints.
If ``check`` is True it will also be rejected if the domain does not exists.
:param bool check: Check the domain exists (perform a DNS resolution)
:param bool ip: Allow IP (both ipv4/ipv6) as domain
:param bool local: Allow localhost (both string or ip) as domain
:param bool port: Allow a port to be present
:param bool auth: Allow authentication to be present
:param list|tuple schemes: Restrict valid schemes to this list
:param list|tuple domains: Restrict valid domains to this list
:param list|tuple exclude: Exclude some domains
"""
def __init__(
self,
check=False,
ip=False,
local=False,
port=False,
auth=False,
schemes=None,
domains=None,
exclude=None,
):
self.check = check
self.ip = ip
self.local = local
self.port = port
self.auth = auth
self.schemes = schemes
self.domains = domains
self.exclude = exclude
def error(self, value, details=None):
msg = "{0} is not a valid URL"
if details:
msg = ". ".join((msg, details))
raise ValueError(msg.format(value))
def __call__(self, value):
parsed = urlparse(value)
netloc_match = netloc_regex.match(parsed.netloc)
if not all((parsed.scheme, parsed.netloc)):
if netloc_regex.match(
parsed.netloc or parsed.path.split("/", 1)[0].split("?", 1)[0]
):
self.error(value, "Did you mean: http://{0}")
self.error(value)
if parsed.scheme and self.schemes and parsed.scheme not in self.schemes:
self.error(value, "Protocol is not allowed")
if not netloc_match:
self.error(value)
data = netloc_match.groupdict()
if data["ipv4"] or data["ipv6"]:
if not self.ip:
self.error(value, "IP is not allowed")
else:
try:
ip(data["ipv4"] or data["ipv6"])
except ValueError as e:
self.error(value, str(e))
if not self.local:
if data["ipv4"] and data["ipv4"].startswith("127."):
self.error(value, "Localhost is not allowed")
elif data["ipv6"] == "::1":
self.error(value, "Localhost is not allowed")
if self.check:
pass
if data["auth"] and not self.auth:
self.error(value, "Authentication is not allowed")
if data["localhost"] and not self.local:
self.error(value, "Localhost is not allowed")
if data["port"]:
if not self.port:
self.error(value, "Custom port is not allowed")
else:
port = int(data["port"])
if not 0 < port < 65535:
self.error(value, "Port is out of range")
if data["domain"]:
if self.domains and data["domain"] not in self.domains:
self.error(value, "Domain is not allowed")
elif self.exclude and data["domain"] in self.exclude:
self.error(value, "Domain is not allowed")
if self.check:
try:
socket.getaddrinfo(data["domain"], None)
except socket.error:
self.error(value, "Domain does not exists")
return value
@property
def __schema__(self):
return {
"type": "string",
"format": "url",
}
#: Validate an URL
#:
#: Legacy validator, allows, auth, port, ip and local
#: Only allows schemes 'http', 'https', 'ftp' and 'ftps'
url = URL(
ip=True, auth=True, port=True, local=True, schemes=("http", "https", "ftp", "ftps")
)
class email(object):
"""
Validate an email.
Example::
parser = reqparse.RequestParser()
parser.add_argument('email', type=inputs.email(dns=True))
Input to the ``email`` argument will be rejected if it does not match an email
and if domain does not exists.
:param bool check: Check the domain exists (perform a DNS resolution)
:param bool ip: Allow IP (both ipv4/ipv6) as domain
:param bool local: Allow localhost (both string or ip) as domain
:param list|tuple domains: Restrict valid domains to this list
:param list|tuple exclude: Exclude some domains
"""
def __init__(self, check=False, ip=False, local=False, domains=None, exclude=None):
self.check = check
self.ip = ip
self.local = local
self.domains = domains
self.exclude = exclude
def error(self, value, msg=None):
msg = msg or "{0} is not a valid email"
raise ValueError(msg.format(value))
def is_ip(self, value):
try:
ip(value)
return True
except ValueError:
return False
def __call__(self, value):
match = email_regex.match(value)
if not match or ".." in value:
self.error(value)
server = match.group("server")
if self.check:
try:
socket.getaddrinfo(server, None)
except socket.error:
self.error(value)
if self.domains and server not in self.domains:
self.error(value, "{0} does not belong to the authorized domains")
if self.exclude and server in self.exclude:
self.error(value, "{0} belongs to a forbidden domain")
if not self.local and (
server in ("localhost", "::1") or server.startswith("127.")
):
self.error(value)
if self.is_ip(server) and not self.ip:
self.error(value)
return value
@property
def __schema__(self):
return {
"type": "string",
"format": "email",
}
class regex(object):
"""
Validate a string based on a regular expression.
Example::
parser = reqparse.RequestParser()
parser.add_argument('example', type=inputs.regex('^[0-9]+$'))
Input to the ``example`` argument will be rejected if it contains anything
but numbers.
:param str pattern: The regular expression the input must match
"""
def __init__(self, pattern):
self.pattern = pattern
self.re = re.compile(pattern)
def __call__(self, value):
if not self.re.search(value):
message = 'Value does not match pattern: "{0}"'.format(self.pattern)
raise ValueError(message)
return value
def __deepcopy__(self, memo):
return regex(self.pattern)
@property
def __schema__(self):
return {
"type": "string",
"pattern": self.pattern,
}
def _normalize_interval(start, end, value):
"""
Normalize datetime intervals.
Given a pair of datetime.date or datetime.datetime objects,
returns a 2-tuple of tz-aware UTC datetimes spanning the same interval.
For datetime.date objects, the returned interval starts at 00:00:00.0
on the first date and ends at 00:00:00.0 on the second.
Naive datetimes are upgraded to UTC.
Timezone-aware datetimes are normalized to the UTC tzdata.
Params:
- start: A date or datetime
- end: A date or datetime
"""
if not isinstance(start, datetime):
start = datetime.combine(start, START_OF_DAY)
end = datetime.combine(end, START_OF_DAY)
if start.tzinfo is None:
start = pytz.UTC.localize(start)
end = pytz.UTC.localize(end)
else:
start = start.astimezone(pytz.UTC)
end = end.astimezone(pytz.UTC)
return start, end
def _expand_datetime(start, value):
if not isinstance(start, datetime):
# Expand a single date object to be the interval spanning
# that entire day.
end = start + timedelta(days=1)
else:
# Expand a datetime based on the finest resolution provided
# in the original input string.
time = value.split("T")[1]
time_without_offset = re.sub("[+-].+", "", time)
num_separators = time_without_offset.count(":")
if num_separators == 0:
# Hour resolution
end = start + timedelta(hours=1)
elif num_separators == 1:
# Minute resolution:
end = start + timedelta(minutes=1)
else:
# Second resolution
end = start + timedelta(seconds=1)
return end
def _parse_interval(value):
"""
Do some nasty try/except voodoo to get some sort of datetime
object(s) out of the string.
"""
try:
return sorted(aniso8601.parse_interval(value))
except ValueError:
try:
return aniso8601.parse_datetime(value), None
except ValueError:
return aniso8601.parse_date(value), None
def iso8601interval(value, argument="argument"):
"""
Parses ISO 8601-formatted datetime intervals into tuples of datetimes.
Accepts both a single date(time) or a full interval using either start/end
or start/duration notation, with the following behavior:
- Intervals are defined as inclusive start, exclusive end
- Single datetimes are translated into the interval spanning the
largest resolution not specified in the input value, up to the day.
- The smallest accepted resolution is 1 second.
- All timezones are accepted as values; returned datetimes are
localized to UTC. Naive inputs and date inputs will are assumed UTC.
Examples::
"2013-01-01" -> datetime(2013, 1, 1), datetime(2013, 1, 2)
"2013-01-01T12" -> datetime(2013, 1, 1, 12), datetime(2013, 1, 1, 13)
"2013-01-01/2013-02-28" -> datetime(2013, 1, 1), datetime(2013, 2, 28)
"2013-01-01/P3D" -> datetime(2013, 1, 1), datetime(2013, 1, 4)
"2013-01-01T12:00/PT30M" -> datetime(2013, 1, 1, 12), datetime(2013, 1, 1, 12, 30)
"2013-01-01T06:00/2013-01-01T12:00" -> datetime(2013, 1, 1, 6), datetime(2013, 1, 1, 12)
:param str value: The ISO8601 date time as a string
:return: Two UTC datetimes, the start and the end of the specified interval
:rtype: A tuple (datetime, datetime)
:raises ValueError: if the interval is invalid.
"""
if not value:
raise ValueError("Expected a valid ISO8601 date/time interval.")
try:
start, end = _parse_interval(value)
if end is None:
end = _expand_datetime(start, value)
start, end = _normalize_interval(start, end, value)
except ValueError:
msg = (
"Invalid {arg}: {value}. {arg} must be a valid ISO8601 date/time interval."
)
raise ValueError(msg.format(arg=argument, value=value))
return start, end
iso8601interval.__schema__ = {"type": "string", "format": "iso8601-interval"}
def date(value):
"""Parse a valid looking date in the format YYYY-mm-dd"""
date = datetime.strptime(value, "%Y-%m-%d")
return date
date.__schema__ = {"type": "string", "format": "date"}
def _get_integer(value):
try:
return int(value)
except (TypeError, ValueError):
raise ValueError("{0} is not a valid integer".format(value))
def natural(value, argument="argument"):
"""Restrict input type to the natural numbers (0, 1, 2, 3...)"""
value = _get_integer(value)
if value < 0:
msg = "Invalid {arg}: {value}. {arg} must be a non-negative integer"
raise ValueError(msg.format(arg=argument, value=value))
return value
natural.__schema__ = {"type": "integer", "minimum": 0}
def positive(value, argument="argument"):
"""Restrict input type to the positive integers (1, 2, 3...)"""
value = _get_integer(value)
if value < 1:
msg = "Invalid {arg}: {value}. {arg} must be a positive integer"
raise ValueError(msg.format(arg=argument, value=value))
return value
positive.__schema__ = {"type": "integer", "minimum": 0, "exclusiveMinimum": True}
class int_range(object):
"""Restrict input to an integer in a range (inclusive)"""
def __init__(self, low, high, argument="argument"):
self.low = low
self.high = high
self.argument = argument
def __call__(self, value):
value = _get_integer(value)
if value < self.low or value > self.high:
msg = "Invalid {arg}: {val}. {arg} must be within the range {lo} - {hi}"
raise ValueError(
msg.format(arg=self.argument, val=value, lo=self.low, hi=self.high)
)
return value
@property
def __schema__(self):
return {
"type": "integer",
"minimum": self.low,
"maximum": self.high,
}
def boolean(value):
"""
Parse the string ``"true"`` or ``"false"`` as a boolean (case insensitive).
Also accepts ``"1"`` and ``"0"`` as ``True``/``False`` (respectively).
If the input is from the request JSON body, the type is already a native python boolean,
and will be passed through without further parsing.
:raises ValueError: if the boolean value is invalid
"""
if isinstance(value, bool):
return value
if value is None:
raise ValueError("boolean type must be non-null")
elif not value:
return False
value = str(value).lower()
if value in ("true", "1", "on",):
return True
if value in ("false", "0",):
return False
raise ValueError("Invalid literal for boolean(): {0}".format(value))
boolean.__schema__ = {"type": "boolean"}
def datetime_from_rfc822(value):
"""
Turns an RFC822 formatted date into a datetime object.
Example::
inputs.datetime_from_rfc822('Wed, 02 Oct 2002 08:00:00 EST')
:param str value: The RFC822-complying string to transform
:return: The parsed datetime
:rtype: datetime
:raises ValueError: if value is an invalid date literal
"""
raw = value
if not time_regex.search(value):
value = " ".join((value, "00:00:00"))
try:
timetuple = parsedate_tz(value)
timestamp = mktime_tz(timetuple)
if timetuple[-1] is None:
return datetime.fromtimestamp(timestamp).replace(tzinfo=pytz.utc)
else:
return datetime.fromtimestamp(timestamp, pytz.utc)
except Exception:
raise ValueError('Invalid date literal "{0}"'.format(raw))
def datetime_from_iso8601(value):
"""
Turns an ISO8601 formatted date into a datetime object.
Example::
inputs.datetime_from_iso8601("2012-01-01T23:30:00+02:00")
:param str value: The ISO8601-complying string to transform
:return: A datetime
:rtype: datetime
:raises ValueError: if value is an invalid date literal
"""
try:
try:
return aniso8601.parse_datetime(value)
except ValueError:
date = aniso8601.parse_date(value)
return datetime(date.year, date.month, date.day)
except Exception:
raise ValueError('Invalid date literal "{0}"'.format(value))
datetime_from_iso8601.__schema__ = {"type": "string", "format": "date-time"}
def date_from_iso8601(value):
"""
Turns an ISO8601 formatted date into a date object.
Example::
inputs.date_from_iso8601("2012-01-01")
:param str value: The ISO8601-complying string to transform
:return: A date
:rtype: date
:raises ValueError: if value is an invalid date literal
"""
return datetime_from_iso8601(value).date()
date_from_iso8601.__schema__ = {"type": "string", "format": "date"}

View File

@ -0,0 +1,309 @@
# -*- coding: utf-8 -*-
from __future__ import unicode_literals
from collections import OrderedDict
from functools import wraps
from six import iteritems
from flask import request, current_app, has_app_context
from .mask import Mask, apply as apply_mask
from .utils import unpack
def make(cls):
if isinstance(cls, type):
return cls()
return cls
def marshal(data, fields, envelope=None, skip_none=False, mask=None, ordered=False):
"""Takes raw data (in the form of a dict, list, object) and a dict of
fields to output and filters the data based on those fields.
:param data: the actual object(s) from which the fields are taken from
:param fields: a dict of whose keys will make up the final serialized
response output
:param envelope: optional key that will be used to envelop the serialized
response
:param bool skip_none: optional key will be used to eliminate fields
which value is None or the field's key not
exist in data
:param bool ordered: Wether or not to preserve order
>>> from flask_restx import fields, marshal
>>> data = { 'a': 100, 'b': 'foo', 'c': None }
>>> mfields = { 'a': fields.Raw, 'c': fields.Raw, 'd': fields.Raw }
>>> marshal(data, mfields)
{'a': 100, 'c': None, 'd': None}
>>> marshal(data, mfields, envelope='data')
{'data': {'a': 100, 'c': None, 'd': None}}
>>> marshal(data, mfields, skip_none=True)
{'a': 100}
>>> marshal(data, mfields, ordered=True)
OrderedDict([('a', 100), ('c', None), ('d', None)])
>>> marshal(data, mfields, envelope='data', ordered=True)
OrderedDict([('data', OrderedDict([('a', 100), ('c', None), ('d', None)]))])
>>> marshal(data, mfields, skip_none=True, ordered=True)
OrderedDict([('a', 100)])
"""
out, has_wildcards = _marshal(data, fields, envelope, skip_none, mask, ordered)
if has_wildcards:
# ugly local import to avoid dependency loop
from .fields import Wildcard
items = []
keys = []
for dkey, val in fields.items():
key = dkey
if isinstance(val, dict):
value = marshal(data, val, skip_none=skip_none, ordered=ordered)
else:
field = make(val)
is_wildcard = isinstance(field, Wildcard)
# exclude already parsed keys from the wildcard
if is_wildcard:
field.reset()
if keys:
field.exclude |= set(keys)
keys = []
value = field.output(dkey, data, ordered=ordered)
if is_wildcard:
def _append(k, v):
if skip_none and (v is None or v == OrderedDict() or v == {}):
return
items.append((k, v))
key = field.key or dkey
_append(key, value)
while True:
value = field.output(dkey, data, ordered=ordered)
if value is None or value == field.container.format(
field.default
):
break
key = field.key
_append(key, value)
continue
keys.append(key)
if skip_none and (value is None or value == OrderedDict() or value == {}):
continue
items.append((key, value))
items = tuple(items)
out = OrderedDict(items) if ordered else dict(items)
if envelope:
out = OrderedDict([(envelope, out)]) if ordered else {envelope: out}
return out
return out
def _marshal(data, fields, envelope=None, skip_none=False, mask=None, ordered=False):
"""Takes raw data (in the form of a dict, list, object) and a dict of
fields to output and filters the data based on those fields.
:param data: the actual object(s) from which the fields are taken from
:param fields: a dict of whose keys will make up the final serialized
response output
:param envelope: optional key that will be used to envelop the serialized
response
:param bool skip_none: optional key will be used to eliminate fields
which value is None or the field's key not
exist in data
:param bool ordered: Wether or not to preserve order
>>> from flask_restx import fields, marshal
>>> data = { 'a': 100, 'b': 'foo', 'c': None }
>>> mfields = { 'a': fields.Raw, 'c': fields.Raw, 'd': fields.Raw }
>>> marshal(data, mfields)
{'a': 100, 'c': None, 'd': None}
>>> marshal(data, mfields, envelope='data')
{'data': {'a': 100, 'c': None, 'd': None}}
>>> marshal(data, mfields, skip_none=True)
{'a': 100}
>>> marshal(data, mfields, ordered=True)
OrderedDict([('a', 100), ('c', None), ('d', None)])
>>> marshal(data, mfields, envelope='data', ordered=True)
OrderedDict([('data', OrderedDict([('a', 100), ('c', None), ('d', None)]))])
>>> marshal(data, mfields, skip_none=True, ordered=True)
OrderedDict([('a', 100)])
"""
# ugly local import to avoid dependency loop
from .fields import Wildcard
mask = mask or getattr(fields, "__mask__", None)
fields = getattr(fields, "resolved", fields)
if mask:
fields = apply_mask(fields, mask, skip=True)
if isinstance(data, (list, tuple)):
out = [marshal(d, fields, skip_none=skip_none, ordered=ordered) for d in data]
if envelope:
out = OrderedDict([(envelope, out)]) if ordered else {envelope: out}
return out, False
has_wildcards = {"present": False}
def __format_field(key, val):
field = make(val)
if isinstance(field, Wildcard):
has_wildcards["present"] = True
value = field.output(key, data, ordered=ordered)
return (key, value)
items = (
(k, marshal(data, v, skip_none=skip_none, ordered=ordered))
if isinstance(v, dict)
else __format_field(k, v)
for k, v in iteritems(fields)
)
if skip_none:
items = (
(k, v) for k, v in items if v is not None and v != OrderedDict() and v != {}
)
out = OrderedDict(items) if ordered else dict(items)
if envelope:
out = OrderedDict([(envelope, out)]) if ordered else {envelope: out}
return out, has_wildcards["present"]
class marshal_with(object):
"""A decorator that apply marshalling to the return values of your methods.
>>> from flask_restx import fields, marshal_with
>>> mfields = { 'a': fields.Raw }
>>> @marshal_with(mfields)
... def get():
... return { 'a': 100, 'b': 'foo' }
...
...
>>> get()
OrderedDict([('a', 100)])
>>> @marshal_with(mfields, envelope='data')
... def get():
... return { 'a': 100, 'b': 'foo' }
...
...
>>> get()
OrderedDict([('data', OrderedDict([('a', 100)]))])
>>> mfields = { 'a': fields.Raw, 'c': fields.Raw, 'd': fields.Raw }
>>> @marshal_with(mfields, skip_none=True)
... def get():
... return { 'a': 100, 'b': 'foo', 'c': None }
...
...
>>> get()
OrderedDict([('a', 100)])
see :meth:`flask_restx.marshal`
"""
def __init__(
self, fields, envelope=None, skip_none=False, mask=None, ordered=False
):
"""
:param fields: a dict of whose keys will make up the final
serialized response output
:param envelope: optional key that will be used to envelop the serialized
response
"""
self.fields = fields
self.envelope = envelope
self.skip_none = skip_none
self.ordered = ordered
self.mask = Mask(mask, skip=True)
def __call__(self, f):
@wraps(f)
def wrapper(*args, **kwargs):
resp = f(*args, **kwargs)
mask = self.mask
if has_app_context():
mask_header = current_app.config["RESTX_MASK_HEADER"]
mask = request.headers.get(mask_header) or mask
if isinstance(resp, tuple):
data, code, headers = unpack(resp)
return (
marshal(
data,
self.fields,
self.envelope,
self.skip_none,
mask,
self.ordered,
),
code,
headers,
)
else:
return marshal(
resp, self.fields, self.envelope, self.skip_none, mask, self.ordered
)
return wrapper
class marshal_with_field(object):
"""
A decorator that formats the return values of your methods with a single field.
>>> from flask_restx import marshal_with_field, fields
>>> @marshal_with_field(fields.List(fields.Integer))
... def get():
... return ['1', 2, 3.0]
...
>>> get()
[1, 2, 3]
see :meth:`flask_restx.marshal_with`
"""
def __init__(self, field):
"""
:param field: a single field with which to marshal the output.
"""
if isinstance(field, type):
self.field = field()
else:
self.field = field
def __call__(self, f):
@wraps(f)
def wrapper(*args, **kwargs):
resp = f(*args, **kwargs)
if isinstance(resp, tuple):
data, code, headers = unpack(resp)
return self.field.format(data), code, headers
return self.field.format(resp)
return wrapper

191
libs/flask_restx/mask.py Normal file
View File

@ -0,0 +1,191 @@
# -*- coding: utf-8 -*-
from __future__ import unicode_literals, absolute_import
import logging
import re
import six
from collections import OrderedDict
from inspect import isclass
from .errors import RestError
log = logging.getLogger(__name__)
LEXER = re.compile(r"\{|\}|\,|[\w_:\-\*]+")
class MaskError(RestError):
"""Raised when an error occurs on mask"""
pass
class ParseError(MaskError):
"""Raised when the mask parsing failed"""
pass
class Mask(OrderedDict):
"""
Hold a parsed mask.
:param str|dict|Mask mask: A mask, parsed or not
:param bool skip: If ``True``, missing fields won't appear in result
"""
def __init__(self, mask=None, skip=False, **kwargs):
self.skip = skip
if isinstance(mask, six.string_types):
super(Mask, self).__init__()
self.parse(mask)
elif isinstance(mask, (dict, OrderedDict)):
super(Mask, self).__init__(mask, **kwargs)
else:
self.skip = skip
super(Mask, self).__init__(**kwargs)
def parse(self, mask):
"""
Parse a fields mask.
Expect something in the form::
{field,nested{nested_field,another},last}
External brackets are optionals so it can also be written::
field,nested{nested_field,another},last
All extras characters will be ignored.
:param str mask: the mask string to parse
:raises ParseError: when a mask is unparseable/invalid
"""
if not mask:
return
mask = self.clean(mask)
fields = self
previous = None
stack = []
for token in LEXER.findall(mask):
if token == "{":
if previous not in fields:
raise ParseError("Unexpected opening bracket")
fields[previous] = Mask(skip=self.skip)
stack.append(fields)
fields = fields[previous]
elif token == "}":
if not stack:
raise ParseError("Unexpected closing bracket")
fields = stack.pop()
elif token == ",":
if previous in (",", "{", None):
raise ParseError("Unexpected comma")
else:
fields[token] = True
previous = token
if stack:
raise ParseError("Missing closing bracket")
def clean(self, mask):
"""Remove unnecessary characters"""
mask = mask.replace("\n", "").strip()
# External brackets are optional
if mask[0] == "{":
if mask[-1] != "}":
raise ParseError("Missing closing bracket")
mask = mask[1:-1]
return mask
def apply(self, data):
"""
Apply a fields mask to the data.
:param data: The data or model to apply mask on
:raises MaskError: when unable to apply the mask
"""
from . import fields
# Should handle lists
if isinstance(data, (list, tuple, set)):
return [self.apply(d) for d in data]
elif isinstance(data, (fields.Nested, fields.List, fields.Polymorph)):
return data.clone(self)
elif type(data) == fields.Raw:
return fields.Raw(default=data.default, attribute=data.attribute, mask=self)
elif data == fields.Raw:
return fields.Raw(mask=self)
elif (
isinstance(data, fields.Raw)
or isclass(data)
and issubclass(data, fields.Raw)
):
# Not possible to apply a mask on these remaining fields types
raise MaskError("Mask is inconsistent with model")
# Should handle objects
elif not isinstance(data, (dict, OrderedDict)) and hasattr(data, "__dict__"):
data = data.__dict__
return self.filter_data(data)
def filter_data(self, data):
"""
Handle the data filtering given a parsed mask
:param dict data: the raw data to filter
:param list mask: a parsed mask to filter against
:param bool skip: whether or not to skip missing fields
"""
out = {}
for field, content in six.iteritems(self):
if field == "*":
continue
elif isinstance(content, Mask):
nested = data.get(field, None)
if self.skip and nested is None:
continue
elif nested is None:
out[field] = None
else:
out[field] = content.apply(nested)
elif self.skip and field not in data:
continue
else:
out[field] = data.get(field, None)
if "*" in self.keys():
for key, value in six.iteritems(data):
if key not in out:
out[key] = value
return out
def __str__(self):
return "{{{0}}}".format(
",".join(
[
"".join((k, str(v))) if isinstance(v, Mask) else k
for k, v in six.iteritems(self)
]
)
)
def apply(data, mask, skip=False):
"""
Apply a fields mask to the data.
:param data: The data or model to apply mask on
:param str|Mask mask: the mask (parsed or not) to apply on data
:param bool skip: If rue, missing field won't appear in result
:raises MaskError: when unable to apply the mask
"""
return Mask(mask, skip).apply(data)

295
libs/flask_restx/model.py Normal file
View File

@ -0,0 +1,295 @@
# -*- coding: utf-8 -*-
from __future__ import unicode_literals
import copy
import re
import warnings
from collections import OrderedDict
try:
from collections.abc import MutableMapping
except ImportError:
# TODO Remove this to drop Python2 support
from collections import MutableMapping
from six import iteritems, itervalues
from werkzeug.utils import cached_property
from .mask import Mask
from .errors import abort
from jsonschema import Draft4Validator
from jsonschema.exceptions import ValidationError
from .utils import not_none
from ._http import HTTPStatus
RE_REQUIRED = re.compile(r"u?\'(?P<name>.*)\' is a required property", re.I | re.U)
def instance(cls):
if isinstance(cls, type):
return cls()
return cls
class ModelBase(object):
"""
Handles validation and swagger style inheritance for both subclasses.
Subclass must define `schema` attribute.
:param str name: The model public name
"""
def __init__(self, name, *args, **kwargs):
super(ModelBase, self).__init__(*args, **kwargs)
self.__apidoc__ = {"name": name}
self.name = name
self.__parents__ = []
def instance_inherit(name, *parents):
return self.__class__.inherit(name, self, *parents)
self.inherit = instance_inherit
@property
def ancestors(self):
"""
Return the ancestors tree
"""
ancestors = [p.ancestors for p in self.__parents__]
return set.union(set([self.name]), *ancestors)
def get_parent(self, name):
if self.name == name:
return self
else:
for parent in self.__parents__:
found = parent.get_parent(name)
if found:
return found
raise ValueError("Parent " + name + " not found")
@property
def __schema__(self):
schema = self._schema
if self.__parents__:
refs = [
{"$ref": "#/definitions/{0}".format(parent.name)}
for parent in self.__parents__
]
return {"allOf": refs + [schema]}
else:
return schema
@classmethod
def inherit(cls, name, *parents):
"""
Inherit this model (use the Swagger composition pattern aka. allOf)
:param str name: The new model name
:param dict fields: The new model extra fields
"""
model = cls(name, parents[-1])
model.__parents__ = parents[:-1]
return model
def validate(self, data, resolver=None, format_checker=None):
validator = Draft4Validator(
self.__schema__, resolver=resolver, format_checker=format_checker
)
try:
validator.validate(data)
except ValidationError:
abort(
HTTPStatus.BAD_REQUEST,
message="Input payload validation failed",
errors=dict(self.format_error(e) for e in validator.iter_errors(data)),
)
def format_error(self, error):
path = list(error.path)
if error.validator == "required":
name = RE_REQUIRED.match(error.message).group("name")
path.append(name)
key = ".".join(str(p) for p in path)
return key, error.message
def __unicode__(self):
return "Model({name},{{{fields}}})".format(
name=self.name, fields=",".join(self.keys())
)
__str__ = __unicode__
class RawModel(ModelBase):
"""
A thin wrapper on ordered fields dict to store API doc metadata.
Can also be used for response marshalling.
:param str name: The model public name
:param str mask: an optional default model mask
:param bool strict: validation should raise error when there is param not provided in schema
"""
wrapper = dict
def __init__(self, name, *args, **kwargs):
self.__mask__ = kwargs.pop("mask", None)
self.__strict__ = kwargs.pop("strict", False)
if self.__mask__ and not isinstance(self.__mask__, Mask):
self.__mask__ = Mask(self.__mask__)
super(RawModel, self).__init__(name, *args, **kwargs)
def instance_clone(name, *parents):
return self.__class__.clone(name, self, *parents)
self.clone = instance_clone
@property
def _schema(self):
properties = self.wrapper()
required = set()
discriminator = None
for name, field in iteritems(self):
field = instance(field)
properties[name] = field.__schema__
if field.required:
required.add(name)
if getattr(field, "discriminator", False):
discriminator = name
definition = {
"required": sorted(list(required)) or None,
"properties": properties,
"discriminator": discriminator,
"x-mask": str(self.__mask__) if self.__mask__ else None,
"type": "object",
}
if self.__strict__:
definition['additionalProperties'] = False
return not_none(definition)
@cached_property
def resolved(self):
"""
Resolve real fields before submitting them to marshal
"""
# Duplicate fields
resolved = copy.deepcopy(self)
# Recursively copy parent fields if necessary
for parent in self.__parents__:
resolved.update(parent.resolved)
# Handle discriminator
candidates = [
f for f in itervalues(resolved) if getattr(f, "discriminator", None)
]
# Ensure the is only one discriminator
if len(candidates) > 1:
raise ValueError("There can only be one discriminator by schema")
# Ensure discriminator always output the model name
elif len(candidates) == 1:
candidates[0].default = self.name
return resolved
def extend(self, name, fields):
"""
Extend this model (Duplicate all fields)
:param str name: The new model name
:param dict fields: The new model extra fields
:deprecated: since 0.9. Use :meth:`clone` instead.
"""
warnings.warn(
"extend is is deprecated, use clone instead",
DeprecationWarning,
stacklevel=2,
)
if isinstance(fields, (list, tuple)):
return self.clone(name, *fields)
else:
return self.clone(name, fields)
@classmethod
def clone(cls, name, *parents):
"""
Clone these models (Duplicate all fields)
It can be used from the class
>>> model = Model.clone(fields_1, fields_2)
or from an Instanciated model
>>> new_model = model.clone(fields_1, fields_2)
:param str name: The new model name
:param dict parents: The new model extra fields
"""
fields = cls.wrapper()
for parent in parents:
fields.update(copy.deepcopy(parent))
return cls(name, fields)
def __deepcopy__(self, memo):
obj = self.__class__(
self.name,
[(key, copy.deepcopy(value, memo)) for key, value in iteritems(self)],
mask=self.__mask__,
strict=self.__strict__,
)
obj.__parents__ = self.__parents__
return obj
class Model(RawModel, dict, MutableMapping):
"""
A thin wrapper on fields dict to store API doc metadata.
Can also be used for response marshalling.
:param str name: The model public name
:param str mask: an optional default model mask
"""
pass
class OrderedModel(RawModel, OrderedDict, MutableMapping):
"""
A thin wrapper on ordered fields dict to store API doc metadata.
Can also be used for response marshalling.
:param str name: The model public name
:param str mask: an optional default model mask
"""
wrapper = OrderedDict
class SchemaModel(ModelBase):
"""
Stores API doc metadata based on a json schema.
:param str name: The model public name
:param dict schema: The json schema we are documenting
"""
def __init__(self, name, schema=None):
super(SchemaModel, self).__init__(name)
self._schema = schema or {}
def __unicode__(self):
return "SchemaModel({name},{schema})".format(
name=self.name, schema=self._schema
)
__str__ = __unicode__

View File

@ -0,0 +1,379 @@
# -*- coding: utf-8 -*-
from __future__ import unicode_literals
import inspect
import warnings
import logging
from collections import namedtuple, OrderedDict
import six
from flask import request
from flask.views import http_method_funcs
from ._http import HTTPStatus
from .errors import abort
from .marshalling import marshal, marshal_with
from .model import Model, OrderedModel, SchemaModel
from .reqparse import RequestParser
from .utils import merge
# Container for each route applied to a Resource using @ns.route decorator
ResourceRoute = namedtuple("ResourceRoute", "resource urls route_doc kwargs")
class Namespace(object):
"""
Group resources together.
Namespace is to API what :class:`flask:flask.Blueprint` is for :class:`flask:flask.Flask`.
:param str name: The namespace name
:param str description: An optional short description
:param str path: An optional prefix path. If not provided, prefix is ``/+name``
:param list decorators: A list of decorators to apply to each resources
:param bool validate: Whether or not to perform validation on this namespace
:param bool ordered: Whether or not to preserve order on models and marshalling
:param Api api: an optional API to attache to the namespace
"""
def __init__(
self,
name,
description=None,
path=None,
decorators=None,
validate=None,
authorizations=None,
ordered=False,
**kwargs
):
self.name = name
self.description = description
self._path = path
self._schema = None
self._validate = validate
self.models = {}
self.urls = {}
self.decorators = decorators if decorators else []
self.resources = [] # List[ResourceRoute]
self.error_handlers = OrderedDict()
self.default_error_handler = None
self.authorizations = authorizations
self.ordered = ordered
self.apis = []
if "api" in kwargs:
self.apis.append(kwargs["api"])
self.logger = logging.getLogger(__name__ + "." + self.name)
@property
def path(self):
return (self._path or ("/" + self.name)).rstrip("/")
def add_resource(self, resource, *urls, **kwargs):
"""
Register a Resource for a given API Namespace
:param Resource resource: the resource ro register
:param str urls: one or more url routes to match for the resource,
standard flask routing rules apply.
Any url variables will be passed to the resource method as args.
:param str endpoint: endpoint name (defaults to :meth:`Resource.__name__.lower`
Can be used to reference this route in :class:`fields.Url` fields
:param list|tuple resource_class_args: args to be forwarded to the constructor of the resource.
:param dict resource_class_kwargs: kwargs to be forwarded to the constructor of the resource.
Additional keyword arguments not specified above will be passed as-is
to :meth:`flask.Flask.add_url_rule`.
Examples::
namespace.add_resource(HelloWorld, '/', '/hello')
namespace.add_resource(Foo, '/foo', endpoint="foo")
namespace.add_resource(FooSpecial, '/special/foo', endpoint="foo")
"""
route_doc = kwargs.pop("route_doc", {})
self.resources.append(ResourceRoute(resource, urls, route_doc, kwargs))
for api in self.apis:
ns_urls = api.ns_urls(self, urls)
api.register_resource(self, resource, *ns_urls, **kwargs)
def route(self, *urls, **kwargs):
"""
A decorator to route resources.
"""
def wrapper(cls):
doc = kwargs.pop("doc", None)
if doc is not None:
# build api doc intended only for this route
kwargs["route_doc"] = self._build_doc(cls, doc)
self.add_resource(cls, *urls, **kwargs)
return cls
return wrapper
def _build_doc(self, cls, doc):
if doc is False:
return False
unshortcut_params_description(doc)
handle_deprecations(doc)
for http_method in http_method_funcs:
if http_method in doc:
if doc[http_method] is False:
continue
unshortcut_params_description(doc[http_method])
handle_deprecations(doc[http_method])
if "expect" in doc[http_method] and not isinstance(
doc[http_method]["expect"], (list, tuple)
):
doc[http_method]["expect"] = [doc[http_method]["expect"]]
return merge(getattr(cls, "__apidoc__", {}), doc)
def doc(self, shortcut=None, **kwargs):
"""A decorator to add some api documentation to the decorated object"""
if isinstance(shortcut, six.text_type):
kwargs["id"] = shortcut
show = shortcut if isinstance(shortcut, bool) else True
def wrapper(documented):
documented.__apidoc__ = self._build_doc(
documented, kwargs if show else False
)
return documented
return wrapper
def hide(self, func):
"""A decorator to hide a resource or a method from specifications"""
return self.doc(False)(func)
def abort(self, *args, **kwargs):
"""
Properly abort the current request
See: :func:`~flask_restx.errors.abort`
"""
abort(*args, **kwargs)
def add_model(self, name, definition):
self.models[name] = definition
for api in self.apis:
api.models[name] = definition
return definition
def model(self, name=None, model=None, mask=None, strict=False, **kwargs):
"""
Register a model
:param bool strict - should model validation raise error when non-specified param
is provided?
.. seealso:: :class:`Model`
"""
cls = OrderedModel if self.ordered else Model
model = cls(name, model, mask=mask, strict=strict)
model.__apidoc__.update(kwargs)
return self.add_model(name, model)
def schema_model(self, name=None, schema=None):
"""
Register a model
.. seealso:: :class:`Model`
"""
model = SchemaModel(name, schema)
return self.add_model(name, model)
def extend(self, name, parent, fields):
"""
Extend a model (Duplicate all fields)
:deprecated: since 0.9. Use :meth:`clone` instead
"""
if isinstance(parent, list):
parents = parent + [fields]
model = Model.extend(name, *parents)
else:
model = Model.extend(name, parent, fields)
return self.add_model(name, model)
def clone(self, name, *specs):
"""
Clone a model (Duplicate all fields)
:param str name: the resulting model name
:param specs: a list of models from which to clone the fields
.. seealso:: :meth:`Model.clone`
"""
model = Model.clone(name, *specs)
return self.add_model(name, model)
def inherit(self, name, *specs):
"""
Inherit a model (use the Swagger composition pattern aka. allOf)
.. seealso:: :meth:`Model.inherit`
"""
model = Model.inherit(name, *specs)
return self.add_model(name, model)
def expect(self, *inputs, **kwargs):
"""
A decorator to Specify the expected input model
:param ModelBase|Parse inputs: An expect model or request parser
:param bool validate: whether to perform validation or not
"""
expect = []
params = {"validate": kwargs.get("validate", self._validate), "expect": expect}
for param in inputs:
expect.append(param)
return self.doc(**params)
def parser(self):
"""Instanciate a :class:`~RequestParser`"""
return RequestParser()
def as_list(self, field):
"""Allow to specify nested lists for documentation"""
field.__apidoc__ = merge(getattr(field, "__apidoc__", {}), {"as_list": True})
return field
def marshal_with(
self, fields, as_list=False, code=HTTPStatus.OK, description=None, **kwargs
):
"""
A decorator specifying the fields to use for serialization.
:param bool as_list: Indicate that the return type is a list (for the documentation)
:param int code: Optionally give the expected HTTP response code if its different from 200
"""
def wrapper(func):
doc = {
"responses": {
str(code): (description, [fields], kwargs)
if as_list
else (description, fields, kwargs)
},
"__mask__": kwargs.get(
"mask", True
), # Mask values can't be determined outside app context
}
func.__apidoc__ = merge(getattr(func, "__apidoc__", {}), doc)
return marshal_with(fields, ordered=self.ordered, **kwargs)(func)
return wrapper
def marshal_list_with(self, fields, **kwargs):
"""A shortcut decorator for :meth:`~Api.marshal_with` with ``as_list=True``"""
return self.marshal_with(fields, True, **kwargs)
def marshal(self, *args, **kwargs):
"""A shortcut to the :func:`marshal` helper"""
return marshal(*args, **kwargs)
def errorhandler(self, exception):
"""A decorator to register an error handler for a given exception"""
if inspect.isclass(exception) and issubclass(exception, Exception):
# Register an error handler for a given exception
def wrapper(func):
self.error_handlers[exception] = func
return func
return wrapper
else:
# Register the default error handler
self.default_error_handler = exception
return exception
def param(self, name, description=None, _in="query", **kwargs):
"""
A decorator to specify one of the expected parameters
:param str name: the parameter name
:param str description: a small description
:param str _in: the parameter location `(query|header|formData|body|cookie)`
"""
param = kwargs
param["in"] = _in
param["description"] = description
return self.doc(params={name: param})
def response(self, code, description, model=None, **kwargs):
"""
A decorator to specify one of the expected responses
:param int code: the HTTP status code
:param str description: a small description about the response
:param ModelBase model: an optional response model
"""
return self.doc(responses={str(code): (description, model, kwargs)})
def header(self, name, description=None, **kwargs):
"""
A decorator to specify one of the expected headers
:param str name: the HTTP header name
:param str description: a description about the header
"""
header = {"description": description}
header.update(kwargs)
return self.doc(headers={name: header})
def produces(self, mimetypes):
"""A decorator to specify the MIME types the API can produce"""
return self.doc(produces=mimetypes)
def deprecated(self, func):
"""A decorator to mark a resource or a method as deprecated"""
return self.doc(deprecated=True)(func)
def vendor(self, *args, **kwargs):
"""
A decorator to expose vendor extensions.
Extensions can be submitted as dict or kwargs.
The ``x-`` prefix is optionnal and will be added if missing.
See: http://swagger.io/specification/#specification-extensions-128
"""
for arg in args:
kwargs.update(arg)
return self.doc(vendor=kwargs)
@property
def payload(self):
"""Store the input payload in the current request context"""
return request.get_json()
def unshortcut_params_description(data):
if "params" in data:
for name, description in six.iteritems(data["params"]):
if isinstance(description, six.string_types):
data["params"][name] = {"description": description}
def handle_deprecations(doc):
if "parser" in doc:
warnings.warn(
"The parser attribute is deprecated, use expect instead",
DeprecationWarning,
stacklevel=2,
)
doc["expect"] = doc.get("expect", []) + [doc.pop("parser")]
if "body" in doc:
warnings.warn(
"The body attribute is deprecated, use expect instead",
DeprecationWarning,
stacklevel=2,
)
doc["expect"] = doc.get("expect", []) + [doc.pop("body")]

207
libs/flask_restx/postman.py Normal file
View File

@ -0,0 +1,207 @@
# -*- coding: utf-8 -*-
from __future__ import unicode_literals, absolute_import
from time import time
from uuid import uuid5, NAMESPACE_URL
from six import iteritems
from six.moves.urllib.parse import urlencode
def clean(data):
"""Remove all keys where value is None"""
return dict((k, v) for k, v in iteritems(data) if v is not None)
DEFAULT_VARS = {
"string": "",
"integer": 0,
"number": 0,
}
class Request(object):
"""Wraps a Swagger operation into a Postman Request"""
def __init__(self, collection, path, params, method, operation):
self.collection = collection
self.path = path
self.params = params
self.method = method.upper()
self.operation = operation
@property
def id(self):
seed = str(" ".join((self.method, self.url)))
return str(uuid5(self.collection.uuid, seed))
@property
def url(self):
return self.collection.api.base_url.rstrip("/") + self.path
@property
def headers(self):
headers = {}
# Handle content-type
if self.method != "GET":
consumes = self.collection.api.__schema__.get("consumes", [])
consumes = self.operation.get("consumes", consumes)
if len(consumes):
headers["Content-Type"] = consumes[-1]
# Add all parameters headers
for param in self.operation.get("parameters", []):
if param["in"] == "header":
headers[param["name"]] = param.get("default", "")
# Add security headers if needed (global then local)
for security in self.collection.api.__schema__.get("security", []):
for key, header in iteritems(self.collection.apikeys):
if key in security:
headers[header] = ""
for security in self.operation.get("security", []):
for key, header in iteritems(self.collection.apikeys):
if key in security:
headers[header] = ""
lines = [":".join(line) for line in iteritems(headers)]
return "\n".join(lines)
@property
def folder(self):
if "tags" not in self.operation or len(self.operation["tags"]) == 0:
return
tag = self.operation["tags"][0]
for folder in self.collection.folders:
if folder.tag == tag:
return folder.id
def as_dict(self, urlvars=False):
url, variables = self.process_url(urlvars)
return clean(
{
"id": self.id,
"method": self.method,
"name": self.operation["operationId"],
"description": self.operation.get("summary"),
"url": url,
"headers": self.headers,
"collectionId": self.collection.id,
"folder": self.folder,
"pathVariables": variables,
"time": int(time()),
}
)
def process_url(self, urlvars=False):
url = self.url
path_vars = {}
url_vars = {}
params = dict((p["name"], p) for p in self.params)
params.update(
dict((p["name"], p) for p in self.operation.get("parameters", []))
)
if not params:
return url, None
for name, param in iteritems(params):
if param["in"] == "path":
url = url.replace("{%s}" % name, ":%s" % name)
path_vars[name] = DEFAULT_VARS.get(param["type"], "")
elif param["in"] == "query" and urlvars:
default = DEFAULT_VARS.get(param["type"], "")
url_vars[name] = param.get("default", default)
if url_vars:
url = "?".join((url, urlencode(url_vars)))
return url, path_vars
class Folder(object):
def __init__(self, collection, tag):
self.collection = collection
self.tag = tag["name"]
self.description = tag["description"]
@property
def id(self):
return str(uuid5(self.collection.uuid, str(self.tag)))
@property
def order(self):
return [r.id for r in self.collection.requests if r.folder == self.id]
def as_dict(self):
return clean(
{
"id": self.id,
"name": self.tag,
"description": self.description,
"order": self.order,
"collectionId": self.collection.id,
}
)
class PostmanCollectionV1(object):
"""Postman Collection (V1 format) serializer"""
def __init__(self, api, swagger=False):
self.api = api
self.swagger = swagger
@property
def uuid(self):
return uuid5(NAMESPACE_URL, self.api.base_url)
@property
def id(self):
return str(self.uuid)
@property
def requests(self):
if self.swagger:
# First request is Swagger specifications
yield Request(
self,
"/swagger.json",
{},
"get",
{
"operationId": "Swagger specifications",
"summary": "The API Swagger specifications as JSON",
},
)
# Then iter over API paths and methods
for path, operations in iteritems(self.api.__schema__["paths"]):
path_params = operations.get("parameters", [])
for method, operation in iteritems(operations):
if method != "parameters":
yield Request(self, path, path_params, method, operation)
@property
def folders(self):
for tag in self.api.__schema__["tags"]:
yield Folder(self, tag)
@property
def apikeys(self):
return dict(
(name, secdef["name"])
for name, secdef in iteritems(
self.api.__schema__.get("securityDefinitions")
)
if secdef.get("in") == "header" and secdef.get("type") == "apiKey"
)
def as_dict(self, urlvars=False):
return clean(
{
"id": self.id,
"name": " ".join((self.api.title, self.api.version)),
"description": self.api.description,
"order": [r.id for r in self.requests if not r.folder],
"requests": [r.as_dict(urlvars=urlvars) for r in self.requests],
"folders": [f.as_dict() for f in self.folders],
"timestamp": int(time()),
}
)

View File

@ -1,20 +1,24 @@
from __future__ import absolute_import
# -*- coding: utf-8 -*-
from __future__ import unicode_literals, absolute_import
try:
from ujson import dumps
except ImportError:
from json import dumps
from flask import make_response, current_app
from flask_restful.utils import PY3
from json import dumps
def output_json(data, code, headers=None):
"""Makes a Flask response with a JSON encoded body"""
settings = current_app.config.get('RESTFUL_JSON', {})
settings = current_app.config.get("RESTX_JSON", {})
# If we're in debug mode, and the indent is not set, we set it to a
# reasonable value here. Note that this won't override any existing value
# that was set. We also set the "sort_keys" value.
# that was set.
if current_app.debug:
settings.setdefault('indent', 4)
settings.setdefault('sort_keys', not PY3)
settings.setdefault("indent", 4)
# always end the json dumps with a new line
# see https://github.com/mitsuhiko/flask/pull/1262

View File

@ -1,18 +1,30 @@
from copy import deepcopy
# -*- coding: utf-8 -*-
from __future__ import unicode_literals
try:
from collections.abc import MutableSequence
except ImportError:
from collections import MutableSequence
from flask import current_app, request
from werkzeug.datastructures import MultiDict, FileStorage
from werkzeug import exceptions
import flask_restful
import decimal
import six
try:
from collections.abc import Hashable
except ImportError:
from collections import Hashable
from copy import deepcopy
from flask import current_app, request
from werkzeug.datastructures import MultiDict, FileStorage
from werkzeug import exceptions
from .errors import abort, SpecsError
from .marshalling import marshal
from .model import Model
from ._http import HTTPStatus
class ParseResult(dict):
"""
The default result container as an Object dict.
"""
class Namespace(dict):
def __getattr__(self, name):
try:
return self[name]
@ -24,37 +36,52 @@ class Namespace(dict):
_friendly_location = {
u'json': u'the JSON body',
u'form': u'the post body',
u'args': u'the query string',
u'values': u'the post body or the query string',
u'headers': u'the HTTP headers',
u'cookies': u'the request\'s cookies',
u'files': u'an uploaded file',
"json": "the JSON body",
"form": "the post body",
"args": "the query string",
"values": "the post body or the query string",
"headers": "the HTTP headers",
"cookies": "the request's cookies",
"files": "an uploaded file",
}
text_type = lambda x: six.text_type(x)
#: Maps Flask-RESTX RequestParser locations to Swagger ones
LOCATIONS = {
"args": "query",
"form": "formData",
"headers": "header",
"json": "body",
"values": "query",
"files": "formData",
}
#: Maps Python primitives types to Swagger ones
PY_TYPES = {
int: "integer",
str: "string",
bool: "boolean",
float: "number",
None: "void",
}
SPLIT_CHAR = ","
text_type = lambda x: six.text_type(x) # noqa
class Argument(object):
"""
:param name: Either a name or a list of option strings, e.g. foo or
-f, --foo.
:param default: The value produced if the argument is absent from the
request.
:param name: Either a name or a list of option strings, e.g. foo or -f, --foo.
:param default: The value produced if the argument is absent from the request.
:param dest: The name of the attribute to be added to the object
returned by :meth:`~reqparse.RequestParser.parse_args()`.
:param bool required: Whether or not the argument may be omitted (optionals
only).
:param action: The basic type of action to be taken when this argument
:param bool required: Whether or not the argument may be omitted (optionals only).
:param string action: The basic type of action to be taken when this argument
is encountered in the request. Valid options are "store" and "append".
:param ignore: Whether to ignore cases where the argument fails type
conversion
:param type: The type to which the request argument should be
converted. If a type raises an exception, the message in the
error will be returned in the response. Defaults to :class:`unicode`
in python2 and :class:`str` in python3.
:param bool ignore: Whether to ignore cases where the argument fails type conversion
:param type: The type to which the request argument should be converted.
If a type raises an exception, the message in the error will be returned in the response.
Defaults to :class:`unicode` in python2 and :class:`str` in python3.
:param location: The attributes of the :class:`flask.Request` object
to source the arguments from (ex: headers, args, etc.), can be an
iterator. The last item listed takes precedence in the result set.
@ -71,11 +98,24 @@ class Argument(object):
:param bool nullable: If enabled, allows null value in argument.
"""
def __init__(self, name, default=None, dest=None, required=False,
ignore=False, type=text_type, location=('json', 'values',),
choices=(), action='store', help=None, operators=('=',),
case_sensitive=True, store_missing=True, trim=False,
nullable=True):
def __init__(
self,
name,
default=None,
dest=None,
required=False,
ignore=False,
type=text_type,
location=("json", "values",),
choices=(),
action="store",
help=None,
operators=("=",),
case_sensitive=True,
store_missing=True,
trim=False,
nullable=True,
):
self.name = name
self.default = default
self.dest = dest
@ -92,25 +132,9 @@ class Argument(object):
self.trim = trim
self.nullable = nullable
def __str__(self):
if len(self.choices) > 5:
choices = self.choices[0:3]
choices.append('...')
choices.append(self.choices[-1])
else:
choices = self.choices
return 'Name: {0}, type: {1}, choices: {2}'.format(self.name, self.type, choices)
def __repr__(self):
return "{0}('{1}', default={2}, dest={3}, required={4}, ignore={5}, location={6}, " \
"type=\"{7}\", choices={8}, action='{9}', help={10}, case_sensitive={11}, " \
"operators={12}, store_missing={13}, trim={14}, nullable={15})".format(
self.__class__.__name__, self.name, self.default, self.dest, self.required, self.ignore, self.location,
self.type, self.choices, self.action, self.help, self.case_sensitive,
self.operators, self.store_missing, self.trim, self.nullable)
def source(self, request):
"""Pulls values off the request in the provided location
"""
Pulls values off the request in the provided location
:param request: The flask request object to parse arguments from
"""
if isinstance(self.location, six.string_types):
@ -134,10 +158,12 @@ class Argument(object):
def convert(self, value, op):
# Don't cast None
if value is None:
if self.nullable:
return None
else:
raise ValueError('Must not be null!')
if not self.nullable:
raise ValueError("Must not be null!")
return None
elif isinstance(self.type, Model) and isinstance(value, dict):
return marshal(value, self.type)
# and check if we're expecting a filestorage and haven't overridden `type`
# (required because the below instantiation isn't valid for FileStorage)
@ -149,38 +175,43 @@ class Argument(object):
except TypeError:
try:
if self.type is decimal.Decimal:
return self.type(str(value))
return self.type(str(value), self.name)
else:
return self.type(value, self.name)
except TypeError:
return self.type(value)
def handle_validation_error(self, error, bundle_errors):
"""Called when an error is raised while parsing. Aborts the request
"""
Called when an error is raised while parsing. Aborts the request
with a 400 status and an error message
:param error: the error that was raised
:param bundle_errors: do not abort when first error occurs, return a
:param bool bundle_errors: do not abort when first error occurs, return a
dict with the name of the argument and the error message to be
bundled
"""
error_str = six.text_type(error)
error_msg = self.help.format(error_msg=error_str) if self.help else error_str
msg = {self.name: error_msg}
error_msg = (
" ".join([six.text_type(self.help), error_str]) if self.help else error_str
)
errors = {self.name: error_msg}
if current_app.config.get("BUNDLE_ERRORS", False) or bundle_errors:
return error, msg
flask_restful.abort(400, message=msg)
if bundle_errors:
return ValueError(error), errors
abort(HTTPStatus.BAD_REQUEST, "Input payload validation failed", errors=errors)
def parse(self, request, bundle_errors=False):
"""Parses argument value(s) from the request, converting according to
"""
Parses argument value(s) from the request, converting according to
the argument's type.
:param request: The flask request object to parse arguments from
:param bundle_errors: Do not abort when first error occurs, return a
:param bool bundle_errors: do not abort when first error occurs, return a
dict with the name of the argument and the error message to be
bundled
"""
bundle_errors = current_app.config.get("BUNDLE_ERRORS", False) or bundle_errors
source = self.source(request)
results = []
@ -196,9 +227,7 @@ class Argument(object):
if hasattr(source, "getlist"):
values = source.getlist(name)
else:
values = source.get(name)
if not (isinstance(values, MutableSequence) and self.action == 'append'):
values = [values]
values = [source.get(name)]
for value in values:
if hasattr(value, "strip") and self.trim:
@ -207,24 +236,26 @@ class Argument(object):
value = value.lower()
if hasattr(self.choices, "__iter__"):
self.choices = [choice.lower()
for choice in self.choices]
self.choices = [choice.lower() for choice in self.choices]
try:
value = self.convert(value, operator)
if self.action == "split":
value = [
self.convert(v, operator)
for v in value.split(SPLIT_CHAR)
]
else:
value = self.convert(value, operator)
except Exception as error:
if self.ignore:
continue
return self.handle_validation_error(error, bundle_errors)
if self.choices and value not in self.choices:
if current_app.config.get("BUNDLE_ERRORS", False) or bundle_errors:
return self.handle_validation_error(
ValueError(u"{0} is not a valid choice".format(
value)), bundle_errors)
self.handle_validation_error(
ValueError(u"{0} is not a valid choice".format(
value)), bundle_errors)
msg = "The value '{0}' is not a valid choice for '{1}'.".format(
value, name
)
return self.handle_validation_error(msg, bundle_errors)
if name in request.unparsed_arguments:
request.unparsed_arguments.pop(name)
@ -232,18 +263,12 @@ class Argument(object):
if not results and self.required:
if isinstance(self.location, six.string_types):
error_msg = u"Missing required parameter in {0}".format(
_friendly_location.get(self.location, self.location)
)
location = _friendly_location.get(self.location, self.location)
else:
friendly_locations = [_friendly_location.get(loc, loc)
for loc in self.location]
error_msg = u"Missing required parameter in {0}".format(
' or '.join(friendly_locations)
)
if current_app.config.get("BUNDLE_ERRORS", False) or bundle_errors:
return self.handle_validation_error(ValueError(error_msg), bundle_errors)
self.handle_validation_error(ValueError(error_msg), bundle_errors)
locations = [_friendly_location.get(loc, loc) for loc in self.location]
location = " or ".join(locations)
error_msg = "Missing required parameter in {0}".format(location)
return self.handle_validation_error(error_msg, bundle_errors)
if not results:
if callable(self.default):
@ -251,48 +276,79 @@ class Argument(object):
else:
return self.default, _not_found
if self.action == 'append':
if self.action == "append":
return results, _found
if self.action == 'store' or len(results) == 1:
if self.action == "store" or len(results) == 1:
return results[0], _found
return results, _found
@property
def __schema__(self):
if self.location == "cookie":
return
param = {"name": self.name, "in": LOCATIONS.get(self.location, "query")}
_handle_arg_type(self, param)
if self.required:
param["required"] = True
if self.help:
param["description"] = self.help
if self.default is not None:
param["default"] = (
self.default() if callable(self.default) else self.default
)
if self.action == "append":
param["items"] = {"type": param["type"]}
param["type"] = "array"
param["collectionFormat"] = "multi"
if self.action == "split":
param["items"] = {"type": param["type"]}
param["type"] = "array"
param["collectionFormat"] = "csv"
if self.choices:
param["enum"] = self.choices
return param
class RequestParser(object):
"""Enables adding and parsing of multiple arguments in the context of a
single request. Ex::
"""
Enables adding and parsing of multiple arguments in the context of a single request.
Ex::
from flask_restful import reqparse
from flask_restx import RequestParser
parser = reqparse.RequestParser()
parser = RequestParser()
parser.add_argument('foo')
parser.add_argument('int_bar', type=int)
args = parser.parse_args()
:param bool trim: If enabled, trims whitespace on all arguments in this
parser
:param bool trim: If enabled, trims whitespace on all arguments in this parser
:param bool bundle_errors: If enabled, do not abort when first error occurs,
return a dict with the name of the argument and the error message to be
bundled and return all validation errors
"""
def __init__(self, argument_class=Argument, namespace_class=Namespace,
trim=False, bundle_errors=False):
def __init__(
self,
argument_class=Argument,
result_class=ParseResult,
trim=False,
bundle_errors=False,
):
self.args = []
self.argument_class = argument_class
self.namespace_class = namespace_class
self.result_class = result_class
self.trim = trim
self.bundle_errors = bundle_errors
def add_argument(self, *args, **kwargs):
"""Adds an argument to be parsed.
"""
Adds an argument to be parsed.
Accepts either a single instance of Argument or arguments to be passed
into :class:`Argument`'s constructor.
See :class:`Argument`'s constructor for documentation on the
available options.
See :class:`Argument`'s constructor for documentation on the available options.
"""
if len(args) == 1 and isinstance(args[0], self.argument_class):
@ -303,26 +359,28 @@ class RequestParser(object):
# Do not know what other argument classes are out there
if self.trim and self.argument_class is Argument:
# enable trim for appended element
self.args[-1].trim = kwargs.get('trim', self.trim)
self.args[-1].trim = kwargs.get("trim", self.trim)
return self
def parse_args(self, req=None, strict=False, http_error_code=400):
"""Parse all arguments from the provided request and return the results
as a Namespace
def parse_args(self, req=None, strict=False):
"""
Parse all arguments from the provided request and return the results as a ParseResult
:param req: Can be used to overwrite request from Flask
:param strict: if req includes args not in parser, throw 400 BadRequest exception
:param http_error_code: use custom error code for `flask_restful.abort()`
:param bool strict: if req includes args not in parser, throw 400 BadRequest exception
:return: the parsed results as :class:`ParseResult` (or any class defined as :attr:`result_class`)
:rtype: ParseResult
"""
if req is None:
req = request
namespace = self.namespace_class()
result = self.result_class()
# A record of arguments not yet parsed; as each is found
# among self.args, it will be popped out
req.unparsed_arguments = dict(self.argument_class('').source(req)) if strict else {}
req.unparsed_arguments = (
dict(self.argument_class("").source(req)) if strict else {}
)
errors = {}
for arg in self.args:
value, found = arg.parse(req, self.bundle_errors)
@ -330,26 +388,29 @@ class RequestParser(object):
errors.update(found)
found = None
if found or arg.store_missing:
namespace[arg.dest or arg.name] = value
result[arg.dest or arg.name] = value
if errors:
flask_restful.abort(http_error_code, message=errors)
abort(
HTTPStatus.BAD_REQUEST, "Input payload validation failed", errors=errors
)
if strict and req.unparsed_arguments:
raise exceptions.BadRequest('Unknown arguments: %s'
% ', '.join(req.unparsed_arguments.keys()))
arguments = ", ".join(req.unparsed_arguments.keys())
msg = "Unknown arguments: {0}".format(arguments)
raise exceptions.BadRequest(msg)
return namespace
return result
def copy(self):
""" Creates a copy of this RequestParser with the same set of arguments """
parser_copy = self.__class__(self.argument_class, self.namespace_class)
"""Creates a copy of this RequestParser with the same set of arguments"""
parser_copy = self.__class__(self.argument_class, self.result_class)
parser_copy.args = deepcopy(self.args)
parser_copy.trim = self.trim
parser_copy.bundle_errors = self.bundle_errors
return parser_copy
def replace_argument(self, name, *args, **kwargs):
""" Replace the argument matching the given name with a new version. """
"""Replace the argument matching the given name with a new version."""
new_arg = self.argument_class(name, *args, **kwargs)
for index, arg in enumerate(self.args[:]):
if new_arg.name == arg.name:
@ -359,9 +420,36 @@ class RequestParser(object):
return self
def remove_argument(self, name):
""" Remove the argument matching the given name. """
"""Remove the argument matching the given name."""
for index, arg in enumerate(self.args[:]):
if name == arg.name:
del self.args[index]
break
return self
@property
def __schema__(self):
params = []
locations = set()
for arg in self.args:
param = arg.__schema__
if param:
params.append(param)
locations.add(param["in"])
if "body" in locations and "formData" in locations:
raise SpecsError("Can't use formData and body at the same time")
return params
def _handle_arg_type(arg, param):
if isinstance(arg.type, Hashable) and arg.type in PY_TYPES:
param["type"] = PY_TYPES[arg.type]
elif hasattr(arg.type, "__apidoc__"):
param["type"] = arg.type.__apidoc__["name"]
param["in"] = "body"
elif hasattr(arg.type, "__schema__"):
param.update(arg.type.__schema__)
elif arg.location == "files":
param["type"] = "file"
else:
param["type"] = "string"

View File

@ -0,0 +1,93 @@
# -*- coding: utf-8 -*-
from __future__ import unicode_literals
from flask import request
from flask.views import MethodView
from werkzeug import __version__ as werkzeug_version
if werkzeug_version.split('.')[0] >= '2':
from werkzeug.wrappers import Response as BaseResponse
else:
from werkzeug.wrappers import BaseResponse
from .model import ModelBase
from .utils import unpack
class Resource(MethodView):
"""
Represents an abstract RESTX resource.
Concrete resources should extend from this class
and expose methods for each supported HTTP method.
If a resource is invoked with an unsupported HTTP method,
the API will return a response with status 405 Method Not Allowed.
Otherwise the appropriate method is called and passed all arguments
from the url rule used when adding the resource to an Api instance.
See :meth:`~flask_restx.Api.add_resource` for details.
"""
representations = None
method_decorators = []
def __init__(self, api=None, *args, **kwargs):
self.api = api
def dispatch_request(self, *args, **kwargs):
# Taken from flask
meth = getattr(self, request.method.lower(), None)
if meth is None and request.method == "HEAD":
meth = getattr(self, "get", None)
assert meth is not None, "Unimplemented method %r" % request.method
for decorator in self.method_decorators:
meth = decorator(meth)
self.validate_payload(meth)
resp = meth(*args, **kwargs)
if isinstance(resp, BaseResponse):
return resp
representations = self.representations or {}
mediatype = request.accept_mimetypes.best_match(representations, default=None)
if mediatype in representations:
data, code, headers = unpack(resp)
resp = representations[mediatype](data, code, headers)
resp.headers["Content-Type"] = mediatype
return resp
return resp
def __validate_payload(self, expect, collection=False):
"""
:param ModelBase expect: the expected model for the input payload
:param bool collection: False if a single object of a resource is
expected, True if a collection of objects of a resource is expected.
"""
# TODO: proper content negotiation
data = request.get_json()
if collection:
data = data if isinstance(data, list) else [data]
for obj in data:
expect.validate(obj, self.api.refresolver, self.api.format_checker)
else:
expect.validate(data, self.api.refresolver, self.api.format_checker)
def validate_payload(self, func):
"""Perform a payload validation on expected model if necessary"""
if getattr(func, "__apidoc__", False) is not False:
doc = func.__apidoc__
validate = doc.get("validate", None)
validate = validate if validate is not None else self.api._validate
if validate:
for expect in doc.get("expect", []):
# TODO: handle third party handlers
if isinstance(expect, list) and len(expect) == 1:
if isinstance(expect[0], ModelBase):
self.__validate_payload(expect[0], collection=True)
if isinstance(expect, ModelBase):
self.__validate_payload(expect, collection=False)

View File

@ -0,0 +1,125 @@
# -*- coding: utf-8 -*-
"""
This module give access to OpenAPI specifications schemas
and allows to validate specs against them.
.. versionadded:: 0.12.1
"""
from __future__ import unicode_literals
import io
import json
import pkg_resources
try:
from collections.abc import Mapping
except ImportError:
# TODO Remove this to drop Python2 support
from collections import Mapping
from jsonschema import Draft4Validator
from flask_restx import errors
class SchemaValidationError(errors.ValidationError):
"""
Raised when specification is not valid
.. versionadded:: 0.12.1
"""
def __init__(self, msg, errors=None):
super(SchemaValidationError, self).__init__(msg)
self.errors = errors
def __str__(self):
msg = [self.msg]
for error in sorted(self.errors, key=lambda e: e.path):
path = ".".join(error.path)
msg.append("- {}: {}".format(path, error.message))
for suberror in sorted(error.context, key=lambda e: e.schema_path):
path = ".".join(suberror.schema_path)
msg.append(" - {}: {}".format(path, suberror.message))
return "\n".join(msg)
__unicode__ = __str__
class LazySchema(Mapping):
"""
A thin wrapper around schema file lazy loading the data on first access
:param filename str: The package relative json schema filename
:param validator: The jsonschema validator class version
.. versionadded:: 0.12.1
"""
def __init__(self, filename, validator=Draft4Validator):
super(LazySchema, self).__init__()
self.filename = filename
self._schema = None
self._validator = validator
def _load(self):
if not self._schema:
filename = pkg_resources.resource_filename(__name__, self.filename)
with io.open(filename) as infile:
self._schema = json.load(infile)
def __getitem__(self, key):
self._load()
return self._schema.__getitem__(key)
def __iter__(self):
self._load()
return self._schema.__iter__()
def __len__(self):
self._load()
return self._schema.__len__()
@property
def validator(self):
"""The jsonschema validator to validate against"""
return self._validator(self)
#: OpenAPI 2.0 specification schema
OAS_20 = LazySchema("oas-2.0.json")
#: Map supported OpenAPI versions to their JSON schema
VERSIONS = {
"2.0": OAS_20,
}
def validate(data):
"""
Validate an OpenAPI specification.
Supported OpenAPI versions: 2.0
:param data dict: The specification to validate
:returns boolean: True if the specification is valid
:raises SchemaValidationError: when the specification is invalid
:raises flask_restx.errors.SpecsError: when it's not possible to determinate
the schema to validate against
.. versionadded:: 0.12.1
"""
if "swagger" not in data:
raise errors.SpecsError("Unable to determinate OpenAPI schema version")
version = data["swagger"]
if version not in VERSIONS:
raise errors.SpecsError('Unknown OpenAPI schema version "{}"'.format(version))
validator = VERSIONS[version].validator
validation_errors = list(validator.iter_errors(data))
if validation_errors:
raise SchemaValidationError(
"OpenAPI {} validation failed".format(version), errors=validation_errors
)
return True

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,26 @@
/* droid-sans-400normal - latin */
@font-face {
font-family: 'Droid Sans';
font-style: normal;
font-display: swap;
font-weight: 400;
src:
local('Droid Sans Regular '),
local('Droid Sans-Regular'),
url('./files/droid-sans-latin-400.woff2') format('woff2'), /* Super Modern Browsers */
url('./files/droid-sans-latin-400.woff') format('woff'); /* Modern Browsers */
}
/* droid-sans-700normal - latin */
@font-face {
font-family: 'Droid Sans';
font-style: normal;
font-display: swap;
font-weight: 700;
src:
local('Droid Sans Bold '),
local('Droid Sans-Bold'),
url('./files/droid-sans-latin-700.woff2') format('woff2'), /* Super Modern Browsers */
url('./files/droid-sans-latin-700.woff') format('woff'); /* Modern Browsers */
}

Binary file not shown.

After

Width:  |  Height:  |  Size: 665 B

Binary file not shown.

After

Width:  |  Height:  |  Size: 628 B

View File

View File

@ -0,0 +1,75 @@
<!doctype html>
<html lang="en-US">
<head>
<title>Swagger UI: OAuth2 Redirect</title>
</head>
<body>
<script>
'use strict';
function run () {
var oauth2 = window.opener.swaggerUIRedirectOauth2;
var sentState = oauth2.state;
var redirectUrl = oauth2.redirectUrl;
var isValid, qp, arr;
if (/code|token|error/.test(window.location.hash)) {
qp = window.location.hash.substring(1);
} else {
qp = location.search.substring(1);
}
arr = qp.split("&");
arr.forEach(function (v,i,_arr) { _arr[i] = '"' + v.replace('=', '":"') + '"';});
qp = qp ? JSON.parse('{' + arr.join() + '}',
function (key, value) {
return key === "" ? value : decodeURIComponent(value);
}
) : {};
isValid = qp.state === sentState;
if ((
oauth2.auth.schema.get("flow") === "accessCode" ||
oauth2.auth.schema.get("flow") === "authorizationCode" ||
oauth2.auth.schema.get("flow") === "authorization_code"
) && !oauth2.auth.code) {
if (!isValid) {
oauth2.errCb({
authId: oauth2.auth.name,
source: "auth",
level: "warning",
message: "Authorization may be unsafe, passed state was changed in server Passed state wasn't returned from auth server"
});
}
if (qp.code) {
delete oauth2.state;
oauth2.auth.code = qp.code;
oauth2.callback({auth: oauth2.auth, redirectUrl: redirectUrl});
} else {
let oauthErrorMsg;
if (qp.error) {
oauthErrorMsg = "["+qp.error+"]: " +
(qp.error_description ? qp.error_description+ ". " : "no accessCode received from the server. ") +
(qp.error_uri ? "More info: "+qp.error_uri : "");
}
oauth2.errCb({
authId: oauth2.auth.name,
source: "auth",
level: "error",
message: oauthErrorMsg || "[Authorization failed]: no accessCode received from the server"
});
}
} else {
oauth2.callback({auth: oauth2.auth, token: qp, isValid: isValid, redirectUrl: redirectUrl});
}
window.close();
}
window.addEventListener('DOMContentLoaded', function () {
run();
});
</script>
</body>
</html>

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

Some files were not shown because too many files have changed in this diff Show More