"""
This code was copied and adapted from `django-model-utils`_, and originally contributed to that package by
`Jeff Elmore`_.
.. _django-model-utils: https://django-model-utils.readthedocs.io/en/latest/managers.html#inheritancemanager
.. _Jeff Elmore: http://jeffelmore.org/2010/11/11/automatic-downcasting-of-inherited-models-in-django/
The original license is preserved below:
.. code-block:: text
Copyright (c) 2009-2019, Carl Meyer and contributors
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
* Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above
copyright notice, this list of conditions and the following
disclaimer in the documentation and/or other materials provided
with the distribution.
* Neither the name of the author nor the names of other
contributors may be used to endorse or promote products derived
from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""
# Imports
from django.db import models
from django.db.models.constants import LOOKUP_SEP
from django.db.models.fields.related import OneToOneField, OneToOneRel
from django.db.models.query import QuerySet
from django.db.models.query import ModelIterable
from django.core.exceptions import ObjectDoesNotExist
# Exports
__all__ = (
"PolymorphicManager",
)
# Classes
class PolymorphicIterable(ModelIterable):
"""Iterable that yields a polymorphic model instance for each row."""
def __iter__(self):
queryset = self.queryset
iter = ModelIterable(queryset)
if getattr(queryset, 'subclasses', False):
extras = tuple(queryset.query.extra.keys())
# Sort the subclass names longest first, so with 'a' and 'a__b' it goes as deep as possible.
subclasses = sorted(queryset.subclasses, key=len, reverse=True)
for obj in iter:
sub_obj = None
for s in subclasses:
sub_obj = queryset._get_sub_obj_recurse(obj, s)
if sub_obj:
break
if not sub_obj:
sub_obj = obj
if getattr(queryset, '_annotated', False):
for k in queryset._annotated:
setattr(sub_obj, k, getattr(obj, k))
for k in extras:
setattr(sub_obj, k, getattr(obj, k))
yield sub_obj
else:
yield from iter
class PolymorphicQuerySetMixin:
"""A queryset mixin which adds support for acquiring the subclassed model instances of a polymorhpic model."""
def __init__(self, *args, **kwargs):
"""Initialize the queryset to also assign ``_iterable_class`` as
:py:class:`superdjango.db.polymorphic.managers.PolymorphicIterable`.
"""
super().__init__(*args, **kwargs)
self._iterable_class = PolymorphicIterable
def select_subclasses(self, *subclasses):
"""Get the queryset for the given subclasses (model classes)."""
levels = None
# noinspection PyUnresolvedReferences
calculated_subclasses = self._get_subclasses_recurse(self.model, levels=levels)
# If no sub classes were provided, select all classes.
if not subclasses:
subclasses = calculated_subclasses
else:
verified_subclasses = []
for subclass in subclasses:
# Special case for passing in the same model as the queryset is bound against. Rather than raise an
# error later, we know we can allow this through.
# noinspection PyUnresolvedReferences
if subclass is self.model:
continue
if not isinstance(subclass, (str,)):
subclass = self._get_ancestors_path(subclass, levels=levels)
if subclass in calculated_subclasses:
verified_subclasses.append(subclass)
else:
error = '{!r} is not in the discovered subclasses, tried: {}'
raise ValueError(error.format(subclass, ', '.join(calculated_subclasses)))
subclasses = verified_subclasses
# Workaround for: https://code.djangoproject.com/ticket/16855
# noinspection PyUnresolvedReferences
previous_select_related = self.query.select_related
# noinspection PyUnresolvedReferences
new_qs = self.select_related(*subclasses)
previous_is_dict = isinstance(previous_select_related, dict)
new_is_dict = isinstance(new_qs.query.select_related, dict)
if previous_is_dict and new_is_dict:
new_qs.query.select_related.update(previous_select_related)
new_qs.subclasses = subclasses
return new_qs
def _chain(self, **kwargs):
"""Handle chaining for polymorphic subclasses."""
for name in ['subclasses', '_annotated']:
if hasattr(self, name):
kwargs[name] = getattr(self, name)
# noinspection PyProtectedMember,PyUnresolvedReferences
return super()._chain(**kwargs)
# noinspection PyUnusedLocal
def _clone(self, klass=None, setup=False, **kwargs):
"""Handle cloning for polymorphic subclasses."""
# noinspection PyProtectedMember,PyUnresolvedReferences
qs = super()._clone()
for name in ['subclasses', '_annotated']:
if hasattr(self, name):
setattr(qs, name, getattr(self, name))
return qs
def annotate(self, *args, **kwargs):
"""Handle annotation for polymorphic instances."""
# noinspection PyUnresolvedReferences
qset = super().annotate(*args, **kwargs)
qset._annotated = [a.default_alias for a in args] + list(kwargs.keys())
return qset
def _get_subclasses_recurse(self, model, levels=None):
"""Given a model class, find all related objects, exploring children recursively, returning a list of strings
representing the relations for ``select_related``.
"""
related_objects = [
f for f in model._meta.get_fields()
if isinstance(f, OneToOneRel)]
# noinspection PyUnresolvedReferences
rels = [
rel for rel in related_objects
if isinstance(rel.field, OneToOneField)
and issubclass(rel.field.model, model)
and model is not rel.field.model
and rel.parent_link
]
subclasses = []
if levels:
levels -= 1
for rel in rels:
if levels or levels is None:
for subclass in self._get_subclasses_recurse(rel.field.model, levels=levels):
subclasses.append(rel.get_accessor_name() + LOOKUP_SEP + subclass)
subclasses.append(rel.get_accessor_name())
return subclasses
def _get_ancestors_path(self, model, levels=None):
"""Serves as an opposite to ``_get_subclasses_recurse()``, instead walking from the model class up the model's
ancestry and constructing the desired select_related string backwards.
"""
# noinspection PyUnresolvedReferences
if not issubclass(model, self.model):
# noinspection PyUnresolvedReferences
raise ValueError("{!r} is not a subclass of {!r}".format(model, self.model))
ancestry = []
# Should be a OneToOneField or None.
# noinspection PyProtectedMember,PyUnresolvedReferences
parent_link = model._meta.get_ancestor_link(self.model)
if levels:
levels -= 1
while parent_link is not None:
related = parent_link.remote_field
ancestry.insert(0, related.get_accessor_name())
if levels or levels is None:
parent_model = related.model
# noinspection PyUnresolvedReferences,PyProtectedMember
parent_link = parent_model._meta.get_ancestor_link(self.model)
else:
parent_link = None
return LOOKUP_SEP.join(ancestry)
def _get_sub_obj_recurse(self, obj, s):
"""Overridden to support polymorphic behavior."""
rel, _, s = s.partition(LOOKUP_SEP)
try:
node = getattr(obj, rel)
except ObjectDoesNotExist:
return None
if s:
child = self._get_sub_obj_recurse(node, s)
return child
else:
return node
def get_subclass(self, *args, **kwargs):
"""Get the subclass instance for the given parameters."""
return self.select_subclasses().get(*args, **kwargs)
class PolymorphicQuerySet(PolymorphicQuerySetMixin, QuerySet):
"""A polymorphic queryset that implements support for fetching based on model class."""
def instance_of(self, *models):
"""Fetch only objects that are instances of the provided model(s)."""
# If we aren't already selecting the subclasess, we need to in order to get this to work.
# How can we tell if we are not selecting subclasses?
# Is it safe to just apply .select_subclasses(*models)?
# Due to https://code.djangoproject.com/ticket/16572, we can't really do this for anything other than children,
# i.e. no grandchildren, great grandchildren, etc.
where_queries = []
for model in models:
where_queries.append('(' + ' AND '.join([
'"{}"."{}" IS NOT NULL'.format(
model._meta.db_table,
field.attname, # Should this be something else?
) for field in model._meta.parents.values()
]) + ')')
return self.select_subclasses(*models).extra(where=[' OR '.join(where_queries)])
class PolymorphicManagerMixin:
"""A mixin for implementing the :py:class:`superdjango.db.polymorphic.managers.PolymorphicQuerySet`."""
_queryset_class = PolymorphicQuerySet
def get_queryset(self):
"""Overridden to use :py:class:`superdjango.db.polymorphic.managers.PolymorphicQuerySet`."""
# noinspection PyUnresolvedReferences
return self._queryset_class(self.model)
def select_subclasses(self, *subclasses):
"""Get the queryset while resolving subclass instances."""
return self.get_queryset().select_subclasses(*subclasses)
def get_subclass(self, *args, **kwargs):
"""Get a specific instance of a child model."""
return self.get_queryset().get_subclass(*args, **kwargs)
def instance_of(self, *models):
"""Get the instance for the given model(s)."""
return self.get_queryset().instance_of(*models)
[docs]class PolymorphicManager(PolymorphicManagerMixin, models.Manager):
"""A polymorphic manager."""
pass