|
2 | 2 |
|
3 | 3 | from __future__ import annotations
|
4 | 4 |
|
| 5 | +import typing |
| 6 | + |
5 | 7 | import re
|
6 | 8 | from collections import OrderedDict, UserList, defaultdict
|
7 | 9 | from copy import deepcopy
|
@@ -413,65 +415,67 @@ def _restore_dataclass_field(
|
413 | 415 | Returns:
|
414 | 416 | Union[:const:`Primitive`, CBORSerializable]: A CBOR primitive or a CBORSerializable.
|
415 | 417 | """
|
| 418 | + |
416 | 419 | if "object_hook" in f.metadata:
|
417 | 420 | return f.metadata["object_hook"](v)
|
418 |
| - elif isclass(f.type) and issubclass(f.type, CBORSerializable): |
419 |
| - return f.type.from_primitive(v) |
420 |
| - elif hasattr(f.type, "__origin__") and (f.type.__origin__ is list): |
421 |
| - t_args = f.type.__args__ |
| 421 | + return _restore_typed_primitive(f.type, v) |
| 422 | + |
| 423 | + |
| 424 | +def _restore_typed_primitive( |
| 425 | + t: typing.Type, v: Primitive |
| 426 | +) -> Union[Primitive, CBORSerializable]: |
| 427 | + """Try to restore a value back to its original type based on information given in field. |
| 428 | +
|
| 429 | + Args: |
| 430 | + f (type): A type |
| 431 | + v (:const:`Primitive`): A CBOR primitive. |
| 432 | +
|
| 433 | + Returns: |
| 434 | + Union[:const:`Primitive`, CBORSerializable]: A CBOR primitive or a CBORSerializable. |
| 435 | + """ |
| 436 | + if t in PRIMITIVE_TYPES and isinstance(v, t): |
| 437 | + return v |
| 438 | + elif isclass(t) and issubclass(t, CBORSerializable): |
| 439 | + return t.from_primitive(v) |
| 440 | + elif hasattr(t, "__origin__") and (t.__origin__ is list): |
| 441 | + t_args = t.__args__ |
422 | 442 | if len(t_args) != 1:
|
423 | 443 | raise DeserializeException(
|
424 | 444 | f"List types need exactly one type argument, but got {t_args}"
|
425 | 445 | )
|
426 | 446 | t = t_args[0]
|
427 | 447 | if not isinstance(v, list):
|
428 | 448 | raise DeserializeException(f"Expected type list but got {type(v)}")
|
429 |
| - if isclass(t) and issubclass(t, CBORSerializable): |
430 |
| - return IndefiniteList([t.from_primitive(w) for w in v]) |
431 |
| - else: |
432 |
| - return IndefiniteList(v) |
433 |
| - elif isclass(f.type) and issubclass(f.type, IndefiniteList): |
| 449 | + return IndefiniteList([_restore_typed_primitive(t, w) for w in v]) |
| 450 | + elif isclass(t) and issubclass(t, IndefiniteList): |
434 | 451 | return IndefiniteList(v)
|
435 |
| - elif hasattr(f.type, "__origin__") and (f.type.__origin__ is dict): |
436 |
| - t_args = f.type.__args__ |
| 452 | + elif hasattr(t, "__origin__") and (t.__origin__ is dict): |
| 453 | + t_args = t.__args__ |
437 | 454 | if len(t_args) != 2:
|
438 | 455 | raise DeserializeException(
|
439 | 456 | f"Dict types need exactly two type arguments, but got {t_args}"
|
440 | 457 | )
|
441 | 458 | key_t = t_args[0]
|
442 | 459 | val_t = t_args[1]
|
443 |
| - if isclass(key_t) and issubclass(key_t, CBORSerializable): |
444 |
| - key_converter = key_t.from_primitive |
445 |
| - else: |
446 |
| - key_converter = _identity |
447 |
| - if isclass(val_t) and issubclass(val_t, CBORSerializable): |
448 |
| - val_converter = val_t.from_primitive |
449 |
| - else: |
450 |
| - val_converter = _identity |
451 | 460 | if not isinstance(v, dict):
|
452 | 461 | raise DeserializeException(f"Expected dict type but got {type(v)}")
|
453 |
| - return {key_converter(key): val_converter(val) for key, val in v.items()} |
454 |
| - elif hasattr(f.type, "__origin__") and ( |
455 |
| - f.type.__origin__ is Union or f.type.__origin__ is Optional |
| 462 | + return { |
| 463 | + _restore_typed_primitive(key_t, key): _restore_typed_primitive(val_t, val) |
| 464 | + for key, val in v.items() |
| 465 | + } |
| 466 | + elif hasattr(t, "__origin__") and ( |
| 467 | + t.__origin__ is Union or t.__origin__ is Optional |
456 | 468 | ):
|
457 |
| - t_args = f.type.__args__ |
| 469 | + t_args = t.__args__ |
458 | 470 | for t in t_args:
|
459 |
| - if isclass(t) and issubclass(t, IndefiniteList): |
460 |
| - return IndefiniteList(v) |
461 |
| - elif isclass(t) and issubclass(t, CBORSerializable): |
462 |
| - try: |
463 |
| - return t.from_primitive(v) |
464 |
| - except DeserializeException: |
465 |
| - pass |
466 |
| - else: |
467 |
| - if not isclass(t) and hasattr(t, "__origin__"): |
468 |
| - t = t.__origin__ |
469 |
| - if t in PRIMITIVE_TYPES and isinstance(v, t): |
470 |
| - return v |
| 471 | + try: |
| 472 | + return _restore_typed_primitive(t, v) |
| 473 | + except DeserializeException: |
| 474 | + pass |
471 | 475 | raise DeserializeException(
|
472 | 476 | f"Cannot deserialize object: \n{v}\n in any valid type from {t_args}."
|
473 | 477 | )
|
474 |
| - return v |
| 478 | + raise DeserializeException(f"Cannot deserialize object: \n{v}\n to type {t}.") |
475 | 479 |
|
476 | 480 |
|
477 | 481 | ArrayBase = TypeVar("ArrayBase", bound="ArrayCBORSerializable")
|
@@ -556,8 +560,8 @@ def to_shallow_primitive(self) -> List[Primitive]:
|
556 | 560 | return primitives
|
557 | 561 |
|
558 | 562 | @classmethod
|
559 |
| - @limit_primitive_type(list) |
560 |
| - def from_primitive(cls: Type[ArrayBase], values: list) -> ArrayBase: |
| 563 | + @limit_primitive_type(list, tuple) |
| 564 | + def from_primitive(cls: Type[ArrayBase], values: Union[list, tuple]) -> ArrayBase: |
561 | 565 | """Restore a primitive value to its original class type.
|
562 | 566 |
|
563 | 567 | Args:
|
|
0 commit comments