Skip to content

Conversation

@a-sidorova
Copy link

@a-sidorova a-sidorova commented Nov 20, 2025

Description:

  • Added the direct lowering pass for torch.aten.convolution_backward from Torch to Linalg. Enabled this pass by default. The pass generates linalg.generic ops instead of linalg.conv_<> for better lowering.
  • Removed the previous pass DecomposeAtenConvolutionBackwardOp from Torch/Transforms/DecomposeComplexOps.cpp.
  • Created new lit tests for backward convolution in the separate file convolution_backward.mlir. Also added more test cases for better test coverage.
  • Added new e2e tests for backward convolution for better test coverage.

Issue:

@a-sidorova a-sidorova force-pushed the feature/linalg_conv_bwd branch from 4f1cb20 to 8e2b616 Compare November 21, 2025 13:31
@a-sidorova a-sidorova marked this pull request as ready for review November 21, 2025 13:36
@a-sidorova
Copy link
Author

@zjgarvey hey! May I ask you to take a look when you're available? Thank you in advance for the review.

@zjgarvey zjgarvey self-requested a review November 21, 2025 19:36
Copy link
Collaborator

@zjgarvey zjgarvey left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! This is an excellent start.

We need to keep the existing decomposition for other backends. I have a few other comments for you to look at, but that's the biggest blocker right now.

rewriter, loc, op.getResultTypes()[2], gradOutput, gradIntList,
cstFalse, cstNone);
}

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should keep the decomposition, E.g., TOSA and StableHLO still rely on this pattern. The purpose of the backend_legal_ops option in torch-decompose-complex-ops is specifically to prevent selected decomposition patterns.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it, thank you for the explanation - I didn't know about this mechanism on the backend sides.

Returned this pass and the lit test in decompose-complex-ops.mlir

SmallVector<int64_t> weightFlipDims;
weightFlipDims.reserve(numSpatialDims);
for (int64_t i = 0; i < static_cast<int64_t>(numSpatialDims); ++i)
weightFlipDims.push_back(spatialStartDimIdx + i);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the weight shape is static at index i, and the dim size is 1 there, don't add to the flip. We definitely see a lot of 1x1 filter convs and the noop flip doesn't get folded easily IIRC.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree. I also noticed that I forgot to add condition for numSpatialDims == 1 to not insert flip.

So now we flip kernel dims only when numSpatialDims > 1 and this is not 1x1 kernel. + added lit test

Thanks!

createZeroInitTensor(rewriter, loc, gradWeightSizes, weightDTy);
SmallVector<ReassociationIndices> gradWeightCollapseIndices;
if (isGroupedConvBwd) {
auto gradWeightInitExpanded = expandGroups(gradWeightInit, 0);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should the init just be made on the expanded shape here (instead of expanding the init)? This probably gets folded, but I think it would be better to generate simpler IR when possible.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This probably gets folded, but I think it would be better to generate simpler IR when possible.

I will know! I thought that it will be fold by any canonicalization pass further. But let's do it ourself here.

Thanks!

// `c` is the input channel dimension, `f` is the output channel
// dimension, `o` is the input spatial dimension, `k` is the kernel
// dimension, `d0` is dilation. `x` is the input tensor, `dLdy` is the
// gradient of the output tensor. `dLdx` is the data-gradient tensor.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be good to mention that dLdy is the stride/padding modified grad output tensor here.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And that w is flipped along spatial dims.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added, thanks

}

static linalg::GenericOp
createGenericOp(OpBuilder &b, Location loc, Value in0, Value in1, Value out,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There might be a util for this already like "createReductionGeneric` or something. In any case, might be good to call this something a little more specific (pun intended).

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd prefer to have the own impl here for better control and understanding.
Renamed to createConvAsGenericOp.

Thanks!

Comment on lines 2061 to 2277
if (!isGrouped) {
if (numSpatialDims == 1) {
AffineExpr n, c, o, f, k;
bindDims(context, n, c, o, f, k);
AffineExpr d0 = rewriter.getAffineConstantExpr(dilationInts[0]);
SmallVector<AffineExpr> goExprs = {n, f, d0 * k + o};
SmallVector<AffineExpr> weiExprs = {f, c, k};
SmallVector<AffineExpr> outExprs = {n, c, o};
indexingMaps = {AffineMap::get(5, 0, goExprs, context),
AffineMap::get(5, 0, weiExprs, context),
AffineMap::get(5, 0, outExprs, context)};
iteratorTypes = {IT::parallel, IT::parallel, IT::parallel,
IT::reduction, IT::reduction};
} else if (numSpatialDims == 2) {
AffineExpr n, c, oh, ow, f, kh, kw;
bindDims(context, n, c, oh, ow, f, kh, kw);
AffineExpr d0 = rewriter.getAffineConstantExpr(dilationInts[0]);
AffineExpr d1 = rewriter.getAffineConstantExpr(dilationInts[1]);
SmallVector<AffineExpr> goExprs = {n, f, d0 * kh + oh, d1 * kw + ow};
SmallVector<AffineExpr> weiExprs = {f, c, kh, kw};
SmallVector<AffineExpr> outExprs = {n, c, oh, ow};
indexingMaps = {AffineMap::get(7, 0, goExprs, context),
AffineMap::get(7, 0, weiExprs, context),
AffineMap::get(7, 0, outExprs, context)};
iteratorTypes = {IT::parallel, IT::parallel, IT::parallel,
IT::parallel, IT::reduction, IT::reduction,
IT::reduction};
} else {
AffineExpr n, c, od, oh, ow, f, kd, kh, kw;
bindDims(context, n, c, od, oh, ow, f, kd, kh, kw);
AffineExpr d0 = rewriter.getAffineConstantExpr(dilationInts[0]);
AffineExpr d1 = rewriter.getAffineConstantExpr(dilationInts[1]);
AffineExpr d2 = rewriter.getAffineConstantExpr(dilationInts[2]);
SmallVector<AffineExpr> goExprs = {n, f, d0 * kd + od, d1 * kh + oh,
d2 * kw + ow};
SmallVector<AffineExpr> weiExprs = {f, c, kd, kh, kw};
SmallVector<AffineExpr> outExprs = {n, c, od, oh, ow};
indexingMaps = {AffineMap::get(9, 0, goExprs, context),
AffineMap::get(9, 0, weiExprs, context),
AffineMap::get(9, 0, outExprs, context)};
iteratorTypes = {IT::parallel, IT::parallel, IT::parallel,
IT::parallel, IT::parallel, IT::reduction,
IT::reduction, IT::reduction, IT::reduction};
}
} else {
if (numSpatialDims == 1) {
AffineExpr n, g, cg, o, fg, k;
bindDims(context, n, g, cg, o, fg, k);
AffineExpr d0 = rewriter.getAffineConstantExpr(dilationInts[0]);
SmallVector<AffineExpr> goExprs = {n, g, fg, d0 * k + o};
SmallVector<AffineExpr> weiExprs = {g, fg, cg, k};
SmallVector<AffineExpr> outExprs = {n, g, cg, o};
indexingMaps = {AffineMap::get(6, 0, goExprs, context),
AffineMap::get(6, 0, weiExprs, context),
AffineMap::get(6, 0, outExprs, context)};
iteratorTypes = {IT::parallel, IT::parallel, IT::parallel,
IT::parallel, IT::reduction, IT::reduction};
} else if (numSpatialDims == 2) {
AffineExpr n, g, cg, oh, ow, fg, kh, kw;
bindDims(context, n, g, cg, oh, ow, fg, kh, kw);
AffineExpr d0 = rewriter.getAffineConstantExpr(dilationInts[0]);
AffineExpr d1 = rewriter.getAffineConstantExpr(dilationInts[1]);
SmallVector<AffineExpr> goExprs = {n, g, fg, d0 * kh + oh,
d1 * kw + ow};
SmallVector<AffineExpr> weiExprs = {g, fg, cg, kh, kw};
SmallVector<AffineExpr> outExprs = {n, g, cg, oh, ow};
indexingMaps = {AffineMap::get(8, 0, goExprs, context),
AffineMap::get(8, 0, weiExprs, context),
AffineMap::get(8, 0, outExprs, context)};
iteratorTypes = {IT::parallel, IT::parallel, IT::parallel,
IT::parallel, IT::parallel, IT::reduction,
IT::reduction, IT::reduction};
} else {
AffineExpr n, g, cg, od, oh, ow, fg, kd, kh, kw;
bindDims(context, n, g, cg, od, oh, ow, fg, kd, kh, kw);
AffineExpr d0 = rewriter.getAffineConstantExpr(dilationInts[0]);
AffineExpr d1 = rewriter.getAffineConstantExpr(dilationInts[1]);
AffineExpr d2 = rewriter.getAffineConstantExpr(dilationInts[2]);
SmallVector<AffineExpr> goExprs = {
n, g, fg, d0 * kd + od, d1 * kh + oh, d2 * kw + ow};
SmallVector<AffineExpr> weiExprs = {g, fg, cg, kd, kh, kw};
SmallVector<AffineExpr> outExprs = {n, g, cg, od, oh, ow};
indexingMaps = {AffineMap::get(10, 0, goExprs, context),
AffineMap::get(10, 0, weiExprs, context),
AffineMap::get(10, 0, outExprs, context)};
iteratorTypes = {IT::parallel, IT::parallel, IT::parallel,
IT::parallel, IT::parallel, IT::parallel,
IT::reduction, IT::reduction, IT::reduction,
IT::reduction};
}
}
}

