import math import sys from flask import abort from flask import render_template from flask import request from peewee import Database from peewee import DoesNotExist from peewee import Model from peewee import Proxy from peewee import SelectQuery from playhouse.db_url import connect as db_url_connect class PaginatedQuery(object): def __init__(self, query_or_model, paginate_by, page_var='page', page=None, check_bounds=False): self.paginate_by = paginate_by self.page_var = page_var self.page = page or None self.check_bounds = check_bounds if isinstance(query_or_model, SelectQuery): self.query = query_or_model self.model = self.query.model else: self.model = query_or_model self.query = self.model.select() def get_page(self): if self.page is not None: return self.page curr_page = request.args.get(self.page_var) if curr_page and curr_page.isdigit(): return max(1, int(curr_page)) return 1 def get_page_count(self): if not hasattr(self, '_page_count'): self._page_count = int(math.ceil( float(self.query.count()) / self.paginate_by)) return self._page_count def get_object_list(self): if self.check_bounds and self.get_page() > self.get_page_count(): abort(404) return self.query.paginate(self.get_page(), self.paginate_by) def get_object_or_404(query_or_model, *query): if not isinstance(query_or_model, SelectQuery): query_or_model = query_or_model.select() try: return query_or_model.where(*query).get() except DoesNotExist: abort(404) def object_list(template_name, query, context_variable='object_list', paginate_by=20, page_var='page', page=None, check_bounds=True, **kwargs): paginated_query = PaginatedQuery( query, paginate_by=paginate_by, page_var=page_var, page=page, check_bounds=check_bounds) kwargs[context_variable] = paginated_query.get_object_list() return render_template( template_name, pagination=paginated_query, page=paginated_query.get_page(), **kwargs) def get_current_url(): if not request.query_string: return request.path return '%s?%s' % (request.path, request.query_string) def get_next_url(default='/'): if request.args.get('next'): return request.args['next'] elif request.form.get('next'): return request.form['next'] return default class FlaskDB(object): def __init__(self, app=None, database=None, model_class=Model): self.database = None # Reference to actual Peewee database instance. self.base_model_class = model_class self._app = app self._db = database # dict, url, Database, or None (default). if app is not None: self.init_app(app) def init_app(self, app): self._app = app if self._db is None: if 'DATABASE' in app.config: initial_db = app.config['DATABASE'] elif 'DATABASE_URL' in app.config: initial_db = app.config['DATABASE_URL'] else: raise ValueError('Missing required configuration data for ' 'database: DATABASE or DATABASE_URL.') else: initial_db = self._db self._load_database(app, initial_db) self._register_handlers(app) def _load_database(self, app, config_value): if isinstance(config_value, Database): database = config_value elif isinstance(config_value, dict): database = self._load_from_config_dict(dict(config_value)) else: # Assume a database connection URL. database = db_url_connect(config_value) if isinstance(self.database, Proxy): self.database.initialize(database) else: self.database = database def _load_from_config_dict(self, config_dict): try: name = config_dict.pop('name') engine = config_dict.pop('engine') except KeyError: raise RuntimeError('DATABASE configuration must specify a ' '`name` and `engine`.') if '.' in engine: path, class_name = engine.rsplit('.', 1) else: path, class_name = 'peewee', engine try: __import__(path) module = sys.modules[path] database_class = getattr(module, class_name) assert issubclass(database_class, Database) except ImportError: raise RuntimeError('Unable to import %s' % engine) except AttributeError: raise RuntimeError('Database engine not found %s' % engine) except AssertionError: raise RuntimeError('Database engine not a subclass of ' 'peewee.Database: %s' % engine) return database_class(name, **config_dict) def _register_handlers(self, app): app.before_request(self.connect_db) app.teardown_request(self.close_db) def get_model_class(self): if self.database is None: raise RuntimeError('Database must be initialized.') class BaseModel(self.base_model_class): class Meta: database = self.database return BaseModel @property def Model(self): if self._app is None: database = getattr(self, 'database', None) if database is None: self.database = Proxy() if not hasattr(self, '_model_class'): self._model_class = self.get_model_class() return self._model_class def connect_db(self): self.database.connect() def close_db(self, exc): if not self.database.is_closed(): self.database.close()