Skip to content

Commit e635520

Browse files
committed
feat: change update and fix test
1 parent 5d82d51 commit e635520

File tree

5 files changed

+31
-20
lines changed

5 files changed

+31
-20
lines changed

tests/fields/test_array.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,22 @@
1-
from tests import testmodels
1+
from tests import testmodels_postgres as testmodels
22
from tortoise.contrib import test
3-
from tortoise.exceptions import IntegrityError
3+
from tortoise.exceptions import IntegrityError, OperationalError
44

55

6-
class TestArrayFields(test.TestCase):
7-
@test.requireCapability(dialect="postgres")
6+
@test.requireCapability(dialect="postgres")
7+
class TestArrayFields(test.IsolatedTestCase):
8+
tortoise_test_modules = ["tests.testmodels_postgres"]
9+
10+
async def _setUpDB(self) -> None:
11+
try:
12+
await super()._setUpDB()
13+
except OperationalError:
14+
raise test.SkipTest("Works only with PostgreSQL")
15+
816
async def test_empty(self):
917
with self.assertRaises(IntegrityError):
1018
await testmodels.ArrayFields.create()
1119

12-
@test.requireCapability(dialect="postgres")
1320
async def test_create(self):
1421
obj0 = await testmodels.ArrayFields.create(array=[0])
1522
obj = await testmodels.ArrayFields.get(id=obj0.id)
@@ -19,21 +26,18 @@ async def test_create(self):
1926
obj2 = await testmodels.ArrayFields.get(id=obj.id)
2027
self.assertEqual(obj, obj2)
2128

22-
@test.requireCapability(dialect="postgres")
2329
async def test_update(self):
2430
obj0 = await testmodels.ArrayFields.create(array=[0])
2531
await testmodels.ArrayFields.filter(id=obj0.id).update(array=[1])
2632
obj = await testmodels.ArrayFields.get(id=obj0.id)
2733
self.assertEqual(obj.array, [1])
2834
self.assertIs(obj.array_null, None)
2935

30-
@test.requireCapability(dialect="postgres")
3136
async def test_values(self):
3237
obj0 = await testmodels.ArrayFields.create(array=[0])
3338
values = await testmodels.ArrayFields.get(id=obj0.id).values("array")
3439
self.assertEqual(values["array"], [0])
3540

36-
@test.requireCapability(dialect="postgres")
3741
async def test_values_list(self):
3842
obj0 = await testmodels.ArrayFields.create(array=[0])
3943
values = await testmodels.ArrayFields.get(id=obj0.id).values_list("array", flat=True)

tests/schema/test_generate_schema.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -417,7 +417,7 @@ async def init_for(self, module: str, safe=False) -> None:
417417
)
418418
self.sqls = get_schema_sql(connections.get("default"), safe).split("; ")
419419
except ImportError:
420-
raise test.SkipTest("aiomysql not installed")
420+
raise test.SkipTest("asyncmy not installed")
421421

422422
async def test_noid(self):
423423
await self.init_for("tests.testmodels")

tests/testmodels.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
import pytz
1414

1515
from tortoise import fields
16-
from tortoise.contrib.postgres.fields import ArrayField
1716
from tortoise.exceptions import NoValuesFetched, ValidationError
1817
from tortoise.manager import Manager
1918
from tortoise.models import Model
@@ -222,12 +221,6 @@ class BooleanFields(Model):
222221
boolean_null = fields.BooleanField(null=True)
223222

224223

225-
class ArrayFields(Model):
226-
id = fields.IntField(pk=True)
227-
array = ArrayField()
228-
array_null = ArrayField(null=True)
229-
230-
231224
class BinaryFields(Model):
232225
id = fields.IntField(pk=True)
233226
binary = fields.BinaryField()

tests/testmodels_postgres.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
from tortoise import Model, fields
2+
from tortoise.contrib.postgres.fields import ArrayField
3+
4+
5+
class ArrayFields(Model):
6+
id = fields.IntField(pk=True)
7+
array = ArrayField()
8+
array_null = ArrayField(null=True)

tortoise/queryset.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1029,6 +1029,7 @@ class UpdateQuery(AwaitableQuery):
10291029
"custom_filters",
10301030
"orderings",
10311031
"limit",
1032+
"values",
10321033
)
10331034

10341035
def __init__(
@@ -1050,6 +1051,7 @@ def __init__(
10501051
self._db = db
10511052
self.limit = limit
10521053
self.orderings = orderings
1054+
self.values: List[Any] = []
10531055

10541056
def _make_query(self) -> None:
10551057
table = self.model._meta.basetable
@@ -1066,7 +1068,7 @@ def _make_query(self) -> None:
10661068
)
10671069
# Need to get executor to get correct column_map
10681070
executor = self._db.executor_class(model=self.model, db=self._db)
1069-
1071+
count = 0
10701072
for key, value in self.update_kwargs.items():
10711073
field_object = self.model._meta.fields_map.get(key)
10721074
if not field_object:
@@ -1090,8 +1092,12 @@ def _make_query(self) -> None:
10901092
value = value.resolve(self.model, table)["field"]
10911093
else:
10921094
value = executor.column_map[key](value, None)
1093-
1094-
self.query = self.query.set(db_field, value)
1095+
if isinstance(value, Term):
1096+
self.query = self.query.set(db_field, value)
1097+
else:
1098+
self.query = self.query.set(db_field, executor.parameter(count))
1099+
self.values.append(value)
1100+
count += 1
10951101

10961102
def __await__(self) -> Generator[Any, None, int]:
10971103
if self._db is None:
@@ -1100,7 +1106,7 @@ def __await__(self) -> Generator[Any, None, int]:
11001106
return self._execute().__await__()
11011107

11021108
async def _execute(self) -> int:
1103-
return (await self._db.execute_query(str(self.query)))[0]
1109+
return (await self._db.execute_query(str(self.query), self.values))[0]
11041110

11051111

11061112
class DeleteQuery(AwaitableQuery):

0 commit comments

Comments
 (0)