import functools import re import sys from peewee import * from peewee import _atomic from peewee import _manual from peewee import ColumnMetadata # (name, data_type, null, primary_key, table, default) from peewee import EnclosedNodeList from peewee import Entity from peewee import ForeignKeyMetadata # (column, dest_table, dest_column, table). from peewee import IndexMetadata from peewee import NodeList from playhouse.pool import _PooledPostgresqlDatabase try: from playhouse.postgres_ext import ArrayField from playhouse.postgres_ext import BinaryJSONField from playhouse.postgres_ext import IntervalField JSONField = BinaryJSONField except ImportError: # psycopg2 not installed, ignore. ArrayField = BinaryJSONField = IntervalField = JSONField = None if sys.version_info[0] > 2: basestring = str NESTED_TX_MIN_VERSION = 200100 TXN_ERR_MSG = ('CockroachDB does not support nested transactions. You may ' 'alternatively use the @transaction context-manager/decorator, ' 'which only wraps the outer-most block in transactional logic. ' 'To run a transaction with automatic retries, use the ' 'run_transaction() helper.') class ExceededMaxAttempts(OperationalError): pass class UUIDKeyField(UUIDField): auto_increment = True def __init__(self, *args, **kwargs): if kwargs.get('constraints'): raise ValueError('%s cannot specify constraints.' % type(self)) kwargs['constraints'] = [SQL('DEFAULT gen_random_uuid()')] kwargs.setdefault('primary_key', True) super(UUIDKeyField, self).__init__(*args, **kwargs) class RowIDField(AutoField): field_type = 'INT' def __init__(self, *args, **kwargs): if kwargs.get('constraints'): raise ValueError('%s cannot specify constraints.' % type(self)) kwargs['constraints'] = [SQL('DEFAULT unique_rowid()')] super(RowIDField, self).__init__(*args, **kwargs) class CockroachDatabase(PostgresqlDatabase): field_types = PostgresqlDatabase.field_types.copy() field_types.update({ 'BLOB': 'BYTES', }) for_update = False nulls_ordering = False release_after_rollback = True def __init__(self, database, *args, **kwargs): # Unless a DSN or database connection-url were specified, provide # convenient defaults for the user and port. if 'dsn' not in kwargs and (database and not database.startswith('postgresql://')): kwargs.setdefault('user', 'root') kwargs.setdefault('port', 26257) super(CockroachDatabase, self).__init__(database, *args, **kwargs) def _set_server_version(self, conn): curs = conn.cursor() curs.execute('select version()') raw, = curs.fetchone() match_obj = re.match(r'^CockroachDB.+?v(\d+)\.(\d+)\.(\d+)', raw) if match_obj is not None: clean = '%d%02d%02d' % tuple(int(i) for i in match_obj.groups()) self.server_version = int(clean) # 19.1.5 -> 190105. else: # Fallback to use whatever cockroachdb tells us via protocol. super(CockroachDatabase, self)._set_server_version(conn) def _get_pk_constraint(self, table, schema=None): query = ('SELECT constraint_name ' 'FROM information_schema.table_constraints ' 'WHERE table_name = %s AND table_schema = %s ' 'AND constraint_type = %s') cursor = self.execute_sql(query, (table, schema or 'public', 'PRIMARY KEY')) row = cursor.fetchone() return row and row[0] or None def get_indexes(self, table, schema=None): # The primary-key index is returned by default, so we will just strip # it out here. indexes = super(CockroachDatabase, self).get_indexes(table, schema) pkc = self._get_pk_constraint(table, schema) return [idx for idx in indexes if (not pkc) or (idx.name != pkc)] def conflict_statement(self, on_conflict, query): if not on_conflict._action: return action = on_conflict._action.lower() if action in ('replace', 'upsert'): return SQL('UPSERT') elif action not in ('ignore', 'nothing', 'update'): raise ValueError('Un-supported action for conflict resolution. ' 'CockroachDB supports REPLACE (UPSERT), IGNORE ' 'and UPDATE.') def conflict_update(self, oc, query): action = oc._action.lower() if oc._action else '' if action in ('ignore', 'nothing'): parts = [SQL('ON CONFLICT')] if oc._conflict_target: parts.append(EnclosedNodeList([ Entity(col) if isinstance(col, basestring) else col for col in oc._conflict_target])) parts.append(SQL('DO NOTHING')) return NodeList(parts) elif action in ('replace', 'upsert'): # No special stuff is necessary, this is just indicated by starting # the statement with UPSERT instead of INSERT. return elif oc._conflict_constraint: raise ValueError('CockroachDB does not support the usage of a ' 'constraint name. Use the column(s) instead.') return super(CockroachDatabase, self).conflict_update(oc, query) def extract_date(self, date_part, date_field): return fn.extract(date_part, date_field) def from_timestamp(self, date_field): # CRDB does not allow casting a decimal/float to timestamp, so we first # cast to int, then to timestamptz. return date_field.cast('int').cast('timestamptz') def begin(self, system_time=None, priority=None): super(CockroachDatabase, self).begin() if system_time is not None: self.execute_sql('SET TRANSACTION AS OF SYSTEM TIME %s', (system_time,), commit=False) if priority is not None: priority = priority.lower() if priority not in ('low', 'normal', 'high'): raise ValueError('priority must be low, normal or high') self.execute_sql('SET TRANSACTION PRIORITY %s' % priority, commit=False) def atomic(self, system_time=None, priority=None): if self.is_closed(): self.connect() # Side-effect, set server version. if self.server_version < NESTED_TX_MIN_VERSION: return _crdb_atomic(self, system_time, priority) return super(CockroachDatabase, self).atomic(system_time, priority) def savepoint(self): if self.is_closed(): self.connect() # Side-effect, set server version. if self.server_version < NESTED_TX_MIN_VERSION: raise NotImplementedError(TXN_ERR_MSG) return super(CockroachDatabase, self).savepoint() def retry_transaction(self, max_attempts=None, system_time=None, priority=None): def deco(cb): @functools.wraps(cb) def new_fn(): return run_transaction(self, cb, max_attempts, system_time, priority) return new_fn return deco def run_transaction(self, cb, max_attempts=None, system_time=None, priority=None): return run_transaction(self, cb, max_attempts, system_time, priority) class _crdb_atomic(_atomic): def __enter__(self): if self.db.transaction_depth() > 0: if not isinstance(self.db.top_transaction(), _manual): raise NotImplementedError(TXN_ERR_MSG) return super(_crdb_atomic, self).__enter__() def run_transaction(db, callback, max_attempts=None, system_time=None, priority=None): """ Run transactional SQL in a transaction with automatic retries. User-provided `callback`: * Must accept one parameter, the `db` instance representing the connection the transaction is running under. * Must not attempt to commit, rollback or otherwise manage transactions. * May be called more than once. * Should ideally only contain SQL operations. Additionally, the database must not have any open transaction at the time this function is called, as CRDB does not support nested transactions. """ max_attempts = max_attempts or -1 with db.atomic(system_time=system_time, priority=priority) as txn: db.execute_sql('SAVEPOINT cockroach_restart') while max_attempts != 0: try: result = callback(db) db.execute_sql('RELEASE SAVEPOINT cockroach_restart') return result except OperationalError as exc: if exc.orig.pgcode == '40001': max_attempts -= 1 db.execute_sql('ROLLBACK TO SAVEPOINT cockroach_restart') continue raise raise ExceededMaxAttempts(None, 'unable to commit transaction') class PooledCockroachDatabase(_PooledPostgresqlDatabase, CockroachDatabase): pass