@@ -438,3 +438,192 @@ PyArray_AssignArray(PyArrayObject *dst, PyArrayObject *src,
438
438
}
439
439
return -1 ;
440
440
}
441
+
442
+ NPY_NO_EXPORT int
443
+ HPyArray_AssignArray (HPyContext * ctx , HPy h_dst , HPy h_src ,
444
+ HPy h_wheremask ,
445
+ NPY_CASTING casting )
446
+ {
447
+ PyArrayObject * src = PyArrayObject_AsStruct (ctx , h_src );
448
+ PyArrayObject * dst = PyArrayObject_AsStruct (ctx , h_dst );
449
+ int copied_src = 0 ;
450
+
451
+ npy_intp src_strides [NPY_MAXDIMS ];
452
+
453
+ /* Use array_assign_scalar if 'src' NDIM is 0 */
454
+ if (PyArray_NDIM (src ) == 0 ) {
455
+ capi_warn ("HPyArray_AssignArray: PyArray_AssignRawScalar" );
456
+ return PyArray_AssignRawScalar (
457
+ dst , PyArray_DESCR (src ), PyArray_DATA (src ),
458
+ PyArrayObject_AsStruct (ctx , h_wheremask ), casting );
459
+ }
460
+
461
+ HPy h_src_descr = HPyArray_DESCR (ctx , h_src , src );
462
+ HPy h_dst_descr = HPyArray_DESCR (ctx , h_dst , dst );
463
+ /*
464
+ * Performance fix for expressions like "a[1000:6000] += x". In this
465
+ * case, first an in-place add is done, followed by an assignment,
466
+ * equivalently expressed like this:
467
+ *
468
+ * tmp = a[1000:6000] # Calls array_subscript in mapping.c
469
+ * np.add(tmp, x, tmp)
470
+ * a[1000:6000] = tmp # Calls array_assign_subscript in mapping.c
471
+ *
472
+ * In the assignment the underlying data type, shape, strides, and
473
+ * data pointers are identical, but src != dst because they are separately
474
+ * generated slices. By detecting this and skipping the redundant
475
+ * copy of values to themselves, we potentially give a big speed boost.
476
+ *
477
+ * Note that we don't call EquivTypes, because usually the exact same
478
+ * dtype object will appear, and we don't want to slow things down
479
+ * with a complicated comparison. The comparisons are ordered to
480
+ * try and reject this with as little work as possible.
481
+ */
482
+ if (PyArray_DATA (src ) == PyArray_DATA (dst ) &&
483
+ HPy_Is (ctx , h_src_descr , h_dst_descr ) &&
484
+ PyArray_NDIM (src ) == PyArray_NDIM (dst ) &&
485
+ PyArray_CompareLists (PyArray_DIMS (src ),
486
+ PyArray_DIMS (dst ),
487
+ PyArray_NDIM (src )) &&
488
+ PyArray_CompareLists (PyArray_STRIDES (src ),
489
+ PyArray_STRIDES (dst ),
490
+ PyArray_NDIM (src ))) {
491
+ /*printf("Redundant copy operation detected\n");*/
492
+ return 0 ;
493
+ }
494
+
495
+ if (PyArray_FailUnlessWriteable (dst , "assignment destination" ) < 0 ) {
496
+ goto fail ;
497
+ }
498
+
499
+ /* Check the casting rule */
500
+ if (!HPyArray_CanCastTypeTo (ctx , h_src_descr ,
501
+ h_dst_descr , casting )) {
502
+ capi_warn ("HPyArray_AssignArray: npy_set_invalid_cast_error" );
503
+ npy_set_invalid_cast_error (
504
+ PyArray_DESCR (src ), PyArray_DESCR (dst ), casting , NPY_FALSE );
505
+ goto fail ;
506
+ }
507
+
508
+ /*
509
+ * When ndim is 1 and the strides point in the same direction,
510
+ * the lower-level inner loop handles copying
511
+ * of overlapping data. For bigger ndim and opposite-strided 1D
512
+ * data, we make a temporary copy of 'src' if 'src' and 'dst' overlap.'
513
+ */
514
+ capi_warn ("HPyArray_AssignArray: arrays_overlap and reminder of this function..." );
515
+ if (((PyArray_NDIM (dst ) == 1 && PyArray_NDIM (src ) >= 1 &&
516
+ PyArray_STRIDES (dst )[0 ] *
517
+ PyArray_STRIDES (src )[PyArray_NDIM (src ) - 1 ] < 0 ) ||
518
+ PyArray_NDIM (dst ) > 1 || PyArray_HASFIELDS (dst )) &&
519
+ arrays_overlap (src , dst )) {
520
+ PyArrayObject * tmp ;
521
+
522
+ /*
523
+ * Allocate a temporary copy array.
524
+ */
525
+ tmp = (PyArrayObject * )PyArray_NewLikeArray (dst ,
526
+ NPY_KEEPORDER , NULL , 0 );
527
+ if (tmp == NULL ) {
528
+ goto fail ;
529
+ }
530
+
531
+ if (PyArray_AssignArray (tmp , src , NULL , NPY_UNSAFE_CASTING ) < 0 ) {
532
+ Py_DECREF (tmp );
533
+ goto fail ;
534
+ }
535
+
536
+ src = tmp ;
537
+ copied_src = 1 ;
538
+ }
539
+
540
+ /* Broadcast 'src' to 'dst' for raw iteration */
541
+ if (PyArray_NDIM (src ) > PyArray_NDIM (dst )) {
542
+ int ndim_tmp = PyArray_NDIM (src );
543
+ npy_intp * src_shape_tmp = PyArray_DIMS (src );
544
+ npy_intp * src_strides_tmp = PyArray_STRIDES (src );
545
+ /*
546
+ * As a special case for backwards compatibility, strip
547
+ * away unit dimensions from the left of 'src'
548
+ */
549
+ while (ndim_tmp > PyArray_NDIM (dst ) && src_shape_tmp [0 ] == 1 ) {
550
+ -- ndim_tmp ;
551
+ ++ src_shape_tmp ;
552
+ ++ src_strides_tmp ;
553
+ }
554
+
555
+ if (broadcast_strides (PyArray_NDIM (dst ), PyArray_DIMS (dst ),
556
+ ndim_tmp , src_shape_tmp ,
557
+ src_strides_tmp , "input array" ,
558
+ src_strides ) < 0 ) {
559
+ goto fail ;
560
+ }
561
+ }
562
+ else {
563
+ if (broadcast_strides (PyArray_NDIM (dst ), PyArray_DIMS (dst ),
564
+ PyArray_NDIM (src ), PyArray_DIMS (src ),
565
+ PyArray_STRIDES (src ), "input array" ,
566
+ src_strides ) < 0 ) {
567
+ goto fail ;
568
+ }
569
+ }
570
+
571
+ PyArrayObject * wheremask = PyArrayObject_AsStruct (ctx , h_wheremask );
572
+ /* optimization: scalar boolean mask */
573
+ if (wheremask != NULL &&
574
+ PyArray_NDIM (wheremask ) == 0 &&
575
+ PyArray_DESCR (wheremask )-> type_num == NPY_BOOL ) {
576
+ npy_bool value = * (npy_bool * )PyArray_DATA (wheremask );
577
+ if (value ) {
578
+ /* where=True is the same as no where at all */
579
+ wheremask = NULL ;
580
+ }
581
+ else {
582
+ /* where=False copies nothing */
583
+ return 0 ;
584
+ }
585
+ }
586
+
587
+ if (wheremask == NULL ) {
588
+ /* A straightforward value assignment */
589
+ /* Do the assignment with raw array iteration */
590
+ if (raw_array_assign_array (PyArray_NDIM (dst ), PyArray_DIMS (dst ),
591
+ PyArray_DESCR (dst ), PyArray_DATA (dst ), PyArray_STRIDES (dst ),
592
+ PyArray_DESCR (src ), PyArray_DATA (src ), src_strides ) < 0 ) {
593
+ goto fail ;
594
+ }
595
+ }
596
+ else {
597
+ npy_intp wheremask_strides [NPY_MAXDIMS ];
598
+
599
+ /* Broadcast the wheremask to 'dst' for raw iteration */
600
+ if (broadcast_strides (PyArray_NDIM (dst ), PyArray_DIMS (dst ),
601
+ PyArray_NDIM (wheremask ), PyArray_DIMS (wheremask ),
602
+ PyArray_STRIDES (wheremask ), "where mask" ,
603
+ wheremask_strides ) < 0 ) {
604
+ goto fail ;
605
+ }
606
+
607
+ /* A straightforward where-masked assignment */
608
+ /* Do the masked assignment with raw array iteration */
609
+ if (raw_array_wheremasked_assign_array (
610
+ PyArray_NDIM (dst ), PyArray_DIMS (dst ),
611
+ PyArray_DESCR (dst ), PyArray_DATA (dst ), PyArray_STRIDES (dst ),
612
+ PyArray_DESCR (src ), PyArray_DATA (src ), src_strides ,
613
+ PyArray_DESCR (wheremask ), PyArray_DATA (wheremask ),
614
+ wheremask_strides ) < 0 ) {
615
+ goto fail ;
616
+ }
617
+ }
618
+
619
+ if (copied_src ) {
620
+ Py_DECREF (src );
621
+ }
622
+ return 0 ;
623
+
624
+ fail :
625
+ if (copied_src ) {
626
+ Py_DECREF (src );
627
+ }
628
+ return -1 ;
629
+ }
0 commit comments