-
Notifications
You must be signed in to change notification settings - Fork 625
[TorchToLinalg] Direct lowering from Torch to Linalg for torch.aten.convolution_backward #4384
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
4f1cb20 to
8e2b616
Compare
|
@zjgarvey hey! May I ask you to take a look when you're available? Thank you in advance for the review. |
zjgarvey
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice! This is an excellent start.
We need to keep the existing decomposition for other backends. I have a few other comments for you to look at, but that's the biggest blocker right now.
| rewriter, loc, op.getResultTypes()[2], gradOutput, gradIntList, | ||
| cstFalse, cstNone); | ||
| } | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should keep the decomposition, E.g., TOSA and StableHLO still rely on this pattern. The purpose of the backend_legal_ops option in torch-decompose-complex-ops is specifically to prevent selected decomposition patterns.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Got it, thank you for the explanation - I didn't know about this mechanism on the backend sides.
Returned this pass and the lit test in decompose-complex-ops.mlir
| SmallVector<int64_t> weightFlipDims; | ||
| weightFlipDims.reserve(numSpatialDims); | ||
| for (int64_t i = 0; i < static_cast<int64_t>(numSpatialDims); ++i) | ||
| weightFlipDims.push_back(spatialStartDimIdx + i); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the weight shape is static at index i, and the dim size is 1 there, don't add to the flip. We definitely see a lot of 1x1 filter convs and the noop flip doesn't get folded easily IIRC.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agree. I also noticed that I forgot to add condition for numSpatialDims == 1 to not insert flip.
So now we flip kernel dims only when numSpatialDims > 1 and this is not 1x1 kernel. + added lit test
Thanks!
| createZeroInitTensor(rewriter, loc, gradWeightSizes, weightDTy); | ||
| SmallVector<ReassociationIndices> gradWeightCollapseIndices; | ||
| if (isGroupedConvBwd) { | ||
| auto gradWeightInitExpanded = expandGroups(gradWeightInit, 0); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should the init just be made on the expanded shape here (instead of expanding the init)? This probably gets folded, but I think it would be better to generate simpler IR when possible.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This probably gets folded, but I think it would be better to generate simpler IR when possible.
I will know! I thought that it will be fold by any canonicalization pass further. But let's do it ourself here.
Thanks!
| // `c` is the input channel dimension, `f` is the output channel | ||
| // dimension, `o` is the input spatial dimension, `k` is the kernel | ||
| // dimension, `d0` is dilation. `x` is the input tensor, `dLdy` is the | ||
| // gradient of the output tensor. `dLdx` is the data-gradient tensor. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Might be good to mention that dLdy is the stride/padding modified grad output tensor here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And that w is flipped along spatial dims.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added, thanks
| } | ||
|
|
||
| static linalg::GenericOp | ||
| createGenericOp(OpBuilder &b, Location loc, Value in0, Value in1, Value out, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There might be a util for this already like "createReductionGeneric` or something. In any case, might be good to call this something a little more specific (pun intended).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd prefer to have the own impl here for better control and understanding.
Renamed to createConvAsGenericOp.
Thanks!
| if (!isGrouped) { | ||
| if (numSpatialDims == 1) { | ||
| AffineExpr n, c, o, f, k; | ||
| bindDims(context, n, c, o, f, k); | ||
| AffineExpr d0 = rewriter.getAffineConstantExpr(dilationInts[0]); | ||
| SmallVector<AffineExpr> goExprs = {n, f, d0 * k + o}; | ||
| SmallVector<AffineExpr> weiExprs = {f, c, k}; | ||
| SmallVector<AffineExpr> outExprs = {n, c, o}; | ||
| indexingMaps = {AffineMap::get(5, 0, goExprs, context), | ||
| AffineMap::get(5, 0, weiExprs, context), | ||
| AffineMap::get(5, 0, outExprs, context)}; | ||
| iteratorTypes = {IT::parallel, IT::parallel, IT::parallel, | ||
| IT::reduction, IT::reduction}; | ||
| } else if (numSpatialDims == 2) { | ||
| AffineExpr n, c, oh, ow, f, kh, kw; | ||
| bindDims(context, n, c, oh, ow, f, kh, kw); | ||
| AffineExpr d0 = rewriter.getAffineConstantExpr(dilationInts[0]); | ||
| AffineExpr d1 = rewriter.getAffineConstantExpr(dilationInts[1]); | ||
| SmallVector<AffineExpr> goExprs = {n, f, d0 * kh + oh, d1 * kw + ow}; | ||
| SmallVector<AffineExpr> weiExprs = {f, c, kh, kw}; | ||
| SmallVector<AffineExpr> outExprs = {n, c, oh, ow}; | ||
| indexingMaps = {AffineMap::get(7, 0, goExprs, context), | ||
| AffineMap::get(7, 0, weiExprs, context), | ||
| AffineMap::get(7, 0, outExprs, context)}; | ||
| iteratorTypes = {IT::parallel, IT::parallel, IT::parallel, | ||
| IT::parallel, IT::reduction, IT::reduction, | ||
| IT::reduction}; | ||
| } else { | ||
| AffineExpr n, c, od, oh, ow, f, kd, kh, kw; | ||
| bindDims(context, n, c, od, oh, ow, f, kd, kh, kw); | ||
| AffineExpr d0 = rewriter.getAffineConstantExpr(dilationInts[0]); | ||
| AffineExpr d1 = rewriter.getAffineConstantExpr(dilationInts[1]); | ||
| AffineExpr d2 = rewriter.getAffineConstantExpr(dilationInts[2]); | ||
| SmallVector<AffineExpr> goExprs = {n, f, d0 * kd + od, d1 * kh + oh, | ||
| d2 * kw + ow}; | ||
| SmallVector<AffineExpr> weiExprs = {f, c, kd, kh, kw}; | ||
| SmallVector<AffineExpr> outExprs = {n, c, od, oh, ow}; | ||
| indexingMaps = {AffineMap::get(9, 0, goExprs, context), | ||
| AffineMap::get(9, 0, weiExprs, context), | ||
| AffineMap::get(9, 0, outExprs, context)}; | ||
| iteratorTypes = {IT::parallel, IT::parallel, IT::parallel, | ||
| IT::parallel, IT::parallel, IT::reduction, | ||
| IT::reduction, IT::reduction, IT::reduction}; | ||
| } | ||
| } else { | ||
| if (numSpatialDims == 1) { | ||
| AffineExpr n, g, cg, o, fg, k; | ||
| bindDims(context, n, g, cg, o, fg, k); | ||
| AffineExpr d0 = rewriter.getAffineConstantExpr(dilationInts[0]); | ||
| SmallVector<AffineExpr> goExprs = {n, g, fg, d0 * k + o}; | ||
| SmallVector<AffineExpr> weiExprs = {g, fg, cg, k}; | ||
| SmallVector<AffineExpr> outExprs = {n, g, cg, o}; | ||
| indexingMaps = {AffineMap::get(6, 0, goExprs, context), | ||
| AffineMap::get(6, 0, weiExprs, context), | ||
| AffineMap::get(6, 0, outExprs, context)}; | ||
| iteratorTypes = {IT::parallel, IT::parallel, IT::parallel, | ||
| IT::parallel, IT::reduction, IT::reduction}; | ||
| } else if (numSpatialDims == 2) { | ||
| AffineExpr n, g, cg, oh, ow, fg, kh, kw; | ||
| bindDims(context, n, g, cg, oh, ow, fg, kh, kw); | ||
| AffineExpr d0 = rewriter.getAffineConstantExpr(dilationInts[0]); | ||
| AffineExpr d1 = rewriter.getAffineConstantExpr(dilationInts[1]); | ||
| SmallVector<AffineExpr> goExprs = {n, g, fg, d0 * kh + oh, | ||
| d1 * kw + ow}; | ||
| SmallVector<AffineExpr> weiExprs = {g, fg, cg, kh, kw}; | ||
| SmallVector<AffineExpr> outExprs = {n, g, cg, oh, ow}; | ||
| indexingMaps = {AffineMap::get(8, 0, goExprs, context), | ||
| AffineMap::get(8, 0, weiExprs, context), | ||
| AffineMap::get(8, 0, outExprs, context)}; | ||
| iteratorTypes = {IT::parallel, IT::parallel, IT::parallel, | ||
| IT::parallel, IT::parallel, IT::reduction, | ||
| IT::reduction, IT::reduction}; | ||
| } else { | ||
| AffineExpr n, g, cg, od, oh, ow, fg, kd, kh, kw; | ||
| bindDims(context, n, g, cg, od, oh, ow, fg, kd, kh, kw); | ||
| AffineExpr d0 = rewriter.getAffineConstantExpr(dilationInts[0]); | ||
| AffineExpr d1 = rewriter.getAffineConstantExpr(dilationInts[1]); | ||
| AffineExpr d2 = rewriter.getAffineConstantExpr(dilationInts[2]); | ||
| SmallVector<AffineExpr> goExprs = { | ||
| n, g, fg, d0 * kd + od, d1 * kh + oh, d2 * kw + ow}; | ||
| SmallVector<AffineExpr> weiExprs = {g, fg, cg, kd, kh, kw}; | ||
| SmallVector<AffineExpr> outExprs = {n, g, cg, od, oh, ow}; | ||
| indexingMaps = {AffineMap::get(10, 0, goExprs, context), | ||
| AffineMap::get(10, 0, weiExprs, context), | ||
| AffineMap::get(10, 0, outExprs, context)}; | ||
| iteratorTypes = {IT::parallel, IT::parallel, IT::parallel, | ||
| IT::parallel, IT::parallel, IT::parallel, | ||
| IT::reduction, IT::reduction, IT::reduction, | ||
| IT::reduction}; | ||
| } | ||
| } | ||
| } | ||
|
|
||
| static void initIndexingMapsAndIteratorTypesForWeightBwd( | ||
| OpBuilder &rewriter, MLIRContext *context, bool isGrouped, | ||
| int numSpatialDims, const SmallVector<int64_t> &strideInts, | ||
| const SmallVector<int64_t> &dilationInts, | ||
| SmallVector<AffineMap> &indexingMaps, SmallVector<IT> &iteratorTypes) { | ||
| // To calculate convolution backward-weight, we use generic operation. | ||
| // The generic operation is a generalization of the convolution operation | ||
| // that can handle any number of spatial dimensions. | ||
| // The generic operation is defined as follows: | ||
| // ``` | ||
| // dLdw[f, g, c, k] = sum(x[n, g, c, d0 * k + s0 * o] * dLdy[n, g, f, o] | ||
| // for n in range(batch_size) for o in range(output_spatial_dims)) | ||
| // ``` | ||
| // where `n` is the batch dimension, `g` is the group dimension, | ||
| // `c` is the input channel dimension, `f` is the output channel | ||
| // dimension, `o` is the output spatial dimension, `k` is the kernel | ||
| // dimension, `d0` is dilation and `s0` is stride. `x` is the input | ||
| // tensor, `dLdy` is the gradient of the output tensor. `dLdw` is the | ||
| // weight-gradient tensor. | ||
| if (!isGrouped) { | ||
| if (numSpatialDims == 1) { | ||
| AffineExpr f, c, k, n, o; | ||
| bindDims(context, f, c, k, n, o); | ||
| AffineExpr s0 = rewriter.getAffineConstantExpr(strideInts[0]); | ||
| AffineExpr d0 = rewriter.getAffineConstantExpr(dilationInts[0]); | ||
| SmallVector<AffineExpr> inExprs = {n, c, d0 * k + s0 * o}; | ||
| SmallVector<AffineExpr> goExprs = {n, f, o}; | ||
| SmallVector<AffineExpr> outExprs = {f, c, k}; | ||
| indexingMaps = {AffineMap::get(5, 0, inExprs, context), | ||
| AffineMap::get(5, 0, goExprs, context), | ||
| AffineMap::get(5, 0, outExprs, context)}; | ||
| iteratorTypes = {IT::parallel, IT::parallel, IT::parallel, | ||
| IT::reduction, IT::reduction}; | ||
| } else if (numSpatialDims == 2) { | ||
| AffineExpr f, c, kh, kw, n, oh, ow; | ||
| bindDims(context, f, c, kh, kw, n, oh, ow); | ||
| AffineExpr s0 = rewriter.getAffineConstantExpr(strideInts[0]); | ||
| AffineExpr s1 = rewriter.getAffineConstantExpr(strideInts[1]); | ||
| AffineExpr d0 = rewriter.getAffineConstantExpr(dilationInts[0]); | ||
| AffineExpr d1 = rewriter.getAffineConstantExpr(dilationInts[1]); | ||
| SmallVector<AffineExpr> inExprs = {n, c, d0 * kh + s0 * oh, | ||
| d1 * kw + s1 * ow}; | ||
| SmallVector<AffineExpr> goExprs = {n, f, oh, ow}; | ||
| SmallVector<AffineExpr> outExprs = {f, c, kh, kw}; | ||
| indexingMaps = {AffineMap::get(7, 0, inExprs, context), | ||
| AffineMap::get(7, 0, goExprs, context), | ||
| AffineMap::get(7, 0, outExprs, context)}; | ||
| iteratorTypes = {IT::parallel, IT::parallel, IT::parallel, | ||
| IT::parallel, IT::reduction, IT::reduction, | ||
| IT::reduction}; | ||
| } else { | ||
| AffineExpr f, c, kd, kh, kw, n, od, oh, ow; | ||
| bindDims(context, f, c, kd, kh, kw, n, od, oh, ow); | ||
| AffineExpr s0 = rewriter.getAffineConstantExpr(strideInts[0]); | ||
| AffineExpr s1 = rewriter.getAffineConstantExpr(strideInts[1]); | ||
| AffineExpr s2 = rewriter.getAffineConstantExpr(strideInts[2]); | ||
| AffineExpr d0 = rewriter.getAffineConstantExpr(dilationInts[0]); | ||
| AffineExpr d1 = rewriter.getAffineConstantExpr(dilationInts[1]); | ||
| AffineExpr d2 = rewriter.getAffineConstantExpr(dilationInts[2]); | ||
| SmallVector<AffineExpr> inExprs = { | ||
| n, c, d0 * kd + s0 * od, d1 * kh + s1 * oh, d2 * kw + s2 * ow}; | ||
| SmallVector<AffineExpr> goExprs = {n, f, od, oh, ow}; | ||
| SmallVector<AffineExpr> outExprs = {f, c, kd, kh, kw}; | ||
| indexingMaps = {AffineMap::get(9, 0, inExprs, context), | ||
| AffineMap::get(9, 0, goExprs, context), | ||
| AffineMap::get(9, 0, outExprs, context)}; | ||
| iteratorTypes = {IT::parallel, IT::parallel, IT::parallel, | ||
| IT::parallel, IT::parallel, IT::reduction, | ||
| IT::reduction, IT::reduction, IT::reduction}; | ||
| } | ||
| } else { | ||
| if (numSpatialDims == 1) { | ||
| AffineExpr g, fg, cg, k, n, o; | ||
| bindDims(context, g, fg, cg, k, n, o); | ||
| AffineExpr s0 = rewriter.getAffineConstantExpr(strideInts[0]); | ||
| AffineExpr d0 = rewriter.getAffineConstantExpr(dilationInts[0]); | ||
| SmallVector<AffineExpr> inExprs = {n, g, cg, d0 * k + s0 * o}; | ||
| SmallVector<AffineExpr> goExprs = {n, g, fg, o}; | ||
| SmallVector<AffineExpr> outExprs = {g, fg, cg, k}; | ||
| indexingMaps = {AffineMap::get(6, 0, inExprs, context), | ||
| AffineMap::get(6, 0, goExprs, context), | ||
| AffineMap::get(6, 0, outExprs, context)}; | ||
| iteratorTypes = {IT::parallel, IT::parallel, IT::parallel, | ||
| IT::parallel, IT::reduction, IT::reduction}; | ||
| } else if (numSpatialDims == 2) { | ||
| AffineExpr g, fg, cg, kh, kw, n, oh, ow; | ||
| bindDims(context, g, fg, cg, kh, kw, n, oh, ow); | ||
| AffineExpr s0 = rewriter.getAffineConstantExpr(strideInts[0]); | ||
| AffineExpr s1 = rewriter.getAffineConstantExpr(strideInts[1]); | ||
| AffineExpr d0 = rewriter.getAffineConstantExpr(dilationInts[0]); | ||
| AffineExpr d1 = rewriter.getAffineConstantExpr(dilationInts[1]); | ||
| SmallVector<AffineExpr> inExprs = {n, g, cg, d0 * kh + s0 * oh, | ||
| d1 * kw + s1 * ow}; | ||
| SmallVector<AffineExpr> goExprs = {n, g, fg, oh, ow}; | ||
| SmallVector<AffineExpr> outExprs = {g, fg, cg, kh, kw}; | ||
| indexingMaps = {AffineMap::get(8, 0, inExprs, context), | ||
| AffineMap::get(8, 0, goExprs, context), | ||
| AffineMap::get(8, 0, outExprs, context)}; | ||
| iteratorTypes = {IT::parallel, IT::parallel, IT::parallel, | ||
| IT::parallel, IT::parallel, IT::reduction, | ||
| IT::reduction, IT::reduction}; | ||
| } else { | ||
| AffineExpr g, fg, cg, kd, kh, kw, n, od, oh, ow; | ||
| bindDims(context, g, fg, cg, kd, kh, kw, n, od, oh, ow); | ||
| AffineExpr s0 = rewriter.getAffineConstantExpr(strideInts[0]); | ||
| AffineExpr s1 = rewriter.getAffineConstantExpr(strideInts[1]); | ||
| AffineExpr s2 = rewriter.getAffineConstantExpr(strideInts[2]); | ||
| AffineExpr d0 = rewriter.getAffineConstantExpr(dilationInts[0]); | ||
| AffineExpr d1 = rewriter.getAffineConstantExpr(dilationInts[1]); | ||
| AffineExpr d2 = rewriter.getAffineConstantExpr(dilationInts[2]); | ||
| SmallVector<AffineExpr> inExprs = { | ||
| n, g, cg, d0 * kd + s0 * od, d1 * kh + s1 * oh, d2 * kw + s2 * ow}; | ||
| SmallVector<AffineExpr> goExprs = {n, g, fg, od, oh, ow}; | ||
| SmallVector<AffineExpr> outExprs = {g, fg, cg, kd, kh, kw}; | ||
| indexingMaps = {AffineMap::get(10, 0, inExprs, context), | ||
| AffineMap::get(10, 0, goExprs, context), | ||
| AffineMap::get(10, 0, outExprs, context)}; | ||
| iteratorTypes = {IT::parallel, IT::parallel, IT::parallel, | ||
| IT::parallel, IT::parallel, IT::parallel, | ||
| IT::reduction, IT::reduction, IT::reduction, | ||
| IT::reduction}; | ||
| } | ||
| } | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There must be a better way.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
E.g., you could make the AffineExprs for stride, dilation, spatial dims, etc. SmallVector<AffineExpr>. I don't even think there need to be conditionals on anything other than like:
SmallVector<AffineExpr> lhsExprs = isGrouped ? {n, g, c} : {n, c};
// loop over spatial dims and add expressions...Everything else can be like:
int64_t numIterators = 3; // batch, parallel channel, reduction channel
numIterators += static_cast<int64_t>(isGrouped);
numIterators += numSpatialDims*2 // parallel spatial dims, reduction spatial dims
indexingMaps = {
AffineMap::get(numIterators, lhsExprs, context),
AffineMap::get(numIterators, rhsExprs, context),
AffineMap::get(numIterators, outExprs, context)
};There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great idea, thanks for that. Implemented.
a-sidorova
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@zjgarvey thank you for review! I have applied your comments in the latest commit. Could you please take a look at the changes one more time?
| rewriter, loc, op.getResultTypes()[2], gradOutput, gradIntList, | ||
| cstFalse, cstNone); | ||
| } | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Got it, thank you for the explanation - I didn't know about this mechanism on the backend sides.
Returned this pass and the lit test in decompose-complex-ops.mlir
| SmallVector<int64_t> weightFlipDims; | ||
| weightFlipDims.reserve(numSpatialDims); | ||
| for (int64_t i = 0; i < static_cast<int64_t>(numSpatialDims); ++i) | ||
| weightFlipDims.push_back(spatialStartDimIdx + i); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agree. I also noticed that I forgot to add condition for numSpatialDims == 1 to not insert flip.
So now we flip kernel dims only when numSpatialDims > 1 and this is not 1x1 kernel. + added lit test
Thanks!
| createZeroInitTensor(rewriter, loc, gradWeightSizes, weightDTy); | ||
| SmallVector<ReassociationIndices> gradWeightCollapseIndices; | ||
| if (isGroupedConvBwd) { | ||
| auto gradWeightInitExpanded = expandGroups(gradWeightInit, 0); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This probably gets folded, but I think it would be better to generate simpler IR when possible.
I will know! I thought that it will be fold by any canonicalization pass further. But let's do it ourself here.
Thanks!
| if (!isGrouped) { | ||
| if (numSpatialDims == 1) { | ||
| AffineExpr n, c, o, f, k; | ||
| bindDims(context, n, c, o, f, k); | ||
| AffineExpr d0 = rewriter.getAffineConstantExpr(dilationInts[0]); | ||
| SmallVector<AffineExpr> goExprs = {n, f, d0 * k + o}; | ||
| SmallVector<AffineExpr> weiExprs = {f, c, k}; | ||
| SmallVector<AffineExpr> outExprs = {n, c, o}; | ||
| indexingMaps = {AffineMap::get(5, 0, goExprs, context), | ||
| AffineMap::get(5, 0, weiExprs, context), | ||
| AffineMap::get(5, 0, outExprs, context)}; | ||
| iteratorTypes = {IT::parallel, IT::parallel, IT::parallel, | ||
| IT::reduction, IT::reduction}; | ||
| } else if (numSpatialDims == 2) { | ||
| AffineExpr n, c, oh, ow, f, kh, kw; | ||
| bindDims(context, n, c, oh, ow, f, kh, kw); | ||
| AffineExpr d0 = rewriter.getAffineConstantExpr(dilationInts[0]); | ||
| AffineExpr d1 = rewriter.getAffineConstantExpr(dilationInts[1]); | ||
| SmallVector<AffineExpr> goExprs = {n, f, d0 * kh + oh, d1 * kw + ow}; | ||
| SmallVector<AffineExpr> weiExprs = {f, c, kh, kw}; | ||
| SmallVector<AffineExpr> outExprs = {n, c, oh, ow}; | ||
| indexingMaps = {AffineMap::get(7, 0, goExprs, context), | ||
| AffineMap::get(7, 0, weiExprs, context), | ||
| AffineMap::get(7, 0, outExprs, context)}; | ||
| iteratorTypes = {IT::parallel, IT::parallel, IT::parallel, | ||
| IT::parallel, IT::reduction, IT::reduction, | ||
| IT::reduction}; | ||
| } else { | ||
| AffineExpr n, c, od, oh, ow, f, kd, kh, kw; | ||
| bindDims(context, n, c, od, oh, ow, f, kd, kh, kw); | ||
| AffineExpr d0 = rewriter.getAffineConstantExpr(dilationInts[0]); | ||
| AffineExpr d1 = rewriter.getAffineConstantExpr(dilationInts[1]); | ||
| AffineExpr d2 = rewriter.getAffineConstantExpr(dilationInts[2]); | ||
| SmallVector<AffineExpr> goExprs = {n, f, d0 * kd + od, d1 * kh + oh, | ||
| d2 * kw + ow}; | ||
| SmallVector<AffineExpr> weiExprs = {f, c, kd, kh, kw}; | ||
| SmallVector<AffineExpr> outExprs = {n, c, od, oh, ow}; | ||
| indexingMaps = {AffineMap::get(9, 0, goExprs, context), | ||
| AffineMap::get(9, 0, weiExprs, context), | ||
| AffineMap::get(9, 0, outExprs, context)}; | ||
| iteratorTypes = {IT::parallel, IT::parallel, IT::parallel, | ||
| IT::parallel, IT::parallel, IT::reduction, | ||
| IT::reduction, IT::reduction, IT::reduction}; | ||
| } | ||
| } else { | ||
| if (numSpatialDims == 1) { | ||
| AffineExpr n, g, cg, o, fg, k; | ||
| bindDims(context, n, g, cg, o, fg, k); | ||
| AffineExpr d0 = rewriter.getAffineConstantExpr(dilationInts[0]); | ||
| SmallVector<AffineExpr> goExprs = {n, g, fg, d0 * k + o}; | ||
| SmallVector<AffineExpr> weiExprs = {g, fg, cg, k}; | ||
| SmallVector<AffineExpr> outExprs = {n, g, cg, o}; | ||
| indexingMaps = {AffineMap::get(6, 0, goExprs, context), | ||
| AffineMap::get(6, 0, weiExprs, context), | ||
| AffineMap::get(6, 0, outExprs, context)}; | ||
| iteratorTypes = {IT::parallel, IT::parallel, IT::parallel, | ||
| IT::parallel, IT::reduction, IT::reduction}; | ||
| } else if (numSpatialDims == 2) { | ||
| AffineExpr n, g, cg, oh, ow, fg, kh, kw; | ||
| bindDims(context, n, g, cg, oh, ow, fg, kh, kw); | ||
| AffineExpr d0 = rewriter.getAffineConstantExpr(dilationInts[0]); | ||
| AffineExpr d1 = rewriter.getAffineConstantExpr(dilationInts[1]); | ||
| SmallVector<AffineExpr> goExprs = {n, g, fg, d0 * kh + oh, | ||
| d1 * kw + ow}; | ||
| SmallVector<AffineExpr> weiExprs = {g, fg, cg, kh, kw}; | ||
| SmallVector<AffineExpr> outExprs = {n, g, cg, oh, ow}; | ||
| indexingMaps = {AffineMap::get(8, 0, goExprs, context), | ||
| AffineMap::get(8, 0, weiExprs, context), | ||
| AffineMap::get(8, 0, outExprs, context)}; | ||
| iteratorTypes = {IT::parallel, IT::parallel, IT::parallel, | ||
| IT::parallel, IT::parallel, IT::reduction, | ||
| IT::reduction, IT::reduction}; | ||
| } else { | ||
| AffineExpr n, g, cg, od, oh, ow, fg, kd, kh, kw; | ||
| bindDims(context, n, g, cg, od, oh, ow, fg, kd, kh, kw); | ||
| AffineExpr d0 = rewriter.getAffineConstantExpr(dilationInts[0]); | ||
| AffineExpr d1 = rewriter.getAffineConstantExpr(dilationInts[1]); | ||
| AffineExpr d2 = rewriter.getAffineConstantExpr(dilationInts[2]); | ||
| SmallVector<AffineExpr> goExprs = { | ||
| n, g, fg, d0 * kd + od, d1 * kh + oh, d2 * kw + ow}; | ||
| SmallVector<AffineExpr> weiExprs = {g, fg, cg, kd, kh, kw}; | ||
| SmallVector<AffineExpr> outExprs = {n, g, cg, od, oh, ow}; | ||
| indexingMaps = {AffineMap::get(10, 0, goExprs, context), | ||
| AffineMap::get(10, 0, weiExprs, context), | ||
| AffineMap::get(10, 0, outExprs, context)}; | ||
| iteratorTypes = {IT::parallel, IT::parallel, IT::parallel, | ||
| IT::parallel, IT::parallel, IT::parallel, | ||
| IT::reduction, IT::reduction, IT::reduction, | ||
| IT::reduction}; | ||
| } | ||
| } | ||
| } | ||
|
|
||
| static void initIndexingMapsAndIteratorTypesForWeightBwd( | ||
| OpBuilder &rewriter, MLIRContext *context, bool isGrouped, | ||
| int numSpatialDims, const SmallVector<int64_t> &strideInts, | ||
| const SmallVector<int64_t> &dilationInts, | ||
| SmallVector<AffineMap> &indexingMaps, SmallVector<IT> &iteratorTypes) { | ||
| // To calculate convolution backward-weight, we use generic operation. | ||
| // The generic operation is a generalization of the convolution operation | ||
| // that can handle any number of spatial dimensions. | ||
| // The generic operation is defined as follows: | ||
| // ``` | ||
| // dLdw[f, g, c, k] = sum(x[n, g, c, d0 * k + s0 * o] * dLdy[n, g, f, o] | ||
| // for n in range(batch_size) for o in range(output_spatial_dims)) | ||
| // ``` | ||
| // where `n` is the batch dimension, `g` is the group dimension, | ||
| // `c` is the input channel dimension, `f` is the output channel | ||
| // dimension, `o` is the output spatial dimension, `k` is the kernel | ||
| // dimension, `d0` is dilation and `s0` is stride. `x` is the input | ||
| // tensor, `dLdy` is the gradient of the output tensor. `dLdw` is the | ||
| // weight-gradient tensor. | ||
| if (!isGrouped) { | ||
| if (numSpatialDims == 1) { | ||
| AffineExpr f, c, k, n, o; | ||
| bindDims(context, f, c, k, n, o); | ||
| AffineExpr s0 = rewriter.getAffineConstantExpr(strideInts[0]); | ||
| AffineExpr d0 = rewriter.getAffineConstantExpr(dilationInts[0]); | ||
| SmallVector<AffineExpr> inExprs = {n, c, d0 * k + s0 * o}; | ||
| SmallVector<AffineExpr> goExprs = {n, f, o}; | ||
| SmallVector<AffineExpr> outExprs = {f, c, k}; | ||
| indexingMaps = {AffineMap::get(5, 0, inExprs, context), | ||
| AffineMap::get(5, 0, goExprs, context), | ||
| AffineMap::get(5, 0, outExprs, context)}; | ||
| iteratorTypes = {IT::parallel, IT::parallel, IT::parallel, | ||
| IT::reduction, IT::reduction}; | ||
| } else if (numSpatialDims == 2) { | ||
| AffineExpr f, c, kh, kw, n, oh, ow; | ||
| bindDims(context, f, c, kh, kw, n, oh, ow); | ||
| AffineExpr s0 = rewriter.getAffineConstantExpr(strideInts[0]); | ||
| AffineExpr s1 = rewriter.getAffineConstantExpr(strideInts[1]); | ||
| AffineExpr d0 = rewriter.getAffineConstantExpr(dilationInts[0]); | ||
| AffineExpr d1 = rewriter.getAffineConstantExpr(dilationInts[1]); | ||
| SmallVector<AffineExpr> inExprs = {n, c, d0 * kh + s0 * oh, | ||
| d1 * kw + s1 * ow}; | ||
| SmallVector<AffineExpr> goExprs = {n, f, oh, ow}; | ||
| SmallVector<AffineExpr> outExprs = {f, c, kh, kw}; | ||
| indexingMaps = {AffineMap::get(7, 0, inExprs, context), | ||
| AffineMap::get(7, 0, goExprs, context), | ||
| AffineMap::get(7, 0, outExprs, context)}; | ||
| iteratorTypes = {IT::parallel, IT::parallel, IT::parallel, | ||
| IT::parallel, IT::reduction, IT::reduction, | ||
| IT::reduction}; | ||
| } else { | ||
| AffineExpr f, c, kd, kh, kw, n, od, oh, ow; | ||
| bindDims(context, f, c, kd, kh, kw, n, od, oh, ow); | ||
| AffineExpr s0 = rewriter.getAffineConstantExpr(strideInts[0]); | ||
| AffineExpr s1 = rewriter.getAffineConstantExpr(strideInts[1]); | ||
| AffineExpr s2 = rewriter.getAffineConstantExpr(strideInts[2]); | ||
| AffineExpr d0 = rewriter.getAffineConstantExpr(dilationInts[0]); | ||
| AffineExpr d1 = rewriter.getAffineConstantExpr(dilationInts[1]); | ||
| AffineExpr d2 = rewriter.getAffineConstantExpr(dilationInts[2]); | ||
| SmallVector<AffineExpr> inExprs = { | ||
| n, c, d0 * kd + s0 * od, d1 * kh + s1 * oh, d2 * kw + s2 * ow}; | ||
| SmallVector<AffineExpr> goExprs = {n, f, od, oh, ow}; | ||
| SmallVector<AffineExpr> outExprs = {f, c, kd, kh, kw}; | ||
| indexingMaps = {AffineMap::get(9, 0, inExprs, context), | ||
| AffineMap::get(9, 0, goExprs, context), | ||
| AffineMap::get(9, 0, outExprs, context)}; | ||
| iteratorTypes = {IT::parallel, IT::parallel, IT::parallel, | ||
| IT::parallel, IT::parallel, IT::reduction, | ||
| IT::reduction, IT::reduction, IT::reduction}; | ||
| } | ||
| } else { | ||
| if (numSpatialDims == 1) { | ||
| AffineExpr g, fg, cg, k, n, o; | ||
| bindDims(context, g, fg, cg, k, n, o); | ||
| AffineExpr s0 = rewriter.getAffineConstantExpr(strideInts[0]); | ||
| AffineExpr d0 = rewriter.getAffineConstantExpr(dilationInts[0]); | ||
| SmallVector<AffineExpr> inExprs = {n, g, cg, d0 * k + s0 * o}; | ||
| SmallVector<AffineExpr> goExprs = {n, g, fg, o}; | ||
| SmallVector<AffineExpr> outExprs = {g, fg, cg, k}; | ||
| indexingMaps = {AffineMap::get(6, 0, inExprs, context), | ||
| AffineMap::get(6, 0, goExprs, context), | ||
| AffineMap::get(6, 0, outExprs, context)}; | ||
| iteratorTypes = {IT::parallel, IT::parallel, IT::parallel, | ||
| IT::parallel, IT::reduction, IT::reduction}; | ||
| } else if (numSpatialDims == 2) { | ||
| AffineExpr g, fg, cg, kh, kw, n, oh, ow; | ||
| bindDims(context, g, fg, cg, kh, kw, n, oh, ow); | ||
| AffineExpr s0 = rewriter.getAffineConstantExpr(strideInts[0]); | ||
| AffineExpr s1 = rewriter.getAffineConstantExpr(strideInts[1]); | ||
| AffineExpr d0 = rewriter.getAffineConstantExpr(dilationInts[0]); | ||
| AffineExpr d1 = rewriter.getAffineConstantExpr(dilationInts[1]); | ||
| SmallVector<AffineExpr> inExprs = {n, g, cg, d0 * kh + s0 * oh, | ||
| d1 * kw + s1 * ow}; | ||
| SmallVector<AffineExpr> goExprs = {n, g, fg, oh, ow}; | ||
| SmallVector<AffineExpr> outExprs = {g, fg, cg, kh, kw}; | ||
| indexingMaps = {AffineMap::get(8, 0, inExprs, context), | ||
| AffineMap::get(8, 0, goExprs, context), | ||
| AffineMap::get(8, 0, outExprs, context)}; | ||
| iteratorTypes = {IT::parallel, IT::parallel, IT::parallel, | ||
| IT::parallel, IT::parallel, IT::reduction, | ||
| IT::reduction, IT::reduction}; | ||
| } else { | ||
| AffineExpr g, fg, cg, kd, kh, kw, n, od, oh, ow; | ||
| bindDims(context, g, fg, cg, kd, kh, kw, n, od, oh, ow); | ||
| AffineExpr s0 = rewriter.getAffineConstantExpr(strideInts[0]); | ||
| AffineExpr s1 = rewriter.getAffineConstantExpr(strideInts[1]); | ||
| AffineExpr s2 = rewriter.getAffineConstantExpr(strideInts[2]); | ||
| AffineExpr d0 = rewriter.getAffineConstantExpr(dilationInts[0]); | ||
| AffineExpr d1 = rewriter.getAffineConstantExpr(dilationInts[1]); | ||
| AffineExpr d2 = rewriter.getAffineConstantExpr(dilationInts[2]); | ||
| SmallVector<AffineExpr> inExprs = { | ||
| n, g, cg, d0 * kd + s0 * od, d1 * kh + s1 * oh, d2 * kw + s2 * ow}; | ||
| SmallVector<AffineExpr> goExprs = {n, g, fg, od, oh, ow}; | ||
| SmallVector<AffineExpr> outExprs = {g, fg, cg, kd, kh, kw}; | ||
| indexingMaps = {AffineMap::get(10, 0, inExprs, context), | ||
| AffineMap::get(10, 0, goExprs, context), | ||
| AffineMap::get(10, 0, outExprs, context)}; | ||
| iteratorTypes = {IT::parallel, IT::parallel, IT::parallel, | ||
| IT::parallel, IT::parallel, IT::parallel, | ||
| IT::reduction, IT::reduction, IT::reduction, | ||
| IT::reduction}; | ||
| } | ||
| } | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great idea, thanks for that. Implemented.
| // `c` is the input channel dimension, `f` is the output channel | ||
| // dimension, `o` is the input spatial dimension, `k` is the kernel | ||
| // dimension, `d0` is dilation. `x` is the input tensor, `dLdy` is the | ||
| // gradient of the output tensor. `dLdx` is the data-gradient tensor. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added, thanks
| } | ||
|
|
||
| static linalg::GenericOp | ||
| createGenericOp(OpBuilder &b, Location loc, Value in0, Value in1, Value out, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd prefer to have the own impl here for better control and understanding.
Renamed to createConvAsGenericOp.
Thanks!
b20062d to
30dac4b
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
30dac4b to
31eac08
Compare
zjgarvey
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A few more comments before we merge. Sorry if this took a while to get back to.
| } | ||
|
|
||
| auto retType = inType.clone(makeShapeLLVMCompatible(outShape)); | ||
| return tensor::ExpandShapeOp::create(rewriter, loc, retType, tensor, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Although I think the rewriter gets properly updated these days, It's definitely preferred to use the rewriter hooks for any IR modification instead of Operation methods. I.e., use rewriter.create<OpTy>(...) instead. Sorry, I forgot to include this in my previous review.
Some info here:
https://mlir.llvm.org/docs/PatternRewriter/#matchandrewrite-implementation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This goes for the other ::create methods being called below (there are a few of these).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When I replace Op::create(...) with rewriter.create<Op>(...), I get the following warning during build (deprecated method):
cmake --build . -j128 --target torch-mlir-opt
[1/3] Building CXX object tools/torch-mlir/lib/Conversion/TorchToLinalg/CMakeFiles/obj.TorchMLIRTorchToLinalg.dir/Linear.cpp.o
/opt/torch-mlir/lib/Conversion/TorchToLinalg/Linear.cpp:1800:23: warning: 'create' is deprecated: Use OpTy::create instead [-Wdeprecated-declarations]
1800 | return rewriter.create<tensor::ExpandShapeOp>(loc, retType, tensor, indices);
| ^
/opt/torch-mlir/externals/llvm-project/llvm/../mlir/include/mlir/IR/Builders.h:506:8: note: 'create' has been explicitly marked deprecated here
506 | OpTy create(Location location, Args &&...args) {
| ^
/opt/torch-mlir/lib/Conversion/TorchToLinalg/Linear.cpp:1800:23: warning: 'create<mlir::tensor::ExpandShapeOp, mlir::RankedTensorType &, mlir::Value &, llvm::SmallVector<llvm::SmallVector<long, 2>> &>' is deprecated: Use OpTy::create instead [-Wdeprecated-declarations]
1800 | return rewriter.create<tensor::ExpandShapeOp>(loc, retType, tensor, indices);
| ^
/opt/torch-mlir/externals/llvm-project/llvm/../mlir/include/mlir/IR/Builders.h:505:5: note: 'create<mlir::tensor::ExpandShapeOp, mlir::RankedTensorType &, mlir::Value &, llvm::SmallVector<llvm::SmallVector<long, 2>> &>' has been explicitly marked deprecated here
505 | [[deprecated("Use OpTy::create instead")]]
| ^
2 warnings generated.
[3/3] Linking CXX executable bin/torch-mlir-opt
Anyway I replaced with rewriter.create<Op>(...) in the separate commit 7ed7b5b. If needed, we will be able to revert this commit
| SmallVector<int64_t> weightDimsInt = makeShapeTorchCompatible( | ||
| cast<RankedTensorType>(weightExpanded.getType()).getShape()); | ||
| bool is1x1Kernel = std::all_of(weightDimsInt.rbegin(), | ||
| weightDimsInt.rbegin() + numSpatialDims, | ||
| [](int64_t dim) { return dim == 1; }); | ||
| if (numSpatialDims > 1 && !is1x1Kernel) { | ||
| SmallVector<int64_t> weightFlipDims; | ||
| weightFlipDims.reserve(numSpatialDims); | ||
| for (int64_t i = 0; i < static_cast<int64_t>(numSpatialDims); ++i) | ||
| weightFlipDims.push_back(spatialStartDimIdx + i); | ||
| weightExpanded = torch_to_linalg::flipTensor( | ||
| rewriter, loc, weightExpanded, weightFlipDims); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay, we also care about 3x1 kernels, and we still won't need to flip the second spatial dim. It may be simpler to do something like:
| SmallVector<int64_t> weightDimsInt = makeShapeTorchCompatible( | |
| cast<RankedTensorType>(weightExpanded.getType()).getShape()); | |
| bool is1x1Kernel = std::all_of(weightDimsInt.rbegin(), | |
| weightDimsInt.rbegin() + numSpatialDims, | |
| [](int64_t dim) { return dim == 1; }); | |
| if (numSpatialDims > 1 && !is1x1Kernel) { | |
| SmallVector<int64_t> weightFlipDims; | |
| weightFlipDims.reserve(numSpatialDims); | |
| for (int64_t i = 0; i < static_cast<int64_t>(numSpatialDims); ++i) | |
| weightFlipDims.push_back(spatialStartDimIdx + i); | |
| weightExpanded = torch_to_linalg::flipTensor( | |
| rewriter, loc, weightExpanded, weightFlipDims); | |
| } | |
| SmallVector<int64_t> weightDimsInt = makeShapeTorchCompatible( | |
| cast<RankedTensorType>(weightExpanded.getType()).getShape()); | |
| # Collect any non-unit spatial dim indices. | |
| SmallVector<int64_t> weightFlipDims; | |
| for (auto e in llvm::enumerate(weightDimsInt)) { | |
| if (e.value() == 1) | |
| weightFlipDims.push_back(static_cast<int64_t>(e.index()); | |
| } | |
| # Perform a flip if we have any non-trivial spatial dims. | |
| if (!weightFlipDims.empty()) { | |
| weightExpanded = torch_to_linalg::flipTensor( | |
| rewriter, loc, weightExpanded, weightFlipDims); | |
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the idea! Missed the other usecases. I took your pseudo-code and updated a little bit. Also I added lit and func tests just in case with kernel [3,1,3].
| SmallVector<IT> iteratorTypes = SmallVector<IT>(numIterators, IT::parallel); | ||
| std::fill(iteratorTypes.rbegin(), | ||
| iteratorTypes.rbegin() + (numSpatialDims + 1), IT::reduction); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit Question: Is this better than using append? Like one fewer allocation, but has to rewrite values?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To align with other impls, I replaced std::fill with iteratorTypes.append. Thanks!
| if (!isGrouped) | ||
| llvm_unreachable("g() called for non-grouped convolution."); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is really more of a curiosity for me: does the compiler actually do something with this information? E.g., maybe if the condition isGrouped ends up getting hoisted, then it helps the compiler figure out if this function should get loaded at all. Again, just a curiousity. If you don't know, feel free to disregard.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just wanna to throw exception when we execute code which shouldn't be reachable. I found only llvm_unreachable in code base and decided that compiler automatically handle such exceptions. But I found that this is not exception - this is program aborting.
I removed these checks to avoid confusion from developer side since we call g() only when isGrouped = true anyway and won't reach llvm_unreachable("g() called for non-grouped convolution.").
Please let me know if we still have opportunities to handle such exception is MLIR? I will be able to use them in the future then.
| "aten.flatten.using_ints", | ||
| "aten.adaptive_avg_pool1d", | ||
| "aten.adaptive_avg_pool2d", | ||
| "aten.convolution_backward", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't really use this path in the CI. The fx_importer config is the one which gets called in the linalg path these days. However, I'm not sure if changing this configuration would negatively impact other tests like those testing over-padded case (i.e., negative bwd data padding).
I would be fine with you checking this e2e testing out locally and reporting on which ConvolutionBackward tests fail the fx_importer config with this op as "backend legal" (see the link below) + the lit tests you've already added. We don't need to modify this in the torch-mlir e2e testing framework, since this pattern is mostly added as an optimization option that our downstream project will use. If we fully flesh out support for all the different options that are supported by the current decomposition, then we can update the backend legal ops here:
torch-mlir/projects/pt1/python/torch_mlir_e2e_test/configs/fx_importer_backend.py
Line 23 in 7712b97
| BACKEND_LEGAL_OPS = { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Got it, thank you for the detailed explanation!
Then I removed this change from jit_importer_backend.py.
As for local testing, the results of func tests:
python -m e2e_testing.main -v -f ConvolutionBackward
PASS - "ConvolutionBackwardModule2DDilated_basic"
XFAIL - "ConvolutionBackwardModule2DPadded_basic"
PASS - "ConvolutionBackwardModule2DStatic_basic"
PASS - "ConvolutionBackwardModule2DStridedPaddedDilatedGrouped_basic"
PASS - "ConvolutionBackwardModule2DStrided_basic"
XFAIL - "ConvolutionBackwardModule2D_basic"
PASS - "ConvolutionBackwardModule3DStatic_basic"
Summary:
Passed: 5
Expectedly Failed: 2
The additional changes to pass tests:
- Removed ConvoltuionBackward tests from xfail set (except 2 tests with dynamic shapes - they are failed due to
TypeError: new_empty(): argument 'size' (position 1) must be tuple of ints, but found element of type NoneType at pos 0(bias_sizes=None) - Replaced
bias_sizes=[2]withbias_sizes=[output_channels]. Otherwise I got the following error:error: 'memref.cast' op operand type 'memref<16xf32>' and result type 'memref<4xf32>' are cast incompatible. I suppose this is mistake in tests (since bias shape should be broadcastable to output shape) - decided to fix it. Please let me know If i should revert this change. - Added
convolution_backwardtoBACKEND_LEGAL_OPSas you suggested.
As for lit tests,
/opt/torch-mlir/build/bin/llvm-lit /opt/torch-mlir/test -v --filter=convolution_bwd.mlir
Enabling Torch v2.3+ tests
-- Testing: 1 of 120 tests, 1 workers --
PASS: TORCH_MLIR :: Conversion/TorchToLinalg/convolution_bwd.mlir (1 of 1)
Testing Time: 0.10s
Total Discovered Tests: 120
Excluded: 119 (99.17%)
Passed : 1 (0.83%)
31eac08 to
3a30a55
Compare
a-sidorova
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@zjgarvey thank you for the one more review! I applied your comments - may I ask you to take a look again?
| } | ||
|
|
||
| auto retType = inType.clone(makeShapeLLVMCompatible(outShape)); | ||
| return tensor::ExpandShapeOp::create(rewriter, loc, retType, tensor, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When I replace Op::create(...) with rewriter.create<Op>(...), I get the following warning during build (deprecated method):
cmake --build . -j128 --target torch-mlir-opt
[1/3] Building CXX object tools/torch-mlir/lib/Conversion/TorchToLinalg/CMakeFiles/obj.TorchMLIRTorchToLinalg.dir/Linear.cpp.o
/opt/torch-mlir/lib/Conversion/TorchToLinalg/Linear.cpp:1800:23: warning: 'create' is deprecated: Use OpTy::create instead [-Wdeprecated-declarations]
1800 | return rewriter.create<tensor::ExpandShapeOp>(loc, retType, tensor, indices);
| ^
/opt/torch-mlir/externals/llvm-project/llvm/../mlir/include/mlir/IR/Builders.h:506:8: note: 'create' has been explicitly marked deprecated here
506 | OpTy create(Location location, Args &&...args) {
| ^
/opt/torch-mlir/lib/Conversion/TorchToLinalg/Linear.cpp:1800:23: warning: 'create<mlir::tensor::ExpandShapeOp, mlir::RankedTensorType &, mlir::Value &, llvm::SmallVector<llvm::SmallVector<long, 2>> &>' is deprecated: Use OpTy::create instead [-Wdeprecated-declarations]
1800 | return rewriter.create<tensor::ExpandShapeOp>(loc, retType, tensor, indices);
| ^
/opt/torch-mlir/externals/llvm-project/llvm/../mlir/include/mlir/IR/Builders.h:505:5: note: 'create<mlir::tensor::ExpandShapeOp, mlir::RankedTensorType &, mlir::Value &, llvm::SmallVector<llvm::SmallVector<long, 2>> &>' has been explicitly marked deprecated here
505 | [[deprecated("Use OpTy::create instead")]]
| ^
2 warnings generated.
[3/3] Linking CXX executable bin/torch-mlir-opt
Anyway I replaced with rewriter.create<Op>(...) in the separate commit 7ed7b5b. If needed, we will be able to revert this commit
| SmallVector<int64_t> weightDimsInt = makeShapeTorchCompatible( | ||
| cast<RankedTensorType>(weightExpanded.getType()).getShape()); | ||
| bool is1x1Kernel = std::all_of(weightDimsInt.rbegin(), | ||
| weightDimsInt.rbegin() + numSpatialDims, | ||
| [](int64_t dim) { return dim == 1; }); | ||
| if (numSpatialDims > 1 && !is1x1Kernel) { | ||
| SmallVector<int64_t> weightFlipDims; | ||
| weightFlipDims.reserve(numSpatialDims); | ||
| for (int64_t i = 0; i < static_cast<int64_t>(numSpatialDims); ++i) | ||
| weightFlipDims.push_back(spatialStartDimIdx + i); | ||
| weightExpanded = torch_to_linalg::flipTensor( | ||
| rewriter, loc, weightExpanded, weightFlipDims); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the idea! Missed the other usecases. I took your pseudo-code and updated a little bit. Also I added lit and func tests just in case with kernel [3,1,3].
| SmallVector<IT> iteratorTypes = SmallVector<IT>(numIterators, IT::parallel); | ||
| std::fill(iteratorTypes.rbegin(), | ||
| iteratorTypes.rbegin() + (numSpatialDims + 1), IT::reduction); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To align with other impls, I replaced std::fill with iteratorTypes.append. Thanks!
| if (!isGrouped) | ||
| llvm_unreachable("g() called for non-grouped convolution."); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just wanna to throw exception when we execute code which shouldn't be reachable. I found only llvm_unreachable in code base and decided that compiler automatically handle such exceptions. But I found that this is not exception - this is program aborting.
I removed these checks to avoid confusion from developer side since we call g() only when isGrouped = true anyway and won't reach llvm_unreachable("g() called for non-grouped convolution.").
Please let me know if we still have opportunities to handle such exception is MLIR? I will be able to use them in the future then.
| "aten.flatten.using_ints", | ||
| "aten.adaptive_avg_pool1d", | ||
| "aten.adaptive_avg_pool2d", | ||
| "aten.convolution_backward", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Got it, thank you for the detailed explanation!
Then I removed this change from jit_importer_backend.py.
As for local testing, the results of func tests:
python -m e2e_testing.main -v -f ConvolutionBackward
PASS - "ConvolutionBackwardModule2DDilated_basic"
XFAIL - "ConvolutionBackwardModule2DPadded_basic"
PASS - "ConvolutionBackwardModule2DStatic_basic"
PASS - "ConvolutionBackwardModule2DStridedPaddedDilatedGrouped_basic"
PASS - "ConvolutionBackwardModule2DStrided_basic"
XFAIL - "ConvolutionBackwardModule2D_basic"
PASS - "ConvolutionBackwardModule3DStatic_basic"
Summary:
Passed: 5
Expectedly Failed: 2
The additional changes to pass tests:
- Removed ConvoltuionBackward tests from xfail set (except 2 tests with dynamic shapes - they are failed due to
TypeError: new_empty(): argument 'size' (position 1) must be tuple of ints, but found element of type NoneType at pos 0(bias_sizes=None) - Replaced
bias_sizes=[2]withbias_sizes=[output_channels]. Otherwise I got the following error:error: 'memref.cast' op operand type 'memref<16xf32>' and result type 'memref<4xf32>' are cast incompatible. I suppose this is mistake in tests (since bias shape should be broadcastable to output shape) - decided to fix it. Please let me know If i should revert this change. - Added
convolution_backwardtoBACKEND_LEGAL_OPSas you suggested.
As for lit tests,
/opt/torch-mlir/build/bin/llvm-lit /opt/torch-mlir/test -v --filter=convolution_bwd.mlir
Enabling Torch v2.3+ tests
-- Testing: 1 of 120 tests, 1 workers --
PASS: TORCH_MLIR :: Conversion/TorchToLinalg/convolution_bwd.mlir (1 of 1)
Testing Time: 0.10s
Total Discovered Tests: 120
Excluded: 119 (99.17%)
Passed : 1 (0.83%)
Description:
torch.aten.convolution_backwardfrom Torch to Linalg. Enabled this pass by default. The pass generateslinalg.genericops instead oflinalg.conv_<>for better lowering.DecomposeAtenConvolutionBackwardOpfromTorch/Transforms/DecomposeComplexOps.cpp.convolution_backward.mlir. Also added more test cases for better test coverage.Issue: