Skip to content

Commit 3b00d58

Browse files
junzh0ufacebook-github-bot
authored andcommitted
python: any registry (5/n) support more primitive types
Summary: `bool`, `str` and `bytes` Reviewed By: nanshu Differential Revision: D43367484 fbshipit-source-id: 1162370c3e9a5163e1c9a7619fc232924ec2adce
1 parent ab63e21 commit 3b00d58

File tree

6 files changed

+129
-52
lines changed

6 files changed

+129
-52
lines changed

thrift/lib/python/any/any_registry.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,14 +56,22 @@ def _standard_protocol_to_serializer_protocol(
5656

5757

5858
def _infer_type_name_from_value(value: PrimitiveType) -> TypeName:
59+
if isinstance(value, bool):
60+
return TypeName(boolType=Void.Unused)
5961
if isinstance(value, int):
6062
return TypeName(i64Type=Void.Unused)
6163
if isinstance(value, float):
6264
return TypeName(doubleType=Void.Unused)
65+
if isinstance(value, str):
66+
return TypeName(stringType=Void.Unused)
67+
if isinstance(value, bytes):
68+
return TypeName(binaryType=Void.Unused)
6369
raise ValueError(f"Can not infer thrift type from: {value}")
6470

6571

6672
def _type_name_to_primitive_type(type_name: TypeName) -> typing.Type[PrimitiveType]:
73+
if type_name.type is TypeName.Type.boolType:
74+
return bool
6775
if type_name.type in (
6876
TypeName.Type.i16Type,
6977
TypeName.Type.i32Type,
@@ -75,6 +83,10 @@ def _type_name_to_primitive_type(type_name: TypeName) -> typing.Type[PrimitiveTy
7583
TypeName.Type.doubleType,
7684
):
7785
return float
86+
if type_name.type is TypeName.Type.stringType:
87+
return str
88+
if type_name.type is TypeName.Type.binaryType:
89+
return bytes
7890
raise ValueError(f"Unsupported primitive type: {type_name}")
7991

8092

@@ -113,7 +125,7 @@ def store(
113125
)
114126
if isinstance(obj, StructOrUnion):
115127
return self._store_struct(obj, protocol=protocol)
116-
if isinstance(obj, (int, float)):
128+
if isinstance(obj, (bool, int, float, str, bytes)):
117129
return self._store_primitive(obj, protocol=protocol)
118130
raise ValueError(f"Unsupported type: f{type(obj)}")
119131

@@ -166,14 +178,17 @@ def load(self, any_obj: Any) -> SupportedType:
166178
if any_obj.type.name.type is TypeName.Type.structType:
167179
return self._load_struct(any_obj)
168180
if any_obj.type.name.type in [
181+
TypeName.Type.boolType,
169182
TypeName.Type.i16Type,
170183
TypeName.Type.i32Type,
171184
TypeName.Type.i64Type,
172185
TypeName.Type.floatType,
173186
TypeName.Type.doubleType,
187+
TypeName.Type.stringType,
188+
TypeName.Type.binaryType,
174189
]:
175190
return self._load_primitive(any_obj)
176-
raise NotImplementedError(f"Unsupported type: {any_obj.type.name.value}")
191+
raise NotImplementedError(f"Unsupported type: {any_obj.type.name}")
177192

178193
def _load_struct(self, any_obj: Any) -> StructOrUnion:
179194
return serializer.deserialize(

thrift/lib/python/any/serializer.pyi

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ from apache.thrift.type.type.thrift_types import Type
1919
from folly.iobuf import IOBuf
2020
from thrift.python.serializer import Protocol
2121

22-
PrimitiveType = typing.Union[int, float]
22+
PrimitiveType = typing.Union[bool, int, float, str, bytes]
2323
Primitive = typing.TypeVar("Primitive", bound=PrimitiveType)
2424

2525
def serialize_primitive(

thrift/lib/python/any/serializer.pyx

Lines changed: 53 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -16,53 +16,90 @@ from apache.thrift.type.standard.thrift_types import TypeName
1616
from apache.thrift.type.type.thrift_types import Type
1717
from cython.view cimport memoryview
1818
from folly.iobuf cimport IOBuf, from_unique_ptr
19+
from libcpp cimport bool as cbool
1920
from libcpp.utility cimport move as cmove
20-
from thrift.python.types cimport Struct, StructOrUnion, Union, doubleTypeInfo, floatTypeInfo, i16TypeInfo, i32TypeInfo, i64TypeInfo
21+
from thrift.python.types cimport Struct, StructOrUnion, Union, getCTypeInfo
22+
from thrift.python.types import (
23+
typeinfo_bool,
24+
typeinfo_byte,
25+
typeinfo_i16,
26+
typeinfo_i32,
27+
typeinfo_i64,
28+
typeinfo_double,
29+
typeinfo_float,
30+
typeinfo_string,
31+
typeinfo_binary,
32+
)
33+
from cython.operator cimport dereference as deref
2134

2235
import cython
2336
import typing
2437

2538
Buf = cython.fused_type(IOBuf, bytes, bytearray, memoryview)
26-
Primitive = cython.fused_type(int, float)
2739

28-
cdef cTypeInfo _thrift_type_to_type_info(thrift_type):
40+
41+
def _thrift_type_to_type_info(thrift_type):
42+
if thrift_type.name.type is TypeName.Type.boolType:
43+
return typeinfo_bool
44+
if thrift_type.name.type is TypeName.Type.byteType:
45+
return typeinfo_byte
2946
if thrift_type.name.type is TypeName.Type.i16Type:
30-
return i16TypeInfo
47+
return typeinfo_i16
3148
if thrift_type.name.type is TypeName.Type.i32Type:
32-
return i32TypeInfo
49+
return typeinfo_i32
3350
if thrift_type.name.type is TypeName.Type.i64Type:
34-
return i64TypeInfo
51+
return typeinfo_i64
3552
if thrift_type.name.type is TypeName.Type.floatType:
36-
return floatTypeInfo
53+
return typeinfo_float
3754
if thrift_type.name.type is TypeName.Type.doubleType:
38-
return doubleTypeInfo
55+
return typeinfo_double
56+
if thrift_type.name.type is TypeName.Type.stringType:
57+
return typeinfo_string
58+
if thrift_type.name.type is TypeName.Type.binaryType:
59+
return typeinfo_binary
3960
raise NotImplementedError(f"Unsupported type: {thrift_type}")
4061

4162

42-
cdef cTypeInfo _infer_type_info_from_cls(cls):
63+
def _infer_type_info_from_cls(cls):
64+
if issubclass(cls, bool):
65+
return typeinfo_bool
4366
if issubclass(cls, int):
44-
return i64TypeInfo
67+
return typeinfo_i64
4568
if issubclass(cls, float):
46-
return doubleTypeInfo
69+
return typeinfo_double
70+
if issubclass(cls, str):
71+
return typeinfo_string
72+
if issubclass(cls, bytes):
73+
return typeinfo_binary
4774
raise NotImplementedError(f"Can not infer thrift type from: {cls}")
4875

4976

50-
def serialize_primitive(Primitive obj, Protocol protocol=Protocol.COMPACT, thrift_type=None):
51-
cdef cTypeInfo type_info
77+
def serialize_primitive(obj, Protocol protocol=Protocol.COMPACT, thrift_type=None):
5278
if thrift_type is None:
5379
type_info = _infer_type_info_from_cls(type(obj))
5480
else:
5581
type_info = _thrift_type_to_type_info(thrift_type)
5682
return folly.iobuf.from_unique_ptr(
57-
cmove(cserialize_type(type_info, obj, protocol))
83+
cmove(
84+
cserialize_type(
85+
deref(getCTypeInfo(type_info)),
86+
type_info.to_internal_data(obj),
87+
protocol,
88+
)
89+
)
5890
)
5991

6092

6193
def deserialize_primitive(cls, Buf buf, Protocol protocol=Protocol.COMPACT, thrift_type=None):
62-
cdef cTypeInfo type_info
6394
if thrift_type is None:
6495
type_info = _infer_type_info_from_cls(cls)
6596
else:
6697
type_info = _thrift_type_to_type_info(thrift_type)
6798
cdef IOBuf iobuf = buf if isinstance(buf, IOBuf) else IOBuf(buf)
68-
return cdeserialize_type(type_info, iobuf._this, protocol)
99+
return type_info.to_python_value(
100+
cdeserialize_type(
101+
deref(getCTypeInfo(type_info)),
102+
iobuf._this,
103+
protocol,
104+
)
105+
)

thrift/lib/python/any/test/any_registry.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,11 @@
3232
)
3333

3434
TEST_PRIMITIVES = [
35+
True,
3536
42,
3637
123456.789,
38+
"thrift-python",
39+
b"raw bytes",
3740
]
3841

3942

@@ -81,7 +84,9 @@ def test_primitive_round_trip(self) -> None:
8184
)
8285
loaded = registry.load(any_obj)
8386
if isinstance(loaded, float):
84-
self.assertAlmostEqual(primitive, loaded, places=3)
87+
self.assertAlmostEqual(
88+
float(primitive), float(loaded), places=3
89+
)
8590
else:
8691
self.assertEqual(primitive, loaded)
8792

thrift/lib/python/any/test/serializer.py

Lines changed: 50 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
from __future__ import annotations
1818

19+
import typing
1920
import unittest
2021

2122
from apache.thrift.type.standard.thrift_types import TypeName, Void
@@ -24,42 +25,59 @@
2425
from thrift.python.any.serializer import deserialize_primitive, serialize_primitive
2526

2627

28+
if typing.TYPE_CHECKING:
29+
from thrift.python.any.serializer import PrimitiveType
30+
31+
2732
class SerializerTests(unittest.TestCase):
33+
def _test_round_trip(
34+
self, value: PrimitiveType, thrift_type: typing.Optional[Type] = None
35+
) -> None:
36+
iobuf = serialize_primitive(value, thrift_type=thrift_type)
37+
decoded = deserialize_primitive(type(value), iobuf, thrift_type=thrift_type)
38+
if isinstance(value, float):
39+
self.assertAlmostEqual(value, float(decoded), places=3)
40+
else:
41+
self.assertEqual(value, decoded)
42+
43+
def test_bool_round_trip(self) -> None:
44+
self._test_round_trip(True)
45+
2846
def test_int_round_trip(self) -> None:
29-
i = 42
30-
iobuf = serialize_primitive(i)
31-
decoded = deserialize_primitive(int, iobuf)
32-
self.assertEqual(i, decoded)
47+
self._test_round_trip(42)
3348

34-
def test_int_round_trip_with_type_name(self) -> None:
35-
i = 42
36-
for type_name in [
37-
TypeName(i16Type=Void.Unused),
38-
TypeName(i32Type=Void.Unused),
39-
TypeName(i64Type=Void.Unused),
40-
]:
49+
def test_float_round_trip(self) -> None:
50+
self._test_round_trip(123456.789)
51+
52+
def test_str_round_trip(self) -> None:
53+
self._test_round_trip("thrift-python")
54+
55+
def test_bytes_round_trip(self) -> None:
56+
self._test_round_trip(b"raw bytes")
57+
58+
def _test_round_trip_with_type_names(
59+
self, value: PrimitiveType, type_names: typing.Sequence[TypeName]
60+
) -> None:
61+
for type_name in type_names:
4162
with self.subTest(type_name=type_name):
42-
iobuf = serialize_primitive(i, thrift_type=Type(name=type_name))
43-
decoded = deserialize_primitive(
44-
int, iobuf, thrift_type=Type(name=type_name)
45-
)
46-
self.assertEqual(i, decoded)
63+
self._test_round_trip(value, thrift_type=Type(name=type_name))
4764

48-
def test_float_round_trip(self) -> None:
49-
f = 123456.789
50-
iobuf = serialize_primitive(f)
51-
decoded = deserialize_primitive(float, iobuf)
52-
self.assertAlmostEqual(f, decoded, delta=0.001)
65+
def test_int_round_trip_with_type_name(self) -> None:
66+
self._test_round_trip_with_type_names(
67+
42,
68+
[
69+
TypeName(byteType=Void.Unused),
70+
TypeName(i16Type=Void.Unused),
71+
TypeName(i32Type=Void.Unused),
72+
TypeName(i64Type=Void.Unused),
73+
],
74+
)
5375

5476
def test_float_round_trip_with_type_name(self) -> None:
55-
f = 123456.789
56-
for type_name in [
57-
TypeName(floatType=Void.Unused),
58-
TypeName(doubleType=Void.Unused),
59-
]:
60-
with self.subTest(type_name=type_name):
61-
iobuf = serialize_primitive(f, thrift_type=Type(name=type_name))
62-
decoded = deserialize_primitive(
63-
float, iobuf, thrift_type=Type(name=type_name)
64-
)
65-
self.assertAlmostEqual(f, decoded, delta=0.001)
77+
self._test_round_trip_with_type_names(
78+
123456.789,
79+
[
80+
TypeName(floatType=Void.Unused),
81+
TypeName(doubleType=Void.Unused),
82+
],
83+
)

thrift/lib/python/types.pxd

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,3 +178,5 @@ cdef void set_struct_field(tuple struct_tuple, int16_t index, value) except *
178178

179179
cdef class ServiceInterface:
180180
pass
181+
182+
cdef const cTypeInfo* getCTypeInfo(type_info)

0 commit comments

Comments
 (0)