Source code for concurrency.fields

import copy
import functools
import hashlib
import inspect
import logging
import time
from collections import OrderedDict
from functools import update_wrapper

from django.db import models
from django.db.models import signals
from django.db.models.fields import Field
from django.db.models.signals import class_prepared, post_migrate
from django.utils.encoding import force_str
from django.utils.translation import gettext_lazy as _

from concurrency import forms
from concurrency.api import get_revision_of_object
from concurrency.config import conf
from concurrency.core import ConcurrencyOptions
from concurrency.utils import fqn, refetch

from .triggers import _TRIGGERS

logger = logging.getLogger(__name__)

OFFSET = int(time.mktime((2000, 1, 1, 0, 0, 0, 0, 0, 0)))


def _accepts_argument(func, argument_name: str) -> bool:
    """Return whether ``func`` accepts ``argument_name`` (or generic ``**kwargs``)."""
    try:
        parameters = inspect.signature(func).parameters.values()
    except (TypeError, ValueError):
        return False

    return any(
        parameter.name == argument_name or parameter.kind == inspect.Parameter.VAR_KEYWORD for parameter in parameters
    )


def class_prepared_concurrency_handler(sender, **kwargs) -> None:
    if hasattr(sender, "_concurrencymeta"):
        if sender != sender._concurrencymeta.base:
            origin = sender._concurrencymeta.base._concurrencymeta
            local = copy.deepcopy(origin)
            sender._concurrencymeta = local

        if hasattr(sender, "ConcurrencyMeta"):
            sender._concurrencymeta.enabled = getattr(sender.ConcurrencyMeta, "enabled", True)
            check_fields = getattr(sender.ConcurrencyMeta, "check_fields", None)
            ignore_fields = getattr(sender.ConcurrencyMeta, "ignore_fields", None)
            if check_fields and ignore_fields:
                raise ValueError("Cannot set both 'check_fields' and 'ignore_fields'")

            sender._concurrencymeta.check_fields = check_fields
            sender._concurrencymeta.ignore_fields = ignore_fields
            sender._concurrencymeta.increment = getattr(sender.ConcurrencyMeta, "increment", True)
            sender._concurrencymeta.skip = False

        if not (sender._concurrencymeta.manually):
            sender._concurrencymeta.field.wrap_model(sender)

        sender.get_concurrency_version = get_revision_of_object


def post_syncdb_concurrency_handler(sender, **kwargs) -> None:
    from django.db import connections  # noqa: PLC0415

    from concurrency.triggers import create_triggers  # noqa: PLC0415

    databases = list(connections)
    create_triggers(databases)


class_prepared.connect(
    class_prepared_concurrency_handler,
    dispatch_uid="class_prepared_concurrency_handler",
)


if conf.AUTO_CREATE_TRIGGERS:
    post_migrate.connect(post_syncdb_concurrency_handler, dispatch_uid="post_syncdb_concurrency_handler")