static void initIndexingMapsAndIteratorTypesForWeightBwd(
OpBuilder &rewriter, MLIRContext *context, bool isGrouped,
int numSpatialDims, const SmallVector<int64_t> &strideInts,
const SmallVector<int64_t> &dilationInts,
SmallVector<AffineMap> &indexingMaps, SmallVector<IT> &iteratorTypes) {
// To calculate convolution backward-weight, we use generic operation.
// The generic operation is a generalization of the convolution operation
// that can handle any number of spatial dimensions.
// The generic operation is defined as follows:
// ```
// dLdw[f, g, c, k] = sum(x[n, g, c, d0 * k + s0 * o] * dLdy[n, g, f, o]
// for n in range(batch_size) for o in range(output_spatial_dims))
// ```
// where `n` is the batch dimension, `g` is the group dimension,
// `c` is the input channel dimension, `f` is the output channel
// dimension, `o` is the output spatial dimension, `k` is the kernel
// dimension, `d0` is dilation and `s0` is stride. `x` is the input
// tensor, `dLdy` is the gradient of the output tensor. `dLdw` is the
// weight-gradient tensor.
if (!isGrouped) {
if (numSpatialDims == 1) {
AffineExpr f, c, k, n, o;
bindDims(context, f, c, k, n, o);
AffineExpr s0 = rewriter.getAffineConstantExpr(strideInts[0]);
AffineExpr d0 = rewriter.getAffineConstantExpr(dilationInts[0]);
SmallVector<AffineExpr> inExprs = {n, c, d0 * k + s0 * o};
SmallVector<AffineExpr> goExprs = {n, f, o};
SmallVector<AffineExpr> outExprs = {f, c, k};
indexingMaps = {AffineMap::get(5, 0, inExprs, context),
AffineMap::get(5, 0, goExprs, context),
AffineMap::get(5, 0, outExprs, context)};
iteratorTypes = {IT::parallel, IT::parallel, IT::parallel,
IT::reduction, IT::reduction};
} else if (numSpatialDims == 2) {
AffineExpr f, c, kh, kw, n, oh, ow;
bindDims(context, f, c, kh, kw, n, oh, ow);
AffineExpr s0 = rewriter.getAffineConstantExpr(strideInts[0]);
AffineExpr s1 = rewriter.getAffineConstantExpr(strideInts[1]);
AffineExpr d0 = rewriter.getAffineConstantExpr(dilationInts[0]);
AffineExpr d1 = rewriter.getAffineConstantExpr(dilationInts[1]);
SmallVector<AffineExpr> inExprs = {n, c, d0 * kh + s0 * oh,
d1 * kw + s1 * ow};
SmallVector<AffineExpr> goExprs = {n, f, oh, ow};
SmallVector<AffineExpr> outExprs = {f, c, kh, kw};
indexingMaps = {AffineMap::get(7, 0, inExprs, context),
AffineMap::get(7, 0, goExprs, context),
AffineMap::get(7, 0, outExprs, context)};
iteratorTypes = {IT::parallel, IT::parallel, IT::parallel,
IT::parallel, IT::reduction, IT::reduction,
IT::reduction};
} else {
AffineExpr f, c, kd, kh, kw, n, od, oh, ow;
bindDims(context, f, c, kd, kh, kw, n, od, oh, ow);
AffineExpr s0 = rewriter.getAffineConstantExpr(strideInts[0]);
AffineExpr s1 = rewriter.getAffineConstantExpr(strideInts[1]);
AffineExpr s2 = rewriter.getAffineConstantExpr(strideInts[2]);
AffineExpr d0 = rewriter.getAffineConstantExpr(dilationInts[0]);
AffineExpr d1 = rewriter.getAffineConstantExpr(dilationInts[1]);
AffineExpr d2 = rewriter.getAffineConstantExpr(dilationInts[2]);
SmallVector<AffineExpr> inExprs = {
n, c, d0 * kd + s0 * od, d1 * kh + s1 * oh, d2 * kw + s2 * ow};
SmallVector<AffineExpr> goExprs = {n, f, od, oh, ow};
SmallVector<AffineExpr> outExprs = {f, c, kd, kh, kw};
indexingMaps = {AffineMap::get(9, 0, inExprs, context),
AffineMap::get(9, 0, goExprs, context),
AffineMap::get(9, 0, outExprs, context)};
iteratorTypes = {IT::parallel, IT::parallel, IT::parallel,
IT::parallel, IT::parallel, IT::reduction,
IT::reduction, IT::reduction, IT::reduction};
}
} else {
if (numSpatialDims == 1) {
AffineExpr g, fg, cg, k, n, o;
bindDims(context, g, fg, cg, k, n, o);
AffineExpr s0 = rewriter.getAffineConstantExpr(strideInts[0]);
AffineExpr d0 = rewriter.getAffineConstantExpr(dilationInts[0]);
SmallVector<AffineExpr> inExprs = {n, g, cg, d0 * k + s0 * o};
SmallVector<AffineExpr> goExprs = {n, g, fg, o};
SmallVector<AffineExpr> outExprs = {g, fg, cg, k};
indexingMaps = {AffineMap::get(6, 0, inExprs, context),
AffineMap::get(6, 0, goExprs, context),
AffineMap::get(6, 0, outExprs, context)};
iteratorTypes = {IT::parallel, IT::parallel, IT::parallel,
IT::parallel, IT::reduction, IT::reduction};
} else if (numSpatialDims == 2) {
AffineExpr g, fg, cg, kh, kw, n, oh, ow;
bindDims(context, g, fg, cg, kh, kw, n, oh, ow);
AffineExpr s0 = rewriter.getAffineConstantExpr(strideInts[0]);
AffineExpr s1 = rewriter.getAffineConstantExpr(strideInts[1]);
AffineExpr d0 = rewriter.getAffineConstantExpr(dilationInts[0]);
AffineExpr d1 = rewriter.getAffineConstantExpr(dilationInts[1]);
SmallVector<AffineExpr> inExprs = {n, g, cg, d0 * kh + s0 * oh,
d1 * kw + s1 * ow};
SmallVector<AffineExpr> goExprs = {n, g, fg, oh, ow};
SmallVector<AffineExpr> outExprs = {g, fg, cg, kh, kw};
indexingMaps = {AffineMap::get(8, 0, inExprs, context),
AffineMap::get(8, 0, goExprs, context),
AffineMap::get(8, 0, outExprs, context)};
iteratorTypes = {IT::parallel, IT::parallel, IT::parallel,
IT::parallel, IT::parallel, IT::reduction,
IT::reduction, IT::reduction};
} else {
AffineExpr g, fg, cg, kd, kh, kw, n, od, oh, ow;
bindDims(context, g, fg, cg, kd, kh, kw, n, od, oh, ow);
AffineExpr s0 = rewriter.getAffineConstantExpr(strideInts[0]);
AffineExpr s1 = rewriter.getAffineConstantExpr(strideInts[1]);
AffineExpr s2 = rewriter.getAffineConstantExpr(strideInts[2]);
AffineExpr d0 = rewriter.getAffineConstantExpr(dilationInts[0]);
AffineExpr d1 = rewriter.getAffineConstantExpr(dilationInts[1]);
AffineExpr d2 = rewriter.getAffineConstantExpr(dilationInts[2]);
SmallVector<AffineExpr> inExprs = {
n, g, cg, d0 * kd + s0 * od, d1 * kh + s1 * oh, d2 * kw + s2 * ow};
SmallVector<AffineExpr> goExprs = {n, g, fg, od, oh, ow};
SmallVector<AffineExpr> outExprs = {g, fg, cg, kd, kh, kw};
indexingMaps = {AffineMap::get(10, 0, inExprs, context),
AffineMap::get(10, 0, goExprs, context),
AffineMap::get(10, 0, outExprs, context)};
iteratorTypes = {IT::parallel, IT::parallel, IT::parallel,
IT::parallel, IT::parallel, IT::parallel,
IT::reduction, IT::reduction, IT::reduction,
IT::reduction};
}
}
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There must be a better way.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

E.g., you could make the AffineExprs for stride, dilation, spatial dims, etc. SmallVector<AffineExpr>. I don't even think there need to be conditionals on anything other than like:

SmallVector<AffineExpr> lhsExprs = isGrouped ? {n, g, c} : {n, c};
// loop over spatial dims and add expressions...

Everything else can be like:

int64_t numIterators = 3; // batch, parallel channel, reduction channel
numIterators += static_cast<int64_t>(isGrouped);
numIterators += numSpatialDims*2 // parallel spatial dims, reduction spatial dims
indexingMaps = {
    AffineMap::get(numIterators, lhsExprs, context),
    AffineMap::get(numIterators, rhsExprs, context),
    AffineMap::get(numIterators, outExprs, context)
};

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great idea, thanks for that. Implemented.

Copy link
Author

@a-sidorova a-sidorova left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zjgarvey thank you for review! I have applied your comments in the latest commit. Could you please take a look at the changes one more time?

rewriter, loc, op.getResultTypes()[2], gradOutput, gradIntList,
cstFalse, cstNone);
}

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it, thank you for the explanation - I didn't know about this mechanism on the backend sides.

Returned this pass and the lit test in decompose-complex-ops.mlir

