Skip to content

Commit 50c38a5

Browse files
lovasoasloria
andauthored
Fix inconsistency in post_load behavior (#76)
* Fix inconsistency in post_load behavior * Remove unneeded import * Update marshmallow_dataclass/__init__.py Co-Authored-By: Steven Loria <[email protected]> * Updated tests Co-authored-by: Steven Loria <[email protected]>
1 parent 69caec4 commit 50c38a5

File tree

2 files changed

+54
-5
lines changed

2 files changed

+54
-5
lines changed

marshmallow_dataclass/__init__.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,9 @@ class User:
3434
})
3535
Schema: ClassVar[Type[Schema]] = Schema # For the type checker
3636
"""
37-
import dataclasses
3837
import inspect
39-
from functools import lru_cache
4038
from enum import EnumMeta
39+
from functools import lru_cache
4140
from typing import (
4241
overload,
4342
Dict,
@@ -54,6 +53,7 @@ class User:
5453
Set,
5554
)
5655

56+
import dataclasses
5757
import marshmallow
5858
import typing_inspect
5959

@@ -457,12 +457,17 @@ def _base_schema(
457457
Base schema factory that creates a schema for `clazz` derived either from `base_schema`
458458
or `BaseSchema`
459459
"""
460+
460461
# Remove `type: ignore` when mypy handles dynamic base classes
461462
# https://github.com/python/mypy/issues/2813
462463
class BaseSchema(base_schema or marshmallow.Schema): # type: ignore
463-
@marshmallow.post_load
464-
def make_data_class(self, data, **_):
465-
return clazz(**data)
464+
def load(self, data: Mapping, *, many: bool = None, **kwargs):
465+
all_loaded = super().load(data, many=many, **kwargs)
466+
many = self.many if many is None else bool(many)
467+
if many:
468+
return [clazz(**loaded) for loaded in all_loaded]
469+
else:
470+
return clazz(**all_loaded)
466471

467472
return BaseSchema
468473

tests/test_post_load.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
import unittest
2+
3+
import marshmallow
4+
5+
import marshmallow_dataclass
6+
7+
8+
# Regression test for https://github.com/lovasoa/marshmallow_dataclass/issues/75
9+
class TestPostLoad(unittest.TestCase):
10+
@marshmallow_dataclass.dataclass
11+
class Named:
12+
first_name: str
13+
last_name: str
14+
15+
@marshmallow.post_load
16+
def a(self, data, **_kwargs):
17+
data["first_name"] = data["first_name"].capitalize()
18+
return data
19+
20+
@marshmallow.post_load
21+
def z(self, data, **_kwargs):
22+
data["last_name"] = data["last_name"].capitalize()
23+
return data
24+
25+
def test_post_load_method_naming_does_not_affect_data(self):
26+
actual = self.Named.Schema().load(
27+
{"first_name": "matt", "last_name": "groening"}
28+
)
29+
expected = self.Named(first_name="Matt", last_name="Groening")
30+
self.assertEqual(actual, expected)
31+
32+
def test_load_many(self):
33+
actual = self.Named.Schema().load(
34+
[
35+
{"first_name": "matt", "last_name": "groening"},
36+
{"first_name": "bart", "last_name": "simpson"},
37+
],
38+
many=True,
39+
)
40+
expected = [
41+
self.Named(first_name="Matt", last_name="Groening"),
42+
self.Named(first_name="Bart", last_name="Simpson"),
43+
]
44+
self.assertEqual(actual, expected)

0 commit comments

Comments
 (0)