|  | 
| 1 | 1 | from functools import reduce, wraps | 
| 2 | 2 | from operator import add as add_operator | 
|  | 3 | +from collections.abc import Mapping | 
| 3 | 4 | 
 | 
| 4 | 5 | from django.core.exceptions import EmptyResultSet, FullResultSet | 
| 5 |  | -from django.db import DatabaseError, IntegrityError, NotSupportedError | 
|  | 6 | +from django.db import DatabaseError, IntegrityError, NotSupportedError, connections | 
|  | 7 | +from django.db.models import QuerySet | 
| 6 | 8 | from django.db.models.expressions import Case, Col, When | 
| 7 | 9 | from django.db.models.functions import Mod | 
| 8 | 10 | from django.db.models.lookups import Exact | 
| 9 |  | -from django.db.models.sql.constants import INNER | 
|  | 11 | +from django.db.models.query import BaseIterable | 
|  | 12 | +from django.db.models.sql.constants import INNER, GET_ITERATOR_CHUNK_SIZE | 
| 10 | 13 | from django.db.models.sql.datastructures import Join | 
| 11 | 14 | from django.db.models.sql.where import AND, OR, XOR, ExtraWhere, NothingNode, WhereNode | 
|  | 15 | +from django.utils.functional import cached_property | 
| 12 | 16 | from pymongo.errors import BulkWriteError, DuplicateKeyError, PyMongoError | 
| 13 | 17 | 
 | 
| 14 | 18 | 
 | 
