Skip to content

Commit 38a2608

Browse files
fix: fix json types
1 parent 779b1eb commit 38a2608

File tree

4 files changed

+89
-35
lines changed

4 files changed

+89
-35
lines changed

pydantic_aioredis/abstract.py

+6-8
Original file line numberDiff line numberDiff line change
@@ -2,24 +2,18 @@
22
import json
33
from datetime import date
44
from datetime import datetime
5-
from ipaddress import IPv4Address
6-
from ipaddress import IPv4Network
7-
from ipaddress import IPv6Address
8-
from ipaddress import IPv6Network
95
from typing import Any
106
from typing import Dict
117
from typing import List
128
from typing import Optional
139
from typing import Union
14-
from uuid import UUID
1510

1611
from pydantic import BaseModel
1712
from pydantic_aioredis.config import RedisConfig
1813
from redis import asyncio as aioredis
1914

20-
# STR_DUMP_SHAPES are object types that are serialized to strings using str(obj)
21-
# They are stored in redis as strings and rely on pydantic to deserialize them
22-
STR_DUMP_SHAPES = (IPv4Address, IPv4Network, IPv6Address, IPv6Network, UUID)
15+
from .types import JSON_DUMP_SHAPES
16+
from .types import STR_DUMP_SHAPES
2317

2418

2519
class _AbstractStore(BaseModel):
@@ -89,6 +83,8 @@ def serialize_partially(cls, data: Dict[str, Any]):
8983
continue
9084
if cls.__fields__[field].type_ not in [str, float, int]:
9185
data[field] = json.dumps(data[field], default=cls.json_default)
86+
if getattr(cls.__fields__[field], "shape", None) in JSON_DUMP_SHAPES:
87+
data[field] = json.dumps(data[field], default=cls.json_default)
9288
if getattr(cls.__fields__[field], "allow_none", False):
9389
if data[field] is None:
9490
data[field] = "None"
@@ -107,6 +103,8 @@ def deserialize_partially(cls, data: Dict[bytes, Any]):
107103
continue
108104
if cls.__fields__[field].type_ not in [str, float, int]:
109105
data[field] = json.loads(data[field], object_hook=cls.json_object_hook)
106+
if getattr(cls.__fields__[field], "shape", None) in JSON_DUMP_SHAPES:
107+
data[field] = json.loads(data[field], object_hook=cls.json_object_hook)
110108
if getattr(cls.__fields__[field], "allow_none", False):
111109
if data[field] == "None":
112110
data[field] = None

pydantic_aioredis/types.py

+34
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
from enum import Enum
2+
from ipaddress import IPv4Address
3+
from ipaddress import IPv4Network
4+
from ipaddress import IPv6Address
5+
from ipaddress import IPv6Network
6+
from uuid import UUID
7+
8+
from pydantic.fields import SHAPE_DEFAULTDICT
9+
from pydantic.fields import SHAPE_DICT
10+
from pydantic.fields import SHAPE_FROZENSET
11+
from pydantic.fields import SHAPE_LIST
12+
from pydantic.fields import SHAPE_MAPPING
13+
from pydantic.fields import SHAPE_SEQUENCE
14+
from pydantic.fields import SHAPE_SET
15+
from pydantic.fields import SHAPE_TUPLE
16+
from pydantic.fields import SHAPE_TUPLE_ELLIPSIS
17+
18+
# JSON_DUMP_SHAPES are object types that are serialized to JSON using json.dumps
19+
JSON_DUMP_SHAPES = (
20+
SHAPE_LIST,
21+
SHAPE_SET,
22+
SHAPE_MAPPING,
23+
SHAPE_TUPLE,
24+
SHAPE_TUPLE_ELLIPSIS,
25+
SHAPE_SEQUENCE,
26+
SHAPE_FROZENSET,
27+
SHAPE_DICT,
28+
SHAPE_DEFAULTDICT,
29+
Enum,
30+
)
31+
32+
# STR_DUMP_SHAPES are object types that are serialized to strings using str(obj)
33+
# They are stored in redis as strings and rely on pydantic to deserialize them
34+
STR_DUMP_SHAPES = (IPv4Address, IPv4Network, IPv6Address, IPv6Network, UUID)

test/test_abstract.py

+27-27
Original file line numberDiff line numberDiff line change
@@ -29,48 +29,48 @@ class SimpleModel(Model):
2929
test_tuple: Tuple[str]
3030

3131

