Skip to content

Commit 5222b83

Browse files
committed
HPyArray_AssignArray
1 parent a5ce64f commit 5222b83

File tree

6 files changed

+241
-8
lines changed

6 files changed

+241
-8
lines changed

numpy/core/src/common/array_assign.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,12 @@ PyArray_AssignArray(PyArrayObject *dst, PyArrayObject *src,
3030
PyArrayObject *wheremask,
3131
NPY_CASTING casting);
3232

33+
34+
NPY_NO_EXPORT int
35+
HPyArray_AssignArray(HPyContext *ctx, HPy h_dst, HPy h_src,
36+
HPy h_wheremask,
37+
NPY_CASTING casting);
38+
3339
NPY_NO_EXPORT int
3440
PyArray_AssignRawScalar(PyArrayObject *dst,
3541
PyArray_Descr *src_dtype, char *src_data,

numpy/core/src/multiarray/array_assign_array.c

Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -438,3 +438,192 @@ PyArray_AssignArray(PyArrayObject *dst, PyArrayObject *src,
438438
}
439439
return -1;
440440
}
441+
442+
NPY_NO_EXPORT int
443+
HPyArray_AssignArray(HPyContext *ctx, HPy h_dst, HPy h_src,
444+
HPy h_wheremask,
445+
NPY_CASTING casting)
446+
{
447+
PyArrayObject *src = PyArrayObject_AsStruct(ctx, h_src);
448+
PyArrayObject *dst = PyArrayObject_AsStruct(ctx, h_dst);
449+
int copied_src = 0;
450+
451+
npy_intp src_strides[NPY_MAXDIMS];
452+
453+
/* Use array_assign_scalar if 'src' NDIM is 0 */
454+
if (PyArray_NDIM(src) == 0) {
455+
capi_warn("HPyArray_AssignArray: PyArray_AssignRawScalar");
456+
return PyArray_AssignRawScalar(
457+
dst, PyArray_DESCR(src), PyArray_DATA(src),
458+
PyArrayObject_AsStruct(ctx, h_wheremask), casting);
459+
}
460+
461+
HPy h_src_descr = HPyArray_DESCR(ctx, h_src, src);
462+
HPy h_dst_descr = HPyArray_DESCR(ctx, h_dst, dst);
463+
/*
464+
* Performance fix for expressions like "a[1000:6000] += x". In this
465+
* case, first an in-place add is done, followed by an assignment,
466+
* equivalently expressed like this:
467+
*
468+
* tmp = a[1000:6000] # Calls array_subscript in mapping.c
469+
* np.add(tmp, x, tmp)
470+
* a[1000:6000] = tmp # Calls array_assign_subscript in mapping.c
471+
*
472+
* In the assignment the underlying data type, shape, strides, and
473+
* data pointers are identical, but src != dst because they are separately
474+
* generated slices. By detecting this and skipping the redundant
475+
* copy of values to themselves, we potentially give a big speed boost.
476+
*
477+
* Note that we don't call EquivTypes, because usually the exact same
478+
* dtype object will appear, and we don't want to slow things down
479+
* with a complicated comparison. The comparisons are ordered to
480+
* try and reject this with as little work as possible.
481+
*/
482+
if (PyArray_DATA(src) == PyArray_DATA(dst) &&
483+
HPy_Is(ctx, h_src_descr, h_dst_descr) &&
484+
PyArray_NDIM(src) == PyArray_NDIM(dst) &&
485+
PyArray_CompareLists(PyArray_DIMS(src),
486+
PyArray_DIMS(dst),
487+
PyArray_NDIM(src)) &&
488+
PyArray_CompareLists(PyArray_STRIDES(src),
489+
PyArray_STRIDES(dst),
490+
PyArray_NDIM(src))) {
491+
/*printf("Redundant copy operation detected\n");*/
492+
return 0;
493+
}
494+
495+
if (PyArray_FailUnlessWriteable(dst, "assignment destination") < 0) {
496+
goto fail;
497+
}
498+
499+
/* Check the casting rule */
500+
if (!HPyArray_CanCastTypeTo(ctx, h_src_descr,
501+
h_dst_descr, casting)) {
502+
capi_warn("HPyArray_AssignArray: npy_set_invalid_cast_error");
503+
npy_set_invalid_cast_error(
504+
PyArray_DESCR(src), PyArray_DESCR(dst), casting, NPY_FALSE);
505+
goto fail;
506+
}
507+
508+
/*
509+
* When ndim is 1 and the strides point in the same direction,
510+
* the lower-level inner loop handles copying
511+
* of overlapping data. For bigger ndim and opposite-strided 1D
512+
* data, we make a temporary copy of 'src' if 'src' and 'dst' overlap.'
513+
*/
514+
capi_warn("HPyArray_AssignArray: arrays_overlap and reminder of this function...");
515+
if (((PyArray_NDIM(dst) == 1 && PyArray_NDIM(src) >= 1 &&
516+
PyArray_STRIDES(dst)[0] *
517+
PyArray_STRIDES(src)[PyArray_NDIM(src) - 1] < 0) ||
518+
PyArray_NDIM(dst) > 1 || PyArray_HASFIELDS(dst)) &&
519+
arrays_overlap(src, dst)) {
520+
PyArrayObject *tmp;
521+
522+
/*
523+
* Allocate a temporary copy array.
524+
*/
525+
tmp = (PyArrayObject *)PyArray_NewLikeArray(dst,
526+
NPY_KEEPORDER, NULL, 0);
527+
if (tmp == NULL) {
528+
goto fail;
529+
}
530+
531+
if (PyArray_AssignArray(tmp, src, NULL, NPY_UNSAFE_CASTING) < 0) {
532+
Py_DECREF(tmp);
533+
goto fail;
534+
}
535+
536+
src = tmp;
537+
copied_src = 1;
538+
}
539+
540+
/* Broadcast 'src' to 'dst' for raw iteration */
541+
if (PyArray_NDIM(src) > PyArray_NDIM(dst)) {
542+
int ndim_tmp = PyArray_NDIM(src);
543+
npy_intp *src_shape_tmp = PyArray_DIMS(src);
544+
npy_intp *src_strides_tmp = PyArray_STRIDES(src);
545+
/*
546+
* As a special case for backwards compatibility, strip
547+
* away unit dimensions from the left of 'src'
548+
*/
549+
while (ndim_tmp > PyArray_NDIM(dst) && src_shape_tmp[0] == 1) {
550+
--ndim_tmp;
551+
++src_shape_tmp;
552+
++src_strides_tmp;
553+
}
554+
555+
if (broadcast_strides(PyArray_NDIM(dst), PyArray_DIMS(dst),
556+
ndim_tmp, src_shape_tmp,
557+
src_strides_tmp, "input array",
558+
src_strides) < 0) {
559+
goto fail;
560+
}
561+
}
562+
else {
563+
if (broadcast_strides(PyArray_NDIM(dst), PyArray_DIMS(dst),
564+
PyArray_NDIM(src), PyArray_DIMS(src),
565+
PyArray_STRIDES(src), "input array",
566+
src_strides) < 0) {
567+
goto fail;
568+
}
569+
}
570+
571+
PyArrayObject *wheremask = PyArrayObject_AsStruct(ctx, h_wheremask);
572+
/* optimization: scalar boolean mask */
573+
if (wheremask != NULL &&
574+
PyArray_NDIM(wheremask) == 0 &&
575+
PyArray_DESCR(wheremask)->type_num == NPY_BOOL) {
576+
npy_bool value = *(npy_bool *)PyArray_DATA(wheremask);
577+
if (value) {
578+
/* where=True is the same as no where at all */
579+
wheremask = NULL;
580+
}
581+
else {
582+
/* where=False copies nothing */
583+
return 0;
584+
}
585+
}
586+
587+
if (wheremask == NULL) {
588+
/* A straightforward value assignment */
589+
/* Do the assignment with raw array iteration */
590+
if (raw_array_assign_array(PyArray_NDIM(dst), PyArray_DIMS(dst),
591+
PyArray_DESCR(dst), PyArray_DATA(dst), PyArray_STRIDES(dst),
592+
PyArray_DESCR(src), PyArray_DATA(src), src_strides) < 0) {
593+
goto fail;
594+
}
595+
}
596+
else {
597+
npy_intp wheremask_strides[NPY_MAXDIMS];
598+
599+
/* Broadcast the wheremask to 'dst' for raw iteration */
600+
if (broadcast_strides(PyArray_NDIM(dst), PyArray_DIMS(dst),
601+
PyArray_NDIM(wheremask), PyArray_DIMS(wheremask),
602+
PyArray_STRIDES(wheremask), "where mask",
603+
wheremask_strides) < 0) {
604+
goto fail;
605+
}
606+
607+
/* A straightforward where-masked assignment */
608+
/* Do the masked assignment with raw array iteration */
609+
if (raw_array_wheremasked_assign_array(
610+
PyArray_NDIM(dst), PyArray_DIMS(dst),
611+
PyArray_DESCR(dst), PyArray_DATA(dst), PyArray_STRIDES(dst),
612+
PyArray_DESCR(src), PyArray_DATA(src), src_strides,
613+
PyArray_DESCR(wheremask), PyArray_DATA(wheremask),
614+
wheremask_strides) < 0) {
615+
goto fail;
616+
}
617+
}
618+
619+
if (copied_src) {
620+
Py_DECREF(src);
621+
}
622+
return 0;
623+
624+
fail:
625+
if (copied_src) {
626+
Py_DECREF(src);
627+
}
628+
return -1;
629+
}

