Skip to content

Commit 25c08fa

Browse files
kaviramrwgk
authored andcommitted
Modify function.py to support standard containers.
Modify utils.py to add more tests for when cpp_exact_type is faulty. C++ functions with return parameters require the code generator to generate a C++ lambda function. PiperOrigin-RevId: 363927286
1 parent c39aa51 commit 25c08fa

File tree

3 files changed

+106
-43
lines changed

3 files changed

+106
-43
lines changed

clif/pybind11/function.py

Lines changed: 43 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,11 @@ def generate_from(module_name: str, func_decl: ast_pb2.FuncDecl,
3838
pybind11 function bindings code.
3939
"""
4040

41+
if len(func_decl.returns) >= 2 or (len(func_decl.returns) >= 1 and
42+
func_decl.cpp_void_return):
43+
yield from _generate_return_args_lambda(func_decl, module_name)
44+
return
45+
4146
cpp_lambda_return_type = _has_bytes_return(func_decl)
4247
if cpp_lambda_return_type:
4348
yield from _generate_cpp_lambda(func_decl, cpp_lambda_return_type,
@@ -86,7 +91,7 @@ def _generate_cpp_function_cast(func_decl: ast_pb2.FuncDecl,
8691
params_list_types = []
8792
for param in func_decl.params:
8893
if param.HasField('cpp_exact_type'):
89-
if utils.is_nested_template(param.cpp_exact_type):
94+
if not utils.is_usable_cpp_exact_type(param.cpp_exact_type):
9095
params_list_types.append(param.type.cpp_type)
9196
else:
9297
params_list_types.append(param.cpp_exact_type)
@@ -156,7 +161,7 @@ def get_params_strings(func: ast_pb2.FuncDecl):
156161
lang_types.append(param.type.lang_type)
157162
cpp_names.append(param.name.cpp_name)
158163
params_str_with_types_list.append(
159-
f'{param.type.lang_type} {param.name.cpp_name}')
164+
f'{param.type.cpp_type} {param.name.cpp_name}')
160165
if param.default_value:
161166
default_values.append(
162167
f'py::arg("{param.name.cpp_name}") = {param.default_value}')
@@ -202,3 +207,39 @@ def _generate_cpp_lambda(func_decl: ast_pb2.FuncDecl, return_type: str,
202207
yield I + I + '}'
203208
yield I + ');'
204209
return
210+
211+
212+
def _generate_return_args_lambda(func_decl: ast_pb2.FuncDecl, module_name: str):
213+
"""Generates C++ lambda functions with return parameters."""
214+
params_strings = get_params_strings(func_decl)
215+
yield (f'{module_name}.def("{func_decl.name.native}",'
216+
f'[]({params_strings.names_with_types}) {{')
217+
218+
main_return = ''
219+
main_return_cpp_name = ''
220+
other_returns_cpp_names = []
221+
for i, r in enumerate(func_decl.returns):
222+
if i == 0:
223+
main_return_cpp_name = r.name.cpp_name
224+
main_return = f'{r.type.cpp_type} {r.name.cpp_name}'
225+
continue
226+
other_returns_cpp_names.append(r.name.cpp_name)
227+
yield I + f'{r.type.cpp_type} {r.name.cpp_name};'
228+
other_returns = ', '.join(other_returns_cpp_names)
229+
other_returns_params_list = [f'&{r}' for r in other_returns_cpp_names]
230+
231+
if not func_decl.cpp_void_return:
232+
yield I + (f'{main_return} = {func_decl.name.cpp_name}'
233+
f'({params_strings.cpp_names}, &y);')
234+
yield I + f'return std::make_tuple({main_return_cpp_name}, {other_returns});'
235+
else:
236+
yield I + f'{main_return};'
237+
if not other_returns_cpp_names:
238+
yield I + (f'{func_decl.name.cpp_name}({params_strings.cpp_names},'
239+
f'&{main_return_cpp_name});')
240+
yield I + f'return {main_return_cpp_name};'
241+
else:
242+
yield I + (f'{func_decl.name.cpp_name}({params_strings.cpp_names},'
243+
f'&{main_return_cpp_name}, {other_returns_params_list});')
244+
yield I + f'return std::make_tuple({main_return_cpp_name}, {other_returns});'
245+
yield '});'

clif/pybind11/utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,10 @@ def is_nested_template(s: str) -> bool:
6565
return '>>' in s
6666

6767

68+
def is_usable_cpp_exact_type(s: str) -> bool:
69+
return not is_nested_template(s) or '&' in s
70+
71+
6872
# Dataclass used to group together and pass around parameter lists.
6973
@dataclasses.dataclass
7074
class ParamsStrings:

clif/testing/python/std_containers_test.py

Lines changed: 59 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -14,72 +14,90 @@
1414

1515
"""Tests for clif.testing.python.std_containers."""
1616

17-
import unittest
18-
from clif.testing.python import std_containers
19-
20-
21-
class StdContainersTest(unittest.TestCase):
22-
23-
def testVector(self):
24-
self.assertEqual(std_containers.Mul([1, 2, 3], 2), [2, 4, 6])
25-
26-
def testArray(self):
27-
self.assertEqual(std_containers.Div([2, 4], 2), [1, 2])
17+
from absl.testing import absltest
18+
from absl.testing import parameterized
2819

20+
from clif.testing.python import std_containers
21+
# TODO: Restore simple import after OSS setup includes pybind11.
22+
# pylint: disable=g-import-not-at-top
23+
try:
24+
from clif.testing.python import std_containers_pybind11
25+
except ImportError:
26+
std_containers_pybind11 = None
27+
# pylint: enable=g-import-not-at-top
28+
29+
30+
@parameterized.named_parameters([
31+
np for np in zip(('c_api', 'pybind11'), (std_containers,
32+
std_containers_pybind11))
33+
if np[1] is not None
34+
])
35+
class StdContainersTest(absltest.TestCase):
36+
37+
def testVector(self, wrapper_lib):
38+
self.assertEqual(wrapper_lib.Mul([1, 2, 3], 2), [2, 4, 6])
39+
40+
def testArray(self, wrapper_lib):
41+
self.assertEqual(wrapper_lib.Div([2, 4], 2), [1, 2])
42+
43+
error = ValueError
44+
if wrapper_lib is std_containers_pybind11:
45+
error = TypeError
2946
# Exceed bounds of array.
30-
self.assertRaises(ValueError, lambda: std_containers.Div([2, 4, 6], 2))
47+
self.assertRaises(error, lambda: wrapper_lib.Div([2, 4, 6], 2))
3148

32-
def testVectorBool(self):
33-
self.assertEqual(std_containers.Odd([1, 2, 3]), [True, False, True])
34-
self.assertEqual(std_containers.Even([1, 2, 3]), [False, True, False])
49+
def testVectorBool(self, wrapper_lib):
50+
self.assertEqual(wrapper_lib.Odd([1, 2, 3]), [True, False, True])
51+
self.assertEqual(wrapper_lib.Even([1, 2, 3]), [False, True, False])
3552

36-
def testMap(self):
53+
def testMap(self, wrapper_lib):
3754
# pylint: disable=bad-whitespace
38-
self.assertEqual(std_containers.Find(0, {1: 2, 3: 4, 5: 0}), (True, 5))
39-
r = std_containers.Find(6, {1: 2, 3: 4, 5: 0})
55+
self.assertEqual(wrapper_lib.Find(0, {1: 2, 3: 4, 5: 0}), (True, 5))
56+
r = wrapper_lib.Find(6, {1: 2, 3: 4, 5: 0})
4057
self.assertIsInstance(r, tuple)
4158
self.assertIs(r[0], False)
4259

43-
def testOnes(self):
44-
ones = std_containers.Ones(4, 5)
45-
self.assertEqual(len(ones), 4)
60+
def testOnes(self, wrapper_lib):
61+
ones = wrapper_lib.Ones(4, 5)
62+
self.assertLen(ones, 4)
4663
for row in ones:
47-
self.assertEqual(len(row), 5)
64+
self.assertLen(row, 5)
4865
self.assertTrue(all(elem == 1 for elem in row))
4966

50-
def testCapitals(self):
51-
capitals = std_containers.Capitals()
52-
self.assertEqual(len(capitals), 3)
67+
def testCapitals(self, wrapper_lib):
68+
capitals = wrapper_lib.Capitals()
69+
self.assertLen(capitals, 3)
5370
self.assertEqual(capitals[0], ('CA', 'Sacramento'))
5471
self.assertEqual(capitals[1], ('OR', 'Salem'))
5572
self.assertEqual(capitals[2], ('WA', 'Olympia'))
5673

57-
def testMatrixSum(self):
74+
def testMatrixSum(self, wrapper_lib):
5875
a = [[1, 2, 3], [3, 2, 1]]
5976
b = [[4, 3, 2], [2, 3, 4]]
60-
s = std_containers.MatrixSum(a, b)
61-
self.assertEqual(len(s), 2)
77+
s = wrapper_lib.MatrixSum(a, b)
78+
self.assertLen(s, 2)
6279
for row in s:
63-
self.assertEqual(len(row), 3)
80+
self.assertLen(row, 3)
6481
self.assertTrue(all(elem == 5 for elem in row))
6582

66-
def testMake2By3(self):
67-
m = std_containers.Make2By3()
68-
self.assertEqual(len(m), 2)
83+
def testMake2By3(self, wrapper_lib):
84+
m = wrapper_lib.Make2By3()
85+
self.assertLen(m, 2)
6986
for c in m:
70-
self.assertEqual(len(c), 3)
87+
self.assertLen(c, 3)
7188

72-
def testFlatten2By3(self):
89+
def testFlatten2By3(self, wrapper_lib):
7390
m = ((1, 2, 3), (4, 5, 6))
74-
self.assertEqual(len(m), 2)
75-
t = std_containers.Flatten2By3(m)
76-
self.assertEqual(len(t), 6)
91+
self.assertLen(m, 2)
92+
t = wrapper_lib.Flatten2By3(m)
93+
self.assertLen(t, 6)
7794
for e, i in zip(t, range(1, 7)):
7895
self.assertEqual(e, i)
7996

80-
def testTakeVectorOfStrings(self):
81-
self.assertEqual(std_containers.LastStringInVector(['hello', 'world']),
82-
'world')
97+
def testTakeVectorOfStrings(self, wrapper_lib):
98+
self.assertEqual(
99+
wrapper_lib.LastStringInVector(['hello', 'world']), 'world')
100+
83101

84102
if __name__ == '__main__':
85-
unittest.main()
103+
absltest.main()

0 commit comments

Comments
 (0)