"""Tests for cast.""" import time from absl import logging from google3.experimental.users.jblespiau.pybind11 import cast from google3.testing.pybase import googletest class CastTest(googletest.TestCase): def test_give_me_a_name(self): # import pdb # pdb.set_trace() num_iterations = 1000000 logging.info("Creating %d objects from Python", num_iterations) start = time.time() for _ in range(num_iterations): obj = cast.MakeJaxCompiledFunction(1) end = time.time() logging.info("Raw API took %.2fs", end - start) start = time.time() for _ in range(num_iterations): obj = cast.MakeJaxCompiledFunctionPybind11Fast(1) end = time.time() logging.info("Pybind11 FAST took %.2fs", end - start) start = time.time() for _ in range(num_iterations): obj = cast.JaxCompiledFunction(1) end = time.time() logging.info("Pybind11 took %.2fs", end - start) logging.info("") logging.info("Accessing the C++ object from the Python wrapped object") c_api_obj = cast.MakeJaxCompiledFunction(1) start = time.time() for _ in range(num_iterations): cast.AccessFieldPython(c_api_obj) end = time.time() logging.info("Raw C API cast took %.2fs", end - start) pybind11_obj = cast.JaxCompiledFunction(1) start = time.time() for _ in range(num_iterations): cast.AccessFieldPybind11Fast(pybind11_obj) end = time.time() logging.info("Pybind11Fast took %.2fs", end - start) pybind11_obj = cast.JaxCompiledFunction(1) start = time.time() for _ in range(num_iterations): cast.AccessFieldPybind11(pybind11_obj) end = time.time() logging.info("Pybind11 took %.2fs", end - start) logging.info("") logging.info("Accessing many C++ objects from the Python wrapped objects") c_api_objs = [cast.MakeJaxCompiledFunction(i) for i in range(100)] start = time.time() for _ in range(num_iterations): cast.AccessFieldsPython(c_api_objs) end = time.time() logging.info("Raw C API cast took %.2fs", end - start) pybind11_objs = [cast.JaxCompiledFunction(i) for i in range(100)] start = time.time() for _ in range(num_iterations): cast.AccessFieldsPybind11Fast(pybind11_objs) end = time.time() logging.info("Pybind11 Fast took %.2fs", end - start) pybind11_objs = [cast.JaxCompiledFunction(i) for i in range(100)] start = time.time() for _ in range(num_iterations): cast.AccessFieldsPybind11(pybind11_objs) end = time.time() logging.info("Pybind11 took %.2fs", end - start) logging.info("") logging.info("Creating many C++ objects and returning them to Python") num_iterations = 100 start = time.time() for _ in range(num_iterations): cast.CreateManyObjectsPython(100000) end = time.time() logging.info("Raw C API cast took %.2fs", end - start) start = time.time() for _ in range(num_iterations): cast.CreateManyObjectsPybind11Fast(100000) end = time.time() logging.info("Pybind11 FAST took %.2fs", end - start) start = time.time() for _ in range(num_iterations): cast.CreateManyObjectsPybind11(100000) end = time.time() logging.info("Pybind11 took %.2fs", end - start) import pdb pdb.set_trace() if __name__ == "__main__": googletest.main()