Source code for superdjango.db.polymorphic.managers

"""
This code was copied and adapted from `django-model-utils`_, and originally contributed to that package by
`Jeff Elmore`_.

.. _django-model-utils: https://django-model-utils.readthedocs.io/en/latest/managers.html#inheritancemanager
.. _Jeff Elmore: http://jeffelmore.org/2010/11/11/automatic-downcasting-of-inherited-models-in-django/

The original license is preserved below:

.. code-block:: text

    Copyright (c) 2009-2019, Carl Meyer and contributors
    All rights reserved.

    Redistribution and use in source and binary forms, with or without
    modification, are permitted provided that the following conditions are
    met:

        * Redistributions of source code must retain the above copyright
          notice, this list of conditions and the following disclaimer.
        * Redistributions in binary form must reproduce the above
          copyright notice, this list of conditions and the following
          disclaimer in the documentation and/or other materials provided
          with the distribution.
        * Neither the name of the author nor the names of other
          contributors may be used to endorse or promote products derived
          from this software without specific prior written permission.

    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
    "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
    A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
    OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
    SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
    LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
    DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
    THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
    (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
    OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

"""
# Imports

from django.db import models
from django.db.models.constants import LOOKUP_SEP
from django.db.models.fields.related import OneToOneField, OneToOneRel
from django.db.models.query import QuerySet
from django.db.models.query import ModelIterable
from django.core.exceptions import ObjectDoesNotExist

# Exports

__all__ = (
    "PolymorphicManager",
)

# Classes


class PolymorphicIterable(ModelIterable):
    """Iterable that yields a polymorphic model instance for each row."""

    def __iter__(self):
        queryset = self.queryset
        iter = ModelIterable(queryset)
        if getattr(queryset, 'subclasses', False):
            extras = tuple(queryset.query.extra.keys())

            # Sort the subclass names longest first, so with 'a' and 'a__b' it goes as deep as possible.
            subclasses = sorted(queryset.subclasses, key=len, reverse=True)
            for obj in iter:
                sub_obj = None
                for s in subclasses:
                    sub_obj = queryset._get_sub_obj_recurse(obj, s)
                    if sub_obj:
                        break

                if not sub_obj:
                    sub_obj = obj

                if getattr(queryset, '_annotated', False):
                    for k in queryset._annotated:
                        setattr(sub_obj, k, getattr(obj, k))

                for k in extras:
                    setattr(sub_obj, k, getattr(obj, k))

                yield sub_obj
        else:
            yield from iter


