Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions seal/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from enum import Enum


class Seal(Enum):
SINGLE = 'single'
MULTIPLE = 'multiple'
17 changes: 9 additions & 8 deletions seal/descriptors.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,9 @@ def _fetch_all(self):
class SealedPrefetchMixin(object):
def get_prefetch_queryset(self, instances, queryset=None):
prefetch = super(SealedPrefetchMixin, self).get_prefetch_queryset(instances, queryset)
if getattr(instances[0]._state, 'sealed', False) and isinstance(prefetch[0], SealableQuerySet):
prefetch = (prefetch[0].seal(),) + prefetch[1:]
seal = getattr(instances[0]._state, 'seal', None)
if seal is not None and isinstance(prefetch[0], SealableQuerySet):
prefetch = (prefetch[0].seal(seal=seal),) + prefetch[1:]
return prefetch


Expand All @@ -81,7 +82,7 @@ def seal_related_queryset(queryset, warning):
def create_sealable_related_manager(related_manager_cls, field_name):
class SealableRelatedManager(SealedPrefetchMixin, related_manager_cls):
def get_queryset(self):
if getattr(self.instance._state, 'sealed', False):
if hasattr(self.instance._state, 'seal'):
try:
prefetch_cache_name = self.prefetch_cache_name
except AttributeError:
Expand Down Expand Up @@ -110,7 +111,7 @@ def _check_parent_chain(self, instance, field_name=None):
def __get__(self, instance, cls=None):
if instance is None:
return self
if (getattr(instance._state, 'sealed', False) and
if (hasattr(instance._state, 'seal') and
instance.__dict__.get(self.field_name, self) is self and
self._check_parent_chain(instance, self.field_name) is None):
message = 'Attempt to fetch deferred field "%s" on sealed %s.' % (self.field_name, _bare_repr(instance))
Expand All @@ -120,7 +121,7 @@ def __get__(self, instance, cls=None):

class SealableForwardOneToOneDescriptor(SealedPrefetchMixin, ForwardOneToOneDescriptor):
def get_object(self, instance):
sealed = getattr(instance._state, 'sealed', False)
sealed = hasattr(instance._state, 'seal')
if sealed:
from .models import SealableModel
rel_model = self.field.remote_field.model
Expand Down Expand Up @@ -154,15 +155,15 @@ def get_object(self, instance):

class SealableReverseOneToOneDescriptor(SealedPrefetchMixin, ReverseOneToOneDescriptor):
def get_queryset(self, instance, **hints):
if getattr(instance._state, 'sealed', False):
if hasattr(instance._state, 'seal'):
message = 'Attempt to fetch related field "%s" on sealed %s.' % (self.related.name, _bare_repr(instance))
warnings.warn(message, category=UnsealedAttributeAccess, stacklevel=3)
return super(SealableReverseOneToOneDescriptor, self).get_queryset(instance=instance, **hints)


class SealableForwardManyToOneDescriptor(ForwardManyToOneDescriptor):
def get_object(self, instance):
if getattr(instance._state, 'sealed', False):
if getattr(instance._state, 'seal', False):
message = 'Attempt to fetch related field "%s" on sealed %s.' % (self.field.name, _bare_repr(instance))
warnings.warn(message, category=UnsealedAttributeAccess, stacklevel=3)
return super(SealableForwardManyToOneDescriptor, self).get_object(instance)
Expand All @@ -188,7 +189,7 @@ def __get__(self, instance, cls=None):
if instance is None:
return self

if getattr(instance._state, 'sealed', False) and not self.is_cached(instance):
if hasattr(instance._state, 'seal') and not self.is_cached(instance):
message = 'Attempt to fetch related field "%s" on sealed %s.' % (self.name, _bare_repr(instance))
warnings.warn(message, category=UnsealedAttributeAccess, stacklevel=2)

Expand Down
5 changes: 3 additions & 2 deletions seal/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from django.db.models.fields.related import lazy_related_operation
from django.dispatch import receiver

from .constants import Seal
from .descriptors import sealable_descriptor_classes
from .query import SealableQuerySet

Expand Down Expand Up @@ -43,12 +44,12 @@ class SealableModel(models.Model):
class Meta:
abstract = True

def seal(self):
def seal(self, seal=Seal.SINGLE):
"""
Seal the instance to turn deferred and related fields access that would
required fetching from the database into exceptions.
"""
self._state.sealed = True
self._state.seal = seal


def make_descriptor_sealable(model, attname):
Expand Down
44 changes: 39 additions & 5 deletions seal/query.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from __future__ import unicode_literals

from functools import partial
from functools import partial, wraps
from operator import attrgetter

import django
from django.db import models

from .constants import Seal

if django.VERSION >= (2, 0):
cached_value_getter = attrgetter('get_cached_value')
else:
Expand Down Expand Up @@ -40,18 +42,24 @@ def walk_select_relateds(obj, getters):


class SealedModelIterable(models.query.ModelIterable):
def __init__(self, queryset, **kwargs):
self.seal = queryset._seal
super(SealedModelIterable, self).__init__(queryset, **kwargs)

def _sealed_iterator(self):
"""Iterate over objects and seal them."""
objs = super(SealedModelIterable, self).__iter__()
seal = self.seal
for obj in objs:
obj._state.sealed = True
obj._state.seal = seal
yield obj

def _sealed_related_iterator(self, related_walker):
"""Iterate over objects and seal them and their select related."""
seal = self.seal
for obj in self._sealed_iterator():
for related_obj in related_walker(obj):
related_obj._state.sealed = True
related_obj._state.seal = seal
yield obj

def __iter__(self):
Expand All @@ -68,8 +76,20 @@ def __iter__(self):
yield obj


def single_result_method(func):
@wraps(func)
def wrapper(self, *args, **kwargs):
queryset = self
iterable_class = queryset._iterable_class
if issubclass(iterable_class, SealedModelIterable):
queryset = queryset._clone(_seal=Seal.SINGLE)
return func(queryset, *args, **kwargs)
return wrapper


class SealableQuerySet(models.QuerySet):
_base_manager_class = None
_seal = None

def as_manager(cls):
manager = cls._base_manager_class.from_queryset(cls)()
Expand All @@ -78,11 +98,25 @@ def as_manager(cls):
as_manager.queryset_only = True
as_manager = classmethod(as_manager)

def seal(self, iterable_class=SealedModelIterable):
def _clone(self, **kwargs):
seal = kwargs.pop('_seal', self._seal)
clone = super(SealableQuerySet, self)._clone(**kwargs)
clone._seal = seal
return clone

def seal(self, iterable_class=SealedModelIterable, seal=Seal.MULTIPLE):
if self._fields is not None:
raise TypeError('Cannot call seal() after .values() or .values_list()')
if not issubclass(iterable_class, SealedModelIterable):
raise TypeError('iterable_class %r is not a subclass of SealedModelIterable' % iterable_class)
clone = self._clone()
clone = self._clone(_seal=seal)
clone._iterable_class = iterable_class
return clone

get = single_result_method(models.QuerySet.get)
first = single_result_method(models.QuerySet.first)
last = single_result_method(models.QuerySet.last)
latest = single_result_method(models.QuerySet.latest)
earliest = single_result_method(models.QuerySet.earliest)
get_or_create = single_result_method(models.QuerySet.get_or_create)
update_or_create = single_result_method(models.QuerySet.update_or_create)
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
author_email='[email protected]',
install_requires=[
'Django>=1.11',
'enum34;python_version<"3.4"',
],
packages=find_packages(exclude=['tests', 'tests.*']),
license='MIT License',
Expand Down
13 changes: 7 additions & 6 deletions tests/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from django.test import SimpleTestCase, TestCase
from django.test.utils import isolate_apps

from seal.constants import Seal
from seal.descriptors import _SealedRelatedQuerySet
from seal.exceptions import UnsealedAttributeAccess
from seal.models import make_model_sealable
Expand Down Expand Up @@ -41,7 +42,7 @@ def setUp(self):

def test_state_sealed_assigned(self):
instance = SeaLion.objects.seal().get()
self.assertTrue(instance._state.sealed)
self.assertIs(instance._state.seal, Seal.SINGLE)

def test_sealed_deferred_field(self):
instance = SeaLion.objects.seal().defer('weight').get()
Expand Down Expand Up @@ -399,7 +400,7 @@ class Meta:
db_table = Location._meta.db_table
queryset = SealableQuerySet(model=NonSealableLocation)
instance = queryset.seal().get()
self.assertTrue(instance._state.sealed)
self.assertTrue(instance._state.seal)

@isolate_apps('tests')
def test_sealed_select_related_non_sealable_model(self):
Expand All @@ -414,8 +415,8 @@ class Meta:
db_table = SeaLion._meta.db_table
queryset = SealableQuerySet(model=NonSealableSeaLion)
instance = queryset.select_related('location').seal().get()
self.assertTrue(instance._state.sealed)
self.assertTrue(instance.location._state.sealed)
self.assertIs(instance._state.seal, Seal.SINGLE)
self.assertIs(instance.location._state.seal, Seal.SINGLE)

@isolate_apps('tests')
def test_sealed_prefetch_related_non_sealable_model(self):
Expand All @@ -440,6 +441,6 @@ class Meta:
make_model_sealable(NonSealableLocation)
queryset = SealableQuerySet(model=NonSealableLocation)
instance = queryset.prefetch_related('climates').seal().get()
self.assertTrue(instance._state.sealed)
self.assertIs(instance._state.seal, Seal.SINGLE)
with self.assertNumQueries(0):
self.assertTrue(instance.climates.all()[0]._state.sealed)
self.assertIs(instance.climates.all()[0]._state.seal, Seal.SINGLE)