Skip to content

Commit 9ebcece

Browse files
author
Erlend Egeberg Aasland
authored
gh-79097: Add support for aggregate window functions in sqlite3 (GH-20903)
1 parent f45aa8f commit 9ebcece

File tree

10 files changed

+477
-13
lines changed

10 files changed

+477
-13
lines changed

Doc/includes/sqlite3/sumintwindow.py

+46
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
# Example taken from https://www.sqlite.org/windowfunctions.html#udfwinfunc
2+
import sqlite3
3+
4+
5+
class WindowSumInt:
6+
def __init__(self):
7+
self.count = 0
8+
9+
def step(self, value):
10+
"""Adds a row to the current window."""
11+
self.count += value
12+
13+
def value(self):
14+
"""Returns the current value of the aggregate."""
15+
return self.count
16+
17+
def inverse(self, value):
18+
"""Removes a row from the current window."""
19+
self.count -= value
20+
21+
def finalize(self):
22+
"""Returns the final value of the aggregate.
23+
24+
Any clean-up actions should be placed here.
25+
"""
26+
return self.count
27+
28+
29+
con = sqlite3.connect(":memory:")
30+
cur = con.execute("create table test(x, y)")
31+
values = [
32+
("a", 4),
33+
("b", 5),
34+
("c", 3),
35+
("d", 8),
36+
("e", 1),
37+
]
38+
cur.executemany("insert into test values(?, ?)", values)
39+
con.create_window_function("sumint", 1, WindowSumInt)
40+
cur.execute("""
41+
select x, sumint(y) over (
42+
order by x rows between 1 preceding and 1 following
43+
) as sum_y
44+
from test order by x
45+
""")
46+
print(cur.fetchall())

Doc/library/sqlite3.rst

+29
Original file line numberDiff line numberDiff line change
@@ -473,6 +473,35 @@ Connection Objects
473473
.. literalinclude:: ../includes/sqlite3/mysumaggr.py
474474

475475

476+
.. method:: create_window_function(name, num_params, aggregate_class, /)
477+
478+
Creates user-defined aggregate window function *name*.
479+
480+
*aggregate_class* must implement the following methods:
481+
482+
* ``step``: adds a row to the current window
483+
* ``value``: returns the current value of the aggregate
484+
* ``inverse``: removes a row from the current window
485+
* ``finalize``: returns the final value of the aggregate
486+
487+
``step`` and ``value`` accept *num_params* number of parameters,
488+
unless *num_params* is ``-1``, in which case they may take any number of
489+
arguments. ``finalize`` and ``value`` can return any of the types
490+
supported by SQLite:
491+
:class:`bytes`, :class:`str`, :class:`int`, :class:`float`, and
492+
:const:`None`. Call :meth:`create_window_function` with
493+
*aggregate_class* set to :const:`None` to clear window function *name*.
494+
495+
Aggregate window functions are supported by SQLite 3.25.0 and higher.
496+
:exc:`NotSupportedError` will be raised if used with older versions.
497+
498+
.. versionadded:: 3.11
499+
500+
Example:
501+
502+
.. literalinclude:: ../includes/sqlite3/sumintwindow.py
503+
504+
476505
.. method:: create_collation(name, callable)
477506

478507
Creates a collation with the specified *name* and *callable*. The callable will

Doc/whatsnew/3.11.rst

+4
Original file line numberDiff line numberDiff line change
@@ -389,6 +389,10 @@ sqlite3
389389
serializing and deserializing databases.
390390
(Contributed by Erlend E. Aasland in :issue:`41930`.)
391391

392+
* Add :meth:`~sqlite3.Connection.create_window_function` to
393+
:class:`sqlite3.Connection` for creating aggregate window functions.
394+
(Contributed by Erlend E. Aasland in :issue:`34916`.)
395+
392396

393397
sys
394398
---

Lib/test/test_sqlite3/test_dbapi.py

+2
Original file line numberDiff line numberDiff line change
@@ -1084,6 +1084,8 @@ def test_check_connection_thread(self):
10841084
if hasattr(sqlite.Connection, "serialize"):
10851085
fns.append(lambda: self.con.serialize())
10861086
fns.append(lambda: self.con.deserialize(b""))
1087+
if sqlite.sqlite_version_info >= (3, 25, 0):
1088+
fns.append(lambda: self.con.create_window_function("foo", 0, None))
10871089

10881090
for fn in fns:
10891091
with self.subTest(fn=fn):

