Replaced peewee with sqlalchemy as ORM. This is a major change, please report related issues on Discord.

This commit is contained in:
morpheus65535 2023-07-26 19:34:49 -04:00 committed by GitHub
parent 486d2f9481
commit bccded275c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
693 changed files with 313863 additions and 55131 deletions

View File

@ -1,11 +1,12 @@
# coding=utf-8
import operator
import ast
from functools import reduce
from flask_restx import Resource, Namespace, fields
from app.database import get_exclusion_clause, TableEpisodes, TableShows, TableMovies
from app.database import get_exclusion_clause, TableEpisodes, TableShows, TableMovies, database, select
from app.get_providers import get_throttled_providers
from app.signalr_client import sonarr_signalr_client, radarr_signalr_client
from app.announcements import get_all_announcements
@ -35,31 +36,38 @@ class Badges(Resource):
@api_ns_badges.doc(parser=None)
def get(self):
"""Get badges count to update the UI"""
episodes_conditions = [(TableEpisodes.missing_subtitles.is_null(False)),
episodes_conditions = [(TableEpisodes.missing_subtitles.is_not(None)),
(TableEpisodes.missing_subtitles != '[]')]
episodes_conditions += get_exclusion_clause('series')
missing_episodes = TableEpisodes.select(TableShows.tags,
TableShows.seriesType,
TableEpisodes.monitored)\
.join(TableShows, on=(TableEpisodes.sonarrSeriesId == TableShows.sonarrSeriesId))\
.where(reduce(operator.and_, episodes_conditions))\
.count()
missing_episodes = database.execute(
select(TableEpisodes.missing_subtitles)
.select_from(TableEpisodes)
.join(TableShows)
.where(reduce(operator.and_, episodes_conditions))) \
.all()
missing_episodes_count = 0
for episode in missing_episodes:
missing_episodes_count += len(ast.literal_eval(episode.missing_subtitles))
movies_conditions = [(TableMovies.missing_subtitles.is_null(False)),
movies_conditions = [(TableMovies.missing_subtitles.is_not(None)),
(TableMovies.missing_subtitles != '[]')]
movies_conditions += get_exclusion_clause('movie')
missing_movies = TableMovies.select(TableMovies.tags,
TableMovies.monitored)\
.where(reduce(operator.and_, movies_conditions))\
.count()
missing_movies = database.execute(
select(TableMovies.missing_subtitles)
.select_from(TableMovies)
.where(reduce(operator.and_, movies_conditions))) \
.all()
missing_movies_count = 0
for movie in missing_movies:
missing_movies_count += len(ast.literal_eval(movie.missing_subtitles))
throttled_providers = len(get_throttled_providers())
health_issues = len(get_health_issues())
result = {
"episodes": missing_episodes,
"movies": missing_movies,
"episodes": missing_episodes_count,
"movies": missing_movies_count,
"providers": throttled_providers,
"status": health_issues,
'sonarr_signalr': "LIVE" if sonarr_signalr_client.connected else "",

View File

@ -4,7 +4,7 @@ import pretty
from flask_restx import Resource, Namespace, reqparse, fields
from app.database import TableEpisodes, TableShows, TableBlacklist
from app.database import TableEpisodes, TableShows, TableBlacklist, database, select
from subtitles.tools.delete import delete_subtitles
from sonarr.blacklist import blacklist_log, blacklist_delete_all, blacklist_delete
from utilities.path_mappings import path_mappings
@ -48,29 +48,32 @@ class EpisodesBlacklist(Resource):
start = args.get('start')
length = args.get('length')
data = TableBlacklist.select(TableShows.title.alias('seriesTitle'),
TableEpisodes.season.concat('x').concat(TableEpisodes.episode).alias('episode_number'),
TableEpisodes.title.alias('episodeTitle'),
TableEpisodes.sonarrSeriesId,
TableBlacklist.provider,
TableBlacklist.subs_id,
TableBlacklist.language,
TableBlacklist.timestamp)\
.join(TableEpisodes, on=(TableBlacklist.sonarr_episode_id == TableEpisodes.sonarrEpisodeId))\
.join(TableShows, on=(TableBlacklist.sonarr_series_id == TableShows.sonarrSeriesId))\
stmt = select(TableShows.title.label('seriesTitle'),
TableEpisodes.season.concat('x').concat(TableEpisodes.episode).label('episode_number'),
TableEpisodes.title.label('episodeTitle'),
TableEpisodes.sonarrSeriesId,
TableBlacklist.provider,
TableBlacklist.subs_id,
TableBlacklist.language,
TableBlacklist.timestamp) \
.select_from(TableBlacklist) \
.join(TableShows, onclause=TableBlacklist.sonarr_series_id == TableShows.sonarrSeriesId) \
.join(TableEpisodes, onclause=TableBlacklist.sonarr_episode_id == TableEpisodes.sonarrEpisodeId) \
.order_by(TableBlacklist.timestamp.desc())
if length > 0:
data = data.limit(length).offset(start)
data = list(data.dicts())
stmt = stmt.limit(length).offset(start)
for item in data:
# Make timestamp pretty
item["parsed_timestamp"] = item['timestamp'].strftime('%x %X')
item.update({'timestamp': pretty.date(item['timestamp'])})
postprocess(item)
return data
return [postprocess({
'seriesTitle': x.seriesTitle,
'episode_number': x.episode_number,
'episodeTitle': x.episodeTitle,
'sonarrSeriesId': x.sonarrSeriesId,
'provider': x.provider,
'subs_id': x.subs_id,
'language': x.language,
'timestamp': pretty.date(x.timestamp),
'parsed_timestamp': x.timestamp.strftime('%x %X')
}) for x in database.execute(stmt).all()]
post_request_parser = reqparse.RequestParser()
post_request_parser.add_argument('seriesid', type=int, required=True, help='Series ID')
@ -94,15 +97,15 @@ class EpisodesBlacklist(Resource):
subs_id = args.get('subs_id')
language = args.get('language')
episodeInfo = TableEpisodes.select(TableEpisodes.path)\
.where(TableEpisodes.sonarrEpisodeId == sonarr_episode_id)\
.dicts()\
.get_or_none()
episodeInfo = database.execute(
select(TableEpisodes.path)
.where(TableEpisodes.sonarrEpisodeId == sonarr_episode_id)) \
.first()
if not episodeInfo:
return 'Episode not found', 404
media_path = episodeInfo['path']
media_path = episodeInfo.path
subtitles_path = args.get('subtitles_path')
blacklist_log(sonarr_series_id=sonarr_series_id,

View File

@ -2,7 +2,7 @@
from flask_restx import Resource, Namespace, reqparse, fields
from app.database import TableEpisodes
from app.database import TableEpisodes, database, select
from api.swaggerui import subtitles_model, subtitles_language_model, audio_language_model
from ..utils import authenticate, postprocess
@ -23,24 +23,16 @@ class Episodes(Resource):
get_audio_language_model = api_ns_episodes.model('audio_language_model', audio_language_model)
get_response_model = api_ns_episodes.model('EpisodeGetResponse', {
'rowid': fields.Integer(),
'audio_codec': fields.String(),
'audio_language': fields.Nested(get_audio_language_model),
'episode': fields.Integer(),
'episode_file_id': fields.Integer(),
'failedAttempts': fields.String(),
'file_size': fields.Integer(),
'format': fields.String(),
'missing_subtitles': fields.Nested(get_subtitles_language_model),
'monitored': fields.Boolean(),
'path': fields.String(),
'resolution': fields.String(),
'season': fields.Integer(),
'sonarrEpisodeId': fields.Integer(),
'sonarrSeriesId': fields.Integer(),
'subtitles': fields.Nested(get_subtitles_model),
'title': fields.String(),
'video_codec': fields.String(),
'sceneName': fields.String(),
})
@ -56,18 +48,44 @@ class Episodes(Resource):
seriesId = args.get('seriesid[]')
episodeId = args.get('episodeid[]')
stmt = select(
TableEpisodes.audio_language,
TableEpisodes.episode,
TableEpisodes.missing_subtitles,
TableEpisodes.monitored,
TableEpisodes.path,
TableEpisodes.season,
TableEpisodes.sonarrEpisodeId,
TableEpisodes.sonarrSeriesId,
TableEpisodes.subtitles,
TableEpisodes.title,
TableEpisodes.sceneName,
)
if len(episodeId) > 0:
result = TableEpisodes.select().where(TableEpisodes.sonarrEpisodeId.in_(episodeId)).dicts()
stmt_query = database.execute(
stmt
.where(TableEpisodes.sonarrEpisodeId.in_(episodeId)))\
.all()
elif len(seriesId) > 0:
result = TableEpisodes.select()\
.where(TableEpisodes.sonarrSeriesId.in_(seriesId))\
.order_by(TableEpisodes.season.desc(), TableEpisodes.episode.desc())\
.dicts()
stmt_query = database.execute(
stmt
.where(TableEpisodes.sonarrSeriesId.in_(seriesId))
.order_by(TableEpisodes.season.desc(), TableEpisodes.episode.desc()))\
.all()
else:
return "Series or Episode ID not provided", 404
result = list(result)
for item in result:
postprocess(item)
return result
return [postprocess({
'audio_language': x.audio_language,
'episode': x.episode,
'missing_subtitles': x.missing_subtitles,
'monitored': x.monitored,
'path': x.path,
'season': x.season,
'sonarrEpisodeId': x.sonarrEpisodeId,
'sonarrSeriesId': x.sonarrSeriesId,
'subtitles': x.subtitles,
'title': x.title,
'sceneName': x.sceneName,
}) for x in stmt_query]

View File

@ -7,7 +7,7 @@ from flask_restx import Resource, Namespace, reqparse
from subliminal_patch.core import SUBTITLE_EXTENSIONS
from werkzeug.datastructures import FileStorage
from app.database import TableShows, TableEpisodes, get_audio_profile_languages, get_profile_id
from app.database import TableShows, TableEpisodes, get_audio_profile_languages, get_profile_id, database, select
from utilities.path_mappings import path_mappings
from subtitles.upload import manual_upload_subtitle
from subtitles.download import generate_subtitles
@ -42,28 +42,28 @@ class EpisodesSubtitles(Resource):
args = self.patch_request_parser.parse_args()
sonarrSeriesId = args.get('seriesid')
sonarrEpisodeId = args.get('episodeid')
episodeInfo = TableEpisodes.select(
TableEpisodes.path,
TableEpisodes.sceneName,
TableEpisodes.audio_language,
TableShows.title) \
.join(TableShows, on=(TableEpisodes.sonarrSeriesId == TableShows.sonarrSeriesId)) \
.where(TableEpisodes.sonarrEpisodeId == sonarrEpisodeId) \
.dicts() \
.get_or_none()
episodeInfo = database.execute(
select(TableEpisodes.path,
TableEpisodes.sceneName,
TableEpisodes.audio_language,
TableShows.title)
.select_from(TableEpisodes)
.join(TableShows)
.where(TableEpisodes.sonarrEpisodeId == sonarrEpisodeId)) \
.first()
if not episodeInfo:
return 'Episode not found', 404
title = episodeInfo['title']
episodePath = path_mappings.path_replace(episodeInfo['path'])
sceneName = episodeInfo['sceneName'] or "None"
title = episodeInfo.title
episodePath = path_mappings.path_replace(episodeInfo.path)
sceneName = episodeInfo.sceneName or "None"
language = args.get('language')
hi = args.get('hi').capitalize()
forced = args.get('forced').capitalize()
audio_language_list = get_audio_profile_languages(episodeInfo["audio_language"])
audio_language_list = get_audio_profile_languages(episodeInfo.audio_language)
if len(audio_language_list) > 0:
audio_language = audio_language_list[0]['name']
else:
@ -104,18 +104,18 @@ class EpisodesSubtitles(Resource):
args = self.post_request_parser.parse_args()
sonarrSeriesId = args.get('seriesid')
sonarrEpisodeId = args.get('episodeid')
episodeInfo = TableEpisodes.select(TableEpisodes.path,
TableEpisodes.audio_language) \
.where(TableEpisodes.sonarrEpisodeId == sonarrEpisodeId) \
.dicts() \
.get_or_none()
episodeInfo = database.execute(
select(TableEpisodes.path,
TableEpisodes.audio_language)
.where(TableEpisodes.sonarrEpisodeId == sonarrEpisodeId)) \
.first()
if not episodeInfo:
return 'Episode not found', 404
episodePath = path_mappings.path_replace(episodeInfo['path'])
episodePath = path_mappings.path_replace(episodeInfo.path)
audio_language = get_audio_profile_languages(episodeInfo['audio_language'])
audio_language = get_audio_profile_languages(episodeInfo.audio_language)
if len(audio_language) and isinstance(audio_language[0], dict):
audio_language = audio_language[0]
else:
@ -173,18 +173,15 @@ class EpisodesSubtitles(Resource):
args = self.delete_request_parser.parse_args()
sonarrSeriesId = args.get('seriesid')
sonarrEpisodeId = args.get('episodeid')
episodeInfo = TableEpisodes.select(TableEpisodes.title,
TableEpisodes.path,
TableEpisodes.sceneName,
TableEpisodes.audio_language) \
.where(TableEpisodes.sonarrEpisodeId == sonarrEpisodeId) \
.dicts() \
.get_or_none()
episodeInfo = database.execute(
select(TableEpisodes.path)
.where(TableEpisodes.sonarrEpisodeId == sonarrEpisodeId)) \
.first()
if not episodeInfo:
return 'Episode not found', 404
episodePath = path_mappings.path_replace(episodeInfo['path'])
episodePath = path_mappings.path_replace(episodeInfo.path)
language = args.get('language')
forced = args.get('forced')

View File

@ -1,17 +1,15 @@
# coding=utf-8
import os
import operator
import pretty
from flask_restx import Resource, Namespace, reqparse, fields
import ast
from functools import reduce
from app.database import TableEpisodes, TableShows, TableHistory, TableBlacklist
from subtitles.upgrade import get_upgradable_episode_subtitles
from utilities.path_mappings import path_mappings
from api.swaggerui import subtitles_language_model
from app.database import TableEpisodes, TableShows, TableHistory, TableBlacklist, database, select, func
from subtitles.upgrade import get_upgradable_episode_subtitles, _language_still_desired
import pretty
from flask_restx import Resource, Namespace, reqparse, fields
from ..utils import authenticate, postprocess
api_ns_episodes_history = Namespace('Episodes History', description='List episodes history events')
@ -27,7 +25,6 @@ class EpisodesHistory(Resource):
get_language_model = api_ns_episodes_history.model('subtitles_language_model', subtitles_language_model)
data_model = api_ns_episodes_history.model('history_episodes_data_model', {
'id': fields.Integer(),
'seriesTitle': fields.String(),
'monitored': fields.Boolean(),
'episode_number': fields.String(),
@ -40,15 +37,14 @@ class EpisodesHistory(Resource):
'score': fields.String(),
'tags': fields.List(fields.String),
'action': fields.Integer(),
'video_path': fields.String(),
'subtitles_path': fields.String(),
'sonarrEpisodeId': fields.Integer(),
'provider': fields.String(),
'seriesType': fields.String(),
'upgradable': fields.Boolean(),
'raw_timestamp': fields.Integer(),
'parsed_timestamp': fields.String(),
'blacklisted': fields.Boolean(),
'matches': fields.List(fields.String),
'dont_matches': fields.List(fields.String),
})
get_response_model = api_ns_episodes_history.model('EpisodeHistoryGetResponse', {
@ -68,84 +64,116 @@ class EpisodesHistory(Resource):
episodeid = args.get('episodeid')
upgradable_episodes_not_perfect = get_upgradable_episode_subtitles()
if len(upgradable_episodes_not_perfect):
upgradable_episodes_not_perfect = [{"video_path": x['video_path'],
"timestamp": x['timestamp'],
"score": x['score'],
"tags": x['tags'],
"monitored": x['monitored'],
"seriesType": x['seriesType']}
for x in upgradable_episodes_not_perfect]
query_conditions = [(TableEpisodes.title.is_null(False))]
blacklisted_subtitles = select(TableBlacklist.provider,
TableBlacklist.subs_id) \
.subquery()
query_conditions = [(TableEpisodes.title.is_not(None))]
if episodeid:
query_conditions.append((TableEpisodes.sonarrEpisodeId == episodeid))
query_condition = reduce(operator.and_, query_conditions)
episode_history = TableHistory.select(TableHistory.id,
TableShows.title.alias('seriesTitle'),
TableEpisodes.monitored,
TableEpisodes.season.concat('x').concat(TableEpisodes.episode).alias(
'episode_number'),
TableEpisodes.title.alias('episodeTitle'),
TableHistory.timestamp,
TableHistory.subs_id,
TableHistory.description,
TableHistory.sonarrSeriesId,
TableEpisodes.path,
TableHistory.language,
TableHistory.score,
TableShows.tags,
TableHistory.action,
TableHistory.video_path,
TableHistory.subtitles_path,
TableHistory.sonarrEpisodeId,
TableHistory.provider,
TableShows.seriesType) \
.join(TableShows, on=(TableHistory.sonarrSeriesId == TableShows.sonarrSeriesId)) \
.join(TableEpisodes, on=(TableHistory.sonarrEpisodeId == TableEpisodes.sonarrEpisodeId)) \
.where(query_condition) \
stmt = select(TableHistory.id,
TableShows.title.label('seriesTitle'),
TableEpisodes.monitored,
TableEpisodes.season.concat('x').concat(TableEpisodes.episode).label('episode_number'),
TableEpisodes.title.label('episodeTitle'),
TableHistory.timestamp,
TableHistory.subs_id,
TableHistory.description,
TableHistory.sonarrSeriesId,
TableEpisodes.path,
TableHistory.language,
TableHistory.score,
TableShows.tags,
TableHistory.action,
TableHistory.video_path,
TableHistory.subtitles_path,
TableHistory.sonarrEpisodeId,
TableHistory.provider,
TableShows.seriesType,
TableShows.profileId,
TableHistory.matched,
TableHistory.not_matched,
TableEpisodes.subtitles.label('external_subtitles'),
upgradable_episodes_not_perfect.c.id.label('upgradable'),
blacklisted_subtitles.c.subs_id.label('blacklisted')) \
.select_from(TableHistory) \
.join(TableShows, onclause=TableHistory.sonarrSeriesId == TableShows.sonarrSeriesId) \
.join(TableEpisodes, onclause=TableHistory.sonarrEpisodeId == TableEpisodes.sonarrEpisodeId) \
.join(upgradable_episodes_not_perfect, onclause=TableHistory.id == upgradable_episodes_not_perfect.c.id,
isouter=True) \
.join(blacklisted_subtitles, onclause=TableHistory.subs_id == blacklisted_subtitles.c.subs_id,
isouter=True) \
.where(reduce(operator.and_, query_conditions)) \
.order_by(TableHistory.timestamp.desc())
if length > 0:
episode_history = episode_history.limit(length).offset(start)
episode_history = list(episode_history.dicts())
blacklist_db = TableBlacklist.select(TableBlacklist.provider, TableBlacklist.subs_id).dicts()
blacklist_db = list(blacklist_db)
stmt = stmt.limit(length).offset(start)
episode_history = [{
'id': x.id,
'seriesTitle': x.seriesTitle,
'monitored': x.monitored,
'episode_number': x.episode_number,
'episodeTitle': x.episodeTitle,
'timestamp': x.timestamp,
'subs_id': x.subs_id,
'description': x.description,
'sonarrSeriesId': x.sonarrSeriesId,
'path': x.path,
'language': x.language,
'score': x.score,
'tags': x.tags,
'action': x.action,
'video_path': x.video_path,
'subtitles_path': x.subtitles_path,
'sonarrEpisodeId': x.sonarrEpisodeId,
'provider': x.provider,
'matches': x.matched,
'dont_matches': x.not_matched,
'external_subtitles': [y[1] for y in ast.literal_eval(x.external_subtitles) if y[1]],
'upgradable': bool(x.upgradable) if _language_still_desired(x.language, x.profileId) else False,
'blacklisted': bool(x.blacklisted),
} for x in database.execute(stmt).all()]
for item in episode_history:
# Mark episode as upgradable or not
item.update({"upgradable": False})
if {"video_path": str(item['path']), "timestamp": item['timestamp'], "score": item['score'],
"tags": str(item['tags']), "monitored": str(item['monitored']),
"seriesType": str(item['seriesType'])} in upgradable_episodes_not_perfect: # noqa: E129
if os.path.exists(path_mappings.path_replace(item['subtitles_path'])) and \
os.path.exists(path_mappings.path_replace(item['video_path'])):
item.update({"upgradable": True})
original_video_path = item['path']
original_subtitle_path = item['subtitles_path']
item.update(postprocess(item))
# Mark not upgradable if score is perfect or if video/subtitles file doesn't exist anymore
if item['upgradable']:
if original_subtitle_path not in item['external_subtitles'] or \
not item['video_path'] == original_video_path:
item.update({"upgradable": False})
del item['path']
postprocess(item)
del item['video_path']
del item['external_subtitles']
if item['score']:
item['score'] = str(round((int(item['score']) * 100 / 360), 2)) + "%"
# Make timestamp pretty
if item['timestamp']:
item["raw_timestamp"] = item['timestamp'].timestamp()
item["parsed_timestamp"] = item['timestamp'].strftime('%x %X')
item['timestamp'] = pretty.date(item["timestamp"])
# Check if subtitles is blacklisted
item.update({"blacklisted": False})
if item['action'] not in [0, 4, 5]:
for blacklisted_item in blacklist_db:
if blacklisted_item['provider'] == item['provider'] and \
blacklisted_item['subs_id'] == item['subs_id']:
item.update({"blacklisted": True})
break
# Parse matches and dont_matches
if item['matches']:
item.update({'matches': ast.literal_eval(item['matches'])})
else:
item.update({'matches': []})
count = TableHistory.select() \
.join(TableEpisodes, on=(TableHistory.sonarrEpisodeId == TableEpisodes.sonarrEpisodeId)) \
.where(TableEpisodes.title.is_null(False)).count()
if item['dont_matches']:
item.update({'dont_matches': ast.literal_eval(item['dont_matches'])})
else:
item.update({'dont_matches': []})
count = database.execute(
select(func.count())
.select_from(TableHistory)
.join(TableEpisodes)
.where(TableEpisodes.title.is_not(None))) \
.scalar()
return {'data': episode_history, 'total': count}

View File

@ -5,7 +5,7 @@ import operator
from flask_restx import Resource, Namespace, reqparse, fields
from functools import reduce
from app.database import get_exclusion_clause, TableEpisodes, TableShows
from app.database import get_exclusion_clause, TableEpisodes, TableShows, database, select, func
from api.swaggerui import subtitles_language_model
from ..utils import authenticate, postprocess
@ -25,7 +25,6 @@ class EpisodesWanted(Resource):
data_model = api_ns_episodes_wanted.model('wanted_episodes_data_model', {
'seriesTitle': fields.String(),
'monitored': fields.Boolean(),
'episode_number': fields.String(),
'episodeTitle': fields.String(),
'missing_subtitles': fields.Nested(get_subtitles_language_model),
@ -33,7 +32,6 @@ class EpisodesWanted(Resource):
'sonarrEpisodeId': fields.Integer(),
'sceneName': fields.String(),
'tags': fields.List(fields.String),
'failedAttempts': fields.String(),
'seriesType': fields.String(),
})
@ -54,56 +52,48 @@ class EpisodesWanted(Resource):
wanted_conditions = [(TableEpisodes.missing_subtitles != '[]')]
if len(episodeid) > 0:
wanted_conditions.append((TableEpisodes.sonarrEpisodeId in episodeid))
wanted_conditions += get_exclusion_clause('series')
wanted_condition = reduce(operator.and_, wanted_conditions)
if len(episodeid) > 0:
data = TableEpisodes.select(TableShows.title.alias('seriesTitle'),
TableEpisodes.monitored,
TableEpisodes.season.concat('x').concat(TableEpisodes.episode).alias('episode_number'),
TableEpisodes.title.alias('episodeTitle'),
TableEpisodes.missing_subtitles,
TableEpisodes.sonarrSeriesId,
TableEpisodes.sonarrEpisodeId,
TableEpisodes.sceneName,
TableShows.tags,
TableEpisodes.failedAttempts,
TableShows.seriesType)\
.join(TableShows, on=(TableEpisodes.sonarrSeriesId == TableShows.sonarrSeriesId))\
.where(wanted_condition)\
.dicts()
start = 0
length = 0
else:
start = args.get('start')
length = args.get('length')
data = TableEpisodes.select(TableShows.title.alias('seriesTitle'),
TableEpisodes.monitored,
TableEpisodes.season.concat('x').concat(TableEpisodes.episode).alias('episode_number'),
TableEpisodes.title.alias('episodeTitle'),
TableEpisodes.missing_subtitles,
TableEpisodes.sonarrSeriesId,
TableEpisodes.sonarrEpisodeId,
TableEpisodes.sceneName,
TableShows.tags,
TableEpisodes.failedAttempts,
TableShows.seriesType)\
.join(TableShows, on=(TableEpisodes.sonarrSeriesId == TableShows.sonarrSeriesId))\
.where(wanted_condition)\
.order_by(TableEpisodes.rowid.desc())
if length > 0:
data = data.limit(length).offset(start)
data = data.dicts()
data = list(data)
for item in data:
postprocess(item)
wanted_conditions += get_exclusion_clause('series')
wanted_condition = reduce(operator.and_, wanted_conditions)
count_conditions = [(TableEpisodes.missing_subtitles != '[]')]
count_conditions += get_exclusion_clause('series')
count = TableEpisodes.select(TableShows.tags,
TableShows.seriesType,
TableEpisodes.monitored)\
.join(TableShows, on=(TableEpisodes.sonarrSeriesId == TableShows.sonarrSeriesId))\
.where(reduce(operator.and_, count_conditions))\
.count()
stmt = select(TableShows.title.label('seriesTitle'),
TableEpisodes.season.concat('x').concat(TableEpisodes.episode).label('episode_number'),
TableEpisodes.title.label('episodeTitle'),
TableEpisodes.missing_subtitles,
TableEpisodes.sonarrSeriesId,
TableEpisodes.sonarrEpisodeId,
TableEpisodes.sceneName,
TableShows.tags,
TableShows.seriesType) \
.select_from(TableEpisodes) \
.join(TableShows) \
.where(wanted_condition)
return {'data': data, 'total': count}
if length > 0:
stmt = stmt.order_by(TableEpisodes.sonarrEpisodeId.desc()).limit(length).offset(start)
results = [postprocess({
'seriesTitle': x.seriesTitle,
'episode_number': x.episode_number,
'episodeTitle': x.episodeTitle,
'missing_subtitles': x.missing_subtitles,
'sonarrSeriesId': x.sonarrSeriesId,
'sonarrEpisodeId': x.sonarrEpisodeId,
'sceneName': x.sceneName,
'tags': x.tags,
'seriesType': x.seriesType,
}) for x in database.execute(stmt).all()]
count = database.execute(
select(func.count())
.select_from(TableEpisodes)
.join(TableShows)
.where(wanted_condition)) \
.scalar()
return {'data': results, 'total': count}

View File

@ -8,7 +8,7 @@ from dateutil import rrule
from flask_restx import Resource, Namespace, reqparse, fields
from functools import reduce
from app.database import TableHistory, TableHistoryMovie
from app.database import TableHistory, TableHistoryMovie, database, select
from ..utils import authenticate
@ -86,17 +86,25 @@ class HistoryStats(Resource):
history_where_clause = reduce(operator.and_, history_where_clauses)
history_where_clause_movie = reduce(operator.and_, history_where_clauses_movie)
data_series = TableHistory.select(TableHistory.timestamp, TableHistory.id)\
.where(history_where_clause) \
.dicts()
data_series = [{
'timestamp': x.timestamp,
'id': x.id,
} for x in database.execute(
select(TableHistory.timestamp, TableHistory.id)
.where(history_where_clause))
.all()]
data_series = [{'date': date[0], 'count': sum(1 for item in date[1])} for date in
itertools.groupby(list(data_series),
key=lambda x: x['timestamp'].strftime(
'%Y-%m-%d'))]
data_movies = TableHistoryMovie.select(TableHistoryMovie.timestamp, TableHistoryMovie.id) \
.where(history_where_clause_movie) \
.dicts()
data_movies = [{
'timestamp': x.timestamp,
'id': x.id,
} for x in database.execute(
select(TableHistoryMovie.timestamp, TableHistoryMovie.id)
.where(history_where_clause_movie))
.all()]
data_movies = [{'date': date[0], 'count': sum(1 for item in date[1])} for date in
itertools.groupby(list(data_movies),
key=lambda x: x['timestamp'].strftime(

View File

@ -4,7 +4,7 @@ import pretty
from flask_restx import Resource, Namespace, reqparse, fields
from app.database import TableMovies, TableBlacklistMovie
from app.database import TableMovies, TableBlacklistMovie, database, select
from subtitles.tools.delete import delete_subtitles
from radarr.blacklist import blacklist_log_movie, blacklist_delete_all_movie, blacklist_delete_movie
from utilities.path_mappings import path_mappings
@ -46,26 +46,28 @@ class MoviesBlacklist(Resource):
start = args.get('start')
length = args.get('length')
data = TableBlacklistMovie.select(TableMovies.title,
TableMovies.radarrId,
TableBlacklistMovie.provider,
TableBlacklistMovie.subs_id,
TableBlacklistMovie.language,
TableBlacklistMovie.timestamp)\
.join(TableMovies, on=(TableBlacklistMovie.radarr_id == TableMovies.radarrId))\
.order_by(TableBlacklistMovie.timestamp.desc())
data = database.execute(
select(TableMovies.title,
TableMovies.radarrId,
TableBlacklistMovie.provider,
TableBlacklistMovie.subs_id,
TableBlacklistMovie.language,
TableBlacklistMovie.timestamp)
.select_from(TableBlacklistMovie)
.join(TableMovies)
.order_by(TableBlacklistMovie.timestamp.desc()))
if length > 0:
data = data.limit(length).offset(start)
data = list(data.dicts())
for item in data:
postprocess(item)
# Make timestamp pretty
item["parsed_timestamp"] = item['timestamp'].strftime('%x %X')
item.update({'timestamp': pretty.date(item['timestamp'])})
return data
return [postprocess({
'title': x.title,
'radarrId': x.radarrId,
'provider': x.provider,
'subs_id': x.subs_id,
'language': x.language,
'timestamp': pretty.date(x.timestamp),
'parsed_timestamp': x.timestamp.strftime('%x %X'),
}) for x in data.all()]
post_request_parser = reqparse.RequestParser()
post_request_parser.add_argument('radarrid', type=int, required=True, help='Radarr ID')
@ -90,12 +92,15 @@ class MoviesBlacklist(Resource):
forced = False
hi = False
data = TableMovies.select(TableMovies.path).where(TableMovies.radarrId == radarr_id).dicts().get_or_none()
data = database.execute(
select(TableMovies.path)
.where(TableMovies.radarrId == radarr_id))\
.first()
if not data:
return 'Movie not found', 404
media_path = data['path']
media_path = data.path
subtitles_path = args.get('subtitles_path')
blacklist_log_movie(radarr_id=radarr_id,

View File

@ -1,15 +1,14 @@
# coding=utf-8
import os
import operator
import pretty
import ast
from flask_restx import Resource, Namespace, reqparse, fields
from functools import reduce
from app.database import TableMovies, TableHistoryMovie, TableBlacklistMovie
from subtitles.upgrade import get_upgradable_movies_subtitles
from utilities.path_mappings import path_mappings
from app.database import TableMovies, TableHistoryMovie, TableBlacklistMovie, database, select, func
from subtitles.upgrade import get_upgradable_movies_subtitles, _language_still_desired
from api.swaggerui import subtitles_language_model
from api.utils import authenticate, postprocess
@ -27,7 +26,6 @@ class MoviesHistory(Resource):
get_language_model = api_ns_movies_history.model('subtitles_language_model', subtitles_language_model)
data_model = api_ns_movies_history.model('history_movies_data_model', {
'id': fields.Integer(),
'action': fields.Integer(),
'title': fields.String(),
'timestamp': fields.String(),
@ -42,9 +40,10 @@ class MoviesHistory(Resource):
'provider': fields.String(),
'subtitles_path': fields.String(),
'upgradable': fields.Boolean(),
'raw_timestamp': fields.Integer(),
'parsed_timestamp': fields.String(),
'blacklisted': fields.Boolean(),
'matches': fields.List(fields.String),
'dont_matches': fields.List(fields.String),
})
get_response_model = api_ns_movies_history.model('MovieHistoryGetResponse', {
@ -64,79 +63,108 @@ class MoviesHistory(Resource):
radarrid = args.get('radarrid')
upgradable_movies_not_perfect = get_upgradable_movies_subtitles()
if len(upgradable_movies_not_perfect):
upgradable_movies_not_perfect = [{"video_path": x['video_path'],
"timestamp": x['timestamp'],
"score": x['score'],
"tags": x['tags'],
"monitored": x['monitored']}
for x in upgradable_movies_not_perfect]
query_conditions = [(TableMovies.title.is_null(False))]
blacklisted_subtitles = select(TableBlacklistMovie.provider,
TableBlacklistMovie.subs_id) \
.subquery()
query_conditions = [(TableMovies.title.is_not(None))]
if radarrid:
query_conditions.append((TableMovies.radarrId == radarrid))
query_condition = reduce(operator.and_, query_conditions)
movie_history = TableHistoryMovie.select(TableHistoryMovie.id,
TableHistoryMovie.action,
TableMovies.title,
TableHistoryMovie.timestamp,
TableHistoryMovie.description,
TableHistoryMovie.radarrId,
TableMovies.monitored,
TableHistoryMovie.video_path.alias('path'),
TableHistoryMovie.language,
TableMovies.tags,
TableHistoryMovie.score,
TableHistoryMovie.subs_id,
TableHistoryMovie.provider,
TableHistoryMovie.subtitles_path,
TableHistoryMovie.video_path) \
.join(TableMovies, on=(TableHistoryMovie.radarrId == TableMovies.radarrId)) \
.where(query_condition) \
stmt = select(TableHistoryMovie.id,
TableHistoryMovie.action,
TableMovies.title,
TableHistoryMovie.timestamp,
TableHistoryMovie.description,
TableHistoryMovie.radarrId,
TableMovies.monitored,
TableMovies.path,
TableHistoryMovie.language,
TableMovies.tags,
TableHistoryMovie.score,
TableHistoryMovie.subs_id,
TableHistoryMovie.provider,
TableHistoryMovie.subtitles_path,
TableHistoryMovie.video_path,
TableHistoryMovie.matched,
TableHistoryMovie.not_matched,
TableMovies.profileId,
TableMovies.subtitles.label('external_subtitles'),
upgradable_movies_not_perfect.c.id.label('upgradable'),
blacklisted_subtitles.c.subs_id.label('blacklisted')) \
.select_from(TableHistoryMovie) \
.join(TableMovies) \
.join(upgradable_movies_not_perfect, onclause=TableHistoryMovie.id == upgradable_movies_not_perfect.c.id,
isouter=True) \
.join(blacklisted_subtitles, onclause=TableHistoryMovie.subs_id == blacklisted_subtitles.c.subs_id,
isouter=True) \
.where(reduce(operator.and_, query_conditions)) \
.order_by(TableHistoryMovie.timestamp.desc())
if length > 0:
movie_history = movie_history.limit(length).offset(start)
movie_history = list(movie_history.dicts())
blacklist_db = TableBlacklistMovie.select(TableBlacklistMovie.provider, TableBlacklistMovie.subs_id).dicts()
blacklist_db = list(blacklist_db)
stmt = stmt.limit(length).offset(start)
movie_history = [{
'id': x.id,
'action': x.action,
'title': x.title,
'timestamp': x.timestamp,
'description': x.description,
'radarrId': x.radarrId,
'monitored': x.monitored,
'path': x.path,
'language': x.language,
'tags': x.tags,
'score': x.score,
'subs_id': x.subs_id,
'provider': x.provider,
'subtitles_path': x.subtitles_path,
'video_path': x.video_path,
'matches': x.matched,
'dont_matches': x.not_matched,
'external_subtitles': [y[1] for y in ast.literal_eval(x.external_subtitles) if y[1]],
'upgradable': bool(x.upgradable) if _language_still_desired(x.language, x.profileId) else False,
'blacklisted': bool(x.blacklisted),
} for x in database.execute(stmt).all()]
for item in movie_history:
# Mark movies as upgradable or not
item.update({"upgradable": False})
if {"video_path": str(item['path']), "timestamp": item['timestamp'], "score": item['score'],
"tags": str(item['tags']),
"monitored": str(item['monitored'])} in upgradable_movies_not_perfect: # noqa: E129
if os.path.exists(path_mappings.path_replace_movie(item['subtitles_path'])) and \
os.path.exists(path_mappings.path_replace_movie(item['video_path'])):
item.update({"upgradable": True})
original_video_path = item['path']
original_subtitle_path = item['subtitles_path']
item.update(postprocess(item))
# Mark not upgradable if score or if video/subtitles file doesn't exist anymore
if item['upgradable']:
if original_subtitle_path not in item['external_subtitles'] or \
not item['video_path'] == original_video_path:
item.update({"upgradable": False})
del item['path']
postprocess(item)
del item['video_path']
del item['external_subtitles']
if item['score']:
item['score'] = str(round((int(item['score']) * 100 / 120), 2)) + "%"
# Make timestamp pretty
if item['timestamp']:
item["raw_timestamp"] = item['timestamp'].timestamp()
item["parsed_timestamp"] = item['timestamp'].strftime('%x %X')
item['timestamp'] = pretty.date(item["timestamp"])
# Check if subtitles is blacklisted
item.update({"blacklisted": False})
if item['action'] not in [0, 4, 5]:
for blacklisted_item in blacklist_db:
if blacklisted_item['provider'] == item['provider'] and blacklisted_item['subs_id'] == item[
'subs_id']: # noqa: E125
item.update({"blacklisted": True})
break
# Parse matches and dont_matches
if item['matches']:
item.update({'matches': ast.literal_eval(item['matches'])})
else:
item.update({'matches': []})
count = TableHistoryMovie.select() \
.join(TableMovies, on=(TableHistoryMovie.radarrId == TableMovies.radarrId)) \
.where(TableMovies.title.is_null(False)) \
.count()
if item['dont_matches']:
item.update({'dont_matches': ast.literal_eval(item['dont_matches'])})
else:
item.update({'dont_matches': []})
count = database.execute(
select(func.count())
.select_from(TableHistoryMovie)
.join(TableMovies)
.where(TableMovies.title.is_not(None))) \
.scalar()
return {'data': movie_history, 'total': count}

View File

@ -2,7 +2,7 @@
from flask_restx import Resource, Namespace, reqparse, fields
from app.database import TableMovies
from app.database import TableMovies, database, update, select, func
from subtitles.indexer.movies import list_missing_subtitles_movies, movies_scan_subtitles
from app.event_handler import event_stream
from subtitles.wanted import wanted_search_missing_subtitles_movies
@ -29,30 +29,20 @@ class Movies(Resource):
data_model = api_ns_movies.model('movies_data_model', {
'alternativeTitles': fields.List(fields.String),
'audio_codec': fields.String(),
'audio_language': fields.Nested(get_audio_language_model),
'failedAttempts': fields.String(),
'fanart': fields.String(),
'file_size': fields.Integer(),
'format': fields.String(),
'imdbId': fields.String(),
'missing_subtitles': fields.Nested(get_subtitles_language_model),
'monitored': fields.Boolean(),
'movie_file_id': fields.Integer(),
'overview': fields.String(),
'path': fields.String(),
'poster': fields.String(),
'profileId': fields.Integer(),
'radarrId': fields.Integer(),
'resolution': fields.String(),
'rowid': fields.Integer(),
'sceneName': fields.String(),
'sortTitle': fields.String(),
'subtitles': fields.Nested(get_subtitles_model),
'tags': fields.List(fields.String),
'title': fields.String(),
'tmdbId': fields.String(),
'video_codec': fields.String(),
'year': fields.String(),
})
@ -73,23 +63,56 @@ class Movies(Resource):
length = args.get('length')
radarrId = args.get('radarrid[]')
count = TableMovies.select().count()
stmt = select(TableMovies.alternativeTitles,
TableMovies.audio_language,
TableMovies.fanart,
TableMovies.imdbId,
TableMovies.missing_subtitles,
TableMovies.monitored,
TableMovies.overview,
TableMovies.path,
TableMovies.poster,
TableMovies.profileId,
TableMovies.radarrId,
TableMovies.sceneName,
TableMovies.subtitles,
TableMovies.tags,
TableMovies.title,
TableMovies.year,
)\
.order_by(TableMovies.sortTitle)
if len(radarrId) != 0:
result = TableMovies.select()\
.where(TableMovies.radarrId.in_(radarrId))\
.order_by(TableMovies.sortTitle)\
.dicts()
else:
result = TableMovies.select().order_by(TableMovies.sortTitle)
if length > 0:
result = result.limit(length).offset(start)
result = result.dicts()
result = list(result)
for item in result:
postprocess(item)
stmt = stmt.where(TableMovies.radarrId.in_(radarrId))
return {'data': result, 'total': count}
if length > 0:
stmt = stmt.limit(length).offset(start)
results = [postprocess({
'alternativeTitles': x.alternativeTitles,
'audio_language': x.audio_language,
'fanart': x.fanart,
'imdbId': x.imdbId,
'missing_subtitles': x.missing_subtitles,
'monitored': x.monitored,
'overview': x.overview,
'path': x.path,
'poster': x.poster,
'profileId': x.profileId,
'radarrId': x.radarrId,
'sceneName': x.sceneName,
'subtitles': x.subtitles,
'tags': x.tags,
'title': x.title,
'year': x.year,
}) for x in database.execute(stmt).all()]
count = database.execute(
select(func.count())
.select_from(TableMovies)) \
.scalar()
return {'data': results, 'total': count}
post_request_parser = reqparse.RequestParser()
post_request_parser.add_argument('radarrid', type=int, action='append', required=False, default=[],
@ -120,11 +143,10 @@ class Movies(Resource):
except Exception:
return 'Languages profile not found', 404
TableMovies.update({
TableMovies.profileId: profileId
})\
.where(TableMovies.radarrId == radarrId)\
.execute()
database.execute(
update(TableMovies)
.values(profileId=profileId)
.where(TableMovies.radarrId == radarrId))
list_missing_subtitles_movies(no=radarrId, send_event=False)

View File

@ -8,7 +8,7 @@ from flask_restx import Resource, Namespace, reqparse
from subliminal_patch.core import SUBTITLE_EXTENSIONS
from werkzeug.datastructures import FileStorage
from app.database import TableMovies, get_audio_profile_languages, get_profile_id
from app.database import TableMovies, get_audio_profile_languages, get_profile_id, database, select
from utilities.path_mappings import path_mappings
from subtitles.upload import manual_upload_subtitle
from subtitles.download import generate_subtitles
@ -42,28 +42,28 @@ class MoviesSubtitles(Resource):
args = self.patch_request_parser.parse_args()
radarrId = args.get('radarrid')
movieInfo = TableMovies.select(
TableMovies.title,
TableMovies.path,
TableMovies.sceneName,
TableMovies.audio_language) \
.where(TableMovies.radarrId == radarrId) \
.dicts() \
.get_or_none()
movieInfo = database.execute(
select(
TableMovies.title,
TableMovies.path,
TableMovies.sceneName,
TableMovies.audio_language)
.where(TableMovies.radarrId == radarrId)) \
.first()
if not movieInfo:
return 'Movie not found', 404
moviePath = path_mappings.path_replace_movie(movieInfo['path'])
sceneName = movieInfo['sceneName'] or 'None'
moviePath = path_mappings.path_replace_movie(movieInfo.path)
sceneName = movieInfo.sceneName or 'None'
title = movieInfo['title']
title = movieInfo.title
language = args.get('language')
hi = args.get('hi').capitalize()
forced = args.get('forced').capitalize()
audio_language_list = get_audio_profile_languages(movieInfo["audio_language"])
audio_language_list = get_audio_profile_languages(movieInfo.audio_language)
if len(audio_language_list) > 0:
audio_language = audio_language_list[0]['name']
else:
@ -99,18 +99,17 @@ class MoviesSubtitles(Resource):
# TODO: Support Multiply Upload
args = self.post_request_parser.parse_args()
radarrId = args.get('radarrid')
movieInfo = TableMovies.select(TableMovies.path,
TableMovies.audio_language) \
.where(TableMovies.radarrId == radarrId) \
.dicts() \
.get_or_none()
movieInfo = database.execute(
select(TableMovies.path, TableMovies.audio_language)
.where(TableMovies.radarrId == radarrId)) \
.first()
if not movieInfo:
return 'Movie not found', 404
moviePath = path_mappings.path_replace_movie(movieInfo['path'])
moviePath = path_mappings.path_replace_movie(movieInfo.path)
audio_language = get_audio_profile_languages(movieInfo['audio_language'])
audio_language = get_audio_profile_languages(movieInfo.audio_language)
if len(audio_language) and isinstance(audio_language[0], dict):
audio_language = audio_language[0]
else:
@ -163,15 +162,15 @@ class MoviesSubtitles(Resource):
"""Delete a movie subtitles"""
args = self.delete_request_parser.parse_args()
radarrId = args.get('radarrid')
movieInfo = TableMovies.select(TableMovies.path) \
.where(TableMovies.radarrId == radarrId) \
.dicts() \
.get_or_none()
movieInfo = database.execute(
select(TableMovies.path)
.where(TableMovies.radarrId == radarrId)) \
.first()
if not movieInfo:
return 'Movie not found', 404
moviePath = path_mappings.path_replace_movie(movieInfo['path'])
moviePath = path_mappings.path_replace_movie(movieInfo.path)
language = args.get('language')
forced = args.get('forced')

View File

@ -5,7 +5,7 @@ import operator
from flask_restx import Resource, Namespace, reqparse, fields
from functools import reduce
from app.database import get_exclusion_clause, TableMovies
from app.database import get_exclusion_clause, TableMovies, database, select, func
from api.swaggerui import subtitles_language_model
from api.utils import authenticate, postprocess
@ -26,12 +26,10 @@ class MoviesWanted(Resource):
data_model = api_ns_movies_wanted.model('wanted_movies_data_model', {
'title': fields.String(),
'monitored': fields.Boolean(),
'missing_subtitles': fields.Nested(get_subtitles_language_model),
'radarrId': fields.Integer(),
'sceneName': fields.String(),
'tags': fields.List(fields.String),
'failedAttempts': fields.String(),
})
get_response_model = api_ns_movies_wanted.model('MovieWantedGetResponse', {
@ -51,44 +49,36 @@ class MoviesWanted(Resource):
wanted_conditions = [(TableMovies.missing_subtitles != '[]')]
if len(radarrid) > 0:
wanted_conditions.append((TableMovies.radarrId.in_(radarrid)))
wanted_conditions += get_exclusion_clause('movie')
wanted_condition = reduce(operator.and_, wanted_conditions)
if len(radarrid) > 0:
result = TableMovies.select(TableMovies.title,
TableMovies.missing_subtitles,
TableMovies.radarrId,
TableMovies.sceneName,
TableMovies.failedAttempts,
TableMovies.tags,
TableMovies.monitored)\
.where(wanted_condition)\
.dicts()
start = 0
length = 0
else:
start = args.get('start')
length = args.get('length')
result = TableMovies.select(TableMovies.title,
TableMovies.missing_subtitles,
TableMovies.radarrId,
TableMovies.sceneName,
TableMovies.failedAttempts,
TableMovies.tags,
TableMovies.monitored)\
.where(wanted_condition)\
.order_by(TableMovies.rowid.desc())
if length > 0:
result = result.limit(length).offset(start)
result = result.dicts()
result = list(result)
for item in result:
postprocess(item)
wanted_conditions += get_exclusion_clause('movie')
wanted_condition = reduce(operator.and_, wanted_conditions)
count_conditions = [(TableMovies.missing_subtitles != '[]')]
count_conditions += get_exclusion_clause('movie')
count = TableMovies.select(TableMovies.monitored,
TableMovies.tags)\
.where(reduce(operator.and_, count_conditions))\
.count()
stmt = select(TableMovies.title,
TableMovies.missing_subtitles,
TableMovies.radarrId,
TableMovies.sceneName,
TableMovies.tags) \
.where(wanted_condition)
if length > 0:
stmt = stmt.order_by(TableMovies.radarrId.desc()).limit(length).offset(start)
return {'data': result, 'total': count}
results = [postprocess({
'title': x.title,
'missing_subtitles': x.missing_subtitles,
'radarrId': x.radarrId,
'sceneName': x.sceneName,
'tags': x.tags,
}) for x in database.execute(stmt).all()]
count = database.execute(
select(func.count())
.select_from(TableMovies)
.where(wanted_condition)) \
.scalar()
return {'data': results, 'total': count}

View File

@ -3,7 +3,7 @@
from flask_restx import Resource, Namespace, reqparse, fields
from operator import itemgetter
from app.database import TableHistory, TableHistoryMovie
from app.database import TableHistory, TableHistoryMovie, database, select
from app.get_providers import list_throttled_providers, reset_throttled_providers
from ..utils import authenticate, False_Keys
@ -32,20 +32,25 @@ class Providers(Resource):
args = self.get_request_parser.parse_args()
history = args.get('history')
if history and history not in False_Keys:
providers = list(TableHistory.select(TableHistory.provider)
.where(TableHistory.provider is not None and TableHistory.provider != "manual")
.dicts())
providers += list(TableHistoryMovie.select(TableHistoryMovie.provider)
.where(TableHistoryMovie.provider is not None and TableHistoryMovie.provider != "manual")
.dicts())
providers_list = list(set([x['provider'] for x in providers]))
providers = database.execute(
select(TableHistory.provider)
.where(TableHistory.provider and TableHistory.provider != "manual")
.distinct())\
.all()
providers += database.execute(
select(TableHistoryMovie.provider)
.where(TableHistoryMovie.provider and TableHistoryMovie.provider != "manual")
.distinct())\
.all()
providers_list = [x.provider for x in providers]
providers_dicts = []
for provider in providers_list:
providers_dicts.append({
'name': provider,
'status': 'History',
'retry': '-'
})
if provider not in [x['name'] for x in providers_dicts]:
providers_dicts.append({
'name': provider,
'status': 'History',
'retry': '-'
})
else:
throttled_providers = list_throttled_providers()

View File

@ -2,7 +2,7 @@
from flask_restx import Resource, Namespace, reqparse, fields
from app.database import TableEpisodes, TableShows, get_audio_profile_languages, get_profile_id
from app.database import TableEpisodes, TableShows, get_audio_profile_languages, get_profile_id, database, select
from utilities.path_mappings import path_mappings
from app.get_providers import get_providers
from subtitles.manual import manual_search, manual_download_subtitle
@ -47,22 +47,23 @@ class ProviderEpisodes(Resource):
"""Search manually for an episode subtitles"""
args = self.get_request_parser.parse_args()
sonarrEpisodeId = args.get('episodeid')
episodeInfo = TableEpisodes.select(TableEpisodes.path,
TableEpisodes.sceneName,
TableShows.title,
TableShows.profileId) \
.join(TableShows, on=(TableEpisodes.sonarrSeriesId == TableShows.sonarrSeriesId)) \
.where(TableEpisodes.sonarrEpisodeId == sonarrEpisodeId) \
.dicts() \
.get_or_none()
episodeInfo = database.execute(
select(TableEpisodes.path,
TableEpisodes.sceneName,
TableShows.title,
TableShows.profileId)
.select_from(TableEpisodes)
.join(TableShows)
.where(TableEpisodes.sonarrEpisodeId == sonarrEpisodeId)) \
.first()
if not episodeInfo:
return 'Episode not found', 404
title = episodeInfo['title']
episodePath = path_mappings.path_replace(episodeInfo['path'])
sceneName = episodeInfo['sceneName'] or "None"
profileId = episodeInfo['profileId']
title = episodeInfo.title
episodePath = path_mappings.path_replace(episodeInfo.path)
sceneName = episodeInfo.sceneName or "None"
profileId = episodeInfo.profileId
providers_list = get_providers()
@ -91,22 +92,23 @@ class ProviderEpisodes(Resource):
args = self.post_request_parser.parse_args()
sonarrSeriesId = args.get('seriesid')
sonarrEpisodeId = args.get('episodeid')
episodeInfo = TableEpisodes.select(
TableEpisodes.audio_language,
TableEpisodes.path,
TableEpisodes.sceneName,
TableShows.title) \
.join(TableShows, on=(TableEpisodes.sonarrSeriesId == TableShows.sonarrSeriesId)) \
.where(TableEpisodes.sonarrEpisodeId == sonarrEpisodeId) \
.dicts() \
.get_or_none()
episodeInfo = database.execute(
select(
TableEpisodes.audio_language,
TableEpisodes.path,
TableEpisodes.sceneName,
TableShows.title)
.select_from(TableEpisodes)
.join(TableShows)
.where(TableEpisodes.sonarrEpisodeId == sonarrEpisodeId)) \
.first()
if not episodeInfo:
return 'Episode not found', 404
title = episodeInfo['title']
episodePath = path_mappings.path_replace(episodeInfo['path'])
sceneName = episodeInfo['sceneName'] or "None"
title = episodeInfo.title
episodePath = path_mappings.path_replace(episodeInfo.path)
sceneName = episodeInfo.sceneName or "None"
hi = args.get('hi').capitalize()
forced = args.get('forced').capitalize()
@ -114,7 +116,7 @@ class ProviderEpisodes(Resource):
selected_provider = args.get('provider')
subtitle = args.get('subtitle')
audio_language_list = get_audio_profile_languages(episodeInfo["audio_language"])
audio_language_list = get_audio_profile_languages(episodeInfo.audio_language)
if len(audio_language_list) > 0:
audio_language = audio_language_list[0]['name']
else:

View File

@ -2,7 +2,7 @@
from flask_restx import Resource, Namespace, reqparse, fields
from app.database import TableMovies, get_audio_profile_languages, get_profile_id
from app.database import TableMovies, get_audio_profile_languages, get_profile_id, database, select
from utilities.path_mappings import path_mappings
from app.get_providers import get_providers
from subtitles.manual import manual_search, manual_download_subtitle
@ -48,21 +48,21 @@ class ProviderMovies(Resource):
"""Search manually for a movie subtitles"""
args = self.get_request_parser.parse_args()
radarrId = args.get('radarrid')
movieInfo = TableMovies.select(TableMovies.title,
TableMovies.path,
TableMovies.sceneName,
TableMovies.profileId) \
.where(TableMovies.radarrId == radarrId) \
.dicts() \
.get_or_none()
movieInfo = database.execute(
select(TableMovies.title,
TableMovies.path,
TableMovies.sceneName,
TableMovies.profileId)
.where(TableMovies.radarrId == radarrId)) \
.first()
if not movieInfo:
return 'Movie not found', 404
title = movieInfo['title']
moviePath = path_mappings.path_replace_movie(movieInfo['path'])
sceneName = movieInfo['sceneName'] or "None"
profileId = movieInfo['profileId']
title = movieInfo.title
moviePath = path_mappings.path_replace_movie(movieInfo.path)
sceneName = movieInfo.sceneName or "None"
profileId = movieInfo.profileId
providers_list = get_providers()
@ -89,20 +89,20 @@ class ProviderMovies(Resource):
"""Manually download a movie subtitles"""
args = self.post_request_parser.parse_args()
radarrId = args.get('radarrid')
movieInfo = TableMovies.select(TableMovies.title,
TableMovies.path,
TableMovies.sceneName,
TableMovies.audio_language) \
.where(TableMovies.radarrId == radarrId) \
.dicts() \
.get_or_none()
movieInfo = database.execute(
select(TableMovies.title,
TableMovies.path,
TableMovies.sceneName,
TableMovies.audio_language)
.where(TableMovies.radarrId == radarrId)) \
.first()
if not movieInfo:
return 'Movie not found', 404
title = movieInfo['title']
moviePath = path_mappings.path_replace_movie(movieInfo['path'])
sceneName = movieInfo['sceneName'] or "None"
title = movieInfo.title
moviePath = path_mappings.path_replace_movie(movieInfo.path)
sceneName = movieInfo.sceneName or "None"
hi = args.get('hi').capitalize()
forced = args.get('forced').capitalize()
@ -110,7 +110,7 @@ class ProviderMovies(Resource):
selected_provider = args.get('provider')
subtitle = args.get('subtitle')
audio_language_list = get_audio_profile_languages(movieInfo["audio_language"])
audio_language_list = get_audio_profile_languages(movieInfo.audio_language)
if len(audio_language_list) > 0:
audio_language = audio_language_list[0]['name']
else:

View File

@ -4,9 +4,8 @@ import operator
from flask_restx import Resource, Namespace, reqparse, fields
from functools import reduce
from peewee import fn, JOIN
from app.database import get_exclusion_clause, TableEpisodes, TableShows
from app.database import get_exclusion_clause, TableEpisodes, TableShows, database, select, update, func
from subtitles.indexer.series import list_missing_subtitles, series_scan_subtitles
from subtitles.mass_download import series_download_subtitles
from subtitles.wanted import wanted_search_missing_subtitles_series
@ -45,7 +44,6 @@ class Series(Resource):
'profileId': fields.Integer(),
'seriesType': fields.String(),
'sonarrSeriesId': fields.Integer(),
'sortTitle': fields.String(),
'tags': fields.List(fields.String),
'title': fields.String(),
'tvdbId': fields.Integer(),
@ -69,40 +67,77 @@ class Series(Resource):
length = args.get('length')
seriesId = args.get('seriesid[]')
count = TableShows.select().count()
episodeFileCount = TableEpisodes.select(TableShows.sonarrSeriesId,
fn.COUNT(TableEpisodes.sonarrSeriesId).coerce(False).alias('episodeFileCount')) \
.join(TableShows, on=(TableEpisodes.sonarrSeriesId == TableShows.sonarrSeriesId)) \
.group_by(TableShows.sonarrSeriesId).alias('episodeFileCount')
episodeFileCount = select(TableShows.sonarrSeriesId,
func.count(TableEpisodes.sonarrSeriesId).label('episodeFileCount')) \
.select_from(TableEpisodes) \
.join(TableShows) \
.group_by(TableShows.sonarrSeriesId)\
.subquery()
episodes_missing_conditions = [(TableEpisodes.missing_subtitles != '[]')]
episodes_missing_conditions += get_exclusion_clause('series')
episodeMissingCount = (TableEpisodes.select(TableShows.sonarrSeriesId,
fn.COUNT(TableEpisodes.sonarrSeriesId).coerce(False).alias('episodeMissingCount'))
.join(TableShows, on=(TableEpisodes.sonarrSeriesId == TableShows.sonarrSeriesId))
.where(reduce(operator.and_, episodes_missing_conditions)).group_by(
TableShows.sonarrSeriesId).alias('episodeMissingCount'))
episodeMissingCount = select(TableShows.sonarrSeriesId,
func.count(TableEpisodes.sonarrSeriesId).label('episodeMissingCount')) \
.select_from(TableEpisodes) \
.join(TableShows) \
.where(reduce(operator.and_, episodes_missing_conditions)) \
.group_by(TableShows.sonarrSeriesId)\
.subquery()
result = TableShows.select(TableShows, episodeFileCount.c.episodeFileCount,
episodeMissingCount.c.episodeMissingCount).join(episodeFileCount,
join_type=JOIN.LEFT_OUTER, on=(
TableShows.sonarrSeriesId ==
episodeFileCount.c.sonarrSeriesId)
) \
.join(episodeMissingCount, join_type=JOIN.LEFT_OUTER,
on=(TableShows.sonarrSeriesId == episodeMissingCount.c.sonarrSeriesId)).order_by(TableShows.sortTitle)
stmt = select(TableShows.tvdbId,
TableShows.alternativeTitles,
TableShows.audio_language,
TableShows.fanart,
TableShows.imdbId,
TableShows.monitored,
TableShows.overview,
TableShows.path,
TableShows.poster,
TableShows.profileId,
TableShows.seriesType,
TableShows.sonarrSeriesId,
TableShows.tags,
TableShows.title,
TableShows.year,
episodeFileCount.c.episodeFileCount,
episodeMissingCount.c.episodeMissingCount) \
.select_from(TableShows) \
.join(episodeFileCount, TableShows.sonarrSeriesId == episodeFileCount.c.sonarrSeriesId, isouter=True) \
.join(episodeMissingCount, TableShows.sonarrSeriesId == episodeMissingCount.c.sonarrSeriesId, isouter=True)\
.order_by(TableShows.sortTitle)
if len(seriesId) != 0:
result = result.where(TableShows.sonarrSeriesId.in_(seriesId))
stmt = stmt.where(TableShows.sonarrSeriesId.in_(seriesId))
elif length > 0:
result = result.limit(length).offset(start)
result = list(result.dicts())
stmt = stmt.limit(length).offset(start)
for item in result:
postprocess(item)
results = [postprocess({
'tvdbId': x.tvdbId,
'alternativeTitles': x.alternativeTitles,
'audio_language': x.audio_language,
'fanart': x.fanart,
'imdbId': x.imdbId,
'monitored': x.monitored,
'overview': x.overview,
'path': x.path,
'poster': x.poster,
'profileId': x.profileId,
'seriesType': x.seriesType,
'sonarrSeriesId': x.sonarrSeriesId,
'tags': x.tags,
'title': x.title,
'year': x.year,
'episodeFileCount': x.episodeFileCount,
'episodeMissingCount': x.episodeMissingCount,
}) for x in database.execute(stmt).all()]
return {'data': result, 'total': count}
count = database.execute(
select(func.count())
.select_from(TableShows)) \
.scalar()
return {'data': results, 'total': count}
post_request_parser = reqparse.RequestParser()
post_request_parser.add_argument('seriesid', type=int, action='append', required=False, default=[],
@ -133,23 +168,22 @@ class Series(Resource):
except Exception:
return 'Languages profile not found', 404
TableShows.update({
TableShows.profileId: profileId
}) \
.where(TableShows.sonarrSeriesId == seriesId) \
.execute()
database.execute(
update(TableShows)
.values(profileId=profileId)
.where(TableShows.sonarrSeriesId == seriesId))
list_missing_subtitles(no=seriesId, send_event=False)
event_stream(type='series', payload=seriesId)
episode_id_list = TableEpisodes \
.select(TableEpisodes.sonarrEpisodeId) \
.where(TableEpisodes.sonarrSeriesId == seriesId) \
.dicts()
episode_id_list = database.execute(
select(TableEpisodes.sonarrEpisodeId)
.where(TableEpisodes.sonarrSeriesId == seriesId))\
.all()
for item in episode_id_list:
event_stream(type='episode-wanted', payload=item['sonarrEpisodeId'])
event_stream(type='episode-wanted', payload=item.sonarrEpisodeId)
event_stream(type='badges')

View File

@ -6,7 +6,7 @@ import gc
from flask_restx import Resource, Namespace, reqparse
from app.database import TableEpisodes, TableMovies
from app.database import TableEpisodes, TableMovies, database, select
from languages.get_languages import alpha3_from_alpha2
from utilities.path_mappings import path_mappings
from subtitles.tools.subsyncer import SubSyncer
@ -53,28 +53,31 @@ class Subtitles(Resource):
id = args.get('id')
if media_type == 'episode':
metadata = TableEpisodes.select(TableEpisodes.path, TableEpisodes.sonarrSeriesId)\
.where(TableEpisodes.sonarrEpisodeId == id)\
.dicts()\
.get_or_none()
metadata = database.execute(
select(TableEpisodes.path, TableEpisodes.sonarrSeriesId)
.where(TableEpisodes.sonarrEpisodeId == id)) \
.first()
if not metadata:
return 'Episode not found', 404
video_path = path_mappings.path_replace(metadata['path'])
video_path = path_mappings.path_replace(metadata.path)
else:
metadata = TableMovies.select(TableMovies.path).where(TableMovies.radarrId == id).dicts().get_or_none()
metadata = database.execute(
select(TableMovies.path)
.where(TableMovies.radarrId == id))\
.first()
if not metadata:
return 'Movie not found', 404
video_path = path_mappings.path_replace_movie(metadata['path'])
video_path = path_mappings.path_replace_movie(metadata.path)
if action == 'sync':
subsync = SubSyncer()
if media_type == 'episode':
subsync.sync(video_path=video_path, srt_path=subtitles_path,
srt_lang=language, media_type='series', sonarr_series_id=metadata['sonarrSeriesId'],
srt_lang=language, media_type='series', sonarr_series_id=metadata.sonarrSeriesId,
sonarr_episode_id=int(id))
else:
subsync.sync(video_path=video_path, srt_path=subtitles_path,
@ -89,7 +92,7 @@ class Subtitles(Resource):
translate_subtitles_file(video_path=video_path, source_srt_file=subtitles_path,
from_lang=from_language, to_lang=dest_language, forced=forced, hi=hi,
media_type="series" if media_type == "episode" else "movies",
sonarr_series_id=metadata.get('sonarrSeriesId'),
sonarr_series_id=metadata.sonarrSeriesId,
sonarr_episode_id=int(id),
radarr_id=id)
else:
@ -105,7 +108,7 @@ class Subtitles(Resource):
if media_type == 'episode':
store_subtitles(path_mappings.path_replace_reverse(video_path), video_path)
event_stream(type='series', payload=metadata['sonarrSeriesId'])
event_stream(type='series', payload=metadata.sonarrSeriesId)
event_stream(type='episode', payload=int(id))
else:
store_subtitles_movie(path_mappings.path_replace_reverse_movie(video_path), video_path)

View File

@ -3,7 +3,7 @@
from flask_restx import Resource, Namespace, reqparse
from operator import itemgetter
from app.database import TableHistory, TableHistoryMovie, TableSettingsLanguages
from app.database import TableHistory, TableHistoryMovie, TableSettingsLanguages, database, select
from languages.get_languages import alpha2_from_alpha3, language_from_alpha2, alpha3_from_alpha2
from ..utils import authenticate, False_Keys
@ -25,13 +25,15 @@ class Languages(Resource):
args = self.get_request_parser.parse_args()
history = args.get('history')
if history and history not in False_Keys:
languages = list(TableHistory.select(TableHistory.language)
.where(TableHistory.language.is_null(False))
.dicts())
languages += list(TableHistoryMovie.select(TableHistoryMovie.language)
.where(TableHistoryMovie.language.is_null(False))
.dicts())
languages_list = list(set([lang['language'].split(':')[0] for lang in languages]))
languages = database.execute(
select(TableHistory.language)
.where(TableHistory.language.is_not(None)))\
.all()
languages += database.execute(
select(TableHistoryMovie.language)
.where(TableHistoryMovie.language.is_not(None)))\
.all()
languages_list = [lang.language.split(':')[0] for lang in languages]
languages_dicts = []
for language in languages_list:
code2 = None
@ -54,13 +56,17 @@ class Languages(Resource):
except Exception:
continue
else:
languages_dicts = TableSettingsLanguages.select(TableSettingsLanguages.name,
TableSettingsLanguages.code2,
TableSettingsLanguages.code3,
TableSettingsLanguages.enabled)\
.order_by(TableSettingsLanguages.name).dicts()
languages_dicts = list(languages_dicts)
for item in languages_dicts:
item['enabled'] = item['enabled'] == 1
languages_dicts = [{
'name': x.name,
'code2': x.code2,
'code3': x.code3,
'enabled': x.enabled == 1
} for x in database.execute(
select(TableSettingsLanguages.name,
TableSettingsLanguages.code2,
TableSettingsLanguages.code3,
TableSettingsLanguages.enabled)
.order_by(TableSettingsLanguages.name))
.all()]
return sorted(languages_dicts, key=itemgetter('name'))

View File

@ -1,9 +1,10 @@
# coding=utf-8
from flask_restx import Resource, Namespace, reqparse
from unidecode import unidecode
from app.config import settings
from app.database import TableShows, TableMovies
from app.database import TableShows, TableMovies, database, select
from ..utils import authenticate
@ -22,30 +23,42 @@ class Searches(Resource):
def get(self):
"""List results from query"""
args = self.get_request_parser.parse_args()
query = args.get('query')
query = unidecode(args.get('query')).lower()
search_list = []
if query:
if settings.general.getboolean('use_sonarr'):
# Get matching series
series = TableShows.select(TableShows.title,
TableShows.sonarrSeriesId,
TableShows.year)\
.where(TableShows.title.contains(query))\
.order_by(TableShows.title)\
.dicts()
series = list(series)
search_list += series
search_list += database.execute(
select(TableShows.title,
TableShows.sonarrSeriesId,
TableShows.year)
.order_by(TableShows.title)) \
.all()
if settings.general.getboolean('use_radarr'):
# Get matching movies
movies = TableMovies.select(TableMovies.title,
TableMovies.radarrId,
TableMovies.year) \
.where(TableMovies.title.contains(query)) \
.order_by(TableMovies.title) \
.dicts()
movies = list(movies)
search_list += movies
search_list += database.execute(
select(TableMovies.title,
TableMovies.radarrId,
TableMovies.year)
.order_by(TableMovies.title)) \
.all()
return search_list
results = []
for x in search_list:
if query in unidecode(x.title).lower():
result = {
'title': x.title,
'year': x.year,
}
if hasattr(x, 'sonarrSeriesId'):
result['sonarrSeriesId'] = x.sonarrSeriesId
else:
result['radarrId'] = x.radarrId
results.append(result)
return results

View File

@ -5,8 +5,8 @@ import json
from flask import request, jsonify
from flask_restx import Resource, Namespace
from app.database import TableLanguagesProfiles, TableSettingsLanguages, TableShows, TableMovies, \
TableSettingsNotifier, update_profile_id_list
from app.database import TableLanguagesProfiles, TableSettingsLanguages, TableSettingsNotifier, \
update_profile_id_list, database, insert, update, delete, select
from app.event_handler import event_stream
from app.config import settings, save_settings, get_settings
from app.scheduler import scheduler
@ -24,15 +24,17 @@ class SystemSettings(Resource):
@authenticate
def get(self):
data = get_settings()
notifications = TableSettingsNotifier.select().order_by(TableSettingsNotifier.name).dicts()
notifications = list(notifications)
for i, item in enumerate(notifications):
item["enabled"] = item["enabled"] == 1
notifications[i] = item
data['notifications'] = dict()
data['notifications']['providers'] = notifications
data['notifications']['providers'] = [{
'name': x.name,
'enabled': x.enabled == 1,
'url': x.url
} for x in database.execute(
select(TableSettingsNotifier.name,
TableSettingsNotifier.enabled,
TableSettingsNotifier.url)
.order_by(TableSettingsNotifier.name))
.all()]
return jsonify(data)
@ -40,57 +42,55 @@ class SystemSettings(Resource):
def post(self):
enabled_languages = request.form.getlist('languages-enabled')
if len(enabled_languages) != 0:
TableSettingsLanguages.update({
TableSettingsLanguages.enabled: 0
}).execute()
database.execute(
update(TableSettingsLanguages)
.values(enabled=0))
for code in enabled_languages:
TableSettingsLanguages.update({
TableSettingsLanguages.enabled: 1
})\
.where(TableSettingsLanguages.code2 == code)\
.execute()
database.execute(
update(TableSettingsLanguages)
.values(enabled=1)
.where(TableSettingsLanguages.code2 == code))
event_stream("languages")
languages_profiles = request.form.get('languages-profiles')
if languages_profiles:
existing_ids = TableLanguagesProfiles.select(TableLanguagesProfiles.profileId).dicts()
existing_ids = list(existing_ids)
existing = [x['profileId'] for x in existing_ids]
existing_ids = database.execute(
select(TableLanguagesProfiles.profileId))\
.all()
existing = [x.profileId for x in existing_ids]
for item in json.loads(languages_profiles):
if item['profileId'] in existing:
# Update existing profiles
TableLanguagesProfiles.update({
TableLanguagesProfiles.name: item['name'],
TableLanguagesProfiles.cutoff: item['cutoff'] if item['cutoff'] != 'null' else None,
TableLanguagesProfiles.items: json.dumps(item['items']),
TableLanguagesProfiles.mustContain: item['mustContain'],
TableLanguagesProfiles.mustNotContain: item['mustNotContain'],
TableLanguagesProfiles.originalFormat: item['originalFormat'] if item['originalFormat'] != 'null' else None,
})\
.where(TableLanguagesProfiles.profileId == item['profileId'])\
.execute()
database.execute(
update(TableLanguagesProfiles)
.values(
name=item['name'],
cutoff=item['cutoff'] if item['cutoff'] != 'null' else None,
items=json.dumps(item['items']),
mustContain=str(item['mustContain']),
mustNotContain=str(item['mustNotContain']),
originalFormat=item['originalFormat'] if item['originalFormat'] != 'null' else None,
)
.where(TableLanguagesProfiles.profileId == item['profileId']))
existing.remove(item['profileId'])
else:
# Add new profiles
TableLanguagesProfiles.insert({
TableLanguagesProfiles.profileId: item['profileId'],
TableLanguagesProfiles.name: item['name'],
TableLanguagesProfiles.cutoff: item['cutoff'] if item['cutoff'] != 'null' else None,
TableLanguagesProfiles.items: json.dumps(item['items']),
TableLanguagesProfiles.mustContain: item['mustContain'],
TableLanguagesProfiles.mustNotContain: item['mustNotContain'],
TableLanguagesProfiles.originalFormat: item['originalFormat'] if item['originalFormat'] != 'null' else None,
}).execute()
database.execute(
insert(TableLanguagesProfiles)
.values(
profileId=item['profileId'],
name=item['name'],
cutoff=item['cutoff'] if item['cutoff'] != 'null' else None,
items=json.dumps(item['items']),
mustContain=str(item['mustContain']),
mustNotContain=str(item['mustNotContain']),
originalFormat=item['originalFormat'] if item['originalFormat'] != 'null' else None,
))
for profileId in existing:
# Unassign this profileId from series and movies
TableShows.update({
TableShows.profileId: None
}).where(TableShows.profileId == profileId).execute()
TableMovies.update({
TableMovies.profileId: None
}).where(TableMovies.profileId == profileId).execute()
# Remove deleted profiles
TableLanguagesProfiles.delete().where(TableLanguagesProfiles.profileId == profileId).execute()
database.execute(
delete(TableLanguagesProfiles)
.where(TableLanguagesProfiles.profileId == profileId))
# invalidate cache
update_profile_id_list.invalidate()
@ -106,10 +106,11 @@ class SystemSettings(Resource):
notifications = request.form.getlist('notifications-providers')
for item in notifications:
item = json.loads(item)
TableSettingsNotifier.update({
TableSettingsNotifier.enabled: item['enabled'],
TableSettingsNotifier.url: item['url']
}).where(TableSettingsNotifier.name == item['name']).execute()
database.execute(
update(TableSettingsNotifier).values(
enabled=item['enabled'],
url=item['url'])
.where(TableSettingsNotifier.name == item['name']))
save_settings(zip(request.form.keys(), request.form.listvalues()))
event_stream("settings")

View File

@ -36,7 +36,7 @@ def authenticate(actual_method):
def postprocess(item):
# Remove ffprobe_cache
if item.get('movie_file_id'):
if item.get('radarrId'):
path_replace = path_mappings.path_replace_movie
else:
path_replace = path_mappings.path_replace
@ -57,12 +57,6 @@ def postprocess(item):
else:
item['alternativeTitles'] = []
# Parse failed attempts
if item.get('failedAttempts'):
item['failedAttempts'] = ast.literal_eval(item['failedAttempts'])
else:
item['failedAttempts'] = []
# Parse subtitles
if item.get('subtitles'):
item['subtitles'] = ast.literal_eval(item['subtitles'])
@ -135,10 +129,6 @@ def postprocess(item):
"hi": bool(item['language'].endswith(':hi')),
}
# Parse seriesType
if item.get('seriesType'):
item['seriesType'] = item['seriesType'].capitalize()
if item.get('path'):
item['path'] = path_replace(item['path'])
@ -149,8 +139,10 @@ def postprocess(item):
# map poster and fanart to server proxy
if item.get('poster') is not None:
poster = item['poster']
item['poster'] = f"{base_url}/images/{'movies' if item.get('movie_file_id') else 'series'}{poster}" if poster else None
item['poster'] = f"{base_url}/images/{'movies' if item.get('radarrId') else 'series'}{poster}" if poster else None
if item.get('fanart') is not None:
fanart = item['fanart']
item['fanart'] = f"{base_url}/images/{'movies' if item.get('movie_file_id') else 'series'}{fanart}" if fanart else None
item['fanart'] = f"{base_url}/images/{'movies' if item.get('radarrId') else 'series'}{fanart}" if fanart else None
return item

View File

@ -8,7 +8,7 @@ import logging
from flask_restx import Resource, Namespace, reqparse
from bs4 import BeautifulSoup as bso
from app.database import TableEpisodes, TableShows, TableMovies
from app.database import TableEpisodes, TableShows, TableMovies, database, select
from subtitles.mass_download import episode_download_subtitles, movies_download_subtitles
from ..utils import authenticate
@ -73,16 +73,17 @@ class WebHooksPlex(Resource):
logging.debug('BAZARR is unable to get series IMDB id.')
return 'IMDB series ID not found', 404
else:
sonarrEpisodeId = TableEpisodes.select(TableEpisodes.sonarrEpisodeId) \
.join(TableShows, on=(TableEpisodes.sonarrSeriesId == TableShows.sonarrSeriesId)) \
sonarrEpisodeId = database.execute(
select(TableEpisodes.sonarrEpisodeId)
.select_from(TableEpisodes)
.join(TableShows)
.where(TableShows.imdbId == series_imdb_id,
TableEpisodes.season == season,
TableEpisodes.episode == episode) \
.dicts() \
.get_or_none()
TableEpisodes.episode == episode)) \
.first()
if sonarrEpisodeId:
episode_download_subtitles(no=sonarrEpisodeId['sonarrEpisodeId'], send_progress=True)
episode_download_subtitles(no=sonarrEpisodeId.sonarrEpisodeId, send_progress=True)
else:
try:
movie_imdb_id = [x['imdb'] for x in ids if 'imdb' in x][0]
@ -90,12 +91,12 @@ class WebHooksPlex(Resource):
logging.debug('BAZARR is unable to get movie IMDB id.')
return 'IMDB movie ID not found', 404
else:
radarrId = TableMovies.select(TableMovies.radarrId)\
.where(TableMovies.imdbId == movie_imdb_id)\
.dicts()\
.get_or_none()
radarrId = database.execute(
select(TableMovies.radarrId)
.where(TableMovies.imdbId == movie_imdb_id)) \
.first()
if radarrId:
movies_download_subtitles(no=radarrId['radarrId'])
movies_download_subtitles(no=radarrId.radarrId)
return '', 200

View File

@ -2,7 +2,7 @@
from flask_restx import Resource, Namespace, reqparse
from app.database import TableMovies
from app.database import TableMovies, database, select
from subtitles.mass_download import movies_download_subtitles
from subtitles.indexer.movies import store_subtitles_movie
from utilities.path_mappings import path_mappings
@ -28,14 +28,13 @@ class WebHooksRadarr(Resource):
args = self.post_request_parser.parse_args()
movie_file_id = args.get('radarr_moviefile_id')
radarrMovieId = TableMovies.select(TableMovies.radarrId,
TableMovies.path) \
.where(TableMovies.movie_file_id == movie_file_id) \
.dicts() \
.get_or_none()
radarrMovieId = database.execute(
select(TableMovies.radarrId, TableMovies.path)
.where(TableMovies.movie_file_id == movie_file_id)) \
.first()
if radarrMovieId:
store_subtitles_movie(radarrMovieId['path'], path_mappings.path_replace_movie(radarrMovieId['path']))
movies_download_subtitles(no=radarrMovieId['radarrId'])
store_subtitles_movie(radarrMovieId.path, path_mappings.path_replace_movie(radarrMovieId.path))
movies_download_subtitles(no=radarrMovieId.radarrId)
return '', 200

View File

@ -2,7 +2,7 @@
from flask_restx import Resource, Namespace, reqparse
from app.database import TableEpisodes, TableShows
from app.database import TableEpisodes, TableShows, database, select
from subtitles.mass_download import episode_download_subtitles
from subtitles.indexer.series import store_subtitles
from utilities.path_mappings import path_mappings
@ -28,15 +28,15 @@ class WebHooksSonarr(Resource):
args = self.post_request_parser.parse_args()
episode_file_id = args.get('sonarr_episodefile_id')
sonarrEpisodeId = TableEpisodes.select(TableEpisodes.sonarrEpisodeId,
TableEpisodes.path) \
.join(TableShows, on=(TableEpisodes.sonarrSeriesId == TableShows.sonarrSeriesId)) \
.where(TableEpisodes.episode_file_id == episode_file_id) \
.dicts() \
.get_or_none()
sonarrEpisodeId = database.execute(
select(TableEpisodes.sonarrEpisodeId, TableEpisodes.path)
.select_from(TableEpisodes)
.join(TableShows)
.where(TableEpisodes.episode_file_id == episode_file_id)) \
.first()
if sonarrEpisodeId:
store_subtitles(sonarrEpisodeId['path'], path_mappings.path_replace(sonarrEpisodeId['path']))
episode_download_subtitles(no=sonarrEpisodeId['sonarrEpisodeId'], send_progress=True)
store_subtitles(sonarrEpisodeId.path, path_mappings.path_replace(sonarrEpisodeId.path))
episode_download_subtitles(no=sonarrEpisodeId.sonarrEpisodeId, send_progress=True)
return '', 200

View File

@ -11,7 +11,7 @@ from datetime import datetime
from operator import itemgetter
from app.get_providers import get_enabled_providers
from app.database import TableAnnouncements
from app.database import TableAnnouncements, database, insert, select
from .get_args import args
from sonarr.info import get_sonarr_info
from radarr.info import get_radarr_info
@ -116,9 +116,12 @@ def get_local_announcements():
def get_all_announcements():
# get announcements that haven't been dismissed yet
announcements = [parse_announcement_dict(x) for x in get_online_announcements() + get_local_announcements() if
x['enabled'] and (not x['dismissible'] or not TableAnnouncements.select()
.where(TableAnnouncements.hash ==
hashlib.sha256(x['text'].encode('UTF8')).hexdigest()).get_or_none())]
x['enabled'] and (not x['dismissible'] or not
database.execute(
select(TableAnnouncements)
.where(TableAnnouncements.hash ==
hashlib.sha256(x['text'].encode('UTF8')).hexdigest()))
.first())]
return sorted(announcements, key=itemgetter('timestamp'), reverse=True)
@ -126,8 +129,9 @@ def get_all_announcements():
def mark_announcement_as_dismissed(hashed_announcement):
text = [x['text'] for x in get_all_announcements() if x['hash'] == hashed_announcement]
if text:
TableAnnouncements.insert({TableAnnouncements.hash: hashed_announcement,
TableAnnouncements.timestamp: datetime.now(),
TableAnnouncements.text: text[0]})\
.on_conflict_ignore(ignore=True)\
.execute()
database.execute(
insert(TableAnnouncements)
.values(hash=hashed_announcement,
timestamp=datetime.now(),
text=text[0])
.on_conflict_do_nothing())

View File

@ -45,14 +45,13 @@ def create_app():
# generated by the request.
@app.before_request
def _db_connect():
database.connect()
database.begin()
# This hook ensures that the connection is closed when we've finished
# processing the request.
@app.teardown_request
def _db_close(exc):
if not database.is_closed():
database.close()
database.close()
return app

View File

@ -594,10 +594,8 @@ def save_settings(settings_items):
if audio_tracks_parsing_changed:
from .scheduler import scheduler
if settings.general.getboolean('use_sonarr'):
from sonarr.sync.episodes import sync_episodes
from sonarr.sync.series import update_series
scheduler.add_job(update_series, kwargs={'send_event': True}, max_instances=1)
scheduler.add_job(sync_episodes, kwargs={'send_event': True}, max_instances=1)
if settings.general.getboolean('use_radarr'):
from radarr.sync.movies import update_movies
scheduler.add_job(update_movies, kwargs={'send_event': True}, max_instances=1)

View File

@ -4,19 +4,20 @@ import atexit
import json
import logging
import os
import time
from datetime import datetime
import flask_migrate
from dogpile.cache import make_region
from peewee import Model, AutoField, TextField, IntegerField, ForeignKeyField, BlobField, BooleanField, BigIntegerField, \
DateTimeField, OperationalError, PostgresqlDatabase
from playhouse.migrate import PostgresqlMigrator
from playhouse.migrate import SqliteMigrator, migrate
from playhouse.shortcuts import ThreadSafeDatabaseMetadata, ReconnectMixin
from playhouse.sqlite_ext import RowIDField
from playhouse.sqliteq import SqliteQueueDatabase
from datetime import datetime
from sqlalchemy import create_engine, inspect, DateTime, ForeignKey, Integer, LargeBinary, Text, func, text
# importing here to be indirectly imported in other modules later
from sqlalchemy import update, delete, select, func # noqa W0611
from sqlalchemy.orm import scoped_session, sessionmaker, relationship, mapped_column
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.pool import NullPool
from flask_sqlalchemy import SQLAlchemy
from utilities.path_mappings import path_mappings
from .config import settings, get_array_from
from .get_args import args
@ -27,10 +28,9 @@ postgresql = (os.getenv("POSTGRES_ENABLED", settings.postgresql.enabled).lower()
region = make_region().configure('dogpile.cache.memory')
if postgresql:
class ReconnectPostgresqlDatabase(ReconnectMixin, PostgresqlDatabase):
reconnect_errors = (
(OperationalError, 'server closed the connection unexpectedly'),
)
# insert is different between database types
from sqlalchemy.dialects.postgresql import insert # noqa E402
from sqlalchemy.engine import URL # noqa E402
postgres_database = os.getenv("POSTGRES_DATABASE", settings.postgresql.database)
postgres_username = os.getenv("POSTGRES_USERNAME", settings.postgresql.username)
@ -38,520 +38,307 @@ if postgresql:
postgres_host = os.getenv("POSTGRES_HOST", settings.postgresql.host)
postgres_port = os.getenv("POSTGRES_PORT", settings.postgresql.port)
logger.debug(
f"Connecting to PostgreSQL database: {postgres_host}:{postgres_port}/{postgres_database}")
database = ReconnectPostgresqlDatabase(postgres_database,
user=postgres_username,
password=postgres_password,
host=postgres_host,
port=postgres_port,
autocommit=True,
autorollback=True,
autoconnect=True,
)
migrator = PostgresqlMigrator(database)
logger.debug(f"Connecting to PostgreSQL database: {postgres_host}:{postgres_port}/{postgres_database}")
url = URL.create(
drivername="postgresql",
username=postgres_username,
password=postgres_password,
host=postgres_host,
port=postgres_port,
database=postgres_database
)
engine = create_engine(url, poolclass=NullPool, isolation_level="AUTOCOMMIT")
else:
db_path = os.path.join(args.config_dir, 'db', 'bazarr.db')
logger.debug(f"Connecting to SQLite database: {db_path}")
database = SqliteQueueDatabase(db_path,
use_gevent=False,
autostart=True,
queue_max_size=256)
migrator = SqliteMigrator(database)
# insert is different between database types
from sqlalchemy.dialects.sqlite import insert # noqa E402
url = f'sqlite:///{os.path.join(args.config_dir, "db", "bazarr.db")}'
logger.debug(f"Connecting to SQLite database: {url}")
engine = create_engine(url, poolclass=NullPool, isolation_level="AUTOCOMMIT")
from sqlalchemy.engine import Engine
from sqlalchemy import event
@event.listens_for(Engine, "connect")
def set_sqlite_pragma(dbapi_connection, connection_record):
cursor = dbapi_connection.cursor()
cursor.execute("PRAGMA foreign_keys=ON")
cursor.close()
session_factory = sessionmaker(bind=engine)
database = scoped_session(session_factory)
@atexit.register
def _stop_worker_threads():
if not postgresql:
database.stop()
database.remove()
class UnknownField(object):
def __init__(self, *_, **__): pass
Base = declarative_base()
metadata = Base.metadata
class BaseModel(Model):
class Meta:
database = database
model_metadata_class = ThreadSafeDatabaseMetadata
class System(Base):
__tablename__ = 'system'
id = mapped_column(Integer, primary_key=True)
configured = mapped_column(Text)
updated = mapped_column(Text)
class System(BaseModel):
configured = TextField(null=True)
updated = TextField(null=True)
class TableAnnouncements(Base):
__tablename__ = 'table_announcements'
class Meta:
table_name = 'system'
primary_key = False
id = mapped_column(Integer, primary_key=True)
timestamp = mapped_column(DateTime, nullable=False, default=datetime.now)
hash = mapped_column(Text)
text = mapped_column(Text)
class TableBlacklist(BaseModel):
language = TextField(null=True)
provider = TextField(null=True)
sonarr_episode_id = IntegerField(null=True)
sonarr_series_id = IntegerField(null=True)
subs_id = TextField(null=True)
timestamp = DateTimeField(null=True)
class TableBlacklist(Base):
__tablename__ = 'table_blacklist'
class Meta:
table_name = 'table_blacklist'
primary_key = False
id = mapped_column(Integer, primary_key=True)
language = mapped_column(Text)
provider = mapped_column(Text)
sonarr_episode_id = mapped_column(Integer, ForeignKey('table_episodes.sonarrEpisodeId', ondelete='CASCADE'))
sonarr_series_id = mapped_column(Integer, ForeignKey('table_shows.sonarrSeriesId', ondelete='CASCADE'))
subs_id = mapped_column(Text)
timestamp = mapped_column(DateTime, default=datetime.now)
class TableBlacklistMovie(BaseModel):
language = TextField(null=True)
provider = TextField(null=True)
radarr_id = IntegerField(null=True)
subs_id = TextField(null=True)
timestamp = DateTimeField(null=True)
class TableBlacklistMovie(Base):
__tablename__ = 'table_blacklist_movie'
class Meta:
table_name = 'table_blacklist_movie'
primary_key = False
id = mapped_column(Integer, primary_key=True)
language = mapped_column(Text)
provider = mapped_column(Text)
radarr_id = mapped_column(Integer, ForeignKey('table_movies.radarrId', ondelete='CASCADE'))
subs_id = mapped_column(Text)
timestamp = mapped_column(DateTime, default=datetime.now)
class TableEpisodes(BaseModel):
rowid = RowIDField()
audio_codec = TextField(null=True)
audio_language = TextField(null=True)
episode = IntegerField()
episode_file_id = IntegerField(null=True)
failedAttempts = TextField(null=True)
ffprobe_cache = BlobField(null=True)
file_size = BigIntegerField(default=0, null=True)
format = TextField(null=True)
missing_subtitles = TextField(null=True)
monitored = TextField(null=True)
path = TextField()
resolution = TextField(null=True)
sceneName = TextField(null=True)
season = IntegerField()
sonarrEpisodeId = IntegerField(unique=True)
sonarrSeriesId = IntegerField()
subtitles = TextField(null=True)
title = TextField()
video_codec = TextField(null=True)
class TableCustomScoreProfiles(Base):
__tablename__ = 'table_custom_score_profiles'
class Meta:
table_name = 'table_episodes'
primary_key = False
id = mapped_column(Integer, primary_key=True)
name = mapped_column(Text)
media = mapped_column(Text)
score = mapped_column(Integer)
class TableHistory(BaseModel):
action = IntegerField()
description = TextField()
id = AutoField()
language = TextField(null=True)
provider = TextField(null=True)
score = IntegerField(null=True)
sonarrEpisodeId = IntegerField()
sonarrSeriesId = IntegerField()
subs_id = TextField(null=True)
subtitles_path = TextField(null=True)
timestamp = DateTimeField()
video_path = TextField(null=True)
class TableEpisodes(Base):
__tablename__ = 'table_episodes'
class Meta:
table_name = 'table_history'
audio_codec = mapped_column(Text)
audio_language = mapped_column(Text)
episode = mapped_column(Integer, nullable=False)
episode_file_id = mapped_column(Integer)
failedAttempts = mapped_column(Text)
ffprobe_cache = mapped_column(LargeBinary)
file_size = mapped_column(Integer)
format = mapped_column(Text)
missing_subtitles = mapped_column(Text)
monitored = mapped_column(Text)
path = mapped_column(Text, nullable=False)
resolution = mapped_column(Text)
sceneName = mapped_column(Text)
season = mapped_column(Integer, nullable=False)
sonarrEpisodeId = mapped_column(Integer, primary_key=True)
sonarrSeriesId = mapped_column(Integer, ForeignKey('table_shows.sonarrSeriesId', ondelete='CASCADE'))
subtitles = mapped_column(Text)
title = mapped_column(Text, nullable=False)
video_codec = mapped_column(Text)
class TableHistoryMovie(BaseModel):
action = IntegerField()
description = TextField()
id = AutoField()
language = TextField(null=True)
provider = TextField(null=True)
radarrId = IntegerField()
score = IntegerField(null=True)
subs_id = TextField(null=True)
subtitles_path = TextField(null=True)
timestamp = DateTimeField()
video_path = TextField(null=True)
class TableHistory(Base):
__tablename__ = 'table_history'
class Meta:
table_name = 'table_history_movie'
id = mapped_column(Integer, primary_key=True)
action = mapped_column(Integer, nullable=False)
description = mapped_column(Text, nullable=False)
language = mapped_column(Text)
provider = mapped_column(Text)
score = mapped_column(Integer)
sonarrEpisodeId = mapped_column(Integer, ForeignKey('table_episodes.sonarrEpisodeId', ondelete='CASCADE'))
sonarrSeriesId = mapped_column(Integer, ForeignKey('table_shows.sonarrSeriesId', ondelete='CASCADE'))
subs_id = mapped_column(Text)
subtitles_path = mapped_column(Text)
timestamp = mapped_column(DateTime, nullable=False, default=datetime.now)
video_path = mapped_column(Text)
matched = mapped_column(Text)
not_matched = mapped_column(Text)
class TableLanguagesProfiles(BaseModel):
cutoff = IntegerField(null=True)
originalFormat = BooleanField(null=True)
items = TextField()
name = TextField()
profileId = AutoField()
mustContain = TextField(null=True)
mustNotContain = TextField(null=True)
class TableHistoryMovie(Base):
__tablename__ = 'table_history_movie'
class Meta:
table_name = 'table_languages_profiles'
id = mapped_column(Integer, primary_key=True)
action = mapped_column(Integer, nullable=False)
description = mapped_column(Text, nullable=False)
language = mapped_column(Text)
provider = mapped_column(Text)
radarrId = mapped_column(Integer, ForeignKey('table_movies.radarrId', ondelete='CASCADE'))
score = mapped_column(Integer)
subs_id = mapped_column(Text)
subtitles_path = mapped_column(Text)
timestamp = mapped_column(DateTime, nullable=False, default=datetime.now)
video_path = mapped_column(Text)
matched = mapped_column(Text)
not_matched = mapped_column(Text)
class TableMovies(BaseModel):
rowid = RowIDField()
alternativeTitles = TextField(null=True)
audio_codec = TextField(null=True)
audio_language = TextField(null=True)
failedAttempts = TextField(null=True)
fanart = TextField(null=True)
ffprobe_cache = BlobField(null=True)
file_size = BigIntegerField(default=0, null=True)
format = TextField(null=True)
imdbId = TextField(null=True)
missing_subtitles = TextField(null=True)
monitored = TextField(null=True)
movie_file_id = IntegerField(null=True)
overview = TextField(null=True)
path = TextField(unique=True)
poster = TextField(null=True)
profileId = IntegerField(null=True)
radarrId = IntegerField(unique=True)
resolution = TextField(null=True)
sceneName = TextField(null=True)
sortTitle = TextField(null=True)
subtitles = TextField(null=True)
tags = TextField(null=True)
title = TextField()
tmdbId = TextField(unique=True)
video_codec = TextField(null=True)
year = TextField(null=True)
class TableLanguagesProfiles(Base):
__tablename__ = 'table_languages_profiles'
class Meta:
table_name = 'table_movies'
profileId = mapped_column(Integer, primary_key=True)
cutoff = mapped_column(Integer)
originalFormat = mapped_column(Integer)
items = mapped_column(Text, nullable=False)
name = mapped_column(Text, nullable=False)
mustContain = mapped_column(Text)
mustNotContain = mapped_column(Text)
class TableMoviesRootfolder(BaseModel):
accessible = IntegerField(null=True)
error = TextField(null=True)
id = IntegerField(null=True)
path = TextField(null=True)
class TableMovies(Base):
__tablename__ = 'table_movies'
class Meta:
table_name = 'table_movies_rootfolder'
primary_key = False
alternativeTitles = mapped_column(Text)
audio_codec = mapped_column(Text)
audio_language = mapped_column(Text)
failedAttempts = mapped_column(Text)
fanart = mapped_column(Text)
ffprobe_cache = mapped_column(LargeBinary)
file_size = mapped_column(Integer)
format = mapped_column(Text)
imdbId = mapped_column(Text)
missing_subtitles = mapped_column(Text)
monitored = mapped_column(Text)
movie_file_id = mapped_column(Integer)
overview = mapped_column(Text)
path = mapped_column(Text, nullable=False, unique=True)
poster = mapped_column(Text)
profileId = mapped_column(Integer, ForeignKey('table_languages_profiles.profileId', ondelete='SET NULL'))
radarrId = mapped_column(Integer, primary_key=True)
resolution = mapped_column(Text)
sceneName = mapped_column(Text)
sortTitle = mapped_column(Text)
subtitles = mapped_column(Text)
tags = mapped_column(Text)
title = mapped_column(Text, nullable=False)
tmdbId = mapped_column(Text, nullable=False, unique=True)
video_codec = mapped_column(Text)
year = mapped_column(Text)
class TableSettingsLanguages(BaseModel):
code2 = TextField(null=True)
code3 = TextField(primary_key=True)
code3b = TextField(null=True)
enabled = IntegerField(null=True)
name = TextField()
class TableMoviesRootfolder(Base):
__tablename__ = 'table_movies_rootfolder'
class Meta:
table_name = 'table_settings_languages'
accessible = mapped_column(Integer)
error = mapped_column(Text)
id = mapped_column(Integer, primary_key=True)
path = mapped_column(Text)
class TableSettingsNotifier(BaseModel):
enabled = IntegerField(null=True)
name = TextField(null=True, primary_key=True)
url = TextField(null=True)
class TableSettingsLanguages(Base):
__tablename__ = 'table_settings_languages'
class Meta:
table_name = 'table_settings_notifier'
code3 = mapped_column(Text, primary_key=True)
code2 = mapped_column(Text)
code3b = mapped_column(Text)
enabled = mapped_column(Integer)
name = mapped_column(Text, nullable=False)
class TableShows(BaseModel):
alternativeTitles = TextField(null=True)
audio_language = TextField(null=True)
fanart = TextField(null=True)
imdbId = TextField(default='""', null=True)
monitored = TextField(null=True)
overview = TextField(null=True)
path = TextField(unique=True)
poster = TextField(null=True)
profileId = IntegerField(null=True)
seriesType = TextField(null=True)
sonarrSeriesId = IntegerField(unique=True)
sortTitle = TextField(null=True)
tags = TextField(null=True)
title = TextField()
tvdbId = AutoField()
year = TextField(null=True)
class TableSettingsNotifier(Base):
__tablename__ = 'table_settings_notifier'
class Meta:
table_name = 'table_shows'
name = mapped_column(Text, primary_key=True)
enabled = mapped_column(Integer)
url = mapped_column(Text)
class TableShowsRootfolder(BaseModel):
accessible = IntegerField(null=True)
error = TextField(null=True)
id = IntegerField(null=True)
path = TextField(null=True)
class TableShows(Base):
__tablename__ = 'table_shows'
class Meta:
table_name = 'table_shows_rootfolder'
primary_key = False
tvdbId = mapped_column(Integer)
alternativeTitles = mapped_column(Text)
audio_language = mapped_column(Text)
fanart = mapped_column(Text)
imdbId = mapped_column(Text)
monitored = mapped_column(Text)
overview = mapped_column(Text)
path = mapped_column(Text, nullable=False, unique=True)
poster = mapped_column(Text)
profileId = mapped_column(Integer, ForeignKey('table_languages_profiles.profileId', ondelete='SET NULL'))
seriesType = mapped_column(Text)
sonarrSeriesId = mapped_column(Integer, primary_key=True)
sortTitle = mapped_column(Text)
tags = mapped_column(Text)
title = mapped_column(Text, nullable=False)
year = mapped_column(Text)
class TableCustomScoreProfiles(BaseModel):
id = AutoField()
name = TextField(null=True)
media = TextField(null=True)
score = IntegerField(null=True)
class TableShowsRootfolder(Base):
__tablename__ = 'table_shows_rootfolder'
class Meta:
table_name = 'table_custom_score_profiles'
accessible = mapped_column(Integer)
error = mapped_column(Text)
id = mapped_column(Integer, primary_key=True)
path = mapped_column(Text)
class TableCustomScoreProfileConditions(BaseModel):
profile_id = ForeignKeyField(TableCustomScoreProfiles, to_field="id")
type = TextField(null=True) # provider, uploader, regex, etc
value = TextField(null=True) # opensubtitles, jane_doe, [a-z], etc
required = BooleanField(default=False)
negate = BooleanField(default=False)
class TableCustomScoreProfileConditions(Base):
__tablename__ = 'table_custom_score_profile_conditions'
class Meta:
table_name = 'table_custom_score_profile_conditions'
id = mapped_column(Integer, primary_key=True)
profile_id = mapped_column(Integer, ForeignKey('table_custom_score_profiles.id'), nullable=False)
type = mapped_column(Text)
value = mapped_column(Text)
required = mapped_column(Integer, nullable=False)
negate = mapped_column(Integer, nullable=False)
class TableAnnouncements(BaseModel):
timestamp = DateTimeField()
hash = TextField(null=True, unique=True)
text = TextField(null=True)
class Meta:
table_name = 'table_announcements'
profile = relationship('TableCustomScoreProfiles')
def init_db():
# Create tables if they don't exists.
database.create_tables([System,
TableBlacklist,
TableBlacklistMovie,
TableEpisodes,
TableHistory,
TableHistoryMovie,
TableLanguagesProfiles,
TableMovies,
TableMoviesRootfolder,
TableSettingsLanguages,
TableSettingsNotifier,
TableShows,
TableShowsRootfolder,
TableCustomScoreProfiles,
TableCustomScoreProfileConditions,
TableAnnouncements])
database.begin()
# Create tables if they don't exist.
metadata.create_all(engine)
def create_db_revision(app):
logging.info("Creating a new database revision for future migration")
app.config["SQLALCHEMY_DATABASE_URI"] = url
db = SQLAlchemy(app, metadata=metadata)
with app.app_context():
flask_migrate.Migrate(app, db, render_as_batch=True)
flask_migrate.migrate()
db.engine.dispose()
def migrate_db(app):
logging.debug("Upgrading database schema")
app.config["SQLALCHEMY_DATABASE_URI"] = url
db = SQLAlchemy(app, metadata=metadata)
insp = inspect(engine)
alembic_temp_tables_list = [x for x in insp.get_table_names() if x.startswith('_alembic_tmp_')]
for table in alembic_temp_tables_list:
database.execute(text(f"DROP TABLE IF EXISTS {table}"))
with app.app_context():
flask_migrate.Migrate(app, db, render_as_batch=True)
flask_migrate.upgrade()
db.engine.dispose()
# add the system table single row if it's not existing
# we must retry until the tables are created
tables_created = False
while not tables_created:
try:
if not System.select().count():
System.insert({System.configured: '0', System.updated: '0'}).execute()
except Exception:
time.sleep(0.1)
else:
tables_created = True
def migrate_db():
table_shows = [t.name for t in database.get_columns('table_shows')]
table_episodes = [t.name for t in database.get_columns('table_episodes')]
table_movies = [t.name for t in database.get_columns('table_movies')]
table_history = [t.name for t in database.get_columns('table_history')]
table_history_movie = [t.name for t in database.get_columns('table_history_movie')]
table_languages_profiles = [t.name for t in database.get_columns('table_languages_profiles')]
if "year" not in table_shows:
migrate(migrator.add_column('table_shows', 'year', TextField(null=True)))
if "alternativeTitle" not in table_shows:
migrate(migrator.add_column('table_shows', 'alternativeTitle', TextField(null=True)))
if "tags" not in table_shows:
migrate(migrator.add_column('table_shows', 'tags', TextField(default='[]', null=True)))
if "seriesType" not in table_shows:
migrate(migrator.add_column('table_shows', 'seriesType', TextField(default='""', null=True)))
if "imdbId" not in table_shows:
migrate(migrator.add_column('table_shows', 'imdbId', TextField(default='""', null=True)))
if "profileId" not in table_shows:
migrate(migrator.add_column('table_shows', 'profileId', IntegerField(null=True)))
if "profileId" not in table_shows:
migrate(migrator.add_column('table_shows', 'profileId', IntegerField(null=True)))
if "monitored" not in table_shows:
migrate(migrator.add_column('table_shows', 'monitored', TextField(null=True)))
if "format" not in table_episodes:
migrate(migrator.add_column('table_episodes', 'format', TextField(null=True)))
if "resolution" not in table_episodes:
migrate(migrator.add_column('table_episodes', 'resolution', TextField(null=True)))
if "video_codec" not in table_episodes:
migrate(migrator.add_column('table_episodes', 'video_codec', TextField(null=True)))
if "audio_codec" not in table_episodes:
migrate(migrator.add_column('table_episodes', 'audio_codec', TextField(null=True)))
if "episode_file_id" not in table_episodes:
migrate(migrator.add_column('table_episodes', 'episode_file_id', IntegerField(null=True)))
if "audio_language" not in table_episodes:
migrate(migrator.add_column('table_episodes', 'audio_language', TextField(null=True)))
if "file_size" not in table_episodes:
migrate(migrator.add_column('table_episodes', 'file_size', BigIntegerField(default=0, null=True)))
if "ffprobe_cache" not in table_episodes:
migrate(migrator.add_column('table_episodes', 'ffprobe_cache', BlobField(null=True)))
if "sortTitle" not in table_movies:
migrate(migrator.add_column('table_movies', 'sortTitle', TextField(null=True)))
if "year" not in table_movies:
migrate(migrator.add_column('table_movies', 'year', TextField(null=True)))
if "alternativeTitles" not in table_movies:
migrate(migrator.add_column('table_movies', 'alternativeTitles', TextField(null=True)))
if "format" not in table_movies:
migrate(migrator.add_column('table_movies', 'format', TextField(null=True)))
if "resolution" not in table_movies:
migrate(migrator.add_column('table_movies', 'resolution', TextField(null=True)))
if "video_codec" not in table_movies:
migrate(migrator.add_column('table_movies', 'video_codec', TextField(null=True)))
if "audio_codec" not in table_movies:
migrate(migrator.add_column('table_movies', 'audio_codec', TextField(null=True)))
if "imdbId" not in table_movies:
migrate(migrator.add_column('table_movies', 'imdbId', TextField(null=True)))
if "movie_file_id" not in table_movies:
migrate(migrator.add_column('table_movies', 'movie_file_id', IntegerField(null=True)))
if "tags" not in table_movies:
migrate(migrator.add_column('table_movies', 'tags', TextField(default='[]', null=True)))
if "profileId" not in table_movies:
migrate(migrator.add_column('table_movies', 'profileId', IntegerField(null=True)))
if "file_size" not in table_movies:
migrate(migrator.add_column('table_movies', 'file_size', BigIntegerField(default=0, null=True)))
if "ffprobe_cache" not in table_movies:
migrate(migrator.add_column('table_movies', 'ffprobe_cache', BlobField(null=True)))
if "video_path" not in table_history:
migrate(migrator.add_column('table_history', 'video_path', TextField(null=True)))
if "language" not in table_history:
migrate(migrator.add_column('table_history', 'language', TextField(null=True)))
if "provider" not in table_history:
migrate(migrator.add_column('table_history', 'provider', TextField(null=True)))
if "score" not in table_history:
migrate(migrator.add_column('table_history', 'score', TextField(null=True)))
if "subs_id" not in table_history:
migrate(migrator.add_column('table_history', 'subs_id', TextField(null=True)))
if "subtitles_path" not in table_history:
migrate(migrator.add_column('table_history', 'subtitles_path', TextField(null=True)))
if "video_path" not in table_history_movie:
migrate(migrator.add_column('table_history_movie', 'video_path', TextField(null=True)))
if "language" not in table_history_movie:
migrate(migrator.add_column('table_history_movie', 'language', TextField(null=True)))
if "provider" not in table_history_movie:
migrate(migrator.add_column('table_history_movie', 'provider', TextField(null=True)))
if "score" not in table_history_movie:
migrate(migrator.add_column('table_history_movie', 'score', TextField(null=True)))
if "subs_id" not in table_history_movie:
migrate(migrator.add_column('table_history_movie', 'subs_id', TextField(null=True)))
if "subtitles_path" not in table_history_movie:
migrate(migrator.add_column('table_history_movie', 'subtitles_path', TextField(null=True)))
if "mustContain" not in table_languages_profiles:
migrate(migrator.add_column('table_languages_profiles', 'mustContain', TextField(null=True)))
if "mustNotContain" not in table_languages_profiles:
migrate(migrator.add_column('table_languages_profiles', 'mustNotContain', TextField(null=True)))
if "originalFormat" not in table_languages_profiles:
migrate(migrator.add_column('table_languages_profiles', 'originalFormat', BooleanField(null=True)))
if "languages" in table_shows:
migrate(migrator.drop_column('table_shows', 'languages'))
if "hearing_impaired" in table_shows:
migrate(migrator.drop_column('table_shows', 'hearing_impaired'))
if "languages" in table_movies:
migrate(migrator.drop_column('table_movies', 'languages'))
if "hearing_impaired" in table_movies:
migrate(migrator.drop_column('table_movies', 'hearing_impaired'))
if not any(
x
for x in database.get_columns('table_blacklist')
if x.name == "timestamp" and x.data_type in ["DATETIME", "timestamp without time zone"]
):
migrate(migrator.alter_column_type('table_blacklist', 'timestamp', DateTimeField(default=datetime.now)))
update = TableBlacklist.select()
for item in update:
item.update({"timestamp": datetime.fromtimestamp(int(item.timestamp))}).execute()
if not any(
x
for x in database.get_columns('table_blacklist_movie')
if x.name == "timestamp" and x.data_type in ["DATETIME", "timestamp without time zone"]
):
migrate(migrator.alter_column_type('table_blacklist_movie', 'timestamp', DateTimeField(default=datetime.now)))
update = TableBlacklistMovie.select()
for item in update:
item.update({"timestamp": datetime.fromtimestamp(int(item.timestamp))}).execute()
if not any(
x for x in database.get_columns('table_history') if x.name == "score" and x.data_type.lower() == "integer"):
migrate(migrator.alter_column_type('table_history', 'score', IntegerField(null=True)))
if not any(
x
for x in database.get_columns('table_history')
if x.name == "timestamp" and x.data_type in ["DATETIME", "timestamp without time zone"]
):
migrate(migrator.alter_column_type('table_history', 'timestamp', DateTimeField(default=datetime.now)))
update = TableHistory.select()
list_to_update = []
for i, item in enumerate(update):
item.timestamp = datetime.fromtimestamp(int(item.timestamp))
list_to_update.append(item)
if i % 100 == 0:
TableHistory.bulk_update(list_to_update, fields=[TableHistory.timestamp])
list_to_update = []
if list_to_update:
TableHistory.bulk_update(list_to_update, fields=[TableHistory.timestamp])
if not any(x for x in database.get_columns('table_history_movie') if
x.name == "score" and x.data_type.lower() == "integer"):
migrate(migrator.alter_column_type('table_history_movie', 'score', IntegerField(null=True)))
if not any(
x
for x in database.get_columns('table_history_movie')
if x.name == "timestamp" and x.data_type in ["DATETIME", "timestamp without time zone"]
):
migrate(migrator.alter_column_type('table_history_movie', 'timestamp', DateTimeField(default=datetime.now)))
update = TableHistoryMovie.select()
list_to_update = []
for i, item in enumerate(update):
item.timestamp = datetime.fromtimestamp(int(item.timestamp))
list_to_update.append(item)
if i % 100 == 0:
TableHistoryMovie.bulk_update(list_to_update, fields=[TableHistoryMovie.timestamp])
list_to_update = []
if list_to_update:
TableHistoryMovie.bulk_update(list_to_update, fields=[TableHistoryMovie.timestamp])
# if not any(x for x in database.get_columns('table_movies') if x.name == "monitored" and x.data_type == "BOOLEAN"):
# migrate(migrator.alter_column_type('table_movies', 'monitored', BooleanField(null=True)))
if database.get_columns('table_settings_providers'):
database.execute_sql('drop table if exists table_settings_providers;')
if "alternateTitles" in table_shows:
migrate(migrator.rename_column('table_shows', 'alternateTitles', "alternativeTitles"))
if "scene_name" in table_episodes:
migrate(migrator.rename_column('table_episodes', 'scene_name', "sceneName"))
class SqliteDictPathMapper:
def __init__(self):
pass
@staticmethod
def path_replace(values_dict):
if type(values_dict) is list:
for item in values_dict:
item['path'] = path_mappings.path_replace(item['path'])
elif type(values_dict) is dict:
values_dict['path'] = path_mappings.path_replace(values_dict['path'])
else:
return path_mappings.path_replace(values_dict)
@staticmethod
def path_replace_movie(values_dict):
if type(values_dict) is list:
for item in values_dict:
item['path'] = path_mappings.path_replace_movie(item['path'])
elif type(values_dict) is dict:
values_dict['path'] = path_mappings.path_replace_movie(values_dict['path'])
else:
return path_mappings.path_replace_movie(values_dict)
dict_mapper = SqliteDictPathMapper()
if not database.execute(
select(System)) \
.first():
database.execute(
insert(System)
.values(configured='0', updated='0'))
def get_exclusion_clause(exclusion_type):
@ -568,12 +355,12 @@ def get_exclusion_clause(exclusion_type):
if exclusion_type == 'series':
monitoredOnly = settings.sonarr.getboolean('only_monitored')
if monitoredOnly:
where_clause.append((TableEpisodes.monitored == True)) # noqa E712
where_clause.append((TableShows.monitored == True)) # noqa E712
where_clause.append((TableEpisodes.monitored == 'True')) # noqa E712
where_clause.append((TableShows.monitored == 'True')) # noqa E712
else:
monitoredOnly = settings.radarr.getboolean('only_monitored')
if monitoredOnly:
where_clause.append((TableMovies.monitored == True)) # noqa E712
where_clause.append((TableMovies.monitored == 'True')) # noqa E712
if exclusion_type == 'series':
typesList = get_array_from(settings.sonarr.excluded_series_types)
@ -589,20 +376,24 @@ def get_exclusion_clause(exclusion_type):
@region.cache_on_arguments()
def update_profile_id_list():
profile_id_list = TableLanguagesProfiles.select(TableLanguagesProfiles.profileId,
TableLanguagesProfiles.name,
TableLanguagesProfiles.cutoff,
TableLanguagesProfiles.items,
TableLanguagesProfiles.mustContain,
TableLanguagesProfiles.mustNotContain,
TableLanguagesProfiles.originalFormat).dicts()
profile_id_list = list(profile_id_list)
for profile in profile_id_list:
profile['items'] = json.loads(profile['items'])
profile['mustContain'] = ast.literal_eval(profile['mustContain']) if profile['mustContain'] else []
profile['mustNotContain'] = ast.literal_eval(profile['mustNotContain']) if profile['mustNotContain'] else []
return profile_id_list
return [{
'profileId': x.profileId,
'name': x.name,
'cutoff': x.cutoff,
'items': json.loads(x.items),
'mustContain': ast.literal_eval(x.mustContain) if x.mustContain else [],
'mustNotContain': ast.literal_eval(x.mustNotContain) if x.mustNotContain else [],
'originalFormat': x.originalFormat,
} for x in database.execute(
select(TableLanguagesProfiles.profileId,
TableLanguagesProfiles.name,
TableLanguagesProfiles.cutoff,
TableLanguagesProfiles.items,
TableLanguagesProfiles.mustContain,
TableLanguagesProfiles.mustNotContain,
TableLanguagesProfiles.originalFormat))
.all()
]
def get_profiles_list(profile_id=None):
@ -617,36 +408,15 @@ def get_profiles_list(profile_id=None):
def get_desired_languages(profile_id):
languages = []
profile_id_list = update_profile_id_list()
if profile_id and profile_id != 'null':
for profile in profile_id_list:
profileId, name, cutoff, items, mustContain, mustNotContain, originalFormat = profile.values()
try:
profile_id_int = int(profile_id)
except ValueError:
continue
else:
if profileId == profile_id_int:
languages = [x['language'] for x in items]
break
return languages
for profile in update_profile_id_list():
if profile['profileId'] == profile_id:
return [x['language'] for x in profile['items']]
def get_profile_id_name(profile_id):
name_from_id = None
profile_id_list = update_profile_id_list()
if profile_id and profile_id != 'null':
for profile in profile_id_list:
profileId, name, cutoff, items, mustContain, mustNotContain, originalFormat = profile.values()
if profileId == int(profile_id):
name_from_id = name
break
return name_from_id
for profile in update_profile_id_list():
if profile['profileId'] == profile_id:
return profile['name']
def get_profile_cutoff(profile_id):
@ -703,23 +473,27 @@ def get_audio_profile_languages(audio_languages_list_str):
def get_profile_id(series_id=None, episode_id=None, movie_id=None):
if series_id:
data = TableShows.select(TableShows.profileId) \
.where(TableShows.sonarrSeriesId == series_id) \
.get_or_none()
data = database.execute(
select(TableShows.profileId)
.where(TableShows.sonarrSeriesId == series_id))\
.first()
if data:
return data.profileId
elif episode_id:
data = TableShows.select(TableShows.profileId) \
.join(TableEpisodes, on=(TableShows.sonarrSeriesId == TableEpisodes.sonarrSeriesId)) \
.where(TableEpisodes.sonarrEpisodeId == episode_id) \
.get_or_none()
data = database.execute(
select(TableShows.profileId)
.select_from(TableShows)
.join(TableEpisodes)
.where(TableEpisodes.sonarrEpisodeId == episode_id)) \
.first()
if data:
return data.profileId
elif movie_id:
data = TableMovies.select(TableMovies.profileId) \
.where(TableMovies.radarrId == movie_id) \
.get_or_none()
data = database.execute(
select(TableMovies.profileId)
.where(TableMovies.radarrId == movie_id))\
.first()
if data:
return data.profileId

View File

@ -30,6 +30,8 @@ parser.add_argument('--no-tasks', default=False, type=bool, const=True, metavar=
help="Disable all tasks (default: False)")
parser.add_argument('--no-signalr', default=False, type=bool, const=True, metavar="BOOL", nargs="?",
help="Disable SignalR connections to Sonarr and/or Radarr (default: False)")
parser.add_argument('--create-db-revision', default=False, type=bool, const=True, metavar="BOOL", nargs="?",
help="Create a new database revision that will be used to migrate database")
if not no_cli:

View File

@ -59,6 +59,7 @@ class NoExceptionFormatter(logging.Formatter):
def configure_logging(debug=False):
warnings.simplefilter('ignore', category=ResourceWarning)
warnings.simplefilter('ignore', category=PytzUsageWarning)
# warnings.simplefilter('ignore', category=SAWarning)
if not debug:
log_level = "INFO"
@ -93,7 +94,7 @@ def configure_logging(debug=False):
logger.addHandler(fh)
if debug:
logging.getLogger("peewee").setLevel(logging.DEBUG)
logging.getLogger("alembic.runtime.migration").setLevel(logging.DEBUG)
logging.getLogger("apscheduler").setLevel(logging.DEBUG)
logging.getLogger("subliminal").setLevel(logging.DEBUG)
logging.getLogger("subliminal_patch").setLevel(logging.DEBUG)
@ -111,7 +112,7 @@ def configure_logging(debug=False):
logging.debug('Operating system: %s', platform.platform())
logging.debug('Python version: %s', platform.python_version())
else:
logging.getLogger("peewee").setLevel(logging.CRITICAL)
logging.getLogger("alembic.runtime.migration").setLevel(logging.CRITICAL)
logging.getLogger("apscheduler").setLevel(logging.WARNING)
logging.getLogger("apprise").setLevel(logging.WARNING)
logging.getLogger("subliminal").setLevel(logging.CRITICAL)

View File

@ -3,99 +3,70 @@
import apprise
import logging
from .database import TableSettingsNotifier, TableEpisodes, TableShows, TableMovies
from .database import TableSettingsNotifier, TableEpisodes, TableShows, TableMovies, database, insert, delete, select
def update_notifier():
# define apprise object
a = apprise.Apprise()
# Retrieve all of the details
# Retrieve all the details
results = a.details()
notifiers_new = []
notifiers_old = []
notifiers_added = []
notifiers_kept = []
notifiers_current_db = TableSettingsNotifier.select(TableSettingsNotifier.name).dicts()
notifiers_current = []
for notifier in notifiers_current_db:
notifiers_current.append([notifier['name']])
notifiers_in_db = [row.name for row in
database.execute(
select(TableSettingsNotifier.name))
.all()]
for x in results['schemas']:
if [str(x['service_name'])] not in notifiers_current:
notifiers_new.append({'name': str(x['service_name']), 'enabled': 0})
if x['service_name'] not in notifiers_in_db:
notifiers_added.append({'name': str(x['service_name']), 'enabled': 0})
logging.debug('Adding new notifier agent: ' + str(x['service_name']))
else:
notifiers_old.append([str(x['service_name'])])
notifiers_kept.append(x['service_name'])
notifiers_to_delete = [item for item in notifiers_current if item not in notifiers_old]
TableSettingsNotifier.insert_many(notifiers_new).execute()
notifiers_to_delete = [item for item in notifiers_in_db if item not in notifiers_kept]
for item in notifiers_to_delete:
TableSettingsNotifier.delete().where(TableSettingsNotifier.name == item).execute()
database.execute(
delete(TableSettingsNotifier)
.where(TableSettingsNotifier.name == item))
database.execute(
insert(TableSettingsNotifier)
.values(notifiers_added))
def get_notifier_providers():
providers = TableSettingsNotifier.select(TableSettingsNotifier.name,
TableSettingsNotifier.url)\
.where(TableSettingsNotifier.enabled == 1)\
.dicts()
return providers
def get_series(sonarr_series_id):
data = TableShows.select(TableShows.title, TableShows.year)\
.where(TableShows.sonarrSeriesId == sonarr_series_id)\
.dicts()\
.get_or_none()
if not data:
return
return {'title': data['title'], 'year': data['year']}
def get_episode_name(sonarr_episode_id):
data = TableEpisodes.select(TableEpisodes.title, TableEpisodes.season, TableEpisodes.episode)\
.where(TableEpisodes.sonarrEpisodeId == sonarr_episode_id)\
.dicts()\
.get_or_none()
if not data:
return
return data['title'], data['season'], data['episode']
def get_movie(radarr_id):
data = TableMovies.select(TableMovies.title, TableMovies.year)\
.where(TableMovies.radarrId == radarr_id)\
.dicts()\
.get_or_none()
if not data:
return
return {'title': data['title'], 'year': data['year']}
return database.execute(
select(TableSettingsNotifier.name, TableSettingsNotifier.url)
.where(TableSettingsNotifier.enabled == 1))\
.all()
def send_notifications(sonarr_series_id, sonarr_episode_id, message):
providers = get_notifier_providers()
if not len(providers):
return
series = get_series(sonarr_series_id)
series = database.execute(
select(TableShows.title, TableShows.year)
.where(TableShows.sonarrSeriesId == sonarr_series_id))\
.first()
if not series:
return
series_title = series['title']
series_year = series['year']
series_title = series.title
series_year = series.year
if series_year not in [None, '', '0']:
series_year = ' ({})'.format(series_year)
else:
series_year = ''
episode = get_episode_name(sonarr_episode_id)
episode = database.execute(
select(TableEpisodes.title, TableEpisodes.season, TableEpisodes.episode)
.where(TableEpisodes.sonarrEpisodeId == sonarr_episode_id))\
.first()
if not episode:
return
@ -109,8 +80,8 @@ def send_notifications(sonarr_series_id, sonarr_episode_id, message):
apobj.notify(
title='Bazarr notification',
body="{}{} - S{:02d}E{:02d} - {} : {}".format(series_title, series_year, episode[1], episode[2], episode[0],
message),
body="{}{} - S{:02d}E{:02d} - {} : {}".format(series_title, series_year, episode.season, episode.episode,
episode.title, message),
)
@ -118,11 +89,14 @@ def send_notifications_movie(radarr_id, message):
providers = get_notifier_providers()
if not len(providers):
return
movie = get_movie(radarr_id)
movie = database.execute(
select(TableMovies.title, TableMovies.year)
.where(TableMovies.radarrId == radarr_id))\
.first()
if not movie:
return
movie_title = movie['title']
movie_year = movie['year']
movie_title = movie.title
movie_year = movie.year
if movie_year not in [None, '', '0']:
movie_year = ' ({})'.format(movie_year)
else:

View File

@ -19,7 +19,7 @@ import logging
from app.announcements import get_announcements_to_file
from sonarr.sync.series import update_series
from sonarr.sync.episodes import sync_episodes, update_all_episodes
from sonarr.sync.episodes import update_all_episodes
from radarr.sync.movies import update_movies, update_all_movies
from subtitles.wanted import wanted_search_missing_subtitles_series, wanted_search_missing_subtitles_movies
from subtitles.upgrade import upgrade_subtitles
@ -163,18 +163,14 @@ class Scheduler:
if settings.general.getboolean('use_sonarr'):
self.aps_scheduler.add_job(
update_series, IntervalTrigger(minutes=int(settings.sonarr.series_sync)), max_instances=1,
coalesce=True, misfire_grace_time=15, id='update_series', name='Update Series list from Sonarr',
replace_existing=True)
self.aps_scheduler.add_job(
sync_episodes, IntervalTrigger(minutes=int(settings.sonarr.episodes_sync)), max_instances=1,
coalesce=True, misfire_grace_time=15, id='sync_episodes', name='Sync episodes with Sonarr',
coalesce=True, misfire_grace_time=15, id='update_series', name='Sync with Sonarr',
replace_existing=True)
def __radarr_update_task(self):
if settings.general.getboolean('use_radarr'):
self.aps_scheduler.add_job(
update_movies, IntervalTrigger(minutes=int(settings.radarr.movies_sync)), max_instances=1,
coalesce=True, misfire_grace_time=15, id='update_movies', name='Update Movie list from Radarr',
coalesce=True, misfire_grace_time=15, id='update_movies', name='Sync with Radarr',
replace_existing=True)
def __cache_cleanup_task(self):
@ -210,18 +206,18 @@ class Scheduler:
self.aps_scheduler.add_job(
update_all_episodes, CronTrigger(hour=settings.sonarr.full_update_hour), max_instances=1,
coalesce=True, misfire_grace_time=15, id='update_all_episodes',
name='Update all Episode Subtitles from disk', replace_existing=True)
name='Index all Episode Subtitles from disk', replace_existing=True)
elif full_update == "Weekly":
self.aps_scheduler.add_job(
update_all_episodes,
CronTrigger(day_of_week=settings.sonarr.full_update_day, hour=settings.sonarr.full_update_hour),
max_instances=1, coalesce=True, misfire_grace_time=15, id='update_all_episodes',
name='Update all Episode Subtitles from disk', replace_existing=True)
name='Index all Episode Subtitles from disk', replace_existing=True)
elif full_update == "Manually":
self.aps_scheduler.add_job(
update_all_episodes, CronTrigger(year='2100'), max_instances=1, coalesce=True,
misfire_grace_time=15, id='update_all_episodes',
name='Update all Episode Subtitles from disk', replace_existing=True)
name='Index all Episode Subtitles from disk', replace_existing=True)
def __radarr_full_update_task(self):
if settings.general.getboolean('use_radarr'):
@ -230,17 +226,17 @@ class Scheduler:
self.aps_scheduler.add_job(
update_all_movies, CronTrigger(hour=settings.radarr.full_update_hour), max_instances=1,
coalesce=True, misfire_grace_time=15,
id='update_all_movies', name='Update all Movie Subtitles from disk', replace_existing=True)
id='update_all_movies', name='Index all Movie Subtitles from disk', replace_existing=True)
elif full_update == "Weekly":
self.aps_scheduler.add_job(
update_all_movies,
CronTrigger(day_of_week=settings.radarr.full_update_day, hour=settings.radarr.full_update_hour),
max_instances=1, coalesce=True, misfire_grace_time=15, id='update_all_movies',
name='Update all Movie Subtitles from disk', replace_existing=True)
name='Index all Movie Subtitles from disk', replace_existing=True)
elif full_update == "Manually":
self.aps_scheduler.add_job(
update_all_movies, CronTrigger(year='2100'), max_instances=1, coalesce=True, misfire_grace_time=15,
id='update_all_movies', name='Update all Movie Subtitles from disk', replace_existing=True)
id='update_all_movies', name='Index all Movie Subtitles from disk', replace_existing=True)
def __update_bazarr_task(self):
if not args.no_update and os.environ["BAZARR_VERSION"] != '':

View File

@ -19,7 +19,7 @@ from sonarr.sync.series import update_series, update_one_series
from radarr.sync.movies import update_movies, update_one_movie
from sonarr.info import get_sonarr_info, url_sonarr
from radarr.info import url_radarr
from .database import TableShows
from .database import TableShows, database, select
from .config import settings
from .scheduler import scheduler
@ -73,7 +73,6 @@ class SonarrSignalrClientLegacy:
logging.info('BAZARR SignalR client for Sonarr is connected and waiting for events.')
if not args.dev:
scheduler.add_job(update_series, kwargs={'send_event': True}, max_instances=1)
scheduler.add_job(sync_episodes, kwargs={'send_event': True}, max_instances=1)
def stop(self, log=True):
try:
@ -150,7 +149,6 @@ class SonarrSignalrClient:
logging.info('BAZARR SignalR client for Sonarr is connected and waiting for events.')
if not args.dev:
scheduler.add_job(update_series, kwargs={'send_event': True}, max_instances=1)
scheduler.add_job(sync_episodes, kwargs={'send_event': True}, max_instances=1)
def on_reconnect_handler(self):
self.connected = False
@ -266,10 +264,10 @@ def dispatcher(data):
series_title = data['body']['resource']['series']['title']
series_year = data['body']['resource']['series']['year']
else:
series_metadata = TableShows.select(TableShows.title, TableShows.year)\
.where(TableShows.sonarrSeriesId == data['body']['resource']['seriesId'])\
.dicts()\
.get_or_none()
series_metadata = database.execute(
select(TableShows.title, TableShows.year)
.where(TableShows.sonarrSeriesId == data['body']['resource']['seriesId']))\
.first()
if series_metadata:
series_title = series_metadata['title']
series_year = series_metadata['year']
@ -284,10 +282,10 @@ def dispatcher(data):
if topic == 'series':
logging.debug(f'Event received from Sonarr for series: {series_title} ({series_year})')
update_one_series(series_id=media_id, action=action, send_event=False)
update_one_series(series_id=media_id, action=action)
if episodesChanged:
# this will happen if a season monitored status is changed.
sync_episodes(series_id=media_id, send_event=False)
sync_episodes(series_id=media_id, send_event=True)
elif topic == 'episode':
logging.debug(f'Event received from Sonarr for episode: {series_title} ({series_year}) - '
f'S{season_number:0>2}E{episode_number:0>2} - {episode_title}')

View File

@ -18,6 +18,8 @@ from utilities.binaries import get_binary, BinaryNotFound
from utilities.path_mappings import path_mappings
from utilities.backup import restore_from_backup
from app.database import init_db
# set start time global variable as epoch
global startTime
startTime = time.time()
@ -233,9 +235,6 @@ def init_binaries():
return exe
# keep this import at the end to prevent peewee.OperationalError: unable to open database file
from app.database import init_db, migrate_db # noqa E402
init_db()
migrate_db()
init_binaries()
path_mappings.update()

View File

@ -5,6 +5,8 @@ import os
from subzero.language import Language
from app.database import database, insert
logger = logging.getLogger(__name__)
@ -46,9 +48,13 @@ class CustomLanguage:
"Register the custom language subclasses in the database."
for sub in cls.__subclasses__():
table.insert(
{table.code3: sub.alpha3, table.code2: sub.alpha2, table.name: sub.name}
).on_conflict(action="IGNORE").execute()
database.execute(
insert(table)
.values(code3=sub.alpha3,
code2=sub.alpha2,
name=sub.name,
enabled=0)
.on_conflict_do_nothing())
@classmethod
def found_external(cls, subtitle, subtitle_path):

View File

@ -5,32 +5,29 @@ import pycountry
from subzero.language import Language
from .custom_lang import CustomLanguage
from app.database import TableSettingsLanguages
from app.database import TableSettingsLanguages, database, insert, update, select
def load_language_in_db():
# Get languages list in langs tuple
langs = [[lang.alpha_3, lang.alpha_2, lang.name]
langs = [{'code3': lang.alpha_3, 'code2': lang.alpha_2, 'name': lang.name, 'enabled': 0}
for lang in pycountry.languages
if hasattr(lang, 'alpha_2')]
# Insert standard languages in database table
TableSettingsLanguages.insert_many(langs,
fields=[TableSettingsLanguages.code3, TableSettingsLanguages.code2,
TableSettingsLanguages.name]) \
.on_conflict(action='IGNORE') \
.execute()
database.execute(
insert(TableSettingsLanguages)
.values(langs)
.on_conflict_do_nothing())
# Update standard languages with code3b if available
langs = [[lang.bibliographic, lang.alpha_3]
langs = [{'code3b': lang.bibliographic, 'code3': lang.alpha_3}
for lang in pycountry.languages
if hasattr(lang, 'alpha_2') and hasattr(lang, 'bibliographic')]
# Update languages in database table
for lang in langs:
TableSettingsLanguages.update({TableSettingsLanguages.code3b: lang[0]}) \
.where(TableSettingsLanguages.code3 == lang[1]) \
.execute()
database.execute(
update(TableSettingsLanguages), langs)
# Insert custom languages in database table
CustomLanguage.register(TableSettingsLanguages)
@ -42,52 +39,58 @@ def load_language_in_db():
def create_languages_dict():
global languages_dict
# replace chinese by chinese simplified
TableSettingsLanguages.update({TableSettingsLanguages.name: 'Chinese Simplified'}) \
.where(TableSettingsLanguages.code3 == 'zho') \
.execute()
database.execute(
update(TableSettingsLanguages)
.values(name='Chinese Simplified')
.where(TableSettingsLanguages.code3 == 'zho'))
languages_dict = TableSettingsLanguages.select(TableSettingsLanguages.name,
TableSettingsLanguages.code2,
TableSettingsLanguages.code3,
TableSettingsLanguages.code3b).dicts()
languages_dict = [{
'code3': x.code3,
'code2': x.code2,
'name': x.name,
'code3b': x.code3b,
} for x in database.execute(
select(TableSettingsLanguages.code3, TableSettingsLanguages.code2, TableSettingsLanguages.name,
TableSettingsLanguages.code3b))
.all()]
def language_from_alpha2(lang):
return next((item["name"] for item in languages_dict if item["code2"] == lang[:2]), None)
return next((item['name'] for item in languages_dict if item['code2'] == lang[:2]), None)
def language_from_alpha3(lang):
return next((item["name"] for item in languages_dict if item["code3"] == lang[:3] or item["code3b"] == lang[:3]),
None)
return next((item['name'] for item in languages_dict if lang[:3] in [item['code3'], item['code3b']]), None)
def alpha2_from_alpha3(lang):
return next((item["code2"] for item in languages_dict if item["code3"] == lang[:3] or item["code3b"] == lang[:3]),
None)
return next((item['code2'] for item in languages_dict if lang[:3] in [item['code3'], item['code3b']]), None)
def alpha2_from_language(lang):
return next((item["code2"] for item in languages_dict if item["name"] == lang), None)
return next((item['code2'] for item in languages_dict if item['name'] == lang), None)
def alpha3_from_alpha2(lang):
return next((item["code3"] for item in languages_dict if item["code2"] == lang[:2]), None)
return next((item['code3'] for item in languages_dict if item['code2'] == lang[:2]), None)
def alpha3_from_language(lang):
return next((item["code3"] for item in languages_dict if item["name"] == lang), None)
return next((item['code3'] for item in languages_dict if item['name'] == lang), None)
def get_language_set():
languages = TableSettingsLanguages.select(TableSettingsLanguages.code3) \
.where(TableSettingsLanguages.enabled == 1).dicts()
languages = database.execute(
select(TableSettingsLanguages.code3)
.where(TableSettingsLanguages.enabled == 1))\
.all()
language_set = set()
for lang in languages:
custom = CustomLanguage.from_value(lang["code3"], "alpha3")
custom = CustomLanguage.from_value(lang.code3, "alpha3")
if custom is None:
language_set.add(Language(lang["code3"]))
language_set.add(Language(lang.code3))
else:
language_set.add(custom.subzero_language())

View File

@ -1,6 +1,8 @@
# coding=utf-8
import os
import io
import logging
from threading import Thread
@ -34,19 +36,35 @@ else:
# there's missing embedded packages after a commit
check_if_new_update()
from app.database import System # noqa E402
from app.database import System, database, update, migrate_db, create_db_revision # noqa E402
from app.notifier import update_notifier # noqa E402
from languages.get_languages import load_language_in_db # noqa E402
from app.signalr_client import sonarr_signalr_client, radarr_signalr_client # noqa E402
from app.server import webserver # noqa E402
from app.server import webserver, app # noqa E402
from app.announcements import get_announcements_to_file # noqa E402
if args.create_db_revision:
try:
stop_file = io.open(os.path.join(args.config_dir, "bazarr.stop"), "w", encoding='UTF-8')
except Exception as e:
logging.error('BAZARR Cannot create stop file: ' + repr(e))
else:
create_db_revision(app)
logging.info('Bazarr is being shutdown...')
stop_file.write(str(''))
stop_file.close()
os._exit(0)
else:
migrate_db(app)
configure_proxy_func()
get_announcements_to_file()
# Reset the updated once Bazarr have been restarted after an update
System.update({System.updated: '0'}).execute()
database.execute(
update(System)
.values(updated='0'))
# Load languages in database
load_language_in_db()

View File

@ -2,38 +2,38 @@
from datetime import datetime
from app.database import TableBlacklistMovie
from app.database import TableBlacklistMovie, database, insert, delete, select
from app.event_handler import event_stream
def get_blacklist_movie():
blacklist_db = TableBlacklistMovie.select(TableBlacklistMovie.provider, TableBlacklistMovie.subs_id).dicts()
blacklist_list = []
for item in blacklist_db:
blacklist_list.append((item['provider'], item['subs_id']))
return blacklist_list
return [(item.provider, item.subs_id) for item in
database.execute(
select(TableBlacklistMovie.provider, TableBlacklistMovie.subs_id))
.all()]
def blacklist_log_movie(radarr_id, provider, subs_id, language):
TableBlacklistMovie.insert({
TableBlacklistMovie.radarr_id: radarr_id,
TableBlacklistMovie.timestamp: datetime.now(),
TableBlacklistMovie.provider: provider,
TableBlacklistMovie.subs_id: subs_id,
TableBlacklistMovie.language: language
}).execute()
database.execute(
insert(TableBlacklistMovie)
.values(
radarr_id=radarr_id,
timestamp=datetime.now(),
provider=provider,
subs_id=subs_id,
language=language
))
event_stream(type='movie-blacklist')
def blacklist_delete_movie(provider, subs_id):
TableBlacklistMovie.delete().where((TableBlacklistMovie.provider == provider) and
(TableBlacklistMovie.subs_id == subs_id))\
.execute()
database.execute(
delete(TableBlacklistMovie)
.where((TableBlacklistMovie.provider == provider) and (TableBlacklistMovie.subs_id == subs_id)))
event_stream(type='movie-blacklist', action='delete')
def blacklist_delete_all_movie():
TableBlacklistMovie.delete().execute()
database.execute(
delete(TableBlacklistMovie))
event_stream(type='movie-blacklist', action='delete')

View File

@ -2,7 +2,7 @@
from datetime import datetime
from app.database import TableHistoryMovie
from app.database import TableHistoryMovie, database, insert
from app.event_handler import event_stream
@ -14,17 +14,23 @@ def history_log_movie(action, radarr_id, result, fake_provider=None, fake_score=
score = fake_score or result.score
subs_id = result.subs_id
subtitles_path = result.subs_path
matched = result.matched
not_matched = result.not_matched
TableHistoryMovie.insert({
TableHistoryMovie.action: action,
TableHistoryMovie.radarrId: radarr_id,
TableHistoryMovie.timestamp: datetime.now(),
TableHistoryMovie.description: description,
TableHistoryMovie.video_path: video_path,
TableHistoryMovie.language: language,
TableHistoryMovie.provider: provider,
TableHistoryMovie.score: score,
TableHistoryMovie.subs_id: subs_id,
TableHistoryMovie.subtitles_path: subtitles_path
}).execute()
database.execute(
insert(TableHistoryMovie)
.values(
action=action,
radarrId=radarr_id,
timestamp=datetime.now(),
description=description,
video_path=video_path,
language=language,
provider=provider,
score=score,
subs_id=subs_id,
subtitles_path=subtitles_path,
matched=str(matched) if matched else None,
not_matched=str(not_matched) if not_matched else None
))
event_stream(type='movie-history')

View File

@ -6,7 +6,7 @@ import logging
from app.config import settings
from utilities.path_mappings import path_mappings
from app.database import TableMoviesRootfolder, TableMovies
from app.database import TableMoviesRootfolder, TableMovies, database, delete, update, insert, select
from radarr.info import get_radarr_info, url_radarr
from constants import headers
@ -33,52 +33,61 @@ def get_radarr_rootfolder():
logging.exception("BAZARR Error trying to get rootfolder from Radarr.")
return []
else:
radarr_movies_paths = list(TableMovies.select(TableMovies.path).dicts())
for folder in rootfolder.json():
if any(item['path'].startswith(folder['path']) for item in radarr_movies_paths):
if any(item.path.startswith(folder['path']) for item in database.execute(
select(TableMovies.path))
.all()):
radarr_rootfolder.append({'id': folder['id'], 'path': folder['path']})
db_rootfolder = TableMoviesRootfolder.select(TableMoviesRootfolder.id, TableMoviesRootfolder.path).dicts()
db_rootfolder = database.execute(
select(TableMoviesRootfolder.id, TableMoviesRootfolder.path))\
.all()
rootfolder_to_remove = [x for x in db_rootfolder if not
next((item for item in radarr_rootfolder if item['id'] == x['id']), False)]
next((item for item in radarr_rootfolder if item['id'] == x.id), False)]
rootfolder_to_update = [x for x in radarr_rootfolder if
next((item for item in db_rootfolder if item['id'] == x['id']), False)]
next((item for item in db_rootfolder if item.id == x['id']), False)]
rootfolder_to_insert = [x for x in radarr_rootfolder if not
next((item for item in db_rootfolder if item['id'] == x['id']), False)]
next((item for item in db_rootfolder if item.id == x['id']), False)]
for item in rootfolder_to_remove:
TableMoviesRootfolder.delete().where(TableMoviesRootfolder.id == item['id']).execute()
database.execute(
delete(TableMoviesRootfolder)
.where(TableMoviesRootfolder.id == item.id))
for item in rootfolder_to_update:
TableMoviesRootfolder.update({TableMoviesRootfolder.path: item['path']})\
.where(TableMoviesRootfolder.id == item['id']).execute()
database.execute(
update(TableMoviesRootfolder)
.values(path=item['path'])
.where(TableMoviesRootfolder.id == item['id']))
for item in rootfolder_to_insert:
TableMoviesRootfolder.insert({TableMoviesRootfolder.id: item['id'],
TableMoviesRootfolder.path: item['path']}).execute()
database.execute(
insert(TableMoviesRootfolder)
.values(id=item['id'], path=item['path']))
def check_radarr_rootfolder():
get_radarr_rootfolder()
rootfolder = TableMoviesRootfolder.select(TableMoviesRootfolder.id, TableMoviesRootfolder.path).dicts()
rootfolder = database.execute(
select(TableMoviesRootfolder.id, TableMoviesRootfolder.path))\
.all()
for item in rootfolder:
root_path = item['path']
root_path = item.path
if not root_path.endswith(('/', '\\')):
if root_path.startswith('/'):
root_path += '/'
else:
root_path += '\\'
if not os.path.isdir(path_mappings.path_replace_movie(root_path)):
TableMoviesRootfolder.update({TableMoviesRootfolder.accessible: 0,
TableMoviesRootfolder.error: 'This Radarr root directory does not seems to '
'be accessible by Please check path '
'mapping.'}) \
.where(TableMoviesRootfolder.id == item['id']) \
.execute()
database.execute(
update(TableMoviesRootfolder)
.values(accessible=0, error='This Radarr root directory does not seems to be accessible by Please '
'check path mapping.')
.where(TableMoviesRootfolder.id == item.id))
elif not os.access(path_mappings.path_replace_movie(root_path), os.W_OK):
TableMoviesRootfolder.update({TableMoviesRootfolder.accessible: 0,
TableMoviesRootfolder.error: 'Bazarr cannot write to this directory'}) \
.where(TableMoviesRootfolder.id == item['id']) \
.execute()
database.execute(
update(TableMoviesRootfolder)
.values(accessible=0, error='Bazarr cannot write to this directory')
.where(TableMoviesRootfolder.id == item.id))
else:
TableMoviesRootfolder.update({TableMoviesRootfolder.accessible: 1,
TableMoviesRootfolder.error: ''}) \
.where(TableMoviesRootfolder.id == item['id']) \
.execute()
database.execute(
update(TableMoviesRootfolder)
.values(accessible=1, error='')
.where(TableMoviesRootfolder.id == item.id))

View File

@ -3,7 +3,7 @@
import os
import logging
from peewee import IntegrityError
from sqlalchemy.exc import IntegrityError
from app.config import settings
from radarr.info import url_radarr
@ -11,7 +11,7 @@ from utilities.path_mappings import path_mappings
from subtitles.indexer.movies import store_subtitles_movie, movies_full_scan_subtitles
from radarr.rootfolder import check_radarr_rootfolder
from subtitles.mass_download import movies_download_subtitles
from app.database import TableMovies
from app.database import TableMovies, database, insert, update, delete, select
from app.event_handler import event_stream, show_progress, hide_progress
from .utils import get_profile_list, get_tags, get_movies_from_radarr_api
@ -49,9 +49,10 @@ def update_movies(send_event=True):
return
else:
# Get current movies in DB
current_movies_db = TableMovies.select(TableMovies.tmdbId, TableMovies.path, TableMovies.radarrId).dicts()
current_movies_db_list = [x['tmdbId'] for x in current_movies_db]
current_movies_db = [x.tmdbId for x in
database.execute(
select(TableMovies.tmdbId))
.all()]
current_movies_radarr = []
movies_to_update = []
@ -79,7 +80,7 @@ def update_movies(send_event=True):
# Add movies in radarr to current movies list
current_movies_radarr.append(str(movie['tmdbId']))
if str(movie['tmdbId']) in current_movies_db_list:
if str(movie['tmdbId']) in current_movies_db:
movies_to_update.append(movieParser(movie, action='update',
tags_dict=tagsDict,
movie_default_profile=movie_default_profile,
@ -94,51 +95,25 @@ def update_movies(send_event=True):
hide_progress(id='movies_progress')
# Remove old movies from DB
removed_movies = list(set(current_movies_db_list) - set(current_movies_radarr))
removed_movies = list(set(current_movies_db) - set(current_movies_radarr))
for removed_movie in removed_movies:
try:
TableMovies.delete().where(TableMovies.tmdbId == removed_movie).execute()
except Exception as e:
logging.error(f"BAZARR cannot remove movie tmdbId {removed_movie} because of {e}")
continue
database.execute(
delete(TableMovies)
.where(TableMovies.tmdbId == removed_movie))
# Update movies in DB
movies_in_db_list = []
movies_in_db = TableMovies.select(TableMovies.radarrId,
TableMovies.title,
TableMovies.path,
TableMovies.tmdbId,
TableMovies.overview,
TableMovies.poster,
TableMovies.fanart,
TableMovies.audio_language,
TableMovies.sceneName,
TableMovies.monitored,
TableMovies.sortTitle,
TableMovies.year,
TableMovies.alternativeTitles,
TableMovies.format,
TableMovies.resolution,
TableMovies.video_codec,
TableMovies.audio_codec,
TableMovies.imdbId,
TableMovies.movie_file_id,
TableMovies.tags,
TableMovies.file_size).dicts()
for item in movies_in_db:
movies_in_db_list.append(item)
movies_to_update_list = [i for i in movies_to_update if i not in movies_in_db_list]
for updated_movie in movies_to_update_list:
try:
TableMovies.update(updated_movie).where(TableMovies.tmdbId == updated_movie['tmdbId']).execute()
except IntegrityError as e:
logging.error(f"BAZARR cannot update movie {updated_movie['path']} because of {e}")
for updated_movie in movies_to_update:
if database.execute(
select(TableMovies)
.filter_by(**updated_movie))\
.first():
continue
else:
database.execute(
update(TableMovies).values(updated_movie)
.where(TableMovies.tmdbId == updated_movie['tmdbId']))
altered_movies.append([updated_movie['tmdbId'],
updated_movie['path'],
updated_movie['radarrId'],
@ -147,21 +122,19 @@ def update_movies(send_event=True):
# Insert new movies in DB
for added_movie in movies_to_add:
try:
result = TableMovies.insert(added_movie).on_conflict_ignore().execute()
database.execute(
insert(TableMovies)
.values(added_movie))
except IntegrityError as e:
logging.error(f"BAZARR cannot insert movie {added_movie['path']} because of {e}")
logging.error(f"BAZARR cannot update movie {added_movie['path']} because of {e}")
continue
else:
if result and result > 0:
altered_movies.append([added_movie['tmdbId'],
added_movie['path'],
added_movie['radarrId'],
added_movie['monitored']])
if send_event:
event_stream(type='movie', action='update', payload=int(added_movie['radarrId']))
else:
logging.debug('BAZARR unable to insert this movie into the database:',
path_mappings.path_replace_movie(added_movie['path']))
altered_movies.append([added_movie['tmdbId'],
added_movie['path'],
added_movie['radarrId'],
added_movie['monitored']])
if send_event:
event_stream(type='movie', action='update', payload=int(added_movie['radarrId']))
# Store subtitles for added or modified movies
for i, altered_movie in enumerate(altered_movies, 1):
@ -174,22 +147,21 @@ def update_one_movie(movie_id, action, defer_search=False):
logging.debug('BAZARR syncing this specific movie from Radarr: {}'.format(movie_id))
# Check if there's a row in database for this movie ID
existing_movie = TableMovies.select(TableMovies.path)\
.where(TableMovies.radarrId == movie_id)\
.dicts()\
.get_or_none()
existing_movie = database.execute(
select(TableMovies.path)
.where(TableMovies.radarrId == movie_id))\
.first()
# Remove movie from DB
if action == 'deleted':
if existing_movie:
try:
TableMovies.delete().where(TableMovies.radarrId == movie_id).execute()
except Exception as e:
logging.error(f"BAZARR cannot delete movie {existing_movie['path']} because of {e}")
else:
event_stream(type='movie', action='delete', payload=int(movie_id))
logging.debug('BAZARR deleted this movie from the database:{}'.format(path_mappings.path_replace_movie(
existing_movie['path'])))
database.execute(
delete(TableMovies)
.where(TableMovies.radarrId == movie_id))
event_stream(type='movie', action='delete', payload=int(movie_id))
logging.debug('BAZARR deleted this movie from the database:{}'.format(path_mappings.path_replace_movie(
existing_movie.path)))
return
movie_default_enabled = settings.general.getboolean('movie_default_enabled')
@ -228,33 +200,34 @@ def update_one_movie(movie_id, action, defer_search=False):
# Remove movie from DB
if not movie and existing_movie:
try:
TableMovies.delete().where(TableMovies.radarrId == movie_id).execute()
except Exception as e:
logging.error(f"BAZARR cannot insert episode {existing_movie['path']} because of {e}")
else:
event_stream(type='movie', action='delete', payload=int(movie_id))
logging.debug('BAZARR deleted this movie from the database:{}'.format(path_mappings.path_replace_movie(
existing_movie['path'])))
return
database.execute(
delete(TableMovies)
.where(TableMovies.radarrId == movie_id))
event_stream(type='movie', action='delete', payload=int(movie_id))
logging.debug('BAZARR deleted this movie from the database:{}'.format(path_mappings.path_replace_movie(
existing_movie.path)))
return
# Update existing movie in DB
elif movie and existing_movie:
try:
TableMovies.update(movie).where(TableMovies.radarrId == movie['radarrId']).execute()
except IntegrityError as e:
logging.error(f"BAZARR cannot insert episode {movie['path']} because of {e}")
else:
event_stream(type='movie', action='update', payload=int(movie_id))
logging.debug('BAZARR updated this movie into the database:{}'.format(path_mappings.path_replace_movie(
movie['path'])))
database.execute(
update(TableMovies)
.values(movie)
.where(TableMovies.radarrId == movie['radarrId']))
event_stream(type='movie', action='update', payload=int(movie_id))
logging.debug('BAZARR updated this movie into the database:{}'.format(path_mappings.path_replace_movie(
movie['path'])))
# Insert new movie in DB
elif movie and not existing_movie:
try:
TableMovies.insert(movie).on_conflict(action='IGNORE').execute()
database.execute(
insert(TableMovies)
.values(movie))
except IntegrityError as e:
logging.error(f"BAZARR cannot insert movie {movie['path']} because of {e}")
logging.error(f"BAZARR cannot update movie {movie['path']} because of {e}")
else:
event_stream(type='movie', action='update', payload=int(movie_id))
logging.debug('BAZARR inserted this movie into the database:{}'.format(path_mappings.path_replace_movie(

View File

@ -2,39 +2,38 @@
from datetime import datetime
from app.database import TableBlacklist
from app.database import TableBlacklist, database, insert, delete, select
from app.event_handler import event_stream
def get_blacklist():
blacklist_db = TableBlacklist.select(TableBlacklist.provider, TableBlacklist.subs_id).dicts()
blacklist_list = []
for item in blacklist_db:
blacklist_list.append((item['provider'], item['subs_id']))
return blacklist_list
return [(item.provider, item.subs_id) for item in
database.execute(
select(TableBlacklist.provider, TableBlacklist.subs_id))
.all()]
def blacklist_log(sonarr_series_id, sonarr_episode_id, provider, subs_id, language):
TableBlacklist.insert({
TableBlacklist.sonarr_series_id: sonarr_series_id,
TableBlacklist.sonarr_episode_id: sonarr_episode_id,
TableBlacklist.timestamp: datetime.now(),
TableBlacklist.provider: provider,
TableBlacklist.subs_id: subs_id,
TableBlacklist.language: language
}).execute()
database.execute(
insert(TableBlacklist)
.values(
sonarr_series_id=sonarr_series_id,
sonarr_episode_id=sonarr_episode_id,
timestamp=datetime.now(),
provider=provider,
subs_id=subs_id,
language=language
))
event_stream(type='episode-blacklist')
def blacklist_delete(provider, subs_id):
TableBlacklist.delete().where((TableBlacklist.provider == provider) and
(TableBlacklist.subs_id == subs_id))\
.execute()
database.execute(
delete(TableBlacklist)
.where((TableBlacklist.provider == provider) and (TableBlacklist.subs_id == subs_id)))
event_stream(type='episode-blacklist', action='delete')
def blacklist_delete_all():
TableBlacklist.delete().execute()
database.execute(delete(TableBlacklist))
event_stream(type='episode-blacklist', action='delete')

View File

@ -2,7 +2,7 @@
from datetime import datetime
from app.database import TableHistory
from app.database import TableHistory, database, insert
from app.event_handler import event_stream
@ -14,18 +14,24 @@ def history_log(action, sonarr_series_id, sonarr_episode_id, result, fake_provid
score = fake_score or result.score
subs_id = result.subs_id
subtitles_path = result.subs_path
matched = result.matched
not_matched = result.not_matched
TableHistory.insert({
TableHistory.action: action,
TableHistory.sonarrSeriesId: sonarr_series_id,
TableHistory.sonarrEpisodeId: sonarr_episode_id,
TableHistory.timestamp: datetime.now(),
TableHistory.description: description,
TableHistory.video_path: video_path,
TableHistory.language: language,
TableHistory.provider: provider,
TableHistory.score: score,
TableHistory.subs_id: subs_id,
TableHistory.subtitles_path: subtitles_path
}).execute()
database.execute(
insert(TableHistory)
.values(
action=action,
sonarrSeriesId=sonarr_series_id,
sonarrEpisodeId=sonarr_episode_id,
timestamp=datetime.now(),
description=description,
video_path=video_path,
language=language,
provider=provider,
score=score,
subs_id=subs_id,
subtitles_path=subtitles_path,
matched=str(matched) if matched else None,
not_matched=str(not_matched) if not_matched else None
))
event_stream(type='episode-history')

View File

@ -5,7 +5,7 @@ import requests
import logging
from app.config import settings
from app.database import TableShowsRootfolder, TableShows
from app.database import TableShowsRootfolder, TableShows, database, insert, update, delete, select
from utilities.path_mappings import path_mappings
from sonarr.info import get_sonarr_info, url_sonarr
from constants import headers
@ -33,53 +33,61 @@ def get_sonarr_rootfolder():
logging.exception("BAZARR Error trying to get rootfolder from Sonarr.")
return []
else:
sonarr_movies_paths = list(TableShows.select(TableShows.path).dicts())
for folder in rootfolder.json():
if any(item['path'].startswith(folder['path']) for item in sonarr_movies_paths):
if any(item.path.startswith(folder['path']) for item in database.execute(
select(TableShows.path))
.all()):
sonarr_rootfolder.append({'id': folder['id'], 'path': folder['path']})
db_rootfolder = TableShowsRootfolder.select(TableShowsRootfolder.id, TableShowsRootfolder.path).dicts()
db_rootfolder = database.execute(
select(TableShowsRootfolder.id, TableShowsRootfolder.path))\
.all()
rootfolder_to_remove = [x for x in db_rootfolder if not
next((item for item in sonarr_rootfolder if item['id'] == x['id']), False)]
next((item for item in sonarr_rootfolder if item['id'] == x.id), False)]
rootfolder_to_update = [x for x in sonarr_rootfolder if
next((item for item in db_rootfolder if item['id'] == x['id']), False)]
next((item for item in db_rootfolder if item.id == x['id']), False)]
rootfolder_to_insert = [x for x in sonarr_rootfolder if not
next((item for item in db_rootfolder if item['id'] == x['id']), False)]
next((item for item in db_rootfolder if item.id == x['id']), False)]
for item in rootfolder_to_remove:
TableShowsRootfolder.delete().where(TableShowsRootfolder.id == item['id']).execute()
database.execute(
delete(TableShowsRootfolder)
.where(TableShowsRootfolder.id == item.id))
for item in rootfolder_to_update:
TableShowsRootfolder.update({TableShowsRootfolder.path: item['path']})\
.where(TableShowsRootfolder.id == item['id'])\
.execute()
database.execute(
update(TableShowsRootfolder)
.values(path=item['path'])
.where(TableShowsRootfolder.id == item['id']))
for item in rootfolder_to_insert:
TableShowsRootfolder.insert({TableShowsRootfolder.id: item['id'], TableShowsRootfolder.path: item['path']})\
.execute()
database.execute(
insert(TableShowsRootfolder)
.values(id=item['id'], path=item['path']))
def check_sonarr_rootfolder():
get_sonarr_rootfolder()
rootfolder = TableShowsRootfolder.select(TableShowsRootfolder.id, TableShowsRootfolder.path).dicts()
rootfolder = database.execute(
select(TableShowsRootfolder.id, TableShowsRootfolder.path))\
.all()
for item in rootfolder:
root_path = item['path']
root_path = item.path
if not root_path.endswith(('/', '\\')):
if root_path.startswith('/'):
root_path += '/'
else:
root_path += '\\'
if not os.path.isdir(path_mappings.path_replace(root_path)):
TableShowsRootfolder.update({TableShowsRootfolder.accessible: 0,
TableShowsRootfolder.error: 'This Sonarr root directory does not seems to '
'be accessible by Please check path '
'mapping.'})\
.where(TableShowsRootfolder.id == item['id'])\
.execute()
database.execute(
update(TableShowsRootfolder)
.values(accessible=0, error='This Sonarr root directory does not seems to be accessible by Bazarr. '
'Please check path mapping.')
.where(TableShowsRootfolder.id == item.id))
elif not os.access(path_mappings.path_replace(root_path), os.W_OK):
TableShowsRootfolder.update({TableShowsRootfolder.accessible: 0,
TableShowsRootfolder.error: 'Bazarr cannot write to this directory.'}) \
.where(TableShowsRootfolder.id == item['id']) \
.execute()
database.execute(
update(TableShowsRootfolder)
.values(accessible=0, error='Bazarr cannot write to this directory.')
.where(TableShowsRootfolder.id == item.id))
else:
TableShowsRootfolder.update({TableShowsRootfolder.accessible: 1,
TableShowsRootfolder.error: ''}) \
.where(TableShowsRootfolder.id == item['id']) \
.execute()
database.execute(
update(TableShowsRootfolder)
.values(accessible=1, error='')
.where(TableShowsRootfolder.id == item.id))

View File

@ -3,18 +3,18 @@
import os
import logging
from peewee import IntegrityError
from sqlalchemy.exc import IntegrityError
from app.database import TableEpisodes
from app.database import database, TableEpisodes, delete, update, insert, select
from app.config import settings
from utilities.path_mappings import path_mappings
from subtitles.indexer.series import store_subtitles, series_full_scan_subtitles
from subtitles.mass_download import episode_download_subtitles
from app.event_handler import event_stream, show_progress, hide_progress
from app.event_handler import event_stream
from sonarr.info import get_sonarr_info, url_sonarr
from .parser import episodeParser
from .utils import get_series_from_sonarr_api, get_episodes_from_sonarr_api, get_episodesFiles_from_sonarr_api
from .utils import get_episodes_from_sonarr_api, get_episodesFiles_from_sonarr_api
def update_all_episodes():
@ -22,147 +22,118 @@ def update_all_episodes():
logging.info('BAZARR All existing episode subtitles indexed from disk.')
def sync_episodes(series_id=None, send_event=True):
def sync_episodes(series_id, send_event=True):
logging.debug('BAZARR Starting episodes sync from Sonarr.')
apikey_sonarr = settings.sonarr.apikey
# Get current episodes id in DB
current_episodes_db = TableEpisodes.select(TableEpisodes.sonarrEpisodeId,
if series_id:
current_episodes_db_list = [row.sonarrEpisodeId for row in
database.execute(
select(TableEpisodes.sonarrEpisodeId,
TableEpisodes.path,
TableEpisodes.sonarrSeriesId)\
.where((TableEpisodes.sonarrSeriesId == series_id) if series_id else None)\
.dicts()
current_episodes_db_list = [x['sonarrEpisodeId'] for x in current_episodes_db]
TableEpisodes.sonarrSeriesId)
.where(TableEpisodes.sonarrSeriesId == series_id)).all()]
else:
return
current_episodes_sonarr = []
episodes_to_update = []
episodes_to_add = []
altered_episodes = []
# Get sonarrId for each series from database
seriesIdList = get_series_from_sonarr_api(url=url_sonarr(), apikey_sonarr=apikey_sonarr, sonarr_series_id=series_id)
series_count = len(seriesIdList)
for i, seriesId in enumerate(seriesIdList):
if send_event:
show_progress(id='episodes_progress',
header='Syncing episodes...',
name=seriesId['title'],
value=i,
count=series_count)
# Get episodes data for a series from Sonarr
episodes = get_episodes_from_sonarr_api(url=url_sonarr(), apikey_sonarr=apikey_sonarr,
series_id=seriesId['id'])
if not episodes:
continue
else:
# For Sonarr v3, we need to update episodes to integrate the episodeFile API endpoint results
if not get_sonarr_info.is_legacy():
episodeFiles = get_episodesFiles_from_sonarr_api(url=url_sonarr(), apikey_sonarr=apikey_sonarr,
series_id=seriesId['id'])
for episode in episodes:
if episodeFiles and episode['hasFile']:
item = [x for x in episodeFiles if x['id'] == episode['episodeFileId']]
if item:
episode['episodeFile'] = item[0]
# Get episodes data for a series from Sonarr
episodes = get_episodes_from_sonarr_api(url=url_sonarr(), apikey_sonarr=apikey_sonarr,
series_id=series_id)
if episodes:
# For Sonarr v3, we need to update episodes to integrate the episodeFile API endpoint results
if not get_sonarr_info.is_legacy():
episodeFiles = get_episodesFiles_from_sonarr_api(url=url_sonarr(), apikey_sonarr=apikey_sonarr,
series_id=series_id)
for episode in episodes:
if 'hasFile' in episode:
if episode['hasFile'] is True:
if 'episodeFile' in episode:
try:
bazarr_file_size = \
os.path.getsize(path_mappings.path_replace(episode['episodeFile']['path']))
except OSError:
bazarr_file_size = 0
if episode['episodeFile']['size'] > 20480 or bazarr_file_size > 20480:
# Add episodes in sonarr to current episode list
current_episodes_sonarr.append(episode['id'])
if episodeFiles and episode['hasFile']:
item = [x for x in episodeFiles if x['id'] == episode['episodeFileId']]
if item:
episode['episodeFile'] = item[0]
# Parse episode data
if episode['id'] in current_episodes_db_list:
episodes_to_update.append(episodeParser(episode))
else:
episodes_to_add.append(episodeParser(episode))
for episode in episodes:
if 'hasFile' in episode:
if episode['hasFile'] is True:
if 'episodeFile' in episode:
try:
bazarr_file_size = \
os.path.getsize(path_mappings.path_replace(episode['episodeFile']['path']))
except OSError:
bazarr_file_size = 0
if episode['episodeFile']['size'] > 20480 or bazarr_file_size > 20480:
# Add episodes in sonarr to current episode list
current_episodes_sonarr.append(episode['id'])
if send_event:
hide_progress(id='episodes_progress')
# Parse episode data
if episode['id'] in current_episodes_db_list:
episodes_to_update.append(episodeParser(episode))
else:
episodes_to_add.append(episodeParser(episode))
# Remove old episodes from DB
removed_episodes = list(set(current_episodes_db_list) - set(current_episodes_sonarr))
stmt = select(TableEpisodes.path,
TableEpisodes.sonarrSeriesId,
TableEpisodes.sonarrEpisodeId)
for removed_episode in removed_episodes:
episode_to_delete = TableEpisodes.select(TableEpisodes.path,
TableEpisodes.sonarrSeriesId,
TableEpisodes.sonarrEpisodeId)\
.where(TableEpisodes.sonarrEpisodeId == removed_episode)\
.dicts()\
.get_or_none()
episode_to_delete = database.execute(stmt.where(TableEpisodes.sonarrEpisodeId == removed_episode)).first()
if not episode_to_delete:
continue
try:
TableEpisodes.delete().where(TableEpisodes.sonarrEpisodeId == removed_episode).execute()
database.execute(
delete(TableEpisodes)
.where(TableEpisodes.sonarrEpisodeId == removed_episode))
except Exception as e:
logging.error(f"BAZARR cannot delete episode {episode_to_delete['path']} because of {e}")
logging.error(f"BAZARR cannot delete episode {episode_to_delete.path} because of {e}")
continue
else:
if send_event:
event_stream(type='episode', action='delete', payload=episode_to_delete['sonarrEpisodeId'])
event_stream(type='episode', action='delete', payload=episode_to_delete.sonarrEpisodeId)
# Update existing episodes in DB
episode_in_db_list = []
episodes_in_db = TableEpisodes.select(TableEpisodes.sonarrSeriesId,
TableEpisodes.sonarrEpisodeId,
TableEpisodes.title,
TableEpisodes.path,
TableEpisodes.season,
TableEpisodes.episode,
TableEpisodes.sceneName,
TableEpisodes.monitored,
TableEpisodes.format,
TableEpisodes.resolution,
TableEpisodes.video_codec,
TableEpisodes.audio_codec,
TableEpisodes.episode_file_id,
TableEpisodes.audio_language,
TableEpisodes.file_size).dicts()
for item in episodes_in_db:
episode_in_db_list.append(item)
episodes_to_update_list = [i for i in episodes_to_update if i not in episode_in_db_list]
for updated_episode in episodes_to_update_list:
try:
TableEpisodes.update(updated_episode).where(TableEpisodes.sonarrEpisodeId ==
updated_episode['sonarrEpisodeId']).execute()
except IntegrityError as e:
logging.error(f"BAZARR cannot update episode {updated_episode['path']} because of {e}")
for updated_episode in episodes_to_update:
if database.execute(
select(TableEpisodes)
.filter_by(**updated_episode))\
.first():
continue
else:
altered_episodes.append([updated_episode['sonarrEpisodeId'],
updated_episode['path'],
updated_episode['sonarrSeriesId']])
try:
database.execute(
update(TableEpisodes)
.values(updated_episode)
.where(TableEpisodes.sonarrEpisodeId == updated_episode['sonarrEpisodeId']))
except IntegrityError as e:
logging.error(f"BAZARR cannot update episode {updated_episode['path']} because of {e}")
continue
else:
altered_episodes.append([updated_episode['sonarrEpisodeId'],
updated_episode['path'],
updated_episode['sonarrSeriesId']])
if send_event:
event_stream(type='episode', action='update', payload=updated_episode['sonarrEpisodeId'])
# Insert new episodes in DB
for added_episode in episodes_to_add:
try:
result = TableEpisodes.insert(added_episode).on_conflict_ignore().execute()
database.execute(
insert(TableEpisodes)
.values(added_episode))
except IntegrityError as e:
logging.error(f"BAZARR cannot insert episode {added_episode['path']} because of {e}")
continue
else:
if result and result > 0:
altered_episodes.append([added_episode['sonarrEpisodeId'],
added_episode['path'],
added_episode['monitored']])
if send_event:
event_stream(type='episode', payload=added_episode['sonarrEpisodeId'])
else:
logging.debug('BAZARR unable to insert this episode into the database:{}'.format(
path_mappings.path_replace(added_episode['path'])))
altered_episodes.append([added_episode['sonarrEpisodeId'],
added_episode['path'],
added_episode['monitored']])
if send_event:
event_stream(type='episode', payload=added_episode['sonarrEpisodeId'])
# Store subtitles for added or modified episodes
for i, altered_episode in enumerate(altered_episodes, 1):
@ -177,10 +148,10 @@ def sync_one_episode(episode_id, defer_search=False):
apikey_sonarr = settings.sonarr.apikey
# Check if there's a row in database for this episode ID
existing_episode = TableEpisodes.select(TableEpisodes.path, TableEpisodes.episode_file_id)\
.where(TableEpisodes.sonarrEpisodeId == episode_id)\
.dicts()\
.get_or_none()
existing_episode = database.execute(
select(TableEpisodes.path, TableEpisodes.episode_file_id)
.where(TableEpisodes.sonarrEpisodeId == episode_id)) \
.first()
try:
# Get episode data from sonarr api
@ -207,20 +178,22 @@ def sync_one_episode(episode_id, defer_search=False):
# Remove episode from DB
if not episode and existing_episode:
try:
TableEpisodes.delete().where(TableEpisodes.sonarrEpisodeId == episode_id).execute()
except Exception as e:
logging.error(f"BAZARR cannot delete episode {existing_episode['path']} because of {e}")
else:
event_stream(type='episode', action='delete', payload=int(episode_id))
logging.debug('BAZARR deleted this episode from the database:{}'.format(path_mappings.path_replace(
existing_episode['path'])))
return
database.execute(
delete(TableEpisodes)
.where(TableEpisodes.sonarrEpisodeId == episode_id))
event_stream(type='episode', action='delete', payload=int(episode_id))
logging.debug('BAZARR deleted this episode from the database:{}'.format(path_mappings.path_replace(
existing_episode['path'])))
return
# Update existing episodes in DB
elif episode and existing_episode:
try:
TableEpisodes.update(episode).where(TableEpisodes.sonarrEpisodeId == episode_id).execute()
database.execute(
update(TableEpisodes)
.values(episode)
.where(TableEpisodes.sonarrEpisodeId == episode_id))
except IntegrityError as e:
logging.error(f"BAZARR cannot update episode {episode['path']} because of {e}")
else:
@ -231,7 +204,9 @@ def sync_one_episode(episode_id, defer_search=False):
# Insert new episodes in DB
elif episode and not existing_episode:
try:
TableEpisodes.insert(episode).on_conflict(action='IGNORE').execute()
database.execute(
insert(TableEpisodes)
.values(episode))
except IntegrityError as e:
logging.error(f"BAZARR cannot insert episode {episode['path']} because of {e}")
else:

View File

@ -3,7 +3,7 @@
import os
from app.config import settings
from app.database import TableShows
from app.database import TableShows, database, select
from utilities.path_mappings import path_mappings
from utilities.video_analyzer import embedded_audio_reader
from sonarr.info import get_sonarr_info
@ -118,8 +118,10 @@ def episodeParser(episode):
if 'name' in item:
audio_language.append(item['name'])
else:
audio_language = TableShows.get(
TableShows.sonarrSeriesId == episode['seriesId']).audio_language
audio_language = database.execute(
select(TableShows.audio_language)
.where(TableShows.sonarrSeriesId == episode['seriesId']))\
.first().audio_language
if 'mediaInfo' in episode['episodeFile']:
if 'videoCodec' in episode['episodeFile']['mediaInfo']:

View File

@ -2,13 +2,13 @@
import logging
from peewee import IntegrityError
from sqlalchemy.exc import IntegrityError
from app.config import settings
from sonarr.info import url_sonarr
from subtitles.indexer.series import list_missing_subtitles
from sonarr.rootfolder import check_sonarr_rootfolder
from app.database import TableShows, TableEpisodes
from app.database import TableShows, database, insert, update, delete, select
from utilities.path_mappings import path_mappings
from app.event_handler import event_stream, show_progress, hide_progress
@ -41,12 +41,11 @@ def update_series(send_event=True):
return
else:
# Get current shows in DB
current_shows_db = TableShows.select(TableShows.sonarrSeriesId).dicts()
current_shows_db_list = [x['sonarrSeriesId'] for x in current_shows_db]
current_shows_db = [x.sonarrSeriesId for x in
database.execute(
select(TableShows.sonarrSeriesId))
.all()]
current_shows_sonarr = []
series_to_update = []
series_to_add = []
series_count = len(series)
for i, show in enumerate(series):
@ -60,82 +59,60 @@ def update_series(send_event=True):
# Add shows in Sonarr to current shows list
current_shows_sonarr.append(show['id'])
if show['id'] in current_shows_db_list:
series_to_update.append(seriesParser(show, action='update', tags_dict=tagsDict,
serie_default_profile=serie_default_profile,
audio_profiles=audio_profiles))
if show['id'] in current_shows_db:
updated_series = seriesParser(show, action='update', tags_dict=tagsDict,
serie_default_profile=serie_default_profile,
audio_profiles=audio_profiles)
if not database.execute(
select(TableShows)
.filter_by(**updated_series))\
.first():
try:
database.execute(
update(TableShows)
.values(updated_series)
.where(TableShows.sonarrSeriesId == show['id']))
except IntegrityError as e:
logging.error(f"BAZARR cannot update series {updated_series['path']} because of {e}")
continue
if send_event:
event_stream(type='series', payload=show['id'])
else:
series_to_add.append(seriesParser(show, action='insert', tags_dict=tagsDict,
serie_default_profile=serie_default_profile,
audio_profiles=audio_profiles))
added_series = seriesParser(show, action='insert', tags_dict=tagsDict,
serie_default_profile=serie_default_profile,
audio_profiles=audio_profiles)
try:
database.execute(
insert(TableShows)
.values(added_series))
except IntegrityError as e:
logging.error(f"BAZARR cannot insert series {added_series['path']} because of {e}")
continue
else:
list_missing_subtitles(no=show['id'])
if send_event:
event_stream(type='series', action='update', payload=show['id'])
sync_episodes(series_id=show['id'], send_event=send_event)
# Remove old series from DB
removed_series = list(set(current_shows_db) - set(current_shows_sonarr))
for series in removed_series:
database.execute(
delete(TableShows)
.where(TableShows.sonarrSeriesId == series))
if send_event:
event_stream(type='series', action='delete', payload=series)
if send_event:
hide_progress(id='series_progress')
# Remove old series from DB
removed_series = list(set(current_shows_db_list) - set(current_shows_sonarr))
for series in removed_series:
try:
TableShows.delete().where(TableShows.sonarrSeriesId == series).execute()
except Exception as e:
logging.error(f"BAZARR cannot delete series with sonarrSeriesId {series} because of {e}")
continue
else:
if send_event:
event_stream(type='series', action='delete', payload=series)
# Update existing series in DB
series_in_db_list = []
series_in_db = TableShows.select(TableShows.title,
TableShows.path,
TableShows.tvdbId,
TableShows.sonarrSeriesId,
TableShows.overview,
TableShows.poster,
TableShows.fanart,
TableShows.audio_language,
TableShows.sortTitle,
TableShows.year,
TableShows.alternativeTitles,
TableShows.tags,
TableShows.seriesType,
TableShows.imdbId,
TableShows.monitored).dicts()
for item in series_in_db:
series_in_db_list.append(item)
series_to_update_list = [i for i in series_to_update if i not in series_in_db_list]
for updated_series in series_to_update_list:
try:
TableShows.update(updated_series).where(TableShows.sonarrSeriesId ==
updated_series['sonarrSeriesId']).execute()
except IntegrityError as e:
logging.error(f"BAZARR cannot update series {updated_series['path']} because of {e}")
continue
else:
if send_event:
event_stream(type='series', payload=updated_series['sonarrSeriesId'])
# Insert new series in DB
for added_series in series_to_add:
try:
result = TableShows.insert(added_series).on_conflict(action='IGNORE').execute()
except IntegrityError as e:
logging.error(f"BAZARR cannot insert series {added_series['path']} because of {e}")
continue
else:
if result:
list_missing_subtitles(no=added_series['sonarrSeriesId'])
else:
logging.debug('BAZARR unable to insert this series into the database:',
path_mappings.path_replace(added_series['path']))
if send_event:
event_stream(type='series', action='update', payload=added_series['sonarrSeriesId'])
logging.debug('BAZARR All series synced from Sonarr into database.')
@ -143,21 +120,19 @@ def update_one_series(series_id, action):
logging.debug('BAZARR syncing this specific series from Sonarr: {}'.format(series_id))
# Check if there's a row in database for this series ID
existing_series = TableShows.select(TableShows.path)\
.where(TableShows.sonarrSeriesId == series_id)\
.dicts()\
.get_or_none()
existing_series = database.execute(
select(TableShows)
.where(TableShows.sonarrSeriesId == series_id))\
.first()
# Delete series from DB
if action == 'deleted' and existing_series:
try:
TableShows.delete().where(TableShows.sonarrSeriesId == int(series_id)).execute()
except Exception as e:
logging.error(f"BAZARR cannot delete series with sonarrSeriesId {series_id} because of {e}")
else:
TableEpisodes.delete().where(TableEpisodes.sonarrSeriesId == int(series_id)).execute()
event_stream(type='series', action='delete', payload=int(series_id))
return
database.execute(
delete(TableShows)
.where(TableShows.sonarrSeriesId == int(series_id)))
event_stream(type='series', action='delete', payload=int(series_id))
return
serie_default_enabled = settings.general.getboolean('serie_default_enabled')
@ -196,7 +171,10 @@ def update_one_series(series_id, action):
# Update existing series in DB
if action == 'updated' and existing_series:
try:
TableShows.update(series).where(TableShows.sonarrSeriesId == series['sonarrSeriesId']).execute()
database.execute(
update(TableShows)
.values(series)
.where(TableShows.sonarrSeriesId == series['sonarrSeriesId']))
except IntegrityError as e:
logging.error(f"BAZARR cannot update series {series['path']} because of {e}")
else:
@ -208,7 +186,9 @@ def update_one_series(series_id, action):
# Insert new series in DB
elif action == 'updated' and not existing_series:
try:
TableShows.insert(series).on_conflict(action='IGNORE').execute()
database.execute(
insert(TableShows)
.values(series))
except IntegrityError as e:
logging.error(f"BAZARR cannot insert series {series['path']} because of {e}")
else:

View File

@ -13,7 +13,7 @@ from subliminal_patch.core_persistent import download_best_subtitles
from subliminal_patch.score import ComputeScore
from app.config import settings, get_array_from, get_scores
from app.database import TableEpisodes, TableMovies
from app.database import TableEpisodes, TableMovies, database, select
from utilities.path_mappings import path_mappings
from utilities.helper import get_target_folder, force_unicode
from languages.get_languages import alpha3_from_alpha2
@ -163,15 +163,15 @@ def parse_language_object(language):
def check_missing_languages(path, media_type):
# confirm if language is still missing or if cutoff has been reached
if media_type == 'series':
confirmed_missing_subs = TableEpisodes.select(TableEpisodes.missing_subtitles) \
.where(TableEpisodes.path == path_mappings.path_replace_reverse(path)) \
.dicts() \
.get_or_none()
confirmed_missing_subs = database.execute(
select(TableEpisodes.missing_subtitles)
.where(TableEpisodes.path == path_mappings.path_replace_reverse(path)))\
.first()
else:
confirmed_missing_subs = TableMovies.select(TableMovies.missing_subtitles) \
.where(TableMovies.path == path_mappings.path_replace_reverse_movie(path)) \
.dicts() \
.get_or_none()
confirmed_missing_subs = database.execute(
select(TableMovies.missing_subtitles)
.where(TableMovies.path == path_mappings.path_replace_reverse_movie(path)))\
.first()
if not confirmed_missing_subs:
reversed_path = path_mappings.path_replace_reverse(path) if media_type == 'series' else \
@ -180,7 +180,7 @@ def check_missing_languages(path, media_type):
return []
languages = []
for language in ast.literal_eval(confirmed_missing_subs['missing_subtitles']):
for language in ast.literal_eval(confirmed_missing_subs.missing_subtitles):
if language is not None:
hi_ = "True" if language.endswith(':hi') else "False"
forced_ = "True" if language.endswith(':forced') else "False"

View File

@ -8,7 +8,8 @@ import ast
from subliminal_patch import core, search_external_subtitles
from languages.custom_lang import CustomLanguage
from app.database import get_profiles_list, get_profile_cutoff, TableMovies, get_audio_profile_languages
from app.database import get_profiles_list, get_profile_cutoff, TableMovies, get_audio_profile_languages, database, \
update, select
from languages.get_languages import alpha2_from_alpha3, get_language_set
from app.config import settings
from utilities.helper import get_subtitle_destination_folder
@ -26,17 +27,17 @@ def store_subtitles_movie(original_path, reversed_path, use_cache=True):
if os.path.exists(reversed_path):
if settings.general.getboolean('use_embedded_subs'):
logging.debug("BAZARR is trying to index embedded subtitles.")
item = TableMovies.select(TableMovies.movie_file_id, TableMovies.file_size)\
.where(TableMovies.path == original_path)\
.dicts()\
.get_or_none()
item = database.execute(
select(TableMovies.movie_file_id, TableMovies.file_size)
.where(TableMovies.path == original_path)) \
.first()
if not item:
logging.exception(f"BAZARR error when trying to select this movie from database: {reversed_path}")
else:
try:
subtitle_languages = embedded_subs_reader(reversed_path,
file_size=item['file_size'],
movie_file_id=item['movie_file_id'],
file_size=item.file_size,
movie_file_id=item.movie_file_id,
use_cache=use_cache)
for subtitle_language, subtitle_forced, subtitle_hi, subtitle_codec in subtitle_languages:
try:
@ -56,35 +57,35 @@ def store_subtitles_movie(original_path, reversed_path, use_cache=True):
lang = lang + ':hi'
logging.debug("BAZARR embedded subtitles detected: " + lang)
actual_subtitles.append([lang, None, None])
except Exception:
logging.debug("BAZARR unable to index this unrecognized language: " + subtitle_language)
pass
except Exception as error:
logging.debug("BAZARR unable to index this unrecognized language: %s (%s)",
subtitle_language, error)
except Exception:
logging.exception(
"BAZARR error when trying to analyze this %s file: %s" % (os.path.splitext(reversed_path)[1],
reversed_path))
pass
try:
dest_folder = get_subtitle_destination_folder() or ''
dest_folder = get_subtitle_destination_folder()
core.CUSTOM_PATHS = [dest_folder] if dest_folder else []
# get previously indexed subtitles that haven't changed:
item = TableMovies.select(TableMovies.subtitles) \
.where(TableMovies.path == original_path) \
.dicts() \
.get_or_none()
item = database.execute(
select(TableMovies.subtitles)
.where(TableMovies.path == original_path))\
.first()
if not item:
previously_indexed_subtitles_to_exclude = []
else:
previously_indexed_subtitles = ast.literal_eval(item['subtitles']) if item['subtitles'] else []
previously_indexed_subtitles = ast.literal_eval(item.subtitles) if item.subtitles else []
previously_indexed_subtitles_to_exclude = [x for x in previously_indexed_subtitles
if len(x) == 3 and
x[1] and
os.path.isfile(path_mappings.path_replace(x[1])) and
os.stat(path_mappings.path_replace(x[1])).st_size == x[2]]
subtitles = search_external_subtitles(reversed_path, languages=get_language_set())
subtitles = search_external_subtitles(reversed_path, languages=get_language_set(),
only_one=settings.general.getboolean('single_language'))
full_dest_folder_path = os.path.dirname(reversed_path)
if dest_folder:
if settings.general.subfolder == "absolute":
@ -95,7 +96,6 @@ def store_subtitles_movie(original_path, reversed_path, use_cache=True):
previously_indexed_subtitles_to_exclude)
except Exception:
logging.exception("BAZARR unable to index external subtitles.")
pass
else:
for subtitle, language in subtitles.items():
valid_language = False
@ -127,15 +127,19 @@ def store_subtitles_movie(original_path, reversed_path, use_cache=True):
actual_subtitles.append([language_str, path_mappings.path_replace_reverse_movie(subtitle_path),
os.stat(subtitle_path).st_size])
TableMovies.update({TableMovies.subtitles: str(actual_subtitles)})\
.where(TableMovies.path == original_path)\
.execute()
matching_movies = TableMovies.select(TableMovies.radarrId).where(TableMovies.path == original_path).dicts()
database.execute(
update(TableMovies)
.values(subtitles=str(actual_subtitles))
.where(TableMovies.path == original_path))
matching_movies = database.execute(
select(TableMovies.radarrId)
.where(TableMovies.path == original_path))\
.all()
for movie in matching_movies:
if movie:
logging.debug("BAZARR storing those languages to DB: " + str(actual_subtitles))
list_missing_subtitles_movies(no=movie['radarrId'])
list_missing_subtitles_movies(no=movie.radarrId)
else:
logging.debug("BAZARR haven't been able to update existing subtitles to DB : " + str(actual_subtitles))
else:
@ -147,39 +151,45 @@ def store_subtitles_movie(original_path, reversed_path, use_cache=True):
def list_missing_subtitles_movies(no=None, send_event=True):
movies_subtitles = TableMovies.select(TableMovies.radarrId,
TableMovies.subtitles,
TableMovies.profileId,
TableMovies.audio_language)\
.where((TableMovies.radarrId == no) if no else None)\
.dicts()
if isinstance(movies_subtitles, str):
logging.error("BAZARR list missing subtitles query to DB returned this instead of rows: " + movies_subtitles)
return
if no:
movies_subtitles = database.execute(
select(TableMovies.radarrId,
TableMovies.subtitles,
TableMovies.profileId,
TableMovies.audio_language)
.where(TableMovies.radarrId == no)) \
.all()
else:
movies_subtitles = database.execute(
select(TableMovies.radarrId,
TableMovies.subtitles,
TableMovies.profileId,
TableMovies.audio_language)) \
.all()
use_embedded_subs = settings.general.getboolean('use_embedded_subs')
for movie_subtitles in movies_subtitles:
missing_subtitles_text = '[]'
if movie_subtitles['profileId']:
if movie_subtitles.profileId:
# get desired subtitles
desired_subtitles_temp = get_profiles_list(profile_id=movie_subtitles['profileId'])
desired_subtitles_temp = get_profiles_list(profile_id=movie_subtitles.profileId)
desired_subtitles_list = []
if desired_subtitles_temp:
for language in desired_subtitles_temp['items']:
if language['audio_exclude'] == "True":
if any(x['code2'] == language['language'] for x in get_audio_profile_languages(
movie_subtitles['audio_language'])):
movie_subtitles.audio_language)):
continue
desired_subtitles_list.append([language['language'], language['forced'], language['hi']])
# get existing subtitles
actual_subtitles_list = []
if movie_subtitles['subtitles'] is not None:
if movie_subtitles.subtitles is not None:
if use_embedded_subs:
actual_subtitles_temp = ast.literal_eval(movie_subtitles['subtitles'])
actual_subtitles_temp = ast.literal_eval(movie_subtitles.subtitles)
else:
actual_subtitles_temp = [x for x in ast.literal_eval(movie_subtitles['subtitles']) if x[1]]
actual_subtitles_temp = [x for x in ast.literal_eval(movie_subtitles.subtitles) if x[1]]
for subtitles in actual_subtitles_temp:
subtitles = subtitles[0].split(':')
@ -197,14 +207,14 @@ def list_missing_subtitles_movies(no=None, send_event=True):
# check if cutoff is reached and skip any further check
cutoff_met = False
cutoff_temp_list = get_profile_cutoff(profile_id=movie_subtitles['profileId'])
cutoff_temp_list = get_profile_cutoff(profile_id=movie_subtitles.profileId)
if cutoff_temp_list:
for cutoff_temp in cutoff_temp_list:
cutoff_language = [cutoff_temp['language'], cutoff_temp['forced'], cutoff_temp['hi']]
if cutoff_temp['audio_exclude'] == 'True' and \
any(x['code2'] == cutoff_temp['language'] for x in
get_audio_profile_languages(movie_subtitles['audio_language'])):
get_audio_profile_languages(movie_subtitles.audio_language)):
cutoff_met = True
elif cutoff_language in actual_subtitles_list:
cutoff_met = True
@ -241,19 +251,22 @@ def list_missing_subtitles_movies(no=None, send_event=True):
missing_subtitles_text = str(missing_subtitles_output_list)
TableMovies.update({TableMovies.missing_subtitles: missing_subtitles_text})\
.where(TableMovies.radarrId == movie_subtitles['radarrId'])\
.execute()
database.execute(
update(TableMovies)
.values(missing_subtitles=missing_subtitles_text)
.where(TableMovies.radarrId == movie_subtitles.radarrId))
if send_event:
event_stream(type='movie', payload=movie_subtitles['radarrId'])
event_stream(type='movie-wanted', action='update', payload=movie_subtitles['radarrId'])
event_stream(type='movie', payload=movie_subtitles.radarrId)
event_stream(type='movie-wanted', action='update', payload=movie_subtitles.radarrId)
if send_event:
event_stream(type='badges')
def movies_full_scan_subtitles(use_cache=settings.radarr.getboolean('use_ffprobe_cache')):
movies = TableMovies.select(TableMovies.path).dicts()
movies = database.execute(
select(TableMovies.path))\
.all()
count_movies = len(movies)
for i, movie in enumerate(movies):
@ -262,7 +275,7 @@ def movies_full_scan_subtitles(use_cache=settings.radarr.getboolean('use_ffprobe
name='Movies subtitles',
value=i,
count=count_movies)
store_subtitles_movie(movie['path'], path_mappings.path_replace_movie(movie['path']), use_cache=use_cache)
store_subtitles_movie(movie.path, path_mappings.path_replace_movie(movie.path), use_cache=use_cache)
hide_progress(id='movies_disk_scan')
@ -270,10 +283,11 @@ def movies_full_scan_subtitles(use_cache=settings.radarr.getboolean('use_ffprobe
def movies_scan_subtitles(no):
movies = TableMovies.select(TableMovies.path)\
.where(TableMovies.radarrId == no)\
.order_by(TableMovies.radarrId)\
.dicts()
movies = database.execute(
select(TableMovies.path)
.where(TableMovies.radarrId == no)
.order_by(TableMovies.radarrId)) \
.all()
for movie in movies:
store_subtitles_movie(movie['path'], path_mappings.path_replace_movie(movie['path']), use_cache=False)
store_subtitles_movie(movie.path, path_mappings.path_replace_movie(movie.path), use_cache=False)

View File

@ -8,7 +8,8 @@ import ast
from subliminal_patch import core, search_external_subtitles
from languages.custom_lang import CustomLanguage
from app.database import get_profiles_list, get_profile_cutoff, TableEpisodes, TableShows, get_audio_profile_languages
from app.database import get_profiles_list, get_profile_cutoff, TableEpisodes, TableShows, \
get_audio_profile_languages, database, update, select
from languages.get_languages import alpha2_from_alpha3, get_language_set
from app.config import settings
from utilities.helper import get_subtitle_destination_folder
@ -26,17 +27,17 @@ def store_subtitles(original_path, reversed_path, use_cache=True):
if os.path.exists(reversed_path):
if settings.general.getboolean('use_embedded_subs'):
logging.debug("BAZARR is trying to index embedded subtitles.")
item = TableEpisodes.select(TableEpisodes.episode_file_id, TableEpisodes.file_size)\
.where(TableEpisodes.path == original_path)\
.dicts()\
.get_or_none()
item = database.execute(
select(TableEpisodes.episode_file_id, TableEpisodes.file_size)
.where(TableEpisodes.path == original_path))\
.first()
if not item:
logging.exception(f"BAZARR error when trying to select this episode from database: {reversed_path}")
else:
try:
subtitle_languages = embedded_subs_reader(reversed_path,
file_size=item['file_size'],
episode_file_id=item['episode_file_id'],
file_size=item.file_size,
episode_file_id=item.episode_file_id,
use_cache=use_cache)
for subtitle_language, subtitle_forced, subtitle_hi, subtitle_codec in subtitle_languages:
try:
@ -68,14 +69,14 @@ def store_subtitles(original_path, reversed_path, use_cache=True):
core.CUSTOM_PATHS = [dest_folder] if dest_folder else []
# get previously indexed subtitles that haven't changed:
item = TableEpisodes.select(TableEpisodes.subtitles) \
.where(TableEpisodes.path == original_path) \
.dicts() \
.get_or_none()
item = database.execute(
select(TableEpisodes.subtitles)
.where(TableEpisodes.path == original_path)) \
.first()
if not item:
previously_indexed_subtitles_to_exclude = []
else:
previously_indexed_subtitles = ast.literal_eval(item['subtitles']) if item['subtitles'] else []
previously_indexed_subtitles = ast.literal_eval(item.subtitles) if item.subtitles else []
previously_indexed_subtitles_to_exclude = [x for x in previously_indexed_subtitles
if len(x) == 3 and
x[1] and
@ -114,7 +115,7 @@ def store_subtitles(original_path, reversed_path, use_cache=True):
if custom is not None:
actual_subtitles.append([custom, path_mappings.path_replace_reverse(subtitle_path)])
elif str(language) != 'und':
elif str(language.basename) != 'und':
if language.forced:
language_str = str(language)
elif language.hi:
@ -125,17 +126,19 @@ def store_subtitles(original_path, reversed_path, use_cache=True):
actual_subtitles.append([language_str, path_mappings.path_replace_reverse(subtitle_path),
os.stat(subtitle_path).st_size])
TableEpisodes.update({TableEpisodes.subtitles: str(actual_subtitles)})\
.where(TableEpisodes.path == original_path)\
.execute()
matching_episodes = TableEpisodes.select(TableEpisodes.sonarrEpisodeId, TableEpisodes.sonarrSeriesId)\
.where(TableEpisodes.path == original_path)\
.dicts()
database.execute(
update(TableEpisodes)
.values(subtitles=str(actual_subtitles))
.where(TableEpisodes.path == original_path))
matching_episodes = database.execute(
select(TableEpisodes.sonarrEpisodeId, TableEpisodes.sonarrSeriesId)
.where(TableEpisodes.path == original_path))\
.all()
for episode in matching_episodes:
if episode:
logging.debug("BAZARR storing those languages to DB: " + str(actual_subtitles))
list_missing_subtitles(epno=episode['sonarrEpisodeId'])
list_missing_subtitles(epno=episode.sonarrEpisodeId)
else:
logging.debug("BAZARR haven't been able to update existing subtitles to DB : " + str(actual_subtitles))
else:
@ -153,41 +156,40 @@ def list_missing_subtitles(no=None, epno=None, send_event=True):
episodes_subtitles_clause = (TableEpisodes.sonarrSeriesId == no)
else:
episodes_subtitles_clause = None
episodes_subtitles = TableEpisodes.select(TableShows.sonarrSeriesId,
TableEpisodes.sonarrEpisodeId,
TableEpisodes.subtitles,
TableShows.profileId,
TableEpisodes.audio_language)\
.join(TableShows, on=(TableEpisodes.sonarrSeriesId == TableShows.sonarrSeriesId))\
.where(episodes_subtitles_clause)\
.dicts()
if isinstance(episodes_subtitles, str):
logging.error("BAZARR list missing subtitles query to DB returned this instead of rows: " + episodes_subtitles)
return
episodes_subtitles = database.execute(
select(TableShows.sonarrSeriesId,
TableEpisodes.sonarrEpisodeId,
TableEpisodes.subtitles,
TableShows.profileId,
TableEpisodes.audio_language)
.select_from(TableEpisodes)
.join(TableShows)
.where(episodes_subtitles_clause))\
.all()
use_embedded_subs = settings.general.getboolean('use_embedded_subs')
for episode_subtitles in episodes_subtitles:
missing_subtitles_text = '[]'
if episode_subtitles['profileId']:
if episode_subtitles.profileId:
# get desired subtitles
desired_subtitles_temp = get_profiles_list(profile_id=episode_subtitles['profileId'])
desired_subtitles_temp = get_profiles_list(profile_id=episode_subtitles.profileId)
desired_subtitles_list = []
if desired_subtitles_temp:
for language in desired_subtitles_temp['items']:
if language['audio_exclude'] == "True":
if any(x['code2'] == language['language'] for x in get_audio_profile_languages(
episode_subtitles['audio_language'])):
episode_subtitles.audio_language)):
continue
desired_subtitles_list.append([language['language'], language['forced'], language['hi']])
# get existing subtitles
actual_subtitles_list = []
if episode_subtitles['subtitles'] is not None:
if episode_subtitles.subtitles is not None:
if use_embedded_subs:
actual_subtitles_temp = ast.literal_eval(episode_subtitles['subtitles'])
actual_subtitles_temp = ast.literal_eval(episode_subtitles.subtitles)
else:
actual_subtitles_temp = [x for x in ast.literal_eval(episode_subtitles['subtitles']) if x[1]]
actual_subtitles_temp = [x for x in ast.literal_eval(episode_subtitles.subtitles) if x[1]]
for subtitles in actual_subtitles_temp:
subtitles = subtitles[0].split(':')
@ -205,14 +207,14 @@ def list_missing_subtitles(no=None, epno=None, send_event=True):
# check if cutoff is reached and skip any further check
cutoff_met = False
cutoff_temp_list = get_profile_cutoff(profile_id=episode_subtitles['profileId'])
cutoff_temp_list = get_profile_cutoff(profile_id=episode_subtitles.profileId)
if cutoff_temp_list:
for cutoff_temp in cutoff_temp_list:
cutoff_language = [cutoff_temp['language'], cutoff_temp['forced'], cutoff_temp['hi']]
if cutoff_temp['audio_exclude'] == 'True' and \
any(x['code2'] == cutoff_temp['language'] for x in
get_audio_profile_languages(episode_subtitles['audio_language'])):
get_audio_profile_languages(episode_subtitles.audio_language)):
cutoff_met = True
elif cutoff_language in actual_subtitles_list:
cutoff_met = True
@ -251,19 +253,22 @@ def list_missing_subtitles(no=None, epno=None, send_event=True):
missing_subtitles_text = str(missing_subtitles_output_list)
TableEpisodes.update({TableEpisodes.missing_subtitles: missing_subtitles_text})\
.where(TableEpisodes.sonarrEpisodeId == episode_subtitles['sonarrEpisodeId'])\
.execute()
database.execute(
update(TableEpisodes)
.values(missing_subtitles=missing_subtitles_text)
.where(TableEpisodes.sonarrEpisodeId == episode_subtitles.sonarrEpisodeId))
if send_event:
event_stream(type='episode', payload=episode_subtitles['sonarrEpisodeId'])
event_stream(type='episode-wanted', action='update', payload=episode_subtitles['sonarrEpisodeId'])
event_stream(type='episode', payload=episode_subtitles.sonarrEpisodeId)
event_stream(type='episode-wanted', action='update', payload=episode_subtitles.sonarrEpisodeId)
if send_event:
event_stream(type='badges')
def series_full_scan_subtitles(use_cache=settings.sonarr.getboolean('use_ffprobe_cache')):
episodes = TableEpisodes.select(TableEpisodes.path).dicts()
episodes = database.execute(
select(TableEpisodes.path))\
.all()
count_episodes = len(episodes)
for i, episode in enumerate(episodes):
@ -272,7 +277,7 @@ def series_full_scan_subtitles(use_cache=settings.sonarr.getboolean('use_ffprobe
name='Episodes subtitles',
value=i,
count=count_episodes)
store_subtitles(episode['path'], path_mappings.path_replace(episode['path']), use_cache=use_cache)
store_subtitles(episode.path, path_mappings.path_replace(episode.path), use_cache=use_cache)
hide_progress(id='episodes_disk_scan')
@ -280,10 +285,11 @@ def series_full_scan_subtitles(use_cache=settings.sonarr.getboolean('use_ffprobe
def series_scan_subtitles(no):
episodes = TableEpisodes.select(TableEpisodes.path)\
.where(TableEpisodes.sonarrSeriesId == no)\
.order_by(TableEpisodes.sonarrEpisodeId)\
.dicts()
episodes = database.execute(
select(TableEpisodes.path)
.where(TableEpisodes.sonarrSeriesId == no)
.order_by(TableEpisodes.sonarrEpisodeId))\
.all()
for episode in episodes:
store_subtitles(episode['path'], path_mappings.path_replace(episode['path']), use_cache=False)
store_subtitles(episode.path, path_mappings.path_replace(episode.path), use_cache=False)

View File

@ -12,7 +12,7 @@ from subtitles.indexer.movies import store_subtitles_movie
from radarr.history import history_log_movie
from app.notifier import send_notifications_movie
from app.get_providers import get_providers
from app.database import get_exclusion_clause, get_audio_profile_languages, TableMovies
from app.database import get_exclusion_clause, get_audio_profile_languages, TableMovies, database, select
from app.event_handler import show_progress, hide_progress
from ..download import generate_subtitles
@ -21,28 +21,27 @@ from ..download import generate_subtitles
def movies_download_subtitles(no):
conditions = [(TableMovies.radarrId == no)]
conditions += get_exclusion_clause('movie')
movies = TableMovies.select(TableMovies.path,
TableMovies.missing_subtitles,
TableMovies.audio_language,
TableMovies.radarrId,
TableMovies.sceneName,
TableMovies.title,
TableMovies.tags,
TableMovies.monitored)\
.where(reduce(operator.and_, conditions))\
.dicts()
if not len(movies):
movie = database.execute(
select(TableMovies.path,
TableMovies.missing_subtitles,
TableMovies.audio_language,
TableMovies.radarrId,
TableMovies.sceneName,
TableMovies.title,
TableMovies.tags,
TableMovies.monitored)
.where(reduce(operator.and_, conditions))) \
.first()
if not len(movie):
logging.debug("BAZARR no movie with that radarrId can be found in database:", str(no))
return
else:
movie = movies[0]
if ast.literal_eval(movie['missing_subtitles']):
count_movie = len(ast.literal_eval(movie['missing_subtitles']))
if ast.literal_eval(movie.missing_subtitles):
count_movie = len(ast.literal_eval(movie.missing_subtitles))
else:
count_movie = 0
audio_language_list = get_audio_profile_languages(movie['audio_language'])
audio_language_list = get_audio_profile_languages(movie.audio_language)
if len(audio_language_list) > 0:
audio_language = audio_language_list[0]['name']
else:
@ -50,7 +49,7 @@ def movies_download_subtitles(no):
languages = []
for language in ast.literal_eval(movie['missing_subtitles']):
for language in ast.literal_eval(movie.missing_subtitles):
providers_list = get_providers()
if providers_list:
@ -64,20 +63,20 @@ def movies_download_subtitles(no):
show_progress(id='movie_search_progress_{}'.format(no),
header='Searching missing subtitles...',
name=movie['title'],
name=movie.title,
value=0,
count=count_movie)
for result in generate_subtitles(path_mappings.path_replace_movie(movie['path']),
for result in generate_subtitles(path_mappings.path_replace_movie(movie.path),
languages,
audio_language,
str(movie['sceneName']),
movie['title'],
str(movie.sceneName),
movie.title,
'movie',
check_if_still_required=True):
if result:
store_subtitles_movie(movie['path'], path_mappings.path_replace_movie(movie['path']))
store_subtitles_movie(movie.path, path_mappings.path_replace_movie(movie.path))
history_log_movie(1, no, result)
send_notifications_movie(no, result.message)

View File

@ -12,7 +12,7 @@ from subtitles.indexer.series import store_subtitles
from sonarr.history import history_log
from app.notifier import send_notifications
from app.get_providers import get_providers
from app.database import get_exclusion_clause, get_audio_profile_languages, TableShows, TableEpisodes
from app.database import get_exclusion_clause, get_audio_profile_languages, TableShows, TableEpisodes, database, select
from app.event_handler import show_progress, hide_progress
from ..download import generate_subtitles
@ -22,21 +22,23 @@ def series_download_subtitles(no):
conditions = [(TableEpisodes.sonarrSeriesId == no),
(TableEpisodes.missing_subtitles != '[]')]
conditions += get_exclusion_clause('series')
episodes_details = TableEpisodes.select(TableEpisodes.path,
TableEpisodes.missing_subtitles,
TableEpisodes.monitored,
TableEpisodes.sonarrEpisodeId,
TableEpisodes.sceneName,
TableShows.tags,
TableShows.seriesType,
TableEpisodes.audio_language,
TableShows.title,
TableEpisodes.season,
TableEpisodes.episode,
TableEpisodes.title.alias('episodeTitle')) \
.join(TableShows, on=(TableEpisodes.sonarrSeriesId == TableShows.sonarrSeriesId)) \
.where(reduce(operator.and_, conditions)) \
.dicts()
episodes_details = database.execute(
select(TableEpisodes.path,
TableEpisodes.missing_subtitles,
TableEpisodes.monitored,
TableEpisodes.sonarrEpisodeId,
TableEpisodes.sceneName,
TableShows.tags,
TableShows.seriesType,
TableEpisodes.audio_language,
TableShows.title,
TableEpisodes.season,
TableEpisodes.episode,
TableEpisodes.title.label('episodeTitle'))
.select_from(TableEpisodes)
.join(TableShows)
.where(reduce(operator.and_, conditions))) \
.all()
if not episodes_details:
logging.debug("BAZARR no episode for that sonarrSeriesId have been found in database or they have all been "
"ignored because of monitored status, series type or series tags: {}".format(no))
@ -50,21 +52,21 @@ def series_download_subtitles(no):
if providers_list:
show_progress(id='series_search_progress_{}'.format(no),
header='Searching missing subtitles...',
name='{0} - S{1:02d}E{2:02d} - {3}'.format(episode['title'],
episode['season'],
episode['episode'],
episode['episodeTitle']),
name='{0} - S{1:02d}E{2:02d} - {3}'.format(episode.title,
episode.season,
episode.episode,
episode.episodeTitle),
value=i,
count=count_episodes_details)
audio_language_list = get_audio_profile_languages(episode['audio_language'])
audio_language_list = get_audio_profile_languages(episode.audio_language)
if len(audio_language_list) > 0:
audio_language = audio_language_list[0]['name']
else:
audio_language = 'None'
languages = []
for language in ast.literal_eval(episode['missing_subtitles']):
for language in ast.literal_eval(episode.missing_subtitles):
if language is not None:
hi_ = "True" if language.endswith(':hi') else "False"
forced_ = "True" if language.endswith(':forced') else "False"
@ -73,17 +75,17 @@ def series_download_subtitles(no):
if not languages:
continue
for result in generate_subtitles(path_mappings.path_replace(episode['path']),
for result in generate_subtitles(path_mappings.path_replace(episode.path),
languages,
audio_language,
str(episode['sceneName']),
episode['title'],
str(episode.sceneName),
episode.title,
'series',
check_if_still_required=True):
if result:
store_subtitles(episode['path'], path_mappings.path_replace(episode['path']))
history_log(1, no, episode['sonarrEpisodeId'], result)
send_notifications(no, episode['sonarrEpisodeId'], result.message)
store_subtitles(episode.path, path_mappings.path_replace(episode.path))
history_log(1, no, episode.sonarrEpisodeId, result)
send_notifications(no, episode.sonarrEpisodeId, result.message)
else:
logging.info("BAZARR All providers are throttled")
break
@ -94,22 +96,24 @@ def series_download_subtitles(no):
def episode_download_subtitles(no, send_progress=False):
conditions = [(TableEpisodes.sonarrEpisodeId == no)]
conditions += get_exclusion_clause('series')
episodes_details = TableEpisodes.select(TableEpisodes.path,
TableEpisodes.missing_subtitles,
TableEpisodes.monitored,
TableEpisodes.sonarrEpisodeId,
TableEpisodes.sceneName,
TableShows.tags,
TableShows.title,
TableShows.sonarrSeriesId,
TableEpisodes.audio_language,
TableShows.seriesType,
TableEpisodes.title.alias('episodeTitle'),
TableEpisodes.season,
TableEpisodes.episode) \
.join(TableShows, on=(TableEpisodes.sonarrSeriesId == TableShows.sonarrSeriesId)) \
.where(reduce(operator.and_, conditions)) \
.dicts()
episodes_details = database.execute(
select(TableEpisodes.path,
TableEpisodes.missing_subtitles,
TableEpisodes.monitored,
TableEpisodes.sonarrEpisodeId,
TableEpisodes.sceneName,
TableShows.tags,
TableShows.title,
TableShows.sonarrSeriesId,
TableEpisodes.audio_language,
TableShows.seriesType,
TableEpisodes.title.label('episodeTitle'),
TableEpisodes.season,
TableEpisodes.episode)
.select_from(TableEpisodes)
.join(TableShows)
.where(reduce(operator.and_, conditions))) \
.all()
if not episodes_details:
logging.debug("BAZARR no episode with that sonarrEpisodeId can be found in database:", str(no))
return
@ -121,21 +125,21 @@ def episode_download_subtitles(no, send_progress=False):
if send_progress:
show_progress(id='episode_search_progress_{}'.format(no),
header='Searching missing subtitles...',
name='{0} - S{1:02d}E{2:02d} - {3}'.format(episode['title'],
episode['season'],
episode['episode'],
episode['episodeTitle']),
name='{0} - S{1:02d}E{2:02d} - {3}'.format(episode.title,
episode.season,
episode.episode,
episode.episodeTitle),
value=0,
count=1)
audio_language_list = get_audio_profile_languages(episode['audio_language'])
audio_language_list = get_audio_profile_languages(episode.audio_language)
if len(audio_language_list) > 0:
audio_language = audio_language_list[0]['name']
else:
audio_language = 'None'
languages = []
for language in ast.literal_eval(episode['missing_subtitles']):
for language in ast.literal_eval(episode.missing_subtitles):
if language is not None:
hi_ = "True" if language.endswith(':hi') else "False"
forced_ = "True" if language.endswith(':forced') else "False"
@ -144,17 +148,17 @@ def episode_download_subtitles(no, send_progress=False):
if not languages:
continue
for result in generate_subtitles(path_mappings.path_replace(episode['path']),
for result in generate_subtitles(path_mappings.path_replace(episode.path),
languages,
audio_language,
str(episode['sceneName']),
episode['title'],
str(episode.sceneName),
episode.title,
'series',
check_if_still_required=True):
if result:
store_subtitles(episode['path'], path_mappings.path_replace(episode['path']))
history_log(1, episode['sonarrSeriesId'], episode['sonarrEpisodeId'], result)
send_notifications(episode['sonarrSeriesId'], episode['sonarrEpisodeId'], result.message)
store_subtitles(episode.path, path_mappings.path_replace(episode.path))
history_log(1, episode.sonarrSeriesId, episode.sonarrEpisodeId, result)
send_notifications(episode.sonarrSeriesId, episode.sonarrEpisodeId, result.message)
if send_progress:
hide_progress(id='episode_search_progress_{}'.format(no))

View File

@ -7,7 +7,7 @@ from app.config import settings
from utilities.path_mappings import path_mappings
from utilities.post_processing import pp_replace, set_chmod
from languages.get_languages import alpha2_from_alpha3, alpha2_from_language, alpha3_from_language, language_from_alpha3
from app.database import TableEpisodes, TableMovies
from app.database import TableEpisodes, TableMovies, database, select
from utilities.analytics import event_tracker
from radarr.notify import notify_radarr
from sonarr.notify import notify_sonarr
@ -15,17 +15,20 @@ from app.event_handler import event_stream
from .utils import _get_download_code3
from .post_processing import postprocessing
from .utils import _get_scores
class ProcessSubtitlesResult:
def __init__(self, message, reversed_path, downloaded_language_code2, downloaded_provider, score, forced,
subtitle_id, reversed_subtitles_path, hearing_impaired):
subtitle_id, reversed_subtitles_path, hearing_impaired, matched=None, not_matched=None):
self.message = message
self.path = reversed_path
self.provider = downloaded_provider
self.score = score
self.subs_id = subtitle_id
self.subs_path = reversed_subtitles_path
self.matched = matched
self.not_matched = not_matched
if hearing_impaired:
self.language_code = downloaded_language_code2 + ":hi"
@ -67,39 +70,38 @@ def process_subtitle(subtitle, media_type, audio_language, path, max_score, is_u
downloaded_provider + " with a score of " + str(percent_score) + "%."
if media_type == 'series':
episode_metadata = TableEpisodes.select(TableEpisodes.sonarrSeriesId,
TableEpisodes.sonarrEpisodeId) \
.where(TableEpisodes.path == path_mappings.path_replace_reverse(path)) \
.dicts() \
.get_or_none()
episode_metadata = database.execute(
select(TableEpisodes.sonarrSeriesId, TableEpisodes.sonarrEpisodeId)
.where(TableEpisodes.path == path_mappings.path_replace_reverse(path)))\
.first()
if not episode_metadata:
return
series_id = episode_metadata['sonarrSeriesId']
episode_id = episode_metadata['sonarrEpisodeId']
series_id = episode_metadata.sonarrSeriesId
episode_id = episode_metadata.sonarrEpisodeId
from .sync import sync_subtitles
sync_subtitles(video_path=path, srt_path=downloaded_path,
forced=subtitle.language.forced,
srt_lang=downloaded_language_code2, media_type=media_type,
percent_score=percent_score,
sonarr_series_id=episode_metadata['sonarrSeriesId'],
sonarr_episode_id=episode_metadata['sonarrEpisodeId'])
sonarr_series_id=episode_metadata.sonarrSeriesId,
sonarr_episode_id=episode_metadata.sonarrEpisodeId)
else:
movie_metadata = TableMovies.select(TableMovies.radarrId) \
.where(TableMovies.path == path_mappings.path_replace_reverse_movie(path)) \
.dicts() \
.get_or_none()
movie_metadata = database.execute(
select(TableMovies.radarrId)
.where(TableMovies.path == path_mappings.path_replace_reverse_movie(path)))\
.first()
if not movie_metadata:
return
series_id = ""
episode_id = movie_metadata['radarrId']
episode_id = movie_metadata.radarrId
from .sync import sync_subtitles
sync_subtitles(video_path=path, srt_path=downloaded_path,
forced=subtitle.language.forced,
srt_lang=downloaded_language_code2, media_type=media_type,
percent_score=percent_score,
radarr_id=movie_metadata['radarrId'])
radarr_id=movie_metadata.radarrId)
if use_postprocessing is True:
command = pp_replace(postprocessing_cmd, path, downloaded_path, downloaded_language, downloaded_language_code2,
@ -124,16 +126,16 @@ def process_subtitle(subtitle, media_type, audio_language, path, max_score, is_u
if media_type == 'series':
reversed_path = path_mappings.path_replace_reverse(path)
reversed_subtitles_path = path_mappings.path_replace_reverse(downloaded_path)
notify_sonarr(episode_metadata['sonarrSeriesId'])
event_stream(type='series', action='update', payload=episode_metadata['sonarrSeriesId'])
notify_sonarr(episode_metadata.sonarrSeriesId)
event_stream(type='series', action='update', payload=episode_metadata.sonarrSeriesId)
event_stream(type='episode-wanted', action='delete',
payload=episode_metadata['sonarrEpisodeId'])
payload=episode_metadata.sonarrEpisodeId)
else:
reversed_path = path_mappings.path_replace_reverse_movie(path)
reversed_subtitles_path = path_mappings.path_replace_reverse_movie(downloaded_path)
notify_radarr(movie_metadata['radarrId'])
event_stream(type='movie-wanted', action='delete', payload=movie_metadata['radarrId'])
notify_radarr(movie_metadata.radarrId)
event_stream(type='movie-wanted', action='delete', payload=movie_metadata.radarrId)
event_tracker.track(provider=downloaded_provider, action=action, language=downloaded_language)
@ -145,4 +147,15 @@ def process_subtitle(subtitle, media_type, audio_language, path, max_score, is_u
forced=subtitle.language.forced,
subtitle_id=subtitle.id,
reversed_subtitles_path=reversed_subtitles_path,
hearing_impaired=subtitle.language.hi)
hearing_impaired=subtitle.language.hi,
matched=list(subtitle.matches),
not_matched=_get_not_matched(subtitle, media_type))
def _get_not_matched(subtitle, media_type):
_, _, scores = _get_scores(media_type)
if 'hash' not in subtitle.matches:
return list(set(scores) - set(subtitle.matches))
else:
return []

View File

@ -7,7 +7,7 @@ import re
from subliminal import Episode, Movie
from utilities.path_mappings import path_mappings
from app.database import TableShows, TableEpisodes, TableMovies
from app.database import TableShows, TableEpisodes, TableMovies, database, select
from .utils import convert_to_guessit
@ -17,84 +17,85 @@ _TITLE_RE = re.compile(r'\s(\(\d{4}\))')
def refine_from_db(path, video):
if isinstance(video, Episode):
data = TableEpisodes.select(TableShows.title.alias('seriesTitle'),
TableEpisodes.season,
TableEpisodes.episode,
TableEpisodes.title.alias('episodeTitle'),
TableShows.year,
TableShows.tvdbId,
TableShows.alternativeTitles,
TableEpisodes.format,
TableEpisodes.resolution,
TableEpisodes.video_codec,
TableEpisodes.audio_codec,
TableEpisodes.path,
TableShows.imdbId)\
.join(TableShows, on=(TableEpisodes.sonarrSeriesId == TableShows.sonarrSeriesId))\
.where((TableEpisodes.path == path_mappings.path_replace_reverse(path)))\
.dicts()
data = database.execute(
select(TableShows.title.label('seriesTitle'),
TableEpisodes.season,
TableEpisodes.episode,
TableEpisodes.title.label('episodeTitle'),
TableShows.year,
TableShows.tvdbId,
TableShows.alternativeTitles,
TableEpisodes.format,
TableEpisodes.resolution,
TableEpisodes.video_codec,
TableEpisodes.audio_codec,
TableEpisodes.path,
TableShows.imdbId)
.select_from(TableEpisodes)
.join(TableShows)
.where((TableEpisodes.path == path_mappings.path_replace_reverse(path)))) \
.first()
if len(data):
data = data[0]
video.series = _TITLE_RE.sub('', data['seriesTitle'])
video.season = int(data['season'])
video.episode = int(data['episode'])
video.title = data['episodeTitle']
if data:
video.series = _TITLE_RE.sub('', data.seriesTitle)
video.season = int(data.season)
video.episode = int(data.episode)
video.title = data.episodeTitle
# Only refine year as a fallback
if not video.year and data['year']:
if int(data['year']) > 0:
video.year = int(data['year'])
if not video.year and data.year:
if int(data.year) > 0:
video.year = int(data.year)
video.series_tvdb_id = int(data['tvdbId'])
video.alternative_series = ast.literal_eval(data['alternativeTitles'])
if data['imdbId'] and not video.series_imdb_id:
video.series_imdb_id = data['imdbId']
video.series_tvdb_id = int(data.tvdbId)
video.alternative_series = ast.literal_eval(data.alternativeTitles)
if data.imdbId and not video.series_imdb_id:
video.series_imdb_id = data.imdbId
if not video.source:
video.source = convert_to_guessit('source', str(data['format']))
video.source = convert_to_guessit('source', str(data.format))
if not video.resolution:
video.resolution = str(data['resolution'])
video.resolution = str(data.resolution)
if not video.video_codec:
if data['video_codec']:
video.video_codec = convert_to_guessit('video_codec', data['video_codec'])
if data.video_codec:
video.video_codec = convert_to_guessit('video_codec', data.video_codec)
if not video.audio_codec:
if data['audio_codec']:
video.audio_codec = convert_to_guessit('audio_codec', data['audio_codec'])
if data.audio_codec:
video.audio_codec = convert_to_guessit('audio_codec', data.audio_codec)
elif isinstance(video, Movie):
data = TableMovies.select(TableMovies.title,
TableMovies.year,
TableMovies.alternativeTitles,
TableMovies.format,
TableMovies.resolution,
TableMovies.video_codec,
TableMovies.audio_codec,
TableMovies.imdbId)\
.where(TableMovies.path == path_mappings.path_replace_reverse_movie(path))\
.dicts()
data = database.execute(
select(TableMovies.title,
TableMovies.year,
TableMovies.alternativeTitles,
TableMovies.format,
TableMovies.resolution,
TableMovies.video_codec,
TableMovies.audio_codec,
TableMovies.imdbId)
.where(TableMovies.path == path_mappings.path_replace_reverse_movie(path))) \
.first()
if len(data):
data = data[0]
video.title = _TITLE_RE.sub('', data['title'])
if data:
video.title = _TITLE_RE.sub('', data.title)
# Only refine year as a fallback
if not video.year and data['year']:
if int(data['year']) > 0:
video.year = int(data['year'])
if not video.year and data.year:
if int(data.year) > 0:
video.year = int(data.year)
if data['imdbId'] and not video.imdb_id:
video.imdb_id = data['imdbId']
video.alternative_titles = ast.literal_eval(data['alternativeTitles'])
if data.imdbId and not video.imdb_id:
video.imdb_id = data.imdbId
video.alternative_titles = ast.literal_eval(data.alternativeTitles)
if not video.source:
if data['format']:
video.source = convert_to_guessit('source', data['format'])
if data.format:
video.source = convert_to_guessit('source', data.format)
if not video.resolution:
if data['resolution']:
video.resolution = data['resolution']
if data.resolution:
video.resolution = data.resolution
if not video.video_codec:
if data['video_codec']:
video.video_codec = convert_to_guessit('video_codec', data['video_codec'])
if data.video_codec:
video.video_codec = convert_to_guessit('video_codec', data.video_codec)
if not video.audio_codec:
if data['audio_codec']:
video.audio_codec = convert_to_guessit('audio_codec', data['audio_codec'])
if data.audio_codec:
video.audio_codec = convert_to_guessit('audio_codec', data.audio_codec)
return video

View File

@ -6,31 +6,31 @@ import logging
from subliminal import Movie
from utilities.path_mappings import path_mappings
from app.database import TableEpisodes, TableMovies
from app.database import TableEpisodes, TableMovies, database, select
from utilities.video_analyzer import parse_video_metadata
def refine_from_ffprobe(path, video):
if isinstance(video, Movie):
file_id = TableMovies.select(TableMovies.movie_file_id, TableMovies.file_size)\
.where(TableMovies.path == path_mappings.path_replace_reverse_movie(path))\
.dicts()\
.get_or_none()
file_id = database.execute(
select(TableMovies.movie_file_id, TableMovies.file_size)
.where(TableMovies.path == path_mappings.path_replace_reverse_movie(path))) \
.first()
else:
file_id = TableEpisodes.select(TableEpisodes.episode_file_id, TableEpisodes.file_size)\
.where(TableEpisodes.path == path_mappings.path_replace_reverse(path))\
.dicts()\
.get_or_none()
file_id = database.execute(
select(TableEpisodes.episode_file_id, TableEpisodes.file_size)
.where(TableEpisodes.path == path_mappings.path_replace_reverse(path)))\
.first()
if not file_id:
return video
if isinstance(video, Movie):
data = parse_video_metadata(file=path, file_size=file_id['file_size'],
movie_file_id=file_id['movie_file_id'])
data = parse_video_metadata(file=path, file_size=file_id.file_size,
movie_file_id=file_id.movie_file_id)
else:
data = parse_video_metadata(file=path, file_size=file_id['file_size'],
episode_file_id=file_id['episode_file_id'])
data = parse_video_metadata(file=path, file_size=file_id.file_size,
episode_file_id=file_id.episode_file_id)
if not data or ('ffprobe' not in data and 'mediainfo' not in data):
logging.debug("No cache available for this file: {}".format(path))

View File

@ -7,7 +7,7 @@ import re
from app.config import get_settings
from app.database import TableCustomScoreProfileConditions as conditions_table, TableCustomScoreProfiles as \
profiles_table
profiles_table, database, select
logger = logging.getLogger(__name__)
@ -91,14 +91,14 @@ class CustomScoreProfile:
self._conditions_loaded = False
def load_conditions(self):
try:
self._conditions = [
Condition.from_dict(item)
for item in self.conditions_table.select()
.where(self.conditions_table.profile_id == self.id)
.dicts()
]
except self.conditions_table.DoesNotExist:
self._conditions = [
Condition.from_dict(item)
for item in database.execute(
select(self.conditions_table)
.where(self.conditions_table.profile_id == self.id))
.all()
]
if not self._conditions:
logger.debug("Conditions not found for %s", self)
self._conditions = []
@ -164,15 +164,16 @@ class Score:
def load_profiles(self):
"""Load the profiles associated with the class. This method must be called
after every custom profile creation or update."""
try:
self._profiles = [
CustomScoreProfile(**item)
for item in self.profiles_table.select()
.where(self.profiles_table.media == self.media)
.dicts()
]
self._profiles = [
CustomScoreProfile(**item)
for item in database.execute(
select(self.profiles_table)
.where(self.profiles_table.media == self.media))
.all()
]
if self._profiles:
logger.debug("Loaded profiles: %s", self._profiles)
except self.profiles_table.DoesNotExist:
else:
logger.debug("No score profiles found")
self._profiles = []

View File

@ -3,13 +3,15 @@
import logging
import operator
import os
import ast
from datetime import datetime, timedelta
from functools import reduce
from sqlalchemy import and_
from app.config import settings
from app.database import get_exclusion_clause, get_audio_profile_languages, TableShows, TableEpisodes, TableMovies, \
TableHistory, TableHistoryMovie, get_profiles_list
TableHistory, TableHistoryMovie, database, select, func, get_profiles_list
from app.event_handler import show_progress, hide_progress
from app.get_providers import get_providers
from app.notifier import send_notifications, send_notifications_movie
@ -27,9 +29,60 @@ def upgrade_subtitles():
if use_sonarr:
episodes_to_upgrade = get_upgradable_episode_subtitles()
count_episode_to_upgrade = len(episodes_to_upgrade)
episodes_data = [{
'id': x.id,
'seriesTitle': x.seriesTitle,
'season': x.season,
'episode': x.episode,
'title': x.title,
'language': x.language,
'audio_language': x.audio_language,
'video_path': x.video_path,
'sceneName': x.sceneName,
'score': x.score,
'sonarrEpisodeId': x.sonarrEpisodeId,
'sonarrSeriesId': x.sonarrSeriesId,
'subtitles_path': x.subtitles_path,
'path': x.path,
'external_subtitles': [y[1] for y in ast.literal_eval(x.external_subtitles) if y[1]],
'upgradable': bool(x.upgradable),
} for x in database.execute(
select(TableHistory.id,
TableShows.title.label('seriesTitle'),
TableEpisodes.season,
TableEpisodes.episode,
TableEpisodes.title,
TableHistory.language,
TableEpisodes.audio_language,
TableHistory.video_path,
TableEpisodes.sceneName,
TableHistory.score,
TableHistory.sonarrEpisodeId,
TableHistory.sonarrSeriesId,
TableHistory.subtitles_path,
TableEpisodes.path,
TableShows.profileId,
TableEpisodes.subtitles.label('external_subtitles'),
episodes_to_upgrade.c.id.label('upgradable'))
.select_from(TableHistory)
.join(TableShows, onclause=TableHistory.sonarrSeriesId == TableShows.sonarrSeriesId)
.join(TableEpisodes, onclause=TableHistory.sonarrEpisodeId == TableEpisodes.sonarrEpisodeId)
.join(episodes_to_upgrade, onclause=TableHistory.id == episodes_to_upgrade.c.id, isouter=True)
.where(episodes_to_upgrade.c.id.is_not(None)))
.all() if _language_still_desired(x.language, x.profileId)]
for i, episode in enumerate(episodes_to_upgrade):
for item in episodes_data:
if item['upgradable']:
if item['subtitles_path'] not in item['external_subtitles'] or \
not item['video_path'] == item['path']:
item.update({"upgradable": False})
del item['path']
del item['external_subtitles']
count_episode_to_upgrade = len(episodes_data)
for i, episode in enumerate(episodes_data):
providers_list = get_providers()
show_progress(id='upgrade_episodes_progress',
@ -72,9 +125,49 @@ def upgrade_subtitles():
if use_radarr:
movies_to_upgrade = get_upgradable_movies_subtitles()
count_movie_to_upgrade = len(movies_to_upgrade)
movies_data = [{
'title': x.title,
'language': x.language,
'audio_language': x.audio_language,
'video_path': x.video_path,
'sceneName': x.sceneName,
'score': x.score,
'radarrId': x.radarrId,
'path': x.path,
'subtitles_path': x.subtitles_path,
'external_subtitles': [y[1] for y in ast.literal_eval(x.external_subtitles) if y[1]],
'upgradable': bool(x.upgradable),
} for x in database.execute(
select(TableMovies.title,
TableHistoryMovie.language,
TableMovies.audio_language,
TableHistoryMovie.video_path,
TableMovies.sceneName,
TableHistoryMovie.score,
TableHistoryMovie.radarrId,
TableHistoryMovie.subtitles_path,
TableMovies.path,
TableMovies.profileId,
TableMovies.subtitles.label('external_subtitles'),
movies_to_upgrade.c.id.label('upgradable'))
.select_from(TableHistoryMovie)
.join(TableMovies, onclause=TableHistoryMovie.radarrId == TableMovies.radarrId)
.join(movies_to_upgrade, onclause=TableHistoryMovie.id == movies_to_upgrade.c.id, isouter=True)
.where(movies_to_upgrade.c.id.is_not(None)))
.all() if _language_still_desired(x.language, x.profileId)]
for i, movie in enumerate(movies_to_upgrade):
for item in movies_data:
if item['upgradable']:
if item['subtitles_path'] not in item['external_subtitles'] or \
not item['video_path'] == item['path']:
item.update({"upgradable": False})
del item['path']
del item['external_subtitles']
count_movie_to_upgrade = len(movies_data)
for i, movie in enumerate(movies_data):
providers_list = get_providers()
show_progress(id='upgrade_movies_progress',
@ -127,45 +220,6 @@ def get_queries_condition_parameters():
return [minimum_timestamp, query_actions]
def parse_upgradable_list(upgradable_list, perfect_score, media_type):
if media_type == 'series':
path_replace_method = path_mappings.path_replace
else:
path_replace_method = path_mappings.path_replace_movie
items_to_upgrade = []
for item in upgradable_list:
logging.debug(f"Trying to validate eligibility to upgrade for this subtitles: "
f"{item['subtitles_path']}")
if not os.path.exists(path_replace_method(item['subtitles_path'])):
logging.debug("Subtitles file doesn't exists anymore, we skip this one.")
continue
if (item['video_path'], item['language']) in \
[(x['video_path'], x['language']) for x in items_to_upgrade]:
logging.debug("Newer video path and subtitles language combination already in list of subtitles to "
"upgrade, we skip this one.")
continue
if os.path.exists(path_replace_method(item['subtitles_path'])) and \
os.path.exists(path_replace_method(item['video_path'])):
logging.debug("Video and subtitles file are still there, we continue the eligibility validation.")
pass
items_to_upgrade.append(item)
if not settings.general.getboolean('upgrade_manual'):
logging.debug("Removing history items for manually downloaded or translated subtitles.")
items_to_upgrade = [x for x in items_to_upgrade if x['action'] in [2, 4, 6]]
logging.debug("Removing history items for already perfectly scored subtitles.")
items_to_upgrade = [x for x in items_to_upgrade if x['score'] < perfect_score]
logging.debug(f"Bazarr will try to upgrade {len(items_to_upgrade)} subtitles.")
return items_to_upgrade
def parse_language_string(language_string):
if language_string.endswith('forced'):
language = language_string.split(':')[0]
@ -187,82 +241,59 @@ def get_upgradable_episode_subtitles():
if not settings.general.getboolean('upgrade_subs'):
return []
max_id_timestamp = select(TableHistory.video_path,
TableHistory.language,
func.max(TableHistory.timestamp).label('timestamp')) \
.group_by(TableHistory.video_path, TableHistory.language) \
.distinct() \
.subquery()
minimum_timestamp, query_actions = get_queries_condition_parameters()
upgradable_episodes_conditions = [(TableHistory.action << query_actions),
upgradable_episodes_conditions = [(TableHistory.action.in_(query_actions)),
(TableHistory.timestamp > minimum_timestamp),
(TableHistory.score.is_null(False))]
TableHistory.score.is_not(None),
(TableHistory.score < 357)]
upgradable_episodes_conditions += get_exclusion_clause('series')
upgradable_episodes = TableHistory.select(TableHistory.video_path,
TableHistory.language,
TableHistory.score,
TableShows.tags,
TableShows.profileId,
TableEpisodes.audio_language,
TableEpisodes.sceneName,
TableEpisodes.title,
TableEpisodes.sonarrSeriesId,
TableHistory.action,
TableHistory.subtitles_path,
TableEpisodes.sonarrEpisodeId,
TableHistory.timestamp.alias('timestamp'),
TableEpisodes.monitored,
TableEpisodes.season,
TableEpisodes.episode,
TableShows.title.alias('seriesTitle'),
TableShows.seriesType) \
.join(TableShows, on=(TableHistory.sonarrSeriesId == TableShows.sonarrSeriesId)) \
.join(TableEpisodes, on=(TableHistory.sonarrEpisodeId == TableEpisodes.sonarrEpisodeId)) \
return select(TableHistory.id)\
.select_from(TableHistory) \
.join(max_id_timestamp, onclause=and_(TableHistory.video_path == max_id_timestamp.c.video_path,
TableHistory.language == max_id_timestamp.c.language,
max_id_timestamp.c.timestamp == TableHistory.timestamp)) \
.join(TableShows, onclause=TableHistory.sonarrSeriesId == TableShows.sonarrSeriesId) \
.join(TableEpisodes, onclause=TableHistory.sonarrEpisodeId == TableEpisodes.sonarrEpisodeId) \
.where(reduce(operator.and_, upgradable_episodes_conditions)) \
.order_by(TableHistory.timestamp.desc()) \
.dicts()
if not upgradable_episodes:
return []
else:
upgradable_episodes = [x for x in upgradable_episodes if _language_still_desired(x['language'], x['profileId'])]
logging.debug(f"{len(upgradable_episodes)} potentially upgradable episode subtitles have been found, let's "
f"filter them...")
return parse_upgradable_list(upgradable_list=upgradable_episodes, perfect_score=357, media_type='series')
.order_by(TableHistory.timestamp.desc())\
.subquery()
def get_upgradable_movies_subtitles():
if not settings.general.getboolean('upgrade_subs'):
return []
max_id_timestamp = select(TableHistoryMovie.video_path,
TableHistoryMovie.language,
func.max(TableHistoryMovie.timestamp).label('timestamp')) \
.group_by(TableHistoryMovie.video_path, TableHistoryMovie.language) \
.distinct() \
.subquery()
minimum_timestamp, query_actions = get_queries_condition_parameters()
upgradable_movies_conditions = [(TableHistoryMovie.action << query_actions),
upgradable_movies_conditions = [(TableHistoryMovie.action.in_(query_actions)),
(TableHistoryMovie.timestamp > minimum_timestamp),
(TableHistoryMovie.score.is_null(False))]
TableHistoryMovie.score.is_not(None),
(TableHistoryMovie.score < 117)]
upgradable_movies_conditions += get_exclusion_clause('movie')
upgradable_movies = TableHistoryMovie.select(TableHistoryMovie.video_path,
TableHistoryMovie.language,
TableHistoryMovie.score,
TableMovies.profileId,
TableHistoryMovie.action,
TableHistoryMovie.subtitles_path,
TableMovies.audio_language,
TableMovies.sceneName,
TableHistoryMovie.timestamp.alias('timestamp'),
TableMovies.monitored,
TableMovies.tags,
TableMovies.radarrId,
TableMovies.title) \
.join(TableMovies, on=(TableHistoryMovie.radarrId == TableMovies.radarrId)) \
return select(TableHistoryMovie.id) \
.select_from(TableHistoryMovie) \
.join(max_id_timestamp, onclause=and_(TableHistoryMovie.video_path == max_id_timestamp.c.video_path,
TableHistoryMovie.language == max_id_timestamp.c.language,
max_id_timestamp.c.timestamp == TableHistoryMovie.timestamp)) \
.join(TableMovies, onclause=TableHistoryMovie.radarrId == TableMovies.radarrId) \
.where(reduce(operator.and_, upgradable_movies_conditions)) \
.order_by(TableHistoryMovie.timestamp.desc()) \
.dicts()
if not upgradable_movies:
return []
else:
upgradable_movies = [x for x in upgradable_movies if _language_still_desired(x['language'], x['profileId'])]
logging.debug(f"{len(upgradable_movies)} potentially upgradable movie subtitles have been found, let's filter "
f"them...")
return parse_upgradable_list(upgradable_list=upgradable_movies, perfect_score=117, media_type='movie')
.subquery()
def _language_still_desired(language, profile_id):

View File

@ -18,7 +18,7 @@ from utilities.path_mappings import path_mappings
from radarr.notify import notify_radarr
from sonarr.notify import notify_sonarr
from languages.custom_lang import CustomLanguage
from app.database import TableEpisodes, TableMovies, TableShows, get_profiles_list
from app.database import TableEpisodes, TableMovies, TableShows, get_profiles_list, database, select
from app.event_handler import event_stream
from subtitles.processing import ProcessSubtitlesResult
@ -52,26 +52,27 @@ def manual_upload_subtitle(path, language, forced, hi, media_type, subtitle, aud
lang_obj = Language.rebuild(lang_obj, forced=True)
if media_type == 'series':
episode_metadata = TableEpisodes.select(TableEpisodes.sonarrSeriesId,
TableEpisodes.sonarrEpisodeId,
TableShows.profileId) \
.join(TableShows, on=(TableEpisodes.sonarrSeriesId == TableShows.sonarrSeriesId)) \
.where(TableEpisodes.path == path_mappings.path_replace_reverse(path)) \
.dicts() \
.get_or_none()
episode_metadata = database.execute(
select(TableEpisodes.sonarrSeriesId,
TableEpisodes.sonarrEpisodeId,
TableShows.profileId)
.select_from(TableEpisodes)
.join(TableShows)
.where(TableEpisodes.path == path_mappings.path_replace_reverse(path))) \
.first()
if episode_metadata:
use_original_format = bool(get_profiles_list(episode_metadata["profileId"])["originalFormat"])
use_original_format = bool(get_profiles_list(episode_metadata.profileId)["originalFormat"])
else:
use_original_format = False
else:
movie_metadata = TableMovies.select(TableMovies.radarrId, TableMovies.profileId) \
.where(TableMovies.path == path_mappings.path_replace_reverse_movie(path)) \
.dicts() \
.get_or_none()
movie_metadata = database.execute(
select(TableMovies.radarrId, TableMovies.profileId)
.where(TableMovies.path == path_mappings.path_replace_reverse_movie(path))) \
.first()
if movie_metadata:
use_original_format = bool(get_profiles_list(movie_metadata["profileId"])["originalFormat"])
use_original_format = bool(get_profiles_list(movie_metadata.profileId)["originalFormat"])
else:
use_original_format = False
@ -134,18 +135,18 @@ def manual_upload_subtitle(path, language, forced, hi, media_type, subtitle, aud
if media_type == 'series':
if not episode_metadata:
return
series_id = episode_metadata['sonarrSeriesId']
episode_id = episode_metadata['sonarrEpisodeId']
series_id = episode_metadata.sonarrSeriesId
episode_id = episode_metadata.sonarrEpisodeId
sync_subtitles(video_path=path, srt_path=subtitle_path, srt_lang=uploaded_language_code2, media_type=media_type,
percent_score=100, sonarr_series_id=episode_metadata['sonarrSeriesId'], forced=forced,
sonarr_episode_id=episode_metadata['sonarrEpisodeId'])
percent_score=100, sonarr_series_id=episode_metadata.sonarrSeriesId, forced=forced,
sonarr_episode_id=episode_metadata.sonarrEpisodeId)
else:
if not movie_metadata:
return
series_id = ""
episode_id = movie_metadata['radarrId']
episode_id = movie_metadata.radarrId
sync_subtitles(video_path=path, srt_path=subtitle_path, srt_lang=uploaded_language_code2, media_type=media_type,
percent_score=100, radarr_id=movie_metadata['radarrId'], forced=forced)
percent_score=100, radarr_id=movie_metadata.radarrId, forced=forced)
if use_postprocessing:
command = pp_replace(postprocessing_cmd, path, subtitle_path, uploaded_language, uploaded_language_code2,
@ -157,15 +158,15 @@ def manual_upload_subtitle(path, language, forced, hi, media_type, subtitle, aud
if media_type == 'series':
reversed_path = path_mappings.path_replace_reverse(path)
reversed_subtitles_path = path_mappings.path_replace_reverse(subtitle_path)
notify_sonarr(episode_metadata['sonarrSeriesId'])
event_stream(type='series', action='update', payload=episode_metadata['sonarrSeriesId'])
event_stream(type='episode-wanted', action='delete', payload=episode_metadata['sonarrEpisodeId'])
notify_sonarr(episode_metadata.sonarrSeriesId)
event_stream(type='series', action='update', payload=episode_metadata.sonarrSeriesId)
event_stream(type='episode-wanted', action='delete', payload=episode_metadata.sonarrEpisodeId)
else:
reversed_path = path_mappings.path_replace_reverse_movie(path)
reversed_subtitles_path = path_mappings.path_replace_reverse_movie(subtitle_path)
notify_radarr(movie_metadata['radarrId'])
event_stream(type='movie', action='update', payload=movie_metadata['radarrId'])
event_stream(type='movie-wanted', action='delete', payload=movie_metadata['radarrId'])
notify_radarr(movie_metadata.radarrId)
event_stream(type='movie', action='update', payload=movie_metadata.radarrId)
event_stream(type='movie-wanted', action='delete', payload=movie_metadata.radarrId)
result = ProcessSubtitlesResult(message=language_from_alpha3(language) + modifier_string + " Subtitles manually "
"uploaded.",

View File

@ -12,7 +12,7 @@ from subtitles.indexer.movies import store_subtitles_movie
from radarr.history import history_log_movie
from app.notifier import send_notifications_movie
from app.get_providers import get_providers
from app.database import get_exclusion_clause, get_audio_profile_languages, TableMovies
from app.database import get_exclusion_clause, get_audio_profile_languages, TableMovies, database, update, select
from app.event_handler import event_stream, show_progress, hide_progress
from ..adaptive_searching import is_search_active, updateFailedAttempts
@ -20,7 +20,7 @@ from ..download import generate_subtitles
def _wanted_movie(movie):
audio_language_list = get_audio_profile_languages(movie['audio_language'])
audio_language_list = get_audio_profile_languages(movie.audio_language)
if len(audio_language_list) > 0:
audio_language = audio_language_list[0]['name']
else:
@ -28,48 +28,48 @@ def _wanted_movie(movie):
languages = []
for language in ast.literal_eval(movie['missing_subtitles']):
if is_search_active(desired_language=language, attempt_string=movie['failedAttempts']):
TableMovies.update({TableMovies.failedAttempts:
updateFailedAttempts(desired_language=language,
attempt_string=movie['failedAttempts'])}) \
.where(TableMovies.radarrId == movie['radarrId']) \
.execute()
for language in ast.literal_eval(movie.missing_subtitles):
if is_search_active(desired_language=language, attempt_string=movie.failedAttempts):
database.execute(
update(TableMovies)
.values(failedAttempts=updateFailedAttempts(desired_language=language,
attempt_string=movie.failedAttempts))
.where(TableMovies.radarrId == movie.radarrId))
hi_ = "True" if language.endswith(':hi') else "False"
forced_ = "True" if language.endswith(':forced') else "False"
languages.append((language.split(":")[0], hi_, forced_))
else:
logging.info(f"BAZARR Search is throttled by adaptive search for this movie {movie['path']} and "
logging.info(f"BAZARR Search is throttled by adaptive search for this movie {movie.path} and "
f"language: {language}")
for result in generate_subtitles(path_mappings.path_replace_movie(movie['path']),
for result in generate_subtitles(path_mappings.path_replace_movie(movie.path),
languages,
audio_language,
str(movie['sceneName']),
movie['title'],
str(movie.sceneName),
movie.title,
'movie',
check_if_still_required=True):
if result:
store_subtitles_movie(movie['path'], path_mappings.path_replace_movie(movie['path']))
history_log_movie(1, movie['radarrId'], result)
event_stream(type='movie-wanted', action='delete', payload=movie['radarrId'])
send_notifications_movie(movie['radarrId'], result.message)
store_subtitles_movie(movie.path, path_mappings.path_replace_movie(movie.path))
history_log_movie(1, movie.radarrId, result)
event_stream(type='movie-wanted', action='delete', payload=movie.radarrId)
send_notifications_movie(movie.radarrId, result.message)
def wanted_download_subtitles_movie(radarr_id):
movies_details = TableMovies.select(TableMovies.path,
TableMovies.missing_subtitles,
TableMovies.radarrId,
TableMovies.audio_language,
TableMovies.sceneName,
TableMovies.failedAttempts,
TableMovies.title)\
.where((TableMovies.radarrId == radarr_id))\
.dicts()
movies_details = list(movies_details)
movies_details = database.execute(
select(TableMovies.path,
TableMovies.missing_subtitles,
TableMovies.radarrId,
TableMovies.audio_language,
TableMovies.sceneName,
TableMovies.failedAttempts,
TableMovies.title)
.where(TableMovies.radarrId == radarr_id)) \
.all()
for movie in movies_details:
providers_list = get_providers()
@ -84,25 +84,25 @@ def wanted_download_subtitles_movie(radarr_id):
def wanted_search_missing_subtitles_movies():
conditions = [(TableMovies.missing_subtitles != '[]')]
conditions += get_exclusion_clause('movie')
movies = TableMovies.select(TableMovies.radarrId,
TableMovies.tags,
TableMovies.monitored,
TableMovies.title) \
.where(reduce(operator.and_, conditions)) \
.dicts()
movies = list(movies)
movies = database.execute(
select(TableMovies.radarrId,
TableMovies.tags,
TableMovies.monitored,
TableMovies.title)
.where(reduce(operator.and_, conditions))) \
.all()
count_movies = len(movies)
for i, movie in enumerate(movies):
show_progress(id='wanted_movies_progress',
header='Searching subtitles...',
name=movie['title'],
name=movie.title,
value=i,
count=count_movies)
providers = get_providers()
if providers:
wanted_download_subtitles_movie(movie['radarrId'])
wanted_download_subtitles_movie(movie.radarrId)
else:
logging.info("BAZARR All providers are throttled")
break

View File

@ -12,7 +12,8 @@ from subtitles.indexer.series import store_subtitles
from sonarr.history import history_log
from app.notifier import send_notifications
from app.get_providers import get_providers
from app.database import get_exclusion_clause, get_audio_profile_languages, TableShows, TableEpisodes
from app.database import get_exclusion_clause, get_audio_profile_languages, TableShows, TableEpisodes, database, \
update, select
from app.event_handler import event_stream, show_progress, hide_progress
from ..adaptive_searching import is_search_active, updateFailedAttempts
@ -20,20 +21,20 @@ from ..download import generate_subtitles
def _wanted_episode(episode):
audio_language_list = get_audio_profile_languages(episode['audio_language'])
audio_language_list = get_audio_profile_languages(episode.audio_language)
if len(audio_language_list) > 0:
audio_language = audio_language_list[0]['name']
else:
audio_language = 'None'
languages = []
for language in ast.literal_eval(episode['missing_subtitles']):
if is_search_active(desired_language=language, attempt_string=episode['failedAttempts']):
TableEpisodes.update({TableEpisodes.failedAttempts:
updateFailedAttempts(desired_language=language,
attempt_string=episode['failedAttempts'])}) \
.where(TableEpisodes.sonarrEpisodeId == episode['sonarrEpisodeId']) \
.execute()
for language in ast.literal_eval(episode.missing_subtitles):
if is_search_active(desired_language=language, attempt_string=episode.failedAttempts):
database.execute(
update(TableEpisodes)
.values(failedAttempts=updateFailedAttempts(desired_language=language,
attempt_string=episode.failedAttempts))
.where(TableEpisodes.sonarrEpisodeId == episode.sonarrEpisodeId))
hi_ = "True" if language.endswith(':hi') else "False"
forced_ = "True" if language.endswith(':forced') else "False"
@ -41,37 +42,38 @@ def _wanted_episode(episode):
else:
logging.debug(
f"BAZARR Search is throttled by adaptive search for this episode {episode['path']} and "
f"BAZARR Search is throttled by adaptive search for this episode {episode.path} and "
f"language: {language}")
for result in generate_subtitles(path_mappings.path_replace(episode['path']),
for result in generate_subtitles(path_mappings.path_replace(episode.path),
languages,
audio_language,
str(episode['sceneName']),
episode['title'],
str(episode.sceneName),
episode.title,
'series',
check_if_still_required=True):
if result:
store_subtitles(episode['path'], path_mappings.path_replace(episode['path']))
history_log(1, episode['sonarrSeriesId'], episode['sonarrEpisodeId'], result)
event_stream(type='series', action='update', payload=episode['sonarrSeriesId'])
event_stream(type='episode-wanted', action='delete', payload=episode['sonarrEpisodeId'])
send_notifications(episode['sonarrSeriesId'], episode['sonarrEpisodeId'], result.message)
store_subtitles(episode.path, path_mappings.path_replace(episode.path))
history_log(1, episode.sonarrSeriesId, episode.sonarrEpisodeId, result)
event_stream(type='series', action='update', payload=episode.sonarrSeriesId)
event_stream(type='episode-wanted', action='delete', payload=episode.sonarrEpisodeId)
send_notifications(episode.sonarrSeriesId, episode.sonarrEpisodeId, result.message)
def wanted_download_subtitles(sonarr_episode_id):
episodes_details = TableEpisodes.select(TableEpisodes.path,
TableEpisodes.missing_subtitles,
TableEpisodes.sonarrEpisodeId,
TableEpisodes.sonarrSeriesId,
TableEpisodes.audio_language,
TableEpisodes.sceneName,
TableEpisodes.failedAttempts,
TableShows.title)\
.join(TableShows, on=(TableEpisodes.sonarrSeriesId == TableShows.sonarrSeriesId))\
.where((TableEpisodes.sonarrEpisodeId == sonarr_episode_id))\
.dicts()
episodes_details = list(episodes_details)
episodes_details = database.execute(
select(TableEpisodes.path,
TableEpisodes.missing_subtitles,
TableEpisodes.sonarrEpisodeId,
TableEpisodes.sonarrSeriesId,
TableEpisodes.audio_language,
TableEpisodes.sceneName,
TableEpisodes.failedAttempts,
TableShows.title)
.select_from(TableEpisodes)
.join(TableShows)
.where((TableEpisodes.sonarrEpisodeId == sonarr_episode_id))) \
.all()
for episode in episodes_details:
providers_list = get_providers()
@ -86,34 +88,35 @@ def wanted_download_subtitles(sonarr_episode_id):
def wanted_search_missing_subtitles_series():
conditions = [(TableEpisodes.missing_subtitles != '[]')]
conditions += get_exclusion_clause('series')
episodes = TableEpisodes.select(TableEpisodes.sonarrSeriesId,
TableEpisodes.sonarrEpisodeId,
TableShows.tags,
TableEpisodes.monitored,
TableShows.title,
TableEpisodes.season,
TableEpisodes.episode,
TableEpisodes.title.alias('episodeTitle'),
TableShows.seriesType)\
.join(TableShows, on=(TableEpisodes.sonarrSeriesId == TableShows.sonarrSeriesId))\
.where(reduce(operator.and_, conditions))\
.dicts()
episodes = list(episodes)
episodes = database.execute(
select(TableEpisodes.sonarrSeriesId,
TableEpisodes.sonarrEpisodeId,
TableShows.tags,
TableEpisodes.monitored,
TableShows.title,
TableEpisodes.season,
TableEpisodes.episode,
TableEpisodes.title.label('episodeTitle'),
TableShows.seriesType)
.select_from(TableEpisodes)
.join(TableShows)
.where(reduce(operator.and_, conditions))) \
.all()
count_episodes = len(episodes)
for i, episode in enumerate(episodes):
show_progress(id='wanted_episodes_progress',
header='Searching subtitles...',
name='{0} - S{1:02d}E{2:02d} - {3}'.format(episode['title'],
episode['season'],
episode['episode'],
episode['episodeTitle']),
name='{0} - S{1:02d}E{2:02d} - {3}'.format(episode.title,
episode.season,
episode.episode,
episode.episodeTitle),
value=i,
count=count_episodes)
providers = get_providers()
if providers:
wanted_download_subtitles(episode['sonarrEpisodeId'])
wanted_download_subtitles(episode.sonarrEpisodeId)
else:
logging.info("BAZARR All providers are throttled")
break

View File

@ -1,7 +1,7 @@
# coding=utf-8
from app.config import settings
from app.database import TableShowsRootfolder, TableMoviesRootfolder
from app.database import TableShowsRootfolder, TableMoviesRootfolder, database, select
from app.event_handler import event_stream
from .path_mappings import path_mappings
from sonarr.rootfolder import check_sonarr_rootfolder
@ -25,24 +25,26 @@ def get_health_issues():
# get Sonarr rootfolder issues
if settings.general.getboolean('use_sonarr'):
rootfolder = TableShowsRootfolder.select(TableShowsRootfolder.path,
TableShowsRootfolder.accessible,
TableShowsRootfolder.error)\
.where(TableShowsRootfolder.accessible == 0)\
.dicts()
rootfolder = database.execute(
select(TableShowsRootfolder.path,
TableShowsRootfolder.accessible,
TableShowsRootfolder.error)
.where(TableShowsRootfolder.accessible == 0)) \
.all()
for item in rootfolder:
health_issues.append({'object': path_mappings.path_replace(item['path']),
'issue': item['error']})
health_issues.append({'object': path_mappings.path_replace(item.path),
'issue': item.error})
# get Radarr rootfolder issues
if settings.general.getboolean('use_radarr'):
rootfolder = TableMoviesRootfolder.select(TableMoviesRootfolder.path,
TableMoviesRootfolder.accessible,
TableMoviesRootfolder.error)\
.where(TableMoviesRootfolder.accessible == 0)\
.dicts()
rootfolder = database.execute(
select(TableMoviesRootfolder.path,
TableMoviesRootfolder.accessible,
TableMoviesRootfolder.error)
.where(TableMoviesRootfolder.accessible == 0)) \
.all()
for item in rootfolder:
health_issues.append({'object': path_mappings.path_replace_movie(item['path']),
'issue': item['error']})
health_issues.append({'object': path_mappings.path_replace_movie(item.path),
'issue': item.error})
return health_issues

View File

@ -7,7 +7,7 @@ from knowit.api import know, KnowitException
from languages.custom_lang import CustomLanguage
from languages.get_languages import language_from_alpha3, alpha3_from_alpha2
from app.database import TableEpisodes, TableMovies
from app.database import TableEpisodes, TableMovies, database, update, select
from utilities.path_mappings import path_mappings
from app.config import settings
@ -116,22 +116,22 @@ def parse_video_metadata(file, file_size, episode_file_id=None, movie_file_id=No
if use_cache:
# Get the actual cache value form database
if episode_file_id:
cache_key = TableEpisodes.select(TableEpisodes.ffprobe_cache)\
.where(TableEpisodes.path == path_mappings.path_replace_reverse(file))\
.dicts()\
.get_or_none()
cache_key = database.execute(
select(TableEpisodes.ffprobe_cache)
.where(TableEpisodes.path == path_mappings.path_replace_reverse(file))) \
.first()
elif movie_file_id:
cache_key = TableMovies.select(TableMovies.ffprobe_cache)\
.where(TableMovies.path == path_mappings.path_replace_reverse_movie(file))\
.dicts()\
.get_or_none()
cache_key = database.execute(
select(TableMovies.ffprobe_cache)
.where(TableMovies.path == path_mappings.path_replace_reverse_movie(file))) \
.first()
else:
cache_key = None
# check if we have a value for that cache key
try:
# Unpickle ffprobe cache
cached_value = pickle.loads(cache_key['ffprobe_cache'])
cached_value = pickle.loads(cache_key.ffprobe_cache)
except Exception:
pass
else:
@ -144,7 +144,7 @@ def parse_video_metadata(file, file_size, episode_file_id=None, movie_file_id=No
# no valid cache
pass
else:
# cache mut be renewed
# cache must be renewed
pass
# if not, we retrieve the metadata from the file
@ -180,11 +180,13 @@ def parse_video_metadata(file, file_size, episode_file_id=None, movie_file_id=No
# we write to db the result and return the newly cached ffprobe dict
if episode_file_id:
TableEpisodes.update({TableEpisodes.ffprobe_cache: pickle.dumps(data, pickle.HIGHEST_PROTOCOL)})\
.where(TableEpisodes.path == path_mappings.path_replace_reverse(file))\
.execute()
database.execute(
update(TableEpisodes)
.values(ffprobe_cache=pickle.dumps(data, pickle.HIGHEST_PROTOCOL))
.where(TableEpisodes.path == path_mappings.path_replace_reverse(file)))
elif movie_file_id:
TableMovies.update({TableEpisodes.ffprobe_cache: pickle.dumps(data, pickle.HIGHEST_PROTOCOL)})\
.where(TableMovies.path == path_mappings.path_replace_reverse_movie(file))\
.execute()
database.execute(
update(TableMovies)
.values(ffprobe_cache=pickle.dumps(data, pickle.HIGHEST_PROTOCOL))
.where(TableMovies.path == path_mappings.path_replace_reverse_movie(file)))
return data

View File

@ -87,6 +87,14 @@ const Search: FunctionComponent = () => {
value={query}
onChange={setQuery}
onBlur={() => setQuery("")}
filter={(value, item) =>
item.value.toLowerCase().includes(value.toLowerCase().trim()) ||
item.value
.normalize("NFD")
.replace(/[\u0300-\u036f]/g, "")
.toLowerCase()
.includes(value.trim())
}
></Autocomplete>
);
};

View File

@ -0,0 +1,78 @@
import { BuildKey } from "@/utilities";
import {
faCheck,
faCheckCircle,
faExclamationCircle,
faListCheck,
faTimes,
} from "@fortawesome/free-solid-svg-icons";
import { FontAwesomeIcon } from "@fortawesome/react-fontawesome";
import { Group, List, Popover, Stack, Text } from "@mantine/core";
import { useHover } from "@mantine/hooks";
import { FunctionComponent } from "react";
interface StateIconProps {
matches: string[];
dont: string[];
isHistory: boolean;
}
const StateIcon: FunctionComponent<StateIconProps> = ({
matches,
dont,
isHistory,
}) => {
const hasIssues = dont.length > 0;
const { hovered, ref } = useHover();
const PopoverTarget: FunctionComponent = () => {
if (isHistory) {
return <FontAwesomeIcon icon={faListCheck} />;
} else {
return (
<Text color={hasIssues ? "yellow" : "green"}>
<FontAwesomeIcon
icon={hasIssues ? faExclamationCircle : faCheckCircle}
/>
</Text>
);
}
};
return (
<Popover opened={hovered} position="top" width={360} withArrow withinPortal>
<Popover.Target>
<Text ref={ref}>
<PopoverTarget />
</Text>
</Popover.Target>
<Popover.Dropdown>
<Group position="left" spacing="xl" noWrap grow>
<Stack align="flex-start" justify="flex-start" spacing="xs" mb="auto">
<Text color="green">
<FontAwesomeIcon icon={faCheck}></FontAwesomeIcon>
</Text>
<List>
{matches.map((v, idx) => (
<List.Item key={BuildKey(idx, v, "match")}>{v}</List.Item>
))}
</List>
</Stack>
<Stack align="flex-start" justify="flex-start" spacing="xs" mb="auto">
<Text color="yellow">
<FontAwesomeIcon icon={faTimes}></FontAwesomeIcon>
</Text>
<List>
{dont.map((v, idx) => (
<List.Item key={BuildKey(idx, v, "miss")}>{v}</List.Item>
))}
</List>
</Stack>
</Group>
</Popover.Dropdown>
</Popover>
);
};
export default StateIcon;

View File

@ -5,6 +5,7 @@ import {
useMovieAddBlacklist,
useMovieHistory,
} from "@/apis/hooks";
import StateIcon from "@/components/StateIcon";
import { withModal } from "@/modules/modals";
import { faFileExcel, faInfoCircle } from "@fortawesome/free-solid-svg-icons";
import { FontAwesomeIcon } from "@fortawesome/react-fontawesome";
@ -62,6 +63,23 @@ const MovieHistoryView: FunctionComponent<MovieHistoryViewProps> = ({
Header: "Score",
accessor: "score",
},
{
accessor: "matches",
Cell: (row) => {
const { matches, dont_matches: dont } = row.row.original;
if (matches.length || dont.length) {
return (
<StateIcon
matches={matches}
dont={dont}
isHistory={true}
></StateIcon>
);
} else {
return null;
}
},
},
{
Header: "Date",
accessor: "timestamp",
@ -168,6 +186,23 @@ const EpisodeHistoryView: FunctionComponent<EpisodeHistoryViewProps> = ({
Header: "Score",
accessor: "score",
},
{
accessor: "matches",
Cell: (row) => {
const { matches, dont_matches: dont } = row.row.original;
if (matches.length || dont.length) {
return (
<StateIcon
matches={matches}
dont={dont}
isHistory={true}
></StateIcon>
);
} else {
return null;
}
},
},
{
Header: "Date",
accessor: "timestamp",

View File

@ -1,15 +1,11 @@
import { withModal } from "@/modules/modals";
import { task, TaskGroup } from "@/modules/task";
import { useTableStyles } from "@/styles";
import { BuildKey, GetItemId } from "@/utilities";
import { GetItemId } from "@/utilities";
import {
faCaretDown,
faCheck,
faCheckCircle,
faDownload,
faExclamationCircle,
faInfoCircle,
faTimes,
} from "@fortawesome/free-solid-svg-icons";
import { FontAwesomeIcon } from "@fortawesome/react-fontawesome";
import {
@ -20,19 +16,16 @@ import {
Code,
Collapse,
Divider,
Group,
List,
Popover,
Stack,
Text,
} from "@mantine/core";
import { useHover } from "@mantine/hooks";
import { isString } from "lodash";
import { FunctionComponent, useCallback, useMemo, useState } from "react";
import { useCallback, useMemo, useState } from "react";
import { UseQueryResult } from "react-query";
import { Column } from "react-table";
import { Action, PageTable } from "..";
import Language from "../bazarr/Language";
import StateIcon from "../StateIcon";
type SupportType = Item.Movie | Item.Episode;
@ -155,7 +148,13 @@ function ManualSearchView<T extends SupportType>(props: Props<T>) {
accessor: "matches",
Cell: (row) => {
const { matches, dont_matches: dont } = row.row.original;
return <StateIcon matches={matches} dont={dont}></StateIcon>;
return (
<StateIcon
matches={matches}
dont={dont}
isHistory={false}
></StateIcon>
);
},
},
{
@ -227,48 +226,3 @@ export const EpisodeSearchModal = withModal<Props<Item.Episode>>(
"episode-manual-search",
{ title: "Search Subtitles", size: "calc(100vw - 4rem)" }
);
const StateIcon: FunctionComponent<{ matches: string[]; dont: string[] }> = ({
matches,
dont,
}) => {
const hasIssues = dont.length > 0;
const { ref, hovered } = useHover();
return (
<Popover opened={hovered} position="top" width={360} withArrow withinPortal>
<Popover.Target>
<Text color={hasIssues ? "yellow" : "green"} ref={ref}>
<FontAwesomeIcon
icon={hasIssues ? faExclamationCircle : faCheckCircle}
></FontAwesomeIcon>
</Text>
</Popover.Target>
<Popover.Dropdown>
<Group position="left" spacing="xl" noWrap grow>
<Stack align="flex-start" justify="flex-start" spacing="xs" mb="auto">
<Text color="green">
<FontAwesomeIcon icon={faCheck}></FontAwesomeIcon>
</Text>
<List>
{matches.map((v, idx) => (
<List.Item key={BuildKey(idx, v, "match")}>{v}</List.Item>
))}
</List>
</Stack>
<Stack align="flex-start" justify="flex-start" spacing="xs" mb="auto">
<Text color="yellow">
<FontAwesomeIcon icon={faTimes}></FontAwesomeIcon>
</Text>
<List>
{dont.map((v, idx) => (
<List.Item key={BuildKey(idx, v, "miss")}>{v}</List.Item>
))}
</List>
</Stack>
</Group>
</Popover.Dropdown>
</Popover>
);
};

View File

@ -3,6 +3,7 @@ import { useMovieAddBlacklist, useMovieHistoryPagination } from "@/apis/hooks";
import { MutateAction } from "@/components/async";
import { HistoryIcon } from "@/components/bazarr";
import Language from "@/components/bazarr/Language";
import StateIcon from "@/components/StateIcon";
import TextPopover from "@/components/TextPopover";
import HistoryView from "@/pages/views/HistoryView";
import { useTableStyles } from "@/styles";
@ -56,6 +57,23 @@ const MoviesHistoryView: FunctionComponent = () => {
Header: "Score",
accessor: "score",
},
{
accessor: "matches",
Cell: (row) => {
const { matches, dont_matches: dont } = row.row.original;
if (matches.length || dont.length) {
return (
<StateIcon
matches={matches}
dont={dont}
isHistory={true}
></StateIcon>
);
} else {
return null;
}
},
},
{
Header: "Date",
accessor: "timestamp",

View File

@ -6,6 +6,7 @@ import {
import { MutateAction } from "@/components/async";
import { HistoryIcon } from "@/components/bazarr";
import Language from "@/components/bazarr/Language";
import StateIcon from "@/components/StateIcon";
import TextPopover from "@/components/TextPopover";
import HistoryView from "@/pages/views/HistoryView";
import { useTableStyles } from "@/styles";
@ -72,6 +73,23 @@ const SeriesHistoryView: FunctionComponent = () => {
Header: "Score",
accessor: "score",
},
{
accessor: "matches",
Cell: (row) => {
const { matches, dont_matches: dont } = row.row.original;
if (matches.length || dont.length) {
return (
<StateIcon
matches={matches}
dont={dont}
isHistory={true}
></StateIcon>
);
} else {
return null;
}
},
},
{
Header: "Date",
accessor: "timestamp",

View File

@ -126,7 +126,6 @@ declare namespace Item {
type Series = Base &
SeriesIdType & {
hearing_impaired: boolean;
episodeFileCount: number;
episodeMissingCount: number;
seriesType: SonarrSeriesType;
@ -137,12 +136,7 @@ declare namespace Item {
MovieIdType &
SubtitleType &
MissingSubtitleType &
SceneNameType & {
hearing_impaired: boolean;
audio_codec: string;
// movie_file_id: number;
tmdbId: number;
};
SceneNameType;
type Episode = PathType &
TitleType &
@ -152,13 +146,8 @@ declare namespace Item {
MissingSubtitleType &
SceneNameType &
AudioLanguageType & {
audio_codec: string;
video_codec: string;
season: number;
episode: number;
resolution: string;
format: string;
// episode_file_id: number;
};
}
@ -166,7 +155,6 @@ declare namespace Wanted {
type Base = MonitoredType &
TagType &
SceneNameType & {
// failedAttempts?: any;
hearing_impaired: boolean;
missing_subtitles: Subtitle[];
};
@ -202,16 +190,16 @@ declare namespace History {
TagType &
MonitoredType &
Partial<ItemHistoryType> & {
id: number;
action: number;
blacklisted: boolean;
score?: string;
subs_id?: string;
raw_timestamp: number;
parsed_timestamp: string;
timestamp: string;
description: string;
upgradable: boolean;
matches: string[];
dont_matches: string[];
};
type Movie = History.Base & MovieIdType & TitleType;

6
libs/alembic/__init__.py Normal file
View File

@ -0,0 +1,6 @@
import sys
from . import context
from . import op
__version__ = "1.10.3"

4
libs/alembic/__main__.py Normal file
View File

@ -0,0 +1,4 @@
from .config import main
if __name__ == "__main__":
main(prog="alembic")

View File

@ -0,0 +1,10 @@
from .api import _render_migration_diffs
from .api import compare_metadata
from .api import produce_migrations
from .api import render_python_code
from .api import RevisionContext
from .compare import _produce_net_changes
from .compare import comparators
from .render import render_op_text
from .render import renderers
from .rewriter import Rewriter

View File

@ -0,0 +1,605 @@
from __future__ import annotations
import contextlib
from typing import Any
from typing import Callable
from typing import Dict
from typing import Iterator
from typing import Optional
from typing import Set
from typing import Tuple
from typing import TYPE_CHECKING
from typing import Union
from sqlalchemy import inspect
from . import compare
from . import render
from .. import util
from ..operations import ops
"""Provide the 'autogenerate' feature which can produce migration operations
automatically."""
if TYPE_CHECKING:
from sqlalchemy.engine import Connection
from sqlalchemy.engine import Dialect
from sqlalchemy.engine import Inspector
from sqlalchemy.sql.schema import Column
from sqlalchemy.sql.schema import ForeignKeyConstraint
from sqlalchemy.sql.schema import Index
from sqlalchemy.sql.schema import MetaData
from sqlalchemy.sql.schema import Table
from sqlalchemy.sql.schema import UniqueConstraint
from alembic.config import Config
from alembic.operations.ops import MigrationScript
from alembic.operations.ops import UpgradeOps
from alembic.runtime.migration import MigrationContext
from alembic.script.base import Script
from alembic.script.base import ScriptDirectory
def compare_metadata(context: MigrationContext, metadata: MetaData) -> Any:
"""Compare a database schema to that given in a
:class:`~sqlalchemy.schema.MetaData` instance.
The database connection is presented in the context
of a :class:`.MigrationContext` object, which
provides database connectivity as well as optional
comparison functions to use for datatypes and
server defaults - see the "autogenerate" arguments
at :meth:`.EnvironmentContext.configure`
for details on these.
The return format is a list of "diff" directives,
each representing individual differences::
from alembic.migration import MigrationContext
from alembic.autogenerate import compare_metadata
from sqlalchemy.schema import SchemaItem
from sqlalchemy.types import TypeEngine
from sqlalchemy import (create_engine, MetaData, Column,
Integer, String, Table, text)
import pprint
engine = create_engine("sqlite://")
with engine.begin() as conn:
conn.execute(text('''
create table foo (
id integer not null primary key,
old_data varchar,
x integer
)'''))
conn.execute(text('''
create table bar (
data varchar
)'''))
metadata = MetaData()
Table('foo', metadata,
Column('id', Integer, primary_key=True),
Column('data', Integer),
Column('x', Integer, nullable=False)
)
Table('bat', metadata,
Column('info', String)
)
mc = MigrationContext.configure(engine.connect())
diff = compare_metadata(mc, metadata)
pprint.pprint(diff, indent=2, width=20)
Output::
[ ( 'add_table',
Table('bat', MetaData(bind=None),
Column('info', String(), table=<bat>), schema=None)),
( 'remove_table',
Table(u'bar', MetaData(bind=None),
Column(u'data', VARCHAR(), table=<bar>), schema=None)),
( 'add_column',
None,
'foo',
Column('data', Integer(), table=<foo>)),
( 'remove_column',
None,
'foo',
Column(u'old_data', VARCHAR(), table=None)),
[ ( 'modify_nullable',
None,
'foo',
u'x',
{ 'existing_server_default': None,
'existing_type': INTEGER()},
True,
False)]]
:param context: a :class:`.MigrationContext`
instance.
:param metadata: a :class:`~sqlalchemy.schema.MetaData`
instance.
.. seealso::
:func:`.produce_migrations` - produces a :class:`.MigrationScript`
structure based on metadata comparison.
"""
migration_script = produce_migrations(context, metadata)
return migration_script.upgrade_ops.as_diffs()
def produce_migrations(
context: MigrationContext, metadata: MetaData
) -> MigrationScript:
"""Produce a :class:`.MigrationScript` structure based on schema
comparison.
This function does essentially what :func:`.compare_metadata` does,
but then runs the resulting list of diffs to produce the full
:class:`.MigrationScript` object. For an example of what this looks like,
see the example in :ref:`customizing_revision`.
.. seealso::
:func:`.compare_metadata` - returns more fundamental "diff"
data from comparing a schema.
"""
autogen_context = AutogenContext(context, metadata=metadata)
migration_script = ops.MigrationScript(
rev_id=None,
upgrade_ops=ops.UpgradeOps([]),
downgrade_ops=ops.DowngradeOps([]),
)
compare._populate_migration_script(autogen_context, migration_script)
return migration_script
def render_python_code(
up_or_down_op: UpgradeOps,
sqlalchemy_module_prefix: str = "sa.",
alembic_module_prefix: str = "op.",
render_as_batch: bool = False,
imports: Tuple[str, ...] = (),
render_item: None = None,
migration_context: Optional[MigrationContext] = None,
) -> str:
"""Render Python code given an :class:`.UpgradeOps` or
:class:`.DowngradeOps` object.
This is a convenience function that can be used to test the
autogenerate output of a user-defined :class:`.MigrationScript` structure.
"""
opts = {
"sqlalchemy_module_prefix": sqlalchemy_module_prefix,
"alembic_module_prefix": alembic_module_prefix,
"render_item": render_item,
"render_as_batch": render_as_batch,
}
if migration_context is None:
from ..runtime.migration import MigrationContext
from sqlalchemy.engine.default import DefaultDialect
migration_context = MigrationContext.configure(
dialect=DefaultDialect()
)
autogen_context = AutogenContext(migration_context, opts=opts)
autogen_context.imports = set(imports)
return render._indent(
render._render_cmd_body(up_or_down_op, autogen_context)
)
def _render_migration_diffs(
context: MigrationContext, template_args: Dict[Any, Any]
) -> None:
"""legacy, used by test_autogen_composition at the moment"""
autogen_context = AutogenContext(context)
upgrade_ops = ops.UpgradeOps([])
compare._produce_net_changes(autogen_context, upgrade_ops)
migration_script = ops.MigrationScript(
rev_id=None,
upgrade_ops=upgrade_ops,
downgrade_ops=upgrade_ops.reverse(),
)
render._render_python_into_templatevars(
autogen_context, migration_script, template_args
)
class AutogenContext:
"""Maintains configuration and state that's specific to an
autogenerate operation."""
metadata: Optional[MetaData] = None
"""The :class:`~sqlalchemy.schema.MetaData` object
representing the destination.
This object is the one that is passed within ``env.py``
to the :paramref:`.EnvironmentContext.configure.target_metadata`
parameter. It represents the structure of :class:`.Table` and other
objects as stated in the current database model, and represents the
destination structure for the database being examined.
While the :class:`~sqlalchemy.schema.MetaData` object is primarily
known as a collection of :class:`~sqlalchemy.schema.Table` objects,
it also has an :attr:`~sqlalchemy.schema.MetaData.info` dictionary
that may be used by end-user schemes to store additional schema-level
objects that are to be compared in custom autogeneration schemes.
"""
connection: Optional[Connection] = None
"""The :class:`~sqlalchemy.engine.base.Connection` object currently
connected to the database backend being compared.
This is obtained from the :attr:`.MigrationContext.bind` and is
ultimately set up in the ``env.py`` script.
"""
dialect: Optional[Dialect] = None
"""The :class:`~sqlalchemy.engine.Dialect` object currently in use.
This is normally obtained from the
:attr:`~sqlalchemy.engine.base.Connection.dialect` attribute.
"""
imports: Set[str] = None # type: ignore[assignment]
"""A ``set()`` which contains string Python import directives.
The directives are to be rendered into the ``${imports}`` section
of a script template. The set is normally empty and can be modified
within hooks such as the
:paramref:`.EnvironmentContext.configure.render_item` hook.
.. seealso::
:ref:`autogen_render_types`
"""
migration_context: MigrationContext = None # type: ignore[assignment]
"""The :class:`.MigrationContext` established by the ``env.py`` script."""
def __init__(
self,
migration_context: MigrationContext,
metadata: Optional[MetaData] = None,
opts: Optional[dict] = None,
autogenerate: bool = True,
) -> None:
if (
autogenerate
and migration_context is not None
and migration_context.as_sql
):
raise util.CommandError(
"autogenerate can't use as_sql=True as it prevents querying "
"the database for schema information"
)
if opts is None:
opts = migration_context.opts
self.metadata = metadata = (
opts.get("target_metadata", None) if metadata is None else metadata
)
if (
autogenerate
and metadata is None
and migration_context is not None
and migration_context.script is not None
):
raise util.CommandError(
"Can't proceed with --autogenerate option; environment "
"script %s does not provide "
"a MetaData object or sequence of objects to the context."
% (migration_context.script.env_py_location)
)
include_object = opts.get("include_object", None)
include_name = opts.get("include_name", None)
object_filters = []
name_filters = []
if include_object:
object_filters.append(include_object)
if include_name:
name_filters.append(include_name)
self._object_filters = object_filters
self._name_filters = name_filters
self.migration_context = migration_context
if self.migration_context is not None:
self.connection = self.migration_context.bind
self.dialect = self.migration_context.dialect
self.imports = set()
self.opts: Dict[str, Any] = opts
self._has_batch: bool = False
@util.memoized_property
def inspector(self) -> Inspector:
if self.connection is None:
raise TypeError(
"can't return inspector as this "
"AutogenContext has no database connection"
)
return inspect(self.connection)
@contextlib.contextmanager
def _within_batch(self) -> Iterator[None]:
self._has_batch = True
yield
self._has_batch = False
def run_name_filters(
self,
name: Optional[str],
type_: str,
parent_names: Dict[str, Optional[str]],
) -> bool:
"""Run the context's name filters and return True if the targets
should be part of the autogenerate operation.
This method should be run for every kind of name encountered within the
reflection side of an autogenerate operation, giving the environment
the chance to filter what names should be reflected as database
objects. The filters here are produced directly via the
:paramref:`.EnvironmentContext.configure.include_name` parameter.
"""
if "schema_name" in parent_names:
if type_ == "table":
table_name = name
else:
table_name = parent_names.get("table_name", None)
if table_name:
schema_name = parent_names["schema_name"]
if schema_name:
parent_names["schema_qualified_table_name"] = "%s.%s" % (
schema_name,
table_name,
)
else:
parent_names["schema_qualified_table_name"] = table_name
for fn in self._name_filters:
if not fn(name, type_, parent_names):
return False
else:
return True
def run_object_filters(
self,
object_: Union[
Table,
Index,
Column,
UniqueConstraint,
ForeignKeyConstraint,
],
name: Optional[str],
type_: str,
reflected: bool,
compare_to: Optional[Union[Table, Index, Column, UniqueConstraint]],
) -> bool:
"""Run the context's object filters and return True if the targets
should be part of the autogenerate operation.
This method should be run for every kind of object encountered within
an autogenerate operation, giving the environment the chance
to filter what objects should be included in the comparison.
The filters here are produced directly via the
:paramref:`.EnvironmentContext.configure.include_object` parameter.
"""
for fn in self._object_filters:
if not fn(object_, name, type_, reflected, compare_to):
return False
else:
return True
run_filters = run_object_filters
@util.memoized_property
def sorted_tables(self):
"""Return an aggregate of the :attr:`.MetaData.sorted_tables`
collection(s).
For a sequence of :class:`.MetaData` objects, this
concatenates the :attr:`.MetaData.sorted_tables` collection
for each individual :class:`.MetaData` in the order of the
sequence. It does **not** collate the sorted tables collections.
"""
result = []
for m in util.to_list(self.metadata):
result.extend(m.sorted_tables)
return result
@util.memoized_property
def table_key_to_table(self):
"""Return an aggregate of the :attr:`.MetaData.tables` dictionaries.
The :attr:`.MetaData.tables` collection is a dictionary of table key
to :class:`.Table`; this method aggregates the dictionary across
multiple :class:`.MetaData` objects into one dictionary.
Duplicate table keys are **not** supported; if two :class:`.MetaData`
objects contain the same table key, an exception is raised.
"""
result = {}
for m in util.to_list(self.metadata):
intersect = set(result).intersection(set(m.tables))
if intersect:
raise ValueError(
"Duplicate table keys across multiple "
"MetaData objects: %s"
% (", ".join('"%s"' % key for key in sorted(intersect)))
)
result.update(m.tables)
return result
class RevisionContext:
"""Maintains configuration and state that's specific to a revision
file generation operation."""
def __init__(
self,
config: Config,
script_directory: ScriptDirectory,
command_args: Dict[str, Any],
process_revision_directives: Optional[Callable] = None,
) -> None:
self.config = config
self.script_directory = script_directory
self.command_args = command_args
self.process_revision_directives = process_revision_directives
self.template_args = {
"config": config # Let templates use config for
# e.g. multiple databases
}
self.generated_revisions = [self._default_revision()]
def _to_script(
self, migration_script: MigrationScript
) -> Optional[Script]:
template_args: Dict[str, Any] = self.template_args.copy()
if getattr(migration_script, "_needs_render", False):
autogen_context = self._last_autogen_context
# clear out existing imports if we are doing multiple
# renders
autogen_context.imports = set()
if migration_script.imports:
autogen_context.imports.update(migration_script.imports)
render._render_python_into_templatevars(
autogen_context, migration_script, template_args
)
assert migration_script.rev_id is not None
return self.script_directory.generate_revision(
migration_script.rev_id,
migration_script.message,
refresh=True,
head=migration_script.head,
splice=migration_script.splice,
branch_labels=migration_script.branch_label,
version_path=migration_script.version_path,
depends_on=migration_script.depends_on,
**template_args,
)
def run_autogenerate(
self, rev: tuple, migration_context: MigrationContext
) -> None:
self._run_environment(rev, migration_context, True)
def run_no_autogenerate(
self, rev: tuple, migration_context: MigrationContext
) -> None:
self._run_environment(rev, migration_context, False)
def _run_environment(
self,
rev: tuple,
migration_context: MigrationContext,
autogenerate: bool,
) -> None:
if autogenerate:
if self.command_args["sql"]:
raise util.CommandError(
"Using --sql with --autogenerate does not make any sense"
)
if set(self.script_directory.get_revisions(rev)) != set(
self.script_directory.get_revisions("heads")
):
raise util.CommandError("Target database is not up to date.")
upgrade_token = migration_context.opts["upgrade_token"]
downgrade_token = migration_context.opts["downgrade_token"]
migration_script = self.generated_revisions[-1]
if not getattr(migration_script, "_needs_render", False):
migration_script.upgrade_ops_list[-1].upgrade_token = upgrade_token
migration_script.downgrade_ops_list[
-1
].downgrade_token = downgrade_token
migration_script._needs_render = True
else:
migration_script._upgrade_ops.append(
ops.UpgradeOps([], upgrade_token=upgrade_token)
)
migration_script._downgrade_ops.append(
ops.DowngradeOps([], downgrade_token=downgrade_token)
)
autogen_context = AutogenContext(
migration_context, autogenerate=autogenerate
)
self._last_autogen_context: AutogenContext = autogen_context
if autogenerate:
compare._populate_migration_script(
autogen_context, migration_script
)
if self.process_revision_directives:
self.process_revision_directives(
migration_context, rev, self.generated_revisions
)
hook = migration_context.opts["process_revision_directives"]
if hook:
hook(migration_context, rev, self.generated_revisions)
for migration_script in self.generated_revisions:
migration_script._needs_render = True
def _default_revision(self) -> MigrationScript:
command_args: Dict[str, Any] = self.command_args
op = ops.MigrationScript(
rev_id=command_args["rev_id"] or util.rev_id(),
message=command_args["message"],
upgrade_ops=ops.UpgradeOps([]),
downgrade_ops=ops.DowngradeOps([]),
head=command_args["head"],
splice=command_args["splice"],
branch_label=command_args["branch_label"],
version_path=command_args["version_path"],
depends_on=command_args["depends_on"],
)
return op
def generate_scripts(self) -> Iterator[Optional[Script]]:
for generated_revision in self.generated_revisions:
yield self._to_script(generated_revision)

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,227 @@
from __future__ import annotations
from typing import Any
from typing import Callable
from typing import Iterator
from typing import List
from typing import Optional
from typing import Type
from typing import TYPE_CHECKING
from typing import Union
from alembic import util
from alembic.operations import ops
if TYPE_CHECKING:
from alembic.operations.ops import AddColumnOp
from alembic.operations.ops import AlterColumnOp
from alembic.operations.ops import CreateTableOp
from alembic.operations.ops import MigrateOperation
from alembic.operations.ops import MigrationScript
from alembic.operations.ops import ModifyTableOps
from alembic.operations.ops import OpContainer
from alembic.runtime.migration import MigrationContext
from alembic.script.revision import Revision
class Rewriter:
"""A helper object that allows easy 'rewriting' of ops streams.
The :class:`.Rewriter` object is intended to be passed along
to the
:paramref:`.EnvironmentContext.configure.process_revision_directives`
parameter in an ``env.py`` script. Once constructed, any number
of "rewrites" functions can be associated with it, which will be given
the opportunity to modify the structure without having to have explicit
knowledge of the overall structure.
The function is passed the :class:`.MigrationContext` object and
``revision`` tuple that are passed to the :paramref:`.Environment
Context.configure.process_revision_directives` function normally,
and the third argument is an individual directive of the type
noted in the decorator. The function has the choice of returning
a single op directive, which normally can be the directive that
was actually passed, or a new directive to replace it, or a list
of zero or more directives to replace it.
.. seealso::
:ref:`autogen_rewriter` - usage example
"""
_traverse = util.Dispatcher()
_chained: Optional[Rewriter] = None
def __init__(self) -> None:
self.dispatch = util.Dispatcher()
def chain(self, other: Rewriter) -> Rewriter:
"""Produce a "chain" of this :class:`.Rewriter` to another.
This allows two rewriters to operate serially on a stream,
e.g.::
writer1 = autogenerate.Rewriter()
writer2 = autogenerate.Rewriter()
@writer1.rewrites(ops.AddColumnOp)
def add_column_nullable(context, revision, op):
op.column.nullable = True
return op
@writer2.rewrites(ops.AddColumnOp)
def add_column_idx(context, revision, op):
idx_op = ops.CreateIndexOp(
'ixc', op.table_name, [op.column.name])
return [
op,
idx_op
]
writer = writer1.chain(writer2)
:param other: a :class:`.Rewriter` instance
:return: a new :class:`.Rewriter` that will run the operations
of this writer, then the "other" writer, in succession.
"""
wr = self.__class__.__new__(self.__class__)
wr.__dict__.update(self.__dict__)
wr._chained = other
return wr
def rewrites(
self,
operator: Union[
Type[AddColumnOp],
Type[MigrateOperation],
Type[AlterColumnOp],
Type[CreateTableOp],
Type[ModifyTableOps],
],
) -> Callable:
"""Register a function as rewriter for a given type.
The function should receive three arguments, which are
the :class:`.MigrationContext`, a ``revision`` tuple, and
an op directive of the type indicated. E.g.::
@writer1.rewrites(ops.AddColumnOp)
def add_column_nullable(context, revision, op):
op.column.nullable = True
return op
"""
return self.dispatch.dispatch_for(operator)
def _rewrite(
self,
context: MigrationContext,
revision: Revision,
directive: MigrateOperation,
) -> Iterator[MigrateOperation]:
try:
_rewriter = self.dispatch.dispatch(directive)
except ValueError:
_rewriter = None
yield directive
else:
if self in directive._mutations:
yield directive
else:
for r_directive in util.to_list(
_rewriter(context, revision, directive), []
):
r_directive._mutations = r_directive._mutations.union(
[self]
)
yield r_directive
def __call__(
self,
context: MigrationContext,
revision: Revision,
directives: List[MigrationScript],
) -> None:
self.process_revision_directives(context, revision, directives)
if self._chained:
self._chained(context, revision, directives)
@_traverse.dispatch_for(ops.MigrationScript)
def _traverse_script(
self,
context: MigrationContext,
revision: Revision,
directive: MigrationScript,
) -> None:
upgrade_ops_list = []
for upgrade_ops in directive.upgrade_ops_list:
ret = self._traverse_for(context, revision, upgrade_ops)
if len(ret) != 1:
raise ValueError(
"Can only return single object for UpgradeOps traverse"
)
upgrade_ops_list.append(ret[0])
directive.upgrade_ops = upgrade_ops_list
downgrade_ops_list = []
for downgrade_ops in directive.downgrade_ops_list:
ret = self._traverse_for(context, revision, downgrade_ops)
if len(ret) != 1:
raise ValueError(
"Can only return single object for DowngradeOps traverse"
)
downgrade_ops_list.append(ret[0])
directive.downgrade_ops = downgrade_ops_list
@_traverse.dispatch_for(ops.OpContainer)
def _traverse_op_container(
self,
context: MigrationContext,
revision: Revision,
directive: OpContainer,
) -> None:
self._traverse_list(context, revision, directive.ops)
@_traverse.dispatch_for(ops.MigrateOperation)
def _traverse_any_directive(
self,
context: MigrationContext,
revision: Revision,
directive: MigrateOperation,
) -> None:
pass
def _traverse_for(
self,
context: MigrationContext,
revision: Revision,
directive: MigrateOperation,
) -> Any:
directives = list(self._rewrite(context, revision, directive))
for directive in directives:
traverser = self._traverse.dispatch(directive)
traverser(self, context, revision, directive)
return directives
def _traverse_list(
self,
context: MigrationContext,
revision: Revision,
directives: Any,
) -> None:
dest = []
for directive in directives:
dest.extend(self._traverse_for(context, revision, directive))
directives[:] = dest
def process_revision_directives(
self,
context: MigrationContext,
revision: Revision,
directives: List[MigrationScript],
) -> None:
self._traverse_list(context, revision, directives)

730
libs/alembic/command.py Normal file
View File

@ -0,0 +1,730 @@
from __future__ import annotations
import os
from typing import List
from typing import Optional
from typing import TYPE_CHECKING
from typing import Union
from . import autogenerate as autogen
from . import util
from .runtime.environment import EnvironmentContext
from .script import ScriptDirectory
if TYPE_CHECKING:
from alembic.config import Config
from alembic.script.base import Script
from .runtime.environment import ProcessRevisionDirectiveFn
def list_templates(config):
"""List available templates.
:param config: a :class:`.Config` object.
"""
config.print_stdout("Available templates:\n")
for tempname in os.listdir(config.get_template_directory()):
with open(
os.path.join(config.get_template_directory(), tempname, "README")
) as readme:
synopsis = next(readme).rstrip()
config.print_stdout("%s - %s", tempname, synopsis)
config.print_stdout("\nTemplates are used via the 'init' command, e.g.:")
config.print_stdout("\n alembic init --template generic ./scripts")
def init(
config: Config,
directory: str,
template: str = "generic",
package: bool = False,
) -> None:
"""Initialize a new scripts directory.
:param config: a :class:`.Config` object.
:param directory: string path of the target directory
:param template: string name of the migration environment template to
use.
:param package: when True, write ``__init__.py`` files into the
environment location as well as the versions/ location.
.. versionadded:: 1.2
"""
if os.access(directory, os.F_OK) and os.listdir(directory):
raise util.CommandError(
"Directory %s already exists and is not empty" % directory
)
template_dir = os.path.join(config.get_template_directory(), template)
if not os.access(template_dir, os.F_OK):
raise util.CommandError("No such template %r" % template)
if not os.access(directory, os.F_OK):
util.status(
"Creating directory %s" % os.path.abspath(directory),
os.makedirs,
directory,
)
versions = os.path.join(directory, "versions")
util.status(
"Creating directory %s" % os.path.abspath(versions),
os.makedirs,
versions,
)
script = ScriptDirectory(directory)
for file_ in os.listdir(template_dir):
file_path = os.path.join(template_dir, file_)
if file_ == "alembic.ini.mako":
assert config.config_file_name is not None
config_file = os.path.abspath(config.config_file_name)
if os.access(config_file, os.F_OK):
util.msg("File %s already exists, skipping" % config_file)
else:
script._generate_template(
file_path, config_file, script_location=directory
)
elif os.path.isfile(file_path):
output_file = os.path.join(directory, file_)
script._copy_file(file_path, output_file)
if package:
for path in [
os.path.join(os.path.abspath(directory), "__init__.py"),
os.path.join(os.path.abspath(versions), "__init__.py"),
]:
file_ = util.status("Adding %s" % path, open, path, "w")
file_.close() # type:ignore[attr-defined]
util.msg(
"Please edit configuration/connection/logging "
"settings in %r before proceeding." % config_file
)
def revision(
config: Config,
message: Optional[str] = None,
autogenerate: bool = False,
sql: bool = False,
head: str = "head",
splice: bool = False,
branch_label: Optional[str] = None,
version_path: Optional[str] = None,
rev_id: Optional[str] = None,
depends_on: Optional[str] = None,
process_revision_directives: Optional[ProcessRevisionDirectiveFn] = None,
) -> Union[Optional[Script], List[Optional[Script]]]:
"""Create a new revision file.
:param config: a :class:`.Config` object.
:param message: string message to apply to the revision; this is the
``-m`` option to ``alembic revision``.
:param autogenerate: whether or not to autogenerate the script from
the database; this is the ``--autogenerate`` option to
``alembic revision``.
:param sql: whether to dump the script out as a SQL string; when specified,
the script is dumped to stdout. This is the ``--sql`` option to
``alembic revision``.
:param head: head revision to build the new revision upon as a parent;
this is the ``--head`` option to ``alembic revision``.
:param splice: whether or not the new revision should be made into a
new head of its own; is required when the given ``head`` is not itself
a head. This is the ``--splice`` option to ``alembic revision``.
:param branch_label: string label to apply to the branch; this is the
``--branch-label`` option to ``alembic revision``.
:param version_path: string symbol identifying a specific version path
from the configuration; this is the ``--version-path`` option to
``alembic revision``.
:param rev_id: optional revision identifier to use instead of having
one generated; this is the ``--rev-id`` option to ``alembic revision``.
:param depends_on: optional list of "depends on" identifiers; this is the
``--depends-on`` option to ``alembic revision``.
:param process_revision_directives: this is a callable that takes the
same form as the callable described at
:paramref:`.EnvironmentContext.configure.process_revision_directives`;
will be applied to the structure generated by the revision process
where it can be altered programmatically. Note that unlike all
the other parameters, this option is only available via programmatic
use of :func:`.command.revision`
"""
script_directory = ScriptDirectory.from_config(config)
command_args = dict(
message=message,
autogenerate=autogenerate,
sql=sql,
head=head,
splice=splice,
branch_label=branch_label,
version_path=version_path,
rev_id=rev_id,
depends_on=depends_on,
)
revision_context = autogen.RevisionContext(
config,
script_directory,
command_args,
process_revision_directives=process_revision_directives,
)
environment = util.asbool(config.get_main_option("revision_environment"))
if autogenerate:
environment = True
if sql:
raise util.CommandError(
"Using --sql with --autogenerate does not make any sense"
)
def retrieve_migrations(rev, context):
revision_context.run_autogenerate(rev, context)
return []
elif environment:
def retrieve_migrations(rev, context):
revision_context.run_no_autogenerate(rev, context)
return []
elif sql:
raise util.CommandError(
"Using --sql with the revision command when "
"revision_environment is not configured does not make any sense"
)
if environment:
with EnvironmentContext(
config,
script_directory,
fn=retrieve_migrations,
as_sql=sql,
template_args=revision_context.template_args,
revision_context=revision_context,
):
script_directory.run_env()
# the revision_context now has MigrationScript structure(s) present.
# these could theoretically be further processed / rewritten *here*,
# in addition to the hooks present within each run_migrations() call,
# or at the end of env.py run_migrations_online().
scripts = [script for script in revision_context.generate_scripts()]
if len(scripts) == 1:
return scripts[0]
else:
return scripts
def check(
config: "Config",
) -> None:
"""Check if revision command with autogenerate has pending upgrade ops.
:param config: a :class:`.Config` object.
.. versionadded:: 1.9.0
"""
script_directory = ScriptDirectory.from_config(config)
command_args = dict(
message=None,
autogenerate=True,
sql=False,
head="head",
splice=False,
branch_label=None,
version_path=None,
rev_id=None,
depends_on=None,
)
revision_context = autogen.RevisionContext(
config,
script_directory,
command_args,
)
def retrieve_migrations(rev, context):
revision_context.run_autogenerate(rev, context)
return []
with EnvironmentContext(
config,
script_directory,
fn=retrieve_migrations,
as_sql=False,
template_args=revision_context.template_args,
revision_context=revision_context,
):
script_directory.run_env()
# the revision_context now has MigrationScript structure(s) present.
migration_script = revision_context.generated_revisions[-1]
diffs = migration_script.upgrade_ops.as_diffs()
if diffs:
raise util.AutogenerateDiffsDetected(
f"New upgrade operations detected: {diffs}"
)
else:
config.print_stdout("No new upgrade operations detected.")
def merge(
config: Config,
revisions: str,
message: Optional[str] = None,
branch_label: Optional[str] = None,
rev_id: Optional[str] = None,
) -> Optional[Script]:
"""Merge two revisions together. Creates a new migration file.
:param config: a :class:`.Config` instance
:param message: string message to apply to the revision
:param branch_label: string label name to apply to the new revision
:param rev_id: hardcoded revision identifier instead of generating a new
one.
.. seealso::
:ref:`branches`
"""
script = ScriptDirectory.from_config(config)
template_args = {
"config": config # Let templates use config for
# e.g. multiple databases
}
return script.generate_revision(
rev_id or util.rev_id(),
message,
refresh=True,
head=revisions,
branch_labels=branch_label,
**template_args, # type:ignore[arg-type]
)
def upgrade(
config: Config,
revision: str,
sql: bool = False,
tag: Optional[str] = None,
) -> None:
"""Upgrade to a later version.
:param config: a :class:`.Config` instance.
:param revision: string revision target or range for --sql mode
:param sql: if True, use ``--sql`` mode
:param tag: an arbitrary "tag" that can be intercepted by custom
``env.py`` scripts via the :meth:`.EnvironmentContext.get_tag_argument`
method.
"""
script = ScriptDirectory.from_config(config)
starting_rev = None
if ":" in revision:
if not sql:
raise util.CommandError("Range revision not allowed")
starting_rev, revision = revision.split(":", 2)
def upgrade(rev, context):
return script._upgrade_revs(revision, rev)
with EnvironmentContext(
config,
script,
fn=upgrade,
as_sql=sql,
starting_rev=starting_rev,
destination_rev=revision,
tag=tag,
):
script.run_env()
def downgrade(
config: Config,
revision: str,
sql: bool = False,
tag: Optional[str] = None,
) -> None:
"""Revert to a previous version.
:param config: a :class:`.Config` instance.
:param revision: string revision target or range for --sql mode
:param sql: if True, use ``--sql`` mode
:param tag: an arbitrary "tag" that can be intercepted by custom
``env.py`` scripts via the :meth:`.EnvironmentContext.get_tag_argument`
method.
"""
script = ScriptDirectory.from_config(config)
starting_rev = None
if ":" in revision:
if not sql:
raise util.CommandError("Range revision not allowed")
starting_rev, revision = revision.split(":", 2)
elif sql:
raise util.CommandError(
"downgrade with --sql requires <fromrev>:<torev>"
)
def downgrade(rev, context):
return script._downgrade_revs(revision, rev)
with EnvironmentContext(
config,
script,
fn=downgrade,
as_sql=sql,
starting_rev=starting_rev,
destination_rev=revision,
tag=tag,
):
script.run_env()
def show(config, rev):
"""Show the revision(s) denoted by the given symbol.
:param config: a :class:`.Config` instance.
:param revision: string revision target
"""
script = ScriptDirectory.from_config(config)
if rev == "current":
def show_current(rev, context):
for sc in script.get_revisions(rev):
config.print_stdout(sc.log_entry)
return []
with EnvironmentContext(config, script, fn=show_current):
script.run_env()
else:
for sc in script.get_revisions(rev):
config.print_stdout(sc.log_entry)
def history(
config: Config,
rev_range: Optional[str] = None,
verbose: bool = False,
indicate_current: bool = False,
) -> None:
"""List changeset scripts in chronological order.
:param config: a :class:`.Config` instance.
:param rev_range: string revision range
:param verbose: output in verbose mode.
:param indicate_current: indicate current revision.
"""
base: Optional[str]
head: Optional[str]
script = ScriptDirectory.from_config(config)
if rev_range is not None:
if ":" not in rev_range:
raise util.CommandError(
"History range requires [start]:[end], " "[start]:, or :[end]"
)
base, head = rev_range.strip().split(":")
else:
base = head = None
environment = (
util.asbool(config.get_main_option("revision_environment"))
or indicate_current
)
def _display_history(config, script, base, head, currents=()):
for sc in script.walk_revisions(
base=base or "base", head=head or "heads"
):
if indicate_current:
sc._db_current_indicator = sc.revision in currents
config.print_stdout(
sc.cmd_format(
verbose=verbose,
include_branches=True,
include_doc=True,
include_parents=True,
)
)
def _display_history_w_current(config, script, base, head):
def _display_current_history(rev, context):
if head == "current":
_display_history(config, script, base, rev, rev)
elif base == "current":
_display_history(config, script, rev, head, rev)
else:
_display_history(config, script, base, head, rev)
return []
with EnvironmentContext(config, script, fn=_display_current_history):
script.run_env()
if base == "current" or head == "current" or environment:
_display_history_w_current(config, script, base, head)
else:
_display_history(config, script, base, head)
def heads(config, verbose=False, resolve_dependencies=False):
"""Show current available heads in the script directory.
:param config: a :class:`.Config` instance.
:param verbose: output in verbose mode.
:param resolve_dependencies: treat dependency version as down revisions.
"""
script = ScriptDirectory.from_config(config)
if resolve_dependencies:
heads = script.get_revisions("heads")
else:
heads = script.get_revisions(script.get_heads())
for rev in heads:
config.print_stdout(
rev.cmd_format(
verbose, include_branches=True, tree_indicators=False
)
)
def branches(config, verbose=False):
"""Show current branch points.
:param config: a :class:`.Config` instance.
:param verbose: output in verbose mode.
"""
script = ScriptDirectory.from_config(config)
for sc in script.walk_revisions():
if sc.is_branch_point:
config.print_stdout(
"%s\n%s\n",
sc.cmd_format(verbose, include_branches=True),
"\n".join(
"%s -> %s"
% (
" " * len(str(sc.revision)),
rev_obj.cmd_format(
False, include_branches=True, include_doc=verbose
),
)
for rev_obj in (
script.get_revision(rev) for rev in sc.nextrev
)
),
)
def current(config: Config, verbose: bool = False) -> None:
"""Display the current revision for a database.
:param config: a :class:`.Config` instance.
:param verbose: output in verbose mode.
"""
script = ScriptDirectory.from_config(config)
def display_version(rev, context):
if verbose:
config.print_stdout(
"Current revision(s) for %s:",
util.obfuscate_url_pw(context.connection.engine.url),
)
for rev in script.get_all_current(rev):
config.print_stdout(rev.cmd_format(verbose))
return []
with EnvironmentContext(
config, script, fn=display_version, dont_mutate=True
):
script.run_env()
def stamp(
config: Config,
revision: str,
sql: bool = False,
tag: Optional[str] = None,
purge: bool = False,
) -> None:
"""'stamp' the revision table with the given revision; don't
run any migrations.
:param config: a :class:`.Config` instance.
:param revision: target revision or list of revisions. May be a list
to indicate stamping of multiple branch heads.
.. note:: this parameter is called "revisions" in the command line
interface.
.. versionchanged:: 1.2 The revision may be a single revision or
list of revisions when stamping multiple branch heads.
:param sql: use ``--sql`` mode
:param tag: an arbitrary "tag" that can be intercepted by custom
``env.py`` scripts via the :class:`.EnvironmentContext.get_tag_argument`
method.
:param purge: delete all entries in the version table before stamping.
.. versionadded:: 1.2
"""
script = ScriptDirectory.from_config(config)
if sql:
destination_revs = []
starting_rev = None
for _revision in util.to_list(revision):
if ":" in _revision:
srev, _revision = _revision.split(":", 2)
if starting_rev != srev:
if starting_rev is None:
starting_rev = srev
else:
raise util.CommandError(
"Stamp operation with --sql only supports a "
"single starting revision at a time"
)
destination_revs.append(_revision)
else:
destination_revs = util.to_list(revision)
def do_stamp(rev, context):
return script._stamp_revs(util.to_tuple(destination_revs), rev)
with EnvironmentContext(
config,
script,
fn=do_stamp,
as_sql=sql,
starting_rev=starting_rev if sql else None,
destination_rev=util.to_tuple(destination_revs),
tag=tag,
purge=purge,
):
script.run_env()
def edit(config: Config, rev: str) -> None:
"""Edit revision script(s) using $EDITOR.
:param config: a :class:`.Config` instance.
:param rev: target revision.
"""
script = ScriptDirectory.from_config(config)
if rev == "current":
def edit_current(rev, context):
if not rev:
raise util.CommandError("No current revisions")
for sc in script.get_revisions(rev):
util.open_in_editor(sc.path)
return []
with EnvironmentContext(config, script, fn=edit_current):
script.run_env()
else:
revs = script.get_revisions(rev)
if not revs:
raise util.CommandError(
"No revision files indicated by symbol '%s'" % rev
)
for sc in revs:
assert sc
util.open_in_editor(sc.path)
def ensure_version(config: Config, sql: bool = False) -> None:
"""Create the alembic version table if it doesn't exist already .
:param config: a :class:`.Config` instance.
:param sql: use ``--sql`` mode
.. versionadded:: 1.7.6
"""
script = ScriptDirectory.from_config(config)
def do_ensure_version(rev, context):
context._ensure_version_table()
return []
with EnvironmentContext(
config,
script,
fn=do_ensure_version,
as_sql=sql,
):
script.run_env()

595
libs/alembic/config.py Normal file
View File

@ -0,0 +1,595 @@
from __future__ import annotations
from argparse import ArgumentParser
from argparse import Namespace
from configparser import ConfigParser
import inspect
import os
import sys
from typing import Dict
from typing import Optional
from typing import overload
from typing import TextIO
from typing import Union
from . import __version__
from . import command
from . import util
from .util import compat
class Config:
r"""Represent an Alembic configuration.
Within an ``env.py`` script, this is available
via the :attr:`.EnvironmentContext.config` attribute,
which in turn is available at ``alembic.context``::
from alembic import context
some_param = context.config.get_main_option("my option")
When invoking Alembic programatically, a new
:class:`.Config` can be created by passing
the name of an .ini file to the constructor::
from alembic.config import Config
alembic_cfg = Config("/path/to/yourapp/alembic.ini")
With a :class:`.Config` object, you can then
run Alembic commands programmatically using the directives
in :mod:`alembic.command`.
The :class:`.Config` object can also be constructed without
a filename. Values can be set programmatically, and
new sections will be created as needed::
from alembic.config import Config
alembic_cfg = Config()
alembic_cfg.set_main_option("script_location", "myapp:migrations")
alembic_cfg.set_main_option("sqlalchemy.url", "postgresql://foo/bar")
alembic_cfg.set_section_option("mysection", "foo", "bar")
.. warning::
When using programmatic configuration, make sure the
``env.py`` file in use is compatible with the target configuration;
including that the call to Python ``logging.fileConfig()`` is
omitted if the programmatic configuration doesn't actually include
logging directives.
For passing non-string values to environments, such as connections and
engines, use the :attr:`.Config.attributes` dictionary::
with engine.begin() as connection:
alembic_cfg.attributes['connection'] = connection
command.upgrade(alembic_cfg, "head")
:param file\_: name of the .ini file to open.
:param ini_section: name of the main Alembic section within the
.ini file
:param output_buffer: optional file-like input buffer which
will be passed to the :class:`.MigrationContext` - used to redirect
the output of "offline generation" when using Alembic programmatically.
:param stdout: buffer where the "print" output of commands will be sent.
Defaults to ``sys.stdout``.
:param config_args: A dictionary of keys and values that will be used
for substitution in the alembic config file. The dictionary as given
is **copied** to a new one, stored locally as the attribute
``.config_args``. When the :attr:`.Config.file_config` attribute is
first invoked, the replacement variable ``here`` will be added to this
dictionary before the dictionary is passed to ``ConfigParser()``
to parse the .ini file.
:param attributes: optional dictionary of arbitrary Python keys/values,
which will be populated into the :attr:`.Config.attributes` dictionary.
.. seealso::
:ref:`connection_sharing`
"""
def __init__(
self,
file_: Union[str, os.PathLike[str], None] = None,
ini_section: str = "alembic",
output_buffer: Optional[TextIO] = None,
stdout: TextIO = sys.stdout,
cmd_opts: Optional[Namespace] = None,
config_args: util.immutabledict = util.immutabledict(),
attributes: Optional[dict] = None,
) -> None:
"""Construct a new :class:`.Config`"""
self.config_file_name = file_
self.config_ini_section = ini_section
self.output_buffer = output_buffer
self.stdout = stdout
self.cmd_opts = cmd_opts
self.config_args = dict(config_args)
if attributes:
self.attributes.update(attributes)
cmd_opts: Optional[Namespace] = None
"""The command-line options passed to the ``alembic`` script.
Within an ``env.py`` script this can be accessed via the
:attr:`.EnvironmentContext.config` attribute.
.. seealso::
:meth:`.EnvironmentContext.get_x_argument`
"""
config_file_name: Union[str, os.PathLike[str], None] = None
"""Filesystem path to the .ini file in use."""
config_ini_section: str = None # type:ignore[assignment]
"""Name of the config file section to read basic configuration
from. Defaults to ``alembic``, that is the ``[alembic]`` section
of the .ini file. This value is modified using the ``-n/--name``
option to the Alembic runner.
"""
@util.memoized_property
def attributes(self):
"""A Python dictionary for storage of additional state.
This is a utility dictionary which can include not just strings but
engines, connections, schema objects, or anything else.
Use this to pass objects into an env.py script, such as passing
a :class:`sqlalchemy.engine.base.Connection` when calling
commands from :mod:`alembic.command` programmatically.
.. seealso::
:ref:`connection_sharing`
:paramref:`.Config.attributes`
"""
return {}
def print_stdout(self, text: str, *arg) -> None:
"""Render a message to standard out.
When :meth:`.Config.print_stdout` is called with additional args
those arguments will formatted against the provided text,
otherwise we simply output the provided text verbatim.
e.g.::
>>> config.print_stdout('Some text %s', 'arg')
Some Text arg
"""
if arg:
output = str(text) % arg
else:
output = str(text)
util.write_outstream(self.stdout, output, "\n")
@util.memoized_property
def file_config(self):
"""Return the underlying ``ConfigParser`` object.
Direct access to the .ini file is available here,
though the :meth:`.Config.get_section` and
:meth:`.Config.get_main_option`
methods provide a possibly simpler interface.
"""
if self.config_file_name:
here = os.path.abspath(os.path.dirname(self.config_file_name))
else:
here = ""
self.config_args["here"] = here
file_config = ConfigParser(self.config_args)
if self.config_file_name:
file_config.read([self.config_file_name])
else:
file_config.add_section(self.config_ini_section)
return file_config
def get_template_directory(self) -> str:
"""Return the directory where Alembic setup templates are found.
This method is used by the alembic ``init`` and ``list_templates``
commands.
"""
import alembic
package_dir = os.path.abspath(os.path.dirname(alembic.__file__))
return os.path.join(package_dir, "templates")
@overload
def get_section(
self, name: str, default: Dict[str, str]
) -> Dict[str, str]:
...
@overload
def get_section(
self, name: str, default: Optional[Dict[str, str]] = ...
) -> Optional[Dict[str, str]]:
...
def get_section(self, name: str, default=None):
"""Return all the configuration options from a given .ini file section
as a dictionary.
"""
if not self.file_config.has_section(name):
return default
return dict(self.file_config.items(name))
def set_main_option(self, name: str, value: str) -> None:
"""Set an option programmatically within the 'main' section.
This overrides whatever was in the .ini file.
:param name: name of the value
:param value: the value. Note that this value is passed to
``ConfigParser.set``, which supports variable interpolation using
pyformat (e.g. ``%(some_value)s``). A raw percent sign not part of
an interpolation symbol must therefore be escaped, e.g. ``%%``.
The given value may refer to another value already in the file
using the interpolation format.
"""
self.set_section_option(self.config_ini_section, name, value)
def remove_main_option(self, name: str) -> None:
self.file_config.remove_option(self.config_ini_section, name)
def set_section_option(self, section: str, name: str, value: str) -> None:
"""Set an option programmatically within the given section.
The section is created if it doesn't exist already.
The value here will override whatever was in the .ini
file.
:param section: name of the section
:param name: name of the value
:param value: the value. Note that this value is passed to
``ConfigParser.set``, which supports variable interpolation using
pyformat (e.g. ``%(some_value)s``). A raw percent sign not part of
an interpolation symbol must therefore be escaped, e.g. ``%%``.
The given value may refer to another value already in the file
using the interpolation format.
"""
if not self.file_config.has_section(section):
self.file_config.add_section(section)
self.file_config.set(section, name, value)
def get_section_option(
self, section: str, name: str, default: Optional[str] = None
) -> Optional[str]:
"""Return an option from the given section of the .ini file."""
if not self.file_config.has_section(section):
raise util.CommandError(
"No config file %r found, or file has no "
"'[%s]' section" % (self.config_file_name, section)
)
if self.file_config.has_option(section, name):
return self.file_config.get(section, name)
else:
return default
@overload
def get_main_option(self, name: str, default: str) -> str:
...
@overload
def get_main_option(
self, name: str, default: Optional[str] = None
) -> Optional[str]:
...
def get_main_option(self, name, default=None):
"""Return an option from the 'main' section of the .ini file.
This defaults to being a key from the ``[alembic]``
section, unless the ``-n/--name`` flag were used to
indicate a different section.
"""
return self.get_section_option(self.config_ini_section, name, default)
class CommandLine:
def __init__(self, prog: Optional[str] = None) -> None:
self._generate_args(prog)
def _generate_args(self, prog: Optional[str]) -> None:
def add_options(fn, parser, positional, kwargs):
kwargs_opts = {
"template": (
"-t",
"--template",
dict(
default="generic",
type=str,
help="Setup template for use with 'init'",
),
),
"message": (
"-m",
"--message",
dict(
type=str, help="Message string to use with 'revision'"
),
),
"sql": (
"--sql",
dict(
action="store_true",
help="Don't emit SQL to database - dump to "
"standard output/file instead. See docs on "
"offline mode.",
),
),
"tag": (
"--tag",
dict(
type=str,
help="Arbitrary 'tag' name - can be used by "
"custom env.py scripts.",
),
),
"head": (
"--head",
dict(
type=str,
help="Specify head revision or <branchname>@head "
"to base new revision on.",
),
),
"splice": (
"--splice",
dict(
action="store_true",
help="Allow a non-head revision as the "
"'head' to splice onto",
),
),
"depends_on": (
"--depends-on",
dict(
action="append",
help="Specify one or more revision identifiers "
"which this revision should depend on.",
),
),
"rev_id": (
"--rev-id",
dict(
type=str,
help="Specify a hardcoded revision id instead of "
"generating one",
),
),
"version_path": (
"--version-path",
dict(
type=str,
help="Specify specific path from config for "
"version file",
),
),
"branch_label": (
"--branch-label",
dict(
type=str,
help="Specify a branch label to apply to the "
"new revision",
),
),
"verbose": (
"-v",
"--verbose",
dict(action="store_true", help="Use more verbose output"),
),
"resolve_dependencies": (
"--resolve-dependencies",
dict(
action="store_true",
help="Treat dependency versions as down revisions",
),
),
"autogenerate": (
"--autogenerate",
dict(
action="store_true",
help="Populate revision script with candidate "
"migration operations, based on comparison "
"of database to model.",
),
),
"rev_range": (
"-r",
"--rev-range",
dict(
action="store",
help="Specify a revision range; "
"format is [start]:[end]",
),
),
"indicate_current": (
"-i",
"--indicate-current",
dict(
action="store_true",
help="Indicate the current revision",
),
),
"purge": (
"--purge",
dict(
action="store_true",
help="Unconditionally erase the version table "
"before stamping",
),
),
"package": (
"--package",
dict(
action="store_true",
help="Write empty __init__.py files to the "
"environment and version locations",
),
),
}
positional_help = {
"directory": "location of scripts directory",
"revision": "revision identifier",
"revisions": "one or more revisions, or 'heads' for all heads",
}
for arg in kwargs:
if arg in kwargs_opts:
args = kwargs_opts[arg]
args, kw = args[0:-1], args[-1]
parser.add_argument(*args, **kw)
for arg in positional:
if (
arg == "revisions"
or fn in positional_translations
and positional_translations[fn][arg] == "revisions"
):
subparser.add_argument(
"revisions",
nargs="+",
help=positional_help.get("revisions"),
)
else:
subparser.add_argument(arg, help=positional_help.get(arg))
parser = ArgumentParser(prog=prog)
parser.add_argument(
"--version", action="version", version="%%(prog)s %s" % __version__
)
parser.add_argument(
"-c",
"--config",
type=str,
default=os.environ.get("ALEMBIC_CONFIG", "alembic.ini"),
help="Alternate config file; defaults to value of "
'ALEMBIC_CONFIG environment variable, or "alembic.ini"',
)
parser.add_argument(
"-n",
"--name",
type=str,
default="alembic",
help="Name of section in .ini file to " "use for Alembic config",
)
parser.add_argument(
"-x",
action="append",
help="Additional arguments consumed by "
"custom env.py scripts, e.g. -x "
"setting1=somesetting -x setting2=somesetting",
)
parser.add_argument(
"--raiseerr",
action="store_true",
help="Raise a full stack trace on error",
)
subparsers = parser.add_subparsers()
positional_translations = {command.stamp: {"revision": "revisions"}}
for fn in [getattr(command, n) for n in dir(command)]:
if (
inspect.isfunction(fn)
and fn.__name__[0] != "_"
and fn.__module__ == "alembic.command"
):
spec = compat.inspect_getfullargspec(fn)
if spec[3] is not None:
positional = spec[0][1 : -len(spec[3])]
kwarg = spec[0][-len(spec[3]) :]
else:
positional = spec[0][1:]
kwarg = []
if fn in positional_translations:
positional = [
positional_translations[fn].get(name, name)
for name in positional
]
# parse first line(s) of helptext without a line break
help_ = fn.__doc__
if help_:
help_text = []
for line in help_.split("\n"):
if not line.strip():
break
else:
help_text.append(line.strip())
else:
help_text = []
subparser = subparsers.add_parser(
fn.__name__, help=" ".join(help_text)
)
add_options(fn, subparser, positional, kwarg)
subparser.set_defaults(cmd=(fn, positional, kwarg))
self.parser = parser
def run_cmd(self, config: Config, options: Namespace) -> None:
fn, positional, kwarg = options.cmd
try:
fn(
config,
*[getattr(options, k, None) for k in positional],
**{k: getattr(options, k, None) for k in kwarg},
)
except util.CommandError as e:
if options.raiseerr:
raise
else:
util.err(str(e))
def main(self, argv=None):
options = self.parser.parse_args(argv)
if not hasattr(options, "cmd"):
# see http://bugs.python.org/issue9253, argparse
# behavior changed incompatibly in py3.3
self.parser.error("too few arguments")
else:
cfg = Config(
file_=options.config,
ini_section=options.name,
cmd_opts=options,
)
self.run_cmd(cfg, options)
def main(argv=None, prog=None, **kwargs):
"""The console runner function for Alembic."""
CommandLine(prog=prog).main(argv=argv)
if __name__ == "__main__":
main()

5
libs/alembic/context.py Normal file
View File

@ -0,0 +1,5 @@
from .runtime.environment import EnvironmentContext
# create proxy functions for
# each method on the EnvironmentContext class.
EnvironmentContext.create_module_class_proxy(globals(), locals())

753
libs/alembic/context.pyi Normal file
View File

@ -0,0 +1,753 @@
# ### this file stubs are generated by tools/write_pyi.py - do not edit ###
# ### imports are manually managed
from __future__ import annotations
from typing import Any
from typing import Callable
from typing import ContextManager
from typing import Dict
from typing import List
from typing import Literal
from typing import Optional
from typing import overload
from typing import TextIO
from typing import Tuple
from typing import TYPE_CHECKING
from typing import Union
if TYPE_CHECKING:
from sqlalchemy.engine.base import Connection
from sqlalchemy.engine.url import URL
from sqlalchemy.sql.elements import ClauseElement
from sqlalchemy.sql.schema import MetaData
from .config import Config
from .operations import MigrateOperation
from .runtime.migration import _ProxyTransaction
from .runtime.migration import MigrationContext
from .script import ScriptDirectory
### end imports ###
def begin_transaction() -> Union[_ProxyTransaction, ContextManager[None]]:
r"""Return a context manager that will
enclose an operation within a "transaction",
as defined by the environment's offline
and transactional DDL settings.
e.g.::
with context.begin_transaction():
context.run_migrations()
:meth:`.begin_transaction` is intended to
"do the right thing" regardless of
calling context:
* If :meth:`.is_transactional_ddl` is ``False``,
returns a "do nothing" context manager
which otherwise produces no transactional
state or directives.
* If :meth:`.is_offline_mode` is ``True``,
returns a context manager that will
invoke the :meth:`.DefaultImpl.emit_begin`
and :meth:`.DefaultImpl.emit_commit`
methods, which will produce the string
directives ``BEGIN`` and ``COMMIT`` on
the output stream, as rendered by the
target backend (e.g. SQL Server would
emit ``BEGIN TRANSACTION``).
* Otherwise, calls :meth:`sqlalchemy.engine.Connection.begin`
on the current online connection, which
returns a :class:`sqlalchemy.engine.Transaction`
object. This object demarcates a real
transaction and is itself a context manager,
which will roll back if an exception
is raised.
Note that a custom ``env.py`` script which
has more specific transactional needs can of course
manipulate the :class:`~sqlalchemy.engine.Connection`
directly to produce transactional state in "online"
mode.
"""
config: Config
def configure(
connection: Optional[Connection] = None,
url: Union[str, URL, None] = None,
dialect_name: Optional[str] = None,
dialect_opts: Optional[Dict[str, Any]] = None,
transactional_ddl: Optional[bool] = None,
transaction_per_migration: bool = False,
output_buffer: Optional[TextIO] = None,
starting_rev: Optional[str] = None,
tag: Optional[str] = None,
template_args: Optional[Dict[str, Any]] = None,
render_as_batch: bool = False,
target_metadata: Optional[MetaData] = None,
include_name: Optional[Callable[..., bool]] = None,
include_object: Optional[Callable[..., bool]] = None,
include_schemas: bool = False,
process_revision_directives: Optional[
Callable[
[MigrationContext, Tuple[str, str], List[MigrateOperation]], None
]
] = None,
compare_type: bool = False,
compare_server_default: bool = False,
render_item: Optional[Callable[..., bool]] = None,
literal_binds: bool = False,
upgrade_token: str = "upgrades",
downgrade_token: str = "downgrades",
alembic_module_prefix: str = "op.",
sqlalchemy_module_prefix: str = "sa.",
user_module_prefix: Optional[str] = None,
on_version_apply: Optional[Callable[..., None]] = None,
**kw: Any,
) -> None:
r"""Configure a :class:`.MigrationContext` within this
:class:`.EnvironmentContext` which will provide database
connectivity and other configuration to a series of
migration scripts.
Many methods on :class:`.EnvironmentContext` require that
this method has been called in order to function, as they
ultimately need to have database access or at least access
to the dialect in use. Those which do are documented as such.
The important thing needed by :meth:`.configure` is a
means to determine what kind of database dialect is in use.
An actual connection to that database is needed only if
the :class:`.MigrationContext` is to be used in
"online" mode.
If the :meth:`.is_offline_mode` function returns ``True``,
then no connection is needed here. Otherwise, the
``connection`` parameter should be present as an
instance of :class:`sqlalchemy.engine.Connection`.
This function is typically called from the ``env.py``
script within a migration environment. It can be called
multiple times for an invocation. The most recent
:class:`~sqlalchemy.engine.Connection`
for which it was called is the one that will be operated upon
by the next call to :meth:`.run_migrations`.
General parameters:
:param connection: a :class:`~sqlalchemy.engine.Connection`
to use
for SQL execution in "online" mode. When present, is also
used to determine the type of dialect in use.
:param url: a string database url, or a
:class:`sqlalchemy.engine.url.URL` object.
The type of dialect to be used will be derived from this if
``connection`` is not passed.
:param dialect_name: string name of a dialect, such as
"postgresql", "mssql", etc.
The type of dialect to be used will be derived from this if
``connection`` and ``url`` are not passed.
:param dialect_opts: dictionary of options to be passed to dialect
constructor.
.. versionadded:: 1.0.12
:param transactional_ddl: Force the usage of "transactional"
DDL on or off;
this otherwise defaults to whether or not the dialect in
use supports it.
:param transaction_per_migration: if True, nest each migration script
in a transaction rather than the full series of migrations to
run.
:param output_buffer: a file-like object that will be used
for textual output
when the ``--sql`` option is used to generate SQL scripts.
Defaults to
``sys.stdout`` if not passed here and also not present on
the :class:`.Config`
object. The value here overrides that of the :class:`.Config`
object.
:param output_encoding: when using ``--sql`` to generate SQL
scripts, apply this encoding to the string output.
:param literal_binds: when using ``--sql`` to generate SQL
scripts, pass through the ``literal_binds`` flag to the compiler
so that any literal values that would ordinarily be bound
parameters are converted to plain strings.
.. warning:: Dialects can typically only handle simple datatypes
like strings and numbers for auto-literal generation. Datatypes
like dates, intervals, and others may still require manual
formatting, typically using :meth:`.Operations.inline_literal`.
.. note:: the ``literal_binds`` flag is ignored on SQLAlchemy
versions prior to 0.8 where this feature is not supported.
.. seealso::
:meth:`.Operations.inline_literal`
:param starting_rev: Override the "starting revision" argument
when using ``--sql`` mode.
:param tag: a string tag for usage by custom ``env.py`` scripts.
Set via the ``--tag`` option, can be overridden here.
:param template_args: dictionary of template arguments which
will be added to the template argument environment when
running the "revision" command. Note that the script environment
is only run within the "revision" command if the --autogenerate
option is used, or if the option "revision_environment=true"
is present in the alembic.ini file.
:param version_table: The name of the Alembic version table.
The default is ``'alembic_version'``.
:param version_table_schema: Optional schema to place version
table within.
:param version_table_pk: boolean, whether the Alembic version table
should use a primary key constraint for the "value" column; this
only takes effect when the table is first created.
Defaults to True; setting to False should not be necessary and is
here for backwards compatibility reasons.
:param on_version_apply: a callable or collection of callables to be
run for each migration step.
The callables will be run in the order they are given, once for
each migration step, after the respective operation has been
applied but before its transaction is finalized.
Each callable accepts no positional arguments and the following
keyword arguments:
* ``ctx``: the :class:`.MigrationContext` running the migration,
* ``step``: a :class:`.MigrationInfo` representing the
step currently being applied,
* ``heads``: a collection of version strings representing the
current heads,
* ``run_args``: the ``**kwargs`` passed to :meth:`.run_migrations`.
Parameters specific to the autogenerate feature, when
``alembic revision`` is run with the ``--autogenerate`` feature:
:param target_metadata: a :class:`sqlalchemy.schema.MetaData`
object, or a sequence of :class:`~sqlalchemy.schema.MetaData`
objects, that will be consulted during autogeneration.
The tables present in each :class:`~sqlalchemy.schema.MetaData`
will be compared against
what is locally available on the target
:class:`~sqlalchemy.engine.Connection`
to produce candidate upgrade/downgrade operations.
:param compare_type: Indicates type comparison behavior during
an autogenerate
operation. Defaults to ``False`` which disables type
comparison. Set to
``True`` to turn on default type comparison, which has varied
accuracy depending on backend. See :ref:`compare_types`
for an example as well as information on other type
comparison options.
.. seealso::
:ref:`compare_types`
:paramref:`.EnvironmentContext.configure.compare_server_default`
:param compare_server_default: Indicates server default comparison
behavior during
an autogenerate operation. Defaults to ``False`` which disables
server default
comparison. Set to ``True`` to turn on server default comparison,
which has
varied accuracy depending on backend.
To customize server default comparison behavior, a callable may
be specified
which can filter server default comparisons during an
autogenerate operation.
defaults during an autogenerate operation. The format of this
callable is::
def my_compare_server_default(context, inspected_column,
metadata_column, inspected_default, metadata_default,
rendered_metadata_default):
# return True if the defaults are different,
# False if not, or None to allow the default implementation
# to compare these defaults
return None
context.configure(
# ...
compare_server_default = my_compare_server_default
)
``inspected_column`` is a dictionary structure as returned by
:meth:`sqlalchemy.engine.reflection.Inspector.get_columns`, whereas
``metadata_column`` is a :class:`sqlalchemy.schema.Column` from
the local model environment.
A return value of ``None`` indicates to allow default server default
comparison
to proceed. Note that some backends such as Postgresql actually
execute
the two defaults on the database side to compare for equivalence.
.. seealso::
:paramref:`.EnvironmentContext.configure.compare_type`
:param include_name: A callable function which is given
the chance to return ``True`` or ``False`` for any database reflected
object based on its name, including database schema names when
the :paramref:`.EnvironmentContext.configure.include_schemas` flag
is set to ``True``.
The function accepts the following positional arguments:
* ``name``: the name of the object, such as schema name or table name.
Will be ``None`` when indicating the default schema name of the
database connection.
* ``type``: a string describing the type of object; currently
``"schema"``, ``"table"``, ``"column"``, ``"index"``,
``"unique_constraint"``, or ``"foreign_key_constraint"``
* ``parent_names``: a dictionary of "parent" object names, that are
relative to the name being given. Keys in this dictionary may
include: ``"schema_name"``, ``"table_name"``.
E.g.::
def include_name(name, type_, parent_names):
if type_ == "schema":
return name in ["schema_one", "schema_two"]
else:
return True
context.configure(
# ...
include_schemas = True,
include_name = include_name
)
.. versionadded:: 1.5
.. seealso::
:ref:`autogenerate_include_hooks`
:paramref:`.EnvironmentContext.configure.include_object`
:paramref:`.EnvironmentContext.configure.include_schemas`
:param include_object: A callable function which is given
the chance to return ``True`` or ``False`` for any object,
indicating if the given object should be considered in the
autogenerate sweep.
The function accepts the following positional arguments:
* ``object``: a :class:`~sqlalchemy.schema.SchemaItem` object such
as a :class:`~sqlalchemy.schema.Table`,
:class:`~sqlalchemy.schema.Column`,
:class:`~sqlalchemy.schema.Index`
:class:`~sqlalchemy.schema.UniqueConstraint`,
or :class:`~sqlalchemy.schema.ForeignKeyConstraint` object
* ``name``: the name of the object. This is typically available
via ``object.name``.
* ``type``: a string describing the type of object; currently
``"table"``, ``"column"``, ``"index"``, ``"unique_constraint"``,
or ``"foreign_key_constraint"``
* ``reflected``: ``True`` if the given object was produced based on
table reflection, ``False`` if it's from a local :class:`.MetaData`
object.
* ``compare_to``: the object being compared against, if available,
else ``None``.
E.g.::
def include_object(object, name, type_, reflected, compare_to):
if (type_ == "column" and
not reflected and
object.info.get("skip_autogenerate", False)):
return False
else:
return True
context.configure(
# ...
include_object = include_object
)
For the use case of omitting specific schemas from a target database
when :paramref:`.EnvironmentContext.configure.include_schemas` is
set to ``True``, the :attr:`~sqlalchemy.schema.Table.schema`
attribute can be checked for each :class:`~sqlalchemy.schema.Table`
object passed to the hook, however it is much more efficient
to filter on schemas before reflection of objects takes place
using the :paramref:`.EnvironmentContext.configure.include_name`
hook.
.. seealso::
:ref:`autogenerate_include_hooks`
:paramref:`.EnvironmentContext.configure.include_name`
:paramref:`.EnvironmentContext.configure.include_schemas`
:param render_as_batch: if True, commands which alter elements
within a table will be placed under a ``with batch_alter_table():``
directive, so that batch migrations will take place.
.. seealso::
:ref:`batch_migrations`
:param include_schemas: If True, autogenerate will scan across
all schemas located by the SQLAlchemy
:meth:`~sqlalchemy.engine.reflection.Inspector.get_schema_names`
method, and include all differences in tables found across all
those schemas. When using this option, you may want to also
use the :paramref:`.EnvironmentContext.configure.include_name`
parameter to specify a callable which
can filter the tables/schemas that get included.
.. seealso::
:ref:`autogenerate_include_hooks`
:paramref:`.EnvironmentContext.configure.include_name`
:paramref:`.EnvironmentContext.configure.include_object`
:param render_item: Callable that can be used to override how
any schema item, i.e. column, constraint, type,
etc., is rendered for autogenerate. The callable receives a
string describing the type of object, the object, and
the autogen context. If it returns False, the
default rendering method will be used. If it returns None,
the item will not be rendered in the context of a Table
construct, that is, can be used to skip columns or constraints
within op.create_table()::
def my_render_column(type_, col, autogen_context):
if type_ == "column" and isinstance(col, MySpecialCol):
return repr(col)
else:
return False
context.configure(
# ...
render_item = my_render_column
)
Available values for the type string include: ``"column"``,
``"primary_key"``, ``"foreign_key"``, ``"unique"``, ``"check"``,
``"type"``, ``"server_default"``.
.. seealso::
:ref:`autogen_render_types`
:param upgrade_token: When autogenerate completes, the text of the
candidate upgrade operations will be present in this template
variable when ``script.py.mako`` is rendered. Defaults to
``upgrades``.
:param downgrade_token: When autogenerate completes, the text of the
candidate downgrade operations will be present in this
template variable when ``script.py.mako`` is rendered. Defaults to
``downgrades``.
:param alembic_module_prefix: When autogenerate refers to Alembic
:mod:`alembic.operations` constructs, this prefix will be used
(i.e. ``op.create_table``) Defaults to "``op.``".
Can be ``None`` to indicate no prefix.
:param sqlalchemy_module_prefix: When autogenerate refers to
SQLAlchemy
:class:`~sqlalchemy.schema.Column` or type classes, this prefix
will be used
(i.e. ``sa.Column("somename", sa.Integer)``) Defaults to "``sa.``".
Can be ``None`` to indicate no prefix.
Note that when dialect-specific types are rendered, autogenerate
will render them using the dialect module name, i.e. ``mssql.BIT()``,
``postgresql.UUID()``.
:param user_module_prefix: When autogenerate refers to a SQLAlchemy
type (e.g. :class:`.TypeEngine`) where the module name is not
under the ``sqlalchemy`` namespace, this prefix will be used
within autogenerate. If left at its default of
``None``, the ``__module__`` attribute of the type is used to
render the import module. It's a good practice to set this
and to have all custom types be available from a fixed module space,
in order to future-proof migration files against reorganizations
in modules.
.. seealso::
:ref:`autogen_module_prefix`
:param process_revision_directives: a callable function that will
be passed a structure representing the end result of an autogenerate
or plain "revision" operation, which can be manipulated to affect
how the ``alembic revision`` command ultimately outputs new
revision scripts. The structure of the callable is::
def process_revision_directives(context, revision, directives):
pass
The ``directives`` parameter is a Python list containing
a single :class:`.MigrationScript` directive, which represents
the revision file to be generated. This list as well as its
contents may be freely modified to produce any set of commands.
The section :ref:`customizing_revision` shows an example of
doing this. The ``context`` parameter is the
:class:`.MigrationContext` in use,
and ``revision`` is a tuple of revision identifiers representing the
current revision of the database.
The callable is invoked at all times when the ``--autogenerate``
option is passed to ``alembic revision``. If ``--autogenerate``
is not passed, the callable is invoked only if the
``revision_environment`` variable is set to True in the Alembic
configuration, in which case the given ``directives`` collection
will contain empty :class:`.UpgradeOps` and :class:`.DowngradeOps`
collections for ``.upgrade_ops`` and ``.downgrade_ops``. The
``--autogenerate`` option itself can be inferred by inspecting
``context.config.cmd_opts.autogenerate``.
The callable function may optionally be an instance of
a :class:`.Rewriter` object. This is a helper object that
assists in the production of autogenerate-stream rewriter functions.
.. seealso::
:ref:`customizing_revision`
:ref:`autogen_rewriter`
:paramref:`.command.revision.process_revision_directives`
Parameters specific to individual backends:
:param mssql_batch_separator: The "batch separator" which will
be placed between each statement when generating offline SQL Server
migrations. Defaults to ``GO``. Note this is in addition to the
customary semicolon ``;`` at the end of each statement; SQL Server
considers the "batch separator" to denote the end of an
individual statement execution, and cannot group certain
dependent operations in one step.
:param oracle_batch_separator: The "batch separator" which will
be placed between each statement when generating offline
Oracle migrations. Defaults to ``/``. Oracle doesn't add a
semicolon between statements like most other backends.
"""
def execute(
sql: Union[ClauseElement, str], execution_options: Optional[dict] = None
) -> None:
r"""Execute the given SQL using the current change context.
The behavior of :meth:`.execute` is the same
as that of :meth:`.Operations.execute`. Please see that
function's documentation for full detail including
caveats and limitations.
This function requires that a :class:`.MigrationContext` has
first been made available via :meth:`.configure`.
"""
def get_bind() -> Connection:
r"""Return the current 'bind'.
In "online" mode, this is the
:class:`sqlalchemy.engine.Connection` currently being used
to emit SQL to the database.
This function requires that a :class:`.MigrationContext`
has first been made available via :meth:`.configure`.
"""
def get_context() -> MigrationContext:
r"""Return the current :class:`.MigrationContext` object.
If :meth:`.EnvironmentContext.configure` has not been
called yet, raises an exception.
"""
def get_head_revision() -> Union[str, Tuple[str, ...], None]:
r"""Return the hex identifier of the 'head' script revision.
If the script directory has multiple heads, this
method raises a :class:`.CommandError`;
:meth:`.EnvironmentContext.get_head_revisions` should be preferred.
This function does not require that the :class:`.MigrationContext`
has been configured.
.. seealso:: :meth:`.EnvironmentContext.get_head_revisions`
"""
def get_head_revisions() -> Union[str, Tuple[str, ...], None]:
r"""Return the hex identifier of the 'heads' script revision(s).
This returns a tuple containing the version number of all
heads in the script directory.
This function does not require that the :class:`.MigrationContext`
has been configured.
"""
def get_revision_argument() -> Union[str, Tuple[str, ...], None]:
r"""Get the 'destination' revision argument.
This is typically the argument passed to the
``upgrade`` or ``downgrade`` command.
If it was specified as ``head``, the actual
version number is returned; if specified
as ``base``, ``None`` is returned.
This function does not require that the :class:`.MigrationContext`
has been configured.
"""
def get_starting_revision_argument() -> Union[str, Tuple[str, ...], None]:
r"""Return the 'starting revision' argument,
if the revision was passed using ``start:end``.
This is only meaningful in "offline" mode.
Returns ``None`` if no value is available
or was configured.
This function does not require that the :class:`.MigrationContext`
has been configured.
"""
def get_tag_argument() -> Optional[str]:
r"""Return the value passed for the ``--tag`` argument, if any.
The ``--tag`` argument is not used directly by Alembic,
but is available for custom ``env.py`` configurations that
wish to use it; particularly for offline generation scripts
that wish to generate tagged filenames.
This function does not require that the :class:`.MigrationContext`
has been configured.
.. seealso::
:meth:`.EnvironmentContext.get_x_argument` - a newer and more
open ended system of extending ``env.py`` scripts via the command
line.
"""
@overload
def get_x_argument(as_dictionary: Literal[False]) -> List[str]: ...
@overload
def get_x_argument(as_dictionary: Literal[True]) -> Dict[str, str]: ...
@overload
def get_x_argument(
as_dictionary: bool = ...,
) -> Union[List[str], Dict[str, str]]:
r"""Return the value(s) passed for the ``-x`` argument, if any.
The ``-x`` argument is an open ended flag that allows any user-defined
value or values to be passed on the command line, then available
here for consumption by a custom ``env.py`` script.
The return value is a list, returned directly from the ``argparse``
structure. If ``as_dictionary=True`` is passed, the ``x`` arguments
are parsed using ``key=value`` format into a dictionary that is
then returned.
For example, to support passing a database URL on the command line,
the standard ``env.py`` script can be modified like this::
cmd_line_url = context.get_x_argument(
as_dictionary=True).get('dbname')
if cmd_line_url:
engine = create_engine(cmd_line_url)
else:
engine = engine_from_config(
config.get_section(config.config_ini_section),
prefix='sqlalchemy.',
poolclass=pool.NullPool)
This then takes effect by running the ``alembic`` script as::
alembic -x dbname=postgresql://user:pass@host/dbname upgrade head
This function does not require that the :class:`.MigrationContext`
has been configured.
.. seealso::
:meth:`.EnvironmentContext.get_tag_argument`
:attr:`.Config.cmd_opts`
"""
def is_offline_mode() -> bool:
r"""Return True if the current migrations environment
is running in "offline mode".
This is ``True`` or ``False`` depending
on the ``--sql`` flag passed.
This function does not require that the :class:`.MigrationContext`
has been configured.
"""
def is_transactional_ddl():
r"""Return True if the context is configured to expect a
transactional DDL capable backend.
This defaults to the type of database in use, and
can be overridden by the ``transactional_ddl`` argument
to :meth:`.configure`
This function requires that a :class:`.MigrationContext`
has first been made available via :meth:`.configure`.
"""
def run_migrations(**kw: Any) -> None:
r"""Run migrations as determined by the current command line
configuration
as well as versioning information present (or not) in the current
database connection (if one is present).
The function accepts optional ``**kw`` arguments. If these are
passed, they are sent directly to the ``upgrade()`` and
``downgrade()``
functions within each target revision file. By modifying the
``script.py.mako`` file so that the ``upgrade()`` and ``downgrade()``
functions accept arguments, parameters can be passed here so that
contextual information, usually information to identify a particular
database in use, can be passed from a custom ``env.py`` script
to the migration functions.
This function requires that a :class:`.MigrationContext` has
first been made available via :meth:`.configure`.
"""
script: ScriptDirectory
def static_output(text: str) -> None:
r"""Emit text directly to the "offline" SQL stream.
Typically this is for emitting comments that
start with --. The statement is not treated
as a SQL execution, no ; or batch separator
is added, etc.
"""

View File

@ -0,0 +1,6 @@
from . import mssql
from . import mysql
from . import oracle
from . import postgresql
from . import sqlite
from .impl import DefaultImpl

332
libs/alembic/ddl/base.py Normal file
View File

@ -0,0 +1,332 @@
from __future__ import annotations
import functools
from typing import Optional
from typing import TYPE_CHECKING
from typing import Union
from sqlalchemy import exc
from sqlalchemy import Integer
from sqlalchemy import types as sqltypes
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.schema import Column
from sqlalchemy.schema import DDLElement
from sqlalchemy.sql.elements import quoted_name
from ..util.sqla_compat import _columns_for_constraint # noqa
from ..util.sqla_compat import _find_columns # noqa
from ..util.sqla_compat import _fk_spec # noqa
from ..util.sqla_compat import _is_type_bound # noqa
from ..util.sqla_compat import _table_for_constraint # noqa
if TYPE_CHECKING:
from typing import Any
from sqlalchemy.sql.compiler import Compiled
from sqlalchemy.sql.compiler import DDLCompiler
from sqlalchemy.sql.elements import TextClause
from sqlalchemy.sql.functions import Function
from sqlalchemy.sql.schema import FetchedValue
from sqlalchemy.sql.type_api import TypeEngine
from .impl import DefaultImpl
from ..util.sqla_compat import Computed
from ..util.sqla_compat import Identity
_ServerDefault = Union["TextClause", "FetchedValue", "Function[Any]", str]
class AlterTable(DDLElement):
"""Represent an ALTER TABLE statement.
Only the string name and optional schema name of the table
is required, not a full Table object.
"""
def __init__(
self,
table_name: str,
schema: Optional[Union[quoted_name, str]] = None,
) -> None:
self.table_name = table_name
self.schema = schema
class RenameTable(AlterTable):
def __init__(
self,
old_table_name: str,
new_table_name: Union[quoted_name, str],
schema: Optional[Union[quoted_name, str]] = None,
) -> None:
super().__init__(old_table_name, schema=schema)
self.new_table_name = new_table_name
class AlterColumn(AlterTable):
def __init__(
self,
name: str,
column_name: str,
schema: Optional[str] = None,
existing_type: Optional[TypeEngine] = None,
existing_nullable: Optional[bool] = None,
existing_server_default: Optional[_ServerDefault] = None,
existing_comment: Optional[str] = None,
) -> None:
super().__init__(name, schema=schema)
self.column_name = column_name
self.existing_type = (
sqltypes.to_instance(existing_type)
if existing_type is not None
else None
)
self.existing_nullable = existing_nullable
self.existing_server_default = existing_server_default
self.existing_comment = existing_comment
class ColumnNullable(AlterColumn):
def __init__(
self, name: str, column_name: str, nullable: bool, **kw
) -> None:
super().__init__(name, column_name, **kw)
self.nullable = nullable
class ColumnType(AlterColumn):
def __init__(
self, name: str, column_name: str, type_: TypeEngine, **kw
) -> None:
super().__init__(name, column_name, **kw)
self.type_ = sqltypes.to_instance(type_)
class ColumnName(AlterColumn):
def __init__(
self, name: str, column_name: str, newname: str, **kw
) -> None:
super().__init__(name, column_name, **kw)
self.newname = newname
class ColumnDefault(AlterColumn):
def __init__(
self,
name: str,
column_name: str,
default: Optional[_ServerDefault],
**kw,
) -> None:
super().__init__(name, column_name, **kw)
self.default = default
class ComputedColumnDefault(AlterColumn):
def __init__(
self, name: str, column_name: str, default: Optional[Computed], **kw
) -> None:
super().__init__(name, column_name, **kw)
self.default = default
class IdentityColumnDefault(AlterColumn):
def __init__(
self,
name: str,
column_name: str,
default: Optional[Identity],
impl: DefaultImpl,
**kw,
) -> None:
super().__init__(name, column_name, **kw)
self.default = default
self.impl = impl
class AddColumn(AlterTable):
def __init__(
self,
name: str,
column: Column,
schema: Optional[Union[quoted_name, str]] = None,
) -> None:
super().__init__(name, schema=schema)
self.column = column
class DropColumn(AlterTable):
def __init__(
self, name: str, column: Column, schema: Optional[str] = None
) -> None:
super().__init__(name, schema=schema)
self.column = column
class ColumnComment(AlterColumn):
def __init__(
self, name: str, column_name: str, comment: Optional[str], **kw
) -> None:
super().__init__(name, column_name, **kw)
self.comment = comment
@compiles(RenameTable)
def visit_rename_table(
element: RenameTable, compiler: DDLCompiler, **kw
) -> str:
return "%s RENAME TO %s" % (
alter_table(compiler, element.table_name, element.schema),
format_table_name(compiler, element.new_table_name, element.schema),
)
@compiles(AddColumn)
def visit_add_column(element: AddColumn, compiler: DDLCompiler, **kw) -> str:
return "%s %s" % (
alter_table(compiler, element.table_name, element.schema),
add_column(compiler, element.column, **kw),
)
@compiles(DropColumn)
def visit_drop_column(element: DropColumn, compiler: DDLCompiler, **kw) -> str:
return "%s %s" % (
alter_table(compiler, element.table_name, element.schema),
drop_column(compiler, element.column.name, **kw),
)
@compiles(ColumnNullable)
def visit_column_nullable(
element: ColumnNullable, compiler: DDLCompiler, **kw
) -> str:
return "%s %s %s" % (
alter_table(compiler, element.table_name, element.schema),
alter_column(compiler, element.column_name),
"DROP NOT NULL" if element.nullable else "SET NOT NULL",
)
@compiles(ColumnType)
def visit_column_type(element: ColumnType, compiler: DDLCompiler, **kw) -> str:
return "%s %s %s" % (
alter_table(compiler, element.table_name, element.schema),
alter_column(compiler, element.column_name),
"TYPE %s" % format_type(compiler, element.type_),
)
@compiles(ColumnName)
def visit_column_name(element: ColumnName, compiler: DDLCompiler, **kw) -> str:
return "%s RENAME %s TO %s" % (
alter_table(compiler, element.table_name, element.schema),
format_column_name(compiler, element.column_name),
format_column_name(compiler, element.newname),
)
@compiles(ColumnDefault)
def visit_column_default(
element: ColumnDefault, compiler: DDLCompiler, **kw
) -> str:
return "%s %s %s" % (
alter_table(compiler, element.table_name, element.schema),
alter_column(compiler, element.column_name),
"SET DEFAULT %s" % format_server_default(compiler, element.default)
if element.default is not None
else "DROP DEFAULT",
)
@compiles(ComputedColumnDefault)
def visit_computed_column(
element: ComputedColumnDefault, compiler: DDLCompiler, **kw
):
raise exc.CompileError(
'Adding or removing a "computed" construct, e.g. GENERATED '
"ALWAYS AS, to or from an existing column is not supported."
)
@compiles(IdentityColumnDefault)
def visit_identity_column(
element: IdentityColumnDefault, compiler: DDLCompiler, **kw
):
raise exc.CompileError(
'Adding, removing or modifying an "identity" construct, '
"e.g. GENERATED AS IDENTITY, to or from an existing "
"column is not supported in this dialect."
)
def quote_dotted(
name: Union[quoted_name, str], quote: functools.partial
) -> Union[quoted_name, str]:
"""quote the elements of a dotted name"""
if isinstance(name, quoted_name):
return quote(name)
result = ".".join([quote(x) for x in name.split(".")])
return result
def format_table_name(
compiler: Compiled,
name: Union[quoted_name, str],
schema: Optional[Union[quoted_name, str]],
) -> Union[quoted_name, str]:
quote = functools.partial(compiler.preparer.quote)
if schema:
return quote_dotted(schema, quote) + "." + quote(name)
else:
return quote(name)
def format_column_name(
compiler: DDLCompiler, name: Optional[Union[quoted_name, str]]
) -> Union[quoted_name, str]:
return compiler.preparer.quote(name) # type: ignore[arg-type]
def format_server_default(
compiler: DDLCompiler,
default: Optional[_ServerDefault],
) -> str:
return compiler.get_column_default_string(
Column("x", Integer, server_default=default)
)
def format_type(compiler: DDLCompiler, type_: TypeEngine) -> str:
return compiler.dialect.type_compiler.process(type_)
def alter_table(
compiler: DDLCompiler,
name: str,
schema: Optional[str],
) -> str:
return "ALTER TABLE %s" % format_table_name(compiler, name, schema)
def drop_column(compiler: DDLCompiler, name: str, **kw) -> str:
return "DROP COLUMN %s" % format_column_name(compiler, name)
def alter_column(compiler: DDLCompiler, name: str) -> str:
return "ALTER COLUMN %s" % format_column_name(compiler, name)
def add_column(compiler: DDLCompiler, column: Column, **kw) -> str:
text = "ADD COLUMN %s" % compiler.get_column_specification(column, **kw)
const = " ".join(
compiler.process(constraint) for constraint in column.constraints
)
if const:
text += " " + const
return text

707
libs/alembic/ddl/impl.py Normal file
View File

@ -0,0 +1,707 @@
from __future__ import annotations
from collections import namedtuple
import re
from typing import Any
from typing import Callable
from typing import Dict
from typing import List
from typing import Optional
from typing import Sequence
from typing import Set
from typing import Tuple
from typing import Type
from typing import TYPE_CHECKING
from typing import Union
from sqlalchemy import cast
from sqlalchemy import schema
from sqlalchemy import text
from . import base
from .. import util
from ..util import sqla_compat
if TYPE_CHECKING:
from typing import Literal
from typing import TextIO
from sqlalchemy.engine import Connection
from sqlalchemy.engine import Dialect
from sqlalchemy.engine.cursor import CursorResult
from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.sql.elements import ClauseElement
from sqlalchemy.sql.elements import ColumnElement
from sqlalchemy.sql.elements import quoted_name
from sqlalchemy.sql.schema import Column
from sqlalchemy.sql.schema import Constraint
from sqlalchemy.sql.schema import ForeignKeyConstraint
from sqlalchemy.sql.schema import Index
from sqlalchemy.sql.schema import Table
from sqlalchemy.sql.schema import UniqueConstraint
from sqlalchemy.sql.selectable import TableClause
from sqlalchemy.sql.type_api import TypeEngine
from .base import _ServerDefault
from ..autogenerate.api import AutogenContext
from ..operations.batch import ApplyBatchImpl
from ..operations.batch import BatchOperationsImpl
class ImplMeta(type):
def __init__(
cls,
classname: str,
bases: Tuple[Type[DefaultImpl]],
dict_: Dict[str, Any],
):
newtype = type.__init__(cls, classname, bases, dict_)
if "__dialect__" in dict_:
_impls[dict_["__dialect__"]] = cls # type: ignore[assignment]
return newtype
_impls: Dict[str, Type[DefaultImpl]] = {}
Params = namedtuple("Params", ["token0", "tokens", "args", "kwargs"])
class DefaultImpl(metaclass=ImplMeta):
"""Provide the entrypoint for major migration operations,
including database-specific behavioral variances.
While individual SQL/DDL constructs already provide
for database-specific implementations, variances here
allow for entirely different sequences of operations
to take place for a particular migration, such as
SQL Server's special 'IDENTITY INSERT' step for
bulk inserts.
"""
__dialect__ = "default"
transactional_ddl = False
command_terminator = ";"
type_synonyms: Tuple[Set[str], ...] = ({"NUMERIC", "DECIMAL"},)
type_arg_extract: Sequence[str] = ()
# on_null is known to be supported only by oracle
identity_attrs_ignore: Tuple[str, ...] = ("on_null",)
def __init__(
self,
dialect: Dialect,
connection: Optional[Connection],
as_sql: bool,
transactional_ddl: Optional[bool],
output_buffer: Optional[TextIO],
context_opts: Dict[str, Any],
) -> None:
self.dialect = dialect
self.connection = connection
self.as_sql = as_sql
self.literal_binds = context_opts.get("literal_binds", False)
self.output_buffer = output_buffer
self.memo: dict = {}
self.context_opts = context_opts
if transactional_ddl is not None:
self.transactional_ddl = transactional_ddl
if self.literal_binds:
if not self.as_sql:
raise util.CommandError(
"Can't use literal_binds setting without as_sql mode"
)
@classmethod
def get_by_dialect(cls, dialect: Dialect) -> Type[DefaultImpl]:
return _impls[dialect.name]
def static_output(self, text: str) -> None:
assert self.output_buffer is not None
self.output_buffer.write(text + "\n\n")
self.output_buffer.flush()
def requires_recreate_in_batch(
self, batch_op: BatchOperationsImpl
) -> bool:
"""Return True if the given :class:`.BatchOperationsImpl`
would need the table to be recreated and copied in order to
proceed.
Normally, only returns True on SQLite when operations other
than add_column are present.
"""
return False
def prep_table_for_batch(
self, batch_impl: ApplyBatchImpl, table: Table
) -> None:
"""perform any operations needed on a table before a new
one is created to replace it in batch mode.
the PG dialect uses this to drop constraints on the table
before the new one uses those same names.
"""
@property
def bind(self) -> Optional[Connection]:
return self.connection
def _exec(
self,
construct: Union[ClauseElement, str],
execution_options: Optional[dict] = None,
multiparams: Sequence[dict] = (),
params: Dict[str, int] = util.immutabledict(),
) -> Optional[CursorResult]:
if isinstance(construct, str):
construct = text(construct)
if self.as_sql:
if multiparams or params:
# TODO: coverage
raise Exception("Execution arguments not allowed with as_sql")
if self.literal_binds and not isinstance(
construct, schema.DDLElement
):
compile_kw = dict(compile_kwargs={"literal_binds": True})
else:
compile_kw = {}
compiled = construct.compile(
dialect=self.dialect, **compile_kw # type: ignore[arg-type]
)
self.static_output(
str(compiled).replace("\t", " ").strip()
+ self.command_terminator
)
return None
else:
conn = self.connection
assert conn is not None
if execution_options:
conn = conn.execution_options(**execution_options)
if params:
assert isinstance(multiparams, tuple)
multiparams += (params,)
return conn.execute( # type: ignore[call-overload]
construct, multiparams
)
def execute(
self,
sql: Union[ClauseElement, str],
execution_options: None = None,
) -> None:
self._exec(sql, execution_options)
def alter_column(
self,
table_name: str,
column_name: str,
nullable: Optional[bool] = None,
server_default: Union[_ServerDefault, Literal[False]] = False,
name: Optional[str] = None,
type_: Optional[TypeEngine] = None,
schema: Optional[str] = None,
autoincrement: Optional[bool] = None,
comment: Optional[Union[str, Literal[False]]] = False,
existing_comment: Optional[str] = None,
existing_type: Optional[TypeEngine] = None,
existing_server_default: Optional[_ServerDefault] = None,
existing_nullable: Optional[bool] = None,
existing_autoincrement: Optional[bool] = None,
**kw: Any,
) -> None:
if autoincrement is not None or existing_autoincrement is not None:
util.warn(
"autoincrement and existing_autoincrement "
"only make sense for MySQL",
stacklevel=3,
)
if nullable is not None:
self._exec(
base.ColumnNullable(
table_name,
column_name,
nullable,
schema=schema,
existing_type=existing_type,
existing_server_default=existing_server_default,
existing_nullable=existing_nullable,
existing_comment=existing_comment,
)
)
if server_default is not False:
kw = {}
cls_: Type[
Union[
base.ComputedColumnDefault,
base.IdentityColumnDefault,
base.ColumnDefault,
]
]
if sqla_compat._server_default_is_computed(
server_default, existing_server_default
):
cls_ = base.ComputedColumnDefault
elif sqla_compat._server_default_is_identity(
server_default, existing_server_default
):
cls_ = base.IdentityColumnDefault
kw["impl"] = self
else:
cls_ = base.ColumnDefault
self._exec(
cls_(
table_name,
column_name,
server_default, # type:ignore[arg-type]
schema=schema,
existing_type=existing_type,
existing_server_default=existing_server_default,
existing_nullable=existing_nullable,
existing_comment=existing_comment,
**kw,
)
)
if type_ is not None:
self._exec(
base.ColumnType(
table_name,
column_name,
type_,
schema=schema,
existing_type=existing_type,
existing_server_default=existing_server_default,
existing_nullable=existing_nullable,
existing_comment=existing_comment,
)
)
if comment is not False:
self._exec(
base.ColumnComment(
table_name,
column_name,
comment,
schema=schema,
existing_type=existing_type,
existing_server_default=existing_server_default,
existing_nullable=existing_nullable,
existing_comment=existing_comment,
)
)
# do the new name last ;)
if name is not None:
self._exec(
base.ColumnName(
table_name,
column_name,
name,
schema=schema,
existing_type=existing_type,
existing_server_default=existing_server_default,
existing_nullable=existing_nullable,
)
)
def add_column(
self,
table_name: str,
column: Column,
schema: Optional[Union[str, quoted_name]] = None,
) -> None:
self._exec(base.AddColumn(table_name, column, schema=schema))
def drop_column(
self,
table_name: str,
column: Column,
schema: Optional[str] = None,
**kw,
) -> None:
self._exec(base.DropColumn(table_name, column, schema=schema))
def add_constraint(self, const: Any) -> None:
if const._create_rule is None or const._create_rule(self):
self._exec(schema.AddConstraint(const))
def drop_constraint(self, const: Constraint) -> None:
self._exec(schema.DropConstraint(const))
def rename_table(
self,
old_table_name: str,
new_table_name: Union[str, quoted_name],
schema: Optional[Union[str, quoted_name]] = None,
) -> None:
self._exec(
base.RenameTable(old_table_name, new_table_name, schema=schema)
)
def create_table(self, table: Table) -> None:
table.dispatch.before_create(
table, self.connection, checkfirst=False, _ddl_runner=self
)
self._exec(schema.CreateTable(table))
table.dispatch.after_create(
table, self.connection, checkfirst=False, _ddl_runner=self
)
for index in table.indexes:
self._exec(schema.CreateIndex(index))
with_comment = (
self.dialect.supports_comments and not self.dialect.inline_comments
)
comment = table.comment
if comment and with_comment:
self.create_table_comment(table)
for column in table.columns:
comment = column.comment
if comment and with_comment:
self.create_column_comment(column)
def drop_table(self, table: Table) -> None:
table.dispatch.before_drop(
table, self.connection, checkfirst=False, _ddl_runner=self
)
self._exec(schema.DropTable(table))
table.dispatch.after_drop(
table, self.connection, checkfirst=False, _ddl_runner=self
)
def create_index(self, index: Index) -> None:
self._exec(schema.CreateIndex(index))
def create_table_comment(self, table: Table) -> None:
self._exec(schema.SetTableComment(table))
def drop_table_comment(self, table: Table) -> None:
self._exec(schema.DropTableComment(table))
def create_column_comment(self, column: ColumnElement) -> None:
self._exec(schema.SetColumnComment(column))
def drop_index(self, index: Index) -> None:
self._exec(schema.DropIndex(index))
def bulk_insert(
self,
table: Union[TableClause, Table],
rows: List[dict],
multiinsert: bool = True,
) -> None:
if not isinstance(rows, list):
raise TypeError("List expected")
elif rows and not isinstance(rows[0], dict):
raise TypeError("List of dictionaries expected")
if self.as_sql:
for row in rows:
self._exec(
sqla_compat._insert_inline(table).values(
**{
k: sqla_compat._literal_bindparam(
k, v, type_=table.c[k].type
)
if not isinstance(
v, sqla_compat._literal_bindparam
)
else v
for k, v in row.items()
}
)
)
else:
if rows:
if multiinsert:
self._exec(
sqla_compat._insert_inline(table), multiparams=rows
)
else:
for row in rows:
self._exec(
sqla_compat._insert_inline(table).values(**row)
)
def _tokenize_column_type(self, column: Column) -> Params:
definition = self.dialect.type_compiler.process(column.type).lower()
# tokenize the SQLAlchemy-generated version of a type, so that
# the two can be compared.
#
# examples:
# NUMERIC(10, 5)
# TIMESTAMP WITH TIMEZONE
# INTEGER UNSIGNED
# INTEGER (10) UNSIGNED
# INTEGER(10) UNSIGNED
# varchar character set utf8
#
tokens = re.findall(r"[\w\-_]+|\(.+?\)", definition)
term_tokens = []
paren_term = None
for token in tokens:
if re.match(r"^\(.*\)$", token):
paren_term = token
else:
term_tokens.append(token)
params = Params(term_tokens[0], term_tokens[1:], [], {})
if paren_term:
for term in re.findall("[^(),]+", paren_term):
if "=" in term:
key, val = term.split("=")
params.kwargs[key.strip()] = val.strip()
else:
params.args.append(term.strip())
return params
def _column_types_match(
self, inspector_params: Params, metadata_params: Params
) -> bool:
if inspector_params.token0 == metadata_params.token0:
return True
synonyms = [{t.lower() for t in batch} for batch in self.type_synonyms]
inspector_all_terms = " ".join(
[inspector_params.token0] + inspector_params.tokens
)
metadata_all_terms = " ".join(
[metadata_params.token0] + metadata_params.tokens
)
for batch in synonyms:
if {inspector_all_terms, metadata_all_terms}.issubset(batch) or {
inspector_params.token0,
metadata_params.token0,
}.issubset(batch):
return True
return False
def _column_args_match(
self, inspected_params: Params, meta_params: Params
) -> bool:
"""We want to compare column parameters. However, we only want
to compare parameters that are set. If they both have `collation`,
we want to make sure they are the same. However, if only one
specifies it, dont flag it for being less specific
"""
if (
len(meta_params.tokens) == len(inspected_params.tokens)
and meta_params.tokens != inspected_params.tokens
):
return False
if (
len(meta_params.args) == len(inspected_params.args)
and meta_params.args != inspected_params.args
):
return False
insp = " ".join(inspected_params.tokens).lower()
meta = " ".join(meta_params.tokens).lower()
for reg in self.type_arg_extract:
mi = re.search(reg, insp)
mm = re.search(reg, meta)
if mi and mm and mi.group(1) != mm.group(1):
return False
return True
def compare_type(
self, inspector_column: Column, metadata_column: Column
) -> bool:
"""Returns True if there ARE differences between the types of the two
columns. Takes impl.type_synonyms into account between retrospected
and metadata types
"""
inspector_params = self._tokenize_column_type(inspector_column)
metadata_params = self._tokenize_column_type(metadata_column)
if not self._column_types_match(inspector_params, metadata_params):
return True
if not self._column_args_match(inspector_params, metadata_params):
return True
return False
def compare_server_default(
self,
inspector_column,
metadata_column,
rendered_metadata_default,
rendered_inspector_default,
):
return rendered_inspector_default != rendered_metadata_default
def correct_for_autogen_constraints(
self,
conn_uniques: Set[UniqueConstraint],
conn_indexes: Set[Index],
metadata_unique_constraints: Set[UniqueConstraint],
metadata_indexes: Set[Index],
) -> None:
pass
def cast_for_batch_migrate(self, existing, existing_transfer, new_type):
if existing.type._type_affinity is not new_type._type_affinity:
existing_transfer["expr"] = cast(
existing_transfer["expr"], new_type
)
def render_ddl_sql_expr(
self, expr: ClauseElement, is_server_default: bool = False, **kw: Any
) -> str:
"""Render a SQL expression that is typically a server default,
index expression, etc.
.. versionadded:: 1.0.11
"""
compile_kw = {
"compile_kwargs": {"literal_binds": True, "include_table": False}
}
return str(
expr.compile(
dialect=self.dialect, **compile_kw # type: ignore[arg-type]
)
)
def _compat_autogen_column_reflect(self, inspector: Inspector) -> Callable:
return self.autogen_column_reflect
def correct_for_autogen_foreignkeys(
self,
conn_fks: Set[ForeignKeyConstraint],
metadata_fks: Set[ForeignKeyConstraint],
) -> None:
pass
def autogen_column_reflect(self, inspector, table, column_info):
"""A hook that is attached to the 'column_reflect' event for when
a Table is reflected from the database during the autogenerate
process.
Dialects can elect to modify the information gathered here.
"""
def start_migrations(self) -> None:
"""A hook called when :meth:`.EnvironmentContext.run_migrations`
is called.
Implementations can set up per-migration-run state here.
"""
def emit_begin(self) -> None:
"""Emit the string ``BEGIN``, or the backend-specific
equivalent, on the current connection context.
This is used in offline mode and typically
via :meth:`.EnvironmentContext.begin_transaction`.
"""
self.static_output("BEGIN" + self.command_terminator)
def emit_commit(self) -> None:
"""Emit the string ``COMMIT``, or the backend-specific
equivalent, on the current connection context.
This is used in offline mode and typically
via :meth:`.EnvironmentContext.begin_transaction`.
"""
self.static_output("COMMIT" + self.command_terminator)
def render_type(
self, type_obj: TypeEngine, autogen_context: AutogenContext
) -> Union[str, Literal[False]]:
return False
def _compare_identity_default(self, metadata_identity, inspector_identity):
# ignored contains the attributes that were not considered
# because assumed to their default values in the db.
diff, ignored = _compare_identity_options(
sqla_compat._identity_attrs,
metadata_identity,
inspector_identity,
sqla_compat.Identity(),
)
meta_always = getattr(metadata_identity, "always", None)
inspector_always = getattr(inspector_identity, "always", None)
# None and False are the same in this comparison
if bool(meta_always) != bool(inspector_always):
diff.add("always")
diff.difference_update(self.identity_attrs_ignore)
# returns 3 values:
return (
# different identity attributes
diff,
# ignored identity attributes
ignored,
# if the two identity should be considered different
bool(diff) or bool(metadata_identity) != bool(inspector_identity),
)
def create_index_sig(self, index: Index) -> Tuple[Any, ...]:
# order of col matters in an index
return tuple(col.name for col in index.columns)
def _skip_functional_indexes(self, metadata_indexes, conn_indexes):
conn_indexes_by_name = {c.name: c for c in conn_indexes}
for idx in list(metadata_indexes):
if idx.name in conn_indexes_by_name:
continue
iex = sqla_compat.is_expression_index(idx)
if iex:
util.warn(
"autogenerate skipping metadata-specified "
"expression-based index "
f"{idx.name!r}; dialect {self.__dialect__!r} under "
f"SQLAlchemy {sqla_compat.sqlalchemy_version} can't "
"reflect these indexes so they can't be compared"
)
metadata_indexes.discard(idx)
def _compare_identity_options(
attributes, metadata_io, inspector_io, default_io
):
# this can be used for identity or sequence compare.
# default_io is an instance of IdentityOption with all attributes to the
# default value.
diff = set()
ignored_attr = set()
for attr in attributes:
meta_value = getattr(metadata_io, attr, None)
default_value = getattr(default_io, attr, None)
conn_value = getattr(inspector_io, attr, None)
if conn_value != meta_value:
if meta_value == default_value:
ignored_attr.add(attr)
else:
diff.add(attr)
return diff, ignored_attr

408
libs/alembic/ddl/mssql.py Normal file
View File

@ -0,0 +1,408 @@
from __future__ import annotations
import re
from typing import Any
from typing import List
from typing import Optional
from typing import TYPE_CHECKING
from typing import Union
from sqlalchemy import types as sqltypes
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.schema import Column
from sqlalchemy.schema import CreateIndex
from sqlalchemy.sql.base import Executable
from sqlalchemy.sql.elements import ClauseElement
from .base import AddColumn
from .base import alter_column
from .base import alter_table
from .base import ColumnDefault
from .base import ColumnName
from .base import ColumnNullable
from .base import ColumnType
from .base import format_column_name
from .base import format_server_default
from .base import format_table_name
from .base import format_type
from .base import RenameTable
from .impl import DefaultImpl
from .. import util
from ..util import sqla_compat
if TYPE_CHECKING:
from typing import Literal
from sqlalchemy.dialects.mssql.base import MSDDLCompiler
from sqlalchemy.dialects.mssql.base import MSSQLCompiler
from sqlalchemy.engine.cursor import CursorResult
from sqlalchemy.sql.schema import Index
from sqlalchemy.sql.schema import Table
from sqlalchemy.sql.selectable import TableClause
from sqlalchemy.sql.type_api import TypeEngine
from .base import _ServerDefault
class MSSQLImpl(DefaultImpl):
__dialect__ = "mssql"
transactional_ddl = True
batch_separator = "GO"
type_synonyms = DefaultImpl.type_synonyms + ({"VARCHAR", "NVARCHAR"},)
identity_attrs_ignore = (
"minvalue",
"maxvalue",
"nominvalue",
"nomaxvalue",
"cycle",
"cache",
"order",
"on_null",
"order",
)
def __init__(self, *arg, **kw) -> None:
super().__init__(*arg, **kw)
self.batch_separator = self.context_opts.get(
"mssql_batch_separator", self.batch_separator
)
def _exec(self, construct: Any, *args, **kw) -> Optional[CursorResult]:
result = super()._exec(construct, *args, **kw)
if self.as_sql and self.batch_separator:
self.static_output(self.batch_separator)
return result
def emit_begin(self) -> None:
self.static_output("BEGIN TRANSACTION" + self.command_terminator)
def emit_commit(self) -> None:
super().emit_commit()
if self.as_sql and self.batch_separator:
self.static_output(self.batch_separator)
def alter_column( # type:ignore[override]
self,
table_name: str,
column_name: str,
nullable: Optional[bool] = None,
server_default: Optional[
Union[_ServerDefault, Literal[False]]
] = False,
name: Optional[str] = None,
type_: Optional[TypeEngine] = None,
schema: Optional[str] = None,
existing_type: Optional[TypeEngine] = None,
existing_server_default: Optional[_ServerDefault] = None,
existing_nullable: Optional[bool] = None,
**kw: Any,
) -> None:
if nullable is not None:
if type_ is not None:
# the NULL/NOT NULL alter will handle
# the type alteration
existing_type = type_
type_ = None
elif existing_type is None:
raise util.CommandError(
"MS-SQL ALTER COLUMN operations "
"with NULL or NOT NULL require the "
"existing_type or a new type_ be passed."
)
elif existing_nullable is not None and type_ is not None:
nullable = existing_nullable
# the NULL/NOT NULL alter will handle
# the type alteration
existing_type = type_
type_ = None
elif type_ is not None:
util.warn(
"MS-SQL ALTER COLUMN operations that specify type_= "
"should also specify a nullable= or "
"existing_nullable= argument to avoid implicit conversion "
"of NOT NULL columns to NULL."
)
used_default = False
if sqla_compat._server_default_is_identity(
server_default, existing_server_default
) or sqla_compat._server_default_is_computed(
server_default, existing_server_default
):
used_default = True
kw["server_default"] = server_default
kw["existing_server_default"] = existing_server_default
super().alter_column(
table_name,
column_name,
nullable=nullable,
type_=type_,
schema=schema,
existing_type=existing_type,
existing_nullable=existing_nullable,
**kw,
)
if server_default is not False and used_default is False:
if existing_server_default is not False or server_default is None:
self._exec(
_ExecDropConstraint(
table_name,
column_name,
"sys.default_constraints",
schema,
)
)
if server_default is not None:
super().alter_column(
table_name,
column_name,
schema=schema,
server_default=server_default,
)
if name is not None:
super().alter_column(
table_name, column_name, schema=schema, name=name
)
def create_index(self, index: Index) -> None:
# this likely defaults to None if not present, so get()
# should normally not return the default value. being
# defensive in any case
mssql_include = index.kwargs.get("mssql_include", None) or ()
assert index.table is not None
for col in mssql_include:
if col not in index.table.c:
index.table.append_column(Column(col, sqltypes.NullType))
self._exec(CreateIndex(index))
def bulk_insert( # type:ignore[override]
self, table: Union[TableClause, Table], rows: List[dict], **kw: Any
) -> None:
if self.as_sql:
self._exec(
"SET IDENTITY_INSERT %s ON"
% self.dialect.identifier_preparer.format_table(table)
)
super().bulk_insert(table, rows, **kw)
self._exec(
"SET IDENTITY_INSERT %s OFF"
% self.dialect.identifier_preparer.format_table(table)
)
else:
super().bulk_insert(table, rows, **kw)
def drop_column(
self,
table_name: str,
column: Column,
schema: Optional[str] = None,
**kw,
) -> None:
drop_default = kw.pop("mssql_drop_default", False)
if drop_default:
self._exec(
_ExecDropConstraint(
table_name, column, "sys.default_constraints", schema
)
)
drop_check = kw.pop("mssql_drop_check", False)
if drop_check:
self._exec(
_ExecDropConstraint(
table_name, column, "sys.check_constraints", schema
)
)
drop_fks = kw.pop("mssql_drop_foreign_key", False)
if drop_fks:
self._exec(_ExecDropFKConstraint(table_name, column, schema))
super().drop_column(table_name, column, schema=schema, **kw)
def compare_server_default(
self,
inspector_column,
metadata_column,
rendered_metadata_default,
rendered_inspector_default,
):
if rendered_metadata_default is not None:
rendered_metadata_default = re.sub(
r"[\(\) \"\']", "", rendered_metadata_default
)
if rendered_inspector_default is not None:
# SQL Server collapses whitespace and adds arbitrary parenthesis
# within expressions. our only option is collapse all of it
rendered_inspector_default = re.sub(
r"[\(\) \"\']", "", rendered_inspector_default
)
return rendered_inspector_default != rendered_metadata_default
def _compare_identity_default(self, metadata_identity, inspector_identity):
diff, ignored, is_alter = super()._compare_identity_default(
metadata_identity, inspector_identity
)
if (
metadata_identity is None
and inspector_identity is not None
and not diff
and inspector_identity.column is not None
and inspector_identity.column.primary_key
):
# mssql reflect primary keys with autoincrement as identity
# columns. if no different attributes are present ignore them
is_alter = False
return diff, ignored, is_alter
class _ExecDropConstraint(Executable, ClauseElement):
inherit_cache = False
def __init__(
self,
tname: str,
colname: Union[Column, str],
type_: str,
schema: Optional[str],
) -> None:
self.tname = tname
self.colname = colname
self.type_ = type_
self.schema = schema
class _ExecDropFKConstraint(Executable, ClauseElement):
inherit_cache = False
def __init__(
self, tname: str, colname: Column, schema: Optional[str]
) -> None:
self.tname = tname
self.colname = colname
self.schema = schema
@compiles(_ExecDropConstraint, "mssql")
def _exec_drop_col_constraint(
element: _ExecDropConstraint, compiler: MSSQLCompiler, **kw
) -> str:
schema, tname, colname, type_ = (
element.schema,
element.tname,
element.colname,
element.type_,
)
# from http://www.mssqltips.com/sqlservertip/1425/\
# working-with-default-constraints-in-sql-server/
return """declare @const_name varchar(256)
select @const_name = QUOTENAME([name]) from %(type)s
where parent_object_id = object_id('%(schema_dot)s%(tname)s')
and col_name(parent_object_id, parent_column_id) = '%(colname)s'
exec('alter table %(tname_quoted)s drop constraint ' + @const_name)""" % {
"type": type_,
"tname": tname,
"colname": colname,
"tname_quoted": format_table_name(compiler, tname, schema),
"schema_dot": schema + "." if schema else "",
}
@compiles(_ExecDropFKConstraint, "mssql")
def _exec_drop_col_fk_constraint(
element: _ExecDropFKConstraint, compiler: MSSQLCompiler, **kw
) -> str:
schema, tname, colname = element.schema, element.tname, element.colname
return """declare @const_name varchar(256)
select @const_name = QUOTENAME([name]) from
sys.foreign_keys fk join sys.foreign_key_columns fkc
on fk.object_id=fkc.constraint_object_id
where fkc.parent_object_id = object_id('%(schema_dot)s%(tname)s')
and col_name(fkc.parent_object_id, fkc.parent_column_id) = '%(colname)s'
exec('alter table %(tname_quoted)s drop constraint ' + @const_name)""" % {
"tname": tname,
"colname": colname,
"tname_quoted": format_table_name(compiler, tname, schema),
"schema_dot": schema + "." if schema else "",
}
@compiles(AddColumn, "mssql")
def visit_add_column(element: AddColumn, compiler: MSDDLCompiler, **kw) -> str:
return "%s %s" % (
alter_table(compiler, element.table_name, element.schema),
mssql_add_column(compiler, element.column, **kw),
)
def mssql_add_column(compiler: MSDDLCompiler, column: Column, **kw) -> str:
return "ADD %s" % compiler.get_column_specification(column, **kw)
@compiles(ColumnNullable, "mssql")
def visit_column_nullable(
element: ColumnNullable, compiler: MSDDLCompiler, **kw
) -> str:
return "%s %s %s %s" % (
alter_table(compiler, element.table_name, element.schema),
alter_column(compiler, element.column_name),
format_type(compiler, element.existing_type), # type: ignore[arg-type]
"NULL" if element.nullable else "NOT NULL",
)
@compiles(ColumnDefault, "mssql")
def visit_column_default(
element: ColumnDefault, compiler: MSDDLCompiler, **kw
) -> str:
# TODO: there can also be a named constraint
# with ADD CONSTRAINT here
return "%s ADD DEFAULT %s FOR %s" % (
alter_table(compiler, element.table_name, element.schema),
format_server_default(compiler, element.default),
format_column_name(compiler, element.column_name),
)
@compiles(ColumnName, "mssql")
def visit_rename_column(
element: ColumnName, compiler: MSDDLCompiler, **kw
) -> str:
return "EXEC sp_rename '%s.%s', %s, 'COLUMN'" % (
format_table_name(compiler, element.table_name, element.schema),
format_column_name(compiler, element.column_name),
format_column_name(compiler, element.newname),
)
@compiles(ColumnType, "mssql")
def visit_column_type(
element: ColumnType, compiler: MSDDLCompiler, **kw
) -> str:
return "%s %s %s" % (
alter_table(compiler, element.table_name, element.schema),
alter_column(compiler, element.column_name),
format_type(compiler, element.type_),
)
@compiles(RenameTable, "mssql")
def visit_rename_table(
element: RenameTable, compiler: MSDDLCompiler, **kw
) -> str:
return "EXEC sp_rename '%s', %s" % (
format_table_name(compiler, element.table_name, element.schema),
format_table_name(compiler, element.new_table_name, None),
)

463
libs/alembic/ddl/mysql.py Normal file
View File

@ -0,0 +1,463 @@
from __future__ import annotations
import re
from typing import Any
from typing import Optional
from typing import TYPE_CHECKING
from typing import Union
from sqlalchemy import schema
from sqlalchemy import types as sqltypes
from sqlalchemy.ext.compiler import compiles
from .base import alter_table
from .base import AlterColumn
from .base import ColumnDefault
from .base import ColumnName
from .base import ColumnNullable
from .base import ColumnType
from .base import format_column_name
from .base import format_server_default
from .impl import DefaultImpl
from .. import util
from ..autogenerate import compare
from ..util import sqla_compat
from ..util.sqla_compat import _is_mariadb
from ..util.sqla_compat import _is_type_bound
if TYPE_CHECKING:
from typing import Literal
from sqlalchemy.dialects.mysql.base import MySQLDDLCompiler
from sqlalchemy.sql.ddl import DropConstraint
from sqlalchemy.sql.schema import Constraint
from sqlalchemy.sql.type_api import TypeEngine
from .base import _ServerDefault
class MySQLImpl(DefaultImpl):
__dialect__ = "mysql"
transactional_ddl = False
type_synonyms = DefaultImpl.type_synonyms + (
{"BOOL", "TINYINT"},
{"JSON", "LONGTEXT"},
)
type_arg_extract = [r"character set ([\w\-_]+)", r"collate ([\w\-_]+)"]
def alter_column( # type:ignore[override]
self,
table_name: str,
column_name: str,
nullable: Optional[bool] = None,
server_default: Union[_ServerDefault, Literal[False]] = False,
name: Optional[str] = None,
type_: Optional[TypeEngine] = None,
schema: Optional[str] = None,
existing_type: Optional[TypeEngine] = None,
existing_server_default: Optional[_ServerDefault] = None,
existing_nullable: Optional[bool] = None,
autoincrement: Optional[bool] = None,
existing_autoincrement: Optional[bool] = None,
comment: Optional[Union[str, Literal[False]]] = False,
existing_comment: Optional[str] = None,
**kw: Any,
) -> None:
if sqla_compat._server_default_is_identity(
server_default, existing_server_default
) or sqla_compat._server_default_is_computed(
server_default, existing_server_default
):
# modifying computed or identity columns is not supported
# the default will raise
super().alter_column(
table_name,
column_name,
nullable=nullable,
type_=type_,
schema=schema,
existing_type=existing_type,
existing_nullable=existing_nullable,
server_default=server_default,
existing_server_default=existing_server_default,
**kw,
)
if name is not None or self._is_mysql_allowed_functional_default(
type_ if type_ is not None else existing_type, server_default
):
self._exec(
MySQLChangeColumn(
table_name,
column_name,
schema=schema,
newname=name if name is not None else column_name,
nullable=nullable
if nullable is not None
else existing_nullable
if existing_nullable is not None
else True,
type_=type_ if type_ is not None else existing_type,
default=server_default
if server_default is not False
else existing_server_default,
autoincrement=autoincrement
if autoincrement is not None
else existing_autoincrement,
comment=comment
if comment is not False
else existing_comment,
)
)
elif (
nullable is not None
or type_ is not None
or autoincrement is not None
or comment is not False
):
self._exec(
MySQLModifyColumn(
table_name,
column_name,
schema=schema,
newname=name if name is not None else column_name,
nullable=nullable
if nullable is not None
else existing_nullable
if existing_nullable is not None
else True,
type_=type_ if type_ is not None else existing_type,
default=server_default
if server_default is not False
else existing_server_default,
autoincrement=autoincrement
if autoincrement is not None
else existing_autoincrement,
comment=comment
if comment is not False
else existing_comment,
)
)
elif server_default is not False:
self._exec(
MySQLAlterDefault(
table_name, column_name, server_default, schema=schema
)
)
def drop_constraint(
self,
const: Constraint,
) -> None:
if isinstance(const, schema.CheckConstraint) and _is_type_bound(const):
return
super().drop_constraint(const)
def _is_mysql_allowed_functional_default(
self,
type_: Optional[TypeEngine],
server_default: Union[_ServerDefault, Literal[False]],
) -> bool:
return (
type_ is not None
and type_._type_affinity # type:ignore[attr-defined]
is sqltypes.DateTime
and server_default is not None
)
def compare_server_default(
self,
inspector_column,
metadata_column,
rendered_metadata_default,
rendered_inspector_default,
):
# partially a workaround for SQLAlchemy issue #3023; if the
# column were created without "NOT NULL", MySQL may have added
# an implicit default of '0' which we need to skip
# TODO: this is not really covered anymore ?
if (
metadata_column.type._type_affinity is sqltypes.Integer
and inspector_column.primary_key
and not inspector_column.autoincrement
and not rendered_metadata_default
and rendered_inspector_default == "'0'"
):
return False
elif inspector_column.type._type_affinity is sqltypes.Integer:
rendered_inspector_default = (
re.sub(r"^'|'$", "", rendered_inspector_default)
if rendered_inspector_default is not None
else None
)
return rendered_inspector_default != rendered_metadata_default
elif rendered_inspector_default and rendered_metadata_default:
# adjust for "function()" vs. "FUNCTION" as can occur particularly
# for the CURRENT_TIMESTAMP function on newer MariaDB versions
# SQLAlchemy MySQL dialect bundles ON UPDATE into the server
# default; adjust for this possibly being present.
onupdate_ins = re.match(
r"(.*) (on update.*?)(?:\(\))?$",
rendered_inspector_default.lower(),
)
onupdate_met = re.match(
r"(.*) (on update.*?)(?:\(\))?$",
rendered_metadata_default.lower(),
)
if onupdate_ins:
if not onupdate_met:
return True
elif onupdate_ins.group(2) != onupdate_met.group(2):
return True
rendered_inspector_default = onupdate_ins.group(1)
rendered_metadata_default = onupdate_met.group(1)
return re.sub(
r"(.*?)(?:\(\))?$", r"\1", rendered_inspector_default.lower()
) != re.sub(
r"(.*?)(?:\(\))?$", r"\1", rendered_metadata_default.lower()
)
else:
return rendered_inspector_default != rendered_metadata_default
def correct_for_autogen_constraints(
self,
conn_unique_constraints,
conn_indexes,
metadata_unique_constraints,
metadata_indexes,
):
# TODO: if SQLA 1.0, make use of "duplicates_index"
# metadata
removed = set()
for idx in list(conn_indexes):
if idx.unique:
continue
# MySQL puts implicit indexes on FK columns, even if
# composite and even if MyISAM, so can't check this too easily.
# the name of the index may be the column name or it may
# be the name of the FK constraint.
for col in idx.columns:
if idx.name == col.name:
conn_indexes.remove(idx)
removed.add(idx.name)
break
for fk in col.foreign_keys:
if fk.name == idx.name:
conn_indexes.remove(idx)
removed.add(idx.name)
break
if idx.name in removed:
break
# then remove indexes from the "metadata_indexes"
# that we've removed from reflected, otherwise they come out
# as adds (see #202)
for idx in list(metadata_indexes):
if idx.name in removed:
metadata_indexes.remove(idx)
def correct_for_autogen_foreignkeys(self, conn_fks, metadata_fks):
conn_fk_by_sig = {
compare._fk_constraint_sig(fk).sig: fk for fk in conn_fks
}
metadata_fk_by_sig = {
compare._fk_constraint_sig(fk).sig: fk for fk in metadata_fks
}
for sig in set(conn_fk_by_sig).intersection(metadata_fk_by_sig):
mdfk = metadata_fk_by_sig[sig]
cnfk = conn_fk_by_sig[sig]
# MySQL considers RESTRICT to be the default and doesn't
# report on it. if the model has explicit RESTRICT and
# the conn FK has None, set it to RESTRICT
if (
mdfk.ondelete is not None
and mdfk.ondelete.lower() == "restrict"
and cnfk.ondelete is None
):
cnfk.ondelete = "RESTRICT"
if (
mdfk.onupdate is not None
and mdfk.onupdate.lower() == "restrict"
and cnfk.onupdate is None
):
cnfk.onupdate = "RESTRICT"
class MariaDBImpl(MySQLImpl):
__dialect__ = "mariadb"
class MySQLAlterDefault(AlterColumn):
def __init__(
self,
name: str,
column_name: str,
default: _ServerDefault,
schema: Optional[str] = None,
) -> None:
super(AlterColumn, self).__init__(name, schema=schema)
self.column_name = column_name
self.default = default
class MySQLChangeColumn(AlterColumn):
def __init__(
self,
name: str,
column_name: str,
schema: Optional[str] = None,
newname: Optional[str] = None,
type_: Optional[TypeEngine] = None,
nullable: Optional[bool] = None,
default: Optional[Union[_ServerDefault, Literal[False]]] = False,
autoincrement: Optional[bool] = None,
comment: Optional[Union[str, Literal[False]]] = False,
) -> None:
super(AlterColumn, self).__init__(name, schema=schema)
self.column_name = column_name
self.nullable = nullable
self.newname = newname
self.default = default
self.autoincrement = autoincrement
self.comment = comment
if type_ is None:
raise util.CommandError(
"All MySQL CHANGE/MODIFY COLUMN operations "
"require the existing type."
)
self.type_ = sqltypes.to_instance(type_)
class MySQLModifyColumn(MySQLChangeColumn):
pass
@compiles(ColumnNullable, "mysql", "mariadb")
@compiles(ColumnName, "mysql", "mariadb")
@compiles(ColumnDefault, "mysql", "mariadb")
@compiles(ColumnType, "mysql", "mariadb")
def _mysql_doesnt_support_individual(element, compiler, **kw):
raise NotImplementedError(
"Individual alter column constructs not supported by MySQL"
)
@compiles(MySQLAlterDefault, "mysql", "mariadb")
def _mysql_alter_default(
element: MySQLAlterDefault, compiler: MySQLDDLCompiler, **kw
) -> str:
return "%s ALTER COLUMN %s %s" % (
alter_table(compiler, element.table_name, element.schema),
format_column_name(compiler, element.column_name),
"SET DEFAULT %s" % format_server_default(compiler, element.default)
if element.default is not None
else "DROP DEFAULT",
)
@compiles(MySQLModifyColumn, "mysql", "mariadb")
def _mysql_modify_column(
element: MySQLModifyColumn, compiler: MySQLDDLCompiler, **kw
) -> str:
return "%s MODIFY %s %s" % (
alter_table(compiler, element.table_name, element.schema),
format_column_name(compiler, element.column_name),
_mysql_colspec(
compiler,
nullable=element.nullable,
server_default=element.default,
type_=element.type_,
autoincrement=element.autoincrement,
comment=element.comment,
),
)
@compiles(MySQLChangeColumn, "mysql", "mariadb")
def _mysql_change_column(
element: MySQLChangeColumn, compiler: MySQLDDLCompiler, **kw
) -> str:
return "%s CHANGE %s %s %s" % (
alter_table(compiler, element.table_name, element.schema),
format_column_name(compiler, element.column_name),
format_column_name(compiler, element.newname),
_mysql_colspec(
compiler,
nullable=element.nullable,
server_default=element.default,
type_=element.type_,
autoincrement=element.autoincrement,
comment=element.comment,
),
)
def _mysql_colspec(
compiler: MySQLDDLCompiler,
nullable: Optional[bool],
server_default: Optional[Union[_ServerDefault, Literal[False]]],
type_: TypeEngine,
autoincrement: Optional[bool],
comment: Optional[Union[str, Literal[False]]],
) -> str:
spec = "%s %s" % (
compiler.dialect.type_compiler.process(type_),
"NULL" if nullable else "NOT NULL",
)
if autoincrement:
spec += " AUTO_INCREMENT"
if server_default is not False and server_default is not None:
spec += " DEFAULT %s" % format_server_default(compiler, server_default)
if comment:
spec += " COMMENT %s" % compiler.sql_compiler.render_literal_value(
comment, sqltypes.String()
)
return spec
@compiles(schema.DropConstraint, "mysql", "mariadb")
def _mysql_drop_constraint(
element: DropConstraint, compiler: MySQLDDLCompiler, **kw
) -> str:
"""Redefine SQLAlchemy's drop constraint to
raise errors for invalid constraint type."""
constraint = element.element
if isinstance(
constraint,
(
schema.ForeignKeyConstraint,
schema.PrimaryKeyConstraint,
schema.UniqueConstraint,
),
):
assert not kw
return compiler.visit_drop_constraint(element)
elif isinstance(constraint, schema.CheckConstraint):
# note that SQLAlchemy as of 1.2 does not yet support
# DROP CONSTRAINT for MySQL/MariaDB, so we implement fully
# here.
if _is_mariadb(compiler.dialect):
return "ALTER TABLE %s DROP CONSTRAINT %s" % (
compiler.preparer.format_table(constraint.table),
compiler.preparer.format_constraint(constraint),
)
else:
return "ALTER TABLE %s DROP CHECK %s" % (
compiler.preparer.format_table(constraint.table),
compiler.preparer.format_constraint(constraint),
)
else:
raise NotImplementedError(
"No generic 'DROP CONSTRAINT' in MySQL - "
"please specify constraint type"
)

197
libs/alembic/ddl/oracle.py Normal file
View File

@ -0,0 +1,197 @@
from __future__ import annotations
import re
from typing import Any
from typing import Optional
from typing import TYPE_CHECKING
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.sql import sqltypes
from .base import AddColumn
from .base import alter_table
from .base import ColumnComment
from .base import ColumnDefault
from .base import ColumnName
from .base import ColumnNullable
from .base import ColumnType
from .base import format_column_name
from .base import format_server_default
from .base import format_table_name
from .base import format_type
from .base import IdentityColumnDefault
from .base import RenameTable
from .impl import DefaultImpl
if TYPE_CHECKING:
from sqlalchemy.dialects.oracle.base import OracleDDLCompiler
from sqlalchemy.engine.cursor import CursorResult
from sqlalchemy.sql.schema import Column
class OracleImpl(DefaultImpl):
__dialect__ = "oracle"
transactional_ddl = False
batch_separator = "/"
command_terminator = ""
type_synonyms = DefaultImpl.type_synonyms + (
{"VARCHAR", "VARCHAR2"},
{"BIGINT", "INTEGER", "SMALLINT", "DECIMAL", "NUMERIC", "NUMBER"},
{"DOUBLE", "FLOAT", "DOUBLE_PRECISION"},
)
identity_attrs_ignore = ()
def __init__(self, *arg, **kw) -> None:
super().__init__(*arg, **kw)
self.batch_separator = self.context_opts.get(
"oracle_batch_separator", self.batch_separator
)
def _exec(self, construct: Any, *args, **kw) -> Optional[CursorResult]:
result = super()._exec(construct, *args, **kw)
if self.as_sql and self.batch_separator:
self.static_output(self.batch_separator)
return result
def compare_server_default(
self,
inspector_column,
metadata_column,
rendered_metadata_default,
rendered_inspector_default,
):
if rendered_metadata_default is not None:
rendered_metadata_default = re.sub(
r"^\((.+)\)$", r"\1", rendered_metadata_default
)
rendered_metadata_default = re.sub(
r"^\"?'(.+)'\"?$", r"\1", rendered_metadata_default
)
if rendered_inspector_default is not None:
rendered_inspector_default = re.sub(
r"^\((.+)\)$", r"\1", rendered_inspector_default
)
rendered_inspector_default = re.sub(
r"^\"?'(.+)'\"?$", r"\1", rendered_inspector_default
)
rendered_inspector_default = rendered_inspector_default.strip()
return rendered_inspector_default != rendered_metadata_default
def emit_begin(self) -> None:
self._exec("SET TRANSACTION READ WRITE")
def emit_commit(self) -> None:
self._exec("COMMIT")
@compiles(AddColumn, "oracle")
def visit_add_column(
element: AddColumn, compiler: OracleDDLCompiler, **kw
) -> str:
return "%s %s" % (
alter_table(compiler, element.table_name, element.schema),
add_column(compiler, element.column, **kw),
)
@compiles(ColumnNullable, "oracle")
def visit_column_nullable(
element: ColumnNullable, compiler: OracleDDLCompiler, **kw
) -> str:
return "%s %s %s" % (
alter_table(compiler, element.table_name, element.schema),
alter_column(compiler, element.column_name),
"NULL" if element.nullable else "NOT NULL",
)
@compiles(ColumnType, "oracle")
def visit_column_type(
element: ColumnType, compiler: OracleDDLCompiler, **kw
) -> str:
return "%s %s %s" % (
alter_table(compiler, element.table_name, element.schema),
alter_column(compiler, element.column_name),
"%s" % format_type(compiler, element.type_),
)
@compiles(ColumnName, "oracle")
def visit_column_name(
element: ColumnName, compiler: OracleDDLCompiler, **kw
) -> str:
return "%s RENAME COLUMN %s TO %s" % (
alter_table(compiler, element.table_name, element.schema),
format_column_name(compiler, element.column_name),
format_column_name(compiler, element.newname),
)
@compiles(ColumnDefault, "oracle")
def visit_column_default(
element: ColumnDefault, compiler: OracleDDLCompiler, **kw
) -> str:
return "%s %s %s" % (
alter_table(compiler, element.table_name, element.schema),
alter_column(compiler, element.column_name),
"DEFAULT %s" % format_server_default(compiler, element.default)
if element.default is not None
else "DEFAULT NULL",
)
@compiles(ColumnComment, "oracle")
def visit_column_comment(
element: ColumnComment, compiler: OracleDDLCompiler, **kw
) -> str:
ddl = "COMMENT ON COLUMN {table_name}.{column_name} IS {comment}"
comment = compiler.sql_compiler.render_literal_value(
(element.comment if element.comment is not None else ""),
sqltypes.String(),
)
return ddl.format(
table_name=element.table_name,
column_name=element.column_name,
comment=comment,
)
@compiles(RenameTable, "oracle")
def visit_rename_table(
element: RenameTable, compiler: OracleDDLCompiler, **kw
) -> str:
return "%s RENAME TO %s" % (
alter_table(compiler, element.table_name, element.schema),
format_table_name(compiler, element.new_table_name, None),
)
def alter_column(compiler: OracleDDLCompiler, name: str) -> str:
return "MODIFY %s" % format_column_name(compiler, name)
def add_column(compiler: OracleDDLCompiler, column: Column, **kw) -> str:
return "ADD %s" % compiler.get_column_specification(column, **kw)
@compiles(IdentityColumnDefault, "oracle")
def visit_identity_column(
element: IdentityColumnDefault, compiler: OracleDDLCompiler, **kw
):
text = "%s %s " % (
alter_table(compiler, element.table_name, element.schema),
alter_column(compiler, element.column_name),
)
if element.default is None:
# drop identity
text += "DROP IDENTITY"
return text
else:
text += compiler.visit_identity_column(element.default)
return text

View File

@ -0,0 +1,688 @@
from __future__ import annotations
import logging
import re
from typing import Any
from typing import cast
from typing import List
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import TYPE_CHECKING
from typing import Union
from sqlalchemy import Column
from sqlalchemy import Index
from sqlalchemy import literal_column
from sqlalchemy import Numeric
from sqlalchemy import text
from sqlalchemy import types as sqltypes
from sqlalchemy.dialects.postgresql import BIGINT
from sqlalchemy.dialects.postgresql import ExcludeConstraint
from sqlalchemy.dialects.postgresql import INTEGER
from sqlalchemy.schema import CreateIndex
from sqlalchemy.sql.elements import ColumnClause
from sqlalchemy.sql.elements import TextClause
from sqlalchemy.types import NULLTYPE
from .base import alter_column
from .base import alter_table
from .base import AlterColumn
from .base import ColumnComment
from .base import compiles
from .base import format_column_name
from .base import format_table_name
from .base import format_type
from .base import IdentityColumnDefault
from .base import RenameTable
from .impl import DefaultImpl
from .. import util
from ..autogenerate import render
from ..operations import ops
from ..operations import schemaobj
from ..operations.base import BatchOperations
from ..operations.base import Operations
from ..util import sqla_compat
if TYPE_CHECKING:
from typing import Literal
from sqlalchemy.dialects.postgresql.array import ARRAY
from sqlalchemy.dialects.postgresql.base import PGDDLCompiler
from sqlalchemy.dialects.postgresql.hstore import HSTORE
from sqlalchemy.dialects.postgresql.json import JSON
from sqlalchemy.dialects.postgresql.json import JSONB
from sqlalchemy.sql.elements import BinaryExpression
from sqlalchemy.sql.elements import quoted_name
from sqlalchemy.sql.schema import MetaData
from sqlalchemy.sql.schema import Table
from sqlalchemy.sql.type_api import TypeEngine
from .base import _ServerDefault
from ..autogenerate.api import AutogenContext
from ..autogenerate.render import _f_name
from ..runtime.migration import MigrationContext
log = logging.getLogger(__name__)
class PostgresqlImpl(DefaultImpl):
__dialect__ = "postgresql"
transactional_ddl = True
type_synonyms = DefaultImpl.type_synonyms + (
{"FLOAT", "DOUBLE PRECISION"},
)
identity_attrs_ignore = ("on_null", "order")
def create_index(self, index):
# this likely defaults to None if not present, so get()
# should normally not return the default value. being
# defensive in any case
postgresql_include = index.kwargs.get("postgresql_include", None) or ()
for col in postgresql_include:
if col not in index.table.c:
index.table.append_column(Column(col, sqltypes.NullType))
self._exec(CreateIndex(index))
def prep_table_for_batch(self, batch_impl, table):
for constraint in table.constraints:
if (
constraint.name is not None
and constraint.name in batch_impl.named_constraints
):
self.drop_constraint(constraint)
def compare_server_default(
self,
inspector_column,
metadata_column,
rendered_metadata_default,
rendered_inspector_default,
):
# don't do defaults for SERIAL columns
if (
metadata_column.primary_key
and metadata_column is metadata_column.table._autoincrement_column
):
return False
conn_col_default = rendered_inspector_default
defaults_equal = conn_col_default == rendered_metadata_default
if defaults_equal:
return False
if None in (
conn_col_default,
rendered_metadata_default,
metadata_column.server_default,
):
return not defaults_equal
metadata_default = metadata_column.server_default.arg
if isinstance(metadata_default, str):
if not isinstance(inspector_column.type, Numeric):
metadata_default = re.sub(r"^'|'$", "", metadata_default)
metadata_default = f"'{metadata_default}'"
metadata_default = literal_column(metadata_default)
# run a real compare against the server
return not self.connection.scalar(
sqla_compat._select(
literal_column(conn_col_default) == metadata_default
)
)
def alter_column( # type:ignore[override]
self,
table_name: str,
column_name: str,
nullable: Optional[bool] = None,
server_default: Union[_ServerDefault, Literal[False]] = False,
name: Optional[str] = None,
type_: Optional[TypeEngine] = None,
schema: Optional[str] = None,
autoincrement: Optional[bool] = None,
existing_type: Optional[TypeEngine] = None,
existing_server_default: Optional[_ServerDefault] = None,
existing_nullable: Optional[bool] = None,
existing_autoincrement: Optional[bool] = None,
**kw: Any,
) -> None:
using = kw.pop("postgresql_using", None)
if using is not None and type_ is None:
raise util.CommandError(
"postgresql_using must be used with the type_ parameter"
)
if type_ is not None:
self._exec(
PostgresqlColumnType(
table_name,
column_name,
type_,
schema=schema,
using=using,
existing_type=existing_type,
existing_server_default=existing_server_default,
existing_nullable=existing_nullable,
)
)
super().alter_column(
table_name,
column_name,
nullable=nullable,
server_default=server_default,
name=name,
schema=schema,
autoincrement=autoincrement,
existing_type=existing_type,
existing_server_default=existing_server_default,
existing_nullable=existing_nullable,
existing_autoincrement=existing_autoincrement,
**kw,
)
def autogen_column_reflect(self, inspector, table, column_info):
if column_info.get("default") and isinstance(
column_info["type"], (INTEGER, BIGINT)
):
seq_match = re.match(
r"nextval\('(.+?)'::regclass\)", column_info["default"]
)
if seq_match:
info = sqla_compat._exec_on_inspector(
inspector,
text(
"select c.relname, a.attname "
"from pg_class as c join "
"pg_depend d on d.objid=c.oid and "
"d.classid='pg_class'::regclass and "
"d.refclassid='pg_class'::regclass "
"join pg_class t on t.oid=d.refobjid "
"join pg_attribute a on a.attrelid=t.oid and "
"a.attnum=d.refobjsubid "
"where c.relkind='S' and c.relname=:seqname"
),
seqname=seq_match.group(1),
).first()
if info:
seqname, colname = info
if colname == column_info["name"]:
log.info(
"Detected sequence named '%s' as "
"owned by integer column '%s(%s)', "
"assuming SERIAL and omitting",
seqname,
table.name,
colname,
)
# sequence, and the owner is this column,
# its a SERIAL - whack it!
del column_info["default"]
def correct_for_autogen_constraints(
self,
conn_unique_constraints,
conn_indexes,
metadata_unique_constraints,
metadata_indexes,
):
doubled_constraints = {
index
for index in conn_indexes
if index.info.get("duplicates_constraint")
}
for ix in doubled_constraints:
conn_indexes.remove(ix)
if not sqla_compat.sqla_2:
self._skip_functional_indexes(metadata_indexes, conn_indexes)
def _cleanup_index_expr(self, index: Index, expr: str) -> str:
# start = expr
expr = expr.lower()
expr = expr.replace('"', "")
if index.table is not None:
expr = expr.replace(f"{index.table.name.lower()}.", "")
while expr and expr[0] == "(" and expr[-1] == ")":
expr = expr[1:-1]
if "::" in expr:
# strip :: cast. types can have spaces in them
expr = re.sub(r"(::[\w ]+\w)", "", expr)
# print(f"START: {start} END: {expr}")
return expr
def create_index_sig(self, index: Index) -> Tuple[Any, ...]:
if sqla_compat.is_expression_index(index):
return tuple(
self._cleanup_index_expr(
index,
e
if isinstance(e, str)
else e.compile(
dialect=self.dialect,
compile_kwargs={"literal_binds": True},
).string,
)
for e in index.expressions
)
else:
return super().create_index_sig(index)
def render_type(
self, type_: TypeEngine, autogen_context: AutogenContext
) -> Union[str, Literal[False]]:
mod = type(type_).__module__
if not mod.startswith("sqlalchemy.dialects.postgresql"):
return False
if hasattr(self, "_render_%s_type" % type_.__visit_name__):
meth = getattr(self, "_render_%s_type" % type_.__visit_name__)
return meth(type_, autogen_context)
return False
def _render_HSTORE_type(
self, type_: HSTORE, autogen_context: AutogenContext
) -> str:
return cast(
str,
render._render_type_w_subtype(
type_, autogen_context, "text_type", r"(.+?\(.*text_type=)"
),
)
def _render_ARRAY_type(
self, type_: ARRAY, autogen_context: AutogenContext
) -> str:
return cast(
str,
render._render_type_w_subtype(
type_, autogen_context, "item_type", r"(.+?\()"
),
)
def _render_JSON_type(
self, type_: JSON, autogen_context: AutogenContext
) -> str:
return cast(
str,
render._render_type_w_subtype(
type_, autogen_context, "astext_type", r"(.+?\(.*astext_type=)"
),
)
def _render_JSONB_type(
self, type_: JSONB, autogen_context: AutogenContext
) -> str:
return cast(
str,
render._render_type_w_subtype(
type_, autogen_context, "astext_type", r"(.+?\(.*astext_type=)"
),
)
class PostgresqlColumnType(AlterColumn):
def __init__(
self, name: str, column_name: str, type_: TypeEngine, **kw
) -> None:
using = kw.pop("using", None)
super().__init__(name, column_name, **kw)
self.type_ = sqltypes.to_instance(type_)
self.using = using
@compiles(RenameTable, "postgresql")
def visit_rename_table(
element: RenameTable, compiler: PGDDLCompiler, **kw
) -> str:
return "%s RENAME TO %s" % (
alter_table(compiler, element.table_name, element.schema),
format_table_name(compiler, element.new_table_name, None),
)
@compiles(PostgresqlColumnType, "postgresql")
def visit_column_type(
element: PostgresqlColumnType, compiler: PGDDLCompiler, **kw
) -> str:
return "%s %s %s %s" % (
alter_table(compiler, element.table_name, element.schema),
alter_column(compiler, element.column_name),
"TYPE %s" % format_type(compiler, element.type_),
"USING %s" % element.using if element.using else "",
)
@compiles(ColumnComment, "postgresql")
def visit_column_comment(
element: ColumnComment, compiler: PGDDLCompiler, **kw
) -> str:
ddl = "COMMENT ON COLUMN {table_name}.{column_name} IS {comment}"
comment = (
compiler.sql_compiler.render_literal_value(
element.comment, sqltypes.String()
)
if element.comment is not None
else "NULL"
)
return ddl.format(
table_name=format_table_name(
compiler, element.table_name, element.schema
),
column_name=format_column_name(compiler, element.column_name),
comment=comment,
)
@compiles(IdentityColumnDefault, "postgresql")
def visit_identity_column(
element: IdentityColumnDefault, compiler: PGDDLCompiler, **kw
):
text = "%s %s " % (
alter_table(compiler, element.table_name, element.schema),
alter_column(compiler, element.column_name),
)
if element.default is None:
# drop identity
text += "DROP IDENTITY"
return text
elif element.existing_server_default is None:
# add identity options
text += "ADD "
text += compiler.visit_identity_column(element.default)
return text
else:
# alter identity
diff, _, _ = element.impl._compare_identity_default(
element.default, element.existing_server_default
)
identity = element.default
for attr in sorted(diff):
if attr == "always":
text += "SET GENERATED %s " % (
"ALWAYS" if identity.always else "BY DEFAULT"
)
else:
text += "SET %s " % compiler.get_identity_options(
sqla_compat.Identity(**{attr: getattr(identity, attr)})
)
return text
@Operations.register_operation("create_exclude_constraint")
@BatchOperations.register_operation(
"create_exclude_constraint", "batch_create_exclude_constraint"
)
@ops.AddConstraintOp.register_add_constraint("exclude_constraint")
class CreateExcludeConstraintOp(ops.AddConstraintOp):
"""Represent a create exclude constraint operation."""
constraint_type = "exclude"
def __init__(
self,
constraint_name: sqla_compat._ConstraintName,
table_name: Union[str, quoted_name],
elements: Union[
Sequence[Tuple[str, str]],
Sequence[Tuple[ColumnClause, str]],
],
where: Optional[Union[BinaryExpression, str]] = None,
schema: Optional[str] = None,
_orig_constraint: Optional[ExcludeConstraint] = None,
**kw,
) -> None:
self.constraint_name = constraint_name
self.table_name = table_name
self.elements = elements
self.where = where
self.schema = schema
self._orig_constraint = _orig_constraint
self.kw = kw
@classmethod
def from_constraint( # type:ignore[override]
cls, constraint: ExcludeConstraint
) -> CreateExcludeConstraintOp:
constraint_table = sqla_compat._table_for_constraint(constraint)
return cls(
constraint.name,
constraint_table.name,
[
(expr, op)
for expr, name, op in constraint._render_exprs # type:ignore[attr-defined] # noqa
],
where=cast(
"Optional[Union[BinaryExpression, str]]", constraint.where
),
schema=constraint_table.schema,
_orig_constraint=constraint,
deferrable=constraint.deferrable,
initially=constraint.initially,
using=constraint.using,
)
def to_constraint(
self, migration_context: Optional[MigrationContext] = None
) -> ExcludeConstraint:
if self._orig_constraint is not None:
return self._orig_constraint
schema_obj = schemaobj.SchemaObjects(migration_context)
t = schema_obj.table(self.table_name, schema=self.schema)
excl = ExcludeConstraint(
*self.elements,
name=self.constraint_name,
where=self.where,
**self.kw,
)
for (
expr,
name,
oper,
) in excl._render_exprs: # type:ignore[attr-defined]
t.append_column(Column(name, NULLTYPE))
t.append_constraint(excl)
return excl
@classmethod
def create_exclude_constraint(
cls,
operations: Operations,
constraint_name: str,
table_name: str,
*elements: Any,
**kw: Any,
) -> Optional[Table]:
"""Issue an alter to create an EXCLUDE constraint using the
current migration context.
.. note:: This method is Postgresql specific, and additionally
requires at least SQLAlchemy 1.0.
e.g.::
from alembic import op
op.create_exclude_constraint(
"user_excl",
"user",
("period", '&&'),
("group", '='),
where=("group != 'some group'")
)
Note that the expressions work the same way as that of
the ``ExcludeConstraint`` object itself; if plain strings are
passed, quoting rules must be applied manually.
:param name: Name of the constraint.
:param table_name: String name of the source table.
:param elements: exclude conditions.
:param where: SQL expression or SQL string with optional WHERE
clause.
:param deferrable: optional bool. If set, emit DEFERRABLE or
NOT DEFERRABLE when issuing DDL for this constraint.
:param initially: optional string. If set, emit INITIALLY <value>
when issuing DDL for this constraint.
:param schema: Optional schema name to operate within.
"""
op = cls(constraint_name, table_name, elements, **kw)
return operations.invoke(op)
@classmethod
def batch_create_exclude_constraint(
cls, operations, constraint_name, *elements, **kw
):
"""Issue a "create exclude constraint" instruction using the
current batch migration context.
.. note:: This method is Postgresql specific, and additionally
requires at least SQLAlchemy 1.0.
.. seealso::
:meth:`.Operations.create_exclude_constraint`
"""
kw["schema"] = operations.impl.schema
op = cls(constraint_name, operations.impl.table_name, elements, **kw)
return operations.invoke(op)
@render.renderers.dispatch_for(CreateExcludeConstraintOp)
def _add_exclude_constraint(
autogen_context: AutogenContext, op: CreateExcludeConstraintOp
) -> str:
return _exclude_constraint(op.to_constraint(), autogen_context, alter=True)
@render._constraint_renderers.dispatch_for(ExcludeConstraint)
def _render_inline_exclude_constraint(
constraint: ExcludeConstraint,
autogen_context: AutogenContext,
namespace_metadata: MetaData,
) -> str:
rendered = render._user_defined_render(
"exclude", constraint, autogen_context
)
if rendered is not False:
return rendered
return _exclude_constraint(constraint, autogen_context, False)
def _postgresql_autogenerate_prefix(autogen_context: AutogenContext) -> str:
imports = autogen_context.imports
if imports is not None:
imports.add("from sqlalchemy.dialects import postgresql")
return "postgresql."
def _exclude_constraint(
constraint: ExcludeConstraint,
autogen_context: AutogenContext,
alter: bool,
) -> str:
opts: List[Tuple[str, Union[quoted_name, str, _f_name, None]]] = []
has_batch = autogen_context._has_batch
if constraint.deferrable:
opts.append(("deferrable", str(constraint.deferrable)))
if constraint.initially:
opts.append(("initially", str(constraint.initially)))
if constraint.using:
opts.append(("using", str(constraint.using)))
if not has_batch and alter and constraint.table.schema:
opts.append(("schema", render._ident(constraint.table.schema)))
if not alter and constraint.name:
opts.append(
("name", render._render_gen_name(autogen_context, constraint.name))
)
if alter:
args = [
repr(render._render_gen_name(autogen_context, constraint.name))
]
if not has_batch:
args += [repr(render._ident(constraint.table.name))]
args.extend(
[
"(%s, %r)"
% (
_render_potential_column(sqltext, autogen_context),
opstring,
)
for sqltext, name, opstring in constraint._render_exprs # type:ignore[attr-defined] # noqa
]
)
if constraint.where is not None:
args.append(
"where=%s"
% render._render_potential_expr(
constraint.where, autogen_context
)
)
args.extend(["%s=%r" % (k, v) for k, v in opts])
return "%(prefix)screate_exclude_constraint(%(args)s)" % {
"prefix": render._alembic_autogenerate_prefix(autogen_context),
"args": ", ".join(args),
}
else:
args = [
"(%s, %r)"
% (_render_potential_column(sqltext, autogen_context), opstring)
for sqltext, name, opstring in constraint._render_exprs
]
if constraint.where is not None:
args.append(
"where=%s"
% render._render_potential_expr(
constraint.where, autogen_context
)
)
args.extend(["%s=%r" % (k, v) for k, v in opts])
return "%(prefix)sExcludeConstraint(%(args)s)" % {
"prefix": _postgresql_autogenerate_prefix(autogen_context),
"args": ", ".join(args),
}
def _render_potential_column(
value: Union[ColumnClause, Column, TextClause],
autogen_context: AutogenContext,
) -> str:
if isinstance(value, ColumnClause):
if value.is_literal:
# like literal_column("int8range(from, to)") in ExcludeConstraint
template = "%(prefix)sliteral_column(%(name)r)"
else:
template = "%(prefix)scolumn(%(name)r)"
return template % {
"prefix": render._sqlalchemy_autogenerate_prefix(autogen_context),
"name": value.name,
}
else:
return render._render_potential_expr(
value, autogen_context, wrap_in_text=isinstance(value, TextClause)
)

225
libs/alembic/ddl/sqlite.py Normal file
View File

@ -0,0 +1,225 @@
from __future__ import annotations
import re
from typing import Any
from typing import Dict
from typing import Optional
from typing import TYPE_CHECKING
from typing import Union
from sqlalchemy import cast
from sqlalchemy import JSON
from sqlalchemy import schema
from sqlalchemy import sql
from sqlalchemy.ext.compiler import compiles
from .base import alter_table
from .base import format_table_name
from .base import RenameTable
from .impl import DefaultImpl
from .. import util
if TYPE_CHECKING:
from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.sql.compiler import DDLCompiler
from sqlalchemy.sql.elements import Cast
from sqlalchemy.sql.elements import ClauseElement
from sqlalchemy.sql.schema import Column
from sqlalchemy.sql.schema import Constraint
from sqlalchemy.sql.schema import Table
from sqlalchemy.sql.type_api import TypeEngine
from ..operations.batch import BatchOperationsImpl
class SQLiteImpl(DefaultImpl):
__dialect__ = "sqlite"
transactional_ddl = False
"""SQLite supports transactional DDL, but pysqlite does not:
see: http://bugs.python.org/issue10740
"""
def requires_recreate_in_batch(
self, batch_op: BatchOperationsImpl
) -> bool:
"""Return True if the given :class:`.BatchOperationsImpl`
would need the table to be recreated and copied in order to
proceed.
Normally, only returns True on SQLite when operations other
than add_column are present.
"""
for op in batch_op.batch:
if op[0] == "add_column":
col = op[1][1]
if isinstance(
col.server_default, schema.DefaultClause
) and isinstance(col.server_default.arg, sql.ClauseElement):
return True
elif (
isinstance(col.server_default, util.sqla_compat.Computed)
and col.server_default.persisted
):
return True
elif op[0] not in ("create_index", "drop_index"):
return True
else:
return False
def add_constraint(self, const: Constraint):
# attempt to distinguish between an
# auto-gen constraint and an explicit one
if const._create_rule is None: # type:ignore[attr-defined]
raise NotImplementedError(
"No support for ALTER of constraints in SQLite dialect. "
"Please refer to the batch mode feature which allows for "
"SQLite migrations using a copy-and-move strategy."
)
elif const._create_rule(self): # type:ignore[attr-defined]
util.warn(
"Skipping unsupported ALTER for "
"creation of implicit constraint. "
"Please refer to the batch mode feature which allows for "
"SQLite migrations using a copy-and-move strategy."
)
def drop_constraint(self, const: Constraint):
if const._create_rule is None: # type:ignore[attr-defined]
raise NotImplementedError(
"No support for ALTER of constraints in SQLite dialect. "
"Please refer to the batch mode feature which allows for "
"SQLite migrations using a copy-and-move strategy."
)
def compare_server_default(
self,
inspector_column: Column,
metadata_column: Column,
rendered_metadata_default: Optional[str],
rendered_inspector_default: Optional[str],
) -> bool:
if rendered_metadata_default is not None:
rendered_metadata_default = re.sub(
r"^\((.+)\)$", r"\1", rendered_metadata_default
)
rendered_metadata_default = re.sub(
r"^\"?'(.+)'\"?$", r"\1", rendered_metadata_default
)
if rendered_inspector_default is not None:
rendered_inspector_default = re.sub(
r"^\((.+)\)$", r"\1", rendered_inspector_default
)
rendered_inspector_default = re.sub(
r"^\"?'(.+)'\"?$", r"\1", rendered_inspector_default
)
return rendered_inspector_default != rendered_metadata_default
def _guess_if_default_is_unparenthesized_sql_expr(
self, expr: Optional[str]
) -> bool:
"""Determine if a server default is a SQL expression or a constant.
There are too many assertions that expect server defaults to round-trip
identically without parenthesis added so we will add parens only in
very specific cases.
"""
if not expr:
return False
elif re.match(r"^[0-9\.]$", expr):
return False
elif re.match(r"^'.+'$", expr):
return False
elif re.match(r"^\(.+\)$", expr):
return False
else:
return True
def autogen_column_reflect(
self,
inspector: Inspector,
table: Table,
column_info: Dict[str, Any],
) -> None:
# SQLite expression defaults require parenthesis when sent
# as DDL
if self._guess_if_default_is_unparenthesized_sql_expr(
column_info.get("default", None)
):
column_info["default"] = "(%s)" % (column_info["default"],)
def render_ddl_sql_expr(
self, expr: ClauseElement, is_server_default: bool = False, **kw
) -> str:
# SQLite expression defaults require parenthesis when sent
# as DDL
str_expr = super().render_ddl_sql_expr(
expr, is_server_default=is_server_default, **kw
)
if (
is_server_default
and self._guess_if_default_is_unparenthesized_sql_expr(str_expr)
):
str_expr = "(%s)" % (str_expr,)
return str_expr
def cast_for_batch_migrate(
self,
existing: Column,
existing_transfer: Dict[str, Union[TypeEngine, Cast]],
new_type: TypeEngine,
) -> None:
if (
existing.type._type_affinity # type:ignore[attr-defined]
is not new_type._type_affinity # type:ignore[attr-defined]
and not isinstance(new_type, JSON)
):
existing_transfer["expr"] = cast(
existing_transfer["expr"], new_type
)
def correct_for_autogen_constraints(
self,
conn_unique_constraints,
conn_indexes,
metadata_unique_constraints,
metadata_indexes,
):
self._skip_functional_indexes(metadata_indexes, conn_indexes)
@compiles(RenameTable, "sqlite")
def visit_rename_table(
element: RenameTable, compiler: DDLCompiler, **kw
) -> str:
return "%s RENAME TO %s" % (
alter_table(compiler, element.table_name, element.schema),
format_table_name(compiler, element.new_table_name, None),
)
# @compiles(AddColumn, 'sqlite')
# def visit_add_column(element, compiler, **kw):
# return "%s %s" % (
# alter_table(compiler, element.table_name, element.schema),
# add_column(compiler, element.column, **kw)
# )
# def add_column(compiler, column, **kw):
# text = "ADD COLUMN %s" % compiler.get_column_specification(column, **kw)
# need to modify SQLAlchemy so that the CHECK associated with a Boolean
# or Enum gets placed as part of the column constraints, not the Table
# see ticket 98
# for const in column.constraints:
# text += compiler.process(AddConstraint(const))
# return text

View File

@ -0,0 +1 @@
from .runtime.environment import * # noqa

View File

@ -0,0 +1 @@
from .runtime.migration import * # noqa

5
libs/alembic/op.py Normal file
View File

@ -0,0 +1,5 @@
from .operations.base import Operations
# create proxy functions for
# each method on the Operations class.
Operations.create_module_class_proxy(globals(), locals())

1184
libs/alembic/op.pyi Normal file

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,7 @@
from . import toimpl
from .base import BatchOperations
from .base import Operations
from .ops import MigrateOperation
__all__ = ["Operations", "BatchOperations", "MigrateOperation"]

View File

@ -0,0 +1,523 @@
from __future__ import annotations
from contextlib import contextmanager
import re
import textwrap
from typing import Any
from typing import Callable
from typing import Dict
from typing import Iterator
from typing import List # noqa
from typing import Mapping
from typing import Optional
from typing import Sequence # noqa
from typing import Tuple
from typing import Type # noqa
from typing import TYPE_CHECKING
from typing import Union
from sqlalchemy.sql.elements import conv
from . import batch
from . import schemaobj
from .. import util
from ..util import sqla_compat
from ..util.compat import formatannotation_fwdref
from ..util.compat import inspect_formatargspec
from ..util.compat import inspect_getfullargspec
from ..util.sqla_compat import _literal_bindparam
NoneType = type(None)
if TYPE_CHECKING:
from typing import Literal
from sqlalchemy import Table # noqa
from sqlalchemy.engine import Connection
from .batch import BatchOperationsImpl
from .ops import MigrateOperation
from ..ddl import DefaultImpl
from ..runtime.migration import MigrationContext
__all__ = ("Operations", "BatchOperations")
class Operations(util.ModuleClsProxy):
"""Define high level migration operations.
Each operation corresponds to some schema migration operation,
executed against a particular :class:`.MigrationContext`
which in turn represents connectivity to a database,
or a file output stream.
While :class:`.Operations` is normally configured as
part of the :meth:`.EnvironmentContext.run_migrations`
method called from an ``env.py`` script, a standalone
:class:`.Operations` instance can be
made for use cases external to regular Alembic
migrations by passing in a :class:`.MigrationContext`::
from alembic.migration import MigrationContext
from alembic.operations import Operations
conn = myengine.connect()
ctx = MigrationContext.configure(conn)
op = Operations(ctx)
op.alter_column("t", "c", nullable=True)
Note that as of 0.8, most of the methods on this class are produced
dynamically using the :meth:`.Operations.register_operation`
method.
"""
impl: Union[DefaultImpl, BatchOperationsImpl]
_to_impl = util.Dispatcher()
def __init__(
self,
migration_context: MigrationContext,
impl: Optional[BatchOperationsImpl] = None,
) -> None:
"""Construct a new :class:`.Operations`
:param migration_context: a :class:`.MigrationContext`
instance.
"""
self.migration_context = migration_context
if impl is None:
self.impl = migration_context.impl
else:
self.impl = impl
self.schema_obj = schemaobj.SchemaObjects(migration_context)
@classmethod
def register_operation(
cls, name: str, sourcename: Optional[str] = None
) -> Callable[..., Any]:
"""Register a new operation for this class.
This method is normally used to add new operations
to the :class:`.Operations` class, and possibly the
:class:`.BatchOperations` class as well. All Alembic migration
operations are implemented via this system, however the system
is also available as a public API to facilitate adding custom
operations.
.. seealso::
:ref:`operation_plugins`
"""
def register(op_cls):
if sourcename is None:
fn = getattr(op_cls, name)
source_name = fn.__name__
else:
fn = getattr(op_cls, sourcename)
source_name = fn.__name__
spec = inspect_getfullargspec(fn)
name_args = spec[0]
assert name_args[0:2] == ["cls", "operations"]
name_args[0:2] = ["self"]
args = inspect_formatargspec(
*spec, formatannotation=formatannotation_fwdref
)
num_defaults = len(spec[3]) if spec[3] else 0
if num_defaults:
defaulted_vals = name_args[0 - num_defaults :]
else:
defaulted_vals = ()
apply_kw = inspect_formatargspec(
name_args,
spec[1],
spec[2],
defaulted_vals,
formatvalue=lambda x: "=" + x,
formatannotation=formatannotation_fwdref,
)
args = re.sub(
r'[_]?ForwardRef\(([\'"].+?[\'"])\)',
lambda m: m.group(1),
args,
)
func_text = textwrap.dedent(
"""\
def %(name)s%(args)s:
%(doc)r
return op_cls.%(source_name)s%(apply_kw)s
"""
% {
"name": name,
"source_name": source_name,
"args": args,
"apply_kw": apply_kw,
"doc": fn.__doc__,
}
)
globals_ = dict(globals())
globals_.update({"op_cls": op_cls})
lcl = {}
exec(func_text, globals_, lcl)
setattr(cls, name, lcl[name])
fn.__func__.__doc__ = (
"This method is proxied on "
"the :class:`.%s` class, via the :meth:`.%s.%s` method."
% (cls.__name__, cls.__name__, name)
)
if hasattr(fn, "_legacy_translations"):
lcl[name]._legacy_translations = fn._legacy_translations
return op_cls
return register
@classmethod
def implementation_for(cls, op_cls: Any) -> Callable[..., Any]:
"""Register an implementation for a given :class:`.MigrateOperation`.
This is part of the operation extensibility API.
.. seealso::
:ref:`operation_plugins` - example of use
"""
def decorate(fn):
cls._to_impl.dispatch_for(op_cls)(fn)
return fn
return decorate
@classmethod
@contextmanager
def context(
cls, migration_context: MigrationContext
) -> Iterator[Operations]:
op = Operations(migration_context)
op._install_proxy()
yield op
op._remove_proxy()
@contextmanager
def batch_alter_table(
self,
table_name: str,
schema: Optional[str] = None,
recreate: Literal["auto", "always", "never"] = "auto",
partial_reordering: Optional[tuple] = None,
copy_from: Optional[Table] = None,
table_args: Tuple[Any, ...] = (),
table_kwargs: Mapping[str, Any] = util.immutabledict(),
reflect_args: Tuple[Any, ...] = (),
reflect_kwargs: Mapping[str, Any] = util.immutabledict(),
naming_convention: Optional[Dict[str, str]] = None,
) -> Iterator[BatchOperations]:
"""Invoke a series of per-table migrations in batch.
Batch mode allows a series of operations specific to a table
to be syntactically grouped together, and allows for alternate
modes of table migration, in particular the "recreate" style of
migration required by SQLite.
"recreate" style is as follows:
1. A new table is created with the new specification, based on the
migration directives within the batch, using a temporary name.
2. the data copied from the existing table to the new table.
3. the existing table is dropped.
4. the new table is renamed to the existing table name.
The directive by default will only use "recreate" style on the
SQLite backend, and only if directives are present which require
this form, e.g. anything other than ``add_column()``. The batch
operation on other backends will proceed using standard ALTER TABLE
operations.
The method is used as a context manager, which returns an instance
of :class:`.BatchOperations`; this object is the same as
:class:`.Operations` except that table names and schema names
are omitted. E.g.::
with op.batch_alter_table("some_table") as batch_op:
batch_op.add_column(Column('foo', Integer))
batch_op.drop_column('bar')
The operations within the context manager are invoked at once
when the context is ended. When run against SQLite, if the
migrations include operations not supported by SQLite's ALTER TABLE,
the entire table will be copied to a new one with the new
specification, moving all data across as well.
The copy operation by default uses reflection to retrieve the current
structure of the table, and therefore :meth:`.batch_alter_table`
in this mode requires that the migration is run in "online" mode.
The ``copy_from`` parameter may be passed which refers to an existing
:class:`.Table` object, which will bypass this reflection step.
.. note:: The table copy operation will currently not copy
CHECK constraints, and may not copy UNIQUE constraints that are
unnamed, as is possible on SQLite. See the section
:ref:`sqlite_batch_constraints` for workarounds.
:param table_name: name of table
:param schema: optional schema name.
:param recreate: under what circumstances the table should be
recreated. At its default of ``"auto"``, the SQLite dialect will
recreate the table if any operations other than ``add_column()``,
``create_index()``, or ``drop_index()`` are
present. Other options include ``"always"`` and ``"never"``.
:param copy_from: optional :class:`~sqlalchemy.schema.Table` object
that will act as the structure of the table being copied. If omitted,
table reflection is used to retrieve the structure of the table.
.. seealso::
:ref:`batch_offline_mode`
:paramref:`~.Operations.batch_alter_table.reflect_args`
:paramref:`~.Operations.batch_alter_table.reflect_kwargs`
:param reflect_args: a sequence of additional positional arguments that
will be applied to the table structure being reflected / copied;
this may be used to pass column and constraint overrides to the
table that will be reflected, in lieu of passing the whole
:class:`~sqlalchemy.schema.Table` using
:paramref:`~.Operations.batch_alter_table.copy_from`.
:param reflect_kwargs: a dictionary of additional keyword arguments
that will be applied to the table structure being copied; this may be
used to pass additional table and reflection options to the table that
will be reflected, in lieu of passing the whole
:class:`~sqlalchemy.schema.Table` using
:paramref:`~.Operations.batch_alter_table.copy_from`.
:param table_args: a sequence of additional positional arguments that
will be applied to the new :class:`~sqlalchemy.schema.Table` when
created, in addition to those copied from the source table.
This may be used to provide additional constraints such as CHECK
constraints that may not be reflected.
:param table_kwargs: a dictionary of additional keyword arguments
that will be applied to the new :class:`~sqlalchemy.schema.Table`
when created, in addition to those copied from the source table.
This may be used to provide for additional table options that may
not be reflected.
:param naming_convention: a naming convention dictionary of the form
described at :ref:`autogen_naming_conventions` which will be applied
to the :class:`~sqlalchemy.schema.MetaData` during the reflection
process. This is typically required if one wants to drop SQLite
constraints, as these constraints will not have names when
reflected on this backend. Requires SQLAlchemy **0.9.4** or greater.
.. seealso::
:ref:`dropping_sqlite_foreign_keys`
:param partial_reordering: a list of tuples, each suggesting a desired
ordering of two or more columns in the newly created table. Requires
that :paramref:`.batch_alter_table.recreate` is set to ``"always"``.
Examples, given a table with columns "a", "b", "c", and "d":
Specify the order of all columns::
with op.batch_alter_table(
"some_table", recreate="always",
partial_reordering=[("c", "d", "a", "b")]
) as batch_op:
pass
Ensure "d" appears before "c", and "b", appears before "a"::
with op.batch_alter_table(
"some_table", recreate="always",
partial_reordering=[("d", "c"), ("b", "a")]
) as batch_op:
pass
The ordering of columns not included in the partial_reordering
set is undefined. Therefore it is best to specify the complete
ordering of all columns for best results.
.. versionadded:: 1.4.0
.. note:: batch mode requires SQLAlchemy 0.8 or above.
.. seealso::
:ref:`batch_migrations`
"""
impl = batch.BatchOperationsImpl(
self,
table_name,
schema,
recreate,
copy_from,
table_args,
table_kwargs,
reflect_args,
reflect_kwargs,
naming_convention,
partial_reordering,
)
batch_op = BatchOperations(self.migration_context, impl=impl)
yield batch_op
impl.flush()
def get_context(self) -> MigrationContext:
"""Return the :class:`.MigrationContext` object that's
currently in use.
"""
return self.migration_context
def invoke(self, operation: MigrateOperation) -> Any:
"""Given a :class:`.MigrateOperation`, invoke it in terms of
this :class:`.Operations` instance.
"""
fn = self._to_impl.dispatch(
operation, self.migration_context.impl.__dialect__
)
return fn(self, operation)
def f(self, name: str) -> conv:
"""Indicate a string name that has already had a naming convention
applied to it.
This feature combines with the SQLAlchemy ``naming_convention`` feature
to disambiguate constraint names that have already had naming
conventions applied to them, versus those that have not. This is
necessary in the case that the ``"%(constraint_name)s"`` token
is used within a naming convention, so that it can be identified
that this particular name should remain fixed.
If the :meth:`.Operations.f` is used on a constraint, the naming
convention will not take effect::
op.add_column('t', 'x', Boolean(name=op.f('ck_bool_t_x')))
Above, the CHECK constraint generated will have the name
``ck_bool_t_x`` regardless of whether or not a naming convention is
in use.
Alternatively, if a naming convention is in use, and 'f' is not used,
names will be converted along conventions. If the ``target_metadata``
contains the naming convention
``{"ck": "ck_bool_%(table_name)s_%(constraint_name)s"}``, then the
output of the following:
op.add_column('t', 'x', Boolean(name='x'))
will be::
CONSTRAINT ck_bool_t_x CHECK (x in (1, 0)))
The function is rendered in the output of autogenerate when
a particular constraint name is already converted.
"""
return conv(name)
def inline_literal(
self, value: Union[str, int], type_: None = None
) -> _literal_bindparam:
r"""Produce an 'inline literal' expression, suitable for
using in an INSERT, UPDATE, or DELETE statement.
When using Alembic in "offline" mode, CRUD operations
aren't compatible with SQLAlchemy's default behavior surrounding
literal values,
which is that they are converted into bound values and passed
separately into the ``execute()`` method of the DBAPI cursor.
An offline SQL
script needs to have these rendered inline. While it should
always be noted that inline literal values are an **enormous**
security hole in an application that handles untrusted input,
a schema migration is not run in this context, so
literals are safe to render inline, with the caveat that
advanced types like dates may not be supported directly
by SQLAlchemy.
See :meth:`.execute` for an example usage of
:meth:`.inline_literal`.
The environment can also be configured to attempt to render
"literal" values inline automatically, for those simple types
that are supported by the dialect; see
:paramref:`.EnvironmentContext.configure.literal_binds` for this
more recently added feature.
:param value: The value to render. Strings, integers, and simple
numerics should be supported. Other types like boolean,
dates, etc. may or may not be supported yet by various
backends.
:param type\_: optional - a :class:`sqlalchemy.types.TypeEngine`
subclass stating the type of this value. In SQLAlchemy
expressions, this is usually derived automatically
from the Python type of the value itself, as well as
based on the context in which the value is used.
.. seealso::
:paramref:`.EnvironmentContext.configure.literal_binds`
"""
return sqla_compat._literal_bindparam(None, value, type_=type_)
def get_bind(self) -> Connection:
"""Return the current 'bind'.
Under normal circumstances, this is the
:class:`~sqlalchemy.engine.Connection` currently being used
to emit SQL to the database.
In a SQL script context, this value is ``None``. [TODO: verify this]
"""
return self.migration_context.impl.bind # type: ignore[return-value]
class BatchOperations(Operations):
"""Modifies the interface :class:`.Operations` for batch mode.
This basically omits the ``table_name`` and ``schema`` parameters
from associated methods, as these are a given when running under batch
mode.
.. seealso::
:meth:`.Operations.batch_alter_table`
Note that as of 0.8, most of the methods on this class are produced
dynamically using the :meth:`.Operations.register_operation`
method.
"""
impl: BatchOperationsImpl
def _noop(self, operation):
raise NotImplementedError(
"The %s method does not apply to a batch table alter operation."
% operation
)

View File

@ -0,0 +1,716 @@
from __future__ import annotations
from typing import Any
from typing import Dict
from typing import List
from typing import Optional
from typing import Tuple
from typing import TYPE_CHECKING
from typing import Union
from sqlalchemy import CheckConstraint
from sqlalchemy import Column
from sqlalchemy import ForeignKeyConstraint
from sqlalchemy import Index
from sqlalchemy import MetaData
from sqlalchemy import PrimaryKeyConstraint
from sqlalchemy import schema as sql_schema
from sqlalchemy import Table
from sqlalchemy import types as sqltypes
from sqlalchemy.events import SchemaEventTarget
from sqlalchemy.util import OrderedDict
from sqlalchemy.util import topological
from ..util import exc
from ..util.sqla_compat import _columns_for_constraint
from ..util.sqla_compat import _copy
from ..util.sqla_compat import _copy_expression
from ..util.sqla_compat import _ensure_scope_for_ddl
from ..util.sqla_compat import _fk_is_self_referential
from ..util.sqla_compat import _idx_table_bound_expressions
from ..util.sqla_compat import _insert_inline
from ..util.sqla_compat import _is_type_bound
from ..util.sqla_compat import _remove_column_from_collection
from ..util.sqla_compat import _resolve_for_variant
from ..util.sqla_compat import _select
from ..util.sqla_compat import constraint_name_defined
from ..util.sqla_compat import constraint_name_string
if TYPE_CHECKING:
from typing import Literal
from sqlalchemy.engine import Dialect
from sqlalchemy.sql.elements import ColumnClause
from sqlalchemy.sql.elements import quoted_name
from sqlalchemy.sql.functions import Function
from sqlalchemy.sql.schema import Constraint
from sqlalchemy.sql.type_api import TypeEngine
from ..ddl.impl import DefaultImpl
class BatchOperationsImpl:
def __init__(
self,
operations,
table_name,
schema,
recreate,
copy_from,
table_args,
table_kwargs,
reflect_args,
reflect_kwargs,
naming_convention,
partial_reordering,
):
self.operations = operations
self.table_name = table_name
self.schema = schema
if recreate not in ("auto", "always", "never"):
raise ValueError(
"recreate may be one of 'auto', 'always', or 'never'."
)
self.recreate = recreate
self.copy_from = copy_from
self.table_args = table_args
self.table_kwargs = dict(table_kwargs)
self.reflect_args = reflect_args
self.reflect_kwargs = dict(reflect_kwargs)
self.reflect_kwargs.setdefault(
"listeners", list(self.reflect_kwargs.get("listeners", ()))
)
self.reflect_kwargs["listeners"].append(
("column_reflect", operations.impl.autogen_column_reflect)
)
self.naming_convention = naming_convention
self.partial_reordering = partial_reordering
self.batch = []
@property
def dialect(self) -> Dialect:
return self.operations.impl.dialect
@property
def impl(self) -> DefaultImpl:
return self.operations.impl
def _should_recreate(self) -> bool:
if self.recreate == "auto":
return self.operations.impl.requires_recreate_in_batch(self)
elif self.recreate == "always":
return True
else:
return False
def flush(self) -> None:
should_recreate = self._should_recreate()
with _ensure_scope_for_ddl(self.impl.connection):
if not should_recreate:
for opname, arg, kw in self.batch:
fn = getattr(self.operations.impl, opname)
fn(*arg, **kw)
else:
if self.naming_convention:
m1 = MetaData(naming_convention=self.naming_convention)
else:
m1 = MetaData()
if self.copy_from is not None:
existing_table = self.copy_from
reflected = False
else:
if self.operations.migration_context.as_sql:
raise exc.CommandError(
f"This operation cannot proceed in --sql mode; "
f"batch mode with dialect "
f"{self.operations.migration_context.dialect.name} " # noqa: E501
f"requires a live database connection with which "
f'to reflect the table "{self.table_name}". '
f"To generate a batch SQL migration script using "
"table "
'"move and copy", a complete Table object '
f'should be passed to the "copy_from" argument '
"of the batch_alter_table() method so that table "
"reflection can be skipped."
)
existing_table = Table(
self.table_name,
m1,
schema=self.schema,
autoload_with=self.operations.get_bind(),
*self.reflect_args,
**self.reflect_kwargs,
)
reflected = True
batch_impl = ApplyBatchImpl(
self.impl,
existing_table,
self.table_args,
self.table_kwargs,
reflected,
partial_reordering=self.partial_reordering,
)
for opname, arg, kw in self.batch:
fn = getattr(batch_impl, opname)
fn(*arg, **kw)
batch_impl._create(self.impl)
def alter_column(self, *arg, **kw) -> None:
self.batch.append(("alter_column", arg, kw))
def add_column(self, *arg, **kw) -> None:
if (
"insert_before" in kw or "insert_after" in kw
) and not self._should_recreate():
raise exc.CommandError(
"Can't specify insert_before or insert_after when using "
"ALTER; please specify recreate='always'"
)
self.batch.append(("add_column", arg, kw))
def drop_column(self, *arg, **kw) -> None:
self.batch.append(("drop_column", arg, kw))
def add_constraint(self, const: Constraint) -> None:
self.batch.append(("add_constraint", (const,), {}))
def drop_constraint(self, const: Constraint) -> None:
self.batch.append(("drop_constraint", (const,), {}))
def rename_table(self, *arg, **kw):
self.batch.append(("rename_table", arg, kw))
def create_index(self, idx: Index) -> None:
self.batch.append(("create_index", (idx,), {}))
def drop_index(self, idx: Index) -> None:
self.batch.append(("drop_index", (idx,), {}))
def create_table_comment(self, table):
self.batch.append(("create_table_comment", (table,), {}))
def drop_table_comment(self, table):
self.batch.append(("drop_table_comment", (table,), {}))
def create_table(self, table):
raise NotImplementedError("Can't create table in batch mode")
def drop_table(self, table):
raise NotImplementedError("Can't drop table in batch mode")
def create_column_comment(self, column):
self.batch.append(("create_column_comment", (column,), {}))
class ApplyBatchImpl:
def __init__(
self,
impl: DefaultImpl,
table: Table,
table_args: tuple,
table_kwargs: Dict[str, Any],
reflected: bool,
partial_reordering: tuple = (),
) -> None:
self.impl = impl
self.table = table # this is a Table object
self.table_args = table_args
self.table_kwargs = table_kwargs
self.temp_table_name = self._calc_temp_name(table.name)
self.new_table: Optional[Table] = None
self.partial_reordering = partial_reordering # tuple of tuples
self.add_col_ordering: Tuple[
Tuple[str, str], ...
] = () # tuple of tuples
self.column_transfers = OrderedDict(
(c.name, {"expr": c}) for c in self.table.c
)
self.existing_ordering = list(self.column_transfers)
self.reflected = reflected
self._grab_table_elements()
@classmethod
def _calc_temp_name(cls, tablename: Union[quoted_name, str]) -> str:
return ("_alembic_tmp_%s" % tablename)[0:50]
def _grab_table_elements(self) -> None:
schema = self.table.schema
self.columns: Dict[str, Column] = OrderedDict()
for c in self.table.c:
c_copy = _copy(c, schema=schema)
c_copy.unique = c_copy.index = False
# ensure that the type object was copied,
# as we may need to modify it in-place
if isinstance(c.type, SchemaEventTarget):
assert c_copy.type is not c.type
self.columns[c.name] = c_copy
self.named_constraints: Dict[str, Constraint] = {}
self.unnamed_constraints = []
self.col_named_constraints = {}
self.indexes: Dict[str, Index] = {}
self.new_indexes: Dict[str, Index] = {}
for const in self.table.constraints:
if _is_type_bound(const):
continue
elif (
self.reflected
and isinstance(const, CheckConstraint)
and not const.name
):
# TODO: we are skipping unnamed reflected CheckConstraint
# because
# we have no way to determine _is_type_bound() for these.
pass
elif constraint_name_string(const.name):
self.named_constraints[const.name] = const
else:
self.unnamed_constraints.append(const)
if not self.reflected:
for col in self.table.c:
for const in col.constraints:
if const.name:
self.col_named_constraints[const.name] = (col, const)
for idx in self.table.indexes:
self.indexes[idx.name] = idx # type: ignore[index]
for k in self.table.kwargs:
self.table_kwargs.setdefault(k, self.table.kwargs[k])
def _adjust_self_columns_for_partial_reordering(self) -> None:
pairs = set()
col_by_idx = list(self.columns)
if self.partial_reordering:
for tuple_ in self.partial_reordering:
for index, elem in enumerate(tuple_):
if index > 0:
pairs.add((tuple_[index - 1], elem))
else:
for index, elem in enumerate(self.existing_ordering):
if index > 0:
pairs.add((col_by_idx[index - 1], elem))
pairs.update(self.add_col_ordering)
# this can happen if some columns were dropped and not removed
# from existing_ordering. this should be prevented already, but
# conservatively making sure this didn't happen
pairs_list = [p for p in pairs if p[0] != p[1]]
sorted_ = list(
topological.sort(pairs_list, col_by_idx, deterministic_order=True)
)
self.columns = OrderedDict((k, self.columns[k]) for k in sorted_)
self.column_transfers = OrderedDict(
(k, self.column_transfers[k]) for k in sorted_
)
def _transfer_elements_to_new_table(self) -> None:
assert self.new_table is None, "Can only create new table once"
m = MetaData()
schema = self.table.schema
if self.partial_reordering or self.add_col_ordering:
self._adjust_self_columns_for_partial_reordering()
self.new_table = new_table = Table(
self.temp_table_name,
m,
*(list(self.columns.values()) + list(self.table_args)),
schema=schema,
**self.table_kwargs,
)
for const in (
list(self.named_constraints.values()) + self.unnamed_constraints
):
const_columns = {c.key for c in _columns_for_constraint(const)}
if not const_columns.issubset(self.column_transfers):
continue
const_copy: Constraint
if isinstance(const, ForeignKeyConstraint):
if _fk_is_self_referential(const):
# for self-referential constraint, refer to the
# *original* table name, and not _alembic_batch_temp.
# This is consistent with how we're handling
# FK constraints from other tables; we assume SQLite
# no foreign keys just keeps the names unchanged, so
# when we rename back, they match again.
const_copy = _copy(
const, schema=schema, target_table=self.table
)
else:
# "target_table" for ForeignKeyConstraint.copy() is
# only used if the FK is detected as being
# self-referential, which we are handling above.
const_copy = _copy(const, schema=schema)
else:
const_copy = _copy(
const, schema=schema, target_table=new_table
)
if isinstance(const, ForeignKeyConstraint):
self._setup_referent(m, const)
new_table.append_constraint(const_copy)
def _gather_indexes_from_both_tables(self) -> List[Index]:
assert self.new_table is not None
idx: List[Index] = []
for idx_existing in self.indexes.values():
# this is a lift-and-move from Table.to_metadata
if idx_existing._column_flag: # type: ignore
continue
idx_copy = Index(
idx_existing.name,
unique=idx_existing.unique,
*[
_copy_expression(expr, self.new_table)
for expr in _idx_table_bound_expressions(idx_existing)
],
_table=self.new_table,
**idx_existing.kwargs,
)
idx.append(idx_copy)
for index in self.new_indexes.values():
idx.append(
Index(
index.name,
unique=index.unique,
*[self.new_table.c[col] for col in index.columns.keys()],
**index.kwargs,
)
)
return idx
def _setup_referent(
self, metadata: MetaData, constraint: ForeignKeyConstraint
) -> None:
spec = constraint.elements[
0
]._get_colspec() # type:ignore[attr-defined]
parts = spec.split(".")
tname = parts[-2]
if len(parts) == 3:
referent_schema = parts[0]
else:
referent_schema = None
if tname != self.temp_table_name:
key = sql_schema._get_table_key(tname, referent_schema)
def colspec(elem: Any):
return elem._get_colspec()
if key in metadata.tables:
t = metadata.tables[key]
for elem in constraint.elements:
colname = colspec(elem).split(".")[-1]
if colname not in t.c:
t.append_column(Column(colname, sqltypes.NULLTYPE))
else:
Table(
tname,
metadata,
*[
Column(n, sqltypes.NULLTYPE)
for n in [
colspec(elem).split(".")[-1]
for elem in constraint.elements
]
],
schema=referent_schema,
)
def _create(self, op_impl: DefaultImpl) -> None:
self._transfer_elements_to_new_table()
op_impl.prep_table_for_batch(self, self.table)
assert self.new_table is not None
op_impl.create_table(self.new_table)
try:
op_impl._exec(
_insert_inline(self.new_table).from_select(
list(
k
for k, transfer in self.column_transfers.items()
if "expr" in transfer
),
_select(
*[
transfer["expr"]
for transfer in self.column_transfers.values()
if "expr" in transfer
]
),
)
)
op_impl.drop_table(self.table)
except:
op_impl.drop_table(self.new_table)
raise
else:
op_impl.rename_table(
self.temp_table_name, self.table.name, schema=self.table.schema
)
self.new_table.name = self.table.name
try:
for idx in self._gather_indexes_from_both_tables():
op_impl.create_index(idx)
finally:
self.new_table.name = self.temp_table_name
def alter_column(
self,
table_name: str,
column_name: str,
nullable: Optional[bool] = None,
server_default: Optional[Union[Function[Any], str, bool]] = False,
name: Optional[str] = None,
type_: Optional[TypeEngine] = None,
autoincrement: None = None,
comment: Union[str, Literal[False]] = False,
**kw,
) -> None:
existing = self.columns[column_name]
existing_transfer: Dict[str, Any] = self.column_transfers[column_name]
if name is not None and name != column_name:
# note that we don't change '.key' - we keep referring
# to the renamed column by its old key in _create(). neat!
existing.name = name
existing_transfer["name"] = name
existing_type = kw.get("existing_type", None)
if existing_type:
resolved_existing_type = _resolve_for_variant(
kw["existing_type"], self.impl.dialect
)
# pop named constraints for Boolean/Enum for rename
if (
isinstance(resolved_existing_type, SchemaEventTarget)
and resolved_existing_type.name # type:ignore[attr-defined] # noqa E501
):
self.named_constraints.pop(
resolved_existing_type.name, # type:ignore[attr-defined] # noqa E501
None,
)
if type_ is not None:
type_ = sqltypes.to_instance(type_)
# old type is being discarded so turn off eventing
# rules. Alternatively we can
# erase the events set up by this type, but this is simpler.
# we also ignore the drop_constraint that will come here from
# Operations.implementation_for(alter_column)
if isinstance(existing.type, SchemaEventTarget):
existing.type._create_events = ( # type:ignore[attr-defined]
existing.type.create_constraint # type:ignore[attr-defined] # noqa
) = False
self.impl.cast_for_batch_migrate(
existing, existing_transfer, type_
)
existing.type = type_
# we *dont* however set events for the new type, because
# alter_column is invoked from
# Operations.implementation_for(alter_column) which already
# will emit an add_constraint()
if nullable is not None:
existing.nullable = nullable
if server_default is not False:
if server_default is None:
existing.server_default = None
else:
sql_schema.DefaultClause(
server_default # type: ignore[arg-type]
)._set_parent( # type:ignore[attr-defined]
existing
)
if autoincrement is not None:
existing.autoincrement = bool(autoincrement)
if comment is not False:
existing.comment = comment
def _setup_dependencies_for_add_column(
self,
colname: str,
insert_before: Optional[str],
insert_after: Optional[str],
) -> None:
index_cols = self.existing_ordering
col_indexes = {name: i for i, name in enumerate(index_cols)}
if not self.partial_reordering:
if insert_after:
if not insert_before:
if insert_after in col_indexes:
# insert after an existing column
idx = col_indexes[insert_after] + 1
if idx < len(index_cols):
insert_before = index_cols[idx]
else:
# insert after a column that is also new
insert_before = dict(self.add_col_ordering)[
insert_after
]
if insert_before:
if not insert_after:
if insert_before in col_indexes:
# insert before an existing column
idx = col_indexes[insert_before] - 1
if idx >= 0:
insert_after = index_cols[idx]
else:
# insert before a column that is also new
insert_after = {
b: a for a, b in self.add_col_ordering
}[insert_before]
if insert_before:
self.add_col_ordering += ((colname, insert_before),)
if insert_after:
self.add_col_ordering += ((insert_after, colname),)
if (
not self.partial_reordering
and not insert_before
and not insert_after
and col_indexes
):
self.add_col_ordering += ((index_cols[-1], colname),)
def add_column(
self,
table_name: str,
column: Column,
insert_before: Optional[str] = None,
insert_after: Optional[str] = None,
**kw,
) -> None:
self._setup_dependencies_for_add_column(
column.name, insert_before, insert_after
)
# we copy the column because operations.add_column()
# gives us a Column that is part of a Table already.
self.columns[column.name] = _copy(column, schema=self.table.schema)
self.column_transfers[column.name] = {}
def drop_column(
self, table_name: str, column: Union[ColumnClause, Column], **kw
) -> None:
if column.name in self.table.primary_key.columns:
_remove_column_from_collection(
self.table.primary_key.columns, column
)
del self.columns[column.name]
del self.column_transfers[column.name]
self.existing_ordering.remove(column.name)
# pop named constraints for Boolean/Enum for rename
if (
"existing_type" in kw
and isinstance(kw["existing_type"], SchemaEventTarget)
and kw["existing_type"].name # type:ignore[attr-defined]
):
self.named_constraints.pop(
kw["existing_type"].name, None # type:ignore[attr-defined]
)
def create_column_comment(self, column):
"""the batch table creation function will issue create_column_comment
on the real "impl" as part of the create table process.
That is, the Column object will have the comment on it already,
so when it is received by add_column() it will be a normal part of
the CREATE TABLE and doesn't need an extra step here.
"""
def create_table_comment(self, table):
"""the batch table creation function will issue create_table_comment
on the real "impl" as part of the create table process.
"""
def drop_table_comment(self, table):
"""the batch table creation function will issue drop_table_comment
on the real "impl" as part of the create table process.
"""
def add_constraint(self, const: Constraint) -> None:
if not constraint_name_defined(const.name):
raise ValueError("Constraint must have a name")
if isinstance(const, sql_schema.PrimaryKeyConstraint):
if self.table.primary_key in self.unnamed_constraints:
self.unnamed_constraints.remove(self.table.primary_key)
if constraint_name_string(const.name):
self.named_constraints[const.name] = const
else:
self.unnamed_constraints.append(const)
def drop_constraint(self, const: Constraint) -> None:
if not const.name:
raise ValueError("Constraint must have a name")
try:
if const.name in self.col_named_constraints:
col, const = self.col_named_constraints.pop(const.name)
for col_const in list(self.columns[col.name].constraints):
if col_const.name == const.name:
self.columns[col.name].constraints.remove(col_const)
elif constraint_name_string(const.name):
const = self.named_constraints.pop(const.name)
elif const in self.unnamed_constraints:
self.unnamed_constraints.remove(const)
except KeyError:
if _is_type_bound(const):
# type-bound constraints are only included in the new
# table via their type object in any case, so ignore the
# drop_constraint() that comes here via the
# Operations.implementation_for(alter_column)
return
raise ValueError("No such constraint: '%s'" % const.name)
else:
if isinstance(const, PrimaryKeyConstraint):
for col in const.columns:
self.columns[col.name].primary_key = False
def create_index(self, idx: Index) -> None:
self.new_indexes[idx.name] = idx # type: ignore[index]
def drop_index(self, idx: Index) -> None:
try:
del self.indexes[idx.name] # type: ignore[arg-type]
except KeyError:
raise ValueError("No such index: '%s'" % idx.name)
def rename_table(self, *arg, **kw):
raise NotImplementedError("TODO")

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,284 @@
from __future__ import annotations
from typing import Any
from typing import Dict
from typing import List
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import TYPE_CHECKING
from typing import Union
from sqlalchemy import schema as sa_schema
from sqlalchemy.sql.schema import Column
from sqlalchemy.sql.schema import Constraint
from sqlalchemy.sql.schema import Index
from sqlalchemy.types import Integer
from sqlalchemy.types import NULLTYPE
from .. import util
from ..util import sqla_compat
if TYPE_CHECKING:
from sqlalchemy.sql.elements import ColumnElement
from sqlalchemy.sql.elements import TextClause
from sqlalchemy.sql.schema import CheckConstraint
from sqlalchemy.sql.schema import ForeignKey
from sqlalchemy.sql.schema import ForeignKeyConstraint
from sqlalchemy.sql.schema import MetaData
from sqlalchemy.sql.schema import PrimaryKeyConstraint
from sqlalchemy.sql.schema import Table
from sqlalchemy.sql.schema import UniqueConstraint
from sqlalchemy.sql.type_api import TypeEngine
from ..runtime.migration import MigrationContext
class SchemaObjects:
def __init__(
self, migration_context: Optional[MigrationContext] = None
) -> None:
self.migration_context = migration_context
def primary_key_constraint(
self,
name: Optional[sqla_compat._ConstraintNameDefined],
table_name: str,
cols: Sequence[str],
schema: Optional[str] = None,
**dialect_kw,
) -> PrimaryKeyConstraint:
m = self.metadata()
columns = [sa_schema.Column(n, NULLTYPE) for n in cols]
t = sa_schema.Table(table_name, m, *columns, schema=schema)
# SQLAlchemy primary key constraint name arg is wrongly typed on
# the SQLAlchemy side through 2.0.5 at least
p = sa_schema.PrimaryKeyConstraint(
*[t.c[n] for n in cols], name=name, **dialect_kw # type: ignore
)
return p
def foreign_key_constraint(
self,
name: Optional[sqla_compat._ConstraintNameDefined],
source: str,
referent: str,
local_cols: List[str],
remote_cols: List[str],
onupdate: Optional[str] = None,
ondelete: Optional[str] = None,
deferrable: Optional[bool] = None,
source_schema: Optional[str] = None,
referent_schema: Optional[str] = None,
initially: Optional[str] = None,
match: Optional[str] = None,
**dialect_kw,
) -> ForeignKeyConstraint:
m = self.metadata()
if source == referent and source_schema == referent_schema:
t1_cols = local_cols + remote_cols
else:
t1_cols = local_cols
sa_schema.Table(
referent,
m,
*[sa_schema.Column(n, NULLTYPE) for n in remote_cols],
schema=referent_schema,
)
t1 = sa_schema.Table(
source,
m,
*[sa_schema.Column(n, NULLTYPE) for n in t1_cols],
schema=source_schema,
)
tname = (
"%s.%s" % (referent_schema, referent)
if referent_schema
else referent
)
dialect_kw["match"] = match
f = sa_schema.ForeignKeyConstraint(
local_cols,
["%s.%s" % (tname, n) for n in remote_cols],
name=name,
onupdate=onupdate,
ondelete=ondelete,
deferrable=deferrable,
initially=initially,
**dialect_kw,
)
t1.append_constraint(f)
return f
def unique_constraint(
self,
name: Optional[sqla_compat._ConstraintNameDefined],
source: str,
local_cols: Sequence[str],
schema: Optional[str] = None,
**kw,
) -> UniqueConstraint:
t = sa_schema.Table(
source,
self.metadata(),
*[sa_schema.Column(n, NULLTYPE) for n in local_cols],
schema=schema,
)
kw["name"] = name
uq = sa_schema.UniqueConstraint(*[t.c[n] for n in local_cols], **kw)
# TODO: need event tests to ensure the event
# is fired off here
t.append_constraint(uq)
return uq
def check_constraint(
self,
name: Optional[sqla_compat._ConstraintNameDefined],
source: str,
condition: Union[str, TextClause, ColumnElement[Any]],
schema: Optional[str] = None,
**kw,
) -> Union[CheckConstraint]:
t = sa_schema.Table(
source,
self.metadata(),
sa_schema.Column("x", Integer),
schema=schema,
)
ck = sa_schema.CheckConstraint(condition, name=name, **kw)
t.append_constraint(ck)
return ck
def generic_constraint(
self,
name: Optional[sqla_compat._ConstraintNameDefined],
table_name: str,
type_: Optional[str],
schema: Optional[str] = None,
**kw,
) -> Any:
t = self.table(table_name, schema=schema)
types: Dict[Optional[str], Any] = {
"foreignkey": lambda name: sa_schema.ForeignKeyConstraint(
[], [], name=name
),
"primary": sa_schema.PrimaryKeyConstraint,
"unique": sa_schema.UniqueConstraint,
"check": lambda name: sa_schema.CheckConstraint("", name=name),
None: sa_schema.Constraint,
}
try:
const = types[type_]
except KeyError as ke:
raise TypeError(
"'type' can be one of %s"
% ", ".join(sorted(repr(x) for x in types))
) from ke
else:
const = const(name=name)
t.append_constraint(const)
return const
def metadata(self) -> MetaData:
kw = {}
if (
self.migration_context is not None
and "target_metadata" in self.migration_context.opts
):
mt = self.migration_context.opts["target_metadata"]
if hasattr(mt, "naming_convention"):
kw["naming_convention"] = mt.naming_convention
return sa_schema.MetaData(**kw)
def table(self, name: str, *columns, **kw) -> Table:
m = self.metadata()
cols = [
sqla_compat._copy(c) if c.table is not None else c
for c in columns
if isinstance(c, Column)
]
# these flags have already added their UniqueConstraint /
# Index objects to the table, so flip them off here.
# SQLAlchemy tometadata() avoids this instead by preserving the
# flags and skipping the constraints that have _type_bound on them,
# but for a migration we'd rather list out the constraints
# explicitly.
_constraints_included = kw.pop("_constraints_included", False)
if _constraints_included:
for c in cols:
c.unique = c.index = False
t = sa_schema.Table(name, m, *cols, **kw)
constraints = [
sqla_compat._copy(elem, target_table=t)
if getattr(elem, "parent", None) is not t
and getattr(elem, "parent", None) is not None
else elem
for elem in columns
if isinstance(elem, (Constraint, Index))
]
for const in constraints:
t.append_constraint(const)
for f in t.foreign_keys:
self._ensure_table_for_fk(m, f)
return t
def column(self, name: str, type_: TypeEngine, **kw) -> Column:
return sa_schema.Column(name, type_, **kw)
def index(
self,
name: Optional[str],
tablename: Optional[str],
columns: Sequence[Union[str, TextClause, ColumnElement[Any]]],
schema: Optional[str] = None,
**kw,
) -> Index:
t = sa_schema.Table(
tablename or "no_table",
self.metadata(),
schema=schema,
)
kw["_table"] = t
idx = sa_schema.Index(
name,
*[util.sqla_compat._textual_index_column(t, n) for n in columns],
**kw,
)
return idx
def _parse_table_key(self, table_key: str) -> Tuple[Optional[str], str]:
if "." in table_key:
tokens = table_key.split(".")
sname: Optional[str] = ".".join(tokens[0:-1])
tname = tokens[-1]
else:
tname = table_key
sname = None
return (sname, tname)
def _ensure_table_for_fk(self, metadata: MetaData, fk: ForeignKey) -> None:
"""create a placeholder Table object for the referent of a
ForeignKey.
"""
if isinstance(fk._colspec, str): # type:ignore[attr-defined]
table_key, cname = fk._colspec.rsplit( # type:ignore[attr-defined]
".", 1
)
sname, tname = self._parse_table_key(table_key)
if table_key not in metadata.tables:
rel_t = sa_schema.Table(tname, metadata, schema=sname)
else:
rel_t = metadata.tables[table_key]
if cname not in rel_t.c:
rel_t.append_column(sa_schema.Column(cname, NULLTYPE))

View File

@ -0,0 +1,209 @@
from typing import TYPE_CHECKING
from sqlalchemy import schema as sa_schema
from . import ops
from .base import Operations
from ..util.sqla_compat import _copy
if TYPE_CHECKING:
from sqlalchemy.sql.schema import Table
@Operations.implementation_for(ops.AlterColumnOp)
def alter_column(
operations: "Operations", operation: "ops.AlterColumnOp"
) -> None:
compiler = operations.impl.dialect.statement_compiler(
operations.impl.dialect, None
)
existing_type = operation.existing_type
existing_nullable = operation.existing_nullable
existing_server_default = operation.existing_server_default
type_ = operation.modify_type
column_name = operation.column_name
table_name = operation.table_name
schema = operation.schema
server_default = operation.modify_server_default
new_column_name = operation.modify_name
nullable = operation.modify_nullable
comment = operation.modify_comment
existing_comment = operation.existing_comment
def _count_constraint(constraint):
return not isinstance(constraint, sa_schema.PrimaryKeyConstraint) and (
not constraint._create_rule or constraint._create_rule(compiler)
)
if existing_type and type_:
t = operations.schema_obj.table(
table_name,
sa_schema.Column(column_name, existing_type),
schema=schema,
)
for constraint in t.constraints:
if _count_constraint(constraint):
operations.impl.drop_constraint(constraint)
operations.impl.alter_column(
table_name,
column_name,
nullable=nullable,
server_default=server_default,
name=new_column_name,
type_=type_,
schema=schema,
existing_type=existing_type,
existing_server_default=existing_server_default,
existing_nullable=existing_nullable,
comment=comment,
existing_comment=existing_comment,
**operation.kw
)
if type_:
t = operations.schema_obj.table(
table_name,
operations.schema_obj.column(column_name, type_),
schema=schema,
)
for constraint in t.constraints:
if _count_constraint(constraint):
operations.impl.add_constraint(constraint)
@Operations.implementation_for(ops.DropTableOp)
def drop_table(operations: "Operations", operation: "ops.DropTableOp") -> None:
operations.impl.drop_table(
operation.to_table(operations.migration_context)
)
@Operations.implementation_for(ops.DropColumnOp)
def drop_column(
operations: "Operations", operation: "ops.DropColumnOp"
) -> None:
column = operation.to_column(operations.migration_context)
operations.impl.drop_column(
operation.table_name, column, schema=operation.schema, **operation.kw
)
@Operations.implementation_for(ops.CreateIndexOp)
def create_index(
operations: "Operations", operation: "ops.CreateIndexOp"
) -> None:
idx = operation.to_index(operations.migration_context)
operations.impl.create_index(idx)
@Operations.implementation_for(ops.DropIndexOp)
def drop_index(operations: "Operations", operation: "ops.DropIndexOp") -> None:
operations.impl.drop_index(
operation.to_index(operations.migration_context)
)
@Operations.implementation_for(ops.CreateTableOp)
def create_table(
operations: "Operations", operation: "ops.CreateTableOp"
) -> "Table":
table = operation.to_table(operations.migration_context)
operations.impl.create_table(table)
return table
@Operations.implementation_for(ops.RenameTableOp)
def rename_table(
operations: "Operations", operation: "ops.RenameTableOp"
) -> None:
operations.impl.rename_table(
operation.table_name, operation.new_table_name, schema=operation.schema
)
@Operations.implementation_for(ops.CreateTableCommentOp)
def create_table_comment(
operations: "Operations", operation: "ops.CreateTableCommentOp"
) -> None:
table = operation.to_table(operations.migration_context)
operations.impl.create_table_comment(table)
@Operations.implementation_for(ops.DropTableCommentOp)
def drop_table_comment(
operations: "Operations", operation: "ops.DropTableCommentOp"
) -> None:
table = operation.to_table(operations.migration_context)
operations.impl.drop_table_comment(table)
@Operations.implementation_for(ops.AddColumnOp)
def add_column(operations: "Operations", operation: "ops.AddColumnOp") -> None:
table_name = operation.table_name
column = operation.column
schema = operation.schema
kw = operation.kw
if column.table is not None:
column = _copy(column)
t = operations.schema_obj.table(table_name, column, schema=schema)
operations.impl.add_column(table_name, column, schema=schema, **kw)
for constraint in t.constraints:
if not isinstance(constraint, sa_schema.PrimaryKeyConstraint):
operations.impl.add_constraint(constraint)
for index in t.indexes:
operations.impl.create_index(index)
with_comment = (
operations.impl.dialect.supports_comments
and not operations.impl.dialect.inline_comments
)
comment = column.comment
if comment and with_comment:
operations.impl.create_column_comment(column)
@Operations.implementation_for(ops.AddConstraintOp)
def create_constraint(
operations: "Operations", operation: "ops.AddConstraintOp"
) -> None:
operations.impl.add_constraint(
operation.to_constraint(operations.migration_context)
)
@Operations.implementation_for(ops.DropConstraintOp)
def drop_constraint(
operations: "Operations", operation: "ops.DropConstraintOp"
) -> None:
operations.impl.drop_constraint(
operations.schema_obj.generic_constraint(
operation.constraint_name,
operation.table_name,
operation.constraint_type,
schema=operation.schema,
)
)
@Operations.implementation_for(ops.BulkInsertOp)
def bulk_insert(
operations: "Operations", operation: "ops.BulkInsertOp"
) -> None:
operations.impl.bulk_insert( # type: ignore[union-attr]
operation.table, operation.rows, multiinsert=operation.multiinsert
)
@Operations.implementation_for(ops.ExecuteSQLOp)
def execute_sql(
operations: "Operations", operation: "ops.ExecuteSQLOp"
) -> None:
operations.migration_context.impl.execute(
operation.sqltext, execution_options=operation.execution_options
)

View File

Some files were not shown because too many files have changed in this diff Show More