numpy/core/src/multiarray/arrayobject.c

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -728,6 +728,7 @@ array_might_be_written(PyArrayObject *obj)
728728
"overlapping memory from np.broadcast_arrays. If this is intentional\n"
729729
"set the WRITEABLE flag True or make a copy immediately before writing.";
730730
if (PyArray_FLAGS(obj) & NPY_ARRAY_WARN_ON_WRITE) {
731+
capi_warn("array_might_be_written: warning...");
731732
if (DEPRECATE(msg) < 0) {
732733
return -1;
733734
}

numpy/core/src/multiarray/convert.c

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -580,18 +580,11 @@ HPyArray_NewCopy(HPyContext *ctx, HPy obj, NPY_ORDER order)
580580
return HPy_NULL;
581581
}
582582

583-
PyObject *py_obj = HPy_AsPyObject(ctx, obj);
584-
PyObject *py_ret = HPy_AsPyObject(ctx, ret);
585-
capi_warn("HPyArray_NewCopy: PyArray_AssignArray");
586-
if (PyArray_AssignArray(py_ret, (PyArrayObject*) py_obj, NULL, NPY_UNSAFE_CASTING) < 0) {
587-
Py_DECREF(py_ret);
588-
Py_DECREF(py_obj);
583+
if (HPyArray_AssignArray(ctx, ret, obj, HPy_NULL, NPY_UNSAFE_CASTING) < 0) {
589584
HPy_Close(ctx, ret);
590585
return HPy_NULL;
591586
}
592587

593-
Py_DECREF(py_ret);
594-
Py_DECREF(py_obj);
595588
return ret;
596589
}
597590

