Skip to content

[mlir][vector] Better handle rank-preserving shape_cast #135855

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 10 additions & 8 deletions mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2244,18 +2244,20 @@ def Vector_ShapeCastOp :
Results<(outs AnyVectorOfAnyRank:$result)> {
let summary = "shape_cast casts between vector shapes";
let description = [{
The shape_cast operation casts between an n-D source vector shape and
a k-D result vector shape (the element type remains the same).
The shape_cast operation casts from an n-D source vector to a k-D result
vector. The element type remains the same, as does the number of elements
(product of dimensions).

A shape_cast must be either collapsing or expanding. Collapsing means all
result dimension sizes are products of contiguous source dimension sizes.
Expanding means source dimensions all factor into contiguous sequences of
destination dimension sizes. Size 1 dimensions in source and destination
are ignored.

If reducing rank (n > k), result dimension sizes must be a product
of contiguous source dimension sizes.
If expanding rank (n < k), source dimensions must factor into a
contiguous sequence of destination dimension sizes.
Each source dim is expanded (or contiguous sequence of source dims combined)
in source dimension list order (i.e. 0 <= i < n), to produce a contiguous
sequence of result dims (or a single result dim), in result dimension list
order (i.e. 0 <= j < k). The product of all source dimension sizes and all
result dimension sizes must match.
order (i.e. 0 <= j < k).

It is currently assumed that this operation does not require moving data,
and that it will be folded away before lowering vector operations.
Expand Down
107 changes: 50 additions & 57 deletions mlir/lib/Dialect/Vector/IR/VectorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5532,75 +5532,72 @@ void ShapeCastOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
setResultRanges(getResult(), argRanges.front());
}

/// Returns true if each element of 'a' is equal to the product of a contiguous
/// sequence of the elements of 'b'. Returns false otherwise.
static bool isValidShapeCast(ArrayRef<int64_t> a, ArrayRef<int64_t> b) {
unsigned rankA = a.size();
unsigned rankB = b.size();
assert(rankA < rankB);

auto isOne = [](int64_t v) { return v == 1; };

// Special-case for n-D to 0-d shape cast. 'b' must be all ones to be shape
// casted to a 0-d vector.
if (rankA == 0 && llvm::all_of(b, isOne))
return true;
/// Returns true if each element of 'a' is either 1 or equal to the product of a
/// contiguous sequence of the elements of 'b'. Returns false otherwise.
///
/// This function assumes that the product of elements in a and b are the same.
static bool isExpandingShapeCast(ArrayRef<int64_t> a, ArrayRef<int64_t> b) {

unsigned rankA = a.size();
unsigned i = 0;
unsigned j = 0;
while (i < rankA && j < rankB) {
while (i < rankA) {
if (a[i] == 1) {
++i;
continue;
}

int64_t dimA = a[i];
int64_t dimB = 1;
while (dimB < dimA && j < rankB)

while (dimB < dimA) {
dimB *= b[j++];
if (dimA != dimB)
break;
++i;
}

// Handle the case when trailing dimensions are of size 1.
// Include them into the contiguous sequence.
if (i < rankA && llvm::all_of(a.slice(i), isOne))
i = rankA;
if (j < rankB && llvm::all_of(b.slice(j), isOne))
j = rankB;
if (dimA != dimB) {
return false;
}
++i;
}
return true;
}

return i == rankA && j == rankB;
static bool isValidShapeCast(ArrayRef<int64_t> a, ArrayRef<int64_t> b) {
return isExpandingShapeCast(a, b) || isExpandingShapeCast(b, a);
}

static LogicalResult verifyVectorShapeCast(Operation *op,
VectorType sourceVectorType,
VectorType resultVectorType) {
// Check that element type is the same.
if (sourceVectorType.getElementType() != resultVectorType.getElementType())
return op->emitOpError("source/result vectors must have same element type");
auto sourceShape = sourceVectorType.getShape();
auto resultShape = resultVectorType.getShape();

// Check that product of source dim sizes matches product of result dim sizes.
int64_t sourceDimProduct = std::accumulate(
sourceShape.begin(), sourceShape.end(), 1LL, std::multiplies<int64_t>{});
int64_t resultDimProduct = std::accumulate(
resultShape.begin(), resultShape.end(), 1LL, std::multiplies<int64_t>{});
if (sourceDimProduct != resultDimProduct)
return op->emitOpError("source/result number of elements must match");

// Check that expanding/contracting rank cases.
unsigned sourceRank = sourceVectorType.getRank();
unsigned resultRank = resultVectorType.getRank();
if (sourceRank < resultRank) {
if (!isValidShapeCast(sourceShape, resultShape))
return op->emitOpError("invalid shape cast");
} else if (sourceRank > resultRank) {
if (!isValidShapeCast(resultShape, sourceShape))
return op->emitOpError("invalid shape cast");
return op->emitOpError("has different source and result element types");
ArrayRef<int64_t> inShape = sourceVectorType.getShape();
ArrayRef<int64_t> outShape = resultVectorType.getShape();

// Check that product of source dim sizes matches product of result dim
// sizes.
int64_t nInElms = std::accumulate(inShape.begin(), inShape.end(), 1LL,
std::multiplies<int64_t>{});
int64_t nOutElms = std::accumulate(outShape.begin(), outShape.end(), 1LL,
std::multiplies<int64_t>{});

if (nInElms != nOutElms) {
return op->emitOpError(
"has a different number of source and result elements");
}

if (!isValidShapeCast(inShape, outShape)) {
return op->emitOpError(
"is invalid (does not uniformly collapse or expand)");
}

// Check that (non-)scalability is preserved
int64_t sourceNScalableDims = sourceVectorType.getNumScalableDims();
int64_t resultNScalableDims = resultVectorType.getNumScalableDims();
if (sourceNScalableDims != resultNScalableDims)
return op->emitOpError("different number of scalable dims at source (")
return op->emitOpError(
"has a different number of scalable dims at source (")
<< sourceNScalableDims << ") and result (" << resultNScalableDims
<< ")";
sourceVectorType.getNumDynamicDims();
Expand Down Expand Up @@ -5634,17 +5631,13 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {

// Only allows valid transitive folding (expand/collapse dimensions).
VectorType srcType = otherOp.getSource().getType();

if (resultType == srcType)
return otherOp.getSource();
if (srcType.getRank() < resultType.getRank()) {
if (!isValidShapeCast(srcType.getShape(), resultType.getShape()))
return {};
} else if (srcType.getRank() > resultType.getRank()) {
if (!isValidShapeCast(resultType.getShape(), srcType.getShape()))
return {};
} else {

if (!isValidShapeCast(srcType.getShape(), resultType.getShape()))
return {};
}

setOperand(otherOp.getSource());
return getResult();
}
Expand Down Expand Up @@ -6459,8 +6452,8 @@ void mlir::vector::MaskOp::print(OpAsmPrinter &p) {
void MaskOp::ensureTerminator(Region &region, Builder &builder, Location loc) {
OpTrait::SingleBlockImplicitTerminator<vector::YieldOp>::Impl<
MaskOp>::ensureTerminator(region, builder, loc);
// Keep the default yield terminator if the number of masked operations is not
// the expected. This case will trigger a verification failure.
// Keep the default yield terminator if the number of masked operations is
// not as expected. This case will trigger a verification failure.
Block &block = region.front();
if (block.getOperations().size() != 2)
return;
Expand Down
26 changes: 14 additions & 12 deletions mlir/test/Dialect/Vector/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -950,14 +950,16 @@ func.func @insert_no_fold_scalar_to_0d(%v: vector<f32>) -> vector<f32> {

// -----

// The definition of shape_cast stipulates that it must be either expanding or collapsing,
// it cannot be a mixture of both.
// CHECK-LABEL: dont_fold_expand_collapse
// CHECK: %[[A:.*]] = vector.shape_cast %{{.*}} : vector<1x1x64xf32> to vector<1x1x8x8xf32>
// CHECK: %[[B:.*]] = vector.shape_cast %{{.*}} : vector<1x1x8x8xf32> to vector<8x8xf32>
// CHECK: return %[[B]] : vector<8x8xf32>
func.func @dont_fold_expand_collapse(%arg0: vector<1x1x64xf32>) -> vector<8x8xf32> {
%0 = vector.shape_cast %arg0 : vector<1x1x64xf32> to vector<1x1x8x8xf32>
%1 = vector.shape_cast %0 : vector<1x1x8x8xf32> to vector<8x8xf32>
return %1 : vector<8x8xf32>
// CHECK: %[[A:.*]] = vector.shape_cast %{{.*}} : vector<2x2x9xf32> to vector<2x2x3x3xf32>
// CHECK: %[[B:.*]] = vector.shape_cast %{{.*}} : vector<2x2x3x3xf32> to vector<4x3x3xf32>
// CHECK: return %[[B]] : vector<4x3x3xf32>
func.func @dont_fold_expand_collapse(%arg0: vector<2x2x9xf32>) -> vector<4x3x3xf32> {
%0 = vector.shape_cast %arg0 : vector<2x2x9xf32> to vector<2x2x3x3xf32>
%1 = vector.shape_cast %0 : vector<2x2x3x3xf32> to vector<4x3x3xf32>
return %1 : vector<4x3x3xf32>
}

// -----
Expand Down Expand Up @@ -1290,12 +1292,12 @@ func.func @extract_strided_broadcast4(%arg0: f32) -> vector<1x4xf32> {
// -----

// CHECK-LABEL: consecutive_shape_cast
// CHECK: %[[C:.*]] = vector.shape_cast %{{.*}} : vector<16xf16> to vector<4x4xf16>
// CHECK-NEXT: return %[[C]] : vector<4x4xf16>
func.func @consecutive_shape_cast(%arg0: vector<16xf16>) -> vector<4x4xf16> {
// CHECK: %[[C:.*]] = vector.shape_cast %{{.*}} : vector<16xf16> to vector<2x2x4xf16>
// CHECK-NEXT: return %[[C]] : vector<2x2x4xf16>
func.func @consecutive_shape_cast(%arg0: vector<16xf16>) -> vector<2x2x4xf16> {
%0 = vector.shape_cast %arg0 : vector<16xf16> to vector<2x8xf16>
%1 = vector.shape_cast %0 : vector<2x8xf16> to vector<4x4xf16>
return %1 : vector<4x4xf16>
%1 = vector.shape_cast %0 : vector<2x8xf16> to vector<2x2x4xf16>
return %1 : vector<2x2x4xf16>
}

// -----
Expand Down
15 changes: 11 additions & 4 deletions mlir/test/Dialect/Vector/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1132,28 +1132,35 @@ func.func @cannot_print_string_with_source_set(%vec: vector<[4]xf32>) {
// -----

func.func @shape_cast_wrong_element_type(%arg0 : vector<5x1x3x2xf32>) {
// expected-error@+1 {{op source/result vectors must have same element type}}
// expected-error@+1 {{op has different source and result element types}}
%0 = vector.shape_cast %arg0 : vector<5x1x3x2xf32> to vector<15x2xi32>
}

// -----

func.func @shape_cast_wrong_num_elements(%arg0 : vector<5x1x3x2xf32>) {
// expected-error@+1 {{op source/result number of elements must match}}
// expected-error@+1 {{op has a different number of source and result elements}}
%0 = vector.shape_cast %arg0 : vector<5x1x3x2xf32> to vector<10x2xf32>
}

// -----

func.func @shape_cast_invalid_rank_preserving(%arg0 : vector<3x2xf32>) {
// expected-error@+1 {{op is invalid (does not uniformly collapse or expand)}}
%0 = vector.shape_cast %arg0 : vector<3x2xf32> to vector<2x3xf32>
}

// -----

func.func @shape_cast_invalid_rank_reduction(%arg0 : vector<5x1x3x2xf32>) {
// expected-error@+1 {{invalid shape cast}}
// expected-error@+1 {{op is invalid (does not uniformly collapse or expand)}}
%0 = vector.shape_cast %arg0 : vector<5x1x3x2xf32> to vector<2x15xf32>
}

// -----

func.func @shape_cast_invalid_rank_expansion(%arg0 : vector<15x2xf32>) {
// expected-error@+1 {{invalid shape cast}}
// expected-error@+1 {{op is invalid (does not uniformly collapse or expand)}}
%0 = vector.shape_cast %arg0 : vector<15x2xf32> to vector<5x2x3x1xf32>
}

Expand Down
28 changes: 28 additions & 0 deletions mlir/test/Dialect/Vector/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -576,6 +576,34 @@ func.func @shape_cast_0d(%arg0 : vector<1x1x1x1xf32>) -> (vector<1x1x1x1xf32>) {
return %1 : vector<1x1x1x1xf32>
}

// CHECK-LABEL: @shape_cast_rank_preserving
func.func @shape_cast_rank_preserving(%arg0 : vector<1x4xf32>) -> vector<4x1xf32> {

// CHECK: vector.shape_cast %{{.*}} : vector<1x4xf32> to vector<4x1xf32>
%0 = vector.shape_cast %arg0 : vector<1x4xf32> to vector<4x1xf32>

return %0 : vector<4x1xf32>
}


// CHECK-LABEL: @collapse_but_increase_rank
func.func @collapse_but_increase_rank(%arg0 : vector<2x3x5x7xf32>) -> vector<1x6x1x35x1xf32> {

// CHECK: vector.shape_cast %{{.*}} : vector<2x3x5x7xf32> to vector<1x6x1x35x1xf32>
%0 = vector.shape_cast %arg0 : vector<2x3x5x7xf32> to vector<1x6x1x35x1xf32>

return %0 : vector<1x6x1x35x1xf32>
}

// CHECK-LABEL: @expand_but_decrease_rank
func.func @expand_but_decrease_rank(%arg0 : vector<1x1x6xi8>) -> vector<2x3xi8> {

// CHECK: vector.shape_cast %{{.*}} : vector<1x1x6xi8> to vector<2x3xi8>
%0 = vector.shape_cast %arg0 : vector<1x1x6xi8> to vector<2x3xi8>

return %0 : vector<2x3xi8>
}

// CHECK-LABEL: @bitcast
func.func @bitcast(%arg0 : vector<5x1x3x2xf32>,
%arg1 : vector<8x1xi32>,
Expand Down
Loading
Loading