Source code for concurrency.triggers
from collections import defaultdict
from django.apps import apps
from django.db import connections, router
from django.db.utils import DatabaseError
# from .fields import _TRIGGERS # noqa
class TriggerRegistry:
def __init__(self) -> None:
self._fields = []
def append(self, field) -> None:
self._fields.append([field.model._meta.app_label, field.model.__name__])
def __iter__(self):
return iter(self._fields)
def __contains__(self, field) -> bool:
target = [field.model._meta.app_label, field.model.__name__]
return target in self._fields
_TRIGGERS = TriggerRegistry()
def get_trigger_name(field) -> str:
"""Retrieve the name of the trigger.
:param field: Field instance
:return: unicode
"""
name = field._trigger_name or f"{field.model._meta.db_table}_{field.name}"
return f"concurrency_{name}"
def get_triggers(databases=None):
if databases is None:
databases = list(connections)
ret = {}
for alias in databases:
connection = connections[alias]
f = factory(connection)
r = f.get_list()
ret[alias] = r
return ret
def drop_triggers(*databases):
global _TRIGGERS # noqa
ret = defaultdict(list)
for app_label, model_name in _TRIGGERS:
model = apps.get_model(app_label, model_name)
field = model._concurrencymeta.field
alias = router.db_for_write(model)
if alias in databases:
connection = connections[alias]
f = factory(connection)
f.drop(field)
field._trigger_exists = False
ret[alias].append([model, field, field.trigger_name])
else: # pragma: no cover
pass
return ret
def create_triggers(databases):
global _TRIGGERS # noqa
ret = defaultdict(list)
for app_label, model_name in _TRIGGERS:
model = apps.get_model(app_label, model_name)
field = model._concurrencymeta.field
storage = model._concurrencymeta.triggers
alias = router.db_for_write(model)
if (alias in databases) and field not in storage:
storage.append(field)
connection = connections[alias]
f = factory(connection)
f.create(field)
ret[alias].append([model, field, field.trigger_name])
else: # pragma: no cover
pass
return ret
[docs]
class TriggerFactory:
"""Abstract Factory class to create triggers.
Implemementations need to set the following attributes
`update_clause`, `drop_clause` and `list_clause`
Those will be formatted using standard python `format()` as::
self.update_clause.format(trigger_name=field.trigger_name,
opts=field.model._meta,
field=field)
So as example::
update_clause = \"\"\"CREATE TRIGGER {trigger_name}
AFTER UPDATE ON {opts.db_table}
BEGIN UPDATE {opts.db_table}
SET {field.column} = {field.column}+1
WHERE {opts.pk.column} = NEW.{opts.pk.column};
END;
\"\"\"
"""
update_clause = ""
drop_clause = ""
list_clause = ""
def __init__(self, connection) -> None:
self.connection = connection
def get_trigger(self, field):
if field.trigger_name in self.get_list():
return field.trigger_name
return None
def create(self, field) -> None:
if field.trigger_name not in self.get_list():
stm = self.update_clause.format(trigger_name=field.trigger_name, opts=field.model._meta, field=field)
try:
self.connection.cursor().execute(stm)
except Exception as exc: # noqa pragma: no cover
msg = f"""Error executing:
{stm}
{exc}"""
raise DatabaseError(msg)
else: # pragma: no cover
pass
field._trigger_exists = True
def drop(self, field):
opts = field.model._meta
ret = []
stm = self.drop_clause.format(trigger_name=field.trigger_name, opts=opts, field=field)
self.connection.cursor().execute(stm)
ret.append(field.trigger_name)
return ret
def _list(self):
cursor = self.connection.cursor()
cursor.execute(self.list_clause)
return cursor.fetchall()
def get_list(self):
return sorted([m[0] for m in self._list()])
class Sqlite3(TriggerFactory):
drop_clause = """DROP TRIGGER IF EXISTS {trigger_name};"""
update_clause = """CREATE TRIGGER {trigger_name}
AFTER UPDATE ON {opts.db_table}
BEGIN UPDATE {opts.db_table} SET {field.column} = {field.column}+1 WHERE {opts.pk.column} = NEW.{opts.pk.column};
END;"""
list_clause = "select name from sqlite_master where type='trigger';"
class PostgreSQL(TriggerFactory):
drop_clause = r"""DROP TRIGGER IF EXISTS {trigger_name} ON {opts.db_table};"""
update_clause = r"""CREATE OR REPLACE FUNCTION func_{trigger_name}()
RETURNS TRIGGER as
'
BEGIN
NEW.{field.column} = OLD.{field.column} +1;
RETURN NEW;
END;
' language 'plpgsql';
CREATE TRIGGER {trigger_name} BEFORE UPDATE
ON {opts.db_table} FOR EACH ROW
EXECUTE PROCEDURE func_{trigger_name}();
"""
list_clause = "select tgname from pg_trigger where tgname LIKE 'concurrency_%%'; "
class MySQL(TriggerFactory):
drop_clause = """DROP TRIGGER IF EXISTS {trigger_name};"""
update_clause = """
CREATE TRIGGER {trigger_name} BEFORE UPDATE ON {opts.db_table}
FOR EACH ROW SET NEW.{field.column} = OLD.{field.column}+1;
"""
list_clause = "SHOW TRIGGERS"
def factory(conn):
from concurrency.config import conf # noqa: PLC0415
mapping = conf.TRIGGERS_FACTORY
try:
return mapping[conn.vendor](conn)
except KeyError: # pragma: no cover
msg = f"{conn} is not supported by TriggerVersionField"
raise ValueError(msg)