SmallVector<int64_t> weightFlipDims;
weightFlipDims.reserve(numSpatialDims);
for (int64_t i = 0; i < static_cast<int64_t>(numSpatialDims); ++i)
weightFlipDims.push_back(spatialStartDimIdx + i);
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree. I also noticed that I forgot to add condition for numSpatialDims == 1 to not insert flip.

So now we flip kernel dims only when numSpatialDims > 1 and this is not 1x1 kernel. + added lit test

Thanks!

createZeroInitTensor(rewriter, loc, gradWeightSizes, weightDTy);
SmallVector<ReassociationIndices> gradWeightCollapseIndices;
if (isGroupedConvBwd) {
auto gradWeightInitExpanded = expandGroups(gradWeightInit, 0);
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This probably gets folded, but I think it would be better to generate simpler IR when possible.

I will know! I thought that it will be fold by any canonicalization pass further. But let's do it ourself here.

Thanks!

Comment on lines 2061 to 2277
if (!isGrouped) {
if (numSpatialDims == 1) {
AffineExpr n, c, o, f, k;
bindDims(context, n, c, o, f, k);
AffineExpr d0 = rewriter.getAffineConstantExpr(dilationInts[0]);
SmallVector<AffineExpr> goExprs = {n, f, d0 * k + o};
SmallVector<AffineExpr> weiExprs = {f, c, k};
SmallVector<AffineExpr> outExprs = {n, c, o};
indexingMaps = {AffineMap::get(5, 0, goExprs, context),
AffineMap::get(5, 0, weiExprs, context),
AffineMap::get(5, 0, outExprs, context)};
iteratorTypes = {IT::parallel, IT::parallel, IT::parallel,
IT::reduction, IT::reduction};
} else if (numSpatialDims == 2) {
AffineExpr n, c, oh, ow, f, kh, kw;
bindDims(context, n, c, oh, ow, f, kh, kw);
AffineExpr d0 = rewriter.getAffineConstantExpr(dilationInts[0]);
AffineExpr d1 = rewriter.getAffineConstantExpr(dilationInts[1]);
SmallVector<AffineExpr> goExprs = {n, f, d0 * kh + oh, d1 * kw + ow};
SmallVector<AffineExpr> weiExprs = {f, c, kh, kw};
SmallVector<AffineExpr> outExprs = {n, c, oh, ow};
indexingMaps = {AffineMap::get(7, 0, goExprs, context),
AffineMap::get(7, 0, weiExprs, context),
AffineMap::get(7, 0, outExprs, context)};
iteratorTypes = {IT::parallel, IT::parallel, IT::parallel,
IT::parallel, IT::reduction, IT::reduction,
IT::reduction};
} else {
AffineExpr n, c, od, oh, ow, f, kd, kh, kw;
bindDims(context, n, c, od, oh, ow, f, kd, kh, kw);
AffineExpr d0 = rewriter.getAffineConstantExpr(dilationInts[0]);
AffineExpr d1 = rewriter.getAffineConstantExpr(dilationInts[1]);
AffineExpr d2 = rewriter.getAffineConstantExpr(dilationInts[2]);
SmallVector<AffineExpr> goExprs = {n, f, d0 * kd + od, d1 * kh + oh,
d2 * kw + ow};
SmallVector<AffineExpr> weiExprs = {f, c, kd, kh, kw};
SmallVector<AffineExpr> outExprs = {n, c, od, oh, ow};
indexingMaps = {AffineMap::get(9, 0, goExprs, context),
AffineMap::get(9, 0, weiExprs, context),
AffineMap::get(9, 0, outExprs, context)};
iteratorTypes = {IT::parallel, IT::parallel, IT::parallel,
IT::parallel, IT::parallel, IT::reduction,
IT::reduction, IT::reduction, IT::reduction};
}
} else {
if (numSpatialDims == 1) {
AffineExpr n, g, cg, o, fg, k;
bindDims(context, n, g, cg, o, fg, k);
AffineExpr d0 = rewriter.getAffineConstantExpr(dilationInts[0]);
SmallVector<AffineExpr> goExprs = {n, g, fg, d0 * k + o};
SmallVector<AffineExpr> weiExprs = {g, fg, cg, k};
SmallVector<AffineExpr> outExprs = {n, g, cg, o};
indexingMaps = {AffineMap::get(6, 0, goExprs, context),
AffineMap::get(6, 0, weiExprs, context),
AffineMap::get(6, 0, outExprs, context)};
iteratorTypes = {IT::parallel, IT::parallel, IT::parallel,
IT::parallel, IT::reduction, IT::reduction};
} else if (numSpatialDims == 2) {
AffineExpr n, g, cg, oh, ow, fg, kh, kw;
bindDims(context, n, g, cg, oh, ow, fg, kh, kw);
AffineExpr d0 = rewriter.getAffineConstantExpr(dilationInts[0]);
AffineExpr d1 = rewriter.getAffineConstantExpr(dilationInts[1]);
SmallVector<AffineExpr> goExprs = {n, g, fg, d0 * kh + oh,
d1 * kw + ow};
SmallVector<AffineExpr> weiExprs = {g, fg, cg, kh, kw};
SmallVector<AffineExpr> outExprs = {n, g, cg, oh, ow};
indexingMaps = {AffineMap::get(8, 0, goExprs, context),
AffineMap::get(8, 0, weiExprs, context),
AffineMap::get(8, 0, outExprs, context)};
iteratorTypes = {IT::parallel, IT::parallel, IT::parallel,
IT::parallel, IT::parallel, IT::reduction,
IT::reduction, IT::reduction};
} else {
AffineExpr n, g, cg, od, oh, ow, fg, kd, kh, kw;
bindDims(context, n, g, cg, od, oh, ow, fg, kd, kh, kw);
AffineExpr d0 = rewriter.getAffineConstantExpr(dilationInts[0]);
AffineExpr d1 = rewriter.getAffineConstantExpr(dilationInts[1]);
AffineExpr d2 = rewriter.getAffineConstantExpr(dilationInts[2]);
SmallVector<AffineExpr> goExprs = {
n, g, fg, d0 * kd + od, d1 * kh + oh, d2 * kw + ow};
SmallVector<AffineExpr> weiExprs = {g, fg, cg, kd, kh, kw};
SmallVector<AffineExpr> outExprs = {n, g, cg, od, oh, ow};
indexingMaps = {AffineMap::get(10, 0, goExprs, context),
AffineMap::get(10, 0, weiExprs, context),
AffineMap::get(10, 0, outExprs, context)};
iteratorTypes = {IT::parallel, IT::parallel, IT::parallel,
IT::parallel, IT::parallel, IT::parallel,
IT::reduction, IT::reduction, IT::reduction,
IT::reduction};
}
}
}