numpy/core/src/multiarray/convert_datatype.c

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -696,6 +696,46 @@ PyArray_CanCastTypeTo(PyArray_Descr *from, PyArray_Descr *to,
696696
return is_valid;
697697
}
698698

699+
NPY_NO_EXPORT npy_bool
700+
HPyArray_CanCastTypeTo(HPyContext *ctx, HPy h_from, HPy h_to,
701+
NPY_CASTING casting)
702+
{
703+
PyArray_Descr *to = PyArray_Descr_AsStruct(ctx, h_to);
704+
705+
/*
706+
* NOTE: This code supports U and S, this is identical to the code
707+
* in `ctors.c` which does not allow these dtypes to be attached
708+
* to an array. Unlike the code for `np.array(..., dtype=)`
709+
* which uses `PyArray_ExtractDTypeAndDescriptor` it rejects "m8"
710+
* as a flexible dtype instance representing a DType.
711+
*/
712+
/*
713+
* TODO: We should grow support for `np.can_cast("d", "S")` being
714+
* different from `np.can_cast("d", "S0")` here, at least for
715+
* the python side API.
716+
* The `to = NULL` branch, which considers "S0" to be "flexible"
717+
* should probably be deprecated.
718+
* (This logic is duplicated in `PyArray_CanCastArrayTo`)
719+
*/
720+
if (PyDataType_ISUNSIZED(to) && to->subarray == NULL) {
721+
to = NULL; /* consider mainly S0 and U0 as S and U */
722+
}
723+
724+
capi_warn("HPyArray_CanCastTypeTo -> PyArray_CheckCastSafety");
725+
HPy to_meta = HPy_Type(ctx, h_to);
726+
int is_valid = PyArray_CheckCastSafety(casting,
727+
PyArray_Descr_AsStruct(ctx, h_from),
728+
PyArray_Descr_AsStruct(ctx, h_to),
729+
PyArray_DTypeMeta_AsStruct(ctx, to_meta));
730+
HPy_Close(ctx, to_meta);
731+
/* Clear any errors and consider this unsafe (should likely be changed) */
732+
if (is_valid < 0) {
733+
HPyErr_Clear(ctx);
734+
return 0;
735+
}
736+
return is_valid;
737+
}
738+
699739

700740
/* CanCastArrayTo needs this function */
701741
static int min_scalar_type_num(char *valueptr, int type_num,

numpy/core/src/multiarray/convert_datatype.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,4 +102,8 @@ simple_cast_resolve_descriptors(
102102
NPY_NO_EXPORT int
103103
PyArray_InitializeCasts(void);
104104

105+
NPY_NO_EXPORT npy_bool
106+
HPyArray_CanCastTypeTo(HPyContext *ctx, HPy h_from, HPy h_to,
107+
NPY_CASTING casting);
108+
105109
#endif /* NUMPY_CORE_SRC_MULTIARRAY_CONVERT_DATATYPE_H_ */

0 commit comments

Comments
 (0)