class PolymorphicQuerySetMixin:
    """A queryset mixin which adds support for acquiring the subclassed model instances of a polymorhpic model."""

    def __init__(self, *args, **kwargs):
        """Initialize the queryset to also assign ``_iterable_class`` as
        :py:class:`superdjango.db.polymorphic.managers.PolymorphicIterable`.
        """
        super().__init__(*args, **kwargs)
        self._iterable_class = PolymorphicIterable

    def select_subclasses(self, *subclasses):
        """Get the queryset for the given subclasses (model classes)."""
        levels = None

        # noinspection PyUnresolvedReferences
        calculated_subclasses = self._get_subclasses_recurse(self.model, levels=levels)

        # If no sub classes were provided, select all classes.
        if not subclasses:
            subclasses = calculated_subclasses
        else:
            verified_subclasses = []
            for subclass in subclasses:
                # Special case for passing in the same model as the queryset is bound against. Rather than raise an
                # error later, we know we can allow this through.
                # noinspection PyUnresolvedReferences
                if subclass is self.model:
                    continue

                if not isinstance(subclass, (str,)):
                    subclass = self._get_ancestors_path(subclass, levels=levels)

                if subclass in calculated_subclasses:
                    verified_subclasses.append(subclass)
                else:
                    error = '{!r} is not in the discovered subclasses, tried: {}'
                    raise ValueError(error.format(subclass, ', '.join(calculated_subclasses)))

            subclasses = verified_subclasses

        # Workaround for: https://code.djangoproject.com/ticket/16855
        # noinspection PyUnresolvedReferences
        previous_select_related = self.query.select_related
        # noinspection PyUnresolvedReferences
        new_qs = self.select_related(*subclasses)
        previous_is_dict = isinstance(previous_select_related, dict)
        new_is_dict = isinstance(new_qs.query.select_related, dict)
        if previous_is_dict and new_is_dict:
            new_qs.query.select_related.update(previous_select_related)

        new_qs.subclasses = subclasses

        return new_qs

    def _chain(self, **kwargs):
        """Handle chaining for polymorphic subclasses."""

        for name in ['subclasses', '_annotated']:
            if hasattr(self, name):
                kwargs[name] = getattr(self, name)

        # noinspection PyProtectedMember,PyUnresolvedReferences
        return super()._chain(**kwargs)

    # noinspection PyUnusedLocal
    def _clone(self, klass=None, setup=False, **kwargs):
        """Handle cloning for polymorphic subclasses."""

        # noinspection PyProtectedMember,PyUnresolvedReferences
        qs = super()._clone()
        for name in ['subclasses', '_annotated']:
            if hasattr(self, name):
                setattr(qs, name, getattr(self, name))

        return qs

    def annotate(self, *args, **kwargs):
        """Handle annotation for polymorphic instances."""

        # noinspection PyUnresolvedReferences
        qset = super().annotate(*args, **kwargs)
        qset._annotated = [a.default_alias for a in args] + list(kwargs.keys())

        return qset

    def _get_subclasses_recurse(self, model, levels=None):
        """Given a model class, find all related objects, exploring children recursively, returning a list of strings
        representing the relations for ``select_related``.
        """
        related_objects = [
            f for f in model._meta.get_fields()
            if isinstance(f, OneToOneRel)]

        # noinspection PyUnresolvedReferences
        rels = [
            rel for rel in related_objects
            if isinstance(rel.field, OneToOneField)
               and issubclass(rel.field.model, model)
               and model is not rel.field.model
               and rel.parent_link
        ]

        subclasses = []
        if levels:
            levels -= 1

        for rel in rels:
            if levels or levels is None:
                for subclass in self._get_subclasses_recurse(rel.field.model, levels=levels):
                    subclasses.append(rel.get_accessor_name() + LOOKUP_SEP + subclass)

            subclasses.append(rel.get_accessor_name())

        return subclasses

    def _get_ancestors_path(self, model, levels=None):
        """Serves as an opposite to ``_get_subclasses_recurse()``, instead walking from the model class up the model's
        ancestry and constructing the desired select_related string backwards.
        """
        # noinspection PyUnresolvedReferences
        if not issubclass(model, self.model):
            # noinspection PyUnresolvedReferences
            raise ValueError("{!r} is not a subclass of {!r}".format(model, self.model))

        ancestry = []

        # Should be a OneToOneField or None.
        # noinspection PyProtectedMember,PyUnresolvedReferences
        parent_link = model._meta.get_ancestor_link(self.model)
        if levels:
            levels -= 1

        while parent_link is not None:
            related = parent_link.remote_field
            ancestry.insert(0, related.get_accessor_name())
            if levels or levels is None:
                parent_model = related.model
                # noinspection PyUnresolvedReferences,PyProtectedMember
                parent_link = parent_model._meta.get_ancestor_link(self.model)
            else:
                parent_link = None

        return LOOKUP_SEP.join(ancestry)

    def _get_sub_obj_recurse(self, obj, s):
        """Overridden to support polymorphic behavior."""
        rel, _, s = s.partition(LOOKUP_SEP)

        try:
            node = getattr(obj, rel)
        except ObjectDoesNotExist:
            return None

        if s:
            child = self._get_sub_obj_recurse(node, s)
            return child
        else:
            return node

    def get_subclass(self, *args, **kwargs):
        """Get the subclass instance for the given parameters."""
        return self.select_subclasses().get(*args, **kwargs)


class PolymorphicQuerySet(PolymorphicQuerySetMixin, QuerySet):
    """A polymorphic queryset that implements support for fetching based on model class."""

    def instance_of(self, *models):
        """Fetch only objects that are instances of the provided model(s)."""
        # If we aren't already selecting the subclasess, we need to in order to get this to work.

        # How can we tell if we are not selecting subclasses?

        # Is it safe to just apply .select_subclasses(*models)?

        # Due to https://code.djangoproject.com/ticket/16572, we can't really do this for anything other than children,
        # i.e. no grandchildren, great grandchildren, etc.
        where_queries = []
        for model in models:
            where_queries.append('(' + ' AND '.join([
                '"{}"."{}" IS NOT NULL'.format(
                    model._meta.db_table,
                    field.attname,  # Should this be something else?
                ) for field in model._meta.parents.values()
            ]) + ')')

        return self.select_subclasses(*models).extra(where=[' OR '.join(where_queries)])


class PolymorphicManagerMixin:
    """A mixin for implementing the :py:class:`superdjango.db.polymorphic.managers.PolymorphicQuerySet`."""
    _queryset_class = PolymorphicQuerySet

    def get_queryset(self):
        """Overridden to use :py:class:`superdjango.db.polymorphic.managers.PolymorphicQuerySet`."""
        # noinspection PyUnresolvedReferences
        return self._queryset_class(self.model)

    def select_subclasses(self, *subclasses):
        """Get the queryset while resolving subclass instances."""
        return self.get_queryset().select_subclasses(*subclasses)

    def get_subclass(self, *args, **kwargs):
        """Get a specific instance of a child model."""
        return self.get_queryset().get_subclass(*args, **kwargs)

    def instance_of(self, *models):
        """Get the instance for the given model(s)."""
        return self.get_queryset().instance_of(*models)


[docs]class PolymorphicManager(PolymorphicManagerMixin, models.Manager): """A polymorphic manager.""" pass