Lib/test/test_sqlite3/test_userfunctions.py

+163-5
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,9 @@
2727
import re
2828
import sys
2929
import unittest
30-
import unittest.mock
3130
import sqlite3 as sqlite
3231

32+
from unittest.mock import Mock, patch
3333
from test.support import bigmemtest, catch_unraisable_exception, gc_collect
3434

3535
from test.test_sqlite3.test_dbapi import cx_limit
@@ -393,7 +393,7 @@ def append_result(arg):
393393
# indices, which allows testing based on syntax, iso. the query optimizer.
394394
@unittest.skipIf(sqlite.sqlite_version_info < (3, 8, 3), "Requires SQLite 3.8.3 or higher")
395395
def test_func_non_deterministic(self):
396-
mock = unittest.mock.Mock(return_value=None)
396+
mock = Mock(return_value=None)
397397
self.con.create_function("nondeterministic", 0, mock, deterministic=False)
398398
if sqlite.sqlite_version_info < (3, 15, 0):
399399
self.con.execute("select nondeterministic() = nondeterministic()")
@@ -404,7 +404,7 @@ def test_func_non_deterministic(self):
404404

405405
@unittest.skipIf(sqlite.sqlite_version_info < (3, 8, 3), "Requires SQLite 3.8.3 or higher")
406406
def test_func_deterministic(self):
407-
mock = unittest.mock.Mock(return_value=None)
407+
mock = Mock(return_value=None)
408408
self.con.create_function("deterministic", 0, mock, deterministic=True)
409409
if sqlite.sqlite_version_info < (3, 15, 0):
410410
self.con.execute("select deterministic() = deterministic()")
@@ -482,6 +482,164 @@ def test_func_return_illegal_value(self):
482482
self.con.execute, "select badreturn()")
483483

484484