| @@ -307,3 +311,283 @@ def register_nodes(): | 
| 307 | 311 |     Join.as_mql = join | 
| 308 | 312 |     NothingNode.as_mql = NothingNode.as_sql | 
| 309 | 313 |     WhereNode.as_mql = where_node | 
|  | 314 | + | 
|  | 315 | + | 
|  | 316 | +class MongoQuerySet(QuerySet): | 
|  | 317 | +    def raw_mql(self, raw_query, params=(), translations=None, using=None): | 
|  | 318 | +        if using is None: | 
|  | 319 | +            using = self.db | 
|  | 320 | +        qs = RawQuerySet( | 
|  | 321 | +            raw_query, | 
|  | 322 | +            model=self.model, | 
|  | 323 | +            params=params, | 
|  | 324 | +            translations=translations, | 
|  | 325 | +            using=using, | 
|  | 326 | +        ) | 
|  | 327 | +        return qs | 
|  | 328 | + | 
|  | 329 | + | 
|  | 330 | +class RawQuerySet: | 
|  | 331 | +    """ | 
|  | 332 | +    Provide an iterator which converts the results of raw SQL queries into | 
|  | 333 | +    annotated model instances. | 
|  | 334 | +    """ | 
|  | 335 | + | 
|  | 336 | +    def __init__( | 
|  | 337 | +        self, | 
|  | 338 | +        raw_query, | 
|  | 339 | +        model=None, | 
|  | 340 | +        query=None, | 
|  | 341 | +        params=(), | 
|  | 342 | +        translations=None, | 
|  | 343 | +        using=None, | 
|  | 344 | +        hints=None, | 
|  | 345 | +    ): | 
|  | 346 | +        self.raw_query = raw_query | 
|  | 347 | +        self.model = model | 
|  | 348 | +        self._db = using | 
|  | 349 | +        self._hints = hints or {} | 
|  | 350 | +        self.query = query or RawQuery(sql=raw_query, using=self.db, params=params) | 
|  | 351 | +        self.params = params | 
|  | 352 | +        self.translations = translations or {} | 
|  | 353 | +        self._result_cache = None | 
|  | 354 | +        self._prefetch_related_lookups = () | 
|  | 355 | +        self._prefetch_done = False | 
|  | 356 | + | 
|  | 357 | +    def resolve_model_init_order(self): | 
|  | 358 | +        """Resolve the init field names and value positions.""" | 
|  | 359 | +        converter = connections[self.db].introspection.identifier_converter | 
|  | 360 | +        model_init_fields = [ | 
|  | 361 | +            f for f in self.model._meta.fields if converter(f.column) in self.columns | 
|  | 362 | +        ] | 
|  | 363 | +        annotation_fields = [ | 
|  | 364 | +            (column, pos) | 
|  | 365 | +            for pos, column in enumerate(self.columns) | 
|  | 366 | +            if column not in self.model_fields | 
|  | 367 | +        ] | 
|  | 368 | +        model_init_order = [self.columns.index(converter(f.column)) for f in model_init_fields] | 
|  | 369 | +        model_init_names = [f.attname for f in model_init_fields] | 
|  | 370 | +        return model_init_names, model_init_order, annotation_fields | 
|  | 371 | + | 
|  | 372 | +    def prefetch_related(self, *lookups): | 
|  | 373 | +        """Same as QuerySet.prefetch_related()""" | 
|  | 374 | +        clone = self._clone() | 
|  | 375 | +        if lookups == (None,): | 
|  | 376 | +            clone._prefetch_related_lookups = () | 
|  | 377 | +        else: | 
|  | 378 | +            clone._prefetch_related_lookups = clone._prefetch_related_lookups + lookups | 
|  | 379 | +        return clone | 
|  | 380 | + | 
|  | 381 | +    def _prefetch_related_objects(self): | 
|  | 382 | +        prefetch_related_objects(self._result_cache, *self._prefetch_related_lookups) | 
|  | 383 | +        self._prefetch_done = True | 
|  | 384 | + | 
|  | 385 | +    def _clone(self): | 
|  | 386 | +        """Same as QuerySet._clone()""" | 
|  | 387 | +        c = self.__class__( | 
|  | 388 | +            self.raw_query, | 
|  | 389 | +            model=self.model, | 
|  | 390 | +            query=self.query, | 
|  | 391 | +            params=self.params, | 
|  | 392 | +            translations=self.translations, | 
|  | 393 | +            using=self._db, | 
|  | 394 | +            hints=self._hints, | 
|  | 395 | +        ) | 
|  | 396 | +        c._prefetch_related_lookups = self._prefetch_related_lookups[:] | 
|  | 397 | +        return c | 
|  | 398 | + | 
|  | 399 | +    def _fetch_all(self): | 
|  | 400 | +        if self._result_cache is None: | 
|  | 401 | +            self._result_cache = list(self.iterator()) | 
|  | 402 | +        if self._prefetch_related_lookups and not self._prefetch_done: | 
|  | 403 | +            self._prefetch_related_objects() | 
|  | 404 | + | 
|  | 405 | +    def __len__(self): | 
|  | 406 | +        self._fetch_all() | 
|  | 407 | +        return len(self._result_cache) | 
|  | 408 | + | 
|  | 409 | +    def __bool__(self): | 
|  | 410 | +        self._fetch_all() | 
|  | 411 | +        return bool(self._result_cache) | 
|  | 412 | + | 
|  | 413 | +    def __iter__(self): | 
|  | 414 | +        self._fetch_all() | 
|  | 415 | +        return iter(self._result_cache) | 
|  | 416 | + | 
|  | 417 | +    def __aiter__(self): | 
|  | 418 | +        # Remember, __aiter__ itself is synchronous, it's the thing it returns | 
|  | 419 | +        # that is async! | 
|  | 420 | +        async def generator(): | 
|  | 421 | +            await sync_to_async(self._fetch_all)() | 
|  | 422 | +            for item in self._result_cache: | 
|  | 423 | +                yield item | 
|  | 424 | + | 
|  | 425 | +        return generator() | 
|  | 426 | + | 
|  | 427 | +    def iterator(self): | 
|  | 428 | +        yield from RawModelIterable(self) | 
|  | 429 | + | 
|  | 430 | +    def __repr__(self): | 
|  | 431 | +        return "<%s: %s>" % (self.__class__.__name__, self.query) | 
|  | 432 | + | 
|  | 433 | +    def __getitem__(self, k): | 
|  | 434 | +        return list(self)[k] | 
|  | 435 | + | 
|  | 436 | +    @property | 
|  | 437 | +    def db(self): | 
|  | 438 | +        """Return the database used if this query is executed now.""" | 
|  | 439 | +        return self._db or router.db_for_read(self.model, **self._hints) | 
|  | 440 | + | 
|  | 441 | +    def using(self, alias): | 
|  | 442 | +        """Select the database this RawQuerySet should execute against.""" | 
|  | 443 | +        return RawQuerySet( | 
|  | 444 | +            self.raw_query, | 
|  | 445 | +            model=self.model, | 
|  | 446 | +            query=self.query.chain(using=alias), | 
|  | 447 | +            params=self.params, | 
|  | 448 | +            translations=self.translations, | 
|  | 449 | +            using=alias, | 
|  | 450 | +        ) | 
|  | 451 | + | 
|  | 452 | +    @cached_property | 
|  | 453 | +    def columns(self): | 
|  | 454 | +        """ | 
|  | 455 | +        A list of model field names in the order they'll appear in the | 
|  | 456 | +        query results. | 
|  | 457 | +        """ | 
|  | 458 | +        columns = self.query.get_columns() | 
|  | 459 | +        # Adjust any column names which don't match field names | 
|  | 460 | +        for query_name, model_name in self.translations.items(): | 
|  | 461 | +            # Ignore translations for nonexistent column names | 
|  | 462 | +            try: | 
|  | 463 | +                index = columns.index(query_name) | 
|  | 464 | +            except ValueError: | 
|  | 465 | +                pass | 
|  | 466 | +            else: | 
|  | 467 | +                columns[index] = model_name | 
|  | 468 | +        return columns | 
|  | 469 | + | 
|  | 470 | +    @cached_property | 
|  | 471 | +    def model_fields(self): | 
|  | 472 | +        """A dict mapping column names to model field names.""" | 
|  | 473 | +        converter = connections[self.db].introspection.identifier_converter | 
|  | 474 | +        model_fields = {} | 
|  | 475 | +        for field in self.model._meta.fields: | 
|  | 476 | +            name, column = field.get_attname_column() | 
|  | 477 | +            model_fields[converter(column)] = field | 
|  | 478 | +        return model_fields | 
|  | 479 | + | 
|  | 480 | + | 
|  | 481 | +class RawQuery: | 
|  | 482 | +    """A single raw SQL query.""" | 
|  | 483 | + | 
|  | 484 | +    def __init__(self, sql, using, params=()): | 
|  | 485 | +        self.params = params | 
|  | 486 | +        self.sql = sql | 
|  | 487 | +        self.using = using | 
|  | 488 | +        self.cursor = None | 
|  | 489 | + | 
|  | 490 | +        # Mirror some properties of a normal query so that | 
|  | 491 | +        # the compiler can be used to process results. | 
|  | 492 | +        self.low_mark, self.high_mark = 0, None  # Used for offset/limit | 
|  | 493 | +        self.extra_select = {} | 
|  | 494 | +        self.annotation_select = {} | 
|  | 495 | + | 
|  | 496 | +    def chain(self, using): | 
|  | 497 | +        return self.clone(using) | 
|  | 498 | + | 
|  | 499 | +    def clone(self, using): | 
|  | 500 | +        return RawQuery(self.sql, using, params=self.params) | 
|  | 501 | + | 
|  | 502 | +    def get_columns(self): | 
|  | 503 | +        if self.cursor is None: | 
|  | 504 | +            self._execute_query() | 
|  | 505 | +        converter = connections[self.using].introspection.identifier_converter | 
|  | 506 | +        return [converter(column_meta[0]) for column_meta in self.cursor.description] | 
|  | 507 | + | 
|  | 508 | +    def __iter__(self): | 
|  | 509 | +        # Always execute a new query for a new iterator. | 
|  | 510 | +        # This could be optimized with a cache at the expense of RAM. | 
|  | 511 | +        self._execute_query() | 
|  | 512 | +        if not connections[self.using].features.can_use_chunked_reads: | 
|  | 513 | +            # If the database can't use chunked reads we need to make sure we | 
|  | 514 | +            # evaluate the entire query up front. | 
|  | 515 | +            result = list(self.cursor) | 
|  | 516 | +        else: | 
|  | 517 | +            result = self.cursor | 
|  | 518 | +        return iter(result) | 
|  | 519 | + | 
|  | 520 | +    def __repr__(self): | 
|  | 521 | +        return "<%s: %s>" % (self.__class__.__name__, self) | 
|  | 522 | + | 
|  | 523 | +    @property | 
|  | 524 | +    def params_type(self): | 
|  | 525 | +        if self.params is None: | 
|  | 526 | +            return None | 
|  | 527 | +        return dict if isinstance(self.params, Mapping) else tuple | 
|  | 528 | + | 
|  | 529 | +    def __str__(self): | 
|  | 530 | +        if self.params_type is None: | 
|  | 531 | +            return self.sql | 
|  | 532 | +        return self.sql % self.params_type(self.params) | 
|  | 533 | + | 
|  | 534 | +    def _execute_query(self): | 
|  | 535 | +        connection = connections[self.using] | 
|  | 536 | + | 
|  | 537 | +        # Adapt parameters to the database, as much as possible considering | 
|  | 538 | +        # that the target type isn't known. See #17755. | 
|  | 539 | +        params_type = self.params_type | 
|  | 540 | +        adapter = connection.ops.adapt_unknown_value | 
|  | 541 | +        if params_type is tuple: | 
|  | 542 | +            params = tuple(adapter(val) for val in self.params) | 
|  | 543 | +        elif params_type is dict: | 
|  | 544 | +            params = {key: adapter(val) for key, val in self.params.items()} | 
|  | 545 | +        elif params_type is None: | 
|  | 546 | +            params = None | 
|  | 547 | +        else: | 
|  | 548 | +            raise RuntimeError("Unexpected params type: %s" % params_type) | 
|  | 549 | + | 
|  | 550 | +        self.cursor = connection.cursor() | 
|  | 551 | +        self.cursor.execute(self.sql, params) | 
|  | 552 | + | 
|  | 553 | + | 
|  | 554 | +class RawModelIterable(BaseIterable): | 
|  | 555 | +    """ | 
|  | 556 | +    Iterable that yields a model instance for each row from a raw queryset. | 
|  | 557 | +    """ | 
|  | 558 | + | 
|  | 559 | +    def __iter__(self): | 
|  | 560 | +        # Cache some things for performance reasons outside the loop. | 
|  | 561 | +        db = self.queryset.db | 
|  | 562 | +        query = self.queryset.query | 
|  | 563 | +        connection = connections[db] | 
|  | 564 | +        compiler = connection.ops.compiler("SQLCompiler")(query, connection, db) | 
|  | 565 | +        query_iterator = iter(query) | 
|  | 566 | + | 
|  | 567 | +        try: | 
|  | 568 | +            ( | 
|  | 569 | +                model_init_names, | 
|  | 570 | +                model_init_pos, | 
|  | 571 | +                annotation_fields, | 
|  | 572 | +            ) = self.queryset.resolve_model_init_order() | 
|  | 573 | +            model_cls = self.queryset.model | 
|  | 574 | +            if model_cls._meta.pk.attname not in model_init_names: | 
|  | 575 | +                raise exceptions.FieldDoesNotExist("Raw query must include the primary key") | 
|  | 576 | +            fields = [self.queryset.model_fields.get(c) for c in self.queryset.columns] | 
|  | 577 | +            converters = compiler.get_converters( | 
|  | 578 | +                [f.get_col(f.model._meta.db_table) if f else None for f in fields] | 
|  | 579 | +            ) | 
|  | 580 | +            if converters: | 
|  | 581 | +                query_iterator = compiler.apply_converters(query_iterator, converters) | 
|  | 582 | +            for values in query_iterator: | 
|  | 583 | +                # Associate fields to values | 
|  | 584 | +                model_init_values = [values[pos] for pos in model_init_pos] | 
|  | 585 | +                instance = model_cls.from_db(db, model_init_names, model_init_values) | 
|  | 586 | +                if annotation_fields: | 
|  | 587 | +                    for column, pos in annotation_fields: | 
|  | 588 | +                        setattr(instance, column, values[pos]) | 
|  | 589 | +                yield instance | 
|  | 590 | +        finally: | 
|  | 591 | +            # Done iterating the Query. If it has its own cursor, close it. | 
|  | 592 | +            if hasattr(query, "cursor") and query.cursor: | 
|  | 593 | +                query.cursor.close() | 
0 commit comments