|
1 |
| -// RUN: mlir-opt %s -convert-vector-to-arm-sme -convert-arm-sme-to-scf -convert-vector-to-llvm="enable-arm-sme" -cse -canonicalize -split-input-file | FileCheck %s |
| 1 | +// RUN: mlir-opt %s -convert-vector-to-arm-sme -convert-arm-sme-to-scf -convert-vector-to-llvm="enable-arm-sme" -cse -canonicalize -split-input-file -allow-unregistered-dialect -verify-diagnostics | FileCheck %s |
| 2 | + |
| 3 | +//===----------------------------------------------------------------------===// |
| 4 | +// vector.transfer_write |
| 5 | +//===----------------------------------------------------------------------===// |
2 | 6 |
|
3 | 7 | // CHECK-LABEL: @transfer_write_2d_zero_i8(
|
4 | 8 | // CHECK-SAME: %[[ARG0:.*]]: memref<?x?xi8>)
|
@@ -33,6 +37,10 @@ func.func @transfer_write_2d_zero_i8(%arg0 : memref<?x?xi8>) {
|
33 | 37 | return
|
34 | 38 | }
|
35 | 39 |
|
| 40 | +//===----------------------------------------------------------------------===// |
| 41 | +// vector.load |
| 42 | +//===----------------------------------------------------------------------===// |
| 43 | + |
36 | 44 | // -----
|
37 | 45 |
|
38 | 46 | // Load an 8-bit tile from a rank 2 memref with a non-zero offset for the first
|
@@ -232,6 +240,10 @@ func.func @vector_load_i128(%arg0 : memref<?x?xi128>) -> vector<[1]x[1]xi128> {
|
232 | 240 | return %tile : vector<[1]x[1]xi128>
|
233 | 241 | }
|
234 | 242 |
|
| 243 | +//===----------------------------------------------------------------------===// |
| 244 | +// vector.store |
| 245 | +//===----------------------------------------------------------------------===// |
| 246 | + |
235 | 247 | // -----
|
236 | 248 |
|
237 | 249 | // CHECK-LABEL: @vector_store_i8(
|
@@ -391,3 +403,96 @@ func.func @vector_store_i128(%tile : vector<[1]x[1]xi128>, %arg0 : memref<?x?xi1
|
391 | 403 | vector.store %tile, %arg0[%c0, %c0] : memref<?x?xi128>, vector<[1]x[1]xi128>
|
392 | 404 | return
|
393 | 405 | }
|
| 406 | + |
| 407 | +//===----------------------------------------------------------------------===// |
| 408 | +// vector.outerproduct |
| 409 | +//===----------------------------------------------------------------------===// |
| 410 | + |
| 411 | +// ----- |
| 412 | + |
| 413 | +// CHECK-LABEL: @vector_outerproduct_add_f16 |
| 414 | +// CHECK-SAME: (%[[LHS:.*]]: vector<[8]xf16>, %[[RHS:.*]]: vector<[8]xf16>, %[[ACC:.*]]: vector<[8]x[8]xf16>) |
| 415 | +func.func @vector_outerproduct_add_f16(%lhs : vector<[8]xf16>, %rhs : vector<[8]xf16>, %acc : vector<[8]x[8]xf16>) { |
| 416 | + // CHECK: %[[PTRUE_ALL:.*]] = arith.constant dense<true> : vector<[8]xi1> |
| 417 | + // CHECK: %[[CAST_VECTOR_TO_TILE:.*]] = arm_sme.cast_vector_to_tile %[[ACC]] : vector<[8]x[8]xf16> to i16 |
| 418 | + // CHECK: %[[CAST_VECTOR_TO_TILE_I32:.*]] = arith.extui %[[CAST_VECTOR_TO_TILE]] : i16 to i32 |
| 419 | + // CHECK: "arm_sme.intr.mopa"(%[[CAST_VECTOR_TO_TILE_I32]], %[[PTRUE_ALL]], %[[PTRUE_ALL]], %[[LHS]], %[[RHS]]) : (i32, vector<[8]xi1>, vector<[8]xi1>, vector<[8]xf16>, vector<[8]xf16>) |
| 420 | + %0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[8]xf16>, vector<[8]xf16> |
| 421 | + "prevent.dce"(%0) : (vector<[8]x[8]xf16>) -> () |
| 422 | +} |
| 423 | + |
| 424 | +// ----- |
| 425 | + |
| 426 | +// CHECK-LABEL: @vector_outerproduct_add_bf16 |
| 427 | +func.func @vector_outerproduct_add_bf16(%lhs : vector<[8]xbf16>, %rhs : vector<[8]xbf16>, %acc : vector<[8]x[8]xbf16>) { |
| 428 | + // CHECK: "arm_sme.intr.mopa"({{.*}}, {{.*}}, {{.*}}) : (i32, vector<[8]xi1>, vector<[8]xi1>, vector<[8]xbf16>, vector<[8]xbf16>) |
| 429 | + %0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[8]xbf16>, vector<[8]xbf16> |
| 430 | + "prevent.dce"(%0) : (vector<[8]x[8]xbf16>) -> () |
| 431 | +} |
| 432 | + |
| 433 | +// ----- |
| 434 | + |
| 435 | +// CHECK-LABEL: @vector_outerproduct_add_f32 |
| 436 | +func.func @vector_outerproduct_add_f32(%lhs : vector<[4]xf32>, %rhs : vector<[4]xf32>, %acc : vector<[4]x[4]xf32>) { |
| 437 | + // CHECK-NOT: arith.extui |
| 438 | + // CHECK-NOT: arith.trunci |
| 439 | + // CHECK: "arm_sme.intr.mopa"({{.*}}, {{.*}}, {{.*}}) : (i32, vector<[4]xi1>, vector<[4]xi1>, vector<[4]xf32>, vector<[4]xf32>) |
| 440 | + %0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[4]xf32>, vector<[4]xf32> |
| 441 | + "prevent.dce"(%0) : (vector<[4]x[4]xf32>) -> () |
| 442 | +} |
| 443 | + |
| 444 | +// ----- |
| 445 | + |
| 446 | +// CHECK-LABEL: @vector_outerproduct_add_f64 |
| 447 | +func.func @vector_outerproduct_add_f64(%lhs : vector<[2]xf64>, %rhs : vector<[2]xf64>, %acc : vector<[2]x[2]xf64>) { |
| 448 | + // CHECK: arith.trunci {{.*}} : i64 to i32 |
| 449 | + // CHECK: "arm_sme.intr.mopa"({{.*}}, {{.*}}, {{.*}}) : (i32, vector<[2]xi1>, vector<[2]xi1>, vector<[2]xf64>, vector<[2]xf64>) |
| 450 | + %0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[2]xf64>, vector<[2]xf64> |
| 451 | + "prevent.dce"(%0) : (vector<[2]x[2]xf64>) -> () |
| 452 | +} |
| 453 | + |
| 454 | +// ----- |
| 455 | + |
| 456 | +// CHECK-LABEL: @vector_outerproduct_no_accumulator |
| 457 | +func.func @vector_outerproduct_no_accumulator(%lhs : vector<[2]xf64>, %rhs : vector<[2]xf64>) { |
| 458 | + // CHECK: "arm_sme.intr.zero"({{.*}}) : (i32) -> () |
| 459 | + // CHECK: "arm_sme.intr.mopa"({{.*}}, {{.*}}, {{.*}}) : (i32, vector<[2]xi1>, vector<[2]xi1>, vector<[2]xf64>, vector<[2]xf64>) |
| 460 | + %0 = vector.outerproduct %lhs, %rhs {kind = #vector.kind<add>} : vector<[2]xf64>, vector<[2]xf64> |
| 461 | + "prevent.dce"(%0) : (vector<[2]x[2]xf64>) -> () |
| 462 | +} |
| 463 | + |
| 464 | +// ----- |
| 465 | + |
| 466 | +// CHECK-LABEL: @vector_outerproduct_unsupported_axpy |
| 467 | +func.func @vector_outerproduct_unsupported_axpy(%lhs : vector<[2]xf64>, %rhs : f64, %acc : vector<[2]xf64>) -> vector<[2]xf64> { |
| 468 | + // CHECK-NOT: arm_sme |
| 469 | + %0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<mul>} : vector<[2]xf64>, f64 |
| 470 | + return %0 : vector<[2]xf64> |
| 471 | +} |
| 472 | + |
| 473 | +// ----- |
| 474 | + |
| 475 | +func.func @vector_outerproduct_unsupported_type(%lhs : vector<[16]xi8>, %rhs : vector<[16]xi8>, %acc : vector<[16]x[16]xi8>) { |
| 476 | + // expected-error@+2 {{failed to legalize operation 'vector.outerproduct'}} |
| 477 | + // expected-error@+1 {{unsupported type}} |
| 478 | + %0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[16]xi8>, vector<[16]xi8> |
| 479 | + "prevent.dce"(%0) : (vector<[16]x[16]xi8>) -> () |
| 480 | +} |
| 481 | + |
| 482 | +// ----- |
| 483 | + |
| 484 | +func.func @vector_outerproduct_unsupported_kind(%lhs : vector<[2]xf64>, %rhs : vector<[2]xf64>, %acc : vector<[2]x[2]xf64>) { |
| 485 | + // expected-error@+2 {{failed to legalize operation 'vector.outerproduct'}} |
| 486 | + // expected-error@+1 {{unsupported kind}} |
| 487 | + %0 = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<mul>} : vector<[2]xf64>, vector<[2]xf64> |
| 488 | + "prevent.dce"(%0) : (vector<[2]x[2]xf64>) -> () |
| 489 | +} |
| 490 | + |
| 491 | +// ----- |
| 492 | + |
| 493 | +func.func @vector_outerproduct_add_masked_f32(%lhs : vector<[4]xf32>, %rhs : vector<[4]xf32>, %acc : vector<[4]x[4]xf32>, %mask : vector<[4]x[4]xi1>) { |
| 494 | + // expected-error@+2 {{failed to legalize operation 'vector.outerproduct'}} |
| 495 | + // expected-error@+1 {{masking is currently unsupported}} |
| 496 | + %0 = vector.mask %mask { vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[4]xf32>, vector<[4]xf32> } : vector<[4]x[4]xi1> -> vector<[4]x[4]xf32> |
| 497 | + "prevent.dce"(%0) : (vector<[4]x[4]xf32>) -> () |
| 498 | +} |
0 commit comments