static void initIndexingMapsAndIteratorTypesForWeightBwd(
OpBuilder &rewriter, MLIRContext *context, bool isGrouped,
int numSpatialDims, const SmallVector<int64_t> &strideInts,
const SmallVector<int64_t> &dilationInts,
SmallVector<AffineMap> &indexingMaps, SmallVector<IT> &iteratorTypes) {
// To calculate convolution backward-weight, we use generic operation.
// The generic operation is a generalization of the convolution operation
// that can handle any number of spatial dimensions.
// The generic operation is defined as follows:
// ```
// dLdw[f, g, c, k] = sum(x[n, g, c, d0 * k + s0 * o] * dLdy[n, g, f, o]
// for n in range(batch_size) for o in range(output_spatial_dims))
// ```
// where `n` is the batch dimension, `g` is the group dimension,
// `c` is the input channel dimension, `f` is the output channel
// dimension, `o` is the output spatial dimension, `k` is the kernel
// dimension, `d0` is dilation and `s0` is stride. `x` is the input
// tensor, `dLdy` is the gradient of the output tensor. `dLdw` is the
// weight-gradient tensor.
if (!isGrouped) {
if (numSpatialDims == 1) {
AffineExpr f, c, k, n, o;
bindDims(context, f, c, k, n, o);
AffineExpr s0 = rewriter.getAffineConstantExpr(strideInts[0]);
AffineExpr d0 = rewriter.getAffineConstantExpr(dilationInts[0]);
SmallVector<AffineExpr> inExprs = {n, c, d0 * k + s0 * o};
SmallVector<AffineExpr> goExprs = {n, f, o};
SmallVector<AffineExpr> outExprs = {f, c, k};
indexingMaps = {AffineMap::get(5, 0, inExprs, context),
AffineMap::get(5, 0, goExprs, context),
AffineMap::get(5, 0, outExprs, context)};
iteratorTypes = {IT::parallel, IT::parallel, IT::parallel,
IT::reduction, IT::reduction};
} else if (numSpatialDims == 2) {
AffineExpr f, c, kh, kw, n, oh, ow;
bindDims(context, f, c, kh, kw, n, oh, ow);
AffineExpr s0 = rewriter.getAffineConstantExpr(strideInts[0]);
AffineExpr s1 = rewriter.getAffineConstantExpr(strideInts[1]);
AffineExpr d0 = rewriter.getAffineConstantExpr(dilationInts[0]);
AffineExpr d1 = rewriter.getAffineConstantExpr(dilationInts[1]);
SmallVector<AffineExpr> inExprs = {n, c, d0 * kh + s0 * oh,
d1 * kw + s1 * ow};
SmallVector<AffineExpr> goExprs = {n, f, oh, ow};
SmallVector<AffineExpr> outExprs = {f, c, kh, kw};
indexingMaps = {AffineMap::get(7, 0, inExprs, context),
AffineMap::get(7, 0, goExprs, context),
AffineMap::get(7, 0, outExprs, context)};
iteratorTypes = {IT::parallel, IT::parallel, IT::parallel,
IT::parallel, IT::reduction, IT::reduction,
IT::reduction};
} else {
AffineExpr f, c, kd, kh, kw, n, od, oh, ow;
bindDims(context, f, c, kd, kh, kw, n, od, oh, ow);
AffineExpr s0 = rewriter.getAffineConstantExpr(strideInts[0]);
AffineExpr s1 = rewriter.getAffineConstantExpr(strideInts[1]);
AffineExpr s2 = rewriter.getAffineConstantExpr(strideInts[2]);
AffineExpr d0 = rewriter.getAffineConstantExpr(dilationInts[0]);
AffineExpr d1 = rewriter.getAffineConstantExpr(dilationInts[1]);
AffineExpr d2 = rewriter.getAffineConstantExpr(dilationInts[2]);
SmallVector<AffineExpr> inExprs = {
n, c, d0 * kd + s0 * od, d1 * kh + s1 * oh, d2 * kw + s2 * ow};
SmallVector<AffineExpr> goExprs = {n, f, od, oh, ow};
SmallVector<AffineExpr> outExprs = {f, c, kd, kh, kw};
indexingMaps = {AffineMap::get(9, 0, inExprs, context),
AffineMap::get(9, 0, goExprs, context),
AffineMap::get(9, 0, outExprs, context)};
iteratorTypes = {IT::parallel, IT::parallel, IT::parallel,
IT::parallel, IT::parallel, IT::reduction,
IT::reduction, IT::reduction, IT::reduction};
}
} else {
if (numSpatialDims == 1) {
AffineExpr g, fg, cg, k, n, o;
bindDims(context, g, fg, cg, k, n, o);
AffineExpr s0 = rewriter.getAffineConstantExpr(strideInts[0]);
AffineExpr d0 = rewriter.getAffineConstantExpr(dilationInts[0]);
SmallVector<AffineExpr> inExprs = {n, g, cg, d0 * k + s0 * o};
SmallVector<AffineExpr> goExprs = {n, g, fg, o};
SmallVector<AffineExpr> outExprs = {g, fg, cg, k};
indexingMaps = {AffineMap::get(6, 0, inExprs, context),
AffineMap::get(6, 0, goExprs, context),
AffineMap::get(6, 0, outExprs, context)};
iteratorTypes = {IT::parallel, IT::parallel, IT::parallel,
IT::parallel, IT::reduction, IT::reduction};
} else if (numSpatialDims == 2) {
AffineExpr g, fg, cg, kh, kw, n, oh, ow;
bindDims(context, g, fg, cg, kh, kw, n, oh, ow);
AffineExpr s0 = rewriter.getAffineConstantExpr(strideInts[0]);
AffineExpr s1 = rewriter.getAffineConstantExpr(strideInts[1]);
AffineExpr d0 = rewriter.getAffineConstantExpr(dilationInts[0]);
AffineExpr d1 = rewriter.getAffineConstantExpr(dilationInts[1]);
SmallVector<AffineExpr> inExprs = {n, g, cg, d0 * kh + s0 * oh,
d1 * kw + s1 * ow};
SmallVector<AffineExpr> goExprs = {n, g, fg, oh, ow};
SmallVector<AffineExpr> outExprs = {g, fg, cg, kh, kw};
indexingMaps = {AffineMap::get(8, 0, inExprs, context),
AffineMap::get(8, 0, goExprs, context),
AffineMap::get(8, 0, outExprs, context)};
iteratorTypes = {IT::parallel, IT::parallel, IT::parallel,
IT::parallel, IT::parallel, IT::reduction,
IT::reduction, IT::reduction};
} else {
AffineExpr g, fg, cg, kd, kh, kw, n, od, oh, ow;
bindDims(context, g, fg, cg, kd, kh, kw, n, od, oh, ow);
AffineExpr s0 = rewriter.getAffineConstantExpr(strideInts[0]);
AffineExpr s1 = rewriter.getAffineConstantExpr(strideInts[1]);
AffineExpr s2 = rewriter.getAffineConstantExpr(strideInts[2]);
AffineExpr d0 = rewriter.getAffineConstantExpr(dilationInts[0]);
AffineExpr d1 = rewriter.getAffineConstantExpr(dilationInts[1]);
AffineExpr d2 = rewriter.getAffineConstantExpr(dilationInts[2]);
SmallVector<AffineExpr> inExprs = {
n, g, cg, d0 * kd + s0 * od, d1 * kh + s1 * oh, d2 * kw + s2 * ow};
SmallVector<AffineExpr> goExprs = {n, g, fg, od, oh, ow};
SmallVector<AffineExpr> outExprs = {g, fg, cg, kd, kh, kw};
indexingMaps = {AffineMap::get(10, 0, inExprs, context),
AffineMap::get(10, 0, goExprs, context),
AffineMap::get(10, 0, outExprs, context)};
iteratorTypes = {IT::parallel, IT::parallel, IT::parallel,
IT::parallel, IT::parallel, IT::parallel,
IT::reduction, IT::reduction, IT::reduction,
IT::reduction};
}
}
}
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great idea, thanks for that. Implemented.