32+
def test_serialize_partially_skip_missing_field():
33+
serialized = SimpleModel.serialize_partially({"unknown": "test"})
34+
assert serialized["unknown"] == "test"
35+
36+
3237
parameters = [
33-
(st.text, [], {}, "test_str", None),
34-
(st.integers, [], {}, "test_int", None),
35-
(st.floats, [], {"allow_nan": False}, "test_float", None),
36-
(st.dates, [], {}, "test_date", lambda x: json.dumps(x.isoformat())),
37-
(st.datetimes, [], {}, "test_datetime", lambda x: json.dumps(x.isoformat())),
38-
(st.ip_addresses, [], {"v": 4}, "test_ip_v4", lambda x: json.dumps(str(x))),
39-
(st.ip_addresses, [], {"v": 6}, "test_ip_v4", lambda x: json.dumps(str(x))),
40-
(
41-
st.lists,
42-
[st.tuples(st.integers(), st.floats())],
43-
{},
44-
"test_list",
45-
lambda x: json.dumps(x),
46-
),
38+
(st.text, [], {}, "test_str", str, False),
39+
(st.dates, [], {}, "test_date", str, False),
40+
(st.datetimes, [], {}, "test_datetime", str, False),
41+
(st.ip_addresses, [], {"v": 4}, "test_ip_v4", str, False),
42+
(st.ip_addresses, [], {"v": 6}, "test_ip_v4", str, False),
43+
(st.lists, [st.tuples(st.integers(), st.floats())], {}, "test_list", str, False),
4744
(
4845
st.dictionaries,
4946
[st.text(), st.tuples(st.integers(), st.floats())],
5047
{},
5148
"test_dict",
52-
lambda x: json.dumps(x),
49+
str,
50+
False,
5351
),
54-
(st.tuples, [st.text()], {}, "test_tuple", lambda x: json.dumps(x)),
52+
(st.tuples, [st.text()], {}, "test_tuple", str, False),
53+
(st.floats, [], {"allow_nan": False}, "test_float", float, True),
54+
(st.integers, [], {}, "test_int", int, True),
5555
]
5656

5757

5858
@pytest.mark.parametrize(
59-
"strategy, strategy_args, strategy_kwargs, model_field, serialize_callable",
59+
"strategy, strategy_args, strategy_kwargs, model_field, expected_type, equality_expected",
6060
parameters,
6161
)
6262
@given(st.data())
6363
def test_serialize_partially(
64-
strategy, strategy_args, strategy_kwargs, model_field, serialize_callable, data
64+
strategy,
65+
strategy_args,
66+
strategy_kwargs,
67+
model_field,
68+
expected_type,
69+
equality_expected,
70+
data,
6571
):
6672
value = data.draw(strategy(*strategy_args, **strategy_kwargs))
6773
serialized = SimpleModel.serialize_partially({model_field: value})
68-
if serialize_callable is None:
69-
assert serialized[model_field] == value
70-
else:
71-
assert serialized[model_field] == serialize_callable(value)
72-
73-
74-
def test_serialize_partially_skip_missing_filed():
75-
serialized = SimpleModel.serialize_partially({"unknown": "test"})
76-
assert serialized["unknown"] == "test"
74+
assert isinstance(serialized.get(model_field), expected_type)
75+
if equality_expected:
76+
assert serialized.get(model_field) == value

test/test_model.py

+22
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from ipaddress import IPv6Address
77
from typing import Dict
88
from typing import List
9+
from typing import Optional
910
from typing import Tuple
1011
from typing import Union
1112

@@ -93,3 +94,24 @@ class UpdateModel(Model):
9394

9495
redis_model = await UpdateModel.select(ids=[test_str])
9596
assert redis_model[0].test_int == update_int
97+
98+
99+
async def test_storing_list(redis_store):
100+
# https://github.com/andrewthetechie/pydantic-aioredis/issues/403
101+
class DataTypeTest(Model):
102+
_primary_key_field: str = "key"
103+
104+
key: str
105+
value: List[int]
106+
107+
redis_store.register_model(DataTypeTest)
108+
key = "test_list_storage"
109+
instance = DataTypeTest(
110+
key=key,
111+
value=[1, 2, 3],
112+
)
113+
await instance.save()
114+
115+
instance_in_redis = await DataTypeTest.select()
116+
assert instance_in_redis[0].key == instance.key
117+
assert len(instance_in_redis[0].value) == len(instance.value)

0 commit comments

Comments
 (0)