[docs] class VersionField(Field): """Base class.""" def __init__(self, *args, **kwargs) -> None: verbose_name = kwargs.get("verbose_name") name = kwargs.get("name") db_tablespace = kwargs.get("db_tablespace") db_column = kwargs.get("db_column") help_text = kwargs.get("help_text", _("record revision number")) super().__init__( verbose_name, name, help_text=help_text, default=0, db_tablespace=db_tablespace, db_column=db_column, ) def get_internal_type(self) -> str: return "BigIntegerField" def to_python(self, value): return int(value) def validate(self, value, model_instance) -> None: pass def formfield(self, **kwargs): kwargs["form_class"] = self.form_class kwargs["widget"] = forms.VersionField.widget return super().formfield(**kwargs) def contribute_to_class(self, cls, *args, **kwargs) -> None: super().contribute_to_class(cls, *args, **kwargs) if hasattr(cls, "_concurrencymeta") or cls._meta.abstract: return cls._concurrencymeta = ConcurrencyOptions() cls._concurrencymeta.field = self cls._concurrencymeta.base = cls cls._concurrencymeta.triggers = [] def _set_version_value(self, model_instance, value) -> None: setattr(model_instance, self.attname, int(value)) def pre_save(self, model_instance, add): # Django 6 can evaluate insert values twice during a single save() call. # Version assignment must be idempotent for new objects to avoid 0 -> 1 -> 2 on first insert. if add and not getattr(model_instance, self.attname): value = self._get_next_version(model_instance) self._set_version_value(model_instance, value) return getattr(model_instance, self.attname) @classmethod def wrap_model(cls, model, force=False) -> None: if not force and model._concurrencymeta.versioned_save: return cls._wrap_model_methods(model) model._concurrencymeta.versioned_save = True @staticmethod def _wrap_model_methods(model) -> None: old_do_update = model._do_update model._do_update = model._concurrencymeta.field._wrap_do_update(old_do_update) def _wrap_do_update(self, func): # noqa: C901 supports_returning_fields = _accepts_argument(func, "returning_fields") def _updated_with_no_values(): return [()] if supports_returning_fields else True def _not_updated(): return [] if supports_returning_fields else False def _update_with_filtered_queryset(filtered_queryset, values, returning_fields): if supports_returning_fields: return filtered_queryset._update(values, returning_fields=returning_fields) return filtered_queryset._update(values) >= 1 def _perform_update(model_instance, filtered_queryset, values, forced_update, returning_fields): if model_instance._meta.select_on_save and not forced_update: if not filtered_queryset.exists(): return _not_updated() updated = _update_with_filtered_queryset(filtered_queryset, values, returning_fields) if updated: return updated return _updated_with_no_values() if filtered_queryset.exists() else _not_updated() return _update_with_filtered_queryset(filtered_queryset, values, returning_fields) def _do_update( # noqa model_instance, base_qs, using, pk_val, values, update_fields, forced_update, returning_fields=None, ): version_field = model_instance._concurrencymeta.field old_version = get_revision_of_object(model_instance) if not version_field.model._meta.abstract and version_field.model is not base_qs.model: if supports_returning_fields: return func( model_instance, base_qs, using, pk_val, values, update_fields, forced_update, returning_fields=returning_fields, ) return func( model_instance, base_qs, using, pk_val, values, update_fields, forced_update, ) filtered = base_qs.filter(pk=pk_val) # This provides a default if either (1) no values were provided or (2) we reached this code as part of a # create. We don't need to worry about a race condition because a competing create should produce an # error anyway. if not values: if update_fields is not None or filtered.exists(): return _updated_with_no_values() return _not_updated() # This second situation can occur because `Model.save_base` calls `Model._save_parent` without relaying # the `force_insert` flag that marks the process as a create. Eventually, `Model._save_table` will call # this function as-if it were in the middle of an update. The update is expected to fail because there # is no object to update and the caller will fall back on the create logic instead. We need to ensure # the update fails (but does not raise an exception) under this circumstance by skipping the concurrency # logic. if filtered.exists(): for i, (field, _1, _value) in enumerate(values): if field == version_field: if model_instance._concurrencymeta.increment and not getattr( model_instance, "_concurrency_disable_increment", False, ): new_version = field._get_next_version(model_instance) values[i] = (field, _1, new_version) field._set_version_value(model_instance, new_version) break if ( model_instance._concurrencymeta.enabled and conf.ENABLED and not getattr(model_instance, "_concurrency_disabled", False) and (old_version or conf.VERSION_FIELD_REQUIRED) ): filter_kwargs = {"pk": pk_val, version_field.attname: old_version} updated = _perform_update( model_instance, base_qs.filter(**filter_kwargs), values, forced_update, returning_fields, ) if not updated: version_field._set_version_value(model_instance, old_version) callback_result = conf._callback(model_instance) if supports_returning_fields: return _updated_with_no_values() if callback_result else _not_updated() return callback_result return updated return _perform_update( model_instance, base_qs.filter(pk=pk_val), values, forced_update, returning_fields, ) return _not_updated() return update_wrapper(_do_update, func)
[docs] class IntegerVersionField(VersionField): """ Version Field that returns a "unique" version number for the record. The version number is produced using time.time() * 1000000, to get the benefits of microsecond if the system clock provides them. """ form_class = forms.VersionField def _get_next_version(self, model_instance): old_value = getattr(model_instance, self.attname, 0) return max(int(old_value) + 1, (int(time.time() * 1000000) - OFFSET))
[docs] class AutoIncVersionField(VersionField): """Version Field increment the revision number each commit.""" form_class = forms.VersionField def _get_next_version(self, model_instance): return int(getattr(model_instance, self.attname, 0)) + 1
class TriggerVersionField(VersionField): """Version Field increment the revision number each commit.""" form_class = forms.VersionField def __init__(self, *args, **kwargs) -> None: self._trigger_name = kwargs.pop("trigger_name", None) self._trigger_exists = False super().__init__(*args, **kwargs) def contribute_to_class(self, cls, *args, **kwargs) -> None: super().contribute_to_class(cls, *args, **kwargs) if (not cls._meta.abstract or cls._meta.proxy) and self not in _TRIGGERS: _TRIGGERS.append(self) def check(self, **kwargs): errors = [] model = self.model from django.core.checks import Warning # noqa: PLC0415 A004 from django.db import connections, router # noqa: PLC0415 from concurrency.triggers import factory # noqa: PLC0415 alias = router.db_for_write(model) connection = connections[alias] f = factory(connection) if not f.get_trigger(self): errors.append( Warning( f"Missed trigger for field {self}", hint=None, obj=None, id="concurrency.W001", ) ) return errors @property def trigger_name(self): from concurrency.triggers import get_trigger_name # noqa: PLC0415 return get_trigger_name(self) def _get_next_version(self, model_instance): # always returns the same value return int(getattr(model_instance, self.attname, 1)) def pre_save(self, model_instance, add) -> int: # always returns the same value return 1 @staticmethod def _increment_version_number(obj) -> None: old_value = get_revision_of_object(obj) setattr(obj, obj._concurrencymeta.field.attname, int(old_value) + 1) @staticmethod def _wrap_model_methods(model) -> None: super(TriggerVersionField, TriggerVersionField)._wrap_model_methods(model) old_save = model.save model.save = model._concurrencymeta.field._wrap_save(old_save) @staticmethod def _wrap_save(func): def inner(self, **kwargs): reload = kwargs.pop("refetch", False) ret = func(self, **kwargs) TriggerVersionField._increment_version_number(self) if reload: ret = refetch(self) setattr( self, self._concurrencymeta.field.attname, get_revision_of_object(ret), ) return ret return update_wrapper(inner, func) def filter_fields(instance, field) -> bool: if not field.concrete: # reverse relation return False if field.is_relation and field.related_model is None: # generic foreignkeys return False if field.many_to_many and instance.pk is None: # noqa # can't load remote object yet return False return True class ConditionalVersionField(AutoIncVersionField): def contribute_to_class(self, cls, *args, **kwargs) -> None: super().contribute_to_class(cls, *args, **kwargs) signals.post_init.connect(self._load_model, sender=cls, dispatch_uid=fqn(cls)) signals.post_save.connect(self._save_model, sender=cls, dispatch_uid=fqn(cls)) def _load_model(self, *args, **kwargs) -> None: instance = kwargs["instance"] instance._concurrency_initial = self._get_hash(instance) def _save_model(self, *args, **kwargs) -> None: instance = kwargs["instance"] instance._concurrency_initial = self._get_hash(instance) def _get_hash(self, instance): values = OrderedDict() opts = instance._meta check_fields = instance._concurrencymeta.check_fields ignore_fields = instance._concurrencymeta.ignore_fields filter_ = functools.partial(filter_fields, instance) if check_fields is None and ignore_fields is None: fields = sorted([f.name for f in filter(filter_, instance._meta.get_fields())]) elif check_fields is None: fields = sorted( [f.name for f in filter(filter_, instance._meta.get_fields()) if f.name not in ignore_fields] ) else: fields = instance._concurrencymeta.check_fields for field_name in fields: # do not use getattr here. we do not need extra sql to retrieve # FK. the raw value of the FK is enough field = opts.get_field(field_name) if isinstance(field, models.ManyToManyField): values[field_name] = getattr(instance, field_name).values_list("pk", flat=True) else: values[field_name] = field.value_from_object(instance) return hashlib.sha1(force_str(values).encode("utf-8")).hexdigest() # noqa def _get_next_version(self, model_instance): if not model_instance.pk: return int(getattr(model_instance, self.attname) + 1) old = getattr(model_instance, "_concurrency_initial", None) new = self._get_hash(model_instance) if old != new: return int(getattr(model_instance, self.attname, 0) + 1) return int(getattr(model_instance, self.attname, 0))