Skip to content

Commit 966901a

Browse files
committed
Added optimization for copying arrays of simple types from python to C
1 parent 5b9fe9c commit 966901a

File tree

1 file changed

+141
-110
lines changed

1 file changed

+141
-110
lines changed

rosidl_generator_py/resource/_msg_support.c.em

Lines changed: 141 additions & 110 deletions
Original file line numberDiff line numberDiff line change
@@ -248,156 +248,187 @@ nested_type = '__'.join(type_.namespaced_name())
248248
}
249249
@[ end if]@
250250
@[ elif isinstance(member.type, AbstractNestedType)]@
251-
@[ if isinstance(member.type, Array) and isinstance(member.type.value_type, BasicType) and member.type.value_type.typename in SPECIAL_NESTED_BASIC_TYPES]@
252-
// TODO(dirk-thomas) use a better way to check the type before casting
253-
assert(field->ob_type != NULL);
254-
assert(field->ob_type->tp_name != NULL);
255-
assert(strcmp(field->ob_type->tp_name, "numpy.ndarray") == 0);
256-
PyArrayObject * seq_field = (PyArrayObject *)field;
257-
Py_INCREF(seq_field);
258-
assert(PyArray_NDIM(seq_field) == 1);
259-
assert(PyArray_TYPE(seq_field) == @(SPECIAL_NESTED_BASIC_TYPES[member.type.value_type.typename]['dtype'].replace('numpy.', 'NPY_').upper()));
260-
@[ else]@
261-
PyObject * seq_field = PySequence_Fast(field, "expected a sequence in '@(member.name)'");
262-
if (!seq_field) {
263-
Py_DECREF(field);
264-
return false;
265-
}
266-
@[ end if]@
267-
@[ if isinstance(member.type, AbstractSequence)]@
268-
Py_ssize_t size = PySequence_Size(field);
269-
if (-1 == size) {
270-
Py_DECREF(seq_field);
271-
Py_DECREF(field);
272-
return false;
273-
}
274-
@[ if isinstance(member.type.value_type, AbstractString)]@
275-
if (!rosidl_runtime_c__String__Sequence__init(&(ros_message->@(member.name)), size)) {
276-
PyErr_SetString(PyExc_RuntimeError, "unable to create String__Sequence ros_message");
277-
Py_DECREF(seq_field);
278-
Py_DECREF(field);
279-
return false;
280-
}
281-
@[ elif isinstance(member.type.value_type, AbstractWString)]@
282-
if (!rosidl_runtime_c__U16String__Sequence__init(&(ros_message->@(member.name)), size)) {
283-
PyErr_SetString(PyExc_RuntimeError, "unable to create U16String__Sequence ros_message");
284-
Py_DECREF(seq_field);
285-
Py_DECREF(field);
286-
return false;
287-
}
288-
@[ else]@
289-
if (!rosidl_runtime_c__@(member.type.value_type.typename)__Sequence__init(&(ros_message->@(member.name)), size)) {
290-
PyErr_SetString(PyExc_RuntimeError, "unable to create @(member.type.value_type.typename)__Sequence ros_message");
291-
Py_DECREF(seq_field);
292-
Py_DECREF(field);
293-
return false;
294-
}
295-
@[ end if]@
296-
@primitive_msg_type_to_c(member.type.value_type) * dest = ros_message->@(member.name).data;
297-
@[ else]@
298-
Py_ssize_t size = @(member.type.size);
299-
@primitive_msg_type_to_c(member.type.value_type) * dest = ros_message->@(member.name);
300-
@[ end if]@
301-
for (Py_ssize_t i = 0; i < size; ++i) {
302-
@[ if not isinstance(member.type, Array) or not isinstance(member.type.value_type, BasicType) or member.type.value_type.typename not in SPECIAL_NESTED_BASIC_TYPES]@
303-
PyObject * item = PySequence_Fast_GET_ITEM(seq_field, i);
304-
if (!item) {
305-
Py_DECREF(seq_field);
251+
@[ if isinstance(member.type, AbstractSequence) and isinstance(member.type.value_type, BasicType)]@
252+
if (PyObject_CheckBuffer(field)) {
253+
// Optimization for converting arrays of primitives
254+
Py_buffer view;
255+
int rc = PyObject_GetBuffer(field, &view, PyBUF_SIMPLE);
256+
if (rc) {
257+
PyErr_SetString(PyExc_RuntimeError, "unable to get buffer");
306258
Py_DECREF(field);
307259
return false;
308260
}
261+
Py_ssize_t size = view.len / sizeof(@primitive_msg_type_to_c(member.type.value_type));
262+
if (!rosidl_runtime_c__@(member.type.value_type.typename)__Sequence__init(&(ros_message->@(member.name)), size)) {
263+
PyErr_SetString(PyExc_RuntimeError, "unable to create @(member.type.value_type.typename)__Sequence ros_message");
264+
PyBuffer_Release(&view);
265+
Py_DECREF(field);
266+
return false;
267+
}
268+
@primitive_msg_type_to_c(member.type.value_type) * dest = ros_message->@(member.name).data;
269+
rc = PyBuffer_ToContiguous(dest, &view, view.len, 'C');
270+
if (rc) {
271+
PyErr_SetString(PyExc_RuntimeError, "unable to copy buffer");
272+
PyBuffer_Release(&view);
273+
Py_DECREF(field);
274+
return false;
275+
}
276+
PyBuffer_Release(&view);
277+
} else {
278+
@[ else]@
279+
{
309280
@[ end if]@
310281
@[ if isinstance(member.type, Array) and isinstance(member.type.value_type, BasicType) and member.type.value_type.typename in SPECIAL_NESTED_BASIC_TYPES]@
311-
@primitive_msg_type_to_c(member.type.value_type) tmp = *(@(SPECIAL_NESTED_BASIC_TYPES[member.type.value_type.typename]['dtype'].replace('numpy.', 'npy_')) *)PyArray_GETPTR1(seq_field, i);
312-
@[ elif isinstance(member.type.value_type, BasicType) and member.type.value_type.typename == 'char']@
313-
assert(PyUnicode_Check(item));
314-
PyObject * encoded_item = PyUnicode_AsUTF8String(item);
315-
if (!encoded_item) {
316-
Py_DECREF(seq_field);
282+
// TODO(dirk-thomas) use a better way to check the type before casting
283+
assert(field->ob_type != NULL);
284+
assert(field->ob_type->tp_name != NULL);
285+
assert(strcmp(field->ob_type->tp_name, "numpy.ndarray") == 0);
286+
PyArrayObject * seq_field = (PyArrayObject *)field;
287+
Py_INCREF(seq_field);
288+
assert(PyArray_NDIM(seq_field) == 1);
289+
assert(PyArray_TYPE(seq_field) == @(SPECIAL_NESTED_BASIC_TYPES[member.type.value_type.typename]['dtype'].replace('numpy.', 'NPY_').upper()));
290+
@[ else]@
291+
PyObject * seq_field = PySequence_Fast(field, "expected a sequence in '@(member.name)'");
292+
if (!seq_field) {
317293
Py_DECREF(field);
318294
return false;
319295
}
320-
@primitive_msg_type_to_c(member.type.value_type) tmp = PyBytes_AS_STRING(encoded_item)[0];
321-
Py_DECREF(encoded_item);
322-
@[ elif isinstance(member.type.value_type, BasicType) and member.type.value_type.typename == 'octet']@
323-
assert(PyBytes_Check(item));
324-
@primitive_msg_type_to_c(member.type.value_type) tmp = PyBytes_AS_STRING(item)[0];
325-
@[ elif isinstance(member.type.value_type, AbstractString)]@
326-
assert(PyUnicode_Check(item));
327-
PyObject * encoded_item = PyUnicode_AsUTF8String(item);
328-
if (!encoded_item) {
296+
@[ end if]@
297+
@[ if isinstance(member.type, AbstractSequence)]@
298+
Py_ssize_t size = PySequence_Size(field);
299+
if (-1 == size) {
329300
Py_DECREF(seq_field);
330301
Py_DECREF(field);
331302
return false;
332303
}
333-
rosidl_runtime_c__String__assign(&dest[i], PyBytes_AS_STRING(encoded_item));
334-
Py_DECREF(encoded_item);
335-
@[ elif isinstance(member.type.value_type, AbstractWString)]@
336-
assert(PyUnicode_Check(item));
337-
// the returned string starts with a BOM mark and uses native byte order
338-
PyObject * encoded_item = PyUnicode_AsUTF16String(item);
339-
if (!encoded_item) {
304+
@[ if isinstance(member.type.value_type, AbstractString)]@
305+
if (!rosidl_runtime_c__String__Sequence__init(&(ros_message->@(member.name)), size)) {
306+
PyErr_SetString(PyExc_RuntimeError, "unable to create String__Sequence ros_message");
340307
Py_DECREF(seq_field);
341308
Py_DECREF(field);
342309
return false;
343310
}
344-
char * buffer;
345-
Py_ssize_t length;
346-
int rc = PyBytes_AsStringAndSize(encoded_item, &buffer, &length);
347-
if (rc) {
348-
Py_DECREF(encoded_item);
311+
@[ elif isinstance(member.type.value_type, AbstractWString)]@
312+
if (!rosidl_runtime_c__U16String__Sequence__init(&(ros_message->@(member.name)), size)) {
313+
PyErr_SetString(PyExc_RuntimeError, "unable to create U16String__Sequence ros_message");
349314
Py_DECREF(seq_field);
350315
Py_DECREF(field);
351316
return false;
352317
}
353-
// use offset of 2 to skip BOM mark
354-
bool succeeded = rosidl_runtime_c__U16String__assignn_from_char(&dest[i], buffer + 2, length - 2);
355-
Py_DECREF(encoded_item);
356-
if (!succeeded) {
318+
@[ else]@
319+
if (!rosidl_runtime_c__@(member.type.value_type.typename)__Sequence__init(&(ros_message->@(member.name)), size)) {
320+
PyErr_SetString(PyExc_RuntimeError, "unable to create @(member.type.value_type.typename)__Sequence ros_message");
357321
Py_DECREF(seq_field);
358322
Py_DECREF(field);
359323
return false;
360324
}
325+
@[ end if]@
326+
@primitive_msg_type_to_c(member.type.value_type) * dest = ros_message->@(member.name).data;
327+
@[ else]@
328+
Py_ssize_t size = @(member.type.size);
329+
@primitive_msg_type_to_c(member.type.value_type) * dest = ros_message->@(member.name);
330+
@[ end if]@
331+
for (Py_ssize_t i = 0; i < size; ++i) {
332+
@[ if not isinstance(member.type, Array) or not isinstance(member.type.value_type, BasicType) or member.type.value_type.typename not in SPECIAL_NESTED_BASIC_TYPES]@
333+
PyObject * item = PySequence_Fast_GET_ITEM(seq_field, i);
334+
if (!item) {
335+
Py_DECREF(seq_field);
336+
Py_DECREF(field);
337+
return false;
338+
}
339+
@[ end if]@
340+
@[ if isinstance(member.type, Array) and isinstance(member.type.value_type, BasicType) and member.type.value_type.typename in SPECIAL_NESTED_BASIC_TYPES]@
341+
@primitive_msg_type_to_c(member.type.value_type) tmp = *(@(SPECIAL_NESTED_BASIC_TYPES[member.type.value_type.typename]['dtype'].replace('numpy.', 'npy_')) *)PyArray_GETPTR1(seq_field, i);
342+
@[ elif isinstance(member.type.value_type, BasicType) and member.type.value_type.typename == 'char']@
343+
assert(PyUnicode_Check(item));
344+
PyObject * encoded_item = PyUnicode_AsUTF8String(item);
345+
if (!encoded_item) {
346+
Py_DECREF(seq_field);
347+
Py_DECREF(field);
348+
return false;
349+
}
350+
@primitive_msg_type_to_c(member.type.value_type) tmp = PyBytes_AS_STRING(encoded_item)[0];
351+
Py_DECREF(encoded_item);
352+
@[ elif isinstance(member.type.value_type, BasicType) and member.type.value_type.typename == 'octet']@
353+
assert(PyBytes_Check(item));
354+
@primitive_msg_type_to_c(member.type.value_type) tmp = PyBytes_AS_STRING(item)[0];
355+
@[ elif isinstance(member.type.value_type, AbstractString)]@
356+
assert(PyUnicode_Check(item));
357+
PyObject * encoded_item = PyUnicode_AsUTF8String(item);
358+
if (!encoded_item) {
359+
Py_DECREF(seq_field);
360+
Py_DECREF(field);
361+
return false;
362+
}
363+
rosidl_runtime_c__String__assign(&dest[i], PyBytes_AS_STRING(encoded_item));
364+
Py_DECREF(encoded_item);
365+
@[ elif isinstance(member.type.value_type, AbstractWString)]@
366+
assert(PyUnicode_Check(item));
367+
// the returned string starts with a BOM mark and uses native byte order
368+
PyObject * encoded_item = PyUnicode_AsUTF16String(item);
369+
if (!encoded_item) {
370+
Py_DECREF(seq_field);
371+
Py_DECREF(field);
372+
return false;
373+
}
374+
char * buffer;
375+
Py_ssize_t length;
376+
int rc = PyBytes_AsStringAndSize(encoded_item, &buffer, &length);
377+
if (rc) {
378+
Py_DECREF(encoded_item);
379+
Py_DECREF(seq_field);
380+
Py_DECREF(field);
381+
return false;
382+
}
383+
// use offset of 2 to skip BOM mark
384+
bool succeeded = rosidl_runtime_c__U16String__assignn_from_char(&dest[i], buffer + 2, length - 2);
385+
Py_DECREF(encoded_item);
386+
if (!succeeded) {
387+
Py_DECREF(seq_field);
388+
Py_DECREF(field);
389+
return false;
390+
}
361391
@[ elif isinstance(member.type.value_type, BasicType) and member.type.value_type.typename == 'boolean']@
362-
assert(PyBool_Check(item));
363-
@primitive_msg_type_to_c(member.type.value_type) tmp = (item == Py_True);
392+
assert(PyBool_Check(item));
393+
@primitive_msg_type_to_c(member.type.value_type) tmp = (item == Py_True);
364394
@[ elif isinstance(member.type.value_type, BasicType) and member.type.value_type.typename in ('float', 'double')]@
365-
assert(PyFloat_Check(item));
395+
assert(PyFloat_Check(item));
366396
@[ if member.type.value_type.typename == 'float']@
367-
@primitive_msg_type_to_c(member.type.value_type) tmp = (float)PyFloat_AS_DOUBLE(item);
397+
@primitive_msg_type_to_c(member.type.value_type) tmp = (float)PyFloat_AS_DOUBLE(item);
368398
@[ else]@
369-
@primitive_msg_type_to_c(member.type.value_type) tmp = PyFloat_AS_DOUBLE(item);
399+
@primitive_msg_type_to_c(member.type.value_type) tmp = PyFloat_AS_DOUBLE(item);
370400
@[ end if]@
371401
@[ elif isinstance(member.type.value_type, BasicType) and member.type.value_type.typename in (
372-
'int8',
373-
'int16',
374-
'int32',
375-
)]@
376-
assert(PyLong_Check(item));
377-
@primitive_msg_type_to_c(member.type.value_type) tmp = (@(primitive_msg_type_to_c(member.type.value_type)))PyLong_AsLong(item);
402+
'int8',
403+
'int16',
404+
'int32',
405+
)]@
406+
assert(PyLong_Check(item));
407+
@primitive_msg_type_to_c(member.type.value_type) tmp = (@(primitive_msg_type_to_c(member.type.value_type)))PyLong_AsLong(item);
378408
@[ elif isinstance(member.type.value_type, BasicType) and member.type.value_type.typename in (
379-
'uint8',
380-
'uint16',
381-
'uint32',
382-
)]@
383-
assert(PyLong_Check(item));
409+
'uint8',
410+
'uint16',
411+
'uint32',
412+
)]@
413+
assert(PyLong_Check(item));
384414
@[ if isinstance(member.type.value_type, BasicType) and member.type.value_type.typename == 'uint32']@
385-
@primitive_msg_type_to_c(member.type.value_type) tmp = PyLong_AsUnsignedLong(item);
415+
@primitive_msg_type_to_c(member.type.value_type) tmp = PyLong_AsUnsignedLong(item);
386416
@[ else]@
387-
@primitive_msg_type_to_c(member.type.value_type) tmp = (@(primitive_msg_type_to_c(member.type.value_type)))PyLong_AsUnsignedLong(item);
417+
@primitive_msg_type_to_c(member.type.value_type) tmp = (@(primitive_msg_type_to_c(member.type.value_type)))PyLong_AsUnsignedLong(item);
388418
@[ end if]
389419
@[ elif isinstance(member.type.value_type, BasicType) and member.type.value_type.typename == 'int64']@
390-
assert(PyLong_Check(item));
391-
@primitive_msg_type_to_c(member.type.value_type) tmp = PyLong_AsLongLong(item);
420+
assert(PyLong_Check(item));
421+
@primitive_msg_type_to_c(member.type.value_type) tmp = PyLong_AsLongLong(item);
392422
@[ elif isinstance(member.type.value_type, BasicType) and member.type.value_type.typename == 'uint64']@
393-
assert(PyLong_Check(item));
394-
@primitive_msg_type_to_c(member.type.value_type) tmp = PyLong_AsUnsignedLongLong(item);
423+
assert(PyLong_Check(item));
424+
@primitive_msg_type_to_c(member.type.value_type) tmp = PyLong_AsUnsignedLongLong(item);
395425
@[ end if]@
396426
@[ if isinstance(member.type.value_type, BasicType)]@
397-
memcpy(&dest[i], &tmp, sizeof(@primitive_msg_type_to_c(member.type.value_type)));
427+
memcpy(&dest[i], &tmp, sizeof(@primitive_msg_type_to_c(member.type.value_type)));
398428
@[ end if]@
429+
}
430+
Py_DECREF(seq_field);
399431
}
400-
Py_DECREF(seq_field);
401432
@[ elif isinstance(member.type, BasicType) and member.type.typename == 'char']@
402433
assert(PyUnicode_Check(field));
403434
PyObject * encoded_field = PyUnicode_AsUTF8String(field);

0 commit comments

Comments
 (0)