485+
class WindowSumInt:
486+
def __init__(self):
487+
self.count = 0
488+
489+
def step(self, value):
490+
self.count += value
491+
492+
def value(self):
493+
return self.count
494+
495+
def inverse(self, value):
496+
self.count -= value
497+
498+
def finalize(self):
499+
return self.count
500+
501+
class BadWindow(Exception):
502+
pass
503+
504+
505+
@unittest.skipIf(sqlite.sqlite_version_info < (3, 25, 0),
506+
"Requires SQLite 3.25.0 or newer")
507+
class WindowFunctionTests(unittest.TestCase):
508+
def setUp(self):
509+
self.con = sqlite.connect(":memory:")
510+
self.cur = self.con.cursor()
511+
512+
# Test case taken from https://www.sqlite.org/windowfunctions.html#udfwinfunc
513+
values = [
514+
("a", 4),
515+
("b", 5),
516+
("c", 3),
517+
("d", 8),
518+
("e", 1),
519+
]
520+
with self.con:
521+
self.con.execute("create table test(x, y)")
522+
self.con.executemany("insert into test values(?, ?)", values)
523+
self.expected = [
524+
("a", 9),
525+
("b", 12),
526+
("c", 16),
527+
("d", 12),
528+
("e", 9),
529+
]
530+
self.query = """
531+
select x, %s(y) over (
532+
order by x rows between 1 preceding and 1 following
533+
) as sum_y
534+
from test order by x
535+
"""
536+
self.con.create_window_function("sumint", 1, WindowSumInt)
537+
538+
def test_win_sum_int(self):
539+
self.cur.execute(self.query % "sumint")
540+
self.assertEqual(self.cur.fetchall(), self.expected)
541+
542+
def test_win_error_on_create(self):
543+
self.assertRaises(sqlite.ProgrammingError,
544+
self.con.create_window_function,
545+
"shouldfail", -100, WindowSumInt)
546+
547+
@with_tracebacks(BadWindow)
548+
def test_win_exception_in_method(self):
549+
for meth in "__init__", "step", "value", "inverse":
550+
with self.subTest(meth=meth):
551+
with patch.object(WindowSumInt, meth, side_effect=BadWindow):
552+
name = f"exc_{meth}"
553+
self.con.create_window_function(name, 1, WindowSumInt)
554+
msg = f"'{meth}' method raised error"
555+
with self.assertRaisesRegex(sqlite.OperationalError, msg):
556+
self.cur.execute(self.query % name)
557+
self.cur.fetchall()
558+
559+
@with_tracebacks(BadWindow)
560+
def test_win_exception_in_finalize(self):
561+
# Note: SQLite does not (as of version 3.38.0) propagate finalize
562+
# callback errors to sqlite3_step(); this implies that OperationalError
563+
# is _not_ raised.
564+
with patch.object(WindowSumInt, "finalize", side_effect=BadWindow):
565+
name = f"exception_in_finalize"
566+
self.con.create_window_function(name, 1, WindowSumInt)
567+
self.cur.execute(self.query % name)
568+
self.cur.fetchall()
569+
570+
@with_tracebacks(AttributeError)
571+
def test_win_missing_method(self):
572+
class MissingValue:
573+
def step(self, x): pass
574+
def inverse(self, x): pass
575+
def finalize(self): return 42
576+
577+
class MissingInverse:
578+
def step(self, x): pass
579+
def value(self): return 42
580+
def finalize(self): return 42
581+
582+
class MissingStep:
583+
def value(self): return 42
584+
def inverse(self, x): pass
585+
def finalize(self): return 42
586+
587+
dataset = (
588+
("step", MissingStep),
589+
("value", MissingValue),
590+
("inverse", MissingInverse),
591+
)
592+
for meth, cls in dataset:
593+
with self.subTest(meth=meth, cls=cls):
594+
name = f"exc_{meth}"
595+
self.con.create_window_function(name, 1, cls)
596+
with self.assertRaisesRegex(sqlite.OperationalError,
597+
f"'{meth}' method not defined"):
598+
self.cur.execute(self.query % name)
599+
self.cur.fetchall()
600+
601+
@with_tracebacks(AttributeError)
602+
def test_win_missing_finalize(self):
603+
# Note: SQLite does not (as of version 3.38.0) propagate finalize
604+
# callback errors to sqlite3_step(); this implies that OperationalError
605+
# is _not_ raised.
606+
class MissingFinalize:
607+
def step(self, x): pass
608+
def value(self): return 42
609+
def inverse(self, x): pass
610+
611+
name = "missing_finalize"
612+
self.con.create_window_function(name, 1, MissingFinalize)
613+
self.cur.execute(self.query % name)
614+
self.cur.fetchall()
615+
616+
def test_win_clear_function(self):
617+
self.con.create_window_function("sumint", 1, None)
618+
self.assertRaises(sqlite.OperationalError, self.cur.execute,
619+
self.query % "sumint")
620+
621+
def test_win_redefine_function(self):
622+
# Redefine WindowSumInt; adjust the expected results accordingly.
623+
class Redefined(WindowSumInt):
624+
def step(self, value): self.count += value * 2
625+
def inverse(self, value): self.count -= value * 2
626+
expected = [(v[0], v[1]*2) for v in self.expected]
627+
628+
self.con.create_window_function("sumint", 1, Redefined)
629+
self.cur.execute(self.query % "sumint")
630+
self.assertEqual(self.cur.fetchall(), expected)
631+
632+
def test_win_error_value_return(self):
633+
class ErrorValueReturn:
634+
def __init__(self): pass
635+
def step(self, x): pass
636+
def value(self): return 1 << 65
637+
638+
self.con.create_window_function("err_val_ret", 1, ErrorValueReturn)
639+
self.assertRaisesRegex(sqlite.DataError, "string or blob too big",
640+
self.cur.execute, self.query % "err_val_ret")
641+
642+
485643
class AggregateTests(unittest.TestCase):
486644
def setUp(self):
487645
self.con = sqlite.connect(":memory:")
@@ -527,10 +685,10 @@ def test_aggr_no_step(self):
527685

528686
def test_aggr_no_finalize(self):
529687
cur = self.con.cursor()
530-
with self.assertRaises(sqlite.OperationalError) as cm:
688+
msg = "user-defined aggregate's 'finalize' method not defined"
689+
with self.assertRaisesRegex(sqlite.OperationalError, msg):
531690
cur.execute("select nofinalize(t) from test")
532691
val = cur.fetchone()[0]
533-
self.assertEqual(str(cm.exception), "user-defined aggregate's 'finalize' method raised error")
534692

535693
@with_tracebacks(ZeroDivisionError, name="AggrExceptionInInit")
536694
def test_aggr_exception_in_init(self):
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
Add :meth:`~sqlite3.Connection.create_window_function` to
2+
:class:`sqlite3.Connection` for creating aggregate window functions.
3+
Patch by Erlend E. Aasland.

Modules/_sqlite/clinic/connection.c.h

+52-1
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)