// `c` is the input channel dimension, `f` is the output channel
// dimension, `o` is the input spatial dimension, `k` is the kernel
// dimension, `d0` is dilation. `x` is the input tensor, `dLdy` is the
// gradient of the output tensor. `dLdx` is the data-gradient tensor.
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added, thanks

}

static linalg::GenericOp
createGenericOp(OpBuilder &b, Location loc, Value in0, Value in1, Value out,
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd prefer to have the own impl here for better control and understanding.
Renamed to createConvAsGenericOp.

Thanks!

@a-sidorova a-sidorova requested a review from zjgarvey November 25, 2025 05:09
@a-sidorova a-sidorova force-pushed the feature/linalg_conv_bwd branch from b20062d to 30dac4b Compare November 25, 2025 08:19
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zjgarvey FYI I added my new func tests to xfail list to fix CI and align with already xfail-marked ConvolutionBackwardModule2D tests.

@a-sidorova a-sidorova force-pushed the feature/linalg_conv_bwd branch from 30dac4b to 31eac08 Compare November 25, 2025 08:25
Copy link
Collaborator

@zjgarvey zjgarvey left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few more comments before we merge. Sorry if this took a while to get back to.

}

auto retType = inType.clone(makeShapeLLVMCompatible(outShape));
return tensor::ExpandShapeOp::create(rewriter, loc, retType, tensor,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Although I think the rewriter gets properly updated these days, It's definitely preferred to use the rewriter hooks for any IR modification instead of Operation methods. I.e., use rewriter.create<OpTy>(...) instead. Sorry, I forgot to include this in my previous review.

Some info here:

https://mlir.llvm.org/docs/PatternRewriter/#matchandrewrite-implementation

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This goes for the other ::create methods being called below (there are a few of these).

Copy link
Author

@a-sidorova a-sidorova Dec 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When I replace Op::create(...) with rewriter.create<Op>(...), I get the following warning during build (deprecated method):

cmake --build . -j128 --target torch-mlir-opt
[1/3] Building CXX object tools/torch-mlir/lib/Conversion/TorchToLinalg/CMakeFiles/obj.TorchMLIRTorchToLinalg.dir/Linear.cpp.o
/opt/torch-mlir/lib/Conversion/TorchToLinalg/Linear.cpp:1800:23: warning: 'create' is deprecated: Use OpTy::create instead [-Wdeprecated-declarations]
 1800 |       return rewriter.create<tensor::ExpandShapeOp>(loc, retType, tensor, indices);
      |                       ^
/opt/torch-mlir/externals/llvm-project/llvm/../mlir/include/mlir/IR/Builders.h:506:8: note: 'create' has been explicitly marked deprecated here
  506 |   OpTy create(Location location, Args &&...args) {
      |        ^
/opt/torch-mlir/lib/Conversion/TorchToLinalg/Linear.cpp:1800:23: warning: 'create<mlir::tensor::ExpandShapeOp, mlir::RankedTensorType &, mlir::Value &, llvm::SmallVector<llvm::SmallVector<long, 2>> &>' is deprecated: Use OpTy::create instead [-Wdeprecated-declarations]
 1800 |       return rewriter.create<tensor::ExpandShapeOp>(loc, retType, tensor, indices);
      |                       ^
/opt/torch-mlir/externals/llvm-project/llvm/../mlir/include/mlir/IR/Builders.h:505:5: note: 'create<mlir::tensor::ExpandShapeOp, mlir::RankedTensorType &, mlir::Value &, llvm::SmallVector<llvm::SmallVector<long, 2>> &>' has been explicitly marked deprecated here
  505 |   [[deprecated("Use OpTy::create instead")]]
      |     ^
2 warnings generated.
[3/3] Linking CXX executable bin/torch-mlir-opt

Anyway I replaced with rewriter.create<Op>(...) in the separate commit 7ed7b5b. If needed, we will be able to revert this commit

Comment on lines 1849 to 1860
SmallVector<int64_t> weightDimsInt = makeShapeTorchCompatible(
cast<RankedTensorType>(weightExpanded.getType()).getShape());
bool is1x1Kernel = std::all_of(weightDimsInt.rbegin(),
weightDimsInt.rbegin() + numSpatialDims,
[](int64_t dim) { return dim == 1; });
if (numSpatialDims > 1 && !is1x1Kernel) {
SmallVector<int64_t> weightFlipDims;
weightFlipDims.reserve(numSpatialDims);
for (int64_t i = 0; i < static_cast<int64_t>(numSpatialDims); ++i)
weightFlipDims.push_back(spatialStartDimIdx + i);
weightExpanded = torch_to_linalg::flipTensor(
rewriter, loc, weightExpanded, weightFlipDims);
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, we also care about 3x1 kernels, and we still won't need to flip the second spatial dim. It may be simpler to do something like:

Suggested change
SmallVector<int64_t> weightDimsInt = makeShapeTorchCompatible(
cast<RankedTensorType>(weightExpanded.getType()).getShape());
bool is1x1Kernel = std::all_of(weightDimsInt.rbegin(),
weightDimsInt.rbegin() + numSpatialDims,
[](int64_t dim) { return dim == 1; });
if (numSpatialDims > 1 && !is1x1Kernel) {
SmallVector<int64_t> weightFlipDims;
weightFlipDims.reserve(numSpatialDims);
for (int64_t i = 0; i < static_cast<int64_t>(numSpatialDims); ++i)
weightFlipDims.push_back(spatialStartDimIdx + i);
weightExpanded = torch_to_linalg::flipTensor(
rewriter, loc, weightExpanded, weightFlipDims);
}
SmallVector<int64_t> weightDimsInt = makeShapeTorchCompatible(
cast<RankedTensorType>(weightExpanded.getType()).getShape());
# Collect any non-unit spatial dim indices.
SmallVector<int64_t> weightFlipDims;
for (auto e in llvm::enumerate(weightDimsInt)) {
if (e.value() == 1)
weightFlipDims.push_back(static_cast<int64_t>(e.index());
}
# Perform a flip if we have any non-trivial spatial dims.
if (!weightFlipDims.empty()) {
weightExpanded = torch_to_linalg::flipTensor(
rewriter, loc, weightExpanded, weightFlipDims);
}

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the idea! Missed the other usecases. I took your pseudo-code and updated a little bit. Also I added lit and func tests just in case with kernel [3,1,3].

Comment on lines 2137 to 2139
SmallVector<IT> iteratorTypes = SmallVector<IT>(numIterators, IT::parallel);
std::fill(iteratorTypes.rbegin(),
iteratorTypes.rbegin() + (numSpatialDims + 1), IT::reduction);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit Question: Is this better than using append? Like one fewer allocation, but has to rewrite values?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To align with other impls, I replaced std::fill with iteratorTypes.append. Thanks!

Comment on lines 2178 to 2179
if (!isGrouped)
llvm_unreachable("g() called for non-grouped convolution.");
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is really more of a curiosity for me: does the compiler actually do something with this information? E.g., maybe if the condition isGrouped ends up getting hoisted, then it helps the compiler figure out if this function should get loaded at all. Again, just a curiousity. If you don't know, feel free to disregard.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just wanna to throw exception when we execute code which shouldn't be reachable. I found only llvm_unreachable in code base and decided that compiler automatically handle such exceptions. But I found that this is not exception - this is program aborting.

I removed these checks to avoid confusion from developer side since we call g() only when isGrouped = true anyway and won't reach llvm_unreachable("g() called for non-grouped convolution.").

Please let me know if we still have opportunities to handle such exception is MLIR? I will be able to use them in the future then.

"aten.flatten.using_ints",
"aten.adaptive_avg_pool1d",
"aten.adaptive_avg_pool2d",
"aten.convolution_backward",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't really use this path in the CI. The fx_importer config is the one which gets called in the linalg path these days. However, I'm not sure if changing this configuration would negatively impact other tests like those testing over-padded case (i.e., negative bwd data padding).

I would be fine with you checking this e2e testing out locally and reporting on which ConvolutionBackward tests fail the fx_importer config with this op as "backend legal" (see the link below) + the lit tests you've already added. We don't need to modify this in the torch-mlir e2e testing framework, since this pattern is mostly added as an optimization option that our downstream project will use. If we fully flesh out support for all the different options that are supported by the current decomposition, then we can update the backend legal ops here:

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it, thank you for the detailed explanation!

Then I removed this change from jit_importer_backend.py.

As for local testing, the results of func tests:

python -m e2e_testing.main   -v -f ConvolutionBackward
PASS - "ConvolutionBackwardModule2DDilated_basic"
XFAIL - "ConvolutionBackwardModule2DPadded_basic"
PASS - "ConvolutionBackwardModule2DStatic_basic"
PASS - "ConvolutionBackwardModule2DStridedPaddedDilatedGrouped_basic"
PASS - "ConvolutionBackwardModule2DStrided_basic"
XFAIL - "ConvolutionBackwardModule2D_basic"
PASS - "ConvolutionBackwardModule3DStatic_basic"

Summary:
    Passed: 5
    Expectedly Failed: 2

The additional changes to pass tests:

  • Removed ConvoltuionBackward tests from xfail set (except 2 tests with dynamic shapes - they are failed due to TypeError: new_empty(): argument 'size' (position 1) must be tuple of ints, but found element of type NoneType at pos 0 (bias_sizes=None)
  • Replaced bias_sizes=[2] with bias_sizes=[output_channels]. Otherwise I got the following error: error: 'memref.cast' op operand type 'memref<16xf32>' and result type 'memref<4xf32>' are cast incompatible. I suppose this is mistake in tests (since bias shape should be broadcastable to output shape) - decided to fix it. Please let me know If i should revert this change.
  • Added convolution_backward to BACKEND_LEGAL_OPS as you suggested.

As for lit tests,

/opt/torch-mlir/build/bin/llvm-lit /opt/torch-mlir/test -v --filter=convolution_bwd.mlir
Enabling Torch v2.3+ tests
-- Testing: 1 of 120 tests, 1 workers --
PASS: TORCH_MLIR :: Conversion/TorchToLinalg/convolution_bwd.mlir (1 of 1)

Testing Time: 0.10s

Total Discovered Tests: 120
  Excluded: 119 (99.17%)
  Passed  :   1 (0.83%)

@a-sidorova a-sidorova force-pushed the feature/linalg_conv_bwd branch from 31eac08 to 3a30a55 Compare December 9, 2025 10:08
Copy link
Author

@a-sidorova a-sidorova left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zjgarvey thank you for the one more review! I applied your comments - may I ask you to take a look again?

}

auto retType = inType.clone(makeShapeLLVMCompatible(outShape));
return tensor::ExpandShapeOp::create(rewriter, loc, retType, tensor,
Copy link
Author

@a-sidorova a-sidorova Dec 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When I replace Op::create(...) with rewriter.create<Op>(...), I get the following warning during build (deprecated method):

cmake --build . -j128 --target torch-mlir-opt
[1/3] Building CXX object tools/torch-mlir/lib/Conversion/TorchToLinalg/CMakeFiles/obj.TorchMLIRTorchToLinalg.dir/Linear.cpp.o
/opt/torch-mlir/lib/Conversion/TorchToLinalg/Linear.cpp:1800:23: warning: 'create' is deprecated: Use OpTy::create instead [-Wdeprecated-declarations]
 1800 |       return rewriter.create<tensor::ExpandShapeOp>(loc, retType, tensor, indices);
      |                       ^
/opt/torch-mlir/externals/llvm-project/llvm/../mlir/include/mlir/IR/Builders.h:506:8: note: 'create' has been explicitly marked deprecated here
  506 |   OpTy create(Location location, Args &&...args) {
      |        ^
/opt/torch-mlir/lib/Conversion/TorchToLinalg/Linear.cpp:1800:23: warning: 'create<mlir::tensor::ExpandShapeOp, mlir::RankedTensorType &, mlir::Value &, llvm::SmallVector<llvm::SmallVector<long, 2>> &>' is deprecated: Use OpTy::create instead [-Wdeprecated-declarations]
 1800 |       return rewriter.create<tensor::ExpandShapeOp>(loc, retType, tensor, indices);
      |                       ^
/opt/torch-mlir/externals/llvm-project/llvm/../mlir/include/mlir/IR/Builders.h:505:5: note: 'create<mlir::tensor::ExpandShapeOp, mlir::RankedTensorType &, mlir::Value &, llvm::SmallVector<llvm::SmallVector<long, 2>> &>' has been explicitly marked deprecated here
  505 |   [[deprecated("Use OpTy::create instead")]]
      |     ^
2 warnings generated.
[3/3] Linking CXX executable bin/torch-mlir-opt

Anyway I replaced with rewriter.create<Op>(...) in the separate commit 7ed7b5b. If needed, we will be able to revert this commit

Comment on lines 1849 to 1860
SmallVector<int64_t> weightDimsInt = makeShapeTorchCompatible(
cast<RankedTensorType>(weightExpanded.getType()).getShape());
bool is1x1Kernel = std::all_of(weightDimsInt.rbegin(),
weightDimsInt.rbegin() + numSpatialDims,
[](int64_t dim) { return dim == 1; });
if (numSpatialDims > 1 && !is1x1Kernel) {
SmallVector<int64_t> weightFlipDims;
weightFlipDims.reserve(numSpatialDims);
for (int64_t i = 0; i < static_cast<int64_t>(numSpatialDims); ++i)
weightFlipDims.push_back(spatialStartDimIdx + i);
weightExpanded = torch_to_linalg::flipTensor(
rewriter, loc, weightExpanded, weightFlipDims);
}
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the idea! Missed the other usecases. I took your pseudo-code and updated a little bit. Also I added lit and func tests just in case with kernel [3,1,3].

Comment on lines 2137 to 2139
SmallVector<IT> iteratorTypes = SmallVector<IT>(numIterators, IT::parallel);
std::fill(iteratorTypes.rbegin(),
iteratorTypes.rbegin() + (numSpatialDims + 1), IT::reduction);
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To align with other impls, I replaced std::fill with iteratorTypes.append. Thanks!

Comment on lines 2178 to 2179
if (!isGrouped)
llvm_unreachable("g() called for non-grouped convolution.");
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just wanna to throw exception when we execute code which shouldn't be reachable. I found only llvm_unreachable in code base and decided that compiler automatically handle such exceptions. But I found that this is not exception - this is program aborting.

I removed these checks to avoid confusion from developer side since we call g() only when isGrouped = true anyway and won't reach llvm_unreachable("g() called for non-grouped convolution.").

Please let me know if we still have opportunities to handle such exception is MLIR? I will be able to use them in the future then.

"aten.flatten.using_ints",
"aten.adaptive_avg_pool1d",
"aten.adaptive_avg_pool2d",
"aten.convolution_backward",
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it, thank you for the detailed explanation!

Then I removed this change from jit_importer_backend.py.

As for local testing, the results of func tests:

python -m e2e_testing.main   -v -f ConvolutionBackward
PASS - "ConvolutionBackwardModule2DDilated_basic"
XFAIL - "ConvolutionBackwardModule2DPadded_basic"
PASS - "ConvolutionBackwardModule2DStatic_basic"
PASS - "ConvolutionBackwardModule2DStridedPaddedDilatedGrouped_basic"
PASS - "ConvolutionBackwardModule2DStrided_basic"
XFAIL - "ConvolutionBackwardModule2D_basic"
PASS - "ConvolutionBackwardModule3DStatic_basic"

Summary:
    Passed: 5
    Expectedly Failed: 2

The additional changes to pass tests:

  • Removed ConvoltuionBackward tests from xfail set (except 2 tests with dynamic shapes - they are failed due to TypeError: new_empty(): argument 'size' (position 1) must be tuple of ints, but found element of type NoneType at pos 0 (bias_sizes=None)
  • Replaced bias_sizes=[2] with bias_sizes=[output_channels]. Otherwise I got the following error: error: 'memref.cast' op operand type 'memref<16xf32>' and result type 'memref<4xf32>' are cast incompatible. I suppose this is mistake in tests (since bias shape should be broadcastable to output shape) - decided to fix it. Please let me know If i should revert this change.
  • Added convolution_backward to BACKEND_LEGAL_OPS as you suggested.

As for lit tests,

/opt/torch-mlir/build/bin/llvm-lit /opt/torch-mlir/test -v --filter=convolution_bwd.mlir
Enabling Torch v2.3+ tests
-- Testing: 1 of 120 tests, 1 workers --
PASS: TORCH_MLIR :: Conversion/TorchToLinalg/convolution_bwd.mlir (1 of 1)

Testing Time: 0.10s

Total Discovered Tests: 120
  Excluded: 119 (99.17%)
  Passed  :   1 (0.83%)

@a-sidorova a-sidorova requested a review from zjgarvey December 9, 2025 11:15
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants